diff --git a/.bazelrc b/.bazelrc index fb938169b3c0..5e7b6af95494 100644 --- a/.bazelrc +++ b/.bazelrc @@ -4,90 +4,108 @@ # TODO: Enable Bzlmod common --noenable_bzlmod -# TODO: Migrate for https://github.com/bazelbuild/bazel/issues/7260 -common --noincompatible_enable_cc_toolchain_resolution - # Make Bazel print out all options from rc files. common --announce_rc # By default, execute all actions locally. -build --spawn_strategy=local +common --spawn_strategy=local -# Enable host OS specific configs. For instance, "build:linux" will be used +# Enable host OS specific configs. For instance, "common:linux" will be used # automatically when building on Linux. -build --enable_platform_specific_config +common --enable_platform_specific_config common --experimental_cc_shared_library +common --incompatible_enable_cc_toolchain_resolution +common --repo_env USE_HERMETIC_CC_TOOLCHAIN=1 + +# TODO: Migrate for https://github.com/bazelbuild/bazel/issues/7260 +common:clang_local --noincompatible_enable_cc_toolchain_resolution +common:clang_local --@rules_ml_toolchain//common:enable_hermetic_cc=False +common:clang_local --repo_env USE_HERMETIC_CC_TOOLCHAIN=0 + +# Toolchains for CUDA non-hermetic builds. +common:cuda_clang_local --config=clang_local +common:cuda_clang_local --host_crosstool_top="@local_config_cuda//crosstool:toolchain" +common:cuda_clang_local --crosstool_top="@local_config_cuda//crosstool:toolchain" + # Do not use C-Ares when building gRPC. -build --define=grpc_no_ares=true +common --define=grpc_no_ares=true -build --define=tsl_link_protobuf=true +common --define=tsl_link_protobuf=true # Enable optimization. -build -c opt +common -c opt # Suppress all warning messages. -build --output_filter=DONT_MATCH_ANYTHING +common --output_filter=DONT_MATCH_ANYTHING + +common --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir. +common --copt=-DNB_DOMAIN=jax -build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir. +common --legacy_external_runfiles=false # ############################################################################# # Platform Specific configs below. These are automatically picked up by Bazel # depending on the platform that is running the build. # ############################################################################# -build:linux --config=posix -build:linux --copt=-Wno-unknown-warning-option +common:linux --config=posix +common:linux --copt=-Wno-unknown-warning-option # Workaround for gcc 10+ warnings related to upb. # See https://github.com/tensorflow/tensorflow/issues/39467 -build:linux --copt=-Wno-stringop-truncation -build:linux --copt=-Wno-array-parameter +common:linux --copt=-Wno-stringop-truncation +common:linux --copt=-Wno-array-parameter +common:linux --copt=-Wno-deprecated-register +common:linux --copt=-Wno-register -build:macos --config=posix -build:macos --apple_platform_type=macos +common:macos --config=posix +common:macos --apple_platform_type=macos # Bazel 7.0.0 no longer supports dynamic symbol lookup on macOS. To resolve # undefined symbol errors in macOS arm64 builds, explicitly add the necessary # linker flags until dependencies are well defined. See # https://github.com/bazelbuild/bazel/issues/19730. -build:macos --linkopt=-Wl,-undefined,dynamic_lookup -build:macos --host_linkopt=-Wl,-undefined,dynamic_lookup +common:macos --linkopt=-Wl,-undefined,dynamic_lookup +common:macos --host_linkopt=-Wl,-undefined,dynamic_lookup # Use cc toolchains from apple_support for Apple builds. # https://github.com/bazelbuild/apple_support/tree/master?tab=readme-ov-file#bazel-6-setup -build:macos --apple_crosstool_top=@local_config_apple_cc//:toolchain -build:macos --crosstool_top=@local_config_apple_cc//:toolchain -build:macos --host_crosstool_top=@local_config_apple_cc//:toolchain +common:macos --config=clang_local +common:macos --apple_crosstool_top=@local_config_apple_cc//:toolchain +common:macos --crosstool_top=@local_config_apple_cc//:toolchain +common:macos --host_crosstool_top=@local_config_apple_cc//:toolchain + +common:windows --config=clang_local # Windows has a relatively short command line limit, which JAX has begun to hit. # See https://docs.bazel.build/versions/main/windows.html -build:windows --features=compiler_param_file -build:windows --features=archive_param_file +common:windows --features=compiler_param_file +common:windows --features=archive_param_file # XLA uses M_* math constants that only get defined by MSVC headers if # _USE_MATH_DEFINES is defined. -build:windows --copt=/D_USE_MATH_DEFINES -build:windows --host_copt=/D_USE_MATH_DEFINES +common:windows --copt=/D_USE_MATH_DEFINES +common:windows --host_copt=/D_USE_MATH_DEFINES # Make sure to include as little of windows.h as possible -build:windows --copt=-DWIN32_LEAN_AND_MEAN -build:windows --host_copt=-DWIN32_LEAN_AND_MEAN -build:windows --copt=-DNOGDI -build:windows --host_copt=-DNOGDI +common:windows --copt=-DWIN32_LEAN_AND_MEAN +common:windows --host_copt=-DWIN32_LEAN_AND_MEAN +common:windows --copt=-DNOGDI +common:windows --host_copt=-DNOGDI # https://devblogs.microsoft.com/cppblog/announcing-full-support-for-a-c-c-conformant-preprocessor-in-msvc/ # otherwise, there will be some compiling error due to preprocessing. -build:windows --copt=/Zc:preprocessor -build:windows --cxxopt=/std:c++17 -build:windows --host_cxxopt=/std:c++17 +common:windows --copt=/Zc:preprocessor +common:windows --cxxopt=/std:c++17 +common:windows --host_cxxopt=/std:c++17 # Generate PDB files, to generate useful PDBs, in opt compilation_mode # --copt /Z7 is needed. -build:windows --linkopt=/DEBUG -build:windows --host_linkopt=/DEBUG -build:windows --linkopt=/OPT:REF -build:windows --host_linkopt=/OPT:REF -build:windows --linkopt=/OPT:ICF -build:windows --host_linkopt=/OPT:ICF -build:windows --incompatible_strict_action_env=true +common:windows --linkopt=/DEBUG +common:windows --host_linkopt=/DEBUG +common:windows --linkopt=/OPT:REF +common:windows --host_linkopt=/OPT:REF +common:windows --linkopt=/OPT:ICF +common:windows --host_linkopt=/OPT:ICF +common:windows --incompatible_strict_action_env=true # ############################################################################# # Feature-specific configurations. These are used by the CI configs below @@ -95,58 +113,83 @@ build:windows --incompatible_strict_action_env=true # configs such as `avx_linux` and `mkl_open_source_only`, `ci_linux_x86_64_cuda` # inherits `cuda` and `build_cuda_with_nvcc`, etc. # ############################################################################# -build:nonccl --define=no_nccl_support=true +common --repo_env=USE_PYWRAP_RULES=True +common --copt=-DGRPC_BAZEL_BUILD +common --host_copt=-DGRPC_BAZEL_BUILD +common --action_env=GRPC_BAZEL_RUNTIME=1 +common --repo_env=PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=upb +common --action_env=PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=upb + +# Some targets have the same py source file, but use different +# configurations via `requires-` tags. This results in an action +# conflict when precompiling. Disable to avoid that problem. +# See https://github.com/bazel-contrib/rules_python/issues/2445 +common --@rules_python//python/config_settings:precompile=force_disabled -build:posix --copt=-fvisibility=hidden -build:posix --copt=-Wno-sign-compare -build:posix --cxxopt=-std=c++17 -build:posix --host_cxxopt=-std=c++17 +# Do not do this. If enabled protobuf's core internal target +# @com_google_protobuf//python:protobuf_python will start depending on a bunch +# of cc_binary shared libraries artifacts, which will mess with how we link +# protobuf dependencies ourselves. By default this value is false, but some +# projects enable it, which we don't want here. +# common --define=use_fast_cpp_protos=true -build:avx_posix --copt=-mavx -build:avx_posix --host_copt=-mavx +common:nonccl --define=no_nccl_support=true -build:native_arch_posix --copt=-march=native -build:native_arch_posix --host_copt=-march=native +common:posix --copt=-fvisibility=hidden +common:posix --copt=-Wno-sign-compare +common:posix --cxxopt=-std=c++17 +common:posix --host_cxxopt=-std=c++17 -build:avx_linux --copt=-mavx -build:avx_linux --host_copt=-mavx +common:avx_posix --copt=-mavx +common:avx_posix --host_copt=-mavx -build:avx_windows --copt=/arch:AVX +common:native_arch_posix --copt=-march=native +common:native_arch_posix --host_copt=-march=native -build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1 +common:avx_linux --copt=-mavx +common:avx_linux --host_copt=-mavx + +common:avx_windows --copt=/arch:AVX + +common:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1 # Config setting to build oneDNN with Compute Library for the Arm Architecture (ACL). -build:mkl_aarch64_threadpool --define=build_with_mkl_aarch64=true -build:mkl_aarch64_threadpool --@compute_library//:openmp=false -build:mkl_aarch64_threadpool -c opt +common:mkl_aarch64_threadpool --define=build_with_mkl_aarch64=true +common:mkl_aarch64_threadpool --@compute_library//:openmp=false +common:mkl_aarch64_threadpool -c opt # Disable clang extention that rejects type definitions within offsetof. # This was added in clang-16 by https://reviews.llvm.org/D133574. # Can be removed once upb is updated, since a type definition is used within # offset of in the current version of ubp. # See https://github.com/protocolbuffers/upb/blob/9effcbcb27f0a665f9f345030188c0b291e32482/upb/upb.c#L183. -build:clang --copt=-Wno-gnu-offsetof-extensions +common:clang --copt=-Wno-gnu-offsetof-extensions # Disable clang extention that rejects unknown arguments. -build:clang --copt=-Qunused-arguments +common:clang --copt=-Qunused-arguments # Error on struct/class mismatches, since this causes link failures on Windows. -build:clang --copt=-Werror=mismatched-tags +common:clang --copt=-Werror=mismatched-tags +# Required when building with clang>=19, see jax-ml/jax#27091 +common:clang --copt=-Wno-error=c23-extensions # Configs for CUDA -build:cuda --repo_env TF_NEED_CUDA=1 -build:cuda --repo_env TF_NCCL_USE_STUB=1 +common:cuda_v12 --repo_env=HERMETIC_CUDA_VERSION="12.9.1" +common:cuda_v12 --repo_env=HERMETIC_CUDNN_VERSION="9.8.0" +common:cuda_v12 --repo_env=HERMETIC_NVSHMEM_VERSION="3.3.9" +common:cuda_v12 --repo_env=HERMETIC_NCCL_VERSION="2.27.7" # "sm" means we emit only cubin, which is forward compatible within a GPU generation. # "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations. -build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90" -build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain -build:cuda --@local_config_cuda//:enable_cuda +common:cuda_v12 --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,sm_90,sm_100,compute_120" -# Default hermetic CUDA and CUDNN versions. -build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" -build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" -build:cuda --@local_config_cuda//cuda:include_cuda_libs=true +common:cuda_v13 --repo_env=HERMETIC_CUDA_VERSION="13.0.0" +common:cuda_v13 --repo_env=HERMETIC_CUDNN_VERSION="9.12.0" +common:cuda_v13 --repo_env=HERMETIC_NVSHMEM_VERSION="3.3.20" +common:cuda_v13 --repo_env=HERMETIC_NCCL_VERSION="2.27.7" +common:cuda_v13 --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_75,sm_80,sm_90,sm_100,compute_120" -# This config is used for building targets with CUDA libraries from stubs. -build:cuda_libraries_from_stubs --@local_config_cuda//cuda:include_cuda_libs=false +common:cuda_common --repo_env TF_NEED_CUDA=1 +common:cuda_common --repo_env TF_NCCL_USE_STUB=1 +common:cuda_common --@local_config_cuda//:enable_cuda +common:cuda_common --@local_config_cuda//cuda:include_cuda_libs=true # Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries, # ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to @@ -160,40 +203,56 @@ build:cuda_libraries_from_stubs --@local_config_cuda//cuda:include_cuda_libs=fal # via LD_LIBRARY_PATH, if the nvidia-... pip packages are installed. This is # acceptable, because the workaround is "remove the nvidia-..." pip packages. # The list of CUDA pip packages that JAX depends on are present in setup.py. -build:cuda --linkopt=-Wl,--disable-new-dtags +common:cuda_common --linkopt=-Wl,--disable-new-dtags + +common:cuda12 --config=cuda_common +common:cuda12 --config=cuda_v12 + +common:cuda13 --config=cuda_common +common:cuda13 --config=cuda_v13 -# Build CUDA and other C++ targets with Clang -build:build_cuda_with_clang --@local_config_cuda//:cuda_compiler=clang +# Alias for backward compatibility. +common:cuda --config=cuda12 -# Build CUDA with NVCC and other C++ targets with Clang -build:build_cuda_with_nvcc --action_env=TF_NVCC_CLANG="1" -build:build_cuda_with_nvcc --@local_config_cuda//:cuda_compiler=nvcc +# This config is used for building targets with CUDA/NVSHMEM libraries from stubs. +common:cuda_libraries_from_stubs --@local_config_cuda//cuda:include_cuda_libs=false + +common:hermetic_cuda_umd --@cuda_driver//:include_cuda_umd_libs=true + +# common CUDA and other C++ targets with Clang +common:build_cuda_with_clang --@local_config_cuda//:cuda_compiler=clang + +# common CUDA with NVCC and other C++ targets with Clang +common:build_cuda_with_nvcc --action_env=TF_NVCC_CLANG="1" +common:build_cuda_with_nvcc --@local_config_cuda//:cuda_compiler=nvcc # Requires MSVC and LLVM to be installed -build:win_clang --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl -build:win_clang --extra_execution_platforms=//jax/tools/toolchains:x64_windows-clang-cl -build:win_clang --compiler=clang-cl +common:win_clang --config=clang_local +common:win_clang --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl +common:win_clang --extra_execution_platforms=//jax/tools/toolchains:x64_windows-clang-cl +common:win_clang --compiler=clang-cl -build:rocm_base --crosstool_top=@local_config_rocm//crosstool:toolchain -build:rocm_base --define=using_rocm=true --define=using_rocm_hipcc=true -build:rocm_base --repo_env TF_NEED_ROCM=1 -build:rocm_base --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100,gfx1200,gfx1201" +common:rocm_base --config=clang_local +common:rocm_base --crosstool_top=@local_config_rocm//crosstool:toolchain +common:rocm_base --define=using_rocm=true --define=using_rocm_hipcc=true +common:rocm_base --repo_env TF_NEED_ROCM=1 +common:rocm_base --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100,gfx1200,gfx1201" # Build with hipcc for ROCm and clang for the host. -build:rocm --config=rocm_base -build:rocm --action_env=TF_ROCM_CLANG="1" -build:rocm --action_env=CLANG_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" -build:rocm --copt=-Wno-gnu-offsetof-extensions -build:rocm --copt=-Qunused-arguments -build:rocm --action_env=TF_HIPCC_CLANG="1" +common:rocm --config=rocm_base +common:rocm --action_env=TF_ROCM_CLANG="1" +common:rocm --action_env=CLANG_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" +common:rocm --copt=-Wno-gnu-offsetof-extensions +common:rocm --copt=-Qunused-arguments +common:rocm --action_env=TF_HIPCC_CLANG="1" # ############################################################################# # Cache options below. # ############################################################################# # Public read-only cache -build:public_cache --remote_cache="https://storage.googleapis.com/jax-bazel-cache/" --remote_upload_local_results=false +common:public_cache --remote_cache="https://storage.googleapis.com/jax-bazel-cache/" --remote_upload_local_results=false # Cache pushes are limited to JAX's CI system. -build:public_cache_push --config=public_cache --remote_upload_local_results=true --google_default_credentials +common:public_cache_push --config=public_cache --remote_upload_local_results=true --google_default_credentials # Note: the following cache configs are deprecated and will be removed soon. # Public read-only cache for Mac builds. JAX uses a GCS bucket to store cache @@ -201,10 +260,10 @@ build:public_cache_push --config=public_cache --remote_upload_local_results=true # should be able to read from this cache and potentially see a speedup. The # "oct2023" in the URL is just the date when the bucket was created and can be # disregarded. It still contains the latest cache that is being used. -build:macos_cache --remote_cache="https://storage.googleapis.com/tensorflow-macos-bazel-cache/oct2023" --remote_upload_local_results=false +common:macos_cache --remote_cache="https://storage.googleapis.com/tensorflow-macos-bazel-cache/oct2023" --remote_upload_local_results=false # Cache pushes are limited to JAX's CI system. -build:macos_cache_push --config=macos_cache --remote_upload_local_results=true --google_default_credentials +common:macos_cache_push --config=macos_cache --remote_upload_local_results=true --google_default_credentials # ############################################################################# # CI Build config options below. @@ -212,58 +271,84 @@ build:macos_cache_push --config=macos_cache --remote_upload_local_results=true - # Bazel tests. # ############################################################################# # Linux x86 CI configs -build:ci_linux_x86_64 --config=avx_linux --config=avx_posix -build:ci_linux_x86_64 --config=mkl_open_source_only -build:ci_linux_x86_64 --config=clang --verbose_failures=true -build:ci_linux_x86_64 --color=yes +common:ci_linux_x86_64 --config=avx_linux --config=avx_posix +common:ci_linux_x86_64 --config=mkl_open_source_only +common:ci_linux_x86_64 --config=clang --verbose_failures=true +common:ci_linux_x86_64 --color=yes +# Deprecated CI config with non-hermetic toolchains. # TODO(b/356695103): We do not have a CPU only toolchain so we use the CUDA # toolchain for both CPU and GPU builds. -build:ci_linux_x86_64 --host_crosstool_top="@local_config_cuda//crosstool:toolchain" -build:ci_linux_x86_64 --crosstool_top="@local_config_cuda//crosstool:toolchain" -build:ci_linux_x86_64 --extra_toolchains="@local_config_cuda//crosstool:toolchain-linux-x86_64" -build:ci_linux_x86_64 --repo_env=TF_SYSROOT="/dt9" +common:ci_linux_x86_64_clang_local --config=ci_linux_x86_64 +common:ci_linux_x86_64_clang_local --config=clang_local +common:ci_linux_x86_64_clang_local --repo_env=TF_SYSROOT="/dt9" # Clang path needs to be set for remote toolchain to be configured correctly. -build:ci_linux_x86_64 --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" +common:ci_linux_x86_64_clang_local --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" + +common:ci_linux_x86_64_cuda_common --config=build_cuda_with_nvcc +common:ci_linux_x86_64_cuda_common --config=ci_linux_x86_64 # The toolchain in `--config=cuda` needs to be read before the toolchain in # `--config=ci_linux_x86_64`. Otherwise, we run into issues with manylinux # compliance. -build:ci_linux_x86_64_cuda --config=cuda --config=build_cuda_with_nvcc -build:ci_linux_x86_64_cuda --config=ci_linux_x86_64 +common:ci_linux_x86_64_cuda12 --config=cuda12 --config=ci_linux_x86_64_cuda_common +# Alias for backward compatibility. +common:ci_linux_x86_64_cuda --config=ci_linux_x86_64_cuda12 + +common:ci_linux_x86_64_cuda13 --config=cuda13 --config=ci_linux_x86_64_cuda_common # Linux Aarch64 CI configs -build:ci_linux_aarch64_base --config=clang --verbose_failures=true -build:ci_linux_aarch64_base --action_env=TF_SYSROOT="/dt10" -build:ci_linux_aarch64_base --color=yes - -build:ci_linux_aarch64 --config=ci_linux_aarch64_base -build:ci_linux_aarch64 --host_crosstool_top="@ml2014_clang_aarch64_config_aarch64//crosstool:toolchain" -build:ci_linux_aarch64 --crosstool_top="@ml2014_clang_aarch64_config_aarch64//crosstool:toolchain" - -# CUDA configs for Linux Aarch64 do not pass in the crosstool_top flag from -# above because the Aarch64 toolchain rule does not support building with NVCC. -# Instead, we use `@local_config_cuda//crosstool:toolchain` from --config=cuda -# and set `CLANG_CUDA_COMPILER_PATH` to define the toolchain so that we can -# use Clang for the C++ targets and NVCC to build CUDA targets. -build:ci_linux_aarch64_cuda --config=ci_linux_aarch64_base -build:ci_linux_aarch64_cuda --config=cuda --config=build_cuda_with_nvcc -build:ci_linux_aarch64_cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" +common:ci_linux_aarch64_base --config=clang --verbose_failures=true +common:ci_linux_aarch64_base --color=yes + +# This appears to help avoid a timeout in CI for linalg_test. +common:ci_linux_aarch64_base --test_env=OMP_NUM_THREADS=8 + +common:ci_linux_aarch64 --config=ci_linux_aarch64_base + +common:ci_linux_aarch64_cuda_common --config=ci_linux_aarch64_base +common:ci_linux_aarch64_cuda_common --config=build_cuda_with_nvcc + +# Aarch64 builds use CLANG_CUDA_COMPILER_PATH, which allows Clang to compile C++ +# targets and NVCC to compile CUDA targets. +common:ci_linux_aarch64_cuda12 --config=ci_linux_aarch64_cuda_common +common:ci_linux_aarch64_cuda12 --config=cuda12 +# Alias for backward compatibility. +common:ci_linux_aarch64_cuda --config=ci_linux_aarch64_cuda12 + +common:ci_linux_aarch64_cuda13 --config=ci_linux_aarch64_cuda_common +common:ci_linux_aarch64_cuda13 --config=cuda13 + +# Deprecated CI config with non-hermetic toolchains. +common:ci_linux_aarch64_clang_local --config=clang_local +common:ci_linux_aarch64_clang_local --config=ci_linux_aarch64_base + +common:ci_linux_aarch64_cuda_common_clang_local --config=cuda_clang_local +common:ci_linux_aarch64_cuda_common_clang_local --config=ci_linux_aarch64_cuda_common +common:ci_linux_aarch64_cuda_common_clang_local --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" + +common:ci_linux_aarch64_cuda12_clang_local --config=ci_linux_aarch64_cuda_common_clang_local +common:ci_linux_aarch64_cuda12_clang_local --config=cuda12 + +common:ci_linux_aarch64_cuda13_clang_local --config=ci_linux_aarch64_cuda_common_clang_local +common:ci_linux_aarch64_cuda13_clang_local --config=cuda13 # Mac Arm64 CI configs -build:ci_darwin_arm64 --macos_minimum_os=11.0 -build:ci_darwin_arm64 --config=macos_cache_push -build:ci_darwin_arm64 --verbose_failures=true -build:ci_darwin_arm64 --color=yes +common:ci_darwin_arm64 --config=clang_local +common:ci_darwin_arm64 --macos_minimum_os=11.0 +common:ci_darwin_arm64 --config=macos_cache_push +common:ci_darwin_arm64 --verbose_failures=true +common:ci_darwin_arm64 --color=yes # Windows x86 CI configs -build:ci_windows_amd64 --config=avx_windows -build:ci_windows_amd64 --compiler=clang-cl --config=clang --verbose_failures=true -build:ci_windows_amd64 --crosstool_top="@xla//tools/toolchains/win/20240424:toolchain" -build:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win/20240424:cc-toolchain-x64_windows-clang-cl" -build:ci_windows_amd64 --host_linkopt=/FORCE:MULTIPLE --linkopt=/FORCE:MULTIPLE -build:ci_windows_amd64 --color=yes +common:ci_windows_amd64 --config=clang_local +common:ci_windows_amd64 --config=avx_windows +common:ci_windows_amd64 --compiler=clang-cl --config=clang --verbose_failures=true +common:ci_windows_amd64 --crosstool_top="@xla//tools/toolchains/win2022/20241118:toolchain" +common:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win2022/20241118:cc-toolchain-x64_windows-clang-cl" +common:ci_windows_amd64 --host_linkopt=/FORCE:MULTIPLE --linkopt=/FORCE:MULTIPLE +common:ci_windows_amd64 --color=yes # ############################################################################# # RBE config options below. These inherit the CI configs above and set the @@ -274,73 +359,98 @@ build:ci_windows_amd64 --color=yes common --experimental_repo_remote_exec # Allow creation of resultstore URLs for any bazel invocation -build:resultstore --google_default_credentials -build:resultstore --bes_backend=buildeventservice.googleapis.com -build:resultstore --bes_instance_name="tensorflow-testing" -build:resultstore --bes_results_url="https://source.cloud.google.com/results/invocations" -build:resultstore --bes_timeout=600s + +common:resultstore_base --google_default_credentials +common:resultstore_base --bes_backend=buildeventservice.googleapis.com +common:resultstore_base --bes_timeout=600s +common:resultstore_base --bes_results_url="https://source.cloud.google.com/results/invocations" + +common:resultstore --config=resultstore_base +common:resultstore --bes_instance_name="tensorflow-testing" # Configs for RBE cache. When using resultstore, we need to use these configs # as well to ensure that the logs that get uploaded to resultstore can be read -# without any errors. -build:rbe_cache --remote_cache=remotebuildexecution.googleapis.com -build:rbe_cache --remote_instance_name=projects/tensorflow-testing/instances/default_instance - -build:rbe --config=resultstore -build:rbe --repo_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1 -build:rbe --define=EXECUTOR=remote -build:rbe --flaky_test_attempts=3 -build:rbe --jobs=200 -build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com -build:rbe --remote_timeout=3600 -build:rbe --spawn_strategy=remote,worker,standalone,local +# without any errors. Write is limited to CI +common:ci_rbe_cache --config=resultstore_base +common:ci_rbe_cache --remote_upload_local_results=true +common:ci_rbe_cache --bes_instance_name="ml-oss-rbe-testing" +common:ci_rbe_cache --remote_cache="grpcs://remotebuildexecution.googleapis.com" +common:ci_rbe_cache --remote_instance_name=projects/ml-oss-rbe-testing/instances/default_instance + +common:use_tar_archive_files --repo_env=USE_CUDA_TAR_ARCHIVE_FILES=1 +common:use_tar_archive_files --repo_env=USE_NVSHMEM_TAR_ARCHIVE_FILES=1 +common:use_tar_archive_files --repo_env=USE_LLVM_TAR_ARCHIVE_FILES=1 +common:use_tar_archive_files --repo_env=USE_MIRRORED_TAR_ARCHIVE_FILES=1 + +common:rbe --config=resultstore +common:rbe --repo_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1 +common:rbe --define=EXECUTOR=remote +common:rbe --flaky_test_attempts=3 +common:rbe --jobs=200 +common:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com +common:rbe --remote_timeout=3600 +common:rbe --spawn_strategy=remote,worker,standalone,local # Attempt to minimize the amount of data transfer between bazel and the remote # workers: -build:rbe --remote_download_toplevel +common:rbe --remote_download_toplevel test:rbe --test_env=USER=anon +test:rbe --test_env=IS_JAX_RBE_TESTING=1 + +common:rbe_cpu_pool --repo_env=REMOTE_GPU_TESTING=0 +common:rbe_gpu_pool --repo_env=REMOTE_GPU_TESTING=1 # RBE configs for Linux x86 # Set the remote worker pool common:rbe_linux_x86_64_base --remote_instance_name=projects/tensorflow-testing/instances/default_instance -build:rbe_linux_x86_64_base --config=rbe -build:rbe_linux_x86_64_base --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin" -build:rbe_linux_x86_64_base --linkopt=-lrt -build:rbe_linux_x86_64_base --host_linkopt=-lrt -build:rbe_linux_x86_64_base --linkopt=-lm -build:rbe_linux_x86_64_base --host_linkopt=-lm +common:rbe_linux_x86_64_base --config=rbe +common:rbe_linux_x86_64_base --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin" +common:rbe_linux_x86_64_base --linkopt=-lrt +common:rbe_linux_x86_64_base --host_linkopt=-lrt +common:rbe_linux_x86_64_base --linkopt=-lm +common:rbe_linux_x86_64_base --host_linkopt=-lm # Set the host, execution, and target platform -build:rbe_linux_x86_64_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" -build:rbe_linux_x86_64_base --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" -build:rbe_linux_x86_64_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" +common:rbe_linux_x86_64_base --host_platform="@ml_build_config_platform//:platform" +common:rbe_linux_x86_64_base --extra_execution_platforms="@ml_build_config_platform//:platform" +common:rbe_linux_x86_64_base --platforms="@ml_build_config_platform//:platform" + +common:rbe_linux_x86_64 --config=rbe_linux_x86_64_base +common:rbe_linux_x86_64 --config=ci_linux_x86_64 -build:rbe_linux_x86_64 --config=rbe_linux_x86_64_base -build:rbe_linux_x86_64 --config=ci_linux_x86_64 +common:rbe_linux_x86_64_cuda_common --config=rbe_linux_x86_64_base +common:rbe_linux_x86_64_cuda_common --config=rbe_gpu_pool +# Update UMD version when RBE CUDA driver is updated. +common:rbe_linux_x86_64_cuda_common --repo_env=HERMETIC_CUDA_UMD_VERSION="13.0.1" -build:rbe_linux_x86_64_cuda --config=rbe_linux_x86_64_base -build:rbe_linux_x86_64_cuda --config=ci_linux_x86_64_cuda -build:rbe_linux_x86_64_cuda --repo_env=REMOTE_GPU_TESTING=1 +common:rbe_linux_x86_64_cuda12 --config=rbe_linux_x86_64_cuda_common +common:rbe_linux_x86_64_cuda12 --config=ci_linux_x86_64_cuda12 +# Alias for backward compatibility. +common:rbe_linux_x86_64_cuda --config=rbe_linux_x86_64_cuda12 + +common:rbe_linux_x86_64_cuda13 --config=rbe_linux_x86_64_cuda_common +common:rbe_linux_x86_64_cuda13 --config=ci_linux_x86_64_cuda13 # RBE configs for Windows # Set the remote worker pool common:rbe_windows_amd64 --remote_instance_name=projects/tensorflow-testing/instances/windows -build:rbe_windows_amd64 --config=rbe +common:rbe_windows_amd64 --config=clang_local +common:rbe_windows_amd64 --config=rbe # Set the host, execution, and target platform -build:rbe_windows_amd64 --host_platform="@xla//tools/toolchains/win:x64_windows-clang-cl" -build:rbe_windows_amd64 --extra_execution_platforms="@xla//tools/toolchains/win:x64_windows-clang-cl" -build:rbe_windows_amd64 --platforms="@xla//tools/toolchains/win:x64_windows-clang-cl" +common:rbe_windows_amd64 --host_platform="@xla//tools/toolchains/win2022:windows_ltsc2022_clang" +common:rbe_windows_amd64 --extra_execution_platforms="@xla//tools/toolchains/win2022:windows_ltsc2022_clang" +common:rbe_windows_amd64 --platforms="@xla//tools/toolchains/win2022:windows_ltsc2022_clang" -build:rbe_windows_amd64 --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe -build:rbe_windows_amd64 --enable_runfiles -build:rbe_windows_amd64 --define=override_eigen_strong_inline=true +common:rbe_windows_amd64 --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe +common:rbe_windows_amd64 --enable_runfiles +common:rbe_windows_amd64 --define=override_eigen_strong_inline=true # Don't build the python zip archive in the RBE build. -build:rbe_windows_amd64 --nobuild_python_zip +common:rbe_windows_amd64 --nobuild_python_zip -build:rbe_windows_amd64 --config=ci_windows_amd64 +common:rbe_windows_amd64 --config=ci_windows_amd64 # ############################################################################# # Cross-compile config options below. Native RBE support does not exist for @@ -352,47 +462,51 @@ build:rbe_windows_amd64 --config=ci_windows_amd64 # flags seem to be actually used to specify the execution platform details. It # seems it is this way because these flags are old and predate the distinction # between host and execution platform. -build:cross_compile_base --host_cpu=k8 -build:cross_compile_base --host_crosstool_top=@xla//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite -build:cross_compile_base --extra_execution_platforms=@xla//tools/toolchains/cross_compile/config:linux_x86_64 +common:cross_compile_base --config=clang_local +common:cross_compile_base --host_cpu=k8 +common:cross_compile_base --host_crosstool_top=@xla//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite +common:cross_compile_base --extra_execution_platforms=@xla//tools/toolchains/cross_compile/config:linux_x86_64 # Linux Aarch64 -build:cross_compile_linux_aarch64 --config=cross_compile_base +common:cross_compile_linux_aarch64 --config=cross_compile_base # Set the target CPU to Aarch64 -build:cross_compile_linux_aarch64 --platforms=@xla//tools/toolchains/cross_compile/config:linux_aarch64 -build:cross_compile_linux_aarch64 --cpu=aarch64 -build:cross_compile_linux_aarch64 --crosstool_top=@xla//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite +common:cross_compile_linux_aarch64 --platforms=@xla//tools/toolchains/cross_compile/config:linux_aarch64 +common:cross_compile_linux_aarch64 --cpu=aarch64 +common:cross_compile_linux_aarch64 --crosstool_top=@xla//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite -build:rbe_cross_compile_base --config=rbe -build:rbe_cross_compile_base --remote_instance_name=projects/tensorflow-testing/instances/default_instance +common:rbe_cross_compile_base --config=rbe +common:rbe_cross_compile_base --remote_instance_name=projects/tensorflow-testing/instances/default_instance # RBE cross-compile configs for Linux Aarch64 -build:rbe_cross_compile_linux_aarch64 --config=cross_compile_linux_aarch64 -build:rbe_cross_compile_linux_aarch64 --config=rbe_cross_compile_base +common:rbe_cross_compile_linux_aarch64 --config=cross_compile_linux_aarch64 +common:rbe_cross_compile_linux_aarch64 --config=rbe_cross_compile_base + +# Avoids a timeout in linalg_test on ARM. +common:rbe_cross_compile_linux_aarch64 --test_env=OMP_NUM_THREADS=8 # Mac x86 -build:cross_compile_darwin_x86_64 --config=cross_compile_base -build:cross_compile_darwin_x86_64 --config=nonccl +common:cross_compile_darwin_x86_64 --config=cross_compile_base +common:cross_compile_darwin_x86_64 --config=nonccl # Target Catalina (10.15) as the minimum supported OS -build:cross_compile_darwin_x86_64 --action_env MACOSX_DEPLOYMENT_TARGET=10.15 +common:cross_compile_darwin_x86_64 --action_env MACOSX_DEPLOYMENT_TARGET=10.15 # Set the target CPU to Darwin x86 -build:cross_compile_darwin_x86_64 --platforms=@xla//tools/toolchains/cross_compile/config:darwin_x86_64 -build:cross_compile_darwin_x86_64 --cpu=darwin -build:cross_compile_darwin_x86_64 --crosstool_top=@xla//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite +common:cross_compile_darwin_x86_64 --platforms=@xla//tools/toolchains/cross_compile/config:darwin_x86_64 +common:cross_compile_darwin_x86_64 --cpu=darwin +common:cross_compile_darwin_x86_64 --crosstool_top=@xla//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite # When RBE cross-compiling for macOS, we need to explicitly register the # toolchain. Otherwise, oddly, RBE complains that a "docker container must be # specified". -build:cross_compile_darwin_x86_64 --extra_toolchains=@xla//tools/toolchains/cross_compile/config:macos-x86-cross-compile-cc-toolchain +common:cross_compile_darwin_x86_64 --extra_toolchains=@xla//tools/toolchains/cross_compile/config:macos-x86-cross-compile-cc-toolchain # Map --platforms=darwin_x86_64 to --cpu=darwin and vice-versa to make selects() # and transistions that use these flags work. The flag --platform_mappings needs # to be set to a file that exists relative to the package path roots. -build:cross_compile_darwin_x86_64 --platform_mappings=platform_mappings +common:cross_compile_darwin_x86_64 --platform_mappings=platform_mappings # RBE cross-compile configs for Darwin x86 -build:rbe_cross_compile_darwin_x86_64 --config=cross_compile_darwin_x86_64 -build:rbe_cross_compile_darwin_x86_64 --config=rbe_cross_compile_base +common:rbe_cross_compile_darwin_x86_64 --config=cross_compile_darwin_x86_64 +common:rbe_cross_compile_darwin_x86_64 --config=rbe_cross_compile_base ############################################################################# # Some configs to make getting some forms of debug builds. In general, the @@ -401,16 +515,16 @@ build:rbe_cross_compile_darwin_x86_64 --config=rbe_cross_compile_base # Or try 'debug' to get a build with assertions enabled and minimal # optimizations. # Include these in a local .bazelrc.user file as: -# build --config=debug_symbols +# common --config=debug_symbols # Or: -# build --config=debug +# common --config=debug # # Additional files can be opted in for debug symbols by adding patterns # to a per_file_copt similar to below. ############################################################################# -build:debug_symbols --strip=never --per_file_copt="xla/pjrt|xla/python@-g3" -build:debug --config debug_symbols -c fastbuild +common:debug_symbols --strip=never --per_file_copt="xla/pjrt|xla/python@-g3" +common:debug --config=debug_symbols -c fastbuild # Load `.jax_configure.bazelrc` file written by build.py try-import %workspace%/.jax_configure.bazelrc diff --git a/.bazelversion b/.bazelversion index 815da58b7a9e..1985849fb589 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1 +1 @@ -7.4.1 +7.7.0 diff --git a/.editorconfig b/.editorconfig index eb7b2cbd5e03..72666216605a 100644 --- a/.editorconfig +++ b/.editorconfig @@ -5,18 +5,7 @@ indent_style = space end_of_line = lf trim_trailing_whitespace = true insert_final_newline = true - -[*.py] -max_line_length = 79 -indent_size = 2 - -[*.rst] -max_line_length = 79 indent_size = 2 -[*.md] -max_line_length = 79 -indent_size = 2 - -[*.yml] -indent_size = 2 +[*.py] +max_line_length = 80 diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 628310519b66..1f8c2b2ac254 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -24,7 +24,7 @@ body: [issue search]: https://github.com/jax-ml/jax/search?q=is%3Aissue&type=issues - [Raw report]: http://github.com/jax-ml/jax/issues/new + [Raw report]: https://github.com/jax-ml/jax/issues/new?template=none - type: textarea attributes: label: Description diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 000000000000..f3d0c8252821 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,11 @@ + \ No newline at end of file diff --git a/.github/actionlint.yaml b/.github/actionlint.yaml new file mode 100644 index 000000000000..6307be466509 --- /dev/null +++ b/.github/actionlint.yaml @@ -0,0 +1,20 @@ +# Configuration related to self-hosted runner. +self-hosted-runner: + labels: + - "linux-x86-n4-16" # Linux X86 runner using the 16 vcpu n4-standard-16 machine. + - "linux-x86-n4-32" # Linux X86 runner using the 32 vcpu n4-standard-32 machine. + - "linux-x86-n4-64" # Linux X86 runner using the 64 vcpu n2-standard-64 machine. + - "linux-x86-g2-16-l4-1gpu" # Linux X86 GPU runner using g2-standard-16 machine with 1 NVIDIA L4 GPU attached. + - "linux-x86-g2-48-l4-4gpu" # Linux X86 GPU runner using g2-standard-48 machine with 4 NVIDIA L4 GPUs attached. + - "linux-x86-ct5lp-224-8tpu" # Linux X86 TPU runner using ct5lp-hightpu-8t machine with 2x4 topology. + - "linux-arm64-c4a-16" # Linux ARM64 CPU Runner using the 16 vcpu c4a-standard-16 machine. + - "linux-arm64-c4a-64" # Linux ARM64 CPU Runner using the 64 vcpu c4a-standard-64 machine. + - "windows-x86-n2-16" # Windows X86 runner using n2-standard-16 machine. + - "windows-x86-n2-64" # Windows X86 runner using n2-standard-64 machine. + - "linux-x86-a4-224-b200-1gpu" # Linux X86 GPU runner using 1 B200 GPU and 1/8 the resources of a a4-highgpu-8g machine + - "linux-x86-a3-8g-h100-8gpu" # Linux X86 GPU runner using a3-highgpu-8g machine with 8 NVIDIA H100 GPUs attached. + - "linux-x86-ct6e-180-8tpu" # Linux X86 TPU runner using ct6e-hightpu-8t machine with 2x4 topology. + - "linux-x86-ct6e-180-4tpu" # Linux X86 TPU runner using ct6e-hightpu-4t machine with 2x2 topology. + - "linux-x86-ct4p-240-4tpu" # Linux X86 TPU runner using ct4p-hightpu-4t machine with 2x2x1 topology. + - "linux-x86-tpu7x-224-4tpu" # Linux X86 TPU runner using tpu7x-224 machine with 4 TPU chips (8 cores) and 2x2x1 topology. + - "linux-x86_64-cirrascale-64-8gpu-amd-mi250" # AMD runner diff --git a/.github/actions/download-jax-cpu-wheels/action.yml b/.github/actions/download-jax-cpu-wheels/action.yml new file mode 100644 index 000000000000..ef1d57fed5ca --- /dev/null +++ b/.github/actions/download-jax-cpu-wheels/action.yml @@ -0,0 +1,105 @@ +# Composite action to download the jax and jaxlib wheels +name: Download JAX CPU wheels + +inputs: + runner: + description: "Which runner type should the wheels be downloaded for?" + type: string + default: "linux-x86-n4-16" + python: + description: "Which python version should the artifact be downloaded for?" + required: true + type: string + jaxlib-version: + description: "Which jaxlib version to download? (head/pypi_latest)" + type: string + default: "head" + skip-download-jaxlib-from-gcs: + description: "Whether to skip downloading the jaxlib artifact from GCS (e.g for testing a jax only release)" + default: '0' + type: string + gcs_download_uri: + description: "GCS location prefix from where the artifacts should be downloaded" + default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + type: string +permissions: {} +runs: + using: "composite" + + steps: + # Note that certain envs such as JAXCI_HERMETIC_PYTHON_VERSION are set by the calling workflow. + - name: Set env vars for use in artifact download URL + shell: bash + run: | + os=$(uname -s | awk '{print tolower($0)}') + arch=$(uname -m) + + # Adjust os and arch for Windows + if [[ $os =~ "msys_nt" ]] && [[ $arch =~ "x86_64" ]]; then + os="win" + arch="amd64" + fi + + # Get the major and minor version of Python. + # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 + # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.13-nogil, then python_major_minor=313t + python_major_minor=$(echo "${JAXCI_HERMETIC_PYTHON_VERSION//-nogil/t}" | tr -d '.') + + echo "OS=${os}" >> $GITHUB_ENV + echo "ARCH=${arch}" >> $GITHUB_ENV + # Python wheels follow a naming convention: standard wheels use the pattern + # `*-cp-cp-*`, while free-threaded wheels use + # `*-cp-cpt-*`. + echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV + - name: Download wheels from GCS (non-Windows runs) + shell: bash + id: download-wheel-artifacts-nw + # Set continue-on-error to true to prevent actions from failing the workflow if this step + # fails. Instead, we verify the outcome in the step below so that we can print a more + # informative error message. + continue-on-error: true + if: ${{ !contains(inputs.runner, 'windows-x86') }} + run: | + mkdir -p $(pwd)/dist + gcloud storage cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ + + if [[ "${{ inputs.skip-download-jaxlib-from-gcs }}" == "1" ]]; then + echo "JAX only release. Only downloading the jax wheel from the release bucket." + else + if [[ ${{ inputs.jaxlib-version }} == "head" ]]; then + gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ + elif [[ ${{ inputs.jaxlib-version }} == "pypi_latest" ]]; then + PYTHON=python${{ inputs.python }} + $PYTHON -m pip download jaxlib --dest $(pwd)/dist/ + else + echo "Invalid jaxlib version: ${{ inputs.jaxlib-version }}" + exit 1 + fi + fi + - name: Download wheels from GCS (Windows runs) + shell: cmd + id: download-wheel-artifacts-w + # Set continue-on-error to true to prevent actions from failing the workflow if this step + # fails. Instead, we verify the outcome in step below so that we can print a more + # informative error message. + continue-on-error: true + if: ${{ contains(inputs.runner, 'windows-x86') }} + run: | + mkdir dist + @REM Use `call` so that we can run sequential gcloud storage commands on Windows + @REM See https://github.com/GoogleCloudPlatform/gsutil/issues/233#issuecomment-196150652 + call gcloud storage cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl dist/ + + if "${{ inputs.skip-download-jaxlib-from-gcs }}"=="1" ( + echo "JAX only release. Only downloading the jax wheel from the release bucket." + ) else ( + call gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jaxlib*%PYTHON_MAJOR_MINOR%*%OS%*%ARCH%*.whl" dist/ + ) + - name: Skip the test run if the wheel artifacts were not downloaded successfully + shell: bash + if: steps.download-wheel-artifacts-nw.outcome == 'failure' || steps.download-wheel-artifacts-w.outcome == 'failure' + run: | + echo "Failed to download wheel artifacts from GCS. Please check if the wheels were" + echo "built successfully by the artifact build jobs and are available in the GCS bucket." + echo "Skipping the test run." + exit 1 diff --git a/.github/actions/download-jax-cuda-wheels/action.yml b/.github/actions/download-jax-cuda-wheels/action.yml new file mode 100644 index 000000000000..f40bd18e31f5 --- /dev/null +++ b/.github/actions/download-jax-cuda-wheels/action.yml @@ -0,0 +1,100 @@ +# Composite action to download the jax, jaxlib, and the CUDA plugin wheels +name: Download JAX CUDA wheels + +inputs: + python: + description: "Which python version should the artifact be downloaded for?" + type: string + required: true + cuda-version: + description: "Which cuda version should the artifact be downloaded for?" + type: string + default: "12" + use-nvidia-pip-wheels: + description: "Whether to download Nvidia CUDA packages from PyPI?" + type: boolean + default: false + jaxlib-version: + description: "Which jaxlib version to download? (head/pypi_latest)" + type: string + default: "head" + download-jax-from-gcs: + description: "Whether to download the jax wheel from GCS" + default: '1' + type: string + skip-download-jaxlib-and-cuda-plugins-from-gcs: + description: "Whether to skip downloading the jaxlib and cuda plugins from GCS (e.g for testing a jax only release)" + default: '0' + type: string + gcs_download_uri: + description: "GCS location prefix from where the artifacts should be downloaded" + default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + type: string +permissions: {} +runs: + using: "composite" + + steps: + # Note that certain envs such as JAXCI_HERMETIC_PYTHON_VERSION are set by the calling workflow. + - name: Set env vars for use in artifact download URL + shell: bash + run: | + os=$(uname -s | awk '{print tolower($0)}') + arch=$(uname -m) + + # Get the major and minor version of Python. + # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.11, then python_major_minor=311 + # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.13-nogil, then python_major_minor=313t + python_major_minor=$(echo "${JAXCI_HERMETIC_PYTHON_VERSION//-nogil/t}" | tr -d '.') + + echo "OS=${os}" >> $GITHUB_ENV + echo "ARCH=${arch}" >> $GITHUB_ENV + # Python wheels follow a naming convention: standard wheels use the pattern + # `*-cp-cp-*`, while free-threaded wheels use + # `*-cp-cpt-*`. + echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV + + # Get the CUDA major version only + full_cuda_version="${{ inputs.cuda-version }}" + echo "JAXCI_CUDA_VERSION=${full_cuda_version%%.*}" >> $GITHUB_ENV + - name: Download wheels + shell: bash + id: download-wheel-artifacts + # Set continue-on-error to true to prevent actions from failing the workflow if this step + # fails. Instead, we verify the outcome in the next step so that we can print a more + # informative error message. + continue-on-error: true + run: | + mkdir -p $(pwd)/dist + if [[ "${{ inputs.download-jax-from-gcs }}" == "1" ]]; then + gcloud storage cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ + else + echo "JAX wheel won't be downloaded, only jaxlib pre-built wheel is tested." + fi + + # Do not download the jaxlib and CUDA plugin artifacts if we are testing a jax only + # release. + if [[ "${{ inputs.skip-download-jaxlib-and-cuda-plugins-from-gcs }}" == "1" ]]; then + echo "JAX only release. Only downloading the jax wheel from the release bucket." + else + if [[ ${{ inputs.jaxlib-version }} == "head" ]]; then + gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ + gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jax*cuda${JAXCI_CUDA_VERSION}*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ + gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jax*cuda${JAXCI_CUDA_VERSION}*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ + elif [[ ${{ inputs.jaxlib-version }} == "pypi_latest" ]]; then + PYTHON=python${{ inputs.python }} + $PYTHON -m pip download jaxlib jax-cuda${JAXCI_CUDA_VERSION}-pjrt jax-cuda${JAXCI_CUDA_VERSION}-plugin --dest $(pwd)/dist/ + else + echo "Invalid jaxlib version: ${{ inputs.jaxlib-version }}" + exit 1 + fi + fi + - name: Skip the test run if the wheel artifacts were not downloaded successfully + shell: bash + if: steps.download-wheel-artifacts.outcome == 'failure' + run: | + echo "Failed to download wheel artifacts. Please check if the wheels were" + echo "built successfully by the artifact build jobs and are available in the GCS bucket if + echo "downloading from GCS." + echo "Skipping the test run." + exit 1 diff --git a/.github/workflows/asan.yaml b/.github/workflows/asan.yaml index ea69d92e552e..b4a06f08f741 100644 --- a/.github/workflows/asan.yaml +++ b/.github/workflows/asan.yaml @@ -13,12 +13,17 @@ on: - main paths: - '**/workflows/asan.yaml' +permissions: {} + +env: + UV_DEFAULT_INDEX: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple" + PIP_INDEX_URL: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple" jobs: asan: # Don't execute in fork due to runner type if: github.repository == 'jax-ml/jax' - runs-on: linux-x86-n2-64 + runs-on: linux-x86-n4-64 container: image: index.docker.io/library/ubuntu@sha256:b359f1067efa76f37863778f7b6d0e8d911e3ee8efa807ad01fbf5dc1ef9006b # ratchet:ubuntu:24.04 strategy: @@ -38,14 +43,16 @@ jobs: zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl git \ libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \ libffi-dev liblzma-dev - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 with: path: jax - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + persist-credentials: false + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 with: repository: python/cpython path: cpython ref: v3.13.0 + persist-credentials: false - name: Build CPython with ASAN enabled env: ASAN_OPTIONS: detect_leaks=0 @@ -60,6 +67,7 @@ jobs: env: ASAN_OPTIONS: detect_leaks=0 run: | + apt install -y xxd source ${GITHUB_WORKSPACE}/venv/bin/activate cd jax pip install uv~=0.5.30 @@ -72,8 +80,7 @@ jobs: cd jax python build/build.py build --wheels=jaxlib --verbose \ --bazel_options=--color=yes \ - --bazel_options=--copt=-fsanitize=address \ - --clang_path=/usr/bin/clang-18 + --bazel_options=--copt=-fsanitize=address uv pip install dist/jaxlib-*.whl \ -e . - name: Run tests diff --git a/.github/workflows/bazel_cpu.yml b/.github/workflows/bazel_cpu.yml new file mode 100644 index 000000000000..3d1d2381c38c --- /dev/null +++ b/.github/workflows/bazel_cpu.yml @@ -0,0 +1,96 @@ +# CI - Bazel CPU tests (RBE) +# +# This workflow runs the Bazel CPU tests with wheel dependencies. It can only be triggered by +# other workflows via `workflow_call`. It is used by the `CI - Wheel Tests (Continuous)` and +# `CI - Wheel Tests (Nightly/Release)` workflows to run the Bazel CPU tests. +# +# It consists of the following job: +# run-tests: +# - Downloads the jax, jaxlib from a GCS bucket if build_jaxlib is false. Otherwise, +# the artifacts are built from source. +# - Executes the `run_bazel_test_cpu_rbe.sh` script, which performs the following actions: +# - `build_jaxlib=wheel`: Runs the Bazel CPU tests with py_import dependencies. +# - `build_jaxlib=false`: Runs the Bazel CPU tests with downloaded wheel dependencies. +# - `build_jaxlib=true`: Runs the Bazel CPU tests with individual Bazel target dependencies. + +name: CI - Bazel CPU tests with wheel dependencies (RBE) +permissions: {} +on: + workflow_call: + inputs: + runner: + description: "Which runner should the workflow run on?" + type: string + default: "linux-x86-n4-16" + python: + description: "Which python version to test?" + type: string + default: "3.12" + enable-x64: + description: "Should x64 mode be enabled?" + type: string + default: "0" + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: string + default: 'no' + build_jaxlib: + description: 'Should jaxlib be built from source?' + required: true + type: string + build_jax: + description: 'Should jax be built from source?' + required: true + type: string + bazel_command: + description: "Whether to build or run test targets?" + required: false + type: string + default: "test" + gcs_download_uri: + description: "GCS location prefix from where the artifacts should be downloaded" + default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + type: string + +jobs: + run-tests: + defaults: + run: + # Explicitly set the shell to bash + shell: bash + runs-on: ${{ inputs.runner }} + container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest') || + (contains(inputs.runner, 'linux-arm64') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build-arm64:latest') || + (contains(inputs.runner, 'windows-x86') && null) }} + env: + JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }} + JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }} + JAXCI_BUILD_JAXLIB: ${{ inputs.build_jaxlib }} + JAXCI_BUILD_JAX: ${{ inputs.build_jax }} + JAXCI_BAZEL_CPU_RBE_MODE: "${{ inputs.bazel_command }}" + +# Begin Presubmit Naming Check - name modification requires internal check to be updated + name: "${{ (contains(inputs.runner, 'linux-x86') && 'linux x86') || + (contains(inputs.runner, 'linux-arm64') && 'linux arm64') || + (contains(inputs.runner, 'windows-x86') && 'windows x86') }}, Python=${{ inputs.python }}, x64=${{ inputs.enable-x64 }}, build_jax=${{ inputs.build_jax }}, build_jaxlib=${{ inputs.build_jaxlib }}" +# End Presubmit Naming Check github-cpu-presubmits + + steps: + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + with: + persist-credentials: false + - name: Download JAX CPU wheels + if: inputs.build_jaxlib == 'false' + uses: ./.github/actions/download-jax-cpu-wheels + with: + runner: ${{ inputs.runner }} + python: ${{ inputs.python }} + gcs_download_uri: ${{ inputs.gcs_download_uri }} + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: "Bazel CPU tests with build_jaxlib=${{ format('{0}', inputs.build_jaxlib) }}" + timeout-minutes: 60 + run: ./ci/run_bazel_test_cpu_rbe.sh diff --git a/.github/workflows/bazel_cpu_presubmit.yml b/.github/workflows/bazel_cpu_presubmit.yml new file mode 100644 index 000000000000..c931b4c1445a --- /dev/null +++ b/.github/workflows/bazel_cpu_presubmit.yml @@ -0,0 +1,88 @@ +name: CI - Bazel CPU tests + +on: + workflow_dispatch: + inputs: + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: choice + required: true + default: 'no' + options: + - 'yes' + - 'no' + pull_request: + branches: + - main + push: + branches: + - main + - 'release/**' +permissions: {} +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + # Don't cancel in-progress jobs for main/release branches. + cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }} + +jobs: + build-jax-artifact: + uses: ./.github/workflows/build_artifacts.yml + name: "Build jax artifact" + with: + # Note that since jax is a pure python package, the runner OS and Python values do not + # matter. In addition, cloning main XLA also has no effect. + runner: "linux-x86-n4-16" + artifact: "jax" + upload_artifacts_to_gcs: false + + build-jaxlib-artifact: + uses: ./.github/workflows/build_artifacts.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + # Runner OS and Python values need to match the matrix stategy in the CPU tests job + runner: ["linux-x86-n4-16", "linux-arm64-c4a-16", "windows-x86-n2-16"] + artifact: ["jaxlib"] + python: ["3.14"] + # Note: For reasons unknown, Github actions groups jobs with the same top-level name in the + # dashboard only if we use an expression in the "name" field. Otherwise, it appends the matrix + # values to the name and creates a separate entry for each matrix combination. + name: "Build ${{ format('{0}', 'jaxlib') }} artifacts" + with: + runner: ${{ matrix.runner }} + artifact: ${{ matrix.artifact }} + python: ${{ matrix.python }} + clone_main_xla: 1 + upload_artifacts_to_gcs: false + + run_tests: + if: github.event.repository.fork == false + uses: ./.github/workflows/bazel_cpu.yml +# Begin Presubmit Naming Check - name modification requires internal check to be updated + strategy: + matrix: + python: ["3.11", "3.14"] + runner: ["linux-x86-n4-16", "linux-arm64-c4a-16", "windows-x86-n2-16"] + enable-x64: [1, 0] + exclude: + # Exclude x64=1 on the oldest Python and x64=0 on the newest Python. As long as we have + # coverage for one of each, we don't need to run both. + - python: "3.11" + enable-x64: 1 + - python: "3.14" + enable-x64: 0 + # Only test a single Python version on Arm64 and Windows as we don't run the tests. + - python: "3.11" + runner: "linux-arm64-c4a-16" + - python: "3.11" + runner: "windows-x86-n2-16" + name: "Bazel CPU ${{ (contains(matrix.runner, 'linux-x86') && 'tests' || 'build only') }}" +# End Presubmit Naming Check github-cpu-presubmits + with: + runner: ${{ matrix.runner }} + python: ${{ matrix.python }} + halt-for-connection: ${{ inputs.halt-for-connection }} + enable-x64: ${{ matrix.enable-x64 }} + build_jaxlib: 'true' + build_jax: 'true' + bazel_command: ${{ (contains(matrix.runner, 'linux-x86') && 'test') || 'build' }} \ No newline at end of file diff --git a/.github/workflows/bazel_cpu_rbe.yml b/.github/workflows/bazel_cpu_rbe.yml deleted file mode 100644 index d6816d492d1d..000000000000 --- a/.github/workflows/bazel_cpu_rbe.yml +++ /dev/null @@ -1,58 +0,0 @@ -name: CI - Bazel CPU tests (RBE) - -on: - workflow_dispatch: - inputs: - halt-for-connection: - description: 'Should this workflow run wait for a remote connection?' - type: choice - required: true - default: 'no' - options: - - 'yes' - - 'no' - pull_request: - branches: - - main - push: - branches: - - main - - 'release/**' - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} - # Don't cancel in-progress jobs for main/release branches. - cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }} - -jobs: - run_tests: - if: github.event.repository.fork == false - runs-on: ${{ matrix.runner }} - container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || - (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') }} - env: - JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} - JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} -# Begin Presubmit Naming Check - name modification requires internal check to be updated - strategy: - matrix: - python: ["3.10", "3.13"] - runner: ["linux-x86-n2-16", "linux-arm64-c4a-16"] - enable-x_64: [1, 0] - exclude: - # Exclude x64=1 on the oldest Python and x64=0 on the newest Python. As long as we have - # coverage for one of each, we don't need to run both. - - python: "3.10" - enable-x_64: 1 - - python: "3.13" - enable-x_64: 0 - name: "Bazel CPU tests (${{ matrix.runner }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x_64 }})" -# End Presubmit Naming Check github-cpu-presubmits - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main - with: - halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Run Bazel CPU Tests with RBE - run: ./ci/run_bazel_test_cpu_rbe.sh \ No newline at end of file diff --git a/.github/workflows/bazel_cuda.yml b/.github/workflows/bazel_cuda.yml new file mode 100644 index 000000000000..6206d91d9f23 --- /dev/null +++ b/.github/workflows/bazel_cuda.yml @@ -0,0 +1,125 @@ +# CI - Bazel CUDA tests +# +# This workflow runs the CUDA tests with Bazel. It can only be triggered by other workflows via +# `workflow_call`. It is used by the `CI - Bazel CUDA tests (RBE)`,`CI - Wheel Tests (Continuous)` +# and `CI - Wheel Tests (Nightly/Release)` workflows to run the Bazel CUDA tests. +# +# It consists of the following job: +# run-tests: +# - Downloads the jaxlib and CUDA artifacts from a GCS bucket if build_jaxlib is `false`. +# Otherwise, the artifacts are built from source. +# - Downloads the jax artifact from a GCS bucket if build_jax is `false`. +# Otherwise, the artifact is built from source. +# - If `run_multiaccelerator_tests` is `false`, executes the `run_bazel_test_cuda_rbe.sh` script, +# which performs the following actions: +# - `build_jaxlib=wheel`: Runs the Bazel CPU tests with py_import dependencies. +# - `build_jaxlib=false`: Runs the Bazel CPU tests with downloaded wheel dependencies. +# - `build_jaxlib=true`: Runs the Bazel CPU tests with individual Bazel target dependencies. +# - If `run_multiaccelerator_tests` is `true`, executes the `run_bazel_test_cuda_non_rbe.sh` +# script, which performs the following actions: +# - `build_jaxlib=wheel`: Runs the Bazel CPU tests with py_import dependencies. +# - `build_jaxlib=false`: Runs the Bazel CPU tests with downloaded wheel dependencies. + +name: CI - Bazel CUDA tests + +on: + workflow_call: + inputs: + runner: + description: "Which runner should the workflow run on?" + type: string + default: "linux-x86-n4-16" + python: + description: "Which python version to test?" + type: string + default: "3.12" + cuda-version: + description: "Which CUDA version to test?" + type: string + default: "12" + enable-x64: + description: "Should x64 mode be enabled?" + type: string + default: "0" + jaxlib-version: + description: "Which jaxlib version to test? (head/pypi_latest)" + type: string + default: "head" + download-jax-from-gcs: + description: "Whether to download the jax wheel from GCS" + default: '1' + type: string + skip-download-jaxlib-and-cuda-plugins-from-gcs: + description: "Whether to skip downloading the jaxlib and cuda plugins from GCS (e.g for testing a jax only release)" + default: '0' + type: string + gcs_download_uri: + description: "GCS location URI from where the artifacts should be downloaded" + default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + type: string + build_jaxlib: + description: 'Should jaxlib be built from source?' + required: true + type: string + build_jax: + description: 'Should jax be built from source?' + required: true + type: string + write_to_bazel_remote_cache: + description: 'Whether to enable writing to the Bazel remote cache bucket' + required: false + default: '0' + type: string + run_multiaccelerator_tests: + description: 'Whether to run multi-accelerator tests' + required: false + default: 'false' + type: string + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: string + default: 'no' +permissions: {} +jobs: + run-tests: + defaults: + run: + # Explicitly set the shell to bash + shell: bash + runs-on: ${{ inputs.runner }} + container: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest" + + env: + JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }} + JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }} + JAXCI_CUDA_VERSION: ${{ inputs.cuda-version }} + JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE: ${{ inputs.write_to_bazel_remote_cache }} + JAXCI_BUILD_JAX: ${{ inputs.build_jax }} + JAXCI_BUILD_JAXLIB: ${{ inputs.build_jaxlib }} +# Begin Presubmit Naming Check - name modification requires internal check to be updated + name: "${{ (contains(inputs.runner, 'linux-x86') && 'linux x86') || + (contains(inputs.runner, 'linux-arm64') && 'linux arm64') || + (contains(inputs.runner, 'windows-x86') && 'windows x86') }}, jaxlib=${{ inputs.jaxlib-version }}, CUDA=${{ inputs.cuda-version }}, Python=${{ inputs.python }}, x64=${{ inputs.enable-x64 }}, build_jax=${{ inputs.build_jax }}, build_jaxlib=${{ inputs.build_jaxlib }}" +# End Presubmit Naming Check github-cuda-presubmits + steps: + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + with: + persist-credentials: false + - name: Download JAX CUDA wheels + if: inputs.build_jaxlib == 'false' + uses: ./.github/actions/download-jax-cuda-wheels + with: + python: ${{ inputs.python }} + cuda-version: ${{ inputs.cuda-version }} + download-jax-from-gcs: ${{ inputs.download-jax-from-gcs }} + skip-download-jaxlib-and-cuda-plugins-from-gcs: ${{ inputs.skip-download-jaxlib-and-cuda-plugins-from-gcs }} + jaxlib-version: ${{ inputs.jaxlib-version }} + gcs_download_uri: ${{ inputs.gcs_download_uri }} + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: "Bazel CUDA tests with build_jax=${{ inputs.build_jax }}, build_jaxlib=${{ inputs.build_jaxlib }}" + timeout-minutes: 60 + run: ${{ ((inputs.run_multiaccelerator_tests == 'false') && './ci/run_bazel_test_cuda_rbe.sh') || './ci/run_bazel_test_cuda_non_rbe.sh' }} diff --git a/.github/workflows/bazel_cuda_h100_b200.yml b/.github/workflows/bazel_cuda_h100_b200.yml new file mode 100644 index 000000000000..60cbb01ed203 --- /dev/null +++ b/.github/workflows/bazel_cuda_h100_b200.yml @@ -0,0 +1,149 @@ +name: CI - Bazel H100 and B200 CUDA tests +# This runs if any of the following conditions are met +# H100 and B200 on Workflow dispatch +# H100 and B200 on scheduled every two hours +# B200 on PR to main that modifies mosaic files or this file, see below for list +# H100 and B200 on PR to main that has the 'CI Optional GPU Presubmit' label +on: + # Runs on PR if label "CI Optional GPU Presubmit" is present. + workflow_dispatch: + inputs: + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: choice + required: true + default: 'no' + options: + - 'yes' + - 'no' + pull_request: + branches: + - main + types: [ labeled, synchronize, opened, reopened ] + schedule: + - cron: "0 */2 * * *" # Run once every 2 hours +permissions: {} +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + # Don't cancel in-progress jobs for main/release branches. + cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }} +jobs: + changed_files: + permissions: {} # No permissions given + runs-on: ubuntu-latest # Do not run tj-actions on self-hosted runners + steps: + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + with: + persist-credentials: false + - name: Get and filter changed files # We only run this if it is a pull request, do not run tj-actions on non PR event + if: ${{ github.event_name == 'pull_request' }} + id: changed-files + uses: tj-actions/changed-files@ed68ef82c095e0d48ec87eccea555d944a631a4c # v46 + with: + files: | + jax/_src/pallas/mosaic_gpu/** + jax/experimental/mosaic/gpu/** + jaxlib/mosaic/dialect/gpu/** + jaxlib/mosaic/gpu/** + .github/workflows/bazel_cuda_h100_b200.yml + - name: List all changed files + env: + ALL_CHANGED_FILES: ${{ steps.changed-files.outputs.all_changed_files }} + run: | + for file in ${ALL_CHANGED_FILES}; do + echo "$file was changed" + done + outputs: + any_changed: ${{ steps.changed-files.outputs.any_changed || 'false' }} + run_tests: + needs: changed_files + if: ${{ github.event.repository.fork == false && (github.event_name == 'schedule' || needs.changed_files.outputs.any_changed == 'true' || github.event_name == 'workflow_dispatch' || contains(github.event.pull_request.labels.*.name, 'CI Optional GPU Presubmit')) }} + runs-on: linux-x86-a4-224-b200-1gpu + container: 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest' + name: "Bazel single B200 CUDA tests" +# End Presubmit Naming Check github-cuda-presubmits + steps: + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + with: + persist-credentials: false + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Run Bazel single B200 CUDA Tests + run: | + nvidia-smi + bazel test \ + --config=ci_linux_x86_64_cuda \ + --config=ci_rbe_cache \ + --config=hermetic_cuda_umd \ + --repo_env=HERMETIC_PYTHON_VERSION="3.14" \ + --repo_env=HERMETIC_CUDNN_VERSION="9.11.0" \ + --repo_env=HERMETIC_CUDA_UMD_VERSION="13.0.0" \ + --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ + --run_under "$(pwd)/build/parallel_accelerator_execute.sh" \ + --test_output=errors \ + --test_tag_filters=-multiaccelerator \ + --test_env=JAX_ACCELERATOR_COUNT=1 \ + --test_env=JAX_TESTS_PER_ACCELERATOR=8 \ + --strategy=TestRunner=local \ + --local_test_jobs=8 \ + --test_env=JAX_EXCLUDE_TEST_TARGETS='PmapTest.testSizeOverflow|.*InterpretTest.*' \ + --test_env=TF_CPP_MIN_LOG_LEVEL=0 \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64="1" \ + --action_env=NCCL_DEBUG=WARN \ + --flaky_test_attempts=1 \ + --test_timeout=420 \ + --color=yes \ + //tests:cudnn_fusion_test_gpu \ + //tests:scaled_matmul_stablehlo_test_gpu \ + //tests:fused_attention_stablehlo_test_gpu \ + //tests:nn_test_gpu \ + //tests/pallas:gpu_tests \ + //tests/mosaic:gpu_tests + run_multiaccelerator_tests: + if: ${{ github.event.repository.fork == false && (github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' || contains(github.event.pull_request.labels.*.name, 'CI Optional GPU Presubmit')) }} + runs-on: linux-x86-a3-8g-h100-8gpu + container: 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest' + name: "Bazel multiple H100 CUDA tests" + steps: + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + with: + persist-credentials: false + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Run Bazel multiple H100 CUDA Tests + run: | + nvidia-smi + bazel test \ + --config=ci_linux_x86_64_cuda \ + --config=ci_rbe_cache \ + --config=hermetic_cuda_umd \ + --repo_env=HERMETIC_PYTHON_VERSION="3.14" \ + --repo_env=HERMETIC_CUDNN_VERSION="9.11.0" \ + --repo_env=HERMETIC_CUDA_UMD_VERSION="13.0.0" \ + --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ + --test_output=errors \ + --strategy=TestRunner=local \ + --local_test_jobs=8 \ + --test_env=JAX_EXCLUDE_TEST_TARGETS='PmapTest.testSizeOverflow|.*InterpretTest.*' \ + --test_tag_filters=multiaccelerator \ + --test_env=TF_CPP_MIN_LOG_LEVEL=0 \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64="1" \ + --action_env=NCCL_DEBUG=WARN \ + --flaky_test_attempts=1 \ + --color=yes \ + //tests/mosaic:gpu_tests \ + //tests/pallas:gpu_tests \ + //tests:array_interoperability_test_gpu \ + //tests:cudnn_fusion_test_gpu \ + //tests:fused_attention_stablehlo_test_gpu \ + //tests:gpu_tests \ + //tests:python_callback_test_gpu \ + //tests:ragged_collective_test_gpu \ + //tests/multiprocess:gpu_tests \ + //jax/experimental/jax2tf/tests/multiprocess:gpu_tests \ No newline at end of file diff --git a/.github/workflows/bazel_cuda_non_rbe.yml b/.github/workflows/bazel_cuda_non_rbe.yml deleted file mode 100644 index 0b0e1cb62497..000000000000 --- a/.github/workflows/bazel_cuda_non_rbe.yml +++ /dev/null @@ -1,99 +0,0 @@ -# CI - Bazel CUDA tests (Non-RBE) -# -# This workflow runs the CUDA tests with Bazel. It can only be triggered by other workflows via -# `workflow_call`. It is used by the `CI - Wheel Tests` workflows to run the Bazel CUDA tests. -# -# It consists of the following job: -# run-tests: -# - Downloads the jaxlib and CUDA artifacts from a GCS bucket. -# - Executes the `run_bazel_test_cuda_non_rbe.sh` script, which performs the following actions: -# - Installs the downloaded wheel artifacts. -# - Runs the CUDA tests with Bazel. -name: CI - Bazel CUDA tests (Non-RBE) - -on: - workflow_call: - inputs: - runner: - description: "Which runner should the workflow run on?" - type: string - required: true - default: "linux-x86-n2-16" - python: - description: "Which python version to test?" - type: string - required: true - default: "3.12" - enable-x64: - description: "Should x64 mode be enabled?" - type: string - required: true - default: "0" - gcs_download_uri: - description: "GCS location URI from where the artifacts should be downloaded" - required: true - default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' - type: string - halt-for-connection: - description: 'Should this workflow run wait for a remote connection?' - type: boolean - required: false - default: false - -jobs: - run-tests: - defaults: - run: - # Explicitly set the shell to bash - shell: bash - runs-on: ${{ inputs.runner }} - container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.3-cudnn9.1:latest" - - env: - JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }} - JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }} - # Enable writing to the Bazel remote cache bucket. - JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE: "1" - - name: "Bazel single accelerator and multi-accelerator CUDA tests (${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})" - - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set env vars for use in artifact download URL - run: | - os=$(uname -s | awk '{print tolower($0)}') - arch=$(uname -m) - - # Get the major and minor version of Python. - # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 - python_major_minor=$(echo "$JAXCI_HERMETIC_PYTHON_VERSION" | tr -d '.') - - echo "OS=${os}" >> $GITHUB_ENV - echo "ARCH=${arch}" >> $GITHUB_ENV - echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV - - name: Download the wheel artifacts from GCS - id: download-wheel-artifacts - # Set continue-on-error to true to prevent actions from failing the workflow if this step - # fails. Instead, we verify the outcome in the next step so that we can print a more - # informative error message. - continue-on-error: true - run: >- - mkdir -p $(pwd)/dist && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ - - name: Skip the test run if the wheel artifacts were not downloaded successfully - if: steps.download-wheel-artifacts.outcome == 'failure' - run: | - echo "Failed to download wheel artifacts from GCS. Please check if the wheels were" - echo "built successfully by the artifact build jobs and are available in the GCS bucket." - echo "Skipping the test run." - exit 1 - # Halt for testing - - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main - with: - halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Run Bazel CUDA tests (Non-RBE) - timeout-minutes: 60 - run: ./ci/run_bazel_test_cuda_non_rbe.sh diff --git a/.github/workflows/bazel_cuda_presubmit.yml b/.github/workflows/bazel_cuda_presubmit.yml new file mode 100644 index 000000000000..48eb4495f2a9 --- /dev/null +++ b/.github/workflows/bazel_cuda_presubmit.yml @@ -0,0 +1,92 @@ +name: CI - Bazel CUDA tests + +on: + workflow_dispatch: + inputs: + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: choice + required: true + default: 'no' + options: + - 'yes' + - 'no' + pull_request: + branches: + - main + push: + branches: + - main + - 'release/**' + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + # Don't cancel in-progress jobs for main/release branches. + cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }} +permissions: {} +jobs: + build-cuda-artifacts: + uses: ./.github/workflows/build_artifacts.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + # Python values need to match the matrix stategy in the CUDA tests job below + runner: ["linux-x86-n4-16"] + artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"] + python: ["3.14",] + cuda-version: ["13"] + name: "Build ${{ format('{0}', 'CUDA') }} artifacts" + with: + runner: ${{ matrix.runner }} + artifact: ${{ matrix.artifact }} + python: ${{ matrix.python }} + cuda-version: ${{ matrix.cuda-version }} + clone_main_xla: 1 + upload_artifacts_to_gcs: false + + run-tests: + if: github.event.repository.fork == false + uses: ./.github/workflows/bazel_cuda.yml +# Begin Presubmit Naming Check - name modification requires internal check to be updated + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + python: ["3.11", "3.14"] + runner: ["linux-x86-n4-16"] + jaxlib-version: ["head", "pypi_latest"] + enable-x64: [1, 0] + cuda-version: ["12", "13"] + exclude: + # Exclude x64=1 on the oldest Python and x64=0 on the newest Python. As long as we have + # coverage for one of each, we don't need to run both. + - python: "3.11" + enable-x64: 1 + - python: "3.11" + enable-x64: 0 + jaxlib-version: "head" + - python: "3.11" + enable-x64: 0 + jaxlib-version: "pypi_latest" + cuda-version: "13" + - python: "3.14" + enable-x64: 0 + - python: "3.14" + enable-x64: 1 + jaxlib-version: "pypi_latest" + # Exclude CUDA 12 on jaxlib head because it's too slow. + - cuda-version: "12" + jaxlib-version: "head" + name: "Bazel single accelerator ${{ format('{0}', 'CUDA tests') }}" +# End Presubmit Naming Check github-cuda-presubmits + with: + runner: ${{ matrix.runner }} + python: ${{ matrix.python }} + download-jax-from-gcs: 0 + skip-download-jaxlib-and-cuda-plugins-from-gcs: 0 + jaxlib-version: ${{ matrix.jaxlib-version }} + halt-for-connection: ${{ inputs.halt-for-connection }} + cuda-version: ${{ matrix.cuda-version }} + enable-x64: ${{ matrix.enable-x64 }} + build_jaxlib: ${{ (matrix.jaxlib-version == 'head' && true) || false }} + build_jax: 'true' + run_multiaccelerator_tests: 'false' diff --git a/.github/workflows/bazel_cuda_rbe.yml b/.github/workflows/bazel_cuda_rbe.yml deleted file mode 100644 index 5a2c94c4db47..000000000000 --- a/.github/workflows/bazel_cuda_rbe.yml +++ /dev/null @@ -1,57 +0,0 @@ -name: CI - Bazel CUDA tests (RBE) - -on: - workflow_dispatch: - inputs: - halt-for-connection: - description: 'Should this workflow run wait for a remote connection?' - type: choice - required: true - default: 'no' - options: - - 'yes' - - 'no' - pull_request: - branches: - - main - push: - branches: - - main - - 'release/**' - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} - # Don't cancel in-progress jobs for main/release branches. - cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }} - -jobs: - run_tests: - if: github.event.repository.fork == false - runs-on: ${{ matrix.runner }} - container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest' - env: - JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} - JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} -# Begin Presubmit Naming Check - name modification requires internal check to be updated - strategy: - matrix: - python: ["3.10", "3.13"] - runner: ["linux-x86-n2-16"] - enable-x_64: [1, 0] - exclude: - # Exclude x64=1 on the oldest Python and x64=0 on the newest Python. As long as we have - # coverage for one of each, we don't need to run both. - - python: "3.10" - enable-x_64: 1 - - python: "3.13" - enable-x_64: 0 - name: "Bazel single accelerator CUDA tests (${{ matrix.runner }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x_64 }})" -# End Presubmit Naming Check github-cuda-presubmits - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main - with: - halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Run Bazel CUDA Tests with RBE - run: ./ci/run_bazel_test_cuda_rbe.sh \ No newline at end of file diff --git a/.github/workflows/bazel_test_tpu.yml b/.github/workflows/bazel_test_tpu.yml new file mode 100644 index 000000000000..6459c45475e0 --- /dev/null +++ b/.github/workflows/bazel_test_tpu.yml @@ -0,0 +1,147 @@ +# CI - Bazel test TPU +# +# This workflow runs the TPU tests with Bazel. It can only be triggered by other workflows via +# `workflow_call`. It is used by the "CI - Wheel Tests" workflows to run the Bazel TPU tests. +# +# It consists of the following job: +# run-tests: +# - Downloads the jaxlib wheel from a GCS bucket if build_jaxlib is false. +# - Downloads the libtpu wheels. +# - Executes the `run_bazel_test_tpu.sh` script, which performs the following actions: +# - Runs the TPU tests with Bazel. +name: CI - Bazel test TPU +on: + workflow_call: + inputs: + # Note that the values for runners, cores, and tpu-type are linked to each other. + # For example, the v5e-8 TPU type requires 8 cores. For ease of reference, we use the + # following mapping: + # {tpu-type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, + # {tpu-type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} + runner: + description: "Which runner should the workflow run on?" + type: string + default: "linux-x86-ct5lp-224-8tpu" + cores: + description: "How many TPU cores should the test use?" + type: string + default: "8" + tpu-type: + description: "Which TPU type is used for testing?" + type: string + default: "v5e-8" + python: + description: "Which Python version should be used for testing?" + type: string + default: "3.12" + run-full-tpu-test-suite: + description: "Should the full TPU test suite be run?" + type: string + default: "0" + libtpu-version-type: + description: "Which libtpu version should be used for testing?" + type: string + # Choices are: + # - "nightly": Use the nightly libtpu wheel. + # - "pypi_latest": Use the latest libtpu wheel from PyPI. + # - "oldest_supported_libtpu": Use the oldest supported libtpu wheel. + default: "nightly" + jaxlib-version: + description: "Which jaxlib version to test? (head/pypi_latest)" + required: false + type: string + default: "head" + skip-download-jaxlib-from-gcs: + description: "Whether to skip downloading the jaxlib artifact from GCS (e.g for testing a jax only release)" + default: '0' + type: string + gcs_download_uri: + description: "GCS location prefix from where the artifacts should be downloaded" + required: false + default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + type: string + clone_main_xla: + description: "Should latest XLA be used?" + type: string + default: "0" + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: string + default: 'no' + build_jaxlib: + description: 'Should jaxlib be built from source?' + required: false + default: 'false' + type: string + build_jax: + description: 'Should jax be built from source?' + required: false + default: 'false' + type: string +permissions: {} + +env: + PIP_INDEX_URL: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple" + +jobs: + run-bazel-tests: + defaults: + run: + shell: bash + runs-on: ${{ inputs.runner }} + container: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest" + # Begin Presubmit Naming Check - name modification requires internal check to be updated + name: "${{ inputs.tpu-type }}, py ${{ inputs.python }}, libtpu=${{ inputs.libtpu-version-type }}" + # End Presubmit Naming Check github-tpu-presubmits + + env: + LIBTPU_OLDEST_VERSION_DATE: 20250228 + JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" + JAXCI_PYTHON: "python${{ inputs.python }}" + JAXCI_RUN_FULL_TPU_TEST_SUITE: "${{ inputs.run-full-tpu-test-suite }}" + JAXCI_TPU_CORES: "${{ inputs.cores }}" + JAXCI_BUILD_JAXLIB: ${{ inputs.build_jaxlib }} + JAXCI_BUILD_JAX: ${{ inputs.build_jax }} + JAXCI_CLONE_MAIN_XLA: "${{ inputs.clone_main_xla }}" + + steps: + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + with: + persist-credentials: false + - name: Download JAX wheels + if: inputs.build_jaxlib == 'false' + uses: ./.github/actions/download-jax-cpu-wheels + with: + runner: ${{ inputs.runner }} + python: ${{ inputs.python }} + skip-download-jaxlib-from-gcs: ${{ inputs.skip-download-jaxlib-from-gcs }} + jaxlib-version: ${{ inputs.jaxlib-version }} + gcs_download_uri: ${{ inputs.gcs_download_uri }} + - name: Download libtpu wheels + run: | + mkdir -p $(pwd)/dist + $JAXCI_PYTHON -m pip install --upgrade pip + echo "Download the wheel into a local directory" + if [[ "${{ inputs.libtpu-version-type }}" == "nightly" ]]; then + $JAXCI_PYTHON -m pip download -d $(pwd)/dist --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + elif [[ "${{ inputs.libtpu-version-type }}" == "pypi_latest" ]]; then + echo "Using latest libtpu from PyPI" + $JAXCI_PYTHON -m pip download -d $(pwd)/dist libtpu + elif [[ "${{ inputs.libtpu-version-type }}" == "oldest_supported_libtpu" ]]; then + echo "Using oldest supported libtpu" + $JAXCI_PYTHON -m pip download -d $(pwd)/dist --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \ + -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + + echo "libtpu_version_type=oldest_supported_libtpu" >> $GITHUB_ENV + else + echo "Unknown libtpu version type: ${{ inputs.libtpu-version-type }}" + exit 1 + fi + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Run Bazel TPU tests with build_jaxlib=${{ format('{0}', inputs.build_jaxlib) }} + timeout-minutes: ${{ github.event_name == 'pull_request' && 20 || 210 }} + run: ./ci/run_bazel_test_tpu.sh diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index c2e7acb91f7a..dbd86a934d73 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -12,16 +12,15 @@ on: runner: description: "Which runner should the workflow run on?" type: choice - required: true - default: "linux-x86-n2-16" + default: "linux-x86-n4-16" options: - - "linux-x86-n2-16" - - "linux-arm64-c4a-64" - - "windows-x86-n2-64" + - "linux-x86-n4-16" + - "linux-arm64-t2a-48" + - "linux-arm64-c4a-16" + - "windows-x86-n2-16" artifact: description: "Which JAX artifact to build?" type: choice - required: true default: "jaxlib" options: - "jax" @@ -31,17 +30,24 @@ on: python: description: "Which python version should the artifact be built for?" type: choice - required: false default: "3.12" options: - - "3.10" - "3.11" - "3.12" - "3.13" + - "3.13-ft" + - "3.14" + - "3.14-ft" + cuda-version: + description: "Which cuda version should the artifact be built for?" + type: choice + default: "12" + options: + - "12" + - "13" clone_main_xla: description: "Should latest XLA be used?" type: choice - required: false default: "0" options: - "1" @@ -49,7 +55,6 @@ on: halt-for-connection: description: 'Should this workflow run wait for a remote connection?' type: choice - required: false default: 'no' options: - 'yes' @@ -59,41 +64,40 @@ on: runner: description: "Which runner should the workflow run on?" type: string - required: true - default: "linux-x86-n2-16" + default: "linux-x86-n4-16" artifact: description: "Which JAX artifact to build?" type: string - required: true default: "jaxlib" python: description: "Which python version should the artifact be built for?" type: string - required: false default: "3.12" + cuda-version: + description: "Which cuda version should the artifact be built for?" + type: string + default: "12" clone_main_xla: description: "Should latest XLA be used?" type: string - required: false default: "0" upload_artifacts_to_gcs: description: "Should the artifacts be uploaded to a GCS bucket?" - required: true default: true type: boolean gcs_upload_uri: description: "GCS location prefix to where the artifacts should be uploaded" - required: true default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: string + default: 'no' outputs: gcs_upload_uri: description: "GCS location prefix to where the artifacts were uploaded" value: ${{ jobs.build-artifacts.outputs.gcs_upload_uri }} - -permissions: - contents: read - +permissions: {} jobs: build-artifacts: defaults: @@ -103,31 +107,51 @@ jobs: runs-on: ${{ inputs.runner }} - container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || - (contains(inputs.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || + container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest') || + (contains(inputs.runner, 'linux-arm64') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build-arm64:latest') || (contains(inputs.runner, 'windows-x86') && null) }} env: JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" + JAXCI_CUDA_VERSION: "${{ inputs.cuda-version }}" JAXCI_CLONE_MAIN_XLA: "${{ inputs.clone_main_xla }}" - name: Build ${{ inputs.artifact }} (${{ inputs.runner }}, Python ${{ inputs.python }}, clone main XLA=${{ inputs.clone_main_xla }}) + name: "${{ inputs.artifact }}, + ${{ (contains(inputs.runner, 'linux-x86') && 'linux x86') || + (contains(inputs.runner, 'linux-arm64') && 'linux arm64') || + (contains(inputs.runner, 'windows-x86') && 'windows x86') }}, py ${{ inputs.python }}${{ (contains(inputs.artifact, 'cuda') && format(', cuda {0}', inputs.cuda-version)) || '' }}, clone main XLA=${{ inputs.clone_main_xla }}" # Map the job outputs to step outputs outputs: gcs_upload_uri: ${{ steps.store-gcs-upload-uri.outputs.gcs_upload_uri }} steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Enable RBE if building on Linux x86 - if: contains(inputs.runner, 'linux-x86') - run: echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV - - name: Enable Bazel remote cache (with writes enabled) if building on Linux Aarch64 or Windows x86 - if: contains(inputs.runner, 'linux-arm64') || contains(inputs.runner, 'windows-x86') - run: echo "JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=1" >> $GITHUB_ENV + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + with: + persist-credentials: false + - name: Configure Build Environment + shell: bash + run: | + RUNNER_TYPE="${{ inputs.runner }}" + ARTIFACT_TYPE="${{ inputs.artifact }}" + + # Enable RBE if building on Linux or Windows x86 + if [[ "${RUNNER_TYPE}" == *linux-x86* || "${RUNNER_TYPE}" == *windows-x86* || ("${RUNNER_TYPE}" == *linux-arm64* && "${ARTIFACT_TYPE}" == "jaxlib") ]]; then + echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV + fi + + # Set bazel output base on Windows runner machine + if [[ "${RUNNER_TYPE}" == *windows-x86* ]]; then + echo "JAXCI_BAZEL_OUTPUT_BASE=C:\actions-runner\_work\bazel_output_base" >> $GITHUB_ENV + fi + + # Enable Bazel remote cache (with writes enabled) if building on Linux Aarch64 + if [[ "${RUNNER_TYPE}" == *linux-arm64* ]]; then + echo "JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=1" >> $GITHUB_ENV + fi # Halt for testing - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Build ${{ inputs.artifact }} @@ -136,13 +160,13 @@ jobs: - name: Upload artifacts to a GCS bucket (non-Windows runs) if: >- ${{ inputs.upload_artifacts_to_gcs && !contains(inputs.runner, 'windows-x86') }} - run: gsutil -m cp -r "$(pwd)/dist/*.whl" "${{ inputs.gcs_upload_uri }}"/ + run: gcloud storage cp -r "$(pwd)/dist/*.whl" "${{ inputs.gcs_upload_uri }}"/ # Set shell to cmd to avoid path errors when using gcloud commands on Windows - name: Upload artifacts to a GCS bucket (Windows runs) if: >- ${{ inputs.upload_artifacts_to_gcs && contains(inputs.runner, 'windows-x86') }} shell: cmd - run: gsutil -m cp -r "dist/*.whl" "${{ inputs.gcs_upload_uri }}"/ + run: gcloud storage cp -r "dist/*.whl" "${{ inputs.gcs_upload_uri }}"/ - name: Store the GCS upload URI as an output id: store-gcs-upload-uri if: ${{ inputs.upload_artifacts_to_gcs }} diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index f43407af2ed9..684c52da6e65 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -1,11 +1,5 @@ name: CI -# We test all supported Python versions as follows: -# - 3.10 : Documentation build -# - 3.10 : Part of Matrix with NumPy dispatch -# - 3.10 : Part of Matrix -# - 3.11 : Part of Matrix - on: # Trigger the workflow on push or pull request, # but only for the main branch @@ -16,26 +10,30 @@ on: branches: - main -permissions: - contents: read # to fetch code - actions: write # to cancel previous workflows - +permissions: {} concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} - cancel-in-progress: true + # Don't cancel in-progress jobs for main branches. + cancel-in-progress: ${{ github.ref != 'main' }} + +env: + UV_DEFAULT_INDEX: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple" + PIP_INDEX_URL: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple" jobs: lint_and_typecheck: runs-on: ubuntu-latest timeout-minutes: 5 steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + with: + persist-credentials: false - name: Set up Python 3.11 - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 with: python-version: 3.11 - run: python -m pip install pre-commit - - uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + - uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4.3.0 with: path: ~/.cache/pre-commit key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('.pre-commit-config.yaml', 'setup.py') }} @@ -44,33 +42,35 @@ jobs: build: # Don't execute in fork due to runner type if: github.repository == 'jax-ml/jax' - name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})" - runs-on: linux-x86-n2-32 + name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-22.04, x64=${{ matrix.enable-x64}})" + runs-on: linux-x86-n4-32 container: - image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04 + image: index.docker.io/library/ubuntu@sha256:4e0171b9275e12d375863f2b3ae9ce00a4c53ddda176bd55868df97ac6f21a6e # ratchet:ubuntu:22.04 timeout-minutes: 60 strategy: matrix: # Test the oldest and newest supported Python versions here. include: - - name-prefix: "with 3.10" - python-version: "3.10" + - name-prefix: "with 3.11" + python-version: "3.11" enable-x64: 1 prng-upgrade: 1 num_generated_cases: 1 - - name-prefix: "with 3.13" - python-version: "3.13" + - name-prefix: "with 3.14" + python-version: "3.14" enable-x64: 0 prng-upgrade: 0 num_generated_cases: 1 steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + with: + persist-credentials: false - name: Image Setup run: | apt update apt install -y libssl-dev - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -88,7 +88,6 @@ jobs: JAX_SKIP_SLOW_TESTS: true PY_COLORS: 1 run: | - uv pip install --system -e . echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES" echo "JAX_ENABLE_X64=$JAX_ENABLE_X64" echo "JAX_ENABLE_CUSTOM_PRNG=$JAX_ENABLE_CUSTOM_PRNG" @@ -104,11 +103,13 @@ jobs: timeout-minutes: 10 strategy: matrix: - python-version: ['3.10'] + python-version: ['3.12'] steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + with: + persist-credentials: false - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -123,26 +124,28 @@ jobs: PY_COLORS: 1 run: | pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md - pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/lib/xla_extension.py + pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/lib documentation_render: name: Documentation - render documentation - runs-on: linux-x86-n2-16 + runs-on: linux-x86-n4-16 container: - image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04 + image: index.docker.io/library/ubuntu@sha256:4e0171b9275e12d375863f2b3ae9ce00a4c53ddda176bd55868df97ac6f21a6e # ratchet:ubuntu:22.04 timeout-minutes: 10 strategy: matrix: - python-version: ['3.10'] + python-version: ['3.11'] steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + with: + persist-credentials: false - name: Image Setup run: | apt update - apt install -y libssl-dev libsqlite3-dev + apt install -y libssl-dev libsqlite3-dev build-essential - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -151,46 +154,7 @@ jobs: uv pip install --system -r docs/requirements.txt - name: Render documentation run: | - sphinx-build -j auto --color -W --keep-going -b html -D nb_execution_mode=off docs docs/build/html - - jax2tf_test: - name: "jax2tf_test (py ${{ matrix.python-version }} on ${{ matrix.os }}, x64=${{ matrix.enable-x64}})" - runs-on: ${{ matrix.os }} - timeout-minutes: 30 - strategy: - matrix: - # Test the oldest supported Python version here. - include: - - python-version: "3.10" - os: ubuntu-latest - enable-x64: 0 - num_generated_cases: 10 - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - pip install uv~=0.5.30 - uv pip install --system .[minimum-jaxlib] -r build/test-requirements.txt - uv pip install --system --pre tensorflow==2.19.0rc0 - - - name: Run tests - env: - JAX_NUM_GENERATED_CASES: ${{ matrix.num_generated_cases }} - JAX_ENABLE_X64: ${{ matrix.enable-x64 }} - JAX_ENABLE_CHECKS: true - JAX_SKIP_SLOW_TESTS: true - PY_COLORS: 1 - run: | - uv pip install --system -e . - echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES" - echo "JAX_ENABLE_X64=$JAX_ENABLE_X64" - echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" - echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS" - pytest -n auto --tb=short --maxfail=20 jax/experimental/jax2tf/tests/jax2tf_test.py + sphinx-build -j auto --color -W --keep-going -b html docs docs/build/html ffi: name: FFI example @@ -199,9 +163,11 @@ jobs: image: index.docker.io/tensorflow/build:latest-python3.12@sha256:48e99608fe9434ada5b14e19fdfd8e64f4cfc83aacd328b9c2101b210e984295 # ratchet:index.docker.io/tensorflow/build:latest-python3.12 timeout-minutes: 30 steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + with: + persist-credentials: false - name: Set up Python - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 with: python-version: 3.12 - name: Install JAX diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index 099f4ad5c520..73ccb3d7b5b4 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -11,32 +11,45 @@ # Github Actions environment). name: CI - Cloud TPU (nightly) +# Disable the schedule; Slated for removal, the new test workflow is in +# "wheel_tests_nightly_release.yml" on: - schedule: - - cron: "0 2,14 * * *" # Run at 7am and 7pm PST +# schedule: +# - cron: "0 2,14 * * *" # Run at 7am and 7pm PST workflow_dispatch: # allows triggering the workflow run manually + # This should also be set to read-only in the project settings, but it's nice to # document and enforce the permissions here. permissions: contents: read + +env: + UV_DEFAULT_INDEX: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple" + jobs: cloud-tpu-test: strategy: fail-fast: false # don't cancel all jobs on failure matrix: - jaxlib-version: ["head", "pypi_latest", "nightly", "nightly+oldest_supported_libtpu"] + jaxlib-version: ["head", "pypi_latest", "nightly"] tpu: [ - # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, - {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} + {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}, + {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"}, + {type: "v7x-8", cores: "8", runner: "linux-x86-tpu7x-224-4tpu"} ] - python-version: ["3.10"] + python-version: ["3.11"] + # Exclude v6e-8 tests for pypi_latest for resource constraints. + exclude: + - tpu: + type: "v6e-8" + jaxlib-version: "pypi_latest" name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})" env: - LIBTPU_OLDEST_VERSION_DATE: 20241205 + LIBTPU_OLDEST_VERSION_DATE: 20250228 PYTHON: python${{ matrix.python-version }} runs-on: ${{ matrix.tpu.runner }} - container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" + container: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest" timeout-minutes: 180 defaults: run: @@ -45,14 +58,17 @@ jobs: # https://opensource.google/documentation/reference/github/services#actions # mandates using a specific commit for non-Google actions. We use # https://github.com/sethvargo/ratchet to pin specific versions. - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + with: + persist-credentials: false # Checkout XLA at head, if we're building jaxlib at head. - name: Checkout XLA at head - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 if: ${{ matrix.jaxlib-version == 'head' }} with: repository: openxla/xla path: xla + persist-credentials: false # We need to mark the GitHub workspace as safe as otherwise git commands will fail. - name: Mark GitHub workspace as safe run: | @@ -80,14 +96,14 @@ jobs: elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then $PYTHON -m uv pip install \ - --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + --pre . -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html \ requests elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then # TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release. $PYTHON -m uv pip install \ - --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + --pre . -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html \ requests @@ -100,8 +116,8 @@ jobs: $PYTHON -c 'import jax; print("jax version:", jax.__version__)' $PYTHON -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)' strings /usr/local/lib/"$PYTHON"/dist-packages/libtpu/libtpu.so | grep 'Built on' - $PYTHON -c 'import jax; print("libtpu version:", - jax.lib.xla_bridge.get_backend().platform_version)' + $PYTHON -c 'import jax.extend; print("libtpu version:", + jax.extend.backend.get_backend().platform_version)' - name: Run tests env: JAX_PLATFORMS: tpu,cpu @@ -116,11 +132,11 @@ jobs: fi # Run single-accelerator tests in parallel JAX_ENABLE_TPU_XDIST=true $PYTHON -m pytest -n=${{ matrix.tpu.cores }} --tb=short \ - --deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \ + --deselect=tests/pallas/tpu_pallas_call_print_test.py::PallasCallPrintTest \ --maxfail=20 -m "not multiaccelerator" $IGNORE_FLAGS tests examples # Run Pallas printing tests, which need to run with I/O capturing disabled. TPU_STDERR_LOG_LEVEL=0 $PYTHON -m pytest -s \ - tests/pallas/tpu_pallas_test.py::PallasCallPrintTest + tests/pallas/tpu_pallas_call_print_test.py::PallasCallPrintTest # Run multi-accelerator across all chips $PYTHON -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests - name: Send chat on failure diff --git a/.github/workflows/cloud-tpu-ci-presubmit.yml b/.github/workflows/cloud-tpu-ci-presubmit.yml index a92e3cc19313..0c9af09ab9e0 100644 --- a/.github/workflows/cloud-tpu-ci-presubmit.yml +++ b/.github/workflows/cloud-tpu-ci-presubmit.yml @@ -1,7 +1,5 @@ # Cloud TPU CI (presubmit) -# -# This job currently runs as a non-blocking presubmit. It is experimental and is currently being -# tested to get to a stable state before we enable it as a blocking presubmit. + name: CI - Cloud TPU (presubmit) on: @@ -25,41 +23,26 @@ on: # This should also be set to read-only in the project settings, but it's nice to # document and enforce the permissions here. -permissions: - contents: read - +permissions: {} concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} # Don't cancel in-progress jobs for main/release branches. cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }} jobs: - build-jax-artifacts: - if: github.event.repository.fork == false - uses: ./.github/workflows/build_artifacts.yml - strategy: - fail-fast: false # don't cancel all jobs on failure - matrix: - artifact: ["jax", "jaxlib"] - with: - runner: "linux-x86-n2-16" - artifact: ${{ matrix.artifact }} - python: "3.10" - clone_main_xla: 1 - upload_artifacts_to_gcs: true - gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' - - run-pytest-tpu: + run-bazel-test-tpu: if: github.event.repository.fork == false - needs: [build-jax-artifacts] - uses: ./.github/workflows/pytest_tpu.yml + uses: ./.github/workflows/bazel_test_tpu.yml # Begin Presubmit Naming Check - name modification requires internal check to be updated - name: "TPU test (jaxlib=head, v5e-8)" + name: "TPU test (jaxlib=head)" with: - runner: "linux-x86-ct5lp-224-8tpu" + runner: "linux-x86-ct6e-180-8tpu" cores: "8" - tpu-type: "v5e-8" - python: "3.10" + tpu-type: "v6e-8" + python: "3.14-nogil" libtpu-version-type: "nightly" - gcs_download_uri: ${{ needs.build-jax-artifacts.outputs.gcs_upload_uri }} - # End Presubmit Naming Check github-tpu-presubmits \ No newline at end of file + halt-for-connection: false + build_jaxlib: "true" + build_jax: "true" + clone_main_xla: 1 + # End Presubmit Naming Check github-tpu-presubmits diff --git a/.github/workflows/community_release_actions.yml b/.github/workflows/community_release_actions.yml new file mode 100644 index 000000000000..1110cbad9475 --- /dev/null +++ b/.github/workflows/community_release_actions.yml @@ -0,0 +1,34 @@ +name: Release Actions + +on: + release: + types: [published] + +permissions: + contents: read + +jobs: + discord_release: + if: github.repository_owner == 'jax-ml' + runs-on: ubuntu-latest + steps: + - name: Get release URL + id: get-release-url + run: | + URL="https://docs.jax.dev/en/latest/changelog.html" + echo "::set-output name=URL::$URL" + - name: Get content + uses: 2428392/gh-truncate-string-action@b3ff790d21cf42af3ca7579146eedb93c8fb0757 # v1.4.1 + id: get-content + with: + stringToTruncate: | + JAX [${{ github.event.release.tag_name }}](<${{ steps.get-release-url.outputs.URL }}>) was just released! + + ${{ github.event.release.body }} + maxLength: 2000 + truncationSymbol: "..." + - name: Discord Webhook Action + uses: tsickert/discord-webhook@b217a69502f52803de774ded2b1ab7c282e99645 # v7.0.0 + with: + webhook-url: ${{ secrets.DISCORD_WEBHOOK_URL }} + content: ${{ steps.get-content.outputs.string }} diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index 2b97c5a05c1c..870483bbce6a 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -10,39 +10,43 @@ on: concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} - cancel-in-progress: true + # Don't cancel in-progress jobs for main branches. + cancel-in-progress: ${{ github.ref != 'main' }} +permissions: {} + +env: + UV_DEFAULT_INDEX: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple" jobs: build: - - runs-on: ubuntu-latest + runs-on: linux-x86-n4-16 + container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest strategy: matrix: python-version: [3.11] - + env: + PYTHON: "python${{ matrix.python-version }}" steps: - name: Checkout jax - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + with: + persist-credentials: false - name: Checkout array-api-tests - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 with: repository: data-apis/array-api-tests - # TODO(jakevdp) update this to a stable release/tag when available. - ref: '0b89c5268e4e4a352223a487b8f63dbd1023872d' # Latest commit as of 2025-03-04 + ref: '2025.05.23' submodules: 'true' path: 'array-api-tests' - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 - with: - python-version: ${{ matrix.python-version }} + persist-credentials: false - name: Install dependencies + # TODO(jakevdp) remove numpy pin once ml_dtypes 0.6 is released run: | - pip install uv~=0.5.30 - uv pip install --system .[ci] pytest-xdist -r array-api-tests/requirements.txt + $PYTHON -m uv pip install --system .[ci] "numpy<2.4.0" pytest-xdist -r array-api-tests/requirements.txt - name: Run the test suite env: ARRAY_API_TESTS_MODULE: jax.numpy JAX_ENABLE_X64: 'true' run: | cd ${GITHUB_WORKSPACE}/array-api-tests - pytest -n auto array_api_tests --derandomize --disable-deadline --skips-file ${GITHUB_WORKSPACE}/tests/array_api_skips.txt + $PYTHON -m pytest -n auto array_api_tests --derandomize --disable-deadline --skips-file ${GITHUB_WORKSPACE}/tests/array_api_skips.txt diff --git a/.github/workflows/k8s.yaml b/.github/workflows/k8s.yaml new file mode 100644 index 000000000000..1a06e3b27a4d --- /dev/null +++ b/.github/workflows/k8s.yaml @@ -0,0 +1,116 @@ +name: Multi-process run using K8s +on: + push: + branches: + - main + paths: + - '.github/workflows/k8s.yaml' + - 'ci/k8s/**' + - 'jax/distributed.py' + - 'jax/_src/distributed.py' + - 'jax/_src/clusters/**' + pull_request: + branches: + - main + paths: + - '.github/workflows/k8s.yaml' + - 'ci/k8s/**' + - 'jax/distributed.py' + - 'jax/_src/distributed.py' + - 'jax/_src/clusters/**' + +permissions: {} +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true +defaults: + run: + shell: bash -ex -o pipefail {0} +jobs: + distributed-initialize: + runs-on: ubuntu-22.04 + strategy: + fail-fast: false + matrix: + controller: [jobset, indexed-job] + steps: + - name: Checkout + uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v4 + with: + path: jax + persist-credentials: false + + - name: Start Minikube cluster + uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # ratchet:medyagh/setup-minikube@v0.0.19 + + - name: Install K8s Jobset + if: matrix.controller == 'jobset' + run: | + kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/v0.8.0/manifests.yaml + kubectl wait --for=condition=established crd/jobsets.jobset.x-k8s.io --timeout=60s + kubectl rollout status -n jobset-system deploy/jobset-controller-manager --timeout=120s + + - name: Build image + run: | + cat > Dockerfile <> $GITHUB_ENV - echo "ARCH=${arch}" >> $GITHUB_ENV - # Python wheels follow a naming convention: standard wheels use the pattern - # `*-cp-cp-*`, while free-threaded wheels use - # `*-cp-cpt-*`. - echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV - - name: Download wheels from GCS (non-Windows runs) - id: download-wheel-artifacts-nw - # Set continue-on-error to true to prevent actions from failing the workflow if this step - # fails. Instead, we verify the outcome in the step below so that we can print a more - # informative error message. - continue-on-error: true - if: ${{ !contains(inputs.runner, 'windows-x86') }} - run: | - mkdir -p $(pwd)/dist - gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ - - name: Download wheels from GCS (Windows runs) - id: download-wheel-artifacts-w - # Set continue-on-error to true to prevent actions from failing the workflow if this step - # fails. Instead, we verify the outcome in step below so that we can print a more - # informative error message. - continue-on-error: true - if: ${{ contains(inputs.runner, 'windows-x86') }} - shell: cmd - run: | - mkdir dist - @REM Use `call` so that we can run sequential gsutil commands on Windows - @REM See https://github.com/GoogleCloudPlatform/gsutil/issues/233#issuecomment-196150652 - call gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl dist/ - call gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*%PYTHON_MAJOR_MINOR%*%OS%*%ARCH%*.whl" dist/ - - name: Skip the test run if the wheel artifacts were not downloaded successfully - if: steps.download-wheel-artifacts-nw.outcome == 'failure' || steps.download-wheel-artifacts-w.outcome == 'failure' - run: | - echo "Failed to download wheel artifacts from GCS. Please check if the wheels were" - echo "built successfully by the artifact build jobs and are available in the GCS bucket." - echo "Skipping the test run." - exit 1 + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + with: + persist-credentials: false + - name: Download JAX CPU wheels + uses: ./.github/actions/download-jax-cpu-wheels + with: + runner: ${{ inputs.runner }} + python: ${{ inputs.python }} + skip-download-jaxlib-from-gcs: ${{ inputs.skip-download-jaxlib-from-gcs }} + gcs_download_uri: ${{ inputs.gcs_download_uri }} - name: Install Python dependencies run: | $JAXCI_PYTHON -m pip install uv~=0.5.30 - $JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt - # CPU Pytests crash with NumPy 2.2+ on Linux Aarch64; b/399168632 - if [[ $OS == "linux" && $ARCH == "aarch64" ]]; then - $JAXCI_PYTHON -m uv pip install numpy~=2.1.0 - fi + $JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt # Halt for testing - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Run Pytest CPU tests diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml index ae74da53edcb..39064d6d4647 100644 --- a/.github/workflows/pytest_cuda.yml +++ b/.github/workflows/pytest_cuda.yml @@ -17,46 +17,55 @@ on: runner: description: "Which runner should the workflow run on?" type: string - required: true - default: "linux-x86-n2-16" + default: "linux-x86-n4-16" python: description: "Which python version to test?" type: string - required: true default: "3.12" - cuda: + cuda-version: description: "Which CUDA version to test?" type: string - required: true - default: "12.3" + default: "12.9" + use-nvidia-pip-wheels: + description: "Whether to download CUDA packages from PyPI?" + type: boolean + default: false enable-x64: description: "Should x64 mode be enabled?" type: string - required: true default: "0" + skip-download-jaxlib-and-cuda-plugins-from-gcs: + description: "Whether to skip downloading the jaxlib and cuda plugins from GCS (e.g for testing a jax only release)" + default: '0' + type: string gcs_download_uri: description: "GCS location prefix from where the artifacts should be downloaded" - required: true default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string halt-for-connection: description: 'Should this workflow run wait for a remote connection?' - type: boolean - required: false - default: false + type: string + default: 'no' +permissions: {} + +env: + UV_DEFAULT_INDEX: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple" jobs: run-tests: defaults: run: - # Explicitly set the shell to bash + # Set the shell to bash as GitHub actions run with /bin/sh by default shell: bash runs-on: ${{ inputs.runner }} - # TODO: Update to the generic ML ecosystem test containers when they are ready. - container: ${{ (contains(inputs.cuda, '12.3') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.3-cudnn9.1:latest') || - (contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.1:latest') || - (contains(inputs.cuda, '12.8') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest') }} - name: "Pytest CUDA (${{ inputs.runner }}, CUDA ${{ inputs.cuda }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})" + # Test the oldest and newest supported CUDA versions. + # If testing the CUDA packages from PyPI, then use the ml-build image which does not have any + # CUDA pckages installed on the system. + container: ${{ !inputs.use-nvidia-pip-wheels && (contains(inputs.cuda-version, '12.1') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build-cuda12.1-cudnn9.8:latest') || + inputs.use-nvidia-pip-wheels && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest'}} + name: "${{ (contains(inputs.runner, 'h100') && 'h100') || + (contains(inputs.runner, 'b200') && 'b200') || + (contains(inputs.runner, 'l4') && 'l4') }}, CUDA ${{ inputs.cuda-version }}, py ${{ inputs.python }}, x64=${{ inputs.enable-x64 }}" env: JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" @@ -64,49 +73,39 @@ jobs: JAXCI_ENABLE_X64: "${{ inputs.enable-x64 }}" steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set env vars for use in artifact download URL - run: | - os=$(uname -s | awk '{print tolower($0)}') - arch=$(uname -m) - - # Get the major and minor version of Python. - # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 - # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.13-nogil, then python_major_minor=313t - python_major_minor=$(echo "${JAXCI_HERMETIC_PYTHON_VERSION//-nogil/t}" | tr -d '.') - - echo "OS=${os}" >> $GITHUB_ENV - echo "ARCH=${arch}" >> $GITHUB_ENV - # Python wheels follow a naming convention: standard wheels use the pattern - # `*-cp-cp-*`, while free-threaded wheels use - # `*-cp-cpt-*`. - echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV - - name: Download wheels from GCS - id: download-wheel-artifacts - # Set continue-on-error to true to prevent actions from failing the workflow if this step - # fails. Instead, we verify the outcome in the next step so that we can print a more - # informative error message. - continue-on-error: true - run: | - mkdir -p $(pwd)/dist && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ - - name: Skip the test run if the wheel artifacts were not downloaded successfully - if: steps.download-wheel-artifacts.outcome == 'failure' + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + with: + persist-credentials: false + - name: Download JAX CUDA wheels + uses: ./.github/actions/download-jax-cuda-wheels + with: + python: ${{ inputs.python }} + cuda-version: ${{ inputs.cuda-version }} + skip-download-jaxlib-and-cuda-plugins-from-gcs: ${{ inputs.skip-download-jaxlib-and-cuda-plugins-from-gcs }} + use-nvidia-pip-wheels: ${{ inputs.use-nvidia-pip-wheels }} + gcs_download_uri: ${{ inputs.gcs_download_uri }} + - name: Set jax PyPI extras run: | - echo "Failed to download wheel artifacts from GCS. Please check if the wheels were" - echo "built successfully by the artifact build jobs and are available in the GCS bucket." - echo "Skipping the test run." - exit 1 + # Set the "jax" wheel extra to configure whether to download the Nvidia Pip packages from + # PyPI. + if [[ "${{ inputs.use-nvidia-pip-wheels }}" == false ]] && [[ "${{ inputs.cuda-version }}" == 12* ]]; then + echo "JAXCI_JAX_PYPI_EXTRAS=cuda12-local">> $GITHUB_ENV + elif [[ "${{ inputs.use-nvidia-pip-wheels }}" == true ]] && [[ "${{ inputs.cuda-version }}" == 12* ]]; then + echo "JAXCI_JAX_PYPI_EXTRAS=cuda12">> $GITHUB_ENV + elif [[ "${{ inputs.use-nvidia-pip-wheels }}" == true ]] && [[ "${{ inputs.cuda-version }}" == 13* ]]; then + echo "JAXCI_JAX_PYPI_EXTRAS=cuda13">> $GITHUB_ENV + else + echo "Invalid CUDA version: ${{ inputs.cuda-version }}" + exit 1 + fi - name: Install Python dependencies - run: $JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt + run: | + $JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt # Halt for testing - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Run Pytest CUDA tests - timeout-minutes: 60 + timeout-minutes: 120 run: ./ci/run_pytest_cuda.sh \ No newline at end of file diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml index a105a2feb347..27b6f312a3a9 100644 --- a/.github/workflows/pytest_tpu.yml +++ b/.github/workflows/pytest_tpu.yml @@ -11,7 +11,6 @@ # - Installs the downloaded jaxlib wheel. # - Runs the TPU tests with Pytest. name: CI - Pytest TPU - on: workflow_call: inputs: @@ -23,47 +22,47 @@ on: runner: description: "Which runner should the workflow run on?" type: string - required: true default: "linux-x86-ct5lp-224-8tpu" cores: description: "How many TPU cores should the test use?" type: string - required: true default: "8" tpu-type: description: "Which TPU type is used for testing?" type: string - required: true default: "v5e-8" python: description: "Which Python version should be used for testing?" type: string - required: true default: "3.12" run-full-tpu-test-suite: description: "Should the full TPU test suite be run?" type: string - required: false default: "0" libtpu-version-type: description: "Which libtpu version should be used for testing?" type: string - required: false # Choices are: # - "nightly": Use the nightly libtpu wheel. # - "pypi_latest": Use the latest libtpu wheel from PyPI. # - "oldest_supported_libtpu": Use the oldest supported libtpu wheel. default: "nightly" + skip-download-jaxlib-from-gcs: + description: "Whether to skip downloading the jaxlib artifact from GCS (e.g for testing a jax only release)" + default: '0' + type: string gcs_download_uri: description: "GCS location prefix from where the artifacts should be downloaded" - required: true default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string halt-for-connection: description: 'Should this workflow run wait for a remote connection?' - type: boolean - required: false - default: false + type: string + default: 'no' +permissions: {} + +env: + UV_DEFAULT_INDEX: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple" jobs: run-tests: @@ -71,53 +70,27 @@ jobs: run: shell: bash runs-on: ${{ inputs.runner }} - container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" - # Begin Presubmit Naming Check - name modification requires internal check to be updated - name: "Pytest TPU (${{ inputs.tpu-type }}, Python ${{ inputs.python }}, libtpu=${{ inputs.libtpu-version-type }})" - # End Presubmit Naming Check github-tpu-presubmits + container: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest" + name: "${{ inputs.tpu-type }}, py ${{ inputs.python }}, libtpu=${{ inputs.libtpu-version-type }}" env: - LIBTPU_OLDEST_VERSION_DATE: 20241205 + LIBTPU_OLDEST_VERSION_DATE: 20250228 JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" JAXCI_PYTHON: "python${{ inputs.python }}" JAXCI_RUN_FULL_TPU_TEST_SUITE: "${{ inputs.run-full-tpu-test-suite }}" JAXCI_TPU_CORES: "${{ inputs.cores }}" steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set env vars for use in artifact download URL - run: | - os=$(uname -s | awk '{print tolower($0)}') - arch=$(uname -m) - - # Get the major and minor version of Python. - # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 - # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.13-nogil, then python_major_minor=313t - python_major_minor=$(echo "${JAXCI_HERMETIC_PYTHON_VERSION//-nogil/t}" | tr -d '.') - - echo "OS=${os}" >> $GITHUB_ENV - echo "ARCH=${arch}" >> $GITHUB_ENV - # Python wheels follow a naming convention: standard wheels use the pattern - # `*-cp-cp-*`, while free-threaded wheels use - # `*-cp-cpt-*`. - echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV - - name: Download JAX wheels from GCS - id: download-wheel-artifacts - # Set continue-on-error to true to prevent actions from failing the workflow if this step - # fails. Instead, we verify the outcome in the step below so that we can print a more - # informative error message. - continue-on-error: true - run: | - mkdir -p $(pwd)/dist - gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ - - name: Skip the test run if the wheel artifacts were not downloaded successfully - if: steps.download-wheel-artifacts.outcome == 'failure' - run: | - echo "Failed to download wheel artifacts from GCS. Please check if the wheels were" - echo "built successfully by the artifact build jobs and are available in the GCS bucket." - echo "Skipping the test run." - exit 1 + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + with: + persist-credentials: false + - name: Download JAX CPU wheels + uses: ./.github/actions/download-jax-cpu-wheels + with: + runner: ${{ inputs.runner }} + python: ${{ inputs.python }} + skip-download-jaxlib-from-gcs: ${{ inputs.skip-download-jaxlib-from-gcs }} + gcs_download_uri: ${{ inputs.gcs_download_uri }} - name: Install Python dependencies run: | $JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt -r build/collect-profile-requirements.txt @@ -128,9 +101,9 @@ jobs: $JAXCI_PYTHON -m uv pip install --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html elif [[ "${{ inputs.libtpu-version-type }}" == "pypi_latest" ]]; then echo "Using latest libtpu from PyPI" - # Set JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI to "tpu_pypi". The `run_pytest_tpu.sh` - # script will install the latest libtpu wheel from PyPI. - echo "JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI=tpu_pypi" >> $GITHUB_ENV + # Set JAXCI_JAX_PYPI_EXTRAS to "tpu". The `run_pytest_tpu.sh` script will install the + # latest libtpu wheel from PyPI. + echo "JAXCI_JAX_PYPI_EXTRAS=tpu" >> $GITHUB_ENV elif [[ "${{ inputs.libtpu-version-type }}" == "oldest_supported_libtpu" ]]; then echo "Using oldest supported libtpu" $JAXCI_PYTHON -m uv pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \ @@ -143,9 +116,14 @@ jobs: fi # Halt for testing - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Run Pytest TPU tests - timeout-minutes: ${{ github.event_name == 'pull_request' && 30 || 180 }} - run: ./ci/run_pytest_tpu.sh + timeout-minutes: ${{ github.event_name == 'pull_request' && 30 || 210 }} + run: | + if [[ ${{ inputs.python }} == "3.13-nogil" ]]; then + echo "Uninstalling xprof as it is not compatible with python 3.13t." + $JAXCI_PYTHON -m uv pip uninstall xprof + fi + ./ci/run_pytest_tpu.sh diff --git a/.github/workflows/release-notification.yml b/.github/workflows/release-notification.yml index a4a342ef6de7..6d68bf922655 100644 --- a/.github/workflows/release-notification.yml +++ b/.github/workflows/release-notification.yml @@ -2,14 +2,21 @@ name: Google Chat Release Notification on: release: types: [published] +permissions: {} jobs: build: + env: + WEBHOOK_URL: ${{ secrets.RELEASES_WEBHOOK }} + RELEASE_NAME: ${{github.event.release.name}} + PUBLISHED_AT: ${{github.event.release.published_at}} + AUTHOR_LOGIN: ${{github.event.release.author.login}} + RELEASE_URL: ${{github.event.release.url}} runs-on: ubuntu-latest steps: - name: Google Chat Notification run: | - curl --location --request POST '${{ secrets.RELEASES_WEBHOOK }}' \ + curl --location --request POST '${WEBHOOK_URL}' \ --header 'Content-Type: application/json' \ --data-raw '{ - "text": "Release ${{github.event.release.name}} at ${{github.event.release.published_at}} by ${{github.event.release.author.login}}. <${{github.event.release.url}}|[github]>" + "text": "Release $RELEASE_NAME at $PUBLISHED_AT by $AUTHOR_LOGIN. <$RELEASE_URL|[github]>" }' diff --git a/.github/workflows/requirements_lock_3_13_ft.patch b/.github/workflows/requirements_lock_3_13_ft.patch deleted file mode 100644 index 0b63cb5b8711..000000000000 --- a/.github/workflows/requirements_lock_3_13_ft.patch +++ /dev/null @@ -1,85 +0,0 @@ -diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt -index e7a2968e9..d37e11ee3 100644 ---- a/build/requirements_lock_3_13_ft.txt -+++ b/build/requirements_lock_3_13_ft.txt -@@ -4,6 +4,11 @@ - # - # pip-compile --allow-unsafe --generate-hashes --output-file=build/requirements_lock_3_13_ft.txt build/requirements.in - # -+ -+--pre -+--extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple -+numpy -+ - absl-py==2.1.0 \ - --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ - --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff -@@ -328,68 +333,6 @@ mpmath==1.3.0 \ - --hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \ - --hash=sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c - # via -r build/test-requirements.txt --numpy==2.2.1 ; python_version >= "3.13" \ -- --hash=sha256:059e6a747ae84fce488c3ee397cee7e5f905fd1bda5fb18c66bc41807ff119b2 \ -- --hash=sha256:08ef779aed40dbc52729d6ffe7dd51df85796a702afbf68a4f4e41fafdc8bda5 \ -- --hash=sha256:164a829b6aacf79ca47ba4814b130c4020b202522a93d7bff2202bfb33b61c60 \ -- --hash=sha256:26c9c4382b19fcfbbed3238a14abf7ff223890ea1936b8890f058e7ba35e8d71 \ -- --hash=sha256:27f5cdf9f493b35f7e41e8368e7d7b4bbafaf9660cba53fb21d2cd174ec09631 \ -- --hash=sha256:31b89fa67a8042e96715c68e071a1200c4e172f93b0fbe01a14c0ff3ff820fc8 \ -- --hash=sha256:32cb94448be47c500d2c7a95f93e2f21a01f1fd05dd2beea1ccd049bb6001cd2 \ -- --hash=sha256:360137f8fb1b753c5cde3ac388597ad680eccbbbb3865ab65efea062c4a1fd16 \ -- --hash=sha256:3683a8d166f2692664262fd4900f207791d005fb088d7fdb973cc8d663626faa \ -- --hash=sha256:38efc1e56b73cc9b182fe55e56e63b044dd26a72128fd2fbd502f75555d92591 \ -- --hash=sha256:3d03883435a19794e41f147612a77a8f56d4e52822337844fff3d4040a142964 \ -- --hash=sha256:3ecc47cd7f6ea0336042be87d9e7da378e5c7e9b3c8ad0f7c966f714fc10d821 \ -- --hash=sha256:40f9e544c1c56ba8f1cf7686a8c9b5bb249e665d40d626a23899ba6d5d9e1484 \ -- --hash=sha256:4250888bcb96617e00bfa28ac24850a83c9f3a16db471eca2ee1f1714df0f957 \ -- --hash=sha256:4511d9e6071452b944207c8ce46ad2f897307910b402ea5fa975da32e0102800 \ -- --hash=sha256:45681fd7128c8ad1c379f0ca0776a8b0c6583d2f69889ddac01559dfe4390918 \ -- --hash=sha256:48fd472630715e1c1c89bf1feab55c29098cb403cc184b4859f9c86d4fcb6a95 \ -- --hash=sha256:4c86e2a209199ead7ee0af65e1d9992d1dce7e1f63c4b9a616500f93820658d0 \ -- --hash=sha256:4dfda918a13cc4f81e9118dea249e192ab167a0bb1966272d5503e39234d694e \ -- --hash=sha256:5062dc1a4e32a10dc2b8b13cedd58988261416e811c1dc4dbdea4f57eea61b0d \ -- --hash=sha256:51faf345324db860b515d3f364eaa93d0e0551a88d6218a7d61286554d190d73 \ -- --hash=sha256:526fc406ab991a340744aad7e25251dd47a6720a685fa3331e5c59fef5282a59 \ -- --hash=sha256:53c09385ff0b72ba79d8715683c1168c12e0b6e84fb0372e97553d1ea91efe51 \ -- --hash=sha256:55ba24ebe208344aa7a00e4482f65742969a039c2acfcb910bc6fcd776eb4355 \ -- --hash=sha256:5b6c390bfaef8c45a260554888966618328d30e72173697e5cabe6b285fb2348 \ -- --hash=sha256:5c5cc0cbabe9452038ed984d05ac87910f89370b9242371bd9079cb4af61811e \ -- --hash=sha256:5edb4e4caf751c1518e6a26a83501fda79bff41cc59dac48d70e6d65d4ec4440 \ -- --hash=sha256:61048b4a49b1c93fe13426e04e04fdf5a03f456616f6e98c7576144677598675 \ -- --hash=sha256:676f4eebf6b2d430300f1f4f4c2461685f8269f94c89698d832cdf9277f30b84 \ -- --hash=sha256:67d4cda6fa6ffa073b08c8372aa5fa767ceb10c9a0587c707505a6d426f4e046 \ -- --hash=sha256:694f9e921a0c8f252980e85bce61ebbd07ed2b7d4fa72d0e4246f2f8aa6642ab \ -- --hash=sha256:733585f9f4b62e9b3528dd1070ec4f52b8acf64215b60a845fa13ebd73cd0712 \ -- --hash=sha256:7671dc19c7019103ca44e8d94917eba8534c76133523ca8406822efdd19c9308 \ -- --hash=sha256:780077d95eafc2ccc3ced969db22377b3864e5b9a0ea5eb347cc93b3ea900315 \ -- --hash=sha256:7ba9cc93a91d86365a5d270dee221fdc04fb68d7478e6bf6af650de78a8339e3 \ -- --hash=sha256:89b16a18e7bba224ce5114db863e7029803c179979e1af6ad6a6b11f70545008 \ -- --hash=sha256:9036d6365d13b6cbe8f27a0eaf73ddcc070cae584e5ff94bb45e3e9d729feab5 \ -- --hash=sha256:93cf4e045bae74c90ca833cba583c14b62cb4ba2cba0abd2b141ab52548247e2 \ -- --hash=sha256:9ad014faa93dbb52c80d8f4d3dcf855865c876c9660cb9bd7553843dd03a4b1e \ -- --hash=sha256:9b1d07b53b78bf84a96898c1bc139ad7f10fda7423f5fd158fd0f47ec5e01ac7 \ -- --hash=sha256:a7746f235c47abc72b102d3bce9977714c2444bdfaea7888d241b4c4bb6a78bf \ -- --hash=sha256:aa3017c40d513ccac9621a2364f939d39e550c542eb2a894b4c8da92b38896ab \ -- --hash=sha256:b34d87e8a3090ea626003f87f9392b3929a7bbf4104a05b6667348b6bd4bf1cd \ -- --hash=sha256:b541032178a718c165a49638d28272b771053f628382d5e9d1c93df23ff58dbf \ -- --hash=sha256:ba5511d8f31c033a5fcbda22dd5c813630af98c70b2661f2d2c654ae3cdfcfc8 \ -- --hash=sha256:bc8a37ad5b22c08e2dbd27df2b3ef7e5c0864235805b1e718a235bcb200cf1cb \ -- --hash=sha256:bff7d8ec20f5f42607599f9994770fa65d76edca264a87b5e4ea5629bce12268 \ -- --hash=sha256:c1ad395cf254c4fbb5b2132fee391f361a6e8c1adbd28f2cd8e79308a615fe9d \ -- --hash=sha256:f1d09e520217618e76396377c81fba6f290d5f926f50c35f3a5f72b01a0da780 \ -- --hash=sha256:f3eac17d9ec51be534685ba877b6ab5edc3ab7ec95c8f163e5d7b39859524716 \ -- --hash=sha256:f419290bc8968a46c4933158c91a0012b7a99bb2e465d5ef5293879742f8797e \ -- --hash=sha256:f62aa6ee4eb43b024b0e5a01cf65a0bb078ef8c395e8713c6e8a12a697144528 \ -- --hash=sha256:f74e6fdeb9a265624ec3a3918430205dff1df7e95a230779746a6af78bc615af \ -- --hash=sha256:f9b57eaa3b0cd8db52049ed0330747b0364e899e8a606a624813452b8203d5f7 \ -- --hash=sha256:fce4f615f8ca31b2e61aa0eb5865a21e14f5629515c9151850aa936c02a1ee51 -- # via -- # -r build/requirements.in -- # contourpy -- # matplotlib -- # ml-dtypes -- # scipy - nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ - --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ - --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index 713e9099e381..4195ff1ef004 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -6,9 +6,7 @@ on: branches: - main -permissions: - contents: read - +permissions: {} concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} @@ -18,8 +16,8 @@ jobs: env: BASE_IMAGE: "ubuntu:22.04" TEST_IMAGE: ubuntu-jax-upstream-${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }} - PYTHON_VERSION: "3.10" - ROCM_VERSION: "6.2.4" + PYTHON_VERSION: "3.11" + ROCM_VERSION: "6.3.3" WORKSPACE_DIR: workdir_${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }} steps: - name: Clean up old runs @@ -33,9 +31,10 @@ jobs: - name: Print system info run: | rocm-smi - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 with: path: ${{ env.WORKSPACE_DIR }} + persist-credentials: false - name: Build JAX run: | pushd $WORKSPACE_DIR @@ -47,7 +46,7 @@ jobs: dist_docker \ --image-tag $TEST_IMAGE - name: Archive jax wheels - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0 with: name: rocm_jax_r${{ env.ROCM_VERSION }}_py${{ env.PYTHON_VERSION }}_id${{ github.run_id }} path: ${{ env.WORKSPACE_DIR }}/dist/*.whl diff --git a/.github/workflows/tsan.yaml b/.github/workflows/tsan.yaml index 7d93707e4e92..55f4bfeeeebd 100644 --- a/.github/workflows/tsan.yaml +++ b/.github/workflows/tsan.yaml @@ -3,67 +3,120 @@ name: CI - Free-threading and Thread Sanitizer (nightly) concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true - on: schedule: - - cron: "0 5 * * *" # Daily at 05:00 UTC == 00:00 EST == 21:00 PST + - cron: "0 5 * * *" # Daily at 05:00 UTC == 00:00 EST == 21:00 PST workflow_dispatch: # allows triggering the workflow run manually pull_request: # Automatically trigger on pull requests affecting this file branches: - main paths: - '**/workflows/tsan.yaml' + - '**/workflows/tsan-suppressions*.txt' +permissions: {} + +env: + UV_DEFAULT_INDEX: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple" + PIP_INDEX_URL: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple" jobs: tsan: - runs-on: linux-x86-n2-64 + runs-on: linux-x86-n4-64 container: - image: index.docker.io/library/ubuntu@sha256:b359f1067efa76f37863778f7b6d0e8d911e3ee8efa807ad01fbf5dc1ef9006b # ratchet:ubuntu:24.04 + image: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build@sha256:ea67e8453d8b09c2ba48853da5e79efef4b65804b4a48dfae4b4da89ffd38405 # ml-build container (based on Ubuntu 22.04) strategy: fail-fast: false + matrix: + include: + - name-prefix: "with 3.14" + python-version: "3.14" + github_branch: "3.14" + requirements_lock_name: "requirements_lock_3_14_ft" defaults: run: shell: bash -l {0} steps: # Install git before actions/checkout as otherwise it will download the code with the GitHub # REST API and therefore any subsequent git commands will fail. - - name: Install clang 18 + # Also install dependencies for Google Cloud SDK (curl, python3, etc) + - name: Install dependencies env: DEBIAN_FRONTEND: noninteractive run: | - apt update - apt install -y clang-18 libstdc++-14-dev build-essential libssl-dev \ - zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl git \ - libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \ - libffi-dev liblzma-dev file zip - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + apt-get update + apt-get install -q -y \ + clang-18 \ + libclang-common-18-dev \ + libclang-rt-18-dev \ + libc++abi-18-dev \ + lld-18 \ + libstdc++-12-dev \ + libc++-18-dev \ + build-essential \ + libssl-dev \ + zlib1g-dev \ + libbz2-dev \ + libreadline-dev \ + libsqlite3-dev \ + curl \ + git \ + libncursesw5-dev \ + xz-utils \ + tk-dev \ + libxml2-dev \ + libxmlsec1-dev \ + libffi-dev \ + liblzma-dev \ + file \ + vim \ + wget \ + zip \ + zstd + - name: Install Google Cloud SDK + run: | + echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list + curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - + apt update && apt install -y google-cloud-cli + + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 with: path: jax - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - repository: python/cpython - path: cpython - ref: "3.13" - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + persist-credentials: false + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 with: repository: numpy/numpy path: numpy submodules: true + persist-credentials: false + + - name: Get year & week number + id: get-date + run: echo "date=$(/bin/date "+%Y-%U")" >> $GITHUB_OUTPUT + shell: bash -l {0} - - name: Restore cached TSAN CPython + - name: Restore cached TSAN CPython ${{ matrix.python-version }} id: cache-cpython-tsan-restore - uses: actions/cache/restore@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + uses: actions/cache/restore@0057852bfaa89a56745cba8c7296529d2fc39830 # v4.3.0 with: path: | ./python-tsan.tgz - key: ${{ runner.os }}-cpython-tsan-${{ hashFiles('cpython/configure.ac') }} + key: ${{ runner.os }}-cpython-tsan-${{ matrix.python-version }}-${{ steps.get-date.outputs.date }} - - name: Build CPython with enabled TSAN + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + if: steps.cache-cpython-tsan-restore.outputs.cache-hit != 'true' + with: + repository: python/cpython + path: cpython + ref: ${{ matrix.github_branch }} + persist-credentials: false + + + - name: Build TSAN CPython ${{ matrix.python-version }} if: steps.cache-cpython-tsan-restore.outputs.cache-hit != 'true' run: | cd cpython mkdir ${GITHUB_WORKSPACE}/cpython-tsan - CC=clang-18 CXX=clang++-18 ./configure --prefix ${GITHUB_WORKSPACE}/cpython-tsan --disable-gil --with-thread-sanitizer + CC=clang-18 CXX=clang++-18 CFLAGS=" -O0 -g" CXXFLAGS="-stdlib=libc++" ./configure --prefix ${GITHUB_WORKSPACE}/cpython-tsan --disable-gil --with-thread-sanitizer --with-mimalloc make -j64 make install -j64 # Check whether free-threading mode is enabled @@ -72,31 +125,43 @@ jobs: # Create archive to be used with bazel as hermetic python: cd ${GITHUB_WORKSPACE} && tar -czpf python-tsan.tgz cpython-tsan - - name: Save TSAN CPython + - name: Save TSAN CPython ${{ matrix.python-version }} id: cache-cpython-tsan-save if: steps.cache-cpython-tsan-restore.outputs.cache-hit != 'true' - uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + uses: actions/cache/save@0057852bfaa89a56745cba8c7296529d2fc39830 # v4.3.0 with: path: | ./python-tsan.tgz - key: ${{ runner.os }}-cpython-tsan-${{ hashFiles('cpython/configure.ac') }} + key: ${{ runner.os }}-cpython-tsan-${{ matrix.python-version }}-${{ steps.get-date.outputs.date }} - - name: Get year & week number - id: get-date - run: echo "date=$(/bin/date "+%Y-%U")" >> $GITHUB_OUTPUT - shell: bash -l {0} + # Upload the Python tarball to GCS so RBE can access it. + - name: Upload TSAN CPython to GCS + run: | + GCS_DEST="gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}" + echo "Uploading python-tsan.tgz to $GCS_DEST" + gcloud storage cp python-tsan.tgz "$GCS_DEST/" + + # Output the HTTP URL for Bazel + BASE_URL="https://storage.googleapis.com/general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}/python-tsan.tgz" + # URL-encode the path (handles spaces in workflow name) + PUBLIC_URL=$(python3 -c "import urllib.parse; print(urllib.parse.quote('${BASE_URL}', safe=':/'))") + echo "HERMETIC_PYTHON_URL=$PUBLIC_URL" >> $GITHUB_ENV + + # --- Condtional NumPy steps for pre-release Python --- - name: Restore cached TSAN Numpy id: cache-numpy-tsan-restore - uses: actions/cache/restore@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + if: matrix.python-version == '3.15' + uses: actions/cache/restore@0057852bfaa89a56745cba8c7296529d2fc39830 # v4.3.0 with: path: | ./wheelhouse - key: ${{ runner.os }}-numpy-tsan-${{ hashFiles('numpy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} + key: ${{ runner.os }}-numpy-tsan-${{ matrix.python-version }}-${{ hashFiles('numpy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} - name: Build TSAN Numpy wheel - if: steps.cache-numpy-tsan-restore.outputs.cache-hit != 'true' + if: matrix.python-version == '3.15' && steps.cache-numpy-tsan-restore.outputs.cache-hit != 'true' run: | + set -eux cd numpy # If we restored cpython from cache, we need to get python interpreter from python-tsan.tgz @@ -112,8 +177,7 @@ jobs: export PATH=${GITHUB_WORKSPACE}/cpython-tsan/bin/:$PATH python3 -m pip install uv~=0.5.30 - # Make sure to install a compatible Cython version (master branch is best for now) - python3 -m uv pip install -r requirements/build_requirements.txt -U git+https://github.com/cython/cython + python3 -m uv pip install -r requirements/build_requirements.txt CC=clang-18 CXX=clang++-18 python3 -m pip wheel --wheel-dir dist -v . --no-build-isolation -Csetup-args=-Db_sanitize=thread -Csetup-args=-Dbuildtype=debugoptimized @@ -141,76 +205,78 @@ jobs: - name: Save TSAN Numpy wheel id: cache-numpy-tsan-save - if: steps.cache-numpy-tsan-restore.outputs.cache-hit != 'true' - uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + if: matrix.python-version == '3.15' && steps.cache-numpy-tsan-restore.outputs.cache-hit != 'true' + uses: actions/cache/save@0057852bfaa89a56745cba8c7296529d2fc39830 # v4.3.0 with: path: | ./wheelhouse - key: ${{ runner.os }}-numpy-tsan-${{ hashFiles('numpy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} + key: ${{ runner.os }}-numpy-tsan-${{ matrix.python-version }}-${{ hashFiles('numpy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} + + # --- End Conditional NumPy Steps --- - name: Build Jax and run tests - timeout-minutes: 120 + timeout-minutes: 180 env: JAX_NUM_GENERATED_CASES: 1 JAX_ENABLE_X64: true JAX_SKIP_SLOW_TESTS: true PY_COLORS: 1 + DEBIAN_FRONTEND: noninteractive run: | + set -x cd jax - export PYTHON_SHA256=($(sha256sum ${GITHUB_WORKSPACE}/python-tsan.tgz)) + # Calculate SHA256 of the Python tarball + export PYTHON_SHA256=($(sha256sum ${GITHUB_WORKSPACE}/python-tsan.tgz | awk '{ print $1 }')) echo "Python sha256: ${PYTHON_SHA256}" + echo "Python URL: $HERMETIC_PYTHON_URL" - python3 -VV + # Configure Bazel python3 build/build.py build --configure_only \ - --python_version=3.13-ft \ - --bazel_options=--repo_env=HERMETIC_PYTHON_URL="file://${GITHUB_WORKSPACE}/python-tsan.tgz" \ + --python_version=${{ matrix.python-version }}-ft \ + --bazel_options=--repo_env=HERMETIC_PYTHON_URL="${HERMETIC_PYTHON_URL}" \ --bazel_options=--repo_env=HERMETIC_PYTHON_SHA256=${PYTHON_SHA256} \ --bazel_options=--repo_env=HERMETIC_PYTHON_PREFIX="cpython-tsan/" \ --bazel_options=--color=yes \ --bazel_options=--copt=-fsanitize=thread \ --bazel_options=--linkopt="-fsanitize=thread" \ - --bazel_options=--copt=-g \ - --clang_path=/usr/bin/clang-18 - - # Patch build/requirements_lock_3_13_ft.txt to use TSAN instrumented NumPy - sed -i "s|+--extra-index-url.*|+--extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/|" .github/workflows/requirements_lock_3_13_ft.patch - cat .github/workflows/requirements_lock_3_13_ft.patch - git apply .github/workflows/requirements_lock_3_13_ft.patch || exit 1 - - # Display the content for debugging in logs - cat build/requirements_lock_3_13_ft.txt | head -15 - # Check the patch - cat build/requirements_lock_3_13_ft.txt | head -15 | grep -E "(--pre|.*${GITHUB_WORKSPACE}/wheelhouse/|numpy)" - if [ "$?" == "1" ]; then echo "Could not find the patch in the requirements_lock_3_13_ft.txt"; exit 1; fi - cat build/requirements_lock_3_13_ft.txt | grep -E "(numpy==)" - if [ "$?" == "0" ]; then "Found original numpy dependency in the requirements_lock_3_13_ft.txt"; exit 1; fi + --bazel_options=--copt=-g + + mkdir -p dist + + # Copy custom Numpy wheel only if using 3.15 (and if it exists) + if [ "${{ matrix.python-version }}" == "3.15" ]; then + ls ${GITHUB_WORKSPACE}/wheelhouse/numpy/*.whl || exit 1 + cp -v ${GITHUB_WORKSPACE}/wheelhouse/numpy/*.whl dist/ + + # Patch requirements lock to use TSAN-instrumented NumPy + sed -i "s|--extra-index-url.*|--extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/|" build/${{ matrix.requirements_lock_name }}.txt + fi echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES" echo "JAX_ENABLE_X64=$JAX_ENABLE_X64" echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS" - # Set symlink to the bazel executable - bazel_exec=($(ls bazel-*)) - ln -s ${bazel_exec} bazel # Check python version - ./bazel run --@rules_python//python/config_settings:py_freethreaded="yes" @python//:python3 -- -VV + python_version="${{ matrix.python-version }}" + bazel run --config=rbe_linux_x86_64 --@rules_python//python/config_settings:py_freethreaded="yes" @python_${python_version//./_}_host//:python -- -VV # Check numpy version - ./bazel cquery @pypi_numpy//:* | grep whl + bazel cquery @pypi_numpy//:* | grep whl # Build JAX and run tests - ./bazel test \ + bazel test \ --test_env=JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES \ --test_env=JAX_ENABLE_X64=$JAX_ENABLE_X64 \ --test_env=JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS \ --test_env=PYTHON_GIL=0 \ - --test_env=TSAN_OPTIONS=halt_on_error=1,suppressions=$PWD/.github/workflows/tsan-suppressions.txt \ + --test_env=TSAN_OPTIONS="halt_on_error=1,suppressions=tests/config/tsan-suppressions_${{ matrix.python-version }}.txt" \ --test_env=JAX_TEST_NUM_THREADS=8 \ + --@rules_python//python/config_settings:py_freethreaded="yes" \ --test_output=errors \ - --local_test_jobs=32 \ - --test_timeout=600 \ - --config=resultstore \ - --config=rbe_cache \ + --test_timeout=1800 \ + --config=rbe_linux_x86_64 \ + --test_tag_filters=-notsan \ + --run_under=//tests/config:oss_tsan_wrapper_sh \ //tests:cpu_tests diff --git a/.github/workflows/upstream-nightly.yml b/.github/workflows/upstream-nightly.yml index 5132a12cf16f..1af06df444bb 100644 --- a/.github/workflows/upstream-nightly.yml +++ b/.github/workflows/upstream-nightly.yml @@ -19,10 +19,16 @@ on: - main paths: - '**workflows/upstream-nightly.yml' +permissions: {} + +env: + UV_DEFAULT_INDEX: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple" + PIP_INDEX_URL: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple" jobs: upstream-dev: - runs-on: ubuntu-latest + runs-on: linux-x86-n4-64 + container: index.docker.io/library/ubuntu@sha256:b359f1067efa76f37863778f7b6d0e8d911e3ee8efa807ad01fbf5dc1ef9006b # ratchet:ubuntu:24.04 permissions: contents: read issues: write # for failed-build-issue @@ -31,9 +37,11 @@ jobs: matrix: python-version: ["3.13"] steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + with: + persist-credentials: false - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 with: python-version: ${{ matrix.python-version }} - name: Install JAX test requirements @@ -66,7 +74,7 @@ jobs: echo "JAX_ENABLE_X64=$JAX_ENABLE_X64" echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS" - pytest -n 2 --tb=short --maxfail=20 tests examples + pytest -n auto --tb=short --maxfail=20 tests examples - name: Notify failed build uses: jayqi/failed-build-issue-action@1a893bbf43ef1c2a8705e2b115cd4f0fe3c5649b # v1.2.0 if: failure() && github.event.pull_request == null diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index ecdf43b133cc..382b5b1b55f7 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -1,6 +1,6 @@ # CI - Wheel Tests (Continuous) # -# This workflow builds JAX artifacts and runs CPU/CUDA tests. +# This workflow builds JAX artifacts and runs CPU/TPU/CUDA tests. # # It orchestrates the following: # 1. build-jaxlib-artifact: Calls the `build_artifacts.yml` workflow to build jaxlib and @@ -9,17 +9,27 @@ # that was built in the previous step and runs CPU tests. # 3. build-cuda-artifacts: Calls the `build_artifacts.yml` workflow to build CUDA artifacts and # uploads them to a GCS bucket. -# 4. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow which downloads the jaxlib and CUDA +# 4. run-bazel-test-cpu-py-import: Calls the `bazel_cpu_rbe.yml` workflow which +# runs Bazel CPU tests with py_import on RBE. +# 5. run-bazel-test-cuda-py-import: Calls the `bazel_cuda.yml` workflow which +# runs Bazel CUDA tests with py_import on non-RBE. +# 6. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow which downloads the jaxlib and CUDA # artifacts that were built in the previous steps and runs the CUDA tests. -# 5. run-bazel-test-cuda: Calls the `bazel_cuda_non_rbe.yml` workflow which downloads the jaxlib +# 7. run-bazel-test-cuda: Calls the `bazel_cuda.yml` workflow which downloads the jaxlib # and CUDA artifacts that were built in the previous steps and runs the # CUDA tests using Bazel. +# 8. run-pytest-tpu: Calls the `pytest_tpu.yml` workflow which downloads the jaxlib wheel +# that was built in the previous step and runs TPU tests. +# 9. run-bazel-test-tpu: Calls the `bazel_test_tpu.yml` workflow which +# runs Bazel TPU tests with py_import. name: CI - Wheel Tests (Continuous) +permissions: + contents: read on: schedule: - - cron: "0 */2 * * *" # Run once every 2 hours + - cron: "0 */3 * * *" # Run once every 3 hours workflow_dispatch: # allows triggering the workflow run manually concurrency: @@ -33,7 +43,7 @@ jobs: with: # Note that since jax is a pure python package, the runner OS and Python values do not # matter. In addition, cloning main XLA also has no effect. - runner: "linux-x86-n2-16" + runner: "linux-x86-n4-16" artifact: "jax" upload_artifacts_to_gcs: true gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' @@ -44,9 +54,9 @@ jobs: fail-fast: false # don't cancel all jobs on failure matrix: # Runner OS and Python values need to match the matrix stategy in the CPU tests job - runner: ["linux-x86-n2-16", "linux-arm64-t2a-48", "windows-x86-n2-64"] + runner: ["linux-x86-n4-16", "linux-arm64-t2a-48", "windows-x86-n2-16"] artifact: ["jaxlib"] - python: ["3.10"] + python: ["3.11"] # Note: For reasons unknown, Github actions groups jobs with the same top-level name in the # dashboard only if we use an expression in the "name" field. Otherwise, it appends the matrix # values to the name and creates a separate entry for each matrix combination. @@ -65,14 +75,16 @@ jobs: fail-fast: false # don't cancel all jobs on failure matrix: # Python values need to match the matrix stategy in the CUDA tests job below - runner: ["linux-x86-n2-16"] + runner: ["linux-x86-n4-16"] artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"] - python: ["3.10",] + python: ["3.11",] + cuda-version: ["12", "13"] name: "Build ${{ format('{0}', 'CUDA') }} artifacts" with: runner: ${{ matrix.runner }} artifact: ${{ matrix.artifact }} python: ${{ matrix.python }} + cuda-version: ${{ matrix.cuda-version }} clone_main_xla: 1 upload_artifacts_to_gcs: true gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' @@ -89,8 +101,8 @@ jobs: matrix: # Runner OS and Python values need to match the matrix stategy in the # build_jaxlib_artifact job above - runner: ["linux-x86-n2-64", "linux-arm64-t2a-48", "windows-x86-n2-64"] - python: ["3.10",] + runner: ["linux-x86-n4-64", "linux-arm64-t2a-48", "windows-x86-n2-64"] + python: ["3.11",] enable-x64: [1, 0] name: "Pytest CPU (JAX artifacts version = ${{ format('{0}', 'head') }})" with: @@ -111,59 +123,109 @@ jobs: matrix: # Python values need to match the matrix stategy in the artifact build jobs above # See exlusions for what is fully tested - runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu","linux-x86-a4-224-b200-1gpu"] - python: ["3.10",] - cuda: ["12.1","12.3","12.8"] + runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu", "linux-x86-a4-224-b200-1gpu"] + python: ["3.11",] + cuda: [ + {version: "12.1", use-nvidia-pip-wheels: false}, + {version: "12.9", use-nvidia-pip-wheels: true}, + {version: "13", use-nvidia-pip-wheels: true}, + ] enable-x64: [1, 0] exclude: - # L4 does not run on cuda 12.8 but tests other configs - - runner: "linux-x86-g2-48-l4-4gpu" - cuda: "12.8" - # H100 runs only a single config, CUDA 12.3 Enable x64 1 + # H100 runs only a single config, CUDA 12.9 Enable x64 1 - runner: "linux-x86-a3-8g-h100-8gpu" - cuda: "12.8" + cuda: + version: "12.1" - runner: "linux-x86-a3-8g-h100-8gpu" - cuda: "12.1" - - runner: "linux-x86-a3-8g-h100-8gpu" - enable-x64: "0" - # B200 runs only a single config, CUDA 12.8 Enable x64 1 - - runner: "linux-x86-a4-224-b200-1gpu" enable-x64: "0" + # B200 runs only a single config, CUDA 12.9 Enable x64 1 - runner: "linux-x86-a4-224-b200-1gpu" - cuda: "12.1" + cuda: + version: "12.1" - runner: "linux-x86-a4-224-b200-1gpu" - cuda: "12.3" + enable-x64: "0" - name: "Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }})" + name: "Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }}, CUDA Pip packages = ${{ matrix.cuda.use-nvidia-pip-wheels }})" with: runner: ${{ matrix.runner }} python: ${{ matrix.python }} - cuda: ${{ matrix.cuda }} + cuda-version: ${{ matrix.cuda.version }} + use-nvidia-pip-wheels: ${{ matrix.cuda.use-nvidia-pip-wheels }} enable-x64: ${{ matrix.enable-x64 }} # GCS upload URI is the same for both artifact build jobs gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} + run-bazel-test-cpu-py-import: + uses: ./.github/workflows/bazel_cpu.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + runner: ["linux-x86-n4-16", "linux-arm64-t2a-48", "windows-x86-n2-16"] + python: ["3.11",] + enable-x64: [1, 0] + name: "Bazel CPU tests with ${{ format('{0}', 'build_jaxlib=wheel') }}" + with: + runner: ${{ matrix.runner }} + python: ${{ matrix.python }} + enable-x64: ${{ matrix.enable-x64 }} + build_jaxlib: "wheel" + build_jax: "wheel" + run-bazel-test-cuda: # Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we # still want to run the tests for other platforms. if: ${{ !cancelled() }} - needs: [build-jaxlib-artifact, build-cuda-artifacts] - uses: ./.github/workflows/bazel_cuda_non_rbe.yml + needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts] + uses: ./.github/workflows/bazel_cuda.yml strategy: fail-fast: false # don't cancel all jobs on failure matrix: # Python values need to match the matrix stategy in the build artifacts job above runner: ["linux-x86-g2-48-l4-4gpu",] - python: ["3.10",] + python: ["3.11",] + cuda-version: ["12", "13"] + jaxlib-version: ["head", "pypi_latest"] enable-x64: [1, 0] - name: "Bazel CUDA Non-RBE (JAX artifacts version = ${{ format('{0}', 'head') }})" + name: "Bazel CUDA Non-RBE with build_jaxlib=false, (jax version = ${{ format('{0}', 'head') }})" with: runner: ${{ matrix.runner }} python: ${{ matrix.python }} + cuda-version: ${{ matrix.cuda-version }} enable-x64: ${{ matrix.enable-x64 }} + jaxlib-version: ${{ matrix.jaxlib-version }} # GCS upload URI is the same for both artifact build jobs gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} + build_jaxlib: "false" + build_jax: "false" + write_to_bazel_remote_cache: 1 + run_multiaccelerator_tests: "true" + + run-bazel-test-cuda-py-import: + # Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated + # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we + # still want to run the tests for other platforms. + if: ${{ !cancelled() }} + uses: ./.github/workflows/bazel_cuda.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + # Python values need to match the matrix stategy in the build artifacts job above + runner: ["linux-x86-g2-48-l4-4gpu",] + python: ["3.11"] + cuda-version: ["12", "13"] + enable-x64: [1] + name: "Bazel CUDA Non-RBE with ${{ format('{0}', 'build_jaxlib=wheel') }}" + with: + runner: ${{ matrix.runner }} + python: ${{ matrix.python }} + cuda-version: ${{ matrix.cuda-version }} + enable-x64: ${{ matrix.enable-x64 }} + build_jaxlib: "wheel" + build_jax: "wheel" + jaxlib-version: "head" + write_to_bazel_remote_cache: 1 + run_multiaccelerator_tests: "true" run-pytest-tpu: # Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated @@ -175,12 +237,14 @@ jobs: strategy: fail-fast: false # don't cancel all jobs on failure matrix: - python: ["3.10",] + python: ["3.11"] tpu-specs: [ # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available - {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, - {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} + {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}, + {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"}, + {type: "v7x-8", cores: "4", runner: "linux-x86-tpu7x-224-4tpu"} ] + libtpu-version-type: ["nightly"] name: "Pytest TPU (JAX artifacts version = ${{ format('{0}', 'head') }})" with: runner: ${{ matrix.tpu-specs.runner }} @@ -188,5 +252,34 @@ jobs: tpu-type: ${{ matrix.tpu-specs.type }} python: ${{ matrix.python }} run-full-tpu-test-suite: "1" - libtpu-version-type: "nightly" - gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} \ No newline at end of file + libtpu-version-type: ${{ matrix.libtpu-version-type }} + gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} + + run-bazel-test-tpu: + # Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated + # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we + # still want to run the tests for other platforms. + if: ${{ !cancelled() }} + uses: ./.github/workflows/bazel_test_tpu.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + python: ["3.11"] + tpu-specs: [ + {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, + {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}, + {type: "v7x-8", cores: "4", runner: "linux-x86-tpu7x-224-4tpu"}, + ] + libtpu-version-type: ["nightly"] + name: "Bazel tests TPU (JAX artifacts version = ${{ format('{0}', 'head') }})" + with: + runner: ${{ matrix.tpu-specs.runner }} + cores: ${{ matrix.tpu-specs.cores }} + tpu-type: ${{ matrix.tpu-specs.type }} + python: ${{ matrix.python }} + run-full-tpu-test-suite: "1" + libtpu-version-type: ${{ matrix.libtpu-version-type }} + gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} + build_jaxlib: "wheel" + build_jax: "wheel" + clone_main_xla: 1 \ No newline at end of file diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index adb678be9d9d..f360cc3cfcba 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -1,12 +1,20 @@ # CI - Wheel Tests (Nightly/Release) # -# This workflow builds JAX artifacts and runs CPU/CUDA tests. +# This workflow is used to test the JAX wheels that were built by internal CI jobs. # -# It orchestrates the following: -# 1. run-pytest-cpu: Calls the `pytest_cpu.yml` workflow which downloads the jaxlib wheel that was +# 1. run-pytest-cpu: Calls the `pytest_cpu.yml` workflow which downloads the JAX wheels that were # built by internal CI jobs and runs CPU tests. -# 2. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow which downloads the jaxlib and CUDA -# artifacts that were built by internal CI jobs and runs the CUDA tests. +# 2. run-bazel-test-cpu: Calls the `bazel_cpu.yml` workflow which downloads the +# JAX wheels that were built by internal CI jobs and runs Bazel CPU tests. +# 3. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow which downloads the JAX wheels that were +# built by internal CI jobs and runs CUDA tests. +# 4. run-bazel-test-cuda: Calls the `bazel_cuda.yml` workflow which downloads the JAX wheels +# that were built by internal CI jobs and runs Bazel CUDA tests. +# 5. run-pytest-tpu: Calls the `pytest_tpu.yml` workflow which downloads the JAX wheels that were +# built by internal CI jobs and runs TPU tests. +# 6. run-bazel-test-tpu: Calls the `bazel_test_tpu.yml` workflow which downloads the JAX wheels that +# were built by internal CI jobs and runs Bazel TPU tests. +# 7. verify-release-wheels-install: Verifies that JAX's release wheels can be installed. name: CI - Wheel Tests (Nightly/Release) on: @@ -15,12 +23,26 @@ on: gcs_download_uri: description: "GCS location URI from where the artifacts should be downloaded" required: true - default: 'gs://jax-nightly-release-transient/nightly/latest' + default: 'gs://jax-nightly-artifacts/latest' + type: string + skip-download-jaxlib-and-cuda-plugins-from-gcs: + description: "Whether to download only the jax wheel from GCS (e.g for testing a jax only release)" + required: true + default: '0' + type: string + halt-for-connection: + description: 'Should this workflow run wait for a remote connection? (yes/no)' + required: false + default: 'no' type: string concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} cancel-in-progress: true +permissions: {} + +env: + UV_DEFAULT_INDEX: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple" jobs: run-pytest-cpu: @@ -30,67 +52,245 @@ jobs: matrix: # Runner OS and Python values need to match the matrix stategy of our internal CI jobs # that build the wheels. - runner: ["linux-x86-n2-64", "linux-arm64-t2a-48", "windows-x86-n2-64"] - python: ["3.10","3.11", "3.12", "3.13", "3.13-nogil"] + runner: ["linux-x86-n4-64", "linux-arm64-t2a-48", "windows-x86-n2-64"] + python: ["3.11", "3.12", "3.13", "3.13-nogil", "3.14", "3.14-nogil"] enable-x64: [0] exclude: - runner: "windows-x86-n2-64" python: "3.13-nogil" + - runner: "windows-x86-n2-64" + python: "3.14-nogil" name: "Pytest CPU (JAX artifacts version = ${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})" with: runner: ${{ matrix.runner }} python: ${{ matrix.python }} enable-x64: ${{ matrix.enable-x64 }} + skip-download-jaxlib-from-gcs: ${{inputs.skip-download-jaxlib-and-cuda-plugins-from-gcs}} gcs_download_uri: ${{inputs.gcs_download_uri}} + halt-for-connection: ${{inputs.halt-for-connection}} + # TODO(b/456203132): Increase to more than 16 processes if this works + # (56 prior to this). + max-processes: ${{ contains(matrix.runner, 'windows-x86') && '16' || '' }} + + run-bazel-test-cpu: + uses: ./.github/workflows/bazel_cpu.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + # Runner OS and Python values need to match the matrix stategy of our internal CI jobs + # that build the wheels. + runner: ["linux-x86-n4-64", "linux-arm64-t2a-48"] + python: ["3.11", "3.12", "3.13", "3.13-nogil", "3.14", "3.14-nogil"] + enable-x64: [0] + name: "Bazel CPU tests with ${{ format('{0}', 'build_jaxlib=false') }}" + with: + runner: ${{ matrix.runner }} + python: ${{ matrix.python }} + enable-x64: ${{ matrix.enable-x64 }} + gcs_download_uri: ${{inputs.gcs_download_uri}} + halt-for-connection: ${{inputs.halt-for-connection}} + build_jaxlib: "false" + build_jax: "false" run-pytest-cuda: uses: ./.github/workflows/pytest_cuda.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + # Runner OS and Python values need to match the matrix stategy of our internal CI jobs + # that build the wheels. + runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu", "linux-x86-a4-224-b200-1gpu"] + python: ["3.11", "3.12", "3.13", "3.13-nogil", "3.14", "3.14-nogil"] + cuda: [ + {cuda-version: "12.1", use-nvidia-pip-wheels: false}, + {cuda-version: "12.9", use-nvidia-pip-wheels: true}, + {cuda-version: "13", use-nvidia-pip-wheels: true} + ] + enable-x64: [0] + exclude: + # H100 runs only CUDA 12.9 and min and max Python versions. + - runner: "linux-x86-a3-8g-h100-8gpu" + cuda: + cuda-version: "12.1" + - runner: "linux-x86-a3-8g-h100-8gpu" + python: "3.12" + - runner: "linux-x86-a3-8g-h100-8gpu" + python: "3.13" + - runner: "linux-x86-a3-8g-h100-8gpu" + python: "3.13-nogil" + - runner: "linux-x86-a3-8g-h100-8gpu" + python: "3.14-nogil" + # B200 runs only CUDA 12.9 and min and max Python versions. + - runner: "linux-x86-a4-224-b200-1gpu" + cuda: + cuda-version: "12.1" + - runner: "linux-x86-a4-224-b200-1gpu" + python: "3.12" + - runner: "linux-x86-a4-224-b200-1gpu" + python: "3.13" + - runner: "linux-x86-a4-224-b200-1gpu" + python: "3.13-nogil" + - runner: "linux-x86-a4-224-b200-1gpu" + python: "3.14-nogil" + name: "Pytest CUDA (JAX artifacts version = ${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }}, CUDA Pip packages = ${{ matrix.cuda.use-nvidia-pip-wheels }})" + with: + runner: ${{ matrix.runner }} + python: ${{ matrix.python }} + cuda-version: ${{ matrix.cuda.cuda-version }} + use-nvidia-pip-wheels: ${{ matrix.cuda.use-nvidia-pip-wheels }} + enable-x64: ${{ matrix.enable-x64 }} + skip-download-jaxlib-and-cuda-plugins-from-gcs: ${{inputs.skip-download-jaxlib-and-cuda-plugins-from-gcs}} + gcs_download_uri: ${{inputs.gcs_download_uri}} + halt-for-connection: ${{inputs.halt-for-connection}} + + run-bazel-test-cuda: + uses: ./.github/workflows/bazel_cuda.yml strategy: fail-fast: false # don't cancel all jobs on failure matrix: # Runner OS and Python values need to match the matrix stategy of our internal CI jobs # that build the wheels. runner: ["linux-x86-g2-48-l4-4gpu"] - python: ["3.10","3.11", "3.12", "3.13", "3.13-nogil"] - cuda: ["12.3", "12.1"] + python: ["3.11", "3.12", "3.13", "3.13-nogil", "3.14", "3.14-nogil"] + cuda-version: [12, 13] enable-x64: [0] - name: "Pytest CUDA (JAX artifacts version = ${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})" + name: "Bazel CUDA Non-RBE with ${{ format('{0}', 'build_jaxlib=false') }}" with: runner: ${{ matrix.runner }} python: ${{ matrix.python }} - cuda: ${{ matrix.cuda }} + cuda-version: ${{ matrix.cuda-version }} enable-x64: ${{ matrix.enable-x64 }} gcs_download_uri: ${{inputs.gcs_download_uri}} + halt-for-connection: ${{inputs.halt-for-connection}} + build_jaxlib: "false" + build_jax: "false" + jaxlib-version: "head" + write_to_bazel_remote_cache: 1 + run_multiaccelerator_tests: "true" run-pytest-tpu: uses: ./.github/workflows/pytest_tpu.yml strategy: fail-fast: false # don't cancel all jobs on failure matrix: - # Skip Python 3.13 as it fails due to missing TensorFlow wheels (used for - # profiler_test.py, build/collect-profile-requirements.txt) for that version (b/402590302) - python: ["3.10", "3.11", "3.12"] + python: ["3.11", "3.12", "3.13", "3.13-nogil", "3.14", "3.14-nogil"] + tpu-specs: [ + # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available + {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}, + {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"}, + {type: "v7x-8", cores: "4", runner: "linux-x86-tpu7x-224-4tpu"} + ] + libtpu-version-type: ["pypi_latest", "nightly"] + exclude: + # Exclude nightly for releases + - libtpu-version-type: ${{ startsWith(github.ref_name, 'release/') && 'nightly' }} + # Exclude pypi_latest for nightly releases + - libtpu-version-type: ${{ !startsWith(github.ref_name, 'release/') && 'pypi_latest' }} + # Run Python versions in between min and max for v6e-8 + - tpu-specs: + type: "v6e-8" + python: "3.11" + - tpu-specs: + type: "v6e-8" + python: "3.14" + - tpu-specs: + type: "v6e-8" + python: "3.14-nogil" + # Run max Python versions for v5e-8 + - tpu-specs: + type: "v5e-8" + python: "3.11" + - tpu-specs: + type: "v5e-8" + python: "3.12" + - tpu-specs: + type: "v5e-8" + python: "3.13" + - tpu-specs: + type: "v5e-8" + python: "3.13-nogil" + # Run min and max Python versions for v7x-8 + - tpu-specs: + type: "v7x-8" + python: "3.12" + - tpu-specs: + type: "v7x-8" + python: "3.13" + - tpu-specs: + type: "v7x-8" + python: "3.13-nogil" + + name: "Pytest TPU (JAX artifacts version = ${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})" + with: + runner: ${{ matrix.tpu-specs.runner }} + cores: ${{ matrix.tpu-specs.cores }} + tpu-type: ${{ matrix.tpu-specs.type }} + python: ${{ matrix.python }} + run-full-tpu-test-suite: "1" + libtpu-version-type: ${{ matrix.libtpu-version-type }} + skip-download-jaxlib-from-gcs: ${{inputs.skip-download-jaxlib-and-cuda-plugins-from-gcs}} + gcs_download_uri: ${{inputs.gcs_download_uri}} + halt-for-connection: ${{inputs.halt-for-connection}} + + run-bazel-test-tpu: + uses: ./.github/workflows/bazel_test_tpu.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + python: ["3.11", "3.12", "3.13", "3.13-nogil", "3.14", "3.14-nogil"] tpu-specs: [ # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, - {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} + {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}, + {type: "v7x-8", cores: "4", runner: "linux-x86-tpu7x-224-4tpu"}, ] - libtpu-version-type: ["pypi_latest", "nightly", "oldest_supported_libtpu"] + libtpu-version-type: ["pypi_latest", "nightly"] exclude: + # Exclude nightly for releases - libtpu-version-type: ${{ startsWith(github.ref_name, 'release/') && 'nightly' }} + # Exclude pypi_latest for nightly releases - libtpu-version-type: ${{ !startsWith(github.ref_name, 'release/') && 'pypi_latest' }} - # Run a single Python version for v4-8. + # Run a single Python version for v4-8 - tpu-specs: type: "v4-8" - python: "3.10" + python: "3.11" - tpu-specs: type: "v4-8" - python: "3.11" - # Run min and max Python versions for v5e-8 + python: "3.12" + - tpu-specs: + type: "v4-8" + python: "3.13" + - tpu-specs: + type: "v4-8" + python: "3.13-nogil" + - tpu-specs: + type: "v4-8" + python: "3.14-nogil" + # Run max Python versions for v5e-8 - tpu-specs: type: "v5e-8" python: "3.11" - name: "Pytest TPU (JAX artifacts version = ${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})" + - tpu-specs: + type: "v5e-8" + python: "3.12" + - tpu-specs: + type: "v5e-8" + python: "3.13" + - tpu-specs: + type: "v5e-8" + python: "3.13-nogil" + # Run min and max Python versions for v7x-8 + - tpu-specs: + type: "v7x-8" + python: "3.12" + - tpu-specs: + type: "v7x-8" + python: "3.13" + - tpu-specs: + type: "v7x-8" + python: "3.13-nogil" + + name: "Bazel tests TPU (JAX artifacts version = ${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})" with: runner: ${{ matrix.tpu-specs.runner }} cores: ${{ matrix.tpu-specs.cores }} @@ -98,4 +298,95 @@ jobs: python: ${{ matrix.python }} run-full-tpu-test-suite: "1" libtpu-version-type: ${{ matrix.libtpu-version-type }} - gcs_download_uri: ${{inputs.gcs_download_uri}} \ No newline at end of file + skip-download-jaxlib-from-gcs: ${{inputs.skip-download-jaxlib-and-cuda-plugins-from-gcs}} + gcs_download_uri: ${{inputs.gcs_download_uri}} + halt-for-connection: ${{inputs.halt-for-connection}} + build_jaxlib: "false" + build_jax: "false" + jaxlib-version: "head" + clone_main_xla: 0 + + verify-release-wheels-install: + if: ${{ startsWith(github.ref_name, 'release/')}} + defaults: + run: + # Set the shell to bash as GitHub actions runs with /bin/sh by default + shell: bash + runs-on: linux-x86-n4-16 + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + python: ["3.11", "3.13", "3.13-nogil"] + cuda-version: [12, 13] + container: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest" + + # Verifies that JAX's release wheels can be installed + name: "Verify release wheels install (Python ${{ matrix.python }})" + + env: + PYTHON: "python${{ matrix.python }}" + + steps: + - name: Download release wheels from GCS + run: | + mkdir -p $(pwd)/dist + final_gcs_download_uri=${{ inputs.gcs_download_uri }} + + # Get the major and minor version of Python. + # E.g if python=3.11, then python_major_minor=311 + # E.g if python=3.13-nogil, then python_major_minor=313t + python_major_minor=${{ matrix.python }} + python_major_minor=$(echo "${python_major_minor//-nogil/t}" | tr -d '.') + python_major_minor="cp${python_major_minor%t}-cp${python_major_minor}-" + + gcloud storage cp -r "${final_gcs_download_uri}"/jax*py3*none*any.whl $(pwd)/dist/ + + jax_wheel=$(ls dist/jax*py3*none*any.whl 2>/dev/null) + echo "JAX_WHEEL=$jax_wheel" >> $GITHUB_ENV + + if [[ "${{ inputs.skip-download-jaxlib-and-cuda-plugins-from-gcs }}" != "1" ]]; then + gcloud storage cp -r "${final_gcs_download_uri}/jaxlib*${python_major_minor}*linux*x86_64*.whl" $(pwd)/dist/ + gcloud storage cp -r "${final_gcs_download_uri}/jax*cuda${{ matrix.cuda-version }}*plugin*${python_major_minor}*linux*x86_64*.whl" $(pwd)/dist/ + gcloud storage cp -r "${final_gcs_download_uri}/jax*cuda${{ matrix.cuda-version }}*pjrt*linux*x86_64*.whl" $(pwd)/dist/ + + jaxlib_wheel=$(ls dist/jaxlib*${python_major_minor}*linux*x86_64*.whl 2>/dev/null) + jax_cuda_plugin_wheel=$(ls dist/jax*cuda${{ matrix.cuda-version }}*plugin*${python_major_minor}*linux*x86_64*.whl 2>/dev/null) + jax_cuda_pjrt_wheel=$(ls dist/jax*cuda${{ matrix.cuda-version }}*pjrt*linux*x86_64*.whl 2>/dev/null) + + echo "JAXLIB_WHEEL=$jaxlib_wheel" >> $GITHUB_ENV + echo "JAX_CUDA_PLUGIN_WHEEL=$jax_cuda_plugin_wheel" >> $GITHUB_ENV + echo "JAX_CUDA_PJRT_WHEEL=$jax_cuda_pjrt_wheel" >> $GITHUB_ENV + fi + - name: Verify JAX CPU packages can be installed + run: | + $PYTHON -m uv venv ~/test_cpu && source ~/test_cpu/bin/activate + if [[ "${{ inputs.skip-download-jaxlib-and-cuda-plugins-from-gcs }}" == "1" ]]; then + uv pip install $JAX_WHEEL + else + uv pip install $JAX_WHEEL $JAXLIB_WHEEL + fi + - name: Verify JAX TPU packages can be installed + run: | + $PYTHON -m uv venv ~/test_tpu && source ~/test_tpu/bin/activate + + if [[ "${{ inputs.skip-download-jaxlib-and-cuda-plugins-from-gcs }}" == "1" ]]; then + uv pip install $JAX_WHEEL[tpu] + else + uv pip install $JAX_WHEEL[tpu] $JAXLIB_WHEEL + fi + - name: Verify JAX CUDA packages can be installed (Nvidia Pip Packages) + run: | + $PYTHON -m uv venv ~/test_cuda_pip && source ~/test_cuda_pip/bin/activate + if [[ "${{ inputs.skip-download-jaxlib-and-cuda-plugins-from-gcs }}" == "1" ]]; then + uv pip install $JAX_WHEEL[cuda${{ matrix.cuda-version }}] + else + uv pip install $JAX_WHEEL[cuda${{ matrix.cuda-version }}] $JAXLIB_WHEEL $JAX_CUDA_PJRT_WHEEL $JAX_CUDA_PLUGIN_WHEEL[with-cuda] $EXTRA_INDEX + fi + - name: Verify JAX CUDA packages can be installed (CUDA local) + run: | + $PYTHON -m uv venv ~/test_cuda_local && source ~/test_cuda_local/bin/activate + if [[ "${{ inputs.skip-download-jaxlib-and-cuda-plugins-from-gcs }}" == "1" ]]; then + uv pip install $JAX_WHEEL[cuda${{ matrix.cuda-version }}-local] + else + uv pip install $JAX_WHEEL $JAXLIB_WHEEL $JAX_CUDA_PJRT_WHEEL $JAX_CUDA_PLUGIN_WHEEL + fi diff --git a/.github/workflows/wheel_win_x64.yml b/.github/workflows/wheel_win_x64.yml deleted file mode 100644 index 444bc83f2889..000000000000 --- a/.github/workflows/wheel_win_x64.yml +++ /dev/null @@ -1,64 +0,0 @@ -name: Wheel build - Windows CPU x86_64 -on: - workflow_dispatch: # allows triggering the workflow run manually - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} - cancel-in-progress: true - -env: - DISTUTILS_USE_SDK: 1 - MSSdk: 1 - -jobs: - win-wheels: - strategy: - fail-fast: false # Don't stop all wheel builds if one has a test failure. - matrix: - os: [windows-2019-32core] - arch: [AMD64] - pyver: ['3.10', '3.11', '3.12', '3.13'] - name: ${{ matrix.os }} ${{ matrix.pyver }} jaxlib wheel build - runs-on: ${{ matrix.os }} - - steps: - - name: Install LLVM/Clang - run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade - - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 - with: - python-version: ${{ matrix.pyver }} - cache: 'pip' - - - name: Build wheels - env: - BAZEL_VC: "C:\\Program Files (x86)\\Microsoft Visual Studio\\2019\\Enterprise\\VC" - JAXLIB_RELEASE: true - run: | - python -m pip install uv~=0.5.30 - python -m uv pip install -r build/test-requirements.txt \ - --upgrade numpy==2.0.0 scipy==1.13.1 - "C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH - python.exe build\build.py build --wheels=jaxlib ` - --bazel_options=--color=yes ` - --bazel_options=--config=win_clang ` - --verbose - - - uses: actions/upload-artifact@6f51ac03b9356f520e9adb1b1b7802705f340c2b # v4.5.0 - with: - name: wheels-${{ matrix.os }}-${{ matrix.pyver }} - path: ${{ github.workspace }}\dist\*.whl - retention-days: 5 - - - name: Run tests - env: - JAX_ENABLE_CHECKS: true - JAX_SKIP_SLOW_TESTS: true - PY_COLORS: 1 - run: | - python -m uv pip install --find-links ${{ github.workspace }}\dist jaxlib \ - -e ${{ github.workspace }} - echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" - pytest -n auto --tb=short tests examples diff --git a/.github/workflows/windows_ci.yml b/.github/workflows/windows_ci.yml deleted file mode 100644 index fc2b63396f56..000000000000 --- a/.github/workflows/windows_ci.yml +++ /dev/null @@ -1,73 +0,0 @@ -name: CI - Windows CPU -on: - schedule: - - cron: "0 12 * * *" # Daily at 12:00 UTC - workflow_dispatch: # allows triggering the workflow run manually - pull_request: - types: [ labeled ] # allow force-windows-run label - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} - cancel-in-progress: true - -env: - DISTUTILS_USE_SDK: 1 - MSSdk: 1 - -jobs: - win-wheels: - if: ${{ (github.event.action != 'labeled') || (github.event.label.name == 'windows:force-run')}} - strategy: - fail-fast: true - matrix: - os: [windows-2019-32core] - arch: [AMD64] - pyver: ['3.10'] - name: Windows CI build - runs-on: ${{ matrix.os }} - - steps: - - - name: Install LLVM/Clang - run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade - - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - path: jax - - - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 - with: - python-version: ${{ matrix.pyver }} - cache: 'pip' - - - name: Build wheels - env: - BAZEL_VC: "C:\\Program Files (x86)\\Microsoft Visual Studio\\2019\\Enterprise\\VC" - JAXLIB_NIGHTLY: true # Tag the wheels as dev versions - run: | - cd jax - python -m pip install uv~=0.5.30 - python -m uv pip install -r build/test-requirements.txt --upgrade numpy==2.0.0 scipy==1.13.1 - "C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH - python.exe build\build.py build --wheels=jaxlib ` - --bazel_options=--color=yes ` - --bazel_options=--config=win_clang ` - --verbose - - - uses: actions/upload-artifact@6f51ac03b9356f520e9adb1b1b7802705f340c2b # v4.5.0 - with: - name: wheels - path: ${{ github.workspace }}\jax\dist\*.whl - retention-days: 5 - - - name: Run tests - env: - JAX_ENABLE_CHECKS: true - JAX_SKIP_SLOW_TESTS: true - PY_COLORS: 1 - run: | - cd jax - python -m uv pip install --pre --find-links ${{ github.workspace }}\jax\dist jaxlib ` - -e ${{ github.workspace }}\jax - echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" - pytest -n auto --tb=short tests examples diff --git a/.gitignore b/.gitignore index 83f1780df946..fe90ba9297af 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ *.whl /build/lib /build/bazel* +/build/__editable__.jax-* /dist/ .ipynb_checkpoints /bazel-* @@ -29,3 +30,7 @@ jax.iml /include/ /lib/ /share/ + +/compile_commands.json +/strace.txt +/external diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 27ccc6d831f3..6e619a81854e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,12 +9,17 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: 2c9f875913ee60ca25ce70243dc24d5b6415598c # frozen: v4.6.0 + rev: 3e8a8703264a2f4a69428a0aa4dcb512790b2c8c # frozen: v6.0.0 hooks: - id: check-ast - id: check-merge-conflict - id: check-toml - id: check-yaml + exclude: | + (?x)^( + examples/k8s/svc-acct\.yaml | + ci/k8s/indexed-job\.yaml + )$ - id: end-of-file-fixer # only include python files files: \.py$ @@ -26,17 +31,17 @@ repos: files: \.py$ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: 8983acb92ee4b01924893632cf90af926fa608f0 # frozen: v0.7.0 + rev: a113f03edeabb71305f025e6e14bd2cd68660e29 # frozen: v0.13.1 hooks: - id: ruff - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'bbc3dc1f890007061f18f17e2334f216ea9e5df7' # frozen: v1.14.1 + rev: 9f70dc58c23dfcca1b97af99eaeee3140a807c7e # frozen: v1.18.2 hooks: - id: mypy files: (jax/|tests/typing_test\.py) - exclude: jax/_src/basearray.py|jax/numpy/__init__.py # Use pyi instead - additional_dependencies: [types-requests==2.31.0, jaxlib, numpy>=2.2.0] + exclude: jax/_src/basearray.py|jax/numpy/__init__.py|jax/nn/__init__.py|jaxlib/_jax/.* # Use pyi instead + additional_dependencies: [types-requests==2.31.0, numpy~=2.3.0, scipy-stubs] args: [--config=pyproject.toml] - repo: https://github.com/mwouts/jupytext diff --git a/.readthedocs.yml b/.readthedocs.yml index 6f807aa82377..0ac20301cee2 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -6,9 +6,23 @@ version: 2 build: - os: "ubuntu-22.04" + os: "ubuntu-24.04" tools: - python: "3.10" + python: "3.12" + jobs: + post_checkout: + # Skip building PRs unless tagged with the "documentation" label. + - | + [ "${READTHEDOCS_VERSION_TYPE}" != "external" ] && echo "Building latest" && exit 0 + (curl -sL https://api.github.com/repos/jax-ml/jax/issues/${READTHEDOCS_VERSION}/labels | grep -q "https://api.github.com/repos/jax-ml/jax/labels/documentation") && echo "Building PR with label" || exit 183 + create_environment: + - asdf plugin add uv + - asdf install uv latest + - asdf global uv latest + - uv venv $READTHEDOCS_VIRTUALENV_PATH + - UV_PROJECT_ENVIRONMENT=$READTHEDOCS_VIRTUALENV_PATH uv pip install -r docs/requirements.txt + install: + - "true" # skip # Build documentation in the docs/ directory with Sphinx sphinx: @@ -18,8 +32,3 @@ sphinx: # Optionally build your docs in additional formats such as PDF and ePub formats: - htmlzip - -# Optionally set the version of Python and requirements required to build your docs -python: - install: - - requirements: docs/requirements.txt diff --git a/BUILD.bazel b/BUILD.bazel index 33cbefd29f0b..bf0c17629f0d 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -12,38 +12,78 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@xla//third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps") +load("@rules_python//python:py_binary.bzl", "py_binary") +load( + "@xla//third_party/py:py_import.bzl", + "py_import", +) load( "//jaxlib:jax.bzl", + "jax_source_package", "jax_wheel", + "pytype_test", + "wheel_sources", ) -collect_data_files( - name = "transitive_py_data", - deps = ["//jax"], -) - -transitive_py_deps( - name = "transitive_py_deps", - deps = [ +wheel_sources( + name = "jax_sources", + data_srcs = ["//jax"], + py_srcs = [ "//jax", - "//jax:compilation_cache", - "//jax:experimental", - "//jax:experimental_colocated_python", - "//jax:experimental_sparse", - "//jax:lax_reference", - "//jax:pallas_experimental_gpu_ops", - "//jax:pallas_gpu_ops", - "//jax:pallas_mosaic_gpu", - "//jax:pallas_tpu_ops", - "//jax:pallas_triton", - "//jax:source_mapper", - "//jax:sparse_test_util", - "//jax:test_util", + "//jax/example_libraries:example_libraries", + "//jax/example_libraries:optimizers", + "//jax/example_libraries:stax", + "//jax/experimental:buffer_callback", + "//jax/experimental:checkify", + "//jax/experimental:colocated_python", + "//jax/experimental:compilation_cache", + "//jax/experimental:compute_on", + "//jax/experimental:custom_dce", + "//jax/experimental:custom_partitioning", + "//jax/experimental:fused", + "//jax/experimental:hijax", + "//jax/experimental:jet", + "//jax/experimental:layout", + "//jax/experimental:mesh_utils", + "//jax/experimental:multihost_utils", + "//jax/experimental:ode", + "//jax/experimental:pallas_experimental_gpu_ops", + "//jax/experimental:pallas_fuser", + "//jax/experimental:pallas_gpu_ops", + "//jax/experimental:pallas_mosaic_gpu", + "//jax/experimental:pallas_tpu_ops", + "//jax/experimental:pallas_triton", + "//jax/experimental:pjit", + "//jax/experimental:profiler", + "//jax/experimental:random", + "//jax/experimental:rnn", + "//jax/experimental:scheduling_groups", + "//jax/experimental:serialize_executable", + "//jax/experimental:shard_alike", + "//jax/experimental:shard_map", + "//jax/experimental:source_mapper", + "//jax/experimental:sparse_test_util", + "//jax/experimental:sparse", + "//jax/experimental:topologies", + "//jax/experimental:transfer", + "//jax/experimental:xla_metadata", + "//jax/experimental", + "//jax/_src:lax_reference", + "//jax/_src:internal_export_back_compat_test_util", + "//jax/_src:internal_export_back_compat_test_data", + "//jax/_src:internal_test_harnesses", + "//jax/_src:internal_test_util", + "//jax/_src:test_multiprocess", + "//jax/_src:test_util", "//jax/_src/lib", + "//jax/_src/pallas/fuser", "//jax/_src/pallas/mosaic_gpu", + "//jax/_src/pallas/mosaic_gpu/interpret:interpret_pallas_call", "//jax/experimental/array_serialization:serialization", + "//jax/experimental/array_serialization:pytree_serialization", "//jax/experimental/jax2tf", + "//jax/experimental/mosaic/gpu/examples:flash_attention", + "//jax/experimental/mosaic/gpu/examples:matmul", "//jax/extend", "//jax/extend:ifrt_programs", "//jax/extend/mlir", @@ -52,6 +92,14 @@ transitive_py_deps( "//jax/tools:jax_to_ir", "//jax/tools:pgo_nsys_converter", ], + static_srcs = [ + "//jax:py.typed", + "AUTHORS", + "LICENSE", + "README.md", + "pyproject.toml", + "setup.py", + ], ) py_binary( @@ -59,26 +107,88 @@ py_binary( srcs = ["build_wheel.py"], deps = [ "//jaxlib/tools:build_utils", - "@pypi_build//:pkg", - "@pypi_setuptools//:pkg", - "@pypi_wheel//:pkg", + "@pypi//build", + "@pypi//setuptools", + "@pypi//wheel", ], ) jax_wheel( name = "jax_wheel", - build_wheel_only = False, platform_independent = True, - source_files = [ - ":transitive_py_data", - ":transitive_py_deps", - "//jax:py.typed", - "AUTHORS", - "LICENSE", - "README.md", - "pyproject.toml", - "setup.py", - ], + source_files = [":jax_sources"], wheel_binary = ":build_wheel", wheel_name = "jax", ) + +jax_wheel( + name = "jax_wheel_editable", + editable = True, + platform_independent = True, + source_files = [":jax_sources"], + wheel_binary = ":build_wheel", + wheel_name = "jax", +) + +jax_source_package( + name = "jax_source_package", + source_files = [":jax_sources"], + source_package_binary = ":build_wheel", + source_package_name = "jax", +) + +genrule( + name = "wheel_additives", + testonly = True, + srcs = [ + "//jax/_src:internal_test_harnesses", + "//jax/_src:internal_test_util", + "//jax/_src:internal_export_back_compat_test_util", + "//jax/_src:internal_export_back_compat_test_data", + "//jax/experimental/jax2tf/tests:jax2tf_tests", + "//jax/experimental/jax2tf/tests:tf_test_util", + "//jax/experimental/mosaic/gpu/examples:flash_attention.py", + "//jax/experimental/mosaic/gpu/examples:gpu_examples", + "//jax/experimental/mosaic/gpu/examples:matmul.py", + "//jax/experimental/mosaic/gpu/examples:matmul_blackwell", + "//jax/_src:test_multiprocess", + "//jax/_src/pallas:pallas_test_util", + "//jax/experimental:mosaic_gpu_test_util", + ], + outs = ["wheel_additives.zip"], + cmd = "$(location @bazel_tools//tools/zip:zipper) c $@ $(SRCS)", + tools = ["@bazel_tools//tools/zip:zipper"], + visibility = ["//jax:internal"], +) + +py_import( + name = "jax_py_import", + testonly = True, + wheel = ":jax_wheel", + zip_deps = [":wheel_additives"], +) + +# This target is used to add more sources to the jax wheel. +# This is needed for the tests that depend on jax and use modules that are not part of +# the jax wheel, but share the same package paths as the modules in the jax wheel. +py_import( + name = "jax_wheel_with_internal_test_util", + testonly = True, + wheel = "@pypi_jax//:whl", + zip_deps = [":wheel_additives"], +) + +pytype_test( + name = "jax_wheel_size_test", + srcs = ["//jaxlib/tools:wheel_size_test.py"], + args = [ + "--wheel-path=$(location :jax_wheel)", + "--max-size-mib=5", + ], + data = [":jax_wheel"], + main = "wheel_size_test.py", + tags = [ + "manual", + "notap", + ], +) diff --git a/CHANGELOG.md b/CHANGELOG.md index c30877ecae14..ba3d06d77e2e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Change log -Best viewed [here](https://jax.readthedocs.io/en/latest/changelog.html). +Best viewed [here](https://docs.jax.dev/en/latest/changelog.html). For the changes specific to the experimental Pallas APIs, see {ref}`pallas-changelog`. @@ -16,6 +16,434 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## Unreleased +## JAX 0.9.0 (TODO) + +* New features: + + * Added {func}`jax.thread_guard`, a context manager that detects when devices + are used by multiple threads in multi-controller JAX. + +* Bug fixes: + * Fixed a workspace size calculation error for pivoted QR (`magma_zgeqp3_gpu`) + in MAGMA 2.9.0 when using `use_magma=True` and `pivoting=True`. + ({jax-issue}`#34145`). + +* Deprecations: + * The flag `jax_collectives_common_channel_id` was removed. + * The `jax_pmap_no_rank_reduction` config state has been removed. The + no-rank-reduction behavior is now the only supported behavior: a + `jax.pmap`ped function `f` sees inputs of the same rank as the input to + `jax.pmap(f)`. For example, if `jax.pmap(f)` receives shape `(8, 128)` on + 8 devices, then `f` receives shape `(1, 128)`. + * Setting the `jax_pmap_shmap_merge` config state is deprecated in JAX v0.9.0 + and will be removed in JAX v0.10.0. + * {func}`jax.numpy.fix` is deprecated, anticipating the deprecation of + {func}`numpy.fix` in NumPy v2.5.0. {func}`jax.numpy.trunc` is a drop-in + replacement. + +* Changes: + * {func}`jax.export` now supports explicit sharding. This required a new + export serialization format version that includes the NamedSharding, + including the abstract mesh, and the partition spec. As part of this + change we have added a restriction in the use of exported modules: when + calling them the abstract mesh must match the one used at export time, + including the axis names. Previously, only the number of the devices + mattered. + +## JAX 0.8.2 (December 18, 2025) + +* Deprecations + * `jax.lax.pvary` has been deprecated. + Please use `jax.lax.pcast(..., to='varying')` as the replacement. + * Complex arguments passed to {func}`jax.numpy.arange` now result in a + deprecation warning, because the output is poorly-defined. + * From {mod}`jax.core` a number of symbols are newly deprecated including: + `call_impl`, `get_aval`, `mapped_aval`, `subjaxprs`, `set_current_trace`, + `take_current_trace`, `traverse_jaxpr_params`, `unmapped_aval`, + `AbstractToken`, and `TraceTag`. + * All symbols in {mod}`jax.interpreters.pxla` are deprecated. These are + primarily JAX internal APIs, and users should not rely on them. + +* Changes: + * jax's `Tracer` no longer inherits from `jax.Array` at runtime. However, + `jax.Array` now uses a custom metaclass such `isinstance(x, Array)` is true + if an object `x` represents a traced `Array`. Only some `Tracer`s represent + `Array`s, so it is not correct for `Tracer` to inherit from `Array`. + + For the moment, during Python type checking, we continue to declare `Tracer` + as a subclass of `Array`, however we expect to remove this in a future + release. + * `jax.experimental.si_vjp` has been deleted. + `jax.vjp` subsumes it's functionality. + +## JAX 0.8.1 (November 18, 2025) + +* New features: + + * {func}`jax.jit` now supports the decorator factory pattern; i.e instead of + writing + ``` + @functools.partial(jax.jit, static_argnames=['n']) + def f(x, n): + ... + ``` + you may write + ``` + @jax.jit(static_argnames=['n']) + def f(x, n): + ... + ``` + +* Changes: + + * {func}`jax.lax.linalg.eigh` now accepts an `implementation` argument to + select between QR (CPU/GPU), Jacobi (GPU/TPU), and QDWH (TPU) + implementations. The `EighImplementation` enum is publicly exported from + {mod}`jax.lax.linalg`. + + * {func}`jax.lax.linalg.svd` now implements an `algorithm` that uses the polar + decomposition on CUDA GPUs. This is also an alias for the existing algorithm + on TPUs. + +* Bug fixes: + + * Fixed a bug introduced in JAX 0.7.2 where eigh failed for large matrices on + GPU (({jax-issue}`#33062`). + +* Deprecations: + * `jax.sharding.PmapSharding` is now deprecated. Please use + `jax.NamedSharding` instead. + * `jx.device_put_replicated` is now deprecated. Please use `jax.device_put` + with the appropriate sharding instead. + * `jax.device_put_sharded` is now deprecated. Please use `jax.device_put` with + the appropriate sharding instead. + * Default `axis_types` of `jax.make_mesh` will change in JAX v0.9.0 to return + `jax.sharding.AxisType.Explicit`. Leaving axis_types unspecified will raise a + `DeprecationWarning`. + * {mod}`jax.cloud_tpu_init` and its contents were deprecated. There is no reason for a user to import or use the contents of this module; JAX handles this for you automatically if needed. + +## JAX 0.8.0 (October 15, 2025) + +* Breaking changes: + + * JAX is changing the default `jax.pmap` implementation to one implemented in + terms of `jax.jit` and `jax.shard_map`. `jax.pmap` is in maintenance mode + and we encourage all new code to use `jax.shard_map` directly. See the + [migration guide](https://docs.jax.dev/en/latest/migrate_pmap.html) for + more information. + * The `auto=` parameter of `jax.experimental.shard_map.shard_map` has been + removed. This means that `jax.experimental.shard_map.shard_map` no longer + supports nesting. If you want to nest shard_map calls, please use + `jax.shard_map`. + * JAX no longer allows passing objects that support `__jax_array__` directly + to, e.g. `jit`-ed functions. Call `jax.numpy.asarray` on them first. + * {func}`jax.numpy.cov` is now returns NaN for empty arrays ({jax-issue}`#32305`), + and matches NumPy 2.2 behavior for single-row design matrices ({jax-issue}`#32308`). + * JAX no longer accepts `Array` values where a `dtype` value is expected. Call + `.dtype` on these values first. + * The deprecated function {func}`jax.interpreters.mlir.custom_call` was + removed. + * The `jax.util`, `jax.extend.ffi`, and `jax.experimental.host_callback` + modules have been removed. All public APIs within these modules were + deprecated and removed in v0.7.0 or earlier. + * The deprecated symbol {obj}`jax.custom_derivatives.custom_jvp_call_jaxpr_p` + was removed. + * `jax.experimental.multihost_utils.process_allgather` raises an error when + the input is a jax.Array and not fully-addressable and `tiled=False`. To fix + this, pass `tiled=True` to your `process_allgather` invocation. + * from {mod}`jax.experimental.compilation_cache`, the deprecated symbols + `is_initialized` and `initialize_cache` were removed. + * The deprecated function {func}`jax.interpreters.xla.canonicalize_dtype` + was removed. + * {mod}`jaxlib.hlo_helpers` has been removed. Use {mod}`jax.ffi` instead. + * The option `jax_cpu_enable_gloo_collectives` has been removed. Use + `jax_cpu_collectives_implementation` instead. + * The previously-deprecated `interpolation` argument to + {func}`jax.numpy.percentile` and {func}`jax.numpy.quantile` has been + removed; use `method` instead. + * The JAX-internal `for_loop` primitive was removed. Its functionality, + reading from and writing to refs in the loop body, is now directly + supported by {func}`jax.lax.fori_loop`. If you need help updating your + code, please file a bug. + * {func}`jax.numpy.trimzeros` now errors for non-1D input. + * The `where` argument to {func}`jax.numpy.sum` and other reductions is now + required to be boolean. Non-boolean values have resulted in a + `DeprecationWarning` since JAX v0.5.0. + * The deprecated functions in {mod} `jax.dlpack`, {mod} `jax.errors`, {mod} + `jax.lib.xla_bridge`, {mod} `jax.lib.xla_client`, and {mod} + `jax.lib.xla_extension` were removed. + * `jax.interpreters.mlir.dense_bool_array` was removed. Use MLIR APIs to + construct attributes instead. + +* Changes + * {func}`jax.numpy.linalg.eig` now returns a namedtuple (with attributes + `eigenvalues` and `eigenvectors`) instead of a plain tuple. + * {func}`jax.grad` and {func}`jax.vjp` will now round always primals to + `float32` if `float64` mode is not enabled. + * {func}`jax.dlpack.from_dlpack` now accepts arrays with non-default layouts, + for example, transposed. + * The default nonsymmetric eigendecomposition on NVIDIA GPUs now uses + cusolver. The magma and LAPACK implementations are still available via the + new `implementation` argument to {func}`jax.lax.linalg.eig` + ({jax-issue}`#27265`). The `use_magma` argument is now deprecated in favor + of `implementation`. + * {func}`jax.numpy.trim_zeros` now follows NumPy 2.2 in supporting + multi-dimensional inputs. + +* Deprecations + * {func}`jax.experimental.enable_x64` and {func}`jax.experimental.disable_x64` + are deprecated in favor of the new non-experimental context manager + {func}`jax.enable_x64`. + * {func}`jax.experimental.shard_map.shard_map` is deprecated; going forward use + {func}`jax.shard_map`. + * {func}`jax.experimental.pjit.pjit` is deprecated; going forward use + {func}`jax.jit`. + +## JAX 0.7.2 (September 16, 2025) + +* Breaking changes: + + * {func}`jax.dlpack.from_dlpack` no longer accepts a DLPack capsule. This + behavior was deprecated and is now removed. The function must be called + with an array implementing `__dlpack__` and `__dlpack_device__`. + +* Changes + * The minimum supported NumPy version is now 2.0. Since SciPy 1.13 is required + for NumPy 2.0 support, the minimum supported SciPy version is now 1.13. + + * JAX now represents constants in its internal jaxpr representation as a + `TypedNdArray`, which is a private JAX type that duck types as a + `numpy.ndarray`. This type may be exposed to users via `custom_jvp` rules, + for example, and may break code that uses `isinstance(x, np.ndarray)`. If + this breaks your code, you may convert these arrays to classic NumPy arrays + using `np.asarray(x)`. + +* Bug fixes + * `arr.view(dtype=None)` now returns the array unchanged, matching NumPy's + semantics. Previously it returned the array with a float dtype. + * `jax.random.randint` now produces a less-biased distribution for 8-bit and + 16-bit integer types ({jax-issue}`#27742`). To restore the previous biased + behavior, you may temporarily set the `jax_safer_randint` configuration to + `False`, but note this is a temporary config that will be removed in a + future release. + +* Deprecations: + * The parameters `enable_xla` and `native_serialization` for `jax2tf.convert` + are deprecated and will be removed in a future version of JAX. These were + used for jax2tf with non-native serialization, which has been now removed. + * Setting the config state `jax_pmap_no_rank_reduction` to `False` is + deprecated. By default, `jax_pmap_no_rank_reduction` will be set to `True` + and `jax.pmap` shards will not have their rank reduced, keeping the same + rank as their enclosing array. + +## JAX 0.7.1 (August 20, 2025) + +* New features + * JAX now ships Python 3.14 and 3.14t wheels. + * JAX now ships Python 3.13t and 3.14t wheels on Mac. Previously we only + offered free-threading builds on Linux. + +* Changes + * Exposed `jax.set_mesh` which acts as a global setter and a context manager. + Removed `jax.sharding.use_mesh` in favor of `jax.set_mesh`. + * JAX is now built using CUDA 12.9. All versions of CUDA 12.1 or newer remain + supported. + * {func}`jax.lax.dot` now implements the general dot product via the optional + ``dimension_numbers`` argument. + +* Deprecations: + + * {func}`jax.lax.zeros_like_array` is deprecated. Please use + {func}`jax.numpy.zeros_like` instead. + * Attempting to import {mod}`jax.experimental.host_callback` now results in + a `DeprecationWarning`, and will result in an `ImportError` starting in JAX + v0.8.0. Its APIs have raised `NotImplementedError` since JAX version 0.4.35. + * In {func}`jax.lax.dot`, passing the ``precision`` and ``preferred_element_type`` + arguments by position is deprecated. Pass them by explicit keyword instead. + * Several dozen internal APIs have been deprecated from {mod}`jax.interpreters.ad`, + {mod}`jax.interpreters.batching`, and {mod}`jax.interpreters.partial_eval`; they + are used rarely if ever outside JAX itself, and most are deprecated without any + public replacement. + + +## JAX 0.7.0 (July 22, 2025) + +* New features: + * Added `jax.P` which is an alias for `jax.sharding.PartitionSpec`. + * Added {func}`jax.tree.reduce_associative`. + * The {attr}`jax.numpy.ndarray.at` indexing methods now support a `wrap_negative_indices` + argument, which defaults to `True` to match the current behavior ({jax-issue}`#29434`). + +* Breaking changes: + * JAX is migrating from GSPMD to Shardy by default. See the + [migration guide](https://docs.jax.dev/en/latest/shardy_jax_migration.html) + for more information. + * JAX autodiff is switching to using direct linearization by default (instead of + implementing linearization via JVP and partial eval). + See [migration guide](https://docs.jax.dev/en/latest/direct_linearize_migration.html) + for more information. + * `jax.stages.OutInfo` has been replaced with `jax.ShapeDtypeStruct`. + * {func}`jax.jit` now requires `fun` to be passed by position, and additional + arguments to be passed by keyword. Doing otherwise will result in an error + starting in v0.7.x. This raised a DeprecationWarning in v0.6.x. + * The minimum Python version is now 3.11. 3.11 will remain the minimum + supported version until July 2026. + * Layout API renames: + * `Layout`, `.layout`, `.input_layouts` and `.output_layouts` have been + renamed to `Format`, `.format`, `.input_formats` and `.output_formats` + * `DeviceLocalLayout`, `.device_local_layout` have been renamed to `Layout` + and `.layout` + * `jax.experimental.shard` module has been deleted and all the APIs have been + moved to the `jax.sharding` endpoint. So use `jax.sharding.reshard`, + `jax.sharding.auto_axes` and `jax.sharding.explicit_axes` instead of their + experimental endpoints. + * `lax.infeed` and `lax.outfeed` were removed, after being deprecated in + JAX 0.6. The `transfer_to_infeed` and `transfer_from_outfeed` methods were + also removed the `Device` objects. + * The `jax.extend.core.primitives.pjit_p` primitive has been renamed to + `jit_p`, and its `name` attribute has changed from `"pjit"` to `"jit"`. + This affects the string representations of jaxprs. The same primitive is no + longer exported from the `jax.experimental.pjit` module. + * The (undocumented) function `jax.extend.backend.add_clear_backends_callback` + has been removed. Users should use `jax.extend.backend.register_backend_cache` + instead. + * `out_sharding` arg added to `x.at[y].set` and `x.at[y].add`. Previous + behavior propagating operand sharding removed. Please use + `x.at[y].set/add(z, out_sharding=jax.typeof(x).sharding)` to retain previous + behavior if scatter op requires collectives. + +* Deprecations: + * {obj}`jax.dlpack.SUPPORTED_DTYPES` is deprecated; please use the new + {func}`jax.dlpack.is_supported_dtype` function. + * {func}`jax.scipy.special.sph_harm` has been deprecated following a similar + deprecation in SciPy; use {func}`jax.scipy.special.sph_harm_y` instead. + * From {mod}`jax.interpreters.xla`, the previously deprecated symbols + `abstractify` and `pytype_aval_mappings` have been removed. + * {func}`jax.interpreters.xla.canonicalize_dtype` is deprecated. For + canonicalizing dtypes, prefer {func}`jax.dtypes.canonicalize_dtype`. + For checking whether an object is a valid jax input, prefer + {func}`jax.core.valid_jaxtype`. + * From {mod}`jax.core`, the previously deprecated symbols `AxisName`, + `ConcretizationTypeError`, `axis_frame`, `call_p`, `closed_call_p`, + `get_type`, `trace_state_clean`, `typematch`, and `typecheck` have been + removed. + * From {mod}`jax.lib.xla_client`, the previously deprecated symbols + `DeviceAssignment`, `get_topology_for_devices`, and `mlir_api_version` + have been removed. + * `jax.extend.ffi` was removed after being deprecated in v0.5.0. + Use {mod}`jax.ffi` instead. + * {func}`jax.lib.xla_bridge.get_compile_options` is deprecated, and replaced by + {func}`jax.extend.backend.get_compile_options`. + +## JAX 0.6.2 (June 17, 2025) + +* New features: + * Added {func}`jax.tree.broadcast` which implements a pytree prefix broadcasting helper. + +* Changes + * The minimum NumPy version is 1.26 and the minimum SciPy version is 1.12. + +## JAX 0.6.1 (May 21, 2025) + +* New features: + * Added {func}`jax.lax.axis_size` which returns the size of the mapped axis + given its name. + +* Changes + * Additional checking for the versions of CUDA package dependencies was + re-enabled, having been accidentally disabled in a previous release. + * JAX nightly packages are now published to artifact registry. To install + these packages, see the [JAX installation guide](https://docs.jax.dev/en/latest/installation.html#jax-nightly-installation). + * `jax.sharding.PartitionSpec` no longer inherits from a tuple. + * `jax.ShapeDtypeStruct` is immutable now. Please use `.update` method to + update your `ShapeDtypeStruct` instead of doing in-place updates. + +* Deprecations + * `jax.custom_derivatives.custom_jvp_call_jaxpr_p` is deprecated, and will be + removed in JAX v0.7.0. + +## JAX 0.6.0 (April 16, 2025) + +* Breaking changes + + * {func}`jax.numpy.array` no longer accepts `None`. This behavior was + deprecated since November 2023 and is now removed. + * Removed the `config.jax_data_dependent_tracing_fallback` config option, + which was added temporarily in v0.4.36 to allow users to opt out of the + new "stackless" tracing machinery. + * Removed the `config.jax_eager_pmap` config option. + * Disallow the calling of `lower` and `trace` AOT APIs on the result + of `jax.jit` if there have been subsequent wrappers applied. + Previously this worked, but silently ignored the wrappers. + The workaround is to apply `jax.jit` last among the wrappers, + and similarly for `jax.pmap`. + See {jax-issue}`#27873`. + * The `cuda12_pip` extra for `jax` has been removed; use `pip install jax[cuda12]` + instead. + +* Changes + * The minimum CuDNN version is v9.8. + * JAX is now built using CUDA 12.8. All versions of CUDA 12.1 or newer remain + supported. + * JAX package extras are now updated to use dash instead of underscore to + align with PEP 685. For instance, if you were previously using `pip install jax[cuda12_local]` + to install JAX, run `pip install jax[cuda12-local]` instead. + * {func}`jax.jit` now requires `fun` to be passed by position, and additional + arguments to be passed by keyword. Doing otherwise will result in a + DeprecationWarning in v0.6.X, and an error in starting in v0.7.X. + +* Deprecations + + * {func}`jax.tree_util.build_tree` is deprecated. Use {func}`jax.tree.unflatten` + instead. + * Implemented host callback handlers for CPU and GPU devices using XLA's FFI + and removed existing CPU/GPU handlers using XLA's custom call. + * All APIs in `jax.lib.xla_extension` are now deprecated. + * `jax.interpreters.mlir.hlo` and `jax.interpreters.mlir.func_dialect`, + which were accidental exports, have been removed. If needed, they are + available from `jax.extend.mlir`. + * `jax.interpreters.mlir.custom_call` is deprecated. The APIs provided by + {mod}`jax.ffi` should be used instead. + * The deprecated use of {func}`jax.ffi.ffi_call` with inline arguments is no + longer supported. {func}`~jax.ffi.ffi_call` now unconditionally returns a + callable. + * The following exports in `jax.lib.xla_client` are deprecated: + `get_topology_for_devices`, `heap_profile`, `mlir_api_version`, `Client`, + `CompileOptions`, `DeviceAssignment`, `Frame`, `HloSharding`, `OpSharding`, + `Traceback`. + * The following internal APIs in `jax.util` are deprecated: + `HashableFunction`, `as_hashable_function`, `cache`, `safe_map`, `safe_zip`, + `split_dict`, `split_list`, `split_list_checked`, `split_merge`, `subvals`, + `toposort`, `unzip2`, `wrap_name`, and `wraps`. + * `jax.dlpack.to_dlpack` has been deprecated. You can usually pass a JAX + `Array` directly to the `from_dlpack` function of another framework. If you + need the functionality of `to_dlpack`, use the `__dlpack__` attribute of an + array. + * `jax.lax.infeed`, `jax.lax.infeed_p`, `jax.lax.outfeed`, and + `jax.lax.outfeed_p` are deprecated and will be removed in JAX v0.7.0. + * Several previously-deprecated APIs have been removed, including: + * From `jax.lib.xla_client`: `ArrayImpl`, `FftType`, `PaddingType`, + `PrimitiveType`, `XlaBuilder`, `dtype_to_etype`, + `ops`, `register_custom_call_target`, `shape_from_pyval`, `Shape`, + `XlaComputation`. + * From `jax.lib.xla_extension`: `ArrayImpl`, `XlaRuntimeError`. + * From `jax`: `jax.treedef_is_leaf`, `jax.tree_flatten`, `jax.tree_map`, + `jax.tree_leaves`, `jax.tree_structure`, `jax.tree_transpose`, and + `jax.tree_unflatten`. Replacements can be found in {mod}`jax.tree` or + {mod}`jax.tree_util`. + * From `jax.core`: `AxisSize`, `ClosedJaxpr`, `EvalTrace`, `InDBIdx`, `InputType`, + `Jaxpr`, `JaxprEqn`, `Literal`, `MapPrimitive`, `OpaqueTraceState`, `OutDBIdx`, + `Primitive`, `Token`, `TRACER_LEAK_DEBUGGER_WARNING`, `Var`, `concrete_aval`, + `dedup_referents`, `escaped_tracer_error`, `extend_axis_env_nd`, `full_lower`, `get_referent`, `jaxpr_as_fun`, `join_effects`, `lattice_join`, + `leaked_tracer_error`, `maybe_find_leaked_tracers`, `raise_to_shaped`, + `raise_to_shaped_mappings`, `reset_trace_state`, `str_eqn_compact`, + `substitute_vars_in_output_ty`, `typecompat`, and `used_axis_names_jaxpr`. Most + have no public replacement, though a few are available at {mod}`jax.extend.core`. + * The `vectorized` argument to {func}`~jax.pure_callback` and + {func}`~jax.ffi.ffi_call`. Use the `vmap_method` parameter instead. + +## jax 0.5.3 (Mar 19, 2025) + * New Features * Added a `allow_negative_indices` option to {func}`jax.lax.dynamic_slice`, @@ -34,6 +462,30 @@ Patch release of 0.5.1 ## jax 0.5.1 (Feb 24, 2025) +* Breaking changes + * The jit tracing cache now keys on input NamedShardings. Previously, the + tracing cache did not include sharding information at all + (although subsequent jit caches did like lowering and compilation caches), + so two equivalent shardings of different types would not retrace, + but now they do. For example: + ```python + @jax.jit + def f(x): + return x + + # inp1.sharding is of type SingleDeviceSharding + inp1 = jnp.arange(8) + f(inp1) + + mesh = jax.make_mesh((1,), ('x',)) + # inp2.sharding is of type NamedSharding + inp2 = jax.device_put(jnp.arange(8), NamedSharding(mesh, P('x'))) + f(inp2) # tracing cache miss + ``` + In the above example, calling `f(inp1)` and then `f(inp2)` will lead to a + tracing cache miss because the shardings have changed on the abstract values + while tracing. + * New Features * Added an experimental {func}`jax.experimental.custom_dce.custom_dce` decorator to support customizing the behavior of opaque functions under @@ -81,7 +533,7 @@ Patch release of 0.5.1 ## jax 0.5.0 (Jan 17, 2025) As of this release, JAX now uses -[effort-based versioning](https://jax.readthedocs.io/en/latest/jep/25516-effver.html). +[effort-based versioning](https://docs.jax.dev/en/latest/jep/25516-effver.html). Since this release makes a breaking change to PRNG key semantics that may require users to update their code, we are bumping the "meso" version of JAX to signify this. @@ -101,7 +553,7 @@ to signify this. developers at this point. So it is difficult for us to fix this kind of problem even if we wanted to. - We are open to readding support for Mac x86 if the community is willing + We are open to re-adding support for Mac x86 if the community is willing to help support that platform: in particular, we would need the JAX test suite to pass cleanly on Mac x86 before we could ship releases again. @@ -172,7 +624,7 @@ to signify this. * New Features * {func}`jax.export.export` can be used for device-polymorphic export with shardings constructed with {func}`jax.sharding.AbstractMesh`. - See the [jax.export documentation](https://jax.readthedocs.io/en/latest/export/export.html#device-polymorphic-export). + See the [jax.export documentation](https://docs.jax.dev/en/latest/export/export.html#device-polymorphic-export). * Added {func}`jax.lax.split`. This is a primitive version of {func}`jax.numpy.split`, added because it yields a more compact transpose during automatic differentiation. @@ -214,11 +666,11 @@ This is a patch release of jax 0.4.36. Only "jax" was released at this version. after being deprecated in JAX v0.4.31. Instead use `xb = jax.lib.xla_bridge`, `xc = jax.lib.xla_client`, and `xe = jax.lib.xla_extension`. * The deprecated module `jax.experimental.export` has been removed. It was replaced - by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export) + by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://docs.jax.dev/en/latest/export/export.html#migration-guide-from-jax-experimental-export) for information on migrating to the new API. * The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax` has been removed, after being deprecated in v0.4.27. - * Calling `np.asarray` on typed PRNG keys (i.e. keys produced by :func:`jax.random.key`) + * Calling `np.asarray` on typed PRNG keys (i.e. keys produced by {func}`jax.random.key`) now raises an error. Previously, this returned a scalar object array. * The following deprecated methods and functions in {mod}`jax.export` have been removed: @@ -252,7 +704,7 @@ This is a patch release of jax 0.4.36. Only "jax" was released at this version. call that we guarantee export stability. This is because this custom call relies on Triton IR, which is not guaranteed to be stable. If you need to export code that uses this custom call, you can use the `disabled_checks` - parameter. See more details in the [documentation](https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees-for-custom-calls). + parameter. See more details in the [documentation](https://docs.jax.dev/en/latest/export/export.html#compatibility-guarantees-for-custom-calls). * New Features * {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for @@ -326,7 +778,7 @@ This is a patch release of jax 0.4.36. Only "jax" was released at this version. * `jax_pmap_no_rank_reduction` flag is set to `True` by default. * array[0] on a pmap result now introduces a reshape (use array[0:1] instead). - * The per-shard shape (accessable via jax_array.addressable_shards or + * The per-shard shape (accessible via jax_array.addressable_shards or jax_array.addressable_data(0)) now has a leading (1, ...). Update code that directly accesses shards accordingly. The rank of the per-shard-shape now matches that of the global shape which is the same behavior as jit. @@ -532,7 +984,7 @@ See the 0.4.33 release notes for more details. * Added an API for exporting and serializing JAX functions. This used to exist in `jax.experimental.export` (which is being deprecated), and will now live in `jax.export`. - See the [documentation](https://jax.readthedocs.io/en/latest/export/index.html). + See the [documentation](https://docs.jax.dev/en/latest/export/index.html). * Deprecations * Internal pretty-printing tools `jax.core.pp_*` are deprecated, and will be removed @@ -541,7 +993,7 @@ See the 0.4.33 release notes for more details. release. This previously was the case, but there was an inadvertent regression in the last several JAX releases. * `jax.experimental.export` is deprecated. Use {mod}`jax.export` instead. - See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export). + See the [migration guide](https://docs.jax.dev/en/latest/export/export.html#migration-guide-from-jax-experimental-export). * Passing an array in place of a dtype is now deprecated in most cases; e.g. for arrays `x` and `y`, `x.astype(y)` will raise a warning. To silence it use `x.astype(y.dtype)`. * `jax.xla_computation` is deprecated and will be removed in a future release. @@ -709,7 +1161,7 @@ See the 0.4.33 release notes for more details. positional-only, following deprecation of the keywords in JAX v0.4.21. * Non-array arguments to functions in {mod}`jax.lax.linalg` now must be specified by keyword. Previously, this raised a DeprecationWarning. - * Array-like arguments are now required in several :func:`jax.numpy` APIs, + * Array-like arguments are now required in several {func}`jax.numpy` APIs, including {func}`~jax.numpy.apply_along_axis`, {func}`~jax.numpy.apply_over_axes`, {func}`~jax.numpy.inner`, {func}`~jax.numpy.outer`, {func}`~jax.numpy.cross`, @@ -753,7 +1205,7 @@ See the 0.4.33 release notes for more details. deprecated. Use `jax.experimental.shard_map` or `jax.vmap` with the `spmd_axis_name` argument for expressing SPMD device-parallel computations. * The `jax.experimental.host_callback` module is deprecated. - Use instead the [new JAX external callbacks](https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html). + Use instead the [new JAX external callbacks](https://docs.jax.dev/en/latest/notebooks/external_callbacks.html). Added `JAX_HOST_CALLBACK_LEGACY` flag to assist in the transition to the new callbacks. See {jax-issue}`#20385` for a discussion. * Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv` @@ -1225,10 +1677,10 @@ See the 0.4.33 release notes for more details. * Deprecations * Python 3.8 support has been dropped as per - https://jax.readthedocs.io/en/latest/deprecation.html + https://docs.jax.dev/en/latest/deprecation.html * JAX now requires NumPy 1.22 or newer as per - https://jax.readthedocs.io/en/latest/deprecation.html - * Passing optional arguments to {func}`jax.numpy.ndarray.at` by position is + https://docs.jax.dev/en/latest/deprecation.html + * Passing optional arguments to {attr}`jax.numpy.ndarray.at` by position is no longer supported, after being deprecated in JAX version 0.4.7. For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)` * The following `jax.Array` methods have been removed, after being deprecated @@ -1272,7 +1724,7 @@ See the 0.4.33 release notes for more details. * Deprecations * Python 3.8 support has been dropped as per - https://jax.readthedocs.io/en/latest/deprecation.html + https://docs.jax.dev/en/latest/deprecation.html ## jax 0.4.13 (June 22, 2023) @@ -1328,7 +1780,7 @@ See the 0.4.33 release notes for more details. * Deprecations * `jax.abstract_arrays` and its contents are now deprecated. See related - functionality in :mod:`jax.core`. + functionality in {mod}`jax.core`. * `jax.numpy.alltrue`: use `jax.numpy.all`. This follows the deprecation of `numpy.alltrue` in NumPy version 1.25.0. * `jax.numpy.sometrue`: use `jax.numpy.any`. This follows the deprecation @@ -1382,7 +1834,7 @@ See the 0.4.33 release notes for more details. dict of string stat names with int values, e.g. `"bytes_in_use"`, or None if the platform doesn't support memory statistics. The exact stats returned may vary across platforms. Currently only implemented on Cloud TPU. - * Readded support for the Python buffer protocol (`memoryview`) on CPU + * Re-added support for the Python buffer protocol (`memoryview`) on CPU devices. ## jax 0.4.10 (May 11, 2023) @@ -1451,7 +1903,7 @@ See the 0.4.33 release notes for more details. ## jax 0.4.7 (March 27, 2023) * Changes - * As per https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration + * As per https://docs.jax.dev/en/latest/jax_array_migration.html#jax-array-migration `jax.config.jax_array` cannot be disabled anymore. * `jax.config.jax_jit_pjit_api_merge` cannot be disabled anymore. * {func}`jax.experimental.jax2tf.convert` now supports the `native_serialization` @@ -1473,7 +1925,7 @@ See the 0.4.33 release notes for more details. for which it is an alias. * The type `jax.interpreters.pxla.ShardedDeviceArray` is deprecated. Use `jax.Array` instead. - * Passing additional arguments to {func}`jax.numpy.ndarray.at` by position is deprecated. + * Passing additional arguments to {attr}`jax.numpy.ndarray.at` by position is deprecated. For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)` * `jax.interpreters.xla.device_put` is deprecated. Please use `jax.device_put`. * `jax.interpreters.pxla.device_put` is deprecated. Please use `jax.device_put`. @@ -1535,7 +1987,7 @@ Changes: on top of each other. With the `jit`-`pjit` implementation merge, `jit` becomes an initial style primitive which means that we trace to jaxpr as early as possible. For more information see - [this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing). + [this section in autodidax](https://docs.jax.dev/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing). Moving to initial style should simplify JAX's internals and make development of features like dynamic shapes, etc easier. You can disable it only via the environment variable i.e. @@ -1558,7 +2010,7 @@ Changes: * `jax.interpreters.pxla.Mesh`: use `jax.sharding.Mesh`. * `jax.interpreters.pxla.PartitionSpec`: use `jax.sharding.PartitionSpec`. * Breaking Changes - * the `initial` argument to reduction functions like :func:`jax.numpy.sum` + * the `initial` argument to reduction functions like {func}`jax.numpy.sum` is now required to be a scalar, consistent with the corresponding NumPy API. The previous behavior of broadcasting the output against non-scalar `initial` values was an unintentional implementation detail ({jax-issue}`#14446`). @@ -1620,9 +2072,9 @@ Changes: simplifies and unifies JAX internals, and allows us to unify `jit` and `pjit`. `jax.Array` has been enabled by default in JAX 0.4 and makes some breaking change to the `pjit` API. The [jax.Array migration - guide](https://jax.readthedocs.io/en/latest/jax_array_migration.html) can + guide](https://docs.jax.dev/en/latest/jax_array_migration.html) can help you migrate your codebase to `jax.Array`. You can also look at the - [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) + [Distributed arrays and automatic parallelization](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) tutorial to understand the new concepts. * `PartitionSpec` and `Mesh` are now out of experimental. The new API endpoints are `jax.sharding.PartitionSpec` and `jax.sharding.Mesh`. @@ -1651,7 +2103,7 @@ Changes: * The behavior of `XLA_PYTHON_CLIENT_MEM_FRACTION=.XX` has been changed to allocate XX% of the total GPU memory instead of the previous behavior of using currently available GPU memory to calculate preallocation. Please refer to - [GPU memory allocation](https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html) for + [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html) for more details. * The deprecated method `.block_host_until_ready()` has been removed. Use `.block_until_ready()` instead. @@ -1765,7 +2217,7 @@ Changes: * Changes * Ahead-of-time lowering and compilation functionality (tracked in {jax-issue}`#7733`) is stable and public. See [the - overview](https://jax.readthedocs.io/en/latest/aot.html) and the API docs + overview](https://docs.jax.dev/en/latest/aot.html) and the API docs for {mod}`jax.stages`. * Introduced {class}`jax.Array`, intended to be used for both `isinstance` checks and type annotations for array types in JAX. Notice that this included some subtle @@ -1786,7 +2238,7 @@ Changes: * Breaking changes * {func}`jax.checkpoint`, also known as {func}`jax.remat`, no longer supports the `concrete` option, following the previous version's deprecation; see - [JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html). + [JEP 11830](https://docs.jax.dev/en/latest/jep/11830-new-remat-checkpoint.html). * Changes * Added {func}`jax.pure_callback` that enables calling back to pure Python functions from compiled functions (e.g. functions decorated with `jax.jit` or `jax.pmap`). * Deprecations: @@ -1798,11 +2250,11 @@ Changes: * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.15...main). * Breaking changes * Support for NumPy 1.19 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to NumPy 1.20 or newer. * Changes * Added {mod}`jax.debug` that includes utilities for runtime value debugging such at {func}`jax.debug.print` and {func}`jax.debug.breakpoint`. - * Added new documentation for [runtime value debugging](debugging/index) + * Added new documentation for [runtime value debugging](https://github.com/jax-ml/jax/blob/7ac8181cce087d8bcd564d07e19f5067cb5d9d3b/docs/debugging/index.md) * Deprecations * {func}`jax.mask` {func}`jax.shapecheck` APIs have been removed. See {jax-issue}`#11557`. @@ -1816,7 +2268,7 @@ Changes: {mod}`jax.example_libraries.optimizers`. * {func}`jax.checkpoint`, also known as {func}`jax.remat`, has a new implementation switched on by default, meaning the old implementation is - deprecated; see [JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html). + deprecated; see [JEP 11830](https://docs.jax.dev/en/latest/jep/11830-new-remat-checkpoint.html). ## jax 0.3.15 (July 22, 2022) * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.14...jax-v0.3.15). @@ -1884,7 +2336,7 @@ Changes: traces as an alternative to the TensorBoard UI. * Added a `jax.named_scope` context manager that adds profiler metadata to Python programs (similar to `jax.named_call`). - * In scatter-update operations (i.e. :attr:`jax.numpy.ndarray.at`), unsafe implicit + * In scatter-update operations (i.e. {attr}`jax.numpy.ndarray.at`), unsafe implicit dtype casts are deprecated, and now result in a `FutureWarning`. In a future release, this will become an error. An example of an unsafe implicit cast is `jnp.zeros(4, dtype=int).at[0].set(1.5)`, in which `1.5` previously was @@ -1948,7 +2400,7 @@ Changes: * {func}`jax.numpy.linalg.matrix_rank` on TPUs now accepts complex input. * {func}`jax.scipy.cluster.vq.vq` has been added. * `jax.experimental.maps.mesh` has been deleted. - Please use `jax.experimental.maps.Mesh`. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh + Please use `jax.experimental.maps.Mesh`. Please see https://docs.jax.dev/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh for more information. * {func}`jax.scipy.linalg.qr` now returns a length-1 tuple rather than the raw array when `mode='r'`, in order to match the behavior of `scipy.linalg.qr` ({jax-issue}`#10452`) @@ -2064,7 +2516,7 @@ Changes: * Changes: * The functions `jax.ops.index_update`, `jax.ops.index_add`, which were deprecated in 0.2.22, have been removed. Please use - [the `.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html) + [the `.at` property on JAX arrays](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html) instead, e.g., `x.at[idx].set(y)`. * Moved `jax.experimental.ann.approx_*_k` into `jax.lax`. These functions are optimized alternatives to `jax.lax.top_k`. @@ -2110,13 +2562,13 @@ Changes: commits](https://github.com/jax-ml/jax/compare/jax-v0.2.28...jax-v0.3.0). * Changes - * jax version has been bumped to 0.3.0. Please see the [design doc](https://jax.readthedocs.io/en/latest/design_notes/jax_versioning.html) + * jax version has been bumped to 0.3.0. Please see the [design doc](https://docs.jax.dev/en/latest/design_notes/jax_versioning.html) for the explanation. ## jaxlib 0.3.0 (Feb 10, 2022) * Changes * Bazel 5.0.0 is now required to build jaxlib. - * jaxlib version has been bumped to 0.3.0. Please see the [design doc](https://jax.readthedocs.io/en/latest/design_notes/jax_versioning.html) + * jaxlib version has been bumped to 0.3.0. Please see the [design doc](https://docs.jax.dev/en/latest/design_notes/jax_versioning.html) for the explanation. ## jax 0.2.28 (Feb 1, 2022) @@ -2138,7 +2590,7 @@ Changes: by default. * Breaking changes * Support for NumPy 1.18 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported NumPy version. * Bug fixes * Fixed a bug where apparently identical pytreedef objects constructed by different routes @@ -2150,7 +2602,7 @@ Changes: * Breaking changes: * Support for NumPy 1.18 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported NumPy version. * The host_callback primitives have been simplified to drop the special autodiff handling for hcb.id_tap and id_print. @@ -2277,7 +2729,7 @@ Changes: * Deprecations * The functions `jax.ops.index_update`, `jax.ops.index_add` etc. are deprecated and will be removed in a future JAX release. Please use - [the `.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html) + [the `.at` property on JAX arrays](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html) instead, e.g., `x.at[idx].set(y)`. For now, these functions produce a `DeprecationWarning`. * New features: @@ -2341,7 +2793,7 @@ Changes: commits](https://github.com/jax-ml/jax/compare/jax-v0.2.18...jax-v0.2.19). * Breaking changes: * Support for NumPy 1.17 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported NumPy version. * The `jit` decorator has been added around the implementation of a number of operators on JAX arrays. This speeds up dispatch times for common @@ -2362,10 +2814,10 @@ Changes: ## jaxlib 0.1.70 (Aug 9, 2021) * Breaking changes: * Support for Python 3.6 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported Python version. * Support for NumPy 1.17 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported NumPy version. * The host_callback mechanism now uses one thread per local device for @@ -2379,7 +2831,7 @@ Changes: * Breaking changes: * Support for Python 3.6 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported Python version. * The minimum jaxlib version is now 0.1.69. * The `backend` argument to {py:func}`jax.dlpack.from_dlpack` has been @@ -2428,7 +2880,7 @@ Changes: * Breaking changes: * Support for NumPy 1.16 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). * Bug fixes: * Fixed bug that prevented round-tripping from JAX to TF and back: @@ -2968,7 +3420,7 @@ Changes: * Support for reduction over subsets of a pmapped axis using `axis_index_groups` {jax-issue}`#2382`. * Experimental support for printing and calling host-side Python function from - compiled code. See [id_print and id_tap](https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html) + compiled code. See [id_print and id_tap](https://docs.jax.dev/en/latest/jax.experimental.host_callback.html) ({jax-issue}`#3006`). * Notable changes: * The visibility of names exported from {mod}`jax.numpy` has been @@ -3040,7 +3492,7 @@ Changes: ## jax 0.1.63 (April 12, 2020) * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.62...jax-v0.1.63). -* Added `jax.custom_jvp` and `jax.custom_vjp` from {jax-issue}`#2026`, see the [tutorial notebook](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). Deprecated `jax.custom_transforms` and removed it from the docs (though it still works). +* Added `jax.custom_jvp` and `jax.custom_vjp` from {jax-issue}`#2026`, see the [tutorial notebook](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). Deprecated `jax.custom_transforms` and removed it from the docs (though it still works). * Add `scipy.sparse.linalg.cg` {jax-issue}`#2566`. * Changed how Tracers are printed to show more useful information for debugging {jax-issue}`#2591`. * Made `jax.numpy.isclose` handle `nan` and `inf` correctly {jax-issue}`#2501`. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 314d4387a044..046d3df3195c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,4 +1,4 @@ # Contributing to JAX For information on how to contribute to JAX, see -[Contributing to JAX](https://jax.readthedocs.io/en/latest/contributing.html) +[Contributing to JAX](https://docs.jax.dev/en/latest/contributing.html) diff --git a/README.md b/README.md index 0aca7cf58e6e..91e9afb66a56 100644 --- a/README.md +++ b/README.md @@ -7,12 +7,11 @@ [![Continuous integration](https://github.com/jax-ml/jax/actions/workflows/ci-build.yaml/badge.svg)](https://github.com/jax-ml/jax/actions/workflows/ci-build.yaml) [![PyPI version](https://img.shields.io/pypi/v/jax)](https://pypi.org/project/jax/) -[**Quickstart**](#quickstart-colab-in-the-cloud) -| [**Transformations**](#transformations) +[**Transformations**](#transformations) +| [**Scaling**](#scaling) | [**Install guide**](#installation) -| [**Neural net libraries**](#neural-network-libraries) -| [**Change logs**](https://jax.readthedocs.io/en/latest/changelog.html) -| [**Reference docs**](https://jax.readthedocs.io/en/latest/) +| [**Change logs**](https://docs.jax.dev/en/latest/changelog.html) +| [**Reference docs**](https://docs.jax.dev/en/latest/) ## What is JAX? @@ -20,42 +19,29 @@ JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning. -With its updated version of [Autograd](https://github.com/hips/autograd), JAX can automatically differentiate native Python and NumPy functions. It can differentiate through loops, branches, recursion, and closures, and it can take derivatives of derivatives of derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation) -via [`grad`](#automatic-differentiation-with-grad) as well as forward-mode differentiation, +via [`jax.grad`](#automatic-differentiation-with-grad) as well as forward-mode differentiation, and the two can be composed arbitrarily to any order. -What’s new is that JAX uses [XLA](https://www.tensorflow.org/xla) -to compile and run your NumPy programs on GPUs and TPUs. Compilation happens -under the hood by default, with library calls getting just-in-time compiled and -executed. But JAX also lets you just-in-time compile your own Python functions -into XLA-optimized kernels using a one-function API, -[`jit`](#compilation-with-jit). Compilation and automatic differentiation can be -composed arbitrarily, so you can express sophisticated algorithms and get -maximal performance without leaving Python. You can even program multiple GPUs -or TPU cores at once using [`pmap`](#spmd-programming-with-pmap), and -differentiate through the whole thing. +JAX uses [XLA](https://www.openxla.org/xla) +to compile and scale your NumPy programs on TPUs, GPUs, and other hardware accelerators. +You can compile your own pure functions with [`jax.jit`](#compilation-with-jit). +Compilation and automatic differentiation can be composed arbitrarily. Dig a little deeper, and you'll see that JAX is really an extensible system for -[composable function transformations](#transformations). Both -[`grad`](#automatic-differentiation-with-grad) and [`jit`](#compilation-with-jit) -are instances of such transformations. Others are -[`vmap`](#auto-vectorization-with-vmap) for automatic vectorization and -[`pmap`](#spmd-programming-with-pmap) for single-program multiple-data (SPMD) -parallel programming of multiple accelerators, with more to come. +[composable function transformations](#transformations) at [scale](#scaling). This is a research project, not an official Google product. Expect -[sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html). -Please help by trying it out, [reporting -bugs](https://github.com/jax-ml/jax/issues), and letting us know what you -think! +[sharp edges](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html). +Please help by trying it out, [reporting bugs](https://github.com/jax-ml/jax/issues), +and letting us know what you think! ```python +import jax import jax.numpy as jnp -from jax import grad, jit, vmap def predict(params, inputs): for W, b in params: @@ -67,85 +53,49 @@ def loss(params, inputs, targets): preds = predict(params, inputs) return jnp.sum((preds - targets)**2) -grad_loss = jit(grad(loss)) # compiled gradient evaluation function -perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example grads +grad_loss = jax.jit(jax.grad(loss)) # compiled gradient evaluation function +perex_grads = jax.jit(jax.vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example grads ``` ### Contents -* [Quickstart: Colab in the Cloud](#quickstart-colab-in-the-cloud) * [Transformations](#transformations) -* [Current gotchas](#current-gotchas) +* [Scaling](#scaling) +* [Current gotchas](#gotchas-and-sharp-bits) * [Installation](#installation) -* [Neural net libraries](#neural-network-libraries) * [Citing JAX](#citing-jax) * [Reference documentation](#reference-documentation) -## Quickstart: Colab in the Cloud -Jump right in using a notebook in your browser, connected to a Google Cloud GPU. -Here are some starter notebooks: -- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/quickstart.html) -- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) - -**JAX now runs on Cloud TPUs.** To try out the preview, see the [Cloud TPU -Colabs](https://github.com/jax-ml/jax/tree/main/cloud_tpu_colabs). - -For a deeper dive into JAX: -- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) -- [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) -- See the [full list of -notebooks](https://github.com/jax-ml/jax/tree/main/docs/notebooks). - ## Transformations At its core, JAX is an extensible system for transforming numerical functions. -Here are four transformations of primary interest: `grad`, `jit`, `vmap`, and -`pmap`. +Here are three: `jax.grad`, `jax.jit`, and `jax.vmap`. ### Automatic differentiation with `grad` -JAX has roughly the same API as [Autograd](https://github.com/hips/autograd). -The most popular function is -[`grad`](https://jax.readthedocs.io/en/latest/jax.html#jax.grad) -for reverse-mode gradients: +Use [`jax.grad`](https://docs.jax.dev/en/latest/jax.html#jax.grad) +to efficiently compute reverse-mode gradients: ```python -from jax import grad +import jax import jax.numpy as jnp -def tanh(x): # Define a function +def tanh(x): y = jnp.exp(-2.0 * x) return (1.0 - y) / (1.0 + y) -grad_tanh = grad(tanh) # Obtain its gradient function -print(grad_tanh(1.0)) # Evaluate it at x = 1.0 +grad_tanh = jax.grad(tanh) +print(grad_tanh(1.0)) # prints 0.4199743 ``` -You can differentiate to any order with `grad`. +You can differentiate to any order with `grad`: ```python -print(grad(grad(grad(tanh)))(1.0)) +print(jax.grad(jax.grad(jax.grad(tanh)))(1.0)) # prints 0.62162673 ``` -For more advanced autodiff, you can use -[`jax.vjp`](https://jax.readthedocs.io/en/latest/jax.html#jax.vjp) for -reverse-mode vector-Jacobian products and -[`jax.jvp`](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) for -forward-mode Jacobian-vector products. The two can be composed arbitrarily with -one another, and with other JAX transformations. Here's one way to compose those -to make a function that efficiently computes [full Hessian -matrices](https://jax.readthedocs.io/en/latest/_autosummary/jax.hessian.html#jax.hessian): - -```python -from jax import jit, jacfwd, jacrev - -def hessian(fun): - return jit(jacfwd(jacrev(fun))) -``` - -As with [Autograd](https://github.com/hips/autograd), you're free to use -differentiation with Python control structures: +You're free to use differentiation with Python control flow: ```python def abs_val(x): @@ -154,242 +104,134 @@ def abs_val(x): else: return -x -abs_val_grad = grad(abs_val) +abs_val_grad = jax.grad(abs_val) print(abs_val_grad(1.0)) # prints 1.0 print(abs_val_grad(-1.0)) # prints -1.0 (abs_val is re-evaluated) ``` -See the [reference docs on automatic -differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) -and the [JAX Autodiff -Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) +See the [JAX Autodiff +Cookbook](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html) +and the [reference docs on automatic +differentiation](https://docs.jax.dev/en/latest/jax.html#automatic-differentiation) for more. ### Compilation with `jit` -You can use XLA to compile your functions end-to-end with -[`jit`](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit), +Use XLA to compile your functions end-to-end with +[`jit`](https://docs.jax.dev/en/latest/jax.html#just-in-time-compilation-jit), used either as an `@jit` decorator or as a higher-order function. ```python +import jax import jax.numpy as jnp -from jax import jit def slow_f(x): # Element-wise ops see a large benefit from fusion return x * x + x * 2.0 x = jnp.ones((5000, 5000)) -fast_f = jit(slow_f) -%timeit -n10 -r3 fast_f(x) # ~ 4.5 ms / loop on Titan X -%timeit -n10 -r3 slow_f(x) # ~ 14.5 ms / loop (also on GPU via JAX) +fast_f = jax.jit(slow_f) +%timeit -n10 -r3 fast_f(x) +%timeit -n10 -r3 slow_f(x) ``` -You can mix `jit` and `grad` and any other JAX transformation however you like. - -Using `jit` puts constraints on the kind of Python control flow +Using `jax.jit` constrains the kind of Python control flow the function can use; see -the tutorial on [Control Flow and Logical Operators with JIT](https://jax.readthedocs.io/en/latest/control-flow.html) +the tutorial on [Control Flow and Logical Operators with JIT](https://docs.jax.dev/en/latest/control-flow.html) for more. ### Auto-vectorization with `vmap` -[`vmap`](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) is -the vectorizing map. -It has the familiar semantics of mapping a function along array axes, but -instead of keeping the loop on the outside, it pushes the loop down into a -function’s primitive operations for better performance. +[`vmap`](https://docs.jax.dev/en/latest/jax.html#vectorization-vmap) maps +a function along array axes. +But instead of just looping over function applications, it pushes the loop down +onto the function’s primitive operations, e.g. turning matrix-vector multiplies into +matrix-matrix multiplies for better performance. Using `vmap` can save you from having to carry around batch dimensions in your -code. For example, consider this simple *unbatched* neural network prediction -function: - -```python -def predict(params, input_vec): - assert input_vec.ndim == 1 - activations = input_vec - for W, b in params: - outputs = jnp.dot(W, activations) + b # `activations` on the right-hand side! - activations = jnp.tanh(outputs) # inputs to the next layer - return outputs # no activation on last layer -``` - -We often instead write `jnp.dot(activations, W)` to allow for a batch dimension on the -left side of `activations`, but we’ve written this particular prediction function to -apply only to single input vectors. If we wanted to apply this function to a -batch of inputs at once, semantically we could just write - -```python -from functools import partial -predictions = jnp.stack(list(map(partial(predict, params), input_batch))) -``` - -But pushing one example through the network at a time would be slow! It’s better -to vectorize the computation, so that at every layer we’re doing matrix-matrix -multiplication rather than matrix-vector multiplication. - -The `vmap` function does that transformation for us. That is, if we write +code: ```python -from jax import vmap -predictions = vmap(partial(predict, params))(input_batch) -# or, alternatively -predictions = vmap(predict, in_axes=(None, 0))(params, input_batch) -``` +import jax +import jax.numpy as jnp -then the `vmap` function will push the outer loop inside the function, and our -machine will end up executing matrix-matrix multiplications exactly as if we’d -done the batching by hand. +def l1_distance(x, y): + assert x.ndim == y.ndim == 1 # only works on 1D inputs + return jnp.sum(jnp.abs(x - y)) -It’s easy enough to manually batch a simple neural network without `vmap`, but -in other cases manual vectorization can be impractical or impossible. Take the -problem of efficiently computing per-example gradients: that is, for a fixed set -of parameters, we want to compute the gradient of our loss function evaluated -separately at each example in a batch. With `vmap`, it’s easy: +def pairwise_distances(dist1D, xs): + return jax.vmap(jax.vmap(dist1D, (0, None)), (None, 0))(xs, xs) -```python -per_example_gradients = vmap(partial(grad(loss), params))(inputs, targets) +xs = jax.random.normal(jax.random.key(0), (100, 3)) +dists = pairwise_distances(l1_distance, xs) +dists.shape # (100, 100) ``` -Of course, `vmap` can be arbitrarily composed with `jit`, `grad`, and any other -JAX transformation! We use `vmap` with both forward- and reverse-mode automatic -differentiation for fast Jacobian and Hessian matrix calculations in -`jax.jacfwd`, `jax.jacrev`, and `jax.hessian`. - -### SPMD programming with `pmap` - -For parallel programming of multiple accelerators, like multiple GPUs, use -[`pmap`](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap). -With `pmap` you write single-program multiple-data (SPMD) programs, including -fast parallel collective communication operations. Applying `pmap` will mean -that the function you write is compiled by XLA (similarly to `jit`), then -replicated and executed in parallel across devices. - -Here's an example on an 8-GPU machine: +By composing `jax.vmap` with `jax.grad` and `jax.jit`, we can get efficient +Jacobian matrices, or per-example gradients: ```python -from jax import random, pmap -import jax.numpy as jnp - -# Create 8 random 5000 x 6000 matrices, one per GPU -keys = random.split(random.key(0), 8) -mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys) - -# Run a local matmul on each device in parallel (no data transfer) -result = pmap(lambda x: jnp.dot(x, x.T))(mats) # result.shape is (8, 5000, 5000) - -# Compute the mean on each device in parallel and print the result -print(pmap(jnp.mean)(result)) -# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157] +per_example_grads = jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0, 0))) ``` -In addition to expressing pure maps, you can use fast [collective communication -operations](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators) -between devices: +## Scaling + +To scale your computations across thousands of devices, you can use any +composition of these: +* [**Compiler-based automatic parallelization**](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) +where you program as if using a single global machine, and the compiler chooses +how to shard data and partition computation (with some user-provided constraints); +* [**Explicit sharding and automatic partitioning**](https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html) +where you still have a global view but data shardings are +explicit in JAX types, inspectable using `jax.typeof`; +* [**Manual per-device programming**](https://docs.jax.dev/en/latest/notebooks/shard_map.html) +where you have a per-device view of data +and computation, and can communicate with explicit collectives. + +| Mode | View? | Explicit sharding? | Explicit Collectives? | +|---|---|---|---| +| Auto | Global | ❌ | ❌ | +| Explicit | Global | ✅ | ❌ | +| Manual | Per-device | ✅ | ✅ | ```python -from functools import partial -from jax import lax +from jax.sharding import set_mesh, AxisType, PartitionSpec as P +mesh = jax.make_mesh((8,), ('data',), axis_types=(AxisType.Explicit,)) +set_mesh(mesh) -@partial(pmap, axis_name='i') -def normalize(x): - return x / lax.psum(x, 'i') +# parameters are sharded for FSDP: +for W, b in params: + print(f'{jax.typeof(W)}') # f32[512@data,512] + print(f'{jax.typeof(b)}') # f32[512] -print(normalize(jnp.arange(4.))) -# prints [0. 0.16666667 0.33333334 0.5 ] -``` +# shard data for batch parallelism: +inputs, targets = jax.device_put((inputs, targets), P('data')) -You can even [nest `pmap` functions](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb#scrollTo=MdRscR5MONuN) for more -sophisticated communication patterns. - -It all composes, so you're free to differentiate through parallel computations: - -```python -from jax import grad - -@pmap -def f(x): - y = jnp.sin(x) - @pmap - def g(z): - return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum() - return grad(lambda w: jnp.sum(g(w)))(x) - -print(f(x)) -# [[ 0. , -0.7170853 ], -# [-3.1085174 , -0.4824318 ], -# [10.366636 , 13.135289 ], -# [ 0.22163185, -0.52112055]] - -print(grad(lambda x: jnp.sum(f(x)))(x)) -# [[ -3.2369726, -1.6356447], -# [ 4.7572474, 11.606951 ], -# [-98.524414 , 42.76499 ], -# [ -1.6007166, -1.2568436]] +# evaluate gradients, automatically parallelized! +gradfun = jax.jit(jax.grad(loss)) +param_grads = gradfun(params, (inputs, targets)) ``` -When reverse-mode differentiating a `pmap` function (e.g. with `grad`), the -backward pass of the computation is parallelized just like the forward pass. +See the [tutorial](https://docs.jax.dev/en/latest/sharded-computation.html) and +[advanced guides](https://docs.jax.dev/en/latest/advanced_guide.html) for more. -See the [SPMD -Cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) -and the [SPMD MNIST classifier from scratch -example](https://github.com/jax-ml/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py) -for more. +## Gotchas and sharp bits -## Current gotchas - -For a more thorough survey of current gotchas, with examples and explanations, -we highly recommend reading the [Gotchas -Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html). -Some standouts: - -1. JAX transformations only work on [pure functions](https://en.wikipedia.org/wiki/Pure_function), which don't have side-effects and respect [referential transparency](https://en.wikipedia.org/wiki/Referential_transparency) (i.e. object identity testing with `is` isn't preserved). If you use a JAX transformation on an impure Python function, you might see an error like `Exception: Can't lift Traced...` or `Exception: Different traces at same level`. -1. [In-place mutating updates of - arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically. -1. [Random numbers are - different](https://jax.readthedocs.io/en/latest/random-numbers.html), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md). -1. If you're looking for [convolution - operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html), - they're in the `jax.lax` package. -1. JAX enforces single-precision (32-bit, e.g. `float32`) values by default, and - [to enable - double-precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision) - (64-bit, e.g. `float64`) one needs to set the `jax_enable_x64` variable at - startup (or set the environment variable `JAX_ENABLE_X64=True`). - On TPU, JAX uses 32-bit values by default for everything _except_ internal - temporary variables in 'matmul-like' operations, such as `jax.numpy.dot` and `lax.conv`. - Those ops have a `precision` parameter which can be used to approximate 32-bit operations - via three bfloat16 passes, with a cost of possibly slower runtime. - Non-matmul operations on TPU lower to implementations that often emphasize speed over - accuracy, so in practice computations on TPU will be less precise than similar - computations on other backends. -1. Some of NumPy's dtype promotion semantics involving a mix of Python scalars - and NumPy types aren't preserved, namely `np.add(1, np.array([2], - np.float32)).dtype` is `float64` rather than `float32`. -1. Some transformations, like `jit`, [constrain how you can use Python control - flow](https://jax.readthedocs.io/en/latest/control-flow.html). - You'll always get loud errors if something goes wrong. You might have to use - [`jit`'s `static_argnums` - parameter](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit), - [structured control flow - primitives](https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators) - like - [`lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan), - or just use `jit` on smaller subfunctions. +See the [Gotchas +Notebook](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html). ## Installation ### Supported platforms -| | Linux x86_64 | Linux aarch64 | Mac x86_64 | Mac aarch64 | Windows x86_64 | Windows WSL2 x86_64 | -|------------|--------------|---------------|--------------|--------------|----------------|---------------------| -| CPU | yes | yes | yes | yes | yes | yes | -| NVIDIA GPU | yes | yes | no | n/a | no | experimental | -| Google TPU | yes | n/a | n/a | n/a | n/a | n/a | -| AMD GPU | yes | no | experimental | n/a | no | no | -| Apple GPU | n/a | no | n/a | experimental | n/a | n/a | -| Intel GPU | experimental | n/a | n/a | n/a | no | no | +| | Linux x86_64 | Linux aarch64 | Mac aarch64 | Windows x86_64 | Windows WSL2 x86_64 | +|------------|--------------|---------------|--------------|----------------|---------------------| +| CPU | yes | yes | yes | yes | yes | +| NVIDIA GPU | yes | yes | n/a | no | experimental | +| Google TPU | yes | n/a | n/a | n/a | n/a | +| AMD GPU | yes | no | n/a | no | experimental | +| Apple GPU | n/a | no | experimental | n/a | n/a | +| Intel GPU | experimental | n/a | n/a | no | no | ### Instructions @@ -397,34 +239,17 @@ Some standouts: | Platform | Instructions | |-----------------|-----------------------------------------------------------------------------------------------------------------| | CPU | `pip install -U jax` | -| NVIDIA GPU | `pip install -U "jax[cuda12]"` | +| NVIDIA GPU | `pip install -U "jax[cuda13]"` | | Google TPU | `pip install -U "jax[tpu]"` | | AMD GPU (Linux) | Follow [AMD's instructions](https://github.com/jax-ml/jax/blob/main/build/rocm/README.md). | | Mac GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). | | Intel GPU | Follow [Intel's instructions](https://github.com/intel/intel-extension-for-openxla/blob/main/docs/acc_jax.md). | -See [the documentation](https://jax.readthedocs.io/en/latest/installation.html) +See [the documentation](https://docs.jax.dev/en/latest/installation.html) for information on alternative installation strategies. These include compiling from source, installing with Docker, using other versions of CUDA, a community-supported conda build, and answers to some frequently-asked questions. - - -## Neural network libraries - -Multiple Google research groups at Google DeepMind and Alphabet develop and share libraries -for training neural networks in JAX. If you want a fully featured library for neural network -training with examples and how-to guides, try -[Flax](https://github.com/google/flax) and its [documentation site](https://flax.readthedocs.io/en/latest/nnx/index.html). - -Check out the [JAX Ecosystem section](https://jax.readthedocs.io/en/latest/#ecosystem) -on the JAX documentation site for a list of JAX-based network libraries, which includes -[Optax](https://github.com/deepmind/optax) for gradient processing and -optimization, [chex](https://github.com/deepmind/chex) for reliable code and testing, and -[Equinox](https://github.com/patrick-kidger/equinox) for neural networks. -(Watch the NeurIPS 2020 JAX Ecosystem at DeepMind talk -[here](https://www.youtube.com/watch?v=iDxJxIyzSiM) for additional details.) - ## Citing JAX To cite this repository: @@ -452,7 +277,7 @@ paper. ## Reference documentation For details about the JAX API, see the -[reference documentation](https://jax.readthedocs.io/). +[reference documentation](https://docs.jax.dev/). For getting started as a JAX developer, see the -[developer documentation](https://jax.readthedocs.io/en/latest/developer.html). +[developer documentation](https://docs.jax.dev/en/latest/developer.html). diff --git a/WORKSPACE b/WORKSPACE index 129488281ea9..8054ec1aad31 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -1,68 +1,112 @@ -# The XLA commit is determined by third_party/xla/workspace.bzl. +load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") + +# The XLA commit is determined by third_party/xla/revision.bzl. load("//third_party/xla:workspace.bzl", jax_xla_workspace = "repo") + jax_xla_workspace() +load("@xla//:workspace4.bzl", "xla_workspace4") + +xla_workspace4() + +load("@xla//:workspace3.bzl", "xla_workspace3") + +xla_workspace3() + +# Initialize Hermetic toolchains +# Details: https://github.com/google-ml-infra/rules_ml_toolchain +tf_http_archive( + name = "rules_ml_toolchain", + sha256 = "54c1a357f71f611efdb4891ebd4bcbe4aeb6dfa7e473f14fd7ecad5062096616", + strip_prefix = "rules_ml_toolchain-d8cb9c2c168cd64000eaa6eda0781a9615a26ffe", + urls = tf_mirror_urls( + "https://github.com/google-ml-infra/rules_ml_toolchain/archive/d8cb9c2c168cd64000eaa6eda0781a9615a26ffe.tar.gz", + ), +) + +load( + "@rules_ml_toolchain//cc/deps:cc_toolchain_deps.bzl", + "cc_toolchain_deps", +) + +cc_toolchain_deps() + +register_toolchains("@rules_ml_toolchain//cc:linux_x86_64_linux_x86_64") + +register_toolchains("@rules_ml_toolchain//cc:linux_x86_64_linux_x86_64_cuda") + +register_toolchains("@rules_ml_toolchain//cc:linux_aarch64_linux_aarch64") + +register_toolchains("@rules_ml_toolchain//cc:linux_aarch64_linux_aarch64_cuda") + # Initialize hermetic Python load("@xla//third_party/py:python_init_rules.bzl", "python_init_rules") + python_init_rules() load("@xla//third_party/py:python_init_repositories.bzl", "python_init_repositories") + python_init_repositories( - requirements = { - "3.10": "//build:requirements_lock_3_10.txt", - "3.11": "//build:requirements_lock_3_11.txt", - "3.12": "//build:requirements_lock_3_12.txt", - "3.13": "//build:requirements_lock_3_13.txt", - "3.13-ft": "//build:requirements_lock_3_13_ft.txt", - }, + default_python_version = "system", + local_wheel_dist_folder = "../dist", local_wheel_inclusion_list = [ + "libtpu*", + "ml_dtypes*", + "ml-dtypes*", + "numpy*", + "scipy*", + "jax-*", "jaxlib*", "jax_cuda*", "jax-cuda*", ], local_wheel_workspaces = ["//jaxlib:jax.bzl"], - local_wheel_dist_folder = "../dist", - default_python_version = "system", + requirements = { + "3.11": "//build:requirements_lock_3_11.txt", + "3.12": "//build:requirements_lock_3_12.txt", + "3.13": "//build:requirements_lock_3_13.txt", + "3.14": "//build:requirements_lock_3_14.txt", + "3.13-ft": "//build:requirements_lock_3_13_ft.txt", + "3.14-ft": "//build:requirements_lock_3_14_ft.txt", + }, ) load("@xla//third_party/py:python_init_toolchains.bzl", "python_init_toolchains") + python_init_toolchains() load("@xla//third_party/py:python_init_pip.bzl", "python_init_pip") + python_init_pip() load("@pypi//:requirements.bzl", "install_deps") -install_deps() - -# Optional, to facilitate testing against newest versions of Python -load("@xla//third_party/py:python_repo.bzl", "custom_python_interpreter") -custom_python_interpreter( - name = "python_dev", - urls = ["https://www.python.org/ftp/python/{version}/Python-{version_variant}.tgz"], - strip_prefix = "Python-{version_variant}", - version = "3.13.0", - version_variant = "3.13.0rc2", -) - -load("@xla//:workspace4.bzl", "xla_workspace4") -xla_workspace4() -load("@xla//:workspace3.bzl", "xla_workspace3") -xla_workspace3() +install_deps() load("@xla//:workspace2.bzl", "xla_workspace2") + xla_workspace2() load("@xla//:workspace1.bzl", "xla_workspace1") + xla_workspace1() load("@xla//:workspace0.bzl", "xla_workspace0") + xla_workspace0() load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo") + flatbuffers() +load("//:test_shard_count.bzl", "test_shard_count_repository") + +test_shard_count_repository( + name = "test_shard_count", +) + load("//jaxlib:jax_python_wheel.bzl", "jax_python_wheel_repository") + jax_python_wheel_repository( name = "jax_wheel", version_key = "_version", @@ -71,14 +115,21 @@ jax_python_wheel_repository( load( "@xla//third_party/py:python_wheel.bzl", + "nvidia_wheel_versions_repository", "python_wheel_version_suffix_repository", ) + +nvidia_wheel_versions_repository( + name = "nvidia_wheel_versions", + versions_source = "//build:nvidia-requirements.txt", +) + python_wheel_version_suffix_repository( name = "jax_wheel_version_suffix", ) load( - "@xla//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", + "@rules_ml_toolchain//gpu/cuda:cuda_json_init_repository.bzl", "cuda_json_init_repository", ) @@ -90,7 +141,7 @@ load( "CUDNN_REDISTRIBUTIONS", ) load( - "@xla//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", + "@rules_ml_toolchain//gpu/cuda:cuda_redist_init_repositories.bzl", "cuda_redist_init_repositories", "cudnn_redist_init_repository", ) @@ -104,22 +155,42 @@ cudnn_redist_init_repository( ) load( - "@xla//third_party/gpus/cuda/hermetic:cuda_configure.bzl", + "@rules_ml_toolchain//gpu/cuda:cuda_configure.bzl", "cuda_configure", ) cuda_configure(name = "local_config_cuda") load( - "@xla//third_party/nccl/hermetic:nccl_redist_init_repository.bzl", + "@rules_ml_toolchain//gpu/nccl:nccl_redist_init_repository.bzl", "nccl_redist_init_repository", ) nccl_redist_init_repository() load( - "@xla//third_party/nccl/hermetic:nccl_configure.bzl", + "@rules_ml_toolchain//gpu/nccl:nccl_configure.bzl", "nccl_configure", ) nccl_configure(name = "local_config_nccl") + +load( + "@rules_ml_toolchain//gpu/nvshmem:nvshmem_json_init_repository.bzl", + "nvshmem_json_init_repository", +) + +nvshmem_json_init_repository() + +load( + "@nvshmem_redist_json//:distributions.bzl", + "NVSHMEM_REDISTRIBUTIONS", +) +load( + "@rules_ml_toolchain//gpu/nvshmem:nvshmem_redist_init_repository.bzl", + "nvshmem_redist_init_repository", +) + +nvshmem_redist_init_repository( + nvshmem_redistributions = NVSHMEM_REDISTRIBUTIONS, +) diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index cabebce2227c..51b0b2c0ac7d 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -72,9 +72,10 @@ class AnEnum(enum.IntEnum): @google_benchmark.register def eager_unary_dispatch(state): a = jax.device_put(1) - lax.neg(a) + x = lax.neg(a) while state: - lax.neg(a) + x = lax.neg(a) + x.block_until_ready() @google_benchmark.register @@ -98,9 +99,10 @@ def eager_binary_dispatch(state): def eager_binary(state): a = jax.device_put(1) b = jax.device_put(2) - lax.add(a, b).block_until_ready() + x = lax.add(a, b).block_until_ready() while state: - lax.add(a, b).block_until_ready() + x = lax.add(a, b).block_until_ready() + x.block_until_ready() @google_benchmark.register @@ -131,10 +133,11 @@ def jit_simple_dispatch(state): a = jax.device_put(1) b = jax.device_put(2) f = jax.jit(operator.add) - f(a, b) + x = f(a, b) while state: - f(a, b) + x = f(a, b) + x.block_until_ready() @google_benchmark.register @@ -152,10 +155,11 @@ def jit_simple_dispatch_array(state): a = jax.device_put(1) b = jax.device_put(2) f = jax.jit(operator.add) - f(a, b) + x = f(a, b) while state: - f(a, b) + x = f(a, b) + x.block_until_ready() @google_benchmark.register @@ -205,7 +209,7 @@ def jit_big_matmul(state): @google_benchmark.option.args([2000]) def jit_simple_many_args_dispatch(state): args = [jax.device_put(i) for i in range(state.range(0))] - f = jax.jit(lambda xs: functools.reduce(operator.add, xs)) + f = jax.jit(sum) x = f(args) x.block_until_ready() @@ -225,7 +229,7 @@ def jit_simple_many_args_dispatch(state): @google_benchmark.option.args([2000]) def jit_simple_many_args(state): args = [jax.device_put(i) for i in range(state.range(0))] - f = jax.jit(lambda xs: functools.reduce(operator.add, xs)) + f = jax.jit(sum) f(args).block_until_ready() while state: @@ -269,10 +273,11 @@ def jit_dispatch_without_transfer(state): imgs = jax.device_put(imgs) f = jax.jit(lambda x: x+1) - f(imgs) + x = f(imgs) while state: - f(imgs) + x = f(imgs) + x.block_until_ready() @google_benchmark.register @@ -280,7 +285,7 @@ def jit_dispatch_with_transfer(state): imgs = np.ones((128, 224, 224), np.float32) f = jax.jit(lambda x: x+1) - f(imgs).block_until_ready() + x = f(imgs).block_until_ready() while state: x = f(imgs) @@ -308,6 +313,8 @@ def pmap_trivial_dispatch_8_devices(state): while state: a, b = f(a, b) + a.block_until_ready() + b.block_until_ready() @google_benchmark.register @@ -344,6 +351,8 @@ def pmap_simple_dispatch_8_devices(state): while state: a, b = f(a, b) + a.block_until_ready() + b.block_until_ready() @google_benchmark.register @@ -371,6 +380,7 @@ def pmap_simple_dispatch_8_devices_100_args(state): while state: args = f(*args) + args[0].block_until_ready() @google_benchmark.register @@ -395,6 +405,7 @@ def _run_sda_index_bench(state, num_devices): while state: for i in range(num_devices): _ = x[i] + x.block_until_ready() @google_benchmark.register @@ -450,7 +461,7 @@ def bench_xla_abstractify(): @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMicrosecond) -def bench_are_op_shardings_equal(state): +def bench_are_hlo_shardings_equal(state): op1 = xc.OpSharding() op1.type = xc.OpSharding.Type.OTHER op1.tile_assignment_dimensions = [4, 192, 16] @@ -461,8 +472,11 @@ def bench_are_op_shardings_equal(state): op2.tile_assignment_dimensions = [4, 192, 16] op2.tile_assignment_devices = list(range(12288)) + hs1 = xc.HloSharding.from_proto(op1) + hs2 = xc.HloSharding.from_proto(op2) + while state: - op_shardings.are_op_shardings_equal(op1, op2) + op_shardings.are_hlo_shardings_equal(hs1, hs2) @google_benchmark.register @@ -592,6 +606,7 @@ def pjit_simple_benchmark(state, num_devices, num_args, use_aot=False): while state: x = f(x) + x[0].block_until_ready() @google_benchmark.register @@ -689,9 +704,8 @@ def device_put_from_numpy_array(state): @google_benchmark.option.args([10]) @google_benchmark.option.args([100]) @google_benchmark.option.args([1000]) +@required_devices(2) def device_put_from_jax_array(state): - if len(jax.devices()) < 2: - state.skip_with_error('requires 2 devices') x = [np.array(1, np.int32)] * state.range(0) x = jax.block_until_ready(jax.device_put(x, device=jax.devices()[0])) d = jax.devices()[1] @@ -847,7 +861,7 @@ def safe_map(state): args = tuple(list(range(state.range(0))) for _ in range(state.range(1))) def f(*args): return tuple(args) while state: - jax.util.safe_map(f, *args) + jax._src.util.safe_map(f, *args) @google_benchmark.register @google_benchmark.option.arg_names(['arg_lengths', 'num_args']) @@ -855,7 +869,7 @@ def f(*args): return tuple(args) def safe_zip(state): args = tuple(list(range(state.range(0))) for _ in range(state.range(1))) while state: - jax.util.safe_zip(*args) + jax._src.util.safe_zip(*args) @google_benchmark.register diff --git a/benchmarks/mosaic/BUILD b/benchmarks/mosaic/BUILD index 39c7aa5f3395..0b14c147d571 100644 --- a/benchmarks/mosaic/BUILD +++ b/benchmarks/mosaic/BUILD @@ -35,7 +35,7 @@ jax_multiplatform_test( enable_configs = ["gpu_h100"], tags = ["notap"], deps = [ - "//jax:mosaic_gpu", + "//jax/experimental:mosaic_gpu", "//jax/experimental/mosaic/gpu/examples:matmul", "//third_party/py/google_benchmark", ] + py_deps("absl/testing") + py_deps("numpy"), diff --git a/benchmarks/mosaic/matmul_bench.py b/benchmarks/mosaic/matmul_bench.py index 32c147916407..fd3fcd6da315 100644 --- a/benchmarks/mosaic/matmul_bench.py +++ b/benchmarks/mosaic/matmul_bench.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Microbenchmarks for mosaic gpu matrix mutliplication.""" +"""Microbenchmarks for mosaic gpu matrix multiplication.""" import functools import sys diff --git a/benchmarks/sparse_benchmark.py b/benchmarks/sparse_benchmark.py index d6328881d5c6..0ffb2aed5125 100644 --- a/benchmarks/sparse_benchmark.py +++ b/benchmarks/sparse_benchmark.py @@ -21,7 +21,13 @@ import jax from jax.experimental import sparse -def _sparse_bcoo_fromdense(state, jit: bool = False, compile: bool = False): + +def _sparse_fromdense( + state, + bcsr: bool = False, + jit: bool = False, + compile: bool = False, +): shape = (2000, 2000) nse = 10000 size = math.prod(shape) @@ -32,7 +38,7 @@ def _sparse_bcoo_fromdense(state, jit: bool = False, compile: bool = False): ) mat = jnp.zeros(shape).at[indices].set(data) - f = sparse.BCOO.fromdense + f = sparse.BCSR.fromdense if bcsr else sparse.BCOO.fromdense if compile or jit: # Note: nse must be specified for JIT. f = jax.jit(partial(f, nse=nse)) @@ -49,22 +55,12 @@ def _sparse_bcoo_fromdense(state, jit: bool = False, compile: bool = False): f(mat).block_until_ready() -@google_benchmark.register -def sparse_bcoo_fromdense(state): - return _sparse_bcoo_fromdense(state) - - -@google_benchmark.register -def sparse_bcoo_fromdense_jit(state): - return _sparse_bcoo_fromdense(state, jit=True) - - -@google_benchmark.register -def sparse_bcoo_fromdense_compile(state): - return _sparse_bcoo_fromdense(state, compile=True) - - -def _sparse_bcoo_todense(state, jit: bool = False, compile: bool = False): +def _sparse_todense( + state, + bcsr: bool = False, + jit: bool = False, + compile: bool = False, +): shape = (2000, 2000) nse = 10000 size = math.prod(shape) @@ -74,6 +70,8 @@ def _sparse_bcoo_todense(state, jit: bool = False, compile: bool = False): rng.choice(size, size=nse, replace=False), shape=shape ) mat = sparse.BCOO((jnp.array(data), jnp.column_stack(indices)), shape=shape) + if bcsr: + mat = sparse.BCSR.from_bcoo(mat) f = lambda mat: mat.todense() if jit or compile: @@ -91,22 +89,12 @@ def _sparse_bcoo_todense(state, jit: bool = False, compile: bool = False): f(mat).block_until_ready() -@google_benchmark.register -def sparse_bcoo_todense(state): - return _sparse_bcoo_todense(state) - - -@google_benchmark.register -def sparse_bcoo_todense_jit(state): - return _sparse_bcoo_todense(state, jit=True) - - -@google_benchmark.register -def sparse_bcoo_todense_compile(state): - return _sparse_bcoo_todense(state, compile=True) - - -def _sparse_bcoo_matvec(state, jit: bool = False, compile: bool = False): +def _sparse_matvec( + state, + bcsr: bool = False, + jit: bool = False, + compile: bool = False, +): shape = (2000, 2000) nse = 10000 key = jax.random.key(1701) @@ -118,6 +106,9 @@ def _sparse_bcoo_matvec(state, jit: bool = False, compile: bool = False): indices_dtype=jnp.int32, sorted_indices=True, ) + if bcsr: + mat = sparse.BCSR.from_bcoo(mat) + vec = jax.random.uniform(key, shape=(shape[1],), dtype=jnp.float32) f = lambda mat, vec: mat @ vec @@ -136,19 +127,94 @@ def _sparse_bcoo_matvec(state, jit: bool = False, compile: bool = False): f(mat, vec).block_until_ready() +@google_benchmark.register +def sparse_bcoo_fromdense(state): + return _sparse_fromdense(state) + + +@google_benchmark.register +def sparse_bcoo_fromdense_jit(state): + return _sparse_fromdense(state, jit=True) + + +@google_benchmark.register +def sparse_bcoo_fromdense_compile(state): + return _sparse_fromdense(state, compile=True) + + +@google_benchmark.register +def sparse_bcoo_todense(state): + return _sparse_todense(state) + + +@google_benchmark.register +def sparse_bcoo_todense_jit(state): + return _sparse_todense(state, jit=True) + + +@google_benchmark.register +def sparse_bcoo_todense_compile(state): + return _sparse_todense(state, compile=True) + + @google_benchmark.register def sparse_bcoo_matvec(state): - return _sparse_bcoo_matvec(state) + return _sparse_matvec(state) @google_benchmark.register def sparse_bcoo_matvec_jit(state): - return _sparse_bcoo_matvec(state, jit=True) + return _sparse_matvec(state, jit=True) @google_benchmark.register def sparse_bcoo_matvec_compile(state): - return _sparse_bcoo_matvec(state, compile=True) + return _sparse_matvec(state, compile=True) + + +@google_benchmark.register +def sparse_bscr_fromdense(state): + return _sparse_fromdense(state, bcsr=True) + + +@google_benchmark.register +def sparse_bscr_fromdense_jit(state): + return _sparse_fromdense(state, bcsr=True, jit=True) + + +@google_benchmark.register +def sparse_bscr_fromdense_compile(state): + return _sparse_fromdense(state, bcsr=True, compile=True) + + +@google_benchmark.register +def sparse_bscr_todense(state): + return _sparse_todense(state, bcsr=True) + + +@google_benchmark.register +def sparse_bscr_todense_jit(state): + return _sparse_todense(state, bcsr=True, jit=True) + + +@google_benchmark.register +def sparse_bscr_todense_compile(state): + return _sparse_todense(state, bcsr=True, compile=True) + + +@google_benchmark.register +def sparse_bcsr_matvec(state): + return _sparse_matvec(state, bcsr=True) + + +@google_benchmark.register +def sparse_bcsr_matvec_jit(state): + return _sparse_matvec(state, bcsr=True, jit=True) + + +@google_benchmark.register +def sparse_bcsr_matvec_compile(state): + return _sparse_matvec(state, bcsr=True, compile=True) if __name__ == "__main__": diff --git a/benchmarks/tracing_benchmark.py b/benchmarks/tracing_benchmark.py new file mode 100644 index 000000000000..22095c2af25b --- /dev/null +++ b/benchmarks/tracing_benchmark.py @@ -0,0 +1,162 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmarks for Jax tracing.""" +import functools + +import google_benchmark +import jax +from jax import random +from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel as splash +from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib +import jax.numpy as jnp +import numpy as np + + +def clear_caches(state): + state.pause_timing() + jax.clear_caches() + state.resume_timing() + + +def make_mqa_splash_attention_fn_and_args(): + seed = 0 + key = random.key(seed) + k1, k2, k3 = random.split(key, 3) + + q_seq_len = 1024 + kv_seq_len = 1024 + num_q_heads = 2 + head_dim_qk = 128 + head_dim_v = 128 + dtype = np.dtype("float32") + + q = random.uniform(k1, (num_q_heads, q_seq_len, head_dim_qk), dtype=dtype) + k = random.uniform(k2, (kv_seq_len, head_dim_qk), dtype=dtype) + v = random.uniform(k3, (kv_seq_len, head_dim_v), dtype=dtype) + + mask = mask_lib.NumpyMask( + mask_lib.make_random_mask((q_seq_len, kv_seq_len), sparsity=0.5, seed=0) + ) + mask = mask_lib.MultiHeadMask(tuple(mask for _ in range(num_q_heads))) + block_sizes = splash.BlockSizes.get_default() + + return ( + jax.jit( + splash.make_splash_mqa_single_device(mask, block_sizes=block_sizes) + ) + ), (q, k, v) + + +@google_benchmark.register +@google_benchmark.option.unit(google_benchmark.kMillisecond) +def test_pallas_mqa_splash_attention_trace(state): + attn, (q, k, v) = make_mqa_splash_attention_fn_and_args() + + while state: + _ = attn.trace(q, k, v) + clear_caches(state) + + +@google_benchmark.register +@google_benchmark.option.unit(google_benchmark.kMillisecond) +def test_pallas_mqa_splash_attention_trace_no_cache_clear(state): + attn, (q, k, v) = make_mqa_splash_attention_fn_and_args() + + while state: + _ = attn.trace(q, k, v) + + +@google_benchmark.register +@google_benchmark.option.unit(google_benchmark.kMillisecond) +def test_pallas_mqa_splash_attention_lower(state): + attn, (q, k, v) = make_mqa_splash_attention_fn_and_args() + traced = attn.trace(q, k, v) + + while state: + _ = traced.lower(lowering_platforms=("tpu",)) + clear_caches(state) + + +@google_benchmark.register +@google_benchmark.option.unit(google_benchmark.kMillisecond) +def test_pallas_mqa_splash_attention_lower_no_cache_clear(state): + attn, (q, k, v) = make_mqa_splash_attention_fn_and_args() + traced = attn.trace(q, k, v) + + while state: + _ = traced.lower(lowering_platforms=("tpu",)) + + +@google_benchmark.register +@google_benchmark.option.unit(google_benchmark.kMillisecond) +def test_jnp_dot_trace(state): + fn = jax.jit(jnp.dot) + while state: + _ = fn.trace(jnp.arange(1024), jnp.arange(1024)) + clear_caches(state) + + +@google_benchmark.register +@google_benchmark.option.unit(google_benchmark.kMillisecond) +def test_jnp_dot_trace_no_cache_clear(state): + fn = jax.jit(jnp.dot) + while state: + _ = fn.trace(jnp.arange(1024), jnp.arange(1024)) + + +@google_benchmark.register +@google_benchmark.option.unit(google_benchmark.kMillisecond) +def test_jnp_concat_trace(state): + fn = jax.jit(functools.partial(jnp.concat, axis=0)) + while state: + _ = fn.trace((jnp.ones((1024, 1)), jnp.ones((1024, 1)))) + clear_caches(state) + + +@google_benchmark.register +@google_benchmark.option.unit(google_benchmark.kMillisecond) +def test_jnp_concat_trace_no_cache_clear(state): + fn = jax.jit(functools.partial(jnp.concat, axis=0)) + while state: + _ = fn.trace((jnp.ones((1024, 1)), jnp.ones((1024, 1)))) + + +@google_benchmark.register +@google_benchmark.option.unit(google_benchmark.kMillisecond) +# NOTE(dsuo): Linear spacing so it's easier to eyeball historical plots. +@google_benchmark.option.arg(1) +@google_benchmark.option.dense_range(128, 896, 128) +def test_num_multiply_eqns_trace(state): + fns = [lambda x: x * x for _ in range(state.range(0))] + fn = jax.jit(functools.reduce(lambda a, b: (lambda x: a(b(x))), fns)) + while state: + _ = fn.trace(jnp.ones((1024,))) + clear_caches(state) + + +@google_benchmark.register +@google_benchmark.option.unit(google_benchmark.kMillisecond) +# NOTE(dsuo): Linear spacing so it's easier to eyeball historical plots. +@google_benchmark.option.arg(1) +@google_benchmark.option.dense_range(128, 896, 128) +def test_num_multiply_eqns_trace_no_cache_clear(state): + fns = [lambda x: x * x for _ in range(state.range(0))] + fn = jax.jit(functools.reduce(lambda a, b: (lambda x: a(b(x))), fns)) + while state: + _ = fn.trace(jnp.ones((1024,))) + + +if __name__ == "__main__": + google_benchmark.main() diff --git a/build/BUILD.bazel b/build/BUILD.bazel index f088cd58aa74..4d242f57ccb4 100644 --- a/build/BUILD.bazel +++ b/build/BUILD.bazel @@ -13,55 +13,90 @@ # limitations under the License. # ============================================================================== -licenses(["notice"]) - -load("@python//:defs.bzl", "compile_pip_requirements") load("@python_version_repo//:py_version.bzl", "REQUIREMENTS") +load("@rules_python//python:pip.bzl", "compile_pip_requirements") +load("@rules_python//python:py_library.bzl", "py_library") load("//jaxlib:jax.bzl", "all_py_deps") -compile_pip_requirements( - name = "requirements", - extra_args = [ - "--allow-unsafe", - "--build-isolation", - "--rebuild", - ], - requirements_in = "requirements.in", - requirements_txt = REQUIREMENTS, - generate_hashes = True, - data = ["test-requirements.txt", "gpu-test-requirements.txt"] -) +licenses(["notice"]) -compile_pip_requirements( - name = "requirements_nightly", - extra_args = [ - "--allow-unsafe", - "--build-isolation", - "--extra-index-url=https://pypi.anaconda.org/scientific-python-nightly-wheels/simple", - "--pre", - "--upgrade" - ], - requirements_in = "requirements.in", - requirements_txt = REQUIREMENTS, - generate_hashes = False, - data = ["test-requirements.txt", "gpu-test-requirements.txt"] -) +COMMON_REQUIREMENTS = [ + "requirements.in", + "test-requirements.txt", + "nvidia-requirements.txt", +] -compile_pip_requirements( - name = "requirements_dev", - extra_args = [ - "--allow-unsafe", - "--build-isolation", - "--upgrade", - "--rebuild", - ], - requirements_in = "requirements.in", - requirements_txt = REQUIREMENTS, - generate_hashes = False, - data = ["test-requirements.txt", "gpu-test-requirements.txt"] -) +# It isn't possible to constraint based on free-threaded vs non-free threaded +# in a requirements file. So we do it by having two separate sets of requirement +# files and two sets of build rules. +FREETHREADING_REQUIREMENTS = COMMON_REQUIREMENTS + [ + "freethreading-requirements.txt", +] + +NON_FREETHREADING_REQUIREMENTS = COMMON_REQUIREMENTS + [ + "nonfreethreading-requirements.txt", +] + +COMBOS = [ + ("", NON_FREETHREADING_REQUIREMENTS), + ("_ft", FREETHREADING_REQUIREMENTS), +] + +[ + compile_pip_requirements( + name = "requirements" + suffix, + srcs = requirements, + extra_args = [ + "--allow-unsafe", + "--build-isolation", + "--rebuild", + "--index-url=https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple", + ], + generate_hashes = True, + requirements_txt = REQUIREMENTS, + ) + for suffix, requirements in COMBOS +] + +[ + compile_pip_requirements( + name = "requirements_nightly" + suffix, + srcs = requirements, + extra_args = [ + "--allow-unsafe", + "--build-isolation", + "--index-url=https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple", + "--extra-index-url=https://pypi.anaconda.org/scientific-python-nightly-wheels/simple", + "--pre", + "--upgrade", + ], + generate_hashes = False, + requirements_txt = REQUIREMENTS, + ) + for suffix, requirements in COMBOS +] + +[ + compile_pip_requirements( + name = "requirements_dev" + suffix, + srcs = requirements, + extra_args = [ + "--allow-unsafe", + "--build-isolation", + "--upgrade", + "--rebuild", + "--index-url=https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple", + ], + generate_hashes = False, + requirements_txt = REQUIREMENTS, + ) + for suffix, requirements in COMBOS +] py_library( name = "all_py_deps", - deps = all_py_deps(["zstandard"]), -) \ No newline at end of file + deps = all_py_deps([ + "zstandard", + "tensorstore", + ]), +) diff --git a/build/build.py b/build/build.py index d38b911bb904..eacf3ae612f6 100755 --- a/build/build.py +++ b/build/build.py @@ -1,4 +1,4 @@ -#!/usr/bin/python +#!/usr/bin/env python3 # # Copyright 2018 The JAX Authors. # @@ -56,23 +56,18 @@ # Define the build target for each wheel. WHEEL_BUILD_TARGET_DICT = { - "jaxlib": "//jaxlib/tools:build_wheel", - "jax-cuda-plugin": "//jaxlib/tools:build_gpu_kernels_wheel", - "jax-cuda-pjrt": "//jaxlib/tools:build_gpu_plugin_wheel", - "jax-rocm-plugin": "//jaxlib/tools:build_gpu_kernels_wheel", - "jax-rocm-pjrt": "//jaxlib/tools:build_gpu_plugin_wheel", -} - -# Dictionary with the new wheel build rule. Note that when JAX migrates to the -# new wheel build rule fully, the build CLI will switch to the new wheel build -# rule as the default. -WHEEL_BUILD_TARGET_DICT_NEW = { "jax": "//:jax_wheel", + "jax_editable": "//:jax_wheel_editable", + "jax_source_package": "//:jax_source_package", "jaxlib": "//jaxlib/tools:jaxlib_wheel", - "jax-cuda-plugin": "//jaxlib/tools:jax_cuda_plugin_wheel", - "jax-cuda-pjrt": "//jaxlib/tools:jax_cuda_pjrt_wheel", + "jaxlib_editable": "//jaxlib/tools:jaxlib_wheel_editable", + "jax-cuda-plugin": "//jaxlib/tools:jax_cuda{cuda_major_version}_plugin_wheel", + "jax-cuda-plugin_editable": "//jaxlib/tools:jax_cuda{cuda_major_version}_plugin_wheel_editable", + "jax-cuda-pjrt": "//jaxlib/tools:jax_cuda{cuda_major_version}_pjrt_wheel", + "jax-cuda-pjrt_editable": "//jaxlib/tools:jax_cuda{cuda_major_version}_pjrt_wheel_editable", "jax-rocm-plugin": "//jaxlib/tools:jax_rocm_plugin_wheel", "jax-rocm-pjrt": "//jaxlib/tools:jax_rocm_pjrt_wheel", + "mosaic-gpu-cuda": "//jaxlib/tools:mosaic_gpu_wheel_cuda{cuda_major_version}", } def add_global_arguments(parser: argparse.ArgumentParser): @@ -163,8 +158,8 @@ def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser): action="store_true", help= """ - Whether to use the new wheel build rule. Temporary flag and will be - removed once JAX migrates to the new wheel build rule fully. + DEPRECATED: Whether to use the new wheel build rule. Temporary flag and + will be removed once JAX migrates to the new wheel build rule fully. """, ) @@ -278,13 +273,12 @@ def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser): compile_group.add_argument( "--use_clang", - type=utils._parse_string_as_bool, - default="true", - const=True, + type=str, + default="", nargs="?", help=""" - Whether to use Clang as the compiler. Not recommended to set this to - False as JAX uses Clang as the default compiler. + DEPRECATED: Whether to use Clang as the compiler. Not recommended to + set this flag because Clang is the default compiler. """, ) @@ -302,7 +296,7 @@ def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser): type=str, default="", help=""" - Path to the GCC binary to use. + DEPRECATED: Path to the GCC binary to use. """, ) @@ -374,7 +368,7 @@ async def main(): # Artifact build subcommand build_artifact_parser = subparsers.add_parser( - "build", help="Builds the jaxlib, plugin, and pjrt artifact" + "build", help="Builds the jaxlib, plugin, mosaic, and pjrt artifact" ) add_artifact_subcommand_arguments(build_artifact_parser) add_global_arguments(build_artifact_parser) @@ -382,6 +376,11 @@ async def main(): arch = platform.machine() os_name = platform.system().lower() + custom_wheel_version_suffix = "" + wheel_build_date = "" + wheel_git_hash = "" + wheel_type = "snapshot" + args = parser.parse_args() logger.info("%s", BANNER) @@ -407,16 +406,17 @@ async def main(): for option in args.bazel_startup_options: bazel_command_base.append(option) - if not args.use_new_wheel_build_rule or args.command == "requirements_update": + if args.command == "requirements_update": bazel_command_base.append("run") else: bazel_command_base.append("build") + freethreaded = False if args.python_version: # Do not add --repo_env=HERMETIC_PYTHON_VERSION with default args.python_version # if bazel_options override it python_version_opt = "--repo_env=HERMETIC_PYTHON_VERSION=" - if any([python_version_opt in opt for opt in args.bazel_options]): + if any(python_version_opt in opt for opt in args.bazel_options): raise RuntimeError( "Please use python_version to set hermetic python version instead of " "setting --repo_env=HERMETIC_PYTHON_VERSION= bazel option" @@ -427,8 +427,9 @@ async def main(): ) # Let's interpret X.YY-ft version as free-threading python and set rules_python config flag: if args.python_version.endswith("-ft"): + freethreaded = True bazel_command_base.append( - "--@rules_python//python/config_settings:py_freethreaded='yes'" + "--@rules_python//python/config_settings:py_freethreaded=\"yes\"" ) # Enable verbose failures. @@ -444,14 +445,15 @@ async def main(): for option in args.bazel_options: requirements_command.append(option) + ft_suffix = "_ft" if freethreaded else "" if args.nightly_update: logging.info( "--nightly_update is set. Bazel will run" " //build:requirements_nightly.update" ) - requirements_command.append("//build:requirements_nightly.update") + requirements_command.append(f"//build:requirements{ft_suffix}_nightly.update") else: - requirements_command.append("//build:requirements.update") + requirements_command.append(f"//build:requirements{ft_suffix}.update") result = await executor.run(requirements_command.get_command_as_string(), args.dry_run, args.detailed_timestamped_log) if result.return_code != 0: @@ -459,6 +461,36 @@ async def main(): else: sys.exit(0) + if args.use_new_wheel_build_rule: + logger.warning( + "The --use_new_wheel_build_rule flag is deprecated and no-op. It will" + " be removed soon. Please remove it from your build scripts." + ) + if args.use_clang: + logger.warning( + "The --use_clang flag is deprecated and no-op. It will be removed soon." + " Please remove it from your build scripts. Clang is the only" + " acceptable compiler." + ) + if args.gcc_path: + logger.warning( + "The --gcc_path flag is deprecated and no-op. It will be removed soon." + " Please remove it from your build scripts. Clang is the only" + " acceptable compiler." + ) + + wheels = args.wheels.split(",") + for wheel in wheels: + if wheel not in WHEEL_BUILD_TARGET_DICT.keys(): + logging.error( + "Incorrect wheel name %s provided, valid choices are jaxlib," + " jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt," + " jax-rocm-plugin or rocm-plugin, jax-rocm-pjrt or rocm-pjrt," + " or mosaic-gpu", + wheel, + ) + sys.exit(1) + wheel_build_command_base = copy.deepcopy(bazel_command_base) wheel_cpus = { @@ -483,12 +515,18 @@ async def main(): logging.debug("Disabling NCCL") wheel_build_command_base.append("--config=nonccl") - git_hash = utils.get_githash() - clang_path = "" - if args.use_clang: + clang_local = args.clang_path or not (utils.is_linux_x86_64(arch, os_name) + or utils.is_linux_aarch64(arch, os_name)) + if clang_local: + if "cuda" in args.wheels and utils.is_linux(os_name): + wheel_build_command_base.append("--config=cuda_clang_local") + else: + wheel_build_command_base.append("--config=clang_local") + clang_path = args.clang_path or utils.get_clang_path_or_exit() clang_major_version = utils.get_clang_major_version(clang_path) + clangpp_path = utils.get_clangpp_path(clang_path) logging.debug( "Using Clang as the compiler, clang path: %s, clang version: %s", clang_path, @@ -498,6 +536,7 @@ async def main(): # Use double quotes around clang path to avoid path issues on Windows. wheel_build_command_base.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") wheel_build_command_base.append(f"--repo_env=CC=\"{clang_path}\"") + wheel_build_command_base.append(f"--repo_env=CXX=\"{clangpp_path}\"") wheel_build_command_base.append(f"--repo_env=BAZEL_COMPILER=\"{clang_path}\"") if clang_major_version >= 16: @@ -506,19 +545,10 @@ async def main(): wheel_build_command_base.append("--config=clang") if clang_major_version < 19: wheel_build_command_base.append("--define=xnn_enable_avxvnniint8=false") - else: - gcc_path = args.gcc_path or utils.get_gcc_path_or_exit() - logging.debug( - "Using GCC as the compiler, gcc path: %s", - gcc_path, - ) - wheel_build_command_base.append(f"--repo_env=CC=\"{gcc_path}\"") - wheel_build_command_base.append(f"--repo_env=BAZEL_COMPILER=\"{gcc_path}\"") - - gcc_major_version = utils.get_gcc_major_version(gcc_path) - if gcc_major_version < 13: - wheel_build_command_base.append("--define=xnn_enable_avxvnniint8=false") + # TODO:(yuriit) Check version of Clang when it will be available outside + # of rules_ml_toolchain. Current hermetic Clang version is 18 + wheel_build_command_base.append("--define=xnn_enable_avxvnniint8=false") if not args.disable_mkl_dnn: logging.debug("Enabling MKL DNN") @@ -554,19 +584,21 @@ async def main(): logging.error("CUDA and ROCm cannot be enabled at the same time.") sys.exit(1) + if args.cuda_version: + cuda_major_version = args.cuda_version.split(".")[0] + else: + cuda_major_version = args.cuda_major_version + if "cuda" in args.wheels: - wheel_build_command_base.append("--config=cuda") - wheel_build_command_base.append("--config=cuda_libraries_from_stubs") - if args.use_clang: + wheel_build_command_base.append(f"--config=cuda{cuda_major_version}") + + if clang_local: wheel_build_command_base.append( f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\"" ) - if args.build_cuda_with_clang: - logging.debug("Building CUDA with Clang") - wheel_build_command_base.append("--config=build_cuda_with_clang") - else: - logging.debug("Building CUDA with NVCC") - wheel_build_command_base.append("--config=build_cuda_with_nvcc") + if args.build_cuda_with_clang: + logging.debug("Building CUDA with Clang") + wheel_build_command_base.append("--config=build_cuda_with_clang") else: logging.debug("Building CUDA with NVCC") wheel_build_command_base.append("--config=build_cuda_with_nvcc") @@ -591,12 +623,16 @@ async def main(): ) if "rocm" in args.wheels: + if not args.configure_only: + print("ERROR: This repo is not used for building the ROCm JAX plugins. Please use the new plugin repo: https://github.com/ROCm/rocm-jax") + exit(1) + wheel_build_command_base.append("--config=rocm_base") - if args.use_clang: - wheel_build_command_base.append("--config=rocm") + wheel_build_command_base.append("--config=rocm") + if clang_local: wheel_build_command_base.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") if args.rocm_path: - logging.debug("ROCm tookit path: %s", args.rocm_path) + logging.debug("ROCm toolkit path: %s", args.rocm_path) wheel_build_command_base.append(f"--action_env=ROCM_PATH=\"{args.rocm_path}\"") if args.rocm_amdgpu_targets: logging.debug("ROCm AMD GPU targets: %s", args.rocm_amdgpu_targets) @@ -612,28 +648,33 @@ async def main(): ) for option in args.bazel_options: wheel_build_command_base.append(option) - if "cuda" in args.wheels: - wheel_build_command_base.append("--config=cuda_libraries_from_stubs") - with open(".jax_configure.bazelrc", "w") as f: - jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command_base.get_command_as_list(), args.use_new_wheel_build_rule) + # Parse the build options for the wheel version suffix. + if "ML_WHEEL_TYPE" in option: + wheel_type = option.split("=")[-1] + if "ML_WHEEL_VERSION_SUFFIX" in option: + custom_wheel_version_suffix = option.split("=")[-1].replace("-", "") + if "ML_WHEEL_BUILD_DATE" in option: + wheel_build_date = option.split("=")[-1].replace("-", "") + if "ML_WHEEL_GIT_HASH" in option: + # Strip leading zeros as they end up being stripped by setuptools, + # which leads to a mismatch between expected and actual wheel names + # https://peps.python.org/pep-0440/ + wheel_git_hash = option.split("=")[-1].lstrip('0')[:9] + + with open(".jax_configure.bazelrc", "w") as f: # noqa: ASYNC230 + jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command_base.get_command_as_list()) if not jax_configure_options: logging.error("Error retrieving the Bazel options to be written to .jax_configure.bazelrc, exiting.") sys.exit(1) f.write(jax_configure_options) logging.info("Bazel options written to .jax_configure.bazelrc") - if args.use_new_wheel_build_rule: - logging.info("Using new wheel build rule") - wheel_build_targets = WHEEL_BUILD_TARGET_DICT_NEW - else: - wheel_build_targets = WHEEL_BUILD_TARGET_DICT - if args.configure_only: logging.info("--configure_only is set so not running any Bazel commands.") else: # Wheel build command execution - for wheel in args.wheels.split(","): + for wheel in wheels: output_path = args.output_path logger.debug("Artifacts output directory: %s", output_path) @@ -641,15 +682,10 @@ async def main(): if ("plugin" in wheel or "pjrt" in wheel) and "jax" not in wheel: wheel = "jax-" + wheel - if wheel not in wheel_build_targets.keys(): - logging.error( - "Incorrect wheel name provided, valid choices are jaxlib," - " jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt," - " jax-rocm-plugin or rocm-plugin, jax-rocm-pjrt or rocm-pjrt" - ) - sys.exit(1) + wheel_build_command = copy.deepcopy(bazel_command_base) + if "cuda" in args.wheels: + wheel_build_command.append("--config=cuda_libraries_from_stubs") - wheel_build_command = copy.deepcopy(wheel_build_command_base) print("\n") logger.info( "Building %s for %s %s...", @@ -659,39 +695,80 @@ async def main(): ) # Append the build target to the Bazel command. - build_target = wheel_build_targets[wheel] + if args.editable: + build_target = WHEEL_BUILD_TARGET_DICT[wheel + "_editable"] + else: + build_target = WHEEL_BUILD_TARGET_DICT[wheel] + build_target = build_target.format(cuda_major_version=cuda_major_version) wheel_build_command.append(build_target) + if wheel == "jax" and not args.editable: + wheel_build_command.append( + WHEEL_BUILD_TARGET_DICT["jax_source_package"] + ) - if not args.use_new_wheel_build_rule: - wheel_build_command.append("--") - - if args.editable: - logger.info("Building an editable build") - output_path = os.path.join(output_path, wheel) - wheel_build_command.append("--editable") - - wheel_build_command.append(f'--output_path="{output_path}"') - wheel_build_command.append(f"--cpu={target_cpu}") - - if "cuda" in wheel: - wheel_build_command.append("--enable-cuda=True") - if args.cuda_version: - cuda_major_version = args.cuda_version.split(".")[0] - else: - cuda_major_version = args.cuda_major_version - wheel_build_command.append(f"--platform_version={cuda_major_version}") - - if "rocm" in wheel: - wheel_build_command.append("--enable-rocm=True") - wheel_build_command.append(f"--platform_version={args.rocm_version}") - - wheel_build_command.append(f"--jaxlib_git_hash={git_hash}") + # If we build jax wheel, we don't need to build jaxlib targets. + if wheel == "jax": + wheel_build_command.append("--//jax:build_jaxlib=false") result = await executor.run(wheel_build_command.get_command_as_string(), args.dry_run, args.detailed_timestamped_log) # Exit with error if any wheel build fails. if result.return_code != 0: raise RuntimeError(f"Command failed with return code {result.return_code}") + output_path = args.output_path + jax_bazel_dir = os.path.join("bazel-bin", "dist") + jaxlib_and_plugins_bazel_dir = os.path.join( + "bazel-bin", "jaxlib", "tools", "dist" + ) + for wheel in wheels: + if wheel == "jax": + bazel_dir = jax_bazel_dir + else: + bazel_dir = jaxlib_and_plugins_bazel_dir + if "cuda" in wheel: + wheel_dir = wheel.replace("cuda", f"cuda{cuda_major_version}").replace( + "-", "_" + ) + elif "rocm" in wheel: + if args.editable: + # For editable builds, use the actual ROCm version since directory paths cannot contain wildcards + wheel_dir = wheel.replace("rocm", f"rocm{args.rocm_version}").replace("-", "_") + else: + # For non-editable builds, use wildcard pattern to match any ROCm version in glob patterns + wheel_dir = wheel.replace("rocm", "rocm*").replace("-", "_") + else: + wheel_dir = wheel + + if args.editable: + src_dir = os.path.join(bazel_dir, wheel_dir) + dst_dir = os.path.join(output_path, wheel_dir) + utils.copy_dir_recursively(src_dir, dst_dir) + else: + wheel_version_suffix = "dev0+selfbuilt" + if wheel_type == "release": + wheel_version_suffix = custom_wheel_version_suffix + elif wheel_type in ["nightly", "custom"]: + wheel_version_suffix = f".dev{wheel_build_date}" + if wheel_type == "custom": + wheel_version_suffix += ( + f"+{wheel_git_hash}{custom_wheel_version_suffix}" + ) + if wheel in ["jax", "jax-cuda-pjrt", "jax-rocm-pjrt"]: + python_tag = "py" + else: + python_tag = "cp" + utils.copy_individual_files( + bazel_dir, + output_path, + f"{wheel_dir}*{wheel_version_suffix}-{python_tag}*.whl", + ) + if wheel == "jax": + utils.copy_individual_files( + bazel_dir, + output_path, + f"{wheel_dir}*{wheel_version_suffix}.tar.gz", + ) + # Exit with success if all wheels in the list were built successfully. sys.exit(0) diff --git a/build/collect-profile-requirements.txt b/build/collect-profile-requirements.txt index da25d4b6ffe1..a334f408e271 100644 --- a/build/collect-profile-requirements.txt +++ b/build/collect-profile-requirements.txt @@ -1,4 +1,5 @@ -tensorflow -tensorboard-plugin-profile -# Needed for the profile plugin to work without error +# TF hasn't released 3.13 wheels yet (b/402590302) +tensorflow; python_version<"3.13" +xprof>=2.19.0 +# Needed for XProf to work without error protobuf diff --git a/build/freethreading-requirements.txt b/build/freethreading-requirements.txt new file mode 100644 index 000000000000..e3a2f722d2b1 --- /dev/null +++ b/build/freethreading-requirements.txt @@ -0,0 +1,3 @@ +# Under free-threading, we need an up-to-date numpy at least for the moment. +numpy~=2.2.6; python_version=="3.13" +numpy>=2.3.2; python_version>="3.14" \ No newline at end of file diff --git a/build/gpu-test-requirements.txt b/build/gpu-test-requirements.txt deleted file mode 100644 index ff43f91ba90f..000000000000 --- a/build/gpu-test-requirements.txt +++ /dev/null @@ -1,13 +0,0 @@ -# NVIDIA CUDA dependencies -# Note that the wheels are downloaded only when the targets in bazel command -# contain dependencies on these wheels. -nvidia-cublas-cu12>=12.1.3.1 ; sys_platform == "linux" -nvidia-cuda-cupti-cu12>=12.1.105 ; sys_platform == "linux" -nvidia-cuda-nvcc-cu12>=12.6.85 ; sys_platform == "linux" -nvidia-cuda-runtime-cu12>=12.1.105 ; sys_platform == "linux" -nvidia-cudnn-cu12>=9.1,<10.0 ; sys_platform == "linux" -nvidia-cufft-cu12>=11.0.2.54 ; sys_platform == "linux" -nvidia-cusolver-cu12>=11.4.5.107 ; sys_platform == "linux" -nvidia-cusparse-cu12>=12.1.0.106 ; sys_platform == "linux" -nvidia-nccl-cu12>=2.18.1 ; sys_platform == "linux" -nvidia-nvjitlink-cu12>=12.1.105 ; sys_platform == "linux" diff --git a/build/nonfreethreading-requirements.txt b/build/nonfreethreading-requirements.txt new file mode 100644 index 000000000000..42567bc5411d --- /dev/null +++ b/build/nonfreethreading-requirements.txt @@ -0,0 +1,14 @@ +numpy~=2.0.0; python_version<="3.12" +numpy~=2.1.0; python_version=="3.13" +numpy>=2.3.2; python_version>="3.14" + +# These packages have not released free-threaded wheels. + +# zstandard is available in the Python 3.14 standard library as +# compression.zstd, and we'll use that if available. +zstandard; python_version<"3.14" + +tensorstore + +# For jax2tf_test +tensorflow; python_version<"3.14" \ No newline at end of file diff --git a/build/nvidia-requirements.txt b/build/nvidia-requirements.txt new file mode 100644 index 000000000000..2c53d19ab467 --- /dev/null +++ b/build/nvidia-requirements.txt @@ -0,0 +1,30 @@ +nvidia-cublas-cu12>=12.1.3.1 ; sys_platform == "linux" +nvidia-cuda-cupti-cu12>=12.1.105 ; sys_platform == "linux" +nvidia-cuda-nvcc-cu12>=12.6.85 ; sys_platform == "linux" +nvidia-cuda-runtime-cu12>=12.1.105 ; sys_platform == "linux" +# The upper bound is set for the CUDNN API compatibility. +# See +# https://docs.nvidia.com/deeplearning/cudnn/backend/latest/developer/forward-compatibility.html#cudnn-api-compatibility +nvidia-cudnn-cu12>=9.8,<10.0 ; sys_platform == "linux" +nvidia-cufft-cu12>=11.0.2.54 ; sys_platform == "linux" +nvidia-cusolver-cu12>=11.4.5.107 ; sys_platform == "linux" +nvidia-cusparse-cu12>=12.1.0.106 ; sys_platform == "linux" +nvidia-nccl-cu12>=2.18.1 ; sys_platform == "linux" +nvidia-nvjitlink-cu12>=12.1.105 ; sys_platform == "linux" +nvidia-cuda-nvrtc-cu12>=12.1.55 ; sys_platform == "linux" +nvidia-nvshmem-cu12>=3.2.5 ; sys_platform == "linux" + +nvidia-nccl-cu13>=2.27.7 ; sys_platform == "linux" +nvidia-nvshmem-cu13>=3.3.20 ; sys_platform == "linux" +nvidia-cublas>=13.0.0.19 ; sys_platform == "linux" +nvidia-cuda-cupti>=13.0.48 ; sys_platform == "linux" +nvidia-cuda-nvcc>=13.0.48 ; sys_platform == "linux" +nvidia-cuda-runtime>=13.0.48 ; sys_platform == "linux" +nvidia-cudnn-cu13>=9.12.0.46,<10.0 ; sys_platform == "linux" +nvidia-cufft>=12.0.0.15 ; sys_platform == "linux" +nvidia-cusolver>=12.0.3.29 ; sys_platform == "linux" +nvidia-cusparse>=12.6.2.49 ; sys_platform == "linux" +nvidia-nvjitlink>=13.0.39 ; sys_platform == "linux" +nvidia-cuda-nvrtc>=13.0.48 ; sys_platform == "linux" +nvidia-nvvm>=13.0.48 ; sys_platform == "linux" +nvidia-cuda-crt>=13.0.48 ; sys_platform == "linux" diff --git a/build/requirements.in b/build/requirements.in index ec7fc71b07e1..e37b6cec1659 100644 --- a/build/requirements.in +++ b/build/requirements.in @@ -1,22 +1,32 @@ -# -# test deps -# --r test-requirements.txt --r gpu-test-requirements.txt - -# -# build deps -# -numpy~=2.0.0; python_version<="3.12" -numpy~=2.1.0; python_version>="3.13" - # # runtime deps # -scipy>=1.13.1 +scipy>=1.13.1; python_version<="3.12" +scipy>=1.15.2; python_version>="3.13" -ml_dtypes>=0.4.0 -opt_einsum -zstandard +ml_dtypes>=0.5.3 etils[epath] +opt-einsum + +# Needed to build wheels +build setuptools +wheel + +# JAX's own libraries. We include these in the requirements so you can +# bazel test without building jaxlib and without manually updating the +# the requirements files. +jaxlib + +jax-cuda12-plugin; sys_platform == "linux" +jax-cuda13-plugin +jax-cuda12-pjrt; sys_platform == "linux" +jax-cuda13-pjrt + +# TPU dependencies +libtpu; sys_platform == "linux" and platform_machine == "x86_64" + +# Colorama is only needed on windows. Pip-compile ignores platform-specific +# requirements therefore we include it unconditionally here so that the lockfile +# regeneration always includes it. +colorama>=0.4.4 diff --git a/build/requirements_lock_3_10.txt b/build/requirements_lock_3_10.txt deleted file mode 100644 index 6ed6b59aa584..000000000000 --- a/build/requirements_lock_3_10.txt +++ /dev/null @@ -1,707 +0,0 @@ -# -# This file is autogenerated by pip-compile with Python 3.10 -# by the following command: -# -# bazel run //build:requirements.update -# -absl-py==2.1.0 \ - --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ - --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff - # via -r build/test-requirements.txt -attrs==23.2.0 \ - --hash=sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30 \ - --hash=sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1 - # via hypothesis -auditwheel==6.1.0 \ - --hash=sha256:3bdc686e774cf9e355e924b0fe5a562d55caa385d72234ffe7b81b378dba360f \ - --hash=sha256:e52f734861859e3743eb29fcac7da9c4921a1e4bea58f954b52f2926f8e9e364 - # via -r build/test-requirements.txt -build==1.2.1 \ - --hash=sha256:526263f4870c26f26c433545579475377b2b7588b6f1eac76a001e873ae3e19d \ - --hash=sha256:75e10f767a433d9a86e50d83f418e83efc18ede923ee5ff7df93b6cb0306c5d4 - # via -r build/test-requirements.txt -cloudpickle==3.0.0 \ - --hash=sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 \ - --hash=sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882 - # via -r build/test-requirements.txt -colorama==0.4.6 \ - --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ - --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 - # via -r build/test-requirements.txt -contourpy==1.2.1 \ - --hash=sha256:00e5388f71c1a0610e6fe56b5c44ab7ba14165cdd6d695429c5cd94021e390b2 \ - --hash=sha256:10a37ae557aabf2509c79715cd20b62e4c7c28b8cd62dd7d99e5ed3ce28c3fd9 \ - --hash=sha256:11959f0ce4a6f7b76ec578576a0b61a28bdc0696194b6347ba3f1c53827178b9 \ - --hash=sha256:187fa1d4c6acc06adb0fae5544c59898ad781409e61a926ac7e84b8f276dcef4 \ - --hash=sha256:1a07fc092a4088ee952ddae19a2b2a85757b923217b7eed584fdf25f53a6e7ce \ - --hash=sha256:1cac0a8f71a041aa587410424ad46dfa6a11f6149ceb219ce7dd48f6b02b87a7 \ - --hash=sha256:1d59e739ab0e3520e62a26c60707cc3ab0365d2f8fecea74bfe4de72dc56388f \ - --hash=sha256:2855c8b0b55958265e8b5888d6a615ba02883b225f2227461aa9127c578a4922 \ - --hash=sha256:2e785e0f2ef0d567099b9ff92cbfb958d71c2d5b9259981cd9bee81bd194c9a4 \ - --hash=sha256:309be79c0a354afff9ff7da4aaed7c3257e77edf6c1b448a779329431ee79d7e \ - --hash=sha256:39f3ecaf76cd98e802f094e0d4fbc6dc9c45a8d0c4d185f0f6c2234e14e5f75b \ - --hash=sha256:457499c79fa84593f22454bbd27670227874cd2ff5d6c84e60575c8b50a69619 \ - --hash=sha256:49e70d111fee47284d9dd867c9bb9a7058a3c617274900780c43e38d90fe1205 \ - --hash=sha256:4c75507d0a55378240f781599c30e7776674dbaf883a46d1c90f37e563453480 \ - --hash=sha256:4c863140fafc615c14a4bf4efd0f4425c02230eb8ef02784c9a156461e62c965 \ - --hash=sha256:4d8908b3bee1c889e547867ca4cdc54e5ab6be6d3e078556814a22457f49423c \ - --hash=sha256:5b9eb0ca724a241683c9685a484da9d35c872fd42756574a7cfbf58af26677fd \ - --hash=sha256:6022cecf8f44e36af10bd9118ca71f371078b4c168b6e0fab43d4a889985dbb5 \ - --hash=sha256:6150ffa5c767bc6332df27157d95442c379b7dce3a38dff89c0f39b63275696f \ - --hash=sha256:62828cada4a2b850dbef89c81f5a33741898b305db244904de418cc957ff05dc \ - --hash=sha256:7b4182299f251060996af5249c286bae9361fa8c6a9cda5efc29fe8bfd6062ec \ - --hash=sha256:94b34f32646ca0414237168d68a9157cb3889f06b096612afdd296003fdd32fd \ - --hash=sha256:9ce6889abac9a42afd07a562c2d6d4b2b7134f83f18571d859b25624a331c90b \ - --hash=sha256:9cffe0f850e89d7c0012a1fb8730f75edd4320a0a731ed0c183904fe6ecfc3a9 \ - --hash=sha256:a12a813949e5066148712a0626895c26b2578874e4cc63160bb007e6df3436fe \ - --hash=sha256:a1eea9aecf761c661d096d39ed9026574de8adb2ae1c5bd7b33558af884fb2ce \ - --hash=sha256:a31f94983fecbac95e58388210427d68cd30fe8a36927980fab9c20062645609 \ - --hash=sha256:ac58bdee53cbeba2ecad824fa8159493f0bf3b8ea4e93feb06c9a465d6c87da8 \ - --hash=sha256:af3f4485884750dddd9c25cb7e3915d83c2db92488b38ccb77dd594eac84c4a0 \ - --hash=sha256:b33d2bc4f69caedcd0a275329eb2198f560b325605810895627be5d4b876bf7f \ - --hash=sha256:b59c0ffceff8d4d3996a45f2bb6f4c207f94684a96bf3d9728dbb77428dd8cb8 \ - --hash=sha256:bb6834cbd983b19f06908b45bfc2dad6ac9479ae04abe923a275b5f48f1a186b \ - --hash=sha256:bd3db01f59fdcbce5b22afad19e390260d6d0222f35a1023d9adc5690a889364 \ - --hash=sha256:bd7c23df857d488f418439686d3b10ae2fbf9bc256cd045b37a8c16575ea1040 \ - --hash=sha256:c2528d60e398c7c4c799d56f907664673a807635b857df18f7ae64d3e6ce2d9f \ - --hash=sha256:d31a63bc6e6d87f77d71e1abbd7387ab817a66733734883d1fc0021ed9bfa083 \ - --hash=sha256:d4492d82b3bc7fbb7e3610747b159869468079fe149ec5c4d771fa1f614a14df \ - --hash=sha256:ddcb8581510311e13421b1f544403c16e901c4e8f09083c881fab2be80ee31ba \ - --hash=sha256:e1d59258c3c67c865435d8fbeb35f8c59b8bef3d6f46c1f29f6123556af28445 \ - --hash=sha256:eb3315a8a236ee19b6df481fc5f997436e8ade24a9f03dfdc6bd490fea20c6da \ - --hash=sha256:ef2b055471c0eb466033760a521efb9d8a32b99ab907fc8358481a1dd29e3bd3 \ - --hash=sha256:ef5adb9a3b1d0c645ff694f9bca7702ec2c70f4d734f9922ea34de02294fdf72 \ - --hash=sha256:f32c38afb74bd98ce26de7cc74a67b40afb7b05aae7b42924ea990d51e4dac02 \ - --hash=sha256:fe0ccca550bb8e5abc22f530ec0466136379c01321fd94f30a22231e8a48d985 - # via matplotlib -cycler==0.12.1 \ - --hash=sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30 \ - --hash=sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c - # via matplotlib -etils[epath,epy]==1.7.0 \ - --hash=sha256:61af8f7c242171de15e22e5da02d527cb9e677d11f8bcafe18fcc3548eee3e60 \ - --hash=sha256:97b68fd25e185683215286ef3a54e38199b6245f5fe8be6bedc1189be4256350 - # via -r build/requirements.in -exceptiongroup==1.2.1 \ - --hash=sha256:5258b9ed329c5bbdd31a309f53cbfb0b155341807f6ff7606a1e801a891b29ad \ - --hash=sha256:a4785e48b045528f5bfe627b6ad554ff32def154f42372786903b7abcfe1aa16 - # via - # hypothesis - # pytest -execnet==2.1.1 \ - --hash=sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc \ - --hash=sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3 - # via pytest-xdist -filelock==3.14.0 \ - --hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \ - --hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a - # via -r build/test-requirements.txt -flatbuffers==24.3.25 \ - --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ - --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 - # via -r build/test-requirements.txt -fonttools==4.51.0 \ - --hash=sha256:0118ef998a0699a96c7b28457f15546815015a2710a1b23a7bf6c1be60c01636 \ - --hash=sha256:0d145976194a5242fdd22df18a1b451481a88071feadf251221af110ca8f00ce \ - --hash=sha256:0e19bd9e9964a09cd2433a4b100ca7f34e34731e0758e13ba9a1ed6e5468cc0f \ - --hash=sha256:0f08c901d3866a8905363619e3741c33f0a83a680d92a9f0e575985c2634fcc1 \ - --hash=sha256:1250e818b5f8a679ad79660855528120a8f0288f8f30ec88b83db51515411fcc \ - --hash=sha256:15c94eeef6b095831067f72c825eb0e2d48bb4cea0647c1b05c981ecba2bf39f \ - --hash=sha256:1621ee57da887c17312acc4b0e7ac30d3a4fb0fec6174b2e3754a74c26bbed1e \ - --hash=sha256:180194c7fe60c989bb627d7ed5011f2bef1c4d36ecf3ec64daec8302f1ae0716 \ - --hash=sha256:278e50f6b003c6aed19bae2242b364e575bcb16304b53f2b64f6551b9c000e15 \ - --hash=sha256:32b17504696f605e9e960647c5f64b35704782a502cc26a37b800b4d69ff3c77 \ - --hash=sha256:3bee3f3bd9fa1d5ee616ccfd13b27ca605c2b4270e45715bd2883e9504735034 \ - --hash=sha256:4060acc2bfa2d8e98117828a238889f13b6f69d59f4f2d5857eece5277b829ba \ - --hash=sha256:54dcf21a2f2d06ded676e3c3f9f74b2bafded3a8ff12f0983160b13e9f2fb4a7 \ - --hash=sha256:56fc244f2585d6c00b9bcc59e6593e646cf095a96fe68d62cd4da53dd1287b55 \ - --hash=sha256:599bdb75e220241cedc6faebfafedd7670335d2e29620d207dd0378a4e9ccc5a \ - --hash=sha256:5f6bc991d1610f5c3bbe997b0233cbc234b8e82fa99fc0b2932dc1ca5e5afec0 \ - --hash=sha256:60a3409c9112aec02d5fb546f557bca6efa773dcb32ac147c6baf5f742e6258b \ - --hash=sha256:68b3fb7775a923be73e739f92f7e8a72725fd333eab24834041365d2278c3671 \ - --hash=sha256:76f1777d8b3386479ffb4a282e74318e730014d86ce60f016908d9801af9ca2a \ - --hash=sha256:806e7912c32a657fa39d2d6eb1d3012d35f841387c8fc6cf349ed70b7c340039 \ - --hash=sha256:84d7751f4468dd8cdd03ddada18b8b0857a5beec80bce9f435742abc9a851a74 \ - --hash=sha256:865a58b6e60b0938874af0968cd0553bcd88e0b2cb6e588727117bd099eef836 \ - --hash=sha256:8ac27f436e8af7779f0bb4d5425aa3535270494d3bc5459ed27de3f03151e4c2 \ - --hash=sha256:8b4850fa2ef2cfbc1d1f689bc159ef0f45d8d83298c1425838095bf53ef46308 \ - --hash=sha256:8b5ad456813d93b9c4b7ee55302208db2b45324315129d85275c01f5cb7e61a2 \ - --hash=sha256:8e2f1a4499e3b5ee82c19b5ee57f0294673125c65b0a1ff3764ea1f9db2f9ef5 \ - --hash=sha256:9696fe9f3f0c32e9a321d5268208a7cc9205a52f99b89479d1b035ed54c923f1 \ - --hash=sha256:96a48e137c36be55e68845fc4284533bda2980f8d6f835e26bca79d7e2006438 \ - --hash=sha256:a8feca65bab31479d795b0d16c9a9852902e3a3c0630678efb0b2b7941ea9c74 \ - --hash=sha256:aefa011207ed36cd280babfaa8510b8176f1a77261833e895a9d96e57e44802f \ - --hash=sha256:b2b92381f37b39ba2fc98c3a45a9d6383bfc9916a87d66ccb6553f7bdd129097 \ - --hash=sha256:b3c61423f22165541b9403ee39874dcae84cd57a9078b82e1dce8cb06b07fa2e \ - --hash=sha256:b5b48a1121117047d82695d276c2af2ee3a24ffe0f502ed581acc2673ecf1037 \ - --hash=sha256:c18b49adc721a7d0b8dfe7c3130c89b8704baf599fb396396d07d4aa69b824a1 \ - --hash=sha256:c5b8cab0c137ca229433570151b5c1fc6af212680b58b15abd797dcdd9dd5051 \ - --hash=sha256:c7e91abdfae1b5c9e3a543f48ce96013f9a08c6c9668f1e6be0beabf0a569c1b \ - --hash=sha256:cadf4e12a608ef1d13e039864f484c8a968840afa0258b0b843a0556497ea9ed \ - --hash=sha256:dc0673361331566d7a663d7ce0f6fdcbfbdc1f59c6e3ed1165ad7202ca183c68 \ - --hash=sha256:de7c29bdbdd35811f14493ffd2534b88f0ce1b9065316433b22d63ca1cd21f14 \ - --hash=sha256:e9d9298be7a05bb4801f558522adbe2feea1b0b103d5294ebf24a92dd49b78e5 \ - --hash=sha256:ee1af4be1c5afe4c96ca23badd368d8dc75f611887fb0c0dac9f71ee5d6f110e \ - --hash=sha256:f7e89853d8bea103c8e3514b9f9dc86b5b4120afb4583b57eb10dfa5afbe0936 - # via matplotlib -fsspec==2024.5.0 \ - --hash=sha256:1d021b0b0f933e3b3029ed808eb400c08ba101ca2de4b3483fbc9ca23fcee94a \ - --hash=sha256:e0fdbc446d67e182f49a70b82cf7889028a63588fde6b222521f10937b2b670c - # via etils -hypothesis==6.102.4 \ - --hash=sha256:013df31b04a4daede13756f497e60e451963d86f426395a79f99c5d692919bbd \ - --hash=sha256:59b4d144346d5cffb482cc1bafbd21b13ff31608e8c4b3e4630339aee3e87763 - # via -r build/test-requirements.txt -importlib-resources==6.4.0 \ - --hash=sha256:50d10f043df931902d4194ea07ec57960f66a80449ff867bfe782b4c486ba78c \ - --hash=sha256:cdb2b453b8046ca4e3798eb1d84f3cce1446a0e8e7b5ef4efb600f19fc398145 - # via etils -iniconfig==2.0.0 \ - --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ - --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 - # via pytest -kiwisolver==1.4.5 \ - --hash=sha256:00bd361b903dc4bbf4eb165f24d1acbee754fce22ded24c3d56eec268658a5cf \ - --hash=sha256:040c1aebeda72197ef477a906782b5ab0d387642e93bda547336b8957c61022e \ - --hash=sha256:05703cf211d585109fcd72207a31bb170a0f22144d68298dc5e61b3c946518af \ - --hash=sha256:06f54715b7737c2fecdbf140d1afb11a33d59508a47bf11bb38ecf21dc9ab79f \ - --hash=sha256:0dc9db8e79f0036e8173c466d21ef18e1befc02de8bf8aa8dc0813a6dc8a7046 \ - --hash=sha256:0f114aa76dc1b8f636d077979c0ac22e7cd8f3493abbab152f20eb8d3cda71f3 \ - --hash=sha256:11863aa14a51fd6ec28688d76f1735f8f69ab1fabf388851a595d0721af042f5 \ - --hash=sha256:11c7de8f692fc99816e8ac50d1d1aef4f75126eefc33ac79aac02c099fd3db71 \ - --hash=sha256:11d011a7574eb3b82bcc9c1a1d35c1d7075677fdd15de527d91b46bd35e935ee \ - --hash=sha256:146d14bebb7f1dc4d5fbf74f8a6cb15ac42baadee8912eb84ac0b3b2a3dc6ac3 \ - --hash=sha256:15568384086b6df3c65353820a4473575dbad192e35010f622c6ce3eebd57af9 \ - --hash=sha256:19df6e621f6d8b4b9c4d45f40a66839294ff2bb235e64d2178f7522d9170ac5b \ - --hash=sha256:1b04139c4236a0f3aff534479b58f6f849a8b351e1314826c2d230849ed48985 \ - --hash=sha256:210ef2c3a1f03272649aff1ef992df2e724748918c4bc2d5a90352849eb40bea \ - --hash=sha256:2270953c0d8cdab5d422bee7d2007f043473f9d2999631c86a223c9db56cbd16 \ - --hash=sha256:2400873bccc260b6ae184b2b8a4fec0e4082d30648eadb7c3d9a13405d861e89 \ - --hash=sha256:2a40773c71d7ccdd3798f6489aaac9eee213d566850a9533f8d26332d626b82c \ - --hash=sha256:2c5674c4e74d939b9d91dda0fae10597ac7521768fec9e399c70a1f27e2ea2d9 \ - --hash=sha256:3195782b26fc03aa9c6913d5bad5aeb864bdc372924c093b0f1cebad603dd712 \ - --hash=sha256:31a82d498054cac9f6d0b53d02bb85811185bcb477d4b60144f915f3b3126342 \ - --hash=sha256:32d5cf40c4f7c7b3ca500f8985eb3fb3a7dfc023215e876f207956b5ea26632a \ - --hash=sha256:346f5343b9e3f00b8db8ba359350eb124b98c99efd0b408728ac6ebf38173958 \ - --hash=sha256:378a214a1e3bbf5ac4a8708304318b4f890da88c9e6a07699c4ae7174c09a68d \ - --hash=sha256:39b42c68602539407884cf70d6a480a469b93b81b7701378ba5e2328660c847a \ - --hash=sha256:3a2b053a0ab7a3960c98725cfb0bf5b48ba82f64ec95fe06f1d06c99b552e130 \ - --hash=sha256:3aba7311af82e335dd1e36ffff68aaca609ca6290c2cb6d821a39aa075d8e3ff \ - --hash=sha256:3cd32d6c13807e5c66a7cbb79f90b553642f296ae4518a60d8d76243b0ad2898 \ - --hash=sha256:3edd2fa14e68c9be82c5b16689e8d63d89fe927e56debd6e1dbce7a26a17f81b \ - --hash=sha256:4c380469bd3f970ef677bf2bcba2b6b0b4d5c75e7a020fb863ef75084efad66f \ - --hash=sha256:4e66e81a5779b65ac21764c295087de82235597a2293d18d943f8e9e32746265 \ - --hash=sha256:53abb58632235cd154176ced1ae8f0d29a6657aa1aa9decf50b899b755bc2b93 \ - --hash=sha256:5794cf59533bc3f1b1c821f7206a3617999db9fbefc345360aafe2e067514929 \ - --hash=sha256:59415f46a37f7f2efeec758353dd2eae1b07640d8ca0f0c42548ec4125492635 \ - --hash=sha256:59ec7b7c7e1a61061850d53aaf8e93db63dce0c936db1fda2658b70e4a1be709 \ - --hash=sha256:59edc41b24031bc25108e210c0def6f6c2191210492a972d585a06ff246bb79b \ - --hash=sha256:5a580c91d686376f0f7c295357595c5a026e6cbc3d77b7c36e290201e7c11ecb \ - --hash=sha256:5b94529f9b2591b7af5f3e0e730a4e0a41ea174af35a4fd067775f9bdfeee01a \ - --hash=sha256:5c7b3b3a728dc6faf3fc372ef24f21d1e3cee2ac3e9596691d746e5a536de920 \ - --hash=sha256:5c90ae8c8d32e472be041e76f9d2f2dbff4d0b0be8bd4041770eddb18cf49a4e \ - --hash=sha256:5e7139af55d1688f8b960ee9ad5adafc4ac17c1c473fe07133ac092310d76544 \ - --hash=sha256:5ff5cf3571589b6d13bfbfd6bcd7a3f659e42f96b5fd1c4830c4cf21d4f5ef45 \ - --hash=sha256:620ced262a86244e2be10a676b646f29c34537d0d9cc8eb26c08f53d98013390 \ - --hash=sha256:6512cb89e334e4700febbffaaa52761b65b4f5a3cf33f960213d5656cea36a77 \ - --hash=sha256:6c08e1312a9cf1074d17b17728d3dfce2a5125b2d791527f33ffbe805200a355 \ - --hash=sha256:6c3bd3cde54cafb87d74d8db50b909705c62b17c2099b8f2e25b461882e544ff \ - --hash=sha256:6ef7afcd2d281494c0a9101d5c571970708ad911d028137cd558f02b851c08b4 \ - --hash=sha256:7269d9e5f1084a653d575c7ec012ff57f0c042258bf5db0954bf551c158466e7 \ - --hash=sha256:72d40b33e834371fd330fb1472ca19d9b8327acb79a5821d4008391db8e29f20 \ - --hash=sha256:74d1b44c6cfc897df648cc9fdaa09bc3e7679926e6f96df05775d4fb3946571c \ - --hash=sha256:74db36e14a7d1ce0986fa104f7d5637aea5c82ca6326ed0ec5694280942d1162 \ - --hash=sha256:763773d53f07244148ccac5b084da5adb90bfaee39c197554f01b286cf869228 \ - --hash=sha256:76c6a5964640638cdeaa0c359382e5703e9293030fe730018ca06bc2010c4437 \ - --hash=sha256:76d9289ed3f7501012e05abb8358bbb129149dbd173f1f57a1bf1c22d19ab7cc \ - --hash=sha256:7931d8f1f67c4be9ba1dd9c451fb0eeca1a25b89e4d3f89e828fe12a519b782a \ - --hash=sha256:7b8b454bac16428b22560d0a1cf0a09875339cab69df61d7805bf48919415901 \ - --hash=sha256:7e5bab140c309cb3a6ce373a9e71eb7e4873c70c2dda01df6820474f9889d6d4 \ - --hash=sha256:83d78376d0d4fd884e2c114d0621624b73d2aba4e2788182d286309ebdeed770 \ - --hash=sha256:852542f9481f4a62dbb5dd99e8ab7aedfeb8fb6342349a181d4036877410f525 \ - --hash=sha256:85267bd1aa8880a9c88a8cb71e18d3d64d2751a790e6ca6c27b8ccc724bcd5ad \ - --hash=sha256:88a2df29d4724b9237fc0c6eaf2a1adae0cdc0b3e9f4d8e7dc54b16812d2d81a \ - --hash=sha256:88b9f257ca61b838b6f8094a62418421f87ac2a1069f7e896c36a7d86b5d4c29 \ - --hash=sha256:8ab3919a9997ab7ef2fbbed0cc99bb28d3c13e6d4b1ad36e97e482558a91be90 \ - --hash=sha256:92dea1ffe3714fa8eb6a314d2b3c773208d865a0e0d35e713ec54eea08a66250 \ - --hash=sha256:9407b6a5f0d675e8a827ad8742e1d6b49d9c1a1da5d952a67d50ef5f4170b18d \ - --hash=sha256:9408acf3270c4b6baad483865191e3e582b638b1654a007c62e3efe96f09a9a3 \ - --hash=sha256:955e8513d07a283056b1396e9a57ceddbd272d9252c14f154d450d227606eb54 \ - --hash=sha256:9db8ea4c388fdb0f780fe91346fd438657ea602d58348753d9fb265ce1bca67f \ - --hash=sha256:9eaa8b117dc8337728e834b9c6e2611f10c79e38f65157c4c38e9400286f5cb1 \ - --hash=sha256:a51a263952b1429e429ff236d2f5a21c5125437861baeed77f5e1cc2d2c7c6da \ - --hash=sha256:a6aa6315319a052b4ee378aa171959c898a6183f15c1e541821c5c59beaa0238 \ - --hash=sha256:aa12042de0171fad672b6c59df69106d20d5596e4f87b5e8f76df757a7c399aa \ - --hash=sha256:aaf7be1207676ac608a50cd08f102f6742dbfc70e8d60c4db1c6897f62f71523 \ - --hash=sha256:b0157420efcb803e71d1b28e2c287518b8808b7cf1ab8af36718fd0a2c453eb0 \ - --hash=sha256:b3f7e75f3015df442238cca659f8baa5f42ce2a8582727981cbfa15fee0ee205 \ - --hash=sha256:b9098e0049e88c6a24ff64545cdfc50807818ba6c1b739cae221bbbcbc58aad3 \ - --hash=sha256:ba55dce0a9b8ff59495ddd050a0225d58bd0983d09f87cfe2b6aec4f2c1234e4 \ - --hash=sha256:bb86433b1cfe686da83ce32a9d3a8dd308e85c76b60896d58f082136f10bffac \ - --hash=sha256:bbea0db94288e29afcc4c28afbf3a7ccaf2d7e027489c449cf7e8f83c6346eb9 \ - --hash=sha256:bbf1d63eef84b2e8c89011b7f2235b1e0bf7dacc11cac9431fc6468e99ac77fb \ - --hash=sha256:c7940c1dc63eb37a67721b10d703247552416f719c4188c54e04334321351ced \ - --hash=sha256:c9bf3325c47b11b2e51bca0824ea217c7cd84491d8ac4eefd1e409705ef092bd \ - --hash=sha256:cdc8a402aaee9a798b50d8b827d7ecf75edc5fb35ea0f91f213ff927c15f4ff0 \ - --hash=sha256:ceec1a6bc6cab1d6ff5d06592a91a692f90ec7505d6463a88a52cc0eb58545da \ - --hash=sha256:cfe6ab8da05c01ba6fbea630377b5da2cd9bcbc6338510116b01c1bc939a2c18 \ - --hash=sha256:d099e745a512f7e3bbe7249ca835f4d357c586d78d79ae8f1dcd4d8adeb9bda9 \ - --hash=sha256:d0ef46024e6a3d79c01ff13801cb19d0cad7fd859b15037aec74315540acc276 \ - --hash=sha256:d2e5a98f0ec99beb3c10e13b387f8db39106d53993f498b295f0c914328b1333 \ - --hash=sha256:da4cfb373035def307905d05041c1d06d8936452fe89d464743ae7fb8371078b \ - --hash=sha256:da802a19d6e15dffe4b0c24b38b3af68e6c1a68e6e1d8f30148c83864f3881db \ - --hash=sha256:dced8146011d2bc2e883f9bd68618b8247387f4bbec46d7392b3c3b032640126 \ - --hash=sha256:dfdd7c0b105af050eb3d64997809dc21da247cf44e63dc73ff0fd20b96be55a9 \ - --hash=sha256:e368f200bbc2e4f905b8e71eb38b3c04333bddaa6a2464a6355487b02bb7fb09 \ - --hash=sha256:e391b1f0a8a5a10ab3b9bb6afcfd74f2175f24f8975fb87ecae700d1503cdee0 \ - --hash=sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec \ - --hash=sha256:e5d706eba36b4c4d5bc6c6377bb6568098765e990cfc21ee16d13963fab7b3e7 \ - --hash=sha256:ec20916e7b4cbfb1f12380e46486ec4bcbaa91a9c448b97023fde0d5bbf9e4ff \ - --hash=sha256:f1d072c2eb0ad60d4c183f3fb44ac6f73fb7a8f16a2694a91f988275cbf352f9 \ - --hash=sha256:f846c260f483d1fd217fe5ed7c173fb109efa6b1fc8381c8b7552c5781756192 \ - --hash=sha256:f91de7223d4c7b793867797bacd1ee53bfe7359bd70d27b7b58a04efbb9436c8 \ - --hash=sha256:faae4860798c31530dd184046a900e652c95513796ef51a12bc086710c2eec4d \ - --hash=sha256:fc579bf0f502e54926519451b920e875f433aceb4624a3646b3252b5caa9e0b6 \ - --hash=sha256:fcc700eadbbccbf6bc1bcb9dbe0786b4b1cb91ca0dcda336eef5c2beed37b797 \ - --hash=sha256:fd32ea360bcbb92d28933fc05ed09bffcb1704ba3fc7942e81db0fd4f81a7892 \ - --hash=sha256:fdb7adb641a0d13bdcd4ef48e062363d8a9ad4a182ac7647ec88f695e719ae9f - # via matplotlib -markdown-it-py==3.0.0 \ - --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ - --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb - # via rich -matplotlib==3.8.4 ; python_version <= "3.10" \ - --hash=sha256:1c13f041a7178f9780fb61cc3a2b10423d5e125480e4be51beaf62b172413b67 \ - --hash=sha256:232ce322bfd020a434caaffbd9a95333f7c2491e59cfc014041d95e38ab90d1c \ - --hash=sha256:493e9f6aa5819156b58fce42b296ea31969f2aab71c5b680b4ea7a3cb5c07d94 \ - --hash=sha256:50bac6e4d77e4262c4340d7a985c30912054745ec99756ce213bfbc3cb3808eb \ - --hash=sha256:606e3b90897554c989b1e38a258c626d46c873523de432b1462f295db13de6f9 \ - --hash=sha256:6209e5c9aaccc056e63b547a8152661324404dd92340a6e479b3a7f24b42a5d0 \ - --hash=sha256:6485ac1f2e84676cff22e693eaa4fbed50ef5dc37173ce1f023daef4687df616 \ - --hash=sha256:6addbd5b488aedb7f9bc19f91cd87ea476206f45d7116fcfe3d31416702a82fa \ - --hash=sha256:72f9322712e4562e792b2961971891b9fbbb0e525011e09ea0d1f416c4645661 \ - --hash=sha256:7a6769f58ce51791b4cb8b4d7642489df347697cd3e23d88266aaaee93b41d9a \ - --hash=sha256:8080d5081a86e690d7688ffa542532e87f224c38a6ed71f8fbed34dd1d9fedae \ - --hash=sha256:843cbde2f0946dadd8c5c11c6d91847abd18ec76859dc319362a0964493f0ba6 \ - --hash=sha256:8aac397d5e9ec158960e31c381c5ffc52ddd52bd9a47717e2a694038167dffea \ - --hash=sha256:8f65c9f002d281a6e904976007b2d46a1ee2bcea3a68a8c12dda24709ddc9106 \ - --hash=sha256:90df07db7b599fe7035d2f74ab7e438b656528c68ba6bb59b7dc46af39ee48ef \ - --hash=sha256:9bb0189011785ea794ee827b68777db3ca3f93f3e339ea4d920315a0e5a78d54 \ - --hash=sha256:a0e47eda4eb2614300fc7bb4657fced3e83d6334d03da2173b09e447418d499f \ - --hash=sha256:abc9d838f93583650c35eca41cfcec65b2e7cb50fd486da6f0c49b5e1ed23014 \ - --hash=sha256:ac24233e8f2939ac4fd2919eed1e9c0871eac8057666070e94cbf0b33dd9c338 \ - --hash=sha256:b12ba985837e4899b762b81f5b2845bd1a28f4fdd1a126d9ace64e9c4eb2fb25 \ - --hash=sha256:b7a2a253d3b36d90c8993b4620183b55665a429da8357a4f621e78cd48b2b30b \ - --hash=sha256:c7064120a59ce6f64103c9cefba8ffe6fba87f2c61d67c401186423c9a20fd35 \ - --hash=sha256:c89ee9314ef48c72fe92ce55c4e95f2f39d70208f9f1d9db4e64079420d8d732 \ - --hash=sha256:cc4ccdc64e3039fc303defd119658148f2349239871db72cd74e2eeaa9b80b71 \ - --hash=sha256:ce1edd9f5383b504dbc26eeea404ed0a00656c526638129028b758fd43fc5f10 \ - --hash=sha256:ecd79298550cba13a43c340581a3ec9c707bd895a6a061a78fa2524660482fc0 \ - --hash=sha256:f51c4c869d4b60d769f7b4406eec39596648d9d70246428745a681c327a8ad30 \ - --hash=sha256:fb44f53af0a62dc80bba4443d9b27f2fde6acfdac281d95bc872dc148a6509cc - # via -r build/test-requirements.txt -mdurl==0.1.2 \ - --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ - --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba - # via markdown-it-py -ml-dtypes==0.5.1 \ - --hash=sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327 \ - --hash=sha256:05f23447a1c20ddf4dc7c2c661aa9ed93fcb2658f1017c204d1e758714dc28a8 \ - --hash=sha256:12651420130ee7cc13059fc56dac6ad300c3af3848b802d475148c9defd27c23 \ - --hash=sha256:141b2ea2f20bb10802ddca55d91fe21231ef49715cfc971998e8f2a9838f3dbe \ - --hash=sha256:15ad0f3b0323ce96c24637a88a6f44f6713c64032f27277b069f285c3cf66478 \ - --hash=sha256:1b7fbe5571fdf28fd3aaab3ef4aafc847de9ebf263be959958c1ca58ec8eadf5 \ - --hash=sha256:26ebcc69d7b779c8f129393e99732961b5cc33fcff84090451f448c89b0e01b4 \ - --hash=sha256:6f462f5eca22fb66d7ff9c4744a3db4463af06c49816c4b6ac89b16bfcdc592e \ - --hash=sha256:6f76232163b5b9c34291b54621ee60417601e2e4802a188a0ea7157cd9b323f4 \ - --hash=sha256:7000b6e4d8ef07542c05044ec5d8bbae1df083b3f56822c3da63993a113e716f \ - --hash=sha256:810512e2eccdfc3b41eefa3a27402371a3411453a1efc7e9c000318196140fed \ - --hash=sha256:8f2c028954f16ede77902b223a8da2d9cbb3892375b85809a5c3cfb1587960c4 \ - --hash=sha256:9626d0bca1fb387d5791ca36bacbba298c5ef554747b7ebeafefb4564fc83566 \ - --hash=sha256:ac5b58559bb84a95848ed6984eb8013249f90b6bab62aa5acbad876e256002c9 \ - --hash=sha256:ad4953c5eb9c25a56d11a913c2011d7e580a435ef5145f804d98efa14477d390 \ - --hash=sha256:aefedc579ece2f8fb38f876aa7698204ee4c372d0e54f1c1ffa8ca580b54cc60 \ - --hash=sha256:afb2009ac98da274e893e03162f6269398b2b00d947e7057ee2469a921d58135 \ - --hash=sha256:b8a9d46b4df5ae2135a8e8e72b465448ebbc1559997f4f9304a9ecc3413efb5b \ - --hash=sha256:bd73f51957949069573ff783563486339a9285d72e2f36c18e0c1aa9ca7eb190 \ - --hash=sha256:bf9975bda82a99dc935f2ae4c83846d86df8fd6ba179614acac8e686910851da \ - --hash=sha256:c09526488c3a9e8b7a23a388d4974b670a9a3dd40c5c8a61db5593ce9b725bab \ - --hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \ - --hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \ - --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 - # via -r build/requirements.in -mpmath==1.4.0a1 \ - --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ - --hash=sha256:f8b7b5a3a1726ab6e8c898eb2157426b82c482ab1ab8ffed9f88bb9e07c6e9c1 - # via -r build/test-requirements.txt -numpy==2.0.0 ; python_version <= "3.12" \ - --hash=sha256:04494f6ec467ccb5369d1808570ae55f6ed9b5809d7f035059000a37b8d7e86f \ - --hash=sha256:0a43f0974d501842866cc83471bdb0116ba0dffdbaac33ec05e6afed5b615238 \ - --hash=sha256:0e50842b2295ba8414c8c1d9d957083d5dfe9e16828b37de883f51fc53c4016f \ - --hash=sha256:0ec84b9ba0654f3b962802edc91424331f423dcf5d5f926676e0150789cb3d95 \ - --hash=sha256:17067d097ed036636fa79f6a869ac26df7db1ba22039d962422506640314933a \ - --hash=sha256:1cde1753efe513705a0c6d28f5884e22bdc30438bf0085c5c486cdaff40cd67a \ - --hash=sha256:1e72728e7501a450288fc8e1f9ebc73d90cfd4671ebbd631f3e7857c39bd16f2 \ - --hash=sha256:2635dbd200c2d6faf2ef9a0d04f0ecc6b13b3cad54f7c67c61155138835515d2 \ - --hash=sha256:2ce46fd0b8a0c947ae047d222f7136fc4d55538741373107574271bc00e20e8f \ - --hash=sha256:34f003cb88b1ba38cb9a9a4a3161c1604973d7f9d5552c38bc2f04f829536609 \ - --hash=sha256:354f373279768fa5a584bac997de6a6c9bc535c482592d7a813bb0c09be6c76f \ - --hash=sha256:38ecb5b0582cd125f67a629072fed6f83562d9dd04d7e03256c9829bdec027ad \ - --hash=sha256:3e8e01233d57639b2e30966c63d36fcea099d17c53bf424d77f088b0f4babd86 \ - --hash=sha256:3f6bed7f840d44c08ebdb73b1825282b801799e325bcbdfa6bc5c370e5aecc65 \ - --hash=sha256:4554eb96f0fd263041baf16cf0881b3f5dafae7a59b1049acb9540c4d57bc8cb \ - --hash=sha256:46e161722e0f619749d1cd892167039015b2c2817296104487cd03ed4a955995 \ - --hash=sha256:49d9f7d256fbc804391a7f72d4a617302b1afac1112fac19b6c6cec63fe7fe8a \ - --hash=sha256:4d2f62e55a4cd9c58c1d9a1c9edaedcd857a73cb6fda875bf79093f9d9086f85 \ - --hash=sha256:5f64641b42b2429f56ee08b4f427a4d2daf916ec59686061de751a55aafa22e4 \ - --hash=sha256:63b92c512d9dbcc37f9d81b123dec99fdb318ba38c8059afc78086fe73820275 \ - --hash=sha256:6d7696c615765091cc5093f76fd1fa069870304beaccfd58b5dcc69e55ef49c1 \ - --hash=sha256:79e843d186c8fb1b102bef3e2bc35ef81160ffef3194646a7fdd6a73c6b97196 \ - --hash=sha256:821eedb7165ead9eebdb569986968b541f9908979c2da8a4967ecac4439bae3d \ - --hash=sha256:84554fc53daa8f6abf8e8a66e076aff6ece62de68523d9f665f32d2fc50fd66e \ - --hash=sha256:8d83bb187fb647643bd56e1ae43f273c7f4dbcdf94550d7938cfc32566756514 \ - --hash=sha256:903703372d46bce88b6920a0cd86c3ad82dae2dbef157b5fc01b70ea1cfc430f \ - --hash=sha256:9416a5c2e92ace094e9f0082c5fd473502c91651fb896bc17690d6fc475128d6 \ - --hash=sha256:9a1712c015831da583b21c5bfe15e8684137097969c6d22e8316ba66b5baabe4 \ - --hash=sha256:9c27f0946a3536403efb0e1c28def1ae6730a72cd0d5878db38824855e3afc44 \ - --hash=sha256:a356364941fb0593bb899a1076b92dfa2029f6f5b8ba88a14fd0984aaf76d0df \ - --hash=sha256:a7039a136017eaa92c1848152827e1424701532ca8e8967fe480fe1569dae581 \ - --hash=sha256:acd3a644e4807e73b4e1867b769fbf1ce8c5d80e7caaef0d90dcdc640dfc9787 \ - --hash=sha256:ad0c86f3455fbd0de6c31a3056eb822fc939f81b1618f10ff3406971893b62a5 \ - --hash=sha256:b4c76e3d4c56f145d41b7b6751255feefae92edbc9a61e1758a98204200f30fc \ - --hash=sha256:b6f6a8f45d0313db07d6d1d37bd0b112f887e1369758a5419c0370ba915b3871 \ - --hash=sha256:c5a59996dc61835133b56a32ebe4ef3740ea5bc19b3983ac60cc32be5a665d54 \ - --hash=sha256:c73aafd1afca80afecb22718f8700b40ac7cab927b8abab3c3e337d70e10e5a2 \ - --hash=sha256:cee6cc0584f71adefe2c908856ccc98702baf95ff80092e4ca46061538a2ba98 \ - --hash=sha256:cef04d068f5fb0518a77857953193b6bb94809a806bd0a14983a8f12ada060c9 \ - --hash=sha256:cf5d1c9e6837f8af9f92b6bd3e86d513cdc11f60fd62185cc49ec7d1aba34864 \ - --hash=sha256:e61155fae27570692ad1d327e81c6cf27d535a5d7ef97648a17d922224b216de \ - --hash=sha256:e7f387600d424f91576af20518334df3d97bc76a300a755f9a8d6e4f5cadd289 \ - --hash=sha256:ed08d2703b5972ec736451b818c2eb9da80d66c3e84aed1deeb0c345fefe461b \ - --hash=sha256:fbd6acc766814ea6443628f4e6751d0da6593dae29c08c0b2606164db026970c \ - --hash=sha256:feff59f27338135776f6d4e2ec7aeeac5d5f7a08a83e80869121ef8164b74af9 - # via - # -r build/requirements.in - # contourpy - # matplotlib - # ml-dtypes - # opt-einsum - # scipy -nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ - --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ - --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ - --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 - # via - # via -r build/test-requirements.txt - # nvidia-cudnn-cu12 - # nvidia-cusolver-cu12 -nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \ - --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ - --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ - --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 - # via -r build/test-requirements.txt -nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \ - --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ - --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ - --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b - # via -r build/test-requirements.txt -nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ - --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ - --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ - --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 - # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ - --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ - --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ - --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef - # via -r build/test-requirements.txt -nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ - --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ - --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ - --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 - # via -r build/test-requirements.txt -nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \ - --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ - --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ - --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac - # via -r build/test-requirements.txt -nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \ - --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ - --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ - --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 - # via - # via -r build/test-requirements.txt - # nvidia-cusolver-cu12 -nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \ - --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ - --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 - # via -r build/test-requirements.txt -nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \ - --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ - --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ - --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 - # via - # via -r build/test-requirements.txt - # nvidia-cufft-cu12 - # nvidia-cusolver-cu12 - # nvidia-cusparse-cu12 -opt-einsum==3.3.0 \ - --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ - --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 - # via - # -r build/requirements.in - # -r build/test-requirements.txt -packaging==24.0 \ - --hash=sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 \ - --hash=sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9 - # via - # auditwheel - # build - # matplotlib - # pytest -pillow==11.0.0 \ - --hash=sha256:00177a63030d612148e659b55ba99527803288cea7c75fb05766ab7981a8c1b7 \ - --hash=sha256:006bcdd307cc47ba43e924099a038cbf9591062e6c50e570819743f5607404f5 \ - --hash=sha256:084a07ef0821cfe4858fe86652fffac8e187b6ae677e9906e192aafcc1b69903 \ - --hash=sha256:0ae08bd8ffc41aebf578c2af2f9d8749d91f448b3bfd41d7d9ff573d74f2a6b2 \ - --hash=sha256:0e038b0745997c7dcaae350d35859c9715c71e92ffb7e0f4a8e8a16732150f38 \ - --hash=sha256:1187739620f2b365de756ce086fdb3604573337cc28a0d3ac4a01ab6b2d2a6d2 \ - --hash=sha256:16095692a253047fe3ec028e951fa4221a1f3ed3d80c397e83541a3037ff67c9 \ - --hash=sha256:1a61b54f87ab5786b8479f81c4b11f4d61702830354520837f8cc791ebba0f5f \ - --hash=sha256:1c1d72714f429a521d8d2d018badc42414c3077eb187a59579f28e4270b4b0fc \ - --hash=sha256:1e2688958a840c822279fda0086fec1fdab2f95bf2b717b66871c4ad9859d7e8 \ - --hash=sha256:20ec184af98a121fb2da42642dea8a29ec80fc3efbaefb86d8fdd2606619045d \ - --hash=sha256:21a0d3b115009ebb8ac3d2ebec5c2982cc693da935f4ab7bb5c8ebe2f47d36f2 \ - --hash=sha256:224aaa38177597bb179f3ec87eeefcce8e4f85e608025e9cfac60de237ba6316 \ - --hash=sha256:2679d2258b7f1192b378e2893a8a0a0ca472234d4c2c0e6bdd3380e8dfa21b6a \ - --hash=sha256:27a7860107500d813fcd203b4ea19b04babe79448268403172782754870dac25 \ - --hash=sha256:290f2cc809f9da7d6d622550bbf4c1e57518212da51b6a30fe8e0a270a5b78bd \ - --hash=sha256:2e46773dc9f35a1dd28bd6981332fd7f27bec001a918a72a79b4133cf5291dba \ - --hash=sha256:3107c66e43bda25359d5ef446f59c497de2b5ed4c7fdba0894f8d6cf3822dafc \ - --hash=sha256:375b8dd15a1f5d2feafff536d47e22f69625c1aa92f12b339ec0b2ca40263273 \ - --hash=sha256:45c566eb10b8967d71bf1ab8e4a525e5a93519e29ea071459ce517f6b903d7fa \ - --hash=sha256:499c3a1b0d6fc8213519e193796eb1a86a1be4b1877d678b30f83fd979811d1a \ - --hash=sha256:4ad70c4214f67d7466bea6a08061eba35c01b1b89eaa098040a35272a8efb22b \ - --hash=sha256:4b60c9520f7207aaf2e1d94de026682fc227806c6e1f55bba7606d1c94dd623a \ - --hash=sha256:5178952973e588b3f1360868847334e9e3bf49d19e169bbbdfaf8398002419ae \ - --hash=sha256:52a2d8323a465f84faaba5236567d212c3668f2ab53e1c74c15583cf507a0291 \ - --hash=sha256:598b4e238f13276e0008299bd2482003f48158e2b11826862b1eb2ad7c768b97 \ - --hash=sha256:5bd2d3bdb846d757055910f0a59792d33b555800813c3b39ada1829c372ccb06 \ - --hash=sha256:5c39ed17edea3bc69c743a8dd3e9853b7509625c2462532e62baa0732163a904 \ - --hash=sha256:5d203af30149ae339ad1b4f710d9844ed8796e97fda23ffbc4cc472968a47d0b \ - --hash=sha256:5ddbfd761ee00c12ee1be86c9c0683ecf5bb14c9772ddbd782085779a63dd55b \ - --hash=sha256:607bbe123c74e272e381a8d1957083a9463401f7bd01287f50521ecb05a313f8 \ - --hash=sha256:61b887f9ddba63ddf62fd02a3ba7add935d053b6dd7d58998c630e6dbade8527 \ - --hash=sha256:6619654954dc4936fcff82db8eb6401d3159ec6be81e33c6000dfd76ae189947 \ - --hash=sha256:674629ff60030d144b7bca2b8330225a9b11c482ed408813924619c6f302fdbb \ - --hash=sha256:6ec0d5af64f2e3d64a165f490d96368bb5dea8b8f9ad04487f9ab60dc4bb6003 \ - --hash=sha256:6f4dba50cfa56f910241eb7f883c20f1e7b1d8f7d91c750cd0b318bad443f4d5 \ - --hash=sha256:70fbbdacd1d271b77b7721fe3cdd2d537bbbd75d29e6300c672ec6bb38d9672f \ - --hash=sha256:72bacbaf24ac003fea9bff9837d1eedb6088758d41e100c1552930151f677739 \ - --hash=sha256:7326a1787e3c7b0429659e0a944725e1b03eeaa10edd945a86dead1913383944 \ - --hash=sha256:73853108f56df97baf2bb8b522f3578221e56f646ba345a372c78326710d3830 \ - --hash=sha256:73e3a0200cdda995c7e43dd47436c1548f87a30bb27fb871f352a22ab8dcf45f \ - --hash=sha256:75acbbeb05b86bc53cbe7b7e6fe00fbcf82ad7c684b3ad82e3d711da9ba287d3 \ - --hash=sha256:8069c5179902dcdce0be9bfc8235347fdbac249d23bd90514b7a47a72d9fecf4 \ - --hash=sha256:846e193e103b41e984ac921b335df59195356ce3f71dcfd155aa79c603873b84 \ - --hash=sha256:8594f42df584e5b4bb9281799698403f7af489fba84c34d53d1c4bfb71b7c4e7 \ - --hash=sha256:86510e3f5eca0ab87429dd77fafc04693195eec7fd6a137c389c3eeb4cfb77c6 \ - --hash=sha256:8853a3bf12afddfdf15f57c4b02d7ded92c7a75a5d7331d19f4f9572a89c17e6 \ - --hash=sha256:88a58d8ac0cc0e7f3a014509f0455248a76629ca9b604eca7dc5927cc593c5e9 \ - --hash=sha256:8ba470552b48e5835f1d23ecb936bb7f71d206f9dfeee64245f30c3270b994de \ - --hash=sha256:8c676b587da5673d3c75bd67dd2a8cdfeb282ca38a30f37950511766b26858c4 \ - --hash=sha256:8ec4a89295cd6cd4d1058a5e6aec6bf51e0eaaf9714774e1bfac7cfc9051db47 \ - --hash=sha256:94f3e1780abb45062287b4614a5bc0874519c86a777d4a7ad34978e86428b8dd \ - --hash=sha256:9a0f748eaa434a41fccf8e1ee7a3eed68af1b690e75328fd7a60af123c193b50 \ - --hash=sha256:a5629742881bcbc1f42e840af185fd4d83a5edeb96475a575f4da50d6ede337c \ - --hash=sha256:a65149d8ada1055029fcb665452b2814fe7d7082fcb0c5bed6db851cb69b2086 \ - --hash=sha256:b3c5ac4bed7519088103d9450a1107f76308ecf91d6dabc8a33a2fcfb18d0fba \ - --hash=sha256:b4fd7bd29610a83a8c9b564d457cf5bd92b4e11e79a4ee4716a63c959699b306 \ - --hash=sha256:bcd1fb5bb7b07f64c15618c89efcc2cfa3e95f0e3bcdbaf4642509de1942a699 \ - --hash=sha256:c12b5ae868897c7338519c03049a806af85b9b8c237b7d675b8c5e089e4a618e \ - --hash=sha256:c26845094b1af3c91852745ae78e3ea47abf3dbcd1cf962f16b9a5fbe3ee8488 \ - --hash=sha256:c6a660307ca9d4867caa8d9ca2c2658ab685de83792d1876274991adec7b93fa \ - --hash=sha256:c809a70e43c7977c4a42aefd62f0131823ebf7dd73556fa5d5950f5b354087e2 \ - --hash=sha256:c8b2351c85d855293a299038e1f89db92a2f35e8d2f783489c6f0b2b5f3fe8a3 \ - --hash=sha256:cb929ca942d0ec4fac404cbf520ee6cac37bf35be479b970c4ffadf2b6a1cad9 \ - --hash=sha256:d2c0a187a92a1cb5ef2c8ed5412dd8d4334272617f532d4ad4de31e0495bd923 \ - --hash=sha256:d69bfd8ec3219ae71bcde1f942b728903cad25fafe3100ba2258b973bd2bc1b2 \ - --hash=sha256:daffdf51ee5db69a82dd127eabecce20729e21f7a3680cf7cbb23f0829189790 \ - --hash=sha256:e58876c91f97b0952eb766123bfef372792ab3f4e3e1f1a2267834c2ab131734 \ - --hash=sha256:eda2616eb2313cbb3eebbe51f19362eb434b18e3bb599466a1ffa76a033fb916 \ - --hash=sha256:ee217c198f2e41f184f3869f3e485557296d505b5195c513b2bfe0062dc537f1 \ - --hash=sha256:f02541ef64077f22bf4924f225c0fd1248c168f86e4b7abdedd87d6ebaceab0f \ - --hash=sha256:f1b82c27e89fffc6da125d5eb0ca6e68017faf5efc078128cfaa42cf5cb38798 \ - --hash=sha256:fba162b8872d30fea8c52b258a542c5dfd7b235fb5cb352240c8d63b414013eb \ - --hash=sha256:fbbcb7b57dc9c794843e3d1258c0fbf0f48656d46ffe9e09b63bbd6e8cd5d0a2 \ - --hash=sha256:fcb4621042ac4b7865c179bb972ed0da0218a076dc1820ffc48b1d74c1e37fe9 - # via - # -r build/test-requirements.txt - # matplotlib -pluggy==1.5.0 \ - --hash=sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1 \ - --hash=sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669 - # via pytest -portpicker==1.6.0 \ - --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ - --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa - # via -r build/test-requirements.txt -psutil==5.9.8 \ - --hash=sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d \ - --hash=sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73 \ - --hash=sha256:26bd09967ae00920df88e0352a91cff1a78f8d69b3ecabbfe733610c0af486c8 \ - --hash=sha256:27cc40c3493bb10de1be4b3f07cae4c010ce715290a5be22b98493509c6299e2 \ - --hash=sha256:36f435891adb138ed3c9e58c6af3e2e6ca9ac2f365efe1f9cfef2794e6c93b4e \ - --hash=sha256:50187900d73c1381ba1454cf40308c2bf6f34268518b3f36a9b663ca87e65e36 \ - --hash=sha256:611052c4bc70432ec770d5d54f64206aa7203a101ec273a0cd82418c86503bb7 \ - --hash=sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c \ - --hash=sha256:7d79560ad97af658a0f6adfef8b834b53f64746d45b403f225b85c5c2c140eee \ - --hash=sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421 \ - --hash=sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf \ - --hash=sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81 \ - --hash=sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0 \ - --hash=sha256:bd1184ceb3f87651a67b2708d4c3338e9b10c5df903f2e3776b62303b26cb631 \ - --hash=sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4 \ - --hash=sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8 - # via portpicker -pyelftools==0.31 \ - --hash=sha256:c774416b10310156879443b81187d182d8d9ee499660380e645918b50bc88f99 \ - --hash=sha256:f52de7b3c7e8c64c8abc04a79a1cf37ac5fb0b8a49809827130b858944840607 - # via auditwheel -pygments==2.18.0 \ - --hash=sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199 \ - --hash=sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a - # via rich -pyparsing==3.1.2 \ - --hash=sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad \ - --hash=sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742 - # via matplotlib -pyproject-hooks==1.1.0 \ - --hash=sha256:4b37730834edbd6bd37f26ece6b44802fb1c1ee2ece0e54ddff8bfc06db86965 \ - --hash=sha256:7ceeefe9aec63a1064c18d939bdc3adf2d8aa1988a510afec15151578b232aa2 - # via build -pytest==8.2.0 \ - --hash=sha256:1733f0620f6cda4095bbf0d9ff8022486e91892245bb9e7d5542c018f612f233 \ - --hash=sha256:d507d4482197eac0ba2bae2e9babf0672eb333017bcedaa5fb1a3d42c1174b3f - # via pytest-xdist -pytest-xdist==3.6.1 \ - --hash=sha256:9ed4adfb68a016610848639bb7e02c9352d5d9f03d04809919e2dafc3be4cca7 \ - --hash=sha256:ead156a4db231eec769737f57668ef58a2084a34b2e55c4a8fa20d861107300d - # via -r build/test-requirements.txt -python-dateutil==2.9.0.post0 \ - --hash=sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3 \ - --hash=sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427 - # via matplotlib -rich==13.7.1 \ - --hash=sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222 \ - --hash=sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432 - # via -r build/test-requirements.txt -scipy==1.13.1 \ - --hash=sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d \ - --hash=sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c \ - --hash=sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca \ - --hash=sha256:27e52b09c0d3a1d5b63e1105f24177e544a222b43611aaf5bc44d4a0979e32f9 \ - --hash=sha256:2831f0dc9c5ea9edd6e51e6e769b655f08ec6db6e2e10f86ef39bd32eb11da54 \ - --hash=sha256:2ac65fb503dad64218c228e2dc2d0a0193f7904747db43014645ae139c8fad16 \ - --hash=sha256:392e4ec766654852c25ebad4f64e4e584cf19820b980bc04960bca0b0cd6eaa2 \ - --hash=sha256:436bbb42a94a8aeef855d755ce5a465479c721e9d684de76bf61a62e7c2b81d5 \ - --hash=sha256:45484bee6d65633752c490404513b9ef02475b4284c4cfab0ef946def50b3f59 \ - --hash=sha256:54f430b00f0133e2224c3ba42b805bfd0086fe488835effa33fa291561932326 \ - --hash=sha256:5713f62f781eebd8d597eb3f88b8bf9274e79eeabf63afb4a737abc6c84ad37b \ - --hash=sha256:5d72782f39716b2b3509cd7c33cdc08c96f2f4d2b06d51e52fb45a19ca0c86a1 \ - --hash=sha256:637e98dcf185ba7f8e663e122ebf908c4702420477ae52a04f9908707456ba4d \ - --hash=sha256:8335549ebbca860c52bf3d02f80784e91a004b71b059e3eea9678ba994796a24 \ - --hash=sha256:949ae67db5fa78a86e8fa644b9a6b07252f449dcf74247108c50e1d20d2b4627 \ - --hash=sha256:a014c2b3697bde71724244f63de2476925596c24285c7a637364761f8710891c \ - --hash=sha256:a78b4b3345f1b6f68a763c6e25c0c9a23a9fd0f39f5f3d200efe8feda560a5fa \ - --hash=sha256:cdd7dacfb95fea358916410ec61bbc20440f7860333aee6d882bb8046264e949 \ - --hash=sha256:cfa31f1def5c819b19ecc3a8b52d28ffdcc7ed52bb20c9a7589669dd3c250989 \ - --hash=sha256:d533654b7d221a6a97304ab63c41c96473ff04459e404b83275b60aa8f4b7004 \ - --hash=sha256:d605e9c23906d1994f55ace80e0125c587f96c020037ea6aa98d01b4bd2e222f \ - --hash=sha256:de3ade0e53bc1f21358aa74ff4830235d716211d7d077e340c7349bc3542e884 \ - --hash=sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299 \ - --hash=sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94 \ - --hash=sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f - # via -r build/requirements.in -six==1.16.0 \ - --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ - --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 - # via python-dateutil -sortedcontainers==2.4.0 \ - --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \ - --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0 - # via hypothesis -tomli==2.0.1 \ - --hash=sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc \ - --hash=sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f - # via - # build - # pytest -typing-extensions==4.12.0rc1 \ - --hash=sha256:be199d06d8f09ca2c9425e3aa04a9afba33e892fe079dea959e72df7f8442343 \ - --hash=sha256:f933a7b288a919ca97adbff656e52ff81f7ff25d98a2aabb9355ca4090f772fe - # via etils -wheel==0.43.0 \ - --hash=sha256:465ef92c69fa5c5da2d1cf8ac40559a8c940886afcef87dcf14b9470862f1d85 \ - --hash=sha256:55c570405f142630c6b9f72fe09d9b67cf1477fcf543ae5b8dcb1f5b7377da81 - # via -r build/test-requirements.txt -zipp==3.18.2 \ - --hash=sha256:6278d9ddbcfb1f1089a88fde84481528b07b0e10474e09dcfe53dad4069fa059 \ - --hash=sha256:dce197b859eb796242b0622af1b8beb0a722d52aa2f57133ead08edd5bf5374e - # via etils -zstandard==0.22.0 \ - --hash=sha256:11f0d1aab9516a497137b41e3d3ed4bbf7b2ee2abc79e5c8b010ad286d7464bd \ - --hash=sha256:1958100b8a1cc3f27fa21071a55cb2ed32e9e5df4c3c6e661c193437f171cba2 \ - --hash=sha256:1a90ba9a4c9c884bb876a14be2b1d216609385efb180393df40e5172e7ecf356 \ - --hash=sha256:1d43501f5f31e22baf822720d82b5547f8a08f5386a883b32584a185675c8fbf \ - --hash=sha256:23d2b3c2b8e7e5a6cb7922f7c27d73a9a615f0a5ab5d0e03dd533c477de23004 \ - --hash=sha256:2612e9bb4977381184bb2463150336d0f7e014d6bb5d4a370f9a372d21916f69 \ - --hash=sha256:275df437ab03f8c033b8a2c181e51716c32d831082d93ce48002a5227ec93019 \ - --hash=sha256:2ac9957bc6d2403c4772c890916bf181b2653640da98f32e04b96e4d6fb3252a \ - --hash=sha256:2b11ea433db22e720758cba584c9d661077121fcf60ab43351950ded20283440 \ - --hash=sha256:2fdd53b806786bd6112d97c1f1e7841e5e4daa06810ab4b284026a1a0e484c0b \ - --hash=sha256:33591d59f4956c9812f8063eff2e2c0065bc02050837f152574069f5f9f17775 \ - --hash=sha256:36a47636c3de227cd765e25a21dc5dace00539b82ddd99ee36abae38178eff9e \ - --hash=sha256:39b2853efc9403927f9065cc48c9980649462acbdf81cd4f0cb773af2fd734bc \ - --hash=sha256:3db41c5e49ef73641d5111554e1d1d3af106410a6c1fb52cf68912ba7a343a0d \ - --hash=sha256:445b47bc32de69d990ad0f34da0e20f535914623d1e506e74d6bc5c9dc40bb09 \ - --hash=sha256:466e6ad8caefb589ed281c076deb6f0cd330e8bc13c5035854ffb9c2014b118c \ - --hash=sha256:48f260e4c7294ef275744210a4010f116048e0c95857befb7462e033f09442fe \ - --hash=sha256:4ac59d5d6910b220141c1737b79d4a5aa9e57466e7469a012ed42ce2d3995e88 \ - --hash=sha256:53866a9d8ab363271c9e80c7c2e9441814961d47f88c9bc3b248142c32141d94 \ - --hash=sha256:589402548251056878d2e7c8859286eb91bd841af117dbe4ab000e6450987e08 \ - --hash=sha256:68953dc84b244b053c0d5f137a21ae8287ecf51b20872eccf8eaac0302d3e3b0 \ - --hash=sha256:6c25b8eb733d4e741246151d895dd0308137532737f337411160ff69ca24f93a \ - --hash=sha256:7034d381789f45576ec3f1fa0e15d741828146439228dc3f7c59856c5bcd3292 \ - --hash=sha256:73a1d6bd01961e9fd447162e137ed949c01bdb830dfca487c4a14e9742dccc93 \ - --hash=sha256:8226a33c542bcb54cd6bd0a366067b610b41713b64c9abec1bc4533d69f51e70 \ - --hash=sha256:888196c9c8893a1e8ff5e89b8f894e7f4f0e64a5af4d8f3c410f0319128bb2f8 \ - --hash=sha256:88c5b4b47a8a138338a07fc94e2ba3b1535f69247670abfe422de4e0b344aae2 \ - --hash=sha256:8a1b2effa96a5f019e72874969394edd393e2fbd6414a8208fea363a22803b45 \ - --hash=sha256:93e1856c8313bc688d5df069e106a4bc962eef3d13372020cc6e3ebf5e045202 \ - --hash=sha256:9501f36fac6b875c124243a379267d879262480bf85b1dbda61f5ad4d01b75a3 \ - --hash=sha256:959665072bd60f45c5b6b5d711f15bdefc9849dd5da9fb6c873e35f5d34d8cfb \ - --hash=sha256:a1d67d0d53d2a138f9e29d8acdabe11310c185e36f0a848efa104d4e40b808e4 \ - --hash=sha256:a493d470183ee620a3df1e6e55b3e4de8143c0ba1b16f3ded83208ea8ddfd91d \ - --hash=sha256:a7ccf5825fd71d4542c8ab28d4d482aace885f5ebe4b40faaa290eed8e095a4c \ - --hash=sha256:a88b7df61a292603e7cd662d92565d915796b094ffb3d206579aaebac6b85d5f \ - --hash=sha256:a97079b955b00b732c6f280d5023e0eefe359045e8b83b08cf0333af9ec78f26 \ - --hash=sha256:d22fdef58976457c65e2796e6730a3ea4a254f3ba83777ecfc8592ff8d77d303 \ - --hash=sha256:d75f693bb4e92c335e0645e8845e553cd09dc91616412d1d4650da835b5449df \ - --hash=sha256:d8593f8464fb64d58e8cb0b905b272d40184eac9a18d83cf8c10749c3eafcd7e \ - --hash=sha256:d8fff0f0c1d8bc5d866762ae95bd99d53282337af1be9dc0d88506b340e74b73 \ - --hash=sha256:de20a212ef3d00d609d0b22eb7cc798d5a69035e81839f549b538eff4105d01c \ - --hash=sha256:e9e9d4e2e336c529d4c435baad846a181e39a982f823f7e4495ec0b0ec8538d2 \ - --hash=sha256:f058a77ef0ece4e210bb0450e68408d4223f728b109764676e1a13537d056bb0 \ - --hash=sha256:f1a4b358947a65b94e2501ce3e078bbc929b039ede4679ddb0460829b12f7375 \ - --hash=sha256:f9b2cde1cd1b2a10246dbc143ba49d942d14fb3d2b4bccf4618d475c65464912 \ - --hash=sha256:fe3390c538f12437b859d815040763abc728955a52ca6ff9c5d4ac707c4ad98e - # via -r build/requirements.in - -# The following packages are considered to be unsafe in a requirements file: -setuptools==76.0.0 \ - --hash=sha256:199466a166ff664970d0ee145839f5582cb9bca7a0a3a2e795b6a9cb2308e9c6 \ - --hash=sha256:43b4ee60e10b0d0ee98ad11918e114c70701bc6051662a9a675a0496c1a158f4 - # via - # -r build/requirements.in - # -r build/test-requirements.txt diff --git a/build/requirements_lock_3_11.txt b/build/requirements_lock_3_11.txt index 8446e8361505..6e9a776a38c5 100644 --- a/build/requirements_lock_3_11.txt +++ b/build/requirements_lock_3_11.txt @@ -4,693 +4,1754 @@ # # bazel run //build:requirements.update # -absl-py==2.1.0 \ - --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ - --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff - # via -r build/test-requirements.txt -attrs==23.2.0 \ - --hash=sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30 \ - --hash=sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1 +--index-url https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple + +absl-py==2.3.1 \ + --hash=sha256:a97820526f7fbfd2ec1bce83f3f25e3a14840dac0d8e02a0b71cd75db3f77fc9 \ + --hash=sha256:eeecf07f0c2a93ace0772c92e596ace6d3d3996c042b2128459aaae2a76de11d + # via + # -r build/test-requirements.txt + # keras + # tensorboard + # tensorflow +astunparse==1.6.3 \ + --hash=sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872 \ + --hash=sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8 + # via tensorflow +attrs==25.4.0 \ + --hash=sha256:16d5969b87f0859ef33a48b35d55ac1be6e42ae49d5e853b597db70c35c57e11 \ + --hash=sha256:adcf7e2a1fb3b36ac48d97835bb6d8ade15b8dcce26aba8bf1d14847b57a3373 # via hypothesis -auditwheel==6.1.0 \ - --hash=sha256:3bdc686e774cf9e355e924b0fe5a562d55caa385d72234ffe7b81b378dba360f \ - --hash=sha256:e52f734861859e3743eb29fcac7da9c4921a1e4bea58f954b52f2926f8e9e364 +auditwheel==6.5.0 \ + --hash=sha256:4fbcbd5854054bb1dd7870db03727b871b96b18147db57259561c058603987d7 \ + --hash=sha256:e08d2eede0259be6feff597d041c06175026e93248a1a97143acc52c57714d80 # via -r build/test-requirements.txt -build==1.2.1 \ - --hash=sha256:526263f4870c26f26c433545579475377b2b7588b6f1eac76a001e873ae3e19d \ - --hash=sha256:75e10f767a433d9a86e50d83f418e83efc18ede923ee5ff7df93b6cb0306c5d4 - # via -r build/test-requirements.txt -cloudpickle==3.0.0 \ - --hash=sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 \ - --hash=sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882 +build==1.3.0 \ + --hash=sha256:698edd0ea270bde950f53aed21f3a0135672206f3911e0176261a31e0e07b397 \ + --hash=sha256:7145f0b5061ba90a1500d60bd1b13ca0a8a4cebdd0cc16ed8adf1c0e739f43b4 + # via -r build/requirements.in +certifi==2025.11.12 \ + --hash=sha256:97de8790030bbd5c2d96b7ec782fc2f7820ef8dba6db909ccf95449f2d062d4b \ + --hash=sha256:d8ab5478f2ecd78af242878415affce761ca6bc54a22a27e026d7c25357c3316 + # via requests +charset-normalizer==3.4.4 \ + --hash=sha256:027f6de494925c0ab2a55eab46ae5129951638a49a34d87f4c3eda90f696b4ad \ + --hash=sha256:077fbb858e903c73f6c9db43374fd213b0b6a778106bc7032446a8e8b5b38b93 \ + --hash=sha256:0a98e6759f854bd25a58a73fa88833fba3b7c491169f86ce1180c948ab3fd394 \ + --hash=sha256:0d3d8f15c07f86e9ff82319b3d9ef6f4bf907608f53fe9d92b28ea9ae3d1fd89 \ + --hash=sha256:0f04b14ffe5fdc8c4933862d8306109a2c51e0704acfa35d51598eb45a1e89fc \ + --hash=sha256:11d694519d7f29d6cd09f6ac70028dba10f92f6cdd059096db198c283794ac86 \ + --hash=sha256:194f08cbb32dc406d6e1aea671a68be0823673db2832b38405deba2fb0d88f63 \ + --hash=sha256:1bee1e43c28aa63cb16e5c14e582580546b08e535299b8b6158a7c9c768a1f3d \ + --hash=sha256:21d142cc6c0ec30d2efee5068ca36c128a30b0f2c53c1c07bd78cb6bc1d3be5f \ + --hash=sha256:2437418e20515acec67d86e12bf70056a33abdacb5cb1655042f6538d6b085a8 \ + --hash=sha256:244bfb999c71b35de57821b8ea746b24e863398194a4014e4c76adc2bbdfeff0 \ + --hash=sha256:2677acec1a2f8ef614c6888b5b4ae4060cc184174a938ed4e8ef690e15d3e505 \ + --hash=sha256:277e970e750505ed74c832b4bf75dac7476262ee2a013f5574dd49075879e161 \ + --hash=sha256:2aaba3b0819274cc41757a1da876f810a3e4d7b6eb25699253a4effef9e8e4af \ + --hash=sha256:2b7d8f6c26245217bd2ad053761201e9f9680f8ce52f0fcd8d0755aeae5b2152 \ + --hash=sha256:2c9d3c380143a1fedbff95a312aa798578371eb29da42106a29019368a475318 \ + --hash=sha256:3162d5d8ce1bb98dd51af660f2121c55d0fa541b46dff7bb9b9f86ea1d87de72 \ + --hash=sha256:31fd66405eaf47bb62e8cd575dc621c56c668f27d46a61d975a249930dd5e2a4 \ + --hash=sha256:362d61fd13843997c1c446760ef36f240cf81d3ebf74ac62652aebaf7838561e \ + --hash=sha256:376bec83a63b8021bb5c8ea75e21c4ccb86e7e45ca4eb81146091b56599b80c3 \ + --hash=sha256:44c2a8734b333e0578090c4cd6b16f275e07aa6614ca8715e6c038e865e70576 \ + --hash=sha256:47cc91b2f4dd2833fddaedd2893006b0106129d4b94fdb6af1f4ce5a9965577c \ + --hash=sha256:4902828217069c3c5c71094537a8e623f5d097858ac6ca8252f7b4d10b7560f1 \ + --hash=sha256:4bd5d4137d500351a30687c2d3971758aac9a19208fc110ccb9d7188fbe709e8 \ + --hash=sha256:4fe7859a4e3e8457458e2ff592f15ccb02f3da787fcd31e0183879c3ad4692a1 \ + --hash=sha256:542d2cee80be6f80247095cc36c418f7bddd14f4a6de45af91dfad36d817bba2 \ + --hash=sha256:554af85e960429cf30784dd47447d5125aaa3b99a6f0683589dbd27e2f45da44 \ + --hash=sha256:5833d2c39d8896e4e19b689ffc198f08ea58116bee26dea51e362ecc7cd3ed26 \ + --hash=sha256:5947809c8a2417be3267efc979c47d76a079758166f7d43ef5ae8e9f92751f88 \ + --hash=sha256:5ae497466c7901d54b639cf42d5b8c1b6a4fead55215500d2f486d34db48d016 \ + --hash=sha256:5bd2293095d766545ec1a8f612559f6b40abc0eb18bb2f5d1171872d34036ede \ + --hash=sha256:5bfbb1b9acf3334612667b61bd3002196fe2a1eb4dd74d247e0f2a4d50ec9bbf \ + --hash=sha256:5cb4d72eea50c8868f5288b7f7f33ed276118325c1dfd3957089f6b519e1382a \ + --hash=sha256:5dbe56a36425d26d6cfb40ce79c314a2e4dd6211d51d6d2191c00bed34f354cc \ + --hash=sha256:5f819d5fe9234f9f82d75bdfa9aef3a3d72c4d24a6e57aeaebba32a704553aa0 \ + --hash=sha256:64b55f9dce520635f018f907ff1b0df1fdc31f2795a922fb49dd14fbcdf48c84 \ + --hash=sha256:6515f3182dbe4ea06ced2d9e8666d97b46ef4c75e326b79bb624110f122551db \ + --hash=sha256:65e2befcd84bc6f37095f5961e68a6f077bf44946771354a28ad434c2cce0ae1 \ + --hash=sha256:6aee717dcfead04c6eb1ce3bd29ac1e22663cdea57f943c87d1eab9a025438d7 \ + --hash=sha256:6b39f987ae8ccdf0d2642338faf2abb1862340facc796048b604ef14919e55ed \ + --hash=sha256:6e1fcf0720908f200cd21aa4e6750a48ff6ce4afe7ff5a79a90d5ed8a08296f8 \ + --hash=sha256:74018750915ee7ad843a774364e13a3db91682f26142baddf775342c3f5b1133 \ + --hash=sha256:74664978bb272435107de04e36db5a9735e78232b85b77d45cfb38f758efd33e \ + --hash=sha256:74bb723680f9f7a6234dcf67aea57e708ec1fbdf5699fb91dfd6f511b0a320ef \ + --hash=sha256:752944c7ffbfdd10c074dc58ec2d5a8a4cd9493b314d367c14d24c17684ddd14 \ + --hash=sha256:778d2e08eda00f4256d7f672ca9fef386071c9202f5e4607920b86d7803387f2 \ + --hash=sha256:780236ac706e66881f3b7f2f32dfe90507a09e67d1d454c762cf642e6e1586e0 \ + --hash=sha256:798d75d81754988d2565bff1b97ba5a44411867c0cf32b77a7e8f8d84796b10d \ + --hash=sha256:799a7a5e4fb2d5898c60b640fd4981d6a25f1c11790935a44ce38c54e985f828 \ + --hash=sha256:7a32c560861a02ff789ad905a2fe94e3f840803362c84fecf1851cb4cf3dc37f \ + --hash=sha256:7c308f7e26e4363d79df40ca5b2be1c6ba9f02bdbccfed5abddb7859a6ce72cf \ + --hash=sha256:7fa17817dc5625de8a027cb8b26d9fefa3ea28c8253929b8d6649e705d2835b6 \ + --hash=sha256:81d5eb2a312700f4ecaa977a8235b634ce853200e828fbadf3a9c50bab278328 \ + --hash=sha256:82004af6c302b5d3ab2cfc4cc5f29db16123b1a8417f2e25f9066f91d4411090 \ + --hash=sha256:837c2ce8c5a65a2035be9b3569c684358dfbf109fd3b6969630a87535495ceaa \ + --hash=sha256:840c25fb618a231545cbab0564a799f101b63b9901f2569faecd6b222ac72381 \ + --hash=sha256:8a6562c3700cce886c5be75ade4a5db4214fda19fede41d9792d100288d8f94c \ + --hash=sha256:8af65f14dc14a79b924524b1e7fffe304517b2bff5a58bf64f30b98bbc5079eb \ + --hash=sha256:8ef3c867360f88ac904fd3f5e1f902f13307af9052646963ee08ff4f131adafc \ + --hash=sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a \ + --hash=sha256:99ae2cffebb06e6c22bdc25801d7b30f503cc87dbd283479e7b606f70aff57ec \ + --hash=sha256:9a26f18905b8dd5d685d6d07b0cdf98a79f3c7a918906af7cc143ea2e164c8bc \ + --hash=sha256:9b35f4c90079ff2e2edc5b26c0c77925e5d2d255c42c74fdb70fb49b172726ac \ + --hash=sha256:9cd98cdc06614a2f768d2b7286d66805f94c48cde050acdbbb7db2600ab3197e \ + --hash=sha256:9d1bb833febdff5c8927f922386db610b49db6e0d4f4ee29601d71e7c2694313 \ + --hash=sha256:9f7fcd74d410a36883701fafa2482a6af2ff5ba96b9a620e9e0721e28ead5569 \ + --hash=sha256:a59cb51917aa591b1c4e6a43c132f0cdc3c76dbad6155df4e28ee626cc77a0a3 \ + --hash=sha256:a61900df84c667873b292c3de315a786dd8dac506704dea57bc957bd31e22c7d \ + --hash=sha256:a79cfe37875f822425b89a82333404539ae63dbdddf97f84dcbc3d339aae9525 \ + --hash=sha256:a8a8b89589086a25749f471e6a900d3f662d1d3b6e2e59dcecf787b1cc3a1894 \ + --hash=sha256:a8bf8d0f749c5757af2142fe7903a9df1d2e8aa3841559b2bad34b08d0e2bcf3 \ + --hash=sha256:a9768c477b9d7bd54bc0c86dbaebdec6f03306675526c9927c0e8a04e8f94af9 \ + --hash=sha256:ac1c4a689edcc530fc9d9aa11f5774b9e2f33f9a0c6a57864e90908f5208d30a \ + --hash=sha256:af2d8c67d8e573d6de5bc30cdb27e9b95e49115cd9baad5ddbd1a6207aaa82a9 \ + --hash=sha256:b435cba5f4f750aa6c0a0d92c541fb79f69a387c91e61f1795227e4ed9cece14 \ + --hash=sha256:b5b290ccc2a263e8d185130284f8501e3e36c5e02750fc6b6bdeb2e9e96f1e25 \ + --hash=sha256:b5d84d37db046c5ca74ee7bb47dd6cbc13f80665fdde3e8040bdd3fb015ecb50 \ + --hash=sha256:b7cf1017d601aa35e6bb650b6ad28652c9cd78ee6caff19f3c28d03e1c80acbf \ + --hash=sha256:bc7637e2f80d8530ee4a78e878bce464f70087ce73cf7c1caf142416923b98f1 \ + --hash=sha256:c0463276121fdee9c49b98908b3a89c39be45d86d1dbaa22957e38f6321d4ce3 \ + --hash=sha256:c4ef880e27901b6cc782f1b95f82da9313c0eb95c3af699103088fa0ac3ce9ac \ + --hash=sha256:c8ae8a0f02f57a6e61203a31428fa1d677cbe50c93622b4149d5c0f319c1d19e \ + --hash=sha256:ca5862d5b3928c4940729dacc329aa9102900382fea192fc5e52eb69d6093815 \ + --hash=sha256:cb01158d8b88ee68f15949894ccc6712278243d95f344770fa7593fa2d94410c \ + --hash=sha256:cb6254dc36b47a990e59e1068afacdcd02958bdcce30bb50cc1700a8b9d624a6 \ + --hash=sha256:cc00f04ed596e9dc0da42ed17ac5e596c6ccba999ba6bd92b0e0aef2f170f2d6 \ + --hash=sha256:cd09d08005f958f370f539f186d10aec3377d55b9eeb0d796025d4886119d76e \ + --hash=sha256:cd4b7ca9984e5e7985c12bc60a6f173f3c958eae74f3ef6624bb6b26e2abbae4 \ + --hash=sha256:ce8a0633f41a967713a59c4139d29110c07e826d131a316b50ce11b1d79b4f84 \ + --hash=sha256:cead0978fc57397645f12578bfd2d5ea9138ea0fac82b2f63f7f7c6877986a69 \ + --hash=sha256:d055ec1e26e441f6187acf818b73564e6e6282709e9bcb5b63f5b23068356a15 \ + --hash=sha256:d1f13550535ad8cff21b8d757a3257963e951d96e20ec82ab44bc64aeb62a191 \ + --hash=sha256:d9c7f57c3d666a53421049053eaacdd14bbd0a528e2186fcb2e672effd053bb0 \ + --hash=sha256:d9e45d7faa48ee908174d8fe84854479ef838fc6a705c9315372eacbc2f02897 \ + --hash=sha256:da3326d9e65ef63a817ecbcc0df6e94463713b754fe293eaa03da99befb9a5bd \ + --hash=sha256:de00632ca48df9daf77a2c65a484531649261ec9f25489917f09e455cb09ddb2 \ + --hash=sha256:e1f185f86a6f3403aa2420e815904c67b2f9ebc443f045edd0de921108345794 \ + --hash=sha256:e824f1492727fa856dd6eda4f7cee25f8518a12f3c4a56a74e8095695089cf6d \ + --hash=sha256:e912091979546adf63357d7e2ccff9b44f026c075aeaf25a52d0e95ad2281074 \ + --hash=sha256:eaabd426fe94daf8fd157c32e571c85cb12e66692f15516a83a03264b08d06c3 \ + --hash=sha256:ebf3e58c7ec8a8bed6d66a75d7fb37b55e5015b03ceae72a8e7c74495551e224 \ + --hash=sha256:ecaae4149d99b1c9e7b88bb03e3221956f68fd6d50be2ef061b2381b61d20838 \ + --hash=sha256:eecbc200c7fd5ddb9a7f16c7decb07b566c29fa2161a16cf67b8d068bd21690a \ + --hash=sha256:f155a433c2ec037d4e8df17d18922c3a0d9b3232a396690f17175d2946f0218d \ + --hash=sha256:f1e34719c6ed0b92f418c7c780480b26b5d9c50349e9a9af7d76bf757530350d \ + --hash=sha256:f34be2938726fc13801220747472850852fe6b1ea75869a048d6f896838c896f \ + --hash=sha256:f820802628d2694cb7e56db99213f930856014862f3fd943d290ea8438d07ca8 \ + --hash=sha256:f8bf04158c6b607d747e93949aa60618b61312fe647a6369f88ce2ff16043490 \ + --hash=sha256:f8e160feb2aed042cd657a72acc0b481212ed28b1b9a95c0cee1621b524e1966 \ + --hash=sha256:f9d332f8c2a2fcbffe1378594431458ddbef721c1769d78e2cbc06280d8155f9 \ + --hash=sha256:fa09f53c465e532f4d3db095e0c55b615f010ad81803d383195b6b5ca6cbf5f3 \ + --hash=sha256:faa3a41b2b66b6e50f84ae4a68c64fcd0c44355741c6374813a800cd6695db9e \ + --hash=sha256:fd44c878ea55ba351104cb93cc85e74916eb8fa440ca7903e57575e97394f608 + # via requests +cloudpickle==3.1.2 \ + --hash=sha256:7fda9eb655c9c230dab534f1983763de5835249750e85fbcef43aaa30a9a2414 \ + --hash=sha256:9acb47f6afd73f60dc1df93bb801b472f05ff42fa6c84167d25cb206be1fbf4a # via -r build/test-requirements.txt colorama==0.4.6 \ --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 - # via -r build/test-requirements.txt -contourpy==1.2.1 \ - --hash=sha256:00e5388f71c1a0610e6fe56b5c44ab7ba14165cdd6d695429c5cd94021e390b2 \ - --hash=sha256:10a37ae557aabf2509c79715cd20b62e4c7c28b8cd62dd7d99e5ed3ce28c3fd9 \ - --hash=sha256:11959f0ce4a6f7b76ec578576a0b61a28bdc0696194b6347ba3f1c53827178b9 \ - --hash=sha256:187fa1d4c6acc06adb0fae5544c59898ad781409e61a926ac7e84b8f276dcef4 \ - --hash=sha256:1a07fc092a4088ee952ddae19a2b2a85757b923217b7eed584fdf25f53a6e7ce \ - --hash=sha256:1cac0a8f71a041aa587410424ad46dfa6a11f6149ceb219ce7dd48f6b02b87a7 \ - --hash=sha256:1d59e739ab0e3520e62a26c60707cc3ab0365d2f8fecea74bfe4de72dc56388f \ - --hash=sha256:2855c8b0b55958265e8b5888d6a615ba02883b225f2227461aa9127c578a4922 \ - --hash=sha256:2e785e0f2ef0d567099b9ff92cbfb958d71c2d5b9259981cd9bee81bd194c9a4 \ - --hash=sha256:309be79c0a354afff9ff7da4aaed7c3257e77edf6c1b448a779329431ee79d7e \ - --hash=sha256:39f3ecaf76cd98e802f094e0d4fbc6dc9c45a8d0c4d185f0f6c2234e14e5f75b \ - --hash=sha256:457499c79fa84593f22454bbd27670227874cd2ff5d6c84e60575c8b50a69619 \ - --hash=sha256:49e70d111fee47284d9dd867c9bb9a7058a3c617274900780c43e38d90fe1205 \ - --hash=sha256:4c75507d0a55378240f781599c30e7776674dbaf883a46d1c90f37e563453480 \ - --hash=sha256:4c863140fafc615c14a4bf4efd0f4425c02230eb8ef02784c9a156461e62c965 \ - --hash=sha256:4d8908b3bee1c889e547867ca4cdc54e5ab6be6d3e078556814a22457f49423c \ - --hash=sha256:5b9eb0ca724a241683c9685a484da9d35c872fd42756574a7cfbf58af26677fd \ - --hash=sha256:6022cecf8f44e36af10bd9118ca71f371078b4c168b6e0fab43d4a889985dbb5 \ - --hash=sha256:6150ffa5c767bc6332df27157d95442c379b7dce3a38dff89c0f39b63275696f \ - --hash=sha256:62828cada4a2b850dbef89c81f5a33741898b305db244904de418cc957ff05dc \ - --hash=sha256:7b4182299f251060996af5249c286bae9361fa8c6a9cda5efc29fe8bfd6062ec \ - --hash=sha256:94b34f32646ca0414237168d68a9157cb3889f06b096612afdd296003fdd32fd \ - --hash=sha256:9ce6889abac9a42afd07a562c2d6d4b2b7134f83f18571d859b25624a331c90b \ - --hash=sha256:9cffe0f850e89d7c0012a1fb8730f75edd4320a0a731ed0c183904fe6ecfc3a9 \ - --hash=sha256:a12a813949e5066148712a0626895c26b2578874e4cc63160bb007e6df3436fe \ - --hash=sha256:a1eea9aecf761c661d096d39ed9026574de8adb2ae1c5bd7b33558af884fb2ce \ - --hash=sha256:a31f94983fecbac95e58388210427d68cd30fe8a36927980fab9c20062645609 \ - --hash=sha256:ac58bdee53cbeba2ecad824fa8159493f0bf3b8ea4e93feb06c9a465d6c87da8 \ - --hash=sha256:af3f4485884750dddd9c25cb7e3915d83c2db92488b38ccb77dd594eac84c4a0 \ - --hash=sha256:b33d2bc4f69caedcd0a275329eb2198f560b325605810895627be5d4b876bf7f \ - --hash=sha256:b59c0ffceff8d4d3996a45f2bb6f4c207f94684a96bf3d9728dbb77428dd8cb8 \ - --hash=sha256:bb6834cbd983b19f06908b45bfc2dad6ac9479ae04abe923a275b5f48f1a186b \ - --hash=sha256:bd3db01f59fdcbce5b22afad19e390260d6d0222f35a1023d9adc5690a889364 \ - --hash=sha256:bd7c23df857d488f418439686d3b10ae2fbf9bc256cd045b37a8c16575ea1040 \ - --hash=sha256:c2528d60e398c7c4c799d56f907664673a807635b857df18f7ae64d3e6ce2d9f \ - --hash=sha256:d31a63bc6e6d87f77d71e1abbd7387ab817a66733734883d1fc0021ed9bfa083 \ - --hash=sha256:d4492d82b3bc7fbb7e3610747b159869468079fe149ec5c4d771fa1f614a14df \ - --hash=sha256:ddcb8581510311e13421b1f544403c16e901c4e8f09083c881fab2be80ee31ba \ - --hash=sha256:e1d59258c3c67c865435d8fbeb35f8c59b8bef3d6f46c1f29f6123556af28445 \ - --hash=sha256:eb3315a8a236ee19b6df481fc5f997436e8ade24a9f03dfdc6bd490fea20c6da \ - --hash=sha256:ef2b055471c0eb466033760a521efb9d8a32b99ab907fc8358481a1dd29e3bd3 \ - --hash=sha256:ef5adb9a3b1d0c645ff694f9bca7702ec2c70f4d734f9922ea34de02294fdf72 \ - --hash=sha256:f32c38afb74bd98ce26de7cc74a67b40afb7b05aae7b42924ea990d51e4dac02 \ - --hash=sha256:fe0ccca550bb8e5abc22f530ec0466136379c01321fd94f30a22231e8a48d985 + # via -r build/requirements.in +contourpy==1.3.3 \ + --hash=sha256:023b44101dfe49d7d53932be418477dba359649246075c996866106da069af69 \ + --hash=sha256:07ce5ed73ecdc4a03ffe3e1b3e3c1166db35ae7584be76f65dbbe28a7791b0cc \ + --hash=sha256:083e12155b210502d0bca491432bb04d56dc3432f95a979b429f2848c3dbe880 \ + --hash=sha256:0bf67e0e3f482cb69779dd3061b534eb35ac9b17f163d851e2a547d56dba0a3a \ + --hash=sha256:0c1fc238306b35f246d61a1d416a627348b5cf0648648a031e14bb8705fcdfe8 \ + --hash=sha256:13b68d6a62db8eafaebb8039218921399baf6e47bf85006fd8529f2a08ef33fc \ + --hash=sha256:15ff10bfada4bf92ec8b31c62bf7c1834c244019b4a33095a68000d7075df470 \ + --hash=sha256:177fb367556747a686509d6fef71d221a4b198a3905fe824430e5ea0fda54eb5 \ + --hash=sha256:1cadd8b8969f060ba45ed7c1b714fe69185812ab43bd6b86a9123fe8f99c3263 \ + --hash=sha256:1fd43c3be4c8e5fd6e4f2baeae35ae18176cf2e5cced681cca908addf1cdd53b \ + --hash=sha256:22e9b1bd7a9b1d652cd77388465dc358dafcd2e217d35552424aa4f996f524f5 \ + --hash=sha256:23416f38bfd74d5d28ab8429cc4d63fa67d5068bd711a85edb1c3fb0c3e2f381 \ + --hash=sha256:283edd842a01e3dcd435b1c5116798d661378d83d36d337b8dde1d16a5fc9ba3 \ + --hash=sha256:2a2a8b627d5cc6b7c41a4beff6c5ad5eb848c88255fda4a8745f7e901b32d8e4 \ + --hash=sha256:2b7e9480ffe2b0cd2e787e4df64270e3a0440d9db8dc823312e2c940c167df7e \ + --hash=sha256:322ab1c99b008dad206d406bb61d014cf0174df491ae9d9d0fac6a6fda4f977f \ + --hash=sha256:33c82d0138c0a062380332c861387650c82e4cf1747aaa6938b9b6516762e772 \ + --hash=sha256:348ac1f5d4f1d66d3322420f01d42e43122f43616e0f194fc1c9f5d830c5b286 \ + --hash=sha256:3519428f6be58431c56581f1694ba8e50626f2dd550af225f82fb5f5814d2a42 \ + --hash=sha256:3c30273eb2a55024ff31ba7d052dde990d7d8e5450f4bbb6e913558b3d6c2301 \ + --hash=sha256:3d1a3799d62d45c18bafd41c5fa05120b96a28079f2393af559b843d1a966a77 \ + --hash=sha256:451e71b5a7d597379ef572de31eeb909a87246974d960049a9848c3bc6c41bf7 \ + --hash=sha256:459c1f020cd59fcfe6650180678a9993932d80d44ccde1fa1868977438f0b411 \ + --hash=sha256:4d00e655fcef08aba35ec9610536bfe90267d7ab5ba944f7032549c55a146da1 \ + --hash=sha256:4debd64f124ca62069f313a9cb86656ff087786016d76927ae2cf37846b006c9 \ + --hash=sha256:4feffb6537d64b84877da813a5c30f1422ea5739566abf0bd18065ac040e120a \ + --hash=sha256:50ed930df7289ff2a8d7afeb9603f8289e5704755c7e5c3bbd929c90c817164b \ + --hash=sha256:51e79c1f7470158e838808d4a996fa9bac72c498e93d8ebe5119bc1e6becb0db \ + --hash=sha256:556dba8fb6f5d8742f2923fe9457dbdd51e1049c4a43fd3986a0b14a1d815fc6 \ + --hash=sha256:598c3aaece21c503615fd59c92a3598b428b2f01bfb4b8ca9c4edeecc2438620 \ + --hash=sha256:5ed3657edf08512fc3fe81b510e35c2012fbd3081d2e26160f27ca28affec989 \ + --hash=sha256:626d60935cf668e70a5ce6ff184fd713e9683fb458898e4249b63be9e28286ea \ + --hash=sha256:644a6853d15b2512d67881586bd03f462c7ab755db95f16f14d7e238f2852c67 \ + --hash=sha256:655456777ff65c2c548b7c454af9c6f33f16c8884f11083244b5819cc214f1b5 \ + --hash=sha256:66c8a43a4f7b8df8b71ee1840e4211a3c8d93b214b213f590e18a1beca458f7d \ + --hash=sha256:6afc576f7b33cf00996e5c1102dc2a8f7cc89e39c0b55df93a0b78c1bd992b36 \ + --hash=sha256:6c3d53c796f8647d6deb1abe867daeb66dcc8a97e8455efa729516b997b8ed99 \ + --hash=sha256:709a48ef9a690e1343202916450bc48b9e51c049b089c7f79a267b46cffcdaa1 \ + --hash=sha256:70f9aad7de812d6541d29d2bbf8feb22ff7e1c299523db288004e3157ff4674e \ + --hash=sha256:8153b8bfc11e1e4d75bcb0bff1db232f9e10b274e0929de9d608027e0d34ff8b \ + --hash=sha256:87acf5963fc2b34825e5b6b048f40e3635dd547f590b04d2ab317c2619ef7ae8 \ + --hash=sha256:88df9880d507169449d434c293467418b9f6cbe82edd19284aa0409e7fdb933d \ + --hash=sha256:929ddf8c4c7f348e4c0a5a3a714b5c8542ffaa8c22954862a46ca1813b667ee7 \ + --hash=sha256:92d9abc807cf7d0e047b95ca5d957cf4792fcd04e920ca70d48add15c1a90ea7 \ + --hash=sha256:95b181891b4c71de4bb404c6621e7e2390745f887f2a026b2d99e92c17892339 \ + --hash=sha256:9e999574eddae35f1312c2b4b717b7885d4edd6cb46700e04f7f02db454e67c1 \ + --hash=sha256:a15459b0f4615b00bbd1e91f1b9e19b7e63aea7483d03d804186f278c0af2659 \ + --hash=sha256:a22738912262aa3e254e4f3cb079a95a67132fc5a063890e224393596902f5a4 \ + --hash=sha256:ab2fd90904c503739a75b7c8c5c01160130ba67944a7b77bbf36ef8054576e7f \ + --hash=sha256:ab3074b48c4e2cf1a960e6bbeb7f04566bf36b1861d5c9d4d8ac04b82e38ba20 \ + --hash=sha256:afe5a512f31ee6bd7d0dda52ec9864c984ca3d66664444f2d72e0dc4eb832e36 \ + --hash=sha256:b08a32ea2f8e42cf1d4be3169a98dd4be32bafe4f22b6c4cb4ba810fa9e5d2cb \ + --hash=sha256:b20c7c9a3bf701366556e1b1984ed2d0cedf999903c51311417cf5f591d8c78d \ + --hash=sha256:b2e8faa0ed68cb29af51edd8e24798bb661eac3bd9f65420c1887b6ca89987c8 \ + --hash=sha256:b7301b89040075c30e5768810bc96a8e8d78085b47d8be6e4c3f5a0b4ed478a0 \ + --hash=sha256:b7448cb5a725bb1e35ce88771b86fba35ef418952474492cf7c764059933ff8b \ + --hash=sha256:ca0fdcd73925568ca027e0b17ab07aad764be4706d0a925b89227e447d9737b7 \ + --hash=sha256:ca658cd1a680a5c9ea96dc61cdbae1e85c8f25849843aa799dfd3cb370ad4fbe \ + --hash=sha256:cbedb772ed74ff5be440fa8eee9bd49f64f6e3fc09436d9c7d8f1c287b121d77 \ + --hash=sha256:cd5dfcaeb10f7b7f9dc8941717c6c2ade08f587be2226222c12b25f0483ed497 \ + --hash=sha256:cf9022ef053f2694e31d630feaacb21ea24224be1c3ad0520b13d844274614fd \ + --hash=sha256:d002b6f00d73d69333dac9d0b8d5e84d9724ff9ef044fd63c5986e62b7c9e1b1 \ + --hash=sha256:d06bb1f751ba5d417047db62bca3c8fde202b8c11fb50742ab3ab962c81e8216 \ + --hash=sha256:d304906ecc71672e9c89e87c4675dc5c2645e1f4269a5063b99b0bb29f232d13 \ + --hash=sha256:e4e6b05a45525357e382909a4c1600444e2a45b4795163d3b22669285591c1ae \ + --hash=sha256:e74a9a0f5e3fff48fb5a7f2fd2b9b70a3fe014a67522f79b7cca4c0c7e43c9ae \ + --hash=sha256:ea37e7b45949df430fe649e5de8351c423430046a2af20b1c1961cae3afcda77 \ + --hash=sha256:f64836de09927cba6f79dcd00fdd7d5329f3fccc633468507079c829ca4db4e3 \ + --hash=sha256:fd6ec6be509c787f1caf6b247f0b1ca598bef13f4ddeaa126b7658215529ba0f \ + --hash=sha256:fd907ae12cd483cd83e414b12941c632a969171bf90fc937d0c9f268a31cafff \ + --hash=sha256:fd914713266421b7536de2bfa8181aa8c699432b6763a0ea64195ebe28bff6a9 \ + --hash=sha256:fde6c716d51c04b1c25d0b90364d0be954624a0ee9d60e23e850e8d48353d07a # via matplotlib cycler==0.12.1 \ --hash=sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30 \ --hash=sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c # via matplotlib -etils[epath,epy]==1.8.0 \ - --hash=sha256:f31d7f27a889457eaa44eab18ce836d24fd6d40dbbb167d38879b7296f6456ea \ - --hash=sha256:fb478f57fec202e260e54c9192b317692fd63db2d11d993e70bcdffa29cccd58 +etils[epath,epy]==1.13.0 \ + --hash=sha256:a5b60c71f95bcd2d43d4e9fb3dc3879120c1f60472bb5ce19f7a860b1d44f607 \ + --hash=sha256:d9cd4f40fbe77ad6613b7348a18132cc511237b6c076dbb89105c0b520a4c6bb # via -r build/requirements.in -execnet==2.1.1 \ - --hash=sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc \ - --hash=sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3 +execnet==2.1.2 \ + --hash=sha256:63d83bfdd9a23e35b9c6a3261412324f964c2ec8dcd8d3c6916ee9373e0befcd \ + --hash=sha256:67fba928dd5a544b783f6056f449e5e3931a5c378b128bc18501f7ea79e296ec # via pytest-xdist -filelock==3.14.0 \ - --hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \ - --hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a +filelock==3.20.1 \ + --hash=sha256:15d9e9a67306188a44baa72f569d2bfd803076269365fdea0934385da4dc361a \ + --hash=sha256:b8360948b351b80f420878d8516519a2204b07aefcdcfd24912a5d33127f188c # via -r build/test-requirements.txt -flatbuffers==24.3.25 \ - --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ - --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 - # via -r build/test-requirements.txt -fonttools==4.51.0 \ - --hash=sha256:0118ef998a0699a96c7b28457f15546815015a2710a1b23a7bf6c1be60c01636 \ - --hash=sha256:0d145976194a5242fdd22df18a1b451481a88071feadf251221af110ca8f00ce \ - --hash=sha256:0e19bd9e9964a09cd2433a4b100ca7f34e34731e0758e13ba9a1ed6e5468cc0f \ - --hash=sha256:0f08c901d3866a8905363619e3741c33f0a83a680d92a9f0e575985c2634fcc1 \ - --hash=sha256:1250e818b5f8a679ad79660855528120a8f0288f8f30ec88b83db51515411fcc \ - --hash=sha256:15c94eeef6b095831067f72c825eb0e2d48bb4cea0647c1b05c981ecba2bf39f \ - --hash=sha256:1621ee57da887c17312acc4b0e7ac30d3a4fb0fec6174b2e3754a74c26bbed1e \ - --hash=sha256:180194c7fe60c989bb627d7ed5011f2bef1c4d36ecf3ec64daec8302f1ae0716 \ - --hash=sha256:278e50f6b003c6aed19bae2242b364e575bcb16304b53f2b64f6551b9c000e15 \ - --hash=sha256:32b17504696f605e9e960647c5f64b35704782a502cc26a37b800b4d69ff3c77 \ - --hash=sha256:3bee3f3bd9fa1d5ee616ccfd13b27ca605c2b4270e45715bd2883e9504735034 \ - --hash=sha256:4060acc2bfa2d8e98117828a238889f13b6f69d59f4f2d5857eece5277b829ba \ - --hash=sha256:54dcf21a2f2d06ded676e3c3f9f74b2bafded3a8ff12f0983160b13e9f2fb4a7 \ - --hash=sha256:56fc244f2585d6c00b9bcc59e6593e646cf095a96fe68d62cd4da53dd1287b55 \ - --hash=sha256:599bdb75e220241cedc6faebfafedd7670335d2e29620d207dd0378a4e9ccc5a \ - --hash=sha256:5f6bc991d1610f5c3bbe997b0233cbc234b8e82fa99fc0b2932dc1ca5e5afec0 \ - --hash=sha256:60a3409c9112aec02d5fb546f557bca6efa773dcb32ac147c6baf5f742e6258b \ - --hash=sha256:68b3fb7775a923be73e739f92f7e8a72725fd333eab24834041365d2278c3671 \ - --hash=sha256:76f1777d8b3386479ffb4a282e74318e730014d86ce60f016908d9801af9ca2a \ - --hash=sha256:806e7912c32a657fa39d2d6eb1d3012d35f841387c8fc6cf349ed70b7c340039 \ - --hash=sha256:84d7751f4468dd8cdd03ddada18b8b0857a5beec80bce9f435742abc9a851a74 \ - --hash=sha256:865a58b6e60b0938874af0968cd0553bcd88e0b2cb6e588727117bd099eef836 \ - --hash=sha256:8ac27f436e8af7779f0bb4d5425aa3535270494d3bc5459ed27de3f03151e4c2 \ - --hash=sha256:8b4850fa2ef2cfbc1d1f689bc159ef0f45d8d83298c1425838095bf53ef46308 \ - --hash=sha256:8b5ad456813d93b9c4b7ee55302208db2b45324315129d85275c01f5cb7e61a2 \ - --hash=sha256:8e2f1a4499e3b5ee82c19b5ee57f0294673125c65b0a1ff3764ea1f9db2f9ef5 \ - --hash=sha256:9696fe9f3f0c32e9a321d5268208a7cc9205a52f99b89479d1b035ed54c923f1 \ - --hash=sha256:96a48e137c36be55e68845fc4284533bda2980f8d6f835e26bca79d7e2006438 \ - --hash=sha256:a8feca65bab31479d795b0d16c9a9852902e3a3c0630678efb0b2b7941ea9c74 \ - --hash=sha256:aefa011207ed36cd280babfaa8510b8176f1a77261833e895a9d96e57e44802f \ - --hash=sha256:b2b92381f37b39ba2fc98c3a45a9d6383bfc9916a87d66ccb6553f7bdd129097 \ - --hash=sha256:b3c61423f22165541b9403ee39874dcae84cd57a9078b82e1dce8cb06b07fa2e \ - --hash=sha256:b5b48a1121117047d82695d276c2af2ee3a24ffe0f502ed581acc2673ecf1037 \ - --hash=sha256:c18b49adc721a7d0b8dfe7c3130c89b8704baf599fb396396d07d4aa69b824a1 \ - --hash=sha256:c5b8cab0c137ca229433570151b5c1fc6af212680b58b15abd797dcdd9dd5051 \ - --hash=sha256:c7e91abdfae1b5c9e3a543f48ce96013f9a08c6c9668f1e6be0beabf0a569c1b \ - --hash=sha256:cadf4e12a608ef1d13e039864f484c8a968840afa0258b0b843a0556497ea9ed \ - --hash=sha256:dc0673361331566d7a663d7ce0f6fdcbfbdc1f59c6e3ed1165ad7202ca183c68 \ - --hash=sha256:de7c29bdbdd35811f14493ffd2534b88f0ce1b9065316433b22d63ca1cd21f14 \ - --hash=sha256:e9d9298be7a05bb4801f558522adbe2feea1b0b103d5294ebf24a92dd49b78e5 \ - --hash=sha256:ee1af4be1c5afe4c96ca23badd368d8dc75f611887fb0c0dac9f71ee5d6f110e \ - --hash=sha256:f7e89853d8bea103c8e3514b9f9dc86b5b4120afb4583b57eb10dfa5afbe0936 +flatbuffers==25.9.23 \ + --hash=sha256:255538574d6cb6d0a79a17ec8bc0d30985913b87513a01cce8bcdb6b4c44d0e2 \ + --hash=sha256:676f9fa62750bb50cf531b42a0a2a118ad8f7f797a511eda12881c016f093b12 + # via + # -r build/test-requirements.txt + # tensorflow +fonttools==4.61.1 \ + --hash=sha256:0de30bfe7745c0d1ffa2b0b7048fb7123ad0d71107e10ee090fa0b16b9452e87 \ + --hash=sha256:10d88e55330e092940584774ee5e8a6971b01fc2f4d3466a1d6c158230880796 \ + --hash=sha256:11f35ad7805edba3aac1a3710d104592df59f4b957e30108ae0ba6c10b11dd75 \ + --hash=sha256:15acc09befd16a0fb8a8f62bc147e1a82817542d72184acca9ce6e0aeda9fa6d \ + --hash=sha256:17d2bf5d541add43822bcf0c43d7d847b160c9bb01d15d5007d84e2217aaa371 \ + --hash=sha256:2180f14c141d2f0f3da43f3a81bc8aa4684860f6b0e6f9e165a4831f24e6a23b \ + --hash=sha256:21e7c8d76f62ab13c9472ccf74515ca5b9a761d1bde3265152a6dc58700d895b \ + --hash=sha256:41a7170d042e8c0024703ed13b71893519a1a6d6e18e933e3ec7507a2c26a4b2 \ + --hash=sha256:41ed4b5ec103bd306bb68f81dc166e77409e5209443e5773cb4ed837bcc9b0d3 \ + --hash=sha256:497c31ce314219888c0e2fce5ad9178ca83fe5230b01a5006726cdf3ac9f24d9 \ + --hash=sha256:4c1b526c8d3f615a7b1867f38a9410849c8f4aef078535742198e942fba0e9bd \ + --hash=sha256:4d7092bb38c53bbc78e9255a59158b150bcdc115a1e3b3ce0b5f267dc35dd63c \ + --hash=sha256:4f5686e1fe5fce75d82d93c47a438a25bf0d1319d2843a926f741140b2b16e0c \ + --hash=sha256:58b0ee0ab5b1fc9921eccfe11d1435added19d6494dde14e323f25ad2bc30c56 \ + --hash=sha256:5ce02f38a754f207f2f06557523cd39a06438ba3aafc0639c477ac409fc64e37 \ + --hash=sha256:5fade934607a523614726119164ff621e8c30e8fa1ffffbbd358662056ba69f0 \ + --hash=sha256:5fe9fd43882620017add5eabb781ebfbc6998ee49b35bd7f8f79af1f9f99a958 \ + --hash=sha256:64102ca87e84261419c3747a0d20f396eb024bdbeb04c2bfb37e2891f5fadcb5 \ + --hash=sha256:664c5a68ec406f6b1547946683008576ef8b38275608e1cee6c061828171c118 \ + --hash=sha256:6675329885c44657f826ef01d9e4fb33b9158e9d93c537d84ad8399539bc6f69 \ + --hash=sha256:75c1a6dfac6abd407634420c93864a1e274ebc1c7531346d9254c0d8f6ca00f9 \ + --hash=sha256:75da8f28eff26defba42c52986de97b22106cb8f26515b7c22443ebc9c2d3261 \ + --hash=sha256:77efb033d8d7ff233385f30c62c7c79271c8885d5c9657d967ede124671bbdfb \ + --hash=sha256:78a7d3ab09dc47ac1a363a493e6112d8cabed7ba7caad5f54dbe2f08676d1b47 \ + --hash=sha256:7c7db70d57e5e1089a274cbb2b1fd635c9a24de809a231b154965d415d6c6d24 \ + --hash=sha256:8c56c488ab471628ff3bfa80964372fc13504ece601e0d97a78ee74126b2045c \ + --hash=sha256:91669ccac46bbc1d09e9273546181919064e8df73488ea087dcac3e2968df9ba \ + --hash=sha256:9b666a475a65f4e839d3d10473fad6d47e0a9db14a2f4a224029c5bfde58ad2c \ + --hash=sha256:9cfef3ab326780c04d6646f68d4b4742aae222e8b8ea1d627c74e38afcbc9d91 \ + --hash=sha256:a13fc8aeb24bad755eea8f7f9d409438eb94e82cf86b08fe77a03fbc8f6a96b1 \ + --hash=sha256:a75c301f96db737e1c5ed5fd7d77d9c34466de16095a266509e13da09751bd19 \ + --hash=sha256:a76d4cb80f41ba94a6691264be76435e5f72f2cb3cab0b092a6212855f71c2f6 \ + --hash=sha256:aed04cabe26f30c1647ef0e8fbb207516fd40fe9472e9439695f5c6998e60ac5 \ + --hash=sha256:b148b56f5de675ee16d45e769e69f87623a4944f7443850bf9a9376e628a89d2 \ + --hash=sha256:b501c862d4901792adaec7c25b1ecc749e2662543f68bb194c42ba18d6eec98d \ + --hash=sha256:b846a1fcf8beadeb9ea4f44ec5bdde393e2f1569e17d700bfc49cd69bde75881 \ + --hash=sha256:b931ae8f62db78861b0ff1ac017851764602288575d65b8e8ff1963fed419063 \ + --hash=sha256:c33ab3ca9d3ccd581d58e989d67554e42d8d4ded94ab3ade3508455fe70e65f7 \ + --hash=sha256:c6604b735bb12fef8e0efd5578c9fb5d3d8532d5001ea13a19cddf295673ee09 \ + --hash=sha256:d8db08051fc9e7d8bc622f2112511b8107d8f27cd89e2f64ec45e9825e8288da \ + --hash=sha256:d9203500f7c63545b4ce3799319fe4d9feb1a1b89b28d3cb5abd11b9dd64147e \ + --hash=sha256:dc492779501fa723b04d0ab1f5be046797fee17d27700476edc7ee9ae535a61e \ + --hash=sha256:e6bcdf33aec38d16508ce61fd81838f24c83c90a1d1b8c68982857038673d6b8 \ + --hash=sha256:e76ce097e3c57c4bcb67c5aa24a0ecdbd9f74ea9219997a707a4061fbe2707aa \ + --hash=sha256:eff1ac3cc66c2ac7cda1e64b4e2f3ffef474b7335f92fc3833fc632d595fcee6 \ + --hash=sha256:f3cb4a569029b9f291f88aafc927dd53683757e640081ca8c412781ea144565e \ + --hash=sha256:f79b168428351d11e10c5aeb61a74e1851ec221081299f4cf56036a95431c43a \ + --hash=sha256:fa646ecec9528bef693415c79a86e733c70a4965dd938e9a226b0fc64c9d2e6c \ + --hash=sha256:fe2efccb324948a11dd09d22136fe2ac8a97d6c1347cf0b58a911dcd529f66b7 \ + --hash=sha256:fff4f534200a04b4a36e7ae3cb74493afe807b517a09e99cb4faa89a34ed6ecd # via matplotlib -fsspec==2024.5.0 \ - --hash=sha256:1d021b0b0f933e3b3029ed808eb400c08ba101ca2de4b3483fbc9ca23fcee94a \ - --hash=sha256:e0fdbc446d67e182f49a70b82cf7889028a63588fde6b222521f10937b2b670c +fsspec==2025.10.0 \ + --hash=sha256:7c7712353ae7d875407f97715f0e1ffcc21e33d5b24556cb1e090ae9409ec61d \ + --hash=sha256:b6789427626f068f9a83ca4e8a3cc050850b6c0f71f99ddb4f542b8266a26a59 # via etils -hypothesis==6.102.4 \ - --hash=sha256:013df31b04a4daede13756f497e60e451963d86f426395a79f99c5d692919bbd \ - --hash=sha256:59b4d144346d5cffb482cc1bafbd21b13ff31608e8c4b3e4630339aee3e87763 +gast==0.6.0 \ + --hash=sha256:52b182313f7330389f72b069ba00f174cfe2a06411099547288839c6cbafbd54 \ + --hash=sha256:88fc5300d32c7ac6ca7b515310862f71e6fdf2c029bbec7c66c0f5dd47b6b1fb + # via tensorflow +google-pasta==0.2.0 \ + --hash=sha256:4612951da876b1a10fe3960d7226f0c7682cf901e16ac06e473b267a5afa8954 \ + --hash=sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed \ + --hash=sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e + # via tensorflow +grpcio==1.76.0 \ + --hash=sha256:035d90bc79eaa4bed83f524331d55e35820725c9fbb00ffa1904d5550ed7ede3 \ + --hash=sha256:04bbe1bfe3a68bbfd4e52402ab7d4eb59d72d02647ae2042204326cf4bbad280 \ + --hash=sha256:063065249d9e7e0782d03d2bca50787f53bd0fb89a67de9a7b521c4a01f1989b \ + --hash=sha256:06c3d6b076e7b593905d04fdba6a0525711b3466f43b3400266f04ff735de0cd \ + --hash=sha256:08caea849a9d3c71a542827d6df9d5a69067b0a1efbea8a855633ff5d9571465 \ + --hash=sha256:0aaa82d0813fd4c8e589fac9b65d7dd88702555f702fb10417f96e2a2a6d4c0f \ + --hash=sha256:0b7604868b38c1bfd5cf72d768aedd7db41d78cb6a4a18585e33fb0f9f2363fd \ + --hash=sha256:0c37db8606c258e2ee0c56b78c62fc9dee0e901b5dbdcf816c2dd4ad652b8b0c \ + --hash=sha256:1c9b93f79f48b03ada57ea24725d83a30284a012ec27eab2cf7e50a550cbbbcc \ + --hash=sha256:2107b0c024d1b35f4083f11245c0e23846ae64d02f40b2b226684840260ed054 \ + --hash=sha256:2229ae655ec4e8999599469559e97630185fdd53ae1e8997d147b7c9b2b72cba \ + --hash=sha256:25a18e9810fbc7e7f03ec2516addc116a957f8cbb8cbc95ccc80faa072743d03 \ + --hash=sha256:26ef06c73eb53267c2b319f43e6634c7556ea37672029241a056629af27c10e2 \ + --hash=sha256:2e1743fbd7f5fa713a1b0a8ac8ebabf0ec980b5d8809ec358d488e273b9cf02a \ + --hash=sha256:32483fe2aab2c3794101c2a159070584e5db11d0aa091b2c0ea9c4fc43d0d749 \ + --hash=sha256:3bf0f392c0b806905ed174dcd8bdd5e418a40d5567a05615a030a5aeddea692d \ + --hash=sha256:3e2a27c89eb9ac3d81ec8835e12414d73536c6e620355d65102503064a4ed6eb \ + --hash=sha256:40ad3afe81676fd9ec6d9d406eda00933f218038433980aa19d401490e46ecde \ + --hash=sha256:4215d3a102bd95e2e11b5395c78562967959824156af11fa93d18fdd18050990 \ + --hash=sha256:45d59a649a82df5718fd9527ce775fd66d1af35e6d31abdcdc906a49c6822958 \ + --hash=sha256:45e0111e73f43f735d70786557dc38141185072d7ff8dc1829d6a77ac1471468 \ + --hash=sha256:479496325ce554792dba6548fae3df31a72cef7bad71ca2e12b0e58f9b336bfc \ + --hash=sha256:490fa6d203992c47c7b9e4a9d39003a0c2bcc1c9aa3c058730884bbbb0ee9f09 \ + --hash=sha256:49ce47231818806067aea3324d4bf13825b658ad662d3b25fada0bdad9b8a6af \ + --hash=sha256:4baf3cbe2f0be3289eb68ac8ae771156971848bb8aaff60bad42005539431980 \ + --hash=sha256:522175aba7af9113c48ec10cc471b9b9bd4f6ceb36aeb4544a8e2c80ed9d252d \ + --hash=sha256:5e8571632780e08526f118f74170ad8d50fb0a48c23a746bef2a6ebade3abd6f \ + --hash=sha256:615ba64c208aaceb5ec83bfdce7728b80bfeb8be97562944836a7a0a9647d882 \ + --hash=sha256:61f69297cba3950a524f61c7c8ee12e55c486cb5f7db47ff9dcee33da6f0d3ae \ + --hash=sha256:65a20de41e85648e00305c1bb09a3598f840422e522277641145a32d42dcefcc \ + --hash=sha256:6a15c17af8839b6801d554263c546c69c4d7718ad4321e3166175b37eaacca77 \ + --hash=sha256:747fa73efa9b8b1488a95d0ba1039c8e2dca0f741612d80415b1e1c560febf4e \ + --hash=sha256:7be78388d6da1a25c0d5ec506523db58b18be22d9c37d8d3a32c08be4987bd73 \ + --hash=sha256:81fd9652b37b36f16138611c7e884eb82e0cec137c40d3ef7c3f9b3ed00f6ed8 \ + --hash=sha256:83d57312a58dcfe2a3a0f9d1389b299438909a02db60e2f2ea2ae2d8034909d3 \ + --hash=sha256:8843114c0cfce61b40ad48df65abcfc00d4dba82eae8718fab5352390848c5da \ + --hash=sha256:8cc3309d8e08fd79089e13ed4819d0af72aa935dd8f435a195fd152796752ff2 \ + --hash=sha256:8ebe63ee5f8fa4296b1b8cfc743f870d10e902ca18afc65c68cf46fd39bb0783 \ + --hash=sha256:8eddfb4d203a237da6f3cc8a540dad0517d274b5a1e9e636fd8d2c79b5c1d397 \ + --hash=sha256:922fa70ba549fce362d2e2871ab542082d66e2aaf0c19480ea453905b01f384e \ + --hash=sha256:931091142fd8cc14edccc0845a79248bc155425eee9a98b2db2ea4f00a235a42 \ + --hash=sha256:971fd5a1d6e62e00d945423a567e42eb1fa678ba89072832185ca836a94daaa6 \ + --hash=sha256:980a846182ce88c4f2f7e2c22c56aefd515daeb36149d1c897f83cf57999e0b6 \ + --hash=sha256:9d9adda641db7207e800a7f089068f6f645959f2df27e870ee81d44701dd9db3 \ + --hash=sha256:9f8f757bebaaea112c00dba718fc0d3260052ce714e25804a03f93f5d1c6cc11 \ + --hash=sha256:a6ae758eb08088d36812dd5d9af7a9859c05b1e0f714470ea243694b49278e7b \ + --hash=sha256:a8c2cf1209497cf659a667d7dea88985e834c24b7c3b605e6254cbb5076d985c \ + --hash=sha256:acab0277c40eff7143c2323190ea57b9ee5fd353d8190ee9652369fae735668a \ + --hash=sha256:b331680e46239e090f5b3cead313cc772f6caa7d0fc8de349337563125361a4a \ + --hash=sha256:c088e7a90b6017307f423efbb9d1ba97a22aa2170876223f9709e9d1de0b5347 \ + --hash=sha256:d099566accf23d21037f18a2a63d323075bebace807742e4b0ac210971d4dd70 \ + --hash=sha256:d388087771c837cdb6515539f43b9d4bf0b0f23593a24054ac16f7a960be16f4 \ + --hash=sha256:dcfe41187da8992c5f40aa8c5ec086fa3672834d2be57a32384c08d5a05b4c00 \ + --hash=sha256:e6d1db20594d9daba22f90da738b1a0441a7427552cc6e2e3d1297aeddc00378 \ + --hash=sha256:ebea5cc3aa8ea72e04df9913492f9a96d9348db876f9dda3ad729cfedf7ac416 \ + --hash=sha256:ebebf83299b0cb1721a8859ea98f3a77811e35dce7609c5c963b9ad90728f886 \ + --hash=sha256:f0e34c2079d47ae9f6188211db9e777c619a21d4faba6977774e8fa43b085e48 \ + --hash=sha256:f92f88e6c033db65a5ae3d97905c8fea9c725b63e28d5a75cb73b49bda5024d8 \ + --hash=sha256:f9f7bd5faab55f47231ad8dba7787866b69f5e93bc306e3915606779bbfb4ba8 \ + --hash=sha256:fd5ef5932f6475c436c4a55e4336ebbe47bd3272be04964a03d316bbf4afbcbc \ + --hash=sha256:ff8a59ea85a1f2191a0ffcc61298c571bc566332f82e5f5be1b83c9d8e668a62 + # via + # tensorboard + # tensorflow +h5py==3.15.1 \ + --hash=sha256:01f55111ca516f5568ae7a7fc8247dfce607de331b4467ee8a9a6ed14e5422c7 \ + --hash=sha256:0e2f471688402c3404fa4e13466e373e622fd4b74b47b56cfdff7cc688209422 \ + --hash=sha256:121b2b7a4c1915d63737483b7bff14ef253020f617c2fb2811f67a4bed9ac5e8 \ + --hash=sha256:25c8843fec43b2cc368aa15afa1cdf83fc5e17b1c4e10cd3771ef6c39b72e5ce \ + --hash=sha256:28a20e1a4082a479b3d7db2169f3a5034af010b90842e75ebbf2e9e49eb4183e \ + --hash=sha256:2cbc4104d3d4aca9d6db8c0c694555e255805bfeacf9eb1349bda871e26cacbe \ + --hash=sha256:316dd0f119734f324ca7ed10b5627a2de4ea42cc4dfbcedbee026aaa361c238c \ + --hash=sha256:4411c1867b9899a25e983fff56d820a66f52ac326bbe10c7cdf7d832c9dcd883 \ + --hash=sha256:4c45802bcb711e128a6839cb6c01e9ac648dc55df045c9542a675c771f15c8d5 \ + --hash=sha256:550e51131376889656feec4aff2170efc054a7fe79eb1da3bb92e1625d1ac878 \ + --hash=sha256:59b0d63b318bf3cc06687def2b45afd75926bbc006f7b8cd2b1a231299fc8599 \ + --hash=sha256:59b25cf02411bf12e14f803fef0b80886444c7fe21a5ad17c6a28d3f08098a1e \ + --hash=sha256:5aaa330bcbf2830150c50897ea5dcbed30b5b6d56897289846ac5b9e529ec243 \ + --hash=sha256:5b849ba619a066196169763c33f9f0f02e381156d61c03e000bb0100f9950faf \ + --hash=sha256:5f4fb0567eb8517c3ecd6b3c02c4f4e9da220c8932604960fd04e24ee1254763 \ + --hash=sha256:61d5a58a9851e01ee61c932bbbb1c98fe20aba0a5674776600fb9a361c0aa652 \ + --hash=sha256:64ce3f6470adb87c06e3a8dd1b90e973699f1759ad79bfa70c230939bff356c9 \ + --hash=sha256:67e59f6c2f19a32973a40f43d9a088ae324fe228c8366e25ebc57ceebf093a6b \ + --hash=sha256:80e5bb5b9508d5d9da09f81fd00abbb3f85da8143e56b1585d59bc8ceb1dba8b \ + --hash=sha256:8a33bfd5dfcea037196f7778534b1ff7e36a7f40a89e648c8f2967292eb6898e \ + --hash=sha256:954e480433e82d3872503104f9b285d369048c3a788b2b1a00e53d1c47c98dd2 \ + --hash=sha256:99d374a21f7321a4c6ab327c4ab23bd925ad69821aeb53a1e75dd809d19f67fa \ + --hash=sha256:9c73d1d7cdb97d5b17ae385153472ce118bed607e43be11e9a9deefaa54e0734 \ + --hash=sha256:a308fd8681a864c04423c0324527237a0484e2611e3441f8089fd00ed56a8171 \ + --hash=sha256:a6d8c5a05a76aca9a494b4c53ce8a9c29023b7f64f625c6ce1841e92a362ccdf \ + --hash=sha256:ab2219dbc6fcdb6932f76b548e2b16f34a1f52b7666e998157a4dfc02e2c4123 \ + --hash=sha256:b39239947cb36a819147fc19e86b618dcb0953d1cd969f5ed71fc0de60392427 \ + --hash=sha256:b51469890e58e85d5242e43aab29f5e9c7e526b951caab354f3ded4ac88e7b76 \ + --hash=sha256:c256254a8a81e2bddc0d376e23e2a6d2dc8a1e8a2261835ed8c1281a0744cd97 \ + --hash=sha256:c8440fd8bee9500c235ecb7aa1917a0389a2adb80c209fa1cc485bd70e0d94a5 \ + --hash=sha256:c86e3ed45c4473564de55aa83b6fc9e5ead86578773dfbd93047380042e26b69 \ + --hash=sha256:c970fb80001fffabb0109eaf95116c8e7c0d3ca2de854e0901e8a04c1f098509 \ + --hash=sha256:ca8a3a22458956ee7b40d8e39c9a9dc01f82933e4c030c964f8b875592f4d831 \ + --hash=sha256:d8cb02c3a96255149ed3ac811eeea25b655d959c6dd5ce702c9a95ff11859eb5 \ + --hash=sha256:dea78b092fd80a083563ed79a3171258d4a4d307492e7cf8b2313d464c82ba52 \ + --hash=sha256:e02fe77a03f652500d8bff288cbf3675f742fc0411f5a628fa37116507dc7cc0 \ + --hash=sha256:e7f6c841efd4e6e5b7e82222eaf90819927b6d256ab0f3aca29675601f654f3c \ + --hash=sha256:f4a016df3f4a8a14d573b496e4d1964deb380e26031fc85fb40e417e9131888a \ + --hash=sha256:fa8df5267f545b4946df8ca0d93d23382191018e4cda2deda4c2cedf9a010e13 \ + --hash=sha256:fd125c131889ebbef0849f4a0e29cf363b48aba42f228d08b4079913b576bb3a + # via + # keras + # tensorflow +hypothesis==6.142.1 \ + --hash=sha256:3179cb08756562c526aaf4a9871ebbff83d2d75c03896ed0bc9c1d14097a930c \ + --hash=sha256:95a7d38fcc58e697e3020665adcb951c630cdbc8065e4b4474949e486b06bd6d # via -r build/test-requirements.txt -importlib-resources==6.4.0 \ - --hash=sha256:50d10f043df931902d4194ea07ec57960f66a80449ff867bfe782b4c486ba78c \ - --hash=sha256:cdb2b453b8046ca4e3798eb1d84f3cce1446a0e8e7b5ef4efb600f19fc398145 +idna==3.11 \ + --hash=sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea \ + --hash=sha256:795dafcc9c04ed0c1fb032c2aa73654d8e8c5023a7df64a53f39190ada629902 + # via requests +importlib-resources==6.5.2 \ + --hash=sha256:185f87adef5bcc288449d98fb4fba07cea78bc036455dd44c5fc4a2fe78fed2c \ + --hash=sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec # via etils -iniconfig==2.0.0 \ - --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ - --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 +iniconfig==2.3.0 \ + --hash=sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730 \ + --hash=sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12 # via pytest -kiwisolver==1.4.5 \ - --hash=sha256:00bd361b903dc4bbf4eb165f24d1acbee754fce22ded24c3d56eec268658a5cf \ - --hash=sha256:040c1aebeda72197ef477a906782b5ab0d387642e93bda547336b8957c61022e \ - --hash=sha256:05703cf211d585109fcd72207a31bb170a0f22144d68298dc5e61b3c946518af \ - --hash=sha256:06f54715b7737c2fecdbf140d1afb11a33d59508a47bf11bb38ecf21dc9ab79f \ - --hash=sha256:0dc9db8e79f0036e8173c466d21ef18e1befc02de8bf8aa8dc0813a6dc8a7046 \ - --hash=sha256:0f114aa76dc1b8f636d077979c0ac22e7cd8f3493abbab152f20eb8d3cda71f3 \ - --hash=sha256:11863aa14a51fd6ec28688d76f1735f8f69ab1fabf388851a595d0721af042f5 \ - --hash=sha256:11c7de8f692fc99816e8ac50d1d1aef4f75126eefc33ac79aac02c099fd3db71 \ - --hash=sha256:11d011a7574eb3b82bcc9c1a1d35c1d7075677fdd15de527d91b46bd35e935ee \ - --hash=sha256:146d14bebb7f1dc4d5fbf74f8a6cb15ac42baadee8912eb84ac0b3b2a3dc6ac3 \ - --hash=sha256:15568384086b6df3c65353820a4473575dbad192e35010f622c6ce3eebd57af9 \ - --hash=sha256:19df6e621f6d8b4b9c4d45f40a66839294ff2bb235e64d2178f7522d9170ac5b \ - --hash=sha256:1b04139c4236a0f3aff534479b58f6f849a8b351e1314826c2d230849ed48985 \ - --hash=sha256:210ef2c3a1f03272649aff1ef992df2e724748918c4bc2d5a90352849eb40bea \ - --hash=sha256:2270953c0d8cdab5d422bee7d2007f043473f9d2999631c86a223c9db56cbd16 \ - --hash=sha256:2400873bccc260b6ae184b2b8a4fec0e4082d30648eadb7c3d9a13405d861e89 \ - --hash=sha256:2a40773c71d7ccdd3798f6489aaac9eee213d566850a9533f8d26332d626b82c \ - --hash=sha256:2c5674c4e74d939b9d91dda0fae10597ac7521768fec9e399c70a1f27e2ea2d9 \ - --hash=sha256:3195782b26fc03aa9c6913d5bad5aeb864bdc372924c093b0f1cebad603dd712 \ - --hash=sha256:31a82d498054cac9f6d0b53d02bb85811185bcb477d4b60144f915f3b3126342 \ - --hash=sha256:32d5cf40c4f7c7b3ca500f8985eb3fb3a7dfc023215e876f207956b5ea26632a \ - --hash=sha256:346f5343b9e3f00b8db8ba359350eb124b98c99efd0b408728ac6ebf38173958 \ - --hash=sha256:378a214a1e3bbf5ac4a8708304318b4f890da88c9e6a07699c4ae7174c09a68d \ - --hash=sha256:39b42c68602539407884cf70d6a480a469b93b81b7701378ba5e2328660c847a \ - --hash=sha256:3a2b053a0ab7a3960c98725cfb0bf5b48ba82f64ec95fe06f1d06c99b552e130 \ - --hash=sha256:3aba7311af82e335dd1e36ffff68aaca609ca6290c2cb6d821a39aa075d8e3ff \ - --hash=sha256:3cd32d6c13807e5c66a7cbb79f90b553642f296ae4518a60d8d76243b0ad2898 \ - --hash=sha256:3edd2fa14e68c9be82c5b16689e8d63d89fe927e56debd6e1dbce7a26a17f81b \ - --hash=sha256:4c380469bd3f970ef677bf2bcba2b6b0b4d5c75e7a020fb863ef75084efad66f \ - --hash=sha256:4e66e81a5779b65ac21764c295087de82235597a2293d18d943f8e9e32746265 \ - --hash=sha256:53abb58632235cd154176ced1ae8f0d29a6657aa1aa9decf50b899b755bc2b93 \ - --hash=sha256:5794cf59533bc3f1b1c821f7206a3617999db9fbefc345360aafe2e067514929 \ - --hash=sha256:59415f46a37f7f2efeec758353dd2eae1b07640d8ca0f0c42548ec4125492635 \ - --hash=sha256:59ec7b7c7e1a61061850d53aaf8e93db63dce0c936db1fda2658b70e4a1be709 \ - --hash=sha256:59edc41b24031bc25108e210c0def6f6c2191210492a972d585a06ff246bb79b \ - --hash=sha256:5a580c91d686376f0f7c295357595c5a026e6cbc3d77b7c36e290201e7c11ecb \ - --hash=sha256:5b94529f9b2591b7af5f3e0e730a4e0a41ea174af35a4fd067775f9bdfeee01a \ - --hash=sha256:5c7b3b3a728dc6faf3fc372ef24f21d1e3cee2ac3e9596691d746e5a536de920 \ - --hash=sha256:5c90ae8c8d32e472be041e76f9d2f2dbff4d0b0be8bd4041770eddb18cf49a4e \ - --hash=sha256:5e7139af55d1688f8b960ee9ad5adafc4ac17c1c473fe07133ac092310d76544 \ - --hash=sha256:5ff5cf3571589b6d13bfbfd6bcd7a3f659e42f96b5fd1c4830c4cf21d4f5ef45 \ - --hash=sha256:620ced262a86244e2be10a676b646f29c34537d0d9cc8eb26c08f53d98013390 \ - --hash=sha256:6512cb89e334e4700febbffaaa52761b65b4f5a3cf33f960213d5656cea36a77 \ - --hash=sha256:6c08e1312a9cf1074d17b17728d3dfce2a5125b2d791527f33ffbe805200a355 \ - --hash=sha256:6c3bd3cde54cafb87d74d8db50b909705c62b17c2099b8f2e25b461882e544ff \ - --hash=sha256:6ef7afcd2d281494c0a9101d5c571970708ad911d028137cd558f02b851c08b4 \ - --hash=sha256:7269d9e5f1084a653d575c7ec012ff57f0c042258bf5db0954bf551c158466e7 \ - --hash=sha256:72d40b33e834371fd330fb1472ca19d9b8327acb79a5821d4008391db8e29f20 \ - --hash=sha256:74d1b44c6cfc897df648cc9fdaa09bc3e7679926e6f96df05775d4fb3946571c \ - --hash=sha256:74db36e14a7d1ce0986fa104f7d5637aea5c82ca6326ed0ec5694280942d1162 \ - --hash=sha256:763773d53f07244148ccac5b084da5adb90bfaee39c197554f01b286cf869228 \ - --hash=sha256:76c6a5964640638cdeaa0c359382e5703e9293030fe730018ca06bc2010c4437 \ - --hash=sha256:76d9289ed3f7501012e05abb8358bbb129149dbd173f1f57a1bf1c22d19ab7cc \ - --hash=sha256:7931d8f1f67c4be9ba1dd9c451fb0eeca1a25b89e4d3f89e828fe12a519b782a \ - --hash=sha256:7b8b454bac16428b22560d0a1cf0a09875339cab69df61d7805bf48919415901 \ - --hash=sha256:7e5bab140c309cb3a6ce373a9e71eb7e4873c70c2dda01df6820474f9889d6d4 \ - --hash=sha256:83d78376d0d4fd884e2c114d0621624b73d2aba4e2788182d286309ebdeed770 \ - --hash=sha256:852542f9481f4a62dbb5dd99e8ab7aedfeb8fb6342349a181d4036877410f525 \ - --hash=sha256:85267bd1aa8880a9c88a8cb71e18d3d64d2751a790e6ca6c27b8ccc724bcd5ad \ - --hash=sha256:88a2df29d4724b9237fc0c6eaf2a1adae0cdc0b3e9f4d8e7dc54b16812d2d81a \ - --hash=sha256:88b9f257ca61b838b6f8094a62418421f87ac2a1069f7e896c36a7d86b5d4c29 \ - --hash=sha256:8ab3919a9997ab7ef2fbbed0cc99bb28d3c13e6d4b1ad36e97e482558a91be90 \ - --hash=sha256:92dea1ffe3714fa8eb6a314d2b3c773208d865a0e0d35e713ec54eea08a66250 \ - --hash=sha256:9407b6a5f0d675e8a827ad8742e1d6b49d9c1a1da5d952a67d50ef5f4170b18d \ - --hash=sha256:9408acf3270c4b6baad483865191e3e582b638b1654a007c62e3efe96f09a9a3 \ - --hash=sha256:955e8513d07a283056b1396e9a57ceddbd272d9252c14f154d450d227606eb54 \ - --hash=sha256:9db8ea4c388fdb0f780fe91346fd438657ea602d58348753d9fb265ce1bca67f \ - --hash=sha256:9eaa8b117dc8337728e834b9c6e2611f10c79e38f65157c4c38e9400286f5cb1 \ - --hash=sha256:a51a263952b1429e429ff236d2f5a21c5125437861baeed77f5e1cc2d2c7c6da \ - --hash=sha256:a6aa6315319a052b4ee378aa171959c898a6183f15c1e541821c5c59beaa0238 \ - --hash=sha256:aa12042de0171fad672b6c59df69106d20d5596e4f87b5e8f76df757a7c399aa \ - --hash=sha256:aaf7be1207676ac608a50cd08f102f6742dbfc70e8d60c4db1c6897f62f71523 \ - --hash=sha256:b0157420efcb803e71d1b28e2c287518b8808b7cf1ab8af36718fd0a2c453eb0 \ - --hash=sha256:b3f7e75f3015df442238cca659f8baa5f42ce2a8582727981cbfa15fee0ee205 \ - --hash=sha256:b9098e0049e88c6a24ff64545cdfc50807818ba6c1b739cae221bbbcbc58aad3 \ - --hash=sha256:ba55dce0a9b8ff59495ddd050a0225d58bd0983d09f87cfe2b6aec4f2c1234e4 \ - --hash=sha256:bb86433b1cfe686da83ce32a9d3a8dd308e85c76b60896d58f082136f10bffac \ - --hash=sha256:bbea0db94288e29afcc4c28afbf3a7ccaf2d7e027489c449cf7e8f83c6346eb9 \ - --hash=sha256:bbf1d63eef84b2e8c89011b7f2235b1e0bf7dacc11cac9431fc6468e99ac77fb \ - --hash=sha256:c7940c1dc63eb37a67721b10d703247552416f719c4188c54e04334321351ced \ - --hash=sha256:c9bf3325c47b11b2e51bca0824ea217c7cd84491d8ac4eefd1e409705ef092bd \ - --hash=sha256:cdc8a402aaee9a798b50d8b827d7ecf75edc5fb35ea0f91f213ff927c15f4ff0 \ - --hash=sha256:ceec1a6bc6cab1d6ff5d06592a91a692f90ec7505d6463a88a52cc0eb58545da \ - --hash=sha256:cfe6ab8da05c01ba6fbea630377b5da2cd9bcbc6338510116b01c1bc939a2c18 \ - --hash=sha256:d099e745a512f7e3bbe7249ca835f4d357c586d78d79ae8f1dcd4d8adeb9bda9 \ - --hash=sha256:d0ef46024e6a3d79c01ff13801cb19d0cad7fd859b15037aec74315540acc276 \ - --hash=sha256:d2e5a98f0ec99beb3c10e13b387f8db39106d53993f498b295f0c914328b1333 \ - --hash=sha256:da4cfb373035def307905d05041c1d06d8936452fe89d464743ae7fb8371078b \ - --hash=sha256:da802a19d6e15dffe4b0c24b38b3af68e6c1a68e6e1d8f30148c83864f3881db \ - --hash=sha256:dced8146011d2bc2e883f9bd68618b8247387f4bbec46d7392b3c3b032640126 \ - --hash=sha256:dfdd7c0b105af050eb3d64997809dc21da247cf44e63dc73ff0fd20b96be55a9 \ - --hash=sha256:e368f200bbc2e4f905b8e71eb38b3c04333bddaa6a2464a6355487b02bb7fb09 \ - --hash=sha256:e391b1f0a8a5a10ab3b9bb6afcfd74f2175f24f8975fb87ecae700d1503cdee0 \ - --hash=sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec \ - --hash=sha256:e5d706eba36b4c4d5bc6c6377bb6568098765e990cfc21ee16d13963fab7b3e7 \ - --hash=sha256:ec20916e7b4cbfb1f12380e46486ec4bcbaa91a9c448b97023fde0d5bbf9e4ff \ - --hash=sha256:f1d072c2eb0ad60d4c183f3fb44ac6f73fb7a8f16a2694a91f988275cbf352f9 \ - --hash=sha256:f846c260f483d1fd217fe5ed7c173fb109efa6b1fc8381c8b7552c5781756192 \ - --hash=sha256:f91de7223d4c7b793867797bacd1ee53bfe7359bd70d27b7b58a04efbb9436c8 \ - --hash=sha256:faae4860798c31530dd184046a900e652c95513796ef51a12bc086710c2eec4d \ - --hash=sha256:fc579bf0f502e54926519451b920e875f433aceb4624a3646b3252b5caa9e0b6 \ - --hash=sha256:fcc700eadbbccbf6bc1bcb9dbe0786b4b1cb91ca0dcda336eef5c2beed37b797 \ - --hash=sha256:fd32ea360bcbb92d28933fc05ed09bffcb1704ba3fc7942e81db0fd4f81a7892 \ - --hash=sha256:fdb7adb641a0d13bdcd4ef48e062363d8a9ad4a182ac7647ec88f695e719ae9f +jax-cuda12-pjrt==0.8.2 ; sys_platform == "linux" \ + --hash=sha256:717a1b196a642409ce195ddf031c20bbeadcc886f55e49a1d3f4927373aeedae \ + --hash=sha256:e3bab41ca7c48e4163db9e7efd271b3aa85f0fe45f5ed0708d6bbed93a59f977 + # via + # -r build/requirements.in + # jax-cuda12-plugin +jax-cuda12-plugin==0.8.2 ; sys_platform == "linux" and python_version < "3.14" \ + --hash=sha256:0b0a3304ce7e494acd8d9c593490c112a32cdb6010fe1afc584d9e41fd863167 \ + --hash=sha256:1b4828242d57f233b394d17ebaa599c503c1fb9b7c754012a06eb84dbc935fc8 \ + --hash=sha256:20165861b3d3e66ebb2c0f63a547d1d5ee17ea44ac3be7153c7908c9ca8c88f3 \ + --hash=sha256:377e4be17e22dde0343b3f3c05bf69235b3dbf11d766cca9c5a93da47971dcb7 \ + --hash=sha256:403d5e07731b5cdac3bd9fb3f448bd8480062cb2c0ab61ea2ad23fcd0a65479a \ + --hash=sha256:58c51473fc622e03138035985f741833564d70a4bd5a2178f61b62cdaa32ff94 \ + --hash=sha256:637387dc3408cd204562668502f9e95f76c6edde0a6d2e48f055162dc2aebf0d \ + --hash=sha256:70d33222484ad5c375b8f8357b7c23cacb844f6ecfc39567f8dd47fde6e87858 \ + --hash=sha256:82c6798be66bf8c773386918e4c8e5cd8119753f3bfb3ca4bbc46818283750c6 \ + --hash=sha256:a5898bac1d8ab6020b54546440256409f2c66bcbbb3a1099ca473c84843addad \ + --hash=sha256:d68a6d8b4a45ee561746bac7a6468da8203832626b0b39ad4ac43011f61f875d \ + --hash=sha256:dd4f7c34d4512ff5a36fd1b01584ef7781cad615e3f9e71880eae2f4998e5108 + # via -r build/requirements.in +jax-cuda13-pjrt==0.8.2 \ + --hash=sha256:370e9cb9a76af6a6b00e8f6c6ae8e5a0158886b65994a114cf3acf56ea1570f3 \ + --hash=sha256:c0ebc284d4bea9c0b278017b4bb62eb871755fe582cbf22ce1fb770d30fb9af6 + # via + # -r build/requirements.in + # jax-cuda13-plugin +jax-cuda13-plugin==0.8.2 \ + --hash=sha256:0ab3a13f798f973416754cb00b1696a8e6d2425c39496e2145293661bcd28446 \ + --hash=sha256:2c7532f2a6c9dcf50ba4b8ebbe4c78988b8c260695e50f9fdbd68f9417a8bb51 \ + --hash=sha256:358f4b4d0d2b92f62fe7f07d355a9c57c5684da3d4e6f483afec38b23e66f02e \ + --hash=sha256:6703552ee10b6b41b43fe4fc62b4672bec13d271eb709a29bf14b0fa1ec805dd \ + --hash=sha256:8fe2c5874cbae18a01a0fa722dfbd547b5504e7bf42aaaae6ad3e1b676e8daa2 \ + --hash=sha256:950f2c1f2e0f5e13893c4c27a4df540a96a26ebd15d7954fa85ffb8cc8b1d9a1 \ + --hash=sha256:9835076d0a917345700370cfaee5cfd7876c0b0982e99206a8b4848af0ea6e5e \ + --hash=sha256:bb059d99eb960718358f82f3eaa0829f96fc7d7e978ded821f2c40256a8aee4b \ + --hash=sha256:d5f9470db46aee8a64e9ab7817927335b77c6d3698d5113e0da2dd962c1468c4 \ + --hash=sha256:f1643b216ccb41a2d7c0e240359c729c96c5eca76e4525ffcd6313049e348b9a \ + --hash=sha256:f35f687d9b839365a49c2b0052eaeb6b9f0a9879813b4fa020f8d9caa3ddde46 \ + --hash=sha256:fa6a6faf3130d6ef05f9ee6cbc27e0a2e4db0f23720136403991eccb3a0144af + # via -r build/requirements.in +jaxlib==0.8.2 \ + --hash=sha256:023de6f3f56da2af7037970996500586331fdb50b530ecbb54b9666da633bd00 \ + --hash=sha256:05b958f497e49824c432e734bb059723b7dfe69e2ad696a9f9c8ad82fff7c3f8 \ + --hash=sha256:1bfbcf6c3de221784fa4cdb6765a09d71cb4298b15626b3d0409b3dfcd8a8667 \ + --hash=sha256:28eec1a4e0639a0d8702cea3cb70dd3663053dbfa344452994ea48dc6ceadaa5 \ + --hash=sha256:2b9789bd08f8b0cc5a5c12ae896fe432d5942e32e417091b8b5a96a9a6fd5cf1 \ + --hash=sha256:3b16e50c5b730c9dd0a49e55f1acfaa722b00b1af0522a591558dcc0464252f2 \ + --hash=sha256:490bf0cb029c73c65c9431124b86cdc95082dbc1fb76fc549d24d75da33e5454 \ + --hash=sha256:4d006db96be020c8165212a1216372f8acac4ff4f8fb067743d694ef2b301ace \ + --hash=sha256:68108dff0de74adc468016be9a19f80efe48c660c0d5a122287094b44b092afc \ + --hash=sha256:7c304f3a016965b9d1f5239a8a0399a73925f5604fe914c5ca66ecf734bf6422 \ + --hash=sha256:7da8127557c786264049ae55460d1b8d04cc3cdf0403a087f2fc1e6d313ec722 \ + --hash=sha256:964626f581beab31ee6826b228fcc2ec5181b05cecf94a528dff97921c145dbc \ + --hash=sha256:a397ea7dcb37d689ce79173eeb99b2f1347637a36be9a27f20ae6848bfc58bfc \ + --hash=sha256:aa8701b6356f098e8452c3cec762fb5f706fcb8f67ffd65964f63982479aa23b \ + --hash=sha256:bb89be452b1b808d3f88fc01c415b364a260be4cc7ac120c038009f6150a32dc \ + --hash=sha256:beffb004e7eeb5c9afb24439e2b2cf45a4ee3e3e8adf45e355edf2af62acf8b8 \ + --hash=sha256:ccf77da917a20935247c990691decfcbdd06c25ef0ac94d914a04aadb22f714c \ + --hash=sha256:dffc22b5b732b9556d92c918b251c61bcc046617c4dbb51e1f7a656587fddffb \ + --hash=sha256:e6a97dfb0232eed9a2bb6e3828e4f682dbac1a7fea840bfda574cae2dbf5faf9 \ + --hash=sha256:f205e91c3a152a2a76c0bc59a6a2de03e87ec261b91e8812922777185e7b08f5 \ + --hash=sha256:f28edac8c226fc07fa3e8af6f9defede8ac2c307429e3291edce8739d39becc9 \ + --hash=sha256:f472cc72e3058e50b5f0230b236d5a1183bf6c3d5423d2a52eff07bcf34908de + # via -r build/requirements.in +keras==3.12.0 \ + --hash=sha256:02b69e007d5df8042286c3bcc2a888539e3e487590ffb08f6be1b4354df50aa8 \ + --hash=sha256:536e3f8385a05ae04e82e08715a1a59988578087e187b04cb0a6fad11743f07f + # via tensorflow +kiwisolver==1.4.9 \ + --hash=sha256:0749fd8f4218ad2e851e11cc4dc05c7cbc0cbc4267bdfdb31782e65aace4ee9c \ + --hash=sha256:0763515d4df10edf6d06a3c19734e2566368980d21ebec439f33f9eb936c07b7 \ + --hash=sha256:0856e241c2d3df4efef7c04a1e46b1936b6120c9bcf36dd216e3acd84bc4fb21 \ + --hash=sha256:0a590506f303f512dff6b7f75fd2fd18e16943efee932008fe7140e5fa91d80e \ + --hash=sha256:0ab74e19f6a2b027ea4f845a78827969af45ce790e6cb3e1ebab71bdf9f215ff \ + --hash=sha256:0ae37737256ba2de764ddc12aed4956460277f00c4996d51a197e72f62f5eec7 \ + --hash=sha256:0e4e2bf29574a6a7b7f6cb5fa69293b9f96c928949ac4a53ba3f525dffb87f9c \ + --hash=sha256:15163165efc2f627eb9687ea5f3a28137217d217ac4024893d753f46bce9de26 \ + --hash=sha256:17680d737d5335b552994a2008fab4c851bcd7de33094a82067ef3a576ff02fa \ + --hash=sha256:1a12cf6398e8a0a001a059747a1cbf24705e18fe413bc22de7b3d15c67cffe3f \ + --hash=sha256:1b11d6a633e4ed84fc0ddafd4ebfd8ea49b3f25082c04ad12b8315c11d504dc1 \ + --hash=sha256:1fa333e8b2ce4d9660f2cda9c0e1b6bafcfb2457a9d259faa82289e73ec24891 \ + --hash=sha256:2327a4a30d3ee07d2fbe2e7933e8a37c591663b96ce42a00bc67461a87d7df77 \ + --hash=sha256:2405a7d98604b87f3fc28b1716783534b1b4b8510d8142adca34ee0bc3c87543 \ + --hash=sha256:2489e4e5d7ef9a1c300a5e0196e43d9c739f066ef23270607d45aba368b91f2d \ + --hash=sha256:24c175051354f4a28c5d6a31c93906dc653e2bf234e8a4bbfb964892078898ce \ + --hash=sha256:2635d352d67458b66fd0667c14cb1d4145e9560d503219034a18a87e971ce4f3 \ + --hash=sha256:2c1a4f57df73965f3f14df20b80ee29e6a7930a57d2d9e8491a25f676e197c60 \ + --hash=sha256:2c93f00dcba2eea70af2be5f11a830a742fe6b579a1d4e00f47760ef13be247a \ + --hash=sha256:39a219e1c81ae3b103643d2aedb90f1ef22650deb266ff12a19e7773f3e5f089 \ + --hash=sha256:3b3115b2581ea35bb6d1f24a4c90af37e5d9b49dcff267eeed14c3893c5b86ab \ + --hash=sha256:40092754720b174e6ccf9e845d0d8c7d8e12c3d71e7fc35f55f3813e96376f78 \ + --hash=sha256:412f287c55a6f54b0650bd9b6dce5aceddb95864a1a90c87af16979d37c89771 \ + --hash=sha256:464415881e4801295659462c49461a24fb107c140de781d55518c4b80cb6790f \ + --hash=sha256:497d05f29a1300d14e02e6441cf0f5ee81c1ff5a304b0d9fb77423974684e08b \ + --hash=sha256:4a2899935e724dd1074cb568ce7ac0dce28b2cd6ab539c8e001a8578eb106d14 \ + --hash=sha256:4a48a2ce79d65d363597ef7b567ce3d14d68783d2b2263d98db3d9477805ba32 \ + --hash=sha256:4d1d9e582ad4d63062d34077a9a1e9f3c34088a2ec5135b1f7190c07cf366527 \ + --hash=sha256:52a15b0f35dad39862d376df10c5230155243a2c1a436e39eb55623ccbd68185 \ + --hash=sha256:540c7c72324d864406a009d72f5d6856f49693db95d1fbb46cf86febef873634 \ + --hash=sha256:5656aa670507437af0207645273ccdfee4f14bacd7f7c67a4306d0dcaeaf6eed \ + --hash=sha256:5a0f2724dfd4e3b3ac5a82436a8e6fd16baa7d507117e4279b660fe8ca38a3a1 \ + --hash=sha256:60c439763a969a6af93b4881db0eed8fadf93ee98e18cbc35bc8da868d0c4f0c \ + --hash=sha256:61874cdb0a36016354853593cffc38e56fc9ca5aa97d2c05d3dcf6922cd55a11 \ + --hash=sha256:67bb8b474b4181770f926f7b7d2f8c0248cbcb78b660fdd41a47054b28d2a752 \ + --hash=sha256:720e05574713db64c356e86732c0f3c5252818d05f9df320f0ad8380641acea5 \ + --hash=sha256:72d0eb9fba308b8311685c2268cf7d0a0639a6cd027d8128659f72bdd8a024b4 \ + --hash=sha256:767c23ad1c58c9e827b649a9ab7809fd5fd9db266a9cf02b0e926ddc2c680d58 \ + --hash=sha256:77937e5e2a38a7b48eef0585114fe7930346993a88060d0bf886086d2aa49ef5 \ + --hash=sha256:7a08b491ec91b1d5053ac177afe5290adacf1f0f6307d771ccac5de30592d198 \ + --hash=sha256:7b4da0d01ac866a57dd61ac258c5607b4cd677f63abaec7b148354d2b2cdd536 \ + --hash=sha256:7cf974dd4e35fa315563ac99d6287a1024e4dc2077b8a7d7cd3d2fb65d283134 \ + --hash=sha256:84fd60810829c27ae375114cd379da1fa65e6918e1da405f356a775d49a62bcf \ + --hash=sha256:858e4c22fb075920b96a291928cb7dea5644e94c0ee4fcd5af7e865655e4ccf2 \ + --hash=sha256:85b5352f94e490c028926ea567fc569c52ec79ce131dadb968d3853e809518c2 \ + --hash=sha256:85bd218b5ecfbee8c8a82e121802dcb519a86044c9c3b2e4aef02fa05c6da370 \ + --hash=sha256:8a1f570ce4d62d718dce3f179ee78dac3b545ac16c0c04bb363b7607a949c0d1 \ + --hash=sha256:8fdca1def57a2e88ef339de1737a1449d6dbf5fab184c54a1fca01d541317154 \ + --hash=sha256:90f47e70293fc3688b71271100a1a5453aa9944a81d27ff779c108372cf5567b \ + --hash=sha256:92a2f997387a1b79a75e7803aa7ded2cfbe2823852ccf1ba3bcf613b62ae3197 \ + --hash=sha256:9928fe1eb816d11ae170885a74d074f57af3a0d65777ca47e9aeb854a1fba386 \ + --hash=sha256:9af39d6551f97d31a4deebeac6f45b156f9755ddc59c07b402c148f5dbb6482a \ + --hash=sha256:9cf554f21be770f5111a1690d42313e140355e687e05cf82cb23d0a721a64a48 \ + --hash=sha256:a30fd6fdef1430fd9e1ba7b3398b5ee4e2887783917a687d86ba69985fb08748 \ + --hash=sha256:a31d512c812daea6d8b3be3b2bfcbeb091dbb09177706569bcfc6240dcf8b41c \ + --hash=sha256:a5d0432ccf1c7ab14f9949eec60c5d1f924f17c037e9f8b33352fa05799359b8 \ + --hash=sha256:a60ea74330b91bd22a29638940d115df9dc00af5035a9a2a6ad9399ffb4ceca5 \ + --hash=sha256:ac5a486ac389dddcc5bef4f365b6ae3ffff2c433324fb38dd35e3fab7c957999 \ + --hash=sha256:aedff62918805fb62d43a4aa2ecd4482c380dc76cd31bd7c8878588a61bd0369 \ + --hash=sha256:b34e51affded8faee0dfdb705416153819d8ea9250bbbf7ea1b249bdeb5f1122 \ + --hash=sha256:b4b4d74bda2b8ebf4da5bd42af11d02d04428b2c32846e4c2c93219df8a7987b \ + --hash=sha256:b67e6efbf68e077dd71d1a6b37e43e1a99d0bff1a3d51867d45ee8908b931098 \ + --hash=sha256:b78efa4c6e804ecdf727e580dbb9cba85624d2e1c6b5cb059c66290063bd99a9 \ + --hash=sha256:bb4ae2b57fc1d8cbd1cf7b1d9913803681ffa903e7488012be5b76dedf49297f \ + --hash=sha256:bdd1a81a1860476eb41ac4bc1e07b3f07259e6d55bbf739b79c8aaedcf512799 \ + --hash=sha256:bdee92c56a71d2b24c33a7d4c2856bd6419d017e08caa7802d2963870e315028 \ + --hash=sha256:be6a04e6c79819c9a8c2373317d19a96048e5a3f90bec587787e86a1153883c2 \ + --hash=sha256:bfc08add558155345129c7803b3671cf195e6a56e7a12f3dde7c57d9b417f525 \ + --hash=sha256:c3b22c26c6fd6811b0ae8363b95ca8ce4ea3c202d3d0975b2914310ceb1bcc4d \ + --hash=sha256:c9e7cdf45d594ee04d5be1b24dd9d49f3d1590959b2271fb30b5ca2b262c00fb \ + --hash=sha256:cb27e7b78d716c591e88e0a09a2139c6577865d7f2e152488c2cc6257f460872 \ + --hash=sha256:cc9617b46837c6468197b5945e196ee9ca43057bb7d9d1ae688101e4e1dddf64 \ + --hash=sha256:ccd09f20ccdbbd341b21a67ab50a119b64a403b09288c27481575105283c1586 \ + --hash=sha256:ce6a3a4e106cf35c2d9c4fa17c05ce0b180db622736845d4315519397a77beaf \ + --hash=sha256:d0005b053977e7b43388ddec89fa567f43d4f6d5c2c0affe57de5ebf290dc552 \ + --hash=sha256:d4188e73af84ca82468f09cadc5ac4db578109e52acb4518d8154698d3a87ca2 \ + --hash=sha256:d4efec7bcf21671db6a3294ff301d2fc861c31faa3c8740d1a94689234d1b415 \ + --hash=sha256:d75aa530ccfaa593da12834b86a0724f58bff12706659baa9227c2ccaa06264c \ + --hash=sha256:d84cd4061ae292d8ac367b2c3fa3aad11cb8625a95d135fe93f286f914f3f5a6 \ + --hash=sha256:d8aacd3d4b33b772542b2e01beb50187536967b514b00003bdda7589722d2a64 \ + --hash=sha256:d8fc5c867c22b828001b6a38d2eaeb88160bf5783c6cb4a5e440efc981ce286d \ + --hash=sha256:d976bbb382b202f71c67f77b0ac11244021cfa3f7dfd9e562eefcea2df711548 \ + --hash=sha256:dba5ee5d3981160c28d5490f0d1b7ed730c22470ff7f6cc26cfcfaacb9896a07 \ + --hash=sha256:dc1ae486f9abcef254b5618dfb4113dd49f94c68e3e027d03cf0143f3f772b61 \ + --hash=sha256:dd0a578400839256df88c16abddf9ba14813ec5f21362e1fe65022e00c883d4d \ + --hash=sha256:deed0c7258ceb4c44ad5ec7d9918f9f14fd05b2be86378d86cf50e63d1e7b771 \ + --hash=sha256:e09c2279a4d01f099f52d5c4b3d9e208e91edcbd1a175c9662a8b16e000fece9 \ + --hash=sha256:e2ea9f7ab7fbf18fffb1b5434ce7c69a07582f7acc7717720f1d69f3e806f90c \ + --hash=sha256:e6b93f13371d341afee3be9f7c5964e3fe61d5fa30f6a30eb49856935dfe4fc3 \ + --hash=sha256:eb14a5da6dc7642b0f3a18f13654847cd8b7a2550e2645a5bda677862b03ba16 \ + --hash=sha256:ed0fecd28cc62c54b262e3736f8bb2512d8dcfdc2bcf08be5f47f96bf405b145 \ + --hash=sha256:ede8c6d533bc6601a47ad4046080d36b8fc99f81e6f1c17b0ac3c2dc91ac7611 \ + --hash=sha256:efb3a45b35622bb6c16dbfab491a8f5a391fe0e9d45ef32f4df85658232ca0e2 \ + --hash=sha256:f117e1a089d9411663a3207ba874f31be9ac8eaa5b533787024dc07aeb74f464 \ + --hash=sha256:f2ba92255faa7309d06fe44c3a4a97efe1c8d640c2a79a5ef728b685762a6fd2 \ + --hash=sha256:f6008a4919fdbc0b0097089f67a1eb55d950ed7e90ce2cc3e640abadd2757a04 \ + --hash=sha256:f68208a520c3d86ea51acf688a3e3002615a7f0238002cccc17affecc86a8a54 \ + --hash=sha256:f68e4f3eeca8fb22cc3d731f9715a13b652795ef657a13df1ad0c7dc0e9731df \ + --hash=sha256:fb3b8132019ea572f4611d770991000d7f58127560c4889729248eb5852a102f \ + --hash=sha256:fb940820c63a9590d31d88b815e7a3aa5915cad3ce735ab45f0c730b39547de1 \ + --hash=sha256:fc1795ac5cd0510207482c3d1d3ed781143383b8cfd36f5c645f3897ce066220 # via matplotlib -markdown-it-py==3.0.0 \ - --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ - --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb +libclang==18.1.1 \ + --hash=sha256:0b2e143f0fac830156feb56f9231ff8338c20aecfe72b4ffe96f19e5a1dbb69a \ + --hash=sha256:3f0e1f49f04d3cd198985fea0511576b0aee16f9ff0e0f0cad7f9c57ec3c20e8 \ + --hash=sha256:4dd2d3b82fab35e2bf9ca717d7b63ac990a3519c7e312f19fa8e86dcc712f7fb \ + --hash=sha256:54dda940a4a0491a9d1532bf071ea3ef26e6dbaf03b5000ed94dd7174e8f9592 \ + --hash=sha256:69f8eb8f65c279e765ffd28aaa7e9e364c776c17618af8bff22a8df58677ff4f \ + --hash=sha256:6f14c3f194704e5d09769108f03185fce7acaf1d1ae4bbb2f30a72c2400cb7c5 \ + --hash=sha256:83ce5045d101b669ac38e6da8e58765f12da2d3aafb3b9b98d88b286a60964d8 \ + --hash=sha256:a1214966d08d73d971287fc3ead8dfaf82eb07fb197680d8b3859dbbbbf78250 \ + --hash=sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b \ + --hash=sha256:cf4a99b05376513717ab5d82a0db832c56ccea4fd61a69dbb7bccf2dfb207dbe + # via tensorflow +libtpu==0.0.32 ; sys_platform == "linux" and platform_machine == "x86_64" \ + --hash=sha256:37f6aefe6d69d24e5e268e74c1e90cd0ce0fa6a796720709445d8938605912ee \ + --hash=sha256:4a4ae6db0a90f0e33902cd345923a5d77dfc5bbab9a3e524c79bf876858104ee \ + --hash=sha256:6453995774b902e43d1caf0d82298a2e90f2f731ddd74d6bef9510d91732775f \ + --hash=sha256:7ffd9cdb89da22d32bb60ce70b88e9e76861c5d704ebf673f32987e19e9913f3 \ + --hash=sha256:98d9713a39d9581c9b10a7bdeaa575cc874a4a7f543cfb40d3a723ee27b428b5 \ + --hash=sha256:d2a0e7abe4a029b79049bc95da20e4c2a64f8ef09823f75286cecfbb4b0b6da1 + # via -r build/requirements.in +markdown==3.10 \ + --hash=sha256:37062d4f2aa4b2b6b32aefb80faa300f82cc790cb949a35b8caede34f2b68c0e \ + --hash=sha256:b5b99d6951e2e4948d939255596523444c0e677c669700b1d17aa4a8a464cb7c + # via tensorboard +markdown-it-py==4.0.0 \ + --hash=sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147 \ + --hash=sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3 # via rich -matplotlib==3.9.0 ; python_version >= "3.11" \ - --hash=sha256:063af8587fceeac13b0936c42a2b6c732c2ab1c98d38abc3337e430e1ff75e38 \ - --hash=sha256:06a478f0d67636554fa78558cfbcd7b9dba85b51f5c3b5a0c9be49010cf5f321 \ - --hash=sha256:0a490715b3b9984fa609116481b22178348c1a220a4499cda79132000a79b4db \ - --hash=sha256:0fc51eaa5262553868461c083d9adadb11a6017315f3a757fc45ec6ec5f02888 \ - --hash=sha256:13beb4840317d45ffd4183a778685e215939be7b08616f431c7795276e067463 \ - --hash=sha256:290d304e59be2b33ef5c2d768d0237f5bd132986bdcc66f80bc9bcc300066a03 \ - --hash=sha256:2bcee1dffaf60fe7656183ac2190bd630842ff87b3153afb3e384d966b57fe56 \ - --hash=sha256:2e7f03e5cbbfacdd48c8ea394d365d91ee8f3cae7e6ec611409927b5ed997ee4 \ - --hash=sha256:3f988bafb0fa39d1074ddd5bacd958c853e11def40800c5824556eb630f94d3b \ - --hash=sha256:52146fc3bd7813cc784562cb93a15788be0b2875c4655e2cc6ea646bfa30344b \ - --hash=sha256:550cdda3adbd596078cca7d13ed50b77879104e2e46392dcd7c75259d8f00e85 \ - --hash=sha256:616fabf4981a3b3c5a15cd95eba359c8489c4e20e03717aea42866d8d0465956 \ - --hash=sha256:76cce0f31b351e3551d1f3779420cf8f6ec0d4a8cf9c0237a3b549fd28eb4abb \ - --hash=sha256:7ff2e239c26be4f24bfa45860c20ffccd118d270c5b5d081fa4ea409b5469fcd \ - --hash=sha256:8146ce83cbc5dc71c223a74a1996d446cd35cfb6a04b683e1446b7e6c73603b7 \ - --hash=sha256:81c40af649d19c85f8073e25e5806926986806fa6d54be506fbf02aef47d5a89 \ - --hash=sha256:9a2fa6d899e17ddca6d6526cf6e7ba677738bf2a6a9590d702c277204a7c6152 \ - --hash=sha256:a5be985db2596d761cdf0c2eaf52396f26e6a64ab46bd8cd810c48972349d1be \ - --hash=sha256:af4001b7cae70f7eaacfb063db605280058246de590fa7874f00f62259f2df7e \ - --hash=sha256:bd4f2831168afac55b881db82a7730992aa41c4f007f1913465fb182d6fb20c0 \ - --hash=sha256:bdd1ecbe268eb3e7653e04f451635f0fb0f77f07fd070242b44c076c9106da84 \ - --hash=sha256:c53aeb514ccbbcbab55a27f912d79ea30ab21ee0531ee2c09f13800efb272674 \ - --hash=sha256:c79f3a585f1368da6049318bdf1f85568d8d04b2e89fc24b7e02cc9b62017382 \ - --hash=sha256:cd53c79fd02f1c1808d2cfc87dd3cf4dbc63c5244a58ee7944497107469c8d8a \ - --hash=sha256:d38e85a1a6d732f645f1403ce5e6727fd9418cd4574521d5803d3d94911038e5 \ - --hash=sha256:d91a4ffc587bacf5c4ce4ecfe4bcd23a4b675e76315f2866e588686cc97fccdf \ - --hash=sha256:e6d29ea6c19e34b30fb7d88b7081f869a03014f66fe06d62cc77d5a6ea88ed7a \ - --hash=sha256:eaf3978060a106fab40c328778b148f590e27f6fa3cd15a19d6892575bce387d \ - --hash=sha256:fe428e191ea016bb278758c8ee82a8129c51d81d8c4bc0846c09e7e8e9057241 +markupsafe==3.0.3 \ + --hash=sha256:0303439a41979d9e74d18ff5e2dd8c43ed6c6001fd40e5bf2e43f7bd9bbc523f \ + --hash=sha256:068f375c472b3e7acbe2d5318dea141359e6900156b5b2ba06a30b169086b91a \ + --hash=sha256:0bf2a864d67e76e5c9a34dc26ec616a66b9888e25e7b9460e1c76d3293bd9dbf \ + --hash=sha256:0db14f5dafddbb6d9208827849fad01f1a2609380add406671a26386cdf15a19 \ + --hash=sha256:0eb9ff8191e8498cca014656ae6b8d61f39da5f95b488805da4bb029cccbfbaf \ + --hash=sha256:0f4b68347f8c5eab4a13419215bdfd7f8c9b19f2b25520968adfad23eb0ce60c \ + --hash=sha256:1085e7fbddd3be5f89cc898938f42c0b3c711fdcb37d75221de2666af647c175 \ + --hash=sha256:116bb52f642a37c115f517494ea5feb03889e04df47eeff5b130b1808ce7c219 \ + --hash=sha256:12c63dfb4a98206f045aa9563db46507995f7ef6d83b2f68eda65c307c6829eb \ + --hash=sha256:133a43e73a802c5562be9bbcd03d090aa5a1fe899db609c29e8c8d815c5f6de6 \ + --hash=sha256:1353ef0c1b138e1907ae78e2f6c63ff67501122006b0f9abad68fda5f4ffc6ab \ + --hash=sha256:15d939a21d546304880945ca1ecb8a039db6b4dc49b2c5a400387cdae6a62e26 \ + --hash=sha256:177b5253b2834fe3678cb4a5f0059808258584c559193998be2601324fdeafb1 \ + --hash=sha256:1872df69a4de6aead3491198eaf13810b565bdbeec3ae2dc8780f14458ec73ce \ + --hash=sha256:1b4b79e8ebf6b55351f0d91fe80f893b4743f104bff22e90697db1590e47a218 \ + --hash=sha256:1b52b4fb9df4eb9ae465f8d0c228a00624de2334f216f178a995ccdcf82c4634 \ + --hash=sha256:1ba88449deb3de88bd40044603fafffb7bc2b055d626a330323a9ed736661695 \ + --hash=sha256:1cc7ea17a6824959616c525620e387f6dd30fec8cb44f649e31712db02123dad \ + --hash=sha256:218551f6df4868a8d527e3062d0fb968682fe92054e89978594c28e642c43a73 \ + --hash=sha256:26a5784ded40c9e318cfc2bdb30fe164bdb8665ded9cd64d500a34fb42067b1c \ + --hash=sha256:2713baf880df847f2bece4230d4d094280f4e67b1e813eec43b4c0e144a34ffe \ + --hash=sha256:2a15a08b17dd94c53a1da0438822d70ebcd13f8c3a95abe3a9ef9f11a94830aa \ + --hash=sha256:2f981d352f04553a7171b8e44369f2af4055f888dfb147d55e42d29e29e74559 \ + --hash=sha256:32001d6a8fc98c8cb5c947787c5d08b0a50663d139f1305bac5885d98d9b40fa \ + --hash=sha256:3524b778fe5cfb3452a09d31e7b5adefeea8c5be1d43c4f810ba09f2ceb29d37 \ + --hash=sha256:3537e01efc9d4dccdf77221fb1cb3b8e1a38d5428920e0657ce299b20324d758 \ + --hash=sha256:35add3b638a5d900e807944a078b51922212fb3dedb01633a8defc4b01a3c85f \ + --hash=sha256:38664109c14ffc9e7437e86b4dceb442b0096dfe3541d7864d9cbe1da4cf36c8 \ + --hash=sha256:3a7e8ae81ae39e62a41ec302f972ba6ae23a5c5396c8e60113e9066ef893da0d \ + --hash=sha256:3b562dd9e9ea93f13d53989d23a7e775fdfd1066c33494ff43f5418bc8c58a5c \ + --hash=sha256:457a69a9577064c05a97c41f4e65148652db078a3a509039e64d3467b9e7ef97 \ + --hash=sha256:4bd4cd07944443f5a265608cc6aab442e4f74dff8088b0dfc8238647b8f6ae9a \ + --hash=sha256:4e885a3d1efa2eadc93c894a21770e4bc67899e3543680313b09f139e149ab19 \ + --hash=sha256:4faffd047e07c38848ce017e8725090413cd80cbc23d86e55c587bf979e579c9 \ + --hash=sha256:509fa21c6deb7a7a273d629cf5ec029bc209d1a51178615ddf718f5918992ab9 \ + --hash=sha256:5678211cb9333a6468fb8d8be0305520aa073f50d17f089b5b4b477ea6e67fdc \ + --hash=sha256:591ae9f2a647529ca990bc681daebdd52c8791ff06c2bfa05b65163e28102ef2 \ + --hash=sha256:5a7d5dc5140555cf21a6fefbdbf8723f06fcd2f63ef108f2854de715e4422cb4 \ + --hash=sha256:69c0b73548bc525c8cb9a251cddf1931d1db4d2258e9599c28c07ef3580ef354 \ + --hash=sha256:6b5420a1d9450023228968e7e6a9ce57f65d148ab56d2313fcd589eee96a7a50 \ + --hash=sha256:722695808f4b6457b320fdc131280796bdceb04ab50fe1795cd540799ebe1698 \ + --hash=sha256:729586769a26dbceff69f7a7dbbf59ab6572b99d94576a5592625d5b411576b9 \ + --hash=sha256:77f0643abe7495da77fb436f50f8dab76dbc6e5fd25d39589a0f1fe6548bfa2b \ + --hash=sha256:795e7751525cae078558e679d646ae45574b47ed6e7771863fcc079a6171a0fc \ + --hash=sha256:7be7b61bb172e1ed687f1754f8e7484f1c8019780f6f6b0786e76bb01c2ae115 \ + --hash=sha256:7c3fb7d25180895632e5d3148dbdc29ea38ccb7fd210aa27acbd1201a1902c6e \ + --hash=sha256:7e68f88e5b8799aa49c85cd116c932a1ac15caaa3f5db09087854d218359e485 \ + --hash=sha256:83891d0e9fb81a825d9a6d61e3f07550ca70a076484292a70fde82c4b807286f \ + --hash=sha256:8485f406a96febb5140bfeca44a73e3ce5116b2501ac54fe953e488fb1d03b12 \ + --hash=sha256:8709b08f4a89aa7586de0aadc8da56180242ee0ada3999749b183aa23df95025 \ + --hash=sha256:8f71bc33915be5186016f675cd83a1e08523649b0e33efdb898db577ef5bb009 \ + --hash=sha256:915c04ba3851909ce68ccc2b8e2cd691618c4dc4c4232fb7982bca3f41fd8c3d \ + --hash=sha256:949b8d66bc381ee8b007cd945914c721d9aba8e27f71959d750a46f7c282b20b \ + --hash=sha256:94c6f0bb423f739146aec64595853541634bde58b2135f27f61c1ffd1cd4d16a \ + --hash=sha256:9a1abfdc021a164803f4d485104931fb8f8c1efd55bc6b748d2f5774e78b62c5 \ + --hash=sha256:9b79b7a16f7fedff2495d684f2b59b0457c3b493778c9eed31111be64d58279f \ + --hash=sha256:a320721ab5a1aba0a233739394eb907f8c8da5c98c9181d1161e77a0c8e36f2d \ + --hash=sha256:a4afe79fb3de0b7097d81da19090f4df4f8d3a2b3adaa8764138aac2e44f3af1 \ + --hash=sha256:ad2cf8aa28b8c020ab2fc8287b0f823d0a7d8630784c31e9ee5edea20f406287 \ + --hash=sha256:b8512a91625c9b3da6f127803b166b629725e68af71f8184ae7e7d54686a56d6 \ + --hash=sha256:bc51efed119bc9cfdf792cdeaa4d67e8f6fcccab66ed4bfdd6bde3e59bfcbb2f \ + --hash=sha256:bdc919ead48f234740ad807933cdf545180bfbe9342c2bb451556db2ed958581 \ + --hash=sha256:bdd37121970bfd8be76c5fb069c7751683bdf373db1ed6c010162b2a130248ed \ + --hash=sha256:be8813b57049a7dc738189df53d69395eba14fb99345e0a5994914a3864c8a4b \ + --hash=sha256:c0c0b3ade1c0b13b936d7970b1d37a57acde9199dc2aecc4c336773e1d86049c \ + --hash=sha256:c47a551199eb8eb2121d4f0f15ae0f923d31350ab9280078d1e5f12b249e0026 \ + --hash=sha256:c4ffb7ebf07cfe8931028e3e4c85f0357459a3f9f9490886198848f4fa002ec8 \ + --hash=sha256:ccfcd093f13f0f0b7fdd0f198b90053bf7b2f02a3927a30e63f3ccc9df56b676 \ + --hash=sha256:d2ee202e79d8ed691ceebae8e0486bd9a2cd4794cec4824e1c99b6f5009502f6 \ + --hash=sha256:d53197da72cc091b024dd97249dfc7794d6a56530370992a5e1a08983ad9230e \ + --hash=sha256:d6dd0be5b5b189d31db7cda48b91d7e0a9795f31430b7f271219ab30f1d3ac9d \ + --hash=sha256:d88b440e37a16e651bda4c7c2b930eb586fd15ca7406cb39e211fcff3bf3017d \ + --hash=sha256:de8a88e63464af587c950061a5e6a67d3632e36df62b986892331d4620a35c01 \ + --hash=sha256:df2449253ef108a379b8b5d6b43f4b1a8e81a061d6537becd5582fba5f9196d7 \ + --hash=sha256:e1c1493fb6e50ab01d20a22826e57520f1284df32f2d8601fdd90b6304601419 \ + --hash=sha256:e1cf1972137e83c5d4c136c43ced9ac51d0e124706ee1c8aa8532c1287fa8795 \ + --hash=sha256:e2103a929dfa2fcaf9bb4e7c091983a49c9ac3b19c9061b6d5427dd7d14d81a1 \ + --hash=sha256:e56b7d45a839a697b5eb268c82a71bd8c7f6c94d6fd50c3d577fa39a9f1409f5 \ + --hash=sha256:e8afc3f2ccfa24215f8cb28dcf43f0113ac3c37c2f0f0806d8c70e4228c5cf4d \ + --hash=sha256:e8fc20152abba6b83724d7ff268c249fa196d8259ff481f3b1476383f8f24e42 \ + --hash=sha256:eaa9599de571d72e2daf60164784109f19978b327a3910d3e9de8c97b5b70cfe \ + --hash=sha256:ec15a59cf5af7be74194f7ab02d0f59a62bdcf1a537677ce67a2537c9b87fcda \ + --hash=sha256:f190daf01f13c72eac4efd5c430a8de82489d9cff23c364c3ea822545032993e \ + --hash=sha256:f34c41761022dd093b4b6896d4810782ffbabe30f2d443ff5f083e0cbbb8c737 \ + --hash=sha256:f3e98bb3798ead92273dc0e5fd0f31ade220f59a266ffd8a4f6065e0a3ce0523 \ + --hash=sha256:f42d0984e947b8adf7dd6dde396e720934d12c506ce84eea8476409563607591 \ + --hash=sha256:f71a396b3bf33ecaa1626c255855702aca4d3d9fea5e051b41ac59a9c1c41edc \ + --hash=sha256:f9e130248f4462aaa8e2552d547f36ddadbeaa573879158d721bbd33dfe4743a \ + --hash=sha256:fed51ac40f757d41b7c48425901843666a6677e3e8eb0abcff09e4ba6e664f50 + # via werkzeug +matplotlib==3.10.8 \ + --hash=sha256:00270d217d6b20d14b584c521f810d60c5c78406dc289859776550df837dcda7 \ + --hash=sha256:0a33deb84c15ede243aead39f77e990469fff93ad1521163305095b77b72ce4a \ + --hash=sha256:113bb52413ea508ce954a02c10ffd0d565f9c3bc7f2eddc27dfe1731e71c7b5f \ + --hash=sha256:12d90df9183093fcd479f4172ac26b322b1248b15729cb57f42f71f24c7e37a3 \ + --hash=sha256:15d30132718972c2c074cd14638c7f4592bd98719e2308bccea40e0538bc0cb5 \ + --hash=sha256:18821ace09c763ec93aef5eeff087ee493a24051936d7b9ebcad9662f66501f9 \ + --hash=sha256:1ae029229a57cd1e8fe542485f27e7ca7b23aa9e8944ddb4985d0bc444f1eca2 \ + --hash=sha256:2299372c19d56bcd35cf05a2738308758d32b9eaed2371898d8f5bd33f084aa3 \ + --hash=sha256:238b7ce5717600615c895050239ec955d91f321c209dd110db988500558e70d6 \ + --hash=sha256:24d50994d8c5816ddc35411e50a86ab05f575e2530c02752e02538122613371f \ + --hash=sha256:25d380fe8b1dc32cf8f0b1b448470a77afb195438bafdf1d858bfb876f3edf7b \ + --hash=sha256:2c1998e92cd5999e295a731bcb2911c75f597d937341f3030cc24ef2733d78a8 \ + --hash=sha256:2cf5bd12cecf46908f286d7838b2abc6c91cda506c0445b8223a7c19a00df008 \ + --hash=sha256:32f8dce744be5569bebe789e46727946041199030db8aeb2954d26013a0eb26b \ + --hash=sha256:37b3c1cc42aa184b3f738cfa18c1c1d72fd496d85467a6cf7b807936d39aa656 \ + --hash=sha256:3a48a78d2786784cc2413e57397981fb45c79e968d99656706018d6e62e57958 \ + --hash=sha256:3ab4aabc72de4ff77b3ec33a6d78a68227bf1123465887f9905ba79184a1cc04 \ + --hash=sha256:3c624e43ed56313651bc18a47f838b60d7b8032ed348911c54906b130b20071b \ + --hash=sha256:3f2e409836d7f5ac2f1c013110a4d50b9f7edc26328c108915f9075d7d7a91b6 \ + --hash=sha256:3f5c3e4da343bba819f0234186b9004faba952cc420fbc522dc4e103c1985908 \ + --hash=sha256:41703cc95688f2516b480f7f339d8851a6035f18e100ee6a32bc0b8536a12a9c \ + --hash=sha256:495672de149445ec1b772ff2c9ede9b769e3cb4f0d0aa7fa730d7f59e2d4e1c1 \ + --hash=sha256:4cf267add95b1c88300d96ca837833d4112756045364f5c734a2276038dae27d \ + --hash=sha256:56271f3dac49a88d7fca5060f004d9d22b865f743a12a23b1e937a0be4818ee1 \ + --hash=sha256:595ba4d8fe983b88f0eec8c26a241e16d6376fe1979086232f481f8f3f67494c \ + --hash=sha256:5f62550b9a30afde8c1c3ae450e5eb547d579dd69b25c2fc7a1c67f934c1717a \ + --hash=sha256:646d95230efb9ca614a7a594d4fcacde0ac61d25e37dd51710b36477594963ce \ + --hash=sha256:64fcc24778ca0404ce0cb7b6b77ae1f4c7231cdd60e6778f999ee05cbd581b9a \ + --hash=sha256:6be43b667360fef5c754dda5d25a32e6307a03c204f3c0fc5468b78fa87b4160 \ + --hash=sha256:6da7c2ce169267d0d066adcf63758f0604aa6c3eebf67458930f9d9b79ad1db1 \ + --hash=sha256:83d282364ea9f3e52363da262ce32a09dfe241e4080dcedda3c0db059d3c1f11 \ + --hash=sha256:9153c3292705be9f9c64498a8872118540c3f4123d1a1c840172edf262c8be4a \ + --hash=sha256:99eefd13c0dc3b3c1b4d561c1169e65fe47aab7b8158754d7c084088e2329466 \ + --hash=sha256:a0a7f52498f72f13d4a25ea70f35f4cb60642b466cbb0a9be951b5bc3f45a486 \ + --hash=sha256:a2b336e2d91a3d7006864e0990c83b216fcdca64b5a6484912902cef87313d78 \ + --hash=sha256:a48f2b74020919552ea25d222d5cc6af9ca3f4eb43a93e14d068457f545c2a17 \ + --hash=sha256:ad3d9833a64cf48cc4300f2b406c3d0f4f4724a91c0bd5640678a6ba7c102077 \ + --hash=sha256:b44d07310e404ba95f8c25aa5536f154c0a8ec473303535949e52eb71d0a1565 \ + --hash=sha256:b53285e65d4fa4c86399979e956235deb900be5baa7fc1218ea67fbfaeaadd6f \ + --hash=sha256:b5a2b97dbdc7d4f353ebf343744f1d1f1cca8aa8bfddb4262fcf4306c3761d50 \ + --hash=sha256:b9a5ca4ac220a0cdd1ba6bcba3608547117d30468fefce49bb26f55c1a3d5c58 \ + --hash=sha256:bab485bcf8b1c7d2060b4fcb6fc368a9e6f4cd754c9c2fea281f4be21df394a2 \ + --hash=sha256:c108a1d6fa78a50646029cb6d49808ff0fc1330fda87fa6f6250c6b5369b6645 \ + --hash=sha256:d56a1efd5bfd61486c8bc968fa18734464556f0fb8e51690f4ac25d85cbbbbc2 \ + --hash=sha256:d9050fee89a89ed57b4fb2c1bfac9a3d0c57a0d55aed95949eedbc42070fea39 \ + --hash=sha256:dd80ecb295460a5d9d260df63c43f4afbdd832d725a531f008dad1664f458adf \ + --hash=sha256:e8ea3e2d4066083e264e75c829078f9e149fa119d27e19acd503de65e0b13149 \ + --hash=sha256:eb3823f11823deade26ce3b9f40dcb4a213da7a670013929f31d5f5ed1055b22 \ + --hash=sha256:ee40c27c795bda6a5292e9cff9890189d32f7e3a0bf04e0e3c9430c4a00c37df \ + --hash=sha256:efb30e3baaea72ce5928e32bab719ab4770099079d66726a62b11b1ef7273be4 \ + --hash=sha256:f254d118d14a7f99d616271d6c3c27922c092dac11112670b157798b89bf4933 \ + --hash=sha256:f89c151aab2e2e23cb3fe0acad1e8b82841fd265379c4cecd0f3fcb34c15e0f6 \ + --hash=sha256:f97aeb209c3d2511443f8797e3e5a569aebb040d4f8bc79aa3ee78a8fb9e3dd8 \ + --hash=sha256:f9b587c9c7274c1613a30afabf65a272114cd6cdbe67b3406f818c79d7ab2e2a \ + --hash=sha256:fb061f596dad3a0f52b60dc6a5dec4a0c300dec41e058a7efe09256188d170b7 # via -r build/test-requirements.txt mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba # via markdown-it-py -ml-dtypes==0.5.1 \ - --hash=sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327 \ - --hash=sha256:05f23447a1c20ddf4dc7c2c661aa9ed93fcb2658f1017c204d1e758714dc28a8 \ - --hash=sha256:12651420130ee7cc13059fc56dac6ad300c3af3848b802d475148c9defd27c23 \ - --hash=sha256:141b2ea2f20bb10802ddca55d91fe21231ef49715cfc971998e8f2a9838f3dbe \ - --hash=sha256:15ad0f3b0323ce96c24637a88a6f44f6713c64032f27277b069f285c3cf66478 \ - --hash=sha256:1b7fbe5571fdf28fd3aaab3ef4aafc847de9ebf263be959958c1ca58ec8eadf5 \ - --hash=sha256:26ebcc69d7b779c8f129393e99732961b5cc33fcff84090451f448c89b0e01b4 \ - --hash=sha256:6f462f5eca22fb66d7ff9c4744a3db4463af06c49816c4b6ac89b16bfcdc592e \ - --hash=sha256:6f76232163b5b9c34291b54621ee60417601e2e4802a188a0ea7157cd9b323f4 \ - --hash=sha256:7000b6e4d8ef07542c05044ec5d8bbae1df083b3f56822c3da63993a113e716f \ - --hash=sha256:810512e2eccdfc3b41eefa3a27402371a3411453a1efc7e9c000318196140fed \ - --hash=sha256:8f2c028954f16ede77902b223a8da2d9cbb3892375b85809a5c3cfb1587960c4 \ - --hash=sha256:9626d0bca1fb387d5791ca36bacbba298c5ef554747b7ebeafefb4564fc83566 \ - --hash=sha256:ac5b58559bb84a95848ed6984eb8013249f90b6bab62aa5acbad876e256002c9 \ - --hash=sha256:ad4953c5eb9c25a56d11a913c2011d7e580a435ef5145f804d98efa14477d390 \ - --hash=sha256:aefedc579ece2f8fb38f876aa7698204ee4c372d0e54f1c1ffa8ca580b54cc60 \ - --hash=sha256:afb2009ac98da274e893e03162f6269398b2b00d947e7057ee2469a921d58135 \ - --hash=sha256:b8a9d46b4df5ae2135a8e8e72b465448ebbc1559997f4f9304a9ecc3413efb5b \ - --hash=sha256:bd73f51957949069573ff783563486339a9285d72e2f36c18e0c1aa9ca7eb190 \ - --hash=sha256:bf9975bda82a99dc935f2ae4c83846d86df8fd6ba179614acac8e686910851da \ - --hash=sha256:c09526488c3a9e8b7a23a388d4974b670a9a3dd40c5c8a61db5593ce9b725bab \ - --hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \ - --hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \ - --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 - # via -r build/requirements.in -mpmath==1.4.0a1 \ - --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ - --hash=sha256:f8b7b5a3a1726ab6e8c898eb2157426b82c482ab1ab8ffed9f88bb9e07c6e9c1 - # via -r build/test-requirements.txt -numpy==2.0.0 ; python_version <= "3.12" \ - --hash=sha256:04494f6ec467ccb5369d1808570ae55f6ed9b5809d7f035059000a37b8d7e86f \ - --hash=sha256:0a43f0974d501842866cc83471bdb0116ba0dffdbaac33ec05e6afed5b615238 \ - --hash=sha256:0e50842b2295ba8414c8c1d9d957083d5dfe9e16828b37de883f51fc53c4016f \ - --hash=sha256:0ec84b9ba0654f3b962802edc91424331f423dcf5d5f926676e0150789cb3d95 \ - --hash=sha256:17067d097ed036636fa79f6a869ac26df7db1ba22039d962422506640314933a \ - --hash=sha256:1cde1753efe513705a0c6d28f5884e22bdc30438bf0085c5c486cdaff40cd67a \ - --hash=sha256:1e72728e7501a450288fc8e1f9ebc73d90cfd4671ebbd631f3e7857c39bd16f2 \ - --hash=sha256:2635dbd200c2d6faf2ef9a0d04f0ecc6b13b3cad54f7c67c61155138835515d2 \ - --hash=sha256:2ce46fd0b8a0c947ae047d222f7136fc4d55538741373107574271bc00e20e8f \ - --hash=sha256:34f003cb88b1ba38cb9a9a4a3161c1604973d7f9d5552c38bc2f04f829536609 \ - --hash=sha256:354f373279768fa5a584bac997de6a6c9bc535c482592d7a813bb0c09be6c76f \ - --hash=sha256:38ecb5b0582cd125f67a629072fed6f83562d9dd04d7e03256c9829bdec027ad \ - --hash=sha256:3e8e01233d57639b2e30966c63d36fcea099d17c53bf424d77f088b0f4babd86 \ - --hash=sha256:3f6bed7f840d44c08ebdb73b1825282b801799e325bcbdfa6bc5c370e5aecc65 \ - --hash=sha256:4554eb96f0fd263041baf16cf0881b3f5dafae7a59b1049acb9540c4d57bc8cb \ - --hash=sha256:46e161722e0f619749d1cd892167039015b2c2817296104487cd03ed4a955995 \ - --hash=sha256:49d9f7d256fbc804391a7f72d4a617302b1afac1112fac19b6c6cec63fe7fe8a \ - --hash=sha256:4d2f62e55a4cd9c58c1d9a1c9edaedcd857a73cb6fda875bf79093f9d9086f85 \ - --hash=sha256:5f64641b42b2429f56ee08b4f427a4d2daf916ec59686061de751a55aafa22e4 \ - --hash=sha256:63b92c512d9dbcc37f9d81b123dec99fdb318ba38c8059afc78086fe73820275 \ - --hash=sha256:6d7696c615765091cc5093f76fd1fa069870304beaccfd58b5dcc69e55ef49c1 \ - --hash=sha256:79e843d186c8fb1b102bef3e2bc35ef81160ffef3194646a7fdd6a73c6b97196 \ - --hash=sha256:821eedb7165ead9eebdb569986968b541f9908979c2da8a4967ecac4439bae3d \ - --hash=sha256:84554fc53daa8f6abf8e8a66e076aff6ece62de68523d9f665f32d2fc50fd66e \ - --hash=sha256:8d83bb187fb647643bd56e1ae43f273c7f4dbcdf94550d7938cfc32566756514 \ - --hash=sha256:903703372d46bce88b6920a0cd86c3ad82dae2dbef157b5fc01b70ea1cfc430f \ - --hash=sha256:9416a5c2e92ace094e9f0082c5fd473502c91651fb896bc17690d6fc475128d6 \ - --hash=sha256:9a1712c015831da583b21c5bfe15e8684137097969c6d22e8316ba66b5baabe4 \ - --hash=sha256:9c27f0946a3536403efb0e1c28def1ae6730a72cd0d5878db38824855e3afc44 \ - --hash=sha256:a356364941fb0593bb899a1076b92dfa2029f6f5b8ba88a14fd0984aaf76d0df \ - --hash=sha256:a7039a136017eaa92c1848152827e1424701532ca8e8967fe480fe1569dae581 \ - --hash=sha256:acd3a644e4807e73b4e1867b769fbf1ce8c5d80e7caaef0d90dcdc640dfc9787 \ - --hash=sha256:ad0c86f3455fbd0de6c31a3056eb822fc939f81b1618f10ff3406971893b62a5 \ - --hash=sha256:b4c76e3d4c56f145d41b7b6751255feefae92edbc9a61e1758a98204200f30fc \ - --hash=sha256:b6f6a8f45d0313db07d6d1d37bd0b112f887e1369758a5419c0370ba915b3871 \ - --hash=sha256:c5a59996dc61835133b56a32ebe4ef3740ea5bc19b3983ac60cc32be5a665d54 \ - --hash=sha256:c73aafd1afca80afecb22718f8700b40ac7cab927b8abab3c3e337d70e10e5a2 \ - --hash=sha256:cee6cc0584f71adefe2c908856ccc98702baf95ff80092e4ca46061538a2ba98 \ - --hash=sha256:cef04d068f5fb0518a77857953193b6bb94809a806bd0a14983a8f12ada060c9 \ - --hash=sha256:cf5d1c9e6837f8af9f92b6bd3e86d513cdc11f60fd62185cc49ec7d1aba34864 \ - --hash=sha256:e61155fae27570692ad1d327e81c6cf27d535a5d7ef97648a17d922224b216de \ - --hash=sha256:e7f387600d424f91576af20518334df3d97bc76a300a755f9a8d6e4f5cadd289 \ - --hash=sha256:ed08d2703b5972ec736451b818c2eb9da80d66c3e84aed1deeb0c345fefe461b \ - --hash=sha256:fbd6acc766814ea6443628f4e6751d0da6593dae29c08c0b2606164db026970c \ - --hash=sha256:feff59f27338135776f6d4e2ec7aeeac5d5f7a08a83e80869121ef8164b74af9 +ml-dtypes==0.5.4 \ + --hash=sha256:0d2ffd05a2575b1519dc928c0b93c06339eb67173ff53acb00724502cda231cf \ + --hash=sha256:11942cbf2cf92157db91e5022633c0d9474d4dfd813a909383bd23ce828a4b7d \ + --hash=sha256:14a4fd3228af936461db66faccef6e4f41c1d82fcc30e9f8d58a08916b1d811f \ + --hash=sha256:19b9a53598f21e453ea2fbda8aa783c20faff8e1eeb0d7ab899309a0053f1483 \ + --hash=sha256:2314892cdc3fcf05e373d76d72aaa15fda9fb98625effa73c1d646f331fcecb7 \ + --hash=sha256:2b857d3af6ac0d39db1de7c706e69c7f9791627209c3d6dedbfca8c7e5faec22 \ + --hash=sha256:304ad47faa395415b9ccbcc06a0350800bc50eda70f0e45326796e27c62f18b6 \ + --hash=sha256:35f29491a3e478407f7047b8a4834e4640a77d2737e0b294d049746507af5175 \ + --hash=sha256:388d399a2152dd79a3f0456a952284a99ee5c93d3e2f8dfe25977511e0515270 \ + --hash=sha256:3bbbe120b915090d9dd1375e4684dd17a20a2491ef25d640a908281da85e73f1 \ + --hash=sha256:3d277bf3637f2a62176f4575512e9ff9ef51d00e39626d9fe4a161992f355af2 \ + --hash=sha256:4381fe2f2452a2d7589689693d3162e876b3ddb0a832cde7a414f8e1adf7eab1 \ + --hash=sha256:4ff7f3e7ca2972e7de850e7b8fcbb355304271e2933dd90814c1cb847414d6e2 \ + --hash=sha256:531eff30e4d368cb6255bc2328d070e35836aa4f282a0fb5f3a0cd7260257298 \ + --hash=sha256:533ce891ba774eabf607172254f2e7260ba5f57bdd64030c9a4fcfbd99815d0d \ + --hash=sha256:557a31a390b7e9439056644cb80ed0735a6e3e3bb09d67fd5687e4b04238d1de \ + --hash=sha256:5a0f68ca8fd8d16583dfa7793973feb86f2fbb56ce3966daf9c9f748f52a2049 \ + --hash=sha256:6a0df4223b514d799b8a1629c65ddc351b3efa833ccf7f8ea0cf654a61d1e35d \ + --hash=sha256:6c7ecb74c4bd71db68a6bea1edf8da8c34f3d9fe218f038814fd1d310ac76c90 \ + --hash=sha256:7c23c54a00ae43edf48d44066a7ec31e05fdc2eee0be2b8b50dd1903a1db94bb \ + --hash=sha256:805cef3a38f4eafae3a5bf9ebdcdb741d0bcfd9e1bd90eb54abd24f928cd2465 \ + --hash=sha256:88c982aac7cb1cbe8cbb4e7f253072b1df872701fcaf48d84ffbb433b6568f24 \ + --hash=sha256:8ab06a50fb9bf9666dd0fe5dfb4676fa2b0ac0f31ecff72a6c3af8e22c063453 \ + --hash=sha256:8c6a2dcebd6f3903e05d51960a8058d6e131fe69f952a5397e5dbabc841b6d56 \ + --hash=sha256:8c760d85a2f82e2bed75867079188c9d18dae2ee77c25a54d60e9cc79be1bc48 \ + --hash=sha256:9ad459e99793fa6e13bd5b7e6792c8f9190b4e5a1b45c63aba14a4d0a7f1d5ff \ + --hash=sha256:9bad06436568442575beb2d03389aa7456c690a5b05892c471215bfd8cf39460 \ + --hash=sha256:a174837a64f5b16cab6f368171a1a03a27936b31699d167684073ff1c4237dac \ + --hash=sha256:a7f7c643e8b1320fd958bf098aa7ecf70623a42ec5154e3be3be673f4c34d900 \ + --hash=sha256:a9b61c19040397970d18d7737375cffd83b1f36a11dd4ad19f83a016f736c3ef \ + --hash=sha256:b4b801ebe0b477be666696bda493a9be8356f1f0057a57f1e35cd26928823e5a \ + --hash=sha256:b95e97e470fe60ed493fd9ae3911d8da4ebac16bd21f87ffa2b7c588bf22ea2c \ + --hash=sha256:bc11d7e8c44a65115d05e2ab9989d1e045125d7be8e05a071a48bc76eb6d6040 \ + --hash=sha256:bfc534409c5d4b0bf945af29e5d0ab075eae9eecbb549ff8a29280db822f34f9 \ + --hash=sha256:c1a953995cccb9e25a4ae19e34316671e4e2edaebe4cf538229b1fc7109087b7 \ + --hash=sha256:cb73dccfc991691c444acc8c0012bee8f2470da826a92e3a20bb333b1a7894e6 \ + --hash=sha256:ce756d3a10d0c4067172804c9cc276ba9cc0ff47af9078ad439b075d1abdc29b \ + --hash=sha256:d81fdb088defa30eb37bf390bb7dde35d3a83ec112ac8e33d75ab28cc29dd8b0 \ + --hash=sha256:f21c9219ef48ca5ee78402d5cc831bd58ea27ce89beda894428bc67a52da5328 # via # -r build/requirements.in + # jaxlib + # keras + # tensorflow + # tensorstore +mpmath==1.3.0 \ + --hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \ + --hash=sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c + # via -r build/test-requirements.txt +namex==0.1.0 \ + --hash=sha256:117f03ccd302cc48e3f5c58a296838f6b89c83455ab8683a1e85f2a430aa4306 \ + --hash=sha256:e2012a474502f1e2251267062aae3114611f07df4224b6e06334c57b0f2ce87c + # via keras +numpy==2.0.2 ; python_version <= "3.12" \ + --hash=sha256:0123ffdaa88fa4ab64835dcbde75dcdf89c453c922f18dced6e27c90d1d0ec5a \ + --hash=sha256:11a76c372d1d37437857280aa142086476136a8c0f373b2e648ab2c8f18fb195 \ + --hash=sha256:13e689d772146140a252c3a28501da66dfecd77490b498b168b501835041f951 \ + --hash=sha256:1e795a8be3ddbac43274f18588329c72939870a16cae810c2b73461c40718ab1 \ + --hash=sha256:26df23238872200f63518dd2aa984cfca675d82469535dc7162dc2ee52d9dd5c \ + --hash=sha256:286cd40ce2b7d652a6f22efdfc6d1edf879440e53e76a75955bc0c826c7e64dc \ + --hash=sha256:2b2955fa6f11907cf7a70dab0d0755159bca87755e831e47932367fc8f2f2d0b \ + --hash=sha256:2da5960c3cf0df7eafefd806d4e612c5e19358de82cb3c343631188991566ccd \ + --hash=sha256:312950fdd060354350ed123c0e25a71327d3711584beaef30cdaa93320c392d4 \ + --hash=sha256:423e89b23490805d2a5a96fe40ec507407b8ee786d66f7328be214f9679df6dd \ + --hash=sha256:496f71341824ed9f3d2fd36cf3ac57ae2e0165c143b55c3a035ee219413f3318 \ + --hash=sha256:49ca4decb342d66018b01932139c0961a8f9ddc7589611158cb3c27cbcf76448 \ + --hash=sha256:51129a29dbe56f9ca83438b706e2e69a39892b5eda6cedcb6b0c9fdc9b0d3ece \ + --hash=sha256:5fec9451a7789926bcf7c2b8d187292c9f93ea30284802a0ab3f5be8ab36865d \ + --hash=sha256:671bec6496f83202ed2d3c8fdc486a8fc86942f2e69ff0e986140339a63bcbe5 \ + --hash=sha256:7f0a0c6f12e07fa94133c8a67404322845220c06a9e80e85999afe727f7438b8 \ + --hash=sha256:807ec44583fd708a21d4a11d94aedf2f4f3c3719035c76a2bbe1fe8e217bdc57 \ + --hash=sha256:883c987dee1880e2a864ab0dc9892292582510604156762362d9326444636e78 \ + --hash=sha256:8c5713284ce4e282544c68d1c3b2c7161d38c256d2eefc93c1d683cf47683e66 \ + --hash=sha256:8cafab480740e22f8d833acefed5cc87ce276f4ece12fdaa2e8903db2f82897a \ + --hash=sha256:8df823f570d9adf0978347d1f926b2a867d5608f434a7cff7f7908c6570dcf5e \ + --hash=sha256:9059e10581ce4093f735ed23f3b9d283b9d517ff46009ddd485f1747eb22653c \ + --hash=sha256:905d16e0c60200656500c95b6b8dca5d109e23cb24abc701d41c02d74c6b3afa \ + --hash=sha256:9189427407d88ff25ecf8f12469d4d39d35bee1db5d39fc5c168c6f088a6956d \ + --hash=sha256:96a55f64139912d61de9137f11bf39a55ec8faec288c75a54f93dfd39f7eb40c \ + --hash=sha256:97032a27bd9d8988b9a97a8c4d2c9f2c15a81f61e2f21404d7e8ef00cb5be729 \ + --hash=sha256:984d96121c9f9616cd33fbd0618b7f08e0cfc9600a7ee1d6fd9b239186d19d97 \ + --hash=sha256:9a92ae5c14811e390f3767053ff54eaee3bf84576d99a2456391401323f4ec2c \ + --hash=sha256:9ea91dfb7c3d1c56a0e55657c0afb38cf1eeae4544c208dc465c3c9f3a7c09f9 \ + --hash=sha256:a15f476a45e6e5a3a79d8a14e62161d27ad897381fecfa4a09ed5322f2085669 \ + --hash=sha256:a392a68bd329eafac5817e5aefeb39038c48b671afd242710b451e76090e81f4 \ + --hash=sha256:a3f4ab0caa7f053f6797fcd4e1e25caee367db3112ef2b6ef82d749530768c73 \ + --hash=sha256:a46288ec55ebbd58947d31d72be2c63cbf839f0a63b49cb755022310792a3385 \ + --hash=sha256:a61ec659f68ae254e4d237816e33171497e978140353c0c2038d46e63282d0c8 \ + --hash=sha256:a842d573724391493a97a62ebbb8e731f8a5dcc5d285dfc99141ca15a3302d0c \ + --hash=sha256:becfae3ddd30736fe1889a37f1f580e245ba79a5855bff5f2a29cb3ccc22dd7b \ + --hash=sha256:c05e238064fc0610c840d1cf6a13bf63d7e391717d247f1bf0318172e759e692 \ + --hash=sha256:c1c9307701fec8f3f7a1e6711f9089c06e6284b3afbbcd259f7791282d660a15 \ + --hash=sha256:c7b0be4ef08607dd04da4092faee0b86607f111d5ae68036f16cc787e250a131 \ + --hash=sha256:cfd41e13fdc257aa5778496b8caa5e856dc4896d4ccf01841daee1d96465467a \ + --hash=sha256:d731a1c6116ba289c1e9ee714b08a8ff882944d4ad631fd411106a30f083c326 \ + --hash=sha256:df55d490dea7934f330006d0f81e8551ba6010a5bf035a249ef61a94f21c500b \ + --hash=sha256:ec9852fb39354b5a45a80bdab5ac02dd02b15f44b3804e9f00c556bf24b4bded \ + --hash=sha256:f15975dfec0cf2239224d80e32c3170b1d168335eaedee69da84fbe9f1f9cd04 \ + --hash=sha256:f26b258c385842546006213344c50655ff1555a9338e2e5e02a0756dc3e803dd + # via + # -r build/nonfreethreading-requirements.txt # contourpy + # h5py + # jaxlib + # keras # matplotlib # ml-dtypes - # opt-einsum + # numpy-typing-compat + # optype # scipy -nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ - --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ - --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ - --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 + # tensorboard + # tensorflow + # tensorstore +numpy-typing-compat==20251206.2.0 \ + --hash=sha256:413171c4333c4175cbad4206c94e58422d291d20426c42581865380156715493 \ + --hash=sha256:7db9d5e991af03b2ade38f43253e4eb03ab88925230931bff7f559c020676fb1 + # via optype +nvidia-cublas==13.1.0.3 ; sys_platform == "linux" \ + --hash=sha256:2a3b94a37def342471c59fad7856caee4926809a72dd5270155d6a31b5b277be \ + --hash=sha256:c86fc7f7ae36d7528288c5d88098edcb7b02c633d262e7ddbb86b0ad91be5df2 \ + --hash=sha256:ee8722c1f0145ab246bccb9e452153b5e0515fd094c3678df50b2a0888b8b171 # via - # -r build/test-requirements.txt + # -r build/nvidia-requirements.txt + # nvidia-cudnn-cu13 + # nvidia-cusolver +nvidia-cublas-cu12==12.9.1.4 ; sys_platform == "linux" \ + --hash=sha256:1e5fee10662e6e52bd71dec533fbbd4971bb70a5f24f3bc3793e5c2e9dc640bf \ + --hash=sha256:453611eb21a7c1f2c2156ed9f3a45b691deda0440ec550860290dc901af5b4c2 \ + --hash=sha256:7a950dae01add3b415a5a5cdc4ec818fb5858263e9cca59004bb99fdbbd3a5d6 + # via + # -r build/nvidia-requirements.txt # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 -nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \ - --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ - --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ - --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 - # via -r build/test-requirements.txt -nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \ - --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ - --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ - --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b - # via -r build/test-requirements.txt -nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ - --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ - --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ - --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 - # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ - --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ - --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ - --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef - # via -r build/test-requirements.txt -nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ - --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ - --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ - --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 - # via -r build/test-requirements.txt -nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \ - --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ - --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ - --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac - # via -r build/test-requirements.txt -nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \ - --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ - --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ - --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 +nvidia-cuda-crt==13.0.88 ; sys_platform == "linux" \ + --hash=sha256:2c8043c7c9e02492716426e9919fc78d2c5b3b2a7a768a88e952676b08aa55a4 \ + --hash=sha256:31e02c52916804ca15e31f272a96181d8fadaf40c4c82a77a6f78071a22eccf3 \ + --hash=sha256:ee2ea2a97073e02ee62bb27841f437332be2c248e3eac013df07997ada39c003 # via - # -r build/test-requirements.txt + # -r build/nvidia-requirements.txt + # nvidia-cuda-nvcc +nvidia-cuda-cupti==13.0.85 ; sys_platform == "linux" \ + --hash=sha256:4eb01c08e859bf924d222250d2e8f8b8ff6d3db4721288cf35d14252a4d933c8 \ + --hash=sha256:683f58d301548deeefcb8f6fac1b8d907691b9d8b18eccab417f51e362102f00 \ + --hash=sha256:796bd679890ee55fb14a94629b698b6db54bcfd833d391d5e94017dd9d7d3151 + # via -r build/nvidia-requirements.txt +nvidia-cuda-cupti-cu12==12.9.79 ; sys_platform == "linux" \ + --hash=sha256:096bcf334f13e1984ba36685ad4c1d6347db214de03dbb6eebb237b41d9d934f \ + --hash=sha256:1848a9380067560d5bee10ed240eecc22991713e672c0515f9c3d9396adf93c8 \ + --hash=sha256:791853b030602c6a11d08b5578edfb957cadea06e9d3b26adbf8d036135a4afe + # via -r build/nvidia-requirements.txt +nvidia-cuda-nvcc==13.0.88 ; sys_platform == "linux" \ + --hash=sha256:56fe502eb77625a12f25172caa3cdddb4e4c8ba2c8c17dba44b164761b380f03 \ + --hash=sha256:7c3a32c8ca9866addfd784da363ddee2f6874d560027a296f583e86a61f2d543 \ + --hash=sha256:c7ff28f86a24effdc6c034fa15230c549a273e4771b10a7fec14996f8cf3307f + # via -r build/nvidia-requirements.txt +nvidia-cuda-nvcc-cu12==12.9.86 ; sys_platform == "linux" \ + --hash=sha256:44e1eca4d08926193a558d2434b1bf83d57b4d5743e0c431c0c83d51da1df62b \ + --hash=sha256:5d6a0d32fdc7ea39917c20065614ae93add6f577d840233237ff08e9a38f58f0 \ + --hash=sha256:8ed7f0b17dea662755395be029376db3b94fed5cbb17c2d35cc866c5b1b84099 + # via -r build/nvidia-requirements.txt +nvidia-cuda-nvrtc==13.0.88 ; sys_platform == "linux" \ + --hash=sha256:6bcd4e7f8e205cbe644f5a98f2f799bef9556fefc89dd786e79a16312ce49872 \ + --hash=sha256:ad9b6d2ead2435f11cbb6868809d2adeeee302e9bb94bcf0539c7a40d80e8575 \ + --hash=sha256:d27f20a0ca67a4bb34268a5e951033496c5b74870b868bacd046b1b8e0c3267b + # via -r build/nvidia-requirements.txt +nvidia-cuda-nvrtc-cu12==12.9.86 ; sys_platform == "linux" \ + --hash=sha256:096d4de6bda726415dfaf3198d4f5c522b8e70139c97feef5cd2ca6d4cd9cead \ + --hash=sha256:210cf05005a447e29214e9ce50851e83fc5f4358df8b453155d5e1918094dcb4 \ + --hash=sha256:72972ebdcf504d69462d3bcd67e7b81edd25d0fb85a2c46d3ea3517666636349 + # via -r build/nvidia-requirements.txt +nvidia-cuda-runtime==13.0.96 ; sys_platform == "linux" \ + --hash=sha256:7f82250d7782aa23b6cfe765ecc7db554bd3c2870c43f3d1821f1d18aebf0548 \ + --hash=sha256:ef9bcbe90493a2b9d810e43d249adb3d02e98dd30200d86607d8d02687c43f55 \ + --hash=sha256:f79298c8a098cec150a597c8eba58ecdab96e3bdc4b9bc4f9983635031740492 + # via + # -r build/nvidia-requirements.txt + # nvidia-cuda-nvcc +nvidia-cuda-runtime-cu12==12.9.79 ; sys_platform == "linux" \ + --hash=sha256:25bba2dfb01d48a9b59ca474a1ac43c6ebf7011f1b0b8cc44f54eb6ac48a96c3 \ + --hash=sha256:83469a846206f2a733db0c42e223589ab62fd2fabac4432d2f8802de4bded0a4 \ + --hash=sha256:8e018af8fa02363876860388bd10ccb89eb9ab8fb0aa749aaf58430a9f7c4891 + # via -r build/nvidia-requirements.txt +nvidia-cudnn-cu12==9.16.0.29 ; sys_platform == "linux" \ + --hash=sha256:142e2bd646a4573ab17d61a24c6359155cdfe1f34c67fc305b71222a7ae45b8e \ + --hash=sha256:4b09c43096db582f110c5572d0bcbd98b30d709e860a8f73c6c3846baa83b8d2 \ + --hash=sha256:78d05b4434dacc7dd9bc903d5c33a2f28a5f0064d02568ef7b2418f89f6c5922 + # via -r build/nvidia-requirements.txt +nvidia-cudnn-cu13==9.16.0.29 ; sys_platform == "linux" \ + --hash=sha256:6349bc8769369a91611c5e2ce5c2e510e61848c245099c31e870d2cdce0ab90d \ + --hash=sha256:79dc1bfe8c1a780cf4eb7b334d14d7927576d6dd8823f8e2769911af30fd4da3 \ + --hash=sha256:faafa46e2e7dd844bbcf06b6adec3fa66924987f2fb21bf67f5c6fd697c74a64 + # via -r build/nvidia-requirements.txt +nvidia-cufft==12.0.0.61 ; sys_platform == "linux" \ + --hash=sha256:2708c852ef8cd89d1d2068bdbece0aa188813a0c934db3779b9b1faa8442e5f5 \ + --hash=sha256:2abce5b39d2f5ae12730fb7e5db6696533e36c26e2d3e8fd1750bdd2853364eb \ + --hash=sha256:6c44f692dce8fd5ffd3e3df134b6cdb9c2f72d99cf40b62c32dde45eea9ddad3 + # via -r build/nvidia-requirements.txt +nvidia-cufft-cu12==11.4.1.4 ; sys_platform == "linux" \ + --hash=sha256:1a28c9b12260a1aa7a8fd12f5ebd82d027963d635ba82ff39a1acfa7c4c0fbcf \ + --hash=sha256:8e5bfaac795e93f80611f807d42844e8e27e340e0cde270dcb6c65386d795b80 \ + --hash=sha256:c67884f2a7d276b4b80eb56a79322a95df592ae5e765cf1243693365ccab4e28 + # via -r build/nvidia-requirements.txt +nvidia-cusolver==12.0.4.66 ; sys_platform == "linux" \ + --hash=sha256:02c2457eaa9e39de20f880f4bd8820e6a1cfb9f9a34f820eb12a155aa5bc92d2 \ + --hash=sha256:0a759da5dea5c0ea10fd307de75cdeb59e7ea4fcb8add0924859b944babf1112 \ + --hash=sha256:16515bd33a8e76bb54d024cfa068fa68d30e80fc34b9e1090813ea9362e0cb65 + # via -r build/nvidia-requirements.txt +nvidia-cusolver-cu12==11.7.5.82 ; sys_platform == "linux" \ + --hash=sha256:15da72d1340d29b5b3cf3fd100e3cd53421dde36002eda6ed93811af63c40d88 \ + --hash=sha256:62efa83e4ace59a4c734d052bb72158e888aa7b770e1a5f601682f16fe5b4fd2 \ + --hash=sha256:77666337237716783c6269a658dea310195cddbd80a5b2919b1ba8735cec8efd + # via -r build/nvidia-requirements.txt +nvidia-cusparse==12.6.3.3 ; sys_platform == "linux" \ + --hash=sha256:2b3c89c88d01ee0e477cb7f82ef60a11a4bcd57b6b87c33f789350b59759360b \ + --hash=sha256:80bcc4662f23f1054ee334a15c72b8940402975e0eab63178fc7e670aa59472c \ + --hash=sha256:cbcf42feb737bd7ec15b4c0a63e62351886bd3f975027b8815d7f720a2b5ea79 + # via + # -r build/nvidia-requirements.txt + # nvidia-cusolver +nvidia-cusparse-cu12==12.5.10.65 ; sys_platform == "linux" \ + --hash=sha256:221c73e7482dd93eda44e65ce567c031c07e2f93f6fa0ecd3ba876a195023e83 \ + --hash=sha256:73060ce019ac064a057267c585bf1fd5a353734151f87472ff02b2c5c9984e78 \ + --hash=sha256:9e487468a22a1eaf1fbd1d2035936a905feb79c4ce5c2f67626764ee4f90227c + # via + # -r build/nvidia-requirements.txt # nvidia-cusolver-cu12 -nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \ - --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ - --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 - # via -r build/test-requirements.txt -nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \ - --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ - --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ - --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 +nvidia-nccl-cu12==2.28.9 ; sys_platform == "linux" \ + --hash=sha256:485776daa8447da5da39681af455aa3b2c2586ddcf4af8772495e7c532c7e5ab \ + --hash=sha256:50a36e01c4a090b9f9c47d92cec54964de6b9fcb3362d0e19b8ffc6323c21b60 + # via -r build/nvidia-requirements.txt +nvidia-nccl-cu13==2.28.9 ; sys_platform == "linux" \ + --hash=sha256:01c873ba1626b54caa12272ed228dc5b2781545e0ae8ba3f432a8ef1c6d78643 \ + --hash=sha256:e4553a30f34195f3fa1da02a6da3d6337d28f2003943aa0a3d247bbc25fefc42 + # via -r build/nvidia-requirements.txt +nvidia-nvjitlink==13.0.88 ; sys_platform == "linux" \ + --hash=sha256:13a74f429e23b921c1109976abefacc69835f2f433ebd323d3946e11d804e47b \ + --hash=sha256:634e96e3da9ef845ae744097a1f289238ecf946ce0b82e93cdce14b9782e682f \ + --hash=sha256:e931536ccc7d467a98ba1d8b89ff7fa7f1fa3b13f2b0069118cd7f47bff07d0c # via - # -r build/test-requirements.txt + # -r build/nvidia-requirements.txt + # nvidia-cufft + # nvidia-cusolver + # nvidia-cusparse +nvidia-nvjitlink-cu12==12.9.86 ; sys_platform == "linux" \ + --hash=sha256:994a05ef08ef4b0b299829cde613a424382aff7efb08a7172c1fa616cc3af2ca \ + --hash=sha256:cc6fcec260ca843c10e34c936921a1c426b351753587fdd638e8cff7b16bb9db \ + --hash=sha256:e3f1171dbdc83c5932a45f0f4c99180a70de9bd2718c1ab77d14104f6d7147f9 + # via + # -r build/nvidia-requirements.txt # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 -opt-einsum==3.3.0 \ - --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ - --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 +nvidia-nvshmem-cu12==3.4.5 ; sys_platform == "linux" \ + --hash=sha256:042f2500f24c021db8a06c5eec2539027d57460e1c1a762055a6554f72c369bd \ + --hash=sha256:0b48363fc6964dede448029434c6abed6c5e37f823cb43c3bcde7ecfc0457e15 + # via -r build/nvidia-requirements.txt +nvidia-nvshmem-cu13==3.4.5 ; sys_platform == "linux" \ + --hash=sha256:290f0a2ee94c9f3687a02502f3b9299a9f9fe826e6d0287ee18482e78d495b80 \ + --hash=sha256:6dc2a197f38e5d0376ad52cd1a2a3617d3cdc150fd5966f4aee9bcebb1d68fe9 + # via -r build/nvidia-requirements.txt +nvidia-nvvm==13.0.88 ; sys_platform == "linux" \ + --hash=sha256:2ef0db7849e476d3b2fc3c09b27bdd79bd7ea8ce58cd9c86553d64ea40844ba0 \ + --hash=sha256:c4376a291d72d22a315d9d2f69bdae8f8cd83a627f75bad395cee49a0fe65dc1 \ + --hash=sha256:c5f41ffeb6466944a026dfa5317d7d85355c119bbec279205d22f1869d1054e0 + # via + # -r build/nvidia-requirements.txt + # nvidia-cuda-nvcc +opt-einsum==3.4.0 \ + --hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \ + --hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac # via # -r build/requirements.in - # -r build/test-requirements.txt -packaging==24.0 \ - --hash=sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 \ - --hash=sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9 + # tensorflow +optree==0.18.0 \ + --hash=sha256:01b79aaee544adf5bfa573db32b943030dfeb9fd1c6e7a97aa417db56a8127e7 \ + --hash=sha256:02d9999840fabef85a6b22e757f336d5591f712f99c710d8b232d52e53115314 \ + --hash=sha256:056894ce6242cd1c7fed71325a7d9f633b2d3b4420c52af48f6a0c4560d74ca1 \ + --hash=sha256:057b983a9526645133553184bed7090bb07855df986abd9e99c456922045c6bc \ + --hash=sha256:07c5f64783ad0f0f80e61c25f276ce79b47deda83ed7956a4a9af6385fe8f60d \ + --hash=sha256:090a3f0ccafa0fe99d71e7d974ae52ff966ac26c409ec41f96556b96646054ef \ + --hash=sha256:0959bac58631e64e2ac6349cc284b37872c24f353b3d73b4682202a431f07d76 \ + --hash=sha256:0d25941de1acba176305dbdeb931dea6143b30d64ebdc5bfea2bfc12ef9e2b0a \ + --hash=sha256:0e0dbe995241efe70cfb522e89c1a7c968216926725a0e5e20cc72bd5d0311b1 \ + --hash=sha256:10f29662d637b80363dc620da46ddc58def7acf7935e20595b23e216ea912367 \ + --hash=sha256:1545c68299c0ce600e4ea1bc9112765dc4afe9a0b8ab43f955df6566bf78db42 \ + --hash=sha256:1b75e083137f361377ff8d70df885ab3a1cf8980e4019e3f311237579adadb64 \ + --hash=sha256:1db0a6497203a13063a8f044ae751dd5d8253cb815359270c38de0e4c9f8bed5 \ + --hash=sha256:1f19867b02a547fc9f11d27c0413e7483cef89699e16f3b9e8af73a9b25e6061 \ + --hash=sha256:1f674e34202383f8b42fa9335f13bedfb6b6f019c66e1f41034929e4be203423 \ + --hash=sha256:20536964ba2458f166c1e8ab25951e3fc0a5056b651bd08f16be99bb3ffed54a \ + --hash=sha256:27611c6c122745a003b5be7aedba49ef86e9fef46d743c234596de0bde6dc679 \ + --hash=sha256:27b1d0cadcf4627c98abbbdce912dbc2243f5687f3c7df39963b793c89321c65 \ + --hash=sha256:289b184cc41dfc400a30db6207ec997884d14540aae2cba10cb88dc7ebaae2a1 \ + --hash=sha256:2b5cfb5fc643f16d3a7d957807e55a937dce07566c49ccc4aa71b01064c56758 \ + --hash=sha256:3014537ff7e4e091ee46e57976f7d95c52f66a0e3eb5ebcbe0de0d924504b58e \ + --hash=sha256:30a2636279bdc805c8e154a0f346bcf704626b831ff44724d305fb72c90b7389 \ + --hash=sha256:30f95279188f6b9300e17c1557989baa991c2d6f519013bd8fea13462a0e6a45 \ + --hash=sha256:30fefc84975ac41d9075993196c64ce0c240510f0539cff121d63b709e03846f \ + --hash=sha256:31539dec60af84e16e99574634811d38e34e1fb381f40d6f489a2e582bf41f03 \ + --hash=sha256:328857d7a35129904b21164f6b0c2ff1d728ad1f5838589c5f437a16c94213c8 \ + --hash=sha256:3804fb6ddc923855db2dc4805b4524c66e00f1ef30b166be4aadd52822b13e06 \ + --hash=sha256:382e5ca02cbd5b20d713d4da189a8613f828832e2af57ccbe04a9c6b0bd9497e \ + --hash=sha256:385bd727cc7bd3c01bd6204028ac2adce8a8f622c296053d9df434aa0e30b01f \ + --hash=sha256:421b839c7ff30df5791e66c89b2e9c2f68191dd6a5d6927c32bcc6b887090df8 \ + --hash=sha256:446c46c53cb8f13abcc0d7dd1989d59bb059953c122fe9901ef53de7fb38b33e \ + --hash=sha256:4cc92339899acb685ee718fd22b25069dfa7be038c63274c54481d54ccc2f9e2 \ + --hash=sha256:4eb146711d4cd0876bf93e0118d3e74050b6f633d756c269ce7cda907281b499 \ + --hash=sha256:51e2cd9ac7fecfd5f6f56ce69f4f805553c226a2744810175959eb408101513c \ + --hash=sha256:55a2ccd121fccc9df961e982db2f4e8f2b4f7015e814ef70b1140514cdffe214 \ + --hash=sha256:56bb19ff827c9a443202b52bf103705ce96ef14d045e0a30d0d7ee7dbcef6a0d \ + --hash=sha256:571b732229d7b2e7a2215f57586f8ec0140e07c0faea916e456cbbfa819e56cb \ + --hash=sha256:5b126c34b459ef4f10f3a4d7d222416d9102b3c5a76b39f346c611792f144821 \ + --hash=sha256:5b75e32c191e4b8cf42a8aa854ed264df82936136c0bcad77be44605da41cdfc \ + --hash=sha256:5bc1221068a58175e0ad62afc199893f77c653206673a5552992a604c66fb77e \ + --hash=sha256:5e669f98b9af9f66144c7ae09912d0367ac3182abe016f67cdd15cb45e13c923 \ + --hash=sha256:66f142c743732cd4e630ea84415f654a00c792793c7f80d4511167f0f89796a6 \ + --hash=sha256:6fc9f8acde3bb561b2034e96079507fbe6d4624058fe204161eb8ef29f961296 \ + --hash=sha256:7172b16e87c87160475275e4bfaa6e4067ccde184d2cca65ba25a402a8ed7758 \ + --hash=sha256:71ca2fcad8972ba56d6cfffbcd962f45f5d4bc04182f23d66154b38c2eb37de3 \ + --hash=sha256:72fa79be4d6515682417f103ae759a22345439eb1319886be936029215ee00dc \ + --hash=sha256:7699957183f8d45402edd6266e175510317f5fcd7f0e623510f2eb7e1ebfc667 \ + --hash=sha256:79bbe14d6cad81f5840958589daa1b836864ada40031712a446dce8129917efd \ + --hash=sha256:7ae6945f68771b1389ee46a1778e779f4ad76bca9306f3e39eb397f9a0dd2753 \ + --hash=sha256:80d971060c888c3989132b7e75dfb50848636d41bc931af1b93fe2019fba469c \ + --hash=sha256:80f28e4666aad66e5e20bdc2c47b5bf320250bb5407b3a39dfb1772787a7068f \ + --hash=sha256:81e755124b77e766166c9d05206b90c68f234f425ad2e3c8a6c96f0db548c67b \ + --hash=sha256:86f5bf05ad236f666e5395e989d6ac2cbfd02556526703e6c6f0a594c7fa081f \ + --hash=sha256:895f23a4cd8aee2c2464efdad2d9bde28a2aaabee634c96423a933f40e74a67e \ + --hash=sha256:89d5156f8a0a3792701e1c31473eb307f0b45696f48dc51d721f1bfe0c3a950f \ + --hash=sha256:89e81afb11792d13d3777b503c6f21ec17b1a3b7de69cde1ae2c5471bcdcd4a0 \ + --hash=sha256:8a2003fab79694e04b5f260628511e441c248b46a9fc46138e2424038ac04ada \ + --hash=sha256:8a4ca121b6fc6b04300fa225fe6c31897e424db0d92691875af326f8c4e1cead \ + --hash=sha256:8a901666afc2d7a8d0c20decc8079763e3313457ee67210382162d90163c0007 \ + --hash=sha256:8b9ad4a01a1346b11acc574b7f932dea1a7c7ab31d93546a7540a1f02b3e724a \ + --hash=sha256:8d88c00c70b5914904feaf8f505f3512c2f3f4493dbbd93951fcdddc85dcfe8c \ + --hash=sha256:9104fc8915890e7292e5833fc677e4749607c67aa3cf8884677267078201c2f3 \ + --hash=sha256:9460cba62e941626beb75c99a803373b38a52136d5f1932fcdfdcede1df6f2ef \ + --hash=sha256:94983b3aa31ee401d2ac77ba570a3157d83f9508cfbb006095a48770e0a1c5ca \ + --hash=sha256:9b1e7e8f9ddc85f05d542b74157bdb73ed0e49aded67d1775f721fcd6eb9be94 \ + --hash=sha256:9d4b9d8c7e9335120ecf222d817699d17de743ad118080fb40467c367f009143 \ + --hash=sha256:a479fa25b6e2430e530d00f0c27a55e15ecb9de8ad2d0aec3d40b680e2d6df64 \ + --hash=sha256:a5c213a291c798139ed9ff80aec4bfcd2ac8f001bc015a9cdeb78457e9687dd3 \ + --hash=sha256:a63df296fec376c5cd08298a85109db4a130f4cc8df15916fc92d44ef6068937 \ + --hash=sha256:a74c45f04def041504bd21682eaf7f359f1a50dc7cf42b548b6f19aab50596bd \ + --hash=sha256:ad428ccdb2a40804919880dfe8d2a3021fd4418be15ea7ecb8434ab249badf9f \ + --hash=sha256:b0986ff1267a3b44d3ed76c3efb8b7239371444143f6e0d79f9dd23dbe02c7f9 \ + --hash=sha256:b45d7172c67fc8d2b69f77b384998b39793ee91f8b3b46c609297b781fb7eea5 \ + --hash=sha256:b4da3223c5b4cf694822752d0fbb6bf34c3f41648af1bd1b443cc3d68cc55106 \ + --hash=sha256:b7aa0de08bbbfcef6e49c107f9f397f5d4742548500f16e3e6c5e0b9e4ff0faa \ + --hash=sha256:b8adc912ecb6e4fd9df227ded66efaa6702f46a98e1403554be3c9c51d0ca920 \ + --hash=sha256:ba23caafd0e0c911bb7eab54e5cf69644af864d153e4b2abdab83ff0ef357ba1 \ + --hash=sha256:bda4572392ac1dff3fc67b6d9a4b1084e1637972e8135ad3788b4ce7cf0a90f5 \ + --hash=sha256:c017539e1196ea08f20aea3a4c473f758149b851678edd3d15773b4326decf83 \ + --hash=sha256:c1f20e8754abe312a701ee00d071ddd8502e9d97ca38fbc56204d14a9ffcb41c \ + --hash=sha256:c8841d44f3648b0662e99fc39ef8c248726ddfb4d1bfce4bdba982e51bb7e3f8 \ + --hash=sha256:cbb083a15ea968ad99e7da17d24632348d69e26534e83c69941f3020ed7536eb \ + --hash=sha256:cde70c97e4cc4e997e8fda2266e40a9bff7679c72ab4af6e15e81748a12882cc \ + --hash=sha256:cfa2e16993ba47e671a4e7ee1ad805f67b8d6744eb30a9d27ea0b07b3b7a22ed \ + --hash=sha256:d20765efa494a80a8fd91c4de8890f34de8e9f234da5516e8f34f55703cfb93d \ + --hash=sha256:d2844478690b5892159df0b2500e9d146dc8d3aa5b44e4564d05787b7330eca3 \ + --hash=sha256:d569730b2647c51a5ee68d67198aa9a78c7a55563d57b8cc1ca8d8c8377e7621 \ + --hash=sha256:daab231cf768937ce4675376ea3e214d399116d9867a6737372c31c58630bdfc \ + --hash=sha256:db00c604c1ae452f6092293bf230984d4f6cbb3ad905a9991e8cf680fd7d1523 \ + --hash=sha256:e058cc51d9d57b45801060af9f74765b95bedfc59fd6df1c7489ae0825126be5 \ + --hash=sha256:e28024e6e343353285cf99ae9c74210f0e89e47b2f0f3af7c72c4a9e89dc3ebc \ + --hash=sha256:e4a468ae1541614b5aa7b4f00254bce005ab7572fbb1fc764af4ee17d90fde7b \ + --hash=sha256:ea357657143f364a764b63b2b1ce12d77156d48a1f32def990b696d755acb629 \ + --hash=sha256:efd162e3bfc7812d75ebf2d0fb2783daee2407a92155af8a90650a6b0fa9342e \ + --hash=sha256:f02faeda66d531dc5f5356589afcf2a6bc41c8d00bc903efab60f9a2182b140d \ + --hash=sha256:f04286908654ffb05455254ebf72fe69473fc4560fc7ea49410df94dea6783a2 \ + --hash=sha256:f5197f864630162f008f5dfad3fceef32553c0fa7639eee1b8e280d924ed678e \ + --hash=sha256:f81f5340c8df50662abaf753ab07095901e40b934efb27da50032a4ae71c5a97 \ + --hash=sha256:fa8e3878a1857761d64f08a23b32140d29754a53f85f7c87186ced2b5b1b49cb \ + --hash=sha256:ff7326f36ed70d84c3fd62fb39bc6858f699640b8ab238c3cb8dafe1e200af59 + # via keras +optype[numpy]==0.14.0 \ + --hash=sha256:50d02edafd04edf2e5e27d6249760a51b2198adb9f6ffd778030b3d2806b026b \ + --hash=sha256:925cf060b7d1337647f880401f6094321e7d8e837533b8e159b9a92afa3157c6 + # via scipy-stubs +packaging==25.0 \ + --hash=sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484 \ + --hash=sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f # via # auditwheel # build + # keras # matplotlib # pytest -pillow==11.0.0 \ - --hash=sha256:00177a63030d612148e659b55ba99527803288cea7c75fb05766ab7981a8c1b7 \ - --hash=sha256:006bcdd307cc47ba43e924099a038cbf9591062e6c50e570819743f5607404f5 \ - --hash=sha256:084a07ef0821cfe4858fe86652fffac8e187b6ae677e9906e192aafcc1b69903 \ - --hash=sha256:0ae08bd8ffc41aebf578c2af2f9d8749d91f448b3bfd41d7d9ff573d74f2a6b2 \ - --hash=sha256:0e038b0745997c7dcaae350d35859c9715c71e92ffb7e0f4a8e8a16732150f38 \ - --hash=sha256:1187739620f2b365de756ce086fdb3604573337cc28a0d3ac4a01ab6b2d2a6d2 \ - --hash=sha256:16095692a253047fe3ec028e951fa4221a1f3ed3d80c397e83541a3037ff67c9 \ - --hash=sha256:1a61b54f87ab5786b8479f81c4b11f4d61702830354520837f8cc791ebba0f5f \ - --hash=sha256:1c1d72714f429a521d8d2d018badc42414c3077eb187a59579f28e4270b4b0fc \ - --hash=sha256:1e2688958a840c822279fda0086fec1fdab2f95bf2b717b66871c4ad9859d7e8 \ - --hash=sha256:20ec184af98a121fb2da42642dea8a29ec80fc3efbaefb86d8fdd2606619045d \ - --hash=sha256:21a0d3b115009ebb8ac3d2ebec5c2982cc693da935f4ab7bb5c8ebe2f47d36f2 \ - --hash=sha256:224aaa38177597bb179f3ec87eeefcce8e4f85e608025e9cfac60de237ba6316 \ - --hash=sha256:2679d2258b7f1192b378e2893a8a0a0ca472234d4c2c0e6bdd3380e8dfa21b6a \ - --hash=sha256:27a7860107500d813fcd203b4ea19b04babe79448268403172782754870dac25 \ - --hash=sha256:290f2cc809f9da7d6d622550bbf4c1e57518212da51b6a30fe8e0a270a5b78bd \ - --hash=sha256:2e46773dc9f35a1dd28bd6981332fd7f27bec001a918a72a79b4133cf5291dba \ - --hash=sha256:3107c66e43bda25359d5ef446f59c497de2b5ed4c7fdba0894f8d6cf3822dafc \ - --hash=sha256:375b8dd15a1f5d2feafff536d47e22f69625c1aa92f12b339ec0b2ca40263273 \ - --hash=sha256:45c566eb10b8967d71bf1ab8e4a525e5a93519e29ea071459ce517f6b903d7fa \ - --hash=sha256:499c3a1b0d6fc8213519e193796eb1a86a1be4b1877d678b30f83fd979811d1a \ - --hash=sha256:4ad70c4214f67d7466bea6a08061eba35c01b1b89eaa098040a35272a8efb22b \ - --hash=sha256:4b60c9520f7207aaf2e1d94de026682fc227806c6e1f55bba7606d1c94dd623a \ - --hash=sha256:5178952973e588b3f1360868847334e9e3bf49d19e169bbbdfaf8398002419ae \ - --hash=sha256:52a2d8323a465f84faaba5236567d212c3668f2ab53e1c74c15583cf507a0291 \ - --hash=sha256:598b4e238f13276e0008299bd2482003f48158e2b11826862b1eb2ad7c768b97 \ - --hash=sha256:5bd2d3bdb846d757055910f0a59792d33b555800813c3b39ada1829c372ccb06 \ - --hash=sha256:5c39ed17edea3bc69c743a8dd3e9853b7509625c2462532e62baa0732163a904 \ - --hash=sha256:5d203af30149ae339ad1b4f710d9844ed8796e97fda23ffbc4cc472968a47d0b \ - --hash=sha256:5ddbfd761ee00c12ee1be86c9c0683ecf5bb14c9772ddbd782085779a63dd55b \ - --hash=sha256:607bbe123c74e272e381a8d1957083a9463401f7bd01287f50521ecb05a313f8 \ - --hash=sha256:61b887f9ddba63ddf62fd02a3ba7add935d053b6dd7d58998c630e6dbade8527 \ - --hash=sha256:6619654954dc4936fcff82db8eb6401d3159ec6be81e33c6000dfd76ae189947 \ - --hash=sha256:674629ff60030d144b7bca2b8330225a9b11c482ed408813924619c6f302fdbb \ - --hash=sha256:6ec0d5af64f2e3d64a165f490d96368bb5dea8b8f9ad04487f9ab60dc4bb6003 \ - --hash=sha256:6f4dba50cfa56f910241eb7f883c20f1e7b1d8f7d91c750cd0b318bad443f4d5 \ - --hash=sha256:70fbbdacd1d271b77b7721fe3cdd2d537bbbd75d29e6300c672ec6bb38d9672f \ - --hash=sha256:72bacbaf24ac003fea9bff9837d1eedb6088758d41e100c1552930151f677739 \ - --hash=sha256:7326a1787e3c7b0429659e0a944725e1b03eeaa10edd945a86dead1913383944 \ - --hash=sha256:73853108f56df97baf2bb8b522f3578221e56f646ba345a372c78326710d3830 \ - --hash=sha256:73e3a0200cdda995c7e43dd47436c1548f87a30bb27fb871f352a22ab8dcf45f \ - --hash=sha256:75acbbeb05b86bc53cbe7b7e6fe00fbcf82ad7c684b3ad82e3d711da9ba287d3 \ - --hash=sha256:8069c5179902dcdce0be9bfc8235347fdbac249d23bd90514b7a47a72d9fecf4 \ - --hash=sha256:846e193e103b41e984ac921b335df59195356ce3f71dcfd155aa79c603873b84 \ - --hash=sha256:8594f42df584e5b4bb9281799698403f7af489fba84c34d53d1c4bfb71b7c4e7 \ - --hash=sha256:86510e3f5eca0ab87429dd77fafc04693195eec7fd6a137c389c3eeb4cfb77c6 \ - --hash=sha256:8853a3bf12afddfdf15f57c4b02d7ded92c7a75a5d7331d19f4f9572a89c17e6 \ - --hash=sha256:88a58d8ac0cc0e7f3a014509f0455248a76629ca9b604eca7dc5927cc593c5e9 \ - --hash=sha256:8ba470552b48e5835f1d23ecb936bb7f71d206f9dfeee64245f30c3270b994de \ - --hash=sha256:8c676b587da5673d3c75bd67dd2a8cdfeb282ca38a30f37950511766b26858c4 \ - --hash=sha256:8ec4a89295cd6cd4d1058a5e6aec6bf51e0eaaf9714774e1bfac7cfc9051db47 \ - --hash=sha256:94f3e1780abb45062287b4614a5bc0874519c86a777d4a7ad34978e86428b8dd \ - --hash=sha256:9a0f748eaa434a41fccf8e1ee7a3eed68af1b690e75328fd7a60af123c193b50 \ - --hash=sha256:a5629742881bcbc1f42e840af185fd4d83a5edeb96475a575f4da50d6ede337c \ - --hash=sha256:a65149d8ada1055029fcb665452b2814fe7d7082fcb0c5bed6db851cb69b2086 \ - --hash=sha256:b3c5ac4bed7519088103d9450a1107f76308ecf91d6dabc8a33a2fcfb18d0fba \ - --hash=sha256:b4fd7bd29610a83a8c9b564d457cf5bd92b4e11e79a4ee4716a63c959699b306 \ - --hash=sha256:bcd1fb5bb7b07f64c15618c89efcc2cfa3e95f0e3bcdbaf4642509de1942a699 \ - --hash=sha256:c12b5ae868897c7338519c03049a806af85b9b8c237b7d675b8c5e089e4a618e \ - --hash=sha256:c26845094b1af3c91852745ae78e3ea47abf3dbcd1cf962f16b9a5fbe3ee8488 \ - --hash=sha256:c6a660307ca9d4867caa8d9ca2c2658ab685de83792d1876274991adec7b93fa \ - --hash=sha256:c809a70e43c7977c4a42aefd62f0131823ebf7dd73556fa5d5950f5b354087e2 \ - --hash=sha256:c8b2351c85d855293a299038e1f89db92a2f35e8d2f783489c6f0b2b5f3fe8a3 \ - --hash=sha256:cb929ca942d0ec4fac404cbf520ee6cac37bf35be479b970c4ffadf2b6a1cad9 \ - --hash=sha256:d2c0a187a92a1cb5ef2c8ed5412dd8d4334272617f532d4ad4de31e0495bd923 \ - --hash=sha256:d69bfd8ec3219ae71bcde1f942b728903cad25fafe3100ba2258b973bd2bc1b2 \ - --hash=sha256:daffdf51ee5db69a82dd127eabecce20729e21f7a3680cf7cbb23f0829189790 \ - --hash=sha256:e58876c91f97b0952eb766123bfef372792ab3f4e3e1f1a2267834c2ab131734 \ - --hash=sha256:eda2616eb2313cbb3eebbe51f19362eb434b18e3bb599466a1ffa76a033fb916 \ - --hash=sha256:ee217c198f2e41f184f3869f3e485557296d505b5195c513b2bfe0062dc537f1 \ - --hash=sha256:f02541ef64077f22bf4924f225c0fd1248c168f86e4b7abdedd87d6ebaceab0f \ - --hash=sha256:f1b82c27e89fffc6da125d5eb0ca6e68017faf5efc078128cfaa42cf5cb38798 \ - --hash=sha256:fba162b8872d30fea8c52b258a542c5dfd7b235fb5cb352240c8d63b414013eb \ - --hash=sha256:fbbcb7b57dc9c794843e3d1258c0fbf0f48656d46ffe9e09b63bbd6e8cd5d0a2 \ - --hash=sha256:fcb4621042ac4b7865c179bb972ed0da0218a076dc1820ffc48b1d74c1e37fe9 + # tensorboard + # tensorflow + # wheel +pillow==12.0.0 \ + --hash=sha256:0869154a2d0546545cde61d1789a6524319fc1897d9ee31218eae7a60ccc5643 \ + --hash=sha256:09f2d0abef9e4e2f349305a4f8cc784a8a6c2f58a8c4892eea13b10a943bd26e \ + --hash=sha256:0b817e7035ea7f6b942c13aa03bb554fc44fea70838ea21f8eb31c638326584e \ + --hash=sha256:0fd00cac9c03256c8b2ff58f162ebcd2587ad3e1f2e397eab718c47e24d231cc \ + --hash=sha256:110486b79f2d112cf6add83b28b627e369219388f64ef2f960fef9ebaf54c642 \ + --hash=sha256:1979f4566bb96c1e50a62d9831e2ea2d1211761e5662afc545fa766f996632f6 \ + --hash=sha256:1ac11e8ea4f611c3c0147424eae514028b5e9077dd99ab91e1bd7bc33ff145e1 \ + --hash=sha256:1b1b133e6e16105f524a8dec491e0586d072948ce15c9b914e41cdadd209052b \ + --hash=sha256:1ee80a59f6ce048ae13cda1abf7fbd2a34ab9ee7d401c46be3ca685d1999a399 \ + --hash=sha256:21f241bdd5080a15bc86d3466a9f6074a9c2c2b314100dd896ac81ee6db2f1ba \ + --hash=sha256:266cd5f2b63ff316d5a1bba46268e603c9caf5606d44f38c2873c380950576ad \ + --hash=sha256:26d9f7d2b604cd23aba3e9faf795787456ac25634d82cd060556998e39c6fa47 \ + --hash=sha256:27f95b12453d165099c84f8a8bfdfd46b9e4bda9e0e4b65f0635430027f55739 \ + --hash=sha256:2c54c1a783d6d60595d3514f0efe9b37c8808746a66920315bfd34a938d7994b \ + --hash=sha256:2fa5f0b6716fc88f11380b88b31fe591a06c6315e955c096c35715788b339e3f \ + --hash=sha256:32ed80ea8a90ee3e6fa08c21e2e091bba6eda8eccc83dbc34c95169507a91f10 \ + --hash=sha256:3830c769decf88f1289680a59d4f4c46c72573446352e2befec9a8512104fa52 \ + --hash=sha256:38df9b4bfd3db902c9c2bd369bcacaf9d935b2fff73709429d95cc41554f7b3d \ + --hash=sha256:3adfb466bbc544b926d50fe8f4a4e6abd8c6bffd28a26177594e6e9b2b76572b \ + --hash=sha256:3e42edad50b6909089750e65c91aa09aaf1e0a71310d383f11321b27c224ed8a \ + --hash=sha256:4078242472387600b2ce8d93ade8899c12bf33fa89e55ec89fe126e9d6d5d9e9 \ + --hash=sha256:455247ac8a4cfb7b9bc45b7e432d10421aea9fc2e74d285ba4072688a74c2e9d \ + --hash=sha256:4cc6b3b2efff105c6a1656cfe59da4fdde2cda9af1c5e0b58529b24525d0a098 \ + --hash=sha256:4cf7fed4b4580601c4345ceb5d4cbf5a980d030fd5ad07c4d2ec589f95f09905 \ + --hash=sha256:5193fde9a5f23c331ea26d0cf171fbf67e3f247585f50c08b3e205c7aeb4589b \ + --hash=sha256:5269cc1caeedb67e6f7269a42014f381f45e2e7cd42d834ede3c703a1d915fe3 \ + --hash=sha256:53561a4ddc36facb432fae7a9d8afbfaf94795414f5cdc5fc52f28c1dca90371 \ + --hash=sha256:55f818bd74fe2f11d4d7cbc65880a843c4075e0ac7226bc1a23261dbea531953 \ + --hash=sha256:58eea5ebe51504057dd95c5b77d21700b77615ab0243d8152793dc00eb4faf01 \ + --hash=sha256:5d5c411a8eaa2299322b647cd932586b1427367fd3184ffbb8f7a219ea2041ca \ + --hash=sha256:6846bd2d116ff42cba6b646edf5bf61d37e5cbd256425fa089fee4ff5c07a99e \ + --hash=sha256:6ace95230bfb7cd79ef66caa064bbe2f2a1e63d93471c3a2e1f1348d9f22d6b7 \ + --hash=sha256:6e51b71417049ad6ab14c49608b4a24d8fb3fe605e5dfabfe523b58064dc3d27 \ + --hash=sha256:71db6b4c1653045dacc1585c1b0d184004f0d7e694c7b34ac165ca70c0838082 \ + --hash=sha256:7438839e9e053ef79f7112c881cef684013855016f928b168b81ed5835f3e75e \ + --hash=sha256:759de84a33be3b178a64c8ba28ad5c135900359e85fb662bc6e403ad4407791d \ + --hash=sha256:792a2c0be4dcc18af9d4a2dfd8a11a17d5e25274a1062b0ec1c2d79c76f3e7f8 \ + --hash=sha256:7d87ef5795da03d742bf49439f9ca4d027cde49c82c5371ba52464aee266699a \ + --hash=sha256:7dfb439562f234f7d57b1ac6bc8fe7f838a4bd49c79230e0f6a1da93e82f1fad \ + --hash=sha256:7fa22993bac7b77b78cae22bad1e2a987ddf0d9015c63358032f84a53f23cdc3 \ + --hash=sha256:805ebf596939e48dbb2e4922a1d3852cfc25c38160751ce02da93058b48d252a \ + --hash=sha256:82240051c6ca513c616f7f9da06e871f61bfd7805f566275841af15015b8f98d \ + --hash=sha256:87d4f8125c9988bfbed67af47dd7a953e2fc7b0cc1e7800ec6d2080d490bb353 \ + --hash=sha256:8d8ca2b210ada074d57fcee40c30446c9562e542fc46aedc19baf758a93532ee \ + --hash=sha256:8dc232e39d409036af549c86f24aed8273a40ffa459981146829a324e0848b4b \ + --hash=sha256:90387104ee8400a7b4598253b4c406f8958f59fcf983a6cea2b50d59f7d63d0b \ + --hash=sha256:905b0365b210c73afb0ebe9101a32572152dfd1c144c7e28968a331b9217b94a \ + --hash=sha256:99353a06902c2e43b43e8ff74ee65a7d90307d82370604746738a1e0661ccca7 \ + --hash=sha256:99a7f72fb6249302aa62245680754862a44179b545ded638cf1fef59befb57ef \ + --hash=sha256:9f0b04c6b8584c2c193babcccc908b38ed29524b29dd464bc8801bf10d746a3a \ + --hash=sha256:9fe611163f6303d1619bbcb653540a4d60f9e55e622d60a3108be0d5b441017a \ + --hash=sha256:a3475b96f5908b3b16c47533daaa87380c491357d197564e0ba34ae75c0f3257 \ + --hash=sha256:a6597ff2b61d121172f5844b53f21467f7082f5fb385a9a29c01414463f93b07 \ + --hash=sha256:a7921c5a6d31b3d756ec980f2f47c0cfdbce0fc48c22a39347a895f41f4a6ea4 \ + --hash=sha256:aa5129de4e174daccbc59d0a3b6d20eaf24417d59851c07ebb37aeb02947987c \ + --hash=sha256:aeaefa96c768fc66818730b952a862235d68825c178f1b3ffd4efd7ad2edcb7c \ + --hash=sha256:afbefa430092f71a9593a99ab6a4e7538bc9eabbf7bf94f91510d3503943edc4 \ + --hash=sha256:aff9e4d82d082ff9513bdd6acd4f5bd359f5b2c870907d2b0a9c5e10d40c88fe \ + --hash=sha256:b22bd8c974942477156be55a768f7aa37c46904c175be4e158b6a86e3a6b7ca8 \ + --hash=sha256:b290fd8aa38422444d4b50d579de197557f182ef1068b75f5aa8558638b8d0a5 \ + --hash=sha256:b2e4b27a6e15b04832fe9bf292b94b5ca156016bbc1ea9c2c20098a0320d6cf6 \ + --hash=sha256:b583dc9070312190192631373c6c8ed277254aa6e6084b74bdd0a6d3b221608e \ + --hash=sha256:b87843e225e74576437fd5b6a4c2205d422754f84a06942cfaf1dc32243e45a8 \ + --hash=sha256:bc91a56697869546d1b8f0a3ff35224557ae7f881050e99f615e0119bf934b4e \ + --hash=sha256:bd87e140e45399c818fac4247880b9ce719e4783d767e030a883a970be632275 \ + --hash=sha256:bde737cff1a975b70652b62d626f7785e0480918dece11e8fef3c0cf057351c3 \ + --hash=sha256:bdee52571a343d721fb2eb3b090a82d959ff37fc631e3f70422e0c2e029f3e76 \ + --hash=sha256:bee2a6db3a7242ea309aa7ee8e2780726fed67ff4e5b40169f2c940e7eb09227 \ + --hash=sha256:beeae3f27f62308f1ddbcfb0690bf44b10732f2ef43758f169d5e9303165d3f9 \ + --hash=sha256:c50f36a62a22d350c96e49ad02d0da41dbd17ddc2e29750dbdba4323f85eb4a5 \ + --hash=sha256:c607c90ba67533e1b2355b821fef6764d1dd2cbe26b8c1005ae84f7aea25ff79 \ + --hash=sha256:c7b2a63fd6d5246349f3d3f37b14430d73ee7e8173154461785e43036ffa96ca \ + --hash=sha256:c828a1ae702fc712978bda0320ba1b9893d99be0badf2647f693cc01cf0f04fa \ + --hash=sha256:c85de1136429c524e55cfa4e033b4a7940ac5c8ee4d9401cc2d1bf48154bbc7b \ + --hash=sha256:c98fa880d695de164b4135a52fd2e9cd7b7c90a9d8ac5e9e443a24a95ef9248e \ + --hash=sha256:cae81479f77420d217def5f54b5b9d279804d17e982e0f2fa19b1d1e14ab5197 \ + --hash=sha256:d034140032870024e6b9892c692fe2968493790dd57208b2c37e3fb35f6df3ab \ + --hash=sha256:d120c38a42c234dc9a8c5de7ceaaf899cf33561956acb4941653f8bdc657aa79 \ + --hash=sha256:d4827615da15cd59784ce39d3388275ec093ae3ee8d7f0c089b76fa87af756c2 \ + --hash=sha256:d49e2314c373f4c2b39446fb1a45ed333c850e09d0c59ac79b72eb3b95397363 \ + --hash=sha256:d52610d51e265a51518692045e372a4c363056130d922a7351429ac9f27e70b0 \ + --hash=sha256:d64317d2587c70324b79861babb9c09f71fbb780bad212018874b2c013d8600e \ + --hash=sha256:d77153e14b709fd8b8af6f66a3afbb9ed6e9fc5ccf0b6b7e1ced7b036a228782 \ + --hash=sha256:d7e091d464ac59d2c7ad8e7e08105eaf9dafbc3883fd7265ffccc2baad6ac925 \ + --hash=sha256:dd333073e0cacdc3089525c7df7d39b211bcdf31fc2824e49d01c6b6187b07d0 \ + --hash=sha256:e5d8efac84c9afcb40914ab49ba063d94f5dbdf5066db4482c66a992f47a3a3b \ + --hash=sha256:f135c702ac42262573fe9714dfe99c944b4ba307af5eb507abef1667e2cbbced \ + --hash=sha256:f13711b1a5ba512d647a0e4ba79280d3a9a045aaf7e0cc6fbe96b91d4cdf6b0c \ + --hash=sha256:f4f1231b7dec408e8670264ce63e9c71409d9583dd21d32c163e25213ee2a344 \ + --hash=sha256:fa3ed2a29a9e9d2d488b4da81dcb54720ac3104a20bf0bd273f1e4648aff5af9 \ + --hash=sha256:fb3096c30df99fd01c7bf8e544f392103d0795b9f98ba71a8054bcbf56b255f1 # via # -r build/test-requirements.txt # matplotlib -pluggy==1.5.0 \ - --hash=sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1 \ - --hash=sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669 + # tensorboard +pluggy==1.6.0 \ + --hash=sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3 \ + --hash=sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746 # via pytest portpicker==1.6.0 \ --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa # via -r build/test-requirements.txt -psutil==5.9.8 \ - --hash=sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d \ - --hash=sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73 \ - --hash=sha256:26bd09967ae00920df88e0352a91cff1a78f8d69b3ecabbfe733610c0af486c8 \ - --hash=sha256:27cc40c3493bb10de1be4b3f07cae4c010ce715290a5be22b98493509c6299e2 \ - --hash=sha256:36f435891adb138ed3c9e58c6af3e2e6ca9ac2f365efe1f9cfef2794e6c93b4e \ - --hash=sha256:50187900d73c1381ba1454cf40308c2bf6f34268518b3f36a9b663ca87e65e36 \ - --hash=sha256:611052c4bc70432ec770d5d54f64206aa7203a101ec273a0cd82418c86503bb7 \ - --hash=sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c \ - --hash=sha256:7d79560ad97af658a0f6adfef8b834b53f64746d45b403f225b85c5c2c140eee \ - --hash=sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421 \ - --hash=sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf \ - --hash=sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81 \ - --hash=sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0 \ - --hash=sha256:bd1184ceb3f87651a67b2708d4c3338e9b10c5df903f2e3776b62303b26cb631 \ - --hash=sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4 \ - --hash=sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8 +protobuf==6.33.2 \ + --hash=sha256:1f8017c48c07ec5859106533b682260ba3d7c5567b1ca1f24297ce03384d1b4f \ + --hash=sha256:2981c58f582f44b6b13173e12bb8656711189c2a70250845f264b877f00b1913 \ + --hash=sha256:56dc370c91fbb8ac85bc13582c9e373569668a290aa2e66a590c2a0d35ddb9e4 \ + --hash=sha256:7109dcc38a680d033ffb8bf896727423528db9163be1b6a02d6a49606dcadbfe \ + --hash=sha256:7636aad9bb01768870266de5dc009de2d1b936771b38a793f73cbbf279c91c5c \ + --hash=sha256:87eb388bd2d0f78febd8f4c8779c79247b26a5befad525008e49a6955787ff3d \ + --hash=sha256:8cd7640aee0b7828b6d03ae518b5b4806fdfc1afe8de82f79c3454f8aef29872 \ + --hash=sha256:b5d3b5625192214066d99b2b605f5783483575656784de223f00a8d00754fc0e \ + --hash=sha256:d9b19771ca75935b3a4422957bc518b0cecb978b31d1dd12037b088f6bcc0e43 \ + --hash=sha256:fc2a0e8b05b180e5fc0dd1559fe8ebdae21a27e81ac77728fb6c42b12c7419b4 + # via + # tensorboard + # tensorflow +psutil==7.1.3 \ + --hash=sha256:0005da714eee687b4b8decd3d6cc7c6db36215c9e74e5ad2264b90c3df7d92dc \ + --hash=sha256:1068c303be3a72f8e18e412c5b2a8f6d31750fb152f9cb106b54090296c9d251 \ + --hash=sha256:18349c5c24b06ac5612c0428ec2a0331c26443d259e2a0144a9b24b4395b58fa \ + --hash=sha256:19644c85dcb987e35eeeaefdc3915d059dac7bd1167cdcdbf27e0ce2df0c08c0 \ + --hash=sha256:2bdbcd0e58ca14996a42adf3621a6244f1bb2e2e528886959c72cf1e326677ab \ + --hash=sha256:31d77fcedb7529f27bb3a0472bea9334349f9a04160e8e6e5020f22c59893264 \ + --hash=sha256:3792983e23b69843aea49c8f5b8f115572c5ab64c153bada5270086a2123c7e7 \ + --hash=sha256:3bb428f9f05c1225a558f53e30ccbad9930b11c3fc206836242de1091d3e7dd3 \ + --hash=sha256:56d974e02ca2c8eb4812c3f76c30e28836fffc311d55d979f1465c1feeb2b68b \ + --hash=sha256:6c86281738d77335af7aec228328e944b30930899ea760ecf33a4dba66be5e74 \ + --hash=sha256:8f33a3702e167783a9213db10ad29650ebf383946e91bc77f28a5eb083496bc9 \ + --hash=sha256:95ef04cf2e5ba0ab9eaafc4a11eaae91b44f4ef5541acd2ee91d9108d00d59a7 \ + --hash=sha256:ad81425efc5e75da3f39b3e636293360ad8d0b49bed7df824c79764fb4ba9b8b \ + --hash=sha256:b403da1df4d6d43973dc004d19cee3b848e998ae3154cc8097d139b77156c353 \ + --hash=sha256:bc31fa00f1fbc3c3802141eede66f3a2d51d89716a194bf2cd6fc68310a19880 \ + --hash=sha256:bd0d69cee829226a761e92f28140bec9a5ee9d5b4fb4b0cc589068dbfff559b1 \ + --hash=sha256:c525ffa774fe4496282fb0b1187725793de3e7c6b29e41562733cae9ada151ee \ + --hash=sha256:f39c2c19fe824b47484b96f9692932248a54c43799a84282cfe58d05a6449efd \ + --hash=sha256:fac9cd332c67f4422504297889da5ab7e05fd11e3c4392140f7370f4208ded1f # via portpicker -pyelftools==0.31 \ - --hash=sha256:c774416b10310156879443b81187d182d8d9ee499660380e645918b50bc88f99 \ - --hash=sha256:f52de7b3c7e8c64c8abc04a79a1cf37ac5fb0b8a49809827130b858944840607 +pyelftools==0.32 \ + --hash=sha256:013df952a006db5e138b1edf6d8a68ecc50630adbd0d83a2d41e7f846163d738 \ + --hash=sha256:6de90ee7b8263e740c8715a925382d4099b354f29ac48ea40d840cf7aa14ace5 # via auditwheel -pygments==2.18.0 \ - --hash=sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199 \ - --hash=sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a - # via rich -pyparsing==3.1.2 \ - --hash=sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad \ - --hash=sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742 +pygments==2.19.2 \ + --hash=sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887 \ + --hash=sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b + # via + # pytest + # rich +pyparsing==3.2.5 \ + --hash=sha256:2df8d5b7b2802ef88e8d016a2eb9c7aeaa923529cd251ed0fe4608275d4105b6 \ + --hash=sha256:e38a4f02064cf41fe6593d328d0512495ad1f3d8a91c4f73fc401b3079a59a5e # via matplotlib -pyproject-hooks==1.1.0 \ - --hash=sha256:4b37730834edbd6bd37f26ece6b44802fb1c1ee2ece0e54ddff8bfc06db86965 \ - --hash=sha256:7ceeefe9aec63a1064c18d939bdc3adf2d8aa1988a510afec15151578b232aa2 +pyproject-hooks==1.2.0 \ + --hash=sha256:1e859bd5c40fae9448642dd871adf459e5e2084186e8d2c2a79a824c970da1f8 \ + --hash=sha256:9e5c6bfa8dcc30091c74b0cf803c81fdd29d94f01992a7707bc97babb1141913 # via build -pytest==8.2.0 \ - --hash=sha256:1733f0620f6cda4095bbf0d9ff8022486e91892245bb9e7d5542c018f612f233 \ - --hash=sha256:d507d4482197eac0ba2bae2e9babf0672eb333017bcedaa5fb1a3d42c1174b3f - # via pytest-xdist -pytest-xdist==3.6.1 \ - --hash=sha256:9ed4adfb68a016610848639bb7e02c9352d5d9f03d04809919e2dafc3be4cca7 \ - --hash=sha256:ead156a4db231eec769737f57668ef58a2084a34b2e55c4a8fa20d861107300d +pytest==8.4.2 \ + --hash=sha256:86c0d0b93306b961d58d62a4db4879f27fe25513d4b969df351abdddb3c30e01 \ + --hash=sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79 + # via + # -r build/test-requirements.txt + # pytest-xdist +pytest-xdist==3.8.0 \ + --hash=sha256:202ca578cfeb7370784a8c33d6d05bc6e13b4f25b5053c30a152269fd10f0b88 \ + --hash=sha256:7e578125ec9bc6050861aa93f2d59f1d8d085595d6551c2c90b6f4fad8d3a9f1 # via -r build/test-requirements.txt python-dateutil==2.9.0.post0 \ --hash=sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3 \ --hash=sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427 # via matplotlib -rich==13.7.1 \ - --hash=sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222 \ - --hash=sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432 +requests==2.32.5 \ + --hash=sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6 \ + --hash=sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf + # via tensorflow +rich==14.2.0 \ + --hash=sha256:73ff50c7c0c1c77c8243079283f4edb376f0f6442433aecb8ce7e6d0b92d1fe4 \ + --hash=sha256:76bc51fe2e57d2b1be1f96c524b890b816e334ab4c1e45888799bfaab0021edd + # via + # -r build/test-requirements.txt + # keras +scipy==1.16.3 ; python_version <= "3.12" \ + --hash=sha256:0151a0749efeaaab78711c78422d413c583b8cdd2011a3c1d6c794938ee9fdb2 \ + --hash=sha256:01e87659402762f43bd2fee13370553a17ada367d42e7487800bf2916535aecb \ + --hash=sha256:03192a35e661470197556de24e7cb1330d84b35b94ead65c46ad6f16f6b28f2a \ + --hash=sha256:0553371015692a898e1aa858fed67a3576c34edefa6b7ebdb4e9dde49ce5c203 \ + --hash=sha256:062246acacbe9f8210de8e751b16fc37458213f124bef161a5a02c7a39284304 \ + --hash=sha256:0c3b4dd3d9b08dbce0f3440032c52e9e2ab9f96ade2d3943313dfe51a7056959 \ + --hash=sha256:0c623a54f7b79dd88ef56da19bc2873afec9673a48f3b85b18e4d402bdd29a5a \ + --hash=sha256:16b8bc35a4cc24db80a0ec836a9286d0e31b2503cb2fd7ff7fb0e0374a97081d \ + --hash=sha256:1fb2472e72e24d1530debe6ae078db70fb1605350c88a3d14bc401d6306dbffe \ + --hash=sha256:21d9d6b197227a12dcbf9633320a4e34c6b0e51c57268df255a0942983bac562 \ + --hash=sha256:2a207a6ce9c24f1951241f4693ede2d393f59c07abc159b2cb2be980820e01fb \ + --hash=sha256:2b71d93c8a9936046866acebc915e2af2e292b883ed6e2cbe5c34beb094b82d9 \ + --hash=sha256:2d1ae2cf0c350e7705168ff2429962a89ad90c2d49d1dd300686d8b2a5af22fc \ + --hash=sha256:3a4c460301fb2cffb7f88528f30b3127742cff583603aa7dc964a52c463b385d \ + --hash=sha256:3d4a07a8e785d80289dfe66b7c27d8634a773020742ec7187b85ccc4b0e7b686 \ + --hash=sha256:40be6cf99e68b6c4321e9f8782e7d5ff8265af28ef2cd56e9c9b2638fa08ad97 \ + --hash=sha256:4aff59800a3b7f786b70bfd6ab551001cb553244988d7d6b8299cb1ea653b353 \ + --hash=sha256:50a3dbf286dbc7d84f176f9a1574c705f277cb6565069f88f60db9eafdbe3ee2 \ + --hash=sha256:532fb5ad6a87e9e9cd9c959b106b73145a03f04c7d57ea3e6f6bb60b86ab0876 \ + --hash=sha256:53c3844d527213631e886621df5695d35e4f6a75f620dca412bcd292f6b87d78 \ + --hash=sha256:56edc65510d1331dae01ef9b658d428e33ed48b4f77b1d51caf479a0253f96dc \ + --hash=sha256:57d01cb6f85e34f0946b33caa66e892aae072b64b034183f3d87c4025802a119 \ + --hash=sha256:5803c5fadd29de0cf27fa08ccbfe7a9e5d741bf63e4ab1085437266f12460ff9 \ + --hash=sha256:6020470b9d00245926f2d5bb93b119ca0340f0d564eb6fbaad843eaebf9d690f \ + --hash=sha256:63d3cdacb8a824a295191a723ee5e4ea7768ca5ca5f2838532d9f2e2b3ce2135 \ + --hash=sha256:663b8d66a8748051c3ee9c96465fb417509315b99c71550fda2591d7dd634234 \ + --hash=sha256:72d1717fd3b5e6ec747327ce9bda32d5463f472c9dce9f54499e81fbd50245a1 \ + --hash=sha256:7dc1360c06535ea6116a2220f760ae572db9f661aba2d88074fe30ec2aa1ff88 \ + --hash=sha256:7f68154688c515cdb541a31ef8eb66d8cd1050605be9dcd74199cbd22ac739bc \ + --hash=sha256:81fc5827606858cf71446a5e98715ba0e11f0dbc83d71c7409d05486592a45d6 \ + --hash=sha256:875555ce62743e1d54f06cdf22c1e0bc47b91130ac40fe5d783b6dfa114beeb6 \ + --hash=sha256:8b3c820ddb80029fe9f43d61b81d8b488d3ef8ca010d15122b152db77dc94c22 \ + --hash=sha256:8be1ca9170fcb6223cc7c27f4305d680ded114a1567c0bd2bfcbf947d1b17511 \ + --hash=sha256:8d09d72dc92742988b0e7750bddb8060b0c7079606c0d24a8cc8e9c9c11f9079 \ + --hash=sha256:9452781bd879b14b6f055b26643703551320aa8d79ae064a71df55c00286a184 \ + --hash=sha256:96491a6a54e995f00a28a3c3badfff58fd093bf26cd5fb34a2188c8c756a3a2c \ + --hash=sha256:9b9c9c07b6d56a35777a1b4cc8966118fb16cfd8daf6743867d17d36cfad2d40 \ + --hash=sha256:a8a26c78ef223d3e30920ef759e25625a0ecdd0d60e5a8818b7513c3e5384cf2 \ + --hash=sha256:aadd23f98f9cb069b3bd64ddc900c4d277778242e961751f77a8cb5c4b946fb0 \ + --hash=sha256:b7180967113560cca57418a7bc719e30366b47959dd845a93206fbed693c867e \ + --hash=sha256:b7c5f1bda1354d6a19bc6af73a649f8285ca63ac6b52e64e658a5a11d4d69800 \ + --hash=sha256:b81c27fc41954319a943d43b20e07c40bdcd3ff7cf013f4fb86286faefe546c4 \ + --hash=sha256:bb61878c18a470021fb515a843dc7a76961a8daceaaaa8bad1332f1bf4b54657 \ + --hash=sha256:bea0a62734d20d67608660f69dcda23e7f90fb4ca20974ab80b6ed40df87a005 \ + --hash=sha256:c5192722cffe15f9329a3948c4b1db789fbb1f05c97899187dcf009b283aea70 \ + --hash=sha256:c97176013d404c7346bf57874eaac5187d969293bf40497140b0a2b2b7482e07 \ + --hash=sha256:cd13e354df9938598af2be05822c323e97132d5e6306b83a3b4ee6724c6e522e \ + --hash=sha256:d2ec56337675e61b312179a1ad124f5f570c00f920cc75e1000025451b88241c \ + --hash=sha256:d3837938ae715fc0fe3c39c0202de3a8853aff22ca66781ddc2ade7554b7e2cc \ + --hash=sha256:d9f48cafc7ce94cf9b15c6bffdc443a81a27bf7075cf2dcd5c8b40f85d10c4e7 \ + --hash=sha256:da7763f55885045036fabcebd80144b757d3db06ab0861415d1c3b7c69042146 \ + --hash=sha256:deb3841c925eeddb6afc1e4e4a45e418d19ec7b87c5df177695224078e8ec733 \ + --hash=sha256:e1d27cbcb4602680a49d787d90664fa4974063ac9d4134813332a8c53dbe667c \ + --hash=sha256:e5d42a9472e7579e473879a1990327830493a7047506d58d73fc429b84c1d49d \ + --hash=sha256:e7efa2681ea410b10dde31a52b18b0154d66f2485328830e45fdf183af5aefc6 \ + --hash=sha256:eab43fae33a0c39006a88096cd7b4f4ef545ea0447d250d5ac18202d40b6611d \ + --hash=sha256:f2622206f5559784fa5c4b53a950c3c7c1cf3e84ca1b9c4b6c03f062f289ca26 \ + --hash=sha256:f379b54b77a597aa7ee5e697df0d66903e41b9c85a6dd7946159e356319158e8 \ + --hash=sha256:f667a4542cc8917af1db06366d3f78a5c8e83badd56409f94d1eac8d8d9133fa \ + --hash=sha256:fb4b29f4cf8cc5a8d628bc8d8e26d12d7278cd1f219f22698a378c3d67db5e4b \ + --hash=sha256:ffa6eea95283b2b8079b821dc11f50a17d0571c92b43e2b5b12764dc5f9b285d + # via + # -r build/requirements.in + # jaxlib +scipy-stubs==1.16.3.3 \ + --hash=sha256:af47578875d5557567225a16ec1b9b38a48c4c4377d92396413ebd65406c44ee \ + --hash=sha256:f6316b36cd0fb272c994ae5b10c4a73c644a7e156ed8d32bcd9c35303d0e1b7e # via -r build/test-requirements.txt -scipy==1.13.1 \ - --hash=sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d \ - --hash=sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c \ - --hash=sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca \ - --hash=sha256:27e52b09c0d3a1d5b63e1105f24177e544a222b43611aaf5bc44d4a0979e32f9 \ - --hash=sha256:2831f0dc9c5ea9edd6e51e6e769b655f08ec6db6e2e10f86ef39bd32eb11da54 \ - --hash=sha256:2ac65fb503dad64218c228e2dc2d0a0193f7904747db43014645ae139c8fad16 \ - --hash=sha256:392e4ec766654852c25ebad4f64e4e584cf19820b980bc04960bca0b0cd6eaa2 \ - --hash=sha256:436bbb42a94a8aeef855d755ce5a465479c721e9d684de76bf61a62e7c2b81d5 \ - --hash=sha256:45484bee6d65633752c490404513b9ef02475b4284c4cfab0ef946def50b3f59 \ - --hash=sha256:54f430b00f0133e2224c3ba42b805bfd0086fe488835effa33fa291561932326 \ - --hash=sha256:5713f62f781eebd8d597eb3f88b8bf9274e79eeabf63afb4a737abc6c84ad37b \ - --hash=sha256:5d72782f39716b2b3509cd7c33cdc08c96f2f4d2b06d51e52fb45a19ca0c86a1 \ - --hash=sha256:637e98dcf185ba7f8e663e122ebf908c4702420477ae52a04f9908707456ba4d \ - --hash=sha256:8335549ebbca860c52bf3d02f80784e91a004b71b059e3eea9678ba994796a24 \ - --hash=sha256:949ae67db5fa78a86e8fa644b9a6b07252f449dcf74247108c50e1d20d2b4627 \ - --hash=sha256:a014c2b3697bde71724244f63de2476925596c24285c7a637364761f8710891c \ - --hash=sha256:a78b4b3345f1b6f68a763c6e25c0c9a23a9fd0f39f5f3d200efe8feda560a5fa \ - --hash=sha256:cdd7dacfb95fea358916410ec61bbc20440f7860333aee6d882bb8046264e949 \ - --hash=sha256:cfa31f1def5c819b19ecc3a8b52d28ffdcc7ed52bb20c9a7589669dd3c250989 \ - --hash=sha256:d533654b7d221a6a97304ab63c41c96473ff04459e404b83275b60aa8f4b7004 \ - --hash=sha256:d605e9c23906d1994f55ace80e0125c587f96c020037ea6aa98d01b4bd2e222f \ - --hash=sha256:de3ade0e53bc1f21358aa74ff4830235d716211d7d077e340c7349bc3542e884 \ - --hash=sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299 \ - --hash=sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94 \ - --hash=sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f - # via -r build/requirements.in -six==1.16.0 \ - --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ - --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 - # via python-dateutil +six==1.17.0 \ + --hash=sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274 \ + --hash=sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81 + # via + # astunparse + # google-pasta + # python-dateutil + # tensorflow sortedcontainers==2.4.0 \ --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \ --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0 # via hypothesis -typing-extensions==4.12.0rc1 \ - --hash=sha256:be199d06d8f09ca2c9425e3aa04a9afba33e892fe079dea959e72df7f8442343 \ - --hash=sha256:f933a7b288a919ca97adbff656e52ff81f7ff25d98a2aabb9355ca4090f772fe - # via etils -wheel==0.43.0 \ - --hash=sha256:465ef92c69fa5c5da2d1cf8ac40559a8c940886afcef87dcf14b9470862f1d85 \ - --hash=sha256:55c570405f142630c6b9f72fe09d9b67cf1477fcf543ae5b8dcb1f5b7377da81 - # via -r build/test-requirements.txt -zipp==3.18.2 \ - --hash=sha256:6278d9ddbcfb1f1089a88fde84481528b07b0e10474e09dcfe53dad4069fa059 \ - --hash=sha256:dce197b859eb796242b0622af1b8beb0a722d52aa2f57133ead08edd5bf5374e +tensorboard==2.20.0 \ + --hash=sha256:9dc9f978cb84c0723acf9a345d96c184f0293d18f166bb8d59ee098e6cfaaba6 + # via tensorflow +tensorboard-data-server==0.7.2 \ + --hash=sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb \ + --hash=sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60 \ + --hash=sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530 + # via tensorboard +tensorflow==2.20.0 ; python_version < "3.14" \ + --hash=sha256:02a0293d94f5c8b7125b66abf622cc4854a33ae9d618a0d41309f95e091bbaea \ + --hash=sha256:0deb5c583dfc53b54fd158a194ce0087b406bb6518af400ca3809735e4548ec3 \ + --hash=sha256:1590cbf87b6bcbd34d8e9ad70d0c696135e0aa71be31803b27358cf7ed63f8fc \ + --hash=sha256:197f0b613b38c0da5c6a12a8295ad4a05c78b853835dae8e0f9dfae3ce9ce8a5 \ + --hash=sha256:25265b0bc527e0d54b1e9cc60c44a24f44a809fe27666b905f0466471f9c52ec \ + --hash=sha256:28bc33759249c98eabcee9debd24e74506bbe29ac139e050cf0c74aa9888ebdf \ + --hash=sha256:2bfbfb3dd0e22bffc45fe1e922390d27753e99261fab8a882e802cf98a0e078f \ + --hash=sha256:3e9568c8efcb05c0266be223e3269c62ebf7ad3498f156438311735f6fa5ced5 \ + --hash=sha256:47c88e05a07f1ead4977b4894b3ecd4d8075c40191065afc4fd9355c9db3d926 \ + --hash=sha256:481499fd0f824583de8945be61d5e827898cdaa4f5ea1bc2cc28ca2ccff8229e \ + --hash=sha256:4a69ac2c2ce20720abf3abf917b4e86376326c0976fcec3df330e184b81e4088 \ + --hash=sha256:52b122f0232fd7ab10f28d537ce08470d0b6dcac7fff9685432daac7f8a06c8f \ + --hash=sha256:5f964016c5035d09b85a246a6b739be89282a7839743f3ea63640224f0c63aee \ + --hash=sha256:5fa3729b0126f75a99882b89fb7d536515721eda8014a63e259e780ba0a37372 \ + --hash=sha256:7551558a48c2e2f6c32a1537f06c654a9df1408a1c18e7b99c3caafbd03edfe3 \ + --hash=sha256:7abd7f3a010e0d354dc804182372779a722d474c4d8a3db8f4a3f5baef2a591e \ + --hash=sha256:a66cbd1b19209d3fbc45cbea80de92514ba455434013937251d65d444779783c \ + --hash=sha256:c25edad45e8cb9e76366f7a8c835279f9169028d610f3b52ce92d332a1b05438 \ + --hash=sha256:dd71a7e7c3270239f4185915e8f2c5d39608c5e18973d6e1d101b153993841eb \ + --hash=sha256:e5f169f8f5130ab255bbe854c5f0ae152e93d3d1ac44f42cb1866003b81a5357 + # via -r build/nonfreethreading-requirements.txt +tensorstore==0.1.80 \ + --hash=sha256:04c29d979eb8b8ee48f873dc13d2701bfd49425500ffc5b848e4ec55b2548281 \ + --hash=sha256:07e4a84bacf70b78305831897068a9b5ad30326e63bbeb92c4bf7e565fcf5e9e \ + --hash=sha256:1113a6982fc0fa8dda8fcc0495715e647ac3360909a86ff13f2e04564f82d54a \ + --hash=sha256:189d924eaec394c9331e284a9c513ed583e336472a925823b5151cb26f41d091 \ + --hash=sha256:1b2b2ed0051dfab7e25295b14e6620520729e6e2ddf505f98c8d3917569614bf \ + --hash=sha256:246641a8780ee5e04e88bc95c8e31faac6471bab1180d1f5cdc9804b29a77c04 \ + --hash=sha256:4158fe76b96f62d12a37d7868150d836e089b5280b2bdd363c43c5d651f10e26 \ + --hash=sha256:46136fe42ee6dd835d957db37073058aea0b78fdfbe2975941640131b7740824 \ + --hash=sha256:4baee67fce95f29f593fbab4866119347115eaace887732aa92cfcbb9e6b0748 \ + --hash=sha256:53fd121ccd332bc4cc397f7af45889360c668b43dc3ff6bc3264df0f9886c11a \ + --hash=sha256:6b7c5dd434bba4ee08fe46bbbdb25c60dd3d47ccb4b8561a9751cf1526da52b8 \ + --hash=sha256:6c8dbbdd31cbb28eccfb23dbbd4218fe67bfc32e9cb452875a485b81031c949d \ + --hash=sha256:7451b30f99d9f31a2b9d70e6ef61815713dc782c58c6d817f91781341e4dac05 \ + --hash=sha256:8cd11027b5a8b66db8d344085a31a1666c78621dac27039c4d571bc4974804a1 \ + --hash=sha256:9c088e8c9f67c266ef4dae3703bd617f7c0cb0fd98e99c4500692e38a4328140 \ + --hash=sha256:a92505189731fcb03f1c69a84ea4460abb24204bfac1f339448a0621e7def77c \ + --hash=sha256:acb8d52fadcefafef4ef8ecca3fc99b1d0e3c5c5a888766484c3e39f050be7f5 \ + --hash=sha256:b193a7a1c4f455a61e60ed2dd67271a3daab0910ddb4bd9db51390d1b36d9996 \ + --hash=sha256:bc28a58c580253a526a4b6d239d18181ef96f1e285a502dbb03ff15eeec07a5b \ + --hash=sha256:c0529afab3800749dd245843d3bf0d061a109a8edb77fb345f476e8bccda51b8 \ + --hash=sha256:d2b353b0bd53fedd77fc5a12a1c1a91cacc3cf59e3dd785529c5a54b31d1c7b1 \ + --hash=sha256:de63843706fdfe9565a45567238c5b1e55a0b28bbde6524200b31d29043a9a16 \ + --hash=sha256:e93df6d34ff5f0f6be245f4d29b99a7c1eef8ad91b50686adf57a5eeea99cb74 \ + --hash=sha256:f65dfaf9e737a41389e29a5a2ea52ca5d14c8d6f48b402c723d800cd16d322b0 \ + --hash=sha256:f8b51d7e685bbb63f6becd7d2ac8634d5ab67ec7e53038e597182e2db2c7aa90 + # via -r build/nonfreethreading-requirements.txt +termcolor==3.2.0 \ + --hash=sha256:610e6456feec42c4bcd28934a8c87a06c3fa28b01561d46aa09a9881b8622c58 \ + --hash=sha256:a10343879eba4da819353c55cb8049b0933890c2ebf9ad5d3ecd2bb32ea96ea6 + # via tensorflow +typing-extensions==4.15.0 \ + --hash=sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466 \ + --hash=sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548 + # via + # etils + # grpcio + # optree + # optype + # tensorflow +urllib3==2.6.2 \ + --hash=sha256:016f9c98bb7e98085cb2b4b17b87d2c702975664e4f060c6532e64d1c1a5e797 \ + --hash=sha256:ec21cddfe7724fc7cb4ba4bea7aa8e2ef36f607a4bab81aa6ce42a13dc3f03dd + # via requests +werkzeug==3.1.4 \ + --hash=sha256:2ad50fb9ed09cc3af22c54698351027ace879a0b60a3b5edf5730b2f7d876905 \ + --hash=sha256:cd3cd98b1b92dc3b7b3995038826c68097dcb16f9baa63abe35f20eafeb9fe5e + # via tensorboard +wheel==0.46.1 \ + --hash=sha256:f796f65d72750ccde090663e466d0ca37cd72b62870f7520b96d34cdc07d86d8 \ + --hash=sha256:fd477efb5da0f7df1d3c76c73c14394002c844451bd63229d8570f376f5e6a38 + # via + # -r build/requirements.in + # astunparse +wrapt==2.0.1 \ + --hash=sha256:09c7476ab884b74dce081ad9bfd07fe5822d8600abade571cb1f66d5fc915af6 \ + --hash=sha256:0e17283f533a0d24d6e5429a7d11f250a58d28b4ae5186f8f47853e3e70d2590 \ + --hash=sha256:115cae4beed3542e37866469a8a1f2b9ec549b4463572b000611e9946b86e6f6 \ + --hash=sha256:1218573502a8235bb8a7ecaed12736213b22dcde9feab115fa2989d42b5ded45 \ + --hash=sha256:17fb85fa4abc26a5184d93b3efd2dcc14deb4b09edcdb3535a536ad34f0b4dba \ + --hash=sha256:1e9b121e9aeb15df416c2c960b8255a49d44b4038016ee17af03975992d03931 \ + --hash=sha256:1f186e26ea0a55f809f232e92cc8556a0977e00183c3ebda039a807a42be1494 \ + --hash=sha256:1fdbb34da15450f2b1d735a0e969c24bdb8d8924892380126e2a293d9902078c \ + --hash=sha256:23097ed8bc4c93b7bf36fa2113c6c733c976316ce0ee2c816f64ca06102034ef \ + --hash=sha256:2879af909312d0baf35f08edeea918ee3af7ab57c37fe47cb6a373c9f2749c7b \ + --hash=sha256:2afa23318136709c4b23d87d543b425c399887b4057936cd20386d5b1422b6fa \ + --hash=sha256:2da620b31a90cdefa9cd0c2b661882329e2e19d1d7b9b920189956b76c564d75 \ + --hash=sha256:35cdbd478607036fee40273be8ed54a451f5f23121bd9d4be515158f9498f7ad \ + --hash=sha256:36982b26f190f4d737f04a492a68accbfc6fa042c3f42326fdfbb6c5b7a20a31 \ + --hash=sha256:3793ac154afb0e5b45d1233cb94d354ef7a983708cc3bb12563853b1d8d53747 \ + --hash=sha256:386fb54d9cd903ee0012c09291336469eb7b244f7183d40dc3e86a16a4bace62 \ + --hash=sha256:3cd1a4bd9a7a619922a8557e1318232e7269b5fb69d4ba97b04d20450a6bf970 \ + --hash=sha256:3d32794fe940b7000f0519904e247f902f0149edbe6316c710a8562fb6738841 \ + --hash=sha256:3d366aa598d69416b5afedf1faa539fac40c1d80a42f6b236c88c73a3c8f2d41 \ + --hash=sha256:3e271346f01e9c8b1130a6a3b0e11908049fe5be2d365a5f402778049147e7e9 \ + --hash=sha256:3f373a4ab5dbc528a94334f9fe444395b23c2f5332adab9ff4ea82f5a9e33bc1 \ + --hash=sha256:3fa272ca34332581e00bf7773e993d4f632594eb2d1b0b162a9038df0fd971dd \ + --hash=sha256:47434236c396d04875180171ee1f3815ca1eada05e24a1ee99546320d54d1d1b \ + --hash=sha256:47b0f8bafe90f7736151f61482c583c86b0693d80f075a58701dd1549b0010a9 \ + --hash=sha256:4811e15d88ee62dbf5c77f2c3ff3932b1e3ac92323ba3912f51fc4016ce81ecf \ + --hash=sha256:49989061a9977a8cbd6d20f2efa813f24bf657c6990a42967019ce779a878dbf \ + --hash=sha256:4ae879acc449caa9ed43fc36ba08392b9412ee67941748d31d94e3cedb36628c \ + --hash=sha256:4b55cacc57e1dc2d0991dbe74c6419ffd415fb66474a02335cb10efd1aa3f84f \ + --hash=sha256:4d2ce1bf1a48c5277d7969259232b57645aae5686dba1eaeade39442277afbca \ + --hash=sha256:4da7384b0e5d4cae05c97cd6f94faaf78cc8b0f791fc63af43436d98c4ab37bb \ + --hash=sha256:4e54bbf554ee29fcceee24fa41c4d091398b911da6e7f5d7bffda963c9aed2e1 \ + --hash=sha256:50844efc8cdf63b2d90cd3d62d4947a28311e6266ce5235a219d21b195b4ec2c \ + --hash=sha256:5a4939eae35db6b6cec8e7aa0e833dcca0acad8231672c26c2a9ab7a0f8ac9c8 \ + --hash=sha256:5dc1b852337c6792aa111ca8becff5bacf576bf4a0255b0f05eb749da6a1643e \ + --hash=sha256:5e53b428f65ece6d9dad23cb87e64506392b720a0b45076c05354d27a13351a1 \ + --hash=sha256:61c4956171c7434634401db448371277d07032a81cc21c599c22953374781395 \ + --hash=sha256:641e94e789b5f6b4822bb8d8ebbdfc10f4e4eae7756d648b717d980f657a9eb9 \ + --hash=sha256:64b103acdaa53b7caf409e8d45d39a8442fe6dcfec6ba3f3d141e0cc2b5b4dbd \ + --hash=sha256:68424221a2dc00d634b54f92441914929c5ffb1c30b3b837343978343a3512a3 \ + --hash=sha256:6bd1a18f5a797fe740cb3d7a0e853a8ce6461cc62023b630caec80171a6b8097 \ + --hash=sha256:6c72328f668cf4c503ffcf9434c2b71fdd624345ced7941bc6693e61bbe36bef \ + --hash=sha256:6d2d947d266d99a1477cd005b23cbd09465276e302515e122df56bb9511aca1b \ + --hash=sha256:7164a55f5e83a9a0b031d3ffab4d4e36bbec42e7025db560f225489fa929e509 \ + --hash=sha256:7b219cb2182f230676308cdcacd428fa837987b89e4b7c5c9025088b8a6c9faf \ + --hash=sha256:7d539241e87b650cbc4c3ac9f32c8d1ac8a54e510f6dca3f6ab60dcfd48c9b10 \ + --hash=sha256:7de3cc939be0e1174969f943f3b44e0d79b6f9a82198133a5b7fc6cc92882f16 \ + --hash=sha256:8330b42d769965e96e01fa14034b28a2a7600fbf7e8f0cc90ebb36d492c993e4 \ + --hash=sha256:837e31620e06b16030b1d126ed78e9383815cbac914693f54926d816d35d8edf \ + --hash=sha256:83ce30937f0ba0d28818807b303a412440c4b63e39d3d8fc036a94764b728c92 \ + --hash=sha256:85df8d92158cb8f3965aecc27cf821461bb5f40b450b03facc5d9f0d4d6ddec6 \ + --hash=sha256:8639b843c9efd84675f1e100ed9e99538ebea7297b62c4b45a7042edb84db03e \ + --hash=sha256:89a82053b193837bf93c0f8a57ded6e4b6d88033a499dadff5067e912c2a41e9 \ + --hash=sha256:8bacfe6e001749a3b64db47bcf0341da757c95959f592823a93931a422395013 \ + --hash=sha256:8ec3303e8a81932171f455f792f8df500fc1a09f20069e5c16bd7049ab4e8e38 \ + --hash=sha256:90897ea1cf0679763b62e79657958cd54eae5659f6360fc7d2ccc6f906342183 \ + --hash=sha256:908f8c6c71557f4deaa280f55d0728c3bca0960e8c3dd5ceeeafb3c19942719d \ + --hash=sha256:91bcc576260a274b169c3098e9a3519fb01f2989f6d3d386ef9cbf8653de1374 \ + --hash=sha256:9219a1d946a9b32bb23ccae66bdb61e35c62773ce7ca6509ceea70f344656b7b \ + --hash=sha256:949520bccc1fa227274da7d03bf238be15389cd94e32e4297b92337df9b7a349 \ + --hash=sha256:98d873ed6c8b4ee2418f7afce666751854d6d03e3c0ec2a399bb039cd2ae89db \ + --hash=sha256:9c9c635e78497cacb81e84f8b11b23e0aacac7a136e73b8e5b2109a1d9fc468f \ + --hash=sha256:9ca66b38dd642bf90c59b6738af8070747b610115a39af2498535f62b5cdc1c3 \ + --hash=sha256:a453257f19c31b31ba593c30d997d6e5be39e3b5ad9148c2af5a7314061c63eb \ + --hash=sha256:a52f93d95c8d38fed0669da2ebdb0b0376e895d84596a976c15a9eb45e3eccb3 \ + --hash=sha256:a9a83618c4f0757557c077ef71d708ddd9847ed66b7cc63416632af70d3e2308 \ + --hash=sha256:ab594f346517010050126fcd822697b25a7031d815bb4fbc238ccbe568216489 \ + --hash=sha256:ad3ee9d0f254851c71780966eb417ef8e72117155cff04821ab9b60549694a55 \ + --hash=sha256:aea9c7224c302bc8bfc892b908537f56c430802560e827b75ecbde81b604598b \ + --hash=sha256:b4c2e3d777e38e913b8ce3a6257af72fb608f86a1df471cb1d4339755d0a807c \ + --hash=sha256:b667189cf8efe008f55bbda321890bef628a67ab4147ebf90d182f2dadc78790 \ + --hash=sha256:b89ef9223d665ab255ae42cc282d27d69704d94be0deffc8b9d919179a609684 \ + --hash=sha256:be9e84e91d6497ba62594158d3d31ec0486c60055c49179edc51ee43d095f79c \ + --hash=sha256:bf4cb76f36be5de950ce13e22e7fdf462b35b04665a12b64f3ac5c1bbbcf3728 \ + --hash=sha256:bfb5539005259f8127ea9c885bdc231978c06b7a980e63a8a61c8c4c979719d0 \ + --hash=sha256:c046781d422f0830de6329fa4b16796096f28a92c8aef3850674442cdcb87b7f \ + --hash=sha256:c1be685ac7700c966b8610ccc63c3187a72e33cab53526a27b2a285a662cd4f7 \ + --hash=sha256:c1c91405fcf1d501fa5d55df21e58ea49e6b879ae829f1039faaf7e5e509b41e \ + --hash=sha256:c235095d6d090aa903f1db61f892fffb779c1eaeb2a50e566b52001f7a0f66ed \ + --hash=sha256:c4012a2bd37059d04f8209916aa771dfb564cccb86079072bdcd48a308b6a5c5 \ + --hash=sha256:c5ef2f2b8a53b7caee2f797ef166a390fef73979b15778a4a153e4b5fedce8fa \ + --hash=sha256:c654eafb01afac55246053d67a4b9a984a3567c3808bb7df2f8de1c1caba2e1c \ + --hash=sha256:c8d60527d1ecfc131426b10d93ab5d53e08a09c5fa0175f6b21b3252080c70a9 \ + --hash=sha256:c9e850f5b7fc67af856ff054c71690d54fa940c3ef74209ad9f935b4f66a0233 \ + --hash=sha256:cbeb0971e13b4bd81d34169ed57a6dda017328d1a22b62fda45e1d21dd06148f \ + --hash=sha256:d1a8a09a004ef100e614beec82862d11fc17d601092c3599afd22b1f36e4137e \ + --hash=sha256:d67956c676be5a24102c7407a71f4126d30de2a569a1c7871c9f3cabc94225d7 \ + --hash=sha256:d6cc985b9c8b235bd933990cdbf0f891f8e010b65a3911f7a55179cd7b0fc57b \ + --hash=sha256:d7b822c61ed04ee6ad64bc90d13368ad6eb094db54883b5dde2182f67a7f22c0 \ + --hash=sha256:df0b6d3b95932809c5b3fecc18fda0f1e07452d05e2662a0b35548985f256e28 \ + --hash=sha256:e042d653a4745be832d5aa190ff80ee4f02c34b21f4b785745eceacd0907b815 \ + --hash=sha256:e2f84e9af2060e3904a32cea9bb6db23ce3f91cfd90c6b426757cf7cc01c45c7 \ + --hash=sha256:e3612dc06b436968dfb9142c62e5dfa9eb5924f91120b3c8ff501ad878f90eb3 \ + --hash=sha256:e505629359cb5f751e16e30cf3f91a1d3ddb4552480c205947da415d597f7ac2 \ + --hash=sha256:e60690ba71a57424c8d9ff28f8d006b7ad7772c22a4af432188572cd7fa004a1 \ + --hash=sha256:e76e3f91f864e89db8b8d2a8311d57df93f01ad6bb1e9b9976d1f2e83e18315c \ + --hash=sha256:eb7cffe572ad0a141a7886a1d2efa5bef0bf7fe021deeea76b3ab334d2c38218 \ + --hash=sha256:ec65a78fbd9d6f083a15d7613b2800d5663dbb6bb96003899c834beaa68b242c \ + --hash=sha256:eda8e4ecd662d48c28bb86be9e837c13e45c58b8300e43ba3c9b4fa9900302f7 \ + --hash=sha256:f26f8e2ca19564e2e1fdbb6a0e47f36e0efbab1acc31e15471fad88f828c75f6 \ + --hash=sha256:f49027b0b9503bf6c8cdc297ca55006b80c2f5dd36cecc72c6835ab6e10e8a25 \ + --hash=sha256:f73f9f7a0ebd0db139253d27e5fc8d2866ceaeef19c30ab5d69dcbe35e1a6981 \ + --hash=sha256:fa4184e74197af3adad3c889a1af95b53bb0466bced92ea99a0c014e48323eec \ + --hash=sha256:fb1a5b72cbd751813adc02ef01ada0b0d05d3dcbc32976ce189a1279d80ad4a2 \ + --hash=sha256:fb3a86e703868561c5cad155a15c36c716e1ab513b7065bd2ac8ed353c503333 \ + --hash=sha256:fc007fdf480c77301ab1afdbb6ab22a5deee8885f3b1ed7afcb7e5e84a0e27be \ + --hash=sha256:fe21b118b9f58859b5ebaa4b130dee18669df4bd111daad082b7beb8799ad16b \ + --hash=sha256:fec0d993ecba3991645b4857837277469c8cc4c554a7e24d064d1ca291cfb81f + # via tensorflow +zipp==3.23.0 \ + --hash=sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e \ + --hash=sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166 # via etils -zstandard==0.22.0 \ - --hash=sha256:11f0d1aab9516a497137b41e3d3ed4bbf7b2ee2abc79e5c8b010ad286d7464bd \ - --hash=sha256:1958100b8a1cc3f27fa21071a55cb2ed32e9e5df4c3c6e661c193437f171cba2 \ - --hash=sha256:1a90ba9a4c9c884bb876a14be2b1d216609385efb180393df40e5172e7ecf356 \ - --hash=sha256:1d43501f5f31e22baf822720d82b5547f8a08f5386a883b32584a185675c8fbf \ - --hash=sha256:23d2b3c2b8e7e5a6cb7922f7c27d73a9a615f0a5ab5d0e03dd533c477de23004 \ - --hash=sha256:2612e9bb4977381184bb2463150336d0f7e014d6bb5d4a370f9a372d21916f69 \ - --hash=sha256:275df437ab03f8c033b8a2c181e51716c32d831082d93ce48002a5227ec93019 \ - --hash=sha256:2ac9957bc6d2403c4772c890916bf181b2653640da98f32e04b96e4d6fb3252a \ - --hash=sha256:2b11ea433db22e720758cba584c9d661077121fcf60ab43351950ded20283440 \ - --hash=sha256:2fdd53b806786bd6112d97c1f1e7841e5e4daa06810ab4b284026a1a0e484c0b \ - --hash=sha256:33591d59f4956c9812f8063eff2e2c0065bc02050837f152574069f5f9f17775 \ - --hash=sha256:36a47636c3de227cd765e25a21dc5dace00539b82ddd99ee36abae38178eff9e \ - --hash=sha256:39b2853efc9403927f9065cc48c9980649462acbdf81cd4f0cb773af2fd734bc \ - --hash=sha256:3db41c5e49ef73641d5111554e1d1d3af106410a6c1fb52cf68912ba7a343a0d \ - --hash=sha256:445b47bc32de69d990ad0f34da0e20f535914623d1e506e74d6bc5c9dc40bb09 \ - --hash=sha256:466e6ad8caefb589ed281c076deb6f0cd330e8bc13c5035854ffb9c2014b118c \ - --hash=sha256:48f260e4c7294ef275744210a4010f116048e0c95857befb7462e033f09442fe \ - --hash=sha256:4ac59d5d6910b220141c1737b79d4a5aa9e57466e7469a012ed42ce2d3995e88 \ - --hash=sha256:53866a9d8ab363271c9e80c7c2e9441814961d47f88c9bc3b248142c32141d94 \ - --hash=sha256:589402548251056878d2e7c8859286eb91bd841af117dbe4ab000e6450987e08 \ - --hash=sha256:68953dc84b244b053c0d5f137a21ae8287ecf51b20872eccf8eaac0302d3e3b0 \ - --hash=sha256:6c25b8eb733d4e741246151d895dd0308137532737f337411160ff69ca24f93a \ - --hash=sha256:7034d381789f45576ec3f1fa0e15d741828146439228dc3f7c59856c5bcd3292 \ - --hash=sha256:73a1d6bd01961e9fd447162e137ed949c01bdb830dfca487c4a14e9742dccc93 \ - --hash=sha256:8226a33c542bcb54cd6bd0a366067b610b41713b64c9abec1bc4533d69f51e70 \ - --hash=sha256:888196c9c8893a1e8ff5e89b8f894e7f4f0e64a5af4d8f3c410f0319128bb2f8 \ - --hash=sha256:88c5b4b47a8a138338a07fc94e2ba3b1535f69247670abfe422de4e0b344aae2 \ - --hash=sha256:8a1b2effa96a5f019e72874969394edd393e2fbd6414a8208fea363a22803b45 \ - --hash=sha256:93e1856c8313bc688d5df069e106a4bc962eef3d13372020cc6e3ebf5e045202 \ - --hash=sha256:9501f36fac6b875c124243a379267d879262480bf85b1dbda61f5ad4d01b75a3 \ - --hash=sha256:959665072bd60f45c5b6b5d711f15bdefc9849dd5da9fb6c873e35f5d34d8cfb \ - --hash=sha256:a1d67d0d53d2a138f9e29d8acdabe11310c185e36f0a848efa104d4e40b808e4 \ - --hash=sha256:a493d470183ee620a3df1e6e55b3e4de8143c0ba1b16f3ded83208ea8ddfd91d \ - --hash=sha256:a7ccf5825fd71d4542c8ab28d4d482aace885f5ebe4b40faaa290eed8e095a4c \ - --hash=sha256:a88b7df61a292603e7cd662d92565d915796b094ffb3d206579aaebac6b85d5f \ - --hash=sha256:a97079b955b00b732c6f280d5023e0eefe359045e8b83b08cf0333af9ec78f26 \ - --hash=sha256:d22fdef58976457c65e2796e6730a3ea4a254f3ba83777ecfc8592ff8d77d303 \ - --hash=sha256:d75f693bb4e92c335e0645e8845e553cd09dc91616412d1d4650da835b5449df \ - --hash=sha256:d8593f8464fb64d58e8cb0b905b272d40184eac9a18d83cf8c10749c3eafcd7e \ - --hash=sha256:d8fff0f0c1d8bc5d866762ae95bd99d53282337af1be9dc0d88506b340e74b73 \ - --hash=sha256:de20a212ef3d00d609d0b22eb7cc798d5a69035e81839f549b538eff4105d01c \ - --hash=sha256:e9e9d4e2e336c529d4c435baad846a181e39a982f823f7e4495ec0b0ec8538d2 \ - --hash=sha256:f058a77ef0ece4e210bb0450e68408d4223f728b109764676e1a13537d056bb0 \ - --hash=sha256:f1a4b358947a65b94e2501ce3e078bbc929b039ede4679ddb0460829b12f7375 \ - --hash=sha256:f9b2cde1cd1b2a10246dbc143ba49d942d14fb3d2b4bccf4618d475c65464912 \ - --hash=sha256:fe3390c538f12437b859d815040763abc728955a52ca6ff9c5d4ac707c4ad98e - # via -r build/requirements.in +zstandard==0.25.0 ; python_version < "3.14" \ + --hash=sha256:011d388c76b11a0c165374ce660ce2c8efa8e5d87f34996aa80f9c0816698b64 \ + --hash=sha256:01582723b3ccd6939ab7b3a78622c573799d5d8737b534b86d0e06ac18dbde4a \ + --hash=sha256:05353cef599a7b0b98baca9b068dd36810c3ef0f42bf282583f438caf6ddcee3 \ + --hash=sha256:05df5136bc5a011f33cd25bc9f506e7426c0c9b3f9954f056831ce68f3b6689f \ + --hash=sha256:06acb75eebeedb77b69048031282737717a63e71e4ae3f77cc0c3b9508320df6 \ + --hash=sha256:07b527a69c1e1c8b5ab1ab14e2afe0675614a09182213f21a0717b62027b5936 \ + --hash=sha256:0bbc9a0c65ce0eea3c34a691e3c4b6889f5f3909ba4822ab385fab9057099431 \ + --hash=sha256:0be7622c37c183406f3dbf0cba104118eb16a4ea7359eeb5752f0794882fc250 \ + --hash=sha256:106281ae350e494f4ac8a80470e66d1fe27e497052c8d9c3b95dc4cf1ade81aa \ + --hash=sha256:10ef2a79ab8e2974e2075fb984e5b9806c64134810fac21576f0668e7ea19f8f \ + --hash=sha256:1673b7199bbe763365b81a4f3252b8e80f44c9e323fc42940dc8843bfeaf9851 \ + --hash=sha256:172de1f06947577d3a3005416977cce6168f2261284c02080e7ad0185faeced3 \ + --hash=sha256:181eb40e0b6a29b3cd2849f825e0fa34397f649170673d385f3598ae17cca2e9 \ + --hash=sha256:1869da9571d5e94a85a5e8d57e4e8807b175c9e4a6294e3b66fa4efb074d90f6 \ + --hash=sha256:19796b39075201d51d5f5f790bf849221e58b48a39a5fc74837675d8bafc7362 \ + --hash=sha256:1cd5da4d8e8ee0e88be976c294db744773459d51bb32f707a0f166e5ad5c8649 \ + --hash=sha256:1f3689581a72eaba9131b1d9bdbfe520ccd169999219b41000ede2fca5c1bfdb \ + --hash=sha256:1f830a0dac88719af0ae43b8b2d6aef487d437036468ef3c2ea59c51f9d55fd5 \ + --hash=sha256:223415140608d0f0da010499eaa8ccdb9af210a543fac54bce15babbcfc78439 \ + --hash=sha256:22a06c5df3751bb7dc67406f5374734ccee8ed37fc5981bf1ad7041831fa1137 \ + --hash=sha256:22a086cff1b6ceca18a8dd6096ec631e430e93a8e70a9ca5efa7561a00f826fa \ + --hash=sha256:23ebc8f17a03133b4426bcc04aabd68f8236eb78c3760f12783385171b0fd8bd \ + --hash=sha256:25f8f3cd45087d089aef5ba3848cd9efe3ad41163d3400862fb42f81a3a46701 \ + --hash=sha256:2b6bd67528ee8b5c5f10255735abc21aa106931f0dbaf297c7be0c886353c3d0 \ + --hash=sha256:2e54296a283f3ab5a26fc9b8b5d4978ea0532f37b231644f367aa588930aa043 \ + --hash=sha256:3756b3e9da9b83da1796f8809dd57cb024f838b9eeafde28f3cb472012797ac1 \ + --hash=sha256:37daddd452c0ffb65da00620afb8e17abd4adaae6ce6310702841760c2c26860 \ + --hash=sha256:3a39c94ad7866160a4a46d772e43311a743c316942037671beb264e395bdd611 \ + --hash=sha256:3b870ce5a02d4b22286cf4944c628e0f0881b11b3f14667c1d62185a99e04f53 \ + --hash=sha256:3c83b0188c852a47cd13ef3bf9209fb0a77fa5374958b8c53aaa699398c6bd7b \ + --hash=sha256:4203ce3b31aec23012d3a4cf4a2ed64d12fea5269c49aed5e4c3611b938e4088 \ + --hash=sha256:457ed498fc58cdc12fc48f7950e02740d4f7ae9493dd4ab2168a47c93c31298e \ + --hash=sha256:474d2596a2dbc241a556e965fb76002c1ce655445e4e3bf38e5477d413165ffa \ + --hash=sha256:4b14abacf83dfb5c25eb4e4a79520de9e7e205f72c9ee7702f91233ae57d33a2 \ + --hash=sha256:4b6d83057e713ff235a12e73916b6d356e3084fd3d14ced499d84240f3eecee0 \ + --hash=sha256:4d441506e9b372386a5271c64125f72d5df6d2a8e8a2a45a0ae09b03cb781ef7 \ + --hash=sha256:4f187a0bb61b35119d1926aee039524d1f93aaf38a9916b8c4b78ac8514a0aaf \ + --hash=sha256:51526324f1b23229001eb3735bc8c94f9c578b1bd9e867a0a646a3b17109f388 \ + --hash=sha256:53e08b2445a6bc241261fea89d065536f00a581f02535f8122eba42db9375530 \ + --hash=sha256:53f94448fe5b10ee75d246497168e5825135d54325458c4bfffbaafabcc0a577 \ + --hash=sha256:5a56ba0db2d244117ed744dfa8f6f5b366e14148e00de44723413b2f3938a902 \ + --hash=sha256:5f1ad7bf88535edcf30038f6919abe087f606f62c00a87d7e33e7fc57cb69fcc \ + --hash=sha256:5f5e4c2a23ca271c218ac025bd7d635597048b366d6f31f420aaeb715239fc98 \ + --hash=sha256:6a573a35693e03cf1d67799fd01b50ff578515a8aeadd4595d2a7fa9f3ec002a \ + --hash=sha256:6c0e5a65158a7946e7a7affa6418878ef97ab66636f13353b8502d7ea03c8097 \ + --hash=sha256:6dffecc361d079bb48d7caef5d673c88c8988d3d33fb74ab95b7ee6da42652ea \ + --hash=sha256:7030defa83eef3e51ff26f0b7bfb229f0204b66fe18e04359ce3474ac33cbc09 \ + --hash=sha256:7149623bba7fdf7e7f24312953bcf73cae103db8cae49f8154dd1eadc8a29ecb \ + --hash=sha256:72d35d7aa0bba323965da807a462b0966c91608ef3a48ba761678cb20ce5d8b7 \ + --hash=sha256:75ffc32a569fb049499e63ce68c743155477610532da1eb38e7f24bf7cd29e74 \ + --hash=sha256:7713e1179d162cf5c7906da876ec2ccb9c3a9dcbdffef0cc7f70c3667a205f0b \ + --hash=sha256:78228d8a6a1c177a96b94f7e2e8d012c55f9c760761980da16ae7546a15a8e9b \ + --hash=sha256:7b3c3a3ab9daa3eed242d6ecceead93aebbb8f5f84318d82cee643e019c4b73b \ + --hash=sha256:809c5bcb2c67cd0ed81e9229d227d4ca28f82d0f778fc5fea624a9def3963f91 \ + --hash=sha256:81dad8d145d8fd981b2962b686b2241d3a1ea07733e76a2f15435dfb7fb60150 \ + --hash=sha256:85304a43f4d513f5464ceb938aa02c1e78c2943b29f44a750b48b25ac999a049 \ + --hash=sha256:89c4b48479a43f820b749df49cd7ba2dbc2b1b78560ecb5ab52985574fd40b27 \ + --hash=sha256:8e735494da3db08694d26480f1493ad2cf86e99bdd53e8e9771b2752a5c0246a \ + --hash=sha256:913cbd31a400febff93b564a23e17c3ed2d56c064006f54efec210d586171c00 \ + --hash=sha256:9174f4ed06f790a6869b41cba05b43eeb9a35f8993c4422ab853b705e8112bbd \ + --hash=sha256:9300d02ea7c6506f00e627e287e0492a5eb0371ec1670ae852fefffa6164b072 \ + --hash=sha256:933b65d7680ea337180733cf9e87293cc5500cc0eb3fc8769f4d3c88d724ec5c \ + --hash=sha256:9654dbc012d8b06fc3d19cc825af3f7bf8ae242226df5f83936cb39f5fdc846c \ + --hash=sha256:98750a309eb2f020da61e727de7d7ba3c57c97cf6213f6f6277bb7fb42a8e065 \ + --hash=sha256:99c0c846e6e61718715a3c9437ccc625de26593fea60189567f0118dc9db7512 \ + --hash=sha256:a1a4ae2dec3993a32247995bdfe367fc3266da832d82f8438c8570f989753de1 \ + --hash=sha256:a3f79487c687b1fc69f19e487cd949bf3aae653d181dfb5fde3bf6d18894706f \ + --hash=sha256:a4089a10e598eae6393756b036e0f419e8c1d60f44a831520f9af41c14216cf2 \ + --hash=sha256:a51ff14f8017338e2f2e5dab738ce1ec3b5a851f23b18c1ae1359b1eecbee6df \ + --hash=sha256:a5a419712cf88862a45a23def0ae063686db3d324cec7edbe40509d1a79a0aab \ + --hash=sha256:a9ec8c642d1ec73287ae3e726792dd86c96f5681eb8df274a757bf62b750eae7 \ + --hash=sha256:aaf21ba8fb76d102b696781bddaa0954b782536446083ae3fdaa6f16b25a1c4b \ + --hash=sha256:ab85470ab54c2cb96e176f40342d9ed41e58ca5733be6a893b730e7af9c40550 \ + --hash=sha256:b9af1fe743828123e12b41dd8091eca1074d0c1569cc42e6e1eee98027f2bbd0 \ + --hash=sha256:bfc4e20784722098822e3eee42b8e576b379ed72cca4a7cb856ae733e62192ea \ + --hash=sha256:bfd06b1c5584b657a2892a6014c2f4c20e0db0208c159148fa78c65f7e0b0277 \ + --hash=sha256:c19bcdd826e95671065f8692b5a4aa95c52dc7a02a4c5a0cac46deb879a017a2 \ + --hash=sha256:c2ba942c94e0691467ab901fc51b6f2085ff48f2eea77b1a48240f011e8247c7 \ + --hash=sha256:c8e167d5adf59476fa3e37bee730890e389410c354771a62e3c076c86f9f7778 \ + --hash=sha256:ca54090275939dc8ec5dea2d2afb400e0f83444b2fc24e07df7fdef677110859 \ + --hash=sha256:d7541afd73985c630bafcd6338d2518ae96060075f9463d7dc14cfb33514383d \ + --hash=sha256:d8c56bb4e6c795fc77d74d8e8b80846e1fb8292fc0b5060cd8131d522974b751 \ + --hash=sha256:da469dc041701583e34de852d8634703550348d5822e66a0c827d39b05365b12 \ + --hash=sha256:daab68faadb847063d0c56f361a289c4f268706b598afbf9ad113cbe5c38b6b2 \ + --hash=sha256:e05ab82ea7753354bb054b92e2f288afb750e6b439ff6ca78af52939ebbc476d \ + --hash=sha256:e09bb6252b6476d8d56100e8147b803befa9a12cea144bbe629dd508800d1ad0 \ + --hash=sha256:e29f0cf06974c899b2c188ef7f783607dbef36da4c242eb6c82dcd8b512855e3 \ + --hash=sha256:e59fdc271772f6686e01e1b3b74537259800f57e24280be3f29c8a0deb1904dd \ + --hash=sha256:e7360eae90809efd19b886e59a09dad07da4ca9ba096752e61a2e03c8aca188e \ + --hash=sha256:e96594a5537722fdfb79951672a2a63aec5ebfb823e7560586f7484819f2a08f \ + --hash=sha256:ea9d54cc3d8064260114a0bbf3479fc4a98b21dffc89b3459edd506b69262f6e \ + --hash=sha256:ec996f12524f88e151c339688c3897194821d7f03081ab35d31d1e12ec975e94 \ + --hash=sha256:f27662e4f7dbf9f9c12391cb37b4c4c3cb90ffbd3b1fb9284dadbbb8935fa708 \ + --hash=sha256:f373da2c1757bb7f1acaf09369cdc1d51d84131e50d5fa9863982fd626466313 \ + --hash=sha256:f5aeea11ded7320a84dcdd62a3d95b5186834224a9e55b92ccae35d21a8b63d4 \ + --hash=sha256:f604efd28f239cc21b3adb53eb061e2a205dc164be408e553b41ba2ffe0ca15c \ + --hash=sha256:f67e8f1a324a900e75b5e28ffb152bcac9fbed1cc7b43f99cd90f395c4375344 \ + --hash=sha256:fd7a5004eb1980d3cefe26b2685bcb0b17989901a70a1040d1ac86f1d898c551 \ + --hash=sha256:ffef5a74088f1e09947aecf91011136665152e0b4b359c42be3373897fb39b01 + # via -r build/nonfreethreading-requirements.txt # The following packages are considered to be unsafe in a requirements file: -setuptools==76.0.0 \ - --hash=sha256:199466a166ff664970d0ee145839f5582cb9bca7a0a3a2e795b6a9cb2308e9c6 \ - --hash=sha256:43b4ee60e10b0d0ee98ad11918e114c70701bc6051662a9a675a0496c1a158f4 +setuptools==80.9.0 \ + --hash=sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922 \ + --hash=sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c # via # -r build/requirements.in - # -r build/test-requirements.txt + # tensorboard + # tensorflow diff --git a/build/requirements_lock_3_12.txt b/build/requirements_lock_3_12.txt index 0436ab6dd486..c6aa89903a08 100644 --- a/build/requirements_lock_3_12.txt +++ b/build/requirements_lock_3_12.txt @@ -4,693 +4,1754 @@ # # bazel run //build:requirements.update # -absl-py==2.1.0 \ - --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ - --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff - # via -r build/test-requirements.txt -attrs==23.2.0 \ - --hash=sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30 \ - --hash=sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1 +--index-url https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple + +absl-py==2.3.1 \ + --hash=sha256:a97820526f7fbfd2ec1bce83f3f25e3a14840dac0d8e02a0b71cd75db3f77fc9 \ + --hash=sha256:eeecf07f0c2a93ace0772c92e596ace6d3d3996c042b2128459aaae2a76de11d + # via + # -r build/test-requirements.txt + # keras + # tensorboard + # tensorflow +astunparse==1.6.3 \ + --hash=sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872 \ + --hash=sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8 + # via tensorflow +attrs==25.4.0 \ + --hash=sha256:16d5969b87f0859ef33a48b35d55ac1be6e42ae49d5e853b597db70c35c57e11 \ + --hash=sha256:adcf7e2a1fb3b36ac48d97835bb6d8ade15b8dcce26aba8bf1d14847b57a3373 # via hypothesis -auditwheel==6.1.0 \ - --hash=sha256:3bdc686e774cf9e355e924b0fe5a562d55caa385d72234ffe7b81b378dba360f \ - --hash=sha256:e52f734861859e3743eb29fcac7da9c4921a1e4bea58f954b52f2926f8e9e364 +auditwheel==6.5.0 \ + --hash=sha256:4fbcbd5854054bb1dd7870db03727b871b96b18147db57259561c058603987d7 \ + --hash=sha256:e08d2eede0259be6feff597d041c06175026e93248a1a97143acc52c57714d80 # via -r build/test-requirements.txt -build==1.2.1 \ - --hash=sha256:526263f4870c26f26c433545579475377b2b7588b6f1eac76a001e873ae3e19d \ - --hash=sha256:75e10f767a433d9a86e50d83f418e83efc18ede923ee5ff7df93b6cb0306c5d4 - # via -r build/test-requirements.txt -cloudpickle==3.0.0 \ - --hash=sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 \ - --hash=sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882 +build==1.3.0 \ + --hash=sha256:698edd0ea270bde950f53aed21f3a0135672206f3911e0176261a31e0e07b397 \ + --hash=sha256:7145f0b5061ba90a1500d60bd1b13ca0a8a4cebdd0cc16ed8adf1c0e739f43b4 + # via -r build/requirements.in +certifi==2025.11.12 \ + --hash=sha256:97de8790030bbd5c2d96b7ec782fc2f7820ef8dba6db909ccf95449f2d062d4b \ + --hash=sha256:d8ab5478f2ecd78af242878415affce761ca6bc54a22a27e026d7c25357c3316 + # via requests +charset-normalizer==3.4.4 \ + --hash=sha256:027f6de494925c0ab2a55eab46ae5129951638a49a34d87f4c3eda90f696b4ad \ + --hash=sha256:077fbb858e903c73f6c9db43374fd213b0b6a778106bc7032446a8e8b5b38b93 \ + --hash=sha256:0a98e6759f854bd25a58a73fa88833fba3b7c491169f86ce1180c948ab3fd394 \ + --hash=sha256:0d3d8f15c07f86e9ff82319b3d9ef6f4bf907608f53fe9d92b28ea9ae3d1fd89 \ + --hash=sha256:0f04b14ffe5fdc8c4933862d8306109a2c51e0704acfa35d51598eb45a1e89fc \ + --hash=sha256:11d694519d7f29d6cd09f6ac70028dba10f92f6cdd059096db198c283794ac86 \ + --hash=sha256:194f08cbb32dc406d6e1aea671a68be0823673db2832b38405deba2fb0d88f63 \ + --hash=sha256:1bee1e43c28aa63cb16e5c14e582580546b08e535299b8b6158a7c9c768a1f3d \ + --hash=sha256:21d142cc6c0ec30d2efee5068ca36c128a30b0f2c53c1c07bd78cb6bc1d3be5f \ + --hash=sha256:2437418e20515acec67d86e12bf70056a33abdacb5cb1655042f6538d6b085a8 \ + --hash=sha256:244bfb999c71b35de57821b8ea746b24e863398194a4014e4c76adc2bbdfeff0 \ + --hash=sha256:2677acec1a2f8ef614c6888b5b4ae4060cc184174a938ed4e8ef690e15d3e505 \ + --hash=sha256:277e970e750505ed74c832b4bf75dac7476262ee2a013f5574dd49075879e161 \ + --hash=sha256:2aaba3b0819274cc41757a1da876f810a3e4d7b6eb25699253a4effef9e8e4af \ + --hash=sha256:2b7d8f6c26245217bd2ad053761201e9f9680f8ce52f0fcd8d0755aeae5b2152 \ + --hash=sha256:2c9d3c380143a1fedbff95a312aa798578371eb29da42106a29019368a475318 \ + --hash=sha256:3162d5d8ce1bb98dd51af660f2121c55d0fa541b46dff7bb9b9f86ea1d87de72 \ + --hash=sha256:31fd66405eaf47bb62e8cd575dc621c56c668f27d46a61d975a249930dd5e2a4 \ + --hash=sha256:362d61fd13843997c1c446760ef36f240cf81d3ebf74ac62652aebaf7838561e \ + --hash=sha256:376bec83a63b8021bb5c8ea75e21c4ccb86e7e45ca4eb81146091b56599b80c3 \ + --hash=sha256:44c2a8734b333e0578090c4cd6b16f275e07aa6614ca8715e6c038e865e70576 \ + --hash=sha256:47cc91b2f4dd2833fddaedd2893006b0106129d4b94fdb6af1f4ce5a9965577c \ + --hash=sha256:4902828217069c3c5c71094537a8e623f5d097858ac6ca8252f7b4d10b7560f1 \ + --hash=sha256:4bd5d4137d500351a30687c2d3971758aac9a19208fc110ccb9d7188fbe709e8 \ + --hash=sha256:4fe7859a4e3e8457458e2ff592f15ccb02f3da787fcd31e0183879c3ad4692a1 \ + --hash=sha256:542d2cee80be6f80247095cc36c418f7bddd14f4a6de45af91dfad36d817bba2 \ + --hash=sha256:554af85e960429cf30784dd47447d5125aaa3b99a6f0683589dbd27e2f45da44 \ + --hash=sha256:5833d2c39d8896e4e19b689ffc198f08ea58116bee26dea51e362ecc7cd3ed26 \ + --hash=sha256:5947809c8a2417be3267efc979c47d76a079758166f7d43ef5ae8e9f92751f88 \ + --hash=sha256:5ae497466c7901d54b639cf42d5b8c1b6a4fead55215500d2f486d34db48d016 \ + --hash=sha256:5bd2293095d766545ec1a8f612559f6b40abc0eb18bb2f5d1171872d34036ede \ + --hash=sha256:5bfbb1b9acf3334612667b61bd3002196fe2a1eb4dd74d247e0f2a4d50ec9bbf \ + --hash=sha256:5cb4d72eea50c8868f5288b7f7f33ed276118325c1dfd3957089f6b519e1382a \ + --hash=sha256:5dbe56a36425d26d6cfb40ce79c314a2e4dd6211d51d6d2191c00bed34f354cc \ + --hash=sha256:5f819d5fe9234f9f82d75bdfa9aef3a3d72c4d24a6e57aeaebba32a704553aa0 \ + --hash=sha256:64b55f9dce520635f018f907ff1b0df1fdc31f2795a922fb49dd14fbcdf48c84 \ + --hash=sha256:6515f3182dbe4ea06ced2d9e8666d97b46ef4c75e326b79bb624110f122551db \ + --hash=sha256:65e2befcd84bc6f37095f5961e68a6f077bf44946771354a28ad434c2cce0ae1 \ + --hash=sha256:6aee717dcfead04c6eb1ce3bd29ac1e22663cdea57f943c87d1eab9a025438d7 \ + --hash=sha256:6b39f987ae8ccdf0d2642338faf2abb1862340facc796048b604ef14919e55ed \ + --hash=sha256:6e1fcf0720908f200cd21aa4e6750a48ff6ce4afe7ff5a79a90d5ed8a08296f8 \ + --hash=sha256:74018750915ee7ad843a774364e13a3db91682f26142baddf775342c3f5b1133 \ + --hash=sha256:74664978bb272435107de04e36db5a9735e78232b85b77d45cfb38f758efd33e \ + --hash=sha256:74bb723680f9f7a6234dcf67aea57e708ec1fbdf5699fb91dfd6f511b0a320ef \ + --hash=sha256:752944c7ffbfdd10c074dc58ec2d5a8a4cd9493b314d367c14d24c17684ddd14 \ + --hash=sha256:778d2e08eda00f4256d7f672ca9fef386071c9202f5e4607920b86d7803387f2 \ + --hash=sha256:780236ac706e66881f3b7f2f32dfe90507a09e67d1d454c762cf642e6e1586e0 \ + --hash=sha256:798d75d81754988d2565bff1b97ba5a44411867c0cf32b77a7e8f8d84796b10d \ + --hash=sha256:799a7a5e4fb2d5898c60b640fd4981d6a25f1c11790935a44ce38c54e985f828 \ + --hash=sha256:7a32c560861a02ff789ad905a2fe94e3f840803362c84fecf1851cb4cf3dc37f \ + --hash=sha256:7c308f7e26e4363d79df40ca5b2be1c6ba9f02bdbccfed5abddb7859a6ce72cf \ + --hash=sha256:7fa17817dc5625de8a027cb8b26d9fefa3ea28c8253929b8d6649e705d2835b6 \ + --hash=sha256:81d5eb2a312700f4ecaa977a8235b634ce853200e828fbadf3a9c50bab278328 \ + --hash=sha256:82004af6c302b5d3ab2cfc4cc5f29db16123b1a8417f2e25f9066f91d4411090 \ + --hash=sha256:837c2ce8c5a65a2035be9b3569c684358dfbf109fd3b6969630a87535495ceaa \ + --hash=sha256:840c25fb618a231545cbab0564a799f101b63b9901f2569faecd6b222ac72381 \ + --hash=sha256:8a6562c3700cce886c5be75ade4a5db4214fda19fede41d9792d100288d8f94c \ + --hash=sha256:8af65f14dc14a79b924524b1e7fffe304517b2bff5a58bf64f30b98bbc5079eb \ + --hash=sha256:8ef3c867360f88ac904fd3f5e1f902f13307af9052646963ee08ff4f131adafc \ + --hash=sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a \ + --hash=sha256:99ae2cffebb06e6c22bdc25801d7b30f503cc87dbd283479e7b606f70aff57ec \ + --hash=sha256:9a26f18905b8dd5d685d6d07b0cdf98a79f3c7a918906af7cc143ea2e164c8bc \ + --hash=sha256:9b35f4c90079ff2e2edc5b26c0c77925e5d2d255c42c74fdb70fb49b172726ac \ + --hash=sha256:9cd98cdc06614a2f768d2b7286d66805f94c48cde050acdbbb7db2600ab3197e \ + --hash=sha256:9d1bb833febdff5c8927f922386db610b49db6e0d4f4ee29601d71e7c2694313 \ + --hash=sha256:9f7fcd74d410a36883701fafa2482a6af2ff5ba96b9a620e9e0721e28ead5569 \ + --hash=sha256:a59cb51917aa591b1c4e6a43c132f0cdc3c76dbad6155df4e28ee626cc77a0a3 \ + --hash=sha256:a61900df84c667873b292c3de315a786dd8dac506704dea57bc957bd31e22c7d \ + --hash=sha256:a79cfe37875f822425b89a82333404539ae63dbdddf97f84dcbc3d339aae9525 \ + --hash=sha256:a8a8b89589086a25749f471e6a900d3f662d1d3b6e2e59dcecf787b1cc3a1894 \ + --hash=sha256:a8bf8d0f749c5757af2142fe7903a9df1d2e8aa3841559b2bad34b08d0e2bcf3 \ + --hash=sha256:a9768c477b9d7bd54bc0c86dbaebdec6f03306675526c9927c0e8a04e8f94af9 \ + --hash=sha256:ac1c4a689edcc530fc9d9aa11f5774b9e2f33f9a0c6a57864e90908f5208d30a \ + --hash=sha256:af2d8c67d8e573d6de5bc30cdb27e9b95e49115cd9baad5ddbd1a6207aaa82a9 \ + --hash=sha256:b435cba5f4f750aa6c0a0d92c541fb79f69a387c91e61f1795227e4ed9cece14 \ + --hash=sha256:b5b290ccc2a263e8d185130284f8501e3e36c5e02750fc6b6bdeb2e9e96f1e25 \ + --hash=sha256:b5d84d37db046c5ca74ee7bb47dd6cbc13f80665fdde3e8040bdd3fb015ecb50 \ + --hash=sha256:b7cf1017d601aa35e6bb650b6ad28652c9cd78ee6caff19f3c28d03e1c80acbf \ + --hash=sha256:bc7637e2f80d8530ee4a78e878bce464f70087ce73cf7c1caf142416923b98f1 \ + --hash=sha256:c0463276121fdee9c49b98908b3a89c39be45d86d1dbaa22957e38f6321d4ce3 \ + --hash=sha256:c4ef880e27901b6cc782f1b95f82da9313c0eb95c3af699103088fa0ac3ce9ac \ + --hash=sha256:c8ae8a0f02f57a6e61203a31428fa1d677cbe50c93622b4149d5c0f319c1d19e \ + --hash=sha256:ca5862d5b3928c4940729dacc329aa9102900382fea192fc5e52eb69d6093815 \ + --hash=sha256:cb01158d8b88ee68f15949894ccc6712278243d95f344770fa7593fa2d94410c \ + --hash=sha256:cb6254dc36b47a990e59e1068afacdcd02958bdcce30bb50cc1700a8b9d624a6 \ + --hash=sha256:cc00f04ed596e9dc0da42ed17ac5e596c6ccba999ba6bd92b0e0aef2f170f2d6 \ + --hash=sha256:cd09d08005f958f370f539f186d10aec3377d55b9eeb0d796025d4886119d76e \ + --hash=sha256:cd4b7ca9984e5e7985c12bc60a6f173f3c958eae74f3ef6624bb6b26e2abbae4 \ + --hash=sha256:ce8a0633f41a967713a59c4139d29110c07e826d131a316b50ce11b1d79b4f84 \ + --hash=sha256:cead0978fc57397645f12578bfd2d5ea9138ea0fac82b2f63f7f7c6877986a69 \ + --hash=sha256:d055ec1e26e441f6187acf818b73564e6e6282709e9bcb5b63f5b23068356a15 \ + --hash=sha256:d1f13550535ad8cff21b8d757a3257963e951d96e20ec82ab44bc64aeb62a191 \ + --hash=sha256:d9c7f57c3d666a53421049053eaacdd14bbd0a528e2186fcb2e672effd053bb0 \ + --hash=sha256:d9e45d7faa48ee908174d8fe84854479ef838fc6a705c9315372eacbc2f02897 \ + --hash=sha256:da3326d9e65ef63a817ecbcc0df6e94463713b754fe293eaa03da99befb9a5bd \ + --hash=sha256:de00632ca48df9daf77a2c65a484531649261ec9f25489917f09e455cb09ddb2 \ + --hash=sha256:e1f185f86a6f3403aa2420e815904c67b2f9ebc443f045edd0de921108345794 \ + --hash=sha256:e824f1492727fa856dd6eda4f7cee25f8518a12f3c4a56a74e8095695089cf6d \ + --hash=sha256:e912091979546adf63357d7e2ccff9b44f026c075aeaf25a52d0e95ad2281074 \ + --hash=sha256:eaabd426fe94daf8fd157c32e571c85cb12e66692f15516a83a03264b08d06c3 \ + --hash=sha256:ebf3e58c7ec8a8bed6d66a75d7fb37b55e5015b03ceae72a8e7c74495551e224 \ + --hash=sha256:ecaae4149d99b1c9e7b88bb03e3221956f68fd6d50be2ef061b2381b61d20838 \ + --hash=sha256:eecbc200c7fd5ddb9a7f16c7decb07b566c29fa2161a16cf67b8d068bd21690a \ + --hash=sha256:f155a433c2ec037d4e8df17d18922c3a0d9b3232a396690f17175d2946f0218d \ + --hash=sha256:f1e34719c6ed0b92f418c7c780480b26b5d9c50349e9a9af7d76bf757530350d \ + --hash=sha256:f34be2938726fc13801220747472850852fe6b1ea75869a048d6f896838c896f \ + --hash=sha256:f820802628d2694cb7e56db99213f930856014862f3fd943d290ea8438d07ca8 \ + --hash=sha256:f8bf04158c6b607d747e93949aa60618b61312fe647a6369f88ce2ff16043490 \ + --hash=sha256:f8e160feb2aed042cd657a72acc0b481212ed28b1b9a95c0cee1621b524e1966 \ + --hash=sha256:f9d332f8c2a2fcbffe1378594431458ddbef721c1769d78e2cbc06280d8155f9 \ + --hash=sha256:fa09f53c465e532f4d3db095e0c55b615f010ad81803d383195b6b5ca6cbf5f3 \ + --hash=sha256:faa3a41b2b66b6e50f84ae4a68c64fcd0c44355741c6374813a800cd6695db9e \ + --hash=sha256:fd44c878ea55ba351104cb93cc85e74916eb8fa440ca7903e57575e97394f608 + # via requests +cloudpickle==3.1.2 \ + --hash=sha256:7fda9eb655c9c230dab534f1983763de5835249750e85fbcef43aaa30a9a2414 \ + --hash=sha256:9acb47f6afd73f60dc1df93bb801b472f05ff42fa6c84167d25cb206be1fbf4a # via -r build/test-requirements.txt colorama==0.4.6 \ --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 - # via -r build/test-requirements.txt -contourpy==1.2.1 \ - --hash=sha256:00e5388f71c1a0610e6fe56b5c44ab7ba14165cdd6d695429c5cd94021e390b2 \ - --hash=sha256:10a37ae557aabf2509c79715cd20b62e4c7c28b8cd62dd7d99e5ed3ce28c3fd9 \ - --hash=sha256:11959f0ce4a6f7b76ec578576a0b61a28bdc0696194b6347ba3f1c53827178b9 \ - --hash=sha256:187fa1d4c6acc06adb0fae5544c59898ad781409e61a926ac7e84b8f276dcef4 \ - --hash=sha256:1a07fc092a4088ee952ddae19a2b2a85757b923217b7eed584fdf25f53a6e7ce \ - --hash=sha256:1cac0a8f71a041aa587410424ad46dfa6a11f6149ceb219ce7dd48f6b02b87a7 \ - --hash=sha256:1d59e739ab0e3520e62a26c60707cc3ab0365d2f8fecea74bfe4de72dc56388f \ - --hash=sha256:2855c8b0b55958265e8b5888d6a615ba02883b225f2227461aa9127c578a4922 \ - --hash=sha256:2e785e0f2ef0d567099b9ff92cbfb958d71c2d5b9259981cd9bee81bd194c9a4 \ - --hash=sha256:309be79c0a354afff9ff7da4aaed7c3257e77edf6c1b448a779329431ee79d7e \ - --hash=sha256:39f3ecaf76cd98e802f094e0d4fbc6dc9c45a8d0c4d185f0f6c2234e14e5f75b \ - --hash=sha256:457499c79fa84593f22454bbd27670227874cd2ff5d6c84e60575c8b50a69619 \ - --hash=sha256:49e70d111fee47284d9dd867c9bb9a7058a3c617274900780c43e38d90fe1205 \ - --hash=sha256:4c75507d0a55378240f781599c30e7776674dbaf883a46d1c90f37e563453480 \ - --hash=sha256:4c863140fafc615c14a4bf4efd0f4425c02230eb8ef02784c9a156461e62c965 \ - --hash=sha256:4d8908b3bee1c889e547867ca4cdc54e5ab6be6d3e078556814a22457f49423c \ - --hash=sha256:5b9eb0ca724a241683c9685a484da9d35c872fd42756574a7cfbf58af26677fd \ - --hash=sha256:6022cecf8f44e36af10bd9118ca71f371078b4c168b6e0fab43d4a889985dbb5 \ - --hash=sha256:6150ffa5c767bc6332df27157d95442c379b7dce3a38dff89c0f39b63275696f \ - --hash=sha256:62828cada4a2b850dbef89c81f5a33741898b305db244904de418cc957ff05dc \ - --hash=sha256:7b4182299f251060996af5249c286bae9361fa8c6a9cda5efc29fe8bfd6062ec \ - --hash=sha256:94b34f32646ca0414237168d68a9157cb3889f06b096612afdd296003fdd32fd \ - --hash=sha256:9ce6889abac9a42afd07a562c2d6d4b2b7134f83f18571d859b25624a331c90b \ - --hash=sha256:9cffe0f850e89d7c0012a1fb8730f75edd4320a0a731ed0c183904fe6ecfc3a9 \ - --hash=sha256:a12a813949e5066148712a0626895c26b2578874e4cc63160bb007e6df3436fe \ - --hash=sha256:a1eea9aecf761c661d096d39ed9026574de8adb2ae1c5bd7b33558af884fb2ce \ - --hash=sha256:a31f94983fecbac95e58388210427d68cd30fe8a36927980fab9c20062645609 \ - --hash=sha256:ac58bdee53cbeba2ecad824fa8159493f0bf3b8ea4e93feb06c9a465d6c87da8 \ - --hash=sha256:af3f4485884750dddd9c25cb7e3915d83c2db92488b38ccb77dd594eac84c4a0 \ - --hash=sha256:b33d2bc4f69caedcd0a275329eb2198f560b325605810895627be5d4b876bf7f \ - --hash=sha256:b59c0ffceff8d4d3996a45f2bb6f4c207f94684a96bf3d9728dbb77428dd8cb8 \ - --hash=sha256:bb6834cbd983b19f06908b45bfc2dad6ac9479ae04abe923a275b5f48f1a186b \ - --hash=sha256:bd3db01f59fdcbce5b22afad19e390260d6d0222f35a1023d9adc5690a889364 \ - --hash=sha256:bd7c23df857d488f418439686d3b10ae2fbf9bc256cd045b37a8c16575ea1040 \ - --hash=sha256:c2528d60e398c7c4c799d56f907664673a807635b857df18f7ae64d3e6ce2d9f \ - --hash=sha256:d31a63bc6e6d87f77d71e1abbd7387ab817a66733734883d1fc0021ed9bfa083 \ - --hash=sha256:d4492d82b3bc7fbb7e3610747b159869468079fe149ec5c4d771fa1f614a14df \ - --hash=sha256:ddcb8581510311e13421b1f544403c16e901c4e8f09083c881fab2be80ee31ba \ - --hash=sha256:e1d59258c3c67c865435d8fbeb35f8c59b8bef3d6f46c1f29f6123556af28445 \ - --hash=sha256:eb3315a8a236ee19b6df481fc5f997436e8ade24a9f03dfdc6bd490fea20c6da \ - --hash=sha256:ef2b055471c0eb466033760a521efb9d8a32b99ab907fc8358481a1dd29e3bd3 \ - --hash=sha256:ef5adb9a3b1d0c645ff694f9bca7702ec2c70f4d734f9922ea34de02294fdf72 \ - --hash=sha256:f32c38afb74bd98ce26de7cc74a67b40afb7b05aae7b42924ea990d51e4dac02 \ - --hash=sha256:fe0ccca550bb8e5abc22f530ec0466136379c01321fd94f30a22231e8a48d985 + # via -r build/requirements.in +contourpy==1.3.3 \ + --hash=sha256:023b44101dfe49d7d53932be418477dba359649246075c996866106da069af69 \ + --hash=sha256:07ce5ed73ecdc4a03ffe3e1b3e3c1166db35ae7584be76f65dbbe28a7791b0cc \ + --hash=sha256:083e12155b210502d0bca491432bb04d56dc3432f95a979b429f2848c3dbe880 \ + --hash=sha256:0bf67e0e3f482cb69779dd3061b534eb35ac9b17f163d851e2a547d56dba0a3a \ + --hash=sha256:0c1fc238306b35f246d61a1d416a627348b5cf0648648a031e14bb8705fcdfe8 \ + --hash=sha256:13b68d6a62db8eafaebb8039218921399baf6e47bf85006fd8529f2a08ef33fc \ + --hash=sha256:15ff10bfada4bf92ec8b31c62bf7c1834c244019b4a33095a68000d7075df470 \ + --hash=sha256:177fb367556747a686509d6fef71d221a4b198a3905fe824430e5ea0fda54eb5 \ + --hash=sha256:1cadd8b8969f060ba45ed7c1b714fe69185812ab43bd6b86a9123fe8f99c3263 \ + --hash=sha256:1fd43c3be4c8e5fd6e4f2baeae35ae18176cf2e5cced681cca908addf1cdd53b \ + --hash=sha256:22e9b1bd7a9b1d652cd77388465dc358dafcd2e217d35552424aa4f996f524f5 \ + --hash=sha256:23416f38bfd74d5d28ab8429cc4d63fa67d5068bd711a85edb1c3fb0c3e2f381 \ + --hash=sha256:283edd842a01e3dcd435b1c5116798d661378d83d36d337b8dde1d16a5fc9ba3 \ + --hash=sha256:2a2a8b627d5cc6b7c41a4beff6c5ad5eb848c88255fda4a8745f7e901b32d8e4 \ + --hash=sha256:2b7e9480ffe2b0cd2e787e4df64270e3a0440d9db8dc823312e2c940c167df7e \ + --hash=sha256:322ab1c99b008dad206d406bb61d014cf0174df491ae9d9d0fac6a6fda4f977f \ + --hash=sha256:33c82d0138c0a062380332c861387650c82e4cf1747aaa6938b9b6516762e772 \ + --hash=sha256:348ac1f5d4f1d66d3322420f01d42e43122f43616e0f194fc1c9f5d830c5b286 \ + --hash=sha256:3519428f6be58431c56581f1694ba8e50626f2dd550af225f82fb5f5814d2a42 \ + --hash=sha256:3c30273eb2a55024ff31ba7d052dde990d7d8e5450f4bbb6e913558b3d6c2301 \ + --hash=sha256:3d1a3799d62d45c18bafd41c5fa05120b96a28079f2393af559b843d1a966a77 \ + --hash=sha256:451e71b5a7d597379ef572de31eeb909a87246974d960049a9848c3bc6c41bf7 \ + --hash=sha256:459c1f020cd59fcfe6650180678a9993932d80d44ccde1fa1868977438f0b411 \ + --hash=sha256:4d00e655fcef08aba35ec9610536bfe90267d7ab5ba944f7032549c55a146da1 \ + --hash=sha256:4debd64f124ca62069f313a9cb86656ff087786016d76927ae2cf37846b006c9 \ + --hash=sha256:4feffb6537d64b84877da813a5c30f1422ea5739566abf0bd18065ac040e120a \ + --hash=sha256:50ed930df7289ff2a8d7afeb9603f8289e5704755c7e5c3bbd929c90c817164b \ + --hash=sha256:51e79c1f7470158e838808d4a996fa9bac72c498e93d8ebe5119bc1e6becb0db \ + --hash=sha256:556dba8fb6f5d8742f2923fe9457dbdd51e1049c4a43fd3986a0b14a1d815fc6 \ + --hash=sha256:598c3aaece21c503615fd59c92a3598b428b2f01bfb4b8ca9c4edeecc2438620 \ + --hash=sha256:5ed3657edf08512fc3fe81b510e35c2012fbd3081d2e26160f27ca28affec989 \ + --hash=sha256:626d60935cf668e70a5ce6ff184fd713e9683fb458898e4249b63be9e28286ea \ + --hash=sha256:644a6853d15b2512d67881586bd03f462c7ab755db95f16f14d7e238f2852c67 \ + --hash=sha256:655456777ff65c2c548b7c454af9c6f33f16c8884f11083244b5819cc214f1b5 \ + --hash=sha256:66c8a43a4f7b8df8b71ee1840e4211a3c8d93b214b213f590e18a1beca458f7d \ + --hash=sha256:6afc576f7b33cf00996e5c1102dc2a8f7cc89e39c0b55df93a0b78c1bd992b36 \ + --hash=sha256:6c3d53c796f8647d6deb1abe867daeb66dcc8a97e8455efa729516b997b8ed99 \ + --hash=sha256:709a48ef9a690e1343202916450bc48b9e51c049b089c7f79a267b46cffcdaa1 \ + --hash=sha256:70f9aad7de812d6541d29d2bbf8feb22ff7e1c299523db288004e3157ff4674e \ + --hash=sha256:8153b8bfc11e1e4d75bcb0bff1db232f9e10b274e0929de9d608027e0d34ff8b \ + --hash=sha256:87acf5963fc2b34825e5b6b048f40e3635dd547f590b04d2ab317c2619ef7ae8 \ + --hash=sha256:88df9880d507169449d434c293467418b9f6cbe82edd19284aa0409e7fdb933d \ + --hash=sha256:929ddf8c4c7f348e4c0a5a3a714b5c8542ffaa8c22954862a46ca1813b667ee7 \ + --hash=sha256:92d9abc807cf7d0e047b95ca5d957cf4792fcd04e920ca70d48add15c1a90ea7 \ + --hash=sha256:95b181891b4c71de4bb404c6621e7e2390745f887f2a026b2d99e92c17892339 \ + --hash=sha256:9e999574eddae35f1312c2b4b717b7885d4edd6cb46700e04f7f02db454e67c1 \ + --hash=sha256:a15459b0f4615b00bbd1e91f1b9e19b7e63aea7483d03d804186f278c0af2659 \ + --hash=sha256:a22738912262aa3e254e4f3cb079a95a67132fc5a063890e224393596902f5a4 \ + --hash=sha256:ab2fd90904c503739a75b7c8c5c01160130ba67944a7b77bbf36ef8054576e7f \ + --hash=sha256:ab3074b48c4e2cf1a960e6bbeb7f04566bf36b1861d5c9d4d8ac04b82e38ba20 \ + --hash=sha256:afe5a512f31ee6bd7d0dda52ec9864c984ca3d66664444f2d72e0dc4eb832e36 \ + --hash=sha256:b08a32ea2f8e42cf1d4be3169a98dd4be32bafe4f22b6c4cb4ba810fa9e5d2cb \ + --hash=sha256:b20c7c9a3bf701366556e1b1984ed2d0cedf999903c51311417cf5f591d8c78d \ + --hash=sha256:b2e8faa0ed68cb29af51edd8e24798bb661eac3bd9f65420c1887b6ca89987c8 \ + --hash=sha256:b7301b89040075c30e5768810bc96a8e8d78085b47d8be6e4c3f5a0b4ed478a0 \ + --hash=sha256:b7448cb5a725bb1e35ce88771b86fba35ef418952474492cf7c764059933ff8b \ + --hash=sha256:ca0fdcd73925568ca027e0b17ab07aad764be4706d0a925b89227e447d9737b7 \ + --hash=sha256:ca658cd1a680a5c9ea96dc61cdbae1e85c8f25849843aa799dfd3cb370ad4fbe \ + --hash=sha256:cbedb772ed74ff5be440fa8eee9bd49f64f6e3fc09436d9c7d8f1c287b121d77 \ + --hash=sha256:cd5dfcaeb10f7b7f9dc8941717c6c2ade08f587be2226222c12b25f0483ed497 \ + --hash=sha256:cf9022ef053f2694e31d630feaacb21ea24224be1c3ad0520b13d844274614fd \ + --hash=sha256:d002b6f00d73d69333dac9d0b8d5e84d9724ff9ef044fd63c5986e62b7c9e1b1 \ + --hash=sha256:d06bb1f751ba5d417047db62bca3c8fde202b8c11fb50742ab3ab962c81e8216 \ + --hash=sha256:d304906ecc71672e9c89e87c4675dc5c2645e1f4269a5063b99b0bb29f232d13 \ + --hash=sha256:e4e6b05a45525357e382909a4c1600444e2a45b4795163d3b22669285591c1ae \ + --hash=sha256:e74a9a0f5e3fff48fb5a7f2fd2b9b70a3fe014a67522f79b7cca4c0c7e43c9ae \ + --hash=sha256:ea37e7b45949df430fe649e5de8351c423430046a2af20b1c1961cae3afcda77 \ + --hash=sha256:f64836de09927cba6f79dcd00fdd7d5329f3fccc633468507079c829ca4db4e3 \ + --hash=sha256:fd6ec6be509c787f1caf6b247f0b1ca598bef13f4ddeaa126b7658215529ba0f \ + --hash=sha256:fd907ae12cd483cd83e414b12941c632a969171bf90fc937d0c9f268a31cafff \ + --hash=sha256:fd914713266421b7536de2bfa8181aa8c699432b6763a0ea64195ebe28bff6a9 \ + --hash=sha256:fde6c716d51c04b1c25d0b90364d0be954624a0ee9d60e23e850e8d48353d07a # via matplotlib cycler==0.12.1 \ --hash=sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30 \ --hash=sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c # via matplotlib -etils[epath,epy]==1.8.0 \ - --hash=sha256:f31d7f27a889457eaa44eab18ce836d24fd6d40dbbb167d38879b7296f6456ea \ - --hash=sha256:fb478f57fec202e260e54c9192b317692fd63db2d11d993e70bcdffa29cccd58 +etils[epath,epy]==1.13.0 \ + --hash=sha256:a5b60c71f95bcd2d43d4e9fb3dc3879120c1f60472bb5ce19f7a860b1d44f607 \ + --hash=sha256:d9cd4f40fbe77ad6613b7348a18132cc511237b6c076dbb89105c0b520a4c6bb # via -r build/requirements.in -execnet==2.1.1 \ - --hash=sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc \ - --hash=sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3 +execnet==2.1.2 \ + --hash=sha256:63d83bfdd9a23e35b9c6a3261412324f964c2ec8dcd8d3c6916ee9373e0befcd \ + --hash=sha256:67fba928dd5a544b783f6056f449e5e3931a5c378b128bc18501f7ea79e296ec # via pytest-xdist -filelock==3.14.0 \ - --hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \ - --hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a +filelock==3.20.1 \ + --hash=sha256:15d9e9a67306188a44baa72f569d2bfd803076269365fdea0934385da4dc361a \ + --hash=sha256:b8360948b351b80f420878d8516519a2204b07aefcdcfd24912a5d33127f188c # via -r build/test-requirements.txt -flatbuffers==24.3.25 \ - --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ - --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 - # via -r build/test-requirements.txt -fonttools==4.51.0 \ - --hash=sha256:0118ef998a0699a96c7b28457f15546815015a2710a1b23a7bf6c1be60c01636 \ - --hash=sha256:0d145976194a5242fdd22df18a1b451481a88071feadf251221af110ca8f00ce \ - --hash=sha256:0e19bd9e9964a09cd2433a4b100ca7f34e34731e0758e13ba9a1ed6e5468cc0f \ - --hash=sha256:0f08c901d3866a8905363619e3741c33f0a83a680d92a9f0e575985c2634fcc1 \ - --hash=sha256:1250e818b5f8a679ad79660855528120a8f0288f8f30ec88b83db51515411fcc \ - --hash=sha256:15c94eeef6b095831067f72c825eb0e2d48bb4cea0647c1b05c981ecba2bf39f \ - --hash=sha256:1621ee57da887c17312acc4b0e7ac30d3a4fb0fec6174b2e3754a74c26bbed1e \ - --hash=sha256:180194c7fe60c989bb627d7ed5011f2bef1c4d36ecf3ec64daec8302f1ae0716 \ - --hash=sha256:278e50f6b003c6aed19bae2242b364e575bcb16304b53f2b64f6551b9c000e15 \ - --hash=sha256:32b17504696f605e9e960647c5f64b35704782a502cc26a37b800b4d69ff3c77 \ - --hash=sha256:3bee3f3bd9fa1d5ee616ccfd13b27ca605c2b4270e45715bd2883e9504735034 \ - --hash=sha256:4060acc2bfa2d8e98117828a238889f13b6f69d59f4f2d5857eece5277b829ba \ - --hash=sha256:54dcf21a2f2d06ded676e3c3f9f74b2bafded3a8ff12f0983160b13e9f2fb4a7 \ - --hash=sha256:56fc244f2585d6c00b9bcc59e6593e646cf095a96fe68d62cd4da53dd1287b55 \ - --hash=sha256:599bdb75e220241cedc6faebfafedd7670335d2e29620d207dd0378a4e9ccc5a \ - --hash=sha256:5f6bc991d1610f5c3bbe997b0233cbc234b8e82fa99fc0b2932dc1ca5e5afec0 \ - --hash=sha256:60a3409c9112aec02d5fb546f557bca6efa773dcb32ac147c6baf5f742e6258b \ - --hash=sha256:68b3fb7775a923be73e739f92f7e8a72725fd333eab24834041365d2278c3671 \ - --hash=sha256:76f1777d8b3386479ffb4a282e74318e730014d86ce60f016908d9801af9ca2a \ - --hash=sha256:806e7912c32a657fa39d2d6eb1d3012d35f841387c8fc6cf349ed70b7c340039 \ - --hash=sha256:84d7751f4468dd8cdd03ddada18b8b0857a5beec80bce9f435742abc9a851a74 \ - --hash=sha256:865a58b6e60b0938874af0968cd0553bcd88e0b2cb6e588727117bd099eef836 \ - --hash=sha256:8ac27f436e8af7779f0bb4d5425aa3535270494d3bc5459ed27de3f03151e4c2 \ - --hash=sha256:8b4850fa2ef2cfbc1d1f689bc159ef0f45d8d83298c1425838095bf53ef46308 \ - --hash=sha256:8b5ad456813d93b9c4b7ee55302208db2b45324315129d85275c01f5cb7e61a2 \ - --hash=sha256:8e2f1a4499e3b5ee82c19b5ee57f0294673125c65b0a1ff3764ea1f9db2f9ef5 \ - --hash=sha256:9696fe9f3f0c32e9a321d5268208a7cc9205a52f99b89479d1b035ed54c923f1 \ - --hash=sha256:96a48e137c36be55e68845fc4284533bda2980f8d6f835e26bca79d7e2006438 \ - --hash=sha256:a8feca65bab31479d795b0d16c9a9852902e3a3c0630678efb0b2b7941ea9c74 \ - --hash=sha256:aefa011207ed36cd280babfaa8510b8176f1a77261833e895a9d96e57e44802f \ - --hash=sha256:b2b92381f37b39ba2fc98c3a45a9d6383bfc9916a87d66ccb6553f7bdd129097 \ - --hash=sha256:b3c61423f22165541b9403ee39874dcae84cd57a9078b82e1dce8cb06b07fa2e \ - --hash=sha256:b5b48a1121117047d82695d276c2af2ee3a24ffe0f502ed581acc2673ecf1037 \ - --hash=sha256:c18b49adc721a7d0b8dfe7c3130c89b8704baf599fb396396d07d4aa69b824a1 \ - --hash=sha256:c5b8cab0c137ca229433570151b5c1fc6af212680b58b15abd797dcdd9dd5051 \ - --hash=sha256:c7e91abdfae1b5c9e3a543f48ce96013f9a08c6c9668f1e6be0beabf0a569c1b \ - --hash=sha256:cadf4e12a608ef1d13e039864f484c8a968840afa0258b0b843a0556497ea9ed \ - --hash=sha256:dc0673361331566d7a663d7ce0f6fdcbfbdc1f59c6e3ed1165ad7202ca183c68 \ - --hash=sha256:de7c29bdbdd35811f14493ffd2534b88f0ce1b9065316433b22d63ca1cd21f14 \ - --hash=sha256:e9d9298be7a05bb4801f558522adbe2feea1b0b103d5294ebf24a92dd49b78e5 \ - --hash=sha256:ee1af4be1c5afe4c96ca23badd368d8dc75f611887fb0c0dac9f71ee5d6f110e \ - --hash=sha256:f7e89853d8bea103c8e3514b9f9dc86b5b4120afb4583b57eb10dfa5afbe0936 +flatbuffers==25.9.23 \ + --hash=sha256:255538574d6cb6d0a79a17ec8bc0d30985913b87513a01cce8bcdb6b4c44d0e2 \ + --hash=sha256:676f9fa62750bb50cf531b42a0a2a118ad8f7f797a511eda12881c016f093b12 + # via + # -r build/test-requirements.txt + # tensorflow +fonttools==4.61.1 \ + --hash=sha256:0de30bfe7745c0d1ffa2b0b7048fb7123ad0d71107e10ee090fa0b16b9452e87 \ + --hash=sha256:10d88e55330e092940584774ee5e8a6971b01fc2f4d3466a1d6c158230880796 \ + --hash=sha256:11f35ad7805edba3aac1a3710d104592df59f4b957e30108ae0ba6c10b11dd75 \ + --hash=sha256:15acc09befd16a0fb8a8f62bc147e1a82817542d72184acca9ce6e0aeda9fa6d \ + --hash=sha256:17d2bf5d541add43822bcf0c43d7d847b160c9bb01d15d5007d84e2217aaa371 \ + --hash=sha256:2180f14c141d2f0f3da43f3a81bc8aa4684860f6b0e6f9e165a4831f24e6a23b \ + --hash=sha256:21e7c8d76f62ab13c9472ccf74515ca5b9a761d1bde3265152a6dc58700d895b \ + --hash=sha256:41a7170d042e8c0024703ed13b71893519a1a6d6e18e933e3ec7507a2c26a4b2 \ + --hash=sha256:41ed4b5ec103bd306bb68f81dc166e77409e5209443e5773cb4ed837bcc9b0d3 \ + --hash=sha256:497c31ce314219888c0e2fce5ad9178ca83fe5230b01a5006726cdf3ac9f24d9 \ + --hash=sha256:4c1b526c8d3f615a7b1867f38a9410849c8f4aef078535742198e942fba0e9bd \ + --hash=sha256:4d7092bb38c53bbc78e9255a59158b150bcdc115a1e3b3ce0b5f267dc35dd63c \ + --hash=sha256:4f5686e1fe5fce75d82d93c47a438a25bf0d1319d2843a926f741140b2b16e0c \ + --hash=sha256:58b0ee0ab5b1fc9921eccfe11d1435added19d6494dde14e323f25ad2bc30c56 \ + --hash=sha256:5ce02f38a754f207f2f06557523cd39a06438ba3aafc0639c477ac409fc64e37 \ + --hash=sha256:5fade934607a523614726119164ff621e8c30e8fa1ffffbbd358662056ba69f0 \ + --hash=sha256:5fe9fd43882620017add5eabb781ebfbc6998ee49b35bd7f8f79af1f9f99a958 \ + --hash=sha256:64102ca87e84261419c3747a0d20f396eb024bdbeb04c2bfb37e2891f5fadcb5 \ + --hash=sha256:664c5a68ec406f6b1547946683008576ef8b38275608e1cee6c061828171c118 \ + --hash=sha256:6675329885c44657f826ef01d9e4fb33b9158e9d93c537d84ad8399539bc6f69 \ + --hash=sha256:75c1a6dfac6abd407634420c93864a1e274ebc1c7531346d9254c0d8f6ca00f9 \ + --hash=sha256:75da8f28eff26defba42c52986de97b22106cb8f26515b7c22443ebc9c2d3261 \ + --hash=sha256:77efb033d8d7ff233385f30c62c7c79271c8885d5c9657d967ede124671bbdfb \ + --hash=sha256:78a7d3ab09dc47ac1a363a493e6112d8cabed7ba7caad5f54dbe2f08676d1b47 \ + --hash=sha256:7c7db70d57e5e1089a274cbb2b1fd635c9a24de809a231b154965d415d6c6d24 \ + --hash=sha256:8c56c488ab471628ff3bfa80964372fc13504ece601e0d97a78ee74126b2045c \ + --hash=sha256:91669ccac46bbc1d09e9273546181919064e8df73488ea087dcac3e2968df9ba \ + --hash=sha256:9b666a475a65f4e839d3d10473fad6d47e0a9db14a2f4a224029c5bfde58ad2c \ + --hash=sha256:9cfef3ab326780c04d6646f68d4b4742aae222e8b8ea1d627c74e38afcbc9d91 \ + --hash=sha256:a13fc8aeb24bad755eea8f7f9d409438eb94e82cf86b08fe77a03fbc8f6a96b1 \ + --hash=sha256:a75c301f96db737e1c5ed5fd7d77d9c34466de16095a266509e13da09751bd19 \ + --hash=sha256:a76d4cb80f41ba94a6691264be76435e5f72f2cb3cab0b092a6212855f71c2f6 \ + --hash=sha256:aed04cabe26f30c1647ef0e8fbb207516fd40fe9472e9439695f5c6998e60ac5 \ + --hash=sha256:b148b56f5de675ee16d45e769e69f87623a4944f7443850bf9a9376e628a89d2 \ + --hash=sha256:b501c862d4901792adaec7c25b1ecc749e2662543f68bb194c42ba18d6eec98d \ + --hash=sha256:b846a1fcf8beadeb9ea4f44ec5bdde393e2f1569e17d700bfc49cd69bde75881 \ + --hash=sha256:b931ae8f62db78861b0ff1ac017851764602288575d65b8e8ff1963fed419063 \ + --hash=sha256:c33ab3ca9d3ccd581d58e989d67554e42d8d4ded94ab3ade3508455fe70e65f7 \ + --hash=sha256:c6604b735bb12fef8e0efd5578c9fb5d3d8532d5001ea13a19cddf295673ee09 \ + --hash=sha256:d8db08051fc9e7d8bc622f2112511b8107d8f27cd89e2f64ec45e9825e8288da \ + --hash=sha256:d9203500f7c63545b4ce3799319fe4d9feb1a1b89b28d3cb5abd11b9dd64147e \ + --hash=sha256:dc492779501fa723b04d0ab1f5be046797fee17d27700476edc7ee9ae535a61e \ + --hash=sha256:e6bcdf33aec38d16508ce61fd81838f24c83c90a1d1b8c68982857038673d6b8 \ + --hash=sha256:e76ce097e3c57c4bcb67c5aa24a0ecdbd9f74ea9219997a707a4061fbe2707aa \ + --hash=sha256:eff1ac3cc66c2ac7cda1e64b4e2f3ffef474b7335f92fc3833fc632d595fcee6 \ + --hash=sha256:f3cb4a569029b9f291f88aafc927dd53683757e640081ca8c412781ea144565e \ + --hash=sha256:f79b168428351d11e10c5aeb61a74e1851ec221081299f4cf56036a95431c43a \ + --hash=sha256:fa646ecec9528bef693415c79a86e733c70a4965dd938e9a226b0fc64c9d2e6c \ + --hash=sha256:fe2efccb324948a11dd09d22136fe2ac8a97d6c1347cf0b58a911dcd529f66b7 \ + --hash=sha256:fff4f534200a04b4a36e7ae3cb74493afe807b517a09e99cb4faa89a34ed6ecd # via matplotlib -fsspec==2024.5.0 \ - --hash=sha256:1d021b0b0f933e3b3029ed808eb400c08ba101ca2de4b3483fbc9ca23fcee94a \ - --hash=sha256:e0fdbc446d67e182f49a70b82cf7889028a63588fde6b222521f10937b2b670c +fsspec==2025.10.0 \ + --hash=sha256:7c7712353ae7d875407f97715f0e1ffcc21e33d5b24556cb1e090ae9409ec61d \ + --hash=sha256:b6789427626f068f9a83ca4e8a3cc050850b6c0f71f99ddb4f542b8266a26a59 # via etils -hypothesis==6.102.4 \ - --hash=sha256:013df31b04a4daede13756f497e60e451963d86f426395a79f99c5d692919bbd \ - --hash=sha256:59b4d144346d5cffb482cc1bafbd21b13ff31608e8c4b3e4630339aee3e87763 +gast==0.6.0 \ + --hash=sha256:52b182313f7330389f72b069ba00f174cfe2a06411099547288839c6cbafbd54 \ + --hash=sha256:88fc5300d32c7ac6ca7b515310862f71e6fdf2c029bbec7c66c0f5dd47b6b1fb + # via tensorflow +google-pasta==0.2.0 \ + --hash=sha256:4612951da876b1a10fe3960d7226f0c7682cf901e16ac06e473b267a5afa8954 \ + --hash=sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed \ + --hash=sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e + # via tensorflow +grpcio==1.76.0 \ + --hash=sha256:035d90bc79eaa4bed83f524331d55e35820725c9fbb00ffa1904d5550ed7ede3 \ + --hash=sha256:04bbe1bfe3a68bbfd4e52402ab7d4eb59d72d02647ae2042204326cf4bbad280 \ + --hash=sha256:063065249d9e7e0782d03d2bca50787f53bd0fb89a67de9a7b521c4a01f1989b \ + --hash=sha256:06c3d6b076e7b593905d04fdba6a0525711b3466f43b3400266f04ff735de0cd \ + --hash=sha256:08caea849a9d3c71a542827d6df9d5a69067b0a1efbea8a855633ff5d9571465 \ + --hash=sha256:0aaa82d0813fd4c8e589fac9b65d7dd88702555f702fb10417f96e2a2a6d4c0f \ + --hash=sha256:0b7604868b38c1bfd5cf72d768aedd7db41d78cb6a4a18585e33fb0f9f2363fd \ + --hash=sha256:0c37db8606c258e2ee0c56b78c62fc9dee0e901b5dbdcf816c2dd4ad652b8b0c \ + --hash=sha256:1c9b93f79f48b03ada57ea24725d83a30284a012ec27eab2cf7e50a550cbbbcc \ + --hash=sha256:2107b0c024d1b35f4083f11245c0e23846ae64d02f40b2b226684840260ed054 \ + --hash=sha256:2229ae655ec4e8999599469559e97630185fdd53ae1e8997d147b7c9b2b72cba \ + --hash=sha256:25a18e9810fbc7e7f03ec2516addc116a957f8cbb8cbc95ccc80faa072743d03 \ + --hash=sha256:26ef06c73eb53267c2b319f43e6634c7556ea37672029241a056629af27c10e2 \ + --hash=sha256:2e1743fbd7f5fa713a1b0a8ac8ebabf0ec980b5d8809ec358d488e273b9cf02a \ + --hash=sha256:32483fe2aab2c3794101c2a159070584e5db11d0aa091b2c0ea9c4fc43d0d749 \ + --hash=sha256:3bf0f392c0b806905ed174dcd8bdd5e418a40d5567a05615a030a5aeddea692d \ + --hash=sha256:3e2a27c89eb9ac3d81ec8835e12414d73536c6e620355d65102503064a4ed6eb \ + --hash=sha256:40ad3afe81676fd9ec6d9d406eda00933f218038433980aa19d401490e46ecde \ + --hash=sha256:4215d3a102bd95e2e11b5395c78562967959824156af11fa93d18fdd18050990 \ + --hash=sha256:45d59a649a82df5718fd9527ce775fd66d1af35e6d31abdcdc906a49c6822958 \ + --hash=sha256:45e0111e73f43f735d70786557dc38141185072d7ff8dc1829d6a77ac1471468 \ + --hash=sha256:479496325ce554792dba6548fae3df31a72cef7bad71ca2e12b0e58f9b336bfc \ + --hash=sha256:490fa6d203992c47c7b9e4a9d39003a0c2bcc1c9aa3c058730884bbbb0ee9f09 \ + --hash=sha256:49ce47231818806067aea3324d4bf13825b658ad662d3b25fada0bdad9b8a6af \ + --hash=sha256:4baf3cbe2f0be3289eb68ac8ae771156971848bb8aaff60bad42005539431980 \ + --hash=sha256:522175aba7af9113c48ec10cc471b9b9bd4f6ceb36aeb4544a8e2c80ed9d252d \ + --hash=sha256:5e8571632780e08526f118f74170ad8d50fb0a48c23a746bef2a6ebade3abd6f \ + --hash=sha256:615ba64c208aaceb5ec83bfdce7728b80bfeb8be97562944836a7a0a9647d882 \ + --hash=sha256:61f69297cba3950a524f61c7c8ee12e55c486cb5f7db47ff9dcee33da6f0d3ae \ + --hash=sha256:65a20de41e85648e00305c1bb09a3598f840422e522277641145a32d42dcefcc \ + --hash=sha256:6a15c17af8839b6801d554263c546c69c4d7718ad4321e3166175b37eaacca77 \ + --hash=sha256:747fa73efa9b8b1488a95d0ba1039c8e2dca0f741612d80415b1e1c560febf4e \ + --hash=sha256:7be78388d6da1a25c0d5ec506523db58b18be22d9c37d8d3a32c08be4987bd73 \ + --hash=sha256:81fd9652b37b36f16138611c7e884eb82e0cec137c40d3ef7c3f9b3ed00f6ed8 \ + --hash=sha256:83d57312a58dcfe2a3a0f9d1389b299438909a02db60e2f2ea2ae2d8034909d3 \ + --hash=sha256:8843114c0cfce61b40ad48df65abcfc00d4dba82eae8718fab5352390848c5da \ + --hash=sha256:8cc3309d8e08fd79089e13ed4819d0af72aa935dd8f435a195fd152796752ff2 \ + --hash=sha256:8ebe63ee5f8fa4296b1b8cfc743f870d10e902ca18afc65c68cf46fd39bb0783 \ + --hash=sha256:8eddfb4d203a237da6f3cc8a540dad0517d274b5a1e9e636fd8d2c79b5c1d397 \ + --hash=sha256:922fa70ba549fce362d2e2871ab542082d66e2aaf0c19480ea453905b01f384e \ + --hash=sha256:931091142fd8cc14edccc0845a79248bc155425eee9a98b2db2ea4f00a235a42 \ + --hash=sha256:971fd5a1d6e62e00d945423a567e42eb1fa678ba89072832185ca836a94daaa6 \ + --hash=sha256:980a846182ce88c4f2f7e2c22c56aefd515daeb36149d1c897f83cf57999e0b6 \ + --hash=sha256:9d9adda641db7207e800a7f089068f6f645959f2df27e870ee81d44701dd9db3 \ + --hash=sha256:9f8f757bebaaea112c00dba718fc0d3260052ce714e25804a03f93f5d1c6cc11 \ + --hash=sha256:a6ae758eb08088d36812dd5d9af7a9859c05b1e0f714470ea243694b49278e7b \ + --hash=sha256:a8c2cf1209497cf659a667d7dea88985e834c24b7c3b605e6254cbb5076d985c \ + --hash=sha256:acab0277c40eff7143c2323190ea57b9ee5fd353d8190ee9652369fae735668a \ + --hash=sha256:b331680e46239e090f5b3cead313cc772f6caa7d0fc8de349337563125361a4a \ + --hash=sha256:c088e7a90b6017307f423efbb9d1ba97a22aa2170876223f9709e9d1de0b5347 \ + --hash=sha256:d099566accf23d21037f18a2a63d323075bebace807742e4b0ac210971d4dd70 \ + --hash=sha256:d388087771c837cdb6515539f43b9d4bf0b0f23593a24054ac16f7a960be16f4 \ + --hash=sha256:dcfe41187da8992c5f40aa8c5ec086fa3672834d2be57a32384c08d5a05b4c00 \ + --hash=sha256:e6d1db20594d9daba22f90da738b1a0441a7427552cc6e2e3d1297aeddc00378 \ + --hash=sha256:ebea5cc3aa8ea72e04df9913492f9a96d9348db876f9dda3ad729cfedf7ac416 \ + --hash=sha256:ebebf83299b0cb1721a8859ea98f3a77811e35dce7609c5c963b9ad90728f886 \ + --hash=sha256:f0e34c2079d47ae9f6188211db9e777c619a21d4faba6977774e8fa43b085e48 \ + --hash=sha256:f92f88e6c033db65a5ae3d97905c8fea9c725b63e28d5a75cb73b49bda5024d8 \ + --hash=sha256:f9f7bd5faab55f47231ad8dba7787866b69f5e93bc306e3915606779bbfb4ba8 \ + --hash=sha256:fd5ef5932f6475c436c4a55e4336ebbe47bd3272be04964a03d316bbf4afbcbc \ + --hash=sha256:ff8a59ea85a1f2191a0ffcc61298c571bc566332f82e5f5be1b83c9d8e668a62 + # via + # tensorboard + # tensorflow +h5py==3.15.1 \ + --hash=sha256:01f55111ca516f5568ae7a7fc8247dfce607de331b4467ee8a9a6ed14e5422c7 \ + --hash=sha256:0e2f471688402c3404fa4e13466e373e622fd4b74b47b56cfdff7cc688209422 \ + --hash=sha256:121b2b7a4c1915d63737483b7bff14ef253020f617c2fb2811f67a4bed9ac5e8 \ + --hash=sha256:25c8843fec43b2cc368aa15afa1cdf83fc5e17b1c4e10cd3771ef6c39b72e5ce \ + --hash=sha256:28a20e1a4082a479b3d7db2169f3a5034af010b90842e75ebbf2e9e49eb4183e \ + --hash=sha256:2cbc4104d3d4aca9d6db8c0c694555e255805bfeacf9eb1349bda871e26cacbe \ + --hash=sha256:316dd0f119734f324ca7ed10b5627a2de4ea42cc4dfbcedbee026aaa361c238c \ + --hash=sha256:4411c1867b9899a25e983fff56d820a66f52ac326bbe10c7cdf7d832c9dcd883 \ + --hash=sha256:4c45802bcb711e128a6839cb6c01e9ac648dc55df045c9542a675c771f15c8d5 \ + --hash=sha256:550e51131376889656feec4aff2170efc054a7fe79eb1da3bb92e1625d1ac878 \ + --hash=sha256:59b0d63b318bf3cc06687def2b45afd75926bbc006f7b8cd2b1a231299fc8599 \ + --hash=sha256:59b25cf02411bf12e14f803fef0b80886444c7fe21a5ad17c6a28d3f08098a1e \ + --hash=sha256:5aaa330bcbf2830150c50897ea5dcbed30b5b6d56897289846ac5b9e529ec243 \ + --hash=sha256:5b849ba619a066196169763c33f9f0f02e381156d61c03e000bb0100f9950faf \ + --hash=sha256:5f4fb0567eb8517c3ecd6b3c02c4f4e9da220c8932604960fd04e24ee1254763 \ + --hash=sha256:61d5a58a9851e01ee61c932bbbb1c98fe20aba0a5674776600fb9a361c0aa652 \ + --hash=sha256:64ce3f6470adb87c06e3a8dd1b90e973699f1759ad79bfa70c230939bff356c9 \ + --hash=sha256:67e59f6c2f19a32973a40f43d9a088ae324fe228c8366e25ebc57ceebf093a6b \ + --hash=sha256:80e5bb5b9508d5d9da09f81fd00abbb3f85da8143e56b1585d59bc8ceb1dba8b \ + --hash=sha256:8a33bfd5dfcea037196f7778534b1ff7e36a7f40a89e648c8f2967292eb6898e \ + --hash=sha256:954e480433e82d3872503104f9b285d369048c3a788b2b1a00e53d1c47c98dd2 \ + --hash=sha256:99d374a21f7321a4c6ab327c4ab23bd925ad69821aeb53a1e75dd809d19f67fa \ + --hash=sha256:9c73d1d7cdb97d5b17ae385153472ce118bed607e43be11e9a9deefaa54e0734 \ + --hash=sha256:a308fd8681a864c04423c0324527237a0484e2611e3441f8089fd00ed56a8171 \ + --hash=sha256:a6d8c5a05a76aca9a494b4c53ce8a9c29023b7f64f625c6ce1841e92a362ccdf \ + --hash=sha256:ab2219dbc6fcdb6932f76b548e2b16f34a1f52b7666e998157a4dfc02e2c4123 \ + --hash=sha256:b39239947cb36a819147fc19e86b618dcb0953d1cd969f5ed71fc0de60392427 \ + --hash=sha256:b51469890e58e85d5242e43aab29f5e9c7e526b951caab354f3ded4ac88e7b76 \ + --hash=sha256:c256254a8a81e2bddc0d376e23e2a6d2dc8a1e8a2261835ed8c1281a0744cd97 \ + --hash=sha256:c8440fd8bee9500c235ecb7aa1917a0389a2adb80c209fa1cc485bd70e0d94a5 \ + --hash=sha256:c86e3ed45c4473564de55aa83b6fc9e5ead86578773dfbd93047380042e26b69 \ + --hash=sha256:c970fb80001fffabb0109eaf95116c8e7c0d3ca2de854e0901e8a04c1f098509 \ + --hash=sha256:ca8a3a22458956ee7b40d8e39c9a9dc01f82933e4c030c964f8b875592f4d831 \ + --hash=sha256:d8cb02c3a96255149ed3ac811eeea25b655d959c6dd5ce702c9a95ff11859eb5 \ + --hash=sha256:dea78b092fd80a083563ed79a3171258d4a4d307492e7cf8b2313d464c82ba52 \ + --hash=sha256:e02fe77a03f652500d8bff288cbf3675f742fc0411f5a628fa37116507dc7cc0 \ + --hash=sha256:e7f6c841efd4e6e5b7e82222eaf90819927b6d256ab0f3aca29675601f654f3c \ + --hash=sha256:f4a016df3f4a8a14d573b496e4d1964deb380e26031fc85fb40e417e9131888a \ + --hash=sha256:fa8df5267f545b4946df8ca0d93d23382191018e4cda2deda4c2cedf9a010e13 \ + --hash=sha256:fd125c131889ebbef0849f4a0e29cf363b48aba42f228d08b4079913b576bb3a + # via + # keras + # tensorflow +hypothesis==6.142.1 \ + --hash=sha256:3179cb08756562c526aaf4a9871ebbff83d2d75c03896ed0bc9c1d14097a930c \ + --hash=sha256:95a7d38fcc58e697e3020665adcb951c630cdbc8065e4b4474949e486b06bd6d # via -r build/test-requirements.txt -importlib-resources==6.4.0 \ - --hash=sha256:50d10f043df931902d4194ea07ec57960f66a80449ff867bfe782b4c486ba78c \ - --hash=sha256:cdb2b453b8046ca4e3798eb1d84f3cce1446a0e8e7b5ef4efb600f19fc398145 +idna==3.11 \ + --hash=sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea \ + --hash=sha256:795dafcc9c04ed0c1fb032c2aa73654d8e8c5023a7df64a53f39190ada629902 + # via requests +importlib-resources==6.5.2 \ + --hash=sha256:185f87adef5bcc288449d98fb4fba07cea78bc036455dd44c5fc4a2fe78fed2c \ + --hash=sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec # via etils -iniconfig==2.0.0 \ - --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ - --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 +iniconfig==2.3.0 \ + --hash=sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730 \ + --hash=sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12 # via pytest -kiwisolver==1.4.5 \ - --hash=sha256:00bd361b903dc4bbf4eb165f24d1acbee754fce22ded24c3d56eec268658a5cf \ - --hash=sha256:040c1aebeda72197ef477a906782b5ab0d387642e93bda547336b8957c61022e \ - --hash=sha256:05703cf211d585109fcd72207a31bb170a0f22144d68298dc5e61b3c946518af \ - --hash=sha256:06f54715b7737c2fecdbf140d1afb11a33d59508a47bf11bb38ecf21dc9ab79f \ - --hash=sha256:0dc9db8e79f0036e8173c466d21ef18e1befc02de8bf8aa8dc0813a6dc8a7046 \ - --hash=sha256:0f114aa76dc1b8f636d077979c0ac22e7cd8f3493abbab152f20eb8d3cda71f3 \ - --hash=sha256:11863aa14a51fd6ec28688d76f1735f8f69ab1fabf388851a595d0721af042f5 \ - --hash=sha256:11c7de8f692fc99816e8ac50d1d1aef4f75126eefc33ac79aac02c099fd3db71 \ - --hash=sha256:11d011a7574eb3b82bcc9c1a1d35c1d7075677fdd15de527d91b46bd35e935ee \ - --hash=sha256:146d14bebb7f1dc4d5fbf74f8a6cb15ac42baadee8912eb84ac0b3b2a3dc6ac3 \ - --hash=sha256:15568384086b6df3c65353820a4473575dbad192e35010f622c6ce3eebd57af9 \ - --hash=sha256:19df6e621f6d8b4b9c4d45f40a66839294ff2bb235e64d2178f7522d9170ac5b \ - --hash=sha256:1b04139c4236a0f3aff534479b58f6f849a8b351e1314826c2d230849ed48985 \ - --hash=sha256:210ef2c3a1f03272649aff1ef992df2e724748918c4bc2d5a90352849eb40bea \ - --hash=sha256:2270953c0d8cdab5d422bee7d2007f043473f9d2999631c86a223c9db56cbd16 \ - --hash=sha256:2400873bccc260b6ae184b2b8a4fec0e4082d30648eadb7c3d9a13405d861e89 \ - --hash=sha256:2a40773c71d7ccdd3798f6489aaac9eee213d566850a9533f8d26332d626b82c \ - --hash=sha256:2c5674c4e74d939b9d91dda0fae10597ac7521768fec9e399c70a1f27e2ea2d9 \ - --hash=sha256:3195782b26fc03aa9c6913d5bad5aeb864bdc372924c093b0f1cebad603dd712 \ - --hash=sha256:31a82d498054cac9f6d0b53d02bb85811185bcb477d4b60144f915f3b3126342 \ - --hash=sha256:32d5cf40c4f7c7b3ca500f8985eb3fb3a7dfc023215e876f207956b5ea26632a \ - --hash=sha256:346f5343b9e3f00b8db8ba359350eb124b98c99efd0b408728ac6ebf38173958 \ - --hash=sha256:378a214a1e3bbf5ac4a8708304318b4f890da88c9e6a07699c4ae7174c09a68d \ - --hash=sha256:39b42c68602539407884cf70d6a480a469b93b81b7701378ba5e2328660c847a \ - --hash=sha256:3a2b053a0ab7a3960c98725cfb0bf5b48ba82f64ec95fe06f1d06c99b552e130 \ - --hash=sha256:3aba7311af82e335dd1e36ffff68aaca609ca6290c2cb6d821a39aa075d8e3ff \ - --hash=sha256:3cd32d6c13807e5c66a7cbb79f90b553642f296ae4518a60d8d76243b0ad2898 \ - --hash=sha256:3edd2fa14e68c9be82c5b16689e8d63d89fe927e56debd6e1dbce7a26a17f81b \ - --hash=sha256:4c380469bd3f970ef677bf2bcba2b6b0b4d5c75e7a020fb863ef75084efad66f \ - --hash=sha256:4e66e81a5779b65ac21764c295087de82235597a2293d18d943f8e9e32746265 \ - --hash=sha256:53abb58632235cd154176ced1ae8f0d29a6657aa1aa9decf50b899b755bc2b93 \ - --hash=sha256:5794cf59533bc3f1b1c821f7206a3617999db9fbefc345360aafe2e067514929 \ - --hash=sha256:59415f46a37f7f2efeec758353dd2eae1b07640d8ca0f0c42548ec4125492635 \ - --hash=sha256:59ec7b7c7e1a61061850d53aaf8e93db63dce0c936db1fda2658b70e4a1be709 \ - --hash=sha256:59edc41b24031bc25108e210c0def6f6c2191210492a972d585a06ff246bb79b \ - --hash=sha256:5a580c91d686376f0f7c295357595c5a026e6cbc3d77b7c36e290201e7c11ecb \ - --hash=sha256:5b94529f9b2591b7af5f3e0e730a4e0a41ea174af35a4fd067775f9bdfeee01a \ - --hash=sha256:5c7b3b3a728dc6faf3fc372ef24f21d1e3cee2ac3e9596691d746e5a536de920 \ - --hash=sha256:5c90ae8c8d32e472be041e76f9d2f2dbff4d0b0be8bd4041770eddb18cf49a4e \ - --hash=sha256:5e7139af55d1688f8b960ee9ad5adafc4ac17c1c473fe07133ac092310d76544 \ - --hash=sha256:5ff5cf3571589b6d13bfbfd6bcd7a3f659e42f96b5fd1c4830c4cf21d4f5ef45 \ - --hash=sha256:620ced262a86244e2be10a676b646f29c34537d0d9cc8eb26c08f53d98013390 \ - --hash=sha256:6512cb89e334e4700febbffaaa52761b65b4f5a3cf33f960213d5656cea36a77 \ - --hash=sha256:6c08e1312a9cf1074d17b17728d3dfce2a5125b2d791527f33ffbe805200a355 \ - --hash=sha256:6c3bd3cde54cafb87d74d8db50b909705c62b17c2099b8f2e25b461882e544ff \ - --hash=sha256:6ef7afcd2d281494c0a9101d5c571970708ad911d028137cd558f02b851c08b4 \ - --hash=sha256:7269d9e5f1084a653d575c7ec012ff57f0c042258bf5db0954bf551c158466e7 \ - --hash=sha256:72d40b33e834371fd330fb1472ca19d9b8327acb79a5821d4008391db8e29f20 \ - --hash=sha256:74d1b44c6cfc897df648cc9fdaa09bc3e7679926e6f96df05775d4fb3946571c \ - --hash=sha256:74db36e14a7d1ce0986fa104f7d5637aea5c82ca6326ed0ec5694280942d1162 \ - --hash=sha256:763773d53f07244148ccac5b084da5adb90bfaee39c197554f01b286cf869228 \ - --hash=sha256:76c6a5964640638cdeaa0c359382e5703e9293030fe730018ca06bc2010c4437 \ - --hash=sha256:76d9289ed3f7501012e05abb8358bbb129149dbd173f1f57a1bf1c22d19ab7cc \ - --hash=sha256:7931d8f1f67c4be9ba1dd9c451fb0eeca1a25b89e4d3f89e828fe12a519b782a \ - --hash=sha256:7b8b454bac16428b22560d0a1cf0a09875339cab69df61d7805bf48919415901 \ - --hash=sha256:7e5bab140c309cb3a6ce373a9e71eb7e4873c70c2dda01df6820474f9889d6d4 \ - --hash=sha256:83d78376d0d4fd884e2c114d0621624b73d2aba4e2788182d286309ebdeed770 \ - --hash=sha256:852542f9481f4a62dbb5dd99e8ab7aedfeb8fb6342349a181d4036877410f525 \ - --hash=sha256:85267bd1aa8880a9c88a8cb71e18d3d64d2751a790e6ca6c27b8ccc724bcd5ad \ - --hash=sha256:88a2df29d4724b9237fc0c6eaf2a1adae0cdc0b3e9f4d8e7dc54b16812d2d81a \ - --hash=sha256:88b9f257ca61b838b6f8094a62418421f87ac2a1069f7e896c36a7d86b5d4c29 \ - --hash=sha256:8ab3919a9997ab7ef2fbbed0cc99bb28d3c13e6d4b1ad36e97e482558a91be90 \ - --hash=sha256:92dea1ffe3714fa8eb6a314d2b3c773208d865a0e0d35e713ec54eea08a66250 \ - --hash=sha256:9407b6a5f0d675e8a827ad8742e1d6b49d9c1a1da5d952a67d50ef5f4170b18d \ - --hash=sha256:9408acf3270c4b6baad483865191e3e582b638b1654a007c62e3efe96f09a9a3 \ - --hash=sha256:955e8513d07a283056b1396e9a57ceddbd272d9252c14f154d450d227606eb54 \ - --hash=sha256:9db8ea4c388fdb0f780fe91346fd438657ea602d58348753d9fb265ce1bca67f \ - --hash=sha256:9eaa8b117dc8337728e834b9c6e2611f10c79e38f65157c4c38e9400286f5cb1 \ - --hash=sha256:a51a263952b1429e429ff236d2f5a21c5125437861baeed77f5e1cc2d2c7c6da \ - --hash=sha256:a6aa6315319a052b4ee378aa171959c898a6183f15c1e541821c5c59beaa0238 \ - --hash=sha256:aa12042de0171fad672b6c59df69106d20d5596e4f87b5e8f76df757a7c399aa \ - --hash=sha256:aaf7be1207676ac608a50cd08f102f6742dbfc70e8d60c4db1c6897f62f71523 \ - --hash=sha256:b0157420efcb803e71d1b28e2c287518b8808b7cf1ab8af36718fd0a2c453eb0 \ - --hash=sha256:b3f7e75f3015df442238cca659f8baa5f42ce2a8582727981cbfa15fee0ee205 \ - --hash=sha256:b9098e0049e88c6a24ff64545cdfc50807818ba6c1b739cae221bbbcbc58aad3 \ - --hash=sha256:ba55dce0a9b8ff59495ddd050a0225d58bd0983d09f87cfe2b6aec4f2c1234e4 \ - --hash=sha256:bb86433b1cfe686da83ce32a9d3a8dd308e85c76b60896d58f082136f10bffac \ - --hash=sha256:bbea0db94288e29afcc4c28afbf3a7ccaf2d7e027489c449cf7e8f83c6346eb9 \ - --hash=sha256:bbf1d63eef84b2e8c89011b7f2235b1e0bf7dacc11cac9431fc6468e99ac77fb \ - --hash=sha256:c7940c1dc63eb37a67721b10d703247552416f719c4188c54e04334321351ced \ - --hash=sha256:c9bf3325c47b11b2e51bca0824ea217c7cd84491d8ac4eefd1e409705ef092bd \ - --hash=sha256:cdc8a402aaee9a798b50d8b827d7ecf75edc5fb35ea0f91f213ff927c15f4ff0 \ - --hash=sha256:ceec1a6bc6cab1d6ff5d06592a91a692f90ec7505d6463a88a52cc0eb58545da \ - --hash=sha256:cfe6ab8da05c01ba6fbea630377b5da2cd9bcbc6338510116b01c1bc939a2c18 \ - --hash=sha256:d099e745a512f7e3bbe7249ca835f4d357c586d78d79ae8f1dcd4d8adeb9bda9 \ - --hash=sha256:d0ef46024e6a3d79c01ff13801cb19d0cad7fd859b15037aec74315540acc276 \ - --hash=sha256:d2e5a98f0ec99beb3c10e13b387f8db39106d53993f498b295f0c914328b1333 \ - --hash=sha256:da4cfb373035def307905d05041c1d06d8936452fe89d464743ae7fb8371078b \ - --hash=sha256:da802a19d6e15dffe4b0c24b38b3af68e6c1a68e6e1d8f30148c83864f3881db \ - --hash=sha256:dced8146011d2bc2e883f9bd68618b8247387f4bbec46d7392b3c3b032640126 \ - --hash=sha256:dfdd7c0b105af050eb3d64997809dc21da247cf44e63dc73ff0fd20b96be55a9 \ - --hash=sha256:e368f200bbc2e4f905b8e71eb38b3c04333bddaa6a2464a6355487b02bb7fb09 \ - --hash=sha256:e391b1f0a8a5a10ab3b9bb6afcfd74f2175f24f8975fb87ecae700d1503cdee0 \ - --hash=sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec \ - --hash=sha256:e5d706eba36b4c4d5bc6c6377bb6568098765e990cfc21ee16d13963fab7b3e7 \ - --hash=sha256:ec20916e7b4cbfb1f12380e46486ec4bcbaa91a9c448b97023fde0d5bbf9e4ff \ - --hash=sha256:f1d072c2eb0ad60d4c183f3fb44ac6f73fb7a8f16a2694a91f988275cbf352f9 \ - --hash=sha256:f846c260f483d1fd217fe5ed7c173fb109efa6b1fc8381c8b7552c5781756192 \ - --hash=sha256:f91de7223d4c7b793867797bacd1ee53bfe7359bd70d27b7b58a04efbb9436c8 \ - --hash=sha256:faae4860798c31530dd184046a900e652c95513796ef51a12bc086710c2eec4d \ - --hash=sha256:fc579bf0f502e54926519451b920e875f433aceb4624a3646b3252b5caa9e0b6 \ - --hash=sha256:fcc700eadbbccbf6bc1bcb9dbe0786b4b1cb91ca0dcda336eef5c2beed37b797 \ - --hash=sha256:fd32ea360bcbb92d28933fc05ed09bffcb1704ba3fc7942e81db0fd4f81a7892 \ - --hash=sha256:fdb7adb641a0d13bdcd4ef48e062363d8a9ad4a182ac7647ec88f695e719ae9f +jax-cuda12-pjrt==0.8.2 ; sys_platform == "linux" \ + --hash=sha256:717a1b196a642409ce195ddf031c20bbeadcc886f55e49a1d3f4927373aeedae \ + --hash=sha256:e3bab41ca7c48e4163db9e7efd271b3aa85f0fe45f5ed0708d6bbed93a59f977 + # via + # -r build/requirements.in + # jax-cuda12-plugin +jax-cuda12-plugin==0.8.2 ; sys_platform == "linux" and python_version < "3.14" \ + --hash=sha256:0b0a3304ce7e494acd8d9c593490c112a32cdb6010fe1afc584d9e41fd863167 \ + --hash=sha256:1b4828242d57f233b394d17ebaa599c503c1fb9b7c754012a06eb84dbc935fc8 \ + --hash=sha256:20165861b3d3e66ebb2c0f63a547d1d5ee17ea44ac3be7153c7908c9ca8c88f3 \ + --hash=sha256:377e4be17e22dde0343b3f3c05bf69235b3dbf11d766cca9c5a93da47971dcb7 \ + --hash=sha256:403d5e07731b5cdac3bd9fb3f448bd8480062cb2c0ab61ea2ad23fcd0a65479a \ + --hash=sha256:58c51473fc622e03138035985f741833564d70a4bd5a2178f61b62cdaa32ff94 \ + --hash=sha256:637387dc3408cd204562668502f9e95f76c6edde0a6d2e48f055162dc2aebf0d \ + --hash=sha256:70d33222484ad5c375b8f8357b7c23cacb844f6ecfc39567f8dd47fde6e87858 \ + --hash=sha256:82c6798be66bf8c773386918e4c8e5cd8119753f3bfb3ca4bbc46818283750c6 \ + --hash=sha256:a5898bac1d8ab6020b54546440256409f2c66bcbbb3a1099ca473c84843addad \ + --hash=sha256:d68a6d8b4a45ee561746bac7a6468da8203832626b0b39ad4ac43011f61f875d \ + --hash=sha256:dd4f7c34d4512ff5a36fd1b01584ef7781cad615e3f9e71880eae2f4998e5108 + # via -r build/requirements.in +jax-cuda13-pjrt==0.8.2 \ + --hash=sha256:370e9cb9a76af6a6b00e8f6c6ae8e5a0158886b65994a114cf3acf56ea1570f3 \ + --hash=sha256:c0ebc284d4bea9c0b278017b4bb62eb871755fe582cbf22ce1fb770d30fb9af6 + # via + # -r build/requirements.in + # jax-cuda13-plugin +jax-cuda13-plugin==0.8.2 \ + --hash=sha256:0ab3a13f798f973416754cb00b1696a8e6d2425c39496e2145293661bcd28446 \ + --hash=sha256:2c7532f2a6c9dcf50ba4b8ebbe4c78988b8c260695e50f9fdbd68f9417a8bb51 \ + --hash=sha256:358f4b4d0d2b92f62fe7f07d355a9c57c5684da3d4e6f483afec38b23e66f02e \ + --hash=sha256:6703552ee10b6b41b43fe4fc62b4672bec13d271eb709a29bf14b0fa1ec805dd \ + --hash=sha256:8fe2c5874cbae18a01a0fa722dfbd547b5504e7bf42aaaae6ad3e1b676e8daa2 \ + --hash=sha256:950f2c1f2e0f5e13893c4c27a4df540a96a26ebd15d7954fa85ffb8cc8b1d9a1 \ + --hash=sha256:9835076d0a917345700370cfaee5cfd7876c0b0982e99206a8b4848af0ea6e5e \ + --hash=sha256:bb059d99eb960718358f82f3eaa0829f96fc7d7e978ded821f2c40256a8aee4b \ + --hash=sha256:d5f9470db46aee8a64e9ab7817927335b77c6d3698d5113e0da2dd962c1468c4 \ + --hash=sha256:f1643b216ccb41a2d7c0e240359c729c96c5eca76e4525ffcd6313049e348b9a \ + --hash=sha256:f35f687d9b839365a49c2b0052eaeb6b9f0a9879813b4fa020f8d9caa3ddde46 \ + --hash=sha256:fa6a6faf3130d6ef05f9ee6cbc27e0a2e4db0f23720136403991eccb3a0144af + # via -r build/requirements.in +jaxlib==0.8.2 \ + --hash=sha256:023de6f3f56da2af7037970996500586331fdb50b530ecbb54b9666da633bd00 \ + --hash=sha256:05b958f497e49824c432e734bb059723b7dfe69e2ad696a9f9c8ad82fff7c3f8 \ + --hash=sha256:1bfbcf6c3de221784fa4cdb6765a09d71cb4298b15626b3d0409b3dfcd8a8667 \ + --hash=sha256:28eec1a4e0639a0d8702cea3cb70dd3663053dbfa344452994ea48dc6ceadaa5 \ + --hash=sha256:2b9789bd08f8b0cc5a5c12ae896fe432d5942e32e417091b8b5a96a9a6fd5cf1 \ + --hash=sha256:3b16e50c5b730c9dd0a49e55f1acfaa722b00b1af0522a591558dcc0464252f2 \ + --hash=sha256:490bf0cb029c73c65c9431124b86cdc95082dbc1fb76fc549d24d75da33e5454 \ + --hash=sha256:4d006db96be020c8165212a1216372f8acac4ff4f8fb067743d694ef2b301ace \ + --hash=sha256:68108dff0de74adc468016be9a19f80efe48c660c0d5a122287094b44b092afc \ + --hash=sha256:7c304f3a016965b9d1f5239a8a0399a73925f5604fe914c5ca66ecf734bf6422 \ + --hash=sha256:7da8127557c786264049ae55460d1b8d04cc3cdf0403a087f2fc1e6d313ec722 \ + --hash=sha256:964626f581beab31ee6826b228fcc2ec5181b05cecf94a528dff97921c145dbc \ + --hash=sha256:a397ea7dcb37d689ce79173eeb99b2f1347637a36be9a27f20ae6848bfc58bfc \ + --hash=sha256:aa8701b6356f098e8452c3cec762fb5f706fcb8f67ffd65964f63982479aa23b \ + --hash=sha256:bb89be452b1b808d3f88fc01c415b364a260be4cc7ac120c038009f6150a32dc \ + --hash=sha256:beffb004e7eeb5c9afb24439e2b2cf45a4ee3e3e8adf45e355edf2af62acf8b8 \ + --hash=sha256:ccf77da917a20935247c990691decfcbdd06c25ef0ac94d914a04aadb22f714c \ + --hash=sha256:dffc22b5b732b9556d92c918b251c61bcc046617c4dbb51e1f7a656587fddffb \ + --hash=sha256:e6a97dfb0232eed9a2bb6e3828e4f682dbac1a7fea840bfda574cae2dbf5faf9 \ + --hash=sha256:f205e91c3a152a2a76c0bc59a6a2de03e87ec261b91e8812922777185e7b08f5 \ + --hash=sha256:f28edac8c226fc07fa3e8af6f9defede8ac2c307429e3291edce8739d39becc9 \ + --hash=sha256:f472cc72e3058e50b5f0230b236d5a1183bf6c3d5423d2a52eff07bcf34908de + # via -r build/requirements.in +keras==3.12.0 \ + --hash=sha256:02b69e007d5df8042286c3bcc2a888539e3e487590ffb08f6be1b4354df50aa8 \ + --hash=sha256:536e3f8385a05ae04e82e08715a1a59988578087e187b04cb0a6fad11743f07f + # via tensorflow +kiwisolver==1.4.9 \ + --hash=sha256:0749fd8f4218ad2e851e11cc4dc05c7cbc0cbc4267bdfdb31782e65aace4ee9c \ + --hash=sha256:0763515d4df10edf6d06a3c19734e2566368980d21ebec439f33f9eb936c07b7 \ + --hash=sha256:0856e241c2d3df4efef7c04a1e46b1936b6120c9bcf36dd216e3acd84bc4fb21 \ + --hash=sha256:0a590506f303f512dff6b7f75fd2fd18e16943efee932008fe7140e5fa91d80e \ + --hash=sha256:0ab74e19f6a2b027ea4f845a78827969af45ce790e6cb3e1ebab71bdf9f215ff \ + --hash=sha256:0ae37737256ba2de764ddc12aed4956460277f00c4996d51a197e72f62f5eec7 \ + --hash=sha256:0e4e2bf29574a6a7b7f6cb5fa69293b9f96c928949ac4a53ba3f525dffb87f9c \ + --hash=sha256:15163165efc2f627eb9687ea5f3a28137217d217ac4024893d753f46bce9de26 \ + --hash=sha256:17680d737d5335b552994a2008fab4c851bcd7de33094a82067ef3a576ff02fa \ + --hash=sha256:1a12cf6398e8a0a001a059747a1cbf24705e18fe413bc22de7b3d15c67cffe3f \ + --hash=sha256:1b11d6a633e4ed84fc0ddafd4ebfd8ea49b3f25082c04ad12b8315c11d504dc1 \ + --hash=sha256:1fa333e8b2ce4d9660f2cda9c0e1b6bafcfb2457a9d259faa82289e73ec24891 \ + --hash=sha256:2327a4a30d3ee07d2fbe2e7933e8a37c591663b96ce42a00bc67461a87d7df77 \ + --hash=sha256:2405a7d98604b87f3fc28b1716783534b1b4b8510d8142adca34ee0bc3c87543 \ + --hash=sha256:2489e4e5d7ef9a1c300a5e0196e43d9c739f066ef23270607d45aba368b91f2d \ + --hash=sha256:24c175051354f4a28c5d6a31c93906dc653e2bf234e8a4bbfb964892078898ce \ + --hash=sha256:2635d352d67458b66fd0667c14cb1d4145e9560d503219034a18a87e971ce4f3 \ + --hash=sha256:2c1a4f57df73965f3f14df20b80ee29e6a7930a57d2d9e8491a25f676e197c60 \ + --hash=sha256:2c93f00dcba2eea70af2be5f11a830a742fe6b579a1d4e00f47760ef13be247a \ + --hash=sha256:39a219e1c81ae3b103643d2aedb90f1ef22650deb266ff12a19e7773f3e5f089 \ + --hash=sha256:3b3115b2581ea35bb6d1f24a4c90af37e5d9b49dcff267eeed14c3893c5b86ab \ + --hash=sha256:40092754720b174e6ccf9e845d0d8c7d8e12c3d71e7fc35f55f3813e96376f78 \ + --hash=sha256:412f287c55a6f54b0650bd9b6dce5aceddb95864a1a90c87af16979d37c89771 \ + --hash=sha256:464415881e4801295659462c49461a24fb107c140de781d55518c4b80cb6790f \ + --hash=sha256:497d05f29a1300d14e02e6441cf0f5ee81c1ff5a304b0d9fb77423974684e08b \ + --hash=sha256:4a2899935e724dd1074cb568ce7ac0dce28b2cd6ab539c8e001a8578eb106d14 \ + --hash=sha256:4a48a2ce79d65d363597ef7b567ce3d14d68783d2b2263d98db3d9477805ba32 \ + --hash=sha256:4d1d9e582ad4d63062d34077a9a1e9f3c34088a2ec5135b1f7190c07cf366527 \ + --hash=sha256:52a15b0f35dad39862d376df10c5230155243a2c1a436e39eb55623ccbd68185 \ + --hash=sha256:540c7c72324d864406a009d72f5d6856f49693db95d1fbb46cf86febef873634 \ + --hash=sha256:5656aa670507437af0207645273ccdfee4f14bacd7f7c67a4306d0dcaeaf6eed \ + --hash=sha256:5a0f2724dfd4e3b3ac5a82436a8e6fd16baa7d507117e4279b660fe8ca38a3a1 \ + --hash=sha256:60c439763a969a6af93b4881db0eed8fadf93ee98e18cbc35bc8da868d0c4f0c \ + --hash=sha256:61874cdb0a36016354853593cffc38e56fc9ca5aa97d2c05d3dcf6922cd55a11 \ + --hash=sha256:67bb8b474b4181770f926f7b7d2f8c0248cbcb78b660fdd41a47054b28d2a752 \ + --hash=sha256:720e05574713db64c356e86732c0f3c5252818d05f9df320f0ad8380641acea5 \ + --hash=sha256:72d0eb9fba308b8311685c2268cf7d0a0639a6cd027d8128659f72bdd8a024b4 \ + --hash=sha256:767c23ad1c58c9e827b649a9ab7809fd5fd9db266a9cf02b0e926ddc2c680d58 \ + --hash=sha256:77937e5e2a38a7b48eef0585114fe7930346993a88060d0bf886086d2aa49ef5 \ + --hash=sha256:7a08b491ec91b1d5053ac177afe5290adacf1f0f6307d771ccac5de30592d198 \ + --hash=sha256:7b4da0d01ac866a57dd61ac258c5607b4cd677f63abaec7b148354d2b2cdd536 \ + --hash=sha256:7cf974dd4e35fa315563ac99d6287a1024e4dc2077b8a7d7cd3d2fb65d283134 \ + --hash=sha256:84fd60810829c27ae375114cd379da1fa65e6918e1da405f356a775d49a62bcf \ + --hash=sha256:858e4c22fb075920b96a291928cb7dea5644e94c0ee4fcd5af7e865655e4ccf2 \ + --hash=sha256:85b5352f94e490c028926ea567fc569c52ec79ce131dadb968d3853e809518c2 \ + --hash=sha256:85bd218b5ecfbee8c8a82e121802dcb519a86044c9c3b2e4aef02fa05c6da370 \ + --hash=sha256:8a1f570ce4d62d718dce3f179ee78dac3b545ac16c0c04bb363b7607a949c0d1 \ + --hash=sha256:8fdca1def57a2e88ef339de1737a1449d6dbf5fab184c54a1fca01d541317154 \ + --hash=sha256:90f47e70293fc3688b71271100a1a5453aa9944a81d27ff779c108372cf5567b \ + --hash=sha256:92a2f997387a1b79a75e7803aa7ded2cfbe2823852ccf1ba3bcf613b62ae3197 \ + --hash=sha256:9928fe1eb816d11ae170885a74d074f57af3a0d65777ca47e9aeb854a1fba386 \ + --hash=sha256:9af39d6551f97d31a4deebeac6f45b156f9755ddc59c07b402c148f5dbb6482a \ + --hash=sha256:9cf554f21be770f5111a1690d42313e140355e687e05cf82cb23d0a721a64a48 \ + --hash=sha256:a30fd6fdef1430fd9e1ba7b3398b5ee4e2887783917a687d86ba69985fb08748 \ + --hash=sha256:a31d512c812daea6d8b3be3b2bfcbeb091dbb09177706569bcfc6240dcf8b41c \ + --hash=sha256:a5d0432ccf1c7ab14f9949eec60c5d1f924f17c037e9f8b33352fa05799359b8 \ + --hash=sha256:a60ea74330b91bd22a29638940d115df9dc00af5035a9a2a6ad9399ffb4ceca5 \ + --hash=sha256:ac5a486ac389dddcc5bef4f365b6ae3ffff2c433324fb38dd35e3fab7c957999 \ + --hash=sha256:aedff62918805fb62d43a4aa2ecd4482c380dc76cd31bd7c8878588a61bd0369 \ + --hash=sha256:b34e51affded8faee0dfdb705416153819d8ea9250bbbf7ea1b249bdeb5f1122 \ + --hash=sha256:b4b4d74bda2b8ebf4da5bd42af11d02d04428b2c32846e4c2c93219df8a7987b \ + --hash=sha256:b67e6efbf68e077dd71d1a6b37e43e1a99d0bff1a3d51867d45ee8908b931098 \ + --hash=sha256:b78efa4c6e804ecdf727e580dbb9cba85624d2e1c6b5cb059c66290063bd99a9 \ + --hash=sha256:bb4ae2b57fc1d8cbd1cf7b1d9913803681ffa903e7488012be5b76dedf49297f \ + --hash=sha256:bdd1a81a1860476eb41ac4bc1e07b3f07259e6d55bbf739b79c8aaedcf512799 \ + --hash=sha256:bdee92c56a71d2b24c33a7d4c2856bd6419d017e08caa7802d2963870e315028 \ + --hash=sha256:be6a04e6c79819c9a8c2373317d19a96048e5a3f90bec587787e86a1153883c2 \ + --hash=sha256:bfc08add558155345129c7803b3671cf195e6a56e7a12f3dde7c57d9b417f525 \ + --hash=sha256:c3b22c26c6fd6811b0ae8363b95ca8ce4ea3c202d3d0975b2914310ceb1bcc4d \ + --hash=sha256:c9e7cdf45d594ee04d5be1b24dd9d49f3d1590959b2271fb30b5ca2b262c00fb \ + --hash=sha256:cb27e7b78d716c591e88e0a09a2139c6577865d7f2e152488c2cc6257f460872 \ + --hash=sha256:cc9617b46837c6468197b5945e196ee9ca43057bb7d9d1ae688101e4e1dddf64 \ + --hash=sha256:ccd09f20ccdbbd341b21a67ab50a119b64a403b09288c27481575105283c1586 \ + --hash=sha256:ce6a3a4e106cf35c2d9c4fa17c05ce0b180db622736845d4315519397a77beaf \ + --hash=sha256:d0005b053977e7b43388ddec89fa567f43d4f6d5c2c0affe57de5ebf290dc552 \ + --hash=sha256:d4188e73af84ca82468f09cadc5ac4db578109e52acb4518d8154698d3a87ca2 \ + --hash=sha256:d4efec7bcf21671db6a3294ff301d2fc861c31faa3c8740d1a94689234d1b415 \ + --hash=sha256:d75aa530ccfaa593da12834b86a0724f58bff12706659baa9227c2ccaa06264c \ + --hash=sha256:d84cd4061ae292d8ac367b2c3fa3aad11cb8625a95d135fe93f286f914f3f5a6 \ + --hash=sha256:d8aacd3d4b33b772542b2e01beb50187536967b514b00003bdda7589722d2a64 \ + --hash=sha256:d8fc5c867c22b828001b6a38d2eaeb88160bf5783c6cb4a5e440efc981ce286d \ + --hash=sha256:d976bbb382b202f71c67f77b0ac11244021cfa3f7dfd9e562eefcea2df711548 \ + --hash=sha256:dba5ee5d3981160c28d5490f0d1b7ed730c22470ff7f6cc26cfcfaacb9896a07 \ + --hash=sha256:dc1ae486f9abcef254b5618dfb4113dd49f94c68e3e027d03cf0143f3f772b61 \ + --hash=sha256:dd0a578400839256df88c16abddf9ba14813ec5f21362e1fe65022e00c883d4d \ + --hash=sha256:deed0c7258ceb4c44ad5ec7d9918f9f14fd05b2be86378d86cf50e63d1e7b771 \ + --hash=sha256:e09c2279a4d01f099f52d5c4b3d9e208e91edcbd1a175c9662a8b16e000fece9 \ + --hash=sha256:e2ea9f7ab7fbf18fffb1b5434ce7c69a07582f7acc7717720f1d69f3e806f90c \ + --hash=sha256:e6b93f13371d341afee3be9f7c5964e3fe61d5fa30f6a30eb49856935dfe4fc3 \ + --hash=sha256:eb14a5da6dc7642b0f3a18f13654847cd8b7a2550e2645a5bda677862b03ba16 \ + --hash=sha256:ed0fecd28cc62c54b262e3736f8bb2512d8dcfdc2bcf08be5f47f96bf405b145 \ + --hash=sha256:ede8c6d533bc6601a47ad4046080d36b8fc99f81e6f1c17b0ac3c2dc91ac7611 \ + --hash=sha256:efb3a45b35622bb6c16dbfab491a8f5a391fe0e9d45ef32f4df85658232ca0e2 \ + --hash=sha256:f117e1a089d9411663a3207ba874f31be9ac8eaa5b533787024dc07aeb74f464 \ + --hash=sha256:f2ba92255faa7309d06fe44c3a4a97efe1c8d640c2a79a5ef728b685762a6fd2 \ + --hash=sha256:f6008a4919fdbc0b0097089f67a1eb55d950ed7e90ce2cc3e640abadd2757a04 \ + --hash=sha256:f68208a520c3d86ea51acf688a3e3002615a7f0238002cccc17affecc86a8a54 \ + --hash=sha256:f68e4f3eeca8fb22cc3d731f9715a13b652795ef657a13df1ad0c7dc0e9731df \ + --hash=sha256:fb3b8132019ea572f4611d770991000d7f58127560c4889729248eb5852a102f \ + --hash=sha256:fb940820c63a9590d31d88b815e7a3aa5915cad3ce735ab45f0c730b39547de1 \ + --hash=sha256:fc1795ac5cd0510207482c3d1d3ed781143383b8cfd36f5c645f3897ce066220 # via matplotlib -markdown-it-py==3.0.0 \ - --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ - --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb +libclang==18.1.1 \ + --hash=sha256:0b2e143f0fac830156feb56f9231ff8338c20aecfe72b4ffe96f19e5a1dbb69a \ + --hash=sha256:3f0e1f49f04d3cd198985fea0511576b0aee16f9ff0e0f0cad7f9c57ec3c20e8 \ + --hash=sha256:4dd2d3b82fab35e2bf9ca717d7b63ac990a3519c7e312f19fa8e86dcc712f7fb \ + --hash=sha256:54dda940a4a0491a9d1532bf071ea3ef26e6dbaf03b5000ed94dd7174e8f9592 \ + --hash=sha256:69f8eb8f65c279e765ffd28aaa7e9e364c776c17618af8bff22a8df58677ff4f \ + --hash=sha256:6f14c3f194704e5d09769108f03185fce7acaf1d1ae4bbb2f30a72c2400cb7c5 \ + --hash=sha256:83ce5045d101b669ac38e6da8e58765f12da2d3aafb3b9b98d88b286a60964d8 \ + --hash=sha256:a1214966d08d73d971287fc3ead8dfaf82eb07fb197680d8b3859dbbbbf78250 \ + --hash=sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b \ + --hash=sha256:cf4a99b05376513717ab5d82a0db832c56ccea4fd61a69dbb7bccf2dfb207dbe + # via tensorflow +libtpu==0.0.32 ; sys_platform == "linux" and platform_machine == "x86_64" \ + --hash=sha256:37f6aefe6d69d24e5e268e74c1e90cd0ce0fa6a796720709445d8938605912ee \ + --hash=sha256:4a4ae6db0a90f0e33902cd345923a5d77dfc5bbab9a3e524c79bf876858104ee \ + --hash=sha256:6453995774b902e43d1caf0d82298a2e90f2f731ddd74d6bef9510d91732775f \ + --hash=sha256:7ffd9cdb89da22d32bb60ce70b88e9e76861c5d704ebf673f32987e19e9913f3 \ + --hash=sha256:98d9713a39d9581c9b10a7bdeaa575cc874a4a7f543cfb40d3a723ee27b428b5 \ + --hash=sha256:d2a0e7abe4a029b79049bc95da20e4c2a64f8ef09823f75286cecfbb4b0b6da1 + # via -r build/requirements.in +markdown==3.10 \ + --hash=sha256:37062d4f2aa4b2b6b32aefb80faa300f82cc790cb949a35b8caede34f2b68c0e \ + --hash=sha256:b5b99d6951e2e4948d939255596523444c0e677c669700b1d17aa4a8a464cb7c + # via tensorboard +markdown-it-py==4.0.0 \ + --hash=sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147 \ + --hash=sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3 # via rich -matplotlib==3.9.0 ; python_version >= "3.11" \ - --hash=sha256:063af8587fceeac13b0936c42a2b6c732c2ab1c98d38abc3337e430e1ff75e38 \ - --hash=sha256:06a478f0d67636554fa78558cfbcd7b9dba85b51f5c3b5a0c9be49010cf5f321 \ - --hash=sha256:0a490715b3b9984fa609116481b22178348c1a220a4499cda79132000a79b4db \ - --hash=sha256:0fc51eaa5262553868461c083d9adadb11a6017315f3a757fc45ec6ec5f02888 \ - --hash=sha256:13beb4840317d45ffd4183a778685e215939be7b08616f431c7795276e067463 \ - --hash=sha256:290d304e59be2b33ef5c2d768d0237f5bd132986bdcc66f80bc9bcc300066a03 \ - --hash=sha256:2bcee1dffaf60fe7656183ac2190bd630842ff87b3153afb3e384d966b57fe56 \ - --hash=sha256:2e7f03e5cbbfacdd48c8ea394d365d91ee8f3cae7e6ec611409927b5ed997ee4 \ - --hash=sha256:3f988bafb0fa39d1074ddd5bacd958c853e11def40800c5824556eb630f94d3b \ - --hash=sha256:52146fc3bd7813cc784562cb93a15788be0b2875c4655e2cc6ea646bfa30344b \ - --hash=sha256:550cdda3adbd596078cca7d13ed50b77879104e2e46392dcd7c75259d8f00e85 \ - --hash=sha256:616fabf4981a3b3c5a15cd95eba359c8489c4e20e03717aea42866d8d0465956 \ - --hash=sha256:76cce0f31b351e3551d1f3779420cf8f6ec0d4a8cf9c0237a3b549fd28eb4abb \ - --hash=sha256:7ff2e239c26be4f24bfa45860c20ffccd118d270c5b5d081fa4ea409b5469fcd \ - --hash=sha256:8146ce83cbc5dc71c223a74a1996d446cd35cfb6a04b683e1446b7e6c73603b7 \ - --hash=sha256:81c40af649d19c85f8073e25e5806926986806fa6d54be506fbf02aef47d5a89 \ - --hash=sha256:9a2fa6d899e17ddca6d6526cf6e7ba677738bf2a6a9590d702c277204a7c6152 \ - --hash=sha256:a5be985db2596d761cdf0c2eaf52396f26e6a64ab46bd8cd810c48972349d1be \ - --hash=sha256:af4001b7cae70f7eaacfb063db605280058246de590fa7874f00f62259f2df7e \ - --hash=sha256:bd4f2831168afac55b881db82a7730992aa41c4f007f1913465fb182d6fb20c0 \ - --hash=sha256:bdd1ecbe268eb3e7653e04f451635f0fb0f77f07fd070242b44c076c9106da84 \ - --hash=sha256:c53aeb514ccbbcbab55a27f912d79ea30ab21ee0531ee2c09f13800efb272674 \ - --hash=sha256:c79f3a585f1368da6049318bdf1f85568d8d04b2e89fc24b7e02cc9b62017382 \ - --hash=sha256:cd53c79fd02f1c1808d2cfc87dd3cf4dbc63c5244a58ee7944497107469c8d8a \ - --hash=sha256:d38e85a1a6d732f645f1403ce5e6727fd9418cd4574521d5803d3d94911038e5 \ - --hash=sha256:d91a4ffc587bacf5c4ce4ecfe4bcd23a4b675e76315f2866e588686cc97fccdf \ - --hash=sha256:e6d29ea6c19e34b30fb7d88b7081f869a03014f66fe06d62cc77d5a6ea88ed7a \ - --hash=sha256:eaf3978060a106fab40c328778b148f590e27f6fa3cd15a19d6892575bce387d \ - --hash=sha256:fe428e191ea016bb278758c8ee82a8129c51d81d8c4bc0846c09e7e8e9057241 +markupsafe==3.0.3 \ + --hash=sha256:0303439a41979d9e74d18ff5e2dd8c43ed6c6001fd40e5bf2e43f7bd9bbc523f \ + --hash=sha256:068f375c472b3e7acbe2d5318dea141359e6900156b5b2ba06a30b169086b91a \ + --hash=sha256:0bf2a864d67e76e5c9a34dc26ec616a66b9888e25e7b9460e1c76d3293bd9dbf \ + --hash=sha256:0db14f5dafddbb6d9208827849fad01f1a2609380add406671a26386cdf15a19 \ + --hash=sha256:0eb9ff8191e8498cca014656ae6b8d61f39da5f95b488805da4bb029cccbfbaf \ + --hash=sha256:0f4b68347f8c5eab4a13419215bdfd7f8c9b19f2b25520968adfad23eb0ce60c \ + --hash=sha256:1085e7fbddd3be5f89cc898938f42c0b3c711fdcb37d75221de2666af647c175 \ + --hash=sha256:116bb52f642a37c115f517494ea5feb03889e04df47eeff5b130b1808ce7c219 \ + --hash=sha256:12c63dfb4a98206f045aa9563db46507995f7ef6d83b2f68eda65c307c6829eb \ + --hash=sha256:133a43e73a802c5562be9bbcd03d090aa5a1fe899db609c29e8c8d815c5f6de6 \ + --hash=sha256:1353ef0c1b138e1907ae78e2f6c63ff67501122006b0f9abad68fda5f4ffc6ab \ + --hash=sha256:15d939a21d546304880945ca1ecb8a039db6b4dc49b2c5a400387cdae6a62e26 \ + --hash=sha256:177b5253b2834fe3678cb4a5f0059808258584c559193998be2601324fdeafb1 \ + --hash=sha256:1872df69a4de6aead3491198eaf13810b565bdbeec3ae2dc8780f14458ec73ce \ + --hash=sha256:1b4b79e8ebf6b55351f0d91fe80f893b4743f104bff22e90697db1590e47a218 \ + --hash=sha256:1b52b4fb9df4eb9ae465f8d0c228a00624de2334f216f178a995ccdcf82c4634 \ + --hash=sha256:1ba88449deb3de88bd40044603fafffb7bc2b055d626a330323a9ed736661695 \ + --hash=sha256:1cc7ea17a6824959616c525620e387f6dd30fec8cb44f649e31712db02123dad \ + --hash=sha256:218551f6df4868a8d527e3062d0fb968682fe92054e89978594c28e642c43a73 \ + --hash=sha256:26a5784ded40c9e318cfc2bdb30fe164bdb8665ded9cd64d500a34fb42067b1c \ + --hash=sha256:2713baf880df847f2bece4230d4d094280f4e67b1e813eec43b4c0e144a34ffe \ + --hash=sha256:2a15a08b17dd94c53a1da0438822d70ebcd13f8c3a95abe3a9ef9f11a94830aa \ + --hash=sha256:2f981d352f04553a7171b8e44369f2af4055f888dfb147d55e42d29e29e74559 \ + --hash=sha256:32001d6a8fc98c8cb5c947787c5d08b0a50663d139f1305bac5885d98d9b40fa \ + --hash=sha256:3524b778fe5cfb3452a09d31e7b5adefeea8c5be1d43c4f810ba09f2ceb29d37 \ + --hash=sha256:3537e01efc9d4dccdf77221fb1cb3b8e1a38d5428920e0657ce299b20324d758 \ + --hash=sha256:35add3b638a5d900e807944a078b51922212fb3dedb01633a8defc4b01a3c85f \ + --hash=sha256:38664109c14ffc9e7437e86b4dceb442b0096dfe3541d7864d9cbe1da4cf36c8 \ + --hash=sha256:3a7e8ae81ae39e62a41ec302f972ba6ae23a5c5396c8e60113e9066ef893da0d \ + --hash=sha256:3b562dd9e9ea93f13d53989d23a7e775fdfd1066c33494ff43f5418bc8c58a5c \ + --hash=sha256:457a69a9577064c05a97c41f4e65148652db078a3a509039e64d3467b9e7ef97 \ + --hash=sha256:4bd4cd07944443f5a265608cc6aab442e4f74dff8088b0dfc8238647b8f6ae9a \ + --hash=sha256:4e885a3d1efa2eadc93c894a21770e4bc67899e3543680313b09f139e149ab19 \ + --hash=sha256:4faffd047e07c38848ce017e8725090413cd80cbc23d86e55c587bf979e579c9 \ + --hash=sha256:509fa21c6deb7a7a273d629cf5ec029bc209d1a51178615ddf718f5918992ab9 \ + --hash=sha256:5678211cb9333a6468fb8d8be0305520aa073f50d17f089b5b4b477ea6e67fdc \ + --hash=sha256:591ae9f2a647529ca990bc681daebdd52c8791ff06c2bfa05b65163e28102ef2 \ + --hash=sha256:5a7d5dc5140555cf21a6fefbdbf8723f06fcd2f63ef108f2854de715e4422cb4 \ + --hash=sha256:69c0b73548bc525c8cb9a251cddf1931d1db4d2258e9599c28c07ef3580ef354 \ + --hash=sha256:6b5420a1d9450023228968e7e6a9ce57f65d148ab56d2313fcd589eee96a7a50 \ + --hash=sha256:722695808f4b6457b320fdc131280796bdceb04ab50fe1795cd540799ebe1698 \ + --hash=sha256:729586769a26dbceff69f7a7dbbf59ab6572b99d94576a5592625d5b411576b9 \ + --hash=sha256:77f0643abe7495da77fb436f50f8dab76dbc6e5fd25d39589a0f1fe6548bfa2b \ + --hash=sha256:795e7751525cae078558e679d646ae45574b47ed6e7771863fcc079a6171a0fc \ + --hash=sha256:7be7b61bb172e1ed687f1754f8e7484f1c8019780f6f6b0786e76bb01c2ae115 \ + --hash=sha256:7c3fb7d25180895632e5d3148dbdc29ea38ccb7fd210aa27acbd1201a1902c6e \ + --hash=sha256:7e68f88e5b8799aa49c85cd116c932a1ac15caaa3f5db09087854d218359e485 \ + --hash=sha256:83891d0e9fb81a825d9a6d61e3f07550ca70a076484292a70fde82c4b807286f \ + --hash=sha256:8485f406a96febb5140bfeca44a73e3ce5116b2501ac54fe953e488fb1d03b12 \ + --hash=sha256:8709b08f4a89aa7586de0aadc8da56180242ee0ada3999749b183aa23df95025 \ + --hash=sha256:8f71bc33915be5186016f675cd83a1e08523649b0e33efdb898db577ef5bb009 \ + --hash=sha256:915c04ba3851909ce68ccc2b8e2cd691618c4dc4c4232fb7982bca3f41fd8c3d \ + --hash=sha256:949b8d66bc381ee8b007cd945914c721d9aba8e27f71959d750a46f7c282b20b \ + --hash=sha256:94c6f0bb423f739146aec64595853541634bde58b2135f27f61c1ffd1cd4d16a \ + --hash=sha256:9a1abfdc021a164803f4d485104931fb8f8c1efd55bc6b748d2f5774e78b62c5 \ + --hash=sha256:9b79b7a16f7fedff2495d684f2b59b0457c3b493778c9eed31111be64d58279f \ + --hash=sha256:a320721ab5a1aba0a233739394eb907f8c8da5c98c9181d1161e77a0c8e36f2d \ + --hash=sha256:a4afe79fb3de0b7097d81da19090f4df4f8d3a2b3adaa8764138aac2e44f3af1 \ + --hash=sha256:ad2cf8aa28b8c020ab2fc8287b0f823d0a7d8630784c31e9ee5edea20f406287 \ + --hash=sha256:b8512a91625c9b3da6f127803b166b629725e68af71f8184ae7e7d54686a56d6 \ + --hash=sha256:bc51efed119bc9cfdf792cdeaa4d67e8f6fcccab66ed4bfdd6bde3e59bfcbb2f \ + --hash=sha256:bdc919ead48f234740ad807933cdf545180bfbe9342c2bb451556db2ed958581 \ + --hash=sha256:bdd37121970bfd8be76c5fb069c7751683bdf373db1ed6c010162b2a130248ed \ + --hash=sha256:be8813b57049a7dc738189df53d69395eba14fb99345e0a5994914a3864c8a4b \ + --hash=sha256:c0c0b3ade1c0b13b936d7970b1d37a57acde9199dc2aecc4c336773e1d86049c \ + --hash=sha256:c47a551199eb8eb2121d4f0f15ae0f923d31350ab9280078d1e5f12b249e0026 \ + --hash=sha256:c4ffb7ebf07cfe8931028e3e4c85f0357459a3f9f9490886198848f4fa002ec8 \ + --hash=sha256:ccfcd093f13f0f0b7fdd0f198b90053bf7b2f02a3927a30e63f3ccc9df56b676 \ + --hash=sha256:d2ee202e79d8ed691ceebae8e0486bd9a2cd4794cec4824e1c99b6f5009502f6 \ + --hash=sha256:d53197da72cc091b024dd97249dfc7794d6a56530370992a5e1a08983ad9230e \ + --hash=sha256:d6dd0be5b5b189d31db7cda48b91d7e0a9795f31430b7f271219ab30f1d3ac9d \ + --hash=sha256:d88b440e37a16e651bda4c7c2b930eb586fd15ca7406cb39e211fcff3bf3017d \ + --hash=sha256:de8a88e63464af587c950061a5e6a67d3632e36df62b986892331d4620a35c01 \ + --hash=sha256:df2449253ef108a379b8b5d6b43f4b1a8e81a061d6537becd5582fba5f9196d7 \ + --hash=sha256:e1c1493fb6e50ab01d20a22826e57520f1284df32f2d8601fdd90b6304601419 \ + --hash=sha256:e1cf1972137e83c5d4c136c43ced9ac51d0e124706ee1c8aa8532c1287fa8795 \ + --hash=sha256:e2103a929dfa2fcaf9bb4e7c091983a49c9ac3b19c9061b6d5427dd7d14d81a1 \ + --hash=sha256:e56b7d45a839a697b5eb268c82a71bd8c7f6c94d6fd50c3d577fa39a9f1409f5 \ + --hash=sha256:e8afc3f2ccfa24215f8cb28dcf43f0113ac3c37c2f0f0806d8c70e4228c5cf4d \ + --hash=sha256:e8fc20152abba6b83724d7ff268c249fa196d8259ff481f3b1476383f8f24e42 \ + --hash=sha256:eaa9599de571d72e2daf60164784109f19978b327a3910d3e9de8c97b5b70cfe \ + --hash=sha256:ec15a59cf5af7be74194f7ab02d0f59a62bdcf1a537677ce67a2537c9b87fcda \ + --hash=sha256:f190daf01f13c72eac4efd5c430a8de82489d9cff23c364c3ea822545032993e \ + --hash=sha256:f34c41761022dd093b4b6896d4810782ffbabe30f2d443ff5f083e0cbbb8c737 \ + --hash=sha256:f3e98bb3798ead92273dc0e5fd0f31ade220f59a266ffd8a4f6065e0a3ce0523 \ + --hash=sha256:f42d0984e947b8adf7dd6dde396e720934d12c506ce84eea8476409563607591 \ + --hash=sha256:f71a396b3bf33ecaa1626c255855702aca4d3d9fea5e051b41ac59a9c1c41edc \ + --hash=sha256:f9e130248f4462aaa8e2552d547f36ddadbeaa573879158d721bbd33dfe4743a \ + --hash=sha256:fed51ac40f757d41b7c48425901843666a6677e3e8eb0abcff09e4ba6e664f50 + # via werkzeug +matplotlib==3.10.8 \ + --hash=sha256:00270d217d6b20d14b584c521f810d60c5c78406dc289859776550df837dcda7 \ + --hash=sha256:0a33deb84c15ede243aead39f77e990469fff93ad1521163305095b77b72ce4a \ + --hash=sha256:113bb52413ea508ce954a02c10ffd0d565f9c3bc7f2eddc27dfe1731e71c7b5f \ + --hash=sha256:12d90df9183093fcd479f4172ac26b322b1248b15729cb57f42f71f24c7e37a3 \ + --hash=sha256:15d30132718972c2c074cd14638c7f4592bd98719e2308bccea40e0538bc0cb5 \ + --hash=sha256:18821ace09c763ec93aef5eeff087ee493a24051936d7b9ebcad9662f66501f9 \ + --hash=sha256:1ae029229a57cd1e8fe542485f27e7ca7b23aa9e8944ddb4985d0bc444f1eca2 \ + --hash=sha256:2299372c19d56bcd35cf05a2738308758d32b9eaed2371898d8f5bd33f084aa3 \ + --hash=sha256:238b7ce5717600615c895050239ec955d91f321c209dd110db988500558e70d6 \ + --hash=sha256:24d50994d8c5816ddc35411e50a86ab05f575e2530c02752e02538122613371f \ + --hash=sha256:25d380fe8b1dc32cf8f0b1b448470a77afb195438bafdf1d858bfb876f3edf7b \ + --hash=sha256:2c1998e92cd5999e295a731bcb2911c75f597d937341f3030cc24ef2733d78a8 \ + --hash=sha256:2cf5bd12cecf46908f286d7838b2abc6c91cda506c0445b8223a7c19a00df008 \ + --hash=sha256:32f8dce744be5569bebe789e46727946041199030db8aeb2954d26013a0eb26b \ + --hash=sha256:37b3c1cc42aa184b3f738cfa18c1c1d72fd496d85467a6cf7b807936d39aa656 \ + --hash=sha256:3a48a78d2786784cc2413e57397981fb45c79e968d99656706018d6e62e57958 \ + --hash=sha256:3ab4aabc72de4ff77b3ec33a6d78a68227bf1123465887f9905ba79184a1cc04 \ + --hash=sha256:3c624e43ed56313651bc18a47f838b60d7b8032ed348911c54906b130b20071b \ + --hash=sha256:3f2e409836d7f5ac2f1c013110a4d50b9f7edc26328c108915f9075d7d7a91b6 \ + --hash=sha256:3f5c3e4da343bba819f0234186b9004faba952cc420fbc522dc4e103c1985908 \ + --hash=sha256:41703cc95688f2516b480f7f339d8851a6035f18e100ee6a32bc0b8536a12a9c \ + --hash=sha256:495672de149445ec1b772ff2c9ede9b769e3cb4f0d0aa7fa730d7f59e2d4e1c1 \ + --hash=sha256:4cf267add95b1c88300d96ca837833d4112756045364f5c734a2276038dae27d \ + --hash=sha256:56271f3dac49a88d7fca5060f004d9d22b865f743a12a23b1e937a0be4818ee1 \ + --hash=sha256:595ba4d8fe983b88f0eec8c26a241e16d6376fe1979086232f481f8f3f67494c \ + --hash=sha256:5f62550b9a30afde8c1c3ae450e5eb547d579dd69b25c2fc7a1c67f934c1717a \ + --hash=sha256:646d95230efb9ca614a7a594d4fcacde0ac61d25e37dd51710b36477594963ce \ + --hash=sha256:64fcc24778ca0404ce0cb7b6b77ae1f4c7231cdd60e6778f999ee05cbd581b9a \ + --hash=sha256:6be43b667360fef5c754dda5d25a32e6307a03c204f3c0fc5468b78fa87b4160 \ + --hash=sha256:6da7c2ce169267d0d066adcf63758f0604aa6c3eebf67458930f9d9b79ad1db1 \ + --hash=sha256:83d282364ea9f3e52363da262ce32a09dfe241e4080dcedda3c0db059d3c1f11 \ + --hash=sha256:9153c3292705be9f9c64498a8872118540c3f4123d1a1c840172edf262c8be4a \ + --hash=sha256:99eefd13c0dc3b3c1b4d561c1169e65fe47aab7b8158754d7c084088e2329466 \ + --hash=sha256:a0a7f52498f72f13d4a25ea70f35f4cb60642b466cbb0a9be951b5bc3f45a486 \ + --hash=sha256:a2b336e2d91a3d7006864e0990c83b216fcdca64b5a6484912902cef87313d78 \ + --hash=sha256:a48f2b74020919552ea25d222d5cc6af9ca3f4eb43a93e14d068457f545c2a17 \ + --hash=sha256:ad3d9833a64cf48cc4300f2b406c3d0f4f4724a91c0bd5640678a6ba7c102077 \ + --hash=sha256:b44d07310e404ba95f8c25aa5536f154c0a8ec473303535949e52eb71d0a1565 \ + --hash=sha256:b53285e65d4fa4c86399979e956235deb900be5baa7fc1218ea67fbfaeaadd6f \ + --hash=sha256:b5a2b97dbdc7d4f353ebf343744f1d1f1cca8aa8bfddb4262fcf4306c3761d50 \ + --hash=sha256:b9a5ca4ac220a0cdd1ba6bcba3608547117d30468fefce49bb26f55c1a3d5c58 \ + --hash=sha256:bab485bcf8b1c7d2060b4fcb6fc368a9e6f4cd754c9c2fea281f4be21df394a2 \ + --hash=sha256:c108a1d6fa78a50646029cb6d49808ff0fc1330fda87fa6f6250c6b5369b6645 \ + --hash=sha256:d56a1efd5bfd61486c8bc968fa18734464556f0fb8e51690f4ac25d85cbbbbc2 \ + --hash=sha256:d9050fee89a89ed57b4fb2c1bfac9a3d0c57a0d55aed95949eedbc42070fea39 \ + --hash=sha256:dd80ecb295460a5d9d260df63c43f4afbdd832d725a531f008dad1664f458adf \ + --hash=sha256:e8ea3e2d4066083e264e75c829078f9e149fa119d27e19acd503de65e0b13149 \ + --hash=sha256:eb3823f11823deade26ce3b9f40dcb4a213da7a670013929f31d5f5ed1055b22 \ + --hash=sha256:ee40c27c795bda6a5292e9cff9890189d32f7e3a0bf04e0e3c9430c4a00c37df \ + --hash=sha256:efb30e3baaea72ce5928e32bab719ab4770099079d66726a62b11b1ef7273be4 \ + --hash=sha256:f254d118d14a7f99d616271d6c3c27922c092dac11112670b157798b89bf4933 \ + --hash=sha256:f89c151aab2e2e23cb3fe0acad1e8b82841fd265379c4cecd0f3fcb34c15e0f6 \ + --hash=sha256:f97aeb209c3d2511443f8797e3e5a569aebb040d4f8bc79aa3ee78a8fb9e3dd8 \ + --hash=sha256:f9b587c9c7274c1613a30afabf65a272114cd6cdbe67b3406f818c79d7ab2e2a \ + --hash=sha256:fb061f596dad3a0f52b60dc6a5dec4a0c300dec41e058a7efe09256188d170b7 # via -r build/test-requirements.txt mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba # via markdown-it-py -ml-dtypes==0.5.1 \ - --hash=sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327 \ - --hash=sha256:05f23447a1c20ddf4dc7c2c661aa9ed93fcb2658f1017c204d1e758714dc28a8 \ - --hash=sha256:12651420130ee7cc13059fc56dac6ad300c3af3848b802d475148c9defd27c23 \ - --hash=sha256:141b2ea2f20bb10802ddca55d91fe21231ef49715cfc971998e8f2a9838f3dbe \ - --hash=sha256:15ad0f3b0323ce96c24637a88a6f44f6713c64032f27277b069f285c3cf66478 \ - --hash=sha256:1b7fbe5571fdf28fd3aaab3ef4aafc847de9ebf263be959958c1ca58ec8eadf5 \ - --hash=sha256:26ebcc69d7b779c8f129393e99732961b5cc33fcff84090451f448c89b0e01b4 \ - --hash=sha256:6f462f5eca22fb66d7ff9c4744a3db4463af06c49816c4b6ac89b16bfcdc592e \ - --hash=sha256:6f76232163b5b9c34291b54621ee60417601e2e4802a188a0ea7157cd9b323f4 \ - --hash=sha256:7000b6e4d8ef07542c05044ec5d8bbae1df083b3f56822c3da63993a113e716f \ - --hash=sha256:810512e2eccdfc3b41eefa3a27402371a3411453a1efc7e9c000318196140fed \ - --hash=sha256:8f2c028954f16ede77902b223a8da2d9cbb3892375b85809a5c3cfb1587960c4 \ - --hash=sha256:9626d0bca1fb387d5791ca36bacbba298c5ef554747b7ebeafefb4564fc83566 \ - --hash=sha256:ac5b58559bb84a95848ed6984eb8013249f90b6bab62aa5acbad876e256002c9 \ - --hash=sha256:ad4953c5eb9c25a56d11a913c2011d7e580a435ef5145f804d98efa14477d390 \ - --hash=sha256:aefedc579ece2f8fb38f876aa7698204ee4c372d0e54f1c1ffa8ca580b54cc60 \ - --hash=sha256:afb2009ac98da274e893e03162f6269398b2b00d947e7057ee2469a921d58135 \ - --hash=sha256:b8a9d46b4df5ae2135a8e8e72b465448ebbc1559997f4f9304a9ecc3413efb5b \ - --hash=sha256:bd73f51957949069573ff783563486339a9285d72e2f36c18e0c1aa9ca7eb190 \ - --hash=sha256:bf9975bda82a99dc935f2ae4c83846d86df8fd6ba179614acac8e686910851da \ - --hash=sha256:c09526488c3a9e8b7a23a388d4974b670a9a3dd40c5c8a61db5593ce9b725bab \ - --hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \ - --hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \ - --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 - # via -r build/requirements.in -mpmath==1.4.0a1 \ - --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ - --hash=sha256:f8b7b5a3a1726ab6e8c898eb2157426b82c482ab1ab8ffed9f88bb9e07c6e9c1 - # via -r build/test-requirements.txt -numpy==2.0.0 ; python_version <= "3.12" \ - --hash=sha256:04494f6ec467ccb5369d1808570ae55f6ed9b5809d7f035059000a37b8d7e86f \ - --hash=sha256:0a43f0974d501842866cc83471bdb0116ba0dffdbaac33ec05e6afed5b615238 \ - --hash=sha256:0e50842b2295ba8414c8c1d9d957083d5dfe9e16828b37de883f51fc53c4016f \ - --hash=sha256:0ec84b9ba0654f3b962802edc91424331f423dcf5d5f926676e0150789cb3d95 \ - --hash=sha256:17067d097ed036636fa79f6a869ac26df7db1ba22039d962422506640314933a \ - --hash=sha256:1cde1753efe513705a0c6d28f5884e22bdc30438bf0085c5c486cdaff40cd67a \ - --hash=sha256:1e72728e7501a450288fc8e1f9ebc73d90cfd4671ebbd631f3e7857c39bd16f2 \ - --hash=sha256:2635dbd200c2d6faf2ef9a0d04f0ecc6b13b3cad54f7c67c61155138835515d2 \ - --hash=sha256:2ce46fd0b8a0c947ae047d222f7136fc4d55538741373107574271bc00e20e8f \ - --hash=sha256:34f003cb88b1ba38cb9a9a4a3161c1604973d7f9d5552c38bc2f04f829536609 \ - --hash=sha256:354f373279768fa5a584bac997de6a6c9bc535c482592d7a813bb0c09be6c76f \ - --hash=sha256:38ecb5b0582cd125f67a629072fed6f83562d9dd04d7e03256c9829bdec027ad \ - --hash=sha256:3e8e01233d57639b2e30966c63d36fcea099d17c53bf424d77f088b0f4babd86 \ - --hash=sha256:3f6bed7f840d44c08ebdb73b1825282b801799e325bcbdfa6bc5c370e5aecc65 \ - --hash=sha256:4554eb96f0fd263041baf16cf0881b3f5dafae7a59b1049acb9540c4d57bc8cb \ - --hash=sha256:46e161722e0f619749d1cd892167039015b2c2817296104487cd03ed4a955995 \ - --hash=sha256:49d9f7d256fbc804391a7f72d4a617302b1afac1112fac19b6c6cec63fe7fe8a \ - --hash=sha256:4d2f62e55a4cd9c58c1d9a1c9edaedcd857a73cb6fda875bf79093f9d9086f85 \ - --hash=sha256:5f64641b42b2429f56ee08b4f427a4d2daf916ec59686061de751a55aafa22e4 \ - --hash=sha256:63b92c512d9dbcc37f9d81b123dec99fdb318ba38c8059afc78086fe73820275 \ - --hash=sha256:6d7696c615765091cc5093f76fd1fa069870304beaccfd58b5dcc69e55ef49c1 \ - --hash=sha256:79e843d186c8fb1b102bef3e2bc35ef81160ffef3194646a7fdd6a73c6b97196 \ - --hash=sha256:821eedb7165ead9eebdb569986968b541f9908979c2da8a4967ecac4439bae3d \ - --hash=sha256:84554fc53daa8f6abf8e8a66e076aff6ece62de68523d9f665f32d2fc50fd66e \ - --hash=sha256:8d83bb187fb647643bd56e1ae43f273c7f4dbcdf94550d7938cfc32566756514 \ - --hash=sha256:903703372d46bce88b6920a0cd86c3ad82dae2dbef157b5fc01b70ea1cfc430f \ - --hash=sha256:9416a5c2e92ace094e9f0082c5fd473502c91651fb896bc17690d6fc475128d6 \ - --hash=sha256:9a1712c015831da583b21c5bfe15e8684137097969c6d22e8316ba66b5baabe4 \ - --hash=sha256:9c27f0946a3536403efb0e1c28def1ae6730a72cd0d5878db38824855e3afc44 \ - --hash=sha256:a356364941fb0593bb899a1076b92dfa2029f6f5b8ba88a14fd0984aaf76d0df \ - --hash=sha256:a7039a136017eaa92c1848152827e1424701532ca8e8967fe480fe1569dae581 \ - --hash=sha256:acd3a644e4807e73b4e1867b769fbf1ce8c5d80e7caaef0d90dcdc640dfc9787 \ - --hash=sha256:ad0c86f3455fbd0de6c31a3056eb822fc939f81b1618f10ff3406971893b62a5 \ - --hash=sha256:b4c76e3d4c56f145d41b7b6751255feefae92edbc9a61e1758a98204200f30fc \ - --hash=sha256:b6f6a8f45d0313db07d6d1d37bd0b112f887e1369758a5419c0370ba915b3871 \ - --hash=sha256:c5a59996dc61835133b56a32ebe4ef3740ea5bc19b3983ac60cc32be5a665d54 \ - --hash=sha256:c73aafd1afca80afecb22718f8700b40ac7cab927b8abab3c3e337d70e10e5a2 \ - --hash=sha256:cee6cc0584f71adefe2c908856ccc98702baf95ff80092e4ca46061538a2ba98 \ - --hash=sha256:cef04d068f5fb0518a77857953193b6bb94809a806bd0a14983a8f12ada060c9 \ - --hash=sha256:cf5d1c9e6837f8af9f92b6bd3e86d513cdc11f60fd62185cc49ec7d1aba34864 \ - --hash=sha256:e61155fae27570692ad1d327e81c6cf27d535a5d7ef97648a17d922224b216de \ - --hash=sha256:e7f387600d424f91576af20518334df3d97bc76a300a755f9a8d6e4f5cadd289 \ - --hash=sha256:ed08d2703b5972ec736451b818c2eb9da80d66c3e84aed1deeb0c345fefe461b \ - --hash=sha256:fbd6acc766814ea6443628f4e6751d0da6593dae29c08c0b2606164db026970c \ - --hash=sha256:feff59f27338135776f6d4e2ec7aeeac5d5f7a08a83e80869121ef8164b74af9 +ml-dtypes==0.5.4 \ + --hash=sha256:0d2ffd05a2575b1519dc928c0b93c06339eb67173ff53acb00724502cda231cf \ + --hash=sha256:11942cbf2cf92157db91e5022633c0d9474d4dfd813a909383bd23ce828a4b7d \ + --hash=sha256:14a4fd3228af936461db66faccef6e4f41c1d82fcc30e9f8d58a08916b1d811f \ + --hash=sha256:19b9a53598f21e453ea2fbda8aa783c20faff8e1eeb0d7ab899309a0053f1483 \ + --hash=sha256:2314892cdc3fcf05e373d76d72aaa15fda9fb98625effa73c1d646f331fcecb7 \ + --hash=sha256:2b857d3af6ac0d39db1de7c706e69c7f9791627209c3d6dedbfca8c7e5faec22 \ + --hash=sha256:304ad47faa395415b9ccbcc06a0350800bc50eda70f0e45326796e27c62f18b6 \ + --hash=sha256:35f29491a3e478407f7047b8a4834e4640a77d2737e0b294d049746507af5175 \ + --hash=sha256:388d399a2152dd79a3f0456a952284a99ee5c93d3e2f8dfe25977511e0515270 \ + --hash=sha256:3bbbe120b915090d9dd1375e4684dd17a20a2491ef25d640a908281da85e73f1 \ + --hash=sha256:3d277bf3637f2a62176f4575512e9ff9ef51d00e39626d9fe4a161992f355af2 \ + --hash=sha256:4381fe2f2452a2d7589689693d3162e876b3ddb0a832cde7a414f8e1adf7eab1 \ + --hash=sha256:4ff7f3e7ca2972e7de850e7b8fcbb355304271e2933dd90814c1cb847414d6e2 \ + --hash=sha256:531eff30e4d368cb6255bc2328d070e35836aa4f282a0fb5f3a0cd7260257298 \ + --hash=sha256:533ce891ba774eabf607172254f2e7260ba5f57bdd64030c9a4fcfbd99815d0d \ + --hash=sha256:557a31a390b7e9439056644cb80ed0735a6e3e3bb09d67fd5687e4b04238d1de \ + --hash=sha256:5a0f68ca8fd8d16583dfa7793973feb86f2fbb56ce3966daf9c9f748f52a2049 \ + --hash=sha256:6a0df4223b514d799b8a1629c65ddc351b3efa833ccf7f8ea0cf654a61d1e35d \ + --hash=sha256:6c7ecb74c4bd71db68a6bea1edf8da8c34f3d9fe218f038814fd1d310ac76c90 \ + --hash=sha256:7c23c54a00ae43edf48d44066a7ec31e05fdc2eee0be2b8b50dd1903a1db94bb \ + --hash=sha256:805cef3a38f4eafae3a5bf9ebdcdb741d0bcfd9e1bd90eb54abd24f928cd2465 \ + --hash=sha256:88c982aac7cb1cbe8cbb4e7f253072b1df872701fcaf48d84ffbb433b6568f24 \ + --hash=sha256:8ab06a50fb9bf9666dd0fe5dfb4676fa2b0ac0f31ecff72a6c3af8e22c063453 \ + --hash=sha256:8c6a2dcebd6f3903e05d51960a8058d6e131fe69f952a5397e5dbabc841b6d56 \ + --hash=sha256:8c760d85a2f82e2bed75867079188c9d18dae2ee77c25a54d60e9cc79be1bc48 \ + --hash=sha256:9ad459e99793fa6e13bd5b7e6792c8f9190b4e5a1b45c63aba14a4d0a7f1d5ff \ + --hash=sha256:9bad06436568442575beb2d03389aa7456c690a5b05892c471215bfd8cf39460 \ + --hash=sha256:a174837a64f5b16cab6f368171a1a03a27936b31699d167684073ff1c4237dac \ + --hash=sha256:a7f7c643e8b1320fd958bf098aa7ecf70623a42ec5154e3be3be673f4c34d900 \ + --hash=sha256:a9b61c19040397970d18d7737375cffd83b1f36a11dd4ad19f83a016f736c3ef \ + --hash=sha256:b4b801ebe0b477be666696bda493a9be8356f1f0057a57f1e35cd26928823e5a \ + --hash=sha256:b95e97e470fe60ed493fd9ae3911d8da4ebac16bd21f87ffa2b7c588bf22ea2c \ + --hash=sha256:bc11d7e8c44a65115d05e2ab9989d1e045125d7be8e05a071a48bc76eb6d6040 \ + --hash=sha256:bfc534409c5d4b0bf945af29e5d0ab075eae9eecbb549ff8a29280db822f34f9 \ + --hash=sha256:c1a953995cccb9e25a4ae19e34316671e4e2edaebe4cf538229b1fc7109087b7 \ + --hash=sha256:cb73dccfc991691c444acc8c0012bee8f2470da826a92e3a20bb333b1a7894e6 \ + --hash=sha256:ce756d3a10d0c4067172804c9cc276ba9cc0ff47af9078ad439b075d1abdc29b \ + --hash=sha256:d81fdb088defa30eb37bf390bb7dde35d3a83ec112ac8e33d75ab28cc29dd8b0 \ + --hash=sha256:f21c9219ef48ca5ee78402d5cc831bd58ea27ce89beda894428bc67a52da5328 # via # -r build/requirements.in + # jaxlib + # keras + # tensorflow + # tensorstore +mpmath==1.3.0 \ + --hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \ + --hash=sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c + # via -r build/test-requirements.txt +namex==0.1.0 \ + --hash=sha256:117f03ccd302cc48e3f5c58a296838f6b89c83455ab8683a1e85f2a430aa4306 \ + --hash=sha256:e2012a474502f1e2251267062aae3114611f07df4224b6e06334c57b0f2ce87c + # via keras +numpy==2.0.2 ; python_version <= "3.12" \ + --hash=sha256:0123ffdaa88fa4ab64835dcbde75dcdf89c453c922f18dced6e27c90d1d0ec5a \ + --hash=sha256:11a76c372d1d37437857280aa142086476136a8c0f373b2e648ab2c8f18fb195 \ + --hash=sha256:13e689d772146140a252c3a28501da66dfecd77490b498b168b501835041f951 \ + --hash=sha256:1e795a8be3ddbac43274f18588329c72939870a16cae810c2b73461c40718ab1 \ + --hash=sha256:26df23238872200f63518dd2aa984cfca675d82469535dc7162dc2ee52d9dd5c \ + --hash=sha256:286cd40ce2b7d652a6f22efdfc6d1edf879440e53e76a75955bc0c826c7e64dc \ + --hash=sha256:2b2955fa6f11907cf7a70dab0d0755159bca87755e831e47932367fc8f2f2d0b \ + --hash=sha256:2da5960c3cf0df7eafefd806d4e612c5e19358de82cb3c343631188991566ccd \ + --hash=sha256:312950fdd060354350ed123c0e25a71327d3711584beaef30cdaa93320c392d4 \ + --hash=sha256:423e89b23490805d2a5a96fe40ec507407b8ee786d66f7328be214f9679df6dd \ + --hash=sha256:496f71341824ed9f3d2fd36cf3ac57ae2e0165c143b55c3a035ee219413f3318 \ + --hash=sha256:49ca4decb342d66018b01932139c0961a8f9ddc7589611158cb3c27cbcf76448 \ + --hash=sha256:51129a29dbe56f9ca83438b706e2e69a39892b5eda6cedcb6b0c9fdc9b0d3ece \ + --hash=sha256:5fec9451a7789926bcf7c2b8d187292c9f93ea30284802a0ab3f5be8ab36865d \ + --hash=sha256:671bec6496f83202ed2d3c8fdc486a8fc86942f2e69ff0e986140339a63bcbe5 \ + --hash=sha256:7f0a0c6f12e07fa94133c8a67404322845220c06a9e80e85999afe727f7438b8 \ + --hash=sha256:807ec44583fd708a21d4a11d94aedf2f4f3c3719035c76a2bbe1fe8e217bdc57 \ + --hash=sha256:883c987dee1880e2a864ab0dc9892292582510604156762362d9326444636e78 \ + --hash=sha256:8c5713284ce4e282544c68d1c3b2c7161d38c256d2eefc93c1d683cf47683e66 \ + --hash=sha256:8cafab480740e22f8d833acefed5cc87ce276f4ece12fdaa2e8903db2f82897a \ + --hash=sha256:8df823f570d9adf0978347d1f926b2a867d5608f434a7cff7f7908c6570dcf5e \ + --hash=sha256:9059e10581ce4093f735ed23f3b9d283b9d517ff46009ddd485f1747eb22653c \ + --hash=sha256:905d16e0c60200656500c95b6b8dca5d109e23cb24abc701d41c02d74c6b3afa \ + --hash=sha256:9189427407d88ff25ecf8f12469d4d39d35bee1db5d39fc5c168c6f088a6956d \ + --hash=sha256:96a55f64139912d61de9137f11bf39a55ec8faec288c75a54f93dfd39f7eb40c \ + --hash=sha256:97032a27bd9d8988b9a97a8c4d2c9f2c15a81f61e2f21404d7e8ef00cb5be729 \ + --hash=sha256:984d96121c9f9616cd33fbd0618b7f08e0cfc9600a7ee1d6fd9b239186d19d97 \ + --hash=sha256:9a92ae5c14811e390f3767053ff54eaee3bf84576d99a2456391401323f4ec2c \ + --hash=sha256:9ea91dfb7c3d1c56a0e55657c0afb38cf1eeae4544c208dc465c3c9f3a7c09f9 \ + --hash=sha256:a15f476a45e6e5a3a79d8a14e62161d27ad897381fecfa4a09ed5322f2085669 \ + --hash=sha256:a392a68bd329eafac5817e5aefeb39038c48b671afd242710b451e76090e81f4 \ + --hash=sha256:a3f4ab0caa7f053f6797fcd4e1e25caee367db3112ef2b6ef82d749530768c73 \ + --hash=sha256:a46288ec55ebbd58947d31d72be2c63cbf839f0a63b49cb755022310792a3385 \ + --hash=sha256:a61ec659f68ae254e4d237816e33171497e978140353c0c2038d46e63282d0c8 \ + --hash=sha256:a842d573724391493a97a62ebbb8e731f8a5dcc5d285dfc99141ca15a3302d0c \ + --hash=sha256:becfae3ddd30736fe1889a37f1f580e245ba79a5855bff5f2a29cb3ccc22dd7b \ + --hash=sha256:c05e238064fc0610c840d1cf6a13bf63d7e391717d247f1bf0318172e759e692 \ + --hash=sha256:c1c9307701fec8f3f7a1e6711f9089c06e6284b3afbbcd259f7791282d660a15 \ + --hash=sha256:c7b0be4ef08607dd04da4092faee0b86607f111d5ae68036f16cc787e250a131 \ + --hash=sha256:cfd41e13fdc257aa5778496b8caa5e856dc4896d4ccf01841daee1d96465467a \ + --hash=sha256:d731a1c6116ba289c1e9ee714b08a8ff882944d4ad631fd411106a30f083c326 \ + --hash=sha256:df55d490dea7934f330006d0f81e8551ba6010a5bf035a249ef61a94f21c500b \ + --hash=sha256:ec9852fb39354b5a45a80bdab5ac02dd02b15f44b3804e9f00c556bf24b4bded \ + --hash=sha256:f15975dfec0cf2239224d80e32c3170b1d168335eaedee69da84fbe9f1f9cd04 \ + --hash=sha256:f26b258c385842546006213344c50655ff1555a9338e2e5e02a0756dc3e803dd + # via + # -r build/nonfreethreading-requirements.txt # contourpy + # h5py + # jaxlib + # keras # matplotlib # ml-dtypes - # opt-einsum + # numpy-typing-compat + # optype # scipy -nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ - --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ - --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ - --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 + # tensorboard + # tensorflow + # tensorstore +numpy-typing-compat==20251206.2.0 \ + --hash=sha256:413171c4333c4175cbad4206c94e58422d291d20426c42581865380156715493 \ + --hash=sha256:7db9d5e991af03b2ade38f43253e4eb03ab88925230931bff7f559c020676fb1 + # via optype +nvidia-cublas==13.1.0.3 ; sys_platform == "linux" \ + --hash=sha256:2a3b94a37def342471c59fad7856caee4926809a72dd5270155d6a31b5b277be \ + --hash=sha256:c86fc7f7ae36d7528288c5d88098edcb7b02c633d262e7ddbb86b0ad91be5df2 \ + --hash=sha256:ee8722c1f0145ab246bccb9e452153b5e0515fd094c3678df50b2a0888b8b171 # via - # -r build/test-requirements.txt + # -r build/nvidia-requirements.txt + # nvidia-cudnn-cu13 + # nvidia-cusolver +nvidia-cublas-cu12==12.9.1.4 ; sys_platform == "linux" \ + --hash=sha256:1e5fee10662e6e52bd71dec533fbbd4971bb70a5f24f3bc3793e5c2e9dc640bf \ + --hash=sha256:453611eb21a7c1f2c2156ed9f3a45b691deda0440ec550860290dc901af5b4c2 \ + --hash=sha256:7a950dae01add3b415a5a5cdc4ec818fb5858263e9cca59004bb99fdbbd3a5d6 + # via + # -r build/nvidia-requirements.txt # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 -nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \ - --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ - --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ - --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 - # via -r build/test-requirements.txt -nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \ - --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ - --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ - --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b - # via -r build/test-requirements.txt -nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ - --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ - --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ - --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 - # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ - --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ - --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ - --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef - # via -r build/test-requirements.txt -nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ - --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ - --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ - --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 - # via -r build/test-requirements.txt -nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \ - --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ - --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ - --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac - # via -r build/test-requirements.txt -nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \ - --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ - --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ - --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 +nvidia-cuda-crt==13.0.88 ; sys_platform == "linux" \ + --hash=sha256:2c8043c7c9e02492716426e9919fc78d2c5b3b2a7a768a88e952676b08aa55a4 \ + --hash=sha256:31e02c52916804ca15e31f272a96181d8fadaf40c4c82a77a6f78071a22eccf3 \ + --hash=sha256:ee2ea2a97073e02ee62bb27841f437332be2c248e3eac013df07997ada39c003 # via - # -r build/test-requirements.txt + # -r build/nvidia-requirements.txt + # nvidia-cuda-nvcc +nvidia-cuda-cupti==13.0.85 ; sys_platform == "linux" \ + --hash=sha256:4eb01c08e859bf924d222250d2e8f8b8ff6d3db4721288cf35d14252a4d933c8 \ + --hash=sha256:683f58d301548deeefcb8f6fac1b8d907691b9d8b18eccab417f51e362102f00 \ + --hash=sha256:796bd679890ee55fb14a94629b698b6db54bcfd833d391d5e94017dd9d7d3151 + # via -r build/nvidia-requirements.txt +nvidia-cuda-cupti-cu12==12.9.79 ; sys_platform == "linux" \ + --hash=sha256:096bcf334f13e1984ba36685ad4c1d6347db214de03dbb6eebb237b41d9d934f \ + --hash=sha256:1848a9380067560d5bee10ed240eecc22991713e672c0515f9c3d9396adf93c8 \ + --hash=sha256:791853b030602c6a11d08b5578edfb957cadea06e9d3b26adbf8d036135a4afe + # via -r build/nvidia-requirements.txt +nvidia-cuda-nvcc==13.0.88 ; sys_platform == "linux" \ + --hash=sha256:56fe502eb77625a12f25172caa3cdddb4e4c8ba2c8c17dba44b164761b380f03 \ + --hash=sha256:7c3a32c8ca9866addfd784da363ddee2f6874d560027a296f583e86a61f2d543 \ + --hash=sha256:c7ff28f86a24effdc6c034fa15230c549a273e4771b10a7fec14996f8cf3307f + # via -r build/nvidia-requirements.txt +nvidia-cuda-nvcc-cu12==12.9.86 ; sys_platform == "linux" \ + --hash=sha256:44e1eca4d08926193a558d2434b1bf83d57b4d5743e0c431c0c83d51da1df62b \ + --hash=sha256:5d6a0d32fdc7ea39917c20065614ae93add6f577d840233237ff08e9a38f58f0 \ + --hash=sha256:8ed7f0b17dea662755395be029376db3b94fed5cbb17c2d35cc866c5b1b84099 + # via -r build/nvidia-requirements.txt +nvidia-cuda-nvrtc==13.0.88 ; sys_platform == "linux" \ + --hash=sha256:6bcd4e7f8e205cbe644f5a98f2f799bef9556fefc89dd786e79a16312ce49872 \ + --hash=sha256:ad9b6d2ead2435f11cbb6868809d2adeeee302e9bb94bcf0539c7a40d80e8575 \ + --hash=sha256:d27f20a0ca67a4bb34268a5e951033496c5b74870b868bacd046b1b8e0c3267b + # via -r build/nvidia-requirements.txt +nvidia-cuda-nvrtc-cu12==12.9.86 ; sys_platform == "linux" \ + --hash=sha256:096d4de6bda726415dfaf3198d4f5c522b8e70139c97feef5cd2ca6d4cd9cead \ + --hash=sha256:210cf05005a447e29214e9ce50851e83fc5f4358df8b453155d5e1918094dcb4 \ + --hash=sha256:72972ebdcf504d69462d3bcd67e7b81edd25d0fb85a2c46d3ea3517666636349 + # via -r build/nvidia-requirements.txt +nvidia-cuda-runtime==13.0.96 ; sys_platform == "linux" \ + --hash=sha256:7f82250d7782aa23b6cfe765ecc7db554bd3c2870c43f3d1821f1d18aebf0548 \ + --hash=sha256:ef9bcbe90493a2b9d810e43d249adb3d02e98dd30200d86607d8d02687c43f55 \ + --hash=sha256:f79298c8a098cec150a597c8eba58ecdab96e3bdc4b9bc4f9983635031740492 + # via + # -r build/nvidia-requirements.txt + # nvidia-cuda-nvcc +nvidia-cuda-runtime-cu12==12.9.79 ; sys_platform == "linux" \ + --hash=sha256:25bba2dfb01d48a9b59ca474a1ac43c6ebf7011f1b0b8cc44f54eb6ac48a96c3 \ + --hash=sha256:83469a846206f2a733db0c42e223589ab62fd2fabac4432d2f8802de4bded0a4 \ + --hash=sha256:8e018af8fa02363876860388bd10ccb89eb9ab8fb0aa749aaf58430a9f7c4891 + # via -r build/nvidia-requirements.txt +nvidia-cudnn-cu12==9.16.0.29 ; sys_platform == "linux" \ + --hash=sha256:142e2bd646a4573ab17d61a24c6359155cdfe1f34c67fc305b71222a7ae45b8e \ + --hash=sha256:4b09c43096db582f110c5572d0bcbd98b30d709e860a8f73c6c3846baa83b8d2 \ + --hash=sha256:78d05b4434dacc7dd9bc903d5c33a2f28a5f0064d02568ef7b2418f89f6c5922 + # via -r build/nvidia-requirements.txt +nvidia-cudnn-cu13==9.16.0.29 ; sys_platform == "linux" \ + --hash=sha256:6349bc8769369a91611c5e2ce5c2e510e61848c245099c31e870d2cdce0ab90d \ + --hash=sha256:79dc1bfe8c1a780cf4eb7b334d14d7927576d6dd8823f8e2769911af30fd4da3 \ + --hash=sha256:faafa46e2e7dd844bbcf06b6adec3fa66924987f2fb21bf67f5c6fd697c74a64 + # via -r build/nvidia-requirements.txt +nvidia-cufft==12.0.0.61 ; sys_platform == "linux" \ + --hash=sha256:2708c852ef8cd89d1d2068bdbece0aa188813a0c934db3779b9b1faa8442e5f5 \ + --hash=sha256:2abce5b39d2f5ae12730fb7e5db6696533e36c26e2d3e8fd1750bdd2853364eb \ + --hash=sha256:6c44f692dce8fd5ffd3e3df134b6cdb9c2f72d99cf40b62c32dde45eea9ddad3 + # via -r build/nvidia-requirements.txt +nvidia-cufft-cu12==11.4.1.4 ; sys_platform == "linux" \ + --hash=sha256:1a28c9b12260a1aa7a8fd12f5ebd82d027963d635ba82ff39a1acfa7c4c0fbcf \ + --hash=sha256:8e5bfaac795e93f80611f807d42844e8e27e340e0cde270dcb6c65386d795b80 \ + --hash=sha256:c67884f2a7d276b4b80eb56a79322a95df592ae5e765cf1243693365ccab4e28 + # via -r build/nvidia-requirements.txt +nvidia-cusolver==12.0.4.66 ; sys_platform == "linux" \ + --hash=sha256:02c2457eaa9e39de20f880f4bd8820e6a1cfb9f9a34f820eb12a155aa5bc92d2 \ + --hash=sha256:0a759da5dea5c0ea10fd307de75cdeb59e7ea4fcb8add0924859b944babf1112 \ + --hash=sha256:16515bd33a8e76bb54d024cfa068fa68d30e80fc34b9e1090813ea9362e0cb65 + # via -r build/nvidia-requirements.txt +nvidia-cusolver-cu12==11.7.5.82 ; sys_platform == "linux" \ + --hash=sha256:15da72d1340d29b5b3cf3fd100e3cd53421dde36002eda6ed93811af63c40d88 \ + --hash=sha256:62efa83e4ace59a4c734d052bb72158e888aa7b770e1a5f601682f16fe5b4fd2 \ + --hash=sha256:77666337237716783c6269a658dea310195cddbd80a5b2919b1ba8735cec8efd + # via -r build/nvidia-requirements.txt +nvidia-cusparse==12.6.3.3 ; sys_platform == "linux" \ + --hash=sha256:2b3c89c88d01ee0e477cb7f82ef60a11a4bcd57b6b87c33f789350b59759360b \ + --hash=sha256:80bcc4662f23f1054ee334a15c72b8940402975e0eab63178fc7e670aa59472c \ + --hash=sha256:cbcf42feb737bd7ec15b4c0a63e62351886bd3f975027b8815d7f720a2b5ea79 + # via + # -r build/nvidia-requirements.txt + # nvidia-cusolver +nvidia-cusparse-cu12==12.5.10.65 ; sys_platform == "linux" \ + --hash=sha256:221c73e7482dd93eda44e65ce567c031c07e2f93f6fa0ecd3ba876a195023e83 \ + --hash=sha256:73060ce019ac064a057267c585bf1fd5a353734151f87472ff02b2c5c9984e78 \ + --hash=sha256:9e487468a22a1eaf1fbd1d2035936a905feb79c4ce5c2f67626764ee4f90227c + # via + # -r build/nvidia-requirements.txt # nvidia-cusolver-cu12 -nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \ - --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ - --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 - # via -r build/test-requirements.txt -nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \ - --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ - --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ - --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 +nvidia-nccl-cu12==2.28.9 ; sys_platform == "linux" \ + --hash=sha256:485776daa8447da5da39681af455aa3b2c2586ddcf4af8772495e7c532c7e5ab \ + --hash=sha256:50a36e01c4a090b9f9c47d92cec54964de6b9fcb3362d0e19b8ffc6323c21b60 + # via -r build/nvidia-requirements.txt +nvidia-nccl-cu13==2.28.9 ; sys_platform == "linux" \ + --hash=sha256:01c873ba1626b54caa12272ed228dc5b2781545e0ae8ba3f432a8ef1c6d78643 \ + --hash=sha256:e4553a30f34195f3fa1da02a6da3d6337d28f2003943aa0a3d247bbc25fefc42 + # via -r build/nvidia-requirements.txt +nvidia-nvjitlink==13.0.88 ; sys_platform == "linux" \ + --hash=sha256:13a74f429e23b921c1109976abefacc69835f2f433ebd323d3946e11d804e47b \ + --hash=sha256:634e96e3da9ef845ae744097a1f289238ecf946ce0b82e93cdce14b9782e682f \ + --hash=sha256:e931536ccc7d467a98ba1d8b89ff7fa7f1fa3b13f2b0069118cd7f47bff07d0c # via - # -r build/test-requirements.txt + # -r build/nvidia-requirements.txt + # nvidia-cufft + # nvidia-cusolver + # nvidia-cusparse +nvidia-nvjitlink-cu12==12.9.86 ; sys_platform == "linux" \ + --hash=sha256:994a05ef08ef4b0b299829cde613a424382aff7efb08a7172c1fa616cc3af2ca \ + --hash=sha256:cc6fcec260ca843c10e34c936921a1c426b351753587fdd638e8cff7b16bb9db \ + --hash=sha256:e3f1171dbdc83c5932a45f0f4c99180a70de9bd2718c1ab77d14104f6d7147f9 + # via + # -r build/nvidia-requirements.txt # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 -opt-einsum==3.3.0 \ - --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ - --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 +nvidia-nvshmem-cu12==3.4.5 ; sys_platform == "linux" \ + --hash=sha256:042f2500f24c021db8a06c5eec2539027d57460e1c1a762055a6554f72c369bd \ + --hash=sha256:0b48363fc6964dede448029434c6abed6c5e37f823cb43c3bcde7ecfc0457e15 + # via -r build/nvidia-requirements.txt +nvidia-nvshmem-cu13==3.4.5 ; sys_platform == "linux" \ + --hash=sha256:290f0a2ee94c9f3687a02502f3b9299a9f9fe826e6d0287ee18482e78d495b80 \ + --hash=sha256:6dc2a197f38e5d0376ad52cd1a2a3617d3cdc150fd5966f4aee9bcebb1d68fe9 + # via -r build/nvidia-requirements.txt +nvidia-nvvm==13.0.88 ; sys_platform == "linux" \ + --hash=sha256:2ef0db7849e476d3b2fc3c09b27bdd79bd7ea8ce58cd9c86553d64ea40844ba0 \ + --hash=sha256:c4376a291d72d22a315d9d2f69bdae8f8cd83a627f75bad395cee49a0fe65dc1 \ + --hash=sha256:c5f41ffeb6466944a026dfa5317d7d85355c119bbec279205d22f1869d1054e0 + # via + # -r build/nvidia-requirements.txt + # nvidia-cuda-nvcc +opt-einsum==3.4.0 \ + --hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \ + --hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac # via # -r build/requirements.in - # -r build/test-requirements.txt -packaging==24.0 \ - --hash=sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 \ - --hash=sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9 + # tensorflow +optree==0.18.0 \ + --hash=sha256:01b79aaee544adf5bfa573db32b943030dfeb9fd1c6e7a97aa417db56a8127e7 \ + --hash=sha256:02d9999840fabef85a6b22e757f336d5591f712f99c710d8b232d52e53115314 \ + --hash=sha256:056894ce6242cd1c7fed71325a7d9f633b2d3b4420c52af48f6a0c4560d74ca1 \ + --hash=sha256:057b983a9526645133553184bed7090bb07855df986abd9e99c456922045c6bc \ + --hash=sha256:07c5f64783ad0f0f80e61c25f276ce79b47deda83ed7956a4a9af6385fe8f60d \ + --hash=sha256:090a3f0ccafa0fe99d71e7d974ae52ff966ac26c409ec41f96556b96646054ef \ + --hash=sha256:0959bac58631e64e2ac6349cc284b37872c24f353b3d73b4682202a431f07d76 \ + --hash=sha256:0d25941de1acba176305dbdeb931dea6143b30d64ebdc5bfea2bfc12ef9e2b0a \ + --hash=sha256:0e0dbe995241efe70cfb522e89c1a7c968216926725a0e5e20cc72bd5d0311b1 \ + --hash=sha256:10f29662d637b80363dc620da46ddc58def7acf7935e20595b23e216ea912367 \ + --hash=sha256:1545c68299c0ce600e4ea1bc9112765dc4afe9a0b8ab43f955df6566bf78db42 \ + --hash=sha256:1b75e083137f361377ff8d70df885ab3a1cf8980e4019e3f311237579adadb64 \ + --hash=sha256:1db0a6497203a13063a8f044ae751dd5d8253cb815359270c38de0e4c9f8bed5 \ + --hash=sha256:1f19867b02a547fc9f11d27c0413e7483cef89699e16f3b9e8af73a9b25e6061 \ + --hash=sha256:1f674e34202383f8b42fa9335f13bedfb6b6f019c66e1f41034929e4be203423 \ + --hash=sha256:20536964ba2458f166c1e8ab25951e3fc0a5056b651bd08f16be99bb3ffed54a \ + --hash=sha256:27611c6c122745a003b5be7aedba49ef86e9fef46d743c234596de0bde6dc679 \ + --hash=sha256:27b1d0cadcf4627c98abbbdce912dbc2243f5687f3c7df39963b793c89321c65 \ + --hash=sha256:289b184cc41dfc400a30db6207ec997884d14540aae2cba10cb88dc7ebaae2a1 \ + --hash=sha256:2b5cfb5fc643f16d3a7d957807e55a937dce07566c49ccc4aa71b01064c56758 \ + --hash=sha256:3014537ff7e4e091ee46e57976f7d95c52f66a0e3eb5ebcbe0de0d924504b58e \ + --hash=sha256:30a2636279bdc805c8e154a0f346bcf704626b831ff44724d305fb72c90b7389 \ + --hash=sha256:30f95279188f6b9300e17c1557989baa991c2d6f519013bd8fea13462a0e6a45 \ + --hash=sha256:30fefc84975ac41d9075993196c64ce0c240510f0539cff121d63b709e03846f \ + --hash=sha256:31539dec60af84e16e99574634811d38e34e1fb381f40d6f489a2e582bf41f03 \ + --hash=sha256:328857d7a35129904b21164f6b0c2ff1d728ad1f5838589c5f437a16c94213c8 \ + --hash=sha256:3804fb6ddc923855db2dc4805b4524c66e00f1ef30b166be4aadd52822b13e06 \ + --hash=sha256:382e5ca02cbd5b20d713d4da189a8613f828832e2af57ccbe04a9c6b0bd9497e \ + --hash=sha256:385bd727cc7bd3c01bd6204028ac2adce8a8f622c296053d9df434aa0e30b01f \ + --hash=sha256:421b839c7ff30df5791e66c89b2e9c2f68191dd6a5d6927c32bcc6b887090df8 \ + --hash=sha256:446c46c53cb8f13abcc0d7dd1989d59bb059953c122fe9901ef53de7fb38b33e \ + --hash=sha256:4cc92339899acb685ee718fd22b25069dfa7be038c63274c54481d54ccc2f9e2 \ + --hash=sha256:4eb146711d4cd0876bf93e0118d3e74050b6f633d756c269ce7cda907281b499 \ + --hash=sha256:51e2cd9ac7fecfd5f6f56ce69f4f805553c226a2744810175959eb408101513c \ + --hash=sha256:55a2ccd121fccc9df961e982db2f4e8f2b4f7015e814ef70b1140514cdffe214 \ + --hash=sha256:56bb19ff827c9a443202b52bf103705ce96ef14d045e0a30d0d7ee7dbcef6a0d \ + --hash=sha256:571b732229d7b2e7a2215f57586f8ec0140e07c0faea916e456cbbfa819e56cb \ + --hash=sha256:5b126c34b459ef4f10f3a4d7d222416d9102b3c5a76b39f346c611792f144821 \ + --hash=sha256:5b75e32c191e4b8cf42a8aa854ed264df82936136c0bcad77be44605da41cdfc \ + --hash=sha256:5bc1221068a58175e0ad62afc199893f77c653206673a5552992a604c66fb77e \ + --hash=sha256:5e669f98b9af9f66144c7ae09912d0367ac3182abe016f67cdd15cb45e13c923 \ + --hash=sha256:66f142c743732cd4e630ea84415f654a00c792793c7f80d4511167f0f89796a6 \ + --hash=sha256:6fc9f8acde3bb561b2034e96079507fbe6d4624058fe204161eb8ef29f961296 \ + --hash=sha256:7172b16e87c87160475275e4bfaa6e4067ccde184d2cca65ba25a402a8ed7758 \ + --hash=sha256:71ca2fcad8972ba56d6cfffbcd962f45f5d4bc04182f23d66154b38c2eb37de3 \ + --hash=sha256:72fa79be4d6515682417f103ae759a22345439eb1319886be936029215ee00dc \ + --hash=sha256:7699957183f8d45402edd6266e175510317f5fcd7f0e623510f2eb7e1ebfc667 \ + --hash=sha256:79bbe14d6cad81f5840958589daa1b836864ada40031712a446dce8129917efd \ + --hash=sha256:7ae6945f68771b1389ee46a1778e779f4ad76bca9306f3e39eb397f9a0dd2753 \ + --hash=sha256:80d971060c888c3989132b7e75dfb50848636d41bc931af1b93fe2019fba469c \ + --hash=sha256:80f28e4666aad66e5e20bdc2c47b5bf320250bb5407b3a39dfb1772787a7068f \ + --hash=sha256:81e755124b77e766166c9d05206b90c68f234f425ad2e3c8a6c96f0db548c67b \ + --hash=sha256:86f5bf05ad236f666e5395e989d6ac2cbfd02556526703e6c6f0a594c7fa081f \ + --hash=sha256:895f23a4cd8aee2c2464efdad2d9bde28a2aaabee634c96423a933f40e74a67e \ + --hash=sha256:89d5156f8a0a3792701e1c31473eb307f0b45696f48dc51d721f1bfe0c3a950f \ + --hash=sha256:89e81afb11792d13d3777b503c6f21ec17b1a3b7de69cde1ae2c5471bcdcd4a0 \ + --hash=sha256:8a2003fab79694e04b5f260628511e441c248b46a9fc46138e2424038ac04ada \ + --hash=sha256:8a4ca121b6fc6b04300fa225fe6c31897e424db0d92691875af326f8c4e1cead \ + --hash=sha256:8a901666afc2d7a8d0c20decc8079763e3313457ee67210382162d90163c0007 \ + --hash=sha256:8b9ad4a01a1346b11acc574b7f932dea1a7c7ab31d93546a7540a1f02b3e724a \ + --hash=sha256:8d88c00c70b5914904feaf8f505f3512c2f3f4493dbbd93951fcdddc85dcfe8c \ + --hash=sha256:9104fc8915890e7292e5833fc677e4749607c67aa3cf8884677267078201c2f3 \ + --hash=sha256:9460cba62e941626beb75c99a803373b38a52136d5f1932fcdfdcede1df6f2ef \ + --hash=sha256:94983b3aa31ee401d2ac77ba570a3157d83f9508cfbb006095a48770e0a1c5ca \ + --hash=sha256:9b1e7e8f9ddc85f05d542b74157bdb73ed0e49aded67d1775f721fcd6eb9be94 \ + --hash=sha256:9d4b9d8c7e9335120ecf222d817699d17de743ad118080fb40467c367f009143 \ + --hash=sha256:a479fa25b6e2430e530d00f0c27a55e15ecb9de8ad2d0aec3d40b680e2d6df64 \ + --hash=sha256:a5c213a291c798139ed9ff80aec4bfcd2ac8f001bc015a9cdeb78457e9687dd3 \ + --hash=sha256:a63df296fec376c5cd08298a85109db4a130f4cc8df15916fc92d44ef6068937 \ + --hash=sha256:a74c45f04def041504bd21682eaf7f359f1a50dc7cf42b548b6f19aab50596bd \ + --hash=sha256:ad428ccdb2a40804919880dfe8d2a3021fd4418be15ea7ecb8434ab249badf9f \ + --hash=sha256:b0986ff1267a3b44d3ed76c3efb8b7239371444143f6e0d79f9dd23dbe02c7f9 \ + --hash=sha256:b45d7172c67fc8d2b69f77b384998b39793ee91f8b3b46c609297b781fb7eea5 \ + --hash=sha256:b4da3223c5b4cf694822752d0fbb6bf34c3f41648af1bd1b443cc3d68cc55106 \ + --hash=sha256:b7aa0de08bbbfcef6e49c107f9f397f5d4742548500f16e3e6c5e0b9e4ff0faa \ + --hash=sha256:b8adc912ecb6e4fd9df227ded66efaa6702f46a98e1403554be3c9c51d0ca920 \ + --hash=sha256:ba23caafd0e0c911bb7eab54e5cf69644af864d153e4b2abdab83ff0ef357ba1 \ + --hash=sha256:bda4572392ac1dff3fc67b6d9a4b1084e1637972e8135ad3788b4ce7cf0a90f5 \ + --hash=sha256:c017539e1196ea08f20aea3a4c473f758149b851678edd3d15773b4326decf83 \ + --hash=sha256:c1f20e8754abe312a701ee00d071ddd8502e9d97ca38fbc56204d14a9ffcb41c \ + --hash=sha256:c8841d44f3648b0662e99fc39ef8c248726ddfb4d1bfce4bdba982e51bb7e3f8 \ + --hash=sha256:cbb083a15ea968ad99e7da17d24632348d69e26534e83c69941f3020ed7536eb \ + --hash=sha256:cde70c97e4cc4e997e8fda2266e40a9bff7679c72ab4af6e15e81748a12882cc \ + --hash=sha256:cfa2e16993ba47e671a4e7ee1ad805f67b8d6744eb30a9d27ea0b07b3b7a22ed \ + --hash=sha256:d20765efa494a80a8fd91c4de8890f34de8e9f234da5516e8f34f55703cfb93d \ + --hash=sha256:d2844478690b5892159df0b2500e9d146dc8d3aa5b44e4564d05787b7330eca3 \ + --hash=sha256:d569730b2647c51a5ee68d67198aa9a78c7a55563d57b8cc1ca8d8c8377e7621 \ + --hash=sha256:daab231cf768937ce4675376ea3e214d399116d9867a6737372c31c58630bdfc \ + --hash=sha256:db00c604c1ae452f6092293bf230984d4f6cbb3ad905a9991e8cf680fd7d1523 \ + --hash=sha256:e058cc51d9d57b45801060af9f74765b95bedfc59fd6df1c7489ae0825126be5 \ + --hash=sha256:e28024e6e343353285cf99ae9c74210f0e89e47b2f0f3af7c72c4a9e89dc3ebc \ + --hash=sha256:e4a468ae1541614b5aa7b4f00254bce005ab7572fbb1fc764af4ee17d90fde7b \ + --hash=sha256:ea357657143f364a764b63b2b1ce12d77156d48a1f32def990b696d755acb629 \ + --hash=sha256:efd162e3bfc7812d75ebf2d0fb2783daee2407a92155af8a90650a6b0fa9342e \ + --hash=sha256:f02faeda66d531dc5f5356589afcf2a6bc41c8d00bc903efab60f9a2182b140d \ + --hash=sha256:f04286908654ffb05455254ebf72fe69473fc4560fc7ea49410df94dea6783a2 \ + --hash=sha256:f5197f864630162f008f5dfad3fceef32553c0fa7639eee1b8e280d924ed678e \ + --hash=sha256:f81f5340c8df50662abaf753ab07095901e40b934efb27da50032a4ae71c5a97 \ + --hash=sha256:fa8e3878a1857761d64f08a23b32140d29754a53f85f7c87186ced2b5b1b49cb \ + --hash=sha256:ff7326f36ed70d84c3fd62fb39bc6858f699640b8ab238c3cb8dafe1e200af59 + # via keras +optype[numpy]==0.15.0 \ + --hash=sha256:457d6ca9e7da19967ec16d42bdf94e240b33b5d70a56fbbf5b427e5ea39cf41e \ + --hash=sha256:caba40ece9ea39b499fa76c036a82e0d452a432dd4dd3e8e0d30892be2e8c76c + # via scipy-stubs +packaging==25.0 \ + --hash=sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484 \ + --hash=sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f # via # auditwheel # build + # keras # matplotlib # pytest -pillow==11.0.0 \ - --hash=sha256:00177a63030d612148e659b55ba99527803288cea7c75fb05766ab7981a8c1b7 \ - --hash=sha256:006bcdd307cc47ba43e924099a038cbf9591062e6c50e570819743f5607404f5 \ - --hash=sha256:084a07ef0821cfe4858fe86652fffac8e187b6ae677e9906e192aafcc1b69903 \ - --hash=sha256:0ae08bd8ffc41aebf578c2af2f9d8749d91f448b3bfd41d7d9ff573d74f2a6b2 \ - --hash=sha256:0e038b0745997c7dcaae350d35859c9715c71e92ffb7e0f4a8e8a16732150f38 \ - --hash=sha256:1187739620f2b365de756ce086fdb3604573337cc28a0d3ac4a01ab6b2d2a6d2 \ - --hash=sha256:16095692a253047fe3ec028e951fa4221a1f3ed3d80c397e83541a3037ff67c9 \ - --hash=sha256:1a61b54f87ab5786b8479f81c4b11f4d61702830354520837f8cc791ebba0f5f \ - --hash=sha256:1c1d72714f429a521d8d2d018badc42414c3077eb187a59579f28e4270b4b0fc \ - --hash=sha256:1e2688958a840c822279fda0086fec1fdab2f95bf2b717b66871c4ad9859d7e8 \ - --hash=sha256:20ec184af98a121fb2da42642dea8a29ec80fc3efbaefb86d8fdd2606619045d \ - --hash=sha256:21a0d3b115009ebb8ac3d2ebec5c2982cc693da935f4ab7bb5c8ebe2f47d36f2 \ - --hash=sha256:224aaa38177597bb179f3ec87eeefcce8e4f85e608025e9cfac60de237ba6316 \ - --hash=sha256:2679d2258b7f1192b378e2893a8a0a0ca472234d4c2c0e6bdd3380e8dfa21b6a \ - --hash=sha256:27a7860107500d813fcd203b4ea19b04babe79448268403172782754870dac25 \ - --hash=sha256:290f2cc809f9da7d6d622550bbf4c1e57518212da51b6a30fe8e0a270a5b78bd \ - --hash=sha256:2e46773dc9f35a1dd28bd6981332fd7f27bec001a918a72a79b4133cf5291dba \ - --hash=sha256:3107c66e43bda25359d5ef446f59c497de2b5ed4c7fdba0894f8d6cf3822dafc \ - --hash=sha256:375b8dd15a1f5d2feafff536d47e22f69625c1aa92f12b339ec0b2ca40263273 \ - --hash=sha256:45c566eb10b8967d71bf1ab8e4a525e5a93519e29ea071459ce517f6b903d7fa \ - --hash=sha256:499c3a1b0d6fc8213519e193796eb1a86a1be4b1877d678b30f83fd979811d1a \ - --hash=sha256:4ad70c4214f67d7466bea6a08061eba35c01b1b89eaa098040a35272a8efb22b \ - --hash=sha256:4b60c9520f7207aaf2e1d94de026682fc227806c6e1f55bba7606d1c94dd623a \ - --hash=sha256:5178952973e588b3f1360868847334e9e3bf49d19e169bbbdfaf8398002419ae \ - --hash=sha256:52a2d8323a465f84faaba5236567d212c3668f2ab53e1c74c15583cf507a0291 \ - --hash=sha256:598b4e238f13276e0008299bd2482003f48158e2b11826862b1eb2ad7c768b97 \ - --hash=sha256:5bd2d3bdb846d757055910f0a59792d33b555800813c3b39ada1829c372ccb06 \ - --hash=sha256:5c39ed17edea3bc69c743a8dd3e9853b7509625c2462532e62baa0732163a904 \ - --hash=sha256:5d203af30149ae339ad1b4f710d9844ed8796e97fda23ffbc4cc472968a47d0b \ - --hash=sha256:5ddbfd761ee00c12ee1be86c9c0683ecf5bb14c9772ddbd782085779a63dd55b \ - --hash=sha256:607bbe123c74e272e381a8d1957083a9463401f7bd01287f50521ecb05a313f8 \ - --hash=sha256:61b887f9ddba63ddf62fd02a3ba7add935d053b6dd7d58998c630e6dbade8527 \ - --hash=sha256:6619654954dc4936fcff82db8eb6401d3159ec6be81e33c6000dfd76ae189947 \ - --hash=sha256:674629ff60030d144b7bca2b8330225a9b11c482ed408813924619c6f302fdbb \ - --hash=sha256:6ec0d5af64f2e3d64a165f490d96368bb5dea8b8f9ad04487f9ab60dc4bb6003 \ - --hash=sha256:6f4dba50cfa56f910241eb7f883c20f1e7b1d8f7d91c750cd0b318bad443f4d5 \ - --hash=sha256:70fbbdacd1d271b77b7721fe3cdd2d537bbbd75d29e6300c672ec6bb38d9672f \ - --hash=sha256:72bacbaf24ac003fea9bff9837d1eedb6088758d41e100c1552930151f677739 \ - --hash=sha256:7326a1787e3c7b0429659e0a944725e1b03eeaa10edd945a86dead1913383944 \ - --hash=sha256:73853108f56df97baf2bb8b522f3578221e56f646ba345a372c78326710d3830 \ - --hash=sha256:73e3a0200cdda995c7e43dd47436c1548f87a30bb27fb871f352a22ab8dcf45f \ - --hash=sha256:75acbbeb05b86bc53cbe7b7e6fe00fbcf82ad7c684b3ad82e3d711da9ba287d3 \ - --hash=sha256:8069c5179902dcdce0be9bfc8235347fdbac249d23bd90514b7a47a72d9fecf4 \ - --hash=sha256:846e193e103b41e984ac921b335df59195356ce3f71dcfd155aa79c603873b84 \ - --hash=sha256:8594f42df584e5b4bb9281799698403f7af489fba84c34d53d1c4bfb71b7c4e7 \ - --hash=sha256:86510e3f5eca0ab87429dd77fafc04693195eec7fd6a137c389c3eeb4cfb77c6 \ - --hash=sha256:8853a3bf12afddfdf15f57c4b02d7ded92c7a75a5d7331d19f4f9572a89c17e6 \ - --hash=sha256:88a58d8ac0cc0e7f3a014509f0455248a76629ca9b604eca7dc5927cc593c5e9 \ - --hash=sha256:8ba470552b48e5835f1d23ecb936bb7f71d206f9dfeee64245f30c3270b994de \ - --hash=sha256:8c676b587da5673d3c75bd67dd2a8cdfeb282ca38a30f37950511766b26858c4 \ - --hash=sha256:8ec4a89295cd6cd4d1058a5e6aec6bf51e0eaaf9714774e1bfac7cfc9051db47 \ - --hash=sha256:94f3e1780abb45062287b4614a5bc0874519c86a777d4a7ad34978e86428b8dd \ - --hash=sha256:9a0f748eaa434a41fccf8e1ee7a3eed68af1b690e75328fd7a60af123c193b50 \ - --hash=sha256:a5629742881bcbc1f42e840af185fd4d83a5edeb96475a575f4da50d6ede337c \ - --hash=sha256:a65149d8ada1055029fcb665452b2814fe7d7082fcb0c5bed6db851cb69b2086 \ - --hash=sha256:b3c5ac4bed7519088103d9450a1107f76308ecf91d6dabc8a33a2fcfb18d0fba \ - --hash=sha256:b4fd7bd29610a83a8c9b564d457cf5bd92b4e11e79a4ee4716a63c959699b306 \ - --hash=sha256:bcd1fb5bb7b07f64c15618c89efcc2cfa3e95f0e3bcdbaf4642509de1942a699 \ - --hash=sha256:c12b5ae868897c7338519c03049a806af85b9b8c237b7d675b8c5e089e4a618e \ - --hash=sha256:c26845094b1af3c91852745ae78e3ea47abf3dbcd1cf962f16b9a5fbe3ee8488 \ - --hash=sha256:c6a660307ca9d4867caa8d9ca2c2658ab685de83792d1876274991adec7b93fa \ - --hash=sha256:c809a70e43c7977c4a42aefd62f0131823ebf7dd73556fa5d5950f5b354087e2 \ - --hash=sha256:c8b2351c85d855293a299038e1f89db92a2f35e8d2f783489c6f0b2b5f3fe8a3 \ - --hash=sha256:cb929ca942d0ec4fac404cbf520ee6cac37bf35be479b970c4ffadf2b6a1cad9 \ - --hash=sha256:d2c0a187a92a1cb5ef2c8ed5412dd8d4334272617f532d4ad4de31e0495bd923 \ - --hash=sha256:d69bfd8ec3219ae71bcde1f942b728903cad25fafe3100ba2258b973bd2bc1b2 \ - --hash=sha256:daffdf51ee5db69a82dd127eabecce20729e21f7a3680cf7cbb23f0829189790 \ - --hash=sha256:e58876c91f97b0952eb766123bfef372792ab3f4e3e1f1a2267834c2ab131734 \ - --hash=sha256:eda2616eb2313cbb3eebbe51f19362eb434b18e3bb599466a1ffa76a033fb916 \ - --hash=sha256:ee217c198f2e41f184f3869f3e485557296d505b5195c513b2bfe0062dc537f1 \ - --hash=sha256:f02541ef64077f22bf4924f225c0fd1248c168f86e4b7abdedd87d6ebaceab0f \ - --hash=sha256:f1b82c27e89fffc6da125d5eb0ca6e68017faf5efc078128cfaa42cf5cb38798 \ - --hash=sha256:fba162b8872d30fea8c52b258a542c5dfd7b235fb5cb352240c8d63b414013eb \ - --hash=sha256:fbbcb7b57dc9c794843e3d1258c0fbf0f48656d46ffe9e09b63bbd6e8cd5d0a2 \ - --hash=sha256:fcb4621042ac4b7865c179bb972ed0da0218a076dc1820ffc48b1d74c1e37fe9 + # tensorboard + # tensorflow + # wheel +pillow==12.0.0 \ + --hash=sha256:0869154a2d0546545cde61d1789a6524319fc1897d9ee31218eae7a60ccc5643 \ + --hash=sha256:09f2d0abef9e4e2f349305a4f8cc784a8a6c2f58a8c4892eea13b10a943bd26e \ + --hash=sha256:0b817e7035ea7f6b942c13aa03bb554fc44fea70838ea21f8eb31c638326584e \ + --hash=sha256:0fd00cac9c03256c8b2ff58f162ebcd2587ad3e1f2e397eab718c47e24d231cc \ + --hash=sha256:110486b79f2d112cf6add83b28b627e369219388f64ef2f960fef9ebaf54c642 \ + --hash=sha256:1979f4566bb96c1e50a62d9831e2ea2d1211761e5662afc545fa766f996632f6 \ + --hash=sha256:1ac11e8ea4f611c3c0147424eae514028b5e9077dd99ab91e1bd7bc33ff145e1 \ + --hash=sha256:1b1b133e6e16105f524a8dec491e0586d072948ce15c9b914e41cdadd209052b \ + --hash=sha256:1ee80a59f6ce048ae13cda1abf7fbd2a34ab9ee7d401c46be3ca685d1999a399 \ + --hash=sha256:21f241bdd5080a15bc86d3466a9f6074a9c2c2b314100dd896ac81ee6db2f1ba \ + --hash=sha256:266cd5f2b63ff316d5a1bba46268e603c9caf5606d44f38c2873c380950576ad \ + --hash=sha256:26d9f7d2b604cd23aba3e9faf795787456ac25634d82cd060556998e39c6fa47 \ + --hash=sha256:27f95b12453d165099c84f8a8bfdfd46b9e4bda9e0e4b65f0635430027f55739 \ + --hash=sha256:2c54c1a783d6d60595d3514f0efe9b37c8808746a66920315bfd34a938d7994b \ + --hash=sha256:2fa5f0b6716fc88f11380b88b31fe591a06c6315e955c096c35715788b339e3f \ + --hash=sha256:32ed80ea8a90ee3e6fa08c21e2e091bba6eda8eccc83dbc34c95169507a91f10 \ + --hash=sha256:3830c769decf88f1289680a59d4f4c46c72573446352e2befec9a8512104fa52 \ + --hash=sha256:38df9b4bfd3db902c9c2bd369bcacaf9d935b2fff73709429d95cc41554f7b3d \ + --hash=sha256:3adfb466bbc544b926d50fe8f4a4e6abd8c6bffd28a26177594e6e9b2b76572b \ + --hash=sha256:3e42edad50b6909089750e65c91aa09aaf1e0a71310d383f11321b27c224ed8a \ + --hash=sha256:4078242472387600b2ce8d93ade8899c12bf33fa89e55ec89fe126e9d6d5d9e9 \ + --hash=sha256:455247ac8a4cfb7b9bc45b7e432d10421aea9fc2e74d285ba4072688a74c2e9d \ + --hash=sha256:4cc6b3b2efff105c6a1656cfe59da4fdde2cda9af1c5e0b58529b24525d0a098 \ + --hash=sha256:4cf7fed4b4580601c4345ceb5d4cbf5a980d030fd5ad07c4d2ec589f95f09905 \ + --hash=sha256:5193fde9a5f23c331ea26d0cf171fbf67e3f247585f50c08b3e205c7aeb4589b \ + --hash=sha256:5269cc1caeedb67e6f7269a42014f381f45e2e7cd42d834ede3c703a1d915fe3 \ + --hash=sha256:53561a4ddc36facb432fae7a9d8afbfaf94795414f5cdc5fc52f28c1dca90371 \ + --hash=sha256:55f818bd74fe2f11d4d7cbc65880a843c4075e0ac7226bc1a23261dbea531953 \ + --hash=sha256:58eea5ebe51504057dd95c5b77d21700b77615ab0243d8152793dc00eb4faf01 \ + --hash=sha256:5d5c411a8eaa2299322b647cd932586b1427367fd3184ffbb8f7a219ea2041ca \ + --hash=sha256:6846bd2d116ff42cba6b646edf5bf61d37e5cbd256425fa089fee4ff5c07a99e \ + --hash=sha256:6ace95230bfb7cd79ef66caa064bbe2f2a1e63d93471c3a2e1f1348d9f22d6b7 \ + --hash=sha256:6e51b71417049ad6ab14c49608b4a24d8fb3fe605e5dfabfe523b58064dc3d27 \ + --hash=sha256:71db6b4c1653045dacc1585c1b0d184004f0d7e694c7b34ac165ca70c0838082 \ + --hash=sha256:7438839e9e053ef79f7112c881cef684013855016f928b168b81ed5835f3e75e \ + --hash=sha256:759de84a33be3b178a64c8ba28ad5c135900359e85fb662bc6e403ad4407791d \ + --hash=sha256:792a2c0be4dcc18af9d4a2dfd8a11a17d5e25274a1062b0ec1c2d79c76f3e7f8 \ + --hash=sha256:7d87ef5795da03d742bf49439f9ca4d027cde49c82c5371ba52464aee266699a \ + --hash=sha256:7dfb439562f234f7d57b1ac6bc8fe7f838a4bd49c79230e0f6a1da93e82f1fad \ + --hash=sha256:7fa22993bac7b77b78cae22bad1e2a987ddf0d9015c63358032f84a53f23cdc3 \ + --hash=sha256:805ebf596939e48dbb2e4922a1d3852cfc25c38160751ce02da93058b48d252a \ + --hash=sha256:82240051c6ca513c616f7f9da06e871f61bfd7805f566275841af15015b8f98d \ + --hash=sha256:87d4f8125c9988bfbed67af47dd7a953e2fc7b0cc1e7800ec6d2080d490bb353 \ + --hash=sha256:8d8ca2b210ada074d57fcee40c30446c9562e542fc46aedc19baf758a93532ee \ + --hash=sha256:8dc232e39d409036af549c86f24aed8273a40ffa459981146829a324e0848b4b \ + --hash=sha256:90387104ee8400a7b4598253b4c406f8958f59fcf983a6cea2b50d59f7d63d0b \ + --hash=sha256:905b0365b210c73afb0ebe9101a32572152dfd1c144c7e28968a331b9217b94a \ + --hash=sha256:99353a06902c2e43b43e8ff74ee65a7d90307d82370604746738a1e0661ccca7 \ + --hash=sha256:99a7f72fb6249302aa62245680754862a44179b545ded638cf1fef59befb57ef \ + --hash=sha256:9f0b04c6b8584c2c193babcccc908b38ed29524b29dd464bc8801bf10d746a3a \ + --hash=sha256:9fe611163f6303d1619bbcb653540a4d60f9e55e622d60a3108be0d5b441017a \ + --hash=sha256:a3475b96f5908b3b16c47533daaa87380c491357d197564e0ba34ae75c0f3257 \ + --hash=sha256:a6597ff2b61d121172f5844b53f21467f7082f5fb385a9a29c01414463f93b07 \ + --hash=sha256:a7921c5a6d31b3d756ec980f2f47c0cfdbce0fc48c22a39347a895f41f4a6ea4 \ + --hash=sha256:aa5129de4e174daccbc59d0a3b6d20eaf24417d59851c07ebb37aeb02947987c \ + --hash=sha256:aeaefa96c768fc66818730b952a862235d68825c178f1b3ffd4efd7ad2edcb7c \ + --hash=sha256:afbefa430092f71a9593a99ab6a4e7538bc9eabbf7bf94f91510d3503943edc4 \ + --hash=sha256:aff9e4d82d082ff9513bdd6acd4f5bd359f5b2c870907d2b0a9c5e10d40c88fe \ + --hash=sha256:b22bd8c974942477156be55a768f7aa37c46904c175be4e158b6a86e3a6b7ca8 \ + --hash=sha256:b290fd8aa38422444d4b50d579de197557f182ef1068b75f5aa8558638b8d0a5 \ + --hash=sha256:b2e4b27a6e15b04832fe9bf292b94b5ca156016bbc1ea9c2c20098a0320d6cf6 \ + --hash=sha256:b583dc9070312190192631373c6c8ed277254aa6e6084b74bdd0a6d3b221608e \ + --hash=sha256:b87843e225e74576437fd5b6a4c2205d422754f84a06942cfaf1dc32243e45a8 \ + --hash=sha256:bc91a56697869546d1b8f0a3ff35224557ae7f881050e99f615e0119bf934b4e \ + --hash=sha256:bd87e140e45399c818fac4247880b9ce719e4783d767e030a883a970be632275 \ + --hash=sha256:bde737cff1a975b70652b62d626f7785e0480918dece11e8fef3c0cf057351c3 \ + --hash=sha256:bdee52571a343d721fb2eb3b090a82d959ff37fc631e3f70422e0c2e029f3e76 \ + --hash=sha256:bee2a6db3a7242ea309aa7ee8e2780726fed67ff4e5b40169f2c940e7eb09227 \ + --hash=sha256:beeae3f27f62308f1ddbcfb0690bf44b10732f2ef43758f169d5e9303165d3f9 \ + --hash=sha256:c50f36a62a22d350c96e49ad02d0da41dbd17ddc2e29750dbdba4323f85eb4a5 \ + --hash=sha256:c607c90ba67533e1b2355b821fef6764d1dd2cbe26b8c1005ae84f7aea25ff79 \ + --hash=sha256:c7b2a63fd6d5246349f3d3f37b14430d73ee7e8173154461785e43036ffa96ca \ + --hash=sha256:c828a1ae702fc712978bda0320ba1b9893d99be0badf2647f693cc01cf0f04fa \ + --hash=sha256:c85de1136429c524e55cfa4e033b4a7940ac5c8ee4d9401cc2d1bf48154bbc7b \ + --hash=sha256:c98fa880d695de164b4135a52fd2e9cd7b7c90a9d8ac5e9e443a24a95ef9248e \ + --hash=sha256:cae81479f77420d217def5f54b5b9d279804d17e982e0f2fa19b1d1e14ab5197 \ + --hash=sha256:d034140032870024e6b9892c692fe2968493790dd57208b2c37e3fb35f6df3ab \ + --hash=sha256:d120c38a42c234dc9a8c5de7ceaaf899cf33561956acb4941653f8bdc657aa79 \ + --hash=sha256:d4827615da15cd59784ce39d3388275ec093ae3ee8d7f0c089b76fa87af756c2 \ + --hash=sha256:d49e2314c373f4c2b39446fb1a45ed333c850e09d0c59ac79b72eb3b95397363 \ + --hash=sha256:d52610d51e265a51518692045e372a4c363056130d922a7351429ac9f27e70b0 \ + --hash=sha256:d64317d2587c70324b79861babb9c09f71fbb780bad212018874b2c013d8600e \ + --hash=sha256:d77153e14b709fd8b8af6f66a3afbb9ed6e9fc5ccf0b6b7e1ced7b036a228782 \ + --hash=sha256:d7e091d464ac59d2c7ad8e7e08105eaf9dafbc3883fd7265ffccc2baad6ac925 \ + --hash=sha256:dd333073e0cacdc3089525c7df7d39b211bcdf31fc2824e49d01c6b6187b07d0 \ + --hash=sha256:e5d8efac84c9afcb40914ab49ba063d94f5dbdf5066db4482c66a992f47a3a3b \ + --hash=sha256:f135c702ac42262573fe9714dfe99c944b4ba307af5eb507abef1667e2cbbced \ + --hash=sha256:f13711b1a5ba512d647a0e4ba79280d3a9a045aaf7e0cc6fbe96b91d4cdf6b0c \ + --hash=sha256:f4f1231b7dec408e8670264ce63e9c71409d9583dd21d32c163e25213ee2a344 \ + --hash=sha256:fa3ed2a29a9e9d2d488b4da81dcb54720ac3104a20bf0bd273f1e4648aff5af9 \ + --hash=sha256:fb3096c30df99fd01c7bf8e544f392103d0795b9f98ba71a8054bcbf56b255f1 # via # -r build/test-requirements.txt # matplotlib -pluggy==1.5.0 \ - --hash=sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1 \ - --hash=sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669 + # tensorboard +pluggy==1.6.0 \ + --hash=sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3 \ + --hash=sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746 # via pytest portpicker==1.6.0 \ --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa # via -r build/test-requirements.txt -psutil==5.9.8 \ - --hash=sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d \ - --hash=sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73 \ - --hash=sha256:26bd09967ae00920df88e0352a91cff1a78f8d69b3ecabbfe733610c0af486c8 \ - --hash=sha256:27cc40c3493bb10de1be4b3f07cae4c010ce715290a5be22b98493509c6299e2 \ - --hash=sha256:36f435891adb138ed3c9e58c6af3e2e6ca9ac2f365efe1f9cfef2794e6c93b4e \ - --hash=sha256:50187900d73c1381ba1454cf40308c2bf6f34268518b3f36a9b663ca87e65e36 \ - --hash=sha256:611052c4bc70432ec770d5d54f64206aa7203a101ec273a0cd82418c86503bb7 \ - --hash=sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c \ - --hash=sha256:7d79560ad97af658a0f6adfef8b834b53f64746d45b403f225b85c5c2c140eee \ - --hash=sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421 \ - --hash=sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf \ - --hash=sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81 \ - --hash=sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0 \ - --hash=sha256:bd1184ceb3f87651a67b2708d4c3338e9b10c5df903f2e3776b62303b26cb631 \ - --hash=sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4 \ - --hash=sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8 +protobuf==6.33.2 \ + --hash=sha256:1f8017c48c07ec5859106533b682260ba3d7c5567b1ca1f24297ce03384d1b4f \ + --hash=sha256:2981c58f582f44b6b13173e12bb8656711189c2a70250845f264b877f00b1913 \ + --hash=sha256:56dc370c91fbb8ac85bc13582c9e373569668a290aa2e66a590c2a0d35ddb9e4 \ + --hash=sha256:7109dcc38a680d033ffb8bf896727423528db9163be1b6a02d6a49606dcadbfe \ + --hash=sha256:7636aad9bb01768870266de5dc009de2d1b936771b38a793f73cbbf279c91c5c \ + --hash=sha256:87eb388bd2d0f78febd8f4c8779c79247b26a5befad525008e49a6955787ff3d \ + --hash=sha256:8cd7640aee0b7828b6d03ae518b5b4806fdfc1afe8de82f79c3454f8aef29872 \ + --hash=sha256:b5d3b5625192214066d99b2b605f5783483575656784de223f00a8d00754fc0e \ + --hash=sha256:d9b19771ca75935b3a4422957bc518b0cecb978b31d1dd12037b088f6bcc0e43 \ + --hash=sha256:fc2a0e8b05b180e5fc0dd1559fe8ebdae21a27e81ac77728fb6c42b12c7419b4 + # via + # tensorboard + # tensorflow +psutil==7.1.3 \ + --hash=sha256:0005da714eee687b4b8decd3d6cc7c6db36215c9e74e5ad2264b90c3df7d92dc \ + --hash=sha256:1068c303be3a72f8e18e412c5b2a8f6d31750fb152f9cb106b54090296c9d251 \ + --hash=sha256:18349c5c24b06ac5612c0428ec2a0331c26443d259e2a0144a9b24b4395b58fa \ + --hash=sha256:19644c85dcb987e35eeeaefdc3915d059dac7bd1167cdcdbf27e0ce2df0c08c0 \ + --hash=sha256:2bdbcd0e58ca14996a42adf3621a6244f1bb2e2e528886959c72cf1e326677ab \ + --hash=sha256:31d77fcedb7529f27bb3a0472bea9334349f9a04160e8e6e5020f22c59893264 \ + --hash=sha256:3792983e23b69843aea49c8f5b8f115572c5ab64c153bada5270086a2123c7e7 \ + --hash=sha256:3bb428f9f05c1225a558f53e30ccbad9930b11c3fc206836242de1091d3e7dd3 \ + --hash=sha256:56d974e02ca2c8eb4812c3f76c30e28836fffc311d55d979f1465c1feeb2b68b \ + --hash=sha256:6c86281738d77335af7aec228328e944b30930899ea760ecf33a4dba66be5e74 \ + --hash=sha256:8f33a3702e167783a9213db10ad29650ebf383946e91bc77f28a5eb083496bc9 \ + --hash=sha256:95ef04cf2e5ba0ab9eaafc4a11eaae91b44f4ef5541acd2ee91d9108d00d59a7 \ + --hash=sha256:ad81425efc5e75da3f39b3e636293360ad8d0b49bed7df824c79764fb4ba9b8b \ + --hash=sha256:b403da1df4d6d43973dc004d19cee3b848e998ae3154cc8097d139b77156c353 \ + --hash=sha256:bc31fa00f1fbc3c3802141eede66f3a2d51d89716a194bf2cd6fc68310a19880 \ + --hash=sha256:bd0d69cee829226a761e92f28140bec9a5ee9d5b4fb4b0cc589068dbfff559b1 \ + --hash=sha256:c525ffa774fe4496282fb0b1187725793de3e7c6b29e41562733cae9ada151ee \ + --hash=sha256:f39c2c19fe824b47484b96f9692932248a54c43799a84282cfe58d05a6449efd \ + --hash=sha256:fac9cd332c67f4422504297889da5ab7e05fd11e3c4392140f7370f4208ded1f # via portpicker -pyelftools==0.31 \ - --hash=sha256:c774416b10310156879443b81187d182d8d9ee499660380e645918b50bc88f99 \ - --hash=sha256:f52de7b3c7e8c64c8abc04a79a1cf37ac5fb0b8a49809827130b858944840607 +pyelftools==0.32 \ + --hash=sha256:013df952a006db5e138b1edf6d8a68ecc50630adbd0d83a2d41e7f846163d738 \ + --hash=sha256:6de90ee7b8263e740c8715a925382d4099b354f29ac48ea40d840cf7aa14ace5 # via auditwheel -pygments==2.18.0 \ - --hash=sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199 \ - --hash=sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a - # via rich -pyparsing==3.1.2 \ - --hash=sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad \ - --hash=sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742 +pygments==2.19.2 \ + --hash=sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887 \ + --hash=sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b + # via + # pytest + # rich +pyparsing==3.2.5 \ + --hash=sha256:2df8d5b7b2802ef88e8d016a2eb9c7aeaa923529cd251ed0fe4608275d4105b6 \ + --hash=sha256:e38a4f02064cf41fe6593d328d0512495ad1f3d8a91c4f73fc401b3079a59a5e # via matplotlib -pyproject-hooks==1.1.0 \ - --hash=sha256:4b37730834edbd6bd37f26ece6b44802fb1c1ee2ece0e54ddff8bfc06db86965 \ - --hash=sha256:7ceeefe9aec63a1064c18d939bdc3adf2d8aa1988a510afec15151578b232aa2 +pyproject-hooks==1.2.0 \ + --hash=sha256:1e859bd5c40fae9448642dd871adf459e5e2084186e8d2c2a79a824c970da1f8 \ + --hash=sha256:9e5c6bfa8dcc30091c74b0cf803c81fdd29d94f01992a7707bc97babb1141913 # via build -pytest==8.2.0 \ - --hash=sha256:1733f0620f6cda4095bbf0d9ff8022486e91892245bb9e7d5542c018f612f233 \ - --hash=sha256:d507d4482197eac0ba2bae2e9babf0672eb333017bcedaa5fb1a3d42c1174b3f - # via pytest-xdist -pytest-xdist==3.6.1 \ - --hash=sha256:9ed4adfb68a016610848639bb7e02c9352d5d9f03d04809919e2dafc3be4cca7 \ - --hash=sha256:ead156a4db231eec769737f57668ef58a2084a34b2e55c4a8fa20d861107300d +pytest==8.4.2 \ + --hash=sha256:86c0d0b93306b961d58d62a4db4879f27fe25513d4b969df351abdddb3c30e01 \ + --hash=sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79 + # via + # -r build/test-requirements.txt + # pytest-xdist +pytest-xdist==3.8.0 \ + --hash=sha256:202ca578cfeb7370784a8c33d6d05bc6e13b4f25b5053c30a152269fd10f0b88 \ + --hash=sha256:7e578125ec9bc6050861aa93f2d59f1d8d085595d6551c2c90b6f4fad8d3a9f1 # via -r build/test-requirements.txt python-dateutil==2.9.0.post0 \ --hash=sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3 \ --hash=sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427 # via matplotlib -rich==13.7.1 \ - --hash=sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222 \ - --hash=sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432 +requests==2.32.5 \ + --hash=sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6 \ + --hash=sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf + # via tensorflow +rich==14.2.0 \ + --hash=sha256:73ff50c7c0c1c77c8243079283f4edb376f0f6442433aecb8ce7e6d0b92d1fe4 \ + --hash=sha256:76bc51fe2e57d2b1be1f96c524b890b816e334ab4c1e45888799bfaab0021edd + # via + # -r build/test-requirements.txt + # keras +scipy==1.16.3 ; python_version <= "3.12" \ + --hash=sha256:0151a0749efeaaab78711c78422d413c583b8cdd2011a3c1d6c794938ee9fdb2 \ + --hash=sha256:01e87659402762f43bd2fee13370553a17ada367d42e7487800bf2916535aecb \ + --hash=sha256:03192a35e661470197556de24e7cb1330d84b35b94ead65c46ad6f16f6b28f2a \ + --hash=sha256:0553371015692a898e1aa858fed67a3576c34edefa6b7ebdb4e9dde49ce5c203 \ + --hash=sha256:062246acacbe9f8210de8e751b16fc37458213f124bef161a5a02c7a39284304 \ + --hash=sha256:0c3b4dd3d9b08dbce0f3440032c52e9e2ab9f96ade2d3943313dfe51a7056959 \ + --hash=sha256:0c623a54f7b79dd88ef56da19bc2873afec9673a48f3b85b18e4d402bdd29a5a \ + --hash=sha256:16b8bc35a4cc24db80a0ec836a9286d0e31b2503cb2fd7ff7fb0e0374a97081d \ + --hash=sha256:1fb2472e72e24d1530debe6ae078db70fb1605350c88a3d14bc401d6306dbffe \ + --hash=sha256:21d9d6b197227a12dcbf9633320a4e34c6b0e51c57268df255a0942983bac562 \ + --hash=sha256:2a207a6ce9c24f1951241f4693ede2d393f59c07abc159b2cb2be980820e01fb \ + --hash=sha256:2b71d93c8a9936046866acebc915e2af2e292b883ed6e2cbe5c34beb094b82d9 \ + --hash=sha256:2d1ae2cf0c350e7705168ff2429962a89ad90c2d49d1dd300686d8b2a5af22fc \ + --hash=sha256:3a4c460301fb2cffb7f88528f30b3127742cff583603aa7dc964a52c463b385d \ + --hash=sha256:3d4a07a8e785d80289dfe66b7c27d8634a773020742ec7187b85ccc4b0e7b686 \ + --hash=sha256:40be6cf99e68b6c4321e9f8782e7d5ff8265af28ef2cd56e9c9b2638fa08ad97 \ + --hash=sha256:4aff59800a3b7f786b70bfd6ab551001cb553244988d7d6b8299cb1ea653b353 \ + --hash=sha256:50a3dbf286dbc7d84f176f9a1574c705f277cb6565069f88f60db9eafdbe3ee2 \ + --hash=sha256:532fb5ad6a87e9e9cd9c959b106b73145a03f04c7d57ea3e6f6bb60b86ab0876 \ + --hash=sha256:53c3844d527213631e886621df5695d35e4f6a75f620dca412bcd292f6b87d78 \ + --hash=sha256:56edc65510d1331dae01ef9b658d428e33ed48b4f77b1d51caf479a0253f96dc \ + --hash=sha256:57d01cb6f85e34f0946b33caa66e892aae072b64b034183f3d87c4025802a119 \ + --hash=sha256:5803c5fadd29de0cf27fa08ccbfe7a9e5d741bf63e4ab1085437266f12460ff9 \ + --hash=sha256:6020470b9d00245926f2d5bb93b119ca0340f0d564eb6fbaad843eaebf9d690f \ + --hash=sha256:63d3cdacb8a824a295191a723ee5e4ea7768ca5ca5f2838532d9f2e2b3ce2135 \ + --hash=sha256:663b8d66a8748051c3ee9c96465fb417509315b99c71550fda2591d7dd634234 \ + --hash=sha256:72d1717fd3b5e6ec747327ce9bda32d5463f472c9dce9f54499e81fbd50245a1 \ + --hash=sha256:7dc1360c06535ea6116a2220f760ae572db9f661aba2d88074fe30ec2aa1ff88 \ + --hash=sha256:7f68154688c515cdb541a31ef8eb66d8cd1050605be9dcd74199cbd22ac739bc \ + --hash=sha256:81fc5827606858cf71446a5e98715ba0e11f0dbc83d71c7409d05486592a45d6 \ + --hash=sha256:875555ce62743e1d54f06cdf22c1e0bc47b91130ac40fe5d783b6dfa114beeb6 \ + --hash=sha256:8b3c820ddb80029fe9f43d61b81d8b488d3ef8ca010d15122b152db77dc94c22 \ + --hash=sha256:8be1ca9170fcb6223cc7c27f4305d680ded114a1567c0bd2bfcbf947d1b17511 \ + --hash=sha256:8d09d72dc92742988b0e7750bddb8060b0c7079606c0d24a8cc8e9c9c11f9079 \ + --hash=sha256:9452781bd879b14b6f055b26643703551320aa8d79ae064a71df55c00286a184 \ + --hash=sha256:96491a6a54e995f00a28a3c3badfff58fd093bf26cd5fb34a2188c8c756a3a2c \ + --hash=sha256:9b9c9c07b6d56a35777a1b4cc8966118fb16cfd8daf6743867d17d36cfad2d40 \ + --hash=sha256:a8a26c78ef223d3e30920ef759e25625a0ecdd0d60e5a8818b7513c3e5384cf2 \ + --hash=sha256:aadd23f98f9cb069b3bd64ddc900c4d277778242e961751f77a8cb5c4b946fb0 \ + --hash=sha256:b7180967113560cca57418a7bc719e30366b47959dd845a93206fbed693c867e \ + --hash=sha256:b7c5f1bda1354d6a19bc6af73a649f8285ca63ac6b52e64e658a5a11d4d69800 \ + --hash=sha256:b81c27fc41954319a943d43b20e07c40bdcd3ff7cf013f4fb86286faefe546c4 \ + --hash=sha256:bb61878c18a470021fb515a843dc7a76961a8daceaaaa8bad1332f1bf4b54657 \ + --hash=sha256:bea0a62734d20d67608660f69dcda23e7f90fb4ca20974ab80b6ed40df87a005 \ + --hash=sha256:c5192722cffe15f9329a3948c4b1db789fbb1f05c97899187dcf009b283aea70 \ + --hash=sha256:c97176013d404c7346bf57874eaac5187d969293bf40497140b0a2b2b7482e07 \ + --hash=sha256:cd13e354df9938598af2be05822c323e97132d5e6306b83a3b4ee6724c6e522e \ + --hash=sha256:d2ec56337675e61b312179a1ad124f5f570c00f920cc75e1000025451b88241c \ + --hash=sha256:d3837938ae715fc0fe3c39c0202de3a8853aff22ca66781ddc2ade7554b7e2cc \ + --hash=sha256:d9f48cafc7ce94cf9b15c6bffdc443a81a27bf7075cf2dcd5c8b40f85d10c4e7 \ + --hash=sha256:da7763f55885045036fabcebd80144b757d3db06ab0861415d1c3b7c69042146 \ + --hash=sha256:deb3841c925eeddb6afc1e4e4a45e418d19ec7b87c5df177695224078e8ec733 \ + --hash=sha256:e1d27cbcb4602680a49d787d90664fa4974063ac9d4134813332a8c53dbe667c \ + --hash=sha256:e5d42a9472e7579e473879a1990327830493a7047506d58d73fc429b84c1d49d \ + --hash=sha256:e7efa2681ea410b10dde31a52b18b0154d66f2485328830e45fdf183af5aefc6 \ + --hash=sha256:eab43fae33a0c39006a88096cd7b4f4ef545ea0447d250d5ac18202d40b6611d \ + --hash=sha256:f2622206f5559784fa5c4b53a950c3c7c1cf3e84ca1b9c4b6c03f062f289ca26 \ + --hash=sha256:f379b54b77a597aa7ee5e697df0d66903e41b9c85a6dd7946159e356319158e8 \ + --hash=sha256:f667a4542cc8917af1db06366d3f78a5c8e83badd56409f94d1eac8d8d9133fa \ + --hash=sha256:fb4b29f4cf8cc5a8d628bc8d8e26d12d7278cd1f219f22698a378c3d67db5e4b \ + --hash=sha256:ffa6eea95283b2b8079b821dc11f50a17d0571c92b43e2b5b12764dc5f9b285d + # via + # -r build/requirements.in + # jaxlib +scipy-stubs==1.16.3.3 \ + --hash=sha256:af47578875d5557567225a16ec1b9b38a48c4c4377d92396413ebd65406c44ee \ + --hash=sha256:f6316b36cd0fb272c994ae5b10c4a73c644a7e156ed8d32bcd9c35303d0e1b7e # via -r build/test-requirements.txt -scipy==1.13.1 \ - --hash=sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d \ - --hash=sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c \ - --hash=sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca \ - --hash=sha256:27e52b09c0d3a1d5b63e1105f24177e544a222b43611aaf5bc44d4a0979e32f9 \ - --hash=sha256:2831f0dc9c5ea9edd6e51e6e769b655f08ec6db6e2e10f86ef39bd32eb11da54 \ - --hash=sha256:2ac65fb503dad64218c228e2dc2d0a0193f7904747db43014645ae139c8fad16 \ - --hash=sha256:392e4ec766654852c25ebad4f64e4e584cf19820b980bc04960bca0b0cd6eaa2 \ - --hash=sha256:436bbb42a94a8aeef855d755ce5a465479c721e9d684de76bf61a62e7c2b81d5 \ - --hash=sha256:45484bee6d65633752c490404513b9ef02475b4284c4cfab0ef946def50b3f59 \ - --hash=sha256:54f430b00f0133e2224c3ba42b805bfd0086fe488835effa33fa291561932326 \ - --hash=sha256:5713f62f781eebd8d597eb3f88b8bf9274e79eeabf63afb4a737abc6c84ad37b \ - --hash=sha256:5d72782f39716b2b3509cd7c33cdc08c96f2f4d2b06d51e52fb45a19ca0c86a1 \ - --hash=sha256:637e98dcf185ba7f8e663e122ebf908c4702420477ae52a04f9908707456ba4d \ - --hash=sha256:8335549ebbca860c52bf3d02f80784e91a004b71b059e3eea9678ba994796a24 \ - --hash=sha256:949ae67db5fa78a86e8fa644b9a6b07252f449dcf74247108c50e1d20d2b4627 \ - --hash=sha256:a014c2b3697bde71724244f63de2476925596c24285c7a637364761f8710891c \ - --hash=sha256:a78b4b3345f1b6f68a763c6e25c0c9a23a9fd0f39f5f3d200efe8feda560a5fa \ - --hash=sha256:cdd7dacfb95fea358916410ec61bbc20440f7860333aee6d882bb8046264e949 \ - --hash=sha256:cfa31f1def5c819b19ecc3a8b52d28ffdcc7ed52bb20c9a7589669dd3c250989 \ - --hash=sha256:d533654b7d221a6a97304ab63c41c96473ff04459e404b83275b60aa8f4b7004 \ - --hash=sha256:d605e9c23906d1994f55ace80e0125c587f96c020037ea6aa98d01b4bd2e222f \ - --hash=sha256:de3ade0e53bc1f21358aa74ff4830235d716211d7d077e340c7349bc3542e884 \ - --hash=sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299 \ - --hash=sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94 \ - --hash=sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f - # via -r build/requirements.in -six==1.16.0 \ - --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ - --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 - # via python-dateutil +six==1.17.0 \ + --hash=sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274 \ + --hash=sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81 + # via + # astunparse + # google-pasta + # python-dateutil + # tensorflow sortedcontainers==2.4.0 \ --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \ --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0 # via hypothesis -typing-extensions==4.12.0rc1 \ - --hash=sha256:be199d06d8f09ca2c9425e3aa04a9afba33e892fe079dea959e72df7f8442343 \ - --hash=sha256:f933a7b288a919ca97adbff656e52ff81f7ff25d98a2aabb9355ca4090f772fe - # via etils -wheel==0.43.0 \ - --hash=sha256:465ef92c69fa5c5da2d1cf8ac40559a8c940886afcef87dcf14b9470862f1d85 \ - --hash=sha256:55c570405f142630c6b9f72fe09d9b67cf1477fcf543ae5b8dcb1f5b7377da81 - # via -r build/test-requirements.txt -zipp==3.18.2 \ - --hash=sha256:6278d9ddbcfb1f1089a88fde84481528b07b0e10474e09dcfe53dad4069fa059 \ - --hash=sha256:dce197b859eb796242b0622af1b8beb0a722d52aa2f57133ead08edd5bf5374e +tensorboard==2.20.0 \ + --hash=sha256:9dc9f978cb84c0723acf9a345d96c184f0293d18f166bb8d59ee098e6cfaaba6 + # via tensorflow +tensorboard-data-server==0.7.2 \ + --hash=sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb \ + --hash=sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60 \ + --hash=sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530 + # via tensorboard +tensorflow==2.20.0 ; python_version < "3.14" \ + --hash=sha256:02a0293d94f5c8b7125b66abf622cc4854a33ae9d618a0d41309f95e091bbaea \ + --hash=sha256:0deb5c583dfc53b54fd158a194ce0087b406bb6518af400ca3809735e4548ec3 \ + --hash=sha256:1590cbf87b6bcbd34d8e9ad70d0c696135e0aa71be31803b27358cf7ed63f8fc \ + --hash=sha256:197f0b613b38c0da5c6a12a8295ad4a05c78b853835dae8e0f9dfae3ce9ce8a5 \ + --hash=sha256:25265b0bc527e0d54b1e9cc60c44a24f44a809fe27666b905f0466471f9c52ec \ + --hash=sha256:28bc33759249c98eabcee9debd24e74506bbe29ac139e050cf0c74aa9888ebdf \ + --hash=sha256:2bfbfb3dd0e22bffc45fe1e922390d27753e99261fab8a882e802cf98a0e078f \ + --hash=sha256:3e9568c8efcb05c0266be223e3269c62ebf7ad3498f156438311735f6fa5ced5 \ + --hash=sha256:47c88e05a07f1ead4977b4894b3ecd4d8075c40191065afc4fd9355c9db3d926 \ + --hash=sha256:481499fd0f824583de8945be61d5e827898cdaa4f5ea1bc2cc28ca2ccff8229e \ + --hash=sha256:4a69ac2c2ce20720abf3abf917b4e86376326c0976fcec3df330e184b81e4088 \ + --hash=sha256:52b122f0232fd7ab10f28d537ce08470d0b6dcac7fff9685432daac7f8a06c8f \ + --hash=sha256:5f964016c5035d09b85a246a6b739be89282a7839743f3ea63640224f0c63aee \ + --hash=sha256:5fa3729b0126f75a99882b89fb7d536515721eda8014a63e259e780ba0a37372 \ + --hash=sha256:7551558a48c2e2f6c32a1537f06c654a9df1408a1c18e7b99c3caafbd03edfe3 \ + --hash=sha256:7abd7f3a010e0d354dc804182372779a722d474c4d8a3db8f4a3f5baef2a591e \ + --hash=sha256:a66cbd1b19209d3fbc45cbea80de92514ba455434013937251d65d444779783c \ + --hash=sha256:c25edad45e8cb9e76366f7a8c835279f9169028d610f3b52ce92d332a1b05438 \ + --hash=sha256:dd71a7e7c3270239f4185915e8f2c5d39608c5e18973d6e1d101b153993841eb \ + --hash=sha256:e5f169f8f5130ab255bbe854c5f0ae152e93d3d1ac44f42cb1866003b81a5357 + # via -r build/nonfreethreading-requirements.txt +tensorstore==0.1.80 \ + --hash=sha256:04c29d979eb8b8ee48f873dc13d2701bfd49425500ffc5b848e4ec55b2548281 \ + --hash=sha256:07e4a84bacf70b78305831897068a9b5ad30326e63bbeb92c4bf7e565fcf5e9e \ + --hash=sha256:1113a6982fc0fa8dda8fcc0495715e647ac3360909a86ff13f2e04564f82d54a \ + --hash=sha256:189d924eaec394c9331e284a9c513ed583e336472a925823b5151cb26f41d091 \ + --hash=sha256:1b2b2ed0051dfab7e25295b14e6620520729e6e2ddf505f98c8d3917569614bf \ + --hash=sha256:246641a8780ee5e04e88bc95c8e31faac6471bab1180d1f5cdc9804b29a77c04 \ + --hash=sha256:4158fe76b96f62d12a37d7868150d836e089b5280b2bdd363c43c5d651f10e26 \ + --hash=sha256:46136fe42ee6dd835d957db37073058aea0b78fdfbe2975941640131b7740824 \ + --hash=sha256:4baee67fce95f29f593fbab4866119347115eaace887732aa92cfcbb9e6b0748 \ + --hash=sha256:53fd121ccd332bc4cc397f7af45889360c668b43dc3ff6bc3264df0f9886c11a \ + --hash=sha256:6b7c5dd434bba4ee08fe46bbbdb25c60dd3d47ccb4b8561a9751cf1526da52b8 \ + --hash=sha256:6c8dbbdd31cbb28eccfb23dbbd4218fe67bfc32e9cb452875a485b81031c949d \ + --hash=sha256:7451b30f99d9f31a2b9d70e6ef61815713dc782c58c6d817f91781341e4dac05 \ + --hash=sha256:8cd11027b5a8b66db8d344085a31a1666c78621dac27039c4d571bc4974804a1 \ + --hash=sha256:9c088e8c9f67c266ef4dae3703bd617f7c0cb0fd98e99c4500692e38a4328140 \ + --hash=sha256:a92505189731fcb03f1c69a84ea4460abb24204bfac1f339448a0621e7def77c \ + --hash=sha256:acb8d52fadcefafef4ef8ecca3fc99b1d0e3c5c5a888766484c3e39f050be7f5 \ + --hash=sha256:b193a7a1c4f455a61e60ed2dd67271a3daab0910ddb4bd9db51390d1b36d9996 \ + --hash=sha256:bc28a58c580253a526a4b6d239d18181ef96f1e285a502dbb03ff15eeec07a5b \ + --hash=sha256:c0529afab3800749dd245843d3bf0d061a109a8edb77fb345f476e8bccda51b8 \ + --hash=sha256:d2b353b0bd53fedd77fc5a12a1c1a91cacc3cf59e3dd785529c5a54b31d1c7b1 \ + --hash=sha256:de63843706fdfe9565a45567238c5b1e55a0b28bbde6524200b31d29043a9a16 \ + --hash=sha256:e93df6d34ff5f0f6be245f4d29b99a7c1eef8ad91b50686adf57a5eeea99cb74 \ + --hash=sha256:f65dfaf9e737a41389e29a5a2ea52ca5d14c8d6f48b402c723d800cd16d322b0 \ + --hash=sha256:f8b51d7e685bbb63f6becd7d2ac8634d5ab67ec7e53038e597182e2db2c7aa90 + # via -r build/nonfreethreading-requirements.txt +termcolor==3.2.0 \ + --hash=sha256:610e6456feec42c4bcd28934a8c87a06c3fa28b01561d46aa09a9881b8622c58 \ + --hash=sha256:a10343879eba4da819353c55cb8049b0933890c2ebf9ad5d3ecd2bb32ea96ea6 + # via tensorflow +typing-extensions==4.15.0 \ + --hash=sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466 \ + --hash=sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548 + # via + # etils + # grpcio + # optree + # optype + # tensorflow +urllib3==2.6.2 \ + --hash=sha256:016f9c98bb7e98085cb2b4b17b87d2c702975664e4f060c6532e64d1c1a5e797 \ + --hash=sha256:ec21cddfe7724fc7cb4ba4bea7aa8e2ef36f607a4bab81aa6ce42a13dc3f03dd + # via requests +werkzeug==3.1.4 \ + --hash=sha256:2ad50fb9ed09cc3af22c54698351027ace879a0b60a3b5edf5730b2f7d876905 \ + --hash=sha256:cd3cd98b1b92dc3b7b3995038826c68097dcb16f9baa63abe35f20eafeb9fe5e + # via tensorboard +wheel==0.46.1 \ + --hash=sha256:f796f65d72750ccde090663e466d0ca37cd72b62870f7520b96d34cdc07d86d8 \ + --hash=sha256:fd477efb5da0f7df1d3c76c73c14394002c844451bd63229d8570f376f5e6a38 + # via + # -r build/requirements.in + # astunparse +wrapt==2.0.1 \ + --hash=sha256:09c7476ab884b74dce081ad9bfd07fe5822d8600abade571cb1f66d5fc915af6 \ + --hash=sha256:0e17283f533a0d24d6e5429a7d11f250a58d28b4ae5186f8f47853e3e70d2590 \ + --hash=sha256:115cae4beed3542e37866469a8a1f2b9ec549b4463572b000611e9946b86e6f6 \ + --hash=sha256:1218573502a8235bb8a7ecaed12736213b22dcde9feab115fa2989d42b5ded45 \ + --hash=sha256:17fb85fa4abc26a5184d93b3efd2dcc14deb4b09edcdb3535a536ad34f0b4dba \ + --hash=sha256:1e9b121e9aeb15df416c2c960b8255a49d44b4038016ee17af03975992d03931 \ + --hash=sha256:1f186e26ea0a55f809f232e92cc8556a0977e00183c3ebda039a807a42be1494 \ + --hash=sha256:1fdbb34da15450f2b1d735a0e969c24bdb8d8924892380126e2a293d9902078c \ + --hash=sha256:23097ed8bc4c93b7bf36fa2113c6c733c976316ce0ee2c816f64ca06102034ef \ + --hash=sha256:2879af909312d0baf35f08edeea918ee3af7ab57c37fe47cb6a373c9f2749c7b \ + --hash=sha256:2afa23318136709c4b23d87d543b425c399887b4057936cd20386d5b1422b6fa \ + --hash=sha256:2da620b31a90cdefa9cd0c2b661882329e2e19d1d7b9b920189956b76c564d75 \ + --hash=sha256:35cdbd478607036fee40273be8ed54a451f5f23121bd9d4be515158f9498f7ad \ + --hash=sha256:36982b26f190f4d737f04a492a68accbfc6fa042c3f42326fdfbb6c5b7a20a31 \ + --hash=sha256:3793ac154afb0e5b45d1233cb94d354ef7a983708cc3bb12563853b1d8d53747 \ + --hash=sha256:386fb54d9cd903ee0012c09291336469eb7b244f7183d40dc3e86a16a4bace62 \ + --hash=sha256:3cd1a4bd9a7a619922a8557e1318232e7269b5fb69d4ba97b04d20450a6bf970 \ + --hash=sha256:3d32794fe940b7000f0519904e247f902f0149edbe6316c710a8562fb6738841 \ + --hash=sha256:3d366aa598d69416b5afedf1faa539fac40c1d80a42f6b236c88c73a3c8f2d41 \ + --hash=sha256:3e271346f01e9c8b1130a6a3b0e11908049fe5be2d365a5f402778049147e7e9 \ + --hash=sha256:3f373a4ab5dbc528a94334f9fe444395b23c2f5332adab9ff4ea82f5a9e33bc1 \ + --hash=sha256:3fa272ca34332581e00bf7773e993d4f632594eb2d1b0b162a9038df0fd971dd \ + --hash=sha256:47434236c396d04875180171ee1f3815ca1eada05e24a1ee99546320d54d1d1b \ + --hash=sha256:47b0f8bafe90f7736151f61482c583c86b0693d80f075a58701dd1549b0010a9 \ + --hash=sha256:4811e15d88ee62dbf5c77f2c3ff3932b1e3ac92323ba3912f51fc4016ce81ecf \ + --hash=sha256:49989061a9977a8cbd6d20f2efa813f24bf657c6990a42967019ce779a878dbf \ + --hash=sha256:4ae879acc449caa9ed43fc36ba08392b9412ee67941748d31d94e3cedb36628c \ + --hash=sha256:4b55cacc57e1dc2d0991dbe74c6419ffd415fb66474a02335cb10efd1aa3f84f \ + --hash=sha256:4d2ce1bf1a48c5277d7969259232b57645aae5686dba1eaeade39442277afbca \ + --hash=sha256:4da7384b0e5d4cae05c97cd6f94faaf78cc8b0f791fc63af43436d98c4ab37bb \ + --hash=sha256:4e54bbf554ee29fcceee24fa41c4d091398b911da6e7f5d7bffda963c9aed2e1 \ + --hash=sha256:50844efc8cdf63b2d90cd3d62d4947a28311e6266ce5235a219d21b195b4ec2c \ + --hash=sha256:5a4939eae35db6b6cec8e7aa0e833dcca0acad8231672c26c2a9ab7a0f8ac9c8 \ + --hash=sha256:5dc1b852337c6792aa111ca8becff5bacf576bf4a0255b0f05eb749da6a1643e \ + --hash=sha256:5e53b428f65ece6d9dad23cb87e64506392b720a0b45076c05354d27a13351a1 \ + --hash=sha256:61c4956171c7434634401db448371277d07032a81cc21c599c22953374781395 \ + --hash=sha256:641e94e789b5f6b4822bb8d8ebbdfc10f4e4eae7756d648b717d980f657a9eb9 \ + --hash=sha256:64b103acdaa53b7caf409e8d45d39a8442fe6dcfec6ba3f3d141e0cc2b5b4dbd \ + --hash=sha256:68424221a2dc00d634b54f92441914929c5ffb1c30b3b837343978343a3512a3 \ + --hash=sha256:6bd1a18f5a797fe740cb3d7a0e853a8ce6461cc62023b630caec80171a6b8097 \ + --hash=sha256:6c72328f668cf4c503ffcf9434c2b71fdd624345ced7941bc6693e61bbe36bef \ + --hash=sha256:6d2d947d266d99a1477cd005b23cbd09465276e302515e122df56bb9511aca1b \ + --hash=sha256:7164a55f5e83a9a0b031d3ffab4d4e36bbec42e7025db560f225489fa929e509 \ + --hash=sha256:7b219cb2182f230676308cdcacd428fa837987b89e4b7c5c9025088b8a6c9faf \ + --hash=sha256:7d539241e87b650cbc4c3ac9f32c8d1ac8a54e510f6dca3f6ab60dcfd48c9b10 \ + --hash=sha256:7de3cc939be0e1174969f943f3b44e0d79b6f9a82198133a5b7fc6cc92882f16 \ + --hash=sha256:8330b42d769965e96e01fa14034b28a2a7600fbf7e8f0cc90ebb36d492c993e4 \ + --hash=sha256:837e31620e06b16030b1d126ed78e9383815cbac914693f54926d816d35d8edf \ + --hash=sha256:83ce30937f0ba0d28818807b303a412440c4b63e39d3d8fc036a94764b728c92 \ + --hash=sha256:85df8d92158cb8f3965aecc27cf821461bb5f40b450b03facc5d9f0d4d6ddec6 \ + --hash=sha256:8639b843c9efd84675f1e100ed9e99538ebea7297b62c4b45a7042edb84db03e \ + --hash=sha256:89a82053b193837bf93c0f8a57ded6e4b6d88033a499dadff5067e912c2a41e9 \ + --hash=sha256:8bacfe6e001749a3b64db47bcf0341da757c95959f592823a93931a422395013 \ + --hash=sha256:8ec3303e8a81932171f455f792f8df500fc1a09f20069e5c16bd7049ab4e8e38 \ + --hash=sha256:90897ea1cf0679763b62e79657958cd54eae5659f6360fc7d2ccc6f906342183 \ + --hash=sha256:908f8c6c71557f4deaa280f55d0728c3bca0960e8c3dd5ceeeafb3c19942719d \ + --hash=sha256:91bcc576260a274b169c3098e9a3519fb01f2989f6d3d386ef9cbf8653de1374 \ + --hash=sha256:9219a1d946a9b32bb23ccae66bdb61e35c62773ce7ca6509ceea70f344656b7b \ + --hash=sha256:949520bccc1fa227274da7d03bf238be15389cd94e32e4297b92337df9b7a349 \ + --hash=sha256:98d873ed6c8b4ee2418f7afce666751854d6d03e3c0ec2a399bb039cd2ae89db \ + --hash=sha256:9c9c635e78497cacb81e84f8b11b23e0aacac7a136e73b8e5b2109a1d9fc468f \ + --hash=sha256:9ca66b38dd642bf90c59b6738af8070747b610115a39af2498535f62b5cdc1c3 \ + --hash=sha256:a453257f19c31b31ba593c30d997d6e5be39e3b5ad9148c2af5a7314061c63eb \ + --hash=sha256:a52f93d95c8d38fed0669da2ebdb0b0376e895d84596a976c15a9eb45e3eccb3 \ + --hash=sha256:a9a83618c4f0757557c077ef71d708ddd9847ed66b7cc63416632af70d3e2308 \ + --hash=sha256:ab594f346517010050126fcd822697b25a7031d815bb4fbc238ccbe568216489 \ + --hash=sha256:ad3ee9d0f254851c71780966eb417ef8e72117155cff04821ab9b60549694a55 \ + --hash=sha256:aea9c7224c302bc8bfc892b908537f56c430802560e827b75ecbde81b604598b \ + --hash=sha256:b4c2e3d777e38e913b8ce3a6257af72fb608f86a1df471cb1d4339755d0a807c \ + --hash=sha256:b667189cf8efe008f55bbda321890bef628a67ab4147ebf90d182f2dadc78790 \ + --hash=sha256:b89ef9223d665ab255ae42cc282d27d69704d94be0deffc8b9d919179a609684 \ + --hash=sha256:be9e84e91d6497ba62594158d3d31ec0486c60055c49179edc51ee43d095f79c \ + --hash=sha256:bf4cb76f36be5de950ce13e22e7fdf462b35b04665a12b64f3ac5c1bbbcf3728 \ + --hash=sha256:bfb5539005259f8127ea9c885bdc231978c06b7a980e63a8a61c8c4c979719d0 \ + --hash=sha256:c046781d422f0830de6329fa4b16796096f28a92c8aef3850674442cdcb87b7f \ + --hash=sha256:c1be685ac7700c966b8610ccc63c3187a72e33cab53526a27b2a285a662cd4f7 \ + --hash=sha256:c1c91405fcf1d501fa5d55df21e58ea49e6b879ae829f1039faaf7e5e509b41e \ + --hash=sha256:c235095d6d090aa903f1db61f892fffb779c1eaeb2a50e566b52001f7a0f66ed \ + --hash=sha256:c4012a2bd37059d04f8209916aa771dfb564cccb86079072bdcd48a308b6a5c5 \ + --hash=sha256:c5ef2f2b8a53b7caee2f797ef166a390fef73979b15778a4a153e4b5fedce8fa \ + --hash=sha256:c654eafb01afac55246053d67a4b9a984a3567c3808bb7df2f8de1c1caba2e1c \ + --hash=sha256:c8d60527d1ecfc131426b10d93ab5d53e08a09c5fa0175f6b21b3252080c70a9 \ + --hash=sha256:c9e850f5b7fc67af856ff054c71690d54fa940c3ef74209ad9f935b4f66a0233 \ + --hash=sha256:cbeb0971e13b4bd81d34169ed57a6dda017328d1a22b62fda45e1d21dd06148f \ + --hash=sha256:d1a8a09a004ef100e614beec82862d11fc17d601092c3599afd22b1f36e4137e \ + --hash=sha256:d67956c676be5a24102c7407a71f4126d30de2a569a1c7871c9f3cabc94225d7 \ + --hash=sha256:d6cc985b9c8b235bd933990cdbf0f891f8e010b65a3911f7a55179cd7b0fc57b \ + --hash=sha256:d7b822c61ed04ee6ad64bc90d13368ad6eb094db54883b5dde2182f67a7f22c0 \ + --hash=sha256:df0b6d3b95932809c5b3fecc18fda0f1e07452d05e2662a0b35548985f256e28 \ + --hash=sha256:e042d653a4745be832d5aa190ff80ee4f02c34b21f4b785745eceacd0907b815 \ + --hash=sha256:e2f84e9af2060e3904a32cea9bb6db23ce3f91cfd90c6b426757cf7cc01c45c7 \ + --hash=sha256:e3612dc06b436968dfb9142c62e5dfa9eb5924f91120b3c8ff501ad878f90eb3 \ + --hash=sha256:e505629359cb5f751e16e30cf3f91a1d3ddb4552480c205947da415d597f7ac2 \ + --hash=sha256:e60690ba71a57424c8d9ff28f8d006b7ad7772c22a4af432188572cd7fa004a1 \ + --hash=sha256:e76e3f91f864e89db8b8d2a8311d57df93f01ad6bb1e9b9976d1f2e83e18315c \ + --hash=sha256:eb7cffe572ad0a141a7886a1d2efa5bef0bf7fe021deeea76b3ab334d2c38218 \ + --hash=sha256:ec65a78fbd9d6f083a15d7613b2800d5663dbb6bb96003899c834beaa68b242c \ + --hash=sha256:eda8e4ecd662d48c28bb86be9e837c13e45c58b8300e43ba3c9b4fa9900302f7 \ + --hash=sha256:f26f8e2ca19564e2e1fdbb6a0e47f36e0efbab1acc31e15471fad88f828c75f6 \ + --hash=sha256:f49027b0b9503bf6c8cdc297ca55006b80c2f5dd36cecc72c6835ab6e10e8a25 \ + --hash=sha256:f73f9f7a0ebd0db139253d27e5fc8d2866ceaeef19c30ab5d69dcbe35e1a6981 \ + --hash=sha256:fa4184e74197af3adad3c889a1af95b53bb0466bced92ea99a0c014e48323eec \ + --hash=sha256:fb1a5b72cbd751813adc02ef01ada0b0d05d3dcbc32976ce189a1279d80ad4a2 \ + --hash=sha256:fb3a86e703868561c5cad155a15c36c716e1ab513b7065bd2ac8ed353c503333 \ + --hash=sha256:fc007fdf480c77301ab1afdbb6ab22a5deee8885f3b1ed7afcb7e5e84a0e27be \ + --hash=sha256:fe21b118b9f58859b5ebaa4b130dee18669df4bd111daad082b7beb8799ad16b \ + --hash=sha256:fec0d993ecba3991645b4857837277469c8cc4c554a7e24d064d1ca291cfb81f + # via tensorflow +zipp==3.23.0 \ + --hash=sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e \ + --hash=sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166 # via etils -zstandard==0.22.0 \ - --hash=sha256:11f0d1aab9516a497137b41e3d3ed4bbf7b2ee2abc79e5c8b010ad286d7464bd \ - --hash=sha256:1958100b8a1cc3f27fa21071a55cb2ed32e9e5df4c3c6e661c193437f171cba2 \ - --hash=sha256:1a90ba9a4c9c884bb876a14be2b1d216609385efb180393df40e5172e7ecf356 \ - --hash=sha256:1d43501f5f31e22baf822720d82b5547f8a08f5386a883b32584a185675c8fbf \ - --hash=sha256:23d2b3c2b8e7e5a6cb7922f7c27d73a9a615f0a5ab5d0e03dd533c477de23004 \ - --hash=sha256:2612e9bb4977381184bb2463150336d0f7e014d6bb5d4a370f9a372d21916f69 \ - --hash=sha256:275df437ab03f8c033b8a2c181e51716c32d831082d93ce48002a5227ec93019 \ - --hash=sha256:2ac9957bc6d2403c4772c890916bf181b2653640da98f32e04b96e4d6fb3252a \ - --hash=sha256:2b11ea433db22e720758cba584c9d661077121fcf60ab43351950ded20283440 \ - --hash=sha256:2fdd53b806786bd6112d97c1f1e7841e5e4daa06810ab4b284026a1a0e484c0b \ - --hash=sha256:33591d59f4956c9812f8063eff2e2c0065bc02050837f152574069f5f9f17775 \ - --hash=sha256:36a47636c3de227cd765e25a21dc5dace00539b82ddd99ee36abae38178eff9e \ - --hash=sha256:39b2853efc9403927f9065cc48c9980649462acbdf81cd4f0cb773af2fd734bc \ - --hash=sha256:3db41c5e49ef73641d5111554e1d1d3af106410a6c1fb52cf68912ba7a343a0d \ - --hash=sha256:445b47bc32de69d990ad0f34da0e20f535914623d1e506e74d6bc5c9dc40bb09 \ - --hash=sha256:466e6ad8caefb589ed281c076deb6f0cd330e8bc13c5035854ffb9c2014b118c \ - --hash=sha256:48f260e4c7294ef275744210a4010f116048e0c95857befb7462e033f09442fe \ - --hash=sha256:4ac59d5d6910b220141c1737b79d4a5aa9e57466e7469a012ed42ce2d3995e88 \ - --hash=sha256:53866a9d8ab363271c9e80c7c2e9441814961d47f88c9bc3b248142c32141d94 \ - --hash=sha256:589402548251056878d2e7c8859286eb91bd841af117dbe4ab000e6450987e08 \ - --hash=sha256:68953dc84b244b053c0d5f137a21ae8287ecf51b20872eccf8eaac0302d3e3b0 \ - --hash=sha256:6c25b8eb733d4e741246151d895dd0308137532737f337411160ff69ca24f93a \ - --hash=sha256:7034d381789f45576ec3f1fa0e15d741828146439228dc3f7c59856c5bcd3292 \ - --hash=sha256:73a1d6bd01961e9fd447162e137ed949c01bdb830dfca487c4a14e9742dccc93 \ - --hash=sha256:8226a33c542bcb54cd6bd0a366067b610b41713b64c9abec1bc4533d69f51e70 \ - --hash=sha256:888196c9c8893a1e8ff5e89b8f894e7f4f0e64a5af4d8f3c410f0319128bb2f8 \ - --hash=sha256:88c5b4b47a8a138338a07fc94e2ba3b1535f69247670abfe422de4e0b344aae2 \ - --hash=sha256:8a1b2effa96a5f019e72874969394edd393e2fbd6414a8208fea363a22803b45 \ - --hash=sha256:93e1856c8313bc688d5df069e106a4bc962eef3d13372020cc6e3ebf5e045202 \ - --hash=sha256:9501f36fac6b875c124243a379267d879262480bf85b1dbda61f5ad4d01b75a3 \ - --hash=sha256:959665072bd60f45c5b6b5d711f15bdefc9849dd5da9fb6c873e35f5d34d8cfb \ - --hash=sha256:a1d67d0d53d2a138f9e29d8acdabe11310c185e36f0a848efa104d4e40b808e4 \ - --hash=sha256:a493d470183ee620a3df1e6e55b3e4de8143c0ba1b16f3ded83208ea8ddfd91d \ - --hash=sha256:a7ccf5825fd71d4542c8ab28d4d482aace885f5ebe4b40faaa290eed8e095a4c \ - --hash=sha256:a88b7df61a292603e7cd662d92565d915796b094ffb3d206579aaebac6b85d5f \ - --hash=sha256:a97079b955b00b732c6f280d5023e0eefe359045e8b83b08cf0333af9ec78f26 \ - --hash=sha256:d22fdef58976457c65e2796e6730a3ea4a254f3ba83777ecfc8592ff8d77d303 \ - --hash=sha256:d75f693bb4e92c335e0645e8845e553cd09dc91616412d1d4650da835b5449df \ - --hash=sha256:d8593f8464fb64d58e8cb0b905b272d40184eac9a18d83cf8c10749c3eafcd7e \ - --hash=sha256:d8fff0f0c1d8bc5d866762ae95bd99d53282337af1be9dc0d88506b340e74b73 \ - --hash=sha256:de20a212ef3d00d609d0b22eb7cc798d5a69035e81839f549b538eff4105d01c \ - --hash=sha256:e9e9d4e2e336c529d4c435baad846a181e39a982f823f7e4495ec0b0ec8538d2 \ - --hash=sha256:f058a77ef0ece4e210bb0450e68408d4223f728b109764676e1a13537d056bb0 \ - --hash=sha256:f1a4b358947a65b94e2501ce3e078bbc929b039ede4679ddb0460829b12f7375 \ - --hash=sha256:f9b2cde1cd1b2a10246dbc143ba49d942d14fb3d2b4bccf4618d475c65464912 \ - --hash=sha256:fe3390c538f12437b859d815040763abc728955a52ca6ff9c5d4ac707c4ad98e - # via -r build/requirements.in +zstandard==0.25.0 ; python_version < "3.14" \ + --hash=sha256:011d388c76b11a0c165374ce660ce2c8efa8e5d87f34996aa80f9c0816698b64 \ + --hash=sha256:01582723b3ccd6939ab7b3a78622c573799d5d8737b534b86d0e06ac18dbde4a \ + --hash=sha256:05353cef599a7b0b98baca9b068dd36810c3ef0f42bf282583f438caf6ddcee3 \ + --hash=sha256:05df5136bc5a011f33cd25bc9f506e7426c0c9b3f9954f056831ce68f3b6689f \ + --hash=sha256:06acb75eebeedb77b69048031282737717a63e71e4ae3f77cc0c3b9508320df6 \ + --hash=sha256:07b527a69c1e1c8b5ab1ab14e2afe0675614a09182213f21a0717b62027b5936 \ + --hash=sha256:0bbc9a0c65ce0eea3c34a691e3c4b6889f5f3909ba4822ab385fab9057099431 \ + --hash=sha256:0be7622c37c183406f3dbf0cba104118eb16a4ea7359eeb5752f0794882fc250 \ + --hash=sha256:106281ae350e494f4ac8a80470e66d1fe27e497052c8d9c3b95dc4cf1ade81aa \ + --hash=sha256:10ef2a79ab8e2974e2075fb984e5b9806c64134810fac21576f0668e7ea19f8f \ + --hash=sha256:1673b7199bbe763365b81a4f3252b8e80f44c9e323fc42940dc8843bfeaf9851 \ + --hash=sha256:172de1f06947577d3a3005416977cce6168f2261284c02080e7ad0185faeced3 \ + --hash=sha256:181eb40e0b6a29b3cd2849f825e0fa34397f649170673d385f3598ae17cca2e9 \ + --hash=sha256:1869da9571d5e94a85a5e8d57e4e8807b175c9e4a6294e3b66fa4efb074d90f6 \ + --hash=sha256:19796b39075201d51d5f5f790bf849221e58b48a39a5fc74837675d8bafc7362 \ + --hash=sha256:1cd5da4d8e8ee0e88be976c294db744773459d51bb32f707a0f166e5ad5c8649 \ + --hash=sha256:1f3689581a72eaba9131b1d9bdbfe520ccd169999219b41000ede2fca5c1bfdb \ + --hash=sha256:1f830a0dac88719af0ae43b8b2d6aef487d437036468ef3c2ea59c51f9d55fd5 \ + --hash=sha256:223415140608d0f0da010499eaa8ccdb9af210a543fac54bce15babbcfc78439 \ + --hash=sha256:22a06c5df3751bb7dc67406f5374734ccee8ed37fc5981bf1ad7041831fa1137 \ + --hash=sha256:22a086cff1b6ceca18a8dd6096ec631e430e93a8e70a9ca5efa7561a00f826fa \ + --hash=sha256:23ebc8f17a03133b4426bcc04aabd68f8236eb78c3760f12783385171b0fd8bd \ + --hash=sha256:25f8f3cd45087d089aef5ba3848cd9efe3ad41163d3400862fb42f81a3a46701 \ + --hash=sha256:2b6bd67528ee8b5c5f10255735abc21aa106931f0dbaf297c7be0c886353c3d0 \ + --hash=sha256:2e54296a283f3ab5a26fc9b8b5d4978ea0532f37b231644f367aa588930aa043 \ + --hash=sha256:3756b3e9da9b83da1796f8809dd57cb024f838b9eeafde28f3cb472012797ac1 \ + --hash=sha256:37daddd452c0ffb65da00620afb8e17abd4adaae6ce6310702841760c2c26860 \ + --hash=sha256:3a39c94ad7866160a4a46d772e43311a743c316942037671beb264e395bdd611 \ + --hash=sha256:3b870ce5a02d4b22286cf4944c628e0f0881b11b3f14667c1d62185a99e04f53 \ + --hash=sha256:3c83b0188c852a47cd13ef3bf9209fb0a77fa5374958b8c53aaa699398c6bd7b \ + --hash=sha256:4203ce3b31aec23012d3a4cf4a2ed64d12fea5269c49aed5e4c3611b938e4088 \ + --hash=sha256:457ed498fc58cdc12fc48f7950e02740d4f7ae9493dd4ab2168a47c93c31298e \ + --hash=sha256:474d2596a2dbc241a556e965fb76002c1ce655445e4e3bf38e5477d413165ffa \ + --hash=sha256:4b14abacf83dfb5c25eb4e4a79520de9e7e205f72c9ee7702f91233ae57d33a2 \ + --hash=sha256:4b6d83057e713ff235a12e73916b6d356e3084fd3d14ced499d84240f3eecee0 \ + --hash=sha256:4d441506e9b372386a5271c64125f72d5df6d2a8e8a2a45a0ae09b03cb781ef7 \ + --hash=sha256:4f187a0bb61b35119d1926aee039524d1f93aaf38a9916b8c4b78ac8514a0aaf \ + --hash=sha256:51526324f1b23229001eb3735bc8c94f9c578b1bd9e867a0a646a3b17109f388 \ + --hash=sha256:53e08b2445a6bc241261fea89d065536f00a581f02535f8122eba42db9375530 \ + --hash=sha256:53f94448fe5b10ee75d246497168e5825135d54325458c4bfffbaafabcc0a577 \ + --hash=sha256:5a56ba0db2d244117ed744dfa8f6f5b366e14148e00de44723413b2f3938a902 \ + --hash=sha256:5f1ad7bf88535edcf30038f6919abe087f606f62c00a87d7e33e7fc57cb69fcc \ + --hash=sha256:5f5e4c2a23ca271c218ac025bd7d635597048b366d6f31f420aaeb715239fc98 \ + --hash=sha256:6a573a35693e03cf1d67799fd01b50ff578515a8aeadd4595d2a7fa9f3ec002a \ + --hash=sha256:6c0e5a65158a7946e7a7affa6418878ef97ab66636f13353b8502d7ea03c8097 \ + --hash=sha256:6dffecc361d079bb48d7caef5d673c88c8988d3d33fb74ab95b7ee6da42652ea \ + --hash=sha256:7030defa83eef3e51ff26f0b7bfb229f0204b66fe18e04359ce3474ac33cbc09 \ + --hash=sha256:7149623bba7fdf7e7f24312953bcf73cae103db8cae49f8154dd1eadc8a29ecb \ + --hash=sha256:72d35d7aa0bba323965da807a462b0966c91608ef3a48ba761678cb20ce5d8b7 \ + --hash=sha256:75ffc32a569fb049499e63ce68c743155477610532da1eb38e7f24bf7cd29e74 \ + --hash=sha256:7713e1179d162cf5c7906da876ec2ccb9c3a9dcbdffef0cc7f70c3667a205f0b \ + --hash=sha256:78228d8a6a1c177a96b94f7e2e8d012c55f9c760761980da16ae7546a15a8e9b \ + --hash=sha256:7b3c3a3ab9daa3eed242d6ecceead93aebbb8f5f84318d82cee643e019c4b73b \ + --hash=sha256:809c5bcb2c67cd0ed81e9229d227d4ca28f82d0f778fc5fea624a9def3963f91 \ + --hash=sha256:81dad8d145d8fd981b2962b686b2241d3a1ea07733e76a2f15435dfb7fb60150 \ + --hash=sha256:85304a43f4d513f5464ceb938aa02c1e78c2943b29f44a750b48b25ac999a049 \ + --hash=sha256:89c4b48479a43f820b749df49cd7ba2dbc2b1b78560ecb5ab52985574fd40b27 \ + --hash=sha256:8e735494da3db08694d26480f1493ad2cf86e99bdd53e8e9771b2752a5c0246a \ + --hash=sha256:913cbd31a400febff93b564a23e17c3ed2d56c064006f54efec210d586171c00 \ + --hash=sha256:9174f4ed06f790a6869b41cba05b43eeb9a35f8993c4422ab853b705e8112bbd \ + --hash=sha256:9300d02ea7c6506f00e627e287e0492a5eb0371ec1670ae852fefffa6164b072 \ + --hash=sha256:933b65d7680ea337180733cf9e87293cc5500cc0eb3fc8769f4d3c88d724ec5c \ + --hash=sha256:9654dbc012d8b06fc3d19cc825af3f7bf8ae242226df5f83936cb39f5fdc846c \ + --hash=sha256:98750a309eb2f020da61e727de7d7ba3c57c97cf6213f6f6277bb7fb42a8e065 \ + --hash=sha256:99c0c846e6e61718715a3c9437ccc625de26593fea60189567f0118dc9db7512 \ + --hash=sha256:a1a4ae2dec3993a32247995bdfe367fc3266da832d82f8438c8570f989753de1 \ + --hash=sha256:a3f79487c687b1fc69f19e487cd949bf3aae653d181dfb5fde3bf6d18894706f \ + --hash=sha256:a4089a10e598eae6393756b036e0f419e8c1d60f44a831520f9af41c14216cf2 \ + --hash=sha256:a51ff14f8017338e2f2e5dab738ce1ec3b5a851f23b18c1ae1359b1eecbee6df \ + --hash=sha256:a5a419712cf88862a45a23def0ae063686db3d324cec7edbe40509d1a79a0aab \ + --hash=sha256:a9ec8c642d1ec73287ae3e726792dd86c96f5681eb8df274a757bf62b750eae7 \ + --hash=sha256:aaf21ba8fb76d102b696781bddaa0954b782536446083ae3fdaa6f16b25a1c4b \ + --hash=sha256:ab85470ab54c2cb96e176f40342d9ed41e58ca5733be6a893b730e7af9c40550 \ + --hash=sha256:b9af1fe743828123e12b41dd8091eca1074d0c1569cc42e6e1eee98027f2bbd0 \ + --hash=sha256:bfc4e20784722098822e3eee42b8e576b379ed72cca4a7cb856ae733e62192ea \ + --hash=sha256:bfd06b1c5584b657a2892a6014c2f4c20e0db0208c159148fa78c65f7e0b0277 \ + --hash=sha256:c19bcdd826e95671065f8692b5a4aa95c52dc7a02a4c5a0cac46deb879a017a2 \ + --hash=sha256:c2ba942c94e0691467ab901fc51b6f2085ff48f2eea77b1a48240f011e8247c7 \ + --hash=sha256:c8e167d5adf59476fa3e37bee730890e389410c354771a62e3c076c86f9f7778 \ + --hash=sha256:ca54090275939dc8ec5dea2d2afb400e0f83444b2fc24e07df7fdef677110859 \ + --hash=sha256:d7541afd73985c630bafcd6338d2518ae96060075f9463d7dc14cfb33514383d \ + --hash=sha256:d8c56bb4e6c795fc77d74d8e8b80846e1fb8292fc0b5060cd8131d522974b751 \ + --hash=sha256:da469dc041701583e34de852d8634703550348d5822e66a0c827d39b05365b12 \ + --hash=sha256:daab68faadb847063d0c56f361a289c4f268706b598afbf9ad113cbe5c38b6b2 \ + --hash=sha256:e05ab82ea7753354bb054b92e2f288afb750e6b439ff6ca78af52939ebbc476d \ + --hash=sha256:e09bb6252b6476d8d56100e8147b803befa9a12cea144bbe629dd508800d1ad0 \ + --hash=sha256:e29f0cf06974c899b2c188ef7f783607dbef36da4c242eb6c82dcd8b512855e3 \ + --hash=sha256:e59fdc271772f6686e01e1b3b74537259800f57e24280be3f29c8a0deb1904dd \ + --hash=sha256:e7360eae90809efd19b886e59a09dad07da4ca9ba096752e61a2e03c8aca188e \ + --hash=sha256:e96594a5537722fdfb79951672a2a63aec5ebfb823e7560586f7484819f2a08f \ + --hash=sha256:ea9d54cc3d8064260114a0bbf3479fc4a98b21dffc89b3459edd506b69262f6e \ + --hash=sha256:ec996f12524f88e151c339688c3897194821d7f03081ab35d31d1e12ec975e94 \ + --hash=sha256:f27662e4f7dbf9f9c12391cb37b4c4c3cb90ffbd3b1fb9284dadbbb8935fa708 \ + --hash=sha256:f373da2c1757bb7f1acaf09369cdc1d51d84131e50d5fa9863982fd626466313 \ + --hash=sha256:f5aeea11ded7320a84dcdd62a3d95b5186834224a9e55b92ccae35d21a8b63d4 \ + --hash=sha256:f604efd28f239cc21b3adb53eb061e2a205dc164be408e553b41ba2ffe0ca15c \ + --hash=sha256:f67e8f1a324a900e75b5e28ffb152bcac9fbed1cc7b43f99cd90f395c4375344 \ + --hash=sha256:fd7a5004eb1980d3cefe26b2685bcb0b17989901a70a1040d1ac86f1d898c551 \ + --hash=sha256:ffef5a74088f1e09947aecf91011136665152e0b4b359c42be3373897fb39b01 + # via -r build/nonfreethreading-requirements.txt # The following packages are considered to be unsafe in a requirements file: -setuptools==76.0.0 \ - --hash=sha256:199466a166ff664970d0ee145839f5582cb9bca7a0a3a2e795b6a9cb2308e9c6 \ - --hash=sha256:43b4ee60e10b0d0ee98ad11918e114c70701bc6051662a9a675a0496c1a158f4 +setuptools==80.9.0 \ + --hash=sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922 \ + --hash=sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c # via # -r build/requirements.in - # -r build/test-requirements.txt + # tensorboard + # tensorflow diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index e74d40b798f4..c6206076d2c1 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -4,813 +4,1763 @@ # # bazel run //build:requirements.update # -absl-py==2.1.0 \ - --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ - --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff - # via -r build/test-requirements.txt -attrs==24.2.0 \ - --hash=sha256:5cfb1b9148b5b086569baec03f20d7b6bf3bcacc9a42bebf87ffaaca362f6346 \ - --hash=sha256:81921eb96de3191c8258c199618104dd27ac608d9366f5e35d011eae1867ede2 +--index-url https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple + +absl-py==2.3.1 \ + --hash=sha256:a97820526f7fbfd2ec1bce83f3f25e3a14840dac0d8e02a0b71cd75db3f77fc9 \ + --hash=sha256:eeecf07f0c2a93ace0772c92e596ace6d3d3996c042b2128459aaae2a76de11d + # via + # -r build/test-requirements.txt + # keras + # tensorboard + # tensorflow +astunparse==1.6.3 \ + --hash=sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872 \ + --hash=sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8 + # via tensorflow +attrs==25.4.0 \ + --hash=sha256:16d5969b87f0859ef33a48b35d55ac1be6e42ae49d5e853b597db70c35c57e11 \ + --hash=sha256:adcf7e2a1fb3b36ac48d97835bb6d8ade15b8dcce26aba8bf1d14847b57a3373 # via hypothesis -auditwheel==6.1.0 \ - --hash=sha256:3bdc686e774cf9e355e924b0fe5a562d55caa385d72234ffe7b81b378dba360f \ - --hash=sha256:e52f734861859e3743eb29fcac7da9c4921a1e4bea58f954b52f2926f8e9e364 +auditwheel==6.5.0 \ + --hash=sha256:4fbcbd5854054bb1dd7870db03727b871b96b18147db57259561c058603987d7 \ + --hash=sha256:e08d2eede0259be6feff597d041c06175026e93248a1a97143acc52c57714d80 # via -r build/test-requirements.txt -build==1.2.2.post1 \ - --hash=sha256:1d61c0887fa860c01971625baae8bdd338e517b836a2f70dd1f7aa3a6b2fc5b5 \ - --hash=sha256:b36993e92ca9375a219c99e606a122ff365a760a2d4bba0caa09bd5278b608b7 - # via -r build/test-requirements.txt -cloudpickle==3.0.0 \ - --hash=sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 \ - --hash=sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882 +build==1.3.0 \ + --hash=sha256:698edd0ea270bde950f53aed21f3a0135672206f3911e0176261a31e0e07b397 \ + --hash=sha256:7145f0b5061ba90a1500d60bd1b13ca0a8a4cebdd0cc16ed8adf1c0e739f43b4 + # via -r build/requirements.in +certifi==2025.11.12 \ + --hash=sha256:97de8790030bbd5c2d96b7ec782fc2f7820ef8dba6db909ccf95449f2d062d4b \ + --hash=sha256:d8ab5478f2ecd78af242878415affce761ca6bc54a22a27e026d7c25357c3316 + # via requests +charset-normalizer==3.4.4 \ + --hash=sha256:027f6de494925c0ab2a55eab46ae5129951638a49a34d87f4c3eda90f696b4ad \ + --hash=sha256:077fbb858e903c73f6c9db43374fd213b0b6a778106bc7032446a8e8b5b38b93 \ + --hash=sha256:0a98e6759f854bd25a58a73fa88833fba3b7c491169f86ce1180c948ab3fd394 \ + --hash=sha256:0d3d8f15c07f86e9ff82319b3d9ef6f4bf907608f53fe9d92b28ea9ae3d1fd89 \ + --hash=sha256:0f04b14ffe5fdc8c4933862d8306109a2c51e0704acfa35d51598eb45a1e89fc \ + --hash=sha256:11d694519d7f29d6cd09f6ac70028dba10f92f6cdd059096db198c283794ac86 \ + --hash=sha256:194f08cbb32dc406d6e1aea671a68be0823673db2832b38405deba2fb0d88f63 \ + --hash=sha256:1bee1e43c28aa63cb16e5c14e582580546b08e535299b8b6158a7c9c768a1f3d \ + --hash=sha256:21d142cc6c0ec30d2efee5068ca36c128a30b0f2c53c1c07bd78cb6bc1d3be5f \ + --hash=sha256:2437418e20515acec67d86e12bf70056a33abdacb5cb1655042f6538d6b085a8 \ + --hash=sha256:244bfb999c71b35de57821b8ea746b24e863398194a4014e4c76adc2bbdfeff0 \ + --hash=sha256:2677acec1a2f8ef614c6888b5b4ae4060cc184174a938ed4e8ef690e15d3e505 \ + --hash=sha256:277e970e750505ed74c832b4bf75dac7476262ee2a013f5574dd49075879e161 \ + --hash=sha256:2aaba3b0819274cc41757a1da876f810a3e4d7b6eb25699253a4effef9e8e4af \ + --hash=sha256:2b7d8f6c26245217bd2ad053761201e9f9680f8ce52f0fcd8d0755aeae5b2152 \ + --hash=sha256:2c9d3c380143a1fedbff95a312aa798578371eb29da42106a29019368a475318 \ + --hash=sha256:3162d5d8ce1bb98dd51af660f2121c55d0fa541b46dff7bb9b9f86ea1d87de72 \ + --hash=sha256:31fd66405eaf47bb62e8cd575dc621c56c668f27d46a61d975a249930dd5e2a4 \ + --hash=sha256:362d61fd13843997c1c446760ef36f240cf81d3ebf74ac62652aebaf7838561e \ + --hash=sha256:376bec83a63b8021bb5c8ea75e21c4ccb86e7e45ca4eb81146091b56599b80c3 \ + --hash=sha256:44c2a8734b333e0578090c4cd6b16f275e07aa6614ca8715e6c038e865e70576 \ + --hash=sha256:47cc91b2f4dd2833fddaedd2893006b0106129d4b94fdb6af1f4ce5a9965577c \ + --hash=sha256:4902828217069c3c5c71094537a8e623f5d097858ac6ca8252f7b4d10b7560f1 \ + --hash=sha256:4bd5d4137d500351a30687c2d3971758aac9a19208fc110ccb9d7188fbe709e8 \ + --hash=sha256:4fe7859a4e3e8457458e2ff592f15ccb02f3da787fcd31e0183879c3ad4692a1 \ + --hash=sha256:542d2cee80be6f80247095cc36c418f7bddd14f4a6de45af91dfad36d817bba2 \ + --hash=sha256:554af85e960429cf30784dd47447d5125aaa3b99a6f0683589dbd27e2f45da44 \ + --hash=sha256:5833d2c39d8896e4e19b689ffc198f08ea58116bee26dea51e362ecc7cd3ed26 \ + --hash=sha256:5947809c8a2417be3267efc979c47d76a079758166f7d43ef5ae8e9f92751f88 \ + --hash=sha256:5ae497466c7901d54b639cf42d5b8c1b6a4fead55215500d2f486d34db48d016 \ + --hash=sha256:5bd2293095d766545ec1a8f612559f6b40abc0eb18bb2f5d1171872d34036ede \ + --hash=sha256:5bfbb1b9acf3334612667b61bd3002196fe2a1eb4dd74d247e0f2a4d50ec9bbf \ + --hash=sha256:5cb4d72eea50c8868f5288b7f7f33ed276118325c1dfd3957089f6b519e1382a \ + --hash=sha256:5dbe56a36425d26d6cfb40ce79c314a2e4dd6211d51d6d2191c00bed34f354cc \ + --hash=sha256:5f819d5fe9234f9f82d75bdfa9aef3a3d72c4d24a6e57aeaebba32a704553aa0 \ + --hash=sha256:64b55f9dce520635f018f907ff1b0df1fdc31f2795a922fb49dd14fbcdf48c84 \ + --hash=sha256:6515f3182dbe4ea06ced2d9e8666d97b46ef4c75e326b79bb624110f122551db \ + --hash=sha256:65e2befcd84bc6f37095f5961e68a6f077bf44946771354a28ad434c2cce0ae1 \ + --hash=sha256:6aee717dcfead04c6eb1ce3bd29ac1e22663cdea57f943c87d1eab9a025438d7 \ + --hash=sha256:6b39f987ae8ccdf0d2642338faf2abb1862340facc796048b604ef14919e55ed \ + --hash=sha256:6e1fcf0720908f200cd21aa4e6750a48ff6ce4afe7ff5a79a90d5ed8a08296f8 \ + --hash=sha256:74018750915ee7ad843a774364e13a3db91682f26142baddf775342c3f5b1133 \ + --hash=sha256:74664978bb272435107de04e36db5a9735e78232b85b77d45cfb38f758efd33e \ + --hash=sha256:74bb723680f9f7a6234dcf67aea57e708ec1fbdf5699fb91dfd6f511b0a320ef \ + --hash=sha256:752944c7ffbfdd10c074dc58ec2d5a8a4cd9493b314d367c14d24c17684ddd14 \ + --hash=sha256:778d2e08eda00f4256d7f672ca9fef386071c9202f5e4607920b86d7803387f2 \ + --hash=sha256:780236ac706e66881f3b7f2f32dfe90507a09e67d1d454c762cf642e6e1586e0 \ + --hash=sha256:798d75d81754988d2565bff1b97ba5a44411867c0cf32b77a7e8f8d84796b10d \ + --hash=sha256:799a7a5e4fb2d5898c60b640fd4981d6a25f1c11790935a44ce38c54e985f828 \ + --hash=sha256:7a32c560861a02ff789ad905a2fe94e3f840803362c84fecf1851cb4cf3dc37f \ + --hash=sha256:7c308f7e26e4363d79df40ca5b2be1c6ba9f02bdbccfed5abddb7859a6ce72cf \ + --hash=sha256:7fa17817dc5625de8a027cb8b26d9fefa3ea28c8253929b8d6649e705d2835b6 \ + --hash=sha256:81d5eb2a312700f4ecaa977a8235b634ce853200e828fbadf3a9c50bab278328 \ + --hash=sha256:82004af6c302b5d3ab2cfc4cc5f29db16123b1a8417f2e25f9066f91d4411090 \ + --hash=sha256:837c2ce8c5a65a2035be9b3569c684358dfbf109fd3b6969630a87535495ceaa \ + --hash=sha256:840c25fb618a231545cbab0564a799f101b63b9901f2569faecd6b222ac72381 \ + --hash=sha256:8a6562c3700cce886c5be75ade4a5db4214fda19fede41d9792d100288d8f94c \ + --hash=sha256:8af65f14dc14a79b924524b1e7fffe304517b2bff5a58bf64f30b98bbc5079eb \ + --hash=sha256:8ef3c867360f88ac904fd3f5e1f902f13307af9052646963ee08ff4f131adafc \ + --hash=sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a \ + --hash=sha256:99ae2cffebb06e6c22bdc25801d7b30f503cc87dbd283479e7b606f70aff57ec \ + --hash=sha256:9a26f18905b8dd5d685d6d07b0cdf98a79f3c7a918906af7cc143ea2e164c8bc \ + --hash=sha256:9b35f4c90079ff2e2edc5b26c0c77925e5d2d255c42c74fdb70fb49b172726ac \ + --hash=sha256:9cd98cdc06614a2f768d2b7286d66805f94c48cde050acdbbb7db2600ab3197e \ + --hash=sha256:9d1bb833febdff5c8927f922386db610b49db6e0d4f4ee29601d71e7c2694313 \ + --hash=sha256:9f7fcd74d410a36883701fafa2482a6af2ff5ba96b9a620e9e0721e28ead5569 \ + --hash=sha256:a59cb51917aa591b1c4e6a43c132f0cdc3c76dbad6155df4e28ee626cc77a0a3 \ + --hash=sha256:a61900df84c667873b292c3de315a786dd8dac506704dea57bc957bd31e22c7d \ + --hash=sha256:a79cfe37875f822425b89a82333404539ae63dbdddf97f84dcbc3d339aae9525 \ + --hash=sha256:a8a8b89589086a25749f471e6a900d3f662d1d3b6e2e59dcecf787b1cc3a1894 \ + --hash=sha256:a8bf8d0f749c5757af2142fe7903a9df1d2e8aa3841559b2bad34b08d0e2bcf3 \ + --hash=sha256:a9768c477b9d7bd54bc0c86dbaebdec6f03306675526c9927c0e8a04e8f94af9 \ + --hash=sha256:ac1c4a689edcc530fc9d9aa11f5774b9e2f33f9a0c6a57864e90908f5208d30a \ + --hash=sha256:af2d8c67d8e573d6de5bc30cdb27e9b95e49115cd9baad5ddbd1a6207aaa82a9 \ + --hash=sha256:b435cba5f4f750aa6c0a0d92c541fb79f69a387c91e61f1795227e4ed9cece14 \ + --hash=sha256:b5b290ccc2a263e8d185130284f8501e3e36c5e02750fc6b6bdeb2e9e96f1e25 \ + --hash=sha256:b5d84d37db046c5ca74ee7bb47dd6cbc13f80665fdde3e8040bdd3fb015ecb50 \ + --hash=sha256:b7cf1017d601aa35e6bb650b6ad28652c9cd78ee6caff19f3c28d03e1c80acbf \ + --hash=sha256:bc7637e2f80d8530ee4a78e878bce464f70087ce73cf7c1caf142416923b98f1 \ + --hash=sha256:c0463276121fdee9c49b98908b3a89c39be45d86d1dbaa22957e38f6321d4ce3 \ + --hash=sha256:c4ef880e27901b6cc782f1b95f82da9313c0eb95c3af699103088fa0ac3ce9ac \ + --hash=sha256:c8ae8a0f02f57a6e61203a31428fa1d677cbe50c93622b4149d5c0f319c1d19e \ + --hash=sha256:ca5862d5b3928c4940729dacc329aa9102900382fea192fc5e52eb69d6093815 \ + --hash=sha256:cb01158d8b88ee68f15949894ccc6712278243d95f344770fa7593fa2d94410c \ + --hash=sha256:cb6254dc36b47a990e59e1068afacdcd02958bdcce30bb50cc1700a8b9d624a6 \ + --hash=sha256:cc00f04ed596e9dc0da42ed17ac5e596c6ccba999ba6bd92b0e0aef2f170f2d6 \ + --hash=sha256:cd09d08005f958f370f539f186d10aec3377d55b9eeb0d796025d4886119d76e \ + --hash=sha256:cd4b7ca9984e5e7985c12bc60a6f173f3c958eae74f3ef6624bb6b26e2abbae4 \ + --hash=sha256:ce8a0633f41a967713a59c4139d29110c07e826d131a316b50ce11b1d79b4f84 \ + --hash=sha256:cead0978fc57397645f12578bfd2d5ea9138ea0fac82b2f63f7f7c6877986a69 \ + --hash=sha256:d055ec1e26e441f6187acf818b73564e6e6282709e9bcb5b63f5b23068356a15 \ + --hash=sha256:d1f13550535ad8cff21b8d757a3257963e951d96e20ec82ab44bc64aeb62a191 \ + --hash=sha256:d9c7f57c3d666a53421049053eaacdd14bbd0a528e2186fcb2e672effd053bb0 \ + --hash=sha256:d9e45d7faa48ee908174d8fe84854479ef838fc6a705c9315372eacbc2f02897 \ + --hash=sha256:da3326d9e65ef63a817ecbcc0df6e94463713b754fe293eaa03da99befb9a5bd \ + --hash=sha256:de00632ca48df9daf77a2c65a484531649261ec9f25489917f09e455cb09ddb2 \ + --hash=sha256:e1f185f86a6f3403aa2420e815904c67b2f9ebc443f045edd0de921108345794 \ + --hash=sha256:e824f1492727fa856dd6eda4f7cee25f8518a12f3c4a56a74e8095695089cf6d \ + --hash=sha256:e912091979546adf63357d7e2ccff9b44f026c075aeaf25a52d0e95ad2281074 \ + --hash=sha256:eaabd426fe94daf8fd157c32e571c85cb12e66692f15516a83a03264b08d06c3 \ + --hash=sha256:ebf3e58c7ec8a8bed6d66a75d7fb37b55e5015b03ceae72a8e7c74495551e224 \ + --hash=sha256:ecaae4149d99b1c9e7b88bb03e3221956f68fd6d50be2ef061b2381b61d20838 \ + --hash=sha256:eecbc200c7fd5ddb9a7f16c7decb07b566c29fa2161a16cf67b8d068bd21690a \ + --hash=sha256:f155a433c2ec037d4e8df17d18922c3a0d9b3232a396690f17175d2946f0218d \ + --hash=sha256:f1e34719c6ed0b92f418c7c780480b26b5d9c50349e9a9af7d76bf757530350d \ + --hash=sha256:f34be2938726fc13801220747472850852fe6b1ea75869a048d6f896838c896f \ + --hash=sha256:f820802628d2694cb7e56db99213f930856014862f3fd943d290ea8438d07ca8 \ + --hash=sha256:f8bf04158c6b607d747e93949aa60618b61312fe647a6369f88ce2ff16043490 \ + --hash=sha256:f8e160feb2aed042cd657a72acc0b481212ed28b1b9a95c0cee1621b524e1966 \ + --hash=sha256:f9d332f8c2a2fcbffe1378594431458ddbef721c1769d78e2cbc06280d8155f9 \ + --hash=sha256:fa09f53c465e532f4d3db095e0c55b615f010ad81803d383195b6b5ca6cbf5f3 \ + --hash=sha256:faa3a41b2b66b6e50f84ae4a68c64fcd0c44355741c6374813a800cd6695db9e \ + --hash=sha256:fd44c878ea55ba351104cb93cc85e74916eb8fa440ca7903e57575e97394f608 + # via requests +cloudpickle==3.1.2 \ + --hash=sha256:7fda9eb655c9c230dab534f1983763de5835249750e85fbcef43aaa30a9a2414 \ + --hash=sha256:9acb47f6afd73f60dc1df93bb801b472f05ff42fa6c84167d25cb206be1fbf4a # via -r build/test-requirements.txt colorama==0.4.6 \ --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 - # via -r build/test-requirements.txt -contourpy==1.3.0 \ - --hash=sha256:00ccd0dbaad6d804ab259820fa7cb0b8036bda0686ef844d24125d8287178ce0 \ - --hash=sha256:0be4d8425bfa755e0fd76ee1e019636ccc7c29f77a7c86b4328a9eb6a26d0639 \ - --hash=sha256:0dce35502151b6bd35027ac39ba6e5a44be13a68f55735c3612c568cac3805fd \ - --hash=sha256:0fa4c02abe6c446ba70d96ece336e621efa4aecae43eaa9b030ae5fb92b309ad \ - --hash=sha256:14e262f67bd7e6eb6880bc564dcda30b15e351a594657e55b7eec94b6ef72843 \ - --hash=sha256:167d6c890815e1dac9536dca00828b445d5d0df4d6a8c6adb4a7ec3166812fa8 \ - --hash=sha256:1ec4dc6bf570f5b22ed0d7efba0dfa9c5b9e0431aeea7581aa217542d9e809a4 \ - --hash=sha256:303c252947ab4b14c08afeb52375b26781ccd6a5ccd81abcdfc1fafd14cf93c1 \ - --hash=sha256:31cd3a85dbdf1fc002280c65caa7e2b5f65e4a973fcdf70dd2fdcb9868069294 \ - --hash=sha256:32b238b3b3b649e09ce9aaf51f0c261d38644bdfa35cbaf7b263457850957a84 \ - --hash=sha256:33c92cdae89ec5135d036e7218e69b0bb2851206077251f04a6c4e0e21f03927 \ - --hash=sha256:345af746d7766821d05d72cb8f3845dfd08dd137101a2cb9b24de277d716def8 \ - --hash=sha256:3634b5385c6716c258d0419c46d05c8aa7dc8cb70326c9a4fb66b69ad2b52e09 \ - --hash=sha256:364174c2a76057feef647c802652f00953b575723062560498dc7930fc9b1cb7 \ - --hash=sha256:36e0cff201bcb17a0a8ecc7f454fe078437fa6bda730e695a92f2d9932bd507f \ - --hash=sha256:36f965570cff02b874773c49bfe85562b47030805d7d8360748f3eca570f4cab \ - --hash=sha256:3bb3808858a9dc68f6f03d319acd5f1b8a337e6cdda197f02f4b8ff67ad2057b \ - --hash=sha256:3e1c7fa44aaae40a2247e2e8e0627f4bea3dd257014764aa644f319a5f8600e3 \ - --hash=sha256:3faeb2998e4fcb256542e8a926d08da08977f7f5e62cf733f3c211c2a5586223 \ - --hash=sha256:420d39daa61aab1221567b42eecb01112908b2cab7f1b4106a52caaec8d36973 \ - --hash=sha256:4553c421929ec95fb07b3aaca0fae668b2eb5a5203d1217ca7c34c063c53d087 \ - --hash=sha256:4865cd1d419e0c7a7bf6de1777b185eebdc51470800a9f42b9e9decf17762081 \ - --hash=sha256:4cfb5c62ce023dfc410d6059c936dcf96442ba40814aefbfa575425a3a7f19dc \ - --hash=sha256:4d63ee447261e963af02642ffcb864e5a2ee4cbfd78080657a9880b8b1868e18 \ - --hash=sha256:570ef7cf892f0afbe5b2ee410c507ce12e15a5fa91017a0009f79f7d93a1268f \ - --hash=sha256:637f674226be46f6ba372fd29d9523dd977a291f66ab2a74fbeb5530bb3f445d \ - --hash=sha256:68a32389b06b82c2fdd68276148d7b9275b5f5cf13e5417e4252f6d1a34f72a2 \ - --hash=sha256:69375194457ad0fad3a839b9e29aa0b0ed53bb54db1bfb6c3ae43d111c31ce41 \ - --hash=sha256:6cb6cc968059db9c62cb35fbf70248f40994dfcd7aa10444bbf8b3faeb7c2d67 \ - --hash=sha256:710a26b3dc80c0e4febf04555de66f5fd17e9cf7170a7b08000601a10570bda6 \ - --hash=sha256:732896af21716b29ab3e988d4ce14bc5133733b85956316fb0c56355f398099b \ - --hash=sha256:75ee7cb1a14c617f34a51d11fa7524173e56551646828353c4af859c56b766e2 \ - --hash=sha256:76a896b2f195b57db25d6b44e7e03f221d32fe318d03ede41f8b4d9ba1bff53c \ - --hash=sha256:76c905ef940a4474a6289c71d53122a4f77766eef23c03cd57016ce19d0f7b42 \ - --hash=sha256:7a52040312b1a858b5e31ef28c2e865376a386c60c0e248370bbea2d3f3b760d \ - --hash=sha256:7ffa0db17717a8ffb127efd0c95a4362d996b892c2904db72428d5b52e1938a4 \ - --hash=sha256:81cb5ed4952aae6014bc9d0421dec7c5835c9c8c31cdf51910b708f548cf58e5 \ - --hash=sha256:834e0cfe17ba12f79963861e0f908556b2cedd52e1f75e6578801febcc6a9f49 \ - --hash=sha256:87ddffef1dbe5e669b5c2440b643d3fdd8622a348fe1983fad7a0f0ccb1cd67b \ - --hash=sha256:880ea32e5c774634f9fcd46504bf9f080a41ad855f4fef54f5380f5133d343c7 \ - --hash=sha256:8ca947601224119117f7c19c9cdf6b3ab54c5726ef1d906aa4a69dfb6dd58102 \ - --hash=sha256:90f73a5116ad1ba7174341ef3ea5c3150ddf20b024b98fb0c3b29034752c8aeb \ - --hash=sha256:92f8557cbb07415a4d6fa191f20fd9d2d9eb9c0b61d1b2f52a8926e43c6e9af7 \ - --hash=sha256:94e848a6b83da10898cbf1311a815f770acc9b6a3f2d646f330d57eb4e87592e \ - --hash=sha256:9c0da700bf58f6e0b65312d0a5e695179a71d0163957fa381bb3c1f72972537c \ - --hash=sha256:a11077e395f67ffc2c44ec2418cfebed032cd6da3022a94fc227b6faf8e2acb8 \ - --hash=sha256:aea348f053c645100612b333adc5983d87be69acdc6d77d3169c090d3b01dc35 \ - --hash=sha256:b11b39aea6be6764f84360fce6c82211a9db32a7c7de8fa6dd5397cf1d079c3b \ - --hash=sha256:c6c7c2408b7048082932cf4e641fa3b8ca848259212f51c8c59c45aa7ac18f14 \ - --hash=sha256:c6ec93afeb848a0845a18989da3beca3eec2c0f852322efe21af1931147d12cb \ - --hash=sha256:cacd81e2d4b6f89c9f8a5b69b86490152ff39afc58a95af002a398273e5ce589 \ - --hash=sha256:d402880b84df3bec6eab53cd0cf802cae6a2ef9537e70cf75e91618a3801c20c \ - --hash=sha256:d51fca85f9f7ad0b65b4b9fe800406d0d77017d7270d31ec3fb1cc07358fdea0 \ - --hash=sha256:d73f659398a0904e125280836ae6f88ba9b178b2fed6884f3b1f95b989d2c8da \ - --hash=sha256:d78ab28a03c854a873787a0a42254a0ccb3cb133c672f645c9f9c8f3ae9d0800 \ - --hash=sha256:da84c537cb8b97d153e9fb208c221c45605f73147bd4cadd23bdae915042aad6 \ - --hash=sha256:dbc4c3217eee163fa3984fd1567632b48d6dfd29216da3ded3d7b844a8014a66 \ - --hash=sha256:e12968fdfd5bb45ffdf6192a590bd8ddd3ba9e58360b29683c6bb71a7b41edca \ - --hash=sha256:e1fd23e9d01591bab45546c089ae89d926917a66dceb3abcf01f6105d927e2cb \ - --hash=sha256:e8134301d7e204c88ed7ab50028ba06c683000040ede1d617298611f9dc6240c \ - --hash=sha256:eb8b141bb00fa977d9122636b16aa67d37fd40a3d8b52dd837e536d64b9a4d06 \ - --hash=sha256:eca7e17a65f72a5133bdbec9ecf22401c62bcf4821361ef7811faee695799779 \ - --hash=sha256:f317576606de89da6b7e0861cf6061f6146ead3528acabff9236458a6ba467f8 \ - --hash=sha256:fd2a0fc506eccaaa7595b7e1418951f213cf8255be2600f1ea1b61e46a60c55f \ - --hash=sha256:fe41b41505a5a33aeaed2a613dccaeaa74e0e3ead6dd6fd3a118fb471644fd6c + # via -r build/requirements.in +contourpy==1.3.3 \ + --hash=sha256:023b44101dfe49d7d53932be418477dba359649246075c996866106da069af69 \ + --hash=sha256:07ce5ed73ecdc4a03ffe3e1b3e3c1166db35ae7584be76f65dbbe28a7791b0cc \ + --hash=sha256:083e12155b210502d0bca491432bb04d56dc3432f95a979b429f2848c3dbe880 \ + --hash=sha256:0bf67e0e3f482cb69779dd3061b534eb35ac9b17f163d851e2a547d56dba0a3a \ + --hash=sha256:0c1fc238306b35f246d61a1d416a627348b5cf0648648a031e14bb8705fcdfe8 \ + --hash=sha256:13b68d6a62db8eafaebb8039218921399baf6e47bf85006fd8529f2a08ef33fc \ + --hash=sha256:15ff10bfada4bf92ec8b31c62bf7c1834c244019b4a33095a68000d7075df470 \ + --hash=sha256:177fb367556747a686509d6fef71d221a4b198a3905fe824430e5ea0fda54eb5 \ + --hash=sha256:1cadd8b8969f060ba45ed7c1b714fe69185812ab43bd6b86a9123fe8f99c3263 \ + --hash=sha256:1fd43c3be4c8e5fd6e4f2baeae35ae18176cf2e5cced681cca908addf1cdd53b \ + --hash=sha256:22e9b1bd7a9b1d652cd77388465dc358dafcd2e217d35552424aa4f996f524f5 \ + --hash=sha256:23416f38bfd74d5d28ab8429cc4d63fa67d5068bd711a85edb1c3fb0c3e2f381 \ + --hash=sha256:283edd842a01e3dcd435b1c5116798d661378d83d36d337b8dde1d16a5fc9ba3 \ + --hash=sha256:2a2a8b627d5cc6b7c41a4beff6c5ad5eb848c88255fda4a8745f7e901b32d8e4 \ + --hash=sha256:2b7e9480ffe2b0cd2e787e4df64270e3a0440d9db8dc823312e2c940c167df7e \ + --hash=sha256:322ab1c99b008dad206d406bb61d014cf0174df491ae9d9d0fac6a6fda4f977f \ + --hash=sha256:33c82d0138c0a062380332c861387650c82e4cf1747aaa6938b9b6516762e772 \ + --hash=sha256:348ac1f5d4f1d66d3322420f01d42e43122f43616e0f194fc1c9f5d830c5b286 \ + --hash=sha256:3519428f6be58431c56581f1694ba8e50626f2dd550af225f82fb5f5814d2a42 \ + --hash=sha256:3c30273eb2a55024ff31ba7d052dde990d7d8e5450f4bbb6e913558b3d6c2301 \ + --hash=sha256:3d1a3799d62d45c18bafd41c5fa05120b96a28079f2393af559b843d1a966a77 \ + --hash=sha256:451e71b5a7d597379ef572de31eeb909a87246974d960049a9848c3bc6c41bf7 \ + --hash=sha256:459c1f020cd59fcfe6650180678a9993932d80d44ccde1fa1868977438f0b411 \ + --hash=sha256:4d00e655fcef08aba35ec9610536bfe90267d7ab5ba944f7032549c55a146da1 \ + --hash=sha256:4debd64f124ca62069f313a9cb86656ff087786016d76927ae2cf37846b006c9 \ + --hash=sha256:4feffb6537d64b84877da813a5c30f1422ea5739566abf0bd18065ac040e120a \ + --hash=sha256:50ed930df7289ff2a8d7afeb9603f8289e5704755c7e5c3bbd929c90c817164b \ + --hash=sha256:51e79c1f7470158e838808d4a996fa9bac72c498e93d8ebe5119bc1e6becb0db \ + --hash=sha256:556dba8fb6f5d8742f2923fe9457dbdd51e1049c4a43fd3986a0b14a1d815fc6 \ + --hash=sha256:598c3aaece21c503615fd59c92a3598b428b2f01bfb4b8ca9c4edeecc2438620 \ + --hash=sha256:5ed3657edf08512fc3fe81b510e35c2012fbd3081d2e26160f27ca28affec989 \ + --hash=sha256:626d60935cf668e70a5ce6ff184fd713e9683fb458898e4249b63be9e28286ea \ + --hash=sha256:644a6853d15b2512d67881586bd03f462c7ab755db95f16f14d7e238f2852c67 \ + --hash=sha256:655456777ff65c2c548b7c454af9c6f33f16c8884f11083244b5819cc214f1b5 \ + --hash=sha256:66c8a43a4f7b8df8b71ee1840e4211a3c8d93b214b213f590e18a1beca458f7d \ + --hash=sha256:6afc576f7b33cf00996e5c1102dc2a8f7cc89e39c0b55df93a0b78c1bd992b36 \ + --hash=sha256:6c3d53c796f8647d6deb1abe867daeb66dcc8a97e8455efa729516b997b8ed99 \ + --hash=sha256:709a48ef9a690e1343202916450bc48b9e51c049b089c7f79a267b46cffcdaa1 \ + --hash=sha256:70f9aad7de812d6541d29d2bbf8feb22ff7e1c299523db288004e3157ff4674e \ + --hash=sha256:8153b8bfc11e1e4d75bcb0bff1db232f9e10b274e0929de9d608027e0d34ff8b \ + --hash=sha256:87acf5963fc2b34825e5b6b048f40e3635dd547f590b04d2ab317c2619ef7ae8 \ + --hash=sha256:88df9880d507169449d434c293467418b9f6cbe82edd19284aa0409e7fdb933d \ + --hash=sha256:929ddf8c4c7f348e4c0a5a3a714b5c8542ffaa8c22954862a46ca1813b667ee7 \ + --hash=sha256:92d9abc807cf7d0e047b95ca5d957cf4792fcd04e920ca70d48add15c1a90ea7 \ + --hash=sha256:95b181891b4c71de4bb404c6621e7e2390745f887f2a026b2d99e92c17892339 \ + --hash=sha256:9e999574eddae35f1312c2b4b717b7885d4edd6cb46700e04f7f02db454e67c1 \ + --hash=sha256:a15459b0f4615b00bbd1e91f1b9e19b7e63aea7483d03d804186f278c0af2659 \ + --hash=sha256:a22738912262aa3e254e4f3cb079a95a67132fc5a063890e224393596902f5a4 \ + --hash=sha256:ab2fd90904c503739a75b7c8c5c01160130ba67944a7b77bbf36ef8054576e7f \ + --hash=sha256:ab3074b48c4e2cf1a960e6bbeb7f04566bf36b1861d5c9d4d8ac04b82e38ba20 \ + --hash=sha256:afe5a512f31ee6bd7d0dda52ec9864c984ca3d66664444f2d72e0dc4eb832e36 \ + --hash=sha256:b08a32ea2f8e42cf1d4be3169a98dd4be32bafe4f22b6c4cb4ba810fa9e5d2cb \ + --hash=sha256:b20c7c9a3bf701366556e1b1984ed2d0cedf999903c51311417cf5f591d8c78d \ + --hash=sha256:b2e8faa0ed68cb29af51edd8e24798bb661eac3bd9f65420c1887b6ca89987c8 \ + --hash=sha256:b7301b89040075c30e5768810bc96a8e8d78085b47d8be6e4c3f5a0b4ed478a0 \ + --hash=sha256:b7448cb5a725bb1e35ce88771b86fba35ef418952474492cf7c764059933ff8b \ + --hash=sha256:ca0fdcd73925568ca027e0b17ab07aad764be4706d0a925b89227e447d9737b7 \ + --hash=sha256:ca658cd1a680a5c9ea96dc61cdbae1e85c8f25849843aa799dfd3cb370ad4fbe \ + --hash=sha256:cbedb772ed74ff5be440fa8eee9bd49f64f6e3fc09436d9c7d8f1c287b121d77 \ + --hash=sha256:cd5dfcaeb10f7b7f9dc8941717c6c2ade08f587be2226222c12b25f0483ed497 \ + --hash=sha256:cf9022ef053f2694e31d630feaacb21ea24224be1c3ad0520b13d844274614fd \ + --hash=sha256:d002b6f00d73d69333dac9d0b8d5e84d9724ff9ef044fd63c5986e62b7c9e1b1 \ + --hash=sha256:d06bb1f751ba5d417047db62bca3c8fde202b8c11fb50742ab3ab962c81e8216 \ + --hash=sha256:d304906ecc71672e9c89e87c4675dc5c2645e1f4269a5063b99b0bb29f232d13 \ + --hash=sha256:e4e6b05a45525357e382909a4c1600444e2a45b4795163d3b22669285591c1ae \ + --hash=sha256:e74a9a0f5e3fff48fb5a7f2fd2b9b70a3fe014a67522f79b7cca4c0c7e43c9ae \ + --hash=sha256:ea37e7b45949df430fe649e5de8351c423430046a2af20b1c1961cae3afcda77 \ + --hash=sha256:f64836de09927cba6f79dcd00fdd7d5329f3fccc633468507079c829ca4db4e3 \ + --hash=sha256:fd6ec6be509c787f1caf6b247f0b1ca598bef13f4ddeaa126b7658215529ba0f \ + --hash=sha256:fd907ae12cd483cd83e414b12941c632a969171bf90fc937d0c9f268a31cafff \ + --hash=sha256:fd914713266421b7536de2bfa8181aa8c699432b6763a0ea64195ebe28bff6a9 \ + --hash=sha256:fde6c716d51c04b1c25d0b90364d0be954624a0ee9d60e23e850e8d48353d07a # via matplotlib cycler==0.12.1 \ --hash=sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30 \ --hash=sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c # via matplotlib -etils[epath,epy]==1.9.4 \ - --hash=sha256:4387e7a4911a3b5cc4b92b99a9211386d176b43bae1dac8e2fe345fc2cb95e4b \ - --hash=sha256:fad950414f0a1ca58c70c70915b0014f9953dd9bcf8aa951a0f75ff9becbeb24 +etils[epath,epy]==1.13.0 \ + --hash=sha256:a5b60c71f95bcd2d43d4e9fb3dc3879120c1f60472bb5ce19f7a860b1d44f607 \ + --hash=sha256:d9cd4f40fbe77ad6613b7348a18132cc511237b6c076dbb89105c0b520a4c6bb # via -r build/requirements.in -execnet==2.1.1 \ - --hash=sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc \ - --hash=sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3 +execnet==2.1.2 \ + --hash=sha256:63d83bfdd9a23e35b9c6a3261412324f964c2ec8dcd8d3c6916ee9373e0befcd \ + --hash=sha256:67fba928dd5a544b783f6056f449e5e3931a5c378b128bc18501f7ea79e296ec # via pytest-xdist -filelock==3.16.1 \ - --hash=sha256:2082e5703d51fbf98ea75855d9d5527e33d8ff23099bec374a134febee6946b0 \ - --hash=sha256:c249fbfcd5db47e5e2d6d62198e565475ee65e4831e2561c8e313fa7eb961435 +filelock==3.20.1 \ + --hash=sha256:15d9e9a67306188a44baa72f569d2bfd803076269365fdea0934385da4dc361a \ + --hash=sha256:b8360948b351b80f420878d8516519a2204b07aefcdcfd24912a5d33127f188c # via -r build/test-requirements.txt -flatbuffers==24.3.25 \ - --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ - --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 - # via -r build/test-requirements.txt -fonttools==4.54.1 \ - --hash=sha256:07e005dc454eee1cc60105d6a29593459a06321c21897f769a281ff2d08939f6 \ - --hash=sha256:0a911591200114969befa7f2cb74ac148bce5a91df5645443371aba6d222e263 \ - --hash=sha256:0d1d353ef198c422515a3e974a1e8d5b304cd54a4c2eebcae708e37cd9eeffb1 \ - --hash=sha256:0e88e3018ac809b9662615072dcd6b84dca4c2d991c6d66e1970a112503bba7e \ - --hash=sha256:1d152d1be65652fc65e695e5619e0aa0982295a95a9b29b52b85775243c06556 \ - --hash=sha256:262705b1663f18c04250bd1242b0515d3bbae177bee7752be67c979b7d47f43d \ - --hash=sha256:278913a168f90d53378c20c23b80f4e599dca62fbffae4cc620c8eed476b723e \ - --hash=sha256:301540e89cf4ce89d462eb23a89464fef50915255ece765d10eee8b2bf9d75b2 \ - --hash=sha256:31c32d7d4b0958600eac75eaf524b7b7cb68d3a8c196635252b7a2c30d80e986 \ - --hash=sha256:357cacb988a18aace66e5e55fe1247f2ee706e01debc4b1a20d77400354cddeb \ - --hash=sha256:37cddd62d83dc4f72f7c3f3c2bcf2697e89a30efb152079896544a93907733bd \ - --hash=sha256:41bb0b250c8132b2fcac148e2e9198e62ff06f3cc472065dff839327945c5882 \ - --hash=sha256:4aa4817f0031206e637d1e685251ac61be64d1adef111060df84fdcbc6ab6c44 \ - --hash=sha256:4e10d2e0a12e18f4e2dd031e1bf7c3d7017be5c8dbe524d07706179f355c5dac \ - --hash=sha256:5419771b64248484299fa77689d4f3aeed643ea6630b2ea750eeab219588ba20 \ - --hash=sha256:54471032f7cb5fca694b5f1a0aaeba4af6e10ae989df408e0216f7fd6cdc405d \ - --hash=sha256:58974b4987b2a71ee08ade1e7f47f410c367cdfc5a94fabd599c88165f56213a \ - --hash=sha256:58d29b9a294573d8319f16f2f79e42428ba9b6480442fa1836e4eb89c4d9d61c \ - --hash=sha256:5eb2474a7c5be8a5331146758debb2669bf5635c021aee00fd7c353558fc659d \ - --hash=sha256:6e37561751b017cf5c40fce0d90fd9e8274716de327ec4ffb0df957160be3bff \ - --hash=sha256:76ae5091547e74e7efecc3cbf8e75200bc92daaeb88e5433c5e3e95ea8ce5aa7 \ - --hash=sha256:7965af9b67dd546e52afcf2e38641b5be956d68c425bef2158e95af11d229f10 \ - --hash=sha256:7e3b7d44e18c085fd8c16dcc6f1ad6c61b71ff463636fcb13df7b1b818bd0c02 \ - --hash=sha256:7ed7ee041ff7b34cc62f07545e55e1468808691dddfd315d51dd82a6b37ddef2 \ - --hash=sha256:82834962b3d7c5ca98cb56001c33cf20eb110ecf442725dc5fdf36d16ed1ab07 \ - --hash=sha256:8583e563df41fdecef31b793b4dd3af8a9caa03397be648945ad32717a92885b \ - --hash=sha256:8fa92cb248e573daab8d032919623cc309c005086d743afb014c836636166f08 \ - --hash=sha256:93d458c8a6a354dc8b48fc78d66d2a8a90b941f7fec30e94c7ad9982b1fa6bab \ - --hash=sha256:957f669d4922f92c171ba01bef7f29410668db09f6c02111e22b2bce446f3285 \ - --hash=sha256:9dc080e5a1c3b2656caff2ac2633d009b3a9ff7b5e93d0452f40cd76d3da3b3c \ - --hash=sha256:9ef1b167e22709b46bf8168368b7b5d3efeaaa746c6d39661c1b4405b6352e58 \ - --hash=sha256:a7a310c6e0471602fe3bf8efaf193d396ea561486aeaa7adc1f132e02d30c4b9 \ - --hash=sha256:ab774fa225238986218a463f3fe151e04d8c25d7de09df7f0f5fce27b1243dbc \ - --hash=sha256:ada215fd079e23e060157aab12eba0d66704316547f334eee9ff26f8c0d7b8ab \ - --hash=sha256:c39287f5c8f4a0c5a55daf9eaf9ccd223ea59eed3f6d467133cc727d7b943a55 \ - --hash=sha256:c9c563351ddc230725c4bdf7d9e1e92cbe6ae8553942bd1fb2b2ff0884e8b714 \ - --hash=sha256:d26732ae002cc3d2ecab04897bb02ae3f11f06dd7575d1df46acd2f7c012a8d8 \ - --hash=sha256:d3b659d1029946f4ff9b6183984578041b520ce0f8fb7078bb37ec7445806b33 \ - --hash=sha256:dd9cc95b8d6e27d01e1e1f1fae8559ef3c02c76317da650a19047f249acd519d \ - --hash=sha256:e4564cf40cebcb53f3dc825e85910bf54835e8a8b6880d59e5159f0f325e637e \ - --hash=sha256:e7d82b9e56716ed32574ee106cabca80992e6bbdcf25a88d97d21f73a0aae664 \ - --hash=sha256:e8a4b261c1ef91e7188a30571be6ad98d1c6d9fa2427244c545e2fa0a2494dd7 \ - --hash=sha256:e96bc94c8cda58f577277d4a71f51c8e2129b8b36fd05adece6320dd3d57de8a \ - --hash=sha256:ed2f80ca07025551636c555dec2b755dd005e2ea8fbeb99fc5cdff319b70b23b \ - --hash=sha256:f5b8a096e649768c2f4233f947cf9737f8dbf8728b90e2771e2497c6e3d21d13 \ - --hash=sha256:f8e953cc0bddc2beaf3a3c3b5dd9ab7554677da72dfaf46951e193c9653e515a \ - --hash=sha256:fda582236fee135d4daeca056c8c88ec5f6f6d88a004a79b84a02547c8f57386 \ - --hash=sha256:fdb062893fd6d47b527d39346e0c5578b7957dcea6d6a3b6794569370013d9ac +flatbuffers==25.9.23 \ + --hash=sha256:255538574d6cb6d0a79a17ec8bc0d30985913b87513a01cce8bcdb6b4c44d0e2 \ + --hash=sha256:676f9fa62750bb50cf531b42a0a2a118ad8f7f797a511eda12881c016f093b12 + # via + # -r build/test-requirements.txt + # tensorflow +fonttools==4.61.1 \ + --hash=sha256:0de30bfe7745c0d1ffa2b0b7048fb7123ad0d71107e10ee090fa0b16b9452e87 \ + --hash=sha256:10d88e55330e092940584774ee5e8a6971b01fc2f4d3466a1d6c158230880796 \ + --hash=sha256:11f35ad7805edba3aac1a3710d104592df59f4b957e30108ae0ba6c10b11dd75 \ + --hash=sha256:15acc09befd16a0fb8a8f62bc147e1a82817542d72184acca9ce6e0aeda9fa6d \ + --hash=sha256:17d2bf5d541add43822bcf0c43d7d847b160c9bb01d15d5007d84e2217aaa371 \ + --hash=sha256:2180f14c141d2f0f3da43f3a81bc8aa4684860f6b0e6f9e165a4831f24e6a23b \ + --hash=sha256:21e7c8d76f62ab13c9472ccf74515ca5b9a761d1bde3265152a6dc58700d895b \ + --hash=sha256:41a7170d042e8c0024703ed13b71893519a1a6d6e18e933e3ec7507a2c26a4b2 \ + --hash=sha256:41ed4b5ec103bd306bb68f81dc166e77409e5209443e5773cb4ed837bcc9b0d3 \ + --hash=sha256:497c31ce314219888c0e2fce5ad9178ca83fe5230b01a5006726cdf3ac9f24d9 \ + --hash=sha256:4c1b526c8d3f615a7b1867f38a9410849c8f4aef078535742198e942fba0e9bd \ + --hash=sha256:4d7092bb38c53bbc78e9255a59158b150bcdc115a1e3b3ce0b5f267dc35dd63c \ + --hash=sha256:4f5686e1fe5fce75d82d93c47a438a25bf0d1319d2843a926f741140b2b16e0c \ + --hash=sha256:58b0ee0ab5b1fc9921eccfe11d1435added19d6494dde14e323f25ad2bc30c56 \ + --hash=sha256:5ce02f38a754f207f2f06557523cd39a06438ba3aafc0639c477ac409fc64e37 \ + --hash=sha256:5fade934607a523614726119164ff621e8c30e8fa1ffffbbd358662056ba69f0 \ + --hash=sha256:5fe9fd43882620017add5eabb781ebfbc6998ee49b35bd7f8f79af1f9f99a958 \ + --hash=sha256:64102ca87e84261419c3747a0d20f396eb024bdbeb04c2bfb37e2891f5fadcb5 \ + --hash=sha256:664c5a68ec406f6b1547946683008576ef8b38275608e1cee6c061828171c118 \ + --hash=sha256:6675329885c44657f826ef01d9e4fb33b9158e9d93c537d84ad8399539bc6f69 \ + --hash=sha256:75c1a6dfac6abd407634420c93864a1e274ebc1c7531346d9254c0d8f6ca00f9 \ + --hash=sha256:75da8f28eff26defba42c52986de97b22106cb8f26515b7c22443ebc9c2d3261 \ + --hash=sha256:77efb033d8d7ff233385f30c62c7c79271c8885d5c9657d967ede124671bbdfb \ + --hash=sha256:78a7d3ab09dc47ac1a363a493e6112d8cabed7ba7caad5f54dbe2f08676d1b47 \ + --hash=sha256:7c7db70d57e5e1089a274cbb2b1fd635c9a24de809a231b154965d415d6c6d24 \ + --hash=sha256:8c56c488ab471628ff3bfa80964372fc13504ece601e0d97a78ee74126b2045c \ + --hash=sha256:91669ccac46bbc1d09e9273546181919064e8df73488ea087dcac3e2968df9ba \ + --hash=sha256:9b666a475a65f4e839d3d10473fad6d47e0a9db14a2f4a224029c5bfde58ad2c \ + --hash=sha256:9cfef3ab326780c04d6646f68d4b4742aae222e8b8ea1d627c74e38afcbc9d91 \ + --hash=sha256:a13fc8aeb24bad755eea8f7f9d409438eb94e82cf86b08fe77a03fbc8f6a96b1 \ + --hash=sha256:a75c301f96db737e1c5ed5fd7d77d9c34466de16095a266509e13da09751bd19 \ + --hash=sha256:a76d4cb80f41ba94a6691264be76435e5f72f2cb3cab0b092a6212855f71c2f6 \ + --hash=sha256:aed04cabe26f30c1647ef0e8fbb207516fd40fe9472e9439695f5c6998e60ac5 \ + --hash=sha256:b148b56f5de675ee16d45e769e69f87623a4944f7443850bf9a9376e628a89d2 \ + --hash=sha256:b501c862d4901792adaec7c25b1ecc749e2662543f68bb194c42ba18d6eec98d \ + --hash=sha256:b846a1fcf8beadeb9ea4f44ec5bdde393e2f1569e17d700bfc49cd69bde75881 \ + --hash=sha256:b931ae8f62db78861b0ff1ac017851764602288575d65b8e8ff1963fed419063 \ + --hash=sha256:c33ab3ca9d3ccd581d58e989d67554e42d8d4ded94ab3ade3508455fe70e65f7 \ + --hash=sha256:c6604b735bb12fef8e0efd5578c9fb5d3d8532d5001ea13a19cddf295673ee09 \ + --hash=sha256:d8db08051fc9e7d8bc622f2112511b8107d8f27cd89e2f64ec45e9825e8288da \ + --hash=sha256:d9203500f7c63545b4ce3799319fe4d9feb1a1b89b28d3cb5abd11b9dd64147e \ + --hash=sha256:dc492779501fa723b04d0ab1f5be046797fee17d27700476edc7ee9ae535a61e \ + --hash=sha256:e6bcdf33aec38d16508ce61fd81838f24c83c90a1d1b8c68982857038673d6b8 \ + --hash=sha256:e76ce097e3c57c4bcb67c5aa24a0ecdbd9f74ea9219997a707a4061fbe2707aa \ + --hash=sha256:eff1ac3cc66c2ac7cda1e64b4e2f3ffef474b7335f92fc3833fc632d595fcee6 \ + --hash=sha256:f3cb4a569029b9f291f88aafc927dd53683757e640081ca8c412781ea144565e \ + --hash=sha256:f79b168428351d11e10c5aeb61a74e1851ec221081299f4cf56036a95431c43a \ + --hash=sha256:fa646ecec9528bef693415c79a86e733c70a4965dd938e9a226b0fc64c9d2e6c \ + --hash=sha256:fe2efccb324948a11dd09d22136fe2ac8a97d6c1347cf0b58a911dcd529f66b7 \ + --hash=sha256:fff4f534200a04b4a36e7ae3cb74493afe807b517a09e99cb4faa89a34ed6ecd # via matplotlib -fsspec==2024.9.0 \ - --hash=sha256:4b0afb90c2f21832df142f292649035d80b421f60a9e1c027802e5a0da2b04e8 \ - --hash=sha256:a0947d552d8a6efa72cc2c730b12c41d043509156966cca4fb157b0f2a0c574b +fsspec==2025.10.0 \ + --hash=sha256:7c7712353ae7d875407f97715f0e1ffcc21e33d5b24556cb1e090ae9409ec61d \ + --hash=sha256:b6789427626f068f9a83ca4e8a3cc050850b6c0f71f99ddb4f542b8266a26a59 # via etils -hypothesis==6.112.5 \ - --hash=sha256:82fbd28a92c4d88743740e3ec05415ea25119d825d1fdac9ab7bf717fe56297b \ - --hash=sha256:e6b7c8ba1126e07cfbf76b8bb544cedd89cb7f7bcf6c315bd759cd2efc2063ff +gast==0.6.0 \ + --hash=sha256:52b182313f7330389f72b069ba00f174cfe2a06411099547288839c6cbafbd54 \ + --hash=sha256:88fc5300d32c7ac6ca7b515310862f71e6fdf2c029bbec7c66c0f5dd47b6b1fb + # via tensorflow +google-pasta==0.2.0 \ + --hash=sha256:4612951da876b1a10fe3960d7226f0c7682cf901e16ac06e473b267a5afa8954 \ + --hash=sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed \ + --hash=sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e + # via tensorflow +grpcio==1.76.0 \ + --hash=sha256:035d90bc79eaa4bed83f524331d55e35820725c9fbb00ffa1904d5550ed7ede3 \ + --hash=sha256:04bbe1bfe3a68bbfd4e52402ab7d4eb59d72d02647ae2042204326cf4bbad280 \ + --hash=sha256:063065249d9e7e0782d03d2bca50787f53bd0fb89a67de9a7b521c4a01f1989b \ + --hash=sha256:06c3d6b076e7b593905d04fdba6a0525711b3466f43b3400266f04ff735de0cd \ + --hash=sha256:08caea849a9d3c71a542827d6df9d5a69067b0a1efbea8a855633ff5d9571465 \ + --hash=sha256:0aaa82d0813fd4c8e589fac9b65d7dd88702555f702fb10417f96e2a2a6d4c0f \ + --hash=sha256:0b7604868b38c1bfd5cf72d768aedd7db41d78cb6a4a18585e33fb0f9f2363fd \ + --hash=sha256:0c37db8606c258e2ee0c56b78c62fc9dee0e901b5dbdcf816c2dd4ad652b8b0c \ + --hash=sha256:1c9b93f79f48b03ada57ea24725d83a30284a012ec27eab2cf7e50a550cbbbcc \ + --hash=sha256:2107b0c024d1b35f4083f11245c0e23846ae64d02f40b2b226684840260ed054 \ + --hash=sha256:2229ae655ec4e8999599469559e97630185fdd53ae1e8997d147b7c9b2b72cba \ + --hash=sha256:25a18e9810fbc7e7f03ec2516addc116a957f8cbb8cbc95ccc80faa072743d03 \ + --hash=sha256:26ef06c73eb53267c2b319f43e6634c7556ea37672029241a056629af27c10e2 \ + --hash=sha256:2e1743fbd7f5fa713a1b0a8ac8ebabf0ec980b5d8809ec358d488e273b9cf02a \ + --hash=sha256:32483fe2aab2c3794101c2a159070584e5db11d0aa091b2c0ea9c4fc43d0d749 \ + --hash=sha256:3bf0f392c0b806905ed174dcd8bdd5e418a40d5567a05615a030a5aeddea692d \ + --hash=sha256:3e2a27c89eb9ac3d81ec8835e12414d73536c6e620355d65102503064a4ed6eb \ + --hash=sha256:40ad3afe81676fd9ec6d9d406eda00933f218038433980aa19d401490e46ecde \ + --hash=sha256:4215d3a102bd95e2e11b5395c78562967959824156af11fa93d18fdd18050990 \ + --hash=sha256:45d59a649a82df5718fd9527ce775fd66d1af35e6d31abdcdc906a49c6822958 \ + --hash=sha256:45e0111e73f43f735d70786557dc38141185072d7ff8dc1829d6a77ac1471468 \ + --hash=sha256:479496325ce554792dba6548fae3df31a72cef7bad71ca2e12b0e58f9b336bfc \ + --hash=sha256:490fa6d203992c47c7b9e4a9d39003a0c2bcc1c9aa3c058730884bbbb0ee9f09 \ + --hash=sha256:49ce47231818806067aea3324d4bf13825b658ad662d3b25fada0bdad9b8a6af \ + --hash=sha256:4baf3cbe2f0be3289eb68ac8ae771156971848bb8aaff60bad42005539431980 \ + --hash=sha256:522175aba7af9113c48ec10cc471b9b9bd4f6ceb36aeb4544a8e2c80ed9d252d \ + --hash=sha256:5e8571632780e08526f118f74170ad8d50fb0a48c23a746bef2a6ebade3abd6f \ + --hash=sha256:615ba64c208aaceb5ec83bfdce7728b80bfeb8be97562944836a7a0a9647d882 \ + --hash=sha256:61f69297cba3950a524f61c7c8ee12e55c486cb5f7db47ff9dcee33da6f0d3ae \ + --hash=sha256:65a20de41e85648e00305c1bb09a3598f840422e522277641145a32d42dcefcc \ + --hash=sha256:6a15c17af8839b6801d554263c546c69c4d7718ad4321e3166175b37eaacca77 \ + --hash=sha256:747fa73efa9b8b1488a95d0ba1039c8e2dca0f741612d80415b1e1c560febf4e \ + --hash=sha256:7be78388d6da1a25c0d5ec506523db58b18be22d9c37d8d3a32c08be4987bd73 \ + --hash=sha256:81fd9652b37b36f16138611c7e884eb82e0cec137c40d3ef7c3f9b3ed00f6ed8 \ + --hash=sha256:83d57312a58dcfe2a3a0f9d1389b299438909a02db60e2f2ea2ae2d8034909d3 \ + --hash=sha256:8843114c0cfce61b40ad48df65abcfc00d4dba82eae8718fab5352390848c5da \ + --hash=sha256:8cc3309d8e08fd79089e13ed4819d0af72aa935dd8f435a195fd152796752ff2 \ + --hash=sha256:8ebe63ee5f8fa4296b1b8cfc743f870d10e902ca18afc65c68cf46fd39bb0783 \ + --hash=sha256:8eddfb4d203a237da6f3cc8a540dad0517d274b5a1e9e636fd8d2c79b5c1d397 \ + --hash=sha256:922fa70ba549fce362d2e2871ab542082d66e2aaf0c19480ea453905b01f384e \ + --hash=sha256:931091142fd8cc14edccc0845a79248bc155425eee9a98b2db2ea4f00a235a42 \ + --hash=sha256:971fd5a1d6e62e00d945423a567e42eb1fa678ba89072832185ca836a94daaa6 \ + --hash=sha256:980a846182ce88c4f2f7e2c22c56aefd515daeb36149d1c897f83cf57999e0b6 \ + --hash=sha256:9d9adda641db7207e800a7f089068f6f645959f2df27e870ee81d44701dd9db3 \ + --hash=sha256:9f8f757bebaaea112c00dba718fc0d3260052ce714e25804a03f93f5d1c6cc11 \ + --hash=sha256:a6ae758eb08088d36812dd5d9af7a9859c05b1e0f714470ea243694b49278e7b \ + --hash=sha256:a8c2cf1209497cf659a667d7dea88985e834c24b7c3b605e6254cbb5076d985c \ + --hash=sha256:acab0277c40eff7143c2323190ea57b9ee5fd353d8190ee9652369fae735668a \ + --hash=sha256:b331680e46239e090f5b3cead313cc772f6caa7d0fc8de349337563125361a4a \ + --hash=sha256:c088e7a90b6017307f423efbb9d1ba97a22aa2170876223f9709e9d1de0b5347 \ + --hash=sha256:d099566accf23d21037f18a2a63d323075bebace807742e4b0ac210971d4dd70 \ + --hash=sha256:d388087771c837cdb6515539f43b9d4bf0b0f23593a24054ac16f7a960be16f4 \ + --hash=sha256:dcfe41187da8992c5f40aa8c5ec086fa3672834d2be57a32384c08d5a05b4c00 \ + --hash=sha256:e6d1db20594d9daba22f90da738b1a0441a7427552cc6e2e3d1297aeddc00378 \ + --hash=sha256:ebea5cc3aa8ea72e04df9913492f9a96d9348db876f9dda3ad729cfedf7ac416 \ + --hash=sha256:ebebf83299b0cb1721a8859ea98f3a77811e35dce7609c5c963b9ad90728f886 \ + --hash=sha256:f0e34c2079d47ae9f6188211db9e777c619a21d4faba6977774e8fa43b085e48 \ + --hash=sha256:f92f88e6c033db65a5ae3d97905c8fea9c725b63e28d5a75cb73b49bda5024d8 \ + --hash=sha256:f9f7bd5faab55f47231ad8dba7787866b69f5e93bc306e3915606779bbfb4ba8 \ + --hash=sha256:fd5ef5932f6475c436c4a55e4336ebbe47bd3272be04964a03d316bbf4afbcbc \ + --hash=sha256:ff8a59ea85a1f2191a0ffcc61298c571bc566332f82e5f5be1b83c9d8e668a62 + # via + # tensorboard + # tensorflow +h5py==3.15.1 \ + --hash=sha256:01f55111ca516f5568ae7a7fc8247dfce607de331b4467ee8a9a6ed14e5422c7 \ + --hash=sha256:0e2f471688402c3404fa4e13466e373e622fd4b74b47b56cfdff7cc688209422 \ + --hash=sha256:121b2b7a4c1915d63737483b7bff14ef253020f617c2fb2811f67a4bed9ac5e8 \ + --hash=sha256:25c8843fec43b2cc368aa15afa1cdf83fc5e17b1c4e10cd3771ef6c39b72e5ce \ + --hash=sha256:28a20e1a4082a479b3d7db2169f3a5034af010b90842e75ebbf2e9e49eb4183e \ + --hash=sha256:2cbc4104d3d4aca9d6db8c0c694555e255805bfeacf9eb1349bda871e26cacbe \ + --hash=sha256:316dd0f119734f324ca7ed10b5627a2de4ea42cc4dfbcedbee026aaa361c238c \ + --hash=sha256:4411c1867b9899a25e983fff56d820a66f52ac326bbe10c7cdf7d832c9dcd883 \ + --hash=sha256:4c45802bcb711e128a6839cb6c01e9ac648dc55df045c9542a675c771f15c8d5 \ + --hash=sha256:550e51131376889656feec4aff2170efc054a7fe79eb1da3bb92e1625d1ac878 \ + --hash=sha256:59b0d63b318bf3cc06687def2b45afd75926bbc006f7b8cd2b1a231299fc8599 \ + --hash=sha256:59b25cf02411bf12e14f803fef0b80886444c7fe21a5ad17c6a28d3f08098a1e \ + --hash=sha256:5aaa330bcbf2830150c50897ea5dcbed30b5b6d56897289846ac5b9e529ec243 \ + --hash=sha256:5b849ba619a066196169763c33f9f0f02e381156d61c03e000bb0100f9950faf \ + --hash=sha256:5f4fb0567eb8517c3ecd6b3c02c4f4e9da220c8932604960fd04e24ee1254763 \ + --hash=sha256:61d5a58a9851e01ee61c932bbbb1c98fe20aba0a5674776600fb9a361c0aa652 \ + --hash=sha256:64ce3f6470adb87c06e3a8dd1b90e973699f1759ad79bfa70c230939bff356c9 \ + --hash=sha256:67e59f6c2f19a32973a40f43d9a088ae324fe228c8366e25ebc57ceebf093a6b \ + --hash=sha256:80e5bb5b9508d5d9da09f81fd00abbb3f85da8143e56b1585d59bc8ceb1dba8b \ + --hash=sha256:8a33bfd5dfcea037196f7778534b1ff7e36a7f40a89e648c8f2967292eb6898e \ + --hash=sha256:954e480433e82d3872503104f9b285d369048c3a788b2b1a00e53d1c47c98dd2 \ + --hash=sha256:99d374a21f7321a4c6ab327c4ab23bd925ad69821aeb53a1e75dd809d19f67fa \ + --hash=sha256:9c73d1d7cdb97d5b17ae385153472ce118bed607e43be11e9a9deefaa54e0734 \ + --hash=sha256:a308fd8681a864c04423c0324527237a0484e2611e3441f8089fd00ed56a8171 \ + --hash=sha256:a6d8c5a05a76aca9a494b4c53ce8a9c29023b7f64f625c6ce1841e92a362ccdf \ + --hash=sha256:ab2219dbc6fcdb6932f76b548e2b16f34a1f52b7666e998157a4dfc02e2c4123 \ + --hash=sha256:b39239947cb36a819147fc19e86b618dcb0953d1cd969f5ed71fc0de60392427 \ + --hash=sha256:b51469890e58e85d5242e43aab29f5e9c7e526b951caab354f3ded4ac88e7b76 \ + --hash=sha256:c256254a8a81e2bddc0d376e23e2a6d2dc8a1e8a2261835ed8c1281a0744cd97 \ + --hash=sha256:c8440fd8bee9500c235ecb7aa1917a0389a2adb80c209fa1cc485bd70e0d94a5 \ + --hash=sha256:c86e3ed45c4473564de55aa83b6fc9e5ead86578773dfbd93047380042e26b69 \ + --hash=sha256:c970fb80001fffabb0109eaf95116c8e7c0d3ca2de854e0901e8a04c1f098509 \ + --hash=sha256:ca8a3a22458956ee7b40d8e39c9a9dc01f82933e4c030c964f8b875592f4d831 \ + --hash=sha256:d8cb02c3a96255149ed3ac811eeea25b655d959c6dd5ce702c9a95ff11859eb5 \ + --hash=sha256:dea78b092fd80a083563ed79a3171258d4a4d307492e7cf8b2313d464c82ba52 \ + --hash=sha256:e02fe77a03f652500d8bff288cbf3675f742fc0411f5a628fa37116507dc7cc0 \ + --hash=sha256:e7f6c841efd4e6e5b7e82222eaf90819927b6d256ab0f3aca29675601f654f3c \ + --hash=sha256:f4a016df3f4a8a14d573b496e4d1964deb380e26031fc85fb40e417e9131888a \ + --hash=sha256:fa8df5267f545b4946df8ca0d93d23382191018e4cda2deda4c2cedf9a010e13 \ + --hash=sha256:fd125c131889ebbef0849f4a0e29cf363b48aba42f228d08b4079913b576bb3a + # via + # keras + # tensorflow +hypothesis==6.142.1 \ + --hash=sha256:3179cb08756562c526aaf4a9871ebbff83d2d75c03896ed0bc9c1d14097a930c \ + --hash=sha256:95a7d38fcc58e697e3020665adcb951c630cdbc8065e4b4474949e486b06bd6d # via -r build/test-requirements.txt -importlib-resources==6.4.5 \ - --hash=sha256:980862a1d16c9e147a59603677fa2aa5fd82b87f223b6cb870695bcfce830065 \ - --hash=sha256:ac29d5f956f01d5e4bb63102a5a19957f1b9175e45649977264a1416783bb717 +idna==3.11 \ + --hash=sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea \ + --hash=sha256:795dafcc9c04ed0c1fb032c2aa73654d8e8c5023a7df64a53f39190ada629902 + # via requests +importlib-resources==6.5.2 \ + --hash=sha256:185f87adef5bcc288449d98fb4fba07cea78bc036455dd44c5fc4a2fe78fed2c \ + --hash=sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec # via etils -iniconfig==2.0.0 \ - --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ - --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 +iniconfig==2.3.0 \ + --hash=sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730 \ + --hash=sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12 # via pytest -kiwisolver==1.4.7 \ - --hash=sha256:073a36c8273647592ea332e816e75ef8da5c303236ec0167196793eb1e34657a \ - --hash=sha256:08471d4d86cbaec61f86b217dd938a83d85e03785f51121e791a6e6689a3be95 \ - --hash=sha256:0c18ec74c0472de033e1bebb2911c3c310eef5649133dd0bedf2a169a1b269e5 \ - --hash=sha256:0c6c43471bc764fad4bc99c5c2d6d16a676b1abf844ca7c8702bdae92df01ee0 \ - --hash=sha256:10849fb2c1ecbfae45a693c070e0320a91b35dd4bcf58172c023b994283a124d \ - --hash=sha256:18077b53dc3bb490e330669a99920c5e6a496889ae8c63b58fbc57c3d7f33a18 \ - --hash=sha256:18e0cca3e008e17fe9b164b55735a325140a5a35faad8de92dd80265cd5eb80b \ - --hash=sha256:22f499f6157236c19f4bbbd472fa55b063db77a16cd74d49afe28992dff8c258 \ - --hash=sha256:2a8781ac3edc42ea4b90bc23e7d37b665d89423818e26eb6df90698aa2287c95 \ - --hash=sha256:2e6039dcbe79a8e0f044f1c39db1986a1b8071051efba3ee4d74f5b365f5226e \ - --hash=sha256:34ea1de54beef1c104422d210c47c7d2a4999bdecf42c7b5718fbe59a4cac383 \ - --hash=sha256:3ab58c12a2cd0fc769089e6d38466c46d7f76aced0a1f54c77652446733d2d02 \ - --hash=sha256:3abc5b19d24af4b77d1598a585b8a719beb8569a71568b66f4ebe1fb0449460b \ - --hash=sha256:3bf1ed55088f214ba6427484c59553123fdd9b218a42bbc8c6496d6754b1e523 \ - --hash=sha256:3ce6b2b0231bda412463e152fc18335ba32faf4e8c23a754ad50ffa70e4091ee \ - --hash=sha256:3da53da805b71e41053dc670f9a820d1157aae77b6b944e08024d17bcd51ef88 \ - --hash=sha256:3f9362ecfca44c863569d3d3c033dbe8ba452ff8eed6f6b5806382741a1334bd \ - --hash=sha256:409afdfe1e2e90e6ee7fc896f3df9a7fec8e793e58bfa0d052c8a82f99c37abb \ - --hash=sha256:40fa14dbd66b8b8f470d5fc79c089a66185619d31645f9b0773b88b19f7223c4 \ - --hash=sha256:4322872d5772cae7369f8351da1edf255a604ea7087fe295411397d0cfd9655e \ - --hash=sha256:44756f9fd339de0fb6ee4f8c1696cfd19b2422e0d70b4cefc1cc7f1f64045a8c \ - --hash=sha256:46707a10836894b559e04b0fd143e343945c97fd170d69a2d26d640b4e297935 \ - --hash=sha256:48b571ecd8bae15702e4f22d3ff6a0f13e54d3d00cd25216d5e7f658242065ee \ - --hash=sha256:48be928f59a1f5c8207154f935334d374e79f2b5d212826307d072595ad76a2e \ - --hash=sha256:4bfa75a048c056a411f9705856abfc872558e33c055d80af6a380e3658766038 \ - --hash=sha256:4c00336b9dd5ad96d0a558fd18a8b6f711b7449acce4c157e7343ba92dd0cf3d \ - --hash=sha256:4c26ed10c4f6fa6ddb329a5120ba3b6db349ca192ae211e882970bfc9d91420b \ - --hash=sha256:4d05d81ecb47d11e7f8932bd8b61b720bf0b41199358f3f5e36d38e28f0532c5 \ - --hash=sha256:4e77f2126c3e0b0d055f44513ed349038ac180371ed9b52fe96a32aa071a5107 \ - --hash=sha256:5337ec7809bcd0f424c6b705ecf97941c46279cf5ed92311782c7c9c2026f07f \ - --hash=sha256:5360cc32706dab3931f738d3079652d20982511f7c0ac5711483e6eab08efff2 \ - --hash=sha256:58370b1ffbd35407444d57057b57da5d6549d2d854fa30249771775c63b5fe17 \ - --hash=sha256:58cb20602b18f86f83a5c87d3ee1c766a79c0d452f8def86d925e6c60fbf7bfb \ - --hash=sha256:599b5c873c63a1f6ed7eead644a8a380cfbdf5db91dcb6f85707aaab213b1674 \ - --hash=sha256:5b7dfa3b546da08a9f622bb6becdb14b3e24aaa30adba66749d38f3cc7ea9706 \ - --hash=sha256:5b9c3f4ee0b9a439d2415012bd1b1cc2df59e4d6a9939f4d669241d30b414327 \ - --hash=sha256:5d34eb8494bea691a1a450141ebb5385e4b69d38bb8403b5146ad279f4b30fa3 \ - --hash=sha256:5d5abf8f8ec1f4e22882273c423e16cae834c36856cac348cfbfa68e01c40f3a \ - --hash=sha256:5e3bc157fed2a4c02ec468de4ecd12a6e22818d4f09cde2c31ee3226ffbefab2 \ - --hash=sha256:612a10bdae23404a72941a0fc8fa2660c6ea1217c4ce0dbcab8a8f6543ea9e7f \ - --hash=sha256:657a05857bda581c3656bfc3b20e353c232e9193eb167766ad2dc58b56504948 \ - --hash=sha256:65e720d2ab2b53f1f72fb5da5fb477455905ce2c88aaa671ff0a447c2c80e8e3 \ - --hash=sha256:693902d433cf585133699972b6d7c42a8b9f8f826ebcaf0132ff55200afc599e \ - --hash=sha256:6af936f79086a89b3680a280c47ea90b4df7047b5bdf3aa5c524bbedddb9e545 \ - --hash=sha256:71bb308552200fb2c195e35ef05de12f0c878c07fc91c270eb3d6e41698c3bcc \ - --hash=sha256:764202cc7e70f767dab49e8df52c7455e8de0df5d858fa801a11aa0d882ccf3f \ - --hash=sha256:76c8094ac20ec259471ac53e774623eb62e6e1f56cd8690c67ce6ce4fcb05650 \ - --hash=sha256:78a42513018c41c2ffd262eb676442315cbfe3c44eed82385c2ed043bc63210a \ - --hash=sha256:79849239c39b5e1fd906556c474d9b0439ea6792b637511f3fe3a41158d89ca8 \ - --hash=sha256:7ab9ccab2b5bd5702ab0803676a580fffa2aa178c2badc5557a84cc943fcf750 \ - --hash=sha256:7bbfcb7165ce3d54a3dfbe731e470f65739c4c1f85bb1018ee912bae139e263b \ - --hash=sha256:7c06a4c7cf15ec739ce0e5971b26c93638730090add60e183530d70848ebdd34 \ - --hash=sha256:801fa7802e5cfabe3ab0c81a34c323a319b097dfb5004be950482d882f3d7225 \ - --hash=sha256:803b8e1459341c1bb56d1c5c010406d5edec8a0713a0945851290a7930679b51 \ - --hash=sha256:82a5c2f4b87c26bb1a0ef3d16b5c4753434633b83d365cc0ddf2770c93829e3c \ - --hash=sha256:84ec80df401cfee1457063732d90022f93951944b5b58975d34ab56bb150dfb3 \ - --hash=sha256:8705f17dfeb43139a692298cb6637ee2e59c0194538153e83e9ee0c75c2eddde \ - --hash=sha256:88a9ca9c710d598fd75ee5de59d5bda2684d9db36a9f50b6125eaea3969c2599 \ - --hash=sha256:88f17c5ffa8e9462fb79f62746428dd57b46eb931698e42e990ad63103f35e6c \ - --hash=sha256:8a3ec5aa8e38fc4c8af308917ce12c536f1c88452ce554027e55b22cbbfbff76 \ - --hash=sha256:8a9c83f75223d5e48b0bc9cb1bf2776cf01563e00ade8775ffe13b0b6e1af3a6 \ - --hash=sha256:8b01aac285f91ca889c800042c35ad3b239e704b150cfd3382adfc9dcc780e39 \ - --hash=sha256:8d53103597a252fb3ab8b5845af04c7a26d5e7ea8122303dd7a021176a87e8b9 \ - --hash=sha256:8e045731a5416357638d1700927529e2b8ab304811671f665b225f8bf8d8f933 \ - --hash=sha256:8f0ea6da6d393d8b2e187e6a5e3fb81f5862010a40c3945e2c6d12ae45cfb2ad \ - --hash=sha256:90da3b5f694b85231cf93586dad5e90e2d71b9428f9aad96952c99055582f520 \ - --hash=sha256:913983ad2deb14e66d83c28b632fd35ba2b825031f2fa4ca29675e665dfecbe1 \ - --hash=sha256:9242795d174daa40105c1d86aba618e8eab7bf96ba8c3ee614da8302a9f95503 \ - --hash=sha256:929e294c1ac1e9f615c62a4e4313ca1823ba37326c164ec720a803287c4c499b \ - --hash=sha256:933d4de052939d90afbe6e9d5273ae05fb836cc86c15b686edd4b3560cc0ee36 \ - --hash=sha256:942216596dc64ddb25adb215c3c783215b23626f8d84e8eff8d6d45c3f29f75a \ - --hash=sha256:94252291e3fe68001b1dd747b4c0b3be12582839b95ad4d1b641924d68fd4643 \ - --hash=sha256:9893ff81bd7107f7b685d3017cc6583daadb4fc26e4a888350df530e41980a60 \ - --hash=sha256:9e838bba3a3bac0fe06d849d29772eb1afb9745a59710762e4ba3f4cb8424483 \ - --hash=sha256:a0f64a48bb81af7450e641e3fe0b0394d7381e342805479178b3d335d60ca7cf \ - --hash=sha256:a17f6a29cf8935e587cc8a4dbfc8368c55edc645283db0ce9801016f83526c2d \ - --hash=sha256:a1ecf0ac1c518487d9d23b1cd7139a6a65bc460cd101ab01f1be82ecf09794b6 \ - --hash=sha256:a79ae34384df2b615eefca647a2873842ac3b596418032bef9a7283675962644 \ - --hash=sha256:a91b5f9f1205845d488c928e8570dcb62b893372f63b8b6e98b863ebd2368ff2 \ - --hash=sha256:aa0abdf853e09aff551db11fce173e2177d00786c688203f52c87ad7fcd91ef9 \ - --hash=sha256:ac542bf38a8a4be2dc6b15248d36315ccc65f0743f7b1a76688ffb6b5129a5c2 \ - --hash=sha256:ad42ba922c67c5f219097b28fae965e10045ddf145d2928bfac2eb2e17673640 \ - --hash=sha256:aeb3531b196ef6f11776c21674dba836aeea9d5bd1cf630f869e3d90b16cfade \ - --hash=sha256:b38ac83d5f04b15e515fd86f312479d950d05ce2368d5413d46c088dda7de90a \ - --hash=sha256:b7d755065e4e866a8086c9bdada157133ff466476a2ad7861828e17b6026e22c \ - --hash=sha256:bd3de6481f4ed8b734da5df134cd5a6a64fe32124fe83dde1e5b5f29fe30b1e6 \ - --hash=sha256:bfa1acfa0c54932d5607e19a2c24646fb4c1ae2694437789129cf099789a3b00 \ - --hash=sha256:c619b101e6de2222c1fcb0531e1b17bbffbe54294bfba43ea0d411d428618c27 \ - --hash=sha256:ce8be0466f4c0d585cdb6c1e2ed07232221df101a4c6f28821d2aa754ca2d9e2 \ - --hash=sha256:cf0438b42121a66a3a667de17e779330fc0f20b0d97d59d2f2121e182b0505e4 \ - --hash=sha256:cf8bcc23ceb5a1b624572a1623b9f79d2c3b337c8c455405ef231933a10da379 \ - --hash=sha256:d2b0e12a42fb4e72d509fc994713d099cbb15ebf1103545e8a45f14da2dfca54 \ - --hash=sha256:d83db7cde68459fc803052a55ace60bea2bae361fc3b7a6d5da07e11954e4b09 \ - --hash=sha256:dda56c24d869b1193fcc763f1284b9126550eaf84b88bbc7256e15028f19188a \ - --hash=sha256:dea0bf229319828467d7fca8c7c189780aa9ff679c94539eed7532ebe33ed37c \ - --hash=sha256:e1631290ee9271dffe3062d2634c3ecac02c83890ada077d225e081aca8aab89 \ - --hash=sha256:e28c7fea2196bf4c2f8d46a0415c77a1c480cc0724722f23d7410ffe9842c407 \ - --hash=sha256:e2e6c39bd7b9372b0be21456caab138e8e69cc0fc1190a9dfa92bd45a1e6e904 \ - --hash=sha256:e33e8fbd440c917106b237ef1a2f1449dfbb9b6f6e1ce17c94cd6a1e0d438376 \ - --hash=sha256:e8df2eb9b2bac43ef8b082e06f750350fbbaf2887534a5be97f6cf07b19d9583 \ - --hash=sha256:e968b84db54f9d42046cf154e02911e39c0435c9801681e3fc9ce8a3c4130278 \ - --hash=sha256:eb542fe7933aa09d8d8f9d9097ef37532a7df6497819d16efe4359890a2f417a \ - --hash=sha256:edcfc407e4eb17e037bca59be0e85a2031a2ac87e4fed26d3e9df88b4165f92d \ - --hash=sha256:eee3ea935c3d227d49b4eb85660ff631556841f6e567f0f7bda972df6c2c9935 \ - --hash=sha256:ef97b8df011141c9b0f6caf23b29379f87dd13183c978a30a3c546d2c47314cb \ - --hash=sha256:f106407dda69ae456dd1227966bf445b157ccc80ba0dff3802bb63f30b74e895 \ - --hash=sha256:f3160309af4396e0ed04db259c3ccbfdc3621b5559b5453075e5de555e1f3a1b \ - --hash=sha256:f32d6edbc638cde7652bd690c3e728b25332acbadd7cad670cc4a02558d9c417 \ - --hash=sha256:f37cfe618a117e50d8c240555331160d73d0411422b59b5ee217843d7b693608 \ - --hash=sha256:f4c9aee212bc89d4e13f58be11a56cc8036cabad119259d12ace14b34476fd07 \ - --hash=sha256:f4d742cb7af1c28303a51b7a27aaee540e71bb8e24f68c736f6f2ffc82f2bf05 \ - --hash=sha256:f5a8b53bdc0b3961f8b6125e198617c40aeed638b387913bf1ce78afb1b0be2a \ - --hash=sha256:f816dd2277f8d63d79f9c8473a79fe54047bc0467754962840782c575522224d \ - --hash=sha256:f9a9e8a507420fe35992ee9ecb302dab68550dedc0da9e2880dd88071c5fb052 +jax-cuda12-pjrt==0.8.2 ; sys_platform == "linux" \ + --hash=sha256:717a1b196a642409ce195ddf031c20bbeadcc886f55e49a1d3f4927373aeedae \ + --hash=sha256:e3bab41ca7c48e4163db9e7efd271b3aa85f0fe45f5ed0708d6bbed93a59f977 + # via + # -r build/requirements.in + # jax-cuda12-plugin +jax-cuda12-plugin==0.8.2 ; sys_platform == "linux" and python_version < "3.14" \ + --hash=sha256:0b0a3304ce7e494acd8d9c593490c112a32cdb6010fe1afc584d9e41fd863167 \ + --hash=sha256:1b4828242d57f233b394d17ebaa599c503c1fb9b7c754012a06eb84dbc935fc8 \ + --hash=sha256:20165861b3d3e66ebb2c0f63a547d1d5ee17ea44ac3be7153c7908c9ca8c88f3 \ + --hash=sha256:377e4be17e22dde0343b3f3c05bf69235b3dbf11d766cca9c5a93da47971dcb7 \ + --hash=sha256:403d5e07731b5cdac3bd9fb3f448bd8480062cb2c0ab61ea2ad23fcd0a65479a \ + --hash=sha256:58c51473fc622e03138035985f741833564d70a4bd5a2178f61b62cdaa32ff94 \ + --hash=sha256:637387dc3408cd204562668502f9e95f76c6edde0a6d2e48f055162dc2aebf0d \ + --hash=sha256:70d33222484ad5c375b8f8357b7c23cacb844f6ecfc39567f8dd47fde6e87858 \ + --hash=sha256:82c6798be66bf8c773386918e4c8e5cd8119753f3bfb3ca4bbc46818283750c6 \ + --hash=sha256:a5898bac1d8ab6020b54546440256409f2c66bcbbb3a1099ca473c84843addad \ + --hash=sha256:d68a6d8b4a45ee561746bac7a6468da8203832626b0b39ad4ac43011f61f875d \ + --hash=sha256:dd4f7c34d4512ff5a36fd1b01584ef7781cad615e3f9e71880eae2f4998e5108 + # via -r build/requirements.in +jax-cuda13-pjrt==0.8.2 \ + --hash=sha256:370e9cb9a76af6a6b00e8f6c6ae8e5a0158886b65994a114cf3acf56ea1570f3 \ + --hash=sha256:c0ebc284d4bea9c0b278017b4bb62eb871755fe582cbf22ce1fb770d30fb9af6 + # via + # -r build/requirements.in + # jax-cuda13-plugin +jax-cuda13-plugin==0.8.2 \ + --hash=sha256:0ab3a13f798f973416754cb00b1696a8e6d2425c39496e2145293661bcd28446 \ + --hash=sha256:2c7532f2a6c9dcf50ba4b8ebbe4c78988b8c260695e50f9fdbd68f9417a8bb51 \ + --hash=sha256:358f4b4d0d2b92f62fe7f07d355a9c57c5684da3d4e6f483afec38b23e66f02e \ + --hash=sha256:6703552ee10b6b41b43fe4fc62b4672bec13d271eb709a29bf14b0fa1ec805dd \ + --hash=sha256:8fe2c5874cbae18a01a0fa722dfbd547b5504e7bf42aaaae6ad3e1b676e8daa2 \ + --hash=sha256:950f2c1f2e0f5e13893c4c27a4df540a96a26ebd15d7954fa85ffb8cc8b1d9a1 \ + --hash=sha256:9835076d0a917345700370cfaee5cfd7876c0b0982e99206a8b4848af0ea6e5e \ + --hash=sha256:bb059d99eb960718358f82f3eaa0829f96fc7d7e978ded821f2c40256a8aee4b \ + --hash=sha256:d5f9470db46aee8a64e9ab7817927335b77c6d3698d5113e0da2dd962c1468c4 \ + --hash=sha256:f1643b216ccb41a2d7c0e240359c729c96c5eca76e4525ffcd6313049e348b9a \ + --hash=sha256:f35f687d9b839365a49c2b0052eaeb6b9f0a9879813b4fa020f8d9caa3ddde46 \ + --hash=sha256:fa6a6faf3130d6ef05f9ee6cbc27e0a2e4db0f23720136403991eccb3a0144af + # via -r build/requirements.in +jaxlib==0.8.2 \ + --hash=sha256:023de6f3f56da2af7037970996500586331fdb50b530ecbb54b9666da633bd00 \ + --hash=sha256:05b958f497e49824c432e734bb059723b7dfe69e2ad696a9f9c8ad82fff7c3f8 \ + --hash=sha256:1bfbcf6c3de221784fa4cdb6765a09d71cb4298b15626b3d0409b3dfcd8a8667 \ + --hash=sha256:28eec1a4e0639a0d8702cea3cb70dd3663053dbfa344452994ea48dc6ceadaa5 \ + --hash=sha256:2b9789bd08f8b0cc5a5c12ae896fe432d5942e32e417091b8b5a96a9a6fd5cf1 \ + --hash=sha256:3b16e50c5b730c9dd0a49e55f1acfaa722b00b1af0522a591558dcc0464252f2 \ + --hash=sha256:490bf0cb029c73c65c9431124b86cdc95082dbc1fb76fc549d24d75da33e5454 \ + --hash=sha256:4d006db96be020c8165212a1216372f8acac4ff4f8fb067743d694ef2b301ace \ + --hash=sha256:68108dff0de74adc468016be9a19f80efe48c660c0d5a122287094b44b092afc \ + --hash=sha256:7c304f3a016965b9d1f5239a8a0399a73925f5604fe914c5ca66ecf734bf6422 \ + --hash=sha256:7da8127557c786264049ae55460d1b8d04cc3cdf0403a087f2fc1e6d313ec722 \ + --hash=sha256:964626f581beab31ee6826b228fcc2ec5181b05cecf94a528dff97921c145dbc \ + --hash=sha256:a397ea7dcb37d689ce79173eeb99b2f1347637a36be9a27f20ae6848bfc58bfc \ + --hash=sha256:aa8701b6356f098e8452c3cec762fb5f706fcb8f67ffd65964f63982479aa23b \ + --hash=sha256:bb89be452b1b808d3f88fc01c415b364a260be4cc7ac120c038009f6150a32dc \ + --hash=sha256:beffb004e7eeb5c9afb24439e2b2cf45a4ee3e3e8adf45e355edf2af62acf8b8 \ + --hash=sha256:ccf77da917a20935247c990691decfcbdd06c25ef0ac94d914a04aadb22f714c \ + --hash=sha256:dffc22b5b732b9556d92c918b251c61bcc046617c4dbb51e1f7a656587fddffb \ + --hash=sha256:e6a97dfb0232eed9a2bb6e3828e4f682dbac1a7fea840bfda574cae2dbf5faf9 \ + --hash=sha256:f205e91c3a152a2a76c0bc59a6a2de03e87ec261b91e8812922777185e7b08f5 \ + --hash=sha256:f28edac8c226fc07fa3e8af6f9defede8ac2c307429e3291edce8739d39becc9 \ + --hash=sha256:f472cc72e3058e50b5f0230b236d5a1183bf6c3d5423d2a52eff07bcf34908de + # via -r build/requirements.in +keras==3.12.0 \ + --hash=sha256:02b69e007d5df8042286c3bcc2a888539e3e487590ffb08f6be1b4354df50aa8 \ + --hash=sha256:536e3f8385a05ae04e82e08715a1a59988578087e187b04cb0a6fad11743f07f + # via tensorflow +kiwisolver==1.4.9 \ + --hash=sha256:0749fd8f4218ad2e851e11cc4dc05c7cbc0cbc4267bdfdb31782e65aace4ee9c \ + --hash=sha256:0763515d4df10edf6d06a3c19734e2566368980d21ebec439f33f9eb936c07b7 \ + --hash=sha256:0856e241c2d3df4efef7c04a1e46b1936b6120c9bcf36dd216e3acd84bc4fb21 \ + --hash=sha256:0a590506f303f512dff6b7f75fd2fd18e16943efee932008fe7140e5fa91d80e \ + --hash=sha256:0ab74e19f6a2b027ea4f845a78827969af45ce790e6cb3e1ebab71bdf9f215ff \ + --hash=sha256:0ae37737256ba2de764ddc12aed4956460277f00c4996d51a197e72f62f5eec7 \ + --hash=sha256:0e4e2bf29574a6a7b7f6cb5fa69293b9f96c928949ac4a53ba3f525dffb87f9c \ + --hash=sha256:15163165efc2f627eb9687ea5f3a28137217d217ac4024893d753f46bce9de26 \ + --hash=sha256:17680d737d5335b552994a2008fab4c851bcd7de33094a82067ef3a576ff02fa \ + --hash=sha256:1a12cf6398e8a0a001a059747a1cbf24705e18fe413bc22de7b3d15c67cffe3f \ + --hash=sha256:1b11d6a633e4ed84fc0ddafd4ebfd8ea49b3f25082c04ad12b8315c11d504dc1 \ + --hash=sha256:1fa333e8b2ce4d9660f2cda9c0e1b6bafcfb2457a9d259faa82289e73ec24891 \ + --hash=sha256:2327a4a30d3ee07d2fbe2e7933e8a37c591663b96ce42a00bc67461a87d7df77 \ + --hash=sha256:2405a7d98604b87f3fc28b1716783534b1b4b8510d8142adca34ee0bc3c87543 \ + --hash=sha256:2489e4e5d7ef9a1c300a5e0196e43d9c739f066ef23270607d45aba368b91f2d \ + --hash=sha256:24c175051354f4a28c5d6a31c93906dc653e2bf234e8a4bbfb964892078898ce \ + --hash=sha256:2635d352d67458b66fd0667c14cb1d4145e9560d503219034a18a87e971ce4f3 \ + --hash=sha256:2c1a4f57df73965f3f14df20b80ee29e6a7930a57d2d9e8491a25f676e197c60 \ + --hash=sha256:2c93f00dcba2eea70af2be5f11a830a742fe6b579a1d4e00f47760ef13be247a \ + --hash=sha256:39a219e1c81ae3b103643d2aedb90f1ef22650deb266ff12a19e7773f3e5f089 \ + --hash=sha256:3b3115b2581ea35bb6d1f24a4c90af37e5d9b49dcff267eeed14c3893c5b86ab \ + --hash=sha256:40092754720b174e6ccf9e845d0d8c7d8e12c3d71e7fc35f55f3813e96376f78 \ + --hash=sha256:412f287c55a6f54b0650bd9b6dce5aceddb95864a1a90c87af16979d37c89771 \ + --hash=sha256:464415881e4801295659462c49461a24fb107c140de781d55518c4b80cb6790f \ + --hash=sha256:497d05f29a1300d14e02e6441cf0f5ee81c1ff5a304b0d9fb77423974684e08b \ + --hash=sha256:4a2899935e724dd1074cb568ce7ac0dce28b2cd6ab539c8e001a8578eb106d14 \ + --hash=sha256:4a48a2ce79d65d363597ef7b567ce3d14d68783d2b2263d98db3d9477805ba32 \ + --hash=sha256:4d1d9e582ad4d63062d34077a9a1e9f3c34088a2ec5135b1f7190c07cf366527 \ + --hash=sha256:52a15b0f35dad39862d376df10c5230155243a2c1a436e39eb55623ccbd68185 \ + --hash=sha256:540c7c72324d864406a009d72f5d6856f49693db95d1fbb46cf86febef873634 \ + --hash=sha256:5656aa670507437af0207645273ccdfee4f14bacd7f7c67a4306d0dcaeaf6eed \ + --hash=sha256:5a0f2724dfd4e3b3ac5a82436a8e6fd16baa7d507117e4279b660fe8ca38a3a1 \ + --hash=sha256:60c439763a969a6af93b4881db0eed8fadf93ee98e18cbc35bc8da868d0c4f0c \ + --hash=sha256:61874cdb0a36016354853593cffc38e56fc9ca5aa97d2c05d3dcf6922cd55a11 \ + --hash=sha256:67bb8b474b4181770f926f7b7d2f8c0248cbcb78b660fdd41a47054b28d2a752 \ + --hash=sha256:720e05574713db64c356e86732c0f3c5252818d05f9df320f0ad8380641acea5 \ + --hash=sha256:72d0eb9fba308b8311685c2268cf7d0a0639a6cd027d8128659f72bdd8a024b4 \ + --hash=sha256:767c23ad1c58c9e827b649a9ab7809fd5fd9db266a9cf02b0e926ddc2c680d58 \ + --hash=sha256:77937e5e2a38a7b48eef0585114fe7930346993a88060d0bf886086d2aa49ef5 \ + --hash=sha256:7a08b491ec91b1d5053ac177afe5290adacf1f0f6307d771ccac5de30592d198 \ + --hash=sha256:7b4da0d01ac866a57dd61ac258c5607b4cd677f63abaec7b148354d2b2cdd536 \ + --hash=sha256:7cf974dd4e35fa315563ac99d6287a1024e4dc2077b8a7d7cd3d2fb65d283134 \ + --hash=sha256:84fd60810829c27ae375114cd379da1fa65e6918e1da405f356a775d49a62bcf \ + --hash=sha256:858e4c22fb075920b96a291928cb7dea5644e94c0ee4fcd5af7e865655e4ccf2 \ + --hash=sha256:85b5352f94e490c028926ea567fc569c52ec79ce131dadb968d3853e809518c2 \ + --hash=sha256:85bd218b5ecfbee8c8a82e121802dcb519a86044c9c3b2e4aef02fa05c6da370 \ + --hash=sha256:8a1f570ce4d62d718dce3f179ee78dac3b545ac16c0c04bb363b7607a949c0d1 \ + --hash=sha256:8fdca1def57a2e88ef339de1737a1449d6dbf5fab184c54a1fca01d541317154 \ + --hash=sha256:90f47e70293fc3688b71271100a1a5453aa9944a81d27ff779c108372cf5567b \ + --hash=sha256:92a2f997387a1b79a75e7803aa7ded2cfbe2823852ccf1ba3bcf613b62ae3197 \ + --hash=sha256:9928fe1eb816d11ae170885a74d074f57af3a0d65777ca47e9aeb854a1fba386 \ + --hash=sha256:9af39d6551f97d31a4deebeac6f45b156f9755ddc59c07b402c148f5dbb6482a \ + --hash=sha256:9cf554f21be770f5111a1690d42313e140355e687e05cf82cb23d0a721a64a48 \ + --hash=sha256:a30fd6fdef1430fd9e1ba7b3398b5ee4e2887783917a687d86ba69985fb08748 \ + --hash=sha256:a31d512c812daea6d8b3be3b2bfcbeb091dbb09177706569bcfc6240dcf8b41c \ + --hash=sha256:a5d0432ccf1c7ab14f9949eec60c5d1f924f17c037e9f8b33352fa05799359b8 \ + --hash=sha256:a60ea74330b91bd22a29638940d115df9dc00af5035a9a2a6ad9399ffb4ceca5 \ + --hash=sha256:ac5a486ac389dddcc5bef4f365b6ae3ffff2c433324fb38dd35e3fab7c957999 \ + --hash=sha256:aedff62918805fb62d43a4aa2ecd4482c380dc76cd31bd7c8878588a61bd0369 \ + --hash=sha256:b34e51affded8faee0dfdb705416153819d8ea9250bbbf7ea1b249bdeb5f1122 \ + --hash=sha256:b4b4d74bda2b8ebf4da5bd42af11d02d04428b2c32846e4c2c93219df8a7987b \ + --hash=sha256:b67e6efbf68e077dd71d1a6b37e43e1a99d0bff1a3d51867d45ee8908b931098 \ + --hash=sha256:b78efa4c6e804ecdf727e580dbb9cba85624d2e1c6b5cb059c66290063bd99a9 \ + --hash=sha256:bb4ae2b57fc1d8cbd1cf7b1d9913803681ffa903e7488012be5b76dedf49297f \ + --hash=sha256:bdd1a81a1860476eb41ac4bc1e07b3f07259e6d55bbf739b79c8aaedcf512799 \ + --hash=sha256:bdee92c56a71d2b24c33a7d4c2856bd6419d017e08caa7802d2963870e315028 \ + --hash=sha256:be6a04e6c79819c9a8c2373317d19a96048e5a3f90bec587787e86a1153883c2 \ + --hash=sha256:bfc08add558155345129c7803b3671cf195e6a56e7a12f3dde7c57d9b417f525 \ + --hash=sha256:c3b22c26c6fd6811b0ae8363b95ca8ce4ea3c202d3d0975b2914310ceb1bcc4d \ + --hash=sha256:c9e7cdf45d594ee04d5be1b24dd9d49f3d1590959b2271fb30b5ca2b262c00fb \ + --hash=sha256:cb27e7b78d716c591e88e0a09a2139c6577865d7f2e152488c2cc6257f460872 \ + --hash=sha256:cc9617b46837c6468197b5945e196ee9ca43057bb7d9d1ae688101e4e1dddf64 \ + --hash=sha256:ccd09f20ccdbbd341b21a67ab50a119b64a403b09288c27481575105283c1586 \ + --hash=sha256:ce6a3a4e106cf35c2d9c4fa17c05ce0b180db622736845d4315519397a77beaf \ + --hash=sha256:d0005b053977e7b43388ddec89fa567f43d4f6d5c2c0affe57de5ebf290dc552 \ + --hash=sha256:d4188e73af84ca82468f09cadc5ac4db578109e52acb4518d8154698d3a87ca2 \ + --hash=sha256:d4efec7bcf21671db6a3294ff301d2fc861c31faa3c8740d1a94689234d1b415 \ + --hash=sha256:d75aa530ccfaa593da12834b86a0724f58bff12706659baa9227c2ccaa06264c \ + --hash=sha256:d84cd4061ae292d8ac367b2c3fa3aad11cb8625a95d135fe93f286f914f3f5a6 \ + --hash=sha256:d8aacd3d4b33b772542b2e01beb50187536967b514b00003bdda7589722d2a64 \ + --hash=sha256:d8fc5c867c22b828001b6a38d2eaeb88160bf5783c6cb4a5e440efc981ce286d \ + --hash=sha256:d976bbb382b202f71c67f77b0ac11244021cfa3f7dfd9e562eefcea2df711548 \ + --hash=sha256:dba5ee5d3981160c28d5490f0d1b7ed730c22470ff7f6cc26cfcfaacb9896a07 \ + --hash=sha256:dc1ae486f9abcef254b5618dfb4113dd49f94c68e3e027d03cf0143f3f772b61 \ + --hash=sha256:dd0a578400839256df88c16abddf9ba14813ec5f21362e1fe65022e00c883d4d \ + --hash=sha256:deed0c7258ceb4c44ad5ec7d9918f9f14fd05b2be86378d86cf50e63d1e7b771 \ + --hash=sha256:e09c2279a4d01f099f52d5c4b3d9e208e91edcbd1a175c9662a8b16e000fece9 \ + --hash=sha256:e2ea9f7ab7fbf18fffb1b5434ce7c69a07582f7acc7717720f1d69f3e806f90c \ + --hash=sha256:e6b93f13371d341afee3be9f7c5964e3fe61d5fa30f6a30eb49856935dfe4fc3 \ + --hash=sha256:eb14a5da6dc7642b0f3a18f13654847cd8b7a2550e2645a5bda677862b03ba16 \ + --hash=sha256:ed0fecd28cc62c54b262e3736f8bb2512d8dcfdc2bcf08be5f47f96bf405b145 \ + --hash=sha256:ede8c6d533bc6601a47ad4046080d36b8fc99f81e6f1c17b0ac3c2dc91ac7611 \ + --hash=sha256:efb3a45b35622bb6c16dbfab491a8f5a391fe0e9d45ef32f4df85658232ca0e2 \ + --hash=sha256:f117e1a089d9411663a3207ba874f31be9ac8eaa5b533787024dc07aeb74f464 \ + --hash=sha256:f2ba92255faa7309d06fe44c3a4a97efe1c8d640c2a79a5ef728b685762a6fd2 \ + --hash=sha256:f6008a4919fdbc0b0097089f67a1eb55d950ed7e90ce2cc3e640abadd2757a04 \ + --hash=sha256:f68208a520c3d86ea51acf688a3e3002615a7f0238002cccc17affecc86a8a54 \ + --hash=sha256:f68e4f3eeca8fb22cc3d731f9715a13b652795ef657a13df1ad0c7dc0e9731df \ + --hash=sha256:fb3b8132019ea572f4611d770991000d7f58127560c4889729248eb5852a102f \ + --hash=sha256:fb940820c63a9590d31d88b815e7a3aa5915cad3ce735ab45f0c730b39547de1 \ + --hash=sha256:fc1795ac5cd0510207482c3d1d3ed781143383b8cfd36f5c645f3897ce066220 # via matplotlib -markdown-it-py==3.0.0 \ - --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ - --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb +libclang==18.1.1 \ + --hash=sha256:0b2e143f0fac830156feb56f9231ff8338c20aecfe72b4ffe96f19e5a1dbb69a \ + --hash=sha256:3f0e1f49f04d3cd198985fea0511576b0aee16f9ff0e0f0cad7f9c57ec3c20e8 \ + --hash=sha256:4dd2d3b82fab35e2bf9ca717d7b63ac990a3519c7e312f19fa8e86dcc712f7fb \ + --hash=sha256:54dda940a4a0491a9d1532bf071ea3ef26e6dbaf03b5000ed94dd7174e8f9592 \ + --hash=sha256:69f8eb8f65c279e765ffd28aaa7e9e364c776c17618af8bff22a8df58677ff4f \ + --hash=sha256:6f14c3f194704e5d09769108f03185fce7acaf1d1ae4bbb2f30a72c2400cb7c5 \ + --hash=sha256:83ce5045d101b669ac38e6da8e58765f12da2d3aafb3b9b98d88b286a60964d8 \ + --hash=sha256:a1214966d08d73d971287fc3ead8dfaf82eb07fb197680d8b3859dbbbbf78250 \ + --hash=sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b \ + --hash=sha256:cf4a99b05376513717ab5d82a0db832c56ccea4fd61a69dbb7bccf2dfb207dbe + # via tensorflow +libtpu==0.0.32 ; sys_platform == "linux" and platform_machine == "x86_64" \ + --hash=sha256:37f6aefe6d69d24e5e268e74c1e90cd0ce0fa6a796720709445d8938605912ee \ + --hash=sha256:4a4ae6db0a90f0e33902cd345923a5d77dfc5bbab9a3e524c79bf876858104ee \ + --hash=sha256:6453995774b902e43d1caf0d82298a2e90f2f731ddd74d6bef9510d91732775f \ + --hash=sha256:7ffd9cdb89da22d32bb60ce70b88e9e76861c5d704ebf673f32987e19e9913f3 \ + --hash=sha256:98d9713a39d9581c9b10a7bdeaa575cc874a4a7f543cfb40d3a723ee27b428b5 \ + --hash=sha256:d2a0e7abe4a029b79049bc95da20e4c2a64f8ef09823f75286cecfbb4b0b6da1 + # via -r build/requirements.in +markdown==3.10 \ + --hash=sha256:37062d4f2aa4b2b6b32aefb80faa300f82cc790cb949a35b8caede34f2b68c0e \ + --hash=sha256:b5b99d6951e2e4948d939255596523444c0e677c669700b1d17aa4a8a464cb7c + # via tensorboard +markdown-it-py==4.0.0 \ + --hash=sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147 \ + --hash=sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3 # via rich -matplotlib==3.9.2 ; python_version >= "3.11" \ - --hash=sha256:039082812cacd6c6bec8e17a9c1e6baca230d4116d522e81e1f63a74d01d2e21 \ - --hash=sha256:03ba9c1299c920964e8d3857ba27173b4dbb51ca4bab47ffc2c2ba0eb5e2cbc5 \ - --hash=sha256:050598c2b29e0b9832cde72bcf97627bf00262adbc4a54e2b856426bb2ef0697 \ - --hash=sha256:18128cc08f0d3cfff10b76baa2f296fc28c4607368a8402de61bb3f2eb33c7d9 \ - --hash=sha256:1cd93b91ab47a3616b4d3c42b52f8363b88ca021e340804c6ab2536344fad9ca \ - --hash=sha256:1d94ff717eb2bd0b58fe66380bd8b14ac35f48a98e7c6765117fe67fb7684e64 \ - --hash=sha256:306c8dfc73239f0e72ac50e5a9cf19cc4e8e331dd0c54f5e69ca8758550f1e1e \ - --hash=sha256:37e51dd1c2db16ede9cfd7b5cabdfc818b2c6397c83f8b10e0e797501c963a03 \ - --hash=sha256:3fd595f34aa8a55b7fc8bf9ebea8aa665a84c82d275190a61118d33fbc82ccae \ - --hash=sha256:4876d7d40219e8ae8bb70f9263bcbe5714415acfdf781086601211335e24f8aa \ - --hash=sha256:5413401594cfaff0052f9d8b1aafc6d305b4bd7c4331dccd18f561ff7e1d3bd3 \ - --hash=sha256:5816b1e1fe8c192cbc013f8f3e3368ac56fbecf02fb41b8f8559303f24c5015e \ - --hash=sha256:65aacf95b62272d568044531e41de26285d54aec8cb859031f511f84bd8b495a \ - --hash=sha256:6758baae2ed64f2331d4fd19be38b7b4eae3ecec210049a26b6a4f3ae1c85dcc \ - --hash=sha256:6d1ce5ed2aefcdce11904fc5bbea7d9c21fff3d5f543841edf3dea84451a09ea \ - --hash=sha256:6d9f07a80deab4bb0b82858a9e9ad53d1382fd122be8cde11080f4e7dfedb38b \ - --hash=sha256:7741f26a58a240f43bee74965c4882b6c93df3e7eb3de160126d8c8f53a6ae6e \ - --hash=sha256:8912ef7c2362f7193b5819d17dae8629b34a95c58603d781329712ada83f9447 \ - --hash=sha256:909645cce2dc28b735674ce0931a4ac94e12f5b13f6bb0b5a5e65e7cea2c192b \ - --hash=sha256:96ab43906269ca64a6366934106fa01534454a69e471b7bf3d79083981aaab92 \ - --hash=sha256:9d78bbc0cbc891ad55b4f39a48c22182e9bdaea7fc0e5dbd364f49f729ca1bbb \ - --hash=sha256:ab68d50c06938ef28681073327795c5db99bb4666214d2d5f880ed11aeaded66 \ - --hash=sha256:ac43031375a65c3196bee99f6001e7fa5bdfb00ddf43379d3c0609bdca042df9 \ - --hash=sha256:ae82a14dab96fbfad7965403c643cafe6515e386de723e498cf3eeb1e0b70cc7 \ - --hash=sha256:b2696efdc08648536efd4e1601b5fd491fd47f4db97a5fbfd175549a7365c1b2 \ - --hash=sha256:b82c5045cebcecd8496a4d694d43f9cc84aeeb49fe2133e036b207abe73f4d30 \ - --hash=sha256:be0fc24a5e4531ae4d8e858a1a548c1fe33b176bb13eff7f9d0d38ce5112a27d \ - --hash=sha256:bf81de2926c2db243c9b2cbc3917619a0fc85796c6ba4e58f541df814bbf83c7 \ - --hash=sha256:c375cc72229614632c87355366bdf2570c2dac01ac66b8ad048d2dabadf2d0d4 \ - --hash=sha256:c797dac8bb9c7a3fd3382b16fe8f215b4cf0f22adccea36f1545a6d7be310b41 \ - --hash=sha256:cef2a73d06601437be399908cf13aee74e86932a5ccc6ccdf173408ebc5f6bb2 \ - --hash=sha256:d52a3b618cb1cbb769ce2ee1dcdb333c3ab6e823944e9a2d36e37253815f9556 \ - --hash=sha256:d719465db13267bcef19ea8954a971db03b9f48b4647e3860e4bc8e6ed86610f \ - --hash=sha256:d8dd059447824eec055e829258ab092b56bb0579fc3164fa09c64f3acd478772 \ - --hash=sha256:dbe196377a8248972f5cede786d4c5508ed5f5ca4a1e09b44bda889958b33f8c \ - --hash=sha256:e0830e188029c14e891fadd99702fd90d317df294c3298aad682739c5533721a \ - --hash=sha256:f053c40f94bc51bc03832a41b4f153d83f2062d88c72b5e79997072594e97e51 \ - --hash=sha256:f32c7410c7f246838a77d6d1eff0c0f87f3cb0e7c4247aebea71a6d5a68cab49 \ - --hash=sha256:f6ee45bc4245533111ced13f1f2cace1e7f89d1c793390392a80c139d6cf0e6c \ - --hash=sha256:f7c0410f181a531ec4e93bbc27692f2c71a15c2da16766f5ba9761e7ae518413 +markupsafe==3.0.3 \ + --hash=sha256:0303439a41979d9e74d18ff5e2dd8c43ed6c6001fd40e5bf2e43f7bd9bbc523f \ + --hash=sha256:068f375c472b3e7acbe2d5318dea141359e6900156b5b2ba06a30b169086b91a \ + --hash=sha256:0bf2a864d67e76e5c9a34dc26ec616a66b9888e25e7b9460e1c76d3293bd9dbf \ + --hash=sha256:0db14f5dafddbb6d9208827849fad01f1a2609380add406671a26386cdf15a19 \ + --hash=sha256:0eb9ff8191e8498cca014656ae6b8d61f39da5f95b488805da4bb029cccbfbaf \ + --hash=sha256:0f4b68347f8c5eab4a13419215bdfd7f8c9b19f2b25520968adfad23eb0ce60c \ + --hash=sha256:1085e7fbddd3be5f89cc898938f42c0b3c711fdcb37d75221de2666af647c175 \ + --hash=sha256:116bb52f642a37c115f517494ea5feb03889e04df47eeff5b130b1808ce7c219 \ + --hash=sha256:12c63dfb4a98206f045aa9563db46507995f7ef6d83b2f68eda65c307c6829eb \ + --hash=sha256:133a43e73a802c5562be9bbcd03d090aa5a1fe899db609c29e8c8d815c5f6de6 \ + --hash=sha256:1353ef0c1b138e1907ae78e2f6c63ff67501122006b0f9abad68fda5f4ffc6ab \ + --hash=sha256:15d939a21d546304880945ca1ecb8a039db6b4dc49b2c5a400387cdae6a62e26 \ + --hash=sha256:177b5253b2834fe3678cb4a5f0059808258584c559193998be2601324fdeafb1 \ + --hash=sha256:1872df69a4de6aead3491198eaf13810b565bdbeec3ae2dc8780f14458ec73ce \ + --hash=sha256:1b4b79e8ebf6b55351f0d91fe80f893b4743f104bff22e90697db1590e47a218 \ + --hash=sha256:1b52b4fb9df4eb9ae465f8d0c228a00624de2334f216f178a995ccdcf82c4634 \ + --hash=sha256:1ba88449deb3de88bd40044603fafffb7bc2b055d626a330323a9ed736661695 \ + --hash=sha256:1cc7ea17a6824959616c525620e387f6dd30fec8cb44f649e31712db02123dad \ + --hash=sha256:218551f6df4868a8d527e3062d0fb968682fe92054e89978594c28e642c43a73 \ + --hash=sha256:26a5784ded40c9e318cfc2bdb30fe164bdb8665ded9cd64d500a34fb42067b1c \ + --hash=sha256:2713baf880df847f2bece4230d4d094280f4e67b1e813eec43b4c0e144a34ffe \ + --hash=sha256:2a15a08b17dd94c53a1da0438822d70ebcd13f8c3a95abe3a9ef9f11a94830aa \ + --hash=sha256:2f981d352f04553a7171b8e44369f2af4055f888dfb147d55e42d29e29e74559 \ + --hash=sha256:32001d6a8fc98c8cb5c947787c5d08b0a50663d139f1305bac5885d98d9b40fa \ + --hash=sha256:3524b778fe5cfb3452a09d31e7b5adefeea8c5be1d43c4f810ba09f2ceb29d37 \ + --hash=sha256:3537e01efc9d4dccdf77221fb1cb3b8e1a38d5428920e0657ce299b20324d758 \ + --hash=sha256:35add3b638a5d900e807944a078b51922212fb3dedb01633a8defc4b01a3c85f \ + --hash=sha256:38664109c14ffc9e7437e86b4dceb442b0096dfe3541d7864d9cbe1da4cf36c8 \ + --hash=sha256:3a7e8ae81ae39e62a41ec302f972ba6ae23a5c5396c8e60113e9066ef893da0d \ + --hash=sha256:3b562dd9e9ea93f13d53989d23a7e775fdfd1066c33494ff43f5418bc8c58a5c \ + --hash=sha256:457a69a9577064c05a97c41f4e65148652db078a3a509039e64d3467b9e7ef97 \ + --hash=sha256:4bd4cd07944443f5a265608cc6aab442e4f74dff8088b0dfc8238647b8f6ae9a \ + --hash=sha256:4e885a3d1efa2eadc93c894a21770e4bc67899e3543680313b09f139e149ab19 \ + --hash=sha256:4faffd047e07c38848ce017e8725090413cd80cbc23d86e55c587bf979e579c9 \ + --hash=sha256:509fa21c6deb7a7a273d629cf5ec029bc209d1a51178615ddf718f5918992ab9 \ + --hash=sha256:5678211cb9333a6468fb8d8be0305520aa073f50d17f089b5b4b477ea6e67fdc \ + --hash=sha256:591ae9f2a647529ca990bc681daebdd52c8791ff06c2bfa05b65163e28102ef2 \ + --hash=sha256:5a7d5dc5140555cf21a6fefbdbf8723f06fcd2f63ef108f2854de715e4422cb4 \ + --hash=sha256:69c0b73548bc525c8cb9a251cddf1931d1db4d2258e9599c28c07ef3580ef354 \ + --hash=sha256:6b5420a1d9450023228968e7e6a9ce57f65d148ab56d2313fcd589eee96a7a50 \ + --hash=sha256:722695808f4b6457b320fdc131280796bdceb04ab50fe1795cd540799ebe1698 \ + --hash=sha256:729586769a26dbceff69f7a7dbbf59ab6572b99d94576a5592625d5b411576b9 \ + --hash=sha256:77f0643abe7495da77fb436f50f8dab76dbc6e5fd25d39589a0f1fe6548bfa2b \ + --hash=sha256:795e7751525cae078558e679d646ae45574b47ed6e7771863fcc079a6171a0fc \ + --hash=sha256:7be7b61bb172e1ed687f1754f8e7484f1c8019780f6f6b0786e76bb01c2ae115 \ + --hash=sha256:7c3fb7d25180895632e5d3148dbdc29ea38ccb7fd210aa27acbd1201a1902c6e \ + --hash=sha256:7e68f88e5b8799aa49c85cd116c932a1ac15caaa3f5db09087854d218359e485 \ + --hash=sha256:83891d0e9fb81a825d9a6d61e3f07550ca70a076484292a70fde82c4b807286f \ + --hash=sha256:8485f406a96febb5140bfeca44a73e3ce5116b2501ac54fe953e488fb1d03b12 \ + --hash=sha256:8709b08f4a89aa7586de0aadc8da56180242ee0ada3999749b183aa23df95025 \ + --hash=sha256:8f71bc33915be5186016f675cd83a1e08523649b0e33efdb898db577ef5bb009 \ + --hash=sha256:915c04ba3851909ce68ccc2b8e2cd691618c4dc4c4232fb7982bca3f41fd8c3d \ + --hash=sha256:949b8d66bc381ee8b007cd945914c721d9aba8e27f71959d750a46f7c282b20b \ + --hash=sha256:94c6f0bb423f739146aec64595853541634bde58b2135f27f61c1ffd1cd4d16a \ + --hash=sha256:9a1abfdc021a164803f4d485104931fb8f8c1efd55bc6b748d2f5774e78b62c5 \ + --hash=sha256:9b79b7a16f7fedff2495d684f2b59b0457c3b493778c9eed31111be64d58279f \ + --hash=sha256:a320721ab5a1aba0a233739394eb907f8c8da5c98c9181d1161e77a0c8e36f2d \ + --hash=sha256:a4afe79fb3de0b7097d81da19090f4df4f8d3a2b3adaa8764138aac2e44f3af1 \ + --hash=sha256:ad2cf8aa28b8c020ab2fc8287b0f823d0a7d8630784c31e9ee5edea20f406287 \ + --hash=sha256:b8512a91625c9b3da6f127803b166b629725e68af71f8184ae7e7d54686a56d6 \ + --hash=sha256:bc51efed119bc9cfdf792cdeaa4d67e8f6fcccab66ed4bfdd6bde3e59bfcbb2f \ + --hash=sha256:bdc919ead48f234740ad807933cdf545180bfbe9342c2bb451556db2ed958581 \ + --hash=sha256:bdd37121970bfd8be76c5fb069c7751683bdf373db1ed6c010162b2a130248ed \ + --hash=sha256:be8813b57049a7dc738189df53d69395eba14fb99345e0a5994914a3864c8a4b \ + --hash=sha256:c0c0b3ade1c0b13b936d7970b1d37a57acde9199dc2aecc4c336773e1d86049c \ + --hash=sha256:c47a551199eb8eb2121d4f0f15ae0f923d31350ab9280078d1e5f12b249e0026 \ + --hash=sha256:c4ffb7ebf07cfe8931028e3e4c85f0357459a3f9f9490886198848f4fa002ec8 \ + --hash=sha256:ccfcd093f13f0f0b7fdd0f198b90053bf7b2f02a3927a30e63f3ccc9df56b676 \ + --hash=sha256:d2ee202e79d8ed691ceebae8e0486bd9a2cd4794cec4824e1c99b6f5009502f6 \ + --hash=sha256:d53197da72cc091b024dd97249dfc7794d6a56530370992a5e1a08983ad9230e \ + --hash=sha256:d6dd0be5b5b189d31db7cda48b91d7e0a9795f31430b7f271219ab30f1d3ac9d \ + --hash=sha256:d88b440e37a16e651bda4c7c2b930eb586fd15ca7406cb39e211fcff3bf3017d \ + --hash=sha256:de8a88e63464af587c950061a5e6a67d3632e36df62b986892331d4620a35c01 \ + --hash=sha256:df2449253ef108a379b8b5d6b43f4b1a8e81a061d6537becd5582fba5f9196d7 \ + --hash=sha256:e1c1493fb6e50ab01d20a22826e57520f1284df32f2d8601fdd90b6304601419 \ + --hash=sha256:e1cf1972137e83c5d4c136c43ced9ac51d0e124706ee1c8aa8532c1287fa8795 \ + --hash=sha256:e2103a929dfa2fcaf9bb4e7c091983a49c9ac3b19c9061b6d5427dd7d14d81a1 \ + --hash=sha256:e56b7d45a839a697b5eb268c82a71bd8c7f6c94d6fd50c3d577fa39a9f1409f5 \ + --hash=sha256:e8afc3f2ccfa24215f8cb28dcf43f0113ac3c37c2f0f0806d8c70e4228c5cf4d \ + --hash=sha256:e8fc20152abba6b83724d7ff268c249fa196d8259ff481f3b1476383f8f24e42 \ + --hash=sha256:eaa9599de571d72e2daf60164784109f19978b327a3910d3e9de8c97b5b70cfe \ + --hash=sha256:ec15a59cf5af7be74194f7ab02d0f59a62bdcf1a537677ce67a2537c9b87fcda \ + --hash=sha256:f190daf01f13c72eac4efd5c430a8de82489d9cff23c364c3ea822545032993e \ + --hash=sha256:f34c41761022dd093b4b6896d4810782ffbabe30f2d443ff5f083e0cbbb8c737 \ + --hash=sha256:f3e98bb3798ead92273dc0e5fd0f31ade220f59a266ffd8a4f6065e0a3ce0523 \ + --hash=sha256:f42d0984e947b8adf7dd6dde396e720934d12c506ce84eea8476409563607591 \ + --hash=sha256:f71a396b3bf33ecaa1626c255855702aca4d3d9fea5e051b41ac59a9c1c41edc \ + --hash=sha256:f9e130248f4462aaa8e2552d547f36ddadbeaa573879158d721bbd33dfe4743a \ + --hash=sha256:fed51ac40f757d41b7c48425901843666a6677e3e8eb0abcff09e4ba6e664f50 + # via werkzeug +matplotlib==3.10.8 \ + --hash=sha256:00270d217d6b20d14b584c521f810d60c5c78406dc289859776550df837dcda7 \ + --hash=sha256:0a33deb84c15ede243aead39f77e990469fff93ad1521163305095b77b72ce4a \ + --hash=sha256:113bb52413ea508ce954a02c10ffd0d565f9c3bc7f2eddc27dfe1731e71c7b5f \ + --hash=sha256:12d90df9183093fcd479f4172ac26b322b1248b15729cb57f42f71f24c7e37a3 \ + --hash=sha256:15d30132718972c2c074cd14638c7f4592bd98719e2308bccea40e0538bc0cb5 \ + --hash=sha256:18821ace09c763ec93aef5eeff087ee493a24051936d7b9ebcad9662f66501f9 \ + --hash=sha256:1ae029229a57cd1e8fe542485f27e7ca7b23aa9e8944ddb4985d0bc444f1eca2 \ + --hash=sha256:2299372c19d56bcd35cf05a2738308758d32b9eaed2371898d8f5bd33f084aa3 \ + --hash=sha256:238b7ce5717600615c895050239ec955d91f321c209dd110db988500558e70d6 \ + --hash=sha256:24d50994d8c5816ddc35411e50a86ab05f575e2530c02752e02538122613371f \ + --hash=sha256:25d380fe8b1dc32cf8f0b1b448470a77afb195438bafdf1d858bfb876f3edf7b \ + --hash=sha256:2c1998e92cd5999e295a731bcb2911c75f597d937341f3030cc24ef2733d78a8 \ + --hash=sha256:2cf5bd12cecf46908f286d7838b2abc6c91cda506c0445b8223a7c19a00df008 \ + --hash=sha256:32f8dce744be5569bebe789e46727946041199030db8aeb2954d26013a0eb26b \ + --hash=sha256:37b3c1cc42aa184b3f738cfa18c1c1d72fd496d85467a6cf7b807936d39aa656 \ + --hash=sha256:3a48a78d2786784cc2413e57397981fb45c79e968d99656706018d6e62e57958 \ + --hash=sha256:3ab4aabc72de4ff77b3ec33a6d78a68227bf1123465887f9905ba79184a1cc04 \ + --hash=sha256:3c624e43ed56313651bc18a47f838b60d7b8032ed348911c54906b130b20071b \ + --hash=sha256:3f2e409836d7f5ac2f1c013110a4d50b9f7edc26328c108915f9075d7d7a91b6 \ + --hash=sha256:3f5c3e4da343bba819f0234186b9004faba952cc420fbc522dc4e103c1985908 \ + --hash=sha256:41703cc95688f2516b480f7f339d8851a6035f18e100ee6a32bc0b8536a12a9c \ + --hash=sha256:495672de149445ec1b772ff2c9ede9b769e3cb4f0d0aa7fa730d7f59e2d4e1c1 \ + --hash=sha256:4cf267add95b1c88300d96ca837833d4112756045364f5c734a2276038dae27d \ + --hash=sha256:56271f3dac49a88d7fca5060f004d9d22b865f743a12a23b1e937a0be4818ee1 \ + --hash=sha256:595ba4d8fe983b88f0eec8c26a241e16d6376fe1979086232f481f8f3f67494c \ + --hash=sha256:5f62550b9a30afde8c1c3ae450e5eb547d579dd69b25c2fc7a1c67f934c1717a \ + --hash=sha256:646d95230efb9ca614a7a594d4fcacde0ac61d25e37dd51710b36477594963ce \ + --hash=sha256:64fcc24778ca0404ce0cb7b6b77ae1f4c7231cdd60e6778f999ee05cbd581b9a \ + --hash=sha256:6be43b667360fef5c754dda5d25a32e6307a03c204f3c0fc5468b78fa87b4160 \ + --hash=sha256:6da7c2ce169267d0d066adcf63758f0604aa6c3eebf67458930f9d9b79ad1db1 \ + --hash=sha256:83d282364ea9f3e52363da262ce32a09dfe241e4080dcedda3c0db059d3c1f11 \ + --hash=sha256:9153c3292705be9f9c64498a8872118540c3f4123d1a1c840172edf262c8be4a \ + --hash=sha256:99eefd13c0dc3b3c1b4d561c1169e65fe47aab7b8158754d7c084088e2329466 \ + --hash=sha256:a0a7f52498f72f13d4a25ea70f35f4cb60642b466cbb0a9be951b5bc3f45a486 \ + --hash=sha256:a2b336e2d91a3d7006864e0990c83b216fcdca64b5a6484912902cef87313d78 \ + --hash=sha256:a48f2b74020919552ea25d222d5cc6af9ca3f4eb43a93e14d068457f545c2a17 \ + --hash=sha256:ad3d9833a64cf48cc4300f2b406c3d0f4f4724a91c0bd5640678a6ba7c102077 \ + --hash=sha256:b44d07310e404ba95f8c25aa5536f154c0a8ec473303535949e52eb71d0a1565 \ + --hash=sha256:b53285e65d4fa4c86399979e956235deb900be5baa7fc1218ea67fbfaeaadd6f \ + --hash=sha256:b5a2b97dbdc7d4f353ebf343744f1d1f1cca8aa8bfddb4262fcf4306c3761d50 \ + --hash=sha256:b9a5ca4ac220a0cdd1ba6bcba3608547117d30468fefce49bb26f55c1a3d5c58 \ + --hash=sha256:bab485bcf8b1c7d2060b4fcb6fc368a9e6f4cd754c9c2fea281f4be21df394a2 \ + --hash=sha256:c108a1d6fa78a50646029cb6d49808ff0fc1330fda87fa6f6250c6b5369b6645 \ + --hash=sha256:d56a1efd5bfd61486c8bc968fa18734464556f0fb8e51690f4ac25d85cbbbbc2 \ + --hash=sha256:d9050fee89a89ed57b4fb2c1bfac9a3d0c57a0d55aed95949eedbc42070fea39 \ + --hash=sha256:dd80ecb295460a5d9d260df63c43f4afbdd832d725a531f008dad1664f458adf \ + --hash=sha256:e8ea3e2d4066083e264e75c829078f9e149fa119d27e19acd503de65e0b13149 \ + --hash=sha256:eb3823f11823deade26ce3b9f40dcb4a213da7a670013929f31d5f5ed1055b22 \ + --hash=sha256:ee40c27c795bda6a5292e9cff9890189d32f7e3a0bf04e0e3c9430c4a00c37df \ + --hash=sha256:efb30e3baaea72ce5928e32bab719ab4770099079d66726a62b11b1ef7273be4 \ + --hash=sha256:f254d118d14a7f99d616271d6c3c27922c092dac11112670b157798b89bf4933 \ + --hash=sha256:f89c151aab2e2e23cb3fe0acad1e8b82841fd265379c4cecd0f3fcb34c15e0f6 \ + --hash=sha256:f97aeb209c3d2511443f8797e3e5a569aebb040d4f8bc79aa3ee78a8fb9e3dd8 \ + --hash=sha256:f9b587c9c7274c1613a30afabf65a272114cd6cdbe67b3406f818c79d7ab2e2a \ + --hash=sha256:fb061f596dad3a0f52b60dc6a5dec4a0c300dec41e058a7efe09256188d170b7 # via -r build/test-requirements.txt mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba # via markdown-it-py -ml-dtypes==0.5.1 \ - --hash=sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327 \ - --hash=sha256:05f23447a1c20ddf4dc7c2c661aa9ed93fcb2658f1017c204d1e758714dc28a8 \ - --hash=sha256:12651420130ee7cc13059fc56dac6ad300c3af3848b802d475148c9defd27c23 \ - --hash=sha256:141b2ea2f20bb10802ddca55d91fe21231ef49715cfc971998e8f2a9838f3dbe \ - --hash=sha256:15ad0f3b0323ce96c24637a88a6f44f6713c64032f27277b069f285c3cf66478 \ - --hash=sha256:1b7fbe5571fdf28fd3aaab3ef4aafc847de9ebf263be959958c1ca58ec8eadf5 \ - --hash=sha256:26ebcc69d7b779c8f129393e99732961b5cc33fcff84090451f448c89b0e01b4 \ - --hash=sha256:6f462f5eca22fb66d7ff9c4744a3db4463af06c49816c4b6ac89b16bfcdc592e \ - --hash=sha256:6f76232163b5b9c34291b54621ee60417601e2e4802a188a0ea7157cd9b323f4 \ - --hash=sha256:7000b6e4d8ef07542c05044ec5d8bbae1df083b3f56822c3da63993a113e716f \ - --hash=sha256:810512e2eccdfc3b41eefa3a27402371a3411453a1efc7e9c000318196140fed \ - --hash=sha256:8f2c028954f16ede77902b223a8da2d9cbb3892375b85809a5c3cfb1587960c4 \ - --hash=sha256:9626d0bca1fb387d5791ca36bacbba298c5ef554747b7ebeafefb4564fc83566 \ - --hash=sha256:ac5b58559bb84a95848ed6984eb8013249f90b6bab62aa5acbad876e256002c9 \ - --hash=sha256:ad4953c5eb9c25a56d11a913c2011d7e580a435ef5145f804d98efa14477d390 \ - --hash=sha256:aefedc579ece2f8fb38f876aa7698204ee4c372d0e54f1c1ffa8ca580b54cc60 \ - --hash=sha256:afb2009ac98da274e893e03162f6269398b2b00d947e7057ee2469a921d58135 \ - --hash=sha256:b8a9d46b4df5ae2135a8e8e72b465448ebbc1559997f4f9304a9ecc3413efb5b \ - --hash=sha256:bd73f51957949069573ff783563486339a9285d72e2f36c18e0c1aa9ca7eb190 \ - --hash=sha256:bf9975bda82a99dc935f2ae4c83846d86df8fd6ba179614acac8e686910851da \ - --hash=sha256:c09526488c3a9e8b7a23a388d4974b670a9a3dd40c5c8a61db5593ce9b725bab \ - --hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \ - --hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \ - --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 - # via -r build/requirements.in +ml-dtypes==0.5.4 \ + --hash=sha256:0d2ffd05a2575b1519dc928c0b93c06339eb67173ff53acb00724502cda231cf \ + --hash=sha256:11942cbf2cf92157db91e5022633c0d9474d4dfd813a909383bd23ce828a4b7d \ + --hash=sha256:14a4fd3228af936461db66faccef6e4f41c1d82fcc30e9f8d58a08916b1d811f \ + --hash=sha256:19b9a53598f21e453ea2fbda8aa783c20faff8e1eeb0d7ab899309a0053f1483 \ + --hash=sha256:2314892cdc3fcf05e373d76d72aaa15fda9fb98625effa73c1d646f331fcecb7 \ + --hash=sha256:2b857d3af6ac0d39db1de7c706e69c7f9791627209c3d6dedbfca8c7e5faec22 \ + --hash=sha256:304ad47faa395415b9ccbcc06a0350800bc50eda70f0e45326796e27c62f18b6 \ + --hash=sha256:35f29491a3e478407f7047b8a4834e4640a77d2737e0b294d049746507af5175 \ + --hash=sha256:388d399a2152dd79a3f0456a952284a99ee5c93d3e2f8dfe25977511e0515270 \ + --hash=sha256:3bbbe120b915090d9dd1375e4684dd17a20a2491ef25d640a908281da85e73f1 \ + --hash=sha256:3d277bf3637f2a62176f4575512e9ff9ef51d00e39626d9fe4a161992f355af2 \ + --hash=sha256:4381fe2f2452a2d7589689693d3162e876b3ddb0a832cde7a414f8e1adf7eab1 \ + --hash=sha256:4ff7f3e7ca2972e7de850e7b8fcbb355304271e2933dd90814c1cb847414d6e2 \ + --hash=sha256:531eff30e4d368cb6255bc2328d070e35836aa4f282a0fb5f3a0cd7260257298 \ + --hash=sha256:533ce891ba774eabf607172254f2e7260ba5f57bdd64030c9a4fcfbd99815d0d \ + --hash=sha256:557a31a390b7e9439056644cb80ed0735a6e3e3bb09d67fd5687e4b04238d1de \ + --hash=sha256:5a0f68ca8fd8d16583dfa7793973feb86f2fbb56ce3966daf9c9f748f52a2049 \ + --hash=sha256:6a0df4223b514d799b8a1629c65ddc351b3efa833ccf7f8ea0cf654a61d1e35d \ + --hash=sha256:6c7ecb74c4bd71db68a6bea1edf8da8c34f3d9fe218f038814fd1d310ac76c90 \ + --hash=sha256:7c23c54a00ae43edf48d44066a7ec31e05fdc2eee0be2b8b50dd1903a1db94bb \ + --hash=sha256:805cef3a38f4eafae3a5bf9ebdcdb741d0bcfd9e1bd90eb54abd24f928cd2465 \ + --hash=sha256:88c982aac7cb1cbe8cbb4e7f253072b1df872701fcaf48d84ffbb433b6568f24 \ + --hash=sha256:8ab06a50fb9bf9666dd0fe5dfb4676fa2b0ac0f31ecff72a6c3af8e22c063453 \ + --hash=sha256:8c6a2dcebd6f3903e05d51960a8058d6e131fe69f952a5397e5dbabc841b6d56 \ + --hash=sha256:8c760d85a2f82e2bed75867079188c9d18dae2ee77c25a54d60e9cc79be1bc48 \ + --hash=sha256:9ad459e99793fa6e13bd5b7e6792c8f9190b4e5a1b45c63aba14a4d0a7f1d5ff \ + --hash=sha256:9bad06436568442575beb2d03389aa7456c690a5b05892c471215bfd8cf39460 \ + --hash=sha256:a174837a64f5b16cab6f368171a1a03a27936b31699d167684073ff1c4237dac \ + --hash=sha256:a7f7c643e8b1320fd958bf098aa7ecf70623a42ec5154e3be3be673f4c34d900 \ + --hash=sha256:a9b61c19040397970d18d7737375cffd83b1f36a11dd4ad19f83a016f736c3ef \ + --hash=sha256:b4b801ebe0b477be666696bda493a9be8356f1f0057a57f1e35cd26928823e5a \ + --hash=sha256:b95e97e470fe60ed493fd9ae3911d8da4ebac16bd21f87ffa2b7c588bf22ea2c \ + --hash=sha256:bc11d7e8c44a65115d05e2ab9989d1e045125d7be8e05a071a48bc76eb6d6040 \ + --hash=sha256:bfc534409c5d4b0bf945af29e5d0ab075eae9eecbb549ff8a29280db822f34f9 \ + --hash=sha256:c1a953995cccb9e25a4ae19e34316671e4e2edaebe4cf538229b1fc7109087b7 \ + --hash=sha256:cb73dccfc991691c444acc8c0012bee8f2470da826a92e3a20bb333b1a7894e6 \ + --hash=sha256:ce756d3a10d0c4067172804c9cc276ba9cc0ff47af9078ad439b075d1abdc29b \ + --hash=sha256:d81fdb088defa30eb37bf390bb7dde35d3a83ec112ac8e33d75ab28cc29dd8b0 \ + --hash=sha256:f21c9219ef48ca5ee78402d5cc831bd58ea27ce89beda894428bc67a52da5328 + # via + # -r build/requirements.in + # jaxlib + # keras + # tensorflow + # tensorstore mpmath==1.3.0 \ --hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \ --hash=sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c # via -r build/test-requirements.txt -numpy==2.1.2 ; python_version >= "3.13" \ - --hash=sha256:05b2d4e667895cc55e3ff2b56077e4c8a5604361fc21a042845ea3ad67465aa8 \ - --hash=sha256:12edb90831ff481f7ef5f6bc6431a9d74dc0e5ff401559a71e5e4611d4f2d466 \ - --hash=sha256:13311c2db4c5f7609b462bc0f43d3c465424d25c626d95040f073e30f7570e35 \ - --hash=sha256:13532a088217fa624c99b843eeb54640de23b3414b14aa66d023805eb731066c \ - --hash=sha256:13602b3174432a35b16c4cfb5de9a12d229727c3dd47a6ce35111f2ebdf66ff4 \ - --hash=sha256:1600068c262af1ca9580a527d43dc9d959b0b1d8e56f8a05d830eea39b7c8af6 \ - --hash=sha256:1b8cde4f11f0a975d1fd59373b32e2f5a562ade7cde4f85b7137f3de8fbb29a0 \ - --hash=sha256:1c193d0b0238638e6fc5f10f1b074a6993cb13b0b431f64079a509d63d3aa8b7 \ - --hash=sha256:1ebec5fd716c5a5b3d8dfcc439be82a8407b7b24b230d0ad28a81b61c2f4659a \ - --hash=sha256:242b39d00e4944431a3cd2db2f5377e15b5785920421993770cddb89992c3f3a \ - --hash=sha256:259ec80d54999cc34cd1eb8ded513cb053c3bf4829152a2e00de2371bd406f5e \ - --hash=sha256:2abbf905a0b568706391ec6fa15161fad0fb5d8b68d73c461b3c1bab6064dd62 \ - --hash=sha256:2cbba4b30bf31ddbe97f1c7205ef976909a93a66bb1583e983adbd155ba72ac2 \ - --hash=sha256:2ffef621c14ebb0188a8633348504a35c13680d6da93ab5cb86f4e54b7e922b5 \ - --hash=sha256:30d53720b726ec36a7f88dc873f0eec8447fbc93d93a8f079dfac2629598d6ee \ - --hash=sha256:32e16a03138cabe0cb28e1007ee82264296ac0983714094380b408097a418cfe \ - --hash=sha256:43cca367bf94a14aca50b89e9bc2061683116cfe864e56740e083392f533ce7a \ - --hash=sha256:456e3b11cb79ac9946c822a56346ec80275eaf2950314b249b512896c0d2505e \ - --hash=sha256:4d6ec0d4222e8ffdab1744da2560f07856421b367928026fb540e1945f2eeeaf \ - --hash=sha256:5006b13a06e0b38d561fab5ccc37581f23c9511879be7693bd33c7cd15ca227c \ - --hash=sha256:675c741d4739af2dc20cd6c6a5c4b7355c728167845e3c6b0e824e4e5d36a6c3 \ - --hash=sha256:6cdb606a7478f9ad91c6283e238544451e3a95f30fb5467fbf715964341a8a86 \ - --hash=sha256:6d95f286b8244b3649b477ac066c6906fbb2905f8ac19b170e2175d3d799f4df \ - --hash=sha256:76322dcdb16fccf2ac56f99048af32259dcc488d9b7e25b51e5eca5147a3fb98 \ - --hash=sha256:7c1c60328bd964b53f8b835df69ae8198659e2b9302ff9ebb7de4e5a5994db3d \ - --hash=sha256:860ec6e63e2c5c2ee5e9121808145c7bf86c96cca9ad396c0bd3e0f2798ccbe2 \ - --hash=sha256:8e00ea6fc82e8a804433d3e9cedaa1051a1422cb6e443011590c14d2dea59146 \ - --hash=sha256:9c6c754df29ce6a89ed23afb25550d1c2d5fdb9901d9c67a16e0b16eaf7e2550 \ - --hash=sha256:a26ae94658d3ba3781d5e103ac07a876b3e9b29db53f68ed7df432fd033358a8 \ - --hash=sha256:a65acfdb9c6ebb8368490dbafe83c03c7e277b37e6857f0caeadbbc56e12f4fb \ - --hash=sha256:a7d80b2e904faa63068ead63107189164ca443b42dd1930299e0d1cb041cec2e \ - --hash=sha256:a84498e0d0a1174f2b3ed769b67b656aa5460c92c9554039e11f20a05650f00d \ - --hash=sha256:ab4754d432e3ac42d33a269c8567413bdb541689b02d93788af4131018cbf366 \ - --hash=sha256:ad369ed238b1959dfbade9018a740fb9392c5ac4f9b5173f420bd4f37ba1f7a0 \ - --hash=sha256:b1d0fcae4f0949f215d4632be684a539859b295e2d0cb14f78ec231915d644db \ - --hash=sha256:b42a1a511c81cc78cbc4539675713bbcf9d9c3913386243ceff0e9429ca892fe \ - --hash=sha256:bd33f82e95ba7ad632bc57837ee99dba3d7e006536200c4e9124089e1bf42426 \ - --hash=sha256:bdd407c40483463898b84490770199d5714dcc9dd9b792f6c6caccc523c00952 \ - --hash=sha256:c6eef7a2dbd0abfb0d9eaf78b73017dbfd0b54051102ff4e6a7b2980d5ac1a03 \ - --hash=sha256:c82af4b2ddd2ee72d1fc0c6695048d457e00b3582ccde72d8a1c991b808bb20f \ - --hash=sha256:d666cb72687559689e9906197e3bec7b736764df6a2e58ee265e360663e9baf7 \ - --hash=sha256:d7bf0a4f9f15b32b5ba53147369e94296f5fffb783db5aacc1be15b4bf72f43b \ - --hash=sha256:d82075752f40c0ddf57e6e02673a17f6cb0f8eb3f587f63ca1eaab5594da5b17 \ - --hash=sha256:da65fb46d4cbb75cb417cddf6ba5e7582eb7bb0b47db4b99c9fe5787ce5d91f5 \ - --hash=sha256:e2b49c3c0804e8ecb05d59af8386ec2f74877f7ca8fd9c1e00be2672e4d399b1 \ - --hash=sha256:e585c8ae871fd38ac50598f4763d73ec5497b0de9a0ab4ef5b69f01c6a046142 \ - --hash=sha256:e8d3ca0a72dd8846eb6f7dfe8f19088060fcb76931ed592d29128e0219652884 \ - --hash=sha256:ef444c57d664d35cac4e18c298c47d7b504c66b17c2ea91312e979fcfbdfb08a \ - --hash=sha256:f1eb068ead09f4994dec71c24b2844f1e4e4e013b9629f812f292f04bd1510d9 \ - --hash=sha256:f2ded8d9b6f68cc26f8425eda5d3877b47343e68ca23d0d0846f4d312ecaa445 \ - --hash=sha256:f751ed0a2f250541e19dfca9f1eafa31a392c71c832b6bb9e113b10d050cb0f1 \ - --hash=sha256:faa88bc527d0f097abdc2c663cddf37c05a1c2f113716601555249805cf573f1 \ - --hash=sha256:fc44e3c68ff00fd991b59092a54350e6e4911152682b4782f68070985aa9e648 +namex==0.1.0 \ + --hash=sha256:117f03ccd302cc48e3f5c58a296838f6b89c83455ab8683a1e85f2a430aa4306 \ + --hash=sha256:e2012a474502f1e2251267062aae3114611f07df4224b6e06334c57b0f2ce87c + # via keras +numpy==2.1.3 ; python_version == "3.13" \ + --hash=sha256:016d0f6f5e77b0f0d45d77387ffa4bb89816b57c835580c3ce8e099ef830befe \ + --hash=sha256:02135ade8b8a84011cbb67dc44e07c58f28575cf9ecf8ab304e51c05528c19f0 \ + --hash=sha256:08788d27a5fd867a663f6fc753fd7c3ad7e92747efc73c53bca2f19f8bc06f48 \ + --hash=sha256:0d30c543f02e84e92c4b1f415b7c6b5326cbe45ee7882b6b77db7195fb971e3a \ + --hash=sha256:0fa14563cc46422e99daef53d725d0c326e99e468a9320a240affffe87852564 \ + --hash=sha256:13138eadd4f4da03074851a698ffa7e405f41a0845a6b1ad135b81596e4e9958 \ + --hash=sha256:14e253bd43fc6b37af4921b10f6add6925878a42a0c5fe83daee390bca80bc17 \ + --hash=sha256:15cb89f39fa6d0bdfb600ea24b250e5f1a3df23f901f51c8debaa6a5d122b2f0 \ + --hash=sha256:17ee83a1f4fef3c94d16dc1802b998668b5419362c8a4f4e8a491de1b41cc3ee \ + --hash=sha256:2312b2aa89e1f43ecea6da6ea9a810d06aae08321609d8dc0d0eda6d946a541b \ + --hash=sha256:2564fbdf2b99b3f815f2107c1bbc93e2de8ee655a69c261363a1172a79a257d4 \ + --hash=sha256:3522b0dfe983a575e6a9ab3a4a4dfe156c3e428468ff08ce582b9bb6bd1d71d4 \ + --hash=sha256:4394bc0dbd074b7f9b52024832d16e019decebf86caf909d94f6b3f77a8ee3b6 \ + --hash=sha256:45966d859916ad02b779706bb43b954281db43e185015df6eb3323120188f9e4 \ + --hash=sha256:4d1167c53b93f1f5d8a139a742b3c6f4d429b54e74e6b57d0eff40045187b15d \ + --hash=sha256:4f2015dfe437dfebbfce7c85c7b53d81ba49e71ba7eadbf1df40c915af75979f \ + --hash=sha256:50ca6aba6e163363f132b5c101ba078b8cbd3fa92c7865fd7d4d62d9779ac29f \ + --hash=sha256:50d18c4358a0a8a53f12a8ba9d772ab2d460321e6a93d6064fc22443d189853f \ + --hash=sha256:5641516794ca9e5f8a4d17bb45446998c6554704d888f86df9b200e66bdcce56 \ + --hash=sha256:576a1c1d25e9e02ed7fa5477f30a127fe56debd53b8d2c89d5578f9857d03ca9 \ + --hash=sha256:6a4825252fcc430a182ac4dee5a505053d262c807f8a924603d411f6718b88fd \ + --hash=sha256:72dcc4a35a8515d83e76b58fdf8113a5c969ccd505c8a946759b24e3182d1f23 \ + --hash=sha256:747641635d3d44bcb380d950679462fae44f54b131be347d5ec2bce47d3df9ed \ + --hash=sha256:762479be47a4863e261a840e8e01608d124ee1361e48b96916f38b119cfda04a \ + --hash=sha256:78574ac2d1a4a02421f25da9559850d59457bac82f2b8d7a44fe83a64f770098 \ + --hash=sha256:825656d0743699c529c5943554d223c021ff0494ff1442152ce887ef4f7561a1 \ + --hash=sha256:8637dcd2caa676e475503d1f8fdb327bc495554e10838019651b76d17b98e512 \ + --hash=sha256:96fe52fcdb9345b7cd82ecd34547fca4321f7656d500eca497eb7ea5a926692f \ + --hash=sha256:973faafebaae4c0aaa1a1ca1ce02434554d67e628b8d805e61f874b84e136b09 \ + --hash=sha256:996bb9399059c5b82f76b53ff8bb686069c05acc94656bb259b1d63d04a9506f \ + --hash=sha256:a38c19106902bb19351b83802531fea19dee18e5b37b36454f27f11ff956f7fc \ + --hash=sha256:a6b46587b14b888e95e4a24d7b13ae91fa22386c199ee7b418f449032b2fa3b8 \ + --hash=sha256:a9f7f672a3388133335589cfca93ed468509cb7b93ba3105fce780d04a6576a0 \ + --hash=sha256:aa08e04e08aaf974d4458def539dece0d28146d866a39da5639596f4921fd761 \ + --hash=sha256:b0df3635b9c8ef48bd3be5f862cf71b0a4716fa0e702155c45067c6b711ddcef \ + --hash=sha256:b47fbb433d3260adcd51eb54f92a2ffbc90a4595f8970ee00e064c644ac788f5 \ + --hash=sha256:baed7e8d7481bfe0874b566850cb0b85243e982388b7b23348c6db2ee2b2ae8e \ + --hash=sha256:bc6f24b3d1ecc1eebfbf5d6051faa49af40b03be1aaa781ebdadcbc090b4539b \ + --hash=sha256:c006b607a865b07cd981ccb218a04fc86b600411d83d6fc261357f1c0966755d \ + --hash=sha256:c181ba05ce8299c7aa3125c27b9c2167bca4a4445b7ce73d5febc411ca692e43 \ + --hash=sha256:c7662f0e3673fe4e832fe07b65c50342ea27d989f92c80355658c7f888fcc83c \ + --hash=sha256:c80e4a09b3d95b4e1cac08643f1152fa71a0a821a2d4277334c88d54b2219a41 \ + --hash=sha256:c894b4305373b9c5576d7a12b473702afdf48ce5369c074ba304cc5ad8730dff \ + --hash=sha256:d7aac50327da5d208db2eec22eb11e491e3fe13d22653dce51b0f4109101b408 \ + --hash=sha256:d89dd2b6da69c4fff5e39c28a382199ddedc3a5be5390115608345dec660b9e2 \ + --hash=sha256:d9beb777a78c331580705326d2367488d5bc473b49a9bc3036c154832520aca9 \ + --hash=sha256:dc258a761a16daa791081d026f0ed4399b582712e6fc887a95af09df10c5ca57 \ + --hash=sha256:e14e26956e6f1696070788252dcdff11b4aca4c3e8bd166e0df1bb8f315a67cb \ + --hash=sha256:e6988e90fcf617da2b5c78902fe8e668361b43b4fe26dbf2d7b0f8034d4cafb9 \ + --hash=sha256:e711e02f49e176a01d0349d82cb5f05ba4db7d5e7e0defd026328e5cfb3226d3 \ + --hash=sha256:ea4dedd6e394a9c180b33c2c872b92f7ce0f8e7ad93e9585312b0c5a04777a4a \ + --hash=sha256:ecc76a9ba2911d8d37ac01de72834d8849e55473457558e12995f4cd53e778e0 \ + --hash=sha256:f55ba01150f52b1027829b50d70ef1dafd9821ea82905b63936668403c3b471e \ + --hash=sha256:f653490b33e9c3a4c1c01d41bc2aef08f9475af51146e4a7710c450cf9761598 \ + --hash=sha256:fa2d1337dc61c8dc417fbccf20f6d1e139896a30721b7f1e832b2bb6ef4eb6c4 # via - # -r build/requirements.in + # -r build/nonfreethreading-requirements.txt # contourpy + # h5py + # jaxlib + # keras # matplotlib # ml-dtypes + # numpy-typing-compat + # optype # scipy -nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ - --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ - --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ - --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 + # tensorboard + # tensorflow + # tensorstore +numpy-typing-compat==20251206.2.1 \ + --hash=sha256:703ae61be7877ab0af562298776b89eae609a3985414d92011a39a42350b42e1 \ + --hash=sha256:8a868da29e8d076c2aaef8ea9ebb602af917c9752063cfe7c95d6cb60c7b9ea3 + # via optype +nvidia-cublas==13.1.0.3 ; sys_platform == "linux" \ + --hash=sha256:2a3b94a37def342471c59fad7856caee4926809a72dd5270155d6a31b5b277be \ + --hash=sha256:c86fc7f7ae36d7528288c5d88098edcb7b02c633d262e7ddbb86b0ad91be5df2 \ + --hash=sha256:ee8722c1f0145ab246bccb9e452153b5e0515fd094c3678df50b2a0888b8b171 # via - # -r build/test-requirements.txt + # -r build/nvidia-requirements.txt + # nvidia-cudnn-cu13 + # nvidia-cusolver +nvidia-cublas-cu12==12.9.1.4 ; sys_platform == "linux" \ + --hash=sha256:1e5fee10662e6e52bd71dec533fbbd4971bb70a5f24f3bc3793e5c2e9dc640bf \ + --hash=sha256:453611eb21a7c1f2c2156ed9f3a45b691deda0440ec550860290dc901af5b4c2 \ + --hash=sha256:7a950dae01add3b415a5a5cdc4ec818fb5858263e9cca59004bb99fdbbd3a5d6 + # via + # -r build/nvidia-requirements.txt # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 -nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \ - --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ - --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ - --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 - # via -r build/test-requirements.txt -nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \ - --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ - --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ - --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b - # via -r build/test-requirements.txt -nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ - --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ - --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ - --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 - # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ - --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ - --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ - --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef - # via -r build/test-requirements.txt -nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ - --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ - --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ - --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 - # via -r build/test-requirements.txt -nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \ - --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ - --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ - --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac - # via -r build/test-requirements.txt -nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \ - --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ - --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ - --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 +nvidia-cuda-crt==13.0.88 ; sys_platform == "linux" \ + --hash=sha256:2c8043c7c9e02492716426e9919fc78d2c5b3b2a7a768a88e952676b08aa55a4 \ + --hash=sha256:31e02c52916804ca15e31f272a96181d8fadaf40c4c82a77a6f78071a22eccf3 \ + --hash=sha256:ee2ea2a97073e02ee62bb27841f437332be2c248e3eac013df07997ada39c003 # via - # -r build/test-requirements.txt + # -r build/nvidia-requirements.txt + # nvidia-cuda-nvcc +nvidia-cuda-cupti==13.0.85 ; sys_platform == "linux" \ + --hash=sha256:4eb01c08e859bf924d222250d2e8f8b8ff6d3db4721288cf35d14252a4d933c8 \ + --hash=sha256:683f58d301548deeefcb8f6fac1b8d907691b9d8b18eccab417f51e362102f00 \ + --hash=sha256:796bd679890ee55fb14a94629b698b6db54bcfd833d391d5e94017dd9d7d3151 + # via -r build/nvidia-requirements.txt +nvidia-cuda-cupti-cu12==12.9.79 ; sys_platform == "linux" \ + --hash=sha256:096bcf334f13e1984ba36685ad4c1d6347db214de03dbb6eebb237b41d9d934f \ + --hash=sha256:1848a9380067560d5bee10ed240eecc22991713e672c0515f9c3d9396adf93c8 \ + --hash=sha256:791853b030602c6a11d08b5578edfb957cadea06e9d3b26adbf8d036135a4afe + # via -r build/nvidia-requirements.txt +nvidia-cuda-nvcc==13.0.88 ; sys_platform == "linux" \ + --hash=sha256:56fe502eb77625a12f25172caa3cdddb4e4c8ba2c8c17dba44b164761b380f03 \ + --hash=sha256:7c3a32c8ca9866addfd784da363ddee2f6874d560027a296f583e86a61f2d543 \ + --hash=sha256:c7ff28f86a24effdc6c034fa15230c549a273e4771b10a7fec14996f8cf3307f + # via -r build/nvidia-requirements.txt +nvidia-cuda-nvcc-cu12==12.9.86 ; sys_platform == "linux" \ + --hash=sha256:44e1eca4d08926193a558d2434b1bf83d57b4d5743e0c431c0c83d51da1df62b \ + --hash=sha256:5d6a0d32fdc7ea39917c20065614ae93add6f577d840233237ff08e9a38f58f0 \ + --hash=sha256:8ed7f0b17dea662755395be029376db3b94fed5cbb17c2d35cc866c5b1b84099 + # via -r build/nvidia-requirements.txt +nvidia-cuda-nvrtc==13.0.88 ; sys_platform == "linux" \ + --hash=sha256:6bcd4e7f8e205cbe644f5a98f2f799bef9556fefc89dd786e79a16312ce49872 \ + --hash=sha256:ad9b6d2ead2435f11cbb6868809d2adeeee302e9bb94bcf0539c7a40d80e8575 \ + --hash=sha256:d27f20a0ca67a4bb34268a5e951033496c5b74870b868bacd046b1b8e0c3267b + # via -r build/nvidia-requirements.txt +nvidia-cuda-nvrtc-cu12==12.9.86 ; sys_platform == "linux" \ + --hash=sha256:096d4de6bda726415dfaf3198d4f5c522b8e70139c97feef5cd2ca6d4cd9cead \ + --hash=sha256:210cf05005a447e29214e9ce50851e83fc5f4358df8b453155d5e1918094dcb4 \ + --hash=sha256:72972ebdcf504d69462d3bcd67e7b81edd25d0fb85a2c46d3ea3517666636349 + # via -r build/nvidia-requirements.txt +nvidia-cuda-runtime==13.0.96 ; sys_platform == "linux" \ + --hash=sha256:7f82250d7782aa23b6cfe765ecc7db554bd3c2870c43f3d1821f1d18aebf0548 \ + --hash=sha256:ef9bcbe90493a2b9d810e43d249adb3d02e98dd30200d86607d8d02687c43f55 \ + --hash=sha256:f79298c8a098cec150a597c8eba58ecdab96e3bdc4b9bc4f9983635031740492 + # via + # -r build/nvidia-requirements.txt + # nvidia-cuda-nvcc +nvidia-cuda-runtime-cu12==12.9.79 ; sys_platform == "linux" \ + --hash=sha256:25bba2dfb01d48a9b59ca474a1ac43c6ebf7011f1b0b8cc44f54eb6ac48a96c3 \ + --hash=sha256:83469a846206f2a733db0c42e223589ab62fd2fabac4432d2f8802de4bded0a4 \ + --hash=sha256:8e018af8fa02363876860388bd10ccb89eb9ab8fb0aa749aaf58430a9f7c4891 + # via -r build/nvidia-requirements.txt +nvidia-cudnn-cu12==9.16.0.29 ; sys_platform == "linux" \ + --hash=sha256:142e2bd646a4573ab17d61a24c6359155cdfe1f34c67fc305b71222a7ae45b8e \ + --hash=sha256:4b09c43096db582f110c5572d0bcbd98b30d709e860a8f73c6c3846baa83b8d2 \ + --hash=sha256:78d05b4434dacc7dd9bc903d5c33a2f28a5f0064d02568ef7b2418f89f6c5922 + # via -r build/nvidia-requirements.txt +nvidia-cudnn-cu13==9.16.0.29 ; sys_platform == "linux" \ + --hash=sha256:6349bc8769369a91611c5e2ce5c2e510e61848c245099c31e870d2cdce0ab90d \ + --hash=sha256:79dc1bfe8c1a780cf4eb7b334d14d7927576d6dd8823f8e2769911af30fd4da3 \ + --hash=sha256:faafa46e2e7dd844bbcf06b6adec3fa66924987f2fb21bf67f5c6fd697c74a64 + # via -r build/nvidia-requirements.txt +nvidia-cufft==12.0.0.61 ; sys_platform == "linux" \ + --hash=sha256:2708c852ef8cd89d1d2068bdbece0aa188813a0c934db3779b9b1faa8442e5f5 \ + --hash=sha256:2abce5b39d2f5ae12730fb7e5db6696533e36c26e2d3e8fd1750bdd2853364eb \ + --hash=sha256:6c44f692dce8fd5ffd3e3df134b6cdb9c2f72d99cf40b62c32dde45eea9ddad3 + # via -r build/nvidia-requirements.txt +nvidia-cufft-cu12==11.4.1.4 ; sys_platform == "linux" \ + --hash=sha256:1a28c9b12260a1aa7a8fd12f5ebd82d027963d635ba82ff39a1acfa7c4c0fbcf \ + --hash=sha256:8e5bfaac795e93f80611f807d42844e8e27e340e0cde270dcb6c65386d795b80 \ + --hash=sha256:c67884f2a7d276b4b80eb56a79322a95df592ae5e765cf1243693365ccab4e28 + # via -r build/nvidia-requirements.txt +nvidia-cusolver==12.0.4.66 ; sys_platform == "linux" \ + --hash=sha256:02c2457eaa9e39de20f880f4bd8820e6a1cfb9f9a34f820eb12a155aa5bc92d2 \ + --hash=sha256:0a759da5dea5c0ea10fd307de75cdeb59e7ea4fcb8add0924859b944babf1112 \ + --hash=sha256:16515bd33a8e76bb54d024cfa068fa68d30e80fc34b9e1090813ea9362e0cb65 + # via -r build/nvidia-requirements.txt +nvidia-cusolver-cu12==11.7.5.82 ; sys_platform == "linux" \ + --hash=sha256:15da72d1340d29b5b3cf3fd100e3cd53421dde36002eda6ed93811af63c40d88 \ + --hash=sha256:62efa83e4ace59a4c734d052bb72158e888aa7b770e1a5f601682f16fe5b4fd2 \ + --hash=sha256:77666337237716783c6269a658dea310195cddbd80a5b2919b1ba8735cec8efd + # via -r build/nvidia-requirements.txt +nvidia-cusparse==12.6.3.3 ; sys_platform == "linux" \ + --hash=sha256:2b3c89c88d01ee0e477cb7f82ef60a11a4bcd57b6b87c33f789350b59759360b \ + --hash=sha256:80bcc4662f23f1054ee334a15c72b8940402975e0eab63178fc7e670aa59472c \ + --hash=sha256:cbcf42feb737bd7ec15b4c0a63e62351886bd3f975027b8815d7f720a2b5ea79 + # via + # -r build/nvidia-requirements.txt + # nvidia-cusolver +nvidia-cusparse-cu12==12.5.10.65 ; sys_platform == "linux" \ + --hash=sha256:221c73e7482dd93eda44e65ce567c031c07e2f93f6fa0ecd3ba876a195023e83 \ + --hash=sha256:73060ce019ac064a057267c585bf1fd5a353734151f87472ff02b2c5c9984e78 \ + --hash=sha256:9e487468a22a1eaf1fbd1d2035936a905feb79c4ce5c2f67626764ee4f90227c + # via + # -r build/nvidia-requirements.txt # nvidia-cusolver-cu12 -nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \ - --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ - --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 - # via -r build/test-requirements.txt -nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \ - --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ - --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ - --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 +nvidia-nccl-cu12==2.28.9 ; sys_platform == "linux" \ + --hash=sha256:485776daa8447da5da39681af455aa3b2c2586ddcf4af8772495e7c532c7e5ab \ + --hash=sha256:50a36e01c4a090b9f9c47d92cec54964de6b9fcb3362d0e19b8ffc6323c21b60 + # via -r build/nvidia-requirements.txt +nvidia-nccl-cu13==2.28.9 ; sys_platform == "linux" \ + --hash=sha256:01c873ba1626b54caa12272ed228dc5b2781545e0ae8ba3f432a8ef1c6d78643 \ + --hash=sha256:e4553a30f34195f3fa1da02a6da3d6337d28f2003943aa0a3d247bbc25fefc42 + # via -r build/nvidia-requirements.txt +nvidia-nvjitlink==13.0.88 ; sys_platform == "linux" \ + --hash=sha256:13a74f429e23b921c1109976abefacc69835f2f433ebd323d3946e11d804e47b \ + --hash=sha256:634e96e3da9ef845ae744097a1f289238ecf946ce0b82e93cdce14b9782e682f \ + --hash=sha256:e931536ccc7d467a98ba1d8b89ff7fa7f1fa3b13f2b0069118cd7f47bff07d0c # via - # -r build/test-requirements.txt + # -r build/nvidia-requirements.txt + # nvidia-cufft + # nvidia-cusolver + # nvidia-cusparse +nvidia-nvjitlink-cu12==12.9.86 ; sys_platform == "linux" \ + --hash=sha256:994a05ef08ef4b0b299829cde613a424382aff7efb08a7172c1fa616cc3af2ca \ + --hash=sha256:cc6fcec260ca843c10e34c936921a1c426b351753587fdd638e8cff7b16bb9db \ + --hash=sha256:e3f1171dbdc83c5932a45f0f4c99180a70de9bd2718c1ab77d14104f6d7147f9 + # via + # -r build/nvidia-requirements.txt # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 +nvidia-nvshmem-cu12==3.4.5 ; sys_platform == "linux" \ + --hash=sha256:042f2500f24c021db8a06c5eec2539027d57460e1c1a762055a6554f72c369bd \ + --hash=sha256:0b48363fc6964dede448029434c6abed6c5e37f823cb43c3bcde7ecfc0457e15 + # via -r build/nvidia-requirements.txt +nvidia-nvshmem-cu13==3.4.5 ; sys_platform == "linux" \ + --hash=sha256:290f0a2ee94c9f3687a02502f3b9299a9f9fe826e6d0287ee18482e78d495b80 \ + --hash=sha256:6dc2a197f38e5d0376ad52cd1a2a3617d3cdc150fd5966f4aee9bcebb1d68fe9 + # via -r build/nvidia-requirements.txt +nvidia-nvvm==13.0.88 ; sys_platform == "linux" \ + --hash=sha256:2ef0db7849e476d3b2fc3c09b27bdd79bd7ea8ce58cd9c86553d64ea40844ba0 \ + --hash=sha256:c4376a291d72d22a315d9d2f69bdae8f8cd83a627f75bad395cee49a0fe65dc1 \ + --hash=sha256:c5f41ffeb6466944a026dfa5317d7d85355c119bbec279205d22f1869d1054e0 + # via + # -r build/nvidia-requirements.txt + # nvidia-cuda-nvcc opt-einsum==3.4.0 \ --hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \ --hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac # via # -r build/requirements.in - # -r build/test-requirements.txt -packaging==24.1 \ - --hash=sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002 \ - --hash=sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124 + # tensorflow +optree==0.18.0 \ + --hash=sha256:01b79aaee544adf5bfa573db32b943030dfeb9fd1c6e7a97aa417db56a8127e7 \ + --hash=sha256:02d9999840fabef85a6b22e757f336d5591f712f99c710d8b232d52e53115314 \ + --hash=sha256:056894ce6242cd1c7fed71325a7d9f633b2d3b4420c52af48f6a0c4560d74ca1 \ + --hash=sha256:057b983a9526645133553184bed7090bb07855df986abd9e99c456922045c6bc \ + --hash=sha256:07c5f64783ad0f0f80e61c25f276ce79b47deda83ed7956a4a9af6385fe8f60d \ + --hash=sha256:090a3f0ccafa0fe99d71e7d974ae52ff966ac26c409ec41f96556b96646054ef \ + --hash=sha256:0959bac58631e64e2ac6349cc284b37872c24f353b3d73b4682202a431f07d76 \ + --hash=sha256:0d25941de1acba176305dbdeb931dea6143b30d64ebdc5bfea2bfc12ef9e2b0a \ + --hash=sha256:0e0dbe995241efe70cfb522e89c1a7c968216926725a0e5e20cc72bd5d0311b1 \ + --hash=sha256:10f29662d637b80363dc620da46ddc58def7acf7935e20595b23e216ea912367 \ + --hash=sha256:1545c68299c0ce600e4ea1bc9112765dc4afe9a0b8ab43f955df6566bf78db42 \ + --hash=sha256:1b75e083137f361377ff8d70df885ab3a1cf8980e4019e3f311237579adadb64 \ + --hash=sha256:1db0a6497203a13063a8f044ae751dd5d8253cb815359270c38de0e4c9f8bed5 \ + --hash=sha256:1f19867b02a547fc9f11d27c0413e7483cef89699e16f3b9e8af73a9b25e6061 \ + --hash=sha256:1f674e34202383f8b42fa9335f13bedfb6b6f019c66e1f41034929e4be203423 \ + --hash=sha256:20536964ba2458f166c1e8ab25951e3fc0a5056b651bd08f16be99bb3ffed54a \ + --hash=sha256:27611c6c122745a003b5be7aedba49ef86e9fef46d743c234596de0bde6dc679 \ + --hash=sha256:27b1d0cadcf4627c98abbbdce912dbc2243f5687f3c7df39963b793c89321c65 \ + --hash=sha256:289b184cc41dfc400a30db6207ec997884d14540aae2cba10cb88dc7ebaae2a1 \ + --hash=sha256:2b5cfb5fc643f16d3a7d957807e55a937dce07566c49ccc4aa71b01064c56758 \ + --hash=sha256:3014537ff7e4e091ee46e57976f7d95c52f66a0e3eb5ebcbe0de0d924504b58e \ + --hash=sha256:30a2636279bdc805c8e154a0f346bcf704626b831ff44724d305fb72c90b7389 \ + --hash=sha256:30f95279188f6b9300e17c1557989baa991c2d6f519013bd8fea13462a0e6a45 \ + --hash=sha256:30fefc84975ac41d9075993196c64ce0c240510f0539cff121d63b709e03846f \ + --hash=sha256:31539dec60af84e16e99574634811d38e34e1fb381f40d6f489a2e582bf41f03 \ + --hash=sha256:328857d7a35129904b21164f6b0c2ff1d728ad1f5838589c5f437a16c94213c8 \ + --hash=sha256:3804fb6ddc923855db2dc4805b4524c66e00f1ef30b166be4aadd52822b13e06 \ + --hash=sha256:382e5ca02cbd5b20d713d4da189a8613f828832e2af57ccbe04a9c6b0bd9497e \ + --hash=sha256:385bd727cc7bd3c01bd6204028ac2adce8a8f622c296053d9df434aa0e30b01f \ + --hash=sha256:421b839c7ff30df5791e66c89b2e9c2f68191dd6a5d6927c32bcc6b887090df8 \ + --hash=sha256:446c46c53cb8f13abcc0d7dd1989d59bb059953c122fe9901ef53de7fb38b33e \ + --hash=sha256:4cc92339899acb685ee718fd22b25069dfa7be038c63274c54481d54ccc2f9e2 \ + --hash=sha256:4eb146711d4cd0876bf93e0118d3e74050b6f633d756c269ce7cda907281b499 \ + --hash=sha256:51e2cd9ac7fecfd5f6f56ce69f4f805553c226a2744810175959eb408101513c \ + --hash=sha256:55a2ccd121fccc9df961e982db2f4e8f2b4f7015e814ef70b1140514cdffe214 \ + --hash=sha256:56bb19ff827c9a443202b52bf103705ce96ef14d045e0a30d0d7ee7dbcef6a0d \ + --hash=sha256:571b732229d7b2e7a2215f57586f8ec0140e07c0faea916e456cbbfa819e56cb \ + --hash=sha256:5b126c34b459ef4f10f3a4d7d222416d9102b3c5a76b39f346c611792f144821 \ + --hash=sha256:5b75e32c191e4b8cf42a8aa854ed264df82936136c0bcad77be44605da41cdfc \ + --hash=sha256:5bc1221068a58175e0ad62afc199893f77c653206673a5552992a604c66fb77e \ + --hash=sha256:5e669f98b9af9f66144c7ae09912d0367ac3182abe016f67cdd15cb45e13c923 \ + --hash=sha256:66f142c743732cd4e630ea84415f654a00c792793c7f80d4511167f0f89796a6 \ + --hash=sha256:6fc9f8acde3bb561b2034e96079507fbe6d4624058fe204161eb8ef29f961296 \ + --hash=sha256:7172b16e87c87160475275e4bfaa6e4067ccde184d2cca65ba25a402a8ed7758 \ + --hash=sha256:71ca2fcad8972ba56d6cfffbcd962f45f5d4bc04182f23d66154b38c2eb37de3 \ + --hash=sha256:72fa79be4d6515682417f103ae759a22345439eb1319886be936029215ee00dc \ + --hash=sha256:7699957183f8d45402edd6266e175510317f5fcd7f0e623510f2eb7e1ebfc667 \ + --hash=sha256:79bbe14d6cad81f5840958589daa1b836864ada40031712a446dce8129917efd \ + --hash=sha256:7ae6945f68771b1389ee46a1778e779f4ad76bca9306f3e39eb397f9a0dd2753 \ + --hash=sha256:80d971060c888c3989132b7e75dfb50848636d41bc931af1b93fe2019fba469c \ + --hash=sha256:80f28e4666aad66e5e20bdc2c47b5bf320250bb5407b3a39dfb1772787a7068f \ + --hash=sha256:81e755124b77e766166c9d05206b90c68f234f425ad2e3c8a6c96f0db548c67b \ + --hash=sha256:86f5bf05ad236f666e5395e989d6ac2cbfd02556526703e6c6f0a594c7fa081f \ + --hash=sha256:895f23a4cd8aee2c2464efdad2d9bde28a2aaabee634c96423a933f40e74a67e \ + --hash=sha256:89d5156f8a0a3792701e1c31473eb307f0b45696f48dc51d721f1bfe0c3a950f \ + --hash=sha256:89e81afb11792d13d3777b503c6f21ec17b1a3b7de69cde1ae2c5471bcdcd4a0 \ + --hash=sha256:8a2003fab79694e04b5f260628511e441c248b46a9fc46138e2424038ac04ada \ + --hash=sha256:8a4ca121b6fc6b04300fa225fe6c31897e424db0d92691875af326f8c4e1cead \ + --hash=sha256:8a901666afc2d7a8d0c20decc8079763e3313457ee67210382162d90163c0007 \ + --hash=sha256:8b9ad4a01a1346b11acc574b7f932dea1a7c7ab31d93546a7540a1f02b3e724a \ + --hash=sha256:8d88c00c70b5914904feaf8f505f3512c2f3f4493dbbd93951fcdddc85dcfe8c \ + --hash=sha256:9104fc8915890e7292e5833fc677e4749607c67aa3cf8884677267078201c2f3 \ + --hash=sha256:9460cba62e941626beb75c99a803373b38a52136d5f1932fcdfdcede1df6f2ef \ + --hash=sha256:94983b3aa31ee401d2ac77ba570a3157d83f9508cfbb006095a48770e0a1c5ca \ + --hash=sha256:9b1e7e8f9ddc85f05d542b74157bdb73ed0e49aded67d1775f721fcd6eb9be94 \ + --hash=sha256:9d4b9d8c7e9335120ecf222d817699d17de743ad118080fb40467c367f009143 \ + --hash=sha256:a479fa25b6e2430e530d00f0c27a55e15ecb9de8ad2d0aec3d40b680e2d6df64 \ + --hash=sha256:a5c213a291c798139ed9ff80aec4bfcd2ac8f001bc015a9cdeb78457e9687dd3 \ + --hash=sha256:a63df296fec376c5cd08298a85109db4a130f4cc8df15916fc92d44ef6068937 \ + --hash=sha256:a74c45f04def041504bd21682eaf7f359f1a50dc7cf42b548b6f19aab50596bd \ + --hash=sha256:ad428ccdb2a40804919880dfe8d2a3021fd4418be15ea7ecb8434ab249badf9f \ + --hash=sha256:b0986ff1267a3b44d3ed76c3efb8b7239371444143f6e0d79f9dd23dbe02c7f9 \ + --hash=sha256:b45d7172c67fc8d2b69f77b384998b39793ee91f8b3b46c609297b781fb7eea5 \ + --hash=sha256:b4da3223c5b4cf694822752d0fbb6bf34c3f41648af1bd1b443cc3d68cc55106 \ + --hash=sha256:b7aa0de08bbbfcef6e49c107f9f397f5d4742548500f16e3e6c5e0b9e4ff0faa \ + --hash=sha256:b8adc912ecb6e4fd9df227ded66efaa6702f46a98e1403554be3c9c51d0ca920 \ + --hash=sha256:ba23caafd0e0c911bb7eab54e5cf69644af864d153e4b2abdab83ff0ef357ba1 \ + --hash=sha256:bda4572392ac1dff3fc67b6d9a4b1084e1637972e8135ad3788b4ce7cf0a90f5 \ + --hash=sha256:c017539e1196ea08f20aea3a4c473f758149b851678edd3d15773b4326decf83 \ + --hash=sha256:c1f20e8754abe312a701ee00d071ddd8502e9d97ca38fbc56204d14a9ffcb41c \ + --hash=sha256:c8841d44f3648b0662e99fc39ef8c248726ddfb4d1bfce4bdba982e51bb7e3f8 \ + --hash=sha256:cbb083a15ea968ad99e7da17d24632348d69e26534e83c69941f3020ed7536eb \ + --hash=sha256:cde70c97e4cc4e997e8fda2266e40a9bff7679c72ab4af6e15e81748a12882cc \ + --hash=sha256:cfa2e16993ba47e671a4e7ee1ad805f67b8d6744eb30a9d27ea0b07b3b7a22ed \ + --hash=sha256:d20765efa494a80a8fd91c4de8890f34de8e9f234da5516e8f34f55703cfb93d \ + --hash=sha256:d2844478690b5892159df0b2500e9d146dc8d3aa5b44e4564d05787b7330eca3 \ + --hash=sha256:d569730b2647c51a5ee68d67198aa9a78c7a55563d57b8cc1ca8d8c8377e7621 \ + --hash=sha256:daab231cf768937ce4675376ea3e214d399116d9867a6737372c31c58630bdfc \ + --hash=sha256:db00c604c1ae452f6092293bf230984d4f6cbb3ad905a9991e8cf680fd7d1523 \ + --hash=sha256:e058cc51d9d57b45801060af9f74765b95bedfc59fd6df1c7489ae0825126be5 \ + --hash=sha256:e28024e6e343353285cf99ae9c74210f0e89e47b2f0f3af7c72c4a9e89dc3ebc \ + --hash=sha256:e4a468ae1541614b5aa7b4f00254bce005ab7572fbb1fc764af4ee17d90fde7b \ + --hash=sha256:ea357657143f364a764b63b2b1ce12d77156d48a1f32def990b696d755acb629 \ + --hash=sha256:efd162e3bfc7812d75ebf2d0fb2783daee2407a92155af8a90650a6b0fa9342e \ + --hash=sha256:f02faeda66d531dc5f5356589afcf2a6bc41c8d00bc903efab60f9a2182b140d \ + --hash=sha256:f04286908654ffb05455254ebf72fe69473fc4560fc7ea49410df94dea6783a2 \ + --hash=sha256:f5197f864630162f008f5dfad3fceef32553c0fa7639eee1b8e280d924ed678e \ + --hash=sha256:f81f5340c8df50662abaf753ab07095901e40b934efb27da50032a4ae71c5a97 \ + --hash=sha256:fa8e3878a1857761d64f08a23b32140d29754a53f85f7c87186ced2b5b1b49cb \ + --hash=sha256:ff7326f36ed70d84c3fd62fb39bc6858f699640b8ab238c3cb8dafe1e200af59 + # via keras +optype[numpy]==0.15.0 \ + --hash=sha256:457d6ca9e7da19967ec16d42bdf94e240b33b5d70a56fbbf5b427e5ea39cf41e \ + --hash=sha256:caba40ece9ea39b499fa76c036a82e0d452a432dd4dd3e8e0d30892be2e8c76c + # via scipy-stubs +packaging==25.0 \ + --hash=sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484 \ + --hash=sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f # via # auditwheel # build + # keras # matplotlib # pytest -pillow==10.4.0 \ - --hash=sha256:02a2be69f9c9b8c1e97cf2713e789d4e398c751ecfd9967c18d0ce304efbf885 \ - --hash=sha256:030abdbe43ee02e0de642aee345efa443740aa4d828bfe8e2eb11922ea6a21ea \ - --hash=sha256:06b2f7898047ae93fad74467ec3d28fe84f7831370e3c258afa533f81ef7f3df \ - --hash=sha256:0755ffd4a0c6f267cccbae2e9903d95477ca2f77c4fcf3a3a09570001856c8a5 \ - --hash=sha256:0a9ec697746f268507404647e531e92889890a087e03681a3606d9b920fbee3c \ - --hash=sha256:0ae24a547e8b711ccaaf99c9ae3cd975470e1a30caa80a6aaee9a2f19c05701d \ - --hash=sha256:134ace6dc392116566980ee7436477d844520a26a4b1bd4053f6f47d096997fd \ - --hash=sha256:166c1cd4d24309b30d61f79f4a9114b7b2313d7450912277855ff5dfd7cd4a06 \ - --hash=sha256:1b5dea9831a90e9d0721ec417a80d4cbd7022093ac38a568db2dd78363b00908 \ - --hash=sha256:1d846aea995ad352d4bdcc847535bd56e0fd88d36829d2c90be880ef1ee4668a \ - --hash=sha256:1ef61f5dd14c300786318482456481463b9d6b91ebe5ef12f405afbba77ed0be \ - --hash=sha256:297e388da6e248c98bc4a02e018966af0c5f92dfacf5a5ca22fa01cb3179bca0 \ - --hash=sha256:298478fe4f77a4408895605f3482b6cc6222c018b2ce565c2b6b9c354ac3229b \ - --hash=sha256:29dbdc4207642ea6aad70fbde1a9338753d33fb23ed6956e706936706f52dd80 \ - --hash=sha256:2db98790afc70118bd0255c2eeb465e9767ecf1f3c25f9a1abb8ffc8cfd1fe0a \ - --hash=sha256:32cda9e3d601a52baccb2856b8ea1fc213c90b340c542dcef77140dfa3278a9e \ - --hash=sha256:37fb69d905be665f68f28a8bba3c6d3223c8efe1edf14cc4cfa06c241f8c81d9 \ - --hash=sha256:416d3a5d0e8cfe4f27f574362435bc9bae57f679a7158e0096ad2beb427b8696 \ - --hash=sha256:43efea75eb06b95d1631cb784aa40156177bf9dd5b4b03ff38979e048258bc6b \ - --hash=sha256:4b35b21b819ac1dbd1233317adeecd63495f6babf21b7b2512d244ff6c6ce309 \ - --hash=sha256:4d9667937cfa347525b319ae34375c37b9ee6b525440f3ef48542fcf66f2731e \ - --hash=sha256:5161eef006d335e46895297f642341111945e2c1c899eb406882a6c61a4357ab \ - --hash=sha256:543f3dc61c18dafb755773efc89aae60d06b6596a63914107f75459cf984164d \ - --hash=sha256:551d3fd6e9dc15e4c1eb6fc4ba2b39c0c7933fa113b220057a34f4bb3268a060 \ - --hash=sha256:59291fb29317122398786c2d44427bbd1a6d7ff54017075b22be9d21aa59bd8d \ - --hash=sha256:5b001114dd152cfd6b23befeb28d7aee43553e2402c9f159807bf55f33af8a8d \ - --hash=sha256:5b4815f2e65b30f5fbae9dfffa8636d992d49705723fe86a3661806e069352d4 \ - --hash=sha256:5dc6761a6efc781e6a1544206f22c80c3af4c8cf461206d46a1e6006e4429ff3 \ - --hash=sha256:5e84b6cc6a4a3d76c153a6b19270b3526a5a8ed6b09501d3af891daa2a9de7d6 \ - --hash=sha256:6209bb41dc692ddfee4942517c19ee81b86c864b626dbfca272ec0f7cff5d9fb \ - --hash=sha256:673655af3eadf4df6b5457033f086e90299fdd7a47983a13827acf7459c15d94 \ - --hash=sha256:6c762a5b0997f5659a5ef2266abc1d8851ad7749ad9a6a5506eb23d314e4f46b \ - --hash=sha256:7086cc1d5eebb91ad24ded9f58bec6c688e9f0ed7eb3dbbf1e4800280a896496 \ - --hash=sha256:73664fe514b34c8f02452ffb73b7a92c6774e39a647087f83d67f010eb9a0cf0 \ - --hash=sha256:76a911dfe51a36041f2e756b00f96ed84677cdeb75d25c767f296c1c1eda1319 \ - --hash=sha256:780c072c2e11c9b2c7ca37f9a2ee8ba66f44367ac3e5c7832afcfe5104fd6d1b \ - --hash=sha256:7928ecbf1ece13956b95d9cbcfc77137652b02763ba384d9ab508099a2eca856 \ - --hash=sha256:7970285ab628a3779aecc35823296a7869f889b8329c16ad5a71e4901a3dc4ef \ - --hash=sha256:7a8d4bade9952ea9a77d0c3e49cbd8b2890a399422258a77f357b9cc9be8d680 \ - --hash=sha256:7c1ee6f42250df403c5f103cbd2768a28fe1a0ea1f0f03fe151c8741e1469c8b \ - --hash=sha256:7dfecdbad5c301d7b5bde160150b4db4c659cee2b69589705b6f8a0c509d9f42 \ - --hash=sha256:812f7342b0eee081eaec84d91423d1b4650bb9828eb53d8511bcef8ce5aecf1e \ - --hash=sha256:866b6942a92f56300012f5fbac71f2d610312ee65e22f1aa2609e491284e5597 \ - --hash=sha256:86dcb5a1eb778d8b25659d5e4341269e8590ad6b4e8b44d9f4b07f8d136c414a \ - --hash=sha256:87dd88ded2e6d74d31e1e0a99a726a6765cda32d00ba72dc37f0651f306daaa8 \ - --hash=sha256:8bc1a764ed8c957a2e9cacf97c8b2b053b70307cf2996aafd70e91a082e70df3 \ - --hash=sha256:8d4d5063501b6dd4024b8ac2f04962d661222d120381272deea52e3fc52d3736 \ - --hash=sha256:8f0aef4ef59694b12cadee839e2ba6afeab89c0f39a3adc02ed51d109117b8da \ - --hash=sha256:930044bb7679ab003b14023138b50181899da3f25de50e9dbee23b61b4de2126 \ - --hash=sha256:950be4d8ba92aca4b2bb0741285a46bfae3ca699ef913ec8416c1b78eadd64cd \ - --hash=sha256:961a7293b2457b405967af9c77dcaa43cc1a8cd50d23c532e62d48ab6cdd56f5 \ - --hash=sha256:9b885f89040bb8c4a1573566bbb2f44f5c505ef6e74cec7ab9068c900047f04b \ - --hash=sha256:9f4727572e2918acaa9077c919cbbeb73bd2b3ebcfe033b72f858fc9fbef0026 \ - --hash=sha256:a02364621fe369e06200d4a16558e056fe2805d3468350df3aef21e00d26214b \ - --hash=sha256:a985e028fc183bf12a77a8bbf36318db4238a3ded7fa9df1b9a133f1cb79f8fc \ - --hash=sha256:ac1452d2fbe4978c2eec89fb5a23b8387aba707ac72810d9490118817d9c0b46 \ - --hash=sha256:b15e02e9bb4c21e39876698abf233c8c579127986f8207200bc8a8f6bb27acf2 \ - --hash=sha256:b2724fdb354a868ddf9a880cb84d102da914e99119211ef7ecbdc613b8c96b3c \ - --hash=sha256:bbc527b519bd3aa9d7f429d152fea69f9ad37c95f0b02aebddff592688998abe \ - --hash=sha256:bcd5e41a859bf2e84fdc42f4edb7d9aba0a13d29a2abadccafad99de3feff984 \ - --hash=sha256:bd2880a07482090a3bcb01f4265f1936a903d70bc740bfcb1fd4e8a2ffe5cf5a \ - --hash=sha256:bee197b30783295d2eb680b311af15a20a8b24024a19c3a26431ff83eb8d1f70 \ - --hash=sha256:bf2342ac639c4cf38799a44950bbc2dfcb685f052b9e262f446482afaf4bffca \ - --hash=sha256:c76e5786951e72ed3686e122d14c5d7012f16c8303a674d18cdcd6d89557fc5b \ - --hash=sha256:cbed61494057c0f83b83eb3a310f0bf774b09513307c434d4366ed64f4128a91 \ - --hash=sha256:cfdd747216947628af7b259d274771d84db2268ca062dd5faf373639d00113a3 \ - --hash=sha256:d7480af14364494365e89d6fddc510a13e5a2c3584cb19ef65415ca57252fb84 \ - --hash=sha256:dbc6ae66518ab3c5847659e9988c3b60dc94ffb48ef9168656e0019a93dbf8a1 \ - --hash=sha256:dc3e2db6ba09ffd7d02ae9141cfa0ae23393ee7687248d46a7507b75d610f4f5 \ - --hash=sha256:dfe91cb65544a1321e631e696759491ae04a2ea11d36715eca01ce07284738be \ - --hash=sha256:e4d49b85c4348ea0b31ea63bc75a9f3857869174e2bf17e7aba02945cd218e6f \ - --hash=sha256:e4db64794ccdf6cb83a59d73405f63adbe2a1887012e308828596100a0b2f6cc \ - --hash=sha256:e553cad5179a66ba15bb18b353a19020e73a7921296a7979c4a2b7f6a5cd57f9 \ - --hash=sha256:e88d5e6ad0d026fba7bdab8c3f225a69f063f116462c49892b0149e21b6c0a0e \ - --hash=sha256:ecd85a8d3e79cd7158dec1c9e5808e821feea088e2f69a974db5edf84dc53141 \ - --hash=sha256:f5b92f4d70791b4a67157321c4e8225d60b119c5cc9aee8ecf153aace4aad4ef \ - --hash=sha256:f5f0c3e969c8f12dd2bb7e0b15d5c468b51e5017e01e2e867335c81903046a22 \ - --hash=sha256:f7baece4ce06bade126fb84b8af1c33439a76d8a6fd818970215e0560ca28c27 \ - --hash=sha256:ff25afb18123cea58a591ea0244b92eb1e61a1fd497bf6d6384f09bc3262ec3e \ - --hash=sha256:ff337c552345e95702c5fde3158acb0625111017d0e5f24bf3acdb9cc16b90d1 + # tensorboard + # tensorflow + # wheel +pillow==12.0.0 \ + --hash=sha256:0869154a2d0546545cde61d1789a6524319fc1897d9ee31218eae7a60ccc5643 \ + --hash=sha256:09f2d0abef9e4e2f349305a4f8cc784a8a6c2f58a8c4892eea13b10a943bd26e \ + --hash=sha256:0b817e7035ea7f6b942c13aa03bb554fc44fea70838ea21f8eb31c638326584e \ + --hash=sha256:0fd00cac9c03256c8b2ff58f162ebcd2587ad3e1f2e397eab718c47e24d231cc \ + --hash=sha256:110486b79f2d112cf6add83b28b627e369219388f64ef2f960fef9ebaf54c642 \ + --hash=sha256:1979f4566bb96c1e50a62d9831e2ea2d1211761e5662afc545fa766f996632f6 \ + --hash=sha256:1ac11e8ea4f611c3c0147424eae514028b5e9077dd99ab91e1bd7bc33ff145e1 \ + --hash=sha256:1b1b133e6e16105f524a8dec491e0586d072948ce15c9b914e41cdadd209052b \ + --hash=sha256:1ee80a59f6ce048ae13cda1abf7fbd2a34ab9ee7d401c46be3ca685d1999a399 \ + --hash=sha256:21f241bdd5080a15bc86d3466a9f6074a9c2c2b314100dd896ac81ee6db2f1ba \ + --hash=sha256:266cd5f2b63ff316d5a1bba46268e603c9caf5606d44f38c2873c380950576ad \ + --hash=sha256:26d9f7d2b604cd23aba3e9faf795787456ac25634d82cd060556998e39c6fa47 \ + --hash=sha256:27f95b12453d165099c84f8a8bfdfd46b9e4bda9e0e4b65f0635430027f55739 \ + --hash=sha256:2c54c1a783d6d60595d3514f0efe9b37c8808746a66920315bfd34a938d7994b \ + --hash=sha256:2fa5f0b6716fc88f11380b88b31fe591a06c6315e955c096c35715788b339e3f \ + --hash=sha256:32ed80ea8a90ee3e6fa08c21e2e091bba6eda8eccc83dbc34c95169507a91f10 \ + --hash=sha256:3830c769decf88f1289680a59d4f4c46c72573446352e2befec9a8512104fa52 \ + --hash=sha256:38df9b4bfd3db902c9c2bd369bcacaf9d935b2fff73709429d95cc41554f7b3d \ + --hash=sha256:3adfb466bbc544b926d50fe8f4a4e6abd8c6bffd28a26177594e6e9b2b76572b \ + --hash=sha256:3e42edad50b6909089750e65c91aa09aaf1e0a71310d383f11321b27c224ed8a \ + --hash=sha256:4078242472387600b2ce8d93ade8899c12bf33fa89e55ec89fe126e9d6d5d9e9 \ + --hash=sha256:455247ac8a4cfb7b9bc45b7e432d10421aea9fc2e74d285ba4072688a74c2e9d \ + --hash=sha256:4cc6b3b2efff105c6a1656cfe59da4fdde2cda9af1c5e0b58529b24525d0a098 \ + --hash=sha256:4cf7fed4b4580601c4345ceb5d4cbf5a980d030fd5ad07c4d2ec589f95f09905 \ + --hash=sha256:5193fde9a5f23c331ea26d0cf171fbf67e3f247585f50c08b3e205c7aeb4589b \ + --hash=sha256:5269cc1caeedb67e6f7269a42014f381f45e2e7cd42d834ede3c703a1d915fe3 \ + --hash=sha256:53561a4ddc36facb432fae7a9d8afbfaf94795414f5cdc5fc52f28c1dca90371 \ + --hash=sha256:55f818bd74fe2f11d4d7cbc65880a843c4075e0ac7226bc1a23261dbea531953 \ + --hash=sha256:58eea5ebe51504057dd95c5b77d21700b77615ab0243d8152793dc00eb4faf01 \ + --hash=sha256:5d5c411a8eaa2299322b647cd932586b1427367fd3184ffbb8f7a219ea2041ca \ + --hash=sha256:6846bd2d116ff42cba6b646edf5bf61d37e5cbd256425fa089fee4ff5c07a99e \ + --hash=sha256:6ace95230bfb7cd79ef66caa064bbe2f2a1e63d93471c3a2e1f1348d9f22d6b7 \ + --hash=sha256:6e51b71417049ad6ab14c49608b4a24d8fb3fe605e5dfabfe523b58064dc3d27 \ + --hash=sha256:71db6b4c1653045dacc1585c1b0d184004f0d7e694c7b34ac165ca70c0838082 \ + --hash=sha256:7438839e9e053ef79f7112c881cef684013855016f928b168b81ed5835f3e75e \ + --hash=sha256:759de84a33be3b178a64c8ba28ad5c135900359e85fb662bc6e403ad4407791d \ + --hash=sha256:792a2c0be4dcc18af9d4a2dfd8a11a17d5e25274a1062b0ec1c2d79c76f3e7f8 \ + --hash=sha256:7d87ef5795da03d742bf49439f9ca4d027cde49c82c5371ba52464aee266699a \ + --hash=sha256:7dfb439562f234f7d57b1ac6bc8fe7f838a4bd49c79230e0f6a1da93e82f1fad \ + --hash=sha256:7fa22993bac7b77b78cae22bad1e2a987ddf0d9015c63358032f84a53f23cdc3 \ + --hash=sha256:805ebf596939e48dbb2e4922a1d3852cfc25c38160751ce02da93058b48d252a \ + --hash=sha256:82240051c6ca513c616f7f9da06e871f61bfd7805f566275841af15015b8f98d \ + --hash=sha256:87d4f8125c9988bfbed67af47dd7a953e2fc7b0cc1e7800ec6d2080d490bb353 \ + --hash=sha256:8d8ca2b210ada074d57fcee40c30446c9562e542fc46aedc19baf758a93532ee \ + --hash=sha256:8dc232e39d409036af549c86f24aed8273a40ffa459981146829a324e0848b4b \ + --hash=sha256:90387104ee8400a7b4598253b4c406f8958f59fcf983a6cea2b50d59f7d63d0b \ + --hash=sha256:905b0365b210c73afb0ebe9101a32572152dfd1c144c7e28968a331b9217b94a \ + --hash=sha256:99353a06902c2e43b43e8ff74ee65a7d90307d82370604746738a1e0661ccca7 \ + --hash=sha256:99a7f72fb6249302aa62245680754862a44179b545ded638cf1fef59befb57ef \ + --hash=sha256:9f0b04c6b8584c2c193babcccc908b38ed29524b29dd464bc8801bf10d746a3a \ + --hash=sha256:9fe611163f6303d1619bbcb653540a4d60f9e55e622d60a3108be0d5b441017a \ + --hash=sha256:a3475b96f5908b3b16c47533daaa87380c491357d197564e0ba34ae75c0f3257 \ + --hash=sha256:a6597ff2b61d121172f5844b53f21467f7082f5fb385a9a29c01414463f93b07 \ + --hash=sha256:a7921c5a6d31b3d756ec980f2f47c0cfdbce0fc48c22a39347a895f41f4a6ea4 \ + --hash=sha256:aa5129de4e174daccbc59d0a3b6d20eaf24417d59851c07ebb37aeb02947987c \ + --hash=sha256:aeaefa96c768fc66818730b952a862235d68825c178f1b3ffd4efd7ad2edcb7c \ + --hash=sha256:afbefa430092f71a9593a99ab6a4e7538bc9eabbf7bf94f91510d3503943edc4 \ + --hash=sha256:aff9e4d82d082ff9513bdd6acd4f5bd359f5b2c870907d2b0a9c5e10d40c88fe \ + --hash=sha256:b22bd8c974942477156be55a768f7aa37c46904c175be4e158b6a86e3a6b7ca8 \ + --hash=sha256:b290fd8aa38422444d4b50d579de197557f182ef1068b75f5aa8558638b8d0a5 \ + --hash=sha256:b2e4b27a6e15b04832fe9bf292b94b5ca156016bbc1ea9c2c20098a0320d6cf6 \ + --hash=sha256:b583dc9070312190192631373c6c8ed277254aa6e6084b74bdd0a6d3b221608e \ + --hash=sha256:b87843e225e74576437fd5b6a4c2205d422754f84a06942cfaf1dc32243e45a8 \ + --hash=sha256:bc91a56697869546d1b8f0a3ff35224557ae7f881050e99f615e0119bf934b4e \ + --hash=sha256:bd87e140e45399c818fac4247880b9ce719e4783d767e030a883a970be632275 \ + --hash=sha256:bde737cff1a975b70652b62d626f7785e0480918dece11e8fef3c0cf057351c3 \ + --hash=sha256:bdee52571a343d721fb2eb3b090a82d959ff37fc631e3f70422e0c2e029f3e76 \ + --hash=sha256:bee2a6db3a7242ea309aa7ee8e2780726fed67ff4e5b40169f2c940e7eb09227 \ + --hash=sha256:beeae3f27f62308f1ddbcfb0690bf44b10732f2ef43758f169d5e9303165d3f9 \ + --hash=sha256:c50f36a62a22d350c96e49ad02d0da41dbd17ddc2e29750dbdba4323f85eb4a5 \ + --hash=sha256:c607c90ba67533e1b2355b821fef6764d1dd2cbe26b8c1005ae84f7aea25ff79 \ + --hash=sha256:c7b2a63fd6d5246349f3d3f37b14430d73ee7e8173154461785e43036ffa96ca \ + --hash=sha256:c828a1ae702fc712978bda0320ba1b9893d99be0badf2647f693cc01cf0f04fa \ + --hash=sha256:c85de1136429c524e55cfa4e033b4a7940ac5c8ee4d9401cc2d1bf48154bbc7b \ + --hash=sha256:c98fa880d695de164b4135a52fd2e9cd7b7c90a9d8ac5e9e443a24a95ef9248e \ + --hash=sha256:cae81479f77420d217def5f54b5b9d279804d17e982e0f2fa19b1d1e14ab5197 \ + --hash=sha256:d034140032870024e6b9892c692fe2968493790dd57208b2c37e3fb35f6df3ab \ + --hash=sha256:d120c38a42c234dc9a8c5de7ceaaf899cf33561956acb4941653f8bdc657aa79 \ + --hash=sha256:d4827615da15cd59784ce39d3388275ec093ae3ee8d7f0c089b76fa87af756c2 \ + --hash=sha256:d49e2314c373f4c2b39446fb1a45ed333c850e09d0c59ac79b72eb3b95397363 \ + --hash=sha256:d52610d51e265a51518692045e372a4c363056130d922a7351429ac9f27e70b0 \ + --hash=sha256:d64317d2587c70324b79861babb9c09f71fbb780bad212018874b2c013d8600e \ + --hash=sha256:d77153e14b709fd8b8af6f66a3afbb9ed6e9fc5ccf0b6b7e1ced7b036a228782 \ + --hash=sha256:d7e091d464ac59d2c7ad8e7e08105eaf9dafbc3883fd7265ffccc2baad6ac925 \ + --hash=sha256:dd333073e0cacdc3089525c7df7d39b211bcdf31fc2824e49d01c6b6187b07d0 \ + --hash=sha256:e5d8efac84c9afcb40914ab49ba063d94f5dbdf5066db4482c66a992f47a3a3b \ + --hash=sha256:f135c702ac42262573fe9714dfe99c944b4ba307af5eb507abef1667e2cbbced \ + --hash=sha256:f13711b1a5ba512d647a0e4ba79280d3a9a045aaf7e0cc6fbe96b91d4cdf6b0c \ + --hash=sha256:f4f1231b7dec408e8670264ce63e9c71409d9583dd21d32c163e25213ee2a344 \ + --hash=sha256:fa3ed2a29a9e9d2d488b4da81dcb54720ac3104a20bf0bd273f1e4648aff5af9 \ + --hash=sha256:fb3096c30df99fd01c7bf8e544f392103d0795b9f98ba71a8054bcbf56b255f1 # via # -r build/test-requirements.txt # matplotlib -pluggy==1.5.0 \ - --hash=sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1 \ - --hash=sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669 + # tensorboard +pluggy==1.6.0 \ + --hash=sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3 \ + --hash=sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746 # via pytest portpicker==1.6.0 \ --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa # via -r build/test-requirements.txt -psutil==6.0.0 \ - --hash=sha256:02b69001f44cc73c1c5279d02b30a817e339ceb258ad75997325e0e6169d8b35 \ - --hash=sha256:1287c2b95f1c0a364d23bc6f2ea2365a8d4d9b726a3be7294296ff7ba97c17f0 \ - --hash=sha256:1e7c870afcb7d91fdea2b37c24aeb08f98b6d67257a5cb0a8bc3ac68d0f1a68c \ - --hash=sha256:21f1fb635deccd510f69f485b87433460a603919b45e2a324ad65b0cc74f8fb1 \ - --hash=sha256:33ea5e1c975250a720b3a6609c490db40dae5d83a4eb315170c4fe0d8b1f34b3 \ - --hash=sha256:34859b8d8f423b86e4385ff3665d3f4d94be3cdf48221fbe476e883514fdb71c \ - --hash=sha256:5fd9a97c8e94059b0ef54a7d4baf13b405011176c3b6ff257c247cae0d560ecd \ - --hash=sha256:6ec7588fb3ddaec7344a825afe298db83fe01bfaaab39155fa84cf1c0d6b13c3 \ - --hash=sha256:6ed2440ada7ef7d0d608f20ad89a04ec47d2d3ab7190896cd62ca5fc4fe08bf0 \ - --hash=sha256:8faae4f310b6d969fa26ca0545338b21f73c6b15db7c4a8d934a5482faa818f2 \ - --hash=sha256:a021da3e881cd935e64a3d0a20983bda0bb4cf80e4f74fa9bfcb1bc5785360c6 \ - --hash=sha256:a495580d6bae27291324fe60cea0b5a7c23fa36a7cd35035a16d93bdcf076b9d \ - --hash=sha256:a9a3dbfb4de4f18174528d87cc352d1f788b7496991cca33c6996f40c9e3c92c \ - --hash=sha256:c588a7e9b1173b6e866756dde596fd4cad94f9399daf99ad8c3258b3cb2b47a0 \ - --hash=sha256:e2e8d0054fc88153ca0544f5c4d554d42e33df2e009c4ff42284ac9ebdef4132 \ - --hash=sha256:fc8c9510cde0146432bbdb433322861ee8c3efbf8589865c8bf8d21cb30c4d14 \ - --hash=sha256:ffe7fc9b6b36beadc8c322f84e1caff51e8703b88eee1da46d1e3a6ae11b4fd0 +protobuf==6.33.2 \ + --hash=sha256:1f8017c48c07ec5859106533b682260ba3d7c5567b1ca1f24297ce03384d1b4f \ + --hash=sha256:2981c58f582f44b6b13173e12bb8656711189c2a70250845f264b877f00b1913 \ + --hash=sha256:56dc370c91fbb8ac85bc13582c9e373569668a290aa2e66a590c2a0d35ddb9e4 \ + --hash=sha256:7109dcc38a680d033ffb8bf896727423528db9163be1b6a02d6a49606dcadbfe \ + --hash=sha256:7636aad9bb01768870266de5dc009de2d1b936771b38a793f73cbbf279c91c5c \ + --hash=sha256:87eb388bd2d0f78febd8f4c8779c79247b26a5befad525008e49a6955787ff3d \ + --hash=sha256:8cd7640aee0b7828b6d03ae518b5b4806fdfc1afe8de82f79c3454f8aef29872 \ + --hash=sha256:b5d3b5625192214066d99b2b605f5783483575656784de223f00a8d00754fc0e \ + --hash=sha256:d9b19771ca75935b3a4422957bc518b0cecb978b31d1dd12037b088f6bcc0e43 \ + --hash=sha256:fc2a0e8b05b180e5fc0dd1559fe8ebdae21a27e81ac77728fb6c42b12c7419b4 + # via + # tensorboard + # tensorflow +psutil==7.1.3 \ + --hash=sha256:0005da714eee687b4b8decd3d6cc7c6db36215c9e74e5ad2264b90c3df7d92dc \ + --hash=sha256:1068c303be3a72f8e18e412c5b2a8f6d31750fb152f9cb106b54090296c9d251 \ + --hash=sha256:18349c5c24b06ac5612c0428ec2a0331c26443d259e2a0144a9b24b4395b58fa \ + --hash=sha256:19644c85dcb987e35eeeaefdc3915d059dac7bd1167cdcdbf27e0ce2df0c08c0 \ + --hash=sha256:2bdbcd0e58ca14996a42adf3621a6244f1bb2e2e528886959c72cf1e326677ab \ + --hash=sha256:31d77fcedb7529f27bb3a0472bea9334349f9a04160e8e6e5020f22c59893264 \ + --hash=sha256:3792983e23b69843aea49c8f5b8f115572c5ab64c153bada5270086a2123c7e7 \ + --hash=sha256:3bb428f9f05c1225a558f53e30ccbad9930b11c3fc206836242de1091d3e7dd3 \ + --hash=sha256:56d974e02ca2c8eb4812c3f76c30e28836fffc311d55d979f1465c1feeb2b68b \ + --hash=sha256:6c86281738d77335af7aec228328e944b30930899ea760ecf33a4dba66be5e74 \ + --hash=sha256:8f33a3702e167783a9213db10ad29650ebf383946e91bc77f28a5eb083496bc9 \ + --hash=sha256:95ef04cf2e5ba0ab9eaafc4a11eaae91b44f4ef5541acd2ee91d9108d00d59a7 \ + --hash=sha256:ad81425efc5e75da3f39b3e636293360ad8d0b49bed7df824c79764fb4ba9b8b \ + --hash=sha256:b403da1df4d6d43973dc004d19cee3b848e998ae3154cc8097d139b77156c353 \ + --hash=sha256:bc31fa00f1fbc3c3802141eede66f3a2d51d89716a194bf2cd6fc68310a19880 \ + --hash=sha256:bd0d69cee829226a761e92f28140bec9a5ee9d5b4fb4b0cc589068dbfff559b1 \ + --hash=sha256:c525ffa774fe4496282fb0b1187725793de3e7c6b29e41562733cae9ada151ee \ + --hash=sha256:f39c2c19fe824b47484b96f9692932248a54c43799a84282cfe58d05a6449efd \ + --hash=sha256:fac9cd332c67f4422504297889da5ab7e05fd11e3c4392140f7370f4208ded1f # via portpicker -pyelftools==0.31 \ - --hash=sha256:c774416b10310156879443b81187d182d8d9ee499660380e645918b50bc88f99 \ - --hash=sha256:f52de7b3c7e8c64c8abc04a79a1cf37ac5fb0b8a49809827130b858944840607 +pyelftools==0.32 \ + --hash=sha256:013df952a006db5e138b1edf6d8a68ecc50630adbd0d83a2d41e7f846163d738 \ + --hash=sha256:6de90ee7b8263e740c8715a925382d4099b354f29ac48ea40d840cf7aa14ace5 # via auditwheel -pygments==2.18.0 \ - --hash=sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199 \ - --hash=sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a - # via rich -pyparsing==3.1.4 \ - --hash=sha256:a6a7ee4235a3f944aa1fa2249307708f893fe5717dc603503c6c7969c070fb7c \ - --hash=sha256:f86ec8d1a83f11977c9a6ea7598e8c27fc5cddfa5b07ea2241edbbde1d7bc032 +pygments==2.19.2 \ + --hash=sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887 \ + --hash=sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b + # via + # pytest + # rich +pyparsing==3.2.5 \ + --hash=sha256:2df8d5b7b2802ef88e8d016a2eb9c7aeaa923529cd251ed0fe4608275d4105b6 \ + --hash=sha256:e38a4f02064cf41fe6593d328d0512495ad1f3d8a91c4f73fc401b3079a59a5e # via matplotlib pyproject-hooks==1.2.0 \ --hash=sha256:1e859bd5c40fae9448642dd871adf459e5e2084186e8d2c2a79a824c970da1f8 \ --hash=sha256:9e5c6bfa8dcc30091c74b0cf803c81fdd29d94f01992a7707bc97babb1141913 # via build -pytest==8.3.3 \ - --hash=sha256:70b98107bd648308a7952b06e6ca9a50bc660be218d53c257cc1fc94fda10181 \ - --hash=sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2 - # via pytest-xdist -pytest-xdist==3.6.1 \ - --hash=sha256:9ed4adfb68a016610848639bb7e02c9352d5d9f03d04809919e2dafc3be4cca7 \ - --hash=sha256:ead156a4db231eec769737f57668ef58a2084a34b2e55c4a8fa20d861107300d +pytest==8.4.2 \ + --hash=sha256:86c0d0b93306b961d58d62a4db4879f27fe25513d4b969df351abdddb3c30e01 \ + --hash=sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79 + # via + # -r build/test-requirements.txt + # pytest-xdist +pytest-xdist==3.8.0 \ + --hash=sha256:202ca578cfeb7370784a8c33d6d05bc6e13b4f25b5053c30a152269fd10f0b88 \ + --hash=sha256:7e578125ec9bc6050861aa93f2d59f1d8d085595d6551c2c90b6f4fad8d3a9f1 # via -r build/test-requirements.txt python-dateutil==2.9.0.post0 \ --hash=sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3 \ --hash=sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427 # via matplotlib -rich==13.9.2 \ - --hash=sha256:51a2c62057461aaf7152b4d611168f93a9fc73068f8ded2790f29fe2b5366d0c \ - --hash=sha256:8c82a3d3f8dcfe9e734771313e606b39d8247bb6b826e196f4914b333b743cf1 +requests==2.32.5 \ + --hash=sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6 \ + --hash=sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf + # via tensorflow +rich==14.2.0 \ + --hash=sha256:73ff50c7c0c1c77c8243079283f4edb376f0f6442433aecb8ce7e6d0b92d1fe4 \ + --hash=sha256:76bc51fe2e57d2b1be1f96c524b890b816e334ab4c1e45888799bfaab0021edd + # via + # -r build/test-requirements.txt + # keras +scipy==1.16.3 ; python_version >= "3.13" \ + --hash=sha256:0151a0749efeaaab78711c78422d413c583b8cdd2011a3c1d6c794938ee9fdb2 \ + --hash=sha256:01e87659402762f43bd2fee13370553a17ada367d42e7487800bf2916535aecb \ + --hash=sha256:03192a35e661470197556de24e7cb1330d84b35b94ead65c46ad6f16f6b28f2a \ + --hash=sha256:0553371015692a898e1aa858fed67a3576c34edefa6b7ebdb4e9dde49ce5c203 \ + --hash=sha256:062246acacbe9f8210de8e751b16fc37458213f124bef161a5a02c7a39284304 \ + --hash=sha256:0c3b4dd3d9b08dbce0f3440032c52e9e2ab9f96ade2d3943313dfe51a7056959 \ + --hash=sha256:0c623a54f7b79dd88ef56da19bc2873afec9673a48f3b85b18e4d402bdd29a5a \ + --hash=sha256:16b8bc35a4cc24db80a0ec836a9286d0e31b2503cb2fd7ff7fb0e0374a97081d \ + --hash=sha256:1fb2472e72e24d1530debe6ae078db70fb1605350c88a3d14bc401d6306dbffe \ + --hash=sha256:21d9d6b197227a12dcbf9633320a4e34c6b0e51c57268df255a0942983bac562 \ + --hash=sha256:2a207a6ce9c24f1951241f4693ede2d393f59c07abc159b2cb2be980820e01fb \ + --hash=sha256:2b71d93c8a9936046866acebc915e2af2e292b883ed6e2cbe5c34beb094b82d9 \ + --hash=sha256:2d1ae2cf0c350e7705168ff2429962a89ad90c2d49d1dd300686d8b2a5af22fc \ + --hash=sha256:3a4c460301fb2cffb7f88528f30b3127742cff583603aa7dc964a52c463b385d \ + --hash=sha256:3d4a07a8e785d80289dfe66b7c27d8634a773020742ec7187b85ccc4b0e7b686 \ + --hash=sha256:40be6cf99e68b6c4321e9f8782e7d5ff8265af28ef2cd56e9c9b2638fa08ad97 \ + --hash=sha256:4aff59800a3b7f786b70bfd6ab551001cb553244988d7d6b8299cb1ea653b353 \ + --hash=sha256:50a3dbf286dbc7d84f176f9a1574c705f277cb6565069f88f60db9eafdbe3ee2 \ + --hash=sha256:532fb5ad6a87e9e9cd9c959b106b73145a03f04c7d57ea3e6f6bb60b86ab0876 \ + --hash=sha256:53c3844d527213631e886621df5695d35e4f6a75f620dca412bcd292f6b87d78 \ + --hash=sha256:56edc65510d1331dae01ef9b658d428e33ed48b4f77b1d51caf479a0253f96dc \ + --hash=sha256:57d01cb6f85e34f0946b33caa66e892aae072b64b034183f3d87c4025802a119 \ + --hash=sha256:5803c5fadd29de0cf27fa08ccbfe7a9e5d741bf63e4ab1085437266f12460ff9 \ + --hash=sha256:6020470b9d00245926f2d5bb93b119ca0340f0d564eb6fbaad843eaebf9d690f \ + --hash=sha256:63d3cdacb8a824a295191a723ee5e4ea7768ca5ca5f2838532d9f2e2b3ce2135 \ + --hash=sha256:663b8d66a8748051c3ee9c96465fb417509315b99c71550fda2591d7dd634234 \ + --hash=sha256:72d1717fd3b5e6ec747327ce9bda32d5463f472c9dce9f54499e81fbd50245a1 \ + --hash=sha256:7dc1360c06535ea6116a2220f760ae572db9f661aba2d88074fe30ec2aa1ff88 \ + --hash=sha256:7f68154688c515cdb541a31ef8eb66d8cd1050605be9dcd74199cbd22ac739bc \ + --hash=sha256:81fc5827606858cf71446a5e98715ba0e11f0dbc83d71c7409d05486592a45d6 \ + --hash=sha256:875555ce62743e1d54f06cdf22c1e0bc47b91130ac40fe5d783b6dfa114beeb6 \ + --hash=sha256:8b3c820ddb80029fe9f43d61b81d8b488d3ef8ca010d15122b152db77dc94c22 \ + --hash=sha256:8be1ca9170fcb6223cc7c27f4305d680ded114a1567c0bd2bfcbf947d1b17511 \ + --hash=sha256:8d09d72dc92742988b0e7750bddb8060b0c7079606c0d24a8cc8e9c9c11f9079 \ + --hash=sha256:9452781bd879b14b6f055b26643703551320aa8d79ae064a71df55c00286a184 \ + --hash=sha256:96491a6a54e995f00a28a3c3badfff58fd093bf26cd5fb34a2188c8c756a3a2c \ + --hash=sha256:9b9c9c07b6d56a35777a1b4cc8966118fb16cfd8daf6743867d17d36cfad2d40 \ + --hash=sha256:a8a26c78ef223d3e30920ef759e25625a0ecdd0d60e5a8818b7513c3e5384cf2 \ + --hash=sha256:aadd23f98f9cb069b3bd64ddc900c4d277778242e961751f77a8cb5c4b946fb0 \ + --hash=sha256:b7180967113560cca57418a7bc719e30366b47959dd845a93206fbed693c867e \ + --hash=sha256:b7c5f1bda1354d6a19bc6af73a649f8285ca63ac6b52e64e658a5a11d4d69800 \ + --hash=sha256:b81c27fc41954319a943d43b20e07c40bdcd3ff7cf013f4fb86286faefe546c4 \ + --hash=sha256:bb61878c18a470021fb515a843dc7a76961a8daceaaaa8bad1332f1bf4b54657 \ + --hash=sha256:bea0a62734d20d67608660f69dcda23e7f90fb4ca20974ab80b6ed40df87a005 \ + --hash=sha256:c5192722cffe15f9329a3948c4b1db789fbb1f05c97899187dcf009b283aea70 \ + --hash=sha256:c97176013d404c7346bf57874eaac5187d969293bf40497140b0a2b2b7482e07 \ + --hash=sha256:cd13e354df9938598af2be05822c323e97132d5e6306b83a3b4ee6724c6e522e \ + --hash=sha256:d2ec56337675e61b312179a1ad124f5f570c00f920cc75e1000025451b88241c \ + --hash=sha256:d3837938ae715fc0fe3c39c0202de3a8853aff22ca66781ddc2ade7554b7e2cc \ + --hash=sha256:d9f48cafc7ce94cf9b15c6bffdc443a81a27bf7075cf2dcd5c8b40f85d10c4e7 \ + --hash=sha256:da7763f55885045036fabcebd80144b757d3db06ab0861415d1c3b7c69042146 \ + --hash=sha256:deb3841c925eeddb6afc1e4e4a45e418d19ec7b87c5df177695224078e8ec733 \ + --hash=sha256:e1d27cbcb4602680a49d787d90664fa4974063ac9d4134813332a8c53dbe667c \ + --hash=sha256:e5d42a9472e7579e473879a1990327830493a7047506d58d73fc429b84c1d49d \ + --hash=sha256:e7efa2681ea410b10dde31a52b18b0154d66f2485328830e45fdf183af5aefc6 \ + --hash=sha256:eab43fae33a0c39006a88096cd7b4f4ef545ea0447d250d5ac18202d40b6611d \ + --hash=sha256:f2622206f5559784fa5c4b53a950c3c7c1cf3e84ca1b9c4b6c03f062f289ca26 \ + --hash=sha256:f379b54b77a597aa7ee5e697df0d66903e41b9c85a6dd7946159e356319158e8 \ + --hash=sha256:f667a4542cc8917af1db06366d3f78a5c8e83badd56409f94d1eac8d8d9133fa \ + --hash=sha256:fb4b29f4cf8cc5a8d628bc8d8e26d12d7278cd1f219f22698a378c3d67db5e4b \ + --hash=sha256:ffa6eea95283b2b8079b821dc11f50a17d0571c92b43e2b5b12764dc5f9b285d + # via + # -r build/requirements.in + # jaxlib +scipy-stubs==1.16.3.3 \ + --hash=sha256:af47578875d5557567225a16ec1b9b38a48c4c4377d92396413ebd65406c44ee \ + --hash=sha256:f6316b36cd0fb272c994ae5b10c4a73c644a7e156ed8d32bcd9c35303d0e1b7e # via -r build/test-requirements.txt -scipy==1.14.1 \ - --hash=sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e \ - --hash=sha256:1729560c906963fc8389f6aac023739ff3983e727b1a4d87696b7bf108316a79 \ - --hash=sha256:278266012eb69f4a720827bdd2dc54b2271c97d84255b2faaa8f161a158c3b37 \ - --hash=sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5 \ - --hash=sha256:2da0469a4ef0ecd3693761acbdc20f2fdeafb69e6819cc081308cc978153c675 \ - --hash=sha256:2ff0a7e01e422c15739ecd64432743cf7aae2b03f3084288f399affcefe5222d \ - --hash=sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f \ - --hash=sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310 \ - --hash=sha256:3a1b111fac6baec1c1d92f27e76511c9e7218f1695d61b59e05e0fe04dc59617 \ - --hash=sha256:4079b90df244709e675cdc8b93bfd8a395d59af40b72e339c2287c91860deb8e \ - --hash=sha256:5149e3fd2d686e42144a093b206aef01932a0059c2a33ddfa67f5f035bdfe13e \ - --hash=sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417 \ - --hash=sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d \ - --hash=sha256:716e389b694c4bb564b4fc0c51bc84d381735e0d39d3f26ec1af2556ec6aad94 \ - --hash=sha256:8426251ad1e4ad903a4514712d2fa8fdd5382c978010d1c6f5f37ef286a713ad \ - --hash=sha256:8475230e55549ab3f207bff11ebfc91c805dc3463ef62eda3ccf593254524ce8 \ - --hash=sha256:8bddf15838ba768bb5f5083c1ea012d64c9a444e16192762bd858f1e126196d0 \ - --hash=sha256:8e32dced201274bf96899e6491d9ba3e9a5f6b336708656466ad0522d8528f69 \ - --hash=sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066 \ - --hash=sha256:97c5dddd5932bd2a1a31c927ba5e1463a53b87ca96b5c9bdf5dfd6096e27efc3 \ - --hash=sha256:a49f6ed96f83966f576b33a44257d869756df6cf1ef4934f59dd58b25e0327e5 \ - --hash=sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07 \ - --hash=sha256:b05d43735bb2f07d689f56f7b474788a13ed8adc484a85aa65c0fd931cf9ccd2 \ - --hash=sha256:b28d2ca4add7ac16ae8bb6632a3c86e4b9e4d52d3e34267f6e1b0c1f8d87e389 \ - --hash=sha256:b99722ea48b7ea25e8e015e8341ae74624f72e5f21fc2abd45f3a93266de4c5d \ - --hash=sha256:baff393942b550823bfce952bb62270ee17504d02a1801d7fd0719534dfb9c84 \ - --hash=sha256:c0ee987efa6737242745f347835da2cc5bb9f1b42996a4d97d5c7ff7928cb6f2 \ - --hash=sha256:d0d2821003174de06b69e58cef2316a6622b60ee613121199cb2852a873f8cf3 \ - --hash=sha256:e0cf28db0f24a38b2a0ca33a85a54852586e43cf6fd876365c86e0657cfe7d73 \ - --hash=sha256:e4f5a7c49323533f9103d4dacf4e4f07078f360743dec7f7596949149efeec06 \ - --hash=sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc \ - --hash=sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1 \ - --hash=sha256:fef8c87f8abfb884dac04e97824b61299880c43f4ce675dd2cbeadd3c9b466d2 - # via -r build/requirements.in -six==1.16.0 \ - --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ - --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 - # via python-dateutil +six==1.17.0 \ + --hash=sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274 \ + --hash=sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81 + # via + # astunparse + # google-pasta + # python-dateutil + # tensorflow sortedcontainers==2.4.0 \ --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \ --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0 # via hypothesis -typing-extensions==4.12.2 \ - --hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \ - --hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8 - # via etils -wheel==0.44.0 \ - --hash=sha256:2376a90c98cc337d18623527a97c31797bd02bad0033d41547043a1cbfbe448f \ - --hash=sha256:a29c3f2817e95ab89aa4660681ad547c0e9547f20e75b0562fe7723c9a2a9d49 - # via -r build/test-requirements.txt -zipp==3.20.2 \ - --hash=sha256:a817ac80d6cf4b23bf7f2828b7cabf326f15a001bea8b1f9b49631780ba28350 \ - --hash=sha256:bc9eb26f4506fda01b81bcde0ca78103b6e62f991b381fec825435c836edbc29 +tensorboard==2.20.0 \ + --hash=sha256:9dc9f978cb84c0723acf9a345d96c184f0293d18f166bb8d59ee098e6cfaaba6 + # via tensorflow +tensorboard-data-server==0.7.2 \ + --hash=sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb \ + --hash=sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60 \ + --hash=sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530 + # via tensorboard +tensorflow==2.20.0 ; python_version < "3.14" \ + --hash=sha256:02a0293d94f5c8b7125b66abf622cc4854a33ae9d618a0d41309f95e091bbaea \ + --hash=sha256:0deb5c583dfc53b54fd158a194ce0087b406bb6518af400ca3809735e4548ec3 \ + --hash=sha256:1590cbf87b6bcbd34d8e9ad70d0c696135e0aa71be31803b27358cf7ed63f8fc \ + --hash=sha256:197f0b613b38c0da5c6a12a8295ad4a05c78b853835dae8e0f9dfae3ce9ce8a5 \ + --hash=sha256:25265b0bc527e0d54b1e9cc60c44a24f44a809fe27666b905f0466471f9c52ec \ + --hash=sha256:28bc33759249c98eabcee9debd24e74506bbe29ac139e050cf0c74aa9888ebdf \ + --hash=sha256:2bfbfb3dd0e22bffc45fe1e922390d27753e99261fab8a882e802cf98a0e078f \ + --hash=sha256:3e9568c8efcb05c0266be223e3269c62ebf7ad3498f156438311735f6fa5ced5 \ + --hash=sha256:47c88e05a07f1ead4977b4894b3ecd4d8075c40191065afc4fd9355c9db3d926 \ + --hash=sha256:481499fd0f824583de8945be61d5e827898cdaa4f5ea1bc2cc28ca2ccff8229e \ + --hash=sha256:4a69ac2c2ce20720abf3abf917b4e86376326c0976fcec3df330e184b81e4088 \ + --hash=sha256:52b122f0232fd7ab10f28d537ce08470d0b6dcac7fff9685432daac7f8a06c8f \ + --hash=sha256:5f964016c5035d09b85a246a6b739be89282a7839743f3ea63640224f0c63aee \ + --hash=sha256:5fa3729b0126f75a99882b89fb7d536515721eda8014a63e259e780ba0a37372 \ + --hash=sha256:7551558a48c2e2f6c32a1537f06c654a9df1408a1c18e7b99c3caafbd03edfe3 \ + --hash=sha256:7abd7f3a010e0d354dc804182372779a722d474c4d8a3db8f4a3f5baef2a591e \ + --hash=sha256:a66cbd1b19209d3fbc45cbea80de92514ba455434013937251d65d444779783c \ + --hash=sha256:c25edad45e8cb9e76366f7a8c835279f9169028d610f3b52ce92d332a1b05438 \ + --hash=sha256:dd71a7e7c3270239f4185915e8f2c5d39608c5e18973d6e1d101b153993841eb \ + --hash=sha256:e5f169f8f5130ab255bbe854c5f0ae152e93d3d1ac44f42cb1866003b81a5357 + # via -r build/nonfreethreading-requirements.txt +tensorstore==0.1.80 \ + --hash=sha256:04c29d979eb8b8ee48f873dc13d2701bfd49425500ffc5b848e4ec55b2548281 \ + --hash=sha256:07e4a84bacf70b78305831897068a9b5ad30326e63bbeb92c4bf7e565fcf5e9e \ + --hash=sha256:1113a6982fc0fa8dda8fcc0495715e647ac3360909a86ff13f2e04564f82d54a \ + --hash=sha256:189d924eaec394c9331e284a9c513ed583e336472a925823b5151cb26f41d091 \ + --hash=sha256:1b2b2ed0051dfab7e25295b14e6620520729e6e2ddf505f98c8d3917569614bf \ + --hash=sha256:246641a8780ee5e04e88bc95c8e31faac6471bab1180d1f5cdc9804b29a77c04 \ + --hash=sha256:4158fe76b96f62d12a37d7868150d836e089b5280b2bdd363c43c5d651f10e26 \ + --hash=sha256:46136fe42ee6dd835d957db37073058aea0b78fdfbe2975941640131b7740824 \ + --hash=sha256:4baee67fce95f29f593fbab4866119347115eaace887732aa92cfcbb9e6b0748 \ + --hash=sha256:53fd121ccd332bc4cc397f7af45889360c668b43dc3ff6bc3264df0f9886c11a \ + --hash=sha256:6b7c5dd434bba4ee08fe46bbbdb25c60dd3d47ccb4b8561a9751cf1526da52b8 \ + --hash=sha256:6c8dbbdd31cbb28eccfb23dbbd4218fe67bfc32e9cb452875a485b81031c949d \ + --hash=sha256:7451b30f99d9f31a2b9d70e6ef61815713dc782c58c6d817f91781341e4dac05 \ + --hash=sha256:8cd11027b5a8b66db8d344085a31a1666c78621dac27039c4d571bc4974804a1 \ + --hash=sha256:9c088e8c9f67c266ef4dae3703bd617f7c0cb0fd98e99c4500692e38a4328140 \ + --hash=sha256:a92505189731fcb03f1c69a84ea4460abb24204bfac1f339448a0621e7def77c \ + --hash=sha256:acb8d52fadcefafef4ef8ecca3fc99b1d0e3c5c5a888766484c3e39f050be7f5 \ + --hash=sha256:b193a7a1c4f455a61e60ed2dd67271a3daab0910ddb4bd9db51390d1b36d9996 \ + --hash=sha256:bc28a58c580253a526a4b6d239d18181ef96f1e285a502dbb03ff15eeec07a5b \ + --hash=sha256:c0529afab3800749dd245843d3bf0d061a109a8edb77fb345f476e8bccda51b8 \ + --hash=sha256:d2b353b0bd53fedd77fc5a12a1c1a91cacc3cf59e3dd785529c5a54b31d1c7b1 \ + --hash=sha256:de63843706fdfe9565a45567238c5b1e55a0b28bbde6524200b31d29043a9a16 \ + --hash=sha256:e93df6d34ff5f0f6be245f4d29b99a7c1eef8ad91b50686adf57a5eeea99cb74 \ + --hash=sha256:f65dfaf9e737a41389e29a5a2ea52ca5d14c8d6f48b402c723d800cd16d322b0 \ + --hash=sha256:f8b51d7e685bbb63f6becd7d2ac8634d5ab67ec7e53038e597182e2db2c7aa90 + # via -r build/nonfreethreading-requirements.txt +termcolor==3.2.0 \ + --hash=sha256:610e6456feec42c4bcd28934a8c87a06c3fa28b01561d46aa09a9881b8622c58 \ + --hash=sha256:a10343879eba4da819353c55cb8049b0933890c2ebf9ad5d3ecd2bb32ea96ea6 + # via tensorflow +typing-extensions==4.15.0 \ + --hash=sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466 \ + --hash=sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548 + # via + # etils + # grpcio + # optree + # tensorflow +urllib3==2.6.2 \ + --hash=sha256:016f9c98bb7e98085cb2b4b17b87d2c702975664e4f060c6532e64d1c1a5e797 \ + --hash=sha256:ec21cddfe7724fc7cb4ba4bea7aa8e2ef36f607a4bab81aa6ce42a13dc3f03dd + # via requests +werkzeug==3.1.4 \ + --hash=sha256:2ad50fb9ed09cc3af22c54698351027ace879a0b60a3b5edf5730b2f7d876905 \ + --hash=sha256:cd3cd98b1b92dc3b7b3995038826c68097dcb16f9baa63abe35f20eafeb9fe5e + # via tensorboard +wheel==0.46.1 \ + --hash=sha256:f796f65d72750ccde090663e466d0ca37cd72b62870f7520b96d34cdc07d86d8 \ + --hash=sha256:fd477efb5da0f7df1d3c76c73c14394002c844451bd63229d8570f376f5e6a38 + # via + # -r build/requirements.in + # astunparse +wrapt==2.0.1 \ + --hash=sha256:09c7476ab884b74dce081ad9bfd07fe5822d8600abade571cb1f66d5fc915af6 \ + --hash=sha256:0e17283f533a0d24d6e5429a7d11f250a58d28b4ae5186f8f47853e3e70d2590 \ + --hash=sha256:115cae4beed3542e37866469a8a1f2b9ec549b4463572b000611e9946b86e6f6 \ + --hash=sha256:1218573502a8235bb8a7ecaed12736213b22dcde9feab115fa2989d42b5ded45 \ + --hash=sha256:17fb85fa4abc26a5184d93b3efd2dcc14deb4b09edcdb3535a536ad34f0b4dba \ + --hash=sha256:1e9b121e9aeb15df416c2c960b8255a49d44b4038016ee17af03975992d03931 \ + --hash=sha256:1f186e26ea0a55f809f232e92cc8556a0977e00183c3ebda039a807a42be1494 \ + --hash=sha256:1fdbb34da15450f2b1d735a0e969c24bdb8d8924892380126e2a293d9902078c \ + --hash=sha256:23097ed8bc4c93b7bf36fa2113c6c733c976316ce0ee2c816f64ca06102034ef \ + --hash=sha256:2879af909312d0baf35f08edeea918ee3af7ab57c37fe47cb6a373c9f2749c7b \ + --hash=sha256:2afa23318136709c4b23d87d543b425c399887b4057936cd20386d5b1422b6fa \ + --hash=sha256:2da620b31a90cdefa9cd0c2b661882329e2e19d1d7b9b920189956b76c564d75 \ + --hash=sha256:35cdbd478607036fee40273be8ed54a451f5f23121bd9d4be515158f9498f7ad \ + --hash=sha256:36982b26f190f4d737f04a492a68accbfc6fa042c3f42326fdfbb6c5b7a20a31 \ + --hash=sha256:3793ac154afb0e5b45d1233cb94d354ef7a983708cc3bb12563853b1d8d53747 \ + --hash=sha256:386fb54d9cd903ee0012c09291336469eb7b244f7183d40dc3e86a16a4bace62 \ + --hash=sha256:3cd1a4bd9a7a619922a8557e1318232e7269b5fb69d4ba97b04d20450a6bf970 \ + --hash=sha256:3d32794fe940b7000f0519904e247f902f0149edbe6316c710a8562fb6738841 \ + --hash=sha256:3d366aa598d69416b5afedf1faa539fac40c1d80a42f6b236c88c73a3c8f2d41 \ + --hash=sha256:3e271346f01e9c8b1130a6a3b0e11908049fe5be2d365a5f402778049147e7e9 \ + --hash=sha256:3f373a4ab5dbc528a94334f9fe444395b23c2f5332adab9ff4ea82f5a9e33bc1 \ + --hash=sha256:3fa272ca34332581e00bf7773e993d4f632594eb2d1b0b162a9038df0fd971dd \ + --hash=sha256:47434236c396d04875180171ee1f3815ca1eada05e24a1ee99546320d54d1d1b \ + --hash=sha256:47b0f8bafe90f7736151f61482c583c86b0693d80f075a58701dd1549b0010a9 \ + --hash=sha256:4811e15d88ee62dbf5c77f2c3ff3932b1e3ac92323ba3912f51fc4016ce81ecf \ + --hash=sha256:49989061a9977a8cbd6d20f2efa813f24bf657c6990a42967019ce779a878dbf \ + --hash=sha256:4ae879acc449caa9ed43fc36ba08392b9412ee67941748d31d94e3cedb36628c \ + --hash=sha256:4b55cacc57e1dc2d0991dbe74c6419ffd415fb66474a02335cb10efd1aa3f84f \ + --hash=sha256:4d2ce1bf1a48c5277d7969259232b57645aae5686dba1eaeade39442277afbca \ + --hash=sha256:4da7384b0e5d4cae05c97cd6f94faaf78cc8b0f791fc63af43436d98c4ab37bb \ + --hash=sha256:4e54bbf554ee29fcceee24fa41c4d091398b911da6e7f5d7bffda963c9aed2e1 \ + --hash=sha256:50844efc8cdf63b2d90cd3d62d4947a28311e6266ce5235a219d21b195b4ec2c \ + --hash=sha256:5a4939eae35db6b6cec8e7aa0e833dcca0acad8231672c26c2a9ab7a0f8ac9c8 \ + --hash=sha256:5dc1b852337c6792aa111ca8becff5bacf576bf4a0255b0f05eb749da6a1643e \ + --hash=sha256:5e53b428f65ece6d9dad23cb87e64506392b720a0b45076c05354d27a13351a1 \ + --hash=sha256:61c4956171c7434634401db448371277d07032a81cc21c599c22953374781395 \ + --hash=sha256:641e94e789b5f6b4822bb8d8ebbdfc10f4e4eae7756d648b717d980f657a9eb9 \ + --hash=sha256:64b103acdaa53b7caf409e8d45d39a8442fe6dcfec6ba3f3d141e0cc2b5b4dbd \ + --hash=sha256:68424221a2dc00d634b54f92441914929c5ffb1c30b3b837343978343a3512a3 \ + --hash=sha256:6bd1a18f5a797fe740cb3d7a0e853a8ce6461cc62023b630caec80171a6b8097 \ + --hash=sha256:6c72328f668cf4c503ffcf9434c2b71fdd624345ced7941bc6693e61bbe36bef \ + --hash=sha256:6d2d947d266d99a1477cd005b23cbd09465276e302515e122df56bb9511aca1b \ + --hash=sha256:7164a55f5e83a9a0b031d3ffab4d4e36bbec42e7025db560f225489fa929e509 \ + --hash=sha256:7b219cb2182f230676308cdcacd428fa837987b89e4b7c5c9025088b8a6c9faf \ + --hash=sha256:7d539241e87b650cbc4c3ac9f32c8d1ac8a54e510f6dca3f6ab60dcfd48c9b10 \ + --hash=sha256:7de3cc939be0e1174969f943f3b44e0d79b6f9a82198133a5b7fc6cc92882f16 \ + --hash=sha256:8330b42d769965e96e01fa14034b28a2a7600fbf7e8f0cc90ebb36d492c993e4 \ + --hash=sha256:837e31620e06b16030b1d126ed78e9383815cbac914693f54926d816d35d8edf \ + --hash=sha256:83ce30937f0ba0d28818807b303a412440c4b63e39d3d8fc036a94764b728c92 \ + --hash=sha256:85df8d92158cb8f3965aecc27cf821461bb5f40b450b03facc5d9f0d4d6ddec6 \ + --hash=sha256:8639b843c9efd84675f1e100ed9e99538ebea7297b62c4b45a7042edb84db03e \ + --hash=sha256:89a82053b193837bf93c0f8a57ded6e4b6d88033a499dadff5067e912c2a41e9 \ + --hash=sha256:8bacfe6e001749a3b64db47bcf0341da757c95959f592823a93931a422395013 \ + --hash=sha256:8ec3303e8a81932171f455f792f8df500fc1a09f20069e5c16bd7049ab4e8e38 \ + --hash=sha256:90897ea1cf0679763b62e79657958cd54eae5659f6360fc7d2ccc6f906342183 \ + --hash=sha256:908f8c6c71557f4deaa280f55d0728c3bca0960e8c3dd5ceeeafb3c19942719d \ + --hash=sha256:91bcc576260a274b169c3098e9a3519fb01f2989f6d3d386ef9cbf8653de1374 \ + --hash=sha256:9219a1d946a9b32bb23ccae66bdb61e35c62773ce7ca6509ceea70f344656b7b \ + --hash=sha256:949520bccc1fa227274da7d03bf238be15389cd94e32e4297b92337df9b7a349 \ + --hash=sha256:98d873ed6c8b4ee2418f7afce666751854d6d03e3c0ec2a399bb039cd2ae89db \ + --hash=sha256:9c9c635e78497cacb81e84f8b11b23e0aacac7a136e73b8e5b2109a1d9fc468f \ + --hash=sha256:9ca66b38dd642bf90c59b6738af8070747b610115a39af2498535f62b5cdc1c3 \ + --hash=sha256:a453257f19c31b31ba593c30d997d6e5be39e3b5ad9148c2af5a7314061c63eb \ + --hash=sha256:a52f93d95c8d38fed0669da2ebdb0b0376e895d84596a976c15a9eb45e3eccb3 \ + --hash=sha256:a9a83618c4f0757557c077ef71d708ddd9847ed66b7cc63416632af70d3e2308 \ + --hash=sha256:ab594f346517010050126fcd822697b25a7031d815bb4fbc238ccbe568216489 \ + --hash=sha256:ad3ee9d0f254851c71780966eb417ef8e72117155cff04821ab9b60549694a55 \ + --hash=sha256:aea9c7224c302bc8bfc892b908537f56c430802560e827b75ecbde81b604598b \ + --hash=sha256:b4c2e3d777e38e913b8ce3a6257af72fb608f86a1df471cb1d4339755d0a807c \ + --hash=sha256:b667189cf8efe008f55bbda321890bef628a67ab4147ebf90d182f2dadc78790 \ + --hash=sha256:b89ef9223d665ab255ae42cc282d27d69704d94be0deffc8b9d919179a609684 \ + --hash=sha256:be9e84e91d6497ba62594158d3d31ec0486c60055c49179edc51ee43d095f79c \ + --hash=sha256:bf4cb76f36be5de950ce13e22e7fdf462b35b04665a12b64f3ac5c1bbbcf3728 \ + --hash=sha256:bfb5539005259f8127ea9c885bdc231978c06b7a980e63a8a61c8c4c979719d0 \ + --hash=sha256:c046781d422f0830de6329fa4b16796096f28a92c8aef3850674442cdcb87b7f \ + --hash=sha256:c1be685ac7700c966b8610ccc63c3187a72e33cab53526a27b2a285a662cd4f7 \ + --hash=sha256:c1c91405fcf1d501fa5d55df21e58ea49e6b879ae829f1039faaf7e5e509b41e \ + --hash=sha256:c235095d6d090aa903f1db61f892fffb779c1eaeb2a50e566b52001f7a0f66ed \ + --hash=sha256:c4012a2bd37059d04f8209916aa771dfb564cccb86079072bdcd48a308b6a5c5 \ + --hash=sha256:c5ef2f2b8a53b7caee2f797ef166a390fef73979b15778a4a153e4b5fedce8fa \ + --hash=sha256:c654eafb01afac55246053d67a4b9a984a3567c3808bb7df2f8de1c1caba2e1c \ + --hash=sha256:c8d60527d1ecfc131426b10d93ab5d53e08a09c5fa0175f6b21b3252080c70a9 \ + --hash=sha256:c9e850f5b7fc67af856ff054c71690d54fa940c3ef74209ad9f935b4f66a0233 \ + --hash=sha256:cbeb0971e13b4bd81d34169ed57a6dda017328d1a22b62fda45e1d21dd06148f \ + --hash=sha256:d1a8a09a004ef100e614beec82862d11fc17d601092c3599afd22b1f36e4137e \ + --hash=sha256:d67956c676be5a24102c7407a71f4126d30de2a569a1c7871c9f3cabc94225d7 \ + --hash=sha256:d6cc985b9c8b235bd933990cdbf0f891f8e010b65a3911f7a55179cd7b0fc57b \ + --hash=sha256:d7b822c61ed04ee6ad64bc90d13368ad6eb094db54883b5dde2182f67a7f22c0 \ + --hash=sha256:df0b6d3b95932809c5b3fecc18fda0f1e07452d05e2662a0b35548985f256e28 \ + --hash=sha256:e042d653a4745be832d5aa190ff80ee4f02c34b21f4b785745eceacd0907b815 \ + --hash=sha256:e2f84e9af2060e3904a32cea9bb6db23ce3f91cfd90c6b426757cf7cc01c45c7 \ + --hash=sha256:e3612dc06b436968dfb9142c62e5dfa9eb5924f91120b3c8ff501ad878f90eb3 \ + --hash=sha256:e505629359cb5f751e16e30cf3f91a1d3ddb4552480c205947da415d597f7ac2 \ + --hash=sha256:e60690ba71a57424c8d9ff28f8d006b7ad7772c22a4af432188572cd7fa004a1 \ + --hash=sha256:e76e3f91f864e89db8b8d2a8311d57df93f01ad6bb1e9b9976d1f2e83e18315c \ + --hash=sha256:eb7cffe572ad0a141a7886a1d2efa5bef0bf7fe021deeea76b3ab334d2c38218 \ + --hash=sha256:ec65a78fbd9d6f083a15d7613b2800d5663dbb6bb96003899c834beaa68b242c \ + --hash=sha256:eda8e4ecd662d48c28bb86be9e837c13e45c58b8300e43ba3c9b4fa9900302f7 \ + --hash=sha256:f26f8e2ca19564e2e1fdbb6a0e47f36e0efbab1acc31e15471fad88f828c75f6 \ + --hash=sha256:f49027b0b9503bf6c8cdc297ca55006b80c2f5dd36cecc72c6835ab6e10e8a25 \ + --hash=sha256:f73f9f7a0ebd0db139253d27e5fc8d2866ceaeef19c30ab5d69dcbe35e1a6981 \ + --hash=sha256:fa4184e74197af3adad3c889a1af95b53bb0466bced92ea99a0c014e48323eec \ + --hash=sha256:fb1a5b72cbd751813adc02ef01ada0b0d05d3dcbc32976ce189a1279d80ad4a2 \ + --hash=sha256:fb3a86e703868561c5cad155a15c36c716e1ab513b7065bd2ac8ed353c503333 \ + --hash=sha256:fc007fdf480c77301ab1afdbb6ab22a5deee8885f3b1ed7afcb7e5e84a0e27be \ + --hash=sha256:fe21b118b9f58859b5ebaa4b130dee18669df4bd111daad082b7beb8799ad16b \ + --hash=sha256:fec0d993ecba3991645b4857837277469c8cc4c554a7e24d064d1ca291cfb81f + # via tensorflow +zipp==3.23.0 \ + --hash=sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e \ + --hash=sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166 # via etils -zstandard==0.23.0 \ - --hash=sha256:034b88913ecc1b097f528e42b539453fa82c3557e414b3de9d5632c80439a473 \ - --hash=sha256:0a7f0804bb3799414af278e9ad51be25edf67f78f916e08afdb983e74161b916 \ - --hash=sha256:11e3bf3c924853a2d5835b24f03eeba7fc9b07d8ca499e247e06ff5676461a15 \ - --hash=sha256:12a289832e520c6bd4dcaad68e944b86da3bad0d339ef7989fb7e88f92e96072 \ - --hash=sha256:1516c8c37d3a053b01c1c15b182f3b5f5eef19ced9b930b684a73bad121addf4 \ - --hash=sha256:157e89ceb4054029a289fb504c98c6a9fe8010f1680de0201b3eb5dc20aa6d9e \ - --hash=sha256:1bfe8de1da6d104f15a60d4a8a768288f66aa953bbe00d027398b93fb9680b26 \ - --hash=sha256:1e172f57cd78c20f13a3415cc8dfe24bf388614324d25539146594c16d78fcc8 \ - --hash=sha256:1fd7e0f1cfb70eb2f95a19b472ee7ad6d9a0a992ec0ae53286870c104ca939e5 \ - --hash=sha256:203d236f4c94cd8379d1ea61db2fce20730b4c38d7f1c34506a31b34edc87bdd \ - --hash=sha256:27d3ef2252d2e62476389ca8f9b0cf2bbafb082a3b6bfe9d90cbcbb5529ecf7c \ - --hash=sha256:29a2bc7c1b09b0af938b7a8343174b987ae021705acabcbae560166567f5a8db \ - --hash=sha256:2ef230a8fd217a2015bc91b74f6b3b7d6522ba48be29ad4ea0ca3a3775bf7dd5 \ - --hash=sha256:2ef3775758346d9ac6214123887d25c7061c92afe1f2b354f9388e9e4d48acfc \ - --hash=sha256:2f146f50723defec2975fb7e388ae3a024eb7151542d1599527ec2aa9cacb152 \ - --hash=sha256:2fb4535137de7e244c230e24f9d1ec194f61721c86ebea04e1581d9d06ea1269 \ - --hash=sha256:32ba3b5ccde2d581b1e6aa952c836a6291e8435d788f656fe5976445865ae045 \ - --hash=sha256:34895a41273ad33347b2fc70e1bff4240556de3c46c6ea430a7ed91f9042aa4e \ - --hash=sha256:379b378ae694ba78cef921581ebd420c938936a153ded602c4fea612b7eaa90d \ - --hash=sha256:38302b78a850ff82656beaddeb0bb989a0322a8bbb1bf1ab10c17506681d772a \ - --hash=sha256:3aa014d55c3af933c1315eb4bb06dd0459661cc0b15cd61077afa6489bec63bb \ - --hash=sha256:4051e406288b8cdbb993798b9a45c59a4896b6ecee2f875424ec10276a895740 \ - --hash=sha256:40b33d93c6eddf02d2c19f5773196068d875c41ca25730e8288e9b672897c105 \ - --hash=sha256:43da0f0092281bf501f9c5f6f3b4c975a8a0ea82de49ba3f7100e64d422a1274 \ - --hash=sha256:445e4cb5048b04e90ce96a79b4b63140e3f4ab5f662321975679b5f6360b90e2 \ - --hash=sha256:48ef6a43b1846f6025dde6ed9fee0c24e1149c1c25f7fb0a0585572b2f3adc58 \ - --hash=sha256:50a80baba0285386f97ea36239855f6020ce452456605f262b2d33ac35c7770b \ - --hash=sha256:519fbf169dfac1222a76ba8861ef4ac7f0530c35dd79ba5727014613f91613d4 \ - --hash=sha256:53dd9d5e3d29f95acd5de6802e909ada8d8d8cfa37a3ac64836f3bc4bc5512db \ - --hash=sha256:53ea7cdc96c6eb56e76bb06894bcfb5dfa93b7adcf59d61c6b92674e24e2dd5e \ - --hash=sha256:576856e8594e6649aee06ddbfc738fec6a834f7c85bf7cadd1c53d4a58186ef9 \ - --hash=sha256:59556bf80a7094d0cfb9f5e50bb2db27fefb75d5138bb16fb052b61b0e0eeeb0 \ - --hash=sha256:5d41d5e025f1e0bccae4928981e71b2334c60f580bdc8345f824e7c0a4c2a813 \ - --hash=sha256:61062387ad820c654b6a6b5f0b94484fa19515e0c5116faf29f41a6bc91ded6e \ - --hash=sha256:61f89436cbfede4bc4e91b4397eaa3e2108ebe96d05e93d6ccc95ab5714be512 \ - --hash=sha256:62136da96a973bd2557f06ddd4e8e807f9e13cbb0bfb9cc06cfe6d98ea90dfe0 \ - --hash=sha256:64585e1dba664dc67c7cdabd56c1e5685233fbb1fc1966cfba2a340ec0dfff7b \ - --hash=sha256:65308f4b4890aa12d9b6ad9f2844b7ee42c7f7a4fd3390425b242ffc57498f48 \ - --hash=sha256:66b689c107857eceabf2cf3d3fc699c3c0fe8ccd18df2219d978c0283e4c508a \ - --hash=sha256:6a41c120c3dbc0d81a8e8adc73312d668cd34acd7725f036992b1b72d22c1772 \ - --hash=sha256:6f77fa49079891a4aab203d0b1744acc85577ed16d767b52fc089d83faf8d8ed \ - --hash=sha256:72c68dda124a1a138340fb62fa21b9bf4848437d9ca60bd35db36f2d3345f373 \ - --hash=sha256:752bf8a74412b9892f4e5b58f2f890a039f57037f52c89a740757ebd807f33ea \ - --hash=sha256:76e79bc28a65f467e0409098fa2c4376931fd3207fbeb6b956c7c476d53746dd \ - --hash=sha256:774d45b1fac1461f48698a9d4b5fa19a69d47ece02fa469825b442263f04021f \ - --hash=sha256:77da4c6bfa20dd5ea25cbf12c76f181a8e8cd7ea231c673828d0386b1740b8dc \ - --hash=sha256:77ea385f7dd5b5676d7fd943292ffa18fbf5c72ba98f7d09fc1fb9e819b34c23 \ - --hash=sha256:80080816b4f52a9d886e67f1f96912891074903238fe54f2de8b786f86baded2 \ - --hash=sha256:80a539906390591dd39ebb8d773771dc4db82ace6372c4d41e2d293f8e32b8db \ - --hash=sha256:82d17e94d735c99621bf8ebf9995f870a6b3e6d14543b99e201ae046dfe7de70 \ - --hash=sha256:837bb6764be6919963ef41235fd56a6486b132ea64afe5fafb4cb279ac44f259 \ - --hash=sha256:84433dddea68571a6d6bd4fbf8ff398236031149116a7fff6f777ff95cad3df9 \ - --hash=sha256:8c24f21fa2af4bb9f2c492a86fe0c34e6d2c63812a839590edaf177b7398f700 \ - --hash=sha256:8ed7d27cb56b3e058d3cf684d7200703bcae623e1dcc06ed1e18ecda39fee003 \ - --hash=sha256:9206649ec587e6b02bd124fb7799b86cddec350f6f6c14bc82a2b70183e708ba \ - --hash=sha256:983b6efd649723474f29ed42e1467f90a35a74793437d0bc64a5bf482bedfa0a \ - --hash=sha256:98da17ce9cbf3bfe4617e836d561e433f871129e3a7ac16d6ef4c680f13a839c \ - --hash=sha256:9c236e635582742fee16603042553d276cca506e824fa2e6489db04039521e90 \ - --hash=sha256:9da6bc32faac9a293ddfdcb9108d4b20416219461e4ec64dfea8383cac186690 \ - --hash=sha256:a05e6d6218461eb1b4771d973728f0133b2a4613a6779995df557f70794fd60f \ - --hash=sha256:a0817825b900fcd43ac5d05b8b3079937073d2b1ff9cf89427590718b70dd840 \ - --hash=sha256:a4ae99c57668ca1e78597d8b06d5af837f377f340f4cce993b551b2d7731778d \ - --hash=sha256:a8c86881813a78a6f4508ef9daf9d4995b8ac2d147dcb1a450448941398091c9 \ - --hash=sha256:a8fffdbd9d1408006baaf02f1068d7dd1f016c6bcb7538682622c556e7b68e35 \ - --hash=sha256:a9b07268d0c3ca5c170a385a0ab9fb7fdd9f5fd866be004c4ea39e44edce47dd \ - --hash=sha256:ab19a2d91963ed9e42b4e8d77cd847ae8381576585bad79dbd0a8837a9f6620a \ - --hash=sha256:ac184f87ff521f4840e6ea0b10c0ec90c6b1dcd0bad2f1e4a9a1b4fa177982ea \ - --hash=sha256:b0e166f698c5a3e914947388c162be2583e0c638a4703fc6a543e23a88dea3c1 \ - --hash=sha256:b2170c7e0367dde86a2647ed5b6f57394ea7f53545746104c6b09fc1f4223573 \ - --hash=sha256:b2d8c62d08e7255f68f7a740bae85b3c9b8e5466baa9cbf7f57f1cde0ac6bc09 \ - --hash=sha256:b4567955a6bc1b20e9c31612e615af6b53733491aeaa19a6b3b37f3b65477094 \ - --hash=sha256:b69bb4f51daf461b15e7b3db033160937d3ff88303a7bc808c67bbc1eaf98c78 \ - --hash=sha256:b8c0bd73aeac689beacd4e7667d48c299f61b959475cdbb91e7d3d88d27c56b9 \ - --hash=sha256:be9b5b8659dff1f913039c2feee1aca499cfbc19e98fa12bc85e037c17ec6ca5 \ - --hash=sha256:bf0a05b6059c0528477fba9054d09179beb63744355cab9f38059548fedd46a9 \ - --hash=sha256:c16842b846a8d2a145223f520b7e18b57c8f476924bda92aeee3a88d11cfc391 \ - --hash=sha256:c363b53e257246a954ebc7c488304b5592b9c53fbe74d03bc1c64dda153fb847 \ - --hash=sha256:c7c517d74bea1a6afd39aa612fa025e6b8011982a0897768a2f7c8ab4ebb78a2 \ - --hash=sha256:d20fd853fbb5807c8e84c136c278827b6167ded66c72ec6f9a14b863d809211c \ - --hash=sha256:d2240ddc86b74966c34554c49d00eaafa8200a18d3a5b6ffbf7da63b11d74ee2 \ - --hash=sha256:d477ed829077cd945b01fc3115edd132c47e6540ddcd96ca169facff28173057 \ - --hash=sha256:d50d31bfedd53a928fed6707b15a8dbeef011bb6366297cc435accc888b27c20 \ - --hash=sha256:dc1d33abb8a0d754ea4763bad944fd965d3d95b5baef6b121c0c9013eaf1907d \ - --hash=sha256:dc5d1a49d3f8262be192589a4b72f0d03b72dcf46c51ad5852a4fdc67be7b9e4 \ - --hash=sha256:e2d1a054f8f0a191004675755448d12be47fa9bebbcffa3cdf01db19f2d30a54 \ - --hash=sha256:e7792606d606c8df5277c32ccb58f29b9b8603bf83b48639b7aedf6df4fe8171 \ - --hash=sha256:ed1708dbf4d2e3a1c5c69110ba2b4eb6678262028afd6c6fbcc5a8dac9cda68e \ - --hash=sha256:f2d4380bf5f62daabd7b751ea2339c1a21d1c9463f1feb7fc2bdcea2c29c3160 \ - --hash=sha256:f3513916e8c645d0610815c257cbfd3242adfd5c4cfa78be514e5a3ebb42a41b \ - --hash=sha256:f8346bfa098532bc1fb6c7ef06783e969d87a99dd1d2a5a18a892c1d7a643c58 \ - --hash=sha256:f83fa6cae3fff8e98691248c9320356971b59678a17f20656a9e59cd32cee6d8 \ - --hash=sha256:fa6ce8b52c5987b3e34d5674b0ab529a4602b632ebab0a93b07bfb4dfc8f8a33 \ - --hash=sha256:fb2b1ecfef1e67897d336de3a0e3f52478182d6a47eda86cbd42504c5cbd009a \ - --hash=sha256:fc9ca1c9718cb3b06634c7c8dec57d24e9438b2aa9a0f02b8bb36bf478538880 \ - --hash=sha256:fd30d9c67d13d891f2360b2a120186729c111238ac63b43dbd37a5a40670b8ca \ - --hash=sha256:fd7699e8fd9969f455ef2926221e0233f81a2542921471382e77a9e2f2b57f4b \ - --hash=sha256:fe3b385d996ee0822fd46528d9f0443b880d4d05528fd26a9119a54ec3f91c69 - # via -r build/requirements.in +zstandard==0.25.0 ; python_version < "3.14" \ + --hash=sha256:011d388c76b11a0c165374ce660ce2c8efa8e5d87f34996aa80f9c0816698b64 \ + --hash=sha256:01582723b3ccd6939ab7b3a78622c573799d5d8737b534b86d0e06ac18dbde4a \ + --hash=sha256:05353cef599a7b0b98baca9b068dd36810c3ef0f42bf282583f438caf6ddcee3 \ + --hash=sha256:05df5136bc5a011f33cd25bc9f506e7426c0c9b3f9954f056831ce68f3b6689f \ + --hash=sha256:06acb75eebeedb77b69048031282737717a63e71e4ae3f77cc0c3b9508320df6 \ + --hash=sha256:07b527a69c1e1c8b5ab1ab14e2afe0675614a09182213f21a0717b62027b5936 \ + --hash=sha256:0bbc9a0c65ce0eea3c34a691e3c4b6889f5f3909ba4822ab385fab9057099431 \ + --hash=sha256:0be7622c37c183406f3dbf0cba104118eb16a4ea7359eeb5752f0794882fc250 \ + --hash=sha256:106281ae350e494f4ac8a80470e66d1fe27e497052c8d9c3b95dc4cf1ade81aa \ + --hash=sha256:10ef2a79ab8e2974e2075fb984e5b9806c64134810fac21576f0668e7ea19f8f \ + --hash=sha256:1673b7199bbe763365b81a4f3252b8e80f44c9e323fc42940dc8843bfeaf9851 \ + --hash=sha256:172de1f06947577d3a3005416977cce6168f2261284c02080e7ad0185faeced3 \ + --hash=sha256:181eb40e0b6a29b3cd2849f825e0fa34397f649170673d385f3598ae17cca2e9 \ + --hash=sha256:1869da9571d5e94a85a5e8d57e4e8807b175c9e4a6294e3b66fa4efb074d90f6 \ + --hash=sha256:19796b39075201d51d5f5f790bf849221e58b48a39a5fc74837675d8bafc7362 \ + --hash=sha256:1cd5da4d8e8ee0e88be976c294db744773459d51bb32f707a0f166e5ad5c8649 \ + --hash=sha256:1f3689581a72eaba9131b1d9bdbfe520ccd169999219b41000ede2fca5c1bfdb \ + --hash=sha256:1f830a0dac88719af0ae43b8b2d6aef487d437036468ef3c2ea59c51f9d55fd5 \ + --hash=sha256:223415140608d0f0da010499eaa8ccdb9af210a543fac54bce15babbcfc78439 \ + --hash=sha256:22a06c5df3751bb7dc67406f5374734ccee8ed37fc5981bf1ad7041831fa1137 \ + --hash=sha256:22a086cff1b6ceca18a8dd6096ec631e430e93a8e70a9ca5efa7561a00f826fa \ + --hash=sha256:23ebc8f17a03133b4426bcc04aabd68f8236eb78c3760f12783385171b0fd8bd \ + --hash=sha256:25f8f3cd45087d089aef5ba3848cd9efe3ad41163d3400862fb42f81a3a46701 \ + --hash=sha256:2b6bd67528ee8b5c5f10255735abc21aa106931f0dbaf297c7be0c886353c3d0 \ + --hash=sha256:2e54296a283f3ab5a26fc9b8b5d4978ea0532f37b231644f367aa588930aa043 \ + --hash=sha256:3756b3e9da9b83da1796f8809dd57cb024f838b9eeafde28f3cb472012797ac1 \ + --hash=sha256:37daddd452c0ffb65da00620afb8e17abd4adaae6ce6310702841760c2c26860 \ + --hash=sha256:3a39c94ad7866160a4a46d772e43311a743c316942037671beb264e395bdd611 \ + --hash=sha256:3b870ce5a02d4b22286cf4944c628e0f0881b11b3f14667c1d62185a99e04f53 \ + --hash=sha256:3c83b0188c852a47cd13ef3bf9209fb0a77fa5374958b8c53aaa699398c6bd7b \ + --hash=sha256:4203ce3b31aec23012d3a4cf4a2ed64d12fea5269c49aed5e4c3611b938e4088 \ + --hash=sha256:457ed498fc58cdc12fc48f7950e02740d4f7ae9493dd4ab2168a47c93c31298e \ + --hash=sha256:474d2596a2dbc241a556e965fb76002c1ce655445e4e3bf38e5477d413165ffa \ + --hash=sha256:4b14abacf83dfb5c25eb4e4a79520de9e7e205f72c9ee7702f91233ae57d33a2 \ + --hash=sha256:4b6d83057e713ff235a12e73916b6d356e3084fd3d14ced499d84240f3eecee0 \ + --hash=sha256:4d441506e9b372386a5271c64125f72d5df6d2a8e8a2a45a0ae09b03cb781ef7 \ + --hash=sha256:4f187a0bb61b35119d1926aee039524d1f93aaf38a9916b8c4b78ac8514a0aaf \ + --hash=sha256:51526324f1b23229001eb3735bc8c94f9c578b1bd9e867a0a646a3b17109f388 \ + --hash=sha256:53e08b2445a6bc241261fea89d065536f00a581f02535f8122eba42db9375530 \ + --hash=sha256:53f94448fe5b10ee75d246497168e5825135d54325458c4bfffbaafabcc0a577 \ + --hash=sha256:5a56ba0db2d244117ed744dfa8f6f5b366e14148e00de44723413b2f3938a902 \ + --hash=sha256:5f1ad7bf88535edcf30038f6919abe087f606f62c00a87d7e33e7fc57cb69fcc \ + --hash=sha256:5f5e4c2a23ca271c218ac025bd7d635597048b366d6f31f420aaeb715239fc98 \ + --hash=sha256:6a573a35693e03cf1d67799fd01b50ff578515a8aeadd4595d2a7fa9f3ec002a \ + --hash=sha256:6c0e5a65158a7946e7a7affa6418878ef97ab66636f13353b8502d7ea03c8097 \ + --hash=sha256:6dffecc361d079bb48d7caef5d673c88c8988d3d33fb74ab95b7ee6da42652ea \ + --hash=sha256:7030defa83eef3e51ff26f0b7bfb229f0204b66fe18e04359ce3474ac33cbc09 \ + --hash=sha256:7149623bba7fdf7e7f24312953bcf73cae103db8cae49f8154dd1eadc8a29ecb \ + --hash=sha256:72d35d7aa0bba323965da807a462b0966c91608ef3a48ba761678cb20ce5d8b7 \ + --hash=sha256:75ffc32a569fb049499e63ce68c743155477610532da1eb38e7f24bf7cd29e74 \ + --hash=sha256:7713e1179d162cf5c7906da876ec2ccb9c3a9dcbdffef0cc7f70c3667a205f0b \ + --hash=sha256:78228d8a6a1c177a96b94f7e2e8d012c55f9c760761980da16ae7546a15a8e9b \ + --hash=sha256:7b3c3a3ab9daa3eed242d6ecceead93aebbb8f5f84318d82cee643e019c4b73b \ + --hash=sha256:809c5bcb2c67cd0ed81e9229d227d4ca28f82d0f778fc5fea624a9def3963f91 \ + --hash=sha256:81dad8d145d8fd981b2962b686b2241d3a1ea07733e76a2f15435dfb7fb60150 \ + --hash=sha256:85304a43f4d513f5464ceb938aa02c1e78c2943b29f44a750b48b25ac999a049 \ + --hash=sha256:89c4b48479a43f820b749df49cd7ba2dbc2b1b78560ecb5ab52985574fd40b27 \ + --hash=sha256:8e735494da3db08694d26480f1493ad2cf86e99bdd53e8e9771b2752a5c0246a \ + --hash=sha256:913cbd31a400febff93b564a23e17c3ed2d56c064006f54efec210d586171c00 \ + --hash=sha256:9174f4ed06f790a6869b41cba05b43eeb9a35f8993c4422ab853b705e8112bbd \ + --hash=sha256:9300d02ea7c6506f00e627e287e0492a5eb0371ec1670ae852fefffa6164b072 \ + --hash=sha256:933b65d7680ea337180733cf9e87293cc5500cc0eb3fc8769f4d3c88d724ec5c \ + --hash=sha256:9654dbc012d8b06fc3d19cc825af3f7bf8ae242226df5f83936cb39f5fdc846c \ + --hash=sha256:98750a309eb2f020da61e727de7d7ba3c57c97cf6213f6f6277bb7fb42a8e065 \ + --hash=sha256:99c0c846e6e61718715a3c9437ccc625de26593fea60189567f0118dc9db7512 \ + --hash=sha256:a1a4ae2dec3993a32247995bdfe367fc3266da832d82f8438c8570f989753de1 \ + --hash=sha256:a3f79487c687b1fc69f19e487cd949bf3aae653d181dfb5fde3bf6d18894706f \ + --hash=sha256:a4089a10e598eae6393756b036e0f419e8c1d60f44a831520f9af41c14216cf2 \ + --hash=sha256:a51ff14f8017338e2f2e5dab738ce1ec3b5a851f23b18c1ae1359b1eecbee6df \ + --hash=sha256:a5a419712cf88862a45a23def0ae063686db3d324cec7edbe40509d1a79a0aab \ + --hash=sha256:a9ec8c642d1ec73287ae3e726792dd86c96f5681eb8df274a757bf62b750eae7 \ + --hash=sha256:aaf21ba8fb76d102b696781bddaa0954b782536446083ae3fdaa6f16b25a1c4b \ + --hash=sha256:ab85470ab54c2cb96e176f40342d9ed41e58ca5733be6a893b730e7af9c40550 \ + --hash=sha256:b9af1fe743828123e12b41dd8091eca1074d0c1569cc42e6e1eee98027f2bbd0 \ + --hash=sha256:bfc4e20784722098822e3eee42b8e576b379ed72cca4a7cb856ae733e62192ea \ + --hash=sha256:bfd06b1c5584b657a2892a6014c2f4c20e0db0208c159148fa78c65f7e0b0277 \ + --hash=sha256:c19bcdd826e95671065f8692b5a4aa95c52dc7a02a4c5a0cac46deb879a017a2 \ + --hash=sha256:c2ba942c94e0691467ab901fc51b6f2085ff48f2eea77b1a48240f011e8247c7 \ + --hash=sha256:c8e167d5adf59476fa3e37bee730890e389410c354771a62e3c076c86f9f7778 \ + --hash=sha256:ca54090275939dc8ec5dea2d2afb400e0f83444b2fc24e07df7fdef677110859 \ + --hash=sha256:d7541afd73985c630bafcd6338d2518ae96060075f9463d7dc14cfb33514383d \ + --hash=sha256:d8c56bb4e6c795fc77d74d8e8b80846e1fb8292fc0b5060cd8131d522974b751 \ + --hash=sha256:da469dc041701583e34de852d8634703550348d5822e66a0c827d39b05365b12 \ + --hash=sha256:daab68faadb847063d0c56f361a289c4f268706b598afbf9ad113cbe5c38b6b2 \ + --hash=sha256:e05ab82ea7753354bb054b92e2f288afb750e6b439ff6ca78af52939ebbc476d \ + --hash=sha256:e09bb6252b6476d8d56100e8147b803befa9a12cea144bbe629dd508800d1ad0 \ + --hash=sha256:e29f0cf06974c899b2c188ef7f783607dbef36da4c242eb6c82dcd8b512855e3 \ + --hash=sha256:e59fdc271772f6686e01e1b3b74537259800f57e24280be3f29c8a0deb1904dd \ + --hash=sha256:e7360eae90809efd19b886e59a09dad07da4ca9ba096752e61a2e03c8aca188e \ + --hash=sha256:e96594a5537722fdfb79951672a2a63aec5ebfb823e7560586f7484819f2a08f \ + --hash=sha256:ea9d54cc3d8064260114a0bbf3479fc4a98b21dffc89b3459edd506b69262f6e \ + --hash=sha256:ec996f12524f88e151c339688c3897194821d7f03081ab35d31d1e12ec975e94 \ + --hash=sha256:f27662e4f7dbf9f9c12391cb37b4c4c3cb90ffbd3b1fb9284dadbbb8935fa708 \ + --hash=sha256:f373da2c1757bb7f1acaf09369cdc1d51d84131e50d5fa9863982fd626466313 \ + --hash=sha256:f5aeea11ded7320a84dcdd62a3d95b5186834224a9e55b92ccae35d21a8b63d4 \ + --hash=sha256:f604efd28f239cc21b3adb53eb061e2a205dc164be408e553b41ba2ffe0ca15c \ + --hash=sha256:f67e8f1a324a900e75b5e28ffb152bcac9fbed1cc7b43f99cd90f395c4375344 \ + --hash=sha256:fd7a5004eb1980d3cefe26b2685bcb0b17989901a70a1040d1ac86f1d898c551 \ + --hash=sha256:ffef5a74088f1e09947aecf91011136665152e0b4b359c42be3373897fb39b01 + # via -r build/nonfreethreading-requirements.txt # The following packages are considered to be unsafe in a requirements file: -setuptools==76.0.0 \ - --hash=sha256:199466a166ff664970d0ee145839f5582cb9bca7a0a3a2e795b6a9cb2308e9c6 \ - --hash=sha256:43b4ee60e10b0d0ee98ad11918e114c70701bc6051662a9a675a0496c1a158f4 +setuptools==80.9.0 \ + --hash=sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922 \ + --hash=sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c # via # -r build/requirements.in - # -r build/test-requirements.txt + # tensorboard + # tensorflow diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt index e7a2968e981e..815ab33cc09a 100644 --- a/build/requirements_lock_3_13_ft.txt +++ b/build/requirements_lock_3_13_ft.txt @@ -2,642 +2,948 @@ # This file is autogenerated by pip-compile with Python 3.13 # by the following command: # -# pip-compile --allow-unsafe --generate-hashes --output-file=build/requirements_lock_3_13_ft.txt build/requirements.in +# bazel run //build:requirements_ft.update # -absl-py==2.1.0 \ - --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ - --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff +--index-url https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple + +absl-py==2.3.1 \ + --hash=sha256:a97820526f7fbfd2ec1bce83f3f25e3a14840dac0d8e02a0b71cd75db3f77fc9 \ + --hash=sha256:eeecf07f0c2a93ace0772c92e596ace6d3d3996c042b2128459aaae2a76de11d # via -r build/test-requirements.txt -attrs==24.3.0 \ - --hash=sha256:8f5c07333d543103541ba7be0e2ce16eeee8130cb0b3f9238ab904ce1e85baff \ - --hash=sha256:ac96cd038792094f438ad1f6ff80837353805ac950cd2aa0e0625ef19850c308 +attrs==25.4.0 \ + --hash=sha256:16d5969b87f0859ef33a48b35d55ac1be6e42ae49d5e853b597db70c35c57e11 \ + --hash=sha256:adcf7e2a1fb3b36ac48d97835bb6d8ade15b8dcce26aba8bf1d14847b57a3373 # via hypothesis -auditwheel==6.2.0 \ - --hash=sha256:4fc9f778cd81dac56820e8cdee9842dc44b8f435f8783606dabd4964d4638b30 \ - --hash=sha256:927e0fc9ab5b6040c1240c81dd7f50924c99c3ca876a776d52e042ba374f6604 - # via -r build/test-requirements.txt -build==1.2.2.post1 \ - --hash=sha256:1d61c0887fa860c01971625baae8bdd338e517b836a2f70dd1f7aa3a6b2fc5b5 \ - --hash=sha256:b36993e92ca9375a219c99e606a122ff365a760a2d4bba0caa09bd5278b608b7 +auditwheel==6.5.0 \ + --hash=sha256:4fbcbd5854054bb1dd7870db03727b871b96b18147db57259561c058603987d7 \ + --hash=sha256:e08d2eede0259be6feff597d041c06175026e93248a1a97143acc52c57714d80 # via -r build/test-requirements.txt -cloudpickle==3.1.0 \ - --hash=sha256:81a929b6e3c7335c863c771d673d105f02efdb89dfaba0c90495d1c64796601b \ - --hash=sha256:fe11acda67f61aaaec473e3afe030feb131d78a43461b718185363384f1ba12e +build==1.3.0 \ + --hash=sha256:698edd0ea270bde950f53aed21f3a0135672206f3911e0176261a31e0e07b397 \ + --hash=sha256:7145f0b5061ba90a1500d60bd1b13ca0a8a4cebdd0cc16ed8adf1c0e739f43b4 + # via -r build/requirements.in +cloudpickle==3.1.2 \ + --hash=sha256:7fda9eb655c9c230dab534f1983763de5835249750e85fbcef43aaa30a9a2414 \ + --hash=sha256:9acb47f6afd73f60dc1df93bb801b472f05ff42fa6c84167d25cb206be1fbf4a # via -r build/test-requirements.txt colorama==0.4.6 \ --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 - # via -r build/test-requirements.txt -contourpy==1.3.1 \ - --hash=sha256:041b640d4ec01922083645a94bb3b2e777e6b626788f4095cf21abbe266413c1 \ - --hash=sha256:05e806338bfeaa006acbdeba0ad681a10be63b26e1b17317bfac3c5d98f36cda \ - --hash=sha256:08d9d449a61cf53033612cb368f3a1b26cd7835d9b8cd326647efe43bca7568d \ - --hash=sha256:0ffa84be8e0bd33410b17189f7164c3589c229ce5db85798076a3fa136d0e509 \ - --hash=sha256:113231fe3825ebf6f15eaa8bc1f5b0ddc19d42b733345eae0934cb291beb88b6 \ - --hash=sha256:14c102b0eab282427b662cb590f2e9340a9d91a1c297f48729431f2dcd16e14f \ - --hash=sha256:174e758c66bbc1c8576992cec9599ce8b6672b741b5d336b5c74e35ac382b18e \ - --hash=sha256:19c1555a6801c2f084c7ddc1c6e11f02eb6a6016ca1318dd5452ba3f613a1751 \ - --hash=sha256:19d40d37c1c3a4961b4619dd9d77b12124a453cc3d02bb31a07d58ef684d3d86 \ - --hash=sha256:1bf98051f1045b15c87868dbaea84f92408337d4f81d0e449ee41920ea121d3b \ - --hash=sha256:20914c8c973f41456337652a6eeca26d2148aa96dd7ac323b74516988bea89fc \ - --hash=sha256:287ccc248c9e0d0566934e7d606201abd74761b5703d804ff3df8935f523d546 \ - --hash=sha256:2ba94a401342fc0f8b948e57d977557fbf4d515f03c67682dd5c6191cb2d16ec \ - --hash=sha256:31c1b55c1f34f80557d3830d3dd93ba722ce7e33a0b472cba0ec3b6535684d8f \ - --hash=sha256:36987a15e8ace5f58d4d5da9dca82d498c2bbb28dff6e5d04fbfcc35a9cb3a82 \ - --hash=sha256:3a04ecd68acbd77fa2d39723ceca4c3197cb2969633836ced1bea14e219d077c \ - --hash=sha256:3e8b974d8db2c5610fb4e76307e265de0edb655ae8169e8b21f41807ccbeec4b \ - --hash=sha256:3ea9924d28fc5586bf0b42d15f590b10c224117e74409dd7a0be3b62b74a501c \ - --hash=sha256:4318af1c925fb9a4fb190559ef3eec206845f63e80fb603d47f2d6d67683901c \ - --hash=sha256:44a29502ca9c7b5ba389e620d44f2fbe792b1fb5734e8b931ad307071ec58c53 \ - --hash=sha256:47734d7073fb4590b4a40122b35917cd77be5722d80683b249dac1de266aac80 \ - --hash=sha256:4d76d5993a34ef3df5181ba3c92fabb93f1eaa5729504fb03423fcd9f3177242 \ - --hash=sha256:4dbbc03a40f916a8420e420d63e96a1258d3d1b58cbdfd8d1f07b49fcbd38e85 \ - --hash=sha256:500360b77259914f7805af7462e41f9cb7ca92ad38e9f94d6c8641b089338124 \ - --hash=sha256:523a8ee12edfa36f6d2a49407f705a6ef4c5098de4f498619787e272de93f2d5 \ - --hash=sha256:573abb30e0e05bf31ed067d2f82500ecfdaec15627a59d63ea2d95714790f5c2 \ - --hash=sha256:5b75aa69cb4d6f137b36f7eb2ace9280cfb60c55dc5f61c731fdf6f037f958a3 \ - --hash=sha256:61332c87493b00091423e747ea78200659dc09bdf7fd69edd5e98cef5d3e9a8d \ - --hash=sha256:805617228ba7e2cbbfb6c503858e626ab528ac2a32a04a2fe88ffaf6b02c32bc \ - --hash=sha256:841ad858cff65c2c04bf93875e384ccb82b654574a6d7f30453a04f04af71342 \ - --hash=sha256:89785bb2a1980c1bd87f0cb1517a71cde374776a5f150936b82580ae6ead44a1 \ - --hash=sha256:8eb96e79b9f3dcadbad2a3891672f81cdcab7f95b27f28f1c67d75f045b6b4f1 \ - --hash=sha256:974d8145f8ca354498005b5b981165b74a195abfae9a8129df3e56771961d595 \ - --hash=sha256:9ddeb796389dadcd884c7eb07bd14ef12408aaae358f0e2ae24114d797eede30 \ - --hash=sha256:a045f341a77b77e1c5de31e74e966537bba9f3c4099b35bf4c2e3939dd54cdab \ - --hash=sha256:a0cffcbede75c059f535725c1680dfb17b6ba8753f0c74b14e6a9c68c29d7ea3 \ - --hash=sha256:a761d9ccfc5e2ecd1bf05534eda382aa14c3e4f9205ba5b1684ecfe400716ef2 \ - --hash=sha256:a7895f46d47671fa7ceec40f31fae721da51ad34bdca0bee83e38870b1f47ffd \ - --hash=sha256:a9fa36448e6a3a1a9a2ba23c02012c43ed88905ec80163f2ffe2421c7192a5d7 \ - --hash=sha256:ab29962927945d89d9b293eabd0d59aea28d887d4f3be6c22deaefbb938a7277 \ - --hash=sha256:abbb49fb7dac584e5abc6636b7b2a7227111c4f771005853e7d25176daaf8453 \ - --hash=sha256:ac4578ac281983f63b400f7fe6c101bedc10651650eef012be1ccffcbacf3697 \ - --hash=sha256:adce39d67c0edf383647a3a007de0a45fd1b08dedaa5318404f1a73059c2512b \ - --hash=sha256:ade08d343436a94e633db932e7e8407fe7de8083967962b46bdfc1b0ced39454 \ - --hash=sha256:b2bdca22a27e35f16794cf585832e542123296b4687f9fd96822db6bae17bfc9 \ - --hash=sha256:b2f926efda994cdf3c8d3fdb40b9962f86edbc4457e739277b961eced3d0b4c1 \ - --hash=sha256:b457d6430833cee8e4b8e9b6f07aa1c161e5e0d52e118dc102c8f9bd7dd060d6 \ - --hash=sha256:c414fc1ed8ee1dbd5da626cf3710c6013d3d27456651d156711fa24f24bd1291 \ - --hash=sha256:cb76c1a154b83991a3cbbf0dfeb26ec2833ad56f95540b442c73950af2013750 \ - --hash=sha256:dfd97abd83335045a913e3bcc4a09c0ceadbe66580cf573fe961f4a825efa699 \ - --hash=sha256:e914a8cb05ce5c809dd0fe350cfbb4e881bde5e2a38dc04e3afe1b3e58bd158e \ - --hash=sha256:ece6df05e2c41bd46776fbc712e0996f7c94e0d0543af1656956d150c4ca7c81 \ - --hash=sha256:efa874e87e4a647fd2e4f514d5e91c7d493697127beb95e77d2f7561f6905bd9 \ - --hash=sha256:f611e628ef06670df83fce17805c344710ca5cde01edfdc72751311da8585375 + # via -r build/requirements.in +contourpy==1.3.3 \ + --hash=sha256:023b44101dfe49d7d53932be418477dba359649246075c996866106da069af69 \ + --hash=sha256:07ce5ed73ecdc4a03ffe3e1b3e3c1166db35ae7584be76f65dbbe28a7791b0cc \ + --hash=sha256:083e12155b210502d0bca491432bb04d56dc3432f95a979b429f2848c3dbe880 \ + --hash=sha256:0bf67e0e3f482cb69779dd3061b534eb35ac9b17f163d851e2a547d56dba0a3a \ + --hash=sha256:0c1fc238306b35f246d61a1d416a627348b5cf0648648a031e14bb8705fcdfe8 \ + --hash=sha256:13b68d6a62db8eafaebb8039218921399baf6e47bf85006fd8529f2a08ef33fc \ + --hash=sha256:15ff10bfada4bf92ec8b31c62bf7c1834c244019b4a33095a68000d7075df470 \ + --hash=sha256:177fb367556747a686509d6fef71d221a4b198a3905fe824430e5ea0fda54eb5 \ + --hash=sha256:1cadd8b8969f060ba45ed7c1b714fe69185812ab43bd6b86a9123fe8f99c3263 \ + --hash=sha256:1fd43c3be4c8e5fd6e4f2baeae35ae18176cf2e5cced681cca908addf1cdd53b \ + --hash=sha256:22e9b1bd7a9b1d652cd77388465dc358dafcd2e217d35552424aa4f996f524f5 \ + --hash=sha256:23416f38bfd74d5d28ab8429cc4d63fa67d5068bd711a85edb1c3fb0c3e2f381 \ + --hash=sha256:283edd842a01e3dcd435b1c5116798d661378d83d36d337b8dde1d16a5fc9ba3 \ + --hash=sha256:2a2a8b627d5cc6b7c41a4beff6c5ad5eb848c88255fda4a8745f7e901b32d8e4 \ + --hash=sha256:2b7e9480ffe2b0cd2e787e4df64270e3a0440d9db8dc823312e2c940c167df7e \ + --hash=sha256:322ab1c99b008dad206d406bb61d014cf0174df491ae9d9d0fac6a6fda4f977f \ + --hash=sha256:33c82d0138c0a062380332c861387650c82e4cf1747aaa6938b9b6516762e772 \ + --hash=sha256:348ac1f5d4f1d66d3322420f01d42e43122f43616e0f194fc1c9f5d830c5b286 \ + --hash=sha256:3519428f6be58431c56581f1694ba8e50626f2dd550af225f82fb5f5814d2a42 \ + --hash=sha256:3c30273eb2a55024ff31ba7d052dde990d7d8e5450f4bbb6e913558b3d6c2301 \ + --hash=sha256:3d1a3799d62d45c18bafd41c5fa05120b96a28079f2393af559b843d1a966a77 \ + --hash=sha256:451e71b5a7d597379ef572de31eeb909a87246974d960049a9848c3bc6c41bf7 \ + --hash=sha256:459c1f020cd59fcfe6650180678a9993932d80d44ccde1fa1868977438f0b411 \ + --hash=sha256:4d00e655fcef08aba35ec9610536bfe90267d7ab5ba944f7032549c55a146da1 \ + --hash=sha256:4debd64f124ca62069f313a9cb86656ff087786016d76927ae2cf37846b006c9 \ + --hash=sha256:4feffb6537d64b84877da813a5c30f1422ea5739566abf0bd18065ac040e120a \ + --hash=sha256:50ed930df7289ff2a8d7afeb9603f8289e5704755c7e5c3bbd929c90c817164b \ + --hash=sha256:51e79c1f7470158e838808d4a996fa9bac72c498e93d8ebe5119bc1e6becb0db \ + --hash=sha256:556dba8fb6f5d8742f2923fe9457dbdd51e1049c4a43fd3986a0b14a1d815fc6 \ + --hash=sha256:598c3aaece21c503615fd59c92a3598b428b2f01bfb4b8ca9c4edeecc2438620 \ + --hash=sha256:5ed3657edf08512fc3fe81b510e35c2012fbd3081d2e26160f27ca28affec989 \ + --hash=sha256:626d60935cf668e70a5ce6ff184fd713e9683fb458898e4249b63be9e28286ea \ + --hash=sha256:644a6853d15b2512d67881586bd03f462c7ab755db95f16f14d7e238f2852c67 \ + --hash=sha256:655456777ff65c2c548b7c454af9c6f33f16c8884f11083244b5819cc214f1b5 \ + --hash=sha256:66c8a43a4f7b8df8b71ee1840e4211a3c8d93b214b213f590e18a1beca458f7d \ + --hash=sha256:6afc576f7b33cf00996e5c1102dc2a8f7cc89e39c0b55df93a0b78c1bd992b36 \ + --hash=sha256:6c3d53c796f8647d6deb1abe867daeb66dcc8a97e8455efa729516b997b8ed99 \ + --hash=sha256:709a48ef9a690e1343202916450bc48b9e51c049b089c7f79a267b46cffcdaa1 \ + --hash=sha256:70f9aad7de812d6541d29d2bbf8feb22ff7e1c299523db288004e3157ff4674e \ + --hash=sha256:8153b8bfc11e1e4d75bcb0bff1db232f9e10b274e0929de9d608027e0d34ff8b \ + --hash=sha256:87acf5963fc2b34825e5b6b048f40e3635dd547f590b04d2ab317c2619ef7ae8 \ + --hash=sha256:88df9880d507169449d434c293467418b9f6cbe82edd19284aa0409e7fdb933d \ + --hash=sha256:929ddf8c4c7f348e4c0a5a3a714b5c8542ffaa8c22954862a46ca1813b667ee7 \ + --hash=sha256:92d9abc807cf7d0e047b95ca5d957cf4792fcd04e920ca70d48add15c1a90ea7 \ + --hash=sha256:95b181891b4c71de4bb404c6621e7e2390745f887f2a026b2d99e92c17892339 \ + --hash=sha256:9e999574eddae35f1312c2b4b717b7885d4edd6cb46700e04f7f02db454e67c1 \ + --hash=sha256:a15459b0f4615b00bbd1e91f1b9e19b7e63aea7483d03d804186f278c0af2659 \ + --hash=sha256:a22738912262aa3e254e4f3cb079a95a67132fc5a063890e224393596902f5a4 \ + --hash=sha256:ab2fd90904c503739a75b7c8c5c01160130ba67944a7b77bbf36ef8054576e7f \ + --hash=sha256:ab3074b48c4e2cf1a960e6bbeb7f04566bf36b1861d5c9d4d8ac04b82e38ba20 \ + --hash=sha256:afe5a512f31ee6bd7d0dda52ec9864c984ca3d66664444f2d72e0dc4eb832e36 \ + --hash=sha256:b08a32ea2f8e42cf1d4be3169a98dd4be32bafe4f22b6c4cb4ba810fa9e5d2cb \ + --hash=sha256:b20c7c9a3bf701366556e1b1984ed2d0cedf999903c51311417cf5f591d8c78d \ + --hash=sha256:b2e8faa0ed68cb29af51edd8e24798bb661eac3bd9f65420c1887b6ca89987c8 \ + --hash=sha256:b7301b89040075c30e5768810bc96a8e8d78085b47d8be6e4c3f5a0b4ed478a0 \ + --hash=sha256:b7448cb5a725bb1e35ce88771b86fba35ef418952474492cf7c764059933ff8b \ + --hash=sha256:ca0fdcd73925568ca027e0b17ab07aad764be4706d0a925b89227e447d9737b7 \ + --hash=sha256:ca658cd1a680a5c9ea96dc61cdbae1e85c8f25849843aa799dfd3cb370ad4fbe \ + --hash=sha256:cbedb772ed74ff5be440fa8eee9bd49f64f6e3fc09436d9c7d8f1c287b121d77 \ + --hash=sha256:cd5dfcaeb10f7b7f9dc8941717c6c2ade08f587be2226222c12b25f0483ed497 \ + --hash=sha256:cf9022ef053f2694e31d630feaacb21ea24224be1c3ad0520b13d844274614fd \ + --hash=sha256:d002b6f00d73d69333dac9d0b8d5e84d9724ff9ef044fd63c5986e62b7c9e1b1 \ + --hash=sha256:d06bb1f751ba5d417047db62bca3c8fde202b8c11fb50742ab3ab962c81e8216 \ + --hash=sha256:d304906ecc71672e9c89e87c4675dc5c2645e1f4269a5063b99b0bb29f232d13 \ + --hash=sha256:e4e6b05a45525357e382909a4c1600444e2a45b4795163d3b22669285591c1ae \ + --hash=sha256:e74a9a0f5e3fff48fb5a7f2fd2b9b70a3fe014a67522f79b7cca4c0c7e43c9ae \ + --hash=sha256:ea37e7b45949df430fe649e5de8351c423430046a2af20b1c1961cae3afcda77 \ + --hash=sha256:f64836de09927cba6f79dcd00fdd7d5329f3fccc633468507079c829ca4db4e3 \ + --hash=sha256:fd6ec6be509c787f1caf6b247f0b1ca598bef13f4ddeaa126b7658215529ba0f \ + --hash=sha256:fd907ae12cd483cd83e414b12941c632a969171bf90fc937d0c9f268a31cafff \ + --hash=sha256:fd914713266421b7536de2bfa8181aa8c699432b6763a0ea64195ebe28bff6a9 \ + --hash=sha256:fde6c716d51c04b1c25d0b90364d0be954624a0ee9d60e23e850e8d48353d07a # via matplotlib cycler==0.12.1 \ --hash=sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30 \ --hash=sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c # via matplotlib -etils[epath,epy]==1.11.0 \ - --hash=sha256:a394cf3476bcec51c221426a70c39cd1006e889456ba41e4d7f12fd6814be7a5 \ - --hash=sha256:aff3278a3be7fddf302dfd80335e9f924244666c71239cd91e836f3d055f1c4a +etils[epath,epy]==1.13.0 \ + --hash=sha256:a5b60c71f95bcd2d43d4e9fb3dc3879120c1f60472bb5ce19f7a860b1d44f607 \ + --hash=sha256:d9cd4f40fbe77ad6613b7348a18132cc511237b6c076dbb89105c0b520a4c6bb # via -r build/requirements.in -execnet==2.1.1 \ - --hash=sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc \ - --hash=sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3 +execnet==2.1.2 \ + --hash=sha256:63d83bfdd9a23e35b9c6a3261412324f964c2ec8dcd8d3c6916ee9373e0befcd \ + --hash=sha256:67fba928dd5a544b783f6056f449e5e3931a5c378b128bc18501f7ea79e296ec # via pytest-xdist -filelock==3.16.1 \ - --hash=sha256:2082e5703d51fbf98ea75855d9d5527e33d8ff23099bec374a134febee6946b0 \ - --hash=sha256:c249fbfcd5db47e5e2d6d62198e565475ee65e4831e2561c8e313fa7eb961435 +filelock==3.20.1 \ + --hash=sha256:15d9e9a67306188a44baa72f569d2bfd803076269365fdea0934385da4dc361a \ + --hash=sha256:b8360948b351b80f420878d8516519a2204b07aefcdcfd24912a5d33127f188c # via -r build/test-requirements.txt -flatbuffers==24.12.23 \ - --hash=sha256:2910b0bc6ae9b6db78dd2b18d0b7a0709ba240fb5585f286a3a2b30785c22dac \ - --hash=sha256:c418e0d48890f4142b92fd3e343e73a48f194e1f80075ddcc5793779b3585444 +flatbuffers==25.9.23 \ + --hash=sha256:255538574d6cb6d0a79a17ec8bc0d30985913b87513a01cce8bcdb6b4c44d0e2 \ + --hash=sha256:676f9fa62750bb50cf531b42a0a2a118ad8f7f797a511eda12881c016f093b12 # via -r build/test-requirements.txt -fonttools==4.55.3 \ - --hash=sha256:07f8288aacf0a38d174445fc78377a97fb0b83cfe352a90c9d9c1400571963c7 \ - --hash=sha256:11e5de1ee0d95af4ae23c1a138b184b7f06e0b6abacabf1d0db41c90b03d834b \ - --hash=sha256:1bc7ad24ff98846282eef1cbeac05d013c2154f977a79886bb943015d2b1b261 \ - --hash=sha256:1dcc07934a2165ccdc3a5a608db56fb3c24b609658a5b340aee4ecf3ba679dc0 \ - --hash=sha256:22f38464daa6cdb7b6aebd14ab06609328fe1e9705bb0fcc7d1e69de7109ee02 \ - --hash=sha256:27e4ae3592e62eba83cd2c4ccd9462dcfa603ff78e09110680a5444c6925d841 \ - --hash=sha256:3983313c2a04d6cc1fe9251f8fc647754cf49a61dac6cb1e7249ae67afaafc45 \ - --hash=sha256:529cef2ce91dc44f8e407cc567fae6e49a1786f2fefefa73a294704c415322a4 \ - --hash=sha256:5323a22eabddf4b24f66d26894f1229261021dacd9d29e89f7872dd8c63f0b8b \ - --hash=sha256:54153c49913f45065c8d9e6d0c101396725c5621c8aee744719300f79771d75a \ - --hash=sha256:546565028e244a701f73df6d8dd6be489d01617863ec0c6a42fa25bf45d43048 \ - --hash=sha256:5480673f599ad410695ca2ddef2dfefe9df779a9a5cda89503881e503c9c7d90 \ - --hash=sha256:5e8d657cd7326eeaba27de2740e847c6b39dde2f8d7cd7cc56f6aad404ddf0bd \ - --hash=sha256:62d65a3022c35e404d19ca14f291c89cc5890032ff04f6c17af0bd1927299674 \ - --hash=sha256:6314bf82c54c53c71805318fcf6786d986461622dd926d92a465199ff54b1b72 \ - --hash=sha256:7a8aa2c5e5b8b3bcb2e4538d929f6589a5c6bdb84fd16e2ed92649fb5454f11c \ - --hash=sha256:827e95fdbbd3e51f8b459af5ea10ecb4e30af50221ca103bea68218e9615de07 \ - --hash=sha256:859c358ebf41db18fb72342d3080bce67c02b39e86b9fbcf1610cca14984841b \ - --hash=sha256:86721fbc389ef5cc1e2f477019e5069e8e4421e8d9576e9c26f840dbb04678de \ - --hash=sha256:89bdc5d88bdeec1b15af790810e267e8332d92561dce4f0748c2b95c9bdf3926 \ - --hash=sha256:8c4491699bad88efe95772543cd49870cf756b019ad56294f6498982408ab03e \ - --hash=sha256:8c5ec45428edaa7022f1c949a632a6f298edc7b481312fc7dc258921e9399628 \ - --hash=sha256:8e75f12c82127486fac2d8bfbf5bf058202f54bf4f158d367e41647b972342ca \ - --hash=sha256:a430178ad3e650e695167cb53242dae3477b35c95bef6525b074d87493c4bf29 \ - --hash=sha256:a8c2794ded89399cc2169c4d0bf7941247b8d5932b2659e09834adfbb01589aa \ - --hash=sha256:aca318b77f23523309eec4475d1fbbb00a6b133eb766a8bdc401faba91261abe \ - --hash=sha256:ae3b6600565b2d80b7c05acb8e24d2b26ac407b27a3f2e078229721ba5698427 \ - --hash=sha256:aedbeb1db64496d098e6be92b2e63b5fac4e53b1b92032dfc6988e1ea9134a4d \ - --hash=sha256:aee3b57643827e237ff6ec6d28d9ff9766bd8b21e08cd13bff479e13d4b14765 \ - --hash=sha256:b54baf65c52952db65df39fcd4820668d0ef4766c0ccdf32879b77f7c804d5c5 \ - --hash=sha256:b586ab5b15b6097f2fb71cafa3c98edfd0dba1ad8027229e7b1e204a58b0e09d \ - --hash=sha256:b8d5e8916c0970fbc0f6f1bece0063363bb5857a7f170121a4493e31c3db3314 \ - --hash=sha256:bc5dbb4685e51235ef487e4bd501ddfc49be5aede5e40f4cefcccabc6e60fb4b \ - --hash=sha256:bdcc9f04b36c6c20978d3f060e5323a43f6222accc4e7fcbef3f428e216d96af \ - --hash=sha256:c3ca99e0d460eff46e033cd3992a969658c3169ffcd533e0a39c63a38beb6831 \ - --hash=sha256:caf8230f3e10f8f5d7593eb6d252a37caf58c480b19a17e250a63dad63834cf3 \ - --hash=sha256:cd70de1a52a8ee2d1877b6293af8a2484ac82514f10b1c67c1c5762d38073e56 \ - --hash=sha256:cf4fe7c124aa3f4e4c1940880156e13f2f4d98170d35c749e6b4f119a872551e \ - --hash=sha256:d342e88764fb201286d185093781bf6628bbe380a913c24adf772d901baa8276 \ - --hash=sha256:da9da6d65cd7aa6b0f806556f4985bcbf603bf0c5c590e61b43aa3e5a0f822d0 \ - --hash=sha256:dc5294a3d5c84226e3dbba1b6f61d7ad813a8c0238fceea4e09aa04848c3d851 \ - --hash=sha256:dd68c87a2bfe37c5b33bcda0fba39b65a353876d3b9006fde3adae31f97b3ef5 \ - --hash=sha256:e6e8766eeeb2de759e862004aa11a9ea3d6f6d5ec710551a88b476192b64fd54 \ - --hash=sha256:e894b5bd60d9f473bed7a8f506515549cc194de08064d829464088d23097331b \ - --hash=sha256:eb6ca911c4c17eb51853143624d8dc87cdcdf12a711fc38bf5bd21521e79715f \ - --hash=sha256:ed63959d00b61959b035c7d47f9313c2c1ece090ff63afea702fe86de00dbed4 \ - --hash=sha256:f412604ccbeee81b091b420272841e5ec5ef68967a9790e80bffd0e30b8e2977 \ - --hash=sha256:f7d66c15ba875432a2d2fb419523f5d3d347f91f48f57b8b08a2dfc3c39b8a3f \ - --hash=sha256:f9e736f60f4911061235603a6119e72053073a12c6d7904011df2d8fad2c0e35 \ - --hash=sha256:fb594b5a99943042c702c550d5494bdd7577f6ef19b0bc73877c948a63184a32 +fonttools==4.61.1 \ + --hash=sha256:0de30bfe7745c0d1ffa2b0b7048fb7123ad0d71107e10ee090fa0b16b9452e87 \ + --hash=sha256:10d88e55330e092940584774ee5e8a6971b01fc2f4d3466a1d6c158230880796 \ + --hash=sha256:11f35ad7805edba3aac1a3710d104592df59f4b957e30108ae0ba6c10b11dd75 \ + --hash=sha256:15acc09befd16a0fb8a8f62bc147e1a82817542d72184acca9ce6e0aeda9fa6d \ + --hash=sha256:17d2bf5d541add43822bcf0c43d7d847b160c9bb01d15d5007d84e2217aaa371 \ + --hash=sha256:2180f14c141d2f0f3da43f3a81bc8aa4684860f6b0e6f9e165a4831f24e6a23b \ + --hash=sha256:21e7c8d76f62ab13c9472ccf74515ca5b9a761d1bde3265152a6dc58700d895b \ + --hash=sha256:41a7170d042e8c0024703ed13b71893519a1a6d6e18e933e3ec7507a2c26a4b2 \ + --hash=sha256:41ed4b5ec103bd306bb68f81dc166e77409e5209443e5773cb4ed837bcc9b0d3 \ + --hash=sha256:497c31ce314219888c0e2fce5ad9178ca83fe5230b01a5006726cdf3ac9f24d9 \ + --hash=sha256:4c1b526c8d3f615a7b1867f38a9410849c8f4aef078535742198e942fba0e9bd \ + --hash=sha256:4d7092bb38c53bbc78e9255a59158b150bcdc115a1e3b3ce0b5f267dc35dd63c \ + --hash=sha256:4f5686e1fe5fce75d82d93c47a438a25bf0d1319d2843a926f741140b2b16e0c \ + --hash=sha256:58b0ee0ab5b1fc9921eccfe11d1435added19d6494dde14e323f25ad2bc30c56 \ + --hash=sha256:5ce02f38a754f207f2f06557523cd39a06438ba3aafc0639c477ac409fc64e37 \ + --hash=sha256:5fade934607a523614726119164ff621e8c30e8fa1ffffbbd358662056ba69f0 \ + --hash=sha256:5fe9fd43882620017add5eabb781ebfbc6998ee49b35bd7f8f79af1f9f99a958 \ + --hash=sha256:64102ca87e84261419c3747a0d20f396eb024bdbeb04c2bfb37e2891f5fadcb5 \ + --hash=sha256:664c5a68ec406f6b1547946683008576ef8b38275608e1cee6c061828171c118 \ + --hash=sha256:6675329885c44657f826ef01d9e4fb33b9158e9d93c537d84ad8399539bc6f69 \ + --hash=sha256:75c1a6dfac6abd407634420c93864a1e274ebc1c7531346d9254c0d8f6ca00f9 \ + --hash=sha256:75da8f28eff26defba42c52986de97b22106cb8f26515b7c22443ebc9c2d3261 \ + --hash=sha256:77efb033d8d7ff233385f30c62c7c79271c8885d5c9657d967ede124671bbdfb \ + --hash=sha256:78a7d3ab09dc47ac1a363a493e6112d8cabed7ba7caad5f54dbe2f08676d1b47 \ + --hash=sha256:7c7db70d57e5e1089a274cbb2b1fd635c9a24de809a231b154965d415d6c6d24 \ + --hash=sha256:8c56c488ab471628ff3bfa80964372fc13504ece601e0d97a78ee74126b2045c \ + --hash=sha256:91669ccac46bbc1d09e9273546181919064e8df73488ea087dcac3e2968df9ba \ + --hash=sha256:9b666a475a65f4e839d3d10473fad6d47e0a9db14a2f4a224029c5bfde58ad2c \ + --hash=sha256:9cfef3ab326780c04d6646f68d4b4742aae222e8b8ea1d627c74e38afcbc9d91 \ + --hash=sha256:a13fc8aeb24bad755eea8f7f9d409438eb94e82cf86b08fe77a03fbc8f6a96b1 \ + --hash=sha256:a75c301f96db737e1c5ed5fd7d77d9c34466de16095a266509e13da09751bd19 \ + --hash=sha256:a76d4cb80f41ba94a6691264be76435e5f72f2cb3cab0b092a6212855f71c2f6 \ + --hash=sha256:aed04cabe26f30c1647ef0e8fbb207516fd40fe9472e9439695f5c6998e60ac5 \ + --hash=sha256:b148b56f5de675ee16d45e769e69f87623a4944f7443850bf9a9376e628a89d2 \ + --hash=sha256:b501c862d4901792adaec7c25b1ecc749e2662543f68bb194c42ba18d6eec98d \ + --hash=sha256:b846a1fcf8beadeb9ea4f44ec5bdde393e2f1569e17d700bfc49cd69bde75881 \ + --hash=sha256:b931ae8f62db78861b0ff1ac017851764602288575d65b8e8ff1963fed419063 \ + --hash=sha256:c33ab3ca9d3ccd581d58e989d67554e42d8d4ded94ab3ade3508455fe70e65f7 \ + --hash=sha256:c6604b735bb12fef8e0efd5578c9fb5d3d8532d5001ea13a19cddf295673ee09 \ + --hash=sha256:d8db08051fc9e7d8bc622f2112511b8107d8f27cd89e2f64ec45e9825e8288da \ + --hash=sha256:d9203500f7c63545b4ce3799319fe4d9feb1a1b89b28d3cb5abd11b9dd64147e \ + --hash=sha256:dc492779501fa723b04d0ab1f5be046797fee17d27700476edc7ee9ae535a61e \ + --hash=sha256:e6bcdf33aec38d16508ce61fd81838f24c83c90a1d1b8c68982857038673d6b8 \ + --hash=sha256:e76ce097e3c57c4bcb67c5aa24a0ecdbd9f74ea9219997a707a4061fbe2707aa \ + --hash=sha256:eff1ac3cc66c2ac7cda1e64b4e2f3ffef474b7335f92fc3833fc632d595fcee6 \ + --hash=sha256:f3cb4a569029b9f291f88aafc927dd53683757e640081ca8c412781ea144565e \ + --hash=sha256:f79b168428351d11e10c5aeb61a74e1851ec221081299f4cf56036a95431c43a \ + --hash=sha256:fa646ecec9528bef693415c79a86e733c70a4965dd938e9a226b0fc64c9d2e6c \ + --hash=sha256:fe2efccb324948a11dd09d22136fe2ac8a97d6c1347cf0b58a911dcd529f66b7 \ + --hash=sha256:fff4f534200a04b4a36e7ae3cb74493afe807b517a09e99cb4faa89a34ed6ecd # via matplotlib -fsspec==2024.12.0 \ - --hash=sha256:670700c977ed2fb51e0d9f9253177ed20cbde4a3e5c0283cc5385b5870c8533f \ - --hash=sha256:b520aed47ad9804237ff878b504267a3b0b441e97508bd6d2d8774e3db85cee2 +fsspec==2025.10.0 \ + --hash=sha256:7c7712353ae7d875407f97715f0e1ffcc21e33d5b24556cb1e090ae9409ec61d \ + --hash=sha256:b6789427626f068f9a83ca4e8a3cc050850b6c0f71f99ddb4f542b8266a26a59 # via etils -hypothesis==6.123.9 \ - --hash=sha256:0f924bd9513daa9ecddbfe8abe8b3f7598d4d09234fe1027b19b4cd717adba05 \ - --hash=sha256:aca6a2f7aeef85e5201079ab93156fca137a8cabcf5cc39ea2a3b7147432fe89 +hypothesis==6.142.1 \ + --hash=sha256:3179cb08756562c526aaf4a9871ebbff83d2d75c03896ed0bc9c1d14097a930c \ + --hash=sha256:95a7d38fcc58e697e3020665adcb951c630cdbc8065e4b4474949e486b06bd6d # via -r build/test-requirements.txt importlib-resources==6.5.2 \ --hash=sha256:185f87adef5bcc288449d98fb4fba07cea78bc036455dd44c5fc4a2fe78fed2c \ --hash=sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec # via etils -iniconfig==2.0.0 \ - --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ - --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 +iniconfig==2.3.0 \ + --hash=sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730 \ + --hash=sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12 # via pytest -kiwisolver==1.4.8 \ - --hash=sha256:01c3d31902c7db5fb6182832713d3b4122ad9317c2c5877d0539227d96bb2e50 \ - --hash=sha256:034d2c891f76bd3edbdb3ea11140d8510dca675443da7304205a2eaa45d8334c \ - --hash=sha256:085940635c62697391baafaaeabdf3dd7a6c3643577dde337f4d66eba021b2b8 \ - --hash=sha256:08e77738ed7538f036cd1170cbed942ef749137b1311fa2bbe2a7fda2f6bf3cc \ - --hash=sha256:111793b232842991be367ed828076b03d96202c19221b5ebab421ce8bcad016f \ - --hash=sha256:11e1022b524bd48ae56c9b4f9296bce77e15a2e42a502cceba602f804b32bb79 \ - --hash=sha256:151dffc4865e5fe6dafce5480fab84f950d14566c480c08a53c663a0020504b6 \ - --hash=sha256:16523b40aab60426ffdebe33ac374457cf62863e330a90a0383639ce14bf44b2 \ - --hash=sha256:1732e065704b47c9afca7ffa272f845300a4eb959276bf6970dc07265e73b605 \ - --hash=sha256:1c8ceb754339793c24aee1c9fb2485b5b1f5bb1c2c214ff13368431e51fc9a09 \ - --hash=sha256:23454ff084b07ac54ca8be535f4174170c1094a4cff78fbae4f73a4bcc0d4dab \ - --hash=sha256:23d5f023bdc8c7e54eb65f03ca5d5bb25b601eac4d7f1a042888a1f45237987e \ - --hash=sha256:257af1622860e51b1a9d0ce387bf5c2c4f36a90594cb9514f55b074bcc787cfc \ - --hash=sha256:286b18e86682fd2217a48fc6be6b0f20c1d0ed10958d8dc53453ad58d7be0bf8 \ - --hash=sha256:291331973c64bb9cce50bbe871fb2e675c4331dab4f31abe89f175ad7679a4d7 \ - --hash=sha256:2f0121b07b356a22fb0414cec4666bbe36fd6d0d759db3d37228f496ed67c880 \ - --hash=sha256:3452046c37c7692bd52b0e752b87954ef86ee2224e624ef7ce6cb21e8c41cc1b \ - --hash=sha256:34d142fba9c464bc3bbfeff15c96eab0e7310343d6aefb62a79d51421fcc5f1b \ - --hash=sha256:369b75d40abedc1da2c1f4de13f3482cb99e3237b38726710f4a793432b1c5ff \ - --hash=sha256:36dbbfd34838500a31f52c9786990d00150860e46cd5041386f217101350f0d3 \ - --hash=sha256:370fd2df41660ed4e26b8c9d6bbcad668fbe2560462cba151a721d49e5b6628c \ - --hash=sha256:3a96c0e790ee875d65e340ab383700e2b4891677b7fcd30a699146f9384a2bb0 \ - --hash=sha256:3b9b4d2892fefc886f30301cdd80debd8bb01ecdf165a449eb6e78f79f0fabd6 \ - --hash=sha256:3cd3bc628b25f74aedc6d374d5babf0166a92ff1317f46267f12d2ed54bc1d30 \ - --hash=sha256:3ddc373e0eef45b59197de815b1b28ef89ae3955e7722cc9710fb91cd77b7f47 \ - --hash=sha256:4191ee8dfd0be1c3666ccbac178c5a05d5f8d689bbe3fc92f3c4abec817f8fe0 \ - --hash=sha256:54a62808ac74b5e55a04a408cda6156f986cefbcf0ada13572696b507cc92fa1 \ - --hash=sha256:577facaa411c10421314598b50413aa1ebcf5126f704f1e5d72d7e4e9f020d90 \ - --hash=sha256:641f2ddf9358c80faa22e22eb4c9f54bd3f0e442e038728f500e3b978d00aa7d \ - --hash=sha256:65ea09a5a3faadd59c2ce96dc7bf0f364986a315949dc6374f04396b0d60e09b \ - --hash=sha256:68269e60ee4929893aad82666821aaacbd455284124817af45c11e50a4b42e3c \ - --hash=sha256:69b5637c3f316cab1ec1c9a12b8c5f4750a4c4b71af9157645bf32830e39c03a \ - --hash=sha256:7506488470f41169b86d8c9aeff587293f530a23a23a49d6bc64dab66bedc71e \ - --hash=sha256:768cade2c2df13db52475bd28d3a3fac8c9eff04b0e9e2fda0f3760f20b3f7fc \ - --hash=sha256:77e6f57a20b9bd4e1e2cedda4d0b986ebd0216236f0106e55c28aea3d3d69b16 \ - --hash=sha256:782bb86f245ec18009890e7cb8d13a5ef54dcf2ebe18ed65f795e635a96a1c6a \ - --hash=sha256:7a3ad337add5148cf51ce0b55642dc551c0b9d6248458a757f98796ca7348712 \ - --hash=sha256:7cd2785b9391f2873ad46088ed7599a6a71e762e1ea33e87514b1a441ed1da1c \ - --hash=sha256:7e9a60b50fe8b2ec6f448fe8d81b07e40141bfced7f896309df271a0b92f80f3 \ - --hash=sha256:84a2f830d42707de1d191b9490ac186bf7997a9495d4e9072210a1296345f7dc \ - --hash=sha256:856b269c4d28a5c0d5e6c1955ec36ebfd1651ac00e1ce0afa3e28da95293b561 \ - --hash=sha256:858416b7fb777a53f0c59ca08190ce24e9abbd3cffa18886a5781b8e3e26f65d \ - --hash=sha256:87b287251ad6488e95b4f0b4a79a6d04d3ea35fde6340eb38fbd1ca9cd35bbbc \ - --hash=sha256:88c6f252f6816a73b1f8c904f7bbe02fd67c09a69f7cb8a0eecdbf5ce78e63db \ - --hash=sha256:893f5525bb92d3d735878ec00f781b2de998333659507d29ea4466208df37bed \ - --hash=sha256:89c107041f7b27844179ea9c85d6da275aa55ecf28413e87624d033cf1f6b751 \ - --hash=sha256:918139571133f366e8362fa4a297aeba86c7816b7ecf0bc79168080e2bd79957 \ - --hash=sha256:99cea8b9dd34ff80c521aef46a1dddb0dcc0283cf18bde6d756f1e6f31772165 \ - --hash=sha256:a17b7c4f5b2c51bb68ed379defd608a03954a1845dfed7cc0117f1cc8a9b7fd2 \ - --hash=sha256:a3c44cb68861de93f0c4a8175fbaa691f0aa22550c331fefef02b618a9dcb476 \ - --hash=sha256:a4d3601908c560bdf880f07d94f31d734afd1bb71e96585cace0e38ef44c6d84 \ - --hash=sha256:a5ce1e481a74b44dd5e92ff03ea0cb371ae7a0268318e202be06c8f04f4f1246 \ - --hash=sha256:a66f60f8d0c87ab7f59b6fb80e642ebb29fec354a4dfad687ca4092ae69d04f4 \ - --hash=sha256:b21dbe165081142b1232a240fc6383fd32cdd877ca6cc89eab93e5f5883e1c25 \ - --hash=sha256:b47a465040146981dc9db8647981b8cb96366fbc8d452b031e4f8fdffec3f26d \ - --hash=sha256:b5773efa2be9eb9fcf5415ea3ab70fc785d598729fd6057bea38d539ead28271 \ - --hash=sha256:b83dc6769ddbc57613280118fb4ce3cd08899cc3369f7d0e0fab518a7cf37fdb \ - --hash=sha256:bade438f86e21d91e0cf5dd7c0ed00cda0f77c8c1616bd83f9fc157fa6760d31 \ - --hash=sha256:bcb1ebc3547619c3b58a39e2448af089ea2ef44b37988caf432447374941574e \ - --hash=sha256:be4816dc51c8a471749d664161b434912eee82f2ea66bd7628bd14583a833e85 \ - --hash=sha256:c07b29089b7ba090b6f1a669f1411f27221c3662b3a1b7010e67b59bb5a6f10b \ - --hash=sha256:c2b9a96e0f326205af81a15718a9073328df1173a2619a68553decb7097fd5d7 \ - --hash=sha256:c5020c83e8553f770cb3b5fc13faac40f17e0b205bd237aebd21d53d733adb03 \ - --hash=sha256:c72941acb7b67138f35b879bbe85be0f6c6a70cab78fe3ef6db9c024d9223e5b \ - --hash=sha256:c8bf637892dc6e6aad2bc6d4d69d08764166e5e3f69d469e55427b6ac001b19d \ - --hash=sha256:cc978a80a0db3a66d25767b03688f1147a69e6237175c0f4ffffaaedf744055a \ - --hash=sha256:ce2cf1e5688edcb727fdf7cd1bbd0b6416758996826a8be1d958f91880d0809d \ - --hash=sha256:d47b28d1dfe0793d5e96bce90835e17edf9a499b53969b03c6c47ea5985844c3 \ - --hash=sha256:d47cfb2650f0e103d4bf68b0b5804c68da97272c84bb12850d877a95c056bd67 \ - --hash=sha256:d5536185fce131780ebd809f8e623bf4030ce1b161353166c49a3c74c287897f \ - --hash=sha256:d561d2d8883e0819445cfe58d7ddd673e4015c3c57261d7bdcd3710d0d14005c \ - --hash=sha256:d6af5e8815fd02997cb6ad9bbed0ee1e60014438ee1a5c2444c96f87b8843502 \ - --hash=sha256:d6d6bd87df62c27d4185de7c511c6248040afae67028a8a22012b010bc7ad062 \ - --hash=sha256:dace81d28c787956bfbfbbfd72fdcef014f37d9b48830829e488fdb32b49d954 \ - --hash=sha256:e063ef9f89885a1d68dd8b2e18f5ead48653176d10a0e324e3b0030e3a69adeb \ - --hash=sha256:e7a019419b7b510f0f7c9dceff8c5eae2392037eae483a7f9162625233802b0a \ - --hash=sha256:eaa973f1e05131de5ff3569bbba7f5fd07ea0595d3870ed4a526d486fe57fa1b \ - --hash=sha256:eb158fe28ca0c29f2260cca8c43005329ad58452c36f0edf298204de32a9a3ed \ - --hash=sha256:ed33ca2002a779a2e20eeb06aea7721b6e47f2d4b8a8ece979d8ba9e2a167e34 \ - --hash=sha256:fc2ace710ba7c1dfd1a3b42530b62b9ceed115f19a1656adefce7b1782a37794 +jax-cuda12-pjrt==0.8.2 ; sys_platform == "linux" \ + --hash=sha256:717a1b196a642409ce195ddf031c20bbeadcc886f55e49a1d3f4927373aeedae \ + --hash=sha256:e3bab41ca7c48e4163db9e7efd271b3aa85f0fe45f5ed0708d6bbed93a59f977 + # via + # -r build/requirements.in + # jax-cuda12-plugin +jax-cuda12-plugin==0.8.2 ; sys_platform == "linux" and python_version < "3.14" \ + --hash=sha256:0b0a3304ce7e494acd8d9c593490c112a32cdb6010fe1afc584d9e41fd863167 \ + --hash=sha256:1b4828242d57f233b394d17ebaa599c503c1fb9b7c754012a06eb84dbc935fc8 \ + --hash=sha256:20165861b3d3e66ebb2c0f63a547d1d5ee17ea44ac3be7153c7908c9ca8c88f3 \ + --hash=sha256:377e4be17e22dde0343b3f3c05bf69235b3dbf11d766cca9c5a93da47971dcb7 \ + --hash=sha256:403d5e07731b5cdac3bd9fb3f448bd8480062cb2c0ab61ea2ad23fcd0a65479a \ + --hash=sha256:58c51473fc622e03138035985f741833564d70a4bd5a2178f61b62cdaa32ff94 \ + --hash=sha256:637387dc3408cd204562668502f9e95f76c6edde0a6d2e48f055162dc2aebf0d \ + --hash=sha256:70d33222484ad5c375b8f8357b7c23cacb844f6ecfc39567f8dd47fde6e87858 \ + --hash=sha256:82c6798be66bf8c773386918e4c8e5cd8119753f3bfb3ca4bbc46818283750c6 \ + --hash=sha256:a5898bac1d8ab6020b54546440256409f2c66bcbbb3a1099ca473c84843addad \ + --hash=sha256:d68a6d8b4a45ee561746bac7a6468da8203832626b0b39ad4ac43011f61f875d \ + --hash=sha256:dd4f7c34d4512ff5a36fd1b01584ef7781cad615e3f9e71880eae2f4998e5108 + # via -r build/requirements.in +jax-cuda13-pjrt==0.8.2 \ + --hash=sha256:370e9cb9a76af6a6b00e8f6c6ae8e5a0158886b65994a114cf3acf56ea1570f3 \ + --hash=sha256:c0ebc284d4bea9c0b278017b4bb62eb871755fe582cbf22ce1fb770d30fb9af6 + # via + # -r build/requirements.in + # jax-cuda13-plugin +jax-cuda13-plugin==0.8.2 \ + --hash=sha256:0ab3a13f798f973416754cb00b1696a8e6d2425c39496e2145293661bcd28446 \ + --hash=sha256:2c7532f2a6c9dcf50ba4b8ebbe4c78988b8c260695e50f9fdbd68f9417a8bb51 \ + --hash=sha256:358f4b4d0d2b92f62fe7f07d355a9c57c5684da3d4e6f483afec38b23e66f02e \ + --hash=sha256:6703552ee10b6b41b43fe4fc62b4672bec13d271eb709a29bf14b0fa1ec805dd \ + --hash=sha256:8fe2c5874cbae18a01a0fa722dfbd547b5504e7bf42aaaae6ad3e1b676e8daa2 \ + --hash=sha256:950f2c1f2e0f5e13893c4c27a4df540a96a26ebd15d7954fa85ffb8cc8b1d9a1 \ + --hash=sha256:9835076d0a917345700370cfaee5cfd7876c0b0982e99206a8b4848af0ea6e5e \ + --hash=sha256:bb059d99eb960718358f82f3eaa0829f96fc7d7e978ded821f2c40256a8aee4b \ + --hash=sha256:d5f9470db46aee8a64e9ab7817927335b77c6d3698d5113e0da2dd962c1468c4 \ + --hash=sha256:f1643b216ccb41a2d7c0e240359c729c96c5eca76e4525ffcd6313049e348b9a \ + --hash=sha256:f35f687d9b839365a49c2b0052eaeb6b9f0a9879813b4fa020f8d9caa3ddde46 \ + --hash=sha256:fa6a6faf3130d6ef05f9ee6cbc27e0a2e4db0f23720136403991eccb3a0144af + # via -r build/requirements.in +jaxlib==0.8.2 \ + --hash=sha256:023de6f3f56da2af7037970996500586331fdb50b530ecbb54b9666da633bd00 \ + --hash=sha256:05b958f497e49824c432e734bb059723b7dfe69e2ad696a9f9c8ad82fff7c3f8 \ + --hash=sha256:1bfbcf6c3de221784fa4cdb6765a09d71cb4298b15626b3d0409b3dfcd8a8667 \ + --hash=sha256:28eec1a4e0639a0d8702cea3cb70dd3663053dbfa344452994ea48dc6ceadaa5 \ + --hash=sha256:2b9789bd08f8b0cc5a5c12ae896fe432d5942e32e417091b8b5a96a9a6fd5cf1 \ + --hash=sha256:3b16e50c5b730c9dd0a49e55f1acfaa722b00b1af0522a591558dcc0464252f2 \ + --hash=sha256:490bf0cb029c73c65c9431124b86cdc95082dbc1fb76fc549d24d75da33e5454 \ + --hash=sha256:4d006db96be020c8165212a1216372f8acac4ff4f8fb067743d694ef2b301ace \ + --hash=sha256:68108dff0de74adc468016be9a19f80efe48c660c0d5a122287094b44b092afc \ + --hash=sha256:7c304f3a016965b9d1f5239a8a0399a73925f5604fe914c5ca66ecf734bf6422 \ + --hash=sha256:7da8127557c786264049ae55460d1b8d04cc3cdf0403a087f2fc1e6d313ec722 \ + --hash=sha256:964626f581beab31ee6826b228fcc2ec5181b05cecf94a528dff97921c145dbc \ + --hash=sha256:a397ea7dcb37d689ce79173eeb99b2f1347637a36be9a27f20ae6848bfc58bfc \ + --hash=sha256:aa8701b6356f098e8452c3cec762fb5f706fcb8f67ffd65964f63982479aa23b \ + --hash=sha256:bb89be452b1b808d3f88fc01c415b364a260be4cc7ac120c038009f6150a32dc \ + --hash=sha256:beffb004e7eeb5c9afb24439e2b2cf45a4ee3e3e8adf45e355edf2af62acf8b8 \ + --hash=sha256:ccf77da917a20935247c990691decfcbdd06c25ef0ac94d914a04aadb22f714c \ + --hash=sha256:dffc22b5b732b9556d92c918b251c61bcc046617c4dbb51e1f7a656587fddffb \ + --hash=sha256:e6a97dfb0232eed9a2bb6e3828e4f682dbac1a7fea840bfda574cae2dbf5faf9 \ + --hash=sha256:f205e91c3a152a2a76c0bc59a6a2de03e87ec261b91e8812922777185e7b08f5 \ + --hash=sha256:f28edac8c226fc07fa3e8af6f9defede8ac2c307429e3291edce8739d39becc9 \ + --hash=sha256:f472cc72e3058e50b5f0230b236d5a1183bf6c3d5423d2a52eff07bcf34908de + # via -r build/requirements.in +kiwisolver==1.4.9 \ + --hash=sha256:0749fd8f4218ad2e851e11cc4dc05c7cbc0cbc4267bdfdb31782e65aace4ee9c \ + --hash=sha256:0763515d4df10edf6d06a3c19734e2566368980d21ebec439f33f9eb936c07b7 \ + --hash=sha256:0856e241c2d3df4efef7c04a1e46b1936b6120c9bcf36dd216e3acd84bc4fb21 \ + --hash=sha256:0a590506f303f512dff6b7f75fd2fd18e16943efee932008fe7140e5fa91d80e \ + --hash=sha256:0ab74e19f6a2b027ea4f845a78827969af45ce790e6cb3e1ebab71bdf9f215ff \ + --hash=sha256:0ae37737256ba2de764ddc12aed4956460277f00c4996d51a197e72f62f5eec7 \ + --hash=sha256:0e4e2bf29574a6a7b7f6cb5fa69293b9f96c928949ac4a53ba3f525dffb87f9c \ + --hash=sha256:15163165efc2f627eb9687ea5f3a28137217d217ac4024893d753f46bce9de26 \ + --hash=sha256:17680d737d5335b552994a2008fab4c851bcd7de33094a82067ef3a576ff02fa \ + --hash=sha256:1a12cf6398e8a0a001a059747a1cbf24705e18fe413bc22de7b3d15c67cffe3f \ + --hash=sha256:1b11d6a633e4ed84fc0ddafd4ebfd8ea49b3f25082c04ad12b8315c11d504dc1 \ + --hash=sha256:1fa333e8b2ce4d9660f2cda9c0e1b6bafcfb2457a9d259faa82289e73ec24891 \ + --hash=sha256:2327a4a30d3ee07d2fbe2e7933e8a37c591663b96ce42a00bc67461a87d7df77 \ + --hash=sha256:2405a7d98604b87f3fc28b1716783534b1b4b8510d8142adca34ee0bc3c87543 \ + --hash=sha256:2489e4e5d7ef9a1c300a5e0196e43d9c739f066ef23270607d45aba368b91f2d \ + --hash=sha256:24c175051354f4a28c5d6a31c93906dc653e2bf234e8a4bbfb964892078898ce \ + --hash=sha256:2635d352d67458b66fd0667c14cb1d4145e9560d503219034a18a87e971ce4f3 \ + --hash=sha256:2c1a4f57df73965f3f14df20b80ee29e6a7930a57d2d9e8491a25f676e197c60 \ + --hash=sha256:2c93f00dcba2eea70af2be5f11a830a742fe6b579a1d4e00f47760ef13be247a \ + --hash=sha256:39a219e1c81ae3b103643d2aedb90f1ef22650deb266ff12a19e7773f3e5f089 \ + --hash=sha256:3b3115b2581ea35bb6d1f24a4c90af37e5d9b49dcff267eeed14c3893c5b86ab \ + --hash=sha256:40092754720b174e6ccf9e845d0d8c7d8e12c3d71e7fc35f55f3813e96376f78 \ + --hash=sha256:412f287c55a6f54b0650bd9b6dce5aceddb95864a1a90c87af16979d37c89771 \ + --hash=sha256:464415881e4801295659462c49461a24fb107c140de781d55518c4b80cb6790f \ + --hash=sha256:497d05f29a1300d14e02e6441cf0f5ee81c1ff5a304b0d9fb77423974684e08b \ + --hash=sha256:4a2899935e724dd1074cb568ce7ac0dce28b2cd6ab539c8e001a8578eb106d14 \ + --hash=sha256:4a48a2ce79d65d363597ef7b567ce3d14d68783d2b2263d98db3d9477805ba32 \ + --hash=sha256:4d1d9e582ad4d63062d34077a9a1e9f3c34088a2ec5135b1f7190c07cf366527 \ + --hash=sha256:52a15b0f35dad39862d376df10c5230155243a2c1a436e39eb55623ccbd68185 \ + --hash=sha256:540c7c72324d864406a009d72f5d6856f49693db95d1fbb46cf86febef873634 \ + --hash=sha256:5656aa670507437af0207645273ccdfee4f14bacd7f7c67a4306d0dcaeaf6eed \ + --hash=sha256:5a0f2724dfd4e3b3ac5a82436a8e6fd16baa7d507117e4279b660fe8ca38a3a1 \ + --hash=sha256:60c439763a969a6af93b4881db0eed8fadf93ee98e18cbc35bc8da868d0c4f0c \ + --hash=sha256:61874cdb0a36016354853593cffc38e56fc9ca5aa97d2c05d3dcf6922cd55a11 \ + --hash=sha256:67bb8b474b4181770f926f7b7d2f8c0248cbcb78b660fdd41a47054b28d2a752 \ + --hash=sha256:720e05574713db64c356e86732c0f3c5252818d05f9df320f0ad8380641acea5 \ + --hash=sha256:72d0eb9fba308b8311685c2268cf7d0a0639a6cd027d8128659f72bdd8a024b4 \ + --hash=sha256:767c23ad1c58c9e827b649a9ab7809fd5fd9db266a9cf02b0e926ddc2c680d58 \ + --hash=sha256:77937e5e2a38a7b48eef0585114fe7930346993a88060d0bf886086d2aa49ef5 \ + --hash=sha256:7a08b491ec91b1d5053ac177afe5290adacf1f0f6307d771ccac5de30592d198 \ + --hash=sha256:7b4da0d01ac866a57dd61ac258c5607b4cd677f63abaec7b148354d2b2cdd536 \ + --hash=sha256:7cf974dd4e35fa315563ac99d6287a1024e4dc2077b8a7d7cd3d2fb65d283134 \ + --hash=sha256:84fd60810829c27ae375114cd379da1fa65e6918e1da405f356a775d49a62bcf \ + --hash=sha256:858e4c22fb075920b96a291928cb7dea5644e94c0ee4fcd5af7e865655e4ccf2 \ + --hash=sha256:85b5352f94e490c028926ea567fc569c52ec79ce131dadb968d3853e809518c2 \ + --hash=sha256:85bd218b5ecfbee8c8a82e121802dcb519a86044c9c3b2e4aef02fa05c6da370 \ + --hash=sha256:8a1f570ce4d62d718dce3f179ee78dac3b545ac16c0c04bb363b7607a949c0d1 \ + --hash=sha256:8fdca1def57a2e88ef339de1737a1449d6dbf5fab184c54a1fca01d541317154 \ + --hash=sha256:90f47e70293fc3688b71271100a1a5453aa9944a81d27ff779c108372cf5567b \ + --hash=sha256:92a2f997387a1b79a75e7803aa7ded2cfbe2823852ccf1ba3bcf613b62ae3197 \ + --hash=sha256:9928fe1eb816d11ae170885a74d074f57af3a0d65777ca47e9aeb854a1fba386 \ + --hash=sha256:9af39d6551f97d31a4deebeac6f45b156f9755ddc59c07b402c148f5dbb6482a \ + --hash=sha256:9cf554f21be770f5111a1690d42313e140355e687e05cf82cb23d0a721a64a48 \ + --hash=sha256:a30fd6fdef1430fd9e1ba7b3398b5ee4e2887783917a687d86ba69985fb08748 \ + --hash=sha256:a31d512c812daea6d8b3be3b2bfcbeb091dbb09177706569bcfc6240dcf8b41c \ + --hash=sha256:a5d0432ccf1c7ab14f9949eec60c5d1f924f17c037e9f8b33352fa05799359b8 \ + --hash=sha256:a60ea74330b91bd22a29638940d115df9dc00af5035a9a2a6ad9399ffb4ceca5 \ + --hash=sha256:ac5a486ac389dddcc5bef4f365b6ae3ffff2c433324fb38dd35e3fab7c957999 \ + --hash=sha256:aedff62918805fb62d43a4aa2ecd4482c380dc76cd31bd7c8878588a61bd0369 \ + --hash=sha256:b34e51affded8faee0dfdb705416153819d8ea9250bbbf7ea1b249bdeb5f1122 \ + --hash=sha256:b4b4d74bda2b8ebf4da5bd42af11d02d04428b2c32846e4c2c93219df8a7987b \ + --hash=sha256:b67e6efbf68e077dd71d1a6b37e43e1a99d0bff1a3d51867d45ee8908b931098 \ + --hash=sha256:b78efa4c6e804ecdf727e580dbb9cba85624d2e1c6b5cb059c66290063bd99a9 \ + --hash=sha256:bb4ae2b57fc1d8cbd1cf7b1d9913803681ffa903e7488012be5b76dedf49297f \ + --hash=sha256:bdd1a81a1860476eb41ac4bc1e07b3f07259e6d55bbf739b79c8aaedcf512799 \ + --hash=sha256:bdee92c56a71d2b24c33a7d4c2856bd6419d017e08caa7802d2963870e315028 \ + --hash=sha256:be6a04e6c79819c9a8c2373317d19a96048e5a3f90bec587787e86a1153883c2 \ + --hash=sha256:bfc08add558155345129c7803b3671cf195e6a56e7a12f3dde7c57d9b417f525 \ + --hash=sha256:c3b22c26c6fd6811b0ae8363b95ca8ce4ea3c202d3d0975b2914310ceb1bcc4d \ + --hash=sha256:c9e7cdf45d594ee04d5be1b24dd9d49f3d1590959b2271fb30b5ca2b262c00fb \ + --hash=sha256:cb27e7b78d716c591e88e0a09a2139c6577865d7f2e152488c2cc6257f460872 \ + --hash=sha256:cc9617b46837c6468197b5945e196ee9ca43057bb7d9d1ae688101e4e1dddf64 \ + --hash=sha256:ccd09f20ccdbbd341b21a67ab50a119b64a403b09288c27481575105283c1586 \ + --hash=sha256:ce6a3a4e106cf35c2d9c4fa17c05ce0b180db622736845d4315519397a77beaf \ + --hash=sha256:d0005b053977e7b43388ddec89fa567f43d4f6d5c2c0affe57de5ebf290dc552 \ + --hash=sha256:d4188e73af84ca82468f09cadc5ac4db578109e52acb4518d8154698d3a87ca2 \ + --hash=sha256:d4efec7bcf21671db6a3294ff301d2fc861c31faa3c8740d1a94689234d1b415 \ + --hash=sha256:d75aa530ccfaa593da12834b86a0724f58bff12706659baa9227c2ccaa06264c \ + --hash=sha256:d84cd4061ae292d8ac367b2c3fa3aad11cb8625a95d135fe93f286f914f3f5a6 \ + --hash=sha256:d8aacd3d4b33b772542b2e01beb50187536967b514b00003bdda7589722d2a64 \ + --hash=sha256:d8fc5c867c22b828001b6a38d2eaeb88160bf5783c6cb4a5e440efc981ce286d \ + --hash=sha256:d976bbb382b202f71c67f77b0ac11244021cfa3f7dfd9e562eefcea2df711548 \ + --hash=sha256:dba5ee5d3981160c28d5490f0d1b7ed730c22470ff7f6cc26cfcfaacb9896a07 \ + --hash=sha256:dc1ae486f9abcef254b5618dfb4113dd49f94c68e3e027d03cf0143f3f772b61 \ + --hash=sha256:dd0a578400839256df88c16abddf9ba14813ec5f21362e1fe65022e00c883d4d \ + --hash=sha256:deed0c7258ceb4c44ad5ec7d9918f9f14fd05b2be86378d86cf50e63d1e7b771 \ + --hash=sha256:e09c2279a4d01f099f52d5c4b3d9e208e91edcbd1a175c9662a8b16e000fece9 \ + --hash=sha256:e2ea9f7ab7fbf18fffb1b5434ce7c69a07582f7acc7717720f1d69f3e806f90c \ + --hash=sha256:e6b93f13371d341afee3be9f7c5964e3fe61d5fa30f6a30eb49856935dfe4fc3 \ + --hash=sha256:eb14a5da6dc7642b0f3a18f13654847cd8b7a2550e2645a5bda677862b03ba16 \ + --hash=sha256:ed0fecd28cc62c54b262e3736f8bb2512d8dcfdc2bcf08be5f47f96bf405b145 \ + --hash=sha256:ede8c6d533bc6601a47ad4046080d36b8fc99f81e6f1c17b0ac3c2dc91ac7611 \ + --hash=sha256:efb3a45b35622bb6c16dbfab491a8f5a391fe0e9d45ef32f4df85658232ca0e2 \ + --hash=sha256:f117e1a089d9411663a3207ba874f31be9ac8eaa5b533787024dc07aeb74f464 \ + --hash=sha256:f2ba92255faa7309d06fe44c3a4a97efe1c8d640c2a79a5ef728b685762a6fd2 \ + --hash=sha256:f6008a4919fdbc0b0097089f67a1eb55d950ed7e90ce2cc3e640abadd2757a04 \ + --hash=sha256:f68208a520c3d86ea51acf688a3e3002615a7f0238002cccc17affecc86a8a54 \ + --hash=sha256:f68e4f3eeca8fb22cc3d731f9715a13b652795ef657a13df1ad0c7dc0e9731df \ + --hash=sha256:fb3b8132019ea572f4611d770991000d7f58127560c4889729248eb5852a102f \ + --hash=sha256:fb940820c63a9590d31d88b815e7a3aa5915cad3ce735ab45f0c730b39547de1 \ + --hash=sha256:fc1795ac5cd0510207482c3d1d3ed781143383b8cfd36f5c645f3897ce066220 # via matplotlib -markdown-it-py==3.0.0 \ - --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ - --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb +libtpu==0.0.32 ; sys_platform == "linux" and platform_machine == "x86_64" \ + --hash=sha256:37f6aefe6d69d24e5e268e74c1e90cd0ce0fa6a796720709445d8938605912ee \ + --hash=sha256:4a4ae6db0a90f0e33902cd345923a5d77dfc5bbab9a3e524c79bf876858104ee \ + --hash=sha256:6453995774b902e43d1caf0d82298a2e90f2f731ddd74d6bef9510d91732775f \ + --hash=sha256:7ffd9cdb89da22d32bb60ce70b88e9e76861c5d704ebf673f32987e19e9913f3 \ + --hash=sha256:98d9713a39d9581c9b10a7bdeaa575cc874a4a7f543cfb40d3a723ee27b428b5 \ + --hash=sha256:d2a0e7abe4a029b79049bc95da20e4c2a64f8ef09823f75286cecfbb4b0b6da1 + # via -r build/requirements.in +markdown-it-py==4.0.0 \ + --hash=sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147 \ + --hash=sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3 # via rich -matplotlib==3.10.0 ; python_version >= "3.11" \ - --hash=sha256:01d2b19f13aeec2e759414d3bfe19ddfb16b13a1250add08d46d5ff6f9be83c6 \ - --hash=sha256:12eaf48463b472c3c0f8dbacdbf906e573013df81a0ab82f0616ea4b11281908 \ - --hash=sha256:2c5829a5a1dd5a71f0e31e6e8bb449bc0ee9dbfb05ad28fc0c6b55101b3a4be6 \ - --hash=sha256:2fbbabc82fde51391c4da5006f965e36d86d95f6ee83fb594b279564a4c5d0d2 \ - --hash=sha256:3547d153d70233a8496859097ef0312212e2689cdf8d7ed764441c77604095ae \ - --hash=sha256:359f87baedb1f836ce307f0e850d12bb5f1936f70d035561f90d41d305fdacea \ - --hash=sha256:3b427392354d10975c1d0f4ee18aa5844640b512d5311ef32efd4dd7db106ede \ - --hash=sha256:4659665bc7c9b58f8c00317c3c2a299f7f258eeae5a5d56b4c64226fca2f7c59 \ - --hash=sha256:4673ff67a36152c48ddeaf1135e74ce0d4bce1bbf836ae40ed39c29edf7e2765 \ - --hash=sha256:503feb23bd8c8acc75541548a1d709c059b7184cde26314896e10a9f14df5f12 \ - --hash=sha256:5439f4c5a3e2e8eab18e2f8c3ef929772fd5641876db71f08127eed95ab64683 \ - --hash=sha256:5cdbaf909887373c3e094b0318d7ff230b2ad9dcb64da7ade654182872ab2593 \ - --hash=sha256:5e6c6461e1fc63df30bf6f80f0b93f5b6784299f721bc28530477acd51bfc3d1 \ - --hash=sha256:5fd41b0ec7ee45cd960a8e71aea7c946a28a0b8a4dcee47d2856b2af051f334c \ - --hash=sha256:607b16c8a73943df110f99ee2e940b8a1cbf9714b65307c040d422558397dac5 \ - --hash=sha256:7e8632baebb058555ac0cde75db885c61f1212e47723d63921879806b40bec6a \ - --hash=sha256:81713dd0d103b379de4516b861d964b1d789a144103277769238c732229d7f03 \ - --hash=sha256:845d96568ec873be63f25fa80e9e7fae4be854a66a7e2f0c8ccc99e94a8bd4ef \ - --hash=sha256:95b710fea129c76d30be72c3b38f330269363fbc6e570a5dd43580487380b5ff \ - --hash=sha256:96f2886f5c1e466f21cc41b70c5a0cd47bfa0015eb2d5793c88ebce658600e25 \ - --hash=sha256:994c07b9d9fe8d25951e3202a68c17900679274dadfc1248738dcfa1bd40d7f3 \ - --hash=sha256:9ade1003376731a971e398cc4ef38bb83ee8caf0aee46ac6daa4b0506db1fd06 \ - --hash=sha256:9b0558bae37f154fffda54d779a592bc97ca8b4701f1c710055b609a3bac44c8 \ - --hash=sha256:a2a43cbefe22d653ab34bb55d42384ed30f611bcbdea1f8d7f431011a2e1c62e \ - --hash=sha256:a994f29e968ca002b50982b27168addfd65f0105610b6be7fa515ca4b5307c95 \ - --hash=sha256:ad2e15300530c1a94c63cfa546e3b7864bd18ea2901317bae8bbf06a5ade6dcf \ - --hash=sha256:ae80dc3a4add4665cf2faa90138384a7ffe2a4e37c58d83e115b54287c4f06ef \ - --hash=sha256:b886d02a581b96704c9d1ffe55709e49b4d2d52709ccebc4be42db856e511278 \ - --hash=sha256:c40ba2eb08b3f5de88152c2333c58cee7edcead0a2a0d60fcafa116b17117adc \ - --hash=sha256:c55b20591ced744aa04e8c3e4b7543ea4d650b6c3c4b208c08a05b4010e8b442 \ - --hash=sha256:c58a9622d5dbeb668f407f35f4e6bfac34bb9ecdcc81680c04d0258169747997 \ - --hash=sha256:d44cb942af1693cced2604c33a9abcef6205601c445f6d0dc531d813af8a2f5a \ - --hash=sha256:d907fddb39f923d011875452ff1eca29a9e7f21722b873e90db32e5d8ddff12e \ - --hash=sha256:fd44fc75522f58612ec4a33958a7e5552562b7705b42ef1b4f8c0818e304a363 +matplotlib==3.10.8 \ + --hash=sha256:00270d217d6b20d14b584c521f810d60c5c78406dc289859776550df837dcda7 \ + --hash=sha256:0a33deb84c15ede243aead39f77e990469fff93ad1521163305095b77b72ce4a \ + --hash=sha256:113bb52413ea508ce954a02c10ffd0d565f9c3bc7f2eddc27dfe1731e71c7b5f \ + --hash=sha256:12d90df9183093fcd479f4172ac26b322b1248b15729cb57f42f71f24c7e37a3 \ + --hash=sha256:15d30132718972c2c074cd14638c7f4592bd98719e2308bccea40e0538bc0cb5 \ + --hash=sha256:18821ace09c763ec93aef5eeff087ee493a24051936d7b9ebcad9662f66501f9 \ + --hash=sha256:1ae029229a57cd1e8fe542485f27e7ca7b23aa9e8944ddb4985d0bc444f1eca2 \ + --hash=sha256:2299372c19d56bcd35cf05a2738308758d32b9eaed2371898d8f5bd33f084aa3 \ + --hash=sha256:238b7ce5717600615c895050239ec955d91f321c209dd110db988500558e70d6 \ + --hash=sha256:24d50994d8c5816ddc35411e50a86ab05f575e2530c02752e02538122613371f \ + --hash=sha256:25d380fe8b1dc32cf8f0b1b448470a77afb195438bafdf1d858bfb876f3edf7b \ + --hash=sha256:2c1998e92cd5999e295a731bcb2911c75f597d937341f3030cc24ef2733d78a8 \ + --hash=sha256:2cf5bd12cecf46908f286d7838b2abc6c91cda506c0445b8223a7c19a00df008 \ + --hash=sha256:32f8dce744be5569bebe789e46727946041199030db8aeb2954d26013a0eb26b \ + --hash=sha256:37b3c1cc42aa184b3f738cfa18c1c1d72fd496d85467a6cf7b807936d39aa656 \ + --hash=sha256:3a48a78d2786784cc2413e57397981fb45c79e968d99656706018d6e62e57958 \ + --hash=sha256:3ab4aabc72de4ff77b3ec33a6d78a68227bf1123465887f9905ba79184a1cc04 \ + --hash=sha256:3c624e43ed56313651bc18a47f838b60d7b8032ed348911c54906b130b20071b \ + --hash=sha256:3f2e409836d7f5ac2f1c013110a4d50b9f7edc26328c108915f9075d7d7a91b6 \ + --hash=sha256:3f5c3e4da343bba819f0234186b9004faba952cc420fbc522dc4e103c1985908 \ + --hash=sha256:41703cc95688f2516b480f7f339d8851a6035f18e100ee6a32bc0b8536a12a9c \ + --hash=sha256:495672de149445ec1b772ff2c9ede9b769e3cb4f0d0aa7fa730d7f59e2d4e1c1 \ + --hash=sha256:4cf267add95b1c88300d96ca837833d4112756045364f5c734a2276038dae27d \ + --hash=sha256:56271f3dac49a88d7fca5060f004d9d22b865f743a12a23b1e937a0be4818ee1 \ + --hash=sha256:595ba4d8fe983b88f0eec8c26a241e16d6376fe1979086232f481f8f3f67494c \ + --hash=sha256:5f62550b9a30afde8c1c3ae450e5eb547d579dd69b25c2fc7a1c67f934c1717a \ + --hash=sha256:646d95230efb9ca614a7a594d4fcacde0ac61d25e37dd51710b36477594963ce \ + --hash=sha256:64fcc24778ca0404ce0cb7b6b77ae1f4c7231cdd60e6778f999ee05cbd581b9a \ + --hash=sha256:6be43b667360fef5c754dda5d25a32e6307a03c204f3c0fc5468b78fa87b4160 \ + --hash=sha256:6da7c2ce169267d0d066adcf63758f0604aa6c3eebf67458930f9d9b79ad1db1 \ + --hash=sha256:83d282364ea9f3e52363da262ce32a09dfe241e4080dcedda3c0db059d3c1f11 \ + --hash=sha256:9153c3292705be9f9c64498a8872118540c3f4123d1a1c840172edf262c8be4a \ + --hash=sha256:99eefd13c0dc3b3c1b4d561c1169e65fe47aab7b8158754d7c084088e2329466 \ + --hash=sha256:a0a7f52498f72f13d4a25ea70f35f4cb60642b466cbb0a9be951b5bc3f45a486 \ + --hash=sha256:a2b336e2d91a3d7006864e0990c83b216fcdca64b5a6484912902cef87313d78 \ + --hash=sha256:a48f2b74020919552ea25d222d5cc6af9ca3f4eb43a93e14d068457f545c2a17 \ + --hash=sha256:ad3d9833a64cf48cc4300f2b406c3d0f4f4724a91c0bd5640678a6ba7c102077 \ + --hash=sha256:b44d07310e404ba95f8c25aa5536f154c0a8ec473303535949e52eb71d0a1565 \ + --hash=sha256:b53285e65d4fa4c86399979e956235deb900be5baa7fc1218ea67fbfaeaadd6f \ + --hash=sha256:b5a2b97dbdc7d4f353ebf343744f1d1f1cca8aa8bfddb4262fcf4306c3761d50 \ + --hash=sha256:b9a5ca4ac220a0cdd1ba6bcba3608547117d30468fefce49bb26f55c1a3d5c58 \ + --hash=sha256:bab485bcf8b1c7d2060b4fcb6fc368a9e6f4cd754c9c2fea281f4be21df394a2 \ + --hash=sha256:c108a1d6fa78a50646029cb6d49808ff0fc1330fda87fa6f6250c6b5369b6645 \ + --hash=sha256:d56a1efd5bfd61486c8bc968fa18734464556f0fb8e51690f4ac25d85cbbbbc2 \ + --hash=sha256:d9050fee89a89ed57b4fb2c1bfac9a3d0c57a0d55aed95949eedbc42070fea39 \ + --hash=sha256:dd80ecb295460a5d9d260df63c43f4afbdd832d725a531f008dad1664f458adf \ + --hash=sha256:e8ea3e2d4066083e264e75c829078f9e149fa119d27e19acd503de65e0b13149 \ + --hash=sha256:eb3823f11823deade26ce3b9f40dcb4a213da7a670013929f31d5f5ed1055b22 \ + --hash=sha256:ee40c27c795bda6a5292e9cff9890189d32f7e3a0bf04e0e3c9430c4a00c37df \ + --hash=sha256:efb30e3baaea72ce5928e32bab719ab4770099079d66726a62b11b1ef7273be4 \ + --hash=sha256:f254d118d14a7f99d616271d6c3c27922c092dac11112670b157798b89bf4933 \ + --hash=sha256:f89c151aab2e2e23cb3fe0acad1e8b82841fd265379c4cecd0f3fcb34c15e0f6 \ + --hash=sha256:f97aeb209c3d2511443f8797e3e5a569aebb040d4f8bc79aa3ee78a8fb9e3dd8 \ + --hash=sha256:f9b587c9c7274c1613a30afabf65a272114cd6cdbe67b3406f818c79d7ab2e2a \ + --hash=sha256:fb061f596dad3a0f52b60dc6a5dec4a0c300dec41e058a7efe09256188d170b7 # via -r build/test-requirements.txt mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba # via markdown-it-py -ml-dtypes==0.5.1 \ - --hash=sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327 \ - --hash=sha256:05f23447a1c20ddf4dc7c2c661aa9ed93fcb2658f1017c204d1e758714dc28a8 \ - --hash=sha256:12651420130ee7cc13059fc56dac6ad300c3af3848b802d475148c9defd27c23 \ - --hash=sha256:141b2ea2f20bb10802ddca55d91fe21231ef49715cfc971998e8f2a9838f3dbe \ - --hash=sha256:15ad0f3b0323ce96c24637a88a6f44f6713c64032f27277b069f285c3cf66478 \ - --hash=sha256:1b7fbe5571fdf28fd3aaab3ef4aafc847de9ebf263be959958c1ca58ec8eadf5 \ - --hash=sha256:26ebcc69d7b779c8f129393e99732961b5cc33fcff84090451f448c89b0e01b4 \ - --hash=sha256:6f462f5eca22fb66d7ff9c4744a3db4463af06c49816c4b6ac89b16bfcdc592e \ - --hash=sha256:6f76232163b5b9c34291b54621ee60417601e2e4802a188a0ea7157cd9b323f4 \ - --hash=sha256:7000b6e4d8ef07542c05044ec5d8bbae1df083b3f56822c3da63993a113e716f \ - --hash=sha256:810512e2eccdfc3b41eefa3a27402371a3411453a1efc7e9c000318196140fed \ - --hash=sha256:8f2c028954f16ede77902b223a8da2d9cbb3892375b85809a5c3cfb1587960c4 \ - --hash=sha256:9626d0bca1fb387d5791ca36bacbba298c5ef554747b7ebeafefb4564fc83566 \ - --hash=sha256:ac5b58559bb84a95848ed6984eb8013249f90b6bab62aa5acbad876e256002c9 \ - --hash=sha256:ad4953c5eb9c25a56d11a913c2011d7e580a435ef5145f804d98efa14477d390 \ - --hash=sha256:aefedc579ece2f8fb38f876aa7698204ee4c372d0e54f1c1ffa8ca580b54cc60 \ - --hash=sha256:afb2009ac98da274e893e03162f6269398b2b00d947e7057ee2469a921d58135 \ - --hash=sha256:b8a9d46b4df5ae2135a8e8e72b465448ebbc1559997f4f9304a9ecc3413efb5b \ - --hash=sha256:bd73f51957949069573ff783563486339a9285d72e2f36c18e0c1aa9ca7eb190 \ - --hash=sha256:bf9975bda82a99dc935f2ae4c83846d86df8fd6ba179614acac8e686910851da \ - --hash=sha256:c09526488c3a9e8b7a23a388d4974b670a9a3dd40c5c8a61db5593ce9b725bab \ - --hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \ - --hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \ - --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 - # via -r build/requirements.in +ml-dtypes==0.5.4 \ + --hash=sha256:0d2ffd05a2575b1519dc928c0b93c06339eb67173ff53acb00724502cda231cf \ + --hash=sha256:11942cbf2cf92157db91e5022633c0d9474d4dfd813a909383bd23ce828a4b7d \ + --hash=sha256:14a4fd3228af936461db66faccef6e4f41c1d82fcc30e9f8d58a08916b1d811f \ + --hash=sha256:19b9a53598f21e453ea2fbda8aa783c20faff8e1eeb0d7ab899309a0053f1483 \ + --hash=sha256:2314892cdc3fcf05e373d76d72aaa15fda9fb98625effa73c1d646f331fcecb7 \ + --hash=sha256:2b857d3af6ac0d39db1de7c706e69c7f9791627209c3d6dedbfca8c7e5faec22 \ + --hash=sha256:304ad47faa395415b9ccbcc06a0350800bc50eda70f0e45326796e27c62f18b6 \ + --hash=sha256:35f29491a3e478407f7047b8a4834e4640a77d2737e0b294d049746507af5175 \ + --hash=sha256:388d399a2152dd79a3f0456a952284a99ee5c93d3e2f8dfe25977511e0515270 \ + --hash=sha256:3bbbe120b915090d9dd1375e4684dd17a20a2491ef25d640a908281da85e73f1 \ + --hash=sha256:3d277bf3637f2a62176f4575512e9ff9ef51d00e39626d9fe4a161992f355af2 \ + --hash=sha256:4381fe2f2452a2d7589689693d3162e876b3ddb0a832cde7a414f8e1adf7eab1 \ + --hash=sha256:4ff7f3e7ca2972e7de850e7b8fcbb355304271e2933dd90814c1cb847414d6e2 \ + --hash=sha256:531eff30e4d368cb6255bc2328d070e35836aa4f282a0fb5f3a0cd7260257298 \ + --hash=sha256:533ce891ba774eabf607172254f2e7260ba5f57bdd64030c9a4fcfbd99815d0d \ + --hash=sha256:557a31a390b7e9439056644cb80ed0735a6e3e3bb09d67fd5687e4b04238d1de \ + --hash=sha256:5a0f68ca8fd8d16583dfa7793973feb86f2fbb56ce3966daf9c9f748f52a2049 \ + --hash=sha256:6a0df4223b514d799b8a1629c65ddc351b3efa833ccf7f8ea0cf654a61d1e35d \ + --hash=sha256:6c7ecb74c4bd71db68a6bea1edf8da8c34f3d9fe218f038814fd1d310ac76c90 \ + --hash=sha256:7c23c54a00ae43edf48d44066a7ec31e05fdc2eee0be2b8b50dd1903a1db94bb \ + --hash=sha256:805cef3a38f4eafae3a5bf9ebdcdb741d0bcfd9e1bd90eb54abd24f928cd2465 \ + --hash=sha256:88c982aac7cb1cbe8cbb4e7f253072b1df872701fcaf48d84ffbb433b6568f24 \ + --hash=sha256:8ab06a50fb9bf9666dd0fe5dfb4676fa2b0ac0f31ecff72a6c3af8e22c063453 \ + --hash=sha256:8c6a2dcebd6f3903e05d51960a8058d6e131fe69f952a5397e5dbabc841b6d56 \ + --hash=sha256:8c760d85a2f82e2bed75867079188c9d18dae2ee77c25a54d60e9cc79be1bc48 \ + --hash=sha256:9ad459e99793fa6e13bd5b7e6792c8f9190b4e5a1b45c63aba14a4d0a7f1d5ff \ + --hash=sha256:9bad06436568442575beb2d03389aa7456c690a5b05892c471215bfd8cf39460 \ + --hash=sha256:a174837a64f5b16cab6f368171a1a03a27936b31699d167684073ff1c4237dac \ + --hash=sha256:a7f7c643e8b1320fd958bf098aa7ecf70623a42ec5154e3be3be673f4c34d900 \ + --hash=sha256:a9b61c19040397970d18d7737375cffd83b1f36a11dd4ad19f83a016f736c3ef \ + --hash=sha256:b4b801ebe0b477be666696bda493a9be8356f1f0057a57f1e35cd26928823e5a \ + --hash=sha256:b95e97e470fe60ed493fd9ae3911d8da4ebac16bd21f87ffa2b7c588bf22ea2c \ + --hash=sha256:bc11d7e8c44a65115d05e2ab9989d1e045125d7be8e05a071a48bc76eb6d6040 \ + --hash=sha256:bfc534409c5d4b0bf945af29e5d0ab075eae9eecbb549ff8a29280db822f34f9 \ + --hash=sha256:c1a953995cccb9e25a4ae19e34316671e4e2edaebe4cf538229b1fc7109087b7 \ + --hash=sha256:cb73dccfc991691c444acc8c0012bee8f2470da826a92e3a20bb333b1a7894e6 \ + --hash=sha256:ce756d3a10d0c4067172804c9cc276ba9cc0ff47af9078ad439b075d1abdc29b \ + --hash=sha256:d81fdb088defa30eb37bf390bb7dde35d3a83ec112ac8e33d75ab28cc29dd8b0 \ + --hash=sha256:f21c9219ef48ca5ee78402d5cc831bd58ea27ce89beda894428bc67a52da5328 + # via + # -r build/requirements.in + # jaxlib mpmath==1.3.0 \ --hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \ --hash=sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c # via -r build/test-requirements.txt -numpy==2.2.1 ; python_version >= "3.13" \ - --hash=sha256:059e6a747ae84fce488c3ee397cee7e5f905fd1bda5fb18c66bc41807ff119b2 \ - --hash=sha256:08ef779aed40dbc52729d6ffe7dd51df85796a702afbf68a4f4e41fafdc8bda5 \ - --hash=sha256:164a829b6aacf79ca47ba4814b130c4020b202522a93d7bff2202bfb33b61c60 \ - --hash=sha256:26c9c4382b19fcfbbed3238a14abf7ff223890ea1936b8890f058e7ba35e8d71 \ - --hash=sha256:27f5cdf9f493b35f7e41e8368e7d7b4bbafaf9660cba53fb21d2cd174ec09631 \ - --hash=sha256:31b89fa67a8042e96715c68e071a1200c4e172f93b0fbe01a14c0ff3ff820fc8 \ - --hash=sha256:32cb94448be47c500d2c7a95f93e2f21a01f1fd05dd2beea1ccd049bb6001cd2 \ - --hash=sha256:360137f8fb1b753c5cde3ac388597ad680eccbbbb3865ab65efea062c4a1fd16 \ - --hash=sha256:3683a8d166f2692664262fd4900f207791d005fb088d7fdb973cc8d663626faa \ - --hash=sha256:38efc1e56b73cc9b182fe55e56e63b044dd26a72128fd2fbd502f75555d92591 \ - --hash=sha256:3d03883435a19794e41f147612a77a8f56d4e52822337844fff3d4040a142964 \ - --hash=sha256:3ecc47cd7f6ea0336042be87d9e7da378e5c7e9b3c8ad0f7c966f714fc10d821 \ - --hash=sha256:40f9e544c1c56ba8f1cf7686a8c9b5bb249e665d40d626a23899ba6d5d9e1484 \ - --hash=sha256:4250888bcb96617e00bfa28ac24850a83c9f3a16db471eca2ee1f1714df0f957 \ - --hash=sha256:4511d9e6071452b944207c8ce46ad2f897307910b402ea5fa975da32e0102800 \ - --hash=sha256:45681fd7128c8ad1c379f0ca0776a8b0c6583d2f69889ddac01559dfe4390918 \ - --hash=sha256:48fd472630715e1c1c89bf1feab55c29098cb403cc184b4859f9c86d4fcb6a95 \ - --hash=sha256:4c86e2a209199ead7ee0af65e1d9992d1dce7e1f63c4b9a616500f93820658d0 \ - --hash=sha256:4dfda918a13cc4f81e9118dea249e192ab167a0bb1966272d5503e39234d694e \ - --hash=sha256:5062dc1a4e32a10dc2b8b13cedd58988261416e811c1dc4dbdea4f57eea61b0d \ - --hash=sha256:51faf345324db860b515d3f364eaa93d0e0551a88d6218a7d61286554d190d73 \ - --hash=sha256:526fc406ab991a340744aad7e25251dd47a6720a685fa3331e5c59fef5282a59 \ - --hash=sha256:53c09385ff0b72ba79d8715683c1168c12e0b6e84fb0372e97553d1ea91efe51 \ - --hash=sha256:55ba24ebe208344aa7a00e4482f65742969a039c2acfcb910bc6fcd776eb4355 \ - --hash=sha256:5b6c390bfaef8c45a260554888966618328d30e72173697e5cabe6b285fb2348 \ - --hash=sha256:5c5cc0cbabe9452038ed984d05ac87910f89370b9242371bd9079cb4af61811e \ - --hash=sha256:5edb4e4caf751c1518e6a26a83501fda79bff41cc59dac48d70e6d65d4ec4440 \ - --hash=sha256:61048b4a49b1c93fe13426e04e04fdf5a03f456616f6e98c7576144677598675 \ - --hash=sha256:676f4eebf6b2d430300f1f4f4c2461685f8269f94c89698d832cdf9277f30b84 \ - --hash=sha256:67d4cda6fa6ffa073b08c8372aa5fa767ceb10c9a0587c707505a6d426f4e046 \ - --hash=sha256:694f9e921a0c8f252980e85bce61ebbd07ed2b7d4fa72d0e4246f2f8aa6642ab \ - --hash=sha256:733585f9f4b62e9b3528dd1070ec4f52b8acf64215b60a845fa13ebd73cd0712 \ - --hash=sha256:7671dc19c7019103ca44e8d94917eba8534c76133523ca8406822efdd19c9308 \ - --hash=sha256:780077d95eafc2ccc3ced969db22377b3864e5b9a0ea5eb347cc93b3ea900315 \ - --hash=sha256:7ba9cc93a91d86365a5d270dee221fdc04fb68d7478e6bf6af650de78a8339e3 \ - --hash=sha256:89b16a18e7bba224ce5114db863e7029803c179979e1af6ad6a6b11f70545008 \ - --hash=sha256:9036d6365d13b6cbe8f27a0eaf73ddcc070cae584e5ff94bb45e3e9d729feab5 \ - --hash=sha256:93cf4e045bae74c90ca833cba583c14b62cb4ba2cba0abd2b141ab52548247e2 \ - --hash=sha256:9ad014faa93dbb52c80d8f4d3dcf855865c876c9660cb9bd7553843dd03a4b1e \ - --hash=sha256:9b1d07b53b78bf84a96898c1bc139ad7f10fda7423f5fd158fd0f47ec5e01ac7 \ - --hash=sha256:a7746f235c47abc72b102d3bce9977714c2444bdfaea7888d241b4c4bb6a78bf \ - --hash=sha256:aa3017c40d513ccac9621a2364f939d39e550c542eb2a894b4c8da92b38896ab \ - --hash=sha256:b34d87e8a3090ea626003f87f9392b3929a7bbf4104a05b6667348b6bd4bf1cd \ - --hash=sha256:b541032178a718c165a49638d28272b771053f628382d5e9d1c93df23ff58dbf \ - --hash=sha256:ba5511d8f31c033a5fcbda22dd5c813630af98c70b2661f2d2c654ae3cdfcfc8 \ - --hash=sha256:bc8a37ad5b22c08e2dbd27df2b3ef7e5c0864235805b1e718a235bcb200cf1cb \ - --hash=sha256:bff7d8ec20f5f42607599f9994770fa65d76edca264a87b5e4ea5629bce12268 \ - --hash=sha256:c1ad395cf254c4fbb5b2132fee391f361a6e8c1adbd28f2cd8e79308a615fe9d \ - --hash=sha256:f1d09e520217618e76396377c81fba6f290d5f926f50c35f3a5f72b01a0da780 \ - --hash=sha256:f3eac17d9ec51be534685ba877b6ab5edc3ab7ec95c8f163e5d7b39859524716 \ - --hash=sha256:f419290bc8968a46c4933158c91a0012b7a99bb2e465d5ef5293879742f8797e \ - --hash=sha256:f62aa6ee4eb43b024b0e5a01cf65a0bb078ef8c395e8713c6e8a12a697144528 \ - --hash=sha256:f74e6fdeb9a265624ec3a3918430205dff1df7e95a230779746a6af78bc615af \ - --hash=sha256:f9b57eaa3b0cd8db52049ed0330747b0364e899e8a606a624813452b8203d5f7 \ - --hash=sha256:fce4f615f8ca31b2e61aa0eb5865a21e14f5629515c9151850aa936c02a1ee51 +numpy==2.2.6 ; python_version == "3.13" \ + --hash=sha256:038613e9fb8c72b0a41f025a7e4c3f0b7a1b5d768ece4796b674c8f3fe13efff \ + --hash=sha256:0678000bb9ac1475cd454c6b8c799206af8107e310843532b04d49649c717a47 \ + --hash=sha256:0811bb762109d9708cca4d0b13c4f67146e3c3b7cf8d34018c722adb2d957c84 \ + --hash=sha256:0b605b275d7bd0c640cad4e5d30fa701a8d59302e127e5f79138ad62762c3e3d \ + --hash=sha256:0bca768cd85ae743b2affdc762d617eddf3bcf8724435498a1e80132d04879e6 \ + --hash=sha256:1bc23a79bfabc5d056d106f9befb8d50c31ced2fbc70eedb8155aec74a45798f \ + --hash=sha256:287cc3162b6f01463ccd86be154f284d0893d2b3ed7292439ea97eafa8170e0b \ + --hash=sha256:37c0ca431f82cd5fa716eca9506aefcabc247fb27ba69c5062a6d3ade8cf8f49 \ + --hash=sha256:37e990a01ae6ec7fe7fa1c26c55ecb672dd98b19c3d0e1d1f326fa13cb38d163 \ + --hash=sha256:389d771b1623ec92636b0786bc4ae56abafad4a4c513d36a55dce14bd9ce8571 \ + --hash=sha256:3d70692235e759f260c3d837193090014aebdf026dfd167834bcba43e30c2a42 \ + --hash=sha256:41c5a21f4a04fa86436124d388f6ed60a9343a6f767fced1a8a71c3fbca038ff \ + --hash=sha256:481b49095335f8eed42e39e8041327c05b0f6f4780488f61286ed3c01368d491 \ + --hash=sha256:4eeaae00d789f66c7a25ac5f34b71a7035bb474e679f410e5e1a94deb24cf2d4 \ + --hash=sha256:55a4d33fa519660d69614a9fad433be87e5252f4b03850642f88993f7b2ca566 \ + --hash=sha256:5a6429d4be8ca66d889b7cf70f536a397dc45ba6faeb5f8c5427935d9592e9cf \ + --hash=sha256:5bd4fc3ac8926b3819797a7c0e2631eb889b4118a9898c84f585a54d475b7e40 \ + --hash=sha256:5beb72339d9d4fa36522fc63802f469b13cdbe4fdab4a288f0c441b74272ebfd \ + --hash=sha256:6031dd6dfecc0cf9f668681a37648373bddd6421fff6c66ec1624eed0180ee06 \ + --hash=sha256:71594f7c51a18e728451bb50cc60a3ce4e6538822731b2933209a1f3614e9282 \ + --hash=sha256:74d4531beb257d2c3f4b261bfb0fc09e0f9ebb8842d82a7b4209415896adc680 \ + --hash=sha256:7befc596a7dc9da8a337f79802ee8adb30a552a94f792b9c9d18c840055907db \ + --hash=sha256:894b3a42502226a1cac872f840030665f33326fc3dac8e57c607905773cdcde3 \ + --hash=sha256:8e41fd67c52b86603a91c1a505ebaef50b3314de0213461c7a6e99c9a3beff90 \ + --hash=sha256:8e9ace4a37db23421249ed236fdcdd457d671e25146786dfc96835cd951aa7c1 \ + --hash=sha256:8fc377d995680230e83241d8a96def29f204b5782f371c532579b4f20607a289 \ + --hash=sha256:9551a499bf125c1d4f9e250377c1ee2eddd02e01eac6644c080162c0c51778ab \ + --hash=sha256:b0544343a702fa80c95ad5d3d608ea3599dd54d4632df855e4c8d24eb6ecfa1c \ + --hash=sha256:b093dd74e50a8cba3e873868d9e93a85b78e0daf2e98c6797566ad8044e8363d \ + --hash=sha256:b412caa66f72040e6d268491a59f2c43bf03eb6c96dd8f0307829feb7fa2b6fb \ + --hash=sha256:b4f13750ce79751586ae2eb824ba7e1e8dba64784086c98cdbbcc6a42112ce0d \ + --hash=sha256:b64d8d4d17135e00c8e346e0a738deb17e754230d7e0810ac5012750bbd85a5a \ + --hash=sha256:ba10f8411898fc418a521833e014a77d3ca01c15b0c6cdcce6a0d2897e6dbbdf \ + --hash=sha256:bd48227a919f1bafbdda0583705e547892342c26fb127219d60a5c36882609d1 \ + --hash=sha256:c1f9540be57940698ed329904db803cf7a402f3fc200bfe599334c9bd84a40b2 \ + --hash=sha256:c820a93b0255bc360f53eca31a0e676fd1101f673dda8da93454a12e23fc5f7a \ + --hash=sha256:ce47521a4754c8f4593837384bd3424880629f718d87c5d44f8ed763edd63543 \ + --hash=sha256:d042d24c90c41b54fd506da306759e06e568864df8ec17ccc17e9e884634fd00 \ + --hash=sha256:de749064336d37e340f640b05f24e9e3dd678c57318c7289d222a8a2f543e90c \ + --hash=sha256:e1dda9c7e08dc141e0247a5b8f49cf05984955246a327d4c48bda16821947b2f \ + --hash=sha256:e29554e2bef54a90aa5cc07da6ce955accb83f21ab5de01a62c8478897b264fd \ + --hash=sha256:e3143e4451880bed956e706a3220b4e5cf6172ef05fcc397f6f36a550b1dd868 \ + --hash=sha256:e8213002e427c69c45a52bbd94163084025f533a55a59d6f9c5b820774ef3303 \ + --hash=sha256:efd28d4e9cd7d7a8d39074a4d44c63eda73401580c5c76acda2ce969e0a38e83 \ + --hash=sha256:f0fd6321b839904e15c46e0d257fdd101dd7f530fe03fd6359c1ea63738703f3 \ + --hash=sha256:f1372f041402e37e5e633e586f62aa53de2eac8d98cbfb822806ce4bbefcb74d \ + --hash=sha256:f2618db89be1b4e05f7a1a847a9c1c0abd63e63a1607d892dd54668dd92faf87 \ + --hash=sha256:f447e6acb680fd307f40d3da4852208af94afdfab89cf850986c3ca00562f4fa \ + --hash=sha256:f92729c95468a2f4f15e9bb94c432a9229d0d50de67304399627a943201baa2f \ + --hash=sha256:f9f1adb22318e121c5c69a09142811a201ef17ab257a1e66ca3025065b7f53ae \ + --hash=sha256:fc0c5673685c508a142ca65209b4e79ed6740a4ed6b2267dbba90f34b0b3cfda \ + --hash=sha256:fc7b73d02efb0e18c000e9ad8b83480dfcd5dfd11065997ed4c6747470ae8915 \ + --hash=sha256:fd83c01228a688733f1ded5201c678f0c53ecc1006ffbc404db9f7a899ac6249 \ + --hash=sha256:fe27749d33bb772c80dcd84ae7e8df2adc920ae8297400dabec45f0dedb3f6de \ + --hash=sha256:fee4236c876c4e8369388054d02d0e9bb84821feb1a64dd59e137e6511a551f8 # via - # -r build/requirements.in + # -r build/freethreading-requirements.txt # contourpy + # jaxlib # matplotlib # ml-dtypes + # numpy-typing-compat + # optype # scipy -nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ - --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ - --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ - --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 +numpy-typing-compat==20251206.2.2 \ + --hash=sha256:93c9442985ef73dc5a18d29d6bc0f7d47a9afe95372d0a9fc68ca4802ea7ad86 \ + --hash=sha256:9d5bf8bca75a27ee1254fea5a2a783b5c862dd9f3e726d12bd4b6143932effd2 + # via optype +nvidia-cublas==13.1.0.3 ; sys_platform == "linux" \ + --hash=sha256:2a3b94a37def342471c59fad7856caee4926809a72dd5270155d6a31b5b277be \ + --hash=sha256:c86fc7f7ae36d7528288c5d88098edcb7b02c633d262e7ddbb86b0ad91be5df2 \ + --hash=sha256:ee8722c1f0145ab246bccb9e452153b5e0515fd094c3678df50b2a0888b8b171 # via - # -r build/test-requirements.txt + # -r build/nvidia-requirements.txt + # nvidia-cudnn-cu13 + # nvidia-cusolver +nvidia-cublas-cu12==12.9.1.4 ; sys_platform == "linux" \ + --hash=sha256:1e5fee10662e6e52bd71dec533fbbd4971bb70a5f24f3bc3793e5c2e9dc640bf \ + --hash=sha256:453611eb21a7c1f2c2156ed9f3a45b691deda0440ec550860290dc901af5b4c2 \ + --hash=sha256:7a950dae01add3b415a5a5cdc4ec818fb5858263e9cca59004bb99fdbbd3a5d6 + # via + # -r build/nvidia-requirements.txt # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 -nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \ - --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ - --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ - --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 - # via -r build/test-requirements.txt -nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \ - --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ - --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ - --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b - # via -r build/test-requirements.txt -nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ - --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ - --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ - --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 - # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ - --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ - --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ - --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef - # via -r build/test-requirements.txt -nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ - --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ - --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ - --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 - # via -r build/test-requirements.txt -nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \ - --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ - --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ - --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac - # via -r build/test-requirements.txt -nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \ - --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ - --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ - --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 +nvidia-cuda-crt==13.0.88 ; sys_platform == "linux" \ + --hash=sha256:2c8043c7c9e02492716426e9919fc78d2c5b3b2a7a768a88e952676b08aa55a4 \ + --hash=sha256:31e02c52916804ca15e31f272a96181d8fadaf40c4c82a77a6f78071a22eccf3 \ + --hash=sha256:ee2ea2a97073e02ee62bb27841f437332be2c248e3eac013df07997ada39c003 # via - # -r build/test-requirements.txt + # -r build/nvidia-requirements.txt + # nvidia-cuda-nvcc +nvidia-cuda-cupti==13.0.85 ; sys_platform == "linux" \ + --hash=sha256:4eb01c08e859bf924d222250d2e8f8b8ff6d3db4721288cf35d14252a4d933c8 \ + --hash=sha256:683f58d301548deeefcb8f6fac1b8d907691b9d8b18eccab417f51e362102f00 \ + --hash=sha256:796bd679890ee55fb14a94629b698b6db54bcfd833d391d5e94017dd9d7d3151 + # via -r build/nvidia-requirements.txt +nvidia-cuda-cupti-cu12==12.9.79 ; sys_platform == "linux" \ + --hash=sha256:096bcf334f13e1984ba36685ad4c1d6347db214de03dbb6eebb237b41d9d934f \ + --hash=sha256:1848a9380067560d5bee10ed240eecc22991713e672c0515f9c3d9396adf93c8 \ + --hash=sha256:791853b030602c6a11d08b5578edfb957cadea06e9d3b26adbf8d036135a4afe + # via -r build/nvidia-requirements.txt +nvidia-cuda-nvcc==13.0.88 ; sys_platform == "linux" \ + --hash=sha256:56fe502eb77625a12f25172caa3cdddb4e4c8ba2c8c17dba44b164761b380f03 \ + --hash=sha256:7c3a32c8ca9866addfd784da363ddee2f6874d560027a296f583e86a61f2d543 \ + --hash=sha256:c7ff28f86a24effdc6c034fa15230c549a273e4771b10a7fec14996f8cf3307f + # via -r build/nvidia-requirements.txt +nvidia-cuda-nvcc-cu12==12.9.86 ; sys_platform == "linux" \ + --hash=sha256:44e1eca4d08926193a558d2434b1bf83d57b4d5743e0c431c0c83d51da1df62b \ + --hash=sha256:5d6a0d32fdc7ea39917c20065614ae93add6f577d840233237ff08e9a38f58f0 \ + --hash=sha256:8ed7f0b17dea662755395be029376db3b94fed5cbb17c2d35cc866c5b1b84099 + # via -r build/nvidia-requirements.txt +nvidia-cuda-nvrtc==13.0.88 ; sys_platform == "linux" \ + --hash=sha256:6bcd4e7f8e205cbe644f5a98f2f799bef9556fefc89dd786e79a16312ce49872 \ + --hash=sha256:ad9b6d2ead2435f11cbb6868809d2adeeee302e9bb94bcf0539c7a40d80e8575 \ + --hash=sha256:d27f20a0ca67a4bb34268a5e951033496c5b74870b868bacd046b1b8e0c3267b + # via -r build/nvidia-requirements.txt +nvidia-cuda-nvrtc-cu12==12.9.86 ; sys_platform == "linux" \ + --hash=sha256:096d4de6bda726415dfaf3198d4f5c522b8e70139c97feef5cd2ca6d4cd9cead \ + --hash=sha256:210cf05005a447e29214e9ce50851e83fc5f4358df8b453155d5e1918094dcb4 \ + --hash=sha256:72972ebdcf504d69462d3bcd67e7b81edd25d0fb85a2c46d3ea3517666636349 + # via -r build/nvidia-requirements.txt +nvidia-cuda-runtime==13.0.96 ; sys_platform == "linux" \ + --hash=sha256:7f82250d7782aa23b6cfe765ecc7db554bd3c2870c43f3d1821f1d18aebf0548 \ + --hash=sha256:ef9bcbe90493a2b9d810e43d249adb3d02e98dd30200d86607d8d02687c43f55 \ + --hash=sha256:f79298c8a098cec150a597c8eba58ecdab96e3bdc4b9bc4f9983635031740492 + # via + # -r build/nvidia-requirements.txt + # nvidia-cuda-nvcc +nvidia-cuda-runtime-cu12==12.9.79 ; sys_platform == "linux" \ + --hash=sha256:25bba2dfb01d48a9b59ca474a1ac43c6ebf7011f1b0b8cc44f54eb6ac48a96c3 \ + --hash=sha256:83469a846206f2a733db0c42e223589ab62fd2fabac4432d2f8802de4bded0a4 \ + --hash=sha256:8e018af8fa02363876860388bd10ccb89eb9ab8fb0aa749aaf58430a9f7c4891 + # via -r build/nvidia-requirements.txt +nvidia-cudnn-cu12==9.16.0.29 ; sys_platform == "linux" \ + --hash=sha256:142e2bd646a4573ab17d61a24c6359155cdfe1f34c67fc305b71222a7ae45b8e \ + --hash=sha256:4b09c43096db582f110c5572d0bcbd98b30d709e860a8f73c6c3846baa83b8d2 \ + --hash=sha256:78d05b4434dacc7dd9bc903d5c33a2f28a5f0064d02568ef7b2418f89f6c5922 + # via -r build/nvidia-requirements.txt +nvidia-cudnn-cu13==9.16.0.29 ; sys_platform == "linux" \ + --hash=sha256:6349bc8769369a91611c5e2ce5c2e510e61848c245099c31e870d2cdce0ab90d \ + --hash=sha256:79dc1bfe8c1a780cf4eb7b334d14d7927576d6dd8823f8e2769911af30fd4da3 \ + --hash=sha256:faafa46e2e7dd844bbcf06b6adec3fa66924987f2fb21bf67f5c6fd697c74a64 + # via -r build/nvidia-requirements.txt +nvidia-cufft==12.0.0.61 ; sys_platform == "linux" \ + --hash=sha256:2708c852ef8cd89d1d2068bdbece0aa188813a0c934db3779b9b1faa8442e5f5 \ + --hash=sha256:2abce5b39d2f5ae12730fb7e5db6696533e36c26e2d3e8fd1750bdd2853364eb \ + --hash=sha256:6c44f692dce8fd5ffd3e3df134b6cdb9c2f72d99cf40b62c32dde45eea9ddad3 + # via -r build/nvidia-requirements.txt +nvidia-cufft-cu12==11.4.1.4 ; sys_platform == "linux" \ + --hash=sha256:1a28c9b12260a1aa7a8fd12f5ebd82d027963d635ba82ff39a1acfa7c4c0fbcf \ + --hash=sha256:8e5bfaac795e93f80611f807d42844e8e27e340e0cde270dcb6c65386d795b80 \ + --hash=sha256:c67884f2a7d276b4b80eb56a79322a95df592ae5e765cf1243693365ccab4e28 + # via -r build/nvidia-requirements.txt +nvidia-cusolver==12.0.4.66 ; sys_platform == "linux" \ + --hash=sha256:02c2457eaa9e39de20f880f4bd8820e6a1cfb9f9a34f820eb12a155aa5bc92d2 \ + --hash=sha256:0a759da5dea5c0ea10fd307de75cdeb59e7ea4fcb8add0924859b944babf1112 \ + --hash=sha256:16515bd33a8e76bb54d024cfa068fa68d30e80fc34b9e1090813ea9362e0cb65 + # via -r build/nvidia-requirements.txt +nvidia-cusolver-cu12==11.7.5.82 ; sys_platform == "linux" \ + --hash=sha256:15da72d1340d29b5b3cf3fd100e3cd53421dde36002eda6ed93811af63c40d88 \ + --hash=sha256:62efa83e4ace59a4c734d052bb72158e888aa7b770e1a5f601682f16fe5b4fd2 \ + --hash=sha256:77666337237716783c6269a658dea310195cddbd80a5b2919b1ba8735cec8efd + # via -r build/nvidia-requirements.txt +nvidia-cusparse==12.6.3.3 ; sys_platform == "linux" \ + --hash=sha256:2b3c89c88d01ee0e477cb7f82ef60a11a4bcd57b6b87c33f789350b59759360b \ + --hash=sha256:80bcc4662f23f1054ee334a15c72b8940402975e0eab63178fc7e670aa59472c \ + --hash=sha256:cbcf42feb737bd7ec15b4c0a63e62351886bd3f975027b8815d7f720a2b5ea79 + # via + # -r build/nvidia-requirements.txt + # nvidia-cusolver +nvidia-cusparse-cu12==12.5.10.65 ; sys_platform == "linux" \ + --hash=sha256:221c73e7482dd93eda44e65ce567c031c07e2f93f6fa0ecd3ba876a195023e83 \ + --hash=sha256:73060ce019ac064a057267c585bf1fd5a353734151f87472ff02b2c5c9984e78 \ + --hash=sha256:9e487468a22a1eaf1fbd1d2035936a905feb79c4ce5c2f67626764ee4f90227c + # via + # -r build/nvidia-requirements.txt # nvidia-cusolver-cu12 -nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \ - --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ - --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 - # via -r build/test-requirements.txt -nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \ - --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ - --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ - --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 +nvidia-nccl-cu12==2.28.9 ; sys_platform == "linux" \ + --hash=sha256:485776daa8447da5da39681af455aa3b2c2586ddcf4af8772495e7c532c7e5ab \ + --hash=sha256:50a36e01c4a090b9f9c47d92cec54964de6b9fcb3362d0e19b8ffc6323c21b60 + # via -r build/nvidia-requirements.txt +nvidia-nccl-cu13==2.28.9 ; sys_platform == "linux" \ + --hash=sha256:01c873ba1626b54caa12272ed228dc5b2781545e0ae8ba3f432a8ef1c6d78643 \ + --hash=sha256:e4553a30f34195f3fa1da02a6da3d6337d28f2003943aa0a3d247bbc25fefc42 + # via -r build/nvidia-requirements.txt +nvidia-nvjitlink==13.0.88 ; sys_platform == "linux" \ + --hash=sha256:13a74f429e23b921c1109976abefacc69835f2f433ebd323d3946e11d804e47b \ + --hash=sha256:634e96e3da9ef845ae744097a1f289238ecf946ce0b82e93cdce14b9782e682f \ + --hash=sha256:e931536ccc7d467a98ba1d8b89ff7fa7f1fa3b13f2b0069118cd7f47bff07d0c # via - # -r build/test-requirements.txt + # -r build/nvidia-requirements.txt + # nvidia-cufft + # nvidia-cusolver + # nvidia-cusparse +nvidia-nvjitlink-cu12==12.9.86 ; sys_platform == "linux" \ + --hash=sha256:994a05ef08ef4b0b299829cde613a424382aff7efb08a7172c1fa616cc3af2ca \ + --hash=sha256:cc6fcec260ca843c10e34c936921a1c426b351753587fdd638e8cff7b16bb9db \ + --hash=sha256:e3f1171dbdc83c5932a45f0f4c99180a70de9bd2718c1ab77d14104f6d7147f9 + # via + # -r build/nvidia-requirements.txt # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 +nvidia-nvshmem-cu12==3.4.5 ; sys_platform == "linux" \ + --hash=sha256:042f2500f24c021db8a06c5eec2539027d57460e1c1a762055a6554f72c369bd \ + --hash=sha256:0b48363fc6964dede448029434c6abed6c5e37f823cb43c3bcde7ecfc0457e15 + # via -r build/nvidia-requirements.txt +nvidia-nvshmem-cu13==3.4.5 ; sys_platform == "linux" \ + --hash=sha256:290f0a2ee94c9f3687a02502f3b9299a9f9fe826e6d0287ee18482e78d495b80 \ + --hash=sha256:6dc2a197f38e5d0376ad52cd1a2a3617d3cdc150fd5966f4aee9bcebb1d68fe9 + # via -r build/nvidia-requirements.txt +nvidia-nvvm==13.0.88 ; sys_platform == "linux" \ + --hash=sha256:2ef0db7849e476d3b2fc3c09b27bdd79bd7ea8ce58cd9c86553d64ea40844ba0 \ + --hash=sha256:c4376a291d72d22a315d9d2f69bdae8f8cd83a627f75bad395cee49a0fe65dc1 \ + --hash=sha256:c5f41ffeb6466944a026dfa5317d7d85355c119bbec279205d22f1869d1054e0 + # via + # -r build/nvidia-requirements.txt + # nvidia-cuda-nvcc opt-einsum==3.4.0 \ --hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \ --hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac - # via - # -r build/test-requirements.txt - # -r build/requirements.in -packaging==24.2 \ - --hash=sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759 \ - --hash=sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f + # via -r build/requirements.in +optype[numpy]==0.15.0 \ + --hash=sha256:457d6ca9e7da19967ec16d42bdf94e240b33b5d70a56fbbf5b427e5ea39cf41e \ + --hash=sha256:caba40ece9ea39b499fa76c036a82e0d452a432dd4dd3e8e0d30892be2e8c76c + # via scipy-stubs +packaging==25.0 \ + --hash=sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484 \ + --hash=sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f # via # auditwheel # build # matplotlib # pytest -pillow==11.1.0 \ - --hash=sha256:015c6e863faa4779251436db398ae75051469f7c903b043a48f078e437656f83 \ - --hash=sha256:0a2f91f8a8b367e7a57c6e91cd25af510168091fb89ec5146003e424e1558a96 \ - --hash=sha256:11633d58b6ee5733bde153a8dafd25e505ea3d32e261accd388827ee987baf65 \ - --hash=sha256:2062ffb1d36544d42fcaa277b069c88b01bb7298f4efa06731a7fd6cc290b81a \ - --hash=sha256:31eba6bbdd27dde97b0174ddf0297d7a9c3a507a8a1480e1e60ef914fe23d352 \ - --hash=sha256:3362c6ca227e65c54bf71a5f88b3d4565ff1bcbc63ae72c34b07bbb1cc59a43f \ - --hash=sha256:368da70808b36d73b4b390a8ffac11069f8a5c85f29eff1f1b01bcf3ef5b2a20 \ - --hash=sha256:36ba10b9cb413e7c7dfa3e189aba252deee0602c86c309799da5a74009ac7a1c \ - --hash=sha256:3764d53e09cdedd91bee65c2527815d315c6b90d7b8b79759cc48d7bf5d4f114 \ - --hash=sha256:3a5fe20a7b66e8135d7fd617b13272626a28278d0e578c98720d9ba4b2439d49 \ - --hash=sha256:3cdcdb0b896e981678eee140d882b70092dac83ac1cdf6b3a60e2216a73f2b91 \ - --hash=sha256:4637b88343166249fe8aa94e7c4a62a180c4b3898283bb5d3d2fd5fe10d8e4e0 \ - --hash=sha256:4db853948ce4e718f2fc775b75c37ba2efb6aaea41a1a5fc57f0af59eee774b2 \ - --hash=sha256:4dd43a78897793f60766563969442020e90eb7847463eca901e41ba186a7d4a5 \ - --hash=sha256:54251ef02a2309b5eec99d151ebf5c9904b77976c8abdcbce7891ed22df53884 \ - --hash=sha256:54ce1c9a16a9561b6d6d8cb30089ab1e5eb66918cb47d457bd996ef34182922e \ - --hash=sha256:593c5fd6be85da83656b93ffcccc2312d2d149d251e98588b14fbc288fd8909c \ - --hash=sha256:5bb94705aea800051a743aa4874bb1397d4695fb0583ba5e425ee0328757f196 \ - --hash=sha256:67cd427c68926108778a9005f2a04adbd5e67c442ed21d95389fe1d595458756 \ - --hash=sha256:70ca5ef3b3b1c4a0812b5c63c57c23b63e53bc38e758b37a951e5bc466449861 \ - --hash=sha256:73ddde795ee9b06257dac5ad42fcb07f3b9b813f8c1f7f870f402f4dc54b5269 \ - --hash=sha256:758e9d4ef15d3560214cddbc97b8ef3ef86ce04d62ddac17ad39ba87e89bd3b1 \ - --hash=sha256:7d33d2fae0e8b170b6a6c57400e077412240f6f5bb2a342cf1ee512a787942bb \ - --hash=sha256:7fdadc077553621911f27ce206ffcbec7d3f8d7b50e0da39f10997e8e2bb7f6a \ - --hash=sha256:8000376f139d4d38d6851eb149b321a52bb8893a88dae8ee7d95840431977081 \ - --hash=sha256:837060a8599b8f5d402e97197d4924f05a2e0d68756998345c829c33186217b1 \ - --hash=sha256:89dbdb3e6e9594d512780a5a1c42801879628b38e3efc7038094430844e271d8 \ - --hash=sha256:8c730dc3a83e5ac137fbc92dfcfe1511ce3b2b5d7578315b63dbbb76f7f51d90 \ - --hash=sha256:8e275ee4cb11c262bd108ab2081f750db2a1c0b8c12c1897f27b160c8bd57bbc \ - --hash=sha256:9044b5e4f7083f209c4e35aa5dd54b1dd5b112b108648f5c902ad586d4f945c5 \ - --hash=sha256:93a18841d09bcdd774dcdc308e4537e1f867b3dec059c131fde0327899734aa1 \ - --hash=sha256:9409c080586d1f683df3f184f20e36fb647f2e0bc3988094d4fd8c9f4eb1b3b3 \ - --hash=sha256:96f82000e12f23e4f29346e42702b6ed9a2f2fea34a740dd5ffffcc8c539eb35 \ - --hash=sha256:9aa9aeddeed452b2f616ff5507459e7bab436916ccb10961c4a382cd3e03f47f \ - --hash=sha256:9ee85f0696a17dd28fbcfceb59f9510aa71934b483d1f5601d1030c3c8304f3c \ - --hash=sha256:a07dba04c5e22824816b2615ad7a7484432d7f540e6fa86af60d2de57b0fcee2 \ - --hash=sha256:a3cd561ded2cf2bbae44d4605837221b987c216cff94f49dfeed63488bb228d2 \ - --hash=sha256:a697cd8ba0383bba3d2d3ada02b34ed268cb548b369943cd349007730c92bddf \ - --hash=sha256:a76da0a31da6fcae4210aa94fd779c65c75786bc9af06289cd1c184451ef7a65 \ - --hash=sha256:a85b653980faad27e88b141348707ceeef8a1186f75ecc600c395dcac19f385b \ - --hash=sha256:a8d65b38173085f24bc07f8b6c505cbb7418009fa1a1fcb111b1f4961814a442 \ - --hash=sha256:aa8dd43daa836b9a8128dbe7d923423e5ad86f50a7a14dc688194b7be5c0dea2 \ - --hash=sha256:ab8a209b8485d3db694fa97a896d96dd6533d63c22829043fd9de627060beade \ - --hash=sha256:abc56501c3fd148d60659aae0af6ddc149660469082859fa7b066a298bde9482 \ - --hash=sha256:ad5db5781c774ab9a9b2c4302bbf0c1014960a0a7be63278d13ae6fdf88126fe \ - --hash=sha256:ae98e14432d458fc3de11a77ccb3ae65ddce70f730e7c76140653048c71bfcbc \ - --hash=sha256:b20be51b37a75cc54c2c55def3fa2c65bb94ba859dde241cd0a4fd302de5ae0a \ - --hash=sha256:b523466b1a31d0dcef7c5be1f20b942919b62fd6e9a9be199d035509cbefc0ec \ - --hash=sha256:b5d658fbd9f0d6eea113aea286b21d3cd4d3fd978157cbf2447a6035916506d3 \ - --hash=sha256:b6123aa4a59d75f06e9dd3dac5bf8bc9aa383121bb3dd9a7a612e05eabc9961a \ - --hash=sha256:bd165131fd51697e22421d0e467997ad31621b74bfc0b75956608cb2906dda07 \ - --hash=sha256:bf902d7413c82a1bfa08b06a070876132a5ae6b2388e2712aab3a7cbc02205c6 \ - --hash=sha256:c12fc111ef090845de2bb15009372175d76ac99969bdf31e2ce9b42e4b8cd88f \ - --hash=sha256:c1eec9d950b6fe688edee07138993e54ee4ae634c51443cfb7c1e7613322718e \ - --hash=sha256:c640e5a06869c75994624551f45e5506e4256562ead981cce820d5ab39ae2192 \ - --hash=sha256:cc1331b6d5a6e144aeb5e626f4375f5b7ae9934ba620c0ac6b3e43d5e683a0f0 \ - --hash=sha256:cfd5cd998c2e36a862d0e27b2df63237e67273f2fc78f47445b14e73a810e7e6 \ - --hash=sha256:d3d8da4a631471dfaf94c10c85f5277b1f8e42ac42bade1ac67da4b4a7359b73 \ - --hash=sha256:d44ff19eea13ae4acdaaab0179fa68c0c6f2f45d66a4d8ec1eda7d6cecbcc15f \ - --hash=sha256:dd0052e9db3474df30433f83a71b9b23bd9e4ef1de13d92df21a52c0303b8ab6 \ - --hash=sha256:dd0e081319328928531df7a0e63621caf67652c8464303fd102141b785ef9547 \ - --hash=sha256:dda60aa465b861324e65a78c9f5cf0f4bc713e4309f83bc387be158b077963d9 \ - --hash=sha256:e06695e0326d05b06833b40b7ef477e475d0b1ba3a6d27da1bb48c23209bf457 \ - --hash=sha256:e1abe69aca89514737465752b4bcaf8016de61b3be1397a8fc260ba33321b3a8 \ - --hash=sha256:e267b0ed063341f3e60acd25c05200df4193e15a4a5807075cd71225a2386e26 \ - --hash=sha256:e5449ca63da169a2e6068dd0e2fcc8d91f9558aba89ff6d02121ca8ab11e79e5 \ - --hash=sha256:e63e4e5081de46517099dc30abe418122f54531a6ae2ebc8680bcd7096860eab \ - --hash=sha256:f189805c8be5ca5add39e6f899e6ce2ed824e65fb45f3c28cb2841911da19070 \ - --hash=sha256:f7955ecf5609dee9442cbface754f2c6e541d9e6eda87fad7f7a989b0bdb9d71 \ - --hash=sha256:f86d3a7a9af5d826744fabf4afd15b9dfef44fe69a98541f666f66fbb8d3fef9 \ - --hash=sha256:fbd43429d0d7ed6533b25fc993861b8fd512c42d04514a0dd6337fb3ccf22761 + # wheel +pillow==12.0.0 \ + --hash=sha256:0869154a2d0546545cde61d1789a6524319fc1897d9ee31218eae7a60ccc5643 \ + --hash=sha256:09f2d0abef9e4e2f349305a4f8cc784a8a6c2f58a8c4892eea13b10a943bd26e \ + --hash=sha256:0b817e7035ea7f6b942c13aa03bb554fc44fea70838ea21f8eb31c638326584e \ + --hash=sha256:0fd00cac9c03256c8b2ff58f162ebcd2587ad3e1f2e397eab718c47e24d231cc \ + --hash=sha256:110486b79f2d112cf6add83b28b627e369219388f64ef2f960fef9ebaf54c642 \ + --hash=sha256:1979f4566bb96c1e50a62d9831e2ea2d1211761e5662afc545fa766f996632f6 \ + --hash=sha256:1ac11e8ea4f611c3c0147424eae514028b5e9077dd99ab91e1bd7bc33ff145e1 \ + --hash=sha256:1b1b133e6e16105f524a8dec491e0586d072948ce15c9b914e41cdadd209052b \ + --hash=sha256:1ee80a59f6ce048ae13cda1abf7fbd2a34ab9ee7d401c46be3ca685d1999a399 \ + --hash=sha256:21f241bdd5080a15bc86d3466a9f6074a9c2c2b314100dd896ac81ee6db2f1ba \ + --hash=sha256:266cd5f2b63ff316d5a1bba46268e603c9caf5606d44f38c2873c380950576ad \ + --hash=sha256:26d9f7d2b604cd23aba3e9faf795787456ac25634d82cd060556998e39c6fa47 \ + --hash=sha256:27f95b12453d165099c84f8a8bfdfd46b9e4bda9e0e4b65f0635430027f55739 \ + --hash=sha256:2c54c1a783d6d60595d3514f0efe9b37c8808746a66920315bfd34a938d7994b \ + --hash=sha256:2fa5f0b6716fc88f11380b88b31fe591a06c6315e955c096c35715788b339e3f \ + --hash=sha256:32ed80ea8a90ee3e6fa08c21e2e091bba6eda8eccc83dbc34c95169507a91f10 \ + --hash=sha256:3830c769decf88f1289680a59d4f4c46c72573446352e2befec9a8512104fa52 \ + --hash=sha256:38df9b4bfd3db902c9c2bd369bcacaf9d935b2fff73709429d95cc41554f7b3d \ + --hash=sha256:3adfb466bbc544b926d50fe8f4a4e6abd8c6bffd28a26177594e6e9b2b76572b \ + --hash=sha256:3e42edad50b6909089750e65c91aa09aaf1e0a71310d383f11321b27c224ed8a \ + --hash=sha256:4078242472387600b2ce8d93ade8899c12bf33fa89e55ec89fe126e9d6d5d9e9 \ + --hash=sha256:455247ac8a4cfb7b9bc45b7e432d10421aea9fc2e74d285ba4072688a74c2e9d \ + --hash=sha256:4cc6b3b2efff105c6a1656cfe59da4fdde2cda9af1c5e0b58529b24525d0a098 \ + --hash=sha256:4cf7fed4b4580601c4345ceb5d4cbf5a980d030fd5ad07c4d2ec589f95f09905 \ + --hash=sha256:5193fde9a5f23c331ea26d0cf171fbf67e3f247585f50c08b3e205c7aeb4589b \ + --hash=sha256:5269cc1caeedb67e6f7269a42014f381f45e2e7cd42d834ede3c703a1d915fe3 \ + --hash=sha256:53561a4ddc36facb432fae7a9d8afbfaf94795414f5cdc5fc52f28c1dca90371 \ + --hash=sha256:55f818bd74fe2f11d4d7cbc65880a843c4075e0ac7226bc1a23261dbea531953 \ + --hash=sha256:58eea5ebe51504057dd95c5b77d21700b77615ab0243d8152793dc00eb4faf01 \ + --hash=sha256:5d5c411a8eaa2299322b647cd932586b1427367fd3184ffbb8f7a219ea2041ca \ + --hash=sha256:6846bd2d116ff42cba6b646edf5bf61d37e5cbd256425fa089fee4ff5c07a99e \ + --hash=sha256:6ace95230bfb7cd79ef66caa064bbe2f2a1e63d93471c3a2e1f1348d9f22d6b7 \ + --hash=sha256:6e51b71417049ad6ab14c49608b4a24d8fb3fe605e5dfabfe523b58064dc3d27 \ + --hash=sha256:71db6b4c1653045dacc1585c1b0d184004f0d7e694c7b34ac165ca70c0838082 \ + --hash=sha256:7438839e9e053ef79f7112c881cef684013855016f928b168b81ed5835f3e75e \ + --hash=sha256:759de84a33be3b178a64c8ba28ad5c135900359e85fb662bc6e403ad4407791d \ + --hash=sha256:792a2c0be4dcc18af9d4a2dfd8a11a17d5e25274a1062b0ec1c2d79c76f3e7f8 \ + --hash=sha256:7d87ef5795da03d742bf49439f9ca4d027cde49c82c5371ba52464aee266699a \ + --hash=sha256:7dfb439562f234f7d57b1ac6bc8fe7f838a4bd49c79230e0f6a1da93e82f1fad \ + --hash=sha256:7fa22993bac7b77b78cae22bad1e2a987ddf0d9015c63358032f84a53f23cdc3 \ + --hash=sha256:805ebf596939e48dbb2e4922a1d3852cfc25c38160751ce02da93058b48d252a \ + --hash=sha256:82240051c6ca513c616f7f9da06e871f61bfd7805f566275841af15015b8f98d \ + --hash=sha256:87d4f8125c9988bfbed67af47dd7a953e2fc7b0cc1e7800ec6d2080d490bb353 \ + --hash=sha256:8d8ca2b210ada074d57fcee40c30446c9562e542fc46aedc19baf758a93532ee \ + --hash=sha256:8dc232e39d409036af549c86f24aed8273a40ffa459981146829a324e0848b4b \ + --hash=sha256:90387104ee8400a7b4598253b4c406f8958f59fcf983a6cea2b50d59f7d63d0b \ + --hash=sha256:905b0365b210c73afb0ebe9101a32572152dfd1c144c7e28968a331b9217b94a \ + --hash=sha256:99353a06902c2e43b43e8ff74ee65a7d90307d82370604746738a1e0661ccca7 \ + --hash=sha256:99a7f72fb6249302aa62245680754862a44179b545ded638cf1fef59befb57ef \ + --hash=sha256:9f0b04c6b8584c2c193babcccc908b38ed29524b29dd464bc8801bf10d746a3a \ + --hash=sha256:9fe611163f6303d1619bbcb653540a4d60f9e55e622d60a3108be0d5b441017a \ + --hash=sha256:a3475b96f5908b3b16c47533daaa87380c491357d197564e0ba34ae75c0f3257 \ + --hash=sha256:a6597ff2b61d121172f5844b53f21467f7082f5fb385a9a29c01414463f93b07 \ + --hash=sha256:a7921c5a6d31b3d756ec980f2f47c0cfdbce0fc48c22a39347a895f41f4a6ea4 \ + --hash=sha256:aa5129de4e174daccbc59d0a3b6d20eaf24417d59851c07ebb37aeb02947987c \ + --hash=sha256:aeaefa96c768fc66818730b952a862235d68825c178f1b3ffd4efd7ad2edcb7c \ + --hash=sha256:afbefa430092f71a9593a99ab6a4e7538bc9eabbf7bf94f91510d3503943edc4 \ + --hash=sha256:aff9e4d82d082ff9513bdd6acd4f5bd359f5b2c870907d2b0a9c5e10d40c88fe \ + --hash=sha256:b22bd8c974942477156be55a768f7aa37c46904c175be4e158b6a86e3a6b7ca8 \ + --hash=sha256:b290fd8aa38422444d4b50d579de197557f182ef1068b75f5aa8558638b8d0a5 \ + --hash=sha256:b2e4b27a6e15b04832fe9bf292b94b5ca156016bbc1ea9c2c20098a0320d6cf6 \ + --hash=sha256:b583dc9070312190192631373c6c8ed277254aa6e6084b74bdd0a6d3b221608e \ + --hash=sha256:b87843e225e74576437fd5b6a4c2205d422754f84a06942cfaf1dc32243e45a8 \ + --hash=sha256:bc91a56697869546d1b8f0a3ff35224557ae7f881050e99f615e0119bf934b4e \ + --hash=sha256:bd87e140e45399c818fac4247880b9ce719e4783d767e030a883a970be632275 \ + --hash=sha256:bde737cff1a975b70652b62d626f7785e0480918dece11e8fef3c0cf057351c3 \ + --hash=sha256:bdee52571a343d721fb2eb3b090a82d959ff37fc631e3f70422e0c2e029f3e76 \ + --hash=sha256:bee2a6db3a7242ea309aa7ee8e2780726fed67ff4e5b40169f2c940e7eb09227 \ + --hash=sha256:beeae3f27f62308f1ddbcfb0690bf44b10732f2ef43758f169d5e9303165d3f9 \ + --hash=sha256:c50f36a62a22d350c96e49ad02d0da41dbd17ddc2e29750dbdba4323f85eb4a5 \ + --hash=sha256:c607c90ba67533e1b2355b821fef6764d1dd2cbe26b8c1005ae84f7aea25ff79 \ + --hash=sha256:c7b2a63fd6d5246349f3d3f37b14430d73ee7e8173154461785e43036ffa96ca \ + --hash=sha256:c828a1ae702fc712978bda0320ba1b9893d99be0badf2647f693cc01cf0f04fa \ + --hash=sha256:c85de1136429c524e55cfa4e033b4a7940ac5c8ee4d9401cc2d1bf48154bbc7b \ + --hash=sha256:c98fa880d695de164b4135a52fd2e9cd7b7c90a9d8ac5e9e443a24a95ef9248e \ + --hash=sha256:cae81479f77420d217def5f54b5b9d279804d17e982e0f2fa19b1d1e14ab5197 \ + --hash=sha256:d034140032870024e6b9892c692fe2968493790dd57208b2c37e3fb35f6df3ab \ + --hash=sha256:d120c38a42c234dc9a8c5de7ceaaf899cf33561956acb4941653f8bdc657aa79 \ + --hash=sha256:d4827615da15cd59784ce39d3388275ec093ae3ee8d7f0c089b76fa87af756c2 \ + --hash=sha256:d49e2314c373f4c2b39446fb1a45ed333c850e09d0c59ac79b72eb3b95397363 \ + --hash=sha256:d52610d51e265a51518692045e372a4c363056130d922a7351429ac9f27e70b0 \ + --hash=sha256:d64317d2587c70324b79861babb9c09f71fbb780bad212018874b2c013d8600e \ + --hash=sha256:d77153e14b709fd8b8af6f66a3afbb9ed6e9fc5ccf0b6b7e1ced7b036a228782 \ + --hash=sha256:d7e091d464ac59d2c7ad8e7e08105eaf9dafbc3883fd7265ffccc2baad6ac925 \ + --hash=sha256:dd333073e0cacdc3089525c7df7d39b211bcdf31fc2824e49d01c6b6187b07d0 \ + --hash=sha256:e5d8efac84c9afcb40914ab49ba063d94f5dbdf5066db4482c66a992f47a3a3b \ + --hash=sha256:f135c702ac42262573fe9714dfe99c944b4ba307af5eb507abef1667e2cbbced \ + --hash=sha256:f13711b1a5ba512d647a0e4ba79280d3a9a045aaf7e0cc6fbe96b91d4cdf6b0c \ + --hash=sha256:f4f1231b7dec408e8670264ce63e9c71409d9583dd21d32c163e25213ee2a344 \ + --hash=sha256:fa3ed2a29a9e9d2d488b4da81dcb54720ac3104a20bf0bd273f1e4648aff5af9 \ + --hash=sha256:fb3096c30df99fd01c7bf8e544f392103d0795b9f98ba71a8054bcbf56b255f1 # via # -r build/test-requirements.txt # matplotlib -pluggy==1.5.0 \ - --hash=sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1 \ - --hash=sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669 +pluggy==1.6.0 \ + --hash=sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3 \ + --hash=sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746 # via pytest portpicker==1.6.0 \ --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa # via -r build/test-requirements.txt -psutil==6.1.1 \ - --hash=sha256:018aeae2af92d943fdf1da6b58665124897cfc94faa2ca92098838f83e1b1bca \ - --hash=sha256:0bdd4eab935276290ad3cb718e9809412895ca6b5b334f5a9111ee6d9aff9377 \ - --hash=sha256:1924e659d6c19c647e763e78670a05dbb7feaf44a0e9c94bf9e14dfc6ba50468 \ - --hash=sha256:33431e84fee02bc84ea36d9e2c4a6d395d479c9dd9bba2376c1f6ee8f3a4e0b3 \ - --hash=sha256:384636b1a64b47814437d1173be1427a7c83681b17a450bfc309a1953e329603 \ - --hash=sha256:6d4281f5bbca041e2292be3380ec56a9413b790579b8e593b1784499d0005dac \ - --hash=sha256:8be07491f6ebe1a693f17d4f11e69d0dc1811fa082736500f649f79df7735303 \ - --hash=sha256:8df0178ba8a9e5bc84fed9cfa61d54601b371fbec5c8eebad27575f1e105c0d4 \ - --hash=sha256:97f7cb9921fbec4904f522d972f0c0e1f4fabbdd4e0287813b21215074a0f160 \ - --hash=sha256:9ccc4316f24409159897799b83004cb1e24f9819b0dcf9c0b68bdcb6cefee6a8 \ - --hash=sha256:b6e06c20c05fe95a3d7302d74e7097756d4ba1247975ad6905441ae1b5b66003 \ - --hash=sha256:c777eb75bb33c47377c9af68f30e9f11bc78e0f07fbf907be4a5d70b2fe5f030 \ - --hash=sha256:ca9609c77ea3b8481ab005da74ed894035936223422dc591d6772b147421f777 \ - --hash=sha256:cf8496728c18f2d0b45198f06895be52f36611711746b7f30c464b422b50e2f5 \ - --hash=sha256:eaa912e0b11848c4d9279a93d7e2783df352b082f40111e078388701fd479e53 \ - --hash=sha256:f35cfccb065fff93529d2afb4a2e89e363fe63ca1e4a5da22b603a85833c2649 \ - --hash=sha256:fc0ed7fe2231a444fc219b9c42d0376e0a9a1a72f16c5cfa0f68d19f1a0663e8 +psutil==7.1.3 \ + --hash=sha256:0005da714eee687b4b8decd3d6cc7c6db36215c9e74e5ad2264b90c3df7d92dc \ + --hash=sha256:1068c303be3a72f8e18e412c5b2a8f6d31750fb152f9cb106b54090296c9d251 \ + --hash=sha256:18349c5c24b06ac5612c0428ec2a0331c26443d259e2a0144a9b24b4395b58fa \ + --hash=sha256:19644c85dcb987e35eeeaefdc3915d059dac7bd1167cdcdbf27e0ce2df0c08c0 \ + --hash=sha256:2bdbcd0e58ca14996a42adf3621a6244f1bb2e2e528886959c72cf1e326677ab \ + --hash=sha256:31d77fcedb7529f27bb3a0472bea9334349f9a04160e8e6e5020f22c59893264 \ + --hash=sha256:3792983e23b69843aea49c8f5b8f115572c5ab64c153bada5270086a2123c7e7 \ + --hash=sha256:3bb428f9f05c1225a558f53e30ccbad9930b11c3fc206836242de1091d3e7dd3 \ + --hash=sha256:56d974e02ca2c8eb4812c3f76c30e28836fffc311d55d979f1465c1feeb2b68b \ + --hash=sha256:6c86281738d77335af7aec228328e944b30930899ea760ecf33a4dba66be5e74 \ + --hash=sha256:8f33a3702e167783a9213db10ad29650ebf383946e91bc77f28a5eb083496bc9 \ + --hash=sha256:95ef04cf2e5ba0ab9eaafc4a11eaae91b44f4ef5541acd2ee91d9108d00d59a7 \ + --hash=sha256:ad81425efc5e75da3f39b3e636293360ad8d0b49bed7df824c79764fb4ba9b8b \ + --hash=sha256:b403da1df4d6d43973dc004d19cee3b848e998ae3154cc8097d139b77156c353 \ + --hash=sha256:bc31fa00f1fbc3c3802141eede66f3a2d51d89716a194bf2cd6fc68310a19880 \ + --hash=sha256:bd0d69cee829226a761e92f28140bec9a5ee9d5b4fb4b0cc589068dbfff559b1 \ + --hash=sha256:c525ffa774fe4496282fb0b1187725793de3e7c6b29e41562733cae9ada151ee \ + --hash=sha256:f39c2c19fe824b47484b96f9692932248a54c43799a84282cfe58d05a6449efd \ + --hash=sha256:fac9cd332c67f4422504297889da5ab7e05fd11e3c4392140f7370f4208ded1f # via portpicker -pyelftools==0.31 \ - --hash=sha256:c774416b10310156879443b81187d182d8d9ee499660380e645918b50bc88f99 \ - --hash=sha256:f52de7b3c7e8c64c8abc04a79a1cf37ac5fb0b8a49809827130b858944840607 +pyelftools==0.32 \ + --hash=sha256:013df952a006db5e138b1edf6d8a68ecc50630adbd0d83a2d41e7f846163d738 \ + --hash=sha256:6de90ee7b8263e740c8715a925382d4099b354f29ac48ea40d840cf7aa14ace5 # via auditwheel -pygments==2.19.1 \ - --hash=sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f \ - --hash=sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c - # via rich -pyparsing==3.2.1 \ - --hash=sha256:506ff4f4386c4cec0590ec19e6302d3aedb992fdc02c761e90416f158dacf8e1 \ - --hash=sha256:61980854fd66de3a90028d679a954d5f2623e83144b5afe5ee86f43d762e5f0a +pygments==2.19.2 \ + --hash=sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887 \ + --hash=sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b + # via + # pytest + # rich +pyparsing==3.2.5 \ + --hash=sha256:2df8d5b7b2802ef88e8d016a2eb9c7aeaa923529cd251ed0fe4608275d4105b6 \ + --hash=sha256:e38a4f02064cf41fe6593d328d0512495ad1f3d8a91c4f73fc401b3079a59a5e # via matplotlib pyproject-hooks==1.2.0 \ --hash=sha256:1e859bd5c40fae9448642dd871adf459e5e2084186e8d2c2a79a824c970da1f8 \ --hash=sha256:9e5c6bfa8dcc30091c74b0cf803c81fdd29d94f01992a7707bc97babb1141913 # via build -pytest==8.3.4 \ - --hash=sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6 \ - --hash=sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761 - # via pytest-xdist -pytest-xdist==3.6.1 \ - --hash=sha256:9ed4adfb68a016610848639bb7e02c9352d5d9f03d04809919e2dafc3be4cca7 \ - --hash=sha256:ead156a4db231eec769737f57668ef58a2084a34b2e55c4a8fa20d861107300d +pytest==8.4.2 \ + --hash=sha256:86c0d0b93306b961d58d62a4db4879f27fe25513d4b969df351abdddb3c30e01 \ + --hash=sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79 + # via + # -r build/test-requirements.txt + # pytest-xdist +pytest-xdist==3.8.0 \ + --hash=sha256:202ca578cfeb7370784a8c33d6d05bc6e13b4f25b5053c30a152269fd10f0b88 \ + --hash=sha256:7e578125ec9bc6050861aa93f2d59f1d8d085595d6551c2c90b6f4fad8d3a9f1 # via -r build/test-requirements.txt python-dateutil==2.9.0.post0 \ --hash=sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3 \ --hash=sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427 # via matplotlib -rich==13.9.4 \ - --hash=sha256:439594978a49a09530cff7ebc4b5c7103ef57baf48d5ea3184f21d9a2befa098 \ - --hash=sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90 +rich==14.2.0 \ + --hash=sha256:73ff50c7c0c1c77c8243079283f4edb376f0f6442433aecb8ce7e6d0b92d1fe4 \ + --hash=sha256:76bc51fe2e57d2b1be1f96c524b890b816e334ab4c1e45888799bfaab0021edd + # via -r build/test-requirements.txt +scipy==1.16.3 ; python_version >= "3.13" \ + --hash=sha256:0151a0749efeaaab78711c78422d413c583b8cdd2011a3c1d6c794938ee9fdb2 \ + --hash=sha256:01e87659402762f43bd2fee13370553a17ada367d42e7487800bf2916535aecb \ + --hash=sha256:03192a35e661470197556de24e7cb1330d84b35b94ead65c46ad6f16f6b28f2a \ + --hash=sha256:0553371015692a898e1aa858fed67a3576c34edefa6b7ebdb4e9dde49ce5c203 \ + --hash=sha256:062246acacbe9f8210de8e751b16fc37458213f124bef161a5a02c7a39284304 \ + --hash=sha256:0c3b4dd3d9b08dbce0f3440032c52e9e2ab9f96ade2d3943313dfe51a7056959 \ + --hash=sha256:0c623a54f7b79dd88ef56da19bc2873afec9673a48f3b85b18e4d402bdd29a5a \ + --hash=sha256:16b8bc35a4cc24db80a0ec836a9286d0e31b2503cb2fd7ff7fb0e0374a97081d \ + --hash=sha256:1fb2472e72e24d1530debe6ae078db70fb1605350c88a3d14bc401d6306dbffe \ + --hash=sha256:21d9d6b197227a12dcbf9633320a4e34c6b0e51c57268df255a0942983bac562 \ + --hash=sha256:2a207a6ce9c24f1951241f4693ede2d393f59c07abc159b2cb2be980820e01fb \ + --hash=sha256:2b71d93c8a9936046866acebc915e2af2e292b883ed6e2cbe5c34beb094b82d9 \ + --hash=sha256:2d1ae2cf0c350e7705168ff2429962a89ad90c2d49d1dd300686d8b2a5af22fc \ + --hash=sha256:3a4c460301fb2cffb7f88528f30b3127742cff583603aa7dc964a52c463b385d \ + --hash=sha256:3d4a07a8e785d80289dfe66b7c27d8634a773020742ec7187b85ccc4b0e7b686 \ + --hash=sha256:40be6cf99e68b6c4321e9f8782e7d5ff8265af28ef2cd56e9c9b2638fa08ad97 \ + --hash=sha256:4aff59800a3b7f786b70bfd6ab551001cb553244988d7d6b8299cb1ea653b353 \ + --hash=sha256:50a3dbf286dbc7d84f176f9a1574c705f277cb6565069f88f60db9eafdbe3ee2 \ + --hash=sha256:532fb5ad6a87e9e9cd9c959b106b73145a03f04c7d57ea3e6f6bb60b86ab0876 \ + --hash=sha256:53c3844d527213631e886621df5695d35e4f6a75f620dca412bcd292f6b87d78 \ + --hash=sha256:56edc65510d1331dae01ef9b658d428e33ed48b4f77b1d51caf479a0253f96dc \ + --hash=sha256:57d01cb6f85e34f0946b33caa66e892aae072b64b034183f3d87c4025802a119 \ + --hash=sha256:5803c5fadd29de0cf27fa08ccbfe7a9e5d741bf63e4ab1085437266f12460ff9 \ + --hash=sha256:6020470b9d00245926f2d5bb93b119ca0340f0d564eb6fbaad843eaebf9d690f \ + --hash=sha256:63d3cdacb8a824a295191a723ee5e4ea7768ca5ca5f2838532d9f2e2b3ce2135 \ + --hash=sha256:663b8d66a8748051c3ee9c96465fb417509315b99c71550fda2591d7dd634234 \ + --hash=sha256:72d1717fd3b5e6ec747327ce9bda32d5463f472c9dce9f54499e81fbd50245a1 \ + --hash=sha256:7dc1360c06535ea6116a2220f760ae572db9f661aba2d88074fe30ec2aa1ff88 \ + --hash=sha256:7f68154688c515cdb541a31ef8eb66d8cd1050605be9dcd74199cbd22ac739bc \ + --hash=sha256:81fc5827606858cf71446a5e98715ba0e11f0dbc83d71c7409d05486592a45d6 \ + --hash=sha256:875555ce62743e1d54f06cdf22c1e0bc47b91130ac40fe5d783b6dfa114beeb6 \ + --hash=sha256:8b3c820ddb80029fe9f43d61b81d8b488d3ef8ca010d15122b152db77dc94c22 \ + --hash=sha256:8be1ca9170fcb6223cc7c27f4305d680ded114a1567c0bd2bfcbf947d1b17511 \ + --hash=sha256:8d09d72dc92742988b0e7750bddb8060b0c7079606c0d24a8cc8e9c9c11f9079 \ + --hash=sha256:9452781bd879b14b6f055b26643703551320aa8d79ae064a71df55c00286a184 \ + --hash=sha256:96491a6a54e995f00a28a3c3badfff58fd093bf26cd5fb34a2188c8c756a3a2c \ + --hash=sha256:9b9c9c07b6d56a35777a1b4cc8966118fb16cfd8daf6743867d17d36cfad2d40 \ + --hash=sha256:a8a26c78ef223d3e30920ef759e25625a0ecdd0d60e5a8818b7513c3e5384cf2 \ + --hash=sha256:aadd23f98f9cb069b3bd64ddc900c4d277778242e961751f77a8cb5c4b946fb0 \ + --hash=sha256:b7180967113560cca57418a7bc719e30366b47959dd845a93206fbed693c867e \ + --hash=sha256:b7c5f1bda1354d6a19bc6af73a649f8285ca63ac6b52e64e658a5a11d4d69800 \ + --hash=sha256:b81c27fc41954319a943d43b20e07c40bdcd3ff7cf013f4fb86286faefe546c4 \ + --hash=sha256:bb61878c18a470021fb515a843dc7a76961a8daceaaaa8bad1332f1bf4b54657 \ + --hash=sha256:bea0a62734d20d67608660f69dcda23e7f90fb4ca20974ab80b6ed40df87a005 \ + --hash=sha256:c5192722cffe15f9329a3948c4b1db789fbb1f05c97899187dcf009b283aea70 \ + --hash=sha256:c97176013d404c7346bf57874eaac5187d969293bf40497140b0a2b2b7482e07 \ + --hash=sha256:cd13e354df9938598af2be05822c323e97132d5e6306b83a3b4ee6724c6e522e \ + --hash=sha256:d2ec56337675e61b312179a1ad124f5f570c00f920cc75e1000025451b88241c \ + --hash=sha256:d3837938ae715fc0fe3c39c0202de3a8853aff22ca66781ddc2ade7554b7e2cc \ + --hash=sha256:d9f48cafc7ce94cf9b15c6bffdc443a81a27bf7075cf2dcd5c8b40f85d10c4e7 \ + --hash=sha256:da7763f55885045036fabcebd80144b757d3db06ab0861415d1c3b7c69042146 \ + --hash=sha256:deb3841c925eeddb6afc1e4e4a45e418d19ec7b87c5df177695224078e8ec733 \ + --hash=sha256:e1d27cbcb4602680a49d787d90664fa4974063ac9d4134813332a8c53dbe667c \ + --hash=sha256:e5d42a9472e7579e473879a1990327830493a7047506d58d73fc429b84c1d49d \ + --hash=sha256:e7efa2681ea410b10dde31a52b18b0154d66f2485328830e45fdf183af5aefc6 \ + --hash=sha256:eab43fae33a0c39006a88096cd7b4f4ef545ea0447d250d5ac18202d40b6611d \ + --hash=sha256:f2622206f5559784fa5c4b53a950c3c7c1cf3e84ca1b9c4b6c03f062f289ca26 \ + --hash=sha256:f379b54b77a597aa7ee5e697df0d66903e41b9c85a6dd7946159e356319158e8 \ + --hash=sha256:f667a4542cc8917af1db06366d3f78a5c8e83badd56409f94d1eac8d8d9133fa \ + --hash=sha256:fb4b29f4cf8cc5a8d628bc8d8e26d12d7278cd1f219f22698a378c3d67db5e4b \ + --hash=sha256:ffa6eea95283b2b8079b821dc11f50a17d0571c92b43e2b5b12764dc5f9b285d + # via + # -r build/requirements.in + # jaxlib +scipy-stubs==1.16.3.3 \ + --hash=sha256:af47578875d5557567225a16ec1b9b38a48c4c4377d92396413ebd65406c44ee \ + --hash=sha256:f6316b36cd0fb272c994ae5b10c4a73c644a7e156ed8d32bcd9c35303d0e1b7e # via -r build/test-requirements.txt -scipy==1.15.0 \ - --hash=sha256:0e5b34f8894f9904cc578008d1a9467829c1817e9f9cb45e6d6eeb61d2ab7731 \ - --hash=sha256:0fcb16eb04d84670722ce8d93b05257df471704c913cb0ff9dc5a1c31d1e9422 \ - --hash=sha256:129f899ed275c0515d553b8d31696924e2ca87d1972421e46c376b9eb87de3d2 \ - --hash=sha256:161f80a98047c219c257bf5ce1777c574bde36b9d962a46b20d0d7e531f86863 \ - --hash=sha256:1b29e4fc02e155a5fd1165f1e6a73edfdd110470736b0f48bcbe48083f0eee37 \ - --hash=sha256:1e2448acd79c6374583581a1ded32ac71a00c2b9c62dfa87a40e1dd2520be111 \ - --hash=sha256:2240e1fd0782e62e1aacdc7234212ee271d810f67e9cd3b8d521003a82603ef8 \ - --hash=sha256:300742e2cc94e36a2880ebe464a1c8b4352a7b0f3e36ec3d2ac006cdbe0219ac \ - --hash=sha256:327163ad73e54541a675240708244644294cb0a65cca420c9c79baeb9648e479 \ - --hash=sha256:351899dd2a801edd3691622172bc8ea01064b1cada794f8641b89a7dc5418db6 \ - --hash=sha256:35c68f7044b4e7ad73a3e68e513dda946989e523df9b062bd3cf401a1a882192 \ - --hash=sha256:36be480e512d38db67f377add5b759fb117edd987f4791cdf58e59b26962bee4 \ - --hash=sha256:37ce9394cdcd7c5f437583fc6ef91bd290014993900643fdfc7af9b052d1613b \ - --hash=sha256:46e91b5b16909ff79224b56e19cbad65ca500b3afda69225820aa3afbf9ec020 \ - --hash=sha256:4e08c6a36f46abaedf765dd2dfcd3698fa4bd7e311a9abb2d80e33d9b2d72c34 \ - --hash=sha256:52475011be29dfcbecc3dfe3060e471ac5155d72e9233e8d5616b84e2b542054 \ - --hash=sha256:5972e3f96f7dda4fd3bb85906a17338e65eaddfe47f750e240f22b331c08858e \ - --hash=sha256:5abbdc6ede5c5fed7910cf406a948e2c0869231c0db091593a6b2fa78be77e5d \ - --hash=sha256:5beb0a2200372b7416ec73fdae94fe81a6e85e44eb49c35a11ac356d2b8eccc6 \ - --hash=sha256:61513b989ee8d5218fbeb178b2d51534ecaddba050db949ae99eeb3d12f6825d \ - --hash=sha256:6d26f17c64abd6c6c2dfb39920f61518cc9e213d034b45b2380e32ba78fde4c0 \ - --hash=sha256:6f376d7c767731477bac25a85d0118efdc94a572c6b60decb1ee48bf2391a73b \ - --hash=sha256:767e8cf6562931f8312f4faa7ddea412cb783d8df49e62c44d00d89f41f9bbe8 \ - --hash=sha256:82bff2eb01ccf7cea8b6ee5274c2dbeadfdac97919da308ee6d8e5bcbe846443 \ - --hash=sha256:952d2e9eaa787f0a9e95b6e85da3654791b57a156c3e6609e65cc5176ccfe6f2 \ - --hash=sha256:9c8254fe21dd2c6c8f7757035ec0c31daecf3bb3cffd93bc1ca661b731d28136 \ - --hash=sha256:aeac60d3562a7bf2f35549bdfdb6b1751c50590f55ce7322b4b2fc821dc27fca \ - --hash=sha256:b1432102254b6dc7766d081fa92df87832ac25ff0b3d3a940f37276e63eb74ff \ - --hash=sha256:bdca4c7bb8dc41307e5f39e9e5d19c707d8e20a29845e7533b3bb20a9d4ccba0 \ - --hash=sha256:c9624eeae79b18cab1a31944b5ef87aa14b125d6ab69b71db22f0dbd962caf1e \ - --hash=sha256:ccb6248a9987193fe74363a2d73b93bc2c546e0728bd786050b7aef6e17db03c \ - --hash=sha256:cd9d9198a7fd9a77f0eb5105ea9734df26f41faeb2a88a0e62e5245506f7b6df \ - --hash=sha256:d13bbc0658c11f3d19df4138336e4bce2c4fbd78c2755be4bf7b8e235481557f \ - --hash=sha256:d35aef233b098e4de88b1eac29f0df378278e7e250a915766786b773309137c4 \ - --hash=sha256:de112c2dae53107cfeaf65101419662ac0a54e9a088c17958b51c95dac5de56d \ - --hash=sha256:e9baff912ea4f78a543d183ed6f5b3bea9784509b948227daaf6f10727a0e2e5 \ - --hash=sha256:eb1533c59f0ec6c55871206f15a5c72d1fae7ad3c0a8ca33ca88f7c309bbbf8c \ - --hash=sha256:ec915cd26d76f6fc7ae8522f74f5b2accf39546f341c771bb2297f3871934a52 \ - --hash=sha256:fde0f3104dfa1dfbc1f230f65506532d0558d43188789eaf68f97e106249a913 \ - --hash=sha256:fe00169cf875bed0b3c40e4da45b57037dc21d7c7bf0c85ed75f210c281488f1 - # via -r build/requirements.in six==1.17.0 \ --hash=sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274 \ --hash=sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81 @@ -646,124 +952,21 @@ sortedcontainers==2.4.0 \ --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \ --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0 # via hypothesis -typing-extensions==4.12.2 \ - --hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \ - --hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8 +typing-extensions==4.15.0 \ + --hash=sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466 \ + --hash=sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548 # via etils -wheel==0.45.1 \ - --hash=sha256:661e1abd9198507b1409a20c02106d9670b2576e916d58f520316666abca6729 \ - --hash=sha256:708e7481cc80179af0e556bbf0cc00b8444c7321e2700b8d8580231d13017248 - # via -r build/test-requirements.txt -zipp==3.21.0 \ - --hash=sha256:2c9958f6430a2040341a52eb608ed6dd93ef4392e02ffe219417c1b28b5dd1f4 \ - --hash=sha256:ac1bbe05fd2991f160ebce24ffbac5f6d11d83dc90891255885223d42b3cd931 +wheel==0.46.1 \ + --hash=sha256:f796f65d72750ccde090663e466d0ca37cd72b62870f7520b96d34cdc07d86d8 \ + --hash=sha256:fd477efb5da0f7df1d3c76c73c14394002c844451bd63229d8570f376f5e6a38 + # via -r build/requirements.in +zipp==3.23.0 \ + --hash=sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e \ + --hash=sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166 # via etils -# python 3.13t can compile 0.23.0 -# due to https://github.com/indygreg/python-zstandard/issues/231 -# zstandard==0.23.0 \ -# --hash=sha256:034b88913ecc1b097f528e42b539453fa82c3557e414b3de9d5632c80439a473 \ -# --hash=sha256:0a7f0804bb3799414af278e9ad51be25edf67f78f916e08afdb983e74161b916 \ -# --hash=sha256:11e3bf3c924853a2d5835b24f03eeba7fc9b07d8ca499e247e06ff5676461a15 \ -# --hash=sha256:12a289832e520c6bd4dcaad68e944b86da3bad0d339ef7989fb7e88f92e96072 \ -# --hash=sha256:1516c8c37d3a053b01c1c15b182f3b5f5eef19ced9b930b684a73bad121addf4 \ -# --hash=sha256:157e89ceb4054029a289fb504c98c6a9fe8010f1680de0201b3eb5dc20aa6d9e \ -# --hash=sha256:1bfe8de1da6d104f15a60d4a8a768288f66aa953bbe00d027398b93fb9680b26 \ -# --hash=sha256:1e172f57cd78c20f13a3415cc8dfe24bf388614324d25539146594c16d78fcc8 \ -# --hash=sha256:1fd7e0f1cfb70eb2f95a19b472ee7ad6d9a0a992ec0ae53286870c104ca939e5 \ -# --hash=sha256:203d236f4c94cd8379d1ea61db2fce20730b4c38d7f1c34506a31b34edc87bdd \ -# --hash=sha256:27d3ef2252d2e62476389ca8f9b0cf2bbafb082a3b6bfe9d90cbcbb5529ecf7c \ -# --hash=sha256:29a2bc7c1b09b0af938b7a8343174b987ae021705acabcbae560166567f5a8db \ -# --hash=sha256:2ef230a8fd217a2015bc91b74f6b3b7d6522ba48be29ad4ea0ca3a3775bf7dd5 \ -# --hash=sha256:2ef3775758346d9ac6214123887d25c7061c92afe1f2b354f9388e9e4d48acfc \ -# --hash=sha256:2f146f50723defec2975fb7e388ae3a024eb7151542d1599527ec2aa9cacb152 \ -# --hash=sha256:2fb4535137de7e244c230e24f9d1ec194f61721c86ebea04e1581d9d06ea1269 \ -# --hash=sha256:32ba3b5ccde2d581b1e6aa952c836a6291e8435d788f656fe5976445865ae045 \ -# --hash=sha256:34895a41273ad33347b2fc70e1bff4240556de3c46c6ea430a7ed91f9042aa4e \ -# --hash=sha256:379b378ae694ba78cef921581ebd420c938936a153ded602c4fea612b7eaa90d \ -# --hash=sha256:38302b78a850ff82656beaddeb0bb989a0322a8bbb1bf1ab10c17506681d772a \ -# --hash=sha256:3aa014d55c3af933c1315eb4bb06dd0459661cc0b15cd61077afa6489bec63bb \ -# --hash=sha256:4051e406288b8cdbb993798b9a45c59a4896b6ecee2f875424ec10276a895740 \ -# --hash=sha256:40b33d93c6eddf02d2c19f5773196068d875c41ca25730e8288e9b672897c105 \ -# --hash=sha256:43da0f0092281bf501f9c5f6f3b4c975a8a0ea82de49ba3f7100e64d422a1274 \ -# --hash=sha256:445e4cb5048b04e90ce96a79b4b63140e3f4ab5f662321975679b5f6360b90e2 \ -# --hash=sha256:48ef6a43b1846f6025dde6ed9fee0c24e1149c1c25f7fb0a0585572b2f3adc58 \ -# --hash=sha256:50a80baba0285386f97ea36239855f6020ce452456605f262b2d33ac35c7770b \ -# --hash=sha256:519fbf169dfac1222a76ba8861ef4ac7f0530c35dd79ba5727014613f91613d4 \ -# --hash=sha256:53dd9d5e3d29f95acd5de6802e909ada8d8d8cfa37a3ac64836f3bc4bc5512db \ -# --hash=sha256:53ea7cdc96c6eb56e76bb06894bcfb5dfa93b7adcf59d61c6b92674e24e2dd5e \ -# --hash=sha256:576856e8594e6649aee06ddbfc738fec6a834f7c85bf7cadd1c53d4a58186ef9 \ -# --hash=sha256:59556bf80a7094d0cfb9f5e50bb2db27fefb75d5138bb16fb052b61b0e0eeeb0 \ -# --hash=sha256:5d41d5e025f1e0bccae4928981e71b2334c60f580bdc8345f824e7c0a4c2a813 \ -# --hash=sha256:61062387ad820c654b6a6b5f0b94484fa19515e0c5116faf29f41a6bc91ded6e \ -# --hash=sha256:61f89436cbfede4bc4e91b4397eaa3e2108ebe96d05e93d6ccc95ab5714be512 \ -# --hash=sha256:62136da96a973bd2557f06ddd4e8e807f9e13cbb0bfb9cc06cfe6d98ea90dfe0 \ -# --hash=sha256:64585e1dba664dc67c7cdabd56c1e5685233fbb1fc1966cfba2a340ec0dfff7b \ -# --hash=sha256:65308f4b4890aa12d9b6ad9f2844b7ee42c7f7a4fd3390425b242ffc57498f48 \ -# --hash=sha256:66b689c107857eceabf2cf3d3fc699c3c0fe8ccd18df2219d978c0283e4c508a \ -# --hash=sha256:6a41c120c3dbc0d81a8e8adc73312d668cd34acd7725f036992b1b72d22c1772 \ -# --hash=sha256:6f77fa49079891a4aab203d0b1744acc85577ed16d767b52fc089d83faf8d8ed \ -# --hash=sha256:72c68dda124a1a138340fb62fa21b9bf4848437d9ca60bd35db36f2d3345f373 \ -# --hash=sha256:752bf8a74412b9892f4e5b58f2f890a039f57037f52c89a740757ebd807f33ea \ -# --hash=sha256:76e79bc28a65f467e0409098fa2c4376931fd3207fbeb6b956c7c476d53746dd \ -# --hash=sha256:774d45b1fac1461f48698a9d4b5fa19a69d47ece02fa469825b442263f04021f \ -# --hash=sha256:77da4c6bfa20dd5ea25cbf12c76f181a8e8cd7ea231c673828d0386b1740b8dc \ -# --hash=sha256:77ea385f7dd5b5676d7fd943292ffa18fbf5c72ba98f7d09fc1fb9e819b34c23 \ -# --hash=sha256:80080816b4f52a9d886e67f1f96912891074903238fe54f2de8b786f86baded2 \ -# --hash=sha256:80a539906390591dd39ebb8d773771dc4db82ace6372c4d41e2d293f8e32b8db \ -# --hash=sha256:82d17e94d735c99621bf8ebf9995f870a6b3e6d14543b99e201ae046dfe7de70 \ -# --hash=sha256:837bb6764be6919963ef41235fd56a6486b132ea64afe5fafb4cb279ac44f259 \ -# --hash=sha256:84433dddea68571a6d6bd4fbf8ff398236031149116a7fff6f777ff95cad3df9 \ -# --hash=sha256:8c24f21fa2af4bb9f2c492a86fe0c34e6d2c63812a839590edaf177b7398f700 \ -# --hash=sha256:8ed7d27cb56b3e058d3cf684d7200703bcae623e1dcc06ed1e18ecda39fee003 \ -# --hash=sha256:9206649ec587e6b02bd124fb7799b86cddec350f6f6c14bc82a2b70183e708ba \ -# --hash=sha256:983b6efd649723474f29ed42e1467f90a35a74793437d0bc64a5bf482bedfa0a \ -# --hash=sha256:98da17ce9cbf3bfe4617e836d561e433f871129e3a7ac16d6ef4c680f13a839c \ -# --hash=sha256:9c236e635582742fee16603042553d276cca506e824fa2e6489db04039521e90 \ -# --hash=sha256:9da6bc32faac9a293ddfdcb9108d4b20416219461e4ec64dfea8383cac186690 \ -# --hash=sha256:a05e6d6218461eb1b4771d973728f0133b2a4613a6779995df557f70794fd60f \ -# --hash=sha256:a0817825b900fcd43ac5d05b8b3079937073d2b1ff9cf89427590718b70dd840 \ -# --hash=sha256:a4ae99c57668ca1e78597d8b06d5af837f377f340f4cce993b551b2d7731778d \ -# --hash=sha256:a8c86881813a78a6f4508ef9daf9d4995b8ac2d147dcb1a450448941398091c9 \ -# --hash=sha256:a8fffdbd9d1408006baaf02f1068d7dd1f016c6bcb7538682622c556e7b68e35 \ -# --hash=sha256:a9b07268d0c3ca5c170a385a0ab9fb7fdd9f5fd866be004c4ea39e44edce47dd \ -# --hash=sha256:ab19a2d91963ed9e42b4e8d77cd847ae8381576585bad79dbd0a8837a9f6620a \ -# --hash=sha256:ac184f87ff521f4840e6ea0b10c0ec90c6b1dcd0bad2f1e4a9a1b4fa177982ea \ -# --hash=sha256:b0e166f698c5a3e914947388c162be2583e0c638a4703fc6a543e23a88dea3c1 \ -# --hash=sha256:b2170c7e0367dde86a2647ed5b6f57394ea7f53545746104c6b09fc1f4223573 \ -# --hash=sha256:b2d8c62d08e7255f68f7a740bae85b3c9b8e5466baa9cbf7f57f1cde0ac6bc09 \ -# --hash=sha256:b4567955a6bc1b20e9c31612e615af6b53733491aeaa19a6b3b37f3b65477094 \ -# --hash=sha256:b69bb4f51daf461b15e7b3db033160937d3ff88303a7bc808c67bbc1eaf98c78 \ -# --hash=sha256:b8c0bd73aeac689beacd4e7667d48c299f61b959475cdbb91e7d3d88d27c56b9 \ -# --hash=sha256:be9b5b8659dff1f913039c2feee1aca499cfbc19e98fa12bc85e037c17ec6ca5 \ -# --hash=sha256:bf0a05b6059c0528477fba9054d09179beb63744355cab9f38059548fedd46a9 \ -# --hash=sha256:c16842b846a8d2a145223f520b7e18b57c8f476924bda92aeee3a88d11cfc391 \ -# --hash=sha256:c363b53e257246a954ebc7c488304b5592b9c53fbe74d03bc1c64dda153fb847 \ -# --hash=sha256:c7c517d74bea1a6afd39aa612fa025e6b8011982a0897768a2f7c8ab4ebb78a2 \ -# --hash=sha256:d20fd853fbb5807c8e84c136c278827b6167ded66c72ec6f9a14b863d809211c \ -# --hash=sha256:d2240ddc86b74966c34554c49d00eaafa8200a18d3a5b6ffbf7da63b11d74ee2 \ -# --hash=sha256:d477ed829077cd945b01fc3115edd132c47e6540ddcd96ca169facff28173057 \ -# --hash=sha256:d50d31bfedd53a928fed6707b15a8dbeef011bb6366297cc435accc888b27c20 \ -# --hash=sha256:dc1d33abb8a0d754ea4763bad944fd965d3d95b5baef6b121c0c9013eaf1907d \ -# --hash=sha256:dc5d1a49d3f8262be192589a4b72f0d03b72dcf46c51ad5852a4fdc67be7b9e4 \ -# --hash=sha256:e2d1a054f8f0a191004675755448d12be47fa9bebbcffa3cdf01db19f2d30a54 \ -# --hash=sha256:e7792606d606c8df5277c32ccb58f29b9b8603bf83b48639b7aedf6df4fe8171 \ -# --hash=sha256:ed1708dbf4d2e3a1c5c69110ba2b4eb6678262028afd6c6fbcc5a8dac9cda68e \ -# --hash=sha256:f2d4380bf5f62daabd7b751ea2339c1a21d1c9463f1feb7fc2bdcea2c29c3160 \ -# --hash=sha256:f3513916e8c645d0610815c257cbfd3242adfd5c4cfa78be514e5a3ebb42a41b \ -# --hash=sha256:f8346bfa098532bc1fb6c7ef06783e969d87a99dd1d2a5a18a892c1d7a643c58 \ -# --hash=sha256:f83fa6cae3fff8e98691248c9320356971b59678a17f20656a9e59cd32cee6d8 \ -# --hash=sha256:fa6ce8b52c5987b3e34d5674b0ab529a4602b632ebab0a93b07bfb4dfc8f8a33 \ -# --hash=sha256:fb2b1ecfef1e67897d336de3a0e3f52478182d6a47eda86cbd42504c5cbd009a \ -# --hash=sha256:fc9ca1c9718cb3b06634c7c8dec57d24e9438b2aa9a0f02b8bb36bf478538880 \ -# --hash=sha256:fd30d9c67d13d891f2360b2a120186729c111238ac63b43dbd37a5a40670b8ca \ -# --hash=sha256:fd7699e8fd9969f455ef2926221e0233f81a2542921471382e77a9e2f2b57f4b \ -# --hash=sha256:fe3b385d996ee0822fd46528d9f0443b880d4d05528fd26a9119a54ec3f91c69 -# # via -r build/requirements.in # The following packages are considered to be unsafe in a requirements file: -setuptools==70.3.0 \ - --hash=sha256:f171bab1dfbc86b132997f26a119f6056a57950d058587841a0082e8830f9dc5 \ - --hash=sha256:fe384da74336c398e0d956d1cae0669bc02eed936cdb1d49b57de1990dc11ffc - # via - # -r build/test-requirements.txt - # -r build/requirements.in +setuptools==80.9.0 \ + --hash=sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922 \ + --hash=sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c + # via -r build/requirements.in diff --git a/build/requirements_lock_3_14.txt b/build/requirements_lock_3_14.txt new file mode 100644 index 000000000000..24ed1fec57a6 --- /dev/null +++ b/build/requirements_lock_3_14.txt @@ -0,0 +1,1004 @@ +# +# This file is autogenerated by pip-compile with Python 3.14 +# by the following command: +# +# bazel run //build:requirements.update +# +--index-url https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple + +absl-py==2.3.1 \ + --hash=sha256:a97820526f7fbfd2ec1bce83f3f25e3a14840dac0d8e02a0b71cd75db3f77fc9 \ + --hash=sha256:eeecf07f0c2a93ace0772c92e596ace6d3d3996c042b2128459aaae2a76de11d + # via -r build/test-requirements.txt +attrs==25.4.0 \ + --hash=sha256:16d5969b87f0859ef33a48b35d55ac1be6e42ae49d5e853b597db70c35c57e11 \ + --hash=sha256:adcf7e2a1fb3b36ac48d97835bb6d8ade15b8dcce26aba8bf1d14847b57a3373 + # via hypothesis +auditwheel==6.5.0 \ + --hash=sha256:4fbcbd5854054bb1dd7870db03727b871b96b18147db57259561c058603987d7 \ + --hash=sha256:e08d2eede0259be6feff597d041c06175026e93248a1a97143acc52c57714d80 + # via -r build/test-requirements.txt +build==1.3.0 \ + --hash=sha256:698edd0ea270bde950f53aed21f3a0135672206f3911e0176261a31e0e07b397 \ + --hash=sha256:7145f0b5061ba90a1500d60bd1b13ca0a8a4cebdd0cc16ed8adf1c0e739f43b4 + # via -r build/requirements.in +cloudpickle==3.1.2 \ + --hash=sha256:7fda9eb655c9c230dab534f1983763de5835249750e85fbcef43aaa30a9a2414 \ + --hash=sha256:9acb47f6afd73f60dc1df93bb801b472f05ff42fa6c84167d25cb206be1fbf4a + # via -r build/test-requirements.txt +colorama==0.4.6 \ + --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ + --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 + # via -r build/requirements.in +contourpy==1.3.3 \ + --hash=sha256:023b44101dfe49d7d53932be418477dba359649246075c996866106da069af69 \ + --hash=sha256:07ce5ed73ecdc4a03ffe3e1b3e3c1166db35ae7584be76f65dbbe28a7791b0cc \ + --hash=sha256:083e12155b210502d0bca491432bb04d56dc3432f95a979b429f2848c3dbe880 \ + --hash=sha256:0bf67e0e3f482cb69779dd3061b534eb35ac9b17f163d851e2a547d56dba0a3a \ + --hash=sha256:0c1fc238306b35f246d61a1d416a627348b5cf0648648a031e14bb8705fcdfe8 \ + --hash=sha256:13b68d6a62db8eafaebb8039218921399baf6e47bf85006fd8529f2a08ef33fc \ + --hash=sha256:15ff10bfada4bf92ec8b31c62bf7c1834c244019b4a33095a68000d7075df470 \ + --hash=sha256:177fb367556747a686509d6fef71d221a4b198a3905fe824430e5ea0fda54eb5 \ + --hash=sha256:1cadd8b8969f060ba45ed7c1b714fe69185812ab43bd6b86a9123fe8f99c3263 \ + --hash=sha256:1fd43c3be4c8e5fd6e4f2baeae35ae18176cf2e5cced681cca908addf1cdd53b \ + --hash=sha256:22e9b1bd7a9b1d652cd77388465dc358dafcd2e217d35552424aa4f996f524f5 \ + --hash=sha256:23416f38bfd74d5d28ab8429cc4d63fa67d5068bd711a85edb1c3fb0c3e2f381 \ + --hash=sha256:283edd842a01e3dcd435b1c5116798d661378d83d36d337b8dde1d16a5fc9ba3 \ + --hash=sha256:2a2a8b627d5cc6b7c41a4beff6c5ad5eb848c88255fda4a8745f7e901b32d8e4 \ + --hash=sha256:2b7e9480ffe2b0cd2e787e4df64270e3a0440d9db8dc823312e2c940c167df7e \ + --hash=sha256:322ab1c99b008dad206d406bb61d014cf0174df491ae9d9d0fac6a6fda4f977f \ + --hash=sha256:33c82d0138c0a062380332c861387650c82e4cf1747aaa6938b9b6516762e772 \ + --hash=sha256:348ac1f5d4f1d66d3322420f01d42e43122f43616e0f194fc1c9f5d830c5b286 \ + --hash=sha256:3519428f6be58431c56581f1694ba8e50626f2dd550af225f82fb5f5814d2a42 \ + --hash=sha256:3c30273eb2a55024ff31ba7d052dde990d7d8e5450f4bbb6e913558b3d6c2301 \ + --hash=sha256:3d1a3799d62d45c18bafd41c5fa05120b96a28079f2393af559b843d1a966a77 \ + --hash=sha256:451e71b5a7d597379ef572de31eeb909a87246974d960049a9848c3bc6c41bf7 \ + --hash=sha256:459c1f020cd59fcfe6650180678a9993932d80d44ccde1fa1868977438f0b411 \ + --hash=sha256:4d00e655fcef08aba35ec9610536bfe90267d7ab5ba944f7032549c55a146da1 \ + --hash=sha256:4debd64f124ca62069f313a9cb86656ff087786016d76927ae2cf37846b006c9 \ + --hash=sha256:4feffb6537d64b84877da813a5c30f1422ea5739566abf0bd18065ac040e120a \ + --hash=sha256:50ed930df7289ff2a8d7afeb9603f8289e5704755c7e5c3bbd929c90c817164b \ + --hash=sha256:51e79c1f7470158e838808d4a996fa9bac72c498e93d8ebe5119bc1e6becb0db \ + --hash=sha256:556dba8fb6f5d8742f2923fe9457dbdd51e1049c4a43fd3986a0b14a1d815fc6 \ + --hash=sha256:598c3aaece21c503615fd59c92a3598b428b2f01bfb4b8ca9c4edeecc2438620 \ + --hash=sha256:5ed3657edf08512fc3fe81b510e35c2012fbd3081d2e26160f27ca28affec989 \ + --hash=sha256:626d60935cf668e70a5ce6ff184fd713e9683fb458898e4249b63be9e28286ea \ + --hash=sha256:644a6853d15b2512d67881586bd03f462c7ab755db95f16f14d7e238f2852c67 \ + --hash=sha256:655456777ff65c2c548b7c454af9c6f33f16c8884f11083244b5819cc214f1b5 \ + --hash=sha256:66c8a43a4f7b8df8b71ee1840e4211a3c8d93b214b213f590e18a1beca458f7d \ + --hash=sha256:6afc576f7b33cf00996e5c1102dc2a8f7cc89e39c0b55df93a0b78c1bd992b36 \ + --hash=sha256:6c3d53c796f8647d6deb1abe867daeb66dcc8a97e8455efa729516b997b8ed99 \ + --hash=sha256:709a48ef9a690e1343202916450bc48b9e51c049b089c7f79a267b46cffcdaa1 \ + --hash=sha256:70f9aad7de812d6541d29d2bbf8feb22ff7e1c299523db288004e3157ff4674e \ + --hash=sha256:8153b8bfc11e1e4d75bcb0bff1db232f9e10b274e0929de9d608027e0d34ff8b \ + --hash=sha256:87acf5963fc2b34825e5b6b048f40e3635dd547f590b04d2ab317c2619ef7ae8 \ + --hash=sha256:88df9880d507169449d434c293467418b9f6cbe82edd19284aa0409e7fdb933d \ + --hash=sha256:929ddf8c4c7f348e4c0a5a3a714b5c8542ffaa8c22954862a46ca1813b667ee7 \ + --hash=sha256:92d9abc807cf7d0e047b95ca5d957cf4792fcd04e920ca70d48add15c1a90ea7 \ + --hash=sha256:95b181891b4c71de4bb404c6621e7e2390745f887f2a026b2d99e92c17892339 \ + --hash=sha256:9e999574eddae35f1312c2b4b717b7885d4edd6cb46700e04f7f02db454e67c1 \ + --hash=sha256:a15459b0f4615b00bbd1e91f1b9e19b7e63aea7483d03d804186f278c0af2659 \ + --hash=sha256:a22738912262aa3e254e4f3cb079a95a67132fc5a063890e224393596902f5a4 \ + --hash=sha256:ab2fd90904c503739a75b7c8c5c01160130ba67944a7b77bbf36ef8054576e7f \ + --hash=sha256:ab3074b48c4e2cf1a960e6bbeb7f04566bf36b1861d5c9d4d8ac04b82e38ba20 \ + --hash=sha256:afe5a512f31ee6bd7d0dda52ec9864c984ca3d66664444f2d72e0dc4eb832e36 \ + --hash=sha256:b08a32ea2f8e42cf1d4be3169a98dd4be32bafe4f22b6c4cb4ba810fa9e5d2cb \ + --hash=sha256:b20c7c9a3bf701366556e1b1984ed2d0cedf999903c51311417cf5f591d8c78d \ + --hash=sha256:b2e8faa0ed68cb29af51edd8e24798bb661eac3bd9f65420c1887b6ca89987c8 \ + --hash=sha256:b7301b89040075c30e5768810bc96a8e8d78085b47d8be6e4c3f5a0b4ed478a0 \ + --hash=sha256:b7448cb5a725bb1e35ce88771b86fba35ef418952474492cf7c764059933ff8b \ + --hash=sha256:ca0fdcd73925568ca027e0b17ab07aad764be4706d0a925b89227e447d9737b7 \ + --hash=sha256:ca658cd1a680a5c9ea96dc61cdbae1e85c8f25849843aa799dfd3cb370ad4fbe \ + --hash=sha256:cbedb772ed74ff5be440fa8eee9bd49f64f6e3fc09436d9c7d8f1c287b121d77 \ + --hash=sha256:cd5dfcaeb10f7b7f9dc8941717c6c2ade08f587be2226222c12b25f0483ed497 \ + --hash=sha256:cf9022ef053f2694e31d630feaacb21ea24224be1c3ad0520b13d844274614fd \ + --hash=sha256:d002b6f00d73d69333dac9d0b8d5e84d9724ff9ef044fd63c5986e62b7c9e1b1 \ + --hash=sha256:d06bb1f751ba5d417047db62bca3c8fde202b8c11fb50742ab3ab962c81e8216 \ + --hash=sha256:d304906ecc71672e9c89e87c4675dc5c2645e1f4269a5063b99b0bb29f232d13 \ + --hash=sha256:e4e6b05a45525357e382909a4c1600444e2a45b4795163d3b22669285591c1ae \ + --hash=sha256:e74a9a0f5e3fff48fb5a7f2fd2b9b70a3fe014a67522f79b7cca4c0c7e43c9ae \ + --hash=sha256:ea37e7b45949df430fe649e5de8351c423430046a2af20b1c1961cae3afcda77 \ + --hash=sha256:f64836de09927cba6f79dcd00fdd7d5329f3fccc633468507079c829ca4db4e3 \ + --hash=sha256:fd6ec6be509c787f1caf6b247f0b1ca598bef13f4ddeaa126b7658215529ba0f \ + --hash=sha256:fd907ae12cd483cd83e414b12941c632a969171bf90fc937d0c9f268a31cafff \ + --hash=sha256:fd914713266421b7536de2bfa8181aa8c699432b6763a0ea64195ebe28bff6a9 \ + --hash=sha256:fde6c716d51c04b1c25d0b90364d0be954624a0ee9d60e23e850e8d48353d07a + # via matplotlib +cycler==0.12.1 \ + --hash=sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30 \ + --hash=sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c + # via matplotlib +etils[epath,epy]==1.13.0 \ + --hash=sha256:a5b60c71f95bcd2d43d4e9fb3dc3879120c1f60472bb5ce19f7a860b1d44f607 \ + --hash=sha256:d9cd4f40fbe77ad6613b7348a18132cc511237b6c076dbb89105c0b520a4c6bb + # via -r build/requirements.in +execnet==2.1.2 \ + --hash=sha256:63d83bfdd9a23e35b9c6a3261412324f964c2ec8dcd8d3c6916ee9373e0befcd \ + --hash=sha256:67fba928dd5a544b783f6056f449e5e3931a5c378b128bc18501f7ea79e296ec + # via pytest-xdist +filelock==3.20.1 \ + --hash=sha256:15d9e9a67306188a44baa72f569d2bfd803076269365fdea0934385da4dc361a \ + --hash=sha256:b8360948b351b80f420878d8516519a2204b07aefcdcfd24912a5d33127f188c + # via -r build/test-requirements.txt +flatbuffers==25.9.23 \ + --hash=sha256:255538574d6cb6d0a79a17ec8bc0d30985913b87513a01cce8bcdb6b4c44d0e2 \ + --hash=sha256:676f9fa62750bb50cf531b42a0a2a118ad8f7f797a511eda12881c016f093b12 + # via -r build/test-requirements.txt +fonttools==4.61.1 \ + --hash=sha256:0de30bfe7745c0d1ffa2b0b7048fb7123ad0d71107e10ee090fa0b16b9452e87 \ + --hash=sha256:10d88e55330e092940584774ee5e8a6971b01fc2f4d3466a1d6c158230880796 \ + --hash=sha256:11f35ad7805edba3aac1a3710d104592df59f4b957e30108ae0ba6c10b11dd75 \ + --hash=sha256:15acc09befd16a0fb8a8f62bc147e1a82817542d72184acca9ce6e0aeda9fa6d \ + --hash=sha256:17d2bf5d541add43822bcf0c43d7d847b160c9bb01d15d5007d84e2217aaa371 \ + --hash=sha256:2180f14c141d2f0f3da43f3a81bc8aa4684860f6b0e6f9e165a4831f24e6a23b \ + --hash=sha256:21e7c8d76f62ab13c9472ccf74515ca5b9a761d1bde3265152a6dc58700d895b \ + --hash=sha256:41a7170d042e8c0024703ed13b71893519a1a6d6e18e933e3ec7507a2c26a4b2 \ + --hash=sha256:41ed4b5ec103bd306bb68f81dc166e77409e5209443e5773cb4ed837bcc9b0d3 \ + --hash=sha256:497c31ce314219888c0e2fce5ad9178ca83fe5230b01a5006726cdf3ac9f24d9 \ + --hash=sha256:4c1b526c8d3f615a7b1867f38a9410849c8f4aef078535742198e942fba0e9bd \ + --hash=sha256:4d7092bb38c53bbc78e9255a59158b150bcdc115a1e3b3ce0b5f267dc35dd63c \ + --hash=sha256:4f5686e1fe5fce75d82d93c47a438a25bf0d1319d2843a926f741140b2b16e0c \ + --hash=sha256:58b0ee0ab5b1fc9921eccfe11d1435added19d6494dde14e323f25ad2bc30c56 \ + --hash=sha256:5ce02f38a754f207f2f06557523cd39a06438ba3aafc0639c477ac409fc64e37 \ + --hash=sha256:5fade934607a523614726119164ff621e8c30e8fa1ffffbbd358662056ba69f0 \ + --hash=sha256:5fe9fd43882620017add5eabb781ebfbc6998ee49b35bd7f8f79af1f9f99a958 \ + --hash=sha256:64102ca87e84261419c3747a0d20f396eb024bdbeb04c2bfb37e2891f5fadcb5 \ + --hash=sha256:664c5a68ec406f6b1547946683008576ef8b38275608e1cee6c061828171c118 \ + --hash=sha256:6675329885c44657f826ef01d9e4fb33b9158e9d93c537d84ad8399539bc6f69 \ + --hash=sha256:75c1a6dfac6abd407634420c93864a1e274ebc1c7531346d9254c0d8f6ca00f9 \ + --hash=sha256:75da8f28eff26defba42c52986de97b22106cb8f26515b7c22443ebc9c2d3261 \ + --hash=sha256:77efb033d8d7ff233385f30c62c7c79271c8885d5c9657d967ede124671bbdfb \ + --hash=sha256:78a7d3ab09dc47ac1a363a493e6112d8cabed7ba7caad5f54dbe2f08676d1b47 \ + --hash=sha256:7c7db70d57e5e1089a274cbb2b1fd635c9a24de809a231b154965d415d6c6d24 \ + --hash=sha256:8c56c488ab471628ff3bfa80964372fc13504ece601e0d97a78ee74126b2045c \ + --hash=sha256:91669ccac46bbc1d09e9273546181919064e8df73488ea087dcac3e2968df9ba \ + --hash=sha256:9b666a475a65f4e839d3d10473fad6d47e0a9db14a2f4a224029c5bfde58ad2c \ + --hash=sha256:9cfef3ab326780c04d6646f68d4b4742aae222e8b8ea1d627c74e38afcbc9d91 \ + --hash=sha256:a13fc8aeb24bad755eea8f7f9d409438eb94e82cf86b08fe77a03fbc8f6a96b1 \ + --hash=sha256:a75c301f96db737e1c5ed5fd7d77d9c34466de16095a266509e13da09751bd19 \ + --hash=sha256:a76d4cb80f41ba94a6691264be76435e5f72f2cb3cab0b092a6212855f71c2f6 \ + --hash=sha256:aed04cabe26f30c1647ef0e8fbb207516fd40fe9472e9439695f5c6998e60ac5 \ + --hash=sha256:b148b56f5de675ee16d45e769e69f87623a4944f7443850bf9a9376e628a89d2 \ + --hash=sha256:b501c862d4901792adaec7c25b1ecc749e2662543f68bb194c42ba18d6eec98d \ + --hash=sha256:b846a1fcf8beadeb9ea4f44ec5bdde393e2f1569e17d700bfc49cd69bde75881 \ + --hash=sha256:b931ae8f62db78861b0ff1ac017851764602288575d65b8e8ff1963fed419063 \ + --hash=sha256:c33ab3ca9d3ccd581d58e989d67554e42d8d4ded94ab3ade3508455fe70e65f7 \ + --hash=sha256:c6604b735bb12fef8e0efd5578c9fb5d3d8532d5001ea13a19cddf295673ee09 \ + --hash=sha256:d8db08051fc9e7d8bc622f2112511b8107d8f27cd89e2f64ec45e9825e8288da \ + --hash=sha256:d9203500f7c63545b4ce3799319fe4d9feb1a1b89b28d3cb5abd11b9dd64147e \ + --hash=sha256:dc492779501fa723b04d0ab1f5be046797fee17d27700476edc7ee9ae535a61e \ + --hash=sha256:e6bcdf33aec38d16508ce61fd81838f24c83c90a1d1b8c68982857038673d6b8 \ + --hash=sha256:e76ce097e3c57c4bcb67c5aa24a0ecdbd9f74ea9219997a707a4061fbe2707aa \ + --hash=sha256:eff1ac3cc66c2ac7cda1e64b4e2f3ffef474b7335f92fc3833fc632d595fcee6 \ + --hash=sha256:f3cb4a569029b9f291f88aafc927dd53683757e640081ca8c412781ea144565e \ + --hash=sha256:f79b168428351d11e10c5aeb61a74e1851ec221081299f4cf56036a95431c43a \ + --hash=sha256:fa646ecec9528bef693415c79a86e733c70a4965dd938e9a226b0fc64c9d2e6c \ + --hash=sha256:fe2efccb324948a11dd09d22136fe2ac8a97d6c1347cf0b58a911dcd529f66b7 \ + --hash=sha256:fff4f534200a04b4a36e7ae3cb74493afe807b517a09e99cb4faa89a34ed6ecd + # via matplotlib +fsspec==2025.10.0 \ + --hash=sha256:7c7712353ae7d875407f97715f0e1ffcc21e33d5b24556cb1e090ae9409ec61d \ + --hash=sha256:b6789427626f068f9a83ca4e8a3cc050850b6c0f71f99ddb4f542b8266a26a59 + # via etils +hypothesis==6.142.1 \ + --hash=sha256:3179cb08756562c526aaf4a9871ebbff83d2d75c03896ed0bc9c1d14097a930c \ + --hash=sha256:95a7d38fcc58e697e3020665adcb951c630cdbc8065e4b4474949e486b06bd6d + # via -r build/test-requirements.txt +importlib-resources==6.5.2 \ + --hash=sha256:185f87adef5bcc288449d98fb4fba07cea78bc036455dd44c5fc4a2fe78fed2c \ + --hash=sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec + # via etils +iniconfig==2.3.0 \ + --hash=sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730 \ + --hash=sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12 + # via pytest +jax-cuda12-pjrt==0.8.2 ; sys_platform == "linux" \ + --hash=sha256:717a1b196a642409ce195ddf031c20bbeadcc886f55e49a1d3f4927373aeedae \ + --hash=sha256:e3bab41ca7c48e4163db9e7efd271b3aa85f0fe45f5ed0708d6bbed93a59f977 + # via -r build/requirements.in +jax-cuda13-pjrt==0.8.2 \ + --hash=sha256:370e9cb9a76af6a6b00e8f6c6ae8e5a0158886b65994a114cf3acf56ea1570f3 \ + --hash=sha256:c0ebc284d4bea9c0b278017b4bb62eb871755fe582cbf22ce1fb770d30fb9af6 + # via + # -r build/requirements.in + # jax-cuda13-plugin +jax-cuda13-plugin==0.8.2 \ + --hash=sha256:0ab3a13f798f973416754cb00b1696a8e6d2425c39496e2145293661bcd28446 \ + --hash=sha256:2c7532f2a6c9dcf50ba4b8ebbe4c78988b8c260695e50f9fdbd68f9417a8bb51 \ + --hash=sha256:358f4b4d0d2b92f62fe7f07d355a9c57c5684da3d4e6f483afec38b23e66f02e \ + --hash=sha256:6703552ee10b6b41b43fe4fc62b4672bec13d271eb709a29bf14b0fa1ec805dd \ + --hash=sha256:8fe2c5874cbae18a01a0fa722dfbd547b5504e7bf42aaaae6ad3e1b676e8daa2 \ + --hash=sha256:950f2c1f2e0f5e13893c4c27a4df540a96a26ebd15d7954fa85ffb8cc8b1d9a1 \ + --hash=sha256:9835076d0a917345700370cfaee5cfd7876c0b0982e99206a8b4848af0ea6e5e \ + --hash=sha256:bb059d99eb960718358f82f3eaa0829f96fc7d7e978ded821f2c40256a8aee4b \ + --hash=sha256:d5f9470db46aee8a64e9ab7817927335b77c6d3698d5113e0da2dd962c1468c4 \ + --hash=sha256:f1643b216ccb41a2d7c0e240359c729c96c5eca76e4525ffcd6313049e348b9a \ + --hash=sha256:f35f687d9b839365a49c2b0052eaeb6b9f0a9879813b4fa020f8d9caa3ddde46 \ + --hash=sha256:fa6a6faf3130d6ef05f9ee6cbc27e0a2e4db0f23720136403991eccb3a0144af + # via -r build/requirements.in +jaxlib==0.8.2 \ + --hash=sha256:023de6f3f56da2af7037970996500586331fdb50b530ecbb54b9666da633bd00 \ + --hash=sha256:05b958f497e49824c432e734bb059723b7dfe69e2ad696a9f9c8ad82fff7c3f8 \ + --hash=sha256:1bfbcf6c3de221784fa4cdb6765a09d71cb4298b15626b3d0409b3dfcd8a8667 \ + --hash=sha256:28eec1a4e0639a0d8702cea3cb70dd3663053dbfa344452994ea48dc6ceadaa5 \ + --hash=sha256:2b9789bd08f8b0cc5a5c12ae896fe432d5942e32e417091b8b5a96a9a6fd5cf1 \ + --hash=sha256:3b16e50c5b730c9dd0a49e55f1acfaa722b00b1af0522a591558dcc0464252f2 \ + --hash=sha256:490bf0cb029c73c65c9431124b86cdc95082dbc1fb76fc549d24d75da33e5454 \ + --hash=sha256:4d006db96be020c8165212a1216372f8acac4ff4f8fb067743d694ef2b301ace \ + --hash=sha256:68108dff0de74adc468016be9a19f80efe48c660c0d5a122287094b44b092afc \ + --hash=sha256:7c304f3a016965b9d1f5239a8a0399a73925f5604fe914c5ca66ecf734bf6422 \ + --hash=sha256:7da8127557c786264049ae55460d1b8d04cc3cdf0403a087f2fc1e6d313ec722 \ + --hash=sha256:964626f581beab31ee6826b228fcc2ec5181b05cecf94a528dff97921c145dbc \ + --hash=sha256:a397ea7dcb37d689ce79173eeb99b2f1347637a36be9a27f20ae6848bfc58bfc \ + --hash=sha256:aa8701b6356f098e8452c3cec762fb5f706fcb8f67ffd65964f63982479aa23b \ + --hash=sha256:bb89be452b1b808d3f88fc01c415b364a260be4cc7ac120c038009f6150a32dc \ + --hash=sha256:beffb004e7eeb5c9afb24439e2b2cf45a4ee3e3e8adf45e355edf2af62acf8b8 \ + --hash=sha256:ccf77da917a20935247c990691decfcbdd06c25ef0ac94d914a04aadb22f714c \ + --hash=sha256:dffc22b5b732b9556d92c918b251c61bcc046617c4dbb51e1f7a656587fddffb \ + --hash=sha256:e6a97dfb0232eed9a2bb6e3828e4f682dbac1a7fea840bfda574cae2dbf5faf9 \ + --hash=sha256:f205e91c3a152a2a76c0bc59a6a2de03e87ec261b91e8812922777185e7b08f5 \ + --hash=sha256:f28edac8c226fc07fa3e8af6f9defede8ac2c307429e3291edce8739d39becc9 \ + --hash=sha256:f472cc72e3058e50b5f0230b236d5a1183bf6c3d5423d2a52eff07bcf34908de + # via -r build/requirements.in +kiwisolver==1.4.9 \ + --hash=sha256:0749fd8f4218ad2e851e11cc4dc05c7cbc0cbc4267bdfdb31782e65aace4ee9c \ + --hash=sha256:0763515d4df10edf6d06a3c19734e2566368980d21ebec439f33f9eb936c07b7 \ + --hash=sha256:0856e241c2d3df4efef7c04a1e46b1936b6120c9bcf36dd216e3acd84bc4fb21 \ + --hash=sha256:0a590506f303f512dff6b7f75fd2fd18e16943efee932008fe7140e5fa91d80e \ + --hash=sha256:0ab74e19f6a2b027ea4f845a78827969af45ce790e6cb3e1ebab71bdf9f215ff \ + --hash=sha256:0ae37737256ba2de764ddc12aed4956460277f00c4996d51a197e72f62f5eec7 \ + --hash=sha256:0e4e2bf29574a6a7b7f6cb5fa69293b9f96c928949ac4a53ba3f525dffb87f9c \ + --hash=sha256:15163165efc2f627eb9687ea5f3a28137217d217ac4024893d753f46bce9de26 \ + --hash=sha256:17680d737d5335b552994a2008fab4c851bcd7de33094a82067ef3a576ff02fa \ + --hash=sha256:1a12cf6398e8a0a001a059747a1cbf24705e18fe413bc22de7b3d15c67cffe3f \ + --hash=sha256:1b11d6a633e4ed84fc0ddafd4ebfd8ea49b3f25082c04ad12b8315c11d504dc1 \ + --hash=sha256:1fa333e8b2ce4d9660f2cda9c0e1b6bafcfb2457a9d259faa82289e73ec24891 \ + --hash=sha256:2327a4a30d3ee07d2fbe2e7933e8a37c591663b96ce42a00bc67461a87d7df77 \ + --hash=sha256:2405a7d98604b87f3fc28b1716783534b1b4b8510d8142adca34ee0bc3c87543 \ + --hash=sha256:2489e4e5d7ef9a1c300a5e0196e43d9c739f066ef23270607d45aba368b91f2d \ + --hash=sha256:24c175051354f4a28c5d6a31c93906dc653e2bf234e8a4bbfb964892078898ce \ + --hash=sha256:2635d352d67458b66fd0667c14cb1d4145e9560d503219034a18a87e971ce4f3 \ + --hash=sha256:2c1a4f57df73965f3f14df20b80ee29e6a7930a57d2d9e8491a25f676e197c60 \ + --hash=sha256:2c93f00dcba2eea70af2be5f11a830a742fe6b579a1d4e00f47760ef13be247a \ + --hash=sha256:39a219e1c81ae3b103643d2aedb90f1ef22650deb266ff12a19e7773f3e5f089 \ + --hash=sha256:3b3115b2581ea35bb6d1f24a4c90af37e5d9b49dcff267eeed14c3893c5b86ab \ + --hash=sha256:40092754720b174e6ccf9e845d0d8c7d8e12c3d71e7fc35f55f3813e96376f78 \ + --hash=sha256:412f287c55a6f54b0650bd9b6dce5aceddb95864a1a90c87af16979d37c89771 \ + --hash=sha256:464415881e4801295659462c49461a24fb107c140de781d55518c4b80cb6790f \ + --hash=sha256:497d05f29a1300d14e02e6441cf0f5ee81c1ff5a304b0d9fb77423974684e08b \ + --hash=sha256:4a2899935e724dd1074cb568ce7ac0dce28b2cd6ab539c8e001a8578eb106d14 \ + --hash=sha256:4a48a2ce79d65d363597ef7b567ce3d14d68783d2b2263d98db3d9477805ba32 \ + --hash=sha256:4d1d9e582ad4d63062d34077a9a1e9f3c34088a2ec5135b1f7190c07cf366527 \ + --hash=sha256:52a15b0f35dad39862d376df10c5230155243a2c1a436e39eb55623ccbd68185 \ + --hash=sha256:540c7c72324d864406a009d72f5d6856f49693db95d1fbb46cf86febef873634 \ + --hash=sha256:5656aa670507437af0207645273ccdfee4f14bacd7f7c67a4306d0dcaeaf6eed \ + --hash=sha256:5a0f2724dfd4e3b3ac5a82436a8e6fd16baa7d507117e4279b660fe8ca38a3a1 \ + --hash=sha256:60c439763a969a6af93b4881db0eed8fadf93ee98e18cbc35bc8da868d0c4f0c \ + --hash=sha256:61874cdb0a36016354853593cffc38e56fc9ca5aa97d2c05d3dcf6922cd55a11 \ + --hash=sha256:67bb8b474b4181770f926f7b7d2f8c0248cbcb78b660fdd41a47054b28d2a752 \ + --hash=sha256:720e05574713db64c356e86732c0f3c5252818d05f9df320f0ad8380641acea5 \ + --hash=sha256:72d0eb9fba308b8311685c2268cf7d0a0639a6cd027d8128659f72bdd8a024b4 \ + --hash=sha256:767c23ad1c58c9e827b649a9ab7809fd5fd9db266a9cf02b0e926ddc2c680d58 \ + --hash=sha256:77937e5e2a38a7b48eef0585114fe7930346993a88060d0bf886086d2aa49ef5 \ + --hash=sha256:7a08b491ec91b1d5053ac177afe5290adacf1f0f6307d771ccac5de30592d198 \ + --hash=sha256:7b4da0d01ac866a57dd61ac258c5607b4cd677f63abaec7b148354d2b2cdd536 \ + --hash=sha256:7cf974dd4e35fa315563ac99d6287a1024e4dc2077b8a7d7cd3d2fb65d283134 \ + --hash=sha256:84fd60810829c27ae375114cd379da1fa65e6918e1da405f356a775d49a62bcf \ + --hash=sha256:858e4c22fb075920b96a291928cb7dea5644e94c0ee4fcd5af7e865655e4ccf2 \ + --hash=sha256:85b5352f94e490c028926ea567fc569c52ec79ce131dadb968d3853e809518c2 \ + --hash=sha256:85bd218b5ecfbee8c8a82e121802dcb519a86044c9c3b2e4aef02fa05c6da370 \ + --hash=sha256:8a1f570ce4d62d718dce3f179ee78dac3b545ac16c0c04bb363b7607a949c0d1 \ + --hash=sha256:8fdca1def57a2e88ef339de1737a1449d6dbf5fab184c54a1fca01d541317154 \ + --hash=sha256:90f47e70293fc3688b71271100a1a5453aa9944a81d27ff779c108372cf5567b \ + --hash=sha256:92a2f997387a1b79a75e7803aa7ded2cfbe2823852ccf1ba3bcf613b62ae3197 \ + --hash=sha256:9928fe1eb816d11ae170885a74d074f57af3a0d65777ca47e9aeb854a1fba386 \ + --hash=sha256:9af39d6551f97d31a4deebeac6f45b156f9755ddc59c07b402c148f5dbb6482a \ + --hash=sha256:9cf554f21be770f5111a1690d42313e140355e687e05cf82cb23d0a721a64a48 \ + --hash=sha256:a30fd6fdef1430fd9e1ba7b3398b5ee4e2887783917a687d86ba69985fb08748 \ + --hash=sha256:a31d512c812daea6d8b3be3b2bfcbeb091dbb09177706569bcfc6240dcf8b41c \ + --hash=sha256:a5d0432ccf1c7ab14f9949eec60c5d1f924f17c037e9f8b33352fa05799359b8 \ + --hash=sha256:a60ea74330b91bd22a29638940d115df9dc00af5035a9a2a6ad9399ffb4ceca5 \ + --hash=sha256:ac5a486ac389dddcc5bef4f365b6ae3ffff2c433324fb38dd35e3fab7c957999 \ + --hash=sha256:aedff62918805fb62d43a4aa2ecd4482c380dc76cd31bd7c8878588a61bd0369 \ + --hash=sha256:b34e51affded8faee0dfdb705416153819d8ea9250bbbf7ea1b249bdeb5f1122 \ + --hash=sha256:b4b4d74bda2b8ebf4da5bd42af11d02d04428b2c32846e4c2c93219df8a7987b \ + --hash=sha256:b67e6efbf68e077dd71d1a6b37e43e1a99d0bff1a3d51867d45ee8908b931098 \ + --hash=sha256:b78efa4c6e804ecdf727e580dbb9cba85624d2e1c6b5cb059c66290063bd99a9 \ + --hash=sha256:bb4ae2b57fc1d8cbd1cf7b1d9913803681ffa903e7488012be5b76dedf49297f \ + --hash=sha256:bdd1a81a1860476eb41ac4bc1e07b3f07259e6d55bbf739b79c8aaedcf512799 \ + --hash=sha256:bdee92c56a71d2b24c33a7d4c2856bd6419d017e08caa7802d2963870e315028 \ + --hash=sha256:be6a04e6c79819c9a8c2373317d19a96048e5a3f90bec587787e86a1153883c2 \ + --hash=sha256:bfc08add558155345129c7803b3671cf195e6a56e7a12f3dde7c57d9b417f525 \ + --hash=sha256:c3b22c26c6fd6811b0ae8363b95ca8ce4ea3c202d3d0975b2914310ceb1bcc4d \ + --hash=sha256:c9e7cdf45d594ee04d5be1b24dd9d49f3d1590959b2271fb30b5ca2b262c00fb \ + --hash=sha256:cb27e7b78d716c591e88e0a09a2139c6577865d7f2e152488c2cc6257f460872 \ + --hash=sha256:cc9617b46837c6468197b5945e196ee9ca43057bb7d9d1ae688101e4e1dddf64 \ + --hash=sha256:ccd09f20ccdbbd341b21a67ab50a119b64a403b09288c27481575105283c1586 \ + --hash=sha256:ce6a3a4e106cf35c2d9c4fa17c05ce0b180db622736845d4315519397a77beaf \ + --hash=sha256:d0005b053977e7b43388ddec89fa567f43d4f6d5c2c0affe57de5ebf290dc552 \ + --hash=sha256:d4188e73af84ca82468f09cadc5ac4db578109e52acb4518d8154698d3a87ca2 \ + --hash=sha256:d4efec7bcf21671db6a3294ff301d2fc861c31faa3c8740d1a94689234d1b415 \ + --hash=sha256:d75aa530ccfaa593da12834b86a0724f58bff12706659baa9227c2ccaa06264c \ + --hash=sha256:d84cd4061ae292d8ac367b2c3fa3aad11cb8625a95d135fe93f286f914f3f5a6 \ + --hash=sha256:d8aacd3d4b33b772542b2e01beb50187536967b514b00003bdda7589722d2a64 \ + --hash=sha256:d8fc5c867c22b828001b6a38d2eaeb88160bf5783c6cb4a5e440efc981ce286d \ + --hash=sha256:d976bbb382b202f71c67f77b0ac11244021cfa3f7dfd9e562eefcea2df711548 \ + --hash=sha256:dba5ee5d3981160c28d5490f0d1b7ed730c22470ff7f6cc26cfcfaacb9896a07 \ + --hash=sha256:dc1ae486f9abcef254b5618dfb4113dd49f94c68e3e027d03cf0143f3f772b61 \ + --hash=sha256:dd0a578400839256df88c16abddf9ba14813ec5f21362e1fe65022e00c883d4d \ + --hash=sha256:deed0c7258ceb4c44ad5ec7d9918f9f14fd05b2be86378d86cf50e63d1e7b771 \ + --hash=sha256:e09c2279a4d01f099f52d5c4b3d9e208e91edcbd1a175c9662a8b16e000fece9 \ + --hash=sha256:e2ea9f7ab7fbf18fffb1b5434ce7c69a07582f7acc7717720f1d69f3e806f90c \ + --hash=sha256:e6b93f13371d341afee3be9f7c5964e3fe61d5fa30f6a30eb49856935dfe4fc3 \ + --hash=sha256:eb14a5da6dc7642b0f3a18f13654847cd8b7a2550e2645a5bda677862b03ba16 \ + --hash=sha256:ed0fecd28cc62c54b262e3736f8bb2512d8dcfdc2bcf08be5f47f96bf405b145 \ + --hash=sha256:ede8c6d533bc6601a47ad4046080d36b8fc99f81e6f1c17b0ac3c2dc91ac7611 \ + --hash=sha256:efb3a45b35622bb6c16dbfab491a8f5a391fe0e9d45ef32f4df85658232ca0e2 \ + --hash=sha256:f117e1a089d9411663a3207ba874f31be9ac8eaa5b533787024dc07aeb74f464 \ + --hash=sha256:f2ba92255faa7309d06fe44c3a4a97efe1c8d640c2a79a5ef728b685762a6fd2 \ + --hash=sha256:f6008a4919fdbc0b0097089f67a1eb55d950ed7e90ce2cc3e640abadd2757a04 \ + --hash=sha256:f68208a520c3d86ea51acf688a3e3002615a7f0238002cccc17affecc86a8a54 \ + --hash=sha256:f68e4f3eeca8fb22cc3d731f9715a13b652795ef657a13df1ad0c7dc0e9731df \ + --hash=sha256:fb3b8132019ea572f4611d770991000d7f58127560c4889729248eb5852a102f \ + --hash=sha256:fb940820c63a9590d31d88b815e7a3aa5915cad3ce735ab45f0c730b39547de1 \ + --hash=sha256:fc1795ac5cd0510207482c3d1d3ed781143383b8cfd36f5c645f3897ce066220 + # via matplotlib +libtpu==0.0.32 ; sys_platform == "linux" and platform_machine == "x86_64" \ + --hash=sha256:37f6aefe6d69d24e5e268e74c1e90cd0ce0fa6a796720709445d8938605912ee \ + --hash=sha256:4a4ae6db0a90f0e33902cd345923a5d77dfc5bbab9a3e524c79bf876858104ee \ + --hash=sha256:6453995774b902e43d1caf0d82298a2e90f2f731ddd74d6bef9510d91732775f \ + --hash=sha256:7ffd9cdb89da22d32bb60ce70b88e9e76861c5d704ebf673f32987e19e9913f3 \ + --hash=sha256:98d9713a39d9581c9b10a7bdeaa575cc874a4a7f543cfb40d3a723ee27b428b5 \ + --hash=sha256:d2a0e7abe4a029b79049bc95da20e4c2a64f8ef09823f75286cecfbb4b0b6da1 + # via -r build/requirements.in +markdown-it-py==4.0.0 \ + --hash=sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147 \ + --hash=sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3 + # via rich +matplotlib==3.10.8 \ + --hash=sha256:00270d217d6b20d14b584c521f810d60c5c78406dc289859776550df837dcda7 \ + --hash=sha256:0a33deb84c15ede243aead39f77e990469fff93ad1521163305095b77b72ce4a \ + --hash=sha256:113bb52413ea508ce954a02c10ffd0d565f9c3bc7f2eddc27dfe1731e71c7b5f \ + --hash=sha256:12d90df9183093fcd479f4172ac26b322b1248b15729cb57f42f71f24c7e37a3 \ + --hash=sha256:15d30132718972c2c074cd14638c7f4592bd98719e2308bccea40e0538bc0cb5 \ + --hash=sha256:18821ace09c763ec93aef5eeff087ee493a24051936d7b9ebcad9662f66501f9 \ + --hash=sha256:1ae029229a57cd1e8fe542485f27e7ca7b23aa9e8944ddb4985d0bc444f1eca2 \ + --hash=sha256:2299372c19d56bcd35cf05a2738308758d32b9eaed2371898d8f5bd33f084aa3 \ + --hash=sha256:238b7ce5717600615c895050239ec955d91f321c209dd110db988500558e70d6 \ + --hash=sha256:24d50994d8c5816ddc35411e50a86ab05f575e2530c02752e02538122613371f \ + --hash=sha256:25d380fe8b1dc32cf8f0b1b448470a77afb195438bafdf1d858bfb876f3edf7b \ + --hash=sha256:2c1998e92cd5999e295a731bcb2911c75f597d937341f3030cc24ef2733d78a8 \ + --hash=sha256:2cf5bd12cecf46908f286d7838b2abc6c91cda506c0445b8223a7c19a00df008 \ + --hash=sha256:32f8dce744be5569bebe789e46727946041199030db8aeb2954d26013a0eb26b \ + --hash=sha256:37b3c1cc42aa184b3f738cfa18c1c1d72fd496d85467a6cf7b807936d39aa656 \ + --hash=sha256:3a48a78d2786784cc2413e57397981fb45c79e968d99656706018d6e62e57958 \ + --hash=sha256:3ab4aabc72de4ff77b3ec33a6d78a68227bf1123465887f9905ba79184a1cc04 \ + --hash=sha256:3c624e43ed56313651bc18a47f838b60d7b8032ed348911c54906b130b20071b \ + --hash=sha256:3f2e409836d7f5ac2f1c013110a4d50b9f7edc26328c108915f9075d7d7a91b6 \ + --hash=sha256:3f5c3e4da343bba819f0234186b9004faba952cc420fbc522dc4e103c1985908 \ + --hash=sha256:41703cc95688f2516b480f7f339d8851a6035f18e100ee6a32bc0b8536a12a9c \ + --hash=sha256:495672de149445ec1b772ff2c9ede9b769e3cb4f0d0aa7fa730d7f59e2d4e1c1 \ + --hash=sha256:4cf267add95b1c88300d96ca837833d4112756045364f5c734a2276038dae27d \ + --hash=sha256:56271f3dac49a88d7fca5060f004d9d22b865f743a12a23b1e937a0be4818ee1 \ + --hash=sha256:595ba4d8fe983b88f0eec8c26a241e16d6376fe1979086232f481f8f3f67494c \ + --hash=sha256:5f62550b9a30afde8c1c3ae450e5eb547d579dd69b25c2fc7a1c67f934c1717a \ + --hash=sha256:646d95230efb9ca614a7a594d4fcacde0ac61d25e37dd51710b36477594963ce \ + --hash=sha256:64fcc24778ca0404ce0cb7b6b77ae1f4c7231cdd60e6778f999ee05cbd581b9a \ + --hash=sha256:6be43b667360fef5c754dda5d25a32e6307a03c204f3c0fc5468b78fa87b4160 \ + --hash=sha256:6da7c2ce169267d0d066adcf63758f0604aa6c3eebf67458930f9d9b79ad1db1 \ + --hash=sha256:83d282364ea9f3e52363da262ce32a09dfe241e4080dcedda3c0db059d3c1f11 \ + --hash=sha256:9153c3292705be9f9c64498a8872118540c3f4123d1a1c840172edf262c8be4a \ + --hash=sha256:99eefd13c0dc3b3c1b4d561c1169e65fe47aab7b8158754d7c084088e2329466 \ + --hash=sha256:a0a7f52498f72f13d4a25ea70f35f4cb60642b466cbb0a9be951b5bc3f45a486 \ + --hash=sha256:a2b336e2d91a3d7006864e0990c83b216fcdca64b5a6484912902cef87313d78 \ + --hash=sha256:a48f2b74020919552ea25d222d5cc6af9ca3f4eb43a93e14d068457f545c2a17 \ + --hash=sha256:ad3d9833a64cf48cc4300f2b406c3d0f4f4724a91c0bd5640678a6ba7c102077 \ + --hash=sha256:b44d07310e404ba95f8c25aa5536f154c0a8ec473303535949e52eb71d0a1565 \ + --hash=sha256:b53285e65d4fa4c86399979e956235deb900be5baa7fc1218ea67fbfaeaadd6f \ + --hash=sha256:b5a2b97dbdc7d4f353ebf343744f1d1f1cca8aa8bfddb4262fcf4306c3761d50 \ + --hash=sha256:b9a5ca4ac220a0cdd1ba6bcba3608547117d30468fefce49bb26f55c1a3d5c58 \ + --hash=sha256:bab485bcf8b1c7d2060b4fcb6fc368a9e6f4cd754c9c2fea281f4be21df394a2 \ + --hash=sha256:c108a1d6fa78a50646029cb6d49808ff0fc1330fda87fa6f6250c6b5369b6645 \ + --hash=sha256:d56a1efd5bfd61486c8bc968fa18734464556f0fb8e51690f4ac25d85cbbbbc2 \ + --hash=sha256:d9050fee89a89ed57b4fb2c1bfac9a3d0c57a0d55aed95949eedbc42070fea39 \ + --hash=sha256:dd80ecb295460a5d9d260df63c43f4afbdd832d725a531f008dad1664f458adf \ + --hash=sha256:e8ea3e2d4066083e264e75c829078f9e149fa119d27e19acd503de65e0b13149 \ + --hash=sha256:eb3823f11823deade26ce3b9f40dcb4a213da7a670013929f31d5f5ed1055b22 \ + --hash=sha256:ee40c27c795bda6a5292e9cff9890189d32f7e3a0bf04e0e3c9430c4a00c37df \ + --hash=sha256:efb30e3baaea72ce5928e32bab719ab4770099079d66726a62b11b1ef7273be4 \ + --hash=sha256:f254d118d14a7f99d616271d6c3c27922c092dac11112670b157798b89bf4933 \ + --hash=sha256:f89c151aab2e2e23cb3fe0acad1e8b82841fd265379c4cecd0f3fcb34c15e0f6 \ + --hash=sha256:f97aeb209c3d2511443f8797e3e5a569aebb040d4f8bc79aa3ee78a8fb9e3dd8 \ + --hash=sha256:f9b587c9c7274c1613a30afabf65a272114cd6cdbe67b3406f818c79d7ab2e2a \ + --hash=sha256:fb061f596dad3a0f52b60dc6a5dec4a0c300dec41e058a7efe09256188d170b7 + # via -r build/test-requirements.txt +mdurl==0.1.2 \ + --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ + --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba + # via markdown-it-py +ml-dtypes==0.5.4 \ + --hash=sha256:0d2ffd05a2575b1519dc928c0b93c06339eb67173ff53acb00724502cda231cf \ + --hash=sha256:11942cbf2cf92157db91e5022633c0d9474d4dfd813a909383bd23ce828a4b7d \ + --hash=sha256:14a4fd3228af936461db66faccef6e4f41c1d82fcc30e9f8d58a08916b1d811f \ + --hash=sha256:19b9a53598f21e453ea2fbda8aa783c20faff8e1eeb0d7ab899309a0053f1483 \ + --hash=sha256:2314892cdc3fcf05e373d76d72aaa15fda9fb98625effa73c1d646f331fcecb7 \ + --hash=sha256:2b857d3af6ac0d39db1de7c706e69c7f9791627209c3d6dedbfca8c7e5faec22 \ + --hash=sha256:304ad47faa395415b9ccbcc06a0350800bc50eda70f0e45326796e27c62f18b6 \ + --hash=sha256:35f29491a3e478407f7047b8a4834e4640a77d2737e0b294d049746507af5175 \ + --hash=sha256:388d399a2152dd79a3f0456a952284a99ee5c93d3e2f8dfe25977511e0515270 \ + --hash=sha256:3bbbe120b915090d9dd1375e4684dd17a20a2491ef25d640a908281da85e73f1 \ + --hash=sha256:3d277bf3637f2a62176f4575512e9ff9ef51d00e39626d9fe4a161992f355af2 \ + --hash=sha256:4381fe2f2452a2d7589689693d3162e876b3ddb0a832cde7a414f8e1adf7eab1 \ + --hash=sha256:4ff7f3e7ca2972e7de850e7b8fcbb355304271e2933dd90814c1cb847414d6e2 \ + --hash=sha256:531eff30e4d368cb6255bc2328d070e35836aa4f282a0fb5f3a0cd7260257298 \ + --hash=sha256:533ce891ba774eabf607172254f2e7260ba5f57bdd64030c9a4fcfbd99815d0d \ + --hash=sha256:557a31a390b7e9439056644cb80ed0735a6e3e3bb09d67fd5687e4b04238d1de \ + --hash=sha256:5a0f68ca8fd8d16583dfa7793973feb86f2fbb56ce3966daf9c9f748f52a2049 \ + --hash=sha256:6a0df4223b514d799b8a1629c65ddc351b3efa833ccf7f8ea0cf654a61d1e35d \ + --hash=sha256:6c7ecb74c4bd71db68a6bea1edf8da8c34f3d9fe218f038814fd1d310ac76c90 \ + --hash=sha256:7c23c54a00ae43edf48d44066a7ec31e05fdc2eee0be2b8b50dd1903a1db94bb \ + --hash=sha256:805cef3a38f4eafae3a5bf9ebdcdb741d0bcfd9e1bd90eb54abd24f928cd2465 \ + --hash=sha256:88c982aac7cb1cbe8cbb4e7f253072b1df872701fcaf48d84ffbb433b6568f24 \ + --hash=sha256:8ab06a50fb9bf9666dd0fe5dfb4676fa2b0ac0f31ecff72a6c3af8e22c063453 \ + --hash=sha256:8c6a2dcebd6f3903e05d51960a8058d6e131fe69f952a5397e5dbabc841b6d56 \ + --hash=sha256:8c760d85a2f82e2bed75867079188c9d18dae2ee77c25a54d60e9cc79be1bc48 \ + --hash=sha256:9ad459e99793fa6e13bd5b7e6792c8f9190b4e5a1b45c63aba14a4d0a7f1d5ff \ + --hash=sha256:9bad06436568442575beb2d03389aa7456c690a5b05892c471215bfd8cf39460 \ + --hash=sha256:a174837a64f5b16cab6f368171a1a03a27936b31699d167684073ff1c4237dac \ + --hash=sha256:a7f7c643e8b1320fd958bf098aa7ecf70623a42ec5154e3be3be673f4c34d900 \ + --hash=sha256:a9b61c19040397970d18d7737375cffd83b1f36a11dd4ad19f83a016f736c3ef \ + --hash=sha256:b4b801ebe0b477be666696bda493a9be8356f1f0057a57f1e35cd26928823e5a \ + --hash=sha256:b95e97e470fe60ed493fd9ae3911d8da4ebac16bd21f87ffa2b7c588bf22ea2c \ + --hash=sha256:bc11d7e8c44a65115d05e2ab9989d1e045125d7be8e05a071a48bc76eb6d6040 \ + --hash=sha256:bfc534409c5d4b0bf945af29e5d0ab075eae9eecbb549ff8a29280db822f34f9 \ + --hash=sha256:c1a953995cccb9e25a4ae19e34316671e4e2edaebe4cf538229b1fc7109087b7 \ + --hash=sha256:cb73dccfc991691c444acc8c0012bee8f2470da826a92e3a20bb333b1a7894e6 \ + --hash=sha256:ce756d3a10d0c4067172804c9cc276ba9cc0ff47af9078ad439b075d1abdc29b \ + --hash=sha256:d81fdb088defa30eb37bf390bb7dde35d3a83ec112ac8e33d75ab28cc29dd8b0 \ + --hash=sha256:f21c9219ef48ca5ee78402d5cc831bd58ea27ce89beda894428bc67a52da5328 + # via + # -r build/requirements.in + # jaxlib + # tensorstore +mpmath==1.3.0 \ + --hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \ + --hash=sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c + # via -r build/test-requirements.txt +numpy==2.3.5 ; python_version >= "3.14" \ + --hash=sha256:00dc4e846108a382c5869e77c6ed514394bdeb3403461d25a829711041217d5b \ + --hash=sha256:0472f11f6ec23a74a906a00b48a4dcf3849209696dff7c189714511268d103ae \ + --hash=sha256:04822c00b5fd0323c8166d66c701dc31b7fbd252c100acd708c48f763968d6a3 \ + --hash=sha256:052e8c42e0c49d2575621c158934920524f6c5da05a1d3b9bab5d8e259e045f0 \ + --hash=sha256:09a1bea522b25109bf8e6f3027bd810f7c1085c64a0c7ce050c1676ad0ba010b \ + --hash=sha256:0cd00b7b36e35398fa2d16af7b907b65304ef8bb4817a550e06e5012929830fa \ + --hash=sha256:0d8163f43acde9a73c2a33605353a4f1bc4798745a8b1d73183b28e5b435ae28 \ + --hash=sha256:1062fde1dcf469571705945b0f221b73928f34a20c904ffb45db101907c3454e \ + --hash=sha256:11e06aa0af8c0f05104d56450d6093ee639e15f24ecf62d417329d06e522e017 \ + --hash=sha256:17531366a2e3a9e30762c000f2c43a9aaa05728712e25c11ce1dbe700c53ad41 \ + --hash=sha256:1978155dd49972084bd6ef388d66ab70f0c323ddee6f693d539376498720fb7e \ + --hash=sha256:1ed1ec893cff7040a02c8aa1c8611b94d395590d553f6b53629a4461dc7f7b63 \ + --hash=sha256:2dcd0808a421a482a080f89859a18beb0b3d1e905b81e617a188bd80422d62e9 \ + --hash=sha256:2e2eb32ddb9ccb817d620ac1d8dae7c3f641c1e5f55f531a33e8ab97960a75b8 \ + --hash=sha256:2feae0d2c91d46e59fcd62784a3a83b3fb677fead592ce51b5a6fbb4f95965ff \ + --hash=sha256:3095bdb8dd297e5920b010e96134ed91d852d81d490e787beca7e35ae1d89cf7 \ + --hash=sha256:30bc11310e8153ca664b14c5f1b73e94bd0503681fcf136a163de856f3a50139 \ + --hash=sha256:3101e5177d114a593d79dd79658650fe28b5a0d8abeb8ce6f437c0e6df5be1a4 \ + --hash=sha256:396084a36abdb603546b119d96528c2f6263921c50df3c8fd7cb28873a237748 \ + --hash=sha256:3997b5b3c9a771e157f9aae01dd579ee35ad7109be18db0e85dbdbe1de06e952 \ + --hash=sha256:414802f3b97f3c1eef41e530aaba3b3c1620649871d8cb38c6eaff034c2e16bd \ + --hash=sha256:51c1e14eb1e154ebd80e860722f9e6ed6ec89714ad2db2d3aa33c31d7c12179b \ + --hash=sha256:51c55fe3451421f3a6ef9a9c1439e82101c57a2c9eab9feb196a62b1a10b58ce \ + --hash=sha256:5ee6609ac3604fa7780e30a03e5e241a7956f8e2fcfe547d51e3afa5247ac47f \ + --hash=sha256:612a95a17655e213502f60cfb9bf9408efdc9eb1d5f50535cc6eb365d11b42b5 \ + --hash=sha256:6203fdf9f3dc5bdaed7319ad8698e685c7a3be10819f41d32a0723e611733b42 \ + --hash=sha256:63c0e9e7eea69588479ebf4a8a270d5ac22763cc5854e9a7eae952a3908103f7 \ + --hash=sha256:66f85ce62c70b843bab1fb14a05d5737741e74e28c7b8b5a064de10142fad248 \ + --hash=sha256:6cf9b429b21df6b99f4dee7a1218b8b7ffbbe7df8764dc0bd60ce8a0708fed1e \ + --hash=sha256:70b37199913c1bd300ff6e2693316c6f869c7ee16378faf10e4f5e3275b299c3 \ + --hash=sha256:727fd05b57df37dc0bcf1a27767a3d9a78cbbc92822445f32cc3436ba797337b \ + --hash=sha256:74ae7b798248fe62021dbf3c914245ad45d1a6b0cb4a29ecb4b31d0bfbc4cc3e \ + --hash=sha256:784db1dcdab56bf0517743e746dfb0f885fc68d948aba86eeec2cba234bdf1c0 \ + --hash=sha256:86945f2ee6d10cdfd67bcb4069c1662dd711f7e2a4343db5cecec06b87cf31aa \ + --hash=sha256:86d835afea1eaa143012a2d7a3f45a3adce2d7adc8b4961f0b362214d800846a \ + --hash=sha256:872a5cf366aec6bb1147336480fef14c9164b154aeb6542327de4970282cd2f5 \ + --hash=sha256:8b973c57ff8e184109db042c842423ff4f60446239bd585a5131cc47f06f789d \ + --hash=sha256:8cba086a43d54ca804ce711b2a940b16e452807acebe7852ff327f1ecd49b0d4 \ + --hash=sha256:8f7f0e05112916223d3f438f293abf0727e1181b5983f413dfa2fefc4098245c \ + --hash=sha256:900218e456384ea676e24ea6a0417f030a3b07306d29d7ad843957b40a9d8d52 \ + --hash=sha256:93eebbcf1aafdf7e2ddd44c2923e2672e1010bddc014138b229e49725b4d6be5 \ + --hash=sha256:9c75442b2209b8470d6d5d8b1c25714270686f14c749028d2199c54e29f20b4d \ + --hash=sha256:9ee2197ef8c4f0dfe405d835f3b6a14f5fee7782b5de51ba06fb65fc9b36e9f1 \ + --hash=sha256:a414504bef8945eae5f2d7cb7be2d4af77c5d1cb5e20b296c2c25b61dff2900c \ + --hash=sha256:a4b9159734b326535f4dd01d947f919c6eefd2d9827466a696c44ced82dfbc18 \ + --hash=sha256:a80afd79f45f3c4a7d341f13acbe058d1ca8ac017c165d3fa0d3de6bc1a079d7 \ + --hash=sha256:aa5bc7c5d59d831d9773d1170acac7893ce3a5e130540605770ade83280e7188 \ + --hash=sha256:acfd89508504a19ed06ef963ad544ec6664518c863436306153e13e94605c218 \ + --hash=sha256:aeffcab3d4b43712bb7a60b65f6044d444e75e563ff6180af8f98dd4b905dfd2 \ + --hash=sha256:afaffc4393205524af9dfa400fa250143a6c3bc646c08c9f5e25a9f4b4d6a903 \ + --hash=sha256:b0c7088a73aef3d687c4deef8452a3ac7c1be4e29ed8bf3b366c8111128ac60c \ + --hash=sha256:b46b4ec24f7293f23adcd2d146960559aaf8020213de8ad1909dba6c013bf89c \ + --hash=sha256:b501b5fa195cc9e24fe102f21ec0a44dffc231d2af79950b451e0d99cea02234 \ + --hash=sha256:bf06bc2af43fa8d32d30fae16ad965663e966b1a3202ed407b84c989c3221e82 \ + --hash=sha256:c804e3a5aba5460c73955c955bdbd5c08c354954e9270a2c1565f62e866bdc39 \ + --hash=sha256:c8a9958e88b65c3b27e22ca2a076311636850b612d6bbfb76e8d156aacde2aaf \ + --hash=sha256:cc0a57f895b96ec78969c34f682c602bf8da1a0270b09bc65673df2e7638ec20 \ + --hash=sha256:cc8920d2ec5fa99875b670bb86ddeb21e295cb07aa331810d9e486e0b969d946 \ + --hash=sha256:ccc933afd4d20aad3c00bcef049cb40049f7f196e0397f1109dba6fed63267b0 \ + --hash=sha256:ce581db493ea1a96c0556360ede6607496e8bf9b3a8efa66e06477267bc831e9 \ + --hash=sha256:d0f23b44f57077c1ede8c5f26b30f706498b4862d3ff0a7298b8411dd2f043ff \ + --hash=sha256:d21644de1b609825ede2f48be98dfde4656aefc713654eeee280e37cadc4e0ad \ + --hash=sha256:d6889ec4ec662a1a37eb4b4fb26b6100841804dac55bd9df579e326cdc146227 \ + --hash=sha256:de5672f4a7b200c15a4127042170a694d4df43c992948f5e1af57f0174beed10 \ + --hash=sha256:e6a0bc88393d65807d751a614207b7129a310ca4fe76a74e5c7da5fa5671417e \ + --hash=sha256:ed89927b86296067b4f81f108a2271d8926467a8868e554eaf370fc27fa3ccaf \ + --hash=sha256:ee3888d9ff7c14604052b2ca5535a30216aa0a58e948cdd3eeb8d3415f638769 \ + --hash=sha256:f0963b55cdd70fad460fa4c1341f12f976bb26cb66021a5580329bd498988310 \ + --hash=sha256:f16417ec91f12f814b10bafe79ef77e70113a2f5f7018640e7425ff979253425 \ + --hash=sha256:f28620fe26bee16243be2b7b874da327312240a7cdc38b769a697578d2100013 \ + --hash=sha256:f4255143f5160d0de972d28c8f9665d882b5f61309d8362fdd3e103cf7bf010c \ + --hash=sha256:ffac52f28a7849ad7576293c0cb7b9f08304e8f7d738a8cb8a90ec4c55a998eb \ + --hash=sha256:ffe22d2b05504f786c867c8395de703937f934272eb67586817b46188b4ded6d \ + --hash=sha256:fffe29a1ef00883599d1dc2c51aa2e5d80afe49523c261a74933df395c15c520 + # via + # -r build/nonfreethreading-requirements.txt + # contourpy + # jaxlib + # matplotlib + # ml-dtypes + # numpy-typing-compat + # optype + # scipy + # tensorstore +numpy-typing-compat==20251206.2.4 \ + --hash=sha256:59882d23aaff054a2536da80564012cdce33487657be4d79c5925bb8705fcabc \ + --hash=sha256:a82e723bd20efaa4cf2886709d4264c144f1f2b609bda83d1545113b7e47a5b5 + # via optype +nvidia-cublas==13.1.0.3 ; sys_platform == "linux" \ + --hash=sha256:2a3b94a37def342471c59fad7856caee4926809a72dd5270155d6a31b5b277be \ + --hash=sha256:c86fc7f7ae36d7528288c5d88098edcb7b02c633d262e7ddbb86b0ad91be5df2 \ + --hash=sha256:ee8722c1f0145ab246bccb9e452153b5e0515fd094c3678df50b2a0888b8b171 + # via + # -r build/nvidia-requirements.txt + # nvidia-cudnn-cu13 + # nvidia-cusolver +nvidia-cublas-cu12==12.9.1.4 ; sys_platform == "linux" \ + --hash=sha256:1e5fee10662e6e52bd71dec533fbbd4971bb70a5f24f3bc3793e5c2e9dc640bf \ + --hash=sha256:453611eb21a7c1f2c2156ed9f3a45b691deda0440ec550860290dc901af5b4c2 \ + --hash=sha256:7a950dae01add3b415a5a5cdc4ec818fb5858263e9cca59004bb99fdbbd3a5d6 + # via + # -r build/nvidia-requirements.txt + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-crt==13.0.88 ; sys_platform == "linux" \ + --hash=sha256:2c8043c7c9e02492716426e9919fc78d2c5b3b2a7a768a88e952676b08aa55a4 \ + --hash=sha256:31e02c52916804ca15e31f272a96181d8fadaf40c4c82a77a6f78071a22eccf3 \ + --hash=sha256:ee2ea2a97073e02ee62bb27841f437332be2c248e3eac013df07997ada39c003 + # via + # -r build/nvidia-requirements.txt + # nvidia-cuda-nvcc +nvidia-cuda-cupti==13.0.85 ; sys_platform == "linux" \ + --hash=sha256:4eb01c08e859bf924d222250d2e8f8b8ff6d3db4721288cf35d14252a4d933c8 \ + --hash=sha256:683f58d301548deeefcb8f6fac1b8d907691b9d8b18eccab417f51e362102f00 \ + --hash=sha256:796bd679890ee55fb14a94629b698b6db54bcfd833d391d5e94017dd9d7d3151 + # via -r build/nvidia-requirements.txt +nvidia-cuda-cupti-cu12==12.9.79 ; sys_platform == "linux" \ + --hash=sha256:096bcf334f13e1984ba36685ad4c1d6347db214de03dbb6eebb237b41d9d934f \ + --hash=sha256:1848a9380067560d5bee10ed240eecc22991713e672c0515f9c3d9396adf93c8 \ + --hash=sha256:791853b030602c6a11d08b5578edfb957cadea06e9d3b26adbf8d036135a4afe + # via -r build/nvidia-requirements.txt +nvidia-cuda-nvcc==13.0.88 ; sys_platform == "linux" \ + --hash=sha256:56fe502eb77625a12f25172caa3cdddb4e4c8ba2c8c17dba44b164761b380f03 \ + --hash=sha256:7c3a32c8ca9866addfd784da363ddee2f6874d560027a296f583e86a61f2d543 \ + --hash=sha256:c7ff28f86a24effdc6c034fa15230c549a273e4771b10a7fec14996f8cf3307f + # via -r build/nvidia-requirements.txt +nvidia-cuda-nvcc-cu12==12.9.86 ; sys_platform == "linux" \ + --hash=sha256:44e1eca4d08926193a558d2434b1bf83d57b4d5743e0c431c0c83d51da1df62b \ + --hash=sha256:5d6a0d32fdc7ea39917c20065614ae93add6f577d840233237ff08e9a38f58f0 \ + --hash=sha256:8ed7f0b17dea662755395be029376db3b94fed5cbb17c2d35cc866c5b1b84099 + # via -r build/nvidia-requirements.txt +nvidia-cuda-nvrtc==13.0.88 ; sys_platform == "linux" \ + --hash=sha256:6bcd4e7f8e205cbe644f5a98f2f799bef9556fefc89dd786e79a16312ce49872 \ + --hash=sha256:ad9b6d2ead2435f11cbb6868809d2adeeee302e9bb94bcf0539c7a40d80e8575 \ + --hash=sha256:d27f20a0ca67a4bb34268a5e951033496c5b74870b868bacd046b1b8e0c3267b + # via -r build/nvidia-requirements.txt +nvidia-cuda-nvrtc-cu12==12.9.86 ; sys_platform == "linux" \ + --hash=sha256:096d4de6bda726415dfaf3198d4f5c522b8e70139c97feef5cd2ca6d4cd9cead \ + --hash=sha256:210cf05005a447e29214e9ce50851e83fc5f4358df8b453155d5e1918094dcb4 \ + --hash=sha256:72972ebdcf504d69462d3bcd67e7b81edd25d0fb85a2c46d3ea3517666636349 + # via -r build/nvidia-requirements.txt +nvidia-cuda-runtime==13.0.96 ; sys_platform == "linux" \ + --hash=sha256:7f82250d7782aa23b6cfe765ecc7db554bd3c2870c43f3d1821f1d18aebf0548 \ + --hash=sha256:ef9bcbe90493a2b9d810e43d249adb3d02e98dd30200d86607d8d02687c43f55 \ + --hash=sha256:f79298c8a098cec150a597c8eba58ecdab96e3bdc4b9bc4f9983635031740492 + # via + # -r build/nvidia-requirements.txt + # nvidia-cuda-nvcc +nvidia-cuda-runtime-cu12==12.9.79 ; sys_platform == "linux" \ + --hash=sha256:25bba2dfb01d48a9b59ca474a1ac43c6ebf7011f1b0b8cc44f54eb6ac48a96c3 \ + --hash=sha256:83469a846206f2a733db0c42e223589ab62fd2fabac4432d2f8802de4bded0a4 \ + --hash=sha256:8e018af8fa02363876860388bd10ccb89eb9ab8fb0aa749aaf58430a9f7c4891 + # via -r build/nvidia-requirements.txt +nvidia-cudnn-cu12==9.16.0.29 ; sys_platform == "linux" \ + --hash=sha256:142e2bd646a4573ab17d61a24c6359155cdfe1f34c67fc305b71222a7ae45b8e \ + --hash=sha256:4b09c43096db582f110c5572d0bcbd98b30d709e860a8f73c6c3846baa83b8d2 \ + --hash=sha256:78d05b4434dacc7dd9bc903d5c33a2f28a5f0064d02568ef7b2418f89f6c5922 + # via -r build/nvidia-requirements.txt +nvidia-cudnn-cu13==9.16.0.29 ; sys_platform == "linux" \ + --hash=sha256:6349bc8769369a91611c5e2ce5c2e510e61848c245099c31e870d2cdce0ab90d \ + --hash=sha256:79dc1bfe8c1a780cf4eb7b334d14d7927576d6dd8823f8e2769911af30fd4da3 \ + --hash=sha256:faafa46e2e7dd844bbcf06b6adec3fa66924987f2fb21bf67f5c6fd697c74a64 + # via -r build/nvidia-requirements.txt +nvidia-cufft==12.0.0.61 ; sys_platform == "linux" \ + --hash=sha256:2708c852ef8cd89d1d2068bdbece0aa188813a0c934db3779b9b1faa8442e5f5 \ + --hash=sha256:2abce5b39d2f5ae12730fb7e5db6696533e36c26e2d3e8fd1750bdd2853364eb \ + --hash=sha256:6c44f692dce8fd5ffd3e3df134b6cdb9c2f72d99cf40b62c32dde45eea9ddad3 + # via -r build/nvidia-requirements.txt +nvidia-cufft-cu12==11.4.1.4 ; sys_platform == "linux" \ + --hash=sha256:1a28c9b12260a1aa7a8fd12f5ebd82d027963d635ba82ff39a1acfa7c4c0fbcf \ + --hash=sha256:8e5bfaac795e93f80611f807d42844e8e27e340e0cde270dcb6c65386d795b80 \ + --hash=sha256:c67884f2a7d276b4b80eb56a79322a95df592ae5e765cf1243693365ccab4e28 + # via -r build/nvidia-requirements.txt +nvidia-cusolver==12.0.4.66 ; sys_platform == "linux" \ + --hash=sha256:02c2457eaa9e39de20f880f4bd8820e6a1cfb9f9a34f820eb12a155aa5bc92d2 \ + --hash=sha256:0a759da5dea5c0ea10fd307de75cdeb59e7ea4fcb8add0924859b944babf1112 \ + --hash=sha256:16515bd33a8e76bb54d024cfa068fa68d30e80fc34b9e1090813ea9362e0cb65 + # via -r build/nvidia-requirements.txt +nvidia-cusolver-cu12==11.7.5.82 ; sys_platform == "linux" \ + --hash=sha256:15da72d1340d29b5b3cf3fd100e3cd53421dde36002eda6ed93811af63c40d88 \ + --hash=sha256:62efa83e4ace59a4c734d052bb72158e888aa7b770e1a5f601682f16fe5b4fd2 \ + --hash=sha256:77666337237716783c6269a658dea310195cddbd80a5b2919b1ba8735cec8efd + # via -r build/nvidia-requirements.txt +nvidia-cusparse==12.6.3.3 ; sys_platform == "linux" \ + --hash=sha256:2b3c89c88d01ee0e477cb7f82ef60a11a4bcd57b6b87c33f789350b59759360b \ + --hash=sha256:80bcc4662f23f1054ee334a15c72b8940402975e0eab63178fc7e670aa59472c \ + --hash=sha256:cbcf42feb737bd7ec15b4c0a63e62351886bd3f975027b8815d7f720a2b5ea79 + # via + # -r build/nvidia-requirements.txt + # nvidia-cusolver +nvidia-cusparse-cu12==12.5.10.65 ; sys_platform == "linux" \ + --hash=sha256:221c73e7482dd93eda44e65ce567c031c07e2f93f6fa0ecd3ba876a195023e83 \ + --hash=sha256:73060ce019ac064a057267c585bf1fd5a353734151f87472ff02b2c5c9984e78 \ + --hash=sha256:9e487468a22a1eaf1fbd1d2035936a905feb79c4ce5c2f67626764ee4f90227c + # via + # -r build/nvidia-requirements.txt + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.28.9 ; sys_platform == "linux" \ + --hash=sha256:485776daa8447da5da39681af455aa3b2c2586ddcf4af8772495e7c532c7e5ab \ + --hash=sha256:50a36e01c4a090b9f9c47d92cec54964de6b9fcb3362d0e19b8ffc6323c21b60 + # via -r build/nvidia-requirements.txt +nvidia-nccl-cu13==2.28.9 ; sys_platform == "linux" \ + --hash=sha256:01c873ba1626b54caa12272ed228dc5b2781545e0ae8ba3f432a8ef1c6d78643 \ + --hash=sha256:e4553a30f34195f3fa1da02a6da3d6337d28f2003943aa0a3d247bbc25fefc42 + # via -r build/nvidia-requirements.txt +nvidia-nvjitlink==13.0.88 ; sys_platform == "linux" \ + --hash=sha256:13a74f429e23b921c1109976abefacc69835f2f433ebd323d3946e11d804e47b \ + --hash=sha256:634e96e3da9ef845ae744097a1f289238ecf946ce0b82e93cdce14b9782e682f \ + --hash=sha256:e931536ccc7d467a98ba1d8b89ff7fa7f1fa3b13f2b0069118cd7f47bff07d0c + # via + # -r build/nvidia-requirements.txt + # nvidia-cufft + # nvidia-cusolver + # nvidia-cusparse +nvidia-nvjitlink-cu12==12.9.86 ; sys_platform == "linux" \ + --hash=sha256:994a05ef08ef4b0b299829cde613a424382aff7efb08a7172c1fa616cc3af2ca \ + --hash=sha256:cc6fcec260ca843c10e34c936921a1c426b351753587fdd638e8cff7b16bb9db \ + --hash=sha256:e3f1171dbdc83c5932a45f0f4c99180a70de9bd2718c1ab77d14104f6d7147f9 + # via + # -r build/nvidia-requirements.txt + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 +nvidia-nvshmem-cu12==3.4.5 ; sys_platform == "linux" \ + --hash=sha256:042f2500f24c021db8a06c5eec2539027d57460e1c1a762055a6554f72c369bd \ + --hash=sha256:0b48363fc6964dede448029434c6abed6c5e37f823cb43c3bcde7ecfc0457e15 + # via -r build/nvidia-requirements.txt +nvidia-nvshmem-cu13==3.4.5 ; sys_platform == "linux" \ + --hash=sha256:290f0a2ee94c9f3687a02502f3b9299a9f9fe826e6d0287ee18482e78d495b80 \ + --hash=sha256:6dc2a197f38e5d0376ad52cd1a2a3617d3cdc150fd5966f4aee9bcebb1d68fe9 + # via -r build/nvidia-requirements.txt +nvidia-nvvm==13.0.88 ; sys_platform == "linux" \ + --hash=sha256:2ef0db7849e476d3b2fc3c09b27bdd79bd7ea8ce58cd9c86553d64ea40844ba0 \ + --hash=sha256:c4376a291d72d22a315d9d2f69bdae8f8cd83a627f75bad395cee49a0fe65dc1 \ + --hash=sha256:c5f41ffeb6466944a026dfa5317d7d85355c119bbec279205d22f1869d1054e0 + # via + # -r build/nvidia-requirements.txt + # nvidia-cuda-nvcc +opt-einsum==3.4.0 \ + --hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \ + --hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac + # via -r build/requirements.in +optype[numpy]==0.15.0 \ + --hash=sha256:457d6ca9e7da19967ec16d42bdf94e240b33b5d70a56fbbf5b427e5ea39cf41e \ + --hash=sha256:caba40ece9ea39b499fa76c036a82e0d452a432dd4dd3e8e0d30892be2e8c76c + # via scipy-stubs +packaging==25.0 \ + --hash=sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484 \ + --hash=sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f + # via + # auditwheel + # build + # matplotlib + # pytest + # wheel +pillow==12.0.0 \ + --hash=sha256:0869154a2d0546545cde61d1789a6524319fc1897d9ee31218eae7a60ccc5643 \ + --hash=sha256:09f2d0abef9e4e2f349305a4f8cc784a8a6c2f58a8c4892eea13b10a943bd26e \ + --hash=sha256:0b817e7035ea7f6b942c13aa03bb554fc44fea70838ea21f8eb31c638326584e \ + --hash=sha256:0fd00cac9c03256c8b2ff58f162ebcd2587ad3e1f2e397eab718c47e24d231cc \ + --hash=sha256:110486b79f2d112cf6add83b28b627e369219388f64ef2f960fef9ebaf54c642 \ + --hash=sha256:1979f4566bb96c1e50a62d9831e2ea2d1211761e5662afc545fa766f996632f6 \ + --hash=sha256:1ac11e8ea4f611c3c0147424eae514028b5e9077dd99ab91e1bd7bc33ff145e1 \ + --hash=sha256:1b1b133e6e16105f524a8dec491e0586d072948ce15c9b914e41cdadd209052b \ + --hash=sha256:1ee80a59f6ce048ae13cda1abf7fbd2a34ab9ee7d401c46be3ca685d1999a399 \ + --hash=sha256:21f241bdd5080a15bc86d3466a9f6074a9c2c2b314100dd896ac81ee6db2f1ba \ + --hash=sha256:266cd5f2b63ff316d5a1bba46268e603c9caf5606d44f38c2873c380950576ad \ + --hash=sha256:26d9f7d2b604cd23aba3e9faf795787456ac25634d82cd060556998e39c6fa47 \ + --hash=sha256:27f95b12453d165099c84f8a8bfdfd46b9e4bda9e0e4b65f0635430027f55739 \ + --hash=sha256:2c54c1a783d6d60595d3514f0efe9b37c8808746a66920315bfd34a938d7994b \ + --hash=sha256:2fa5f0b6716fc88f11380b88b31fe591a06c6315e955c096c35715788b339e3f \ + --hash=sha256:32ed80ea8a90ee3e6fa08c21e2e091bba6eda8eccc83dbc34c95169507a91f10 \ + --hash=sha256:3830c769decf88f1289680a59d4f4c46c72573446352e2befec9a8512104fa52 \ + --hash=sha256:38df9b4bfd3db902c9c2bd369bcacaf9d935b2fff73709429d95cc41554f7b3d \ + --hash=sha256:3adfb466bbc544b926d50fe8f4a4e6abd8c6bffd28a26177594e6e9b2b76572b \ + --hash=sha256:3e42edad50b6909089750e65c91aa09aaf1e0a71310d383f11321b27c224ed8a \ + --hash=sha256:4078242472387600b2ce8d93ade8899c12bf33fa89e55ec89fe126e9d6d5d9e9 \ + --hash=sha256:455247ac8a4cfb7b9bc45b7e432d10421aea9fc2e74d285ba4072688a74c2e9d \ + --hash=sha256:4cc6b3b2efff105c6a1656cfe59da4fdde2cda9af1c5e0b58529b24525d0a098 \ + --hash=sha256:4cf7fed4b4580601c4345ceb5d4cbf5a980d030fd5ad07c4d2ec589f95f09905 \ + --hash=sha256:5193fde9a5f23c331ea26d0cf171fbf67e3f247585f50c08b3e205c7aeb4589b \ + --hash=sha256:5269cc1caeedb67e6f7269a42014f381f45e2e7cd42d834ede3c703a1d915fe3 \ + --hash=sha256:53561a4ddc36facb432fae7a9d8afbfaf94795414f5cdc5fc52f28c1dca90371 \ + --hash=sha256:55f818bd74fe2f11d4d7cbc65880a843c4075e0ac7226bc1a23261dbea531953 \ + --hash=sha256:58eea5ebe51504057dd95c5b77d21700b77615ab0243d8152793dc00eb4faf01 \ + --hash=sha256:5d5c411a8eaa2299322b647cd932586b1427367fd3184ffbb8f7a219ea2041ca \ + --hash=sha256:6846bd2d116ff42cba6b646edf5bf61d37e5cbd256425fa089fee4ff5c07a99e \ + --hash=sha256:6ace95230bfb7cd79ef66caa064bbe2f2a1e63d93471c3a2e1f1348d9f22d6b7 \ + --hash=sha256:6e51b71417049ad6ab14c49608b4a24d8fb3fe605e5dfabfe523b58064dc3d27 \ + --hash=sha256:71db6b4c1653045dacc1585c1b0d184004f0d7e694c7b34ac165ca70c0838082 \ + --hash=sha256:7438839e9e053ef79f7112c881cef684013855016f928b168b81ed5835f3e75e \ + --hash=sha256:759de84a33be3b178a64c8ba28ad5c135900359e85fb662bc6e403ad4407791d \ + --hash=sha256:792a2c0be4dcc18af9d4a2dfd8a11a17d5e25274a1062b0ec1c2d79c76f3e7f8 \ + --hash=sha256:7d87ef5795da03d742bf49439f9ca4d027cde49c82c5371ba52464aee266699a \ + --hash=sha256:7dfb439562f234f7d57b1ac6bc8fe7f838a4bd49c79230e0f6a1da93e82f1fad \ + --hash=sha256:7fa22993bac7b77b78cae22bad1e2a987ddf0d9015c63358032f84a53f23cdc3 \ + --hash=sha256:805ebf596939e48dbb2e4922a1d3852cfc25c38160751ce02da93058b48d252a \ + --hash=sha256:82240051c6ca513c616f7f9da06e871f61bfd7805f566275841af15015b8f98d \ + --hash=sha256:87d4f8125c9988bfbed67af47dd7a953e2fc7b0cc1e7800ec6d2080d490bb353 \ + --hash=sha256:8d8ca2b210ada074d57fcee40c30446c9562e542fc46aedc19baf758a93532ee \ + --hash=sha256:8dc232e39d409036af549c86f24aed8273a40ffa459981146829a324e0848b4b \ + --hash=sha256:90387104ee8400a7b4598253b4c406f8958f59fcf983a6cea2b50d59f7d63d0b \ + --hash=sha256:905b0365b210c73afb0ebe9101a32572152dfd1c144c7e28968a331b9217b94a \ + --hash=sha256:99353a06902c2e43b43e8ff74ee65a7d90307d82370604746738a1e0661ccca7 \ + --hash=sha256:99a7f72fb6249302aa62245680754862a44179b545ded638cf1fef59befb57ef \ + --hash=sha256:9f0b04c6b8584c2c193babcccc908b38ed29524b29dd464bc8801bf10d746a3a \ + --hash=sha256:9fe611163f6303d1619bbcb653540a4d60f9e55e622d60a3108be0d5b441017a \ + --hash=sha256:a3475b96f5908b3b16c47533daaa87380c491357d197564e0ba34ae75c0f3257 \ + --hash=sha256:a6597ff2b61d121172f5844b53f21467f7082f5fb385a9a29c01414463f93b07 \ + --hash=sha256:a7921c5a6d31b3d756ec980f2f47c0cfdbce0fc48c22a39347a895f41f4a6ea4 \ + --hash=sha256:aa5129de4e174daccbc59d0a3b6d20eaf24417d59851c07ebb37aeb02947987c \ + --hash=sha256:aeaefa96c768fc66818730b952a862235d68825c178f1b3ffd4efd7ad2edcb7c \ + --hash=sha256:afbefa430092f71a9593a99ab6a4e7538bc9eabbf7bf94f91510d3503943edc4 \ + --hash=sha256:aff9e4d82d082ff9513bdd6acd4f5bd359f5b2c870907d2b0a9c5e10d40c88fe \ + --hash=sha256:b22bd8c974942477156be55a768f7aa37c46904c175be4e158b6a86e3a6b7ca8 \ + --hash=sha256:b290fd8aa38422444d4b50d579de197557f182ef1068b75f5aa8558638b8d0a5 \ + --hash=sha256:b2e4b27a6e15b04832fe9bf292b94b5ca156016bbc1ea9c2c20098a0320d6cf6 \ + --hash=sha256:b583dc9070312190192631373c6c8ed277254aa6e6084b74bdd0a6d3b221608e \ + --hash=sha256:b87843e225e74576437fd5b6a4c2205d422754f84a06942cfaf1dc32243e45a8 \ + --hash=sha256:bc91a56697869546d1b8f0a3ff35224557ae7f881050e99f615e0119bf934b4e \ + --hash=sha256:bd87e140e45399c818fac4247880b9ce719e4783d767e030a883a970be632275 \ + --hash=sha256:bde737cff1a975b70652b62d626f7785e0480918dece11e8fef3c0cf057351c3 \ + --hash=sha256:bdee52571a343d721fb2eb3b090a82d959ff37fc631e3f70422e0c2e029f3e76 \ + --hash=sha256:bee2a6db3a7242ea309aa7ee8e2780726fed67ff4e5b40169f2c940e7eb09227 \ + --hash=sha256:beeae3f27f62308f1ddbcfb0690bf44b10732f2ef43758f169d5e9303165d3f9 \ + --hash=sha256:c50f36a62a22d350c96e49ad02d0da41dbd17ddc2e29750dbdba4323f85eb4a5 \ + --hash=sha256:c607c90ba67533e1b2355b821fef6764d1dd2cbe26b8c1005ae84f7aea25ff79 \ + --hash=sha256:c7b2a63fd6d5246349f3d3f37b14430d73ee7e8173154461785e43036ffa96ca \ + --hash=sha256:c828a1ae702fc712978bda0320ba1b9893d99be0badf2647f693cc01cf0f04fa \ + --hash=sha256:c85de1136429c524e55cfa4e033b4a7940ac5c8ee4d9401cc2d1bf48154bbc7b \ + --hash=sha256:c98fa880d695de164b4135a52fd2e9cd7b7c90a9d8ac5e9e443a24a95ef9248e \ + --hash=sha256:cae81479f77420d217def5f54b5b9d279804d17e982e0f2fa19b1d1e14ab5197 \ + --hash=sha256:d034140032870024e6b9892c692fe2968493790dd57208b2c37e3fb35f6df3ab \ + --hash=sha256:d120c38a42c234dc9a8c5de7ceaaf899cf33561956acb4941653f8bdc657aa79 \ + --hash=sha256:d4827615da15cd59784ce39d3388275ec093ae3ee8d7f0c089b76fa87af756c2 \ + --hash=sha256:d49e2314c373f4c2b39446fb1a45ed333c850e09d0c59ac79b72eb3b95397363 \ + --hash=sha256:d52610d51e265a51518692045e372a4c363056130d922a7351429ac9f27e70b0 \ + --hash=sha256:d64317d2587c70324b79861babb9c09f71fbb780bad212018874b2c013d8600e \ + --hash=sha256:d77153e14b709fd8b8af6f66a3afbb9ed6e9fc5ccf0b6b7e1ced7b036a228782 \ + --hash=sha256:d7e091d464ac59d2c7ad8e7e08105eaf9dafbc3883fd7265ffccc2baad6ac925 \ + --hash=sha256:dd333073e0cacdc3089525c7df7d39b211bcdf31fc2824e49d01c6b6187b07d0 \ + --hash=sha256:e5d8efac84c9afcb40914ab49ba063d94f5dbdf5066db4482c66a992f47a3a3b \ + --hash=sha256:f135c702ac42262573fe9714dfe99c944b4ba307af5eb507abef1667e2cbbced \ + --hash=sha256:f13711b1a5ba512d647a0e4ba79280d3a9a045aaf7e0cc6fbe96b91d4cdf6b0c \ + --hash=sha256:f4f1231b7dec408e8670264ce63e9c71409d9583dd21d32c163e25213ee2a344 \ + --hash=sha256:fa3ed2a29a9e9d2d488b4da81dcb54720ac3104a20bf0bd273f1e4648aff5af9 \ + --hash=sha256:fb3096c30df99fd01c7bf8e544f392103d0795b9f98ba71a8054bcbf56b255f1 + # via + # -r build/test-requirements.txt + # matplotlib +pluggy==1.6.0 \ + --hash=sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3 \ + --hash=sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746 + # via pytest +portpicker==1.6.0 \ + --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ + --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa + # via -r build/test-requirements.txt +psutil==7.1.3 \ + --hash=sha256:0005da714eee687b4b8decd3d6cc7c6db36215c9e74e5ad2264b90c3df7d92dc \ + --hash=sha256:1068c303be3a72f8e18e412c5b2a8f6d31750fb152f9cb106b54090296c9d251 \ + --hash=sha256:18349c5c24b06ac5612c0428ec2a0331c26443d259e2a0144a9b24b4395b58fa \ + --hash=sha256:19644c85dcb987e35eeeaefdc3915d059dac7bd1167cdcdbf27e0ce2df0c08c0 \ + --hash=sha256:2bdbcd0e58ca14996a42adf3621a6244f1bb2e2e528886959c72cf1e326677ab \ + --hash=sha256:31d77fcedb7529f27bb3a0472bea9334349f9a04160e8e6e5020f22c59893264 \ + --hash=sha256:3792983e23b69843aea49c8f5b8f115572c5ab64c153bada5270086a2123c7e7 \ + --hash=sha256:3bb428f9f05c1225a558f53e30ccbad9930b11c3fc206836242de1091d3e7dd3 \ + --hash=sha256:56d974e02ca2c8eb4812c3f76c30e28836fffc311d55d979f1465c1feeb2b68b \ + --hash=sha256:6c86281738d77335af7aec228328e944b30930899ea760ecf33a4dba66be5e74 \ + --hash=sha256:8f33a3702e167783a9213db10ad29650ebf383946e91bc77f28a5eb083496bc9 \ + --hash=sha256:95ef04cf2e5ba0ab9eaafc4a11eaae91b44f4ef5541acd2ee91d9108d00d59a7 \ + --hash=sha256:ad81425efc5e75da3f39b3e636293360ad8d0b49bed7df824c79764fb4ba9b8b \ + --hash=sha256:b403da1df4d6d43973dc004d19cee3b848e998ae3154cc8097d139b77156c353 \ + --hash=sha256:bc31fa00f1fbc3c3802141eede66f3a2d51d89716a194bf2cd6fc68310a19880 \ + --hash=sha256:bd0d69cee829226a761e92f28140bec9a5ee9d5b4fb4b0cc589068dbfff559b1 \ + --hash=sha256:c525ffa774fe4496282fb0b1187725793de3e7c6b29e41562733cae9ada151ee \ + --hash=sha256:f39c2c19fe824b47484b96f9692932248a54c43799a84282cfe58d05a6449efd \ + --hash=sha256:fac9cd332c67f4422504297889da5ab7e05fd11e3c4392140f7370f4208ded1f + # via portpicker +pyelftools==0.32 \ + --hash=sha256:013df952a006db5e138b1edf6d8a68ecc50630adbd0d83a2d41e7f846163d738 \ + --hash=sha256:6de90ee7b8263e740c8715a925382d4099b354f29ac48ea40d840cf7aa14ace5 + # via auditwheel +pygments==2.19.2 \ + --hash=sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887 \ + --hash=sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b + # via + # pytest + # rich +pyparsing==3.2.5 \ + --hash=sha256:2df8d5b7b2802ef88e8d016a2eb9c7aeaa923529cd251ed0fe4608275d4105b6 \ + --hash=sha256:e38a4f02064cf41fe6593d328d0512495ad1f3d8a91c4f73fc401b3079a59a5e + # via matplotlib +pyproject-hooks==1.2.0 \ + --hash=sha256:1e859bd5c40fae9448642dd871adf459e5e2084186e8d2c2a79a824c970da1f8 \ + --hash=sha256:9e5c6bfa8dcc30091c74b0cf803c81fdd29d94f01992a7707bc97babb1141913 + # via build +pytest==8.4.2 \ + --hash=sha256:86c0d0b93306b961d58d62a4db4879f27fe25513d4b969df351abdddb3c30e01 \ + --hash=sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79 + # via + # -r build/test-requirements.txt + # pytest-xdist +pytest-xdist==3.8.0 \ + --hash=sha256:202ca578cfeb7370784a8c33d6d05bc6e13b4f25b5053c30a152269fd10f0b88 \ + --hash=sha256:7e578125ec9bc6050861aa93f2d59f1d8d085595d6551c2c90b6f4fad8d3a9f1 + # via -r build/test-requirements.txt +python-dateutil==2.9.0.post0 \ + --hash=sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3 \ + --hash=sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427 + # via matplotlib +rich==14.2.0 \ + --hash=sha256:73ff50c7c0c1c77c8243079283f4edb376f0f6442433aecb8ce7e6d0b92d1fe4 \ + --hash=sha256:76bc51fe2e57d2b1be1f96c524b890b816e334ab4c1e45888799bfaab0021edd + # via -r build/test-requirements.txt +scipy==1.16.3 ; python_version >= "3.13" \ + --hash=sha256:0151a0749efeaaab78711c78422d413c583b8cdd2011a3c1d6c794938ee9fdb2 \ + --hash=sha256:01e87659402762f43bd2fee13370553a17ada367d42e7487800bf2916535aecb \ + --hash=sha256:03192a35e661470197556de24e7cb1330d84b35b94ead65c46ad6f16f6b28f2a \ + --hash=sha256:0553371015692a898e1aa858fed67a3576c34edefa6b7ebdb4e9dde49ce5c203 \ + --hash=sha256:062246acacbe9f8210de8e751b16fc37458213f124bef161a5a02c7a39284304 \ + --hash=sha256:0c3b4dd3d9b08dbce0f3440032c52e9e2ab9f96ade2d3943313dfe51a7056959 \ + --hash=sha256:0c623a54f7b79dd88ef56da19bc2873afec9673a48f3b85b18e4d402bdd29a5a \ + --hash=sha256:16b8bc35a4cc24db80a0ec836a9286d0e31b2503cb2fd7ff7fb0e0374a97081d \ + --hash=sha256:1fb2472e72e24d1530debe6ae078db70fb1605350c88a3d14bc401d6306dbffe \ + --hash=sha256:21d9d6b197227a12dcbf9633320a4e34c6b0e51c57268df255a0942983bac562 \ + --hash=sha256:2a207a6ce9c24f1951241f4693ede2d393f59c07abc159b2cb2be980820e01fb \ + --hash=sha256:2b71d93c8a9936046866acebc915e2af2e292b883ed6e2cbe5c34beb094b82d9 \ + --hash=sha256:2d1ae2cf0c350e7705168ff2429962a89ad90c2d49d1dd300686d8b2a5af22fc \ + --hash=sha256:3a4c460301fb2cffb7f88528f30b3127742cff583603aa7dc964a52c463b385d \ + --hash=sha256:3d4a07a8e785d80289dfe66b7c27d8634a773020742ec7187b85ccc4b0e7b686 \ + --hash=sha256:40be6cf99e68b6c4321e9f8782e7d5ff8265af28ef2cd56e9c9b2638fa08ad97 \ + --hash=sha256:4aff59800a3b7f786b70bfd6ab551001cb553244988d7d6b8299cb1ea653b353 \ + --hash=sha256:50a3dbf286dbc7d84f176f9a1574c705f277cb6565069f88f60db9eafdbe3ee2 \ + --hash=sha256:532fb5ad6a87e9e9cd9c959b106b73145a03f04c7d57ea3e6f6bb60b86ab0876 \ + --hash=sha256:53c3844d527213631e886621df5695d35e4f6a75f620dca412bcd292f6b87d78 \ + --hash=sha256:56edc65510d1331dae01ef9b658d428e33ed48b4f77b1d51caf479a0253f96dc \ + --hash=sha256:57d01cb6f85e34f0946b33caa66e892aae072b64b034183f3d87c4025802a119 \ + --hash=sha256:5803c5fadd29de0cf27fa08ccbfe7a9e5d741bf63e4ab1085437266f12460ff9 \ + --hash=sha256:6020470b9d00245926f2d5bb93b119ca0340f0d564eb6fbaad843eaebf9d690f \ + --hash=sha256:63d3cdacb8a824a295191a723ee5e4ea7768ca5ca5f2838532d9f2e2b3ce2135 \ + --hash=sha256:663b8d66a8748051c3ee9c96465fb417509315b99c71550fda2591d7dd634234 \ + --hash=sha256:72d1717fd3b5e6ec747327ce9bda32d5463f472c9dce9f54499e81fbd50245a1 \ + --hash=sha256:7dc1360c06535ea6116a2220f760ae572db9f661aba2d88074fe30ec2aa1ff88 \ + --hash=sha256:7f68154688c515cdb541a31ef8eb66d8cd1050605be9dcd74199cbd22ac739bc \ + --hash=sha256:81fc5827606858cf71446a5e98715ba0e11f0dbc83d71c7409d05486592a45d6 \ + --hash=sha256:875555ce62743e1d54f06cdf22c1e0bc47b91130ac40fe5d783b6dfa114beeb6 \ + --hash=sha256:8b3c820ddb80029fe9f43d61b81d8b488d3ef8ca010d15122b152db77dc94c22 \ + --hash=sha256:8be1ca9170fcb6223cc7c27f4305d680ded114a1567c0bd2bfcbf947d1b17511 \ + --hash=sha256:8d09d72dc92742988b0e7750bddb8060b0c7079606c0d24a8cc8e9c9c11f9079 \ + --hash=sha256:9452781bd879b14b6f055b26643703551320aa8d79ae064a71df55c00286a184 \ + --hash=sha256:96491a6a54e995f00a28a3c3badfff58fd093bf26cd5fb34a2188c8c756a3a2c \ + --hash=sha256:9b9c9c07b6d56a35777a1b4cc8966118fb16cfd8daf6743867d17d36cfad2d40 \ + --hash=sha256:a8a26c78ef223d3e30920ef759e25625a0ecdd0d60e5a8818b7513c3e5384cf2 \ + --hash=sha256:aadd23f98f9cb069b3bd64ddc900c4d277778242e961751f77a8cb5c4b946fb0 \ + --hash=sha256:b7180967113560cca57418a7bc719e30366b47959dd845a93206fbed693c867e \ + --hash=sha256:b7c5f1bda1354d6a19bc6af73a649f8285ca63ac6b52e64e658a5a11d4d69800 \ + --hash=sha256:b81c27fc41954319a943d43b20e07c40bdcd3ff7cf013f4fb86286faefe546c4 \ + --hash=sha256:bb61878c18a470021fb515a843dc7a76961a8daceaaaa8bad1332f1bf4b54657 \ + --hash=sha256:bea0a62734d20d67608660f69dcda23e7f90fb4ca20974ab80b6ed40df87a005 \ + --hash=sha256:c5192722cffe15f9329a3948c4b1db789fbb1f05c97899187dcf009b283aea70 \ + --hash=sha256:c97176013d404c7346bf57874eaac5187d969293bf40497140b0a2b2b7482e07 \ + --hash=sha256:cd13e354df9938598af2be05822c323e97132d5e6306b83a3b4ee6724c6e522e \ + --hash=sha256:d2ec56337675e61b312179a1ad124f5f570c00f920cc75e1000025451b88241c \ + --hash=sha256:d3837938ae715fc0fe3c39c0202de3a8853aff22ca66781ddc2ade7554b7e2cc \ + --hash=sha256:d9f48cafc7ce94cf9b15c6bffdc443a81a27bf7075cf2dcd5c8b40f85d10c4e7 \ + --hash=sha256:da7763f55885045036fabcebd80144b757d3db06ab0861415d1c3b7c69042146 \ + --hash=sha256:deb3841c925eeddb6afc1e4e4a45e418d19ec7b87c5df177695224078e8ec733 \ + --hash=sha256:e1d27cbcb4602680a49d787d90664fa4974063ac9d4134813332a8c53dbe667c \ + --hash=sha256:e5d42a9472e7579e473879a1990327830493a7047506d58d73fc429b84c1d49d \ + --hash=sha256:e7efa2681ea410b10dde31a52b18b0154d66f2485328830e45fdf183af5aefc6 \ + --hash=sha256:eab43fae33a0c39006a88096cd7b4f4ef545ea0447d250d5ac18202d40b6611d \ + --hash=sha256:f2622206f5559784fa5c4b53a950c3c7c1cf3e84ca1b9c4b6c03f062f289ca26 \ + --hash=sha256:f379b54b77a597aa7ee5e697df0d66903e41b9c85a6dd7946159e356319158e8 \ + --hash=sha256:f667a4542cc8917af1db06366d3f78a5c8e83badd56409f94d1eac8d8d9133fa \ + --hash=sha256:fb4b29f4cf8cc5a8d628bc8d8e26d12d7278cd1f219f22698a378c3d67db5e4b \ + --hash=sha256:ffa6eea95283b2b8079b821dc11f50a17d0571c92b43e2b5b12764dc5f9b285d + # via + # -r build/requirements.in + # jaxlib +scipy-stubs==1.16.3.3 \ + --hash=sha256:af47578875d5557567225a16ec1b9b38a48c4c4377d92396413ebd65406c44ee \ + --hash=sha256:f6316b36cd0fb272c994ae5b10c4a73c644a7e156ed8d32bcd9c35303d0e1b7e + # via -r build/test-requirements.txt +six==1.17.0 \ + --hash=sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274 \ + --hash=sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81 + # via python-dateutil +sortedcontainers==2.4.0 \ + --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \ + --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0 + # via hypothesis +tensorstore==0.1.80 \ + --hash=sha256:04c29d979eb8b8ee48f873dc13d2701bfd49425500ffc5b848e4ec55b2548281 \ + --hash=sha256:07e4a84bacf70b78305831897068a9b5ad30326e63bbeb92c4bf7e565fcf5e9e \ + --hash=sha256:1113a6982fc0fa8dda8fcc0495715e647ac3360909a86ff13f2e04564f82d54a \ + --hash=sha256:189d924eaec394c9331e284a9c513ed583e336472a925823b5151cb26f41d091 \ + --hash=sha256:1b2b2ed0051dfab7e25295b14e6620520729e6e2ddf505f98c8d3917569614bf \ + --hash=sha256:246641a8780ee5e04e88bc95c8e31faac6471bab1180d1f5cdc9804b29a77c04 \ + --hash=sha256:4158fe76b96f62d12a37d7868150d836e089b5280b2bdd363c43c5d651f10e26 \ + --hash=sha256:46136fe42ee6dd835d957db37073058aea0b78fdfbe2975941640131b7740824 \ + --hash=sha256:4baee67fce95f29f593fbab4866119347115eaace887732aa92cfcbb9e6b0748 \ + --hash=sha256:53fd121ccd332bc4cc397f7af45889360c668b43dc3ff6bc3264df0f9886c11a \ + --hash=sha256:6b7c5dd434bba4ee08fe46bbbdb25c60dd3d47ccb4b8561a9751cf1526da52b8 \ + --hash=sha256:6c8dbbdd31cbb28eccfb23dbbd4218fe67bfc32e9cb452875a485b81031c949d \ + --hash=sha256:7451b30f99d9f31a2b9d70e6ef61815713dc782c58c6d817f91781341e4dac05 \ + --hash=sha256:8cd11027b5a8b66db8d344085a31a1666c78621dac27039c4d571bc4974804a1 \ + --hash=sha256:9c088e8c9f67c266ef4dae3703bd617f7c0cb0fd98e99c4500692e38a4328140 \ + --hash=sha256:a92505189731fcb03f1c69a84ea4460abb24204bfac1f339448a0621e7def77c \ + --hash=sha256:acb8d52fadcefafef4ef8ecca3fc99b1d0e3c5c5a888766484c3e39f050be7f5 \ + --hash=sha256:b193a7a1c4f455a61e60ed2dd67271a3daab0910ddb4bd9db51390d1b36d9996 \ + --hash=sha256:bc28a58c580253a526a4b6d239d18181ef96f1e285a502dbb03ff15eeec07a5b \ + --hash=sha256:c0529afab3800749dd245843d3bf0d061a109a8edb77fb345f476e8bccda51b8 \ + --hash=sha256:d2b353b0bd53fedd77fc5a12a1c1a91cacc3cf59e3dd785529c5a54b31d1c7b1 \ + --hash=sha256:de63843706fdfe9565a45567238c5b1e55a0b28bbde6524200b31d29043a9a16 \ + --hash=sha256:e93df6d34ff5f0f6be245f4d29b99a7c1eef8ad91b50686adf57a5eeea99cb74 \ + --hash=sha256:f65dfaf9e737a41389e29a5a2ea52ca5d14c8d6f48b402c723d800cd16d322b0 \ + --hash=sha256:f8b51d7e685bbb63f6becd7d2ac8634d5ab67ec7e53038e597182e2db2c7aa90 + # via -r build/nonfreethreading-requirements.txt +typing-extensions==4.15.0 \ + --hash=sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466 \ + --hash=sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548 + # via etils +wheel==0.46.1 \ + --hash=sha256:f796f65d72750ccde090663e466d0ca37cd72b62870f7520b96d34cdc07d86d8 \ + --hash=sha256:fd477efb5da0f7df1d3c76c73c14394002c844451bd63229d8570f376f5e6a38 + # via -r build/requirements.in +zipp==3.23.0 \ + --hash=sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e \ + --hash=sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166 + # via etils + +# The following packages are considered to be unsafe in a requirements file: +setuptools==80.9.0 \ + --hash=sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922 \ + --hash=sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c + # via -r build/requirements.in diff --git a/build/requirements_lock_3_14_ft.txt b/build/requirements_lock_3_14_ft.txt new file mode 100644 index 000000000000..a5d099e40ff0 --- /dev/null +++ b/build/requirements_lock_3_14_ft.txt @@ -0,0 +1,975 @@ +# +# This file is autogenerated by pip-compile with Python 3.14 +# by the following command: +# +# bazel run //build:requirements_ft.update +# +--index-url https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple + +absl-py==2.3.1 \ + --hash=sha256:a97820526f7fbfd2ec1bce83f3f25e3a14840dac0d8e02a0b71cd75db3f77fc9 \ + --hash=sha256:eeecf07f0c2a93ace0772c92e596ace6d3d3996c042b2128459aaae2a76de11d + # via -r build/test-requirements.txt +attrs==25.4.0 \ + --hash=sha256:16d5969b87f0859ef33a48b35d55ac1be6e42ae49d5e853b597db70c35c57e11 \ + --hash=sha256:adcf7e2a1fb3b36ac48d97835bb6d8ade15b8dcce26aba8bf1d14847b57a3373 + # via hypothesis +auditwheel==6.5.0 \ + --hash=sha256:4fbcbd5854054bb1dd7870db03727b871b96b18147db57259561c058603987d7 \ + --hash=sha256:e08d2eede0259be6feff597d041c06175026e93248a1a97143acc52c57714d80 + # via -r build/test-requirements.txt +build==1.3.0 \ + --hash=sha256:698edd0ea270bde950f53aed21f3a0135672206f3911e0176261a31e0e07b397 \ + --hash=sha256:7145f0b5061ba90a1500d60bd1b13ca0a8a4cebdd0cc16ed8adf1c0e739f43b4 + # via -r build/requirements.in +cloudpickle==3.1.2 \ + --hash=sha256:7fda9eb655c9c230dab534f1983763de5835249750e85fbcef43aaa30a9a2414 \ + --hash=sha256:9acb47f6afd73f60dc1df93bb801b472f05ff42fa6c84167d25cb206be1fbf4a + # via -r build/test-requirements.txt +colorama==0.4.6 \ + --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ + --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 + # via -r build/requirements.in +contourpy==1.3.3 \ + --hash=sha256:023b44101dfe49d7d53932be418477dba359649246075c996866106da069af69 \ + --hash=sha256:07ce5ed73ecdc4a03ffe3e1b3e3c1166db35ae7584be76f65dbbe28a7791b0cc \ + --hash=sha256:083e12155b210502d0bca491432bb04d56dc3432f95a979b429f2848c3dbe880 \ + --hash=sha256:0bf67e0e3f482cb69779dd3061b534eb35ac9b17f163d851e2a547d56dba0a3a \ + --hash=sha256:0c1fc238306b35f246d61a1d416a627348b5cf0648648a031e14bb8705fcdfe8 \ + --hash=sha256:13b68d6a62db8eafaebb8039218921399baf6e47bf85006fd8529f2a08ef33fc \ + --hash=sha256:15ff10bfada4bf92ec8b31c62bf7c1834c244019b4a33095a68000d7075df470 \ + --hash=sha256:177fb367556747a686509d6fef71d221a4b198a3905fe824430e5ea0fda54eb5 \ + --hash=sha256:1cadd8b8969f060ba45ed7c1b714fe69185812ab43bd6b86a9123fe8f99c3263 \ + --hash=sha256:1fd43c3be4c8e5fd6e4f2baeae35ae18176cf2e5cced681cca908addf1cdd53b \ + --hash=sha256:22e9b1bd7a9b1d652cd77388465dc358dafcd2e217d35552424aa4f996f524f5 \ + --hash=sha256:23416f38bfd74d5d28ab8429cc4d63fa67d5068bd711a85edb1c3fb0c3e2f381 \ + --hash=sha256:283edd842a01e3dcd435b1c5116798d661378d83d36d337b8dde1d16a5fc9ba3 \ + --hash=sha256:2a2a8b627d5cc6b7c41a4beff6c5ad5eb848c88255fda4a8745f7e901b32d8e4 \ + --hash=sha256:2b7e9480ffe2b0cd2e787e4df64270e3a0440d9db8dc823312e2c940c167df7e \ + --hash=sha256:322ab1c99b008dad206d406bb61d014cf0174df491ae9d9d0fac6a6fda4f977f \ + --hash=sha256:33c82d0138c0a062380332c861387650c82e4cf1747aaa6938b9b6516762e772 \ + --hash=sha256:348ac1f5d4f1d66d3322420f01d42e43122f43616e0f194fc1c9f5d830c5b286 \ + --hash=sha256:3519428f6be58431c56581f1694ba8e50626f2dd550af225f82fb5f5814d2a42 \ + --hash=sha256:3c30273eb2a55024ff31ba7d052dde990d7d8e5450f4bbb6e913558b3d6c2301 \ + --hash=sha256:3d1a3799d62d45c18bafd41c5fa05120b96a28079f2393af559b843d1a966a77 \ + --hash=sha256:451e71b5a7d597379ef572de31eeb909a87246974d960049a9848c3bc6c41bf7 \ + --hash=sha256:459c1f020cd59fcfe6650180678a9993932d80d44ccde1fa1868977438f0b411 \ + --hash=sha256:4d00e655fcef08aba35ec9610536bfe90267d7ab5ba944f7032549c55a146da1 \ + --hash=sha256:4debd64f124ca62069f313a9cb86656ff087786016d76927ae2cf37846b006c9 \ + --hash=sha256:4feffb6537d64b84877da813a5c30f1422ea5739566abf0bd18065ac040e120a \ + --hash=sha256:50ed930df7289ff2a8d7afeb9603f8289e5704755c7e5c3bbd929c90c817164b \ + --hash=sha256:51e79c1f7470158e838808d4a996fa9bac72c498e93d8ebe5119bc1e6becb0db \ + --hash=sha256:556dba8fb6f5d8742f2923fe9457dbdd51e1049c4a43fd3986a0b14a1d815fc6 \ + --hash=sha256:598c3aaece21c503615fd59c92a3598b428b2f01bfb4b8ca9c4edeecc2438620 \ + --hash=sha256:5ed3657edf08512fc3fe81b510e35c2012fbd3081d2e26160f27ca28affec989 \ + --hash=sha256:626d60935cf668e70a5ce6ff184fd713e9683fb458898e4249b63be9e28286ea \ + --hash=sha256:644a6853d15b2512d67881586bd03f462c7ab755db95f16f14d7e238f2852c67 \ + --hash=sha256:655456777ff65c2c548b7c454af9c6f33f16c8884f11083244b5819cc214f1b5 \ + --hash=sha256:66c8a43a4f7b8df8b71ee1840e4211a3c8d93b214b213f590e18a1beca458f7d \ + --hash=sha256:6afc576f7b33cf00996e5c1102dc2a8f7cc89e39c0b55df93a0b78c1bd992b36 \ + --hash=sha256:6c3d53c796f8647d6deb1abe867daeb66dcc8a97e8455efa729516b997b8ed99 \ + --hash=sha256:709a48ef9a690e1343202916450bc48b9e51c049b089c7f79a267b46cffcdaa1 \ + --hash=sha256:70f9aad7de812d6541d29d2bbf8feb22ff7e1c299523db288004e3157ff4674e \ + --hash=sha256:8153b8bfc11e1e4d75bcb0bff1db232f9e10b274e0929de9d608027e0d34ff8b \ + --hash=sha256:87acf5963fc2b34825e5b6b048f40e3635dd547f590b04d2ab317c2619ef7ae8 \ + --hash=sha256:88df9880d507169449d434c293467418b9f6cbe82edd19284aa0409e7fdb933d \ + --hash=sha256:929ddf8c4c7f348e4c0a5a3a714b5c8542ffaa8c22954862a46ca1813b667ee7 \ + --hash=sha256:92d9abc807cf7d0e047b95ca5d957cf4792fcd04e920ca70d48add15c1a90ea7 \ + --hash=sha256:95b181891b4c71de4bb404c6621e7e2390745f887f2a026b2d99e92c17892339 \ + --hash=sha256:9e999574eddae35f1312c2b4b717b7885d4edd6cb46700e04f7f02db454e67c1 \ + --hash=sha256:a15459b0f4615b00bbd1e91f1b9e19b7e63aea7483d03d804186f278c0af2659 \ + --hash=sha256:a22738912262aa3e254e4f3cb079a95a67132fc5a063890e224393596902f5a4 \ + --hash=sha256:ab2fd90904c503739a75b7c8c5c01160130ba67944a7b77bbf36ef8054576e7f \ + --hash=sha256:ab3074b48c4e2cf1a960e6bbeb7f04566bf36b1861d5c9d4d8ac04b82e38ba20 \ + --hash=sha256:afe5a512f31ee6bd7d0dda52ec9864c984ca3d66664444f2d72e0dc4eb832e36 \ + --hash=sha256:b08a32ea2f8e42cf1d4be3169a98dd4be32bafe4f22b6c4cb4ba810fa9e5d2cb \ + --hash=sha256:b20c7c9a3bf701366556e1b1984ed2d0cedf999903c51311417cf5f591d8c78d \ + --hash=sha256:b2e8faa0ed68cb29af51edd8e24798bb661eac3bd9f65420c1887b6ca89987c8 \ + --hash=sha256:b7301b89040075c30e5768810bc96a8e8d78085b47d8be6e4c3f5a0b4ed478a0 \ + --hash=sha256:b7448cb5a725bb1e35ce88771b86fba35ef418952474492cf7c764059933ff8b \ + --hash=sha256:ca0fdcd73925568ca027e0b17ab07aad764be4706d0a925b89227e447d9737b7 \ + --hash=sha256:ca658cd1a680a5c9ea96dc61cdbae1e85c8f25849843aa799dfd3cb370ad4fbe \ + --hash=sha256:cbedb772ed74ff5be440fa8eee9bd49f64f6e3fc09436d9c7d8f1c287b121d77 \ + --hash=sha256:cd5dfcaeb10f7b7f9dc8941717c6c2ade08f587be2226222c12b25f0483ed497 \ + --hash=sha256:cf9022ef053f2694e31d630feaacb21ea24224be1c3ad0520b13d844274614fd \ + --hash=sha256:d002b6f00d73d69333dac9d0b8d5e84d9724ff9ef044fd63c5986e62b7c9e1b1 \ + --hash=sha256:d06bb1f751ba5d417047db62bca3c8fde202b8c11fb50742ab3ab962c81e8216 \ + --hash=sha256:d304906ecc71672e9c89e87c4675dc5c2645e1f4269a5063b99b0bb29f232d13 \ + --hash=sha256:e4e6b05a45525357e382909a4c1600444e2a45b4795163d3b22669285591c1ae \ + --hash=sha256:e74a9a0f5e3fff48fb5a7f2fd2b9b70a3fe014a67522f79b7cca4c0c7e43c9ae \ + --hash=sha256:ea37e7b45949df430fe649e5de8351c423430046a2af20b1c1961cae3afcda77 \ + --hash=sha256:f64836de09927cba6f79dcd00fdd7d5329f3fccc633468507079c829ca4db4e3 \ + --hash=sha256:fd6ec6be509c787f1caf6b247f0b1ca598bef13f4ddeaa126b7658215529ba0f \ + --hash=sha256:fd907ae12cd483cd83e414b12941c632a969171bf90fc937d0c9f268a31cafff \ + --hash=sha256:fd914713266421b7536de2bfa8181aa8c699432b6763a0ea64195ebe28bff6a9 \ + --hash=sha256:fde6c716d51c04b1c25d0b90364d0be954624a0ee9d60e23e850e8d48353d07a + # via matplotlib +cycler==0.12.1 \ + --hash=sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30 \ + --hash=sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c + # via matplotlib +etils[epath,epy]==1.13.0 \ + --hash=sha256:a5b60c71f95bcd2d43d4e9fb3dc3879120c1f60472bb5ce19f7a860b1d44f607 \ + --hash=sha256:d9cd4f40fbe77ad6613b7348a18132cc511237b6c076dbb89105c0b520a4c6bb + # via -r build/requirements.in +execnet==2.1.2 \ + --hash=sha256:63d83bfdd9a23e35b9c6a3261412324f964c2ec8dcd8d3c6916ee9373e0befcd \ + --hash=sha256:67fba928dd5a544b783f6056f449e5e3931a5c378b128bc18501f7ea79e296ec + # via pytest-xdist +filelock==3.20.1 \ + --hash=sha256:15d9e9a67306188a44baa72f569d2bfd803076269365fdea0934385da4dc361a \ + --hash=sha256:b8360948b351b80f420878d8516519a2204b07aefcdcfd24912a5d33127f188c + # via -r build/test-requirements.txt +flatbuffers==25.9.23 \ + --hash=sha256:255538574d6cb6d0a79a17ec8bc0d30985913b87513a01cce8bcdb6b4c44d0e2 \ + --hash=sha256:676f9fa62750bb50cf531b42a0a2a118ad8f7f797a511eda12881c016f093b12 + # via -r build/test-requirements.txt +fonttools==4.61.1 \ + --hash=sha256:0de30bfe7745c0d1ffa2b0b7048fb7123ad0d71107e10ee090fa0b16b9452e87 \ + --hash=sha256:10d88e55330e092940584774ee5e8a6971b01fc2f4d3466a1d6c158230880796 \ + --hash=sha256:11f35ad7805edba3aac1a3710d104592df59f4b957e30108ae0ba6c10b11dd75 \ + --hash=sha256:15acc09befd16a0fb8a8f62bc147e1a82817542d72184acca9ce6e0aeda9fa6d \ + --hash=sha256:17d2bf5d541add43822bcf0c43d7d847b160c9bb01d15d5007d84e2217aaa371 \ + --hash=sha256:2180f14c141d2f0f3da43f3a81bc8aa4684860f6b0e6f9e165a4831f24e6a23b \ + --hash=sha256:21e7c8d76f62ab13c9472ccf74515ca5b9a761d1bde3265152a6dc58700d895b \ + --hash=sha256:41a7170d042e8c0024703ed13b71893519a1a6d6e18e933e3ec7507a2c26a4b2 \ + --hash=sha256:41ed4b5ec103bd306bb68f81dc166e77409e5209443e5773cb4ed837bcc9b0d3 \ + --hash=sha256:497c31ce314219888c0e2fce5ad9178ca83fe5230b01a5006726cdf3ac9f24d9 \ + --hash=sha256:4c1b526c8d3f615a7b1867f38a9410849c8f4aef078535742198e942fba0e9bd \ + --hash=sha256:4d7092bb38c53bbc78e9255a59158b150bcdc115a1e3b3ce0b5f267dc35dd63c \ + --hash=sha256:4f5686e1fe5fce75d82d93c47a438a25bf0d1319d2843a926f741140b2b16e0c \ + --hash=sha256:58b0ee0ab5b1fc9921eccfe11d1435added19d6494dde14e323f25ad2bc30c56 \ + --hash=sha256:5ce02f38a754f207f2f06557523cd39a06438ba3aafc0639c477ac409fc64e37 \ + --hash=sha256:5fade934607a523614726119164ff621e8c30e8fa1ffffbbd358662056ba69f0 \ + --hash=sha256:5fe9fd43882620017add5eabb781ebfbc6998ee49b35bd7f8f79af1f9f99a958 \ + --hash=sha256:64102ca87e84261419c3747a0d20f396eb024bdbeb04c2bfb37e2891f5fadcb5 \ + --hash=sha256:664c5a68ec406f6b1547946683008576ef8b38275608e1cee6c061828171c118 \ + --hash=sha256:6675329885c44657f826ef01d9e4fb33b9158e9d93c537d84ad8399539bc6f69 \ + --hash=sha256:75c1a6dfac6abd407634420c93864a1e274ebc1c7531346d9254c0d8f6ca00f9 \ + --hash=sha256:75da8f28eff26defba42c52986de97b22106cb8f26515b7c22443ebc9c2d3261 \ + --hash=sha256:77efb033d8d7ff233385f30c62c7c79271c8885d5c9657d967ede124671bbdfb \ + --hash=sha256:78a7d3ab09dc47ac1a363a493e6112d8cabed7ba7caad5f54dbe2f08676d1b47 \ + --hash=sha256:7c7db70d57e5e1089a274cbb2b1fd635c9a24de809a231b154965d415d6c6d24 \ + --hash=sha256:8c56c488ab471628ff3bfa80964372fc13504ece601e0d97a78ee74126b2045c \ + --hash=sha256:91669ccac46bbc1d09e9273546181919064e8df73488ea087dcac3e2968df9ba \ + --hash=sha256:9b666a475a65f4e839d3d10473fad6d47e0a9db14a2f4a224029c5bfde58ad2c \ + --hash=sha256:9cfef3ab326780c04d6646f68d4b4742aae222e8b8ea1d627c74e38afcbc9d91 \ + --hash=sha256:a13fc8aeb24bad755eea8f7f9d409438eb94e82cf86b08fe77a03fbc8f6a96b1 \ + --hash=sha256:a75c301f96db737e1c5ed5fd7d77d9c34466de16095a266509e13da09751bd19 \ + --hash=sha256:a76d4cb80f41ba94a6691264be76435e5f72f2cb3cab0b092a6212855f71c2f6 \ + --hash=sha256:aed04cabe26f30c1647ef0e8fbb207516fd40fe9472e9439695f5c6998e60ac5 \ + --hash=sha256:b148b56f5de675ee16d45e769e69f87623a4944f7443850bf9a9376e628a89d2 \ + --hash=sha256:b501c862d4901792adaec7c25b1ecc749e2662543f68bb194c42ba18d6eec98d \ + --hash=sha256:b846a1fcf8beadeb9ea4f44ec5bdde393e2f1569e17d700bfc49cd69bde75881 \ + --hash=sha256:b931ae8f62db78861b0ff1ac017851764602288575d65b8e8ff1963fed419063 \ + --hash=sha256:c33ab3ca9d3ccd581d58e989d67554e42d8d4ded94ab3ade3508455fe70e65f7 \ + --hash=sha256:c6604b735bb12fef8e0efd5578c9fb5d3d8532d5001ea13a19cddf295673ee09 \ + --hash=sha256:d8db08051fc9e7d8bc622f2112511b8107d8f27cd89e2f64ec45e9825e8288da \ + --hash=sha256:d9203500f7c63545b4ce3799319fe4d9feb1a1b89b28d3cb5abd11b9dd64147e \ + --hash=sha256:dc492779501fa723b04d0ab1f5be046797fee17d27700476edc7ee9ae535a61e \ + --hash=sha256:e6bcdf33aec38d16508ce61fd81838f24c83c90a1d1b8c68982857038673d6b8 \ + --hash=sha256:e76ce097e3c57c4bcb67c5aa24a0ecdbd9f74ea9219997a707a4061fbe2707aa \ + --hash=sha256:eff1ac3cc66c2ac7cda1e64b4e2f3ffef474b7335f92fc3833fc632d595fcee6 \ + --hash=sha256:f3cb4a569029b9f291f88aafc927dd53683757e640081ca8c412781ea144565e \ + --hash=sha256:f79b168428351d11e10c5aeb61a74e1851ec221081299f4cf56036a95431c43a \ + --hash=sha256:fa646ecec9528bef693415c79a86e733c70a4965dd938e9a226b0fc64c9d2e6c \ + --hash=sha256:fe2efccb324948a11dd09d22136fe2ac8a97d6c1347cf0b58a911dcd529f66b7 \ + --hash=sha256:fff4f534200a04b4a36e7ae3cb74493afe807b517a09e99cb4faa89a34ed6ecd + # via matplotlib +fsspec==2025.10.0 \ + --hash=sha256:7c7712353ae7d875407f97715f0e1ffcc21e33d5b24556cb1e090ae9409ec61d \ + --hash=sha256:b6789427626f068f9a83ca4e8a3cc050850b6c0f71f99ddb4f542b8266a26a59 + # via etils +hypothesis==6.142.1 \ + --hash=sha256:3179cb08756562c526aaf4a9871ebbff83d2d75c03896ed0bc9c1d14097a930c \ + --hash=sha256:95a7d38fcc58e697e3020665adcb951c630cdbc8065e4b4474949e486b06bd6d + # via -r build/test-requirements.txt +importlib-resources==6.5.2 \ + --hash=sha256:185f87adef5bcc288449d98fb4fba07cea78bc036455dd44c5fc4a2fe78fed2c \ + --hash=sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec + # via etils +iniconfig==2.3.0 \ + --hash=sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730 \ + --hash=sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12 + # via pytest +jax-cuda12-pjrt==0.8.2 ; sys_platform == "linux" \ + --hash=sha256:717a1b196a642409ce195ddf031c20bbeadcc886f55e49a1d3f4927373aeedae \ + --hash=sha256:e3bab41ca7c48e4163db9e7efd271b3aa85f0fe45f5ed0708d6bbed93a59f977 + # via -r build/requirements.in +jax-cuda13-pjrt==0.8.2 \ + --hash=sha256:370e9cb9a76af6a6b00e8f6c6ae8e5a0158886b65994a114cf3acf56ea1570f3 \ + --hash=sha256:c0ebc284d4bea9c0b278017b4bb62eb871755fe582cbf22ce1fb770d30fb9af6 + # via + # -r build/requirements.in + # jax-cuda13-plugin +jax-cuda13-plugin==0.8.2 \ + --hash=sha256:0ab3a13f798f973416754cb00b1696a8e6d2425c39496e2145293661bcd28446 \ + --hash=sha256:2c7532f2a6c9dcf50ba4b8ebbe4c78988b8c260695e50f9fdbd68f9417a8bb51 \ + --hash=sha256:358f4b4d0d2b92f62fe7f07d355a9c57c5684da3d4e6f483afec38b23e66f02e \ + --hash=sha256:6703552ee10b6b41b43fe4fc62b4672bec13d271eb709a29bf14b0fa1ec805dd \ + --hash=sha256:8fe2c5874cbae18a01a0fa722dfbd547b5504e7bf42aaaae6ad3e1b676e8daa2 \ + --hash=sha256:950f2c1f2e0f5e13893c4c27a4df540a96a26ebd15d7954fa85ffb8cc8b1d9a1 \ + --hash=sha256:9835076d0a917345700370cfaee5cfd7876c0b0982e99206a8b4848af0ea6e5e \ + --hash=sha256:bb059d99eb960718358f82f3eaa0829f96fc7d7e978ded821f2c40256a8aee4b \ + --hash=sha256:d5f9470db46aee8a64e9ab7817927335b77c6d3698d5113e0da2dd962c1468c4 \ + --hash=sha256:f1643b216ccb41a2d7c0e240359c729c96c5eca76e4525ffcd6313049e348b9a \ + --hash=sha256:f35f687d9b839365a49c2b0052eaeb6b9f0a9879813b4fa020f8d9caa3ddde46 \ + --hash=sha256:fa6a6faf3130d6ef05f9ee6cbc27e0a2e4db0f23720136403991eccb3a0144af + # via -r build/requirements.in +jaxlib==0.8.2 \ + --hash=sha256:023de6f3f56da2af7037970996500586331fdb50b530ecbb54b9666da633bd00 \ + --hash=sha256:05b958f497e49824c432e734bb059723b7dfe69e2ad696a9f9c8ad82fff7c3f8 \ + --hash=sha256:1bfbcf6c3de221784fa4cdb6765a09d71cb4298b15626b3d0409b3dfcd8a8667 \ + --hash=sha256:28eec1a4e0639a0d8702cea3cb70dd3663053dbfa344452994ea48dc6ceadaa5 \ + --hash=sha256:2b9789bd08f8b0cc5a5c12ae896fe432d5942e32e417091b8b5a96a9a6fd5cf1 \ + --hash=sha256:3b16e50c5b730c9dd0a49e55f1acfaa722b00b1af0522a591558dcc0464252f2 \ + --hash=sha256:490bf0cb029c73c65c9431124b86cdc95082dbc1fb76fc549d24d75da33e5454 \ + --hash=sha256:4d006db96be020c8165212a1216372f8acac4ff4f8fb067743d694ef2b301ace \ + --hash=sha256:68108dff0de74adc468016be9a19f80efe48c660c0d5a122287094b44b092afc \ + --hash=sha256:7c304f3a016965b9d1f5239a8a0399a73925f5604fe914c5ca66ecf734bf6422 \ + --hash=sha256:7da8127557c786264049ae55460d1b8d04cc3cdf0403a087f2fc1e6d313ec722 \ + --hash=sha256:964626f581beab31ee6826b228fcc2ec5181b05cecf94a528dff97921c145dbc \ + --hash=sha256:a397ea7dcb37d689ce79173eeb99b2f1347637a36be9a27f20ae6848bfc58bfc \ + --hash=sha256:aa8701b6356f098e8452c3cec762fb5f706fcb8f67ffd65964f63982479aa23b \ + --hash=sha256:bb89be452b1b808d3f88fc01c415b364a260be4cc7ac120c038009f6150a32dc \ + --hash=sha256:beffb004e7eeb5c9afb24439e2b2cf45a4ee3e3e8adf45e355edf2af62acf8b8 \ + --hash=sha256:ccf77da917a20935247c990691decfcbdd06c25ef0ac94d914a04aadb22f714c \ + --hash=sha256:dffc22b5b732b9556d92c918b251c61bcc046617c4dbb51e1f7a656587fddffb \ + --hash=sha256:e6a97dfb0232eed9a2bb6e3828e4f682dbac1a7fea840bfda574cae2dbf5faf9 \ + --hash=sha256:f205e91c3a152a2a76c0bc59a6a2de03e87ec261b91e8812922777185e7b08f5 \ + --hash=sha256:f28edac8c226fc07fa3e8af6f9defede8ac2c307429e3291edce8739d39becc9 \ + --hash=sha256:f472cc72e3058e50b5f0230b236d5a1183bf6c3d5423d2a52eff07bcf34908de + # via -r build/requirements.in +kiwisolver==1.4.9 \ + --hash=sha256:0749fd8f4218ad2e851e11cc4dc05c7cbc0cbc4267bdfdb31782e65aace4ee9c \ + --hash=sha256:0763515d4df10edf6d06a3c19734e2566368980d21ebec439f33f9eb936c07b7 \ + --hash=sha256:0856e241c2d3df4efef7c04a1e46b1936b6120c9bcf36dd216e3acd84bc4fb21 \ + --hash=sha256:0a590506f303f512dff6b7f75fd2fd18e16943efee932008fe7140e5fa91d80e \ + --hash=sha256:0ab74e19f6a2b027ea4f845a78827969af45ce790e6cb3e1ebab71bdf9f215ff \ + --hash=sha256:0ae37737256ba2de764ddc12aed4956460277f00c4996d51a197e72f62f5eec7 \ + --hash=sha256:0e4e2bf29574a6a7b7f6cb5fa69293b9f96c928949ac4a53ba3f525dffb87f9c \ + --hash=sha256:15163165efc2f627eb9687ea5f3a28137217d217ac4024893d753f46bce9de26 \ + --hash=sha256:17680d737d5335b552994a2008fab4c851bcd7de33094a82067ef3a576ff02fa \ + --hash=sha256:1a12cf6398e8a0a001a059747a1cbf24705e18fe413bc22de7b3d15c67cffe3f \ + --hash=sha256:1b11d6a633e4ed84fc0ddafd4ebfd8ea49b3f25082c04ad12b8315c11d504dc1 \ + --hash=sha256:1fa333e8b2ce4d9660f2cda9c0e1b6bafcfb2457a9d259faa82289e73ec24891 \ + --hash=sha256:2327a4a30d3ee07d2fbe2e7933e8a37c591663b96ce42a00bc67461a87d7df77 \ + --hash=sha256:2405a7d98604b87f3fc28b1716783534b1b4b8510d8142adca34ee0bc3c87543 \ + --hash=sha256:2489e4e5d7ef9a1c300a5e0196e43d9c739f066ef23270607d45aba368b91f2d \ + --hash=sha256:24c175051354f4a28c5d6a31c93906dc653e2bf234e8a4bbfb964892078898ce \ + --hash=sha256:2635d352d67458b66fd0667c14cb1d4145e9560d503219034a18a87e971ce4f3 \ + --hash=sha256:2c1a4f57df73965f3f14df20b80ee29e6a7930a57d2d9e8491a25f676e197c60 \ + --hash=sha256:2c93f00dcba2eea70af2be5f11a830a742fe6b579a1d4e00f47760ef13be247a \ + --hash=sha256:39a219e1c81ae3b103643d2aedb90f1ef22650deb266ff12a19e7773f3e5f089 \ + --hash=sha256:3b3115b2581ea35bb6d1f24a4c90af37e5d9b49dcff267eeed14c3893c5b86ab \ + --hash=sha256:40092754720b174e6ccf9e845d0d8c7d8e12c3d71e7fc35f55f3813e96376f78 \ + --hash=sha256:412f287c55a6f54b0650bd9b6dce5aceddb95864a1a90c87af16979d37c89771 \ + --hash=sha256:464415881e4801295659462c49461a24fb107c140de781d55518c4b80cb6790f \ + --hash=sha256:497d05f29a1300d14e02e6441cf0f5ee81c1ff5a304b0d9fb77423974684e08b \ + --hash=sha256:4a2899935e724dd1074cb568ce7ac0dce28b2cd6ab539c8e001a8578eb106d14 \ + --hash=sha256:4a48a2ce79d65d363597ef7b567ce3d14d68783d2b2263d98db3d9477805ba32 \ + --hash=sha256:4d1d9e582ad4d63062d34077a9a1e9f3c34088a2ec5135b1f7190c07cf366527 \ + --hash=sha256:52a15b0f35dad39862d376df10c5230155243a2c1a436e39eb55623ccbd68185 \ + --hash=sha256:540c7c72324d864406a009d72f5d6856f49693db95d1fbb46cf86febef873634 \ + --hash=sha256:5656aa670507437af0207645273ccdfee4f14bacd7f7c67a4306d0dcaeaf6eed \ + --hash=sha256:5a0f2724dfd4e3b3ac5a82436a8e6fd16baa7d507117e4279b660fe8ca38a3a1 \ + --hash=sha256:60c439763a969a6af93b4881db0eed8fadf93ee98e18cbc35bc8da868d0c4f0c \ + --hash=sha256:61874cdb0a36016354853593cffc38e56fc9ca5aa97d2c05d3dcf6922cd55a11 \ + --hash=sha256:67bb8b474b4181770f926f7b7d2f8c0248cbcb78b660fdd41a47054b28d2a752 \ + --hash=sha256:720e05574713db64c356e86732c0f3c5252818d05f9df320f0ad8380641acea5 \ + --hash=sha256:72d0eb9fba308b8311685c2268cf7d0a0639a6cd027d8128659f72bdd8a024b4 \ + --hash=sha256:767c23ad1c58c9e827b649a9ab7809fd5fd9db266a9cf02b0e926ddc2c680d58 \ + --hash=sha256:77937e5e2a38a7b48eef0585114fe7930346993a88060d0bf886086d2aa49ef5 \ + --hash=sha256:7a08b491ec91b1d5053ac177afe5290adacf1f0f6307d771ccac5de30592d198 \ + --hash=sha256:7b4da0d01ac866a57dd61ac258c5607b4cd677f63abaec7b148354d2b2cdd536 \ + --hash=sha256:7cf974dd4e35fa315563ac99d6287a1024e4dc2077b8a7d7cd3d2fb65d283134 \ + --hash=sha256:84fd60810829c27ae375114cd379da1fa65e6918e1da405f356a775d49a62bcf \ + --hash=sha256:858e4c22fb075920b96a291928cb7dea5644e94c0ee4fcd5af7e865655e4ccf2 \ + --hash=sha256:85b5352f94e490c028926ea567fc569c52ec79ce131dadb968d3853e809518c2 \ + --hash=sha256:85bd218b5ecfbee8c8a82e121802dcb519a86044c9c3b2e4aef02fa05c6da370 \ + --hash=sha256:8a1f570ce4d62d718dce3f179ee78dac3b545ac16c0c04bb363b7607a949c0d1 \ + --hash=sha256:8fdca1def57a2e88ef339de1737a1449d6dbf5fab184c54a1fca01d541317154 \ + --hash=sha256:90f47e70293fc3688b71271100a1a5453aa9944a81d27ff779c108372cf5567b \ + --hash=sha256:92a2f997387a1b79a75e7803aa7ded2cfbe2823852ccf1ba3bcf613b62ae3197 \ + --hash=sha256:9928fe1eb816d11ae170885a74d074f57af3a0d65777ca47e9aeb854a1fba386 \ + --hash=sha256:9af39d6551f97d31a4deebeac6f45b156f9755ddc59c07b402c148f5dbb6482a \ + --hash=sha256:9cf554f21be770f5111a1690d42313e140355e687e05cf82cb23d0a721a64a48 \ + --hash=sha256:a30fd6fdef1430fd9e1ba7b3398b5ee4e2887783917a687d86ba69985fb08748 \ + --hash=sha256:a31d512c812daea6d8b3be3b2bfcbeb091dbb09177706569bcfc6240dcf8b41c \ + --hash=sha256:a5d0432ccf1c7ab14f9949eec60c5d1f924f17c037e9f8b33352fa05799359b8 \ + --hash=sha256:a60ea74330b91bd22a29638940d115df9dc00af5035a9a2a6ad9399ffb4ceca5 \ + --hash=sha256:ac5a486ac389dddcc5bef4f365b6ae3ffff2c433324fb38dd35e3fab7c957999 \ + --hash=sha256:aedff62918805fb62d43a4aa2ecd4482c380dc76cd31bd7c8878588a61bd0369 \ + --hash=sha256:b34e51affded8faee0dfdb705416153819d8ea9250bbbf7ea1b249bdeb5f1122 \ + --hash=sha256:b4b4d74bda2b8ebf4da5bd42af11d02d04428b2c32846e4c2c93219df8a7987b \ + --hash=sha256:b67e6efbf68e077dd71d1a6b37e43e1a99d0bff1a3d51867d45ee8908b931098 \ + --hash=sha256:b78efa4c6e804ecdf727e580dbb9cba85624d2e1c6b5cb059c66290063bd99a9 \ + --hash=sha256:bb4ae2b57fc1d8cbd1cf7b1d9913803681ffa903e7488012be5b76dedf49297f \ + --hash=sha256:bdd1a81a1860476eb41ac4bc1e07b3f07259e6d55bbf739b79c8aaedcf512799 \ + --hash=sha256:bdee92c56a71d2b24c33a7d4c2856bd6419d017e08caa7802d2963870e315028 \ + --hash=sha256:be6a04e6c79819c9a8c2373317d19a96048e5a3f90bec587787e86a1153883c2 \ + --hash=sha256:bfc08add558155345129c7803b3671cf195e6a56e7a12f3dde7c57d9b417f525 \ + --hash=sha256:c3b22c26c6fd6811b0ae8363b95ca8ce4ea3c202d3d0975b2914310ceb1bcc4d \ + --hash=sha256:c9e7cdf45d594ee04d5be1b24dd9d49f3d1590959b2271fb30b5ca2b262c00fb \ + --hash=sha256:cb27e7b78d716c591e88e0a09a2139c6577865d7f2e152488c2cc6257f460872 \ + --hash=sha256:cc9617b46837c6468197b5945e196ee9ca43057bb7d9d1ae688101e4e1dddf64 \ + --hash=sha256:ccd09f20ccdbbd341b21a67ab50a119b64a403b09288c27481575105283c1586 \ + --hash=sha256:ce6a3a4e106cf35c2d9c4fa17c05ce0b180db622736845d4315519397a77beaf \ + --hash=sha256:d0005b053977e7b43388ddec89fa567f43d4f6d5c2c0affe57de5ebf290dc552 \ + --hash=sha256:d4188e73af84ca82468f09cadc5ac4db578109e52acb4518d8154698d3a87ca2 \ + --hash=sha256:d4efec7bcf21671db6a3294ff301d2fc861c31faa3c8740d1a94689234d1b415 \ + --hash=sha256:d75aa530ccfaa593da12834b86a0724f58bff12706659baa9227c2ccaa06264c \ + --hash=sha256:d84cd4061ae292d8ac367b2c3fa3aad11cb8625a95d135fe93f286f914f3f5a6 \ + --hash=sha256:d8aacd3d4b33b772542b2e01beb50187536967b514b00003bdda7589722d2a64 \ + --hash=sha256:d8fc5c867c22b828001b6a38d2eaeb88160bf5783c6cb4a5e440efc981ce286d \ + --hash=sha256:d976bbb382b202f71c67f77b0ac11244021cfa3f7dfd9e562eefcea2df711548 \ + --hash=sha256:dba5ee5d3981160c28d5490f0d1b7ed730c22470ff7f6cc26cfcfaacb9896a07 \ + --hash=sha256:dc1ae486f9abcef254b5618dfb4113dd49f94c68e3e027d03cf0143f3f772b61 \ + --hash=sha256:dd0a578400839256df88c16abddf9ba14813ec5f21362e1fe65022e00c883d4d \ + --hash=sha256:deed0c7258ceb4c44ad5ec7d9918f9f14fd05b2be86378d86cf50e63d1e7b771 \ + --hash=sha256:e09c2279a4d01f099f52d5c4b3d9e208e91edcbd1a175c9662a8b16e000fece9 \ + --hash=sha256:e2ea9f7ab7fbf18fffb1b5434ce7c69a07582f7acc7717720f1d69f3e806f90c \ + --hash=sha256:e6b93f13371d341afee3be9f7c5964e3fe61d5fa30f6a30eb49856935dfe4fc3 \ + --hash=sha256:eb14a5da6dc7642b0f3a18f13654847cd8b7a2550e2645a5bda677862b03ba16 \ + --hash=sha256:ed0fecd28cc62c54b262e3736f8bb2512d8dcfdc2bcf08be5f47f96bf405b145 \ + --hash=sha256:ede8c6d533bc6601a47ad4046080d36b8fc99f81e6f1c17b0ac3c2dc91ac7611 \ + --hash=sha256:efb3a45b35622bb6c16dbfab491a8f5a391fe0e9d45ef32f4df85658232ca0e2 \ + --hash=sha256:f117e1a089d9411663a3207ba874f31be9ac8eaa5b533787024dc07aeb74f464 \ + --hash=sha256:f2ba92255faa7309d06fe44c3a4a97efe1c8d640c2a79a5ef728b685762a6fd2 \ + --hash=sha256:f6008a4919fdbc0b0097089f67a1eb55d950ed7e90ce2cc3e640abadd2757a04 \ + --hash=sha256:f68208a520c3d86ea51acf688a3e3002615a7f0238002cccc17affecc86a8a54 \ + --hash=sha256:f68e4f3eeca8fb22cc3d731f9715a13b652795ef657a13df1ad0c7dc0e9731df \ + --hash=sha256:fb3b8132019ea572f4611d770991000d7f58127560c4889729248eb5852a102f \ + --hash=sha256:fb940820c63a9590d31d88b815e7a3aa5915cad3ce735ab45f0c730b39547de1 \ + --hash=sha256:fc1795ac5cd0510207482c3d1d3ed781143383b8cfd36f5c645f3897ce066220 + # via matplotlib +libtpu==0.0.32 ; sys_platform == "linux" and platform_machine == "x86_64" \ + --hash=sha256:37f6aefe6d69d24e5e268e74c1e90cd0ce0fa6a796720709445d8938605912ee \ + --hash=sha256:4a4ae6db0a90f0e33902cd345923a5d77dfc5bbab9a3e524c79bf876858104ee \ + --hash=sha256:6453995774b902e43d1caf0d82298a2e90f2f731ddd74d6bef9510d91732775f \ + --hash=sha256:7ffd9cdb89da22d32bb60ce70b88e9e76861c5d704ebf673f32987e19e9913f3 \ + --hash=sha256:98d9713a39d9581c9b10a7bdeaa575cc874a4a7f543cfb40d3a723ee27b428b5 \ + --hash=sha256:d2a0e7abe4a029b79049bc95da20e4c2a64f8ef09823f75286cecfbb4b0b6da1 + # via -r build/requirements.in +markdown-it-py==4.0.0 \ + --hash=sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147 \ + --hash=sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3 + # via rich +matplotlib==3.10.8 \ + --hash=sha256:00270d217d6b20d14b584c521f810d60c5c78406dc289859776550df837dcda7 \ + --hash=sha256:0a33deb84c15ede243aead39f77e990469fff93ad1521163305095b77b72ce4a \ + --hash=sha256:113bb52413ea508ce954a02c10ffd0d565f9c3bc7f2eddc27dfe1731e71c7b5f \ + --hash=sha256:12d90df9183093fcd479f4172ac26b322b1248b15729cb57f42f71f24c7e37a3 \ + --hash=sha256:15d30132718972c2c074cd14638c7f4592bd98719e2308bccea40e0538bc0cb5 \ + --hash=sha256:18821ace09c763ec93aef5eeff087ee493a24051936d7b9ebcad9662f66501f9 \ + --hash=sha256:1ae029229a57cd1e8fe542485f27e7ca7b23aa9e8944ddb4985d0bc444f1eca2 \ + --hash=sha256:2299372c19d56bcd35cf05a2738308758d32b9eaed2371898d8f5bd33f084aa3 \ + --hash=sha256:238b7ce5717600615c895050239ec955d91f321c209dd110db988500558e70d6 \ + --hash=sha256:24d50994d8c5816ddc35411e50a86ab05f575e2530c02752e02538122613371f \ + --hash=sha256:25d380fe8b1dc32cf8f0b1b448470a77afb195438bafdf1d858bfb876f3edf7b \ + --hash=sha256:2c1998e92cd5999e295a731bcb2911c75f597d937341f3030cc24ef2733d78a8 \ + --hash=sha256:2cf5bd12cecf46908f286d7838b2abc6c91cda506c0445b8223a7c19a00df008 \ + --hash=sha256:32f8dce744be5569bebe789e46727946041199030db8aeb2954d26013a0eb26b \ + --hash=sha256:37b3c1cc42aa184b3f738cfa18c1c1d72fd496d85467a6cf7b807936d39aa656 \ + --hash=sha256:3a48a78d2786784cc2413e57397981fb45c79e968d99656706018d6e62e57958 \ + --hash=sha256:3ab4aabc72de4ff77b3ec33a6d78a68227bf1123465887f9905ba79184a1cc04 \ + --hash=sha256:3c624e43ed56313651bc18a47f838b60d7b8032ed348911c54906b130b20071b \ + --hash=sha256:3f2e409836d7f5ac2f1c013110a4d50b9f7edc26328c108915f9075d7d7a91b6 \ + --hash=sha256:3f5c3e4da343bba819f0234186b9004faba952cc420fbc522dc4e103c1985908 \ + --hash=sha256:41703cc95688f2516b480f7f339d8851a6035f18e100ee6a32bc0b8536a12a9c \ + --hash=sha256:495672de149445ec1b772ff2c9ede9b769e3cb4f0d0aa7fa730d7f59e2d4e1c1 \ + --hash=sha256:4cf267add95b1c88300d96ca837833d4112756045364f5c734a2276038dae27d \ + --hash=sha256:56271f3dac49a88d7fca5060f004d9d22b865f743a12a23b1e937a0be4818ee1 \ + --hash=sha256:595ba4d8fe983b88f0eec8c26a241e16d6376fe1979086232f481f8f3f67494c \ + --hash=sha256:5f62550b9a30afde8c1c3ae450e5eb547d579dd69b25c2fc7a1c67f934c1717a \ + --hash=sha256:646d95230efb9ca614a7a594d4fcacde0ac61d25e37dd51710b36477594963ce \ + --hash=sha256:64fcc24778ca0404ce0cb7b6b77ae1f4c7231cdd60e6778f999ee05cbd581b9a \ + --hash=sha256:6be43b667360fef5c754dda5d25a32e6307a03c204f3c0fc5468b78fa87b4160 \ + --hash=sha256:6da7c2ce169267d0d066adcf63758f0604aa6c3eebf67458930f9d9b79ad1db1 \ + --hash=sha256:83d282364ea9f3e52363da262ce32a09dfe241e4080dcedda3c0db059d3c1f11 \ + --hash=sha256:9153c3292705be9f9c64498a8872118540c3f4123d1a1c840172edf262c8be4a \ + --hash=sha256:99eefd13c0dc3b3c1b4d561c1169e65fe47aab7b8158754d7c084088e2329466 \ + --hash=sha256:a0a7f52498f72f13d4a25ea70f35f4cb60642b466cbb0a9be951b5bc3f45a486 \ + --hash=sha256:a2b336e2d91a3d7006864e0990c83b216fcdca64b5a6484912902cef87313d78 \ + --hash=sha256:a48f2b74020919552ea25d222d5cc6af9ca3f4eb43a93e14d068457f545c2a17 \ + --hash=sha256:ad3d9833a64cf48cc4300f2b406c3d0f4f4724a91c0bd5640678a6ba7c102077 \ + --hash=sha256:b44d07310e404ba95f8c25aa5536f154c0a8ec473303535949e52eb71d0a1565 \ + --hash=sha256:b53285e65d4fa4c86399979e956235deb900be5baa7fc1218ea67fbfaeaadd6f \ + --hash=sha256:b5a2b97dbdc7d4f353ebf343744f1d1f1cca8aa8bfddb4262fcf4306c3761d50 \ + --hash=sha256:b9a5ca4ac220a0cdd1ba6bcba3608547117d30468fefce49bb26f55c1a3d5c58 \ + --hash=sha256:bab485bcf8b1c7d2060b4fcb6fc368a9e6f4cd754c9c2fea281f4be21df394a2 \ + --hash=sha256:c108a1d6fa78a50646029cb6d49808ff0fc1330fda87fa6f6250c6b5369b6645 \ + --hash=sha256:d56a1efd5bfd61486c8bc968fa18734464556f0fb8e51690f4ac25d85cbbbbc2 \ + --hash=sha256:d9050fee89a89ed57b4fb2c1bfac9a3d0c57a0d55aed95949eedbc42070fea39 \ + --hash=sha256:dd80ecb295460a5d9d260df63c43f4afbdd832d725a531f008dad1664f458adf \ + --hash=sha256:e8ea3e2d4066083e264e75c829078f9e149fa119d27e19acd503de65e0b13149 \ + --hash=sha256:eb3823f11823deade26ce3b9f40dcb4a213da7a670013929f31d5f5ed1055b22 \ + --hash=sha256:ee40c27c795bda6a5292e9cff9890189d32f7e3a0bf04e0e3c9430c4a00c37df \ + --hash=sha256:efb30e3baaea72ce5928e32bab719ab4770099079d66726a62b11b1ef7273be4 \ + --hash=sha256:f254d118d14a7f99d616271d6c3c27922c092dac11112670b157798b89bf4933 \ + --hash=sha256:f89c151aab2e2e23cb3fe0acad1e8b82841fd265379c4cecd0f3fcb34c15e0f6 \ + --hash=sha256:f97aeb209c3d2511443f8797e3e5a569aebb040d4f8bc79aa3ee78a8fb9e3dd8 \ + --hash=sha256:f9b587c9c7274c1613a30afabf65a272114cd6cdbe67b3406f818c79d7ab2e2a \ + --hash=sha256:fb061f596dad3a0f52b60dc6a5dec4a0c300dec41e058a7efe09256188d170b7 + # via -r build/test-requirements.txt +mdurl==0.1.2 \ + --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ + --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba + # via markdown-it-py +ml-dtypes==0.5.4 \ + --hash=sha256:0d2ffd05a2575b1519dc928c0b93c06339eb67173ff53acb00724502cda231cf \ + --hash=sha256:11942cbf2cf92157db91e5022633c0d9474d4dfd813a909383bd23ce828a4b7d \ + --hash=sha256:14a4fd3228af936461db66faccef6e4f41c1d82fcc30e9f8d58a08916b1d811f \ + --hash=sha256:19b9a53598f21e453ea2fbda8aa783c20faff8e1eeb0d7ab899309a0053f1483 \ + --hash=sha256:2314892cdc3fcf05e373d76d72aaa15fda9fb98625effa73c1d646f331fcecb7 \ + --hash=sha256:2b857d3af6ac0d39db1de7c706e69c7f9791627209c3d6dedbfca8c7e5faec22 \ + --hash=sha256:304ad47faa395415b9ccbcc06a0350800bc50eda70f0e45326796e27c62f18b6 \ + --hash=sha256:35f29491a3e478407f7047b8a4834e4640a77d2737e0b294d049746507af5175 \ + --hash=sha256:388d399a2152dd79a3f0456a952284a99ee5c93d3e2f8dfe25977511e0515270 \ + --hash=sha256:3bbbe120b915090d9dd1375e4684dd17a20a2491ef25d640a908281da85e73f1 \ + --hash=sha256:3d277bf3637f2a62176f4575512e9ff9ef51d00e39626d9fe4a161992f355af2 \ + --hash=sha256:4381fe2f2452a2d7589689693d3162e876b3ddb0a832cde7a414f8e1adf7eab1 \ + --hash=sha256:4ff7f3e7ca2972e7de850e7b8fcbb355304271e2933dd90814c1cb847414d6e2 \ + --hash=sha256:531eff30e4d368cb6255bc2328d070e35836aa4f282a0fb5f3a0cd7260257298 \ + --hash=sha256:533ce891ba774eabf607172254f2e7260ba5f57bdd64030c9a4fcfbd99815d0d \ + --hash=sha256:557a31a390b7e9439056644cb80ed0735a6e3e3bb09d67fd5687e4b04238d1de \ + --hash=sha256:5a0f68ca8fd8d16583dfa7793973feb86f2fbb56ce3966daf9c9f748f52a2049 \ + --hash=sha256:6a0df4223b514d799b8a1629c65ddc351b3efa833ccf7f8ea0cf654a61d1e35d \ + --hash=sha256:6c7ecb74c4bd71db68a6bea1edf8da8c34f3d9fe218f038814fd1d310ac76c90 \ + --hash=sha256:7c23c54a00ae43edf48d44066a7ec31e05fdc2eee0be2b8b50dd1903a1db94bb \ + --hash=sha256:805cef3a38f4eafae3a5bf9ebdcdb741d0bcfd9e1bd90eb54abd24f928cd2465 \ + --hash=sha256:88c982aac7cb1cbe8cbb4e7f253072b1df872701fcaf48d84ffbb433b6568f24 \ + --hash=sha256:8ab06a50fb9bf9666dd0fe5dfb4676fa2b0ac0f31ecff72a6c3af8e22c063453 \ + --hash=sha256:8c6a2dcebd6f3903e05d51960a8058d6e131fe69f952a5397e5dbabc841b6d56 \ + --hash=sha256:8c760d85a2f82e2bed75867079188c9d18dae2ee77c25a54d60e9cc79be1bc48 \ + --hash=sha256:9ad459e99793fa6e13bd5b7e6792c8f9190b4e5a1b45c63aba14a4d0a7f1d5ff \ + --hash=sha256:9bad06436568442575beb2d03389aa7456c690a5b05892c471215bfd8cf39460 \ + --hash=sha256:a174837a64f5b16cab6f368171a1a03a27936b31699d167684073ff1c4237dac \ + --hash=sha256:a7f7c643e8b1320fd958bf098aa7ecf70623a42ec5154e3be3be673f4c34d900 \ + --hash=sha256:a9b61c19040397970d18d7737375cffd83b1f36a11dd4ad19f83a016f736c3ef \ + --hash=sha256:b4b801ebe0b477be666696bda493a9be8356f1f0057a57f1e35cd26928823e5a \ + --hash=sha256:b95e97e470fe60ed493fd9ae3911d8da4ebac16bd21f87ffa2b7c588bf22ea2c \ + --hash=sha256:bc11d7e8c44a65115d05e2ab9989d1e045125d7be8e05a071a48bc76eb6d6040 \ + --hash=sha256:bfc534409c5d4b0bf945af29e5d0ab075eae9eecbb549ff8a29280db822f34f9 \ + --hash=sha256:c1a953995cccb9e25a4ae19e34316671e4e2edaebe4cf538229b1fc7109087b7 \ + --hash=sha256:cb73dccfc991691c444acc8c0012bee8f2470da826a92e3a20bb333b1a7894e6 \ + --hash=sha256:ce756d3a10d0c4067172804c9cc276ba9cc0ff47af9078ad439b075d1abdc29b \ + --hash=sha256:d81fdb088defa30eb37bf390bb7dde35d3a83ec112ac8e33d75ab28cc29dd8b0 \ + --hash=sha256:f21c9219ef48ca5ee78402d5cc831bd58ea27ce89beda894428bc67a52da5328 + # via + # -r build/requirements.in + # jaxlib +mpmath==1.3.0 \ + --hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \ + --hash=sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c + # via -r build/test-requirements.txt +numpy==2.3.5 ; python_version >= "3.14" \ + --hash=sha256:00dc4e846108a382c5869e77c6ed514394bdeb3403461d25a829711041217d5b \ + --hash=sha256:0472f11f6ec23a74a906a00b48a4dcf3849209696dff7c189714511268d103ae \ + --hash=sha256:04822c00b5fd0323c8166d66c701dc31b7fbd252c100acd708c48f763968d6a3 \ + --hash=sha256:052e8c42e0c49d2575621c158934920524f6c5da05a1d3b9bab5d8e259e045f0 \ + --hash=sha256:09a1bea522b25109bf8e6f3027bd810f7c1085c64a0c7ce050c1676ad0ba010b \ + --hash=sha256:0cd00b7b36e35398fa2d16af7b907b65304ef8bb4817a550e06e5012929830fa \ + --hash=sha256:0d8163f43acde9a73c2a33605353a4f1bc4798745a8b1d73183b28e5b435ae28 \ + --hash=sha256:1062fde1dcf469571705945b0f221b73928f34a20c904ffb45db101907c3454e \ + --hash=sha256:11e06aa0af8c0f05104d56450d6093ee639e15f24ecf62d417329d06e522e017 \ + --hash=sha256:17531366a2e3a9e30762c000f2c43a9aaa05728712e25c11ce1dbe700c53ad41 \ + --hash=sha256:1978155dd49972084bd6ef388d66ab70f0c323ddee6f693d539376498720fb7e \ + --hash=sha256:1ed1ec893cff7040a02c8aa1c8611b94d395590d553f6b53629a4461dc7f7b63 \ + --hash=sha256:2dcd0808a421a482a080f89859a18beb0b3d1e905b81e617a188bd80422d62e9 \ + --hash=sha256:2e2eb32ddb9ccb817d620ac1d8dae7c3f641c1e5f55f531a33e8ab97960a75b8 \ + --hash=sha256:2feae0d2c91d46e59fcd62784a3a83b3fb677fead592ce51b5a6fbb4f95965ff \ + --hash=sha256:3095bdb8dd297e5920b010e96134ed91d852d81d490e787beca7e35ae1d89cf7 \ + --hash=sha256:30bc11310e8153ca664b14c5f1b73e94bd0503681fcf136a163de856f3a50139 \ + --hash=sha256:3101e5177d114a593d79dd79658650fe28b5a0d8abeb8ce6f437c0e6df5be1a4 \ + --hash=sha256:396084a36abdb603546b119d96528c2f6263921c50df3c8fd7cb28873a237748 \ + --hash=sha256:3997b5b3c9a771e157f9aae01dd579ee35ad7109be18db0e85dbdbe1de06e952 \ + --hash=sha256:414802f3b97f3c1eef41e530aaba3b3c1620649871d8cb38c6eaff034c2e16bd \ + --hash=sha256:51c1e14eb1e154ebd80e860722f9e6ed6ec89714ad2db2d3aa33c31d7c12179b \ + --hash=sha256:51c55fe3451421f3a6ef9a9c1439e82101c57a2c9eab9feb196a62b1a10b58ce \ + --hash=sha256:5ee6609ac3604fa7780e30a03e5e241a7956f8e2fcfe547d51e3afa5247ac47f \ + --hash=sha256:612a95a17655e213502f60cfb9bf9408efdc9eb1d5f50535cc6eb365d11b42b5 \ + --hash=sha256:6203fdf9f3dc5bdaed7319ad8698e685c7a3be10819f41d32a0723e611733b42 \ + --hash=sha256:63c0e9e7eea69588479ebf4a8a270d5ac22763cc5854e9a7eae952a3908103f7 \ + --hash=sha256:66f85ce62c70b843bab1fb14a05d5737741e74e28c7b8b5a064de10142fad248 \ + --hash=sha256:6cf9b429b21df6b99f4dee7a1218b8b7ffbbe7df8764dc0bd60ce8a0708fed1e \ + --hash=sha256:70b37199913c1bd300ff6e2693316c6f869c7ee16378faf10e4f5e3275b299c3 \ + --hash=sha256:727fd05b57df37dc0bcf1a27767a3d9a78cbbc92822445f32cc3436ba797337b \ + --hash=sha256:74ae7b798248fe62021dbf3c914245ad45d1a6b0cb4a29ecb4b31d0bfbc4cc3e \ + --hash=sha256:784db1dcdab56bf0517743e746dfb0f885fc68d948aba86eeec2cba234bdf1c0 \ + --hash=sha256:86945f2ee6d10cdfd67bcb4069c1662dd711f7e2a4343db5cecec06b87cf31aa \ + --hash=sha256:86d835afea1eaa143012a2d7a3f45a3adce2d7adc8b4961f0b362214d800846a \ + --hash=sha256:872a5cf366aec6bb1147336480fef14c9164b154aeb6542327de4970282cd2f5 \ + --hash=sha256:8b973c57ff8e184109db042c842423ff4f60446239bd585a5131cc47f06f789d \ + --hash=sha256:8cba086a43d54ca804ce711b2a940b16e452807acebe7852ff327f1ecd49b0d4 \ + --hash=sha256:8f7f0e05112916223d3f438f293abf0727e1181b5983f413dfa2fefc4098245c \ + --hash=sha256:900218e456384ea676e24ea6a0417f030a3b07306d29d7ad843957b40a9d8d52 \ + --hash=sha256:93eebbcf1aafdf7e2ddd44c2923e2672e1010bddc014138b229e49725b4d6be5 \ + --hash=sha256:9c75442b2209b8470d6d5d8b1c25714270686f14c749028d2199c54e29f20b4d \ + --hash=sha256:9ee2197ef8c4f0dfe405d835f3b6a14f5fee7782b5de51ba06fb65fc9b36e9f1 \ + --hash=sha256:a414504bef8945eae5f2d7cb7be2d4af77c5d1cb5e20b296c2c25b61dff2900c \ + --hash=sha256:a4b9159734b326535f4dd01d947f919c6eefd2d9827466a696c44ced82dfbc18 \ + --hash=sha256:a80afd79f45f3c4a7d341f13acbe058d1ca8ac017c165d3fa0d3de6bc1a079d7 \ + --hash=sha256:aa5bc7c5d59d831d9773d1170acac7893ce3a5e130540605770ade83280e7188 \ + --hash=sha256:acfd89508504a19ed06ef963ad544ec6664518c863436306153e13e94605c218 \ + --hash=sha256:aeffcab3d4b43712bb7a60b65f6044d444e75e563ff6180af8f98dd4b905dfd2 \ + --hash=sha256:afaffc4393205524af9dfa400fa250143a6c3bc646c08c9f5e25a9f4b4d6a903 \ + --hash=sha256:b0c7088a73aef3d687c4deef8452a3ac7c1be4e29ed8bf3b366c8111128ac60c \ + --hash=sha256:b46b4ec24f7293f23adcd2d146960559aaf8020213de8ad1909dba6c013bf89c \ + --hash=sha256:b501b5fa195cc9e24fe102f21ec0a44dffc231d2af79950b451e0d99cea02234 \ + --hash=sha256:bf06bc2af43fa8d32d30fae16ad965663e966b1a3202ed407b84c989c3221e82 \ + --hash=sha256:c804e3a5aba5460c73955c955bdbd5c08c354954e9270a2c1565f62e866bdc39 \ + --hash=sha256:c8a9958e88b65c3b27e22ca2a076311636850b612d6bbfb76e8d156aacde2aaf \ + --hash=sha256:cc0a57f895b96ec78969c34f682c602bf8da1a0270b09bc65673df2e7638ec20 \ + --hash=sha256:cc8920d2ec5fa99875b670bb86ddeb21e295cb07aa331810d9e486e0b969d946 \ + --hash=sha256:ccc933afd4d20aad3c00bcef049cb40049f7f196e0397f1109dba6fed63267b0 \ + --hash=sha256:ce581db493ea1a96c0556360ede6607496e8bf9b3a8efa66e06477267bc831e9 \ + --hash=sha256:d0f23b44f57077c1ede8c5f26b30f706498b4862d3ff0a7298b8411dd2f043ff \ + --hash=sha256:d21644de1b609825ede2f48be98dfde4656aefc713654eeee280e37cadc4e0ad \ + --hash=sha256:d6889ec4ec662a1a37eb4b4fb26b6100841804dac55bd9df579e326cdc146227 \ + --hash=sha256:de5672f4a7b200c15a4127042170a694d4df43c992948f5e1af57f0174beed10 \ + --hash=sha256:e6a0bc88393d65807d751a614207b7129a310ca4fe76a74e5c7da5fa5671417e \ + --hash=sha256:ed89927b86296067b4f81f108a2271d8926467a8868e554eaf370fc27fa3ccaf \ + --hash=sha256:ee3888d9ff7c14604052b2ca5535a30216aa0a58e948cdd3eeb8d3415f638769 \ + --hash=sha256:f0963b55cdd70fad460fa4c1341f12f976bb26cb66021a5580329bd498988310 \ + --hash=sha256:f16417ec91f12f814b10bafe79ef77e70113a2f5f7018640e7425ff979253425 \ + --hash=sha256:f28620fe26bee16243be2b7b874da327312240a7cdc38b769a697578d2100013 \ + --hash=sha256:f4255143f5160d0de972d28c8f9665d882b5f61309d8362fdd3e103cf7bf010c \ + --hash=sha256:ffac52f28a7849ad7576293c0cb7b9f08304e8f7d738a8cb8a90ec4c55a998eb \ + --hash=sha256:ffe22d2b05504f786c867c8395de703937f934272eb67586817b46188b4ded6d \ + --hash=sha256:fffe29a1ef00883599d1dc2c51aa2e5d80afe49523c261a74933df395c15c520 + # via + # -r build/freethreading-requirements.txt + # contourpy + # jaxlib + # matplotlib + # ml-dtypes + # numpy-typing-compat + # optype + # scipy +numpy-typing-compat==20251206.2.4 \ + --hash=sha256:59882d23aaff054a2536da80564012cdce33487657be4d79c5925bb8705fcabc \ + --hash=sha256:a82e723bd20efaa4cf2886709d4264c144f1f2b609bda83d1545113b7e47a5b5 + # via optype +nvidia-cublas==13.1.0.3 ; sys_platform == "linux" \ + --hash=sha256:2a3b94a37def342471c59fad7856caee4926809a72dd5270155d6a31b5b277be \ + --hash=sha256:c86fc7f7ae36d7528288c5d88098edcb7b02c633d262e7ddbb86b0ad91be5df2 \ + --hash=sha256:ee8722c1f0145ab246bccb9e452153b5e0515fd094c3678df50b2a0888b8b171 + # via + # -r build/nvidia-requirements.txt + # nvidia-cudnn-cu13 + # nvidia-cusolver +nvidia-cublas-cu12==12.9.1.4 ; sys_platform == "linux" \ + --hash=sha256:1e5fee10662e6e52bd71dec533fbbd4971bb70a5f24f3bc3793e5c2e9dc640bf \ + --hash=sha256:453611eb21a7c1f2c2156ed9f3a45b691deda0440ec550860290dc901af5b4c2 \ + --hash=sha256:7a950dae01add3b415a5a5cdc4ec818fb5858263e9cca59004bb99fdbbd3a5d6 + # via + # -r build/nvidia-requirements.txt + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-crt==13.0.88 ; sys_platform == "linux" \ + --hash=sha256:2c8043c7c9e02492716426e9919fc78d2c5b3b2a7a768a88e952676b08aa55a4 \ + --hash=sha256:31e02c52916804ca15e31f272a96181d8fadaf40c4c82a77a6f78071a22eccf3 \ + --hash=sha256:ee2ea2a97073e02ee62bb27841f437332be2c248e3eac013df07997ada39c003 + # via + # -r build/nvidia-requirements.txt + # nvidia-cuda-nvcc +nvidia-cuda-cupti==13.0.85 ; sys_platform == "linux" \ + --hash=sha256:4eb01c08e859bf924d222250d2e8f8b8ff6d3db4721288cf35d14252a4d933c8 \ + --hash=sha256:683f58d301548deeefcb8f6fac1b8d907691b9d8b18eccab417f51e362102f00 \ + --hash=sha256:796bd679890ee55fb14a94629b698b6db54bcfd833d391d5e94017dd9d7d3151 + # via -r build/nvidia-requirements.txt +nvidia-cuda-cupti-cu12==12.9.79 ; sys_platform == "linux" \ + --hash=sha256:096bcf334f13e1984ba36685ad4c1d6347db214de03dbb6eebb237b41d9d934f \ + --hash=sha256:1848a9380067560d5bee10ed240eecc22991713e672c0515f9c3d9396adf93c8 \ + --hash=sha256:791853b030602c6a11d08b5578edfb957cadea06e9d3b26adbf8d036135a4afe + # via -r build/nvidia-requirements.txt +nvidia-cuda-nvcc==13.0.88 ; sys_platform == "linux" \ + --hash=sha256:56fe502eb77625a12f25172caa3cdddb4e4c8ba2c8c17dba44b164761b380f03 \ + --hash=sha256:7c3a32c8ca9866addfd784da363ddee2f6874d560027a296f583e86a61f2d543 \ + --hash=sha256:c7ff28f86a24effdc6c034fa15230c549a273e4771b10a7fec14996f8cf3307f + # via -r build/nvidia-requirements.txt +nvidia-cuda-nvcc-cu12==12.9.86 ; sys_platform == "linux" \ + --hash=sha256:44e1eca4d08926193a558d2434b1bf83d57b4d5743e0c431c0c83d51da1df62b \ + --hash=sha256:5d6a0d32fdc7ea39917c20065614ae93add6f577d840233237ff08e9a38f58f0 \ + --hash=sha256:8ed7f0b17dea662755395be029376db3b94fed5cbb17c2d35cc866c5b1b84099 + # via -r build/nvidia-requirements.txt +nvidia-cuda-nvrtc==13.0.88 ; sys_platform == "linux" \ + --hash=sha256:6bcd4e7f8e205cbe644f5a98f2f799bef9556fefc89dd786e79a16312ce49872 \ + --hash=sha256:ad9b6d2ead2435f11cbb6868809d2adeeee302e9bb94bcf0539c7a40d80e8575 \ + --hash=sha256:d27f20a0ca67a4bb34268a5e951033496c5b74870b868bacd046b1b8e0c3267b + # via -r build/nvidia-requirements.txt +nvidia-cuda-nvrtc-cu12==12.9.86 ; sys_platform == "linux" \ + --hash=sha256:096d4de6bda726415dfaf3198d4f5c522b8e70139c97feef5cd2ca6d4cd9cead \ + --hash=sha256:210cf05005a447e29214e9ce50851e83fc5f4358df8b453155d5e1918094dcb4 \ + --hash=sha256:72972ebdcf504d69462d3bcd67e7b81edd25d0fb85a2c46d3ea3517666636349 + # via -r build/nvidia-requirements.txt +nvidia-cuda-runtime==13.0.96 ; sys_platform == "linux" \ + --hash=sha256:7f82250d7782aa23b6cfe765ecc7db554bd3c2870c43f3d1821f1d18aebf0548 \ + --hash=sha256:ef9bcbe90493a2b9d810e43d249adb3d02e98dd30200d86607d8d02687c43f55 \ + --hash=sha256:f79298c8a098cec150a597c8eba58ecdab96e3bdc4b9bc4f9983635031740492 + # via + # -r build/nvidia-requirements.txt + # nvidia-cuda-nvcc +nvidia-cuda-runtime-cu12==12.9.79 ; sys_platform == "linux" \ + --hash=sha256:25bba2dfb01d48a9b59ca474a1ac43c6ebf7011f1b0b8cc44f54eb6ac48a96c3 \ + --hash=sha256:83469a846206f2a733db0c42e223589ab62fd2fabac4432d2f8802de4bded0a4 \ + --hash=sha256:8e018af8fa02363876860388bd10ccb89eb9ab8fb0aa749aaf58430a9f7c4891 + # via -r build/nvidia-requirements.txt +nvidia-cudnn-cu12==9.16.0.29 ; sys_platform == "linux" \ + --hash=sha256:142e2bd646a4573ab17d61a24c6359155cdfe1f34c67fc305b71222a7ae45b8e \ + --hash=sha256:4b09c43096db582f110c5572d0bcbd98b30d709e860a8f73c6c3846baa83b8d2 \ + --hash=sha256:78d05b4434dacc7dd9bc903d5c33a2f28a5f0064d02568ef7b2418f89f6c5922 + # via -r build/nvidia-requirements.txt +nvidia-cudnn-cu13==9.16.0.29 ; sys_platform == "linux" \ + --hash=sha256:6349bc8769369a91611c5e2ce5c2e510e61848c245099c31e870d2cdce0ab90d \ + --hash=sha256:79dc1bfe8c1a780cf4eb7b334d14d7927576d6dd8823f8e2769911af30fd4da3 \ + --hash=sha256:faafa46e2e7dd844bbcf06b6adec3fa66924987f2fb21bf67f5c6fd697c74a64 + # via -r build/nvidia-requirements.txt +nvidia-cufft==12.0.0.61 ; sys_platform == "linux" \ + --hash=sha256:2708c852ef8cd89d1d2068bdbece0aa188813a0c934db3779b9b1faa8442e5f5 \ + --hash=sha256:2abce5b39d2f5ae12730fb7e5db6696533e36c26e2d3e8fd1750bdd2853364eb \ + --hash=sha256:6c44f692dce8fd5ffd3e3df134b6cdb9c2f72d99cf40b62c32dde45eea9ddad3 + # via -r build/nvidia-requirements.txt +nvidia-cufft-cu12==11.4.1.4 ; sys_platform == "linux" \ + --hash=sha256:1a28c9b12260a1aa7a8fd12f5ebd82d027963d635ba82ff39a1acfa7c4c0fbcf \ + --hash=sha256:8e5bfaac795e93f80611f807d42844e8e27e340e0cde270dcb6c65386d795b80 \ + --hash=sha256:c67884f2a7d276b4b80eb56a79322a95df592ae5e765cf1243693365ccab4e28 + # via -r build/nvidia-requirements.txt +nvidia-cusolver==12.0.4.66 ; sys_platform == "linux" \ + --hash=sha256:02c2457eaa9e39de20f880f4bd8820e6a1cfb9f9a34f820eb12a155aa5bc92d2 \ + --hash=sha256:0a759da5dea5c0ea10fd307de75cdeb59e7ea4fcb8add0924859b944babf1112 \ + --hash=sha256:16515bd33a8e76bb54d024cfa068fa68d30e80fc34b9e1090813ea9362e0cb65 + # via -r build/nvidia-requirements.txt +nvidia-cusolver-cu12==11.7.5.82 ; sys_platform == "linux" \ + --hash=sha256:15da72d1340d29b5b3cf3fd100e3cd53421dde36002eda6ed93811af63c40d88 \ + --hash=sha256:62efa83e4ace59a4c734d052bb72158e888aa7b770e1a5f601682f16fe5b4fd2 \ + --hash=sha256:77666337237716783c6269a658dea310195cddbd80a5b2919b1ba8735cec8efd + # via -r build/nvidia-requirements.txt +nvidia-cusparse==12.6.3.3 ; sys_platform == "linux" \ + --hash=sha256:2b3c89c88d01ee0e477cb7f82ef60a11a4bcd57b6b87c33f789350b59759360b \ + --hash=sha256:80bcc4662f23f1054ee334a15c72b8940402975e0eab63178fc7e670aa59472c \ + --hash=sha256:cbcf42feb737bd7ec15b4c0a63e62351886bd3f975027b8815d7f720a2b5ea79 + # via + # -r build/nvidia-requirements.txt + # nvidia-cusolver +nvidia-cusparse-cu12==12.5.10.65 ; sys_platform == "linux" \ + --hash=sha256:221c73e7482dd93eda44e65ce567c031c07e2f93f6fa0ecd3ba876a195023e83 \ + --hash=sha256:73060ce019ac064a057267c585bf1fd5a353734151f87472ff02b2c5c9984e78 \ + --hash=sha256:9e487468a22a1eaf1fbd1d2035936a905feb79c4ce5c2f67626764ee4f90227c + # via + # -r build/nvidia-requirements.txt + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.28.9 ; sys_platform == "linux" \ + --hash=sha256:485776daa8447da5da39681af455aa3b2c2586ddcf4af8772495e7c532c7e5ab \ + --hash=sha256:50a36e01c4a090b9f9c47d92cec54964de6b9fcb3362d0e19b8ffc6323c21b60 + # via -r build/nvidia-requirements.txt +nvidia-nccl-cu13==2.28.9 ; sys_platform == "linux" \ + --hash=sha256:01c873ba1626b54caa12272ed228dc5b2781545e0ae8ba3f432a8ef1c6d78643 \ + --hash=sha256:e4553a30f34195f3fa1da02a6da3d6337d28f2003943aa0a3d247bbc25fefc42 + # via -r build/nvidia-requirements.txt +nvidia-nvjitlink==13.0.88 ; sys_platform == "linux" \ + --hash=sha256:13a74f429e23b921c1109976abefacc69835f2f433ebd323d3946e11d804e47b \ + --hash=sha256:634e96e3da9ef845ae744097a1f289238ecf946ce0b82e93cdce14b9782e682f \ + --hash=sha256:e931536ccc7d467a98ba1d8b89ff7fa7f1fa3b13f2b0069118cd7f47bff07d0c + # via + # -r build/nvidia-requirements.txt + # nvidia-cufft + # nvidia-cusolver + # nvidia-cusparse +nvidia-nvjitlink-cu12==12.9.86 ; sys_platform == "linux" \ + --hash=sha256:994a05ef08ef4b0b299829cde613a424382aff7efb08a7172c1fa616cc3af2ca \ + --hash=sha256:cc6fcec260ca843c10e34c936921a1c426b351753587fdd638e8cff7b16bb9db \ + --hash=sha256:e3f1171dbdc83c5932a45f0f4c99180a70de9bd2718c1ab77d14104f6d7147f9 + # via + # -r build/nvidia-requirements.txt + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 +nvidia-nvshmem-cu12==3.4.5 ; sys_platform == "linux" \ + --hash=sha256:042f2500f24c021db8a06c5eec2539027d57460e1c1a762055a6554f72c369bd \ + --hash=sha256:0b48363fc6964dede448029434c6abed6c5e37f823cb43c3bcde7ecfc0457e15 + # via -r build/nvidia-requirements.txt +nvidia-nvshmem-cu13==3.4.5 ; sys_platform == "linux" \ + --hash=sha256:290f0a2ee94c9f3687a02502f3b9299a9f9fe826e6d0287ee18482e78d495b80 \ + --hash=sha256:6dc2a197f38e5d0376ad52cd1a2a3617d3cdc150fd5966f4aee9bcebb1d68fe9 + # via -r build/nvidia-requirements.txt +nvidia-nvvm==13.0.88 ; sys_platform == "linux" \ + --hash=sha256:2ef0db7849e476d3b2fc3c09b27bdd79bd7ea8ce58cd9c86553d64ea40844ba0 \ + --hash=sha256:c4376a291d72d22a315d9d2f69bdae8f8cd83a627f75bad395cee49a0fe65dc1 \ + --hash=sha256:c5f41ffeb6466944a026dfa5317d7d85355c119bbec279205d22f1869d1054e0 + # via + # -r build/nvidia-requirements.txt + # nvidia-cuda-nvcc +opt-einsum==3.4.0 \ + --hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \ + --hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac + # via -r build/requirements.in +optype[numpy]==0.15.0 \ + --hash=sha256:457d6ca9e7da19967ec16d42bdf94e240b33b5d70a56fbbf5b427e5ea39cf41e \ + --hash=sha256:caba40ece9ea39b499fa76c036a82e0d452a432dd4dd3e8e0d30892be2e8c76c + # via scipy-stubs +packaging==25.0 \ + --hash=sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484 \ + --hash=sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f + # via + # auditwheel + # build + # matplotlib + # pytest + # wheel +pillow==12.0.0 \ + --hash=sha256:0869154a2d0546545cde61d1789a6524319fc1897d9ee31218eae7a60ccc5643 \ + --hash=sha256:09f2d0abef9e4e2f349305a4f8cc784a8a6c2f58a8c4892eea13b10a943bd26e \ + --hash=sha256:0b817e7035ea7f6b942c13aa03bb554fc44fea70838ea21f8eb31c638326584e \ + --hash=sha256:0fd00cac9c03256c8b2ff58f162ebcd2587ad3e1f2e397eab718c47e24d231cc \ + --hash=sha256:110486b79f2d112cf6add83b28b627e369219388f64ef2f960fef9ebaf54c642 \ + --hash=sha256:1979f4566bb96c1e50a62d9831e2ea2d1211761e5662afc545fa766f996632f6 \ + --hash=sha256:1ac11e8ea4f611c3c0147424eae514028b5e9077dd99ab91e1bd7bc33ff145e1 \ + --hash=sha256:1b1b133e6e16105f524a8dec491e0586d072948ce15c9b914e41cdadd209052b \ + --hash=sha256:1ee80a59f6ce048ae13cda1abf7fbd2a34ab9ee7d401c46be3ca685d1999a399 \ + --hash=sha256:21f241bdd5080a15bc86d3466a9f6074a9c2c2b314100dd896ac81ee6db2f1ba \ + --hash=sha256:266cd5f2b63ff316d5a1bba46268e603c9caf5606d44f38c2873c380950576ad \ + --hash=sha256:26d9f7d2b604cd23aba3e9faf795787456ac25634d82cd060556998e39c6fa47 \ + --hash=sha256:27f95b12453d165099c84f8a8bfdfd46b9e4bda9e0e4b65f0635430027f55739 \ + --hash=sha256:2c54c1a783d6d60595d3514f0efe9b37c8808746a66920315bfd34a938d7994b \ + --hash=sha256:2fa5f0b6716fc88f11380b88b31fe591a06c6315e955c096c35715788b339e3f \ + --hash=sha256:32ed80ea8a90ee3e6fa08c21e2e091bba6eda8eccc83dbc34c95169507a91f10 \ + --hash=sha256:3830c769decf88f1289680a59d4f4c46c72573446352e2befec9a8512104fa52 \ + --hash=sha256:38df9b4bfd3db902c9c2bd369bcacaf9d935b2fff73709429d95cc41554f7b3d \ + --hash=sha256:3adfb466bbc544b926d50fe8f4a4e6abd8c6bffd28a26177594e6e9b2b76572b \ + --hash=sha256:3e42edad50b6909089750e65c91aa09aaf1e0a71310d383f11321b27c224ed8a \ + --hash=sha256:4078242472387600b2ce8d93ade8899c12bf33fa89e55ec89fe126e9d6d5d9e9 \ + --hash=sha256:455247ac8a4cfb7b9bc45b7e432d10421aea9fc2e74d285ba4072688a74c2e9d \ + --hash=sha256:4cc6b3b2efff105c6a1656cfe59da4fdde2cda9af1c5e0b58529b24525d0a098 \ + --hash=sha256:4cf7fed4b4580601c4345ceb5d4cbf5a980d030fd5ad07c4d2ec589f95f09905 \ + --hash=sha256:5193fde9a5f23c331ea26d0cf171fbf67e3f247585f50c08b3e205c7aeb4589b \ + --hash=sha256:5269cc1caeedb67e6f7269a42014f381f45e2e7cd42d834ede3c703a1d915fe3 \ + --hash=sha256:53561a4ddc36facb432fae7a9d8afbfaf94795414f5cdc5fc52f28c1dca90371 \ + --hash=sha256:55f818bd74fe2f11d4d7cbc65880a843c4075e0ac7226bc1a23261dbea531953 \ + --hash=sha256:58eea5ebe51504057dd95c5b77d21700b77615ab0243d8152793dc00eb4faf01 \ + --hash=sha256:5d5c411a8eaa2299322b647cd932586b1427367fd3184ffbb8f7a219ea2041ca \ + --hash=sha256:6846bd2d116ff42cba6b646edf5bf61d37e5cbd256425fa089fee4ff5c07a99e \ + --hash=sha256:6ace95230bfb7cd79ef66caa064bbe2f2a1e63d93471c3a2e1f1348d9f22d6b7 \ + --hash=sha256:6e51b71417049ad6ab14c49608b4a24d8fb3fe605e5dfabfe523b58064dc3d27 \ + --hash=sha256:71db6b4c1653045dacc1585c1b0d184004f0d7e694c7b34ac165ca70c0838082 \ + --hash=sha256:7438839e9e053ef79f7112c881cef684013855016f928b168b81ed5835f3e75e \ + --hash=sha256:759de84a33be3b178a64c8ba28ad5c135900359e85fb662bc6e403ad4407791d \ + --hash=sha256:792a2c0be4dcc18af9d4a2dfd8a11a17d5e25274a1062b0ec1c2d79c76f3e7f8 \ + --hash=sha256:7d87ef5795da03d742bf49439f9ca4d027cde49c82c5371ba52464aee266699a \ + --hash=sha256:7dfb439562f234f7d57b1ac6bc8fe7f838a4bd49c79230e0f6a1da93e82f1fad \ + --hash=sha256:7fa22993bac7b77b78cae22bad1e2a987ddf0d9015c63358032f84a53f23cdc3 \ + --hash=sha256:805ebf596939e48dbb2e4922a1d3852cfc25c38160751ce02da93058b48d252a \ + --hash=sha256:82240051c6ca513c616f7f9da06e871f61bfd7805f566275841af15015b8f98d \ + --hash=sha256:87d4f8125c9988bfbed67af47dd7a953e2fc7b0cc1e7800ec6d2080d490bb353 \ + --hash=sha256:8d8ca2b210ada074d57fcee40c30446c9562e542fc46aedc19baf758a93532ee \ + --hash=sha256:8dc232e39d409036af549c86f24aed8273a40ffa459981146829a324e0848b4b \ + --hash=sha256:90387104ee8400a7b4598253b4c406f8958f59fcf983a6cea2b50d59f7d63d0b \ + --hash=sha256:905b0365b210c73afb0ebe9101a32572152dfd1c144c7e28968a331b9217b94a \ + --hash=sha256:99353a06902c2e43b43e8ff74ee65a7d90307d82370604746738a1e0661ccca7 \ + --hash=sha256:99a7f72fb6249302aa62245680754862a44179b545ded638cf1fef59befb57ef \ + --hash=sha256:9f0b04c6b8584c2c193babcccc908b38ed29524b29dd464bc8801bf10d746a3a \ + --hash=sha256:9fe611163f6303d1619bbcb653540a4d60f9e55e622d60a3108be0d5b441017a \ + --hash=sha256:a3475b96f5908b3b16c47533daaa87380c491357d197564e0ba34ae75c0f3257 \ + --hash=sha256:a6597ff2b61d121172f5844b53f21467f7082f5fb385a9a29c01414463f93b07 \ + --hash=sha256:a7921c5a6d31b3d756ec980f2f47c0cfdbce0fc48c22a39347a895f41f4a6ea4 \ + --hash=sha256:aa5129de4e174daccbc59d0a3b6d20eaf24417d59851c07ebb37aeb02947987c \ + --hash=sha256:aeaefa96c768fc66818730b952a862235d68825c178f1b3ffd4efd7ad2edcb7c \ + --hash=sha256:afbefa430092f71a9593a99ab6a4e7538bc9eabbf7bf94f91510d3503943edc4 \ + --hash=sha256:aff9e4d82d082ff9513bdd6acd4f5bd359f5b2c870907d2b0a9c5e10d40c88fe \ + --hash=sha256:b22bd8c974942477156be55a768f7aa37c46904c175be4e158b6a86e3a6b7ca8 \ + --hash=sha256:b290fd8aa38422444d4b50d579de197557f182ef1068b75f5aa8558638b8d0a5 \ + --hash=sha256:b2e4b27a6e15b04832fe9bf292b94b5ca156016bbc1ea9c2c20098a0320d6cf6 \ + --hash=sha256:b583dc9070312190192631373c6c8ed277254aa6e6084b74bdd0a6d3b221608e \ + --hash=sha256:b87843e225e74576437fd5b6a4c2205d422754f84a06942cfaf1dc32243e45a8 \ + --hash=sha256:bc91a56697869546d1b8f0a3ff35224557ae7f881050e99f615e0119bf934b4e \ + --hash=sha256:bd87e140e45399c818fac4247880b9ce719e4783d767e030a883a970be632275 \ + --hash=sha256:bde737cff1a975b70652b62d626f7785e0480918dece11e8fef3c0cf057351c3 \ + --hash=sha256:bdee52571a343d721fb2eb3b090a82d959ff37fc631e3f70422e0c2e029f3e76 \ + --hash=sha256:bee2a6db3a7242ea309aa7ee8e2780726fed67ff4e5b40169f2c940e7eb09227 \ + --hash=sha256:beeae3f27f62308f1ddbcfb0690bf44b10732f2ef43758f169d5e9303165d3f9 \ + --hash=sha256:c50f36a62a22d350c96e49ad02d0da41dbd17ddc2e29750dbdba4323f85eb4a5 \ + --hash=sha256:c607c90ba67533e1b2355b821fef6764d1dd2cbe26b8c1005ae84f7aea25ff79 \ + --hash=sha256:c7b2a63fd6d5246349f3d3f37b14430d73ee7e8173154461785e43036ffa96ca \ + --hash=sha256:c828a1ae702fc712978bda0320ba1b9893d99be0badf2647f693cc01cf0f04fa \ + --hash=sha256:c85de1136429c524e55cfa4e033b4a7940ac5c8ee4d9401cc2d1bf48154bbc7b \ + --hash=sha256:c98fa880d695de164b4135a52fd2e9cd7b7c90a9d8ac5e9e443a24a95ef9248e \ + --hash=sha256:cae81479f77420d217def5f54b5b9d279804d17e982e0f2fa19b1d1e14ab5197 \ + --hash=sha256:d034140032870024e6b9892c692fe2968493790dd57208b2c37e3fb35f6df3ab \ + --hash=sha256:d120c38a42c234dc9a8c5de7ceaaf899cf33561956acb4941653f8bdc657aa79 \ + --hash=sha256:d4827615da15cd59784ce39d3388275ec093ae3ee8d7f0c089b76fa87af756c2 \ + --hash=sha256:d49e2314c373f4c2b39446fb1a45ed333c850e09d0c59ac79b72eb3b95397363 \ + --hash=sha256:d52610d51e265a51518692045e372a4c363056130d922a7351429ac9f27e70b0 \ + --hash=sha256:d64317d2587c70324b79861babb9c09f71fbb780bad212018874b2c013d8600e \ + --hash=sha256:d77153e14b709fd8b8af6f66a3afbb9ed6e9fc5ccf0b6b7e1ced7b036a228782 \ + --hash=sha256:d7e091d464ac59d2c7ad8e7e08105eaf9dafbc3883fd7265ffccc2baad6ac925 \ + --hash=sha256:dd333073e0cacdc3089525c7df7d39b211bcdf31fc2824e49d01c6b6187b07d0 \ + --hash=sha256:e5d8efac84c9afcb40914ab49ba063d94f5dbdf5066db4482c66a992f47a3a3b \ + --hash=sha256:f135c702ac42262573fe9714dfe99c944b4ba307af5eb507abef1667e2cbbced \ + --hash=sha256:f13711b1a5ba512d647a0e4ba79280d3a9a045aaf7e0cc6fbe96b91d4cdf6b0c \ + --hash=sha256:f4f1231b7dec408e8670264ce63e9c71409d9583dd21d32c163e25213ee2a344 \ + --hash=sha256:fa3ed2a29a9e9d2d488b4da81dcb54720ac3104a20bf0bd273f1e4648aff5af9 \ + --hash=sha256:fb3096c30df99fd01c7bf8e544f392103d0795b9f98ba71a8054bcbf56b255f1 + # via + # -r build/test-requirements.txt + # matplotlib +pluggy==1.6.0 \ + --hash=sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3 \ + --hash=sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746 + # via pytest +portpicker==1.6.0 \ + --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ + --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa + # via -r build/test-requirements.txt +psutil==7.1.3 \ + --hash=sha256:0005da714eee687b4b8decd3d6cc7c6db36215c9e74e5ad2264b90c3df7d92dc \ + --hash=sha256:1068c303be3a72f8e18e412c5b2a8f6d31750fb152f9cb106b54090296c9d251 \ + --hash=sha256:18349c5c24b06ac5612c0428ec2a0331c26443d259e2a0144a9b24b4395b58fa \ + --hash=sha256:19644c85dcb987e35eeeaefdc3915d059dac7bd1167cdcdbf27e0ce2df0c08c0 \ + --hash=sha256:2bdbcd0e58ca14996a42adf3621a6244f1bb2e2e528886959c72cf1e326677ab \ + --hash=sha256:31d77fcedb7529f27bb3a0472bea9334349f9a04160e8e6e5020f22c59893264 \ + --hash=sha256:3792983e23b69843aea49c8f5b8f115572c5ab64c153bada5270086a2123c7e7 \ + --hash=sha256:3bb428f9f05c1225a558f53e30ccbad9930b11c3fc206836242de1091d3e7dd3 \ + --hash=sha256:56d974e02ca2c8eb4812c3f76c30e28836fffc311d55d979f1465c1feeb2b68b \ + --hash=sha256:6c86281738d77335af7aec228328e944b30930899ea760ecf33a4dba66be5e74 \ + --hash=sha256:8f33a3702e167783a9213db10ad29650ebf383946e91bc77f28a5eb083496bc9 \ + --hash=sha256:95ef04cf2e5ba0ab9eaafc4a11eaae91b44f4ef5541acd2ee91d9108d00d59a7 \ + --hash=sha256:ad81425efc5e75da3f39b3e636293360ad8d0b49bed7df824c79764fb4ba9b8b \ + --hash=sha256:b403da1df4d6d43973dc004d19cee3b848e998ae3154cc8097d139b77156c353 \ + --hash=sha256:bc31fa00f1fbc3c3802141eede66f3a2d51d89716a194bf2cd6fc68310a19880 \ + --hash=sha256:bd0d69cee829226a761e92f28140bec9a5ee9d5b4fb4b0cc589068dbfff559b1 \ + --hash=sha256:c525ffa774fe4496282fb0b1187725793de3e7c6b29e41562733cae9ada151ee \ + --hash=sha256:f39c2c19fe824b47484b96f9692932248a54c43799a84282cfe58d05a6449efd \ + --hash=sha256:fac9cd332c67f4422504297889da5ab7e05fd11e3c4392140f7370f4208ded1f + # via portpicker +pyelftools==0.32 \ + --hash=sha256:013df952a006db5e138b1edf6d8a68ecc50630adbd0d83a2d41e7f846163d738 \ + --hash=sha256:6de90ee7b8263e740c8715a925382d4099b354f29ac48ea40d840cf7aa14ace5 + # via auditwheel +pygments==2.19.2 \ + --hash=sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887 \ + --hash=sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b + # via + # pytest + # rich +pyparsing==3.2.5 \ + --hash=sha256:2df8d5b7b2802ef88e8d016a2eb9c7aeaa923529cd251ed0fe4608275d4105b6 \ + --hash=sha256:e38a4f02064cf41fe6593d328d0512495ad1f3d8a91c4f73fc401b3079a59a5e + # via matplotlib +pyproject-hooks==1.2.0 \ + --hash=sha256:1e859bd5c40fae9448642dd871adf459e5e2084186e8d2c2a79a824c970da1f8 \ + --hash=sha256:9e5c6bfa8dcc30091c74b0cf803c81fdd29d94f01992a7707bc97babb1141913 + # via build +pytest==8.4.2 \ + --hash=sha256:86c0d0b93306b961d58d62a4db4879f27fe25513d4b969df351abdddb3c30e01 \ + --hash=sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79 + # via + # -r build/test-requirements.txt + # pytest-xdist +pytest-xdist==3.8.0 \ + --hash=sha256:202ca578cfeb7370784a8c33d6d05bc6e13b4f25b5053c30a152269fd10f0b88 \ + --hash=sha256:7e578125ec9bc6050861aa93f2d59f1d8d085595d6551c2c90b6f4fad8d3a9f1 + # via -r build/test-requirements.txt +python-dateutil==2.9.0.post0 \ + --hash=sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3 \ + --hash=sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427 + # via matplotlib +rich==14.2.0 \ + --hash=sha256:73ff50c7c0c1c77c8243079283f4edb376f0f6442433aecb8ce7e6d0b92d1fe4 \ + --hash=sha256:76bc51fe2e57d2b1be1f96c524b890b816e334ab4c1e45888799bfaab0021edd + # via -r build/test-requirements.txt +scipy==1.16.3 ; python_version >= "3.13" \ + --hash=sha256:0151a0749efeaaab78711c78422d413c583b8cdd2011a3c1d6c794938ee9fdb2 \ + --hash=sha256:01e87659402762f43bd2fee13370553a17ada367d42e7487800bf2916535aecb \ + --hash=sha256:03192a35e661470197556de24e7cb1330d84b35b94ead65c46ad6f16f6b28f2a \ + --hash=sha256:0553371015692a898e1aa858fed67a3576c34edefa6b7ebdb4e9dde49ce5c203 \ + --hash=sha256:062246acacbe9f8210de8e751b16fc37458213f124bef161a5a02c7a39284304 \ + --hash=sha256:0c3b4dd3d9b08dbce0f3440032c52e9e2ab9f96ade2d3943313dfe51a7056959 \ + --hash=sha256:0c623a54f7b79dd88ef56da19bc2873afec9673a48f3b85b18e4d402bdd29a5a \ + --hash=sha256:16b8bc35a4cc24db80a0ec836a9286d0e31b2503cb2fd7ff7fb0e0374a97081d \ + --hash=sha256:1fb2472e72e24d1530debe6ae078db70fb1605350c88a3d14bc401d6306dbffe \ + --hash=sha256:21d9d6b197227a12dcbf9633320a4e34c6b0e51c57268df255a0942983bac562 \ + --hash=sha256:2a207a6ce9c24f1951241f4693ede2d393f59c07abc159b2cb2be980820e01fb \ + --hash=sha256:2b71d93c8a9936046866acebc915e2af2e292b883ed6e2cbe5c34beb094b82d9 \ + --hash=sha256:2d1ae2cf0c350e7705168ff2429962a89ad90c2d49d1dd300686d8b2a5af22fc \ + --hash=sha256:3a4c460301fb2cffb7f88528f30b3127742cff583603aa7dc964a52c463b385d \ + --hash=sha256:3d4a07a8e785d80289dfe66b7c27d8634a773020742ec7187b85ccc4b0e7b686 \ + --hash=sha256:40be6cf99e68b6c4321e9f8782e7d5ff8265af28ef2cd56e9c9b2638fa08ad97 \ + --hash=sha256:4aff59800a3b7f786b70bfd6ab551001cb553244988d7d6b8299cb1ea653b353 \ + --hash=sha256:50a3dbf286dbc7d84f176f9a1574c705f277cb6565069f88f60db9eafdbe3ee2 \ + --hash=sha256:532fb5ad6a87e9e9cd9c959b106b73145a03f04c7d57ea3e6f6bb60b86ab0876 \ + --hash=sha256:53c3844d527213631e886621df5695d35e4f6a75f620dca412bcd292f6b87d78 \ + --hash=sha256:56edc65510d1331dae01ef9b658d428e33ed48b4f77b1d51caf479a0253f96dc \ + --hash=sha256:57d01cb6f85e34f0946b33caa66e892aae072b64b034183f3d87c4025802a119 \ + --hash=sha256:5803c5fadd29de0cf27fa08ccbfe7a9e5d741bf63e4ab1085437266f12460ff9 \ + --hash=sha256:6020470b9d00245926f2d5bb93b119ca0340f0d564eb6fbaad843eaebf9d690f \ + --hash=sha256:63d3cdacb8a824a295191a723ee5e4ea7768ca5ca5f2838532d9f2e2b3ce2135 \ + --hash=sha256:663b8d66a8748051c3ee9c96465fb417509315b99c71550fda2591d7dd634234 \ + --hash=sha256:72d1717fd3b5e6ec747327ce9bda32d5463f472c9dce9f54499e81fbd50245a1 \ + --hash=sha256:7dc1360c06535ea6116a2220f760ae572db9f661aba2d88074fe30ec2aa1ff88 \ + --hash=sha256:7f68154688c515cdb541a31ef8eb66d8cd1050605be9dcd74199cbd22ac739bc \ + --hash=sha256:81fc5827606858cf71446a5e98715ba0e11f0dbc83d71c7409d05486592a45d6 \ + --hash=sha256:875555ce62743e1d54f06cdf22c1e0bc47b91130ac40fe5d783b6dfa114beeb6 \ + --hash=sha256:8b3c820ddb80029fe9f43d61b81d8b488d3ef8ca010d15122b152db77dc94c22 \ + --hash=sha256:8be1ca9170fcb6223cc7c27f4305d680ded114a1567c0bd2bfcbf947d1b17511 \ + --hash=sha256:8d09d72dc92742988b0e7750bddb8060b0c7079606c0d24a8cc8e9c9c11f9079 \ + --hash=sha256:9452781bd879b14b6f055b26643703551320aa8d79ae064a71df55c00286a184 \ + --hash=sha256:96491a6a54e995f00a28a3c3badfff58fd093bf26cd5fb34a2188c8c756a3a2c \ + --hash=sha256:9b9c9c07b6d56a35777a1b4cc8966118fb16cfd8daf6743867d17d36cfad2d40 \ + --hash=sha256:a8a26c78ef223d3e30920ef759e25625a0ecdd0d60e5a8818b7513c3e5384cf2 \ + --hash=sha256:aadd23f98f9cb069b3bd64ddc900c4d277778242e961751f77a8cb5c4b946fb0 \ + --hash=sha256:b7180967113560cca57418a7bc719e30366b47959dd845a93206fbed693c867e \ + --hash=sha256:b7c5f1bda1354d6a19bc6af73a649f8285ca63ac6b52e64e658a5a11d4d69800 \ + --hash=sha256:b81c27fc41954319a943d43b20e07c40bdcd3ff7cf013f4fb86286faefe546c4 \ + --hash=sha256:bb61878c18a470021fb515a843dc7a76961a8daceaaaa8bad1332f1bf4b54657 \ + --hash=sha256:bea0a62734d20d67608660f69dcda23e7f90fb4ca20974ab80b6ed40df87a005 \ + --hash=sha256:c5192722cffe15f9329a3948c4b1db789fbb1f05c97899187dcf009b283aea70 \ + --hash=sha256:c97176013d404c7346bf57874eaac5187d969293bf40497140b0a2b2b7482e07 \ + --hash=sha256:cd13e354df9938598af2be05822c323e97132d5e6306b83a3b4ee6724c6e522e \ + --hash=sha256:d2ec56337675e61b312179a1ad124f5f570c00f920cc75e1000025451b88241c \ + --hash=sha256:d3837938ae715fc0fe3c39c0202de3a8853aff22ca66781ddc2ade7554b7e2cc \ + --hash=sha256:d9f48cafc7ce94cf9b15c6bffdc443a81a27bf7075cf2dcd5c8b40f85d10c4e7 \ + --hash=sha256:da7763f55885045036fabcebd80144b757d3db06ab0861415d1c3b7c69042146 \ + --hash=sha256:deb3841c925eeddb6afc1e4e4a45e418d19ec7b87c5df177695224078e8ec733 \ + --hash=sha256:e1d27cbcb4602680a49d787d90664fa4974063ac9d4134813332a8c53dbe667c \ + --hash=sha256:e5d42a9472e7579e473879a1990327830493a7047506d58d73fc429b84c1d49d \ + --hash=sha256:e7efa2681ea410b10dde31a52b18b0154d66f2485328830e45fdf183af5aefc6 \ + --hash=sha256:eab43fae33a0c39006a88096cd7b4f4ef545ea0447d250d5ac18202d40b6611d \ + --hash=sha256:f2622206f5559784fa5c4b53a950c3c7c1cf3e84ca1b9c4b6c03f062f289ca26 \ + --hash=sha256:f379b54b77a597aa7ee5e697df0d66903e41b9c85a6dd7946159e356319158e8 \ + --hash=sha256:f667a4542cc8917af1db06366d3f78a5c8e83badd56409f94d1eac8d8d9133fa \ + --hash=sha256:fb4b29f4cf8cc5a8d628bc8d8e26d12d7278cd1f219f22698a378c3d67db5e4b \ + --hash=sha256:ffa6eea95283b2b8079b821dc11f50a17d0571c92b43e2b5b12764dc5f9b285d + # via + # -r build/requirements.in + # jaxlib +scipy-stubs==1.16.3.3 \ + --hash=sha256:af47578875d5557567225a16ec1b9b38a48c4c4377d92396413ebd65406c44ee \ + --hash=sha256:f6316b36cd0fb272c994ae5b10c4a73c644a7e156ed8d32bcd9c35303d0e1b7e + # via -r build/test-requirements.txt +six==1.17.0 \ + --hash=sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274 \ + --hash=sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81 + # via python-dateutil +sortedcontainers==2.4.0 \ + --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \ + --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0 + # via hypothesis +typing-extensions==4.15.0 \ + --hash=sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466 \ + --hash=sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548 + # via etils +wheel==0.46.1 \ + --hash=sha256:f796f65d72750ccde090663e466d0ca37cd72b62870f7520b96d34cdc07d86d8 \ + --hash=sha256:fd477efb5da0f7df1d3c76c73c14394002c844451bd63229d8570f376f5e6a38 + # via -r build/requirements.in +zipp==3.23.0 \ + --hash=sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e \ + --hash=sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166 + # via etils + +# The following packages are considered to be unsafe in a requirements file: +setuptools==80.9.0 \ + --hash=sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922 \ + --hash=sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c + # via -r build/requirements.in diff --git a/build/rocm-test-requirements.txt b/build/rocm-test-requirements.txt new file mode 100644 index 000000000000..399237175957 --- /dev/null +++ b/build/rocm-test-requirements.txt @@ -0,0 +1,24 @@ +absl-py +build +cloudpickle +colorama>=0.4.4 +filelock +flatbuffers +hypothesis +mpmath>=1.3 +pillow>=10.4.0 +# TODO(kanglan): Remove once psutil from portpicker supports python 3.13t +portpicker; python_version<"3.13" +pytest-xdist +pytest-json-report +pytest-html +pytest-csv +pytest-rerunfailures +pytest-html-merger +pytest-reportlog +wheel +rich +setuptools +matplotlib +opt-einsum +auditwheel diff --git a/build/rocm/Dockerfile.ms b/build/rocm/Dockerfile.ms index a084045256de..40b4decaafb4 100644 --- a/build/rocm/Dockerfile.ms +++ b/build/rocm/Dockerfile.ms @@ -40,7 +40,7 @@ RUN --mount=type=cache,target=/var/cache/apt \ liblzma-dev # Install pyenv with different python versions -ARG PYTHON_VERSION=3.10.14 +ARG PYTHON_VERSION=3.11.13 RUN git clone https://github.com/pyenv/pyenv.git /pyenv ENV PYENV_ROOT /pyenv ENV PATH $PYENV_ROOT/shims:$PYENV_ROOT/bin:$PATH diff --git a/build/rocm/README.md b/build/rocm/README.md index 58427826f73f..f244df45f2a3 100644 --- a/build/rocm/README.md +++ b/build/rocm/README.md @@ -8,13 +8,13 @@ The ROCm JAX team provides prebuilt Docker images, which the simplest way to use To pull the latest ROCm JAX Docker image, run: ```Bash -> docker pull rocm/jax-community:latest +> docker pull rocm/jax:latest ``` Once the image is downloaded, launch a container using the following command: ```Bash -> docker run -it -d --network=host --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size 64G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v $(pwd):/jax_dir --name rocm_jax rocm/jax-community:latest /bin/bash +> docker run -it -d --network=host --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size 64G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v $(pwd):/jax_dir --name rocm_jax rocm/jax:latest /bin/bash > docker attach rocm_jax ``` @@ -24,7 +24,7 @@ Once the image is downloaded, launch a container using the following command: 2. Replace `$(pwd)` with the absolute path to the directory you want to mount inside the container. ***For older versions please review the periodically pushed docker images at: -[ROCm JAX Community DockerHub](https://hub.docker.com/r/rocm/jax-community/tags).*** +[ROCm JAX DockerHub](https://hub.docker.com/r/rocm/jax/tags).*** ### Testing your ROCm environment with JAX: @@ -32,7 +32,7 @@ After launching the container, test whether JAX detects ROCm devices as expected ```Bash > python -c "import jax; print(jax.devices())" -[RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3)] +[RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3), RocmDevice(id=4), RocmDevice(id=5), RocmDevice(id=6), RocmDevice(id=7)] ``` If the setup is successful, the output should list all available ROCm devices. @@ -46,7 +46,7 @@ If you prefer to use the ROCm Ubuntu image or already have a ROCm Ubuntu contain For example, use the following command to pull the ROCm Ubuntu image: ```Bash -> docker pull rocm/dev-ubuntu-22.04:6.3-complete +> docker pull rocm/dev-ubuntu-24.04:7.0.2-complete ``` ### Step 2: Launch the Docker Container @@ -54,16 +54,23 @@ For example, use the following command to pull the ROCm Ubuntu image: After pulling the image, launch a container using this command: ```Bash -> docker run -it -d --network=host --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size 64G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v $(pwd):/jax_dir --name rocm_jax rocm/dev-ubuntu-22.04:6.3-complete /bin/bash +> docker run -it -d --network=host --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size 64G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v $(pwd):/jax_dir --name rocm_jax rocm/dev-ubuntu-24.04:7.0.2-complete /bin/bash > docker attach rocm_jax ``` ### Step 3: Install the Latest Version of JAX -Inside the running container, install the required version of JAX with ROCm support using pip: +Install the required version of JAX and the ROCm plugins using pip. Follow the +instructions for the [latest +release](https://github.com/ROCm/rocm-jax/releases). For example, on a system +with python 3.12, you will need to run the following to install `jax 0.6.2`: ```Bash -> pip3 install jax[rocm] +> pip3 install \ + https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.6.0/jax_rocm7_pjrt-0.6.0-py3-none-manylinux_2_28_x86_64.whl \ + https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.6.0/jax_rocm7_plugin-0.6.0-cp312-cp312-manylinux_2_28_x86_64.whl \ + https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.6.0/jaxlib-0.6.2-cp312-cp312-manylinux2014_x86_64.whl \ + jax==0.6.2 ``` ### Step 4: Verify the Installed JAX Version @@ -72,10 +79,10 @@ Check whether the correct version of JAX and its ROCm plugins are installed: ```Bash > pip3 freeze | grep jax -jax==0.4.35 -jax-rocm60-pjrt==0.4.35 -jax-rocm60-plugin==0.4.35 -jaxlib==0.4.35 +jax==0.6.2 +jax-rocm7-pjrt @ https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.6.0/jax_rocm7_pjrt-0.6.0-py3-none-manylinux_2_28_x86_64.whl#sha256=b20b6820d4701a8edd83509dcbc8dc4fb712f40eab873668ae0dd17f5194c2d6 +jax-rocm7-plugin @ https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.6.0/jax_rocm7_plugin-0.6.0-cp312-cp312-manylinux_2_28_x86_64.whl#sha256=cfecc2865ed450f996608b13af04189a2f9c1328ed896d71be0872d0e7d78389 +jaxlib @ https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.6.0/jaxlib-0.6.2-cp312-cp312-manylinux2014_x86_64.whl#sha256=739fc2ebe28399f551a5c6daf529baae1637546a9a2a93789e3afd7ef0444e66 ``` ### Step 5: Set the `LLVM_PATH` Environment Variable @@ -92,7 +99,7 @@ Run the following command to verify that ROCm JAX is installed correctly: ```Bash > python3 -c "import jax; print(jax.devices())" -[RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3)] +[RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3), RocmDevice(id=4), RocmDevice(id=5), RocmDevice(id=6), RocmDevice(id=7)] > python3 -c "import jax.numpy as jnp; x = jnp.arange(5); print(x)" [0 1 2 3 4] @@ -112,30 +119,36 @@ Once installed, verify ROCm installation using: ```Bash > rocm-smi - -========================================== ROCm System Management Interface ========================================== -==================================================== Concise Info ==================================================== -Device [Model : Revision] Temp Power Partitions SCLK MCLK Fan Perf PwrCap VRAM% GPU% - Name (20 chars) (Junction) (Socket) (Mem, Compute) -====================================================================================================================== -0 [0x74a1 : 0x00] 50.0°C 170.0W NPS1, SPX 131Mhz 900Mhz 0% auto 750.0W 0% 0% - AMD Instinct MI300X -1 [0x74a1 : 0x00] 51.0°C 176.0W NPS1, SPX 132Mhz 900Mhz 0% auto 750.0W 0% 0% - AMD Instinct MI300X -2 [0x74a1 : 0x00] 50.0°C 177.0W NPS1, SPX 132Mhz 900Mhz 0% auto 750.0W 0% 0% - AMD Instinct MI300X -3 [0x74a1 : 0x00] 53.0°C 176.0W NPS1, SPX 132Mhz 900Mhz 0% auto 750.0W 0% 0% - AMD Instinct MI300X -====================================================================================================================== -================================================ End of ROCm SMI Log ================================================= +============================================ ROCm System Management Interface ============================================ +====================================================== Concise Info ====================================================== +Device Node IDs Temp Power Partitions SCLK MCLK Fan Perf PwrCap VRAM% GPU% + (DID, GUID) (Junction) (Socket) (Mem, Compute, ID) +========================================================================================================================== +0 2 0x74a1, 28851 43.0°C 142.0W NPS1, SPX, 0 133Mhz 900Mhz 0% auto 750.0W 0% 0% +1 3 0x74a1, 23018 37.0°C 137.0W NPS1, SPX, 0 134Mhz 900Mhz 0% auto 750.0W 0% 0% +2 4 0x74a1, 29122 44.0°C 140.0W NPS1, SPX, 0 134Mhz 900Mhz 0% auto 750.0W 0% 0% +3 5 0x74a1, 22683 38.0°C 138.0W NPS1, SPX, 0 133Mhz 900Mhz 0% auto 750.0W 0% 0% +4 6 0x74a1, 53458 42.0°C 143.0W NPS1, SPX, 0 133Mhz 900Mhz 0% auto 750.0W 0% 0% +5 7 0x74a1, 63883 39.0°C 138.0W NPS1, SPX, 0 134Mhz 900Mhz 0% auto 750.0W 0% 0% +6 8 0x74a1, 53667 42.0°C 140.0W NPS1, SPX, 0 134Mhz 900Mhz 0% auto 750.0W 0% 0% +7 9 0x74a1, 63738 38.0°C 135.0W NPS1, SPX, 0 133Mhz 900Mhz 0% auto 750.0W 0% 0% +========================================================================================================================== +================================================== End of ROCm SMI Log =================================================== ``` ### Step 2: Install the Latest Version of JAX -Install the required version of JAX with ROCm support using pip: +Install the required version of JAX and the ROCm plugins using pip. Follow the +instructions for the [latest +release](https://github.com/ROCm/rocm-jax/releases). For example, on a system +with python 3.12, you will need to run the following to install `jax 0.6.2`: ```Bash -> pip3 install jax[rocm] +> pip3 install \ + https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.6.0/jax_rocm7_pjrt-0.6.0-py3-none-manylinux_2_28_x86_64.whl \ + https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.6.0/jax_rocm7_plugin-0.6.0-cp312-cp312-manylinux_2_28_x86_64.whl \ + https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.6.0/jaxlib-0.6.2-cp312-cp312-manylinux2014_x86_64.whl \ + jax==0.6.2 ``` ### Step 3: Verify the Installed JAX Version @@ -144,10 +157,10 @@ Check whether the correct version of JAX and its ROCm plugins are installed: ```Bash > pip3 freeze | grep jax -jax==0.4.35 -jax-rocm60-pjrt==0.4.35 -jax-rocm60-plugin==0.4.35 -jaxlib==0.4.35 +jax==0.6.2 +jax-rocm7-pjrt @ https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.6.0/jax_rocm7_pjrt-0.6.0-py3-none-manylinux_2_28_x86_64.whl#sha256=b20b6820d4701a8edd83509dcbc8dc4fb712f40eab873668ae0dd17f5194c2d6 +jax-rocm7-plugin @ https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.6.0/jax_rocm7_plugin-0.6.0-cp312-cp312-manylinux_2_28_x86_64.whl#sha256=cfecc2865ed450f996608b13af04189a2f9c1328ed896d71be0872d0e7d78389 +jaxlib @ https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.6.0/jaxlib-0.6.2-cp312-cp312-manylinux2014_x86_64.whl#sha256=739fc2ebe28399f551a5c6daf529baae1637546a9a2a93789e3afd7ef0444e66 ``` ### Step 4: Set the `LLVM_PATH` Environment Variable @@ -164,7 +177,7 @@ Run the following command to verify that ROCm JAX is installed correctly: ```Bash > python3 -c "import jax; print(jax.devices())" -[RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3)] +[RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3), RocmDevice(id=4), RocmDevice(id=5), RocmDevice(id=6), RocmDevice(id=7)] > python3 -c "import jax.numpy as jnp; x = jnp.arange(5); print(x)" [0 1 2 3 4] @@ -174,7 +187,31 @@ Run the following command to verify that ROCm JAX is installed correctly: Follow these steps to build JAX with ROCm support from source: -### Step 1: Clone the Repository +### Step 1: Build the ROCm specific wheels from `rocm-jax` + +Clone the `rocm-jax` repository for the desired branch: + +```Bash +> git clone https://github.com/ROCm/rocm-jax.git -b +> cd rocm-jax +``` +From the `rocm-jax` directory run: +```Bash +> python3 build/ci_build \ + --python-version $PYTHON_VERSION \ + --rocm_version $ROCM_VERSION \ + dist_wheels +> pip3 install jax_rocm_plugin/wheelhouse/*.whl +``` +The build will produce two wheels: + +* `jax-rocm-plugin` (ROCm-specific plugin) +* `jax-rocm-pjrt` (ROCm-specific runtime) + +Detailed build instructions can be found +[here](https://github.com/ROCm/rocm-jax/blob/master/BUILDING.md). + +### Step 2: Build `jaxlib` from the JAX Repository Clone the ROCm-specific fork of JAX for the desired branch: @@ -183,20 +220,15 @@ Clone the ROCm-specific fork of JAX for the desired branch: > cd jax ``` -### Step 2: Build the Wheels - -Run the following command to build the necessary wheels: +Run the following command to build the `jaxlib` wheel: ```Bash -> python3 ./build/build.py build --wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt \ - --rocm_version=60 --rocm_path=/opt/rocm-[version] +> python3 ./build/build.py build --wheels=jaxlib \ + --rocm_version=7 --rocm_path=/opt/rocm-[version] ``` -This will generate three wheels in the `dist/` directory: - -* jaxlib (generic, device agnostic library) -* jax-rocm-plugin (ROCm-specific plugin) -* jax-rocm-pjrt (ROCm-specific runtime) +This will generate the `jaxlib` wheel in the `dist/` directory. `jaxlib` is a +device agnostic library. ### Step 3: Then install custom JAX using: diff --git a/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm index 08b6bd3ff8d6..3ca491568911 100644 --- a/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm +++ b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm @@ -9,7 +9,7 @@ ARG ROCM_BUILD_NUM # manylinux base image. However, adding this does fix an issue where Bazel isn't able # to find them. RUN --mount=type=cache,target=/var/cache/dnf \ - dnf install -y gcc-c++-8.5.0-22.el8_10.x86_64 + dnf install -y numactl-devel RUN --mount=type=cache,target=/var/cache/dnf \ --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \ @@ -25,5 +25,11 @@ RUN mkdir /tmp/llvm-project && wget -qO - https://github.com/llvm/llvm-project/a mkdir /tmp/llvm-project/build && cd /tmp/llvm-project/build && cmake -DLLVM_ENABLE_PROJECTS='clang;lld' -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=/usr/lib/llvm-18/ ../llvm && \ make -j$(nproc) && make -j$(nproc) install && rm -rf /tmp/llvm-project +# Set some clang config +COPY ./build/rocm/build_wheels/clang.cfg /usr/lib/llvm-18/bin/clang++.cfg +COPY ./build/rocm/build_wheels/clang.cfg /usr/lib/llvm-18/bin/clang.cfg +COPY ./build/rocm/build_wheels/clang.cfg /opt/rocm/llvm/bin/clang++.cfg +COPY ./build/rocm/build_wheels/clang.cfg /opt/rocm/llvm/bin/clang.cfg + # Stop git from erroring out when we don't own the repo RUN git config --global --add safe.directory '*' diff --git a/build/rocm/build_wheels/clang.cfg b/build/rocm/build_wheels/clang.cfg new file mode 100644 index 000000000000..767c04c03ae7 --- /dev/null +++ b/build/rocm/build_wheels/clang.cfg @@ -0,0 +1,3 @@ +# Tell clang where it can find gcc so that it can use gcc's standard libraries +--gcc-toolchain=/opt/rh/gcc-toolset-14/root/usr/ + diff --git a/build/rocm/ci_build b/build/rocm/ci_build index ef43a95044d8..71ce747d7e86 100755 --- a/build/rocm/ci_build +++ b/build/rocm/ci_build @@ -98,7 +98,10 @@ def dist_wheels( bw_cmd.append("/jax") - cmd = ["docker", "run"] + cmd = [ + "docker", + "run", + ] mounts = [ "-v", diff --git a/build/rocm/ci_build.sh b/build/rocm/ci_build.sh index 386f70ee1a96..847d4e9b4b93 100755 --- a/build/rocm/ci_build.sh +++ b/build/rocm/ci_build.sh @@ -44,7 +44,7 @@ CONTAINER_TYPE="rocm" DOCKERFILE_PATH="${SCRIPT_DIR}/Dockerfile.ms" DOCKER_CONTEXT_PATH="${SCRIPT_DIR}" KEEP_IMAGE="--rm" -PYTHON_VERSION="3.10" +PYTHON_VERSION="3.11" ROCM_VERSION="6.1.3" ROCM_BUILD_JOB="" ROCM_BUILD_NUM="" diff --git a/build/rocm/docker/Dockerfile.jax-ubu22 b/build/rocm/docker/Dockerfile.jax-ubu22 index 70b16f9e9677..b6e90f2183d2 100644 --- a/build/rocm/docker/Dockerfile.jax-ubu22 +++ b/build/rocm/docker/Dockerfile.jax-ubu22 @@ -60,7 +60,7 @@ ARG JAX_COMMIT ARG XLA_COMMIT LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \ - com.amdgpu.python_version="3.10" \ + com.amdgpu.python_version="3.11" \ com.amdgpu.jax_version="$JAX_VERSION" \ com.amdgpu.jax_commit="$JAX_COMMIT" \ com.amdgpu.xla_commit="$XLA_COMMIT" diff --git a/build/rocm/setup.rocm.sh b/build/rocm/setup.rocm.sh index 3893d817e3a8..faa79d2ce1fd 100755 --- a/build/rocm/setup.rocm.sh +++ b/build/rocm/setup.rocm.sh @@ -13,7 +13,7 @@ ROCM_BUILD_NAME=ubuntu ROCM_BUILD_NUM=main # Adjust the ROCM repo location -# Intial release don't have the trialing '.0' +# Initial release don't have the trialing '.0' # For example ROCM 5.7.0 is at https://repo.radeon.com/rocm/apt/5.7/ if [ ${ROCM_VERSION##*[^0-9]} -eq '0' ]; then ROCM_VERS=${ROCM_VERSION%.*} diff --git a/build/rocm/tools/build_wheels.py b/build/rocm/tools/build_wheels.py index fd98bbb8ec04..9fdffe6cfa03 100644 --- a/build/rocm/tools/build_wheels.py +++ b/build/rocm/tools/build_wheels.py @@ -226,7 +226,10 @@ def fix_wheel(path, jax_path): py_bin = "/opt/python/cp310-cp310/bin" env["PATH"] = "%s:%s" % (py_bin, env["PATH"]) - cmd = ["pip", "install", "auditwheel>=6"] + # NOTE(mrodden): auditwheel 6.0 added lddtree module, but 6.3.0 changed + # the function to ldd and also changed its behavior + # constrain range to 6.0 to 6.2.x + cmd = ["pip", "install", "auditwheel>=6,<6.3"] subprocess.run(cmd, check=True, env=env) fixwheel_path = os.path.join(jax_path, "build/rocm/tools/fixwheel.py") @@ -248,7 +251,7 @@ def parse_args(): ) p.add_argument( "--python-versions", - default=["3.10.19,3.12"], + default=["3.11.13,3.12"], help="Comma separated CPython versions that wheels will be built and output for", ) p.add_argument( @@ -322,7 +325,7 @@ def main(): shutil.rmtree(os.path.join(args.jax_path, "jax.egg-info")) shutil.rmtree(os.path.join(args.jax_path, "jax", "__pycache__")) - # Make the wheels deleteable by the runner + # Make the wheels deletable by the runner whl_house = os.path.join(args.jax_path, "wheelhouse") logging.info("Changing permissions for %s" % whl_house) mode = 0o664 diff --git a/build/rocm/tools/fixwheel.py b/build/rocm/tools/fixwheel.py index ea77162728d5..90b5ae269627 100644 --- a/build/rocm/tools/fixwheel.py +++ b/build/rocm/tools/fixwheel.py @@ -23,7 +23,6 @@ import argparse import logging import os -from pprint import pprint import subprocess from auditwheel.lddtree import lddtree @@ -87,7 +86,7 @@ def fix_wheel(path): exclude = list(ext_libs.keys()) # call auditwheel repair with excludes - cmd = ["auditwheel", "repair", "--plat", plat, "--only-plat"] + cmd = ["auditwheel", "-v", "repair", "--plat", plat, "--only-plat"] for ex in exclude: cmd.append("--exclude") diff --git a/build/rocm/tools/symbols.py b/build/rocm/tools/symbols.py index 2982bb187c9e..b33f5e77bbbc 100644 --- a/build/rocm/tools/symbols.py +++ b/build/rocm/tools/symbols.py @@ -19,8 +19,6 @@ # needs be compatible with Python 3.6. Please do not include these # in any "upgrade" scripts - -import pprint import re import sys import subprocess diff --git a/build/test-requirements.txt b/build/test-requirements.txt index f0b315771cbb..453ebe4e18ae 100644 --- a/build/test-requirements.txt +++ b/build/test-requirements.txt @@ -1,21 +1,14 @@ absl-py -build cloudpickle -colorama>=0.4.4 filelock flatbuffers -hypothesis +hypothesis==6.142.1 # TODO(justinfu): Fix test failures surfaced by Hypothesis 6.147.0 and remove pin. mpmath>=1.3 -pillow>=10.4.0 -# TODO(kanglan): Remove once psutil from portpicker supports python 3.13t -portpicker; python_version<"3.13" +pillow>=11.3 +portpicker +pytest<9.0 # Works around https://github.com/pytest-dev/pytest/issues/13895 pytest-xdist -wheel rich -setuptools -# matplotlib 3.9.0 pins NumPy 1.23, which is incompatible with the requirement -# below. -matplotlib~=3.8.4; python_version=="3.10" -matplotlib; python_version>="3.11" -opt-einsum -auditwheel \ No newline at end of file +matplotlib +auditwheel +scipy-stubs diff --git a/build/tools/utils.py b/build/tools/utils.py index 7e375169827b..45bb2e1f1531 100644 --- a/build/tools/utils.py +++ b/build/tools/utils.py @@ -14,6 +14,7 @@ # ============================================================================== # Helper script for tools/utilities used by the JAX build CLI. import collections +import glob import hashlib import logging import os @@ -183,9 +184,6 @@ def get_compiler_path_or_exit(compiler_path_flag, compiler_name): ) sys.exit(-1) -def get_gcc_path_or_exit(): - return get_compiler_path_or_exit("gcc_path", "gcc") - def get_clang_path_or_exit(): return get_compiler_path_or_exit("clang_path", "clang") @@ -201,27 +199,31 @@ def get_clang_major_version(clang_path): return major_version -def get_gcc_major_version(gcc_path: str): - gcc_version_proc = subprocess.run( - [gcc_path, "-dumpversion"], - check=True, - capture_output=True, - text=True, - ) - major_version = int(gcc_version_proc.stdout.split(".")[0]) - - return major_version +def get_clangpp_path(clang_path): + clang_path = pathlib.Path(clang_path) + clang_exec_name = clang_path.name + clangpp_exec_name = clang_exec_name + clangpp_path = clang_path.parent / clang_exec_name + # Try and match what the user passed in (either clang-18 or clang) + if "clang++" not in clangpp_exec_name: + clangpp_exec_name = clangpp_exec_name.replace("clang", "clang++") + clangpp_path = clang_path.parent / clangpp_exec_name + if not clangpp_path.exists(): + clangpp_exec_name = "clang++" + clangpp_path = clang_path.parent / clangpp_exec_name + if not clangpp_path.exists(): + raise FileNotFoundError( + f"Failed to get clang++ path from clang path: '{clang_path!s}'. " + f"Tried the path: '{clangpp_path!s}'." + ) + return str(clangpp_path) -def get_jax_configure_bazel_options(bazel_command: list[str], use_new_wheel_build_rule: bool): +def get_jax_configure_bazel_options(bazel_command: list[str]): """Returns the bazel options to be written to .jax_configure.bazelrc.""" - # Get the index of the "run" parameter. Build options will come after "run" so - # we find the index of "run" and filter everything after it. If we are using - # the new wheel build rule, we will find the index of "build" instead. - if use_new_wheel_build_rule: - start = bazel_command.index("build") - else: - start = bazel_command.index("run") + # Get the index of the "build" parameter. Build options will come after + # "build" so we find the index of "build" and filter everything after it. + start = bazel_command.index("build") jax_configure_bazel_options = "" try: for i in range(start + 1, len(bazel_command)): @@ -233,19 +235,9 @@ def get_jax_configure_bazel_options(bazel_command: list[str], use_new_wheel_buil jax_configure_bazel_options += f"build {bazel_flag}\n" return jax_configure_bazel_options except ValueError: - logging.error("Unable to find index for 'run' in the Bazel command") + logging.error("Unable to find index for 'build' in the Bazel command") return "" -def get_githash(): - try: - return subprocess.run( - ["git", "rev-parse", "HEAD"], - encoding="utf-8", - capture_output=True, - check=True, - ).stdout.strip() - except (subprocess.CalledProcessError, OSError): - return "" def _parse_string_as_bool(s): """Parses a string as a boolean value.""" @@ -256,3 +248,43 @@ def _parse_string_as_bool(s): return False else: raise ValueError(f"Expected either 'true' or 'false'; got {s}") + + +def copy_dir_recursively(src, dst): + if os.path.exists(dst): + shutil.rmtree(dst) + os.makedirs(dst, exist_ok=True) + for root, dirs, files in os.walk(src): + relative_path = os.path.relpath(root, src) + dst_dir = os.path.join(dst, relative_path) + os.makedirs(dst_dir, exist_ok=True) + for f in files: + src_file = os.path.join(root, f) + dst_file = os.path.join(dst_dir, f) + shutil.copy2(src_file, dst_file) + logging.info("Editable wheel path: %s" % dst) + + +def copy_individual_files(src: str, dst: str, glob_pattern: str): + os.makedirs(dst, exist_ok=True) + logging.debug( + f"Copying files matching pattern {glob_pattern!r} from {src!r} to {dst!r}" + ) + for f in glob.glob(os.path.join(src, glob_pattern)): + dst_file = os.path.join(dst, os.path.basename(f)) + if os.path.exists(dst_file): + os.remove(dst_file) + shutil.copy2(f, dst_file) + logging.info("Distribution path: %s" % dst_file) + +def is_linux(os_name: str): + """Returns true if OS is Linux.""" + return os_name == "linux" + +def is_linux_x86_64(arch: str, os_name: str): + """Returns true if the architecture is Linux x86_64.""" + return arch == "x86_64" and os_name == "linux" + +def is_linux_aarch64(arch: str, os_name: str): + """Returns true if the architecture is Linux aarch64.""" + return arch == "aarch64" and os_name == "linux" diff --git a/build_wheel.py b/build_wheel.py index f8e1595d3c3a..793523e8e3b2 100644 --- a/build_wheel.py +++ b/build_wheel.py @@ -47,6 +47,25 @@ parser.add_argument( "--srcs", help="source files for the wheel", action="append" ) +parser.add_argument( + "--build-wheel-only", + default=False, + help=( + "Whether to build the wheel only. Optional." + ), +) +parser.add_argument( + "--build-source-package-only", + default=False, + help=( + "Whether to build the source package only. Optional." + ), +) +parser.add_argument( + "--editable", + action="store_true", + help="Create an 'editable' jax build instead of a wheel.", +) args = parser.parse_args() @@ -76,7 +95,11 @@ def prepare_srcs(deps: list[str], srcs_dir: str) -> None: """ for file in deps: - if not (file.startswith("bazel-out") or file.startswith("external")): + if not ( + file.startswith("bazel-out") + or file.startswith("external") + or file.startswith("jaxlib") + ): copy_file(file, srcs_dir) @@ -89,13 +112,18 @@ def prepare_srcs(deps: list[str], srcs_dir: str) -> None: try: os.makedirs(args.output_path, exist_ok=True) prepare_srcs(args.srcs, pathlib.Path(sources_path)) - build_utils.build_wheel( - sources_path, - args.output_path, - package_name="jax", - git_hash=args.jaxlib_git_hash, - build_wheel_only=False, - ) + package_name = "jax" + if args.editable: + build_utils.build_editable(sources_path, args.output_path, package_name) + else: + build_utils.build_wheel( + sources_path, + args.output_path, + package_name, + git_hash=args.jaxlib_git_hash, + build_wheel_only=args.build_wheel_only, + build_source_package_only=args.build_source_package_only, + ) finally: if tmpdir: tmpdir.cleanup() diff --git a/ci/CONTRIBUTING.md b/ci/CONTRIBUTING.md new file mode 100644 index 000000000000..e140bd38e02e --- /dev/null +++ b/ci/CONTRIBUTING.md @@ -0,0 +1,80 @@ +# Contributing to the JAX CI System + +Our CI is a hybrid system using both GitHub Actions and an internal CI for +different tasks (presubmits, continuous, nightly, and release builds). The core +logic for building and testing resides in shell scripts within the [`ci/`](https://github.com/google/jax/tree/main/ci) +directory. This ensures that the same logic can be run locally, in GitHub +Actions, and in our internal CI. GitHub Actions workflows ([`.github/workflows/`](https://github.com/google/jax/tree/main/.github/workflows)) +are primarily used to orchestrate and call these scripts with different +parameters. Configuration is managed through `JAXCI_` environment variables, +with defaults defined in [`ci/envs/default.env`](https://github.com/google/jax/blob/main/ci/envs/default.env). + +## General Principles + +* **Keep it DRY (Don't Repeat Yourself):** For common test patterns (e.g., + running CPU tests), use reusable GitHub workflows (`on: workflow_call`). We + have existing workflows like `pytest_cpu.yml` and `bazel_cpu.yml` for this. +* **Isolate Logic from Orchestration:** Complex build and test logic should be + in the [`ci/`](https://github.com/google/jax/tree/main/ci) scripts. The GitHub Actions YAML files should focus on + orchestrating the calls to these scripts, not implementing the logic itself. +* **Prioritize Presubmit Speed:** Presubmit checks, which run on every PR, + should be fast (target < 10 minutes). Offload longer-running, more + comprehensive tests to the continuous (every 3 hours) or nightly jobs. +* **Run Locally First:** Before pushing changes to CI-related files, run the + scripts locally to catch simple errors. See the "Running These Scripts + Locally on Your Machine" section in [`ci/README.md`](https://github.com/google/jax/blob/main/ci/README.md). + +## Modifying GitHub Actions Workflows + +* **Pin Actions to a Commit Hash:** All third-party GitHub Actions **must** be + pinned to a specific commit hash, not a tag or branch. This ensures our + workflows are deterministic. Use the **ratchet** tool as mentioned in + [`.github/workflows/README.md`](https://github.com/google/jax/blob/main/.github/workflows/README.md) to manage this. Example: `uses: + actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0` +* **Use Matrix Strategies:** To test across different configurations (e.g., + Python versions, runners, CUDA versions), use a `strategy: matrix`. Use + `exclude` to prune unnecessary or redundant combinations from the matrix. +* **Select Runners Carefully:** We use self-hosted runners for specific + hardware (CPU, GPU, TPU). Runner names are descriptive (e.g., + `linux-x86-n4-16`, `linux-x86-g2-48-l4-4gpu`). See + [here](https://github.com/jax-ml/jax/blob/main/.github/actionlint.yaml) for + a list of available runners; choose the most appropriate one for the task. +* **Set Permissions:** To enhance security, all workflows should explicitly + define the permissions for the `GITHUB_TOKEN`. Default to the most + restrictive permissions possible. + * For workflows that don't need to access repository contents, use + `permissions: {}`. + * Only grant write permissions when absolutely necessary. Granting write + permissions should be done with great care, and a discussion should be + had before attempting to add an action with write permissions. + * More information about GitHub Token permissions can be found + [here](https://docs.github.com/en/actions/security-guides/automatic-token-authentication#permissions-for-the-github_token). + +## Working with CI Scripts + +* **Configuration via Environment Variables:** Control script behavior using + `JAXCI_` environment variables. If you add a new variable, be sure to add a + default in [`ci/envs/default.env`](https://github.com/google/jax/blob/main/ci/envs/default.env) and document it in [`ci/envs/README.md`](https://github.com/google/jax/blob/main/ci/envs/README.md). +* **Local Execution with Docker:** For a consistent environment that mirrors + CI, use the [`ci/utilities/run_docker_container.sh`](https://github.com/google/jax/blob/main/ci/utilities/run_docker_container.sh) script. This is the + recommended way to test changes locally for Linux and Windows. + +## Dependencies + +* **XLA:** Presubmit and continuous builds use XLA at HEAD to catch + integration issues early. Nightly and release builds use a pinned XLA + version for stability. This is controlled by the `JAXCI_CLONE_MAIN_XLA` + variable. +* **Python:** Use `uv` for installing Python packages where possible, as it is + much faster than `pip`. + +## Linting and Code Style + +For any changes to GitHub Actions workflow files, it is recommended to run +[actionlint](https://github.com/rhysd/actionlint) to verify correctness and best +practices. + +## Further Reading + +For a more detailed overview of the JAX CI system, including its architecture +and workflows, please see the [`ci/README.md`](https://github.com/google/jax/blob/main/ci/README.md) file. diff --git a/ci/README.md b/ci/README.md index ea867df52f97..c9c11f4f1b53 100644 --- a/ci/README.md +++ b/ci/README.md @@ -1,10 +1,274 @@ -# JAX continuous integration +# JAX Continuous Integration -> [!WARNING] -> This folder is still under construction. It is part of an ongoing -> effort to improve the structure of CI and build related files within the -> JAX repo. This warning will be removed when the contents of this -> directory are stable and appropriate documentation around its usage is in -> place. +This folder contains the configuration files and scripts used to build and test +JAX. It is typically used by continuous integration (CI) jobs to automate builds +and run comprehensive tests across various platforms and configurations. This +page provides an overview of the JAX CI system, its components, and the +different workflows it supports. -******************************************************************************** \ No newline at end of file +******************************************************************************** + +## JAX's CI System + +![Overview of JAX's CI System](jax_ci_system.png) + +JAX's CI system is composed of several interacting components and orchestrates +builds and tests using a hybrid approach, leveraging both an internal CI system +and GitHub Actions as well as an internal build orchestrator for managing +nightly and release flows. It encompasses several distinct workflows, including +comprehensive presubmit checks triggered on pull requests and branch pushes, +bi-hourly continuous builds, extensive nightly builds with broad platform +coverage, and a controlled release process that culminates in PyPI publication. + +These flows build four packages: `jax`, `jaxlib`, `jax-cuda-plugin`, +`jax-cuda-pjrt` and support a range of environments, including: + +* **Linux x86:** CPU, TPU, CUDA +* **Linux aarch64:** CPU, CUDA +* **Windows x86:** CPU +* **Mac Arm64:** CPU + +### Architecture Overview + +1. **Internal CI System:** An internal CI system is used for specific build and + test tasks, such as nightly builds, release candidate (RC) builds, and + Mac-specific testing. + +2. **GitHub Actions:** Used for presubmit checks, continuous integration builds + and tests, and nightly/release artifact testing. + +3. **Build Orchestrator:** An internal tool used to manage complex workflows + such as nightly / release flows, promoting RC builds to release, etc. + +4. **Artifact Storage:** + +* Google Cloud Storage (GCS) Buckets: Used for temporary storage of artifacts + between jobs in GitHub Actions workflows and for storing packages built + during nightly and release flows before testing. +* Artifact Registry: Used to store nightly packages, RC packages and final + releases. +* PyPI: Where final releases are published. + +### CI Workflows and Where They Run + +JAX's CI system consists of the following workflows: + +1. **Presubmits:** Presubmits are run in GitHub actions and are triggered on + pull requests that target the `main` branch and on pushes to the `main` and + `release` branch. JAX's presubmit run time SLO is about 10 minutes so these + are typically run using Bazel with remote build execution + ([RBE](https://bazel.build/remote/rbe)). RBE allows us to execute build and + test actions on a distributed system, separate from the local machine, + instead of solely on the local machine. This enables faster build and test + times by utilizing parallel computing resources and caching across a cluster + of machines. However, we also use Pytest in workflows where we are not able + to use RBE such as the TPU presubmit. In such presubmits, we usually run a + subset of tests to be able to satisfy the presubmit run time SLO. To see the + list of the presubmit workflows, + [click here](https://github.com/search?q=repo%3Ajax-ml%2Fjax+path%3A.github%2Fworkflows%2F+%28path%3A**%2F*.yml+OR+path%3A**%2F*.yaml%29+%22pull_request%22&type=code). + +2. **Continuous:** These jobs are run in GitHub actions and are scheduled to + run once every 2 hours on the `main` branch. It builds JAX packages and runs + a wide range of tests targeting different environments such as CPU, CUDA + (L4, H100, B200, etc), and TPU (v4-8, v5e-8, etc.). For more information, + see + [wheel_tests_continuous.yml](https://github.com/jax-ml/jax/blob/main/.github/workflows/wheel_tests_continuous.yml) + ([An example run](https://github.com/jax-ml/jax/actions/workflows/wheel_tests_continuous.yml).) + +3. **Nightly Builds and Tests:** These jobs use an hybrid approach of both the + internal CI system and GitHub actions. The jobs are triggered once every + night by the internal build orchestrator tool. It first triggers the jobs in + the internal CI system to build the JAX packages for different + configurations (Python versions, CUDA versions, etc) and uploads them to a + staging bucket in GCS as well as to the nightly artifact registry. Next, + testing jobs are triggered that download the artifacts from the staging + bucket and run tests. Mac testing jobs are run in the internal CI system. + For non-Mac testing, a trigger job is run that invokes the + [wheel_tests_nightly_release.yml](https://github.com/jax-ml/jax/blob/main/.github/workflows/wheel_tests_nightly_release.yml) + workflow in GitHub Actions. JAX's nightly artifacts can be found here: + [jax](https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jax), + [jaxlib](https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jaxlib), + [jax-cuda-plugin](https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jax-cuda12-plugin), + [jax-cuda-pjrt](https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jax-cuda12-pjrt). + +4. **Release Builds and Tests:** Release flow is similar to the nightly flow + except for few differences. First, release process has to be triggered + manually in the internal build orchestrator and should be done only after a + release branch (E.g `release/0.5.3`) has been created. The build jobs build + two sets of artifacts for each package: 1. RC wheels 2. Final version + wheels. These two sets are pretty much the same package except for their + metadata and wheel tags. The RC wheels are then uploaded to the staging + bucket and release artifact registry. After the uploads are done, the test + jobs are triggered. As with the nightly flow, Mac test jobs are run in the + internal CI system while non-Mac test jobs are run in GitHub actions. To see + the GitHub actions run for a particular release, filter the workflow runs by + its branch name. + + +5. **Promote RC to Final and Publish to PyPI:** If the RC wheels pass all + testing, then we are ready to promote it as the final version and publish it + to PyPI. This entire flow is internal and is run in our internal CI system. + Final version of the packages are published to PyPI and JAX's release + artifact registry. JAX's release artifacts (RC and final versions) can be + found here: + [jax](https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-release-artifacts-registry/simple/jax), + [jaxlib](https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-release-artifacts-registry/simple/jaxlib), + [jax-cuda-plugin](https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-release-artifacts-registry/simple/jax-cuda12-plugin), + [jax-cuda-pjrt](https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-release-artifacts-registry/simple/jax-cuda12-pjrt). + +### JAX's Official CI and Build/Test Scripts + +JAX's CI jobs (both internal and those on GitHub actions) run the scripts in +this folder. An overview of the different folders and their purpose is given +below: + +- **ci/**: Contains all build scripts, environment files, and utility scripts. +- **ci/utilities/**: Contains helper scripts used throughout the build/test + process. See + [README.md](https://github.com/jax-ml/jax/blob/main/ci/utilities/README.md) + for a brief overview of these utility scripts and their behavior. +- **ci/envs/**: Holds environment files that set `JAXCI` environment variables + that control build and test configurations. see + [README.md](https://github.com/jax-ml/jax/blob/main/ci/envs/README.md) to + see the complete list of these variables and their behavior. + +Every build script in this folder first source the `JAXCI` envs in +[default.env](https://github.com/jax-ml/jax/blob/main/ci/envs/default.env) and +then run the +[setup_build_environment.sh](https://github.com/jax-ml/jax/blob/main/ci/utilities/setup_build_environment.sh) +script to set up the build environment. + +A brief overview of each build script in this folder is given below: + +> [!NOTE] +> Both internal and GitHub action jobs run under the +> [ml-build](https://github.com/tensorflow/tensorflow/tree/master/ci/official/containers) +> Docker image which contains build tools such as Python, Bazelisk, LLVM/Clang, +> manylinux compliant libraries (in Linux images), etc. + +- **build_artifacts.sh:** These build the various JAX artifacts. We build + three different type of artifacts based on the type of job: Nightly, + RC/Release, or at HEAD. +- **run_bazel_test_cpu_rbe.sh/run_bazel_test_cuda_rbe.sh**: These run Bazel + tests with RBE on every GitHub PR. We test compatibility with both CPU and + CUDA. On platforms where RBE is not natively supported (e.g Linux Arm64), we + cross-compile the test targets for Linux Aarch64 on Linux x86. As the tests + still need to be run on the host machines and because running the tests on a + single machine can take a long time, we skip running them on these + platforms. + Note for `run_bazel_test_cpu_rbe.sh`: + - If `$JAXCI_BUILD_JAXLIB=false` and `$JAXCI_BUILD_JAX=false`, these jobs + depend on local JAX wheels and therefore require that the following wheels + to be present in the `../dist` folder: `jax`, and `jaxlib` wheels. In CI + builds, we first build these wheels from source and then run the + `bazel test` command. + - If `$JAXCI_BUILD_JAXLIB=false` and `$JAXCI_BUILD_JAX=true`, CPU jobs + depend on local jaxlib wheels and therefore require that `jaxlib` wheel to + be present in the `../dist` folder. GPU obs + depend on local jaxlib and CUDA wheels, and therefore require that the + following wheels to be present in the `../dist` folder: `jaxlib`, + `jax-cuda-plugin`, and `jax-cuda-pjrt` wheels. In CI builds, we first + build these wheels from source and then run the `bazel test` command. + - If `$JAXCI_BUILD_JAXLIB=wheel` and `$JAXCI_BUILD_JAX=wheel`, the Bazel + tests use + [py_import](https://github.com/openxla/xla/blob/8190847008eddd4c7f3e57449e16d28631770823/third_party/py/py_import.bzl#L47). + - If `$JAXCI_BUILD_JAXLIB=true` and `$JAXCI_BUILD_JAX=true`, Bazel will use + individual targets in the test dependencies. +- **run_bazel_test_cuda_non_rbe.sh**: These run the following Bazel CUDA + tests: Single accelerator tests with one GPU apiece and Multi-accelerator + tests with all GPUs. + - If `$JAXCI_BUILD_JAXLIB=false` and `$JAXCI_BUILD_JAX=false`, these jobs + depend on local JAX wheels and therefore require that the following wheels + to be present in the `../dist` folder: `jax`, `jaxlib`, `jax-cuda-plugin`, + and `jax-cuda-pjrt` wheels. In CI builds, we first build these wheels from + source and then run the `bazel test` command. + - If `$JAXCI_BUILD_JAXLIB=wheel` and `$JAXCI_BUILD_JAX=wheel`, the Bazel + tests use [py_import](https://github.com/openxla/xla/blob/8190847008eddd4c7f3e57449e16d28631770823/third_party/py/py_import.bzl#L47). +- **run_pytest_*.sh**: These run tests with Pytests and use the JAX wheel + packages installed on the system. In CI builds, we build the wheels first + from source and then run the `pytest` commands. We test compatibility with + CPU, CUDA, and TPU. These are primarily run as part of the continuous and + nightly/release test jobs except for TPU which is also run as a presubmit + testing a subset of the tests. + +## Different Test Configurations + +JAX's CI Test jobs run under different test configurations. These configurations +are described briefly in the sections below. + +### XLA Versions + +JAX's CI builds rely on XLA, but use different versions depending on the type of +build. To ensure stability and reproducibility, nightly and release builds use a +pinned XLA version specified in the JAX workspace defined in [revision.bzl](https://github.com/jax-ml/jax/blob/b8b8c308a88060a3db63fa69c5cb7d8d7f1c5078/third_party/xla/revision.bzl#L23-L24). + +However, to keep JAX compatible with the latest XLA developments, presubmit and +postsubmit builds utilize the most recent XLA version. This is done by +overriding the default XLA dependency with a local copy of the XLA repository. +We do this by passing `--override_repository=xla=/path/to/local/xla` which +instructs Bazel to depend on the XLA in the local system instead of the version +in the workspace. + +The CI system uses the `JAXCI` environment variables to manage this process. +When running jobs that need to use XLA at head, we set `JAXCI_CLONE_MAIN_XLA=1`. +This clones the XLA repository at head and sets `JAXCI_XLA_GIT_DIR` to its path. +[JAX build CLI](https://github.com/jax-ml/jax/blob/main/build/build.py) +automatically adds the necessary Bazel flag (`--override_repository`) to point +to this local XLA version during the build process if `JAXCI_XLA_GIT_DIR` is +set. In jobs where the build CLI is not used such as the RBE presubmits, we +explicitly include `--override_repository=xla="${JAXCI_XLA_GIT_DIR}"` as part +of the test command. + +### Enabling/Disabling 64-bit Data Types + +By default, JAX enforces single-precision numbers to mitigate the Numpy API’s +tendency to aggressively promote operands to `double`. In order to use +double-precision numbers, we need to set the `JAX_ENABLE_X64` environment +variable. In CI, we test both configurations in presubmits and postsubmits by +using the `JAXCI_ENABLE_X64` environment variable. + + + +## [Googlers Only] Connecting to CI Runners for Debugging + +If you are a Googler, you can connect to one of the self-hosted runners we have +on GitHub to debug your workflow. For more information, see +go/ml-github-actions:connect. + +## Running These Scripts Locally on Your Machine + +> [!IMPORTANT] +> If you are a Linux / Windows user, you need to have Docker installed as a +> prerequisite. Additionally, if running on Windows, please run these commands +> in a bash environment as all the scripts are written in Shell. + +Follow the steps below to run a CI script locally on your machine. + +1. [Optional] Set `JAXCI` variables in your shell environment. See + [ci/envs/README.md](https://github.com/jax-ml/jax/blob/main/ci/envs/README.md) + for the list of `JAXCI` variables and their behavior. + +2. [Linux/Windows] + + Start the Docker container by running: + + ```bash + ./ci/utilities/run_docker_container.sh + ``` + + This will start a Docker container named "jax". Note that if you set any + `JAXCI` variables in step 1, they will also be be set in the container. + + Run the script under the Docker container. + + ```bash + # docker exec jax + docker exec jax ./ci/build_artifacts.sh jaxlib + ``` + +3. [Mac] Execute the build script directly. + + ```bash + # ./ + ./ci/build_artifacts.sh jaxlib + ``` diff --git a/ci/build_artifacts.sh b/ci/build_artifacts.sh index 84b8d35a2a50..3a9d3dcf8c6a 100755 --- a/ci/build_artifacts.sh +++ b/ci/build_artifacts.sh @@ -38,6 +38,11 @@ allowed_artifacts=("jax" "jaxlib" "jax-cuda-plugin" "jax-cuda-pjrt") os=$(uname -s | awk '{print tolower($0)}') arch=$(uname -m) +bazel_startup_options="" +if [[ -n "${JAXCI_BAZEL_OUTPUT_BASE}" ]]; then + bazel_startup_options="--output_base=${JAXCI_BAZEL_OUTPUT_BASE}" +fi + # Adjust the values when running on Windows x86 to match the config in # .bazelrc if [[ $os =~ "msys_nt" && $arch == "x86_64" ]]; then @@ -52,10 +57,10 @@ fi # the git commit hash of the HEAD of the current branch and the date of the # commit (e.g. 0.5.1.dev20250128+3e75e20c7). if [[ "$JAXCI_ARTIFACT_TYPE" == "release" ]]; then - artifact_tag_flags="--bazel_options=--repo_env=ML_WHEEL_TYPE=release" + artifact_tag_flags="--bazel_options=--repo_env=ML_WHEEL_TYPE=release --bazel_options=--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)" elif [[ "$JAXCI_ARTIFACT_TYPE" == "nightly" ]]; then current_date=$(date +%Y%m%d) - artifact_tag_flags="--bazel_options=--repo_env=ML_WHEEL_BUILD_DATE=${current_date} --bazel_options=--repo_env=ML_WHEEL_TYPE=nightly" + artifact_tag_flags="--bazel_options=--repo_env=ML_WHEEL_BUILD_DATE=${current_date} --bazel_options=--repo_env=ML_WHEEL_TYPE=nightly --bazel_options=--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)" elif [[ "$JAXCI_ARTIFACT_TYPE" == "default" ]]; then artifact_tag_flags="--bazel_options=--repo_env=ML_WHEEL_TYPE=custom --bazel_options=--repo_env=ML_WHEEL_BUILD_DATE=$(git show -s --format=%as HEAD) --bazel_options=--repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD) --bazel_options=--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)" else @@ -63,6 +68,10 @@ else exit 1 fi +if [[ "$JAXCI_HERMETIC_PYTHON_VERSION" == *"-nogil" ]]; then + JAXCI_HERMETIC_PYTHON_VERSION=${JAXCI_HERMETIC_PYTHON_VERSION%-nogil}-ft +fi + if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then # Figure out the bazelrc config to use. We will use one of the "rbe_"/"ci_" # flags in the .bazelrc depending upon the platform we are building for. @@ -73,7 +82,11 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then bazel_remote_cache="" if [[ "$JAXCI_BUILD_ARTIFACT_WITH_RBE" == 1 ]]; then - bazelrc_config="rbe_${bazelrc_config}" + if [[ "$os" == "linux" && "$arch" == "aarch64" && "$artifact" == "jaxlib" ]]; then + bazelrc_config="rbe_cross_compile_${bazelrc_config}" + else + bazelrc_config="rbe_${bazelrc_config}" + fi else bazelrc_config="ci_${bazelrc_config}" @@ -86,16 +99,22 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then fi fi + cuda_version_flag="" # Use the "_cuda" configs when building the CUDA artifacts. if [[ ("$artifact" == "jax-cuda-plugin") || ("$artifact" == "jax-cuda-pjrt") ]]; then - bazelrc_config="${bazelrc_config}_cuda" + bazelrc_config="${bazelrc_config}_cuda${JAXCI_CUDA_VERSION}" + cuda_version_flag="--cuda_major_version=$JAXCI_CUDA_VERSION" fi # Build the artifact. python build/build.py build --wheels="$artifact" \ --bazel_options=--config="$bazelrc_config" $bazel_remote_cache \ + --bazel_options=--config=rbe_cpu_pool \ + --bazel_startup_options="$bazel_startup_options" \ --python_version=$JAXCI_HERMETIC_PYTHON_VERSION \ - --verbose --detailed_timestamped_log --use_new_wheel_build_rule \ + $cuda_version_flag \ + --verbose --detailed_timestamped_log \ + --output_path="$JAXCI_OUTPUT_DIR" \ $artifact_tag_flags # If building release artifacts, we also build a release candidate ("rc") @@ -103,20 +122,15 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then if [[ "$JAXCI_ARTIFACT_TYPE" == "release" ]]; then python build/build.py build --wheels="$artifact" \ --bazel_options=--config="$bazelrc_config" $bazel_remote_cache \ + --bazel_options=--config=rbe_cpu_pool \ + --bazel_startup_options="$bazel_startup_options" \ --python_version=$JAXCI_HERMETIC_PYTHON_VERSION \ - --verbose --detailed_timestamped_log --use_new_wheel_build_rule \ + $cuda_version_flag \ + --verbose --detailed_timestamped_log \ + --output_path="$JAXCI_OUTPUT_DIR" \ $artifact_tag_flags --bazel_options=--repo_env=ML_WHEEL_VERSION_SUFFIX="$JAXCI_WHEEL_RC_VERSION" fi - # Move the built artifacts from the Bazel cache directory to the output - # directory. - if [[ "$artifact" == "jax" ]]; then - mv bazel-bin/dist/*.whl "$JAXCI_OUTPUT_DIR" - mv bazel-bin/dist/*.tar.gz "$JAXCI_OUTPUT_DIR" - else - mv bazel-bin/jaxlib/tools/dist/*.whl "$JAXCI_OUTPUT_DIR" - fi - # If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we # run `auditwheel show` to verify manylinux compliance. if [[ "$os" == "linux" ]] && [[ "$artifact" != "jax" ]]; then diff --git a/ci/envs/README.md b/ci/envs/README.md new file mode 100644 index 000000000000..bf1ba648052d --- /dev/null +++ b/ci/envs/README.md @@ -0,0 +1,46 @@ +# JAXCI Environment Variables + +This docpage describes the various `JAXCI` environment variables that are used +in the CI scripts and their behaviors. These variables are used to control the +behavior of the CI scripts such as the Python version used, path to JAX/XLA +repo, if to clone XLA repo, etc. + +Name | Default Value | Behavior | Usage +------------------------------------------- | ---------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----- +`JAXCI_JAX_GIT_DIR` | Present working directory: `$(pwd)` | Path to the JAX's Git directory. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_JAX_GIT_DIR&type=code) +`JAXCI_HERMETIC_PYTHON_VERSION` | System default | Controls the version of hermetic Python to use. This affects the Bazel commands only such as when building artifacts or when running the Bazel test scripts. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_HERMETIC_PYTHON_VERSION&type=code) +`JAXCI_CUDA_VERSION` | 12 | Controls the CUDA version to use when building the JAX artifacts or running the tests. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_CUDA_VERSION&type=code) +`JAXCI_XLA_GIT_DIR` | Unset | When using a local copy of XLA, this points to the root of the XLA git repository. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_XLA_GIT_DIR&type=code) +`JAXCI_CLONE_MAIN_XLA` | 0 | If set to 1, the XLA repository is cloned at HEAD and its path is set in `JAXCI_XLA_GIT_DIR` | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_CLONE_MAIN_XLA&type=code) +`JAXCI_XLA_COMMIT` | Unset | Allows overriding the XLA commit that is used when using a local copy of XLA. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_XLA_COMMIT&type=code) +`JAXCI_OUTPUT_DIR` | `$(pwd)/dist` | Controls the location where the artifacts are written to. The directory will be automatically created if it does not exist. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_OUTPUT_DIR&type=code) +`JAXCI_BUILD_ARTIFACT_WITH_RBE` | 0 | When set to 1, Bazel will use RBE to build the artifacts. Requires gcloud authentication and only certain platforms support RBE so this typically only set in CI builds | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_BUILD_ARTIFACT_WITH_RBE&type=code) +`JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE` | 0 | When set to 1, Bazel will also try to push new cache entries to the cache bucket. Since writes to the bucket require authentication, this flag is enabled only for CI builds. Note that the builds using RBE use the RBE cache and not Bazel's remote cache, therefore this variable is a no-op if `JAXCI_BUILD_ARTIFACT_WITH_RBE` is set to 1. When `JAXCI_BUILD_ARTIFACT_WITH_RBE` and `JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE` are both not set, Bazel will still read from the public cache bucket to try to speed up the build. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE&type=code) +`JAXCI_ARTIFACT_TYPE` | "default" | Controls the type of artifacts to build. Valid values are "default", "release", "nightly". This affects the wheel tag and metadata, see [ci/build_artifacts.sh](https://github.com/jax-ml/jax/blob/main/ci/build_artifacts.sh) to understand how. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_ARTIFACT_TYPE&type=code) +`JAXCI_WHEEL_RC_VERSION` | Unset | During the release process, we build a Release Candidate (RC) wheel in addition to the release wheel. This environment variable sets the version of the RC wheel to build. Values are set internally. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_WHEEL_RC_VERSION&type=code) +`JAXCI_PYTHON` | `python${JAXCI_HERMETIC_PYTHON_VERSION}` | Points to the system Python binary to use. It used by scripts that make use of the system Python such as the Pytest scripts. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_PYTHON&type=code) +`JAXCI_ENABLE_X64` | 0 | By default, JAX enforces single-precision numbers to mitigate the Numpy API’s tendency to aggressively promote operands to `double`. When set to 1, the tests will use double-precision numbers. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_ENABLE_X64&type=code) +`JAXCI_TPU_CORES` | Unset | Sets the number of TPU cores for the TPU machine type. Values are set in the workflow files. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_TPU_CORES&type=code) +`JAXCI_RUN_FULL_TPU_TEST_SUITE` | 0 | When set to 1, the full TPU test suite is run. Otherwise, a subset of tests is run. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_RUN_FULL_TPU_TEST_SUITE&type=code) +`JAXCI_JAX_PYPI_EXTRAS` | Unset | Used to control the installation of JAX extras from PyPI. See JAX's [setup.py](https://github.com/jax-ml/jax/blob/c9934912885bb7c4b72c5a9271598235a6789a81/setup.py#L71) for the list of valid values. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_JAX_PYPI_EXTRAS&type=code) +`JAXCI_BUILD_JAXLIB` | true | Used to control the value of [build_jaxlib](https://github.com/jax-ml/jax/blob/338b4ebc8a5478e3d22efc9530be71d69c3bb993/jax/BUILD#L55-L63) flag. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_BUILD_JAXLIB&type=code) +`JAXCI_BUILD_JAX` | true | Used to control the value of [build_jax](https://github.com/jax-ml/jax/blob/338b4ebc8a5478e3d22efc9530be71d69c3bb993/jax/BUILD#L92-L100) flag. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_BUILD_JAX&type=code) +`JAXCI_BAZEL_OUTPUT_BASE` | Unset | Used to control the output base for Bazel builds. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_BAZEL_OUTPUT_BASE&type=code) +`JAXCI_BAZEL_CPU_RBE_MODE` | test | Used to control whether to run or build the CPU test targets. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_BAZEL_CPU_RBE_MODE&type=code) + +## Docker Specific Environment Variables + +> [!NOTE] +> The following environment variables only affect the build if the +> [run_docker_container.sh](https://github.com/jax-ml/jax/blob/main/ci/utilities/run_docker_container.sh) +> script was invoked to start a Docker container and the build is running inside +> that container. Typically, this would be the internal CI builds and local +> builds. Note that while GitHub actions use the same Docker images, they do not +> invoke "run_docker_container.sh" as they leverage built-in containerization +> features to run jobs within a container. + +Name | Default Value | Behavior | Usage +----------------------- | ------------------------------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------------------------------- | ----- +`JAXCI_DOCKER_WORK_DIR` | "/jax" | The path on the container where the JAX Git repository is mounted to. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_DOCKER_WORK_DIR&type=code) +`JAXCI_DOCKER_ARGS` | Empty String | Space separated string of additional arguments that will be passed when starting the Docker container | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_DOCKER_ARGS&type=code) +`JAXCI_DOCKER_IMAGE` | Depends on the system (see [ci/envs/docker.env](https://github.com/jax-ml/jax/blob/main/ci/envs/docker.env)) | Docker image to pull | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_DOCKER_IMAGE&type=code) diff --git a/ci/envs/default.env b/ci/envs/default.env index a5a5d56eb8b3..c35ec991c276 100644 --- a/ci/envs/default.env +++ b/ci/envs/default.env @@ -13,9 +13,8 @@ # limitations under the License. # ============================================================================== # This file contains all the default values for the "JAXCI_" environment -# variables used in the CI scripts. These variables are used to control the -# behavior of the CI scripts such as the Python version used, path to JAX/XLA -# repo, if to clone XLA repo, etc. +# variables used in the CI scripts. See ci/envs/README.md for more details on +# the behavior of these variables and their usage in the CI scripts. # The path to the JAX git repository. export JAXCI_JAX_GIT_DIR=$(pwd) @@ -24,13 +23,15 @@ export JAXCI_JAX_GIT_DIR=$(pwd) # set. export JAXCI_HERMETIC_PYTHON_VERSION=${JAXCI_HERMETIC_PYTHON_VERSION:-$(python3 -V | awk '{print $2}' | awk -F. '{print $1"."$2}')} +# Controls the CUDA version to use when building the JAX artifacts or +# running the tests. +export JAXCI_CUDA_VERSION=${JAXCI_CUDA_VERSION:-12} + # Set JAXCI_XLA_GIT_DIR to the root of the XLA git repository to use a local -# copy of XLA instead of the pinned version in the WORKSPACE. When -# JAXCI_CLONE_MAIN_XLA=1, this gets set automatically. +# copy of XLA instead of the pinned version in the WORKSPACE. export JAXCI_XLA_GIT_DIR=${JAXCI_XLA_GIT_DIR:-} -# If set to 1, the builds will clone the XLA repository at HEAD and set its -# path in JAXCI_XLA_GIT_DIR. +# If set to 1, the builds will clone the XLA repository at HEAD. export JAXCI_CLONE_MAIN_XLA=${JAXCI_CLONE_MAIN_XLA:-0} # Allows overriding the XLA commit that is used. @@ -39,49 +40,53 @@ export JAXCI_XLA_COMMIT=${JAXCI_XLA_COMMIT:-} # Controls the location where the artifacts are written to. export JAXCI_OUTPUT_DIR="$(pwd)/dist" -# When enabled, artifacts will be built with RBE. Requires gcloud authentication -# and only certain platforms support RBE. Therefore, this flag is enabled only -# for CI builds where RBE is supported. +# Whether to use RBE to build the artifacts. export JAXCI_BUILD_ARTIFACT_WITH_RBE=${JAXCI_BUILD_ARTIFACT_WITH_RBE:-0} -# On platforms where RBE is not supported, we use Bazel remote cache to speed up -# builds. When this flag is enabled, Bazel will also try to push new cache -# entries to the bucket. Since writes to the bucket require authentication, this -# flag is enabled only for CI builds. +# Whether to write new cache entries to the remote cache bucket. export JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=${JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE:-0} -# Type of artifacts to build. Valid values are "default", "release", "nightly". -# This affects the wheel naming/tag. +# Controls the type of artifacts to build. Valid values are "default", "release", "nightly". export JAXCI_ARTIFACT_TYPE=${JAXCI_ARTIFACT_TYPE:-"default"} -# When building release artifacts, we build a release candidate wheel ("rc" -# tagged wheel) in addition to the release wheel. This environment variable -# sets the version of the release candidate ("RC") artifact to build. +# Controls the version of the Release Candidate wheel to build during the +# release process. export JAXCI_WHEEL_RC_VERSION=${JAXCI_WHEEL_RC_VERSION:-} # ############################################################################# # Test script specific environment variables. # ############################################################################# -# Sets the value of `JAX_ENABLE_X64` in the test scripts. CI builds override -# this value in the Github action workflow files. +# Whether to use double-precision numbers in the tests. export JAXCI_ENABLE_X64=${JAXCI_ENABLE_X64:-0} -# Pytest specific environment variables below. Used in run_pytest_*.sh scripts. -# Sets the number of TPU cores for the TPU machine type. These values are -# defined in the TPU GitHub Actions workflow. +# Sets the number of TPU cores for the TPU machine type. export JAXCI_TPU_CORES=${JAXCI_TPU_CORES:-} -# JAXCI_PYTHON points to the Python interpreter to use for installing JAX wheels -# on the system. By default, it is set to match the version of the hermetic -# Python used by Bazel for building the wheels. +# JAXCI_PYTHON points to the Python binary on the system that should be used +# for installing the JAX wheels on the system and running Pytest scripts. export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}} # When set to 1, the full TPU test suite is run. Otherwise, a subset of tests # is run. export JAXCI_RUN_FULL_TPU_TEST_SUITE=${JAXCI_RUN_FULL_TPU_TEST_SUITE:-0} -# We use this environment variable to control which additional wheels to install -# from PyPI. For instance, it can be set to "tpu_pypi" to install the latest -# libtpu wheel from PyPI. See ci/utilities/install_wheels_locally.sh for the -# list of valid values and their behavior. -export JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI=${JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI:-""} \ No newline at end of file +# Controls which additional extras for JAX to install from PyPI. +export JAXCI_JAX_PYPI_EXTRAS=${JAXCI_JAX_PYPI_EXTRAS:-""} + +# Controls the value of --//jax:build_jaxlib. +# true: add individual jaxlib and CUDA plugin targets in the test dependencies. +# false: add pre-built jaxlib and CUDA plugin wheels in the test dependencies. +# wheel: add jaxlib and CUDA plugin py_import targets in the test dependencies. +export JAXCI_BUILD_JAXLIB=${JAXCI_BUILD_JAXLIB:-true} + +# Controls the value of --//jax:build_jax flag. +# true: add individual jax targets in the test dependencies. +# false: add pre-built jax wheel in the test dependencies. +# wheel: add jax py_import target in the test dependencies. +export JAXCI_BUILD_JAX=${JAXCI_BUILD_JAX:-true} + +# Controls the output base for Bazel builds. +export JAXCI_BAZEL_OUTPUT_BASE=${JAXCI_BAZEL_OUTPUT_BASE:-} + +# Controls whether to build or run CPU test targets. +export JAXCI_BAZEL_CPU_RBE_MODE=${JAXCI_BAZEL_CPU_RBE_MODE:-"test"} \ No newline at end of file diff --git a/ci/envs/docker.env b/ci/envs/docker.env index 82a76d33350c..cef2cda27bf4 100644 --- a/ci/envs/docker.env +++ b/ci/envs/docker.env @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -# This file contains all the docker specifc envs that are needed by the +# This file contains all the docker specific envs that are needed by the # ci/utilities/run_docker_container.sh script. os=$(uname -s | awk '{print tolower($0)}') @@ -29,17 +29,17 @@ export JAXCI_DOCKER_ARGS="" # Linux x86 image for building JAX artifacts, running Pytests CPU/TPU tests, and # Bazel tests if [[ $os == "linux" ]] && [[ $arch == "x86_64" ]]; then - export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" + export JAXCI_DOCKER_IMAGE="us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest" fi # Linux Aarch64 image for building JAX artifacts, running Pytests CPU tests, and # Bazel tests if [[ $os == "linux" ]] && [[ $arch == "aarch64" ]]; then - export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest" + export JAXCI_DOCKER_IMAGE="us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build-arm64:latest" fi # Windows image for building JAX artifacts, running Pytests CPU tests, and Bazel # tests if [[ $os =~ "msys_nt" ]]; then - export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/tf-test-windows@sha256:6e2b299f12418d70ea522646b3dd618042a102f2ac2e4f8b1e423638549ea801" + export JAXCI_DOCKER_IMAGE="us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/tf-test-windows:latest" fi \ No newline at end of file diff --git a/ci/jax_ci_system.png b/ci/jax_ci_system.png new file mode 100644 index 000000000000..19efe62ae59e Binary files /dev/null and b/ci/jax_ci_system.png differ diff --git a/ci/k8s/indexed-job.yaml b/ci/k8s/indexed-job.yaml new file mode 100644 index 000000000000..c38a8c9991a2 --- /dev/null +++ b/ci/k8s/indexed-job.yaml @@ -0,0 +1,42 @@ +apiVersion: v1 +kind: Service +metadata: + name: jaxpods +spec: + publishNotReadyAddresses: true + clusterIP: None + selector: + job-name: jaxjob +--- +apiVersion: batch/v1 +kind: Job +metadata: + name: jaxjob +spec: + parallelism: 8 + completions: 8 + completionMode: Indexed + backoffLimit: 0 + template: + spec: + subdomain: jaxpods # must match headless service name + serviceAccountName: jax-job-sa + restartPolicy: Never + containers: + - name: main + image: local/jax:latest + imagePullPolicy: IfNotPresent + resources: + limits: + cpu: 100m + command: + - python + args: + - -c + - | + import jax + jax.distributed.initialize() + print(jax.devices()) + print(jax.local_devices()) + assert jax.process_count() > 1 + assert len(jax.devices()) > len(jax.local_devices()) diff --git a/ci/k8s/jobset.yaml b/ci/k8s/jobset.yaml new file mode 100644 index 000000000000..00150d0a9095 --- /dev/null +++ b/ci/k8s/jobset.yaml @@ -0,0 +1,34 @@ +apiVersion: jobset.x-k8s.io/v1alpha2 +kind: JobSet +metadata: + name: jaxjob +spec: + replicatedJobs: + - name: workers + template: + spec: + parallelism: 8 + completions: 8 + backoffLimit: 0 + template: + spec: + serviceAccountName: jax-job-sa + restartPolicy: Never + containers: + - name: main + image: local/jax:latest + imagePullPolicy: Never + resources: + limits: + cpu: 100m + command: + - python + args: + - -c + - | + import jax + jax.distributed.initialize() + print(jax.devices()) + print(jax.local_devices()) + assert jax.process_count() > 1 + assert len(jax.devices()) > len(jax.local_devices()) diff --git a/ci/run_bazel_test_cpu_rbe.sh b/ci/run_bazel_test_cpu_rbe.sh index 248111e0247a..fd030d90f548 100755 --- a/ci/run_bazel_test_cpu_rbe.sh +++ b/ci/run_bazel_test_cpu_rbe.sh @@ -37,32 +37,69 @@ source "ci/utilities/setup_build_environment.sh" os=$(uname -s | awk '{print tolower($0)}') arch=$(uname -m) +bazel_output_base="" +# Adjust os and arch for Windows +if [[ $os =~ "msys_nt" ]] && [[ $arch =~ "x86_64" ]]; then + os="windows" + arch="amd64" + bazel_output_base="--output_base=C:\actions-runner\_work\bazel_output_base" +fi + +if [[ "$JAXCI_HERMETIC_PYTHON_VERSION" == *"-nogil" ]]; then + JAXCI_HERMETIC_PYTHON_VERSION=${JAXCI_HERMETIC_PYTHON_VERSION%-nogil}-ft + FREETHREADED_FLAG_VALUE="yes" +else + FREETHREADED_FLAG_VALUE="no" +fi + + # TODO(b/446172564): Remove this condition when the test is fixed on all + # platforms. +if [[ $os == "linux" ]] && [[ $arch == "x86_64" ]]; then + IGNORE_TESTS="" +else + IGNORE_TESTS="-//tests/multiprocess:array_test_cpu" +fi + +if [[ "$JAXCI_BAZEL_CPU_RBE_MODE" == 'build' ]]; then + echo "Building RBE CPU tests..." +else + echo "Running RBE CPU tests..." +fi + +test_strategy="" # When running on Mac or Linux Aarch64, we only build the test targets and # not run them. These platforms do not have native RBE support so we # RBE cross-compile them on remote Linux x86 machines. As the tests still # need to be run on the host machine and because running the tests on a # single machine can take a long time, we skip running them on these -# platforms. +# platforms in the presubmit jobs. if [[ $os == "darwin" ]] || ( [[ $os == "linux" ]] && [[ $arch == "aarch64" ]] ); then - echo "Building RBE CPU tests..." - bazel build --config=rbe_cross_compile_${os}_${arch} \ - --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ - --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ - --test_env=JAX_NUM_GENERATED_CASES=25 \ - --test_env=JAX_SKIP_SLOW_TESTS=true \ - --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ - --test_output=errors \ - --color=yes \ - //tests:cpu_tests //tests:backend_independent_tests + rbe_config=rbe_cross_compile_${os}_${arch} + if [[ "$JAXCI_BAZEL_CPU_RBE_MODE" == 'test' ]]; then + test_strategy="--strategy=TestRunner=local" + fi else - echo "Running RBE CPU tests..." - bazel test --config=rbe_${os}_${arch} \ - --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ - --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ - --test_env=JAX_NUM_GENERATED_CASES=25 \ - --test_env=JAX_SKIP_SLOW_TESTS=true \ - --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ - --test_output=errors \ - --color=yes \ - //tests:cpu_tests //tests:backend_independent_tests -fi \ No newline at end of file + rbe_config=rbe_${os}_${arch} +fi + +bazel $bazel_output_base $JAXCI_BAZEL_CPU_RBE_MODE \ + --build_runfile_links=false \ + --config=$rbe_config \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --@rules_python//python/config_settings:py_freethreaded="$FREETHREADED_FLAG_VALUE" \ + --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ + --//jax:build_jaxlib=$JAXCI_BUILD_JAXLIB \ + --//jax:build_jax=$JAXCI_BUILD_JAX \ + $test_strategy \ + --test_env=JAX_NUM_GENERATED_CASES=25 \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ + --test_output=errors \ + --color=yes \ + -- \ + //tests:cpu_tests //tests:backend_independent_tests \ + //jax/experimental/jax2tf/tests:jax2tf_test_cpu \ + //tests/multiprocess:cpu_tests \ + //jax/experimental/jax2tf/tests/multiprocess:cpu_tests \ + //jaxlib/tools:check_cpu_wheel_sources_test \ + $IGNORE_TESTS \ No newline at end of file diff --git a/ci/run_bazel_test_cuda_non_rbe.sh b/ci/run_bazel_test_cuda_non_rbe.sh index 176efd3444c9..98d5955d60cb 100755 --- a/ci/run_bazel_test_cuda_non_rbe.sh +++ b/ci/run_bazel_test_cuda_non_rbe.sh @@ -15,8 +15,8 @@ # ============================================================================== # Run Bazel GPU tests without RBE. This runs two commands: single accelerator # tests with one GPU a piece, multiaccelerator tests with all GPUS. -# Requires that jaxlib, jax-cuda-plugin, and jax-cuda-pjrt wheels are stored -# inside the ../dist folder +# If $JAXCI_BUILD_JAXLIB=false, the job requires that jaxlib, jax-cuda-plugin, +# and jax-cuda-pjrt wheels are stored inside the ../dist folder # # -e: abort script if one command fails # -u: error if undefined variable used @@ -63,6 +63,36 @@ if [[ $host_memory_limit -lt $num_test_jobs ]]; then fi # End of test environment variables setup. +if [[ "$JAXCI_HERMETIC_PYTHON_VERSION" == *"-nogil" ]]; then + JAXCI_HERMETIC_PYTHON_VERSION=${JAXCI_HERMETIC_PYTHON_VERSION%-nogil}-ft + FREETHREADED_FLAG_VALUE="yes" +else + FREETHREADED_FLAG_VALUE="no" +fi + +# Get the CUDA major version only +cuda_major_version="${JAXCI_CUDA_VERSION%%.*}" + +if [[ "$JAXCI_BUILD_JAXLIB" == "wheel" ]]; then + TEST_CONFIG="rbe_linux_x86_64_cuda$cuda_major_version" + TEST_STRATEGY="--strategy=TestRunner=local" + CACHE_OPTION="" +else + TEST_CONFIG="ci_linux_x86_64_cuda$cuda_major_version" + CACHE_OPTION="--config=ci_rbe_cache" + TEST_STRATEGY="" +fi + +# Enable hermetic UMD 13.0 for NVIDIA drivers older than 580. +driver_version=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -n 1) +driver_major_version=${driver_version%%.*} +if [[ "$driver_major_version" -lt "580" ]]; then + echo "NVIDIA driver version ($driver_version) is older than 580." + echo "Enabling hermetic UMD 13.0." + TEST_CONFIG="$TEST_CONFIG --repo_env=HERMETIC_CUDA_UMD_VERSION=13.0.0" +fi + + # Don't abort the script if one command fails to ensure we run both test # commands below. set +e @@ -71,16 +101,18 @@ set +e # It appears --run_under needs an absolute path. # The product of the `JAX_ACCELERATOR_COUNT`` and `JAX_TESTS_PER_ACCELERATOR` # should match the VM's CPU core count (set in `--local_test_jobs`). -bazel test --config=ci_linux_x86_64_cuda \ - --config=resultstore \ - --config=rbe_cache \ +bazel test --config=$TEST_CONFIG \ + $CACHE_OPTION \ --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ - --//jax:build_jaxlib=false \ + --@rules_python//python/config_settings:py_freethreaded="$FREETHREADED_FLAG_VALUE" \ + --//jax:build_jaxlib=$JAXCI_BUILD_JAXLIB \ + --//jax:build_jax=$JAXCI_BUILD_JAX \ --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ --run_under "$(pwd)/build/parallel_accelerator_execute.sh" \ --test_output=errors \ --test_env=JAX_ACCELERATOR_COUNT=$gpu_count \ --test_env=JAX_TESTS_PER_ACCELERATOR=$max_tests_per_gpu \ + $TEST_STRATEGY \ --local_test_jobs=$num_test_jobs \ --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \ --test_tag_filters=-multiaccelerator \ @@ -89,6 +121,8 @@ bazel test --config=ci_linux_x86_64_cuda \ --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ --action_env=NCCL_DEBUG=WARN \ --color=yes \ + --config=cuda_libraries_from_stubs \ + --config=hermetic_cuda_umd \ //tests:gpu_tests //tests:backend_independent_tests \ //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests @@ -97,21 +131,26 @@ first_bazel_cmd_retval=$? echo "Running multi-accelerator tests (without RBE)..." # Runs multiaccelerator tests with all GPUs directly on the VM without RBE.. -bazel test --config=ci_linux_x86_64_cuda \ - --config=resultstore \ - --config=rbe_cache \ +bazel test --config=$TEST_CONFIG \ + $CACHE_OPTION \ --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ - --//jax:build_jaxlib=false \ + --@rules_python//python/config_settings:py_freethreaded="$FREETHREADED_FLAG_VALUE" \ + --//jax:build_jaxlib=$JAXCI_BUILD_JAXLIB \ + --//jax:build_jax=$JAXCI_BUILD_JAX \ --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ --test_output=errors \ - --jobs=8 \ + $TEST_STRATEGY \ + --local_test_jobs=8 \ --test_tag_filters=multiaccelerator \ --test_env=TF_CPP_MIN_LOG_LEVEL=0 \ --test_env=JAX_SKIP_SLOW_TESTS=true \ --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ --action_env=NCCL_DEBUG=WARN \ --color=yes \ - //tests:gpu_tests //tests/pallas:gpu_tests + --config=cuda_libraries_from_stubs \ + --config=hermetic_cuda_umd \ + //tests:gpu_tests //tests/pallas:gpu_tests \ + //tests/multiprocess:gpu_tests # Store the return value of the second bazel command. second_bazel_cmd_retval=$? diff --git a/ci/run_bazel_test_cuda_rbe.sh b/ci/run_bazel_test_cuda_rbe.sh index 17bd8d9db4f8..d1ebe748ad57 100755 --- a/ci/run_bazel_test_cuda_rbe.sh +++ b/ci/run_bazel_test_cuda_rbe.sh @@ -34,10 +34,16 @@ fi # Set up the build environment. source "ci/utilities/setup_build_environment.sh" +if [[ "$JAXCI_BUILD_JAXLIB" != "true" ]]; then + cuda_libs_flag="--config=cuda_libraries_from_stubs" +else + cuda_libs_flag="--@local_config_cuda//cuda:override_include_cuda_libs=true" +fi + # Run Bazel GPU tests with RBE (single accelerator tests with one GPU apiece). echo "Running RBE GPU tests..." -bazel test --config=rbe_linux_x86_64_cuda \ +bazel test --config=rbe_linux_x86_64_cuda${JAXCI_CUDA_VERSION} \ --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ @@ -48,4 +54,10 @@ bazel test --config=rbe_linux_x86_64_cuda \ --test_env=JAX_SKIP_SLOW_TESTS=true \ --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ --color=yes \ - //tests:gpu_tests //tests:backend_independent_tests //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \ No newline at end of file + $cuda_libs_flag \ + --config=hermetic_cuda_umd \ + --//jax:build_jaxlib=$JAXCI_BUILD_JAXLIB \ + --//jax:build_jax=$JAXCI_BUILD_JAX \ + //tests:gpu_tests //tests:backend_independent_tests \ + //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \ + //jaxlib/tools:check_gpu_wheel_sources_test diff --git a/ci/run_bazel_test_tpu.sh b/ci/run_bazel_test_tpu.sh new file mode 100755 index 000000000000..f0063109a57b --- /dev/null +++ b/ci/run_bazel_test_tpu.sh @@ -0,0 +1,240 @@ +#!/bin/bash +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Runs Bazel TPU tests. If $JAXCI_BUILD_JAXLIB=false and $JAXCI_BUILD_JAX=false, +# the job requires that jax and jaxlib wheels are stored inside the ../dist +# folder. +# +# -e: abort script if one command fails +# -u: error if undefined variable used +# -x: log all commands +# -o history: record shell history +# -o allexport: export all functions and variables to be available to subscripts +set -exu -o history -o allexport + +# Source default JAXCI environment variables. +source ci/envs/default.env + +# Clone XLA at HEAD if path to local XLA is not provided +if [[ -z "$JAXCI_XLA_GIT_DIR" ]]; then + export JAXCI_CLONE_MAIN_XLA=1 +fi + +# Set up the build environment. +source "ci/utilities/setup_build_environment.sh" + +if [[ "$JAXCI_HERMETIC_PYTHON_VERSION" == *"-nogil" ]]; then + JAXCI_HERMETIC_PYTHON_VERSION=${JAXCI_HERMETIC_PYTHON_VERSION%-nogil}-ft + FREETHREADED_FLAG_VALUE="yes" +else + FREETHREADED_FLAG_VALUE="no" +fi + +OVERRIDE_XLA_REPO="" +if [[ "$JAXCI_CLONE_MAIN_XLA" == 1 ]]; then + OVERRIDE_XLA_REPO="--override_repository=xla=${JAXCI_XLA_GIT_DIR}" +fi + +NB_TPUS=$JAXCI_TPU_CORES +JOBS_PER_ACC=1 +J=$((NB_TPUS * JOBS_PER_ACC)) + +# TODO(ybaturina): Bazel cache shouldn't be invalidated when +# `VBAR_CONTROL_SERVICE_URL` changes. +COMMON_TPU_TEST_ENV_VARS="--test_env=TPU_SKIP_MDS_QUERY=true \ + --test_env=TPU_TOPOLOGY \ + --test_env=TPU_WORKER_ID \ + --test_env=TPU_TOPOLOGY_WRAP \ + --test_env=TPU_CHIPS_PER_HOST_BOUNDS \ + --test_env=TPU_ACCELERATOR_TYPE \ + --test_env=TPU_RUNTIME_METRICS_PORTS \ + --test_env=TPU_TOPOLOGY_ALT \ + --test_env=TPU_HOST_BOUNDS \ + --test_env=TPU_WORKER_HOSTNAMES \ + --test_env=CHIPS_PER_HOST_BOUNDS \ + --test_env=HOST_BOUNDS \ + --test_env=VBAR_CONTROL_SERVICE_URL" + +echo "Running Bazel TPU tests..." + +# Don't abort the script if one command fails to ensure we run both test +# commands below. +set +e + +# TODO(emilyaf): Debug and re-enable this test. +IGNORE_TESTS_MULTIACCELERATOR="-//tests/multiprocess:array_test_tpu" + +if [[ "$JAXCI_RUN_FULL_TPU_TEST_SUITE" == "1" ]]; then + # We're deselecting all Pallas TPU tests in the oldest libtpu build. Mosaic + # TPU does not guarantee anything about forward compatibility (unless + # jax.export is used) and the 12 week compatibility window accumulates way + # too many failures. + IGNORE_TESTS="" + if [ "${libtpu_version_type:-""}" == "oldest_supported_libtpu" ]; then + IGNORE_TESTS="-//tests/pallas/..." + else + IGNORE_TESTS="-//tests/pallas:tpu_pallas_interpret_thread_map_test_tpu" + fi + + # Run single-accelerator tests in parallel + bazel test \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --@rules_python//python/config_settings:py_freethreaded="$FREETHREADED_FLAG_VALUE" \ + $OVERRIDE_XLA_REPO \ + --config=ci_linux_x86_64 \ + --config=ci_rbe_cache \ + --//jax:build_jaxlib=$JAXCI_BUILD_JAXLIB \ + --//jax:build_jax=$JAXCI_BUILD_JAX \ + --run_under="$(pwd)/build/parallel_accelerator_execute.sh" \ + --test_env=JAX_ACCELERATOR_COUNT=${NB_TPUS} \ + --test_env=JAX_TESTS_PER_ACCELERATOR=${JOBS_PER_ACC} \ + --strategy=TestRunner=local \ + --local_test_jobs=$J \ + --test_env=JAX_TEST_NUM_THREADS=$J \ + --test_env=ALLOW_MULTIPLE_LIBTPU_LOAD=true \ + --test_env=JAX_SKIP_SLOW_TESTS=1 \ + --test_env=JAX_ENABLE_TPU_XDIST=1 \ + --test_env=JAX_PLATFORMS=tpu,cpu \ + --repo_env=USE_MINIMAL_SHARD_COUNT=True \ + $COMMON_TPU_TEST_ENV_VARS \ + --test_tag_filters=-multiaccelerator \ + --verbose_failures \ + --test_output=errors \ + -- \ + //tests:tpu_tests \ + //tests/pallas:tpu_tests \ + $IGNORE_TESTS + + # Store the return value of the first bazel command. + first_bazel_cmd_retval=$? + + # Run multi-accelerator across all chips + bazel test \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --@rules_python//python/config_settings:py_freethreaded="$FREETHREADED_FLAG_VALUE" \ + $OVERRIDE_XLA_REPO \ + --config=ci_linux_x86_64 \ + --config=ci_rbe_cache \ + --//jax:build_jaxlib=$JAXCI_BUILD_JAXLIB \ + --//jax:build_jax=$JAXCI_BUILD_JAXLIB \ + --test_env=ALLOW_MULTIPLE_LIBTPU_LOAD=true \ + --strategy=TestRunner=local \ + --local_test_jobs=1 \ + --repo_env=USE_MINIMAL_SHARD_COUNT=True \ + --test_env=JAX_SKIP_SLOW_TESTS=1 \ + --test_env=JAX_PLATFORMS=tpu,cpu \ + $COMMON_TPU_TEST_ENV_VARS \ + --test_tag_filters=multiaccelerator \ + --verbose_failures \ + --test_output=errors \ + -- \ + //tests:tpu_tests \ + //tests/pallas:tpu_tests \ + //tests/multiprocess:tpu_tests \ + $IGNORE_TESTS_MULTIACCELERATOR + + # Store the return value of the second bazel command. + second_bazel_cmd_retval=$? +else + + # Run single-accelerator tests in parallel + bazel test \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --@rules_python//python/config_settings:py_freethreaded="$FREETHREADED_FLAG_VALUE" \ + $OVERRIDE_XLA_REPO \ + --config=ci_linux_x86_64 \ + --config=ci_rbe_cache \ + --//jax:build_jaxlib=$JAXCI_BUILD_JAXLIB \ + --//jax:build_jax=$JAXCI_BUILD_JAXLIB \ + --run_under="$(pwd)/build/parallel_accelerator_execute.sh" \ + --test_env=JAX_ACCELERATOR_COUNT=${NB_TPUS} \ + --test_env=JAX_TESTS_PER_ACCELERATOR=${JOBS_PER_ACC} \ + --strategy=TestRunner=local \ + --local_test_jobs=$J \ + --test_env=JAX_TEST_NUM_THREADS=$J \ + --test_env=ALLOW_MULTIPLE_LIBTPU_LOAD=true \ + --test_env=JAX_SKIP_SLOW_TESTS=1 \ + --test_env=JAX_ENABLE_TPU_XDIST=1 \ + --test_env=JAX_PLATFORMS=tpu,cpu \ + --repo_env=USE_MINIMAL_SHARD_COUNT=True \ + $COMMON_TPU_TEST_ENV_VARS \ + --test_tag_filters=-multiaccelerator \ + --verbose_failures \ + --test_output=errors \ + -- \ + //jaxlib/tools:check_tpu_wheel_sources_test \ + //tests/pallas:ops_test_tpu \ + //tests/pallas:export_back_compat_pallas_test_tpu \ + //tests/pallas:export_pallas_test_tpu \ + //tests/pallas:tpu_ops_test_tpu \ + //tests/pallas:tpu_pallas_random_test_tpu \ + //tests/pallas:tpu_pallas_async_test_tpu \ + //tests/pallas:tpu_pallas_state_test_tpu \ + //tests/pallas:tpu_pallas_test_tpu \ + //tests/pallas:tpu_pallas_call_print_test_tpu \ + //tests/pallas:indexing_test_tpu \ + //tests/pallas:pallas_error_handling_test_tpu \ + //tests/pallas:pallas_shape_poly_test_tpu \ + //tests/pallas:tpu_all_gather_test_tpu \ + //tests/pallas:tpu_fusible_matmul_test_tpu \ + //tests/pallas:tpu_pallas_distributed_test_tpu \ + //tests/pallas:tpu_pallas_memory_space_test_tpu \ + //tests/pallas:tpu_splash_attention_kernel_sharded_test_tpu \ + //tests/pallas:tpu_sparsecore_pallas_test_tpu + + # Store the return value of the first bazel command. + first_bazel_cmd_retval=$? + + # Run multi-accelerator across all chips + bazel test \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --@rules_python//python/config_settings:py_freethreaded="$FREETHREADED_FLAG_VALUE" \ + $OVERRIDE_XLA_REPO \ + --config=ci_linux_x86_64 \ + --config=ci_rbe_cache \ + --//jax:build_jaxlib=$JAXCI_BUILD_JAXLIB \ + --//jax:build_jax=$JAXCI_BUILD_JAXLIB \ + --test_env=ALLOW_MULTIPLE_LIBTPU_LOAD=true \ + --strategy=TestRunner=local \ + --local_test_jobs=1 \ + --test_env=JAX_ACCELERATOR_COUNT=${NB_TPUS} \ + --repo_env=USE_MINIMAL_SHARD_COUNT=True \ + --test_env=JAX_SKIP_SLOW_TESTS=1 \ + --test_env=JAX_PLATFORMS=tpu,cpu \ + $COMMON_TPU_TEST_ENV_VARS \ + --test_tag_filters=multiaccelerator \ + --verbose_failures \ + --test_output=errors \ + -- \ + //tests:aot_test_tpu \ + //tests:array_test_tpu \ + //tests:jaxpr_effects_test_tpu \ + //tests:layout_test_tpu \ + //tests:pjit_test_tpu \ + //tests:python_callback_test_tpu \ + //tests:ragged_collective_test_tpu + + # Store the return value of the second bazel command. + second_bazel_cmd_retval=$? +fi + +# Exit with failure if either command fails. +if [[ $first_bazel_cmd_retval -ne 0 ]]; then + exit $first_bazel_cmd_retval +elif [[ $second_bazel_cmd_retval -ne 0 ]]; then + exit $second_bazel_cmd_retval +else + exit 0 +fi diff --git a/ci/run_pytest_cpu.sh b/ci/run_pytest_cpu.sh index 43581ef2c96c..7da9504d459e 100755 --- a/ci/run_pytest_cpu.sh +++ b/ci/run_pytest_cpu.sh @@ -26,16 +26,16 @@ set -exu -o history -o allexport # Source default JAXCI environment variables. source ci/envs/default.env +# Set up the build environment. +source "ci/utilities/setup_build_environment.sh" + # Install jaxlib wheel inside the $JAXCI_OUTPUT_DIR directory on the system. echo "Installing wheels locally..." source ./ci/utilities/install_wheels_locally.sh -# Set up the build environment. -source "ci/utilities/setup_build_environment.sh" - # Print all the installed packages echo "Installed packages:" -"$JAXCI_PYTHON" -m uv pip list +"$JAXCI_PYTHON" -m uv pip freeze "$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))" @@ -43,8 +43,17 @@ echo "Installed packages:" export PY_COLORS=1 export JAX_SKIP_SLOW_TESTS=true export TF_CPP_MIN_LOG_LEVEL=0 -export JAX_ENABLE_64="$JAXCI_ENABLE_X64" +export JAX_ENABLE_X64="$JAXCI_ENABLE_X64" + +MAX_PROCESSES=${MAX_PROCESSES:-} +MAX_PROCESSES_ARG="" +if [[ -n "${MAX_PROCESSES}" ]]; then + MAX_PROCESSES_ARG="--maxprocesses=${MAX_PROCESSES}" +elif [[ "$(uname -s)" == *"MSYS"* ]]; then + MAX_PROCESSES_ARG="--maxprocesses=32" # Tests OOM on Windows sometimes. +fi # End of test environment variable setup echo "Running CPU tests..." -"$JAXCI_PYTHON" -m pytest -n auto --tb=short --maxfail=20 tests examples \ No newline at end of file +"$JAXCI_PYTHON" -m pytest -n auto --tb=short $MAX_PROCESSES_ARG \ + --maxfail=20 tests examples diff --git a/ci/run_pytest_cuda.sh b/ci/run_pytest_cuda.sh index 45020542b34b..95fb5f100b22 100755 --- a/ci/run_pytest_cuda.sh +++ b/ci/run_pytest_cuda.sh @@ -36,47 +36,77 @@ source "ci/utilities/setup_build_environment.sh" # Print all the installed packages echo "Installed packages:" -"$JAXCI_PYTHON" -m uv pip list +"$JAXCI_PYTHON" -m uv pip freeze "$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))" nvidia-smi -# Set up all test environment variables +# ============================================================================== +# Set up the generic test environment variables +# ============================================================================== export PY_COLORS=1 export JAX_SKIP_SLOW_TESTS=true export NCCL_DEBUG=WARN export TF_CPP_MIN_LOG_LEVEL=0 -export JAX_ENABLE_64="$JAXCI_ENABLE_X64" +export JAX_ENABLE_X64="$JAXCI_ENABLE_X64" + +# ============================================================================== +# Calculate the optimal number of parallel processes for pytest +# This will be the minimum of: GPU capacity, CPU core count, and a system RAM limit. +# ============================================================================== -# Set the number of processes to min(num_cpu_cores, gpu_count * $max_tests_per_gpu, total_ram_gb / 6) -# We calculate max_tests_per_gpu as memory_per_gpu_gb / 2gb -# Calculate gpu_count * max_tests_per_gpu export gpu_count=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) -export memory_per_gpu_gb=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits --id=0) -export memory_per_gpu_gb=$((memory_per_gpu_gb / 1024)) -# Allow 2 GB of GPU RAM per test -export max_tests_per_gpu=$((memory_per_gpu_gb / 2)) +echo "Number of GPUs detected: $gpu_count" + +echo "Assuming all GPUs are the same model and have the same amount of memory" +export gpu_name=$(nvidia-smi --query-gpu=name --format=csv,noheader --id=0) +echo "Detected GPU type: $gpu_name" + +export memory_per_gpu_mib=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits --id=0) +echo "Reported memory per GPU: $memory_per_gpu_mib MiB" + +# Convert effective memory from MiB to GiB. +export memory_per_gpu_gib=$((memory_per_gpu_mib / 1024)) +echo "Effective memory per GPU: $memory_per_gpu_gib GiB" + +# Allow 2 GiB of GPU RAM per test. +export max_tests_per_gpu=$((memory_per_gpu_gib / 2)) +echo "Max tests per GPU (assuming 2GiB/test): $max_tests_per_gpu" + export num_processes=$((gpu_count * max_tests_per_gpu)) +echo "Initial number of processes based on GPU capacity: $num_processes" -# Calculate num_cpu_cores export num_cpu_cores=$(nproc) +echo "Number of CPU cores available: $num_cpu_cores" -# Calculate total_ram_gb / 6 -export total_ram_gb=$(awk '/MemTotal/ {printf "%.0f", $2/1048576}' /proc/meminfo) -export host_memory_limit=$((total_ram_gb / 6)) +# Reads total memory from /proc/meminfo (in KiB) and converts to GiB. +export total_ram_gib=$(awk '/MemTotal/ {printf "%.0f", $2/1048576}' /proc/meminfo) +echo "Total system RAM: $total_ram_gib GiB" + +# Set a safety limit for system RAM usage, e.g., 1/6th of total. +export host_memory_limit=$((total_ram_gib / 6)) +echo "Host memory process limit (1/6th of total RAM): $host_memory_limit" if [[ $num_cpu_cores -lt $num_processes ]]; then num_processes=$num_cpu_cores + echo "Adjusting num_processes to match CPU core count: $num_processes" fi if [[ $host_memory_limit -lt $num_processes ]]; then num_processes=$host_memory_limit + echo "Adjusting num_processes to match host memory limit: $num_processes" fi +echo "Final number of processes to run: $num_processes" + +export JAX_ENABLE_CUDA_XDIST="$gpu_count" export XLA_PYTHON_CLIENT_ALLOCATOR=platform export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 -# End of test environment variable setup + +# ============================================================================== +# Run tests +# ============================================================================== echo "Running CUDA tests..." "$JAXCI_PYTHON" -m pytest -n $num_processes --tb=short --maxfail=20 \ diff --git a/ci/run_pytest_tpu.sh b/ci/run_pytest_tpu.sh index 5d8aa9ed648f..95140d7e5133 100755 --- a/ci/run_pytest_tpu.sh +++ b/ci/run_pytest_tpu.sh @@ -35,14 +35,16 @@ source "ci/utilities/setup_build_environment.sh" # Print all the installed packages echo "Installed packages:" -"$JAXCI_PYTHON" -m uv pip list +"$JAXCI_PYTHON" -m uv pip freeze "$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))" "$JAXCI_PYTHON" -c 'import sys; print("python version:", sys.version)' "$JAXCI_PYTHON" -c 'import jax; print("jax version:", jax.__version__)' "$JAXCI_PYTHON" -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)' -strings /usr/local/lib/"$JAXCI_PYTHON"/dist-packages/libtpu/libtpu.so | grep 'Built on' -"$JAXCI_PYTHON" -c 'import jax; print("libtpu version:",jax.lib.xla_bridge.get_backend().platform_version)' +# Free-threaded builds use "-nogil" as the suffix for the binary and "t" for its +# dist-packages path +strings /usr/local/lib/"${JAXCI_PYTHON//-nogil/t}"/dist-packages/libtpu/libtpu.so | grep 'Built on' +"$JAXCI_PYTHON" -c 'import jax.extend; print("libtpu version:",jax.extend.backend.get_backend().platform_version)' # Set up all common test environment variables export PY_COLORS=1 @@ -52,6 +54,10 @@ export JAX_SKIP_SLOW_TESTS=true echo "Running TPU tests..." +# Don't abort the script if one command fails to ensure we run both test +# commands below. +set +e + if [[ "$JAXCI_RUN_FULL_TPU_TEST_SUITE" == "1" ]]; then # We're deselecting all Pallas TPU tests in the oldest libtpu build. Mosaic # TPU does not guarantee anything about forward compatibility (unless @@ -64,19 +70,22 @@ if [[ "$JAXCI_RUN_FULL_TPU_TEST_SUITE" == "1" ]]; then # Run single-accelerator tests in parallel JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \ - --deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \ - --maxfail=20 -m "not multiaccelerator" $IGNORE_FLAGS tests examples + --deselect=tests/pallas/tpu_pallas_call_print_test.py::PallasCallPrintTest \ + --deselect=tests/pallas/tpu_sparsecore_pallas_test.py::DebugPrintTest \ + --deselect=tests/pallas/tpu_pallas_interpret_thread_map_test.py::InterpretThreadMapTest::test_thread_map \ + --dist=loadfile --maxfail=20 -m "not multiaccelerator" $IGNORE_FLAGS tests examples - # Run Pallas printing tests, which need to run with I/O capturing disabled. - TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s \ - tests/pallas/tpu_pallas_test.py::PallasCallPrintTest + # Store the return value of the first command. + first_cmd_retval=$? # Run multi-accelerator across all chips "$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests + + # Store the return value of the second command. + second_cmd_retval=$? else # Run single-accelerator tests in parallel JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \ - --deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \ --maxfail=20 -m "not multiaccelerator" \ tests/pallas/ops_test.py \ tests/pallas/export_back_compat_pallas_test.py \ @@ -87,11 +96,33 @@ else tests/pallas/tpu_pallas_async_test.py \ tests/pallas/tpu_pallas_state_test.py - # Run Pallas printing tests, which need to run with I/O capturing disabled. - TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest + # Store the return value of the first command. + first_cmd_retval=$? # Run multi-accelerator across all chips "$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" \ tests/pjit_test.py \ tests/pallas/tpu_pallas_distributed_test.py + + # Store the return value of the second command. + second_cmd_retval=$? +fi + +# Run Pallas printing tests, which need to run with I/O capturing disabled. +TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest \ + -s tests/pallas/tpu_pallas_call_print_test.py::PallasCallPrintTest \ + -s tests/pallas/tpu_sparsecore_pallas_test.py::DebugPrintTest + +# Store the return value of the third command. +third_cmd_retval=$? + +# Exit with failure if either command fails. +if [[ $first_cmd_retval -ne 0 ]]; then + exit $first_cmd_retval +elif [[ $second_cmd_retval -ne 0 ]]; then + exit $second_cmd_retval +elif [[ $third_cmd_retval -ne 0 ]]; then + exit $third_cmd_retval +else + exit 0 fi \ No newline at end of file diff --git a/ci/utilities/README.md b/ci/utilities/README.md new file mode 100644 index 000000000000..35af5241767b --- /dev/null +++ b/ci/utilities/README.md @@ -0,0 +1,16 @@ +# JAX CI Utility Scripts + +This docpage gives a brief overview of the different utility scripts and what +they are used for. + +- **setup_build_environment.sh**: Sets up the build environment such as + cloning the latest XLA, adjusting file paths (for Windows), etc. +- **convert_msys_paths_to_win_paths.py**: Converts MSYS Linux-like paths + stored in env variables to Windows paths. +- **install_wheels_locally.sh**: Used by Pytest scripts to install JAX wheels + and any additional extras on the system. +- **run_auditwheel.sh**: Verifies that the Linux artifacts are "manylinux" + compliant. +- **run_docker_container.sh**: Runs a Docker container called "jax". Images + are read from the `JAXCI_DOCKER_IMAGE` environment variable in + [ci/envs/docker.env](https://github.com/jax-ml/jax/blob/main/ci/envs/docker.env). diff --git a/ci/utilities/install_wheels_locally.sh b/ci/utilities/install_wheels_locally.sh index f98f7658ad18..d66e1fea967b 100644 --- a/ci/utilities/install_wheels_locally.sh +++ b/ci/utilities/install_wheels_locally.sh @@ -22,31 +22,34 @@ WHEELS=( $(/usr/bin/find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jax*py3*" -o - for i in "${!WHEELS[@]}"; do if [[ "${WHEELS[$i]}" == *jax*py3*none*any.whl ]]; then - if [[ "$JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI" == "tpu_pypi" ]]; then - # Append [tpu] to the jax wheel name to download the latest libtpu wheel - # from PyPI. - WHEELS[$i]="${WHEELS[$i]}[tpu]" + # Append an extra to the end of the JAX wheel path to install those + # packages as well from PyPI. E.g. jax[tpu] will install the libtpu package + # from PyPI. See ci/envs/README.md for more details. + if [[ -n "$JAXCI_JAX_PYPI_EXTRAS" ]]; then + WHEELS[$i]="${WHEELS[$i]}[$JAXCI_JAX_PYPI_EXTRAS]" fi fi done -if [[ -z "${WHEELS[@]}" ]]; then - echo "ERROR: No wheels found under $JAXCI_OUTPUT_DIR" - exit 1 -fi +if [[ -n "${WHEELS[@]}" ]]; then + echo "Installing the following wheels:" + echo "${WHEELS[@]}" -echo "Installing the following wheels:" -echo "${WHEELS[@]}" - -# Install `uv` if it's not already installed. `uv` is much faster than pip for -# installing Python packages. -if ! command -v uv >/dev/null 2>&1; then - pip install uv~=0.5.30 -fi + # Install `uv` if it's not already installed. `uv` is much faster than pip for + # installing Python packages. + if ! command -v uv >/dev/null 2>&1; then + pip install uv~=0.5.30 + fi -# On Windows, convert MSYS Linux-like paths to Windows paths. -if [[ $(uname -s) =~ "MSYS_NT" ]]; then - "$JAXCI_PYTHON" -m uv pip install $(cygpath -w "${WHEELS[@]}") + # On Windows, convert MSYS Linux-like paths to Windows paths. + if [[ $(uname -s) =~ "MSYS_NT" ]]; then + "$JAXCI_PYTHON" -m uv pip install $(cygpath -w "${WHEELS[@]}") + else + "$JAXCI_PYTHON" -m uv pip install "${WHEELS[@]}" + fi else - "$JAXCI_PYTHON" -m uv pip install "${WHEELS[@]}" + # Note that we don't exit here because the wheels may have been installed + # earlier in a different step in the CI job. + echo "INFO: No wheels found under $JAXCI_OUTPUT_DIR" + echo "INFO: Skipping local wheel installation." fi \ No newline at end of file diff --git a/ci/utilities/run_auditwheel.sh b/ci/utilities/run_auditwheel.sh index 30b6a3b51865..304dd1ab1792 100755 --- a/ci/utilities/run_auditwheel.sh +++ b/ci/utilities/run_auditwheel.sh @@ -26,6 +26,10 @@ if [[ -z "$WHEELS" ]]; then fi for wheel in $WHEELS; do + # Skip checking manylinux compliance for jax wheel. + if [[ "$wheel" =~ 'jax-' ]]; then + continue + fi printf "\nRunning auditwheel on the following wheel:" ls $wheel OUTPUT_FULL=$(python -m auditwheel show $wheel) @@ -33,14 +37,19 @@ for wheel in $WHEELS; do wheel_name=$(basename $wheel) OUTPUT=${OUTPUT_FULL//${wheel_name}/} - # If a wheel is manylinux2014 compliant, `auditwheel show` will return the - # platform tag as manylinux_2_17. manylinux2014 is an alias for - # manylinux_2_17. - if echo "$OUTPUT" | grep -q "manylinux_2_17"; then + # If a wheel is manylinux_2_27 or manylinux2014 compliant, `auditwheel show` + # will return platform tag as manylinux_2_27 or manylinux_2_17 respectively. + # manylinux2014 is an alias for manylinux_2_17. + if echo "$OUTPUT" | grep -q "manylinux_2_27"; then + printf "\n$wheel_name is manylinux_2_27 compliant.\n" + # jax_cudaX_plugin...aarch64.whl is consistent with tag: manylinux_2_26_aarch64" + elif echo "$OUTPUT" | grep -q "manylinux_2_26"; then + printf "\n$wheel_name is manylinux_2_26 compliant.\n" + elif echo "$OUTPUT" | grep -q "manylinux_2_17"; then printf "\n$wheel_name is manylinux2014 compliant.\n" else echo "$OUTPUT_FULL" - printf "\n$wheel_name is NOT manylinux2014 compliant.\n" + printf "\n$wheel_name is NOT manylinux_2_27 or manylinux2014 compliant.\n" exit 1 fi done \ No newline at end of file diff --git a/ci/utilities/setup_build_environment.sh b/ci/utilities/setup_build_environment.sh index 114acf2479ff..d7cf81066817 100644 --- a/ci/utilities/setup_build_environment.sh +++ b/ci/utilities/setup_build_environment.sh @@ -16,7 +16,7 @@ # Set up the build environment for JAX CI jobs. This script depends on the # "JAXCI_" environment variables set or sourced in the build script. -# Pre-emptively mark the JAX git directory as safe. This is necessary for JAX CI +# Preemptively mark the JAX git directory as safe. This is necessary for JAX CI # jobs running on Linux runners in GitHub Actions. Without this, git complains # that the directory has dubious ownership and refuses to run any commands. # Avoid running on Windows runners as git runs into issues with not being able @@ -31,6 +31,9 @@ fi function clone_main_xla() { echo "Cloning XLA at HEAD to $(pwd)/xla" git clone --depth=1 https://github.com/openxla/xla.git $(pwd)/xla + cd $(pwd)/xla + echo "XLA commit: $(git log -1 --format=%H)" + cd .. export JAXCI_XLA_GIT_DIR=$(pwd)/xla } diff --git a/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb b/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb index edaa71b93e85..5bc045d0f606 100644 --- a/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb +++ b/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb @@ -225,7 +225,7 @@ "* Jacobian pre-accumulation for elementwise operations (like `gelu`)\n", "\n", "\n", - "For much more, see the [JAX Autodiff Cookbook (Part 1)](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)." + "For much more, see the [JAX Autodiff Cookbook (Part 1)](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html)." ] }, { diff --git a/cloud_tpu_colabs/JAX_demo.ipynb b/cloud_tpu_colabs/JAX_demo.ipynb index d7ba5ed334f4..b69246c57e0b 100644 --- a/cloud_tpu_colabs/JAX_demo.ipynb +++ b/cloud_tpu_colabs/JAX_demo.ipynb @@ -315,7 +315,7 @@ "* Jacobian pre-accumulation for elementwise operations (like `gelu`)\n", "\n", "\n", - "For much more, see the [JAX Autodiff Cookbook (Part 1)](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)." + "For much more, see the [JAX Autodiff Cookbook (Part 1)](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html)." ] }, { diff --git a/cloud_tpu_colabs/Pmap_Cookbook.ipynb b/cloud_tpu_colabs/Pmap_Cookbook.ipynb index ea126ac4f1e7..8b16cd7694eb 100644 --- a/cloud_tpu_colabs/Pmap_Cookbook.ipynb +++ b/cloud_tpu_colabs/Pmap_Cookbook.ipynb @@ -59,7 +59,7 @@ "id": "2e_06-OAJNyi" }, "source": [ - "A basic starting point is expressing parallel maps with [`pmap`](https://jax.readthedocs.io/en/latest/jax.html#jax.pmap):" + "A basic starting point is expressing parallel maps with [`pmap`](https://docs.jax.dev/en/latest/jax.html#jax.pmap):" ] }, { @@ -407,7 +407,7 @@ "source": [ "When writing nested `pmap` functions in the decorator style, axis names are resolved according to lexical scoping.\n", "\n", - "Check [the JAX reference documentation](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators) for a complete list of the parallel operators. More are being added!\n", + "Check [the JAX reference documentation](https://docs.jax.dev/en/latest/jax.lax.html#parallel-operators) for a complete list of the parallel operators. More are being added!\n", "\n", "Here's how to use `lax.ppermute` to implement a simple halo exchange for a [Rule 30](https://en.wikipedia.org/wiki/Rule_30) simulation:" ] diff --git a/cloud_tpu_colabs/README.md b/cloud_tpu_colabs/README.md index db3dc5f30814..6e5501584da0 100644 --- a/cloud_tpu_colabs/README.md +++ b/cloud_tpu_colabs/README.md @@ -4,7 +4,7 @@ The same JAX code that runs on CPU and GPU can also be run on TPU. Cloud TPUs have the advantage of quickly giving you access to multiple TPU accelerators, including in [Colab](https://research.google.com/colaboratory/). All of the example notebooks here use -[`jax.pmap`](https://jax.readthedocs.io/en/latest/jax.html#jax.pmap) to run JAX +[`jax.pmap`](https://docs.jax.dev/en/latest/jax.html#jax.pmap) to run JAX computation across multiple TPU cores from Colab. You can also run the same code directly on a [Cloud TPU VM](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm). diff --git a/conftest.py b/conftest.py index fed4564bbc1c..72b4b598891c 100644 --- a/conftest.py +++ b/conftest.py @@ -15,12 +15,16 @@ import os import pytest - +import json +import threading +import shutil +from datetime import datetime @pytest.fixture(autouse=True) def add_imports(doctest_namespace): import jax import numpy + doctest_namespace["jax"] = jax doctest_namespace["lax"] = jax.lax doctest_namespace["jnp"] = jax.numpy @@ -29,8 +33,8 @@ def add_imports(doctest_namespace): # A pytest hook that runs immediately before test collection (i.e. when pytest # loads all the test cases to run). When running parallel tests via xdist on -# Cloud TPU, we use this hook to set the env vars needed to run multiple test -# processes across different TPU chips. +# GPU or Cloud TPU, we use this hook to set the env vars needed to run multiple +# test processes across different chips. # # It's important that the hook runs before test collection, since jax tests end # up initializing the TPU runtime on import (e.g. to query supported test @@ -43,17 +47,203 @@ def add_imports(doctest_namespace): # https://docs.pytest.org/en/latest/how-to/writing_hook_functions.html#firstresult-stop-at-first-non-none-result # for details. # -# The env var JAX_ENABLE_TPU_XDIST must be set for this hook to have an +# For TPU, the env var JAX_ENABLE_TPU_XDIST must be set for this hook to have an # effect. We do this to minimize any effect on non-TPU tests, and as a pointer # in test code to this "magic" hook. TPU tests should not specify more xdist # workers than the number of TPU chips. +# +# For GPU, the env var JAX_ENABLE_CUDA_XDIST must be set equal to the number of +# CUDA devices. Test processes will be assigned in round robin fashion across +# the devices. def pytest_collection() -> None: - if not os.environ.get("JAX_ENABLE_TPU_XDIST", None): - return - # When running as an xdist worker, will be something like "gw0" - xdist_worker_name = os.environ.get("PYTEST_XDIST_WORKER", "") - if not xdist_worker_name.startswith("gw"): - return - xdist_worker_number = int(xdist_worker_name[len("gw"):]) - os.environ.setdefault("TPU_VISIBLE_CHIPS", str(xdist_worker_number)) - os.environ.setdefault("ALLOW_MULTIPLE_LIBTPU_LOAD", "true") + if os.environ.get("JAX_ENABLE_TPU_XDIST", None): + # When running as an xdist worker, will be something like "gw0" + xdist_worker_name = os.environ.get("PYTEST_XDIST_WORKER", "") + if not xdist_worker_name.startswith("gw"): + return + xdist_worker_number = int(xdist_worker_name[len("gw") :]) + os.environ.setdefault("TPU_VISIBLE_CHIPS", str(xdist_worker_number)) + os.environ.setdefault("ALLOW_MULTIPLE_LIBTPU_LOAD", "true") + + elif num_cuda_devices := os.environ.get("JAX_ENABLE_CUDA_XDIST", None): + num_cuda_devices = int(num_cuda_devices) + # When running as an xdist worker, will be something like "gw0" + xdist_worker_name = os.environ.get("PYTEST_XDIST_WORKER", "") + if not xdist_worker_name.startswith("gw"): + return + xdist_worker_number = int(xdist_worker_name[len("gw") :]) + os.environ.setdefault( + "CUDA_VISIBLE_DEVICES", str(xdist_worker_number % num_cuda_devices) + ) + +class ThreadSafeTestLogger: + """Thread-safe logging for parallel test execution and abort detection""" + def __init__(self): + self.locks = {} + self.global_lock = threading.Lock() + self.base_dir = os.path.abspath("./logs") + + # Create logs directory (archiving is handled by test runner scripts) + try: + os.makedirs(self.base_dir, exist_ok=True) + print(f"[TestLogger] Initialized log directory: {self.base_dir}") + except Exception as e: + print(f"[TestLogger] ERROR: Failed to create log directory {self.base_dir}: {e}") + # Fallback to temp directory if logs dir creation fails + import tempfile + self.base_dir = os.path.join(tempfile.gettempdir(), "jax_test_logs") + os.makedirs(self.base_dir, exist_ok=True) + print(f"[TestLogger] Using fallback directory: {self.base_dir}") + + def get_file_lock(self, test_file): + """Get or create a lock for a specific test file""" + with self.global_lock: + if test_file not in self.locks: + self.locks[test_file] = threading.Lock() + return self.locks[test_file] + + def get_test_file_name(self, session): + """Extract the test file name from the session""" + # Try to get from session config args + if hasattr(session, "config") and hasattr(session.config, "args"): + for arg in session.config.args: + # Handle full nodeid like "jax/tests/foo_test.py::TestClass::test_method" + if "tests/" in arg: + # Split on :: to get just the file path + file_path = arg.split("::")[0] + if file_path.endswith(".py"): + return os.path.basename(file_path).replace(".py", "") + + # Try to get from invocation params + if hasattr(session, "config") and hasattr(session.config, "invocation_params"): + invocation_dir = getattr(session.config.invocation_params, "dir", None) + if invocation_dir: + dir_name = os.path.basename(str(invocation_dir)) + if dir_name: + print(f"[TestLogger] Using invocation directory as test name: {dir_name}") + return dir_name + + # Last resort: try to get from session items + if hasattr(session, "items") and session.items: + first_item = session.items[0] + if hasattr(first_item, "fspath"): + fspath = str(first_item.fspath) + if ".py" in fspath: + return os.path.basename(fspath).replace(".py", "") + + print(f"[TestLogger] WARNING: Could not determine test file name, using 'unknown_test'") + print(f"[TestLogger] Session config args: {getattr(session.config, 'args', 'N/A')}") + return "unknown_test" + + def log_running_test(self, test_file, test_name, nodeid, start_time): + """Log the currently running test for abort detection""" + lock = self.get_file_lock(test_file) + with lock: + log_data = { + "test_file": test_file, + "test_name": test_name, + "nodeid": nodeid, + "start_time": start_time, + "status": "running", + "pid": os.getpid(), + "gpu_id": os.environ.get("HIP_VISIBLE_DEVICES", "unknown"), + } + + log_file = f"{self.base_dir}/{test_file}_last_running.json" + try: + # Ensure directory still exists (might have been deleted) + os.makedirs(self.base_dir, exist_ok=True) + with open(log_file, "w") as f: + json.dump(log_data, f, indent=2) + except Exception as e: + print(f"[TestLogger] ERROR: Failed to write running test log to {log_file}: {e}") + print(f"[TestLogger] Current working directory: {os.getcwd()}") + print(f"[TestLogger] Base directory: {self.base_dir}") + print(f"[TestLogger] Base directory exists: {os.path.exists(self.base_dir)}") + raise + + def clear_running_test(self, test_file): + """Clear the running test log when test completes successfully""" + lock = self.get_file_lock(test_file) + with lock: + log_file = f"{self.base_dir}/{test_file}_last_running.json" + if os.path.exists(log_file): + os.remove(log_file) + + +# Global logger instance +test_logger = ThreadSafeTestLogger() + + +@pytest.hookimpl(hookwrapper=True) +def pytest_runtest_protocol(item, nextitem): + """Hook that wraps around each test to track running tests for crash detection. + + This creates a "last_running" file before each test starts and deletes it + when the test completes successfully. If the test crashes, the file remains + and can be detected by the test runner. + """ + test_file = test_logger.get_test_file_name(item.session) + test_name = item.name + nodeid = item.nodeid + start_time = datetime.now().isoformat() + + # Log that this test is starting + try: + test_logger.log_running_test(test_file, test_name, nodeid, start_time) + except Exception as e: + print(f"[TestLogger] WARNING: Failed to log running test: {e}") + # Continue anyway - not critical for test execution + + test_completed = False + try: + outcome = yield + # Test completed (successfully or with normal failure) + test_completed = True + + # Clear the crash detection file + try: + test_logger.clear_running_test(test_file) + except Exception as e: + print(f"[TestLogger] WARNING: Failed to clear running test log: {e}") + + except Exception as e: + # Test raised exception (might be crash, might be normal exception) + print(f"[TestLogger] Test {test_name} exception: {e}") + if not test_completed: + # Don't clear the file - this might be a crash + print(f"[TestLogger] Leaving crash file for detection") + raise + + +@pytest.hookimpl(tryfirst=True) +def pytest_sessionstart(session): + """Called after the Session object has been created""" + gpu = os.environ.get('HIP_VISIBLE_DEVICES', '?') + print(f"Test session starting on GPU {gpu}") + + +@pytest.hookimpl(trylast=True) +def pytest_sessionfinish(session, exitstatus): + """Called after test run finished. + + If a crash file still exists, it means a test crashed and the runner + will detect it. We just report it here for visibility. + """ + test_file = test_logger.get_test_file_name(session) + log_file = f"{test_logger.base_dir}/{test_file}_last_running.json" + + if os.path.exists(log_file): + try: + with open(log_file, "r") as f: + abort_data = json.load(f) + print( + f"\n[CRASH DETECTED] {abort_data.get('nodeid', abort_data.get('test_name', 'unknown'))} " + f"(GPU: {abort_data.get('gpu_id', '?')}, PID: {abort_data.get('pid', '?')})" + ) + print(f"[CRASH DETECTED] Crash file will be processed by test runner") + except Exception as e: + print(f"[TestLogger] WARNING: Crash file exists but unreadable: {e}") + else: + # Normal completion - no crash + pass diff --git a/docs/README.md b/docs/README.md index 12e00425592f..54b8a67477b0 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,2 +1,2 @@ To rebuild the documentation, -see [Update Documentation](https://jax.readthedocs.io/en/latest/developer.html#update-documentation). +see [Update Documentation](https://docs.jax.dev/en/latest/developer.html#update-documentation). diff --git a/jax/_src/scipy/interpolate/__init__.py b/docs/__init__.py similarity index 94% rename from jax/_src/scipy/interpolate/__init__.py rename to docs/__init__.py index 37ee0c309fae..1337256a5074 100644 --- a/jax/_src/scipy/interpolate/__init__.py +++ b/docs/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 The JAX Authors. +# Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/docs/_static/fault_tolerance/cancel_collectives.py b/docs/_static/fault_tolerance/cancel_collectives.py new file mode 100644 index 000000000000..42c75c015112 --- /dev/null +++ b/docs/_static/fault_tolerance/cancel_collectives.py @@ -0,0 +1,60 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +os.environ['XLA_FLAGS'] = ' '.join([ + '--xla_gpu_nccl_terminate_on_error=false', + '--xla_gpu_nccl_async_execution=true', + '--xla_gpu_nccl_blocking_communicators=false', +]) +os.environ['XLA_PYTHON_CLIENT_ABORT_COLLECTIVES_ON_FAILURE'] = '1' +os.environ['XLA_PYTHON_CLIENT_USE_TFRT_GPU_CLIENT'] = '1' + +from absl import app +from absl import flags +from collections.abc import Sequence +import jax +import jax.numpy as jnp +import time + +_PROCESS_ID = flags.DEFINE_integer("i", -1, "Process id") +_NUM_PROCESSES = flags.DEFINE_integer("n", -1, "Number of processes") + + +def main(_: Sequence[str]) -> None: + jax.config.update("jax_enable_recoverability", True) + jax.distributed.initialize( + coordinator_address="localhost:9000", + num_processes=_NUM_PROCESSES.value, + process_id=_PROCESS_ID.value, + local_device_ids=[_PROCESS_ID.value], + heartbeat_timeout_seconds=10, + ) + print(f'{jax.devices()=}') + print(f'{jax.local_devices()=}') + + # Don't do this. Use live_devices instead. + from jax.experimental.multihost_utils import _live_devices + _live_devices(jax._src.distributed.global_state.client, jax.devices()) + + n = jax.device_count() + jax.set_mesh(jax.make_mesh((n,), ("i",))) + x = jax.device_put(jnp.arange(n), jax.P("i")) + while True: + print(jnp.sum(x)) + time.sleep(1) + + +if __name__ == "__main__": + app.run(main) diff --git a/docs/_static/fault_tolerance/collectives.py b/docs/_static/fault_tolerance/collectives.py new file mode 100644 index 000000000000..0f120f47271f --- /dev/null +++ b/docs/_static/fault_tolerance/collectives.py @@ -0,0 +1,50 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +os.environ['XLA_FLAGS'] = '--xla_gpu_nccl_terminate_on_error=false' + +from absl import app +from absl import flags +from collections.abc import Sequence +import jax +import jax.numpy as jnp +import time + +_PROCESS_ID = flags.DEFINE_integer("i", -1, "Process id") +_NUM_PROCESSES = flags.DEFINE_integer("n", -1, "Number of processes") + + +def main(_: Sequence[str]) -> None: + jax.config.update("jax_enable_recoverability", True) + jax.distributed.initialize( + coordinator_address="localhost:9000", + num_processes=_NUM_PROCESSES.value, + process_id=_PROCESS_ID.value, + local_device_ids=[_PROCESS_ID.value], + heartbeat_timeout_seconds=10, + ) + print(f'{jax.devices()=}') + print(f'{jax.local_devices()=}') + + n = jax.device_count() + jax.set_mesh(jax.make_mesh((n,), ("i",))) + x = jax.device_put(jnp.arange(n), jax.P("i")) + while True: + print(jnp.sum(x)) + time.sleep(1) + + +if __name__ == "__main__": + app.run(main) diff --git a/docs/_static/fault_tolerance/data_parallelism.py b/docs/_static/fault_tolerance/data_parallelism.py new file mode 100644 index 000000000000..c70d52751ecb --- /dev/null +++ b/docs/_static/fault_tolerance/data_parallelism.py @@ -0,0 +1,129 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +os.environ['XLA_FLAGS'] = ' '.join([ + '--xla_gpu_nccl_terminate_on_error=false', + '--xla_gpu_nccl_async_execution=true', + '--xla_gpu_nccl_blocking_communicators=false', +]) +os.environ['XLA_PYTHON_CLIENT_ABORT_COLLECTIVES_ON_FAILURE'] = '1' +os.environ['XLA_PYTHON_CLIENT_USE_TFRT_GPU_CLIENT'] = '1' + +from absl import app +from absl import flags +from collections.abc import Sequence +from jax.experimental.multihost_utils import live_devices +import jax +import jax.numpy as jnp +import time + +_PROCESS_ID = flags.DEFINE_integer("i", -1, "Process id") +_NUM_PROCESSES = flags.DEFINE_integer("n", -1, "Number of processes") + +def replicated(x: jax.Array, devices: list[jax.Device]): + """Return x replicated across the provided devices. + + Note that replicated(x) doesn't actually move any data. It simply creates a + logically replicated array with x as the local replica. + """ + n = len(devices) + mesh = jax.make_mesh((n, ), ("i", ), devices=devices) + spec = jax.sharding.PartitionSpec(None) + sharding = jax.sharding.NamedSharding(mesh, spec) + shards = [ + jax.device_put(x.addressable_shards[0].data, d) for d in devices + if d.process_index == jax.process_index() + ] + return jax.make_array_from_single_device_arrays(x.shape, sharding, shards) + + +def sharded(x: jax.Array, devices: list[jax.Device]): + """Return x sharded across the provided devices. + + Note that sharded(x) doesn't actually move any data. It simply creates a + logically sharded array. x should have the same shape as the global array. + """ + n = len(devices) + mesh = jax.make_mesh((n, ), ("i", ), devices=devices) + spec = jax.sharding.PartitionSpec("i") + sharding = jax.sharding.NamedSharding(mesh, spec) + m = sharding.addressable_devices_indices_map(x.shape) + shards = [jax.device_put(x[m[d]], d) for d in jax.local_devices()] + return jax.make_array_from_single_device_arrays(x.shape, sharding, shards) + + +def main(_: Sequence[str]) -> None: + # Parse command line arguments and initialize multi-controller JAX. + jax.config.update("jax_enable_recoverability", True) + jax.distributed.initialize(coordinator_address="localhost:8000", + process_id=_PROCESS_ID.value, + num_processes=_NUM_PROCESSES.value, + local_device_ids=[_PROCESS_ID.value], + heartbeat_timeout_seconds=10) + print(f'{jax.devices()=}') + print(f'{jax.local_devices()=}') + + # Initialize the model's weights. + keys = iter(jax.random.split(jax.random.key(seed=42), num=3)) + weights = jax.random.normal(next(keys), shape=(1, )) + + # We'll learn a trivial linear model: a*x. + def predict(weights, X): + return weights * X + + # We'll use mean squared error loss. + def loss(weights, X, Y): + return jnp.mean((predict(weights, X) - Y)**2) + + # Initialize the (noisy) training data with a=10. + X = jax.random.permutation(next(keys), jnp.arange(-300., 300.)) + Y = 10 * X + jax.random.normal(next(keys), X.shape) + + # Hyperparameters. + loss_and_grad = jax.jit(jax.value_and_grad(loss)) + learning_rate = 1e-6 + device_batch_size = 10 + + step = 0 + while True: + try: + with live_devices(jax.devices()) as devices: + print(f'=== Running step {step} with live devices = {devices} ===') + + # Replicate the model weights. + weights = replicated(weights, devices) + + # Shard the batch. + batch_size = device_batch_size * len(devices) + start = (step * batch_size) % len(X) + stop = start + batch_size + X_batch = sharded(X[start:stop], devices) + Y_batch = sharded(Y[start:stop], devices) + + # Compute gradients and update weights. + l, grad = loss_and_grad(weights, X_batch, Y_batch) + new_weights = jax.block_until_ready(weights - learning_rate * grad) + except Exception as e: + print(f'Step {step} failed: {e}') + else: + print(f'Step {step} succeeded: loss = {l}') + step += 1 + weights = new_weights + + time.sleep(1) + + +if __name__ == "__main__": + app.run(main) diff --git a/docs/_static/fault_tolerance/data_parallelism_with_recovery.py b/docs/_static/fault_tolerance/data_parallelism_with_recovery.py new file mode 100644 index 000000000000..b97461a20773 --- /dev/null +++ b/docs/_static/fault_tolerance/data_parallelism_with_recovery.py @@ -0,0 +1,182 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +os.environ['XLA_FLAGS'] = ' '.join([ + '--xla_gpu_nccl_terminate_on_error=false', + '--xla_gpu_nccl_async_execution=true', + '--xla_gpu_nccl_blocking_communicators=false', +]) +os.environ['XLA_PYTHON_CLIENT_ABORT_COLLECTIVES_ON_FAILURE'] = '1' +os.environ['XLA_PYTHON_CLIENT_USE_TFRT_GPU_CLIENT'] = '1' + +from absl import app +from absl import flags +from collections.abc import Sequence +from jax.experimental.multihost_utils import live_devices +from jax.experimental import shard_map +import jax +import jax.numpy as jnp +import numpy as np +import time + +_PROCESS_ID = flags.DEFINE_integer("i", -1, "Process id") +_NUM_PROCESSES = flags.DEFINE_integer("n", -1, "Number of processes") + +def replicated(x: jax.Array, devices: list[jax.Device]): + """Return x replicated across the provided devices. + + Note that replicated(x) doesn't actually move any data. It simply creates a + logically replicated array with x as the local replica. + """ + n = len(devices) + mesh = jax.make_mesh((n, ), ("i", ), devices=devices) + spec = jax.sharding.PartitionSpec(None) + sharding = jax.sharding.NamedSharding(mesh, spec) + shards = [ + jax.device_put(x.addressable_shards[0].data, d) for d in devices + if d.process_index == jax.process_index() + ] + return jax.make_array_from_single_device_arrays(x.shape, sharding, shards) + + +def sharded(x: jax.Array, devices: list[jax.Device]): + """Return x sharded across the provided devices. + + Note that sharded(x) doesn't actually move any data. It simply creates a + logically sharded array. x should have the same shape as the global array. + """ + n = len(devices) + mesh = jax.make_mesh((n, ), ("i", ), devices=devices) + spec = jax.sharding.PartitionSpec("i") + sharding = jax.sharding.NamedSharding(mesh, spec) + m = sharding.addressable_devices_indices_map(x.shape) + shards = [jax.device_put(x[m[d]], d) for d in jax.local_devices()] + return jax.make_array_from_single_device_arrays(x.shape, sharding, shards) + + +def send(x: jax.Array, from_device: jax.Device, to_device: jax.Device): + """Sends x from one device to another.""" + assert isinstance(x, jax.Array) + devices = [from_device, to_device] + psum = lambda x: jax.lax.psum(x, "i") + mesh = jax.make_mesh((2, ), ("i", ), devices=devices) + spec = jax.sharding.PartitionSpec(None) + x = replicated(x, [from_device, to_device]) + shard_map.shard_map(psum, mesh=mesh, in_specs=spec, out_specs=spec)(x) + + +def recv(x: jax.Array, from_device: jax.Device, to_device: jax.Device): + """Receives x from a matching send.""" + assert isinstance(x, jax.Array) + to_device = jax.local_devices()[0] + devices = [from_device, to_device] + psum = lambda x: jax.lax.psum(x, "i") + mesh = jax.make_mesh((2, ), ("i", ), devices=devices) + spec = jax.sharding.PartitionSpec(None) + x = jnp.zeros_like(x) + x = replicated(x, [from_device, to_device]) + return shard_map.shard_map(psum, mesh=mesh, in_specs=spec, out_specs=spec)(x) + + +def allgather(x: float, devices: list[jax.Device]) -> list[float]: + """Performs an AllGather across the provided devices.""" + n = len(devices) + mesh = jax.make_mesh((n, ), ("i", ), devices=devices) + spec = jax.sharding.PartitionSpec('i') + p = lambda x: jax.lax.all_gather(x, "i", tiled=True) + f = jax.shard_map(p, mesh=mesh, in_specs=spec, out_specs=spec) + return jax.block_until_ready(f(np.array([x] * len(devices)))).addressable_shards[0].data + + +def main(_: Sequence[str]) -> None: + # Parse command line arguments and initialize multi-controller JAX. + jax.config.update("jax_enable_recoverability", True) + jax.distributed.initialize(coordinator_address="localhost:8000", + process_id=_PROCESS_ID.value, + num_processes=_NUM_PROCESSES.value, + local_device_ids=[_PROCESS_ID.value], + heartbeat_timeout_seconds=10) + print(f'{jax.devices()=}') + print(f'{jax.local_devices()=}') + + # Initialize the model's weights. + keys = iter(jax.random.split(jax.random.key(seed=42), num=3)) + weights = jax.random.normal(next(keys), shape=(1, )) + + # We'll learn a trivial linear model: a*x. + def predict(weights, X): + return weights * X + + # We'll use mean squared error loss. + def loss(weights, X, Y): + return jnp.mean((predict(weights, X) - Y)**2) + + # Initialize the (noisy) training data with a=10. + X = jax.random.permutation(next(keys), jnp.arange(-300., 300.)) + Y = 10 * X + jax.random.normal(next(keys), X.shape) + + # Hyperparameters. + loss_and_grad = jax.jit(jax.value_and_grad(loss)) + learning_rate = 1e-6 + device_batch_size = 10 + + step = 0 + while True: + try: + with live_devices(jax.devices()) as devices: + print(f'=== Running step {step} with live devices = {devices} ===') + + # Handle recovering devices. A device is recovering if its step doesn't + # match process 0's step. We assume process 0 never fails. + print('all gathering steps...') + steps = allgather(step, devices) + print(f'{steps=}') + recovering = [d for d, s in zip(devices, steps) if s != steps[0]] + for d in recovering: + # Process 0 sends weights and step to the recovering devices. + if jax.process_index() == 0: + print('sending...') + send(weights, jax.devices()[0], d) + send(jnp.array([step]), jax.devices()[0], d) + elif d.process_index == jax.process_index(): + print('receiving...') + weights = recv(weights, jax.devices()[0], d) + step = recv(jnp.array([step]), jax.devices()[0], d)[0] + + # Replicate the model weights. + weights = replicated(weights, devices) + + # Shard the batch. + batch_size = device_batch_size * len(devices) + start = (step * batch_size) % len(X) + stop = start + batch_size + X_batch = sharded(X[start:stop], devices) + Y_batch = sharded(Y[start:stop], devices) + + # Compute gradients and update weights. + l, grad = loss_and_grad(weights, X_batch, Y_batch) + new_weights = jax.block_until_ready(weights - learning_rate * grad) + except Exception as e: + print(f'Step {step} failed: {e}') + else: + print(f'Step {step} succeeded: loss = {l}') + step += 1 + weights = new_weights + + time.sleep(1) + + +if __name__ == "__main__": + app.run(main) diff --git a/docs/_static/fault_tolerance/dont_fail.py b/docs/_static/fault_tolerance/dont_fail.py new file mode 100644 index 000000000000..a44514d65c71 --- /dev/null +++ b/docs/_static/fault_tolerance/dont_fail.py @@ -0,0 +1,45 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +os.environ['XLA_FLAGS'] = '--xla_gpu_nccl_terminate_on_error=false' + +from absl import app +from absl import flags +from collections.abc import Sequence +import jax +import time + +_PROCESS_ID = flags.DEFINE_integer("i", -1, "Process id") +_NUM_PROCESSES = flags.DEFINE_integer("n", -1, "Number of processes") + + +def main(_: Sequence[str]) -> None: + jax.config.update("jax_enable_recoverability", True) + jax.distributed.initialize( + coordinator_address="localhost:9000", + num_processes=_NUM_PROCESSES.value, + process_id=_PROCESS_ID.value, + local_device_ids=[_PROCESS_ID.value], + heartbeat_timeout_seconds=10, + ) + print(f'{jax.devices()=}') + print(f'{jax.local_devices()=}') + while True: + print(time.time()) + time.sleep(1) + + +if __name__ == "__main__": + app.run(main) diff --git a/docs/_static/fault_tolerance/fault_tolerance.css b/docs/_static/fault_tolerance/fault_tolerance.css new file mode 100644 index 000000000000..86f26e5e842a --- /dev/null +++ b/docs/_static/fault_tolerance/fault_tolerance.css @@ -0,0 +1,283 @@ +.cluster { + margin: 1em; + font-size: smaller; + position: relative; + height: 20em; +} + +.server-box { + position: absolute; + top: 0%; + left: 0%; + width: 100%; + display: flex; + align-items: flex-start; +} + +.proc-box { + display: flex; + flex-direction: column; + max-width: 33%; + transform: translate(0%, -100%); +} + +.p0-box { + position: absolute; + top: 100%; + left: 0%; +} + +.p1-box { + position: absolute; + top: 100%; + left: 35%; +} + +.p2-box { + position: absolute; + top: 100%; + left: 70%; +} + +.proc-box div { + margin: 1pt; +} + +.proc-box button { + margin: 1pt; +} + +.server { + padding: 0.5em; + border: 2pt solid black; + background-color: #FEEFC3; + display: flex; + justify-content: center; + align-items: center; + text-align: center; + font-weight: bold; + z-index: 99; +} + +.proc { + border: 2pt solid black; + border-radius: 50%; + width: 2rem; + height: 2rem; + display: flex; + align-items: center; + justify-content: center; + font-weight: bold; + z-index: 99; +} + +.p0 { + background-color: #D2E3FC; +} + +.p1 { + background-color: #FAD2CF; +} + +.p2 { + background-color: #CEEAD6; +} + +.alive { + color: green; +} + +.dead { + color: red; +} + +.failed { + background-color: gray; +} + +.msg { + position: absolute; + animation-timing-function: linear; + font-size: large; + z-index: -1; +} + +.p0_to_pserver { + animation-duration: 1s; + animation-name: p0_to_pserver_keyframes; +} + +.p1_to_pserver { + animation-duration: 1.5s; + animation-name: p1_to_pserver_keyframes; +} + +.p2_to_pserver { + animation-duration: 2.0s; + animation-name: p2_to_pserver_keyframes; +} + +.pserver_to_p0 { + animation-duration: 1s; + animation-name: pserver_to_p0_keyframes; +} + +.pserver_to_p1 { + animation-duration: 1.5s; + animation-name: pserver_to_p1_keyframes; +} + +.pserver_to_p2 { + animation-duration: 2.0s; + animation-name: pserver_to_p2_keyframes; +} + +@keyframes p0_to_pserver_keyframes { + from { top: 75%; left: 1%; } + to { top: 0%; left: 1%; } +} + +@keyframes p1_to_pserver_keyframes { + from { top: 75%; left: 36%; } + to { top: 0%; left: 1%; } +} + +@keyframes p2_to_pserver_keyframes { + from { top: 75%; left: 71%; } + to { top: 0%; left: 1%; } +} + +@keyframes pserver_to_p0_keyframes { + from { top: 0%; left: 1%; } + to { top: 75%; left: 1%; } +} + +@keyframes pserver_to_p1_keyframes { + from { top: 0%; left: 1%; } + to { top: 75%; left: 36%; } +} + +@keyframes pserver_to_p2_keyframes { + from { top: 0%; left: 1%; } + to { top: 75%; left: 71%; } +} + +.p0_to_pserver_tall { + animation-duration: 1s; + animation-name: p0_to_pserver_keyframes_tall; +} + +.p1_to_pserver_tall { + animation-duration: 1.5s; + animation-name: p1_to_pserver_keyframes_tall; +} + +.p2_to_pserver_tall { + animation-duration: 2.0s; + animation-name: p2_to_pserver_keyframes_tall; +} + +.pserver_to_p0_tall { + animation-duration: 1s; + animation-name: pserver_to_p0_keyframes_tall; +} + +.pserver_to_p1_tall { + animation-duration: 1.5s; + animation-name: pserver_to_p1_keyframes_tall; +} + +.pserver_to_p2_tall { + animation-duration: 2.0s; + animation-name: pserver_to_p2_keyframes_tall; +} + +@keyframes p0_to_pserver_keyframes_tall { + from { top: 55%; left: 1%; } + to { top: 0%; left: 1%; } +} + +@keyframes p1_to_pserver_keyframes_tall { + from { top: 55%; left: 36%; } + to { top: 0%; left: 1%; } +} + +@keyframes p2_to_pserver_keyframes_tall { + from { top: 55%; left: 71%; } + to { top: 0%; left: 1%; } +} + +@keyframes pserver_to_p0_keyframes_tall { + from { top: 0%; left: 1%; } + to { top: 55%; left: 1%; } +} + +@keyframes pserver_to_p1_keyframes_tall { + from { top: 0%; left: 1%; } + to { top: 55%; left: 36%; } +} + +@keyframes pserver_to_p2_keyframes_tall { + from { top: 0%; left: 1%; } + to { top: 55%; left: 71%; } +} + + +.svgbox { + margin-bottom: 1.15em; +} + +.svgbox svg { + margin-left: auto; + margin-right: auto; + display: block; + width: 100%; + height: 100%; +} + +.svgbox svg .proc { + font-family: monospace; + dominant-baseline: middle; + text-anchor: middle; +} + +.svgbox svg .proc-axis { + stroke: gray; + stroke-width: 0.5; + stroke-linecap: round; +} + +.svgbox svg .event { + dominant-baseline: middle; + text-anchor: middle; + stroke: black; +} + +.svgbox svg .p0-color { + stroke: #D2E3FC; +} + +.svgbox svg .p1-color { + stroke: #FAD2CF; +} + +.svgbox svg .p2-color { + stroke: #CEEAD7; +} + +.svgbox svg .rpc { + stroke-width: 12; + stroke-linecap: round; +} + +.svgbox svg .reply { + font-family: monospace; + font-size: smaller; + dominant-baseline: middle; + text-anchor: middle; +} + +.svgbox svg .snapshot { + stroke-width: 2; + stroke: red; +} diff --git a/docs/_static/fault_tolerance/fault_tolerance.js b/docs/_static/fault_tolerance/fault_tolerance.js new file mode 100644 index 000000000000..25a10b23c67e --- /dev/null +++ b/docs/_static/fault_tolerance/fault_tolerance.js @@ -0,0 +1,377 @@ +// Helpers ///////////////////////////////////////////////////////////////////// + +// Returns a random float between min and max. +function rand(min, max) { + return Math.random() * (max - min) + min; +} + +// Formats the provided time as hh:mm:ss. +function formatTime(date) { + // https://stackoverflow.com/a/25279399 + return date.toISOString().substring(11, 19); +} + +// Periodically runs f with a delay between min_delay and max_delay. +// setIntervalWithJitter returns a cancel function that, when called, cancels +// the interval. +function setIntervalWithJitter(f, min_delay, max_delay) { + let handle = null; + + f(); + const helper = () => { + const g = () => { + f(); + helper(); + }; + handle = setTimeout(g, rand(min_delay, max_delay)); + return () => { + clearTimeout(handle); + }; + }; + + return helper(); +} + +// Coordination Service //////////////////////////////////////////////////////// + +class CoordinationService { + constructor(network, options) { + const now = new Date(); + this.network = network; + this.options = options; + this.heartbeats = [now, now, now]; + this.alive = [true, true, true]; + this.in_barrier = []; + + // Periodically refresh state. + setInterval(() => this.refresh(), 100); + } + + receive(msg) { + const {src, dst, type, payload} = msg; + switch (type) { + case 'heartbeat': + this.heartbeats[src] = new Date(); + return []; + case 'live_devices': + if (this.options.barrier) { + if (!this.in_barrier.includes(src)) { + this.in_barrier.push(src); + this.refresh_live_devices(); + } + } else { + this.network.push({ + src: 'server', + dst: msg.src, + type: 'live_devices', + payload: this.live_devices(), + }) + } + break; + default: + console.log(`Unknown message type ${type}`) + } + } + + time_since_heartbeat(i) { + return (new Date() - this.heartbeats[i]) / 1000; + } + + detect_failures() { + let something_failed = false; + for (let i = 0; i < 3; ++i) { + if (this.time_since_heartbeat(i) > 6) { + if (this.alive[i]) { + something_failed = true; + } + this.alive[i] = false; + } else { + this.alive[i] = true; + } + } + + if (something_failed && this.options.share_fate) { + for (let i = 0; i < 3; ++i) { + if (this.alive[i]) { + this.network.push({ + src: 'server', + dst: i, + type: 'fail', + payload: '💀', + }) + } + } + } + } + + live_devices() { + let devices = []; + for (let i = 0; i < 3; ++i) { + if (this.alive[i]) { + devices.push(i); + } + } + return devices; + } + + refresh_live_devices() { + // Check dst see if the live_devices barrier is done. + for (let i = 0; i < 3; ++i) { + if (this.alive[i] && !this.in_barrier.includes(i)) { + // The barrier isn't done. + return; + } + } + + // The barrier is done! Send the set of live devices dst all live devices. + let live = this.live_devices(); + for (let i of live) { + this.network.push({ + src: 'server', + dst: i, + type: 'live_devices', + payload: live, + }) + } + this.in_barrier = []; + } + + refresh() { + this.detect_failures(); + this.refresh_live_devices(); + } + + update_html(container) { + for (let i = 0; i < 3; ++i) { + // Update time since last heartbeat. + const now = new Date(); + const time_since = + container.getElementsByClassName(`p${i}-time-since-heartbeat`)[0]; + time_since.textContent = + ((now - this.heartbeats[i]) / 1000).toFixed(1) + ' s'; + + // Update health. + const health = container.getElementsByClassName(`p${i}-health`)[0]; + if (this.alive[i]) { + health.textContent = 'alive'; + health.classList.add('alive'); + time_since.classList.add('alive'); + health.classList.remove('dead'); + time_since.classList.remove('dead'); + } else { + health.textContent = 'dead'; + health.classList.add('dead'); + time_since.classList.add('dead'); + health.classList.remove('alive'); + time_since.classList.remove('alive'); + } + + } + + // Update processes in barrier. + const in_barrier = container.getElementsByClassName('in-barrier')[0]; + if (in_barrier) { + in_barrier.textContent = `In barrier = [${this.in_barrier}]`; + } + } +} + +// Process + +class Process { + constructor(network, options, i) { + this.network = network; + this.options = options; + this.i = i; + this.alive = true; + this.live_devices = null; + this.heartbeat_cancel = + setIntervalWithJitter(() => this.send_heartbeat(), 3000, 4000); + } + + receive(msg) { + const {src, dst, type, payload} = msg; + switch (type) { + case 'live_devices': + if (this.alive) { + this.live_devices = payload; + } + break; + case 'fail': + this.fail(); + break; + default: + console.log(`Unknown message type ${type}`) + } + } + + send_heartbeat() { + this.network.push({ + src: this.i, + dst: 'server', + type: 'heartbeat', + payload: '❤️', + }) + } + + send_live_devices() { + this.network.push({ + src: this.i, + dst: 'server', + type: 'live_devices', + payload: '⚫', + }) + } + + fail() { + this.alive = false; + this.live_devices = null; + this.heartbeat_cancel(); + } + + update_html(container) { + const live_devices = + container.getElementsByClassName(`p${this.i}-live-devices`)[0]; + if (this.options.live_devices) { + if (this.live_devices == null) { + live_devices.textContent = 'live processes = 0,1,2'; + } else { + live_devices.textContent = `live processes = ${this.live_devices}`; + } + } + + if (!this.alive) { + const node = container.getElementsByClassName(`p${this.i}`)[0]; + node.classList.add('failed'); + + const ld_button = + container.getElementsByClassName(`p${this.i}-ld-button`)[0]; + if (ld_button) { + ld_button.disabled = true; + } + + const fail_button = + container.getElementsByClassName(`p${this.i}-fail-button`)[0]; + if (fail_button) { + fail_button.disabled = true; + } + } + } +} + + +// Network communication. + +function send(container, tall, text, src, dst, after) { + const msg = document.createElement('div'); + msg.textContent = text; + msg.classList.add('msg'); + if (tall) { + msg.classList.add(`${src}_to_${dst}_tall`); + } else { + msg.classList.add(`${src}_to_${dst}`); + } + msg.addEventListener('animationend', (_) => { + msg.remove(); + after(); + }); + container.appendChild(msg); +} + +// { +// share_fate: false, +// live_devices: false, +// barrier: false, +// } +function init_cluster(id, options) { + const container = document.getElementById(id); + container.innerHTML = ` +
+
Coordination Service
+
+
    +
  • Process 0: 0s (alive)
  • +
  • Process 1: 0s (alive)
  • +
  • Process 2: 0s (alive)
  • +
  • In barrier: []
  • +
+
+
+ +
+
0
+
live processes = 0,1,2
+ + +
+ +
+
1
+
live processes = 0,1,2
+ + +
+ +
+
2
+
live processes = 0,1,2
+ + +
+ `; + + // Create the cluster. + let network = []; + let server = new CoordinationService(network, options); + const processes = [ + new Process(network, options, 0), new Process(network, options, 1), + new Process(network, options, 2) + ]; + + // Set up the live_devices button. + for (let i = 0; i < 3; ++i) { + const button = container.getElementsByClassName(`p${i}-ld-button`)[0]; + if (options.live_devices) { + button.addEventListener('click', () => processes[i].send_live_devices()); + } else { + button.remove(); + } + } + + // Set up the fail button. + const button = container.querySelectorAll('.p2-fail-button')[0]; + button.addEventListener('click', () => processes[2].fail()); + + // Remove live_devices display if needed. + if (!options.live_devices) { + for (let i = 0; i < 3; ++i) { + container.getElementsByClassName(`p${i}-live-devices`)[0].remove(); + } + } + if (!options.barrier) { + container.getElementsByClassName('in-barrier')[0].remove(); + } + + // Periodically process network messages. + setInterval(() => { + while (network.length > 0) { + const msg = network.shift(); + const tall = options.live_devices; + send(container, tall, msg.payload, `p${msg.src}`, `p${msg.dst}`, () => { + if (msg.dst == 'server') { + server.receive(msg); + } else { + processes[msg.dst].receive(msg); + } + }); + } + }, 10) + + // Periodically update HTML. + setInterval(() => { + server.update_html(container); + for (let proc of processes) { + proc.update_html(container); + } + }, 50); +} diff --git a/docs/_static/fault_tolerance/live_devices.py b/docs/_static/fault_tolerance/live_devices.py new file mode 100644 index 000000000000..9f41a2bdac6a --- /dev/null +++ b/docs/_static/fault_tolerance/live_devices.py @@ -0,0 +1,64 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +os.environ['XLA_FLAGS'] = ' '.join([ + '--xla_gpu_nccl_terminate_on_error=false', + '--xla_gpu_nccl_async_execution=true', + '--xla_gpu_nccl_blocking_communicators=false', +]) +os.environ['XLA_PYTHON_CLIENT_ABORT_COLLECTIVES_ON_FAILURE'] = '1' +os.environ['XLA_PYTHON_CLIENT_USE_TFRT_GPU_CLIENT'] = '1' + +from absl import app +from absl import flags +from collections.abc import Sequence +from jax.experimental.multihost_utils import live_devices +import jax +import jax.numpy as jnp +import time + +_PROCESS_ID = flags.DEFINE_integer("i", -1, "Process id") +_NUM_PROCESSES = flags.DEFINE_integer("n", -1, "Number of processes") + + +def main(_: Sequence[str]) -> None: + jax.config.update("jax_enable_recoverability", True) + jax.distributed.initialize( + coordinator_address="localhost:9000", + num_processes=_NUM_PROCESSES.value, + process_id=_PROCESS_ID.value, + local_device_ids=[_PROCESS_ID.value], + heartbeat_timeout_seconds=10, + ) + print(f'{jax.devices()=}') + print(f'{jax.local_devices()=}') + + while True: + try: + with live_devices(jax.devices()) as devices: + print(f'{devices=}') + n = len(devices) + jax.set_mesh(jax.make_mesh((n,), ("i",), devices=devices)) + x = jax.device_put(jnp.arange(n), jax.P("i")) + print(jnp.sum(x)) + except Exception as e: + print('FAIL:', e) + else: + print('PASS') + time.sleep(1) + + +if __name__ == "__main__": + app.run(main) diff --git a/docs/_static/fault_tolerance/while_loop.py b/docs/_static/fault_tolerance/while_loop.py new file mode 100644 index 000000000000..0dbac58b528d --- /dev/null +++ b/docs/_static/fault_tolerance/while_loop.py @@ -0,0 +1,41 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl import app +from absl import flags +from collections.abc import Sequence +import jax +import time + +_PROCESS_ID = flags.DEFINE_integer("i", -1, "Process id") +_NUM_PROCESSES = flags.DEFINE_integer("n", -1, "Number of processes") + + +def main(_: Sequence[str]) -> None: + jax.distributed.initialize( + coordinator_address="localhost:9000", + num_processes=_NUM_PROCESSES.value, + process_id=_PROCESS_ID.value, + local_device_ids=[_PROCESS_ID.value], + heartbeat_timeout_seconds=10, + ) + print(f'{jax.devices()=}') + print(f'{jax.local_devices()=}') + while True: + print(time.time()) + time.sleep(1) + + +if __name__ == "__main__": + app.run(main) diff --git a/docs/_static/multi_process/controller_and_local_devices.png b/docs/_static/multi_process/controller_and_local_devices.png new file mode 100644 index 000000000000..ad74cad65417 Binary files /dev/null and b/docs/_static/multi_process/controller_and_local_devices.png differ diff --git a/docs/_static/multi_process/mcjax_overview.png b/docs/_static/multi_process/mcjax_overview.png new file mode 100644 index 000000000000..dae947ff9df7 Binary files /dev/null and b/docs/_static/multi_process/mcjax_overview.png differ diff --git a/docs/_static/pallas/distributed/collective_all_gather_operands.svg b/docs/_static/pallas/distributed/collective_all_gather_operands.svg new file mode 100644 index 000000000000..a47610d754f3 --- /dev/null +++ b/docs/_static/pallas/distributed/collective_all_gather_operands.svg @@ -0,0 +1,25 @@ + + + + Activations + + + + 0 + 1 + + + Weights + + + + 0 + 1 + + diff --git a/docs/_static/pallas/gpu/collective_mma.svg b/docs/_static/pallas/gpu/collective_mma.svg new file mode 100644 index 000000000000..cf14d7343e2f --- /dev/null +++ b/docs/_static/pallas/gpu/collective_mma.svg @@ -0,0 +1,86 @@ + + + + + + + + + + + + Operand A + + + A₀ + (M/2, K) + + + A₁ + (M/2, K) + + + + + Accumulator D + + + D₀ + (M/2, N) + + + D₁ + (M/2, N) + + + + + Operand B (Shared) + + + B₀ + (K, N/2) + + + B₁ + (K, N/2) + + + + + + + + + + + + + + + + + + + + + + + + + + + SM₀ computes: D₀ = A₀ @ [B₀ B₁] + + + + + SM₁ computes: D₁ = A₁ @ [B₀ B₁] + + + diff --git a/docs/_static/pallas/gpu/grid_tiling_off.svg b/docs/_static/pallas/gpu/grid_tiling_off.svg new file mode 100644 index 000000000000..b11d85759ce4 --- /dev/null +++ b/docs/_static/pallas/gpu/grid_tiling_off.svg @@ -0,0 +1,175 @@ + + + + + A (6x16 tiles) + B (16x16 tiles) + C = A @ B (6x16 tiles) + + + + + + + + diff --git a/docs/_static/pallas/gpu/grid_tiling_on.svg b/docs/_static/pallas/gpu/grid_tiling_on.svg new file mode 100644 index 000000000000..9d24a8187179 --- /dev/null +++ b/docs/_static/pallas/gpu/grid_tiling_on.svg @@ -0,0 +1,183 @@ + + + + + A (6x16 tiles) + B (16x16 tiles) + C = A @ B (6x16 tiles) + + + + + + + + diff --git a/docs/_static/pallas/gpu/memory_spaces.svg b/docs/_static/pallas/gpu/memory_spaces.svg new file mode 100644 index 000000000000..73dc31a12406 --- /dev/null +++ b/docs/_static/pallas/gpu/memory_spaces.svg @@ -0,0 +1,96 @@ + + + + + + Faster / Smaller Capacity + + + Slower / Larger Capacity + + + + + + Registers (RMEM) + Fastest Latency & BW + Smallest Capacity + + Holds arrays (in Pallas). + Spills if full! + + + + + Tensor Memory (TMEM) + Fastest Latency & BW + Smallest Capacity + + Explicitly managed. + Blackwell specific. + + + + + + Shared Memory (SMEM) + Fast (close to compute) + Small Capacity (per SM) + Partitioned into private slices for each CUDA block/cluster. + + + + L2 Cache + Moderate Speed + Moderate Capacity (~100MBs) + Shared betwen SMs, not directly programmable. + + + + Global Memory (GMEM) + Slowest Latency & Bandwidth + Largest Capacity (GBs) + Main GPU memory (HBM/GDDR technology). + + + + + diff --git a/docs/_static/pallas/gpu/nvidia_sm.svg b/docs/_static/pallas/gpu/nvidia_sm.svg new file mode 100644 index 000000000000..76b4edb2afad --- /dev/null +++ b/docs/_static/pallas/gpu/nvidia_sm.svg @@ -0,0 +1,99 @@ + + + + + Streaming Multiprocessor + + + + + + Warp Scheduler + + TensorCore + + ALU + (Float/Int) + + Load/Store + + + Special + Functions + + + + + + + Warp Scheduler + + TensorCore + + ALU + (Float/Int) + + Load/Store + + + Special + Functions + + + + + + + Warp Scheduler + + TensorCore + + ALU + (Float/Int) + + Load/Store + + + Special + Functions + + + + + + + Warp Scheduler + + TensorCore + + ALU + (Float/Int) + + Load/Store + + + Special + Functions + + + + + + Shared Memory / L1 Cache + + + diff --git a/docs/_static/pallas/gpu/pipeline_matmul.svg b/docs/_static/pallas/gpu/pipeline_matmul.svg new file mode 100644 index 000000000000..7037695e33e9 --- /dev/null +++ b/docs/_static/pallas/gpu/pipeline_matmul.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/gpu/pipeline_matmul_ws.svg b/docs/_static/pallas/gpu/pipeline_matmul_ws.svg new file mode 100644 index 000000000000..3a07ba7e9ece --- /dev/null +++ b/docs/_static/pallas/gpu/pipeline_matmul_ws.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/gpu/warp_specialization.svg b/docs/_static/pallas/gpu/warp_specialization.svg new file mode 100644 index 000000000000..85fbce49fa0b --- /dev/null +++ b/docs/_static/pallas/gpu/warp_specialization.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/pipelining_bandwidth_bound.svg b/docs/_static/pallas/pipelining_bandwidth_bound.svg new file mode 100644 index 000000000000..45b78a7ce35e --- /dev/null +++ b/docs/_static/pallas/pipelining_bandwidth_bound.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/pipelining_compute_bound.svg b/docs/_static/pallas/pipelining_compute_bound.svg new file mode 100644 index 000000000000..cb3b58eaef99 --- /dev/null +++ b/docs/_static/pallas/pipelining_compute_bound.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/pipelining_example.svg b/docs/_static/pallas/pipelining_example.svg new file mode 100644 index 000000000000..59ca5b433b11 --- /dev/null +++ b/docs/_static/pallas/pipelining_example.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/pipelining_latency_multistage.svg b/docs/_static/pallas/pipelining_latency_multistage.svg new file mode 100644 index 000000000000..2c40f1692b9a --- /dev/null +++ b/docs/_static/pallas/pipelining_latency_multistage.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/pipelining_mem_hierarchy.svg b/docs/_static/pallas/pipelining_mem_hierarchy.svg new file mode 100644 index 000000000000..d7a2e6cbabd8 --- /dev/null +++ b/docs/_static/pallas/pipelining_mem_hierarchy.svg @@ -0,0 +1,30 @@ + + + + + + + + + + + + Registers + SRAM/Caches + DRAM/HBM + Network + + Fastest + Fast + Slow + Slowest + + Lowest Capacity + Low Capacity + High Capacity + Highest Capacity + + diff --git a/docs/_static/style.css b/docs/_static/style.css index 51ab72d7153e..fb0cdf3d571f 100644 --- a/docs/_static/style.css +++ b/docs/_static/style.css @@ -81,7 +81,7 @@ body:has(.hero) .bd-container { } .getting-started, -.user-guides, +.jax-101, .installation { background: #3C4043; color: white; @@ -91,7 +91,7 @@ body:has(.hero) .bd-container { } .getting-started:hover, -.user-guides:hover, +.jax-101:hover, .installation:hover { background: #AECBFA; color: #202124; @@ -99,7 +99,7 @@ body:has(.hero) .bd-container { } .getting-started .sd-card-body, -.user-guides .sd-card-body, +.jax-101 .sd-card-body, .installation .sd-card-body { display: flex; align-items: center; @@ -108,7 +108,7 @@ body:has(.hero) .bd-container { } .getting-started .sd-card-title, -.user-guides .sd-card-title, +.jax-101 .sd-card-title, .installation .sd-card-title { display: flex; flex-direction: column; @@ -117,13 +117,13 @@ body:has(.hero) .bd-container { } .getting-started svg, -.user-guides svg, +.jax-101 svg, .installation svg { color: #8AB4F8; } .getting-started:hover svg, -.user-guides:hover svg, +.jax-101:hover svg, .installation:hover svg { color: #3C4043; } @@ -218,7 +218,8 @@ body:has(.hero) .bd-container { .color-cards + p { background: #E8EAED; - padding: 24px 12px 48px 12px; + padding: 24px 12px 25px 12px; + text-align: center; font-weight: 600; color: #222832; border-radius: 0 0 24px 24px; diff --git a/docs/_tutorials/advanced-debugging.md b/docs/_tutorials/advanced-debugging.md index d4462feaf829..6a0758a1be81 100644 --- a/docs/_tutorials/advanced-debugging.md +++ b/docs/_tutorials/advanced-debugging.md @@ -21,5 +21,5 @@ kernelspec: This is a placeholder for a section in the new {ref}`jax-tutorials-draft`. For the time being, you may find some related content in the old documentation: -- {doc}`../debugging/index` +- {doc}`../debugging` ``` diff --git a/docs/_tutorials/index.rst b/docs/_tutorials/index.rst index 0e5a6a16dcfc..497bed519662 100644 --- a/docs/_tutorials/index.rst +++ b/docs/_tutorials/index.rst @@ -9,23 +9,22 @@ JAX tutorials draft The tutorials below are a work in progress; for the time being, please refer to the older tutorial content, including :ref:`beginner-guide`, - :ref:`user-guides`, and the now-deleted *JAX 101* tutorials. + :ref:`jax-101`, and the now-deleted *JAX 101* tutorials. JAX 101 ------- -Mostly finalized at :ref:`jax-tutorials`! +Mostly finalized at :ref:`jax-101`! .. toctree:: :maxdepth: 1 - ../quickstart ../key-concepts ../jit-compilation ../automatic-vectorization ../automatic-differentiation ../debugging ../random-numbers - ../working-with-pytrees + ../pytrees ../sharded-computation ../stateful-computations simple-neural-network diff --git a/docs/about.md b/docs/about.md index 58e1703842b9..d427e2ac0771 100644 --- a/docs/about.md +++ b/docs/about.md @@ -10,7 +10,7 @@ DeepMind](https://deepmind.google/), Alphabet more broadly, and elsewhere. At the heart of the project is the [JAX -core](http://github.com/jax-ml/jax) library, which focuses on the +core](https://github.com/jax-ml/jax) library, which focuses on the fundamentals of machine learning and numerical computing, at scale. When [developing](#development) the core, we want to maintain agility @@ -19,7 +19,7 @@ technology stack](#components). First, we design the `jax` module to be [composable](https://github.com/jax-ml/jax?tab=readme-ov-file#transformations) and -[extensible](https://jax.readthedocs.io/en/latest/jax.extend.html), so +[extensible](https://docs.jax.dev/en/latest/jax.extend.html), so that a wide variety of domain-specific libraries can thrive outside of it in a decentralized manner. Second, we lean heavily on a modular backend stack (compiler and runtime) to target different @@ -42,10 +42,10 @@ scale. JAX's day-to-day development takes place in the open on GitHub, using pull requests, the issue tracker, discussions, and [JAX Enhancement Proposals -(JEPs)](https://jax.readthedocs.io/en/latest/jep/index.html). Reading +(JEPs)](https://docs.jax.dev/en/latest/jep/index.html). Reading and participating in these is a good way to get involved. We also maintain [developer -notes](https://jax.readthedocs.io/en/latest/contributor_guide.html) +notes](https://docs.jax.dev/en/latest/contributor_guide.html) that cover JAX's internal design. The JAX core team determines whether to accept changes and @@ -56,7 +56,7 @@ intricate decision structure over time (e.g. with designated area owners) if/when it becomes useful to do so. For more see [contributing to -JAX](https://jax.readthedocs.io/en/latest/contributing.html). +JAX](https://docs.jax.dev/en/latest/contributing.html). (components)= ## A modular stack @@ -71,7 +71,7 @@ and (b) an advancing hardware landscape, we lean heavily on While the JAX core library focuses on the fundamentals, we want to encourage domain-specific libraries and tools to be built on top of JAX. Indeed, [many -libraries](https://jax.readthedocs.io/en/latest/#ecosystem) have +libraries](https://docs.jax.dev/en/latest/#ecosystem) have emerged around JAX to offer higher-level features and extensions. How do we encourage such decentralized development? We guide it with @@ -80,11 +80,11 @@ building blocks (e.g. numerical primitives, NumPy operations, arrays, and transformations), encouraging auxiliary libraries to develop utilities as needed for their domain. In addition, JAX exposes a handful of more advanced APIs for -[customization](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) +[customization](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) and -[extensibility](https://jax.readthedocs.io/en/latest/jax.extend.html). Libraries +[extensibility](https://docs.jax.dev/en/latest/jax.extend.html). Libraries can [lean on these -APIs](https://jax.readthedocs.io/en/latest/building_on_jax.html) in +APIs](https://docs.jax.dev/en/latest/building_on_jax.html) in order to use JAX as an internal means of implementation, to integrate more with its transformations like autodiff, and more. diff --git a/docs/advanced-autodiff.md b/docs/advanced-autodiff.md deleted file mode 100644 index eaa3bc7317c8..000000000000 --- a/docs/advanced-autodiff.md +++ /dev/null @@ -1,1777 +0,0 @@ ---- -jupytext: - formats: md:myst - text_representation: - extension: .md - format_name: myst - format_version: 0.13 - jupytext_version: 1.16.4 -kernelspec: - display_name: Python 3 - language: python - name: python3 ---- - -(advanced-autodiff)= -# Advanced automatic differentiation - - - -In this tutorial, you will learn about complex applications of automatic differentiation (autodiff) in JAX and gain a better understanding of how taking derivatives in JAX can be both easy and powerful. - -Make sure to check out the {ref}`automatic-differentiation` tutorial to go over the JAX autodiff basics, if you haven't already. - -## Setup - -```{code-cell} -import jax -import jax.numpy as jnp -from jax import grad, jit, vmap -from jax import random - -key = random.key(0) -``` - -## Taking gradients (part 2) - -### Higher-order derivatives - -JAX's autodiff makes it easy to compute higher-order derivatives, because the functions that compute derivatives are themselves differentiable. Thus, higher-order derivatives are as easy as stacking transformations. - -The single-variable case was covered in the {ref}`automatic-differentiation` tutorial, where the example showed how to use {func}`jax.grad` to compute the the derivative of $f(x) = x^3 + 2x^2 - 3x + 1$. - -In the multivariable case, higher-order derivatives are more complicated. The second-order derivative of a function is represented by its [Hessian matrix](https://en.wikipedia.org/wiki/Hessian_matrix), defined according to: - -$$(\mathbf{H}f)_{i,j} = \frac{\partial^2 f}{\partial_i\partial_j}.$$ - -The Hessian of a real-valued function of several variables, $f: \mathbb R^n\to\mathbb R$, can be identified with the [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant) of its gradient. - -JAX provides two transformations for computing the Jacobian of a function, {func}`jax.jacfwd` and {func}`jax.jacrev`, corresponding to forward- and reverse-mode autodiff. They give the same answer, but one can be more efficient than the other in different circumstances – refer to the [video about autodiff](https://www.youtube.com/watch?v=wG_nF1awSSY). - -```{code-cell} -def hessian(f): - return jax.jacfwd(jax.grad(f)) -``` - -Let's double check this is correct on the dot-product $f: \mathbf{x} \mapsto \mathbf{x} ^\top \mathbf{x}$. - -if $i=j$, $\frac{\partial^2 f}{\partial_i\partial_j}(\mathbf{x}) = 2$. Otherwise, $\frac{\partial^2 f}{\partial_i\partial_j}(\mathbf{x}) = 0$. - -```{code-cell} -def f(x): - return jnp.dot(x, x) - -hessian(f)(jnp.array([1., 2., 3.])) -``` - -## Higher-order optimization - -Some meta-learning techniques, such as Model-Agnostic Meta-Learning ([MAML](https://arxiv.org/abs/1703.03400)), require differentiating through gradient updates. In other frameworks this can be quite cumbersome, but in JAX it's much easier: - -```python -def meta_loss_fn(params, data): - """Computes the loss after one step of SGD.""" - grads = jax.grad(loss_fn)(params, data) - return loss_fn(params - lr * grads, data) - -meta_grads = jax.grad(meta_loss_fn)(params, data) -``` - -(stopping-gradients)= -### Stopping gradients - -Autodiff enables automatic computation of the gradient of a function with respect to its inputs. Sometimes, however, you might want some additional control: for instance, you might want to avoid backpropagating gradients through some subset of the computational graph. - -Consider for instance the TD(0) ([temporal difference](https://en.wikipedia.org/wiki/Temporal_difference_learning)) reinforcement learning update. This is used to learn to estimate the *value* of a state in an environment from experience of interacting with the environment. Let's assume the value estimate $v_{\theta}(s_{t-1}$) in a state $s_{t-1}$ is parameterised by a linear function. - -```{code-cell} -# Value function and initial parameters -value_fn = lambda theta, state: jnp.dot(theta, state) -theta = jnp.array([0.1, -0.1, 0.]) -``` - -Consider a transition from a state $s_{t-1}$ to a state $s_t$ during which you observed the reward $r_t$ - -```{code-cell} -# An example transition. -s_tm1 = jnp.array([1., 2., -1.]) -r_t = jnp.array(1.) -s_t = jnp.array([2., 1., 0.]) -``` - -The TD(0) update to the network parameters is: - -$$ -\Delta \theta = (r_t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})) \nabla v_{\theta}(s_{t-1}) -$$ - -This update is not the gradient of any loss function. - -However, it can be **written** as the gradient of the pseudo loss function - -$$ -L(\theta) = - \frac{1}{2} [r_t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})]^2 -$$ - -if the dependency of the target $r_t + v_{\theta}(s_t)$ on the parameter $\theta$ is ignored. - -How can you implement this in JAX? If you write the pseudo loss naively, you get: - -```{code-cell} -def td_loss(theta, s_tm1, r_t, s_t): - v_tm1 = value_fn(theta, s_tm1) - target = r_t + value_fn(theta, s_t) - return -0.5 * ((target - v_tm1) ** 2) - -td_update = jax.grad(td_loss) -delta_theta = td_update(theta, s_tm1, r_t, s_t) - -delta_theta -``` - -But `td_update` will **not** compute a TD(0) update, because the gradient computation will include the dependency of `target` on $\theta$. - -You can use {func}`jax.lax.stop_gradient` to force JAX to ignore the dependency of the target on $\theta$: - -```{code-cell} -def td_loss(theta, s_tm1, r_t, s_t): - v_tm1 = value_fn(theta, s_tm1) - target = r_t + value_fn(theta, s_t) - return -0.5 * ((jax.lax.stop_gradient(target) - v_tm1) ** 2) - -td_update = jax.grad(td_loss) -delta_theta = td_update(theta, s_tm1, r_t, s_t) - -delta_theta -``` - -This will treat `target` as if it did **not** depend on the parameters $\theta$ and compute the correct update to the parameters. - -Now, let's also calculate $\Delta \theta$ using the original TD(0) update expression, to cross-check our work. You may wish to try and implement this yourself using {func}`jax.grad` and your knowledge so far. Here's our solution: - -```{code-cell} -s_grad = jax.grad(value_fn)(theta, s_tm1) -delta_theta_original_calculation = (r_t + value_fn(theta, s_t) - value_fn(theta, s_tm1)) * s_grad - -delta_theta_original_calculation # [1.2, 2.4, -1.2], same as `delta_theta` -``` - -`jax.lax.stop_gradient` may also be useful in other settings, for instance if you want the gradient from some loss to only affect a subset of the parameters of the neural network (because, for instance, the other parameters are trained using a different loss). - - -### Straight-through estimator using `stop_gradient` - -The straight-through estimator is a trick for defining a 'gradient' of a function that is otherwise non-differentiable. Given a non-differentiable function $f : \mathbb{R}^n \to \mathbb{R}^n$ that is used as part of a larger function that we wish to find a gradient of, we simply pretend during the backward pass that $f$ is the identity function. This can be implemented neatly using `jax.lax.stop_gradient`: - -```{code-cell} -def f(x): - return jnp.round(x) # non-differentiable - -def straight_through_f(x): - # Create an exactly-zero expression with Sterbenz lemma that has - # an exactly-one gradient. - zero = x - jax.lax.stop_gradient(x) - return zero + jax.lax.stop_gradient(f(x)) - -print("f(x): ", f(3.2)) -print("straight_through_f(x):", straight_through_f(3.2)) - -print("grad(f)(x):", jax.grad(f)(3.2)) -print("grad(straight_through_f)(x):", jax.grad(straight_through_f)(3.2)) -``` - -### Per-example gradients - -While most ML systems compute gradients and updates from batches of data, for reasons of computational efficiency and/or variance reduction, it is sometimes necessary to have access to the gradient/update associated with each specific sample in the batch. - -For instance, this is needed to prioritize data based on gradient magnitude, or to apply clipping / normalisations on a sample by sample basis. - -In many frameworks (PyTorch, TF, Theano) it is often not trivial to compute per-example gradients, because the library directly accumulates the gradient over the batch. Naive workarounds, such as computing a separate loss per example and then aggregating the resulting gradients are typically very inefficient. - -In JAX, you can define the code to compute the gradient per-sample in an easy but efficient way. - -Just combine the {func}`jax.jit`, {func}`jax.vmap` and {func}`jax.grad` transformations together: - -```{code-cell} -perex_grads = jax.jit(jax.vmap(jax.grad(td_loss), in_axes=(None, 0, 0, 0))) - -# Test it: -batched_s_tm1 = jnp.stack([s_tm1, s_tm1]) -batched_r_t = jnp.stack([r_t, r_t]) -batched_s_t = jnp.stack([s_t, s_t]) - -perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t) -``` - -Let's go through this one transformation at a time. - -First, you apply {func}`jax.grad` to `td_loss` to obtain a function that computes the gradient of the loss w.r.t. the parameters on single (unbatched) inputs: - -```{code-cell} -dtdloss_dtheta = jax.grad(td_loss) - -dtdloss_dtheta(theta, s_tm1, r_t, s_t) -``` - -This function computes one row of the array above. - -Then, you vectorise this function using {func}`jax.vmap`. This adds a batch dimension to all inputs and outputs. Now, given a batch of inputs, you produce a batch of outputs — each output in the batch corresponds to the gradient for the corresponding member of the input batch. - -```{code-cell} -almost_perex_grads = jax.vmap(dtdloss_dtheta) - -batched_theta = jnp.stack([theta, theta]) -almost_perex_grads(batched_theta, batched_s_tm1, batched_r_t, batched_s_t) -``` - -This isn't quite what we want, because we have to manually feed this function a batch of `theta`s, whereas we actually want to use a single `theta`. We fix this by adding `in_axes` to the {func}`jax.vmap`, specifying theta as `None`, and the other args as `0`. This makes the resulting function add an extra axis only to the other arguments, leaving `theta` unbatched, as we want: - -```{code-cell} -inefficient_perex_grads = jax.vmap(dtdloss_dtheta, in_axes=(None, 0, 0, 0)) - -inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t) -``` - -This does what we want, but is slower than it has to be. Now, you wrap the whole thing in a {func}`jax.jit` to get the compiled, efficient version of the same function: - -```{code-cell} -perex_grads = jax.jit(inefficient_perex_grads) - -perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t) -``` - -```{code-cell} -%timeit inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready() -%timeit perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready() -``` - -### Hessian-vector products with `jax.grad`-of-`jax.grad` - -One thing you can do with higher-order {func}`jax.grad` is build a Hessian-vector product function. (Later on you'll write an even more efficient implementation that mixes both forward- and reverse-mode, but this one will use pure reverse-mode.) - -A Hessian-vector product function can be useful in a [truncated Newton Conjugate-Gradient algorithm](https://en.wikipedia.org/wiki/Truncated_Newton_method) for minimizing smooth convex functions, or for studying the curvature of neural network training objectives (e.g. [1](https://arxiv.org/abs/1406.2572), [2](https://arxiv.org/abs/1811.07062), [3](https://arxiv.org/abs/1706.04454), [4](https://arxiv.org/abs/1802.03451)). - -For a scalar-valued function $f : \mathbb{R}^n \to \mathbb{R}$ with continuous second derivatives (so that the Hessian matrix is symmetric), the Hessian at a point $x \in \mathbb{R}^n$ is written as $\partial^2 f(x)$. A Hessian-vector product function is then able to evaluate - -$\qquad v \mapsto \partial^2 f(x) \cdot v$ - -for any $v \in \mathbb{R}^n$. - -The trick is not to instantiate the full Hessian matrix: if $n$ is large, perhaps in the millions or billions in the context of neural networks, then that might be impossible to store. - -Luckily, {func}`jax.grad` already gives us a way to write an efficient Hessian-vector product function. You just have to use the identity: - -$\qquad \partial^2 f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)$, - -where $g(x) = \partial f(x) \cdot v$ is a new scalar-valued function that dots the gradient of $f$ at $x$ with the vector $v$. Notice that you're only ever differentiating scalar-valued functions of vector-valued arguments, which is exactly where you know {func}`jax.grad` is efficient. - -In JAX code, you can just write this: - -```{code-cell} -def hvp(f, x, v): - return grad(lambda x: jnp.vdot(grad(f)(x), v))(x) -``` - -This example shows that you can freely use lexical closure, and JAX will never get perturbed or confused. - -You will check this implementation a few cells down, once you learn how to compute dense Hessian matrices. You'll also write an even better version that uses both forward-mode and reverse-mode. - - -### Jacobians and Hessians using `jax.jacfwd` and `jax.jacrev` - -You can compute full Jacobian matrices using the {func}`jax.jacfwd` and {func}`jax.jacrev` functions: - -```{code-cell} -from jax import jacfwd, jacrev - -# Define a sigmoid function. -def sigmoid(x): - return 0.5 * (jnp.tanh(x / 2) + 1) - -# Outputs probability of a label being true. -def predict(W, b, inputs): - return sigmoid(jnp.dot(inputs, W) + b) - -# Build a toy dataset. -inputs = jnp.array([[0.52, 1.12, 0.77], - [0.88, -1.08, 0.15], - [0.52, 0.06, -1.30], - [0.74, -2.49, 1.39]]) - -# Initialize random model coefficients -key, W_key, b_key = random.split(key, 3) -W = random.normal(W_key, (3,)) -b = random.normal(b_key, ()) - -# Isolate the function from the weight matrix to the predictions -f = lambda W: predict(W, b, inputs) - -J = jacfwd(f)(W) -print("jacfwd result, with shape", J.shape) -print(J) - -J = jacrev(f)(W) -print("jacrev result, with shape", J.shape) -print(J) -``` - -These two functions compute the same values (up to machine numerics), but differ in their implementation: {func}`jax.jacfwd` uses forward-mode automatic differentiation, which is more efficient for "tall" Jacobian matrices (more outputs than inputs), while {func}`jax.jacrev` uses reverse-mode, which is more efficient for "wide" Jacobian matrices (more inputs than outputs). For matrices that are near-square, {func}`jax.jacfwd` probably has an edge over {func}`jax.jacrev`. - -You can also use {func}`jax.jacfwd` and {func}`jax.jacrev` with container types: - -```{code-cell} -def predict_dict(params, inputs): - return predict(params['W'], params['b'], inputs) - -J_dict = jacrev(predict_dict)({'W': W, 'b': b}, inputs) -for k, v in J_dict.items(): - print("Jacobian from {} to logits is".format(k)) - print(v) -``` - -For more details on forward- and reverse-mode, as well as how to implement {func}`jax.jacfwd` and {func}`jax.jacrev` as efficiently as possible, read on! - -Using a composition of two of these functions gives us a way to compute dense Hessian matrices: - -```{code-cell} -def hessian(f): - return jacfwd(jacrev(f)) - -H = hessian(f)(W) -print("hessian, with shape", H.shape) -print(H) -``` - -This shape makes sense: if you start with a function $f : \mathbb{R}^n \to \mathbb{R}^m$, then at a point $x \in \mathbb{R}^n$ you expect to get the shapes: - -* $f(x) \in \mathbb{R}^m$, the value of $f$ at $x$, -* $\partial f(x) \in \mathbb{R}^{m \times n}$, the Jacobian matrix at $x$, -* $\partial^2 f(x) \in \mathbb{R}^{m \times n \times n}$, the Hessian at $x$, - -and so on. - -To implement `hessian`, you could have used `jacfwd(jacrev(f))` or `jacrev(jacfwd(f))` or any other composition of these two. But forward-over-reverse is typically the most efficient. That's because in the inner Jacobian computation we're often differentiating a function wide Jacobian (maybe like a loss function $f : \mathbb{R}^n \to \mathbb{R}$), while in the outer Jacobian computation we're differentiating a function with a square Jacobian (since $\nabla f : \mathbb{R}^n \to \mathbb{R}^n$), which is where forward-mode wins out. - - -## How it's made: Two foundational autodiff functions - -### Jacobian-Vector products (JVPs, a.k.a. forward-mode autodiff) - -JAX includes efficient and general implementations of both forward- and reverse-mode automatic differentiation. The familiar {func}`jax.grad` function is built on reverse-mode, but to explain the difference between the two modes, and when each can be useful, you need a bit of math background. - - -#### JVPs in math - -Mathematically, given a function $f : \mathbb{R}^n \to \mathbb{R}^m$, the Jacobian of $f$ evaluated at an input point $x \in \mathbb{R}^n$, denoted $\partial f(x)$, is often thought of as a matrix in $\mathbb{R}^m \times \mathbb{R}^n$: - -$\qquad \partial f(x) \in \mathbb{R}^{m \times n}$. - -But you can also think of $\partial f(x)$ as a linear map, which maps the tangent space of the domain of $f$ at the point $x$ (which is just another copy of $\mathbb{R}^n$) to the tangent space of the codomain of $f$ at the point $f(x)$ (a copy of $\mathbb{R}^m$): - -$\qquad \partial f(x) : \mathbb{R}^n \to \mathbb{R}^m$. - -This map is called the [pushforward map](https://en.wikipedia.org/wiki/Pushforward_(differential)) of $f$ at $x$. The Jacobian matrix is just the matrix for this linear map on a standard basis. - -If you don't commit to one specific input point $x$, then you can think of the function $\partial f$ as first taking an input point and returning the Jacobian linear map at that input point: - -$\qquad \partial f : \mathbb{R}^n \to \mathbb{R}^n \to \mathbb{R}^m$. - -In particular, you can uncurry things so that given input point $x \in \mathbb{R}^n$ and a tangent vector $v \in \mathbb{R}^n$, you get back an output tangent vector in $\mathbb{R}^m$. We call that mapping, from $(x, v)$ pairs to output tangent vectors, the *Jacobian-vector product*, and write it as: - -$\qquad (x, v) \mapsto \partial f(x) v$ - - -#### JVPs in JAX code - -Back in Python code, JAX's {func}`jax.jvp` function models this transformation. Given a Python function that evaluates $f$, JAX's {func}`jax.jvp` is a way to get a Python function for evaluating $(x, v) \mapsto (f(x), \partial f(x) v)$. - -```{code-cell} -from jax import jvp - -# Isolate the function from the weight matrix to the predictions -f = lambda W: predict(W, b, inputs) - -key, subkey = random.split(key) -v = random.normal(subkey, W.shape) - -# Push forward the vector `v` along `f` evaluated at `W` -y, u = jvp(f, (W,), (v,)) -``` - -In terms of [Haskell-like type signatures](https://wiki.haskell.org/Type_signature), you could write: - -```haskell -jvp :: (a -> b) -> a -> T a -> (b, T b) -``` - -where `T a` is used to denote the type of the tangent space for `a`. - -In other words, `jvp` takes as arguments a function of type `a -> b`, a value of type `a`, and a tangent vector value of type `T a`. It gives back a pair consisting of a value of type `b` and an output tangent vector of type `T b`. - -The `jvp`-transformed function is evaluated much like the original function, but paired up with each primal value of type `a` it pushes along tangent values of type `T a`. For each primitive numerical operation that the original function would have applied, the `jvp`-transformed function executes a "JVP rule" for that primitive that both evaluates the primitive on the primals and applies the primitive's JVP at those primal values. - -That evaluation strategy has some immediate implications about computational complexity. Since we evaluate JVPs as we go, we don't need to store anything for later, and so the memory cost is independent of the depth of the computation. In addition, the FLOP cost of the `jvp`-transformed function is about 3x the cost of just evaluating the function (one unit of work for evaluating the original function, for example `sin(x)`; one unit for linearizing, like `cos(x)`; and one unit for applying the linearized function to a vector, like `cos_x * v`). Put another way, for a fixed primal point $x$, we can evaluate $v \mapsto \partial f(x) \cdot v$ for about the same marginal cost as evaluating $f$. - -That memory complexity sounds pretty compelling! So why don't we see forward-mode very often in machine learning? - -To answer that, first think about how you could use a JVP to build a full Jacobian matrix. If we apply a JVP to a one-hot tangent vector, it reveals one column of the Jacobian matrix, corresponding to the nonzero entry we fed in. So we can build a full Jacobian one column at a time, and to get each column costs about the same as one function evaluation. That will be efficient for functions with "tall" Jacobians, but inefficient for "wide" Jacobians. - -If you're doing gradient-based optimization in machine learning, you probably want to minimize a loss function from parameters in $\mathbb{R}^n$ to a scalar loss value in $\mathbb{R}$. That means the Jacobian of this function is a very wide matrix: $\partial f(x) \in \mathbb{R}^{1 \times n}$, which we often identify with the Gradient vector $\nabla f(x) \in \mathbb{R}^n$. Building that matrix one column at a time, with each call taking a similar number of FLOPs to evaluate the original function, sure seems inefficient! In particular, for training neural networks, where $f$ is a training loss function and $n$ can be in the millions or billions, this approach just won't scale. - -To do better for functions like this, you just need to use reverse-mode. - - -### Vector-Jacobian products (VJPs, a.k.a. reverse-mode autodiff) - -Where forward-mode gives us back a function for evaluating Jacobian-vector products, which we can then use to build Jacobian matrices one column at a time, reverse-mode is a way to get back a function for evaluating vector-Jacobian products (equivalently Jacobian-transpose-vector products), which we can use to build Jacobian matrices one row at a time. - - -#### VJPs in math - -Let's again consider a function $f : \mathbb{R}^n \to \mathbb{R}^m$. -Starting from our notation for JVPs, the notation for VJPs is pretty simple: - -$\qquad (x, v) \mapsto v \partial f(x)$, - -where $v$ is an element of the cotangent space of $f$ at $x$ (isomorphic to another copy of $\mathbb{R}^m$). When being rigorous, we should think of $v$ as a linear map $v : \mathbb{R}^m \to \mathbb{R}$, and when we write $v \partial f(x)$ we mean function composition $v \circ \partial f(x)$, where the types work out because $\partial f(x) : \mathbb{R}^n \to \mathbb{R}^m$. But in the common case we can identify $v$ with a vector in $\mathbb{R}^m$ and use the two almost interchangeably, just like we might sometimes flip between "column vectors" and "row vectors" without much comment. - -With that identification, we can alternatively think of the linear part of a VJP as the transpose (or adjoint conjugate) of the linear part of a JVP: - -$\qquad (x, v) \mapsto \partial f(x)^\mathsf{T} v$. - -For a given point $x$, we can write the signature as - -$\qquad \partial f(x)^\mathsf{T} : \mathbb{R}^m \to \mathbb{R}^n$. - -The corresponding map on cotangent spaces is often called the [pullback](https://en.wikipedia.org/wiki/Pullback_(differential_geometry)) -of $f$ at $x$. The key for our purposes is that it goes from something that looks like the output of $f$ to something that looks like the input of $f$, just like we might expect from a transposed linear function. - -#### VJPs in JAX code - -Switching from math back to Python, the JAX function `vjp` can take a Python function for evaluating $f$ and give us back a Python function for evaluating the VJP $(x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))$. - -```{code-cell} -from jax import vjp - -# Isolate the function from the weight matrix to the predictions -f = lambda W: predict(W, b, inputs) - -y, vjp_fun = vjp(f, W) - -key, subkey = random.split(key) -u = random.normal(subkey, y.shape) - -# Pull back the covector `u` along `f` evaluated at `W` -v = vjp_fun(u) -``` - -In terms of [Haskell-like type signatures](https://wiki.haskell.org/Type_signature), we could write - -```haskell -vjp :: (a -> b) -> a -> (b, CT b -> CT a) -``` - -where we use `CT a` to denote the type for the cotangent space for `a`. In words, `vjp` takes as arguments a function of type `a -> b` and a point of type `a`, and gives back a pair consisting of a value of type `b` and a linear map of type `CT b -> CT a`. - -This is great because it lets us build Jacobian matrices one row at a time, and the FLOP cost for evaluating $(x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))$ is only about three times the cost of evaluating $f$. In particular, if we want the gradient of a function $f : \mathbb{R}^n \to \mathbb{R}$, we can do it in just one call. That's how {func}`jax.grad` is efficient for gradient-based optimization, even for objectives like neural network training loss functions on millions or billions of parameters. - -There's a cost, though the FLOPs are friendly, memory scales with the depth of the computation. Also, the implementation is traditionally more complex than that of forward-mode, though JAX has some tricks up its sleeve (that's a story for a future notebook!). - -For more on how reverse-mode works, check out [this tutorial video from the Deep Learning Summer School in 2017](http://videolectures.net/deeplearning2017_johnson_automatic_differentiation/). - - -### Vector-valued gradients with VJPs - -If you're interested in taking vector-valued gradients (like `tf.gradients`): - -```{code-cell} -def vgrad(f, x): - y, vjp_fn = vjp(f, x) - return vjp_fn(jnp.ones(y.shape))[0] - -print(vgrad(lambda x: 3*x**2, jnp.ones((2, 2)))) -``` - -### Hessian-vector products using both forward- and reverse-mode - -In a previous section, you implemented a Hessian-vector product function just using reverse-mode (assuming continuous second derivatives): - -```{code-cell} -def hvp(f, x, v): - return grad(lambda x: jnp.vdot(grad(f)(x), v))(x) -``` - -That's efficient, but you can do even better and save some memory by using forward-mode together with reverse-mode. - -Mathematically, given a function $f : \mathbb{R}^n \to \mathbb{R}$ to differentiate, a point $x \in \mathbb{R}^n$ at which to linearize the function, and a vector $v \in \mathbb{R}^n$, the Hessian-vector product function we want is: - -$(x, v) \mapsto \partial^2 f(x) v$ - -Consider the helper function $g : \mathbb{R}^n \to \mathbb{R}^n$ defined to be the derivative (or gradient) of $f$, namely $g(x) = \partial f(x)$. All you need is its JVP, since that will give us: - -$(x, v) \mapsto \partial g(x) v = \partial^2 f(x) v$. - -We can translate that almost directly into code: - -```{code-cell} -# forward-over-reverse -def hvp(f, primals, tangents): - return jvp(grad(f), primals, tangents)[1] -``` - -Even better, since you didn't have to call {func}`jnp.dot` directly, this `hvp` function works with arrays of any shape and with arbitrary container types (like vectors stored as nested lists/dicts/tuples), and doesn't even have a dependence on {mod}`jax.numpy`. - -Here's an example of how to use it: - -```{code-cell} -def f(X): - return jnp.sum(jnp.tanh(X)**2) - -key, subkey1, subkey2 = random.split(key, 3) -X = random.normal(subkey1, (30, 40)) -V = random.normal(subkey2, (30, 40)) - -ans1 = hvp(f, (X,), (V,)) -ans2 = jnp.tensordot(hessian(f)(X), V, 2) - -print(jnp.allclose(ans1, ans2, 1e-4, 1e-4)) -``` - -Another way you might consider writing this is using reverse-over-forward: - -```{code-cell} -# Reverse-over-forward -def hvp_revfwd(f, primals, tangents): - g = lambda primals: jvp(f, primals, tangents)[1] - return grad(g)(primals) -``` - -That's not quite as good, though, because forward-mode has less overhead than reverse-mode, and since the outer differentiation operator here has to differentiate a larger computation than the inner one, keeping forward-mode on the outside works best: - -```{code-cell} -# Reverse-over-reverse, only works for single arguments -def hvp_revrev(f, primals, tangents): - x, = primals - v, = tangents - return grad(lambda x: jnp.vdot(grad(f)(x), v))(x) - - -print("Forward over reverse") -%timeit -n10 -r3 hvp(f, (X,), (V,)) -print("Reverse over forward") -%timeit -n10 -r3 hvp_revfwd(f, (X,), (V,)) -print("Reverse over reverse") -%timeit -n10 -r3 hvp_revrev(f, (X,), (V,)) - -print("Naive full Hessian materialization") -%timeit -n10 -r3 jnp.tensordot(hessian(f)(X), V, 2) -``` - -## Composing VJPs, JVPs, and `jax.vmap` - -### Jacobian-Matrix and Matrix-Jacobian products - -Now that you have {func}`jax.jvp` and {func}`jax.vjp` transformations that give you functions to push-forward or pull-back single vectors at a time, you can use JAX's {func}`jax.vmap` [transformation](https://github.com/jax-ml/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, you can use that to write fast matrix-Jacobian and Jacobian-matrix products: - -```{code-cell} -# Isolate the function from the weight matrix to the predictions -f = lambda W: predict(W, b, inputs) - -# Pull back the covectors `m_i` along `f`, evaluated at `W`, for all `i`. -# First, use a list comprehension to loop over rows in the matrix M. -def loop_mjp(f, x, M): - y, vjp_fun = vjp(f, x) - return jnp.vstack([vjp_fun(mi) for mi in M]) - -# Now, use vmap to build a computation that does a single fast matrix-matrix -# multiply, rather than an outer loop over vector-matrix multiplies. -def vmap_mjp(f, x, M): - y, vjp_fun = vjp(f, x) - outs, = vmap(vjp_fun)(M) - return outs - -key = random.key(0) -num_covecs = 128 -U = random.normal(key, (num_covecs,) + y.shape) - -loop_vs = loop_mjp(f, W, M=U) -print('Non-vmapped Matrix-Jacobian product') -%timeit -n10 -r3 loop_mjp(f, W, M=U) - -print('\nVmapped Matrix-Jacobian product') -vmap_vs = vmap_mjp(f, W, M=U) -%timeit -n10 -r3 vmap_mjp(f, W, M=U) - -assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical' -``` - -```{code-cell} -def loop_jmp(f, W, M): - # jvp immediately returns the primal and tangent values as a tuple, - # so we'll compute and select the tangents in a list comprehension - return jnp.vstack([jvp(f, (W,), (mi,))[1] for mi in M]) - -def vmap_jmp(f, W, M): - _jvp = lambda s: jvp(f, (W,), (s,))[1] - return vmap(_jvp)(M) - -num_vecs = 128 -S = random.normal(key, (num_vecs,) + W.shape) - -loop_vs = loop_jmp(f, W, M=S) -print('Non-vmapped Jacobian-Matrix product') -%timeit -n10 -r3 loop_jmp(f, W, M=S) -vmap_vs = vmap_jmp(f, W, M=S) -print('\nVmapped Jacobian-Matrix product') -%timeit -n10 -r3 vmap_jmp(f, W, M=S) - -assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical' -``` - -### The implementation of `jax.jacfwd` and `jax.jacrev` - -Now that we've seen fast Jacobian-matrix and matrix-Jacobian products, it's not hard to guess how to write {func}`jax.jacfwd` and {func}`jax.jacrev`. We just use the same technique to push-forward or pull-back an entire standard basis (isomorphic to an identity matrix) at once. - -```{code-cell} -from jax import jacrev as builtin_jacrev - -def our_jacrev(f): - def jacfun(x): - y, vjp_fun = vjp(f, x) - # Use vmap to do a matrix-Jacobian product. - # Here, the matrix is the Euclidean basis, so we get all - # entries in the Jacobian at once. - J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y))) - return J - return jacfun - -assert jnp.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!' -``` - -```{code-cell} -from jax import jacfwd as builtin_jacfwd - -def our_jacfwd(f): - def jacfun(x): - _jvp = lambda s: jvp(f, (x,), (s,))[1] - Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x))) - return jnp.transpose(Jt) - return jacfun - -assert jnp.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!' -``` - -Interestingly, the [Autograd](https://github.com/hips/autograd) library couldn't do this. The [implementation](https://github.com/HIPS/autograd/blob/96a03f44da43cd7044c61ac945c483955deba957/autograd/differential_operators.py#L60) of reverse-mode `jacobian` in Autograd had to pull back one vector at a time with an outer-loop `map`. Pushing one vector at a time through the computation is much less efficient than batching it all together with {func}`jax.vmap`. - -Another thing that Autograd couldn't do is {func}`jax.jit`. Interestingly, no matter how much Python dynamism you use in your function to be differentiated, we could always use {func}`jax.jit` on the linear part of the computation. For example: - -```{code-cell} -def f(x): - try: - if x < 3: - return 2 * x ** 3 - else: - raise ValueError - except ValueError: - return jnp.pi * x - -y, f_vjp = vjp(f, 4.) -print(jit(f_vjp)(1.)) -``` - -## Complex numbers and differentiation - -JAX is great at complex numbers and differentiation. To support both [holomorphic and non-holomorphic differentiation](https://en.wikipedia.org/wiki/Holomorphic_function), it helps to think in terms of JVPs and VJPs. - -Consider a complex-to-complex function $f: \mathbb{C} \to \mathbb{C}$ and identify it with a corresponding function $g: \mathbb{R}^2 \to \mathbb{R}^2$, - -```{code-cell} -def f(z): - x, y = jnp.real(z), jnp.imag(z) - return u(x, y) + v(x, y) * 1j - -def g(x, y): - return (u(x, y), v(x, y)) -``` - -That is, we've decomposed $f(z) = u(x, y) + v(x, y) i$ where $z = x + y i$, and identified $\mathbb{C}$ with $\mathbb{R}^2$ to get $g$. - -Since $g$ only involves real inputs and outputs, we already know how to write a Jacobian-vector product for it, say given a tangent vector $(c, d) \in \mathbb{R}^2$, namely: - -$\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} -\begin{bmatrix} c \\ d \end{bmatrix}$. - -To get a JVP for the original function $f$ applied to a tangent vector $c + di \in \mathbb{C}$, we just use the same definition and identify the result as another complex number, - -$\partial f(x + y i)(c + d i) = -\begin{matrix} \begin{bmatrix} 1 & i \end{bmatrix} \\ ~ \end{matrix} -\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} -\begin{bmatrix} c \\ d \end{bmatrix}$. - -That's our definition of the JVP of a $\mathbb{C} \to \mathbb{C}$ function! Notice it doesn't matter whether or not $f$ is holomorphic: the JVP is unambiguous. - -Here's a check: - -```{code-cell} -def check(seed): - key = random.key(seed) - - # random coeffs for u and v - key, subkey = random.split(key) - a, b, c, d = random.uniform(subkey, (4,)) - - def fun(z): - x, y = jnp.real(z), jnp.imag(z) - return u(x, y) + v(x, y) * 1j - - def u(x, y): - return a * x + b * y - - def v(x, y): - return c * x + d * y - - # primal point - key, subkey = random.split(key) - x, y = random.uniform(subkey, (2,)) - z = x + y * 1j - - # tangent vector - key, subkey = random.split(key) - c, d = random.uniform(subkey, (2,)) - z_dot = c + d * 1j - - # check jvp - _, ans = jvp(fun, (z,), (z_dot,)) - expected = (grad(u, 0)(x, y) * c + - grad(u, 1)(x, y) * d + - grad(v, 0)(x, y) * c * 1j+ - grad(v, 1)(x, y) * d * 1j) - print(jnp.allclose(ans, expected)) -``` - -```{code-cell} -check(0) -check(1) -check(2) -``` - -What about VJPs? We do something pretty similar: for a cotangent vector $c + di \in \mathbb{C}$ we define the VJP of $f$ as - -$(c + di)^* \; \partial f(x + y i) = -\begin{matrix} \begin{bmatrix} c & -d \end{bmatrix} \\ ~ \end{matrix} -\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} -\begin{bmatrix} 1 \\ -i \end{bmatrix}$. - -What's with the negatives? They're just to take care of complex conjugation, and the fact that we're working with covectors. - -Here's a check of the VJP rules: - -```{code-cell} -def check(seed): - key = random.key(seed) - - # random coeffs for u and v - key, subkey = random.split(key) - a, b, c, d = random.uniform(subkey, (4,)) - - def fun(z): - x, y = jnp.real(z), jnp.imag(z) - return u(x, y) + v(x, y) * 1j - - def u(x, y): - return a * x + b * y - - def v(x, y): - return c * x + d * y - - # primal point - key, subkey = random.split(key) - x, y = random.uniform(subkey, (2,)) - z = x + y * 1j - - # cotangent vector - key, subkey = random.split(key) - c, d = random.uniform(subkey, (2,)) - z_bar = jnp.array(c + d * 1j) # for dtype control - - # check vjp - _, fun_vjp = vjp(fun, z) - ans, = fun_vjp(z_bar) - expected = (grad(u, 0)(x, y) * c + - grad(v, 0)(x, y) * (-d) + - grad(u, 1)(x, y) * c * (-1j) + - grad(v, 1)(x, y) * (-d) * (-1j)) - assert jnp.allclose(ans, expected, atol=1e-5, rtol=1e-5) -``` - -```{code-cell} -check(0) -check(1) -check(2) -``` - -What about convenience wrappers like {func}`jax.grad`, {func}`jax.jacfwd`, and {func}`jax.jacrev`? - -For $\mathbb{R} \to \mathbb{R}$ functions, recall we defined `grad(f)(x)` as being `vjp(f, x)[1](1.0)`, which works because applying a VJP to a `1.0` value reveals the gradient (i.e. Jacobian, or derivative). We can do the same thing for $\mathbb{C} \to \mathbb{R}$ functions: we can still use `1.0` as the cotangent vector, and we just get out a complex number result summarizing the full Jacobian: - -```{code-cell} -def f(z): - x, y = jnp.real(z), jnp.imag(z) - return x**2 + y**2 - -z = 3. + 4j -grad(f)(z) -``` - -For general $\mathbb{C} \to \mathbb{C}$ functions, the Jacobian has 4 real-valued degrees of freedom (as in the 2x2 Jacobian matrices above), so we can't hope to represent all of them within a complex number. But we can for holomorphic functions! A holomorphic function is precisely a $\mathbb{C} \to \mathbb{C}$ function with the special property that its derivative can be represented as a single complex number. (The [Cauchy-Riemann equations](https://en.wikipedia.org/wiki/Cauchy%E2%80%93Riemann_equations) ensure that the above 2x2 Jacobians have the special form of a scale-and-rotate matrix in the complex plane, i.e. the action of a single complex number under multiplication.) And we can reveal that one complex number using a single call to `vjp` with a covector of `1.0`. - -Because this only works for holomorphic functions, to use this trick we need to promise JAX that our function is holomorphic; otherwise, JAX will raise an error when {func}`jax.grad` is used for a complex-output function: - -```{code-cell} -def f(z): - return jnp.sin(z) - -z = 3. + 4j -grad(f, holomorphic=True)(z) -``` - -All the `holomorphic=True` promise does is disable the error when the output is complex-valued. We can still write `holomorphic=True` when the function isn't holomorphic, but the answer we get out won't represent the full Jacobian. Instead, it'll be the Jacobian of the function where we just discard the imaginary part of the output: - -```{code-cell} -def f(z): - return jnp.conjugate(z) - -z = 3. + 4j -grad(f, holomorphic=True)(z) # f is not actually holomorphic! -``` - -There are some useful upshots for how {func}`jax.grad` works here: - -1. We can use {func}`jax.grad` on holomorphic $\mathbb{C} \to \mathbb{C}$ functions. -2. We can use {func}`jax.grad` to optimize $f : \mathbb{C} \to \mathbb{R}$ functions, like real-valued loss functions of complex parameters `x`, by taking steps in the direction of the conjugate of `grad(f)(x)`. -3. If we have an $\mathbb{R} \to \mathbb{R}$ function that just happens to use some complex-valued operations internally (some of which must be non-holomorphic, e.g. FFTs used in convolutions) then {func}`jax.grad` still works and we get the same result that an implementation using only real values would have given. - -In any case, JVPs and VJPs are always unambiguous. And if we wanted to compute the full Jacobian matrix of a non-holomorphic $\mathbb{C} \to \mathbb{C}$ function, we can do it with JVPs or VJPs! - - -You should expect complex numbers to work everywhere in JAX. Here's differentiating through a Cholesky decomposition of a complex matrix: - -```{code-cell} -A = jnp.array([[5., 2.+3j, 5j], - [2.-3j, 7., 1.+7j], - [-5j, 1.-7j, 12.]]) - -def f(X): - L = jnp.linalg.cholesky(X) - return jnp.sum((L - jnp.sin(L))**2) - -grad(f, holomorphic=True)(A) -``` - -(advanced-autodiff-custom-derivative-rules)= -## Custom derivative rules for JAX-transformable Python functions - -There are two ways to define differentiation rules in JAX: - -1. Using {func}`jax.custom_jvp` and {func}`jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and -2. Defining new `core.Primitive` instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems. - -This notebook is about #1. To read instead about #2, refer to the [notebook on adding primitives](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html). - - -### TL;DR: Custom JVPs with {func}`jax.custom_jvp` - -```{code-cell} -from jax import custom_jvp - -@custom_jvp -def f(x, y): - return jnp.sin(x) * y - -@f.defjvp -def f_jvp(primals, tangents): - x, y = primals - x_dot, y_dot = tangents - primal_out = f(x, y) - tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot - return primal_out, tangent_out -``` - -```{code-cell} -print(f(2., 3.)) -y, y_dot = jvp(f, (2., 3.), (1., 0.)) -print(y) -print(y_dot) -print(grad(f)(2., 3.)) -``` - -```{code-cell} -# Equivalent alternative using the `defjvps` convenience wrapper - -@custom_jvp -def f(x, y): - return jnp.sin(x) * y - -f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y, - lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot) -``` - -```{code-cell} -print(f(2., 3.)) -y, y_dot = jvp(f, (2., 3.), (1., 0.)) -print(y) -print(y_dot) -print(grad(f)(2., 3.)) -``` - -### TL;DR: Custom VJPs with `jax.custom_vjp` - -```{code-cell} -from jax import custom_vjp - -@custom_vjp -def f(x, y): - return jnp.sin(x) * y - -def f_fwd(x, y): -# Returns primal output and residuals to be used in backward pass by `f_bwd`. - return f(x, y), (jnp.cos(x), jnp.sin(x), y) - -def f_bwd(res, g): - cos_x, sin_x, y = res # Gets residuals computed in `f_fwd` - return (cos_x * g * y, sin_x * g) - -f.defvjp(f_fwd, f_bwd) -``` - -```{code-cell} -print(grad(f)(2., 3.)) -``` - -### Example problems - -To get an idea of what problems {func}`jax.custom_jvp` and {func}`jax.custom_vjp` are meant to solve, let's go over a few examples. A more thorough introduction to the {func}`jax.custom_jvp` and {func}`jax.custom_vjp` APIs is in the next section. - - -#### Example: Numerical stability - -One application of {func}`jax.custom_jvp` is to improve the numerical stability of differentiation. - -Say we want to write a function called `log1pexp`, which computes $x \mapsto \log ( 1 + e^x )$. We can write that using `jax.numpy`: - -```{code-cell} -def log1pexp(x): - return jnp.log(1. + jnp.exp(x)) - -log1pexp(3.) -``` - -Since it's written in terms of `jax.numpy`, it's JAX-transformable: - -```{code-cell} -print(jit(log1pexp)(3.)) -print(jit(grad(log1pexp))(3.)) -print(vmap(jit(grad(log1pexp)))(jnp.arange(3.))) -``` - -But there's a numerical stability problem lurking here: - -```{code-cell} -print(grad(log1pexp)(100.)) -``` - -That doesn't seem right! After all, the derivative of $x \mapsto \log (1 + e^x)$ is $x \mapsto \frac{e^x}{1 + e^x}$, and so for large values of $x$ we'd expect the value to be about 1. - -We can get a bit more insight into what's going on by looking at the jaxpr for the gradient computation: - -```{code-cell} -from jax import make_jaxpr - -make_jaxpr(grad(log1pexp))(100.) -``` - -Stepping through how the jaxpr would be evaluated, notice that the last line would involve multiplying values that floating point math will round to 0 and $\infty$, respectively, which is never a good idea. That is, we're effectively evaluating `lambda x: (1 / (1 + jnp.exp(x))) * jnp.exp(x)` for large `x`, which effectively turns into `0. * jnp.inf`. - -Instead of generating such large and small values, hoping for a cancellation that floats can't always provide, we'd rather just express the derivative function as a more numerically stable program. In particular, we can write a program that more closely evaluates the equal mathematical expression $1 - \frac{1}{1 + e^x}$, with no cancellation in sight. - -This problem is interesting because even though our definition of `log1pexp` could already be JAX-differentiated (and transformed with {func}`jax.jit`, {func}`jax.vmap`, ...), we're not happy with the result of applying standard autodiff rules to the primitives comprising `log1pexp` and composing the result. Instead, we'd like to specify how the whole function `log1pexp` should be differentiated, as a unit, and thus arrange those exponentials better. - -This is one application of custom derivative rules for Python functions that are already JAX transformable: specifying how a composite function should be differentiated, while still using its original Python definition for other transformations (like {func}`jax.jit`, {func}`jax.vmap`, ...). - -Here's a solution using {func}`jax.custom_jvp`: - -```{code-cell} -@custom_jvp -def log1pexp(x): - return jnp.log(1. + jnp.exp(x)) - -@log1pexp.defjvp -def log1pexp_jvp(primals, tangents): - x, = primals - x_dot, = tangents - ans = log1pexp(x) - ans_dot = (1 - 1/(1 + jnp.exp(x))) * x_dot - return ans, ans_dot -``` - -```{code-cell} -print(grad(log1pexp)(100.)) -``` - -```{code-cell} -print(jit(log1pexp)(3.)) -print(jit(grad(log1pexp))(3.)) -print(vmap(jit(grad(log1pexp)))(jnp.arange(3.))) -``` - -Here's a `defjvps` convenience wrapper to express the same thing: - -```{code-cell} -@custom_jvp -def log1pexp(x): - return jnp.log(1. + jnp.exp(x)) - -log1pexp.defjvps(lambda t, ans, x: (1 - 1/(1 + jnp.exp(x))) * t) -``` - -```{code-cell} -print(grad(log1pexp)(100.)) -print(jit(log1pexp)(3.)) -print(jit(grad(log1pexp))(3.)) -print(vmap(jit(grad(log1pexp)))(jnp.arange(3.))) -``` - -#### Example: Enforcing a differentiation convention - -A related application is to enforce a differentiation convention, perhaps at a boundary. - -Consider the function $f : \mathbb{R}_+ \to \mathbb{R}_+$ with $f(x) = \frac{x}{1 + \sqrt{x}}$, where we take $\mathbb{R}_+ = [0, \infty)$. We might implement $f$ as a program like this: - -```{code-cell} -def f(x): - return x / (1 + jnp.sqrt(x)) -``` - -As a mathematical function on $\mathbb{R}$ (the full real line), $f$ is not differentiable at zero (because the limit defining the derivative doesn't exist from the left). Correspondingly, autodiff produces a `nan` value: - -```{code-cell} -print(grad(f)(0.)) -``` - -But mathematically if we think of $f$ as a function on $\mathbb{R}_+$ then it is differentiable at 0 [Rudin's Principles of Mathematical Analysis Definition 5.1, or Tao's Analysis I 3rd ed. Definition 10.1.1 and Example 10.1.6]. Alternatively, we might say as a convention we want to consider the directional derivative from the right. So there is a sensible value for the Python function `grad(f)` to return at `0.0`, namely `1.0`. By default, JAX's machinery for differentiation assumes all functions are defined over $\mathbb{R}$ and thus doesn't produce `1.0` here. - -We can use a custom JVP rule! In particular, we can define the JVP rule in terms of the derivative function $x \mapsto \frac{\sqrt{x} + 2}{2(\sqrt{x} + 1)^2}$ on $\mathbb{R}_+$, - -```{code-cell} -@custom_jvp -def f(x): - return x / (1 + jnp.sqrt(x)) - -@f.defjvp -def f_jvp(primals, tangents): - x, = primals - x_dot, = tangents - ans = f(x) - ans_dot = ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * x_dot - return ans, ans_dot -``` - -```{code-cell} -print(grad(f)(0.)) -``` - -Here's the convenience wrapper version: - -```{code-cell} -@custom_jvp -def f(x): - return x / (1 + jnp.sqrt(x)) - -f.defjvps(lambda t, ans, x: ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * t) -``` - -```{code-cell} -print(grad(f)(0.)) -``` - -#### Example: Gradient clipping - -While in some cases we want to express a mathematical differentiation computation, in other cases we may even want to take a step away from mathematics to adjust the computation autodiff performs. One canonical example is reverse-mode gradient clipping. - -For gradient clipping, we can use {func}`jnp.clip` together with a {func}`jax.custom_vjp` reverse-mode-only rule: - -```{code-cell} -from functools import partial - -@custom_vjp -def clip_gradient(lo, hi, x): - return x # identity function - -def clip_gradient_fwd(lo, hi, x): - return x, (lo, hi) # save bounds as residuals - -def clip_gradient_bwd(res, g): - lo, hi = res - return (None, None, jnp.clip(g, lo, hi)) # use None to indicate zero cotangents for lo and hi - -clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd) -``` - -```{code-cell} -import matplotlib.pyplot as plt - -t = jnp.linspace(0, 10, 1000) - -plt.plot(jnp.sin(t)) -plt.plot(vmap(grad(jnp.sin))(t)) -``` - -```{code-cell} -def clip_sin(x): - x = clip_gradient(-0.75, 0.75, x) - return jnp.sin(x) - -plt.plot(clip_sin(t)) -plt.plot(vmap(grad(clip_sin))(t)) -``` - -#### Example: Python debugging - -Another application that is motivated by development workflow rather than numerics is to set a `pdb` debugger trace in the backward pass of reverse-mode autodiff. - -When trying to track down the source of a `nan` runtime error, or just examine carefully the cotangent (gradient) values being propagated, it can be useful to insert a debugger at a point in the backward pass that corresponds to a specific point in the primal computation. You can do that with {func}`jax.custom_vjp`. - -We'll defer an example until the next section. - - - -#### Example: Implicit function differentiation of iterative implementations - -This example gets pretty deep in the mathematical weeds! - -Another application for {func}`jax.custom_vjp` is reverse-mode differentiation of functions that are JAX-transformable (by {func}`jax.jit`, {func}`jax.vmap`, ...) but not efficiently JAX-differentiable for some reason, perhaps because they involve {func}`jax.lax.while_loop`. (It's not possible to produce an XLA HLO program that efficiently computes the reverse-mode derivative of an XLA HLO While loop because that would require a program with unbounded memory use, which isn't possible to express in XLA HLO, at least without "side-effecting" interactions through infeed/outfeed.) - -For example, consider this `fixed_point` routine which computes a fixed point by iteratively applying a function in a `while_loop`: - -```{code-cell} -from jax.lax import while_loop - -def fixed_point(f, a, x_guess): - def cond_fun(carry): - x_prev, x = carry - return jnp.abs(x_prev - x) > 1e-6 - - def body_fun(carry): - _, x = carry - return x, f(a, x) - - _, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess))) - return x_star -``` - -This is an iterative procedure for numerically solving the equation $x = f(a, x)$ for $x$, by iterating $x_{t+1} = f(a, x_t)$ until $x_{t+1}$ is sufficiently close to $x_t$. The result $x^*$ depends on the parameters $a$, and so we can think of there being a function $a \mapsto x^*(a)$ that is implicitly defined by equation $x = f(a, x)$. - -We can use `fixed_point` to run iterative procedures to convergence, for example running Newton's method to calculate square roots while only executing adds, multiplies, and divides: - -```{code-cell} -def newton_sqrt(a): - update = lambda a, x: 0.5 * (x + a / x) - return fixed_point(update, a, a) -``` - -```{code-cell} -print(newton_sqrt(2.)) -``` - -We can {func}`jax.vmap` or {func}`jax.jit` the function as well: - -```{code-cell} -print(jit(vmap(newton_sqrt))(jnp.array([1., 2., 3., 4.]))) -``` - -We can't apply reverse-mode automatic differentiation because of the `while_loop`, but it turns out we wouldn't want to anyway: instead of differentiating through the implementation of `fixed_point` and all its iterations, we can exploit the mathematical structure to do something that is much more memory-efficient (and FLOP-efficient in this case, too!). We can instead use the implicit function theorem [Prop A.25 of Bertsekas's Nonlinear Programming, 2nd ed.], which guarantees (under some conditions) the existence of the mathematical objects we're about to use. In essence, we linearize the solution and solve those linear equations iteratively to compute the derivatives we want. - -Consider again the equation $x = f(a, x)$ and the function $x^*$. We want to evaluate vector-Jacobian products like $v^\mathsf{T} \mapsto v^\mathsf{T} \partial x^*(a_0)$. - -At least in an open neighborhood around the point $a_0$ at which we want to differentiate, let's assume that the equation $x^*(a) = f(a, x^*(a))$ holds for all $a$. Since the two sides are equal as functions of $a$, their derivatives must be equal as well, so let's differentiate both sides: - -$\qquad \partial x^*(a) = \partial_0 f(a, x^*(a)) + \partial_1 f(a, x^*(a)) \partial x^*(a)$. - -Setting $A = \partial_1 f(a_0, x^*(a_0))$ and $B = \partial_0 f(a_0, x^*(a_0))$, we can write the quantity we're after more simply as: - -$\qquad \partial x^*(a_0) = B + A \partial x^*(a_0)$, - -or, by rearranging, - -$\qquad \partial x^*(a_0) = (I - A)^{-1} B$. - -That means we can evaluate vector-Jacobian products, such as: - -$\qquad v^\mathsf{T} \partial x^*(a_0) = v^\mathsf{T} (I - A)^{-1} B = w^\mathsf{T} B$, - -where $w^\mathsf{T} = v^\mathsf{T} (I - A)^{-1}$, or equivalently $w^\mathsf{T} = v^\mathsf{T} + w^\mathsf{T} A$, or equivalently $w^\mathsf{T}$ is the fixed point of the map $u^\mathsf{T} \mapsto v^\mathsf{T} + u^\mathsf{T} A$. That last characterization gives us a way to write the VJP for `fixed_point` in terms of a call to `fixed_point`! Moreover, after expanding $A$ and $B$ back out, you can conclude you need only to evaluate VJPs of $f$ at $(a_0, x^*(a_0))$. - -Here's the upshot: - -```{code-cell} -@partial(custom_vjp, nondiff_argnums=(0,)) -def fixed_point(f, a, x_guess): - def cond_fun(carry): - x_prev, x = carry - return jnp.abs(x_prev - x) > 1e-6 - - def body_fun(carry): - _, x = carry - return x, f(a, x) - - _, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess))) - return x_star - -def fixed_point_fwd(f, a, x_init): - x_star = fixed_point(f, a, x_init) - return x_star, (a, x_star) - -def fixed_point_rev(f, res, x_star_bar): - a, x_star = res - _, vjp_a = vjp(lambda a: f(a, x_star), a) - a_bar, = vjp_a(fixed_point(partial(rev_iter, f), - (a, x_star, x_star_bar), - x_star_bar)) - return a_bar, jnp.zeros_like(x_star) - -def rev_iter(f, packed, u): - a, x_star, x_star_bar = packed - _, vjp_x = vjp(lambda x: f(a, x), x_star) - return x_star_bar + vjp_x(u)[0] - -fixed_point.defvjp(fixed_point_fwd, fixed_point_rev) -``` - -```{code-cell} -print(newton_sqrt(2.)) -``` - -```{code-cell} -print(grad(newton_sqrt)(2.)) -print(grad(grad(newton_sqrt))(2.)) -``` - -We can check our answers by differentiating {func}`jnp.sqrt`, which uses a totally different implementation: - -```{code-cell} -print(grad(jnp.sqrt)(2.)) -print(grad(grad(jnp.sqrt))(2.)) -``` - -A limitation to this approach is that the argument `f` can't close over any values involved in differentiation. That is, you might notice that we kept the parameter `a` explicit in the argument list of `fixed_point`. For this use case, consider using the low-level primitive `lax.custom_root`, which allows for derivatives in closed-over variables with custom root-finding functions. - - -### Basic usage of `jax.custom_jvp` and `jax.custom_vjp` APIs - -#### Use `jax.custom_jvp` to define forward-mode (and, indirectly, reverse-mode) rules - -Here's a canonical basic example of using {func}`jax.custom_jvp`, where the comments use -[Haskell-like type signatures](https://wiki.haskell.org/Type_signature): - -```{code-cell} -# f :: a -> b -@custom_jvp -def f(x): - return jnp.sin(x) - -# f_jvp :: (a, T a) -> (b, T b) -def f_jvp(primals, tangents): - x, = primals - t, = tangents - return f(x), jnp.cos(x) * t - -f.defjvp(f_jvp) -``` - -```{code-cell} -print(f(3.)) - -y, y_dot = jvp(f, (3.,), (1.,)) -print(y) -print(y_dot) -``` - -In other words, we start with a primal function `f` that takes inputs of type `a` and produces outputs of type `b`. We associate with it a JVP rule function `f_jvp` that takes a pair of inputs representing the primal inputs of type `a` and the corresponding tangent inputs of type `T a`, and produces a pair of outputs representing the primal outputs of type `b` and tangent outputs of type `T b`. The tangent outputs should be a linear function of the tangent inputs. - -You can also use `f.defjvp` as a decorator, as in - -```python -@custom_jvp -def f(x): - ... - -@f.defjvp -def f_jvp(primals, tangents): - ... -``` - -Even though we defined only a JVP rule and no VJP rule, we can use both forward- and reverse-mode differentiation on `f`. JAX will automatically transpose the linear computation on tangent values from our custom JVP rule, computing the VJP as efficiently as if we had written the rule by hand: - -```{code-cell} -print(grad(f)(3.)) -print(grad(grad(f))(3.)) -``` - -For automatic transposition to work, the JVP rule's output tangents must be linear as a function of the input tangents. Otherwise a transposition error is raised. - -Multiple arguments work like this: - -```{code-cell} -@custom_jvp -def f(x, y): - return x ** 2 * y - -@f.defjvp -def f_jvp(primals, tangents): - x, y = primals - x_dot, y_dot = tangents - primal_out = f(x, y) - tangent_out = 2 * x * y * x_dot + x ** 2 * y_dot - return primal_out, tangent_out -``` - -```{code-cell} -print(grad(f)(2., 3.)) -``` - -The `defjvps` convenience wrapper lets us define a JVP for each argument separately, and the results are computed separately then summed: - -```{code-cell} -@custom_jvp -def f(x): - return jnp.sin(x) - -f.defjvps(lambda t, ans, x: jnp.cos(x) * t) -``` - -```{code-cell} -print(grad(f)(3.)) -``` - -Here's a `defjvps` example with multiple arguments: - -```{code-cell} -@custom_jvp -def f(x, y): - return x ** 2 * y - -f.defjvps(lambda x_dot, primal_out, x, y: 2 * x * y * x_dot, - lambda y_dot, primal_out, x, y: x ** 2 * y_dot) -``` - -```{code-cell} -print(grad(f)(2., 3.)) -print(grad(f, 0)(2., 3.)) # same as above -print(grad(f, 1)(2., 3.)) -``` - -As a shorthand, with `defjvps` you can pass a `None` value to indicate that the JVP for a particular argument is zero: - -```{code-cell} -@custom_jvp -def f(x, y): - return x ** 2 * y - -f.defjvps(lambda x_dot, primal_out, x, y: 2 * x * y * x_dot, - None) -``` - -```{code-cell} -print(grad(f)(2., 3.)) -print(grad(f, 0)(2., 3.)) # same as above -print(grad(f, 1)(2., 3.)) -``` - -Calling a {func}`jax.custom_jvp` function with keyword arguments, or writing a {func}`jax.custom_jvp` function definition with default arguments, are both allowed so long as they can be unambiguously mapped to positional arguments based on the function signature retrieved by the standard library `inspect.signature` mechanism. - -When you're not performing differentiation, the function `f` is called just as if it weren't decorated by {func}`jax.custom_jvp`: - -```{code-cell} -@custom_jvp -def f(x): - print('called f!') # a harmless side-effect - return jnp.sin(x) - -@f.defjvp -def f_jvp(primals, tangents): - print('called f_jvp!') # a harmless side-effect - x, = primals - t, = tangents - return f(x), jnp.cos(x) * t -``` - -```{code-cell} -print(f(3.)) -``` - -```{code-cell} -print(vmap(f)(jnp.arange(3.))) -print(jit(f)(3.)) -``` - -The custom JVP rule is invoked during differentiation, whether forward or reverse: - -```{code-cell} -y, y_dot = jvp(f, (3.,), (1.,)) -print(y_dot) -``` - -```{code-cell} -print(grad(f)(3.)) -``` - -Notice that `f_jvp` calls `f` to compute the primal outputs. In the context of higher-order differentiation, each application of a differentiation transform will use the custom JVP rule if and only if the rule calls the original `f` to compute the primal outputs. (This represents a kind of fundamental tradeoff, where we can't make use of intermediate values from the evaluation of `f` in our rule _and also_ have the rule apply in all orders of higher-order differentiation.) - -```{code-cell} -grad(grad(f))(3.) -``` - -You can use Python control flow with {func}`jax.custom_jvp`: - -```{code-cell} -@custom_jvp -def f(x): - if x > 0: - return jnp.sin(x) - else: - return jnp.cos(x) - -@f.defjvp -def f_jvp(primals, tangents): - x, = primals - x_dot, = tangents - ans = f(x) - if x > 0: - return ans, 2 * x_dot - else: - return ans, 3 * x_dot -``` - -```{code-cell} -print(grad(f)(1.)) -print(grad(f)(-1.)) -``` - -#### Use `jax.custom_vjp` to define custom reverse-mode-only rules - -While {func}`jax.custom_jvp` suffices for controlling both forward- and, via JAX's automatic transposition, reverse-mode differentiation behavior, in some cases we may want to directly control a VJP rule, for example in the latter two example problems presented above. We can do that with {func}`jax.custom_vjp`: - -```{code-cell} -from jax import custom_vjp - -# f :: a -> b -@custom_vjp -def f(x): - return jnp.sin(x) - -# f_fwd :: a -> (b, c) -def f_fwd(x): - return f(x), jnp.cos(x) - -# f_bwd :: (c, CT b) -> CT a -def f_bwd(cos_x, y_bar): - return (cos_x * y_bar,) - -f.defvjp(f_fwd, f_bwd) -``` - -```{code-cell} -print(f(3.)) -print(grad(f)(3.)) -``` - -In other words, we again start with a primal function `f` that takes inputs of type `a` and produces outputs of type `b`. We associate with it two functions, `f_fwd` and `f_bwd`, which describe how to perform the forward- and backward-passes of reverse-mode autodiff, respectively. - -The function `f_fwd` describes the forward pass, not only the primal computation but also what values to save for use on the backward pass. Its input signature is just like that of the primal function `f`, in that it takes a primal input of type `a`. But as output it produces a pair, where the first element is the primal output `b` and the second element is any "residual" data of type `c` to be stored for use by the backward pass. (This second output is analogous to [PyTorch's save_for_backward mechanism](https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html).) - -The function `f_bwd` describes the backward pass. It takes two inputs, where the first is the residual data of type `c` produced by `f_fwd` and the second is the output cotangents of type `CT b` corresponding to the output of the primal function. It produces an output of type `CT a` representing the cotangents corresponding to the input of the primal function. In particular, the output of `f_bwd` must be a sequence (e.g. a tuple) of length equal to the number of arguments to the primal function. - -So multiple arguments work like this: - -```{code-cell} -@custom_vjp -def f(x, y): - return jnp.sin(x) * y - -def f_fwd(x, y): - return f(x, y), (jnp.cos(x), jnp.sin(x), y) - -def f_bwd(res, g): - cos_x, sin_x, y = res - return (cos_x * g * y, sin_x * g) - -f.defvjp(f_fwd, f_bwd) -``` - -```{code-cell} -print(grad(f)(2., 3.)) -``` - -Calling a {func}`jax.custom_vjp` function with keyword arguments, or writing a {func}`jax.custom_vjp` function definition with default arguments, are both allowed so long as they can be unambiguously mapped to positional arguments based on the function signature retrieved by the standard library `inspect.signature` mechanism. - -As with {func}`jax.custom_jvp`, the custom VJP rule composed of `f_fwd` and `f_bwd` is not invoked if differentiation is not applied. If the function is evaluated, or transformed with {func}`jax.jit`, {func}`jax.vmap`, or other non-differentiation transformations, then only `f` is called. - -```{code-cell} -@custom_vjp -def f(x): - print("called f!") - return jnp.sin(x) - -def f_fwd(x): - print("called f_fwd!") - return f(x), jnp.cos(x) - -def f_bwd(cos_x, y_bar): - print("called f_bwd!") - return (cos_x * y_bar,) - -f.defvjp(f_fwd, f_bwd) -``` - -```{code-cell} -print(f(3.)) -``` - -```{code-cell} -print(grad(f)(3.)) -``` - -```{code-cell} -y, f_vjp = vjp(f, 3.) -print(y) -``` - -```{code-cell} -print(f_vjp(1.)) -``` - -**Forward-mode autodiff cannot be used on the** {func}`jax.custom_vjp` **function** and will raise an error: - -```{code-cell} -:tags: [raises-exception] - -from jax import jvp - -try: - jvp(f, (3.,), (1.,)) -except TypeError as e: - print('ERROR! {}'.format(e)) -``` - -If you want to use both forward- and reverse-mode, use {func}`jax.custom_jvp` instead. - -We can use {func}`jax.custom_vjp` together with `pdb` to insert a debugger trace in the backward pass: - -```{code-cell} -import pdb - -@custom_vjp -def debug(x): - return x # acts like identity - -def debug_fwd(x): - return x, x - -def debug_bwd(x, g): - import pdb; pdb.set_trace() - return g - -debug.defvjp(debug_fwd, debug_bwd) -``` - -```{code-cell} -def foo(x): - y = x ** 2 - y = debug(y) # insert pdb in corresponding backward pass step - return jnp.sin(y) -``` - -```python -jax.grad(foo)(3.) - -> (12)debug_bwd() --> return g -(Pdb) p x -Array(9., dtype=float32) -(Pdb) p g -Array(-0.91113025, dtype=float32) -(Pdb) q -``` - - -### More features and details - -#### Working with `list` / `tuple` / `dict` containers (and other pytrees) - -You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. - -Here's a contrived example with {func}`jax.custom_jvp`: - -```{code-cell} -from collections import namedtuple -Point = namedtuple("Point", ["x", "y"]) - -@custom_jvp -def f(pt): - x, y = pt.x, pt.y - return {'a': x ** 2, - 'b': (jnp.sin(x), jnp.cos(y))} - -@f.defjvp -def f_jvp(primals, tangents): - pt, = primals - pt_dot, = tangents - ans = f(pt) - ans_dot = {'a': 2 * pt.x * pt_dot.x, - 'b': (jnp.cos(pt.x) * pt_dot.x, -jnp.sin(pt.y) * pt_dot.y)} - return ans, ans_dot - -def fun(pt): - dct = f(pt) - return dct['a'] + dct['b'][0] -``` - -```{code-cell} -pt = Point(1., 2.) - -print(f(pt)) -``` - -```{code-cell} -print(grad(fun)(pt)) -``` - -And an analogous contrived example with {func}`jax.custom_vjp`: - -```{code-cell} -@custom_vjp -def f(pt): - x, y = pt.x, pt.y - return {'a': x ** 2, - 'b': (jnp.sin(x), jnp.cos(y))} - -def f_fwd(pt): - return f(pt), pt - -def f_bwd(pt, g): - a_bar, (b0_bar, b1_bar) = g['a'], g['b'] - x_bar = 2 * pt.x * a_bar + jnp.cos(pt.x) * b0_bar - y_bar = -jnp.sin(pt.y) * b1_bar - return (Point(x_bar, y_bar),) - -f.defvjp(f_fwd, f_bwd) - -def fun(pt): - dct = f(pt) - return dct['a'] + dct['b'][0] -``` - -```{code-cell} -pt = Point(1., 2.) - -print(f(pt)) -``` - -```{code-cell} -print(grad(fun)(pt)) -``` - -#### Handling non-differentiable arguments - -Some use cases, like the final example problem, call for non-differentiable arguments like function-valued arguments to be passed to functions with custom differentiation rules, and for those arguments to also be passed to the rules themselves. In the case of `fixed_point`, the function argument `f` was such a non-differentiable argument. A similar situation arises with `jax.experimental.odeint`. - -##### `jax.custom_jvp` with `nondiff_argnums` - -Use the optional `nondiff_argnums` parameter to {func}`jax.custom_jvp` to indicate arguments like these. Here's an example with {func}`jax.custom_jvp`: - -```{code-cell} -from functools import partial - -@partial(custom_jvp, nondiff_argnums=(0,)) -def app(f, x): - return f(x) - -@app.defjvp -def app_jvp(f, primals, tangents): - x, = primals - x_dot, = tangents - return f(x), 2. * x_dot -``` - -```{code-cell} -print(app(lambda x: x ** 3, 3.)) -``` - -```{code-cell} -print(grad(app, 1)(lambda x: x ** 3, 3.)) -``` - -Notice the gotcha here: no matter where in the argument list these parameters appear, they're placed at the *start* of the signature of the corresponding JVP rule. Here's another example: - -```{code-cell} -@partial(custom_jvp, nondiff_argnums=(0, 2)) -def app2(f, x, g): - return f(g((x))) - -@app2.defjvp -def app2_jvp(f, g, primals, tangents): - x, = primals - x_dot, = tangents - return f(g(x)), 3. * x_dot -``` - -```{code-cell} -print(app2(lambda x: x ** 3, 3., lambda y: 5 * y)) -``` - -```{code-cell} -print(grad(app2, 1)(lambda x: x ** 3, 3., lambda y: 5 * y)) -``` - -##### `jax.custom_vjp` with `nondiff_argnums` - -A similar option exists for {func}`jax.custom_vjp`, and, similarly, the convention is that the non-differentiable arguments are passed as the first arguments to the `_bwd` rule, no matter where they appear in the signature of the original function. The signature of the `_fwd` rule remains unchanged - it is the same as the signature of the primal function. Here's an example: - -```{code-cell} -@partial(custom_vjp, nondiff_argnums=(0,)) -def app(f, x): - return f(x) - -def app_fwd(f, x): - return f(x), x - -def app_bwd(f, x, g): - return (5 * g,) - -app.defvjp(app_fwd, app_bwd) -``` - -```{code-cell} -print(app(lambda x: x ** 2, 4.)) -``` - -```{code-cell} -print(grad(app, 1)(lambda x: x ** 2, 4.)) -``` - -Refer to `fixed_point` above for another usage example. - -**You don't need to use** `nondiff_argnums` **with array-valued arguments**, such as, for example, ones with the integer dtype. Instead, `nondiff_argnums` should only be used for argument values that don't correspond to JAX types (essentially don't correspond to array types), like Python callables or strings. If JAX detects that an argument indicated by `nondiff_argnums` contains a JAX Tracer, then an error is raised. The `clip_gradient` function above is a good example of not using `nondiff_argnums` for integer-dtype array arguments. - -## Next steps - -There's a whole world of other autodiff tricks and functionality out there. Topics that weren't covered in this tutorial but can be worth pursuing include: - - - Gauss-Newton Vector Products, linearizing once - - Custom VJPs and JVPs - - Efficient derivatives at fixed-points - - Estimating the trace of a Hessian using random Hessian-vector products - - Forward-mode autodiff using only reverse-mode autodiff - - Taking derivatives with respect to custom data types - - Checkpointing (binomial checkpointing for efficient reverse-mode, not model snapshotting) - - Optimizing VJPs with Jacobian pre-accumulation diff --git a/docs/advanced_autodiff.md b/docs/advanced_autodiff.md new file mode 100644 index 000000000000..43b62b0e0d5c --- /dev/null +++ b/docs/advanced_autodiff.md @@ -0,0 +1,11 @@ +# Advanced Automatic Differentiation + +```{toctree} +:caption: Advanced automatic differentiation +:maxdepth: 1 + +higher-order +jacobian-vector-products +complex-differentiation +notebooks/Custom_derivative_rules_for_Python_code +``` diff --git a/docs/advanced_guide.rst b/docs/advanced_guide.rst deleted file mode 100644 index db2e83ae2720..000000000000 --- a/docs/advanced_guide.rst +++ /dev/null @@ -1,33 +0,0 @@ -.. _advanced_guide: - -Advanced guides -=============== - -This section contains examples and tutorials on more advanced topics, -such as multi-core computation, automatic differentiation, and custom -operations. - -.. toctree:: - :caption: Parallel computation - :maxdepth: 1 - - notebooks/Distributed_arrays_and_automatic_parallelization - notebooks/explicit-sharding - notebooks/shard_map - multi_process - distributed_data_loading - -.. toctree:: - :caption: Automatic differentiation - :maxdepth: 1 - - notebooks/autodiff_cookbook - notebooks/Custom_derivative_rules_for_Python_code - notebooks/autodiff_remat - -.. toctree:: - :caption: Deep dives - :maxdepth: 1 - - notebooks/convolutions - xla_flags diff --git a/docs/advanced_guides.rst b/docs/advanced_guides.rst new file mode 100644 index 000000000000..4a7624e08262 --- /dev/null +++ b/docs/advanced_guides.rst @@ -0,0 +1,110 @@ +.. _advanced_guides: + +Resources and Advanced Guides +============================= + +This section contains examples and tutorials on more advanced topics, +such as multi-core computation, automatic differentiation, and custom +operations. + +.. toctree:: + :caption: Parallel computation + :maxdepth: 1 + + notebooks/Distributed_arrays_and_automatic_parallelization + notebooks/explicit-sharding + notebooks/shard_map + notebooks/layout + notebooks/host-offloading + multi_process + fault_tolerance + distributed_data_loading + notebooks/colocated-python + +.. toctree:: + :caption: Machine learning + :maxdepth: 1 + + the-training-cookbook + +.. toctree:: + :caption: Automatic differentiation + :maxdepth: 1 + + notebooks/autodiff_cookbook + notebooks/autodiff_remat + advanced_autodiff + +.. toctree:: + :maxdepth: 1 + :caption: Errors and debugging + + errors + debugging + debugging/index + transfer_guard + +.. toctree:: + :maxdepth: 1 + :caption: Pytrees + + custom_pytrees + +.. toctree:: + :maxdepth: 1 + :caption: Performance optimizations + + persistent_compilation_cache + buffer_donation + gpu_performance_tips + +.. toctree:: + :maxdepth: 1 + :caption: Performance benchmarking and profiling + + benchmarking + profiling + device_memory_profiling + +.. toctree:: + :caption: Non-functional programming + :maxdepth: 1 + + array_refs + +.. toctree:: + :caption: External Callbacks + :maxdepth: 1 + + external-callbacks + +.. toctree:: + :caption: FFI + :maxdepth: 1 + + ffi + +.. toctree:: + :caption: Modeling workflows + :maxdepth: 1 + + gradient-checkpointing + aot + export/index + +.. toctree:: + :caption: Example applications + :maxdepth: 1 + + notebooks/neural_network_with_tfds_data + notebooks/Neural_Network_and_Data_Loading + notebooks/vmapped_log_probs + +.. toctree:: + :caption: Deep dives + :maxdepth: 1 + + notebooks/convolutions + xla_flags + jax-primitives + jaxpr diff --git a/docs/aot.md b/docs/aot.md index 1fcf11ab945d..1870f8c55093 100644 --- a/docs/aot.md +++ b/docs/aot.md @@ -26,7 +26,7 @@ are arrays, JAX does the following in order: carries out this specialization by a process that we call _tracing_. During tracing, JAX stages the specialization of `F` to a jaxpr, which is a function in the [Jaxpr intermediate - language](https://jax.readthedocs.io/en/latest/jaxpr.html). + language](https://docs.jax.dev/en/latest/jaxpr.html). 2. **Lower** this specialized, staged-out computation to the XLA compiler's input language, StableHLO. @@ -49,7 +49,10 @@ some other features along the way. An example: >>> # Print the specialized, staged-out representation (as Jaxpr IR) >>> print(traced.jaxpr) -{ lambda ; a:i32[] b:i32[]. let c:i32[] = mul 2 a; d:i32[] = add c b in (d,) } +{ lambda ; a:i32[] b:i32[]. let + c:i32[] = mul 2:i32[] a + d:i32[] = add c b + in (d,) } >>> lowered = traced.lower() diff --git a/docs/api_compatibility.md b/docs/api_compatibility.md index 749c5907bc6b..985b2145c5c4 100644 --- a/docs/api_compatibility.md +++ b/docs/api_compatibility.md @@ -59,6 +59,11 @@ Any API or import path prefixed with an underscore is explicitly private, and may change without warning between JAX releases. We are working to move all private APIs into `jax._src` to make these expectations more clear. +### jaxlib +Any import path in the `jaxlib` package is considered private, and may change +without warning between releases. Some APIs defined in `jaxlib` have public +aliases in the `jax` package. + ### Legacy internal APIs In addition, there are several legacy modules that currently expose some private APIs without an underscore, including: @@ -91,7 +96,7 @@ guarantees of the main JAX package. If you have code that uses `jax.extend`, we would strongly recommend CI tests against JAX's nightly releases, so as to catch potential changes before they are released. -For details on `jax.extend`, see the [`jax.extend` module docuementation](https://jax.readthedocs.io/en/latest/jax.extend.html), or the design document, {ref}`jax-extend-jep`. +For details on `jax.extend`, see the [`jax.extend` module documentation](https://docs.jax.dev/en/latest/jax.extend.html), or the design document, {ref}`jax-extend-jep`. ## Numerics and randomness diff --git a/docs/array_refs.ipynb b/docs/array_refs.ipynb new file mode 100644 index 000000000000..a735e8573946 --- /dev/null +++ b/docs/array_refs.ipynb @@ -0,0 +1,655 @@ +{ + "cells": [ + { + "cell_type": "raw", + "id": "b32297a4", + "metadata": {}, + "source": [ + "---\n", + "Copyright 2025 The JAX Authors.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "you may not use this file except in compliance with the License.\n", + "You may obtain a copy of the License at\n", + "\n", + " https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "Unless required by applicable law or agreed to in writing, software\n", + "distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "See the License for the specific language governing permissions and\n", + "limitations under the License.\n", + "\n", + "---" + ] + }, + { + "cell_type": "markdown", + "id": "380b6c4e", + "metadata": {}, + "source": [ + "# `Ref`: mutable arrays for data plumbing and memory control\n", + "\n", + "JAX `Array`s are immutable, representing mathematical values. Immutability can\n", + "make code easier to reason about, and is useful for optimized compilation,\n", + "parallelization, rematerialization, and transformations like autodiff.\n", + "\n", + "But immutability is constraining too:\n", + "* **expressiveness** --- plumbing out intermediate data or maintaining state,\n", + " e.g. for normalization statistics or metrics, can feel heavyweight;\n", + "* **performance** --- it's more difficult to reason about performance, like\n", + " memory lifetimes and in-place updates.\n", + "\n", + "`Ref`s can help! They represent mutable arrays that can be read and written\n", + "in-place. These array references are compatible with JAX transformations, like\n", + "`jax.jit` and `jax.grad`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7909a3e2", + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "x_ref = jax.new_ref(jnp.zeros(3)) # new array ref, with initial value [0., 0., 0.]\n", + "\n", + "@jax.jit\n", + "def f():\n", + " x_ref[1] += 1. # indexed add-update\n", + "\n", + "print(x_ref) # Ref([0., 0., 0.])\n", + "f()\n", + "f()\n", + "print(x_ref) # Ref([0., 2., 0.])" + ] + }, + { + "cell_type": "markdown", + "id": "667af649", + "metadata": {}, + "source": [ + "The indexing syntax follows NumPy's. For a `Ref` called `x_ref`, we can\n", + "read its entire value into an `Array` by writing `x_ref[...]`, and write its\n", + "entire value using `x_ref[...] = A` for some `Array`-valued expression `A`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1824d27", + "metadata": {}, + "outputs": [], + "source": [ + "def g(x):\n", + " x_ref = jax.new_ref(0.)\n", + " x_ref[...] = jnp.sin(x)\n", + " return x_ref[...]\n", + "\n", + "print(jax.grad(g)(1.0)) # 0.54" + ] + }, + { + "cell_type": "markdown", + "id": "ff8dc074", + "metadata": {}, + "source": [ + "`Ref` is a distinct type from `Array`, and it comes with some important\n", + "constraints and limitations. In particular, indexed reading and writing is just\n", + "about the *only* thing you can do with an `Ref`. References can't be passed\n", + "where `Array`s are expected:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c2191893", + "metadata": {}, + "outputs": [], + "source": [ + "x_ref = jax.new_ref(1.0)\n", + "try:\n", + " jnp.sin(x_ref) # error! can't do math on refs\n", + "except Exception as e:\n", + " print(e)" + ] + }, + { + "cell_type": "markdown", + "id": "2ab77be5", + "metadata": {}, + "source": [ + "To do math, you need to read the ref's value first, like `jnp.sin(x_ref[...])`.\n", + "\n", + "So what _can_ you do with `Ref`? Read on for the details, and some useful\n", + "recipes.\n", + "\n", + "### API\n", + "\n", + "If you've ever used\n", + "[Pallas](https://docs.jax.dev/en/latest/pallas/quickstart.html), then `Ref`\n", + "should look familiar. A big difference is that you can create new `Ref`s\n", + "yourself directly using `jax.new_ref`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8cc852f2", + "metadata": {}, + "outputs": [], + "source": [ + "from jax import Array, Ref\n", + "\n", + "def array_ref(init_val: Array) -> Ref:\n", + " \"\"\"Introduce a new reference with given initial value.\"\"\"" + ] + }, + { + "cell_type": "markdown", + "id": "f4565356", + "metadata": {}, + "source": [ + "`jax.freeze` is its antithesis, invalidating the given ref (so that accessing it\n", + "afterwards is an error) and producing its final value:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "049048ed", + "metadata": {}, + "outputs": [], + "source": [ + "def freeze(ref: Ref) -> Array:\n", + " \"\"\"Invalidate given reference and produce its final value.\"\"\"" + ] + }, + { + "cell_type": "markdown", + "id": "62dd629d", + "metadata": {}, + "source": [ + "In between creating and destroying them, you can perform indexed reads and\n", + "writes on refs. You can read and write using the functions `jax.ref.get` and\n", + "`jax.ref.swap`, but usually you'd just use NumPy-style array indexing syntax:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61b34483", + "metadata": {}, + "outputs": [], + "source": [ + "import types\n", + "Index = int | slice | Array | types.EllipsisType\n", + "Indexer = Index | tuple[Index, ...]\n", + "\n", + "def get(ref: Ref, idx: Indexer) -> Array:\n", + " \"\"\"Returns `ref[idx]` for NumPy-style indexer `idx`.\"\"\"\n", + "\n", + "def swap(ref: Ref, idx: Indexer, val: Array) -> Array:\n", + " \"\"\"Performs `newval, ref[idx] = ref[idx], val` and returns `newval`.\"\"\"" + ] + }, + { + "cell_type": "markdown", + "id": "a0ae59b8", + "metadata": {}, + "source": [ + "Here, `Indexer` can be any NumPy indexing expression:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b5f080d5", + "metadata": {}, + "outputs": [], + "source": [ + "x_ref = jax.new_ref(jnp.arange(12.).reshape(3, 4))\n", + "\n", + "# int indexing\n", + "row = x_ref[0]\n", + "x_ref[1] = row\n", + "\n", + "# tuple indexing\n", + "val = x_ref[1, 2]\n", + "x_ref[2, 3] = val\n", + "\n", + "# slice indexing\n", + "col = x_ref[:, 1]\n", + "x_ref[0, :3] = col\n", + "\n", + "# advanced int array indexing\n", + "vals = x_ref[jnp.array([0, 0, 1]), jnp.array([1, 2, 3])]\n", + "x_ref[jnp.array([1, 2, 1]), jnp.array([0, 0, 1])] = vals" + ] + }, + { + "cell_type": "markdown", + "id": "bd3edc22", + "metadata": {}, + "source": [ + "As with `Array`s, indexing mostly follows NumPy behavior, except for\n", + "out-of-bounds indexing which [behaves in the usual way for JAX\n", + "`Array`s](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#out-of-bounds-indexing).\n", + "\n", + "### Pure and impure functions\n", + "\n", + "A function that takes a ref as an argument (either explicitly or by lexical\n", + "closure) is considered _impure_. For example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f2841ccb", + "metadata": {}, + "outputs": [], + "source": [ + "# takes ref as an argument => impure\n", + "@jax.jit\n", + "def impure1(x_ref, y_ref):\n", + " x_ref[...] = y_ref[...]\n", + "\n", + "# closes over ref => impure\n", + "y_ref = jax.new_ref(0)\n", + "\n", + "@jax.jit\n", + "def impure2(x):\n", + " y_ref[...] = x" + ] + }, + { + "cell_type": "markdown", + "id": "c6b946f6", + "metadata": {}, + "source": [ + "If a function only uses refs internally, it is still considered _pure_. Purity\n", + "is in the eye of the caller. For example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cf8fb062", + "metadata": {}, + "outputs": [], + "source": [ + "# internal refs => still pure\n", + "@jax.jit\n", + "def pure1(x):\n", + " ref = jax.new_ref(x)\n", + " ref[...] = ref[...] + ref[...]\n", + " return ref[...]" + ] + }, + { + "cell_type": "markdown", + "id": "09ed144d", + "metadata": {}, + "source": [ + "Pure functions, even those that use refs internally, are familiar: for example,\n", + "they work with transformations like `jax.grad`, `jax.vmap`, `jax.shard_map`, and\n", + "others in the usual way.\n", + "\n", + "Impure functions are sequenced in Python program order.\n", + "\n", + "### Restrictions\n", + "\n", + "`Ref`s are second-class, in the sense that there are restrictions on their\n", + "use:\n", + "\n", + "* **Can't return refs** from `jit`\\-decorated functions or the bodies of\n", + " higher-order primitives like `jax.lax.scan`, `jax.lax.while_loop`, or\n", + " `jax.lax.cond`\n", + "* **Can't pass a ref as an argument more than once** to `jit`\\-decorated\n", + " functions or higher-order primitives\n", + "* **Can only `freeze` in creation scope**\n", + "* **No higher-order refs** (refs-to-refs)\n", + "\n", + "For example, these are errors:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61a4e501", + "metadata": {}, + "outputs": [], + "source": [ + "x_ref = jax.new_ref(0.)\n", + "\n", + "# can't return refs\n", + "@jax.jit\n", + "def err1(x_ref):\n", + " x_ref[...] = 5.\n", + " return x_ref # error!\n", + "try:\n", + " err1(x_ref)\n", + "except Exception as e:\n", + " print(e)\n", + "\n", + "# can't pass a ref as an argument more than once\n", + "@jax.jit\n", + "def err2(x_ref, y_ref):\n", + " ...\n", + "try:\n", + " err2(x_ref, x_ref) # error!\n", + "except Exception as e:\n", + " print(e)\n", + "\n", + "# can't pass and close over the same ref\n", + "@jax.jit\n", + "def err3(y_ref):\n", + " y_ref[...] = x_ref[...]\n", + "try:\n", + " err3(x_ref) # error!\n", + "except Exception as e:\n", + " print(e)\n", + "\n", + "# can only freeze in creation scope\n", + "@jax.jit\n", + "def err4(x_ref):\n", + " jax.freeze(x_ref)\n", + "try:\n", + " err4(x_ref) # error!\n", + "except Exception as e:\n", + " print(e)" + ] + }, + { + "cell_type": "markdown", + "id": "fc360213", + "metadata": {}, + "source": [ + "These restrictions exist to rule out aliasing, where two refs might refer to the\n", + "same mutable memory, making programs harder to reason about and transform.\n", + "Weaker restrictions would also suffice, so some of these restrictions may be\n", + "lifted as we improve JAX's ability to verify that no aliasing is present.\n", + "\n", + "There are also restrictions stemming from undefined semantics, e.g. in the\n", + "presence of parallelism or rematerialization:\n", + "\n", + "* **Can't `vmap` or `shard_map` a function that closes over refs**\n", + "* **Can't apply `jax.remat`/`jax.checkpoint` to an impure function**\n", + "\n", + "For example, here are ways you can and can't use `vmap` with impure functions:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5701f96e", + "metadata": {}, + "outputs": [], + "source": [ + "# vmap over ref args is okay\n", + "def dist(x, y, out_ref):\n", + " assert x.ndim == y.ndim == 1\n", + " assert out_ref.ndim == 0\n", + " out_ref[...] = jnp.sum((x - y) ** 2)\n", + "\n", + "vecs = jnp.arange(12.).reshape(3, 4)\n", + "out_ref = jax.new_ref(jnp.zeros((3, 3)))\n", + "jax.vmap(jax.vmap(dist, (0, None, 0)), (None, 0, 0))(vecs, vecs, out_ref) # ok!\n", + "print(out_ref)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d94d08be", + "metadata": {}, + "outputs": [], + "source": [ + "# vmap with a closed-over ref is not\n", + "x_ref = jax.new_ref(0.)\n", + "\n", + "def err5(x):\n", + " x_ref[...] = x\n", + "\n", + "try:\n", + " jax.vmap(err5)(jnp.arange(3.)) # error!\n", + "except Exception as e:\n", + " print(e)" + ] + }, + { + "cell_type": "markdown", + "id": "33e635e5", + "metadata": {}, + "source": [ + "The latter is an error because it's not clear which value `x_ref` should be\n", + "after we run `jax.vmap(err5)`.\n", + "\n", + "### `Ref`s and automatic differentiation\n", + "\n", + "Autodiff can be applied to pure functions as before, even if they use array refs\n", + "internally. For example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7b5d32e1", + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def pure2(x):\n", + " ref = jax.new_ref(x)\n", + " ref[...] = ref[...] + ref[...]\n", + " return ref[...]\n", + "\n", + "print(jax.grad(pure1)(3.0)) # 2.0" + ] + }, + { + "cell_type": "markdown", + "id": "801c3b60", + "metadata": {}, + "source": [ + "Autodiff can also be applied to functions that take array refs as arguments, if\n", + "those arguments are only used for plumbing and not involved in differentiation:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c6cd5576", + "metadata": {}, + "outputs": [], + "source": [ + "# error\n", + "def err6(x, some_plumbing_ref):\n", + " y = x + x\n", + " some_plumbing_ref[...] += y\n", + " return y\n", + "\n", + "# fine\n", + "def foo(x, some_plumbing_ref):\n", + " y = x + x\n", + " some_plumbing_ref[...] += jax.lax.stop_gradient(y)\n", + " return y" + ] + }, + { + "cell_type": "markdown", + "id": "86622dd6", + "metadata": {}, + "source": [ + "You can combine plumbing refs with `custom_vjp` to plumb data out of the\n", + "backward pass of a differentiated function:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b1f17fbc", + "metadata": {}, + "outputs": [], + "source": [ + "# First, define the helper `stash_grads`:\n", + "\n", + "@jax.custom_vjp\n", + "def stash_grads(grads_ref, x):\n", + " return x\n", + "\n", + "def stash_grads_fwd(grads_ref, x):\n", + " return x, grads_ref\n", + "\n", + "def stash_grads_bwd(grads_ref, g):\n", + " grads_ref[...] = g\n", + " return None, g\n", + "\n", + "stash_grads.defvjp(stash_grads_fwd, stash_grads_bwd)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0c5842e", + "metadata": {}, + "outputs": [], + "source": [ + "# Now, use `stash_grads` to stash intermediate gradients:\n", + "\n", + "def f(x, grads_ref):\n", + " x = jnp.sin(x)\n", + " x = stash_grads(grads_ref, x)\n", + " return x\n", + "\n", + "grads_ref = jax.new_ref(0.)\n", + "f(1., grads_ref)\n", + "print(grads_ref)" + ] + }, + { + "cell_type": "markdown", + "id": "4d8518e7", + "metadata": {}, + "source": [ + "Notice `stash_grads_fwd` is returning a `Ref` here. That's a special\n", + "allowance for `custom_vjp` fwd rules: it's really syntax for indicating which\n", + "ref arguments should be shared by both the fwd and bwd rules. So any refs\n", + "returned by a fwd rule must be arguments to that fwd rule.\n", + "\n", + "### `Ref`s and performance\n", + "\n", + "At the top level, when calling `jit`\\-decorated functions, `Ref`s obviate\n", + "the need for donation, since they are effectively always donated:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64f3655c", + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def sin_inplace(x_ref):\n", + " x_ref[...] = jnp.sin(x_ref[...])\n", + "\n", + "x_ref = jax.new_ref(jnp.arange(3.))\n", + "print(x_ref.unsafe_buffer_pointer(), x_ref)\n", + "sin_inplace(x_ref)\n", + "print(x_ref.unsafe_buffer_pointer(), x_ref)" + ] + }, + { + "cell_type": "markdown", + "id": "a758adac", + "metadata": {}, + "source": [ + "Here `sin_inplace` operates in-place, updating the buffer backing `x_ref` so\n", + "that its address stays the same.\n", + "\n", + "Under a `jit`, you should expect array references to point to fixed buffer\n", + "addresses, and for indexed updates to be performed in-place.\n", + "\n", + "**Temporary caveat:** dispatch from Python to impure `jit`\\-compiled functions\n", + "that take `Ref` inputs is currently slower than dispatch to pure\n", + "`jit`\\-compiled functions, since it takes a less optimized path.\n", + "\n", + "### `foreach`, a new way to write `scan`\n", + "\n", + "As you may know, `jax.lax.scan` is a loop construct with a built-in fixed access\n", + "pattern for scanned-over inputs and outputs. The access pattern is built in for\n", + "autodiff reasons: if we were instead to slice into immutable inputs directly,\n", + "reverse-mode autodiff would end up creating one-hot gradients and summing them\n", + "up, which can be asymptotically inefficient. See [Sec 5.3.3 of the Dex\n", + "paper](https://arxiv.org/pdf/2104.05372).\n", + "\n", + "But reading slices of `Ref`s doesn't have this efficiency problem: when we\n", + "apply reverse-mode autodiff, we always generate in-place accumulation\n", + "operations. As a result, we no longer need to be constrained by `scan`'s fixed\n", + "access pattern. We can write more flexible loops, e.g. with non-sequential\n", + "access.\n", + "\n", + "Moreover, having mutation available allows for some syntax tricks, like in this\n", + "recipe for a `foreach` decorator:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11753b6e", + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "from jax.lax import scan\n", + "\n", + "def foreach(*args):\n", + " def decorator(body):\n", + " return scan(lambda _, elts: (None, body(*elts)), None, args)[1]\n", + " return decorator" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3ddc7abe", + "metadata": {}, + "outputs": [], + "source": [ + "r = jax.new_ref(0)\n", + "xs = jnp.arange(10)\n", + "\n", + "@foreach(xs)\n", + "def ys(x):\n", + " r[...] += x\n", + " return x * 2\n", + "\n", + "print(r) # Ref(45, dtype=int32)\n", + "print(ys) # [ 0 2 4 6 8 10 12 14 16 18]" + ] + }, + { + "cell_type": "markdown", + "id": "570970cd", + "metadata": {}, + "source": [ + "Here, the loop runs immediately, updating `r` in-place and binding `ys` to be\n", + "the mapped result." + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "formats": "ipynb,md:myst,py", + "main_language": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/array_refs.md b/docs/array_refs.md new file mode 100644 index 000000000000..7aabb8350ea1 --- /dev/null +++ b/docs/array_refs.md @@ -0,0 +1,434 @@ +--- +jupytext: + cell_metadata_filter: -all + formats: ipynb,md:myst,py + main_language: python + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.4 +--- + +```{raw-cell} + +--- +Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +--- +``` + +# `Ref`: mutable arrays for data plumbing and memory control + +JAX `Array`s are immutable, representing mathematical values. Immutability can +make code easier to reason about, and is useful for optimized compilation, +parallelization, rematerialization, and transformations like autodiff. + +But immutability is constraining too: +* **expressiveness** --- plumbing out intermediate data or maintaining state, + e.g. for normalization statistics or metrics, can feel heavyweight; +* **performance** --- it's more difficult to reason about performance, like + memory lifetimes and in-place updates. + +`Ref`s can help! They represent mutable arrays that can be read and written +in-place. These array references are compatible with JAX transformations, like +`jax.jit` and `jax.grad`: + +```{code-cell} +import jax +import jax.numpy as jnp + +x_ref = jax.new_ref(jnp.zeros(3)) # new array ref, with initial value [0., 0., 0.] + +@jax.jit +def f(): + x_ref[1] += 1. # indexed add-update + +print(x_ref) # Ref([0., 0., 0.]) +f() +f() +print(x_ref) # Ref([0., 2., 0.]) +``` + +The indexing syntax follows NumPy's. For a `Ref` called `x_ref`, we can +read its entire value into an `Array` by writing `x_ref[...]`, and write its +entire value using `x_ref[...] = A` for some `Array`-valued expression `A`: + +```{code-cell} +def g(x): + x_ref = jax.new_ref(0.) + x_ref[...] = jnp.sin(x) + return x_ref[...] + +print(jax.grad(g)(1.0)) # 0.54 +``` + +`Ref` is a distinct type from `Array`, and it comes with some important +constraints and limitations. In particular, indexed reading and writing is just +about the *only* thing you can do with an `Ref`. References can't be passed +where `Array`s are expected: + +```{code-cell} +x_ref = jax.new_ref(1.0) +try: + jnp.sin(x_ref) # error! can't do math on refs +except Exception as e: + print(e) +``` + +To do math, you need to read the ref's value first, like `jnp.sin(x_ref[...])`. + +So what _can_ you do with `Ref`? Read on for the details, and some useful +recipes. + +### API + +If you've ever used +[Pallas](https://docs.jax.dev/en/latest/pallas/quickstart.html), then `Ref` +should look familiar. A big difference is that you can create new `Ref`s +yourself directly using `jax.new_ref`: + +```{code-cell} +from jax import Array, Ref + +def array_ref(init_val: Array) -> Ref: + """Introduce a new reference with given initial value.""" +``` + +`jax.freeze` is its antithesis, invalidating the given ref (so that accessing it +afterwards is an error) and producing its final value: + +```{code-cell} +def freeze(ref: Ref) -> Array: + """Invalidate given reference and produce its final value.""" +``` + +In between creating and destroying them, you can perform indexed reads and +writes on refs. You can read and write using the functions `jax.ref.get` and +`jax.ref.swap`, but usually you'd just use NumPy-style array indexing syntax: + +```{code-cell} +import types +Index = int | slice | Array | types.EllipsisType +Indexer = Index | tuple[Index, ...] + +def get(ref: Ref, idx: Indexer) -> Array: + """Returns `ref[idx]` for NumPy-style indexer `idx`.""" + +def swap(ref: Ref, idx: Indexer, val: Array) -> Array: + """Performs `newval, ref[idx] = ref[idx], val` and returns `newval`.""" +``` + +Here, `Indexer` can be any NumPy indexing expression: + +```{code-cell} +x_ref = jax.new_ref(jnp.arange(12.).reshape(3, 4)) + +# int indexing +row = x_ref[0] +x_ref[1] = row + +# tuple indexing +val = x_ref[1, 2] +x_ref[2, 3] = val + +# slice indexing +col = x_ref[:, 1] +x_ref[0, :3] = col + +# advanced int array indexing +vals = x_ref[jnp.array([0, 0, 1]), jnp.array([1, 2, 3])] +x_ref[jnp.array([1, 2, 1]), jnp.array([0, 0, 1])] = vals +``` + +As with `Array`s, indexing mostly follows NumPy behavior, except for +out-of-bounds indexing which [behaves in the usual way for JAX +`Array`s](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#out-of-bounds-indexing). + +### Pure and impure functions + +A function that takes a ref as an argument (either explicitly or by lexical +closure) is considered _impure_. For example: + +```{code-cell} +# takes ref as an argument => impure +@jax.jit +def impure1(x_ref, y_ref): + x_ref[...] = y_ref[...] + +# closes over ref => impure +y_ref = jax.new_ref(0) + +@jax.jit +def impure2(x): + y_ref[...] = x +``` + +If a function only uses refs internally, it is still considered _pure_. Purity +is in the eye of the caller. For example: + +```{code-cell} +# internal refs => still pure +@jax.jit +def pure1(x): + ref = jax.new_ref(x) + ref[...] = ref[...] + ref[...] + return ref[...] +``` + +Pure functions, even those that use refs internally, are familiar: for example, +they work with transformations like `jax.grad`, `jax.vmap`, `jax.shard_map`, and +others in the usual way. + +Impure functions are sequenced in Python program order. + +### Restrictions + +`Ref`s are second-class, in the sense that there are restrictions on their +use: + +* **Can't return refs** from `jit`\-decorated functions or the bodies of + higher-order primitives like `jax.lax.scan`, `jax.lax.while_loop`, or + `jax.lax.cond` +* **Can't pass a ref as an argument more than once** to `jit`\-decorated + functions or higher-order primitives +* **Can only `freeze` in creation scope** +* **No higher-order refs** (refs-to-refs) + +For example, these are errors: + +```{code-cell} +x_ref = jax.new_ref(0.) + +# can't return refs +@jax.jit +def err1(x_ref): + x_ref[...] = 5. + return x_ref # error! +try: + err1(x_ref) +except Exception as e: + print(e) + +# can't pass a ref as an argument more than once +@jax.jit +def err2(x_ref, y_ref): + ... +try: + err2(x_ref, x_ref) # error! +except Exception as e: + print(e) + +# can't pass and close over the same ref +@jax.jit +def err3(y_ref): + y_ref[...] = x_ref[...] +try: + err3(x_ref) # error! +except Exception as e: + print(e) + +# can only freeze in creation scope +@jax.jit +def err4(x_ref): + jax.freeze(x_ref) +try: + err4(x_ref) # error! +except Exception as e: + print(e) +``` + +These restrictions exist to rule out aliasing, where two refs might refer to the +same mutable memory, making programs harder to reason about and transform. +Weaker restrictions would also suffice, so some of these restrictions may be +lifted as we improve JAX's ability to verify that no aliasing is present. + +There are also restrictions stemming from undefined semantics, e.g. in the +presence of parallelism or rematerialization: + +* **Can't `vmap` or `shard_map` a function that closes over refs** +* **Can't apply `jax.remat`/`jax.checkpoint` to an impure function** + +For example, here are ways you can and can't use `vmap` with impure functions: + +```{code-cell} +# vmap over ref args is okay +def dist(x, y, out_ref): + assert x.ndim == y.ndim == 1 + assert out_ref.ndim == 0 + out_ref[...] = jnp.sum((x - y) ** 2) + +vecs = jnp.arange(12.).reshape(3, 4) +out_ref = jax.new_ref(jnp.zeros((3, 3))) +jax.vmap(jax.vmap(dist, (0, None, 0)), (None, 0, 0))(vecs, vecs, out_ref) # ok! +print(out_ref) +``` + +```{code-cell} +# vmap with a closed-over ref is not +x_ref = jax.new_ref(0.) + +def err5(x): + x_ref[...] = x + +try: + jax.vmap(err5)(jnp.arange(3.)) # error! +except Exception as e: + print(e) +``` + +The latter is an error because it's not clear which value `x_ref` should be +after we run `jax.vmap(err5)`. + +### `Ref`s and automatic differentiation + +Autodiff can be applied to pure functions as before, even if they use array refs +internally. For example: + +```{code-cell} +@jax.jit +def pure2(x): + ref = jax.new_ref(x) + ref[...] = ref[...] + ref[...] + return ref[...] + +print(jax.grad(pure1)(3.0)) # 2.0 +``` + +Autodiff can also be applied to functions that take array refs as arguments, if +those arguments are only used for plumbing and not involved in differentiation: + +```{code-cell} +# error +def err6(x, some_plumbing_ref): + y = x + x + some_plumbing_ref[...] += y + return y + +# fine +def foo(x, some_plumbing_ref): + y = x + x + some_plumbing_ref[...] += jax.lax.stop_gradient(y) + return y +``` + +You can combine plumbing refs with `custom_vjp` to plumb data out of the +backward pass of a differentiated function: + +```{code-cell} +# First, define the helper `stash_grads`: + +@jax.custom_vjp +def stash_grads(grads_ref, x): + return x + +def stash_grads_fwd(grads_ref, x): + return x, grads_ref + +def stash_grads_bwd(grads_ref, g): + grads_ref[...] = g + return None, g + +stash_grads.defvjp(stash_grads_fwd, stash_grads_bwd) +``` + +```{code-cell} +# Now, use `stash_grads` to stash intermediate gradients: + +def f(x, grads_ref): + x = jnp.sin(x) + x = stash_grads(grads_ref, x) + return x + +grads_ref = jax.new_ref(0.) +f(1., grads_ref) +print(grads_ref) +``` + +Notice `stash_grads_fwd` is returning a `Ref` here. That's a special +allowance for `custom_vjp` fwd rules: it's really syntax for indicating which +ref arguments should be shared by both the fwd and bwd rules. So any refs +returned by a fwd rule must be arguments to that fwd rule. + +### `Ref`s and performance + +At the top level, when calling `jit`\-decorated functions, `Ref`s obviate +the need for donation, since they are effectively always donated: + +```{code-cell} +@jax.jit +def sin_inplace(x_ref): + x_ref[...] = jnp.sin(x_ref[...]) + +x_ref = jax.new_ref(jnp.arange(3.)) +print(x_ref.unsafe_buffer_pointer(), x_ref) +sin_inplace(x_ref) +print(x_ref.unsafe_buffer_pointer(), x_ref) +``` + +Here `sin_inplace` operates in-place, updating the buffer backing `x_ref` so +that its address stays the same. + +Under a `jit`, you should expect array references to point to fixed buffer +addresses, and for indexed updates to be performed in-place. + +**Temporary caveat:** dispatch from Python to impure `jit`\-compiled functions +that take `Ref` inputs is currently slower than dispatch to pure +`jit`\-compiled functions, since it takes a less optimized path. + +### `foreach`, a new way to write `scan` + +As you may know, `jax.lax.scan` is a loop construct with a built-in fixed access +pattern for scanned-over inputs and outputs. The access pattern is built in for +autodiff reasons: if we were instead to slice into immutable inputs directly, +reverse-mode autodiff would end up creating one-hot gradients and summing them +up, which can be asymptotically inefficient. See [Sec 5.3.3 of the Dex +paper](https://arxiv.org/pdf/2104.05372). + +But reading slices of `Ref`s doesn't have this efficiency problem: when we +apply reverse-mode autodiff, we always generate in-place accumulation +operations. As a result, we no longer need to be constrained by `scan`'s fixed +access pattern. We can write more flexible loops, e.g. with non-sequential +access. + +Moreover, having mutation available allows for some syntax tricks, like in this +recipe for a `foreach` decorator: + +```{code-cell} +import jax +import jax.numpy as jnp +from jax.lax import scan + +def foreach(*args): + def decorator(body): + return scan(lambda _, elts: (None, body(*elts)), None, args)[1] + return decorator +``` + +```{code-cell} +r = jax.new_ref(0) +xs = jnp.arange(10) + +@foreach(xs) +def ys(x): + r[...] += x + return x * 2 + +print(r) # Ref(45, dtype=int32) +print(ys) # [ 0 2 4 6 8 10 12 14 16 18] +``` + +Here, the loop runs immediately, updating `r` in-place and binding `ys` to be +the mapped result. diff --git a/docs/array_refs.py b/docs/array_refs.py new file mode 100644 index 000000000000..001c81f86334 --- /dev/null +++ b/docs/array_refs.py @@ -0,0 +1,443 @@ +# --- +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# jupyter: +# jupytext: +# cell_metadata_filter: -all +# formats: ipynb,md:myst,py +# text_representation: +# extension: .py +# format_name: light +# format_version: '1.5' +# jupytext_version: 1.16.4 +# --- + +# # `Ref`: mutable arrays for data plumbing and memory control +# +# JAX `Array`s are immutable, representing mathematical values. Immutability can +# make code easier to reason about, and is useful for optimized compilation, +# parallelization, rematerialization, and transformations like autodiff. +# +# But immutability is constraining too: +# * **expressiveness** --- plumbing out intermediate data or maintaining state, +# e.g. for normalization statistics or metrics, can feel heavyweight; +# * **performance** --- it's more difficult to reason about performance, like +# memory lifetimes and in-place updates. +# +# `Ref`s can help! They represent mutable arrays that can be read and written +# in-place. These array references are compatible with JAX transformations, like +# `jax.jit` and `jax.grad`: + +# + +import jax +import jax.numpy as jnp + +x_ref = jax.new_ref(jnp.zeros(3)) # new array ref, with initial value [0., 0., 0.] + +@jax.jit +def f(): + x_ref[1] += 1. # indexed add-update + +print(x_ref) # Ref([0., 0., 0.]) +f() +f() +print(x_ref) # Ref([0., 2., 0.]) + + +# - + +# The indexing syntax follows NumPy's. For a `Ref` called `x_ref`, we can +# read its entire value into an `Array` by writing `x_ref[...]`, and write its +# entire value using `x_ref[...] = A` for some `Array`-valued expression `A`: + +# + +def g(x): + x_ref = jax.new_ref(0.) + x_ref[...] = jnp.sin(x) + return x_ref[...] + +print(jax.grad(g)(1.0)) # 0.54 +# - + +# `Ref` is a distinct type from `Array`, and it comes with some important +# constraints and limitations. In particular, indexed reading and writing is just +# about the *only* thing you can do with an `Ref`. References can't be passed +# where `Array`s are expected: + +x_ref = jax.new_ref(1.0) +try: + jnp.sin(x_ref) # error! can't do math on refs +except Exception as e: + print(e) + +# To do math, you need to read the ref's value first, like `jnp.sin(x_ref[...])`. +# +# So what _can_ you do with `Ref`? Read on for the details, and some useful +# recipes. +# +# ### API +# +# If you've ever used +# [Pallas](https://docs.jax.dev/en/latest/pallas/quickstart.html), then `Ref` +# should look familiar. A big difference is that you can create new `Ref`s +# yourself directly using `jax.new_ref`: + +# + +from jax import Array, Ref + +def array_ref(init_val: Array) -> Ref: + """Introduce a new reference with given initial value.""" + + +# - + +# `jax.freeze` is its antithesis, invalidating the given ref (so that accessing it +# afterwards is an error) and producing its final value: + +def freeze(ref: Ref) -> Array: + """Invalidate given reference and produce its final value.""" + + +# In between creating and destroying them, you can perform indexed reads and +# writes on refs. You can read and write using the functions `jax.ref.get` and +# `jax.ref.swap`, but usually you'd just use NumPy-style array indexing syntax: + +# + +import types +Index = int | slice | Array | types.EllipsisType +Indexer = Index | tuple[Index, ...] + +def get(ref: Ref, idx: Indexer) -> Array: + """Returns `ref[idx]` for NumPy-style indexer `idx`.""" + +def swap(ref: Ref, idx: Indexer, val: Array) -> Array: + """Performs `newval, ref[idx] = ref[idx], val` and returns `newval`.""" + + +# - + +# Here, `Indexer` can be any NumPy indexing expression: + +# + +x_ref = jax.new_ref(jnp.arange(12.).reshape(3, 4)) + +# int indexing +row = x_ref[0] +x_ref[1] = row + +# tuple indexing +val = x_ref[1, 2] +x_ref[2, 3] = val + +# slice indexing +col = x_ref[:, 1] +x_ref[0, :3] = col + +# advanced int array indexing +vals = x_ref[jnp.array([0, 0, 1]), jnp.array([1, 2, 3])] +x_ref[jnp.array([1, 2, 1]), jnp.array([0, 0, 1])] = vals + + +# - + +# As with `Array`s, indexing mostly follows NumPy behavior, except for +# out-of-bounds indexing which [behaves in the usual way for JAX +# `Array`s](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#out-of-bounds-indexing). +# +# ### Pure and impure functions +# +# A function that takes a ref as an argument (either explicitly or by lexical +# closure) is considered _impure_. For example: + +# + +# takes ref as an argument => impure +@jax.jit +def impure1(x_ref, y_ref): + x_ref[...] = y_ref[...] + +# closes over ref => impure +y_ref = jax.new_ref(0) + +@jax.jit +def impure2(x): + y_ref[...] = x + + +# - + +# If a function only uses refs internally, it is still considered _pure_. Purity +# is in the eye of the caller. For example: + +# internal refs => still pure +@jax.jit +def pure1(x): + ref = jax.new_ref(x) + ref[...] = ref[...] + ref[...] + return ref[...] + + +# Pure functions, even those that use refs internally, are familiar: for example, +# they work with transformations like `jax.grad`, `jax.vmap`, `jax.shard_map`, and +# others in the usual way. +# +# Impure functions are sequenced in Python program order. +# +# ### Restrictions +# +# `Ref`s are second-class, in the sense that there are restrictions on their +# use: +# +# * **Can't return refs** from `jit`\-decorated functions or the bodies of +# higher-order primitives like `jax.lax.scan`, `jax.lax.while_loop`, or +# `jax.lax.cond` +# * **Can't pass a ref as an argument more than once** to `jit`\-decorated +# functions or higher-order primitives +# * **Can only `freeze` in creation scope** +# * **No higher-order refs** (refs-to-refs) +# +# For example, these are errors: + +# + +x_ref = jax.new_ref(0.) + +# can't return refs +@jax.jit +def err1(x_ref): + x_ref[...] = 5. + return x_ref # error! +try: + err1(x_ref) +except Exception as e: + print(e) + +# can't pass a ref as an argument more than once +@jax.jit +def err2(x_ref, y_ref): + ... +try: + err2(x_ref, x_ref) # error! +except Exception as e: + print(e) + +# can't pass and close over the same ref +@jax.jit +def err3(y_ref): + y_ref[...] = x_ref[...] +try: + err3(x_ref) # error! +except Exception as e: + print(e) + +# can only freeze in creation scope +@jax.jit +def err4(x_ref): + jax.freeze(x_ref) +try: + err4(x_ref) # error! +except Exception as e: + print(e) + + +# - + +# These restrictions exist to rule out aliasing, where two refs might refer to the +# same mutable memory, making programs harder to reason about and transform. +# Weaker restrictions would also suffice, so some of these restrictions may be +# lifted as we improve JAX's ability to verify that no aliasing is present. +# +# There are also restrictions stemming from undefined semantics, e.g. in the +# presence of parallelism or rematerialization: +# +# * **Can't `vmap` or `shard_map` a function that closes over refs** +# * **Can't apply `jax.remat`/`jax.checkpoint` to an impure function** +# +# For example, here are ways you can and can't use `vmap` with impure functions: + +# + +# vmap over ref args is okay +def dist(x, y, out_ref): + assert x.ndim == y.ndim == 1 + assert out_ref.ndim == 0 + out_ref[...] = jnp.sum((x - y) ** 2) + +vecs = jnp.arange(12.).reshape(3, 4) +out_ref = jax.new_ref(jnp.zeros((3, 3))) +jax.vmap(jax.vmap(dist, (0, None, 0)), (None, 0, 0))(vecs, vecs, out_ref) # ok! +print(out_ref) + +# + +# vmap with a closed-over ref is not +x_ref = jax.new_ref(0.) + +def err5(x): + x_ref[...] = x + +try: + jax.vmap(err5)(jnp.arange(3.)) # error! +except Exception as e: + print(e) + + +# - + +# The latter is an error because it's not clear which value `x_ref` should be +# after we run `jax.vmap(err5)`. +# +# ### `Ref`s and automatic differentiation +# +# Autodiff can be applied to pure functions as before, even if they use array refs +# internally. For example: + +# + +@jax.jit +def pure2(x): + ref = jax.new_ref(x) + ref[...] = ref[...] + ref[...] + return ref[...] + +print(jax.grad(pure1)(3.0)) # 2.0 + + +# - + +# Autodiff can also be applied to functions that take array refs as arguments, if +# those arguments are only used for plumbing and not involved in differentiation: + +# + +# error +def err6(x, some_plumbing_ref): + y = x + x + some_plumbing_ref[...] += y + return y + +# fine +def foo(x, some_plumbing_ref): + y = x + x + some_plumbing_ref[...] += jax.lax.stop_gradient(y) + return y + + +# - + +# You can combine plumbing refs with `custom_vjp` to plumb data out of the +# backward pass of a differentiated function: + +# + +# First, define the helper `stash_grads`: + +@jax.custom_vjp +def stash_grads(grads_ref, x): + return x + +def stash_grads_fwd(grads_ref, x): + return x, grads_ref + +def stash_grads_bwd(grads_ref, g): + grads_ref[...] = g + return None, g + +stash_grads.defvjp(stash_grads_fwd, stash_grads_bwd) + + +# + +# Now, use `stash_grads` to stash intermediate gradients: + +def f(x, grads_ref): + x = jnp.sin(x) + x = stash_grads(grads_ref, x) + return x + +grads_ref = jax.new_ref(0.) +f(1., grads_ref) +print(grads_ref) + + +# - + +# Notice `stash_grads_fwd` is returning a `Ref` here. That's a special +# allowance for `custom_vjp` fwd rules: it's really syntax for indicating which +# ref arguments should be shared by both the fwd and bwd rules. So any refs +# returned by a fwd rule must be arguments to that fwd rule. +# +# ### `Ref`s and performance +# +# At the top level, when calling `jit`\-decorated functions, `Ref`s obviate +# the need for donation, since they are effectively always donated: + +# + +@jax.jit +def sin_inplace(x_ref): + x_ref[...] = jnp.sin(x_ref[...]) + +x_ref = jax.new_ref(jnp.arange(3.)) +print(x_ref.unsafe_buffer_pointer(), x_ref) +sin_inplace(x_ref) +print(x_ref.unsafe_buffer_pointer(), x_ref) +# - + +# Here `sin_inplace` operates in-place, updating the buffer backing `x_ref` so +# that its address stays the same. +# +# Under a `jit`, you should expect array references to point to fixed buffer +# addresses, and for indexed updates to be performed in-place. +# +# **Temporary caveat:** dispatch from Python to impure `jit`\-compiled functions +# that take `Ref` inputs is currently slower than dispatch to pure +# `jit`\-compiled functions, since it takes a less optimized path. +# +# ### `foreach`, a new way to write `scan` +# +# As you may know, `jax.lax.scan` is a loop construct with a built-in fixed access +# pattern for scanned-over inputs and outputs. The access pattern is built in for +# autodiff reasons: if we were instead to slice into immutable inputs directly, +# reverse-mode autodiff would end up creating one-hot gradients and summing them +# up, which can be asymptotically inefficient. See [Sec 5.3.3 of the Dex +# paper](https://arxiv.org/pdf/2104.05372). +# +# But reading slices of `Ref`s doesn't have this efficiency problem: when we +# apply reverse-mode autodiff, we always generate in-place accumulation +# operations. As a result, we no longer need to be constrained by `scan`'s fixed +# access pattern. We can write more flexible loops, e.g. with non-sequential +# access. +# +# Moreover, having mutation available allows for some syntax tricks, like in this +# recipe for a `foreach` decorator: + +# + +import jax +import jax.numpy as jnp +from jax.lax import scan + +def foreach(*args): + def decorator(body): + return scan(lambda _, elts: (None, body(*elts)), None, args)[1] + return decorator + + +# + +r = jax.new_ref(0) +xs = jnp.arange(10) + +@foreach(xs) +def ys(x): + r[...] += x + return x * 2 + +print(r) # Ref(45, dtype=int32) +print(ys) # [ 0 2 4 6 8 10 12 14 16 18] +# - + +# Here, the loop runs immediately, updating `r` in-place and binding `ys` to be +# the mapped result. diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb index 7ec91affa05d..f57ce09e0bf6 100644 --- a/docs/autodidax.ipynb +++ b/docs/autodidax.ipynb @@ -72,7 +72,7 @@ "outputs, we want to override primitive application and let different values\n", "flow through our program. For example, we might want to replace the\n", "application of every primitive with an application of [its JVP\n", - "rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html),\n", + "rule](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html),\n", "and let primal-tangent pairs flow through our program. Moreover, we want to be\n", "able to compose multiple transformations, leading to stacks of interpreters." ] @@ -2019,7 +2019,8 @@ "\n", " output = io.StringIO()\n", " c.module.operation.print(file=output)\n", - " compiled = xb.get_backend(None).compile(output.getvalue())\n", + " backend = xb.get_backend(None)\n", + " compiled = backend.compile_and_load(output.getvalue(), backend.devices()[:1])\n", " return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs])\n", "\n", "def _mlir_dtype(dtype: np.dtype) -> ir.Type:\n", @@ -3620,7 +3621,7 @@ "source": [ "Notice that we're not currently supporting the case where the predicate value\n", "itself is batched. In mainline JAX, we handle this case by transforming the\n", - "conditional to a [select primitive](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html).\n", + "conditional to a [select primitive](https://docs.jax.dev/en/latest/_autosummary/jax.lax.select.html).\n", "That transformation is semantically correct so long as `true_fun` and\n", "`false_fun` do not involve any side-effecting primitives.\n", "\n", diff --git a/docs/autodidax.md b/docs/autodidax.md index 2d4d6cd528af..5bf0e8f78e12 100644 --- a/docs/autodidax.md +++ b/docs/autodidax.md @@ -72,7 +72,7 @@ where we apply primitive operations to numerical inputs to produce numerical outputs, we want to override primitive application and let different values flow through our program. For example, we might want to replace the application of every primitive with an application of [its JVP -rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html), +rule](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html), and let primal-tangent pairs flow through our program. Moreover, we want to be able to compose multiple transformations, leading to stacks of interpreters. @@ -1589,7 +1589,8 @@ def xla_callable(hashable_jaxpr: IDHashable, output = io.StringIO() c.module.operation.print(file=output) - compiled = xb.get_backend(None).compile(output.getvalue()) + backend = xb.get_backend(None) + compiled = backend.compile_and_load(output.getvalue(), backend.devices()[:1]) return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs]) def _mlir_dtype(dtype: np.dtype) -> ir.Type: @@ -2843,7 +2844,7 @@ print(out) Notice that we're not currently supporting the case where the predicate value itself is batched. In mainline JAX, we handle this case by transforming the -conditional to a [select primitive](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html). +conditional to a [select primitive](https://docs.jax.dev/en/latest/_autosummary/jax.lax.select.html). That transformation is semantically correct so long as `true_fun` and `false_fun` do not involve any side-effecting primitives. diff --git a/docs/autodidax.py b/docs/autodidax.py index f8c6372fe30d..695fc9993df5 100644 --- a/docs/autodidax.py +++ b/docs/autodidax.py @@ -62,7 +62,7 @@ # outputs, we want to override primitive application and let different values # flow through our program. For example, we might want to replace the # application of every primitive with an application of [its JVP -# rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html), +# rule](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html), # and let primal-tangent pairs flow through our program. Moreover, we want to be # able to compose multiple transformations, leading to stacks of interpreters. @@ -1581,7 +1581,8 @@ def main(*params): output = io.StringIO() c.module.operation.print(file=output) - compiled = xb.get_backend(None).compile(output.getvalue()) + backend = xb.get_backend(None) + compiled = backend.compile_and_load(output.getvalue(), backend.devices()[:1]) return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs]) def _mlir_dtype(dtype: np.dtype) -> ir.Type: @@ -2837,7 +2838,7 @@ def cond_vmap_rule(axis_size, vals_in, dims_in, *, true_jaxpr, false_jaxpr): # Notice that we're not currently supporting the case where the predicate value # itself is batched. In mainline JAX, we handle this case by transforming the -# conditional to a [select primitive](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html). +# conditional to a [select primitive](https://docs.jax.dev/en/latest/_autosummary/jax.lax.select.html). # That transformation is semantically correct so long as `true_fun` and # `false_fun` do not involve any side-effecting primitives. # diff --git a/docs/autodidax2_part1.ipynb b/docs/autodidax2_part1.ipynb index 0a5a89c8ed98..7a58f54b16c8 100644 --- a/docs/autodidax2_part1.ipynb +++ b/docs/autodidax2_part1.ipynb @@ -674,7 +674,7 @@ "something is constant with respect to differentiation? It's tempting to say\n", "\"it's a constant if and only if it's not a dual number\". But actually dual\n", "numbers created by a *different* JVPInterpreter also need to be considered\n", - "constants with resepect to the JVPInterpreter we're currently handling. That's\n", + "constants with respect to the JVPInterpreter we're currently handling. That's\n", "why we need the `x.interpreter is self` check in `JVPInterpreter.lift`. This\n", "comes up in higher order differentiation when there are multiple JVPInterprers\n", "in scope. The sort of bug where you accidentally interpret a dual number from\n", @@ -1046,7 +1046,7 @@ "That's it for part one of this tutorial. We've done two primitives, three\n", "interpreters and the tracing mechanism that weaves them together. In the next\n", "part we'll add types other than floats, error handling, compilation,\n", - "reverse-mode AD and higher-order primtives. Note that the second part is\n", + "reverse-mode AD and higher-order primitives. Note that the second part is\n", "structured differently. Rather than trying to have a top-to-bottom order that\n", "obeys both code dependencies (e.g. data structures need to be defined before\n", "they're used) and pedagogical dependencies (concepts need to be introduced\n", diff --git a/docs/autodidax2_part1.md b/docs/autodidax2_part1.md index 70dd0e4b696b..a4af594fb253 100644 --- a/docs/autodidax2_part1.md +++ b/docs/autodidax2_part1.md @@ -348,7 +348,7 @@ There are some subtleties worth discussing. First, how do you tell if something is constant with respect to differentiation? It's tempting to say "it's a constant if and only if it's not a dual number". But actually dual numbers created by a *different* JVPInterpreter also need to be considered -constants with resepect to the JVPInterpreter we're currently handling. That's +constants with respect to the JVPInterpreter we're currently handling. That's why we need the `x.interpreter is self` check in `JVPInterpreter.lift`. This comes up in higher order differentiation when there are multiple JVPInterprers in scope. The sort of bug where you accidentally interpret a dual number from @@ -539,7 +539,7 @@ print(jvp(lambda x: eval_jaxpr(build_jaxpr(foo, 1), (x,)), 2.0, 1.0)) That's it for part one of this tutorial. We've done two primitives, three interpreters and the tracing mechanism that weaves them together. In the next part we'll add types other than floats, error handling, compilation, -reverse-mode AD and higher-order primtives. Note that the second part is +reverse-mode AD and higher-order primitives. Note that the second part is structured differently. Rather than trying to have a top-to-bottom order that obeys both code dependencies (e.g. data structures need to be defined before they're used) and pedagogical dependencies (concepts need to be introduced diff --git a/docs/autodidax2_part1.py b/docs/autodidax2_part1.py index bfe59df359d3..44bf843c91b3 100644 --- a/docs/autodidax2_part1.py +++ b/docs/autodidax2_part1.py @@ -307,7 +307,7 @@ def nth_order_derivative(n, f, x): # something is constant with respect to differentiation? It's tempting to say # "it's a constant if and only if it's not a dual number". But actually dual # numbers created by a *different* JVPInterpreter also need to be considered -# constants with resepect to the JVPInterpreter we're currently handling. That's +# constants with respect to the JVPInterpreter we're currently handling. That's # why we need the `x.interpreter is self` check in `JVPInterpreter.lift`. This # comes up in higher order differentiation when there are multiple JVPInterprers # in scope. The sort of bug where you accidentally interpret a dual number from @@ -483,7 +483,7 @@ def eval_atom(x): return env[x] if isinstance(x, Var) else x # That's it for part one of this tutorial. We've done two primitives, three # interpreters and the tracing mechanism that weaves them together. In the next # part we'll add types other than floats, error handling, compilation, -# reverse-mode AD and higher-order primtives. Note that the second part is +# reverse-mode AD and higher-order primitives. Note that the second part is # structured differently. Rather than trying to have a top-to-bottom order that # obeys both code dependencies (e.g. data structures need to be defined before # they're used) and pedagogical dependencies (concepts need to be introduced diff --git a/docs/automatic-differentiation.md b/docs/automatic-differentiation.md index 07af05e3d973..221dd19c5121 100644 --- a/docs/automatic-differentiation.md +++ b/docs/automatic-differentiation.md @@ -26,7 +26,7 @@ Computing gradients is a critical part of modern machine learning methods, and t - {ref}`automatic-differentiation-evaluating-using-jax-value_and_grad` - {ref}`automatic-differentiation-checking-against-numerical-differences` -Make sure to also check out the {ref}`advanced-autodiff` tutorial for more advanced topics. +Make sure to also check out the {ref}`"Advanced automatic differentiation" guides ` for more advanced topics. While understanding how automatic differentiation works "under the hood" isn't crucial for using JAX in most contexts, you are encouraged to check out this quite accessible [video](https://www.youtube.com/watch?v=wG_nF1awSSY) to get a deeper sense of what's going on. @@ -230,4 +230,4 @@ check_grads(loss, (W, b), order=2) # check up to 2nd order derivatives ## Next steps -The {ref}`advanced-autodiff` tutorial provides more advanced and detailed explanations of how the ideas covered in this document are implemented in the JAX backend. Some features, such as {ref}`advanced-autodiff-custom-derivative-rules`, depend on understanding advanced automatic differentiation, so do check out that section in the {ref}`advanced-autodiff` tutorial if you are interested. +The {ref}`"Advanced automatic differentiation" guides ` provide more advanced and detailed explanations of how the ideas covered in this document are implemented in the JAX backend. Some features, such as {ref}`advanced-autodiff-custom-derivative-rules`, depend on understanding advanced automatic differentiation, so do check out that section if you are interested. diff --git a/docs/beginner_guide.rst b/docs/beginner_guide.rst index 783d3b49ae52..cdab6b2a8663 100644 --- a/docs/beginner_guide.rst +++ b/docs/beginner_guide.rst @@ -5,22 +5,23 @@ Getting Started with JAX ======================== Welcome to JAX! The JAX documentation contains a number of useful resources for getting started. -:doc:`quickstart` is the easiest place to jump-in and get an overview of the JAX project. +:doc:`notebooks/thinking_in_jax` is the easiest place to jump in and get an overview of the JAX project, its execution +model, and differences with NumPy. -If you're accustomed to writing NumPy code and are starting to explore JAX, you might find the following resources helpful: +If you're starting to explore JAX, you might also find the following resources helpful: -- :doc:`notebooks/thinking_in_jax` is a conceptual walkthrough of JAX's execution model. +- :doc:`key-concepts` introduces the key concepts of JAX, such as transformations, tracing, jaxprs and pytrees. - :doc:`notebooks/Common_Gotchas_in_JAX` lists some of JAX's sharp corners. -- :doc:`faq` answers some frequent jax questions. +- :doc:`faq` answers some frequent JAX questions. -Tutorials ---------- -If you're ready to explore JAX more deeply, the JAX tutorials go into much more detail: +JAX 101 +------- +If you're ready to explore JAX more deeply, the JAX 101 tutorials go into much more detail: .. toctree:: :maxdepth: 2 - tutorials + jax-101 If you prefer a video introduction here is one from JAX contributor Jake VanderPlas: diff --git a/docs/benchmarking.md b/docs/benchmarking.md new file mode 100644 index 000000000000..e00fb2760d22 --- /dev/null +++ b/docs/benchmarking.md @@ -0,0 +1,70 @@ +(benchmarking-jax-code)= +# Benchmarking JAX code + +You just ported a tricky function from NumPy/SciPy to JAX. Did that actually +speed things up? + +Keep in mind these important differences from NumPy when measuring the +speed of code using JAX: + +1. **JAX code is Just-In-Time (JIT) compiled.** Most code written in JAX can be + written in such a way that it supports JIT compilation, which can make it run + *much faster* (see + [To JIT or not to JIT](https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html#to-jit-or-not-to-jit)). + To get maximum performance from JAX, you should apply {func}`jax.jit` on your + outer-most function calls. + + Keep in mind that the first time you run JAX code, it will be slower because + it is being compiled. This is true even if you don't use `jit` in your own + code, because JAX's builtin functions are also JIT compiled. +2. **JAX has asynchronous dispatch.** This means that you need to call + `.block_until_ready()` to ensure that computation has actually happened + (see {ref}`async-dispatch`). +3. **JAX by default only uses 32-bit dtypes.** You may want to either explicitly + use 32-bit dtypes in NumPy or enable 64-bit dtypes in JAX (see + [Double (64 bit) precision](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision)) + for a fair comparison. +4. **Transferring data between CPUs and accelerators takes time.** If you only + want to measure how long it takes to evaluate a function, you may want to + transfer data to the device on which you want to run it first (see + {ref}`faq-data-placement`). + +Here's an example of how to put together all these tricks into a microbenchmark +for comparing JAX versus NumPy, making using of IPython's convenient +[`%time` and `%timeit` magics](https://ipython.readthedocs.io/en/stable/interactive/magics.html#magic-time): + +```python +import numpy as np +import jax + +def f(x): # function we're benchmarking (works in both NumPy & JAX) + return x.T @ (x - x.mean(axis=0)) + +x_np = np.ones((1000, 1000), dtype=np.float32) # same as JAX default dtype +%timeit f(x_np) # measure NumPy runtime + +# measure JAX device transfer time +%time x_jax = jax.device_put(x_np).block_until_ready() + +f_jit = jax.jit(f) +%time f_jit(x_jax).block_until_ready() # measure JAX compilation time +%timeit f_jit(x_jax).block_until_ready() # measure JAX runtime +``` + +When run with a GPU in [Colab](https://colab.research.google.com/), we see: + +- NumPy takes 16.2 ms per evaluation on the CPU +- JAX takes 1.26 ms to copy the NumPy arrays onto the GPU +- JAX takes 193 ms to compile the function +- JAX takes 485 µs per evaluation on the GPU + +In this case, we see that once the data is transferred and the function is +compiled, JAX on the GPU is about 30x faster for repeated evaluations. + +Is this a fair comparison? Maybe. The performance that ultimately matters is for +running full applications, which inevitably include some amount of both data +transfer and compilation. Also, we were careful to pick large enough arrays +(1000x1000) and an intensive enough computation (the `@` operator is +performing matrix-matrix multiplication) to amortize the increased overhead of +JAX/accelerators vs NumPy/CPU. For example, if we switch this example to use +10x10 input instead, JAX/GPU runs 10x slower than NumPy/CPU (100 µs vs 10 µs). diff --git a/docs/buffer_donation.md b/docs/buffer_donation.md new file mode 100644 index 000000000000..a170d5e64939 --- /dev/null +++ b/docs/buffer_donation.md @@ -0,0 +1,88 @@ +(buffer-donation)= +# Buffer donation + +When JAX executes a computation it uses buffers on the device for all inputs and outputs. +If you know that one of the inputs is not needed after the computation, and if it +matches the shape and element type of one of the outputs, you can specify that you +want the corresponding input buffer to be donated to hold an output. This will reduce +the memory required for the execution by the size of the donated buffer. + +If you have something like the following pattern, you can use buffer donation: + +```python +params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params, state) +``` + +You can think of this as a way to do a memory-efficient functional update +on your immutable JAX arrays. Within the boundaries of a computation XLA can +make this optimization for you, but at the jit/pmap boundary you need to +guarantee to XLA that you will not use the donated input buffer after calling +the donating function. + +You achieve this by using the `donate_argnums` parameter to the functions {func}`jax.jit`, +{func}`jax.pjit`, and {func}`jax.pmap`. This parameter is a sequence of indices (0 based) into +the positional argument list: + +```python +def add(x, y): + return x + y + +x = jax.device_put(np.ones((2, 3))) +y = jax.device_put(np.ones((2, 3))) +# Execute `add` with donation of the buffer for `y`. The result has +# the same shape and type as `y`, so it will share its buffer. +z = jax.jit(add, donate_argnums=(1,))(x, y) +``` + +Note that this currently does not work when calling your function with key-word arguments! +The following code will not donate any buffers: + +```python +params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params=params, state=state) +``` + +If an argument whose buffer is donated is a pytree, then all the buffers +for its components are donated: + +```python +def add_ones(xs: List[Array]): + return [x + 1 for x in xs] + +xs = [jax.device_put(np.ones((2, 3))), jax.device_put(np.ones((3, 4)))] +# Execute `add_ones` with donation of all the buffers for `xs`. +# The outputs have the same shape and type as the elements of `xs`, +# so they will share those buffers. +z = jax.jit(add_ones, donate_argnums=0)(xs) +``` + +It is not allowed to donate a buffer that is used subsequently in the computation, +and JAX will give an error because the buffer for `y` has become invalid +after it was donated: + +```python +# Donate the buffer for `y` +z = jax.jit(add, donate_argnums=(1,))(x, y) +w = y + 1 # Reuses `y` whose buffer was donated above +# >> RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer +``` + +You will get a warning if the donated buffer is not used, e.g., because +there are more donated buffers than can be used for the outputs: + +```python +# Execute `add` with donation of the buffers for both `x` and `y`. +# One of those buffers will be used for the result, but the other will +# not be used. +z = jax.jit(add, donate_argnums=(0, 1))(x, y) +# >> UserWarning: Some donated buffers were not usable: f32[2,3]{1,0} +``` + +The donation may also be unused if there is no output whose shape matches +the donation: + +```python +y = jax.device_put(np.ones((1, 3))) # `y` has different shape than the output +# Execute `add` with donation of the buffer for `y`. +z = jax.jit(add, donate_argnums=(1,))(x, y) +# >> UserWarning: Some donated buffers were not usable: f32[1,3]{1,0} +``` diff --git a/docs/building_on_jax.md b/docs/building_on_jax.md index 9416b16cde10..0f968b2afe9d 100644 --- a/docs/building_on_jax.md +++ b/docs/building_on_jax.md @@ -45,8 +45,8 @@ Here are more specific examples of each pattern. ### Direct usage Jax can be directly imported and utilized to build models “from scratch” as shown across this website, -for example in [JAX Tutorials](https://jax.readthedocs.io/en/latest/tutorials.html) -or [Neural Network with JAX](https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html). +for example in {ref}`JAX 101 Tutorials ` +or [Neural Network with JAX](https://docs.jax.dev/en/latest/notebooks/neural_network_with_tfds_data.html). This may be the best option if you are unable to find prebuilt code for your particular challenge, or if you're looking to reduce the number of dependencies in your codebase. diff --git a/docs/complex-differentiation.md b/docs/complex-differentiation.md new file mode 100644 index 000000000000..cf31b90a45ef --- /dev/null +++ b/docs/complex-differentiation.md @@ -0,0 +1,207 @@ +--- +jupytext: + formats: md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.4 +kernelspec: + display_name: Python 3 + name: python3 +--- + +# Complex numbers and differentiation + +JAX is great at complex numbers and differentiation. To support both [holomorphic and non-holomorphic differentiation](https://en.wikipedia.org/wiki/Holomorphic_function), it helps to think in terms of JVPs and VJPs. + +Consider a complex-to-complex function $f: \mathbb{C} \to \mathbb{C}$ and identify it with a corresponding function $g: \mathbb{R}^2 \to \mathbb{R}^2$, + +```{code-cell} +import jax.numpy as jnp + +def f(z): + x, y = jnp.real(z), jnp.imag(z) + return u(x, y) + v(x, y) * 1j + +def g(x, y): + return (u(x, y), v(x, y)) +``` + +That is, we've decomposed $f(z) = u(x, y) + v(x, y) i$ where $z = x + y i$, and identified $\mathbb{C}$ with $\mathbb{R}^2$ to get $g$. + +Since $g$ only involves real inputs and outputs, we already know how to write a Jacobian-vector product for it, say given a tangent vector $(c, d) \in \mathbb{R}^2$, namely: + +$\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} +\begin{bmatrix} c \\ d \end{bmatrix}$. + +To get a JVP for the original function $f$ applied to a tangent vector $c + di \in \mathbb{C}$, we just use the same definition and identify the result as another complex number, + +$\partial f(x + y i)(c + d i) = +\begin{matrix} \begin{bmatrix} 1 & i \end{bmatrix} \\ ~ \end{matrix} +\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} +\begin{bmatrix} c \\ d \end{bmatrix}$. + +That's our definition of the JVP of a $\mathbb{C} \to \mathbb{C}$ function! Notice it doesn't matter whether or not $f$ is holomorphic: the JVP is unambiguous. + +Here's a check: + +```{code-cell} +from jax import random, grad, jvp + +def check(seed): + key = random.key(seed) + + # random coeffs for u and v + key, subkey = random.split(key) + a, b, c, d = random.uniform(subkey, (4,)) + + def fun(z): + x, y = jnp.real(z), jnp.imag(z) + return u(x, y) + v(x, y) * 1j + + def u(x, y): + return a * x + b * y + + def v(x, y): + return c * x + d * y + + # primal point + key, subkey = random.split(key) + x, y = random.uniform(subkey, (2,)) + z = x + y * 1j + + # tangent vector + key, subkey = random.split(key) + c, d = random.uniform(subkey, (2,)) + z_dot = c + d * 1j + + # check jvp + _, ans = jvp(fun, (z,), (z_dot,)) + expected = (grad(u, 0)(x, y) * c + + grad(u, 1)(x, y) * d + + grad(v, 0)(x, y) * c * 1j+ + grad(v, 1)(x, y) * d * 1j) + print(jnp.allclose(ans, expected)) +``` + +```{code-cell} +check(0) +check(1) +check(2) +``` + +What about VJPs? We do something pretty similar: for a cotangent vector $c + di \in \mathbb{C}$ we define the VJP of $f$ as + +$(c + di)^* \; \partial f(x + y i) = +\begin{matrix} \begin{bmatrix} c & -d \end{bmatrix} \\ ~ \end{matrix} +\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} +\begin{bmatrix} 1 \\ -i \end{bmatrix}$. + +What's with the negatives? They're just to take care of complex conjugation, and the fact that we're working with covectors. + +Here's a check of the VJP rules: + +```{code-cell} +from jax import vjp + +def check(seed): + key = random.key(seed) + + # random coeffs for u and v + key, subkey = random.split(key) + a, b, c, d = random.uniform(subkey, (4,)) + + def fun(z): + x, y = jnp.real(z), jnp.imag(z) + return u(x, y) + v(x, y) * 1j + + def u(x, y): + return a * x + b * y + + def v(x, y): + return c * x + d * y + + # primal point + key, subkey = random.split(key) + x, y = random.uniform(subkey, (2,)) + z = x + y * 1j + + # cotangent vector + key, subkey = random.split(key) + c, d = random.uniform(subkey, (2,)) + z_bar = jnp.array(c + d * 1j) # for dtype control + + # check vjp + _, fun_vjp = vjp(fun, z) + ans, = fun_vjp(z_bar) + expected = (grad(u, 0)(x, y) * c + + grad(v, 0)(x, y) * (-d) + + grad(u, 1)(x, y) * c * (-1j) + + grad(v, 1)(x, y) * (-d) * (-1j)) + assert jnp.allclose(ans, expected, atol=1e-5, rtol=1e-5) +``` + +```{code-cell} +check(0) +check(1) +check(2) +``` + +What about convenience wrappers like {func}`jax.grad`, {func}`jax.jacfwd`, and {func}`jax.jacrev`? + +For $\mathbb{R} \to \mathbb{R}$ functions, recall we defined `grad(f)(x)` as being `vjp(f, x)[1](1.0)`, which works because applying a VJP to a `1.0` value reveals the gradient (i.e. Jacobian, or derivative). We can do the same thing for $\mathbb{C} \to \mathbb{R}$ functions: we can still use `1.0` as the cotangent vector, and we just get out a complex number result summarizing the full Jacobian: + +```{code-cell} +def f(z): + x, y = jnp.real(z), jnp.imag(z) + return x**2 + y**2 + +z = 3. + 4j +grad(f)(z) +``` + +For general $\mathbb{C} \to \mathbb{C}$ functions, the Jacobian has 4 real-valued degrees of freedom (as in the 2x2 Jacobian matrices above), so we can't hope to represent all of them within a complex number. But we can for holomorphic functions! A holomorphic function is precisely a $\mathbb{C} \to \mathbb{C}$ function with the special property that its derivative can be represented as a single complex number. (The [Cauchy-Riemann equations](https://en.wikipedia.org/wiki/Cauchy%E2%80%93Riemann_equations) ensure that the above 2x2 Jacobians have the special form of a scale-and-rotate matrix in the complex plane, i.e. the action of a single complex number under multiplication.) And we can reveal that one complex number using a single call to `vjp` with a covector of `1.0`. + +Because this only works for holomorphic functions, to use this trick we need to promise JAX that our function is holomorphic; otherwise, JAX will raise an error when {func}`jax.grad` is used for a complex-output function: + +```{code-cell} +def f(z): + return jnp.sin(z) + +z = 3. + 4j +grad(f, holomorphic=True)(z) +``` + +All the `holomorphic=True` promise does is disable the error when the output is complex-valued. We can still write `holomorphic=True` when the function isn't holomorphic, but the answer we get out won't represent the full Jacobian. Instead, it'll be the Jacobian of the function where we just discard the imaginary part of the output: + +```{code-cell} +def f(z): + return jnp.conjugate(z) + +z = 3. + 4j +grad(f, holomorphic=True)(z) # f is not actually holomorphic! +``` + +There are some useful upshots for how {func}`jax.grad` works here: + +1. We can use {func}`jax.grad` on holomorphic $\mathbb{C} \to \mathbb{C}$ functions. +2. We can use {func}`jax.grad` to optimize $f : \mathbb{C} \to \mathbb{R}$ functions, like real-valued loss functions of complex parameters `x`, by taking steps in the direction of the conjugate of `grad(f)(x)`. +3. If we have an $\mathbb{R} \to \mathbb{R}$ function that just happens to use some complex-valued operations internally (some of which must be non-holomorphic, e.g. FFTs used in convolutions) then {func}`jax.grad` still works and we get the same result that an implementation using only real values would have given. + +In any case, JVPs and VJPs are always unambiguous. And if we wanted to compute the full Jacobian matrix of a non-holomorphic $\mathbb{C} \to \mathbb{C}$ function, we can do it with JVPs or VJPs! + + +You should expect complex numbers to work everywhere in JAX. Here's differentiating through a Cholesky decomposition of a complex matrix: + +```{code-cell} +A = jnp.array([[5., 2.+3j, 5j], + [2.-3j, 7., 1.+7j], + [-5j, 1.-7j, 12.]]) + +def f(X): + L = jnp.linalg.cholesky(X) + return jnp.sum((L - jnp.sin(L))**2) + +grad(f, holomorphic=True)(A) +``` diff --git a/docs/concurrency.rst b/docs/concurrency.rst index 61b2b03fcb34..fcf47305d181 100644 --- a/docs/concurrency.rst +++ b/docs/concurrency.rst @@ -12,3 +12,11 @@ tracing (e.g., :func:`~jax.jit`) from multiple threads, you must not use threading to manipulate JAX values inside the implementation of the function `f` that is passed to :func:`~jax.jit`. The most likely outcome if you do this is a mysterious error from JAX. + +In multi-controller JAX, different processes must apply the same JAX operations +in the same order on a given device. If you are using threads with +multi-controller JAX, you can use the :func:`~jax.thread_guard` context manager +to detect cases where threads may schedule operations in different orders in +different processes, leading to non-deterministic crashes. When the thread guard +is set, an error will be raised at runtime if a JAX operation is called from a +thread other than the one in which the thread guard was set. diff --git a/docs/conf.py b/docs/conf.py index 45964b6d8d7e..ee7cafdf2aaa 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -29,6 +29,7 @@ import inspect import operator import os +from pathlib import Path import sys sys.path.insert(0, os.path.abspath('..')) @@ -38,11 +39,11 @@ from typing import ForwardRef def _do_not_evaluate_in_jax( - self, globalns, *args, _evaluate=ForwardRef._evaluate, + self, globalns, *args, _evaluate=ForwardRef._evaluate, **kwargs, ): if globalns.get('__name__', '').startswith('jax'): return self - return _evaluate(self, globalns, *args) + return _evaluate(self, globalns, *args, **kwargs) ForwardRef._evaluate = _do_not_evaluate_in_jax @@ -80,15 +81,19 @@ def _do_not_evaluate_in_jax( "sphinx_remove_toctrees", 'sphinx_copybutton', 'jax_extensions', + 'jax_list_config_options', 'sphinx_design', 'sphinxext.rediraffe', + 'source_include', + 'sphinxcontrib.mermaid' ] intersphinx_mapping = { 'array_api': ('https://data-apis.org/array-api/2023.12/', None), 'python': ('https://docs.python.org/3/', None), 'numpy': ('https://numpy.org/doc/stable/', None), - 'scipy': ('https://docs.scipy.org/doc/scipy/reference/', None), + # TODO(phawkins,jakevdp): revert to stable scipy docs when it is up again. + 'scipy': ('https://scipy.github.io/devdocs/', None), } suppress_warnings = [ @@ -132,13 +137,17 @@ def _do_not_evaluate_in_jax( # These are kept in sync using the jupytext pre-commit hook. 'notebooks/*.md', 'pallas/quickstart.md', + 'pallas/pipelining.md', + 'pallas/gpu/pipelining.md', 'pallas/tpu/pipelining.md', 'pallas/tpu/distributed.md', 'pallas/tpu/sparse.md', 'pallas/tpu/matmul.md', + 'pallas/tpu/core_map.md', 'jep/9407-type-promotion.md', 'autodidax.md', 'autodidax2_part1.md', + 'array_refs.md', 'sharded-computation.md', 'ffi.ipynb', ] @@ -203,6 +212,8 @@ def _do_not_evaluate_in_jax( # -- Options for myst ---------------------------------------------- myst_heading_anchors = 3 # auto-generate 3 levels of heading anchors myst_enable_extensions = ['dollarmath'] +myst_ref_domains = ["py"] +myst_all_links_external = False nb_execution_mode = "force" nb_execution_allow_errors = False nb_merge_streams = True @@ -222,18 +233,22 @@ def _do_not_evaluate_in_jax( 'jep/9407-type-promotion.*', # TODO(jakevdp): enable execution on the following if possible: 'notebooks/Distributed_arrays_and_automatic_parallelization.*', - 'notebooks/explicit-sharding.*', 'notebooks/autodiff_remat.*', + # Example only gives the specific output demonstrated on some platforms + 'notebooks/layout.*', # Fails on readthedocs with Kernel Died 'notebooks/convolutions.ipynb', # Requires accelerators 'pallas/quickstart.*', + 'pallas/pipelining.*', + 'pallas/gpu/pipelining.*', 'pallas/tpu/pipelining.*', 'pallas/tpu/distributed.*', 'pallas/tpu/sparse.*', 'pallas/tpu/matmul.*', - 'sharded-computation.*', - 'distributed_data_loading.*' + 'pallas/tpu/core_map.*', + 'distributed_data_loading.*', + 'notebooks/host-offloading.*', ] # -- Options for HTMLHelp output --------------------------------------------- @@ -352,25 +367,34 @@ def linkcode_resolve(domain, info): source, linenum = inspect.getsourcelines(obj) except: return None - filename = os.path.relpath(filename, start=os.path.dirname(jax.__file__)) + try: + filename = Path(filename).relative_to(Path(jax.__file__).parent) + except ValueError: + # Source file is not a relative to jax; this must be a re-exported function. + return None lines = f"#L{linenum}-L{linenum + len(source)}" if linenum else "" return f"https://github.com/jax-ml/jax/blob/main/jax/{filename}{lines}" # Generate redirects from deleted files to new sources rediraffe_redirects = { - 'notebooks/quickstart.md': 'quickstart.md', - 'jax-101/01-jax-basics.md': 'key-concepts.md', - 'jax-101/02-jitting.md': 'jit-compilation.md', - 'jax-101/03-vectorization.md': 'automatic-vectorization.md', - 'jax-101/04-advanced-autodiff.md': 'automatic-differentiation.md', - 'jax-101/05-random-numbers.md': 'random-numbers.md', - 'jax-101/05.1-pytrees.md': 'working-with-pytrees.md', - 'jax-101/06-parallelism.md': 'sharded-computation.md', - 'jax-101/07-state.md': 'stateful-computations.md', - 'jax-101/08-pjit.rst': 'sharded-computation.md', - 'jax-101/index.rst': 'tutorials.rst', - 'notebooks/external_callbacks.md': 'external-callbacks.md', - 'notebooks/How_JAX_primitives_work.md': 'jax-primitives.md', - 'jax.extend.ffi.rst': 'jax.ffi.rst', - 'Custom_Operation_for_GPUs.md': 'ffi.md', + "jax-101/01-jax-basics.md": "key-concepts.md", + "jax-101/02-jitting.md": "jit-compilation.md", + "jax-101/03-vectorization.md": "automatic-vectorization.md", + "jax-101/04-advanced-autodiff.md": "automatic-differentiation.md", + "jax-101/05-random-numbers.md": "random-numbers.md", + "jax-101/05.1-pytrees.md": "pytrees.md", + "jax-101/06-parallelism.md": "sharded-computation.md", + "jax-101/07-state.md": "stateful-computations.md", + "jax-101/08-pjit.rst": "sharded-computation.md", + "jax-101/index.rst": "jax-101.rst", + "tutorials.rst": "jax-101.rst", + "notebooks/external_callbacks.md": "external-callbacks.md", + "notebooks/How_JAX_primitives_work.md": "jax-primitives.md", + "jax.extend.ffi.rst": "jax.ffi.rst", + "Custom_Operation_for_GPUs.md": "ffi.md", + "notebooks/quickstart.md": "quickstart.md", + "quickstart.md": "notebooks/thinking_in_jax.md", + "advanced_guide.rst": "advanced_guides.rst", + "user_guides.rst": "advanced_guides.rst", + "working_with_pytrees.md": "pytrees.md", } diff --git a/docs/config_options.rst b/docs/config_options.rst new file mode 100644 index 000000000000..a8ef4e93a834 --- /dev/null +++ b/docs/config_options.rst @@ -0,0 +1,66 @@ +.. _jax: + +.. This target is required to prevent the Sphinx build error "Unknown target name: jax". +.. The custom directive list_config_options imports JAX to extract real configuration +.. data, which causes Sphinx to look for a target named "jax". This dummy target +.. satisfies that requirement while allowing the actual JAX import to work. + +Configuration Options +===================== + +JAX provides various configuration options to customize its behavior. These options control everything from numerical precision to debugging features. + +How to Use Configuration Options +-------------------------------- + +JAX configuration options can be set in several ways: + +1. **Environment variables** (set before running your program): + + .. code-block:: bash + + export JAX_ENABLE_X64=True + python my_program.py + +2. **Runtime configuration** (in your Python code): + + .. code-block:: python + + import jax + jax.config.update("jax_enable_x64", True) + +3. **Command-line flags** (using Abseil): + + .. code-block:: python + + # In your code: + import jax + jax.config.parse_flags_with_absl() + + .. code-block:: bash + + # When running: + python my_program.py --jax_enable_x64=True + +Common Configuration Options +---------------------------- + +Here are some of the most frequently used configuration options: + +- ``jax_enable_x64`` -- Enable 64-bit floating-point precision +- ``jax_disable_jit`` -- Disable JIT compilation for debugging +- ``jax_debug_nans`` -- Check for and raise errors on NaNs +- ``jax_platforms`` -- Control which backends (CPU/GPU/TPU) JAX will initialize +- ``jax_numpy_rank_promotion`` -- Control automatic rank promotion behavior +- ``jax_default_matmul_precision`` -- Set default precision for matrix multiplication operations + +.. raw:: html + +
+ +All Configuration Options +------------------------- + +Below is a complete list of all available JAX configuration options: + +.. list_config_options:: diff --git a/docs/contributing.md b/docs/contributing.md index 99d78453c436..40334bb9599a 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -6,8 +6,8 @@ Everyone can contribute to JAX, and we value everyone's contributions. There are ways to contribute, including: - Answering questions on JAX's [discussions page](https://github.com/jax-ml/jax/discussions) -- Improving or expanding JAX's [documentation](http://jax.readthedocs.io/) -- Contributing to JAX's [code-base](http://github.com/jax-ml/jax/) +- Improving or expanding JAX's [documentation](https://docs.jax.dev) +- Contributing to JAX's [code-base](https://github.com/jax-ml/jax) - Contributing in any of the above ways to the broader ecosystem of [libraries built on JAX](https://github.com/jax-ml/jax#neural-network-libraries) The JAX project follows [Google's Open Source Community Guidelines](https://opensource.google/conduct/). @@ -23,6 +23,22 @@ For other proposals, we ask that you first open a GitHub [Discussion](https://github.com/jax-ml/jax/discussions) to seek feedback on your planned contribution. +## Can I contribute AI generated code? + +All submissions to Google Open Source projects need to follow Google's [Contributor License +Agreement (CLA)](https://cla.developers.google.com/), in which contributors agree that their +contribution is an original work of authorship. This doesn’t prohibit the use of coding +assistance tools, but what’s submitted does need to be a contributor's original creation. + +In the JAX project, a main concern with AI-generated contributions is that +**low-quality AI-generated code imposes a disproportionate review cost**. +Since the team's capacity for code review is limited, we have a higher bar +for accepting AI-generated contributions compared to those written by a human. + +A loose rule of thumb: if the team needs to spend more time reviewing a +contribution than the contributor spends generating it, then the contribution +is probably not helpful to the project, and we will likely reject it. + ## Contributing code using pull requests We do all of our development using git, so basic knowledge is assumed. @@ -30,13 +46,13 @@ We do all of our development using git, so basic knowledge is assumed. Follow these steps to contribute code: 1. Sign the [Google Contributor License Agreement (CLA)](https://cla.developers.google.com/). - For more information, see the Pull Request Checklist below. + For more information, see the {ref}`pr-checklist` below. 2. Fork the JAX repository by clicking the **Fork** button on the - [repository page](http://www.github.com/jax-ml/jax). This creates + [repository page](https://github.com/jax-ml/jax). This creates a copy of the JAX repository in your own account. -3. Install Python >= 3.10 locally in order to run tests. +3. Install Python >= 3.11 locally in order to run tests. 4. `pip` installing your fork from source. This allows you to modify the code and immediately test it out: @@ -52,7 +68,7 @@ Follow these steps to contribute code: changes. ```bash - git remote add upstream https://www.github.com/jax-ml/jax + git remote add upstream https://github.com/jax-ml/jax.git ``` 6. Create a branch where you will develop from: @@ -81,6 +97,12 @@ Follow these steps to contribute code: pytest -n auto tests/ ``` + Run them in 64-bit mode as well, by setting the environment variable `JAX_ENABLE_X64=True`: + + ```bash + JAX_ENABLE_X64=True pytest -n auto tests/ + ``` + JAX's test suite is quite large, so if you know the specific test file that covers your changes, you can limit the tests to that; for example: @@ -192,3 +214,109 @@ not available via standard GitHub CI. Detailed results of these tests are not pu viewable, but the JAX maintainer assigned to your PR will communicate with you regarding any failures these might uncover; it's not uncommon, for example, that numerical tests need different tolerances on TPU than on CPU. + +### Wheel sources update + +If a new python package or a new file is added to the wheel, one of the +following Bazel targets should be updated: + +[jax wheel sources](https://github.com/jax-ml/jax/blob/0080c5c934d9e3668c93d88e2b94f66e05f9d8d8/BUILD.bazel#L29) + +[jaxlib wheel sources](https://github.com/jax-ml/jax/blob/0080c5c934d9e3668c93d88e2b94f66e05f9d8d8/jaxlib/tools/BUILD.bazel#L151) + +[jax CUDA plugin wheel sources](https://github.com/jax-ml/jax/blob/0080c5c934d9e3668c93d88e2b94f66e05f9d8d8/jaxlib/tools/BUILD.bazel#L210) + +[jax CUDA pjrt wheel sources](https://github.com/jax-ml/jax/blob/0080c5c934d9e3668c93d88e2b94f66e05f9d8d8/jaxlib/tools/BUILD.bazel#L318) + +1. A static source addition: add to `static_srcs` list. + + Example: add `//:file.txt` to `jax` wheel. + + ``` + wheel_sources( + name = "jax_sources", + data_srcs = [...], + py_srcs = [...], + static_srcs = [ + ... + "//:file.txt" + ], + ) + ``` + +2. A platform-dependent source addition: add to `data_srcs` list. + + Example: add a `cc_library` target `//:cc_target` to `jax` wheel. + + ``` + wheel_sources( + name = "jax_sources", + data_srcs = [ + ... + "//:cc_target" + ], + py_srcs = [...], + static_srcs = [...], + ) + ``` + + If the existing targets in `data_srcs` already have a transitive + dependency on `//:cc_target`, you don't need to add it explicitly. + +3. A new python package addition: create `__init__.py` file and Bazel python +rule target with `__init__.py` in sources, add it to `py_srcs` list. + + Example: add a new package `jax.test_package` to `jax` wheel: + + The content of the file `jax/test_package/BUILD`: + + ``` + pytype_strict_library( + name = "init", + srcs = ["__init__.py"], + visibility = ["//visibility:public"], + ) + ``` + + ``` + wheel_sources( + name = "jax_sources", + data_srcs = [...], + py_srcs = [ + ... + "//jax/test_package:init", + ], + static_srcs = [...], + ) + ``` + +4. A new python source addition to existing package: create/update Bazel python +rule target with the new file in sources, add it to `py_srcs` list. + + Example: add a new file `jax/test_package/example.py` to `jax` wheel: + + The content of the file `jax/test_package/BUILD`: + + ``` + pytype_strict_library( + name = "example", + srcs = ["__init__.py", + "example.py"], + visibility = ["//visibility:public"], + ) + ``` + + ``` + wheel_sources( + name = "jax_sources", + data_srcs = [...], + py_srcs = [ + ... + "//jax/test_package:example", + ], + static_srcs = [...], + ) + ``` + + If the existing targets in `py_srcs` already have a transitive + dependency on `example.py`, you don't need to add it explicitly. diff --git a/docs/contributor_guide.rst b/docs/contributor_guide.rst index 81e1f5c99135..d5b4a6eea691 100644 --- a/docs/contributor_guide.rst +++ b/docs/contributor_guide.rst @@ -26,3 +26,4 @@ some of JAX's (extensible) internals. autodidax autodidax2_part1 jep/index + internals/index diff --git a/docs/control-flow.md b/docs/control-flow.md index 7cb959f3e434..8f59bd92add7 100644 --- a/docs/control-flow.md +++ b/docs/control-flow.md @@ -244,19 +244,19 @@ lax.cond(False, lambda x: x+1, lambda x: x-1, operand) `jax.lax` provides two other functions that allow branching on dynamic predicates: -- [`lax.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html) is +- [`lax.select`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.select.html) is like a batched version of `lax.cond`, with the choices expressed as pre-computed arrays rather than as functions. -- [`lax.switch`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html) is +- [`lax.switch`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.switch.html) is like `lax.cond`, but allows switching between any number of callable choices. In addition, `jax.numpy` provides several numpy-style interfaces to these functions: -- [`jnp.where`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html) with +- [`jnp.where`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.where.html) with three arguments is the numpy-style wrapper of `lax.select`. -- [`jnp.piecewise`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.piecewise.html) +- [`jnp.piecewise`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.piecewise.html) is a numpy-style wrapper of `lax.switch`, but switches on a list of boolean conditions rather than a single scalar index. -- [`jnp.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.select.html) has +- [`jnp.select`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.select.html) has an API similar to `jnp.piecewise`, but the choices are given as pre-computed arrays rather than as functions. It is implemented in terms of multiple calls to `lax.select`. diff --git a/docs/custom_pytrees.md b/docs/custom_pytrees.md new file mode 100644 index 000000000000..1446b4004d32 --- /dev/null +++ b/docs/custom_pytrees.md @@ -0,0 +1,335 @@ +--- +jupytext: + formats: md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.4 +kernelspec: + display_name: Python 3 + language: python + name: python3 +--- + +(pytrees-custom-pytree-nodes)= +# Custom pytree nodes + +This section explains how in JAX you can extend the set of Python types that will be considered _internal nodes_ in pytrees (pytree nodes) by using {func}`jax.tree_util.register_pytree_node` with {func}`jax.tree.map`. + +Why would you need this? In the previous examples, pytrees were shown as lists, tuples, and dicts, with everything else as pytree leaves. This is because if you define your own container class, it will be considered to be a pytree leaf unless you _register_ it with JAX. This is also the case even if your container class has trees inside it. For example: + +```{code-cell} +import jax + +class Special(object): + def __init__(self, x, y): + self.x = x + self.y = y + +jax.tree.leaves([ + Special(0, 1), + Special(2, 4), +]) +``` + +Accordingly, if you try to use a {func}`jax.tree.map` expecting the leaves to be elements inside the container, you will get an error: + +```{code-cell} +:tags: [raises-exception] + +jax.tree.map(lambda x: x + 1, + [ + Special(0, 1), + Special(2, 4) + ]) +``` + +As a solution, JAX allows to extend the set of types to be considered internal pytree nodes through a global registry of types. Additionally, the values of registered types are traversed recursively. + +First, register a new type using {func}`jax.tree_util.register_pytree_node`: + +```{code-cell} +from jax.tree_util import register_pytree_node + +class RegisteredSpecial(Special): + def __repr__(self): + return "RegisteredSpecial(x={}, y={})".format(self.x, self.y) + +def special_flatten(v): + """Specifies a flattening recipe. + + Params: + v: The value of the registered type to flatten. + Returns: + A pair of an iterable with the children to be flattened recursively, + and some opaque auxiliary data to pass back to the unflattening recipe. + The auxiliary data is stored in the treedef for use during unflattening. + The auxiliary data could be used, for example, for dictionary keys. + """ + children = (v.x, v.y) + aux_data = None + return (children, aux_data) + +def special_unflatten(aux_data, children): + """Specifies an unflattening recipe. + + Params: + aux_data: The opaque data that was specified during flattening of the + current tree definition. + children: The unflattened children + + Returns: + A reconstructed object of the registered type, using the specified + children and auxiliary data. + """ + return RegisteredSpecial(*children) + +# Global registration +register_pytree_node( + RegisteredSpecial, + special_flatten, # Instruct JAX what are the children nodes. + special_unflatten # Instruct JAX how to pack back into a `RegisteredSpecial`. +) +``` + +Now you can traverse the special container structure: + +```{code-cell} +jax.tree.map(lambda x: x + 1, + [ + RegisteredSpecial(0, 1), + RegisteredSpecial(2, 4), + ]) +``` + +Alternatively, you can define appropriate `tree_flatten` and `tree_unflatten` methods +on your class and decorate it with {func}`~jax.tree_util.register_pytree_node_class`: + +```{code-cell} +from jax.tree_util import register_pytree_node_class + +@register_pytree_node_class +class RegisteredSpecial2(Special): + def __repr__(self): + return "RegisteredSpecial2(x={}, y={})".format(self.x, self.y) + + def tree_flatten(self): + children = (self.x, self.y) + aux_data = None + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls(*children) + + +def show_example(structured): + flat, tree = structured.tree_flatten() + unflattened = RegisteredSpecial2.tree_unflatten(tree, flat) + print(f"{structured=}\n {flat=}\n {tree=}\n {unflattened=}") + + +show_example(RegisteredSpecial2(1., 2.)) +``` + +Modern Python comes equipped with helpful tools to make defining containers easier. Some will work with JAX out-of-the-box, but others require more care. + +For instance, a Python `NamedTuple` subclass doesn't need to be registered to be considered a pytree node type: + +```{code-cell} +from typing import NamedTuple, Any + +class MyOtherContainer(NamedTuple): + name: str + a: Any + b: Any + c: Any + +# NamedTuple subclasses are handled as pytree nodes, so +# this will work out-of-the-box. +jax.tree.leaves([ + MyOtherContainer('Alice', 1, 2, 3), + MyOtherContainer('Bob', 4, 5, 6) +]) +``` + +Notice that the `name` field now appears as a leaf, because all tuple elements are children. This is what happens when you don't have to register the class the hard way. + +When defining unflattening functions, in general `children` should contain all the +dynamic elements of the data structure (arrays, dynamic scalars, and pytrees), while +`aux_data` should contain all the static elements that will be rolled into the `treedef` +structure. JAX sometimes needs to compare `treedef` for equality, or compute its hash +for use in the JIT cache, and so care must be taken to ensure that the auxiliary data +specified in the flattening recipe supports meaningful hashing and equality comparisons. + +Unlike `NamedTuple` subclasses, classes decorated with `@dataclass` are not automatically pytrees. However, they can be registered as pytrees using the {func}`jax.tree_util.register_dataclass` decorator: + +```{code-cell} +from dataclasses import dataclass +import jax.numpy as jnp +import numpy as np +import functools + +@functools.partial(jax.tree_util.register_dataclass, + data_fields=['a', 'b', 'c'], + meta_fields=['name']) +@dataclass +class MyDataclassContainer(object): + name: str + a: Any + b: Any + c: Any + +# MyDataclassContainer is now a pytree node. +jax.tree.leaves([ + MyDataclassContainer('apple', 5.3, 1.2, jnp.zeros([4])), + MyDataclassContainer('banana', np.array([3, 4]), -1., 0.) +]) +``` + +Notice that the `name` field does not appear as a leaf. This is because we included it in the `meta_fields` argument to {func}`jax.tree_util.register_dataclass`, indicating that it should be treated as metadata/auxiliary data, just like `aux_data` in `RegisteredSpecial` above. Now instances of `MyDataclassContainer` can be passed into JIT-ed functions, and `name` will be treated as static (see {ref}`jit-marking-arguments-as-static` for more information on static args): + +```{code-cell} +@jax.jit +def f(x: MyDataclassContainer | MyOtherContainer): + return x.a + x.b + +# Works fine! `mdc.name` is static. +mdc = MyDataclassContainer('mdc', 1, 2, 3) +y = f(mdc) +``` + +Contrast this with `MyOtherContainer`, the `NamedTuple` subclass. Since the `name` field is a pytree leaf, JIT expects it to be convertible to {class}`jax.Array`, and the following raises an error: + +```{code-cell} +:tags: [raises-exception] + +moc = MyOtherContainer('moc', 1, 2, 3) +y = f(moc) +``` + +The whole set of functions for operating on pytrees are in {mod}`jax.tree_util`. + +## Custom pytrees and initialization with unexpected values + +Another common gotcha with user-defined pytree objects is that JAX transformations occasionally initialize them with unexpected values, so that any input validation done at initialization may fail. For example: + +```{code-cell} +:tags: [raises-exception] + +class MyTree: + def __init__(self, a): + self.a = jnp.asarray(a) + +register_pytree_node(MyTree, lambda tree: ((tree.a,), None), + lambda _, args: MyTree(*args)) + +tree = MyTree(jnp.arange(5.0)) + +jax.vmap(lambda x: x)(tree) # Error because object() is passed to `MyTree`. +``` + +```{code-cell} +:tags: [raises-exception] + +jax.jacobian(lambda x: x)(tree) # Error because MyTree(...) is passed to `MyTree`. +``` + +- In the first case with `jax.vmap(...)(tree)`, JAX’s internals use arrays of `object()` values to infer the structure of the tree +- In the second case with `jax.jacobian(...)(tree)`, the Jacobian of a function mapping a tree to a tree is defined as a tree of trees. + +**Potential solution 1:** + +- The `__init__` and `__new__` methods of custom pytree classes should generally avoid doing any array conversion or other input validation, or else anticipate and handle these special cases. For example: + +```{code-cell} +class MyTree: + def __init__(self, a): + if not (type(a) is object or a is None or isinstance(a, MyTree)): + a = jnp.asarray(a) + self.a = a +``` + +**Potential solution 2:** + +- Structure your custom `tree_unflatten` function so that it avoids calling `__init__`. If you choose this route, make sure that your `tree_unflatten` function stays in sync with `__init__` if and when the code is updated. Example: + +```{code-cell} +def tree_unflatten(aux_data, children): + del aux_data # Unused in this class. + obj = object.__new__(MyTree) + obj.a = children[0] + return obj +``` + +## Internal pytree handling + +JAX flattens pytrees into lists of leaves at the `api.py` boundary (and also +in control flow primitives). This keeps downstream JAX internals simpler: +transformations like {func}`~jax.grad`, {func}`~jax.jit`, and {func}`~jax.vmap` +can handle user functions that accept and return the myriad different Python +containers, while all the other parts of the system can operate on functions +that only take (multiple) array arguments and always return a flat list of arrays. + +When JAX flattens a pytree it will produce a list of leaves and a `treedef` +object that encodes the structure of the original value. The `treedef` can +then be used to construct a matching structured value after transforming the +leaves. Pytrees are tree-like, rather than DAG-like or graph-like, in that we +handle them assuming referential transparency and that they can't contain +reference cycles. + +Here is a simple example: + +```{code-cell} +:tags: [remove-cell] + +# Execute this to consume & hide the GPU warning. +import jax.numpy as _jnp +_jnp.arange(10) +``` + +```{code-cell} +from jax.tree_util import tree_flatten, tree_unflatten +import jax.numpy as jnp + +# The structured value to be transformed +value_structured = [1., (2., 3.)] + +# The leaves in value_flat correspond to the `*` markers in value_tree +value_flat, value_tree = tree_flatten(value_structured) +print(f"{value_flat=}\n{value_tree=}") + +# Transform the flat value list using an element-wise numeric transformer +transformed_flat = list(map(lambda v: v * 2., value_flat)) +print(f"{transformed_flat=}") + +# Reconstruct the structured output, using the original +transformed_structured = tree_unflatten(value_tree, transformed_flat) +print(f"{transformed_structured=}") +``` + +By default, pytree containers can be lists, tuples, dicts, namedtuple, None, +OrderedDict. Other types of values, including numeric and ndarray values, are +treated as leaves: + +```{code-cell} +from collections import namedtuple +Point = namedtuple('Point', ['x', 'y']) + +example_containers = [ + (1., [2., 3.]), + (1., {'b': 2., 'a': 3.}), + 1., + None, + jnp.zeros(2), + Point(1., 2.) +] +def show_example(structured): + flat, tree = tree_flatten(structured) + unflattened = tree_unflatten(tree, flat) + print(f"{structured=}\n {flat=}\n {tree=}\n {unflattened=}") + +for structured in example_containers: + show_example(structured) +``` diff --git a/docs/debugging.md b/docs/debugging.md index d07f42da5c85..9aa646f1ecb3 100644 --- a/docs/debugging.md +++ b/docs/debugging.md @@ -17,9 +17,20 @@ kernelspec: -This section introduces you to a set of built-in JAX debugging methods — {func}`jax.debug.print`, {func}`jax.debug.breakpoint`, and {func}`jax.debug.callback` — that you can use with various JAX transformations. +Do you have exploding gradients? Are NaNs making you gnash your teeth? Just want +to poke around the intermediate values in your computation? This section +introduces you to a set of built-in JAX debugging methods that you can use with +various JAX transformations. -Let's begin with {func}`jax.debug.print`. +**Summary:** + +- Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions, + and {func}`jax.debug.breakpoint` to pause execution of your compiled function to inspect values in the call stack. +- {mod}`jax.experimental.checkify` lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX + code. +- JAX offers config flags and context managers that enable catching errors more easily. For example, enable the + `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code and enable the + `jax_disable_jit` flag to disable JIT-compilation. ## `jax.debug.print` for simple inspection @@ -110,7 +121,6 @@ f(1, 2) To learn more about {func}`jax.debug.print` and its Sharp Bits, refer to {ref}`advanced-debugging`. - ## `jax.debug.breakpoint` for `pdb`-like debugging **Summary:** Use {func}`jax.debug.breakpoint` to pause the execution of your JAX program to inspect values. @@ -129,6 +139,7 @@ def f(x): y, z = jnp.sin(x), jnp.cos(x) jax.debug.breakpoint() return y * z + f(2.) # ==> Pauses during execution ``` @@ -200,6 +211,67 @@ This can make {func}`jax.debug.callback` useful for general-purpose debugging. You can learn more about {func}`jax.debug.callback` and other kinds of JAX callbacks in {ref}`external-callbacks`. +Read more in [](debugging/print_breakpoint). + +## Functional error checks with `jax.experimental.checkify` + +**Summary:** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code: + +```python +from jax.experimental import checkify +import jax +import jax.numpy as jnp + +def f(x, i): + checkify.check(i >= 0, "index needs to be non-negative!") + y = x[i] + z = jnp.sin(y) + return z + +jittable_f = checkify.checkify(f) + +err, z = jax.jit(jittable_f)(jnp.ones((5,)), -1) +print(err.get()) +# >> index needs to be non-negative! (check failed at <...>:6 (f)) +``` + +You can also use checkify to automatically add common checks: + +```python +errors = checkify.user_checks | checkify.index_checks | checkify.float_checks +checked_f = checkify.checkify(f, errors=errors) + +err, z = checked_f(jnp.ones((5,)), 100) +err.throw() +# ValueError: out-of-bounds indexing at <..>:7 (f) + +err, z = checked_f(jnp.ones((5,)), -1) +err.throw() +# ValueError: index needs to be non-negative! (check failed at <…>:6 (f)) + +err, z = checked_f(jnp.array([jnp.inf, 1]), 0) +err.throw() +# ValueError: nan generated by primitive sin at <...>:8 (f) +``` + +Read more in [](debugging/checkify_guide). + +### Throwing Python errors with JAX's debug flags + +**Summary:** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`. + +```python +import jax +jax.config.update("jax_debug_nans", True) + +def f(x, y): + return x / y + +jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception! +``` + +Read more in [](debugging/flags). + ## Next steps Check out the {ref}`advanced-debugging` to learn more about debugging in JAX. diff --git a/docs/debugging/flags.md b/docs/debugging/flags.md index 13e34a6c3ac4..a879fb69e16e 100644 --- a/docs/debugging/flags.md +++ b/docs/debugging/flags.md @@ -1,31 +1,78 @@ -# JAX debugging flags - - +--- +jupytext: + formats: md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.4 +kernelspec: + display_name: Python 3 + language: python + name: python3 +--- + +(debugging-flags)= +# JAX debugging flags + + JAX offers flags and context managers that enable catching errors more easily. - + ## `jax_debug_nans` configuration option and context manager -**Summary:** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code). +**Summary:** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code. + +`jax_debug_nans` is a JAX flag that when enabled, will cause computations to error-out immediately on production of a NaN. Switching this option on adds a NaN check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jax.jit`. + +For code under an `@jax.jit`, the output of every `@jax.jit` function is checked and if a NaN is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of `@jax.jit` at a time. + +There could be tricky situations that arise, like NaNs that only occur under a `@jax.jit` but don't get produced in de-optimized mode. In that case you'll see a warning message print out but your code will continue to execute. -`jax_debug_nans` is a JAX flag that when enabled, automatically raises an error when a NaN is detected. It has special handling for JIT-compiled -- when a NaN output is detected from a JIT-ted function, the function is re-run eagerly (i.e. without compilation) and will throw an error at the specific primitive that produced the NaN. +If the NaNs are being produced in the backward pass of a gradient evaluation, when an exception is raised several frames up in the stack trace you will be in the backward_pass function, which is essentially a simple jaxpr interpreter that walks the sequence of primitive operations in reverse. ### Usage -If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by: +If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by doing one of: +* running your code inside the `jax.debug_nans` context manager, using `with jax.debug_nans(True):`; * setting the `JAX_DEBUG_NANS=True` environment variable; * adding `jax.config.update("jax_debug_nans", True)` near the top of your main file; * adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`; ### Example(s) -```python +```{code-cell} import jax +import jax.numpy as jnp +import traceback jax.config.update("jax_debug_nans", True) -def f(x, y): - return x / y -jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception! +def f(x): + w = 3 * jnp.square(x) + return jnp.log(-w) + +# The stack trace is very long so only print a couple lines. +try: + f(5.) +except FloatingPointError as e: + print(traceback.format_exc(limit=2)) +``` + +The NaN generated was caught. By running `%debug`, we can get a post-mortem debugger. This also works with functions under `@jax.jit`, as the example below shows. + +```{code-cell} +:tags: [raises-exception] + +jax.jit(f)(5.) +``` + +When this code sees a NaN in the output of an `@jax.jit` function, it calls into the de-optimized code, so we still get a clear stack trace. And we can run a post-mortem debugger with `%debug` to inspect all the values to figure out the error. + +The `jax.debug_nans` context manager can be used to activate/deactivate NaN debugging. Since we activated it above with `jax.config.update`, let's deactivate it: + +```{code-cell} +with jax.debug_nans(False): + print(jax.jit(f)(5.)) ``` #### Strengths and limitations of `jax_debug_nans` @@ -35,10 +82,16 @@ jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception! * Throws a standard Python exception and is compatible with PDB postmortem ##### Limitations -* Not compatible with `jax.pmap` or `jax.pjit` -* Re-running functions eagerly can be slow +* Re-running functions eagerly can be slow. You shouldn't have the NaN-checker on if you're not debugging, as it can introduce lots of device-host round-trips and performance regressions. * Errors on false positives (e.g. intentionally created NaNs) +## `jax_debug_infs` configuration option and context manager + +`jax_debug_infs` works similarly to `jax_debug_nans`. `jax_debug_infs` often needs to be combined with `jax_disable_jit`, since Infs might not cascade to the output like NaNs. Alternatively, `jax.experimental.checkify` may be used to find Infs in intermediates. + +Full documentation of `jax_debug_infs` is forthcoming. + + ## `jax_disable_jit` configuration option and context manager **Summary:** Enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb` @@ -74,5 +127,4 @@ jax.jit(f)(-2.) # ==> Enters PDB breakpoint! * Throws standard Python exceptions and is compatible with PDB postmortem ##### Limitations -* Not compatible with `jax.pmap` or `jax.pjit` * Running functions without JIT-compilation can be slow diff --git a/docs/debugging/index.md b/docs/debugging/index.md index bcf561d06807..724d29af34b6 100644 --- a/docs/debugging/index.md +++ b/docs/debugging/index.md @@ -8,7 +8,8 @@ Table of contents: * [Interactive inspection with `jax.debug`](print_breakpoint) * [Functional error checks with jax.experimental.checkify](checkify_guide) -* [Throwing Python errors with JAX’s debug flags](flags) +* [Throwing Python errors with JAX’s debug flags](./flags) +* [Attaching XLA metadata with `set_xla_metadata`](xla_metadata) ## Interactive inspection with `jax.debug` @@ -85,7 +86,7 @@ Complete guide [here](checkify_guide) ## Throwing Python errors with JAX's debug flags -Complete guide [here](flags) +Complete guide [here](./flags) **Summary:** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`. @@ -98,14 +99,47 @@ def f(x, y): jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception! ``` -[Read more](flags). +[Read more](./flags). + +## Attaching XLA Metadata with `set_xla_metadata` + +Complete guide [here](xla_metadata) + +**Summary:** `set_xla_metadata` allows you to attach metadata to operations in your JAX code. This metadata is passed down to the XLA compiler as `frontend_attributes` and can be used to enable compiler-level debugging tools, such as the XLA-TPU debugger. + +**Note:** `set_xla_metadata` is an experimental feature and its API is subject to change. + +```python +import jax +import jax.numpy as jnp +from jax.experimental.xla_metadata import set_xla_metadata + +# Tagging an individual operation +def value_tagging(x): + y = jnp.sin(x) + z = jnp.cos(x) + return set_xla_metadata(y * z, breakpoint=True) + +print(jax.jit(value_tagging).lower(1.0).as_text("hlo")) +``` +Results in: +``` +ENTRY main.5 { + x.1 = f32[] parameter(0) + sin.2 = f32[] sine(x.1) + cos.3 = f32[] cosine(x.1) + ROOT mul.4 = f32[] multiply(sin.2, cos.3), frontend_attributes={breakpoint="true"} +} +``` + +[Read more](xla_metadata). ```{toctree} :caption: Read more :maxdepth: 1 +flags print_breakpoint checkify_guide -flags +xla_metadata ``` - diff --git a/docs/debugging/print_breakpoint.md b/docs/debugging/print_breakpoint.md index 85580120c0a9..8e9ded01c537 100644 --- a/docs/debugging/print_breakpoint.md +++ b/docs/debugging/print_breakpoint.md @@ -179,6 +179,10 @@ Why? Under the hood, the compiler gets a functional representation of the staged To preserve the original order of `jax.debug.print`s as written in your Python function, you can use `jax.debug.print(..., ordered=True)`, which will ensure the relative order of prints is preserved. But using `ordered=True` will raise an error under `jax.pmap` and other JAX transformations involving parallelism, since ordering can't be guaranteed under parallel execution. +#### Computation perturbation + +Adding `jax.debug.print` or `jax.debug.breakpoint` statements will change the computation that XLA is asked to compile. This can potentially result in numeric discrepancies compared to the same code without debug statements because XLA might perform different operation fusions during compilation. Keep this in mind when debugging numerical issues, as the act of adding debug statements might affect the behavior you're trying to investigate. + #### Asynchronous callbacks Depending on the backend, `jax.debug.print`s may happen asynchronously, i.e. not in your main program thread. This means that values could be printed to your screen even after your JAX function has returned a value. diff --git a/docs/debugging/xla_metadata.md b/docs/debugging/xla_metadata.md new file mode 100644 index 000000000000..807da5ae2d19 --- /dev/null +++ b/docs/debugging/xla_metadata.md @@ -0,0 +1,153 @@ +# Attaching XLA Metadata with `set_xla_metadata` + + + +**Summary:** `set_xla_metadata` allows you to attach metadata to operations in your JAX code. This metadata is passed down to the XLA compiler as `frontend_attributes` and can be used to enable compiler-level debugging tools, such as the XLA-TPU debugger. + +You can use it in three ways: + +1. Tag an individual operation by wrapping its output value +2. Tag a block of operations using a context manager +3. Tag all operations in a function using a decorator + +**Warning:** `set_xla_metadata` is an experimental feature and its API is subject to change. + +# What is XLA Metadata? +When JAX transforms and compiles your code, it ultimately generates an XLA (Accelerated Linear Algebra) computation graph. Each operation in this graph can have associated metadata, specifically `frontend_attributes`. This metadata doesn't change the numerical result of the operation, but it can be used to signal special behavior to the compiler or runtime. + +`set_xla_metadata` provides a way to attach this metadata directly from your JAX code. This is a powerful feature for low-level debugging and profiling. + +# Usage +## Tagging Individual Operations +Tagging an individual operation gives you precise control over which parts of your computation you want to inspect. To do this, you wrap the output (value) of an operation with `set_xla_metadata`. When wrapping a function with multiple operations within, only the final operation of said function will be tagged. + +```python +import jax +import jax.numpy as jnp +from jax.experimental.xla_metadata import set_xla_metadata + +# Tagging an individual operation +def value_tagging(x): + y = jnp.sin(x) + z = jnp.cos(x) + return set_xla_metadata(y * z, breakpoint=True) + +print(jax.jit(value_tagging).lower(1.0).as_text("hlo")) +``` +Results in: +``` +ENTRY main.5 { + x.1 = f32[] parameter(0) + sin.2 = f32[] sine(x.1) + cos.3 = f32[] cosine(x.1) + ROOT mul.4 = f32[] multiply(sin.2, cos.3), frontend_attributes={breakpoint="true"} +} +``` +## Tagging a Block of Code with a Context Manager or Decorator +If you want to apply the same metadata to a larger section of code, you can use `set_xla_metadata` as a context manager. All JAX operations within the `with` block will have the specified metadata attached. + +```python +import jax +import jax.numpy as jnp +from jax.experimental.xla_metadata import set_xla_metadata + +# Tagging a block of code +def context_tagging(x): + with set_xla_metadata(_xla_log=True): + y = jnp.sin(x) + z = jnp.cos(y) + return y * z + +print(jax.jit(context_tagging).lower(1.0).as_text("hlo")) +``` +Results in: +``` +ENTRY main.5 { + x.1 = f32[] parameter(0) + sin.2 = f32[] sine(x.1), frontend_attributes={_xla_log="true"} + cos.3 = f32[] cosine(sin.2), frontend_attributes={_xla_log="true"} + ROOT mul.4 = f32[] multiply(sin.2, cos.3), frontend_attributes={_xla_log="true"} +} +``` + +If you want to tag all operations in a function, you can also use `set_xla_metadata` as a decorator: + +```python +import jax +import jax.numpy as jnp +from jax.experimental.xla_metadata import set_xla_metadata + +# Tagging with a decorator +@set_xla_metadata(_xla_log=True) +@jax.jit +def decorator_tagging(x): + y = jnp.sin(x) + z = jnp.cos(y) + return y * z + +print(decorator_tagging.lower(1.0).as_text("hlo")) +``` +This will result in the same HLO as above. + +# Interaction with JAX Transformations +`set_xla_metadata` utilizes either a `XlaMetadataContextManager` or JAX `primitive` depending on use-case and is compatible with JAX's transformations like `jit`, `vmap`, and `grad`. +* **`vmap`**: When you `vmap` a function containing `set_xla_metadata`, the metadata will be applied to all of the relevant batched operations. +* **`grad`**: + 1. When tagging a block of operations with the **context manager** `with set_xla_metadata(...):`, the metadata is applied to both the forward pass and backward pass of the operations within it. + 2. Tagging **individual ops** with `set_xla_metadata()` currently only applies to the forward pass of a function. To tag individual operations generated by the backward pass (i.e., the gradient computation), a simple `custom_vjp` can be used: + ```python + import jax + import jax.numpy as jnp + from jax.experimental.xla_metadata import set_xla_metadata + + def fn(x): + y = jnp.sin(x) + z = jnp.cos(x) + return y * z + + metadata = {"example": "grad_tagging"} + + # --- Define Custom VJP to tag gradients --- + @jax.custom_vjp + def wrapped_fn(x): + return fn(x) + + def fwd(*args): + primal_out, vjp_fn = jax.vjp(fn, *args) + return primal_out, vjp_fn + + def bwd(vjp_fn, cts_in): + cts_out = vjp_fn(cts_in) + cts_out = set_xla_metadata(cts_out, **metadata) + return cts_out + + wrapped_fn.defvjp(fwd, bwd) + # ------ + + print(jax.jit(jax.grad(wrapped_fn)).lower(jnp.array(3.0)).as_text("hlo")) + ``` + Results in: + ``` + ENTRY main.10 { + x.1 = f32[] parameter(0) + sin.2 = f32[] sine(x.1) + neg.6 = f32[] negate(sin.2) + sin.5 = f32[] sine(x.1) + mul.7 = f32[] multiply(neg.6, sin.5) + cos.4 = f32[] cosine(x.1) + cos.3 = f32[] cosine(x.1) + mul.8 = f32[] multiply(cos.4, cos.3) + ROOT add_any.9 = f32[] add(mul.7, mul.8), frontend_attributes={example="grad_tagging"} + } + ``` +### Strengths and Limitations of `set_xla_metadata` + +#### Strengths +* **Variable Control:** Allows you to target individual operations or blocks of operations. +* **Non-Intrusive:** Does not change the numerical output or fusion behavior of your program. +* **Enables Powerful Tooling:** Unlocks the potential for sophisticated debugging and analysis at the compiler level. + +#### Limitations +* **Attributes may be lost:** While it's intended for XLA metadata to be maintained throughout transformations and HLO optimizations, certain edge-cases may result in the metadata being lost. +* **Forward-pass only:** Metadata is not currently automatically propagated to gradients when tagging individual operations in the backward pass. A `custom_vjp` must be used in order to tag gradients in this case. See above for an example. +* **Liable to change**: `set_xla_metadata` is an experimental feature and its API is subject to change. diff --git a/docs/default_dtypes.md b/docs/default_dtypes.md new file mode 100644 index 000000000000..629f7fb5c314 --- /dev/null +++ b/docs/default_dtypes.md @@ -0,0 +1,82 @@ +(default-dtypes)= +# Default dtypes and the X64 flag +JAX strives to meet the needs of a range of numerical computing practitioners, who +sometimes have conflicting preferences. When it comes to default dtypes, there are +two different camps: + +- Classic scientific computing practitioners (i.e. users of tools like {mod}`numpy` or + {mod}`scipy`) tend to value accuracy of computations foremost: such users would + prefer that computations default to the **widest available representation**: e.g. + floating point values should default to `float64`, integers to `int64`, etc. +- AI researchers (i.e. folks implementing and training neural networks) tend to value + speed over accuracy, to the point where they have developed special data types like + [bfloat16](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format) and others + which deliberately discard the least significant bits in order to speed up computation. + For these users, the mere presence of a float64 value in their computation can lead + to programs that are slow at best, and incompatible with their hardware at worst! + These users would prefer that computations default to `float32` or `int32`. + +The main mechanism JAX offers for this is the `jax_enable_x64` flag, which controls +whether 64-bit values can be created at all. By default this flag is set to `False` +(serving the needs of AI researchers and practitioners), but can be set to `True` +by users who value accuracy over computational speed. + +## Default setting: 32-bits everywhere +By default `jax_enable_x64` is set to False, and so {mod}`jax.numpy` array creation +functions will default to returning 32-bit values. + +For example: +```python +>>> import jax.numpy as jnp + +>>> jnp.arange(5) +Array([0, 1, 2, 3, 4], dtype=int32) + +>>> jnp.zeros(5) +Array([0., 0., 0., 0., 0.], dtype=float32) + +>>> jnp.ones(5, dtype=int) +Array([1, 1, 1, 1, 1], dtype=int32) + +``` + +Beyond defaults, because 64-bit values can be so poisonous to AI workflows, having +this flag set to False prevents you from creating 64-bit arrays at all! For example: +``` +>>> jnp.arange(5, dtype='float64') # doctest: +SKIP +UserWarning: Explicitly requested dtype float64 requested in arange is not available, and will be +truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the +JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. +Array([0., 1., 2., 3., 4.], dtype=float32) +``` + +## The X64 flag: enabling 64-bit values +To work in the "other mode" where functions default to producing 64-bit values, you can set the +`jax_enable_x64` flag to `True`: +```python +import jax +import jax.numpy as jnp + +jax.config.update('jax_enable_x64', True) + +print(repr(jnp.arange(5))) +print(repr(jnp.zeros(5))) +print(repr(jnp.ones(5, dtype=int))) +``` +``` +Array([0, 1, 2, 3, 4], dtype=int64) +Array([0., 0., 0., 0., 0.], dtype=float64) +Array([1, 1, 1, 1, 1], dtype=int64) +``` + +The X64 configuration can also be set via the `JAX_ENABLE_X64` shell environment variable, +for example: +```bash +$ JAX_ENABLE_X64=1 python main.py +``` +The X64 flag is intended as a **global setting** that should have one value for your whole +program, set at the top of your main file. A common feature request is for the flag to +be contextually configurable (e.g. enabling X64 just for one section of a long program): +this turns out to be difficult to implement within JAX's programming model, where code +execution may happen in a different context than code compilation. There is ongoing work +exploring the feasibility of relaxing this constraint, so stay tuned! diff --git a/docs/deprecation.md b/docs/deprecation.md index 7f8634d2b064..7375f6650ee4 100644 --- a/docs/deprecation.md +++ b/docs/deprecation.md @@ -15,26 +15,25 @@ This means we support at least: * All Python feature releases in the 45 months prior to each JAX release. For example: - * **Python 3.10** was released October 2021, and will be supported in new JAX releases at least until **July 2025**. * **Python 3.11** was released October 2022, and will be supported in new JAX releases at least until **July 2026**. * **Python 3.12** was released October 2023, and will be supported in new JAX releases at least until **July 2027**. * **Python 3.13** was released October 2024, and will be supported in new JAX releases at least until **July 2028**. + * **Python 3.14** was released October 2025, and will be supported in new JAX releases at least until **July 2029**. * All NumPy feature releases in the 24 months prior to each JAX release. For example: - * **NumPy 1.25** was released June 2023, and will be supported in new JAX releases at least until **June 2025** - * **NumPy 1.26** was released September 2023, and will be supported in new JAX releases at least until **September 2025** - * **NumPy 2.0** was released June 2024, and will be supported in new JAX releases at least until **June 2026** - * **NumPy 2.1** was released August 2024, and will be supported in new JAX releases at least until **August 2026** + * **NumPy 2.0** was released June 2024, and will be supported in new JAX releases at least until **June 2026**. + * **NumPy 2.1** was released August 2024, and will be supported in new JAX releases at least until **August 2026**. * **NumPy 2.2** was released December 2024, and will be supported in new JAX releases at least until **December 2026**. + * **NumPy 2.3** was released June 2025, and will be supported in new JAX releases at least until **June 2027**. * All SciPy feature releases in the 24 months prior to each JAX release. For example: - * **Scipy 1.11** was released June 2023, and will be supported in new JAX releases at least until **June 2025**. - * **Scipy 1.12** was released January 2024, and will be supported in new JAX releases at least until **January 2026**. - * **Scipy 1.13** was released April 2024, and will be supported in new JAX releases at least until **April 2026**. - * **Scipy 1.14** was released June 2024, and will be supported in new JAX releases at least until **June 2026**. - * **Scipy 1.15** was released January 2025, and will be supported in new JAX releases at least until **January 2027**. + * **SciPy 1.12** was released January 2024, and would normally be supported in new JAX releases at least until **January 2026**. However, we dropped SciPy 1.12 support in September 2025, because NumPy 2.0 support requires SciPy 1.13. + * **SciPy 1.13** was released April 2024, and will be supported in new JAX releases at least until **April 2026**. + * **SciPy 1.14** was released June 2024, and will be supported in new JAX releases at least until **June 2026**. + * **SciPy 1.15** was released January 2025, and will be supported in new JAX releases at least until **January 2027**. + * **SciPy 1.16** was released June 2025, and will be supported in new JAX releases at least until **June 2027**. JAX releases may support older versions of Python, NumPy, and SciPy than strictly required by this policy, but support for older versions may be dropped at any time beyond the listed diff --git a/docs/developer.md b/docs/developer.md index 0affbba9ed36..2fae4eae5a6f 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -1,7 +1,7 @@ (building-from-source)= # Building from source - + First, obtain the JAX source code: @@ -129,7 +129,7 @@ current directory. --bazel_options=--repo_env=LOCAL_NCCL_PATH="/foo/bar/nvidia/nccl" ``` - Please see the full list of instructions in [XLA documentation](https://github.com/openxla/xla/blob/main/docs/hermetic_cuda.md). + Please see the full list of instructions in [XLA documentation](https://github.com/google-ml-infra/rules_ml_toolchain/blob/main/gpu). * JAX versions prior v.0.4.32: you must have CUDA and CUDNN installed and provide paths to them using configuration options. @@ -263,10 +263,12 @@ Alternatively, if you need more control, you may run the bazel command directly (the two commands are equivalent): ``` +# Regular Python bazel run //build:requirements.update --repo_env=HERMETIC_PYTHON_VERSION=3.12 -``` -where `3.12` is the `Python` version you wish to update. +# Free-threaded Python +bazel run //build:requirements_ft.update --repo_env=HERMETIC_PYTHON_VERSION=3.13-ft +``` Note, since it is still `pip` and `pip-compile` tools used under the hood, so most of the command line arguments and features supported by those tools will be @@ -299,7 +301,7 @@ and re-run the requirements updater command for a selected version of Python. For example: ``` -echo -e "\n$(realpath jaxlib-0.4.27.dev20240416-cp312-cp312-manylinux2014_x86_64.whl)" >> build/requirements.in +echo -e "\n$(realpath jaxlib-0.4.27.dev20240416-cp312-cp312-manylinux_2_27_x86_64.whl)" >> build/requirements.in python build/build.py requirements_update --python_version=3.12 ``` @@ -374,7 +376,7 @@ in terms of files, not installations): --repo_env=HERMETIC_PYTHON_URL="https://remote/url/to/my_python.tgz" --repo_env=HERMETIC_PYTHON_SHA256= - # We assume that top-level folder in the tarbal is called "python", if it is + # We assume that top-level folder in the tarball is called "python", if it is # something different just pass additional HERMETIC_PYTHON_PREFIX parameter --repo_env=HERMETIC_PYTHON_URL="https://remote/url/to/my_python.tgz" --repo_env=HERMETIC_PYTHON_SHA256= @@ -455,7 +457,6 @@ which one is selected by specifying `HERMETIC_PYTHON_VERSION`. For example in `WORKSPACE` file: ``` requirements = { - "3.10": "//build:requirements_lock_3_10.txt", "3.11": "//build:requirements_lock_3_11.txt", "3.12": "//build:requirements_lock_3_12.txt", "3.13": "//build:requirements_lock_3_13.txt", @@ -466,16 +467,16 @@ requirements = { Then you can build and test different combinations of stuff without changing anything in your environment: ``` -# To build with scenario1 dependendencies: +# To build with scenario1 dependencies: bazel test --repo_env=HERMETIC_PYTHON_VERSION=3.13-scenario1 -# To build with scenario2 dependendencies: +# To build with scenario2 dependencies: bazel test --repo_env=HERMETIC_PYTHON_VERSION=3.13-scenario2 -# To build with default dependendencies: +# To build with default dependencies: bazel test --repo_env=HERMETIC_PYTHON_VERSION=3.13 -# To build with scenario1 dependendencies and custom Python 3.13 interpreter: +# To build with scenario1 dependencies and custom Python 3.13 interpreter: bazel test --repo_env=HERMETIC_PYTHON_VERSION=3.13-scenario1 --repo_env=HERMETIC_PYTHON_URL="file:///path/to/cpython.tar.gz" @@ -526,6 +527,11 @@ bazel test //tests:cpu_tests //tests:backend_independent_tests `//tests:gpu_tests` and `//tests:tpu_tests` are also available, if you have the necessary hardware. +You need to configure `cuda` to run `gpu` tests: +``` +python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --configure_only +``` + To use a preinstalled `jaxlib` instead of building it you first need to make it available in the hermetic Python. To install a specific version of `jaxlib` within hermetic Python run (using `jaxlib >= 0.4.26` as an example): @@ -538,7 +544,7 @@ python build/build.py requirements_update Alternatively, to install `jaxlib` from a local wheel (assuming Python 3.12): ``` -echo -e "\n$(realpath jaxlib-0.4.26-cp312-cp312-manylinux2014_x86_64.whl)" >> build/requirements.in +echo -e "\n$(realpath jaxlib-0.4.26-cp312-cp312-manylinux_2_27_x86_64.whl)" >> build/requirements.in python build/build.py requirements_update --python_version=3.12 ``` @@ -785,7 +791,7 @@ desired formats, and which the `jupytext --sync` command recognizes when invoked #### Notebooks within the Sphinx build Some of the notebooks are built automatically as part of the pre-submit checks and -as part of the [Read the docs](https://jax.readthedocs.io/en/latest) build. +as part of the [Read the docs](https://docs.jax.dev/en/latest) build. The build will fail if cells raise errors. If the errors are intentional, you can either catch them, or tag the cell with `raises-exceptions` metadata ([example PR](https://github.com/jax-ml/jax/pull/2402/files)). You have to add this metadata by hand in the `.ipynb` file. It will be preserved when somebody else @@ -796,7 +802,7 @@ See `exclude_patterns` in [conf.py](https://github.com/jax-ml/jax/blob/main/docs ### Documentation building on `readthedocs.io` -JAX's auto-generated documentation is at . +JAX's auto-generated documentation is at . The documentation building is controlled for the entire project by the [readthedocs JAX settings](https://readthedocs.org/dashboard/jax). The current settings @@ -807,10 +813,10 @@ For each code version, the building process is driven by the For each automated documentation build you can see the [documentation build logs](https://readthedocs.org/projects/jax/builds/). -If you want to test the documentation generation on Readthedocs, you can push code to the `test-docs` -branch. That branch is also built automatically, and you can -see the generated documentation [here](https://jax.readthedocs.io/en/test-docs/). If the documentation build -fails you may want to [wipe the build environment for test-docs](https://docs.readthedocs.io/en/stable/guides/wipe-environment.html). +If you want to test the documentation generation on Readthedocs, +you can add the "documentation" GitHub label to your PR. You will then +be able to view the docs from the link for the "docs/readthedocs.org:jax" +GitHub check. For a local test, I was able to do it in a fresh directory by replaying the commands I saw in the Readthedocs logs: diff --git a/docs/device_memory_profiling.md b/docs/device_memory_profiling.md index a2fd3f68780c..e2fa67c2f1a2 100644 --- a/docs/device_memory_profiling.md +++ b/docs/device_memory_profiling.md @@ -3,8 +3,8 @@ ```{note} -May 2023 update: we recommend using [Tensorboard -profiling](tensorboard-profiling) for device memory analysis. After taking a +June 2025 update: we recommend using [XProf +profiling](xprof-profiling) for device memory analysis. After taking a profile, open the `memory_viewer` tab of the Tensorboard profiler for more detailed and understandable device memory usage. ``` diff --git a/docs/direct_linearize_migration.md b/docs/direct_linearize_migration.md new file mode 100644 index 000000000000..556f740bbd65 --- /dev/null +++ b/docs/direct_linearize_migration.md @@ -0,0 +1,43 @@ +--- +orphan: true +--- +(direct-linearize-migration)= +# JAX direct linearize + + + +### What’s going on? + +We're changing the way JAX implements autodiff internally. Previously grad was +done by a three-stage process: JVP, partial eval, transposition. With this +change we've bundled together the first two steps, JVP and partial eval, into a +new transformation: linearization. + +This should mostly not change user-visible behavior. Some exceptions: + + * you'll see LinearizeTracer instead of JVPTracer if you print out traced values during autodiff. + + * It's possible that some numerics will change, just for the usual reason that any perturbation to programs can slightly alter numerical results. + + +### Why? + +The upgrade unlocks several new features, like: + + * differentiation involving Pallas-style mutable array references; + + * simpler and more flexible user-defined autodiff rules, like custom_vjp/jvp; + + * controlling the autodiff behavior on user-defined types. + +### This change broke my stuff! + +For now, you can still get the old behavior by unsetting the use_direct_linearize config option: + + * set the shell environment variable to something falsey, e.g. JAX_USE_DIRECT_LINEARIZE=0 + + * set the config option jax.config.update('jax_use_direct_linearize', False) + + * if you parse flags with absl, you can pass the command-line flag --jax_use_direct_linearize=false + +We plan to remove the config option on August 16th 2025. diff --git a/docs/errors.rst b/docs/errors.rst index 9965d6698bd4..9e27fe7c4a0b 100644 --- a/docs/errors.rst +++ b/docs/errors.rst @@ -7,9 +7,11 @@ This page lists a few of the errors you might encounter when using JAX, along with representative examples of how one might fix them. .. currentmodule:: jax.errors +.. autoclass:: JaxRuntimeError +.. autoclass:: JAXTypeError +.. autoclass:: JAXIndexError .. autoclass:: ConcretizationTypeError .. autoclass:: KeyReuseError -.. autoclass:: JaxRuntimeError .. autoclass:: NonConcreteBooleanIndexError .. autoclass:: TracerArrayConversionError .. autoclass:: TracerBoolConversionError diff --git a/docs/export/export.md b/docs/export/export.md index 18cdcc6c51d0..eec4a204ba31 100644 --- a/docs/export/export.md +++ b/docs/export/export.md @@ -69,7 +69,7 @@ Serialization is broken down into two stages: call it from another JAX function. We have plans to add code to generate `Exported` objects from TensorFlow, and to use `Exported` objects from TensorFlow and PyTorch. - 2. the actual serialization to a byte array using the flatbuffers format. + 2. the actual serialization to a byte array using the flatbuffers format. See {ref}`jax2tf` for an alternative serialization to TensorFlow graph that can be used for interoperation with TensorFlow. @@ -161,7 +161,7 @@ e.g., the inference system.) What **matters is when the exporting and consuming components were built**, not the time when the exporting and the compilation happen. For external JAX users, it is -[possible to run JAX and jaxlib at different versions](https://jax.readthedocs.io/en/latest/jep/9419-jax-versioning.html#how-are-jax-and-jaxlib-versioned); +[possible to run JAX and jaxlib at different versions](https://docs.jax.dev/en/latest/jep/9419-jax-versioning.html#how-are-jax-and-jaxlib-versioned); what matters is when the jaxlib release was built. To reduce chances of incompatibility, internal JAX users should: @@ -448,7 +448,7 @@ artifacts using a new mesh constructed at the call site: ... sharding=NamedSharding(export_mesh, P("a")))) >>> # Prepare the mesh for calling `exp`. ->>> calling_mesh = Mesh(np.array(export_devices[::-1]), ("b",)) +>>> calling_mesh = Mesh(np.array(export_devices[::-1]), ("a",)) >>> # Shard the arg according to what `exp` expects. >>> arg = jnp.arange(4 * len(export_devices)) @@ -501,7 +501,7 @@ As of June 2024, all function exported with version 9 >>> from jax import export >>> exp: export.Exported = export.export(jnp.cos)(1.) >>> exp.calling_convention_version -9 +10 ``` @@ -513,13 +513,13 @@ or the `JAX_EXPORT_CALLING_CONVENTION_VERSION` environment variable: ```python >>> from jax import export >>> (export.minimum_supported_calling_convention_version, export.maximum_supported_calling_convention_version) -(9, 9) +(9, 10) >>> from jax._src import config ->>> with config.jax_export_calling_convention_version(9): +>>> with config.jax_export_calling_convention_version(10): ... exp = export.export(jnp.cos)(1.) ... exp.calling_convention_version -9 +10 ``` @@ -668,6 +668,9 @@ We list here a history of the calling convention version numbers: available in JAX since October 20th, 2023 (JAX 0.4.20), and the default since February 1st, 2024 (JAX 0.4.24). This is the only supported version as of 27th of March, 2024. + * Version 10 propagate the `jax.config.use_shardy_partitioner` value to + XlaCallModule. Supported by XlaCallModule since May 20th, 2025, and + the default in JAX since July 14th, 2025 (JAX 0.7.0). ## Developer documentation @@ -710,10 +713,7 @@ total 32 -rw-rw-r--@ 1 necula wheel 2333 Jun 19 11:04 jax_ir3_jit_my_fun_export.mlir ``` -Inside Google, you can turn on logging by using the `--vmodule` argument to -specify the logging levels for different modules, -e.g., `--vmodule=_export=3`. - +Set [`JAX_DEBUG_LOG_MODULES=jax._src.export`](https://docs.jax.dev/en/latest/config_options.html#jax_debug_log_modules) to enable extra debugging logging. (export_ensuring_compat)= ### Ensuring forward and backward compatibility @@ -771,14 +771,23 @@ that live in jaxlib): ``` * Note that the forward compatibility mode is always false in JIT mode or if the user passes `--jax_export_ignore_forward_compatibility=true` - * We add `T_NEW` to the list of - [`_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE`](https://github.com/search?q=repo%3Ajax-ml%2Fjax++%22_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE+%3D%22+path%3A_export.py&%3Btype=code&type=code) - in `_export.py`. - 3. Day “D + 21” (end of forward compatibility window; can be even later than 21 days): + * Note that at this point the exports will still not use `T_NEW`. + 3. This can be done at any time after the previous step, and before + the next step: Add a backward compatibility test for `T_NEW`, + and add `T_NEW` to the list of + [`_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE`](https://github.com/search?q=repo%3Ajax-ml%2Fjax++%22_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE+%3D%22+path%3A_export.py&%3Btype=code&type=code) in `_export.py`. + * Instructions for adding backwards compatibility tests are at the top of + [export_back_compat_test_util.py](https://github.com/search?q=repo%3Ajax-ml%2Fjax+++path%3Aexport_back_compat_test_util.py&%3Btype=code&type=code). + * An example is in [PR #29488](https://github.com/jax-ml/jax/pull/29488). + * Note that if you do this before the next step, the exporting will still not + use the `T_NEW` lowering, and you have to add + `with config.export_ignore_forward_compatibility(True):` around the call to + `self.run_one_test`. This can be removed when you actually get to step 4. + * You may also need to enable the test only for new versions of jaxlib. + 4. Day “D + 21” (end of forward compatibility window; can be even later than 21 days): We remove the `forward_compat_mode` in the lowering code, so now exporting will start using the new custom call target `T_NEW` as long as we are using a new `jaxlib`. - * We add a backwards compatibility test for `T_NEW`. - 4. Day "RELEASE > D" (the first JAX release date after `D`, when we release version `0.4.31`): + 5. Day "RELEASE > D" (the first JAX release date after `D`, when we release version `0.4.31`): we start the clock for the 6 months backwards compatibility. Note that this is relevant only if `T` is among the custom call targets for which we already guarantee stability, i.e., are listed in @@ -787,7 +796,7 @@ that live in jaxlib): we make `RELEASE` the minimum allowed jaxlib version then we can remove the `jaxlib_version < (0, 4, 31)` conditional in the JIT branch. - 5. Day “RELEASE + 180” (end of backward compatibility window, + 6. Day “RELEASE + 180” (end of backward compatibility window, can be even later than 180 days): By now, we must have bumped the minimum jaxlib so that the lowering conditional `jaxlib_version < (0, 4, 31)` was already removed and JAX lowering cannot generate custom calls to `T`. diff --git a/docs/export/shape_poly.md b/docs/export/shape_poly.md index 9254030a4e1c..121752bd2ef4 100644 --- a/docs/export/shape_poly.md +++ b/docs/export/shape_poly.md @@ -86,7 +86,7 @@ matching the structure of the arguments passed to it. The polymorphic shapes specification can be a pytree prefix in cases where one specification should apply to multiple arguments, as in the above example. -See [how optional parameters are matched to arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). +See [how optional parameters are matched to arguments](https://docs.jax.dev/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). A few examples of shape specifications: @@ -221,14 +221,14 @@ JAX shape check errors: ```python >>> v, = export.symbolic_shape("v,") ->>> export.export(jax.jit(lambda x, y: x + y))( -... jax.ShapeDtypeStruct((v,), dtype=np.int32), -... jax.ShapeDtypeStruct((4,), dtype=np.int32)) +>>> export.export(jax.jit(lambda x, y: x + y))( # doctest: +IGNORE_EXCEPTION_DETAIL +... jax.ShapeDtypeStruct((v,), dtype=np.int32), # doctest: +IGNORE_EXCEPTION_DETAIL +... jax.ShapeDtypeStruct((4,), dtype=np.int32)) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): TypeError: add got incompatible shapes for broadcasting: (v,), (4,). ->>> export.export(jax.jit(lambda x: jnp.matmul(x, x)))( -... jax.ShapeDtypeStruct((v, 4), dtype=np.int32)) +>>> export.export(jax.jit(lambda x: jnp.matmul(x, x)))( # doctest: +IGNORE_EXCEPTION_DETAIL +... jax.ShapeDtypeStruct((v, 4), dtype=np.int32)) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): TypeError: dot_general requires contracting dimensions to have the same shape, got (4,) and (v,). @@ -441,7 +441,7 @@ to {func}`jax.export.symbolic_shape` share a scope and can be mixed up in arithmetic operations. The result would also share the same scope. -You can re-use scopes: +You can reuse scopes: ```python >>> a, = export.symbolic_shape("a,", constraints=("a >= 8",)) @@ -609,7 +609,7 @@ Division had remainder 1 when computing the value of 'd'. Using the following polymorphic shapes specifications: args[0].shape = (b, b, 2*d). Obtained dimension variables: 'b' = 3 from specification 'b' for dimension args[0].shape[0] (= 3), . -Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details. +Please see https://docs.jax.dev/en/latest/export/shape_poly.html#shape-assertion-errors for more details. ``` diff --git a/docs/faq.rst b/docs/faq.rst index 44267f6f5f7d..5653ff1cbb26 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -4,7 +4,7 @@ Frequently asked questions (FAQ) .. comment RST primer for Sphinx: https://thomas-cokelaer.info/tutorials/sphinx/rest_syntax.html .. comment Some links referenced here. Use `JAX - The Sharp Bits`_ (underscore at the end) to reference -.. _JAX - The Sharp Bits: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html +.. _JAX - The Sharp Bits: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html We are collecting answers to frequently asked questions here. Contributions welcome! @@ -116,7 +116,7 @@ code in JAX's internal representation, typically because it makes heavy use of Python control flow such as ``for`` loops. For a handful of loop iterations, Python is OK, but if you need *many* loop iterations, you should rewrite your code to make use of JAX's -`structured control flow primitives `_ +`structured control flow primitives `_ (such as :func:`lax.scan`) or avoid wrapping the loop with ``jit`` (you can still use ``jit`` decorated functions *inside* the loop). @@ -137,332 +137,14 @@ on GitHub. How to use ``jit`` with methods? -------------------------------- -Most examples of :func:`jax.jit` concern decorating stand-alone Python functions, -but decorating a method within a class introduces some complication. For example, -consider the following simple class, where we've used a standard :func:`~jax.jit` -annotation on a method:: - - >>> import jax.numpy as jnp - >>> from jax import jit - - >>> class CustomClass: - ... def __init__(self, x: jnp.ndarray, mul: bool): - ... self.x = x - ... self.mul = mul - ... - ... @jit # <---- How to do this correctly? - ... def calc(self, y): - ... if self.mul: - ... return self.x * y - ... return y - -However, this approach will result in an error when you attempt to call this method:: - - >>> c = CustomClass(2, True) - >>> c.calc(3) # doctest: +SKIP - --------------------------------------------------------------------------- - TypeError Traceback (most recent call last) - File "", line 1, in ' of type is not a valid JAX type. - -The problem is that the first argument to the function is ``self``, which has type -``CustomClass``, and JAX does not know how to handle this type. -There are three basic strategies we might use in this case, and we'll discuss -them below. - -Strategy 1: JIT-compiled helper function -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -The most straightforward approach is to create a helper function external to the class -that can be JIT-decorated in the normal way. For example:: - - >>> from functools import partial - - >>> class CustomClass: - ... def __init__(self, x: jnp.ndarray, mul: bool): - ... self.x = x - ... self.mul = mul - ... - ... def calc(self, y): - ... return _calc(self.mul, self.x, y) - - >>> @partial(jit, static_argnums=0) - ... def _calc(mul, x, y): - ... if mul: - ... return x * y - ... return y - -The result will work as expected:: - - >>> c = CustomClass(2, True) - >>> print(c.calc(3)) - 6 - -The benefit of such an approach is that it is simple, explicit, and it avoids the need -to teach JAX how to handle objects of type ``CustomClass``. However, you may wish to -keep all the method logic in the same place. - -Strategy 2: Marking ``self`` as static -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Another common pattern is to use ``static_argnums`` to mark the ``self`` argument as static. -But this must be done with care to avoid unexpected results. -You may be tempted to simply do this:: - - >>> class CustomClass: - ... def __init__(self, x: jnp.ndarray, mul: bool): - ... self.x = x - ... self.mul = mul - ... - ... # WARNING: this example is broken, as we'll see below. Don't copy & paste! - ... @partial(jit, static_argnums=0) - ... def calc(self, y): - ... if self.mul: - ... return self.x * y - ... return y - -If you call the method, it will no longer raise an error:: - - >>> c = CustomClass(2, True) - >>> print(c.calc(3)) - 6 - -However, there is a catch: if you mutate the object after the first method call, the -subsequent method call may return an incorrect result:: - - >>> c.mul = False - >>> print(c.calc(3)) # Should print 3 - 6 - -Why is this? When you mark an object as static, it will effectively be used as a dictionary -key in JIT's internal compilation cache, meaning its hash (i.e. ``hash(obj)``) equality -(i.e. ``obj1 == obj2``) and object identity (i.e. ``obj1 is obj2``) will be assumed to have -consistent behavior. The default ``__hash__`` for a custom object is its object ID, and so -JAX has no way of knowing that a mutated object should trigger a re-compilation. - -You can partially address this by defining an appropriate ``__hash__`` and ``__eq__`` methods -for your object; for example:: - - >>> class CustomClass: - ... def __init__(self, x: jnp.ndarray, mul: bool): - ... self.x = x - ... self.mul = mul - ... - ... @partial(jit, static_argnums=0) - ... def calc(self, y): - ... if self.mul: - ... return self.x * y - ... return y - ... - ... def __hash__(self): - ... return hash((self.x, self.mul)) - ... - ... def __eq__(self, other): - ... return (isinstance(other, CustomClass) and - ... (self.x, self.mul) == (other.x, other.mul)) - -(see the :meth:`object.__hash__` documentation for more discussion of the requirements -when overriding ``__hash__``). - -This should work correctly with JIT and other transforms **so long as you never mutate -your object**. Mutations of objects used as hash keys lead to several subtle problems, -which is why for example mutable Python containers (e.g. :class:`dict`, :class:`list`) -don't define ``__hash__``, while their immutable counterparts (e.g. :class:`tuple`) do. - -If your class relies on in-place mutations (such as setting ``self.attr = ...`` within its -methods), then your object is not really "static" and marking it as such may lead to problems. -Fortunately, there's another option for this case. - -Strategy 3: Making ``CustomClass`` a PyTree -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -The most flexible approach to correctly JIT-compiling a class method is to register the -type as a custom PyTree object; see :ref:`extending-pytrees`. This lets you specify -exactly which components of the class should be treated as static and which should be -treated as dynamic. Here's how it might look:: - - >>> class CustomClass: - ... def __init__(self, x: jnp.ndarray, mul: bool): - ... self.x = x - ... self.mul = mul - ... - ... @jit - ... def calc(self, y): - ... if self.mul: - ... return self.x * y - ... return y - ... - ... def _tree_flatten(self): - ... children = (self.x,) # arrays / dynamic values - ... aux_data = {'mul': self.mul} # static values - ... return (children, aux_data) - ... - ... @classmethod - ... def _tree_unflatten(cls, aux_data, children): - ... return cls(*children, **aux_data) - - >>> from jax import tree_util - >>> tree_util.register_pytree_node(CustomClass, - ... CustomClass._tree_flatten, - ... CustomClass._tree_unflatten) - -This is certainly more involved, but it solves all the issues associated with the simpler -approaches used above:: - - >>> c = CustomClass(2, True) - >>> print(c.calc(3)) - 6 - - >>> c.mul = False # mutation is detected - >>> print(c.calc(3)) - 3 - - >>> c = CustomClass(jnp.array(2), True) # non-hashable x is supported - >>> print(c.calc(3)) - 6 - -So long as your ``tree_flatten`` and ``tree_unflatten`` functions correctly handle all -relevant attributes in the class, you should be able to use objects of this type directly -as arguments to JIT-compiled functions, without any special annotations. -.. _faq-data-placement: - -Controlling data and computation placement on devices ------------------------------------------------------ - -Let's first look at the principles of data and computation placement in JAX. - -In JAX, the computation follows data placement. JAX arrays -have two placement properties: 1) the device where the data resides; -and 2) whether it is **committed** to the device or not (the data is sometimes -referred to as being *sticky* to the device). - -By default, JAX arrays are placed uncommitted on the default device -(``jax.devices()[0]``), which is the first GPU or TPU by default. If no GPU or -TPU is present, ``jax.devices()[0]`` is the CPU. The default device can -be temporarily overridden with the :func:`jax.default_device` context manager, or -set for the whole process by setting the environment variable ``JAX_PLATFORMS`` -or the absl flag ``--jax_platforms`` to "cpu", "gpu", or "tpu" -(``JAX_PLATFORMS`` can also be a list of platforms, which determines which -platforms are available in priority order). - ->>> from jax import numpy as jnp ->>> print(jnp.ones(3).devices()) # doctest: +SKIP -{CudaDevice(id=0)} - -Computations involving uncommitted data are performed on the default -device and the results are uncommitted on the default device. - -Data can also be placed explicitly on a device using :func:`jax.device_put` -with a ``device`` parameter, in which case the data becomes **committed** to the device: - ->>> import jax ->>> from jax import device_put ->>> arr = device_put(1, jax.devices()[2]) # doctest: +SKIP ->>> print(arr.devices()) # doctest: +SKIP -{CudaDevice(id=2)} - -Computations involving some committed inputs will happen on the -committed device and the result will be committed on the -same device. Invoking an operation on arguments that are committed -to more than one device will raise an error. - -You can also use :func:`jax.device_put` without a ``device`` parameter. If the data -is already on a device (committed or not), it's left as-is. If the data isn't on any -device—that is, it's a regular Python or NumPy value—it's placed uncommitted on the default -device. - -Jitted functions behave like any other primitive operations—they will follow the -data and will show errors if invoked on data committed on more than one device. - -(Before `PR #6002 `_ in March 2021 -there was some laziness in creation of array constants, so that -``jax.device_put(jnp.zeros(...), jax.devices()[1])`` or similar would actually -create the array of zeros on ``jax.devices()[1]``, instead of creating the -array on the default device then moving it. But this optimization was removed -so as to simplify the implementation.) - -(As of April 2020, :func:`jax.jit` has a `device` parameter that affects the device -placement. That parameter is experimental, is likely to be removed or changed, -and its use is not recommended.) - -For a worked-out example, we recommend reading through -``test_computation_follows_data`` in -`multi_device_test.py `_. - -.. _faq-benchmark: - -Benchmarking JAX code ---------------------- - -You just ported a tricky function from NumPy/SciPy to JAX. Did that actually -speed things up? - -Keep in mind these important differences from NumPy when measuring the -speed of code using JAX: - -1. **JAX code is Just-In-Time (JIT) compiled.** Most code written in JAX can be - written in such a way that it supports JIT compilation, which can make it run - *much faster* (see `To JIT or not to JIT`_). To get maximum performance from - JAX, you should apply :func:`jax.jit` on your outer-most function calls. - - Keep in mind that the first time you run JAX code, it will be slower because - it is being compiled. This is true even if you don't use ``jit`` in your own - code, because JAX's builtin functions are also JIT compiled. -2. **JAX has asynchronous dispatch.** This means that you need to call - ``.block_until_ready()`` to ensure that computation has actually happened - (see :ref:`async-dispatch`). -3. **JAX by default only uses 32-bit dtypes.** You may want to either explicitly - use 32-bit dtypes in NumPy or enable 64-bit dtypes in JAX (see - `Double (64 bit) precision`_) for a fair comparison. -4. **Transferring data between CPUs and accelerators takes time.** If you only - want to measure how long it takes to evaluate a function, you may want to - transfer data to the device on which you want to run it first (see - :ref:`faq-data-placement`). - -Here's an example of how to put together all these tricks into a microbenchmark -for comparing JAX versus NumPy, making using of IPython's convenient -`%time and %timeit magics`_:: - - import numpy as np - import jax.numpy as jnp - import jax - - def f(x): # function we're benchmarking (works in both NumPy & JAX) - return x.T @ (x - x.mean(axis=0)) - - x_np = np.ones((1000, 1000), dtype=np.float32) # same as JAX default dtype - %timeit f(x_np) # measure NumPy runtime - - %time x_jax = jax.device_put(x_np) # measure JAX device transfer time - f_jit = jax.jit(f) - %time f_jit(x_jax).block_until_ready() # measure JAX compilation time - %timeit f_jit(x_jax).block_until_ready() # measure JAX runtime - -When run with a GPU in Colab_, we see: - -- NumPy takes 16.2 ms per evaluation on the CPU -- JAX takes 1.26 ms to copy the NumPy arrays onto the GPU -- JAX takes 193 ms to compile the function -- JAX takes 485 µs per evaluation on the GPU - -In this case, we see that once the data is transferred and the function is -compiled, JAX on the GPU is about 30x faster for repeated evaluations. - -Is this a fair comparison? Maybe. The performance that ultimately matters is for -running full applications, which inevitably include some amount of both data -transfer and compilation. Also, we were careful to pick large enough arrays -(1000x1000) and an intensive enough computation (the ``@`` operator is -performing matrix-matrix multiplication) to amortize the increased overhead of -JAX/accelerators vs NumPy/CPU. For example, if we switch this example to use -10x10 input instead, JAX/GPU runs 10x slower than NumPy/CPU (100 µs vs 10 µs). - -.. _To JIT or not to JIT: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#to-jit-or-not-to-jit -.. _Double (64 bit) precision: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision -.. _`%time and %timeit magics`: https://ipython.readthedocs.io/en/stable/interactive/magics.html#magic-time -.. _Colab: https://colab.research.google.com/ +Moved to :ref:`jax-jit-class-methods`. .. _faq-jax-vs-numpy: Is JAX faster than NumPy? -~~~~~~~~~~~~~~~~~~~~~~~~~ +------------------------- + One question users frequently attempt to answer with such benchmarks is whether JAX is faster than NumPy; due to the difference in the two packages, there is not a simple answer. @@ -492,172 +174,6 @@ lower per-operation dispatch overhead. If you're running your code on GPU or TPU or are benchmarking more complicated JIT-compiled sequences of operations on CPU, you can generally expect JAX to outperform NumPy. -.. _faq-different-kinds-of-jax-values: - -Different kinds of JAX values ------------------------------ - -In the process of transforming functions, JAX replaces some function -arguments with special tracer values. - -You could see this if you use a ``print`` statement:: - - def func(x): - print(x) - return jnp.cos(x) - - res = jax.jit(func)(0.) - -The above code does return the correct value ``1.`` but it also prints -``Traced`` for the value of ``x``. Normally, JAX -handles these tracer values internally in a transparent way, e.g., -in the numeric JAX primitives that are used to implement the -``jax.numpy`` functions. This is why ``jnp.cos`` works in the example above. - -More precisely, a **tracer** value is introduced for the argument of -a JAX-transformed function, except the arguments identified by special -parameters such as ``static_argnums`` for :func:`jax.jit` or -``static_broadcasted_argnums`` for :func:`jax.pmap`. Typically, computations -that involve at least a tracer value will produce a tracer value. Besides tracer -values, there are **regular** Python values: values that are computed outside JAX -transformations, or arise from above-mentioned static arguments of certain JAX -transformations, or computed solely from other regular Python values. -These are the values that are used everywhere in absence of JAX transformations. - -A tracer value carries an **abstract** value, e.g., ``ShapedArray`` with information -about the shape and dtype of an array. We will refer here to such tracers as -**abstract tracers**. Some tracers, e.g., those that are -introduced for arguments of autodiff transformations, carry ``ConcreteArray`` -abstract values that actually include the regular array data, and are used, -e.g., for resolving conditionals. We will refer here to such tracers -as **concrete tracers**. Tracer values computed from these concrete tracers, -perhaps in combination with regular values, result in concrete tracers. -A **concrete value** is either a regular value or a concrete tracer. - -Most often values computed from tracer values are themselves tracer values. -There are very few exceptions, when a computation can be entirely done -using the abstract value carried by a tracer, in which case the result -can be a regular value. For example, getting the shape of a tracer -with ``ShapedArray`` abstract value. Another example is when explicitly -casting a concrete tracer value to a regular type, e.g., ``int(x)`` or -``x.astype(float)``. -Another such situation is for ``bool(x)``, which produces a Python bool when -concreteness makes it possible. That case is especially salient because -of how often it arises in control flow. - -Here is how the transformations introduce abstract or concrete tracers: - -* :func:`jax.jit`: introduces **abstract tracers** for all positional arguments - except those denoted by ``static_argnums``, which remain regular - values. -* :func:`jax.pmap`: introduces **abstract tracers** for all positional arguments - except those denoted by ``static_broadcasted_argnums``. -* :func:`jax.vmap`, :func:`jax.make_jaxpr`, :func:`xla_computation`: - introduce **abstract tracers** for all positional arguments. -* :func:`jax.jvp` and :func:`jax.grad` introduce **concrete tracers** - for all positional arguments. An exception is when these transformations - are within an outer transformation and the actual arguments are - themselves abstract tracers; in that case, the tracers introduced - by the autodiff transformations are also abstract tracers. -* All higher-order control-flow primitives (:func:`lax.cond`, :func:`lax.while_loop`, - :func:`lax.fori_loop`, :func:`lax.scan`) when they process the functionals - introduce **abstract tracers**, whether or not there is a JAX transformation - in progress. - -All of this is relevant when you have code that can operate -only on regular Python values, such as code that has conditional -control-flow based on data:: - - def divide(x, y): - return x / y if y >= 1. else 0. - -If we want to apply :func:`jax.jit`, we must ensure to specify ``static_argnums=1`` -to ensure ``y`` stays a regular value. This is due to the boolean expression -``y >= 1.``, which requires concrete values (regular or tracers). The -same would happen if we write explicitly ``bool(y >= 1.)``, or ``int(y)``, -or ``float(y)``. - -Interestingly, ``jax.grad(divide)(3., 2.)``, works because :func:`jax.grad` -uses concrete tracers, and resolves the conditional using the concrete -value of ``y``. - -.. _faq-donation: - -Buffer donation ---------------- - -When JAX executes a computation it uses buffers on the device for all inputs and outputs. -If you know that one of the inputs is not needed after the computation, and if it -matches the shape and element type of one of the outputs, you can specify that you -want the corresponding input buffer to be donated to hold an output. This will reduce -the memory required for the execution by the size of the donated buffer. - -If you have something like the following pattern, you can use buffer donation:: - - params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params, state) - -You can think of this as a way to do a memory-efficient functional update -on your immutable JAX arrays. Within the boundaries of a computation XLA can -make this optimization for you, but at the jit/pmap boundary you need to -guarantee to XLA that you will not use the donated input buffer after calling -the donating function. - -You achieve this by using the `donate_argnums` parameter to the functions :func:`jax.jit`, -:func:`jax.pjit`, and :func:`jax.pmap`. This parameter is a sequence of indices (0 based) into -the positional argument list:: - - def add(x, y): - return x + y - - x = jax.device_put(np.ones((2, 3))) - y = jax.device_put(np.ones((2, 3))) - # Execute `add` with donation of the buffer for `y`. The result has - # the same shape and type as `y`, so it will share its buffer. - z = jax.jit(add, donate_argnums=(1,))(x, y) - -Note that this currently does not work when calling your function with key-word arguments! -The following code will not donate any buffers:: - - params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params=params, state=state) - -If an argument whose buffer is donated is a pytree, then all the buffers -for its components are donated:: - - def add_ones(xs: List[Array]): - return [x + 1 for x in xs] - - xs = [jax.device_put(np.ones((2, 3))), jax.device_put(np.ones((3, 4)))] - # Execute `add_ones` with donation of all the buffers for `xs`. - # The outputs have the same shape and type as the elements of `xs`, - # so they will share those buffers. - z = jax.jit(add_ones, donate_argnums=0)(xs) - -It is not allowed to donate a buffer that is used subsequently in the computation, -and JAX will give an error because the buffer for `y` has become invalid -after it was donated:: - - # Donate the buffer for `y` - z = jax.jit(add, donate_argnums=(1,))(x, y) - w = y + 1 # Reuses `y` whose buffer was donated above - # >> RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer - -You will get a warning if the donated buffer is not used, e.g., because -there are more donated buffers than can be used for the outputs:: - - # Execute `add` with donation of the buffers for both `x` and `y`. - # One of those buffers will be used for the result, but the other will - # not be used. - z = jax.jit(add, donate_argnums=(0, 1))(x, y) - # >> UserWarning: Some donated buffers were not usable: f32[2,3]{1,0} - -The donation may also be unused if there is no output whose shape matches -the donation:: - - y = jax.device_put(np.ones((1, 3))) # `y` has different shape than the output - # Execute `add` with donation of the buffer for `y`. - z = jax.jit(add, donate_argnums=(1,))(x, y) - # >> UserWarning: Some donated buffers were not usable: f32[1,3]{1,0} - Gradients contain `NaN` where using ``where`` ------------------------------------------------ @@ -777,7 +293,7 @@ can replace uses of :func:`jax.nn.relu`, etc. How can I convert a JAX Tracer to a NumPy array? ------------------------------------------------ When inspecting a transformed JAX function at runtime, you'll find that array -values are replaced by :class:`~jax.core.Tracer` objects:: +values are replaced by `jax.core.Tracer` objects:: @jax.jit def f(x): @@ -841,12 +357,34 @@ reducing :code:`XLA_PYTHON_CLIENT_MEM_FRACTION` from the default of :code:`.75`, or setting :code:`XLA_PYTHON_CLIENT_PREALLOCATE=false`. For more details, please see the page on `JAX GPU memory allocation`_. -.. _JIT mechanics: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-static-variables -.. _External callbacks in JAX: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html -.. _Pure callback example: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html#example-pure-callback-with-custom-jvp -.. _IO callback example: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html#exploring-jax-experimental-io-callback +.. _faq-data-placement: + +Controlling data and computation placement on devices +----------------------------------------------------- + +Moved to :ref:`sharded-data-placement`. + +.. _faq-benchmark: + +Benchmarking JAX code +--------------------- + +Moved to :ref:`benchmarking-jax-code`. + +.. _faq-donation: + +Buffer donation +--------------- + +Moved to :ref:`buffer-donation`. + + +.. _JIT mechanics: https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-static-variables +.. _External callbacks in JAX: https://docs.jax.dev/en/latest/notebooks/external_callbacks.html +.. _Pure callback example: https://docs.jax.dev/en/latest/notebooks/external_callbacks.html#example-pure-callback-with-custom-jvp +.. _IO callback example: https://docs.jax.dev/en/latest/notebooks/external_callbacks.html#exploring-jax-experimental-io-callback .. _Heaviside Step Function: https://en.wikipedia.org/wiki/Heaviside_step_function .. _Sigmoid Function: https://en.wikipedia.org/wiki/Sigmoid_function .. _algebraic_simplifier.cc: https://github.com/openxla/xla/blob/33f815e190982dac4f20d1f35adb98497a382377/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc#L4851 -.. _JAX GPU memory allocation: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html +.. _JAX GPU memory allocation: https://docs.jax.dev/en/latest/gpu_memory_allocation.html .. _dynamic linker search pattern: https://man7.org/linux/man-pages/man8/ld.so.8.html diff --git a/docs/fault_tolerance.rst b/docs/fault_tolerance.rst new file mode 100644 index 000000000000..153b3c159399 --- /dev/null +++ b/docs/fault_tolerance.rst @@ -0,0 +1,1524 @@ +.. raw:: html + + + + + +Fault Tolerant Distributed JAX +============================== + +Recall that `multi-controller JAX`_ allows you to run a JAX program distributed +across multiple machines. By default, if *any* of these machines fail, then +*every* machine will fail. That is, multi-controller JAX is not +**fault-tolerant** by default. + +This article has three parts. In the first part, we'll explain the basics of +how to write fault tolerant multi-controller JAX programs. In the second part, +we'll show some example fault-tolerant multi-controller JAX programs. In the +third part, we'll take a look under the covers at how multi-controller JAX +implements fault tolerance. + +.. warning:: + + JAX's support for fault tolerance is still experimental. It currently only + works fully on GPUs. It has rough edges, is probably buggy, and is subject + to change. Use at your own risk. + + +.. _part1: + +Part 1: Fault Tolerance Basics +------------------------------ + +Fault Intolerant By Default +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +By default, multi-controller JAX programs are not fault tolerant. If *any* +process crashes, then *all* other processes will also intentionally crash. To +make this concrete, consider the following trivial script, ``example.py``, that +initializes multi-controller JAX by calling ``jax.distributed.initialize`` and +then enters an infinite loop. + +.. literalinclude:: _static/fault_tolerance/while_loop.py + :language: python + :emphasize-lines: 12-18 + :lines: 15- + :linenos: + :caption: ``example.py`` + +Run ``example.py`` across four processes on a VM with four GPUs by running +the following four commands, each in a different terminal. The +``local_device_ids`` argument to ``jax.distributed.initialize`` ensures each +process is assigned only one of the four GPUs. We'll explain the +``heartbeat_timeout_seconds`` argument in just a second. + +.. code-block:: shell + + python example.py --i=0 --n=4 # in terminal 1 + python example.py --i=1 --n=4 # in terminal 2 + python example.py --i=2 --n=4 # in terminal 3 + python example.py --i=3 --n=4 # in terminal 4 + +When you run these commands, you'll see the processes dutifully printing out +the current time every second. Next, fail the fourth process: ``pkill -9 -f +'python example.py --i=3 --n=4'``. After about ten seconds, the other +processes will also terminate and spit out error messages that look something +like this: + +.. code-block:: + + E0926 17:26:32.075402 157988 coordination_service_agent.cc:332] Polled an error from coordination service (this can be an error from this or another task). + F0926 17:26:32.075587 157988 client.h:77] Terminating process because the JAX distributed service detected fatal errors. This most likely indicates that another task died; see the other task logs for more details. Disable Python buffering, i.e. `python -u`, to be sure to see all the previous output. absl::Status: UNAVAILABLE: The following tasks are unhealthy (stopped sending heartbeats): + /job:jax_worker/replica:0/task:3 + The tasks have crashed. Check the task logs for an earlier error, or scheduler events (e.g. preemption, eviction) to debug further. + + RPC: /tensorflow.CoordinationService/PollForError [type.googleapis.com/tensorflow.CoordinationServiceError=''] + +When a process in a multi-controller JAX program notices that a peer process +has crashed, it decides to crash as well. The processes `share fate`_. The +``heartbeat_timeout_seconds`` argument to ``jax.distributed.initialize`` +determines how long a process waits before concluding a peer process has died. +The first three processes crash about ten seconds after you kill the fourth +because we passed ``heartbeat_timeout_seconds=10`` as an argument to +``jax.distributed.initialize``. + +Surviving Faults +^^^^^^^^^^^^^^^^ + +We can disable fate-sharing by adding the +``--xla_gpu_nccl_terminate_on_error=false`` flag and the +``jax_enable_recoverability`` configuration option to ``example.py``, as shown +below: + +.. literalinclude:: _static/fault_tolerance/dont_fail.py + :language: python + :emphasize-lines: 1-2,15 + :linenos: + :lines: 15- + +Again run the script across four processes and then kill the fourth. Notice +that now, the other three processes happily continue executing. + +Next try failing process 0. Notice that all four processes terminate with +error messages that look something like the following: + +.. code-block:: + + E0929 17:42:48.594192 1044529 coordination_service_agent.cc:332] Polled an error from coordination service (this can be an error from this or another task). + F0929 17:42:48.594200 1044529 client.h:77] Terminating process because the JAX distributed service detected fatal errors. This most likely indicates that another task died; see the other task logs for more details. Disable Python buffering, i.e. `python -u`, to be sure to see all the previous output. absl::Status: UNAVAILABLE: Failed to send RPC to coordination service. Either the leader task was preempted/died/restarted unexpectedly or this task is experiencing network issues. Check earlier logs from 1) this task, 2) the leader (usually slice 0 task 0), and 3) cluster scheduler to debug further. + Additional GRPC error information from remote target coordination_service while calling /tensorflow.CoordinationService/PollForError: + :UNKNOWN:Error received from peer {grpc_message:"Socket closed", grpc_status:14} + +Process 0 is special. If process 0 fails, every process will fail, even with +fate-sharing disabled. Why? Process 0 runs an RPC service called the +coordination service that all processes use to coordination with each other. If +the coordination service fails, all other processes have no choice but to fail. +See :ref:`part3` for more details. + +Getting Stuck in Collectives +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +``example.py`` is now able to survive faults, but the processes do not +communicate with each other at all. Any realistic multi-controller JAX program +would involve communication between the processes (otherwise, what's the point +of using multi-controller JAX?). Let's edit ``example.py`` so that the +processes perform a collective ``jnp.sum`` in every iteration of the loop. + +.. literalinclude:: _static/fault_tolerance/collectives.py + :language: python + :emphasize-lines: 27-32 + :linenos: + :lines: 15- + +In the highlighted code above, the processes create an array ``x`` sharded +across the four processes and then perform a distributed ``jnp.sum``. Again run +the program and fail the fourth process. You'll notice that the first three +process do not crash, but they do get *stuck*. By default, if a process fails +while participating in a distributed computation (like ``jnp.sum``), then the +rest of the processes participating in the computation will get stuck +*forever*. + +.. _`canceling_collectives`: + +Cancelling Collectives +^^^^^^^^^^^^^^^^^^^^^^ + +We can avoid getting stuck by cancelling collectives with a failed participant. +We can enable collective cancelling by providing a few more flags and +environment variables, highlighted below. + +.. literalinclude:: _static/fault_tolerance/cancel_collectives.py + :language: python + :emphasize-lines: 1-8,22,33-35 + :linenos: + :lines: 15- + +We also need to insert a call to +``jax.experimental.multihost_utils._live_devices`` to make the script work. You +should normally not do this. You should instead use the ``live_devices`` API +that we'll introduce momentarily. For now, ``_live_devices`` is a hack to get +the script working before we explain the proper API. + +Again run the script and fail the fourth process. The first three processes +will be stuck in their call to ``jnp.sum``, but after about ten seconds, the +call will be cancelled and ``jnp.sum`` will raise an exception that looks +something like this: + +.. code-block:: + + jaxlib._jax.XlaRuntimeError: FAILED_PRECONDITION: Task with incarnation id 3446767950926952685 is not connected + + +Knowing Who's Alive +^^^^^^^^^^^^^^^^^^^ + +After a process dies, the remaining *alive* procesess need to learn who is dead +and who is alive. For this, we can use the core JAX fault tolerance API: +``live_devices``. ``live_devices`` is a context manager that takes a list of +devices as an argument and returns the subset of these devices that are alive. +Below, we edit ``example.py`` to call ``live_devices``. + +.. literalinclude:: _static/fault_tolerance/live_devices.py + :language: python + :emphasize-lines: 34-46 + :linenos: + :lines: 15- + +In the highlighted code above, we call ``live_devices`` with all devices +(``jax.devices()``) to get the set ``devices`` of live devices. We then shard +array ``x`` over these devices and perform a ``jnp.sum``. If a process fails +while executing the ``jnp.sum``, then ``jnp.sum`` will be cancelled and raise +an exception on the remaining live devices. Technically, the collective is not +guaranteed to fail. We'll revisit this in :ref:`atomicity`. For now, assume it +will fail. + +.. note:: + + ``jax.devices()`` always returns the set of *all* devices, even if some of + these devices are on failed processes. Use + ``jax.experimental.multihost_utils.live_devices`` to learn which of these + devices are live. + +Again run the script and fail the fourth process. Notice that the remaining +three alive processes catch the exception raised by ``jnp.sum`` and continue to +the next iteration of the while loop. In this next iteration, ``devices`` does +not include the device on the failed fourth process. The three alive processes +continue to execute correctly even though the fourth process is dead. + +Next, restart the fourth process. Notice that after the fourth process +restarts, its device is again included in the set of alive devices returned by +``live_devices``. All four processes then continue executing normally. + +At first blush, ``live_devices`` seems trivial. You give it a list of devices, +and it returns the ones that are alive. How complicated can that be? +Unfortunately, as with `many things in distributed systems`_, there are a lot +subtleties to iron out. Next, we explain the **barrier** semantics and +**atomicity** properties of ``live_devices``. + +Barrier Semantics +^^^^^^^^^^^^^^^^^ + +Recall that every process in a `multi-controller JAX`_ program should run in +lockstep. The processes should execute the same instructions in the same order. +Failing to do so will *almost certainly* lead to deadlocks, crashes, or +anomalous behavior. + +In the context of ``live_devices``, we need to ensure that every process agrees +on which processes are currently alive. This is difficult to ensure because +every process is executing independently at potentially different speeds and +processes can fail at any time. Consider again the ``example.py`` script from +above running on four processes. Imagine process 1 and 2 call ``live_devices``, +then process 4 fails, and then process 3 calls ``live_devices``. Process 1 and +2 might think process 4 is alive while process 3 thinks it is dead. + +To avoid situations like these, ``live_devices`` guarantees that it returns the +same set of live devices to every process. It accomplishes this using a +barrier. A call to ``live_devicess(devices)`` blocks until every live process +hosting a device in ``devices`` has also called ``live_devices``. Once every +live process is in the ``live_devices`` barrier, ``live_devices`` returns the +same set of live devices to every process. + +.. important:: + + ``live_devices`` uses a barrier to ensure that it will *always* return the + same set of live devices to every live process. + +Because ``live_devices`` implements a barrier it is susceptible to deadlock if +used improperly. We recommend only having a single ``with live_devices`` block +in a program. Multiple calls to ``live_devices`` is hard to reason about and +can lead to deadlock. + +See :ref:`part3` for details on how the ``live_devices`` barrier is implemented +as well as a formal semantics based on `linearizability`_. + +.. _atomicity: + +Atomicity +^^^^^^^^^ + +A distributed computation is **atomic** if every participant in the computation +agrees on whether the operation succeeds or fails. In the ``example.py`` script +above, we saw that when a process failed during the execution of a ``jnp.sum``, +then ``jnp.sum`` would abort and raise an exception on the remaining live +processes. So ``jnp.sum`` is atomic? + +Unfortunately, it's not. + +When a process fails during the execution of a collective operation (like +``jnp.sum``), the remaining processes may cancel the operation and raise an +exception or they may complete the operation successfully. Collective +operations in JAX do not have any inherent atomicity properties. + +If collective operations are not atomic, however, then multi-controller JAX +processes might diverge. For example, if a process fails during a training step +of a machine learning model, some processes might detect the failure and roll +the model back to a checkpoint while other processes might think the step +succeeded and keep training. + +To avoid the complexities of non-atomic execution, ``live_devices`` provides +its own atomicity guarantees despite the fact that collectives are not atomic. +Specifically, the body of a ``with live_devices`` block is guaranteed to either +complete successfully on all processes or raise an exception on all processes. +More concretely, if we consider the code snippet below, either every process +executes branch A or every process executes branch B. It is impossible for some +processes to execute A while others execute B. + +.. code-block:: python + + try: + with live_devices(jax.live_devices()) as devices: + ... + except Exception as e: + ... # Branch A + else: + ... # Branch B + +.. warning:: + + A ``with live_devices`` block does not guarantee atomicity if the code + block non-deterministically raises exceptions for reasons other than + collectives that fail because of a crashed process. For example, if one + process raises an exception because it runs out of memory, this exception + will not be propagated to the other processes. + +Recall that JAX uses `asynchronous dispatch`_. Operations like ``jnp.sum`` do +not block until the operation is complete. Instead, they return ``jax.Arrays`` +that act as futures. This asynchrony can interact with ``live_devices`` in +unexpected ways. For example, consider the following code that performs a +``jnp.sum``, assigns the result to ``y``, and then prints ``y``: + +.. code-block:: python + + x = ... + y = ... + try: + with live_devices(jax.live_devices()) as devices: + y = jnp.sum(x) + except Exception as e: + ... # Branch A + else: + ... # Branch B + print(y) + +Imagine that the ``with live_devices`` block executes successfully on all +processes. That is, all processes execute branch B. This only guarantees that +every process successfully created a future and assigned it to ``y``. The +actual computation of the ``jnp.sum`` may be delayed until outside the block. +Thus, some processes might successfully complete the ``jnp.sum`` and print the +value of ``y`` while other processes fail to complete the ``jnp.sum`` and raise +an exception when trying to print ``y``. + +To avoid this, use ``jax.block_until_ready`` to ensure that computations are +performed within the ``with live_devices`` block. The code snippet below, which +now calls ``jax.block_until_ready`` when assigning to ``y``, guarantees that +every process will successfully execute the ``jnp.sum`` or every process will +raise an exception. + +.. code-block:: python + + x = ... + y = ... + try: + with live_devices(jax.live_devices()) as devices: + y = jax.block_until_ready(jnp.sum(x)) + except Exception as e: + ... # Branch A + else: + ... # Branch B + print(y) + +See :ref:`part3` for details on how atomicity is implemented. + +Part 2: Examples +---------------- + +``live_devices`` is not a panacea; it is a tool. It does not magically make +multi-controller JAX programs fault tolerant. Rather, it allows you to +implement fault tolerance yourself in the way that is best for your +application. + +The exact details of how you implement fault-tolerance will vary greatly based +on the nature of your application. In this section, we present some examples of +how to use ``live_devices``. The examples are meant to be illustrative but not +prescriptive. There are many other ways to implement fault tolerance. + +Example 1: Fault Tolerant Data Parallel Training +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In this example, we train a trivial single-parameter linear model (:math:`y = +\alpha x`) with data parallelism across four processes. The example is +contrived---you would never train a model with a single parameter across four +machines---but we intentionally keep the model simple to focus on fault +tolerance. + +Data parallelism makes implementing fault tolerance relatively straightforward. +Because every process has a full copy of the model weights, if a process fails, +we can simply ignore it and continue training. This example tolerates an +arbitrary number of process failures (excluding process 0), but once a process +fails, we assume it does not recover. The next example shows how to handle +process recovery. + +First, we set some flags to disable fate-sharing and enable collective +cancelling. We also make the necessary imports and define some flags. + +.. literalinclude:: _static/fault_tolerance/data_parallelism.py + :language: python + :lines: 15-33 + :lineno-start: 1 + +Next, we define a ``replicated`` function that returns an array replicated +across a set of devices. Note that ``replicated`` doesn't actually move any +data. It assumes the argument ``x`` already has equal value across all +processes. It simply returns a new view of that data, in a process-spanning +`jax.Array` with a replicated sharding. + +.. literalinclude:: _static/fault_tolerance/data_parallelism.py + :language: python + :lines: 35-49 + :lineno-start: 21 + +We define a similar ``sharded`` function that returns an array sharded across a +set of devices. Again, ``sharded`` is not actually moving any data between +processes. + +.. literalinclude:: _static/fault_tolerance/data_parallelism.py + :language: python + :lines: 52-64 + :lineno-start: 38 + +Now, we're ready to start writing our training loop. We begin by initializing +multi-controller JAX by calling ``jax.distributed.initialize``. + +.. literalinclude:: _static/fault_tolerance/data_parallelism.py + :language: python + :lines: 67-76 + :lineno-start: 53 + +Then, we define our simple linear model, generate some random training data, +and initialize some basic hyperparameters. + +.. literalinclude:: _static/fault_tolerance/data_parallelism.py + :language: python + :lines: 78-97 + :lineno-start: 64 + +Finally, we enter the main training loop. + +.. literalinclude:: _static/fault_tolerance/data_parallelism.py + :language: python + :lines: 99-125 + :lineno-start: 85 + +- Every iteration of the loop, we call ``live_devices`` to learn which devices + are currently alive. +- We then ensure that the model weights are replicated across these devices and + ensure that the training data is sharded across these devices. Note that this + doesn't actually move any data between the devices; it simply creates JAX + arrays with the appropriate replication and sharding metadata. +- We call ``loss_and_grad`` to compute the gradient of the weights with respect + to the current batch of data and then compute the new weights. Notice that we + assign the new weights to ``new_weights`` rather than assigning to + ``weights`` in case the training step fails. We also call + ``jax.block_until_ready`` to ensure that every process has computed the new + weights when we exit the ``live_devices`` block. +- If no processes failed during the execution of the training step, then the + ``else`` branch is taken. The step is incremented, and ``weights`` is + updated. Otherwise, an exception will be raised and the ``except`` branch is + taken. In this case, we do not update ``step`` or ``weights`` and retry the + step on the next iteration with the new set of live devices. + +Here is the full example: + +.. literalinclude:: _static/fault_tolerance/data_parallelism.py + :language: python + :linenos: + :lines: 15- + +Example 2: Fault Tolerant Data Parallel Training With Recovery +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Now, we modify the example above to allow failed processes to recover. When a +process recovers, it needs to receive the current step and model weights. +Because we assume process 0 never fails---recall that if process 0 fails, every +process will fail---we have process 0 send the current step and weights to +recovering processes. + +First, we define ``send`` and ``recv`` functions that use a ``shard_map`` to +send data from one device to another. The sender calls ``send``, and the +receiver calls ``recv``. + +.. literalinclude:: _static/fault_tolerance/data_parallelism_with_recovery.py + :language: python + :lines: 69-90 + :lineno-start: 55 + +``allgather`` performs an AllGather of a single float across a set of devices. + +.. literalinclude:: _static/fault_tolerance/data_parallelism_with_recovery.py + :language: python + :lines: 93-100 + :lineno-start: 79 + +Finally, we modify the training loop to handle recovering processes, as shown +in the highlighted code below. + +.. literalinclude:: _static/fault_tolerance/data_parallelism_with_recovery.py + :language: python + :lines: 135-178 + :lineno-start: 121 + :emphasize-lines: 7-22 + +Recovery is a two-step process. First, we need to detect which processes are +recovering. Second, we need process 0 to send the step and weights to the +recovering processes. + +1. To detect which processes are recovering, we perform an AllGather on all + live processes' steps. When a failed process recovers, its ``step`` will be + ``0``, while the ``step`` on process ``0`` will be some positive number, so + if a process' step is not equal to process 0's step, then it is recovering. +2. Then, we call the ``send`` and ``recv`` functions we defined above to + transfer the current step and model weights from process 0 to the recovering + processes. + +Here is the full example: + +.. literalinclude:: _static/fault_tolerance/data_parallelism_with_recovery.py + :language: python + :linenos: + :lines: 15- + +.. _part3: + + +Part 3: Implementation Details +------------------------------ + +We now take a deep dive into the architecture of multi-controller JAX and the +semantics and implementation of ``live_devices``. If you're only interested in +writing fault-tolerant multi-controller JAX programs, the first two parts of +this article suffice. + +The Coordination Service +^^^^^^^^^^^^^^^^^^^^^^^^ + +When you launch a multi-controller JAX program, the first process (i.e. process +0) runs a standalone RPC server called the **coordination service**. Moreover, +all processes (including process 0) create an RPC client to the coordination +service. Concretely, the ``coordinator_address`` argument of +:func:`jax.distributed.initialize` is the address of the coordination service. +This argument lets process 0 know on what address to run the server, and it +lets all processes know which address to connect to. + +The coordination service implements the multi-controller JAX **control plane**. +For example, it can perform a distributed barrier across all processes, and it +implements a key-value store that processes can use to exchange small amounts +of metadata. Note, however, that the **data plane** (e.g., all collective +operations on program data) is implemented directly between the processes and +does not involve the coordination service. + +One of the most important functionalities of the coordination service is health +checking. Every process periodically sends a heartbeat to the coordination +service. If a process fails, it stops sending heartbeats. If the coordination +service hasn't received a heartbeat from a process for a while, it assumes the +process has failed. + +This is shown in the interactive visualization below. The coordination service +is shown at the top and three multi-controller JAX processes are shown at the +bottom. Note how the processes periodically send heartbeats to the controller, +and the controller keeps track of the health of each process based on when it +last received a heartbeat. Try failing process 2 by clicking the "Fail" button. +Observe how the process stops sending heartbeats and the coordination service +eventually considers the process dead. + +.. raw:: html + +
+ + +By default, when the coordination service detects that a process has failed, it +sends a message to all other processes requesting that they self-terminate. In +other words, all processes in a multi-controller JAX program `share fate`_. +Again fail process 2 in the visualization below by clicking the "Fail" button +and observe how the coordination service notifies the other processes to fail. + +.. raw:: html + +
+ + +This fate sharing means that multi-controller JAX programs are not at all +fault-tolerant. They are fault-*intolerant*. To enable fault-tolerance, we +need to do two things: + +- First, we need to remove fate sharing and allow processes to continue + executing even when a peer process has died. This can be enabled using the + ``jax_enable_recoverability`` option, as described in :ref:`part1`. We'll + assume that this option is set. +- Second, we need to provide an API that processes can use to learn which + processes are alive and which have failed. This is the ``live_devices`` API + introduced in :ref:`part1`. + +There is a surprising amount of technical depth and subtlety in implementing +the ``live_devices`` API. We'll walk through the design and implementation of +the API step-by-step. We'll begin by introducing a simpler ``live_processes`` +API and slowly improve it until we arrive at the ``live_devices`` API. + +Live Processes +^^^^^^^^^^^^^^ + +Let's try to design a new hypothetical JAX API: ``jax.live_processes``. As the +name suggests, we want ``jax.live_processes()`` to return the set of all +currently alive processes. Here is a naive but (as we'll see momentarily) +incorrect implementation. When a process calls ``jax.live_processes()``, it +sends an RPC request to the coordination service. Remember that the +coordination service already uses heartbeats to keep track of which processes +are dead and which are alive, so when it receives a ``jax.live_processes`` +request, it responds with the set of processes it thinks are alive. + +This is illustrated below. Below each process is a "Call live_processes" +button. You can click this button to make the process call +``jax.live_processes``. Note how the coordination service replies to a +``live_processess`` request with the set of alive processes. Fail process 2 by +clicking the "Fail" button and see how it affects later calls to +``jax.live_processes``. + +.. raw:: html + +
+ + +This naive implementation is simple but incorrect. It is crucial that all +processes in a multi-controller JAX job execute the same instructions in the +same order. If the processes start to diverge, by executing different code +paths in the JAX program, the job will behave erratically. Most likely, it will +crash or hang or produce garbage values, and most certainly it will be very +hard to reason about. + +Our naive implementation of ``jax.live_processes`` can very easily lead to +divergence. For example, consider a multi-controller JAX job with three +processes. If process 0 and 1 both call ``jax.live_processes`` around the same +time that process 2 fails, the coordination service might report to process 0 +that all processes are alive but report to process 1 that only processes 0 and +1 are alive. Try to produce this scenario in the visualization below: + +.. raw:: html + +
+ + +If processes disagree on which processes are alive, they will almost certainly +diverge. Thankfully, we can avoid this divergence by augmenting +``jax.live_processes`` with barrier semantics. + +Barrier Semantics +^^^^^^^^^^^^^^^^^ + +Let's change the implementation of ``jax.live_processes`` so that when the +coordination service receives a ``jax.live_processes()`` request, it does not +reply right away. Instead, the coordination service only replies once *every* +live process has called ``jax.live_processes()``. Once every alive process has +entered the ``jax.live_processess()`` barrier, the coordination service returns +the set of live processes. Crucially, the coordination service returns the +*same* set of live processes to all processes, which prevents the processes +from diverging. + +This is illustrated below. Note that coordination server now keeps track of +which devices are in the ``live_processes`` barrier. Try calling +``live_processes`` from every process. Notice how the coordination service +doesn't respond until every process has entered the barrier. Then fail process +2 and call ``live_processes`` from process 0 and process 1. + +.. raw:: html + +
+ + +Formal Semantics +^^^^^^^^^^^^^^^^ + +Distributed systems are notoriously complex. Machines can fail at arbitrary +times, and network messages can be dropped, delayed, and reordered. In this +section, we introduce a formal semantics of the ``jax.live_processes`` API to +help tame this complexity. Thinking rigorously about the semantics of +``jax.live_processes`` will help us understand the behavior of the API even in +pathological executions. + +We'll base the formal semantics of ``jax.live_processes`` on +`linearizability`_: a popular formalism used to define the semantics of many +distributed APIs. Concretely, we model our distributed system as a number of +processes. Each process serially performs a number of events. There are four +types of events: + +1. A process can **start** (👶). We'll assume that when a process starts, it + connects to the coordination service, so the coordination service is aware + that is has started. +2. A process can **fail** (💀). Unlike starting, the coordination service may + not immediately be aware that a process has failed. +3. A process can **send** a ``jax.live_processes`` request to the coordination + service. +4. A process can **receive** a reply to a ``jax.live_processes`` request from + the coordination service. + +Below is a diagram of an execution of three processes: 0, 1, and 2. Time +progresses from left to right. First, all three processes start. This is shown +with the baby emojis. Then all three processes send ``jax.live_processes`` +requests to the coordination service. This is shown as the start of the thick +colored regions. Later, all three processes receive a reply from the +coordination service with ``0,1,2`` as the set of live devices. + +.. raw:: html + +
+ + + 0 + 1 + 2 + + + + + + + + 👶 + + 0,1,2 + + + 👶 + + 0,1,2 + + + 👶 + + 0,1,2 + +
+ +In this simple execution, it is clear that ``jax.live_processes`` is behaving +correctly. We can formalize this intuition with the following formal semantics. + +.. attention:: + + An execution is valid if whenever ``jax.live_processes`` returns a set ``P`` + of live processes, there exists an instantaneous moment in time at which + every process in ``P`` was in the ``live_processes`` barrier and every other + process was dead. An implementation of ``live_processes`` is correct if + it only allows for valid executions. + +Later, we will amend these formal semantics to cover some subtle corner cases, +but assume this simplified semantics for now. + +In the example above, ``live_processes`` returns ``0,1,2``. In the +visualization below, we show that there does exist an instantaneous moment of +time in which processes 0, 1, and 2 are all in the barrier and all other +processes (there are none) are dead. The moment in time is drawn as a vertical +red bar. + +.. raw:: html + +
+ + + 0 + 1 + 2 + + + + + + + + 👶 + + 0,1,2 + + + 👶 + + 0,1,2 + + + 👶 + + 0,1,2 + + + + +
+ +There is nothing special about the specific moment in time we chose in the +visualization above. All that's important is that *there exists some* moment in +time where all processes in `P` are in the barrier and all other processes are +dead. There are many moments in time that satisfy this property, as shown +below. + +.. raw:: html + +
+ + + 0 + 1 + 2 + + + + + + + + 👶 + + 0,1,2 + + + 👶 + + 0,1,2 + + + 👶 + + 0,1,2 + + + + + + + +
+ +In the next example, processes 0 and 1 start, call ``jax.live_devices``, and +receive ``0,1`` as a reply. Process 2 is dead throughout the execution. + +.. raw:: html + +
+ + + 0 + 1 + 2 + + + + + + + + 👶 + + 0,1 + + + 👶 + + 0,1 + + + 💀 + +
+ +This is a valid execution under our formal semantics because there exists a +moment a time in which processes 0 and 1 are in the barrier and process 2 is +dead. + +.. raw:: html + +
+ + + 0 + 1 + 2 + + + + + + + + 👶 + + 0,1 + + + 👶 + + 0,1 + + + 💀 + + + + +
+ +In the following execution, process 0 calls ``jax.live_processes`` and receives +a reply of ``0``. Process 1 calls ``jax.live_processes``, but dies before +receiving a reply. + +.. raw:: html + +
+ + + 0 + 1 + + + + + + + 👶 + + 0 + + + 👶 + + 💀 + +
+ +Is this a valid execution? Yes. There exists a moment in time at which process +0 is in the barrier and process 1 is dead, as shown below. Even though process +1 called ``jax.live_processes``, it is not guaranteed that process 1 will be +included in the coordination service's response. + +For example, process 1's ``jax.live_processes`` request may have been dropped +by the network and never received by the coordination service. So from the +coordination service's perspective, process 1 is thoroughly dead and never even +entered the ``live_processes`` barrier. + +.. raw:: html + +
+ + + 0 + 1 + + + + + + + 👶 + + 0 + + + 👶 + + 💀 + + + + +
+ +What about the same exact execution, except that process 0 now receives the +reply ``0,1`` from the coordination service? + +.. raw:: html + +
+ + + 0 + 1 + + + + + + + 👶 + + 0,1 + + + 👶 + + 💀 + +
+ +Again, this is a valid execution, as witnessed below. Intuitively, the +coordination service could have received ``jax.live_processes`` requests from +both processes 0 and 1 and sent the reply ``0,1`` to both. While this reply was +in the network, process 1 failed. Thus, even though process 1 is dead when +process 0 receives a reply, the execution is still valid. + +.. raw:: html + +
+ + + 0 + 1 + + + + + + + 👶 + + 0,1 + + + 👶 + + 💀 + + + + +
+ +This point bears repeating. If ``jax.live_processes`` returns a set ``P`` of +processes, it does not mean that all processes in ``P`` are *currently* alive +and all other processes are *currently* dead. It only means that *there existed +a point in time* when this was true. + +In the following execution, process 1 calls ``jax.live_processes`` and fails. +Later, process 0 starts, calls ``jax.live_processes``, and receives ``0,1`` as +a reply. + +.. raw:: html + +
+ + + 0 + 1 + + + + + + + 👶 + + 0,1 + + + 👶 + + 💀 + +
+ +Using the formal semantics described thus far, this is *not* a valid execution. +There is never a point in time where process 0 and 1 are both alive. However, +this *should* be a valid execution. + +The reason has to do with the unavoidable fact that in a distributed system, it +is impossible to detect failures with 100% accuracy. If the coordination +service hasn't received heartbeats from a process in a while, it considers the +process dead. But, the coordination service cannot determine with 100% +certainty when the process died or if the process is actually dead at all. +Maybe the process died a long time ago, or maybe it died very recently, or +maybe it is alive but on the other side of a network partition. + +Let's return to the execution above for a concrete example. Imagine the +coordination service successfully received process 1's ``live_processes`` +request. Then, process 1 failed but the coordination service didn't detect the +failure immediately. In the meantime, the coordination service received process +0's ``live_processes`` request. At this point, the coordination service thought +both processes were alive and saw that both processes were in the barrier, so +it naturally returned ``0,1`` to both processes (though only process 0 received +the reply because process 1 was dead). + +The coordination service thought process 1 was alive when it was dead. And +sometimes the coordination service might think a process is dead when it is +alive. Though not ideal, we need to accommodate executions like this because +they are unavoidable. + +We amend our formal semantics and allow ourselves to move a failure either +earlier or later in time, though we cannot move a failure past a different +event from the same process. Intuitively, we can move a failure from when it +actually happened to the point in time when the coordination service thought it +happened. Continuing the example above, we can delay the failure of process 1 +to create a moment in time in which both processes 0 and 1 are in the barrier, +witnessing the fact that the execution is valid. + +.. raw:: html + +
+ + + 0 + 1 + + + + + + + 👶 + + 0,1 + + + 👶 + + + + + + 💀 + + + + + + + +
+ +Consider a similar execution below. + +.. raw:: html + +
+ + + 0 + 1 + + + + + + + 👶 + + 0 + + + 👶 + + 💀 + +
+ +As is, there is no moment in time in which process 0 is alive and process 1 is +dead. However, if we move the failure of process 1 leftwards, there is. How +might such an execution arise? Imagine process 1 is partitioned from the +coordination service. The coordination service doesn't receive any messages +from process 1, including its heartbeats. This leads the coordination service +to conclude that process 1 is dead, even though it isn't. Then, the +coordination service receives process 0's ``live_processes`` request and +responds with ``0``. + +.. raw:: html + +
+ + + 0 + 1 + + + + + + + 👶 + + 0 + + + 👶 + + + + + + 💀 + + + + + + + +
+ +We cannot move a process failure past the process' other events, however. For +example, the following execution is *invalid* because no matter where we move +the failure of process 1, there is never a moment in time where both processes +are in the barrier. + +.. raw:: html + +
+ + + 0 + 1 + + + + + + + 👶 + + 0,1 + + + 👶 + 👶 + + + + + + 💀 + + +
+ +With these formal semantics, we can make sense of even complex executions. For +example, consider the following execution. + +.. raw:: html + +
+ + + 0 + 1 + 2 + + + + + + + + 👶 + + 0 + + 0,2 + + + 👶 + + 💀 + 👶 + + 💀 + + + 👶 + + 💀 + +
+ + +After moving some process failures, we see the execution is valid. + +.. raw:: html + +
+ + + 0 + 1 + 2 + + + + + + + + 👶 + + 0 + + 0,2 + + + 👶 + + 💀 + 👶 + + 💀 + + + 👶 + + 💀 + + + + + +
+ +The following execution, on the other hand, is invalid. + +.. raw:: html + +
+ + + 0 + 1 + 2 + + + + + + + + 👶 + + 0,2 + + + 👶 + + 1 + 💀 + + + 👶 + + 💀 + +
+ + +Atomicity +^^^^^^^^^ + +Equipped with ``jax.live_processes``, let's try to write some fault-tolerant +multi-controller JAX code. + +.. code-block:: python + + step = 0 + while True: + # Get the devices on all live processes. + procs = jax.live_processes() + devices = [d for d in jax.devices() if d.process_index in procs] + + # Shard array x over these devices. + mesh = jax.make_mesh((len(devices),), ("i",), devices=devices) + spec = jax.sharding.PartitionSpec("i") + sharding = jax.sharding.NamedSharding(mesh, spec) + x = jax.make_array_from_process_local_data(sharding, np.ones(1)) + + # Try to perform a jnp.sum. + try: + print(jnp.sum(x)) + except: + # jnp.sum failed. + pass + else: + # jnp.sum succeeded. + step += 1 + +The code repeatedly + +- calls ``jax.live_processes`` to learn which processes are alive, +- computes the set of devices on the healthy processes, +- shards an array across these healthy devices, +- performs a ``jnp.sum`` (i.e. AllReduce) on the array, and +- increments ``step`` if the ``jnp.sum`` succeeds. + +This code *looks* correct, but it has a very subtle bug. Assume the ``jnp.sum`` +is being performed across a set of processes ``P``. If one (or more) of the +processes in ``P`` fails during the execution of the ``jnp.sum``, then +``jnp.sum`` can behave differently on different processes. Some processes in +``P`` might see ``jnp.sum`` return the correct result. Other processes might +see ``jnp.sum`` raise an exception. Others might see ``jnp.sum`` return an +incorrect result. + +.. warning:: + + If a process fails during a collective operation, the operation may behave + differently on different processes. + +This means that the processes executing the code example above might diverge. +Some might increment ``step``, and some might not. In the trivial code example +above, this divergence is benign, but in a real program, the divergence would +likely lead to a crash, a deadlock, or garbage outputs. For example, if a +multi-controller JAX program is training a model with data parallelism and +starts to diverge, some processes might roll back their model weights to a +previous checkpoint while others continue training, leading to a +"franken-model" where nobody agrees on what the model weights are supposed to +be. + +To write fault-tolerant code that does not diverge, we want **atomicity**. When +executing a block of code (like the ``jnp.sum`` above), we either want *every* +process to run the code successfully, or *every* process to learn that the code +failed to execute successfully. We don't want some processes succeeding and +others failing. + +Thankfully, we can achieve atomicity with a very simple trick: call +``live_processes`` twice, once before a code block and once after. If all the +processes that were alive before the block are also alive after the block, then +the code block executed successfully on all live processes. On the other hand, +if any process died, then all remaining processes can agree the code block +failed to execute properly. Here's a sketch of what that might look like: + +.. code-block:: python + + # Get the set of live processes before the code block. + procs_before = jax.live_processes() + + # Execute the code block. + ... + + # Get the set of live processes after the code block + procs_after = jax.live_processes() + if procs_before == procs_after: + # The code block executed successfully on all processes in + # procs_before. + pass + else: + # The code block did not execute successfully. All processes will + # agree it failed. + pass + +The code above should give you a rough idea of how to use two calls to +``live_processes`` to achieve atomicity, but there are still a handful of small +issues we need to address before it is fully correct. For example, + +- What if the code block throws an exception? We need to catch the exception + and still call ``live_processess`` the second time and then re-raise the + exception. +- What if a process fails after the first call to ``live_processes`` and + recovers before the second call? Wouldn't the code block fail but the + processes before and after be the same? Every time a process starts, it + generates a random **incarnation id**. In addition to checking that the set + of processes hasn't changed, we also check that their incarnation ids haven't + changed. +- What if a process recovers and its first call to ``live_processes`` matches + up with a different process' second call to ``live_processes``? Couldn't this + lead to a deadlock? Yes. We can avoid the problem by only calling + ``live_processes`` at a single program point. We can be clever and use a + single call to ``live_processes`` for two purposes. It can be used to check + that the set of processes hasn't changed since the previous call to + ``live_processes``, and it can be used to generate the set of live processes + that should be used the next time the atomic code block is executed. + +All these details are handled and abstracted away by the ``jax.live_devices`` +API introduced in :ref:`part1`. ``jax.live_devices`` is a context manager that +guarantees the atomic execution of a block of code. In the code snippet below, +``devices`` is a list of the devices on all live processes. The code block +``A`` will execute atomically across these processes. That is, either every +process will see the code raise an exception (branch ``B``) or every process +will see the code succeed (branch ``C``). + +.. code-block:: python + + try: + with live_devices() as devices: + pass # A + except Exception as e: + pass # B + else: + pass # C + +Cancelling Collectives +^^^^^^^^^^^^^^^^^^^^^^ + +As mentioned in :ref:`canceling_collectives`, if a process participating in a +collective fails, then the other participating processes get stuck forever. We +need to explicitly cancel these collectives to allow the alive participants to +make progress. While the ``live_devices`` API is supported on all JAX backends +(i.e. CPU, GPU, TPU), cancelling collectives is only supported by the GPU +backend. Here, we briefly explain some of the implementation details behind +collective cancelling. + +The GPU backend implements collectives using `NCCL`_, NVIDIA's collective +communication library. When a set of processes wants to perform a collective, +they form a **NCCL communicator**. Processes can then repeatedly perform +collectives using this communicator. Creating a communicator is expensive---it +requires network communication---so the JAX backend caches communicators keyed +by the set of participating processes and their incarnation ids. + +Internally, a JAX client polls the coordination service for the current status +of every process. If a client ever detects that a process is dead or has +restarted with a new incarnation id, then the client aborts all communicators +with the failed incarnation id in its cache key. + +.. _asynchronous dispatch: https://docs.jax.dev/en/latest/async_dispatch.html +.. _linearizability: https://cs.brown.edu/~mph/HerlihyW90/p463-herlihy.pdf +.. _many things in distributed systems: https://en.wikipedia.org/wiki/Fallacies_of_distributed_computing +.. _multi-controller JAX: https://docs.jax.dev/en/latest/multi_process.html +.. _NCCL: https://developer.nvidia.com/nccl +.. _reference: https://docs.jax.dev/en/latest/config_options.html#jax_enable_recoverability +.. _share fate: https://en.wikipedia.org/wiki/Fate-sharing diff --git a/docs/ffi.ipynb b/docs/ffi.ipynb index b622fba9d5bc..ceb5a0f8dd36 100644 --- a/docs/ffi.ipynb +++ b/docs/ffi.ipynb @@ -439,7 +439,7 @@ "As far as JAX is concerned, the foreign function is a black box that can't be inspected to determine the appropriate behavior when differentiated.\n", "Therefore, it is the {func}`~jax.ffi.ffi_call` user's responsibility to define a custom derivative rule.\n", "\n", - "More details about custom derivative rules can be found in the [custom derivatives tutorial](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function.\n", + "More details about custom derivative rules can be found in the [custom derivatives tutorial](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function.\n", "In this case, we actually define two new FFI calls:\n", "\n", "1. `rms_norm_fwd` returns two outputs: (a) the \"primal\" result, and (b) the \"residuals\" which are used in the backwards pass.\n", @@ -730,13 +730,13 @@ "source": [ "This clearly (to us!) isn't the optimal partitioning of this function, but it's the best that JAX/XLA can do with the information given.\n", "\n", - "To generate better partitioning logic, we can use {func}`~jax.experimental.shard_map.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`, and we discuss both options here.\n", + "To generate better partitioning logic, we can use {func}`~jax.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`, and we discuss both options here.\n", "That being said, it's not straightforward to generate _optimal_ partitioning for all inputs, because sometimes this would require algorithmic changes.\n", "Specifically, let's add support for \"batch partitioning\", which handles the case where the data are sharded on batch dimensions, but sharding on the last dimension will always require in re-sharding.\n", "\n", "### Using `shard_map`\n", "\n", - "If you are using manual sharding control via {func}`~jax.experimental.shard_map.shard_map`, any FFI calls in your program should already partition appropriately:" + "If you are using manual sharding control via {func}`~jax.shard_map`, any FFI calls in your program should already partition appropriately:" ] }, { @@ -746,9 +746,8 @@ "outputs": [], "source": [ "from functools import partial\n", - "from jax.experimental.shard_map import shard_map\n", "\n", - "@partial(shard_map, mesh=mesh, in_specs=P(\"x\", None), out_specs=P(\"x\", None))\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=P(\"x\", None), out_specs=P(\"x\", None))\n", "def rms_norm_shmap(x):\n", " return rms_norm(x)\n", "\n", @@ -781,11 +780,11 @@ "source": [ "### Using `custom partitioning`\n", "\n", - "If you can't use {func}`~jax.experimental.shard_map.shard_map`, an alternative approach is to use {func}`~jax.experimental.custom_partitioning.custom_partitioning`, which supports automatic parallelization via {func}`jax.jit`.\n", + "If you can't use {func}`~jax.shard_map`, an alternative approach is to use {func}`~jax.experimental.custom_partitioning.custom_partitioning`, which supports automatic parallelization via {func}`jax.jit`.\n", "{func}`~jax.experimental.custom_partitioning.custom_partitioning` works by adding Python callbacks into the XLA compiler's partitioning pass, which allows very flexible logic, but also comes with some rough edges.\n", "We won't go into too much detail on the caveats here, but the main issues that you should be aware of are:\n", "\n", - "1. `custom_partitioning` can cause unexpected cache misses when used with the JAX's [Persistent compilation cache](https://jax.readthedocs.io/en/latest/persistent_compilation_cache.html). This can be mitigated using the `jax_remove_custom_partitioning_ptr_from_cache_key` configuration flag, but that isn't always appropriate either.\n", + "1. `custom_partitioning` can cause unexpected cache misses when used with the JAX's [Persistent compilation cache](https://docs.jax.dev/en/latest/persistent_compilation_cache.html). This can be mitigated using the `jax_remove_custom_partitioning_ptr_from_cache_key` configuration flag, but that isn't always appropriate either.\n", "2. Debugging `custom_partitioning` logic can be tedious because Python errors don't always get propagated, instead causing your Python process to exit. That being said, any exceptions will show up in the process logs, so you should be able to track them down there.\n", "\n", "All that being said, here's how we can wrap our FFI implementation of `rms_norm` using {func}`~jax.experimental.custom_partitioning.custom_partitioning`:" @@ -843,6 +842,7 @@ "rms_norm_partitioned.def_partition(\n", " infer_sharding_from_operands=rms_norm_infer_sharding_from_operands,\n", " partition=rms_norm_partition,\n", + " sharding_rule=\"... i -> ... j\",\n", ")\n", "\n", "output = jax.jit(rms_norm_partitioned, out_shardings=batch_shd)(x_batch_shd)\n", diff --git a/docs/ffi.md b/docs/ffi.md index 4aa03c217855..2da124c3c707 100644 --- a/docs/ffi.md +++ b/docs/ffi.md @@ -353,7 +353,7 @@ Unlike with batching, {func}`~jax.ffi.ffi_call` doesn't provide any default supp As far as JAX is concerned, the foreign function is a black box that can't be inspected to determine the appropriate behavior when differentiated. Therefore, it is the {func}`~jax.ffi.ffi_call` user's responsibility to define a custom derivative rule. -More details about custom derivative rules can be found in the [custom derivatives tutorial](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function. +More details about custom derivative rules can be found in the [custom derivatives tutorial](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function. In this case, we actually define two new FFI calls: 1. `rms_norm_fwd` returns two outputs: (a) the "primal" result, and (b) the "residuals" which are used in the backwards pass. @@ -556,19 +556,18 @@ print(hlo.split("\n\n")[-1]) This clearly (to us!) isn't the optimal partitioning of this function, but it's the best that JAX/XLA can do with the information given. -To generate better partitioning logic, we can use {func}`~jax.experimental.shard_map.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`, and we discuss both options here. +To generate better partitioning logic, we can use {func}`~jax.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`, and we discuss both options here. That being said, it's not straightforward to generate _optimal_ partitioning for all inputs, because sometimes this would require algorithmic changes. Specifically, let's add support for "batch partitioning", which handles the case where the data are sharded on batch dimensions, but sharding on the last dimension will always require in re-sharding. ### Using `shard_map` -If you are using manual sharding control via {func}`~jax.experimental.shard_map.shard_map`, any FFI calls in your program should already partition appropriately: +If you are using manual sharding control via {func}`~jax.shard_map`, any FFI calls in your program should already partition appropriately: ```{code-cell} ipython3 from functools import partial -from jax.experimental.shard_map import shard_map -@partial(shard_map, mesh=mesh, in_specs=P("x", None), out_specs=P("x", None)) +@partial(jax.shard_map, mesh=mesh, in_specs=P("x", None), out_specs=P("x", None)) def rms_norm_shmap(x): return rms_norm(x) @@ -587,11 +586,11 @@ assert "all-to-all" in hlo_data_shmap ### Using `custom partitioning` -If you can't use {func}`~jax.experimental.shard_map.shard_map`, an alternative approach is to use {func}`~jax.experimental.custom_partitioning.custom_partitioning`, which supports automatic parallelization via {func}`jax.jit`. +If you can't use {func}`~jax.shard_map`, an alternative approach is to use {func}`~jax.experimental.custom_partitioning.custom_partitioning`, which supports automatic parallelization via {func}`jax.jit`. {func}`~jax.experimental.custom_partitioning.custom_partitioning` works by adding Python callbacks into the XLA compiler's partitioning pass, which allows very flexible logic, but also comes with some rough edges. We won't go into too much detail on the caveats here, but the main issues that you should be aware of are: -1. `custom_partitioning` can cause unexpected cache misses when used with the JAX's [Persistent compilation cache](https://jax.readthedocs.io/en/latest/persistent_compilation_cache.html). This can be mitigated using the `jax_remove_custom_partitioning_ptr_from_cache_key` configuration flag, but that isn't always appropriate either. +1. `custom_partitioning` can cause unexpected cache misses when used with the JAX's [Persistent compilation cache](https://docs.jax.dev/en/latest/persistent_compilation_cache.html). This can be mitigated using the `jax_remove_custom_partitioning_ptr_from_cache_key` configuration flag, but that isn't always appropriate either. 2. Debugging `custom_partitioning` logic can be tedious because Python errors don't always get propagated, instead causing your Python process to exit. That being said, any exceptions will show up in the process logs, so you should be able to track them down there. All that being said, here's how we can wrap our FFI implementation of `rms_norm` using {func}`~jax.experimental.custom_partitioning.custom_partitioning`: @@ -643,6 +642,7 @@ def rms_norm_partition(eps, mesh, args_info, result_info): rms_norm_partitioned.def_partition( infer_sharding_from_operands=rms_norm_infer_sharding_from_operands, partition=rms_norm_partition, + sharding_rule="... i -> ... j", ) output = jax.jit(rms_norm_partitioned, out_shardings=batch_shd)(x_batch_shd) diff --git a/docs/ffi/CMakeLists.txt b/docs/ffi/CMakeLists.txt index 9d3e9df7d3bf..b7f1af5c1a1b 100644 --- a/docs/ffi/CMakeLists.txt +++ b/docs/ffi/CMakeLists.txt @@ -4,7 +4,7 @@ project(rms_norm LANGUAGES CXX) find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) execute_process( COMMAND "${Python_EXECUTABLE}" - "-c" "from jax.extend import ffi; print(ffi.include_dir())" + "-c" "from jax import ffi; print(ffi.include_dir())" OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR) message(STATUS "XLA include directory: ${XLA_DIR}") diff --git a/docs/glossary.rst b/docs/glossary.rst index 286b07e21a66..0c36792ad146 100644 --- a/docs/glossary.rst +++ b/docs/glossary.rst @@ -77,7 +77,7 @@ Glossary of terms Tracer An object used as a standin for a JAX :term:`Array` in order to determine the sequence of operations performed by a Python function. Internally, JAX implements this - via the :class:`jax.core.Tracer` class. + via the `jax.core.Tracer` class. transformation A higher-order function: that is, a function that takes a function as input and outputs @@ -92,7 +92,7 @@ Glossary of terms XLA Short for *Accelerated Linear Algebra*, XLA is a domain-specific compiler for linear algebra operations that is the primary backend for :term:`JIT`-compiled JAX code. - See https://www.tensorflow.org/xla/. + See https://www.openxla.org/xla/. weak type A JAX data type that has the same type promotion semantics as Python scalars; diff --git a/docs/gpu_memory_allocation.rst b/docs/gpu_memory_allocation.rst index 6667589e7b72..9bec686a8675 100644 --- a/docs/gpu_memory_allocation.rst +++ b/docs/gpu_memory_allocation.rst @@ -64,12 +64,12 @@ Common causes of OOM failures **Disabling rematerialization HLO pass** Sometimes disabling the automatic rematerialization HLO pass is favorable to avoid poor remat choices by the compiler. The pass can be enable/disable by setting - :code:`jax.config.update('enable_remat_opt_pass', True)` or - :code:`jax.config.update('enable_remat_opt_pass', False)` respectively. Enabling or + :code:`jax.config.update('jax_compiler_enable_remat_pass', True)` or + :code:`jax.config.update('jax_compiler_enable_remat_pass', False)` respectively. Enabling or disabling the automatic remat pass produces different trade-offs between compute and memory. Note however, that the algorithm is basic and you can often get better trade-off between compute and memory by disabling the automatic remat pass and doing - it manually with `the jax.remat API `_ + it manually with `the jax.remat API `_ Experimental features diff --git a/docs/gpu_performance_tips.md b/docs/gpu_performance_tips.md index bf032dccff88..219051a0acdc 100644 --- a/docs/gpu_performance_tips.md +++ b/docs/gpu_performance_tips.md @@ -1,6 +1,6 @@ # GPU performance tips - + This document focuses on performance tips for neural network workloads @@ -58,7 +58,173 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta * **--xla_gpu_triton_gemm_any** Use the Triton-based GEMM (matmul) emitter for any GEMM that it supports. The default value is False. -### Communication flags +## Communication tips + +### Auto and manual PGLE + +The Profile Guided Latency Estimator (PGLE) workflow measures the actual running time +of compute and collectives, the the profile information is fed back into XLA compiler +for a better scheduling decision. + +The Profile Guided Latency Estimator can be used manually or automatically. In the auto mode +JAX will collect profile information and recompile a module in a single run. While +in manual mode you need to run a task twice, the first time to collect and save profiles +and the second to compile and run with provided data. + +**Important**: the JAX profiler, which is used by both of the PGLE workflows documented +below, cannot co-exist with the NVIDIA Nsight Systems profiler. This limitation can be +avoided by using the JAX compilation cache, as described below. + +### Auto PGLE +The auto PGLE can be turned on by setting the following environment variables: + +Mandatory: +```bash +export JAX_ENABLE_PGLE=true + +# For JAX version <= 0.5.0 make sure to include: +export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true" +``` + +Optional: +```bash +export JAX_PGLE_PROFILING_RUNS=3 +export JAX_PGLE_AGGREGATION_PERCENTILE=85 + +# Right now the auto PGLE profile collection doesn't work with command buffer. +# If the command buffer is enabled, Auto PGLE will disable it during profile +# collection and enable it back after the recompilation. If you need to have a +# consistent command buffer logic with and with PGLE profile you can disable it +# manually: +export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_command_buffer=''" +``` + +Or in the JAX this can be set as the following: + +``` +import jax +from jax._src import config + +with config.enable_pgle(True), config.pgle_profiling_runs(1): + # Run with the profiler collecting performance information. + train_step() + # Automatically re-compile with PGLE profile results + train_step() + ... +``` + +You can control amount of reruns used to collect profile data by changing `JAX_PGLE_PROFILING_RUNS`. +Increasing this parameter would lead to better profile information, but it will also increase the +amount of non-optimized training steps. + +Decreasing the `JAX_PGLE_AGGREGATION_PERCENTILE` parameter might help in case when performance between steps is too noisy to filter out a non-relevant measures. + +**Attention:** Auto PGLE doesn't work for pre-compiled modules. Since JAX need to recompile the module during execution the auto PGLE will not work neither for AoT nor for the following case: + +``` +import jax +from jax._src import config + +train_step_compiled = train_step().lower().compile() + +with config.enable_pgle(True), config.pgle_profiling_runs(1): + train_step_compiled() + # No effect since module was pre-compiled. + train_step_compiled() +``` + +#### Collecting NVIDIA Nsight Systems profiles when using AutoPGLE +[jax#24910](https://github.com/jax-ml/jax/pull/24910) (JAX v0.5.1 and newer) added a +new JAX configuration option, `JAX_COMPILATION_CACHE_EXPECT_PGLE`, which tells JAX to +attempt to load PGLE-optimized compiled functions from the persistent compilation +cache. + +This allows a two-step process, where the first step writes a PGLE-optimized function +to the cache: +```bash +export JAX_ENABLE_COMPILATION_CACHE=yes # not strictly needed, on by default +export JAX_COMPILATION_CACHE_DIR=/root/jax_cache +JAX_ENABLE_PGLE=yes python my-model.py +``` +And the second step uses Nsight Systems and loads the PGLE-optimized function from the +cache: +```bash +JAX_COMPILATION_CACHE_EXPECT_PGLE=yes nsys profile python my-model.py +``` +See also [this page]( +https://docs.jax.dev/en/latest/persistent_compilation_cache.html#pitfalls) for more +information about the persistent compilation cache and possible pitfalls. + +### Manual PGLE + +If you still want to use a manual Profile Guided Latency Estimator the workflow in XLA/GPU is: + +- 1. Run your workload once, with async collectives and latency hiding scheduler enabled. + +You could do so by setting: + +```bash +export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true" +``` + +- 2. Collect and post process a profile by using JAX profiler, saving the extracted instruction latencies into a binary protobuf file. + +```python +import os +from etils import epath +import jax +from jax.experimental import profiler as exp_profiler + +# Define your profile directory +profile_dir = 'gs://my_bucket/profile' +jax.profiler.start_trace(profile_dir) + +# run your workflow +# for i in range(10): +# train_step() + +# Stop trace +jax.profiler.stop_trace() +profile_dir = epath.Path(profile_dir) +directories = profile_dir.glob('plugins/profile/*/') +directories = [d for d in directories if d.is_dir()] +rundir = directories[-1] +logging.info('rundir: %s', rundir) + +# Post process the profile +fdo_profile = exp_profiler.get_profiled_instructions_proto(os.fspath(rundir)) + +# Save the profile proto to a file. +dump_dir = rundir / 'profile.pb' +dump_dir.parent.mkdir(parents=True, exist_ok=True) +dump_dir.write_bytes(fdo_profile) + +``` + +After this step, you will get a `profile.pb` file under the `rundir` printed in the code. + +- 3. Run the workload again feeding that file into the compilation. + +You need to pass the `profile.pb` file to the `--xla_gpu_pgle_profile_file_or_directory_path` flag. + +```bash + export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_pgle_profile_file_or_directory_path=/path/to/profile/profile.pb" +``` + +To enable logging in the XLA and check if the profile is good, set the logging level to include `INFO`: + +```bash +export TF_CPP_MIN_LOG_LEVEL=0 +``` + +Run the real workflow, if you found these loggings in the running log, it means the profiler is used in the latency hiding scheduler: + +``` +2023-07-21 16:09:43.551600: I external/xla/xla/service/gpu/gpu_hlo_schedule.cc:478] Using PGLE profile from /tmp/profile/plugins/profile/2023_07_20_18_29_30/profile.pb +2023-07-21 16:09:43.551741: I external/xla/xla/service/gpu/gpu_hlo_schedule.cc:573] Found profile, using profile guided latency estimator +``` + +#### Flags * **--xla_gpu_enable_latency_hiding_scheduler** This flag enables latency hiding schedulers to overlap asynchronous communication with computation efficiently. @@ -77,20 +243,6 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta By adjusting this factor, users can fine-tune the trade-off between memory efficiency and performance optimizations. -* **--xla_gpu_enable_pipelined_collectives** When using pipeline parallelism, - this flag enables overlapping the (i+1)-th layer weight `AllGather` with the - i-th layer computation. It also enables overlapping (i+1)-th layer - weight `Reduce`/`ReduceScatter` with i-th layer's computation. The default - value is False. **There are some bugs when this flag is turned on.** -* **--xla_gpu_collective_permute_decomposer_threshold** This flag is useful when - performing [GSPMD pipelining](https://arxiv.org/abs/2105.04663). Setting a - nonzero threshold decomposes `CollectivePermute`s into - `CollectivePermuteReceiveDone` and `CollectivePermuteSendDone` pairs, so that - computation can be performed between each corresponding - `ReceiveDone`/`SendDone` pair and hence achieve more overlap. By default the - threshold is 0 and there is no decomposition. Setting it to threshold > 0 such - as `--xla_gpu_collective_permute_decomposer_threshold=1024` can enable this - feature. * **--xla_gpu_all_gather_combine_threshold_bytes** **--xla_gpu_reduce_scatter_combine_threshold_bytes** **--xla_gpu_all_reduce_combine_threshold_bytes** @@ -102,6 +254,449 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta combine at least a Transformer Layer's weight `AllGather`/`ReduceScatter`. By default, the `combine_threshold_bytes` is set to 256. +### Pipeline Parallelism on GPU + +#### Using XLA Flags +XLA implements SPMD-based pipeline parallelism optimizations. This is a scaling +technique where the forward and backward pass are split into multiple pipeline +stages. Each device (or device group) processes the result of the previous +pipeline stage (or the pipeline input) and sends its partial result to the next +stage until the end of the pipeline is reached. This optimization works best +when the latency of the computation is larger than communication. At compile +time, the operations will be rearranged to overlap communication with +computation. + +For an optimized schedule, we recommend these XLA flags: +``` +--xla_gpu_enable_latency_hiding_scheduler=true +--xla_gpu_enable_command_buffer='' +--xla_disable_hlo_passes=collective-permute-motion +--xla_gpu_experimental_pipeline_parallelism_opt_level=PIPELINE_PARALLELISM_OPT_LEVEL_ENABLE +``` + +The following JAX example demonstrates a pattern where communication operations +are scheduled to overlap with computations. In this example we will illustrate +how to set up an optimized pipeline parallelism scheduling using 4 GPUs that +form a communication ring (device 0 -> device 1 -> device 2 -> device 3 -> +device 0). We refer to the pattern `0 -> 1 -> 2 -> 3` as the forward edge, and +`3 -> 0` as the back edge. + +``` +# Imports and setup +import functools +import jax +from jax import sharding +from jax.experimental import mesh_utils +import jax.numpy as jnp +import jax.random + +NUM_DEVICES = 4 +NUM_MICROBATCHES = 5 +NUM_CIRC_REPEATS = 2 +CONTRACTING_DIM_SIZE = 4096 +NON_CONTRACTING_DIM_SIZE = 8192 +COMPUTE_INTENSITY = 32 + +# Creates a collective permute for the "forward edge". +# 0->1, 1->2, ... (N-2)->(N-1) +def shift_right(arr): + padding = [[1, 0]] + [[0, 0]] * (arr.ndim - 1) + # Use lax.slice to guarantee the gradient is a pad. + return jax.lax.slice(jnp.pad(arr, padding), [0] * arr.ndim, arr.shape) + + +# Creates a collective permute for the "back edge". +# (N-1)->0 +def cycle_back(arr): + padding = [[0, NUM_DEVICES - 1]] + [[0, 0]] * (arr.ndim - 1) + return jax.lax.slice( + jnp.pad(arr, padding), + [NUM_DEVICES - 1] + [0] * (arr.ndim - 1), + (NUM_DEVICES - 1 + arr.shape[0],) + arr.shape[1:], + ) + + +def select_on_first_device(then_value, else_value): + assert then_value.shape == else_value.shape + is_first_device = jax.lax.broadcasted_iota("int32", then_value.shape, 0) == 0 + return jnp.where(is_first_device, then_value, else_value) + + +def select_on_last_device(then_value, else_value): + assert then_value.shape == else_value.shape + is_last_device = ( + jax.lax.broadcasted_iota("int32", then_value.shape, 0) == NUM_DEVICES - 1 + ) + return jnp.where(is_last_device, then_value, else_value) + + +def select_on_first_cycle(i, then_value, else_value): + assert then_value.shape == else_value.shape + is_first_cycle = i < NUM_MICROBATCHES + return jnp.where(is_first_cycle, then_value, else_value) + + +def while_body(carry, i): + """Body of the pipeline while loop.""" + weights, input_buffer, output_buffer, fwd_edge_data, bwd_edge_data = carry + + # Read input data from input buffer. + input_data = jax.lax.dynamic_slice( + input_buffer, + (0, (i + 0) % NUM_MICROBATCHES, 0, 0), + (NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE), + ) + + # Collective permute on the "forward edge" shifts data to the next stage. + fwd_edge_data = shift_right(fwd_edge_data) + + # Select compute argument based on device and pipeline cycle. + compute_argument = select_on_first_device( + select_on_first_cycle(i, input_data, bwd_edge_data), + fwd_edge_data, + ).reshape((NUM_DEVICES, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE)) + + # A few matmuls to simulate compute. + tmp = compute_argument + for _ in range(COMPUTE_INTENSITY): + tmp = jax.lax.dot_general(weights, tmp, (((2,), (1,)), ((0,), (0,)))) + compute_result = tmp.reshape( + (NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE) + ) + + # Read data from buffer to pass it to the first device of the pipeline on the + # "back edge". + bwd_edge_data = jax.lax.dynamic_slice( + output_buffer, + (0, (1 + i) % NUM_MICROBATCHES, 0, 0), + (NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE), + ) + + # Collective permute on the "back edge" passes data to the first device. + bwd_edge_data = cycle_back(bwd_edge_data) + + # Update output buffer. We do this after reading from it to avoid the data + # dependency. + output_buffer = jax.lax.dynamic_update_slice( + output_buffer, + compute_result, + (0, (2 + i) % NUM_MICROBATCHES, 0, 0), + ) + + fwd_edge_data = compute_result + carry = ( + weights, + input_buffer, + output_buffer, + fwd_edge_data, + bwd_edge_data, + ) + return carry, i + + +@functools.partial(jax.jit, static_argnames=["mesh"]) +def entry_computation(weights, input_buffer, mesh): + + # Init output buffer. + output_buffer = jnp.zeros_like(input_buffer) + + # Init dummy data for forward and backward edge passed through the while loop. + dummy_data = jnp.zeros( + shape=(NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE) + ).astype(jnp.float32) + dummy_data = jax.device_put( + dummy_data, + sharding.NamedSharding( + mesh, sharding.PartitionSpec("x") + ), + ) + + # Start pipeline. + carry = weights, input_buffer, output_buffer, dummy_data, dummy_data + num_iterations = NUM_CIRC_REPEATS * NUM_MICROBATCHES + NUM_DEVICES - 1 + carry, _ = jax.lax.scan(while_body, carry, xs=jnp.arange(num_iterations)) + _, _, output_buffer, _, _ = carry + + return output_buffer + + +def main(_): + + # Expect constant number of devices. + assert NUM_DEVICES == jax.local_device_count() + + # Create mesh. + mesh = sharding.Mesh( + mesh_utils.create_device_mesh([NUM_DEVICES]), + axis_names=["x"], + ) + + # Init weights. + weights = 1.0 / CONTRACTING_DIM_SIZE + weights = jax.lax.broadcast_in_dim( + weights, + shape=(NUM_DEVICES, CONTRACTING_DIM_SIZE, CONTRACTING_DIM_SIZE), + broadcast_dimensions=(), + ) + weights = jax.device_put( + weights, + sharding.NamedSharding( + mesh, sharding.PartitionSpec("x") + ), + ) + + # Init random input and replicate it across all devices. + random_key = jax.random.key(0) + input_buffer = jax.random.uniform( + random_key, + shape=( + NUM_MICROBATCHES, + CONTRACTING_DIM_SIZE, + NON_CONTRACTING_DIM_SIZE, + ), + ) + input_buffer = jax.lax.broadcast_in_dim( + input_buffer, + shape=( + NUM_DEVICES, + NUM_MICROBATCHES, + CONTRACTING_DIM_SIZE, + NON_CONTRACTING_DIM_SIZE, + ), + broadcast_dimensions=[1, 2, 3], + ) + input_buffer = jax.device_put( + input_buffer, + sharding.NamedSharding( + mesh, sharding.PartitionSpec("x") + ), + ) + + # Run computation. + output_buffer = entry_computation(weights, input_buffer, mesh) + print(f"output_buffer = \n{output_buffer}") +``` + +#### Using `psend` and `precv` + +The JAX example above lowers to `collective-permute` HLO instructions, which are +are implemented through `ncclSend` and `ncclRecv` on GPU. For users who want +more granular control over the ordering of collectives, they can use +`jax.lax.psend` and `jax.lax.precv` directly. Syntactically, these two functions +are analogous to their HLO counterparts. Users should keep in mind that their +program will deadlock when the source-target pairs in a *single* `psend` or +`precv` form a cycle, and when `psend` is not matched by `precv` and +vice-versa. + +If cycles are required in the device communication pattern, deadlocks can be +avoided by making sure that (1) no single `psend` or `precv` function's +source-target pairs contain a cycle, and that (2) a fake data dependency +is inserted to sequentialize the send/recv pairs. No collective can be scheduled +between `psend`/`precv` paris, which can only be controlled through +`jax.lax.optimization_barrier` at the JAX level. The test case +`test_psend_precv_basic_with_no_deadlock_cycle` in the file +[`shard_map_test.py`](https://github.com/jax-ml/jax/blob/main/tests/shard_map_test.py) is one such example. + +The pipeline parallelism example in the previous section uses the +`--xla_gpu_experimental_pipeline_parallelism_opt_level` XLA flag. The same +program can be rewritten using `psend` and `precv` without the flag, if manually +pipelined. + +``` +## same setup and imports +def while_body(carry, i): + ( + weights, + input_buffer, + output_buffer, + prev_compute_res, + prev_stage_slice_fwd, + prev_stage_slice_bwd, + ) = carry + + # Read input data from input buffer. + input_slice = jax.lax.dynamic_slice( + input_buffer, + (0, (i + 0) % NUM_MICROBATCHES, 0, 0), + (1, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE), + ) + + # send_fwd + fwd_send_token = jax.lax.psend( + prev_compute_res, + axis_name="x", + perm=[(0, 1), (1, 2), (2, 3)], + ) + + # Select compute argument based on device and pipeline cycle + compute_argument = select_on_first_device( + select_on_first_cycle(i, input_slice, prev_stage_slice_bwd), + prev_stage_slice_fwd, + ).reshape((1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE)) + + tmp = compute_argument + for _ in range(COMPUTE_INTENSITY): + tmp = jax.lax.dot_general(weights, tmp, (((2,), (1,)), ((0,), (0,)))) + compute_result = tmp.reshape( + (1, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE) + ) + + buffer_slice_for_bwd_ppermute = jax.lax.dynamic_slice( + output_buffer, + (0, (i + 1) % NUM_MICROBATCHES, 0, 0), + (1, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE), + ) + + # make sure ppermute is scheduled after send_fwd + buffer_slice_for_bwd_ppermute_after_send_fwd, _ = ( + jax.lax.optimization_barrier( + (buffer_slice_for_bwd_ppermute, fwd_send_token) + ) + ) + # ppermute_bwd + ppermute_bwd_data = jax.lax.ppermute( + buffer_slice_for_bwd_ppermute_after_send_fwd, + axis_name="x", + perm=[(3, 0)], + ) + + # make sure recv is scheduled after ppermute + precv_token, _ = jax.lax.optimization_barrier( + (jax.lax.create_token(), ppermute_bwd_data) + ) + + # recv_fwd, matches the send_fwd in the next iteration + fwd_recv_data = jax.lax.precv( + precv_token, + out_shape=jax.ShapeDtypeStruct( + input_slice.shape, input_slice.dtype + ), + axis_name="x", + perm=[(0, 1), (1, 2), (2, 3)], + ) + update_output_buffer = jax.lax.dynamic_update_slice( + output_buffer, + compute_result, + (0, (i + 2) % NUM_MICROBATCHES, 0, 0), + ) + carry = ( + weights, + input_buffer, + update_output_buffer, + compute_result, + fwd_recv_data, + ppermute_bwd_data, + ) + return carry, i + + +def entry_computation( + weights, input_buffer, dummy_data, mesh +): + + # Init output buffer. + output_buffer = jnp.zeros_like(input_buffer) + + # Start pipeline. + dummy_slice_fwd = jax.lax.precv( + jax.lax.create_token(), + jax.ShapeDtypeStruct(dummy_data.shape, dummy_data.dtype), + axis_name="x", + perm=[(0, 1), (1, 2), (2, 3)], + ) + + carry = ( + weights, + input_buffer, + output_buffer, + dummy_slice_fwd, + dummy_data, + dummy_data, + ) + + num_iterations = NUM_CIRC_REPEATS * NUM_MICROBATCHES + NUM_DEVICES - 1 + carry, _ = jax.lax.scan(while_body, carry, xs=jnp.arange(num_iterations)) + + _ = jax.lax.psend( + carry[3], + axis_name="x", + perm=[(0, 1), (1, 2), (2, 3)], + ) + + _, _, output_buffer, _, _, _ = carry + + return output_buffer + + +def main(_): + + # Expect constant number of devices. + assert NUM_DEVICES == jax.local_device_count() + + # Create mesh. + mesh = Mesh( + mesh_utils.create_device_mesh([NUM_DEVICES]), + axis_names=["x"], + ) + # Init weights. + weights = 1.0 / CONTRACTING_DIM_SIZE + weights = jax.lax.broadcast_in_dim( + weights, + shape=(NUM_DEVICES, CONTRACTING_DIM_SIZE, CONTRACTING_DIM_SIZE), + broadcast_dimensions=(), + ) + weights = jax.device_put( + weights, NamedSharding(mesh, P("x")) + ) + # Init input. + random_key = jax.random.key(0) + input_buffer = jax.random.uniform( + random_key, + shape=( + NUM_MICROBATCHES, + CONTRACTING_DIM_SIZE, + NON_CONTRACTING_DIM_SIZE, + ), + ) + input_buffer = jax.lax.broadcast_in_dim( + input_buffer, + shape=( + NUM_DEVICES, + NUM_MICROBATCHES, + CONTRACTING_DIM_SIZE, + NON_CONTRACTING_DIM_SIZE, + ), + broadcast_dimensions=[1, 2, 3], + ) + + input_buffer = jax.device_put( + input_buffer, + NamedSharding(mesh, P("x")), + ) + # Init dummy data for forward and backward edge passed through the while + # loop. + dummy_slice = jnp.zeros( + shape=(NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE) + ).astype(jnp.float32) + dummy_data = jax.device_put( + dummy_slice, + NamedSharding(mesh, P("x")), + ) + + entry = partial(entry_computation, mesh=mesh) + + output_buffer = jax.jit( + jax.shard_map( + entry, + mesh=mesh, + in_specs=P("x"), + out_specs=P("x"), + check_vma=False, + ) + )(weights, input_buffer, dummy_data) + print(f"output_buffer = \n{output_buffer}") +``` + ## NCCL flags These Nvidia NCCL flag values may be useful for single-host multi-device diff --git a/docs/gradient-checkpointing.md b/docs/gradient-checkpointing.md index 0938a5da944f..1b6463f65024 100644 --- a/docs/gradient-checkpointing.md +++ b/docs/gradient-checkpointing.md @@ -19,7 +19,7 @@ kernelspec: In this tutorial, you will learn how to control JAX automatic differentiation's saved values using {func}`jax.checkpoint` (also known as {func}`jax.remat`), which can be particularly helpful in machine learning. -If you are new to automatic differentiation (autodiff) or need to refresh your memory, JAX has {ref}`automatic-differentiation` and {ref}`advanced-autodiff` tutorials. +If you are new to automatic differentiation (autodiff) or need to refresh your memory, JAX has an {ref}`automatic-differentiation` tutorial and several {ref}`Advanced automatic differentiation guides `. **TL;DR** Use the {func}`jax.checkpoint` decorator (aliased as {func}`jax.remat`) with {func}`jax.grad` to control which intermediates are saved on the forward pass versus the recomputed intermediates on the backward pass, trading off memory and FLOPs. @@ -49,7 +49,7 @@ x = jnp.ones(4) # Inspect the 'residual' values to be saved on the forward pass # if you were to evaluate `jax.grad(f)(W1, W2, W3, x)` from jax.ad_checkpoint import print_saved_residuals -jax.ad_checkpoint.print_saved_residuals(f, W1, W2, W3, x) +print_saved_residuals(f, W1, W2, W3, x) ``` By applying {func}`jax.checkpoint` to sub-functions, as a decorator or at specific application sites, you force JAX not to save any of that sub-function's residuals. Instead, only the inputs of a {func}`jax.checkpoint`-decorated function might be saved, and any residuals consumed on the backward pass are re-computed from those inputs as needed: @@ -61,7 +61,7 @@ def f2(W1, W2, W3, x): x = jax.checkpoint(g)(W3, x) return x -jax.ad_checkpoint.print_saved_residuals(f2, W1, W2, W3, x) +print_saved_residuals(f2, W1, W2, W3, x) ``` Here, the values of two `sin` applications are saved because they are arguments @@ -73,7 +73,7 @@ To control which values are saveable without having to edit the definition of th ```{code-cell} f3 = jax.checkpoint(f, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable) -jax.ad_checkpoint.print_saved_residuals(f3, W1, W2, W3, x) +print_saved_residuals(f3, W1, W2, W3, x) ``` You can also use policies to refer to intermediate values you name using {func}`jax.ad_checkpoint.checkpoint_name`: @@ -88,7 +88,7 @@ def f4(W1, W2, W3, x): return x f4 = jax.checkpoint(f4, policy=jax.checkpoint_policies.save_only_these_names('a')) -jax.ad_checkpoint.print_saved_residuals(f4, W1, W2, W3, x) +print_saved_residuals(f4, W1, W2, W3, x) ``` When playing around with these toy examples, you can get a closer look at what's going on using a custom `print_fwd_bwd` utility defined in this notebook: @@ -144,7 +144,7 @@ print_fwd_bwd(f3, W1, W2, W3, x) ### Let's think step by step -**Note:** It may help to check out the {ref}`advanced-autodiff` tutorial prior to continuing here. +**Note:** It may help to check out the {ref}`"Advanced automatic differentiation" guides ` prior to continuing here. #### `jax.checkpoint` fundamentals @@ -205,7 +205,7 @@ Using words, this alternative implementation doesn't compute `g_vjp`, or the res The cost you pay is redundant work: in `f_bwd2` you must re-evaluate `g(x)` as part of `jax.vjp(g, x)` just to discard its value (in the underscore variable on the line `_, g_vjp = jax.vjp(g, x)`). -You can get this VJP behavior in autodiff — without having to write VJP functions directly — by instead using {func}`jax.checkpoint` in an alternative definition of the original function `f`: +You can get this VJP behavior in autodiff --- without having to write VJP functions directly --- by instead using {func}`jax.checkpoint` in an alternative definition of the original function `f`: ```{code-cell} def f_checkpoint(x): @@ -288,7 +288,7 @@ As shown so far, using {func}`jax.checkpoint` switches from one extreme to anoth To operate between these two extremes, saving some things and not others, you can carefully place {func}`jax.checkpoint` decorators on sub-functions. But that requires editing the function to be differentiated, e.g. model code, which may be inconvenient. It can also be hard to experiment with variations. -So an alternative is to use the `policy` argument to {func}`jax.checkpoint`. A policy is a callable (i.e. a function) which takes as input a type-level specification of a first order primitive application and returns a boolean indicating whether the corresponding output value(s) are allowed to be saved as residuals (or instead must be recomputed in the (co)tangent computation as needed). To write robust code, a policy should be selected from the attributes on {func}`jax.checkpoint_policies`, like {func}`jax.checkpoint_policies.dots_with_no_batch_dims_saveable`, since the API for writing custom policy callables is considered internal. +So an alternative is to use the `policy` argument to {func}`jax.checkpoint`. A policy is a callable (i.e. a function) which takes as input a type-level specification of a first order primitive application and returns a boolean indicating whether the corresponding output value(s) are allowed to be saved as residuals (or instead must be recomputed in the (co)tangent computation as needed). To write robust code, a policy should be selected from the attributes on {obj}`jax.checkpoint_policies`, like {func}`jax.checkpoint_policies.dots_with_no_batch_dims_saveable`, since the API for writing custom policy callables is considered internal. For example, consider this function to be differentiated: @@ -341,7 +341,7 @@ def predict(params, x): return x ``` -By itself, {func}`jax.ad_checkpoint import.checkpoint_name` is just an identity function. But because some policy functions know to look for them, you can use the names to control whether certain values output by {func}`jax.ad_checkpoint import.checkpoint_name` are considered saveable: +By itself, {func}`jax.ad_checkpoint.checkpoint_name` is just an identity function. But because some policy functions know to look for them, you can use the names to control whether certain values output by {func}`jax.ad_checkpoint.checkpoint_name` are considered saveable: ```{code-cell} print_saved_residuals(loss, params, x, y) @@ -359,7 +359,7 @@ Another policy which refers to names is `jax.checkpoint_policies.save_only_these You may consider offloading to CPU memory instead of recomputing when checkpointing to save accelerator memory. `jax.checkpoint_policies.offload_dot_with_no_batch_dims` can offload the results of matrix multiplications with no batch dimensions to the CPU. ```{code-cell} -from jax.ad_checkpoint import checkpoint +from jax import checkpoint def checkpoint_offload_dot_with_no_batch_dims(self): policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims( @@ -380,7 +380,8 @@ def checkpoint_offload_dot_with_no_batch_dims(self): One of JAX's checkpoint policies allows specified checkpoint names to be offloaded to CPUs. This policy is implemented through `jax.checkpoint_policies.save_and_offload_only_these_names`, which has four arguments: `names_which_can_be_saved`, `names_which_can_be_offloaded`, the offloading source, and destination. Names listed in `names_which_can_be_saved` are kept on the device, names listed in `names_which_can_be_offloaded` are moved to CPU memory, and other names or operations without names are recomputed. For example, if we have checkpoint names `y`, `z`, and `w`, `y` can be saved on the device, `z` can be offloaded to CPU memory, and `w` can be recomputed. ```{code-cell} -from jax.ad_checkpoint import checkpoint, checkpoint_name +from jax import checkpoint +from jax.ad_checkpoint import checkpoint_name from jax._src import test_util as jtu def checkpoint_names_saved_offloaded_recomputed(self): @@ -411,22 +412,7 @@ The code defines a function `f` that which applies checkpointing with a custom p #### List of policies -The policies are: -* `everything_saveable` (the default strategy, as if `jax.checkpoint` were not being used at all) -* `nothing_saveable` (i.e. rematerialize everything, as if a custom policy were not being used at all) -* `dots_saveable` or its alias `checkpoint_dots` -* `dots_with_no_batch_dims_saveable` or its alias `checkpoint_dots_with_no_batch_dims` -* `save_anything_but_these_names` (save any values except for the output of - `checkpoint_name` with any of the names given) -* `save_any_names_but_these` (save only named values, i.e. any outputs of - `checkpoint_name`, except for those with the names given) -* `save_only_these_names` (save only named values, and only among the names - given) -* `offload_dot_with_no_batch_dims` same as `dots_with_no_batch_dims_saveable`, - but offload to CPU memory instead of recomputing. -* `save_and_offload_only_these_names` same as `save_only_these_names`, but - offload to CPU memory instead of recomputing. -* `save_from_both_policies(policy_1, policy_2)` (like a logical `or`, so that a residual is saveable if it is saveable according to `policy_1` _or_ `policy_2`) +The policies can be found [here](https://docs.jax.dev/en/latest/jax.html#checkpoint-policies). Policies only indicate what is saveable; a value is only saved if it's actually needed by the backward pass. @@ -515,7 +501,7 @@ def net(params: ParamsList, x: jnp.ndarray): Instead, iterate over the layer application with {func}`jax.lax.scan`: ```{code-cell} -params = [(jnp.array([[0.5, 0.5], [1., 1.]]), jnp.array([0.5, 0.5])), +params = [(jnp.array([[0.5, 0.5], [1., 1.]]), jnp.array([0.5, 0.5])), (jnp.array([[0.5, 0.5], [1., 1.]]), jnp.array([0.5, 0.5]))] all_weights = jnp.stack([W for W, _ in params]) diff --git a/docs/higher-order.md b/docs/higher-order.md new file mode 100644 index 000000000000..e835d3af82cc --- /dev/null +++ b/docs/higher-order.md @@ -0,0 +1,336 @@ +--- +jupytext: + formats: md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.4 +kernelspec: + display_name: Python 3 + language: python + name: python3 +--- + +# Higher-order derivatives + +## Taking gradients (part 2) + +JAX's autodiff makes it easy to compute higher-order derivatives, because the functions that compute derivatives are themselves differentiable. Thus, higher-order derivatives are as easy as stacking transformations. + +The single-variable case was covered in the {ref}`automatic-differentiation` tutorial, where the example showed how to use {func}`jax.grad` to compute the derivative of $f(x) = x^3 + 2x^2 - 3x + 1$. + +In the multivariable case, higher-order derivatives are more complicated. The second-order derivative of a function is represented by its [Hessian matrix](https://en.wikipedia.org/wiki/Hessian_matrix), defined according to: + +$$(\mathbf{H}f)_{i,j} = \frac{\partial^2 f}{\partial_i\partial_j}.$$ + +The Hessian of a real-valued function of several variables, $f: \mathbb R^n\to\mathbb R$, can be identified with the [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant) of its gradient. + +JAX provides two transformations for computing the Jacobian of a function, {func}`jax.jacfwd` and {func}`jax.jacrev`, corresponding to forward- and reverse-mode autodiff. They give the same answer, but one can be more efficient than the other in different circumstances – refer to the [video about autodiff](https://www.youtube.com/watch?v=wG_nF1awSSY). + +```{code-cell} +import jax + +def hessian(f): + return jax.jacfwd(jax.grad(f)) +``` + +Let's double check this is correct on the dot-product $f: \mathbf{x} \mapsto \mathbf{x} ^\top \mathbf{x}$. + +if $i=j$, $\frac{\partial^2 f}{\partial_i\partial_j}(\mathbf{x}) = 2$. Otherwise, $\frac{\partial^2 f}{\partial_i\partial_j}(\mathbf{x}) = 0$. + +```{code-cell} +import jax.numpy as jnp + +def f(x): + return jnp.dot(x, x) + +hessian(f)(jnp.array([1., 2., 3.])) +``` + +## Higher-order derivative applications + +Some meta-learning techniques, such as Model-Agnostic Meta-Learning ([MAML](https://arxiv.org/abs/1703.03400)), require differentiating through gradient updates. In other frameworks this can be quite cumbersome, but in JAX it's much easier: + +```python +def meta_loss_fn(params, data): + """Computes the loss after one step of SGD.""" + grads = jax.grad(loss_fn)(params, data) + return loss_fn(params - lr * grads, data) + +meta_grads = jax.grad(meta_loss_fn)(params, data) +``` + +(stopping-gradients)= +### Stopping gradients + +Autodiff enables automatic computation of the gradient of a function with respect to its inputs. Sometimes, however, you might want some additional control: for instance, you might want to avoid backpropagating gradients through some subset of the computational graph. + +Consider for instance the TD(0) ([temporal difference](https://en.wikipedia.org/wiki/Temporal_difference_learning)) reinforcement learning update. This is used to learn to estimate the *value* of a state in an environment from experience of interacting with the environment. Let's assume the value estimate $v_{\theta}(s_{t-1}$) in a state $s_{t-1}$ is parameterised by a linear function. + +```{code-cell} +# Value function and initial parameters +value_fn = lambda theta, state: jnp.dot(theta, state) +theta = jnp.array([0.1, -0.1, 0.]) +``` + +Consider a transition from a state $s_{t-1}$ to a state $s_t$ during which you observed the reward $r_t$ + +```{code-cell} +# An example transition. +s_tm1 = jnp.array([1., 2., -1.]) +r_t = jnp.array(1.) +s_t = jnp.array([2., 1., 0.]) +``` + +The TD(0) update to the network parameters is: + +$$ +\Delta \theta = (r_t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})) \nabla v_{\theta}(s_{t-1}) +$$ + +This update is not the gradient of any loss function. + +However, it can be **written** as the gradient of the pseudo loss function + +$$ +L(\theta) = - \frac{1}{2} [r_t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})]^2 +$$ + +if the dependency of the target $r_t + v_{\theta}(s_t)$ on the parameter $\theta$ is ignored. + +How can you implement this in JAX? If you write the pseudo loss naively, you get: + +```{code-cell} +def td_loss(theta, s_tm1, r_t, s_t): + v_tm1 = value_fn(theta, s_tm1) + target = r_t + value_fn(theta, s_t) + return -0.5 * ((target - v_tm1) ** 2) + +td_update = jax.grad(td_loss) +delta_theta = td_update(theta, s_tm1, r_t, s_t) + +delta_theta +``` + +But `td_update` will **not** compute a TD(0) update, because the gradient computation will include the dependency of `target` on $\theta$. + +You can use {func}`jax.lax.stop_gradient` to force JAX to ignore the dependency of the target on $\theta$: + +```{code-cell} +def td_loss(theta, s_tm1, r_t, s_t): + v_tm1 = value_fn(theta, s_tm1) + target = r_t + value_fn(theta, s_t) + return -0.5 * ((jax.lax.stop_gradient(target) - v_tm1) ** 2) + +td_update = jax.grad(td_loss) +delta_theta = td_update(theta, s_tm1, r_t, s_t) + +delta_theta +``` + +This will treat `target` as if it did **not** depend on the parameters $\theta$ and compute the correct update to the parameters. + +Now, let's also calculate $\Delta \theta$ using the original TD(0) update expression, to cross-check our work. You may wish to try and implement this yourself using {func}`jax.grad` and your knowledge so far. Here's our solution: + +```{code-cell} +s_grad = jax.grad(value_fn)(theta, s_tm1) +delta_theta_original_calculation = (r_t + value_fn(theta, s_t) - value_fn(theta, s_tm1)) * s_grad + +delta_theta_original_calculation # [1.2, 2.4, -1.2], same as `delta_theta` +``` + +`jax.lax.stop_gradient` may also be useful in other settings, for instance if you want the gradient from some loss to only affect a subset of the parameters of the neural network (because, for instance, the other parameters are trained using a different loss). + +### Straight-through estimator using `stop_gradient` + +The straight-through estimator is a trick for defining a 'gradient' of a function that is otherwise non-differentiable. Given a non-differentiable function $f : \mathbb{R}^n \to \mathbb{R}^n$ that is used as part of a larger function that we wish to find a gradient of, we simply pretend during the backward pass that $f$ is the identity function. This can be implemented neatly using `jax.lax.stop_gradient`: + +```{code-cell} +def f(x): + return jnp.round(x) # non-differentiable + +def straight_through_f(x): + # Create an exactly-zero expression with Sterbenz lemma that has + # an exactly-one gradient. + zero = x - jax.lax.stop_gradient(x) + return zero + jax.lax.stop_gradient(f(x)) + +print("f(x): ", f(3.2)) +print("straight_through_f(x):", straight_through_f(3.2)) + +print("grad(f)(x):", jax.grad(f)(3.2)) +print("grad(straight_through_f)(x):", jax.grad(straight_through_f)(3.2)) +``` + +### Per-example gradients + +While most ML systems compute gradients and updates from batches of data, for reasons of computational efficiency and/or variance reduction, it is sometimes necessary to have access to the gradient/update associated with each specific sample in the batch. + +For instance, this is needed to prioritize data based on gradient magnitude, or to apply clipping / normalisations on a sample by sample basis. + +In many frameworks (PyTorch, TF, Theano) it is often not trivial to compute per-example gradients, because the library directly accumulates the gradient over the batch. Naive workarounds, such as computing a separate loss per example and then aggregating the resulting gradients are typically very inefficient. + +In JAX, you can define the code to compute the gradient per-sample in an easy but efficient way. + +Just combine the {func}`jax.jit`, {func}`jax.vmap` and {func}`jax.grad` transformations together: + +```{code-cell} +perex_grads = jax.jit(jax.vmap(jax.grad(td_loss), in_axes=(None, 0, 0, 0))) + +# Test it: +batched_s_tm1 = jnp.stack([s_tm1, s_tm1]) +batched_r_t = jnp.stack([r_t, r_t]) +batched_s_t = jnp.stack([s_t, s_t]) + +perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t) +``` + +Let's go through this one transformation at a time. + +First, you apply {func}`jax.grad` to `td_loss` to obtain a function that computes the gradient of the loss w.r.t. the parameters on single (unbatched) inputs: + +```{code-cell} +dtdloss_dtheta = jax.grad(td_loss) + +dtdloss_dtheta(theta, s_tm1, r_t, s_t) +``` + +This function computes one row of the array above. + +Then, you vectorise this function using {func}`jax.vmap`. This adds a batch dimension to all inputs and outputs. Now, given a batch of inputs, you produce a batch of outputs — each output in the batch corresponds to the gradient for the corresponding member of the input batch. + +```{code-cell} +almost_perex_grads = jax.vmap(dtdloss_dtheta) + +batched_theta = jnp.stack([theta, theta]) +almost_perex_grads(batched_theta, batched_s_tm1, batched_r_t, batched_s_t) +``` + +This isn't quite what we want, because we have to manually feed this function a batch of `theta`s, whereas we actually want to use a single `theta`. We fix this by adding `in_axes` to the {func}`jax.vmap`, specifying theta as `None`, and the other args as `0`. This makes the resulting function add an extra axis only to the other arguments, leaving `theta` unbatched, as we want: + +```{code-cell} +inefficient_perex_grads = jax.vmap(dtdloss_dtheta, in_axes=(None, 0, 0, 0)) + +inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t) +``` + +This does what we want, but is slower than it has to be. Now, you wrap the whole thing in a {func}`jax.jit` to get the compiled, efficient version of the same function: + +```{code-cell} +perex_grads = jax.jit(inefficient_perex_grads) + +perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t) +``` + +```{code-cell} +%timeit inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready() +%timeit perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready() +``` + +### Hessian-vector products with `jax.grad`-of-`jax.grad` + +One thing you can do with higher-order {func}`jax.grad` is build a Hessian-vector product function. (Later on you'll write an even more efficient implementation that mixes both forward- and reverse-mode, but this one will use pure reverse-mode.) + +A Hessian-vector product function can be useful in a [truncated Newton Conjugate-Gradient algorithm](https://en.wikipedia.org/wiki/Truncated_Newton_method) for minimizing smooth convex functions, or for studying the curvature of neural network training objectives (e.g. [1](https://arxiv.org/abs/1406.2572), [2](https://arxiv.org/abs/1811.07062), [3](https://arxiv.org/abs/1706.04454), [4](https://arxiv.org/abs/1802.03451)). + +For a scalar-valued function $f : \mathbb{R}^n \to \mathbb{R}$ with continuous second derivatives (so that the Hessian matrix is symmetric), the Hessian at a point $x \in \mathbb{R}^n$ is written as $\partial^2 f(x)$. A Hessian-vector product function is then able to evaluate + +$\qquad v \mapsto \partial^2 f(x) \cdot v$ + +for any $v \in \mathbb{R}^n$. + +The trick is not to instantiate the full Hessian matrix: if $n$ is large, perhaps in the millions or billions in the context of neural networks, then that might be impossible to store. + +Luckily, {func}`jax.grad` already gives us a way to write an efficient Hessian-vector product function. You just have to use the identity: + +$\qquad \partial^2 f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)$, + +where $g(x) = \partial f(x) \cdot v$ is a new scalar-valued function that dots the gradient of $f$ at $x$ with the vector $v$. Notice that you're only ever differentiating scalar-valued functions of vector-valued arguments, which is exactly where you know {func}`jax.grad` is efficient. + +In JAX code, you can just write this: + +```{code-cell} +def hvp(f, x, v): + return jax.grad(lambda x: jnp.vdot(jax.grad(f)(x), v))(x) +``` + +This example shows that you can freely use lexical closure, and JAX will never get perturbed or confused. + +You will check this implementation a few cells down, once you learn how to compute dense Hessian matrices. You'll also write an even better version that uses both forward-mode and reverse-mode. + +### Jacobians and Hessians using `jax.jacfwd` and `jax.jacrev` + +You can compute full Jacobian matrices using the {func}`jax.jacfwd` and {func}`jax.jacrev` functions: + +```{code-cell} +from jax import jacfwd, jacrev + +# Define a sigmoid function. +def sigmoid(x): + return 0.5 * (jnp.tanh(x / 2) + 1) + +# Outputs probability of a label being true. +def predict(W, b, inputs): + return sigmoid(jnp.dot(inputs, W) + b) + +# Build a toy dataset. +inputs = jnp.array([[0.52, 1.12, 0.77], + [0.88, -1.08, 0.15], + [0.52, 0.06, -1.30], + [0.74, -2.49, 1.39]]) + +# Initialize random model coefficients +key = jax.random.key(0) +key, W_key, b_key = jax.random.split(key, 3) +W = jax.random.normal(W_key, (3,)) +b = jax.random.normal(b_key, ()) + +# Isolate the function from the weight matrix to the predictions +f = lambda W: predict(W, b, inputs) + +J = jacfwd(f)(W) +print("jacfwd result, with shape", J.shape) +print(J) + +J = jacrev(f)(W) +print("jacrev result, with shape", J.shape) +print(J) +``` + +These two functions compute the same values (up to machine numerics), but differ in their implementation: {func}`jax.jacfwd` uses forward-mode automatic differentiation, which is more efficient for "tall" Jacobian matrices (more outputs than inputs), while {func}`jax.jacrev` uses reverse-mode, which is more efficient for "wide" Jacobian matrices (more inputs than outputs). For matrices that are near-square, {func}`jax.jacfwd` probably has an edge over {func}`jax.jacrev`. + +You can also use {func}`jax.jacfwd` and {func}`jax.jacrev` with container types: + +```{code-cell} +def predict_dict(params, inputs): + return predict(params['W'], params['b'], inputs) + +J_dict = jax.jacrev(predict_dict)({'W': W, 'b': b}, inputs) +for k, v in J_dict.items(): + print("Jacobian from {} to logits is".format(k)) + print(v) +``` + +For more details on forward- and reverse-mode, as well as how to implement {func}`jax.jacfwd` and {func}`jax.jacrev` as efficiently as possible, read on! + +Using a composition of two of these functions gives us a way to compute dense Hessian matrices: + +```{code-cell} +def hessian(f): + return jax.jacfwd(jax.jacrev(f)) + +H = hessian(f)(W) +print("hessian, with shape", H.shape) +print(H) +``` + +This shape makes sense: if you start with a function $f : \mathbb{R}^n \to \mathbb{R}^m$, then at a point $x \in \mathbb{R}^n$ you expect to get the shapes: + +* $f(x) \in \mathbb{R}^m$, the value of $f$ at $x$, +* $\partial f(x) \in \mathbb{R}^{m \times n}$, the Jacobian matrix at $x$, +* $\partial^2 f(x) \in \mathbb{R}^{m \times n \times n}$, the Hessian at $x$, + +and so on. + +To implement `hessian`, you could have used `jacfwd(jacrev(f))` or `jacrev(jacfwd(f))` or any other composition of these two. But forward-over-reverse is typically the most efficient. That's because in the inner Jacobian computation we're often differentiating a function wide Jacobian (maybe like a loss function $f : \mathbb{R}^n \to \mathbb{R}$), while in the outer Jacobian computation we're differentiating a function with a square Jacobian (since $\nabla f : \mathbb{R}^n \to \mathbb{R}^n$), which is where forward-mode wins out. diff --git a/docs/index.rst b/docs/index.rst index ba8ebcbdd128..9200a4b26a14 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,17 +1,6 @@ JAX: High performance array computing ===================================== -.. raw:: html - - - - .. raw:: html :file: hero.html @@ -57,14 +46,14 @@ JAX: High performance array computing :link-type: ref :class-card: getting-started - .. grid-item-card:: :material-regular:`library_books;2em` User guides + .. grid-item-card:: :material-regular:`library_books;2em` JAX 101 :columns: 12 6 6 4 - :link: user-guides + :link: jax-101 :link-type: ref - :class-card: user-guides + :class-card: jax-101 -If you're looking to train neural networks, use Flax_ and start with its tutorials. -For an end-to-end transformer library built on JAX, see MaxText_. +If you're looking to use JAX to train neural networks, check out the `JAX AI +Stack`_! Ecosystem --------- @@ -107,7 +96,7 @@ numerical computing tools; the following is just a small sample of what is out t .. grid-item:: :material-regular:`bar_chart;2em` **Probabilistic modeling** - - `TensorFlow Probabilty`_ + - `TensorFlow Probability`_ - Distrax_ .. grid-item:: :material-outlined:`animation;2em` **Physics & simulation** @@ -121,6 +110,7 @@ numerical computing tools; the following is just a small sample of what is out t - AXLearn_ - Levanter_ - EasyLM_ + - Marin_ Many more JAX-based libraries have been developed; the community-run `Awesome JAX`_ page @@ -132,29 +122,27 @@ maintains an up-to-date list. :caption: Getting started installation - quickstart + notebooks/thinking_in_jax .. toctree:: :hidden: :maxdepth: 1 - tutorials - notebooks/Common_Gotchas_in_JAX - - faq + jax-101 .. toctree:: :hidden: :maxdepth: 2 - :caption: More guides/resources + :caption: Resources, guides, and references - user_guides - advanced_guide + key-concepts + advanced_guides + jax contributor_guide extensions notes - jax + pallas/index about @@ -162,9 +150,15 @@ maintains an up-to-date list. :hidden: :maxdepth: 1 + faq changelog glossary +.. toctree:: + :hidden: + :maxdepth: 2 + + config_options .. _Awesome JAX: https://github.com/n2cholas/awesome-jax .. _AXLearn: https://github.com/apple/axlearn @@ -179,8 +173,10 @@ maintains an up-to-date list. .. _Grain: https://github.com/google/grain .. _Hugging Face Datasets: https://huggingface.co/docs/datasets/ .. _JAX MD: https://jax-md.readthedocs.io/ +.. _JAX AI Stack: https://docs.jaxstack.ai/en/latest/getting_started.html .. _Keras: https://keras.io/ .. _Levanter: https://github.com/stanford-crfm/levanter +.. _Marin: https://github.com/marin-community/marin .. _Lineax: https://github.com/patrick-kidger/lineax .. _MaxText: https://github.com/google/maxtext/ .. _Numpyro: https://num.pyro.ai/en/latest/index.html @@ -189,4 +185,4 @@ maintains an up-to-date list. .. _Orbax: https://orbax.readthedocs.io/ .. _PyMC: https://www.pymc.io/ .. _TensorFlow Datasets: https://www.tensorflow.org/datasets -.. _TensorFlow Probabilty: https://www.tensorflow.org/probability +.. _TensorFlow Probability: https://www.tensorflow.org/probability diff --git a/docs/installation.md b/docs/installation.md index ee675dd1e586..4564cc8b261a 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -13,9 +13,9 @@ different builds for different operating systems and accelerators. ``` pip install -U jax ``` -* **GPU (NVIDIA, CUDA 12)** +* **GPU (NVIDIA, CUDA 13)** ``` - pip install -U "jax[cuda12]" + pip install -U "jax[cuda13]" ``` * **TPU (Google Cloud TPU VM)** @@ -28,14 +28,14 @@ different builds for different operating systems and accelerators. The table below shows all supported platforms and installation options. Check if your setup is supported; and if it says _"yes"_ or _"experimental"_, then click on the corresponding link to learn how to install JAX in greater detail. -| | Linux, x86_64 | Linux, aarch64 | Mac, x86_64 | Mac, aarch64 | Windows, x86_64 | Windows WSL2, x86_64 | -|------------------|---------------------------------------|---------------------------------|---------------------------------------|---------------------------------------|--------------------------|------------------------------------------| -| CPU | {ref}`yes ` | {ref}`yes ` | {ref}`jax≤0.4.38 only ` | {ref}`yes ` | {ref}`yes ` | {ref}`yes ` | -| NVIDIA GPU | {ref}`yes ` | {ref}`yes ` | no | n/a | no | {ref}`experimental ` | -| Google Cloud TPU | {ref}`yes ` | n/a | n/a | n/a | n/a | n/a | -| AMD GPU | {ref}`experimental ` | no | {ref}`experimental ` | n/a | no | no | -| Apple GPU | n/a | no | n/a | {ref}`experimental ` | n/a | n/a | -| Intel GPU | {ref}`experimental `| n/a | n/a | n/a | no | no | +| | Linux, x86_64 | Linux, aarch64 | Mac, aarch64 | Windows, x86_64 | Windows WSL2, x86_64 | +|------------------|---------------------------------------|---------------------------------|---------------------------------------|--------------------------|------------------------------------------| +| CPU | {ref}`yes ` | {ref}`yes ` | {ref}`yes ` | {ref}`yes ` | {ref}`yes ` | +| NVIDIA GPU | {ref}`yes ` | {ref}`yes ` | n/a | no | {ref}`experimental ` | +| Google Cloud TPU | {ref}`yes ` | n/a | n/a | n/a | n/a | +| AMD GPU | {ref}`yes ` | no | n/a | no | {ref}`experimental ` | +| Apple GPU | n/a | no | {ref}`experimental ` | n/a | n/a | +| Intel GPU | {ref}`experimental `| n/a | n/a | no | no | (install-cpu)= @@ -48,7 +48,6 @@ operating systems and architectures: - Linux, x86_64 - Linux, aarch64 -- macOS, Intel - macOS, Apple ARM-based - Windows, x86_64 (*experimental*) @@ -73,13 +72,15 @@ not being installed alongside `jax`, although `jax` may successfully install (install-nvidia-gpu)= ## NVIDIA GPU -JAX supports NVIDIA GPUs that have SM version 5.2 (Maxwell) or newer. +On CUDA 12, JAX supports NVIDIA GPUs that have SM version 5.2 (Maxwell) or newer. Note that Kepler-series GPUs are no longer supported by JAX since NVIDIA has dropped support for Kepler GPUs in its software. +On CUDA 13, JAX supports NVIDIA GPUs that have SM version 7.5 or newer. NVIDIA +dropped support for previous GPUs in CUDA 13. You must first install the NVIDIA driver. You're recommended to install the newest driver available from NVIDIA, but the driver -version must be >= 525.60.13 for CUDA 12 on Linux. +version must be >= 525 for CUDA 12 on Linux, and >= 580 for CUDA 13 on Linux. If you need to use a newer CUDA toolkit with an older driver, for example on a cluster where you cannot update the NVIDIA driver easily, you may be @@ -97,17 +98,22 @@ There are two ways to install JAX with NVIDIA GPU support: The JAX team strongly recommends installing CUDA and cuDNN using the pip wheels, since it is much easier! -NVIDIA has released CUDA pip packages only for x86_64 and aarch64; on other -platforms you must use a local installation of CUDA. +NVIDIA has released CUDA packages only for x86_64 and aarch64. ```bash pip install --upgrade pip -# NVIDIA CUDA 12 installation +# NVIDIA CUDA 13 installation # Note: wheels only available on linux. -pip install --upgrade "jax[cuda12]" +pip install --upgrade "jax[cuda13]" + +# Alternatively, for CUDA 12, use +# pip install --upgrade "jax[cuda12]" ``` +We recommend migrating to the CUDA 13 wheels; at some point in the future we +will drop CUDA 12 support. + If JAX detects the wrong version of the NVIDIA CUDA libraries, there are several things you need to check: @@ -134,14 +140,27 @@ able to use the [CUDA forward compatibility packages](https://docs.nvidia.com/deploy/cuda-compatibility/) that NVIDIA provides for this purpose. -JAX currently ships one CUDA wheel variant: +JAX currently ships two CUDA wheel variants: CUDA 12 and CUDA 13: + + +The CUDA 12 wheel is: | Built with | Compatible with | |------------|--------------------| | CUDA 12.3 | CUDA >=12.1 | -| CUDNN 9.1 | CUDNN >=9.1, <10.0 | +| CUDNN 9.8 | CUDNN >=9.8, <10.0 | | NCCL 2.19 | NCCL >=2.18 | + +The CUDA 13 wheel is: + +| Built with | Compatible with | +|------------|--------------------| +| CUDA 13.0 | CUDA >=13.0 | +| CUDNN 9.8 | CUDNN >=9.8, <10.0 | +| NCCL 2.19 | NCCL >=2.18 | + + JAX checks the versions of your libraries, and will report an error if they are not sufficiently new. Setting the `JAX_SKIP_CUDA_CONSTRAINTS_CHECK` environment variable will disable @@ -156,9 +175,14 @@ To install, run: ```bash pip install --upgrade pip -# Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 9.0 or newer. + +# Installs the wheel compatible with NVIDIA CUDA 13 and cuDNN 9.8 or newer. +# Note: wheels only available on linux. +pip install --upgrade "jax[cuda13-local]" + +# Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 9.8 or newer. # Note: wheels only available on linux. -pip install --upgrade "jax[cuda12_local]" +# pip install --upgrade "jax[cuda12-local]" ``` **These `pip` installations do not work with Windows, and may fail silently; refer to the table @@ -226,10 +250,15 @@ refer to (install-amd-gpu)= ## AMD GPU (Linux) -JAX has experimental ROCm support. There are two ways to install JAX: +AMD GPU support is provided by a ROCm JAX plugin supported by AMD. -* Use [AMD's Docker container](https://hub.docker.com/r/rocm/jax-community/tags); or -* Build from source. Refer to the section [Additional notes for building a ROCm jaxlib for AMD GPUs](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). +There are several ways to use JAX on AMDGPU devices. +Please see [AMD's instructions](https://github.com/jax-ml/jax/blob/main/build/rocm/README.md) for details. + +**Note**: ROCm support on Windows WSL2 is experimental. For WSL installation, you may need to: +1. Install [ROCm for WSL](https://rocm.docs.amd.com/projects/install-on-windows/en/latest/tutorial/quick-start.html) following AMD's official guide +2. Follow the standard Linux ROCm JAX installation steps within your WSL environment +3. Be aware that performance and stability may differ from native Linux installations (install-intel-gpu)= ## Intel GPU @@ -281,32 +310,50 @@ Unlike the instructions for installing a JAX release, here we name all of JAX's packages explicitly on the command line, so `pip` will upgrade them if a newer version is available. +JAX publishes nightlies, release candidates(RCs), and releases to several non-pypi [PEP 503](https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/) indexes. + +All JAX packages can be reached from the index `https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/` +as well as PyPI mirrored packages. This additional mirroring enables nightly +installation to use --index (-i) as the install method with pip. + +**Note:** The unified index could return an RC or release as the newest version +even with `--pre` immediately after a release before the newest nightly is +rebuilt. If automation or testing must be done against nightlies or you cannot +use our full index, use the extra index `https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/` +which only contains nightly artifacts. + +The nightly index URLs can also be browsed directly. The `--index` URL is a +[PEP 503](https://peps.python.org/pep-0503/) simple repository index for `pip`, +and each package has its own sub-directory. For example, you can see the available +`jax` packages at +[https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jax](https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jax), +`jax-cuda12-pjrt` packages at +[https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jax-cuda12-pjrt](https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jax-cuda12-pjrt), +and `jax-cuda13-pjrt` packages at +[https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jax-cuda13-pjrt](https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jax-cuda13-pjrt). + - CPU only: ```bash -pip install -U --pre jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html +pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ ``` - Google Cloud TPU: ```bash -pip install -U --pre jax jaxlib libtpu requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +pip install -U --pre jax jaxlib libtpu requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html ``` -- NVIDIA GPU (CUDA 12): +- NVIDIA GPU (CUDA 13): ```bash -pip install -U --pre jax jaxlib "jax-cuda12-plugin[with_cuda]" jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html +pip install -U --pre jax jaxlib "jax-cuda13-plugin[with-cuda]" jax-cuda13-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ ``` -- NVIDIA GPU (CUDA 12) legacy: - -Use the following for historical nightly releases of monolithic CUDA jaxlibs. -You most likely do not want this; no further monolithic CUDA jaxlibs will be -built and those that exist will expire by Sep 2024. Use the "CUDA 12" option above. +- NVIDIA GPU (CUDA 12): ```bash -pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html +pip install -U --pre jax jaxlib "jax-cuda12-plugin[with-cuda]" jax-cuda12-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ ``` (building-jax-from-source)= @@ -322,10 +369,10 @@ still be installed directly via the URLs here. For example: ```bash # Install jaxlib on CPU via the wheel archive -pip install "jax[cpu]==0.3.25" -f https://storage.googleapis.com/jax-releases/jax_releases.html +pip install "jax[cpu]==0.3.25" -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ # Install the jaxlib 0.3.25 CPU wheel directly -pip install jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_releases.html +pip install jaxlib==0.3.25 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ ``` For specific older GPU wheels, be sure to use the `jax_cuda_releases.html` URL; for example ```bash diff --git a/docs/internals/constants.md b/docs/internals/constants.md new file mode 100644 index 000000000000..d8eab7e9e5f3 --- /dev/null +++ b/docs/internals/constants.md @@ -0,0 +1,193 @@ +(constants-note)= + +# Handling of closed-over constants + +"Closed-over constants" are non-scalar arrays that are encountered during JAX tracing +of a function and do not have dependencies on any of the function's arguments. +JAX operations such as `jax.numpy` and `lax` are staged out and do not create +closed-over constants. +In the following example, the arrays +`a_jax_array` and `np.full` are closed-over constants, but `jnp.full` +is not. We refer below to closed-over constants simply as constants. + +```python +import numpy as np +from jax import jit +from jax import numpy as jnp + +a_jax_array = jnp.ones((16,), dtype=np.float32) + +@jit +def f(x): + return x + a_jax_array + np.full((16,), 42.) + jnp.full((16,), 142.) +``` + +We describe below the **future** internal implementation details for +constants. As of July 2025, this is not yet the default implementation; +it is enabled by the environment variable `JAX_USE_SIMPLIFIED_JAXPR_CONSTANTS=True`. +See further [below](#previous-implementation) for the details of the previous +implementation, including its drawbacks. + +## Tracing + +When JAX tracing encounters a constant that is either an argument of a JAX primitive +or a function return, it is represented as a `core.Literal`, and is embedded +in the `Jaxpr` along with the primitives that use them. +The function `core.is_literalable` decides which constants are turned into +`core.Literal`. All scalar constants are turned into `core.Literal`, along with +non-scalar `np.ndarray` and `jax.Array`. + +## Lowering + +When lowering the code to HLO we could just emit a `stablehlo.constant` operation +for a `core.Literal`, but this would have several disadvantages: + + * if the constant is a `jax.Array` (e.g., the `a_jax_array` above), then it is + pulled from the device to the host during lowering, and it will later re-materialized + on the device when the lowered module executes. + This can increase the host memory usage, sometimes dramatically. + Furthermore, if the constant is sharded on multiple devices this + sharding is lost. + * large constants increase the size of the HLO, especially if + the same constant is used multiple times. Also, the XLA compiler will attempt + to constant-fold them, resulting in warnings and slow compilation. Furthermore, + we have observed that XLA constant-folding sometimes produces slightly different + numerics compared to compiled code. + See also [Large closed-over constants are inlined in the HLO code #29684](https://github.com/jax-ml/jax/issues/29684). + +Instead, during lowering we use the function `core.jaxpr_const_args` to scan +a `Jaxpr` and return a list of constants contained within, uniquified by their +`id`. The `core.jaxpr_const_args` is memoized for each `Jaxpr` and sub-`Jaxpr` +on which it is called. + +All the lowered HLO functions will take one additional argument +for each unique constant appearing in the `Jaxpr` to which it corresponds. +These arguments, referred to as `const_args`, +come after the dimension variable arguments, after the +token arguments, and just before the actual array arguments. +During lowering we maintain a mapping `const_lowering: dict[int, mlir.IrValues]` +from the `id` of the constants to the HLO values for the corresponding +const args. +This mapping is stored in the `mlir.LoweringRuleContext` and is used +by `mlir.ir_constant`: when a constant is encountered, we just reuse +the existing lowering from `const_lowering` instead of emitting a +`stablehlo.constant`. + +When we lower an HLO inner function (i.e., not the `main` function), +we call again `core.jaxpr_const_args` +to get the actual constants in the corresponding `Jaxpr`. These are +expected to be among the constants for which we have a `const_lowering`. +The inner function will get its own smaller set of `const_args` and +its own `const_lowering` mapping to be used when lowering the body. +E.g., the function `mlir.lower_jaxpr_as_fun` is one place where some +of this happens. + +The function `mlir.jaxpr_subcomp` does not create a new HLO function, +but instead creates a block within the current function. It uses +the enclosing function's `const_lowering`. + +Note also that there will still be `stablehlo.constant` in the lowered +code, in three cases: + * when the constant is a scalar; we want these constants to be + available to XLA for constant folding. + * when the constant did not appear in the traced program, and is + hence not in the `Jaxpr`. This can happen for constants that + arise during lowering, e.g., the lowering of some PRNG functions + include constants. + * when we are exporting: at the moment, we do not hoist constant args + when we export because the export serialization does not currently support + serialization of arrays. + We use the `mlir.LoweringParameters.hoist_constants_as_args` parameter + to control this. + +One additional complication is that some of the internal lowering functions +need to take the argument avals and sometimes also the shardings and +layouts for the arguments. Furthermore, the avals, shardings, and layout for +all arguments, including the const args, +are used also after lowering also. Therefore, it is convenient +to compute these fairly high in the call stack, e.g., in +`pxla.lower_sharding_computations`, and pass them down. + +For example, the functions `mlir.lower_jaxpr_to_module`, +`pjit._pjit_cached_lower_jaxpr_to_fun`, and, `mlir.lower_jaxpr_to_fun` +take `in_avals`, `in_shardings`, and `in_layouts` that +that include both the avals for const_args and for the regular args +(the ones corresponding to the `Jaxpr.invars`). +They also take a `num_const_args` argument. + +## Compilation and execution + +The lowered MLIR module contains arguments for the const args, so +the compiled executable will need to be passed the const args. +It is important to choose the right place where we prepend the +const args. For example, in the following code, the second invocation +of the jitted function `f` is expected to hit the C++ jit cache without +any Python code executing. + +```python +const = jnp.array([42.]) +f = jax.jit(lambda: const) + +f() +f() +``` + +(TODO: yashk2810 plans to write a description of how the jit caches work.) +This means that the `const` will have to be passed to the executable in C++ +(and thus stored in `pxla.MeshExecutableFastpathData`), +and therefore the C++ cache +miss functions (e.g., `pjit._cpp_pjit.cache_miss`, +or `aot_cache_miss` in `pxla.MeshExecutable.create_cpp_call`) +will not take the const args as arguments. Instead these cache +miss functions will have to prepend the const args. + +The C++ fast path has support for const args starting with jaxlib 0.7.1. +In prior versions, the fast path is disabled when there are const args. + +To implement this scheme, we keep the `const_args` in +`stages.Lowering`, `stages.Lowered`, and `stages.CompiledCallParams`. + +Interestingly, when we serialize an executable, e.g., for the compilation +cache, we do not need to serialize the closed over constants. The executable +itself does not contain them, and needs to take them as const args. +Whoever is going to deserialize the cached executable will have to pass +the const args. + +In AOT mode, the lowering and execution may +use different values of the `jax_enable_x64` configuration value. +If the constants are 64-bit `ndarray` we must use the same value +of `jax_enable_x64` for lowering and execution. + +## Previous implementation + +This describes the current way we handle closed-over constants, as +of July 2025 (as long as `JAX_USE_SIMPLIFIED_CONSTANTS=False`). + +When JAX traces a function to a `Jaxpr` it collects the closed-over values +into a set of constants, and adds a corresponding set of `constvars` to the Jaxpr +(the actual arguments are represented by `invars`). +Most tracing functions, e.g., `trace_to_jaxpr_dynamic`, +return both the `Jaxpr` and the constants. + +In many places in the code we use a class `core.ClosedJaxpr` that contains a +`Jaxpr` and `consts` corresponding to the `Jaxpr.constvars`. + +There are several issues with `ClosedJaxpr`: + + * the lowering of the `consts` in `ClosedJaxpr` results in inlined + `stablehlo.constant`, with all the issues described above. + * `Jaxpr` and `ClosedJaxpr` are used pervasively in JAX, often with the + generic name `jaxpr` and it is not easy to tell which kind of `Jaxpr` we have. + We have started to add type declarations, but in some places the code + is written with `isinstance` conditionals to work with both. + * Since Jaxpr and ClosedJaxpr are sometimes used as caching keys, + and they are hashed by `id`, we would like to memoize their construction. + For example, the function [pe.closed_jaxpr](https://github.com/jax-ml/jax/blob/0956da1466d03af81b24d16554f30f2ff8163346/jax/_src/interpreters/partial_eval.py#L1570) + memoizes the construction of `ClosedJaxpr` but only for the case when consts is empty. + This is because sometimes consts are not hashable. + * Handling the constants in ClosedJaxpr requires some extra care. + E.g., there are places in the Mosaic lowering where we have not yet implemented + the handling of ClosedJaxpr with non-empty constants + (e.g. [here](https://github.com/jax-ml/jax/blob/7d924e8f72fd84fb2305f0a1683ae081f171602f/jax/_src/pallas/mosaic/lowering.py#L3115)). + * When we turn closed-over constants into inputs we have to be careful + during transformations with how we handle these auxiliary inputs. diff --git a/docs/internals/index.rst b/docs/internals/index.rst new file mode 100644 index 000000000000..188e37290e89 --- /dev/null +++ b/docs/internals/index.rst @@ -0,0 +1,13 @@ +JAX Internal Implementation Notes +================================= + +This section of the documentation describes implementation details that +are too elaborate to fit in code comments, or in internal function docstrings. + +Like code comments, the material here should be used with caution because it +is prone to become stale. + +.. toctree:: + :maxdepth: 1 + + Handling of closed-over constants diff --git a/docs/jacobian-vector-products.md b/docs/jacobian-vector-products.md new file mode 100644 index 000000000000..bbc678d02d1c --- /dev/null +++ b/docs/jacobian-vector-products.md @@ -0,0 +1,358 @@ +--- +jupytext: + formats: md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.4 +kernelspec: + display_name: Python 3 + name: python3 +--- + +(advanced-guides-jvp-vjp)= +# Forward- and reverse-mode autodiff in JAX + +## Jacobian-Vector products (JVPs, a.k.a. forward-mode autodiff) + +JAX includes efficient and general implementations of both forward- and reverse-mode automatic differentiation. The familiar {func}`jax.grad` function is built on reverse-mode, but to explain the difference between the two modes, and when each can be useful, you need a bit of math background. + +### JVPs in math + +Mathematically, given a function $f : \mathbb{R}^n \to \mathbb{R}^m$, the Jacobian of $f$ evaluated at an input point $x \in \mathbb{R}^n$, denoted $\partial f(x)$, is often thought of as a matrix in $\mathbb{R}^m \times \mathbb{R}^n$: + +$\qquad \partial f(x) \in \mathbb{R}^{m \times n}$. + +But you can also think of $\partial f(x)$ as a linear map, which maps the tangent space of the domain of $f$ at the point $x$ (which is just another copy of $\mathbb{R}^n$) to the tangent space of the codomain of $f$ at the point $f(x)$ (a copy of $\mathbb{R}^m$): + +$\qquad \partial f(x) : \mathbb{R}^n \to \mathbb{R}^m$. + +This map is called the [pushforward map](https://en.wikipedia.org/wiki/Pushforward_(differential)) of $f$ at $x$. The Jacobian matrix is just the matrix for this linear map on a standard basis. + +If you don't commit to one specific input point $x$, then you can think of the function $\partial f$ as first taking an input point and returning the Jacobian linear map at that input point: + +$\qquad \partial f : \mathbb{R}^n \to \mathbb{R}^n \to \mathbb{R}^m$. + +In particular, you can uncurry things so that given input point $x \in \mathbb{R}^n$ and a tangent vector $v \in \mathbb{R}^n$, you get back an output tangent vector in $\mathbb{R}^m$. We call that mapping, from $(x, v)$ pairs to output tangent vectors, the *Jacobian-vector product*, and write it as: + +$\qquad (x, v) \mapsto \partial f(x) v$ + +### JVPs in JAX code + +Back in Python code, JAX's {func}`jax.jvp` function models this transformation. Given a Python function that evaluates $f$, JAX's {func}`jax.jvp` is a way to get a Python function for evaluating $(x, v) \mapsto (f(x), \partial f(x) v)$. + +```{code-cell} +import jax +import jax.numpy as jnp + +key = jax.random.key(0) + +# Initialize random model coefficients +key, W_key, b_key = jax.random.split(key, 3) +W = jax.random.normal(W_key, (3,)) +b = jax.random.normal(b_key, ()) + +# Define a sigmoid function. +def sigmoid(x): + return 0.5 * (jnp.tanh(x / 2) + 1) + +# Outputs probability of a label being true. +def predict(W, b, inputs): + return sigmoid(jnp.dot(inputs, W) + b) + +# Build a toy dataset. +inputs = jnp.array([[0.52, 1.12, 0.77], + [0.88, -1.08, 0.15], + [0.52, 0.06, -1.30], + [0.74, -2.49, 1.39]]) + +# Isolate the function from the weight matrix to the predictions +f = lambda W: predict(W, b, inputs) + +key, subkey = jax.random.split(key) +v = jax.random.normal(subkey, W.shape) + +# Push forward the vector `v` along `f` evaluated at `W` +y, u = jax.jvp(f, (W,), (v,)) +``` + +In terms of [Haskell-like type signatures](https://wiki.haskell.org/Type_signature), you could write: + +```haskell +jvp :: (a -> b) -> a -> T a -> (b, T b) +``` + +where `T a` is used to denote the type of the tangent space for `a`. + +In other words, `jvp` takes as arguments a function of type `a -> b`, a value of type `a`, and a tangent vector value of type `T a`. It gives back a pair consisting of a value of type `b` and an output tangent vector of type `T b`. + +The `jvp`-transformed function is evaluated much like the original function, but paired up with each primal value of type `a` it pushes along tangent values of type `T a`. For each primitive numerical operation that the original function would have applied, the `jvp`-transformed function executes a "JVP rule" for that primitive that both evaluates the primitive on the primals and applies the primitive's JVP at those primal values. + +That evaluation strategy has some immediate implications about computational complexity. Since we evaluate JVPs as we go, we don't need to store anything for later, and so the memory cost is independent of the depth of the computation. In addition, the FLOP cost of the `jvp`-transformed function is about 3x the cost of just evaluating the function (one unit of work for evaluating the original function, for example `sin(x)`; one unit for linearizing, like `cos(x)`; and one unit for applying the linearized function to a vector, like `cos_x * v`). Put another way, for a fixed primal point $x$, we can evaluate $v \mapsto \partial f(x) \cdot v$ for about the same marginal cost as evaluating $f$. + +That memory complexity sounds pretty compelling! So why don't we see forward-mode very often in machine learning? + +To answer that, first think about how you could use a JVP to build a full Jacobian matrix. If we apply a JVP to a one-hot tangent vector, it reveals one column of the Jacobian matrix, corresponding to the nonzero entry we fed in. So we can build a full Jacobian one column at a time, and to get each column costs about the same as one function evaluation. That will be efficient for functions with "tall" Jacobians, but inefficient for "wide" Jacobians. + +If you're doing gradient-based optimization in machine learning, you probably want to minimize a loss function from parameters in $\mathbb{R}^n$ to a scalar loss value in $\mathbb{R}$. That means the Jacobian of this function is a very wide matrix: $\partial f(x) \in \mathbb{R}^{1 \times n}$, which we often identify with the Gradient vector $\nabla f(x) \in \mathbb{R}^n$. Building that matrix one column at a time, with each call taking a similar number of FLOPs to evaluate the original function, sure seems inefficient! In particular, for training neural networks, where $f$ is a training loss function and $n$ can be in the millions or billions, this approach just won't scale. + +To do better for functions like this, you just need to use reverse-mode. + +## Vector-Jacobian products (VJPs, a.k.a. reverse-mode autodiff) + +Where forward-mode gives us back a function for evaluating Jacobian-vector products, which we can then use to build Jacobian matrices one column at a time, reverse-mode is a way to get back a function for evaluating vector-Jacobian products (equivalently Jacobian-transpose-vector products), which we can use to build Jacobian matrices one row at a time. + +### VJPs in math + +Let's again consider a function $f : \mathbb{R}^n \to \mathbb{R}^m$. +Starting from our notation for JVPs, the notation for VJPs is pretty simple: + +$\qquad (x, v) \mapsto v \partial f(x)$, + +where $v$ is an element of the cotangent space of $f$ at $x$ (isomorphic to another copy of $\mathbb{R}^m$). When being rigorous, we should think of $v$ as a linear map $v : \mathbb{R}^m \to \mathbb{R}$, and when we write $v \partial f(x)$ we mean function composition $v \circ \partial f(x)$, where the types work out because $\partial f(x) : \mathbb{R}^n \to \mathbb{R}^m$. But in the common case we can identify $v$ with a vector in $\mathbb{R}^m$ and use the two almost interchangeably, just like we might sometimes flip between "column vectors" and "row vectors" without much comment. + +With that identification, we can alternatively think of the linear part of a VJP as the transpose (or adjoint conjugate) of the linear part of a JVP: + +$\qquad (x, v) \mapsto \partial f(x)^\mathsf{T} v$. + +For a given point $x$, we can write the signature as + +$\qquad \partial f(x)^\mathsf{T} : \mathbb{R}^m \to \mathbb{R}^n$. + +The corresponding map on cotangent spaces is often called the [pullback](https://en.wikipedia.org/wiki/Pullback_(differential_geometry)) +of $f$ at $x$. The key for our purposes is that it goes from something that looks like the output of $f$ to something that looks like the input of $f$, just like we might expect from a transposed linear function. + +### VJPs in JAX code + +Switching from math back to Python, the JAX function `vjp` can take a Python function for evaluating $f$ and give us back a Python function for evaluating the VJP $(x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))$. + +```{code-cell} +from jax import vjp + +# Isolate the function from the weight matrix to the predictions +f = lambda W: predict(W, b, inputs) + +y, vjp_fun = vjp(f, W) + +key, subkey = jax.random.split(key) +u = jax.random.normal(subkey, y.shape) + +# Pull back the covector `u` along `f` evaluated at `W` +v = vjp_fun(u) +``` + +In terms of [Haskell-like type signatures](https://wiki.haskell.org/Type_signature), we could write + +```haskell +vjp :: (a -> b) -> a -> (b, CT b -> CT a) +``` + +where we use `CT a` to denote the type for the cotangent space for `a`. In words, `vjp` takes as arguments a function of type `a -> b` and a point of type `a`, and gives back a pair consisting of a value of type `b` and a linear map of type `CT b -> CT a`. + +This is great because it lets us build Jacobian matrices one row at a time, and the FLOP cost for evaluating $(x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))$ is only about three times the cost of evaluating $f$. In particular, if we want the gradient of a function $f : \mathbb{R}^n \to \mathbb{R}$, we can do it in just one call. That's how {func}`jax.grad` is efficient for gradient-based optimization, even for objectives like neural network training loss functions on millions or billions of parameters. + +There's a cost, though the FLOPs are friendly, memory scales with the depth of the computation. Also, the implementation is traditionally more complex than that of forward-mode, though JAX has some tricks up its sleeve (that's a story for a future notebook!). + +For more on how reverse-mode works, check out [this tutorial video from the Deep Learning Summer School in 2017](http://videolectures.net/deeplearning2017_johnson_automatic_differentiation/). + +## Vector-valued gradients with VJPs + +If you're interested in taking vector-valued gradients (like `tf.gradients`): + +```{code-cell} +def vgrad(f, x): + y, vjp_fn = jax.vjp(f, x) + return vjp_fn(jnp.ones(y.shape))[0] + +print(vgrad(lambda x: 3*x**2, jnp.ones((2, 2)))) +``` + +## Hessian-vector products using both forward- and reverse-mode + +In a previous section, you implemented a Hessian-vector product function just using reverse-mode (assuming continuous second derivatives): + +```{code-cell} +def hvp(f, x, v): + return jax.grad(lambda x: jnp.vdot(jax.grad(f)(x), v))(x) +``` + +That's efficient, but you can do even better and save some memory by using forward-mode together with reverse-mode. + +Mathematically, given a function $f : \mathbb{R}^n \to \mathbb{R}$ to differentiate, a point $x \in \mathbb{R}^n$ at which to linearize the function, and a vector $v \in \mathbb{R}^n$, the Hessian-vector product function we want is: + +$(x, v) \mapsto \partial^2 f(x) v$ + +Consider the helper function $g : \mathbb{R}^n \to \mathbb{R}^n$ defined to be the derivative (or gradient) of $f$, namely $g(x) = \partial f(x)$. All you need is its JVP, since that will give us: + +$(x, v) \mapsto \partial g(x) v = \partial^2 f(x) v$. + +We can translate that almost directly into code: + +```{code-cell} +# forward-over-reverse +def hvp(f, primals, tangents): + return jax.jvp(jax.grad(f), primals, tangents)[1] +``` + +Even better, since you didn't have to call {func}`jnp.dot` directly, this `hvp` function works with arrays of any shape and with arbitrary container types (like vectors stored as nested lists/dicts/tuples), and doesn't even have a dependence on {mod}`jax.numpy`. + +Here's an example of how to use it: + +```{code-cell} +def f(X): + return jnp.sum(jnp.tanh(X)**2) + +key, subkey1, subkey2 = jax.random.split(key, 3) +X = jax.random.normal(subkey1, (30, 40)) +V = jax.random.normal(subkey2, (30, 40)) + +def hessian(f): + return jax.jacfwd(jax.jacrev(f)) + +ans1 = hvp(f, (X,), (V,)) +ans2 = jnp.tensordot(hessian(f)(X), V, 2) + +print(jnp.allclose(ans1, ans2, 1e-4, 1e-4)) +``` + +Another way you might consider writing this is using reverse-over-forward: + +```{code-cell} +# Reverse-over-forward +def hvp_revfwd(f, primals, tangents): + g = lambda primals: jax.jvp(f, primals, tangents)[1] + return jax.grad(g)(primals) +``` + +That's not quite as good, though, because forward-mode has less overhead than reverse-mode, and since the outer differentiation operator here has to differentiate a larger computation than the inner one, keeping forward-mode on the outside works best: + +```{code-cell} +# Reverse-over-reverse, only works for single arguments +def hvp_revrev(f, primals, tangents): + x, = primals + v, = tangents + return jax.grad(lambda x: jnp.vdot(jax.grad(f)(x), v))(x) + + +print("Forward over reverse") +%timeit -n10 -r3 hvp(f, (X,), (V,)) +print("Reverse over forward") +%timeit -n10 -r3 hvp_revfwd(f, (X,), (V,)) +print("Reverse over reverse") +%timeit -n10 -r3 hvp_revrev(f, (X,), (V,)) + +print("Naive full Hessian materialization") +%timeit -n10 -r3 jnp.tensordot(jax.hessian(f)(X), V, 2) +``` + +## Composing VJPs, JVPs, and `jax.vmap` + +## Jacobian-Matrix and Matrix-Jacobian products + +Now that you have {func}`jax.jvp` and {func}`jax.vjp` transformations that give you functions to push-forward or pull-back single vectors at a time, you can use JAX's {func}`jax.vmap` [transformation](https://github.com/jax-ml/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, you can use that to write fast matrix-Jacobian and Jacobian-matrix products: + +```{code-cell} +# Isolate the function from the weight matrix to the predictions +f = lambda W: predict(W, b, inputs) + +# Pull back the covectors `m_i` along `f`, evaluated at `W`, for all `i`. +# First, use a list comprehension to loop over rows in the matrix M. +def loop_mjp(f, x, M): + y, vjp_fun = jax.vjp(f, x) + return jnp.vstack([vjp_fun(mi) for mi in M]) + +# Now, use vmap to build a computation that does a single fast matrix-matrix +# multiply, rather than an outer loop over vector-matrix multiplies. +def vmap_mjp(f, x, M): + y, vjp_fun = jax.vjp(f, x) + outs, = jax.vmap(vjp_fun)(M) + return outs + +key = jax.random.key(0) +num_covecs = 128 +U = jax.random.normal(key, (num_covecs,) + y.shape) + +loop_vs = loop_mjp(f, W, M=U) +print('Non-vmapped Matrix-Jacobian product') +%timeit -n10 -r3 loop_mjp(f, W, M=U) + +print('\nVmapped Matrix-Jacobian product') +vmap_vs = vmap_mjp(f, W, M=U) +%timeit -n10 -r3 vmap_mjp(f, W, M=U) + +assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical' +``` + +```{code-cell} +def loop_jmp(f, W, M): + # jvp immediately returns the primal and tangent values as a tuple, + # so we'll compute and select the tangents in a list comprehension + return jnp.vstack([jax.jvp(f, (W,), (mi,))[1] for mi in M]) + +def vmap_jmp(f, W, M): + _jvp = lambda s: jax.jvp(f, (W,), (s,))[1] + return jax.vmap(_jvp)(M) +num_vecs = 128 +S = jax.random.normal(key, (num_vecs,) + W.shape) + +loop_vs = loop_jmp(f, W, M=S) +print('Non-vmapped Jacobian-Matrix product') +%timeit -n10 -r3 loop_jmp(f, W, M=S) +vmap_vs = vmap_jmp(f, W, M=S) +print('\nVmapped Jacobian-Matrix product') +%timeit -n10 -r3 vmap_jmp(f, W, M=S) + +assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical' +``` + +## The implementation of `jax.jacfwd` and `jax.jacrev` + +Now that we've seen fast Jacobian-matrix and matrix-Jacobian products, it's not hard to guess how to write {func}`jax.jacfwd` and {func}`jax.jacrev`. We just use the same technique to push-forward or pull-back an entire standard basis (isomorphic to an identity matrix) at once. + +```{code-cell} +from jax import jacrev as builtin_jacrev + +def our_jacrev(f): + def jacfun(x): + y, vjp_fun = jax.vjp(f, x) + # Use vmap to do a matrix-Jacobian product. + # Here, the matrix is the Euclidean basis, so we get all + # entries in the Jacobian at once. + J, = jax.vmap(vjp_fun, in_axes=0)(jnp.eye(len(y))) + return J + return jacfun + +assert jnp.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!' +``` + +```{code-cell} +from jax import jacfwd as builtin_jacfwd + +def our_jacfwd(f): + def jacfun(x): + _jvp = lambda s: jax.jvp(f, (x,), (s,))[1] + Jt = jax.vmap(_jvp, in_axes=1)(jnp.eye(len(x))) + return jnp.transpose(Jt) + return jacfun + +assert jnp.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!' +``` + +Interestingly, the [Autograd](https://github.com/hips/autograd) library couldn't do this. The [implementation](https://github.com/HIPS/autograd/blob/96a03f44da43cd7044c61ac945c483955deba957/autograd/differential_operators.py#L60) of reverse-mode `jacobian` in Autograd had to pull back one vector at a time with an outer-loop `map`. Pushing one vector at a time through the computation is much less efficient than batching it all together with {func}`jax.vmap`. + +Another thing that Autograd couldn't do is {func}`jax.jit`. Interestingly, no matter how much Python dynamism you use in your function to be differentiated, we could always use {func}`jax.jit` on the linear part of the computation. For example: + +```{code-cell} +def f(x): + try: + if x < 3: + return 2 * x ** 3 + else: + raise ValueError + except ValueError: + return jnp.pi * x + +y, f_vjp = jax.vjp(f, 4.) +print(jax.jit(f_vjp)(1.)) +``` diff --git a/docs/jax-101.rst b/docs/jax-101.rst new file mode 100644 index 000000000000..328e60ab4e3d --- /dev/null +++ b/docs/jax-101.rst @@ -0,0 +1,22 @@ +.. _jax-101: + +JAX 101 +======= + +These tutorials cover basic usage of JAX and its features, including some of the +internal mechanisms that make JAX work. They start with the fundamentals and are +meant to be read sequentially. For more in-depth discussions of JAX's design and +implementation, in no particular order, see the :doc:`advanced guides `. + +.. toctree:: + :maxdepth: 1 + + jit-compilation + automatic-vectorization + automatic-differentiation + pytrees + random-numbers + sharded-computation + control-flow + tracing + stateful-computations diff --git a/docs/jax-primitives.md b/docs/jax-primitives.md index abdc8be6d0a8..e0e09c7c509a 100644 --- a/docs/jax-primitives.md +++ b/docs/jax-primitives.md @@ -21,14 +21,14 @@ kernelspec: A JAX primitive is the basic computational unit of a JAX program. This document explains the interface that a JAX primitive must support to allow JAX to perform all its transformations (this is not a how-to guide). -For example, the multiply-add operation can be implemented in terms of the low-level `jax.lax.*` primitives (which are like XLA operator wrappers) or `jax.core.Primitive("multiply_add")`, as demonstrated further below. +For example, the multiply-add operation can be implemented in terms of the low-level `jax.lax.*` primitives (which are like XLA operator wrappers) or `jax.extend.core.Primitive("multiply_add")`, as demonstrated further below. And JAX is able to take sequences of such primitive operations, and transform them via its composable transformations of Python functions, such as {func}`jax.jit`, {func}`jax.grad` and {func}`jax.vmap`. JAX implements these transforms in a *JAX-traceable* way. This means that when a Python function is executed, the only operations it applies to the data are either: -- **Inspections of data attributes:** Data information, such as shape or type; or +- **Inspections of data attributes:** Data information, such as shape or type; or - **JAX primitives:** These are the JAX special operations covered in this tutorial. -JAX primitives know how to operate on both concrete data values and abstract JAX values. *A JAX-traceable function* can be invoked by JAX with abstract arguments. For example, a JAX abstract value — `ShapedArray(float32[2,2])` — captures the type and the shape of values, but not the concrete data values. +JAX primitives know how to operate on both concrete data values and abstract JAX values. *A JAX-traceable function* can be invoked by JAX with abstract arguments. For example, a JAX abstract value — `ShapedArray(float32[2,2])` — captures the type and the shape of values, but not the concrete data values. The JAX-transformed functions must themselves be JAX-traceable functions *to make sure that these transformations are composable*, for example like `jax.jit(jax.jacfwd(jax.grad(f)))`. @@ -49,7 +49,7 @@ Consider the following example: you want to add to JAX support for a multiply-ad The easiest way to define new functions is to write them in terms of JAX primitives, or in terms of other functions that are themselves written using JAX primitives, for example, those defined in the {func}`jax.lax` module: ```{code-cell} -from jax import lax +from jax._src.lax import lax from jax._src import api def multiply_add_lax(x, y, z): @@ -100,7 +100,7 @@ def trace(name): vtype = str(type(v)) if "jax._src.xla_bridge._JaxComputationBuilder" in vtype: return "" - elif "jaxlib.xla_extension.XlaOp" in vtype: + elif "jaxlib._jax_.XlaOp" in vtype: return "".format(id(v)) elif ("partial_eval.JaxprTracer" in vtype or "batching.BatchTracer" in vtype or @@ -112,7 +112,7 @@ def trace(name): return str(v) def pp_values(args): return ", ".join([pp(arg) for arg in args]) - + @functools.wraps(func) def func_wrapper(*args): _trace_indent("call {}({})".format(name, pp_values(args))) @@ -140,7 +140,7 @@ class expectNotImplementedError(object): return False ``` -Instead of using {func}`jax.lax` primitives directly, you can use other functions +Instead of using {func}`jax.lax` primitives directly, you can use other functions that are already written in terms of those primitives, such as those in `jax.numpy`: ```{code-cell} @@ -155,7 +155,7 @@ def multiply_add_numpy(x, y, z): def square_add_numpy(a, b): return multiply_add_numpy(a, a, b) -print("\nNormal evaluation:") +print("\nNormal evaluation:") print("square_add_numpy = ", square_add_numpy(2., 10.)) print("\nGradient evaluation:") print("grad(square_add_numpy) = ", api.grad(square_add_numpy)(2.0, 10.)) @@ -171,16 +171,16 @@ The JAX traceability property is satisfied as long as the function is written in The right way to add support for multiply-add is in terms of existing JAX primitives, as shown above. However, to demonstrate how JAX primitives work, pretend that you want to add a new primitive to JAX for the multiply-add functionality. ```{code-cell} -from jax import core +from jax.extend import core multiply_add_p = core.Primitive("multiply_add") # Create the primitive @trace("multiply_add_prim") def multiply_add_prim(x, y, z): """The JAX-traceable way to use the JAX primitive. - + Note that the traced arguments must be passed as positional arguments - to `bind`. + to `bind`. """ return multiply_add_p.bind(x, y, z) @@ -209,7 +209,7 @@ def multiply_add_impl(x, y, z): This function does not need to be JAX traceable. Args: - x, y, z: The concrete arguments of the primitive. Will only be called with + x, y, z: The concrete arguments of the primitive. Will only be called with concrete values. Returns: @@ -241,8 +241,8 @@ with expectNotImplementedError(): To JIT the function, and for other transformations as well, JAX first evaluates it abstractly using only the shape and type of the arguments. This abstract evaluation serves multiple purposes: - * Gets the sequence of JAX primitives that are used in the computation. This sequence will be compiled. - * Computes the shape and type of all vectors and operations used in the computation. + * Gets the sequence of JAX primitives that are used in the computation. This sequence will be compiled. + * Computes the shape and type of all vectors and operations used in the computation. For example, the abstraction of a vector with 3 elements may be `ShapedArray(float32[3])`, or `ConcreteArray([1., 2., 3.])`. In the latter case, JAX uses the actual concrete value wrapped as an abstract value. @@ -300,7 +300,7 @@ def multiply_add_lowering(ctx, xc, yc, zc): return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result] # Now, register the lowering rule with JAX. -# For GPU, refer to the https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html +# For GPU, refer to the https://docs.jax.dev/en/latest/Custom_Operation_for_GPUs.html from jax.interpreters import mlir mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu') @@ -315,13 +315,13 @@ assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14. Below is another use of `jit`, where you compile only with respect to the first argument. Notice how the second argument to `square_add_prim` is concrete, which leads in the third argument to `multiply_add_abstract_eval` being `ConcreteArray`. Notice that `multiply_add_abstract_eval` may be used with both `ShapedArray` and `ConcreteArray`. ```{code-cell} -assert api.jit(lambda x, y: square_add_prim(x, y), +assert api.jit(lambda x, y: square_add_prim(x, y), static_argnums=1)(2., 10.) == 14. ``` ### Forward differentiation -JAX implements forward differentiation in the form of a Jacobian-Vector Product (JVP) (you can learn more about it in {ref}`advanced-autodiff`). +JAX implements forward differentiation in the form of a Jacobian-Vector Product (JVP) (you can learn more about it in {ref}`advanced-guides-jvp-vjp`). If you attempt to compute the `jvp` function, you'll get an error because you have not yet told JAX how to differentiate the `multiply_add` primitive. @@ -342,16 +342,16 @@ from jax.interpreters import ad def multiply_add_value_and_jvp(arg_values, arg_tangents): """Evaluates the primal output and the tangents (Jacobian-vector product). - Given values of the arguments and perturbation of the arguments (tangents), + Given values of the arguments and perturbation of the arguments (tangents), compute the output of the primitive and the perturbation of the output. - This method must be JAX-traceable. JAX may invoke it with abstract values + This method must be JAX-traceable. JAX may invoke it with abstract values for the arguments and tangents. Args: arg_values: A tuple of arguments - arg_tangents: A tuple with the tangents of the arguments. The tuple has - the same length as the arg_values. Some of the tangents may also be the + arg_tangents: A tuple with the tangents of the arguments. The tuple has + the same length as the arg_values. Some of the tangents may also be the special value `ad.Zero` to specify a zero tangent Returns: @@ -360,21 +360,21 @@ def multiply_add_value_and_jvp(arg_values, arg_tangents): x, y, z = arg_values xt, yt, zt = arg_tangents _trace("Primal evaluation:") - # Now, you have a JAX-traceable computation of the output. - # Normally, you can use the multiply add (`ma`) primitive itself to compute the primal output. + # Now, you have a JAX-traceable computation of the output. + # Normally, you can use the multiply add (`ma`) primitive itself to compute the primal output. primal_out = multiply_add_prim(x, y, z) _trace("Tangent evaluation:") - # You must use a JAX-traceable way to compute the tangent. It turns out that + # You must use a JAX-traceable way to compute the tangent. It turns out that # the output tangent can be computed as (xt * y + x * yt + zt), # which you can implement in a JAX-traceable way using the same "multiply_add_prim" primitive. - # You do need to deal specially with `Zero`. Here, you just turn it into a - # proper tensor of 0s (of the same shape as 'x'). - # An alternative would be to check for `Zero` and perform algebraic + # You do need to deal specially with `Zero`. Here, you just turn it into a + # proper tensor of 0s (of the same shape as 'x'). + # An alternative would be to check for `Zero` and perform algebraic # simplification of the output tangent computation. def make_zero(tan): - return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan + return lax.full_like(x, 0) if type(tan) is ad.Zero else tan output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt))) return (primal_out, output_tangent) @@ -393,7 +393,7 @@ assert api.jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.) You can apply `jit` to the forward differentiation function: ```{code-cell} -assert api.jit(lambda arg_values, arg_tangents: +assert api.jit(lambda arg_values, arg_tangents: api.jvp(square_add_prim, arg_values, arg_tangents))( (2., 10.), (1., 1.)) == (14., 5.) ``` @@ -456,8 +456,8 @@ JAX will produce the reverse differentiation computation by processing the JVP c xct += act * 4. ``` -One can verify that this computation produces `xct = 4.` and `yct = 3.`, which -are the partial derivatives of the function `f`. +One can verify that this computation produces `xct = 4.` and `yct = 3.`, which +are the partial derivatives of the function `f`. JAX knows for each primitive that may appear in a JVP calculation how to transpose it. Conceptually, if the primitive `p(x, y, z)` is linear in the arguments `y` and `z` for a constant value of `x`, e.g., `p(x, y, z) = y*cy + z*cz`, then the transposition of the primitive is: @@ -480,13 +480,13 @@ In particular: def multiply_add_transpose(ct, x, y, z): """Evaluates the transpose of a linear primitive. - This method is only used when computing the backward gradient following - `value_and_jvp`, and is only needed for primitives that are used in the JVP - calculation for some other primitive. You need a transposition for `multiply_add_prim`, - because you have used `multiply_add_prim` in the computation of the `output_tangent` in + This method is only used when computing the backward gradient following + `value_and_jvp`, and is only needed for primitives that are used in the JVP + calculation for some other primitive. You need a transposition for `multiply_add_prim`, + because you have used `multiply_add_prim` in the computation of the `output_tangent` in `multiply_add_value_and_jvp`. - In this case, multiply_add is not a linear primitive. However, it is used linearly + In this case, multiply_add is not a linear primitive. However, it is used linearly w.r.t. tangents in `multiply_add_value_and_jvp`: `output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))`. @@ -505,12 +505,12 @@ def multiply_add_transpose(ct, x, y, z): if not ad.is_undefined_primal(x): # This use of multiply_add is with a constant "x". assert ad.is_undefined_primal(y) - ct_y = ad.Zero(y.aval) if type(ct) is ad.Zero else multiply_add_prim(x, ct, lax.zeros_like_array(x)) + ct_y = ad.Zero(y.aval) if type(ct) is ad.Zero else multiply_add_prim(x, ct, lax.full_like(x, 0)) res = None, ct_y, ct else: # This use of multiply_add is with a constant "y". assert ad.is_undefined_primal(x) - ct_x = ad.Zero(x.aval) if type(ct) is ad.Zero else multiply_add_prim(ct, y, lax.zeros_like_array(y)) + ct_x = ad.Zero(x.aval) if type(ct) is ad.Zero else multiply_add_prim(ct, y, lax.full_like(y, 0)) res = ct_x, None, ct return res @@ -526,7 +526,7 @@ assert api.grad(square_add_prim)(2., 10.) == 4. Notice the two calls to `multiply_add_transpose`. They correspond to the two uses of `multiply_add_prim` in the computation of the `output_tangent` in `multiply_add_value_and_jvp`. The first call to transpose corresponds to the last use of `multiply_add_prim`: `multiply_add_prim(xt, y, ...)` where `y` is the constant `2.0`. -#### JIT of reverse differentiation +#### JIT of reverse differentiation Notice that the abstract evaluation of the `multiply_add_value_and_jvp` is using only abstract values. Meanwhile, in the absence of JIT, you used `ConcreteArray`. @@ -555,9 +555,9 @@ from jax.interpreters import batching @trace("multiply_add_batch") def multiply_add_batch(vector_arg_values, batch_axes): """Computes the batched version of the primitive. - + This must be a JAX-traceable function. - + Since the `multiply_add primitive` already operates point-wise on arbitrary dimension tensors, to batch it you can use the primitive itself. This works as long as both the inputs have the same dimensions and are batched along the @@ -569,7 +569,7 @@ def multiply_add_batch(vector_arg_values, batch_axes): batch_axes: The axes that are being batched. See vmap documentation. Returns: - A tuple of the result, and the result axis that was batched. + A tuple of the result, and the result axis that was batched. """ assert batch_axes[0] == batch_axes[1] assert batch_axes[0] == batch_axes[2] diff --git a/docs/jax.ad_checkpoint.rst b/docs/jax.ad_checkpoint.rst new file mode 100644 index 000000000000..aa1cfbd24355 --- /dev/null +++ b/docs/jax.ad_checkpoint.rst @@ -0,0 +1,11 @@ +``jax.ad_checkpoint`` module +============================ + +.. currentmodule:: jax.ad_checkpoint + +.. automodule:: jax.ad_checkpoint + +.. autosummary:: + :toctree: _autosummary + + checkpoint_name diff --git a/docs/jax.dlpack.rst b/docs/jax.dlpack.rst index 4a679052775e..c5f0034ab51f 100644 --- a/docs/jax.dlpack.rst +++ b/docs/jax.dlpack.rst @@ -9,4 +9,4 @@ :toctree: _autosummary from_dlpack - to_dlpack \ No newline at end of file + is_supported_dtype diff --git a/docs/jax.dtypes.rst b/docs/jax.dtypes.rst index 151d80f90368..ab8682422fdb 100644 --- a/docs/jax.dtypes.rst +++ b/docs/jax.dtypes.rst @@ -9,6 +9,7 @@ bfloat16 canonicalize_dtype float0 + itemsize_bits issubdtype prng_key result_type diff --git a/docs/jax.experimental.compilation_cache.rst b/docs/jax.experimental.compilation_cache.rst index 8196b1e9cf6d..a7ad6413adb8 100644 --- a/docs/jax.experimental.compilation_cache.rst +++ b/docs/jax.experimental.compilation_cache.rst @@ -8,7 +8,5 @@ JAX disk compilation cache. API --- -.. autofunction:: is_initialized -.. autofunction:: initialize_cache .. autofunction:: set_cache_dir .. autofunction:: reset_cache diff --git a/docs/jax.experimental.pallas.mosaic_gpu.rst b/docs/jax.experimental.pallas.mosaic_gpu.rst index 2d3452609c75..bb3340f248ee 100644 --- a/docs/jax.experimental.pallas.mosaic_gpu.rst +++ b/docs/jax.experimental.pallas.mosaic_gpu.rst @@ -10,10 +10,11 @@ Classes :toctree: _autosummary Barrier - GPUBlockSpec - GPUCompilerParams - GPUMemorySpace + BlockSpec + CompilerParams + MemorySpace Layout + SemaphoreType SwizzleTransform TilingTransform TransposeTransform @@ -22,21 +23,81 @@ Classes Functions --------- +.. autosummary:: + :toctree: _autosummary + + as_torch_kernel + kernel + layout_cast + set_max_registers + planar_snake + +Loop-like functions +------------------- + +.. autosummary:: + :toctree: _autosummary + + emit_pipeline + emit_pipeline_warp_specialized + nd_loop + dynamic_scheduling_loop + +Synchronization +--------------- + .. autosummary:: :toctree: _autosummary barrier_arrive barrier_wait + semaphore_signal_parallel + SemaphoreSignal + +Asynchronous copies +------------------- + +.. autosummary:: + :toctree: _autosummary + commit_smem copy_gmem_to_smem copy_smem_to_gmem - emit_pipeline - layout_cast - set_max_registers wait_smem_to_gmem + +Hopper-specific functions +------------------------- + +.. autosummary:: + :toctree: _autosummary + wgmma wgmma_wait +Blackwell-specific functions +---------------------------- + +.. autosummary:: + :toctree: _autosummary + + tcgen05_mma + tcgen05_commit_arrive + async_load_tmem + async_store_tmem + wait_load_tmem + commit_tmem + try_cluster_cancel + query_cluster_cancel + +Multimem operations +------------------- + +.. autosummary:: + :toctree: _autosummary + + multimem_store + multimem_load_reduce + Aliases ------- diff --git a/docs/jax.experimental.pallas.rst b/docs/jax.experimental.pallas.rst index c945f939fa4d..34fc990c1d2e 100644 --- a/docs/jax.experimental.pallas.rst +++ b/docs/jax.experimental.pallas.rst @@ -9,9 +9,9 @@ Backends .. toctree:: :maxdepth: 1 - jax.experimental.pallas.mosaic_gpu - jax.experimental.pallas.triton - jax.experimental.pallas.tpu + Pallas TPU (TensorCore) + Pallas MGPU + Triton Classes ------- @@ -23,34 +23,39 @@ Classes GridSpec Slice - MemoryRef - Functions --------- .. autosummary:: :toctree: _autosummary + core_map + kernel pallas_call program_id num_programs - load - store - swap - - atomic_and - atomic_add - atomic_cas - atomic_max - atomic_min - atomic_or - atomic_xchg - atomic_xor + cdiv + dslice + empty + empty_like + broadcast_to + debug_check debug_print dot - max_contiguous + get_global + loop multiple_of run_scoped when + +Synchronization +--------------- + +.. autosummary:: + :toctree: _autosummary + + semaphore_read + semaphore_signal + semaphore_wait diff --git a/docs/jax.experimental.pallas.tpu.rst b/docs/jax.experimental.pallas.tpu.rst index ae4e2c2253e4..908f7dee86fb 100644 --- a/docs/jax.experimental.pallas.tpu.rst +++ b/docs/jax.experimental.pallas.tpu.rst @@ -9,8 +9,83 @@ Classes .. autosummary:: :toctree: _autosummary + ChipVersion + CompilerParams + GridDimensionSemantics + MemorySpace + PrefetchScalarGridSpec + SemaphoreType + TpuInfo + Functions --------- .. autosummary:: - :toctree: _autosummary \ No newline at end of file + :toctree: _autosummary + + load + store + +Communication +------------- + +.. autosummary:: + :toctree: _autosummary + + async_copy + async_remote_copy + make_async_copy + make_async_remote_copy + sync_copy + +Pipelining +---------- + +.. autosummary:: + :toctree: _autosummary + + BufferedRef + BufferedRefBase + emit_pipeline + emit_pipeline_with_allocations + get_pipeline_schedule + make_pipeline_allocations + + +Pseudorandom Number Generation +------------------------------ + +.. autosummary:: + :toctree: _autosummary + + prng_seed + sample_block + stateful_bernoulli + stateful_bits + stateful_normal + stateful_uniform + to_pallas_key + +Interpret Mode +-------------- + +.. autosummary:: + :toctree: _autosummary + + force_tpu_interpret_mode + InterpretParams + reset_tpu_interpret_mode_state + set_tpu_interpret_mode + +Miscellaneous +------------- + +.. autosummary:: + :toctree: _autosummary + + core_barrier + get_barrier_semaphore + get_tpu_info + is_tpu_device + run_on_first_core + with_memory_space_constraint diff --git a/docs/jax.experimental.pallas.triton.rst b/docs/jax.experimental.pallas.triton.rst index 76b0896ccf17..b62cafbc6f12 100644 --- a/docs/jax.experimental.pallas.triton.rst +++ b/docs/jax.experimental.pallas.triton.rst @@ -9,7 +9,7 @@ Classes .. autosummary:: :toctree: _autosummary - TritonCompilerParams + CompilerParams Functions --------- @@ -17,6 +17,17 @@ Functions .. autosummary:: :toctree: _autosummary + atomic_and + atomic_add + atomic_cas + atomic_max + atomic_min + atomic_or + atomic_xchg + atomic_xor approx_tanh debug_barrier - elementwise_inline_asm \ No newline at end of file + elementwise_inline_asm + load + max_contiguous + store diff --git a/docs/jax.experimental.pjit.rst b/docs/jax.experimental.pjit.rst deleted file mode 100644 index 34fe95ef0625..000000000000 --- a/docs/jax.experimental.pjit.rst +++ /dev/null @@ -1,9 +0,0 @@ -``jax.experimental.pjit`` module -================================ - -.. automodule:: jax.experimental.pjit - -API ---- - -.. autofunction:: pjit diff --git a/docs/jax.experimental.random.rst b/docs/jax.experimental.random.rst new file mode 100644 index 000000000000..9e9c7092fad5 --- /dev/null +++ b/docs/jax.experimental.random.rst @@ -0,0 +1,10 @@ +``jax.experimental.random`` module +================================== + +.. automodule:: jax.experimental.random + +.. autosummary:: + :toctree: _autosummary + + stateful_rng + StatefulPRNG diff --git a/docs/jax.experimental.rst b/docs/jax.experimental.rst index 39c778db7aca..e01681edc14c 100644 --- a/docs/jax.experimental.rst +++ b/docs/jax.experimental.rst @@ -23,16 +23,6 @@ Experimental Modules jax.experimental.mesh_utils jax.experimental.multihost_utils jax.experimental.pallas - jax.experimental.pjit + jax.experimental.random jax.experimental.serialize_executable - jax.experimental.shard_map jax.experimental.sparse - -Experimental APIs ------------------ - -.. autosummary:: - :toctree: _autosummary - - enable_x64 - disable_x64 diff --git a/docs/jax.experimental.shard_map.rst b/docs/jax.experimental.shard_map.rst deleted file mode 100644 index 65be7f21ba1e..000000000000 --- a/docs/jax.experimental.shard_map.rst +++ /dev/null @@ -1,12 +0,0 @@ -``jax.experimental.shard_map`` module -===================================== - -.. automodule:: jax.experimental.shard_map - -API ---- - -.. autosummary:: - :toctree: _autosummary - - shard_map diff --git a/docs/jax.experimental.sparse.rst b/docs/jax.experimental.sparse.rst index 37ef8ae43d67..f021cce3452f 100644 --- a/docs/jax.experimental.sparse.rst +++ b/docs/jax.experimental.sparse.rst @@ -1,12 +1,6 @@ ``jax.experimental.sparse`` module ================================== -.. note:: - - The methods in ``jax.experimental.sparse`` are experimental reference - implementations, and not recommended for use in performance-critical - applications. - .. automodule:: jax.experimental.sparse .. currentmodule:: jax.experimental.sparse diff --git a/docs/jax.extend.backend.rst b/docs/jax.extend.backend.rst new file mode 100644 index 000000000000..facd151dc29d --- /dev/null +++ b/docs/jax.extend.backend.rst @@ -0,0 +1,17 @@ +``jax.extend.backend`` module +============================= + +.. automodule:: jax.extend.backend + +.. autosummary:: + :toctree: _autosummary + + backends + backend_xla_version + clear_backends + get_backend + get_compile_options + get_default_device + ifrt_proxy + register_backend_cache + register_backend_factory diff --git a/docs/jax.extend.core.rst b/docs/jax.extend.core.rst index 5f3ff0558af6..675be9778c61 100644 --- a/docs/jax.extend.core.rst +++ b/docs/jax.extend.core.rst @@ -16,3 +16,5 @@ array_types jaxpr_as_fun primitives + mapped_aval + unmapped_aval diff --git a/docs/jax.extend.linear_util.rst b/docs/jax.extend.linear_util.rst index f48df024e0e7..8f70a9b8d689 100644 --- a/docs/jax.extend.linear_util.rst +++ b/docs/jax.extend.linear_util.rst @@ -6,10 +6,13 @@ .. autosummary:: :toctree: _autosummary + Callable StoreException WrappedFun cache merge_linear_aux transformation + transformation2 transformation_with_aux + transformation_with_aux2 wrap_init diff --git a/docs/jax.extend.mlir.rst b/docs/jax.extend.mlir.rst index 006e5d30682a..e65daa7278bd 100644 --- a/docs/jax.extend.mlir.rst +++ b/docs/jax.extend.mlir.rst @@ -6,6 +6,10 @@ .. autosummary:: :toctree: _autosummary + deserialize_portable_artifact dialects + hlo_to_stablehlo ir passmanager + refine_polymorphic_shapes + serialize_portable_artifact \ No newline at end of file diff --git a/docs/jax.extend.random.rst b/docs/jax.extend.random.rst index c14730e5885a..a75784033b8d 100644 --- a/docs/jax.extend.random.rst +++ b/docs/jax.extend.random.rst @@ -11,5 +11,6 @@ threefry2x32_p threefry_2x32 threefry_prng_impl + random_seed rbg_prng_impl unsafe_rbg_prng_impl diff --git a/docs/jax.extend.rst b/docs/jax.extend.rst index 3fb2b9d830c0..4308182f6108 100644 --- a/docs/jax.extend.rst +++ b/docs/jax.extend.rst @@ -11,6 +11,7 @@ Modules .. toctree:: :maxdepth: 1 + jax.extend.backend jax.extend.core jax.extend.linear_util jax.extend.mlir diff --git a/docs/jax.ffi.rst b/docs/jax.ffi.rst index dc2c6f8ac873..063658e97bd9 100644 --- a/docs/jax.ffi.rst +++ b/docs/jax.ffi.rst @@ -10,22 +10,4 @@ ffi_lowering pycapsule register_ffi_target - register_ffi_type_id - - -``jax.extend.ffi`` module (deprecated) -====================================== - -The ``jax.extend.ffi`` module has been moved to ``jax.ffi``, and that import -path should be used instead, but these functions remain documented here while -the legacy import is being deprecated. - -.. automodule:: jax.extend.ffi - -.. autosummary:: - :toctree: _autosummary - - ffi_call - ffi_lowering - pycapsule - register_ffi_target + register_ffi_type diff --git a/docs/jax.lax.rst b/docs/jax.lax.rst index 9db79f591a4e..4eea7a70fdbb 100644 --- a/docs/jax.lax.rst +++ b/docs/jax.lax.rst @@ -9,7 +9,7 @@ are typically defined as transformations on :mod:`jax.lax` primitives. Many of the primitives are thin wrappers around equivalent XLA operations, described by the `XLA operation semantics -`_ documentation. In a few +`_ documentation. In a few cases JAX diverges from XLA, usually to ensure that the set of operations is closed under the operation of JVP and transpose rules. @@ -86,6 +86,7 @@ Operators dynamic_update_index_in_dim dynamic_update_slice dynamic_update_slice_in_dim + empty eq erf erfc @@ -102,6 +103,7 @@ Operators ge gt igamma + igamma_grad_a igammac imag index_in_dim @@ -128,6 +130,9 @@ Operators population_count pow random_gamma_grad + ragged_all_to_all + ragged_dot + ragged_dot_general real reciprocal reduce @@ -153,6 +158,7 @@ Operators scatter_max scatter_min scatter_mul + scatter_sub shift_left shift_right_arithmetic shift_right_logical @@ -170,9 +176,9 @@ Operators sub tan tanh + tile top_k transpose - zeros_like_array zeta .. _lax-control-flow: @@ -222,6 +228,9 @@ Parallel operators pshuffle pswapaxes axis_index + axis_size + psend + precv Sharding-related operators -------------------------- @@ -256,12 +265,22 @@ Linear algebra operators (jax.lax.linalg) tridiagonal tridiagonal_solve +.. autoclass:: EigImplementation + :members: + :undoc-members: +.. autoclass:: EighImplementation + :members: + :undoc-members: + Argument classes ---------------- .. currentmodule:: jax.lax +.. autoclass:: AccuracyMode + :members: + :undoc-members: .. autoclass:: ConvDimensionNumbers .. autoclass:: ConvGeneralDilatedDimensionNumbers .. autoclass:: DotAlgorithm @@ -269,12 +288,14 @@ Argument classes :members: :undoc-members: :member-order: bysource +.. autoclass:: DotDimensionNumbers .. autoclass:: FftType :members: .. autoclass:: GatherDimensionNumbers .. autoclass:: GatherScatterMode .. autoclass:: Precision .. autoclass:: PrecisionLike +.. autoclass:: RaggedDotDimensionNumbers .. autoclass:: RandomAlgorithm :members: :member-order: bysource @@ -282,3 +303,4 @@ Argument classes :members: :member-order: bysource .. autoclass:: ScatterDimensionNumbers +.. autoclass:: Tolerance diff --git a/docs/jax.lib.rst b/docs/jax.lib.rst deleted file mode 100644 index 66c9ed9d5c91..000000000000 --- a/docs/jax.lib.rst +++ /dev/null @@ -1,25 +0,0 @@ -``jax.lib`` module -================== -The `jax.lib` package is a set of internal tools and types for bridging between -JAX's Python frontend and its XLA backend. - -jax.lib.xla_bridge ------------------- - -.. currentmodule:: jax.lib.xla_bridge - -.. autosummary:: - :toctree: _autosummary - - get_backend - get_compile_options - -jax.lib.xla_client ------------------- - -.. currentmodule:: jaxlib.xla_client - -.. autosummary:: - :toctree: _autosummary - - register_custom_call_target diff --git a/docs/jax.nn.initializers.rst b/docs/jax.nn.initializers.rst index 246e0cdbe9a1..3b168749f29f 100644 --- a/docs/jax.nn.initializers.rst +++ b/docs/jax.nn.initializers.rst @@ -26,6 +26,8 @@ data type ``dtype``. Argument ``key`` is a PRNG key (e.g. from glorot_uniform he_normal he_uniform + kaiming_normal + kaiming_uniform lecun_normal lecun_uniform normal @@ -34,4 +36,7 @@ data type ``dtype``. Argument ``key`` is a PRNG key (e.g. from truncated_normal uniform variance_scaling + xavier_normal + xavier_uniform zeros + Initializer diff --git a/docs/jax.nn.rst b/docs/jax.nn.rst index adb13f89903d..3271d44eb7db 100644 --- a/docs/jax.nn.rst +++ b/docs/jax.nn.rst @@ -33,6 +33,7 @@ Activation functions hard_silu hard_swish hard_tanh + tanh elu celu selu @@ -40,6 +41,7 @@ Activation functions glu squareplus mish + identity Other functions --------------- @@ -49,7 +51,12 @@ Other functions softmax log_softmax + logmeanexp logsumexp standardize one_hot dot_product_attention + scaled_matmul + get_scaled_dot_general_config + scaled_dot_general + log1mexp diff --git a/docs/jax.ref.rst b/docs/jax.ref.rst new file mode 100644 index 000000000000..d22e6314a8d2 --- /dev/null +++ b/docs/jax.ref.rst @@ -0,0 +1,21 @@ +``jax.ref`` module +================== + +.. automodule:: jax.ref + +:mod:`jax.ref` has the API for working with :code:`ArrayRef`. + +API +--- + +.. autosummary:: + :toctree: _autosummary + + AbstractRef + Ref + freeze + get + new_ref + set + swap + addupdate diff --git a/docs/jax.rst b/docs/jax.rst index 98cd464cda15..93a996ec1432 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -1,7 +1,7 @@ .. currentmodule:: jax -Public API: ``jax`` package -=========================== +API Reference +============= Subpackages ----------- @@ -14,6 +14,7 @@ Subpackages jax.lax jax.random jax.sharding + jax.ad_checkpoint jax.debug jax.dlpack jax.distributed @@ -24,6 +25,7 @@ Subpackages jax.nn jax.ops jax.profiler + jax.ref jax.stages jax.test_util jax.tree @@ -34,11 +36,6 @@ Subpackages jax.example_libraries jax.experimental -.. toctree:: - :hidden: - - jax.lib - Configuration ------------- @@ -56,7 +53,9 @@ Configuration enable_checks enable_custom_prng enable_custom_vjp_by_custom_transpose + enable_x64 log_compiles + no_tracing numpy_rank_promotion transfer_guard @@ -82,6 +81,7 @@ Just-in-time compilation (:code:`jit`) block_until_ready copy_to_host_async make_mesh + set_mesh .. _jax-grad: @@ -105,6 +105,32 @@ Automatic differentiation closure_convert checkpoint +Vectorization +------------- + +.. autosummary:: + :toctree: _autosummary + + vmap + numpy.vectorize + +Parallelization +--------------- + +.. autosummary:: + :toctree: _autosummary + + shard_map + smap + pmap + devices + local_devices + process_index + device_count + local_device_count + process_count + process_indices + Customization ------------- @@ -216,30 +242,6 @@ Array properties and methods Array.T Array.mT -Vectorization (:code:`vmap`) ----------------------------- - -.. autosummary:: - :toctree: _autosummary - - vmap - numpy.vectorize - -Parallelization (:code:`pmap`) ------------------------------- - -.. autosummary:: - :toctree: _autosummary - - pmap - devices - local_devices - process_index - device_count - local_device_count - process_count - process_indices - Callbacks --------- @@ -261,3 +263,24 @@ Miscellaneous print_environment_info live_arrays clear_caches + typeof + +.. _checkpoint-policies: + +Checkpoint policies +------------------- + +.. autosummary:: + :toctree: _autosummary + + checkpoint_policies.everything_saveable + checkpoint_policies.nothing_saveable + checkpoint_policies.dots_saveable + checkpoint_policies.checkpoint_dots + checkpoint_policies.dots_with_no_batch_dims_saveable + checkpoint_policies.checkpoint_dots_with_no_batch_dims + checkpoint_policies.save_any_names_but_these + checkpoint_policies.save_only_these_names + checkpoint_policies.offload_dot_with_no_batch_dims + checkpoint_policies.save_and_offload_only_these_names + checkpoint_policies.save_from_both_policies diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index dcbb673997ad..6cf14389adcd 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -69,11 +69,13 @@ jax.scipy.linalg lu lu_factor lu_solve + pascal polar qr rsf2csf schur solve + solve_sylvester solve_triangular sqrtm svd @@ -171,6 +173,7 @@ jax.scipy.special gammaln gammasgn hyp1f1 + hyp2f1 i0 i0e i1 @@ -188,6 +191,7 @@ jax.scipy.special poch polygamma rel_entr + sici softmax spence sph_harm @@ -322,6 +326,34 @@ jax.scipy.stats.gamma sf logsf +jax.scipy.stats.gumbel_l +~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: jax.scipy.stats.gumbel_l +.. autosummary:: + :toctree: _autosummary + + logpdf + pdf + cdf + logcdf + sf + logsf + ppf + +jax.scipy.stats.gumbel_r +~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: jax.scipy.stats.gumbel_r +.. autosummary:: + :toctree: _autosummary + + logpdf + pdf + cdf + logcdf + sf + logsf + ppf + jax.scipy.stats.gennorm ~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: jax.scipy.stats.gennorm @@ -414,8 +446,13 @@ jax.scipy.stats.pareto .. autosummary:: :toctree: _autosummary + logcdf logpdf + logsf + cdf pdf + ppf + sf jax.scipy.stats.poisson ~~~~~~~~~~~~~~~~~~~~~~~ @@ -426,6 +463,7 @@ jax.scipy.stats.poisson logpmf pmf cdf + entropy jax.scipy.stats.t ~~~~~~~~~~~~~~~~~ diff --git a/docs/jax.sharding.rst b/docs/jax.sharding.rst index 954f62b8a52d..0146398cb12d 100644 --- a/docs/jax.sharding.rst +++ b/docs/jax.sharding.rst @@ -16,15 +16,6 @@ Classes .. autoclass:: NamedSharding :members: :show-inheritance: -.. autoclass:: PositionalSharding - :members: - :show-inheritance: -.. autoclass:: PmapSharding - :members: - :show-inheritance: -.. autoclass:: GSPMDSharding - :members: - :show-inheritance: .. autoclass:: PartitionSpec :members: .. autoclass:: Mesh diff --git a/docs/jax.tree.rst b/docs/jax.tree.rst index e65c77c757c1..7bb4c3e557c5 100644 --- a/docs/jax.tree.rst +++ b/docs/jax.tree.rst @@ -12,6 +12,7 @@ List of Functions :toctree: _autosummary all + broadcast flatten flatten_with_path leaves @@ -19,6 +20,7 @@ List of Functions map map_with_path reduce + reduce_associative structure transpose unflatten diff --git a/docs/jax.tree_util.rst b/docs/jax.tree_util.rst index 73fd1f376e9f..664fa0f3ce9e 100644 --- a/docs/jax.tree_util.rst +++ b/docs/jax.tree_util.rst @@ -13,7 +13,6 @@ List of Functions Partial all_leaves - build_tree register_dataclass register_pytree_node register_pytree_node_class @@ -38,10 +37,12 @@ These APIs are now accessed via :mod:`jax.tree`. :toctree: _autosummary tree_all + tree_broadcast tree_flatten tree_leaves tree_map tree_reduce + tree_reduce_associative tree_structure tree_transpose tree_unflatten diff --git a/docs/jax_array_migration.md b/docs/jax_array_migration.md index 95d4a632a295..3cc1629b2068 100644 --- a/docs/jax_array_migration.md +++ b/docs/jax_array_migration.md @@ -1,3 +1,6 @@ +--- +orphan: true +--- (jax-array-migration)= # jax.Array migration @@ -24,7 +27,7 @@ the unified jax.Array After the migration is complete `jax.Array` will be the only type of array in JAX. -This doc explains how to migrate existing codebases to `jax.Array`. For more information on using `jax.Array` and JAX parallelism APIs, see the [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) tutorial. +This doc explains how to migrate existing codebases to `jax.Array`. For more information on using `jax.Array` and JAX parallelism APIs, see the [Distributed arrays and automatic parallelization](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) tutorial. ### How to enable jax.Array? diff --git a/docs/jep/10657-sequencing-effects.md b/docs/jep/10657-sequencing-effects.md index 5f7eb0da4c04..ac3024519101 100644 --- a/docs/jep/10657-sequencing-effects.md +++ b/docs/jep/10657-sequencing-effects.md @@ -47,7 +47,7 @@ g() In many cases, JAX will execute `f` and `g` *in parallel*, dispatching the computations onto different threads -- `g` might actually be executed before `f`. Parallel execution is a nice performance optimization, especially if copying -to and from a device is expensive (see the [asynchronous dispatch note](https://jax.readthedocs.io/en/latest/async_dispatch.html) for more details). +to and from a device is expensive (see the [asynchronous dispatch note](https://docs.jax.dev/en/latest/async_dispatch.html) for more details). In practice, however, we often don't need to think about asynchronous dispatch because we're writing pure functions and only care about the inputs and outputs of functions -- we'll naturally block on future diff --git a/docs/jep/11830-new-remat-checkpoint.md b/docs/jep/11830-new-remat-checkpoint.md index 019188349257..5c3657f666b9 100644 --- a/docs/jep/11830-new-remat-checkpoint.md +++ b/docs/jep/11830-new-remat-checkpoint.md @@ -1,5 +1,7 @@ # `jax.remat` / `jax.checkpoint` changes: what you need to know +This document discusses changes made to `jax.checkpoint` (a.k.a. `jax.remat`) +that were finalized in JAX v0.3.17, released in August 2022. ## Contents diff --git a/docs/jep/12049-type-annotations.md b/docs/jep/12049-type-annotations.md index 7a20958c5cab..bf6123b2bc7f 100644 --- a/docs/jep/12049-type-annotations.md +++ b/docs/jep/12049-type-annotations.md @@ -35,7 +35,7 @@ def slice(operand: Array, start_indices: Sequence[int], For the purposes of static type checking, this use of `Array = Any` for array type annotations puts no constraint on the argument values (`Any` is equivalent to no annotation at all), but it does serve as a form of useful in-code documentation for the developer. -For the sake of generated documentation, the name of the alias gets lost (the [HTML docs](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.slice.html) for `jax.lax.slice` report operand as type `Any`), so the documentation benefit does not go beyond the source code (though we could enable some `sphinx-autodoc` options to improve this: See [autodoc_type_aliases](https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#confval-autodoc_type_aliases)). +For the sake of generated documentation, the name of the alias gets lost (the [HTML docs](https://docs.jax.dev/en/latest/_autosummary/jax.lax.slice.html) for `jax.lax.slice` report operand as type `Any`), so the documentation benefit does not go beyond the source code (though we could enable some `sphinx-autodoc` options to improve this: See [autodoc_type_aliases](https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#confval-autodoc_type_aliases)). A benefit of this level of type annotation is that it is never wrong to annotate a value with `Any`, so it will provide a concrete benefit to developers and users in the form of documentation, without added complexity of satisfying the stricter needs of any particular static type checker. @@ -122,7 +122,7 @@ All told, the array-type-granularity challenge is less of an issue than the othe ### Challenge 5: imprecise APIs inherited from NumPy A large part of JAX’s user-facing API is inherited from NumPy within the {mod}`jax.numpy` submodule. -NumPy’s API was developed years before static type checking became part of the Python language, and follows Python’s historic recommendations to use a [duck-typing](https://docs.python.org/3/glossary.html#term-duck-typing)/[EAFP](https://docs.python.org/3/glossary.html#term-eafp) coding style, in which strict type-checking at runtime is discouraged. As a concrete example of this, consider the {func}`numpy.tile` function, which is defined like this: +NumPy’s API was developed years before static type checking became part of the Python language, and follows Python’s historic recommendations to use a [duck-typing](https://docs.python.org/3/glossary.html#term-duck-typing)/[EAFP](https://docs.python.org/3/glossary.html#term-EAFP) coding style, in which strict type-checking at runtime is discouraged. As a concrete example of this, consider the {func}`numpy.tile` function, which is defined like this: ```python def tile(A, reps): diff --git a/docs/jep/14273-shard-map.md b/docs/jep/14273-shard-map.md index 63742bc852c6..fa6681551d17 100644 --- a/docs/jep/14273-shard-map.md +++ b/docs/jep/14273-shard-map.md @@ -4,7 +4,7 @@ *January 2023* **This was the design doc proposing `shard_map`. You may instead want -[the up-to-date user docs](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html).** +[the up-to-date user docs](https://docs.jax.dev/en/latest/notebooks/shard_map.html).** ## Motivation @@ -18,7 +18,7 @@ We need great APIs for both, and rather than being mutually exclusive alternatives, they need to compose with each other. With `pjit` (now just `jit`) we have [a next-gen -API](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) +API](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) for the first school. But we haven't quite leveled-up the second school. `pmap` follows the second school, but over time we found it has [fatal flaws](#why-dont-pmap-or-xmap-already-solve-this). `xmap` solved those flaws, diff --git a/docs/jep/15856-jex.md b/docs/jep/15856-jex.md index a5625abf8930..a821405c399e 100644 --- a/docs/jep/15856-jex.md +++ b/docs/jep/15856-jex.md @@ -14,13 +14,13 @@ import jax.extend as jex Several projects depend on JAX's codebase internals, often to use its core machinery (e.g. to write a -[transformation over its IR](https://jax.readthedocs.io/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html)) +[transformation over its IR](https://docs.jax.dev/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html)) or to extend it (e.g. to [define new primitives](https://github.com/dfm/extending-jax)). Two challenges for these dependencies are (a) that our internals aren't all solidly designed for external use, and (b) that circumventing JAX's public API is -[unsupported](https://jax.readthedocs.io/en/latest/api_compatibility.html). +[unsupported](https://docs.jax.dev/en/latest/api_compatibility.html). In other words, our internals are often used like a library, but are neither structured nor updated like one. @@ -50,12 +50,12 @@ removed altogether. To keep development overhead low, `jax.extend` would not follow the public -[API compatibility](https://jax.readthedocs.io/en/latest/api_compatibility.html) +[API compatibility](https://docs.jax.dev/en/latest/api_compatibility.html) policy. It would promise no deprecation windows nor backwards compatibility between releases. Every release may break existing callers without simple recourse (e.g. without a flag reintroducing prior behavior). We would rely on the -[changelog](https://jax.readthedocs.io/en/latest/changelog.html) +[changelog](https://docs.jax.dev/en/latest/changelog.html) to call out such changes. Callers of `jax.extend` that need to upgrade their code regularly @@ -108,7 +108,7 @@ to process the Jaxpr IR (the output of At initialization, this module will contain many more symbols than what's needed to define primitives and rules, including various names used in setting up -["final-style transformations"](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing), +["final-style transformations"](https://docs.jax.dev/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing), such as the current `jax._src.core.Trace` and `Tracer` classes. We can revisit whether `jex.core` should also support final-style extensions alongside initial style approaches, and whether it can do so by a more @@ -137,7 +137,7 @@ tracer types from `jex`. This module plus `jex.core` ought to suffice for replicating today's custom primitive tutorials (e.g. -[ours](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html) +[ours](https://docs.jax.dev/en/latest/notebooks/How_JAX_primitives_work.html) and [dfm's](https://github.com/dfm/extending-jax)). For instance, defining a primitive and its behavior under `jax.jit` @@ -184,6 +184,6 @@ arrays. We have only one item in mind for now. The XLA compiler's array sharding format is more expressive than [those provided by -JAX](https://jax.readthedocs.io/en/latest/jax.sharding.html). We could +JAX](https://docs.jax.dev/en/latest/jax.sharding.html). We could provide this as `jex.sharding.XlaOpShardingProto`, corresponding to today's `jax._src.lib.xla_client.OpSharding` internally. diff --git a/docs/jep/17111-shmap-transpose.md b/docs/jep/17111-shmap-transpose.md index 2fdf5f822835..00d8a3f383fd 100644 --- a/docs/jep/17111-shmap-transpose.md +++ b/docs/jep/17111-shmap-transpose.md @@ -497,7 +497,7 @@ of every function instance along which the outputs are mapped, whereas for mesh axes over which the output is unmapped only one copy of the value is used. See [the `shmap` -JEP](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html) for examples +JEP](https://docs.jax.dev/en/latest/jep/14273-shard-map.html) for examples of unmapped inputs and outputs. For comparison, in `vmap` unmapped inputs/outputs are indicated by using `in_axes` / `out_axes` of `None` (rather than an `int`). diff --git a/docs/jep/2026-custom-derivatives.md b/docs/jep/2026-custom-derivatives.md index ce149fa6fb35..b09926425667 100644 --- a/docs/jep/2026-custom-derivatives.md +++ b/docs/jep/2026-custom-derivatives.md @@ -2,7 +2,7 @@ This is a design document, explaining some of the thinking behind the design and implementation of `jax.custom_jvp` and `jax.custom_vjp`. For user-oriented -documentation, see [the tutorial notebook](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). +documentation, see [the tutorial notebook](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). There are two ways to define differentiation rules in JAX: 1. using `jax.custom_jvp` and `jax.custom_vjp` to define custom differentiation diff --git a/docs/jep/263-prng.md b/docs/jep/263-prng.md index 7ef10ae0e9c4..ff58d1b7b94e 100644 --- a/docs/jep/263-prng.md +++ b/docs/jep/263-prng.md @@ -12,7 +12,7 @@ We want a PRNG design that As a corollary of these we believe the design should be functional. Another corollary is that, at least given current hardware constraints, we’re going to do the PRNG in software. > TLDR -> **JAX PRNG = [Threefry counter PRNG](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf) + a functional array-oriented [splitting model](https://dl.acm.org/citation.cfm?id=2503784)** +> **JAX PRNG = [Threefry counter PRNG](https://thesalmons.org/john/random123/papers/random123sc11.pdf) + a functional array-oriented [splitting model](https://dl.acm.org/doi/10.1145/2503778.2503784)** ## Contents * [Three programming models and toy example programs](#three-programming-models-and-toy-example-programs) @@ -79,7 +79,7 @@ Explicit threading is inconvenient for the programmer. But worse, it hasn’t ac In short, making the code functional by explicitly threading state isn’t enough to achieve our expressiveness (#1) and performance (#5, #6) goals. -The key problem in both the previous models is that there’s too much sequencing. To reduce the amount of sequential dependence we use **functional [splittable](https://dl.acm.org/citation.cfm?id=2503784) PRNGs**. Splitting is a mechanism to ‘fork’ a new PRNG state into two PRNG states while maintaining the usual desirable PRNG properties (the two new streams are computationally parallelizable and produce independent random values, i.e. they behave like [multistreams](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf)). +The key problem in both the previous models is that there’s too much sequencing. To reduce the amount of sequential dependence we use **functional [splittable](https://dl.acm.org/doi/10.1145/2503778.2503784) PRNGs**. Splitting is a mechanism to ‘fork’ a new PRNG state into two PRNG states while maintaining the usual desirable PRNG properties (the two new streams are computationally parallelizable and produce independent random values, i.e. they behave like [multistreams](https://thesalmons.org/john/random123/papers/random123sc11.pdf)). ```python def foo(rng_1): @@ -105,7 +105,7 @@ The example doesn’t show it, but as a consequence of the choice (2) the only w ## Design -We can use the *counter-based PRNG* design, and in particular the Threefry hash function, as described in [Parallel random numbers: as easy as 1, 2, 3](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf). We use the counter to achieve efficient vectorization: for a given key we can generate an array of values in a vectorized fashion by mapping the hash function over a range of integers [k + 1, …, k + sample_size]. We use the key together with the hash function to implement [splittable PRNGs](https://dl.acm.org/citation.cfm?id=2503784): that is, splitting is a way to generate two new keys from an existing one. +We can use the *counter-based PRNG* design, and in particular the Threefry hash function, as described in [Parallel random numbers: as easy as 1, 2, 3](https://thesalmons.org/john/random123/papers/random123sc11.pdf). We use the counter to achieve efficient vectorization: for a given key we can generate an array of values in a vectorized fashion by mapping the hash function over a range of integers [k + 1, …, k + sample_size]. We use the key together with the hash function to implement [splittable PRNGs](https://dl.acm.org/doi/10.1145/2503778.2503784): that is, splitting is a way to generate two new keys from an existing one. ```haskell type Sample = Int256 diff --git a/docs/jep/28661-jax-array-protocol.md b/docs/jep/28661-jax-array-protocol.md new file mode 100644 index 000000000000..3e1d073c9184 --- /dev/null +++ b/docs/jep/28661-jax-array-protocol.md @@ -0,0 +1,214 @@ +# JEP 28661: Supporting the `__jax_array__` protocol + +[@jakevdp](http://github.com/jakevdp), *May 2025* + +An occasional user request is for the ability to define custom array-like objects that +work with jax APIs. JAX currently has a partial implementation of a mechanism that does +this via a `__jax_array__` method defined on the custom object. This was never intended +to be a load-bearing public API (see the discussion at {jax-issue}`#4725`), but has +become essential to packages like Keras and flax, which explicitly document the ability +to use their custom array objects with jax functions. This JEP proposes a design for +full, documented support of the `__jax_array__` protocol. + +## Levels of array extensibility +Requests for extensibility of JAX arrays come in a few flavors: + +### Level 1 Extensibility: polymorphic inputs +What I’ll call "Level 1" extensibility is the desire that JAX APIs accept polymorphic inputs. +That is, a user desires behavior like this: + +```python +class CustomArray: + data: numpy.ndarray + ... + +x = CustomArray(np.arange(5)) +result = jnp.sin(x) # Converts `x` to JAX array and returns a JAX array +``` + +Under this extensibility model, JAX functions would accept CustomArray objects as inputs, +implicitly converting them to `jax.Array` objects for the sake of computation. +This is similar to the functionality offered by NumPy via the `__array__` method, and in +JAX (in many but not all cases) via the `__jax_array__` method. + +This is the mode of extensibility that has been requested by the maintainers of `flax.nnx` +and others. The current implementation is also used by JAX internally for the case of +symbolic dimensions. + +### Level 2 extensibility: polymorphic outputs +What I’ll call "Level 2" extensibility is the desire that JAX APIs should not only accept +polymorphic inputs, but also wrap outputs to match the class of the input. +That is, a user desires behavior like this: + +```python +class CustomArray: + data: numpy.ndarray + ... + +x = CustomArray(np.arange(5)) +result = jnp.sin(x) # returns a new CustomArray +``` + +Under this extensibility model, JAX functions would not only accept custom objects +as inputs, but have some protocol to determine how to correctly re-wrap outputs with +the same class. In NumPy, this sort of functionality is offered in varying degrees by +the special `__array_ufunc__`, `__array_wrap__`, and `__array_function__` protocols, +which allow user-defined objects to customize how NumPy API functions operate on +arbitrary inputs and map input types to outputs. +JAX does not currently have any equivalent to these interfaces in NumPy. + +This is the mode of extensibility that has been requested by the maintainers of `keras`, +among others. + +### Level 3 extensibility: subclassing `Array` + +What I’ll call "Level 3" extensibility is the desire that the JAX array object itself +could be subclassable. NumPy provides some APIs that allow this +(see [Subclassing ndarray](https://numpy.org/devdocs/user/basics.subclassing.html)) but +this sort of approach would take some extra thought in JAX due to the need for +representing array objects abstractly via tracing. + +This mode of extensibility has occasionally been requested by users who want to add +special metadata to JAX arrays, such as units of measurement. + +## Synopsis + +For the sake of this proposal, we will stick with the simplest, level 1 extensibility +model. The proposed interface is the one currently non-uniformly supported by a number +of JAX APIs, the `__jax_array__` method. Its usage looks something like this: + +```python +import jax +import jax.numpy as jnp +import numpy as np + +class CustomArray: + data: np.ndarray + + def __init__(self, data: np.ndarray): + self.data = data + + def __jax_array__(self) -> jax.Array: + return jnp.asarray(self.data) + +arr = CustomArray(np.arange(5)) +result = jnp.multiply(arr, 2) +print(repr(result)) +# Array([0, 2, 4, 6, 8], dtype=int32) +``` + +We may revisit other extensibility levels in the future. + +## Design challenges + +JAX presents some interesting design challenges related to this kind of extensibility, +which have not been fully explored previously. We’ll discuss them in turn here: + +### Priority of `__jax_array__` vs. PyTree flattening +JAX already has a supported mechanism for registering custom objects, namely pytree +registration (see [Custom pytree nodes](https://docs.jax.dev/en/latest/custom_pytrees.html#pytrees-custom-pytree-nodes)). +If we also support __jax_array__, which one should take precedence? + +To put this more concretely, what should be the result of this code? + +```python +@jax.jit +def f(x): + print("is JAX array:", isinstance(x, jax.Array)) + +f(CustomArray(...)) +``` + +If we choose to prioritize `__jax_array__` at the JIT boundary, then the output of this +function would be: +``` +is JAX array: True +``` +That is, at the JIT boundary, the `CustomArray` object would be converted into a +`__jax_array__`, and its shape and dtype would be used to construct a standard JAX +tracer for the function. + +If we choose to prioritize pytree flattening at the JIT boundary, then the output of +this function would be: +``` +type(x)=CustomArray +``` +That is, at the JIT boundary, the `CustomArray` object is flattened, and then unflattened +before being passed to the JIT-compiled function for tracing. If `CustomArray` has been +registered as a pytree, it will generally contain traced arrays as its attributes, and +when x is passed to any JAX API that supports `__jax_array__`, these traced attributes +will be converted to a single traced array according to the logic specified in the method. + +There are deeper consequences here for how other transformations like vmap and grad work +when encountering custom objects: for example, if we prioritize pytree flattening, vmap +would operate over the dimensions of the flattened contents of the custom object, while +if we prioritize `__jax_array__`, vmap would operate over the converted array dimensions. + +This also has consequences when it comes to JIT invariance: consider a function like this: +```python +def f(x): + if isinstance(x, CustomArray): + return x.custom_method() + else: + # do something else + ... + +result1 = f(x) +result2 = jax.jit(f)(x) +``` +If `jit` consumes `x` via pytree flattening, the results should agree for a well-specified +flattening rule. If `jit` consumes `x` via `__jax_array__`, the results will differ because +`x` is no longer a CustomArray within the JIT-compiled version of the function. + +#### Synopsis +As of JAX v0.6.0, transformations prioritize `__jax_array__` when it is available. This status +quo can lead to confusion around lack of JIT invariance, and the current implementation in practice +leads to subtle bugs in the case of automatic differentiation, where the forward and backward pass +do not treat inputs consistently. + +Because the pytree extensibility mechanism already exists for the case of customizing +transformations, it seems most straightforward if transformations act only via this +mechanism: that is, **we propose to remove `__jax_array__` parsing during abstractification.** +This approach will preserve object identity through transformations, and give the user the +most possible flexibility. If the user wants to opt-in to array conversion semantics, that +is always possible by explicitly casting their input via jnp.asarray, which will trigger the +`__jax_array__` protocol. + +### Which APIs should support `__jax_array__`? +JAX has a number of different levels of API, from the level of explicit primitive binding +(e.g. `jax.lax.add_p.bind(x, y)`) to the `jax.lax` APIs (e.g. `jax.lax.add(x, y)`) to the +`jax.numpy` APIs (e.g. `jax.numpy.add(x, y)`). Which of these API categories should handle +implicit conversion via `__jax_array__`? + +In order to limit the scope of the change and the required testing, I propose that `__jax_array__` +only be explicitly supported in `jax.numpy` APIs: after all, it is inspired by the` __array__` +protocol which is supported by the NumPy package. We could always expand this in the future to +`jax.lax` APIs if needed. + +This is in line with the current state of the package, where `__jax_array__` handling is mainly +within the input validation utilities used by `jax.numpy` APIs. + +## Implementation +With these design choices in mind, we plan to implement this as follows: + +- **Adding runtime support to `jax.numpy`**: This is likely the easiest part, as most + `jax.numpy` functions use a common internal utility (`ensure_arraylike`) to validate + inputs and convert them to array. This utility already supports `__jax_array__`, and + so most jax.numpy APIs are already compliant. +- **Adding test coverage**: To ensure compliance across the APIs, we should add a new + test scaffold that calls every `jax.numpy` API with custom inputs and validates correct + behavior. +- **Deprecating `__jax_array__` during abstractification**: Currently JAX's abstractification + pass, used in `jit` and other transformations, does parse the `__jax_array__` protocol, + and this is not the behavior we want long-term. We need to deprecate this behavior, and + ensure that downstream packages that rely on it can move toward pytree registration or + explicit array conversion where necessary. +- **Adding type annotations**: the type interface for jax.numpy functions is in + `jax/numpy/__init__.pyi`, and we’ll need to change each input type from `ArrayLike` to + `ArrayLike | SupportsJAXArray`, where the latter is a protocol with a `__jax_array__` + method. We cannot add this directly to the `ArrayLike` definition, because `ArrayLike` + is used in contexts where `__jax_array__` should not be supported. +- **Documentation**: once the above support is added, we should add a documentation section + on array extensibility that outlines exactly what to expect regarding the `__jax_array__` + protocol, with examples of how it can be used in conjunction with pytree registration + in order to effectively work with user-defined types. diff --git a/docs/jep/28845-stateful-rng.md b/docs/jep/28845-stateful-rng.md new file mode 100644 index 000000000000..637aa6c07b93 --- /dev/null +++ b/docs/jep/28845-stateful-rng.md @@ -0,0 +1,145 @@ +(stateful-randomness-jep)= +# JEP 28845: Stateful Randomness in JAX + +[@jakevdp](http://github.com/jakevdp), *November 2025* + +This document explores the addition of an **optional** stateful pseudo-random number generator (PRNG) for use in JAX; this is meant to be used alongside the classic functional PRNGs described in {ref}`pseudorandom-numbers` in cases where statefulness is convenient. + +## Background + +JAX has always required users to explicitly manage random state as part of its functional programming paradigm (see {ref}`prng-design-jep` for background on this). Although well-motivated, this is a frequently encountered [sharp bit](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers) for new users who are accustomed to stateful pseudorandom number APIs. + +With the recent introduction of limited-scope [mutable refs](https://docs.jax.dev/en/latest/array_refs.html) in JAX, it is now possible to implement a stateful PRNG in JAX that retains most of the performance benefits of the existing functional PRNG, while providing a much more natural API for users familiar with NumPy, Pytorch, and other numerical computing libraries. + +This JAX Enhancement Proposal (or [JEP](https://docs.jax.dev/en/latest/jep/index.html)) proposes the introduction of a Stateful PRNG API into {mod}`jax.experimental.random`, with a goal of eventual inclusino into {mod}`jax.random` itself. + +## API Design + +To align with best practices developed within the larger numerical Python community, we propose for the stateful PRNG API to align with To align with NumPy’s most recent PRNG API iteration, found in {class}`numpy.random.Generator`, and typically created using the {func}`numpy.random.default_rng` function. A full draft of the proposed implementation can be found at {jax-issue}`#28845`, but here we summarize the main features of the implementation. + +A simplified version of the stateful PRNG Generator code looks like this (function and argument names follow the {mod}`numpy.random` APIs): + +```python +def stateful_rng(seed: ArrayLike) -> StatefulPRNG: + """Create a stateful PRNG Generator given an integer seed.""" + return StatefulPRNG(jax.random.key(seed), jax.new_ref(0)) + + +@tree_util.register_dataclass +@dataclass(frozen=True) +class StatefulPRNG: + """Stateful PRNG Generator class.""" + base_key: jax.Array + counter: jax.core.Ref + + def key(self) -> jax.Array: + """Generate a new jax PRNG key""" + key = jax.random.fold_in(self.base_key, self.counter[...]) + jax.ref.addupdate(self.counter, ..., 1) # increment counter + return key + + def random(self, size: Sequence[int], dtype: DType = float): + """Return random floats in the half-open interval [0, 1)""" + return random.uniform(self.key(), shape=size, dtype=dtype) + + # uniform(), normal(), integers(), and others implemented similarly. +``` + +With this implementation exposed in the {mod}`jax.experimental.random` namespace, usage is virtually identical to that of {func}`numpy.random.default_rng`: + +```python +>>> from jax.experimental.random import stateful_rng +>>> rng = stateful_rng(1701) +>>> rng.random((5,)) +Array([0.09609699, 0.26730824, 0.5619041 , 0.24421775, 0.7715055 ], dtype=float32) +>>> rng.random((5,)) # state is updated -> new random draws! +Array([0.8131045 , 0.33873856, 0.88808906, 0.96005905, 0.7616446 ], dtype=float32) + +>>> import numpy as np +>>> rng = np.random.default_rng(1701) +>>> rng.random((5,)) +array([0.4020733 , 0.30563311, 0.67668051, 0.15821208, 0.79247763]) +>>> rng.random((5,)) +array([0.09419469, 0.36753944, 0.06388928, 0.96431608, 0.35200998]) + +``` + +Because the statefulness in {class}`jax.experimental.random.StatefulPRNG` is tracked via mutable refs, the random state will correctly update even if the generator is used in transformations like {func}`jax.jit`, which typically require pure functional semantics. + +### Interaction with `vmap` and `shard_map` + +The proposed stateful RNG design is based on refs, and so under `vmap` and `shard_map` it inherits the limitations of refs. So, for example, you cannot directly use an un-mapped `rng` within a vmapped function: + +```python +rng = stateful_rng(0) + +def f(x): + return x + rng.uniform() + +jax.vmap(f)(jnp.arange(10)) +``` +```pytb +Exception: performing an addupdate operation with vmapped value on an unbatched + array reference of type Ref{int32[]}. Move the array reference to be + an argument to the vmapped function? +``` + +For this reason we need the ability to split the generator in order to pass it to mapped or sharded code. For this we add a `split` method to the `StatefulPRNG` class that looks like this: + +```python +class StatefulPRNG: + ... + + def split(self, num: int | Sequence[int]) -> StatefulPRNG: + return StatefulPRNG( + base_key=jax.random.split(self.key(), num), + counter=jnp.zeros(num, dtype=int), + ) +``` + +With this method present, the stateful rng can be explicitly split and passed to a vmapped function: + +```python +rng = jax.experimental.random.stateful_rng(0) + +def f(x, rng): + return x + rng.uniform() + +result = jax.vmap(f)(jnp.arange(5), rng.split(5)) +print(result) # [0.07174575 1.0163325 2.0435536 3.4391735 4.534091 ] +``` + +A similar approach would work for sharded computations, though `split` would likely have to grow a `sharding` argument. + +This splitting brings up the question of what to do if a user attempts to generate random numbers directly from a split generator, like `rng.split(10).uniform()`. For this we follow the precedent of classic stateless `jax.random` APIs when receiving batched keys, and raise an informative error. + +## Statistical Considerations + +In the proposed design, the random state is tracked via a base key along with an integer counter that increments each time a key is generated. We chose this approach rather than mutating the key itself in order to avoid the pitfalls of iterative splits (see INSERT_REF_HERE); in particular it means that the stateful generator will always fully explore the 32-bit or 64-bit space of keys before looping back to zero and repeating the initial key. + +## Advantages + +The main advantage of this approach is familiarity: many users are familiar with NumPy, and familiar with its stateful RNG utilities. This would let them start using JAX more directly, without the learning curve of the unfamiliar functional PRNG API. + +This does not just affect JAX users: for convenience, even JAX developers tend to context switch and use stateful NumPy APIs outside of transformations, where the functional PRNG is not necessary. This leads to confusion on the part of JAX users (see for example [this github discussion](https://github.com/jax-ml/jax/issues/30881)). Having a JAX-native stateful API would make it more convenient to always use JAX PRNGs in live demos and written tutorials. + +Another pitfall of functional PRNGs is the possibility of accidental key reuse. Users unfamiliar with the need for explicit state may use keys multiple times, inadvertently generating statistically dependent random values (see for example [this StackOverflow question](https://stackoverflow.com/q/76135488)). By encouraging new JAX users to use a stateful PRNG, we avoid this silent trap. + +Finally, the API affords the ability to call `rng.key()` in order to create a standard functional PRNG key, which can then be used in the typical functional mode: this is an easy onramp to explicitly-managed state in cases where it is warranted. + +## Limitations + +Implementing a stateful PRNG key via mutable refs comes with a few inherent limitations; in particular: + +**Sequential dependence restricts the compiler:** Programs using such keys impose an inherent sequential dependence within the program, meaning that the compiler would not have the freedom to reorder operations that depend on pseudorandom values. The pitfall in this case is silent: it would be up to the user to recognize where this may become an issue, and instead switch to batched execution modes over pre-generated sequences of keys or values. Note, however, that this sequential dependence pitfall also exists when users follow the current usage recommendations in the JAX docs: [https://docs.jax.dev/en/latest/jax.random.html\#basic-usage](https://docs.jax.dev/en/latest/jax.random.html#basic-usage). + +**Sequential dependence restricts the user:** Similarly, just as the compiler cannot reorder operations without changing the randomness, this sequential dependence also means the user cannot easily refactor code without changing the specific random draws. One potential example of this: suppose a stateful RNG is used within a neural network, and the user decides to swap an internal layer with one that has different random draws: this would consume a key and affect the random draws of all subsequent layers. + +**Incompatiblity with remat:** Because mutable refs rely on JAX’s effect system, these APIs would not be usable in places where effects are not supported. In particular, this means that in JAX’s current implementation, stateful keys would not be compatible with `remat`, which might limit their usefulness within neural network implementations. The pitfall in this case is loud: attempting to use a mutable ref within remat will lead to an explicit error. There is a possibility that a future redesign of `remat` could remove this incompatibility (see {jax-issue}`#33018` for some progress on this). + +**Refs cannot be return values:** Mutable refs cannot be present in the return values of transformed JAX functions, and the proposed stateful RNG object would inherit this limitation. This is also an explicit limitation: attempting to return a `StatefulPRNG` from a transformed function would lead to an explicit error. + +## Evaluation + +Our judgment is that the advantages of the stateful PRNG API potentially outweigh the limitations, and that we should introduce a new experimental {func}`~jax.experimental.random.stateful_rng` API in the {mod}`jax.experimental.random` module for now. +Once we get a feel for the usefulness of this, we may evenutally graduate this API to the {mod}`jax.random` module, perhaps with a `default_rng` alias in {mod}`jax.numpy.random`. diff --git a/docs/jep/4008-custom-vjp-update.md b/docs/jep/4008-custom-vjp-update.md index 1e2270e052a6..c3f2be151ef7 100644 --- a/docs/jep/4008-custom-vjp-update.md +++ b/docs/jep/4008-custom-vjp-update.md @@ -4,7 +4,7 @@ _Oct 14 2020_ This doc assumes familiarity with `jax.custom_vjp`, as described in the [Custom derivative rules for JAX-transformable Python -functions](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) +functions](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) notebook. ## What to update diff --git a/docs/jep/4410-omnistaging.md b/docs/jep/4410-omnistaging.md index f95c15f404b6..5b4536864ac2 100644 --- a/docs/jep/4410-omnistaging.md +++ b/docs/jep/4410-omnistaging.md @@ -266,7 +266,7 @@ While tracing the function ex1 at ex1.py:4, this value became a tracer due to JA You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions. -See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information. +See https://docs.jax.dev/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information. Encountered tracer value: Tracedwith ``` diff --git a/docs/jep/9407-type-promotion.ipynb b/docs/jep/9407-type-promotion.ipynb index a1ede3177a3a..5f12877c97a9 100644 --- a/docs/jep/9407-type-promotion.ipynb +++ b/docs/jep/9407-type-promotion.ipynb @@ -12,7 +12,7 @@ "\n", "*Jake VanderPlas, December 2021*\n", "\n", - "One of the challenges faced in the design of any numerical computing library is the choice of how to handle operations between values of different types. This document outlines the thought process behind the promotion semantics used by JAX, summarized in [JAX Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html)." + "One of the challenges faced in the design of any numerical computing library is the choice of how to handle operations between values of different types. This document outlines the thought process behind the promotion semantics used by JAX, summarized in [JAX Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html)." ] }, { @@ -1335,7 +1335,7 @@ "However, these advantages comes with a few tradeoffs:\n", "\n", "- mixed float/integer promotion is very prone to precision loss: for example, `int64` (with a maximum value of $9.2 \\times 10^{18}$) can be promoted to `float16` (with a maximum value of $6.5 \\times 10^4$), meaning most representable values will become `inf`.\n", - "- as mentioned above, `f*` can no longer be thought of as a \"scalar type\", but as a different flavor of float64. In JAX's parlance, this is referred to as a [*weak type*](https://jax.readthedocs.io/en/latest/type_promotion.html#weakly-typed-values-in-jax), in that it is represented as 64-bit, but only weakly holds to this bit width in promotion with other values.\n", + "- as mentioned above, `f*` can no longer be thought of as a \"scalar type\", but as a different flavor of float64. In JAX's parlance, this is referred to as a [*weak type*](https://docs.jax.dev/en/latest/type_promotion.html#weakly-typed-values-in-jax), in that it is represented as 64-bit, but only weakly holds to this bit width in promotion with other values.\n", "\n", "Note that also, this approach still leaves the `uint64` promotion question unanswered, although it is perhaps reasonable to close the lattice by connecting `u64` to `f*`." ] @@ -1413,7 +1413,7 @@ "id": "o0-E2KWjYEXO" }, "source": [ - "The behavior resulting from this choice is summarized in [JAX Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html). Notably, aside from the inclusion of larger unsigned types (`u16`, `u32`, `u64`) and some details about the behavior of scalar/weak types (`i*`, `f*`, `c*`), this type promotion scheme turns out to be very close to that chosen by PyTorch.\n", + "The behavior resulting from this choice is summarized in [JAX Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html). Notably, aside from the inclusion of larger unsigned types (`u16`, `u32`, `u64`) and some details about the behavior of scalar/weak types (`i*`, `f*`, `c*`), this type promotion scheme turns out to be very close to that chosen by PyTorch.\n", "\n", "For those interested, the appendix below prints the full promotion tables used by NumPy, Tensorflow, PyTorch, and JAX." ] @@ -2883,7 +2883,7 @@ "source": [ "### JAX Type Promotion: `jax.numpy`\n", "\n", - "`jax.numpy` follows type promotion rules laid out at https://jax.readthedocs.io/en/latest/type_promotion.html. Here we use `i*`, `f*`, `c*` to indicate both Python scalars and weakly-typed arrays." + "`jax.numpy` follows type promotion rules laid out at https://docs.jax.dev/en/latest/type_promotion.html. Here we use `i*`, `f*`, `c*` to indicate both Python scalars and weakly-typed arrays." ] }, { diff --git a/docs/jep/9407-type-promotion.md b/docs/jep/9407-type-promotion.md index ff67a8c21399..c047d76c1b18 100644 --- a/docs/jep/9407-type-promotion.md +++ b/docs/jep/9407-type-promotion.md @@ -20,7 +20,7 @@ kernelspec: *Jake VanderPlas, December 2021* -One of the challenges faced in the design of any numerical computing library is the choice of how to handle operations between values of different types. This document outlines the thought process behind the promotion semantics used by JAX, summarized in [JAX Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html). +One of the challenges faced in the design of any numerical computing library is the choice of how to handle operations between values of different types. This document outlines the thought process behind the promotion semantics used by JAX, summarized in [JAX Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html). +++ {"id": "Rod6OOyUVbQ8"} @@ -680,7 +680,7 @@ This is important because `f16` and `bf16` are not comparable because they utili However, these advantages comes with a few tradeoffs: - mixed float/integer promotion is very prone to precision loss: for example, `int64` (with a maximum value of $9.2 \times 10^{18}$) can be promoted to `float16` (with a maximum value of $6.5 \times 10^4$), meaning most representable values will become `inf`. -- as mentioned above, `f*` can no longer be thought of as a "scalar type", but as a different flavor of float64. In JAX's parlance, this is referred to as a [*weak type*](https://jax.readthedocs.io/en/latest/type_promotion.html#weakly-typed-values-in-jax), in that it is represented as 64-bit, but only weakly holds to this bit width in promotion with other values. +- as mentioned above, `f*` can no longer be thought of as a "scalar type", but as a different flavor of float64. In JAX's parlance, this is referred to as a [*weak type*](https://docs.jax.dev/en/latest/type_promotion.html#weakly-typed-values-in-jax), in that it is represented as 64-bit, but only weakly holds to this bit width in promotion with other values. Note that also, this approach still leaves the `uint64` promotion question unanswered, although it is perhaps reasonable to close the lattice by connecting `u64` to `f*`. @@ -730,7 +730,7 @@ nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos +++ {"id": "o0-E2KWjYEXO"} -The behavior resulting from this choice is summarized in [JAX Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html). Notably, aside from the inclusion of larger unsigned types (`u16`, `u32`, `u64`) and some details about the behavior of scalar/weak types (`i*`, `f*`, `c*`), this type promotion scheme turns out to be very close to that chosen by PyTorch. +The behavior resulting from this choice is summarized in [JAX Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html). Notably, aside from the inclusion of larger unsigned types (`u16`, `u32`, `u64`) and some details about the behavior of scalar/weak types (`i*`, `f*`, `c*`), this type promotion scheme turns out to be very close to that chosen by PyTorch. For those interested, the appendix below prints the full promotion tables used by NumPy, Tensorflow, PyTorch, and JAX. @@ -900,7 +900,7 @@ display.HTML(table.to_html()) ### JAX Type Promotion: `jax.numpy` -`jax.numpy` follows type promotion rules laid out at https://jax.readthedocs.io/en/latest/type_promotion.html. Here we use `i*`, `f*`, `c*` to indicate both Python scalars and weakly-typed arrays. +`jax.numpy` follows type promotion rules laid out at https://docs.jax.dev/en/latest/type_promotion.html. Here we use `i*`, `f*`, `c*` to indicate both Python scalars and weakly-typed arrays. ```{code-cell} :cellView: form diff --git a/docs/jep/9419-jax-versioning.md b/docs/jep/9419-jax-versioning.md index b964aa2af45d..85b95257ebae 100644 --- a/docs/jep/9419-jax-versioning.md +++ b/docs/jep/9419-jax-versioning.md @@ -167,16 +167,16 @@ We maintain an additional version number (`_version`) in [`xla_client.py` in the XLA repository](https://github.com/openxla/xla/blob/main/xla/python/xla_client.py). The idea is that this version number, is defined in `xla/python` together with the C++ parts of JAX, is also accessible to JAX Python as -`jax._src.lib.xla_extension_version`, and must +`jax._src.lib.jaxlib_extension_version`, and must be incremented every time that a change is made to the XLA/Python code that has backwards compatibility implications for `jax`. The JAX Python code can then use this version number to maintain backwards compatibility, e.g.: ``` -from jax._src.lib import xla_extension_version +from jax._src.lib import jaxlib_extension_version # 123 is the new version number for _version in xla_client.py -if xla_extension_version >= 123: +if jaxlib_extension_version >= 123: # Use new code path ... else: diff --git a/docs/jep/index.rst b/docs/jep/index.rst index 1c4ecbb3411f..cc104fc1d671 100644 --- a/docs/jep/index.rst +++ b/docs/jep/index.rst @@ -52,6 +52,8 @@ Then create a pull request that adds a file named 17111: Efficient transposition of `shard_map` (and other maps) <17111-shmap-transpose> 18137: Scope of JAX NumPy & SciPy Wrappers <18137-numpy-scipy-scope> 25516: Effort-based versioning <25516-effver> + 28661: Supporting the `__jax_array__` protocol <28661-jax-array-protocol> + 28845: Stateful Randomness in JAX <28845-stateful-rng> Several early JEPs were converted in hindsight from other documentation, diff --git a/docs/jit-compilation.md b/docs/jit-compilation.md index 5e5be308068a..228404477c84 100644 --- a/docs/jit-compilation.md +++ b/docs/jit-compilation.md @@ -55,9 +55,9 @@ The {ref}`jax-internals-jaxpr` section of the documentation provides more inform Importantly, notice that the jaxpr does not capture the side-effect present in the function: there is nothing in it corresponding to `global_list.append(x)`. This is a feature, not a bug: JAX transformations are designed to understand side-effect-free (a.k.a. functionally pure) code. -If *pure function* and *side-effect* are unfamiliar terms, this is explained in a little more detail in [🔪 JAX - The Sharp Bits 🔪: Pure Functions](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions). +If *pure function* and *side-effect* are unfamiliar terms, this is explained in a little more detail in [🔪 JAX - The Sharp Bits 🔪: Pure Functions](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions). -Impure functions are dangerous because under JAX transformations they are likely not to behave as intended; they might fail silently, or produce surprising downstream errors like leaked Tracers. +Impure functions are dangerous because under JAX transformations they are likely not to behave as intended; they might fail silently, or produce surprising downstream errors like leaked [Tracers](key-concepts-tracing). Moreover, JAX often can't detect when side effects are present. (If you want debug printing, use {func}`jax.debug.print`. To express general side-effects at the cost of performance, see {func}`jax.experimental.io_callback`. To check for tracer leaks at the cost of performance, use with {func}`jax.check_tracer_leaks`). diff --git a/docs/key-concepts.md b/docs/key-concepts.md index 91f0c953462e..9354cfe51e3d 100644 --- a/docs/key-concepts.md +++ b/docs/key-concepts.md @@ -19,48 +19,6 @@ kernelspec: This section briefly introduces some key concepts of the JAX package. -(key-concepts-jax-arrays)= -## JAX arrays ({class}`jax.Array`) - -The default array implementation in JAX is {class}`jax.Array`. In many ways it is similar to -the {class}`numpy.ndarray` type that you may be familiar with from the NumPy package, but it -has some important differences. - -### Array creation - -We typically don't call the {class}`jax.Array` constructor directly, but rather create arrays via JAX API functions. -For example, {mod}`jax.numpy` provides familiar NumPy-style array construction functionality -such as {func}`jax.numpy.zeros`, {func}`jax.numpy.linspace`, {func}`jax.numpy.arange`, etc. - -```{code-cell} -import jax -import jax.numpy as jnp - -x = jnp.arange(5) -isinstance(x, jax.Array) -``` - -If you use Python type annotations in your code, {class}`jax.Array` is the appropriate -annotation for jax array objects (see {mod}`jax.typing` for more discussion). - -### Array devices and sharding - -JAX Array objects have a `devices` method that lets you inspect where the contents of the array are stored. In the simplest cases, this will be a single CPU device: - -```{code-cell} -x.devices() -``` - -In general, an array may be *sharded* across multiple devices, in a manner that can be inspected via the `sharding` attribute: - -```{code-cell} -x.sharding -``` - -Here the array is on a single device, but in general a JAX array can be -sharded across multiple devices, or even multiple hosts. -To read more about sharded arrays and parallel computation, refer to {ref}`sharded-computation` - (key-concepts-transformations)= ## Transformations Along with functions to operate on arrays, JAX includes a number of @@ -74,6 +32,9 @@ as well as several others. Transformations accept a function as an argument, and new transformed function. For example, here's how you might JIT-compile a simple SELU function: ```{code-cell} +import jax +import jax.numpy as jnp + def selu(x, alpha=1.67, lambda_=1.05): return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) @@ -89,9 +50,6 @@ def selu(x, alpha=1.67, lambda_=1.05): return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) ``` -Transformations like {func}`~jax.jit`, {func}`~jax.vmap`, {func}`~jax.grad`, and others are -key to using JAX effectively, and we'll cover them in detail in later sections. - (key-concepts-tracing)= ## Tracing @@ -118,6 +76,12 @@ by the function before those operations are actually executed: transformations l {func}`~jax.jit`, {func}`~jax.vmap`, and {func}`~jax.grad` can then map this sequence of input operations to a transformed sequence of operations. +**Static vs traced operations**: Just as values can be either static or traced, +operations can be static or traced. Static operations are evaluated at compile-time +in Python; traced operations are compiled & evaluated at run-time in XLA. + +For more details, see [Tracing](tracing-tutorial). + (key-concepts-jaxprs)= ## Jaxprs @@ -190,42 +154,53 @@ in a tree. You can learn more in the {ref}`working-with-pytrees` tutorial. -(key-concepts-prngs)= -## Pseudorandom numbers +## JAX API layering: NumPy, lax & XLA + +All JAX operations are implemented in terms of operations in [XLA](https://www.openxla.org/xla/) – the Accelerated Linear Algebra compiler. If you look at the source of `jax.numpy`, you'll see that all the operations are eventually expressed in terms of functions defined in {mod}`jax.lax`. While `jax.numpy` is a high-level wrapper that provides a familiar interface, you can think of `jax.lax` as a stricter, but often more powerful, lower-level API for working with multi-dimensional arrays. -Generally, JAX strives to be compatible with NumPy, but pseudo random number generation is a notable exception. NumPy supports a method of pseudo random number generation that is based on a global `state`, which can be set using {func}`numpy.random.seed`. Global random state interacts poorly with JAX's compute model and makes it difficult to enforce reproducibility across different threads, processes, and devices. JAX instead tracks state explicitly via a random `key`: +For example, while `jax.numpy` will implicitly promote arguments to allow operations between mixed data types, `jax.lax` will not: + +```{code-cell} +import jax.numpy as jnp +jnp.add(1, 1.0) # jax.numpy API implicitly promotes mixed types. +``` ```{code-cell} -from jax import random +:tags: [raises-exception] -key = random.key(43) -print(key) +from jax import lax +lax.add(1, 1.0) # jax.lax API requires explicit type promotion. ``` -The key is effectively a stand-in for NumPy's hidden state object, but we pass it explicitly to {func}`jax.random` functions. -Importantly, random functions consume the key, but do not modify it: feeding the same key object to a random function will always result in the same sample being generated. +If using `jax.lax` directly, you'll have to do type promotion explicitly in such cases: ```{code-cell} -print(random.normal(key)) -print(random.normal(key)) +lax.add(jnp.float32(1), 1.0) ``` -**The rule of thumb is: never reuse keys (unless you want identical outputs).** +Along with this strictness, `jax.lax` also provides efficient APIs for some more general operations than are supported by NumPy. -In order to generate different and independent samples, you must {func}`~jax.random.split` the key explicitly before passing it to a random function: +For example, consider a 1D convolution, which can be expressed in NumPy this way: ```{code-cell} -for i in range(3): - new_key, subkey = random.split(key) - del key # The old key is consumed by split() -- we must never use it again. +x = jnp.array([1, 2, 1]) +y = jnp.ones(10) +jnp.convolve(x, y) +``` - val = random.normal(subkey) - del subkey # The subkey is consumed by normal(). +Under the hood, this NumPy operation is translated to a much more general convolution implemented by [`lax.conv_general_dilated`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.conv_general_dilated.html): - print(f"draw {i}: {val}") - key = new_key # new_key is safe to use in the next iteration. +```{code-cell} +from jax import lax +result = lax.conv_general_dilated( + x.reshape(1, 1, 3).astype(float), # note: explicit promotion + y.reshape(1, 1, 10), + window_strides=(1,), + padding=[(len(y) - 1, len(y) - 1)]) # equivalent of padding='full' in NumPy +result[0, 0] ``` -Note that this code is thread safe, since the local random state eliminates possible race conditions involving global state. {func}`jax.random.split` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys. +This is a batched convolution operation designed to be efficient for the types of convolutions often used in deep neural nets. It requires much more boilerplate, but is far more flexible and scalable than the convolution provided by NumPy (See [Convolutions in JAX](https://docs.jax.dev/en/latest/notebooks/convolutions.html) for more detail on JAX convolutions). -For more on pseudo random numbers in JAX, see the {ref}`pseudorandom-numbers` tutorial. +At their heart, all `jax.lax` operations are Python wrappers for operations in XLA; here, for example, the convolution implementation is provided by [XLA:ConvWithGeneralPadding](https://www.openxla.org/xla/operation_semantics#convwithgeneralpadding_convolution). +Every JAX operation is eventually expressed in terms of these fundamental XLA operations, which is what enables just-in-time (JIT) compilation. diff --git a/docs/migrate_pmap.md b/docs/migrate_pmap.md new file mode 100644 index 000000000000..e0cbb83ef0f0 --- /dev/null +++ b/docs/migrate_pmap.md @@ -0,0 +1,758 @@ +--- +orphan: true +--- + +(migrate-pmap)= + +# Migrating to the new `jax.pmap` + +## What's going on? + +As of JAX 0.8.0, the default implementation of `jax.pmap` will be based on +`jax.jit` and +[`jax.shard_map`](https://docs.jax.dev/en/latest/notebooks/shard_map.html). The +new implementation is **_not_** a perfect replacement for the original and this +doc gives guidance for users who run into trouble + +This change makes `jax.pmap` integrate well with JAX shardings and simplifies +the implementation (see {doc}`jep/14273-shard-map` for more rationale). + +## Help! Fix me now! + +**IMPORTANT**: This option is not a permanent fix. Until January 15, 2026, it +will be possible to temporarily use the old version of `jax.pmap` by doing one +of the following: + +- Setting the shell environment variable `JAX_PMAP_SHMAP_MERGE` to something + false-like (e.g., 0); +- Setting the boolean flag `--jax_pmap_shmap_merge` to something false-like if + your code parses flags with `absl-py`. +- Using this statement in your main file or anywhere before you call `jax.pmap`: + ```python + import jax + jax.config.update("jax_pmap_shmap_merge", False) + ``` + +**NOTE**: Please file a [bug](https://github.com/jax-ml/jax/issues) with a +reproducer and tag [@danielsuo](https://github.com/danielsuo/) so we can resolve +it as quickly as possible under the new `jax.pmap`. + +## How can I fix my code for the new `jax.pmap`? + +Below are common errors we're collecting and suggestions for fixing them. This +is more work than setting `jax_pmap_shmap_merge=False`, but a more long-term +solution. However, we still recommend that new or important code be migrated to +`jax.shard_map`. + +### `ValueError: Received incompatible devices ...` + +#### Example + +``` +ValueError: Received incompatible devices for jitted computation. Got argument a +of allclose with shape float32[100] and device ids [0] on platform TPU and +argument b of allclose with shape float32[100] and device ids [0, 1] on platform +TPU +``` + +#### How this can happen + +- `jax.pmap` no longer silently reshards inputs, as per the behavior of + `jax.jit` and `jax.shard_map`. As a result, if inputs are sharded differently + from how your `jax.pmap` expects, it will raise. + +#### How to fix + +- Pass an appropriate `jax.NamedSharding` to `jax.device_put` to explicitly + reshard any offending inputs. +- Alternatively, redefine your `jax.pmap` with the appropriate `in_axes`, + `backend`, and / or `devices` keywords to ensure `jax.pmap`'s mesh and + expected input shardings match your operands. + +### `ValueError: The context mesh ... should match the mesh passed to shard_map` + +#### Example + +``` +ValueError: The context mesh AbstractMesh('x': 1, axis_types=(Manual,), +device_kind=TPU v3, num_cores=1) should match the mesh passed to shard_map +Mesh('y': 4, axis_types=(Auto,)) +``` + +#### How this can happen + +- This error can appear when nesting multiple `jax.pmap`s. This behavior is no + longer supported since the `jax.pmap` API would not know anything about inner + calls to `jax.pmap` and therefore not know about inner mesh axes. + +#### How to fix + +- Migrate to `jax.shard_map`. A single `jax.shard_map` can parallelize along + multiple axes of inputs, with each of those axes assigned to the relevant axes + of the device mesh. +- Alternatively, you can nest `jax.shard_map` calls or use `jax.smap`, which + makes it easier to drop into [manual + parallelism](https://docs.jax.dev/en/latest/notebooks/shard_map.html) mode one + mesh axis at a time. This approach greatly simplifies nested parallelism. + +### `JaxRuntimeError: INVALID_ARGUMENT: CopyArrays ... same size` + +#### Example + +``` +jax.errors.JaxRuntimeError: INVALID_ARGUMENT: CopyArrays only supports +destination device list of the same size as the array device lists. +``` + +#### How this can happen + +- This error can appear in a multi-host setting (i.e., + `jax.process_count() > 1`) + where users try to index into a sharded array (e.g., `x[0]`) with the + intention of grabbing what is semantically a replica. Please see + [Appendix A](#appendix-a) for more details. + +#### How to fix + +Instead of `x[0]`, use one of these approaches: + +- **Access local data directly**: Use `.addressable_shards[0].data` to get the + local shard without triggering global resharding. +- **Explicit resharding**: Use `jax.device_put(x, sharding)` with an appropriate + `NamedSharding` to explicitly control how data is distributed. + +### Using `jax.stages.Lowered` returned by `jax.pmap(f).lower(*args)` + +Because of the default call path of a `jax.stages.Lowered` object, we miss the +conversion from host-local arrays to global arrays to pass into the underlying +`jax.shard_map(f)` as well as the conversion back from global arrays to +host-local arrays for the output. This can lead to unexpected behavior in the +multi-host setting. In this case, we recommend users call +`jax.experimental.multihost_utils`'s `host_local_array_to_global_array` on +inputs and `global_array_to_host_local_array` on outputs of `.compile()(*args)` +to perform the necessary conversions. + +### `JaxRuntimeError: INTERNAL: Core halted unexpectedly` + +#### Example + +``` +jax.errors.JaxRuntimeError: INTERNAL: Core halted unexpectedly: Assertion args: +0x00000000 0x00000000 0x00000000 INTERNAL: Accelerator device halted +prematurely, perhaps due to an on-device check-failure. Node 0 halted +unexpectedly at tag:pc +TensorCoreSequencer:1:0x160 (from TensorCoreSequencer:1:0x208): scheckne: +``` + +#### How this can happen + +- This error typically occurs in multi-host settings when process + synchronization barriers are not properly aligned. The new `jax.pmap` + implementation may have different synchronization semantics compared to the + old implementation. + +#### How to fix + +- Replace any custom process barrier implementations with + `jax.experimental.multihost_utils.sync_global_devices()`. This ensures all + processes reach the same synchronization point before proceeding. + + ```python + from jax.experimental import multihost_utils as mhu + + # Instead of custom barriers + mhu.sync_global_devices("barrier_name") + ``` + +## Performance implications + +### `int` indexing into sharded arrays + +The new implementation of `jax.pmap` uses `NamedSharding` instead of the legacy +`PmapSharding`. We've observe a common pattern with the old `jax.pmap` where +users shard stacked copies of an array to replicate (e.g., via +`jax.device_put_replicated`). These "sharded-but-really-replicated" arrays +suffer unnecessary communication overhead when `int` indexing (e.g., `x[0]`) +because JAX does not know the arrays are actually replicated. For a more +thorough discussion, please see the note on the multi-host setting in +[Appendix A](#appendix-a). + +#### Option 1: Prevent unintended sharding (recommended) + +Avoid creating the leading sharded dimension entirely. + +- Use `jax.pmap`'s `out_axes=None` for arguments that should remain replicated. + The output will be fully replicated (e.g., `P(None, None)`), making access + cheap. +- For inputs: When using `jax.device_put`, specify `jax.P()` (fully replicated) + in the partition spec rather than relying on utilities that stack and shard. + (Note: `jax.device_put_replicated` and `jax.device_put_sharded` are deprecated + because they confusingly produce sharded arrays rather than replicated ones). + +#### Option 2: Access local data directly + +If you must work with a sharded array (or want potentially fewer changes to +code), you can access the local data shard directly without triggering JAX's +distributed consistency checks. Note that this is only recommended when bringing +data back to host (e.g., for logging, checkpointing). Instead of `x[0]`, use +`addressable_shards`: + +```python +# Old slow way: +# result = x[0] + +# New fast way: +# x.addressable_shards is a list of shards on the current process. +# We grab the first one, extract the data, and remove the leading dimension. +result = x.addressable_shards[0].data.squeeze(0) +``` + +In the example of `x` with shape `(8, 3, 4)`, `x.addressable_shards[0].data` +returns the local chunk of shape `(1, 3, 4)`. Calling `.squeeze(0)` results in +the desired `(3, 4)` shape without any cross-device communication. Both +solutions will eliminate the `_gather` operations seen in profiling. + +### Host local array to global array round-trip conversion + +In multi-process JAX programs (i.e., `jax.process_count() > 1`), arrays might +not be [fully +addressable](https://docs.jax.dev/en/latest/_autosummary/jax.Array.is_fully_addressable.html) +(i.e., "host local"), so the new `jax.pmap` will reshard the host-local array +into a global one before passing to `jax.jit` of `jax.shard_map` and back into a +host-local array when returning to user code. + +This round-trip conversion cannot be avoided, so if the performance penalty is +too great, we recommend migrating your code to `jax.shard_map`. + +### Transforming `jax.pmap` e.g., `jax.jit` + +We recommend keeping `jax.pmap` as the top-level transform since it is more +performant than under another transform. However, if your code must put +`jax.pmap` under another transform and the performance penalty is +unacceptable, please file a bug as described above. + +### Buffer donation with `donate_argnums` + +Buffer donation with `donate_argnums` is fully supported in the new `jax.pmap` +implementation, but performance depends on whether inputs are correctly sharded: + +- **Correctly sharded inputs (fast path)**: Arrays with the expected local + sharding use a zero-copy rewrap. Donation invalidates the original array as + expected, with no additional memory overhead. + +- **Incorrectly sharded inputs (slow path)**: Arrays that require resharding + must be copied first, then the original is deleted. This causes a **brief 2x + memory spike** before the original is freed. A warning is logged when this + occurs. + +To maximize donation efficiency, ensure your inputs are correctly sharded +before calling `pmap`. If you see the resharding warning and memory is tight, +consider migrating to `jax.shard_map` where you have full control over +input/output sharding. + +## Migrating to `jax.shard_map` + +For the best support, we recommend migrating from `jax.pmap` to +`jax.jit(jax.shard_map)`. `jax.shard_map` allows you to treat your entire device +cluster as a single computational fabric. + +While the new `jax.pmap` is itself built on `shard_map`, migrating your code +gives you explicit control over data distribution and collective operations. +Migrating involves updates to three primary areas: + +### 1. The pmapped function itself (Rank-preserving vs. Rank-reducing) + +#### Update your mapped function + +The "mapped function" is the function you pass to `jax.pmap` or `jax.shard_map` +(often via a decorator). When migrating, the biggest change within the function +body itself is how array ranks and shapes are handled. While it's possible that +very few if any changes are needed, you should carefully verify any +rank-sensitive logic. + +`jax.pmap` is a **rank-reducing map**: it "unstacks" each array along the mapped +axis. For example, if you map over a `(8, 128)` array on 8 devices, the code +inside `jax.pmap` sees an array of shape `(128,)`. + +In contrast, `jax.shard_map` is a **rank-preserving map**: it slices or +"unconcatenates" the array into blocks. Using the same example on a mesh of size +8, the code inside `jax.shard_map` sees an array of shape `(1, 128)`. + +- **Rank adjustments**: Because `shard_map` slices the array, keeping an + explicit dimension for each mapped axis instead of unstacking it, you may + need to adjust how you treat those dimensions. + + ```python + # pmap style (rank-reduced) + def mapped_fn(x): + # x has shape (128,) + return jnp.dot(x, weights) + + # shard_map style (rank-preserving) + def mapped_fn(x): + # x has shape (1, 128) + # Option 1: restores pmap rank + return jnp.dot(x.squeeze(0), weights) + + # Option 2: use matmul (handles the leading dimension naturally) + # return jnp.matmul(x, weights) + + # Option 3: indexing + # return jnp.dot(x[0], weights) + ``` + +Many JAX functions are sensitive to array rank and may behave differently or +raise errors when moving from `pmap` to `shard_map`. Be particularly careful +with reductions (e.g., `jnp.sum`, `jnp.mean`, `jnp.max`) when the `axis` is not +specified, linear algebra operations (`jnp.dot`, `jnp.matmul`, `jnp.einsum`), +shape manipulations (`jnp.reshape`, `jnp.transpose`, `jnp.squeeze`, +`jnp.expand_dims`), and higher-level neural network layers (e.g., in Flax or +Equinox) that expect specific input ranks for batch or feature dimensions. + +- **Broadcasting vs. Stacking**: In `pmap`, "unmapped" inputs (marked with + `None` in `in_axes`) were implicitly replicated. In `shard_map`, you specify + this via `jax.P()`. The mapped function in `shard_map` sees the _full_ + replicated shape of these inputs, just like `pmap` did. + +#### Rewriting `pmap` to `jit(shard_map)` + +Once you have made any necessary rank adjustments, you can rewrite your +`jax.pmap` calls as `jax.jit(jax.shard_map(...))`. This transition involves a +few key components that differ from the implicit world of `pmap`: + +- **`Mesh`**: Unlike `pmap` which assumes a linear arrangement of devices, + `shard_map` requires an explicit `Mesh` object to define your device topology + and axis names. +- **`in_specs` and `out_specs`**: These replace `in_axes` and `out_axes`. + Instead of just specifying integer axes, you use `jax.P` (PartitionSpec) to + explicitly map array dimensions to named mesh axes. This gives you precise + control over how data is sliced (tiled) for inputs and assembled for outputs. +- **`jax.jit` wrapper**: While `pmap` is itself a compiled transform, + `shard_map` is often used as a building block. Wrapping it in `jax.jit` is + required to trigger the SPMD (Single Program Multiple Data) lowering and + compilation that enables efficient parallel execution across the mesh. + +Below are a number of examples of how to rewrite `jax.pmap` using +`jax.jit(jax.shard_map(...))` after first defining a `Mesh` object. + +```python +from functools import partial +import jax +from jax.sharding import Mesh + +# Define device topology: 8 devices logically arranged as a 1D vector named 'i'. +# This serves as the global context for axis names, similar to 'axis_name' in +# pmap. +mesh = jax.make_mesh(shape=(8,), axis_names=('i',)) +``` + +**Basic Map** + +```python +# pmap style: rank-reducing +# x_global: f32[8, 128] +@jax.pmap +def f(x): + # x: f32[128] + return x * 2 +# output: f32[8, 128] + +# shard_map style: rank-preserving +# x_global: f32[8, 128] +@jax.jit +@partial(jax.shard_map, mesh=mesh, in_specs=jax.P('i'), out_specs=jax.P('i')) +def f(x): + # x: f32[1, 128] (if logically x_global was (8, 128) and mesh size is 8) + return x * 2 +# output: f32[8, 128] +``` + +**Unmapped axes and replicated outputs** + +```python +# pmap style +# x: f32[8, 128], y: f32[128] +@partial(jax.pmap, in_axes=(0, None), out_axes=None) +def f(x, y): + # x: f32[128], y: f32[128] + return x + y +# output: f32[128] (replicated) + +# shard_map style +# x_global: f32[8, 128], y_replicated: f32[128] +@jax.jit +@partial( + jax.shard_map, mesh=mesh, in_specs=(jax.P('i'), jax.P()), out_specs=jax.P() +) +def f(x, y): + # x: f32[1, 128], y: f32[128] + return x + y +# output: f32[128] (replicated) +``` + +**Multiple axes of parallelism** + +```python +# Analogy to pmap(pmap(f, 'i'), 'j') +# mesh2d: 4 devices for 'i', 2 devices for 'j' +mesh2d = jax.make_mesh(shape=(4, 2), axis_names=('i', 'j')) + +# nested pmap +# x: f32[4, 2, 128] +@partial(jax.pmap, axis_name='i') +@partial(jax.pmap, axis_name='j') +def f(x): + # x: f32[128] + return jax.lax.psum(x, ('i', 'j')) +# output: f32[4, 2, 128] (if out_axes=0) + +# shard_map +# x_global: f32[4, 2, 128] +@jax.jit +@partial( + jax.shard_map, mesh=mesh2d, in_specs=jax.P('i', 'j'), out_specs=jax.P() +) +def f(x): + # x: f32[1, 1, 128] + return jax.lax.psum(x, ('i', 'j')) +# output: f32[128] (replicated) +``` + +**Buffer donation** + +```python +# pmap style +# donate_argnums specifies which inputs can be overwritten in-place +f = jax.pmap(func, donate_argnums=(0,)) + +# shard_map style: donate_argnums goes on the jit wrapper +# The underlying shard_map itself just handles the sharding layout +f = jax.jit(jax.shard_map(func, mesh=mesh, ...), donate_argnums=(0,)) +``` + +#### Collectives + +Collective operations like `jax.lax.psum` still use +`axis_name`, but they now operate over named mesh axes defined in your `Mesh` +object. Note that in `shard_map`, you must choose an `out_specs` that is +consistent with your collective (e.g., if you `psum` over `'i'`, an +`out_specs` of `jax.P()` implies you want a replicated result). + +### 2. Input data preparation + +Preparing data for `jax.jit(jax.shard_map)` requires a shift in how you think +about data distribution. While `jax.pmap` often handled sharding implicitly +based on array shapes and `in_axes`, `shard_map` asks you to be explicit about +how global data is sliced and placed across your device mesh. This means you +must directly provide arrays with a `sharding` that matches the `mesh` and +`in_specs` of your `shard_map` call; unlike `pmap`, `shard_map` will not +implicitly reshard inputs and will instead raise a **hard error** (e.g., +[`ValueError: Received incompatible +devices`](#valueerror-received-incompatible-devices-)). +This involves new considerations for data locality, sharding layouts, and +multi-host orchestration. + +#### Host-local vs. Global Views + +Migration often starts with how you currently load data. + +- **Host-local Array**: An array stored only on the devices attached to the + current process. This is the standard `pmap` pattern where each host + independently loads a subset of the dataset (e.g., using + `jax.process_index()` to calculate an offset). +- **Global Array**: The entire logical dataset across all devices in the `Mesh`. + `shard_map` (via `jax.jit`) expects this global view. + +#### Addressability and Topology + +The relationship between these views depends on your hardware setup. + +- **Single-host**: All devices are connected to one process. A "global" + array and a "fully addressable" array are effectively the same thing because + the process can "see" every shard. +- **Multi-host**: Devices are spread across multiple processes (e.g., a + TPU Pod). Each process only "sees" its local devices. +- **Fully Addressable**: A global array is **fully addressable** if the current + process can access all of its shards. In multi-host settings, global arrays + are typically **not fully addressable**; each process only sees the + "host-local" part. You can query this state using the + `x.is_fully_addressable` property. + +#### Shardings + +You define how global arrays are distributed across devices using +`jax.NamedSharding`. When using `shard_map`, it is critical that the **input +array's sharding explicitly matches** the `mesh` and `in_specs` you pass to the +`shard_map` call. If the physical distribution of your data does not align with +the logical distribution expected by `shard_map`, JAX will have to reshard the +data (potentially involving expensive communication) before the parallel +computation can begin. + +- **NamedSharding vs. PmapSharding**: + - `PmapSharding` is the legacy internal representation for `pmap`. It is + inherently **rank-reducing** and tied to the implicit device axis of + `pmap`. + - `NamedSharding` is the modern, flexible representation used with `jit` + and `shard_map`. It is **rank-preserving** and uses a `Mesh` and + `PartitionSpec` to logically map array dimensions to device axes. +- **SingleDeviceSharding**: While `shard_map` is about distributed data, + `jax.SingleDeviceSharding` remains a core part of the system. It is used + for arrays that live entirely on one device, such as host-local data or the + results of unshared computations. + +#### The Migration Pattern: "Stitching" + +In `pmap`, JAX implicitly handled the split across hosts. With `shard_map`, you +must be explicit. The standard pattern is to load **host-local** data (just as +you did for `pmap`) and then use +`jax.make_array_from_process_local_data` to "stitch" that local data into a +single global (but partially addressable) `jax.Array` before passing it to your +sharded computation. + +```python +import jax +import jax.numpy as jnp +import numpy as np + +# 1. Define your mesh and sharding (logical view) +mesh = jax.make_mesh((jax.process_count(),), ('batch',)) +sharding = jax.NamedSharding(mesh, jax.P('batch')) + +# 2. Load host-local data (as you would for pmap) +# Example: each process loads a different subset of a dataset +local_batch_size = 32 +start_idx = jax.process_index() * local_batch_size +local_data = np.arange(start_idx, start_idx + local_batch_size).reshape( + local_batch_size, 10 +) + +# 3. Stitch into a global jax.Array +# The resulting array will have global shape (32 * num_processes, 10) +global_batch = jax.make_array_from_process_local_data(sharding, local_data) + +print(f"Process {jax.process_index()} local shape: {local_data.shape}") +print(f"Global array shape: {global_batch.shape}") +``` + +> [!NOTE] +> `jax.make_array_from_process_local_data` requires that the `local_data` shape +> on each process matches the expected shard size derived from the `sharding`. + +### 3. Output consumption + +While `pmap` returns a value that is often treated as a stack of per-device +outputs (sometimes requiring a `concatenate` to use as a single array), +`shard_map` returns a single `jax.Array`. + +#### Global View + +The output is already a single logical array sharded across +devices. You can immediately perform global operations on it (like +`jnp.mean(output)`) within a `jax.jit` context. + +#### The `unreplicate` Anti-pattern + +As described in [Appendix A](#appendix-a), there is a common pattern where +arrays are **physically sharded** across devices despite being **logically +replicated** (i.e., every shard contains the same data). + +In the legacy `pmap` implementation, users would frequently call +`flax.jax_utils.unreplicate(output)` (equivalent to `output[0]`) to retrieve +what they assumed was a cheap local replica. + +- **The issue**: JAX does not track semantic replication for sharded arrays. + When you call `x[0]` on an array sharded along its leading axis, JAX must + assume the first shard contains unique data that needs to be broadcast to the + entire mesh to satisfy indexing semantics. This triggers a **global gather**, + causing significant performance regressions. +- **Recommendation**: Avoid creating physically sharded replicas. If you must + work with them, use `x.addressable_shards[0].data` to access the local replica + without triggering communication. See [Appendix A](#appendix-a) for a detailed + technical breakdown. + +#### Host access + +To get the data back to the host process, you use standard +JAX patterns like `device_get` or simple indexing. + +### Related documentation + +To help with migration, we recommend reviewing the following documentation based +on your needs: + +- **{doc}`sharded-computation`**: Start here for a high-level introduction to + parallel programming in JAX. This tutorial covers all three sharding modes + (automatic, explicit, and manual) with a comparison table, explains key + concepts like data sharding and `NamedSharding`, and demonstrates how each + mode handles a simple neural network layer. This is the best starting point + for understanding the overall landscape of parallelism in JAX. + +- **{doc}`notebooks/Distributed_arrays_and_automatic_parallelization`**: Read + this for a deeper understanding of `jax.Array` and automatic parallelization + via `jax.jit`. This notebook explains how sharded data works, how computation + follows data placement, and how to use `jax.lax.with_sharding_constraint` to + guide the compiler. It includes practical neural network examples with batch + data parallelism and model tensor parallelism. + +- **{doc}`notebooks/shard_map`**: This is the comprehensive guide for manual + parallelism with `jax.shard_map`. It explains the difference between + rank-reducing maps (like `vmap`) and rank-preserving maps (like `shard_map`), + how to control input splitting and output assembly with `in_specs` and + `out_specs`, and includes a detailed collectives tutorial covering `psum`, + `all_gather`, `psum_scatter`, and more. If you're migrating complex `pmap` + code with explicit collectives, this is essential reading. + +- **{doc}`notebooks/explicit-sharding`**: Explore this for the newest sharding + mode where sharding becomes part of the JAX-level type system. With explicit + sharding, sharding propagation happens at trace time and shardings are + queryable via `jax.typeof(x)`. This mode provides more control than automatic + sharding while still using a global-view programming model. It's particularly + useful when you want deterministic sharding behavior without resorting to + fully manual parallelism. + +- **{doc}`jep/14273-shard-map`**: Read the original design document for + `shard_map`. This JEP (JAX Enhancement Proposal) provides the technical + rationale for the API, detailed comparisons with `pmap` and `xmap`, and + explains the fundamental concepts of rank-reducing vs. rank-preserving maps + over array axes. + +(appendix-a)= + +## Appendix A: More details about `int` indexing into sharded arrays. + +### What should `x[0]` return? + +In **NumPy**, `x[0]` returns a rank-reduced array representing the first slice +along the first dimension. For example, if `x = np.ones((8, 3, 4))`, then `x[0]` +returns an array of shape `(3, 4)`. + +In **JAX** (`jax.numpy`), `x[0]` semantically works the same way: it returns the +rank-reduced slice of the logical array `x`. However, performance depends on how +`x` is sharded or replicated across devices. Consider an array `x` with shape +`(8, 3, 4)` distributed across 8 devices (using `jax.P`): + +1. **Fully Replicated:** `jax.P(None, None, None)` + If `x` is fully replicated, every device holds a complete copy of the `(8, +3, 4)` array. `x[0]` will have the shape `(3, 4)` and a partition spec + `jax.P(None, None)`. Since every device already has `x`, this operation will + slice on each device independently and requires **no communication**. + +2. **Sharded on Non-Leading Dimension:** `jax.P(None, 'x', None)` + If `x` is sharded along the second dimension, `x[0]` results in shape `(3, +4)` with partition spec `jax.P('x', None)`. Since the first dimension (the + one being sliced) is unsharded, this operation also requires **no + communication**. + +3. **Sharded on Leading Dimension:** `jax.P('x', None, None)` + If `x` is sharded along the first dimension, `x[0]` results in shape + `(3, 4)` with partition spec `jax.P(None, None)`. + - **The Issue:** Because the first dimension is sharded, the data for + `x[0]` physically resides _only_ on the first device. To satisfy the + output sharding `jax.P(None, None)` (which implies replication), JAX + must broadcast the data from the first device to all other devices. This + requires **communication**; JAX will gather the _entire_ array of shape + `(8, 3, 4)` to each device and then take a slice. + +### The common performance pitfall + +A common pattern among `jax.pmap` users involves arrays that are **semantically +replicated** (the user intends for them to be identical everywhere) but are +**physically sharded** (stacked along the leading dimension). + +This happens implicitly (e.g., via `jax.pmap(..., out_axes=0)`) or explicitly +(e.g., via `jax.device_put_replicated`). Users often try to retrieve metrics or +checkpoints by calling `unreplicate` or `x[0]`, assuming it is a cheap +operation. + +#### Example: The "unreplicate" anti-pattern + +```python +from flax import jax_utils +import jax.numpy as jnp +import jax + +# jax_utils.replicate calls jax.device_put_replicated. +# This stacks num_devices copies and SHARDS them over the stacked dimension. +# Logical Shape: (8, 3, 4) | Sharding: P('x', None, None) +train_state = jax_utils.replicate({'params': jnp.zeros((3, 4))}) + +# out_axes=0 by default, so the output remains sharded along dim 0. +train_step_pmapped = jax.pmap(lambda x: x) + +# jax_utils.unreplicate performs a jax.tree_map(lambda x: x[0], tree). +# Users do this to grab metrics, log param statistics, checkpoint, etc. +train_state = jax_utils.unreplicate(train_step_pmapped(train_state)) +``` + +#### The consequence + +Even though the user knows `train_state` contains identical data on every +device (it is **logically replicated**), JAX sees an array with +`shape (8, 3, 4)` and spec `jax.P('x', None, None)`—that is, the data is +**physically sharded** along its leading dimension. + +**JAX does not track semantic replication.** It does not "know" that the shard +on device 1 is identical to the shard on device 0. Therefore, when you call +`x[0]`, JAX must satisfy the strict semantics of array indexing: it must +retrieve the first slice and, because the output is typically expected to be +available for subsequent JIT-ted operations, it must often ensure that result +is replicated across the mesh. + +This triggers a **global gather (or broadcast)** of the entire array to all +devices before slicing. What the user assumes is a constant-time "ignore the +extra copies" operation actually becomes a serialized communication bottleneck +(visible as `_gather` operations in a stack trace). + +``` +train + └─ jax_utils.py:48 unreplicate + └─ tree_util.py:354 tree_map + └─ jax_utils.py:50 (performing x[0]) + └─ array.py:335 __getitem__ + └─ indexing.py:734 rewriting_take + │ + ▼ + └─ indexing.py:784 _gather + └─ slicing.py:324 gather + └─ PjitFunction(gather) +``` + +### Why was "old `jax.pmap`" fast? + +Historically, `pmap` used `PmapSharding`, which had a fast-path optimization in +`jax.Array`'s `__getitem__` allowing it to return an array with a +`SingleDeviceSharding` (data residing on only one device). + +However, current JAX uses `NamedSharding`. We do not strictly replicate the +legacy behavior because it breaks the semantics of array indexing. If we allowed +`x[0]` to return a `SingleDeviceSharding` array in a general context (e.g., in +the middle of a train step instead of when trying to bring data back to host for +reporting), only one device would have data while others would have nothing. +This is computationally problematic for subsequent operations. + +The slowdown users experience now is JAX enforcing correct semantics: if you ask +for `x[0]` from an array sharded along its leading dimension, you get a fully +replicated result available on all devices, which requires communication. + +### A note on the multi-host setting + +`x[0]` will still give you the first slice along the first dimension of the +_logical_ global array. In the multi-host setting, we will see a more drastic +version of the performance issues described above as all the hosts gather the +entire array to each device before slicing. In certain cases, users can even +face hard errors (e.g., `INVALID_ARGUMENT: CopyArrays only support...`). + +In multi-host settings (e.g., 4 hosts × 2 devices = 8 devices total): + +1. A global array with shape `(8, ...)` and `jax.P('x')` has each slice + distributed across all 8 devices spanning all hosts. + +2. When you call `x[0]`, JAX needs to slice the first element and reshard the + result so it's available to all hosts. + +3. The `CopyArrays` operation in XLA requires source and destination to have the + same device count. But each host only sees its _local_ subset of devices (2 + in this example), not all 8. When JAX tries to create a resharded array, the + device list mismatch triggers the error. + + diff --git a/docs/multi_process.md b/docs/multi_process.md index 32cfae126784..f8c2566ca872 100644 --- a/docs/multi_process.md +++ b/docs/multi_process.md @@ -1,176 +1,667 @@ -# Multi-host and multi-process environments - - - -## Introduction - -This guide explains how to use JAX in environments such as -GPU clusters and [Cloud TPU](https://cloud.google.com/tpu) pods where -accelerators are spread across multiple CPU hosts or JAX processes. We’ll refer -to these as “multi-process” environments. - -This guide specifically focuses on how to use collective communication -operations (e.g. {func}`jax.lax.psum` ) in multi-process settings, although -other communication methods may be useful too depending on your use case (e.g. -RPC, [mpi4jax](https://github.com/mpi4jax/mpi4jax)). If you’re not already -familiar with JAX’s collective operations, we recommend starting with the -{doc}`/sharded-computation` section. An important requirement of -multi-process environments in JAX is direct communication links between -accelerators, e.g. the high-speed interconnects for Cloud TPUs or -[NCCL](https://developer.nvidia.com/nccl) for GPUs. These links allow -collective operations to run across multiple processes’ worth of accelerators -with high performance. - -## Multi-process programming model - -Key concepts: - - * You must run at least one JAX process per host. - * You should initialize the cluster with {func}`jax.distributed.initialize`. - * Each process has a - distinct set of *local* devices it can address. The *global* devices are the set - of all devices across all processes. - * Use standard JAX parallelism APIs like {func}`~jax.jit` (see - {doc}`/sharded-computation` tutorial) and - {func}`~jax.experimental.shard_map.shard_map`. jax.jit only accepts - globally shaped arrays. shard_map allows you to drop to per-device - shape. - * Make sure all processes run the same parallel computations in the same - order. - * Make sure all processes has the same number of local devices. - * Make sure all devices are the same (e.g., all V100, or all H100). - -### Launching JAX processes - -Unlike other distributed systems where a single controller node manages many -worker nodes, JAX uses a “multi-controller” programming model where each JAX -Python process runs independently, sometimes referred to as a {term}`Single -Program, Multiple Data (SPMD)` model. Generally, the same JAX Python -program is run in each process, with only slight differences between each -process’s execution (e.g. different processes will load different input data). -Furthermore, **you must manually run your JAX program on each host!** JAX -doesn’t automatically start multiple processes from a single program invocation. - -(The requirement for multiple processes is why this guide isn’t offered as a -notebook -- we don’t currently have a good way to manage multiple Python -processes from a single notebook.) - -### Initializing the cluster - -To initialize the cluster, you should call {func}`jax.distributed.initialize` at -the start of each process. {func}`jax.distributed.initialize` must be called -early in the program, before any JAX computations are executed. - -The API {func}`jax.distributed.initialize` takes several arguments, namely: - - * `coordinator_address`: the IP address of process 0 in your cluster, together - with a port available on that process. Process 0 will start a JAX service - exposed via that IP address and port, to which the other processes in the - cluster will connect. - * `coordinator_bind_address`: the IP address and port to which the JAX service - on process 0 in your cluster will bind. By default, it will bind to all - available interfaces using the same port as `coordinator_address`. - * `num_processes`: the number of processes in the cluster - * `process_id`: the ID number of this process, in the range `[0 .. - num_processes)`. - * `local_device_ids`: Restricts the visible devices of the current process to - ``local_device_ids``. - -For example on GPU, a typical usage is: +# Introduction to multi-controller JAX (aka multi-process/multi-host JAX) + + + +By reading this tutorial, you'll learn how to scale JAX computations to more +devices than can fit in a single host machine, e.g. when running on a GPU +cluster, Cloud TPU pod, or multiple CPU-only machines. + +The main idea + +- **Run multiple Python processes**, which we sometimes call "controllers." We + can run one (or more) process per host machine. +- **Initialize the cluster with {func}`jax.distributed.initialize`**. +- **A {class}`jax.Array` can span all processes**, and if each process applies + the same JAX function to it, it's like programming against one big device. +- **Use the same [unified sharding mechanism][unified_sharding]** as in + single-controller JAX to control how data is distributed and computation is + parallelized. XLA automatically exploits high-speed networking links like TPU + ICI or NVLink between hosts when available, and otherwise uses available host + networking (e.g. Ethernet, InfiniBand). +- **All processes (usually) run the same Python script**. You write this Python + code almost exactly the same as you would for a single process — just run + multiple instances of it and JAX takes care of the rest. In other words, + except for array creation, you can write your JAX code as if there were one + giant machine with all devices attached to it. + +This tutorial assumes you've read [Distributed arrays and automatic +parallelization][distributed_arrays], which is about single-controller JAX. + +```{figure} _static/multi_process/mcjax_overview.png +:alt: Illustration of a multi-host TPU pod. Each host in the pod is attached via PCI to a board of four TPU chips. The TPUs chips themselves are connected via high-speed inter-chip interconnects. + +Illustration of a multi-host TPU pod. Each host in the pod (green) is attached +via PCI to a board of four TPU chips (blue). The TPUs chips themselves are +connected via high-speed inter-chip interconnects (ICI). JAX Python code runs on +each host, e.g. via ssh. The JAX processes on each host are aware of each other, +allowing you to orchestrate computation across the entire pods' worth of chips. +The principle is the same for GPU, CPU, and other platforms with JAX support! +``` + +## Toy example + +Before we define terms and walk through the details, here's a toy example: +making a process-spanning {class}`jax.Array` of values and applying +{mod}`jax.numpy` functions to it. ```python +# call this file toy.py, to be run in each process simultaneously + import jax +import jax.numpy as jnp +from jax.sharding import NamedSharding, PartitionSpec as P +import numpy as np + +# in this example, get multi-process parameters from sys.argv +import sys +proc_id = int(sys.argv[1]) +num_procs = int(sys.argv[2]) + +# initialize the distributed system +jax.distributed.initialize('localhost:10000', num_procs, proc_id) + +# this example assumes 8 devices total +assert jax.device_count() == 8 + +# make a 2D mesh that refers to devices from all processes +mesh = jax.make_mesh((4, 2), ('i', 'j')) -jax.distributed.initialize(coordinator_address="192.168.0.1:1234", - num_processes=2, - process_id=0) +# create some toy data +global_data = np.arange(32).reshape((4, 8)) + +# make a process- and device-spanning array from our toy data +sharding = NamedSharding(mesh, P('i', 'j')) +global_array = jax.device_put(global_data, sharding) +assert global_array.shape == global_data.shape + +# each process has different shards of the global array +for shard in global_array.addressable_shards: + print(f"device {shard.device} has local data {shard.data}") + +# apply a simple computation, automatically partitioned +global_result = jnp.sum(jnp.sin(global_array)) +print(f'process={proc_id} got result: {global_result}') ``` -On Cloud TPU, Slurm and Open MPI environments, you can simply call {func}`jax.distributed.initialize()` with no -arguments. Default values for the arguments will be chosen automatically. -When running on GPUs with Slurm and Open MPI, it is assumed that one process is started per GPU, i.e. each process will -be assigned only one visible local device. Otherwise it is assumed that one process is started per host, -i.e. each process will be assigned all local devices. -The Open MPI auto-initialization is only used when the JAX processes are launched via `mpirun`/`mpiexec`. +Here, `mesh` contains devices from all processes. We use it to create +`global_array`, logically a single shared array, stored distributed across +devices from all processes. + +Every process must apply the same operations, in the same order, to +`global_array`. XLA automatically partitions those computations, for example +inserting communication collectives to compute the `jnp.sum` over the full +array. We can print the final result because its value is replicated across +processes. + +We can run this code locally on CPU, e.g. using 4 processes and 2 CPU devices +per process: + +```bash +export JAX_NUM_CPU_DEVICES=2 +num_processes=4 + +range=$(seq 0 $(($num_processes - 1))) + +for i in $range; do + python toy.py $i $num_processes > /tmp/toy_$i.out & +done + +wait + +for i in $range; do + echo "=================== process $i output ===================" + cat /tmp/toy_$i.out + echo +done +``` + +Outputs: + +```text +=================== process 0 output =================== +device TFRT_CPU_0 has local data [[0 1 2 3]] +device TFRT_CPU_1 has local data [[4 5 6 7]] +process=0 got result: -0.12398731708526611 + +=================== process 1 output =================== +device TFRT_CPU_131072 has local data [[ 8 9 10 11]] +device TFRT_CPU_131073 has local data [[12 13 14 15]] +process=1 got result: -0.12398731708526611 + +=================== process 2 output =================== +device TFRT_CPU_262144 has local data [[16 17 18 19]] +device TFRT_CPU_262145 has local data [[20 21 22 23]] +process=2 got result: -0.12398731708526611 + +=================== process 3 output =================== +device TFRT_CPU_393216 has local data [[24 25 26 27]] +device TFRT_CPU_393217 has local data [[28 29 30 31]] +process=3 got result: -0.12398731708526611 +``` + +This might not look so different from single-controller JAX code, and in fact, +this is exactly how you'd write the single-controller version of the same +program! (We don't technically need to call {func}`jax.distributed.initialize` +for single-controller, but it doesn't hurt.) Let's run the same code from a +single process: + +```text +JAX_NUM_CPU_DEVICES=8 python toy.py 0 1 +``` + +Outputs: + +```text +device TFRT_CPU_0 has local data [[0 1 2 3]] +device TFRT_CPU_1 has local data [[4 5 6 7]] +device TFRT_CPU_2 has local data [[ 8 9 10 11]] +device TFRT_CPU_3 has local data [[12 13 14 15]] +device TFRT_CPU_4 has local data [[16 17 18 19]] +device TFRT_CPU_5 has local data [[20 21 22 23]] +device TFRT_CPU_6 has local data [[24 25 26 27]] +device TFRT_CPU_7 has local data [[28 29 30 31]] +process=0 got result: -0.12398731708526611 +``` + +The data is sharded across eight devices on one process rather than eight +devices across four processes, but otherwise we're running the same operations +over the same data. + +## Terminology + +It's worth pinning down some terminology. + +We sometimes call each Python process running JAX computations a **controller**, +but the two terms are essentially synonymous. + +Each process has a set of **local devices**, meaning it can transfer data to and +from those devices' memories and run computation on those devices without +involving any other processes. The local devices are usually physically attached +to the process's corresponding host, e.g. via PCI. A device can only be local to +one process; that is, the local device sets are disjoint. A process's local +devices can be queried by evaluating {func}`jax.local_devices()`. We sometimes +use the term **addressable** to mean the same thing as local. + +```{figure} _static/multi_process/controller_and_local_devices.png +:alt: Illustration of how a process/controller and local devices fit into a larger multi-host cluster. The "global devices" are all devices in the cluster. + +Illustration of how a process/controller and local devices fit into a larger +multi-host cluster. The "global devices" are all devices in the cluster. +``` + +The devices across all processes are called the **global devices**. The list of +global devices is queried by {func}`jax.devices()`. That list of all devices is +populated by running {func}`jax.distributed.initialize` on all processes, which +sets up a simple distributed system connecting the processes. + +We often use the terms **global** and **local** to describe process-spanning and +process-local concepts in general. For example, a "local array" could be a numpy +array that's only visible to a single process, vs. a JAX "global array" is +conceptually visible to all processes. + +## Setting up multiple JAX processes + +In practice, setting up multiple JAX processes looks a bit different from the +toy example, which is run from a single host machine. We usually launch each +process on a separate host, or have multiple hosts with multiple processes each. +We can do that directly using `ssh`, or with a cluster manager like Slurm or +Kubernetes. In any case, **you must manually run your JAX program on each +host!** JAX doesn’t automatically start multiple processes from a single program +invocation. + +However they're launched, the Python processes need to run +{func}`jax.distributed.initialize`. When using Slurm, Kubernetes, or any Cloud +TPU deployment, we can run {func}`jax.distributed.initialize` with no arguments +as they're automatically populated. Initializing the system means we can run +{func}`jax.devices()` to report all devices across all processes. + +```{warning} +{func}`jax.distributed.initialize` must be called before running +{func}`jax.devices()`, {func}`jax.local_devices()`, or running any computations +on devices (e.g. with {mod}`jax.numpy`). Otherwise the JAX process won't be +aware of any non-local devices. (Using {func}`jax.config` or other +non-device-accessing functionality is ok.) {func}`jax.distributed.initialize` +will raise an error if you accidentally call it after accessing any devices. +``` + +### GPU Example + +We can run multi-controller JAX on a cluster of [GPU machines][gpu_machines]. +For example, after creating four VMs on Google Cloud with two GPUs per VM, we +can run the following JAX program on every VM. In this example, we provide +arguments to {func}`jax.distributed.initialize` explicitly. The coordinator +address, process id, and number of processes are read from the command line. ```python +# In file gpu_example.py... + import jax +import sys + +# Get the coordinator_address, process_id, and num_processes from the command line. +coord_addr = sys.argv[1] +proc_id = int(sys.argv[2]) +num_procs = int(sys.argv[3]) + +# Initialize the GPU machines. +jax.distributed.initialize(coordinator_address=coord_addr, + num_processes=num_procs, + process_id=proc_id) +print("process id =", jax.process_index()) +print("global devices =", jax.devices()) +print("local devices =", jax.local_devices()) +``` + +For example, if the first VM has address `192.168.0.1`, then you would run +`python3 gpu_example.py 192.168.0.1:8000 0 4` on the first VM, `python3 +gpu_example.py 192.168.0.1:8000 1 4` on the second VM, and so on. After running +the JAX program on all four VMs, the first process prints the following. + +```text +process id = 0 +global devices = [CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3), CudaDevice(id=4), CudaDevice(id=5), CudaDevice(id=6), CudaDevice(id=7)] +local devices = [CudaDevice(id=0), CudaDevice(id=1)] +``` + +The process successfully sees all eight GPUs as global devices, as well as its +two local devices. Similarly, the second process prints the following. + +```text +process id = 1 +global devices = [CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3), CudaDevice(id=4), CudaDevice(id=5), CudaDevice(id=6), CudaDevice(id=7)] +local devices = [CudaDevice(id=2), CudaDevice(id=3)] +``` +This VM sees the same global devices, but has a different set of local devices. + +### TPU Example + +As another example, we can run on [Cloud TPU][cloud_tpu]. After creating a +`v5litepod-16` (which has 4 host machines), we might want to test that we can +connect the processes and list all devices: + +```text +$ TPU_NAME=jax-demo +$ EXTERNAL_IPS=$(gcloud compute tpus tpu-vm describe $TPU_NAME --zone 'us-central1-a' \ + | grep externalIp | cut -d: -f2) +$ cat << EOF > demo.py +import jax jax.distributed.initialize() +if jax.process_index() == 0: + print(jax.devices()) +EOF +$ echo $EXTERNAL_IPS | xargs -n 1 -P 0 bash -c ' +scp demo.py $0: +ssh $0 "pip -q install -U jax[tpu]" +ssh $0 "python demo.py" ' ``` -On TPU at present calling {func}`jax.distributed.initialize` is optional, but -recommended since it enables additional checkpointing and health checking features. +Here we're using `xargs` to run multiple `ssh` commands in parallel, each one +running the same Python program on one of the TPU host machines. In the Python +code, we use {func}`jax.process_index()` to print only on one process. Here's +what it prints: -### Local vs. global devices +```text +[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=2, process_index=1, coords=(2,0,0), core_on_chip=0), TpuDevice(id=3, process_index=1, coords=(3,0,0), core_on_chip=0), TpuDevice(id=6, process_index=1, coords=(2,1,0), core_on_chip=0), TpuDevice(id=7, process_index=1, coords=(3,1,0), core_on_chip=0), TpuDevice(id=8, process_index=2, coords=(0,2,0), core_on_chip=0), TpuDevice(id=9, process_index=2, coords=(1,2,0), core_on_chip=0), TpuDevice(id=12, process_index=2, coords=(0,3,0), core_on_chip=0), TpuDevice(id=13, process_index=2, coords=(1,3,0), core_on_chip=0), TpuDevice(id=10, process_index=3, coords=(2,2,0), core_on_chip=0), TpuDevice(id=11, process_index=3, coords=(3,2,0), core_on_chip=0), TpuDevice(id=14, process_index=3, coords=(2,3,0), core_on_chip=0), TpuDevice(id=15, process_index=3, coords=(3,3,0), core_on_chip=0)] +``` -Before we get to running multi-process computations from your program, it’s -important to understand the distinction between *local* and *global* devices. +Woohoo, look at all those TPU cores! + +### Kubernetes Example + +Running multi-controller JAX on a Kubernetes cluster is almost identical in spirit to the GPU and TPU examples above: every pod runs the same Python program, JAX discovers its peers, and the cluster behaves like one giant machine. + +1. **Container image** - start from a JAX-enabled image, e.g. one of the public JAX AI images on Google Artifact Registry ([TPU][google-artifact-tpu] / [GPU][google-artifact-gpu]) or NVIDIA ([NGC][nvidia-ngc] / [JAX-Toolbox][nvidia-jax-toolbox]). + +2. **Workload type** - use either a [JobSet][k8s-jobset] or an [indexed Job][k8s-indexed-job]. Each replica corresponds to one JAX process. + +3. **Service Account** - JAX needs permission to list the pods that belong to the job so that processes discover their peers. A minimal RBAC setup is provided in [examples/k8s/svc-acct.yaml][rbac-svc-acct]. + +Below is a [minimal JobSet][minimal-jobset] that launches two replicas. Replace the placeholders - +image, GPU count, and any private registry secrets - with values that match your environment. + +```yaml +apiVersion: jobset.x-k8s.io/v1alpha2 +kind: JobSet +metadata: + name: jaxjob +spec: + replicatedJobs: + - name: workers + template: + spec: + parallelism: 2 + completions: 2 + backoffLimit: 0 + template: + spec: + serviceAccountName: jax-job-sa # kubectl apply -f svc-acct.yaml + restartPolicy: Never + imagePullSecrets: + # https://k8s.io/docs/tasks/configure-pod-container/pull-image-private-registry/ + - name: null + containers: + - name: main + image: null # e.g. ghcr.io/nvidia/jax:jax + imagePullPolicy: Always + resources: + limits: + cpu: 1 + # https://k8s.io/docs/tasks/manage-gpus/scheduling-gpus/ + nvidia.com/gpu: null + command: + - python + args: + - -c + - | + import jax + jax.distributed.initialize() + print(jax.devices()) + print(jax.local_devices()) + assert jax.process_count() > 1 + assert len(jax.devices()) > len(jax.local_devices()) +``` -**A process’s *local* devices are those that it can directly address and launch -computations on.** For example, on a GPU cluster, each host can only launch -computations on the directly attached GPUs. On a Cloud TPU pod, each host can -only launch computations on the 8 TPU cores attached directly to that host (see -the -[Cloud TPU System Architecture](https://cloud.google.com/tpu/docs/system-architecture) -documentation for more details). You can see a process’s local devices via -{func}`jax.local_devices()`. +Apply the manifest and watch the pods complete: -**The *global* devices are the devices across all processes.** A computation can -span devices across processes and perform collective operations via the direct -communication links between devices, as long as each process launches the -computation on its local devices. You can see all available global devices via -{func}`jax.devices()`. A process’s local devices are always a subset of the -global devices. +```bash +$ kubectl apply -f example.yaml +$ kubectl get pods -l jobset.sigs.k8s.io/jobset-name=jaxjob +NAME READY STATUS RESTARTS AGE +jaxjob-workers-0-0-xpx8l 0/1 Completed 0 8m32s +jaxjob-workers-0-1-ddkq8 0/1 Completed 0 8m32s +``` -### Running multi-process computations +When the job finishes, inspect the logs to confirm that every process saw all accelerators: -So how do you actually run a computation involving cross-process communication? -**Use the same parallel evaluation APIs that you would in a single process!** +```bash +$ kubectl logs -l jobset.sigs.k8s.io/jobset-name=jaxjob +[CudaDevice(id=0), CudaDevice(id=1)] +[CudaDevice(id=0)] +[CudaDevice(id=0), CudaDevice(id=1)] +[CudaDevice(id=1)] +``` + +Every pod should have the same set of global devices and a different set of local devices. At this point, you can replace the inline script with your real JAX program. + +Once the processes are set up, we can start building global {class}`jax.Array`s +and running computations. The remaining Python code examples in this tutorial +are meant to be run on all processes simultaneously, after running +{func}`jax.distributed.initialize`. + +## Meshes, shardings, and computations can span processes and hosts + +Programming multiple processes from JAX usually looks just like programming a +single process, just with more devices! The main exceptions to this are around +data coming in or out of JAX, e.g. when loading from external data sources. +We'll first go over the basics of multi-process computations here, which largely +look the same as their single-process counterparts. The next section goes over +some data loading fundamentals, i.e. how to create JAX Arrays from non-JAX +sources. + +Recall a {class}`jax.sharding.Mesh` pairs an array of {class}`jax.Device`s with +a sequence of names, with one name per array axis. By creating a `Mesh` using +devices from multiple processes, then using that mesh in a +{class}`jax.sharding.Sharding`, we can construct {class}`jax.Array`s sharded +over devices from multiple processes. + +Here's an example that directly constructs a `Mesh` using {func}`jax.devices()` +to get devices from all processes: + +```python +from jax.sharding import Mesh +mesh = Mesh(jax.devices(), ('a',)) + +# in this case, the same as +mesh = jax.make_mesh((jax.device_count(),), ('a',)) # use this in practice +``` + +You should probably use the {func}`jax.make_mesh` helper in practice, not only +because it's simpler but also because it can choose more performant device +orderings automatically, but we're spelling it out here. By default it includes +all devices across processes, just like {func}`jax.devices()`. + +Once we have a mesh, we can shard arrays over it. There are a few ways to +efficiently build process-spanning arrays, detailed in the next section, but for +now we'll stick to `jax.device_put` for simplicity: + +```python +arr = jax.device_put(jnp.ones((32, 32)), NamedSharding(mesh, P('a'))) +if jax.process_index() == 0: + jax.debug.visualize_array_sharding(arr) +``` + +On process 0, this is printed: + +``` +┌───────────────────────┐ +│ TPU 0 │ +├───────────────────────┤ +│ TPU 1 │ +├───────────────────────┤ +│ TPU 4 │ +├───────────────────────┤ +│ TPU 5 │ +├───────────────────────┤ +│ TPU 2 │ +├───────────────────────┤ +│ TPU 3 │ +├───────────────────────┤ +│ TPU 6 │ +├───────────────────────┤ +│ TPU 7 │ +├───────────────────────┤ +│ TPU 8 │ +├───────────────────────┤ +│ TPU 9 │ +├───────────────────────┤ +│ TPU 12 │ +├───────────────────────┤ +│ TPU 13 │ +├───────────────────────┤ +│ TPU 10 │ +├───────────────────────┤ +│ TPU 11 │ +├───────────────────────┤ +│ TPU 14 │ +├───────────────────────┤ +│ TPU 15 │ +└───────────────────────┘ +``` + +Let's try a slightly more interesting computation! + +```python +mesh = jax.make_mesh((jax.device_count() // 2, 2), ('a', 'b')) + +def device_put(x, spec): + return jax.device_put(x, NamedSharding(mesh, spec)) + +# construct global arrays by sharding over the global mesh +x = device_put(jnp.ones((4096, 2048)), P('a', 'b')) +y = device_put(jnp.ones((2048, 4096)), P('b', None)) + +# run a distributed matmul +z = jax.nn.relu(x @ y) + +# inspect the sharding of the result +if jax.process_index() == 0: + jax.debug.visualize_array_sharding(z) + print() + print(z.sharding) +``` -For example, {func}`~jax.experimental.shard_map.shard_map` can be used -to run a parallel computation across multiple processes. (If you’re -not already familiar with how to use `shard_map` to run across -multiple devices within a single process, check out the -{doc}`/sharded-computation` tutorial.) Conceptually, this can be -thought of as running a pmap over a single array sharded across hosts, -where each host “sees” only its local shard of the input and output. +On process 0, this is printed: -Here’s an example of multi-process pmap in action: +``` +┌───────────────────────┐ +│ TPU 0,1 │ +├───────────────────────┤ +│ TPU 4,5 │ +├───────────────────────┤ +│ TPU 8,9 │ +├───────────────────────┤ +│ TPU 12,13 │ +├───────────────────────┤ +│ TPU 2,3 │ +├───────────────────────┤ +│ TPU 6,7 │ +├───────────────────────┤ +│ TPU 10,11 │ +├───────────────────────┤ +│ TPU 14,15 │ +└───────────────────────┘ + +NamedSharding(mesh=Mesh('a': 8, 'b': 2), spec=PartitionSpec('a',), memory_kind=device) +``` + +Here, just from evaluating `x @ y` on all processes, XLA is automatically +generating and running a distributed matrix multiplication. The result is +sharded against the mesh like `P('a', None)`, since in this case the matmul +included a `psum` over the `'b'` axis. + +```{warning} +When applying JAX computations to process-spanning arrays, to avoid deadlocks +and hangs, **it's crucial that all processes with participating devices run the +same computation in the same order**. That's because the computation may +involve collective communication barriers. If a device over which an array is +sharded does not join in the collective because its controller didn't issue the +same computation, the other devices are left waiting. For example, if only the +first three processes evaluated `x @ y`, while the last process evaluated `y @ +x`, the computation would likely hang indefinitely. This assumption, +computations on process-spanning arrays are run on all participating processes +in the same order, is mostly unchecked. + +So the easiest way to avoid deadlocks in multi-process JAX is to run the same +Python code on every process, and beware of any control flow that depends on +{func}`jax.process_index()` and includes communication. +``` + +If a process-spanning array is sharded over devices on different processes, it +is an error to perform operations on the array that require the data to be +available locally to a process, like printing. For example, if we run `print(z)` +in the preceding example, we see + +``` +RuntimeError: Fetching value for `jax.Array` that spans non-addressable (non process local) devices is not possible. You can use `jax.experimental.multihost_utils.process_allgather` to print the global array or use `.addressable_shards` method of jax.Array to inspect the addressable (process local) shards. +``` + +To print the full array value, we must first ensure it's replicated over +processes (but not necessarily over each process's local devices), e.g. using +`jax.device_put`. In the above example, we can write at the end: + +``` +w = device_put(z, P(None, None)) +if jax.process_index() == 0: + print(w) +``` + +Be careful not to write the {func}`jax.device_put` under the `if process_index() +== 0`, because that would lead to a deadlock as only process 0 initiates the +collective communication and waits indefinitely for the other processes. +The {mod}`jax.experimental.multihost_utils` module has some functions that +make it easier to process global {class}`jax.Array`s (e.g., +{func}`jax.experimental.multihost_utils.process_allgather`). + +Alternatively, to print or otherwise perform Python operations on only +process-local data, we can access `z.addressable_shards`. Accessing that +attribute does not require any communication, so any subset of processes can do +it without needing the others. That attribute is not available under a +{func}`jax.jit`. + +## Making process-spanning arrays from external data + +There are three main ways to create process-spanning {class}`jax.Array`s from +external data sources (e.g. numpy arrays from a data loader): + +1. Create or load the full array on all processes, then shard onto devices using + {func}`jax.device_put`; + +2. Create or load on each process an array representing just the data that will + be locally sharded and stored on that process's devices, then shard onto + devices using {func}`jax.make_array_from_process_local_data`; + +3. Create or load on each process's devices separate arrays, each representing + the data to be stored on that device, then assemble them without any data + movement using {func}`jax.make_array_from_single_device_arrays`. + +The latter two are most often used in practice, since it's often too expensive +to materialize the full global data in every process. + +The toy example above uses {func}`jax.device_put`. + +{func}`jax.make_array_from_process_local_data` is often used for distributed data +loading. It's not as general as {func}`jax.make_array_from_single_device_arrays`, +because it doesn't directly specify which slice of the process-local data goes +on each local device. This is convenient when loading data-parallel batches, +because it doesn't matter exactly which microbatch goes on each device. For +example: ```python -# The following is run in parallel on each host on a GPU cluster or TPU pod slice. ->>> import jax ->>> jax.distributed.initialize() # On GPU, see above for the necessary arguments. ->>> jax.device_count() # total number of accelerator devices in the cluster -32 ->>> jax.local_device_count() # number of accelerator devices attached to this host -8 -# The psum is performed over all mapped devices across the pod slice ->>> xs = jax.numpy.ones(jax.local_device_count()) ->>> jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs) -ShardedDeviceArray([32., 32., 32., 32., 32., 32., 32., 32.], dtype=float32) -``` - -**It’s very important that all processes run the same cross-process computations -in the same order.** Running the same JAX Python program in each process is -usually sufficient. Some common pitfalls to look out for that may cause -differently-ordered computations despite running the same program: - -* Processes passing differently-shaped inputs to the same parallel function - can cause hangs or incorrect return values. Differently-shaped inputs are - safe so long as they result in identically-shaped per-device data shards - across processes; e.g. passing in different leading batch sizes in order to - run on different numbers of local devices per process is ok, but having each - process pad its batch to a different max example length is not. - -* “Last batch” issues where a parallel function is called in a (training) - loop, and one or more processes exit the loop earlier than the rest. This - will cause the rest to hang waiting for the already-finished processes to - start the computation. - -* Conditions based on non-deterministic ordering of collections can cause code - processes to hang. For example, iterating over - `set` on current Python versions or `dict` [before Python 3.7](https://mail.python.org/pipermail/python-dev/2017-December/151283.html) - may result in a different ordering on different processes, even with the - same insertion order. +# target (micro)batch size across the whole cluster +batch_size = 1024 +# how many examples each process should load per batch +per_process_batch_size = batch_size // jax.process_count() +# how many examples each device will process per batch +per_device_batch_size = batch_size // jax.device_count() + +# make a data-parallel mesh and sharding +mesh = jax.make_mesh((jax.device_count(),), ('batch')) +sharding = NamedSharding(mesh, P('batch')) + +# our "data loader". each process loads a different set of "examples". +process_batch = np.random.rand(per_process_batch_size, 2048, 42) + +# assemble a global array containing the per-process batches from all processes +global_batch = jax.make_array_from_process_local_data(sharding, process_batch) + +# sanity check that everything got sharded correctly +assert global_batch.shape[0] == batch_size +assert process_batch.shape[0] == per_process_batch_size +assert global_batch.addressable_shards[0].data.shape[0] == per_device_batch_size +``` + +{func}`jax.make_array_from_single_device_arrays` is the most general way to +build a process-spanning array. It's often used after performing +{func}`jax.device_put`s to send each device its required data. This is the +lowest-level option, since all data movement is performed manually (via e.g. +{func}`jax.device_put`). Here's an example: + +```python +shape = (jax.process_count(), jax.local_device_count()) +mesh = jax.make_mesh(shape, ('i', 'j')) +sharding = NamedSharding(mesh, P('i', 'j')) + +# manually create per-device data equivalent to np.arange(jax.device_count()) +# i.e. each device will get a single scalar value from 0..N +local_arrays = [ + jax.device_put( + jnp.array([[jax.process_index() * jax.local_device_count() + i]]), + device) + for i, device in enumerate(jax.local_devices()) +] + +# assemble a global array from the local_arrays across all processes +global_array = jax.make_array_from_single_device_arrays( + shape=shape, + sharding=sharding, + arrays=local_arrays) + +# sanity check +assert (np.all( + jax.experimental.multihost_utils.process_allgather(global_array) == + np.arange(jax.device_count()).reshape(global_array.shape))) +``` + +[cloud_tpu]: https://cloud.google.com/tpu?hl=en +[distributed_arrays]: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html +[gpu_machines]: https://cloud.google.com/compute/docs/gpus +[unified_sharding]: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html +[google-artifact-tpu]: https://console.cloud.google.com/artifacts/docker/cloud-tpu-images/us/jax-ai-image/tpu +[google-artifact-gpu]: https://console.cloud.google.com/artifacts/docker/deeplearning-images/us-central1/jax-ai-image/gpu +[nvidia-ngc]: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/jax +[nvidia-jax-toolbox]: https://github.com/NVIDIA/JAX-Toolbox +[k8s-jobset]: https://github.com/kubernetes-sigs/jobset +[k8s-indexed-job]: https://kubernetes.io/docs/concepts/workloads/controllers/job/#parallel-jobs +[rbac-svc-acct]: https://github.com/jax-ml/jax/blob/main/examples/k8s/svc-acct.yaml +[minimal-jobset]: https://github.com/jax-ml/jax/blob/main/examples/k8s/example.yaml diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index a1435c4e557e..38d2a0e84383 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -307,7 +307,7 @@ "id": "go3L4x3w4-9p" }, "source": [ - "If we try to update a JAX device array in-place, however, we get an __error__! (☉_☉)" + "If we try to do in-place indexed updating on a `jax.Array`, however, we get an __error__! (☉_☉)" ] }, { @@ -346,7 +346,7 @@ "evalue": "ignored", "output_type": "error", "traceback": [ - "\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m '' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html\n" + "\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m '' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html\n" ] } ], @@ -357,6 +357,45 @@ "jax_array[1, :] = 1.0" ] }, + { + "cell_type": "markdown", + "id": "8f520bec", + "metadata": {}, + "source": [ + "And if we try to do `__iadd__`-style in-place updating, we get __different behavior than NumPy__! (☉_☉) (☉_☉)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20fbed45", + "metadata": {}, + "outputs": [], + "source": [ + "jax_array = jnp.array([10, 20])\n", + "jax_array_new = jax_array\n", + "jax_array_new += 10\n", + "print(jax_array_new) # `jax_array_new` is rebound to a new value [20, 30], but...\n", + "print(jax_array) # the original value is unmodified as [10, 20] !\n", + "\n", + "numpy_array = np.array([10, 20])\n", + "numpy_array_new = numpy_array\n", + "numpy_array_new += 10\n", + "print(numpy_array_new) # `numpy_array_new is numpy_array`, and it was updated\n", + "print(numpy_array) # in-place, so both are [20, 30] !" + ] + }, + { + "cell_type": "markdown", + "id": "2604e220", + "metadata": {}, + "source": [ + "That's because NumPy defines `__iadd__` to perform in-place mutation. In\n", + "contrast, `jax.Array` doesn't define an `__iadd__`, so Python treats\n", + "`jax_array_new += 10` as syntactic sugar for `jax_array_new = jax_array_new +\n", + "10`, rebinding the variable without mutating any arrays." + ] + }, { "cell_type": "markdown", "metadata": { @@ -365,7 +404,7 @@ "source": [ "Allowing mutation of variables in-place makes program analysis and transformation difficult. JAX requires that programs are pure functions.\n", "\n", - "Instead, JAX offers a _functional_ array update using the [`.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at)." + "Instead, JAX offers a _functional_ array update using the [`.at` property on JAX arrays](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at)." ] }, { @@ -415,6 +454,7 @@ } ], "source": [ + "jax_array = jnp.zeros((3,3), dtype=jnp.float32)\n", "updated_array = jax_array.at[1, :].set(1.0)\n", "print(\"updated array:\\n\", updated_array)" ] @@ -521,7 +561,378 @@ "id": "sTjJ3WuaDyqU" }, "source": [ - "For more details on indexed array updates, see the [documentation for the `.at` property](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at)." + "For more details on indexed array updates, see the [documentation for the `.at` property](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(jax-jit-class-methods)=\n", + "## 🔪 Using `jax.jit` with class methods\n", + "\n", + "Most examples of [`jax.jit`](https://docs.jax.dev/en/latest/_autosummary/jax.jit.html) concern decorating stand-alone Python functions, but decorating a method within a class introduces some complication. For example, consider the following simple class, where we've used a standard `jax.jit` annotation on a method:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "import jax.numpy as jnp\n", + "from jax import jit\n", + "\n", + "class CustomClass:\n", + " def __init__(self, x: jnp.ndarray, mul: bool):\n", + " self.x = x\n", + " self.mul = mul\n", + "\n", + " @jit # <---- How to do this correctly?\n", + " def calc(self, y):\n", + " if self.mul:\n", + " return self.x * y\n", + " return y" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "However, this approach will result in an error when you attempt to call this method:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "tags": [ + "raises-exception" + ] + }, + "outputs": [ + { + "ename": "TypeError", + "evalue": "Error interpreting argument to as an abstract array. The problematic value is of type and was passed to the function at path self.\nThis typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mTypeError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[8]\u001b[39m\u001b[32m, line 2\u001b[39m\n\u001b[32m 1\u001b[39m c = CustomClass(\u001b[32m2\u001b[39m, \u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[32m----> \u001b[39m\u001b[32m2\u001b[39m \u001b[43mc\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcalc\u001b[49m\u001b[43m(\u001b[49m\u001b[32;43m3\u001b[39;49m\u001b[43m)\u001b[49m\n", + " \u001b[31m[... skipping hidden 5 frame]\u001b[39m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/.local/share/mamba/envs/jax-dev/lib/python3.12/site-packages/jax/_src/pjit.py:659\u001b[39m, in \u001b[36m_infer_input_type\u001b[39m\u001b[34m(fun, dbg_fn, explicit_args)\u001b[39m\n\u001b[32m 657\u001b[39m dbg = dbg_fn()\n\u001b[32m 658\u001b[39m arg_description = \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mpath \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdbg.arg_names[i]\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mif\u001b[39;00m\u001b[38;5;250m \u001b[39mdbg.arg_names\u001b[38;5;250m \u001b[39m\u001b[38;5;129;01mis\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;129;01mnot\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01melse\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[33m'\u001b[39m\u001b[33munknown\u001b[39m\u001b[33m'\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m \u001b[38;5;66;03m# pytype: disable=name-error\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m659\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[32m 660\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mError interpreting argument to \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfun\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m as an abstract array.\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 661\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m The problematic value is of type \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(x)\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m and was passed to\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;66;03m# pytype: disable=name-error\u001b[39;00m\n\u001b[32m 662\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m the function at \u001b[39m\u001b[38;5;132;01m{\u001b[39;00marg_description\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33m\"\u001b[39m\n\u001b[32m 663\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mThis typically means that a jit-wrapped function was called with a non-array\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 664\u001b[39m \u001b[33m\"\u001b[39m\u001b[33m argument, and this argument was not marked as static using the\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 665\u001b[39m \u001b[33m\"\u001b[39m\u001b[33m static_argnums or static_argnames parameters of jax.jit.\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 666\u001b[39m ) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 667\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m config.mutable_array_checks.value:\n\u001b[32m 668\u001b[39m check_no_aliased_ref_args(dbg_fn, avals, explicit_args)\n", + "\u001b[31mTypeError\u001b[39m: Error interpreting argument to as an abstract array. The problematic value is of type and was passed to the function at path self.\nThis typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit." + ] + } + ], + "source": [ + "c = CustomClass(2, True)\n", + "c.calc(3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The problem is that the first argument to the function is `self`, which has type `CustomClass`, and JAX does not know how to handle this type. There are three basic strategies we might use in this case, and we'll discuss them below." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Strategy 1: JIT-compiled helper function\n", + "\n", + "The most straightforward approach is to create a helper function external to the class that can be JIT-decorated in the normal way. For example:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "from functools import partial\n", + "\n", + "class CustomClass:\n", + " def __init__(self, x: jnp.ndarray, mul: bool):\n", + " self.x = x\n", + " self.mul = mul\n", + "\n", + " def calc(self, y):\n", + " return _calc(self.mul, self.x, y)\n", + "\n", + "@partial(jit, static_argnums=0)\n", + "def _calc(mul, x, y):\n", + " if mul:\n", + " return x * y\n", + " return y" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The result will work as expected:" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6\n" + ] + } + ], + "source": [ + "c = CustomClass(2, True)\n", + "print(c.calc(3))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The benefit of such an approach is that it is simple, explicit, and it avoids the need to teach JAX how to handle objects of type `CustomClass`. However, you may wish to keep all the method logic in the same place." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Strategy 2: Marking `self` as static\n", + "\n", + "Another common pattern is to use `static_argnums` to mark the `self` argument as static. But this must be done with care to avoid unexpected results. You may be tempted to simply do this:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "class CustomClass:\n", + " def __init__(self, x: jnp.ndarray, mul: bool):\n", + " self.x = x\n", + " self.mul = mul\n", + "\n", + " # WARNING: this example is broken, as we'll see below. Don't copy & paste!\n", + " @partial(jit, static_argnums=0)\n", + " def calc(self, y):\n", + " if self.mul:\n", + " return self.x * y\n", + " return y" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you call the method, it will no longer raise an error:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6\n" + ] + } + ], + "source": [ + "c = CustomClass(2, True)\n", + "print(c.calc(3))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "However, there is a catch: if you mutate the object after the first method call, the subsequent method call may return an incorrect result:" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6\n" + ] + } + ], + "source": [ + "c.mul = False\n", + "print(c.calc(3)) # Should print 3" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Why is this? When you mark an object as static, it will effectively be used as a dictionary key in JIT's internal compilation cache, meaning its hash (i.e. `hash(obj)`) equality (i.e. `obj1 == obj2`) and object identity (i.e. `obj1 is obj2`) will be assumed to have consistent behavior. The default `__hash__` for a custom object is its object ID, and so JAX has no way of knowing that a mutated object should trigger a re-compilation.\n", + "\n", + "You can partially address this by defining an appropriate `__hash__` and `__eq__` methods for your object; for example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class CustomClass:\n", + " def __init__(self, x: jnp.ndarray, mul: bool):\n", + " self.x = x\n", + " self.mul = mul\n", + "\n", + " @partial(jit, static_argnums=0)\n", + " def calc(self, y):\n", + " if self.mul:\n", + " return self.x * y\n", + " return y\n", + "\n", + " def __hash__(self):\n", + " return hash((self.x, self.mul))\n", + "\n", + " def __eq__(self, other):\n", + " return (isinstance(other, CustomClass) and\n", + " (self.x, self.mul) == (other.x, other.mul))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(see the [`object.__hash__`](https://docs.python.org/3/reference/datamodel.html#object.__hash__) documentation for more discussion of the requirements\n", + "when overriding `__hash__`).\n", + "\n", + "This should work correctly with JIT and other transforms **so long as you never mutate your object**. Mutations of objects used as hash keys lead to several subtle problems, which is why for example mutable Python containers (e.g. [`dict`](https://docs.python.org/3/library/stdtypes.html#dict), [`list`](https://docs.python.org/3/library/stdtypes.html#list)) don't define `__hash__`, while their immutable counterparts (e.g. [`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple)) do.\n", + "\n", + "If your class relies on in-place mutations (such as setting `self.attr = ...` within its methods), then your object is not really \"static\" and marking it as such may lead to problems. Fortunately, there's another option for this case." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Strategy 3: Making `CustomClass` a PyTree\n", + "\n", + "The most flexible approach to correctly JIT-compiling a class method is to register the type as a custom PyTree object; see [Custom pytree nodes](https://docs.jax.dev/en/latest/custom_pytrees.html#pytrees-custom-pytree-nodes). This lets you specify exactly which components of the class should be treated as static and which should be\n", + "treated as dynamic. Here's how it might look:" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "class CustomClass:\n", + " def __init__(self, x: jnp.ndarray, mul: bool):\n", + " self.x = x\n", + " self.mul = mul\n", + "\n", + " @jit\n", + " def calc(self, y):\n", + " if self.mul:\n", + " return self.x * y\n", + " return y\n", + "\n", + " def _tree_flatten(self):\n", + " children = (self.x,) # arrays / dynamic values\n", + " aux_data = {'mul': self.mul} # static values\n", + " return (children, aux_data)\n", + "\n", + " @classmethod\n", + " def _tree_unflatten(cls, aux_data, children):\n", + " return cls(*children, **aux_data)\n", + "\n", + "from jax import tree_util\n", + "tree_util.register_pytree_node(CustomClass,\n", + " CustomClass._tree_flatten,\n", + " CustomClass._tree_unflatten)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is certainly more involved, but it solves all the issues associated with the simpler approaches used above:" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6\n" + ] + } + ], + "source": [ + "c = CustomClass(2, True)\n", + "print(c.calc(3))" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3\n" + ] + } + ], + "source": [ + "c.mul = False # mutation is detected\n", + "print(c.calc(3))" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6\n" + ] + } + ], + "source": [ + "c = CustomClass(jnp.array(2), True) # non-hashable x is supported\n", + "print(c.calc(3))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "So long as your `tree_flatten` and `tree_unflatten` functions correctly handle all relevant attributes in the class, you should be able to use objects of this type directly as arguments to JIT-compiled functions, without any special annotations." ] }, { @@ -604,7 +1015,7 @@ "id": "NAcXJNAcDi_v" }, "source": [ - "If you would like finer-grained control over the behavior for out-of-bound indices, you can use the optional parameters of [`ndarray.at`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html); for example:" + "If you would like finer-grained control over the behavior for out-of-bound indices, you can use the optional parameters of [`ndarray.at`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html); for example:" ] }, { @@ -971,7 +1382,7 @@ "evalue": "ignored", "output_type": "error", "traceback": [ - "\u001b[0;31mNonConcreteBooleanIndexError\u001b[0m\u001b[0;31m:\u001b[0m Array boolean indices must be concrete; got ShapedArray(bool[5])\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError\n" + "\u001b[0;31mNonConcreteBooleanIndexError\u001b[0m\u001b[0;31m:\u001b[0m Array boolean indices must be concrete; got ShapedArray(bool[5])\n\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError\n" ] } ], @@ -1030,163 +1441,9 @@ "id": "DKTMw6tRZyK2" }, "source": [ - "## 🔪 NaNs" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ncS0NI4jZrwy" - }, - "source": [ - "### Debugging NaNs\n", + "## 🔪 Debugging NaNs and Infs\n", "\n", - "If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by:\n", - "\n", - "* setting the `JAX_DEBUG_NANS=True` environment variable;\n", - "\n", - "* adding `jax.config.update(\"jax_debug_nans\", True)` near the top of your main file;\n", - "\n", - "* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;\n", - "\n", - "This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jit`. For code under an `@jit`, the output of every `@jit` function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of `@jit` at a time.\n", - "\n", - "There could be tricky situations that arise, like nans that only occur under a `@jit` but don't get produced in de-optimized mode. In that case you'll see a warning message print out but your code will continue to execute.\n", - "\n", - "If the nans are being produced in the backward pass of a gradient evaluation, when an exception is raised several frames up in the stack trace you will be in the backward_pass function, which is essentially a simple jaxpr interpreter that walks the sequence of primitive operations in reverse. In the example below, we started an ipython repl with the command line `env JAX_DEBUG_NANS=True ipython`, then ran this:" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "p6ZtDHPbBa_W" - }, - "source": [ - "```\n", - "In [1]: import jax.numpy as jnp\n", - "\n", - "In [2]: jnp.divide(0., 0.)\n", - "---------------------------------------------------------------------------\n", - "FloatingPointError Traceback (most recent call last)\n", - " in ()\n", - "----> 1 jnp.divide(0., 0.)\n", - "\n", - ".../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)\n", - " 343 return floor_divide(x1, x2)\n", - " 344 else:\n", - "--> 345 return true_divide(x1, x2)\n", - " 346\n", - " 347\n", - "\n", - ".../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)\n", - " 332 x1, x2 = _promote_shapes(x1, x2)\n", - " 333 return lax.div(lax.convert_element_type(x1, result_dtype),\n", - "--> 334 lax.convert_element_type(x2, result_dtype))\n", - " 335\n", - " 336\n", - "\n", - ".../jax/jax/lax.pyc in div(x, y)\n", - " 244 def div(x, y):\n", - " 245 r\"\"\"Elementwise division: :math:`x \\over y`.\"\"\"\n", - "--> 246 return div_p.bind(x, y)\n", - " 247\n", - " 248 def rem(x, y):\n", - "\n", - "... stack trace ...\n", - "\n", - ".../jax/jax/interpreters/xla.pyc in handle_result(device_buffer)\n", - " 103 py_val = device_buffer.to_py()\n", - " 104 if np.any(np.isnan(py_val)):\n", - "--> 105 raise FloatingPointError(\"invalid value\")\n", - " 106 else:\n", - " 107 return Array(device_buffer, *result_shape)\n", - "\n", - "FloatingPointError: invalid value\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_NCnVt_GBa_W" - }, - "source": [ - "The nan generated was caught. By running `%debug`, we can get a post-mortem debugger. This also works with functions under `@jit`, as the example below shows." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pf8RF6eiBa_W" - }, - "source": [ - "```\n", - "In [4]: from jax import jit\n", - "\n", - "In [5]: @jit\n", - " ...: def f(x, y):\n", - " ...: a = x * y\n", - " ...: b = (x + y) / (x - y)\n", - " ...: c = a + 2\n", - " ...: return a + b * c\n", - " ...:\n", - "\n", - "In [6]: x = jnp.array([2., 0.])\n", - "\n", - "In [7]: y = jnp.array([3., 0.])\n", - "\n", - "In [8]: f(x, y)\n", - "Invalid value encountered in the output of a jit function. Calling the de-optimized version.\n", - "---------------------------------------------------------------------------\n", - "FloatingPointError Traceback (most recent call last)\n", - " in ()\n", - "----> 1 f(x, y)\n", - "\n", - " ... stack trace ...\n", - "\n", - " in f(x, y)\n", - " 2 def f(x, y):\n", - " 3 a = x * y\n", - "----> 4 b = (x + y) / (x - y)\n", - " 5 c = a + 2\n", - " 6 return a + b * c\n", - "\n", - ".../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)\n", - " 343 return floor_divide(x1, x2)\n", - " 344 else:\n", - "--> 345 return true_divide(x1, x2)\n", - " 346\n", - " 347\n", - "\n", - ".../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)\n", - " 332 x1, x2 = _promote_shapes(x1, x2)\n", - " 333 return lax.div(lax.convert_element_type(x1, result_dtype),\n", - "--> 334 lax.convert_element_type(x2, result_dtype))\n", - " 335\n", - " 336\n", - "\n", - ".../jax/jax/lax.pyc in div(x, y)\n", - " 244 def div(x, y):\n", - " 245 r\"\"\"Elementwise division: :math:`x \\over y`.\"\"\"\n", - "--> 246 return div_p.bind(x, y)\n", - " 247\n", - " 248 def rem(x, y):\n", - "\n", - " ... stack trace ...\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "6ur2yArDBa_W" - }, - "source": [ - "When this code sees a nan in the output of an `@jit` function, it calls into the de-optimized code, so we still get a clear stack trace. And we can run a post-mortem debugger with `%debug` to inspect all the values to figure out the error.\n", - "\n", - "⚠️ You shouldn't have the NaN-checker on if you're not debugging, as it can introduce lots of device-host round-trips and performance regressions!\n", - "\n", - "⚠️ The NaN-checker doesn't work with `pmap`. To debug nans in `pmap` code, one thing to try is replacing `pmap` with `vmap`." + "Use the `jax_debug_nans` and `jax_debug_infs` flags to find the source of NaN/Inf values in functions and gradients. See {ref}`debugging-flags`." ] }, { @@ -1296,8 +1553,8 @@ "While `jax.numpy` makes every attempt to replicate the behavior of numpy's API, there do exist corner cases where the behaviors differ.\n", "Many such cases are discussed in detail in the sections above; here we list several other known places where the APIs diverge.\n", "\n", - "- For binary operations, JAX's type promotion rules differ somewhat from those used by NumPy. See [Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html) for more details.\n", - "- When performing unsafe type casts (i.e. casts in which the target dtype cannot represent the input value), JAX's behavior may be backend dependent, and in general may diverge from NumPy's behavior. Numpy allows control over the result in these scenarios via the `casting` argument (see [`np.ndarray.astype`](https://numpy.org/devdocs/reference/generated/numpy.ndarray.astype.html)); JAX does not provide any such configuration, instead directly inheriting the behavior of [XLA:ConvertElementType](https://www.tensorflow.org/xla/operation_semantics#convertelementtype).\n", + "- For binary operations, JAX's type promotion rules differ somewhat from those used by NumPy. See [Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html) for more details.\n", + "- When performing unsafe type casts (i.e. casts in which the target dtype cannot represent the input value), JAX's behavior may be backend dependent, and in general may diverge from NumPy's behavior. Numpy allows control over the result in these scenarios via the `casting` argument (see [`np.ndarray.astype`](https://numpy.org/devdocs/reference/generated/numpy.ndarray.astype.html)); JAX does not provide any such configuration, instead directly inheriting the behavior of [XLA:ConvertElementType](https://www.openxla.org/xla/operation_semantics#convertelementtype).\n", "\n", " Here is an example of an unsafe cast with differing results between NumPy and JAX:\n", " ```python\n", @@ -1345,7 +1602,7 @@ "formats": "ipynb,md:myst" }, "kernelspec": { - "display_name": "Python 3", + "display_name": "jax-dev", "language": "python", "name": "python3" }, @@ -1359,15 +1616,10 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.2 (v3.8.2:7b3ab5921f, Feb 24 2020, 17:52:18) \n[Clang 6.0 (clang-600.0.57)]" + "version": "3.12.12" }, "mystnb": { "render_error_lexer": "none" - }, - "vscode": { - "interpreter": { - "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49" - } } }, "nbformat": 4, diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index 80ab69be1ed8..40a40640d8ef 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -7,7 +7,7 @@ jupytext: format_version: 0.13 jupytext_version: 1.16.4 kernelspec: - display_name: Python 3 + display_name: jax-dev language: python name: python3 --- @@ -177,7 +177,7 @@ print(numpy_array) +++ {"id": "go3L4x3w4-9p"} -If we try to update a JAX device array in-place, however, we get an __error__! (☉_☉) +If we try to do in-place indexed updating on a `jax.Array`, however, we get an __error__! (☉_☉) ```{code-cell} ipython3 :id: iOscaa_GecEK @@ -197,11 +197,32 @@ jax_array = jnp.zeros((3,3), dtype=jnp.float32) jax_array[1, :] = 1.0 ``` +And if we try to do `__iadd__`-style in-place updating, we get __different behavior than NumPy__! (☉_☉) (☉_☉) + +```{code-cell} ipython3 +jax_array = jnp.array([10, 20]) +jax_array_new = jax_array +jax_array_new += 10 +print(jax_array_new) # `jax_array_new` is rebound to a new value [20, 30], but... +print(jax_array) # the original value is unmodified as [10, 20] ! + +numpy_array = np.array([10, 20]) +numpy_array_new = numpy_array +numpy_array_new += 10 +print(numpy_array_new) # `numpy_array_new is numpy_array`, and it was updated +print(numpy_array) # in-place, so both are [20, 30] ! +``` + +That's because NumPy defines `__iadd__` to perform in-place mutation. In +contrast, `jax.Array` doesn't define an `__iadd__`, so Python treats +`jax_array_new += 10` as syntactic sugar for `jax_array_new = jax_array_new + +10`, rebinding the variable without mutating any arrays. + +++ {"id": "7mo76sS25Wco"} Allowing mutation of variables in-place makes program analysis and transformation difficult. JAX requires that programs are pure functions. -Instead, JAX offers a _functional_ array update using the [`.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at). +Instead, JAX offers a _functional_ array update using the [`.at` property on JAX arrays](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at). +++ {"id": "hfloZ1QXCS_J"} @@ -219,6 +240,7 @@ For example, the update above can be written as: :id: PBGI-HIeCP_s :outputId: de13f19a-2066-4df1-d503-764c34585529 +jax_array = jnp.zeros((3,3), dtype=jnp.float32) updated_array = jax_array.at[1, :].set(1.0) print("updated array:\n", updated_array) ``` @@ -261,7 +283,192 @@ print(new_jax_array) +++ {"id": "sTjJ3WuaDyqU"} -For more details on indexed array updates, see the [documentation for the `.at` property](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at). +For more details on indexed array updates, see the [documentation for the `.at` property](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at). + ++++ + +(jax-jit-class-methods)= +## 🔪 Using `jax.jit` with class methods + +Most examples of [`jax.jit`](https://docs.jax.dev/en/latest/_autosummary/jax.jit.html) concern decorating stand-alone Python functions, but decorating a method within a class introduces some complication. For example, consider the following simple class, where we've used a standard `jax.jit` annotation on a method: + +```{code-cell} ipython3 +import jax.numpy as jnp +from jax import jit + +class CustomClass: + def __init__(self, x: jnp.ndarray, mul: bool): + self.x = x + self.mul = mul + + @jit # <---- How to do this correctly? + def calc(self, y): + if self.mul: + return self.x * y + return y +``` + +However, this approach will result in an error when you attempt to call this method: + +```{code-cell} ipython3 +:tags: [raises-exception] + +c = CustomClass(2, True) +c.calc(3) +``` + +The problem is that the first argument to the function is `self`, which has type `CustomClass`, and JAX does not know how to handle this type. There are three basic strategies we might use in this case, and we'll discuss them below. + ++++ + +### Strategy 1: JIT-compiled helper function + +The most straightforward approach is to create a helper function external to the class that can be JIT-decorated in the normal way. For example: + +```{code-cell} ipython3 +from functools import partial + +class CustomClass: + def __init__(self, x: jnp.ndarray, mul: bool): + self.x = x + self.mul = mul + + def calc(self, y): + return _calc(self.mul, self.x, y) + +@partial(jit, static_argnums=0) +def _calc(mul, x, y): + if mul: + return x * y + return y +``` + +The result will work as expected: + +```{code-cell} ipython3 +c = CustomClass(2, True) +print(c.calc(3)) +``` + +The benefit of such an approach is that it is simple, explicit, and it avoids the need to teach JAX how to handle objects of type `CustomClass`. However, you may wish to keep all the method logic in the same place. + ++++ + +### Strategy 2: Marking `self` as static + +Another common pattern is to use `static_argnums` to mark the `self` argument as static. But this must be done with care to avoid unexpected results. You may be tempted to simply do this: + +```{code-cell} ipython3 +class CustomClass: + def __init__(self, x: jnp.ndarray, mul: bool): + self.x = x + self.mul = mul + + # WARNING: this example is broken, as we'll see below. Don't copy & paste! + @partial(jit, static_argnums=0) + def calc(self, y): + if self.mul: + return self.x * y + return y +``` + +If you call the method, it will no longer raise an error: + +```{code-cell} ipython3 +c = CustomClass(2, True) +print(c.calc(3)) +``` + +However, there is a catch: if you mutate the object after the first method call, the subsequent method call may return an incorrect result: + +```{code-cell} ipython3 +c.mul = False +print(c.calc(3)) # Should print 3 +``` + +Why is this? When you mark an object as static, it will effectively be used as a dictionary key in JIT's internal compilation cache, meaning its hash (i.e. `hash(obj)`) equality (i.e. `obj1 == obj2`) and object identity (i.e. `obj1 is obj2`) will be assumed to have consistent behavior. The default `__hash__` for a custom object is its object ID, and so JAX has no way of knowing that a mutated object should trigger a re-compilation. + +You can partially address this by defining an appropriate `__hash__` and `__eq__` methods for your object; for example: + +```{code-cell} ipython3 +class CustomClass: + def __init__(self, x: jnp.ndarray, mul: bool): + self.x = x + self.mul = mul + + @partial(jit, static_argnums=0) + def calc(self, y): + if self.mul: + return self.x * y + return y + + def __hash__(self): + return hash((self.x, self.mul)) + + def __eq__(self, other): + return (isinstance(other, CustomClass) and + (self.x, self.mul) == (other.x, other.mul)) +``` + +(see the [`object.__hash__`](https://docs.python.org/3/reference/datamodel.html#object.__hash__) documentation for more discussion of the requirements +when overriding `__hash__`). + +This should work correctly with JIT and other transforms **so long as you never mutate your object**. Mutations of objects used as hash keys lead to several subtle problems, which is why for example mutable Python containers (e.g. [`dict`](https://docs.python.org/3/library/stdtypes.html#dict), [`list`](https://docs.python.org/3/library/stdtypes.html#list)) don't define `__hash__`, while their immutable counterparts (e.g. [`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple)) do. + +If your class relies on in-place mutations (such as setting `self.attr = ...` within its methods), then your object is not really "static" and marking it as such may lead to problems. Fortunately, there's another option for this case. + ++++ + +### Strategy 3: Making `CustomClass` a PyTree + +The most flexible approach to correctly JIT-compiling a class method is to register the type as a custom PyTree object; see [Custom pytree nodes](https://docs.jax.dev/en/latest/custom_pytrees.html#pytrees-custom-pytree-nodes). This lets you specify exactly which components of the class should be treated as static and which should be +treated as dynamic. Here's how it might look: + +```{code-cell} ipython3 +class CustomClass: + def __init__(self, x: jnp.ndarray, mul: bool): + self.x = x + self.mul = mul + + @jit + def calc(self, y): + if self.mul: + return self.x * y + return y + + def _tree_flatten(self): + children = (self.x,) # arrays / dynamic values + aux_data = {'mul': self.mul} # static values + return (children, aux_data) + + @classmethod + def _tree_unflatten(cls, aux_data, children): + return cls(*children, **aux_data) + +from jax import tree_util +tree_util.register_pytree_node(CustomClass, + CustomClass._tree_flatten, + CustomClass._tree_unflatten) +``` + +This is certainly more involved, but it solves all the issues associated with the simpler approaches used above: + +```{code-cell} ipython3 +c = CustomClass(2, True) +print(c.calc(3)) +``` + +```{code-cell} ipython3 +c.mul = False # mutation is detected +print(c.calc(3)) +``` + +```{code-cell} ipython3 +c = CustomClass(jnp.array(2), True) # non-hashable x is supported +print(c.calc(3)) +``` + +So long as your `tree_flatten` and `tree_unflatten` functions correctly handle all relevant attributes in the class, you should be able to use objects of this type directly as arguments to JIT-compiled functions, without any special annotations. +++ {"id": "oZ_jE2WAypdL"} @@ -292,7 +499,7 @@ jnp.arange(10)[11] +++ {"id": "NAcXJNAcDi_v"} -If you would like finer-grained control over the behavior for out-of-bound indices, you can use the optional parameters of [`ndarray.at`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html); for example: +If you would like finer-grained control over the behavior for out-of-bound indices, you can use the optional parameters of [`ndarray.at`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html); for example: ```{code-cell} ipython3 :id: -0-MaFddO-xy @@ -459,138 +666,9 @@ Similar tricks can be played in other situations where dynamically-shaped arrays +++ {"id": "DKTMw6tRZyK2"} -## 🔪 NaNs - -+++ {"id": "ncS0NI4jZrwy"} - -### Debugging NaNs - -If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by: - -* setting the `JAX_DEBUG_NANS=True` environment variable; - -* adding `jax.config.update("jax_debug_nans", True)` near the top of your main file; - -* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`; - -This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jit`. For code under an `@jit`, the output of every `@jit` function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of `@jit` at a time. - -There could be tricky situations that arise, like nans that only occur under a `@jit` but don't get produced in de-optimized mode. In that case you'll see a warning message print out but your code will continue to execute. - -If the nans are being produced in the backward pass of a gradient evaluation, when an exception is raised several frames up in the stack trace you will be in the backward_pass function, which is essentially a simple jaxpr interpreter that walks the sequence of primitive operations in reverse. In the example below, we started an ipython repl with the command line `env JAX_DEBUG_NANS=True ipython`, then ran this: - -+++ {"id": "p6ZtDHPbBa_W"} - -``` -In [1]: import jax.numpy as jnp - -In [2]: jnp.divide(0., 0.) ---------------------------------------------------------------------------- -FloatingPointError Traceback (most recent call last) - in () -----> 1 jnp.divide(0., 0.) - -.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2) - 343 return floor_divide(x1, x2) - 344 else: ---> 345 return true_divide(x1, x2) - 346 - 347 - -.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2) - 332 x1, x2 = _promote_shapes(x1, x2) - 333 return lax.div(lax.convert_element_type(x1, result_dtype), ---> 334 lax.convert_element_type(x2, result_dtype)) - 335 - 336 - -.../jax/jax/lax.pyc in div(x, y) - 244 def div(x, y): - 245 r"""Elementwise division: :math:`x \over y`.""" ---> 246 return div_p.bind(x, y) - 247 - 248 def rem(x, y): - -... stack trace ... - -.../jax/jax/interpreters/xla.pyc in handle_result(device_buffer) - 103 py_val = device_buffer.to_py() - 104 if np.any(np.isnan(py_val)): ---> 105 raise FloatingPointError("invalid value") - 106 else: - 107 return Array(device_buffer, *result_shape) - -FloatingPointError: invalid value -``` - -+++ {"id": "_NCnVt_GBa_W"} - -The nan generated was caught. By running `%debug`, we can get a post-mortem debugger. This also works with functions under `@jit`, as the example below shows. - -+++ {"id": "pf8RF6eiBa_W"} - -``` -In [4]: from jax import jit - -In [5]: @jit - ...: def f(x, y): - ...: a = x * y - ...: b = (x + y) / (x - y) - ...: c = a + 2 - ...: return a + b * c - ...: - -In [6]: x = jnp.array([2., 0.]) - -In [7]: y = jnp.array([3., 0.]) - -In [8]: f(x, y) -Invalid value encountered in the output of a jit function. Calling the de-optimized version. ---------------------------------------------------------------------------- -FloatingPointError Traceback (most recent call last) - in () -----> 1 f(x, y) - - ... stack trace ... - - in f(x, y) - 2 def f(x, y): - 3 a = x * y -----> 4 b = (x + y) / (x - y) - 5 c = a + 2 - 6 return a + b * c - -.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2) - 343 return floor_divide(x1, x2) - 344 else: ---> 345 return true_divide(x1, x2) - 346 - 347 - -.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2) - 332 x1, x2 = _promote_shapes(x1, x2) - 333 return lax.div(lax.convert_element_type(x1, result_dtype), ---> 334 lax.convert_element_type(x2, result_dtype)) - 335 - 336 - -.../jax/jax/lax.pyc in div(x, y) - 244 def div(x, y): - 245 r"""Elementwise division: :math:`x \over y`.""" ---> 246 return div_p.bind(x, y) - 247 - 248 def rem(x, y): - - ... stack trace ... -``` - -+++ {"id": "6ur2yArDBa_W"} - -When this code sees a nan in the output of an `@jit` function, it calls into the de-optimized code, so we still get a clear stack trace. And we can run a post-mortem debugger with `%debug` to inspect all the values to figure out the error. - -⚠️ You shouldn't have the NaN-checker on if you're not debugging, as it can introduce lots of device-host round-trips and performance regressions! +## 🔪 Debugging NaNs and Infs -⚠️ The NaN-checker doesn't work with `pmap`. To debug nans in `pmap` code, one thing to try is replacing `pmap` with `vmap`. +Use the `jax_debug_nans` and `jax_debug_infs` flags to find the source of NaN/Inf values in functions and gradients. See {ref}`debugging-flags`. +++ {"id": "YTktlwTTMgFl"} @@ -664,8 +742,8 @@ x.dtype # --> dtype('float64') While `jax.numpy` makes every attempt to replicate the behavior of numpy's API, there do exist corner cases where the behaviors differ. Many such cases are discussed in detail in the sections above; here we list several other known places where the APIs diverge. -- For binary operations, JAX's type promotion rules differ somewhat from those used by NumPy. See [Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html) for more details. -- When performing unsafe type casts (i.e. casts in which the target dtype cannot represent the input value), JAX's behavior may be backend dependent, and in general may diverge from NumPy's behavior. Numpy allows control over the result in these scenarios via the `casting` argument (see [`np.ndarray.astype`](https://numpy.org/devdocs/reference/generated/numpy.ndarray.astype.html)); JAX does not provide any such configuration, instead directly inheriting the behavior of [XLA:ConvertElementType](https://www.tensorflow.org/xla/operation_semantics#convertelementtype). +- For binary operations, JAX's type promotion rules differ somewhat from those used by NumPy. See [Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html) for more details. +- When performing unsafe type casts (i.e. casts in which the target dtype cannot represent the input value), JAX's behavior may be backend dependent, and in general may diverge from NumPy's behavior. Numpy allows control over the result in these scenarios via the `casting` argument (see [`np.ndarray.astype`](https://numpy.org/devdocs/reference/generated/numpy.ndarray.astype.html)); JAX does not provide any such configuration, instead directly inheriting the behavior of [XLA:ConvertElementType](https://www.openxla.org/xla/operation_semantics#convertelementtype). Here is an example of an unsafe cast with differing results between NumPy and JAX: ```python diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb index e550cbf36da3..27f53cf32778 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb @@ -6,20 +6,21 @@ "id": "LqiaKasFjH82" }, "source": [ - "# Custom derivative rules\n", + "(advanced-autodiff-custom-derivative-rules)=\n", + "# Custom derivative rules for JAX-transformable Python functions\n", "\n", - "\n", + "\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)\n", "\n", "There are two ways to define differentiation rules in JAX:\n", "\n", - "1. using `jax.custom_jvp` and `jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and\n", + "1. using [`jax.custom_jvp`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.html) and [`jax.custom_vjp`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_vjp.html) to define custom differentiation rules for Python functions that are already JAX-transformable; and\n", "2. defining new `core.Primitive` instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems.\n", "\n", - "This notebook is about #1. To read instead about #2, see the [notebook on adding primitives](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html).\n", + "This notebook is about #1. To read instead about #2, see the [notebook on adding primitives](https://docs.jax.dev/en/latest/notebooks/How_JAX_primitives_work.html).\n", "\n", - "For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) and [jax.grad](https://jax.readthedocs.io/en/latest/jax.html#jax.grad), and the mathematical meaning of JVPs and VJPs." + "For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://docs.jax.dev/en/latest/_autosummary/jax.jvp.html) and [jax.grad](https://docs.jax.dev/en/latest/_autosummary/jax.grad.html), and the mathematical meaning of JVPs and VJPs." ] }, { @@ -28,16 +29,7 @@ "id": "9Fg3NFNY-2RY" }, "source": [ - "## Summary" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZgMNRtXyWIW8" - }, - "source": [ - "### Custom JVPs with `jax.custom_jvp`" + "### TL;DR: Custom JVPs with `jax.custom_jvp`" ] }, { @@ -144,7 +136,7 @@ "id": "N2DOGCREWXFj" }, "source": [ - "### Custom VJPs with `jax.custom_vjp`" + "### TL;DR: Custom VJPs with `jax.custom_vjp`" ] }, { @@ -209,7 +201,7 @@ "id": "AR02eyd1GQhC" }, "source": [ - "### Numerical stability\n", + "### Example: Numerical stability\n", "\n", "One application of `jax.custom_jvp` is to improve the numerical stability of differentiation." ] @@ -370,7 +362,7 @@ "\n", "Instead of generating such large and small values, hoping for a cancellation that floats can't always provide, we'd rather just express the derivative function as a more numerically stable program. In particular, we can write a program that more closely evaluates the equal mathematical expression $1 - \\frac{1}{1 + e^x}$, with no cancellation in sight.\n", "\n", - "This problem is interesting because even though our definition of `log1pexp` could already be JAX-differentiated (and transformed with `jit`, `vmap`, ...), we're not happy with the result of applying standard autodiff rules to the primitives comprising `log1pexp` and composing the result. Instead, we'd like to specify how the whole function `log1pexp` should be differentiated, as a unit, and thus arrange those exponentials better.\n", + "This problem is interesting because even though our definition of `log1pexp` could already be JAX-differentiated (and transformed with [`jit`](https://docs.jax.dev/en/latest/_autosummary/jax.jit.html), [`vmap`](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html), ...), we're not happy with the result of applying standard autodiff rules to the primitives comprising `log1pexp` and composing the result. Instead, we'd like to specify how the whole function `log1pexp` should be differentiated, as a unit, and thus arrange those exponentials better.\n", "\n", "This is one application of custom derivative rules for Python functions that are already JAX transformable: specifying how a composite function should be differentiated, while still using its original Python definition for other transformations (like `jit`, `vmap`, ...).\n", "\n", @@ -450,7 +442,7 @@ "id": "9sVUGbGkUOqO" }, "source": [ - "Here's a `defjvps` convenience wrapper to express the same thing:" + "Here's a [`defjvps`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.defjvps.html) convenience wrapper to express the same thing:" ] }, { @@ -500,7 +492,7 @@ "id": "V9tHAfrSF1N-" }, "source": [ - "### Enforcing a differentiation convention\n", + "### Example: Enforcing a differentiation convention\n", "\n", "A related application is to enforce a differentiation convention, perhaps at a boundary." ] @@ -657,11 +649,11 @@ "id": "7J2A85wbSAmF" }, "source": [ - "### Gradient clipping\n", + "### Example: Gradient clipping\n", "\n", "While in some cases we want to express a mathematical differentiation computation, in other cases we may even want to take a step away from mathematics to adjust the computation autodiff performs. One canonical example is reverse-mode gradient clipping.\n", "\n", - "For gradient clipping, we can use `jnp.clip` together with a `jax.custom_vjp` reverse-mode-only rule:" + "For gradient clipping, we can use [`jnp.clip`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.clip.html) together with a [`jax.custom_vjp`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_vjp.html) reverse-mode-only rule:" ] }, { @@ -782,7 +774,7 @@ "id": "CICQuI86WK4_" }, "source": [ - "### Python debugging\n", + "### Example: Python debugging\n", "\n", "Another application that is motivated by development workflow rather than numerics is to set a `pdb` debugger trace in the backward pass of reverse-mode autodiff." ] @@ -804,7 +796,7 @@ "id": "IC7tEcr1-Fc5" }, "source": [ - "### Implicit function differentiation of iterative implementations\n", + "### Example: Implicit function differentiation of iterative implementations\n", "\n", "This example gets pretty deep in the mathematical weeds!" ] @@ -815,7 +807,7 @@ "id": "szAt97t80hew" }, "source": [ - "Another application for `jax.custom_vjp` is reverse-mode differentiation of functions that are JAX-transformable (by `jit`, `vmap`, ...) but not efficiently JAX-differentiable for some reason, perhaps because they involve `lax.while_loop`. (It's not possible to produce an XLA HLO program that efficiently computes the reverse-mode derivative of an XLA HLO While loop because that would require a program with unbounded memory use, which isn't possible to express in XLA HLO, at least without side-effecting interactions through infeed/outfeed.)\n", + "Another application for `jax.custom_vjp` is reverse-mode differentiation of functions that are JAX-transformable (by `jit`, `vmap`, ...) but not efficiently JAX-differentiable for some reason, perhaps because they involve [`lax.while_loop`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.while_loop.html). (It's not possible to produce an XLA HLO program that efficiently computes the reverse-mode derivative of an XLA HLO While loop because that would require a program with unbounded memory use, which isn't possible to express in XLA HLO, at least without side-effecting interactions through infeed/outfeed.)\n", "\n", "For example, consider this `fixed_point` routine which computes a fixed point by iteratively applying a function in a `while_loop`:" ] @@ -1069,7 +1061,7 @@ "id": "HowvqayEuy-H" }, "source": [ - "A limitation to this approach is that the argument `f` can't close over any values involved in differentiation. That is, you might notice that we kept the parameter `a` explicit in the argument list of `fixed_point`. For this use case, consider using the low-level primitive `lax.custom_root`, which allows for deriviatives in closed-over variables with custom root-finding functions." + "A limitation to this approach is that the argument `f` can't close over any values involved in differentiation. That is, you might notice that we kept the parameter `a` explicit in the argument list of `fixed_point`. For this use case, consider using the low-level primitive `lax.custom_root`, which allows for derivatives in closed-over variables with custom root-finding functions." ] }, { @@ -1089,7 +1081,7 @@ "source": [ "### Use `jax.custom_jvp` to define forward-mode (and, indirectly, reverse-mode) rules\n", "\n", - "Here's a canonical basic example of using `jax.custom_jvp`, where the comments use\n", + "Here's a canonical basic example of using [`jax.custom_jvp`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.html), where the comments use\n", "[Haskell-like type signatures](https://wiki.haskell.org/Type_signature):" ] }, @@ -1272,7 +1264,7 @@ "id": "YPsPS3rdaGo2" }, "source": [ - "The `defjvps` convenience wrapper lets us define a JVP for each argument separately, and the results are computed separately then summed:" + "The [`defjvps`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.defjvps.html) convenience wrapper lets us define a JVP for each argument separately, and the results are computed separately then summed:" ] }, { @@ -1656,7 +1648,7 @@ "source": [ "### Use `jax.custom_vjp` to define custom reverse-mode-only rules\n", "\n", - "While `jax.custom_jvp` suffices for controlling both forward- and, via JAX's automatic transposition, reverse-mode differentiation behavior, in some cases we may want to directly control a VJP rule, for example in the latter two example problems presented above. We can do that with `jax.custom_vjp`:" + "While `jax.custom_jvp` suffices for controlling both forward- and, via JAX's automatic transposition, reverse-mode differentiation behavior, in some cases we may want to directly control a VJP rule, for example in the latter two example problems presented above. We can do that with [`jax.custom_vjp`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_vjp.html):" ] }, { @@ -2035,7 +2027,7 @@ "source": [ "### Working with `list` / `tuple` / `dict` containers (and other pytrees)\n", "\n", - "You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. \n", + "You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://docs.jax.dev/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. \n", "\n", "Here's a contrived example with `jax.custom_jvp`:" ] @@ -2200,7 +2192,7 @@ "id": "JKTNivxbmKWO" }, "source": [ - "### Handling non-differentiable arguments" + "### Handling non-differentiable arguments" ] }, { diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.md b/docs/notebooks/Custom_derivative_rules_for_Python_code.md index 8a63f142693e..ccdc709bd48b 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.md +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.md @@ -13,28 +13,25 @@ kernelspec: +++ {"id": "LqiaKasFjH82"} -# Custom derivative rules +(advanced-autodiff-custom-derivative-rules)= +# Custom derivative rules for JAX-transformable Python functions - + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) There are two ways to define differentiation rules in JAX: -1. using `jax.custom_jvp` and `jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and +1. using [`jax.custom_jvp`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.html) and [`jax.custom_vjp`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_vjp.html) to define custom differentiation rules for Python functions that are already JAX-transformable; and 2. defining new `core.Primitive` instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems. -This notebook is about #1. To read instead about #2, see the [notebook on adding primitives](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html). +This notebook is about #1. To read instead about #2, see the [notebook on adding primitives](https://docs.jax.dev/en/latest/notebooks/How_JAX_primitives_work.html). -For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) and [jax.grad](https://jax.readthedocs.io/en/latest/jax.html#jax.grad), and the mathematical meaning of JVPs and VJPs. +For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://docs.jax.dev/en/latest/_autosummary/jax.jvp.html) and [jax.grad](https://docs.jax.dev/en/latest/_autosummary/jax.grad.html), and the mathematical meaning of JVPs and VJPs. +++ {"id": "9Fg3NFNY-2RY"} -## Summary - -+++ {"id": "ZgMNRtXyWIW8"} - -### Custom JVPs with `jax.custom_jvp` +### TL;DR: Custom JVPs with `jax.custom_jvp` ```{code-cell} ipython3 :id: zXic8tr--1PK @@ -94,7 +91,7 @@ print(grad(f)(2., 3.)) +++ {"id": "N2DOGCREWXFj"} -### Custom VJPs with `jax.custom_vjp` +### TL;DR: Custom VJPs with `jax.custom_vjp` ```{code-cell} ipython3 :id: 35ScHqhrBwPh @@ -131,7 +128,7 @@ To get an idea of what problems `jax.custom_jvp` and `jax.custom_vjp` are meant +++ {"id": "AR02eyd1GQhC"} -### Numerical stability +### Example: Numerical stability One application of `jax.custom_jvp` is to improve the numerical stability of differentiation. @@ -197,7 +194,7 @@ Stepping through how the jaxpr would be evaluated, we can see that the last line Instead of generating such large and small values, hoping for a cancellation that floats can't always provide, we'd rather just express the derivative function as a more numerically stable program. In particular, we can write a program that more closely evaluates the equal mathematical expression $1 - \frac{1}{1 + e^x}$, with no cancellation in sight. -This problem is interesting because even though our definition of `log1pexp` could already be JAX-differentiated (and transformed with `jit`, `vmap`, ...), we're not happy with the result of applying standard autodiff rules to the primitives comprising `log1pexp` and composing the result. Instead, we'd like to specify how the whole function `log1pexp` should be differentiated, as a unit, and thus arrange those exponentials better. +This problem is interesting because even though our definition of `log1pexp` could already be JAX-differentiated (and transformed with [`jit`](https://docs.jax.dev/en/latest/_autosummary/jax.jit.html), [`vmap`](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html), ...), we're not happy with the result of applying standard autodiff rules to the primitives comprising `log1pexp` and composing the result. Instead, we'd like to specify how the whole function `log1pexp` should be differentiated, as a unit, and thus arrange those exponentials better. This is one application of custom derivative rules for Python functions that are already JAX transformable: specifying how a composite function should be differentiated, while still using its original Python definition for other transformations (like `jit`, `vmap`, ...). @@ -239,7 +236,7 @@ print(vmap(jit(grad(log1pexp)))(jnp.arange(3.))) +++ {"id": "9sVUGbGkUOqO"} -Here's a `defjvps` convenience wrapper to express the same thing: +Here's a [`defjvps`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.defjvps.html) convenience wrapper to express the same thing: ```{code-cell} ipython3 :id: xfQTp8F7USEM @@ -263,7 +260,7 @@ print(vmap(jit(grad(log1pexp)))(jnp.arange(3.))) +++ {"id": "V9tHAfrSF1N-"} -### Enforcing a differentiation convention +### Example: Enforcing a differentiation convention A related application is to enforce a differentiation convention, perhaps at a boundary. @@ -341,11 +338,11 @@ print(grad(f)(0.)) +++ {"id": "7J2A85wbSAmF"} -### Gradient clipping +### Example: Gradient clipping While in some cases we want to express a mathematical differentiation computation, in other cases we may even want to take a step away from mathematics to adjust the computation autodiff performs. One canonical example is reverse-mode gradient clipping. -For gradient clipping, we can use `jnp.clip` together with a `jax.custom_vjp` reverse-mode-only rule: +For gradient clipping, we can use [`jnp.clip`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.clip.html) together with a [`jax.custom_vjp`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_vjp.html) reverse-mode-only rule: ```{code-cell} ipython3 :id: 8jfjSanIW_tJ @@ -394,7 +391,7 @@ plt.plot(vmap(grad(clip_sin))(t)) +++ {"id": "CICQuI86WK4_"} -### Python debugging +### Example: Python debugging Another application that is motivated by development workflow rather than numerics is to set a `pdb` debugger trace in the backward pass of reverse-mode autodiff. @@ -406,13 +403,13 @@ We'll defer an example until the next section. +++ {"id": "IC7tEcr1-Fc5"} -### Implicit function differentiation of iterative implementations +### Example: Implicit function differentiation of iterative implementations This example gets pretty deep in the mathematical weeds! +++ {"id": "szAt97t80hew"} -Another application for `jax.custom_vjp` is reverse-mode differentiation of functions that are JAX-transformable (by `jit`, `vmap`, ...) but not efficiently JAX-differentiable for some reason, perhaps because they involve `lax.while_loop`. (It's not possible to produce an XLA HLO program that efficiently computes the reverse-mode derivative of an XLA HLO While loop because that would require a program with unbounded memory use, which isn't possible to express in XLA HLO, at least without side-effecting interactions through infeed/outfeed.) +Another application for `jax.custom_vjp` is reverse-mode differentiation of functions that are JAX-transformable (by `jit`, `vmap`, ...) but not efficiently JAX-differentiable for some reason, perhaps because they involve [`lax.while_loop`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.while_loop.html). (It's not possible to produce an XLA HLO program that efficiently computes the reverse-mode derivative of an XLA HLO While loop because that would require a program with unbounded memory use, which isn't possible to express in XLA HLO, at least without side-effecting interactions through infeed/outfeed.) For example, consider this `fixed_point` routine which computes a fixed point by iteratively applying a function in a `while_loop`: @@ -559,7 +556,7 @@ print(grad(grad(jnp.sqrt))(2.)) +++ {"id": "HowvqayEuy-H"} -A limitation to this approach is that the argument `f` can't close over any values involved in differentiation. That is, you might notice that we kept the parameter `a` explicit in the argument list of `fixed_point`. For this use case, consider using the low-level primitive `lax.custom_root`, which allows for deriviatives in closed-over variables with custom root-finding functions. +A limitation to this approach is that the argument `f` can't close over any values involved in differentiation. That is, you might notice that we kept the parameter `a` explicit in the argument list of `fixed_point`. For this use case, consider using the low-level primitive `lax.custom_root`, which allows for derivatives in closed-over variables with custom root-finding functions. +++ {"id": "Dr0aNkBslfQf"} @@ -569,7 +566,7 @@ A limitation to this approach is that the argument `f` can't close over any valu ### Use `jax.custom_jvp` to define forward-mode (and, indirectly, reverse-mode) rules -Here's a canonical basic example of using `jax.custom_jvp`, where the comments use +Here's a canonical basic example of using [`jax.custom_jvp`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.html), where the comments use [Haskell-like type signatures](https://wiki.haskell.org/Type_signature): ```{code-cell} ipython3 @@ -670,7 +667,7 @@ print(grad(f)(2., 3.)) +++ {"id": "YPsPS3rdaGo2"} -The `defjvps` convenience wrapper lets us define a JVP for each argument separately, and the results are computed separately then summed: +The [`defjvps`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.defjvps.html) convenience wrapper lets us define a JVP for each argument separately, and the results are computed separately then summed: ```{code-cell} ipython3 :id: CsQIUhUkajua @@ -845,7 +842,7 @@ print(grad(f)(-1.)) ### Use `jax.custom_vjp` to define custom reverse-mode-only rules -While `jax.custom_jvp` suffices for controlling both forward- and, via JAX's automatic transposition, reverse-mode differentiation behavior, in some cases we may want to directly control a VJP rule, for example in the latter two example problems presented above. We can do that with `jax.custom_vjp`: +While `jax.custom_jvp` suffices for controlling both forward- and, via JAX's automatic transposition, reverse-mode differentiation behavior, in some cases we may want to directly control a VJP rule, for example in the latter two example problems presented above. We can do that with [`jax.custom_vjp`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_vjp.html): ```{code-cell} ipython3 :id: zAZk1n3dUw76 @@ -1048,7 +1045,7 @@ Array(-0.91113025, dtype=float32) ### Working with `list` / `tuple` / `dict` containers (and other pytrees) -You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. +You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://docs.jax.dev/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. Here's a contrived example with `jax.custom_jvp`: @@ -1141,7 +1138,7 @@ print(grad(fun)(pt)) +++ {"id": "JKTNivxbmKWO"} -### Handling non-differentiable arguments +### Handling non-differentiable arguments +++ {"id": "7g9sXSp_uc36"} diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb index 8abee469d552..134514880a0e 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb @@ -19,7 +19,7 @@ "source": [ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb)\n", "\n", - "This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer." + "This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer. See {doc}`../the-training-cookbook` for a real-world machine learning training example that uses this API." ] }, { @@ -67,7 +67,7 @@ "source": [ "## Intro and a quick example\n", "\n", - "By reading this tutorial notebook, you'll learn about `jax.Array`, a unified \n", + "By reading this tutorial notebook, you'll learn about `jax.Array`, a unified\n", "datatype for representing arrays, even with physical storage spanning multiple\n", "devices. You'll also learn about how using `jax.Array`s together with `jax.jit`\n", "can provide automatic compiler-based parallelization.\n", @@ -1276,7 +1276,7 @@ "id": "3qfPjJdhgerc" }, "source": [ - "So computation follows data placement: when we explicitly shard data with `jax.device_put`, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding. This policy for sharded data is a generalization of [JAX's policy of following explicit device placement](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices)." + "So computation follows data placement: when we explicitly shard data with `jax.device_put`, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding. This policy for sharded data is a generalization of [JAX's policy of following explicit device placement](https://docs.jax.dev/en/latest/faq.html#controlling-data-and-computation-placement-on-devices)." ] }, { @@ -1382,7 +1382,7 @@ "id": "6ZYcK8eXrn0p" }, "source": [ - "We say arrays that have been explicitly placed or sharded with `jax.device_put` are _committed_ to their device(s), and so won't be automatically moved. See the [device placement FAQ](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) for more information.\n", + "We say arrays that have been explicitly placed or sharded with `jax.device_put` are _committed_ to their device(s), and so won't be automatically moved. See the [device placement FAQ](https://docs.jax.dev/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) for more information.\n", "\n", "When arrays are _not_ explicitly placed or sharded with `jax.device_put`, they are placed _uncommitted_ on the default device.\n", "Unlike committed arrays, uncommitted arrays can be moved and resharded automatically: that is, uncommitted arrays can be arguments to a computation even if other arguments are explicitly placed on different devices.\n", @@ -2321,269 +2321,6 @@ "source": [ "%timeit -n 10 -r 10 gradfun(params, batch)[0][0].block_until_ready()" ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3diqi5VRBy6S" - }, - "source": [ - "## Sharp bits" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "OTfoXNnxFYDJ" - }, - "source": [ - "### Generating random numbers\n", - "\n", - "JAX comes with a functional, deterministic [random number generator](https://jax.readthedocs.io/en/latest/jep/263-prng.html). It underlies the various sampling functions in the [`jax.random` module](https://jax.readthedocs.io/en/latest/jax.random.html), such as `jax.random.uniform`.\n", - "\n", - "JAX's random numbers are produced by a counter-based PRNG, so in principle, random number generation should be a pure map over counter values. A pure map is a trivially partitionable operation in principle. It should require no cross-device communication, nor any redundant computation across devices.\n", - "\n", - "However, the existing stable RNG implementation is not automatically partitionable, for historical reasons." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ht_zYFVXNrjN" - }, - "source": [ - "Consider the following example, where a function draws random uniform numbers and adds them to the input, elementwise:" - ] - }, - { - "cell_type": "code", - "execution_count": 75, - "metadata": { - "id": "kwS-aQE_3vGX" - }, - "outputs": [], - "source": [ - "@jax.jit\n", - "def f(key, x):\n", - " numbers = jax.random.uniform(key, x.shape)\n", - " return x + numbers\n", - "\n", - "key = jax.random.key(42)\n", - "mesh = Mesh(jax.devices(), 'x')\n", - "x_sharding = NamedSharding(mesh, P('x'))\n", - "x = jax.device_put(jnp.arange(24), x_sharding)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZgSA9x9NLMaP" - }, - "source": [ - "On a partitioned input, the function `f` produces output that is also partitioned:" - ] - }, - { - "cell_type": "code", - "execution_count": 76, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 67 - }, - "id": "Oi97rpLz3vGY", - "outputId": "9dd63254-a483-4847-c0f5-5a4367bf08e9" - }, - "outputs": [ - { - "data": { - "text/html": [ - "
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐\n",
-       "│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │\n",
-       "└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘\n",
-       "
\n" - ], - "text/plain": [ - "┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐\n", - "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n", - "└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "jax.debug.visualize_array_sharding(f(key, x))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WnjlWDUYLkp6" - }, - "source": [ - "But if we inspect the compiled computation for `f` on this partitioned input, we see that it does involve some communication:" - ] - }, - { - "cell_type": "code", - "execution_count": 77, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "64wIZuSJ3vGY", - "outputId": "fa166d45-ca9c-457a-be84-bcc9236d0730" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Communicating? True\n" - ] - } - ], - "source": [ - "f_exe = f.lower(key, x).compile()\n", - "print('Communicating?', 'collective-permute' in f_exe.as_text())" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AXp9i8fbL8DD" - }, - "source": [ - "One way to work around this is to configure JAX with the experimental upgrade flag `jax_threefry_partitionable`. With the flag on, the \"collective permute\" operation is now gone from the compiled computation:" - ] - }, - { - "cell_type": "code", - "execution_count": 78, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "1I7bqxA63vGY", - "outputId": "756e0a36-ff14-438f-bbd4-3ef03f97a47b" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Communicating? False\n" - ] - } - ], - "source": [ - "jax.config.update('jax_threefry_partitionable', True)\n", - "f_exe = f.lower(key, x).compile()\n", - "print('Communicating?', 'collective-permute' in f_exe.as_text())" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WV8ZccM5SXOU" - }, - "source": [ - "The output is still partitioned:" - ] - }, - { - "cell_type": "code", - "execution_count": 79, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 67 - }, - "id": "zHPJzdn23vGY", - "outputId": "3332de0f-4827-4f0b-b9ef-69249b7c6bc6" - }, - "outputs": [ - { - "data": { - "text/html": [ - "
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐\n",
-       "│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │\n",
-       "└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘\n",
-       "
\n" - ], - "text/plain": [ - "┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐\n", - "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n", - "└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "jax.debug.visualize_array_sharding(f(key, x))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kaK--hPmSPpV" - }, - "source": [ - "One caveat to the `jax_threefry_partitionable` option, however, is that _the random values produced may be different than without the flag set_, even though they were generated by the same random key:" - ] - }, - { - "cell_type": "code", - "execution_count": 80, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "nBUHBBal3vGY", - "outputId": "4b9be948-ccab-4a31-a06f-37ec9c7b5235" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Stable:\n", - "[ 0.72503686 1.8532515 2.983416 3.083253 4.0332246 5.4782867\n", - " 6.1720605 7.6900277 8.602836 9.810046 10.861367 11.907651\n", - " 12.330483 13.456195 14.808557 15.960099 16.067581 17.739723\n", - " 18.335474 19.46401 20.390276 21.116539 22.858128 23.223194 ]\n", - "\n", - "Partitionable:\n", - "[ 0.48870957 1.6797972 2.6162715 3.561016 4.4506445 5.585866\n", - " 6.0748096 7.775133 8.698959 9.818634 10.350306 11.87282\n", - " 12.925881 13.86013 14.477554 15.818481 16.711355 17.586697\n", - " 18.073738 19.777622 20.404566 21.119123 22.026257 23.63918 ]\n" - ] - } - ], - "source": [ - "jax.config.update('jax_threefry_partitionable', False)\n", - "print('Stable:')\n", - "print(f(key, x))\n", - "print()\n", - "\n", - "jax.config.update('jax_threefry_partitionable', True)\n", - "print('Partitionable:')\n", - "print(f(key, x))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8BDPqgOrTMfK" - }, - "source": [ - "In `jax_threefry_partitionable` mode, the JAX PRNG remains deterministic, but its implementation is new (and under development). The random values generated for a given key will be the same at a given JAX version (or a given commit on the `main` branch), but may vary across releases." - ] } ], "metadata": { diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md index c207f0ae4a00..cc4547b6e417 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md @@ -21,7 +21,7 @@ kernelspec: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) -This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer. +This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer. See {doc}`../the-training-cookbook` for a real-world machine learning training example that uses this API. ```{code-cell} :id: FNxScTfq3vGF @@ -49,7 +49,7 @@ if len(jax.local_devices()) < 8: ## Intro and a quick example -By reading this tutorial notebook, you'll learn about `jax.Array`, a unified +By reading this tutorial notebook, you'll learn about `jax.Array`, a unified datatype for representing arrays, even with physical storage spanning multiple devices. You'll also learn about how using `jax.Array`s together with `jax.jit` can provide automatic compiler-based parallelization. @@ -427,7 +427,7 @@ jax.debug.visualize_array_sharding(w_copy) +++ {"id": "3qfPjJdhgerc"} -So computation follows data placement: when we explicitly shard data with `jax.device_put`, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding. This policy for sharded data is a generalization of [JAX's policy of following explicit device placement](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices). +So computation follows data placement: when we explicitly shard data with `jax.device_put`, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding. This policy for sharded data is a generalization of [JAX's policy of following explicit device placement](https://docs.jax.dev/en/latest/faq.html#controlling-data-and-computation-placement-on-devices). +++ {"id": "QRB95LaWuT80"} @@ -484,7 +484,7 @@ except ValueError as e: print_exception(e) +++ {"id": "6ZYcK8eXrn0p"} -We say arrays that have been explicitly placed or sharded with `jax.device_put` are _committed_ to their device(s), and so won't be automatically moved. See the [device placement FAQ](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) for more information. +We say arrays that have been explicitly placed or sharded with `jax.device_put` are _committed_ to their device(s), and so won't be automatically moved. See the [device placement FAQ](https://docs.jax.dev/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) for more information. When arrays are _not_ explicitly placed or sharded with `jax.device_put`, they are placed _uncommitted_ on the default device. Unlike committed arrays, uncommitted arrays can be moved and resharded automatically: that is, uncommitted arrays can be arguments to a computation even if other arguments are explicitly placed on different devices. @@ -845,121 +845,3 @@ outputId: 479c4d81-cb0b-40a5-89ba-394c10dc3297 --- %timeit -n 10 -r 10 gradfun(params, batch)[0][0].block_until_ready() ``` - -+++ {"id": "3diqi5VRBy6S"} - -## Sharp bits - -+++ {"id": "OTfoXNnxFYDJ"} - -### Generating random numbers - -JAX comes with a functional, deterministic [random number generator](https://jax.readthedocs.io/en/latest/jep/263-prng.html). It underlies the various sampling functions in the [`jax.random` module](https://jax.readthedocs.io/en/latest/jax.random.html), such as `jax.random.uniform`. - -JAX's random numbers are produced by a counter-based PRNG, so in principle, random number generation should be a pure map over counter values. A pure map is a trivially partitionable operation in principle. It should require no cross-device communication, nor any redundant computation across devices. - -However, the existing stable RNG implementation is not automatically partitionable, for historical reasons. - -+++ {"id": "ht_zYFVXNrjN"} - -Consider the following example, where a function draws random uniform numbers and adds them to the input, elementwise: - -```{code-cell} -:id: kwS-aQE_3vGX - -@jax.jit -def f(key, x): - numbers = jax.random.uniform(key, x.shape) - return x + numbers - -key = jax.random.key(42) -mesh = Mesh(jax.devices(), 'x') -x_sharding = NamedSharding(mesh, P('x')) -x = jax.device_put(jnp.arange(24), x_sharding) -``` - -+++ {"id": "ZgSA9x9NLMaP"} - -On a partitioned input, the function `f` produces output that is also partitioned: - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ - height: 67 -id: Oi97rpLz3vGY -outputId: 9dd63254-a483-4847-c0f5-5a4367bf08e9 ---- -jax.debug.visualize_array_sharding(f(key, x)) -``` - -+++ {"id": "WnjlWDUYLkp6"} - -But if we inspect the compiled computation for `f` on this partitioned input, we see that it does involve some communication: - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: 64wIZuSJ3vGY -outputId: fa166d45-ca9c-457a-be84-bcc9236d0730 ---- -f_exe = f.lower(key, x).compile() -print('Communicating?', 'collective-permute' in f_exe.as_text()) -``` - -+++ {"id": "AXp9i8fbL8DD"} - -One way to work around this is to configure JAX with the experimental upgrade flag `jax_threefry_partitionable`. With the flag on, the "collective permute" operation is now gone from the compiled computation: - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: 1I7bqxA63vGY -outputId: 756e0a36-ff14-438f-bbd4-3ef03f97a47b ---- -jax.config.update('jax_threefry_partitionable', True) -f_exe = f.lower(key, x).compile() -print('Communicating?', 'collective-permute' in f_exe.as_text()) -``` - -+++ {"id": "WV8ZccM5SXOU"} - -The output is still partitioned: - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ - height: 67 -id: zHPJzdn23vGY -outputId: 3332de0f-4827-4f0b-b9ef-69249b7c6bc6 ---- -jax.debug.visualize_array_sharding(f(key, x)) -``` - -+++ {"id": "kaK--hPmSPpV"} - -One caveat to the `jax_threefry_partitionable` option, however, is that _the random values produced may be different than without the flag set_, even though they were generated by the same random key: - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: nBUHBBal3vGY -outputId: 4b9be948-ccab-4a31-a06f-37ec9c7b5235 ---- -jax.config.update('jax_threefry_partitionable', False) -print('Stable:') -print(f(key, x)) -print() - -jax.config.update('jax_threefry_partitionable', True) -print('Partitionable:') -print(f(key, x)) -``` - -+++ {"id": "8BDPqgOrTMfK"} - -In `jax_threefry_partitionable` mode, the JAX PRNG remains deterministic, but its implementation is new (and under development). The random values generated for a given key will be the same at a given JAX version (or a given commit on the `main` branch), but may vary across releases. diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb index a7ef2a017048..4c9b6c5e48a7 100644 --- a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb +++ b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb @@ -41,7 +41,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": { "id": "OksHydJDtbbI" }, @@ -64,7 +64,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": { "id": "-fmWA06xYE7d" }, @@ -102,7 +102,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": { "id": "7APc6tD7TiuZ" }, @@ -136,7 +136,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": { "id": "4sW2A5mnXHc5", "outputId": "9d3b29e8-fab3-4ecb-9f63-bc8c092f9006" @@ -159,7 +159,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": { "id": "PpyQxuedXfhp", "outputId": "d5d20211-b6da-44e9-f71e-946f2a9d0fc4" @@ -184,7 +184,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": { "id": "oJOOncKMXbwK", "outputId": "31285fab-7667-4871-fcba-28e86adc3fc6" @@ -229,7 +229,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": { "id": "6lTI6I4lWdh5" }, @@ -268,21 +268,37 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": { "id": "gEvWt8_u2pqG", "outputId": "2c83a679-9ce5-4c67-bccb-9ea835a8eaf6" }, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/m/.opt/miniforge3/envs/jax/lib/python3.12/pty.py:95: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " pid, fd = os.forkpty()\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "Requirement already satisfied: torch in /opt/anaconda3/lib/python3.7/site-packages (1.4.0)\n", - "Requirement already satisfied: torchvision in /opt/anaconda3/lib/python3.7/site-packages (0.5.0)\n", - "Requirement already satisfied: numpy in /opt/anaconda3/lib/python3.7/site-packages (from torchvision) (1.17.2)\n", - "Requirement already satisfied: six in /opt/anaconda3/lib/python3.7/site-packages (from torchvision) (1.12.0)\n", - "Requirement already satisfied: pillow>=4.1.1 in /opt/anaconda3/lib/python3.7/site-packages (from torchvision) (6.2.0)\n" + "Requirement already satisfied: torch in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (2.4.1)\n", + "Requirement already satisfied: torchvision in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (0.19.1)\n", + "Requirement already satisfied: filelock in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (3.16.0)\n", + "Requirement already satisfied: typing-extensions>=4.8.0 in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (4.12.2)\n", + "Requirement already satisfied: sympy in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (1.13.2)\n", + "Requirement already satisfied: networkx in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (3.3)\n", + "Requirement already satisfied: jinja2 in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (3.1.4)\n", + "Requirement already satisfied: fsspec in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (2024.9.0)\n", + "Requirement already satisfied: setuptools in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (73.0.1)\n", + "Requirement already satisfied: numpy in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torchvision) (1.26.4)\n", + "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torchvision) (10.4.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from jinja2->torch) (2.1.5)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from sympy->torch) (1.3.0)\n" ] } ], @@ -292,7 +308,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": { "cellView": "both", "id": "94PjXZ8y3dVF" @@ -301,38 +317,24 @@ "source": [ "import numpy as np\n", "from jax.tree_util import tree_map\n", - "from torch.utils import data\n", + "from torch.utils.data import DataLoader, default_collate\n", "from torchvision.datasets import MNIST\n", "\n", "def numpy_collate(batch):\n", - " return tree_map(np.asarray, data.default_collate(batch))\n", - "\n", - "class NumpyLoader(data.DataLoader):\n", - " def __init__(self, dataset, batch_size=1,\n", - " shuffle=False, sampler=None,\n", - " batch_sampler=None, num_workers=0,\n", - " pin_memory=False, drop_last=False,\n", - " timeout=0, worker_init_fn=None):\n", - " super(self.__class__, self).__init__(dataset,\n", - " batch_size=batch_size,\n", - " shuffle=shuffle,\n", - " sampler=sampler,\n", - " batch_sampler=batch_sampler,\n", - " num_workers=num_workers,\n", - " collate_fn=numpy_collate,\n", - " pin_memory=pin_memory,\n", - " drop_last=drop_last,\n", - " timeout=timeout,\n", - " worker_init_fn=worker_init_fn)\n", + " \"\"\"\n", + " Collate function specifies how to combine a list of data samples into a batch.\n", + " default_collate creates pytorch tensors, then tree_map converts them into numpy arrays.\n", + " \"\"\"\n", + " return tree_map(np.asarray, default_collate(batch))\n", "\n", - "class FlattenAndCast(object):\n", - " def __call__(self, pic):\n", - " return np.ravel(np.array(pic, dtype=jnp.float32))" + "def flatten_and_cast(pic):\n", + " \"\"\"Convert PIL image to flat (1-dimensional) numpy array.\"\"\"\n", + " return np.ravel(np.array(pic, dtype=jnp.float32))" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "metadata": { "id": "l314jsfP4TN4" }, @@ -341,108 +343,110 @@ "name": "stdout", "output_type": "stream", "text": [ - "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz\n" + "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", + "Failed to download (trying next):\n", + "HTTP Error 404: Not Found\n", + "\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz\n" ] }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "75806ce83ace4f69b81bbc4251c5573f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" + "name": "stderr", + "output_type": "stream", + "text": [ + "100.0%\n" + ] }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw\n", - "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz\n" + "\n", + "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n", + "Failed to download (trying next):\n", + "HTTP Error 404: Not Found\n", + "\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz\n" ] }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "274ed4ab05f34f70b7a5bb6cf427ffd0", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" + "name": "stderr", + "output_type": "stream", + "text": [ + "100.0%\n" + ] }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw\n", - "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz\n" + "\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n", + "Failed to download (trying next):\n", + "HTTP Error 404: Not Found\n", + "\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz\n" ] }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "d38fa4eabf3c4d4494eb59e078ac94e8", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" + "name": "stderr", + "output_type": "stream", + "text": [ + "100.0%\n" + ] }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw\n", - "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz\n" + "\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n", + "Failed to download (trying next):\n", + "HTTP Error 404: Not Found\n", + "\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz\n" ] }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "523ac9565c5f4509a1ee8fdbb1e6d66d", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" + "name": "stderr", + "output_type": "stream", + "text": [ + "100.0%" + ] }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw\n", - "Processing...\n", - "Done!\n" + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" ] } ], "source": [ "# Define our dataset, using torch datasets\n", - "mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())\n", - "training_generator = NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)" + "mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=flatten_and_cast)\n", + "# Create pytorch data loader with custom collate function\n", + "training_generator = DataLoader(mnist_dataset, batch_size=batch_size, collate_fn=numpy_collate)" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "metadata": { "id": "FTNo4beUvb6t", "outputId": "65a9087c-c326-49e5-cbfc-e0839212fa31" @@ -452,27 +456,13 @@ "name": "stderr", "output_type": "stream", "text": [ - "/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:55: UserWarning: train_data has been renamed data\n", + "/home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages/torchvision/datasets/mnist.py:76: UserWarning: train_data has been renamed data\n", " warnings.warn(\"train_data has been renamed data\")\n", - "/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:45: UserWarning: train_labels has been renamed targets\n", - " warnings.warn(\"train_labels has been renamed targets\")\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:60: UserWarning: test_data has been renamed data\n", + "/home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages/torchvision/datasets/mnist.py:66: UserWarning: train_labels has been renamed targets\n", + " warnings.warn(\"train_labels has been renamed targets\")\n", + "/home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages/torchvision/datasets/mnist.py:81: UserWarning: test_data has been renamed data\n", " warnings.warn(\"test_data has been renamed data\")\n", - "/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:50: UserWarning: test_labels has been renamed targets\n", + "/home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages/torchvision/datasets/mnist.py:71: UserWarning: test_labels has been renamed targets\n", " warnings.warn(\"test_labels has been renamed targets\")\n" ] } @@ -499,7 +489,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "metadata": { "id": "X2DnZo3iYj18", "outputId": "0eba3ca2-24a1-4cba-aaf4-3ac61d0c650e" @@ -509,30 +499,30 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch 0 in 55.15 sec\n", - "Training set accuracy 0.9157500267028809\n", - "Test set accuracy 0.9195000529289246\n", - "Epoch 1 in 42.26 sec\n", - "Training set accuracy 0.9372166991233826\n", - "Test set accuracy 0.9384000301361084\n", - "Epoch 2 in 44.37 sec\n", - "Training set accuracy 0.9491666555404663\n", - "Test set accuracy 0.9469000697135925\n", - "Epoch 3 in 41.75 sec\n", - "Training set accuracy 0.9568166732788086\n", - "Test set accuracy 0.9534000158309937\n", - "Epoch 4 in 41.16 sec\n", - "Training set accuracy 0.9631333351135254\n", - "Test set accuracy 0.9577000737190247\n", - "Epoch 5 in 38.89 sec\n", + "Epoch 0 in 5.53 sec\n", + "Training set accuracy 0.9156666994094849\n", + "Test set accuracy 0.9199000000953674\n", + "Epoch 1 in 1.13 sec\n", + "Training set accuracy 0.9370499849319458\n", + "Test set accuracy 0.9383999705314636\n", + "Epoch 2 in 1.12 sec\n", + "Training set accuracy 0.9490833282470703\n", + "Test set accuracy 0.9467999935150146\n", + "Epoch 3 in 1.21 sec\n", + "Training set accuracy 0.9568833708763123\n", + "Test set accuracy 0.9532999992370605\n", + "Epoch 4 in 1.17 sec\n", + "Training set accuracy 0.9631666541099548\n", + "Test set accuracy 0.9574999809265137\n", + "Epoch 5 in 1.17 sec\n", "Training set accuracy 0.9675000309944153\n", - "Test set accuracy 0.9616000652313232\n", - "Epoch 6 in 40.68 sec\n", - "Training set accuracy 0.9708333611488342\n", - "Test set accuracy 0.9650000333786011\n", - "Epoch 7 in 41.50 sec\n", - "Training set accuracy 0.973716676235199\n", - "Test set accuracy 0.9672000408172607\n" + "Test set accuracy 0.9615999460220337\n", + "Epoch 6 in 1.11 sec\n", + "Training set accuracy 0.9709500074386597\n", + "Test set accuracy 0.9652999639511108\n", + "Epoch 7 in 1.17 sec\n", + "Training set accuracy 0.9736999869346619\n", + "Test set accuracy 0.967199981212616\n" ] } ], @@ -576,7 +566,7 @@ "formats": "ipynb,md:myst" }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -590,9 +580,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.4" + "version": "3.12.3" } }, "nbformat": 4, - "nbformat_minor": 1 + "nbformat_minor": 4 } diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.md b/docs/notebooks/Neural_Network_and_Data_Loading.md index cd98022e7421..bcc4019d6da0 100644 --- a/docs/notebooks/Neural_Network_and_Data_Loading.md +++ b/docs/notebooks/Neural_Network_and_Data_Loading.md @@ -7,7 +7,7 @@ jupytext: format_version: 0.13 jupytext_version: 1.16.4 kernelspec: - display_name: Python 3 + display_name: Python 3 (ipykernel) language: python name: python3 --- @@ -192,41 +192,28 @@ JAX is laser-focused on program transformations and accelerator-backed NumPy, so import numpy as np from jax.tree_util import tree_map -from torch.utils import data +from torch.utils.data import DataLoader, default_collate from torchvision.datasets import MNIST def numpy_collate(batch): - return tree_map(np.asarray, data.default_collate(batch)) - -class NumpyLoader(data.DataLoader): - def __init__(self, dataset, batch_size=1, - shuffle=False, sampler=None, - batch_sampler=None, num_workers=0, - pin_memory=False, drop_last=False, - timeout=0, worker_init_fn=None): - super(self.__class__, self).__init__(dataset, - batch_size=batch_size, - shuffle=shuffle, - sampler=sampler, - batch_sampler=batch_sampler, - num_workers=num_workers, - collate_fn=numpy_collate, - pin_memory=pin_memory, - drop_last=drop_last, - timeout=timeout, - worker_init_fn=worker_init_fn) - -class FlattenAndCast(object): - def __call__(self, pic): - return np.ravel(np.array(pic, dtype=jnp.float32)) + """ + Collate function specifies how to combine a list of data samples into a batch. + default_collate creates pytorch tensors, then tree_map converts them into numpy arrays. + """ + return tree_map(np.asarray, default_collate(batch)) + +def flatten_and_cast(pic): + """Convert PIL image to flat (1-dimensional) numpy array.""" + return np.ravel(np.array(pic, dtype=jnp.float32)) ``` ```{code-cell} ipython3 :id: l314jsfP4TN4 # Define our dataset, using torch datasets -mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast()) -training_generator = NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0) +mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=flatten_and_cast) +# Create pytorch data loader with custom collate function +training_generator = DataLoader(mnist_dataset, batch_size=batch_size, collate_fn=numpy_collate) ``` ```{code-cell} ipython3 diff --git a/docs/notebooks/README.md b/docs/notebooks/README.md index 07be4441ade8..c945c197ad19 100644 --- a/docs/notebooks/README.md +++ b/docs/notebooks/README.md @@ -1,2 +1,2 @@ For instructions on how to change and test notebooks, see -[Update Documentation](https://jax.readthedocs.io/en/latest/developer.html#update-documentation). +[Update Documentation](https://docs.jax.dev/en/latest/developer.html#update-documentation). diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb index 00ba9186eeec..d22457c5d718 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb @@ -24,7 +24,7 @@ "\n", "Here we show how to add your own function transformations to the system, by writing a custom Jaxpr interpreter. And we'll get composability with all the other transformations for free.\n", "\n", - "**This example uses internal JAX APIs, which may break at any time. Anything not in [the API Documentation](https://jax.readthedocs.io/en/latest/jax.html) should be assumed internal.**" + "**This example uses internal JAX APIs, which may break at any time. Anything not in [the API Documentation](https://docs.jax.dev/en/latest/jax.html) should be assumed internal.**" ] }, { @@ -215,8 +215,8 @@ "# Importing Jax functions useful for tracing/interpreting.\n", "from functools import wraps\n", "\n", - "from jax import core\n", "from jax import lax\n", + "from jax.extend import core\n", "from jax._src.util import safe_map" ] }, diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.md b/docs/notebooks/Writing_custom_interpreters_in_Jax.md index 10c4e7cb6e3b..ad707a9746fc 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.md +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.md @@ -27,7 +27,7 @@ etc.) that enable writing concise, accelerated code. Here we show how to add your own function transformations to the system, by writing a custom Jaxpr interpreter. And we'll get composability with all the other transformations for free. -**This example uses internal JAX APIs, which may break at any time. Anything not in [the API Documentation](https://jax.readthedocs.io/en/latest/jax.html) should be assumed internal.** +**This example uses internal JAX APIs, which may break at any time. Anything not in [the API Documentation](https://docs.jax.dev/en/latest/jax.html) should be assumed internal.** ```{code-cell} ipython3 :id: s27RDKvKXFL8 @@ -147,8 +147,8 @@ Let's use `make_jaxpr` to trace a function into a Jaxpr. # Importing Jax functions useful for tracing/interpreting. from functools import wraps -from jax import core from jax import lax +from jax.extend import core from jax._src.util import safe_map ``` diff --git a/docs/notebooks/autodiff_cookbook.ipynb b/docs/notebooks/autodiff_cookbook.ipynb index 5538b70dac93..46f887f8986f 100644 --- a/docs/notebooks/autodiff_cookbook.ipynb +++ b/docs/notebooks/autodiff_cookbook.ipynb @@ -1637,7 +1637,7 @@ "source": [ "## More advanced autodiff\n", "\n", - "In this notebook, we worked through some easy, and then progressively more complicated, applications of automatic differentiation in JAX. We hope you now feel that taking derivatives in JAX is easy and powerful. \n", + "In this notebook, we worked through some easy, and then progressively more complicated, applications of automatic differentiation in JAX. We hope you now feel that taking derivatives in JAX is easy and powerful. For more details, check out the [\"Advanced automatic differentiation\" section in the JAX advanced guides](https://jax.readthedocs.io/en/latest/advanced_guides.html).\n", "\n", "There's a whole world of other autodiff tricks and functionality out there. Topics we didn't cover, but hope to in an \"Advanced Autodiff Cookbook\" include:\n", "\n", diff --git a/docs/notebooks/autodiff_cookbook.md b/docs/notebooks/autodiff_cookbook.md index db6fde8051d1..d2cb091bc0e8 100644 --- a/docs/notebooks/autodiff_cookbook.md +++ b/docs/notebooks/autodiff_cookbook.md @@ -960,7 +960,7 @@ grad(f, holomorphic=True)(A) ## More advanced autodiff -In this notebook, we worked through some easy, and then progressively more complicated, applications of automatic differentiation in JAX. We hope you now feel that taking derivatives in JAX is easy and powerful. +In this notebook, we worked through some easy, and then progressively more complicated, applications of automatic differentiation in JAX. We hope you now feel that taking derivatives in JAX is easy and powerful. For more details, check out the ["Advanced automatic differentiation" section in the JAX advanced guides](https://jax.readthedocs.io/en/latest/advanced_guides.html). There's a whole world of other autodiff tricks and functionality out there. Topics we didn't cover, but hope to in an "Advanced Autodiff Cookbook" include: diff --git a/docs/notebooks/autodiff_remat.ipynb b/docs/notebooks/autodiff_remat.ipynb index feb906546341..cd03b6b4b0c9 100644 --- a/docs/notebooks/autodiff_remat.ipynb +++ b/docs/notebooks/autodiff_remat.ipynb @@ -348,7 +348,7 @@ "source": [ "### Let's think step by step\n", "\n", - "You might want to first (re)read [the Autodiff Cookbook Part 1](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)." + "You might want to first (re)read [the Autodiff Cookbook Part 1](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html)." ] }, { @@ -473,7 +473,7 @@ "id": "LqTrjPoGqrK7" }, "source": [ - "We can get this VJP behavior in autodiff — without having to write VJP functions directly — by instead using `jax.checkpoint` in an alternative definition of the original function `f`:" + "We can get this VJP behavior in autodiff --- without having to write VJP functions directly --- by instead using `jax.checkpoint` in an alternative definition of the original function `f`:" ] }, { @@ -544,11 +544,11 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "def f_grad_bad(x):\n", + "def f_grad_bad1(x):\n", " _ = f(x) # step 1\n", " _, f_vjp = jax.vjp(f, x) # step 2\n", " x_bar, = f_vjp(1.0) # step 3\n", @@ -812,17 +812,7 @@ "source": [ "Another policy which refers to names is `jax.checkpoint_policies.save_only_these_names`.\n", "\n", - "Some of the policies are:\n", - "* `everything_saveable` (the default strategy, as if `jax.checkpoint` were not being used at all)\n", - "* `nothing_saveable` (i.e. rematerialize everything, as if a custom policy were not being used at all)\n", - "* `dots_saveable` or its alias `checkpoint_dots`\n", - "* `dots_with_no_batch_dims_saveable` or its alias `checkpoint_dots_with_no_batch_dims`\n", - "* `save_anything_but_these_names` (save any values except for the output of\n", - " `checkpoint_name` with any of the names given)\n", - "* `save_any_names_but_these` (save only named values, i.e. any outputs of\n", - " `checkpoint_name`, except for those with the names given)\n", - "* `save_only_these_names` (save only named values, and only among the names\n", - " given)\n", + "A list of policies can be found [here](https://docs.jax.dev/en/latest/jax.html#checkpoint-policies).\n", "\n", "Policies only indicate what is saveable; a value is only saved if it's actually needed by the backward pass." ] diff --git a/docs/notebooks/autodiff_remat.md b/docs/notebooks/autodiff_remat.md index 8ba87dcfee18..5b710e2e0977 100644 --- a/docs/notebooks/autodiff_remat.md +++ b/docs/notebooks/autodiff_remat.md @@ -156,7 +156,7 @@ print_fwd_bwd(f3, W1, W2, W3, x) ### Let's think step by step -You might want to first (re)read [the Autodiff Cookbook Part 1](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html). +You might want to first (re)read [the Autodiff Cookbook Part 1](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html). +++ {"id": "VMfwm_yinvoZ"} @@ -231,7 +231,7 @@ The cost we pay is redundant work: in `f_bwd2` we must re-evaluate `g(x)` as par +++ {"id": "LqTrjPoGqrK7"} -We can get this VJP behavior in autodiff — without having to write VJP functions directly — by instead using `jax.checkpoint` in an alternative definition of the original function `f`: +We can get this VJP behavior in autodiff --- without having to write VJP functions directly --- by instead using `jax.checkpoint` in an alternative definition of the original function `f`: ```{code-cell} def f_checkpoint(x): @@ -275,7 +275,7 @@ Notice that if `f = lambda x: h(g(x))` is the function we want to differentiate, That is, in code we'd have something like: ```{code-cell} -def f_grad_bad(x): +def f_grad_bad1(x): _ = f(x) # step 1 _, f_vjp = jax.vjp(f, x) # step 2 x_bar, = f_vjp(1.0) # step 3 @@ -396,17 +396,7 @@ print_saved_residuals(loss_checkpoint2, params, x, y) Another policy which refers to names is `jax.checkpoint_policies.save_only_these_names`. -Some of the policies are: -* `everything_saveable` (the default strategy, as if `jax.checkpoint` were not being used at all) -* `nothing_saveable` (i.e. rematerialize everything, as if a custom policy were not being used at all) -* `dots_saveable` or its alias `checkpoint_dots` -* `dots_with_no_batch_dims_saveable` or its alias `checkpoint_dots_with_no_batch_dims` -* `save_anything_but_these_names` (save any values except for the output of - `checkpoint_name` with any of the names given) -* `save_any_names_but_these` (save only named values, i.e. any outputs of - `checkpoint_name`, except for those with the names given) -* `save_only_these_names` (save only named values, and only among the names - given) +A list of policies can be found [here](https://docs.jax.dev/en/latest/jax.html#checkpoint-policies). Policies only indicate what is saveable; a value is only saved if it's actually needed by the backward pass. diff --git a/docs/notebooks/colocated-python.ipynb b/docs/notebooks/colocated-python.ipynb new file mode 100644 index 000000000000..5dda5315adb2 --- /dev/null +++ b/docs/notebooks/colocated-python.ipynb @@ -0,0 +1,447 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "WKchP4VBBRgq" + }, + "source": [ + "# Colocated Python\n", + "\n", + "NOTE: Colocated Python is currently an experimental API. Its functionality and\n", + "interface are subject to change without following the standard JAX compatibility\n", + "policy.\n", + "\n", + "Colocated Python provides a uniform way to run Python code on the hosts\n", + "associated with a set of JAX devices. If the JAX devices represent local\n", + "devices, the Python code will run on the local host. If the JAX devices\n", + "represent remote devices, the Python code will be shipped to run on the host of\n", + "these remote devices. This is useful when building a multi-host ML system on top\n", + "of JAX that is portable across multi-controller JAX environments (running JAX\n", + "code on each host with accelerators) as well as single-controller JAX\n", + "environments (running JAX code on a single host orchestrating other hosts with\n", + "accelerators)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "B38uuH1ZBZmd" + }, + "source": [ + "## Colocated CPU devices\n", + "\n", + "To use colocated Python, the first step is to obtain CPU devices colocated with\n", + "target accelerator devices.\n", + "`jax.experimental.colocated_python.colocated_cpu_devices` provides a standard\n", + "way to do so." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "d7FHtd4wCYEf" + }, + "outputs": [], + "source": [ + "import jax\n", + "import jax.experimental.colocated_python as colocated_python\n", + "\n", + "devices = jax.devices()\n", + "cpu_devices = colocated_python.colocated_cpu_devices(devices)\n", + "print(cpu_devices)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Grfb7H4FCVsE" + }, + "source": [ + "As usual, the CPU devices can be used with JAX APIs." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "5RmWK-s4DQsl" + }, + "outputs": [], + "source": [ + "cpu_mesh = jax.sharding.Mesh(cpu_devices, [\"x\"])\n", + "cpu_sharding = jax.sharding.NamedSharding(cpu_mesh, jax.P())\n", + "x = jax.device_put(1, cpu_sharding)\n", + "y = jax.jit(lambda x: x + 1)(x)\n", + "print(y)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7U1OScHaCjSC" + }, + "source": [ + "## Colocated Python function\n", + "\n", + "CPU devices can also be used to run Python code with colocated Python." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "PJbdHF8mDZNT" + }, + "outputs": [], + "source": [ + "def f(x):\n", + " return x + 1\n", + "\n", + "\n", + "f = colocated_python.colocated_python(f)\n", + "y = f(x)\n", + "assert y.sharding == x.sharding\n", + "print(y)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tpGdXqG9C5X3" + }, + "source": [ + "Since colocated Python runs normal Python code, you can also perform I/O:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "MeWnKNlHDgs3" + }, + "outputs": [], + "source": [ + "def f(x):\n", + " with open('/tmp/foo', 'w') as f:\n", + " f.write(str(x))\n", + " return x\n", + "\n", + "\n", + "f = colocated_python.colocated_python(f)\n", + "jax.block_until_ready(f(x))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HOGQQ5IUC7Pe" + }, + "source": [ + "Note the use of `jax.block_until_ready` to ensure the Python code has\n", + "completed. In principle, colocated Python calls may run asynchronously, similar\n", + "to jitted function calls; the calls would return JAX arrays and do not block\n", + "until their output is produced. Thus, you should block on an output from a\n", + "colocated Python call if the completion of the execution is significant.\n", + "\n", + "There exist cases where a colocated Python call runs synchronously.\n", + "\n", + "* If the colocated Python function is called without \"specialization\" (see\n", + " below), the very first call will run synchronously. This is because the shape\n", + " and sharding of the output must be known for asynchronous execution, and\n", + " colocated Python has to run the Python code once to discover this information.\n", + "\n", + "* Some JAX backends do not yet fully support asynchronous execution, and will\n", + " fall back to synchronous execution.\n", + "\n", + "The wrapped Python code must use exactly the same set of devices in the input\n", + "and the output. This is a requirement similar to jitted functions that represent\n", + "an SPMD execution." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uX8q-42tC8ia" + }, + "source": [ + "## Specialization\n", + "\n", + "Specialization in colocated Python is a mechanism to supply extra information\n", + "about the input, output, and execution of a colocated Python function, when the\n", + "information cannot be inferred in advance, or you would like to ensure the\n", + "colocated Python executions to happen precisely as specified.\n", + "\n", + "First, functions wrapped in colocated Python has a `specialize` method.\n", + "This method is used to create another colocated Python wrapped function\n", + "specialized with the supplied information.\n", + "\n", + "`out_specs_fn` is a function that takes a pytree of\n", + "`jax.ShapeDtypeStruct` of the call inputs and returns a pytree of\n", + "`jax.ShapeDtypeStruct` expected for the output. Calling this function is\n", + "analogous to jitted function tracing, but this function is separate from the\n", + "original Python code. This function runs on the caller side and not executed on\n", + "the devices." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "SWEuz68nDtXE" + }, + "outputs": [], + "source": [ + "def f(x):\n", + " return x + 1\n", + "\n", + "\n", + "f = colocated_python.colocated_python(f)\n", + "f = f.specialize(out_specs_fn=lambda x: x)\n", + "y = f(x)\n", + "assert y.sharding == x.sharding" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HkQZwqUBC-QV" + }, + "source": [ + "`in_specs` takes a concrete pytree (the top level is tuple) of\n", + "`jax.sharding.ShapeDtypeStruct` expected for the input to the colocated\n", + "Python function call. This is used if a certain input spec must be used, or the\n", + "output specs function can be computed only for a concrete input spec." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "E0SQPPHID1WU" + }, + "outputs": [], + "source": [ + "import jax.numpy as jnp\n", + "\n", + "\n", + "def f(x):\n", + " return x + 1\n", + "\n", + "\n", + "f = colocated_python.colocated_python(f)\n", + "f = f.specialize(\n", + " in_specs=(\n", + " # args\n", + " (\n", + " jax.ShapeDtypeStruct(\n", + " shape=(), dtype=jnp.int32, sharding=cpu_sharding\n", + " ),\n", + " ),\n", + " # kwargs\n", + " {},\n", + " ),\n", + " out_specs_fn=lambda x: jax.ShapeDtypeStruct(\n", + " shape=(), dtype=jnp.int32, sharding=cpu_sharding\n", + " ),\n", + ")\n", + "f(x) # `x` must match the input spec." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2L7aUBvsC_4m" + }, + "source": [ + "`devices` specifies a list of devices that the colocated Python function\n", + "should run on. Having `devices` specialized lets a colocated Python function\n", + "without input arguments run." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "ZwWQRm_PDAll" + }, + "outputs": [], + "source": [ + "def f():\n", + " with open('/tmp/foo', 'w') as f:\n", + " f.write('foo')\n", + " return\n", + "\n", + "\n", + "f = colocated_python.colocated_python(f)\n", + "f = f.specialize(devices=cpu_devices)\n", + "f() # Would be an error if `f` is not specialized with ``devices``." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xIjM-au9DBQL" + }, + "source": [ + "## Colocated Python class\n", + "\n", + "Colocated Python also supports wrapping Python classes. A real instance is\n", + "created on the hosts associated with the devices, and the caller side will get a\n", + "wrapper class that forwards all method calls to the real instance using\n", + "colocated Python." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "id": "Ikb4Hh5iDB7Z" + }, + "outputs": [], + "source": [ + "class Adder:\n", + "\n", + " def __init__(self, increment):\n", + " print('Adder created')\n", + " self.increment = increment\n", + "\n", + " def __del__(self):\n", + " print('Adder destroyed')\n", + "\n", + " def add(self, x):\n", + " return x + self.increment\n", + "\n", + "\n", + "Adder = colocated_python.colocated_python_class(Adder)\n", + "adder = Adder(1)\n", + "x = jax.device_put(1, cpu_sharding)\n", + "y = adder.add(x)\n", + "print(y)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "t4i192BGDCw8" + }, + "source": [ + "When the wrapper class instance is destroyed, the real instance is destroyed as\n", + "well. Note that this destruction will be asynchronous." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "id": "j5g-NNYFDDln" + }, + "outputs": [], + "source": [ + "del adder" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UfQTjAu9DEV-" + }, + "source": [ + "There are a few important semantic differences between colocated Python and\n", + "normal Python.\n", + "\n", + "* A colocated Python class instance is created only on the hosts associated with\n", + " the devices when any non-constructor method is called for the first time. In\n", + " the above example, `Adder(1)` captures the constructor arguments\n", + " `1`, but the actual constructor call `Adder(1)` on the hosts\n", + " happens only when the first `adder.add(x)` call is made. This is because\n", + " it is unknown what hosts the `Adder` instance should be created on until\n", + " there is a call to its method.\n", + "\n", + "* If the method(s) of the same wrapper class is called with inputs with\n", + " different devices, the real instance may be created at different times on\n", + " different hosts. If the first method call used CPU devices on host A, and the\n", + " second method call used CPU devices on host B, the real instance will be\n", + " created on host A during the first method call, and then on host B during the\n", + " second method call.\n", + "\n", + "* The methods of colocated Python classes are not yet specializable. The support\n", + " will be added in the future." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YOsb92ChDFQd" + }, + "source": [ + "## Execution order and concurrency\n", + "\n", + "Colocated Python provides \"program order\" execution. Even if colocated Python\n", + "calls may be asynchronous (returning output JAX arrays without blocking), the\n", + "calls will be executed in the same order as the order the calls are made in the\n", + "user program. Thus, by default, colocated Python calls are sequentially\n", + "executed.\n", + "\n", + "Several use cases of colocated Python will benefit from concurrent execution.\n", + "For example, one colocated Python call may take long time to return because it\n", + "may be doing expensive file reads, while another colocated Python call may need\n", + "to do file writes that are independent from the first one. This situation could\n", + "expect two calls to run concurrently without blocking each other.\n", + "\n", + "Colocated Python provides concurrent execution if colocated Python calls are\n", + "made from different threads. For example, the below example would make two\n", + "colocated Python calls to run concurrently." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "l0L1-HaGDGHo" + }, + "outputs": [], + "source": [ + "import concurrent.futures\n", + "import time\n", + "\n", + "\n", + "def f(x):\n", + " time.sleep(1)\n", + " return x + 1\n", + "\n", + "\n", + "f = colocated_python.colocated_python(f)\n", + "f = f.specialize(out_specs_fn=lambda x: x) # Calls will be asynchronous.\n", + "\n", + "with concurrent.futures.ThreadPoolExecutor(2) as executor:\n", + " fut1 = executor.submit(f, x)\n", + " fut2 = executor.submit(f, x)\n", + " # Will finish in approximately 1 second instead of 2 seconds.\n", + " jax.block_until_ready([fut1.result(), fut2.result()])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lRYja4_pDHFm" + }, + "source": [ + "While calls from different threads run concurrently, on each thread, program\n", + "ordering will continue to apply." + ] + } + ], + "metadata": { + "colab": { + "private_outputs": true + }, + "jupytext": { + "formats": "ipynb,md:myst", + "main_language": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/notebooks/colocated-python.md b/docs/notebooks/colocated-python.md new file mode 100644 index 000000000000..25178fd1339b --- /dev/null +++ b/docs/notebooks/colocated-python.md @@ -0,0 +1,322 @@ +--- +jupytext: + formats: ipynb,md:myst + main_language: python + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.4 +--- + ++++ {"id": "WKchP4VBBRgq"} + +# Colocated Python + +NOTE: Colocated Python is currently an experimental API. Its functionality and +interface are subject to change without following the standard JAX compatibility +policy. + +Colocated Python provides a uniform way to run Python code on the hosts +associated with a set of JAX devices. If the JAX devices represent local +devices, the Python code will run on the local host. If the JAX devices +represent remote devices, the Python code will be shipped to run on the host of +these remote devices. This is useful when building a multi-host ML system on top +of JAX that is portable across multi-controller JAX environments (running JAX +code on each host with accelerators) as well as single-controller JAX +environments (running JAX code on a single host orchestrating other hosts with +accelerators). + ++++ {"id": "B38uuH1ZBZmd"} + +## Colocated CPU devices + +To use colocated Python, the first step is to obtain CPU devices colocated with +target accelerator devices. +`jax.experimental.colocated_python.colocated_cpu_devices` provides a standard +way to do so. + +```{code-cell} +:id: d7FHtd4wCYEf + +import jax +import jax.experimental.colocated_python as colocated_python + +devices = jax.devices() +cpu_devices = colocated_python.colocated_cpu_devices(devices) +print(cpu_devices) +``` + ++++ {"id": "Grfb7H4FCVsE"} + +As usual, the CPU devices can be used with JAX APIs. + +```{code-cell} +:id: 5RmWK-s4DQsl + +cpu_mesh = jax.sharding.Mesh(cpu_devices, ["x"]) +cpu_sharding = jax.sharding.NamedSharding(cpu_mesh, jax.P()) +x = jax.device_put(1, cpu_sharding) +y = jax.jit(lambda x: x + 1)(x) +print(y) +``` + ++++ {"id": "7U1OScHaCjSC"} + +## Colocated Python function + +CPU devices can also be used to run Python code with colocated Python. + +```{code-cell} +:id: PJbdHF8mDZNT + +def f(x): + return x + 1 + + +f = colocated_python.colocated_python(f) +y = f(x) +assert y.sharding == x.sharding +print(y) +``` + ++++ {"id": "tpGdXqG9C5X3"} + +Since colocated Python runs normal Python code, you can also perform I/O: + +```{code-cell} +:id: MeWnKNlHDgs3 + +def f(x): + with open('/tmp/foo', 'w') as f: + f.write(str(x)) + return x + + +f = colocated_python.colocated_python(f) +jax.block_until_ready(f(x)) +``` + ++++ {"id": "HOGQQ5IUC7Pe"} + +Note the use of `jax.block_until_ready` to ensure the Python code has +completed. In principle, colocated Python calls may run asynchronously, similar +to jitted function calls; the calls would return JAX arrays and do not block +until their output is produced. Thus, you should block on an output from a +colocated Python call if the completion of the execution is significant. + +There exist cases where a colocated Python call runs synchronously. + +* If the colocated Python function is called without "specialization" (see + below), the very first call will run synchronously. This is because the shape + and sharding of the output must be known for asynchronous execution, and + colocated Python has to run the Python code once to discover this information. + +* Some JAX backends do not yet fully support asynchronous execution, and will + fall back to synchronous execution. + +The wrapped Python code must use exactly the same set of devices in the input +and the output. This is a requirement similar to jitted functions that represent +an SPMD execution. + ++++ {"id": "uX8q-42tC8ia"} + +## Specialization + +Specialization in colocated Python is a mechanism to supply extra information +about the input, output, and execution of a colocated Python function, when the +information cannot be inferred in advance, or you would like to ensure the +colocated Python executions to happen precisely as specified. + +First, functions wrapped in colocated Python has a `specialize` method. +This method is used to create another colocated Python wrapped function +specialized with the supplied information. + +`out_specs_fn` is a function that takes a pytree of +`jax.ShapeDtypeStruct` of the call inputs and returns a pytree of +`jax.ShapeDtypeStruct` expected for the output. Calling this function is +analogous to jitted function tracing, but this function is separate from the +original Python code. This function runs on the caller side and not executed on +the devices. + +```{code-cell} +:id: SWEuz68nDtXE + +def f(x): + return x + 1 + + +f = colocated_python.colocated_python(f) +f = f.specialize(out_specs_fn=lambda x: x) +y = f(x) +assert y.sharding == x.sharding +``` + ++++ {"id": "HkQZwqUBC-QV"} + +`in_specs` takes a concrete pytree (the top level is tuple) of +`jax.sharding.ShapeDtypeStruct` expected for the input to the colocated +Python function call. This is used if a certain input spec must be used, or the +output specs function can be computed only for a concrete input spec. + +```{code-cell} +:id: E0SQPPHID1WU + +import jax.numpy as jnp + + +def f(x): + return x + 1 + + +f = colocated_python.colocated_python(f) +f = f.specialize( + in_specs=( + # args + ( + jax.ShapeDtypeStruct( + shape=(), dtype=jnp.int32, sharding=cpu_sharding + ), + ), + # kwargs + {}, + ), + out_specs_fn=lambda x: jax.ShapeDtypeStruct( + shape=(), dtype=jnp.int32, sharding=cpu_sharding + ), +) +f(x) # `x` must match the input spec. +``` + ++++ {"id": "2L7aUBvsC_4m"} + +`devices` specifies a list of devices that the colocated Python function +should run on. Having `devices` specialized lets a colocated Python function +without input arguments run. + +```{code-cell} +:id: ZwWQRm_PDAll + +def f(): + with open('/tmp/foo', 'w') as f: + f.write('foo') + return + + +f = colocated_python.colocated_python(f) +f = f.specialize(devices=cpu_devices) +f() # Would be an error if `f` is not specialized with ``devices``. +``` + ++++ {"id": "xIjM-au9DBQL"} + +## Colocated Python class + +Colocated Python also supports wrapping Python classes. A real instance is +created on the hosts associated with the devices, and the caller side will get a +wrapper class that forwards all method calls to the real instance using +colocated Python. + +```{code-cell} +:id: Ikb4Hh5iDB7Z + +class Adder: + + def __init__(self, increment): + print('Adder created') + self.increment = increment + + def __del__(self): + print('Adder destroyed') + + def add(self, x): + return x + self.increment + + +Adder = colocated_python.colocated_python_class(Adder) +adder = Adder(1) +x = jax.device_put(1, cpu_sharding) +y = adder.add(x) +print(y) +``` + ++++ {"id": "t4i192BGDCw8"} + +When the wrapper class instance is destroyed, the real instance is destroyed as +well. Note that this destruction will be asynchronous. + +```{code-cell} +:id: j5g-NNYFDDln + +del adder +``` + ++++ {"id": "UfQTjAu9DEV-"} + +There are a few important semantic differences between colocated Python and +normal Python. + +* A colocated Python class instance is created only on the hosts associated with + the devices when any non-constructor method is called for the first time. In + the above example, `Adder(1)` captures the constructor arguments + `1`, but the actual constructor call `Adder(1)` on the hosts + happens only when the first `adder.add(x)` call is made. This is because + it is unknown what hosts the `Adder` instance should be created on until + there is a call to its method. + +* If the method(s) of the same wrapper class is called with inputs with + different devices, the real instance may be created at different times on + different hosts. If the first method call used CPU devices on host A, and the + second method call used CPU devices on host B, the real instance will be + created on host A during the first method call, and then on host B during the + second method call. + +* The methods of colocated Python classes are not yet specializable. The support + will be added in the future. + ++++ {"id": "YOsb92ChDFQd"} + +## Execution order and concurrency + +Colocated Python provides "program order" execution. Even if colocated Python +calls may be asynchronous (returning output JAX arrays without blocking), the +calls will be executed in the same order as the order the calls are made in the +user program. Thus, by default, colocated Python calls are sequentially +executed. + +Several use cases of colocated Python will benefit from concurrent execution. +For example, one colocated Python call may take long time to return because it +may be doing expensive file reads, while another colocated Python call may need +to do file writes that are independent from the first one. This situation could +expect two calls to run concurrently without blocking each other. + +Colocated Python provides concurrent execution if colocated Python calls are +made from different threads. For example, the below example would make two +colocated Python calls to run concurrently. + +```{code-cell} +:id: l0L1-HaGDGHo + +import concurrent.futures +import time + + +def f(x): + time.sleep(1) + return x + 1 + + +f = colocated_python.colocated_python(f) +f = f.specialize(out_specs_fn=lambda x: x) # Calls will be asynchronous. + +with concurrent.futures.ThreadPoolExecutor(2) as executor: + fut1 = executor.submit(f, x) + fut2 = executor.submit(f, x) + # Will finish in approximately 1 second instead of 2 seconds. + jax.block_until_ready([fut1.result(), fut2.result()]) +``` + ++++ {"id": "lRYja4_pDHFm"} + +While calls from different threads run concurrently, on each thread, program +ordering will continue to apply. diff --git a/docs/notebooks/explicit-sharding.ipynb b/docs/notebooks/explicit-sharding.ipynb index d656e12d4068..7fe3ea3184db 100644 --- a/docs/notebooks/explicit-sharding.ipynb +++ b/docs/notebooks/explicit-sharding.ipynb @@ -28,7 +28,7 @@ "of work and it's also easy to make mistakes that way because there's no way to\n", "check that the shardings make sense together. More commonly, people add just\n", "enough sharding annotations to constrain the compiler. But this is a slow\n", - "iterative process. It's hard to know ahead of time what XLA's gSPMD pass will\n", + "iterative process. It's hard to know ahead of time what XLA's GSPMD pass will\n", "do (it's a whole-program optimization) so all you can do is add annotations,\n", "inspect XLA's sharding choices to see what happened, and repeat.\n", "\n", @@ -44,7 +44,8 @@ "also be _queried_ at trace time too. In the rest of this doc we'll describe\n", "how to use explicit sharding mode. Note that this is a new feature so we\n", "expect there to be bugs and unimplemented cases. Please let us know when you\n", - "find something that doesn't work!" + "find something that doesn't work! Also see {doc}`../the-training-cookbook`\n", + "for a real-world machine learning training example that uses explicit sharding." ] }, { @@ -58,8 +59,7 @@ "import jax\n", "import numpy as np\n", "import jax.numpy as jnp\n", - "from jax.sharding import PartitionSpec as P, AxisType, set_mesh, get_abstract_mesh\n", - "from jax.experimental.shard import reshard, auto_axes\n", + "from jax.sharding import PartitionSpec as P, AxisType, get_abstract_mesh, reshard\n", "\n", "jax.config.update('jax_num_cpu_devices', 8)" ] @@ -160,10 +160,11 @@ "These types show the shape and dtype of array but they don't appear to\n", "show sharding. (Actually, they _did_ show sharding, but the shardings were\n", "trivial. See \"Concrete array shardings\", below.) To start seeing some\n", - "interesting shardings we need to set up an explicit-sharding mesh. We use\n", - "`set_mesh` to set it as the current mesh for the remainder of this notebook.\n", - "(If you only want to set the mesh for some particular scope and return to the previous\n", - "mesh afterwards then you can use the context manager `jax.sharding.use_mesh` instead.)" + "interesting shardings we need to set up an explicit-sharding mesh.\n", + "\n", + "`jax.set_mesh` can be used as a global setter or a context manager. We use\n", + "`jax.set_mesh` in this notebook as a global setter. You can use it as a scoped\n", + "context manager via `with jax.set_mesh(mesh)`." ] }, { @@ -188,7 +189,7 @@ "source": [ "mesh = jax.make_mesh((2, 4), (\"X\", \"Y\"),\n", " axis_types=(AxisType.Explicit, AxisType.Explicit))\n", - "set_mesh(mesh)\n", + "jax.set_mesh(mesh)\n", "\n", "print(f\"Current mesh is: {get_abstract_mesh()}\")" ] @@ -397,7 +398,7 @@ " which the split/merged axes are sharded as None then we shard the\n", " resulting split/merged axes as None and the other axes according to their\n", " corresponding input axis shardings. In all other cases we throw an error\n", - " and require the user to provide an `out_shardings` argument." + " and require the user to provide an `out_sharding` argument." ] }, { @@ -414,7 +415,7 @@ "wherever types need to match. For example, the two sides of a `lax.cond` need to\n", "have results with matching shardings. And the carry of `lax.scan` needs to have the\n", "same sharding at the input and the output of the scan body. And when you\n", - "contruct a jaxpr without concrete arguments using `make_jaxpr` you need to\n", + "construct a jaxpr without concrete arguments using `make_jaxpr` you need to\n", "provide shardings too. Certain JAX transformations perform type-level\n", "operations. Automatic differentation constructs a tangent type for each primal\n", "type in the original computation (e.g. `TangentOf(float) == float`,\n", @@ -433,7 +434,7 @@ "id": "ERJx4p0tXoS3" }, "source": [ - "## Working around unimplemented sharding rules using `auto_sharding`\n", + "## Working around unimplemented sharding rules using `auto_axes`\n", "\n", "The implementation of explicit sharding is still a work-in-progress and there\n", "are plenty of ops that are missing sharding rules. For example, `scatter` and\n", @@ -478,6 +479,8 @@ } ], "source": [ + "from jax.sharding import auto_axes, explicit_axes\n", + "\n", "some_x = reshard(np.arange(16).reshape(4, 4), P(\"X\", None))\n", "some_y = reshard(np.arange(16).reshape(4, 4), P(None, \"X\"))\n", "\n", @@ -494,7 +497,7 @@ " print(f\"We're in auto-sharding mode here. This is the current mesh: {get_abstract_mesh()}\")\n", " return x + y\n", "\n", - "result = add_with_out_sharding_kwarg(some_x, some_y, out_shardings=P(\"X\", None))\n", + "result = add_with_out_sharding_kwarg(some_x, some_y, out_sharding=P(\"X\", None))\n", "print(f\"Result type: {jax.typeof(result)}\")" ] }, @@ -527,11 +530,11 @@ "\n", "A summary table:\n", "\n", - "| Mode | Explicit sharding? | Explicit Collectives? |\n", - "|---|---|---|\n", - "| Auto | No | No |\n", - "| Explicit (new) | Yes | No |\n", - "| Manual | Yes | Yes |\n", + "| Mode | View? | Explicit sharding? | Explicit Collectives? |\n", + "|---|---|---|---|\n", + "| Auto | Global | ❌ | ❌ |\n", + "| Explicit | Global | ✅ | ❌ |\n", + "| Manual | Per-device | ✅ | ✅ |\n", "\n", "The current mesh tells us which sharding mode we're in. We can query it with\n", "`get_abstract_mesh`:" @@ -637,7 +640,7 @@ " x = jnp.sin(arr1)\n", " print(f'x.sharding: {jax.typeof(x)}', end='\\n\\n')\n", "\n", - " z = g(x, out_shardings=P(\"X\", \"Y\"))\n", + " z = g(x, out_sharding=P(\"X\", \"Y\"))\n", "\n", " print(f'z.sharding: {jax.typeof(z)}', end=\"\\n\\n\")\n", " return z + 1\n", @@ -652,7 +655,51 @@ "id": "_3sfJjRq8w9f" }, "source": [ - "As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`." + "As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`.\n", + "\n", + "\n", + "You can also use the `explicit_axes` API to drop into `Explicit` mode over some or all mesh axes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a102e9c7", + "metadata": {}, + "outputs": [], + "source": [ + "auto_mesh = jax.make_mesh((2, 4), (\"X\", \"Y\"),\n", + " axis_types=(AxisType.Auto, AxisType.Auto))\n", + "\n", + "@functools.partial(explicit_axes, axes=('X', 'Y'))\n", + "def explicit_g(y):\n", + " print(f'mesh inside g: {get_abstract_mesh()}')\n", + " print(f'y.sharding inside g: {jax.typeof(y) = }')\n", + " z = y * 2\n", + " print(f'z.sharding inside g: {jax.typeof(z) = }', end='\\n\\n')\n", + " return z\n", + "\n", + "@jax.jit\n", + "def f(arr1):\n", + " print(f'mesh inside f: {get_abstract_mesh()}', end='\\n\\n')\n", + " x = jnp.sin(arr1)\n", + "\n", + " z = explicit_g(x, in_sharding=P(\"X\", \"Y\"))\n", + "\n", + " return z + 1\n", + "\n", + "with jax.set_mesh(auto_mesh):\n", + " some_x = jax.device_put(np.arange(16).reshape(4, 4), P(\"X\", \"Y\"))\n", + " f(some_x)" + ] + }, + { + "cell_type": "markdown", + "id": "e64d40de", + "metadata": {}, + "source": [ + "As you can see, all axes of mesh inside `f` are of type `Auto` while inside `g`, they are of type `Explicit`.\n", + "Because of that, sharding is visible on the type of arrays inside `g`." ] }, { @@ -734,7 +781,7 @@ " compare_shardings(x)\n", " return x\n", "\n", - "check_in_auto_context(my_array, out_shardings=P(\"X\"))" + "check_in_auto_context(my_array, out_sharding=P(\"X\"))" ] }, { diff --git a/docs/notebooks/explicit-sharding.md b/docs/notebooks/explicit-sharding.md index 7c59a675d8ec..7d85e6cdc850 100644 --- a/docs/notebooks/explicit-sharding.md +++ b/docs/notebooks/explicit-sharding.md @@ -31,7 +31,7 @@ constraints? You could put them on every single intermediate but that's a lot of work and it's also easy to make mistakes that way because there's no way to check that the shardings make sense together. More commonly, people add just enough sharding annotations to constrain the compiler. But this is a slow -iterative process. It's hard to know ahead of time what XLA's gSPMD pass will +iterative process. It's hard to know ahead of time what XLA's GSPMD pass will do (it's a whole-program optimization) so all you can do is add annotations, inspect XLA's sharding choices to see what happened, and repeat. @@ -47,7 +47,8 @@ error otherwise. Since the shardings are propagated at trace time they can also be _queried_ at trace time too. In the rest of this doc we'll describe how to use explicit sharding mode. Note that this is a new feature so we expect there to be bugs and unimplemented cases. Please let us know when you -find something that doesn't work! +find something that doesn't work! Also see {doc}`../the-training-cookbook` +for a real-world machine learning training example that uses explicit sharding. ```{code-cell} ipython3 :id: hVi6mApuVw3r @@ -55,8 +56,7 @@ find something that doesn't work! import jax import numpy as np import jax.numpy as jnp -from jax.sharding import PartitionSpec as P, AxisType, set_mesh, get_abstract_mesh -from jax.experimental.shard import reshard, auto_axes +from jax.sharding import PartitionSpec as P, AxisType, get_abstract_mesh, reshard jax.config.update('jax_num_cpu_devices', 8) ``` @@ -107,10 +107,11 @@ foo(some_array) These types show the shape and dtype of array but they don't appear to show sharding. (Actually, they _did_ show sharding, but the shardings were trivial. See "Concrete array shardings", below.) To start seeing some -interesting shardings we need to set up an explicit-sharding mesh. We use -`set_mesh` to set it as the current mesh for the remainder of this notebook. -(If you only want to set the mesh for some particular scope and return to the previous -mesh afterwards then you can use the context manager `jax.sharding.use_mesh` instead.) +interesting shardings we need to set up an explicit-sharding mesh. + +`jax.set_mesh` can be used as a global setter or a context manager. We use +`jax.set_mesh` in this notebook as a global setter. You can use it as a scoped +context manager via `with jax.set_mesh(mesh)`. ```{code-cell} ipython3 --- @@ -121,7 +122,7 @@ outputId: d888371b-080e-4bff-be5d-ea56beda3aac --- mesh = jax.make_mesh((2, 4), ("X", "Y"), axis_types=(AxisType.Explicit, AxisType.Explicit)) -set_mesh(mesh) +jax.set_mesh(mesh) print(f"Current mesh is: {get_abstract_mesh()}") ``` @@ -239,7 +240,7 @@ Here are some example sharding rules: which the split/merged axes are sharded as None then we shard the resulting split/merged axes as None and the other axes according to their corresponding input axis shardings. In all other cases we throw an error - and require the user to provide an `out_shardings` argument. + and require the user to provide an `out_sharding` argument. +++ {"id": "jZMp6w48Xmd7"} @@ -251,7 +252,7 @@ sharding is part of that type. This means that shardings need to match wherever types need to match. For example, the two sides of a `lax.cond` need to have results with matching shardings. And the carry of `lax.scan` needs to have the same sharding at the input and the output of the scan body. And when you -contruct a jaxpr without concrete arguments using `make_jaxpr` you need to +construct a jaxpr without concrete arguments using `make_jaxpr` you need to provide shardings too. Certain JAX transformations perform type-level operations. Automatic differentation constructs a tangent type for each primal type in the original computation (e.g. `TangentOf(float) == float`, @@ -265,7 +266,7 @@ argument. +++ {"id": "ERJx4p0tXoS3"} -## Working around unimplemented sharding rules using `auto_sharding` +## Working around unimplemented sharding rules using `auto_axes` The implementation of explicit sharding is still a work-in-progress and there are plenty of ops that are missing sharding rules. For example, `scatter` and @@ -292,6 +293,8 @@ colab: id: fpFEaMBcXsJG outputId: 5b84b1d1-d7b2-4e9a-ba98-3dd34a5465ef --- +from jax.sharding import auto_axes, explicit_axes + some_x = reshard(np.arange(16).reshape(4, 4), P("X", None)) some_y = reshard(np.arange(16).reshape(4, 4), P(None, "X")) @@ -308,7 +311,7 @@ def add_with_out_sharding_kwarg(x, y): print(f"We're in auto-sharding mode here. This is the current mesh: {get_abstract_mesh()}") return x + y -result = add_with_out_sharding_kwarg(some_x, some_y, out_shardings=P("X", None)) +result = add_with_out_sharding_kwarg(some_x, some_y, out_sharding=P("X", None)) print(f"Result type: {jax.typeof(result)}") ``` @@ -337,11 +340,11 @@ JAX now has three styles of parallelism: A summary table: -| Mode | Explicit sharding? | Explicit Collectives? | -|---|---|---| -| Auto | No | No | -| Explicit (new) | Yes | No | -| Manual | Yes | Yes | +| Mode | View? | Explicit sharding? | Explicit Collectives? | +|---|---|---|---| +| Auto | Global | ❌ | ❌ | +| Explicit | Global | ✅ | ❌ | +| Manual | Per-device | ✅ | ✅ | The current mesh tells us which sharding mode we're in. We can query it with `get_abstract_mesh`: @@ -390,7 +393,7 @@ def f(arr1): x = jnp.sin(arr1) print(f'x.sharding: {jax.typeof(x)}', end='\n\n') - z = g(x, out_shardings=P("X", "Y")) + z = g(x, out_sharding=P("X", "Y")) print(f'z.sharding: {jax.typeof(z)}', end="\n\n") return z + 1 @@ -403,6 +406,38 @@ f(some_x) As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`. + +You can also use the `explicit_axes` API to drop into `Explicit` mode over some or all mesh axes. + +```{code-cell} ipython3 +auto_mesh = jax.make_mesh((2, 4), ("X", "Y"), + axis_types=(AxisType.Auto, AxisType.Auto)) + +@functools.partial(explicit_axes, axes=('X', 'Y')) +def explicit_g(y): + print(f'mesh inside g: {get_abstract_mesh()}') + print(f'y.sharding inside g: {jax.typeof(y) = }') + z = y * 2 + print(f'z.sharding inside g: {jax.typeof(z) = }', end='\n\n') + return z + +@jax.jit +def f(arr1): + print(f'mesh inside f: {get_abstract_mesh()}', end='\n\n') + x = jnp.sin(arr1) + + z = explicit_g(x, in_sharding=P("X", "Y")) + + return z + 1 + +with jax.set_mesh(auto_mesh): + some_x = jax.device_put(np.arange(16).reshape(4, 4), P("X", "Y")) + f(some_x) +``` + +As you can see, all axes of mesh inside `f` are of type `Auto` while inside `g`, they are of type `Explicit`. +Because of that, sharding is visible on the type of arrays inside `g`. + +++ {"id": "sJcWbfAh7UcO"} ## Concrete array shardings can mention `Auto` mesh axis @@ -437,7 +472,7 @@ def check_in_auto_context(x): compare_shardings(x) return x -check_in_auto_context(my_array, out_shardings=P("X")) +check_in_auto_context(my_array, out_sharding=P("X")) ``` +++ {"id": "MRFccsi5X8so"} diff --git a/docs/notebooks/host-offloading.ipynb b/docs/notebooks/host-offloading.ipynb new file mode 100644 index 000000000000..765809354ece --- /dev/null +++ b/docs/notebooks/host-offloading.ipynb @@ -0,0 +1,888 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "bQbS50fIdHw1" + }, + "source": [ + "(host-offloading)=\n", + "# JAX Memories and Host Offloading\n", + "\n", + "\n", + "\n", + "This tutorial provides a practical introduction to host offloading techniques in JAX, focusing on:\n", + "\n", + "- Activation offloading\n", + "- Parameter offloading\n", + "- Optimizer state offloading\n", + "\n", + "By applying offloading strategies, developers can better manage memory resources and reduce memory pressure on devices. To implement these strategies effectively, understanding JAX's core mechanisms for data placement and movement is essential.\n", + "\n", + "## Building Blocks for Offloading\n", + "\n", + "JAX provides several key components for controlling where and how data are stored and moved between the host and the device memory. The following sections explore:\n", + "\n", + "- How to specify data distribution with sharding\n", + "- How to control memory placement between host and device\n", + "- How to manage data movement in jitted functions\n", + "\n", + "### NamedSharding and Memory Kinds\n", + "\n", + "{class}`~jax.sharding.NamedSharding` defines how data are distributed across devices. It includes:\n", + "\n", + "- Basic data distribution configuration\n", + "- `memory_kind` parameter for specifying memory type (`device` or `pinned_host`)\n", + "- By default, `memory_kind` is set to `device` memory\n", + "- `with_memory_kind` method for creating new sharding with modified memory type" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "f-6sxUlqrlBn", + "outputId": "691a3df2-8341-44a9-a4a0-5521c2d891e3" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "NamedSharding(mesh=Mesh('x': 1, 'y': 1), spec=PartitionSpec('x', 'y'), memory_kind=device)\n", + "NamedSharding(mesh=Mesh('x': 1, 'y': 1), spec=PartitionSpec('x', 'y'), memory_kind=pinned_host)\n" + ] + } + ], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "from jax.sharding import Mesh, NamedSharding, PartitionSpec as P\n", + "import numpy as np\n", + "\n", + "# Create mesh\n", + "# 1x1 mesh represents a single device with two named dimensions (x and y)\n", + "mesh = Mesh(np.array(jax.devices()[0]).reshape(1, 1), ('x', 'y'))\n", + "\n", + "# Device sharding - partitions data along x and y dimensions\n", + "s_dev = NamedSharding(mesh, P('x', 'y'), memory_kind=\"device\")\n", + "\n", + "# Host sharding - same partitioning but in pinned host memory\n", + "s_host = s_dev.with_memory_kind('pinned_host')\n", + "\n", + "print(s_dev) # Shows device memory sharding\n", + "print(s_host) # Shows pinned host memory sharding" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "R_pB9465VoMP" + }, + "source": [ + "### Data Placement with device_put\n", + "\n", + "{func}`jax.device_put` is a function that explicitly transfers arrays to a specified memory location according to a sharding specification." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "OJFnf7FGp6Lj", + "outputId": "c762e1df-2453-4ed9-9d53-0defb6a05ce2" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "pinned_host\n", + "device\n" + ] + } + ], + "source": [ + "# Create a 2x4 array\n", + "arr = jnp.arange(8.0).reshape(2, 4)\n", + "\n", + "# Move arrays to different memory locations based on sharding objects\n", + "arr_host = jax.device_put(arr, s_host) # Places in pinned host memory\n", + "arr_dev = jax.device_put(arr, s_dev) # Places in device memory\n", + "\n", + "# Verify memory locations\n", + "print(arr_host.sharding.memory_kind) # Output: pinned_host\n", + "print(arr_dev.sharding.memory_kind) # Output: device" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HHXvBpQKTMCR" + }, + "source": [ + "### Output Sharding Controls\n", + "\n", + "Shardings determine how data is split across devices. JAX provides `out_shardings` to control how output arrays are partitioned when leaving a jitted function.\n", + "\n", + "Key Features:\n", + " - Can differ from input sharding\n", + " - Allows different memory kinds for outputs\n", + "\n", + "Examples:\n", + "\n", + "#### Device Output Sharding" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ZXNj9NUeaIdX", + "outputId": "399321ef-082a-4a77-c33a-9de3421f429b" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Result value of H2D: \n", + " [[0. 1. 2. 3.]\n", + " [4. 5. 6. 7.]]\n" + ] + } + ], + "source": [ + "f = jax.jit(lambda x:x, out_shardings=s_dev)\n", + "out_dev = f(arr_host)\n", + "print(\"Result value of H2D: \\n\", out_dev)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iYXC5ix384XP" + }, + "source": [ + "Moving data from host to device memory when needed for computation is the essence of host offloading. Use {func}`jax.device_put` to perform this transfer in this example to optimize performance." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "cmM6tJTS84XQ", + "outputId": "40c353a1-fb55-44bc-bac9-dffc09852f49" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Result value of H2D and add 1 in device memory: \n", + " [[1. 2. 3. 4.]\n", + " [5. 6. 7. 8.]]\n" + ] + } + ], + "source": [ + "# Instead of the lambda function, add_func can be defined explicitly\n", + "# move data to device before computation\n", + "def add_func(x): # Move data to device and add one\n", + " x = jax.device_put(x, s_dev)\n", + " return x + 1\n", + "\n", + "f = jax.jit(add_func, out_shardings=s_dev)\n", + "out_dev = f(arr_host)\n", + "print(\"Result value of H2D and add 1 in device memory: \\n\", out_dev)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EbE-eBrJTBuS" + }, + "source": [ + "#### Host Output Sharding" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "FjZzkxI8ky4r", + "outputId": "2a1b6e7a-1c29-4347-c020-7b47c27a5cc3" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Result value of D2H: \n", + " [[0. 1. 2. 3.]\n", + " [4. 5. 6. 7.]]\n" + ] + } + ], + "source": [ + "f = jax.jit(lambda x: x, out_shardings=s_host)\n", + "out_host = f(arr_dev) # Input arrays in the device memory while output arrays in the host memory\n", + "print(\"Result value of D2H: \\n\", out_host)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UhLVvRO2p6Lj" + }, + "source": [ + "## Activation Offloading\n", + "\n", + "Before diving into activation offloading, let's first take a look at the baseline code.\n", + "\n", + "This code implements a simple neural network with 10 layers, each consisting of two linear transformations. The code demonstrates basic memory usage patterns and provides a foundation for comparing offloading optimization techniques.\n", + "\n", + "Key components:\n", + "- Each layer consists of two sequential linear operations:\n", + " 1. First multiplication: `x @ w1`\n", + " 2. Second multiplication: `y @ w2`\n", + "- 10-layer network using JAX's scan operation\n", + "- Memory usage analysis\n", + "- Gradient computation with JIT compilation\n", + "\n", + "To analyze memory usage in JAX, the {func}`jax.stages.Compiled.memory_analysis` method can be used on a compiled function. This provides detailed statistics about memory consumption during computation. The key metrics include temporary memory size, argument size, output size, and alias size. To calculate the total memory usage, sum the temporary, argument, and output sizes, then subtract the alias size to avoid double-counting the same memory multiple times. This provides a summarized view of how the device memory is utilized across different aspects of the computation." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "UEt0dtxukkaz", + "outputId": "22bb32b7-8491-4100-f212-e56c50f44cfa" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Temp size: 17.25 MB\n", + "Argument size: 20.25 MB\n", + "Total size: 57.50 MB\n", + "Sample of results: [3.8312336e-07 3.8312336e-07 3.8312336e-07 3.8312336e-07 3.8312336e-07]\n" + ] + } + ], + "source": [ + "# Initialize input and weights with small values (0.0001)\n", + "input = jnp.ones((256, 256), dtype=jnp.float32) * 0.001 # Input matrix: 256 x 256\n", + "w1 = jnp.ones((10, 256, 1024), dtype=jnp.float32) * 0.001 # 10 layers of 256 x 1024 matrices\n", + "w2 = jnp.ones((10, 1024, 256), dtype=jnp.float32) * 0.001 # 10 layers of 1024 x 256 matrices\n", + "\n", + "def two_layers(x, w):\n", + " # Simple two-layer linear transformation\n", + " w1, w2 = w\n", + " y = x @ w1\n", + " return y @ w2, None\n", + "\n", + "def scanned(w, x):\n", + " # Applies the layer function 10 times using JAX's scan operation\n", + " # Input: w (tuple of weight matrices), x (input matrix)\n", + " # Output: sum of the final layer's output\n", + " result = jax.lax.scan(two_layers, x, w)[0]\n", + " return jnp.sum(result)\n", + "\n", + "# Compile and compute gradients of the scanned function\n", + "f = jax.jit(jax.grad(scanned)) # Apply JIT compilation to gradient computation\n", + "\n", + "# Analyze memory usage\n", + "compiled_step = f.lower((w1, w2), input).compile()\n", + "compiled_stats = compiled_step.memory_analysis()\n", + "\n", + "if compiled_stats is not None:\n", + " # Calculate total memory usage including temporary storage, arguments, and outputs\n", + " # Subtract alias size to avoid double-counting memory shared between different components\n", + " total = compiled_stats.temp_size_in_bytes + compiled_stats.argument_size_in_bytes \\\n", + " + compiled_stats.output_size_in_bytes - compiled_stats.alias_size_in_bytes\n", + " print(f\"Temp size: {compiled_stats.temp_size_in_bytes / (1024**2):.2f} MB\")\n", + " print(f\"Argument size: {compiled_stats.argument_size_in_bytes / (1024**2):.2f} MB\")\n", + " print(f\"Total size: {total/(1024**2):.2f} MB\")\n", + "\n", + "# Execute the function and print sample results\n", + "result = f((w1, w2), input) # Execute the function with weights and input\n", + "print(\"Sample of results: \", result[0][0, 0, :5])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DnFyRt2nkkaz" + }, + "source": [ + "The detailed coverage of activation offloading can be found in the {ref}`gradient-checkpointing` tutorial. Activation offloading helps manage memory by moving intermediate activations to host memory after the forward pass, and bringing them back to device memory during the backward pass when needed for gradient computation.\n", + "\n", + "To implement activation offloading effectively, it is important to understand checkpoint names and policies. Here's how they work in a simple example:\n", + "\n", + "### Checkpoint Names\n", + "\n", + "The {func}`checkpoint_name` function allows labeling activations for memory management during computation. Here's a simple example that a checkpoint name `x` is specified." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "sLO9ceS6p6Lj" + }, + "outputs": [], + "source": [ + "from jax.ad_checkpoint import checkpoint_name\n", + "\n", + "def layer_name(x, w):\n", + " w1, w2 = w\n", + " x = checkpoint_name(x, \"x\")\n", + " y = x @ w1\n", + " return y @ w2, None" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-_T92oCOp6Lk" + }, + "source": [ + "The checkpoint name helps the system decide whether to:\n", + "* Keep the activation in device memory or\n", + "* Offload it to host memory during computation\n", + "\n", + "This pattern is common in neural networks, where multiple transformations are applied sequentially to input data.\n", + "\n", + "### Checkpoint Policies\n", + "\n", + "This checkpoint policy implements a memory management strategy that optimizes memory usage during computation. It manages memory by handling intermediate values through three strategies:\n", + "1. Recomputing during backward pass (default behavior)\n", + "2. Storing on device\n", + "3. Offloading to host memory after forward pass and loading back during backward pass" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "W8Usw_wOp6Lk" + }, + "outputs": [], + "source": [ + "from jax import checkpoint_policies as cp\n", + "\n", + "policy = cp.save_and_offload_only_these_names(\n", + " names_which_can_be_saved=[], # No values stored on device\n", + " names_which_can_be_offloaded=[\"x\"], # Offload activations labeled \"x\"\n", + " offload_src=\"device\", # Move from device memory\n", + " offload_dst=\"pinned_host\" # To pinned host memory\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iuDRCXu7ky4r" + }, + "source": [ + "{func}`jax.lax.scan` is commonly used in JAX for handling sequential operations (like RNNs or transformers). It can be integrated with JAX's rematerialization to process sequential data.\n", + "\n", + "Key components:\n", + "* {func}`jax.remat` creates a rematerialized version of the layer function using {func}`jax.remat` and applies the checkpoint policy to the layer function\n", + "* `prevent_cse=False` enables XLA's common subexpression elimination for better performance\n", + "* {func}`jax.lax.scan` iterates the rematerialized layer along an axis" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "xCrxjTx_p6Lk", + "outputId": "13d46584-9b25-4622-b3c3-f50c1dac02c2" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Temp size: 6.50 MB\n", + "Argument size: 20.25 MB\n", + "Total size: 46.75 MB\n", + "Results match within tolerance: True\n", + "Sample of results: [3.8312336e-07 3.8312336e-07 3.8312336e-07 3.8312336e-07 3.8312336e-07]\n" + ] + } + ], + "source": [ + "def scanned(w, x):\n", + " remat_layer = jax.remat(layer_name,\n", + " policy=policy, # Use our offloading policy\n", + " prevent_cse=False) # Allow CSE optimizations\n", + " result = jax.lax.scan(remat_layer, x, w)[0]\n", + " return jnp.sum(result)\n", + "\n", + "# Initialize input and weights with small values (0.0001)\n", + "input = jnp.ones((256, 256), dtype=jnp.float32) * 0.001 # Input matrix: 256 x 256\n", + "w1 = jnp.ones((10, 256, 1024), dtype=jnp.float32) * 0.001 # 10 layers of 256 x 1024 matrices\n", + "w2 = jnp.ones((10, 1024, 256), dtype=jnp.float32) * 0.001 # 10 layers of 1024 x 256 matrices\n", + "\n", + "# Compile and compute gradients of the scanned function\n", + "f = jax.jit(jax.grad(scanned)) # Apply JIT compilation to gradient computation\n", + "\n", + "# Analyze memory usage\n", + "compiled_step = f.lower((w1, w2), input).compile()\n", + "compiled_stats = compiled_step.memory_analysis()\n", + "\n", + "if compiled_stats is not None:\n", + " total = compiled_stats.temp_size_in_bytes + compiled_stats.argument_size_in_bytes \\\n", + " + compiled_stats.output_size_in_bytes - compiled_stats.alias_size_in_bytes\n", + " print(f\"Temp size: {compiled_stats.temp_size_in_bytes / (1024**2):.2f} MB\")\n", + " print(f\"Argument size: {compiled_stats.argument_size_in_bytes / (1024**2):.2f} MB\")\n", + " print(f\"Total size: {total/(1024**2):.2f} MB\")\n", + "\n", + "result_activation = f((w1, w2), input) # Execute the function with weights and input\n", + "# Verify numerical correctness\n", + "are_close = jnp.allclose(\n", + " result_activation[0], # Result from activation offloading only\n", + " result[0], # Result from both activation and parameter offloading\n", + " rtol=1e-5,\n", + " atol=1e-5\n", + ")\n", + "print(f\"Results match within tolerance: {are_close}\")\n", + "print(\"Sample of results: \", result_activation[0][0, 0, :5])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0tx7aara42pY" + }, + "source": [ + "Activation offloading reduces temporary memory usage from 17.25 MB to 6.5 MB while input and output argument sizes remain the same. Totally 10.75 MB is saved. It is achieved by offloading activation `x` to host memory after the forward pass and loading it back to device memory before the backward pass.\n", + "\n", + "### Summary of Activation Offloading\n", + "\n", + "Activation offloading provides a powerful way to manage memory in large computations by:\n", + "\n", + "* Using checkpoint names to mark specific activations\n", + "* Applying policies to control where and how activations are stored\n", + "* Supporting common JAX patterns like scan operations\n", + "* Moving selected activations to host memory when device memory is under budget\n", + "\n", + "This approach is particularly useful when working with large models that would otherwise exceed device memory capacity.\n", + "\n", + "## Parameter Offloading\n", + "\n", + "Model parameters (also known as weights) can be offloaded to the host memory to optimize device memory usage during initialization. This is achieved by using {func}`jax.jit` with a sharding strategy that specifies host memory kind.\n", + "\n", + "While parameter offloading and activation offloading are distinct memory optimization techniques, the following example demonstrates parameter offloading built upon the activation offloading implementation shown earlier.\n", + "\n", + "### Parameter Placement for Computation\n", + "\n", + "Different from the earlier `layer` function, {func}`jax.device_put` is applied to move parameter `w1` and `w2` to the device before the matrix multiplications. This ensures the parameters are available on the device for both forward and backward passes.\n", + "\n", + "Note that the activation offloading implementation remains unchanged, using the same:\n", + "* Checkpoint name `\"x\"`\n", + "* Checkpoint policy\n", + "* `scanned` function combining {func}`jax.remat` and {func}`jax.lax.scan`\n", + "\n", + "### Parameter Initialization with Host Offloading\n", + "\n", + "During the initialization, parameter `w1` and `w2` are placed on host memory before being passed to the {func}`jax.jit` function `f`, while keeping the `input` variable on the device." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "1qGN2hBQdheo", + "outputId": "48c09658-f8b6-4be3-ef0e-02e0e2566e10" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Temp size: 4.75 MB\n", + "Argument size: 0.25 MB\n", + "Total size: 25.00 MB\n", + "Results match within tolerance: True\n" + ] + } + ], + "source": [ + "# Hybrid version: Both activation and parameter offloading\n", + "def hybrid_layer(x, w):\n", + " # Move model parameters w1 and w2 to device memory via device_put\n", + " w1, w2 = jax.tree.map(lambda x: jax.device_put(x, s_dev), w)\n", + " x = checkpoint_name(x, \"x\") # Offload activation x to host memory\n", + " y = x @ w1\n", + " return y @ w2, None\n", + "\n", + "def hybrid_scanned(w, x):\n", + " remat_layer = jax.remat(hybrid_layer, # Use hybrid_layer instead of layer\n", + " policy=policy, # Use offloading policy\n", + " prevent_cse=False) # Allow CSE optimizations\n", + " result = jax.lax.scan(remat_layer, x, w)[0]\n", + " return jnp.sum(result)\n", + "\n", + "# Move model parameters w1 and w2 to the host via device_put\n", + "# Initialize input and weights with small values (0.0001)\n", + "wh1 = jax.device_put(w1, s_host)\n", + "wh2 = jax.device_put(w2, s_host)\n", + "\n", + "# Compile and compute gradients of the scanned function\n", + "f = jax.jit(jax.grad(hybrid_scanned)) # Apply JIT compilation to gradient computation\n", + "\n", + "# Analyze memory usage\n", + "compiled_step = f.lower((wh1, wh2), input).compile()\n", + "compiled_stats = compiled_step.memory_analysis()\n", + "\n", + "if compiled_stats is not None:\n", + " total = compiled_stats.temp_size_in_bytes + compiled_stats.argument_size_in_bytes \\\n", + " + compiled_stats.output_size_in_bytes - compiled_stats.alias_size_in_bytes\n", + " print(f\"Temp size: {compiled_stats.temp_size_in_bytes / (1024**2):.2f} MB\")\n", + " print(f\"Argument size: {compiled_stats.argument_size_in_bytes / (1024**2):.2f} MB\")\n", + " print(f\"Total size: {total / (1024**2):.2f} MB\")\n", + "\n", + "result_both = f((wh1, wh2), input) # Execute with both activation and parameter offloading\n", + "\n", + "# Verify numerical correctness\n", + "are_close = jnp.allclose(\n", + " result_activation[0], # Result from activation offloading only\n", + " result_both[0], # Result from both activation and parameter offloading\n", + " rtol=1e-5,\n", + " atol=1e-5\n", + ")\n", + "print(f\"Results match within tolerance: {are_close}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SVpozzwHflQk" + }, + "source": [ + "This implementation demonstrates how offloading model parameters together with activation offloading to host memory can significantly reduce device memory usage.\n", + "\n", + "### Memory Analysis\n", + "\n", + "**Baseline Memory Usage:**\n", + "- Input tensor: 0.25 MB (256 × 256 × 4 bytes)\n", + "- Model parameters (w1, w2): 10 MB each (256 × 1024 × 4 bytes ≈ 1 MB per layer × 10 layers)\n", + "\n", + "**Memory Usage Comparison:**\n", + "- Argument size without parameter offloading: 20.25 MB (0.25 + 10 + 10)\n", + "- Argument size with parameter offloading: 0.25 MB (only input remains)\n", + "- Temporary memory without activation offloading: 17.25 MB\n", + "- Temporary memory with activation offloading: 6.50 MB\n", + "- Temporary memory with activation and parameter offloading: 4.75 MB\n", + "\n", + "#### Key Optimizations\n", + "\n", + "1. **Parameter Offloading**: Moving parameters (w1, w2) to host memory reduces argument size by 20 MB (from 20.25 MB to 0.25 MB).\n", + "\n", + "2. **Activation Offloading**: Moving activations to host memory reduces temporary memory usage by 10.75 MB (from 17.25 to 6.50 MB).\n", + "\n", + "3. **Hybrid Strategy**: The rematerialization of activation offloading helps avoid keeping weights on the device and reduce temporary memory usage by 1.75 MB (from 6.50 MB to 4.75 MB). Without it, JAX would be eager to keep the on-device copies of the weights alive for the backward pass.\n", + "\n", + "#### Results\n", + "\n", + "**Total Memory Savings**: 33.5 MB (20 MB + 10.75 MB + 1.75 MB)\n", + "\n", + "This hybrid approach demonstrates that parameter and activation offloading work synergistically to achieve significant memory reductions while maintaining computational correctness.\n", + "\n", + "### Limitations of Parameter Offloading\n", + "\n", + "{func}`jax.lax.scan` is crucial for effective parameter management. Using an explicit for loop would cause parameters to continuously occupy device memory, resulting in the same memory usage as without parameter offloading. While {func}`jax.lax.scan` allows specifying the scan axis, parameter offloading currently works only when scanning over axis 0. Scanning over other axes generates a `transpose` operation during compilation before returning parameters to the device, which is expensive and not supported on all platforms.\n", + "\n", + "The offloading performance can vary for different device types. It may degrade performance due to memory transfers between host and device, so it's important to consider this trade-off when designing your optimization strategy.\n", + "\n", + "# Optimizer State Offloading\n", + "\n", + "Optimizer state offloading is a memory management technique that stores optimizer states in host memory instead of device memory. This approach is particularly useful when optimizer states are large, as it reduces device memory usage.\n", + "\n", + "A basic JAX implementation using the Adam optimizer can serve as a starting point, where all tensors are stored on the device. This will serve as a reference implementation before introducing optimizer state offloading.\n", + "\n", + "### Basic Implementation\n", + "\n", + "This section, let's implement a simple model with the Adam optimizer. This implementation helps establish the baseline behavior before exploring optimizer state offloading. It is particularly useful for understanding memory patterns in large-scale neural network training.\n", + "\n", + "In the code example below, a neural network training loop is included to use JAX and Optax's Adam optimizer. The network consists of four linear layers with GELU activation functions, processing large matrices of size 7168x7168. The training process involves:\n", + "- Forward pass: The input flows through four layers, each applying a linear transformation followed by GELU activation\n", + "- Loss computation: Calculates mean squared error between output and input, plus L2 regularization\n", + "- Backward pass: Computes gradients using automatic differentiation\n", + "- Optimization step: Updates parameters using Adam optimizer with gradient clipping\n", + "\n", + "The code uses JIT compilation to optimize performance and includes memory usage analysis to monitor the computational resources required during training. The memory analysis provides insights into temporary memory usage, argument sizes, and total memory consumption during the optimization step." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ujvC0YJ2VOyV", + "outputId": "d237ca0a-89ae-4e14-edd3-36cc38890349" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Temp size: 2.11 GB\n", + "Argument size: 2.49 GB\n", + "Total size: 4.59 GB\n" + ] + } + ], + "source": [ + "import optax\n", + "\n", + "DIM = 7168\n", + "\n", + "# Initialize data and parameter w1, w2, w3 and w4\n", + "input = jnp.ones((DIM, DIM))\n", + "params = {f'w{i}': jnp.ones((DIM, DIM)) for i in range(1, 5)}\n", + "\n", + "# Initialize optimizer\n", + "optimizer = optax.chain(\n", + " optax.clip_by_global_norm(1.0),\n", + " optax.adam(learning_rate=0.1)\n", + ")\n", + "opt_state = optimizer.init(params)\n", + "\n", + "def gelu(x):\n", + " return 0.5 * x * (1 + jnp.tanh(jnp.sqrt(2 / jnp.pi) * (x + 0.044715 * x**3)))\n", + "\n", + "def single_layer(x, w):\n", + " return x @ w\n", + "\n", + "def forward(params, x):\n", + " for i in range(1, 5):\n", + " x = gelu(single_layer(x, params[f'w{i}']))\n", + " return x\n", + "\n", + "def compute_loss(params, inputs):\n", + " outputs = forward(params, inputs)\n", + " loss = jnp.mean((outputs - inputs) ** 2)\n", + " l2_reg = 0.001 * sum(jnp.sum(w ** 2) for w in jax.tree_util.tree_leaves(params))\n", + " return loss + l2_reg\n", + "\n", + "def step(params, opt_state, inputs):\n", + " grads = jax.grad(lambda p: compute_loss(p, inputs))(params)\n", + " updates, new_opt_state = optimizer.update(grads, opt_state, params)\n", + " return optax.apply_updates(params, updates), new_opt_state\n", + "\n", + "# JIT compile the step function with proper sharding\n", + "step = jax.jit(step, donate_argnums=(0, 1))\n", + "\n", + "# Run a optimization step\n", + "new_params, new_opt_state = step(params, opt_state, input)\n", + "\n", + "# Analyze memory usage\n", + "compiled_step = step.lower(params, opt_state, input).compile()\n", + "compiled_stats = compiled_step.memory_analysis()\n", + "\n", + "if compiled_stats is not None:\n", + " total = compiled_stats.temp_size_in_bytes + compiled_stats.argument_size_in_bytes \\\n", + " + compiled_stats.output_size_in_bytes - compiled_stats.alias_size_in_bytes\n", + " print(f\"Temp size: {compiled_stats.temp_size_in_bytes / (1024**3):.2f} GB\")\n", + " print(f\"Argument size: {compiled_stats.argument_size_in_bytes / (1024**3):.2f} GB\")\n", + " print(f\"Total size: {total / (1024**3):.2f} GB\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oW4Qm6E5VOyV" + }, + "source": [ + "Optimizer state offloading can be implemented as follows.\n", + "\n", + "### Setting Up Sharding and Memory Kinds\n", + "\n", + "{func}`jax.sharding.SingleDeivceSharding` is adopted to simplify the shardings for both device and host memory kinds. During the model state initialization, move the optimizer state to the host using {func}`device_put`.\n", + "\n", + "### Model and Training Step Implementation\n", + "\n", + "Next, define the model architecture, loss function, and training step. The key addition here is moving the optimizer state to device memory via {func}`device_put` at the beginning of each training step, as it's needed for the parameter update on the device.\n", + "\n", + "### Running and Comparing Results\n", + "\n", + "After setting up the sharding, the optimizer state is moved to host memory and the step function is run with {func}`jax.jit`.\n", + "\n", + "The JIT compilation of the step function uses several important parameters:\n", + "- `donate_argnums=(0,)`: Indicates that the first argument (parameters) can be modified in-place, allowing JAX to reuse its memory\n", + "- `out_shardings`: Specifies how output tensors should be sharded across the mesh (devices and hosts)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "fEDTasJZVOyW", + "outputId": "b36cedd6-cf30-4d36-f4fd-32b2fdfd7564" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Temp size: 1.91 GB\n", + "Argument size: 0.96 MB\n", + "Total size: 2.87 GB\n" + ] + } + ], + "source": [ + "# Create sharding specifications for device and host memory\n", + "s_dev = jax.sharding.SingleDeviceSharding(jax.devices()[0], memory_kind=\"device\")\n", + "s_host = jax.sharding.SingleDeviceSharding(jax.devices()[0], memory_kind=\"pinned_host\")\n", + "\n", + "def step(params, opt_state, inputs):\n", + " grads = jax.grad(lambda p: compute_loss(p, inputs))(params)\n", + " opt_state = jax.device_put(opt_state, s_dev)\n", + " updates, new_opt_state = optimizer.update(grads, opt_state, params)\n", + " new_params = optax.apply_updates(params, updates)\n", + " return new_params, new_opt_state\n", + "\n", + "params = {f'w{i}': jnp.ones((DIM, DIM)) for i in range(1, 5)}\n", + "opt_state = optimizer.init(params)\n", + "\n", + "# Initialize optimizer\n", + "optimizer = optax.chain(\n", + " optax.clip_by_global_norm(1.0),\n", + " optax.adam(learning_rate=0.1)\n", + ")\n", + "\n", + "# Optimizer state is placed on the host during initialization\n", + "opt_state = jax.device_put(opt_state, s_host)\n", + "\n", + "# JIT compile the step function with proper sharding and memory optimization\n", + "step = jax.jit(\n", + " step,\n", + " donate_argnums=(0,),\n", + " out_shardings=(s_dev, s_host)\n", + ")\n", + "\n", + "# Run an optimization step\n", + "new_params, offload_opt_state = step(params, opt_state, input)\n", + "\n", + "# Analyze memory usage\n", + "compiled_step = step.lower(params, opt_state, input).compile()\n", + "compiled_stats = compiled_step.memory_analysis()\n", + "if compiled_stats is not None:\n", + " total = compiled_stats.temp_size_in_bytes + compiled_stats.argument_size_in_bytes \\\n", + " + compiled_stats.output_size_in_bytes - compiled_stats.alias_size_in_bytes\n", + " print(f\"Temp size: {compiled_stats.temp_size_in_bytes / (1024**3):.2f} GB\")\n", + " print(f\"Argument size: {compiled_stats.argument_size_in_bytes / (1024**3):.2f} MB\")\n", + " print(f\"Total size: {total / (1024**3):.2f} GB\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vKo8qYnQVOyW" + }, + "source": [ + "This implementation demonstrates how to:\n", + "1. Set up sharding specifications for `device` and `pinned_host`\n", + "2. Move optimizer states between host and device memory via {func}`jax.device_put`\n", + "3. Use `out_shardings` to ensure proper memory placement\n", + "4. Show the memory usage\n", + "\n", + "This implementation demonstrates how offloading optimizer state to host memory can reduce device memory usage through a trade-off between argument size and temporary memory.\n", + "\n", + "Memory Analysis:\n", + "1. Argument Size Reduction:\n", + " - The optimizer states are arguments of the {func}`jax.jit` function\n", + " - By offloading these states to host memory, the argument size on device is reduced\n", + "\n", + "2. Temporary Memory Impact:\n", + " - Offloading increases temporary memory usage\n", + " - This is because outputs of optimizer states need memory buffers before being copied to host\n", + " - The memory live ranges for these temporary buffers are extended due to the host-device transfers\n", + "\n", + "3. Latency Hiding Scheduling:\n", + " - JAX uses XLA's latency hiding scheduling to overlap computation with host-device transfers\n", + " - The overlapping can cause tensors to have larger live ranges, which increases memory pressure on the device\n", + " - This adaptive behavior helps maintain stable memory usage while still providing some performance benefits\n", + "\n", + "4. Memory Trade-off:\n", + " - Total memory size with offloading: 2.87 GB\n", + " - Total memory size without offloading: 4.59 GB\n", + " - Net memory saving: 1.72 GB\n", + "\n", + "while offloading increases temporary memory usage, the reduction in argument size more than compensates for this increase, resulting in an overall reduction in device memory usage.\n", + "\n", + "Note: The optimizer states can be compared for numerical equivalence using `jax.tree_util.tree_map` and `jnp.allclose`, but this verification step is omitted here for brevity.\n", + "\n", + "## Tools for Host Offloading\n", + "\n", + "{func}`jax.stages.Compiled.memory_analysis` API is utilized above to get memory usage information. For device memory analysis, refer to {doc}`../device_memory_profiling`. The profiling tools described in {doc}`../profiling` can help measure memory savings and performance impact from host offloading." + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [], + "toc_visible": true + }, + "jupytext": { + "formats": "ipynb,md:myst" + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/notebooks/host-offloading.md b/docs/notebooks/host-offloading.md new file mode 100644 index 000000000000..73b5cd819f81 --- /dev/null +++ b/docs/notebooks/host-offloading.md @@ -0,0 +1,640 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.4 +kernelspec: + display_name: Python 3 + name: python3 +--- + ++++ {"id": "bQbS50fIdHw1"} + +(host-offloading)= +# JAX Memories and Host Offloading + + + +This tutorial provides a practical introduction to host offloading techniques in JAX, focusing on: + +- Activation offloading +- Parameter offloading +- Optimizer state offloading + +By applying offloading strategies, developers can better manage memory resources and reduce memory pressure on devices. To implement these strategies effectively, understanding JAX's core mechanisms for data placement and movement is essential. + +## Building Blocks for Offloading + +JAX provides several key components for controlling where and how data are stored and moved between the host and the device memory. The following sections explore: + +- How to specify data distribution with sharding +- How to control memory placement between host and device +- How to manage data movement in jitted functions + +### NamedSharding and Memory Kinds + +{class}`~jax.sharding.NamedSharding` defines how data are distributed across devices. It includes: + +- Basic data distribution configuration +- `memory_kind` parameter for specifying memory type (`device` or `pinned_host`) +- By default, `memory_kind` is set to `device` memory +- `with_memory_kind` method for creating new sharding with modified memory type + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: f-6sxUlqrlBn +outputId: 691a3df2-8341-44a9-a4a0-5521c2d891e3 +--- +import jax +import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P +import numpy as np + +# Create mesh +# 1x1 mesh represents a single device with two named dimensions (x and y) +mesh = Mesh(np.array(jax.devices()[0]).reshape(1, 1), ('x', 'y')) + +# Device sharding - partitions data along x and y dimensions +s_dev = NamedSharding(mesh, P('x', 'y'), memory_kind="device") + +# Host sharding - same partitioning but in pinned host memory +s_host = s_dev.with_memory_kind('pinned_host') + +print(s_dev) # Shows device memory sharding +print(s_host) # Shows pinned host memory sharding +``` + ++++ {"id": "R_pB9465VoMP"} + +### Data Placement with device_put + +{func}`jax.device_put` is a function that explicitly transfers arrays to a specified memory location according to a sharding specification. + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: OJFnf7FGp6Lj +outputId: c762e1df-2453-4ed9-9d53-0defb6a05ce2 +--- +# Create a 2x4 array +arr = jnp.arange(8.0).reshape(2, 4) + +# Move arrays to different memory locations based on sharding objects +arr_host = jax.device_put(arr, s_host) # Places in pinned host memory +arr_dev = jax.device_put(arr, s_dev) # Places in device memory + +# Verify memory locations +print(arr_host.sharding.memory_kind) # Output: pinned_host +print(arr_dev.sharding.memory_kind) # Output: device +``` + ++++ {"id": "HHXvBpQKTMCR"} + +### Output Sharding Controls + +Shardings determine how data is split across devices. JAX provides `out_shardings` to control how output arrays are partitioned when leaving a jitted function. + +Key Features: + - Can differ from input sharding + - Allows different memory kinds for outputs + +Examples: + +#### Device Output Sharding + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: ZXNj9NUeaIdX +outputId: 399321ef-082a-4a77-c33a-9de3421f429b +--- +f = jax.jit(lambda x:x, out_shardings=s_dev) +out_dev = f(arr_host) +print("Result value of H2D: \n", out_dev) +``` + ++++ {"id": "iYXC5ix384XP"} + +Moving data from host to device memory when needed for computation is the essence of host offloading. Use {func}`jax.device_put` to perform this transfer in this example to optimize performance. + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: cmM6tJTS84XQ +outputId: 40c353a1-fb55-44bc-bac9-dffc09852f49 +--- +# Instead of the lambda function, add_func can be defined explicitly +# move data to device before computation +def add_func(x): # Move data to device and add one + x = jax.device_put(x, s_dev) + return x + 1 + +f = jax.jit(add_func, out_shardings=s_dev) +out_dev = f(arr_host) +print("Result value of H2D and add 1 in device memory: \n", out_dev) +``` + ++++ {"id": "EbE-eBrJTBuS"} + +#### Host Output Sharding + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: FjZzkxI8ky4r +outputId: 2a1b6e7a-1c29-4347-c020-7b47c27a5cc3 +--- +f = jax.jit(lambda x: x, out_shardings=s_host) +out_host = f(arr_dev) # Input arrays in the device memory while output arrays in the host memory +print("Result value of D2H: \n", out_host) +``` + ++++ {"id": "UhLVvRO2p6Lj"} + +## Activation Offloading + +Before diving into activation offloading, let's first take a look at the baseline code. + +This code implements a simple neural network with 10 layers, each consisting of two linear transformations. The code demonstrates basic memory usage patterns and provides a foundation for comparing offloading optimization techniques. + +Key components: +- Each layer consists of two sequential linear operations: + 1. First multiplication: `x @ w1` + 2. Second multiplication: `y @ w2` +- 10-layer network using JAX's scan operation +- Memory usage analysis +- Gradient computation with JIT compilation + +To analyze memory usage in JAX, the {func}`jax.stages.Compiled.memory_analysis` method can be used on a compiled function. This provides detailed statistics about memory consumption during computation. The key metrics include temporary memory size, argument size, output size, and alias size. To calculate the total memory usage, sum the temporary, argument, and output sizes, then subtract the alias size to avoid double-counting the same memory multiple times. This provides a summarized view of how the device memory is utilized across different aspects of the computation. + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: UEt0dtxukkaz +outputId: 22bb32b7-8491-4100-f212-e56c50f44cfa +--- +# Initialize input and weights with small values (0.0001) +input = jnp.ones((256, 256), dtype=jnp.float32) * 0.001 # Input matrix: 256 x 256 +w1 = jnp.ones((10, 256, 1024), dtype=jnp.float32) * 0.001 # 10 layers of 256 x 1024 matrices +w2 = jnp.ones((10, 1024, 256), dtype=jnp.float32) * 0.001 # 10 layers of 1024 x 256 matrices + +def two_layers(x, w): + # Simple two-layer linear transformation + w1, w2 = w + y = x @ w1 + return y @ w2, None + +def scanned(w, x): + # Applies the layer function 10 times using JAX's scan operation + # Input: w (tuple of weight matrices), x (input matrix) + # Output: sum of the final layer's output + result = jax.lax.scan(two_layers, x, w)[0] + return jnp.sum(result) + +# Compile and compute gradients of the scanned function +f = jax.jit(jax.grad(scanned)) # Apply JIT compilation to gradient computation + +# Analyze memory usage +compiled_step = f.lower((w1, w2), input).compile() +compiled_stats = compiled_step.memory_analysis() + +if compiled_stats is not None: + # Calculate total memory usage including temporary storage, arguments, and outputs + # Subtract alias size to avoid double-counting memory shared between different components + total = compiled_stats.temp_size_in_bytes + compiled_stats.argument_size_in_bytes \ + + compiled_stats.output_size_in_bytes - compiled_stats.alias_size_in_bytes + print(f"Temp size: {compiled_stats.temp_size_in_bytes / (1024**2):.2f} MB") + print(f"Argument size: {compiled_stats.argument_size_in_bytes / (1024**2):.2f} MB") + print(f"Total size: {total/(1024**2):.2f} MB") + +# Execute the function and print sample results +result = f((w1, w2), input) # Execute the function with weights and input +print("Sample of results: ", result[0][0, 0, :5]) +``` + ++++ {"id": "DnFyRt2nkkaz"} + +The detailed coverage of activation offloading can be found in the {ref}`gradient-checkpointing` tutorial. Activation offloading helps manage memory by moving intermediate activations to host memory after the forward pass, and bringing them back to device memory during the backward pass when needed for gradient computation. + +To implement activation offloading effectively, it is important to understand checkpoint names and policies. Here's how they work in a simple example: + +### Checkpoint Names + +The {func}`checkpoint_name` function allows labeling activations for memory management during computation. Here's a simple example that a checkpoint name `x` is specified. + +```{code-cell} ipython3 +:id: sLO9ceS6p6Lj + +from jax.ad_checkpoint import checkpoint_name + +def layer_name(x, w): + w1, w2 = w + x = checkpoint_name(x, "x") + y = x @ w1 + return y @ w2, None +``` + ++++ {"id": "-_T92oCOp6Lk"} + +The checkpoint name helps the system decide whether to: +* Keep the activation in device memory or +* Offload it to host memory during computation + +This pattern is common in neural networks, where multiple transformations are applied sequentially to input data. + +### Checkpoint Policies + +This checkpoint policy implements a memory management strategy that optimizes memory usage during computation. It manages memory by handling intermediate values through three strategies: +1. Recomputing during backward pass (default behavior) +2. Storing on device +3. Offloading to host memory after forward pass and loading back during backward pass + +```{code-cell} ipython3 +:id: W8Usw_wOp6Lk + +from jax import checkpoint_policies as cp + +policy = cp.save_and_offload_only_these_names( + names_which_can_be_saved=[], # No values stored on device + names_which_can_be_offloaded=["x"], # Offload activations labeled "x" + offload_src="device", # Move from device memory + offload_dst="pinned_host" # To pinned host memory +) +``` + ++++ {"id": "iuDRCXu7ky4r"} + +{func}`jax.lax.scan` is commonly used in JAX for handling sequential operations (like RNNs or transformers). It can be integrated with JAX's rematerialization to process sequential data. + +Key components: +* {func}`jax.remat` creates a rematerialized version of the layer function using {func}`jax.remat` and applies the checkpoint policy to the layer function +* `prevent_cse=False` enables XLA's common subexpression elimination for better performance +* {func}`jax.lax.scan` iterates the rematerialized layer along an axis + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: xCrxjTx_p6Lk +outputId: 13d46584-9b25-4622-b3c3-f50c1dac02c2 +--- +def scanned(w, x): + remat_layer = jax.remat(layer_name, + policy=policy, # Use our offloading policy + prevent_cse=False) # Allow CSE optimizations + result = jax.lax.scan(remat_layer, x, w)[0] + return jnp.sum(result) + +# Initialize input and weights with small values (0.0001) +input = jnp.ones((256, 256), dtype=jnp.float32) * 0.001 # Input matrix: 256 x 256 +w1 = jnp.ones((10, 256, 1024), dtype=jnp.float32) * 0.001 # 10 layers of 256 x 1024 matrices +w2 = jnp.ones((10, 1024, 256), dtype=jnp.float32) * 0.001 # 10 layers of 1024 x 256 matrices + +# Compile and compute gradients of the scanned function +f = jax.jit(jax.grad(scanned)) # Apply JIT compilation to gradient computation + +# Analyze memory usage +compiled_step = f.lower((w1, w2), input).compile() +compiled_stats = compiled_step.memory_analysis() + +if compiled_stats is not None: + total = compiled_stats.temp_size_in_bytes + compiled_stats.argument_size_in_bytes \ + + compiled_stats.output_size_in_bytes - compiled_stats.alias_size_in_bytes + print(f"Temp size: {compiled_stats.temp_size_in_bytes / (1024**2):.2f} MB") + print(f"Argument size: {compiled_stats.argument_size_in_bytes / (1024**2):.2f} MB") + print(f"Total size: {total/(1024**2):.2f} MB") + +result_activation = f((w1, w2), input) # Execute the function with weights and input +# Verify numerical correctness +are_close = jnp.allclose( + result_activation[0], # Result from activation offloading only + result[0], # Result from both activation and parameter offloading + rtol=1e-5, + atol=1e-5 +) +print(f"Results match within tolerance: {are_close}") +print("Sample of results: ", result_activation[0][0, 0, :5]) +``` + ++++ {"id": "0tx7aara42pY"} + +Activation offloading reduces temporary memory usage from 17.25 MB to 6.5 MB while input and output argument sizes remain the same. Totally 10.75 MB is saved. It is achieved by offloading activation `x` to host memory after the forward pass and loading it back to device memory before the backward pass. + +### Summary of Activation Offloading + +Activation offloading provides a powerful way to manage memory in large computations by: + +* Using checkpoint names to mark specific activations +* Applying policies to control where and how activations are stored +* Supporting common JAX patterns like scan operations +* Moving selected activations to host memory when device memory is under budget + +This approach is particularly useful when working with large models that would otherwise exceed device memory capacity. + +## Parameter Offloading + +Model parameters (also known as weights) can be offloaded to the host memory to optimize device memory usage during initialization. This is achieved by using {func}`jax.jit` with a sharding strategy that specifies host memory kind. + +While parameter offloading and activation offloading are distinct memory optimization techniques, the following example demonstrates parameter offloading built upon the activation offloading implementation shown earlier. + +### Parameter Placement for Computation + +Different from the earlier `layer` function, {func}`jax.device_put` is applied to move parameter `w1` and `w2` to the device before the matrix multiplications. This ensures the parameters are available on the device for both forward and backward passes. + +Note that the activation offloading implementation remains unchanged, using the same: +* Checkpoint name `"x"` +* Checkpoint policy +* `scanned` function combining {func}`jax.remat` and {func}`jax.lax.scan` + +### Parameter Initialization with Host Offloading + +During the initialization, parameter `w1` and `w2` are placed on host memory before being passed to the {func}`jax.jit` function `f`, while keeping the `input` variable on the device. + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: 1qGN2hBQdheo +outputId: 48c09658-f8b6-4be3-ef0e-02e0e2566e10 +--- +# Hybrid version: Both activation and parameter offloading +def hybrid_layer(x, w): + # Move model parameters w1 and w2 to device memory via device_put + w1, w2 = jax.tree.map(lambda x: jax.device_put(x, s_dev), w) + x = checkpoint_name(x, "x") # Offload activation x to host memory + y = x @ w1 + return y @ w2, None + +def hybrid_scanned(w, x): + remat_layer = jax.remat(hybrid_layer, # Use hybrid_layer instead of layer + policy=policy, # Use offloading policy + prevent_cse=False) # Allow CSE optimizations + result = jax.lax.scan(remat_layer, x, w)[0] + return jnp.sum(result) + +# Move model parameters w1 and w2 to the host via device_put +# Initialize input and weights with small values (0.0001) +wh1 = jax.device_put(w1, s_host) +wh2 = jax.device_put(w2, s_host) + +# Compile and compute gradients of the scanned function +f = jax.jit(jax.grad(hybrid_scanned)) # Apply JIT compilation to gradient computation + +# Analyze memory usage +compiled_step = f.lower((wh1, wh2), input).compile() +compiled_stats = compiled_step.memory_analysis() + +if compiled_stats is not None: + total = compiled_stats.temp_size_in_bytes + compiled_stats.argument_size_in_bytes \ + + compiled_stats.output_size_in_bytes - compiled_stats.alias_size_in_bytes + print(f"Temp size: {compiled_stats.temp_size_in_bytes / (1024**2):.2f} MB") + print(f"Argument size: {compiled_stats.argument_size_in_bytes / (1024**2):.2f} MB") + print(f"Total size: {total / (1024**2):.2f} MB") + +result_both = f((wh1, wh2), input) # Execute with both activation and parameter offloading + +# Verify numerical correctness +are_close = jnp.allclose( + result_activation[0], # Result from activation offloading only + result_both[0], # Result from both activation and parameter offloading + rtol=1e-5, + atol=1e-5 +) +print(f"Results match within tolerance: {are_close}") +``` + ++++ {"id": "SVpozzwHflQk"} + +This implementation demonstrates how offloading model parameters together with activation offloading to host memory can significantly reduce device memory usage. + +### Memory Analysis + +**Baseline Memory Usage:** +- Input tensor: 0.25 MB (256 × 256 × 4 bytes) +- Model parameters (w1, w2): 10 MB each (256 × 1024 × 4 bytes ≈ 1 MB per layer × 10 layers) + +**Memory Usage Comparison:** +- Argument size without parameter offloading: 20.25 MB (0.25 + 10 + 10) +- Argument size with parameter offloading: 0.25 MB (only input remains) +- Temporary memory without activation offloading: 17.25 MB +- Temporary memory with activation offloading: 6.50 MB +- Temporary memory with activation and parameter offloading: 4.75 MB + +#### Key Optimizations + +1. **Parameter Offloading**: Moving parameters (w1, w2) to host memory reduces argument size by 20 MB (from 20.25 MB to 0.25 MB). + +2. **Activation Offloading**: Moving activations to host memory reduces temporary memory usage by 10.75 MB (from 17.25 to 6.50 MB). + +3. **Hybrid Strategy**: The rematerialization of activation offloading helps avoid keeping weights on the device and reduce temporary memory usage by 1.75 MB (from 6.50 MB to 4.75 MB). Without it, JAX would be eager to keep the on-device copies of the weights alive for the backward pass. + +#### Results + +**Total Memory Savings**: 33.5 MB (20 MB + 10.75 MB + 1.75 MB) + +This hybrid approach demonstrates that parameter and activation offloading work synergistically to achieve significant memory reductions while maintaining computational correctness. + +### Limitations of Parameter Offloading + +{func}`jax.lax.scan` is crucial for effective parameter management. Using an explicit for loop would cause parameters to continuously occupy device memory, resulting in the same memory usage as without parameter offloading. While {func}`jax.lax.scan` allows specifying the scan axis, parameter offloading currently works only when scanning over axis 0. Scanning over other axes generates a `transpose` operation during compilation before returning parameters to the device, which is expensive and not supported on all platforms. + +The offloading performance can vary for different device types. It may degrade performance due to memory transfers between host and device, so it's important to consider this trade-off when designing your optimization strategy. + +# Optimizer State Offloading + +Optimizer state offloading is a memory management technique that stores optimizer states in host memory instead of device memory. This approach is particularly useful when optimizer states are large, as it reduces device memory usage. + +A basic JAX implementation using the Adam optimizer can serve as a starting point, where all tensors are stored on the device. This will serve as a reference implementation before introducing optimizer state offloading. + +### Basic Implementation + +This section, let's implement a simple model with the Adam optimizer. This implementation helps establish the baseline behavior before exploring optimizer state offloading. It is particularly useful for understanding memory patterns in large-scale neural network training. + +In the code example below, a neural network training loop is included to use JAX and Optax's Adam optimizer. The network consists of four linear layers with GELU activation functions, processing large matrices of size 7168x7168. The training process involves: +- Forward pass: The input flows through four layers, each applying a linear transformation followed by GELU activation +- Loss computation: Calculates mean squared error between output and input, plus L2 regularization +- Backward pass: Computes gradients using automatic differentiation +- Optimization step: Updates parameters using Adam optimizer with gradient clipping + +The code uses JIT compilation to optimize performance and includes memory usage analysis to monitor the computational resources required during training. The memory analysis provides insights into temporary memory usage, argument sizes, and total memory consumption during the optimization step. + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: ujvC0YJ2VOyV +outputId: d237ca0a-89ae-4e14-edd3-36cc38890349 +--- +import optax + +DIM = 7168 + +# Initialize data and parameter w1, w2, w3 and w4 +input = jnp.ones((DIM, DIM)) +params = {f'w{i}': jnp.ones((DIM, DIM)) for i in range(1, 5)} + +# Initialize optimizer +optimizer = optax.chain( + optax.clip_by_global_norm(1.0), + optax.adam(learning_rate=0.1) +) +opt_state = optimizer.init(params) + +def gelu(x): + return 0.5 * x * (1 + jnp.tanh(jnp.sqrt(2 / jnp.pi) * (x + 0.044715 * x**3))) + +def single_layer(x, w): + return x @ w + +def forward(params, x): + for i in range(1, 5): + x = gelu(single_layer(x, params[f'w{i}'])) + return x + +def compute_loss(params, inputs): + outputs = forward(params, inputs) + loss = jnp.mean((outputs - inputs) ** 2) + l2_reg = 0.001 * sum(jnp.sum(w ** 2) for w in jax.tree_util.tree_leaves(params)) + return loss + l2_reg + +def step(params, opt_state, inputs): + grads = jax.grad(lambda p: compute_loss(p, inputs))(params) + updates, new_opt_state = optimizer.update(grads, opt_state, params) + return optax.apply_updates(params, updates), new_opt_state + +# JIT compile the step function with proper sharding +step = jax.jit(step, donate_argnums=(0, 1)) + +# Run a optimization step +new_params, new_opt_state = step(params, opt_state, input) + +# Analyze memory usage +compiled_step = step.lower(params, opt_state, input).compile() +compiled_stats = compiled_step.memory_analysis() + +if compiled_stats is not None: + total = compiled_stats.temp_size_in_bytes + compiled_stats.argument_size_in_bytes \ + + compiled_stats.output_size_in_bytes - compiled_stats.alias_size_in_bytes + print(f"Temp size: {compiled_stats.temp_size_in_bytes / (1024**3):.2f} GB") + print(f"Argument size: {compiled_stats.argument_size_in_bytes / (1024**3):.2f} GB") + print(f"Total size: {total / (1024**3):.2f} GB") +``` + ++++ {"id": "oW4Qm6E5VOyV"} + +Optimizer state offloading can be implemented as follows. + +### Setting Up Sharding and Memory Kinds + +{func}`jax.sharding.SingleDeivceSharding` is adopted to simplify the shardings for both device and host memory kinds. During the model state initialization, move the optimizer state to the host using {func}`device_put`. + +### Model and Training Step Implementation + +Next, define the model architecture, loss function, and training step. The key addition here is moving the optimizer state to device memory via {func}`device_put` at the beginning of each training step, as it's needed for the parameter update on the device. + +### Running and Comparing Results + +After setting up the sharding, the optimizer state is moved to host memory and the step function is run with {func}`jax.jit`. + +The JIT compilation of the step function uses several important parameters: +- `donate_argnums=(0,)`: Indicates that the first argument (parameters) can be modified in-place, allowing JAX to reuse its memory +- `out_shardings`: Specifies how output tensors should be sharded across the mesh (devices and hosts) + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: fEDTasJZVOyW +outputId: b36cedd6-cf30-4d36-f4fd-32b2fdfd7564 +--- +# Create sharding specifications for device and host memory +s_dev = jax.sharding.SingleDeviceSharding(jax.devices()[0], memory_kind="device") +s_host = jax.sharding.SingleDeviceSharding(jax.devices()[0], memory_kind="pinned_host") + +def step(params, opt_state, inputs): + grads = jax.grad(lambda p: compute_loss(p, inputs))(params) + opt_state = jax.device_put(opt_state, s_dev) + updates, new_opt_state = optimizer.update(grads, opt_state, params) + new_params = optax.apply_updates(params, updates) + return new_params, new_opt_state + +params = {f'w{i}': jnp.ones((DIM, DIM)) for i in range(1, 5)} +opt_state = optimizer.init(params) + +# Initialize optimizer +optimizer = optax.chain( + optax.clip_by_global_norm(1.0), + optax.adam(learning_rate=0.1) +) + +# Optimizer state is placed on the host during initialization +opt_state = jax.device_put(opt_state, s_host) + +# JIT compile the step function with proper sharding and memory optimization +step = jax.jit( + step, + donate_argnums=(0,), + out_shardings=(s_dev, s_host) +) + +# Run an optimization step +new_params, offload_opt_state = step(params, opt_state, input) + +# Analyze memory usage +compiled_step = step.lower(params, opt_state, input).compile() +compiled_stats = compiled_step.memory_analysis() +if compiled_stats is not None: + total = compiled_stats.temp_size_in_bytes + compiled_stats.argument_size_in_bytes \ + + compiled_stats.output_size_in_bytes - compiled_stats.alias_size_in_bytes + print(f"Temp size: {compiled_stats.temp_size_in_bytes / (1024**3):.2f} GB") + print(f"Argument size: {compiled_stats.argument_size_in_bytes / (1024**3):.2f} MB") + print(f"Total size: {total / (1024**3):.2f} GB") +``` + ++++ {"id": "vKo8qYnQVOyW"} + +This implementation demonstrates how to: +1. Set up sharding specifications for `device` and `pinned_host` +2. Move optimizer states between host and device memory via {func}`jax.device_put` +3. Use `out_shardings` to ensure proper memory placement +4. Show the memory usage + +This implementation demonstrates how offloading optimizer state to host memory can reduce device memory usage through a trade-off between argument size and temporary memory. + +Memory Analysis: +1. Argument Size Reduction: + - The optimizer states are arguments of the {func}`jax.jit` function + - By offloading these states to host memory, the argument size on device is reduced + +2. Temporary Memory Impact: + - Offloading increases temporary memory usage + - This is because outputs of optimizer states need memory buffers before being copied to host + - The memory live ranges for these temporary buffers are extended due to the host-device transfers + +3. Latency Hiding Scheduling: + - JAX uses XLA's latency hiding scheduling to overlap computation with host-device transfers + - The overlapping can cause tensors to have larger live ranges, which increases memory pressure on the device + - This adaptive behavior helps maintain stable memory usage while still providing some performance benefits + +4. Memory Trade-off: + - Total memory size with offloading: 2.87 GB + - Total memory size without offloading: 4.59 GB + - Net memory saving: 1.72 GB + +while offloading increases temporary memory usage, the reduction in argument size more than compensates for this increase, resulting in an overall reduction in device memory usage. + +Note: The optimizer states can be compared for numerical equivalence using `jax.tree_util.tree_map` and `jnp.allclose`, but this verification step is omitted here for brevity. + +## Tools for Host Offloading + +{func}`jax.stages.Compiled.memory_analysis` API is utilized above to get memory usage information. For device memory analysis, refer to {doc}`../device_memory_profiling`. The profiling tools described in {doc}`../profiling` can help measure memory savings and performance impact from host offloading. diff --git a/docs/notebooks/layout.ipynb b/docs/notebooks/layout.ipynb new file mode 100644 index 000000000000..144ff15130c3 --- /dev/null +++ b/docs/notebooks/layout.ipynb @@ -0,0 +1,276 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Device-local array layout control\n", + "\n", + "The `jax.experimental.layout` package provides ways to control\n", + "how JAX arrays are laid out in device-local memory.\n", + "\n", + "## Terminology\n", + "\n", + "Array layout is tightly coupled with array\n", + "[sharding](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html>).\n", + "Together, a layout and a sharding fully describes how an array's\n", + "values are laid out across (distributed) memories. Along these lines,\n", + "we use the following terminology:\n", + "\n", + "* **Layout**: how an array's values are laid out within each memory in\n", + " which they reside (e.g., in the memory of a single device\n", + " memory). A typical layout specification is a minor-to-major order\n", + " listing of array dimensions.\n", + "* **Sharding**: how an array's values are distributed *across*\n", + " different memory spaces, such as multiple device memories\n", + " (e.g. described by sharding some dimensions and replicating\n", + " others).\n", + "* **Format**: the pairing of **layout** and **sharding**,\n", + " providing a complete picture of an array's memory placement.\n", + "\n", + "## Types\n", + "\n", + "There are two Python types that come up when controlling array\n", + "layouts: `Layout` and `Format`.\n", + "\n", + "* The `Layout` class is used to define the in-memory\n", + " layout of an array. It has the following key attributes:\n", + "\n", + " * `major_to_minor`: A tuple of integers specifying the dimension\n", + " ordering in memory. For example, for a 2-dimensional array, `(0, 1)`\n", + " indicates row-major layout and `(1, 0)` indicates column-major.\n", + "\n", + " * `_tiling`: An intentionally hidden, highly experimental, optional\n", + " attribute to specify a tiled layout.\n", + "\n", + " * `AUTO`: A special, static sentinel object that can be used with\n", + " `jax.jit` to request that the compiler automatically determine\n", + " a good layout for a compiled function's input or output arrays.\n", + "\n", + "* The `Format` class carries both a `Layout` and a `Sharding`, with\n", + " either one taking on a default value when it is not specified.\n", + " When the layout is explicitly specified, the sharding must be\n", + " as well.\n", + "\n", + "JAX API functions, such as `jax.jit` and `jax.device_put`, accept\n", + "`Sharding`s for sharding control or `Format`s for additional layout\n", + "control. They typically do not accept `Layout` instances directly.\n", + "\n", + "## Specifying and reading layouts\n", + "\n", + "By passing `Format` objects to `jax.jit` in place of shardings (in the\n", + "`in_shardings` and `out_shardings` arguments), you can guide the\n", + "compiler's layout decisions. Similarly you can pass `Format`s instead\n", + "of `Sharding`s to `jax.device_put` to control the layout of the\n", + "resulting array.\n", + "\n", + "Let's see an example that uses both explicit and automatic layouts (as\n", + "in `Layout.AUTO`). Imagine we have two compiled functions, `init_fn`\n", + "and `apply_fn`. Say we expect `init_fn` to be called roughly once, but\n", + "`apply_fn` to be called on the output of `init_fn` many times, so that\n", + "we care much more about the performance of `apply_fn`. We may want to\n", + "have the compiler choose a good layout for `apply_fn` and constrain\n", + "`init_fn` to produce arrays of such layout. We can do this as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import jax, jax.numpy as jnp\n", + "from jax.experimental.layout import Layout, Format\n", + "from jax.sharding import SingleDeviceSharding\n", + "import numpy as np\n", + "\n", + "def init_fn(x, y):\n", + " return x * 2, y * 3\n", + "\n", + "def apply_fn(x, y):\n", + " return x[0, :], y[:, 0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Since `apply_fn` reads a contiguous column of its second argument `y`,\n", + "it makes sense to lay it out in column-major order (where columns are\n", + "stored contiguously). Using `Layout.AUTO`, we can ask the compiler to\n", + "infer good input layouts and see that it indeed chooses to request the\n", + "second argument in column-major layout." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "shape = (4 * 128, 8 * 128)\n", + "duck = jax.ShapeDtypeStruct(shape, jnp.float32)\n", + "\n", + "# Compile the `apply` function with layouts inferred automatically\n", + "apply_exe = jax.jit(\n", + " apply_fn,\n", + " in_shardings=Format(Layout.AUTO),\n", + " out_shardings=Format(Layout.AUTO),\n", + ").trace(duck, duck).lower().compile()\n", + "\n", + "# Read back the inferred input layout\n", + "arg_formats, kwarg_formats = apply_exe.input_formats\n", + "assert len(kwarg_formats) == 0\n", + "assert arg_formats[0].layout.major_to_minor == (0, 1)\n", + "assert arg_formats[1].layout.major_to_minor == (1, 0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can then compile `init_fn` to explicitly match this layout in its\n", + "outputs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "init_exe = jax.jit(init_fn, out_shardings=arg_formats).trace(\n", + " duck, duck).lower().compile()\n", + "\n", + "assert init_exe.output_formats == arg_formats" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally we can see how the compiled `apply_fn` behaves when called\n", + "with differently laid out input arrays. The behavior varies with\n", + "whether inputs are\n", + "[committed](https://docs.jax.dev/en/latest/faq.html#controlling-data-and-computation-placement-on-devices). As\n", + "the following test demonstrates, if the argument arrays are committed,\n", + "then the pre-compiled `apply_fn` requires they match the layout\n", + "determined by the compiler above. Meanwhile it accepts uncommitted\n", + "arrays of any layout (including, of course, the inferred layout). In\n", + "this case, the arrays may be relaid out prior to invoking the compiled\n", + "computation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-- uncommitted with mismatched layout:\n", + "x major_to_minor = (0, 1)\n", + "y major_to_minor = (0, 1)\n", + "-> `apply` called successfully\n", + "\n", + "-- uncommitted with matching layout:\n", + "x major_to_minor = (0, 1)\n", + "y major_to_minor = (1, 0)\n", + "-> `apply` called successfully\n", + "\n", + "-- committed with matching layout:\n", + "x major_to_minor = (0, 1)\n", + "y major_to_minor = (1, 0)\n", + "-> `apply` called successfully\n", + "\n", + "-- committed with mismatched layout:\n", + "x major_to_minor = (0, 1)\n", + "y major_to_minor = (0, 1)\n", + "-> error: mismatched input layouts\n", + "\n" + ] + } + ], + "source": [ + "def test(x, y, msg):\n", + " print(f'-- {msg}:')\n", + " print('x major_to_minor =', x.format.layout.major_to_minor)\n", + " print('y major_to_minor =', y.format.layout.major_to_minor)\n", + " try:\n", + " apply_exe(x, y)\n", + " print('-> `apply` called successfully')\n", + " except ValueError as e:\n", + " assert 'does not match' in str(e)\n", + " print('-> error: mismatched input layouts')\n", + " print()\n", + "\n", + "dev = jax.devices()[0]\n", + "\n", + "x1 = y1 = jnp.ones(shape)\n", + "test(x1, y1, 'uncommitted with mismatched layout')\n", + "\n", + "x2, y2 = init_exe(x1, y1)\n", + "test(x2, y2, 'uncommitted with matching layout')\n", + "\n", + "x3 = jnp.ones(shape)\n", + "y3 = jax.device_put(np.ones(shape), Format(Layout(major_to_minor=(1, 0)),\n", + " SingleDeviceSharding(dev)))\n", + "test(x3, y3, 'committed with matching layout')\n", + "\n", + "x4 = jnp.ones(shape)\n", + "y4 = jax.device_put(np.ones(shape), Format(Layout(major_to_minor=(0, 1)),\n", + " SingleDeviceSharding(dev)))\n", + "test(x4, y4, 'committed with mismatched layout')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Constraining intermediate layouts\n", + "\n", + "We can also enforce a specific layout on an intermediate value within\n", + "a JIT-compiled function using `with_layout_constraint`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from jax.experimental.layout import with_layout_constraint\n", + "\n", + "@jax.jit\n", + "def f(x):\n", + " y = x.T\n", + " # Enforce a specific layout on `y`\n", + " y = with_layout_constraint(y, Layout(major_to_minor=(0, 1)))\n", + " return y * 2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is analogous to\n", + "[`jax.lax.with_sharding_constraint`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.with_sharding_constraint.html),\n", + "for constraining layouts rather than shardings." + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,md:myst" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/notebooks/layout.md b/docs/notebooks/layout.md new file mode 100644 index 000000000000..74dacb04eb6c --- /dev/null +++ b/docs/notebooks/layout.md @@ -0,0 +1,191 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.4 +kernelspec: + display_name: Python 3 + language: python + name: python3 +--- + +# Device-local array layout control + +The `jax.experimental.layout` package provides ways to control +how JAX arrays are laid out in device-local memory. + +## Terminology + +Array layout is tightly coupled with array +[sharding](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html>). +Together, a layout and a sharding fully describes how an array's +values are laid out across (distributed) memories. Along these lines, +we use the following terminology: + +* **Layout**: how an array's values are laid out within each memory in + which they reside (e.g., in the memory of a single device + memory). A typical layout specification is a minor-to-major order + listing of array dimensions. +* **Sharding**: how an array's values are distributed *across* + different memory spaces, such as multiple device memories + (e.g. described by sharding some dimensions and replicating + others). +* **Format**: the pairing of **layout** and **sharding**, + providing a complete picture of an array's memory placement. + +## Types + +There are two Python types that come up when controlling array +layouts: `Layout` and `Format`. + +* The `Layout` class is used to define the in-memory + layout of an array. It has the following key attributes: + + * `major_to_minor`: A tuple of integers specifying the dimension + ordering in memory. For example, for a 2-dimensional array, `(0, 1)` + indicates row-major layout and `(1, 0)` indicates column-major. + + * `_tiling`: An intentionally hidden, highly experimental, optional + attribute to specify a tiled layout. + + * `AUTO`: A special, static sentinel object that can be used with + `jax.jit` to request that the compiler automatically determine + a good layout for a compiled function's input or output arrays. + +* The `Format` class carries both a `Layout` and a `Sharding`, with + either one taking on a default value when it is not specified. + When the layout is explicitly specified, the sharding must be + as well. + +JAX API functions, such as `jax.jit` and `jax.device_put`, accept +`Sharding`s for sharding control or `Format`s for additional layout +control. They typically do not accept `Layout` instances directly. + +## Specifying and reading layouts + +By passing `Format` objects to `jax.jit` in place of shardings (in the +`in_shardings` and `out_shardings` arguments), you can guide the +compiler's layout decisions. Similarly you can pass `Format`s instead +of `Sharding`s to `jax.device_put` to control the layout of the +resulting array. + +Let's see an example that uses both explicit and automatic layouts (as +in `Layout.AUTO`). Imagine we have two compiled functions, `init_fn` +and `apply_fn`. Say we expect `init_fn` to be called roughly once, but +`apply_fn` to be called on the output of `init_fn` many times, so that +we care much more about the performance of `apply_fn`. We may want to +have the compiler choose a good layout for `apply_fn` and constrain +`init_fn` to produce arrays of such layout. We can do this as follows: + +```{code-cell} +import jax, jax.numpy as jnp +from jax.experimental.layout import Layout, Format +from jax.sharding import SingleDeviceSharding +import numpy as np + +def init_fn(x, y): + return x * 2, y * 3 + +def apply_fn(x, y): + return x[0, :], y[:, 0] +``` + +Since `apply_fn` reads a contiguous column of its second argument `y`, +it makes sense to lay it out in column-major order (where columns are +stored contiguously). Using `Layout.AUTO`, we can ask the compiler to +infer good input layouts and see that it indeed chooses to request the +second argument in column-major layout. + +```{code-cell} +shape = (4 * 128, 8 * 128) +duck = jax.ShapeDtypeStruct(shape, jnp.float32) + +# Compile the `apply` function with layouts inferred automatically +apply_exe = jax.jit( + apply_fn, + in_shardings=Format(Layout.AUTO), + out_shardings=Format(Layout.AUTO), +).trace(duck, duck).lower().compile() + +# Read back the inferred input layout +arg_formats, kwarg_formats = apply_exe.input_formats +assert len(kwarg_formats) == 0 +assert arg_formats[0].layout.major_to_minor == (0, 1) +assert arg_formats[1].layout.major_to_minor == (1, 0) +``` + +We can then compile `init_fn` to explicitly match this layout in its +outputs. + +```{code-cell} +init_exe = jax.jit(init_fn, out_shardings=arg_formats).trace( + duck, duck).lower().compile() + +assert init_exe.output_formats == arg_formats +``` + +Finally we can see how the compiled `apply_fn` behaves when called +with differently laid out input arrays. The behavior varies with +whether inputs are +[committed](https://docs.jax.dev/en/latest/faq.html#controlling-data-and-computation-placement-on-devices). As +the following test demonstrates, if the argument arrays are committed, +then the pre-compiled `apply_fn` requires they match the layout +determined by the compiler above. Meanwhile it accepts uncommitted +arrays of any layout (including, of course, the inferred layout). In +this case, the arrays may be relaid out prior to invoking the compiled +computation. + +```{code-cell} +def test(x, y, msg): + print(f'-- {msg}:') + print('x major_to_minor =', x.format.layout.major_to_minor) + print('y major_to_minor =', y.format.layout.major_to_minor) + try: + apply_exe(x, y) + print('-> `apply` called successfully') + except ValueError as e: + assert 'does not match' in str(e) + print('-> error: mismatched input layouts') + print() + +dev = jax.devices()[0] + +x1 = y1 = jnp.ones(shape) +test(x1, y1, 'uncommitted with mismatched layout') + +x2, y2 = init_exe(x1, y1) +test(x2, y2, 'uncommitted with matching layout') + +x3 = jnp.ones(shape) +y3 = jax.device_put(np.ones(shape), Format(Layout(major_to_minor=(1, 0)), + SingleDeviceSharding(dev))) +test(x3, y3, 'committed with matching layout') + +x4 = jnp.ones(shape) +y4 = jax.device_put(np.ones(shape), Format(Layout(major_to_minor=(0, 1)), + SingleDeviceSharding(dev))) +test(x4, y4, 'committed with mismatched layout') +``` + +## Constraining intermediate layouts + +We can also enforce a specific layout on an intermediate value within +a JIT-compiled function using `with_layout_constraint`: + +```{code-cell} +from jax.experimental.layout import with_layout_constraint + +@jax.jit +def f(x): + y = x.T + # Enforce a specific layout on `y` + y = with_layout_constraint(y, Layout(major_to_minor=(0, 1))) + return y * 2 +``` + +This is analogous to +[`jax.lax.with_sharding_constraint`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.with_sharding_constraint.html), +for constraining layouts rather than shardings. diff --git a/docs/notebooks/neural_network_with_tfds_data.ipynb b/docs/notebooks/neural_network_with_tfds_data.ipynb index c31a99746866..a909d9329e24 100644 --- a/docs/notebooks/neural_network_with_tfds_data.ipynb +++ b/docs/notebooks/neural_network_with_tfds_data.ipynb @@ -46,7 +46,7 @@ "\n", "![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png)\n", "\n", - "Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).\n", + "Let's combine everything we showed in the [quickstart](https://docs.jax.dev/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).\n", "\n", "Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model." ] diff --git a/docs/notebooks/neural_network_with_tfds_data.md b/docs/notebooks/neural_network_with_tfds_data.md index 53b7d47358c2..9c153d704763 100644 --- a/docs/notebooks/neural_network_with_tfds_data.md +++ b/docs/notebooks/neural_network_with_tfds_data.md @@ -44,7 +44,7 @@ _Forked from_ `neural_network_and_data_loading.ipynb` ![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png) -Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P). +Let's combine everything we showed in the [quickstart](https://docs.jax.dev/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P). Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model. diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb index d73b0d4c0f3e..0637658c497c 100644 --- a/docs/notebooks/shard_map.ipynb +++ b/docs/notebooks/shard_map.ipynb @@ -13,11 +13,11 @@ "\n", "`shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations.\n", "\n", - "`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed.\n", + "`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed.\n", "\n", - "If you're familiar with `pmap`, think of `shard_map` as an evolution. It's more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see [a detailed comparison to `pmap`.](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this))\n", + "If you're familiar with `pmap`, think of `shard_map` as an evolution. It's more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see [a detailed comparison to `pmap`.](https://docs.jax.dev/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this))\n", "\n", - "By reading this tutorial, you'll learn how to use `shard_map` to get full control over your multi-device code. You'll see in detail how it composes with `jax.jit`'s automatic parallelization and `jax.grad`'s automatic differentiation. We'll also give some basic examples of neural network parallelization strategies.\n", + "By reading this tutorial, you'll learn how to use `shard_map` to get full control over your multi-device code. You'll see in detail how it composes with `jax.jit`'s automatic parallelization and `jax.grad`'s automatic differentiation. We'll also give some basic examples of neural network parallelization strategies, for a more detailed example see {doc}`../the-training-cookbook`.\n", "\n", "We'll assume this tutorial is being run in an environment with eight devices:" ] @@ -56,7 +56,8 @@ "import jax.numpy as jnp\n", "\n", "from jax.sharding import Mesh, PartitionSpec as P\n", - "from jax.experimental.shard_map import shard_map" + "Explicit = jax.sharding.AxisType.Explicit\n", + "Auto = jax.sharding.AxisType.Auto" ] }, { @@ -66,13 +67,14 @@ "metadata": {}, "outputs": [], "source": [ - "mesh = jax.make_mesh((4, 2), ('x', 'y'))\n", + "mesh = jax.make_mesh((4, 2), ('x', 'y'), axis_types=(Explicit,) * 2)\n", + "jax.set_mesh(mesh)\n", "\n", "a = jnp.arange( 8 * 16.).reshape(8, 16)\n", "b = jnp.arange(16 * 4.).reshape(16, 4)\n", "\n", - "@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),\n", - " out_specs=P('x', None))\n", + "@jax.shard_map(in_specs=(P('x', 'y'), P('y', None)),\n", + " out_specs=P('x', None))\n", "def matmul_basic(a_block, b_block):\n", " # a_block: f32[2, 8]\n", " # b_block: f32[8, 4]\n", @@ -149,16 +151,15 @@ "source": [ "from jax.sharding import NamedSharding\n", "\n", - "a = jax.device_put(a, NamedSharding(mesh, P('x', 'y')))\n", - "b = jax.device_put(b, NamedSharding(mesh, P('y', None)))\n", + "a = jax.device_put(a, P('x', 'y'))\n", + "b = jax.device_put(b, P('y', None))\n", "\n", "@jax.jit\n", "def matmul_reference(a, b):\n", - " c = jnp.dot(a, b)\n", - " return jax.lax.with_sharding_constraint(c, NamedSharding(mesh, P('x', None)))\n", + " return jnp.dot(a, b, out_sharding=P('x', None))\n", "\n", "c_ref = matmul_reference(a, b)\n", - "allclose(c_ref, jnp.dot(a, b))" + "allclose(c_ref, jnp.dot(a, b, out_sharding=P('x', None)))" ] }, { @@ -246,10 +247,11 @@ "source": [ "import numpy as np\n", "devices = np.array(jax.devices()[:4])\n", - "mesh = Mesh(devices, ('i',)) # mesh.shape['i'] = 4\n", + "mesh = Mesh(devices, ('i',), axis_types=(Explicit,)) # mesh.shape['i'] = 4\n", + "jax.set_mesh(mesh)\n", "\n", "def check_shmap(f, y):\n", - " ans = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))(y)\n", + " ans = jax.shard_map(f, mesh=mesh, in_specs=P('i'), out_specs=P('i'))(y)\n", " expected = jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, mesh.shape['i'])])\n", " print(allclose(ans, expected))\n", "\n", @@ -294,9 +296,10 @@ "metadata": {}, "outputs": [], "source": [ - "mesh = jax.make_mesh((4, 2), ('i', 'j'))\n", + "mesh = jax.make_mesh((4, 2), ('i', 'j'), axis_types=(Auto,) * 2)\n", + "jax.set_mesh(mesh)\n", "\n", - "@partial(shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))\n", + "@jax.shard_map(mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))\n", "def f1(x_block):\n", " print(x_block.shape) # prints (3, 12)\n", " return x_block\n", @@ -327,7 +330,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', 'j'))\n", + "@jax.shard_map(mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', 'j'))\n", "def f2(x_block):\n", " print(x_block.shape)\n", " return x_block\n", @@ -383,13 +386,13 @@ "source": [ "x = jnp.array([[3.]])\n", "\n", - "z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', 'j'))()\n", + "z = jax.shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', 'j'))()\n", "print(z) # prints the same as jnp.tile(x, (4, 2))\n", "\n", - "z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', None))()\n", + "z = jax.shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', None))()\n", "print(z) # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,))\n", "\n", - "z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P(None, None))()\n", + "z = jax.shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P(None, None))()\n", "print(z) # prints the same as jnp.tile(x, (1, 1)), or just x" ] }, @@ -410,7 +413,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', None))\n", + "@jax.shard_map(mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', None))\n", "def f3(x_block):\n", " return jax.lax.psum(x_block, 'j')\n", "\n", @@ -439,7 +442,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, 'j'))\n", + "@jax.shard_map(mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, 'j'))\n", "def f4(x_block):\n", " return jax.lax.psum(x_block, 'i')\n", "\n", @@ -448,7 +451,7 @@ "print(y4.shape) # (3,12)\n", "\n", "\n", - "@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, None))\n", + "@jax.shard_map(mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, None))\n", "def f5(x_block):\n", " return jax.lax.psum(x_block, ('i', 'j'))\n", "\n", @@ -481,6 +484,354 @@ "`Array`s, or physically how to interpret the buffers across devices as the\n", "physical layout of a single logical `Array`.\n", "\n", + "#### Tracking how values vary over manual mesh axes, and `check_vma=True`\n", + "\n", + "Under a `shard_map`, values can vary across function instances, or they can be\n", + "the same. For example, when we use `in_specs` to split an argument over a mesh\n", + "axis, each function instance along that mesh axis gets a different value:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38668c79", + "metadata": {}, + "outputs": [], + "source": [ + "mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))\n", + "jax.set_mesh(mesh)\n", + "\n", + "@jax.shard_map(mesh=mesh, in_specs=P('i'), out_specs=P('i'))\n", + "def f(x):\n", + " print(x)\n", + " return 2 * x\n", + "\n", + "x = jnp.arange(6.)\n", + "f(x)" + ] + }, + { + "cell_type": "markdown", + "id": "00b66850", + "metadata": {}, + "source": [ + "If instead `in_specs` does not split the argument over a mesh axis, the value\n", + "is the same for each function instance along that axis:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9d0dfa6d", + "metadata": {}, + "outputs": [], + "source": [ + "@jax.shard_map(mesh=mesh, in_specs=P(), out_specs=P())\n", + "def f(x):\n", + " print(x)\n", + " return 2 * x\n", + "\n", + "x = jnp.arange(6.)\n", + "f(x)" + ] + }, + { + "cell_type": "markdown", + "id": "594b4574", + "metadata": {}, + "source": [ + "A collective's output may have a different variance than its input. For\n", + "example, applying a `psum` produces the same output on each function instance\n", + "along an axis:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "df486b2f", + "metadata": {}, + "outputs": [], + "source": [ + "@jax.shard_map(mesh=mesh, in_specs=P('i'), out_specs=P())\n", + "def f(x):\n", + " y = jax.lax.psum(x, 'i')\n", + " print(y)\n", + " return y\n", + "\n", + "x = jnp.arange(6.)\n", + "f(x)" + ] + }, + { + "cell_type": "markdown", + "id": "bf6a17ad", + "metadata": {}, + "source": [ + "In general, each intermediate value in a `shard_map` can be either unvarying or\n", + "possibly-varying over each manual mesh axis. That information can be tracked in\n", + "the JAX type system, enabled by the `check_vma=True` argument to `shard_map`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d7f32190", + "metadata": {}, + "outputs": [], + "source": [ + "@jax.shard_map(mesh=mesh, in_specs=P('i'), out_specs=P())\n", + "def f(x):\n", + " print(jax.typeof(x)) # f32[3]{i}\n", + " y = jax.lax.psum(x, 'i')\n", + " print(jax.typeof(y)) # f32[3]\n", + " return y\n", + "\n", + "x = jnp.arange(6.)\n", + "f(x)" + ] + }, + { + "cell_type": "markdown", + "id": "f76cc47f", + "metadata": {}, + "source": [ + "Here, the type `f32[3]{i}` means that the value of `x` is varying over mesh\n", + "axis `'i'`. The type of `y` printing as `f32[3]` indicates it is unvarying over\n", + "all mesh axes; that is, empty sets are not printed. We call this part of the\n", + "type the _varying manual axes_ (VMA), and it can be accessed via\n", + "`jax.typeof(x).vma`.\n", + "\n", + "In general, the VMA type of a value can include any subset of the manual mesh\n", + "axes over which the `shard_map` is acting:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e69a02d3", + "metadata": {}, + "outputs": [], + "source": [ + "mesh = jax.make_mesh((4, 2), ('i', 'j'), axis_types=(Explicit,) * 2)\n", + "jax.set_mesh(mesh)\n", + "\n", + "@jax.shard_map(mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i'))\n", + "def f(x):\n", + " print(jax.typeof(x)) # f32[2,2]{i,j}\n", + " y = jax.lax.psum(x, 'j')\n", + " assert jax.typeof(y).vma == {'i'}\n", + " print(jax.typeof(y)) # f32[2,2]{i}\n", + " return y\n", + "\n", + "x = jnp.arange(8 * 4.).reshape(8, 4)\n", + "f(x)" + ] + }, + { + "cell_type": "markdown", + "id": "a36f1654", + "metadata": {}, + "source": [ + "Tracking varying manual axes can be useful:\n", + "1. Your code can include prints, assertions, or conditionals about whether\n", + " values are varying over expected mesh axes;\n", + "2. It enables efficient reverse-mode autodiff that doesn't require defensive\n", + " `psum`s (see [JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html));\n", + "3. The correctness of `out_specs` can be checked, ruling out the potential bug\n", + " example below.\n", + "\n", + "For example, this `out_specs` bug is caught with `check_vma=True`, but uncaught\n", + "without it:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c92c1d4d", + "metadata": {}, + "outputs": [], + "source": [ + "mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))\n", + "jax.set_mesh(mesh)\n", + "\n", + "x = jnp.arange(6.)\n", + "try:\n", + " y = jax.shard_map(lambda x: x, mesh=mesh, in_specs=P('i'), out_specs=P())(x)\n", + "except Exception as e:\n", + " print(e)" + ] + }, + { + "cell_type": "markdown", + "id": "68bc33af", + "metadata": {}, + "source": [ + "Here the `out_specs` incorrectly promise that each function instance along mesh\n", + "axis `'i'` produces the same value and thus we can choose just one of them.\n", + "With `check_vma=True` (the default) it raises an exception, while with\n", + "`check_vma=False` there is no exception and instead we get silent undefined\n", + "behavior.\n", + "\n", + "Sometimes we want to treat a value that is unvarying over a mesh axis as\n", + "varying over that mesh axis. That's what `jax.lax.pcast` does:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21276d78", + "metadata": {}, + "outputs": [], + "source": [ + "@jax.shard_map(mesh=mesh, in_specs=P(), out_specs=None)\n", + "def f(x):\n", + " print(jax.typeof(x)) # f32[6]\n", + " y = jax.lax.pcast(x, 'i', to='varying')\n", + " print(jax.typeof(y)) # f32[6]{i}\n", + "\n", + "x = jnp.arange(6.)\n", + "f(x)" + ] + }, + { + "cell_type": "markdown", + "id": "8f766c1a", + "metadata": {}, + "source": [ + "Think of `jax.lax.pcast(..., to='varying')` as applying a\n", + "type cast: it's a no-op at runtime,\n", + "though under reverse-mode autodiff it transposes to a `jax.lax.psum` (see\n", + "[JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html)). That\n", + "makes sense because they do opposite things to the VMA: where `y: f32[3]{i} =\n", + "jax.lax.pcast(x: f32[3], 'i', to='varying')`,\n", + "we correspondingly have `x_grad: f32[3] = jax.lax.psum(y_grad: f32[3]{i}, 'i')`.\n", + "\n", + "JAX implicitly inserts `jax.lax.pcast(..., to='varying')` calls in many cases,\n", + "especially for binary operations:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e22d52a4", + "metadata": {}, + "outputs": [], + "source": [ + "@jax.shard_map(mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))\n", + "def f(x, y):\n", + " return x * y\n", + "\n", + "x = jnp.arange(6.)\n", + "y = jnp.arange(3.)\n", + "print(jax.make_jaxpr(f)(x, y))" + ] + }, + { + "cell_type": "markdown", + "id": "1bd7f6a5", + "metadata": {}, + "source": [ + "In a jaxpr, the multiplication operation requires the VMA types of its\n", + "arguments to match, but for convenience the `jax.numpy` and `jax.lax` APIs\n", + "automatically apply `jax.lax.pcast(..., to='varying')` to make argument VMA\n", + "types agree. In a jaxpr, these `jax.lax.pcast` calls show up as `pvary` since\n", + "`jax.lax.pcast(..., to='varying')` dispatches to `lax.pvary`.\n", + "\n", + "\n", + "\n", + "In some cases, like with `jax.lax.scan`, you might need to apply\n", + "`jax.lax.pcast` yourself to ensure VMA types match as required. For example,\n", + "this code raises an error:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6e33a5fb", + "metadata": {}, + "outputs": [], + "source": [ + "mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))\n", + "jax.set_mesh(mesh)\n", + "\n", + "@jax.shard_map(mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))\n", + "def f(x, y):\n", + " def body(carry, _):\n", + " c1, c2 = carry\n", + " return (c2, c1), () # swap the carry\n", + " (x_, y_), _ = jax.lax.scan(body, (x, y), (), length=2)\n", + " return x_, y_\n", + "\n", + "x = jnp.arange(6.)\n", + "y = jnp.arange(3.)\n", + "\n", + "try:\n", + " f(x, y)\n", + "except Exception as e:\n", + " print(e)" + ] + }, + { + "cell_type": "markdown", + "id": "7b6fef36", + "metadata": {}, + "source": [ + "To make the types match, we need to apply `jax.lax.pcast` to some arguments to\n", + "the `scan`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8c8dbd11", + "metadata": {}, + "outputs": [], + "source": [ + "mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))\n", + "jax.set_mesh(mesh)\n", + "\n", + "@jax.shard_map(mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))\n", + "def f(x, y):\n", + " def body(carry, _):\n", + " c1, c2 = carry\n", + " return (c2, c1), () # swap the carry\n", + "\n", + " y = jax.lax.pcast(y, 'i', to='varying') # apply pcast to fix the error\n", + " (x_, y_), _ = jax.lax.scan(body, (x, y), (), length=2)\n", + " return x_, y_\n", + "\n", + "x = jnp.arange(6.)\n", + "y = jnp.arange(3.)\n", + "\n", + "f(x, y)" + ] + }, + { + "cell_type": "markdown", + "id": "10271c3c", + "metadata": {}, + "source": [ + "Here's a summary of collective primitives and how they affect varying manual axis types:\n", + "\n", + "| Name | Device variance type | Example | Lowers to HLO | Transpose |\n", + "| --- | --- | --- | --- | --- |\n", + "| `psum_invariant` | `Varying -> Invariant` | `y:f32[3]{j} = psum(x:f32[3]{i,j}, axis='i')` | `AllReduceSum` (communication) | `pvary` |\n", + "| `pvary` | `Invariant -> Varying` | `y:f32[3]{i} = pvary(x:f32[3], 'i')` | no-op (no communication) | `psum_invariant` |\n", + "| `all_to_all` | `Varying -> Varying` | `y:f32[16]{i} = all_to_all(x:f32[16]{i}, 'i', 0, 0)` `AllToAll` (communication) | `all_to_all` |\n", + "| `axis_index` | `() -> Varying` | `idx:i32[]{i} = axis_index('i')` | `ReplicaId` and some arithmetic (no communication) | n/a |\n", + "| `psum_scatter` | `Varying -> Varying` | `y:f32[2]{i} = psum_scatter(x:f32[16]{i}, 'i')` | `ReduceScatterSum` (communication) | `all_gather` |\n", + "| `all_gather` | `Varying -> Varying` | `y:f32[16]{i} = all_gather(x:f32[2]{i}, 'i')` | `AllGather` (communication) | `psum_scatter` |\n", + "| `pscatter` | `Invariant -> Varying` | `y:f32[2]{i} = pscatter(x:f32[16], 'i')` | `lambda x: x[axis_index('i'), None]` (no communication) | `all_gather_invariant` |\n", + "| `all_gather_invariant` | `Varying -> Invariant` | `y:f32[16] = all_gather_invariant(x:f32[2]{i}, 'i')` | `AllGather` (communication) | `pscatter` |\n", + "\n", + "A few notes on the table:\n", + "* The function `jax.lax.psum` is a convenience wrapper around `psum_invariant`.\n", + "* It's surprising that `all_gather` is `Varying -> Varying`, but that's because\n", + " it's really the transpose of `psum_scatter` which is `Varying -> Varying`.\n", + "* Neither `pscatter` nor `all_gather_invariant` have user APIs at the time of\n", + " writing, but they're described here for completeness.\n", + "\n", + "\n", "## API Specification\n", "\n", "```python\n", @@ -488,18 +839,21 @@ "Specs = PyTree[PartitionSpec]\n", "\n", "def shard_map(\n", - " f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs,\n", - " auto: collections.abc.Set[AxisName] = frozenset([]),\n", - " check_rep: bool = True,\n", + " f: Callable, /, *, out_specs: Specs, mesh: Mesh | None = None,\n", + " in_specs: Specs | None = None,\n", + " axis_names: collections.abc.Set[AxisName] = set(),\n", + " check_vma: bool = True,\n", ") -> Callable:\n", " ...\n", "```\n", "where:\n", "* communication collectives like `psum` in the body of `f` can mention the axis names of `mesh`;\n", - "* `mesh` encodes devices arranged in an array and with associated axis names, just like it does for `sharding.NamedSharding`;\n", - "* `in_specs` and `out_specs` are `PartitionSpec`s which can affinely mention axis names from `mesh` to express slicing/unconcatenation and concatenation of inputs and outputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively;\n", - "* `auto` is an optional set of axis names corresponding to the subset of names of `mesh` to treat automatically in the body, as in the caller, rather than manually;\n", - "* `check_rep` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://jax.readthedocs.io/en/latest/jep/17111-shmap-transpose.html)).\n", + "* `mesh` encodes devices arranged in an array and with associated axis names, just like it does for `sharding.NamedSharding`; If None, mesh will be inferred from the\n", + "context which can be set via the `jax.set_mesh` context manager.\n", + "* `in_specs` are `PartitionSpec`s which can zero or one times mention axis names from `mesh` to express slicing/unconcatenation of inputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy). If None, all mesh axes must be of type `Explicit`, in which case the in_specs are inferred from the argument types;\n", + "* `out_specs` are `PartitionSpec`s which can zero or one times mention axis names from `mesh` to express concatenation of outputs, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively;\n", + "* `axis_names` is an optional set of axis names corresponding to the subset of names of `mesh` to treat manual in the body. If empty, `f` is manual over all axes of the mesh.\n", + "* `check_vma` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html)).\n", "\n", "The shapes of the arguments passed to `f` have the same ranks as the arguments\n", "passed to `shard_map`-of-`f`, and the shape of an argument to `f` is computed\n", @@ -521,7 +875,7 @@ "```python\n", "mesh = Mesh(jax.devices(), ('i',))\n", "x = jnp.arange(16.)\n", - "f_shmapped = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))\n", + "f_shmapped = jax.shard_map(f, mesh=mesh, in_specs=P('i'), out_specs=P('i'))\n", "y = f_shmapped(x)\n", "```\n", "\n", @@ -593,8 +947,7 @@ "import jax.numpy as jnp\n", "from jax import lax\n", "\n", - "from jax.sharding import Mesh, NamedSharding, PartitionSpec as P\n", - "from jax.experimental.shard_map import shard_map" + "from jax.sharding import Mesh, PartitionSpec as P" ] }, { @@ -605,8 +958,9 @@ "outputs": [], "source": [ "mesh1d = Mesh(jax.devices()[:4], ('i',))\n", + "jax.set_mesh(mesh1d)\n", "\n", - "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P(None))\n", + "@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P(None))\n", "def f1(x_block):\n", " print('BEFORE:\\n', x_block)\n", " y_block = jax.lax.psum(x_block, 'i')\n", @@ -661,8 +1015,9 @@ "outputs": [], "source": [ "mesh2d = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('i', 'j'))\n", + "jax.set_mesh(mesh2d)\n", "\n", - "@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j'))\n", + "@jax.shard_map(mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j'))\n", "def f2(x_block):\n", " print('BEFORE:\\n', x_block)\n", " y_block = jax.lax.psum(x_block, 'i')\n", @@ -693,7 +1048,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, None))\n", + "@jax.shard_map(mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, None))\n", "def f3(x_block):\n", " print('BEFORE:\\n', x_block)\n", " y_block = jax.lax.psum(x_block, ('i', 'j'))\n", @@ -730,7 +1085,9 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", + "jax.set_mesh(mesh1d)\n", + "\n", + "@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", "def f4(x_block):\n", " print('BEFORE:\\n', x_block)\n", " y_block = jax.lax.all_gather(x_block, 'i', tiled=True)\n", @@ -769,7 +1126,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", + "@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", "def f5(x_block):\n", " print('BEFORE:\\n', x_block)\n", " y_block = jax.lax.all_gather(x_block, 'i', tiled=False)\n", @@ -812,7 +1169,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", + "@jax.shard_map(in_specs=P('i'), out_specs=P('i'))\n", "def f6(x_block):\n", " print('BEFORE:\\n', x_block)\n", " y_block = jax.lax.psum_scatter(x_block, 'i', tiled=True)\n", @@ -888,9 +1245,9 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", + "@jax.shard_map(in_specs=P('i'), out_specs=P('i'))\n", "def f7(x_block):\n", - " sz = jax.lax.psum(1, 'i')\n", + " sz = jax.lax.axis_size('i')\n", " print('BEFORE:\\n', x_block)\n", " y_block = jax.lax.ppermute(x_block, 'i', [(i, (i + 1) % sz) for i in range(sz)])\n", " print('AFTER:\\n', y_block)\n", @@ -947,7 +1304,7 @@ "outputs": [], "source": [ "def psum_scatter(x, axis_name, *, tiled=False):\n", - " size = jax.lax.psum(1, axis_name)\n", + " size = jax.lax.axis_size(axis_name)\n", " idx = jax.lax.axis_index(axis_name) # function instance index along axis_name\n", " if tiled:\n", " x = x.reshape(size, -1, *x.shape[1:]) # split leading axis\n", @@ -966,7 +1323,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", + "@jax.shard_map(in_specs=P('i'), out_specs=P('i'))\n", "def f8(x_block):\n", " print('BEFORE:\\n', x_block)\n", " y_block = psum_scatter(x_block, 'i', tiled=True)\n", @@ -1014,7 +1371,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", + "@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", "def f9(x_block):\n", " print('BEFORE:\\n', x_block)\n", " y_block = jax.lax.all_to_all(x_block, 'i', split_axis=0, concat_axis=0,\n", @@ -1086,8 +1443,7 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "from jax.sharding import Mesh, NamedSharding, PartitionSpec as P\n", - "from jax.experimental.shard_map import shard_map" + "from jax.sharding import Mesh, PartitionSpec as P" ] }, { @@ -1098,6 +1454,7 @@ "outputs": [], "source": [ "mesh = Mesh(jax.devices()[:4], ('i',))\n", + "jax.set_mesh(mesh)\n", "\n", "def device_put(x, pspec):\n", " return jax.device_put(x, NamedSharding(mesh, pspec))" @@ -1163,8 +1520,8 @@ "outputs": [], "source": [ "@jax.jit\n", - "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", - " out_specs=rhs_spec)\n", + "@jax.shard_map(mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", + " out_specs=rhs_spec)\n", "def matmul_allgather(lhs_block, rhs_block):\n", " rhs = jax.lax.all_gather(rhs_block, 'i', tiled=True)\n", " return lhs_block @ rhs" @@ -1207,10 +1564,10 @@ "outputs": [], "source": [ "@jax.jit\n", - "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", - " out_specs=rhs_spec)\n", + "@jax.shard_map(mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", + " out_specs=rhs_spec)\n", "def matmul_allgather_overlapped(lhs_block, rhs_block):\n", - " size = jax.lax.psum(1, 'i')\n", + " size = jax.lax.axis_size('i')\n", " idx = jax.lax.axis_index('i')\n", " shift = partial(jax.lax.ppermute, axis_name='i',\n", " perm=[(i, (i + 1) % size) for i in range(size)])\n", @@ -1256,10 +1613,10 @@ "outputs": [], "source": [ "@jax.jit\n", - "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", - " out_specs=rhs_spec)\n", + "@jax.shard_map(mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", + " out_specs=rhs_spec)\n", "def matmul_allgather_overlapped_bidi(lhs_block, rhs_block):\n", - " size = jax.lax.psum(1, 'i')\n", + " size = jax.lax.axis_size('i')\n", " idx = jax.lax.axis_index('i')\n", " shift_up = partial(jax.lax.ppermute, axis_name='i',\n", " perm=[(i, (i + 1) % size) for i in range(size)])\n", @@ -1337,8 +1694,8 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", - " out_specs=rhs_spec)\n", + "@jax.shard_map(mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", + " out_specs=rhs_spec)\n", "def matmul_psumscatter(lhs_block, rhs_block):\n", " out_summand = lhs_block @ rhs_block\n", " return jax.lax.psum_scatter(out_summand, 'i', tiled=True)\n", @@ -1365,10 +1722,10 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", - " out_specs=rhs_spec)\n", + "@jax.shard_map(mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", + " out_specs=rhs_spec)\n", "def matmul_psumscatter_overlapped(lhs_block, rhs_block):\n", - " size = jax.lax.psum(1, 'i')\n", + " size = jax.lax.axis_size('i')\n", " idx = jax.lax.axis_index('i')\n", " shift = partial(jax.lax.ppermute, axis_name='i',\n", " perm=[(i, (i - 1) % size) for i in range(size)])\n", @@ -1408,10 +1765,10 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", - " out_specs=rhs_spec)\n", + "@jax.shard_map(mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", + " out_specs=rhs_spec)\n", "def matmul_psumscatter_overlapped_bidi(lhs_block, rhs_block):\n", - " size = jax.lax.psum(1, 'i')\n", + " size = jax.lax.axis_size('i')\n", " idx = jax.lax.axis_index('i')\n", " shift_up = partial(jax.lax.ppermute, axis_name='i',\n", " perm=[(i, (i + 1) % size) for i in range(size)])\n", @@ -1520,7 +1877,7 @@ "source": [ "Compare these examples with the purely [automatic partitioning examples in the\n", "\"Distributed arrays and automatic partitioning\"\n", - "doc](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html).\n", + "doc](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html).\n", "While in those automatic partitioning examples we don't need to edit the model\n", "functions to use different parallelization strategies, with `shard_map` we\n", "often do.\n", @@ -1542,12 +1899,10 @@ "metadata": {}, "outputs": [], "source": [ - "from functools import partial\n", - "\n", - "from jax.sharding import NamedSharding, Mesh, PartitionSpec as P\n", - "from jax.experimental.shard_map import shard_map\n", + "from jax.sharding import Mesh, PartitionSpec as P\n", "\n", - "mesh = jax.make_mesh((8,), ('batch',))\n", + "mesh = jax.make_mesh((8,), ('batch',), axis_types=(Auto,))\n", + "jax.set_mesh(mesh)\n", "\n", "# replicate initial params on all devices, shard data batch over devices\n", "batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))\n", @@ -1555,7 +1910,7 @@ "\n", "# adapt the loss function to sum the losses across devices\n", "def loss_dp(params, batch):\n", - " @partial(shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P())\n", + " @jax.shard_map(mesh=mesh, in_specs=P('batch', None), out_specs=P())\n", " def loss_spmd(local_batch):\n", " inputs, targets = local_batch\n", " predictions = predict(params, inputs) # use reference 'predict`\n", @@ -1626,7 +1981,7 @@ "parameters from the forward pass for use on the backward pass. Instead, we want\n", "to gather them again on the backward pass. We can express that by using\n", "`jax.remat` with a [custom\n", - "policy](https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#custom-policies-for-what-s-saveable)\n", + "policy](https://docs.jax.dev/en/latest/notebooks/autodiff_remat.html#custom-policies-for-what-s-saveable)\n", "(or a `custom_vjp`), though XLA typically does that rematerialization\n", "automatically.\n", "\n", @@ -1645,6 +2000,7 @@ "source": [ "# shard data batch *and params* over devices\n", "mesh = Mesh(devices, ('batch',))\n", + "jax.set_mesh(mesh)\n", "batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))\n", "params = jax.device_put(params, NamedSharding(mesh, P('batch')))\n", "\n", @@ -1660,7 +2016,7 @@ " return outputs\n", "\n", "def loss_fsdp(params, batch):\n", - " @partial(shard_map, mesh=mesh, in_specs=P('batch'), out_specs=P())\n", + " @jax.shard_map(mesh=mesh, in_specs=P('batch'), out_specs=P())\n", " def loss_spmd(local_params, local_batch):\n", " inputs, targets = local_batch\n", " predictions = predict_fsdp(local_params, inputs)\n", @@ -1718,7 +2074,8 @@ "metadata": {}, "outputs": [], "source": [ - "mesh = jax.make_mesh((8,), ('feats',))\n", + "mesh = jax.make_mesh((8,), ('feats',), axis_types=(Auto,))\n", + "jax.set_mesh(mesh)\n", "\n", "batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats')))\n", "params = jax.device_put(params, NamedSharding(mesh, P('feats')))\n", @@ -1729,9 +2086,8 @@ " inputs = jax.nn.relu(outputs)\n", " return outputs\n", "\n", - "@partial(shard_map, mesh=mesh,\n", - " in_specs=(P(None, 'feats'), P('feats', None), P('feats')),\n", - " out_specs=P(None, 'feats'))\n", + "@jax.shard_map(mesh=mesh, in_specs=(P(None, 'feats'), P('feats', None), P('feats')),\n", + " out_specs=P(None, 'feats'))\n", "def gemm_tp(inputs, W, b):\n", " block_result = jnp.dot(inputs, W)\n", " return jax.lax.psum_scatter(block_result, 'feats',\n", @@ -1760,7 +2116,8 @@ "metadata": {}, "outputs": [], "source": [ - "mesh = jax.make_mesh((4, 2), ('batch', 'feats'))\n", + "mesh = jax.make_mesh((4, 2), ('batch', 'feats'), axis_types=(Auto,) * 2)\n", + "jax.set_mesh(mesh)\n", "\n", "batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats')))\n", "params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats'))))\n", @@ -1777,9 +2134,8 @@ " inputs = jax.nn.relu(outputs)\n", " return outputs\n", "\n", - "@partial(shard_map, mesh=mesh,\n", - " in_specs=(P(('feats', 'batch')), P('batch', 'feats')),\n", - " out_specs=P())\n", + "@jax.shard_map(mesh=mesh, in_specs=(P(('feats', 'batch')), P('batch', 'feats')),\n", + " out_specs=P())\n", "def loss_fsdp_tp(local_params, local_batch):\n", " inputs, targets = local_batch\n", " predictions = predict_fsdp_tp(local_params, inputs)\n", @@ -1887,8 +2243,8 @@ " outputs = jnp.dot(inputs, W_last) + b_last\n", " return outputs\n", "\n", - "@partial(shard_map, mesh=mesh, in_specs=((P(), P('stages'), P()), P('stages')),\n", - " out_specs=P())\n", + "@jax.shard_map(mesh=mesh, in_specs=((P(), P('stages'), P()), P('stages')),\n", + " out_specs=P())\n", "def loss_pp(params, batch):\n", " inputs, targets = batch\n", " predictions = predict_pp(params, inputs.reshape(K, B, -1)).reshape(K * B, -1)\n", @@ -1950,7 +2306,17 @@ "metadata": {}, "outputs": [], "source": [ - "print(jax.jit(loss)(params, batch))\n", + "print(jax.jit(loss)(params, batch))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9ff83661", + "metadata": {}, + "outputs": [], + "source": [ + "jax.set_mesh(mesh)\n", "print(jax.jit(loss_pp)(params_, batch_))" ] }, diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md index c52cf0e6d22b..593ccd0a901c 100644 --- a/docs/notebooks/shard_map.md +++ b/docs/notebooks/shard_map.md @@ -22,11 +22,11 @@ kernelspec: `shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations. -`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed. +`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed. -If you're familiar with `pmap`, think of `shard_map` as an evolution. It's more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see [a detailed comparison to `pmap`.](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this)) +If you're familiar with `pmap`, think of `shard_map` as an evolution. It's more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see [a detailed comparison to `pmap`.](https://docs.jax.dev/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this)) -By reading this tutorial, you'll learn how to use `shard_map` to get full control over your multi-device code. You'll see in detail how it composes with `jax.jit`'s automatic parallelization and `jax.grad`'s automatic differentiation. We'll also give some basic examples of neural network parallelization strategies. +By reading this tutorial, you'll learn how to use `shard_map` to get full control over your multi-device code. You'll see in detail how it composes with `jax.jit`'s automatic parallelization and `jax.grad`'s automatic differentiation. We'll also give some basic examples of neural network parallelization strategies, for a more detailed example see {doc}`../the-training-cookbook`. We'll assume this tutorial is being run in an environment with eight devices: @@ -46,17 +46,19 @@ import jax import jax.numpy as jnp from jax.sharding import Mesh, PartitionSpec as P -from jax.experimental.shard_map import shard_map +Explicit = jax.sharding.AxisType.Explicit +Auto = jax.sharding.AxisType.Auto ``` ```{code-cell} -mesh = jax.make_mesh((4, 2), ('x', 'y')) +mesh = jax.make_mesh((4, 2), ('x', 'y'), axis_types=(Explicit,) * 2) +jax.set_mesh(mesh) a = jnp.arange( 8 * 16.).reshape(8, 16) b = jnp.arange(16 * 4.).reshape(16, 4) -@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)), - out_specs=P('x', None)) +@jax.shard_map(in_specs=(P('x', 'y'), P('y', None)), + out_specs=P('x', None)) def matmul_basic(a_block, b_block): # a_block: f32[2, 8] # b_block: f32[8, 4] @@ -97,16 +99,15 @@ The above code is performing the same computation as this `jax.jit` automatic pa ```{code-cell} from jax.sharding import NamedSharding -a = jax.device_put(a, NamedSharding(mesh, P('x', 'y'))) -b = jax.device_put(b, NamedSharding(mesh, P('y', None))) +a = jax.device_put(a, P('x', 'y')) +b = jax.device_put(b, P('y', None)) @jax.jit def matmul_reference(a, b): - c = jnp.dot(a, b) - return jax.lax.with_sharding_constraint(c, NamedSharding(mesh, P('x', None))) + return jnp.dot(a, b, out_sharding=P('x', None)) c_ref = matmul_reference(a, b) -allclose(c_ref, jnp.dot(a, b)) +allclose(c_ref, jnp.dot(a, b, out_sharding=P('x', None))) ``` We can think of `shard_map` as performing a `device_put` or @@ -158,10 +159,11 @@ when collectives aren't involved): ```{code-cell} import numpy as np devices = np.array(jax.devices()[:4]) -mesh = Mesh(devices, ('i',)) # mesh.shape['i'] = 4 +mesh = Mesh(devices, ('i',), axis_types=(Explicit,)) # mesh.shape['i'] = 4 +jax.set_mesh(mesh) def check_shmap(f, y): - ans = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))(y) + ans = jax.shard_map(f, mesh=mesh, in_specs=P('i'), out_specs=P('i'))(y) expected = jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, mesh.shape['i'])]) print(allclose(ans, expected)) @@ -194,9 +196,10 @@ input array axis size.) If an input's pspec does not mention a mesh axis name, then there's no splitting over that mesh axis. For example: ```{code-cell} -mesh = jax.make_mesh((4, 2), ('i', 'j')) +mesh = jax.make_mesh((4, 2), ('i', 'j'), axis_types=(Auto,) * 2) +jax.set_mesh(mesh) -@partial(shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j')) +@jax.shard_map(mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j')) def f1(x_block): print(x_block.shape) # prints (3, 12) return x_block @@ -215,7 +218,7 @@ less efficient program where all mesh axes are mentioned but the caller performs a `jnp.tile`, for example: ```{code-cell} -@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', 'j')) +@jax.shard_map(mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', 'j')) def f2(x_block): print(x_block.shape) return x_block @@ -259,13 +262,13 @@ using the same mesh as above: ```{code-cell} x = jnp.array([[3.]]) -z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', 'j'))() +z = jax.shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', 'j'))() print(z) # prints the same as jnp.tile(x, (4, 2)) -z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', None))() +z = jax.shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', None))() print(z) # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,)) -z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P(None, None))() +z = jax.shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P(None, None))() print(z) # prints the same as jnp.tile(x, (1, 1)), or just x ``` @@ -274,7 +277,7 @@ augment with a corresponding input pspec of P(None, None). As another example, following more closely to the other examples above: ```{code-cell} -@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', None)) +@jax.shard_map(mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', None)) def f3(x_block): return jax.lax.psum(x_block, 'j') @@ -291,7 +294,7 @@ two more examples where we vary which mesh axes are mentioned in the output pspec: ```{code-cell} -@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, 'j')) +@jax.shard_map(mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, 'j')) def f4(x_block): return jax.lax.psum(x_block, 'i') @@ -300,7 +303,7 @@ y4 = f4(x) print(y4.shape) # (3,12) -@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, None)) +@jax.shard_map(mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, None)) def f5(x_block): return jax.lax.psum(x_block, ('i', 'j')) @@ -328,6 +331,234 @@ Instead, `out_specs` just encodes how to assemble the block outputs into `Array`s, or physically how to interpret the buffers across devices as the physical layout of a single logical `Array`. +#### Tracking how values vary over manual mesh axes, and `check_vma=True` + +Under a `shard_map`, values can vary across function instances, or they can be +the same. For example, when we use `in_specs` to split an argument over a mesh +axis, each function instance along that mesh axis gets a different value: + +```{code-cell} +mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,)) +jax.set_mesh(mesh) + +@jax.shard_map(mesh=mesh, in_specs=P('i'), out_specs=P('i')) +def f(x): + print(x) + return 2 * x + +x = jnp.arange(6.) +f(x) +``` + +If instead `in_specs` does not split the argument over a mesh axis, the value +is the same for each function instance along that axis: + +```{code-cell} +@jax.shard_map(mesh=mesh, in_specs=P(), out_specs=P()) +def f(x): + print(x) + return 2 * x + +x = jnp.arange(6.) +f(x) +``` + +A collective's output may have a different variance than its input. For +example, applying a `psum` produces the same output on each function instance +along an axis: + +```{code-cell} +@jax.shard_map(mesh=mesh, in_specs=P('i'), out_specs=P()) +def f(x): + y = jax.lax.psum(x, 'i') + print(y) + return y + +x = jnp.arange(6.) +f(x) +``` + +In general, each intermediate value in a `shard_map` can be either unvarying or +possibly-varying over each manual mesh axis. That information can be tracked in +the JAX type system, enabled by the `check_vma=True` argument to `shard_map`: + +```{code-cell} +@jax.shard_map(mesh=mesh, in_specs=P('i'), out_specs=P()) +def f(x): + print(jax.typeof(x)) # f32[3]{i} + y = jax.lax.psum(x, 'i') + print(jax.typeof(y)) # f32[3] + return y + +x = jnp.arange(6.) +f(x) +``` + +Here, the type `f32[3]{i}` means that the value of `x` is varying over mesh +axis `'i'`. The type of `y` printing as `f32[3]` indicates it is unvarying over +all mesh axes; that is, empty sets are not printed. We call this part of the +type the _varying manual axes_ (VMA), and it can be accessed via +`jax.typeof(x).vma`. + +In general, the VMA type of a value can include any subset of the manual mesh +axes over which the `shard_map` is acting: + +```{code-cell} +mesh = jax.make_mesh((4, 2), ('i', 'j'), axis_types=(Explicit,) * 2) +jax.set_mesh(mesh) + +@jax.shard_map(mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i')) +def f(x): + print(jax.typeof(x)) # f32[2,2]{i,j} + y = jax.lax.psum(x, 'j') + assert jax.typeof(y).vma == {'i'} + print(jax.typeof(y)) # f32[2,2]{i} + return y + +x = jnp.arange(8 * 4.).reshape(8, 4) +f(x) +``` + +Tracking varying manual axes can be useful: +1. Your code can include prints, assertions, or conditionals about whether + values are varying over expected mesh axes; +2. It enables efficient reverse-mode autodiff that doesn't require defensive + `psum`s (see [JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html)); +3. The correctness of `out_specs` can be checked, ruling out the potential bug + example below. + +For example, this `out_specs` bug is caught with `check_vma=True`, but uncaught +without it: + +```{code-cell} +mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,)) +jax.set_mesh(mesh) + +x = jnp.arange(6.) +try: + y = jax.shard_map(lambda x: x, mesh=mesh, in_specs=P('i'), out_specs=P())(x) +except Exception as e: + print(e) +``` + +Here the `out_specs` incorrectly promise that each function instance along mesh +axis `'i'` produces the same value and thus we can choose just one of them. +With `check_vma=True` (the default) it raises an exception, while with +`check_vma=False` there is no exception and instead we get silent undefined +behavior. + +Sometimes we want to treat a value that is unvarying over a mesh axis as +varying over that mesh axis. That's what `jax.lax.pcast` does: + +```{code-cell} +@jax.shard_map(mesh=mesh, in_specs=P(), out_specs=None) +def f(x): + print(jax.typeof(x)) # f32[6] + y = jax.lax.pcast(x, 'i', to='varying') + print(jax.typeof(y)) # f32[6]{i} + +x = jnp.arange(6.) +f(x) +``` + +Think of `jax.lax.pcast(..., to='varying')` as applying a +type cast: it's a no-op at runtime, +though under reverse-mode autodiff it transposes to a `jax.lax.psum` (see +[JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html)). That +makes sense because they do opposite things to the VMA: where `y: f32[3]{i} = +jax.lax.pcast(x: f32[3], 'i', to='varying')`, +we correspondingly have `x_grad: f32[3] = jax.lax.psum(y_grad: f32[3]{i}, 'i')`. + +JAX implicitly inserts `jax.lax.pcast(..., to='varying')` calls in many cases, +especially for binary operations: + +```{code-cell} +@jax.shard_map(mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i')) +def f(x, y): + return x * y + +x = jnp.arange(6.) +y = jnp.arange(3.) +print(jax.make_jaxpr(f)(x, y)) +``` + +In a jaxpr, the multiplication operation requires the VMA types of its +arguments to match, but for convenience the `jax.numpy` and `jax.lax` APIs +automatically apply `jax.lax.pcast(..., to='varying')` to make argument VMA +types agree. In a jaxpr, these `jax.lax.pcast` calls show up as `pvary` since +`jax.lax.pcast(..., to='varying')` dispatches to `lax.pvary`. + + + +In some cases, like with `jax.lax.scan`, you might need to apply +`jax.lax.pcast` yourself to ensure VMA types match as required. For example, +this code raises an error: + +```{code-cell} +mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,)) +jax.set_mesh(mesh) + +@jax.shard_map(mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i')) +def f(x, y): + def body(carry, _): + c1, c2 = carry + return (c2, c1), () # swap the carry + (x_, y_), _ = jax.lax.scan(body, (x, y), (), length=2) + return x_, y_ + +x = jnp.arange(6.) +y = jnp.arange(3.) + +try: + f(x, y) +except Exception as e: + print(e) +``` + +To make the types match, we need to apply `jax.lax.pcast` to some arguments to +the `scan`: + +```{code-cell} +mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,)) +jax.set_mesh(mesh) + +@jax.shard_map(mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i')) +def f(x, y): + def body(carry, _): + c1, c2 = carry + return (c2, c1), () # swap the carry + + y = jax.lax.pcast(y, 'i', to='varying') # apply pcast to fix the error + (x_, y_), _ = jax.lax.scan(body, (x, y), (), length=2) + return x_, y_ + +x = jnp.arange(6.) +y = jnp.arange(3.) + +f(x, y) +``` + +Here's a summary of collective primitives and how they affect varying manual axis types: + +| Name | Device variance type | Example | Lowers to HLO | Transpose | +| --- | --- | --- | --- | --- | +| `psum_invariant` | `Varying -> Invariant` | `y:f32[3]{j} = psum(x:f32[3]{i,j}, axis='i')` | `AllReduceSum` (communication) | `pvary` | +| `pvary` | `Invariant -> Varying` | `y:f32[3]{i} = pvary(x:f32[3], 'i')` | no-op (no communication) | `psum_invariant` | +| `all_to_all` | `Varying -> Varying` | `y:f32[16]{i} = all_to_all(x:f32[16]{i}, 'i', 0, 0)` `AllToAll` (communication) | `all_to_all` | +| `axis_index` | `() -> Varying` | `idx:i32[]{i} = axis_index('i')` | `ReplicaId` and some arithmetic (no communication) | n/a | +| `psum_scatter` | `Varying -> Varying` | `y:f32[2]{i} = psum_scatter(x:f32[16]{i}, 'i')` | `ReduceScatterSum` (communication) | `all_gather` | +| `all_gather` | `Varying -> Varying` | `y:f32[16]{i} = all_gather(x:f32[2]{i}, 'i')` | `AllGather` (communication) | `psum_scatter` | +| `pscatter` | `Invariant -> Varying` | `y:f32[2]{i} = pscatter(x:f32[16], 'i')` | `lambda x: x[axis_index('i'), None]` (no communication) | `all_gather_invariant` | +| `all_gather_invariant` | `Varying -> Invariant` | `y:f32[16] = all_gather_invariant(x:f32[2]{i}, 'i')` | `AllGather` (communication) | `pscatter` | + +A few notes on the table: +* The function `jax.lax.psum` is a convenience wrapper around `psum_invariant`. +* It's surprising that `all_gather` is `Varying -> Varying`, but that's because + it's really the transpose of `psum_scatter` which is `Varying -> Varying`. +* Neither `pscatter` nor `all_gather_invariant` have user APIs at the time of + writing, but they're described here for completeness. + + ## API Specification ```python @@ -335,18 +566,21 @@ from jax.sharding import Mesh Specs = PyTree[PartitionSpec] def shard_map( - f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs, - auto: collections.abc.Set[AxisName] = frozenset([]), - check_rep: bool = True, + f: Callable, /, *, out_specs: Specs, mesh: Mesh | None = None, + in_specs: Specs | None = None, + axis_names: collections.abc.Set[AxisName] = set(), + check_vma: bool = True, ) -> Callable: ... ``` where: * communication collectives like `psum` in the body of `f` can mention the axis names of `mesh`; -* `mesh` encodes devices arranged in an array and with associated axis names, just like it does for `sharding.NamedSharding`; -* `in_specs` and `out_specs` are `PartitionSpec`s which can affinely mention axis names from `mesh` to express slicing/unconcatenation and concatenation of inputs and outputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively; -* `auto` is an optional set of axis names corresponding to the subset of names of `mesh` to treat automatically in the body, as in the caller, rather than manually; -* `check_rep` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://jax.readthedocs.io/en/latest/jep/17111-shmap-transpose.html)). +* `mesh` encodes devices arranged in an array and with associated axis names, just like it does for `sharding.NamedSharding`; If None, mesh will be inferred from the +context which can be set via the `jax.set_mesh` context manager. +* `in_specs` are `PartitionSpec`s which can zero or one times mention axis names from `mesh` to express slicing/unconcatenation of inputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy). If None, all mesh axes must be of type `Explicit`, in which case the in_specs are inferred from the argument types; +* `out_specs` are `PartitionSpec`s which can zero or one times mention axis names from `mesh` to express concatenation of outputs, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively; +* `axis_names` is an optional set of axis names corresponding to the subset of names of `mesh` to treat manual in the body. If empty, `f` is manual over all axes of the mesh. +* `check_vma` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html)). The shapes of the arguments passed to `f` have the same ranks as the arguments passed to `shard_map`-of-`f`, and the shape of an argument to `f` is computed @@ -368,7 +602,7 @@ so that this: ```python mesh = Mesh(jax.devices(), ('i',)) x = jnp.arange(16.) -f_shmapped = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i')) +f_shmapped = jax.shard_map(f, mesh=mesh, in_specs=P('i'), out_specs=P('i')) y = f_shmapped(x) ``` @@ -433,14 +667,14 @@ import jax import jax.numpy as jnp from jax import lax -from jax.sharding import Mesh, NamedSharding, PartitionSpec as P -from jax.experimental.shard_map import shard_map +from jax.sharding import Mesh, PartitionSpec as P ``` ```{code-cell} mesh1d = Mesh(jax.devices()[:4], ('i',)) +jax.set_mesh(mesh1d) -@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P(None)) +@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P(None)) def f1(x_block): print('BEFORE:\n', x_block) y_block = jax.lax.psum(x_block, 'i') @@ -477,8 +711,9 @@ each one separately, or over multiple axes at once: ```{code-cell} mesh2d = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('i', 'j')) +jax.set_mesh(mesh2d) -@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j')) +@jax.shard_map(mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j')) def f2(x_block): print('BEFORE:\n', x_block) y_block = jax.lax.psum(x_block, 'i') @@ -497,7 +732,7 @@ If we apply the `psum` over both axes, the `y_block` value is equal along both axes: ```{code-cell} -@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, None)) +@jax.shard_map(mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, None)) def f3(x_block): print('BEFORE:\n', x_block) y_block = jax.lax.psum(x_block, ('i', 'j')) @@ -522,7 +757,9 @@ each function application has a full copy of the data along that axis: Illustration of an all_gather computation. ```{code-cell} -@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) +jax.set_mesh(mesh1d) + +@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) def f4(x_block): print('BEFORE:\n', x_block) y_block = jax.lax.all_gather(x_block, 'i', tiled=True) @@ -549,7 +786,7 @@ When `tiled=False` (the default), results are stacked along a new axis instead of concatenated: ```{code-cell} -@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) +@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) def f5(x_block): print('BEFORE:\n', x_block) y_block = jax.lax.all_gather(x_block, 'i', tiled=False) @@ -580,7 +817,7 @@ The `jax.lax.psum_scatter` collective is a bit less intuitive. It's like Illustration of a psum_scatter computation. ```{code-cell} -@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) +@jax.shard_map(in_specs=P('i'), out_specs=P('i')) def f6(x_block): print('BEFORE:\n', x_block) y_block = jax.lax.psum_scatter(x_block, 'i', tiled=True) @@ -644,9 +881,9 @@ that mesh axis, `ppermute` sends its argument value from each source function instance to each destination: ```{code-cell} -@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) +@jax.shard_map(in_specs=P('i'), out_specs=P('i')) def f7(x_block): - sz = jax.lax.psum(1, 'i') + sz = jax.lax.axis_size('i') print('BEFORE:\n', x_block) y_block = jax.lax.ppermute(x_block, 'i', [(i, (i + 1) % sz) for i in range(sz)]) print('AFTER:\n', y_block) @@ -691,7 +928,7 @@ this iteration. In code, it might look like this: ```{code-cell} def psum_scatter(x, axis_name, *, tiled=False): - size = jax.lax.psum(1, axis_name) + size = jax.lax.axis_size(axis_name) idx = jax.lax.axis_index(axis_name) # function instance index along axis_name if tiled: x = x.reshape(size, -1, *x.shape[1:]) # split leading axis @@ -704,7 +941,7 @@ def psum_scatter(x, axis_name, *, tiled=False): ``` ```{code-cell} -@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) +@jax.shard_map(in_specs=P('i'), out_specs=P('i')) def f8(x_block): print('BEFORE:\n', x_block) y_block = psum_scatter(x_block, 'i', tiled=True) @@ -740,7 +977,7 @@ transpose operating along one positional axis and one cross-device axis: Illustration of an all_to_all computation. ```{code-cell} -@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) +@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) def f9(x_block): print('BEFORE:\n', x_block) y_block = jax.lax.all_to_all(x_block, 'i', split_axis=0, concat_axis=0, @@ -800,12 +1037,12 @@ overlap and thus improve FLOP utilization? import jax import jax.numpy as jnp -from jax.sharding import Mesh, NamedSharding, PartitionSpec as P -from jax.experimental.shard_map import shard_map +from jax.sharding import Mesh, PartitionSpec as P ``` ```{code-cell} mesh = Mesh(jax.devices()[:4], ('i',)) +jax.set_mesh(mesh) def device_put(x, pspec): return jax.device_put(x, NamedSharding(mesh, pspec)) @@ -835,8 +1072,8 @@ side: ```{code-cell} @jax.jit -@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), - out_specs=rhs_spec) +@jax.shard_map(mesh=mesh, in_specs=(lhs_spec, rhs_spec), + out_specs=rhs_spec) def matmul_allgather(lhs_block, rhs_block): rhs = jax.lax.all_gather(rhs_block, 'i', tiled=True) return lhs_block @ rhs @@ -861,10 +1098,10 @@ multiplies: ```{code-cell} @jax.jit -@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), - out_specs=rhs_spec) +@jax.shard_map(mesh=mesh, in_specs=(lhs_spec, rhs_spec), + out_specs=rhs_spec) def matmul_allgather_overlapped(lhs_block, rhs_block): - size = jax.lax.psum(1, 'i') + size = jax.lax.axis_size('i') idx = jax.lax.axis_index('i') shift = partial(jax.lax.ppermute, axis_name='i', perm=[(i, (i + 1) % size) for i in range(size)]) @@ -892,10 +1129,10 @@ each half in each direction: ```{code-cell} @jax.jit -@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), - out_specs=rhs_spec) +@jax.shard_map(mesh=mesh, in_specs=(lhs_spec, rhs_spec), + out_specs=rhs_spec) def matmul_allgather_overlapped_bidi(lhs_block, rhs_block): - size = jax.lax.psum(1, 'i') + size = jax.lax.axis_size('i') idx = jax.lax.axis_index('i') shift_up = partial(jax.lax.ppermute, axis_name='i', perm=[(i, (i + 1) % size) for i in range(size)]) @@ -943,8 +1180,8 @@ rhs = device_put(rhs, rhs_spec) Here we can use a `reduce_scatter` to perform the contraction sum over shards: ```{code-cell} -@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), - out_specs=rhs_spec) +@jax.shard_map(mesh=mesh, in_specs=(lhs_spec, rhs_spec), + out_specs=rhs_spec) def matmul_psumscatter(lhs_block, rhs_block): out_summand = lhs_block @ rhs_block return jax.lax.psum_scatter(out_summand, 'i', tiled=True) @@ -959,10 +1196,10 @@ inline an implementation of `psum_scatter` in terms of `ppermute`, then interleave the communication steps with local matrix multiplies: ```{code-cell} -@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), - out_specs=rhs_spec) +@jax.shard_map(mesh=mesh, in_specs=(lhs_spec, rhs_spec), + out_specs=rhs_spec) def matmul_psumscatter_overlapped(lhs_block, rhs_block): - size = jax.lax.psum(1, 'i') + size = jax.lax.axis_size('i') idx = jax.lax.axis_index('i') shift = partial(jax.lax.ppermute, axis_name='i', perm=[(i, (i - 1) % size) for i in range(size)]) @@ -984,10 +1221,10 @@ As in the previous example, to fully utilize interconnects on TPU, we'd run a bidirectional version: ```{code-cell} -@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), - out_specs=rhs_spec) +@jax.shard_map(mesh=mesh, in_specs=(lhs_spec, rhs_spec), + out_specs=rhs_spec) def matmul_psumscatter_overlapped_bidi(lhs_block, rhs_block): - size = jax.lax.psum(1, 'i') + size = jax.lax.axis_size('i') idx = jax.lax.axis_index('i') shift_up = partial(jax.lax.ppermute, axis_name='i', perm=[(i, (i + 1) % size) for i in range(size)]) @@ -1061,7 +1298,7 @@ params, batch = init(jax.random.key(0), layer_sizes, batch_size) Compare these examples with the purely [automatic partitioning examples in the "Distributed arrays and automatic partitioning" -doc](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html). +doc](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html). While in those automatic partitioning examples we don't need to edit the model functions to use different parallelization strategies, with `shard_map` we often do. @@ -1076,12 +1313,10 @@ the end. (To evaluate the gradient of the loss, the devices must perform all-reduce-sums of parameter gradients in the backward pass.) ```{code-cell} -from functools import partial - -from jax.sharding import NamedSharding, Mesh, PartitionSpec as P -from jax.experimental.shard_map import shard_map +from jax.sharding import Mesh, PartitionSpec as P -mesh = jax.make_mesh((8,), ('batch',)) +mesh = jax.make_mesh((8,), ('batch',), axis_types=(Auto,)) +jax.set_mesh(mesh) # replicate initial params on all devices, shard data batch over devices batch = jax.device_put(batch, NamedSharding(mesh, P('batch'))) @@ -1089,7 +1324,7 @@ params = jax.device_put(params, NamedSharding(mesh, P())) # adapt the loss function to sum the losses across devices def loss_dp(params, batch): - @partial(shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P()) + @jax.shard_map(mesh=mesh, in_specs=P('batch', None), out_specs=P()) def loss_spmd(local_batch): inputs, targets = local_batch predictions = predict(params, inputs) # use reference 'predict` @@ -1137,7 +1372,7 @@ There's one other ingredient we need: we don't want to store the fully gathered parameters from the forward pass for use on the backward pass. Instead, we want to gather them again on the backward pass. We can express that by using `jax.remat` with a [custom -policy](https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#custom-policies-for-what-s-saveable) +policy](https://docs.jax.dev/en/latest/notebooks/autodiff_remat.html#custom-policies-for-what-s-saveable) (or a `custom_vjp`), though XLA typically does that rematerialization automatically. @@ -1149,6 +1384,7 @@ to [weight update sharding (WUS)](https://arxiv.org/abs/2004.13336) and ```{code-cell} # shard data batch *and params* over devices mesh = Mesh(devices, ('batch',)) +jax.set_mesh(mesh) batch = jax.device_put(batch, NamedSharding(mesh, P('batch'))) params = jax.device_put(params, NamedSharding(mesh, P('batch'))) @@ -1164,7 +1400,7 @@ def predict_fsdp(params_frag, inputs): return outputs def loss_fsdp(params, batch): - @partial(shard_map, mesh=mesh, in_specs=P('batch'), out_specs=P()) + @jax.shard_map(mesh=mesh, in_specs=P('batch'), out_specs=P()) def loss_spmd(local_params, local_batch): inputs, targets = local_batch predictions = predict_fsdp(local_params, inputs) @@ -1198,7 +1434,8 @@ multiplications followed by a `psum_scatter` to sum the local results and efficiently scatter the result's shards. ```{code-cell} -mesh = jax.make_mesh((8,), ('feats',)) +mesh = jax.make_mesh((8,), ('feats',), axis_types=(Auto,)) +jax.set_mesh(mesh) batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats'))) params = jax.device_put(params, NamedSharding(mesh, P('feats'))) @@ -1209,9 +1446,8 @@ def predict_tp(params, inputs): inputs = jax.nn.relu(outputs) return outputs -@partial(shard_map, mesh=mesh, - in_specs=(P(None, 'feats'), P('feats', None), P('feats')), - out_specs=P(None, 'feats')) +@jax.shard_map(mesh=mesh, in_specs=(P(None, 'feats'), P('feats', None), P('feats')), + out_specs=P(None, 'feats')) def gemm_tp(inputs, W, b): block_result = jnp.dot(inputs, W) return jax.lax.psum_scatter(block_result, 'feats', @@ -1228,7 +1464,8 @@ def loss_tp(params, batch): We can compose these strategies together, using multiple axes of parallelism. ```{code-cell} -mesh = jax.make_mesh((4, 2), ('batch', 'feats')) +mesh = jax.make_mesh((4, 2), ('batch', 'feats'), axis_types=(Auto,) * 2) +jax.set_mesh(mesh) batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats'))) params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats')))) @@ -1245,9 +1482,8 @@ def predict_fsdp_tp(params_frag, inputs): inputs = jax.nn.relu(outputs) return outputs -@partial(shard_map, mesh=mesh, - in_specs=(P(('feats', 'batch')), P('batch', 'feats')), - out_specs=P()) +@jax.shard_map(mesh=mesh, in_specs=(P(('feats', 'batch')), P('batch', 'feats')), + out_specs=P()) def loss_fsdp_tp(local_params, local_batch): inputs, targets = local_batch predictions = predict_fsdp_tp(local_params, inputs) @@ -1325,8 +1561,8 @@ def predict_pp(params, inputs): outputs = jnp.dot(inputs, W_last) + b_last return outputs -@partial(shard_map, mesh=mesh, in_specs=((P(), P('stages'), P()), P('stages')), - out_specs=P()) +@jax.shard_map(mesh=mesh, in_specs=((P(), P('stages'), P()), P('stages')), + out_specs=P()) def loss_pp(params, batch): inputs, targets = batch predictions = predict_pp(params, inputs.reshape(K, B, -1)).reshape(K * B, -1) @@ -1371,6 +1607,10 @@ batch_ = jax.device_put(batch, NamedSharding(mesh, P('stages'))) ```{code-cell} print(jax.jit(loss)(params, batch)) +``` + +```{code-cell} +jax.set_mesh(mesh) print(jax.jit(loss_pp)(params_, batch_)) ``` diff --git a/docs/notebooks/thinking_in_jax.ipynb b/docs/notebooks/thinking_in_jax.ipynb index 5ddcdd32e2b4..26809769c981 100644 --- a/docs/notebooks/thinking_in_jax.ipynb +++ b/docs/notebooks/thinking_in_jax.ipynb @@ -6,13 +6,44 @@ "id": "LQHmwePqryRU" }, "source": [ - "# How to think in JAX\n", + "# Quickstart: How to think in JAX\n", "\n", - "\n", + "\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb)\n", "\n", - "JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively." + "**JAX is a library for array-oriented numerical computation (*à la* [NumPy](https://numpy.org/)), with automatic differentiation and JIT compilation to enable high-performance machine learning research**." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This document provides a quick overview of essential JAX features, so you can get started with JAX:\n", + "\n", + "* JAX provides a unified NumPy-like interface to computations that run on CPU, GPU, or TPU, in local or distributed settings.\n", + "* JAX features built-in Just-In-Time (JIT) compilation via [Open XLA](https://github.com/openxla), an open-source machine learning compiler ecosystem.\n", + "* JAX functions support efficient evaluation of gradients via its automatic differentiation transformations.\n", + "* JAX functions can be automatically vectorized to efficiently map them over arrays representing batches of inputs." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Installation\n", + "\n", + "JAX can be installed for CPU on Linux, Windows, and macOS directly from the [Python Package Index](https://pypi.org/project/jax/):\n", + "\n", + "```\n", + "pip install jax\n", + "```\n", + "or, for NVIDIA GPU:\n", + "\n", + "```\n", + "pip install -U \"jax[cuda13]\"\n", + "```\n", + "For more detailed platform-specific installation information, check out [Installation](https://docs.jax.dev/en/latest/installation.html)." ] }, { @@ -26,68 +57,40 @@ "**Key concepts:**\n", "\n", "- JAX provides a NumPy-inspired interface for convenience.\n", - "- Through duck-typing, JAX arrays can often be used as drop-in replacements of NumPy arrays.\n", - "- Unlike NumPy arrays, JAX arrays are always immutable.\n", - "\n", - "NumPy provides a well-known, powerful API for working with numerical data. For convenience, JAX provides `jax.numpy` which closely mirrors the numpy API and provides easy entry into JAX. Almost anything that can be done with `numpy` can be done with `jax.numpy`:" + "- Through [duck-typing](https://en.wikipedia.org/wiki/Duck_typing), JAX arrays can often be used as drop-in replacements of NumPy arrays.\n", + "- Unlike NumPy arrays, JAX arrays are always immutable." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "NumPy provides a well-known, powerful API for working with numerical data. For convenience, JAX provides [`jax.numpy`](https://docs.jax.dev/en/latest/jax.numpy.html) which closely mirrors the NumPy API and provides easy entry into JAX. Almost anything that can be done with `numpy` can be done with `jax.numpy`, which is typically imported under the `jnp` alias:" ] }, { "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "kZaOXL7-uvUP", - "outputId": "7fd4dd8e-4194-4983-ac6b-28059f8feb90" - }, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjgAAAGdCAYAAAAfTAk2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABz60lEQVR4nO3deXxU5b0/8M/sS5aZ7JOQjU3CDoKkQVrtNZeg1kprrfRiqdTi71Jpi/hy4V7FVqpU29peLT+pu/6qxfbeatX2ohRFqiIgGAWEsGUlmezJJDOZ9ZzfHzPnzAxkz5w52/f9euXVkszyTDw5z/d5nu/zfTQsy7IghBBCCFEQrdgNIIQQQghJNApwCCGEEKI4FOAQQgghRHEowCGEEEKI4lCAQwghhBDFoQCHEEIIIYpDAQ4hhBBCFIcCHEIIIYQojl7sBoiBYRg0NzcjLS0NGo1G7OYQQgghZBRYlkVfXx8KCgqg1Q4/R6PKAKe5uRlFRUViN4MQQggh49DY2IjCwsJhH6PKACctLQ1A+BeUnp4ucmsIIYQQMhoulwtFRUV8Pz4cVQY43LJUeno6BTiEEEKIzIwmvYSSjAkhhBCiOBTgEEIIIURxKMAhhBBCiOJQgEMIIYQQxaEAhxBCCCGKQwEOIYQQQhSHAhxCCCGEKA4FOIQQQghRHApwCCGEEKI4ggY4+/btw3XXXYeCggJoNBq8/vrrIz5n7969uPTSS2EymTBt2jS88MILFz1m+/btKC0thdlsRnl5OQ4ePJj4xhNCCCFEtgQNcNxuN+bPn4/t27eP6vG1tbW49tpr8dWvfhXV1dXYuHEjfvCDH+Dtt9/mH/Pqq69i06ZNeOCBB3DkyBHMnz8fVVVVaGtrE+pjEEIIIURmNCzLskl5I40Gr732GlauXDnkY+655x787W9/w7Fjx/jvrVq1Cj09Pdi1axcAoLy8HJdddhl+97vfAQAYhkFRURF+9KMf4d577x1VW1wuF2w2G3p7e+ksKkIIIUQmxtJ/S+qwzf3796OysjLue1VVVdi4cSMAwO/34/Dhw9i8eTP/c61Wi8rKSuzfv3/I1/X5fPD5fPy/XS5XYhsuMSzL4lBdNz480wGWZbGoNBNfnpYNrXbkw8mIenkDIew65sRJZx/sVgOWz8rDlJxUsZtFJO58zwD+92gL2vt9mJqTimvm5iPVJKmuhaiUpK5Cp9OJvLy8uO/l5eXB5XJhYGAA3d3dCIVCgz7m5MmTQ77utm3b8LOf/UyQNktNt9uPO/5Ujb017XHfX1hsx+OrFqIo0ypSy4iUHTjXiZ/srIbT5eW/98iuk/j+5ZNx79VlMOhoPwKJxzAsnnj3DJ549zSCTHQhYNvfT+DRb83Hv87KG+bZhAhPFXetzZs3o7e3l/9qbGwUu0mC6HL7ccOOj7C3ph1GnRYrFxTgW4sKkWrS49OGHnxrx0eo63CL3UwiMXtr2rD6mQNwurwosJnx3S+V4IpLcsCywLMf1OJHr3yKYIgRu5lEQliWxX++fhS/+ccpBBkWSyZn4nsVJZicnYJuTwC3/b9P8Pqn58VuJlE5Sc3gOBwOtLa2xn2vtbUV6enpsFgs0Ol00Ol0gz7G4XAM+bomkwkmk0mQNksFw7BY/4fDONfuRoHNjOfXLsEMRxoA4I5/vQRrnz+IU639WPfSJ3j99suRQlPIBMCZtn6s/8MRBBkWK2Y78JubFsBi1AEAdh1z4sd//BS7jjvxy3dqsPnqmSK3lkjF8x/W4Y8HG6HVAL/45jx8+7IiAIA/yOD+14/h1U8acdd/f4aSLCsWFmeI3FqiVpKawamoqMCePXvivrd7925UVFQAAIxGIxYtWhT3GIZhsGfPHv4xavXch7U4UNsFq1GHl26NBjcAMMluwR9uLUdumgmn2/rxq3dqRGwpkYoQw2Ljq59iIBBCxZQsPP6dhXxwAwAr5jjw2E3zAQC/f/8cDtV1idVUIiGnWvvwi13hlID7vzaLD24AwKjXYts35+LqOQ4EQiw2vloNbyAkVlOJygka4PT396O6uhrV1dUAwtvAq6ur0dDQACC8dLRmzRr+8f/+7/+Oc+fO4e6778bJkyfxf//v/8Wf/vQn3HHHHfxjNm3ahKeffhovvvgiTpw4gfXr18PtdmPt2rVCfhRJa+r24Jdvh4OW+66dhWm5aRc9JjfdjF9/O9xZvfhRHY6d701qG4n07DzUgGPnXbBZDPivVQtg1F98O/javALctDjcgd3/+jFaqlI5lmXxn68dhT/I4KszcnDL0tKLHqPVavDIt+YhL92E+k4Pnt53LvkNJQQCBziffPIJFi5ciIULFwIIBycLFy7Eli1bAAAtLS18sAMAkydPxt/+9jfs3r0b8+fPx69//Ws888wzqKqq4h9z00034Ve/+hW2bNmCBQsWoLq6Grt27boo8VhN/usfp+ELMvjSlEx8Z0nRkI/78vQcfG1ePhgWePCtL5LYQiI1vQMB/PqdUwCAOyqnIzfdPORj77m6DHarASedffjjIWXmr5HR2XOiDYfqumE2aLHtm/Og0Qy+MzPdbMB/XjsLALB97xm09A4ks5mEAEhiHRwpUVIdnDNt/Vj+m/fBsMBrP1w64np3S+8Arnh0L/whBq/e9iWUT8lKUkuJlDyx5zR+vfsUpuakYNfGr4y4S+rFj+rwwBvHMcluwd67rqRdVSrEMCxW/Nc+nGrtx/orp+KeFWXDPp5lWdy4Yz8+qe/G9y+fjC3XzUpSS4mSjaX/pruUzO14/ywYFqicmTeqZL58mwXfWlwIAPjde2eEbh6RIG8ghBc+qgMA/Piq6aMKVm66rAg5aSac7xnAa7Q7RpXeq2nDqdZ+pJn1+Pcrpo74eI1Ggx9dNR0A8MeDDehy+4VuIiFxKMCRsc5+H974rBkAsP7KkW84nPVXTIVOq8E/T3fgpFPZRQ/Jxf7nSBM63X5Msltwzdz8UT3HbNBh3ZcnAwCe2ncOKpz4Vb3nP6wDAHxnSTFsFsOonvOV6dmYMykdA4EQ/t/+egFbR8jFKMCRsZ2HGuEPMphXaMOlxfZRP68o04qq2eGcpVcONIzwaKIkLMvyHc33l00e01LTd5YUw2LQ4UxbPw7VdQvVRCJBNc4+fHCmA1oNsKaiZNTP02g0WPflKQCAVw81IMRQYEyShwIcmWIYlg9OvldROmSy31C+s6QYAPDakfPw+IMJbx+Rps+benHS2QeTXotvLSoc03PTzAZ8fX4BAGDnQQqM1WTnofB/7+WzHCjMGFs19KrZDtitBjT3erHvVPvITyAkQSjAkakDtV043zOANJMe184b3TJDrMunZqM404o+XxBvfd4iQAuJFP3pk/AuqBVzHKNeZoj1nfJwYPy3oy3o9QQS2jYiTYEQgzeqw0vh375sbEExEF7e/ObC8PP+SIExSSIKcGSKK4N+zdx8mA26ER59Ma1Wg29Hko3fjOTxEGXzBkJ8zhZX22as5hfaUOZIgy/IYNdxCozVYN+pdnS6/chONeIr03PG9RqrIuUr3j3Zhh4PJRuT5KAAR4a8gRD+fjTcuXzj0knjfp3rIssNH57pQEe/b4RHE7n74HQH+rxB5NvM+NI4ywNoNBr+uqGZP3X4y5HwYOr6BZOgH2d5gEvy0lDmSEOQYfH2cWcim0ckyBsISSL1gQIcGXr/VDv6fEEU2MxYUpo57tcpyUrB/EIbGBb432N001E67r9x1WwHtNqx5WzF+lpkSZQCY+XzBkJ4r6YNAHD9goIJvRYFxuqx65gTCx/cjZ+9eVzUdlCAI0O7vwgfNlo1Z2IdFRAuxQ8Ab9EylaL5gwx2fxEOcK6eM/TBtKNRkpWCuZMoMFaDD890wOMPId9mxtxJtgm91rWRkgQfne1EJwXGirb7i1b4ggysxrGnTyQSBTgyE2JYvHsyPKL611kTP55iRaSz+6S+m9bGFWz/uU64vEFkpxqxeAKzfhyufg4XbBNl4paTls/KG/NOzQuVZqdgdkE6QgyLPZF7GFEeXzCEvTVcHzWxwdREUYAjM4fru9Hl9sNmMUxoeYpTlGnFjLw0hBgW79MWTsXadSy8LFA12wHdBGf9AKByZi4A4ONznZJYayeJF2JY/ONEuKNaPjsxHdVVM8ODsvcowFGs/Wc74faHkJtmwrwJzvpNFAU4MsMtM/xLWe64E/4u9NWycGf1Lt10FIlhWOz+IvzfdsUEl6c403JTUZhhgT/I4MMznQl5TSIt3GAq3azHkskTH0wB4fsWAPzzdAcCdDK9InGzupWz8iacQjFRFODICMuy/MWTiOUpzlWR0fjemnYE6aajOCecLnT0+2Ax6BLWUWk0Gr6zosBYmbjB1FUz8xJ2uOq8STZkpxrR7wviUF1XQl6TSAfLxqRQzExcHzVeFODISEOXB3WdHhh0GnzlkvHVoxjMwiI77FYDegcCONLQk7DXJdLwz9MdAICKqVkw6ROX9MfN/L13so3OplIg7rq5ckbi7jVarQZXXBK9boiynG3vR0uvF0a9FhVTx1eKIpEowJGRD86EbziXFmcg1aRP2OvqdVpcEQmY9pykpFGl4crjf2V6dkJft2JKFswGLZwuL75ooUNblaS9z4eTzj4AwOXTEnvdfLUsfK+hmT/l4YLiy0ozxlWANtEowJGRDyMBTqJvOEB0lPYR5VMoiscfxCeRgzETOesHhEvwXz41fC1+ELmxEWXg7jWz8tORnWpK6Gt/eXoOdFoNzra70dTtSehrE3Fx94Fl0xJ7rxkvCnBkIsSw+OhsOPgQIsBZGumojjX30hlDCvLxuU74Qwwm2S2YnJ2S8NfnpqG5a5MoAzdb/OUEz/oBgM1iwLzC8O6aj89RHo5SBEIMPj4Xvg8Icd2MBwU4MvFFsws9ngBSTXrML0z81ru8dDOm5KSAZYEDtdRZKcW+U+GO6iuX5Ey4jslguADnUF0X7YpRCJZloyNxgTqqishRIfspMFaMTxt64PaHkJlixKz8dLGbA4ACHNngRlRfmpKVsO3hF1pKo3HF4UZUywSY9QOAmY502K0GePwhfN7UK8h7kOQ6294PpyucKHpZAmptDYYLjD8+10kJ6grxwelwrt/SqVmibw/nUIAjE9ya+LJpwmWmc8tUXKdI5K3H40dNazhRNFHbwy+k1WrwpcnRzorIHzfAETJRdHFJJgw6Dc73DKCxa0CQ9yDJxS03CjWYGg8KcGQgEGLwSX344qmYKtzFw50wfdLZR4coKsAndd1gWWBKTgpy0hKbKBpr6TRablCSQ5Gk9PLJwg2mLEYdFhTZAQD7z1GCutz5giFUN/UAEG4wNR4U4MjA8WYXvAEGNosB03NTBXufzBQjyhxpAGg0rgQHI4XUygW+4XD5FIfquuALhgR9LyIslmVxqDZ83SwuzRD0vSgPRzmONvXCH2SQnWoUZDPDeFGAIwOfRDqqxSUZgq9tcrM43NZiIl8HIh2V0COqabmpyE41whdkcJTycGStqXsATpcXBp0GC4uEDXC4ew3tpJK/g3wflSnIZobxogBHBriS5ok4BXoki0rCN7XD9RTgyFm/L4hj58PBxhIBlxqA8LEN3HVzpIGuGznj7jVzJtlgMQpbqG1hcQZ0Wg2cLi+aeygPR864Wb/LJLQ8BVCAI3ksy/KzKZcJPGUMRAOcL1pcdEq0jB2p70aIYVGYYcEku0Xw96PAWBkO8fca4Tsqi1HHbyem60a+QgyLTyL//ZYk4boZCwpwJK62w41Otx9GvRZzBah/c6ECuwX5NjNCDIvPGmm5Qa64Zc1k3XCiAU4PbfuVMW4GJxkBDkCBsRLUOPvQ5w0ixajDzPw0sZsThwIcieNmb+YX2hJ6UOJwLuVvOrQ2LlefNvYAABYW25PyfrMLbDDqtOjo99G2X5nqcvtxpq0fQDjfLxkupaVN2eN2+F5akiFYjbbxklZryEW4iycZ+TecxTSqkjWWZfFZJMBZIHCiKMds0GH2pMhyQwMFxnJU3Rj+e5+ak4KMFGNS3vPSSAD+RbMLA37agSdH1Q09AMKHQEsNBTgSxy0TJfPiiSaM9oBhaLlBbmo73HB5gzDptShL4pTxosg1Sjvw5Im71yQrKAaASXYL8tJNCDIsPo/UUSHy8lnkv9uCJM0WjwUFOBLm9gVxui1ciVaI86eGMjM/HWaDFr0DAZxt70/a+5LEqI7M3syZZIMhiVPGlE8hb1xHNb8oefea2B14h2mZSnZc3gDOtrsBAPML7eI2ZhAU4EjYsfO9YFjAkW5Gbro5ae9r0Gkxb5IdAPAZ1TWRnWp+ecqe1Pfl8ilOtfbB7aMdeHISu6yZ7I6Km50+Ut+T1PclE8fVvSrOtCIzScuaY5GUAGf79u0oLS2F2WxGeXk5Dh48OORjr7zySmg0mou+rr32Wv4xt9xyy0U/X7FiRTI+SlJxhxfOS+LsDYfbsXWUpo1lhwtw5ic5wMlLNyMv3QSGDZcZIPLR2DWAbk8ARl1ylzWB6HV69HxPUt+XTBx3rxGjjxoNwQOcV199FZs2bcIDDzyAI0eOYP78+aiqqkJbW9ugj//LX/6ClpYW/uvYsWPQ6XS48cYb4x63YsWKuMf98Y9/FPqjJF01P2VsT/p7cxfs5+dpBkdOvIEQTkSCi4UiXDdzIzN/dLK4vHDLUzPz05K2W5MzKz8dWg3Q6vKhzeVN6nuTiflMpNni0RI8wHnsscewbt06rF27FrNmzcKOHTtgtVrx3HPPDfr4zMxMOBwO/mv37t2wWq0XBTgmkynucRkZ0svgnigu6U6Mtc25k8IBzhfNLgRCTNLfn4zP8WYXAiEWWSlGFGYIX+DvQlxgfIwCY1n5TKRZPwBIMekxNSd8xt5Rum5k5TMRB+GjIWiA4/f7cfjwYVRWVkbfUKtFZWUl9u/fP6rXePbZZ7Fq1SqkpMQf4LV3717k5uZixowZWL9+PTo7hz6wzefzweVyxX1JXWdMPZFkFPi7UGlWCtJMeviCDE63UqKxXMSOqMQ4E4a7VmlHjLx8JuJgCoi9bijAkQtnrxetLh90Wg1mF6SL3ZxBCRrgdHR0IBQKIS8vL+77eXl5cDqdIz7/4MGDOHbsGH7wgx/EfX/FihV46aWXsGfPHjzyyCN4//33cfXVVyMUGryOwrZt22Cz2fivoqKi8X+oJOGWhqZkp8BmMST9/bVaDeZEZnFobVw+xEow5nAzf+c63OjzBkRpAxmbYIjhZ07EGonPm0Qzf3LD3WsuyUuD1agXtzFDkPQuqmeffRZz587FkiVL4r6/atUqfP3rX8fcuXOxcuVKvPXWWzh06BD27t076Ots3rwZvb29/FdjY2MSWj8xnzeKl2DMmUejKtnhOggxZv0AIDvVhAKbGSwbXi4j0ne6rR/eAIM0kx5TslNGfoIA5sbk/NFRH/LA179JYlmBsRI0wMnOzoZOp0Nra2vc91tbW+FwOIZ9rtvtxs6dO3HrrbeO+D5TpkxBdnY2zpw5M+jPTSYT0tPT476kjrt45olYW4DfSUWjKllw+4Ko7QzXpJhdIN5NJ7oDj64bOeCWNecW2qDVJn9ZEwBm5dug1QDtfT60unyitIGMzecS6KNGImiAYzQasWjRIuzZs4f/HsMw2LNnDyoqKoZ97p///Gf4fD7cfPPNI75PU1MTOjs7kZ+fP+E2SwU3a5LMolsX4mrhnGhxwRekMupSd6LFBZYF8tJNyEkzidYO7oZHO/DkgZtp45akxWAx6nBJXnh7Og2opI9lWXzBXTciDqZGIvgS1aZNm/D000/jxRdfxIkTJ7B+/Xq43W6sXbsWALBmzRps3rz5ouc9++yzWLlyJbKysuK+39/fj7vuugsff/wx6urqsGfPHlx//fWYNm0aqqqqhP44SdHW50VHvw8aTbiqsFiKMi2wWw0IhFicclKisdRxHZWYszdANA+HaijJA1ezSOxEUbpu5KOl14tuTwB6rQbT81LFbs6QBM8Muummm9De3o4tW7bA6XRiwYIF2LVrF5943NDQAK02Ps6qqanBBx98gHfeeeei19PpdPj888/x4osvoqenBwUFBVi+fDm2bt0Kk0m8UWsicZHx5OwUUZO3NBoN5k6y4Z+nO/BZU49oeR1kdI43h0e+Uumo6jo96B0IiJIkT0YnxLB83STRr5tCG/58uIlm/mSA66Om5abCbEhu3aSxSErvuWHDBmzYsGHQnw2WGDxjxowhE80sFgvefvvtRDZPck60hM+fmiXi7A1nTiTAocq00ieVkXhGpAZPU/cAjjf3YunUbFHbQ4ZW3+mGxx+C2aDF5GxxR+JzY3ZSsSwrSpkDMjrcbLEU+qjhSHoXlVpxHdUsCdQW4JbITlCAI2mBEMMvI4q9RAVEb3xcsE6kibvXzHCkQydSgjGnzBGuaNzR70d7HyUaS9kXLeFZNin0UcOhAEeCvogsNUghOp4VOZemxtkHhqHtm1J1urUf/hCDdLNelArGF6LAWB6+kNBI3GLUoTSyTf2EkwJjKZPSIHw4FOBIjMcfxLmO8FZfKVw8pVkpMOm18PhDqO/yiN0cMgQu/2ZWQbokpvZnRgLjk04KcKRMKsuaHAqMpa93IMBX2ZdCYDwcCnAkpsbZB5YNF0zLTTOL3RzodVrMcIQ7K7rpSJdUdlBxuI7qVGs/gnSWmWTxuRQSCXBmUYAjedx/m0l2C+xWo8itGR4FOBIjxam/mQ666UjdF83SGokXZViRYtTBH2T4GUkiLW19XrT3hctRlEUGMWLj2nGScrck6wuJBcXDoQBHYqS0Js7hlhsowJEmlmVjlhqkMYOj1WpQRqNxSeMSwMUuRxGLm/k7295PxUUl6rjEBlPDoQBHYk5IcQaHdsRI2vmeAfT7gjDqtJiSI85ZQoOJBsZ03UiRFAdT+TYzbBYDggyLM21UXFSK+FUGCV03Q6EAR0JCDIuTTq4GjjSmjIHw9k0g3JH2DtAJ0VJTE7lmpuSkwKCTzp90GS1tSprUZv2AcHHRMgcFxlLlDzI40xb+7yJmlf3Rks7dkKChyyOZoluxbFYDJtnDW49PUmclOTWt4RvODInkUXBoR4y01UR2uEkl/4ZD14101XW6EQixSDVJoxzFSCjAkRBuJD49N030olsXojwc6eKuG6kFOGWONGg0QFufD539VLhNSvxBBufaw8nfUrtuuKUPKjEgPdy95pK8VEmUoxgJBTgScqqVu3ikdcMBKA9HyvgAR2LXTYpJj5JMKwDwS69EGmo73AgyLNJMeuTbxC9HEassJndrqCN7iDik3EcNhgIcCalpjUbHUjOTRlWSFAhFR+JSvOnQcoM0cR3VdAmOxC/JS4NWA3S5/WijIxskJTqDI717zWAowJGQ01yAI7EpYyAadJ1u66dRlYTUdbjhDzFIMeokuSbO3Qi5DpVIwymJ5m0BgNmgw+TIkQ01NPMnKVK+bgZDAY5ExK6JSzE6LslKgUGngccfwvmeAbGbQyJqYoJiqY3EgdgAh7b8SonUR+IUGEuPNxA9rkeq182FKMCRiLrO8Jp4qkmPAomtiQOAQafFlMjOrtPUWUnGKYnm33C4mb8zNPMnKVLPpZieG71uiDSE/4aBzBQjslOlfUQDhwIcieB3UElwTZwzPdJZ0ahKOk5KfCRekpUCvVaDfl8QLb1esZtDII+R+HSawZGc6C5f6fZRF6IARyK4/BupjsSB8PZ1gJYbpITrAKRWy4Rj1Gv5fArqrKRBDiNxLvA63Uozf1Iht/wbgAIcyaiR+JQxEJtoTB2VFAz4Y0biEr7pxHZWRHxyqGVSmm2FTqtBny+IVhftpJICOfRRF6IARyK4m7+UL57pMR0Vw9CoSmyn2/rAskBWihHZqSaxmzMkWtqUFqnn3wCASa9DaVa4hhJdN9JwSqIFRYdDAY4EeAMh1HVGdlA5pFcDh1OaZYVRp8VAgHZSSYFUKxhfiFvaPE0Jo5IghwAHoJ1UUuLyBtAcyaG7JFfa100sCnAk4Gx7PxgWsFsNyJHwSFwfc1o1LVOJTz4dFe2kkhIuh076gTHtpJIKboXBkW6GzWoQuTWjRwGOBMR2VFJdE+dMy+WWG+imI7YaGSxrAkBpdriGUr8vyI8CiTj6vAF+9lXqI3HaSSUdsZWv5YQCHAk4xXdU0r94aNpYOs5GRrZSv+kYdLSTSipOyWgkzien08yf6KR63t1IKMCRAKkXa4vF76SiGRxRDcRUlJ6aI+0AB4jm4Zyh60ZUp2U0Eud3UnlpJ5XYuGVCqc8WX4gCHAk41cbddKR/8XBtPNNGO6nEdLY9fMPJTDEiM0WatUxi0U4qaeCum+kSX54CaCeVlHDXzdRc6QfGsSjAEZk3EEJTd3gkPk0GF09JJu2kkgL+hhNJ+pY6fmmTEkZFdTZy3t3UXHldN7QDTzyxVcjlcr/hUIAjstoON1gWsFkMyJLBSDx2JxWNqsTD5d/IISgGYnZStfZRPoWIooGxPK4bbifVabrXiKY2EhRnpxpht0q/j4pFAY7IYkfiUt9BxZlOJ0SLjh+Jy6Sj4k6jd9Np9KLxBkJojFS+lst1QzupxHemPfy7nyKTayYWBTgiOyezjgoApkXaeq6dAhyxcEl/crluDDotSrPCM39ccEaSq77TA4YF0s16yZ5BdSHaSSW+s23y66M4FOCITI7JW9wS1VkKcEQRYljUdoRvOnJZogKi1w0FxuKIvdfIZba4JMsKjQbo8wbR0e8XuzmqJLd8v1gU4IiMu3imZMvn4uEi+bPtbhpViaCxywN/iIFJr0WB3SJ2c0ZtKj/zRzM4Yjgrs1k/ADAbdCjMCF/jFBiLg+uj5DSY4iQlwNm+fTtKS0thNptRXl6OgwcPDvnYF154ARqNJu7LbDbHPYZlWWzZsgX5+fmwWCyorKzE6dOnhf4YCccwbHT6T0YXD1e0rXcggC43jaqSjQ+Kc1Kh08pjJA5E1/Bp5k8ccksw5vCBcQcFxskWDDGo65BX3lYswQOcV199FZs2bcIDDzyAI0eOYP78+aiqqkJbW9uQz0lPT0dLSwv/VV9fH/fzRx99FI8//jh27NiBAwcOICUlBVVVVfB65VUG3unyYiAQgl6rQXGmVezmjJrFqMOkyMwB5VMkXzT/Rj6zfkDsEhVdM2KIJqbL7LrJjgTGtFU86Zq6B/jZ4kkymi3mCB7gPPbYY1i3bh3Wrl2LWbNmYceOHbBarXjuueeGfI5Go4HD4eC/8vLy+J+xLIvf/va3uO+++3D99ddj3rx5eOmll9Dc3IzXX39d6I+TUNyIqiTLCoNOXquF3IwTTRsnn1ynjKdGOiqnywu3Lyhya9SFZVlZ5vsBMYExzeAkXexssVZGs8UcQXtVv9+Pw4cPo7KyMvqGWi0qKyuxf//+IZ/X39+PkpISFBUV4frrr8fx48f5n9XW1sLpdMa9ps1mQ3l5+ZCv6fP54HK54r6kQI5r4hwuZ4iWG5JPblvEOTargd+9U0udVVI5XV54/PKbLQYoOV1Mck4wBgQOcDo6OhAKheJmYAAgLy8PTqdz0OfMmDEDzz33HP7617/iD3/4AxiGwdKlS9HU1AQA/PPG8prbtm2DzWbjv4qKiib60RIiWlVUXh0VEDuDQx1VMrEsyy9RyW0GB4hZbqDOKqm4XD9ZzhZHAvnG7gH4giGRW6Muct4iDkhwF1VFRQXWrFmDBQsW4IorrsBf/vIX5OTk4Pe///24X3Pz5s3o7e3lvxobGxPY4vGT4w4qzlSawRFFp9uP3oEANJposrecREsMUGCcTHJNMAaA3DQTUk16hBgWDZ0esZujKnJd1uQIGuBkZ2dDp9OhtbU17vutra1wOByjeg2DwYCFCxfizJkzAMA/byyvaTKZkJ6eHvclBecUMIPT0OWhUVUScbM3RRlWmA06kVszdlOpSKQo5NxRaTQaCoxFQktUwzAajVi0aBH27NnDf49hGOzZswcVFRWjeo1QKISjR48iPz8fADB58mQ4HI6413S5XDhw4MCoX1MK+n1BOF2RA8yy5XfT4UZVDAsaVSWR3G841FGJQ84zOEB0lvtcBwXGydLl9qPbE54tniLDPgpIwhLVpk2b8PTTT+PFF1/EiRMnsH79erjdbqxduxYAsGbNGmzevJl//IMPPoh33nkH586dw5EjR3DzzTejvr4eP/jBDwCEo/mNGzfi5z//Od544w0cPXoUa9asQUFBAVauXCn0x0kYbgSbnWqCzWoQuTVjFz+qoptOssjtiIYLcbVwajv6wTBUJDJZorkUcg2Mua3iFBgnC3dfn2S3wGKU32wxAOiFfoObbroJ7e3t2LJlC5xOJxYsWIBdu3bxScINDQ3QaqNxVnd3N9atWwen04mMjAwsWrQIH330EWbNmsU/5u6774bb7cZtt92Gnp4eLFu2DLt27bqoIKCUyX0kDoQ72c+bemk0nkTc71qOCcYAUJRhgUGngTfAoMXllWVtDbmJnS2W44GJQGyxPxpMJYucd/lyBA9wAGDDhg3YsGHDoD/bu3dv3L9/85vf4De/+c2wr6fRaPDggw/iwQcfTFQTk06OFYwvRFvFk4+76ci1o9LrtCjJSsGZtn6ca++nACcJuNninDQTbBb5zRYD8UUiWZaVzVlacib3ZU1Agruo1EIJFw8XnNEMTnJ4AyE09w4AiN7w5YgPjKkybVIoYbZ4cnYKNBo6HiaZ+OXwXPleNxTgiCRaIVK+F0/sjhg6dFN4DV0esCyQZtIjK8UodnPGbQqdLZRUcq9lAoQP3Syw0fEwycT9fco1wRigAEcUIYaNHmAm44unJMsKjQbo8wbR3u8TuzmKx1X/Lc1OkfUU/VQ6kyqpuOtGjnWTYlFF4+TxBxk0doX7KDkPwinAEcH5yAFmRp0WkzLkm4NgNuhQGGl/LXVWgqtTTEdF1YyTSSkBDp0qnjyN3R4wLGA16pCbZhK7OeNGAY4IajvDf6DFWVboZHiAWazSrPBNs66TbjpCi53BkTNuBqel1wuPnw7dFBLLsvzfplKuG5rBER43mCrNkvdsMQU4Ioi9eOSOGxXWdlCxP6FFR+LyOizxQnarEZmRHCJaphJWW58PHn8IWk24+rWcRWf+6JoRmlJm/SjAEYFSOiogZgaHpo0FF71u5Ju3xeFH43TdCIq7ZgozrDDq5X2753JBGro88AcZkVujbNFZP3n3UfK+4mVKKVPGQDTCpyUqYbl9QbT1hRO5Jytg5o8C4+RQSt4WADjSzbAYdAgxLBq7acZYSNwmGLmvMlCAIwL+piPziweIBml1nW4qvS8gLoDMTDHK8miPC/HXDQU4guLy/ZQQ4Gg0GpRkhWcU6mlAJShaoiLjEggxaOwOF2ubLOPtd5zCDAt02nDp/dY+r9jNUSw+wThL3lPGHEpOT446hV03lPMnvNiConJfZaAAJ8maugcQYliYDVrkpcnn7KyhGHRaFHFbxWk0Lpg6heyg4nBr+3V0Er2g+KUGxVw34c9BMzjC4QuKmuVdUBSgACfpYndQaWW+RZwTXW6gzkoo0aqiCumoIjM4XW4/egcCIrdGmRgmukVc7ksNHG4migZTwoldnpLzFnGAApykq1XQFnEOLTcIT2kzOCkmPXIiBcRoNC6MFpcXviADvVajmENN6V4jPCWVMaEAJ8mUtIOKE10Xp5uOULilHCXcdDhckj1dN8LgOqriTCv0OmXc6rl7zfnuAdoqLhClFBQFKMBJOiXVwOHQjhhh9XqiJygrZakBiMnDoaVNQSipo+LkpJlgNerAsKCt4gJRUh9FAU6S8TM4ChyJ13d5aKu4ALitvrlpJqSY9CK3JnFKaLlBUEqqgcMJbxWnRGMhKamPogAnifxBBue5LeIKuukU2M0w6DTwBxl+eyFJnNqO8Nk7ShqJA1QkUmhKXA4HojMLtFU88Tz+IFpdkYKiCrhuKMBJooau8AmtKUYdn2CpBHqdFkWZtNwgFO5GrpQdVByqZiysWgUVFI1VQteNYLj7d4bVALtV3lvEAQpwkor7gyyR+Qmtg+ETRmk0nnBK20HF4XJwuj0B9Hpoq3giBUMMGrq4Gjjyz6WINZmWNgWjtLwtCnCSSGk1KWJRorFwlFhaAACsRj1yIzOZ1FklVnOPF4EQC6NeiwKbMraIc7jjGuiaSTy+j1LIvYYCnCSKRsfKGlEBFOAIhWVZ/nc6RQFHe1yolPJwBMHNpJZkWhVTUJRDW8WFQzM4ZNyUlJ1+IVqiEkZHvx99viA0mnA9E6WhWjjCUOqyJkBbxYWktJ13FOAkEZfApZSLJxY3K9XY5UEwRKOqROGC4gKbBWaDTuTWJF5JNnc6NHVUiaSU06AHE7tVnGaME0tpaRQU4CSJkk5oHUyBzQKjXotAiEVzD50qnihK7qgAmsERitI6qgtNpsNaE87lDaCjP1xQVCl9FAU4ScKf0GqS/wmtg9FqNSiJLKHQMlXiRHfeKW95CqAcHKEo6TyhwVCJgcTjfpfZqSakKqSgKAU4SRKbvKW0LeIcSjROPG6rr1IDHO5z9XgC6PH4RW6NMgRCDBoVWFA0Fh26mXhKOqKBQwFOkig56Y9Dh24mHhfgFGcq87qxGvXIS+e2itNyQyI0dQ8gxLCwGHT871ZpaOYv8eoVeKAvBThJEq0voJzo+EI0qko87qaj1BkcgJYbEo37+yvJsip4tjj890BbxRNHifcaCnCSJHrxKCc6vlBp5A+jgUbiCdHrCaB3IFzhV4lbxDmllGicUNzfn5KvmZxUE1IiW8W5WU4yMQ1d4b+/YgX1URTgJIkSo+MLlUSmjRu7PQjRqeITVt8VTfpT0iniF+KWG+h06MSILmsq915Dp4onHt9HKei6oQAnCfxBBi2RLeLFCg5wHOlmGHXcVnE6VXyi1BAUAzGnQ9PMX0Ko5bop5U8VpwBnogb8IbT1hU8RV9J1k5QAZ/v27SgtLYXZbEZ5eTkOHjw45GOffvppfPnLX0ZGRgYyMjJQWVl50eNvueUWaDSauK8VK1YI/THGrak7fIq41ahDTqoyk/4AQKfVoDAzfO4NTRtPnBpG4gDtvks0JS41DKaUn8Ghe81EcRWh08x62CwGkVuTOIIHOK+++io2bdqEBx54AEeOHMH8+fNRVVWFtra2QR+/d+9efOc738F7772H/fv3o6ioCMuXL8f58+fjHrdixQq0tLTwX3/84x+F/ijjVh/TUSk16Y/DTW/STWfi1JBLAUQ/X+8AnSo+USzLRksLKPy6oU0NiRM766ekPkrwAOexxx7DunXrsHbtWsyaNQs7duyA1WrFc889N+jjX375Zfzwhz/EggULUFZWhmeeeQYMw2DPnj1xjzOZTHA4HPxXRkaG0B9l3LiOqkjhNxwgmkTN5Y+Q8eN+h0qaMh6M1ahHTuRUcbpuJqatzwdvgIFOq8GkDGWdIn4hbrmfZosnrp4/nFVZs36CBjh+vx+HDx9GZWVl9A21WlRWVmL//v2jeg2Px4NAIIDMzMy47+/duxe5ubmYMWMG1q9fj87OziFfw+fzweVyxX0lk1pGVEB0NE47qSauQSW5FADN/CUK9/srsJth0Ck7xZL7uzjfPUDn300QvxyusHuNoH8BHR0dCIVCyMvLi/t+Xl4enE7nqF7jnnvuQUFBQVyQtGLFCrz00kvYs2cPHnnkEbz//vu4+uqrEQqFBn2Nbdu2wWaz8V9FRUXj/1DjoJakPyD6GamjmhhfMIQWV/hML6UW+YtFo/HEUEveFgDkpZlh1GsRZOj8u4lS4g4qAJD03tNf/OIX2LlzJ/bu3Quz2cx/f9WqVfz/nzt3LubNm4epU6di7969uOqqqy56nc2bN2PTpk38v10uV1KDHLUk/QHRACd89harqPXcZGrsGgAbSUzPTlXe2WUX4qbGaeZvYhoiSw1qCIq1Wg2KM60409aP+i634mYfkkmpgbGgMzjZ2dnQ6XRobW2N+35rayscDsewz/3Vr36FX/ziF3jnnXcwb968YR87ZcoUZGdn48yZM4P+3GQyIT09Pe4rWdSU9AcAhRlWaDRAvy+ILjedLTRefFCsgsR0IGbmj3JwJqRe4WeXXYiWNicuxLBo6qYlqjEzGo1YtGhRXIIwlzBcUVEx5PMeffRRbN26Fbt27cLixYtHfJ+mpiZ0dnYiPz8/Ie1OJDUl/QGA2aBDfnp4to3OFho/teyg4hRR7lZCKHWpYSi0tDlxLb0DCIRYGHQa5NuU1UcJnoW2adMmPP3003jxxRdx4sQJrF+/Hm63G2vXrgUArFmzBps3b+Yf/8gjj+D+++/Hc889h9LSUjidTjidTvT39wMA+vv7cdddd+Hjjz9GXV0d9uzZg+uvvx7Tpk1DVVWV0B9nzNSU9MeJ3nRoND5eqhuJRz5ni8sLX3DwXDoyMqUmiw4lOoND95rx4nf5Zlih0yprtljwHJybbroJ7e3t2LJlC5xOJxYsWIBdu3bxiccNDQ3QaqMd/5NPPgm/349vfetbca/zwAMP4Kc//Sl0Oh0+//xzvPjii+jp6UFBQQGWL1+OrVu3wmSSXhE9pW6/G05JZgo+PtdF08YTwM/gqCBvCwCyUoxIMerg9ofQ2DWAabmpYjdJdvq8AX5ZWC0zfyVU7G/C6hUcFCclyXjDhg3YsGHDoD/bu3dv3L/r6uqGfS2LxYK33347QS0THjeiUkMNHE4xHbo5YfUqytsCwmcLFWel4ESLCw1dbgpwxoG712SmGJFmVk412uHwS5u0qWHclLysqY41ExGpaYs4J5owSgHOeDBMTGK6mq4bShidELXlbQFAUaYFGg3g8YfQ0U+bGsaDSyVQ4iCcAhyBqWkHFYdbjqOOanxa+7zwB8OJ6QV2ZSX9DYdqKE2M2vK2AMCkj25qoETj8YkOppS3HE4BjsDUlvQHRD9rR78Pbl9Q5NbIDzcSn2S3qCYxHaAdMROl5KWG4dCmhvFjWVbRqwzquXuKIDbpT4nR8VBsFgMyrOEcAOqsxq5eoUW3RhKd+aOOajzUVFA0Fs0Yj1+PJ4A+b3gQqsT7DQU4AuL+4LJSjEg1SbpodMIVZ1FnNV7RHVTKu+EMhxtBNnYPgGFYkVsjP0qtRjsS2tQwftxgKi/dBLNBJ3JrEo8CHAGpcXmKQwmj46e2HVScfJsZeq0G/iCD1j46W2gsAiGGP49JiUsNw6FNDeOn9DImFOAISK1r4gDddCaCO09IbR2VXqdFYaTaNwXGY3O+ewAhhoXZoEVumvTqgQmJlqjGjy/yp9A+igIcAcWeJ6Q2xVR6f9yiOTjKHFUNh45sGJ/YvC211YKhTQ3jp/RyFBTgCKheZdVoY/EVRmlnw5j0DgTQ4wkAUOnSJh26OS5qOkX8QjaLATYLbWoYD6WXFqAAR0BK3n43Eu4zN/d4EQgxIrdGPhq71JuYDtByw3gpfSQ+EqqhND5KLw5JAY5A/EEGLb0DANSZg5ObZoLZoEWIYXG+e0Ds5shGvUp3UHGoFs741Cu8oxoJvyROM3+j5g2E4HRxienKnPmjAEcg53sGwLCAxaBDjsqS/oDI2UKZlGg8VtzSjBqDYoBG4uOl5h2bQPS6ocB49LjZ4lSTnq9bpjQU4AikvjOaYKy2pD8ONypooFo4o6a2U8QvxAXFvQMB9EZykcjwWJZV5ZEwsWhpc+xiZ/2U2kdRgCMQtY+ogOjNto5uOqOm5tICAGA16vkZT0o0Hp32fh88/hC0GqAwQ53XDS1tjp0a8rYowBGI2jsqgJYbxkMNN52RUJHIseFm/fJtFhj16rylc38v57sHEKRNDaOihkG4Ov8akkDtyaJAdJmFEv9Gxx9k0BxJTFf3dUOj8bGgoBjISzPDqNciyLB8RWcyPKVXMQYowBGMmov8cUoyox0Vy9LZQiNp6vaA5RLTU9WXmM7hbrhU7G901L6DCgC0Wg2KuCrYNKAaFaXXwAEowBFEXNKfSpNFAWBShgU6rQbeAIO2Pp/YzZE8NVejjVWcRR3VWKhhqWE0+OKiFBiPKMSwaOqKzBYrODCmAEcAbX0+eAMMtBpgkt0idnNEY9BpUWA3A6Cbzmio9RTxCxXTDM6YqGGpYTSKM2lpc7ScLi/8IQZ6rQb5NrPYzREMBTgC4P7ACuzqTfrjRLdv0mh8JJSYHsZNmbe4vPAFQyK3RvooBycsuqmB7jUj4QYPhRkW6HXK7aOU+8lEpOYjGi5ECaOjx+Vtqf26yUoxIsWoA8sCjV1UBXs4bl8QHf1+ADTzR7s2R4/PEVV4CgUFOAJQ88F3F6Itv6On5sNZY2k0GtqBN0rcwCHDakC6WZnVaEeLX9qkTQ0jUstsMQU4AlBDdvpoUQn10aFqtPEoMB4d2kEVVZRpgUYDePwhdLr9YjdH0tTSR1GAIwC66UTFjqrI0Nr6fPAFw4npBSpOTOfQcsPoqGWpYTRMeh3y02lTw2hwOThFCu+jKMARQEMXBTgcLi+gy+1Hn5fOFhoKd0OmxPQwyt0aHbUsNYxW9Lqhpc3h8DvvaAaHjEWfN4CuyPSo0i+e0Ug16ZGVYgRAo6rhqOWGM1olNPM3KlQDJ14xLW2OqNcTgMsbBKD8QTgFOAnG/WFlphiRpvKkPw6NxkcWnfWjpQYgPneLYShhdCiUtxWPK/ZHNZSGxhXQzEkzwWrUi9waYVGAk2C0PHUxShgdGZUWiJdvM0Ov1cAfZNDaR2cLDSYYYnC+m84ui8XP4NBgakhqWtakACfBqOjWxaJbfummMxQaicfT67SYxJ0tRIHxoJp7vAgyLIx6LfLSlFuNdiwoOX1kahqEU4CTYGqKjkcreugmJf4NhXIpLsaX3qfOalD1MQf6arXqPbssFpe71dHvg8cfFLk10sTl+6nhXkMBToLRts2LFdOoalixielqGFWNFtVQGh4Npi5msxpgs4RzH+m6GZyalsOTEuBs374dpaWlMJvNKC8vx8GDB4d9/J///GeUlZXBbDZj7ty5+Pvf/x73c5ZlsWXLFuTn58NisaCyshKnT58W8iOMmpountHibsDNPQPwBxmRWyM9lJg+OP4cM+qoBtVIs36DomWq4TWqaEOD4AHOq6++ik2bNuGBBx7AkSNHMH/+fFRVVaGtrW3Qx3/00Uf4zne+g1tvvRWffvopVq5ciZUrV+LYsWP8Yx599FE8/vjj2LFjBw4cOICUlBRUVVXB6xU3GdEfZNDco/wj6McqJ80Ei0EHhgXO99DZQhdS05r4WBTxS1S0tDkYmsEZXBEtbQ7JFwyhxRXuJ9UwCBc8wHnsscewbt06rF27FrNmzcKOHTtgtVrx3HPPDfr4//qv/8KKFStw1113YebMmdi6dSsuvfRS/O53vwMQnr357W9/i/vuuw/XX3895s2bh5deegnNzc14/fXXhf44wzrfMwCGBcwGLXLTTKK2RUo0Gk00n4JG4xehWb/B8SNxumYGVU8zOIPid21Szt9FGrsGwLJAilHH1ydTMkEDHL/fj8OHD6OysjL6hlotKisrsX///kGfs3///rjHA0BVVRX/+NraWjidzrjH2Gw2lJeXD/maPp8PLpcr7ksIfPJWphUaDSX9xeJr4dBo/CL8KeI0Eo/DBcU9ngB6B6gKdiyWZelQ3yHQEtXQuHtNkUr6KEEDnI6ODoRCIeTl5cV9Py8vD06nc9DnOJ3OYR/P/e9YXnPbtm2w2Wz8V1FR0bg+z0ioWNvQqMLo0KI7qOi6iZVi0iM7NTwTSssN8Trdfrj9IWg04UMmSRR3/22kmb+LqG22WBW7qDZv3oze3l7+q7GxUZD3mV9ox4//ZRq+Ni9fkNeXM1puGBodzjo02kk1OO6ayU83w6TXidwaaeGumabuAQRDtKkhVjTAUcdgStA6zdnZ2dDpdGhtbY37fmtrKxwOx6DPcTgcwz6e+9/W1lbk5+fHPWbBggWDvqbJZILJJHxOzPwiO+YX2QV/HzmimiaDi01MV8uoaixKMq04XN9N+RQXoB1UQ8tLN8Oo08IfYtDS61X8idlj0aiyDQ2CzuAYjUYsWrQIe/bs4b/HMAz27NmDioqKQZ9TUVER93gA2L17N//4yZMnw+FwxD3G5XLhwIEDQ74mEV9JTDVjlqWzhTiUmD482hEzuOgOKnWMxMdCp9WgMJOqYA+mXmWV9gVfotq0aROefvppvPjiizhx4gTWr18Pt9uNtWvXAgDWrFmDzZs384//yU9+gl27duHXv/41Tp48iZ/+9Kf45JNPsGHDBgDhHTkbN27Ez3/+c7zxxhs4evQo1qxZg4KCAqxcuVLoj0PGaZLdAq0GGAiE0N7vE7s5kkGJ6cOjhNHB8VWMVdJRjRXtpLoYw7CqK0kh+FGiN910E9rb27FlyxY4nU4sWLAAu3bt4pOEGxoaoNVG46ylS5filVdewX333Yf/+I//wPTp0/H6669jzpw5/GPuvvtuuN1u3Hbbbejp6cGyZcuwa9cumM10HotUGfVaFNgtaOoeQEOnB7l0dg4ASkwfCeXgDK6B8raGFZ4xbqfrJkZrnxf+IAOdVoMCuzoS05NyVvqGDRv4GZgL7d2796Lv3XjjjbjxxhuHfD2NRoMHH3wQDz74YKKaSJKgONOKpu4B1Hd6sLg0U+zmSEKDynY1jBUX+DX3hqtgG/Wq2BcxIrUtNYwV5fxdjJsFnWS3wKBTx9+ROj4lkQTaSXUx6qiGl51qhNWoA8sCTd103QDAgD+E9r7wMi/l4AyOljYvpsbBFAU4JGm40TgV+4vibjq002NwsVWwKTAO45ZdbBYDbFY6u2wwsZXTaVNDmNrybwAKcEgS0QxOPJaNJv1RFeOh0XJDvNjEdDI4bsDQ7wuiy+0XuTXSUE8BDiHC4f6wqMJoWHufDwOBELQaoDBDPTedsaLlhngNVANnRGaDDo708EYGSjQO42bOaYmKEAFwN+SOfj/6fUGRWyM+bkSVb7NQ8uwwivkaSrS0CdAp4qNVTDvw4tSrcMcm3VVJ0qSbDciI5AzQcoP6zoUZrxI6iT4OJaaPTgmdf8frHQigxxM+sFZNM38U4JCkotF4VAN1VKNCCaPx6BTx0aGlzShuQJmdakSqKSnVYSSBAhySVDSqiqqnjmpUJmVYoNNq4A0waOtTdxXsYIhBUzedXTYa/DEfNJiKVr5W2bImBTgkqagybRQtUY2OQadFgT2cMKr2wLil14sgw8Ko1/JJtGRw3Pl3ar9mAPWdIs6hAIckVRHlU/DUWJdivLiCdvUqr6HEXTNFGRZotXR22XC42eK2Ph8G/CGRWyMutZ0izqEAhyQVLVGF9XkDfH0OmsEZGZcYqfYSA2odiY+H3WpAmjmcb9Ko8irYap0tpgCHJBV3Yz7fM4BAiBG5NeLhbjiZKUakmaka7UhKqJoxAPXmUoyHRqOhROMItW5ooACHJFVumgkmvRYhhkVzz4DYzRENLU+NTTHN/AGgU8THKnrdqHdp0xcMobk3fK9V24YGCnBIUmm1GuqsoN4p4/Giom1hdN2MDX/+nYqvm6buAbAsYDXqkJ1qFLs5SUUBDkk62kkV3bpK1WhHh1va7HL70ecNiNwaccSdXUYBzqjQvSZ+1k+jUVdiOgU4JOloJ1V0JF5MyaKjkmrSIyslPPpU68xflzt8xImGzi4btRI6qFXVh7NSgEOSroTWxWmpYRzUvpOKS7B2pJthNuhEbo088NdMtwchRp1VsNV8tAcFOCTp1F6Ayx9k0BJJ+qMlqtErVvlOKkowHrt8mwUGnQaBEMv/zakNXwNHhbPFFOCQpItNGFXj2UJN3R4wLGAx6JCTZhK7ObKh9hpKlH8zdjqthl/OU+sylZpPn6cAhyRdYYYFGg3g8YfQ0e8XuzlJV9+l3qS/iVD7Qa1U5G981DzzxzDqTkynAIcknUmvQ37kHB01JhrzSw0qvOFMhNqLtnGBXZEKR+IToeadVG19PviCDHRaDQrsFrGbk3QU4BBRRJep1DcaV/OU8URwv69mlVbBputmfIpVvJOK28gxyW6BQae+7l59n5hIQvTwRPXddPgaODSDMyY5aSaYDVowLHC+W10JowP+ENr6fADouhkrflODGgdTKl6eAijAISLhZ3BUGOBQDZzx0Wg0qs2n4JZX0s162K3qqkY7UbFLm2rb1MDdX9W6rEkBDhEFf9NRWUcVl/Sn0pvORPCl91VWQ4lbaqAE47Eriuyi6vMG0eNRVxXsepXfayjAIaIoUekZMbFJf5My1Jf0N1FqTTTmD2dV6VLDRFiMOuRGyjGo7X6j5h1UAAU4RCTcUkN7nw8ef1Dk1iQPNxIvsJtVmfQ3UWqd+aME44lR63XTwB/ToM6ZP7rDElHYrAbYLAYA6hpVRaeM1XnDmSguMFbbcQ1qH4lPlBqXNl3eALojS3JqnfmjAIeIRo3LDVQDZ2KKYw5qVVPCKBfgqDVZdKLUfK/JTjUi1aQXuTXioACHiEaN9SnUnvQ3UYUZVmgjVbDb+31iNycpQgyLpm6qYjwRalyiqqezyyjAIeKJ3nTUM23c0Ek1cCbCqNci3xZOzlZLYBwubMjCqNPCEakATsamSJWDKdp5J2iA09XVhdWrVyM9PR12ux233nor+vv7h338j370I8yYMQMWiwXFxcX48Y9/jN7e3rjHaTSai7527twp5EchAoguN6inaFv0HCr13nQmSm3LDdzyVGGmBTotnV02HtyMqdPlhTcQErk1yUGnzwsc4KxevRrHjx/H7t278dZbb2Hfvn247bbbhnx8c3Mzmpub8atf/QrHjh3DCy+8gF27duHWW2+96LHPP/88Wlpa+K+VK1cK+EmIENSW+NfrCfB1OCgHZ/zUdrYQ7aCauMyUaB4Kt9yndNHDWdV73QiWeXTixAns2rULhw4dwuLFiwEATzzxBK655hr86le/QkFBwUXPmTNnDv7nf/6H//fUqVPx0EMP4eabb0YwGIReH22u3W6Hw+EQqvkkCbg/vKbuAQRDDPQK3zbNTRlnp5pUm/SXCEWZKgtwaKlhwrgq2F+0uFDf6cG03DSxmyQ4vnaSigNjwXqU/fv3w26388ENAFRWVkKr1eLAgQOjfp3e3l6kp6fHBTcAcPvttyM7OxtLlizBc889N+yOCp/PB5fLFfdFxOdIN8Oo1yLIsGjp9YrdHMHRiCoxoueYqWPmj5YaEkNNS5v+IIOW3vDSv5pniwULcJxOJ3Jzc+O+p9frkZmZCafTOarX6OjowNatWy9a1nrwwQfxpz/9Cbt378YNN9yAH/7wh3jiiSeGfJ1t27bBZrPxX0VFRWP/QCThtFoNiiLVfNVw06EjGhJDtUtUKu6oEqFYRTN/Td0eMCxgNeqQk2oSuzmiGXOAc++99w6a5Bv7dfLkyQk3zOVy4dprr8WsWbPw05/+NO5n999/Py6//HIsXLgQ99xzD+6++2788pe/HPK1Nm/ejN7eXv6rsbFxwu0jiaGmk365GQc1j6gSgfv9dfT70e9TdhVslmWpyF+CFKsoMK6PWZ7SaNSbmD7mRIA777wTt9xyy7CPmTJlChwOB9ra2uK+HwwG0dXVNWLuTF9fH1asWIG0tDS89tprMBgMwz6+vLwcW7duhc/ng8l0cbRqMpkG/T4Rn5pGVTQST4x0swEZVgO6PQE0dnkwMz9d7CYJptsT4IO4wgy6biZCTUubtKwZNuYAJycnBzk5OSM+rqKiAj09PTh8+DAWLVoEAHj33XfBMAzKy8uHfJ7L5UJVVRVMJhPeeOMNmM0j132orq5GRkYGBTEypKZifw20RTxhijOt6Pb0or5T2QEO1xk70s0wG3Qit0beuIFFY/cAGIaFVsFb7mkwFSZYDs7MmTOxYsUKrFu3DgcPHsSHH36IDRs2YNWqVfwOqvPnz6OsrAwHDx4EEA5uli9fDrfbjWeffRYulwtOpxNOpxOhULh2wZtvvolnnnkGx44dw5kzZ/Dkk0/i4Ycfxo9+9COhPgoRkFoS/7yBEJyucCK12m86iVCcxZ1Gr+zROHVUiZNvM0Ov1cAfZPi/RaWqiwTGpdnqHkwJulf15ZdfxoYNG3DVVVdBq9XihhtuwOOPP87/PBAIoKamBh5P+I/4yJEj/A6radOmxb1WbW0tSktLYTAYsH37dtxxxx1gWRbTpk3DY489hnXr1gn5UYhAYhNGWZZV7HpxU7cHLAukGHXISjGK3RzZ4xK1lR4Y13aEO6rJKu+oEkGv06Iww4K6Tg/qOz0osFvEbpJg+ABH5aUFBA1wMjMz8corrwz589LS0rjt3VdeeeWIB+itWLECK1asSFgbibi4vIJ+XxBdbj+yFJrxz58Lk5Wi2CAumdSSMFrXSTVwEqko04q6Tg8auzyomJoldnMEEQwxaIz8Xah9BkfZldWI5JkNOv58HSV3VlSNNrHUMoNTF/l8k7PpukkENZx/19zjDZ9dptciX+Vnl1GAQ0SnhtE4PxKnjiohuBmN8z0DCIQYkVsjnLoOyqVIpOhOKhXcazKtik6kHg0KcIjo1DAa53MpaKkhIXLTTDAbtAgxLM53K/Ow1m63H70D4bPLSmjnXUIUq2BTAy1rRlGAQ0Snhp1UtKshsbRaDd/p1yq0rkldzBZxi5G2iCcCl6xd1+EeMd9Truo6aFmTQwEOEZ3St/z6gww/y0C7YRIntrNSomhQTB1VooQr+wJ9viA63X6xmyMIGkxFUYBDRFcamcGp7VDmDE5DV/RcmNw0Ze4SE0OpwgMc7u9B7Vt9E8ls0KHAFt4ertTrhs/bouuGAhwiPq6j6uj3oc8bELk1icfdcEpoi3hCcVPwtQpd2qynkbgguBmxWgUGOMEQg8Zu2iLOoQCHiC7dbEB2arj4XZ0CZ3G4KWNaE08sboSq/JE4XTeJxC9tKjB3i7aIx6MAh0gCd9NRYsJoLU0ZC4K7Zpq6PfAHlbVVnGXZ6HVDI/GEigbGyhtM1dIW8TgU4BBJ4G46te3KC3Ao6U8YOWkmpBh1YFjl1VDq8QTg8oZPEact4onFBcbnFDjzR8ua8SjAIZJQquBp4+i2TbrpJJJGo+FrfShtmaqWtogLhrvX1Hcqb6t4LS1rxqEAh0jCFG6JSmEdlTcQQnNveIs4LVElnlLzKaIVjKmjSrSiDCu0GsDjD6Gtzyd2cxKKqyVGMzhhFOAQSShVaIATPiUdSDXp+URqkjhK3RETPYOKOqpEM+q1/CG/irtuKN8vDgU4RBK4P8jegQC6FVSAqzZmJE5bxBNvcnYqAOXO4FC5fWEosUhkMMTwuWg0gxNGAQ6RBItRh3xbeFujknZS0YhKWNzWe6XtiOET0+m6EYQSd20293gRZGiLeCwKcIhkKHEnVbQGDnVUQuCumebeAXgDIZFbkxixW8TpuhEGl4SrpBkc2iJ+MQpwiGRMzlFewijVwBFWZooRaWY9WAVtFe/2BNAX2SJenElJxkJQYs5fHdVNuggFOEQyJmcp8aZDa+JC0mg00eUGhVw33OfIt9EWcaFM5reKe8AwytgqHl3WpKCYQwEOkQyljaoG/CE4XV4AtNQgpFKFBcZcsbYS6qgEM8lugV6rgS/IoCXyNyp3NINzMQpwiGTE7mxQQgEubkSVbtYjw2oQuTXKpbRTxeso/0Zwep0WxQrLw+Fq4Eym5XAeBThEMoozwwW43P4Q2hVQgCu2o6It4sKZrLBaONzp6JS3JSwlLYnHbhEvocCYRwEOkQylFeCqpXNhkkJptXCiS1R03QhJSTN/53sGaIv4ICjAIZKipDOpqAZOcnAj8VaXDx5/UOTWTAzLsnyZBFqiEpaScv7OcddMVgptEY9BAQ6RlMmRdXElnPRLh2wmh81q4HOc5F7wr73Phz5fEFoNnUMlNH6JSgGDqbPt/QCAKTl0r4lFAQ6RFCWVUKclquRRyszf2chIvDDDCpOetogLiQsgG7s8CIYYkVszMdyAcGpOqsgtkRYKcIikRNfF5T0S7/cF+URp2tUgPKUkjJ7roJF4shTYLDDqtQiEWDT3yHur+DmawRkUBThEUibHjMTlXICLy6PISjHCRlvEBaeUfAoul4JG4sLTajV8UTy5L1NxM39T6LqJQwEOkZRJdgsMOvkX4OLWxKmjSg5u5Cr3AIdyKZKLG1BxMyBy1OcN8LPFdN3EowCHSIpep0VR5PwdOR+6SVPGycUFkmfa+mVdJJKbwZmSTYFxMnDXzTlZ32vCbc9ONSHdTLPFsSjAIZIzhV9ukO+o6iwtNSRVuJgi0DsQQJfbL3ZzxsUXDKGpO5x7NjWXAuNk4P4+z8p4BofL25pKg6mLUIBDJCd605HvqIpfoqKOKinMBh0KMywA5Hvd1Hd6wLBAmkmPnFST2M1Rham5CghwKP9mSIIGOF1dXVi9ejXS09Nht9tx6623or9/+AvpyiuvhEajifv693//97jHNDQ04Nprr4XVakVubi7uuusuBIPyLvBFomKXG+SIYVg+F4SWGpJH7qPxs23RZU062iM5uCXkVpcPfd6AyK0Zn2i+Hw2mLiRogLN69WocP34cu3fvxltvvYV9+/bhtttuG/F569atQ0tLC//16KOP8j8LhUK49tpr4ff78dFHH+HFF1/ECy+8gC1btgj5UUgSyX1Udb5nAL4gA6NOy88qEOHxAY5MA2OqZZJ86WYDctPCs2VyzcOJzuBQgHMhwQKcEydOYNeuXXjmmWdQXl6OZcuW4YknnsDOnTvR3Nw87HOtViscDgf/lZ6ezv/snXfewRdffIE//OEPWLBgAa6++mps3boV27dvh98vz7V3Em9a5Abf0utFv09+M3NcYFaSZYVeR6vAySL7GRxKTBeFnK+b2NliCowvJtjdd//+/bDb7Vi8eDH/vcrKSmi1Whw4cGDY57788svIzs7GnDlzsHnzZng80aJv+/fvx9y5c5GXl8d/r6qqCi6XC8ePHx/09Xw+H1wuV9wXkS6b1YDsVG5UJb+bDtUyEQc3RS/XHBzKpRAHlycnxyXx+NliOtrjQnqhXtjpdCI3Nzf+zfR6ZGZmwul0Dvm8f/u3f0NJSQkKCgrw+eef45577kFNTQ3+8pe/8K8bG9wA4P891Otu27YNP/vZzybycUiSTc1JQUe/D2fb+zGv0C52c8aERuLi4JY2G7s98AZCMBvkc9QBy7JUO0kkcp7B4ZY1S7Ks0NEhmxcZ8wzOvffee1ES8IVfJ0+eHHeDbrvtNlRVVWHu3LlYvXo1XnrpJbz22ms4e/bsuF9z8+bN6O3t5b8aGxvH/VokOablyjfRmDoqcWSlGGGzGMCy8juTqqPfjz5vEBpNuLMiySPnXZtUb2t4Y57BufPOO3HLLbcM+5gpU6bA4XCgra0t7vvBYBBdXV1wOByjfr/y8nIAwJkzZzB16lQ4HA4cPHgw7jGtra0AMOTrmkwmmEy07VJOogmjcrzpRJaocinASSaNRoOpOSk40tCDs21ulDnSR36SRHAdVWGGRVYzT0rA/Z3Wd7oRCDEwyChvjgZTwxtzgJOTk4OcnJwRH1dRUYGenh4cPnwYixYtAgC8++67YBiGD1pGo7q6GgCQn5/Pv+5DDz2EtrY2fgls9+7dSE9Px6xZs8b4aYhUyXUnVZ83gDYqmy6aqTmp4QBHZtfNOSorIJr8dDMsBh0GAiE0dnlklQNFeVvDEyxUnTlzJlasWIF169bh4MGD+PDDD7FhwwasWrUKBQUFAIDz58+jrKyMn5E5e/Ystm7disOHD6Ourg5vvPEG1qxZg6985SuYN28eAGD58uWYNWsWvvvd7+Kzzz7D22+/jfvuuw+33347zdIoCLdEVdfpRjDEiNya0eNuODlpVDZdDFNkmk/BbW2nkXjyabUafjAit2Uq2iI+PEHn4l5++WWUlZXhqquuwjXXXINly5bhqaee4n8eCARQU1PD75IyGo34xz/+geXLl6OsrAx33nknbrjhBrz55pv8c3Q6Hd566y3odDpUVFTg5ptvxpo1a/Dggw8K+VFIknGjqkCIRUOXZ+QnSASfYJxNNxwxRHdSySvA4WdwqKMShRwTjd2+IJyRA4mn0szfoATbRQUAmZmZeOWVV4b8eWlpadzBeEVFRXj//fdHfN2SkhL8/e9/T0gbiTRptRpMzU3BsfMunGnrl80ULOXfiItf2mxzg2FYaGWys4SSRcUlxyKRXP2b7FQjbFaaLR6MfLKpiOrIcXcDzeCIqzjTCr1Wg4FAiB/dSp0vGEJj9wAAWqISC1cLR04zONwOU8rbGhoFOESy5HgmFddWmsERh0Gn5bdZy6Wzqu1wI8SwSDPr+WMDSHLFDqZiVxWk7HRbHwBgeh7da4ZCAQ6RrGky20nlDzL8tPEleWkit0a95LbccLo13M5L8tLokE2RTM5OgUYD9A4E0OmWx5E/pyLXzXQaTA2JAhwiWbGJf3IYVdV1uhFkWKSa9CiwmcVujmpFSwzIY2nzdGtkJE4dlWjMBh1/MK5cAmNutpgGU0OjAIdIVmm2FVoN0OcNoj1SW0bKuJH4tNxUGomLSG5Lm/xInDoqUfHXjQxmjL2BEOoj1bqn0RLVkCjAIZJl0utQnBnOp5DDTecUjcQlgT/mQwbXDBDNpbiEOipRySkwPtveD4YF7FYDclIpb2soFOAQSePzcGRw04l2VDQSFxMXYLb3+dAt8XwKXzCEus5wnafpuXTdiIkLMLmZWCnjgrDpNFs8LApwiKRxo6rTMghwoksNNBIXU4pJz+dTcLNqUhW7gyovnUbiYuIGJjUSv2aAmNliGkwNiwIcImncTUfqHZU/yKCOdlBJxgyZXDenaAeVZHDBghxm/k7TDqpRoQCHSNoMR2RU5eyT9E6q2B1U+bSDSnSXOOQxGj9DeVuSkSqjmb/TtINqVCjAIZI2LTcVWg3Q7QmgvV+6O6m4GyLtoJIGLp/ilFPaS5u0g0pa5DBjHLuDigLj4VGAQyTNbNChNCtcRl3KnVW0WBvdcKSA76japD3zd4p2UEmKHPJwzrW7wbCAzWJADlW+HhYFOETy5HDToR1U0jI1Jzzz1+MJSLaGki8YQn1kBxVdN9IwwyH9mT/+iAaaLR4RBThE8vh8CqdL5JYMjZYapMVs0KE0cuCpVANjOoNKeuQw83ea7jWjRgEOkbwyPmFUmqOq2B1UtCYuHZfkRhPUpYh2UEmPHGb+qKDo6FGAQySPG1Wdbu0Dw0hvVMXtoEqjHVSSws38SbVwG7eDivJvpCM250+qM3+0g2r0KMAhkleaZYVRr4XHH8L5ngGxm3MRfgdVHq2JS8kMieduneLPLqOOSkr4nD8JzvzF7aCiwHhEFOAQydPrtJgWqWh8UoI3nVORNl1CHZWkREvvSzOf4mQkp2wGjcQlRcozf6dbw2dQZVgNlLc1ChTgEFngCv5JsT7FFy3hNpXlU0clJaXZKTDoNHBLcObP7Quiviu8g2omXTeSIuWZvxMt4aB4Zn46zRaPAgU4RBakPG0ce9Mh0mHQafmzzKQWGNe09oFlgdw0E7LoNGhJ4baKSzHn70Rk1q/MQfea0aAAh8gCX59CYh1V70CAnx2YSTcdyeECY6ktbXJBcRkFxZJTkiXdmb/oYIpm/UaDAhwiCzMiwcPZ9n4EQozIrYk6GbnhTLJbYLMaRG4NuRC3tHmyRZoBDnVU0hM78yelwJhlWZyIXMc0Wzw6FOAQWSiwmZFq0iMQYnGu3S12c3jUUUnbrIJwR/BFi7SKRHIB1yzqqCSJ++9yQkLXjdPlRe9AADqtBtOoBs6oUIBDZEGj0fCjcSnddLgRHq2JS9PsSEd1rr0fA/6QyK0JYxiWrhuJ4wLj4829IrckirvvTc1JgdmgE7k18kABDpGN2RIcjVOCsbTlpJmQnWoEw0pnV0xT9wD6fUEYdVpMyUkRuzlkEFKc+eOWpygoHj0KcIhszJbYqCrEsHynSUtU0qTRaDCrwAYA+KJZGp0V12lOz0uFQUe3YCnilqgauwbQOxAQuTVhNJgaO/rrIrIxO9JRHW92SaJwW22HG94AA4tBh5IsGolLFddZSSUw5gr8UUclXXarEZPsFgDSWRKnfL+xowCHyMb0vFTotRr0eAJo7vWK3Rz+hjPDkQadlopuSZXUlhv4LeIO6qikjL9uJDDz5w2EUBs50JcC49GjAIfIhkmvw/RIXZPj58UfjdOUsTxwMzgnW/oQkkDhthO0g0oWuP8+UgiMT7X2gWGBzBQjHdEwBhTgEFmJ5uGIf9OhKWN5mJydArNBi4FACHWd4pYYcHkDaIgc0UBF/qRttoRmcI6dD7dhFh3RMCYU4BBZkUqAw7IsjkZuOnMm2URtCxmeTqvhd56I3Vkdi8w8TrJbkJliFLUtZHjcEtXptj74g+IWFz0auW7mFtK9ZiwEDXC6urqwevVqpKenw26349Zbb0V//9AntNbV1UGj0Qz69ec//5l/3GA/37lzp5AfhUgEP20scsKo0+VFR78POq2GlhpkQCp5OEebwtftPOqoJG+S3YJ0c7i46Ok2cUsMHD3fAwCYS4OpMRE0wFm9ejWOHz+O3bt346233sK+fftw2223Dfn4oqIitLS0xH397Gc/Q2pqKq6++uq4xz7//PNxj1u5cqWQH4VIBNdRNfd60e32i9aOzyMd1fTcVCq6JQNSmfmjkbh8hEsMiD/z5wuG+EOGKcAZG8ECnBMnTmDXrl145plnUF5ejmXLluGJJ57Azp070dzcPOhzdDodHA5H3Ndrr72Gb3/720hNjS9Nbbfb4x5nNpuF+ihEQtLMBpRkWQGI21nRSFxe5kRKDBxt6hG1xAAX4MybZBetDWT0uNIUR0Xc1FDj7EMgxCLDakBhhkW0dsiRYAHO/v37YbfbsXjxYv57lZWV0Gq1OHDgwKhe4/Dhw6iursatt9560c9uv/12ZGdnY8mSJXjuueeGvWn5fD64XK64LyJf3Gj8mIjLVJ/zI3G7aG0go1eWnwaDToNuTwBN3eKcEN3rCaC+M5xgPGcSLWvKwfwiOwDgsyYR7zWR954zyUYJxmMkWIDjdDqRm5sb9z29Xo/MzEw4nc5Rvcazzz6LmTNnYunSpXHff/DBB/GnP/0Ju3fvxg033IAf/vCHeOKJJ4Z8nW3btsFms/FfRUVFY/9ARDLmRka/nzf1iPL+LMviaOS959GUsSyY9Do+V6q6sUeUNnCzAMWZVtitlGAsBwsiA5gTzS7REo25xHSaLR67MQc4995775CJwNzXyZMnJ9ywgYEBvPLKK4PO3tx///24/PLLsXDhQtxzzz24++678ctf/nLI19q8eTN6e3v5r8bGxgm3j4hnflH4D726oUeU92/qHkC3JwCDToMy2iIuG/xoXKQA53MuUZQ6KtkoyrQgw2qAP8TwFaiTjZvBmUvLmmOmH+sT7rzzTtxyyy3DPmbKlClwOBxoa2uL+34wGERXVxccDseI7/Pf//3f8Hg8WLNmzYiPLS8vx9atW+Hz+WAyXVwEyWQyDfp9Ik/zCu3QaMKJxm0uL3LTk5t/xY3EZzjSYNJTgrFczCu0A6jnO4xk40fiNOsnGxqNBvMK7Xj/VDs+a+yJXEPJ4w2EcCpy3h0FxmM35gAnJycHOTk5Iz6uoqICPT09OHz4MBYtWgQAePfdd8EwDMrLy0d8/rPPPouvf/3ro3qv6upqZGRkUBCjEqkmPS7JTUNNax+qG3uwfPbIAXMi0YhKnhYURRNGgyEG+iQfdMlfN9RRycr8onCAU93Yi+9WJPe9Tzr7EGRYZKUYUWCjjTRjJdhf+MyZM7FixQqsW7cOBw8exIcffogNGzZg1apVKCgoAACcP38eZWVlOHjwYNxzz5w5g3379uEHP/jBRa/75ptv4plnnsGxY8dw5swZPPnkk3j44Yfxox/9SKiPQiRoQWS5QYx8Cq4mBa2Jy8uU7FSkmvQYCIRwpn3oelxC6HL7+eRmKgwpL1xg/JkIOX/cbDElGI+PoEOYl19+GWVlZbjqqqtwzTXXYNmyZXjqqaf4nwcCAdTU1MDj8cQ977nnnkNhYSGWL19+0WsaDAZs374dFRUVWLBgAX7/+9/jsccewwMPPCDkRyESM1+kAIdhWH6LONWkkBetVsP/N0t2Hg6XED85OwXpZkNS35tMDLcsdba9H33eQFLfm8szpMHU+Ix5iWosMjMz8corrwz589LS0kG3dz/88MN4+OGHB33OihUrsGLFioS1kcgTN4PzeVMvGIaFNkmneZ/r6IfLG4TZoMUMOg1aduYX2bH/XCeqG3tx02XJe98j9d0AgIXF9uS9KUmI7FQTCjMsaOoewNHzvVg6NTtp7/1pQ/i6ubQkI2nvqSR0FhWRpUvyUmEx6NDvC+JsEpcbDkc6qvmFdhiSnMNBJo5bbkh2iYEjkZH4pcXUUclRdAde8hLUu9x+nOsIHw57aRFdN+NBd2giS3qdlk/W/DSJyw1cgLOIRlSyxC03nHT2YcAfSsp7hhiWH4nTdSNP8yP3murG7qS9JzfrNy03FTYrLWuOBwU4RLYWipCHQwGOvOXbzMi3mcNBR5I6qxpnH9z+UHj3Xx4ta8rRwsjM2+H65B31cYRbnqJlzXGjAIfIFpeHw410hNbt9uNse3jKeCEtNciSRqPB4tJMAMAndcm5bg43RPNvdEnKFSOJNXeSDUadFh39PtR1ekZ+QgLQYGriKMAhssV1VDWtfej1CL+7gRvxT8lJQWYKldqXqyWl4Q7jUF1XUt7v03puJE4dlVyZDTq+gvqhWuGvm0CI4bel03UzfhTgENnKSTNhSk4KWBb4pF74mw4/oqIbjqxxgfGR+m4EQ8KfL3SY8m8U4bLIdZOMwPhkSx+8AQbpZj2m5qQK/n5KRQEOkbXyyeGbzsEkjKq4JY3FpdRRydkleWlIM+vh9odw0tkn6Hu19XlR3+mBRgMsoFwKWUtmgHOgthNAeHt4skpgKBEFOETWuJvOAYEDHG8gxO/W4mYAiDzptBosjsymCB0Y7z8b7qhmF6RTgT+Zu7QkAxoNUNfpQVufV9D3+vhc+Lr50pQsQd9H6SjAIbK2JDKDc+x8L9y+oGDvc6S+G/4gg7x0E6Zkpwj2PiQ5+ERjgZc2uY6qgjoq2bNZDChzpAMQNkE9xLD8gI2um4mhAIfIWmGGFZPsFgQZFp9GiqkJ4aPISHzp1Gw6E0YBuMD4wLkuMIxw2365GZyKqdRRKcFlkeXpA5HAVQhfNLvQ5w0izaTH7IJ0wd5HDSjAIbLHdVYfC3jT2U8jcUWZX2iH1ahDp9uPE06XIO/R0juAuk4PdFoNv5RK5G1pJFD94EyHYO+x/1z4tZdMzkz6ifdKQ789Intc0CHUTcftC/KHM9JIXBmMei2f3/ChQNcNN3szZ5INaZR/owgVU7Oh1QBn291o6R0Q5D1o1i9xKMAhsvflS8KH333e1IMejz/hr3+orgtBhkVhhgVFmdaEvz4Rx+XTwtfNP08LE+Bwy5o066ccNouBP+5DiOsmGGJwKJLfQwnGE0cBDpG9fJsFl+SlgmGFmcWhjkqZvjw9HOAcquuCN5DYc6lYlqWRuEItiwTGHwgQ4HzW1IN+XxA2iwGz8in/ZqIowCGK8OXpOQCAf55K/E1nb00bAGBZpEMkyjA9NxW5aSZ4A0zCj/s41dqP8z0DMOm1WEL5N4rC3Qc+PNOR8AT1d0+G7zVfuSSH6t8kAAU4RBG+ckk4wNl3uj2hh+E1dnlwqrUfOq0GV0TegyiDRqPhR+P/TPDM356TrQDCSakWoy6hr03EdWlxhmAJ6u+ebAcA/EsZ3WsSgQIcoghLSjNh1GvR0uvF2fb+hL3ue5HZm0XFGbBb6fwppeHyt9490ZbQ130vMhL/l7LchL4uEV9sgvremvaEva6z14sTLS5oNMAVl9B1kwgU4BBFsBh1/LENexLYWXGv9VXqqBTpqzNyodNqUNPah/pOd0Jes8fj588to+tGmSpn5gEA3jnuTNhrcoOpBUV2Osw3QSjAIYqxfFb4pvO/xxJz0/H4g3z9m6tmUkelRHarkQ+Md3/RmpDXfP9UOxgWuCQvFYUZtOtOiSpn5UKjAT5r6k3YdnEu/+ZfZtC9JlEowCGKUTXbAY0GqG7sSchN58MznfAHGUyyWzA9l070VSouMH7neGICHJr1U77cNDMuLQ5XNf5HAgLjAX+Ir8dE103iUIBDFCM33YxFkZvOrgTM4vz9aAsA4F9n5dHxDAr2r7MdAMLnUnX2+yb0Wt5ACHtOhDu85bMcE24bkS4uMH47AYHx3po2ePwhFGZY6HiGBKIAhyjKijnhTmWiAY43EOKXLK6bnz/hdhHpmmS3YM6kdDAs8I8TE+us9ta0we0PYZLdgkuL7YlpIJGk5ZHA+ONznej1BCb0Wm99Hh5MXTsvnwZTCUQBDlGUqshN51BdF9r6vON+nb01bej3BTHJbsHCooxENY9I1NVzwkHs6582T+h13qSOSjUmZ6egzJGGIMPib5HZ3vHw+IN8WYGvzS1IVPMIKMAhClOUacWCIjsYFvjrBDqrNz4LP/faeflUcEsFVi6cBCB8qGpTt2dcr9HnDfDbza+bRx2VGnwjct38z5Gmcb/GO8db4Q0wKM60Ys4kWp5KJApwiOLcuLgQAPCnTxrHVfSvo9/HL09dv4A6KjWYZLfwR3H8tXp8gfEbnzVjIBDC1JwU6qhU4hsLJ0GrAQ7Xd6O2Y3xlBv54sAEAcMOlhTTrl2AU4BDFuW5+AcwGLU639aM6cgr4WPzP4SYEQizmF9owu8CW+AYSSfrGpeHR+KuHGhEaRwn+nQcbAQDfWVJMHZVK5Kab+WNi/vxJ45iff669Hwdqu6DVRAdmJHEowCGKk2424JpITsX/+7h+TM9lGJYfUf1beXHC20ak62vz8pFu1qOhy8NXIh6tY+d7cfR8L4w6Lb55KXVUavKdJUUAwjMxYz209dVD4aDoiktyUGC3JLxtakcBDlGkNUtLAQBvVDePqSbOnpNtqOv0IM2kx9coj0JVrEY9Vi0JB7UvfFQ3puc+te8cAODquQ6qQqsy/zrLgcIMC7o9Abz+6flRP8/lDeCVA+HB1OryEqGap2oU4BBFWlBkx5LJmQgyLJ7/sG5Uz2FZFr977wwA4OaKEqSY9AK2kEjRmooSaDXAB2c6cOx876ieU9fhxlufh/N2/s9XpgrZPCJBOq0Gt0QGVE//89yolzf/8HE9+nxBTM9NpTPLBEIBDlGs//OVKQDCN5JW18hbxved7sBnjT0wG7S4ddlkoZtHJKgww4qvzw/P3D2y6+SonvO7986AYcMHa86iIm2q9O3LimCzGHC23T2qHVV93gCe+6AWALD+yqm0U1MgFOAQxfqXslxcWmyHxx/CL9+uGfax/iCDrW99AQD4tyUlyE41JaOJRII2/esM6LUa/PN0Bz443THsY6sbe/gObcO/TEtG84gEpZsN2PDV8H//3+w+hQH/8Lk4T7x7Bh39fpRmWXHdfFoKF4pgAc5DDz2EpUuXwmq1wm63j+o5LMtiy5YtyM/Ph8ViQWVlJU6fPh33mK6uLqxevRrp6emw2+249dZb0d/fL8AnIHKn0Whw/9dmAQD++3ATDtZ2DfnYZz44hzNt/chKMeInV01PVhOJBBVnWbE6kmC++bXP0e8LDvq4YIjBlr8eA8sC37x0En82EVGn71aUYJLdgpZe77CzfzXOPn725oHrZsOgo3kGoQj2m/X7/bjxxhuxfv36UT/n0UcfxeOPP44dO3bgwIEDSElJQVVVFbze6PLC6tWrcfz4cezevRtvvfUW9u3bh9tuu02Ij0AUYGFxBr4d2X65ceenaO+7+Kyhw/Vd+M3uUwCAzdfMhM1qSGobifTcWTUDk+wWNHYNYMvrxwatp/To2zX4vKkXaSY97r26TIRWEikxG3R4+JtzAYST1PcMcuxHnzeA2185giDDonJmLh2sKTDBApyf/exnuOOOOzB37txRPZ5lWfz2t7/Ffffdh+uvvx7z5s3DSy+9hObmZrz++usAgBMnTmDXrl145plnUF5ejmXLluGJJ57Azp070dw8sRLrRLm2XDcbpVlWNPd68d1nD+B8T3RX1cHaLqx9/hACIRbXzHXghkgtFKJu6WYDfv3t+dBqgL98eh4PvHEcgRADAAgxLB7bfYrfOfXIt+YhN80sZnOJRFxxSQ6fcHz7K0fw7slokNPe58Mtzx/CmbZ+5KaZ8Isb5onUSvWQzDaR2tpaOJ1OVFZW8t+z2WwoLy/H/v37sWrVKuzfvx92ux2LFy/mH1NZWQmtVosDBw7gG9/4xqCv7fP54PNFR+4ul0u4D0IkJ9Wkxwtrl+BbO/bjpLMPlb9+H5Wz8tDvDWDvqXawLLCoJAOPfms+FWgjvC9NycJD35iLzX85ipf212PfqXZcVpqJz5t6UdPaBwC4q2oGrplLh7GSqP+4ZibqO914r6Yd33/hE1w+LQu5aWbsOdEKlzeIdLMez91yGeX5JYFkFv+czvDpz3l5eXHfz8vL43/mdDqRmxs/pafX65GZmck/ZjDbtm2DzWbjv4qKihLceiJ1pdkpeO2HS3FpsR0DgRDe/KwZ79W08/kTL35/CVJpWzi5wHeWFGPHzZciw2pAXacHfz7chJrWPqQYdXj0hnm4/auUWEziGfVa7PjuItyytBQaDfDhmU689ul5uLxBlDnS8N/rl2LOJKqQngxjuqPfe++9eOSRR4Z9zIkTJ1BWJq316M2bN2PTpk38v10uFwU5KlSUacX/rF+K/ec68WlDD0x6La64JAfT89LEbhqRsBVz8vHl6TnY/UUr6js9yLebUTXbAZuFcrXI4Ex6HX769dm4ZWkp9ta0we0PYXZBOr48PQc62hKeNGMKcO68807ccsstwz5mypQp42qIw+EAALS2tiI/Pzrl29raigULFvCPaWuLL6EeDAbR1dXFP38wJpMJJhNNB5LwzqqlU7OxdGq22E0hMpJi0vMnjhMyWqXZKbglm2pqiWVMAU5OTg5ycnIEacjkyZPhcDiwZ88ePqBxuVw4cOAAvxOroqICPT09OHz4MBYtWgQAePfdd8EwDMrLywVpFyGEEELkR7AcnIaGBlRXV6OhoQGhUAjV1dWorq6Oq1lTVlaG1157DUB4ZL1x40b8/Oc/xxtvvIGjR49izZo1KCgowMqVKwEAM2fOxIoVK7Bu3TocPHgQH374ITZs2IBVq1ahoICKJRFCCCEkTLCsyi1btuDFF1/k/71w4UIAwHvvvYcrr7wSAFBTU4Pe3uh5L3fffTfcbjduu+029PT0YNmyZdi1axfM5ugWzJdffhkbNmzAVVddBa1WixtuuAGPP/64UB+DEEIIITKkYQerYKVwLpcLNpsNvb29SE+ns2MIIYQQORhL/y2ZbeKEEEIIIYlCAQ4hhBBCFIcCHEIIIYQoDgU4hBBCCFEcCnAIIYQQojgU4BBCCCFEcSjAIYQQQojiUIBDCCGEEMWhAIcQQgghiiPYUQ1SxhVvdrlcIreEEEIIIaPF9dujOYRBlQFOX18fAKCoqEjklhBCCCFkrPr6+mCz2YZ9jCrPomIYBs3NzUhLS4NGo0noa7tcLhQVFaGxsZHOuRIQ/Z6Tg37PyUG/5+Sg33PyCPW7ZlkWfX19KCgogFY7fJaNKmdwtFotCgsLBX2P9PR0+gNKAvo9Jwf9npODfs/JQb/n5BHidz3SzA2HkowJIYQQojgU4BBCCCFEcSjASTCTyYQHHngAJpNJ7KYoGv2ek4N+z8lBv+fkoN9z8kjhd63KJGNCCCGEKBvN4BBCCCFEcSjAIYQQQojiUIBDCCGEEMWhAIcQQgghikMBTgJt374dpaWlMJvNKC8vx8GDB8VukqJs27YNl112GdLS0pCbm4uVK1eipqZG7GYp3i9+8QtoNBps3LhR7KYo0vnz53HzzTcjKysLFosFc+fOxSeffCJ2sxQlFArh/vvvx+TJk2GxWDB16lRs3bp1VOcZkaHt27cP1113HQoKCqDRaPD666/H/ZxlWWzZsgX5+fmwWCyorKzE6dOnk9Y+CnAS5NVXX8WmTZvwwAMP4MiRI5g/fz6qqqrQ1tYmdtMU4/3338ftt9+Ojz/+GLt370YgEMDy5cvhdrvFbppiHTp0CL///e8xb948sZuiSN3d3bj88sthMBjwv//7v/jiiy/w61//GhkZGWI3TVEeeeQRPPnkk/jd736HEydO4JFHHsGjjz6KJ554QuymyZrb7cb8+fOxffv2QX/+6KOP4vHHH8eOHTtw4MABpKSkoKqqCl6vNzkNZElCLFmyhL399tv5f4dCIbagoIDdtm2biK1Stra2NhYA+/7774vdFEXq6+tjp0+fzu7evZu94oor2J/85CdiN0lx7rnnHnbZsmViN0Pxrr32Wvb73/9+3Pe++c1vsqtXrxapRcoDgH3ttdf4fzMMwzocDvaXv/wl/72enh7WZDKxf/zjH5PSJprBSQC/34/Dhw+jsrKS/55Wq0VlZSX2798vYsuUrbe3FwCQmZkpckuU6fbbb8e1114bd12TxHrjjTewePFi3HjjjcjNzcXChQvx9NNPi90sxVm6dCn27NmDU6dOAQA+++wzfPDBB7j66qtFbply1dbWwul0xt0/bDYbysvLk9YvqvKwzUTr6OhAKBRCXl5e3Pfz8vJw8uRJkVqlbAzDYOPGjbj88ssxZ84csZujODt37sSRI0dw6NAhsZuiaOfOncOTTz6JTZs24T/+4z9w6NAh/PjHP4bRaMT3vvc9sZunGPfeey9cLhfKysqg0+kQCoXw0EMPYfXq1WI3TbGcTicADNovcj8TGgU4RJZuv/12HDt2DB988IHYTVGcxsZG/OQnP8Hu3bthNpvFbo6iMQyDxYsX4+GHHwYALFy4EMeOHcOOHTsowEmgP/3pT3j55ZfxyiuvYPbs2aiursbGjRtRUFBAv2cFoyWqBMjOzoZOp0Nra2vc91tbW+FwOERqlXJt2LABb731Ft577z0UFhaK3RzFOXz4MNra2nDppZdCr9dDr9fj/fffx+OPPw69Xo9QKCR2ExUjPz8fs2bNivvezJkz0dDQIFKLlOmuu+7Cvffei1WrVmHu3Ln47ne/izvuuAPbtm0Tu2mKxfV9YvaLFOAkgNFoxKJFi7Bnzx7+ewzDYM+ePaioqBCxZcrCsiw2bNiA1157De+++y4mT54sdpMU6aqrrsLRo0dRXV3Nfy1evBirV69GdXU1dDqd2E1UjMsvv/yiUgenTp1CSUmJSC1SJo/HA602vrvT6XRgGEakFinf5MmT4XA44vpFl8uFAwcOJK1fpCWqBNm0aRO+973vYfHixViyZAl++9vfwu12Y+3atWI3TTFuv/12vPLKK/jrX/+KtLQ0fh3XZrPBYrGI3DrlSEtLuyivKSUlBVlZWZTvlGB33HEHli5diocffhjf/va3cfDgQTz11FN46qmnxG6aolx33XV46KGHUFxcjNmzZ+PTTz/FY489hu9///tiN03W+vv7cebMGf7ftbW1qK6uRmZmJoqLi7Fx40b8/Oc/x/Tp0zF58mTcf//9KCgowMqVK5PTwKTs1VKJJ554gi0uLmaNRiO7ZMkS9uOPPxa7SYoCYNCv559/XuymKR5tExfOm2++yc6ZM4c1mUxsWVkZ+9RTT4ndJMVxuVzsT37yE7a4uJg1m83slClT2P/8z/9kfT6f2E2Ttffee2/Qe/L3vvc9lmXDW8Xvv/9+Ni8vjzWZTOxVV13F1tTUJK19GpalUo6EEEIIURbKwSGEEEKI4lCAQwghhBDFoQCHEEIIIYpDAQ4hhBBCFIcCHEIIIYQoDgU4hBBCCFEcCnAIIYQQojgU4BBCCCFEcSjAIYQQQojiUIBDCCGEEMWhAIcQQgghikMBDiGEEEIU5/8D+U62NWhCeF0AAAAASUVORK5CYII=\n", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "\n", - "x_np = np.linspace(0, 10, 1000)\n", - "y_np = 2 * np.sin(x_np) * np.cos(x_np)\n", - "plt.plot(x_np, y_np);" + "import jax.numpy as jnp" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With this import, you can immediately use JAX in a similar manner to typical NumPy programs, including using NumPy-style array creation functions, Python functions and operators, and array attributes and methods:" ] }, { "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "18XbGpRLuZlr", - "outputId": "3d073b3c-913f-410b-ee33-b3a0eb878436" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjgAAAGdCAYAAAAfTAk2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABz9UlEQVR4nO3deXyU1b0/8M/sS5bJnknIxiZhB0FSkFZ7zSWotdJaK71YKrX4u1TaKr5cuFex1SrVtrZXy0/qrr9qtb23WrW9KEXRqggIRgEhbNkgmezJJDOZ9Xl+f8ycZyaQPfPMs33fr9e8WibPzJyJT875nnO+5xwdz/M8CCGEEEJURC91AQghhBBCEo0CHEIIIYSoDgU4hBBCCFEdCnAIIYQQojoU4BBCCCFEdSjAIYQQQojqUIBDCCGEENWhAIcQQgghqmOUugBS4DgOTU1NSEtLg06nk7o4hBBCCBkFnufR29uLwsJC6PXDj9FoMsBpampCcXGx1MUghBBCyDg0NjaiqKho2Gs0GeCkpaUBiPyC0tPTJS4NIYQQQkbD7XajuLhYaMeHo8kAh01LpaenU4BDCCGEKMxo0ksoyZgQQgghqkMBDiGEEEJUhwIcQgghhKgOBTiEEEIIUR0KcAghhBCiOhTgEEIIIUR1KMAhhBBCiOpQgEMIIYQQ1aEAhxBCCCGqI2qA8/777+Oqq65CYWEhdDodXnvttRFfs3v3blx44YWwWCyYNm0annvuufOu2bZtG8rKymC1WlFRUYF9+/YlvvCEEEIIUSxRAxyPx4P58+dj27Zto7q+trYWV155Jb761a+iuroat9xyC37wgx/grbfeEq555ZVXsGnTJtx77704ePAg5s+fj6qqKrS2tor1NQghhBCiMDqe5/mkfJBOh1dffRWrVq0a8po777wTf/vb33D48GHhudWrV6O7uxs7duwAAFRUVOCiiy7C7373OwAAx3EoLi7Gj370I9x1112jKovb7YbD4UBPTw+dRUUIIYQoxFjab1kdtrlnzx5UVlYOeK6qqgq33HILACAQCODAgQPYvHmz8HO9Xo/Kykrs2bNnyPf1+/3w+/3Cv91ud2ILLjM8z2N/XRc+PNkOnuexqCwLX56WA71+5MPJiHb5gmHsOOzCMVcvMuwmrJiVjym5qVIXi8jc2e5+/O+hZrT1+TE1NxVXzC1AqkVWTQvRKFndhS6XC/n5+QOey8/Ph9vtRn9/P7q6uhAOhwe95tixY0O+79atW/Gzn/1MlDLLTZcngFv/VI3dNW0Dnl9YkoFHVy9EcZZdopIROdt7ugM/ebkaLrdPeO6hHcfw/Ysn467Ly2Ey0HoEMhDH8XjsnZN47J0TCHGxiYCtfz+Kh781H/86K3+YVxMiPk3UWps3b0ZPT4/waGxslLpIouj0BHDN9o+wu6YNZoMeqxYU4luLipBqMeLThm58a/tHqGv3SF1MIjO7a1qx5qm9cLl9KHRY8d0vleKSC3LB88DTH9TiRy99ilCYk7qYREZ4nsd/vnYIv/nHcYQ4HksmZ+F7S0sxOScFXd4gbvp/n+C1T89KXUyicbIawXE6nWhpaRnwXEtLC9LT02Gz2WAwGGAwGAa9xul0Dvm+FosFFotFlDLLBcfx2PCHAzjd5kGhw4pn1y3BDGcaAODWf70A657dh+MtfVj/wid47eaLkUJDyATAydY+bPjDQYQ4HitnO/Gb6xbAZjYAAHYcduHHf/wUO4648Mu3a7D58pkSl5bIxbMf1uGP+xqh1wG/+OY8fPuiYgBAIMThntcO45VPGnH7f3+G0mw7FpZkSlxaolWyGsFZunQpdu3aNeC5nTt3YunSpQAAs9mMRYsWDbiG4zjs2rVLuEarnvmwFntrO2E3G/DCjbHgBgAmZdjwhxsrkJdmwYnWPvzq7RoJS0rkIszxuOWVT9EfDGPplGw8+p2FQnADACvnOPHIdfMBAL9/7zT213VKVVQiI8dbevGLHZGUgHu+NksIbgDAbNRj6zfn4vI5TgTDPG55pRq+YFiqohKNEzXA6evrQ3V1NaqrqwFEloFXV1ejoaEBQGTqaO3atcL1//7v/47Tp0/jjjvuwLFjx/B//+//xZ/+9CfceuutwjWbNm3Ck08+ieeffx5Hjx7Fhg0b4PF4sG7dOjG/iqyd6fLil29Fgpa7r5yFaXlp512Tl27Fr78daaye/6gOh8/2JLWMRH5e3t+Aw2fdcNhM+K/VC2A2nl8dfG1eIa5bHGnA7nntME1VaRzP8/jPVw8hEOLw1Rm5uGFZ2XnX6PU6PPStechPt6C+w4sn3z+d/IISApEDnE8++QQLFy7EwoULAUSCk4ULF2LLli0AgObmZiHYAYDJkyfjb3/7G3bu3In58+fj17/+NZ566ilUVVUJ11x33XX41a9+hS1btmDBggWorq7Gjh07zks81pL/+scJ+EMcvjQlC99ZUjzkdV+enouvzSsAxwP3vflFEktI5KanP4hfv30cAHBr5XTkpVuHvPbOy8uRYTfhmKsXf9yvzvw1Mjq7jrZif10XrCY9tn5zHnS6wVdmpltN+M8rZwEAtu0+ieae/mQWkxAASdwHR07UtA/OydY+rPjNe+B44NUfLhtxvru5px+XPLwbgTCHV276EiqmZCeppEROHtt1Ar/eeRxTc1Ow45avjLhK6vmP6nDv60cwKcOG3bdfSquqNIjjeKz8r/dxvKUPGy6dijtXlg97Pc/zuHb7HnxS34XvXzwZW66alaSSEjUbS/tNtZTCbX/vFDgeqJyZP6pkvgKHDd9aXAQA+N27J8UuHpEhXzCM5z6qAwD8+LLpowpWrruoGLlpFpzt7sertDpGk96tacXxlj6kWY3490umjni9TqfDjy6bDgD4474GdHoCYheRkAEowFGwjj4/Xv+sCQCw4dKRKxxmwyVTYdDr8M8T7TjmUvemh+R8/3PwDDo8AUzKsOGKuQWjeo3VZMD6L08GADzx/mlocOBX8579sA4A8J0lJXDYTKN6zVem52DOpHT0B8P4f3vqRSwdIeejAEfBXt7fiECIw7wiBy4syRj164qz7KiaHclZemlvwwhXEzXheV5oaL6/fPKYppq+s6QENpMBJ1v7sL+uS6wiEhmqcfXig5Pt0OuAtUtLR/06nU6H9V+eAgB4ZX8DwhwFxiR5KMBRKI7jheDke0vLhkz2G8p3lpQAAF49eBbeQCjh5SPy9PmZHhxz9cJi1ONbi4rG9No0qwlfn18IAHh5HwXGWvLy/sh/7xWznCjKHNtu6FWznciwm9DU48P7x9tGfgEhCUIBjkLtre3E2e5+pFmMuHLe6KYZ4l08NQclWXb0+kN48/NmEUpI5OhPn0RWQa2c4xz1NEO871REAuO/HWpGjzeY0LIReQqGObxeHZkK//ZFYwuKgcj05jcXRl73RwqMSRJRgKNQbBv0K+YWwGoyjHD1+fR6Hb4dTTZ+I5rHQ9TNFwwLOVtsb5uxml/kQLkzDf4Qhx1HKDDWgvePt6HDE0BOqhlfmZ47rvdYHd2+4p1jrej2UrIxSQ4KcBTIFwzj74cijcs3Lpw07ve5Kjrd8OHJdrT3+Ue4mijdByfa0esLocBhxZfGuT2ATqcT7hsa+dOGvxyMdKauXjAJxnFuD3BBfhrKnWkIcTzeOuJKZPGIDPmCYVmkPlCAo0DvHW9Drz+EQocVS8qyxv0+pdkpmF/kAMcD/3uYKh21Y/+Nq2Y7odePLWcr3teiU6IUGKufLxjGuzWtAICrFxRO6L0oMNaOHYddWHjfTvzsjSOSloMCHAXa+UXksNGqORNrqIDIVvwA8CZNU6laIMRh5xeRAOfyOUMfTDsapdkpmDuJAmMt+PBkO7yBMAocVsyd5JjQe10Z3ZLgo1Md6KDAWNV2ftECf4iD3Tz29IlEogBHYcIcj3eORXpU/zpr4sdTrIw2dp/Ud9HcuIrtOd0Bty+EnFQzFk9g1I9h++ewYJuoE5tOWjErf8wrNc9VlpOC2YXpCHM8dkXrMKI+/lAYu2tYGzWxztREUYCjMAfqu9DpCcBhM01oeoopzrJjRn4awhyP92gJp2rtOByZFqia7YRhgqN+AFA5Mw8A8PHpDlnMtZPEC3M8/nE00lCtmJ2YhuqymZFO2bsU4KjWnlMd8ATCyEuzYN4ER/0migIchWHTDP9SnjfuhL9zfbU80li9Q5WOKnEcj51fRP7brpzg9BQzLS8VRZk2BEIcPjzZkZD3JPLCOlPpViOWTJ54ZwqI1FsA8M8T7QjSyfSqxEZ1K2flTziFYqIowFEQnueFmycR01PMZdHe+O6aNoSo0lGdoy432vv8sJkMCWuodDqd0FhRYKxOrDN12cz8hB2uOm+SAzmpZvT5Q9hf15mQ9yTywfNxKRQzE9dGjRcFOArS0OlFXYcXJoMOX7lgfPtRDGZhcQYy7Cb09AdxsKE7Ye9L5OGfJ9oBAEunZsNiTFzSHxv5e/dYK51NpULsvrl0RuLqGr1eh0suiN03RF1OtfWhuccHs1E/7q0oEokCHAX54GSkwrmwJBOpFmPC3tdo0OOSaMC06xgljaoN2x7/K9NzEvq+S6dkw2rSw+X24YtmOrRVTdp6/Tjm6gUAXDwtsffNV8sjdQ2N/KkPC4ovKsuETeIVVAAFOIryYTTASXSFA8R6aR9RPoWqeAMhfBI9GDORo35AZAv+i6dG7sUPohUbUQdW18wqSEdOqiWh7/3l6bkw6HU41ebBmS5vQt+bSIvVA8unJbauGS8KcBQizPH46FQk+BAjwFkWbagON/XQGUMq8vHpDgTCHCZl2DA5JyXh7790amQYmt2bRB3YaPGXEzzqBwAOmwnziiKraz4+TXk4ahEMc/j4dKQeEOO+GQ8KcBTiiyY3ur1BpFqMmF+U+KV3+elWTMlNAc8De2upsVKL949HGqqvXJA74X1MBsMCnP11nbQqRiV4no/1xEVqqJZG8zP2UGCsGp82dMMTCCMrxYxZBelSFwcABTiKwXpUX5qSnbDl4edaRr1x1WE9quUijPoBwExnOjLsJngDYXx+pkeUzyDJdaqtDy53JFH0ogTstTUYFhh/fLqDEtRV4oMTkVy/ZVOzJV8ezlCAoxBsTnz5NPEy09k0FWsUibJ1ewOoaYkkiiZqefi59HodvjQ51lgR5WMdnIvKMmE1iZMourg0CyaDDme7+9HY2S/KZ5DkYtONYnWmxoMCHAUIhjl8Uh+5eZZOFe/mYcv6jrl66RBFFfikrgs8D0zJTUFuWmITReMtm0bTDWqyP5qUXjFZvM6UzWzAguIMAMCe05SgrnT+UBjVZ7oBiNeZGg8KcBTgSJMbviAHh82E6Xmpon1OVooZ5c40ANQbV4N90Y3UKkSucFg+xf66TvhDYVE/i4iL53nsr43cN4vLMkX9LMrDUY9DZ3oQCHHISTWLsphhvCjAUYBPog3V4tJM0ec22SgOW1pMlGtvtKESu0c1LS8VOalm+EMcDlEejqKd6eqHy+2DyaDDwmJxAxxW19BKKuXbJ7RRWaIsZhgvCnAUgG1pnohToEeyqDRSqR2opwBHyfr8IRw+Gwk2log41QBEjm1g983BBrpvlIzVNXMmOUTfqG1hSSYMeh1cbh+auikPR8nYqN9FMpqeAijAkT2e54XRlItEHjIGYgHOF81uOiVawQ7WdyHM8SjKtGFShk30z6PAWB32C3WN+A2VzWwQlhPTfaNcYY7HJ9H/fkuScN+MBQU4Mlfb7kGHJwCzUY+5Iux/c67CDBsKHFaEOR6fNdJ0g1Kxac1kVTixAKeblv0qGBvBSUaAA1BgrAY1rl70+kJIMRswsyBN6uIMQAGOzLHRm/lFjoQelDicC4VKh+bGlerTxm4AwMKSjKR83uxCB8wGPdr7/LTsV6E6PQGcbO0DEMn3S4YLaWpT8dgK3wtLM0Xbo2285FUach528yQj/4ZZTL0qReN5Hp9FA5wFIieKMlaTAbMnRacbGigwVqLqxsjf+9TcFGSmmJPymRdGA/AvmtzoD9AKPCWqbugGEDkEWm4owJE5Nk2UzJsnljDaDY6j6QalqW33wO0LwWLUozyJQ8aLovcorcBTJlbXJCsoBoBJGTbkp1sQ4nh8Ht1HhSjLZ9H/bguSNFo8FhTgyJjHH8KJ1shOtGKcPzWUmQXpsJr06OkP4lRbX9I+lyRGdXT0Zs4kB0xJHDKmfAplYw3V/OLk1TXxK/AO0DSV4rh9QZxq8wAA5hdlSFuYQVCAI2OHz/aA4wFnuhV56dakfa7JoMe8SRkAgM9oXxPFqRampzKS+rksn+J4Sy88flqBpyTx05rJbqjY6PTB+u6kfi6ZOLbvVUmWHVlJmtYci6QEONu2bUNZWRmsVisqKiqwb9++Ia+99NJLodPpzntceeWVwjU33HDDeT9fuXJlMr5KUrHDC+clcfSGYSu2DtGwseKwAGd+kgOc/HQr8tMt4PjINgNEORo7+9HlDcJsSO60JhC7Tw+d7U7q55KJY3WNFG3UaIge4LzyyivYtGkT7r33Xhw8eBDz589HVVUVWltbB73+L3/5C5qbm4XH4cOHYTAYcO211w64buXKlQOu++Mf/yj2V0m6amHIOCPpn81u2M/P0giOkviCYRyNBhcLJbhv5kZH/uhkcWVh01MzC9KStlqTmVWQDr0OaHH70er2JfWzycR8JtFo8WiJHuA88sgjWL9+PdatW4dZs2Zh+/btsNvteOaZZwa9PisrC06nU3js3LkTdrv9vADHYrEMuC4zU34Z3BPFku6kmNucOykS4HzR5EYwzCX988n4HGlyIxjmkZ1iRlGm+Bv8nYsFxocpMFaUzyQa9QOAFIsRU3MjZ+wdovtGUT6TsBM+GqIGOIFAAAcOHEBlZWXsA/V6VFZWYs+ePaN6j6effhqrV69GSsrAA7x2796NvLw8zJgxAxs2bEBHx9AHtvn9frjd7gEPueuI208kGRv8nassOwVpFiP8IQ4nWijRWCnie1RSnAnD7lVaEaMsn0nYmQLi7xsKcJTC1eNDi9sPg16H2YXpUhdnUKIGOO3t7QiHw8jPzx/wfH5+Plwu14iv37dvHw4fPowf/OAHA55fuXIlXnjhBezatQsPPfQQ3nvvPVx++eUIhwffR2Hr1q1wOBzCo7i4ePxfKknY1NCUnBQ4bKakf75er8Oc6CgOzY0rh1QJxgwb+Tvd7kGvLyhJGcjYhMKcMHIiVU983iQa+VMaVtdckJ8Gu9kobWGGIOtVVE8//TTmzp2LJUuWDHh+9erV+PrXv465c+di1apVePPNN7F//37s3r170PfZvHkzenp6hEdjY2MSSj8xnzdKl2DMzKNeleKwBkKKUT8AyEm1oNBhBc9HpsuI/J1o7YMvyCHNYsSUnJSRXyCCuXE5f3TUhzII+98kcVuBsRI1wMnJyYHBYEBLS8uA51taWuB0Ood9rcfjwcsvv4wbb7xxxM+ZMmUKcnJycPLkyUF/brFYkJ6ePuAhd+zmmSfh3gLCSirqVSmCxx9CbUdkT4rZhdJVOrEVeHTfKAGb1pxb5IBen/xpTQCYVeCAXge09frR4vZLUgYyNp/LoI0aiagBjtlsxqJFi7Br1y7hOY7jsGvXLixdunTY1/75z3+G3+/H9ddfP+LnnDlzBh0dHSgoKJhwmeWCjZokc9Otc7G9cI42u+EP0Tbqcne02Q2eB/LTLchNs0hWDlbh0Qo8ZWAjbWxKWgo2swEX5EeWp1OHSv54nscX7L6RsDM1EtGnqDZt2oQnn3wSzz//PI4ePYoNGzbA4/Fg3bp1AIC1a9di8+bN573u6aefxqpVq5CdnT3g+b6+Ptx+++34+OOPUVdXh127duHqq6/GtGnTUFVVJfbXSYrWXh/a+/zQ6SK7CkulOMuGDLsJwTCP4y5KNJY71lBJOXoDxPJwaA8lZWB7FkmdKEr3jXI09/jQ5Q3CqNdhen6q1MUZkuiZQddddx3a2tqwZcsWuFwuLFiwADt27BASjxsaGqDXD4yzampq8MEHH+Dtt98+7/0MBgM+//xzPP/88+ju7kZhYSFWrFiB+++/HxaLdL3WRGKR8eScFEmTt3Q6HeZOcuCfJ9rx2ZluyfI6yOgcaYr0fOXSUNV1eNHTH5QkSZ6MTpjjhX2TJL9vihz484EzNPKnAKyNmpaXCqspufsmjUVSWs+NGzdi48aNg/5ssMTgGTNmDJloZrPZ8NZbbyWyeLJztDly/tQsCUdvmDnRAId2ppU/ufTEM6N78Jzp6seRph4sm5ojaXnI0Oo7PPAGwrCa9JicI21PfG7cSiqe5yXZ5oCMDhstlkMbNRxZr6LSKtZQzZLB3gJsiuwoBTiyFgxzwjSi1FNUQKziY8E6kSdW18xwpsMgUYIxU+6M7Gjc3hdAWy8lGsvZF82RUTY5tFHDoQBHhr6ITjXIITqeFT2XpsbVC46j5ZtydaKlD4Ewh3SrUZIdjM9FgbEyfCGjnrjNbEBZdJn6URcFxnImp074cCjAkRlvIITT7ZGlvnK4ecqyU2Ax6uENhFHf6ZW6OGQILP9mVmG6LIb2Z0YD42MuCnDkTC7TmgwFxvLX0x8UdtmXQ2A8HApwZKbG1Quej2yYlpdmlbo4MBr0mOGMNFZU6ciXXFZQMayhOt7ShxCdZSZbQi6FTAKcWRTgyB77bzMpw4YMu1ni0gyPAhyZkePQ30wnVTpy90WTvHrixZl2pJgNCIQ4YUSSyEtrrw9tvZHtKMqjnRipsXIco9wt2fpCZkHxcCjAkRk5zYkzbLqBAhx54nk+bqpBHiM4er0O5dQblzWWAC71dhTx2MjfqbY+2lxUpo7IrDM1HApwZOaoHEdwaEWMrJ3t7kefPwSzQY8pudKcJTSYWGBM940cybEzVeCwwmEzIcTxONlKm4vKkTDLIKP7ZigU4MhImONxzMX2wJHHkDEQWb4JRBrSnn46IVpuaqL3zJTcFJgM8vmTLqepTVmT26gfENlctNxJgbFcBUIcTrZG/rtIucv+aMmnNiRo6PTKZtOteA67CZMyIkuPj1FjJTs1LZEKZ4ZM8igYWhEjbzXRFW5yyb9h6L6Rr7oOD4JhHqkWeWxHMRIKcGSE9cSn56VJvunWuSgPR77YfSO3AKfcmQadDmjt9aOjjzZuk5NAiMPptkjyt9zuGzb1QVsMyA+ray7IT5XFdhQjoQBHRo63sJtHXhUOQHk4ciYEODK7b1IsRpRm2QFAmHol8lDb7kGI45FmMaLAIf12FPHK43K3hjqyh0hDzm3UYCjAkZGallh0LDczqVclS8FwrCcux0qHphvkiTVU02XYE78gPw16HdDpCaCVjmyQldgIjvzqmsFQgCMjJ1iAI7MhYyAWdJ1o7aNelYzUtXsQCHNIMRtkOSfOKkLWoBJ5OC7TvC0AsJoMmBw9sqGGRv5kRc73zWAowJGJ+DlxOUbHpdkpMBl08AbCONvdL3VxSFRNXFAst544EB/g0JJfOZF7T5wCY/nxBWPH9cj1vjkXBTgyUdcRmRNPtRhRKLM5cQAwGfSYEl3ZdYIaK9k4LtP8G4aN/J2kkT9ZkXsuxfS82H1D5CHyNwxkpZiRkyrvIxoYCnBkQlhBJcM5cWZ6tLGiXpV8HJN5T7w0OwVGvQ59/hCae3xSF4dAGT3x6TSCIzuxVb7ybaPORQGOTLD8G7n2xIHI8nWAphvkhDUActvLhDEb9UI+BTVW8qCEnjgLvE600MifXCgt/wagAEc2amQ+ZAzEJxpTQyUH/YG4nriMK534xopITwl7mZTl2GHQ69DrD6HFTSup5EAJbdS5KMCRCVb5y/nmmR7XUHEc9aqkdqK1FzwPZKeYkZNqkbo4Q6KpTXmRe/4NAFiMBpRlR/ZQovtGHo7LdEPR4VCAIwO+YBh1HdEVVE757YHDlGXbYTbo0R+klVRyINcdjM/FpjZPUMKoLCghwAFoJZWcuH1BNEVz6C7Ik/d9E48CHBk41dYHjgcy7Cbkyrgnbow7rZqmqaSnnIaKVlLJCcuhk39gTCup5ILNMDjTrXDYTRKXZvQowJGB+IZKrnPizLQ8Nt1AlY7UahQwrQkAZTmRPZT6/CGhF0ik0esLCqOvcu+J00oq+Yjf+VpJKMCRgeNCQyX/m4eGjeXjVLRnK/dKx2SglVRycVxBPXEhOZ1G/iQn1/PuRkIBjgzIfbO2eMJKKhrBkVR/3I7SU3PlHeAAsTyck3TfSOqEgnriwkoqH62kkhqbJpT7aPG5KMCRgeOtrNKR/83DyniylVZSSelUW6TCyUoxIytFnnuZxKOVVPLA7pvpMp+eAmgllZyw+2ZqnvwD43gU4EjMFwzjTFekJz5NATdPaRatpJIDocKJJn3LnTC1SQmjkjoVPe9uap6y7htagSed+F3IlVLfMBTgSKy23QOeBxw2E7IV0BOPX0lFvSrpsPwbJQTFQNxKqpZeyqeQUCwwVsZ9w1ZSnaC6RjK10aA4J9WMDLv826h4FOBILL4nLvcVVMx0OiFackJPXCENFTuN3kOn0UvGFwyjMbrztVLuG1pJJb2TbZHf/RSF3DPxKMCR2GmFNVQAMC1a1tNtFOBIhSX9KeW+MRn0KMuOjPyx4IwkV32HFxwPpFuNsj2D6ly0kkp6p1qV10YxFOBITInJW2yK6hQFOJIIczxq2yOVjlKmqIDYfUOBsTTi6xqljBaXZtuh0wG9vhDa+wJSF0eTlJbvF48CHImxm2dKjnJuHhbJn2rzUK9KAo2dXgTCHCxGPQozbFIXZ9SmCiN/NIIjhVMKG/UDAKvJgKLMyD1OgbE0WBulpM4Uk5QAZ9u2bSgrK4PVakVFRQX27ds35LXPPfccdDrdgIfVah1wDc/z2LJlCwoKCmCz2VBZWYkTJ06I/TUSjuP42PCfgm4etmlbT38QnR7qVSWbEBTnpsKgV0ZPHIjN4dPInzSUlmDMCIFxOwXGyRYKc6hrV1beVjzRA5xXXnkFmzZtwr333ouDBw9i/vz5qKqqQmtr65CvSU9PR3Nzs/Cor68f8POHH34Yjz76KLZv3469e/ciJSUFVVVV8PmUtQ28y+1DfzAMo16Hkiy71MUZNZvZgEnRkQPKp0i+WP6Nckb9gPgpKrpnpBBLTFfYfZMTDYxpqXjSnenqF0aLJylotJgRPcB55JFHsH79eqxbtw6zZs3C9u3bYbfb8cwzzwz5Gp1OB6fTKTzy8/OFn/E8j9/+9re4++67cfXVV2PevHl44YUX0NTUhNdee03sr5NQrEdVmm2HyaCs2UI24kTDxsmn1CHjqdGGyuX2weMPSVwabeF5XpH5fkBcYEwjOEkXP1qsV9BoMSNqqxoIBHDgwAFUVlbGPlCvR2VlJfbs2TPk6/r6+lBaWori4mJcffXVOHLkiPCz2tpauFyuAe/pcDhQUVEx5Hv6/X643e4BDzlQ4pw4w3KGaLoh+ZS2RJxx2E3C6p1aaqySyuX2wRtQ3mgxQMnpUlJygjEgcoDT3t6OcDg8YAQGAPLz8+FyuQZ9zYwZM/DMM8/gr3/9K/7whz+A4zgsW7YMZ86cAQDhdWN5z61bt8LhcAiP4uLiiX61hIjtKqqshgqIH8GhhiqZeJ4XpqiUNoIDxE03UGOVVCzXT5GjxdFAvrGrH/5QWOLSaIuSl4gDMlxFtXTpUqxduxYLFizAJZdcgr/85S/Izc3F73//+3G/5+bNm9HT0yM8GhsbE1ji8VPiCipmKo3gSKLDE0BPfxA6XSzZW0liWwxQYJxMSk0wBoC8NAtSLUaEOR4NHV6pi6MpSp3WZEQNcHJycmAwGNDS0jLg+ZaWFjidzlG9h8lkwsKFC3Hy5EkAEF43lve0WCxIT08f8JCD0yoYwWno9FKvKonY6E1xph1Wk0Hi0ozdVNokUhJKbqh0Oh0FxhKhKaphmM1mLFq0CLt27RKe4zgOu3btwtKlS0f1HuFwGIcOHUJBQQEAYPLkyXA6nQPe0+12Y+/evaN+Tzno84fgckcPMMtRXqXDelUcD+pVJZHSKxxqqKSh5BEcIDbKfbqdAuNk6fQE0OWNjBZPUWAbBSRhimrTpk148skn8fzzz+Po0aPYsGEDPB4P1q1bBwBYu3YtNm/eLFx/33334e2338bp06dx8OBBXH/99aivr8cPfvADAJFo/pZbbsHPf/5zvP766zh06BDWrl2LwsJCrFq1SuyvkzCsB5uTaoHDbpK4NGM3sFdFlU6yKO2IhnOxvXBq2/vAcbRJZLLEcimUGhizpeIUGCcLq9cnZdhgMytvtBgAjGJ/wHXXXYe2tjZs2bIFLpcLCxYswI4dO4Qk4YaGBuj1sTirq6sL69evh8vlQmZmJhYtWoSPPvoIs2bNEq6544474PF4cNNNN6G7uxvLly/Hjh07ztsQUM6U3hMHIo3s52d6qDeeROx3rcQEYwAozrTBZNDBF+TQ7PYpcm8NpYkfLVbigYlA/GZ/1JlKFiWv8mVED3AAYOPGjdi4ceOgP9u9e/eAf//mN7/Bb37zm2HfT6fT4b777sN9992XqCImnRJ3MD4XLRVPPlbpKLWhMhr0KM1OwcnWPpxu66MAJwnYaHFumgUOm/JGi4GBm0TyPK+Ys7SUTOnTmoAMV1FphRpuHhac0QhOcviCYTT19AOIVfhKJATGtDNtUqhhtHhyTgp0OjoeJpmE6fA85d43FOBIJLZDpHJvnvgVMXTopvgaOr3geSDNYkR2ilnq4ozbFDpbKKmUvpcJEDl0s9BBx8MkE/v7VGqCMUABjiTCHB87wEzBN09pth06HdDrC6Gtzy91cVSP7f5blpOi6CH6qXQmVVKx+0aJ+ybFox2NkycQ4tDYGWmjlNwJpwBHAmejB5iZDXpMylRuDoLVZEBRtPy11FiJrk41DRXtZpxMaglw6FTx5Gns8oLjAbvZgLw0i9TFGTcKcCRQ2xH5Ay3JtsOgwAPM4pVlRyrNug6qdMQWP4KjZGwEp7nHB2+ADt0UE8/zwt+mWu4bGsERH+tMlWUre7SYAhwJxN88Ssd6hbXttNmf2GI9cWUdlniuDLsZWdEcIpqmEldrrx/eQBh6XWT3ayWLjfzRPSM2tYz6UYAjAbU0VEDcCA4NG4sudt8oN2+LEXrjdN+Iit0zRZl2mI3Kru5ZLkhDpxeBECdxadQtNuqn7DZK2Xe8QqllyBiIRfg0RSUujz+E1t5IIvdkFYz8UWCcHGrJ2wIAZ7oVNpMBYY5HYxeNGIuJLYJR+iwDBTgSECodhd88QCxIq+vw0Nb7ImIBZFaKWZFHe5xLuG8owBEVy/dTQ4Cj0+lQmh0ZUainDpWoaIqKjEswzKGxK7JZ22QFL79jijJtMOgjW++39PqkLo5qCQnG2coeMmYoOT056lR231DOn/jiNxRV+iwDBThJdqarH2GOh9WkR36acs7OGorJoEcxWypOvXHR1KlkBRXD5vbr6CR6UQlTDaq5byLfg0ZwxCNsKGpV9oaiAAU4SRe/gkqv8CXiTGy6gRorscR2FVVJQxUdwen0BNDTH5S4NOrEcbEl4kqfamDYSBR1psQTPz2l5CXiAAU4SVeroiXiDE03iE9tIzgpFiNyoxuIUW9cHM1uH/whDka9TjWHmlJdIz41bWNCAU6SqWkFFRObF6dKRyxsKkcNlQ7DkuzpvhEHa6hKsuwwGtRR1bO65mxXPy0VF4laNhQFKMBJOjXtgcPQihhx9XhjJyirZaoBiMvDoalNUaipoWJy0yywmw3geNBScZGoqY2iACfJhBEcFfbE6zu9tFRcBGypb16aBSkWo8SlSZxSmm4QlZr2wGEiS8Up0VhMamqjKMBJokCIw1m2RFxFlU5hhhUmgw6BECcsLySJU9seOXtHTT1xgDaJFJsap8OB2MgCLRVPPG8ghBZ3dENRFdw3FOAkUUNn5ITWFLNBSLBUA6NBj+Ismm4QC6vI1bKCiqHdjMVVq6INReOV0n0jGlZ/Z9pNyLAre4k4QAFOUrE/yFKFn9A6GCFhlHrjCae2FVQMy8Hp8gbR46Wl4okUCnNo6GR74Cg/lyLeZJraFI3a8rYowEkite1JEY8SjcWjxq0FAMBuNiIvOpJJjVViNXX7EAzzMBv1KHSoY4k4w45roHsm8YQ2SiV1DQU4SRSLjtXVowIowBELz/PC73SKCo72OFcZ5eGIgo2klmbZVbOhKENLxcVDIzhk3NSUnX4umqISR3tfAL3+EHS6yH4makN74YhDrdOaAC0VF5PaVt5RgJNELIFLLTdPPDYq1djpRShMvapEYUFxocMGq8kgcWkSrzSHnQ5NDVUiqeU06MHELxWnEePEUlsaBQU4SaKmE1oHU+iwwWzUIxjm0dRNp4onipobKoBGcMSitobqXJPpsNaEc/uCaO+LbCiqljaKApwkEU5otSj/hNbB6PU6lEanUGiaKnFiK+/UNz0FUA6OWNR0ntBgaIuBxGO/y5xUC1JVsqEoBThJEp+8pbYl4gwlGiceW+qr1gCHfa9ubxDd3oDEpVGHYJhDowo3FI1Hh24mnpqOaGAowEkSNSf9MXToZuKxAKckS533jd1sRH46WypO0w2JcKarH2GOh81kEH63akMjf4lXr8IDfSnASZLY/gLqiY7PRb2qxGOVjlpHcACabkg09vdXmm1X8Whx5O+BloonjhrrGgpwkiR286gnOj5XWfQPo4F64gnR4w2ipz+yw68al4gzZZRonFDs70/N90xuqgUp0aXibJSTTExDZ+Tvr0RFbRQFOEmixuj4XKXRYePGLi/CdKr4hNV3xpL+1HSK+LnYdAOdDp0YsWlN9dY1dKp44gltlIruGwpwkiAQ4tAcXSJeouIAx5luhdnAlorTqeITpYWgGIg7HZpG/hJCK/dNmXCqOAU4E9UfCKO1N3KKuJrum6QEONu2bUNZWRmsVisqKiqwb9++Ia998skn8eUvfxmZmZnIzMxEZWXledffcMMN0Ol0Ax4rV64U+2uM25muyCnidrMBuanqTPoDAINeh6KsyLk3NGw8cVroiQO0+i7R1DjVMJgyYQSH6pqJYjtCp1mNcNhMEpcmcUQPcF555RVs2rQJ9957Lw4ePIj58+ejqqoKra2tg16/e/dufOc738G7776LPXv2oLi4GCtWrMDZs2cHXLdy5Uo0NzcLjz/+8Y9if5Vxq49rqNSa9Mew4U2qdCZOC7kUQOz79fTTqeITxfN8bGsBld83tKghceJH/dTURoke4DzyyCNYv3491q1bh1mzZmH79u2w2+145plnBr3+xRdfxA9/+EMsWLAA5eXleOqpp8BxHHbt2jXgOovFAqfTKTwyMzPF/irjxhqqYpVXOEAsiZrlj5DxY79DNQ0ZD8ZuNiI3eqo43TcT09rrhy/IwaDXYVKmuk4RPxeb7qfR4omrFw5nVdeon6gBTiAQwIEDB1BZWRn7QL0elZWV2LNnz6jew+v1IhgMIisra8Dzu3fvRl5eHmbMmIENGzago6NjyPfw+/1wu90DHsmklR4VEOuN00qqiWvQSC4FQCN/icJ+f4UZVpgM6k6xZH8XZ7v66fy7CRKmw1VW14j6F9De3o5wOIz8/PwBz+fn58Plco3qPe68804UFhYOCJJWrlyJF154Abt27cJDDz2E9957D5dffjnC4fCg77F161Y4HA7hUVxcPP4vNQ5aSfoDYt+RGqqJ8YfCaHZHzvRS6yZ/8ag3nhhaydsCgPw0K8xGPUIcnX83UWpcQQUAsl57+otf/AIvv/wydu/eDavVKjy/evVq4f/PnTsX8+bNw9SpU7F7925cdtll573P5s2bsWnTJuHfbrc7qUGOVpL+gFiAEzl7i1fVfG4yNXb2g48mpuekqu/ssnOxoXEa+ZuYhuhUgxaCYr1eh5IsO0629qG+06O60YdkUmtgLOoITk5ODgwGA1paWgY839LSAqfTOexrf/WrX+EXv/gF3n77bcybN2/Ya6dMmYKcnBycPHly0J9bLBakp6cPeCSLlpL+AKAo0w6dDujzh9DpobOFxksIijWQmA7EjfxRDs6E1Kv87LJz0dTmxIU5Hme6aIpqzMxmMxYtWjQgQZglDC9dunTI1z388MO4//77sWPHDixevHjEzzlz5gw6OjpQUFCQkHInkpaS/gDAajKgID0y2kZnC42fVlZQMcWUu5UQap1qGApNbU5cc08/gmEeJoMOBQ51tVGiZ6Ft2rQJTz75JJ5//nkcPXoUGzZsgMfjwbp16wAAa9euxebNm4XrH3roIdxzzz145plnUFZWBpfLBZfLhb6+PgBAX18fbr/9dnz88ceoq6vDrl27cPXVV2PatGmoqqoS++uMmZaS/phYpUO98fHSXE88+j2b3T74Q4Pn0pGRqTVZdCixERyqa8ZLWOWbaYdBr67RYtFzcK677jq0tbVhy5YtcLlcWLBgAXbs2CEkHjc0NECvjzX8jz/+OAKBAL71rW8NeJ97770XP/3pT2EwGPD555/j+eefR3d3NwoLC7FixQrcf//9sFjkt4meWpffDac0KwUfn+6kYeMJEEZwNJC3BQDZKWakmA3wBMJo7OzHtLxUqYukOL2+oDAtrJWRv1La7G/C6lUcFCclyXjjxo3YuHHjoD/bvXv3gH/X1dUN+142mw1vvfVWgkomPtaj0sIeOEwJHbo5YfUaytsCImcLlWSn4GizGw2dHgpwxoHVNVkpZqRZ1bMb7XCEqU1a1DBuap7W1MaciYS0tESciSWMUoAzHhwXl5iupfuGEkYnRGt5WwBQnGWDTgd4A2G099GihvFgqQRq7IRTgCMyLa2gYth0HDVU49PS60MgFElML8xQV9LfcGgPpYnRWt4WAFiMsUUNlGg8PrHOlPqmwynAEZnWkv6A2Hdt7/PD4w9JXBrlYT3xSRk2zSSmA7QiZqLUPNUwHFrUMH48z6t6lkE7tacE4pP+1BgdD8VhMyHTHskBoMZq7OpVuunWSGIjf9RQjYeWNhSNRyPG49ftDaLXF+mEqrG+oQBHROwPLjvFjFSLrDeNTriSbGqsxiu2gkp9Fc5wWA+ysasfHMdLXBrlUetutCOhRQ3jxzpT+ekWWE0GiUuTeBTgiEiL01MMJYyOn9ZWUDEFDiuMeh0CIQ4tvXS20FgEw5xwHpMapxqGQ4saxk/t25hQgCMirc6JA1TpTAQ7T0hrDZXRoEdRdLdvCozH5mxXP8IcD6tJj7w0+e0HJiaaoho/YZM/lbZRFOCIKP48Ia0poa33xy2Wg6POXtVw6MiG8YnP29LaXjC0qGH81L4dBQU4IqrX2G608YQdRmllw5j09AfR7Q0C0OjUJh26OS5aOkX8XA6bCQ4bLWoYD7VvLUABjojUvPxuJOw7N3X7EAxzEpdGORo7tZuYDtB0w3ipvSc+EtpDaXzUvjkkBTgiCYQ4NPf0A9BmDk5emgVWkx5hjsfZrn6pi6MY9RpdQcXQXjjjU6/yhmokwpQ4jfyNmi8YhsvNEtPVOfJHAY5Iznb3g+MBm8mAXI0l/QHRs4WyKNF4rNjUjBaDYoB64uOl5RWbQOy+ocB49NhocarFKOxbpjYU4IikviOWYKy1pD+G9QoaaC+cUdPaKeLnYkFxT38QPdFcJDI8nuc1eSRMPJraHLv4UT+1tlEU4IhE6z0qIFbZ1lGlM2pa3loAAOxmozDiSYnGo9PW54c3EIZeBxRlavO+oanNsdNC3hYFOCLRekMF0HTDeGih0hkJbRI5NmzUr8Bhg9mozSqd/b2c7epHiBY1jIoWOuHa/GtIAq0niwKxaRZK/BudQIhDUzQxXdv3DfXGx4KCYiA/zQqzUY8Qxws7OpPhqX0XY4ACHNFoeZM/pjQr1lDxPJ0tNJIzXV7wLDE9VXuJ6QyrcGmzv9HR+goqANDrdShmu2BTh2pU1L4HDkABjigGJP1pNFkUACZl2mDQ6+ALcmjt9UtdHNnT8m608UqyqaEaCy1MNYyGsLkoBcYjCnM8znRGR4tVHBhTgCOC1l4/fEEOeh0wKcMmdXEkYzLoUZhhBUCVzmho9RTxc5XQCM6YaGGqYTRKsmhqc7Rcbh8CYQ5GvQ4FDqvUxRENBTgiYH9ghRnaTfpjYss3qTc+EkpMj2BD5s1uH/yhsMSlkT/KwYmILWqgumYkrPNQlGmD0aDeNkq930xCWj6i4VyUMDp6LG9L6/dNdooZKWYDeB5o7KRdsIfj8YfQ3hcAQCN/tGpz9IQcUZWnUFCAIwItH3x3LlryO3paPpw1nk6noxV4o8Q6Dpl2E9Kt6tyNdrSEqU1a1DAirYwWU4AjAi1kp48WbaE+OrQb7UAUGI8OraCKKc6yQacDvIEwOjwBqYsja1ppoyjAEQFVOjHxvSoytNZeP/yhSGJ6oYYT0xmabhgdrUw1jIbFaEBBOi1qGA2Wg1Os8jaKAhwRNHRSgMOwvIBOTwC9PjpbaCisQqbE9AjK3RodrUw1jFbsvqGpzeEIK+9oBIeMRa8viM7o8Kjab57RSLUYkZ1iBkC9quFopcIZrVIa+RsV2gNnoBKa2hxRjzcIty8EQP2dcApwEoz9YWWlmJGm8aQ/hnrjI4uN+tFUAzAwd4vjKGF0KJS3NRDb7I/2UBoa20AzN80Cu9kocWnERQFOgtH01PkoYXRktLXAQAUOK4x6HQIhDi29dLbQYEJhDme76OyyeMIIDnWmhqSlaU0KcBKMNt06X2zJL1U6Q6Ge+EBGgx6T2NlCFBgPqqnbhxDHw2zUIz9NvbvRjgUlp49MS51wCnASTEvR8WjFDt2kxL+hUC7F+YSt96mxGlR93IG+er12zy6Lx3K32vv88AZCEpdGnli+nxbqGgpwEoyWbZ6vhHpVw4pPTNdCr2q0aA+l4VFn6nwOuwkOWyT3ke6bwWlpOjwpAc62bdtQVlYGq9WKiooK7Nu3b9jr//znP6O8vBxWqxVz587F3//+9wE/53keW7ZsQUFBAWw2GyorK3HixAkxv8KoaenmGS1WATd19yMQ4iQujfxQYvrghHPMqKEaVCON+g2KpqmG16ihBQ2iBzivvPIKNm3ahHvvvRcHDx7E/PnzUVVVhdbW1kGv/+ijj/Cd73wHN954Iz799FOsWrUKq1atwuHDh4VrHn74YTz66KPYvn079u7di5SUFFRVVcHnkzYZMRDi0NSt/iPoxyo3zQKbyQCOB85209lC59LSnPhYFAtTVDS1ORgawRlcMU1tDskfCqPZHWkntdAJFz3AeeSRR7B+/XqsW7cOs2bNwvbt22G32/HMM88Mev1//dd/YeXKlbj99tsxc+ZM3H///bjwwgvxu9/9DkBk9Oa3v/0t7r77blx99dWYN28eXnjhBTQ1NeG1114T++sM62x3PzgesJr0yEuzSFoWOdHpdLF8CuqNn4dG/QYn9MTpnhlUPY3gDEpYtUk5f+dp7OwHzwMpZoOwP5maiRrgBAIBHDhwAJWVlbEP1OtRWVmJPXv2DPqaPXv2DLgeAKqqqoTra2tr4XK5BlzjcDhQUVEx5Hv6/X643e4BDzEIyVtZduh0lPQXT9gLh3rj5xFOEaee+AAsKO72BtHTT7tgx+N5ng71HQJNUQ2N1TXFGmmjRA1w2tvbEQ6HkZ+fP+D5/Px8uFyuQV/jcrmGvZ7971jec+vWrXA4HMKjuLh4XN9nJLRZ29Boh9GhxVZQ0X0TL8ViRE5qZCSUphsG6vAE4AmEodNFDpkkMaz+baSRv/NobbRYE6uoNm/ejJ6eHuHR2NgoyufML8rAj/9lGr42r0CU91cymm4YGh3OOjRaSTU4ds8UpFthMRokLo28sHvmTFc/QmFa1BAvFuBoozMl6j7NOTk5MBgMaGlpGfB8S0sLnE7noK9xOp3DXs/+t6WlBQUFBQOuWbBgwaDvabFYYLGInxMzvzgD84szRP8cJaI9TQYXn5iulV7VWJRm2XGgvovyKc5BK6iGlp9uhdmgRyDMobnHp/oTs8eiUWMLGkQdwTGbzVi0aBF27dolPMdxHHbt2oWlS5cO+pqlS5cOuB4Adu7cKVw/efJkOJ3OAde43W7s3bt3yPck0iuN282Y5+lsIYYS04dHK2IGF1tBpY2e+FgY9DoUZdEu2IOp19hO+6JPUW3atAlPPvkknn/+eRw9ehQbNmyAx+PBunXrAABr167F5s2bhet/8pOfYMeOHfj1r3+NY8eO4ac//Sk++eQTbNy4EUBkRc4tt9yCn//853j99ddx6NAhrF27FoWFhVi1apXYX4eM06QMG/Q6oD8YRlufX+riyAYlpg+PEkYHJ+xirJGGaqxoJdX5OI7X3JYUoh8let1116GtrQ1btmyBy+XCggULsGPHDiFJuKGhAXp9LM5atmwZXnrpJdx99934j//4D0yfPh2vvfYa5syZI1xzxx13wOPx4KabbkJ3dzeWL1+OHTt2wGql81jkymzUozDDhjNd/Wjo8CKPzs4BQInpI6EcnME1UN7WsCIjxm1038Rp6fUhEOJg0OtQmKGNxPSknJW+ceNGYQTmXLt37z7vuWuvvRbXXnvtkO+n0+lw33334b777ktUEUkSlGTZcaarH/UdXiwuy5K6OLLQoLFVDWPFAr+mnsgu2GajJtZFjEhrUw1jRTl/52OjoJMybDAZtPF3pI1vSWSBVlKdjxqq4eWkmmE3G8DzwJkuum8AoD8QRltvZJqXcnAGR1Ob59NiZ4oCHJI0rDdOm/3FsEqHVnoMLn4XbAqMI9i0i8NmgsNOZ5cNJn7ndFrUEKG1/BuAAhySRDSCMxDPx5L+aBfjodF0w0DxielkcKzD0OcPodMTkLg08lBPAQ4h4mF/WLTDaERbrx/9wTD0OqAoUzuVzljRdMNADbQHzoisJgOc6ZGFDJRoHMFGzmmKihARsAq5vS+APn9I4tJIj/WoChw2Sp4dRomwhxJNbQJ0ivholdAKvAHqNbhik2pVkjTpVhMyozkDNN2gvXNhxquUTqIfgBLTR6eUzr8T9PQH0e2NHFirpZE/CnBIUlFvPKaBGqpRoYTRgegU8dGhqc0Y1qHMSTUj1ZKU3WFkgQIcklTUq4qpp4ZqVCZl2mDQ6+ALcmjt1fYu2KEwhzNddHbZaAjHfFBnKrbztcamNSnAIUlFO9PG0BTV6JgMehRmRBJGtR4YN/f4EOJ4mI16IYmWDI6df6f1ewbQ3iniDAU4JKmKKZ9CoMV9KcaLbWhXr/E9lNg9U5xpg15PZ5cNh40Wt/b60R8IS1waaWntFHGGAhySVDRFFdHrCwr7c9AIzshYYqTWtxjQak98PDLsJqRZI/kmjRrfBVuro8UU4JCkYhXz2e5+BMOcxKWRDqtwslLMSLPSbrQjKaXdjAFoN5diPHQ6HSUaR2l1QQMFOCSp8tIssBj1CHM8mrr7pS6OZGh6amxKaOQPAJ0iPlax+0a7U5v+UBhNPZG6VmsLGijAIUml1+uosYJ2h4zHizZti6D7ZmyE8+80fN+c6eoHzwN2swE5qWapi5NUFOCQpKOVVLGlq7Qb7eiwqc1OTwC9vqDEpZHGgLPLKMAZFaprBo766XTaSkynAIckHa2kivXESyhZdFRSLUZkp0R6n1od+ev0RI440dHZZaNWSge1avpwVgpwSNKV0rw4TTWMg9ZXUrEEa2e6FVaTQeLSKINwz3R5Eea0uQu2lo/2oACHJJ3WN+AKhDg0R5P+aIpq9Eo0vpKKEozHrsBhg8mgQzDMC39zWiPsgaPB0WIKcEjSxSeMavFsoTNdXnA8YDMZkJtmkbo4iqH1PZQo/2bsDHqdMJ2n1WkqLZ8+TwEOSbqiTBt0OsAbCKO9LyB1cZKuvlO7SX8TofWDWmmTv/HR8sgfx2k7MZ0CHJJ0FqMBBdFzdLSYaCxMNWiwwpkIrW/axgK7Yg32xCdCyyupWnv98Ic4GPQ6FGbYpC5O0lGAQyQRm6bSXm9cy0PGE8F+X00a3QWb7pvxKdHwSiq2kGNShg0mg/aae+19YyILscMTtVfpCHvg0AjOmOSmWWA16cHxwNkubSWM9gfCaO31A6D7ZqyERQ1a7ExpeHoKoACHSEQYwdFggEN74IyPTqfTbD4Fm15JtxqRYdfWbrQTFT+1qbVFDax+1eq0JgU4RBJCpaOxhmpA0p9GK52JELbe19geSmyqgRKMx644uoqq1xdCt1dbu2DXa7yuoQCHSKJUo2fExCf9TcrUXtLfRGk10Vg4nFWjUw0TYTMbkBfdjkFr9Y2WV1ABFOAQibCphrZeP7yBkMSlSR7WEy/MsGoy6W+itDryRwnGE6PV+6ZBOKZBmyN/VMMSSTjsJjhsJgDa6lXFhoy1WeFMFAuMtXZcg9Z74hOlxalNty+IruiUnFZH/ijAIZLR4nQD7YEzMSVxB7VqKWGUBThaTRadKC3XNTmpZqRajBKXRhoU4BDJaHF/Cq0n/U1UUaYd+ugu2G19fqmLkxRhjseZLtrFeCK0OEVVT2eXUYBDpBOrdLQzbNzQQXvgTITZqEeBI5KcrZXAOLKxIQ+zQQ9ndAdwMjbFmuxM0co7UQOczs5OrFmzBunp6cjIyMCNN96Ivr6+Ya//0Y9+hBkzZsBms6GkpAQ//vGP0dPTM+A6nU533uPll18W86sQEcSmG7SzaVvsHCrtVjoTpbXpBjY9VZRlg0FPZ5eNBxsxdbl98AXDEpcmOej0eZEDnDVr1uDIkSPYuXMn3nzzTbz//vu46aabhry+qakJTU1N+NWvfoXDhw/jueeew44dO3DjjTeed+2zzz6L5uZm4bFq1SoRvwkRg9YS/3q8QWEfDsrBGT+tnS1EK6gmLisllofCpvvULnY4q3bvG9Eyj44ePYodO3Zg//79WLx4MQDgsccewxVXXIFf/epXKCwsPO81c+bMwf/8z/8I/546dSoeeOABXH/99QiFQjAaY8XNyMiA0+kUq/gkCdgf3pmufoTCHIwqXzbNhoxzUi2aTfpLhOIsjQU4NNUwYWwX7C+a3ajv8GJaXprURRKdsHeShgNj0VqUPXv2ICMjQwhuAKCyshJ6vR579+4d9fv09PQgPT19QHADADfffDNycnKwZMkSPPPMM8OuqPD7/XC73QMeRHrOdCvMRj1CHI/mHp/UxREd9agSI3aOmTZG/miqITG0NLUZCHFo7olM/Wt5tFi0AMflciEvL2/Ac0ajEVlZWXC5XKN6j/b2dtx///3nTWvdd999+NOf/oSdO3fimmuuwQ9/+EM89thjQ77P1q1b4XA4hEdxcfHYvxBJOL1eh+Lobr5aqHToiIbE0OwUlYYbqkQo0dDI35kuLzgesJsNyE21SF0cyYw5wLnrrrsGTfKNfxw7dmzCBXO73bjyyisxa9Ys/PSnPx3ws3vuuQcXX3wxFi5ciDvvvBN33HEHfvnLXw75Xps3b0ZPT4/waGxsnHD5SGJo6aRfNuKg5R5VIrDfX3tfAH1+de+CzfM8bfKXICUaCozr46andDrtJqaPORHgtttuww033DDsNVOmTIHT6URra+uA50OhEDo7O0fMnent7cXKlSuRlpaGV199FSaTadjrKyoqcP/998Pv98NiOT9atVgsgz5PpKelXhX1xBMj3WpCpt2ELm8QjZ1ezCxIl7pIounyBoUgriiT7puJ0NLUJk1rRow5wMnNzUVubu6I1y1duhTd3d04cOAAFi1aBAB45513wHEcKioqhnyd2+1GVVUVLBYLXn/9dVitI+/7UF1djczMTApiFEhLm/010BLxhCnJsqPL24P6DnUHOKwxdqZbYTUZJC6NsrGORWNXPziOh17FS+6pMxUhWg7OzJkzsXLlSqxfvx779u3Dhx9+iI0bN2L16tXCCqqzZ8+ivLwc+/btAxAJblasWAGPx4Onn34abrcbLpcLLpcL4XBk74I33ngDTz31FA4fPoyTJ0/i8ccfx4MPPogf/ehHYn0VIiKtJP75gmG43JFEaq1XOolQks1Oo1d3b5waqsQpcFhh1OsQCHHC36Ja1UUD47IcbXemRF2r+uKLL2Ljxo247LLLoNfrcc011+DRRx8Vfh4MBlFTUwOvN/JHfPDgQWGF1bRp0wa8V21tLcrKymAymbBt2zbceuut4Hke06ZNwyOPPIL169eL+VWISOITRnmeV+188ZkuL3geSDEbkJ1ilro4iscStdUeGNe2RxqqyRpvqBLBaNCjKNOGug4v6ju8KMywSV0k0QgBjsa3FhA1wMnKysJLL7005M/LysoGLO++9NJLRzxAb+XKlVi5cmXCykikxfIK+vwhdHoCyFZpxr9wLkx2imqDuGTSSsJoXQftgZNIxVl21HV40djpxdKp2VIXRxShMIfG6N+F1kdw1L2zGpE9q8kgnK+j5saKdqNNLK2M4NRFv9/kHLpvEkEL59819/giZ5cZ9SjQ+NllFOAQyWmhNy70xKmhSgg2onG2ux/BMCdxacRT1065FIkUW0ml3rqGTWuWZtlVnUg9GhTgEMlpoTcu5FLQVENC5KVZYDXpEeZ4nO1S52GtXZ4AevojZ5eV0sq7hCjRwKIGmtaMoQCHSE4LK6loVUNi6fU6odGvVem+JnVxS8RtZloinggsWbuu3TNivqdS1bXTtCZDAQ6RnNqX/AZCnDDKQKthEie+sVKjWFBMDVWiRHb2BXr9IXR4AlIXRxTUmYqhAIdIriw6glPbrs4RnIbO2LkweWnqXCUmhTKVBzjs70HrS30TyWoyoNARWR6u1vtGyNui+4YCHCI91lC19/nR6wtKXJrEYxVOKS0RTyg2BF+r0qnNeuqJi4KNiNWqMMAJhTk0dtEScYYCHCK5dKsJOamRze/qVDiKw4aMaU48sVgPVf09cbpvEkmY2lRh7lZTNy0Rj0cBDpEFVumoMWG0loaMRcHumTNdXgRC6loqzvN87L6hnnhCxQJj9XamaIl4BAU4RBZYpVPbpr4Ah5L+xJGbZkGK2QCOV98eSt3eINy+yCnitEQ8sVhgfFqFI39U1wxEAQ6RhTIVDxvHlm1SpZNIOp1O2OtDbdNUtbREXDSsrqnvUN9S8Vqa1hyAAhwiC1PYFJXKGipfMIymnsgScZqiSjy15lPEdjCmhirRijPt0OsAbyCM1l6/1MVJKLaXGI3gRFCAQ2ShTKUBTuSUdCDVYhQSqUniqHVFTOwMKmqoEs1s1AuH/KruvqF8vwEowCGywP4ge/qD6FLRBly1cT1xWiKeeJNzUgGodwSHttsXhxo3iQyFOSEXjUZwIijAIbJgMxtQ4Igsa1TTSirqUYmLLb1X24oYIVmU7htRqHHVZlO3DyGOlojHowCHyIYaV1LF9sChhkoM7J5p6umHLxiWuDSJEb9EnO4bcbAkXDWN4NTSEvHzUIBDZGNyrvoSRmkPHHFlpZiRZjWCV9FS8S5vEL3RJeIlWZRkLAY15vzRztfnowCHyMbkbPVVOmzqhCodceh0uth0g0ruG/Y9Chy0RFwsk4Wl4l5wnDqWitMS8fNRgENkQ229qv5AGC63DwBNNYipTGWBMeuJl1JDJZpJGTYY9Tr4Qxyao3+jSldHO1+fhwIcIhvxKxvUsAEXm2pLtxqRaTdJXBr1Utup4nWUfyM6o0GPEpXl4bA9cCbTdLiAAhwiGyVZkQ24PIEw2lSwAVd8Q0VLxMUzWWV74bDT0SlvS1xqmhKPXyJeSoGxgAIcIhtq24CrlpL+kkJte+HEpqjovhGTmkb+znb30xLxQVCAQ2RFTWdS0R44ycF64i1uP7yBkMSlmRie54VtEmiKSlxqyvk7ze6Z7BRaIh6HAhwiK5Oj8+JqOOmXDtlMDofdJOQ4KX3Dv7ZeP3r9Ieh1dA6V2IQpKhV0pk619QEApuRSXROPAhwiK2raQp2mqJJHLSN/p6I98aJMOyxGWiIuJhZANnZ6EQpzEpdmYliHcGpuqsQlkRcKcIisxObFld0T7/OHhERpWtUgPrUkjJ5up554shQ6bDAb9QiGeTR1K3up+GkawRkUBThEVibH9cSVvAEXy6PITjHDQUvERaeWfAqWS0E9cfHp9TphUzylT1Ox+2YK3TcDUIBDZGVShg0mg/I34GJz4tRQJQfruSo9wKFciuRiHSo2AqJEvb4gWqOjxXTfDEQBDpEVo0GP4uj5O0o+dJOGjJOLBZInW/sUvUmk0BPPocA4Gdh9c1rRdU2k7DmpFqRbabQ4HgU4RHamCNMNyu1VnaKphqSKbKYI9PQH0ekJSF2ccfGHwjjTFck9m5pHgXEysL/PUwoewWF5W1OpM3UeCnCI7MQqHeX2qoQpKmqoksJqMqAo0wZAufdNfYcXHA+kWYzITbVIXRxNmJqnggCH8m+GJGqA09nZiTVr1iA9PR0ZGRm48cYb0dc3/I106aWXQqfTDXj8+7//+4BrGhoacOWVV8JutyMvLw+33347QiFlb/BFYuKnG5SI43ghF4SmGpJH6b3xU62xaU062iM52BRyi9uPXl9Q4tKMTywxnTpT5xI1wFmzZg2OHDmCnTt34s0338T777+Pm266acTXrV+/Hs3NzcLj4YcfFn4WDodx5ZVXIhAI4KOPPsLzzz+P5557Dlu2bBHzq5AkUnqv6mx3P/whDmaDXhhVIOITAhyFBsa0l0nypVtNyEuLjJYpNQ+HEtOHJlqAc/ToUezYsQNPPfUUKioqsHz5cjz22GN4+eWX0dTUNOxr7XY7nE6n8EhPTxd+9vbbb+OLL77AH/7wByxYsACXX3457r//fmzbtg2BgDLn3slA06IVfHOPD31+5Y3MsQqnNNsOo4FmgZNF8SM41FBJQsn3TfxoMQXG5xOt9t2zZw8yMjKwePFi4bnKykro9Xrs3bt32Ne++OKLyMnJwZw5c7B582Z4vbFN3/bs2YO5c+ciPz9feK6qqgputxtHjhwZ9P38fj/cbveAB5Evh92EnFTWq1JepUN7mUiDDdErNQeHcimkwfLklDglPnC0mI72OJdRrDd2uVzIy8sb+GFGI7KysuByuYZ83b/927+htLQUhYWF+Pzzz3HnnXeipqYGf/nLX4T3jQ9uAAj/Hup9t27dip/97GcT+TokyabmpqC9z49TbX2YV5QhdXHGhHri0mBTm41dXviCYVhNyjnqgOd52jtJIkoewWHTmqXZdhjokM3zjHkE56677jovCfjcx7Fjx8ZdoJtuuglVVVWYO3cu1qxZgxdeeAGvvvoqTp06Ne733Lx5M3p6eoRHY2PjuN+LJAdrrJTYq6KGShrZKWY4bCbwvPLOpGrvC6DXF4JOF2msSPIoedUm7bc1vDGP4Nx222244YYbhr1mypQpcDqdaG1tHfB8KBRCZ2cnnE7nqD+voqICAHDy5ElMnToVTqcT+/btG3BNS0sLAAz5vhaLBRYLLbtUkmlCwqgSK53oFFUeBTjJpNPpMDU3BQcbunGq1YNyZ/rIL5IJ1lAVZdoUNfKkBuzvtL7Dg2CYg0lBeXPUmRremAOc3Nxc5Obmjnjd0qVL0d3djQMHDmDRokUAgHfeeQccxwlBy2hUV1cDAAoKCoT3feCBB9Da2ipMge3cuRPp6emYNWvWGL8NkSulrqSibdOlNTU3NRLgKOy+OU3bCkimIN0Km8mA/mAYjZ1eReVAUd7W8EQLVWfOnImVK1di/fr12LdvHz788ENs3LgRq1evRmFhIQDg7NmzKC8vF0ZkTp06hfvvvx8HDhxAXV0dXn/9daxduxZf+cpXMG/ePADAihUrMGvWLHz3u9/FZ599hrfeegt33303br75ZhqlURGWMFrX4UEozElcmtFjFU5uGm2bLoUpCs2nYEvbqSeefHq9TuiMKG2aKhbgUGdqMKKOxb344osoLy/HZZddhiuuuALLly/HE088Ifw8GAyipqZGWCVlNpvxj3/8AytWrEB5eTluu+02XHPNNXjjjTeE1xgMBrz55pswGAxYunQprr/+eqxduxb33XefmF+FJFmhwwabyYBgmEdDp3fkF8iEkGCcQxWOFGIrqZQV4AgjONRQSUKJicYefwiu6IHEU2nkb1CiraICgKysLLz00ktD/rysrGzAwXjFxcV47733Rnzf0tJS/P3vf09IGYk8sV7VkSY3Trb2KWYIlvJvpCVMbbZ6wHE89ApZWULJotJS4iaRbP+bnFQzHHYaLR6McrKpiOYocXUDjeBIqyTLDqNeh/5gWOjdyp0/FEZjVz8AmqKSCtsLR0kjOGyFKeVtDY0CHCJb0xS4VJyVlUZwpGEy6IVl1kpprGrbPQhzPNKsRuHYAJJc8Z2p+FkFOTvR2gsAmJ5Pdc1QKMAhsqW0efFAiBOGjS/IT5O4NNqltOmGEy2Rcl6Qn0aHbEpkck4KdDqgpz+IDo8yjvw5Hr1vplNnakgU4BDZmha3VFwJvaq6Dg9CHI9UixGFDqvUxdGs2BYDypjaPNES7YlTQyUZq8kgHIyrlMCYjRZTZ2poFOAQ2SrNtkOvA3p9IbRF95aRM9YTn5aXSj1xCbERHKVMbQo9cWqoJCXcNwoYMfYFw6iP7tY9jaaohkQBDpEtq8mA4qxIPoUSKp3j1BOXBSF3SwH3DBDLpbiAGipJKSkwPtXWB44HMuwm5KZS3tZQKMAhsjZNQfkUsYaKeuJSYgFmW68fXTLPp/CHwqjriOzzND2P7hspsQCTjcTKGQvCptNo8bAowCGyxvIpTiggwIlNNVBPXEopFqOQT8FG1eQqfgVVfjr1xKXEOiY1Mr9ngLjRYupMDYsCHCJrrNKRe0MVCHGooxVUsjFDIffNcVpBJRssWFDCyN8JWkE1KhTgEFljDVWNq1fWK6niV1AV0AoqyV3gVEZv/CTlbclGqoJG/k7QCqpRoQCHyNr0/FTodECXN4i2PvmupGIVIq2gkgeWT3HcJe+pTVpBJS9KGDGOX0FFgfHwKMAhsmY1GVCWHdlGXc6NVWyzNqpw5EBoqFrlPfJ3nFZQyYoS8nBOt3nA8YDDZkIu7Xw9LApwiOyxyl/OlQ6toJKXqbmp0OuAbm9Qtnso+UNh1EdXUNF9Iw8znPIf+ROOaKDR4hFRgENkb4YzHQBQ43JLXJKh0VSDvFhNBpRFDzyVa2BMZ1DJjxJG/k5QXTNqFOAQ2RMSjWW6P0X8CiqaE5ePC/JiCepyRCuo5EcJI3+0oejoUYBDZG9GdEXMiZZecJz8elVsBVUaraCSlQuE+0aegTFbQUX5N/IRn/Mn15E/WkE1ehTgENkry7bDbNDDGwjjbHe/1MU5j7CCKp/mxOVkhswTRo8LZ5dRQyUnF+TLd+RvwAoqCoxHRAEOkT2jQS/saHxMhpXO8WiZLqCGSlZiW+/LM5/iWDSnbAb1xGVFziN/J1oiZ1Bl2k2UtzUKFOAQRSh3ynd/ii+aI2UqL6CGSk7KclJgMujgkeHIn8cfQn1nZAXVTLpvZEXOI39HmyNB8cyCdBotHgUKcIgiyHnYOL7SIfJhMuiFE6LlFhjXtPSC54G8NAuy6TRoWWFLxeWY83c0OupX7qS6ZjQowCGKIOxPIbOGqqc/KIwOzKRKR3ZYYCy3qU0WFJdTUCw7pdnyHfmLdaZo1G80KMAhisD2wjnV1odgmJO4NDHHohXOpAwbHHaTxKUh52Ir8I41yzPAoYZKfuJH/uQUGPM8j6PR+5hGi0eHAhyiCIUOK1ItRgTDPE63eaQujoAaKnmbVRhpCL5oltcmkSzgmkUNlSyx/y5HZXTfuNw+9PQHYdDrMI32wBkVCnCIIuh0OqE3LqdKh/XwaE5cnmZHG6rTbX3oD4QlLk0Ex/F038gcC4yPNPVIXJIYVu9NzU2B1WSQuDTKQAEOUYzZMuyNU4KxvOWmWZCTagbHy2dVzJmufvT5QzAb9JiSmyJ1ccgg5Djyx6anKCgePQpwiGLMllmvKszxQqNJU1TypNPpMKvQAQD4okkejRVrNKfnp8JkoCpYjtgUVWNnP3r6gxKXJoI6U2NHf11EMWZHG6ojTW5ZbNxW2+6BL8jBZjKgNJt64nLFGiu5BMZsgz9qqOQrw27GpAwbAPlMiVO+39hRgEMUY3p+Kox6Hbq9QTT1+KQujlDhzHCmwaCnTbfkSm7TDcIScSc1VHIm3DcyGPnzBcOojR7oS4Hx6FGAQxTDYjRgenRfkyNnpe+N05CxMrARnGPNvQjLYOO2o7SCShHYfx85BMbHW3rB8UBWipmOaBgDCnCIosTycKSvdGjIWBkm56TAatKjPxhGXYe0Wwy4fUE0RI9ooE3+5G22jEZwDp+NlGEWHdEwJhTgEEWRS4DD8zwORSudOZMckpaFDM+g1wkrT6RurA5HRx4nZdiQlWKWtCxkeGyK6kRrLwIhaTcXPRS9b+YWUV0zFqIGOJ2dnVizZg3S09ORkZGBG2+8EX19Q5/QWldXB51ON+jjz3/+s3DdYD9/+eWXxfwqRCaEYWOJE0Zdbh/a+/ww6HU01aAAcsnDOXQmct/Oo4ZK9iZl2JBujWwueqJV2i0GDp3tBgDMpc7UmIga4KxZswZHjhzBzp078eabb+L999/HTTfdNOT1xcXFaG5uHvD42c9+htTUVFx++eUDrn322WcHXLdq1SoxvwqRCdZQNfX40OUJSFaOz6MN1fS8VNp0SwHkMvJHPXHliGwxIP3Inz8UFg4ZpgBnbEQLcI4ePYodO3bgqaeeQkVFBZYvX47HHnsML7/8MpqamgZ9jcFggNPpHPB49dVX8e1vfxupqQO3ps7IyBhwndVqFeurEBlJs5pQmm0HIG1jRT1xZZkT3WLg0JluSbcYYAHOvEkZkpWBjB7bmuKQhIsaaly9CIZ5ZNpNKMq0SVYOJRItwNmzZw8yMjKwePFi4bnKykro9Xrs3bt3VO9x4MABVFdX48YbbzzvZzfffDNycnKwZMkSPPPMM8NWWn6/H263e8CDKBfrjR+WcJrqc6EnniFZGcjolRekwWTQocsbxJkuaU6I7vEGUd8RSTCeM4mmNZVgfnEGAOCzMxLWNdHPnjPJQQnGYyRagONyuZCXlzfgOaPRiKysLLhcrlG9x9NPP42ZM2di2bJlA56/77778Kc//Qk7d+7ENddcgx/+8Id47LHHhnyfrVu3wuFwCI/i4uKxfyEiG3Ojvd/Pz3RL8vk8z+NQ9LPn0ZCxIliMBiFXqrqxW5IysFGAkiw7MuyUYKwEC6IdmKNNbskSjVliOo0Wj92YA5y77rpryERg9jh27NiEC9bf34+XXnpp0NGbe+65BxdffDEWLlyIO++8E3fccQd++ctfDvlemzdvRk9Pj/BobGyccPmIdOYXR/7Qqxu6Jfn8M1396PIGYTLoUE5LxBVD6I1LFOB8zhJFqaFSjOIsGzLtJgTCnLADdbKxEZy5NK05ZsaxvuC2227DDTfcMOw1U6ZMgdPpRGtr64DnQ6EQOjs74XQ6R/yc//7v/4bX68XatWtHvLaiogL3338//H4/LJbzN0GyWCyDPk+UaV5RBnS6SKJxq9uHvPTk5l+xnvgMZxosRkowVop5RRkA6oUGI9mEnjiN+imGTqfDvKIMvHe8DZ81dkfvoeTxBcM4Hj3vjgLjsRtzgJObm4vc3NwRr1u6dCm6u7tx4MABLFq0CADwzjvvgOM4VFRUjPj6p59+Gl//+tdH9VnV1dXIzMykIEYjUi1GXJCXhpqWXlQ3dmPF7JED5kSiHpUyLSiOJYyGwhyMST7oUrhvqKFSlPnFkQCnurEH312a3M8+5upFiOORnWJGoYMW0oyVaH/hM2fOxMqVK7F+/Xrs27cPH374ITZu3IjVq1ejsLAQAHD27FmUl5dj3759A1578uRJvP/++/jBD35w3vu+8cYbeOqpp3D48GGcPHkSjz/+OB588EH86Ec/EuurEBlaEJ1ukCKfgu1JQXPiyjIlJxWpFiP6g2GcbBt6Py4xdHoCQnIzbQypLCww/kyCnD82WkwJxuMjahfmxRdfRHl5OS677DJcccUVWL58OZ544gnh58FgEDU1NfB6vQNe98wzz6CoqAgrVqw47z1NJhO2bduGpUuXYsGCBfj973+PRx55BPfee6+YX4XIzHyJAhyO44Ul4rQnhbLo9Trhv1my83BYQvzknBSkW01J/WwyMWxa6lRbH3p9waR+NsszpM7U+Ix5imossrKy8NJLLw3587KyskGXdz/44IN48MEHB33NypUrsXLlyoSVkSgTG8H5/EwPOI6HPkmneZ9u74PbF4LVpMcMOg1aceYXZ2DP6Q5UN/bguouS97kH67sAAAtLMpL3oSQhclItKMq04UxXPw6d7cGyqTlJ++xPGyL3zYWlmUn7TDWhs6iIIl2QnwqbyYA+fwinkjjdcCDaUM0vyoApyTkcZOLYdEOytxg4GO2JX1hCDZUSxVbgJS9BvdMTwOn2yOGwFxbTfTMeVEMTRTIa9EKy5qdJnG5gAc4i6lEpEptuOObqRX8gnJTPDHO80BOn+0aZ5kfrmurGrqR9Jhv1m5aXCoedpjXHgwIcolgLJcjDoQBH2QocVhQ4rJGgI0mNVY2rF55AOLL6L5+mNZVoYXTk7UB98o76OMimp2hac9wowCGKxfJwWE9HbF2eAE61RYaMF9JUgyLpdDosLssCAHxSl5z75kBDLP/GkKRcMZJYcyc5YDbo0d7nR12Hd+QXJAB1piaOAhyiWKyhqmnpRY9X/NUNrMc/JTcFWSm01b5SLSmLNBj76zqT8nmf1rOeODVUSmU1GYQd1PfXin/fBMOcsCyd7pvxowCHKFZumgVTclPA88An9eJXOkKPiiocRWOB8cH6LoTC4p8vdIDyb1Thouh9k4zA+FhzL3xBDulWI6bmpor+eWpFAQ5RtIrJkUpnXxJ6VWxKY3EZNVRKdkF+GtKsRngCYRxz9Yr6Wa29PtR3eKHTAQsol0LRkhng7K3tABBZHp6sLTDUiAIcomis0tkrcoDjC4aF1VpsBIAok0Gvw+LoaIrYgfGeU5GGanZhOm3wp3AXlmZCpwPqOrxo7fWJ+lkfn47cN1+aki3q56gdBThE0ZZER3AOn+2Bxx8S7XMO1nchEOKQn27BlJwU0T6HJIeQaCzy1CZrqJZSQ6V4DpsJ5c50AOImqIc5Xuiw0X0zMRTgEEUryrRjUoYNIY7Hp9HN1MTwUbQnvmxqDp0JowIsMN57uhMcJ96yXzaCs3QqNVRqcFF0enpvNHAVwxdNbvT6QkizGDG7MF20z9ECCnCI4rHG6mMRK5091BNXlflFGbCbDejwBHDU5RblM5p7+lHX4YVBrxOmUomyLYsGqh+cbBftM/acjrz3kslZST/xXm3ot0cUjwUdYlU6Hn9IOJyReuLqYDbqhfyGD0W6b9jozZxJDqRR/o0qLJ2aA70OONXmQXNPvyifQaN+iUMBDlG8L18QOfzu8zPd6PYGEv7+++s6EeJ4FGXaUJxlT/j7E2lcPC1y3/zzhDgBDpvWpFE/9XDYTMJxH2LcN6Ewh/3R/B5KMJ44CnCI4hU4bLggPxUcL84oDjVU6vTl6ZEAZ39dJ3zBxJ5LxfM89cRVank0MP5AhADnszPd6POH4LCZMKuA8m8migIcogpfnp4LAPjn8cRXOrtrWgEAy6MNIlGH6XmpyEuzwBfkEn7cx/GWPpzt7ofFqMcSyr9RFVYPfHiyPeEJ6u8ci9Q1X7kgl/a/SQAKcIgqfOWCSIDz/om2hB6G19jpxfGWPhj0OlwS/QyiDjqdTuiN/zPBI3+7jrUAiCSl2syGhL43kdaFJZmiJai/c6wNAPAv5VTXJAIFOEQVlpRlwWzUo7nHh1NtfQl733ejozeLSjKRYafzp9SG5W+9c7Q1oe/7brQn/i/leQl9XyK9+AT13TVtCXtfV48PR5vd0OmASy6g+yYRKMAhqmAzG4RjG3YlsLFi7/VVaqhU6asz8mDQ61DT0ov6Dk9C3rPbGxDOLaP7Rp0qZ+YDAN4+4krYe7LO1ILiDDrMN0EowCGqsWJWpNL538OJqXS8gZCw/81lM6mhUqMMu1kIjHd+0ZKQ93zveBs4HrggPxVFmbTqTo0qZ+VBpwM+O9OTsOXiLP/mX2ZQXZMoFOAQ1aia7YROB1Q3diek0vnwZAcCIQ6TMmyYnkcn+qoVC4zfPpKYAIdG/dQvL82KC0siuxr/IwGBcX8gLOzHRPdN4lCAQ1QjL92KRdFKZ0cCRnH+fqgZAPCvs/LpeAYV+9fZTgCRc6k6+vwTei9fMIxdRyMN3opZzgmXjcgXC4zfSkBgvLumFd5AGEWZNjqeIYEowCGqsnJOpFGZaIDjC4aFKYur5hdMuFxEviZl2DBnUjo4HvjH0Yk1VrtrWuEJhDEpw4YLSzISU0AiSyuigfHHpzvQ4w1O6L3e/DzSmbpyXgF1phKIAhyiKlXRSmd/XSdae33jfp/dNa3o84cwKcOGhcWZiSoekanL50SC2Nc+bZrQ+7xBDZVmTM5JQbkzDSGOx9+io73j4Q2EhG0Fvja3MFHFI6AAh6hMcZYdC4ozwPHAXyfQWL3+WeS1V84roA23NGDVwkkAIoeqnunyjus9en1BYbn5VfOoodKCb0Tvm/85eGbc7/H2kRb4ghxKsuyYM4mmpxKJAhyiOtcuLgIA/OmTxnFt+tfe5xemp65eQA2VFkzKsAlHcfy1enyB8eufNaE/GMbU3BRqqDTiGwsnQa8DDtR3obZ9fNsM/HFfAwDgmguLaNQvwSjAIapz1fxCWE16nGjtQ3X0FPCx+J8DZxAM85hf5MDsQkfiC0hk6RsXRnrjr+xvRHgcW/C/vK8RAPCdJSXUUGlEXrpVOCbmz580jvn1p9v6sLe2E3pdrGNGEocCHKI66VYTrojmVPy/j+vH9FqO44Ue1b9VlCS8bES+vjavAOlWIxo6vcJOxKN1+GwPDp3tgdmgxzcvpIZKS76zpBhAZCRmrIe2vrI/EhRdckEuCjNsCS+b1lGAQ1Rp7bIyAMDr1U1j2hNn17FW1HV4kWYx4muUR6EpdrMRq5dEgtrnPqob02ufeP80AODyuU7ahVZj/nWWE0WZNnR5g3jt07Ojfp3bF8RLeyOdqTUVpWIVT9MowCGqtKA4A0smZyHE8Xj2w7pRvYbnefzu3ZMAgOuXliLFYhSxhESO1i4thV4HfHCyHYfP9ozqNXXtHrz5eSRv5/98ZaqYxSMyZNDrcEO0Q/XkP0+PenrzDx/Xo9cfwvS8VDqzTCQU4BDV+j9fmQIgUpG0uEdeMv7+iXZ81tgNq0mPG5dPFrt4RIaKMu34+vzIyN1DO46N6jW/e/ckOD5ysOYs2qRNk759UTEcNhNOtXlGtaKq1xfEMx/UAgA2XDqVVmqKhAIcolr/Up6HC0sy4A2E8cu3aoa9NhDicP+bXwAA/m1JKXJSLckoIpGhTf86A0a9Dv880Y4PTrQPe211Y7fQoG38l2nJKB6RoXSrCRu/Gvnv/5udx9EfGD4X57F3TqK9L4CybDuumk9T4WIRLcB54IEHsGzZMtjtdmRkZIzqNTzPY8uWLSgoKIDNZkNlZSVOnDgx4JrOzk6sWbMG6enpyMjIwI033oi+vj4RvgFROp1Oh3u+NgsA8N8HzmBfbeeQ1z71wWmcbO1DdooZP7lserKKSGSoJNuONdEE882vfo4+f2jQ60JhDlv+ehg8D3zzwknC2UREm767tBSTMmxo7vENO/pX4+oVRm/uvWo2TAYaZxCLaL/ZQCCAa6+9Fhs2bBj1ax5++GE8+uij2L59O/bu3YuUlBRUVVXB54tNL6xZswZHjhzBzp078eabb+L999/HTTfdJMZXICqwsCQT344uv7zl5U/R1nv+WUMH6jvxm53HAQCbr5gJh92U1DIS+bmtagYmZdjQ2NmPLa8dHnQ/pYffqsHnZ3qQZjHirsvLJSglkROryYAHvzkXQCRJfdcgx370+oK4+aWDCHE8Kmfm0cGaIhMtwPnZz36GW2+9FXPnzh3V9TzP47e//S3uvvtuXH311Zg3bx5eeOEFNDU14bXXXgMAHD16FDt27MBTTz2FiooKLF++HI899hhefvllNDVNbIt1ol5brpqNsmw7mnp8+O7Te3G2O7aqal9tJ9Y9ux/BMI8r5jpxTXQvFKJt6VYTfv3t+dDrgL98ehb3vn4EwTAHAAhzPB7ZeVxYOfXQt+YhL80qZXGJTFxyQa6QcHzzSwfxzrFYkNPW68cNz+7HydY+5KVZ8Itr5klUSu2QzTKR2tpauFwuVFZWCs85HA5UVFRgz549WL16Nfbs2YOMjAwsXrxYuKayshJ6vR579+7FN77xjUHf2+/3w++P9dzdbrd4X4TITqrFiOfWLcG3tu/BMVcvKn/9Hipn5aPPF8Tu423geWBRaSYe/tZ82qCNCL40JRsPfGMuNv/lEF7YU4/3j7fhorIsfH6mBzUtvQCA26tm4Iq5dBgrifmPK2aivsODd2va8P3nPsHF07KRl2bFrqMtcPtCSLca8cwNF1GeXxLIZvLP5Yqc/pyfnz/g+fz8fOFnLpcLeXkDh/SMRiOysrKEawazdetWOBwO4VFcXJzg0hO5K8tJwas/XIYLSzLQHwzjjc+a8G5Nm5A/8fz3lyCVloWTc3xnSQm2X38hMu0m1HV48ecDZ1DT0osUswEPXzMPN3+VEovJQGajHtu/uwg3LCuDTgd8eLIDr356Fm5fCOXONPz3hmWYM4l2SE+GMdXod911Fx566KFhrzl69CjKy+U1H71582Zs2rRJ+Lfb7aYgR4OKs+z4nw3LsOd0Bz5t6IbFqMclF+Rien6a1EUjMrZyTgG+PD0XO79oQX2HFwUZVlTNdsJho1wtMjiL0YCffn02blhWht01rfAEwphdmI4vT8+FgZaEJ82YApzbbrsNN9xww7DXTJkyZVwFcTqdAICWlhYUFMSGfFtaWrBgwQLhmtbWgVuoh0IhdHZ2Cq8fjMVigcVCw4EksrJq2dQcLJuaI3VRiIKkWIzCieOEjFZZTgpuyKE9taQypgAnNzcXubm5ohRk8uTJcDqd2LVrlxDQuN1u7N27V1iJtXTpUnR3d+PAgQNYtGgRAOCdd94Bx3GoqKgQpVyEEEIIUR7RcnAaGhpQXV2NhoYGhMNhVFdXo7q6esCeNeXl5Xj11VcBRHrWt9xyC37+85/j9ddfx6FDh7B27VoUFhZi1apVAICZM2di5cqVWL9+Pfbt24cPP/wQGzduxOrVq1FYSJslEUIIISRCtKzKLVu24Pnnnxf+vXDhQgDAu+++i0svvRQAUFNTg56e2Hkvd9xxBzweD2666SZ0d3dj+fLl2LFjB6zW2BLMF198ERs3bsRll10GvV6Pa665Bo8++qhYX4MQQgghCqTjB9vBSuXcbjccDgd6enqQnk5nxxBCCCFKMJb2WzbLxAkhhBBCEoUCHEIIIYSoDgU4hBBCCFEdCnAIIYQQojoU4BBCCCFEdSjAIYQQQojqUIBDCCGEENWhAIcQQgghqkMBDiGEEEJUR7SjGuSMbd7sdrslLgkhhBBCRou126M5hEGTAU5vby8AoLi4WOKSEEIIIWSsent74XA4hr1Gk2dRcRyHpqYmpKWlQafTJfS93W43iouL0djYSOdciYh+z8lBv+fkoN9zctDvOXnE+l3zPI/e3l4UFhZCrx8+y0aTIzh6vR5FRUWifkZ6ejr9ASUB/Z6Tg37PyUG/5+Sg33PyiPG7HmnkhqEkY0IIIYSoDgU4hBBCCFEdCnASzGKx4N5774XFYpG6KKpGv+fkoN9zctDvOTno95w8cvhdazLJmBBCCCHqRiM4hBBCCFEdCnAIIYQQojoU4BBCCCFEdSjAIYQQQojqUICTQNu2bUNZWRmsVisqKiqwb98+qYukKlu3bsVFF12EtLQ05OXlYdWqVaipqZG6WKr3i1/8AjqdDrfccovURVGls2fP4vrrr0d2djZsNhvmzp2LTz75ROpiqUo4HMY999yDyZMnw2azYerUqbj//vtHdZ4RGdr777+Pq666CoWFhdDpdHjttdcG/JzneWzZsgUFBQWw2WyorKzEiRMnklY+CnAS5JVXXsGmTZtw77334uDBg5g/fz6qqqrQ2toqddFU47333sPNN9+Mjz/+GDt37kQwGMSKFSvg8XikLppq7d+/H7///e8xb948qYuiSl1dXbj44othMpnwv//7v/jiiy/w61//GpmZmVIXTVUeeughPP744/jd736Ho0eP4qGHHsLDDz+Mxx57TOqiKZrH48H8+fOxbdu2QX/+8MMP49FHH8X27duxd+9epKSkoKqqCj6fLzkF5ElCLFmyhL/55puFf4fDYb6wsJDfunWrhKVSt9bWVh4A/95770ldFFXq7e3lp0+fzu/cuZO/5JJL+J/85CdSF0l17rzzTn758uVSF0P1rrzySv773//+gOe++c1v8mvWrJGoROoDgH/11VeFf3McxzudTv6Xv/yl8Fx3dzdvsVj4P/7xj0kpE43gJEAgEMCBAwdQWVkpPKfX61FZWYk9e/ZIWDJ16+npAQBkZWVJXBJ1uvnmm3HllVcOuK9JYr3++utYvHgxrr32WuTl5WHhwoV48sknpS6W6ixbtgy7du3C8ePHAQCfffYZPvjgA1x++eUSl0y9amtr4XK5BtQfDocDFRUVSWsXNXnYZqK1t7cjHA4jPz9/wPP5+fk4duyYRKVSN47jcMstt+Diiy/GnDlzpC6O6rz88ss4ePAg9u/fL3VRVO306dN4/PHHsWnTJvzHf/wH9u/fjx//+Mcwm8343ve+J3XxVOOuu+6C2+1GeXk5DAYDwuEwHnjgAaxZs0bqoqmWy+UCgEHbRfYzsVGAQxTp5ptvxuHDh/HBBx9IXRTVaWxsxE9+8hPs3LkTVqtV6uKoGsdxWLx4MR588EEAwMKFC3H48GFs376dApwE+tOf/oQXX3wRL730EmbPno3q6mrccsstKCwspN+zitEUVQLk5OTAYDCgpaVlwPMtLS1wOp0SlUq9Nm7ciDfffBPvvvsuioqKpC6O6hw4cACtra248MILYTQaYTQa8d577+HRRx+F0WhEOByWuoiqUVBQgFmzZg14bubMmWhoaJCoROp0++2346677sLq1asxd+5cfPe738Wtt96KrVu3Sl001WJtn5TtIgU4CWA2m7Fo0SLs2rVLeI7jOOzatQtLly6VsGTqwvM8Nm7ciFdffRXvvPMOJk+eLHWRVOmyyy7DoUOHUF1dLTwWL16MNWvWoLq6GgaDQeoiqsbFF1983lYHx48fR2lpqUQlUiev1wu9fmBzZzAYwHGcRCVSv8mTJ8PpdA5oF91uN/bu3Zu0dpGmqBJk06ZN+N73vofFixdjyZIl+O1vfwuPx4N169ZJXTTVuPnmm/HSSy/hr3/9K9LS0oR5XIfDAZvNJnHp1CMtLe28vKaUlBRkZ2dTvlOC3XrrrVi2bBkefPBBfPvb38a+ffvwxBNP4IknnpC6aKpy1VVX4YEHHkBJSQlmz56NTz/9FI888gi+//3vS100Revr68PJkyeFf9fW1qK6uhpZWVkoKSnBLbfcgp///OeYPn06Jk+ejHvuuQeFhYVYtWpVcgqYlLVaGvHYY4/xJSUlvNls5pcsWcJ//PHHUhdJVQAM+nj22WelLprq0TJx8bzxxhv8nDlzeIvFwpeXl/NPPPGE1EVSHbfbzf/kJz/hS0pKeKvVyk+ZMoX/z//8T97v90tdNEV79913B62Tv/e97/E8H1kqfs899/D5+fm8xWLhL7vsMr6mpiZp5dPxPG3lSAghhBB1oRwcQgghhKgOBTiEEEIIUR0KcAghhBCiOhTgEEIIIUR1KMAhhBBCiOpQgEMIIYQQ1aEAhxBCCCGqQwEOIYQQQlSHAhxCCCGEqA4FOIQQQghRHQpwCCGEEKI6FOAQQgghRHX+PyzftjdAlFJmAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ - "import jax.numpy as jnp\n", + "import matplotlib.pyplot as plt\n", "\n", "x_jnp = jnp.linspace(0, 10, 1000)\n", "y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp)\n", @@ -96,57 +99,54 @@ }, { "cell_type": "markdown", + "metadata": {}, + "source": [ + "The code blocks are identical to what you would expect with NumPy, aside from replacing `np` with `jnp`, and the results are the same. As we can see, JAX arrays can often be used directly in place of NumPy arrays for things like plotting." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The arrays themselves are implemented as different Python types:" + ] + }, + { + "cell_type": "code", + "execution_count": null, "metadata": { - "id": "kTZcsCJiuPG8" + "id": "kZaOXL7-uvUP", + "outputId": "7fd4dd8e-4194-4983-ac6b-28059f8feb90" }, + "outputs": [], "source": [ - "The code blocks are identical aside from replacing `np` with `jnp`, and the results are the same. As we can see, JAX arrays can often be used directly in place of NumPy arrays for things like plotting.\n", + "import numpy as np\n", + "import jax.numpy as jnp\n", "\n", - "The arrays themselves are implemented as different Python types:" + "x_np = np.linspace(0, 10, 1000)\n", + "x_jnp = jnp.linspace(0, 10, 1000)" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": { "id": "PjFFunI7xNe8", "outputId": "d3b0007e-7997-45c0-d4b8-9f5699cedcbc" }, - "outputs": [ - { - "data": { - "text/plain": [ - "numpy.ndarray" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "type(x_np)" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": { "id": "kpv5K7QYxQnX", "outputId": "ba68a1de-f938-477d-9942-83a839aeca09" }, - "outputs": [ - { - "data": { - "text/plain": [ - "jaxlib.xla_extension.ArrayImpl" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "type(x_jnp)" ] @@ -157,29 +157,19 @@ "id": "Mx94Ri7euEZm" }, "source": [ - "Python's [duck-typing](https://en.wikipedia.org/wiki/Duck_typing) allows JAX arrays and NumPy arrays to be used interchangeably in many places.\n", - "\n", - "However, there is one important difference between JAX and NumPy arrays: JAX arrays are immutable, meaning that once created their contents cannot be changed.\n", + "Python's duck-typing allows JAX arrays and NumPy arrays to be used interchangeably in many places. However, there is one important difference between JAX and NumPy arrays: JAX arrays are immutable, meaning that once created their contents cannot be changed.\n", "\n", "Here is an example of mutating an array in NumPy:" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": { "id": "fzp-y1ZVyGD4", "outputId": "6eb76bf8-0edd-43a5-b2be-85a79fb23190" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[10 1 2 3 4 5 6 7 8 9]\n" - ] - } - ], + "outputs": [], "source": [ "# NumPy: mutable arrays\n", "x = np.arange(10)\n", @@ -198,27 +188,19 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": { "id": "l2AP0QERb0P7", "outputId": "528a8e5f-538f-4739-fe95-1c3605ba8c8a" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Exception reporting mode: Minimal\n" - ] - } - ], + "outputs": [], "source": [ "%xmode minimal" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": { "id": "pCPX0JR-yM4i", "outputId": "c7bf4afd-8b7f-4dac-d065-8189679861d6", @@ -226,16 +208,7 @@ "raises-exception" ] }, - "outputs": [ - { - "ename": "TypeError", - "evalue": "ignored", - "output_type": "error", - "traceback": [ - "" - ] - } - ], + "outputs": [], "source": [ "# JAX: immutable arrays\n", "x = jnp.arange(10)\n", @@ -248,26 +221,17 @@ "id": "yRYF0YgO3F4H" }, "source": [ - "For updating individual elements, JAX provides an [indexed update syntax](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax-numpy-ndarray-at) that returns an updated copy:" + "For updating individual elements, JAX provides an [indexed update syntax](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax-numpy-ndarray-at) that returns an updated copy:" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": { "id": "8zqPEAeP3UK5", "outputId": "20a40c26-3419-4e60-bd2c-83ad30bd7650" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[0 1 2 3 4 5 6 7 8 9]\n", - "[10 1 2 3 4 5 6 7 8 9]\n" - ] - } - ], + "outputs": [], "source": [ "y = x.at[0].set(10)\n", "print(x)\n", @@ -276,204 +240,105 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "886BGDPeyXCu" - }, + "metadata": {}, "source": [ - "## NumPy, lax & XLA: JAX API layering\n", - "\n", - "**Key concepts:**\n", + "You'll find a few differences between JAX arrays and NumPy arrays once you begin digging in. See also:\n", "\n", - "- `jax.numpy` is a high-level wrapper that provides a familiar interface.\n", - "- `jax.lax` is a lower-level API that is stricter and often more powerful.\n", - "- All JAX operations are implemented in terms of operations in [XLA](https://www.tensorflow.org/xla/) – the Accelerated Linear Algebra compiler." + "- [Key concepts](https://docs.jax.dev/en/latest/key-concepts.html#jax-arrays-jax-array) for an introduction to the key concepts of JAX, such as transformations, tracing, jaxprs and pytrees.\n", + "- [🔪 JAX - The Sharp Bits 🔪](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html) for common gotchas when using JAX." ] }, { "cell_type": "markdown", - "metadata": { - "id": "BjE4m2sZy4hh" - }, + "metadata": {}, "source": [ - "If you look at the source of `jax.numpy`, you'll see that all the operations are eventually expressed in terms of functions defined in `jax.lax`. You can think of `jax.lax` as a stricter, but often more powerful, API for working with multi-dimensional arrays.\n", + "## JAX arrays (`jax.Array`)\n", + "\n", + "**Key concepts:**\n", + "- Create arrays using JAX API functions.\n", + "- JAX array objects have a `devices` attribute that indicates where the array is stored.\n", + "- JAX arrays can be *sharded* across multiple devices for parallel computation.\n", + "\n", + "The default array implementation in JAX is [`jax.Array`](https://docs.jax.dev/en/latest/_autosummary/jax.Array.html#jax.Array). In many ways it is similar to\n", + "the [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) type that you may be familiar with from the NumPy package, but it\n", + "has some important differences.\n", "\n", - "For example, while `jax.numpy` will implicitly promote arguments to allow operations between mixed data types, `jax.lax` will not:" + "### Array creation\n", + "\n", + "We typically don't call the `jax.Array` constructor directly, but rather create arrays via JAX API functions.\n", + "For example, [`jax.numpy`](https://docs.jax.dev/en/latest/jax.numpy.html#module-jax.numpy) provides familiar NumPy-style array construction functionality\n", + "such as `jax.numpy.zeros`, `jax.numpy.linspace`, `jax.numpy.arange`, etc." ] }, { "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "c6EFPcj12mw0", - "outputId": "827d09eb-c8aa-43bc-b471-0a6c9c4f6601" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array(2., dtype=float32, weak_type=True)" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ + "import jax\n", "import jax.numpy as jnp\n", - "jnp.add(1, 1.0) # jax.numpy API implicitly promotes mixed types." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "0VkqlcXL2qSp", - "outputId": "7e1e9233-2fe1-46a8-8eb1-1d1dbc54b58c", - "tags": [ - "raises-exception" - ] - }, - "outputs": [ - { - "ename": "TypeError", - "evalue": "ignored", - "output_type": "error", - "traceback": [ - "\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m lax.add requires arguments to have the same dtypes, got int32, float32. (Tip: jnp.add is a similar function that does automatic type promotion on inputs).\n" - ] - } - ], - "source": [ - "from jax import lax\n", - "lax.add(1, 1.0) # jax.lax API requires explicit type promotion." + "\n", + "x = jnp.arange(5)\n", + "isinstance(x, jax.Array)" ] }, { "cell_type": "markdown", - "metadata": { - "id": "aC9TkXaTEu7A" - }, - "source": [ - "If using `jax.lax` directly, you'll have to do type promotion explicitly in such cases:" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "id": "3PNQlieT81mi", - "outputId": "4bd2b6f3-d2d1-44cb-f8ee-18976ae40239" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array(2., dtype=float32)" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], + "metadata": {}, "source": [ - "lax.add(jnp.float32(1), 1.0)" + "If you use Python type annotations in your code, `jax.Array` is the appropriate\n", + "annotation for jax array objects (see [`jax.typing`](https://docs.jax.dev/en/latest/jax.typing.html#module-jax.typing) for more discussion)." ] }, { "cell_type": "markdown", - "metadata": { - "id": "M3HDuM4x2eTL" - }, + "metadata": {}, "source": [ - "Along with this strictness, `jax.lax` also provides efficient APIs for some more general operations than are supported by NumPy.\n", + "### Array devices and sharding\n", "\n", - "For example, consider a 1D convolution, which can be expressed in NumPy this way:" + "JAX Array objects have a `devices` method that lets you inspect where the contents of the array are stored. In the simplest cases, this will be a single CPU device:" ] }, { "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "Bv-7XexyzVCN", - "outputId": "d570f64a-ca61-456f-8cab-6cd643cb8ea1" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ - "x = jnp.array([1, 2, 1])\n", - "y = jnp.ones(10)\n", - "jnp.convolve(x, y)" + "x.devices()" ] }, { "cell_type": "markdown", - "metadata": { - "id": "0GPqgT7S0q8r" - }, + "metadata": {}, "source": [ - "Under the hood, this NumPy operation is translated to a much more general convolution implemented by [`lax.conv_general_dilated`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_general_dilated.html):" + "In general, an array may be [*sharded*](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) across multiple devices, in a manner that can be inspected via the `sharding` attribute:" ] }, { "cell_type": "code", - "execution_count": 13, - "metadata": { - "id": "pi4f6ikjzc3l", - "outputId": "0bb56ae2-7837-4c04-ff8b-6cbc0565b7d7" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from jax import lax\n", - "result = lax.conv_general_dilated(\n", - " x.reshape(1, 1, 3).astype(float), # note: explicit promotion\n", - " y.reshape(1, 1, 10),\n", - " window_strides=(1,),\n", - " padding=[(len(y) - 1, len(y) - 1)]) # equivalent of padding='full' in NumPy\n", - "result[0, 0]" + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "x.sharding" ] }, { "cell_type": "markdown", - "metadata": { - "id": "7mdo6ycczlbd" - }, + "metadata": {}, "source": [ - "This is a batched convolution operation designed to be efficient for the types of convolutions often used in deep neural nets. It requires much more boilerplate, but is far more flexible and scalable than the convolution provided by NumPy (See [Convolutions in JAX](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html) for more detail on JAX convolutions).\n", - "\n", - "At their heart, all `jax.lax` operations are Python wrappers for operations in XLA; here, for example, the convolution implementation is provided by [XLA:ConvWithGeneralPadding](https://www.tensorflow.org/xla/operation_semantics#convwithgeneralpadding_convolution).\n", - "Every JAX operation is eventually expressed in terms of these fundamental XLA operations, which is what enables just-in-time (JIT) compilation." + "Here the array is on a single device, but in general a JAX array can be\n", + "sharded across multiple devices, or even multiple hosts.\n", + "To read more about sharded arrays and parallel computation, refer to [Introduction to parallel programming](https://docs.jax.dev/en/latest/sharded-computation.html)." ] }, { "cell_type": "markdown", - "metadata": { - "id": "NJfWa2PktD5_" - }, + "metadata": {}, "source": [ - "## To JIT or not to JIT\n", + "## Just-in-time compilation with `jax.jit`\n", "\n", "**Key concepts:**\n", "\n", @@ -481,14 +346,19 @@ "- Using a just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once.\n", "- Not all JAX code can be JIT compiled, as it requires array shapes to be static & known at compile time.\n", "\n", - "The fact that all JAX operations are expressed in terms of XLA allows JAX to use the XLA compiler to execute blocks of code very efficiently.\n", - "\n", + "JAX runs transparently on the GPU or TPU (falling back to CPU if you don't have one), with all JAX operations being expressed in terms of XLA. If we have a sequence of operations, we can use the [`jax.jit`](https://docs.jax.dev/en/latest/_autosummary/jax.jit.html) function to compile this sequence of operations together using the XLA compiler." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ "For example, consider this function that normalizes the rows of a 2D matrix, expressed in terms of `jax.numpy` operations:" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": { "id": "SQj_UKGc-7kQ" }, @@ -503,19 +373,15 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "0yVo_OKSAolW" - }, + "metadata": {}, "source": [ "A just-in-time compiled version of the function can be created using the `jax.jit` transform:" ] }, { "cell_type": "code", - "execution_count": 15, - "metadata": { - "id": "oHLwGmhZAnCY" - }, + "execution_count": null, + "metadata": {}, "outputs": [], "source": [ "from jax import jit\n", @@ -524,32 +390,16 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "Q3H9ig5GA2Ms" - }, + "metadata": {}, "source": [ "This function returns the same results as the original, up to standard floating-point accuracy:" ] }, { "cell_type": "code", - "execution_count": 16, - "metadata": { - "id": "oz7zzyS3AwMc", - "outputId": "ed1c796c-59f8-4238-f6e2-f54330edadf0" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "np.random.seed(1701)\n", "X = jnp.array(np.random.rand(10000, 10))\n", @@ -558,30 +408,16 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "3GvisB-CA9M8" - }, + "metadata": {}, "source": [ - "But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case (note the use of `block_until_ready()` to account for JAX's [asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html)):" + "But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case. We can use IPython's `%timeit` to quickly benchmark our function, using `block_until_ready()` to account for JAX's [asynchronous dispatch](https://docs.jax.dev/en/latest/async_dispatch.html):" ] }, { "cell_type": "code", - "execution_count": 18, - "metadata": { - "id": "6mUB6VdDAEIY", - "outputId": "1050a69c-e713-44c1-b3eb-1ef875691978" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "815 µs ± 224 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n", - "656 µs ± 10.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "%timeit norm(X).block_until_ready()\n", "%timeit norm_compiled(X).block_until_ready()" @@ -589,9 +425,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "B1eGBGn0tMba" - }, + "metadata": {}, "source": [ "That said, `jax.jit` does have limitations: in particular, it requires all arrays to have static shapes. That means that some JAX operations are incompatible with JIT compilation.\n", "\n", @@ -600,23 +434,9 @@ }, { "cell_type": "code", - "execution_count": 19, - "metadata": { - "id": "YfZd9mW7CSKM", - "outputId": "6fdbfde4-7cde-447f-badf-26e1f8db288d" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([-0.10570311, -0.59403396, -0.8680282 , -0.23489487], dtype=float32)" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "def get_negatives(x):\n", " return x[x < 0]\n", @@ -627,444 +447,429 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "g6niKxoQC2mZ" - }, + "metadata": {}, "source": [ "But it returns an error if you attempt to execute it in jit mode:" ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": { - "id": "yYWvE4rxCjPK", - "outputId": "9cf7f2d4-8f28-4265-d701-d52086cfd437", "tags": [ "raises-exception" ] }, - "outputs": [ - { - "ename": "NonConcreteBooleanIndexError", - "evalue": "ignored", - "output_type": "error", - "traceback": [ - "\u001b[0;31mNonConcreteBooleanIndexError\u001b[0m\u001b[0;31m:\u001b[0m Array boolean indices must be concrete; got ShapedArray(bool[10])\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError\n" - ] - } - ], + "outputs": [], "source": [ "jit(get_negatives)(x)" ] }, { "cell_type": "markdown", - "metadata": { - "id": "vFL6DNpECfVz" - }, + "metadata": {}, "source": [ "This is because the function generates an array whose shape is not known at compile time: the size of the output depends on the values of the input array, and so it is not compatible with JIT." ] }, { "cell_type": "markdown", - "metadata": { - "id": "BzBnKbXwXjLV" - }, + "metadata": {}, "source": [ - "## JIT mechanics: tracing and static variables\n", + "For more on JIT compilation in JAX, check out [Just-in-time compilation](https://docs.jax.dev/en/latest/jit-compilation.html)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Taking derivatives with `jax.grad`\n", "\n", "**Key concepts:**\n", + "- JAX provides automatic differentiation via the `jax.grad` transformation.\n", + "- The `jax.grad` and `jax.jit` transformations compose and can be mixed arbitrarily.\n", "\n", - "- JIT and other JAX transforms work by *tracing* a function to determine its effect on inputs of a specific shape and type.\n", + "In addition to transforming functions via JIT compilation, JAX also provides other transformations. One such transformation is [`jax.grad`](https://docs.jax.dev/en/latest/_autosummary/jax.grad.html), which performs [automatic differentiation (autodiff)](https://en.wikipedia.org/wiki/Automatic_differentiation):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from jax import grad\n", "\n", - "- Variables that you don't want to be traced can be marked as *static*\n", + "def sum_logistic(x):\n", + " return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))\n", "\n", - "To use `jax.jit` effectively, it is useful to understand how it works. Let's put a few `print()` statements within a JIT-compiled function and then call the function:" + "x_small = jnp.arange(3.)\n", + "derivative_fn = grad(sum_logistic)\n", + "print(derivative_fn(x_small))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's verify with finite differences that our result is correct." ] }, { "cell_type": "code", - "execution_count": 21, - "metadata": { - "id": "TfjVIVuD4gnc", - "outputId": "9f4ddcaa-8ab7-4984-afb6-47fede5314ea" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Running f():\n", - " x = Tracedwith\n", - " y = Tracedwith\n", - " result = Tracedwith\n" - ] - }, - { - "data": { - "text/plain": [ - "Array([0.25773212, 5.3623195 , 5.403243 ], dtype=float32)" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ - "@jit\n", - "def f(x, y):\n", - " print(\"Running f():\")\n", - " print(f\" x = {x}\")\n", - " print(f\" y = {y}\")\n", - " result = jnp.dot(x + 1, y + 1)\n", - " print(f\" result = {result}\")\n", - " return result\n", + "def first_finite_differences(f, x, eps=1E-3):\n", + " return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)\n", + " for v in jnp.eye(len(x))])\n", "\n", - "x = np.random.randn(3, 4)\n", - "y = np.random.randn(4)\n", - "f(x, y)" + "print(first_finite_differences(sum_logistic, x_small))" ] }, { "cell_type": "markdown", - "metadata": { - "id": "Ts1fP45A40QV" - }, + "metadata": {}, "source": [ - "Notice that the print statements execute, but rather than printing the data we passed to the function, though, it prints *tracer* objects that stand-in for them.\n", - "\n", - "These tracer objects are what `jax.jit` uses to extract the sequence of operations specified by the function. Basic tracers are stand-ins that encode the **shape** and **dtype** of the arrays, but are agnostic to the values. This recorded sequence of computations can then be efficiently applied within XLA to new inputs with the same shape and dtype, without having to re-execute the Python code.\n", - "\n", - "When we call the compiled function again on matching inputs, no re-compilation is required and nothing is printed because the result is computed in compiled XLA rather than in Python:" + "The [`jax.grad`](https://docs.jax.dev/en/latest/_autosummary/jax.grad.html) and [`jax.jit`](https://docs.jax.dev/en/latest/_autosummary/jax.jit.html) transformations compose and can be mixed arbitrarily.\n", + "For instance, while the `sum_logistic` function was differentiated directly in the previous example, it could also be JIT-compiled, and these operations can be combined. We can go further:" ] }, { "cell_type": "code", - "execution_count": 22, - "metadata": { - "id": "xGntvzNH7skE", - "outputId": "43aaeee6-3853-4b00-fb2b-646df695204a" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([1.4344584, 4.3004413, 7.9897013], dtype=float32)" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ - "x2 = np.random.randn(3, 4)\n", - "y2 = np.random.randn(4)\n", - "f(x2, y2)" + "print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))" ] }, { "cell_type": "markdown", - "metadata": { - "id": "9EB9WkRX7fm0" - }, + "metadata": {}, "source": [ - "The extracted sequence of operations is encoded in a JAX expression, or *jaxpr* for short. You can view the jaxpr using the `jax.make_jaxpr` transformation:" + "Beyond scalar-valued functions, the [`jax.jacobian`](https://docs.jax.dev/en/latest/_autosummary/jax.jacobian.html) transformation can be\n", + "used to compute the full Jacobian matrix for vector-valued functions:" ] }, { "cell_type": "code", - "execution_count": 23, - "metadata": { - "id": "89TMp_Op5-JZ", - "outputId": "48212815-059a-4af1-de82-cd39ecac264a" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "{ lambda ; a:f32[3,4] b:f32[4]. let\n", - " c:f32[3,4] = add a 1.0\n", - " d:f32[4] = add b 1.0\n", - " e:f32[3] = dot_general[\n", - " dimension_numbers=(([1], [0]), ([], []))\n", - " preferred_element_type=float32\n", - " ] c d\n", - " in (e,) }" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from jax import make_jaxpr\n", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from jax import jacobian\n", + "print(jacobian(jnp.exp)(x_small))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For more advanced autodiff operations, you can use [`jax.vjp`](https://docs.jax.dev/en/latest/_autosummary/jax.vjp.html) for reverse-mode vector-Jacobian products,\n", + "and [`jax.jvp`](https://docs.jax.dev/en/latest/_autosummary/jax.jvp.html) and [`jax.linearize`](https://docs.jax.dev/en/latest/_autosummary/jax.linearize.html) for forward-mode Jacobian-vector products.\n", + "The two can be composed arbitrarily with one another, and with other JAX transformations.\n", + "For example, `jax.jvp` and `jax.vjp` are used to define the forward-mode [`jax.jacfwd`](https://docs.jax.dev/en/latest/_autosummary/jax.jacfwd.html) and reverse-mode [`jax.jacrev`](https://docs.jax.dev/en/latest/_autosummary/jax.jacrev.html) for computing Jacobians in forward- and reverse-mode, respectively.\n", + "Here's one way to compose them to make a function that efficiently computes full Hessian matrices:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from jax import jacfwd, jacrev\n", + "def hessian(fun):\n", + " return jit(jacfwd(jacrev(fun)))\n", + "print(hessian(sum_logistic)(x_small))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This kind of composition produces efficient code in practice; this is more-or-less how JAX's built-in [`jax.hessian`](https://docs.jax.dev/en/latest/_autosummary/jax.hessian.html) function is implemented.\n", + "\n", + "For more on automatic differentiation in JAX, check out [Automatic differentiation](https://docs.jax.dev/en/latest/automatic-differentiation.html)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Auto-vectorization with `jax.vmap`\n", "\n", - "def f(x, y):\n", - " return jnp.dot(x + 1, y + 1)\n", + "**Key concepts:**\n", + "- JAX provides automatic vectorization via the [`jax.vmap`](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html) transformation.\n", + "- `jax.vmap` can be composed with `jax.jit` to produce efficient vectorized code.\n", "\n", - "make_jaxpr(f)(x, y)" + "Another useful transformation is [`jax.vmap`](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html), the vectorizing map.\n", + "It has the familiar semantics of mapping a function along array axes, but instead of explicitly looping\n", + "over function calls, it transforms the function into a natively vectorized version for better performance.\n", + "When composed with [`jax.jit`](https://docs.jax.dev/en/latest/_autosummary/jax.jit.html), it can be just as performant as manually rewriting your function\n", + "to operate over an extra batch dimension.\n", + "\n", + "We're going to work with a simple example, and promote matrix-vector products into matrix-matrix products using [`jax.vmap`](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html).\n", + "Although this is easy to do by hand in this specific case, the same technique can apply to more complicated functions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from jax import random\n", + "\n", + "key = random.key(1701)\n", + "key1, key2 = random.split(key)\n", + "mat = random.normal(key1, (150, 100))\n", + "batched_x = random.normal(key2, (10, 100))\n", + "\n", + "def apply_matrix(x):\n", + " return jnp.dot(mat, x)" ] }, { "cell_type": "markdown", - "metadata": { - "id": "0Oq9S4MZ90TL" - }, + "metadata": {}, "source": [ - "Note one consequence of this: because JIT compilation is done *without* information on the content of the array, control flow statements in the function cannot depend on traced values. For example, this fails:" + "The `apply_matrix` function maps a vector to a vector, but we may want to apply it row-wise across a matrix.\n", + "We could do this by looping over the batch dimension in Python, but this usually results in poor performance." ] }, { "cell_type": "code", - "execution_count": 24, - "metadata": { - "id": "A0rFdM95-Ix_", - "outputId": "e37bf04e-6a6a-4536-e423-f082f52d5f11", - "tags": [ - "raises-exception" - ] - }, - "outputs": [ - { - "ename": "TracerBoolConversionError", - "evalue": "ignored", - "output_type": "error", - "traceback": [ - "\u001b[0;31mTracerBoolConversionError\u001b[0m\u001b[0;31m:\u001b[0m Attempted boolean conversion of traced array with shape bool[]..\nThe error occurred while tracing the function f at :1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg.\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ - "@jit\n", - "def f(x, neg):\n", - " return -x if neg else x\n", + "def naively_batched_apply_matrix(v_batched):\n", + " return jnp.stack([apply_matrix(v) for v in v_batched])\n", "\n", - "f(1, True)" + "print('Naively batched')\n", + "%timeit naively_batched_apply_matrix(batched_x).block_until_ready()" ] }, { "cell_type": "markdown", - "metadata": { - "id": "DkTO9m8j-TYI" - }, + "metadata": {}, "source": [ - "If there are variables that you would not like to be traced, they can be marked as static for the purposes of JIT compilation:" + "A programmer familiar with the `jnp.dot` function might recognize that `apply_matrix` can\n", + "be rewritten to avoid explicit looping, using the built-in batching semantics of `jnp.dot`:" ] }, { "cell_type": "code", - "execution_count": 25, - "metadata": { - "id": "K1C7ZnVv-lbv", - "outputId": "e9d6cce3-b036-43da-ad99-887af9625ab0" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array(-1, dtype=int32, weak_type=True)" - ] - }, - "execution_count": 25, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from functools import partial\n", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", "\n", - "@partial(jit, static_argnums=(1,))\n", - "def f(x, neg):\n", - " return -x if neg else x\n", + "@jit\n", + "def batched_apply_matrix(batched_x):\n", + " return jnp.dot(batched_x, mat.T)\n", "\n", - "f(1, True)" + "np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),\n", + " batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)\n", + "print('Manually batched')\n", + "%timeit batched_apply_matrix(batched_x).block_until_ready()" ] }, { "cell_type": "markdown", - "metadata": { - "id": "dD7p4LRsGzhx" - }, + "metadata": {}, "source": [ - "Note that calling a JIT-compiled function with a different static argument results in re-compilation, so the function still works as expected:" + "However, as functions become more complicated, this kind of manual batching becomes more difficult and error-prone.\n", + "The `jax.vmap` transformation is designed to automatically transform a function into a batch-aware version:" ] }, { "cell_type": "code", - "execution_count": 26, - "metadata": { - "id": "sXqczBOrG7-w", - "outputId": "5fb7c278-b87e-4a6b-ef50-5e4e9c765b52" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array(1, dtype=int32, weak_type=True)" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ - "f(1, False)" + "from jax import vmap\n", + "\n", + "@jit\n", + "def vmap_batched_apply_matrix(batched_x):\n", + " return vmap(apply_matrix)(batched_x)\n", + "\n", + "np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),\n", + " vmap_batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)\n", + "print('Auto-vectorized with vmap')\n", + "%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()" ] }, { "cell_type": "markdown", - "metadata": { - "id": "ZESlrDngGVb1" - }, + "metadata": {}, "source": [ - "Understanding which values and operations will be static and which will be traced is a key part of using `jax.jit` effectively." + "As you would expect, `jax.vmap` can be arbitrarily composed with `jax.jit`,\n", + "`jax.grad`, and any other JAX transformation.\n", + "\n", + "For more on automatic vectorization in JAX, check out [Automatic vectorization](https://docs.jax.dev/en/latest/automatic-vectorization.html)." ] }, { "cell_type": "markdown", - "metadata": { - "id": "r-RCl_wD5lI7" - }, + "metadata": {}, "source": [ - "## Static vs traced operations\n", + "(key-concepts-prngs)=\n", + "## Pseudorandom numbers\n", "\n", "**Key concepts:**\n", "\n", - "- Just as values can be either static or traced, operations can be static or traced.\n", + "- JAX uses a different model for pseudo random number generation than NumPy.\n", + "- JAX random functions consume a random `key` that must be split to generate new independent keys.\n", + "- JAX's random key model is thread-safe and avoids issues with global state.\n", "\n", - "- Static operations are evaluated at compile-time in Python; traced operations are compiled & evaluated at run-time in XLA.\n", + "Generally, JAX strives to be compatible with NumPy, but pseudo random number generation is a notable exception. NumPy supports a method of pseudo random number generation that is based on a global `state`, which can be set using [`numpy.random.seed`](https://numpy.org/doc/stable/reference/random/generated/numpy.random.seed.html). Global random state interacts poorly with JAX's compute model and makes it difficult to enforce reproducibility across different threads, processes, and devices. JAX instead tracks state explicitly via a random `key`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from jax import random\n", "\n", - "- Use `numpy` for operations that you want to be static; use `jax.numpy` for operations that you want to be traced.\n", + "key = random.key(43)\n", + "print(key)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The key is effectively a stand-in for NumPy's hidden state object, but we pass it explicitly to [`jax.random`](https://docs.jax.dev/en/latest/jax.random.html) functions. Importantly, random functions consume the key, but do not modify it: feeding the same key object to a random function will always result in the same sample being generated." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(random.normal(key))\n", + "print(random.normal(key))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**The rule of thumb is: never reuse keys (unless you want identical outputs).**\n", "\n", - "This distinction between static and traced values makes it important to think about how to keep a static value static. Consider this function:" + "In order to generate different and independent samples, you must [`jax.random.split`](https://docs.jax.dev/en/latest/_autosummary/jax.random.split.html) the key explicitly before passing it to a random function:" ] }, { "cell_type": "code", - "execution_count": 27, - "metadata": { - "id": "XJCQ7slcD4iU", - "outputId": "3646dea0-f6b6-48e9-9dc0-c4dec7816b7a", - "tags": [ - "raises-exception" - ] - }, - "outputs": [ - { - "ename": "TypeError", - "evalue": "ignored", - "output_type": "error", - "traceback": [ - "\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m Shapes must be 1D sequences of concrete values of integer type, got [Tracedwith].\nIf using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.\nThe error occurred while tracing the function f at :4 for jit. This value became a tracer due to JAX operations on these lines:\n\n operation a:i32[2] = convert_element_type[new_dtype=int32 weak_type=False] b\n from line :6 (f)\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ - "import jax.numpy as jnp\n", - "from jax import jit\n", + "for i in range(3):\n", + " new_key, subkey = random.split(key)\n", + " del key # The old key is consumed by split() -- we must never use it again.\n", "\n", - "@jit\n", - "def f(x):\n", - " return x.reshape(jnp.array(x.shape).prod())\n", + " val = random.normal(subkey)\n", + " del subkey # The subkey is consumed by normal().\n", "\n", - "x = jnp.ones((2, 3))\n", - "f(x)" + " print(f\"draw {i}: {val}\")\n", + " key = new_key # new_key is safe to use in the next iteration." ] }, { "cell_type": "markdown", - "metadata": { - "id": "ZO3GMGrHBZDS" - }, + "metadata": {}, "source": [ - "This fails with an error specifying that a tracer was found instead of a 1D sequence of concrete values of integer type. Let's add some print statements to the function to understand why this is happening:" + "Note that this code is thread safe, since the local random state eliminates possible race conditions involving global state. `jax.random.split` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys.\n", + "\n", + "For more on pseudo random numbers in JAX, see the [Pseudorandom numbers tutorial](https://docs.jax.dev/en/latest/random-numbers.html)." + ] + }, + { + "cell_type": "markdown", + "id": "b79e0c62", + "metadata": {}, + "source": [ + "## Debugging\n", + "\n", + "Debugging JAX code can be challenging due to its functional programming model and the fact that JAX code is often transformed via JIT compilation or vectorization. However, JAX provides several tools to help with debugging.\n", + "\n", + "### `jax.debug.print`\n", + "\n", + "For simple inspection, use [`jax.debug.print`](https://docs.jax.dev/en/latest/_autosummary/jax.debug.print.html).\n", + "\n", + "Python's built-in `print` executes at trace-time, before the runtime values exist. Because of this, `print` will only show tracer values within `jax.jit`-decorated code." ] }, { "cell_type": "code", - "execution_count": 28, - "metadata": { - "id": "Cb4mbeVZEi_q", - "outputId": "30d8621f-34e1-4e1d-e6c4-c3e0d8769ec4" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "x = Tracedwith\n", - "x.shape = (2, 3)\n", - "jnp.array(x.shape).prod() = Tracedwith\n" - ] - } - ], + "execution_count": null, + "id": "61675ec9", + "metadata": {}, + "outputs": [], "source": [ - "@jit\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "@jax.jit\n", "def f(x):\n", - " print(f\"x = {x}\")\n", - " print(f\"x.shape = {x.shape}\")\n", - " print(f\"jnp.array(x.shape).prod() = {jnp.array(x.shape).prod()}\")\n", - " # comment this out to avoid the error:\n", - " # return x.reshape(jnp.array(x.shape).prod())\n", + " print(\"print(x) ->\", x)\n", + " y = jnp.sin(x)\n", + " print(\"print(y) ->\", y)\n", + " return y\n", "\n", - "f(x)" + "result = f(2.)" ] }, { "cell_type": "markdown", - "metadata": { - "id": "viSQPc3jEwJr" - }, + "id": "a34c34bb", + "metadata": {}, "source": [ - "Notice that although `x` is traced, `x.shape` is a static value. However, when we use `jnp.array` and `jnp.prod` on this static value, it becomes a traced value, at which point it cannot be used in a function like `reshape()` that requires a static input (recall: array shapes must be static).\n", - "\n", - "A useful pattern is to use `numpy` for operations that should be static (i.e. done at compile-time), and use `jax.numpy` for operations that should be traced (i.e. compiled and executed at run-time). For this function, it might look like this:" + "If you want to print the actual runtime values, you can use `jax.debug.print`:" ] }, { "cell_type": "code", - "execution_count": 29, - "metadata": { - "id": "GiovOOPcGJhg", - "outputId": "5363ad1b-23d9-4dd6-d9db-95a6c9de05da" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([1., 1., 1., 1., 1., 1.], dtype=float32)" - ] - }, - "execution_count": 29, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "id": "49b5cb05", + "metadata": {}, + "outputs": [], "source": [ - "from jax import jit\n", - "import jax.numpy as jnp\n", - "import numpy as np\n", - "\n", - "@jit\n", + "@jax.jit\n", "def f(x):\n", - " return x.reshape((np.prod(x.shape),))\n", + " jax.debug.print(\"jax.debug.print(x) -> {x}\", x=x)\n", + " y = jnp.sin(x)\n", + " jax.debug.print(\"jax.debug.print(y) -> {y}\", y=y)\n", + " return y\n", "\n", - "f(x)" + "result = f(2.)" ] }, { "cell_type": "markdown", - "metadata": { - "id": "C-QZ5d1DG-dv" - }, + "id": "515495d4", + "metadata": {}, "source": [ - "For this reason, a standard convention in JAX programs is to `import numpy as np` and `import jax.numpy as jnp` so that both interfaces are available for finer control over whether operations are performed in a static manner (with `numpy`, once at compile-time) or a traced manner (with `jax.numpy`, optimized at run-time)." + "### Debugging flags\n", + "\n", + "JAX offers flags and context managers that enable catching errors more easily. For example, you can enable the `jax.debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code. You can also enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`.\n", + "\n", + "For more details, see [Introduction to debugging](https://docs.jax.dev/en/latest/debugging.html).\n", + "\n", + "---\n", + "\n", + "This is just a taste of what JAX can do. We're really excited to see what you do with it!" ] } ], @@ -1078,6 +883,7 @@ }, "kernelspec": { "display_name": "Python 3", + "language": "python", "name": "python3" }, "language_info": { @@ -1090,7 +896,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.11.13" } }, "nbformat": 4, diff --git a/docs/notebooks/thinking_in_jax.md b/docs/notebooks/thinking_in_jax.md index 0693f6ba8579..77f8797c4f5f 100644 --- a/docs/notebooks/thinking_in_jax.md +++ b/docs/notebooks/thinking_in_jax.md @@ -8,18 +8,44 @@ jupytext: jupytext_version: 1.16.4 kernelspec: display_name: Python 3 + language: python name: python3 --- +++ {"id": "LQHmwePqryRU"} -# How to think in JAX +# Quickstart: How to think in JAX - + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) -JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively. +**JAX is a library for array-oriented numerical computation (*à la* [NumPy](https://numpy.org/)), with automatic differentiation and JIT compilation to enable high-performance machine learning research**. + ++++ + +This document provides a quick overview of essential JAX features, so you can get started with JAX: + +* JAX provides a unified NumPy-like interface to computations that run on CPU, GPU, or TPU, in local or distributed settings. +* JAX features built-in Just-In-Time (JIT) compilation via [Open XLA](https://github.com/openxla), an open-source machine learning compiler ecosystem. +* JAX functions support efficient evaluation of gradients via its automatic differentiation transformations. +* JAX functions can be automatically vectorized to efficiently map them over arrays representing batches of inputs. + ++++ + +## Installation + +JAX can be installed for CPU on Linux, Windows, and macOS directly from the [Python Package Index](https://pypi.org/project/jax/): + +``` +pip install jax +``` +or, for NVIDIA GPU: + +``` +pip install -U "jax[cuda13]" +``` +For more detailed platform-specific installation information, check out [Installation](https://docs.jax.dev/en/latest/installation.html). +++ {"id": "nayIExVUtsVD"} @@ -28,40 +54,44 @@ JAX provides a simple and powerful API for writing accelerated numerical code, b **Key concepts:** - JAX provides a NumPy-inspired interface for convenience. -- Through duck-typing, JAX arrays can often be used as drop-in replacements of NumPy arrays. +- Through [duck-typing](https://en.wikipedia.org/wiki/Duck_typing), JAX arrays can often be used as drop-in replacements of NumPy arrays. - Unlike NumPy arrays, JAX arrays are always immutable. -NumPy provides a well-known, powerful API for working with numerical data. For convenience, JAX provides `jax.numpy` which closely mirrors the numpy API and provides easy entry into JAX. Almost anything that can be done with `numpy` can be done with `jax.numpy`: ++++ -```{code-cell} ipython3 -:id: kZaOXL7-uvUP -:outputId: 7fd4dd8e-4194-4983-ac6b-28059f8feb90 +NumPy provides a well-known, powerful API for working with numerical data. For convenience, JAX provides [`jax.numpy`](https://docs.jax.dev/en/latest/jax.numpy.html) which closely mirrors the NumPy API and provides easy entry into JAX. Almost anything that can be done with `numpy` can be done with `jax.numpy`, which is typically imported under the `jnp` alias: -import matplotlib.pyplot as plt -import numpy as np - -x_np = np.linspace(0, 10, 1000) -y_np = 2 * np.sin(x_np) * np.cos(x_np) -plt.plot(x_np, y_np); +```{code-cell} ipython3 +import jax.numpy as jnp ``` -```{code-cell} ipython3 -:id: 18XbGpRLuZlr -:outputId: 3d073b3c-913f-410b-ee33-b3a0eb878436 +With this import, you can immediately use JAX in a similar manner to typical NumPy programs, including using NumPy-style array creation functions, Python functions and operators, and array attributes and methods: -import jax.numpy as jnp +```{code-cell} ipython3 +import matplotlib.pyplot as plt x_jnp = jnp.linspace(0, 10, 1000) y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp) plt.plot(x_jnp, y_jnp); ``` -+++ {"id": "kTZcsCJiuPG8"} +The code blocks are identical to what you would expect with NumPy, aside from replacing `np` with `jnp`, and the results are the same. As we can see, JAX arrays can often be used directly in place of NumPy arrays for things like plotting. -The code blocks are identical aside from replacing `np` with `jnp`, and the results are the same. As we can see, JAX arrays can often be used directly in place of NumPy arrays for things like plotting. ++++ The arrays themselves are implemented as different Python types: +```{code-cell} ipython3 +:id: kZaOXL7-uvUP +:outputId: 7fd4dd8e-4194-4983-ac6b-28059f8feb90 + +import numpy as np +import jax.numpy as jnp + +x_np = np.linspace(0, 10, 1000) +x_jnp = jnp.linspace(0, 10, 1000) +``` + ```{code-cell} ipython3 :id: PjFFunI7xNe8 :outputId: d3b0007e-7997-45c0-d4b8-9f5699cedcbc @@ -78,9 +108,7 @@ type(x_jnp) +++ {"id": "Mx94Ri7euEZm"} -Python's [duck-typing](https://en.wikipedia.org/wiki/Duck_typing) allows JAX arrays and NumPy arrays to be used interchangeably in many places. - -However, there is one important difference between JAX and NumPy arrays: JAX arrays are immutable, meaning that once created their contents cannot be changed. +Python's duck-typing allows JAX arrays and NumPy arrays to be used interchangeably in many places. However, there is one important difference between JAX and NumPy arrays: JAX arrays are immutable, meaning that once created their contents cannot be changed. Here is an example of mutating an array in NumPy: @@ -117,7 +145,7 @@ x[0] = 10 +++ {"id": "yRYF0YgO3F4H"} -For updating individual elements, JAX provides an [indexed update syntax](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax-numpy-ndarray-at) that returns an updated copy: +For updating individual elements, JAX provides an [indexed update syntax](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax-numpy-ndarray-at) that returns an updated copy: ```{code-cell} ipython3 :id: 8zqPEAeP3UK5 @@ -128,92 +156,64 @@ print(x) print(y) ``` -+++ {"id": "886BGDPeyXCu"} +You'll find a few differences between JAX arrays and NumPy arrays once you begin digging in. See also: -## NumPy, lax & XLA: JAX API layering - -**Key concepts:** +- [Key concepts](https://docs.jax.dev/en/latest/key-concepts.html#jax-arrays-jax-array) for an introduction to the key concepts of JAX, such as transformations, tracing, jaxprs and pytrees. +- [🔪 JAX - The Sharp Bits 🔪](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html) for common gotchas when using JAX. -- `jax.numpy` is a high-level wrapper that provides a familiar interface. -- `jax.lax` is a lower-level API that is stricter and often more powerful. -- All JAX operations are implemented in terms of operations in [XLA](https://www.tensorflow.org/xla/) – the Accelerated Linear Algebra compiler. ++++ -+++ {"id": "BjE4m2sZy4hh"} +## JAX arrays (`jax.Array`) -If you look at the source of `jax.numpy`, you'll see that all the operations are eventually expressed in terms of functions defined in `jax.lax`. You can think of `jax.lax` as a stricter, but often more powerful, API for working with multi-dimensional arrays. +**Key concepts:** +- Create arrays using JAX API functions. +- JAX array objects have a `devices` attribute that indicates where the array is stored. +- JAX arrays can be *sharded* across multiple devices for parallel computation. -For example, while `jax.numpy` will implicitly promote arguments to allow operations between mixed data types, `jax.lax` will not: +The default array implementation in JAX is [`jax.Array`](https://docs.jax.dev/en/latest/_autosummary/jax.Array.html#jax.Array). In many ways it is similar to +the [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) type that you may be familiar with from the NumPy package, but it +has some important differences. -```{code-cell} ipython3 -:id: c6EFPcj12mw0 -:outputId: 827d09eb-c8aa-43bc-b471-0a6c9c4f6601 +### Array creation -import jax.numpy as jnp -jnp.add(1, 1.0) # jax.numpy API implicitly promotes mixed types. -``` +We typically don't call the `jax.Array` constructor directly, but rather create arrays via JAX API functions. +For example, [`jax.numpy`](https://docs.jax.dev/en/latest/jax.numpy.html#module-jax.numpy) provides familiar NumPy-style array construction functionality +such as `jax.numpy.zeros`, `jax.numpy.linspace`, `jax.numpy.arange`, etc. ```{code-cell} ipython3 -:id: 0VkqlcXL2qSp -:outputId: 7e1e9233-2fe1-46a8-8eb1-1d1dbc54b58c -:tags: [raises-exception] +import jax +import jax.numpy as jnp -from jax import lax -lax.add(1, 1.0) # jax.lax API requires explicit type promotion. +x = jnp.arange(5) +isinstance(x, jax.Array) ``` -+++ {"id": "aC9TkXaTEu7A"} - -If using `jax.lax` directly, you'll have to do type promotion explicitly in such cases: +If you use Python type annotations in your code, `jax.Array` is the appropriate +annotation for jax array objects (see [`jax.typing`](https://docs.jax.dev/en/latest/jax.typing.html#module-jax.typing) for more discussion). -```{code-cell} ipython3 -:id: 3PNQlieT81mi -:outputId: 4bd2b6f3-d2d1-44cb-f8ee-18976ae40239 - -lax.add(jnp.float32(1), 1.0) -``` - -+++ {"id": "M3HDuM4x2eTL"} ++++ -Along with this strictness, `jax.lax` also provides efficient APIs for some more general operations than are supported by NumPy. +### Array devices and sharding -For example, consider a 1D convolution, which can be expressed in NumPy this way: +JAX Array objects have a `devices` method that lets you inspect where the contents of the array are stored. In the simplest cases, this will be a single CPU device: ```{code-cell} ipython3 -:id: Bv-7XexyzVCN -:outputId: d570f64a-ca61-456f-8cab-6cd643cb8ea1 - -x = jnp.array([1, 2, 1]) -y = jnp.ones(10) -jnp.convolve(x, y) +x.devices() ``` -+++ {"id": "0GPqgT7S0q8r"} - -Under the hood, this NumPy operation is translated to a much more general convolution implemented by [`lax.conv_general_dilated`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_general_dilated.html): +In general, an array may be [*sharded*](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) across multiple devices, in a manner that can be inspected via the `sharding` attribute: ```{code-cell} ipython3 -:id: pi4f6ikjzc3l -:outputId: 0bb56ae2-7837-4c04-ff8b-6cbc0565b7d7 - -from jax import lax -result = lax.conv_general_dilated( - x.reshape(1, 1, 3).astype(float), # note: explicit promotion - y.reshape(1, 1, 10), - window_strides=(1,), - padding=[(len(y) - 1, len(y) - 1)]) # equivalent of padding='full' in NumPy -result[0, 0] +x.sharding ``` -+++ {"id": "7mdo6ycczlbd"} - -This is a batched convolution operation designed to be efficient for the types of convolutions often used in deep neural nets. It requires much more boilerplate, but is far more flexible and scalable than the convolution provided by NumPy (See [Convolutions in JAX](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html) for more detail on JAX convolutions). +Here the array is on a single device, but in general a JAX array can be +sharded across multiple devices, or even multiple hosts. +To read more about sharded arrays and parallel computation, refer to [Introduction to parallel programming](https://docs.jax.dev/en/latest/sharded-computation.html). -At their heart, all `jax.lax` operations are Python wrappers for operations in XLA; here, for example, the convolution implementation is provided by [XLA:ConvWithGeneralPadding](https://www.tensorflow.org/xla/operation_semantics#convwithgeneralpadding_convolution). -Every JAX operation is eventually expressed in terms of these fundamental XLA operations, which is what enables just-in-time (JIT) compilation. ++++ -+++ {"id": "NJfWa2PktD5_"} - -## To JIT or not to JIT +## Just-in-time compilation with `jax.jit` **Key concepts:** @@ -221,7 +221,9 @@ Every JAX operation is eventually expressed in terms of these fundamental XLA op - Using a just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once. - Not all JAX code can be JIT compiled, as it requires array shapes to be static & known at compile time. -The fact that all JAX operations are expressed in terms of XLA allows JAX to use the XLA compiler to execute blocks of code very efficiently. +JAX runs transparently on the GPU or TPU (falling back to CPU if you don't have one), with all JAX operations being expressed in terms of XLA. If we have a sequence of operations, we can use the [`jax.jit`](https://docs.jax.dev/en/latest/_autosummary/jax.jit.html) function to compile this sequence of operations together using the XLA compiler. + ++++ For example, consider this function that normalizes the rows of a 2D matrix, expressed in terms of `jax.numpy` operations: @@ -235,52 +237,33 @@ def norm(X): return X / X.std(0) ``` -+++ {"id": "0yVo_OKSAolW"} - A just-in-time compiled version of the function can be created using the `jax.jit` transform: ```{code-cell} ipython3 -:id: oHLwGmhZAnCY - from jax import jit norm_compiled = jit(norm) ``` -+++ {"id": "Q3H9ig5GA2Ms"} - This function returns the same results as the original, up to standard floating-point accuracy: ```{code-cell} ipython3 -:id: oz7zzyS3AwMc -:outputId: ed1c796c-59f8-4238-f6e2-f54330edadf0 - np.random.seed(1701) X = jnp.array(np.random.rand(10000, 10)) np.allclose(norm(X), norm_compiled(X), atol=1E-6) ``` -+++ {"id": "3GvisB-CA9M8"} - -But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case (note the use of `block_until_ready()` to account for JAX's [asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html)): +But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case. We can use IPython's `%timeit` to quickly benchmark our function, using `block_until_ready()` to account for JAX's [asynchronous dispatch](https://docs.jax.dev/en/latest/async_dispatch.html): ```{code-cell} ipython3 -:id: 6mUB6VdDAEIY -:outputId: 1050a69c-e713-44c1-b3eb-1ef875691978 - %timeit norm(X).block_until_ready() %timeit norm_compiled(X).block_until_ready() ``` -+++ {"id": "B1eGBGn0tMba"} - That said, `jax.jit` does have limitations: in particular, it requires all arrays to have static shapes. That means that some JAX operations are incompatible with JIT compilation. For example, this operation can be executed in op-by-op mode: ```{code-cell} ipython3 -:id: YfZd9mW7CSKM -:outputId: 6fdbfde4-7cde-447f-badf-26e1f8db288d - def get_negatives(x): return x[x < 0] @@ -288,203 +271,252 @@ x = jnp.array(np.random.randn(10)) get_negatives(x) ``` -+++ {"id": "g6niKxoQC2mZ"} - But it returns an error if you attempt to execute it in jit mode: ```{code-cell} ipython3 -:id: yYWvE4rxCjPK -:outputId: 9cf7f2d4-8f28-4265-d701-d52086cfd437 :tags: [raises-exception] jit(get_negatives)(x) ``` -+++ {"id": "vFL6DNpECfVz"} - This is because the function generates an array whose shape is not known at compile time: the size of the output depends on the values of the input array, and so it is not compatible with JIT. -+++ {"id": "BzBnKbXwXjLV"} ++++ -## JIT mechanics: tracing and static variables +For more on JIT compilation in JAX, check out [Just-in-time compilation](https://docs.jax.dev/en/latest/jit-compilation.html). -**Key concepts:** ++++ -- JIT and other JAX transforms work by *tracing* a function to determine its effect on inputs of a specific shape and type. +## Taking derivatives with `jax.grad` -- Variables that you don't want to be traced can be marked as *static* +**Key concepts:** +- JAX provides automatic differentiation via the `jax.grad` transformation. +- The `jax.grad` and `jax.jit` transformations compose and can be mixed arbitrarily. -To use `jax.jit` effectively, it is useful to understand how it works. Let's put a few `print()` statements within a JIT-compiled function and then call the function: +In addition to transforming functions via JIT compilation, JAX also provides other transformations. One such transformation is [`jax.grad`](https://docs.jax.dev/en/latest/_autosummary/jax.grad.html), which performs [automatic differentiation (autodiff)](https://en.wikipedia.org/wiki/Automatic_differentiation): ```{code-cell} ipython3 -:id: TfjVIVuD4gnc -:outputId: 9f4ddcaa-8ab7-4984-afb6-47fede5314ea +from jax import grad -@jit -def f(x, y): - print("Running f():") - print(f" x = {x}") - print(f" y = {y}") - result = jnp.dot(x + 1, y + 1) - print(f" result = {result}") - return result +def sum_logistic(x): + return jnp.sum(1.0 / (1.0 + jnp.exp(-x))) -x = np.random.randn(3, 4) -y = np.random.randn(4) -f(x, y) +x_small = jnp.arange(3.) +derivative_fn = grad(sum_logistic) +print(derivative_fn(x_small)) ``` -+++ {"id": "Ts1fP45A40QV"} +Let's verify with finite differences that our result is correct. -Notice that the print statements execute, but rather than printing the data we passed to the function, though, it prints *tracer* objects that stand-in for them. +```{code-cell} ipython3 +def first_finite_differences(f, x, eps=1E-3): + return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps) + for v in jnp.eye(len(x))]) -These tracer objects are what `jax.jit` uses to extract the sequence of operations specified by the function. Basic tracers are stand-ins that encode the **shape** and **dtype** of the arrays, but are agnostic to the values. This recorded sequence of computations can then be efficiently applied within XLA to new inputs with the same shape and dtype, without having to re-execute the Python code. +print(first_finite_differences(sum_logistic, x_small)) +``` -When we call the compiled function again on matching inputs, no re-compilation is required and nothing is printed because the result is computed in compiled XLA rather than in Python: +The [`jax.grad`](https://docs.jax.dev/en/latest/_autosummary/jax.grad.html) and [`jax.jit`](https://docs.jax.dev/en/latest/_autosummary/jax.jit.html) transformations compose and can be mixed arbitrarily. +For instance, while the `sum_logistic` function was differentiated directly in the previous example, it could also be JIT-compiled, and these operations can be combined. We can go further: ```{code-cell} ipython3 -:id: xGntvzNH7skE -:outputId: 43aaeee6-3853-4b00-fb2b-646df695204a - -x2 = np.random.randn(3, 4) -y2 = np.random.randn(4) -f(x2, y2) +print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0)) ``` -+++ {"id": "9EB9WkRX7fm0"} +Beyond scalar-valued functions, the [`jax.jacobian`](https://docs.jax.dev/en/latest/_autosummary/jax.jacobian.html) transformation can be +used to compute the full Jacobian matrix for vector-valued functions: + +```{code-cell} ipython3 +from jax import jacobian +print(jacobian(jnp.exp)(x_small)) +``` -The extracted sequence of operations is encoded in a JAX expression, or *jaxpr* for short. You can view the jaxpr using the `jax.make_jaxpr` transformation: +For more advanced autodiff operations, you can use [`jax.vjp`](https://docs.jax.dev/en/latest/_autosummary/jax.vjp.html) for reverse-mode vector-Jacobian products, +and [`jax.jvp`](https://docs.jax.dev/en/latest/_autosummary/jax.jvp.html) and [`jax.linearize`](https://docs.jax.dev/en/latest/_autosummary/jax.linearize.html) for forward-mode Jacobian-vector products. +The two can be composed arbitrarily with one another, and with other JAX transformations. +For example, `jax.jvp` and `jax.vjp` are used to define the forward-mode [`jax.jacfwd`](https://docs.jax.dev/en/latest/_autosummary/jax.jacfwd.html) and reverse-mode [`jax.jacrev`](https://docs.jax.dev/en/latest/_autosummary/jax.jacrev.html) for computing Jacobians in forward- and reverse-mode, respectively. +Here's one way to compose them to make a function that efficiently computes full Hessian matrices: ```{code-cell} ipython3 -:id: 89TMp_Op5-JZ -:outputId: 48212815-059a-4af1-de82-cd39ecac264a +from jax import jacfwd, jacrev +def hessian(fun): + return jit(jacfwd(jacrev(fun))) +print(hessian(sum_logistic)(x_small)) +``` -from jax import make_jaxpr +This kind of composition produces efficient code in practice; this is more-or-less how JAX's built-in [`jax.hessian`](https://docs.jax.dev/en/latest/_autosummary/jax.hessian.html) function is implemented. -def f(x, y): - return jnp.dot(x + 1, y + 1) +For more on automatic differentiation in JAX, check out [Automatic differentiation](https://docs.jax.dev/en/latest/automatic-differentiation.html). -make_jaxpr(f)(x, y) -``` ++++ -+++ {"id": "0Oq9S4MZ90TL"} +## Auto-vectorization with `jax.vmap` -Note one consequence of this: because JIT compilation is done *without* information on the content of the array, control flow statements in the function cannot depend on traced values. For example, this fails: +**Key concepts:** +- JAX provides automatic vectorization via the [`jax.vmap`](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html) transformation. +- `jax.vmap` can be composed with `jax.jit` to produce efficient vectorized code. + +Another useful transformation is [`jax.vmap`](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html), the vectorizing map. +It has the familiar semantics of mapping a function along array axes, but instead of explicitly looping +over function calls, it transforms the function into a natively vectorized version for better performance. +When composed with [`jax.jit`](https://docs.jax.dev/en/latest/_autosummary/jax.jit.html), it can be just as performant as manually rewriting your function +to operate over an extra batch dimension. + +We're going to work with a simple example, and promote matrix-vector products into matrix-matrix products using [`jax.vmap`](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html). +Although this is easy to do by hand in this specific case, the same technique can apply to more complicated functions. ```{code-cell} ipython3 -:id: A0rFdM95-Ix_ -:outputId: e37bf04e-6a6a-4536-e423-f082f52d5f11 -:tags: [raises-exception] +from jax import random -@jit -def f(x, neg): - return -x if neg else x +key = random.key(1701) +key1, key2 = random.split(key) +mat = random.normal(key1, (150, 100)) +batched_x = random.normal(key2, (10, 100)) -f(1, True) +def apply_matrix(x): + return jnp.dot(mat, x) ``` -+++ {"id": "DkTO9m8j-TYI"} - -If there are variables that you would not like to be traced, they can be marked as static for the purposes of JIT compilation: +The `apply_matrix` function maps a vector to a vector, but we may want to apply it row-wise across a matrix. +We could do this by looping over the batch dimension in Python, but this usually results in poor performance. ```{code-cell} ipython3 -:id: K1C7ZnVv-lbv -:outputId: e9d6cce3-b036-43da-ad99-887af9625ab0 +def naively_batched_apply_matrix(v_batched): + return jnp.stack([apply_matrix(v) for v in v_batched]) -from functools import partial +print('Naively batched') +%timeit naively_batched_apply_matrix(batched_x).block_until_ready() +``` -@partial(jit, static_argnums=(1,)) -def f(x, neg): - return -x if neg else x +A programmer familiar with the `jnp.dot` function might recognize that `apply_matrix` can +be rewritten to avoid explicit looping, using the built-in batching semantics of `jnp.dot`: -f(1, True) -``` +```{code-cell} ipython3 +import numpy as np -+++ {"id": "dD7p4LRsGzhx"} +@jit +def batched_apply_matrix(batched_x): + return jnp.dot(batched_x, mat.T) -Note that calling a JIT-compiled function with a different static argument results in re-compilation, so the function still works as expected: +np.testing.assert_allclose(naively_batched_apply_matrix(batched_x), + batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4) +print('Manually batched') +%timeit batched_apply_matrix(batched_x).block_until_ready() +``` + +However, as functions become more complicated, this kind of manual batching becomes more difficult and error-prone. +The `jax.vmap` transformation is designed to automatically transform a function into a batch-aware version: ```{code-cell} ipython3 -:id: sXqczBOrG7-w -:outputId: 5fb7c278-b87e-4a6b-ef50-5e4e9c765b52 +from jax import vmap + +@jit +def vmap_batched_apply_matrix(batched_x): + return vmap(apply_matrix)(batched_x) -f(1, False) +np.testing.assert_allclose(naively_batched_apply_matrix(batched_x), + vmap_batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4) +print('Auto-vectorized with vmap') +%timeit vmap_batched_apply_matrix(batched_x).block_until_ready() ``` -+++ {"id": "ZESlrDngGVb1"} +As you would expect, `jax.vmap` can be arbitrarily composed with `jax.jit`, +`jax.grad`, and any other JAX transformation. -Understanding which values and operations will be static and which will be traced is a key part of using `jax.jit` effectively. +For more on automatic vectorization in JAX, check out [Automatic vectorization](https://docs.jax.dev/en/latest/automatic-vectorization.html). -+++ {"id": "r-RCl_wD5lI7"} ++++ -## Static vs traced operations +(key-concepts-prngs)= +## Pseudorandom numbers **Key concepts:** -- Just as values can be either static or traced, operations can be static or traced. +- JAX uses a different model for pseudo random number generation than NumPy. +- JAX random functions consume a random `key` that must be split to generate new independent keys. +- JAX's random key model is thread-safe and avoids issues with global state. -- Static operations are evaluated at compile-time in Python; traced operations are compiled & evaluated at run-time in XLA. - -- Use `numpy` for operations that you want to be static; use `jax.numpy` for operations that you want to be traced. - -This distinction between static and traced values makes it important to think about how to keep a static value static. Consider this function: +Generally, JAX strives to be compatible with NumPy, but pseudo random number generation is a notable exception. NumPy supports a method of pseudo random number generation that is based on a global `state`, which can be set using [`numpy.random.seed`](https://numpy.org/doc/stable/reference/random/generated/numpy.random.seed.html). Global random state interacts poorly with JAX's compute model and makes it difficult to enforce reproducibility across different threads, processes, and devices. JAX instead tracks state explicitly via a random `key`: ```{code-cell} ipython3 -:id: XJCQ7slcD4iU -:outputId: 3646dea0-f6b6-48e9-9dc0-c4dec7816b7a -:tags: [raises-exception] +from jax import random -import jax.numpy as jnp -from jax import jit +key = random.key(43) +print(key) +``` -@jit -def f(x): - return x.reshape(jnp.array(x.shape).prod()) +The key is effectively a stand-in for NumPy's hidden state object, but we pass it explicitly to [`jax.random`](https://docs.jax.dev/en/latest/jax.random.html) functions. Importantly, random functions consume the key, but do not modify it: feeding the same key object to a random function will always result in the same sample being generated. -x = jnp.ones((2, 3)) -f(x) +```{code-cell} ipython3 +print(random.normal(key)) +print(random.normal(key)) ``` -+++ {"id": "ZO3GMGrHBZDS"} +**The rule of thumb is: never reuse keys (unless you want identical outputs).** -This fails with an error specifying that a tracer was found instead of a 1D sequence of concrete values of integer type. Let's add some print statements to the function to understand why this is happening: +In order to generate different and independent samples, you must [`jax.random.split`](https://docs.jax.dev/en/latest/_autosummary/jax.random.split.html) the key explicitly before passing it to a random function: ```{code-cell} ipython3 -:id: Cb4mbeVZEi_q -:outputId: 30d8621f-34e1-4e1d-e6c4-c3e0d8769ec4 +for i in range(3): + new_key, subkey = random.split(key) + del key # The old key is consumed by split() -- we must never use it again. -@jit -def f(x): - print(f"x = {x}") - print(f"x.shape = {x.shape}") - print(f"jnp.array(x.shape).prod() = {jnp.array(x.shape).prod()}") - # comment this out to avoid the error: - # return x.reshape(jnp.array(x.shape).prod()) + val = random.normal(subkey) + del subkey # The subkey is consumed by normal(). -f(x) + print(f"draw {i}: {val}") + key = new_key # new_key is safe to use in the next iteration. ``` -+++ {"id": "viSQPc3jEwJr"} +Note that this code is thread safe, since the local random state eliminates possible race conditions involving global state. `jax.random.split` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys. -Notice that although `x` is traced, `x.shape` is a static value. However, when we use `jnp.array` and `jnp.prod` on this static value, it becomes a traced value, at which point it cannot be used in a function like `reshape()` that requires a static input (recall: array shapes must be static). +For more on pseudo random numbers in JAX, see the [Pseudorandom numbers tutorial](https://docs.jax.dev/en/latest/random-numbers.html). -A useful pattern is to use `numpy` for operations that should be static (i.e. done at compile-time), and use `jax.numpy` for operations that should be traced (i.e. compiled and executed at run-time). For this function, it might look like this: ++++ -```{code-cell} ipython3 -:id: GiovOOPcGJhg -:outputId: 5363ad1b-23d9-4dd6-d9db-95a6c9de05da +## Debugging -from jax import jit +Debugging JAX code can be challenging due to its functional programming model and the fact that JAX code is often transformed via JIT compilation or vectorization. However, JAX provides several tools to help with debugging. + +### `jax.debug.print` + +For simple inspection, use [`jax.debug.print`](https://docs.jax.dev/en/latest/_autosummary/jax.debug.print.html). + +Python's built-in `print` executes at trace-time, before the runtime values exist. Because of this, `print` will only show tracer values within `jax.jit`-decorated code. + +```{code-cell} ipython3 +import jax import jax.numpy as jnp -import numpy as np -@jit +@jax.jit def f(x): - return x.reshape((np.prod(x.shape),)) + print("print(x) ->", x) + y = jnp.sin(x) + print("print(y) ->", y) + return y -f(x) +result = f(2.) ``` -+++ {"id": "C-QZ5d1DG-dv"} +If you want to print the actual runtime values, you can use `jax.debug.print`: + +```{code-cell} ipython3 +@jax.jit +def f(x): + jax.debug.print("jax.debug.print(x) -> {x}", x=x) + y = jnp.sin(x) + jax.debug.print("jax.debug.print(y) -> {y}", y=y) + return y + +result = f(2.) +``` + +### Debugging flags + +JAX offers flags and context managers that enable catching errors more easily. For example, you can enable the `jax.debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code. You can also enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`. + +For more details, see [Introduction to debugging](https://docs.jax.dev/en/latest/debugging.html). + +--- -For this reason, a standard convention in JAX programs is to `import numpy as np` and `import jax.numpy as jnp` so that both interfaces are available for finer control over whether operations are performed in a static manner (with `numpy`, once at compile-time) or a traced manner (with `jax.numpy`, optimized at run-time). +This is just a taste of what JAX can do. We're really excited to see what you do with it! diff --git a/docs/notes.rst b/docs/notes.rst index 08265638000e..502385142b16 100644 --- a/docs/notes.rst +++ b/docs/notes.rst @@ -9,9 +9,6 @@ Dependencies and version compatibility: - :doc:`api_compatibility` outlines JAX's policies with regard to API compatibility across releases. - :doc:`deprecation` outlines JAX's policies with regard to compatibility with Python and NumPy. -Migrations and deprecations: - - :doc:`jax_array_migration` summarizes the changes to the default array type in jax v 0.4.1 - Memory and computation usage: - :doc:`async_dispatch` describes JAX's asynchronous dispatch model. - :doc:`concurrency` describes how JAX interacts with other Python concurrency. @@ -20,6 +17,10 @@ Memory and computation usage: Programmer guardrails: - :doc:`rank_promotion_warning` describes how to configure :mod:`jax.numpy` to avoid implicit rank promotion. +Arrays and data types: + - :doc:`type_promotion` describes JAX's implicit type promotion for functions of two or more values. + - :doc:`default_dtypes` describes how JAX determines the default dtype for array creation functions. + .. toctree:: :hidden: @@ -27,8 +28,9 @@ Programmer guardrails: api_compatibility deprecation - jax_array_migration async_dispatch concurrency gpu_memory_allocation - rank_promotion_warning \ No newline at end of file + rank_promotion_warning + type_promotion + default_dtypes diff --git a/docs/pallas/CHANGELOG.md b/docs/pallas/CHANGELOG.md index 2b1cad7c9a66..2721aa3dba89 100644 --- a/docs/pallas/CHANGELOG.md +++ b/docs/pallas/CHANGELOG.md @@ -2,15 +2,111 @@ # Pallas Changelog - + This is the list of changes specific to {class}`jax.experimental.pallas`. -For the overall JAX change log see [here](https://jax.readthedocs.io/en/latest/changelog.html). +For the overall JAX change log see [here](https://docs.jax.dev/en/latest/changelog.html). +## Unreleased + +* Changes + + * The default lowering path on GPU now goes through Mosaic GPU. To keep using + Triton, call {func}`jax.experimental.pallas.pallas_call` with + the `backend` argument set to `'triton'`. + +* Removals + + * Removed the previously deprecated `pl.atomic_*`, `pl.load`, `pl.store`, + `pl.swap` and `pl.max_contiguous`. + +## Released with jax 0.8.1 + +* New features: + + * Added {func}`jax.experimental.pallas.tpu.get_tpu_info` to get TPU hardware information. + +* Deprecations + + * `pl.max_contiguous` has been moved to {mod}`jax.experimental.pallas.triton`. + Accessing it via {mod}`jax.experimental.pallas` is deprecated. + * `pl.swap` is deprecated and will be removed in a future release. Use + indexing or backend-specific loading/storing APIs instead. + +* Removals + + * Removed the previously deprecated + {class}`jax.experimental.pallas.tpu.TPUCompilerParams`, + {class}`jax.experimental.pallas.tpu.TPUMemorySpace`, + {class}`jax.experimental.pallas.tpu.TritonCompilerParams`. + +## Released with jax 0.7.1 + +* New features: + + * `pltpu.make_async_remote_copy` and `pltpu.semaphore_signal`'s `device_id` + argument now allows user to pass in a dictionary that only specifies the + device index along the communication axis, instead of the full coordinates. + It also supports TPU core id index. + * `jax.debug.print` now works in Pallas kernels and is the recommended way to + print. + +* Deprecations + + * `pl.atomic_*` APIs have been moved to {mod}`jax.experimental.pallas.triton`. + Accessing them via {mod}`jax.experimental.pallas` is deprecated. + * `pl.load` and `pl.store` are deprecated. Use indexing or backend-specific + loading/storing APIs instead. + +## Released with jax 0.7.0 + +* New functionality + + * Added a new decorator {func}`jax.experimental.pallas.loop` which allows + to write stateless loops as functions. + * Added new multiple buffering and lookahead functionality to + {func}`jax.experimental.pallas.tpu.emit_pipeline`. Input buffers can now + be multiple-buffered with more than 2 buffers and support a lookahead option + to fetch blocks that are an arbitrary number of grid iterations ahead + rather than the immediate next iterations. Additionally, pipeline state + can now be held in registers to reduce scalar memory usage. + +* Deprecations + + * {class}`jax.experimental.pallas.triton.TritonCompilerParams` has been + renamed to {class}`jax.experimental.pallas.triton.CompilerParams`. The + old name is deprecated and will be removed in a future release. + * {class}`jax.experimental.pallas.tpu.TPUCompilerParams` + and {class}`jax.experimental.pallas.tpu.TPUMemorySpace` have been + renamed to {class}`jax.experimental.pallas.tpu.CompilerParams` + and {class}`jax.experimental.pallas.tpu.MemorySpace`. The + old names are deprecated and will be removed in a future release. + +## Released with jax 0.6.1 + +* Removals + + * Removed previously deprecated {mod}`jax.experimental.pallas.gpu`. To use + the Triton backend import {mod}`jax.experimental.pallas.triton`. + +* Changes + + * {func}`jax.experimental.pallas.BlockSpec` now takes in special types in + addition to ints/None in the `block_shape`. `indexing_mode` has been + removed. To achieve "Unblocked", pass a `pl.Element(size)` into + `block_shape` for each entry that needs unblocked indexing. + * {func}`jax.experimental.pallas.pallas_call` now requires `compiler_params` + to be a backend-specific dataclass instead of a param to value mapping. + * {func}`jax.experimental.pallas.debug_check` is now supported both on + TPU and Mosaic GPU. Previously, this functionality was only supported + on TPU and required using the APIs from {mod}`jax.experimental.checkify`. + Note that debug checks are not executed unless + {data}`jax.experimental.pallas.enable_debug_checks` is set. + ## Released with jax 0.5.0 * New functionality diff --git a/docs/pallas/design/async_note.md b/docs/pallas/design/async_note.md index 42e32a074fd7..b21725f7a29e 100644 --- a/docs/pallas/design/async_note.md +++ b/docs/pallas/design/async_note.md @@ -1,3 +1,4 @@ +(pallas_async)= # Pallas Async Operations ## Background \+ Motivation @@ -17,7 +18,7 @@ def f(x): In this function, we could perform the `ppermute` at the same time as the `x + 1`. This is an optimization XLA does automatically by: -1. decomposing `ppermute` into a `ppermute_start` and `ppermute_done` op, which are connected via a future. +1. decomposing `ppermute` into a `ppermute_start` and `ppermute_done` op, which are connected via a future. 2. scheduling the `x + 1` between the `ppermute_start` and `ppermute_done`, resulting in the following program: @@ -106,12 +107,12 @@ def ppermute_start(x, *, axis_name) -> tuple[Semaphore, Semaphore, Array]: ), ), in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], out_specs=( pl.BlockSpec(memory_space=pltpu.SEMAPHORE), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), ), )(x) return send_sem, recv_sem, out @@ -138,11 +139,11 @@ def ppermute_done(send_sem, recv_sem, out) ->Array: ), ), in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), ], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), input_output_aliases={0:0} )(out, send_sem, recv_sem) return out @@ -166,9 +167,9 @@ def f(x): There are three remaining issues with this, each of which exists outside of Pallas to some degree. Here they are at a high level. -1. Scheduling \- just because we write `ppermute_start`, then `x + 1`, then `ppermute_done` doesn’t guarantee that they will happen in that order. XLA is responsible for scheduling, so when we write JAX programs, we are setting up data dependencies that XLA will respect but XLA will not respect the specific order of operations written in JAX. -2. Lifetimes \- XLA assumes that once a value is out of scope in the dependency graph, its memory can be freed for use by other values. If we have an op that asynchronously copies x \-\> y, we need to ensure that x is alive until the copy is complete, otherwise we will be copying from garbage memory. -3. Defensive copies \- XLA reserves the right to create copies of values. We need to make sure we don’t introduce unnecessary copies to a) avoid unnecessary runtime overhead and b) ensure correctness. +1. Scheduling \- just because we write `ppermute_start`, then `x + 1`, then `ppermute_done` doesn’t guarantee that they will happen in that order. XLA is responsible for scheduling, so when we write JAX programs, we are setting up data dependencies that XLA will respect but XLA will not respect the specific order of operations written in JAX. +2. Lifetimes \- XLA assumes that once a value is out of scope in the dependency graph, its memory can be freed for use by other values. If we have an op that asynchronously copies x \-\> y, we need to ensure that x is alive until the copy is complete, otherwise we will be copying from garbage memory. +3. Defensive copies \- XLA reserves the right to create copies of values. We need to make sure we don’t introduce unnecessary copies to a) avoid unnecessary runtime overhead and b) ensure correctness. We will go over these issues one by one and suggest fixes. @@ -291,13 +292,13 @@ def ppermute_start(x, *, axis_name) -> tuple[Semaphore, Semaphore, Array, Array] ), ), in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], out_specs=( pl.BlockSpec(memory_space=pltpu.SEMAPHORE), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), ), input_output_aliases={0:2} )(x) @@ -321,12 +322,12 @@ def ppermute_done(send_sem, recv_sem, x, out) ->Array: ), ), in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), ], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), input_output_aliases={1:0} )(x, out, send_sem, recv_sem) return out @@ -463,7 +464,7 @@ def f(x): return fori_loop(0, 8, body, x) ``` -If you run the alias analysis, you’ll find that all of the buffers have been colored the same\! Intuitively, this is problematic because if we are doing a loop of `ppermute`s, we can’t write into the same buffer we are sending into. We generally need an extra (i.e. a “double”) buffer to receive, and then usually we will switch the send/recv buffers on the next iteration. What XLA will do in practice is that it will observe the buffer re-use and defensively insert a copy. +If you run the alias analysis, you’ll find that all of the buffers have been colored the same\! Intuitively, this is problematic because if we are doing a loop of `ppermute`s, we can’t write into the same buffer we are sending into. We generally need an extra (i.e. a “double”) buffer to receive, and then usually we will switch the send/recv buffers on the next iteration. What XLA will do in practice is that it will observe the buffer reuse and defensively insert a copy. ```py def f(x): @@ -484,7 +485,7 @@ def f(x): def body(i, x): *sems, x, x2 = ppermute_start(x) x2 = ppermute_done((*sems, x, x2)) - + *sems, x2, y = ppermute_start(x2) y = ppermute_done((*sems, x2, y)) return y @@ -573,10 +574,10 @@ our program should now be correct. So we’ve come up with some rules of thumb: -1. If we have operations dependent on the input value to the `ppermute`, unpack the future to use the aliased value instead of the original value. +1. If we have operations dependent on the input value to the `ppermute`, unpack the future to use the aliased value instead of the original value. 2. Use `unroll >= 2` when doing `ppermute`s in a loop body. -Let’s combine everything into one function that does `ppermute`s in a loop and accumulates the result. +Let’s combine everything into one function that does `ppermute`s in a loop and accumulates the result. ```py def f(x): @@ -640,7 +641,7 @@ def f(x): return y_ref[...] ``` -Before, we ran into scheduling ambiguity, where XLA could re-order the add w.r.t. the `ppermute`. With stateful semantics, we actually add in an ordering constraint\! `x_ref[...] += 1` mutates `x_ref` so it can’t be moved wrt to `ppermute_done_stateful`. JAX can inject these scheduling constraints as part of the lowering to HLO. +Before, we ran into scheduling ambiguity, where XLA could re-order the add w.r.t. the `ppermute`. With stateful semantics, we actually add in an ordering constraint\! `x_ref[...] += 1` mutates `x_ref` so it can’t be moved wrt to `ppermute_done_stateful`. JAX can inject these scheduling constraints as part of the lowering to HLO. The final key difference is evident when we try our loop examples. @@ -664,8 +665,8 @@ To handle this without the manual unrolling, we’d create a scratch buffer with The realization here is that being stateful forces us to deal with a lot of the issues that pop up with value semantics earlier on. We define them away\! -1. Scheduling \- stateful ops that have `Ref`s as inputs force an ordering of our program. Note that this will schedule operations on the same `Ref` wrt to each other. We might also need an `opt_barrier_stateful` to enforce more ordering constraints. -2. Lifetimes \- `Ref` lifetimes can be scoped via `run_state` or could be inputs to stateful ops. +1. Scheduling \- stateful ops that have `Ref`s as inputs force an ordering of our program. Note that this will schedule operations on the same `Ref` wrt to each other. We might also need an `opt_barrier_stateful` to enforce more ordering constraints. +2. Lifetimes \- `Ref` lifetimes can be scoped via `run_state` or could be inputs to stateful ops. 3. Defensive copies \- Using `Ref`s forces us to handle buffer assignment “manually” and the lowering can ensure the aliasing works out to avoid any copies. Another important fundamental limitation is that we eventually stage out an HLO program where the live buffers and semaphores are represented as array value types. XLA does not provide guarantees about buffer lifetimes or which memory spaces they live in for these intermediate values. *Therefore, it is possible XLA can copy array values even if they are actively being copied into by Pallas kernels.* This is easy to verify in HLO but it is a sharp edge of using custom calls to represent asynchronous operations in HLO. diff --git a/docs/pallas/design/design.md b/docs/pallas/design/design.md index 17c7a6dbdc0f..53a5eb209510 100644 --- a/docs/pallas/design/design.md +++ b/docs/pallas/design/design.md @@ -71,7 +71,7 @@ A JAX-based kernel language offers several advantages: * JAX as a tracing-based frontend for numerical computing is both mature and well-used. By embedding the kernel programming language in JAX itself, - we can re-use JAX’s tracing infrastructure and provide a + we can reuse JAX’s tracing infrastructure and provide a NumPy-like frontend that’s already familiar to users. * JAX transformations are key to its success, allowing users to express simple programs but transform them to achieve complex @@ -551,7 +551,7 @@ along that dimension. `grad` of `pallas_call` enables automatic differentiation of kernels. `jax.grad` breaks down into applications of three distinct transforms: `jvp`, `partial_eval` and `transpose`. -In principle, we can re-use most of JAX’s infrastructure when +In principle, we can reuse most of JAX’s infrastructure when implementing these rules for `pallas_call` (since it behaves much like existing JAX higher order primitives). diff --git a/docs/pallas/gpu/blackwell_matmul.md b/docs/pallas/gpu/blackwell_matmul.md new file mode 100644 index 000000000000..81af3a6ee5ea --- /dev/null +++ b/docs/pallas/gpu/blackwell_matmul.md @@ -0,0 +1,948 @@ +# Writing high-performance matrix multiplication kernels for Blackwell + +In this guide, we'll progressively iterate on a matrix multiplication kernel. +The first implementation will be very simple, but also quite slow. +However, in just a few simple steps it can be modified into a state-of-the-art +kernel, matching or exceeding highly optimized implementations such as cuBLAS +and CUTLASS. + +```{warning} +The utilization shown in the table below might be different than what you see online, +but the differences can likely be explained by a different input data distribution. +All our benchmarks here use arrays with iid normal float16 entries, which turn out +to be one of the slower distributions you can choose. You can reproduce +the numbers for yourself by running [our test file](https://github.com/jax-ml/jax/blob/main/tests/pallas/mgpu_examples_test.py) after changing the `BENCHMARK` variable to `True`. + +**tl;dr** don't believe matmul benchmarks if they don't specify input data distribution. +``` + +| Implementation | TensorCore utilization | % of cuBLAS utilization | +|---------------------------------|------------------------|-------------------------| +| 0. Basic kernel | 37.62% | 59.4% | +| 1. Warp specialization | 45.47% | 71.7% | +| 2. Tiled epilogue | 55.82% | 88.1% | +| 3. Collective (2CTA) MMA | 59.41% | 93.7% | +| 4. Persistent kernel | 61.46% | 97.0% | +| 5. Dedicated epilogue warpgroup | 63.38% | 100.0% | +| 6. Grid tiling | 69.44% | 109.6% | +| cuBLAS | 63.38% | 100.0% | +| CUTLASS | 69.30% | 109.3% | + +The cuBLAS baseline is obtained by measuring the performace of `jax.dot`. The +CUTLASS performance is measured by taking the best result from the following +`cutlass_profiler` invocation (excluding sparse matmuls): +``` +cutlass_profiler --dist=gaussian,mean:0,stddev:1,scale:-1 --output=results.csv --accumulator-type=f32 --m=4096 --k=4096 --n=8192 --kernels='*sm100*' --A=f16 --B=f16 --C=void --D=f16 +``` + +At each step, we will showcase either the full implementation of the kernel, or +the difference between the code listings shown in the previous and current steps. +Full implementations can be found in [our test file](https://github.com/jax-ml/jax/blob/main/tests/pallas/mgpu_examples_test.py). You can also find the a full standalone +optimized kernel implementation [in the Pallas ops package](https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/gpu/blackwell_matmul_mgpu.py). + +## 0. Basic kernel + +We begin with a simple single-CTA (block) single-warpgroup example. +For convenience, we split the tuning parameters of the kernel into a separate +class: + +```python +@dataclasses.dataclass(frozen=True) +class TuningConfig: + tile_m: int + tile_n: int + tile_k: int + max_concurrent_steps: int +``` + +`tile_m`, `tile_n` and `tile_k` specify the size of the matmul performed at +every step of the pipeline. In general, `tile_k` should ideally be equal to +128 divided by the bytewidth of the input element type. `max_concurrent_steps` +specifies the depth of memory prefetch in the compute/memory pipeline, which is +frequently called the number of stages in other implementations. + +The kernel implementation begins with a bit of setup: + +```python +def matmul0(a, b, config: TuningConfig): + dtype = a.dtype + m, k = a.shape + _, n = b.shape + tile_m, tile_n, tile_k = config.tile_m, config.tile_n, config.tile_k + swizzle = plgpu.find_swizzle(tile_k * jnp.dtype(dtype).itemsize * 8) + swizzle_elems = swizzle // jnp.dtype(dtype).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), plgpu.SwizzleTransform(swizzle) + ) + if m % tile_m != 0: + raise ValueError(f"{m=} must be divisible by {tile_m=}") + if n % tile_n != 0: + raise ValueError(f"{n=} must be divisible by {tile_n=}") + if k % tile_k != 0: + raise ValueError(f"{k=} must be divisible by {tile_k=}") + m_iters = m // tile_m + n_iters = n // tile_n + k_iters = k // tile_k + max_concurrent_steps = config.max_concurrent_steps +``` + +We unpack the config variables for easier access, set the tiling and swizzling +transforms to get the SMEM data format to match [what's expected by MMA instructions](#memory-space-a-b-mma). + +The kernel implementation itself is relatively short. The first part sets up +a [compute/memory pipeline](./pipelining.md) using {py:func}`plgpu.emit_pipeline `. At each step, the compute function (`do_mma`) consumes a +`(tile_m, tile_k)` slice of LHS and `(tile_k, tile_n)` slice of RHS. As mentioned +before, we specify `transforms`, as well `delay_release=1`. This last parameter +ensures that the input windows (`a_smem`, `b_smem`) passed into `do_mma` will not +be overwritten at least until the next invocation of `do_mma` completes. This is +necessary because we only await the completion of the MMA from the one step +in the following step, which is why `arrive_barrier_slot` and `wait_barrier_slot` +flip between 0 and 1 at each invocation. + +```python + def kernel(a_gmem, b_gmem, out_gmem, acc_tmem, acc_smem, consumed_barriers): + mi = lax.axis_index("m") + ni = lax.axis_index("n") + m_slice = pl.ds(mi * tile_m, tile_m) + n_slice = pl.ds(ni * tile_n, tile_n) + + def do_mma(idxs, a_smem, b_smem): + (ki,) = idxs + arrive_barrier_slot = ki % 2 + wait_barrier_slot = 1 - arrive_barrier_slot + plgpu.tcgen05_mma( + acc_tmem, + a_smem, + b_smem, + barrier=consumed_barriers.at[arrive_barrier_slot], + accumulate=(ki > 0), + ) + plgpu.barrier_wait(consumed_barriers.at[wait_barrier_slot]) + + # Make sure the wait succeeds in the first iteration. + plgpu.barrier_arrive(consumed_barriers.at[1]) + block_kwargs = dict(transforms=transforms, delay_release=1) + plgpu.emit_pipeline( + do_mma, + in_specs=[ + plgpu.BlockSpec((tile_m, tile_k), lambda ki: (mi, ki), **block_kwargs), + plgpu.BlockSpec((tile_k, tile_n), lambda ki: (ki, ni), **block_kwargs), + ], + grid=(k_iters,), + max_concurrent_steps=max_concurrent_steps, + )(a_gmem, b_gmem) +``` + +The kernel itself ends with an epilogue. We await the completion of the last MMA +issued by the pipeline before doing anything. Then, we load the final accumulator +from TMEM, write it to SMEM ([remembering to call `plgpu.commit_smem`](#commit-smem)), +and copy it back to GMEM using TMA. + +```python + def kernel(...): + ... # compute pipeline as above + final_barrier = 1 - (k_iters % 2) + plgpu.barrier_wait(consumed_barriers.at[final_barrier]) + acc_smem[...] = plgpu.async_load_tmem(acc_tmem).astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(acc_smem, out_gmem.at[m_slice, n_slice]) + plgpu.wait_smem_to_gmem(0, wait_read_only=True) +``` + +What remains is to actually turn the kernel into a function that can be called +with JAX arrays. We use {py:func}`plgpu.kernel ` +for that. The grid is for now simply 2D and iterates over the output tiles. We +allocate intermediates used by the kernel: +1. The TMEM buffer used as an accumulator +2. The SMEM buffer used to stage the accumulator before its copy to GMEM +3. The barrier used to await the completion of MMA operations. + +```python +def matmul0(a, b, config): + ... # Setup code from the first snippet + def kernel(...): + ... # The whole kernel body + + f = plgpu.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct((m, n), dtype), + grid=(m_iters, n_iters), + grid_names=("m", "n"), + scratch_shapes=dict( + acc_tmem=plgpu.TMEM((tile_m, tile_n), jnp.float32), + acc_smem=plgpu.SMEM((tile_m, tile_n), dtype, transforms=transforms), + consumed_barriers=plgpu.Barrier( + num_arrivals=1, num_barriers=2, orders_tensor_core=True + ), + ) + ) + return f(a, b) +``` + +Omitting the setup code, that's just 50 lines! Unfortunately, it's not very +fast just yet, but it does achieve half the utilization of cuBLAS already! + +## 1. Warp specialization + +```{note} +Recall that on Blackwell a single Pallas:MGPU thread of execution corresponds to +a warpgroup of CUDA lanes/threads. +``` + +The kernel above uses a single warpgroup to do everything: from fetching the data, +through issuing MMA operations, to storing the results into GMEM. While one would +think that the asynchronicity in TensorCore execution should allow us to overlap +the overheads of async copies (TMA) and control-flow, it does not seem to be the +case. + +A common solution to this problem in the Hopper generation of GPUs was to utilize +_warpgroup_ specialization. In Pallas terms, `plgpu.kernel` can be called with +`num_threads=2`, meaning that each program in the grid would result in two calls +to the body. The thread index is then often queried using `lax.axis_index` and +used to select one of multiple different roles, such as _only_ issuing async +copies or _only_ running the MMA operations. + +This solution also works in the Blackwell generation, but it is in fact even +simpler. Since both the async copy (TMA) as well as the `tcgen05` MMA instruction [only require a single CUDA lane to issue them](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-issue-granularity), +we don't even need to use multiple _warpgroups_. We can simply break up a single +warpgroup into _four warps_ and specialize those! + +In Pallas, this can be achieved using `pl.core_map` with a `plgpu.WarpMesh`. +For each Pallas thread that calls such a `core_map`, the body will be invoked +exactly four times. The `core_map` synchronizes all warps both at entry at exit. +Note that only scalar operations are allowed in the body. + +This will be the biggest rewrite to this kernel we'll perform in this whole +sequence, which is why we'll list the entire kernel source once again. + +```python +def matmul1(a, b, config: TuningConfig): + ... # Setup code remains unmodified + + def kernel(a_gmem, b_gmem, out_gmem, + a_smem, b_smem, acc_tmem, acc_smem, + load_barriers, consumed_barriers, mma_done_barrier): + m_index = lax.axis_index("m") + n_index = lax.axis_index("n") + m_slice = pl.ds(m_index * tile_m, tile_m) + n_slice = pl.ds(n_index * tile_n, tile_n) + + @pl.core_map(plgpu.WarpMesh(axis_name="warp")) + def _per_warp(): + warp_id = lax.axis_index("warp") + + @pl.when(warp_id == 0) + def _memory(): + def _loop_body(ki, _): + slot = lax.rem(ki, max_concurrent_steps) + @pl.when(ki >= max_concurrent_steps) + def _(): # Make sure the data has been consumed before overwriting. + plgpu.barrier_wait(consumed_barriers.at[slot]) + k_slice = pl.ds(ki * tile_k, tile_k) + plgpu.copy_gmem_to_smem( + a_gmem.at[m_slice, k_slice], a_smem.at[slot], load_barriers.at[slot] + ) + plgpu.copy_gmem_to_smem( + b_gmem.at[k_slice, n_slice], b_smem.at[slot], load_barriers.at[slot] + ) + + lax.fori_loop(0, k_iters, _loop_body, None) + + @pl.when(warp_id == 1) + def _compute(): + def _loop_body(ki, _): + slot = lax.rem(ki, max_concurrent_steps) + plgpu.barrier_wait(load_barriers.at[slot]) # Wait for data to arrive. + plgpu.tcgen05_mma( + acc_tmem, + a_smem.at[slot], + b_smem.at[slot], + consumed_barriers.at[slot], + accumulate=(ki > 0), + ) + lax.fori_loop(0, k_iters, _loop_body, None) + plgpu.tcgen05_commit_arrive(mma_done_barrier) + + plgpu.barrier_wait(mma_done_barrier) + acc_smem[...] = plgpu.async_load_tmem(acc_tmem).astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(acc_smem, out_gmem.at[m_slice, n_slice]) + plgpu.wait_smem_to_gmem(0, wait_read_only=True) +``` + +The kernel has exactly the same structure as before: we first perform the compute, +which is followed by the epilogue. The epilogue remains the same (we only use a +different barrier to await the completion), so we will not discuss it further. + +The `plgpu.emit_pipeline` call and the `do_mma` function has been replaced by +a single `pl.core_map` invocation. You can see that immediately after entering +its body, each Pallas thread (now representing a warp!) finds out which of the +four threads it is. We then use thread with index 0 to _only_ issue async +copies that fetch the MMA operands in a loop, while thread with index 1 enters +another loop in which it repeatedly calls `plgpu.tcgen05_mma`. + +One interesting aspect here is the synchronization. We keep an array of +`load_barriers`, each tracking progress of an outstanding GMEM->SMEM copy. +The compute thread must await their completion before feeding the respective +operands to the MMA operation. Going in the other direction, the thread responsible +for async copies must await the completion of MMAs that consume operands before +it can overwrite the memory by issuing another async copy. This is tracked through +`consumed_barriers`. Finally, when the compute thread is done issuing all MMA +operations, it calls `plgpu.tcgen05_commit_arrive(mma_done_barrier)`, requesting +the TensorCore to complete the `mma_done_barrier` once all the MMA operations complete. + +We can now turn our attention to the `plgpu.kernel` definition. The only difference +to the prior version is that we explicitly allocate two additional SMEM buffers +that hold the MMA operands (previously they were implicitly allocated by +`plgpu.emit_pipeline`), as well as the additional barriers. Note that the +`load_barriers` have `num_arrivals=2`, because we issue two async copies on the +same barrier. `orders_tensor_core` is necessary to specify on barriers that are +meant to indicate the completion of TensorCore operations. + +```python +def matmul1(a, b, config: TuningConfig): + ... # Setup code remains unmodified + + def kernel(...): + ... # Kernel code above + + f = plgpu.kernel( + kernel, + ..., # Other parameters remain unchanged + scratch_shapes=dict( + a_smem=plgpu.SMEM( + (max_concurrent_steps, tile_m, tile_k), dtype, transforms=transforms + ), + b_smem=plgpu.SMEM( + (max_concurrent_steps, tile_k, tile_n), dtype, transforms=transforms + ), + acc_tmem=plgpu.TMEM((tile_m, tile_n), jnp.float32), + acc_smem=plgpu.SMEM((tile_m, tile_n), dtype, transforms=transforms), + load_barriers=plgpu.Barrier( + num_arrivals=2, num_barriers=max_concurrent_steps + ), + consumed_barriers=plgpu.Barrier( + num_arrivals=1, + num_barriers=max_concurrent_steps, + orders_tensor_core=True, + ), + mma_done_barrier=plgpu.Barrier( + num_arrivals=1, num_barriers=1, orders_tensor_core=True + ), + ) + ) + return f(a, b) +``` + +This relatively simple modification already gives us a meaningful bump in performance, +getting us up to almost 70% of cuBLAS performance. + +## 2. Tiled epilogue + +This time, we turn our attention away from the compute portion of the kernel and +instead focus on its epilogue. We can improve its efficiency by pipelining +the copy from TMEM to SMEM together with a transfer from SMEM to GMEM. To do this, +we change our `scratch_shapes` to allocate two smaller buffers instead of an +SMEM window that can hold the entire output (which also decreases our SMEM usage): + +```python +def matmul2(a, b, config): + ... # Setup and kernel code + f = plgpu.kernel( + ... + scratch_shapes=dict( + ... + # Previously: plgpu.SMEM((tile_m, tile_n), dtype, transforms=transforms), + acc_smem=plgpu.SMEM( + (2, tile_m, config.epilogue_tile_n), dtype, transforms=transforms + ), + ... + ) + ) +``` + +Then, in the kernel, we loop over the output columns in chunks of `epilogue_tile_n`, +and progressively send out the output to GMEM: + +```python +def matmul2(a, b, config): + ... # Setup code remains unchanged + + def kernel(...): + ... # Compute part remains unchanged + + plgpu.barrier_wait(mma_done_barrier) + out_gmem_window = out_gmem.at[m_slice, n_slice] + for ni in range(tile_n // config.epilogue_tile_n): + acc_smem_ni = acc_smem.at[ni % 2] + ni_slice = pl.ds(ni * config.epilogue_tile_n, config.epilogue_tile_n) + # Make sure that previous copy is done before we overwrite. + plgpu.wait_smem_to_gmem(1, wait_read_only=True) + acc_smem_ni[...] = plgpu.async_load_tmem(acc_tmem.at[:, ni_slice]).astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(acc_smem_ni, out_gmem_window.at[:, ni_slice]) + plgpu.wait_smem_to_gmem(0, wait_read_only=True) +``` + +## 3. Collective (2CTA) MMA + +If you benchmark our latest kernel, you'll quickly find out that it can't utilize +its compute units well, because they are constantly waiting on the memory to deliver +the MMA operands. This means that our kernel is memory bound, because it has too +low _arithmetic intensity_: the number of flops we perform for each byte we load +is too small. + +One very effective trick of the Blackwell architecture that allows us to double +our arithmetic intensity are [collective MMAs](#collective-mma). +The core idea is quite simple: we use a cluster of two blocks (on two SMs) to +compute a single matmul. Each block only loads half of each operand, but the MMA +operation exchanges the data from SMEM of each block as its running. + +We'll start with the kernel configuration changes again: + +```python +def matmul3(a, b, config): + ... # Setup code + cluster_tile_m = 2 * tile_m + cluster_tile_n = 2 * tile_n + m_iters = m // cluster_tile_m + n_iters = n // cluster_tile_n + ... # Setup code and kernel + + f = plgpu.kernel( + ... + grid=(m_iters, n_iters), + ... + cluster=(2,), + cluster_names=("cluster",), + scratch_shapes=dict( + ... + # Previously: plgpu.TMEM((tile_m, tile_n), jnp.float32), + acc_tmem=plgpu.TMEM( + (tile_m, cluster_tile_n), jnp.float32, collective=True + ), + ... + ) + ) +``` + +We add the `cluster` parameter to `plgpu.kernel` to indicate that we intend to +have pairs of programs collaborate (as CUDA block clusters). We also append +`collective=True` to our TMEM allocation, to ensure that it will be allowed to +be used by collective MMAs and double its number of columns (to `cluster_tile_n`). + +Another notable change is that our pair of blocks will ultimately compute a +4x larger output tile, which is why we shrink the grid correspondingly. + +We first update the entry of the kernel: + +```python + def kernel(...): + is_lead_block = lax.axis_index("cluster") == 0 + m_index = lax.axis_index("m") + n_index = lax.axis_index("n") + m_slice = pl.ds(m_index * cluster_tile_m, cluster_tile_m) + n_slice = pl.ds(n_index * cluster_tile_n, cluster_tile_n) +``` + +The only changes here are that we use `cluster_tile_m` and `cluster_tile_n` to +compute the slice of the output the two blocks will collectively compute, and +we also check if the current invocation corresponds to the first (leader) block +in the cluster. This is important, because _only the leader block is supposed to +issue MMA instructions_: + +```python + @pl.core_map(plgpu.WarpMesh(axis_name="warp")) + def _per_warp(): + warp_id = lax.axis_index("warp") + + @pl.when(warp_id == 0) + def _memory(): + def _loop_body(ki, _): + ... # Wait for the data to be consumed, as previously. + plgpu.copy_gmem_to_smem( + ..., collective_axes="cluster", partitioned_axis=0 + ) + plgpu.copy_gmem_to_smem( + ..., collective_axes="cluster", partitioned_axis=1 + ) + lax.fori_loop(0, k_iters, _loop_body, None) + + @pl.when(jnp.logical_and(warp_id == 1, is_lead_block)) + def _compute(): + def _loop_body(ki, _): + ... # Wait for the data to arrive, as previously. + plgpu.tcgen05_mma( + ..., + collective_axis="cluster", + ) + lax.fori_loop(0, k_iters, _loop_body, None) + plgpu.tcgen05_commit_arrive(mma_done_barrier, collective_axis="cluster") +``` + +You can see a few modifications here. First of all, both blocks must issue the +async copies. In both blocks we request a copy of the full window for the whole +cluster, but the addition of `collective_axes="cluster"` indicates that the load +is performed jointly by both blocks. `partitioned_axis=` specifies which axis of +the operand should be split across the cluster. We split the LHS rows and RHS +columns. + +```{warning} +A partitioned collective copy only completes the barrier passed in to `copy_gmem_to_smem` +in the leader block of the cluster! This is why you will see the kernel never +awaits the loads in the second block. +``` + +Secondly, as mentioned before, we additionally predicate the `_compute` body so +that only the leader block runs MMA instructions. All `tcgen05` calls additionally +get a `collective_axis=` argument, to indicate that the completion of MMAs should +complete the barriers in both blocks in the cluster. + +Finally, we apply a small modification to the epilogue. Even though the two +blocks in the cluster collectively compute a result of shape `(cluster_tile_m, cluster_tile_n)`, +each individual block only holds a result of shape `(tile_m, cluster_tile_n)`. +We change the output slicing code to need to slice out the right `out_gmem_window`: + +```python +def matmul3(a, b, config): + ... + def kernel(...): + ... # Compute + + plgpu.barrier_wait(mma_done_barrier) + out_m_index = m_index * 2 + lax.axis_index("cluster") + out_m_slice = pl.ds(out_m_index * tile_m, tile_m) + out_gmem_window = out_gmem.at[out_m_slice, n_slice] + for ni in range(cluster_tile_n // config.epilogue_tile_n): + ... + + ... +``` + +## 4. Persistent kernel + +Our next step is to make the kernel persistent. This means that we'll only +launch however many clusters we can actually run concurrently on the GPU (SM +count divided by 2), and we'll have each cluster loop over a fixed number of +output tiles. This technique allows us to better amortize block +(de)initialization costs (since they are only performed once on each SM) and +achieve a small degree of overlap between the SMEM to GMEM copy in the epilogue +with the compute on the next output tile. + +```python +def matmul4(a, b, config): + ... + + num_sms = jax.extend.backend.get_default_device().core_count + f = plgpu.kernel( + ... + grid=(num_sms // 2,), + grid_names=("cluster_grid",), + ... + ) +``` + +The change is relatively simple. We utilize the {py:func}`plgpu.nd_loop ` +helper to specify that our iteration space is `(m_iters, n_iters)`, but we also +request that it should be split accross the cluster grid using the `collective_axes=` +argument. + +```python +def matmul4(a, b, config): + ... + + def kernel(...): + is_lead_block = lax.axis_index("cluster") == 0 + + @plgpu.nd_loop((m_iters, n_iters), collective_axes="cluster_grid") + def _mn_loop(loop_info: plgpu.NDLoopInfo): + m_index, n_index = loop_info.index + m_slice = ... + n_slice = ... + + ... # Compute + epilogue +``` + +The only meaningful modification in the compute portion of the kernel body is +to ensure that the first few waits on `consumed_barriers` in the memory warp +are only skipped when processing the first output tile (as indicated by +`loop_info.local_index == 0`). When processing the second (or later) tile, the SMEM buffers +were used to compute the previous output tile, so we need to ensure that those +computations have completed before we overwrite them: + +```python +def matmul4(a, b, config): + ... + def kernel(...): + ... + def _mn_loop(...): + ... + + @pl.core_map(plgpu.WarpMesh(axis_name="warp")) + def _per_warp(): + warp_id = lax.axis_index("warp") + + @pl.when(warp_id == 0) + def _memory(): + def _loop_body(ki, _): + slot = lax.rem(ki, max_concurrent_steps) + @pl.when(jnp.logical_or(ki >= max_concurrent_steps, loop_info.local_index > 0)) + def _(): # Make sure the data has been consumed before overwriting. + plgpu.barrier_wait(consumed_barriers.at[slot]) +``` + +Finally, we modify the kernel epilogue by appending a single line: +```python +def matmul4(a, b, config): + ... + def kernel(...): + ... + def _mn_loop(...): + ... # Compute + epilogue + plgpu.wait_load_tmem() # Load must complete before MMA can overwrite TMEM. +``` + +As the comment indicates, since [TMEM loads are asynchronous](#tmem-loads), +we must await their completion before we move on to the next output tile and +overwrite our TMEM allocation by issuing another MMA. + +## 5. Dedicated epilogue warpgroup + +While persistence was useful by itself, it also unlocks another optimization. +When the single Pallas thread in our kernel finishes the compute portion of the +kernel, it performs the entire epilogue. However, this means that it can't issue +any more work for the TensorCore until it's done! + +This leads us to a simple solution: just use 2 Pallas threads (warpgroups)! The +first one will only focus on fetching the MMA operands and issuing the MMA +operations, while the second one will only perform the epilogue! Of course, to +enable them to run concurrently, we need to double-buffer the TMEM used for +the accumulator, and use additional barriers to synchronize: + +```python +def matmul5(a, b, config): + ... + + f = plgpu.kernel( + ..., + num_threads=2, + thread_name="wg", + scratch_shapes=dict( + ... + # Previously: plgpu.TMEM((tile_m, cluster_tile_n), jnp.float32, collective=True), + acc_tmem=plgpu.TMEM( + (tile_m, 2 * cluster_tile_n), jnp.float32, collective=True + ), + ... + # mma_done_barrier (now 2 barriers) + a new store_done_barrier (also 2 barriers) + # Previously: plgpu.Barrier(num_arrivals=1, num_barriers=1, orders_tensor_core=True), + mma_done_barrier=plgpu.Barrier( + num_arrivals=1, num_barriers=2, orders_tensor_core=True + ), + store_done_barrier=plgpu.ClusterBarrier( + collective_axes=("cluster",), + num_arrivals=1, + num_barriers=2, + orders_tensor_core=True, + ), + ), + ) +``` + +The kernel begins similarly to what we had before. We renamed `acc_tmem` to `acc_tmem_slots` +and switch between its halves as we step through the loop over the output tiles: + +```python +def matmul(a, b, config): + ... + + def kernel(a_gmem, b_gmem, out_gmem, + a_smem, b_smem, acc_tmem_slots, acc_smem, + load_barriers, consumed_barriers, mma_done_barrier, store_done_barrier): + wg_idx = lax.axis_index("wg") + is_lead_block = ... + + @plgpu.nd_loop(...) + def _mn_loop(...): + ... + acc_slot = lax.rem(loop_info.local_index, jnp.int32(2)) + acc_tmem = acc_tmem_slots.at[:, pl.ds(acc_slot * cluster_tile_n, cluster_tile_n)] + + ... +``` + +The compute portion is additionally predicated on `wg_idx == 0`. There are also +two important changes to how we use the barriers. First of all, if we want to +reuse our TMEM allocation for MMA (which happens only for `loop_info.local_index >= 2`), +we need to wait on the `store_done_barrier` for the TMEM half we want to reuse +(as indicated by `acc_slot`). Secondly, once we want to request the TensorCore +to arrive on the `mma_done_barrier` upon completion, we again need to select one +of the two barriers that corresponds to the currently used half of TMEM. + +```{warning} +Note that even though only one of the blocks in the cluster issues MMAs, they +both await the `store_done_barrier`. This is only necessary, because arriving on +the same barrier twice without a `wait` in between sometimes leads to hardware +assertions. +``` + +```python +def matmul(a, b, config): + ... + def kernel(...): + ... + def _mn_loop(...): + acc_slot = ... + acc_tmem = ... + + @pl.when(wg_idx == 0) + def _compute_wg(): + @pl.core_map(plgpu.WarpMesh(axis_name="warp")) + def _per_warp(): + warp_id = lax.axis_index("warp") + + @pl.when(warp_id == 0) + def _memory(): + ... # Memory code remains unchanged + + # Wait for store to complete (except for the first two steps). + @pl.when(jnp.logical_and(warp_id == 1, loop_info.local_index >= 2)) + def _wait_store(): + plgpu.barrier_wait(store_done_barrier.at[acc_slot]) + @pl.when(jnp.logical_and(warp_id == 1, is_lead_block)) + def _compute(): + ... # Compute loop remains unchanged + plgpu.tcgen05_commit_arrive(mma_done_barrier.at[acc_slot], collective_axis="cluster") +``` + +Finally, we modify the epilogue, by only having the second warpgroup execute +it, and by making the warpgroup signal the completion of the store by arriving +on the `store_done_barrier` associated with the half of TMEM it used. + +```python +def matmul(a, b, config): + ... + def kernel(...): + ... + def _mn_loop(...): + ... # Compute + + @pl.when(wg_idx == 1) + def _store_wg(): + ... # Unmodified epilogue + plgpu.wait_load_tmem() # Load must complete before we signal. + plgpu.barrier_arrive(store_done_barrier.at[acc_slot]) +``` + +## 6. Grid tiling + +Our final change to this kernel is to change the order in which we produce the +output blocks to better utilize L2. As mentioned before, the compute units are +extremely fast compared to the memory system and so we could use all the help +we can get to try to keep them busy. + +```{note} +This is trick goes by many different names. CUTLASS calls it "rasterization order", +ThunderKittens calls it "supergrouping", while the Triton tutorials call it +"program re-ordering". We use the name "grid tiling". +``` + +Our strategy for this is inspired by CUTLASS and works as follows. First, you +select which of the two dimensions in your iteration space is the faster changing +(we call it `grid_minor_dim`). Then, you select the tile size along that dimension +(`grid_tile_width`). Instead of traversing the whole minor dimension of the grid +before incrementing the more major index, we do it every time we traverse +`grid_tile_width` elements. Once we run out of elements, we move on to the next +tile. But there's a twist! Instead of jumping to the beginning of the second tile, +we start from the end and work our way back. This ensures that as we switch the +tiles, we can reuse some of the recent blocks of one of the operands. + +Since this strategy is so common, we provide a helper for it: {py:func}`plgpu.planar_snake `. +When using the helper, the changes to the kernel are quite trivial: + +```python +def matmul(a, b, config): + ... + def kernel(...): + ... + # We now only iterate over a 1D loop (but we still split it across clusters). + @plgpu.nd_loop((m_iters * n_iters,), collective_axes="cluster_grid") + def _mn_loop(loop_info: plgpu.NDLoopInfo): + (lin_idx,) = loop_info.index + m_index, n_index = plgpu.planar_snake( + lin_idx, # Linear index. + (m_iters, n_iters), # The 2D iteration space. + config.grid_minor_dim, # 0 or 1, indicates the fastest changing dim. + config.grid_tile_width, # The width of tiles along the fastest changing dim. + ) + ... # Rest of the code remains unmodified +``` + +This simple trick is _incredibly effectful_ and is crucial in achieving state of +the art performance. + +## Final kernel + +You've reached the end of this tutorial, congratulations! In the previous +sections, we focused only on the differences between the different kernels and +rarely listed the complete source. This is useful to hide the irrelevant details +when extending the implementation, but it can also be helpful to see the full +source. So here it is! The whole implementation is less than 150 lines and +reaches SOTA performance (at least on the shape used in our benchmarks). + +```python +def matmul6(a, b, config: TuningConfig): + dtype = a.dtype + m, k = a.shape + _, n = b.shape + tile_m, tile_n, tile_k = config.tile_m, config.tile_n, config.tile_k + swizzle = plgpu.find_swizzle(tile_k * jnp.dtype(dtype).itemsize * 8) + swizzle_elems = swizzle // jnp.dtype(dtype).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), plgpu.SwizzleTransform(swizzle) + ) + if m % tile_m != 0: + raise ValueError(f"{m=} must be divisible by {tile_m=}") + if n % tile_n != 0: + raise ValueError(f"{n=} must be divisible by {tile_n=}") + if k % tile_k != 0: + raise ValueError(f"{k=} must be divisible by {tile_k=}") + cluster_tile_m = 2 * tile_m + cluster_tile_n = 2 * tile_n + m_iters = m // cluster_tile_m + n_iters = n // cluster_tile_n + k_iters = k // tile_k + max_concurrent_steps = config.max_concurrent_steps + + def kernel(a_gmem, b_gmem, out_gmem, + a_smem, b_smem, acc_tmem, acc_smem, + load_barriers, consumed_barriers, mma_done_barrier, store_done_barrier): + wg_idx = lax.axis_index("wg") + is_lead_block = lax.axis_index("cluster") == 0 + + @plgpu.nd_loop((m_iters * n_iters,), collective_axes="cluster_grid") + def _mn_loop(loop_info: plgpu.NDLoopInfo): + (lin_idx,) = loop_info.index + m_index, n_index = plgpu.planar_snake( + lin_idx, + (m_iters, n_iters), + config.grid_minor_dim, + config.grid_tile_width, + ) + m_slice = pl.ds(m_index * cluster_tile_m, cluster_tile_m) + n_slice = pl.ds(n_index * cluster_tile_n, cluster_tile_n) + acc_slot = lax.rem(loop_info.local_index, jnp.int32(2)) + mn_acc_tmem = acc_tmem.at[:, pl.ds(acc_slot * cluster_tile_n, cluster_tile_n)] + + @pl.when(wg_idx == 0) + def _compute_wg(): + @pl.core_map(plgpu.WarpMesh(axis_name="warp")) + def _per_warp(): + warp_id = lax.axis_index("warp") + + @pl.when(warp_id == 0) + def _memory(): + def _loop_body(ki, _): + slot = lax.rem(ki, max_concurrent_steps) + @pl.when(jnp.logical_or(ki >= max_concurrent_steps, loop_info.local_index > 0)) + def _(): # Make sure the data has been consumed before overwriting. + plgpu.barrier_wait(consumed_barriers.at[slot]) + k_slice = pl.ds(ki * tile_k, tile_k) + plgpu.copy_gmem_to_smem( + a_gmem.at[m_slice, k_slice], a_smem.at[slot], load_barriers.at[slot], + collective_axes="cluster", partitioned_axis=0 + ) + plgpu.copy_gmem_to_smem( + b_gmem.at[k_slice, n_slice], b_smem.at[slot], load_barriers.at[slot], + collective_axes="cluster", partitioned_axis=1 + ) + + lax.fori_loop(0, k_iters, _loop_body, None) + + # Wait for store to complete (except for the first two steps). + @pl.when(jnp.logical_and(warp_id == 1, loop_info.local_index >= 2)) + def _wait_store(): + plgpu.barrier_wait(store_done_barrier.at[acc_slot]) + @pl.when(jnp.logical_and(warp_id == 1, is_lead_block)) + def _compute(): + def _loop_body(ki, _): + slot = lax.rem(ki, max_concurrent_steps) + plgpu.barrier_wait(load_barriers.at[slot]) # Wait for data to arrive. + plgpu.tcgen05_mma( + mn_acc_tmem, + a_smem.at[slot], + b_smem.at[slot], + consumed_barriers.at[slot], + accumulate=(ki > 0), + collective_axis="cluster", + ) + lax.fori_loop(0, k_iters, _loop_body, None) + plgpu.tcgen05_commit_arrive( + mma_done_barrier.at[acc_slot], + collective_axis="cluster", + ) + + @pl.when(wg_idx == 1) + def _store_wg(): + # Ensure that copies from the previous mn step have completed. + plgpu.wait_smem_to_gmem(0, wait_read_only=True) + plgpu.barrier_wait(mma_done_barrier.at[acc_slot]) + out_m_index = m_index * 2 + lax.axis_index("cluster") + out_m_slice = pl.ds(out_m_index * tile_m, tile_m) + out_gmem_window = out_gmem.at[out_m_slice, n_slice] + for ni in range(cluster_tile_n // config.epilogue_tile_n): + acc_smem_ni = acc_smem.at[ni % 2] + ni_slice = pl.ds(ni * config.epilogue_tile_n, config.epilogue_tile_n) + # Make sure that previous copy is done before we overwrite. + plgpu.wait_smem_to_gmem(1, wait_read_only=True) + acc_smem_ni[...] = plgpu.async_load_tmem(mn_acc_tmem.at[:, ni_slice]).astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(acc_smem_ni, out_gmem_window.at[:, ni_slice]) + plgpu.wait_load_tmem() # Load must complete before we signal. + plgpu.barrier_arrive(store_done_barrier.at[acc_slot]) + plgpu.wait_smem_to_gmem(0, wait_read_only=True) + + num_sms = backend.get_default_device().core_count + f = plgpu.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct((m, n), dtype), + grid=(num_sms // 2,), + grid_names=("cluster_grid",), + cluster=(2,), + cluster_names=("cluster",), + num_threads=2, + thread_name="wg", + scratch_shapes=dict( + a_smem=plgpu.SMEM( + (max_concurrent_steps, tile_m, tile_k), dtype, transforms=transforms + ), + b_smem=plgpu.SMEM( + (max_concurrent_steps, tile_k, tile_n), dtype, transforms=transforms + ), + acc_tmem=plgpu.TMEM( + (tile_m, 2 * cluster_tile_n), jnp.float32, collective=True + ), + acc_smem=plgpu.SMEM( + (2, tile_m, config.epilogue_tile_n), dtype, transforms=transforms + ), + load_barriers=plgpu.Barrier( + num_arrivals=2, num_barriers=max_concurrent_steps + ), + consumed_barriers=plgpu.Barrier( + num_arrivals=1, + num_barriers=max_concurrent_steps, + orders_tensor_core=True, + ), + mma_done_barrier=plgpu.Barrier( + num_arrivals=1, num_barriers=2, orders_tensor_core=True + ), + store_done_barrier=plgpu.ClusterBarrier( + collective_axes=("cluster",), + num_arrivals=1, + num_barriers=2, + orders_tensor_core=True, + ), + ) + ) + return f(a, b) +``` diff --git a/docs/pallas/gpu/collective_matmul.md b/docs/pallas/gpu/collective_matmul.md new file mode 100644 index 000000000000..e8176110eaab --- /dev/null +++ b/docs/pallas/gpu/collective_matmul.md @@ -0,0 +1,303 @@ +# Collective matrix multiplication + +Tensor parallelism (TP) and data parallelism (DP) are the most frequently used +parallelism techniques that make it possible to fit the ever larger models onto +a number of accelerators. However, their joint use means that in our programs, +we sometimes end up with data sharded in ways that don't make it directly +possible to execute an operation without additional communication. One such +problem frequently happens at the beginning of the MLP block of a Transformer. +There, the input activations might be sharded on the batch axis (DP), while the +weights might be partitioned on the output feature dimension (TP). + +
Left matrix is split into halves by rows, right matrix is split into halves by columns
+ +The contraction dimension is not sharded, so it might seem that we can just +multiply the inputs, but there is a problem: the output can't be sharded along +the same device axis on both of its dimensions! + +There's a simple way to solve this problem: we can all-gather activations or +weights (here we focus on the activation side), and then perform a local matrix +multiplication with the other operand sharded. This simple strategy works, but +it has a downside: we can't begin computing the matrix multiplication while the +all-gather is running! That means we're underutilizing our hardware! + +To achieve better utilization, we'll show how simple it is to implement a +Pallas:MGPU kernel that overlaps the cross-device communication with the +matrix-multiplication, achieving almost optimal utilization on large enough +problem shapes. Our implementation makes heavy use of the NVLINK interconnect, +which allows us to perform high-bandwidth inter-GPU communication without +involving the host. + +This approach already yields considerable performance improvements! If we +consider a f16 matmul with M=1024, K=4096 and N=4096 and normally distributed +data, our benchmarks indicate that it should take about 43us on a single H100. +In the table below, we scale up the M dimension so that the per-shard shape is +M=1024. We can compute an expected lower bound for the execution of our +distributed kernel by multiplying that local runtime estimate by the number of +devices and by adding about 6us for each round of communication (the memory +fences associated with the synchronization are expensive). Benchmarking our +kernel yields the following results: + +| Device count | Kernel time | TC utilization | Lower bound | TC utilization | Reference time | TC utilization | +|--------------|-------------|----------------|-------------|----------------|----------------|----------------| +| 2 | 102us | 68% | 92us | 75% | 147us | 47% | +| 4 | 212us | 66% | 190us | 73% | 290us | 48% | +| 8 | 436us | 64% | 386us | 72% | 565us | 49% | + +As you can see there are still some opportunities for optimization here, but at +least we're getting much better utilization compared to the baseline +implementation of a NCCL all gather and cuBLAS matmul. + +## Algorithm overview: Ring All-Gather + +To compute `AllGather(A) @ B`, we form a ring on the participating `D` devices. +At each step, the device takes the last received shard (starting from its local +shard), and passes it to the next device in the ring. While the send is +happening, we compute the matrix multiplication between the last received `A` shard +and the local `B` shard. + +![all_gather](../../_static/pallas/distributed/all_gather.svg) + +More formally, the algorithm proceeds in `D` steps. In step `i` (`0 <= i < D`), +device `d` receives shard `A_{(d + i) % D}` (we don't actually receive in the +first step) from device `(d + 1) % D`, computes `A_{(d + i) % D} @ B_d`, and +writes the result to a slice of the output buffer. Concurrently with the +compute, the device `d` sends shard `A_{(i + d) % D}` to device `(i - 1) % D` +for its use in step `i + 1` (we don't send in the last step). After `D` steps, +device `d` will have seen every shard of `A` and computed the full output. + +## Pallas primitives for inter-device communication + +We use three Pallas functions for inter-device communication: + +* **`plgpu.remote_ref(ref, device_id)`**: This function takes a reference to a + buffer in global memory (GMEM) and returns a reference to the same buffer on a + *different* device, specified by `device_id`. When communicating over NVLINK, + this reference can be read or written to directly, even though its data is located + in remote memory. +* **`pl.semaphore_signal(sem, device_id=...)`**: Increments a semaphore on a + target device. This is usually used to indicate completion of some process, + such as when we notify the remote device that the data it's waiting for has + been sent. +* **`pl.semaphore_wait(sem, value=..., decrement=...)`**: Blocks until a local + semaphore reaches a certain value. If decrement is `True` (default), the + value of the semaphore is decreased by the awaited amount. If it is `False`, + the operation is more efficient, but it does not modify the value of the + semaphore after the wait completes. This is frequently used to await signals + from a remote device. + +## Implementation with Pallas + +```{note} +Here, we only present a simplified version of the kernel, which allows us to +focus on the most interesting details. You can find [the full implementation in +our examples directory](https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py). +``` + +First, we focus on the set-up of our kernel. For the compute part, we will reuse +our optimized matmul kernel implementation from `hopper_matmul_mgpu`. Since the +compute kernel will utilize warp-specialization, we use 3 Pallas threads. It +is also persistent, which means that we launch a grid as large as the number of +SMs (queried from `.core_count` on the JAX device). The compute kernel uses +`pl.run_scoped` for SMEM allocations, so we don't use `scratch_shapes`. + +```python +def all_gather_lhs_matmul( + lhs: jax.Array, + rhs: jax.Array, + axis_name, + *, + config: hopper_matmul_mgpu.TuningConfig, + dtype: jnp.dtype = jnp.bfloat16, +) -> jax.Array: + if (num_devices := jax.device_count()) != jax.process_count(): + raise ValueError("The kernel only supports one device per process") + if (axis_size := lax.axis_size(axis_name)) != num_devices: + raise ValueError("The kernel can only work over all devices in a Mesh.") + ... + + m_shard, k = lhs.shape + _, n_shard = rhs.shape + tile_m, tile_n, tile_k = config.tile_m, config.tile_n, config.tile_k + cta_tile_m = tile_m * (1 + (config.wg_dimension == MatmulDimension.M)) + num_sms = jax.extend.backend.get_default_device().core_count + + def kernel_body(lhs_local_ref, rhs_ref, out_ref, scratch_ref): + ... + + result, _ = plgpu.kernel( + kernel_body, + out_shape=[ + # The output (with M gathered) + jax.ShapeDtypeStruct((axis_size * m_shard, n_shard), dtype), + # A scratch buffer for LHS all-gather + jax.ShapeDtypeStruct((axis_size - 1, m_shard, k), dtype), + ], + grid=(num_sms,), + num_threads=3, # The matmul kernel uses 3 threads: 2 compute and 1 memory + thread_name="wg", + )(lhs, rhs) + return result +``` + +The kernel above has two outputs. First one is the actual result of our +primitive, while the second one is used as a scratch space to receive the left +operands. Note that we could shrink the leading axis to be smaller than +`axis_size - 1`, but at that point we would need to introduce backpressure to +the sending devices, which requires additional expensive communication. + +```{note} +You can see how to deal with this backpressure in the [TPU distributed communication guide](../tpu/distributed.md#run-ahead-and-race-conditions). +``` + +Let us now look at the outline of the kernel body: + +```python +def all_gather_lhs_matmul(...): + def kernel_body(lhs_local_ref, rhs_ref, out_ref, scratch_ref, out_smem, received_sem): + wg_idx = lax.axis_index("wg") + dev_id = lax.axis_index(axis_name) + # This device sends to dev_id - 1, forming a ring. + send_dev_id = lax.rem(dev_id + axis_size - 1, axis_size) + send_scratch_ref = plgpu.remote_ref(scratch_ref, send_dev_id) + + def device_step(lhs_source_ref, device_offset): + # Invariant: lhs_source_ref contains A_{(dev_id + device_offset) % D} + # and is ready to be used for computation. + + ... + + # We peel the first step to read data directly from lhs_local_ref. + device_step(lhs_local_ref, 0) + @pl.loop(1, num_devices) + def _device_loop(device_offset): + device_step(scratch_ref.at[device_offset - 1], device_offset) +``` + +We locate our position in the ring by querying `lax.axis_index(axis_name)` and +compute the index of the next device, to which we will be sending the data +(`send_dev_id`). Then, we loop over invocations of the `device_body` as many +times as there are devices. We peel the first step of the loop, because we use +the local reference as the source for the send in that step only (after that the +sends originate from the data previously received in the scratch buffer). + +We are ready to investigate the main loop now: + +```python +def all_gather_lhs_matmul(...): + ... + + def kernel_body(lhs_local_ref, rhs_ref, out_ref, scratch_ref, out_smem, received_sem): + ... + + def device_step(lhs_source_ref, device_offset): + # We are computing block (dev_id + device_offset) % D of the output. + out_device_idx = lax.rem(device_offset + dev_id, axis_size) + out_device_m_slice = pl.ds(out_device_idx * m_shard, m_shard) + + # In step `device_offset`, we send A_{(dev_id + device_offset) % D} to + # the next device in the ring, into scratch slot `device_offset`. + # We also don't send on the last step since that would return the data + # back to its original source. + next_scratch_slot = device_offset + is_send_wg = wg_idx == 0 # Only one warpgroup per CTA sends + has_send_space = next_scratch_slot < axis_size - 1 + should_send = is_send_wg & has_send_space + + # This function will be called by hopper_matmul_mgpu.kernel in the body + # of its pipeline. We use it to take the tile of LHS loaded into SMEM and + # issue a TMA send to the next device in the ring. + def send_lhs(m_idx, n_idx, k_idx, a_smem, b_smem, send_ref, should_send): + del b_smem # Unused. + # We only send when n_idx == 0 to avoid sending the same data + # multiple times when revisiting the left operand. + @pl.when(should_send & jnp.bool(n_idx == 0)) + def _(): + k_slice = pl.ds(k_idx * tile_k, tile_k) + m_slice = pl.ds(m_idx * cta_tile_m, cta_tile_m) + plgpu.copy_smem_to_gmem(a_smem, send_ref.at[m_slice, k_slice]) + # Wait for previous copies to complete. We pass in delay_release=1 + # to the pipeline in the matmul kernel to ensure that it doesn't + # overwrite the input until at least the next step completes, but it + # will not wait any longer. + plgpu.wait_smem_to_gmem(1, wait_read_only=True) + + hopper_matmul_mgpu.kernel( + lhs_source_ref, # LHS shard for this step + rhs_ref, # RHS shard is always the same + out_ref.at[out_device_m_slice], # Slice of output to update + out_smem, + config=config, + pipeline_callback=functools.partial( + send_lhs, + send_ref=send_scratch_ref.at[next_scratch_slot], + should_send=should_send, + ), + delay_release=1, + ) + + # Wait for the next scratch to arrive for the next step's computation. + # Each device signals its neighbor when it has finished sending. + @pl.when(should_send) + def _signal(): + # Make sure our remote copy is done, then signal. + plgpu.wait_smem_to_gmem(0, wait_read_only=False) + pl.semaphore_signal(received_sem, device_id=send_dev_id) + @pl.when(has_send_space) + def _wait(): + # Here, we wait for the data to arrive from the previous device in the + # ring. At each step, will expect to receive a signal from each SM. + # We use decrement=False to make this operation slightly faster, but + # this also means that we need to scale the expected number of signals + # by the number of steps taken so far (as the value only increases). + pl.semaphore_wait(received_sem, value=(device_offset + 1) * num_sms, decrement=False) + + ... +``` + +A few things happen here in a sequence: +1. We begin by computing the slice of the + output that we will compute at this step of the loop. +2. Then, we call into the optimized matmul kernel, but injecting it with a + `pipeline_callback`. We use it to take advantage of the fact that the compute + kernel has to fetch the left operand into SMEM, and we instruct the TMA engine + to asynchronously stream the local data to the next device. The traffic is + transparently routed through NVLINK by the hardware. It is worth noting that we + only issue sends from one of the compute threads and only when we visit the left + operand for the first time (it might be reloaded many times to compute many + output tiles). +3. Finally, the sending thread makes sure that the sends have completed and + signals the `received_sem` on the receiving device to indicate that. After + that, all threads wait until they are sure that all the data for the next + step of the loop has been received (the wait is skipped on the last step). + +## Integrating the kernel with JAX + +To invoke the kernel, you need to wrap it into `jax.shard_map`: +```python +m_shard, n_shard, k = 1024, 1024, 1024 +dtype = jnp.float16 +mesh = jax.make_mesh((jax.device_count(),), ("x",), + axis_types=(jax.sharding.AxisType.Explicit,)) +with jax.set_mesh(mesh): + a = jax.random.normal(jax.random.key(1), (m_shard * jax.device_count(), k), dtype) + b = jax.random.normal(jax.random.key(2), (k, n_shard * jax.device_count()), dtype) + a = jax.sharding.reshard(a, P("x", None)) + b = jax.sharding.reshard(b, P(None, "x")) + + # Example config for 8xH100. You might need to retune to your shape. + config = hopper_matmul_mgpu.TuningConfig( + tile_m=128, tile_n=128, tile_k=64, max_concurrent_steps=4, + grid_minor_dim=MatmulDimension.N, grid_tile_width=8, + wg_dimension=MatmulDimension.N, + ) + + kernel = jax.jit( + jax.shard_map( + functools.partial(all_gather_lhs_matmul, axis_name="x", config=config), + out_specs=P(None, "x"), + check_vma=False, + ) + ) + c = kernel(a, b) +``` \ No newline at end of file diff --git a/docs/pallas/gpu/index.rst b/docs/pallas/gpu/index.rst new file mode 100644 index 000000000000..1dc3b2e3373e --- /dev/null +++ b/docs/pallas/gpu/index.rst @@ -0,0 +1,17 @@ +Pallas:Mosaic GPU +================= +Backend specific documentation for the Mosaic GPU backend. + +.. toctree:: + :caption: Reference documentation + :maxdepth: 2 + + reference + pipelining + blackwell_matmul + collective_matmul + +.. toctree:: + :caption: Guides + :maxdepth: 2 + diff --git a/docs/pallas/gpu/pipelining.ipynb b/docs/pallas/gpu/pipelining.ipynb new file mode 100644 index 000000000000..392421b69654 --- /dev/null +++ b/docs/pallas/gpu/pipelining.ipynb @@ -0,0 +1,428 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "9552ee76", + "lines_to_next_cell": 0 + }, + "source": [ + "(pallas_mgpu_pipelining)=" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bJ5yuIr-M0x0" + }, + "source": [ + "\n", + "## Mosaic GPU Pipelining\n", + "\n", + "This guide covers software pipelining using the Mosaic GPU backend for Pallas.\n", + "\n", + "For a general overview of the pipelining API in Pallas, we recommend that users first read {ref}`pallas_software_pipelining`. Pipelining in Pallas is programmed explicitly. For those who are familiar with Triton, this is a significant difference in programming model because in Triton, pipelining is an optimization that is done automatically by the compiler.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dGAa3iO5DoRT" + }, + "outputs": [], + "source": [ + "import jax\n", + "from jax import lax\n", + "from jax import numpy as jnp\n", + "from jax.experimental.pallas import mosaic_gpu as plgpu\n", + "from jax.experimental import pallas as pl\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Pv9j90hVyswo" + }, + "source": [ + "\n", + "### Pipelining with Mosaic GPU\n", + "\n", + "The recommended approach to pipeline using Mosaic GPU is to use the `plgpu.emit_pipeline` function to pipeline over sequential loops (and to use `plgpu.kernel` to partition the problem in parallel over the CUDA grid). `emit_pipeline` follows a similar API as `pl.pallas_call` except it exposes a few additional GPU-specific options.\n", + "\n", + "- `body`, `grid` have similar semantics as in `pl.pallas_call`. The `grid` denotes how many invocations of the `body` function to run. In contrast with a CUDA grid, the pipeline grid is guaranteed to run sequentially.\n", + "- `in_specs` and `out_specs` also work similarly to `pl.pallas_call`, except they also accept `plgpu.BlockSpec` instances that can be used specify GPU-specific transforms, such as swizzling. See [memory reference transforms](https://docs.jax.dev/en/latest/pallas/gpu/reference.html#memory-reference-transforms) for more detail on available transformations.\n", + "- `max_concurrent_steps` controls the maximum number of concurrent memory transfers. Using additional concurrent steps will consume more SMEM to hold temporary buffers, but it can improve the utilization of the memory subsystem. We recommend autotuning this parameter. Low values (e.g. 2) can sometimes achieve higher occupancy (due to lower SMEM usage) which can improve throughput in ALU-heavy kernels, but will introduce more noise due to the hardware taking care of scheduling. Larger values (between 4 and 6) will work best for kernels that can't take advantage of extra occupancy\n", + "- `delay_release` allows the user to specify an additional number of iterations to wait before the buffer is re-used by the pipeline. For example, a buffer copied into SMEM on iteration 0 with `delay_release=1` and `max_concurrent_steps=2` will not be re-used until iteration 3, as opposed to iteration 2 for a standard double-buffered strategy. `delay_release=1` is necessary if you don't await a `plgpu.wgmma` operation on the pipeline operands, as otherwise the pipeline will begin overwriting the buffers while the WGMMA is still reading them. This is useful for certain optimizations such as allowing multiple async matmuls in flight to keep the tensor core pipeline filled, but care must be taken when using such a strategy as **omitting this parameter will silent data races**, and it reduces the efficiency of `emit_pipeline` as we are overlapping fewer memory transfers.\n", + "\n", + "#### Compatibility API using `pl.pallas_call`\n", + "\n", + "As an alternative to `emit_pipeline` and to maintain compatibility with Pallas TPU, Mosaic GPU also implements the existing `pl.pallas_call` API. By default, `pl.pallas_call` on Mosaic GPU will partition your kernel in parallel over the CUDA grid. You can opt-in to pipelining by passing in a `plgpu.CompilerParams` object as the `compiler_params` argument, which specifies the following options that are relevant for pipelining:\n", + "- `dimension_semantics`: A tuple of `Literal['parallel', 'sequential']` that specifies iteration semantics for each grid dimension. `parallel` will partition the corresponding dimension over the CUDA grid, and `sequential` dimensions will be pipelined sequentially. **Note that if no dimensions are marked `sequential`, no pipelining will happen!**\n", + "- `max_concurrent_steps`: identical to the option in `plgpu.emit_pipeline`.\n", + "- `delay_release`: identical to the option in `plgpu.emit_pipeline`.\n", + "\n", + "Pipelining lets you re-use scratch buffers across the sequential iterations of the grid (e.g. for implementing reductions). Additionally, `pallas_call` supports using `plgpu.BlockSpec` objects in place of `pl.BlockSpec` objects when using the Mosaic GPU backend, allowing you to specify GPU-specific memory transformations.\n", + "\n", + "We recommend that users use `plgpu.kernel` rather than `pl.pallas_call` as `plgpu.kernel` supports more features (such as specifying the number of warpgroups and warp specialization).\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Qp3X6wylJtoa" + }, + "source": [ + "### GPU Memory Spaces\n", + "\n", + "Refs exist primarily in one of two memory spaces, which can be explicitly specified by the `memory_space` argument of `BlockSpec`, i.e. `BlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)`.\n", + "\n", + "- `plgpu.GPUMemorySpace.SMEM` allocates a Ref in Shared Memory (SMEM). SMEM Refs can be dereferenced using array indexing syntax to store values in registers for compute, i.e. `x = y_ref[...]`. This memory space used for a Ref when using `emit_pipeline`.\n", + "\n", + "- `plgpu.GPUMemorySpace.GMEM` allocates a Ref in Global Memory (GMEM/HBM). Any Refs allocated in GMEM are not pipelined, and values cannot be accessed directly using array indexing operations. Instead, GMEM must be accessed via SMEM using `plgpu.copy_gmem_to_smem` for reading, or `plgpu.copy_smem_to_gmem` for writing, or pipelined into SMEM using `plgpu.emit_pipeline`.\n", + "\n", + "The primary purpose of `emit_pipeline` is used to overlap TensorCore computation with data transfers between GMEM and SMEM, since asynchronous copies between GMEM/SMEM have a long latency, but all TensorCore computation must operate on registers (or SMEM Refs in the case of matrix multiplication)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0uzcrDCtKABQ" + }, + "source": [ + "### Example: Matmul Kernel on Hopper GPUs" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vILVdlqEdoEK" + }, + "source": [ + "Let's begin with a matrix multiplication example designed to run on Hopper GPUs. This kernel utilizes the Hopper-specific `wgmma` (warpgroup matrix multiply accumulate) instruction. `wgmma` is issued by a single Mosaic GPU thread and runs asynchronously on the TensorCore.\n", + "\n", + "Our example kernel implements a blockwise matrix multiplication of two matrices of shape `[M, K] @ [K, N] = [M, N]`, where each output block is computed in parallel over the CUDA grid. This grid is specified as the `grid` argument to the outer `plgpu.kernel`, and parallelizes over the non-contracting dimensions M, N of the matrix multiplication." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KSvqVNdy726B" + }, + "source": [ + "\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "10ebHCQ571Fn" + }, + "source": [ + "\n", + "Within a program instance, we run a sequential pipeline using `plgpu.emit_pipeline` that reduces over the contracting dimension K of the matrix multiplication. On each iteration of the pipeline, we load one tile from each input matrix, multiply them, and then store the result in an accumulator Ref (`plgpu.ACC`). `plgpu.ACC` is a special type of Ref that lives in registers and holds the intermediate results of WGMMA. Once we have accumulated over the entire contracting dimension, we write out the result to the output Ref.\n", + "\n", + "To perform the actual matrix multiplication, we call `plgpu.wgmma` with the accumulator, LHS, and RHS Refs as arguments in order to push the arguments into the TensorCore pipeline. All WGMMA operations are executed in order, so this can be viewed as pushing operations into a queue. Since `wgmma` is an asynchronous instruction, `plgpu.wgmma_wait(N)` is used to wait until there are no more than N `wgmma` operations left in-flight. In this particular implementation we wait for 1 in-flight WGMMA, meaning that the WGMMA we queue on the current iteration will be waited for on the next iteration.\n", + "- `wgmma` wants it's arguments to be in a specific format, defined in the [CUDA documentation](https://docs.nvidia.com/cuda/parallel-thread-execution/#register-fragments-and-shared-memory-matrix-layouts). These are implemented by the `TilingTransform` and `SwizzleTransform` transformations on the input BlockSpecs. Note that in the future transforms will be inferred automatically by Mosaic GPU and these will not need to be manually specified. See the [wgmma reference](https://docs.jax.dev/en/latest/pallas/gpu/reference.html#hopper-wgmma) for full details on using this instruction.\n", + "- We use the `delay_release` parameter in conjunction with `plgpu.wgmma_wait(1)` to always allow one `WGMMA` operation to stay in-flight in order to ensure good TensorCore utilization. Without this, we would be flushing the TensorCore pipeline on every iteration of the kernel." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "6Vf5_VA9iCD1" + }, + "outputs": [], + "source": [ + "def matmul(a, b, tile_m=128, tile_n=128, swizzle=128):\n", + " dtype = jnp.float16\n", + " swizzle_elems = swizzle // jnp.dtype(dtype).itemsize\n", + " tile_k = swizzle_elems\n", + " grid_m = m // tile_m\n", + " grid_k = k // tile_k\n", + " grid_n = n // tile_n\n", + " assert tile_m % swizzle_elems == 0\n", + "\n", + " # Note: Transforms will be inferred automatically\n", + " # by Mosaic GPU in the future.\n", + " transforms = (\n", + " plgpu.TilingTransform((8, swizzle_elems)),\n", + " plgpu.SwizzleTransform(swizzle),\n", + " )\n", + "\n", + " def kernel(a_gmem, b_gmem, o_gmem, o_smem, acc):\n", + " def pipeline_step(_, a_smem, b_smem):\n", + " plgpu.wgmma(acc, a_smem, b_smem)\n", + " plgpu.wgmma_wait(1)\n", + "\n", + " # pl.program_id obtains the index into the grid.\n", + " pid_m = pl.program_id(0)\n", + " pid_n = pl.program_id(1)\n", + "\n", + " pipeline = plgpu.emit_pipeline(\n", + " pipeline_step,\n", + " in_specs=[\n", + " plgpu.BlockSpec(\n", + " (tile_m, tile_k), lambda k: (pid_m, k), transforms=transforms\n", + " ),\n", + " plgpu.BlockSpec(\n", + " (tile_k, tile_n), lambda k: (k, pid_n), transforms=transforms\n", + " ),\n", + " ],\n", + " grid=(grid_k,),\n", + " max_concurrent_steps=2,\n", + " delay_release=1,\n", + " )\n", + "\n", + " pipeline(a_gmem, b_gmem)\n", + " # Store WGMMA accumulator to SMEM and then to GMEM.\n", + " o_smem[...] = acc[...].astype(dtype)\n", + " plgpu.commit_smem()\n", + " m_slice = pl.ds(pid_m * tile_m, tile_m)\n", + " n_slice = pl.ds(pid_n * tile_n, tile_n)\n", + " plgpu.copy_smem_to_gmem(o_smem, o_gmem.at[m_slice, n_slice])\n", + " plgpu.wait_smem_to_gmem(0)\n", + "\n", + " return plgpu.kernel(\n", + " kernel,\n", + " out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16),\n", + " scratch_shapes=dict(\n", + " o_smem=plgpu.SMEM((tile_m, tile_n), jnp.float16),\n", + " acc=plgpu.ACC((tile_m, tile_n), jnp.float32)\n", + " ),\n", + " # grid specifies the CUDA grid.\n", + " # Instances of `kernel` will be executed in parallel over this grid.\n", + " grid=(grid_m, grid_n),\n", + " grid_names=(\"m\", \"n\"),\n", + " )(a, b)\n", + "\n", + "m = 132 * 128\n", + "n = 4 * 128\n", + "k = 10 * 64\n", + "key1, key2 = jax.random.split(jax.random.key(42), 2)\n", + "a = jax.random.uniform(key1, shape=(m, k), dtype=jnp.float16)\n", + "b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16)\n", + "\n", + "result = matmul(a, b)\n", + "\n", + "np.testing.assert_allclose(result, a @ b)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lIYV7PN9J8Px" + }, + "source": [ + "### Warp Specialization\n", + "\n", + "Warp specialization is a technique where we program each warp/warpgroup to perform a single task in order to give the GPU hardware the flexibility to schedule them at runtime. Recall that each streaming multiprocessor (SM) in a GPU contains warp schedulers that can swap execution between warps, so for example when one warp is stalling it can begin executing a different warp. In practice, this can be more performant than programming a single instruction stream where the compiler must statically schedule the operations and attempt to overlap them optimally.\n", + "\n", + "In particular, we are interested in warpgroup specialization on Hopper+ GPUs, where it can be useful to have a separate warpgroup issuing TMAs (GMEM/SMEM copies) from the warpgroups performing arithmetic, since indexing calculations and issuing TMAs can take up a significant amount of time and potentially leave the TensorCore idle. The figure below depicts a standard, non-specialized kernel on the left where TMAs (async copies) and matrix multiplication are issued from a single instruction stream, and a warp-specialized version on the right where communication and arithmetic are handled on separate warpgroups. A *consumed barrier* is used to synchronize between the specialized warpgroups that signals to the memory warpgroup when it is safe to begin the next TMA.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "n-y90IC7v7vL" + }, + "source": [ + "\n", + "
\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZH0Pui5kFSdD" + }, + "source": [ + "Warp specialization can be enabled in Pallas by using the `plgpu.emit_pipeline_warp_specialized` helper. This pipeline helper handles all of the logic in the memory thread, and the user only needs to specify the work done in the compute threads. It shares the a similar API as the standard `emit_pipeline`, and currently supports the following arguments:\n", + "\n", + "```python\n", + "plgpu.emit_pipeline_warp_specialized(\n", + " body: Callable,\n", + " *\n", + " grid: tuple[int, ...],\n", + " in_specs: Sequence[pallas_core.BlockSpec] = (),\n", + " out_specs: Sequence[pallas_core.BlockSpec] = (),\n", + " max_concurrent_steps: int,\n", + " compute_context: Callable\n", + " num_compute_wgs: int,\n", + " memory_registers: int\n", + " wg_axis: str,\n", + " memory_thread_idx: int | None = None,\n", + ")\n", + "```\n", + "\n", + "There are a few arguments specific to this pipeline emitter, which are:\n", + "- `num_compute_wgs` specifies how many compute threads/warpgroups to use. The pipeline emitter always uses a single memory thread, so in `plgpu.kernel` you should specify `num_threads=num_compute_wgs+1`.\n", + "- `memory_registers` controls how many registers to allocate to the memory thread. The remaining registers are partitioned evenly among the compute threads. The default value is 40 and should be adjusted up or down depending on whether register spills are encountered.\n", + "- `wg_axis` the name of the thread/warpgroup axis (as specified by the `thead_name` argument of `plgpu.kernel`).\n", + "- `memory_thread_idx` specifies which Pallas thread to designate as the memory thread. Defaults to the last thread.\n", + "- `compute_context` is a enables you to specify a prologue/epilogue to the pipeline that only runs in the compute thread. The function allows you to define the initialization and consumption of a loop carry through the pipeline. All compute thread specific arrays should be instantiated here so the memory thread does not materialize them in registers -- otherwise, you may experience slowdowns due to register spills.\n", + "\n", + "The pipeline body of the warp specialized pipeline is run in parallel by all compute threads, and SMEM is shared between compute threads since they are scheduled within the same CUDA block.`lax.axis_index` can be used inside the kernel to obtain the Pallas thread index in order to divide up work amongst compute threads.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZGbK5gIvFZKy" + }, + "source": [ + "### Example: Matrix Multiplication with Warp Specialization\n", + "\n", + "The following example extends the previous matrix multiplication example to use warp specialization. This particular kernel uses 2 compute threads, which operate on separate columns of the RHS matrix but share the same LHS. Each invocation of the pipeline therefore computes 2 adjacent blocks in the output matrix.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NYWBqa9-bp2p" + }, + "source": [ + "\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OkWmfqn7b53M" + }, + "source": [ + "We use the `compute_context` pattern to initialize the WGMMA accumulator, and copy the final accumulator from registers into SMEM. Here, the compute context is defined in the function `compute_thread`. It is critical that the accumulator be created inside of the `compute_thread` function to avoid allocating it in the memory thread which would waste registers. To perform the WGMMA, we wrap the `wgmma` instruction in a `pl.run_state` in order to create an accumulator ref that is initialized to the carry value.\n", + "\n", + "Instead of using `pl.pallas_call` to call the kernel, we instead use the GPU-specific `plgpu.kernel` entry point. `plgpu.kernel` allows us to specify the number of threads to launch per CUDA block via the `num_threads` argument, and allows us to specify a `thread_name` we can use to query the Pallas thread index inside of the kernel.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "EJhWnwJlFGaT" + }, + "outputs": [], + "source": [ + "def matmul_warp_specialized(a, b, tile_m=128, tile_n=128, swizzle=128,\n", + " compute_wgs=2):\n", + " dtype = jnp.float16\n", + " elems_128b = swizzle // jnp.dtype(dtype).itemsize\n", + " tile_k = elems_128b\n", + " grid_m = m // tile_m\n", + " grid_k = k // tile_k\n", + " grid_n = n // tile_n\n", + " assert tile_m % elems_128b == 0\n", + "\n", + " transforms = (\n", + " plgpu.TilingTransform((8, elems_128b)),\n", + " plgpu.SwizzleTransform(128),\n", + " )\n", + "\n", + " def kernel(a_gmem, b_gmem, o_gmem, o_smem):\n", + " wg_idx = lax.axis_index(\"wg\")\n", + " wg_slice = pl.ds(wg_idx * tile_n, tile_n)\n", + " # pl.program_id obtains the index into the pallas_call grid.\n", + " pid_m = pl.program_id(0)\n", + " pid_n = pl.program_id(1)\n", + "\n", + " def compute_thread(pipeline):\n", + " acc = plgpu.layout_cast(\n", + " jnp.full((tile_m, tile_n), 0, dtype=jnp.float32), plgpu.Layout.WGMMA,\n", + " )\n", + " # yield marks the place where the pipelined loop will be inserted.\n", + " # Its argument are the initial carry values, and its result is the carry\n", + " # value after the loop completes.\n", + " final_acc = pipeline(acc)\n", + " o_smem[:, wg_slice] = final_acc[...].astype(dtype)\n", + "\n", + " def kernel_body(_, a_smem, b_smem, carry):\n", + " acc = carry\n", + " b_smem_wg = b_smem.at[:, wg_slice]\n", + " def do_wgmma(acc_ref):\n", + " plgpu.wgmma(acc_ref, a_smem, b_smem_wg)\n", + " acc = pl.run_state(do_wgmma)(\n", + " plgpu.ACC.init(acc))\n", + " return acc\n", + "\n", + " pipeline = plgpu.emit_pipeline_warp_specialized(\n", + " kernel_body,\n", + " in_specs=[\n", + " plgpu.BlockSpec(\n", + " (tile_m, tile_k), lambda k: (pid_m, k), transforms=transforms\n", + " ),\n", + " plgpu.BlockSpec(\n", + " (tile_k, tile_n * 2), lambda k: (k, pid_n),transforms=transforms\n", + " ),\n", + " ],\n", + " grid=(grid_k,),\n", + " compute_context=compute_thread,\n", + " max_concurrent_steps=2,\n", + " num_compute_wgs=compute_wgs,\n", + " memory_registers=40,\n", + " memory_thread_idx=2,\n", + " wg_axis=\"wg\",\n", + " )\n", + " # Call the pipeline\n", + " pipeline(a_gmem, b_gmem)\n", + " # Copy the output from SMEM to GMEM.\n", + " plgpu.commit_smem()\n", + " m_slice = pl.ds(pid_m * tile_m, tile_m)\n", + " n_slice = pl.ds(pid_n * tile_n * 2, tile_n * 2)\n", + " plgpu.copy_smem_to_gmem(o_smem, o_gmem.at[m_slice, n_slice])\n", + " plgpu.wait_smem_to_gmem(0)\n", + "\n", + " return plgpu.kernel(\n", + " kernel,\n", + " out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16),\n", + " scratch_shapes=dict(\n", + " o_smem=plgpu.SMEM((tile_m, tile_n * 2), jnp.float16)\n", + " ),\n", + " grid=(grid_m, grid_n // 2),\n", + " grid_names=(\"m\", \"n\"),\n", + " num_threads=3, # 2 compute, 1 memory.\n", + " thread_name=\"wg\"\n", + " )(a, b)\n", + "\n", + "m = 132 * 128\n", + "n = 4 * 128\n", + "k = 10 * 64\n", + "key1, key2 = jax.random.split(jax.random.key(42), 2)\n", + "a = jax.random.uniform(key1, shape=(m, k), dtype=jnp.float16)\n", + "b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16)\n", + "\n", + "result = matmul_warp_specialized(a, b)\n", + "\n", + "np.testing.assert_allclose(result, a @ b)" + ] + } + ], + "metadata": { + "colab": { + "last_runtime": { + "build_target": "//experimental/users/justinfu/pallas:colab_gpu", + "kind": "private" + }, + "provenance": [] + }, + "jupytext": { + "formats": "ipynb,md", + "main_language": "python" + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/pallas/gpu/pipelining.md b/docs/pallas/gpu/pipelining.md new file mode 100644 index 000000000000..8f779bd7ded5 --- /dev/null +++ b/docs/pallas/gpu/pipelining.md @@ -0,0 +1,332 @@ +--- +jupyter: + jupytext: + formats: ipynb,md + main_language: python + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.16.4 + kernelspec: + display_name: Python 3 + name: python3 +--- + + +(pallas_mgpu_pipelining)= + + + +## Mosaic GPU Pipelining + +This guide covers software pipelining using the Mosaic GPU backend for Pallas. + +For a general overview of the pipelining API in Pallas, we recommend that users first read {ref}`pallas_software_pipelining`. Pipelining in Pallas is programmed explicitly. For those who are familiar with Triton, this is a significant difference in programming model because in Triton, pipelining is an optimization that is done automatically by the compiler. + + + +```python id="dGAa3iO5DoRT" +import jax +from jax import lax +from jax import numpy as jnp +from jax.experimental.pallas import mosaic_gpu as plgpu +from jax.experimental import pallas as pl +import numpy as np +``` + + + +### Pipelining with Mosaic GPU + +The recommended approach to pipeline using Mosaic GPU is to use the `plgpu.emit_pipeline` function to pipeline over sequential loops (and to use `plgpu.kernel` to partition the problem in parallel over the CUDA grid). `emit_pipeline` follows a similar API as `pl.pallas_call` except it exposes a few additional GPU-specific options. + +- `body`, `grid` have similar semantics as in `pl.pallas_call`. The `grid` denotes how many invocations of the `body` function to run. In contrast with a CUDA grid, the pipeline grid is guaranteed to run sequentially. +- `in_specs` and `out_specs` also work similarly to `pl.pallas_call`, except they also accept `plgpu.BlockSpec` instances that can be used specify GPU-specific transforms, such as swizzling. See [memory reference transforms](https://docs.jax.dev/en/latest/pallas/gpu/reference.html#memory-reference-transforms) for more detail on available transformations. +- `max_concurrent_steps` controls the maximum number of concurrent memory transfers. Using additional concurrent steps will consume more SMEM to hold temporary buffers, but it can improve the utilization of the memory subsystem. We recommend autotuning this parameter. Low values (e.g. 2) can sometimes achieve higher occupancy (due to lower SMEM usage) which can improve throughput in ALU-heavy kernels, but will introduce more noise due to the hardware taking care of scheduling. Larger values (between 4 and 6) will work best for kernels that can't take advantage of extra occupancy +- `delay_release` allows the user to specify an additional number of iterations to wait before the buffer is re-used by the pipeline. For example, a buffer copied into SMEM on iteration 0 with `delay_release=1` and `max_concurrent_steps=2` will not be re-used until iteration 3, as opposed to iteration 2 for a standard double-buffered strategy. `delay_release=1` is necessary if you don't await a `plgpu.wgmma` operation on the pipeline operands, as otherwise the pipeline will begin overwriting the buffers while the WGMMA is still reading them. This is useful for certain optimizations such as allowing multiple async matmuls in flight to keep the tensor core pipeline filled, but care must be taken when using such a strategy as **omitting this parameter will silent data races**, and it reduces the efficiency of `emit_pipeline` as we are overlapping fewer memory transfers. + +#### Compatibility API using `pl.pallas_call` + +As an alternative to `emit_pipeline` and to maintain compatibility with Pallas TPU, Mosaic GPU also implements the existing `pl.pallas_call` API. By default, `pl.pallas_call` on Mosaic GPU will partition your kernel in parallel over the CUDA grid. You can opt-in to pipelining by passing in a `plgpu.CompilerParams` object as the `compiler_params` argument, which specifies the following options that are relevant for pipelining: +- `dimension_semantics`: A tuple of `Literal['parallel', 'sequential']` that specifies iteration semantics for each grid dimension. `parallel` will partition the corresponding dimension over the CUDA grid, and `sequential` dimensions will be pipelined sequentially. **Note that if no dimensions are marked `sequential`, no pipelining will happen!** +- `max_concurrent_steps`: identical to the option in `plgpu.emit_pipeline`. +- `delay_release`: identical to the option in `plgpu.emit_pipeline`. + +Pipelining lets you re-use scratch buffers across the sequential iterations of the grid (e.g. for implementing reductions). Additionally, `pallas_call` supports using `plgpu.BlockSpec` objects in place of `pl.BlockSpec` objects when using the Mosaic GPU backend, allowing you to specify GPU-specific memory transformations. + +We recommend that users use `plgpu.kernel` rather than `pl.pallas_call` as `plgpu.kernel` supports more features (such as specifying the number of warpgroups and warp specialization). + + + + +### GPU Memory Spaces + +Refs exist primarily in one of two memory spaces, which can be explicitly specified by the `memory_space` argument of `BlockSpec`, i.e. `BlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)`. + +- `plgpu.GPUMemorySpace.SMEM` allocates a Ref in Shared Memory (SMEM). SMEM Refs can be dereferenced using array indexing syntax to store values in registers for compute, i.e. `x = y_ref[...]`. This memory space used for a Ref when using `emit_pipeline`. + +- `plgpu.GPUMemorySpace.GMEM` allocates a Ref in Global Memory (GMEM/HBM). Any Refs allocated in GMEM are not pipelined, and values cannot be accessed directly using array indexing operations. Instead, GMEM must be accessed via SMEM using `plgpu.copy_gmem_to_smem` for reading, or `plgpu.copy_smem_to_gmem` for writing, or pipelined into SMEM using `plgpu.emit_pipeline`. + +The primary purpose of `emit_pipeline` is used to overlap TensorCore computation with data transfers between GMEM and SMEM, since asynchronous copies between GMEM/SMEM have a long latency, but all TensorCore computation must operate on registers (or SMEM Refs in the case of matrix multiplication). + + + +### Example: Matmul Kernel on Hopper GPUs + + + +Let's begin with a matrix multiplication example designed to run on Hopper GPUs. This kernel utilizes the Hopper-specific `wgmma` (warpgroup matrix multiply accumulate) instruction. `wgmma` is issued by a single Mosaic GPU thread and runs asynchronously on the TensorCore. + +Our example kernel implements a blockwise matrix multiplication of two matrices of shape `[M, K] @ [K, N] = [M, N]`, where each output block is computed in parallel over the CUDA grid. This grid is specified as the `grid` argument to the outer `plgpu.kernel`, and parallelizes over the non-contracting dimensions M, N of the matrix multiplication. + + + + +
+ + + + + +Within a program instance, we run a sequential pipeline using `plgpu.emit_pipeline` that reduces over the contracting dimension K of the matrix multiplication. On each iteration of the pipeline, we load one tile from each input matrix, multiply them, and then store the result in an accumulator Ref (`plgpu.ACC`). `plgpu.ACC` is a special type of Ref that lives in registers and holds the intermediate results of WGMMA. Once we have accumulated over the entire contracting dimension, we write out the result to the output Ref. + +To perform the actual matrix multiplication, we call `plgpu.wgmma` with the accumulator, LHS, and RHS Refs as arguments in order to push the arguments into the TensorCore pipeline. All WGMMA operations are executed in order, so this can be viewed as pushing operations into a queue. Since `wgmma` is an asynchronous instruction, `plgpu.wgmma_wait(N)` is used to wait until there are no more than N `wgmma` operations left in-flight. In this particular implementation we wait for 1 in-flight WGMMA, meaning that the WGMMA we queue on the current iteration will be waited for on the next iteration. +- `wgmma` wants it's arguments to be in a specific format, defined in the [CUDA documentation](https://docs.nvidia.com/cuda/parallel-thread-execution/#register-fragments-and-shared-memory-matrix-layouts). These are implemented by the `TilingTransform` and `SwizzleTransform` transformations on the input BlockSpecs. Note that in the future transforms will be inferred automatically by Mosaic GPU and these will not need to be manually specified. See the [wgmma reference](https://docs.jax.dev/en/latest/pallas/gpu/reference.html#hopper-wgmma) for full details on using this instruction. +- We use the `delay_release` parameter in conjunction with `plgpu.wgmma_wait(1)` to always allow one `WGMMA` operation to stay in-flight in order to ensure good TensorCore utilization. Without this, we would be flushing the TensorCore pipeline on every iteration of the kernel. + + +```python id="6Vf5_VA9iCD1" +def matmul(a, b, tile_m=128, tile_n=128, swizzle=128): + dtype = jnp.float16 + swizzle_elems = swizzle // jnp.dtype(dtype).itemsize + tile_k = swizzle_elems + grid_m = m // tile_m + grid_k = k // tile_k + grid_n = n // tile_n + assert tile_m % swizzle_elems == 0 + + # Note: Transforms will be inferred automatically + # by Mosaic GPU in the future. + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(swizzle), + ) + + def kernel(a_gmem, b_gmem, o_gmem, o_smem, acc): + def pipeline_step(_, a_smem, b_smem): + plgpu.wgmma(acc, a_smem, b_smem) + plgpu.wgmma_wait(1) + + # pl.program_id obtains the index into the grid. + pid_m = pl.program_id(0) + pid_n = pl.program_id(1) + + pipeline = plgpu.emit_pipeline( + pipeline_step, + in_specs=[ + plgpu.BlockSpec( + (tile_m, tile_k), lambda k: (pid_m, k), transforms=transforms + ), + plgpu.BlockSpec( + (tile_k, tile_n), lambda k: (k, pid_n), transforms=transforms + ), + ], + grid=(grid_k,), + max_concurrent_steps=2, + delay_release=1, + ) + + pipeline(a_gmem, b_gmem) + # Store WGMMA accumulator to SMEM and then to GMEM. + o_smem[...] = acc[...].astype(dtype) + plgpu.commit_smem() + m_slice = pl.ds(pid_m * tile_m, tile_m) + n_slice = pl.ds(pid_n * tile_n, tile_n) + plgpu.copy_smem_to_gmem(o_smem, o_gmem.at[m_slice, n_slice]) + plgpu.wait_smem_to_gmem(0) + + return plgpu.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16), + scratch_shapes=dict( + o_smem=plgpu.SMEM((tile_m, tile_n), jnp.float16), + acc=plgpu.ACC((tile_m, tile_n), jnp.float32) + ), + # grid specifies the CUDA grid. + # Instances of `kernel` will be executed in parallel over this grid. + grid=(grid_m, grid_n), + grid_names=("m", "n"), + )(a, b) + +m = 132 * 128 +n = 4 * 128 +k = 10 * 64 +key1, key2 = jax.random.split(jax.random.key(42), 2) +a = jax.random.uniform(key1, shape=(m, k), dtype=jnp.float16) +b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16) + +result = matmul(a, b) + +np.testing.assert_allclose(result, a @ b) +``` + + +### Warp Specialization + +Warp specialization is a technique where we program each warp/warpgroup to perform a single task in order to give the GPU hardware the flexibility to schedule them at runtime. Recall that each streaming multiprocessor (SM) in a GPU contains warp schedulers that can swap execution between warps, so for example when one warp is stalling it can begin executing a different warp. In practice, this can be more performant than programming a single instruction stream where the compiler must statically schedule the operations and attempt to overlap them optimally. + +In particular, we are interested in warpgroup specialization on Hopper+ GPUs, where it can be useful to have a separate warpgroup issuing TMAs (GMEM/SMEM copies) from the warpgroups performing arithmetic, since indexing calculations and issuing TMAs can take up a significant amount of time and potentially leave the TensorCore idle. The figure below depicts a standard, non-specialized kernel on the left where TMAs (async copies) and matrix multiplication are issued from a single instruction stream, and a warp-specialized version on the right where communication and arithmetic are handled on separate warpgroups. A *consumed barrier* is used to synchronize between the specialized warpgroups that signals to the memory warpgroup when it is safe to begin the next TMA. + + + + + + +
+ + + + + +Warp specialization can be enabled in Pallas by using the `plgpu.emit_pipeline_warp_specialized` helper. This pipeline helper handles all of the logic in the memory thread, and the user only needs to specify the work done in the compute threads. It shares the a similar API as the standard `emit_pipeline`, and currently supports the following arguments: + +```python +plgpu.emit_pipeline_warp_specialized( + body: Callable, + * + grid: tuple[int, ...], + in_specs: Sequence[pallas_core.BlockSpec] = (), + out_specs: Sequence[pallas_core.BlockSpec] = (), + max_concurrent_steps: int, + compute_context: Callable + num_compute_wgs: int, + memory_registers: int + wg_axis: str, + memory_thread_idx: int | None = None, +) +``` + +There are a few arguments specific to this pipeline emitter, which are: +- `num_compute_wgs` specifies how many compute threads/warpgroups to use. The pipeline emitter always uses a single memory thread, so in `plgpu.kernel` you should specify `num_threads=num_compute_wgs+1`. +- `memory_registers` controls how many registers to allocate to the memory thread. The remaining registers are partitioned evenly among the compute threads. The default value is 40 and should be adjusted up or down depending on whether register spills are encountered. +- `wg_axis` the name of the thread/warpgroup axis (as specified by the `thead_name` argument of `plgpu.kernel`). +- `memory_thread_idx` specifies which Pallas thread to designate as the memory thread. Defaults to the last thread. +- `compute_context` is a enables you to specify a prologue/epilogue to the pipeline that only runs in the compute thread. The function allows you to define the initialization and consumption of a loop carry through the pipeline. All compute thread specific arrays should be instantiated here so the memory thread does not materialize them in registers -- otherwise, you may experience slowdowns due to register spills. + +The pipeline body of the warp specialized pipeline is run in parallel by all compute threads, and SMEM is shared between compute threads since they are scheduled within the same CUDA block.`lax.axis_index` can be used inside the kernel to obtain the Pallas thread index in order to divide up work amongst compute threads. + + + + +### Example: Matrix Multiplication with Warp Specialization + +The following example extends the previous matrix multiplication example to use warp specialization. This particular kernel uses 2 compute threads, which operate on separate columns of the RHS matrix but share the same LHS. Each invocation of the pipeline therefore computes 2 adjacent blocks in the output matrix. + + + + + +
+ + + + +We use the `compute_context` pattern to initialize the WGMMA accumulator, and copy the final accumulator from registers into SMEM. Here, the compute context is defined in the function `compute_thread`. It is critical that the accumulator be created inside of the `compute_thread` function to avoid allocating it in the memory thread which would waste registers. To perform the WGMMA, we wrap the `wgmma` instruction in a `pl.run_state` in order to create an accumulator ref that is initialized to the carry value. + +Instead of using `pl.pallas_call` to call the kernel, we instead use the GPU-specific `plgpu.kernel` entry point. `plgpu.kernel` allows us to specify the number of threads to launch per CUDA block via the `num_threads` argument, and allows us to specify a `thread_name` we can use to query the Pallas thread index inside of the kernel. + + + +```python id="EJhWnwJlFGaT" +def matmul_warp_specialized(a, b, tile_m=128, tile_n=128, swizzle=128, + compute_wgs=2): + dtype = jnp.float16 + elems_128b = swizzle // jnp.dtype(dtype).itemsize + tile_k = elems_128b + grid_m = m // tile_m + grid_k = k // tile_k + grid_n = n // tile_n + assert tile_m % elems_128b == 0 + + transforms = ( + plgpu.TilingTransform((8, elems_128b)), + plgpu.SwizzleTransform(128), + ) + + def kernel(a_gmem, b_gmem, o_gmem, o_smem): + wg_idx = lax.axis_index("wg") + wg_slice = pl.ds(wg_idx * tile_n, tile_n) + # pl.program_id obtains the index into the pallas_call grid. + pid_m = pl.program_id(0) + pid_n = pl.program_id(1) + + def compute_thread(pipeline): + acc = plgpu.layout_cast( + jnp.full((tile_m, tile_n), 0, dtype=jnp.float32), plgpu.Layout.WGMMA, + ) + # yield marks the place where the pipelined loop will be inserted. + # Its argument are the initial carry values, and its result is the carry + # value after the loop completes. + final_acc = pipeline(acc) + o_smem[:, wg_slice] = final_acc[...].astype(dtype) + + def kernel_body(_, a_smem, b_smem, carry): + acc = carry + b_smem_wg = b_smem.at[:, wg_slice] + def do_wgmma(acc_ref): + plgpu.wgmma(acc_ref, a_smem, b_smem_wg) + acc = pl.run_state(do_wgmma)( + plgpu.ACC.init(acc)) + return acc + + pipeline = plgpu.emit_pipeline_warp_specialized( + kernel_body, + in_specs=[ + plgpu.BlockSpec( + (tile_m, tile_k), lambda k: (pid_m, k), transforms=transforms + ), + plgpu.BlockSpec( + (tile_k, tile_n * 2), lambda k: (k, pid_n),transforms=transforms + ), + ], + grid=(grid_k,), + compute_context=compute_thread, + max_concurrent_steps=2, + num_compute_wgs=compute_wgs, + memory_registers=40, + memory_thread_idx=2, + wg_axis="wg", + ) + # Call the pipeline + pipeline(a_gmem, b_gmem) + # Copy the output from SMEM to GMEM. + plgpu.commit_smem() + m_slice = pl.ds(pid_m * tile_m, tile_m) + n_slice = pl.ds(pid_n * tile_n * 2, tile_n * 2) + plgpu.copy_smem_to_gmem(o_smem, o_gmem.at[m_slice, n_slice]) + plgpu.wait_smem_to_gmem(0) + + return plgpu.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16), + scratch_shapes=dict( + o_smem=plgpu.SMEM((tile_m, tile_n * 2), jnp.float16) + ), + grid=(grid_m, grid_n // 2), + grid_names=("m", "n"), + num_threads=3, # 2 compute, 1 memory. + thread_name="wg" + )(a, b) + +m = 132 * 128 +n = 4 * 128 +k = 10 * 64 +key1, key2 = jax.random.split(jax.random.key(42), 2) +a = jax.random.uniform(key1, shape=(m, k), dtype=jnp.float16) +b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16) + +result = matmul_warp_specialized(a, b) + +np.testing.assert_allclose(result, a @ b) +``` diff --git a/docs/pallas/gpu/reference.md b/docs/pallas/gpu/reference.md new file mode 100644 index 000000000000..99db4d2b1354 --- /dev/null +++ b/docs/pallas/gpu/reference.md @@ -0,0 +1,1510 @@ +# Writing Mosaic GPU kernels with Pallas + +This page is a reference for the most important features of the Pallas:MGPU backend. +It's not a tutorial and as such we do not expect everyone to read it top to bottom. +Still, it is worth going over +just to familiarise yourself with some patterns you can find in other tutorials. + +In the following examples, we're going to assume the following imports are in scope: +```python +import jax.experimental.pallas as pl +import jax.experimental.pallas.mosaic_gpu as plgpu +``` + +## What is a GPU? + +Technically, the NVIDIA GPU architecture looks as follows: the GPU is partitioned into +_streaming multiprocessors_ (SMs). The way this manifests in the CUDA programming model +is that each _CUDA thread block_ (or CTA) is scheduled on exactly one SM, but multiple +blocks can be scheduled onto a single SM at a time. + +Each SM contains a chunk of fast memory called _shared memory_ (SMEM) and 4 subdivisions, +each containing a _warp scheduler_ and compute units (ALU, TensorCore, ...). +This is also reflected in the CUDA programs: each _warp_ (a group of consecutive 32 CUDA +threads in a block) is assigned to one of those subdivisions in a round-robin fashion. +Similarly to blocks, each warp is assigned to exactly one subdivision (it never migrates), +but multiple warps can be assigned to the same SM subdivision. At each clock cycle, the +warp scheduler from each subdivision tries to select one of its resident warps to execute +the next instruction. + +
A diagram of one NVIDIA SM
+ +Going further, recent CUDA versions also outline the concept of a _warpgroup_, which are +4 consecutive warps. Knowing how the hardware looks like, we can see where this is coming +from: 4 consecutive warps occupy the 4 quarters of an SM and let us issue instructions +that utilize the whole SM. + +```{note} +A GPU can be viewed in many different ways and in here we want to focus on a slightly +simplified model that is very TensorCore-centric. This should help you navigate the +complexities of writing kernels involving the TensorCore, but keep in mind that the +real picture is more complicated. +``` + +For our purposes, TensorCore operations have grown so big that it no longer makes much +sense to follow the CUDA model. As such, to us, a GPU is a collection of single-threaded cores +(SMs) with one thread of Pallas:MGPU corresponding to a CUDA warpgroup. In this model, each +operation you perform in the kernel occupies the whole CUDA warpgroup, and its constituent +warps always run in lockstep (modulo the jitter from hardware scheduling) and never take +different paths through control flow (with the small exception of `core_map` that we will +discuss later). One notable addition here is that we still allow you to co-schedule multiple +of those Pallas-level threads on the same SM so that they can cooperate and communicate +through shared memory (we realize that by putting them in the same CUDA block). + +```{note} +From now on, whenever we say "thread", we refer to the Pallas thread, not a CUDA thread/lane. +``` + +```{note} +This is very similar to a programming model popularized by [Triton](https://triton-lang.org/), +but as you will see there are a few differences. Mosaic GPU tends to be more low level, +which usually means you will have to put in more work, but it also puts you more in control. +In our view both approaches have their merits and we encourage you to pick the backend that +suits your needs the best! Pallas supports and will continue to support Triton as an alternative +GPU backend. +``` + +### In-order execution & using multiple hardware units + +Unlike more complicated CPU architectures GPU only support in-order execution. That, however, +does not mean that at any given time only a single instruction is running! Each SM quarter +has multiple independent functional units: TensorCore, Arithmetic logic unit (ALU), +Load/Store (LSU), Special function unit (SFU). If the first instruction targets one of the +units and is followed by another one (that does not use the result of the first one), then the +warp scheduler can issue the second one before the first one completes. This is often referred +to as instruction-level parallelism (ILP) and is a common theme in modern TensorCore kernels: +TensorCore operations are so big and take so many cycles to complete, that it is a waste to not +try to use other units in the meantime. + +To extend this even further, we can take advantage of this hardware-unit-level parallelism by +allowing multiple Pallas threads to run concurrently. If one of the threads primarily +occupies the ALU, while another one primarily issues TensorCore related instructions, we can +take advantage of the efficient context switching built into the warp schedulers to keep both +units busy. This is one of the core idea behind algorithms such as [FlashAttention 3](https://arxiv.org/abs/2407.08608) +or [CUTLASS ping-pong matmul kernels](https://pytorch.org/blog/cutlass-ping-pong-gemm-kernel/). + +For more information on how warp scheduling and instruction issue works, we recommend reading +[Analyzing Modern NVIDIA GPU cores](https://arxiv.org/abs/2503.20481). + +### Memory spaces + +The GPU features a few different memory spaces that can be totally ordered from largest (in +terms of capacity) and slowest (in both total bandwidth and latency of a single access). + +
A diagram of memory spaces of an NVIDIA GPU
+ +The biggest memory space is `plgpu.GMEM`, for _global memory_. In recent data-center grade GPUs +this memory space is often measured in tens or even hudreds of gigabytes, but it is also the +slowest one. + +The next memory space, used for the L2 cache, is also more or less global in the +sense that it is shared by the whole GPU, but its use can only be influenced indirectly through +cache hints. As such, there's no way to manually place values in there and so this memory space +is not exposed in Pallas:MGPU. While only about a 100MB in size, this memory has considerably +higher bandwidth than GMEM, and so it is still often recommended to take advantage of it while +writing high-performance kernels. + +Next in line is _shared memory_, or `plgpu.SMEM`. This memory is located directly inside each SM +and so it is partitioned. Unless block clusters are used (see the section of clusters below), +each block is only allowed to access its own SMEM allocations. + +Finally, the lowest level memory space is the _register memory_. This is where every single value +(i.e. JAX array) in a Pallas kernel will be located. If the compiler runs out of registers to +store those arrays, it will insert _spills_, meaning that it will periodically store and reload +values to memory. Those spills often introduce other significant performance degradations and so +we recommend avoiding them. The warning messages about spills can be clearly seen in the `ptxas` +messages during kernel compilation. To make them visible, run with `MOSAIC_GPU_DUMP_PTXAS=1` +in your environment. + +The Blackwell GPU generation, has one additional memory space called _tensor memory_ or `plgpu.TMEM`. +TMEM is very similar to register memory, only it is explicitly allocated and managed by you. +It is used to store the MMA accumulator, operand metadata (for sparsity or scaling), +and optionally the left MMA operand. See the Blackwell MMA section for more information about TMEM. + +#### Requesting/allocating memory in specific memory spaces + +Kernel inputs or outputs are placed in SMEM by default. If you want to access them as GMEM references +add `memory_space=plgpu.GMEM` to their `BlockSpec`. If you want the kernel to be called with the whole +input or output array in GMEM, it is sufficient to specify `BlockSpec(memory_space=plgpu.GMEM)`. + +`SMEM` and `TMEM` can be allocated explicitly in the `scratch_shapes` argument of `pl.pallas_call`, +or using `pl.run_scoped`. To allocate a reference, simply call the memory space object with the +requested shape and dtype. For example: `plgpu.SMEM((128, 128), jnp.float16)` will allocate a 128x128 +array of float16 elements in shared memory. + +#### Taking advantage of the L2 cache + +While the L2 cache cannot be managed manually, its noticeably higher bandwidth compared to global +memory makes it worth thinking about. The simplest way to take advantage of it, is to reorder +the parallel grid dimensions so that invocations that are scheduled in similar time periods also +access the same input data. + +While the CUDA programming model does not guarantee anything about the order in which the blocks +are assigned to SMs, in recent generations the heuristic seems to simply iterate over the +`(x, y, z)` CUDA grids in column-major order (i.e. `x` is the fastest-changing dimension and +`z` is the slowest). Similarly, Pallas:MGPU does not guarantee how a user-specified grid is mapped to +the CUDA grid (Pallas supports grids of arbitrary rank, not just up to 3D). However, you can assume that +the iteration will happen in _row-major_ order. That is, if a grid has dimensions `(a, b)`, then +`b` will be the fastest-changing dimension and `a` will be the slower one. + +To give a practical example of this, consider a plain matrix multiplication kernel. There, one +usually uses two parallel grid dimensions `(m, n)`, corresponding to tiling the two non-contracting +dimensions. If we use this simple scheme, in Pallas:MGPU all programs with id `(0, ...)` will be +scheduled before any block with id `(1, ...)`. And, collectively, the programs with `m=0` have to +read all of the `B` operand! If the `n` or `k` dimensions are very large, there is no chance that +we'll be able to get cache hits from the `(1, ...)` programs from accesses made by the `(0, ...)` +programs. For simplicity, assuming we can only run 16 blocks at a time, we see this access pattern +from the first scheduled wave: + +
+ + Your browser does not support SVGs or scripting is disabled. + This would be an image showing the access pattern of first 16 blocks without grid tiling. + +
+ +However, if we simply rearrange the grid to be `(m // mt, n, mt)` (and then replace `pl.program_id(0)` +with `pl.program_id(0) * mt + pl.program_id(2)` in the kernel), it is straightforward to see that a +band of programs along both dimensions will be scheduled concurrently (instead of scheduling a single +row). This greatly increases the number of concurrent programs that load similar slices of data, +usually significantly improves the L2 utilization and hence the overall performance of the kernel +(if it was memory bound). Continuing our example with 16 blocks and using `mt=4`, we get the following +access pattern: + +
+ + Your browser does not support SVGs or scripting is disabled. + This would be an image showing the access pattern of first 16 blocks with grid tiling. + +
+ +Note that even though the number of active blocks hasn't changed, the total footprint of the data they +access has halved! We get a much higher chance of getting L2 hits now. + +## Array layouts and memory reference transforms + +In Pallas, the data structures you work with (arrays and references) have a +**logical shape** (e.g., a 128x128 matrix). This +logical shape must be mapped to a **physical representation** (how the data is +actually represented in the GPU's memory). The specific mapping depends on where the +data resides: + +1. **Array Layouts:** Arrays are stored in register memory and we call this mapping + a _layout_. Layouts define how the elements of an array are + distributed across the registers available to the CUDA lanes that form a Pallas thread. +2. **Memory Reference Transforms:** For mutable references pointing + to `SMEM`, this mapping is called a _transform_. + Transforms describe how the logical data structure is arranged within that + block of memory. + +These concepts are crucial for performance, especially when interacting with +specialized hardware units like TensorCores or optimizing memory access +patterns. + +```{note} +We are working on a mode that will deal with assigning layouts and transforms fully +automatically (although with way to provide hints and more control). The APIs listed +below will likely continue to function, but will become optional. +``` + +### Memory reference transforms + +Transforms are applied when a memory reference is first allocated. Pallas +primitives that operate on these references will automatically account for their +associated transforms. + +``` +def body(..., scratch_ref): + # Asynchronous copy will reformat the GMEM data to match the SMEM transforms + plgpu.copy_gmem_to_smem(..., scratch_ref, barrier) + plgpu.barrier_wait(barrier) + plgpu.wgmma(..., scratch_ref) # wgmma only accepts properly transformed refs + ... +``` + +There are two ways in which references are allocated and each has a way to select +the desired transforms: + +**1. Using `plgpu.BlockSpec`** + +```python +transforms = (plgpu.TileTransform((8, 64)), plgpu.SwizzleTransform(128)) +f = pl.pallas_call( + in_specs=plgpu.BlockSpec(in_block_shape, in_index_map, transforms=transforms), + out_specs=plgpu.BlockSpec(out_block_shape, out_index_map, transforms=transforms), + ... +) +``` + +Note that unlike `plgpu.BlockSpec`, `pl.BlockSpec` does *not* allow specifying +transforms. + +**2. Specifying the `transforms` argument on the allocated `SMEM`** + +```python +transforms = (plgpu.TileTransform((8, 64)), plgpu.SwizzleTransform(128)) +f = pl.pallas_call( + scratch_shapes=plgpu.SMEM((128, 128), jnp.float16, transforms=transforms), + ... +) +``` + +The available transforms are: +* `plgpu.TileTransform(tile_shape)`, which organizes the data into contiguous, + non-overlapping tiles of shape `tile_shape`. The data of one tile is always + fully linearized (row-major), before another tile begins (tiles are also + traversed in row-major order). As an example, applying `TileTransform((8, + 64))` to a `(128, 128)` reference means the data corresponding to the logical + slice `[0:8, 0:64]` will be stored first (row-major), followed by + `[0:8, 64:128], [8:16, 0:64], [8:16, 64:128]`, and so on. A different way to achieve + this would be to take the input array `x` and traverse + `x.reshape(128 // 8, 128 // 64, 8, 64).transpose(0, 2, 1, 3)` in row-major order. +* `plgpu.SwizzleTransform(swizzle_in_bytes)`, which transforms the data as described in the + [PTX docs](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-swizzling-modes) and + [CUDA docs](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#the-swizzle-modes). + Swizzling is useful, because it allows transferring data in MMA-related layouts + between register and shared memory without bank conflicts. The exact details + of how the memory looks like after swizzling _are not that important_, since + all primitives will account for it automatically. Note that the swizzle amount + is specified in bytes (only 128, 64, 32 and 16 are supported), and is usually + accompanied by a `TileTransform` (which uses elements in its shape!). +* `plgpu.TransposeTransform(permutation)`, which permutes the dimensions of the array before it is linearized. + This is primarily useful in that it lets you change the layout during the GMEM-SMEM copies (only + do keep in mind that changing the minormost/last dimension is not supported by the hardware). + +```{note} +When performing GMEM-SMEM or SMEM-GMEM copies and a `plgpu.TileTransform` is +applied on the SMEM reference, the offsets into the GMEM reference must be +aligned with the tile sizes. If that is not the case, the transfer may produce +wrong results. +``` + +### Array layouts + +There are a few useful layouts we have defined for you so far: +* `plgpu.Layout.WGMMA`, which is the layout in which the Hopper-generation TensorCore + expects the MMA accumulator or 16-bit input operands to have in registers. +* `plgpu.Layout.WGMMA_ROW`, which is the layout obtained after the above after reducing + it along the rows. Re-broadcasting the rows is free and will produce a value with `WGMMA` + layout. +* `plgpu.Layout.WGMMA_COL`, which is an analogue of the one above, only reduced along + columns instead of rows. +* `plgpu.Layout.WG_STRIDED`, where the value is partitioned equally among the 128 + CUDA lanes making up a Pallas thread. The consecutive elements (after vectorization) + are assigned to the lanes in a round-robin fashion. Very simple and effective when + no interaction with TensorCores is needed. +* `plgpu.Layout.WG_SPLAT`, indicating that the value is constant. Each CUDA lane will + hold a single register that contains the value. You normally never have to interact + with this layout, as it is implicitly used when constant values are created and + is always implicitly convertible to other layouts. + +At the moment, in the default mode of operation, array layout propagation happens +only in a forward direction and there is little implicit support for reconciling +layout conflicts: only splat layouts can be implicitly converted into any other +layout. If you e.g. try to add two arrays that have a different layout, the lowering +will complain and fail. There are very limited facilities that let you convert between +layouts, and we usually recommend storing the value to SMEM and reading it back in +the target layout. + +## MMA (TensorCore) + +In this section, we focus on how Pallas:MGPU kernels can utilize the TensorCore unit. +The programming interface of the TensorCore changes significantly between different +NVIDIA GPU generations, which is why the lowest-level interfaces differ in Pallas:MGPU as well. + +Each MMA operation is associated with three operands: +* the accumulator `D` of shape `(M, N)`, +* the left input `A` of shape `(M, K)`, +* the right input `B` of shape `(K, N)`. +All operands must have the same element type. + +Each use of MMA involves a few steps: +1. Allocating the space for the accumulator (MMA implicitly performs `D += A @ B`) +2. Preparing the `A` and `B` operands +3. Issuing the operation +4. Waiting for the operation to complete +5. Reading out the result + +Steps 2.-4. are usually performed in a loop over the contraction dimension (`K`). + +(memory-space-a-b-mma)= +### Memory space of `A` and `B` operands + +The `A` and `B` operands are generally best passed in through SMEM, where they can +be conveniently loaded using `plgpu.copy_gmem_to_smem`. For those operands to be +compatible with MMA operations, they need to have the appropriate tiling and swizzling +transforms specified upon their allocation. For all currently supported generations, +the TensorCore requires the data to be laid out into row-major 2D tiles of shape +`(8, swizzle_elems)`, where `swizzle_elems` is derived by dividing the swizzle by the +element type bytewidth. The currently supported swizzles are: 128, 64, and 32. Larger +swizzles are preferable as they improve the performance of GMEM-to-SMEM copies. + +```python +def mma_transforms(shape_dtype: jax.ShapeDtypeStruct): + assert len(shape_dtype.shape) == 2 + if shape_dtype.shape[0] % 8: + raise ValueError("Number of rows must be divisible by 8") + for swizzle_bytes in (128, 64, 32): + swizzle_elems = swizzle_bytes // shape_dtype.dtype.itemsize + if shape_dtype.shape[-1] % swizzle_elems == 0: + return (plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(swizzle_bytes)) + raise ValueError("Failed to find transforms for the specified window type") +``` + +If the operands need to be transformed, the `A` operand can be passed in through a different +memory space (architecture dependent, see below). The `B` operand _must_ be located in SMEM. + +### Transposed operands + +When performing MMA on 16-bit operands, the TensorCore can automatically transpose the +input data. For example, the `A` reference is allowed to be of shape `(K, M)`, but it +has to be transposed before passing it into the mma function. For example: +```python +assert acc_ref.shape == (M, N) and a_ref.shape == (K, M) and b_ref.shape == (K, N) +a_ref_t = plgpu.transpose_ref(a_ref, (1, 0)) +assert a_ref_t.shape == (M, K) # The shape expected by plgpu.wgmma +plgpu.wgmma(acc, a_ref_t, b_ref) +``` +An analogous operation is allowed on the `B` reference in this case too. + +### Hopper (`wgmma`) + +In this section, we cover the basics of using the Hopper-generation TensorCores, exposed in +PTX as the [`wgmma.mma_async` instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions-wgmma-mma). + +#### Allocating the accumulator + +In the Hopper hardware architecture the accumulator is allocated in registers, but in Pallas +it is modeled as a mutable reference, as each MMA operation accumulates in-place. +There are two ways to allocate the accumulator. + +To create a zero-initialized accumulator you can use `pl.run_scoped` with a +`plgpu.ACC((m, n), dtype)` type. +```python +def compute(acc_ref): + ... + return acc_ref[...] +output = pl.run_scoped(compute, plgpu.ACC((m, n), jnp.float32)) +``` +Dereferencing the accumulator reference, as seen in the end of the `compute` function will +implicitly await all outstanding WGMMA operations. + +If you'd like to initialize it with an existing array, you can use `pl.run_state` with +`plgpu.ACC.init(init_array)`: +```python +def compute(acc_ref): + ... + return # pl.run_state only returns the final value of the accumulator +output = pl.run_state(compute)(plgpu.ACC.init(init_array)) +``` +If `pl.run_state` has accumulator operands, it implicitly awaits all outstanding WGMMA +operations before returning the final values. + +#### Preparing the `A` and `B` operands + +As discussed above, we recommend passing in `A` and `B` through shared memory. In this +case the correct tiling and swizzling transforms must be specified. + +`plgpu.wgmma` additionally allows passing in `A` through registers (i.e. not an SMEM +reference but as a regular JAX array). This mode, however, comes with a number of +significant drawbacks and it is very difficult to ensure sufficient synchronization to +make this safe. + +TODO: Explain the conditions under which it is acceptable to do this. + +#### Issuing the operation + +The supported MMA shapes are such that: +* `M` is divisible by 64 +* `N` is divisible by 8 and not greater than 256 +* `K` is a multiple of `swizzle` divided by the operand's element type bytewidth + +The currently supported data types are: `jnp.float32`, `jnp.bfloat16` and `jnp.float16`. +The accumulator `D` must be a `jnp.float32`, with the exception of `jnp.float16` inputs, +in which case it is allowed to be `jnp.float16` as well. + +#### Waiting for the operation to complete + +Each `plgpu.wgmma` call implicitly synchronizes with all previous `plgpu.wgmma` calls, such +that once control returns from it, we guarantee that no WGMMA other than the last issued +one is still running. As such, any SMEM regions that were read by previously issued WGMMA +instructions can be reused. This is especially relevant for pipelining WGMMA with async memory copies: +```python +buffers = 3 # In reality you might want even more +assert a_smem.shape == (buffers, m, k) +assert b_smem.shape == (buffers, k, n) +assert acc_ref.shape == (m, n) + +def fetch_a_b(ki, slot): + a_slice = ... # Replace with the right M/K slice + b_slice = ... # Replace with the right K/N slice + plgpu.copy_gmem_to_smem(a_gmem.at[a_slice], a_smem.at[slot], a_loaded.at[slot]) + plgpu.copy_gmem_to_smem(b_gmem.at[b_slice], b_smem.at[slot], b_loaded.at[slot]) + +def loop_body(i, _): + slot = jax.lax.rem(i, buffers) + plgpu.barrier_wait(a_loaded.at[slot]) + plgpu.barrier_wait(b_loaded.at[slot]) + plgpu.wgmma(acc_ref, a_smem.at[slot], b_smem.at[slot]) + # We know that only the last issued WGMMA is running, so we can issue a async load in + # into the other buffer + load_i = i + buffers - 1 + load_slot = jax.lax.rem(load_i, buffers) + @pl.when(jnp.logical_and(load_i >= buffers, load_i < num_steps)) + def _do_fetch(): + fetch_a_b(load_i, slot) +for slot in range(buffers): + fetch_a_b(slot, slot) +jax.lax.fori_loop(0, num_steps, loop_body, None) +``` + +### Blackwell (`tcgen05`) + +The Blackwell generation has significantly redesigned the TensorCore subunit. +It is now significantly more independent from the regular warp schedulers and +no longer uses or even supports using registers as its operands. In their place, +a new memory space called _tensor memory_ (TMEM) has been introduced. What's +more TensorCores from pairs of SMs can now pool their resources and compute +larger MMA operations that span both SMs. We call this a ["collective MMA operation"](#collective-mma). + +#### Allocating the accumulator / Using TMEM + +TMEM references can be allocated in the same way in which all other references +are allocated---using {py:func}`pl.run_scoped `: + +```python +@functools.partial(pl.run_scoped, tmem_ref=plgpu.TMEM((128, 128), jnp.float32)) +def barrier_scope(tmem_ref): + ... +``` + +Not all shapes can be allocated in TMEM. Only 2D references are supported, and +the number of rows (the size of the first dimension) must be 128 or 64 at the +moment. + +What's more, if the data type has a bitwidth smaller than 32-bits, it is necessary +to declare if the allocation is supposed to be packed (e.g. putting two 16-bit +elements into a single 32-bit cell in TMEM) or not (with each element padded up +to 32-bits). MMA accumulators (fp32 or fp16) are never packed, but if the left +operand it passed in TMEM, it must always be packed: + +```python +@functools.partial(pl.run_scoped, + acc_ref=plgpu.TMEM((128, 128), jnp.float16, packed=False), + lhs_ref=plgpu.TMEM((128, 128), jnp.float16, packed=True)) +def barrier_scope(acc_ref, lhs_ref): + plgpu.tcgen05_mma(acc_ref, lhs_ref, rhs_smem_ref, ...) + ... +``` + +Another interesting complication with TMEM is that all operations on it are asynchronous. +For that reason, reads and writes using the Python subscript syntax that are normally +used e.g. for SMEM are not allowed for TMEM. + +(tmem-loads)= +##### Loads + +Loads can be performed using {py:func}`plgpu.async_load_tmem ` and awaited using {py:func}`plgpu.wait_load_tmem `: + +```python +smem_ref[...] = plgpu.async_load_tmem(tmem_ref) +plgpu.commit_smem() +plgpu.copy_smem_to_gmem(smem_ref, gmem_ref) +plgpu.wait_smem_to_gmem(0) +plgpu.wait_load_tmem() # Wait for the read to fully complete before we overwrite tmem_ref again. +``` + +The load semantics are quite confusing, in that the array returned from the load +can be safely used without any additional synchronization. However, if the read +TMEM region is ever overwritten again (e.g. by a store or an MMA operation), the +thread that issued the load must first call `plgpu.wait_load_tmem()` to ensure +the program remains race-free. + +```{note} +One way to make peace with this seemingly causality-breaking behavior (data +arrives in registers before it is fully read from TMEM) is to consider that it +might be an effect of an interaction of a limitation and a convenience feature +in the PTX compiler. We don't know if this is true, but at least it makes sense. + +The convenience feature is that the compiler can reliably track the usage of +registers produced by TMEM loads and will insert the minimum number of delays +necessary to ensure the data arrives from TMEM before it's used. The read +operation is unrolled into many instructions, meaning that they don't have to +all be awaited before we start consuming the registers filled in by the first load. +This is why we don't need to guard the use of the result. + +The limitation is that the compiler cannot reliably perform alias analysis on +TMEM loads and stores, which is why any load and store that is not separated +by an explicit wait is considered safe to execute concurrently. The alternative +would unnecessarily pessimize the performance of loads and stores that are truly +unrelated. This is why we need to explicitly wait before we reuse TMEM again. +``` + +##### Stores + +Conversely, stores are performed using {py:func}`plgpu.async_store_tmem ` and awaited using {py:func}`plgpu.commit_tmem `: + +```python +plgpu.async_store_tmem(tmem_ref, smem_ref[...]) +plgpu.commit_tmem() +smem_ref2[...] = plgpu.async_load_tmem(tmem_ref) # Safe to read from tmem_ref now +``` + +#### Preparing the `A` and `B` operands + +We recommend passing in `A` and `B` through shared memory. In this case the +[correct tiling and swizzling transforms must be specified](#memory-space-a-b-mma). +The `A` operand can be passed in as a TMEM reference as well, but it must be packed. + +#### Issuing the operation + +The supported **non-collective** MMA shapes are such that: +* `M` is 64 or 128 +* `N` is divisible by 8 and not greater than 512 +* `K` is a multiple of `8 * swizzle` divided by the bitwidth of element type + +The supported [**collective** MMA](#collective-mma) shapes are such that: +* `M` is 128 or 256 (half of that per block) +* `N` is divisible by 8 and not greater than 256 (not greater than 128 in each block) +* `K` is a multiple of `8 * swizzle` divided by the bitwidth of element type + +The currently supported floating-point data types are: `jnp.bfloat16`, +`jnp.float16`, `jnp.float8_e5m2`, `jnp.float8_e4m3fn`. The accumulator can be +a `jnp.float32` or `jnp.float16`, with the exception of `jnp.bfloat16` when it +must be a `jnp.float32`. + +The only currently supported integer data type is `jnp.int8` with a `jnp.int32` +accumulator. + +```{note} +According to our benchmarks, here are some performance rules-of-thumb: + +* Non-collective MMA should always use M=128 and N >= 128. + - M=64 causes a significant performance drop. + - N=64 causes a noticeable performance drop, but not as significant as M=64. +* Collective MMA is always reasonably fast, but not faster than non-collective MMA. + - The biggest benefit from collective MMA is not higher TensorCore throughput + but the ability to share data between SMs, allowing to increase the arithmetic + intensity of the kernel. +* Swizzle and transposes do not seem to affect performance in a significant way. +``` + +#### Waiting for the operation to complete + +Awaiting the result of a {py:func}`plgpu.tcgen05_mma ` +call requires the use of a `Barrier`. We recommend reading through the reference +documentation for [`Barrier`s](#barrier), and especially its +[Blackwell-related subsection](#awaiting-tcgen05-instructions) for more information. + +If the barrier is passed in directly to +the {py:func}`plgpu.tcgen05_mma `, +completing a wait on that barrier will indicate that the final accumulator has +been written to TMEM. For example: + +```python +@functools.partial(pl.run_scoped, barrier_ref=plgpu.Barrier(orders_tensor_core=True)) +def barrier_scope(barrier_ref): + plgpu.tcgen05_mma(acc_tmem, lhs_ref, rhs_ref, barrier_ref, accumulate=False) + plgpu.barrier_wait(barrier_ref) + # We can read the result now. + result = plgpu.async_load_tmem(acc_tmem) + ... +``` + +If no barrier is given to {py:func}`plgpu.tcgen05_mma `, +its completion will be tracked only once {py:func}`plgpu.tcgen05_commit ` is called: + +```python +@functools.partial(pl.run_scoped, barrier_ref=plgpu.Barrier(orders_tensor_core=True)) +def barrier_scope(barrier_ref): + plgpu.tcgen05_mma(acc_tmem, lhs_ref, rhs_ref, accumulate=False) + plgpu.tcgen05_mma(acc_tmem, lhs_ref2, rhs_ref2) + plgpu.tcgen05_commit(barrier_ref) + plgpu.barrier_wait(barrier_ref) + # We can read the result now. Both MMAs have completed. + result = plgpu.async_load_tmem(acc_tmem) + ... +``` + +(collective-mma)= +#### Collective MMA + +The Blackwell generation gains a new way to perform MMA operations, where the +TensorCores of 2 SMs in a cluster collaborate on a single MMA operation. The +`B` operand from each SM is shared with the other. The `D` and `A` operands are +local to each SM and not shared. + +
A diagram showing the partitioning of operands in a collective MMA
+ +This means that to perform a collective MMA with shape M, N, and K, the operands +in each of the two Pallas threads should be of sizes: `(M // 2, K)` for `A`, +`(K, N // 2)` for `B` and `(M // 2, N)` for `D` (the accumulator). Stacking the +two accumulators on top would recover the result of performing a MxNxK matrix +multiplication. + +To make loading of the `B` operand easier, {py:func}`plgpu.copy_gmem_to_smem ` +can be used together with `collective_axes` and `partitioned_axis` to indicate +that the two Pallas threads along the collective axis should load the same slice, +but each will only obtain half of it. Unlike a copy with `collective_axes` alone +it does not utilize TMA multicast (since each thread loads a distinct slice of +data), but it can simplify the indexing logic a bit. + +```python +plgpu.copy_gmem_to_smem( + b_gmem, # [K, N] + b_smem, # [K, N // 2] + b_tma_barrier, + collective_axes="x", + partitioned_axis=1, +) +``` + +## Using `core_map` + +`pl.pallas_call` is suitable for kernels where a single Pallas thread can +perform the whole computation for an entire CUDA block. The `pl.core_map` +function relaxes this restriction, allowing for using multiple threads within a +single block (e.g. for warp specialization) or across multiple blocks in a block +cluster (e.g. to utilize multicast TMA). + +### Replacing `pl.pallas_call` with `pl.core_map` or `plgpu.kernel` + +Let us begin with a simple Pallas kernel that increments an array: + +```python +@functools.partial( + pl.pallas_call, + grid=(2,), + in_specs=[pl.BlockSpec(block_shape=(128,), index_map=lambda i: (i,))], + out_specs=pl.BlockSpec(block_shape=(128,), index_map=lambda i: (i,)), + out_shape=jax.ShapeDtypeStruct((256,), jnp.float32), # Total output shape +) +def run_kernel(x_ref, y_ref): + # x_ref and y_ref are in SMEM! + y_ref[...] = x_ref[...] + 1 + +x = jnp.arange(256, dtype=jnp.float32) +y = run_kernel(x) +np.testing.assert_array_equal(y, x + 1) +``` + +We can write a similar kernel using `pl.core_map`. One big difference is that +unlike `pl.pallas_call`, no GMEM<->SMEM copies will be inserted automatically. +If you want them, you can either insert them yourself or use the +{py:func}`plgpu.emit_pipeline ` +helper. We recommend reviewing the [software pipelining guide](./pipelining.md). + +```python +@pl.run_state +def run_kernel(refs): + x_ref, y_ref = refs + # Here, we're not in the kernel yet! pl.run_state simply changes the JAX + # immutable arrays into mutable GMEM (not SMEM!) references. + + # Define the mesh: 2 CUDA blocks over 1 axis called "x" + mesh = plgpu.Mesh(grid=(2,), grid_names=("x",)) + + @pl.core_map(mesh) # core_map executes the body + def kernel_body(): + # Once we enter the pl.core_map scope, we are in the body of the kernel. + block_slice = pl.ds(jax.lax.axis_index("x") * 128, 128) + y_ref[block_slice] = x_ref[block_slice] + 1 + +x = jnp.arange(256, dtype=jnp.float32) +y_init = jnp.zeros_like(x) +_, y = run_kernel((x, y_init)) +np.testing.assert_array_equal(y, x + 1) +``` + +While `pl.core_map` is a powerful API, it is also quite low-level and is pretty +much always used in under `pl.run_state` (to make JAX arrays into refs) or +`pl.run_scoped` (to allocate for scratch refs). For that reason, we also +provide a convenience API `plgpu.kernel`: + +```python +@functools.partial( + plgpu.kernel, + out_shape=jax.ShapeDtypeStruct((256,), jnp.float32), + grid=(2,), + grid_names=("x",), +) +def run_kernel(x_ref, y_ref): + # x_ref and y_ref are in GMEM! + block_slice = pl.ds(jax.lax.axis_index("x") * 128, 128) + y_ref[block_slice] = x_ref[block_slice] + 1 + +x = jnp.arange(256, dtype=jnp.float32) +y = run_kernel(x) # No need to preallocate outputs as in pl.core_map. +np.testing.assert_array_equal(y, x + 1) +``` + +```{note} +The `plgpu.Mesh` used with `pl.core_map` defines a topology for computation +*within a single GPU*, specifying how work is distributed across CUDA blocks +(the `grid`), Pallas threads within a block (`num_threads`), and potentially +CUDA block clusters (`cluster`). This is analogous to how `jax.sharding.Mesh` +defines a topology for distributed computation *across multiple devices* in JAX. +Both involve SPMD programs executing across the defined topology. Furthermore, +you can run "collectives" over the Pallas threads and cluster (e.g., using +`plgpu.ClusterBarrier` or collective async copies), similar to how JAX +collectives (`psum`, `all_gather`, etc.) operate across devices in a JAX `Mesh`. +Both also use named axes, and `jax.lax.axis_index(axis_name)` can be used to get +a thread's or block's coordinate. +``` + +### Using multiple Pallas threads per CUDA block + +Below, you can find an example of two Pallas threads within a single block +synchronizing through a barrier and even exchanging data through SMEM. + +```python +x = jnp.arange(128, dtype=jnp.float32) + +@functools.partial( + plgpu.kernel, + out_shape=x, + scratch_shapes=dict( + smem_ref=plgpu.SMEM(x.shape, x.dtype), + barrier_ref=plgpu.Barrier(), + ), + num_threads=2, + thread_name="pallas_thread", +) +def run_kernel(x_ref, y_ref, smem_ref, barrier_ref): + thread_id = jax.lax.axis_index("pallas_thread") + + @pl.when(thread_id == 0) + def producer_thread(): + smem_ref[...] = x_ref[...] + 1 + plgpu.barrier_arrive(barrier_ref) # Signal the consumer thread + + @pl.when(thread_id == 1) + def consumer_thread(): + plgpu.barrier_wait(barrier_ref) # Wait for the producer thread + out_ref[...] = smem_ref[...] + 1 + +y = run_kernel(x) # There's no need to preallocate the input anymore. +np.testing.assert_array_equal(y, x + 2) +``` + +While this example is simple, you can find a more complicated example in the +[synchronization section](#cross-thread-synchronization). + +Multiple threads are frequently used in high-performance kernels such as the +latest flash attention variants or ping-pong matrix multiplication. In both of +those, there are 2 compute threads in the program that use the SM's ALU +and TensorCore in an alternating fashion to ensure no execution conflicts. + +Another common technique is to allocate one Pallas thread and devote it entirely +to scheduling asynchronous copies for data consumed by other threads. While +implementing this scheme from scratch can be complicated, we provide a +convenient helper API: `plgpu.emit_pipeline_warp_specialized`. + +### Using CUDA block clusters + +The kernel below launches a single cluster of 2 CUDA blocks and uses the TMA +multicast feature to collectively perform a copy of GMEM into SMEM of both +blocks. All blocks participating in the collective copy must schedule the exact +same copy for the program to be valid. + +```python +@functools.partial( + plgpu.kernel, + out_shape=jax.ShapeDtypeStruct((2, 128), jnp.float32), + scratch_shapes=dict( + smem_ref=plgpu.SMEM((128,), jnp.float32), + barrier_ref=plgpu.Barrier(), + ), + cluster=(2,), + cluster_names=("cluster",), +) +def run_kernel(x_ref, y_ref, smem_ref, barrier_ref): + # Specifying collective_axes will enable TMA multicast automatically. + plgpu.copy_gmem_to_smem(x_ref, smem_ref, barrier_ref, collective_axes="cluster") + plgpu.barrier_wait(barrier_ref) + plgpu.copy_smem_to_gmem(smem_ref, o_ref.at[jax.lax.axis_index("cluster")]) + plgpu.wait_smem_to_gmem(0) + +x = jnp.arange(128, dtype=jnp.float32) +y = run_kernel(x) +# Each block gets the same data and writes it out. +np.testing.assert_array_equal(y, jnp.stack([x, x], axis=0)) +``` + +### Collective allocations in `pl.run_scoped` + +When using `pl.core_map` with multiple Pallas threads (i.e., `num_threads > 1` +in `plgpu.Mesh`), allocations made via `pl.run_scoped` (for SMEM or Barriers) +must be performed _collectively by all threads_. This is indicated by specifying +a `collective_axis` argument to the `run_scoped`, which has two effects: +1. it promises that all threads will call the same allocation, and +2. all threads will receive the exact same allocation. + +If collective_axes is not specified or does not include the Pallas thread axis, +each thread would get its own private copy of the scratch variable. This is +usually undesired and not supported at the moment. + +### Global (grid-wide) allocations using `pl.get_global` + +Sometimes, it is useful to allocate [semaphores](#semaphore) in a way that enables them to be +shared by all the parallel program instances. For example, when the number of +parallel instances is small enough that the kernel is persistent. Such allocations +are possible using `pl.get_global`: + +```python +def body(out_ref): + sem_ref = pl.get_global(plgpu.SemaphoreType.REGULAR) + block_id = lax.axis_index("x") + @pl.when(block_id == 0) + def _(): + pl.semaphore_signal(sem_ref) # Block 0 signals + @pl.when(block_id == 1) + def _(): + pl.semaphore_wait(sem_ref) # Block 1 waits + out_ref[...] = jnp.ones_like(out_ref) + +out_shape = jax.ShapeDtypeStruct((128,), jnp.float32) +plgpu.kernel(body, out_shape=out_shape, grid=(2,), grid_names=("x",))() +``` + +## Synchronization structures and primitives + +In this section, we go over the most important functions and data structures +used for synchronization between threads and also some asynchronous operations. + +(commit-smem)= +### `commit_smem` + +Regular reads/writes to references are guaranteed to produce values consistent +with the sequential program order. For example, in the following program, it is +guaranteed that `value` is equal to `value2`. +```python +ref[...] = value +value2 = ref[...] +``` + +This guarantee, however, does not extend to asynchronous primitives such as async +copies or MMA operations. To make the SMEM writes visible to those primitives, you +are required to explicitly synchronize with them using the `plgpu.commit_smem()` function. + +For example: +```python +smem_ref[...] = value +plgpu.commit_smem() +plgpu.copy_smem_to_gmem(smem_ref, ...) +``` +or: +```python +smem_ref[...] = value +plgpu.commit_smem() +plgpu.wgmma(smem_ref, ...) +``` + +This explicit synchronization is also required in the other direction, for +example: +```python +v = plgpu.load(smem_ref, ()) +plgpu.commit_smem() +plgpu.copy_gmem_to_smem(..., smem_ref, ...) +``` + +Failing to call this function is likely to cause subtle data races, due to those asynchronous +hardware units reading stale data from SMEM. Unfortunately, this function is relatively expensive, +which is why we rely on you, the user, to insert it in the minimal number of places where it's necessary. + +(barrier)= +### `Barrier` + +This is essentially a thin wrapper around an array of PTX `mbarrier` types and is +passed in as a reference. All functions involving barriers expect to only get a single +barrier argument, and so if the reference contains multiple, you have to extract one +of them explicitly using `barriers.at[index]`. `Barrier`s are always allocated in SMEM +and as such have relatively low overheads. Each barrier can be configured to complete +after a fixed number of "arrivals" (by default 1). + +To block a thread until a barrier completes, use the following function: +```python +plgpu.barrier_wait(barrier) +``` + +```{warning} +It is critical to ensure that the synchronization scheme makes it impossible for two +barrier completions to happen without a call to `plgpu.barrier_wait` in between them. +For example, if you use `Barrier`s to synchronize two producer/consumer threads, you +need to perform barrier synchronization going both ways to introduce "backpressure" +that will stop one thread from arriving twice before the other one had a chance to await. +Failing to satisfy this will corrupt the data structure and can cause surprising failures +(including CUDA runtime errors). See below for an example of a valid program with two threads. +``` + +```{warning} +Another critical restriction is that the number of barrier completions must equal the +number of barrier waits throughout the barrier's lifetime. It is not allowed to end a scoped +allocation of a barrier when it has an unawaited completion. Otherwise, when it is +reused by the compiler, leaving it in this state can cause problems downstream. +``` + +```{warning} +Finally, it is crucial to ensure that each thread that ever waits on a `Barrier` +takes part in all `wait` operations on it. It is not allowed to e.g. await every +other completion of a barrier from one thread, and all other completions from another +one. Doing so will lead to deadlocks. To recap: when a `Barrier` is used to wait in +some thread, it must observe every single completion of that barrier (by waiting on it). + +Note that the `Barrier` can receive arrivals from any source, without restrictions. +``` + +There are three operations that can complete a barrier: + +#### Asynchronous GMEM-to-SMEM copies + +When an asynchronous GMEM-to-SMEM copy is being executed by the TMA engine, it will +post progress updates to the barrier given to `plgpu.copy_gmem_to_smem`. Once the copy +is complete, the barrier will complete one arrival as well. + +(cross-thread-synchronization)= +#### Explicit arrival (cross-thread synchronization) + +Any thread can explicitly arrival on a barrier using the following function: +```python +plgpu.barrier_arrive(barrier) +``` + +This is especially useful when synchronizing two threads that are in producer/consumer +roles. In this case, we recommend allocating two arrays of `Barrier`s, with size equal +to the size of the "queue" used to pass data between the two threads. For example, +assume one thread continues writing tiles of an array to SMEM while another thread +reads them. We triple-buffer the SMEM region to allow more asynchrony between the two +threads: + +```python +tid = jax.lax.axis_index("thread") +assert queue.shape == (buffering, *item_shape) +assert produced.shape == consumed.shape == (buffering,) + +def thread0_body(i, _): + slot = jax.lax.rem(i, buffering) + @pl.when(i >= buffering) + def _await_consumed(): + plgpu.barrier_wait(consumed.at[slot]) # Wait for consumption of the value before overwriting it + # Option 1: Compute the next value + queue[slot] = produce() + plgpu.barrier_arrive(produced.at[slot]) # Signal the value is ready + # Option 2: Produce the value through async_copy + # plgpu.copy_gmem_to_smem(..., queue.at[slot], barrier=produced.at[slot]) +pl.when(tid == 0)(lambda: jax.lax.fori_loop(0, steps, thread0_body, None)) + +def thread1_body(i, _): + slot = jax.lax.rem(i, buffering) + plgpu.barrier_wait(produced.at[slot]) # Wait for the value to be ready + consume(queue[slot]) # Load and compute + plgpu.barrier_arrive(consumed.at[slot]) # Signal that the value is consumed +pl.when(tid == 1)(lambda: jax.lax.fori_loop(0, steps, thread1_body, None)) +``` + +(awaiting-tcgen05-instructions)= +#### Awaiting `tcgen05` TensorCore instructions + +Before we begin, an important warning: + +```{warning} +On Blackwell generation of GPUs, `Barrier` operations by default have relaxed +semantics with respect to the TensorCore operations. This means that by default +any TensorCore-related operation (including TMEM operation) can be moved by the +compiler _after a barrier signal_. Similarly, any TensorCore-related operation +can be moved _before a barrier wait_. + +If you mean to use `Barrier`s to indicate to other threads that a TensorCore +operation is complete, allocate the barrier with `orders_tensor_core=True`. This +argument will insert the necessary instructions to prevent the problematic +reordering mentioned above. +``` + +Unlike in older GPUs, the only way to observe the completion of +Blackwell-generation TensorCore instructions is to pass in a `Barrier` reference +to the {py:func}`plgpu.tcgen05_mma ` +function. Once the MMA is complete, the TensorCore will arrive on the barrier. + +Note that this use of `Barrier`s requires that they are created with +`orders_tensor_core=True`, since they are used to synchronize with TensorCore +operations. + +```python +@functools.partial(pl.run_scoped, barrier_ref=plgpu.Barrier(orders_tensor_core=True)) +def barrier_scope(barrier_ref): + plgpu.tcgen05_mma(acc_tmem, lhs_ref, rhs_ref, barrier_ref, accumulate=False) + plgpu.barrier_wait(barrier_ref) + # We can read the result now + result = plgpu.async_load_tmem(acc_tmem) + ... +``` + +### `ClusterBarrier` + +`ClusterBarrier` is very similar to `Barrier`, only used to synchronize across +block clusters, instead of threads within a single block. This is always +necessary when the blocks in the cluster collaborate on shared resources. +Below we outline some of the more common cases when `ClusterBarrier` is necessary +to ensure correctness. + +#### Reusing SMEM for collective async copies + +In the following example, `ClusterBarrier` ensures that both blocks are done +using `x_smem` before it is overwritten. Without the barrier, one of the blocks +would be able to run ahead and start overwriting `x_smem` by entering the +collective copy before the other block is done reading from it. + +```python +def collective_smem_reuse(x_gmem, x_gmem2, y_gmem, x_smem, local_barrier, cluster_barrier): + plgpu.copy_gmem_to_smem(x_gmem, x_smem, local_barrier, collective_axes="cluster") + plgpu.barrier_wait(local_barrier) # x_smem is ready to be used once the local wait completes + y_gmem[0] = x_smem[...] + plgpu.barrier_arrive(cluster_barrier) + plgpu.barrier_wait(cluster_barrier) # x_smem can only be reused once the cluster barrier completes + plgpu.copy_gmem_to_smem(x_gmem2, x_smem, local_barrier, collective_axes="cluster") + plgpu.barrier_wait(local_barrier) # x_smem is ready to be used once the local wait completes + y_gmem[1] = x_smem[...] +``` + +#### Reusing TMEM for collective MMAs on Blackwell + +This example works very similarly to the one before, only this time TMEM is the +shared resource. One block issues collective MMAs for both of them, but they both +need to safely complete a read from TMEM before it can be reused for another +collective MMA. + +```python +def collective_tmem_reuse(acc_tmem, lhs_ref, rhs_ref, mma_barrier, cluster_barrier): + leader_block = lax.axis_index("cluster") == 0 + @pl.when(leader_block) + def _do_mma(): + plgpu.tcgen05_mma( + acc_tmem, lhs_ref.at[0], rhs_ref.at[0], mma_barrier, + accumulate=False, collective_axis="x", + ) + plgpu.barrier_wait(mma_barrier) + do_something(plgpu.async_load_tmem(acc_tmem)) + plgpu.wait_load_tmem() # Ensure the load is complete. + plgpu.barrier_arrive(cluster_barrier) + plgpu.barrier_wait(cluster_barrier) # acc_tmem can only be reused once the cluster barrier completes + @pl.when(leader_block) + def _do_mma(): + plgpu.tcgen05_mma( + acc_tmem, lhs_ref.at[1], rhs_ref.at[1], mma_barrier, + accumulate=False, collective_axis="x", + ) + ... +``` + +### `Semaphore` + +Semaphores are powerful synchronization structures, primarily used to +synchronize across different blocks, potentially running on different devices. +For synchronization between threads within a single block, it is preferable to +use `Barrier`s, while for cluster synchronization it is preferable to use +`ClusterBarrier`s. Semaphores are implemented as 32-bit atomic counters located in +GMEM that support the following operations: + +* {py:func}`pl.semaphore_signal `, + which atomically increments the semaphore. Any effects performed by the thread + before the signal (including reads or writes to remote memory over NVLINK) are + guaranteed to complete before the signal is visible on the target device. +* {py:func}`pl.semaphore_wait `, which + blocks the thread until the semaphore reaches _at least_ the desired value, at + which point the value is atomically decreased and the thread is awoken. The + function can be optionally called with `decrement=False`, which will wake the + thread as soon as the value is at least the requested value, but the value of + the semaphore will not be decreased. The non-decrementing version is a bit + more efficient. + +Here we present a small example kernel that exchanges two small shards between +two devices: + +```python +def exchange_shards(x_ref, y_ref, done_sem): + other_dev_id = 1 - lax.axis_index("x") # We assume two devices + neighbor_ref = plgpu.remote_ref(y_ref, other_dev_id) + neighbor_ref[...] = x_ref[...] # This will write over NVLINK + pl.semaphore_signal(done_sem, device_id=other_dev_id) # Signal that the write is complete + pl.semaphore_wait(done_sem) # Wait for the other device to write to our memory + +mesh = jax.make_mesh((2,), ("x",)) +y = jax.jit( + jax.shard_map( + lambda x: plgpu.kernel(exchange_shards, out_shape=x, + scratch_shapes=[plgpu.Semaphore.REGULAR])(x), + mesh=mesh, in_specs=P("x"), out_specs=P("x"), check_vma=False, + ) +)(x) +``` + +## Cluster launch control + +[Cluster launch control](https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_cluster_launch_control.html#blackwell-cluster-launch-control) +is a feature introduced in Blackwell GPUs (SM100A+) that enables work stealing +or dynamic scheduling of the CUDA grid. This allows an SM +(or cluster of SMs) that has finished its work to cancel the launch of block +intended for another SM and execute the work for itself. The end result is +that load balancing across SMs is improved and you should see better utilization +of the GPU towards the tail end of a kernel. Mosaic GPU exposes both the +low-level cluster launch control commands as well as a helper API that abstracts +away most of the implementation details. + +### Directly using the cluster launch control API + +Mosaic GPU directly exposes the low-level cluster launch control API as two +functions: {py:func}`plgpu.try_cluster_cancel ` +and {py:func}`plgpu.query_cluster_cancel `. +`try_cluster_cancel` is an asynchronous operation that will atomically attempt +to cancel the launch of an available block, and place the result in a Ref. +The result Ref should be a scratch Ref allocated via +`plgpu.TryClusterCancelResult()` (which under the hood is a 16-byte SMEM Ref). + `query_cluster_cancel` will read the result and return two +values: a tuple containing the indices of the grid axes that were requested, +and a boolean indicating whether the cancellation was successful. If +`query_cluster_cancel` was not successful, then the result of the grid indices +is undefined and should not be used. + +When used with clusters, all blocks within the same cluster will receive the +same result from `query_cluster_cancel`. + +The following example demonstrates how to call these with a kernel: +```python +@functools.partial( + plgpu.kernel, + grid=grid, + grid_names=grid_names, + scratch_shapes=dict( + result_ref=plgpu.TryCancelResultRef(), + barrier_ref=plgpu.Barrier() + ) +) +def kernel(result_ref, barrier_ref): + plgpu.try_cluster_cancel(result_ref, barrier_ref) + # ... do work + plgpu.barrier_wait(barrier_ref) + grid_idxs, success = plgpu.query_cluster_cancel(result_ref, grid_names) +``` +```{warning} +It is important to ensure proper synchronization on all threads throughout the +cluster. In most cases when canceling multiple blocks, you may need to +double-buffer the result and barrier to ensure no race conditions occur. For +this reason we recommend using the {py:func}`plgpu.dynamic_scheduling_loop ` +helper function. +``` + +### Using the `plgpu.dynamic_scheduling_loop` helper + +A common pattern when using dynamic work scheduling is to continuously poll +and execute work within the kernel body until there are no more work left, and +then exit the kernel. The {py:func}`plgpu.dynamic_scheduling_loop ` +helper function implements exactly this pattern. + +```python +@plgpu.dynamic_scheduling_loop( + grid_names=grid_names, + thread_axis=thread_name # Required if using multiple threads in a kernel. +) +def body(loop_info): + grid_indices = loop_info.index + # ... do work +``` + +When using this pattern, the kernel should be instantiated with a grid +equal to the logical amount of work to be done (as opposed to a persistent +kernel where the grid is set to the number of cores). Each core running +this loop will continuously query the next available block of work and +the loop will terminate when the entire grid has been scheduled. +The signature of the body function is identical to the one used in +{py:func}`plgpu.nd_loop ` (which +is used for normal persistent kernels) and takes in a `loop_info` dataclass +that contains iteration info, and optionally supports carry values. + +## Asynchronous copies + +Modern GPUs can directly and asynchronously copy data between GMEM and SMEM without +involving registers. Starting from the Hopper generation, the copies can even +be offloaded to a special hardware unit called the Tensor Memory Accelerator (TMA), +which is what Mosaic uses to implement them. + +### GMEM to SMEM copies + +To schedule an asynchronous GMEM to SMEM copy, use {py:func}`plgpu.copy_gmem_to_smem `. The function takes three operands: a source ref, +a destination ref and a `Barrier`. Once the copy is complete, a single arrival will +be observed on the barrier, as if `plgpu.barrier_arrive(barrier)` was called by a background thread: + +```python +def body(in_gmem_ref, out_gmem_ref, smem_ref, barrier): + plgpu.copy_gmem_to_smem(in_gmem_ref, smem_ref, barrier) + plgpu.barrier_wait(barrier) + ... + +plgpu.kernel( + body, + out_shape=..., + scratch_shapes=[plgpu.SMEM(x.shape, x.dtype), plgpu.Barrier()], +) +``` + +A single barrier can be used to synchronize multiple copies, but it has to be +allocated with a higher `arrival_count`: + +```python +def body(in_gmem_ref, in_gmem_ref2, out_gmem_ref, smem_ref, smem_ref2, barrier): + plgpu.copy_gmem_to_smem(in_gmem_ref, smem_ref, barrier) + plgpu.copy_gmem_to_smem(in_gmem_ref2, smem_ref2, barrier) + plgpu.barrier_wait(barrier) # Awaits both copies + ... + +plgpu.kernel( + body, + out_shape=..., + # Barrier is allocated with 2 arrivals. + scratch_shapes=[plgpu.SMEM(x.shape, x.dtype), plgpu.Barrier(num_arrivals=2)], +) +``` + +#### Collective copies + +When using block clusters, the asynchronous transfers feature a _multicast_ option, +meaning that multiple blocks from the cluster can collectively load the same input. +In some sense, this can be seen as a guaranteed L2 hit for all participating blocks, +as it allows for better sharing of the limited HBM bandwidth. + +```{warning} +When using collective copies, all blocks along the specified cluster axes must +issue the same collective copy for the program to be valid. It is not allowed to +only issue it from one block but not from others and it will result in undefined +behavior (most likely a deadlock). +``` + +```{warning} +When using collective copies, you need to be extra careful about reusing the SMEM +buffers. The different blocks in the cluster might finish using them at different +points in time but the first block that issues the next collective copy can overwrite +the data still used by other blocks. See the [`ClusterBarrier` section](#clusterbarrier) +for examples for how to make this safe. +``` + +```python +def body(in_gmem_ref, in_gmem_ref2, out_gmem_ref, smem_ref, smem_ref2, barrier): + block_id = lax.axis_index("cluster") + # Both blocks in the cluster load the same data into smem_ref, so we can use + # a collective copy here. + plgpu.copy_gmem_to_smem(in_gmem_ref, smem_ref, barrier, collective_axes="cluster") + # Each block in the cluster loads a different slice of in_gmem_ref2, so we + # are not allowed to use collective copies. + plgpu.copy_gmem_to_smem(in_gmem_ref2.at[block_id], smem_ref2, barrier) + plgpu.barrier_wait(barrier) # Awaits both copies + ... + +plgpu.kernel( + body, + out_shape=..., + # Barrier is allocated with 2 arrivals. + scratch_shapes=[plgpu.SMEM(x.shape, x.dtype), plgpu.Barrier(num_arrivals=2)], +) +``` + +#### Collective partitioned copies (Blackwell only) + +In the Blackwell generations, collective copies that involve clusters of two +blocks can be _partitioned_ by passing an additional `partitioned_axis` argument. +When specified, the GMEM reference is expected to be double the size of the +destination SMEM reference along the specified dimension. The destination in the +first block will be overwritten with the first half of the GMEM ref, while the +second block will receive the second half. + +This by itself would be equivalent to performing two non-collective copies on +different input slices, but there's one crucial difference: only the barrier in +the first block will receive the arrival once both copies complete. The barrier +argument in the second block is ignored and the second block cannot use it to +await the completion of the transfer. + +Arguably, this is a bit of a surprising feature, but it makes sense in the +context of collective MMAs on Blackwell. There, each block is responsible for +loading the operands into SMEM, but only the first block awaits the +completion of the transfers and issues the MMA instructions. The second block +usually waits on the completion of the MMA to indicate that the transfer is done, +and the SMEM data has been read out, implying that it can safely overwrite it. + +### SMEM to GMEM copies + +To schedule an asynchronous GMEM to SMEM copy, use {py:func}`plgpu.copy_smem_to_gmem `. As opposed to the other direction, this primitive +only takes in the source and destination references. To await the completion of +the copy, use the {py:func}`plgpu.wait_smem_to_gmem `. + +The synchronization scheme for SMEM to GMEM copies is a little unexpected in that +they cannot be awaited in arbitrary orders. `plgpu.wait_smem_to_gmem` takes as +an argument the number of most recent copies **you do not want to await**, or equivalently +the number of asynchronous SMEM to GMEM copies that you still want to allow +to run: + +```python +def copy_out(x_smem, y_smem, x_gmem, y_gmem): + plgpu.copy_smem_to_gmem(x_smem, x_gmem) + plgpu.copy_smem_to_gmem(y_smem, y_gmem) + plgpu.wait_smem_to_gmem(1, wait_read_only=True) + # At this point we know that the data of x_smem has been read, but we don't + # yet know that x_gmem contains the updated data. + plgpu.wait_smem_to_gmem(1) + # At this point we know that the x_smem -> x_gmem copy is done, but we know + # nothing about the y_smem -> y_gmem copy. + plgpu.wait_smem_to_gmem(0) + # At this point we know that both copies are complete. +``` + +Note that an SMEM to GMEM copy can only ever be awaited in the same thread that +has issued it. `wait_smem_to_gmem` returns immediately if no copies have been +issued or they have all completed. + +#### Only awaiting the read from SMEM + +Another option is that you can either await the copy being committed to GMEM +You can choose to wait until the copy is fully written into GMEM +(in a way that will be visible to following reads), or you can only await the +data being read from SMEM by specifying `wait_read_only` in the wait function. +This allows for a faster reuse of SMEM buffers if you don't intend to read back +the data sent to GMEM just yet. + +#### Grouping multiple copies + +When `copy_smem_to_gmem` receives `commit_group=False` as an argument, it cannot +be awaited until {py:func}`plgpu.commit_group ` +is called explicitly, or another `copy_smem_to_gmem` without that argument is issued. +All SMEM to GMEM copies since the last commit are grouped together as a single awaitable unit: + +```python +def copy_out(x_smem, y_smem, x_gmem, y_gmem): + plgpu.copy_smem_to_gmem(x_smem, x_gmem, commit_group=False) + plgpu.copy_smem_to_gmem(y_smem, y_gmem) # Implicitly commits both copies + plgpu.wait_smem_to_gmem(1) + # At this point we only know that no SMEM to GMEM copies other than the two + # above are active. + plgpu.wait_smem_to_gmem(0) + # Only now we know that both copies above have completed. +``` + +### Asynchronous gathers + +On Blackwell GPUs, the TMA engine has an additional mode that allows for an efficient +implementation of gathers along the first dimension on a 2D matrix. Using this +mode is actually very simple. The 1D array of indices should be loaded into +a `plgpu.Layout.TMA_GATHER_INDICES` layout, and the source GMEM reference +has to be indexed with that array using the `.at` operator: + +```python +@functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(out_shape, dtype), + out_specs=plgpu.BlockSpec(memory_space=plgpu.SMEM, transforms=transforms), + in_specs=( + pl.BlockSpec(memory_space=plgpu.GMEM), + pl.BlockSpec(memory_space=plgpu.SMEM), + ), + scratch_shapes=[plgpu.Barrier()], +) +def kernel(x_ref_gmem, idx_ref, o_ref, barrier_ref): + idxs = plgpu.load(idx_ref, (), layout=plgpu.Layout.TMA_GATHER_INDICES) + plgpu.copy_gmem_to_smem(x_ref_gmem.at[idxs], o_ref, barrier_ref) + plgpu.barrier_wait(barrier_ref) +``` + +The `plgpu.copy_gmem_to_smem` automatically recognizes that the reference has +been sliced with an array and will use the gather TMA instructions to implement +the copy. + +### NVLINK transfers + +Asynchronous copies in either direction support GMEM references returned from +`plgpu.peer_ref`, which makes it possible to perform NVLINK transfers asynchronously. + +```python +def exchange_shards(x_ref, y_ref, smem_ref, local_barrier, done_sem): + plgpu.copy_gmem_to_smem(x_ref, smem_ref, local_barrier) # Local copy + plgpu.barrier_wait(local_barrier) + other_dev_id = 1 - lax.axis_index("x") # We assume two devices + neighbor_ref = plgpu.remote_ref(y_ref, other_dev_id) + plgpu.copy_smem_to_gmem(smem_ref, neighbor_ref) + plgpu.wait_smem_to_gmem(0) # Wait for the asynchronous write to complete + pl.semaphore_signal(done_sem, device_id=other_dev_id) # Signal that the write is complete + pl.semaphore_wait(done_sem) # Wait for the other device to write to our memory + +mesh = jax.make_mesh((2,), ("x",)) +y = jax.jit( + jax.shard_map( + lambda x: plgpu.kernel( + exchange_shards, + out_shape=x, + scratch_shapes=[x, plgpu.Barrier(), plgpu.Semaphore.REGULAR] + )(x), + mesh=mesh, in_specs=P("x"), out_specs=P("x"), check_vma=False, + ) +)(x) +``` + +## Inline Mosaic GPU + +TODO + +## Compiler parameters + +TODO + +## Debugging + +Mosaic GPU exposes a number of environment variables to diagnose issues with the +generated low-level code: + +* `MOSAIC_GPU_DUMP_PTXAS` allows dumping the compilation logs from `ptxas` to + standard output when set; +* `MOSAIC_GPU_DUMP_PTX` allows dumping the PTX code generated during compilation + to standard output when set; +* `MOSAIC_GPU_DUMP_MLIR_PASSES` allows dumping the IR after every MLIR pass + in the compilation pipeline to standard output; +* `MOSAIC_GPU_DUMP_SASS` allows dumping the SASS code produced at the end of + compilation to standard output; +* `MOSAIC_GPU_DUMP_SASS_CTRL` allows dumping the SASS control codes following + [NervanaSystems/maxas](https://github.com/NervanaSystems/maxas) to standard + output; +* `MOSAIC_GPU_DUMP_TO` allows specifying a directory path (that must exist) + where all of the above will be dumped as files. +* `MOSAIC_GPU_LLVM_DEBUG_ONLY` allows specifying a comma-separated list of + [LLVM debug types](https://llvm.org/docs/ProgrammersManual.html#fine-grained-debug-info-with-debug-type-and-the-debug-only-option), + in order to produce relevant LLVM debugging logs. This environment variable is + only available in debug builds (i.e. builds without `NDEBUG`). +* `MOSAIC_GPU_DUMP_LLVM` allows dumping LLVM IR when set. It is equivalent to + setting `MOSAIC_GPU_LLVM_DEBUG_ONLY=serialize-to-llvm`, and both environment + variables compose. Like `MOSAIC_GPU_LLVM_DEBUG_ONLY`, this environment + variable is only available in debug builds. + +## Calling kernels from PyTorch + +The {py:func}`plgpu.as_torch_kernel ` +decorator wraps a Pallas:MGPU kernel to allow invoking it with PyTorch tensors. +It accepts CUDA tensors as inputs and returns newly allocated CUDA tensors +on the same device. + +Example: + +```python +import functools +import jax +import jax.numpy as jnp +import torch + +@functools.partial( + pl.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.int32) +) +def add_kernel(x_ref, y_ref, o_ref): + o_ref[...] = x_ref[...] + y_ref[...] + +x = torch.arange(128, dtype=torch.int32, device="cuda") +y = x * x +out = plgpu.as_torch_kernel(add_kernel)(x, y) +``` + +`plgpu.as_torch_kernel` only supports functions that contain a single kernel +invocation (e.g. via `pl.pallas_call` or `plgpu.kernel`), and no calls to +other JAX operations, e.g. from {mod}`jax.numpy`. diff --git a/docs/pallas/grid_blockspec.md b/docs/pallas/grid_blockspec.md index ea1df15f2fd4..d360e3e660b5 100644 --- a/docs/pallas/grid_blockspec.md +++ b/docs/pallas/grid_blockspec.md @@ -80,8 +80,14 @@ Not all block shapes are supported. must be equal to the array dimension, or be divisible by `128 * (32 / bitwidth(dtype))`. - * On GPU, the size of the blocks themselves is not restricted, but each - operation must operate on arrays whose size is a power of 2. + * On GPU, when using the Mosaic GPU backend, the size of the blocks is + unrestricted. However, due to hardware limitations, the size of the minormost + array dimension must by such that it is a multiple of 16 bytes. For example, + it must be a multiple of 8 if the input is `jnp.float16`. + + * On GPU, when using the Triton backend, the size of the blocks themselves is + unrestricted, but each operation (including a load or store) must operate + on arrays whose size is a power of 2. ``` If the block shape does not divide evenly the overall shape then the @@ -151,8 +157,7 @@ over the second axis: ```python >>> def show_program_ids(x_shape, block_shape, grid, -... index_map=lambda i, j: (i, j), -... indexing_mode=pl.Blocked()): +... index_map=lambda i, j: (i, j)): ... def program_ids_kernel(o_ref): # Fill the output block with 10*program_id(1) + program_id(0) ... axes = 0 ... for axis in range(len(grid)): @@ -162,7 +167,7 @@ over the second axis: ... out_shape=jax.ShapeDtypeStruct(x_shape, dtype=np.int32), ... grid=grid, ... in_specs=[], -... out_specs=pl.BlockSpec(block_shape, index_map, indexing_mode=indexing_mode), +... out_specs=pl.BlockSpec(block_shape, index_map), ... interpret=True)() ... print(res) @@ -227,7 +232,8 @@ See {ref}`pallas_tpu_noteworthy_properties`. A `None` value appearing as a dimension value in the `block_shape` behaves as the value `1`, except that the corresponding -block axis is squeezed. In the example below, observe that the +block axis is squeezed (you could also pass in `pl.Squeezed()` instead of +`None`). In the example below, observe that the shape of the `o_ref` is (2,) when the block shape was specified as `(None, 2)` (the leading dimension was squeezed). @@ -269,27 +275,33 @@ used: `index_map=lambda *invocation_indices: (0,) * len(block_shape)`. ``` -### The "unblocked" indexing mode +### The "element" indexing mode -The behavior documented above applies to the `indexing_mode=pl.Blocked()`. -When using the `pl.Unblocked` indexing mode the values returned by the +The behavior documented above applies to the default "blocked" indexing mode. +When integers are used in the `block_shape` tuple e.g. `(4, 8)`, it is +equivalent to passing in a `pl.Blocked(block_size)` object instead, e.g. +`(pl.Blocked(4), pl.Blocked(8))`. Blocked indexing mode means the indices +returned by `index_map` are *block indices*. We can pass in objects other than +`pl.Blocked` to change the semantics of `index_map`, most notably, +`pl.Element(block_size)`.. +When using the `pl.Element` indexing mode the values returned by the index map function are used directly as the array indices, without first scaling them by the block size. -When using the unblocked mode you can specify virtual padding -of the array as a tuple of low-high paddings for each dimension: the +When using the `pl.Element` mode you can specify virtual padding +of the array as a tuple of low-high paddings for the dimension: the behavior is as if the overall array is padded on input. No guarantees -are made for the padding values in the unblocked mode, similarly to the padding +are made for the padding values in element mode, similarly to the padding values for the blocked indexing mode when the block shape does not divide the overall array shape. -The unblocked mode is currently supported only on TPUs. +The `Element` mode is currently supported only on TPUs. ```python ->>> # unblocked without padding ->>> show_program_ids(x_shape=(8, 6), block_shape=(2, 3), grid=(4, 2), -... index_map=lambda i, j: (2*i, 3*j), -... indexing_mode=pl.Unblocked()) +>>> # element without padding +>>> show_program_ids(x_shape=(8, 6), block_shape=(pl.Element(2), pl.Element(3)), +... grid=(4, 2), +... index_map=lambda i, j: (2*i, 3*j)) [[ 0 0 0 1 1 1] [ 0 0 0 1 1 1] [10 10 10 11 11 11] @@ -299,10 +311,12 @@ The unblocked mode is currently supported only on TPUs. [30 30 30 31 31 31] [30 30 30 31 31 31]] ->>> # unblocked, first pad the array with 1 row and 2 columns. ->>> show_program_ids(x_shape=(7, 7), block_shape=(2, 3), grid=(4, 3), -... index_map=lambda i, j: (2*i, 3*j), -... indexing_mode=pl.Unblocked(((1, 0), (2, 0)))) +>>> # element, first pad the array with 1 row and 2 columns. +>>> show_program_ids(x_shape=(7, 7), +... block_shape=(pl.Element(2, (1, 0)), +... pl.Element(3, (2, 0))), +... grid=(4, 3), +... index_map=lambda i, j: (2*i, 3*j)) [[ 0 1 1 1 2 2 2] [10 11 11 11 12 12 12] [10 11 11 11 12 12 12] diff --git a/docs/pallas/index.rst b/docs/pallas/index.rst index b2e2fca6c82e..b3e3ca327f1d 100644 --- a/docs/pallas/index.rst +++ b/docs/pallas/index.rst @@ -22,15 +22,28 @@ See also the :class:`jax.experimental.pallas` module API documentation. :maxdepth: 2 quickstart + pipelining grid_blockspec .. toctree:: - :caption: Platform Features + :caption: TPU backend guide :maxdepth: 2 tpu/index +.. toctree:: + :caption: Mosaic GPU backend guide + :maxdepth: 2 + + gpu/index + +.. toctree:: + :caption: Instruction Reference + :maxdepth: 2 + + Instruction Reference <../jax.experimental.pallas> + .. toctree:: :caption: Design Notes :maxdepth: 2 diff --git a/docs/pallas/pipelining.ipynb b/docs/pallas/pipelining.ipynb new file mode 100644 index 000000000000..e2fb97a42cd9 --- /dev/null +++ b/docs/pallas/pipelining.ipynb @@ -0,0 +1,870 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "C93Xlf0DRW9H" + }, + "source": [ + "\n", + "(pallas_software_pipelining)=\n", + "\n", + "# Software Pipelining\n", + "\n", + "Software pipelining is an important technique in performance optimization by overlapping multiple asynchronous operations even if there are data dependencies between them. In the context of kernel writing, the most common form of pipelining involves overlapping communication and memory transfers with compute such that the hardware accelerator never stalls while waiting for data to arrive. Therefore, we will solely focus on the problem of communication-compute pipelining in this tutorial. We will begin by covering the problem conceptually, outlining the Pallas API for writing pipelines, and going over some realistic examples using the API.\n", + "\n", + "This tutorial only covers the conceptual foundations of pipelining. For platform-specific references, please see {ref}`pallas_tpu_pipelining`, or {ref}`pallas_mgpu_pipelining`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YkOjspo5BKPD" + }, + "outputs": [], + "source": [ + "import jax\n", + "from jax import numpy as jnp\n", + "from jax.experimental import pallas as pl\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "shnVghWUSvpx" + }, + "source": [ + "## Memory Hierarchies\n", + "\n", + "The first step in understanding pipelining conceptually involves understanding the different forms of memory available and the tradeoffs between them. Most hardware architectures (including CPUs, GPUs, and TPUs) utilize a wide variety of memory spaces that tradeoff capacity vs latency/bandwidth. For the purpose of Pallas, we are typically interested in registers, SRAM, DRAM, and potentially network communication:\n", + "- **Registers** are the the memory physically closest to the processor, and typically values must be loaded directly into registers before doing any compute on them.\n", + "- **SRAM** (also known as Shared Memory/L1 and L2 cache on GPUs, or VMEM on TPUs) also lives fairly close to the processor, but has larger capacity than registers.\n", + "SRAM on modern ML accelerators typically range in the 10-100MB range (TPU v5p contains 96MB of VMEM, and H100 GPUs contain ~30MB of L1 cache and 50MB of L2).\n", + "It's reasonable to expect the latency to access SRAM to be on the order of 10x longer than accessing a register.\n", + "- **DRAM** (also known as HBM) has much higher capacity than SRAM, typically in the 10-100GB range for modern ML accelerators. However, the latency is roughly on the order of 10x longer to access compared to SRAM.\n", + "- **Network** communication becomes crucial for larger workloads when the size of DRAM on a single device becomes insufficient or when we'd like to take advantage of parallel computations. We do not cover distributed pipelining in this tutorial, but see the [distributed TPU kernels](https://docs.jax.dev/en/latest/pallas/tpu/distributed.html) guide for writing pipelines across multiple devices.\n", + "\n", + "\n", + "\n", + "\n", + "![memory_hierarchy](../_static/pallas/pipelining_mem_hierarchy.svg)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WvW6Lo7d2jfb" + }, + "source": [ + "\n", + "In order to perform computation on values X and Y that live in HBM, we need to:\n", + "\n", + "1. Copy the values x and y into SRAM.\n", + "2. Load the values from SRAM into registers.\n", + "3. Execute the computation and store the result into registers.\n", + "4. Store the values in the output registers into SRAM.\n", + "5. Copy the output values in SRAM back to HBM.\n", + "\n", + "Let’s implement a Pallas function that does just that!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 108, + "status": "ok", + "timestamp": 1744764235906, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "IrPhDFnT3Nvw", + "outputId": "8bc03872-fd9f-4610-9d53-d4b46be560f4" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([[2., 2., 2., ..., 2., 2., 2.],\n", + " [2., 2., 2., ..., 2., 2., 2.],\n", + " [2., 2., 2., ..., 2., 2., 2.],\n", + " ...,\n", + " [2., 2., 2., ..., 2., 2., 2.],\n", + " [2., 2., 2., ..., 2., 2., 2.],\n", + " [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Note: This is a TPU example.\n", + "\n", + "def add_matrices_kernel(x_sram_ref, y_sram_ref, z_sram_ref):\n", + " # Load x and y from SRAM into registers\n", + " x_regs = x_sram_ref[:, :]\n", + " y_regs = y_sram_ref[:, :]\n", + " # Execute a vectorized add\n", + " z_regs = x_regs + y_regs\n", + " # Store the output values in registers back into SRAM\n", + " z_sram_ref[:, :] = z_regs\n", + "\n", + "\n", + "def add_matrices(x: jax.Array, y: jax.Array) -> jax.Array:\n", + " # pallas_call will first allocate scratch buffers for `x` and `y` in SRAM.\n", + " # It will then copy `x` and `y` from HBM into SRAM.\n", + " z = pl.pallas_call(\n", + " add_matrices_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)\n", + " )(x, y)\n", + " # pallas_call will also copy the output from SRAM back into HBM.\n", + " return z\n", + "\n", + "\n", + "x, y = jnp.ones((512, 512)), jnp.ones((512, 512))\n", + "add_matrices(x, y)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gGjtwv9u3UNK" + }, + "source": [ + "We've written two functions: `add_matrices_kernel` and `add_matrices`.\n", + "\n", + "`add_matrices_kernel` operates using `Refs` that live in SRAM. Loading from a SRAM Ref produces a value that lives in registers. Values in registers behave like jax.Arrays in that we can use `jnp` and `jax.lax` operations on them to produce new values that live in registers. When we produce the values we'd like to return, we store them in the output SRAM `Ref`.\n", + "\n", + "The `add_matrices` function acts on `jax.Array`s and returns a `jax.Array`. Inside it, we pass `x` and `y` into pallas_call. `pallas_call` is responsible for copying `x` and `y` into SRAM and for allocating the SRAM buffers that the kernel operates on (including allocating `z_vmem_ref`, the output SRAM buffer). After the kernel function is finished running, `pallas_call` will also copy the value in `z_vmem_ref` to HBM, resulting in an output `jax.Array`.\n", + "\n", + "Pallas exposes access to lower level memory spaces like SRAM but writing performant kernels requires more care in utilizing the various memory spaces. For example, we need to consider both:\n", + "\n", + "- **Memory capacity**. SRAM is small! If our arrays are too big, the above kernel would not work because we cannot fit the input into SRAM. For reference, an `f32[2048, 2048]` array is 16MiB, so our above kernel won't scale beyond moderately sized arrays.\n", + "\n", + "- **Memory bandwidth**. Copying to/from HBM and SRAM takes a long time, at least compared to most compute instructions. The `add_matrices` function above will likely spend more time copying between HBM and SRAM than actually performing the addition itself.\n", + "\n", + "With these two constraints in mind, we'll have to rethink our strategy for getting performance out of our accelerators.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0Ebs2pCDgsEW" + }, + "source": [ + "## Pipelining Basics\n", + "\n", + "\n", + "How can we take advantage of the strengths of each form of type memory in the hierarchy, and be able to operate on large arrays stored in HBM while still utilizing fast SRAM for compute? Pipelining is a very general programming pattern which will allow us to do exactly this, but it requires transforming your problem into smaller sub-problems that can be overlapped in parallel.\n", + "\n", + "The first step in pipelining is to divide our problem into smaller subproblems that can fit inside of SRAM. For example, an elementwise operation is can be trivially transformed by operating on one slice of the source array at a time, which results in the following 3 steps (also known as stages): \n", + "\n", + "1. **copy_in**: Copy a slice `A[i]` from HBM to SRAM `X`.\n", + "2. **compute**: Load `X` into registers, compute a result, and store in SRAM `Y`\n", + "3. **copy_out**: Copy result `Y` back into HBM `A[i]`.\n", + "\n", + "Note that there is a data-dependence between steps 1-3, and we cannot trivially overlap them since we need step (1) to complete before starting step (2), and so on. However, there is no data dependence across multiple invocations of the subproblem - that is, we can execute step (1) for block `A[i+1]` while executing step (2) for block `A[i]` and step (3) for block `A[i-1]`.\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8vCtShhBjzTd" + }, + "source": [ + "\n", + "![pipelining_example](../_static/pallas/pipelining_example.svg)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Qs3F--kwiOJm" + }, + "source": [ + "The diagram above depicts how an idealized pipelined program can be scheduled across time. The key insight is that in the majority of the kernel, the copy operations are executed in parallel with compute operations, meaning we can ideally \"hide\" the cost of transferring between HBM/SRAM with computation and keep the processor busy with as much uptime as possible.\n", + "\n", + "The initial startup time and final teardown time known as \"bubbles\", where only a subset of the stages are being executed while the pipeline is being \"filled\" or \"drained\". The bulk of the time is spent in the \"steady-state\" phase of the pipeline, where each pipeline stage is being executed in parallel across different iterations of the subproblem. While with more general pipelining approaches the goal is to achieve N-way parallelism (where N is the number of stages), with kernel pipelining we are usually bottlenecked either by memory bandwidth or processing speed. Therefore, our goal with kernel pipelining is typically to achieve full utilization of the FLOPs/s of our processor, meaning that at any point in time there is always a `compute` block active. In the figure above, the compute block is active in 6/8 timeslots, and assuming we are fully utilizing the processor in each compute timeslot, we would have achieved 75% utilization of the processor." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZcSzl4N6pPbG" + }, + "source": [ + "### Deriving a Double-Buffered Pipeline\n", + "\n", + "Now lets look at how we could implement a pipeline in pseudocode. Consider the following elementwise program, where we load values from HBM (`A[i]`) with a `copy_in` instruction, add 1 to the result, and store the result back to HBM with `copy_out`:\n", + "\n", + "
\n",
+    "for i in range(N):\n",
+    "  copy_in(A[i], X)\n",
+    "  Y = X + 1\n",
+    "  copy_out(Y, A[i])\n",
+    "
\n", + "The issue with this approach is that `copy_in` and `copy_out` are typically blocking operations. So we are forced to wait for the copies to finish while the GPU/TPU is idle, then perform compute while the memory is idle. What we would like to do is to \"pre-fetch\" the input value that is required on the next iteration of the loop asynchronously while performing the computation for the current loop, so that compute and memory communication are happening simultaneously.\n", + "\n", + "In order to reason about the code transformation we will make, lets unroll the loop for N=4, and decompose the copy instructions into separate `copy_start` and `copy_wait` operations to be able to express asynchrony:\n", + "
\n",
+    "  # Itr 1\n",
+    "  copy_in_start(A[0], X)\n",
+    "  copy_in_wait(X)\n",
+    "  Y = X + 1\n",
+    "  copy_out_start(Y, A[0])\n",
+    "  copy_out_wait(Y)\n",
+    "\n",
+    "  # Itr 2\n",
+    "  copy_in_start(A[1], X)\n",
+    "  copy_in_wait(X)\n",
+    "  Y = X + 1\n",
+    "  copy_out_start(Y, A[1])\n",
+    "  copy_out_wait(Y)\n",
+    "\n",
+    "  # Itr 3\n",
+    "  copy_in_start(A[2], X)\n",
+    "  copy_in_wait(X)\n",
+    "  Y = X + 1\n",
+    "  copy_out_start(Y, A[2])\n",
+    "  copy_out_wait(Y)\n",
+    "\n",
+    "  # Itr 4\n",
+    "  copy_in_start(A[3], X)\n",
+    "  copy_in_wait(X)\n",
+    "  Y = X + 1\n",
+    "  copy_out_start(Y, A[3])\n",
+    "  copy_out_wait(Y)\n",
+    "
\n", + "\n", + "Once the loop has been unrolled, the pipelining transformation simply involves issuing `copy_start` instructions as early as possible, and `copy_wait` values as late as possible (right before we need the value). However, in the current state of the loop there is a fake data dependency through X - we cannot simultaneously perform an async copy into X while using it for computation or else we may have a race condition. Therefore, we can use a **multiple-buffering** technique where we keep 2 buffers for each input X and each output Y. With 2 buffers, we can push the `copy_in_start` one iteration ahead (with 3 buffers you can push 2 iterations, and so on) and we rewrite our loop as follows:\n", + "
\n",
+    "  # Prologue\n",
+    "  copy_in_start(A[0], X[0])\n",
+    "  \n",
+    "  # Itr 1\n",
+    "  copy_in_start(A[1], X[1])\n",
+    "  copy_in_wait(X[0])\n",
+    "  Y[0] = X[0] + 1\n",
+    "  copy_out_start(Y[0], A[0])\n",
+    "  copy_out_wait(Y[0])\n",
+    "\n",
+    "  # Itr 2 - Steady state\n",
+    "  copy_in_start(A[2], X[0])\n",
+    "  copy_in_wait(X[1])\n",
+    "  Y[1] = X[1] + 1\n",
+    "  copy_out_start(Y[1], A[1])\n",
+    "  copy_out_wait(Y[1])\n",
+    "\n",
+    "  # Itr 3 - Steady state\n",
+    "  copy_in_start(A[3], X[1])\n",
+    "  copy_in_wait(X[0])\n",
+    "  Y[0] = X[0] + 1\n",
+    "  copy_out_start(Y[0], A[2])\n",
+    "  copy_out_wait(Y[0])\n",
+    "\n",
+    "  # Itr 4 - No copy-in\n",
+    "  copy_in_wait(X[1])\n",
+    "  Y[1] = X[1] + 1\n",
+    "  copy_out_start(Y[1], A[3])\n",
+    "  copy_out_wait(Y[1])\n",
+    "
\n", + "\n", + "Next, we can push the `copy_out_wait` as late as possible, right before we need to write into Y on the subsequent loop iteration.\n", + "\n", + "
\n",
+    "  # Prologue\n",
+    "  copy_in_start(A[0], X[0])\n",
+    "  \n",
+    "  # Itr 1\n",
+    "  copy_in_start(A[1], X[1])\n",
+    "  copy_in_wait(X[0])\n",
+    "  Y[0] = X[0] + 1\n",
+    "  copy_out_start(Y[0], A[0])\n",
+    "\n",
+    "  # Itr 2 - Steady state\n",
+    "  copy_in_start(A[2], X[0])\n",
+    "  copy_in_wait(X[1])\n",
+    "  Y[1] = X[1] + 1\n",
+    "  copy_out_start(Y[1], A[1])\n",
+    "  copy_out_wait(Y[0])\n",
+    "\n",
+    "  # Itr 3 - Steady state\n",
+    "  copy_in_start(A[3], X[1])\n",
+    "  copy_in_wait(X[0])\n",
+    "  Y[0] = X[0] + 1\n",
+    "  copy_out_start(Y[0], A[2])\n",
+    "  copy_out_wait(Y[1])\n",
+    "\n",
+    "  # Itr 4 - No copy-in\n",
+    "  copy_in_wait(X[1])\n",
+    "  Y[1] = X[1] + 1\n",
+    "  copy_out_start(Y[1], A[3])\n",
+    "  copy_out_wait(Y[0])\n",
+    "\n",
+    "  # Epilogue\n",
+    "  copy_out_wait(Y[1])\n",
+    "
\n", + "\n", + "Finally, re-rolling our loop back into a for loop, we obtain the following pipelined loop:\n", + "\n", + "```\n", + "# Prologue\n", + "copy_in_start(A[0], X[0])\n", + "\n", + "# Main loop\n", + "for i in range(N):\n", + " cur_slot = i % 2\n", + " next_slot = (i + 1) % 2\n", + "\n", + " if i+1 < N:\n", + " copy_in_start(A[i+1], X[next_slot])\n", + " \n", + " copy_in_wait(X[cur_slot])\n", + " Y[cur_slot] = X[cur_slot] + 1\n", + " copy_out_start(Y[cur_slot], A[i])\n", + "\n", + " if i > 0:\n", + " copy_out_wait(Y[next_slot])\n", + "\n", + "# Epilogue\n", + "copy_out_wait(Y[1])\n", + "```\n", + "\n", + "If we want to generalize this loop to handle a broader set of computations, notice that we essentially need to specify 3 pieces of information to the pipeline:\n", + "\n", + "- The **grid**, or the bounds of the for loop that specifies the number of subproblems to compute. In our example we had a 1-dimensional grid with size `(N,)`.\n", + "- The **kernel**, or the actual computation happening once the inputs have been loaded into SRAM. In our example we performed an elementwise addition `Y = X + 1`.\n", + "- The **data_slices**, which map a subproblem to corresponding slices into the HBM buffer. In our example the data slice was the identity function `lambda i: i`.\n", + "\n", + "By allowing the user to specify these pieces of information we can write a wide variety of programs following this pattern:\n", + "```python\n", + "def double_buffered_pipeline(\n", + " grid: tuple[int, ...],\n", + " kernel: Callable,\n", + " in_slices: Callable,\n", + " out_slices: Callable):\n", + " # Prologue\n", + " copy_in_start(in_hbm[in_slices(0)], in_sram[0])\n", + "\n", + " # Main loop\n", + " grid_size = prod(grid)\n", + " for i in range(grid_size):\n", + " cur_slot = i % 2\n", + " next_slot = (i + 1) % 2\n", + " if (i + 1) < grid_size:\n", + " copy_in_start(in_hbm[in_slices(i+1)], in_sram[next_slot])\n", + " copy_in_wait(in_sram[cur_slot])\n", + "\n", + " kernel(in_sram[cur_slot], out_ram[cur_slot])\n", + "\n", + " copy_out_start(out_sram[cur_slot], out_hbm[out_slices(i)])\n", + " if i > 0:\n", + " copy_out_wait(out_sram[next_slot])\n", + "\n", + " # Epilogue\n", + " last_slot = (grid_size - 1) % 2\n", + " copy_out_wait(out_sram[last_slot])\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ziBuvv8jDgxo" + }, + "source": [ + "Now that we've seen how to manually implement a pipelined loop, let's look into how to use the Pallas API." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "niMr39cPkJ2m" + }, + "source": [ + "## Pallas Pipelining API\n", + "\n", + "Pallas offers a pipelining API that abstracts away the boilerplate of maintaining multiple buffers and overlapping asynchronous communication with computation. The basics of this API are covered in {ref}`pallas_quickstart`, so we will go over the API briefly here for completeness and discuss some sharp edges that arise from the use of pipelining.\n", + "\n", + "\n", + "### Grid\n", + "\n", + "The program **grid** is a tuple of integers specifying the number of subproblems as an array. The structure of the pipeline can be interpreted as a nested for-loop where the bounds of each loop.\n", + "\n", + "```\n", + "# For grid (N, M, K)\n", + "for n in range (N):\n", + " for m in range(M):\n", + " for k in range(K):\n", + " kernel()\n", + "```\n", + "\n", + "The kernel will be invoked a total of `prod(grid)` times. For more details, see [grid and blockspecs](https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#grid-a-k-a-kernels-in-a-loop).\n", + "\n", + "### BlockSpecs\n", + "\n", + "A BlockSpec specifies the size and slice of data copied to the kernel on each subproblem. The basic constructor to `pl.BlockSpec` involves specifying the `block_shape`, the size of a slice of data, and `index_map`, a function that takes in the program ids of the current subproblem and outputs _blocked_ indices into the source buffer. Blocked indices specify which block to copy on each iteration, assuming the source buffer has been carved into blocks of shape as `block_shape`. The `memory_space` argument specifies what memory space to copy the inputs to - be default this will be SRAM.\n", + "\n", + "```python\n", + "pl.BlockSpec(\n", + " block_shape: tuple[int, ...],\n", + " index_map: Callable,\n", + " memory_space: pl.MemorySpace\n", + ")\n", + "```\n", + "There should be one BlockSpec for each input and each output to the kernel. For more details, see [grid and blockspecs](https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#grid-a-k-a-kernels-in-a-loop).\n", + "\n", + "### Kernel\n", + "\n", + "The kernel function specifies what compute to perform on each subproblem. The kernel function should return no outputs, and instead all outputs should be written into the output buffers that are passed into the kernel. All inputs and output buffers are SRAM buffers by default (unless the user has overridden the behavior by specifying a `memory_space` on the corresponding `BlockSpec`).\n", + "\n", + "```python\n", + "def kernel(*input_buffers, *output_buffers):\n", + " # ... perform compute\n", + " # ... store result into output buffers\n", + "```\n", + "\n", + "The index of the current subproblem can be queried inside the kernel using `pl.program_id(grid_axis: int)`.\n", + "\n", + "\n", + "### Pallas Call\n", + "\n", + "The `pl.pallas_call` function is the main entry point to Pallas and performs pipelined execution when a grid and BlockSpecs are supplied. It has the following signature:\n", + "```python\n", + "def pallas_call(\n", + " kernel,\n", + " grid: tuple[int, ...],\n", + " in_specs: Sequence[PyTree[BlockSpec]],\n", + " out_specs: PyTree[BlockSpec],\n", + " out_shape: PyTree[jax.ShapeDtypeStruct],\n", + ") -> Callable:\n", + "```\n", + "`pallas_call` will return a callable function that when invoked with input values, will return outputs of the same shape as `out_shape`.\n", + "\n", + "`in_specs`, `out_specs`, and `out_shape` are PyTrees of their respective element type. The PyTrees for `in_specs` and the input buffers supplied to the kernel should match, and the PyTrees for `out_specs` and `out_shape` should also match.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0mHZ63eAq_8j" + }, + "source": [ + "### Example - Elementwise Kernel revisited\n", + "\n", + "Let's revisit the initial `add_matrices_kernel` from the beginning of the tutorial, except using pipelining. We will add two input arrays of shape `f32[4096, 4096]` that live in HBM. As subproblems, we will carve up the inputs into `block_shape=(512, 512)` blocks and only add two blocks together at a time in the kernel. Because addition is elementwise, each `index_map` is identical and selects out the `i, j`th block on the `i, j`th iteration." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "iqr_qjONAHN9" + }, + "outputs": [], + "source": [ + "# Note: This is a TPU example.\n", + "\n", + "total_shape = (4096, 4096)\n", + "block_shape = (512, 512)\n", + "\n", + "def add_matrices_pipelined_kernel(x_ref, y_ref, o_ref):\n", + " o_ref[...] = x_ref[...] + y_ref[...]\n", + "\n", + "def add_matrices_pipelined(x: jax.Array, y: jax.Array):\n", + " return pl.pallas_call(\n", + " add_matrices_pipelined_kernel,\n", + " grid=tuple(total // block for (total, block) in zip(total_shape, block_shape)),\n", + " in_specs=[\n", + " pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)),\n", + " pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j))\n", + " ],\n", + " out_specs=pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)),\n", + " out_shape=jax.ShapeDtypeStruct(total_shape, dtype=jnp.float32),\n", + " )(x, y)\n", + "\n", + "x = jax.random.uniform(jax.random.key(0), total_shape, dtype=jnp.float32)\n", + "y = jax.random.uniform(jax.random.key(1), total_shape, dtype=jnp.float32)\n", + "result = add_matrices_pipelined(x, y)\n", + "np.testing.assert_array_equal(\n", + " result, x + y\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UWHD0_qm6DL7" + }, + "source": [ + "It turns out that with this API, writing a pipelined kernel is not much more lines of code than writing our original naive addition kernel!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BZ-4U6Cv6cvU" + }, + "source": [ + "### Parameterizing a Kernel\n", + "\n", + "It's common to parameterize the block shapes in our kernel. Block sizes are perhaps the most important parameter to tune when optimizing the performance of Pallas kernels! They give us control over the pipeline (for example, picking smaller blocks adds more iterations to our pipelined loop where each iteration has less work to do). Let's write a a function that does so:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RZTAiwrZ6srD" + }, + "outputs": [], + "source": [ + "def add_matrices_pipelined_param(\n", + " x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256\n", + ") -> jax.Array:\n", + " m, n = x.shape\n", + " block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j))\n", + " return pl.pallas_call(\n", + " add_matrices_kernel,\n", + " out_shape=x,\n", + " in_specs=[block_spec, block_spec],\n", + " out_specs=block_spec,\n", + " grid=(m // bm, n // bn),\n", + " )(x, y)\n", + "\n", + "np.testing.assert_array_equal(\n", + " add_matrices_pipelined_param(x, y, bm=256, bn=256), x + y\n", + ")\n", + "np.testing.assert_array_equal(\n", + " add_matrices_pipelined_param(x, y, bm=128, bn=128), x + y\n", + ")\n", + "np.testing.assert_array_equal(\n", + " add_matrices_pipelined_param(x, y, bm=512, bn=512), x + y\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vO8VkbYj_ral" + }, + "source": [ + "## Sharp edges\n", + "\n", + "While pipelining provides a close approximation to the mental model of simply calling a kernel function in a loop, there are a number of sharp edges that arise from the use of intermediate buffers that are not fully hidden from the user and can result in subtle bugs.\n", + "\n", + "### Buffer Revisiting\n", + "\n", + "In general, a good rule-of-thumb to follow is that **the input buffers passed into the kernel function should be interpreted as read-only, and output buffers are write only**.\n", + "\n", + "Writing to inputs and reading from outputs will in most cases result in incorrectness. This is because the SRAM buffers passed to a kernel only contain copies of the data contained in the underlying HBM buffer. If an input SRAM buffer is updated, the updated results will never be written back out to HBM, and if an output buffer is updated, it's updated value is never read into SRAM. This issue is analogous to staleness issues encountered when using caches in general.\n", + "\n", + "There are two cases where a buffer supports both reads and writes - accumulation (discussed next), and marking a pair of input and output buffers as input-output aliased by passing in the `input_output_aliases` argument to `pallas_call`.\n", + "\n", + "\n", + "### Reductions and accumulation\n", + "\n", + "**Reduction/accumulation should only be performed over the last (innermost) dimensions of the grid, and the buffer should be initialized manually first.**\n", + "\n", + "Reductions are one of the few cases where the pipeline supports both reading and writing to an output buffer, but the reason it works is subtle.\n", + "The Pallas pipeline emitter performs an optimization where if the data slices between two consecutive iterations are the same, the pipeline will not issue a `copy_in`/`copy_out` on that buffer. This means the same SRAM buffer used in a previous iteration will be passed into the kernel again on the following iteration, and thus any writes that were issued to the output buffer will become visible on the next iteration. Once the data slice changes, the final accumulated SRAM buffer will be written out to HBM. This is also why reductions must be performed over the last dimensions of the grid -- we want to finish all of the accumulation while the output buffer is in SRAM in the innermost loop, then write it to HBM and never touch that output block again.\n", + "\n", + "As a concrete example, let's consider performing the following computation for reducing an `(8, 1024, 1024)` array along the first axies into a `(1024, 1024)` array.\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 244, + "status": "ok", + "timestamp": 1744763773938, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "4qz1ET-_f9fJ", + "outputId": "e43067ef-933a-45a5-912a-e224151cfa60" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([[8., 8., 8., ..., 8., 8., 8.],\n", + " [8., 8., 8., ..., 8., 8., 8.],\n", + " [8., 8., 8., ..., 8., 8., 8.],\n", + " ...,\n", + " [8., 8., 8., ..., 8., 8., 8.],\n", + " [8., 8., 8., ..., 8., 8., 8.],\n", + " [8., 8., 8., ..., 8., 8., 8.]], dtype=float32)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x = jnp.ones((8, 1024, 1024))\n", + "jnp.sum(x, axis=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yX762DRrgCOG" + }, + "source": [ + "To do this using `pallas_call`, we could use a grid of size `(8,)` and in each iteration i load `x[i]` into SRAM. Then we could add `x[i]` to an output SRAM buffer. Let's implement this naively first." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 79, + "status": "ok", + "timestamp": 1744763774254, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "ZEi1_vQVf-81", + "outputId": "581744b7-ddc1-4dc1-98ec-03c852772eda" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[65. 65. 65. ... 66. 66. 66.]\n", + " [65. 65. 65. ... 66. 66. 66.]\n", + " [65. 65. 65. ... 66. 66. 66.]\n", + " ...\n", + " [71. 71. 71. ... 72. 72. 72.]\n", + " [71. 71. 71. ... 72. 72. 72.]\n", + " [71. 71. 71. ... 72. 72. 72.]]\n" + ] + } + ], + "source": [ + "# Note: This is a TPU example.\n", + "\n", + "# Warning: this implementation is incorrect!\n", + "def incorrect_sum_kernel(x_ref, o_ref):\n", + " o_ref[...] += x_ref[...]\n", + "\n", + "def incorrect_sum(x: jax.Array,\n", + " block_size: tuple[int, ...] = (256, 256)) -> jax.Array:\n", + " reduction_size, *out_shape = x.shape\n", + " grid = (reduction_size, *(out // blk for out, blk in zip(out_shape, block_size)))\n", + " return pl.pallas_call(\n", + " incorrect_sum_kernel,\n", + " grid=grid,\n", + " # None in `block_shape` means we pick a size of 1 and squeeze it away\n", + " in_specs=[pl.BlockSpec((None, *block_size), lambda i, j, k: (i, j, k))],\n", + " out_specs=pl.BlockSpec(block_size, lambda i, j, k: (j, k)),\n", + " out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),\n", + " )(x)\n", + "\n", + "result = incorrect_sum(x)\n", + "print(result)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MglScPDD9618" + }, + "source": [ + "This result is completely wrong!\n", + "\n", + "There are two errors inside this kernel. First, we are accumulating along the first grid dimension instead of the last grid dimension. Second, `o_ref` initially contains garbage values and thus we need to initialize it to zeros before we begin accumulation.\n", + "\n", + "After fixing these two issues, we obtain the following corrected kernel. In this new kernel, we use `@pl.when` to create a conditional that checks when the program ID is `0` along the reduction axis, indicating we are beginning to accumulate into a new output block. We have also moved the reduction dimension to the last axis of the `grid`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 104, + "status": "ok", + "timestamp": 1744763774523, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "XtgD4nMa9_Bd", + "outputId": "9ef07cdf-9e22-4dc8-c17f-c96172639801" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[8. 8. 8. ... 8. 8. 8.]\n", + " [8. 8. 8. ... 8. 8. 8.]\n", + " [8. 8. 8. ... 8. 8. 8.]\n", + " ...\n", + " [8. 8. 8. ... 8. 8. 8.]\n", + " [8. 8. 8. ... 8. 8. 8.]\n", + " [8. 8. 8. ... 8. 8. 8.]]\n" + ] + } + ], + "source": [ + "# Note: This is a TPU example.\n", + "\n", + "def correct_sum_kernel(x_ref, o_ref):\n", + " @pl.when(pl.program_id(2) == 0)\n", + " def _():\n", + " o_ref[...] = jnp.zeros_like(o_ref)\n", + " o_ref[...] += x_ref[...]\n", + "\n", + "def correct_sum(x: jax.Array,\n", + " block_size: tuple[int, ...] = (256, 256)) -> jax.Array:\n", + " reduction_size, *out_shape = x.shape\n", + " # We moved the reduction to the last axis of the grid.\n", + " grid = (*(out // blk for out, blk in zip(out_shape, block_size)), reduction_size)\n", + " return pl.pallas_call(\n", + " correct_sum_kernel,\n", + " grid=grid,\n", + " # None in `block_shape` means we pick a size of 1 and squeeze it away\n", + " in_specs=[pl.BlockSpec((None, *block_size), lambda i, j, k: (k, i, j))],\n", + " out_specs=pl.BlockSpec(block_size, lambda i, j, k: (i, j)),\n", + " out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),\n", + " )(x)\n", + "\n", + "result = correct_sum(x)\n", + "print(result)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BckuFg6qcnVw" + }, + "source": [ + "\n", + "## Analyzing the performance\n", + "\n", + "What is the performance of a pipelined kernel? This question can vary depending on where the bottleneck is the hardware is. We are typically interested in 3 quantities:\n", + "- **Memory latency** $α$, the minimum latency of a memory transfer.\n", + "- **Memory bandwidth** $β$, the rate in bytes/second that we can transfer from HBM to SRAM.\n", + "- **FLOP/s** $F$, or floating-point-operations per second, the number of calculations per second that the processor can perform.\n", + "\n", + "We refer to a program as **compute-bound** if the processing speed FLOPs/s is the bottleneck, and as **memory-bound** if the bandwidth or latency are the bottleneck. Generally, our goal is to optimize a kernel such that it is compute-bound, meaning we are utilizing all of the available processing power of our hardware.\n", + "\n", + "Suppose we are running a program that requires $X$ bytes of memory transfers per kernel iteration, and runs $Y$ floating-point operations per iteration. The ratio of $X$ to $Y$ varies depending on the type of compute -- for elementwise operations such as addition or multiplication, they will both scale equally. However, for operations such as matrix multiplication, compute scales cubically with the size of the problem while memory scales quadratically.\n", + "\n", + "In a **compute-bound** regime, a pipeline running $N$ iterations would take $(\\alpha + X/\\beta) + N (Y/F)$ seconds, where the first term represents the cost of the initial bubble (multiply by a factor of 2 if there is also a bubble at the end), and the second term represents the total time of the steady-state of the pipeline. Assuming that N is large and there is enough work to produce a long pipeline, the dominating term in the runtime is $F$, the processing speed of the accelerator.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NDY4mcae_nMO" + }, + "source": [ + "\n", + "![pipelining_compute](../_static/pallas/pipelining_compute_bound.svg)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HFWcaAudW4z1" + }, + "source": [ + "In a **memory-bound** regime it is useful to identify if the problem is the latency versus the bandwidth. If the bandwidth is the bottleneck, then the total runtime would take $\\alpha + N(X / \\beta)$ seconds. In contrast with a latency-bound regime, the memory copies happen serially because the bandwidth is already saturated. Being memory-bound is generally not ideal as there will be gaps in time where the processor is idle, and in most hardware configurations the memory bandwidth $\\beta$ is orders of magnitude slower than the processing speed $F$." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gqcCDsGg_sca" + }, + "source": [ + "\n", + "![pipelining_bandwidth](../_static/pallas/pipelining_bandwidth_bound.svg)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "V4YQCZf1W7X5" + }, + "source": [ + "If the bottleneck is specifically the latency and not the bandwidth, it is possible to fix the problem by inserting additional pipeline stages at the cost of additional SRAM required to store more buffers. With sufficient stages, the problem will either become compute or bandwidth bound again depending on which bottleneck we hit first during the steady-stage stage of the pipeline. The downside, however, of a multi-stage pipeline is that the size of the bubble is proportional to the number of stages so it is important to make sure the pipeline is long enough such that the bubble does not take up a substantial amount of the total runtime.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Sj5PFl0s_yc6" + }, + "source": [ + "\n", + "![pipelining_latency](../_static/pallas/pipelining_latency_multistage.svg)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ar4NVxxFfKEb" + }, + "source": [ + "Pallas on TPU only supports double-buffering, as TPU programs can operate on larger block sizes and double-buffering is typically enough to cover the latency. On GPU, the number of pipeline stages can be specified in both the Triton (via `CompilerParams`) and Mosaic GPU backends (via argument to the pipeline emitter). See the platform-specific pipelining documentation for more details." + ] + } + ], + "metadata": { + "colab": { + "last_runtime": { + "build_target": "//experimental/users/justinfu/pallas:colab", + "kind": "private" + }, + "provenance": [] + }, + "jupytext": { + "formats": "ipynb,md", + "main_language": "python" + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/pallas/pipelining.md b/docs/pallas/pipelining.md new file mode 100644 index 000000000000..6fbc9caa985e --- /dev/null +++ b/docs/pallas/pipelining.md @@ -0,0 +1,600 @@ +--- +jupyter: + jupytext: + formats: ipynb,md + main_language: python + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.16.4 + kernelspec: + display_name: Python 3 + name: python3 +--- + + + +(pallas_software_pipelining)= + +# Software Pipelining + +Software pipelining is an important technique in performance optimization by overlapping multiple asynchronous operations even if there are data dependencies between them. In the context of kernel writing, the most common form of pipelining involves overlapping communication and memory transfers with compute such that the hardware accelerator never stalls while waiting for data to arrive. Therefore, we will solely focus on the problem of communication-compute pipelining in this tutorial. We will begin by covering the problem conceptually, outlining the Pallas API for writing pipelines, and going over some realistic examples using the API. + +This tutorial only covers the conceptual foundations of pipelining. For platform-specific references, please see {ref}`pallas_tpu_pipelining`, or {ref}`pallas_mgpu_pipelining`. + + + +```python id="YkOjspo5BKPD" +import jax +from jax import numpy as jnp +from jax.experimental import pallas as pl +import numpy as np +``` + + +## Memory Hierarchies + +The first step in understanding pipelining conceptually involves understanding the different forms of memory available and the tradeoffs between them. Most hardware architectures (including CPUs, GPUs, and TPUs) utilize a wide variety of memory spaces that tradeoff capacity vs latency/bandwidth. For the purpose of Pallas, we are typically interested in registers, SRAM, DRAM, and potentially network communication: +- **Registers** are the the memory physically closest to the processor, and typically values must be loaded directly into registers before doing any compute on them. +- **SRAM** (also known as Shared Memory/L1 and L2 cache on GPUs, or VMEM on TPUs) also lives fairly close to the processor, but has larger capacity than registers. +SRAM on modern ML accelerators typically range in the 10-100MB range (TPU v5p contains 96MB of VMEM, and H100 GPUs contain ~30MB of L1 cache and 50MB of L2). +It's reasonable to expect the latency to access SRAM to be on the order of 10x longer than accessing a register. +- **DRAM** (also known as HBM) has much higher capacity than SRAM, typically in the 10-100GB range for modern ML accelerators. However, the latency is roughly on the order of 10x longer to access compared to SRAM. +- **Network** communication becomes crucial for larger workloads when the size of DRAM on a single device becomes insufficient or when we'd like to take advantage of parallel computations. We do not cover distributed pipelining in this tutorial, but see the [distributed TPU kernels](https://docs.jax.dev/en/latest/pallas/tpu/distributed.html) guide for writing pipelines across multiple devices. + + + + +![memory_hierarchy](../_static/pallas/pipelining_mem_hierarchy.svg) + + + + + + +In order to perform computation on values X and Y that live in HBM, we need to: + +1. Copy the values x and y into SRAM. +2. Load the values from SRAM into registers. +3. Execute the computation and store the result into registers. +4. Store the values in the output registers into SRAM. +5. Copy the output values in SRAM back to HBM. + +Let’s implement a Pallas function that does just that! + + +```python executionInfo={"elapsed": 108, "status": "ok", "timestamp": 1744764235906, "user": {"displayName": "Justin Fu", "userId": "17543197034567316452"}, "user_tz": 420} id="IrPhDFnT3Nvw" outputId="8bc03872-fd9f-4610-9d53-d4b46be560f4" +# Note: This is a TPU example. + +def add_matrices_kernel(x_sram_ref, y_sram_ref, z_sram_ref): + # Load x and y from SRAM into registers + x_regs = x_sram_ref[:, :] + y_regs = y_sram_ref[:, :] + # Execute a vectorized add + z_regs = x_regs + y_regs + # Store the output values in registers back into SRAM + z_sram_ref[:, :] = z_regs + + +def add_matrices(x: jax.Array, y: jax.Array) -> jax.Array: + # pallas_call will first allocate scratch buffers for `x` and `y` in SRAM. + # It will then copy `x` and `y` from HBM into SRAM. + z = pl.pallas_call( + add_matrices_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype) + )(x, y) + # pallas_call will also copy the output from SRAM back into HBM. + return z + + +x, y = jnp.ones((512, 512)), jnp.ones((512, 512)) +add_matrices(x, y) +``` + + +We've written two functions: `add_matrices_kernel` and `add_matrices`. + +`add_matrices_kernel` operates using `Refs` that live in SRAM. Loading from a SRAM Ref produces a value that lives in registers. Values in registers behave like jax.Arrays in that we can use `jnp` and `jax.lax` operations on them to produce new values that live in registers. When we produce the values we'd like to return, we store them in the output SRAM `Ref`. + +The `add_matrices` function acts on `jax.Array`s and returns a `jax.Array`. Inside it, we pass `x` and `y` into pallas_call. `pallas_call` is responsible for copying `x` and `y` into SRAM and for allocating the SRAM buffers that the kernel operates on (including allocating `z_vmem_ref`, the output SRAM buffer). After the kernel function is finished running, `pallas_call` will also copy the value in `z_vmem_ref` to HBM, resulting in an output `jax.Array`. + +Pallas exposes access to lower level memory spaces like SRAM but writing performant kernels requires more care in utilizing the various memory spaces. For example, we need to consider both: + +- **Memory capacity**. SRAM is small! If our arrays are too big, the above kernel would not work because we cannot fit the input into SRAM. For reference, an `f32[2048, 2048]` array is 16MiB, so our above kernel won't scale beyond moderately sized arrays. + +- **Memory bandwidth**. Copying to/from HBM and SRAM takes a long time, at least compared to most compute instructions. The `add_matrices` function above will likely spend more time copying between HBM and SRAM than actually performing the addition itself. + +With these two constraints in mind, we'll have to rethink our strategy for getting performance out of our accelerators. + + + + + +## Pipelining Basics + + +How can we take advantage of the strengths of each form of type memory in the hierarchy, and be able to operate on large arrays stored in HBM while still utilizing fast SRAM for compute? Pipelining is a very general programming pattern which will allow us to do exactly this, but it requires transforming your problem into smaller sub-problems that can be overlapped in parallel. + +The first step in pipelining is to divide our problem into smaller subproblems that can fit inside of SRAM. For example, an elementwise operation is can be trivially transformed by operating on one slice of the source array at a time, which results in the following 3 steps (also known as stages): + +1. **copy_in**: Copy a slice `A[i]` from HBM to SRAM `X`. +2. **compute**: Load `X` into registers, compute a result, and store in SRAM `Y` +3. **copy_out**: Copy result `Y` back into HBM `A[i]`. + +Note that there is a data-dependence between steps 1-3, and we cannot trivially overlap them since we need step (1) to complete before starting step (2), and so on. However, there is no data dependence across multiple invocations of the subproblem - that is, we can execute step (1) for block `A[i+1]` while executing step (2) for block `A[i]` and step (3) for block `A[i-1]`. + + + + + + + +![pipelining_example](../_static/pallas/pipelining_example.svg) + + + + +The diagram above depicts how an idealized pipelined program can be scheduled across time. The key insight is that in the majority of the kernel, the copy operations are executed in parallel with compute operations, meaning we can ideally "hide" the cost of transferring between HBM/SRAM with computation and keep the processor busy with as much uptime as possible. + +The initial startup time and final teardown time known as "bubbles", where only a subset of the stages are being executed while the pipeline is being "filled" or "drained". The bulk of the time is spent in the "steady-state" phase of the pipeline, where each pipeline stage is being executed in parallel across different iterations of the subproblem. While with more general pipelining approaches the goal is to achieve N-way parallelism (where N is the number of stages), with kernel pipelining we are usually bottlenecked either by memory bandwidth or processing speed. Therefore, our goal with kernel pipelining is typically to achieve full utilization of the FLOPs/s of our processor, meaning that at any point in time there is always a `compute` block active. In the figure above, the compute block is active in 6/8 timeslots, and assuming we are fully utilizing the processor in each compute timeslot, we would have achieved 75% utilization of the processor. + + + +### Deriving a Double-Buffered Pipeline + +Now lets look at how we could implement a pipeline in pseudocode. Consider the following elementwise program, where we load values from HBM (`A[i]`) with a `copy_in` instruction, add 1 to the result, and store the result back to HBM with `copy_out`: + +
+for i in range(N):
+  copy_in(A[i], X)
+  Y = X + 1
+  copy_out(Y, A[i])
+
+The issue with this approach is that `copy_in` and `copy_out` are typically blocking operations. So we are forced to wait for the copies to finish while the GPU/TPU is idle, then perform compute while the memory is idle. What we would like to do is to "pre-fetch" the input value that is required on the next iteration of the loop asynchronously while performing the computation for the current loop, so that compute and memory communication are happening simultaneously. + +In order to reason about the code transformation we will make, lets unroll the loop for N=4, and decompose the copy instructions into separate `copy_start` and `copy_wait` operations to be able to express asynchrony: +
+  # Itr 1
+  copy_in_start(A[0], X)
+  copy_in_wait(X)
+  Y = X + 1
+  copy_out_start(Y, A[0])
+  copy_out_wait(Y)
+
+  # Itr 2
+  copy_in_start(A[1], X)
+  copy_in_wait(X)
+  Y = X + 1
+  copy_out_start(Y, A[1])
+  copy_out_wait(Y)
+
+  # Itr 3
+  copy_in_start(A[2], X)
+  copy_in_wait(X)
+  Y = X + 1
+  copy_out_start(Y, A[2])
+  copy_out_wait(Y)
+
+  # Itr 4
+  copy_in_start(A[3], X)
+  copy_in_wait(X)
+  Y = X + 1
+  copy_out_start(Y, A[3])
+  copy_out_wait(Y)
+
+ +Once the loop has been unrolled, the pipelining transformation simply involves issuing `copy_start` instructions as early as possible, and `copy_wait` values as late as possible (right before we need the value). However, in the current state of the loop there is a fake data dependency through X - we cannot simultaneously perform an async copy into X while using it for computation or else we may have a race condition. Therefore, we can use a **multiple-buffering** technique where we keep 2 buffers for each input X and each output Y. With 2 buffers, we can push the `copy_in_start` one iteration ahead (with 3 buffers you can push 2 iterations, and so on) and we rewrite our loop as follows: +
+  # Prologue
+  copy_in_start(A[0], X[0])
+  
+  # Itr 1
+  copy_in_start(A[1], X[1])
+  copy_in_wait(X[0])
+  Y[0] = X[0] + 1
+  copy_out_start(Y[0], A[0])
+  copy_out_wait(Y[0])
+
+  # Itr 2 - Steady state
+  copy_in_start(A[2], X[0])
+  copy_in_wait(X[1])
+  Y[1] = X[1] + 1
+  copy_out_start(Y[1], A[1])
+  copy_out_wait(Y[1])
+
+  # Itr 3 - Steady state
+  copy_in_start(A[3], X[1])
+  copy_in_wait(X[0])
+  Y[0] = X[0] + 1
+  copy_out_start(Y[0], A[2])
+  copy_out_wait(Y[0])
+
+  # Itr 4 - No copy-in
+  copy_in_wait(X[1])
+  Y[1] = X[1] + 1
+  copy_out_start(Y[1], A[3])
+  copy_out_wait(Y[1])
+
+ +Next, we can push the `copy_out_wait` as late as possible, right before we need to write into Y on the subsequent loop iteration. + +
+  # Prologue
+  copy_in_start(A[0], X[0])
+  
+  # Itr 1
+  copy_in_start(A[1], X[1])
+  copy_in_wait(X[0])
+  Y[0] = X[0] + 1
+  copy_out_start(Y[0], A[0])
+
+  # Itr 2 - Steady state
+  copy_in_start(A[2], X[0])
+  copy_in_wait(X[1])
+  Y[1] = X[1] + 1
+  copy_out_start(Y[1], A[1])
+  copy_out_wait(Y[0])
+
+  # Itr 3 - Steady state
+  copy_in_start(A[3], X[1])
+  copy_in_wait(X[0])
+  Y[0] = X[0] + 1
+  copy_out_start(Y[0], A[2])
+  copy_out_wait(Y[1])
+
+  # Itr 4 - No copy-in
+  copy_in_wait(X[1])
+  Y[1] = X[1] + 1
+  copy_out_start(Y[1], A[3])
+  copy_out_wait(Y[0])
+
+  # Epilogue
+  copy_out_wait(Y[1])
+
+ +Finally, re-rolling our loop back into a for loop, we obtain the following pipelined loop: + +``` +# Prologue +copy_in_start(A[0], X[0]) + +# Main loop +for i in range(N): + cur_slot = i % 2 + next_slot = (i + 1) % 2 + + if i+1 < N: + copy_in_start(A[i+1], X[next_slot]) + + copy_in_wait(X[cur_slot]) + Y[cur_slot] = X[cur_slot] + 1 + copy_out_start(Y[cur_slot], A[i]) + + if i > 0: + copy_out_wait(Y[next_slot]) + +# Epilogue +copy_out_wait(Y[1]) +``` + +If we want to generalize this loop to handle a broader set of computations, notice that we essentially need to specify 3 pieces of information to the pipeline: + +- The **grid**, or the bounds of the for loop that specifies the number of subproblems to compute. In our example we had a 1-dimensional grid with size `(N,)`. +- The **kernel**, or the actual computation happening once the inputs have been loaded into SRAM. In our example we performed an elementwise addition `Y = X + 1`. +- The **data_slices**, which map a subproblem to corresponding slices into the HBM buffer. In our example the data slice was the identity function `lambda i: i`. + +By allowing the user to specify these pieces of information we can write a wide variety of programs following this pattern: +```python +def double_buffered_pipeline( + grid: tuple[int, ...], + kernel: Callable, + in_slices: Callable, + out_slices: Callable): + # Prologue + copy_in_start(in_hbm[in_slices(0)], in_sram[0]) + + # Main loop + grid_size = prod(grid) + for i in range(grid_size): + cur_slot = i % 2 + next_slot = (i + 1) % 2 + if (i + 1) < grid_size: + copy_in_start(in_hbm[in_slices(i+1)], in_sram[next_slot]) + copy_in_wait(in_sram[cur_slot]) + + kernel(in_sram[cur_slot], out_ram[cur_slot]) + + copy_out_start(out_sram[cur_slot], out_hbm[out_slices(i)]) + if i > 0: + copy_out_wait(out_sram[next_slot]) + + # Epilogue + last_slot = (grid_size - 1) % 2 + copy_out_wait(out_sram[last_slot]) +``` + + + +Now that we've seen how to manually implement a pipelined loop, let's look into how to use the Pallas API. + + + +## Pallas Pipelining API + +Pallas offers a pipelining API that abstracts away the boilerplate of maintaining multiple buffers and overlapping asynchronous communication with computation. The basics of this API are covered in {ref}`pallas_quickstart`, so we will go over the API briefly here for completeness and discuss some sharp edges that arise from the use of pipelining. + + +### Grid + +The program **grid** is a tuple of integers specifying the number of subproblems as an array. The structure of the pipeline can be interpreted as a nested for-loop where the bounds of each loop. + +``` +# For grid (N, M, K) +for n in range (N): + for m in range(M): + for k in range(K): + kernel() +``` + +The kernel will be invoked a total of `prod(grid)` times. For more details, see [grid and blockspecs](https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#grid-a-k-a-kernels-in-a-loop). + +### BlockSpecs + +A BlockSpec specifies the size and slice of data copied to the kernel on each subproblem. The basic constructor to `pl.BlockSpec` involves specifying the `block_shape`, the size of a slice of data, and `index_map`, a function that takes in the program ids of the current subproblem and outputs _blocked_ indices into the source buffer. Blocked indices specify which block to copy on each iteration, assuming the source buffer has been carved into blocks of shape as `block_shape`. The `memory_space` argument specifies what memory space to copy the inputs to - be default this will be SRAM. + +```python +pl.BlockSpec( + block_shape: tuple[int, ...], + index_map: Callable, + memory_space: pl.MemorySpace +) +``` +There should be one BlockSpec for each input and each output to the kernel. For more details, see [grid and blockspecs](https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#grid-a-k-a-kernels-in-a-loop). + +### Kernel + +The kernel function specifies what compute to perform on each subproblem. The kernel function should return no outputs, and instead all outputs should be written into the output buffers that are passed into the kernel. All inputs and output buffers are SRAM buffers by default (unless the user has overridden the behavior by specifying a `memory_space` on the corresponding `BlockSpec`). + +```python +def kernel(*input_buffers, *output_buffers): + # ... perform compute + # ... store result into output buffers +``` + +The index of the current subproblem can be queried inside the kernel using `pl.program_id(grid_axis: int)`. + + +### Pallas Call + +The `pl.pallas_call` function is the main entry point to Pallas and performs pipelined execution when a grid and BlockSpecs are supplied. It has the following signature: +```python +def pallas_call( + kernel, + grid: tuple[int, ...], + in_specs: Sequence[PyTree[BlockSpec]], + out_specs: PyTree[BlockSpec], + out_shape: PyTree[jax.ShapeDtypeStruct], +) -> Callable: +``` +`pallas_call` will return a callable function that when invoked with input values, will return outputs of the same shape as `out_shape`. + +`in_specs`, `out_specs`, and `out_shape` are PyTrees of their respective element type. The PyTrees for `in_specs` and the input buffers supplied to the kernel should match, and the PyTrees for `out_specs` and `out_shape` should also match. + + + + +### Example - Elementwise Kernel revisited + +Let's revisit the initial `add_matrices_kernel` from the beginning of the tutorial, except using pipelining. We will add two input arrays of shape `f32[4096, 4096]` that live in HBM. As subproblems, we will carve up the inputs into `block_shape=(512, 512)` blocks and only add two blocks together at a time in the kernel. Because addition is elementwise, each `index_map` is identical and selects out the `i, j`th block on the `i, j`th iteration. + + +```python id="iqr_qjONAHN9" +# Note: This is a TPU example. + +total_shape = (4096, 4096) +block_shape = (512, 512) + +def add_matrices_pipelined_kernel(x_ref, y_ref, o_ref): + o_ref[...] = x_ref[...] + y_ref[...] + +def add_matrices_pipelined(x: jax.Array, y: jax.Array): + return pl.pallas_call( + add_matrices_pipelined_kernel, + grid=tuple(total // block for (total, block) in zip(total_shape, block_shape)), + in_specs=[ + pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)), + pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)) + ], + out_specs=pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)), + out_shape=jax.ShapeDtypeStruct(total_shape, dtype=jnp.float32), + )(x, y) + +x = jax.random.uniform(jax.random.key(0), total_shape, dtype=jnp.float32) +y = jax.random.uniform(jax.random.key(1), total_shape, dtype=jnp.float32) +result = add_matrices_pipelined(x, y) +np.testing.assert_array_equal( + result, x + y +) +``` + + +It turns out that with this API, writing a pipelined kernel is not much more lines of code than writing our original naive addition kernel! + + + +### Parameterizing a Kernel + +It's common to parameterize the block shapes in our kernel. Block sizes are perhaps the most important parameter to tune when optimizing the performance of Pallas kernels! They give us control over the pipeline (for example, picking smaller blocks adds more iterations to our pipelined loop where each iteration has less work to do). Let's write a a function that does so: + + +```python id="RZTAiwrZ6srD" +def add_matrices_pipelined_param( + x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256 +) -> jax.Array: + m, n = x.shape + block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j)) + return pl.pallas_call( + add_matrices_kernel, + out_shape=x, + in_specs=[block_spec, block_spec], + out_specs=block_spec, + grid=(m // bm, n // bn), + )(x, y) + +np.testing.assert_array_equal( + add_matrices_pipelined_param(x, y, bm=256, bn=256), x + y +) +np.testing.assert_array_equal( + add_matrices_pipelined_param(x, y, bm=128, bn=128), x + y +) +np.testing.assert_array_equal( + add_matrices_pipelined_param(x, y, bm=512, bn=512), x + y +) +``` + + +## Sharp edges + +While pipelining provides a close approximation to the mental model of simply calling a kernel function in a loop, there are a number of sharp edges that arise from the use of intermediate buffers that are not fully hidden from the user and can result in subtle bugs. + +### Buffer Revisiting + +In general, a good rule-of-thumb to follow is that **the input buffers passed into the kernel function should be interpreted as read-only, and output buffers are write only**. + +Writing to inputs and reading from outputs will in most cases result in incorrectness. This is because the SRAM buffers passed to a kernel only contain copies of the data contained in the underlying HBM buffer. If an input SRAM buffer is updated, the updated results will never be written back out to HBM, and if an output buffer is updated, it's updated value is never read into SRAM. This issue is analogous to staleness issues encountered when using caches in general. + +There are two cases where a buffer supports both reads and writes - accumulation (discussed next), and marking a pair of input and output buffers as input-output aliased by passing in the `input_output_aliases` argument to `pallas_call`. + + +### Reductions and accumulation + +**Reduction/accumulation should only be performed over the last (innermost) dimensions of the grid, and the buffer should be initialized manually first.** + +Reductions are one of the few cases where the pipeline supports both reading and writing to an output buffer, but the reason it works is subtle. +The Pallas pipeline emitter performs an optimization where if the data slices between two consecutive iterations are the same, the pipeline will not issue a `copy_in`/`copy_out` on that buffer. This means the same SRAM buffer used in a previous iteration will be passed into the kernel again on the following iteration, and thus any writes that were issued to the output buffer will become visible on the next iteration. Once the data slice changes, the final accumulated SRAM buffer will be written out to HBM. This is also why reductions must be performed over the last dimensions of the grid -- we want to finish all of the accumulation while the output buffer is in SRAM in the innermost loop, then write it to HBM and never touch that output block again. + +As a concrete example, let's consider performing the following computation for reducing an `(8, 1024, 1024)` array along the first axies into a `(1024, 1024)` array. + + + + + + + + +```python executionInfo={"elapsed": 244, "status": "ok", "timestamp": 1744763773938, "user": {"displayName": "Justin Fu", "userId": "17543197034567316452"}, "user_tz": 420} id="4qz1ET-_f9fJ" outputId="e43067ef-933a-45a5-912a-e224151cfa60" +x = jnp.ones((8, 1024, 1024)) +jnp.sum(x, axis=0) +``` + + +To do this using `pallas_call`, we could use a grid of size `(8,)` and in each iteration i load `x[i]` into SRAM. Then we could add `x[i]` to an output SRAM buffer. Let's implement this naively first. + + +```python executionInfo={"elapsed": 79, "status": "ok", "timestamp": 1744763774254, "user": {"displayName": "Justin Fu", "userId": "17543197034567316452"}, "user_tz": 420} id="ZEi1_vQVf-81" outputId="581744b7-ddc1-4dc1-98ec-03c852772eda" +# Note: This is a TPU example. + +# Warning: this implementation is incorrect! +def incorrect_sum_kernel(x_ref, o_ref): + o_ref[...] += x_ref[...] + +def incorrect_sum(x: jax.Array, + block_size: tuple[int, ...] = (256, 256)) -> jax.Array: + reduction_size, *out_shape = x.shape + grid = (reduction_size, *(out // blk for out, blk in zip(out_shape, block_size))) + return pl.pallas_call( + incorrect_sum_kernel, + grid=grid, + # None in `block_shape` means we pick a size of 1 and squeeze it away + in_specs=[pl.BlockSpec((None, *block_size), lambda i, j, k: (i, j, k))], + out_specs=pl.BlockSpec(block_size, lambda i, j, k: (j, k)), + out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype), + )(x) + +result = incorrect_sum(x) +print(result) +``` + + +This result is completely wrong! + +There are two errors inside this kernel. First, we are accumulating along the first grid dimension instead of the last grid dimension. Second, `o_ref` initially contains garbage values and thus we need to initialize it to zeros before we begin accumulation. + +After fixing these two issues, we obtain the following corrected kernel. In this new kernel, we use `@pl.when` to create a conditional that checks when the program ID is `0` along the reduction axis, indicating we are beginning to accumulate into a new output block. We have also moved the reduction dimension to the last axis of the `grid`. + + +```python executionInfo={"elapsed": 104, "status": "ok", "timestamp": 1744763774523, "user": {"displayName": "Justin Fu", "userId": "17543197034567316452"}, "user_tz": 420} id="XtgD4nMa9_Bd" outputId="9ef07cdf-9e22-4dc8-c17f-c96172639801" +# Note: This is a TPU example. + +def correct_sum_kernel(x_ref, o_ref): + @pl.when(pl.program_id(2) == 0) + def _(): + o_ref[...] = jnp.zeros_like(o_ref) + o_ref[...] += x_ref[...] + +def correct_sum(x: jax.Array, + block_size: tuple[int, ...] = (256, 256)) -> jax.Array: + reduction_size, *out_shape = x.shape + # We moved the reduction to the last axis of the grid. + grid = (*(out // blk for out, blk in zip(out_shape, block_size)), reduction_size) + return pl.pallas_call( + correct_sum_kernel, + grid=grid, + # None in `block_shape` means we pick a size of 1 and squeeze it away + in_specs=[pl.BlockSpec((None, *block_size), lambda i, j, k: (k, i, j))], + out_specs=pl.BlockSpec(block_size, lambda i, j, k: (i, j)), + out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype), + )(x) + +result = correct_sum(x) +print(result) +``` + + + +## Analyzing the performance + +What is the performance of a pipelined kernel? This question can vary depending on where the bottleneck is the hardware is. We are typically interested in 3 quantities: +- **Memory latency** $α$, the minimum latency of a memory transfer. +- **Memory bandwidth** $β$, the rate in bytes/second that we can transfer from HBM to SRAM. +- **FLOP/s** $F$, or floating-point-operations per second, the number of calculations per second that the processor can perform. + +We refer to a program as **compute-bound** if the processing speed FLOPs/s is the bottleneck, and as **memory-bound** if the bandwidth or latency are the bottleneck. Generally, our goal is to optimize a kernel such that it is compute-bound, meaning we are utilizing all of the available processing power of our hardware. + +Suppose we are running a program that requires $X$ bytes of memory transfers per kernel iteration, and runs $Y$ floating-point operations per iteration. The ratio of $X$ to $Y$ varies depending on the type of compute -- for elementwise operations such as addition or multiplication, they will both scale equally. However, for operations such as matrix multiplication, compute scales cubically with the size of the problem while memory scales quadratically. + +In a **compute-bound** regime, a pipeline running $N$ iterations would take $(\alpha + X/\beta) + N (Y/F)$ seconds, where the first term represents the cost of the initial bubble (multiply by a factor of 2 if there is also a bubble at the end), and the second term represents the total time of the steady-state of the pipeline. Assuming that N is large and there is enough work to produce a long pipeline, the dominating term in the runtime is $F$, the processing speed of the accelerator. + + + + + + +![pipelining_compute](../_static/pallas/pipelining_compute_bound.svg) + + + + +In a **memory-bound** regime it is useful to identify if the problem is the latency versus the bandwidth. If the bandwidth is the bottleneck, then the total runtime would take $\alpha + N(X / \beta)$ seconds. In contrast with a latency-bound regime, the memory copies happen serially because the bandwidth is already saturated. Being memory-bound is generally not ideal as there will be gaps in time where the processor is idle, and in most hardware configurations the memory bandwidth $\beta$ is orders of magnitude slower than the processing speed $F$. + + + + +![pipelining_bandwidth](../_static/pallas/pipelining_bandwidth_bound.svg) + + + + +If the bottleneck is specifically the latency and not the bandwidth, it is possible to fix the problem by inserting additional pipeline stages at the cost of additional SRAM required to store more buffers. With sufficient stages, the problem will either become compute or bandwidth bound again depending on which bottleneck we hit first during the steady-stage stage of the pipeline. The downside, however, of a multi-stage pipeline is that the size of the bubble is proportional to the number of stages so it is important to make sure the pipeline is long enough such that the bubble does not take up a substantial amount of the total runtime. + + + + + +![pipelining_latency](../_static/pallas/pipelining_latency_multistage.svg) + + + + +Pallas on TPU only supports double-buffering, as TPU programs can operate on larger block sizes and double-buffering is typically enough to cover the latency. On GPU, the number of pipeline stages can be specified in both the Triton (via `CompilerParams`) and Mosaic GPU backends (via argument to the pipeline emitter). See the platform-specific pipelining documentation for more details. + diff --git a/docs/pallas/quickstart.ipynb b/docs/pallas/quickstart.ipynb index 11dd2108e405..3fefc2cbc157 100644 --- a/docs/pallas/quickstart.ipynb +++ b/docs/pallas/quickstart.ipynb @@ -5,6 +5,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ + "(pallas_quickstart)=\n", "# Pallas Quickstart\n", "\n", "\n", @@ -88,7 +89,53 @@ "We then write `x + y` to `o_ref`.\n", "Mutation has not historically been supported in JAX -- `jax.Array`s are immutable!\n", "`Ref`s are new (experimental) types that allow mutation under certain circumstances.\n", - "We can interpret writing to a `Ref` as mutating its underlying buffer." + "We can interpret writing to a `Ref` as mutating its underlying buffer.\n", + "\n", + "**Indexing and Slicing `Ref`s with `.at`**\n", + "\n", + "In addition to accessing the entire underlying buffer through a reference, it\n", + "is possible to also access only a slice by using the `.at` property. Using\n", + "`x_ref.at[slice]` does not immediately read or write data; it\n", + "creates a new `Ref` object that points to a slice of the original buffer. For\n", + "example `ref.at[0:128]` creates a view of the first 128 elements; `ref.at[::2]`\n", + "creates a strided view.\n", + "\n", + "Once you have a new `Ref` that represents a slice you can read it or write to it\n", + "with the usual syntax. Here is a simple example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b2563d6a", + "metadata": {}, + "outputs": [], + "source": [ + "def add_sliced_kernel(x_ref, y_ref, o_ref):\n", + " small_mid = x_ref.shape[0] // 2\n", + "\n", + " x_left = x_ref.at[:small_mid]\n", + " x_right = x_ref.at[small_mid:]\n", + " y_left = y_ref.at[:small_mid]\n", + " y_right = y_ref.at[small_mid:]\n", + "\n", + " # The output shape is (4*small_mid).\n", + " large_mid = 2*small_mid\n", + " o_ref.at[:large_mid][:small_mid] = x_left[...] + y_left[...]\n", + " o_ref.at[:large_mid][small_mid:] = x_left[...] + y_right[...]\n", + " o_ref.at[large_mid:][:small_mid] = x_right[...] + y_left[...]\n", + " o_ref.at[large_mid:][small_mid:] = x_right[...] + y_right[...]" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that using `x_ref.at[slice][...]` is equivalent to `x_ref[slice]`. The\n", + "`.at` is useful if you want to compose multiple slices (e.g.\n", + "`x_ref.at[block_slice][thread_slice]`) or if need to pass a slice to a subkernel\n", + "function that takes a `Ref`." ] }, { @@ -279,7 +326,7 @@ "metadata": {}, "source": [ "TPUs distinguish between vector and scalar memory spaces and in this case the\n", - "output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is\n", + "output must be placed in scalar memory (`MemorySpace.SMEM`) since `i` is\n", "a scalar. For more details read {ref}`tpu_and_its_memory_spaces`.\n", "To call the above kernel on TPU, run:" ] @@ -296,7 +343,7 @@ "\n", "def iota(size: int):\n", " return pl.pallas_call(iota_kernel,\n", - " out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),\n", + " out_specs=pl.BlockSpec(memory_space=pltpu.SMEM),\n", " out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),\n", " grid=(size,))()\n", "iota(8)" diff --git a/docs/pallas/quickstart.md b/docs/pallas/quickstart.md index fff1dcb730f3..f18225a589d5 100644 --- a/docs/pallas/quickstart.md +++ b/docs/pallas/quickstart.md @@ -12,6 +12,7 @@ kernelspec: name: python3 --- +(pallas_quickstart)= # Pallas Quickstart @@ -71,6 +72,40 @@ Mutation has not historically been supported in JAX -- `jax.Array`s are immutabl `Ref`s are new (experimental) types that allow mutation under certain circumstances. We can interpret writing to a `Ref` as mutating its underlying buffer. +**Indexing and Slicing `Ref`s with `.at`** + +In addition to accessing the entire underlying buffer through a reference, it +is possible to also access only a slice by using the `.at` property. Using +`x_ref.at[slice]` does not immediately read or write data; it +creates a new `Ref` object that points to a slice of the original buffer. For +example `ref.at[0:128]` creates a view of the first 128 elements; `ref.at[::2]` +creates a strided view. + +Once you have a new `Ref` that represents a slice you can read it or write to it +with the usual syntax. Here is a simple example: + +```{code-cell} ipython3 +def add_sliced_kernel(x_ref, y_ref, o_ref): + small_mid = x_ref.shape[0] // 2 + + x_left = x_ref.at[:small_mid] + x_right = x_ref.at[small_mid:] + y_left = y_ref.at[:small_mid] + y_right = y_ref.at[small_mid:] + + # The output shape is (4*small_mid). + large_mid = 2*small_mid + o_ref.at[:large_mid][:small_mid] = x_left[...] + y_left[...] + o_ref.at[:large_mid][small_mid:] = x_left[...] + y_right[...] + o_ref.at[large_mid:][:small_mid] = x_right[...] + y_left[...] + o_ref.at[large_mid:][small_mid:] = x_right[...] + y_right[...] +``` + +Note that using `x_ref.at[slice][...]` is equivalent to `x_ref[slice]`. The +`.at` is useful if you want to compose multiple slices (e.g. +`x_ref.at[block_slice][thread_slice]`) or if need to pass a slice to a subkernel +function that takes a `Ref`. + +++ So we've written what we call a "kernel", which we define as a program that will @@ -185,7 +220,7 @@ iota(8) ``` TPUs distinguish between vector and scalar memory spaces and in this case the -output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is +output must be placed in scalar memory (`MemorySpace.SMEM`) since `i` is a scalar. For more details read {ref}`tpu_and_its_memory_spaces`. To call the above kernel on TPU, run: @@ -195,7 +230,7 @@ from jax.experimental.pallas import tpu as pltpu def iota(size: int): return pl.pallas_call(iota_kernel, - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), out_shape=jax.ShapeDtypeStruct((size,), jnp.int32), grid=(size,))() iota(8) diff --git a/docs/pallas/tpu/core_map.ipynb b/docs/pallas/tpu/core_map.ipynb new file mode 100644 index 000000000000..38be63d61cb4 --- /dev/null +++ b/docs/pallas/tpu/core_map.ipynb @@ -0,0 +1,628 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "last_runtime": { + "build_target": "//third_party/py/jax_triton/google/pallas_tpu:notebook", + "kind": "private" + } + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Pallas Core-specific Programming" + ], + "metadata": { + "id": "YIt0Za36LYg9" + } + }, + { + "cell_type": "markdown", + "source": [ + "In this guide, we explore using `pl.core_map` to write Pallas kernels. Compared with `pallas_call`, `core_map` offers a few key characteristics:\n", + "\n", + "* **Per-core level programming**: You write code for an TPU/GPU core, not for a JAX device. This gives you full control over what runs on every core, or how cores communicate and distribute work among one another.\n", + "\n", + "* **Collectives**: `core_map` explicitly models physical cores, so inter-core communication can be expressed safely.\n", + "\n", + "* **Platform generic**: `core_map` programming model works for TPU (TensorCore and SparseCore) and GPU with minimal boilerplate changes." + ], + "metadata": { + "id": "khDWSc7aOVts" + } + }, + { + "cell_type": "markdown", + "source": [ + "This guide focuses on TPU. For how to use `core_map` on GPU to achieve higher thread flexibility, check out our [Pallas GPU `core_map` tutorial](https://docs.jax.dev/en/latest/pallas/gpu/reference.html#using-core-map)." + ], + "metadata": { + "id": "i8pl0CLqTVvL" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Environment setup\n", + "\n", + "Modern accelerators often have multiple cores under a device. For recent TPU chips (v4, v5p), every JAX device may contains 2 TensorCores (aka. a [Megacore](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#chips)). Some TPUs (v5p, v6e, 7x) also contain [SparseCores](https://openxla.org/xla/sparsecore#specifications_at_a_glance), each of which consists of many subcores.\n", + "\n", + "This guide was written on a v5p chip, which contains 4 devices (2 TensorCores each) and 4 SparseCores, each with 16 subcores." + ], + "metadata": { + "id": "bsOPXdJkzC-x" + } + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "14PNaMVsLUur", + "executionInfo": { + "status": "ok", + "timestamp": 1764795546418, + "user_tz": 480, + "elapsed": 2087, + "user": { + "displayName": "Ivy Zheng", + "userId": "15297372265856137303" + } + }, + "outputId": "01976bb1-2f2f-40e9-ca23-f0e480a82ab3" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Running on 4 TPU v5p devices.\n" + ] + } + ], + "source": [ + "from functools import partial\n", + "\n", + "import jax\n", + "from jax.sharding import NamedSharding\n", + "from jax.experimental import pallas as pl\n", + "from jax.experimental.pallas import tpu as pltpu\n", + "from jax.experimental.pallas import tpu_sc as plsc\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "\n", + "\n", + "num_devices = jax.local_device_count()\n", + "assert num_devices > 1, \"Please run this notebook with more than one device.\"\n", + "\n", + "tpu_info = pltpu.get_tpu_info() # This notebook only runs on TPU.\n", + "print(f\"Running on {num_devices} TPU {tpu_info.chip_version} devices.\")" + ] + }, + { + "cell_type": "markdown", + "source": [ + "In addition to the typical TPU device mesh, you need to make a mesh of cores. Consider this as an addition dimension called `core`, with length 2, in addition to the 4-device mesh you work with. That is 8 cores in total." + ], + "metadata": { + "id": "3f0XEhaYnGyk" + } + }, + { + "cell_type": "code", + "source": [ + "# Mesh of devices\n", + "mesh = jax.make_mesh((jax.device_count(),), ('device',))\n", + "print(mesh)\n", + "\n", + "# Mesh of cores, within a JAX device\n", + "tc_mesh = pltpu.create_tensorcore_mesh('core')\n", + "print(tc_mesh)\n", + "\n", + "num_devices = mesh.size\n", + "num_cores = len(tc_mesh.devices)\n", + "print(f\"There are {num_devices} devices, and {num_cores} cores each.\")" + ], + "metadata": { + "id": "jr5MARD-mIlC", + "executionInfo": { + "status": "ok", + "timestamp": 1764795546665, + "user_tz": 480, + "elapsed": 57, + "user": { + "displayName": "Ivy Zheng", + "userId": "15297372265856137303" + } + }, + "outputId": "1ea63c2f-3aec-4cdd-9674-d0e2df32460c" + }, + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Mesh('device': 4, axis_types=(Explicit,))\n", + "TensorCoreMesh(devices=array([TensorCore(id=0), TensorCore(id=1)], dtype=object), axis_names=('core',))\n", + "There are 4 devices, and 2 cores each.\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## A simple per-core kernel\n", + "\n", + "`pl.core_map` allows you to write per-core local code, just as `jax.shard_map` allows you to write per-device code.\n", + "\n", + "In the example kernel below, each core has its own VMEM and semaphore allocations. As with normal kernel, you can initiate copies between HBM and VMEM refs using `pltpu.async_copy`.\n", + "\n", + "**Communication between cores**\n", + "\n", + "Before communicating between cores, it is good practice to perform a barrier (using `pltpu.semaphore_signal`) to ensure resources have been allocated and both cores are at the same point during the program.\n", + "\n", + "Once the cores are synchronized, use `pltpu.make_async_remote_copy` to send data between them. The `device_id` keyword argument generically allows sending to any core on any device, but if you just pass in `{'core': other_core_id}`, it will perform a intra-device inter-core copy (the other axis names are held constant).\n" + ], + "metadata": { + "id": "CYxwiULfndlh" + } + }, + { + "cell_type": "code", + "source": [ + "# This runs on every core\n", + "def swap_cores_kernel(in_hbm, out_hbm,\n", + " in_vmem, scratch_vmem, out_vmem,\n", + " sem, send_sem, recv_sem):\n", + " core_index = jax.lax.axis_index('core')\n", + " num_cores = jax.lax.axis_size('core')\n", + " slc_size = in_hbm.shape[-1] // num_cores\n", + " slc = pl.ds(core_index * slc_size, slc_size)\n", + "\n", + " # Copy in a core-dependent slice of the input\n", + " pltpu.async_copy(in_hbm.at[:, slc], in_vmem, sem).wait()\n", + "\n", + " # A barrier to make sure all cores have entered run_scoped.\n", + " # You won't need this if not doing inter-core communications.\n", + " dst_core = (core_index + 1) % num_cores\n", + " sem0 = pltpu.get_barrier_semaphore()\n", + " pltpu.semaphore_signal(sem0, 1, device_id={'core': dst_core})\n", + " pltpu.semaphore_wait(sem0, 1)\n", + "\n", + " # Swap data between core 0 and core 1\n", + " the_copy = pltpu.make_async_remote_copy(\n", + " in_vmem, scratch_vmem, send_sem, recv_sem, device_id={'core': dst_core},\n", + " )\n", + " the_copy.start()\n", + " the_copy.wait()\n", + "\n", + " # Core-local compute\n", + " out_vmem[...] = scratch_vmem[...] * 2\n", + "\n", + " # Copy out the output\n", + " pltpu.async_copy(out_vmem, out_hbm.at[:, slc], sem).wait()\n" + ], + "metadata": { + "id": "GkGRT2HRJOUU", + "executionInfo": { + "status": "ok", + "timestamp": 1764795546946, + "user_tz": 480, + "elapsed": 53, + "user": { + "displayName": "Ivy Zheng", + "userId": "15297372265856137303" + } + } + }, + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Once you have the local kernel:\n", + "\n", + " * Start your top-level JAX code with HBM refs, and allocate output refs if needed.\n", + "\n", + " * Use `pl.core_map`, which takes the TensorCore mesh, to start per-core programming.\n", + "\n", + " * You will need `collective_id` for the barrier semaphore.\n", + "\n", + " * Inside `pl.core_map`, invoke `pl.run_scoped` to allocate per-core scratch spaces (VMEM and semaphores) and run the local kernel." + ], + "metadata": { + "id": "2T0tSkFmoFLI" + } + }, + { + "cell_type": "code", + "source": [ + "input_shape = (32, 256)\n", + "local_vmem_shape = (32 // num_devices, 256 // num_cores)\n", + "in_spec = jax.P('device', None)\n", + "sharding = NamedSharding(mesh, in_spec)\n", + "\n", + "@jax.jit\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=in_spec, out_specs=in_spec,\n", + " check_vma=False)\n", + "def swap_cores(x):\n", + " # Get buffers out of the input and output\n", + " x_hbm_ref = jax.new_ref(x)\n", + " o_hbm_ref = jax.new_ref(jax.lax.empty(x.shape, x.dtype))\n", + "\n", + " @pl.core_map(tc_mesh, compiler_params=pltpu.CompilerParams(collective_id=0))\n", + " def _():\n", + " pl.run_scoped(\n", + " partial(swap_cores_kernel, x_hbm_ref, o_hbm_ref),\n", + " *([pltpu.VMEM(local_vmem_shape, x.dtype)] * 3), # VMEM allocations\n", + " *([pltpu.SemaphoreType.DMA] * 3), # semaphores\n", + " )\n", + " return o_hbm_ref[...]\n", + "\n", + "\n", + "x = jax.random.normal(jax.random.key(0), input_shape, jnp.float32)\n", + "x = jax.device_put(x, sharding)\n", + "y = swap_cores(x)\n", + "\n", + "np.testing.assert_array_equal(y[:, 128:], x[:, :128] * 2)\n", + "np.testing.assert_array_equal(y[:, :128], x[:, 128:] * 2)" + ], + "metadata": { + "id": "KT6zkEKi1Sbc", + "executionInfo": { + "status": "ok", + "timestamp": 1764795548996, + "user_tz": 480, + "elapsed": 1800, + "user": { + "displayName": "Ivy Zheng", + "userId": "15297372265856137303" + } + } + }, + "execution_count": 4, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Save the boilerplate\n", + "\n", + "You can use the `pl.kernel` decorator to wrap boilerplate such as `core_map`, `run_scoped`, and output buffer allocation.\n", + "\n", + "Note that this should run inside any `jax.shard_map` you may have at the top level." + ], + "metadata": { + "id": "dLV8sKa4HuSX" + } + }, + { + "cell_type": "code", + "source": [ + "@jax.jit\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=in_spec, out_specs=in_spec, check_vma=False)\n", + "def swap_cores(x):\n", + " scratch_shapes = [pltpu.VMEM(local_vmem_shape, x.dtype)] * 3 + [pltpu.SemaphoreType.DMA] * 3\n", + " return pl.kernel(swap_cores_kernel, out_shape=x, mesh=tc_mesh,\n", + " scratch_shapes=scratch_shapes,\n", + " compiler_params=pltpu.CompilerParams(collective_id=0))(x)\n", + "\n", + "y = swap_cores(x)\n", + "np.testing.assert_array_equal(y[:, 128:], x[:, :128] * 2)\n", + "np.testing.assert_array_equal(y[:, :128], x[:, 128:] * 2)" + ], + "metadata": { + "id": "7cHnsRHPHyfH", + "executionInfo": { + "status": "ok", + "timestamp": 1764795549347, + "user_tz": 480, + "elapsed": 106, + "user": { + "displayName": "Ivy Zheng", + "userId": "15297372265856137303" + } + } + }, + "execution_count": 5, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Pipelining with `core_map`\n", + "\n", + "Note that the kernel above only does simple copies and compute, without automatic pipelining via Pallas `grid` and `BlockSpec`. To do pipelining inside `core_map`, use `pltpu.emit_pipeline` inside the core-local kernel.\n", + "\n", + "**Automatically parallelize work amongst cores**\n", + "\n", + "The simple way is to annotate a block axis as `pltpu.PARALLEL`, and Pallas will automatically parallelize work along this axis. Both `pl.pallas_call` and `pltpu.emit_pipeline` supports this, via arguments `core_axis` and `dimension_semantics`. The `pallas_call` example is [in another guide](https://docs.jax.dev/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration), and the `emit_pipeline` case is shown below.\n", + "\n", + "When the `PARALLEL` annotation is provided, the corresponding grid dimension will be logically split and executed on separate cores. (The exact semantics of which grid dimensions are executed on which core is guaranteed).\n", + "\n", + "**Scratch shapes allocation**\n", + "\n", + "Note that in the example below, the top level `pl.run_scoped` (wrapped inside `kernel`) did not allocate any VMEM scratch buffers. Instead, `pltpu.emit_pipeline` allocates its own scratch buffers in VMEM and use them for its multiple buffering.\n" + ], + "metadata": { + "id": "4-G--Wnysdjs" + } + }, + { + "cell_type": "code", + "source": [ + "def add_one_body(in_vmem, out_vmem):\n", + " out_vmem[...] = in_vmem[...] + 1\n", + "\n", + "input_shape = (1024, 1024)\n", + "in_spec = jax.P('device', None)\n", + "\n", + "def add_one_kernel(x_hbm_ref, o_hbm_ref):\n", + " in_shape = x_hbm_ref.shape\n", + " pltpu.emit_pipeline(\n", + " add_one_body,\n", + " grid=(in_shape[0] // 8, in_shape[1] // 128),\n", + " in_specs=[pl.BlockSpec(\n", + " block_shape=(8, 128), index_map=lambda i, j: (i, j),\n", + " )],\n", + " out_specs=[pl.BlockSpec(\n", + " block_shape=(8, 128), index_map=lambda i, j: (i, j),\n", + " )],\n", + " core_axis_name='core',\n", + " dimension_semantics=(pltpu.PARALLEL, pltpu.ARBITRARY),\n", + " )(x_hbm_ref, o_hbm_ref)\n", + "\n", + "\n", + "@jax.jit\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=in_spec, out_specs=in_spec, check_vma=False)\n", + "def add_one(x):\n", + " return pl.kernel(add_one_kernel, out_shape=x, mesh=tc_mesh, scratch_shapes=[])(x)\n", + "\n", + "\n", + "x = jax.random.normal(jax.random.key(0), input_shape, jnp.float32)\n", + "x = jax.device_put(x, NamedSharding(mesh, in_spec))\n", + "y = add_one(x)\n", + "\n", + "np.testing.assert_array_equal(y, x + 1)" + ], + "metadata": { + "id": "xUMRPLxb1rEH", + "executionInfo": { + "status": "ok", + "timestamp": 1764795550106, + "user_tz": 480, + "elapsed": 518, + "user": { + "displayName": "Ivy Zheng", + "userId": "15297372265856137303" + } + } + }, + "execution_count": 6, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Scalar prefetch\n", + "\n", + "The code below extends the kernel above but uses [scalar prefetch and dynamic block indexing](https://docs.jax.dev/en/latest/pallas/tpu/sparse.html) to select a specific sub-slice of the input.\n", + "\n", + "This involves pre-allocating an SMEM buffer (via the `pl.run_scoped` call inside `kernel`) and populating the buffer using a `sync_copy` before the pipeline starts. Close over the dynamic index value inside the `index_map` to use it.\n", + "\n", + "**Manually delegate work amongst cores**\n", + "\n", + "The code example below also shows how `core_map` allows you to customize exactly how the work is split between cores, without relying on the automatic API shown above.\n", + "\n", + "To achieve that, customize your `index_map` to use the core index to work on different slices on different cores.\n" + ], + "metadata": { + "id": "Cq5rYyvL2Tte" + } + }, + { + "cell_type": "code", + "source": [ + "input_shape = (1024, 1024)\n", + "in_spec = jax.P('device', None)\n", + "output_shape = (1024, 512)\n", + "\n", + "def indexed_add_one_kernel(in_refs, out_refs, i_smem_ref):\n", + " (x_hbm_ref, i_hbm_ref), o_hbm_ref = in_refs, out_refs\n", + " in_shape = x_hbm_ref.shape\n", + " pltpu.sync_copy(i_hbm_ref, i_smem_ref)\n", + "\n", + " core_idx = jax.lax.axis_index('core')\n", + " core_slc_size = in_shape[0] // num_cores\n", + " i_map = lambda i: core_idx * core_slc_size // 8 + i # split work among cores\n", + " j_map = lambda j: i_smem_ref[0] // 128 + j # use the prefetched offset\n", + "\n", + " pltpu.emit_pipeline(\n", + " add_one_body,\n", + " grid=(core_slc_size // 8, output_shape[1] // 128),\n", + " in_specs=[pl.BlockSpec(\n", + " block_shape=(8, 128), index_map=lambda i, j: (i_map(i), j_map(j)),\n", + " )],\n", + " out_specs=[pl.BlockSpec(\n", + " block_shape=(8, 128), index_map=lambda i, j: (i_map(i), j),\n", + " )]\n", + " )(x_hbm_ref, o_hbm_ref)\n", + "\n", + "\n", + "@jax.jit\n", + "@partial(jax.shard_map, mesh=mesh,\n", + " in_specs=(in_spec, jax.P()), out_specs=in_spec, check_vma=False)\n", + "def indexed_add_one(x, index):\n", + " out_shape = jax.ShapeDtypeStruct((x.shape[0], x.shape[1] // 2), x.dtype)\n", + " return pl.kernel(indexed_add_one_kernel,\n", + " out_shape=out_shape, mesh=tc_mesh,\n", + " scratch_shapes=[pltpu.SMEM((1,), jnp.int32)])((x, index))\n", + "\n", + "\n", + "xs = jax.random.normal(jax.random.key(0), input_shape, jnp.float32)\n", + "xs = jax.device_put(xs, NamedSharding(mesh, in_spec))\n", + "idx = 256\n", + "y = indexed_add_one(xs, jnp.array([idx]))\n", + "\n", + "np.testing.assert_array_equal(y, xs[:, idx:(idx+512)] + 1)" + ], + "metadata": { + "id": "SE8pTStHeSWB", + "executionInfo": { + "status": "ok", + "timestamp": 1764795550778, + "user_tz": 480, + "elapsed": 378, + "user": { + "displayName": "Ivy Zheng", + "userId": "15297372265856137303" + } + } + }, + "execution_count": 7, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Mapping over SparseCores\n", + "\n", + "TPU v5p contains 4 [SparseCores](https://openxla.org/xla/sparsecore), which are specialized for sparse memory access and operations. This guide will not dive into the full capabilities of SparseCore, but rather show how to run a program on SparseCore with the same semantics and minimal changes from the TensorCore code.\n", + "\n", + "Start with knowing the basic SparseCore specs of your chip, and create a `VectorSubcoreMesh` for vector operations. Note that each SparseCore has 16 (or other number) subcores on TPU v5p, and `core_map` will run your code SPMD on each of them." + ], + "metadata": { + "id": "B8qeo-4A2KRm" + } + }, + { + "cell_type": "code", + "source": [ + "sc_info = pltpu.get_tpu_info().sparse_core\n", + "assert sc_info is not None\n", + "print(sc_info)\n", + "\n", + "sc_mesh = plsc.VectorSubcoreMesh(\n", + " core_axis_name=\"core\", subcore_axis_name=\"subcore\",\n", + " num_cores=sc_info.num_cores\n", + ")\n", + "sc_num_cores = sc_info.num_cores\n", + "sc_num_subcores = sc_info.num_subcores" + ], + "metadata": { + "id": "AHurx-yyYVvs", + "executionInfo": { + "status": "ok", + "timestamp": 1764795551102, + "user_tz": 480, + "elapsed": 55, + "user": { + "displayName": "Ivy Zheng", + "userId": "15297372265856137303" + } + }, + "outputId": "aa4a45da-dd9a-4f57-de1a-bc9b5872b2df" + }, + "execution_count": 8, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "SparseCoreInfo(num_cores=4, num_subcores=16, num_lanes=8)\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "The code below is very similar to the `add_one_kernel` we wrote earlier, except for a few differences:\n", + "\n", + "1. You need to split the work amongst all subcores, so a few lines to compute the specific slice for each subcore.\n", + "\n", + "1. SparseCore register computation allows smaller slices (`4x16` max for int32), so you need nested loops to iterate the slice during computation phase." + ], + "metadata": { + "id": "n2_dfsUWFgwU" + } + }, + { + "cell_type": "code", + "source": [ + "input_shape = (4096, 128)\n", + "SC_REG_OP_SHAPE = (4, 16)\n", + "\n", + "def sc_add_one_body(in_vmem, out_vmem):\n", + " @pl.loop(0, in_vmem.shape[0], step=SC_REG_OP_SHAPE[0])\n", + " def _reg_loop_0(c0):\n", + " @pl.loop(0, in_vmem.shape[1], step=SC_REG_OP_SHAPE[1])\n", + " def _reg_loop_1(c1):\n", + " slc = (pl.ds(c0, SC_REG_OP_SHAPE[0]), pl.ds(c1, SC_REG_OP_SHAPE[1]))\n", + " out_vmem[slc] = in_vmem[slc] + 1\n", + "\n", + "\n", + "def sc_add_one_kernel(x_hbm_ref, o_hbm_ref):\n", + " in_shape = x_hbm_ref.shape\n", + " core_idx = jax.lax.axis_index('core')\n", + " subcore_idx = jax.lax.axis_index(\"subcore\")\n", + " cm_idx = core_idx * sc_num_subcores + subcore_idx # index on the core_map\n", + " slc_size = in_shape[0] // (sc_num_subcores * sc_num_cores)\n", + " index_map = lambda i, j: (\n", + " pl.ds(pl.multiple_of(cm_idx * slc_size + i * 8, 8), 8), j)\n", + "\n", + " pltpu.emit_pipeline(\n", + " sc_add_one_body,\n", + " grid=(slc_size // 8, in_shape[1] // 128),\n", + " in_specs=[pl.BlockSpec(\n", + " block_shape=(pl.BoundedSlice(8), 128), index_map=index_map,\n", + " )],\n", + " out_specs=[pl.BlockSpec(\n", + " block_shape=(pl.BoundedSlice(8), 128), index_map=index_map,\n", + " )]\n", + " )(x_hbm_ref, o_hbm_ref)\n", + "\n", + "\n", + "@jax.jit\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=in_spec, out_specs=in_spec, check_vma=False)\n", + "def sc_add_one(x):\n", + " return pl.kernel(sc_add_one_kernel, out_shape=x, mesh=sc_mesh, scratch_shapes=[])(x)\n", + "\n", + "\n", + "x = jax.random.randint(jax.random.key(0), input_shape, 0, 64, jnp.int32)\n", + "x = jax.device_put(x, NamedSharding(mesh, in_spec))\n", + "y = sc_add_one(x)\n", + "\n", + "np.testing.assert_array_equal(y, x + 1)" + ], + "metadata": { + "id": "6fNShx6k2kxi", + "executionInfo": { + "status": "ok", + "timestamp": 1764795552411, + "user_tz": 480, + "elapsed": 1117, + "user": { + "displayName": "Ivy Zheng", + "userId": "15297372265856137303" + } + } + }, + "execution_count": 9, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/docs/pallas/tpu/core_map.md b/docs/pallas/tpu/core_map.md new file mode 100644 index 000000000000..4e00399b6fe8 --- /dev/null +++ b/docs/pallas/tpu/core_map.md @@ -0,0 +1,363 @@ +# Pallas Core-specific Programming + +In this guide, we explore using `pl.core_map` to write Pallas kernels. Compared with `pallas_call`, `core_map` offers a few key characteristics: + +* **Per-core level programming**: You write code for an TPU/GPU core, not for a JAX device. This gives you full control over what runs on every core, or how cores communicate and distribute work among one another. + +* **Collectives**: `core_map` explicitly models physical cores, so inter-core communication can be expressed safely. + +* **Platform generic**: `core_map` programming model works for TPU (TensorCore and SparseCore) and GPU with minimal boilerplate changes. + +This guide focuses on TPU. For how to use `core_map` on GPU to achieve higher thread flexibility, check out our [Pallas GPU `core_map` tutorial](https://docs.jax.dev/en/latest/pallas/gpu/reference.html#using-core-map). + +## Environment setup + +Modern accelerators often have multiple cores under a device. For recent TPU chips (v4, v5p), every JAX device may contains 2 TensorCores (aka. a [Megacore](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#chips)). Some TPUs (v5p, v6e, 7x) also contain [SparseCores](https://openxla.org/xla/sparsecore#specifications_at_a_glance), each of which consists of many subcores. + +This guide was written on a v5p chip, which contains 4 devices (2 TensorCores each) and 4 SparseCores, each with 16 subcores. + + +```python +from functools import partial + +import jax +from jax.sharding import NamedSharding +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +from jax.experimental.pallas import tpu_sc as plsc +import jax.numpy as jnp +import numpy as np + + +num_devices = jax.local_device_count() +assert num_devices > 1, "Please run this notebook with more than one device." + +tpu_info = pltpu.get_tpu_info() # This notebook only runs on TPU. +print(f"Running on {num_devices} TPU {tpu_info.chip_version} devices.") +``` + + Running on 4 TPU v5p devices. + + +In addition to the typical TPU device mesh, you need to make a mesh of cores. Consider this as an addition dimension called `core`, with length 2, in addition to the 4-device mesh you work with. That is 8 cores in total. + + +```python +# Mesh of devices +mesh = jax.make_mesh((jax.device_count(),), ('device',)) +print(mesh) + +# Mesh of cores, within a JAX device +tc_mesh = pltpu.create_tensorcore_mesh('core') +print(tc_mesh) + +num_devices = mesh.size +num_cores = len(tc_mesh.devices) +print(f"There are {num_devices} devices, and {num_cores} cores each.") +``` + + Mesh('device': 4, axis_types=(Explicit,)) + TensorCoreMesh(devices=array([TensorCore(id=0), TensorCore(id=1)], dtype=object), axis_names=('core',)) + There are 4 devices, and 2 cores each. + + +## A simple per-core kernel + +`pl.core_map` allows you to write per-core local code, just as `jax.shard_map` allows you to write per-device code. + +In the example kernel below, each core has its own VMEM and semaphore allocations. As with normal kernel, you can initiate copies between HBM and VMEM refs using `pltpu.async_copy`. + +**Communication between cores** + +Before communicating between cores, it is good practice to perform a barrier (using `pltpu.semaphore_signal`) to ensure resources have been allocated and both cores are at the same point during the program. + +Once the cores are synchronized, use `pltpu.make_async_remote_copy` to send data between them. The `device_id` keyword argument generically allows sending to any core on any device, but if you just pass in `{'core': other_core_id}`, it will perform a intra-device inter-core copy (the other axis names are held constant). + + + +```python +# This runs on every core +def swap_cores_kernel(in_hbm, out_hbm, + in_vmem, scratch_vmem, out_vmem, + sem, send_sem, recv_sem): + core_index = jax.lax.axis_index('core') + num_cores = jax.lax.axis_size('core') + slc_size = in_hbm.shape[-1] // num_cores + slc = pl.ds(core_index * slc_size, slc_size) + + # Copy in a core-dependent slice of the input + pltpu.async_copy(in_hbm.at[:, slc], in_vmem, sem).wait() + + # A barrier to make sure all cores have entered run_scoped. + # You won't need this if not doing inter-core communications. + dst_core = (core_index + 1) % num_cores + sem0 = pltpu.get_barrier_semaphore() + pltpu.semaphore_signal(sem0, 1, device_id={'core': dst_core}) + pltpu.semaphore_wait(sem0, 1) + + # Swap data between core 0 and core 1 + the_copy = pltpu.make_async_remote_copy( + in_vmem, scratch_vmem, send_sem, recv_sem, device_id={'core': dst_core}, + ) + the_copy.start() + the_copy.wait() + + # Core-local compute + out_vmem[...] = scratch_vmem[...] * 2 + + # Copy out the output + pltpu.async_copy(out_vmem, out_hbm.at[:, slc], sem).wait() + +``` + +Once you have the local kernel: + + * Start your top-level JAX code with HBM refs, and allocate output refs if needed. + + * Use `pl.core_map`, which takes the TensorCore mesh, to start per-core programming. + + * You will need `collective_id` for the barrier semaphore. + + * Inside `pl.core_map`, invoke `pl.run_scoped` to allocate per-core scratch spaces (VMEM and semaphores) and run the local kernel. + + +```python +input_shape = (32, 256) +local_vmem_shape = (32 // num_devices, 256 // num_cores) +in_spec = jax.P('device', None) +sharding = NamedSharding(mesh, in_spec) + +@jax.jit +@partial(jax.shard_map, mesh=mesh, in_specs=in_spec, out_specs=in_spec, + check_vma=False) +def swap_cores(x): + # Get buffers out of the input and output + x_hbm_ref = jax.new_ref(x) + o_hbm_ref = jax.new_ref(jax.lax.empty(x.shape, x.dtype)) + + @pl.core_map(tc_mesh, compiler_params=pltpu.CompilerParams(collective_id=0)) + def _(): + pl.run_scoped( + partial(swap_cores_kernel, x_hbm_ref, o_hbm_ref), + *([pltpu.VMEM(local_vmem_shape, x.dtype)] * 3), # VMEM allocations + *([pltpu.SemaphoreType.DMA] * 3), # semaphores + ) + return o_hbm_ref[...] + + +x = jax.random.normal(jax.random.key(0), input_shape, jnp.float32) +x = jax.device_put(x, sharding) +y = swap_cores(x) + +np.testing.assert_array_equal(y[:, 128:], x[:, :128] * 2) +np.testing.assert_array_equal(y[:, :128], x[:, 128:] * 2) +``` + +### Save the boilerplate + +You can use the `pl.kernel` decorator to wrap boilerplate such as `core_map`, `run_scoped`, and output buffer allocation. + +Note that this should run inside any `jax.shard_map` you may have at the top level. + + +```python +@jax.jit +@partial(jax.shard_map, mesh=mesh, in_specs=in_spec, out_specs=in_spec, check_vma=False) +def swap_cores(x): + scratch_shapes = [pltpu.VMEM(local_vmem_shape, x.dtype)] * 3 + [pltpu.SemaphoreType.DMA] * 3 + return pl.kernel(swap_cores_kernel, out_shape=x, mesh=tc_mesh, + scratch_shapes=scratch_shapes, + compiler_params=pltpu.CompilerParams(collective_id=0))(x) + +y = swap_cores(x) +np.testing.assert_array_equal(y[:, 128:], x[:, :128] * 2) +np.testing.assert_array_equal(y[:, :128], x[:, 128:] * 2) +``` + +## Pipelining with `core_map` + +Note that the kernel above only does simple copies and compute, without automatic pipelining via Pallas `grid` and `BlockSpec`. To do pipelining inside `core_map`, use `pltpu.emit_pipeline` inside the core-local kernel. + +**Automatically parallelize work amongst cores** + +The simple way is to annotate a block axis as `pltpu.PARALLEL`, and Pallas will automatically parallelize work along this axis. Both `pl.pallas_call` and `pltpu.emit_pipeline` supports this, via arguments `core_axis` and `dimension_semantics`. The `pallas_call` example is [in another guide](https://docs.jax.dev/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration), and the `emit_pipeline` case is shown below. + +When the `PARALLEL` annotation is provided, the corresponding grid dimension will be logically split and executed on separate cores. (The exact semantics of which grid dimensions are executed on which core is guaranteed). + +**Scratch shapes allocation** + +Note that in the example below, the top level `pl.run_scoped` (wrapped inside `kernel`) did not allocate any VMEM scratch buffers. Instead, `pltpu.emit_pipeline` allocates its own scratch buffers in VMEM and use them for its multiple buffering. + + + +```python +def add_one_body(in_vmem, out_vmem): + out_vmem[...] = in_vmem[...] + 1 + +input_shape = (1024, 1024) +in_spec = jax.P('device', None) + +def add_one_kernel(x_hbm_ref, o_hbm_ref): + in_shape = x_hbm_ref.shape + pltpu.emit_pipeline( + add_one_body, + grid=(in_shape[0] // 8, in_shape[1] // 128), + in_specs=[pl.BlockSpec( + block_shape=(8, 128), index_map=lambda i, j: (i, j), + )], + out_specs=[pl.BlockSpec( + block_shape=(8, 128), index_map=lambda i, j: (i, j), + )], + core_axis_name='core', + dimension_semantics=(pltpu.PARALLEL, pltpu.ARBITRARY), + )(x_hbm_ref, o_hbm_ref) + + +@jax.jit +@partial(jax.shard_map, mesh=mesh, in_specs=in_spec, out_specs=in_spec, check_vma=False) +def add_one(x): + return pl.kernel(add_one_kernel, out_shape=x, mesh=tc_mesh, scratch_shapes=[])(x) + + +x = jax.random.normal(jax.random.key(0), input_shape, jnp.float32) +x = jax.device_put(x, NamedSharding(mesh, in_spec)) +y = add_one(x) + +np.testing.assert_array_equal(y, x + 1) +``` + +## Scalar prefetch + +The code below extends the kernel above but uses [scalar prefetch and dynamic block indexing](https://docs.jax.dev/en/latest/pallas/tpu/sparse.html) to select a specific sub-slice of the input. + +This involves pre-allocating an SMEM buffer (via the `pl.run_scoped` call inside `kernel`) and populating the buffer using a `sync_copy` before the pipeline starts. Close over the dynamic index value inside the `index_map` to use it. + +**Manually delegate work amongst cores** + +The code example below also shows how `core_map` allows you to customize exactly how the work is split between cores, without relying on the automatic API shown above. + +To achieve that, customize your `index_map` to use the core index to work on different slices on different cores. + + + +```python +input_shape = (1024, 1024) +in_spec = jax.P('device', None) +output_shape = (1024, 512) + +def indexed_add_one_kernel(in_refs, out_refs, i_smem_ref): + (x_hbm_ref, i_hbm_ref), o_hbm_ref = in_refs, out_refs + in_shape = x_hbm_ref.shape + pltpu.sync_copy(i_hbm_ref, i_smem_ref) + + core_idx = jax.lax.axis_index('core') + core_slc_size = in_shape[0] // num_cores + i_map = lambda i: core_idx * core_slc_size // 8 + i # split work among cores + j_map = lambda j: i_smem_ref[0] // 128 + j # use the prefetched offset + + pltpu.emit_pipeline( + add_one_body, + grid=(core_slc_size // 8, output_shape[1] // 128), + in_specs=[pl.BlockSpec( + block_shape=(8, 128), index_map=lambda i, j: (i_map(i), j_map(j)), + )], + out_specs=[pl.BlockSpec( + block_shape=(8, 128), index_map=lambda i, j: (i_map(i), j), + )] + )(x_hbm_ref, o_hbm_ref) + + +@jax.jit +@partial(jax.shard_map, mesh=mesh, + in_specs=(in_spec, jax.P()), out_specs=in_spec, check_vma=False) +def indexed_add_one(x, index): + out_shape = jax.ShapeDtypeStruct((x.shape[0], x.shape[1] // 2), x.dtype) + return pl.kernel(indexed_add_one_kernel, + out_shape=out_shape, mesh=tc_mesh, + scratch_shapes=[pltpu.SMEM((1,), jnp.int32)])((x, index)) + + +xs = jax.random.normal(jax.random.key(0), input_shape, jnp.float32) +xs = jax.device_put(xs, NamedSharding(mesh, in_spec)) +idx = 256 +y = indexed_add_one(xs, jnp.array([idx])) + +np.testing.assert_array_equal(y, xs[:, idx:(idx+512)] + 1) +``` + +## Mapping over SparseCores + +TPU v5p contains 4 [SparseCores](https://openxla.org/xla/sparsecore), which are specialized for sparse memory access and operations. This guide will not dive into the full capabilities of SparseCore, but rather show how to run a program on SparseCore with the same semantics and minimal changes from the TensorCore code. + +Start with knowing the basic SparseCore specs of your chip, and create a `VectorSubcoreMesh` for vector operations. Note that each SparseCore has 16 (or other number) subcores on TPU v5p, and `core_map` will run your code SPMD on each of them. + + +```python +sc_info = pltpu.get_tpu_info().sparse_core +assert sc_info is not None +print(sc_info) + +sc_mesh = plsc.VectorSubcoreMesh( + core_axis_name="core", subcore_axis_name="subcore", + num_cores=sc_info.num_cores +) +sc_num_cores = sc_info.num_cores +sc_num_subcores = sc_info.num_subcores +``` + + SparseCoreInfo(num_cores=4, num_subcores=16, num_lanes=8) + + +The code below is very similar to the `add_one_kernel` we wrote earlier, except for a few differences: + +1. You need to split the work amongst all subcores, so a few lines to compute the specific slice for each subcore. + +1. SparseCore register computation allows smaller slices (`4x16` max for int32), so you need nested loops to iterate the slice during computation phase. + + +```python +input_shape = (4096, 128) +SC_REG_OP_SHAPE = (4, 16) + +def sc_add_one_body(in_vmem, out_vmem): + @pl.loop(0, in_vmem.shape[0], step=SC_REG_OP_SHAPE[0]) + def _reg_loop_0(c0): + @pl.loop(0, in_vmem.shape[1], step=SC_REG_OP_SHAPE[1]) + def _reg_loop_1(c1): + slc = (pl.ds(c0, SC_REG_OP_SHAPE[0]), pl.ds(c1, SC_REG_OP_SHAPE[1])) + out_vmem[slc] = in_vmem[slc] + 1 + + +def sc_add_one_kernel(x_hbm_ref, o_hbm_ref): + in_shape = x_hbm_ref.shape + core_idx = jax.lax.axis_index('core') + subcore_idx = jax.lax.axis_index("subcore") + cm_idx = core_idx * sc_num_subcores + subcore_idx # index on the core_map + slc_size = in_shape[0] // (sc_num_subcores * sc_num_cores) + index_map = lambda i, j: ( + pl.ds(pl.multiple_of(cm_idx * slc_size + i * 8, 8), 8), j) + + pltpu.emit_pipeline( + sc_add_one_body, + grid=(slc_size // 8, in_shape[1] // 128), + in_specs=[pl.BlockSpec( + block_shape=(pl.BoundedSlice(8), 128), index_map=index_map, + )], + out_specs=[pl.BlockSpec( + block_shape=(pl.BoundedSlice(8), 128), index_map=index_map, + )] + )(x_hbm_ref, o_hbm_ref) + + +@jax.jit +@partial(jax.shard_map, mesh=mesh, in_specs=in_spec, out_specs=in_spec, check_vma=False) +def sc_add_one(x): + return pl.kernel(sc_add_one_kernel, out_shape=x, mesh=sc_mesh, scratch_shapes=[])(x) + + +x = jax.random.randint(jax.random.key(0), input_shape, 0, 64, jnp.int32) +x = jax.device_put(x, NamedSharding(mesh, in_spec)) +y = sc_add_one(x) + +np.testing.assert_array_equal(y, x + 1) +``` diff --git a/docs/pallas/tpu/details.rst b/docs/pallas/tpu/details.rst index 0575806e6037..91aefd52d2e8 100644 --- a/docs/pallas/tpu/details.rst +++ b/docs/pallas/tpu/details.rst @@ -99,8 +99,8 @@ for exceptions). This unlocks some interesting capabilities: output, without any risk of race conditions. However, we do require that all invocations that write to a particular slice are consecutive. -The "consecutive" restriction on the output usually means that the some prefix -of the grid dimensions always vary the slice of the output an invocation needs +The "consecutive" restriction on the output usually means that some prefix +of the grid dimensions always varies the slice of the output an invocation needs to access, while the output window remains constant for the remaining suffix. For example, when implementing a Pallas TPU kernel for matrix multiplication, @@ -128,7 +128,7 @@ has no impact on performance, as the compiler is free to rearrange them. However, as Pallas is meant to expose lower-level capabilities, the dimension order can have great impact on the quality of generated code. -TPUs perform bulk of the computation on 2D vector registers, which are typically of +TPUs perform the bulk of the computation on 2D vector registers, which are typically of size 8x128 for 32-bit values (as of TPU v6). When a vector value is loaded from VMEM into registers (e.g. ``x = x_ref[...]``), the last two dimensions of the array will be tiled into the registers. @@ -167,10 +167,11 @@ sequential grid execution guarantees, and will need to parallelize one of the grid axes over cores. This is an opt-in procedure. To allow that, ``pallas_call`` requires an extra parameter named ``dimension_semantics``: -.. +.. code:: python + pallas_call( ..., - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=["parallel", "parallel", "arbitrary"] ), ) diff --git a/docs/pallas/tpu/distributed.ipynb b/docs/pallas/tpu/distributed.ipynb index b52ec579f508..feebb7c2f8e7 100644 --- a/docs/pallas/tpu/distributed.ipynb +++ b/docs/pallas/tpu/distributed.ipynb @@ -8,21 +8,21 @@ "source": [ "# Distributed Computing in Pallas for TPUs\n", "\n", - "In this tutorial, we will cover the basics of distributed computing in Pallas on TPUs. We will learn about TPU topologies, communication using the remote DMA primitive, and calling a distributed kernel from JAX using `shard_map`. We will also cover some more advanced kernel writing techniques, such as double-buffering, bi-directional bandwidth optimization, and nested pipelining. As educational examples, we will learn how to implement various collective primitives from JAX, such as `lax.ppermute`, `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`.\n", + "In this tutorial, we will cover the basics of distributed computing in Pallas on TPUs. We will learn about TPU topologies, communication using the remote DMA primitive, and calling a distributed kernel from JAX using `jax.shard_map`. We will also cover some more advanced kernel writing techniques, such as double-buffering, bi-directional bandwidth optimization, and nested pipelining. As educational examples, we will learn how to implement various collective primitives from JAX, such as `lax.ppermute`, `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`.\n", "\n", "Some recommended readings beforehand:\n", " - [Pallas Pipelining on TPU](pallas_tpu_pipelining)\n", - " - [Collectives with `shard_map`](shard_map_collectives_tutorial)" + " - [Collectives with `jax.shard_map`](shard_map_collectives_tutorial)" ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 6, "metadata": { "executionInfo": { - "elapsed": 1978, + "elapsed": 52, "status": "ok", - "timestamp": 1722904801801, + "timestamp": 1744390458993, "user": { "displayName": "Justin Fu", "userId": "17543197034567316452" @@ -30,23 +30,23 @@ "user_tz": 420 }, "id": "PyAGnWc9yI8T", - "outputId": "1d8229bd-cab5-495f-93e9-fff2e41db480" + "outputId": "c5912653-c34b-4810-c373-4a2787691317" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Running with 4 TPU v5 lite devices.\n" + "Running with 4 TPU v4 devices.\n" ] } ], "source": [ + "import functools\n", "import jax\n", "from jax import lax\n", "from jax import numpy as jnp\n", "from jax.experimental import pallas as pl\n", - "from jax.experimental import shard_map\n", "from jax.experimental.pallas import tpu as pltpu\n", "\n", "P = jax.sharding.PartitionSpec\n", @@ -71,7 +71,7 @@ "\n", "![tpu_topologies](https://cloud.google.com/static/tpu/docs/images/v4-topologies.png)\n", "\n", - "Flattened as a graph, the torus can be visualized as follows. Each edge (orange or black) is a bidirectional connection between two devices. You will commonly hear about rings in conjunction with discussion about device toplogies — a key feature of a torus is that when taking a slice along an axis of the pod, such as the nodes `[(0,1), (1, 1), (2, 1), (3, 1)]` or `[(0, 1), (1, 1)]`, we have a ring of devices. This is a feature we can use to simplify communication patterns within the pod.\n", + "Flattened as a graph, the torus can be visualized as follows. Each edge (orange or black) is a bidirectional connection between two devices. You will commonly hear about rings in conjunction with discussion about device topologies — a key feature of a torus is that when taking a slice along an axis of the pod, such as the nodes `[(0,1), (1, 1), (2, 1), (3, 1)]` or `[(0, 1), (1, 1)]`, we have a ring of devices. This is a feature we can use to simplify communication patterns within the pod.\n", "\n", "![tpu_torus](https://cloud.google.com/static/tpu/docs/images/untwisted-tori.png)" ] @@ -178,7 +178,7 @@ "\n", "`send_sem` and `recv_sem` are instances of a special type of semaphore reserved exclusively for use with DMAs. They must be allocated with the `tpu.SemaphoreType.DMA` type when specifying input specs to `pallas_call`.\n", "\n", - "Internally, DMA semaphores can be thought of as integer-valued progress trackers. On DMA start, the local device will begin to increment the value of `send_sem` and the receiver's `recv_sem` asynchronously. Waiting on a semaphore will block until the value of the semaphore reaches the total bytes of data sent/received; when the value is reached, waiting threads are released and the sempahore's value is decremented by the same amount. This means that either all data has been sent (for `send_sem`) or all data has been received (for `dst_sem`). The value of the semaphore can be read with `pl.semaphore_read`, but note that the underlying semantics of the value could change between hardware generations (e.g. the value may not represent exactly the number of bytes sent, although this is a useful mental model to have when reasoning about the behavior of the semaphore).\n", + "Internally, DMA semaphores can be thought of as integer-valued progress trackers. On DMA start, the local device will begin to increment the value of `send_sem` and the receiver's `recv_sem` asynchronously. Waiting on a semaphore will block until the value of the semaphore reaches the total bytes of data sent/received; when the value is reached, waiting threads are released and the semaphore's value is decremented by the same amount. This means that either all data has been sent (for `send_sem`) or all data has been received (for `recv_sem`). The value of the semaphore can be read with `pl.semaphore_read`, but note that the underlying semantics of the value could change between hardware generations (e.g. the value may not represent exactly the number of bytes sent, although this is a useful mental model to have when reasoning about the behavior of the semaphore).\n", "\n", "### Routing\n", "\n", @@ -215,12 +215,12 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 7, "metadata": { "executionInfo": { - "elapsed": 1606, + "elapsed": 152, "status": "ok", - "timestamp": 1722904803566, + "timestamp": 1744390459367, "user": { "displayName": "Justin Fu", "userId": "17543197034567316452" @@ -228,7 +228,7 @@ "user_tz": 420 }, "id": "YkyIKN2thZ-V", - "outputId": "9b7ed142-d161-4237-fed8-cbce41adc5f0" + "outputId": "26719bb9-87ff-46dd-af90-a114ce332417" }, "outputs": [ { @@ -271,11 +271,11 @@ "out_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32)\n", "grid_spec = pltpu.PrefetchScalarGridSpec(\n", " num_scalar_prefetch=0,\n", - " # TPUMemorySpace.ANY will (usually) place the tensor in HBM.\n", + " # MemorySpace.ANY will (usually) place the tensor in HBM.\n", " in_specs=[\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " pl.BlockSpec(memory_space=pl.ANY),\n", " ],\n", - " out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " out_specs=pl.BlockSpec(memory_space=pl.ANY),\n", " scratch_shapes=(\n", " # We allocate DMA semaphores in scratch memory.\n", " [pltpu.SemaphoreType.DMA] * 2\n", @@ -288,12 +288,12 @@ ")\n", "# Wrap the kernel within a shard_map to call.\n", "pallas_result = jax.jit(\n", - " shard_map.shard_map(\n", + " jax.shard_map(\n", " right_permute,\n", " mesh=mesh,\n", " in_specs=partition,\n", " out_specs=partition,\n", - " check_rep=False,\n", + " check_vma=False,\n", " )\n", ")(input_arr)\n", "\n", @@ -301,7 +301,7 @@ "perm = tuple((src, (src + 1) % num_devices) for src in range(num_devices))\n", "\n", "xla_result = jax.jit(\n", - " shard_map.shard_map(\n", + " jax.shard_map(\n", " lambda x: lax.ppermute(x, 'x', perm),\n", " mesh=mesh, in_specs=partition, out_specs=partition)\n", ")(input_arr)\n", @@ -338,12 +338,12 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 8, "metadata": { "executionInfo": { - "elapsed": 812, + "elapsed": 209, "status": "ok", - "timestamp": 1722904804531, + "timestamp": 1744390459789, "user": { "displayName": "Justin Fu", "userId": "17543197034567316452" @@ -351,7 +351,7 @@ "user_tz": 420 }, "id": "ojQEZB5mBRqM", - "outputId": "e1648f54-737c-4921-ca3b-b4c639a38d2b" + "outputId": "3a4373f8-1fb5-4a6b-b88e-3461c2609021" }, "outputs": [ { @@ -420,10 +420,10 @@ "grid_spec = pltpu.PrefetchScalarGridSpec(\n", " num_scalar_prefetch=0,\n", " in_specs=[\n", - " # TPUMemorySpace.ANY will (usually) place the tensor in HBM.\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " # MemorySpace.ANY will (usually) place the tensor in HBM.\n", + " pl.BlockSpec(memory_space=pl.ANY),\n", " ],\n", - " out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " out_specs=pl.BlockSpec(memory_space=pl.ANY),\n", " scratch_shapes=(\n", " # DMA semaphores are allocated in scratch memory.\n", " # We allocated one semaphore for a local HBM-VMEM copy,\n", @@ -447,18 +447,18 @@ "\n", "# Wrap the kernel within a shard_map to call.\n", "pallas_result = jax.jit(\n", - " shard_map.shard_map(\n", + " jax.shard_map(\n", " all_gather,\n", " mesh=mesh,\n", " in_specs=partition,\n", " out_specs=partition,\n", - " check_rep=False\n", + " check_vma=False\n", " )\n", ")(input_arr)\n", "\n", "# Compare Pallas result to XLA shard_map result.\n", "xla_result = jax.jit(\n", - " shard_map.shard_map(\n", + " jax.shard_map(\n", " lambda x: lax.all_gather(x, 'x'),\n", " mesh=mesh, in_specs=partition, out_specs=partition\n", " )\n", @@ -477,13 +477,13 @@ "id": "KgU7HI2pS4om" }, "source": [ - "A detail worth mentioning here is the use of multiple receive semaphores. Because we only block on the receiving device, it is still possible for a sender to have sent multiple DMAs in flight before the receiver has finished processing the first one (see the next section and reduce-sum example which discusses race conditions in more detail). In this situation we may hit a situation where the same semaphore is being used for multiple DMAs occurring simultaneously. To avoid this, we allocate `num_devices-1` semaphores so there is no risk of re-use. While this race condition is unlikely to happen on such a small kernel, on larger kernels there is more chance for devices to fall out of sync and potentially cause a silent failure." + "A detail worth mentioning here is the use of multiple receive semaphores. Because we only block on the receiving device, it is still possible for a sender to have sent multiple DMAs in flight before the receiver has finished processing the first one (see the next section and reduce-sum example which discusses race conditions in more detail). In this situation we may hit a situation where the same semaphore is being used for multiple DMAs occurring simultaneously. To avoid this, we allocate `num_devices-1` semaphores so there is no risk of reuse. While this race condition is unlikely to happen on such a small kernel, on larger kernels there is more chance for devices to fall out of sync and potentially cause a silent failure." ] }, { "cell_type": "markdown", "metadata": { - "id": "KgU7HI2pS4om" + "id": "EDCmAaHVtY7x" }, "source": [ "## Advanced Techniques\n", @@ -529,9 +529,9 @@ "\n", "In order to use regular semaphores, they can be allocated in the same way as a DMA semaphore, but by specifying `pltpu.SemaphoreType.REGULAR` rather than `pltpu.SemaphoreType.DMA`.\n", "\n", - "Semaphores must be zero at the end of a Pallas program to complete succesfully. There are two error cases where this may happen:\n", + "Semaphores must be zero at the end of a Pallas program to complete successfully. There are two error cases where this may happen:\n", " - If a semaphore is over-signaled, the program will end with non-zero (>0) semaphores. In this case, the program will crash upon completion. This is useful for debugging as non-zero semaphores typically means there is a bug somewhere inside of the program.\n", - " - If a semaphore is over-waited, the program will hang on the blocking `semaphore_wait` call while it waits for the sempahore to be incremented. In this case the device or program will need to be restarted.\n", + " - If a semaphore is over-waited, the program will hang on the blocking `semaphore_wait` call while it waits for the semaphore to be incremented. In this case the device or program will need to be restarted.\n", "\n", "#### Barrier Semaphores\n", "\n", @@ -569,7 +569,7 @@ "kernel = pl.pallas_call(\n", " example_kernel,\n", " ...,\n", - " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n", + " compiler_params=pltpu.CompilerParams(collective_id=0),\n", ")\n", "```" ] @@ -644,19 +644,19 @@ "\n", "The main body assumes that a value has already been copied into our local working slot, either from the previous iteration or from the prologue. A complicating factor is that our destination buffers live in HBM, but we need to load values to VMEM before we perform arithmetic. Therefore, we simultaneously copy the working slot value into our VMEM (`receive_scratch`) and pass the value on to our right neighbor's receiving slot. Once the value has been copied into our VMEM, we can accumulate it into our result (contained in `o_ref`).\n", "\n", - "A subtle race condition can occur if one device runs one loop ahead of it's right neighbor. In this case, it could copy into the receiver's `working_slot` at the same time the receiver is reading from it. In order to avoid this, each device will block on a `REGULAR` semaphore before copying into the right neighbor's `dst_ref` until it has signaled that it is done reading from its `working_slot`. This race condition is rarely triggered for a small kernel such as this example, but can it can be explicitly triggered if for example using a `pltpu.delay` instruction to artifically hang a device.\n", + "A subtle race condition can occur if one device runs one loop ahead of it's right neighbor. In this case, it could copy into the receiver's `working_slot` at the same time the receiver is reading from it. In order to avoid this, each device will block on a `REGULAR` semaphore before copying into the right neighbor's `dst_ref` until it has signaled that it is done reading from its `working_slot`. This race condition is rarely triggered for a small kernel such as this example, but can it can be explicitly triggered if for example using a `pl.delay` instruction to artificially hang a device.\n", "\n", "Note that this is not an optimal or fully general kernel, as the block sizes must entirely fit in VMEM and we could better interleave communication and accumulation. We will discuss these optimizations in later sections." ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 9, "metadata": { "executionInfo": { - "elapsed": 254, + "elapsed": 248, "status": "ok", - "timestamp": 1722904804952, + "timestamp": 1744390460289, "user": { "displayName": "Justin Fu", "userId": "17543197034567316452" @@ -664,7 +664,7 @@ "user_tz": 420 }, "id": "XrY5bMlvBroQ", - "outputId": "77497000-4496-462e-cc3c-73fb640cc14c" + "outputId": "9216e749-48d2-43ff-d64b-bd419acf3e11" }, "outputs": [ { @@ -674,7 +674,7 @@ "Input = [0.9858954 0.11763906 0.9955574 0.775211 ]\n", "Pallas result = [2.8743029 2.8743029 2.8743029 2.8743029]\n", "lax.psum result = [2.8743029 2.8743029 2.8743029 2.8743029]\n", - "Difference |Pallas - lax.psum| = 1.4959369e-08\n" + "Difference |Pallas - lax.psum| = 1.0535587e-08\n" ] } ], @@ -687,6 +687,41 @@ "input_arr = jax.device_put(input_arr, sharding)\n", "\n", "\n", + "def local_barrier(left_neighbor, right_neighbor, double_barrier=True):\n", + " \"\"\"Performs a barrier with neighbors on the global barrier semaphore.\n", + "\n", + " Optionally performs a second barrier, which prevents a potential race\n", + " when reusing the same collective_id across kernel invocations.\n", + " \"\"\"\n", + " barrier_sem = pltpu.get_barrier_semaphore()\n", + " for neighbor in [left_neighbor, right_neighbor]:\n", + " pltpu.semaphore_signal(\n", + " barrier_sem,\n", + " inc=1,\n", + " device_id=(neighbor,),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " pltpu.semaphore_wait(barrier_sem, 2)\n", + " if double_barrier:\n", + " # The double-barrier prevents a race condition where one neighbor can\n", + " # re-enter the kernel again on a subsequent call and increment the\n", + " # barrier semaphore a second time. This would unblock the current device\n", + " # even if the other neighbor is not ready yet.\n", + " # To implement a double-barrier, we stack-allocate a second REGULAR\n", + " # semaphore using run_scoped.\n", + " @functools.partial(pl.run_scoped,\n", + " second_barrier=pltpu.SemaphoreType.REGULAR)\n", + " def _(second_barrier):\n", + " for neighbor in [left_neighbor, right_neighbor]:\n", + " pltpu.semaphore_signal(\n", + " second_barrier,\n", + " inc=1,\n", + " device_id=(neighbor,),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " pltpu.semaphore_wait(second_barrier, 2)\n", + "\n", + "\n", "def all_reduce_kernel(\n", " x_ref,\n", " o_ref,\n", @@ -709,20 +744,7 @@ " def _():\n", " # Barrier with both neighbors at the start, since we will be\n", " # communicating with both.\n", - " barrier_sem = pltpu.get_barrier_semaphore()\n", - " pltpu.semaphore_signal(\n", - " barrier_sem,\n", - " inc=1,\n", - " device_id=(left_neighbor,),\n", - " device_id_type=pltpu.DeviceIdType.MESH,\n", - " )\n", - " pltpu.semaphore_signal(\n", - " barrier_sem,\n", - " inc=1,\n", - " device_id=(right_neighbor,),\n", - " device_id_type=pltpu.DeviceIdType.MESH,\n", - " )\n", - " pltpu.semaphore_wait(barrier_sem, 2)\n", + " local_barrier(left_neighbor, right_neighbor)\n", "\n", " # Initialize o_ref, acc_scratch, and hbm_scratch.\n", " o_ref[...] = jnp.zeros_like(o_ref)\n", @@ -787,13 +809,13 @@ " num_scalar_prefetch=0,\n", " in_specs=[\n", " # Our input lives in VMEM\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n", + " pl.BlockSpec(memory_space=pltpu.VMEM),\n", " ],\n", " out_specs=[\n", " # Our output lives in VMEM\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n", + " pl.BlockSpec(memory_space=pltpu.VMEM),\n", " # Our double-buffer lives in HBM\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " pl.BlockSpec(memory_space=pl.ANY),\n", " ],\n", " grid=(num_devices,),\n", " scratch_shapes=(\n", @@ -807,16 +829,16 @@ " all_reduce_kernel,\n", " out_shape=out_shape,\n", " grid_spec=grid_spec,\n", - " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n", + " compiler_params=pltpu.CompilerParams(collective_id=0),\n", ")\n", "\n", "pallas_result = jax.jit(\n", - " shard_map.shard_map(\n", + " jax.shard_map(\n", " kernel,\n", " mesh=mesh,\n", " in_specs=partition,\n", " out_specs=partition,\n", - " check_rep=False,\n", + " check_vma=False,\n", " )\n", ")(input_arr)\n", "pallas_result = jax.block_until_ready(pallas_result)[0]\n", @@ -827,7 +849,7 @@ "\n", "\n", "xla_result = jax.jit(\n", - " shard_map.shard_map(\n", + " jax.shard_map(\n", " lax_sum, mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x')\n", " )\n", ")(input_arr)\n", @@ -892,12 +914,12 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 10, "metadata": { "executionInfo": { - "elapsed": 544, + "elapsed": 362, "status": "ok", - "timestamp": 1722904805699, + "timestamp": 1744390460871, "user": { "displayName": "Justin Fu", "userId": "17543197034567316452" @@ -1017,20 +1039,7 @@ " def _():\n", " # Barrier with both neighbors at the start, since we will be\n", " # communicating with both.\n", - " barrier_sem = pltpu.get_barrier_semaphore()\n", - " pltpu.semaphore_signal(\n", - " barrier_sem,\n", - " inc=1,\n", - " device_id=(left_neighbor,),\n", - " device_id_type=pltpu.DeviceIdType.MESH,\n", - " )\n", - " pltpu.semaphore_signal(\n", - " barrier_sem,\n", - " inc=1,\n", - " device_id=(right_neighbor,),\n", - " device_id_type=pltpu.DeviceIdType.MESH,\n", - " )\n", - " pltpu.semaphore_wait(barrier_sem, 2)\n", + " local_barrier(left_neighbor, right_neighbor)\n", "\n", " # Initialize o_ref, acc_scratch, and hbm_scratch with initial copies.\n", " o_ref[...] = jnp.zeros_like(o_ref[...])\n", @@ -1137,11 +1146,11 @@ "grid_spec = pltpu.PrefetchScalarGridSpec(\n", " num_scalar_prefetch=0,\n", " in_specs=[\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n", + " pl.BlockSpec(memory_space=pltpu.VMEM),\n", " ],\n", " out_specs=[\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " pl.BlockSpec(memory_space=pltpu.VMEM),\n", + " pl.BlockSpec(memory_space=pl.ANY),\n", " ],\n", " grid=(num_devices, 2),\n", " scratch_shapes=(\n", @@ -1160,17 +1169,17 @@ " reduce_scatter_kernel,\n", " out_shape=out_shape,\n", " grid_spec=grid_spec,\n", - " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n", + " compiler_params=pltpu.CompilerParams(collective_id=0),\n", " )(input_arr)[0]\n", "\n", "\n", "pallas_result = jax.jit(\n", - " shard_map.shard_map(\n", + " jax.shard_map(\n", " pallas_reduce_scatter,\n", " mesh=mesh,\n", " in_specs=P(None, 'x'),\n", " out_specs=P('x', None),\n", - " check_rep=False,\n", + " check_vma=False,\n", " )\n", ")(input_arr)\n", "\n", @@ -1179,12 +1188,12 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 11, "metadata": { "executionInfo": { - "elapsed": 596, + "elapsed": 917, "status": "ok", - "timestamp": 1722904806442, + "timestamp": 1744390461967, "user": { "displayName": "Justin Fu", "userId": "17543197034567316452" @@ -1192,7 +1201,7 @@ "user_tz": 420 }, "id": "E-NMh-_teoi4", - "outputId": "24beb42f-1bdd-4c34-e8d2-681dd7f2e9c0" + "outputId": "6c8b82bc-ed64-4cc1-8c5f-65e29cdb333c" }, "outputs": [ { @@ -1220,7 +1229,7 @@ "\n", "\n", "xla_result = jax.jit(\n", - " shard_map.shard_map(\n", + " jax.shard_map(\n", " lax_reduce_sum_scatter,\n", " mesh=mesh,\n", " in_specs=P(None, 'x'),\n", @@ -1247,21 +1256,7 @@ "\n", "A limitation of the previous all-reduce and reduce-scatter kernels that we wrote is that the blocks we copy via remote DMA must be small enough to fit in our working VMEM that we use for accumulation. For some kernels it may be advantageous to use larger block sizes to better utilize the TPU. For example, a matrix multiplication requires on the order of $O(N^3)$ compute operations, but only $O(N^2)$ memory transfers. Therefore, we want each block of work transferred between devices to be large enough such that the operation becomes compute bound and we can hide the communication cost using pipelining. For reference, the VMEM of a TPU (for generations v4/v5) is typically on the order of 10-100MB, whereas HBM ranges from 10-100GB.\n", "\n", - "To address this problem, we need to be able to write an \"inner kernel\" that handles local HBM-VMEM pipelining inside of the \"outer kernel\" that handles pipelining larger HBM-HBM transfers between devices. Pallas offers an API for constructing nested pipelines using the `emit_pipeline` function. The basic call signature for `emit_pipeline` follows that of a standard `pallas_call` by specifying a `grid` and `BlockSpec`s for the inputs and outputs:\n", - "\n", - "```python\n", - "def emit_pipeline(\n", - " kernel: Callable,\n", - " grid: tuple[int],\n", - " in_specs: PyTree[BlockSpec] = None,\n", - " out_specs: PyTree[BlockSpec] = None,\n", - " should_accumulate_out: bool = False,\n", - " dimension_semantics: tuple[GridDimensionSemantics] = None,\n", - ") -> Callable:\n", - " ... # Returns a custom pipeline given an inner kernel and BlockSpecs.\n", - "```\n", - "\n", - "Indeed, one can view `pallas_call` itself as simply a wrapper around `emit_pipeline`. Because our outer kernel only involves remote HBM-HBM transfers, we are not using any of the built-in pipelining that `pallas_call` provides for HBM-VMEM transfers. The following code skeleton demonstrates what a typical program structure would look like using this pattern:\n", + "To address this problem, we need to be able to write an \"inner kernel\" that handles local HBM-VMEM pipelining inside of the \"outer kernel\" that handles pipelining larger HBM-HBM transfers between devices. Pallas offers an API for constructing nested pipelines using the `emit_pipeline` function. See the [TPU pipelining](pallas_tpu_emit_pipeline) guide for a general overview on `emit_pipeline`. Because our outer kernel only involves remote HBM-HBM transfers, we are not using any of the built-in pipelining that `pallas_call` provides for HBM-VMEM transfers. The following code skeleton demonstrates what a typical program structure would look like using this pattern:\n", "\n", "```python\n", "\n", @@ -1298,7 +1293,7 @@ "\n", "In this next example we will modify our previous reduce-scatter example to utilize a nested inner pipeline. Note that the communication and computation costs of `reduce_scatter` both scale linearly with the size of the input, so we do not necessarily expect to see the operation become compute-bound with larger block sizes. This example is purely for demonstration purposes on how to use the pipeline emitter.\n", "\n", - "We will increase the block sizes of the outer kernel such that they would be undesirable to place inside of VMEM, and allocate all inputs and outputs in HBM (`memory_space=TPUMemorySpace.Any`). The only major change from our previous kernel is the body of the kernel where accumulation is done. Rather than manually copying from HBM to VMEM, accumulating, and copying back to HBM, we use `emit_pipeline` to handle the memory transfers for us. Accumulation is done in an inner kernel with a much smaller, VMEM-friendly block size.\n", + "We will increase the block sizes of the outer kernel such that they would be undesirable to place inside of VMEM, and allocate all inputs and outputs in HBM (`memory_space=MemorySpace.ANY`). The only major change from our previous kernel is the body of the kernel where accumulation is done. Rather than manually copying from HBM to VMEM, accumulating, and copying back to HBM, we use `emit_pipeline` to handle the memory transfers for us. Accumulation is done in an inner kernel with a much smaller, VMEM-friendly block size.\n", "\n", "In our previous kernel we had the following kernel body to copy data from HBM to the VMEM accumulator, increment, and then copy the results back to HBM:\n", "\n", @@ -1356,12 +1351,12 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 12, "metadata": { "executionInfo": { - "elapsed": 1341, + "elapsed": 997, "status": "ok", - "timestamp": 1722904807930, + "timestamp": 1744390463178, "user": { "displayName": "Justin Fu", "userId": "17543197034567316452" @@ -1399,7 +1394,7 @@ "inner_block_spec = pl.BlockSpec(\n", " index_map=lambda i, j: (i, j),\n", " block_shape=inner_block_size,\n", - " memory_space=pltpu.TPUMemorySpace.ANY,\n", + " memory_space=pltpu.TPUMemorySpace.VMEM,\n", ")\n", "\n", "\n", @@ -1474,20 +1469,7 @@ " def _():\n", " # Barrier with both neighbors at the start, since we will be\n", " # communicating with both.\n", - " barrier_sem = pltpu.get_barrier_semaphore()\n", - " pltpu.semaphore_signal(\n", - " barrier_sem,\n", - " inc=1,\n", - " device_id=(left_neighbor,),\n", - " device_id_type=pltpu.DeviceIdType.MESH,\n", - " )\n", - " pltpu.semaphore_signal(\n", - " barrier_sem,\n", - " inc=1,\n", - " device_id=(right_neighbor,),\n", - " device_id_type=pltpu.DeviceIdType.MESH,\n", - " )\n", - " pltpu.semaphore_wait(barrier_sem, 2)\n", + " local_barrier(left_neighbor, right_neighbor)\n", "\n", " initial_left_copy.start()\n", " initial_left_copy.wait()\n", @@ -1594,11 +1576,11 @@ "grid_spec = pltpu.PrefetchScalarGridSpec(\n", " num_scalar_prefetch=0,\n", " in_specs=[\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " pl.BlockSpec(memory_space=pl.ANY),\n", " ],\n", " out_specs=[\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " pl.BlockSpec(memory_space=pl.ANY),\n", + " pl.BlockSpec(memory_space=pl.ANY),\n", " ],\n", " grid=(num_devices, 2),\n", " scratch_shapes=(\n", @@ -1616,17 +1598,17 @@ " reduce_scatter_kernel,\n", " out_shape=out_shape,\n", " grid_spec=grid_spec,\n", - " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n", + " compiler_params=pltpu.CompilerParams(collective_id=0),\n", " )(input_arr)[0]\n", "\n", "\n", "pallas_result = jax.jit(\n", - " shard_map.shard_map(\n", + " jax.shard_map(\n", " pallas_reduce_scatter,\n", " mesh=mesh,\n", " in_specs=P(None, 'x'),\n", " out_specs=P('x', None),\n", - " check_rep=False,\n", + " check_vma=False,\n", " )\n", ")(input_arr)\n", "\n", @@ -1635,12 +1617,12 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 13, "metadata": { "executionInfo": { - "elapsed": 768, + "elapsed": 1132, "status": "ok", - "timestamp": 1722904808851, + "timestamp": 1744390464532, "user": { "displayName": "Justin Fu", "userId": "17543197034567316452" @@ -1648,7 +1630,7 @@ "user_tz": 420 }, "id": "cTEyiMDyx9Y0", - "outputId": "1de26695-3713-430e-9ab4-4ea646691680" + "outputId": "70ce154e-dab2-4ae0-e297-c4774d29da85" }, "outputs": [ { @@ -1670,7 +1652,7 @@ "\n", "\n", "xla_result = jax.jit(\n", - " shard_map.shard_map(\n", + " jax.shard_map(\n", " lax_reduce_sum_scatter,\n", " mesh=mesh,\n", " in_specs=P(None, 'x'),\n", @@ -1705,11 +1687,18 @@ "\n", "### Next Steps\n", "\n", - "Excellent follow-up excercises for the reader could include implementing a distributed matrix multiplication, implementing `lax.all_to_all`, and relaxing synchronization to allow for additional run-ahead." + "Excellent follow-up exercises for the reader could include implementing a distributed matrix multiplication, implementing `lax.all_to_all`, and relaxing synchronization to allow for additional run-ahead." ] } ], "metadata": { + "colab": { + "last_runtime": { + "build_target": "//experimental/users/justinfu/pallas:colab", + "kind": "private" + }, + "provenance": [] + }, "jupytext": { "formats": "ipynb,md:myst", "main_language": "python" @@ -1733,5 +1722,5 @@ } }, "nbformat": 4, - "nbformat_minor": 4 + "nbformat_minor": 0 } diff --git a/docs/pallas/tpu/distributed.md b/docs/pallas/tpu/distributed.md index c1f216c6153e..e8bbdb3089cc 100644 --- a/docs/pallas/tpu/distributed.md +++ b/docs/pallas/tpu/distributed.md @@ -17,30 +17,30 @@ kernelspec: # Distributed Computing in Pallas for TPUs -In this tutorial, we will cover the basics of distributed computing in Pallas on TPUs. We will learn about TPU topologies, communication using the remote DMA primitive, and calling a distributed kernel from JAX using `shard_map`. We will also cover some more advanced kernel writing techniques, such as double-buffering, bi-directional bandwidth optimization, and nested pipelining. As educational examples, we will learn how to implement various collective primitives from JAX, such as `lax.ppermute`, `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`. +In this tutorial, we will cover the basics of distributed computing in Pallas on TPUs. We will learn about TPU topologies, communication using the remote DMA primitive, and calling a distributed kernel from JAX using `jax.shard_map`. We will also cover some more advanced kernel writing techniques, such as double-buffering, bi-directional bandwidth optimization, and nested pipelining. As educational examples, we will learn how to implement various collective primitives from JAX, such as `lax.ppermute`, `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`. Some recommended readings beforehand: - [Pallas Pipelining on TPU](pallas_tpu_pipelining) - - [Collectives with `shard_map`](shard_map_collectives_tutorial) + - [Collectives with `jax.shard_map`](shard_map_collectives_tutorial) ```{code-cell} ipython3 --- executionInfo: - elapsed: 1978 + elapsed: 52 status: ok - timestamp: 1722904801801 + timestamp: 1744390458993 user: displayName: Justin Fu userId: '17543197034567316452' user_tz: 420 id: PyAGnWc9yI8T -outputId: 1d8229bd-cab5-495f-93e9-fff2e41db480 +outputId: c5912653-c34b-4810-c373-4a2787691317 --- +import functools import jax from jax import lax from jax import numpy as jnp from jax.experimental import pallas as pl -from jax.experimental import shard_map from jax.experimental.pallas import tpu as pltpu P = jax.sharding.PartitionSpec @@ -61,7 +61,7 @@ TPUs pods are typically arranged in an ND torus topology. The following graphic ![tpu_topologies](https://cloud.google.com/static/tpu/docs/images/v4-topologies.png) -Flattened as a graph, the torus can be visualized as follows. Each edge (orange or black) is a bidirectional connection between two devices. You will commonly hear about rings in conjunction with discussion about device toplogies — a key feature of a torus is that when taking a slice along an axis of the pod, such as the nodes `[(0,1), (1, 1), (2, 1), (3, 1)]` or `[(0, 1), (1, 1)]`, we have a ring of devices. This is a feature we can use to simplify communication patterns within the pod. +Flattened as a graph, the torus can be visualized as follows. Each edge (orange or black) is a bidirectional connection between two devices. You will commonly hear about rings in conjunction with discussion about device topologies — a key feature of a torus is that when taking a slice along an axis of the pod, such as the nodes `[(0,1), (1, 1), (2, 1), (3, 1)]` or `[(0, 1), (1, 1)]`, we have a ring of devices. This is a feature we can use to simplify communication patterns within the pod. ![tpu_torus](https://cloud.google.com/static/tpu/docs/images/untwisted-tori.png) @@ -163,7 +163,7 @@ def example_kernel(input_ref, output_ref, send_sem, recv_sem): `send_sem` and `recv_sem` are instances of a special type of semaphore reserved exclusively for use with DMAs. They must be allocated with the `tpu.SemaphoreType.DMA` type when specifying input specs to `pallas_call`. -Internally, DMA semaphores can be thought of as integer-valued progress trackers. On DMA start, the local device will begin to increment the value of `send_sem` and the receiver's `recv_sem` asynchronously. Waiting on a semaphore will block until the value of the semaphore reaches the total bytes of data sent/received; when the value is reached, waiting threads are released and the sempahore's value is decremented by the same amount. This means that either all data has been sent (for `send_sem`) or all data has been received (for `dst_sem`). The value of the semaphore can be read with `pl.semaphore_read`, but note that the underlying semantics of the value could change between hardware generations (e.g. the value may not represent exactly the number of bytes sent, although this is a useful mental model to have when reasoning about the behavior of the semaphore). +Internally, DMA semaphores can be thought of as integer-valued progress trackers. On DMA start, the local device will begin to increment the value of `send_sem` and the receiver's `recv_sem` asynchronously. Waiting on a semaphore will block until the value of the semaphore reaches the total bytes of data sent/received; when the value is reached, waiting threads are released and the semaphore's value is decremented by the same amount. This means that either all data has been sent (for `send_sem`) or all data has been received (for `recv_sem`). The value of the semaphore can be read with `pl.semaphore_read`, but note that the underlying semantics of the value could change between hardware generations (e.g. the value may not represent exactly the number of bytes sent, although this is a useful mental model to have when reasoning about the behavior of the semaphore). ### Routing @@ -195,15 +195,15 @@ In order to call the kernel in distributed mode, we wrap the `pallas_call` in a ```{code-cell} ipython3 --- executionInfo: - elapsed: 1606 + elapsed: 152 status: ok - timestamp: 1722904803566 + timestamp: 1744390459367 user: displayName: Justin Fu userId: '17543197034567316452' user_tz: 420 id: YkyIKN2thZ-V -outputId: 9b7ed142-d161-4237-fed8-cbce41adc5f0 +outputId: 26719bb9-87ff-46dd-af90-a114ce332417 --- partition = P(None, 'x') mesh = jax.make_mesh((num_devices,), ('x',)) @@ -233,11 +233,11 @@ def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem): out_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32) grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, - # TPUMemorySpace.ANY will (usually) place the tensor in HBM. + # MemorySpace.ANY will (usually) place the tensor in HBM. in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), scratch_shapes=( # We allocate DMA semaphores in scratch memory. [pltpu.SemaphoreType.DMA] * 2 @@ -250,12 +250,12 @@ right_permute = pl.pallas_call( ) # Wrap the kernel within a shard_map to call. pallas_result = jax.jit( - shard_map.shard_map( + jax.shard_map( right_permute, mesh=mesh, in_specs=partition, out_specs=partition, - check_rep=False, + check_vma=False, ) )(input_arr) @@ -263,7 +263,7 @@ pallas_result = jax.jit( perm = tuple((src, (src + 1) % num_devices) for src in range(num_devices)) xla_result = jax.jit( - shard_map.shard_map( + jax.shard_map( lambda x: lax.ppermute(x, 'x', perm), mesh=mesh, in_specs=partition, out_specs=partition) )(input_arr) @@ -296,15 +296,15 @@ We can re-purpose Pallas's `grid` argument to implement the loop. Rather than it ```{code-cell} ipython3 --- executionInfo: - elapsed: 812 + elapsed: 209 status: ok - timestamp: 1722904804531 + timestamp: 1744390459789 user: displayName: Justin Fu userId: '17543197034567316452' user_tz: 420 id: ojQEZB5mBRqM -outputId: e1648f54-737c-4921-ca3b-b4c639a38d2b +outputId: 3a4373f8-1fb5-4a6b-b88e-3461c2609021 --- partition = P('x', None) mesh = jax.make_mesh((num_devices,), ('x',)) @@ -356,10 +356,10 @@ out_shape = jax.ShapeDtypeStruct((num_devices, 8, 128), jnp.float32) grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - # TPUMemorySpace.ANY will (usually) place the tensor in HBM. - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + # MemorySpace.ANY will (usually) place the tensor in HBM. + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), scratch_shapes=( # DMA semaphores are allocated in scratch memory. # We allocated one semaphore for a local HBM-VMEM copy, @@ -383,18 +383,18 @@ all_gather = pl.pallas_call( # Wrap the kernel within a shard_map to call. pallas_result = jax.jit( - shard_map.shard_map( + jax.shard_map( all_gather, mesh=mesh, in_specs=partition, out_specs=partition, - check_rep=False + check_vma=False ) )(input_arr) # Compare Pallas result to XLA shard_map result. xla_result = jax.jit( - shard_map.shard_map( + jax.shard_map( lambda x: lax.all_gather(x, 'x'), mesh=mesh, in_specs=partition, out_specs=partition ) @@ -409,9 +409,9 @@ print('Difference |Pallas - lax.all_gather| = ', +++ {"id": "KgU7HI2pS4om"} -A detail worth mentioning here is the use of multiple receive semaphores. Because we only block on the receiving device, it is still possible for a sender to have sent multiple DMAs in flight before the receiver has finished processing the first one (see the next section and reduce-sum example which discusses race conditions in more detail). In this situation we may hit a situation where the same semaphore is being used for multiple DMAs occurring simultaneously. To avoid this, we allocate `num_devices-1` semaphores so there is no risk of re-use. While this race condition is unlikely to happen on such a small kernel, on larger kernels there is more chance for devices to fall out of sync and potentially cause a silent failure. +A detail worth mentioning here is the use of multiple receive semaphores. Because we only block on the receiving device, it is still possible for a sender to have sent multiple DMAs in flight before the receiver has finished processing the first one (see the next section and reduce-sum example which discusses race conditions in more detail). In this situation we may hit a situation where the same semaphore is being used for multiple DMAs occurring simultaneously. To avoid this, we allocate `num_devices-1` semaphores so there is no risk of reuse. While this race condition is unlikely to happen on such a small kernel, on larger kernels there is more chance for devices to fall out of sync and potentially cause a silent failure. -+++ {"id": "KgU7HI2pS4om"} ++++ {"id": "EDCmAaHVtY7x"} ## Advanced Techniques @@ -451,9 +451,9 @@ def semaphore_read( In order to use regular semaphores, they can be allocated in the same way as a DMA semaphore, but by specifying `pltpu.SemaphoreType.REGULAR` rather than `pltpu.SemaphoreType.DMA`. -Semaphores must be zero at the end of a Pallas program to complete succesfully. There are two error cases where this may happen: +Semaphores must be zero at the end of a Pallas program to complete successfully. There are two error cases where this may happen: - If a semaphore is over-signaled, the program will end with non-zero (>0) semaphores. In this case, the program will crash upon completion. This is useful for debugging as non-zero semaphores typically means there is a bug somewhere inside of the program. - - If a semaphore is over-waited, the program will hang on the blocking `semaphore_wait` call while it waits for the sempahore to be incremented. In this case the device or program will need to be restarted. + - If a semaphore is over-waited, the program will hang on the blocking `semaphore_wait` call while it waits for the semaphore to be incremented. In this case the device or program will need to be restarted. #### Barrier Semaphores @@ -491,7 +491,7 @@ When using barrier semaphores, the `collective_id` compiler parameter must be pa kernel = pl.pallas_call( example_kernel, ..., - compiler_params=pltpu.TPUCompilerParams(collective_id=0), + compiler_params=pltpu.CompilerParams(collective_id=0), ) ``` @@ -556,22 +556,22 @@ The prologue (executed when `outer_step==0`) first initiates a barrier with both The main body assumes that a value has already been copied into our local working slot, either from the previous iteration or from the prologue. A complicating factor is that our destination buffers live in HBM, but we need to load values to VMEM before we perform arithmetic. Therefore, we simultaneously copy the working slot value into our VMEM (`receive_scratch`) and pass the value on to our right neighbor's receiving slot. Once the value has been copied into our VMEM, we can accumulate it into our result (contained in `o_ref`). -A subtle race condition can occur if one device runs one loop ahead of it's right neighbor. In this case, it could copy into the receiver's `working_slot` at the same time the receiver is reading from it. In order to avoid this, each device will block on a `REGULAR` semaphore before copying into the right neighbor's `dst_ref` until it has signaled that it is done reading from its `working_slot`. This race condition is rarely triggered for a small kernel such as this example, but can it can be explicitly triggered if for example using a `pltpu.delay` instruction to artifically hang a device. +A subtle race condition can occur if one device runs one loop ahead of it's right neighbor. In this case, it could copy into the receiver's `working_slot` at the same time the receiver is reading from it. In order to avoid this, each device will block on a `REGULAR` semaphore before copying into the right neighbor's `dst_ref` until it has signaled that it is done reading from its `working_slot`. This race condition is rarely triggered for a small kernel such as this example, but can it can be explicitly triggered if for example using a `pl.delay` instruction to artificially hang a device. Note that this is not an optimal or fully general kernel, as the block sizes must entirely fit in VMEM and we could better interleave communication and accumulation. We will discuss these optimizations in later sections. ```{code-cell} ipython3 --- executionInfo: - elapsed: 254 + elapsed: 248 status: ok - timestamp: 1722904804952 + timestamp: 1744390460289 user: displayName: Justin Fu userId: '17543197034567316452' user_tz: 420 id: XrY5bMlvBroQ -outputId: 77497000-4496-462e-cc3c-73fb640cc14c +outputId: 9216e749-48d2-43ff-d64b-bd419acf3e11 --- partition = P(None, 'x') mesh = jax.make_mesh((num_devices,), ('x',)) @@ -581,6 +581,41 @@ input_arr = jax.random.uniform(jax.random.key(0), shape=(8, 128 * num_devices)) input_arr = jax.device_put(input_arr, sharding) +def local_barrier(left_neighbor, right_neighbor, double_barrier=True): + """Performs a barrier with neighbors on the global barrier semaphore. + + Optionally performs a second barrier, which prevents a potential race + when reusing the same collective_id across kernel invocations. + """ + barrier_sem = pltpu.get_barrier_semaphore() + for neighbor in [left_neighbor, right_neighbor]: + pltpu.semaphore_signal( + barrier_sem, + inc=1, + device_id=(neighbor,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + pltpu.semaphore_wait(barrier_sem, 2) + if double_barrier: + # The double-barrier prevents a race condition where one neighbor can + # re-enter the kernel again on a subsequent call and increment the + # barrier semaphore a second time. This would unblock the current device + # even if the other neighbor is not ready yet. + # To implement a double-barrier, we stack-allocate a second REGULAR + # semaphore using run_scoped. + @functools.partial(pl.run_scoped, + second_barrier=pltpu.SemaphoreType.REGULAR) + def _(second_barrier): + for neighbor in [left_neighbor, right_neighbor]: + pltpu.semaphore_signal( + second_barrier, + inc=1, + device_id=(neighbor,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + pltpu.semaphore_wait(second_barrier, 2) + + def all_reduce_kernel( x_ref, o_ref, @@ -603,20 +638,7 @@ def all_reduce_kernel( def _(): # Barrier with both neighbors at the start, since we will be # communicating with both. - barrier_sem = pltpu.get_barrier_semaphore() - pltpu.semaphore_signal( - barrier_sem, - inc=1, - device_id=(left_neighbor,), - device_id_type=pltpu.DeviceIdType.MESH, - ) - pltpu.semaphore_signal( - barrier_sem, - inc=1, - device_id=(right_neighbor,), - device_id_type=pltpu.DeviceIdType.MESH, - ) - pltpu.semaphore_wait(barrier_sem, 2) + local_barrier(left_neighbor, right_neighbor) # Initialize o_ref, acc_scratch, and hbm_scratch. o_ref[...] = jnp.zeros_like(o_ref) @@ -681,13 +703,13 @@ grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ # Our input lives in VMEM - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), ], out_specs=[ # Our output lives in VMEM - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), # Our double-buffer lives in HBM - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], grid=(num_devices,), scratch_shapes=( @@ -701,16 +723,16 @@ kernel = pl.pallas_call( all_reduce_kernel, out_shape=out_shape, grid_spec=grid_spec, - compiler_params=pltpu.TPUCompilerParams(collective_id=0), + compiler_params=pltpu.CompilerParams(collective_id=0), ) pallas_result = jax.jit( - shard_map.shard_map( + jax.shard_map( kernel, mesh=mesh, in_specs=partition, out_specs=partition, - check_rep=False, + check_vma=False, ) )(input_arr) pallas_result = jax.block_until_ready(pallas_result)[0] @@ -721,7 +743,7 @@ def lax_sum(x): xla_result = jax.jit( - shard_map.shard_map( + jax.shard_map( lax_sum, mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x') ) )(input_arr) @@ -772,9 +794,9 @@ In terms of construction of the kernel, we introduce an additional `phase` dimen ```{code-cell} ipython3 --- executionInfo: - elapsed: 544 + elapsed: 362 status: ok - timestamp: 1722904805699 + timestamp: 1744390460871 user: displayName: Justin Fu userId: '17543197034567316452' @@ -890,20 +912,7 @@ def reduce_scatter_kernel( def _(): # Barrier with both neighbors at the start, since we will be # communicating with both. - barrier_sem = pltpu.get_barrier_semaphore() - pltpu.semaphore_signal( - barrier_sem, - inc=1, - device_id=(left_neighbor,), - device_id_type=pltpu.DeviceIdType.MESH, - ) - pltpu.semaphore_signal( - barrier_sem, - inc=1, - device_id=(right_neighbor,), - device_id_type=pltpu.DeviceIdType.MESH, - ) - pltpu.semaphore_wait(barrier_sem, 2) + local_barrier(left_neighbor, right_neighbor) # Initialize o_ref, acc_scratch, and hbm_scratch with initial copies. o_ref[...] = jnp.zeros_like(o_ref[...]) @@ -1010,11 +1019,11 @@ out_shape = ( grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), ], out_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.VMEM), + pl.BlockSpec(memory_space=pl.ANY), ], grid=(num_devices, 2), scratch_shapes=( @@ -1033,17 +1042,17 @@ def pallas_reduce_scatter(input_arr): reduce_scatter_kernel, out_shape=out_shape, grid_spec=grid_spec, - compiler_params=pltpu.TPUCompilerParams(collective_id=0), + compiler_params=pltpu.CompilerParams(collective_id=0), )(input_arr)[0] pallas_result = jax.jit( - shard_map.shard_map( + jax.shard_map( pallas_reduce_scatter, mesh=mesh, in_specs=P(None, 'x'), out_specs=P('x', None), - check_rep=False, + check_vma=False, ) )(input_arr) @@ -1053,15 +1062,15 @@ pallas_result = jax.block_until_ready(pallas_result) ```{code-cell} ipython3 --- executionInfo: - elapsed: 596 + elapsed: 917 status: ok - timestamp: 1722904806442 + timestamp: 1744390461967 user: displayName: Justin Fu userId: '17543197034567316452' user_tz: 420 id: E-NMh-_teoi4 -outputId: 24beb42f-1bdd-4c34-e8d2-681dd7f2e9c0 +outputId: 6c8b82bc-ed64-4cc1-8c5f-65e29cdb333c --- # Compare our result to XLA. def lax_reduce_sum_scatter(x): @@ -1070,7 +1079,7 @@ def lax_reduce_sum_scatter(x): xla_result = jax.jit( - shard_map.shard_map( + jax.shard_map( lax_reduce_sum_scatter, mesh=mesh, in_specs=P(None, 'x'), @@ -1093,21 +1102,7 @@ print( A limitation of the previous all-reduce and reduce-scatter kernels that we wrote is that the blocks we copy via remote DMA must be small enough to fit in our working VMEM that we use for accumulation. For some kernels it may be advantageous to use larger block sizes to better utilize the TPU. For example, a matrix multiplication requires on the order of $O(N^3)$ compute operations, but only $O(N^2)$ memory transfers. Therefore, we want each block of work transferred between devices to be large enough such that the operation becomes compute bound and we can hide the communication cost using pipelining. For reference, the VMEM of a TPU (for generations v4/v5) is typically on the order of 10-100MB, whereas HBM ranges from 10-100GB. -To address this problem, we need to be able to write an "inner kernel" that handles local HBM-VMEM pipelining inside of the "outer kernel" that handles pipelining larger HBM-HBM transfers between devices. Pallas offers an API for constructing nested pipelines using the `emit_pipeline` function. The basic call signature for `emit_pipeline` follows that of a standard `pallas_call` by specifying a `grid` and `BlockSpec`s for the inputs and outputs: - -```python -def emit_pipeline( - kernel: Callable, - grid: tuple[int], - in_specs: PyTree[BlockSpec] = None, - out_specs: PyTree[BlockSpec] = None, - should_accumulate_out: bool = False, - dimension_semantics: tuple[GridDimensionSemantics] = None, -) -> Callable: - ... # Returns a custom pipeline given an inner kernel and BlockSpecs. -``` - -Indeed, one can view `pallas_call` itself as simply a wrapper around `emit_pipeline`. Because our outer kernel only involves remote HBM-HBM transfers, we are not using any of the built-in pipelining that `pallas_call` provides for HBM-VMEM transfers. The following code skeleton demonstrates what a typical program structure would look like using this pattern: +To address this problem, we need to be able to write an "inner kernel" that handles local HBM-VMEM pipelining inside of the "outer kernel" that handles pipelining larger HBM-HBM transfers between devices. Pallas offers an API for constructing nested pipelines using the `emit_pipeline` function. See the [TPU pipelining](pallas_tpu_emit_pipeline) guide for a general overview on `emit_pipeline`. Because our outer kernel only involves remote HBM-HBM transfers, we are not using any of the built-in pipelining that `pallas_call` provides for HBM-VMEM transfers. The following code skeleton demonstrates what a typical program structure would look like using this pattern: ```python @@ -1139,7 +1134,7 @@ pl.pallas_call( In this next example we will modify our previous reduce-scatter example to utilize a nested inner pipeline. Note that the communication and computation costs of `reduce_scatter` both scale linearly with the size of the input, so we do not necessarily expect to see the operation become compute-bound with larger block sizes. This example is purely for demonstration purposes on how to use the pipeline emitter. -We will increase the block sizes of the outer kernel such that they would be undesirable to place inside of VMEM, and allocate all inputs and outputs in HBM (`memory_space=TPUMemorySpace.Any`). The only major change from our previous kernel is the body of the kernel where accumulation is done. Rather than manually copying from HBM to VMEM, accumulating, and copying back to HBM, we use `emit_pipeline` to handle the memory transfers for us. Accumulation is done in an inner kernel with a much smaller, VMEM-friendly block size. +We will increase the block sizes of the outer kernel such that they would be undesirable to place inside of VMEM, and allocate all inputs and outputs in HBM (`memory_space=MemorySpace.ANY`). The only major change from our previous kernel is the body of the kernel where accumulation is done. Rather than manually copying from HBM to VMEM, accumulating, and copying back to HBM, we use `emit_pipeline` to handle the memory transfers for us. Accumulation is done in an inner kernel with a much smaller, VMEM-friendly block size. In our previous kernel we had the following kernel body to copy data from HBM to the VMEM accumulator, increment, and then copy the results back to HBM: @@ -1197,9 +1192,9 @@ The full kernel is as follows: ```{code-cell} ipython3 --- executionInfo: - elapsed: 1341 + elapsed: 997 status: ok - timestamp: 1722904807930 + timestamp: 1744390463178 user: displayName: Justin Fu userId: '17543197034567316452' @@ -1233,7 +1228,7 @@ inner_grid = ( inner_block_spec = pl.BlockSpec( index_map=lambda i, j: (i, j), block_shape=inner_block_size, - memory_space=pltpu.TPUMemorySpace.ANY, + memory_space=pltpu.TPUMemorySpace.VMEM, ) @@ -1308,20 +1303,7 @@ def reduce_scatter_kernel( def _(): # Barrier with both neighbors at the start, since we will be # communicating with both. - barrier_sem = pltpu.get_barrier_semaphore() - pltpu.semaphore_signal( - barrier_sem, - inc=1, - device_id=(left_neighbor,), - device_id_type=pltpu.DeviceIdType.MESH, - ) - pltpu.semaphore_signal( - barrier_sem, - inc=1, - device_id=(right_neighbor,), - device_id_type=pltpu.DeviceIdType.MESH, - ) - pltpu.semaphore_wait(barrier_sem, 2) + local_barrier(left_neighbor, right_neighbor) initial_left_copy.start() initial_left_copy.wait() @@ -1428,11 +1410,11 @@ out_shape = ( grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], out_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], grid=(num_devices, 2), scratch_shapes=( @@ -1450,17 +1432,17 @@ def pallas_reduce_scatter(input_arr): reduce_scatter_kernel, out_shape=out_shape, grid_spec=grid_spec, - compiler_params=pltpu.TPUCompilerParams(collective_id=0), + compiler_params=pltpu.CompilerParams(collective_id=0), )(input_arr)[0] pallas_result = jax.jit( - shard_map.shard_map( + jax.shard_map( pallas_reduce_scatter, mesh=mesh, in_specs=P(None, 'x'), out_specs=P('x', None), - check_rep=False, + check_vma=False, ) )(input_arr) @@ -1470,15 +1452,15 @@ pallas_result = jax.block_until_ready(pallas_result) ```{code-cell} ipython3 --- executionInfo: - elapsed: 768 + elapsed: 1132 status: ok - timestamp: 1722904808851 + timestamp: 1744390464532 user: displayName: Justin Fu userId: '17543197034567316452' user_tz: 420 id: cTEyiMDyx9Y0 -outputId: 1de26695-3713-430e-9ab4-4ea646691680 +outputId: 70ce154e-dab2-4ae0-e297-c4774d29da85 --- # Now we compare our result to XLA. def lax_reduce_sum_scatter(x): @@ -1487,7 +1469,7 @@ def lax_reduce_sum_scatter(x): xla_result = jax.jit( - shard_map.shard_map( + jax.shard_map( lax_reduce_sum_scatter, mesh=mesh, in_specs=P(None, 'x'), @@ -1518,4 +1500,4 @@ In this tutorial we covered several kernel examples which replicate the function ### Next Steps -Excellent follow-up excercises for the reader could include implementing a distributed matrix multiplication, implementing `lax.all_to_all`, and relaxing synchronization to allow for additional run-ahead. +Excellent follow-up exercises for the reader could include implementing a distributed matrix multiplication, implementing `lax.all_to_all`, and relaxing synchronization to allow for additional run-ahead. diff --git a/docs/pallas/tpu/index.rst b/docs/pallas/tpu/index.rst index 20abad5f610e..784e037bca06 100644 --- a/docs/pallas/tpu/index.rst +++ b/docs/pallas/tpu/index.rst @@ -11,4 +11,6 @@ TPU specific documentation. matmul sparse distributed + core_map + prng diff --git a/docs/pallas/tpu/matmul.ipynb b/docs/pallas/tpu/matmul.ipynb index 9c90add16ab0..dbe9747c4884 100644 --- a/docs/pallas/tpu/matmul.ipynb +++ b/docs/pallas/tpu/matmul.ipynb @@ -210,7 +210,7 @@ " pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))],\n", " out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),\n", " grid=(m // bm, n // bn, k // bk),\n", - " compiler_params=pltpu.TPUCompilerParams(\n", + " compiler_params=pltpu.CompilerParams(\n", " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n", " )(x, y)" ] @@ -466,7 +466,7 @@ " grid=(m // bm, n // bn, k // bk),\n", " ),\n", " out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n", - " compiler_params=pltpu.TPUCompilerParams(\n", + " compiler_params=pltpu.CompilerParams(\n", " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n", " )(x, y)" ] @@ -496,7 +496,14 @@ "\n", "Our above analysis about FLOPs vs memory usage applies at a coarse scale i.e. when we are looking at the the size of a the total matrix multiplication. However, remember that in practice, we are pipelining the execution of a blocked matrix multiplication, meaning we have a loop in which we are doing matrix multiplication with smaller blocks.\n", "\n", - "This means that we actually care about the FLOPs vs memory bandwidth usage of each individual instance of the kernel, not the global FLOPs vs memory bandwidth usage. Therefore, the block sizes `bm`, `bk`, `bn` are extremely important for performance. Even if we have the largest matrices in the world, if we pick very small `bm`, `bk`, and `bn`, we will be memory bound because each time we invoke the kernel we will have too few FLOPs to hide the memory transfers happening in the background.\n", + "This means that we actually care about the FLOPs vs memory bandwidth usage of each individual instance of the kernel, not the global FLOPs vs memory bandwidth usage.\n", + "\n", + "In addition, when tiling the matmul operation, the same values could be read multiple times from memory.\n", + "Specifically the memory bandwidth for the first operand of the kernel is `(bm * bk)`, which needs to be multiplied by the grid dimensions, that is `(bm * bk) * m // bm * n // bn * k // bk = m * k * n // bn`.\n", + "Similarly for the second operand, yielding a total bandwidth usage `(m * k * n // bn + k * n * m // bm + m * n) * element_size`.\n", + "\n", + "Therefore, the block sizes `bm`, `bk`, `bn` are extremely important for performance.\n", + " Even if we have the largest matrices in the world, if we pick very small `bm`, `bk`, and `bn`, we will be memory bound because each time we invoke the kernel we will have too few FLOPs to hide the memory transfers happening in the background.\n", "\n", "The intuition should therefore be: to be compute bound, make the blocks as big as possible! There are two main constraints:\n", "\n", @@ -741,7 +748,7 @@ " grid=(m // bm, n // bn, k // bk),\n", " ),\n", " out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n", - " compiler_params=pltpu.TPUCompilerParams(\n", + " compiler_params=pltpu.CompilerParams(\n", " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n", " )(x, y)" ] @@ -929,7 +936,7 @@ " grid=(m // bm, n // bn, k // bk),\n", " ),\n", " out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n", - " compiler_params=pltpu.TPUCompilerParams(\n", + " compiler_params=pltpu.CompilerParams(\n", " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n", " )(x, y)" ] diff --git a/docs/pallas/tpu/matmul.md b/docs/pallas/tpu/matmul.md index 42084f12d5f5..509d47093af7 100644 --- a/docs/pallas/tpu/matmul.md +++ b/docs/pallas/tpu/matmul.md @@ -167,7 +167,7 @@ def matmul( pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))], out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)), grid=(m // bm, n // bn, k // bk), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("parallel", "parallel", "arbitrary")), )(x, y) ``` @@ -321,7 +321,7 @@ def matmul( grid=(m // bm, n // bn, k // bk), ), out_shape=jax.ShapeDtypeStruct((m, n), x.dtype), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("parallel", "parallel", "arbitrary")), )(x, y) ``` @@ -342,7 +342,14 @@ np.testing.assert_array_equal(x @ y, matmul(x, y)) Our above analysis about FLOPs vs memory usage applies at a coarse scale i.e. when we are looking at the the size of a the total matrix multiplication. However, remember that in practice, we are pipelining the execution of a blocked matrix multiplication, meaning we have a loop in which we are doing matrix multiplication with smaller blocks. -This means that we actually care about the FLOPs vs memory bandwidth usage of each individual instance of the kernel, not the global FLOPs vs memory bandwidth usage. Therefore, the block sizes `bm`, `bk`, `bn` are extremely important for performance. Even if we have the largest matrices in the world, if we pick very small `bm`, `bk`, and `bn`, we will be memory bound because each time we invoke the kernel we will have too few FLOPs to hide the memory transfers happening in the background. +This means that we actually care about the FLOPs vs memory bandwidth usage of each individual instance of the kernel, not the global FLOPs vs memory bandwidth usage. + +In addition, when tiling the matmul operation, the same values could be read multiple times from memory. +Specifically the memory bandwidth for the first operand of the kernel is `(bm * bk)`, which needs to be multiplied by the grid dimensions, that is `(bm * bk) * m // bm * n // bn * k // bk = m * k * n // bn`. +Similarly for the second operand, yielding a total bandwidth usage `(m * k * n // bn + k * n * m // bm + m * n) * element_size`. + +Therefore, the block sizes `bm`, `bk`, `bn` are extremely important for performance. + Even if we have the largest matrices in the world, if we pick very small `bm`, `bk`, and `bn`, we will be memory bound because each time we invoke the kernel we will have too few FLOPs to hide the memory transfers happening in the background. The intuition should therefore be: to be compute bound, make the blocks as big as possible! There are two main constraints: @@ -489,7 +496,7 @@ def matmul( grid=(m // bm, n // bn, k // bk), ), out_shape=jax.ShapeDtypeStruct((m, n), x.dtype), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("parallel", "parallel", "arbitrary")), )(x, y) ``` @@ -613,7 +620,7 @@ def matmul( grid=(m // bm, n // bn, k // bk), ), out_shape=jax.ShapeDtypeStruct((m, n), x.dtype), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("parallel", "parallel", "arbitrary")), )(x, y) ``` diff --git a/docs/pallas/tpu/pipelining.ipynb b/docs/pallas/tpu/pipelining.ipynb index 10de587105f2..12a4b852e84a 100644 --- a/docs/pallas/tpu/pipelining.ipynb +++ b/docs/pallas/tpu/pipelining.ipynb @@ -2,8 +2,9 @@ "cells": [ { "cell_type": "markdown", - "id": "7704d3bb", - "metadata": {}, + "metadata": { + "id": "7704d3bb" + }, "source": [ "(pallas_tpu_pipelining)=" ] @@ -14,7 +15,7 @@ "id": "teoJ_fUwlu0l" }, "source": [ - "# Pipelining\n", + "# TPU Pipelining\n", "\n", "" ] @@ -25,14 +26,24 @@ "id": "gAJDZh1gBh-h" }, "source": [ - "In this guide we'll cover how memory spaces in TPU work and how to write\n", - "pipelines in Pallas that overlap memory I/O with compute." + "This guide serves as a reference for TPU-specific pipelining concerns.\n", + "We'll review the memory hierarchy and compute units on TPUs, and TPU-specific features of the pipelining API. For a more general-purpose overview of pipelining, see the {ref}`pallas_software_pipelining`." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": { + "executionInfo": { + "elapsed": 54, + "status": "ok", + "timestamp": 1744908474512, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, "id": "ejAVO6ikUUuF" }, "outputs": [], @@ -48,9 +59,8 @@ }, { "cell_type": "markdown", - "id": "0e212a5e", "metadata": { - "id": "TWKESTKAlyjT" + "id": "0e212a5e" }, "source": [ "(tpu_and_its_memory_spaces)=\n", @@ -60,7 +70,9 @@ }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "id": "NnWW9GV4kW6P" + }, "source": [ "A TPU and its TensorCore consist of memory spaces (where arrays can reside),\n", "registers (which temporarily store scalar and array values) and compute units\n", @@ -83,568 +95,203 @@ " Values can be loaded into memory from their respective caches (VMEM for\n", " VREGs and SMEM for SREGs).\n", "* **Compute units**: A TensorCore has a scalar unit, vector unit (VPU) and\n", - " matrix unit (MXU) that can do numerical computation.\n", + " matrix unit (MXU) that can do numerical computation. Each of these compute units can operate asynchronously, but this is managed by the TPU compiler and thus from the programmer's perspective a TPU program is single-threaded.\n", " Compute units operate on values that live in SREGs and VREGs and output\n", - " values into those registers as well.\n", - "\n", - "In order to do a vectorized computation on our values `x` and `y` that live\n", - "in HBM, we need to:\n", - "\n", - "1. Copy the values `x` and `y` into VMEM.\n", - "2. Load the values from VMEM into VREGs.\n", - "3. Execute the computation using the VPU or MXU, storing the output in VREGs.\n", - "4. Store the values in the output VREGs into VMEM.\n", - "5. Copy the output values in VMEM back to HBM." + " values into those registers as well." ] }, { "cell_type": "markdown", "metadata": { - "id": "TzctMbNsn3vc" + "id": "8Tl3wt5Wk3Ek" }, "source": [ - "Let's implement a Pallas function that does just that!" + "## TPU-specific Pipelining Features\n", + "\n", + "Pallas TPU supports the following platform-specific features." ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": { - "id": "2IXQxNWrKJyb", - "outputId": "d62eb493-5f92-4496-f113-d3cd24cb0b9f" + "id": "1jg5WmExk47l" }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[2., 2., 2., ..., 2., 2., 2.],\n", - " [2., 2., 2., ..., 2., 2., 2.],\n", - " [2., 2., 2., ..., 2., 2., 2.],\n", - " ...,\n", - " [2., 2., 2., ..., 2., 2., 2.],\n", - " [2., 2., 2., ..., 2., 2., 2.],\n", - " [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], "source": [ - "def add_matrices_kernel(x_vmem_ref, y_vmem_ref, z_vmem_ref):\n", - " # Load x and y from VMEM into VREGs\n", - " x_vregs = x_vmem_ref[:, :]\n", - " y_vregs = y_vmem_ref[:, :]\n", - " # Execute a vectorized add\n", - " z_vregs = x_vregs + y_vregs\n", - " # Store the output values in VREGs back into VMEM\n", - " z_vmem_ref[:, :] = z_vregs\n", + "### TPU Memory Spaces\n", "\n", + "Pallas exposes all levels of the TPU memory hierarchy to users. The following table maps from Pallas TPU memory spaces to their standard memory types (DRAM/SRAM):\n", "\n", - "def add_matrices(x: jax.Array, y: jax.Array) -> jax.Array:\n", - " # pallas_call will first allocate scratch buffers for `x` and `y` in VMEM.\n", - " # It will then copy `x` and `y` from HBM into VMEM.\n", - " z = pl.pallas_call(\n", - " add_matrices_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)\n", - " )(x, y)\n", - " # pallas_call will also copy the output from VMEM back into HBM.\n", - " return z\n", + "| Pallas Enum | TPU Memory Space | Type (DRAM/SRAM) |\n", + "| --- | --- | --- |\n", + "| `pl.ANY` | HBM (usually) or VMEM | DRAM |\n", + "| `pltpu.VMEM` | VMEM | SRAM |\n", + "| `pltpu.SMEM` | SMEM | SRAM |\n", + "| `pltpu.SEMAPHORE` | Semaphore | SRAM |\n", "\n", + "- `MemorySpace.VMEM` denotes vector SRAM. It is the default memory space if nothing is specified.\n", + "- `MemorySpace.SMEM` denotes scalar SRAM. Only scalar loads and stores can be performed to/from SMEM.\n", + "- `MemorySpace.ANY` is a hint to the compiler that the memory space is unconstrained. In most cases, XLA will place this buffer in HBM. A buffer assigned to the `ANY` memory space cannot be dereferenced normally using array indexing syntax (e.g. `x[...]`). Instead, we must first copy the values into a VMEM or SMEM buffer using `pltpu.sync_copy` or `pltpu.async_copy`.\n", + "- `MemorySpace.SEMAPHORE` is used to allocate semaphores for constructing barriers or tracking asynchronous operations. It is also possible to return semaphores from the kernel for building asynchronous kernels - this is an experimental feature; see {ref}`pallas_async` for more details.\n", "\n", - "x, y = jnp.ones((512, 512)), jnp.ones((512, 512))\n", - "add_matrices(x, y)" + "Pipelining on TPUs is typically done between HBM (DRAM) to VMEM (Vector SRAM). The default behavior for `pallas_call` on TPU is that arguments to `pallas_call` are assumed to live in HBM, and inputs to the user kernel body are stored in VMEM.\n", + "\n", + "While not specific to pipelining, it is possible to gain manual control over the memory space of input and output buffers, you can specify the `memory_space` argument on a `BlockSpec`. Note that pipelining is not allowed unless the `memory_space` is marked as `VMEM`. Memory spaces can also be used to specify scratch arguments to a kernel via the `scratch_shapes` argument on `pallas_call`. Scratch buffers are persistent across kernel iterations and are useful for storing intermediate results such as partial accumulations and reductions. A scratch buffer must reside in `VMEM`, `SMEM`, or `SEMAPHORE`.\n", + "\n", + "As an example for using multiple manual memory space assignments in a kernel, the following program copies a slice of an HBM buffer `x_hbm_ref` into a scratch VMEM buffer `scratch_vmem_ref` before using it for arithmetic and storing the result into an output VMEM buffer:" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 8, "metadata": { - "id": "HMENNLy8okCL" + "executionInfo": { + "elapsed": 65, + "status": "ok", + "timestamp": 1744908591430, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "zcqz1CA_o50a" }, + "outputs": [], "source": [ - "We've written two functions: `add_matrices_kernel` and `add_matrices`.\n", - "\n", - "`add_matrices_kernel` operates using `Ref`s that live in VMEM.\n", - "Loading from a VMEM `Ref` produces a value that lives in VREGs.\n", - "Values in VREGs behave like `jax.Array`s in that we can use `jnp` and\n", - "`jax.lax` operations on them to produce new values that live in VREGs.\n", - "When we produce the values we'd like to return, we store them in the output\n", - "VMEM `Ref`.\n", - "\n", - "The `add_matrices` function acts on `jax.Array`s and returns a `jax.Array`.\n", - "Inside it, we pass `x` and `y` into `pallas_call`.\n", - "`pallas_call` is responsible for copying `x` and `y` into VMEM and for\n", - "allocating the VMEM buffers that the kernel operates on (including allocating\n", - "`z_vmem_ref`, the output VMEM buffer).\n", - "After the kernel function is finished running, `pallas_call` will also copy\n", - "the value in `z_vmem_ref` to HBM, resulting in an output `jax.Array`." + "def hbm_vmem_kernel(x_hbm_ref, out_vmem_ref, scratch_vmem_ref):\n", + " pltpu.sync_copy(x_hbm_ref.at[0:1], scratch_vmem_ref)\n", + " out_vmem_ref[...] = scratch_vmem_ref[...] + 1\n", + "\n", + "x = jax.random.uniform(jax.random.key(0), (8, 128), jnp.float32)\n", + "out = pl.pallas_call(hbm_vmem_kernel,\n", + " in_specs=[pl.BlockSpec(memory_space=pl.ANY)],\n", + " out_shape=jax.ShapeDtypeStruct((1, 128), jnp.float32),\n", + " scratch_shapes=(pltpu.VMEM(shape=(1, 128), dtype=jnp.float32),)\n", + ")(x)\n", + "\n", + "np.testing.assert_allclose(out, x[0:1] + 1)" ] }, { "cell_type": "markdown", - "metadata": { - "id": "5kWr-1tKpYro" - }, + "metadata": {}, "source": [ - "## Constraints of using VMEM/SMEM\n", - "\n", - "Pallas exposes access to lower level memory spaces like VMEM and SMEM but\n", - "writing kernels utilizing them adds some considerations.\n", + "### Multiple Buffering\n", "\n", - "1. Memory capacity. VMEM and SMEM are *small*! VMEM on v4 TPUs is only 16MiB\n", - " and SMEM ranges in the tens to hundreds of KiB.\n", - " If our arrays are too big, we won't even be able to fit them into VMEM at all.\n", - " For reference, a `f32[2048, 2048]` array is 16MiB, so our above kernel won't\n", - " scale beyond moderately sized arrays.\n", + "Multiple buffering can be specified on a per-argument basis to the pipeline via the `pipeline_mode` option on `pl.BlockSpec`. To do so, pass a `pl.Buffered` object to `pl.BlockSpec` specifying the number of buffers to allocate for this particular argument:\n", "\n", - "2. Memory bandwidth. Copying to/from HBM and VMEM takes a long time, at least\n", - " compared to most compute instructions.\n", - " The `add_matrices` function above will likely spend more time copying\n", - " between HBM and VMEM than actually performing the addition itself.\n", + "```python\n", + "pl.BlockSpec(\n", + " pipeline_mode=pl.Buffered(buffer_count=buffer_count)\n", + ")\n", + "```\n", "\n", - "With these two constraints in mind, we'll have to rethink our strategy for\n", - "getting performance out of our TPUs." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_NTqvlbetB3P" - }, - "source": [ - "## Primer: Pipelining\n", - "\n", - "Pipelining our computation offers a way of dealing with both the memory\n", - "capacity and bandwidth constraints in one fell swoop.\n", - "What do we mean by pipelining?\n", - "\n", - "The goal is: *in parallel* copy to/from HBM and VMEM *while* utilizing our\n", - "compute units.\n", - "Naively this is difficult because in our program above we copy *all* of `x`\n", - "and `y` before we start doing any compute with them, creating a dependence\n", - "between the copy and the compute.\n", - "\n", - "However, if we can chunk up our computation into several subcomputations\n", - "(e.g. when we add two matrices, we can express that as addition of \"blocks\"\n", - "of the original matrices together), we can now overlap the copies of one of\n", - "those subcomputations with the compute of the other. Let's walk through a\n", - "simple example:\n", - "\n", - "Let's say we split our arrays `x` and `y` into `x1, x2` and `y1, y2` (for\n", - "example, split along the leading axis, resulting in two `(256, 512)` arrays\n", - "for each input.\n", - "We can now execute the following pipelined computation.\n", - "\n", - "1. Copy `x1` and `y1` into VMEM.\n", - "1. Start copying `x2` and `y2` into VMEM\n", - "2. Load `x1, y1` from VMEM into VREGs.\n", - "3. Execute the `z1 = x1 + y1` using the compute units.\n", - "4. Store `z1` into VMEM.\n", - "5. Start copying `z1` from VMEM back into HBM.\n", - "6. Wait until `x2, y2` have been copied into VMEM.\n", - "7. Load `x2, y2` from VMEM into VREGs.\n", - "8. Execute the `z2 = x2 + y2` using the compute units.\n", - "9. Store `z2` into VMEM.\n", - "10. Wait until `z1` is copied into HBM.\n", - "10. Start copying `z2` from VMEM back into HBM.\n", - "10. Wait until `z2` is copied into HBM.\n", - "\n", - "Any time we are doing compute here, we are asynchronously copying something.\n", - "This means that some of the time spent copying is not wasted.\n", - "\n", - "The two most important numbers for determining how efficient a pipelined\n", - "computation are a) how many floating point operations (FLOPs) we need to\n", - "execute and b) how many bytes we need to copy to execute that computation.\n", - "The ratio of these two (FLOPs/memory usage) is called the\n", - "*arithmetic intensity* of an operation and determines if our pipeline will\n", - "be compute bound or memory bound." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gutx7y8uvZKH" - }, - "source": [ - "## Pipelining in Pallas" + "The default buffer count is 2 for all inputs and outputs." ] }, { "cell_type": "markdown", - "metadata": { - "id": "U-dPTjlBverB" - }, + "metadata": {}, "source": [ - "How do we implement a pipeline like the one above in Pallas?\n", - "It seems like a complex sequence of asynchronous data operations and\n", - "executing kernels that would be a pain to implement manually.\n", - "Fear not! Pallas offers an API for expressing pipelines without too much\n", - "boilerplate, namely through `grid`s and `BlockSpec`s.\n", - "\n", - "See how in the above pipelined example, we are executing the same logic\n", - "multiple times: steps 3-5 and 8-10 both execute the same operations,\n", - "only on different inputs.\n", - "The {func}`jax.experimental.pallas.pallas_call` provides a way to\n", - "execute a kernel multiple times, by using the `grid` argument.\n", - "See {ref}`pallas_grid`.\n", - "\n", - "We also use {class}`jax.experimental.pallas.BlockSpec` to specify\n", - "how to construct the input of each kernel invocation.\n", - "See {ref}`pallas_blockspec`.\n", - "\n", - "In the pipelining example above, we had `(512, 512)`-shaped arrays and\n", - "split them along the leading dimension into two `(256, 512)`-shaped arrays.\n", - "In this pipeline, our `BlockSpec.block_shape` would be `(256, 512)`.\n", - "On the 1st iteration we'd\n", - "like to select `x1` and on the second iteration we'd like to use `x2`.\n", - "This can be expressed with the following `index_map`:\n", + "(pallas_tpu_emit_pipeline)=\n", "\n", - "```python\n", - "def x_index_map(i):\n", - " return (i, 0)\n", - "```\n", + "### pltpu.emit_pipeline\n", + "\n", + "`pltpu.emit_pipeline` is a pipelining API implemented in Pallas that allows you to construct pipelines inside of a kernel rather than only on kernel entry. This several use-cases over using `pl.pallas_call`, such as:\n", + "- For constructing nested pipelines. For example, an outer pipeline that communicates between chips, and an inner pipeline that performs HBM-VMEM pipelining.\n", + "- For using `emit_pipeline` specific features such as lookahead prefetch and dynamic block shapes (covered below).\n", + "\n", + "`pltpu.emit_pipeline` follows a similar signature to `pl.pallas_call` and requires you to specify a body `kernel`, a grid, and block specs for inputs and outputs:\n", "\n", - "We'd then construct the `BlockSpec`:\n", "```python\n", - "block_spec = pl.BlockSpec((256, 512), x_index_map)\n", + "def emit_pipeline(\n", + " kernel: Callable,\n", + " grid: tuple[int],\n", + " in_specs: PyTree[BlockSpec] = None,\n", + " out_specs: PyTree[BlockSpec] = None,\n", + " dimension_semantics: tuple[GridDimensionSemantics] = None,\n", + " core_axis: int | None = None,\n", + ") -> Callable:\n", + " ... # Returns a custom pipeline given an inner kernel and BlockSpecs.\n", "```\n", "\n", - "The `BlockSpec`s for `y` and `z` will be the same as the one for `x`." + "The `dimension_semantics` and `core_axis` arguments are used for partitioning the kernel grid over Megacore (see below)." ] }, { "cell_type": "markdown", - "metadata": { - "id": "noybOKghzjwG" - }, + "metadata": {}, "source": [ - "### Putting it together\n", + "### Lookahead Prefetch\n", "\n", - "We provide these arguments to `pallas_call` via `grid`, `in_specs` and\n", - "`out_specs` (`in_specs` corresponds to the tuple of positional arguments,\n", - "and `out_specs` corresponds to the output)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ehKAYAwIojfv", - "outputId": "504bab29-83f3-4e1f-8664-1860ad15b6de" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[2., 2., 2., ..., 2., 2., 2.],\n", - " [2., 2., 2., ..., 2., 2., 2.],\n", - " [2., 2., 2., ..., 2., 2., 2.],\n", - " ...,\n", - " [2., 2., 2., ..., 2., 2., 2.],\n", - " [2., 2., 2., ..., 2., 2., 2.],\n", - " [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def add_matrices_pipelined(x: jax.Array, y: jax.Array) -> jax.Array:\n", - " block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))\n", - " return pl.pallas_call(\n", - " add_matrices_kernel,\n", - " out_shape=x,\n", - " in_specs=[block_spec, block_spec],\n", - " out_specs=block_spec,\n", - " grid=(2,)\n", - " )(x, y)\n", + "Lookahead prefetch is a pipelining feature where the pipeline will attempt to prefetch the next input block as soon as a buffering slot is available, rather than the iteration directly before it would be used. For example, if the kernel had a grid of `(8,)` and the block indices to fetch on each iteration were `0, 0, 0, 0, 1, 1, 1, 1`, then lookahead prefetch will begin fetching both blocks `0` and `1` on iteration 0, whereas the standard pipeline schedule would fetch block `0` on iteration 0 but not begin fetching block `1` until iteration 3. There is a small amount of control flow overhead in performing lookahead so it is disabled by default.\n", "\n", - "add_matrices_pipelined(x, y)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rkytgIZYzz4t" - }, - "source": [ - "We've only added a little bit of code to our original function to add\n", - "automatic pipelining but the `BlockSpec`s and `grid` do a lot of heavy\n", - "lifting!\n", - "\n", - "How does it work? Well, the `BlockSpec`s provide enough information to start\n", - "*prefetching* blocks of our input from HBM into VMEM.\n", - "For example, if we are starting iteration `i` of our `grid`, we can pass\n", - "`i + 1` into the `index_map` functions to obtain the blocks needed for the\n", - "next iteration. We can then start an asynchronous copy for those blocks.\n", - "Similarly for outputs, we can wait for the outputs of the previous iteration\n", - "to be copied before starting the copy for the current iteration's outputs." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7Xtz9oMs0ZRL" - }, - "source": [ - "### Parameterizing a pipeline" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "esY4GcIB0bqQ" - }, - "source": [ - "It's common to parameterize the block shapes in our kernel. Block sizes are\n", - "perhaps the most important parameter to tune when optimizing the performance\n", - "of Pallas kernels! They give us control over the pipeline (for example,\n", - "picking smaller blocks adds more iterations to our pipelined loop where each\n", - "iteration has less work to do).\n", - "\n", - "Furthermore, we could also carve up the inputs and outputs along the 2nd\n", - "dimension (we are only splitting along the first right now). Let's write a\n", - "more general kernel that handles both of these features." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "VartelFd0YfY" - }, - "outputs": [], - "source": [ - "def add_matrices_pipelined_2d(\n", - " x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256\n", - ") -> jax.Array:\n", - " m, n = x.shape\n", - " block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j))\n", - " return pl.pallas_call(\n", - " add_matrices_kernel,\n", - " out_shape=x,\n", - " in_specs=[block_spec, block_spec],\n", - " out_specs=block_spec,\n", - " grid=(m // bm, n // bn),\n", - " )(x, y)\n", + "Lookahead is primarily useful when there is a variable amount of compute work in each block, such as when some blocks contain skipped or a reduced amount of work. In these cases, there may not be enough compute work in the iteration immediately preceding the step when the block is needed to fully overlap with the memory transfer. Therefore, we would like to begin fetching blocks earlier in the pipeline.\n", "\n", - "np.testing.assert_array_equal(\n", - " add_matrices_pipelined_2d(x, y, bm=256, bn=256), x + y\n", - ")\n", - "np.testing.assert_array_equal(\n", - " add_matrices_pipelined_2d(x, y, bm=128, bn=128), x + y\n", + "Lookahead prefetch can be used in conjunction with multiple buffering and can likewise be enabled by passing `pl.Buffered` into the `pipeline_mode` argument:\n", + "```python\n", + "pl.BlockSpec(\n", + " pipeline_mode=pl.Buffered(buffer_count=buffer_count, use_lookahead=True)\n", ")\n", - "np.testing.assert_array_equal(\n", - " add_matrices_pipelined_2d(x, y, bm=512, bn=512), x + y\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KrfeYwaW1QA-" - }, - "source": [ - "## Handling reductions" + "```" ] }, { "cell_type": "markdown", - "metadata": { - "id": "P3SqEKDe3Mar" - }, + "metadata": {}, "source": [ - "How would you implement something like `jnp.sum` using `pallas_call`?\n", - "Specifically, we'd like to pipeline across the reduction dimension.\n", + "### Dynamic Block Shapes\n", "\n", - "Take the example of reducing a `(8, 512, 512)`-shaped array to a\n", - "`(512, 512)`-shaped one." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "JoT-ZKEk1R7l", - "outputId": "fd842223-98a5-4e5c-87fc-5dadc94da4fa" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[8., 8., 8., ..., 8., 8., 8.],\n", - " [8., 8., 8., ..., 8., 8., 8.],\n", - " [8., 8., 8., ..., 8., 8., 8.],\n", - " ...,\n", - " [8., 8., 8., ..., 8., 8., 8.],\n", - " [8., 8., 8., ..., 8., 8., 8.],\n", - " [8., 8., 8., ..., 8., 8., 8.]], dtype=float32)" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x = jnp.ones((8, 512, 512))\n", - "jnp.sum(x, axis=0)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5O3ByvuT3iyC" - }, - "source": [ - "To do this using `pallas_call`, we could use a grid of size `(8,)` and in\n", - "each iteration `i` load `x[i]` into VMEM.\n", - "Then we could add `x[i]` to an output VMEM buffer. Let's implement this\n", - "naively first." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "hqvv_WRQ3bvP", - "outputId": "200648d2-3f4d-4d1a-b95a-d2c1352cd7b8" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[9., 9., 9., ..., 9., 9., 9.],\n", - " [9., 9., 9., ..., 9., 9., 9.],\n", - " [9., 9., 9., ..., 9., 9., 9.],\n", - " ...,\n", - " [9., 9., 9., ..., 9., 9., 9.],\n", - " [9., 9., 9., ..., 9., 9., 9.],\n", - " [9., 9., 9., ..., 9., 9., 9.]], dtype=float32)" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Warning: this implementation is incorrect!\n", + "`pltpu.emit_pipeline` supports pipelining over blocks with dynamic but bounded shapes. In order to specify such an block shape, the dynamic-sized dimension in the block should be marked with `pl.BoundedSlice(max_size)` rather than a static integer size, where `max_size` is the maximum size of the block. In addition, the corresponding index returned by `index_map` should be a dynamic slice constructed via `pl.ds(start, size)` where both `start` and `size` are _element_ indices (not block indices) and can be dynamic.\n", "\n", - "def naive_sum_kernel(x_ref, o_ref):\n", - " o_ref[...] += x_ref[...]\n", + "The following is an example for a block spec with a dynamic first dimension:\n", "\n", - "def naive_sum(x: jax.Array) -> jax.Array:\n", - " grid, *out_shape = x.shape\n", - " return pl.pallas_call(\n", - " naive_sum_kernel,\n", - " grid=grid,\n", - " # None in `block_shape` means we pick a size of 1 and squeeze it away\n", - " in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],\n", - " out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),\n", - " out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),\n", - " )(x)\n", - "naive_sum(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Kv9qJYJY4jbK" - }, - "source": [ - "Notice how we've set up the `BlockSpec`s: we're loading the entirety of\n", - "the `(512, 512)` dimension into VMEM (no pipelining there) but selecting\n", - "the `i`-th dimension of `x` each iteration in the `index_map`.\n", - "We are using a `None` for that dimension in the block shape, which indicates\n", - "that we are selecting a singleton dimension from `x` that we would like\n", - "to squeeze away in the kernel.\n", - "Therefore, `x_ref` is `(512, 512)`-shaped in VMEM as well.\n", - "\n", - "`out_spec` uses `lambda i: (0, 0)` as its `index_map`, indicating that\n", - "`o_ref` is unchanged over the course of the pipeline.\n", - "This means that we can update its value each iteration by reading from and\n", - "writing to it. Or can it?\n", - "Actually there is one catch: *`o_ref` is initially garbage*, meaning we'll\n", - "be accumulating into garbage.\n", - "This will result in the overall function outputting the incorrect value!\n", - "\n", - "Therefore, **whenever we do a reduction in a kernel, we need to make sure\n", - "to initialize the `Ref` that is storing the reduced value**.\n", - "We can accomplish this by conditionally writing a value to `out_ref`\n", - "when we're on iteration 0.\n", - "We can do this with the helper function `pl.when`, a convenience wrapper\n", - "around `jax.lax.cond`, and `pl.program_id`,\n", - "which queries which iteration in a grid axis we are in." + "```python\n", + "pl.BlockSpec(\n", + " block_shape=(pl.BoundedSlice(32), 256),\n", + " index_map=lambda *grid_idxs: (pl.ds(start, end), 0),\n", + ")\n", + "```" ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "JXN2RthX5cSw", - "outputId": "195df19b-a889-479b-95b6-1fb7281f1518" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[8., 8., 8., ..., 8., 8., 8.],\n", - " [8., 8., 8., ..., 8., 8., 8.],\n", - " [8., 8., 8., ..., 8., 8., 8.],\n", - " ...,\n", - " [8., 8., 8., ..., 8., 8., 8.],\n", - " [8., 8., 8., ..., 8., 8., 8.],\n", - " [8., 8., 8., ..., 8., 8., 8.]], dtype=float32)" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def sum_kernel(x_ref, o_ref):\n", - " @pl.when(pl.program_id(axis=0) == 0)\n", - " def _():\n", - " o_ref[...] = jnp.zeros_like(o_ref)\n", - "\n", - " o_ref[...] += x_ref[...]\n", - "\n", - "def sum(x: jax.Array) -> jax.Array:\n", - " grid, *out_shape = x.shape\n", - " return pl.pallas_call(\n", - " sum_kernel,\n", - " grid=grid,\n", - " # None in `block_shape` means we pick a size of 1 and squeeze it away\n", - " in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],\n", - " out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),\n", - " out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)\n", - " )(x)\n", - "\n", - "sum(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2828qXBI5ksZ" - }, + "metadata": {}, + "outputs": [], "source": [ - "This `sum` function now outputs the correct values!\n", - "\n", - "One last thing to note about reductions in Pallas are that **they must be\n", - "done in the minormost (rightmost) dimensions of our grid** (our grid is\n", - "1-dimensional in the above example so we are reducing over its minormost\n", - "dimension). This is because the pipeline that Pallas generates using\n", - "the `BlockSpec`s, `grid` and kernel function *does not read outputs back\n", - "from HBM*.\n", - "Once you've written an output value back to HBM you cannot revisit it.\n", - "Therefore, you cannot do a reduction across a grid dimension that has any\n", - "revisiting and therefore all reductions need to happen in the rightmost\n", - "dimensions." + "# The following kernel copies `x` to the output in dynamic-sized chunks\n", + "# passed in via `slices`.\n", + "\n", + "def dynamic_block_example_kernel(x_hbm, slices_hbm, o_hbm, slices_smem):\n", + " pltpu.sync_copy(slices_hbm, slices_smem) # Copy slices into SMEM.\n", + " def pipeline_body(x_vmem, o_vmem):\n", + " o_vmem[...] = x_vmem[...]\n", + " def index_map(i):\n", + " start = slices_smem[i, 0]\n", + " size = slices_smem[i, 1] - slices_smem[i, 0]\n", + " return (pl.ds(start, size), 0)\n", + " block_spec = pl.BlockSpec(block_shape=(pl.BoundedSlice(8), 128),\n", + " index_map=index_map)\n", + " pltpu.emit_pipeline(\n", + " pipeline_body,\n", + " grid=(slices.shape[0],),\n", + " in_specs=[block_spec],\n", + " out_specs=block_spec\n", + " )(x_hbm, o_hbm)\n", + "\n", + "x = jax.random.uniform(jax.random.key(0), (8, 128), jnp.float32)\n", + "slices = jnp.array([[0, 2], [2, 3], [3, 5], [5, 8]], dtype=jnp.int32)\n", + "\n", + "hbm_block_spec = pl.BlockSpec(memory_space=pl.ANY)\n", + "out = pl.pallas_call(dynamic_block_example_kernel,\n", + " in_specs=[hbm_block_spec, hbm_block_spec],\n", + " out_specs=hbm_block_spec,\n", + " out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),\n", + " scratch_shapes=(pltpu.SMEM(slices.shape, jnp.int32),)\n", + " )(x, slices)\n", + "\n", + "np.testing.assert_allclose(x, out)" ] }, { @@ -655,7 +302,7 @@ "source": [ "(pallas_tpu_megacore)=\n", "\n", - "## TPUs in Megacore configuration" + "### TPUs in Megacore configuration" ] }, { @@ -683,10 +330,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": { + "executionInfo": { + "elapsed": 106, + "status": "ok", + "timestamp": 1744910274556, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, "id": "nQNa8RaQ-TR1", - "outputId": "385ed87c-d95c-466c-af77-df3845c979f2" + "outputId": "29c0b574-3528-49a5-8a88-b6987efc69ce" }, "outputs": [ { @@ -701,21 +358,31 @@ " [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)" ] }, - "execution_count": 9, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "def add_matrices_kernel(x_vmem_ref, y_vmem_ref, z_vmem_ref):\n", + " # Load x and y from VMEM into VREGs\n", + " x_vregs = x_vmem_ref[:, :]\n", + " y_vregs = y_vmem_ref[:, :]\n", + " # Execute a vectorized add\n", + " z_vregs = x_vregs + y_vregs\n", + " # Store the output values in VREGs back into VMEM\n", + " z_vmem_ref[:, :] = z_vregs\n", + "\n", "def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array:\n", " block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))\n", " return pl.pallas_call(\n", " add_matrices_kernel,\n", - " out_shape=x,\n", + " out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),\n", " in_specs=[block_spec, block_spec],\n", " out_specs=block_spec,\n", " grid=(2,),\n", - " compiler_params=pltpu.TPUCompilerParams(dimension_semantics=(\"parallel\",))\n", + " compiler_params=pltpu.CompilerParams(\n", + " dimension_semantics=(\"parallel\",))\n", " )(x, y)\n", "\n", "x, y = jnp.ones((512, 512)), jnp.ones((512, 512))\n", @@ -735,41 +402,58 @@ "simultaneously on each TensorCore. Pallas will handle splitting up the grid\n", "automatically.\n", "\n", - "> Note that Megacore is only currently available on TPU `v4` and TPU `v5p`. Supplying `dimension_semantics` annotations is a no-op on other platforms, but *not* specifying it will result in only one TensorCore being used (even if there are more than one available)." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1ZJ2rV5W8FAe" - }, - "source": [ - "## Conclusion\n", + "> Note that Megacore is only currently available on TPU `v4` and TPU `v5p`. Supplying `dimension_semantics` annotations is a no-op on other platforms, but *not* specifying it will result in only one TensorCore being used (even if there are more than one available).\n", "\n", - "In this guide we covered how to express TPU pipelines using `pallas_call`,\n", - "`grid` and `BlockSpec`s. We covered how to express nested loops via a\n", - "multi-dimensional grid and how to handle reductions by initialize our\n", - "accumulators at the beginning of the reduction.\n", - "We also learned how to handle Megacore by adding annotations to the kernel.\n", + "When using `pltpu.emit_pipeline`, `core_axis` should be passed into `emit_pipeline`. `core_axis` should be the index of a parallel grid axis to partition the grid on. For example, the following template can be used to partition the kernel over a leading parallel grid dimension:\n", "\n", - "Exercises left to the reader:\n", - "* Try implementing a `sum` kernel that pipelines the other dimensions as well\n", - "* Add megacore support to the `add` kernel and the `sum` kernel as well." + "```python\n", + "def kernel_body(...):\n", + " def inner_pipeline_body(...):\n", + " ...\n", + " pltpu.emit_pipeline(inner_pipeline_body,\n", + " grid=(4, 4), \n", + " core_axis=0,\n", + " dimension_semantics=(\"parallel\", \"sequential\"))\n", + "\n", + "pl.pallas_call(\n", + " kernel_body,\n", + " grid=(num_cores,),\n", + " compiler_params=pltpu.CompilerParams(\n", + " dimension_semantics=(\"parallel\",))\n", + " )\n", + "```" ] } ], "metadata": { + "colab": { + "last_runtime": { + "build_target": "//experimental/users/justinfu/pallas:colab", + "kind": "private" + }, + "provenance": [] + }, "jupytext": { "formats": "ipynb,md:myst" }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", + "language": "python", "name": "python3" }, "language_info": { - "name": "python" + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 4 } diff --git a/docs/pallas/tpu/pipelining.md b/docs/pallas/tpu/pipelining.md index df570cf0806c..02c9187edd2e 100644 --- a/docs/pallas/tpu/pipelining.md +++ b/docs/pallas/tpu/pipelining.md @@ -7,26 +7,38 @@ jupytext: format_version: 0.13 jupytext_version: 1.16.4 kernelspec: - display_name: Python 3 + display_name: Python 3 (ipykernel) + language: python name: python3 --- ++++ {"id": "7704d3bb"} + (pallas_tpu_pipelining)= +++ {"id": "teoJ_fUwlu0l"} -# Pipelining +# TPU Pipelining +++ {"id": "gAJDZh1gBh-h"} -In this guide we'll cover how memory spaces in TPU work and how to write -pipelines in Pallas that overlap memory I/O with compute. - -```{code-cell} -:id: ejAVO6ikUUuF +This guide serves as a reference for TPU-specific pipelining concerns. +We'll review the memory hierarchy and compute units on TPUs, and TPU-specific features of the pipelining API. For a more general-purpose overview of pipelining, see the {ref}`pallas_software_pipelining`. +```{code-cell} ipython3 +--- +executionInfo: + elapsed: 54 + status: ok + timestamp: 1744908474512 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: ejAVO6ikUUuF +--- #@title Imports import jax @@ -36,13 +48,13 @@ import jax.numpy as jnp import numpy as np ``` -+++ {"id": "TWKESTKAlyjT"} ++++ {"id": "0e212a5e"} (tpu_and_its_memory_spaces)= ## TPU and its memory spaces -+++ ++++ {"id": "NnWW9GV4kW6P"} A TPU and its TensorCore consist of memory spaces (where arrays can reside), registers (which temporarily store scalar and array values) and compute units @@ -65,384 +77,174 @@ Let's talk about the components of this diagram in more detail: Values can be loaded into memory from their respective caches (VMEM for VREGs and SMEM for SREGs). * **Compute units**: A TensorCore has a scalar unit, vector unit (VPU) and - matrix unit (MXU) that can do numerical computation. + matrix unit (MXU) that can do numerical computation. Each of these compute units can operate asynchronously, but this is managed by the TPU compiler and thus from the programmer's perspective a TPU program is single-threaded. Compute units operate on values that live in SREGs and VREGs and output values into those registers as well. -In order to do a vectorized computation on our values `x` and `y` that live -in HBM, we need to: ++++ {"id": "8Tl3wt5Wk3Ek"} -1. Copy the values `x` and `y` into VMEM. -2. Load the values from VMEM into VREGs. -3. Execute the computation using the VPU or MXU, storing the output in VREGs. -4. Store the values in the output VREGs into VMEM. -5. Copy the output values in VMEM back to HBM. +## TPU-specific Pipelining Features -+++ {"id": "TzctMbNsn3vc"} +Pallas TPU supports the following platform-specific features. -Let's implement a Pallas function that does just that! ++++ {"id": "1jg5WmExk47l"} -```{code-cell} -:id: 2IXQxNWrKJyb -:outputId: d62eb493-5f92-4496-f113-d3cd24cb0b9f +### TPU Memory Spaces -def add_matrices_kernel(x_vmem_ref, y_vmem_ref, z_vmem_ref): - # Load x and y from VMEM into VREGs - x_vregs = x_vmem_ref[:, :] - y_vregs = y_vmem_ref[:, :] - # Execute a vectorized add - z_vregs = x_vregs + y_vregs - # Store the output values in VREGs back into VMEM - z_vmem_ref[:, :] = z_vregs +Pallas exposes all levels of the TPU memory hierarchy to users. The following table maps from Pallas TPU memory spaces to their standard memory types (DRAM/SRAM): +| Pallas Enum | TPU Memory Space | Type (DRAM/SRAM) | +| --- | --- | --- | +| `pl.ANY` | HBM (usually) or VMEM | DRAM | +| `pltpu.VMEM` | VMEM | SRAM | +| `pltpu.SMEM` | SMEM | SRAM | +| `pltpu.SEMAPHORE` | Semaphore | SRAM | -def add_matrices(x: jax.Array, y: jax.Array) -> jax.Array: - # pallas_call will first allocate scratch buffers for `x` and `y` in VMEM. - # It will then copy `x` and `y` from HBM into VMEM. - z = pl.pallas_call( - add_matrices_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype) - )(x, y) - # pallas_call will also copy the output from VMEM back into HBM. - return z +- `MemorySpace.VMEM` denotes vector SRAM. It is the default memory space if nothing is specified. +- `MemorySpace.SMEM` denotes scalar SRAM. Only scalar loads and stores can be performed to/from SMEM. +- `MemorySpace.ANY` is a hint to the compiler that the memory space is unconstrained. In most cases, XLA will place this buffer in HBM. A buffer assigned to the `ANY` memory space cannot be dereferenced normally using array indexing syntax (e.g. `x[...]`). Instead, we must first copy the values into a VMEM or SMEM buffer using `pltpu.sync_copy` or `pltpu.async_copy`. +- `MemorySpace.SEMAPHORE` is used to allocate semaphores for constructing barriers or tracking asynchronous operations. It is also possible to return semaphores from the kernel for building asynchronous kernels - this is an experimental feature; see {ref}`pallas_async` for more details. +Pipelining on TPUs is typically done between HBM (DRAM) to VMEM (Vector SRAM). The default behavior for `pallas_call` on TPU is that arguments to `pallas_call` are assumed to live in HBM, and inputs to the user kernel body are stored in VMEM. -x, y = jnp.ones((512, 512)), jnp.ones((512, 512)) -add_matrices(x, y) -``` - -+++ {"id": "HMENNLy8okCL"} - -We've written two functions: `add_matrices_kernel` and `add_matrices`. - -`add_matrices_kernel` operates using `Ref`s that live in VMEM. -Loading from a VMEM `Ref` produces a value that lives in VREGs. -Values in VREGs behave like `jax.Array`s in that we can use `jnp` and -`jax.lax` operations on them to produce new values that live in VREGs. -When we produce the values we'd like to return, we store them in the output -VMEM `Ref`. - -The `add_matrices` function acts on `jax.Array`s and returns a `jax.Array`. -Inside it, we pass `x` and `y` into `pallas_call`. -`pallas_call` is responsible for copying `x` and `y` into VMEM and for -allocating the VMEM buffers that the kernel operates on (including allocating -`z_vmem_ref`, the output VMEM buffer). -After the kernel function is finished running, `pallas_call` will also copy -the value in `z_vmem_ref` to HBM, resulting in an output `jax.Array`. - -+++ {"id": "5kWr-1tKpYro"} - -## Constraints of using VMEM/SMEM - -Pallas exposes access to lower level memory spaces like VMEM and SMEM but -writing kernels utilizing them adds some considerations. - -1. Memory capacity. VMEM and SMEM are *small*! VMEM on v4 TPUs is only 16MiB - and SMEM ranges in the tens to hundreds of KiB. - If our arrays are too big, we won't even be able to fit them into VMEM at all. - For reference, a `f32[2048, 2048]` array is 16MiB, so our above kernel won't - scale beyond moderately sized arrays. - -2. Memory bandwidth. Copying to/from HBM and VMEM takes a long time, at least - compared to most compute instructions. - The `add_matrices` function above will likely spend more time copying - between HBM and VMEM than actually performing the addition itself. - -With these two constraints in mind, we'll have to rethink our strategy for -getting performance out of our TPUs. - -+++ {"id": "_NTqvlbetB3P"} - -## Primer: Pipelining - -Pipelining our computation offers a way of dealing with both the memory -capacity and bandwidth constraints in one fell swoop. -What do we mean by pipelining? - -The goal is: *in parallel* copy to/from HBM and VMEM *while* utilizing our -compute units. -Naively this is difficult because in our program above we copy *all* of `x` -and `y` before we start doing any compute with them, creating a dependence -between the copy and the compute. - -However, if we can chunk up our computation into several subcomputations -(e.g. when we add two matrices, we can express that as addition of "blocks" -of the original matrices together), we can now overlap the copies of one of -those subcomputations with the compute of the other. Let's walk through a -simple example: - -Let's say we split our arrays `x` and `y` into `x1, x2` and `y1, y2` (for -example, split along the leading axis, resulting in two `(256, 512)` arrays -for each input. -We can now execute the following pipelined computation. - -1. Copy `x1` and `y1` into VMEM. -1. Start copying `x2` and `y2` into VMEM -2. Load `x1, y1` from VMEM into VREGs. -3. Execute the `z1 = x1 + y1` using the compute units. -4. Store `z1` into VMEM. -5. Start copying `z1` from VMEM back into HBM. -6. Wait until `x2, y2` have been copied into VMEM. -7. Load `x2, y2` from VMEM into VREGs. -8. Execute the `z2 = x2 + y2` using the compute units. -9. Store `z2` into VMEM. -10. Wait until `z1` is copied into HBM. -10. Start copying `z2` from VMEM back into HBM. -10. Wait until `z2` is copied into HBM. - -Any time we are doing compute here, we are asynchronously copying something. -This means that some of the time spent copying is not wasted. - -The two most important numbers for determining how efficient a pipelined -computation are a) how many floating point operations (FLOPs) we need to -execute and b) how many bytes we need to copy to execute that computation. -The ratio of these two (FLOPs/memory usage) is called the -*arithmetic intensity* of an operation and determines if our pipeline will -be compute bound or memory bound. - -+++ {"id": "gutx7y8uvZKH"} - -## Pipelining in Pallas - -+++ {"id": "U-dPTjlBverB"} - -How do we implement a pipeline like the one above in Pallas? -It seems like a complex sequence of asynchronous data operations and -executing kernels that would be a pain to implement manually. -Fear not! Pallas offers an API for expressing pipelines without too much -boilerplate, namely through `grid`s and `BlockSpec`s. - -See how in the above pipelined example, we are executing the same logic -multiple times: steps 3-5 and 8-10 both execute the same operations, -only on different inputs. -The {func}`jax.experimental.pallas.pallas_call` provides a way to -execute a kernel multiple times, by using the `grid` argument. -See {ref}`pallas_grid`. - -We also use {class}`jax.experimental.pallas.BlockSpec` to specify -how to construct the input of each kernel invocation. -See {ref}`pallas_blockspec`. - -In the pipelining example above, we had `(512, 512)`-shaped arrays and -split them along the leading dimension into two `(256, 512)`-shaped arrays. -In this pipeline, our `BlockSpec.block_shape` would be `(256, 512)`. -On the 1st iteration we'd -like to select `x1` and on the second iteration we'd like to use `x2`. -This can be expressed with the following `index_map`: +While not specific to pipelining, it is possible to gain manual control over the memory space of input and output buffers, you can specify the `memory_space` argument on a `BlockSpec`. Note that pipelining is not allowed unless the `memory_space` is marked as `VMEM`. Memory spaces can also be used to specify scratch arguments to a kernel via the `scratch_shapes` argument on `pallas_call`. Scratch buffers are persistent across kernel iterations and are useful for storing intermediate results such as partial accumulations and reductions. A scratch buffer must reside in `VMEM`, `SMEM`, or `SEMAPHORE`. -```python -def x_index_map(i): - return (i, 0) -``` +As an example for using multiple manual memory space assignments in a kernel, the following program copies a slice of an HBM buffer `x_hbm_ref` into a scratch VMEM buffer `scratch_vmem_ref` before using it for arithmetic and storing the result into an output VMEM buffer: -We'd then construct the `BlockSpec`: -```python -block_spec = pl.BlockSpec((256, 512), x_index_map) +```{code-cell} ipython3 +--- +executionInfo: + elapsed: 65 + status: ok + timestamp: 1744908591430 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: zcqz1CA_o50a +--- +def hbm_vmem_kernel(x_hbm_ref, out_vmem_ref, scratch_vmem_ref): + pltpu.sync_copy(x_hbm_ref.at[0:1], scratch_vmem_ref) + out_vmem_ref[...] = scratch_vmem_ref[...] + 1 + +x = jax.random.uniform(jax.random.key(0), (8, 128), jnp.float32) +out = pl.pallas_call(hbm_vmem_kernel, + in_specs=[pl.BlockSpec(memory_space=pl.ANY)], + out_shape=jax.ShapeDtypeStruct((1, 128), jnp.float32), + scratch_shapes=(pltpu.VMEM(shape=(1, 128), dtype=jnp.float32),) +)(x) + +np.testing.assert_allclose(out, x[0:1] + 1) ``` -The `BlockSpec`s for `y` and `z` will be the same as the one for `x`. - -+++ {"id": "noybOKghzjwG"} - -### Putting it together +### Multiple Buffering -We provide these arguments to `pallas_call` via `grid`, `in_specs` and -`out_specs` (`in_specs` corresponds to the tuple of positional arguments, -and `out_specs` corresponds to the output). +Multiple buffering can be specified on a per-argument basis to the pipeline via the `pipeline_mode` option on `pl.BlockSpec`. To do so, pass a `pl.Buffered` object to `pl.BlockSpec` specifying the number of buffers to allocate for this particular argument: -```{code-cell} -:id: ehKAYAwIojfv -:outputId: 504bab29-83f3-4e1f-8664-1860ad15b6de - -def add_matrices_pipelined(x: jax.Array, y: jax.Array) -> jax.Array: - block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0)) - return pl.pallas_call( - add_matrices_kernel, - out_shape=x, - in_specs=[block_spec, block_spec], - out_specs=block_spec, - grid=(2,) - )(x, y) - -add_matrices_pipelined(x, y) +```python +pl.BlockSpec( + pipeline_mode=pl.Buffered(buffer_count=buffer_count) +) ``` -+++ {"id": "rkytgIZYzz4t"} - -We've only added a little bit of code to our original function to add -automatic pipelining but the `BlockSpec`s and `grid` do a lot of heavy -lifting! - -How does it work? Well, the `BlockSpec`s provide enough information to start -*prefetching* blocks of our input from HBM into VMEM. -For example, if we are starting iteration `i` of our `grid`, we can pass -`i + 1` into the `index_map` functions to obtain the blocks needed for the -next iteration. We can then start an asynchronous copy for those blocks. -Similarly for outputs, we can wait for the outputs of the previous iteration -to be copied before starting the copy for the current iteration's outputs. - -+++ {"id": "7Xtz9oMs0ZRL"} - -### Parameterizing a pipeline +The default buffer count is 2 for all inputs and outputs. -+++ {"id": "esY4GcIB0bqQ"} ++++ -It's common to parameterize the block shapes in our kernel. Block sizes are -perhaps the most important parameter to tune when optimizing the performance -of Pallas kernels! They give us control over the pipeline (for example, -picking smaller blocks adds more iterations to our pipelined loop where each -iteration has less work to do). +(pallas_tpu_emit_pipeline)= -Furthermore, we could also carve up the inputs and outputs along the 2nd -dimension (we are only splitting along the first right now). Let's write a -more general kernel that handles both of these features. +### pltpu.emit_pipeline -```{code-cell} -:id: VartelFd0YfY +`pltpu.emit_pipeline` is a pipelining API implemented in Pallas that allows you to construct pipelines inside of a kernel rather than only on kernel entry. This several use-cases over using `pl.pallas_call`, such as: +- For constructing nested pipelines. For example, an outer pipeline that communicates between chips, and an inner pipeline that performs HBM-VMEM pipelining. +- For using `emit_pipeline` specific features such as lookahead prefetch and dynamic block shapes (covered below). -def add_matrices_pipelined_2d( - x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256 -) -> jax.Array: - m, n = x.shape - block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j)) - return pl.pallas_call( - add_matrices_kernel, - out_shape=x, - in_specs=[block_spec, block_spec], - out_specs=block_spec, - grid=(m // bm, n // bn), - )(x, y) +`pltpu.emit_pipeline` follows a similar signature to `pl.pallas_call` and requires you to specify a body `kernel`, a grid, and block specs for inputs and outputs: -np.testing.assert_array_equal( - add_matrices_pipelined_2d(x, y, bm=256, bn=256), x + y -) -np.testing.assert_array_equal( - add_matrices_pipelined_2d(x, y, bm=128, bn=128), x + y -) -np.testing.assert_array_equal( - add_matrices_pipelined_2d(x, y, bm=512, bn=512), x + y -) +```python +def emit_pipeline( + kernel: Callable, + grid: tuple[int], + in_specs: PyTree[BlockSpec] = None, + out_specs: PyTree[BlockSpec] = None, + dimension_semantics: tuple[GridDimensionSemantics] = None, + core_axis: int | None = None, +) -> Callable: + ... # Returns a custom pipeline given an inner kernel and BlockSpecs. ``` -+++ {"id": "KrfeYwaW1QA-"} +The `dimension_semantics` and `core_axis` arguments are used for partitioning the kernel grid over Megacore (see below). -## Handling reductions - -+++ {"id": "P3SqEKDe3Mar"} ++++ -How would you implement something like `jnp.sum` using `pallas_call`? -Specifically, we'd like to pipeline across the reduction dimension. +### Lookahead Prefetch -Take the example of reducing a `(8, 512, 512)`-shaped array to a -`(512, 512)`-shaped one. +Lookahead prefetch is a pipelining feature where the pipeline will attempt to prefetch the next input block as soon as a buffering slot is available, rather than the iteration directly before it would be used. For example, if the kernel had a grid of `(8,)` and the block indices to fetch on each iteration were `0, 0, 0, 0, 1, 1, 1, 1`, then lookahead prefetch will begin fetching both blocks `0` and `1` on iteration 0, whereas the standard pipeline schedule would fetch block `0` on iteration 0 but not begin fetching block `1` until iteration 3. There is a small amount of control flow overhead in performing lookahead so it is disabled by default. -```{code-cell} -:id: JoT-ZKEk1R7l -:outputId: fd842223-98a5-4e5c-87fc-5dadc94da4fa +Lookahead is primarily useful when there is a variable amount of compute work in each block, such as when some blocks contain skipped or a reduced amount of work. In these cases, there may not be enough compute work in the iteration immediately preceding the step when the block is needed to fully overlap with the memory transfer. Therefore, we would like to begin fetching blocks earlier in the pipeline. -x = jnp.ones((8, 512, 512)) -jnp.sum(x, axis=0) +Lookahead prefetch can be used in conjunction with multiple buffering and can likewise be enabled by passing `pl.Buffered` into the `pipeline_mode` argument: +```python +pl.BlockSpec( + pipeline_mode=pl.Buffered(buffer_count=buffer_count, use_lookahead=True) +) ``` -+++ {"id": "5O3ByvuT3iyC"} - -To do this using `pallas_call`, we could use a grid of size `(8,)` and in -each iteration `i` load `x[i]` into VMEM. -Then we could add `x[i]` to an output VMEM buffer. Let's implement this -naively first. ++++ -```{code-cell} -:id: hqvv_WRQ3bvP -:outputId: 200648d2-3f4d-4d1a-b95a-d2c1352cd7b8 +### Dynamic Block Shapes -# Warning: this implementation is incorrect! +`pltpu.emit_pipeline` supports pipelining over blocks with dynamic but bounded shapes. In order to specify such an block shape, the dynamic-sized dimension in the block should be marked with `pl.BoundedSlice(max_size)` rather than a static integer size, where `max_size` is the maximum size of the block. In addition, the corresponding index returned by `index_map` should be a dynamic slice constructed via `pl.ds(start, size)` where both `start` and `size` are _element_ indices (not block indices) and can be dynamic. -def naive_sum_kernel(x_ref, o_ref): - o_ref[...] += x_ref[...] +The following is an example for a block spec with a dynamic first dimension: -def naive_sum(x: jax.Array) -> jax.Array: - grid, *out_shape = x.shape - return pl.pallas_call( - naive_sum_kernel, - grid=grid, - # None in `block_shape` means we pick a size of 1 and squeeze it away - in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))], - out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)), - out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype), - )(x) -naive_sum(x) +```python +pl.BlockSpec( + block_shape=(pl.BoundedSlice(32), 256), + index_map=lambda *grid_idxs: (pl.ds(start, end), 0), +) ``` -+++ {"id": "Kv9qJYJY4jbK"} - -Notice how we've set up the `BlockSpec`s: we're loading the entirety of -the `(512, 512)` dimension into VMEM (no pipelining there) but selecting -the `i`-th dimension of `x` each iteration in the `index_map`. -We are using a `None` for that dimension in the block shape, which indicates -that we are selecting a singleton dimension from `x` that we would like -to squeeze away in the kernel. -Therefore, `x_ref` is `(512, 512)`-shaped in VMEM as well. - -`out_spec` uses `lambda i: (0, 0)` as its `index_map`, indicating that -`o_ref` is unchanged over the course of the pipeline. -This means that we can update its value each iteration by reading from and -writing to it. Or can it? -Actually there is one catch: *`o_ref` is initially garbage*, meaning we'll -be accumulating into garbage. -This will result in the overall function outputting the incorrect value! - -Therefore, **whenever we do a reduction in a kernel, we need to make sure -to initialize the `Ref` that is storing the reduced value**. -We can accomplish this by conditionally writing a value to `out_ref` -when we're on iteration 0. -We can do this with the helper function `pl.when`, a convenience wrapper -around `jax.lax.cond`, and `pl.program_id`, -which queries which iteration in a grid axis we are in. - -```{code-cell} -:id: JXN2RthX5cSw -:outputId: 195df19b-a889-479b-95b6-1fb7281f1518 - -def sum_kernel(x_ref, o_ref): - @pl.when(pl.program_id(axis=0) == 0) - def _(): - o_ref[...] = jnp.zeros_like(o_ref) - - o_ref[...] += x_ref[...] - -def sum(x: jax.Array) -> jax.Array: - grid, *out_shape = x.shape - return pl.pallas_call( - sum_kernel, - grid=grid, - # None in `block_shape` means we pick a size of 1 and squeeze it away - in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))], - out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)), - out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype) - )(x) - -sum(x) +```{code-cell} ipython3 +# The following kernel copies `x` to the output in dynamic-sized chunks +# passed in via `slices`. + +def dynamic_block_example_kernel(x_hbm, slices_hbm, o_hbm, slices_smem): + pltpu.sync_copy(slices_hbm, slices_smem) # Copy slices into SMEM. + def pipeline_body(x_vmem, o_vmem): + o_vmem[...] = x_vmem[...] + def index_map(i): + start = slices_smem[i, 0] + size = slices_smem[i, 1] - slices_smem[i, 0] + return (pl.ds(start, size), 0) + block_spec = pl.BlockSpec(block_shape=(pl.BoundedSlice(8), 128), + index_map=index_map) + pltpu.emit_pipeline( + pipeline_body, + grid=(slices.shape[0],), + in_specs=[block_spec], + out_specs=block_spec + )(x_hbm, o_hbm) + +x = jax.random.uniform(jax.random.key(0), (8, 128), jnp.float32) +slices = jnp.array([[0, 2], [2, 3], [3, 5], [5, 8]], dtype=jnp.int32) + +hbm_block_spec = pl.BlockSpec(memory_space=pl.ANY) +out = pl.pallas_call(dynamic_block_example_kernel, + in_specs=[hbm_block_spec, hbm_block_spec], + out_specs=hbm_block_spec, + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + scratch_shapes=(pltpu.SMEM(slices.shape, jnp.int32),) + )(x, slices) + +np.testing.assert_allclose(x, out) ``` -+++ {"id": "2828qXBI5ksZ"} - -This `sum` function now outputs the correct values! - -One last thing to note about reductions in Pallas are that **they must be -done in the minormost (rightmost) dimensions of our grid** (our grid is -1-dimensional in the above example so we are reducing over its minormost -dimension). This is because the pipeline that Pallas generates using -the `BlockSpec`s, `grid` and kernel function *does not read outputs back -from HBM*. -Once you've written an output value back to HBM you cannot revisit it. -Therefore, you cannot do a reduction across a grid dimension that has any -revisiting and therefore all reductions need to happen in the rightmost -dimensions. - +++ {"id": "KvPFez9N8cKJ"} (pallas_tpu_megacore)= -## TPUs in Megacore configuration +### TPUs in Megacore configuration +++ {"id": "0f4HAVzQ8n71"} @@ -462,19 +264,38 @@ computation, we can split up those dimensions across the TensorCores. We can indicate which dimensions are parallelizable by providing an annotation to `pallas_call` called `dimension_semantics`. -```{code-cell} -:id: nQNa8RaQ-TR1 -:outputId: 385ed87c-d95c-466c-af77-df3845c979f2 +```{code-cell} ipython3 +--- +executionInfo: + elapsed: 106 + status: ok + timestamp: 1744910274556 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: nQNa8RaQ-TR1 +outputId: 29c0b574-3528-49a5-8a88-b6987efc69ce +--- +def add_matrices_kernel(x_vmem_ref, y_vmem_ref, z_vmem_ref): + # Load x and y from VMEM into VREGs + x_vregs = x_vmem_ref[:, :] + y_vregs = y_vmem_ref[:, :] + # Execute a vectorized add + z_vregs = x_vregs + y_vregs + # Store the output values in VREGs back into VMEM + z_vmem_ref[:, :] = z_vregs def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array: block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0)) return pl.pallas_call( add_matrices_kernel, - out_shape=x, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), in_specs=[block_spec, block_spec], out_specs=block_spec, grid=(2,), - compiler_params=pltpu.TPUCompilerParams(dimension_semantics=("parallel",)) + compiler_params=pltpu.CompilerParams( + dimension_semantics=("parallel",)) )(x, y) x, y = jnp.ones((512, 512)), jnp.ones((512, 512)) @@ -492,16 +313,21 @@ automatically. > Note that Megacore is only currently available on TPU `v4` and TPU `v5p`. Supplying `dimension_semantics` annotations is a no-op on other platforms, but *not* specifying it will result in only one TensorCore being used (even if there are more than one available). -+++ {"id": "1ZJ2rV5W8FAe"} +When using `pltpu.emit_pipeline`, `core_axis` should be passed into `emit_pipeline`. `core_axis` should be the index of a parallel grid axis to partition the grid on. For example, the following template can be used to partition the kernel over a leading parallel grid dimension: -## Conclusion - -In this guide we covered how to express TPU pipelines using `pallas_call`, -`grid` and `BlockSpec`s. We covered how to express nested loops via a -multi-dimensional grid and how to handle reductions by initialize our -accumulators at the beginning of the reduction. -We also learned how to handle Megacore by adding annotations to the kernel. - -Exercises left to the reader: -* Try implementing a `sum` kernel that pipelines the other dimensions as well -* Add megacore support to the `add` kernel and the `sum` kernel as well. +```python +def kernel_body(...): + def inner_pipeline_body(...): + ... + pltpu.emit_pipeline(inner_pipeline_body, + grid=(4, 4), + core_axis=0, + dimension_semantics=("parallel", "sequential")) + +pl.pallas_call( + kernel_body, + grid=(num_cores,), + compiler_params=pltpu.CompilerParams( + dimension_semantics=("parallel",)) + ) +``` diff --git a/docs/pallas/tpu/prng.rst b/docs/pallas/tpu/prng.rst new file mode 100644 index 000000000000..be1f4b04ca3a --- /dev/null +++ b/docs/pallas/tpu/prng.rst @@ -0,0 +1,171 @@ +Pseudo-Random Number Generation +=============================== + +Pallas TPU implements several APIs for generating pseudorandom numbers inside of a kernel with varying tradeoffs in portability and efficiency. For maximum portability, consider using `jax.random` functions directly. Pallas also exposes the hardware PRNG contained on TPUs which are the fastest to compute but the underlying implementation can vary between hardware generations. + +Using the ``jax.random`` API +---------------------------- + +Pallas supports a subset of operations in the ``jax.random`` API. These functions are guaranteed to produce bitwise-equal results compared to calling these functions in JAX outside of Pallas when given the same key. Only ``threefry2x32`` keys are supported. + +The following random sampling functions are currently supported: + +* :func:`jax.random.bits` +* :func:`jax.random.uniform` +* :func:`jax.random.bernoulli` +* :func:`jax.random.normal` + +The following utility functions are supported: + +* :func:`jax.random.key` +* :func:`jax.random.fold_in` +* :func:`jax.random.wrap_key_data` + +PRNG keys can be generated inside of the kernel using :func:`jax.random.key`. However, the more likely scenario is that a key will be passed into the kernel from the caller. In such a case, the key can be passed into the kernel via VMEM as follows: + +.. code-block:: python + + def body(key_ref, o_ref): + key = key_ref[...] + o_ref[...] = jax_random.uniform( + key, shape=o_ref[...].shape, minval=0.0, maxval=1.0 + ) + + threefry_key = jax_random.key(0, impl="threefry2x32") + + # We generate a threefry key outside of the kernel and pass it in via VMEM. + result = pl.pallas_call( + body, + in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], + out_shape=jax.ShapeDtypeStruct((256, 256), jnp.float32) + )(threefry_key) + +.. note:: + + In terms of performance concerns, generating random numbers inside of a kernel helps reduce memory bandwidth usage as it is cheaper to pass in a key than a large array of random numbers. However, ``threefry2x32`` is a vector-heavy algorithm that involves dozens of chained bitwise operations. This can become a bottleneck and lead to low accelerator usage as it does not utilize the matrix multiply unit (MXU) where the majority of FLOP/s are. + +Using the hardware PRNG +----------------------- + +TPUs implement a sequential (rather than counter-based) PRNG natively in hardware that is much faster to compute than using a software-implemented PRNG such as ``threefry2x32``. However, JAX random APIs assume a stateless, counter-based PRNG so Pallas introduces its own stateful PRNG API to offer equivalent functionality. + +.. warning:: + + The underlying implementation of the hardware PRNG varies between TPU generations, so it is best practice to not depend on its exact behavior. For a more stable PRNG implemented in software, it is recommended to use the ``threefry2x32`` implementation. + + +Stateful Random Number Generation +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Using the Pallas PRNG in stateful mode is the most native and efficient method for generative random numbers. First, the PRNG seed should be set using ``pltpu.prng_seed(N)``, where N is an integer seed. + +Afterwards, you can call any number of stateful sampling functions which are equivalent to the corresponding JAX version but lack the ``key`` argument: + +* ``pltpu.stateful_uniform``: the stateful equivalent to :func:`jax.random.uniform` +* ``pltpu.stateful_normal``: the stateful equivalent to :func:`jax.random.normal` +* ``pltpu.stateful_bernoulli``: the stateful equivalent to :func:`jax.random.bernoulli` + +Generating any random number updates the internal state of the PRNG and subsequent calls will generate different numbers. Unlike in JAX, there is no need to ``split`` or ``fold_in`` keys and pass them into the sampling functions. + +For example, the following kernel generates a set of uniform numbers from 0 to 1: + +.. code-block:: python + + from jax.experimental.pallas import tpu as pltpu + + def kernel_body(o_ref): + pltpu.prng_seed(0) + o_ref[...] = pltpu.stateful_uniform(shape=o_ref.shape, minval=0.0, maxval=1.0) + + pl.pallas_call(kernel_body, + out_shape=jax.ShapeDtypeStruct((256, 256), jnp.float32)) + +Note that in kernels with a grid, the seed should only be set on the first iteration, or else the random numbers generated in each program instance will be identical due to resetting the seed. + +Stateless Generation +^^^^^^^^^^^^^^^^^^^^ + +Pallas offers an intermediate API between the stateless API described previously and the stateless ``jax.random`` API and allows you to use the hardware PRNG in a stateless manner. In order to do so, convert a JAX key into a special Pallas-typed key via ``pltpu.to_pallas_key(key)`` and pass this key into the kernel via SMEM. Once the key is dereferenced inside the kernel, it can be passed into supported sampling functions from ``jax.random`` to produce random numbers. Compared to the stateless API, there is an overhead of computing and setting a seed every time the random number generator is invoked. + +For example, the following kernel draws uniform numbers using the hardware PRNG: + +.. code-block:: python + + def body(key_ref, o_ref): + o_ref[...] = jax.random.uniform( + key_ref[...], shape=o_ref[...].shape + ) + + rbg_key = jax_random.key(0, impl="threefry2x32") + key = pltpu.to_pallas_key(rbg_key) + o_shape = jax.ShapeDtypeStruct((8, 128), dtype) + result = pl.pallas_call( + body, + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], + out_shape=o_shape, + )(key) + +For larger kernels with a grid, :func:`jax.random.fold_in` can be used on the ``program_id`` to generate a unique key for each program instance. + + +Block-invariant sampling +------------------------ + +Block-invariant sampling is a method for generating random numbers in blocks that is invariant to the block sizes and iteration order used. For example, you may wish to generate identical sets of random numbers between two kernels (such as a forwards and backwards pass), but the two kernels may have different block sizes chosen after tuning. + +Pallas providers a helper function (``pltpu.sample_block``) that allows one to guarantee identical random numbers drawn over different block and grid settings. The first step is to select a ``tile_size``, which is a tile that divides all block sizes you wish to be invariant to. For example, ``tile_size=(16, 128)`` would work for block sizes of ``(32, 128)`` and ``(16, 256)``. The larger the tile size, the more efficient the sampling process will be, so the greatest common divisor between all potential block sizes is the best choice. + +Next, call ``pltpu.sample_block`` with the following arguments: + +.. code-block:: python + + pltpu.sample_block( + sampler_function, # A JAX random function, such as `jax.random.uniform`. + global_key, # A global key shared across all blocks. + block_size, # The local block size to generate. + tile_size, # The tile size. + total_size, # The total shape of the generated array across all blocks. + block_index, # The block index into total_size. Usually this is the current program instance. + **sampler_kwargs # Keyword arguments to sampler_function + ) + +For example, the following snippet generates identical numbers over a `(16, 128)` block shape, and a `(32, 256)` block shape with a transposed grid iteration order: + +.. code-block:: python + + def make_kernel_body(index_map): + def body(key_ref, o_ref): + key = key_ref[...] + samples = pltpu.sample_block( + jax.random.uniform, + key, + block_size=o_ref[...].shape, + tile_size=(16, 128), + total_size=(64, 512), + block_index=index_map(pl.program_id(0), pl.program_id(1)), + minval=0.0, + maxval=1.0) + o_ref[...] = samples + return body + + global_key = pltpu.to_pallas_key(jax_random.key(0)) + o_shape = jnp.ones((64, 512), dtype=jnp.float32) + key_spec = pl.BlockSpec(memory_space=pltpu.SMEM) + out_spec = pl.BlockSpec((16, 128), lambda i, j: (i, j)) + result_16x128 = pl.pallas_call( + make_kernel_body(index_map=lambda i, j: (i, j)), + out_shape=o_shape, + in_specs=[key_spec], + out_specs=out_spec, + grid=(4, 4), + )(global_key) + + out_spec = pl.BlockSpec((32, 256), lambda i, j: (j, i)) + result_32x256_transposed = pl.pallas_call( + make_kernel_body(index_map=lambda i, j: (j, i)), + in_specs=[key_spec], + out_shape=o_shape, + out_specs=out_spec, + grid=(2, 2), + )(global_key) + diff --git a/docs/pallas/tpu/sparse.ipynb b/docs/pallas/tpu/sparse.ipynb index ac3a0dad2404..6834f2d7d930 100644 --- a/docs/pallas/tpu/sparse.ipynb +++ b/docs/pallas/tpu/sparse.ipynb @@ -62,7 +62,7 @@ "source": [ "## Dynamic Block Indexing with Scalar Prefetch\n", "\n", - "We will be exploiting the \"scalar prefetch\" feature of Pallas to enable us to write sparse kernels. Scalar prefetch allows you to pass in a small amount of data into SMEM (\"scalar memory\") that is loaded before the start of the pipeline (\"prefetch\"). Because this data is loaded before the pipeline, it is available for use in the `index_map` for each BlockSpec, allowing the you to perform data-dependent indexing calculations. The main goal of this tutorial is to go over common programming patterns that utilize this feature.\n", + "We will be exploiting the \"scalar prefetch\" feature of Pallas to enable us to write sparse kernels. Scalar prefetch allows you to pass in a small amount of data into SMEM (\"scalar memory\") that is loaded before the start of the pipeline (\"prefetch\"). Because this data is loaded before the pipeline, it is available for use in the `index_map` for each BlockSpec, allowing you to perform data-dependent indexing calculations. The main goal of this tutorial is to go over common programming patterns that utilize this feature.\n", "\n", "To use scalar prefetch, use `pltpu.PrefetchScalarGridSpec` in place of the standard `pl.GridSpec`:\n", "\n", @@ -253,13 +253,13 @@ "source": [ "## Example: Sparse @ Dense Matrix Multiplication\n", "\n", - "In our first example, we will multiple a sparse LHS matrix with a dense RHS matrix to produce a dense output.\n", + "In our first example, we will multiply a sparse LHS matrix with a dense RHS matrix to produce a dense output.\n", "\n", "We will structure our kernel grid with 2 loops - the outer loop over the columns of the RHS/output, and inner loop over the sparse blocks of the LHS. During each inner loop iteration, we load one block from the LHS and lookup the corresponding block on in the RHS using the block index of the contracting dimension (K). We multiply the two blocks together and accumulate into the correct output block. One outer loop iteration will compute a result for an entire column as depicted by the following diagram:\n", "\n", "![sparse_matmul](../../_static/pallas/sparse/sparse_matmul.svg)\n", "\n", - "It is important that we group the block indices by row (e.g. `[0, 0, 1, 2, 3, 3]`) before we pass them into the kernel for two reasons. First, in our kernel we need to know when to initially zero-out the accumulator in the output ref, and it is easy to do so if the row indices are grouped. Second, the pipelining logic for Pallas does not allow us to re-visit blocks in the output `Ref` on non-consecutive iterations, and therefore we need to do all accumulation into an output block in consecutive kernel iterations. This is because the pipeline emitter will realize that we loading the same output block on consecutive iterations and keep the block in VMEM. When we change output block Pallas will finally store the output into HBM and assume we never touch it again. Failure to access output blocks consecutively will result in incorrect values even though the kernel is otherwise logically correct." + "It is important that we group the block indices by row (e.g. `[0, 0, 1, 2, 3, 3]`) before we pass them into the kernel for two reasons. First, in our kernel we need to know when to initially zero-out the accumulator in the output ref, and it is easy to do so if the row indices are grouped. Second, the pipelining logic for Pallas does not allow us to re-visit blocks in the output `Ref` on non-consecutive iterations, and therefore we need to do all accumulation into an output block in consecutive kernel iterations. This is because the pipeline emitter will realize that we are loading the same output block on consecutive iterations and keep the block in VMEM. When we change output block Pallas will finally store the output into HBM and assume we never touch it again. Failure to access output blocks consecutively will result in incorrect values even though the kernel is otherwise logically correct." ] }, { @@ -437,7 +437,7 @@ "\n", "In our previous example we considered the case when the data itself is sparse. This manifested itself in the kernel structure as a dimension in the kernel grid that was dynamic and looped over the number of nonzero blocks (`num_blocks`).\n", "\n", - "A second useful programming pattern emerges when the underlying is data is dense, but we wish to perform sparse computation over it. Our kernel grid in this case will be dense, but we wish to skip over some blocks in the grid as indicated by a block-sparse mask. This type of programming pattern is commonly arises when using masks in many machine learning applications, such as causal or local masks in self-attention. In these cases, we can entirely skip over computation in blocks where the mask is zeroed-out. Examples of this programming pattern can be found in the Splash Attention and Grouped Matrix Multiplication kernels located in `jax/experimental/pallas/ops/tpu`, or in PyTorch's [FlexAttention](https://pytorch.org/blog/flexattention/).\n", + "A second useful programming pattern emerges when the underlying data is dense, but we wish to perform sparse computation over it. Our kernel grid in this case will be dense, but we wish to skip over some blocks in the grid as indicated by a block-sparse mask. This type of programming pattern commonly arises when using masks in many machine learning applications, such as causal or local masks in self-attention. In these cases, we can entirely skip over computation in blocks where the mask is zeroed-out. Examples of this programming pattern can be found in the Splash Attention and Grouped Matrix Multiplication kernels located in `jax/experimental/pallas/ops/tpu`, or in PyTorch's [FlexAttention](https://pytorch.org/blog/flexattention/).\n", "\n", "The main performance consideration with dealing with a sparse access pattern on dense data is the interaction with pipelining. On any given kernel iteration, the Pallas pipeline emitter will attempt to prefetch the next block of data by calling the `index_map` for each `BlockSpec` on the next iteration of the grid. However, if our computation is sparse we may be skipping the computation for the next block in the grid, so we need some method to tell the pipeline instead begin fetching the *next block that we are not skipping*. In order to do this, we need to construct *prefetch maps* which contains indices to the next non-skipped block of data for each kernel input. The following diagram illustrates how a prefetch map could be constructed for a block-sparse mask that is stored in a COO-like format.\n", "\n", @@ -491,7 +491,7 @@ "source": [ "def sparsify_mask(mask: jax.Array,\n", " block_shape: tuple[int, int]):\n", - " \"\"\"Preprocesses a mask into a sparse reprentation.\n", + " \"\"\"Preprocesses a mask into a sparse representation.\n", "\n", " Args:\n", " mask: A boolean array of shape [M, N]\n", @@ -511,7 +511,6 @@ " block_mask = jnp.zeros((M // bm, N // bn), dtype=mask.dtype)\n", " mask_types_finder = []\n", " mask_data = []\n", - " mask_type_idxs = []\n", "\n", " next_mask_type_idx = 0\n", " prefetch_mask = jnp.zeros_like(block_mask)\n", @@ -536,7 +535,6 @@ " next_j = j\n", " else:\n", " type_index = -1\n", - " mask_type_idxs.append(type_index)\n", " block_mask = block_mask.at[i, j].set(is_nonzero)\n", " prefetch_mask = prefetch_mask.at[i, j].set(next_mask_type_idx)\n", " prefetch_i = prefetch_i.at[i, j].set(next_i)\n", @@ -665,7 +663,7 @@ "\n", "We would generally expect performance to get closer to the theoretical peak as our inputs get larger, since a few of the main reasons why we don't exactly reach theoretical performance are:\n", "- We skip slightly less than half of computation since the blocks along the diagonal are mixed 0s and 1s, and for mixed blocks we need to compute the entire block. With larger inputs, our overhead for mixed blocks becomes smaller relative to the overall computation.\n", - "- The pipeline bubble also becomes accounts for a less percentage of the overall runtime as inputs become larger." + "- The pipeline bubble also accounts for a less percentage of the overall runtime as inputs become larger." ] }, { diff --git a/docs/pallas/tpu/sparse.md b/docs/pallas/tpu/sparse.md index 113f31d8bab2..e9a4bb143a2f 100644 --- a/docs/pallas/tpu/sparse.md +++ b/docs/pallas/tpu/sparse.md @@ -51,7 +51,7 @@ print("Running on", jax.devices()[0].device_kind) ## Dynamic Block Indexing with Scalar Prefetch -We will be exploiting the "scalar prefetch" feature of Pallas to enable us to write sparse kernels. Scalar prefetch allows you to pass in a small amount of data into SMEM ("scalar memory") that is loaded before the start of the pipeline ("prefetch"). Because this data is loaded before the pipeline, it is available for use in the `index_map` for each BlockSpec, allowing the you to perform data-dependent indexing calculations. The main goal of this tutorial is to go over common programming patterns that utilize this feature. +We will be exploiting the "scalar prefetch" feature of Pallas to enable us to write sparse kernels. Scalar prefetch allows you to pass in a small amount of data into SMEM ("scalar memory") that is loaded before the start of the pipeline ("prefetch"). Because this data is loaded before the pipeline, it is available for use in the `index_map` for each BlockSpec, allowing you to perform data-dependent indexing calculations. The main goal of this tutorial is to go over common programming patterns that utilize this feature. To use scalar prefetch, use `pltpu.PrefetchScalarGridSpec` in place of the standard `pl.GridSpec`: @@ -208,13 +208,13 @@ def generate_block_sparse_mat(key, M, N, blk_M, blk_N, p=0.2, dtype=jnp.float32) ## Example: Sparse @ Dense Matrix Multiplication -In our first example, we will multiple a sparse LHS matrix with a dense RHS matrix to produce a dense output. +In our first example, we will multiply a sparse LHS matrix with a dense RHS matrix to produce a dense output. We will structure our kernel grid with 2 loops - the outer loop over the columns of the RHS/output, and inner loop over the sparse blocks of the LHS. During each inner loop iteration, we load one block from the LHS and lookup the corresponding block on in the RHS using the block index of the contracting dimension (K). We multiply the two blocks together and accumulate into the correct output block. One outer loop iteration will compute a result for an entire column as depicted by the following diagram: ![sparse_matmul](../../_static/pallas/sparse/sparse_matmul.svg) -It is important that we group the block indices by row (e.g. `[0, 0, 1, 2, 3, 3]`) before we pass them into the kernel for two reasons. First, in our kernel we need to know when to initially zero-out the accumulator in the output ref, and it is easy to do so if the row indices are grouped. Second, the pipelining logic for Pallas does not allow us to re-visit blocks in the output `Ref` on non-consecutive iterations, and therefore we need to do all accumulation into an output block in consecutive kernel iterations. This is because the pipeline emitter will realize that we loading the same output block on consecutive iterations and keep the block in VMEM. When we change output block Pallas will finally store the output into HBM and assume we never touch it again. Failure to access output blocks consecutively will result in incorrect values even though the kernel is otherwise logically correct. +It is important that we group the block indices by row (e.g. `[0, 0, 1, 2, 3, 3]`) before we pass them into the kernel for two reasons. First, in our kernel we need to know when to initially zero-out the accumulator in the output ref, and it is easy to do so if the row indices are grouped. Second, the pipelining logic for Pallas does not allow us to re-visit blocks in the output `Ref` on non-consecutive iterations, and therefore we need to do all accumulation into an output block in consecutive kernel iterations. This is because the pipeline emitter will realize that we are loading the same output block on consecutive iterations and keep the block in VMEM. When we change output block Pallas will finally store the output into HBM and assume we never touch it again. Failure to access output blocks consecutively will result in incorrect values even though the kernel is otherwise logically correct. ```{code-cell} --- @@ -353,7 +353,7 @@ print("Reference: %.3f ms (avg over %d trials)" % (time * 1000, n_trials)) In our previous example we considered the case when the data itself is sparse. This manifested itself in the kernel structure as a dimension in the kernel grid that was dynamic and looped over the number of nonzero blocks (`num_blocks`). -A second useful programming pattern emerges when the underlying is data is dense, but we wish to perform sparse computation over it. Our kernel grid in this case will be dense, but we wish to skip over some blocks in the grid as indicated by a block-sparse mask. This type of programming pattern is commonly arises when using masks in many machine learning applications, such as causal or local masks in self-attention. In these cases, we can entirely skip over computation in blocks where the mask is zeroed-out. Examples of this programming pattern can be found in the Splash Attention and Grouped Matrix Multiplication kernels located in `jax/experimental/pallas/ops/tpu`, or in PyTorch's [FlexAttention](https://pytorch.org/blog/flexattention/). +A second useful programming pattern emerges when the underlying data is dense, but we wish to perform sparse computation over it. Our kernel grid in this case will be dense, but we wish to skip over some blocks in the grid as indicated by a block-sparse mask. This type of programming pattern commonly arises when using masks in many machine learning applications, such as causal or local masks in self-attention. In these cases, we can entirely skip over computation in blocks where the mask is zeroed-out. Examples of this programming pattern can be found in the Splash Attention and Grouped Matrix Multiplication kernels located in `jax/experimental/pallas/ops/tpu`, or in PyTorch's [FlexAttention](https://pytorch.org/blog/flexattention/). The main performance consideration with dealing with a sparse access pattern on dense data is the interaction with pipelining. On any given kernel iteration, the Pallas pipeline emitter will attempt to prefetch the next block of data by calling the `index_map` for each `BlockSpec` on the next iteration of the grid. However, if our computation is sparse we may be skipping the computation for the next block in the grid, so we need some method to tell the pipeline instead begin fetching the *next block that we are not skipping*. In order to do this, we need to construct *prefetch maps* which contains indices to the next non-skipped block of data for each kernel input. The following diagram illustrates how a prefetch map could be constructed for a block-sparse mask that is stored in a COO-like format. @@ -391,7 +391,7 @@ As we will be working with a sparse mask, we will begin by implementing a functi def sparsify_mask(mask: jax.Array, block_shape: tuple[int, int]): - """Preprocesses a mask into a sparse reprentation. + """Preprocesses a mask into a sparse representation. Args: mask: A boolean array of shape [M, N] @@ -411,7 +411,6 @@ def sparsify_mask(mask: jax.Array, block_mask = jnp.zeros((M // bm, N // bn), dtype=mask.dtype) mask_types_finder = [] mask_data = [] - mask_type_idxs = [] next_mask_type_idx = 0 prefetch_mask = jnp.zeros_like(block_mask) @@ -436,7 +435,6 @@ def sparsify_mask(mask: jax.Array, next_j = j else: type_index = -1 - mask_type_idxs.append(type_index) block_mask = block_mask.at[i, j].set(is_nonzero) prefetch_mask = prefetch_mask.at[i, j].set(next_mask_type_idx) prefetch_i = prefetch_i.at[i, j].set(next_i) @@ -542,7 +540,7 @@ Now let's compare performance versus a naive dense implementation. On TPU v5e, w We would generally expect performance to get closer to the theoretical peak as our inputs get larger, since a few of the main reasons why we don't exactly reach theoretical performance are: - We skip slightly less than half of computation since the blocks along the diagonal are mixed 0s and 1s, and for mixed blocks we need to compute the entire block. With larger inputs, our overhead for mixed blocks becomes smaller relative to the overall computation. -- The pipeline bubble also becomes accounts for a less percentage of the overall runtime as inputs become larger. +- The pipeline bubble also accounts for a less percentage of the overall runtime as inputs become larger. ```{code-cell} --- diff --git a/docs/persistent_compilation_cache.md b/docs/persistent_compilation_cache.md index 0a5a89abe26d..6e82d995b782 100644 --- a/docs/persistent_compilation_cache.md +++ b/docs/persistent_compilation_cache.md @@ -132,6 +132,34 @@ Cloud Storage (GCS) bucket. We recommend the following configuration: * All encryption policies are supported. +It is **recommended** to use +[Google Cloud Storage Fuse](https://cloud.google.com/storage/docs/cloud-storage-fuse) +to mount the GCS bucket as a local directory. This is because when running JAX +in a multi-node setup, multiple nodes might try to write to the cache +simultaneously, leading to GCS rate-limit errors. GCSFuse handles this by +ensuring that only one process can write to a file at a time, preventing these +errors. + +To set up GCSFuse, follow instructions for +[GCE](https://cloud.google.com/storage/docs/cloud-storage-fuse/mount-bucket) or +[GKE](https://cloud.google.com/kubernetes-engine/docs/how-to/cloud-storage-fuse-csi-driver-setup). +For better performance, enable file caching +([GCE](https://cloud.google.com/storage/docs/cloud-storage-fuse/file-caching) and +[GKE](https://cloud.google.com/kubernetes-engine/docs/how-to/cloud-storage-fuse-csi-driver-perf#enable-and-use-file-caching)). + +Once GCSFuse is configured, set the JAX cache directory to the GCSFuse mount +point: + +```python +# Example assuming the GCS bucket is mounted at /gcs/my-bucket +jax.config.update("jax_compilation_cache_dir", "/gcs/my-bucket/jax-cache") +``` + +**Direct GCS access :** + +If you choose not to use GCSFuse, you can point the cache directly to a GCS +bucket. + Assuming that `gs://jax-cache` is the GCS bucket, set cache location as follows: @@ -234,7 +262,7 @@ jax.config.update("jax_explain_cache_misses", True) There are a couple of pitfalls that have currently been discovered: -* Currently the persistent cache doesn't work with function that have host callbacks. In this situation, caching in completely avoided. +* Currently the persistent cache doesn't work with function that have host callbacks. In this situation, caching is completely avoided. - This is because the HLO contains a pointer to the callback and changes from run to run even if the computation and compute infrastructure is exactly the same. * Currently the persistent cache doesn't work with a function that uses primitives that implement their own custom_partitioning. @@ -260,14 +288,13 @@ If we were to merely compile this function without shard_map, the cache key for layernorm_matmul_without_shard_map = jax.jit(F, in_shardings=(...), out_sharding=(...))(x1, x2, gamma, beta) ``` -However, if we were to wrap the layernorm primitive in shard_map and define a function G that performs the same computation, the cache key for `layernorm_matmul_with_shard_map` will be the same everytime despite `LayerNorm` being implementing `custom_partitioning`: +However, if we were to wrap the layernorm primitive in shard_map and define a function G that performs the same computation, the cache key for `layernorm_matmul_with_shard_map` will be the same every time despite `LayerNorm` being implementing `custom_partitioning`: ```python import jax -from jax.experimental.shard_map import shard_map def G(x1, x2, gamma, beta, mesh, ispecs, ospecs): - ln_out = shard_map(LayerNorm, mesh, in_specs=ispecs, out_specs=ospecs, check_rep=False)(x1, x2, gamma, beta) + ln_out = jax.shard_map(LayerNorm, mesh=mesh, in_specs=ispecs, out_specs=ospecs, check_vma=False)(x1, x2, gamma, beta) return ln_out @ x2 ispecs = jax.sharding.PartitionSpec(...) diff --git a/docs/profiling.md b/docs/profiling.md index ac992b3a05da..8d9a9e37190b 100644 --- a/docs/profiling.md +++ b/docs/profiling.md @@ -1,6 +1,6 @@ # Profiling computation - + ## Viewing program traces with Perfetto @@ -8,7 +8,7 @@ We can use the JAX profiler to generate traces of a JAX program that can be visualized using the [Perfetto visualizer](https://ui.perfetto.dev). Currently, this method blocks the program until a link is clicked and the Perfetto UI loads the trace. If you wish to get profiling information without any interaction, -check out the Tensorboard profiler below. +check out the XProf profiler below. ```python with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True): @@ -64,55 +64,53 @@ Also, by default, the program will prompt you to open a link to file and open a visualizer. This feature is disabled by passing in `--no_perfetto_link` into the command. Alternatively, you can also point Tensorboard to the `log_dir` to analyze the trace (see the -"Tensorboard Profiling" section below). +"XProf (Tensorboard Profiling)" section below). -(tensorboard-profiling)= -## TensorBoard profiling +(xprof-profiling)= +## XProf (TensorBoard profiling) -[TensorBoard's -profiler](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras) -can be used to profile JAX programs. Tensorboard is a great way to acquire and +[XProf](https://openxla.org/xprof) +can be used to profile JAX programs. XProf is a great way to acquire and visualize performance traces and profiles of your program, including activity on GPU and TPU. The end result looks something like this: -![TensorBoard profiler example](_static/tensorboard_profiler.png) +![XProf example](_static/tensorboard_profiler.png) ### Installation -The TensorBoard profiler is only available with the version of TensorBoard -bundled with TensorFlow. - +XProf is available as a plugin to TensorBoard, as well as an independently +run program. ```shell -pip install tensorflow tensorboard-plugin-profile +pip install xprof ``` -If you already have TensorFlow installed, you only need to install the -`tensorboard-plugin-profile` pip package. Be careful to only install one version -of TensorFlow or TensorBoard, otherwise you may encounter the "duplicate -plugins" error described {ref}`below `. See +If you have TensorBoard installed, the `xprof` pip package will also install +the TensorBoard Profiler plugin. Be careful to only install one version of +TensorFlow or TensorBoard, otherwise you may encounter the "duplicate plugins" +error described {ref}`below `. See for more information on installing TensorBoard. -Nightly version of TensorBoard profiler requires nightly tensorflow and -tensorboard +Profiling with the nightly version of TensorBoard requires the nightly +XProf. ```shell -pip install tf-nightly tb-nightly tbp-nightly +pip install tb-nightly xprof-nightly ``` ### Programmatic capture You can instrument your code to capture a profiler trace via the -{func}`jax.profiler.start_trace` and {func}`jax.profiler.stop_trace` -methods. Call {func}`~jax.profiler.start_trace` with the directory to write -trace files to. This should be the same `--logdir` directory used to start -TensorBoard. Then, you can use TensorBoard to view the traces. +{func}`jax.profiler.start_trace` and {func}`jax.profiler.stop_trace` methods. +Call {func}`~jax.profiler.start_trace` with the directory to write trace files +to. This should be the same `--logdir` directory used to start XProf. +Then, you can XProf to view the traces. For example, to take a profiler trace: ```python import jax -jax.profiler.start_trace("/tmp/tensorboard") +jax.profiler.start_trace("/tmp/profile-data") # Run the operations to be profiled key = jax.random.key(0) @@ -133,49 +131,51 @@ alternative to `start_trace` and `stop_trace`: ```python import jax -with jax.profiler.trace("/tmp/tensorboard"): +with jax.profiler.trace("/tmp/profile-data"): key = jax.random.key(0) x = jax.random.normal(key, (5000, 5000)) y = x @ x y.block_until_ready() ``` -To view the trace, first start TensorBoard if you haven't already: +### Viewing the trace + +After capturing a trace, you can view it using the XProf UI. + +You can launch the profiler UI directly using the standalone XProf command by +pointing it to your log directory: ```shell -$ tensorboard --logdir=/tmp/tensorboard -[...] -Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all -TensorBoard 2.5.0 at http://localhost:6006/ (Press CTRL+C to quit) +$ xprof --port 8791 /tmp/profile-data +Attempting to start XProf server: + Log Directory: /tmp/profile-data + Port: 8791 +XProf at http://localhost:8791/ (Press CTRL+C to quit) ``` -You should be able to load TensorBoard at in this -example. You can specify a different port with the `--port` flag. See -{ref}`remote_profiling` below if running JAX on a remote server. - -Then, either select "Profile" in the upper-right dropdown menu, or go directly -to . Available traces appear in the "Runs" -dropdown menu on the left. Select the run you're interested in, and then under -"Tools", select `trace_viewer`. You should now see a timeline of the -execution. You can use the WASD keys to navigate the trace, and click or drag to -select events to see more details at the bottom. See [these TensorFlow -docs](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras#use_the_tensorflow_profiler_to_profile_model_training_performance) -for more details on using the trace viewer. +Navigate to the provided URL (e.g., http://localhost:8791/) in your browser +to view the profile. -You can also use the `memory_viewer`, `op_profile`, and `graph_viewer` tools. +Available traces appear in the "Runs" dropdown menu on the left. Select the +run you're interested in, and then under the "Tools" dropdown, select +trace_viewer. You should now see a timeline of the execution. You can use the +WASD keys to navigate the trace, and click or drag to select events for more +details. See +[these TensorFlow docs](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras#use_the_tensorflow_profiler_to_profile_model_training_performance) +for more details on using the trace viewer. -### Manual capture via TensorBoard +### Manual capture via XProf The following are instructions for capturing a manually-triggered N-second trace from a running program. -1. Start a TensorBoard server: +1. Start an XProf server: ```shell - tensorboard --logdir /tmp/tensorboard/ + xprof --logdir /tmp/profile-data/ ``` - You should be able to load TensorBoard at . You can + You should be able to load XProf at . You can specify a different port with the `--port` flag. See {ref}`remote_profiling` below if running JAX on a remote server.

@@ -187,7 +187,7 @@ from a running program. jax.profiler.start_server(9999) ``` - This starts the profiler server that TensorBoard connects to. The profiler + This starts the profiler server that XProf connects to. The profiler server must be running before you move on to the next step. When you're done using the server, you can call `jax.profiler.stop_server()` to shut it down. @@ -200,7 +200,7 @@ from a running program. beginning of the program and use `time.sleep()` to give you enough time to start the capture.

-1. Open , and click the "CAPTURE PROFILE" button +1. Open , and click the "CAPTURE PROFILE" button in the upper left. Enter "localhost:9999" as the profile service URL (this is the address of the profiler server you started in the previous step). Enter the number of milliseconds you'd like to profile for, and click "CAPTURE".

-1. After the capture finishes, TensorBoard should automatically refresh. (Not - all of the TensorBoard profiling features are hooked up with JAX, so it may +1. After the capture finishes, XProf should automatically refresh. (Not + all of the XProf profiling features are hooked up with JAX, so it may initially look like nothing was captured.) On the left under "Tools", select `trace_viewer`. - You should now see a timeline of the execution. You can use the WASD keys to - navigate the trace, and click or drag to select events to see more details at - the bottom. See [these TensorFlow - docs](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras#use_the_tensorflow_profiler_to_profile_model_training_performance) - for more details on using the trace viewer. +You should now see a timeline of the execution. You can use the WASD keys to +navigate the trace, and click or drag to select events to see more details at +the bottom. See [these XProf docs](https://openxla.org/xprof/trace_viewer) +for more details on using the trace viewer. + +You can also use the following tools: - You can also use the `memory_viewer`, `op_profile`, and `graph_viewer` - tools.

+- [Framework Op Stats](https://openxla.org/xprof/framework_op_stats) +- [Graph Viewer](https://openxla.org/xprof/graph_viewer) +- [HLO Op Stats](https://openxla.org/xprof/hlo_op_stats) +- [Memory Profile](https://openxla.org/xprof/memory_profile) +- [Memory Viewer](https://openxla.org/xprof/memory_viewer) +- [HLO Op Profile](https://openxla.org/xprof/hlo_op_profile) +- [Roofline Model](https://openxla.org/xprof/roofline_analysis)

+ +### XProf and Tensorboard + +XProf is the underlying tool that powers the profiling and trace capturing +functionality in Tensorboard. As long as `xprof` is installed, a "Profile" tab +will be present within Tensorboard. Using this is identical to launching XProf +independently, as long as it is launched pointing to the same log directory. +This includes profile capture, analysis, and viewing functionality. XProf +supplants the `tensorboard_plugin_profile` functionality that was previously +recommended. + +```shell +$ tensorboard --logdir=/tmp/profile-data +[...] +Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all +TensorBoard 2.19.0 at http://localhost:6006/ (Press CTRL+C to quit) +``` ### Adding custom trace events @@ -231,6 +254,135 @@ functions. You can add your own events and functions by using {class}`jax.profiler.TraceAnnotation` and {func}`jax.profiler.annotate_function` in your code. +### Configuring profiler options + +The `start_trace` method accepts an optional `profiler_options` parameter, which +allows for fine-grained control over the profiler's behavior. This parameter +should be an instance of `jax.profiler.ProfileOptions`. + + +For example, to disable all python and host traces: + +```python +import jax + +options = jax.profiler.ProfileOptions() +options.python_tracer_level = 0 +options.host_tracer_level = 0 +jax.profiler.start_trace("/tmp/profile-data", profiler_options=options) + +# Run the operations to be profiled +key = jax.random.key(0) +x = jax.random.normal(key, (5000, 5000)) +y = x @ x +y.block_until_ready() + +jax.profiler.stop_trace() +``` + +#### General options + +1. `host_tracer_level`: Sets the trace level for host-side activities. + + Supported Values: + + `0`: Disables host (CPU) tracing entirely. + + `1`: Enables tracing of only user-instrumented TraceMe events. + + `2`: Includes level 1 traces plus high-level program execution details like + expensive XLA operations (default). + + `3`: Includes level 2 traces plus more verbose, low-level program execution + details such as cheap XLA operations. + +2. `device_tracer_level`: Controls whether device tracing is enabled. + + Supported Values: + + `0`: Disables device tracing. + + `1`: Enables device tracing (default). + +3. `python_tracer_level`: Controls whether Python tracing is enabled. + + Supported Values: + + `0`: Disables Python function call tracing (default). + + `1`: Enables Python tracing. + +#### Advanced configuration options + +##### TPU options + +1. `tpu_trace_mode`: Specifies the mode for TPU tracing. + + Supported Values: + + `TRACE_ONLY_HOST`: This means only host-side (CPU) activities are traced, + and no device (TPU/GPU) traces are collected. + + `TRACE_ONLY_XLA`: This means only XLA-level operations on the device are + traced. + + `TRACE_COMPUTE`: This traces compute operations on the device. + + `TRACE_COMPUTE_AND_SYNC`: This traces both compute operations and + synchronization events on the device. + + If "tpu_trace_mode" is not provided the trace_mode defaults to + TRACE_ONLY_XLA. + +2. `tpu_num_sparse_cores_to_trace`: Specifies the number of sparse cores to + trace on the TPU. +3. `tpu_num_sparse_core_tiles_to_trace`: Specifies the number of tiles within + each sparse core to trace on the TPU. +4. `tpu_num_chips_to_profile_per_task`: Specifies the number of TPU chips to + profile per task. + +##### GPU options + +The following options are available for GPU profiling: + +* `gpu_max_callback_api_events`: Sets the maximum number of events collected + by the CUPTI callback API. Defaults to `2*1024*1024`. +* `gpu_max_activity_api_events`: Sets the maximum number of events collected + by the CUPTI activity API. Defaults to `2*1024*1024`. +* `gpu_max_annotation_strings`: Sets the maximum number of annotation + strings that can be collected. Defaults to `1024*1024`. +* `gpu_enable_nvtx_tracking`: Enables NVTX tracking in CUPTI. Defaults to + `False`. +* `gpu_enable_cupti_activity_graph_trace`: Enables CUPTI activity graph + tracing for CUDA graphs. Defaults to `False`. +* `gpu_pm_sample_counters`: A comma-separated string of GPU + Performance Monitoring metrics to collect using CUPTI's PM sampling feature + (e.g. `"sm__cycles_active.avg.pct_of_peak_sustained_elapsed"`). PM sampling + is disabled by default. For available metrics, see + [NVIDIA's CUPTI documentation](https://docs.nvidia.com/cupti/main/main.html#metrics-table). +* `gpu_pm_sample_interval_us`: Sets the sampling interval in microseconds + for CUPTI PM sampling. Defaults to `500`. +* `gpu_pm_sample_buffer_size_per_gpu_mb`: Sets the system memory buffer size + per device in MB for CUPTI PM sampling. Defaults to 64MB. The maximum + supported value is 4GB. +* `gpu_num_chips_to_profile_per_task`: Specifies the number of GPU devices to + profile per task. If not specified, set to 0, or set to an invalid value, + all available GPUs will be profiled. This can be used to decrease the trace + collection size. +* `gpu_dump_graph_node_mapping`: If enabled, dumps CUDA graph node + mapping information into the trace. Defaults to `False`. + +For example: + +``` +options = ProfileOptions() +options.advanced_configuration = {"tpu_trace_mode" : "TRACE_ONLY_HOST", "tpu_num_sparse_cores_to_trace" : 2} + +``` + +Returns InvalidArgumentError if any unrecognized keys or option values are +found. + ### Troubleshooting #### GPU profiling @@ -308,8 +460,8 @@ replace, so it may be necessary to uninstall everything and reinstall a single version: ```shell -pip uninstall tensorflow tf-nightly tensorboard tb-nightly -pip install tensorflow +pip uninstall tensorflow tf-nightly tensorboard tb-nightly xprof xprof-nightly tensorboard-plugin-profile tbp-nightly +pip install tensorboard xprof ``` ## Nsight diff --git a/docs/pytrees.md b/docs/pytrees.md index a39c36db5de6..cde86ce1323f 100644 --- a/docs/pytrees.md +++ b/docs/pytrees.md @@ -1,338 +1,325 @@ --- jupytext: + formats: md:myst text_representation: extension: .md format_name: myst + format_version: 0.13 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python name: python3 -language_info: - name: python - file_extension: .py --- -(pytrees)= +```{code-cell} +:tags: [remove-cell] + +# This ensures that code cell tracebacks appearing below will be concise. +%xmode minimal +``` +(pytrees)= +(working-with-pytrees)= # Pytrees - + + +JAX has built-in support for objects that look like dictionaries (dicts) of arrays, or lists of lists of dicts, or other nested structures — in JAX these are called pytrees. +This section will explain how to use them, provide useful code examples, and point out common "gotchas" and patterns. +For an explanation of how to create custom pytrees, see {doc}`custom_pytrees`. + +(pytrees-what-is-a-pytree)= ## What is a pytree? -In JAX, we use the term *pytree* to refer to a tree-like structure built out of -container-like Python objects. Classes are considered container-like if they -are in the pytree registry, which by default includes lists, tuples, and dicts. -That is: +A pytree is a container-like structure built out of container-like Python objects — “leaf” pytrees and/or more pytrees. A pytree can include lists, tuples, and dicts. A leaf is anything that’s not a pytree, such as an array, but a single leaf is also a pytree. -1. any object whose type is *not* in the pytree container registry is - considered a *leaf* pytree; -2. any object whose type is in the pytree container registry, and which - contains pytrees, is considered a pytree. +In the context of machine learning (ML), a pytree can contain: -For each entry in the pytree container registry, a container-like type is -registered with a pair of functions that specify how to convert an instance of -the container type to a `(children, metadata)` pair and how to convert such a -pair back to an instance of the container type. Using these functions, JAX can -canonicalize any tree of registered container objects into tuples. +- Model parameters +- Dataset entries +- Reinforcement learning agent observations -Example pytrees: +When working with datasets, you can often come across pytrees (such as lists of lists of dicts). -``` -[1, "a", object()] # 3 leaves +Below is an example of a simple pytree. In JAX, you can use {func}`jax.tree.leaves`, to extract the flattened leaves from the trees, as demonstrated here: -(1, (2, 3), ()) # 3 leaves +```{code-cell} +import jax +import jax.numpy as jnp -[1, {"k1": 2, "k2": (3, 4)}, 5] # 5 leaves +example_trees = [ + [1, 'a', object()], + (1, (2, 3), ()), + [1, {'k1': 2, 'k2': (3, 4)}, 5], + {'a': 2, 'b': (2, 3)}, + jnp.array([1, 2, 3]), +] + +# Print how many leaves the pytrees have. +for pytree in example_trees: + # This `jax.tree.leaves()` method extracts the flattened leaves from the pytrees. + leaves = jax.tree.leaves(pytree) + print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}") ``` -JAX can be extended to consider other container types as pytrees; see -{ref}`extending-pytrees` below. +Any tree-like structure built out of container-like Python objects can be treated as a pytree in JAX. +Classes are considered container-like if they are in the pytree registry, which by default includes lists, tuples, and dicts. Any object whose type is *not* in the pytree container registry will be treated as a leaf node in the tree. -## Pytrees and JAX functions +The pytree registry can be extended to include user-defined container classes by registering the class +with functions that specify how to flatten the tree; see {ref}`pytrees-custom-pytree-nodes` below. -Many JAX functions, like {func}`jax.lax.scan`, operate over pytrees of arrays. -JAX function transformations can be applied to functions that accept as input -and produce as output pytrees of arrays. +(pytrees-common-pytree-functions)= +## Common pytree functions -## Applying optional parameters to pytrees +JAX provides a number of utilities to operate over pytrees. These can be found in the {mod}`jax.tree_util` subpackage; +for convenience many of these have aliases in the {mod}`jax.tree` module. -Some JAX function transformations take optional parameters that specify how -certain input or output values should be treated (e.g. the `in_axes` and -`out_axes` arguments to {func}`~jax.vmap`). These parameters can also be pytrees, -and their structure must correspond to the pytree structure of the corresponding -arguments. In particular, to be able to "match up" leaves in these parameter -pytrees with values in the argument pytrees, the parameter pytrees are often -constrained to be tree prefixes of the argument pytrees. +### Common function: `jax.tree.map` -For example, if we pass the following input to {func}`~jax.vmap` (note that the input -arguments to a function are considered a tuple): +The most commonly used pytree function is {func}`jax.tree.map`. It works analogously to Python's native `map`, but transparently operates over entire pytrees. -``` -(a1, {"k1": a2, "k2": a3}) -``` +Here's an example: -We can use the following `in_axes` pytree to specify that only the `k2` -argument is mapped (`axis=0`) and the rest aren't mapped over -(`axis=None`): +```{code-cell} +list_of_lists = [ + [1, 2, 3], + [1, 2], + [1, 2, 3, 4] +] -``` -(None, {"k1": None, "k2": 0}) +jax.tree.map(lambda x: x*2, list_of_lists) ``` -The optional parameter pytree structure must match that of the main input -pytree. However, the optional parameters can optionally be specified as a -"prefix" pytree, meaning that a single leaf value can be applied to an entire -sub-pytree. For example, if we have the same {func}`~jax.vmap` input as above, -but wish to only map over the dictionary argument, we can use: +{func}`jax.tree.map` also allows mapping a [N-ary](https://en.wikipedia.org/wiki/N-ary) function over multiple arguments. For example: +```{code-cell} +another_list_of_lists = list_of_lists +jax.tree.map(lambda x, y: x+y, list_of_lists, another_list_of_lists) ``` -(None, 0) # equivalent to (None, {"k1": 0, "k2": 0}) -``` -Or, if we want every argument to be mapped, we can simply write a single leaf -value that is applied over the entire argument tuple pytree: +When using multiple arguments with {func}`jax.tree.map`, the structure of the inputs must exactly match. That is, lists must have the same number of elements, dicts must have the same keys, etc. + +(pytrees-example-jax-tree-map-ml)= +### Example of `jax.tree.map` with ML model parameters + +This example demonstrates how pytree operations can be useful when training a simple [multi-layer perceptron (MLP)](https://en.wikipedia.org/wiki/Multilayer_perceptron). + +Begin with defining the initial model parameters: +```{code-cell} +import numpy as np + +def init_mlp_params(layer_widths): + params = [] + for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]): + params.append( + dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in), + biases=np.ones(shape=(n_out,)) + ) + ) + return params + +params = init_mlp_params([1, 128, 128, 1]) ``` -0 + +Use {func}`jax.tree.map` to check the shapes of the initial parameters: + +```{code-cell} +jax.tree.map(lambda x: x.shape, params) ``` -This happens to be the default `in_axes` value for {func}`~jax.vmap`! +Next, define the functions for training the MLP model: -The same logic applies to other optional parameters that refer to specific input -or output values of a transformed function, e.g. `vmap`'s `out_axes`. +```{code-cell} +# Define the forward pass. +def forward(params, x): + *hidden, last = params + for layer in hidden: + x = jax.nn.relu(x @ layer['weights'] + layer['biases']) + return x @ last['weights'] + last['biases'] + +# Define the loss function. +def loss_fn(params, x, y): + return jnp.mean((forward(params, x) - y) ** 2) + +# Set the learning rate. +LEARNING_RATE = 0.0001 + +# Using the stochastic gradient descent, define the parameter update function. +# Apply `@jax.jit` for JIT compilation (speed). +@jax.jit +def update(params, x, y): + # Calculate the gradients with `jax.grad`. + grads = jax.grad(loss_fn)(params, x, y) + # Note that `grads` is a pytree with the same structure as `params`. + # `jax.grad` is one of many JAX functions that has + # built-in support for pytrees. + # This is useful - you can apply the SGD update using JAX pytree utilities. + return jax.tree.map( + lambda p, g: p - LEARNING_RATE * g, params, grads + ) +``` ## Viewing the pytree definition of an object To view the pytree definition of an arbitrary `object` for debugging purposes, you can use: -``` +```{code-cell} from jax.tree_util import tree_structure print(tree_structure(object)) ``` -## Developer information +(pytree-and-jax-transformations)= +## Pytrees and JAX transformations -*This is primarily JAX internal documentation, end-users are not supposed to need -to understand this to use JAX, except when registering new user-defined -container types with JAX. Some of these details may change.* +Many JAX functions, like {func}`jax.lax.scan`, operate over pytrees of arrays. In addition, all JAX function transformations can be applied to functions that accept as input and produce as output pytrees of arrays. -### Internal pytree handling +Some JAX function transformations take optional parameters that specify how certain input or output values should be treated (such as the `in_axes` and `out_axes` arguments to {func}`jax.vmap`). These parameters can also be pytrees, and their structure must correspond to the pytree structure of the corresponding arguments. In particular, to be able to “match up” leaves in these parameter pytrees with values in the argument pytrees, the parameter pytrees are often constrained to be tree prefixes of the argument pytrees. -JAX flattens pytrees into lists of leaves at the `api.py` boundary (and also -in control flow primitives). This keeps downstream JAX internals simpler: -transformations like {func}`~jax.grad`, {func}`~jax.jit`, and {func}`~jax.vmap` -can handle user functions that accept and return the myriad different Python -containers, while all the other parts of the system can operate on functions -that only take (multiple) array arguments and always return a flat list of arrays. +For example, if you pass the following input to {func}`jax.vmap` (note that the input arguments to a function are considered a tuple): -When JAX flattens a pytree it will produce a list of leaves and a `treedef` -object that encodes the structure of the original value. The `treedef` can -then be used to construct a matching structured value after transforming the -leaves. Pytrees are tree-like, rather than DAG-like or graph-like, in that we -handle them assuming referential transparency and that they can't contain -reference cycles. - -Here is a simple example: +```python +vmap(f, in_axes=(a1, {"k1": a2, "k2": a3})) +``` -```{code-cell} -:tags: [remove-cell] +then you can use the following `in_axes` pytree to specify that only the `k2` argument is mapped (`axis=0`), and the rest aren’t mapped over (`axis=None`): -# Execute this to consume & hide the GPU warning. -import jax.numpy as _jnp -_jnp.arange(10) +```python +vmap(f, in_axes=(None, {"k1": None, "k2": 0})) ``` -```{code-cell} -from jax.tree_util import tree_flatten, tree_unflatten -import jax.numpy as jnp +The optional parameter pytree structure must match that of the main input pytree. However, the optional parameters can optionally be specified as a “prefix” pytree, meaning that a single leaf value can be applied to an entire sub-pytree. -# The structured value to be transformed -value_structured = [1., (2., 3.)] +For example, if you have the same {func}`jax.vmap` input as above, but wish to only map over the dictionary argument, you can use: -# The leaves in value_flat correspond to the `*` markers in value_tree -value_flat, value_tree = tree_flatten(value_structured) -print(f"{value_flat=}\n{value_tree=}") +```python +vmap(f, in_axes=(None, 0)) # equivalent to (None, {"k1": 0, "k2": 0}) +``` -# Transform the flat value list using an element-wise numeric transformer -transformed_flat = list(map(lambda v: v * 2., value_flat)) -print(f"{transformed_flat=}") +Alternatively, if you want every argument to be mapped, you can write a single leaf value that is applied over the entire argument tuple pytree: -# Reconstruct the structured output, using the original -transformed_structured = tree_unflatten(value_tree, transformed_flat) -print(f"{transformed_structured=}") +```python +vmap(f, in_axes=0) # equivalent to (0, {"k1": 0, "k2": 0}) ``` -By default, pytree containers can be lists, tuples, dicts, namedtuple, None, -OrderedDict. Other types of values, including numeric and ndarray values, are -treated as leaves: +This happens to be the default `in_axes` value for {func}`jax.vmap`. -```{code-cell} -from collections import namedtuple -Point = namedtuple('Point', ['x', 'y']) - -example_containers = [ - (1., [2., 3.]), - (1., {'b': 2., 'a': 3.}), - 1., - None, - jnp.zeros(2), - Point(1., 2.) -] -def show_example(structured): - flat, tree = tree_flatten(structured) - unflattened = tree_unflatten(tree, flat) - print(f"{structured=}\n {flat=}\n {tree=}\n {unflattened=}") +The same logic applies to other optional parameters that refer to specific input or output values of a transformed function, such as `out_axes` in {func}`jax.vmap`. -for structured in example_containers: - show_example(structured) -``` +(pytrees-explicity-key-paths)= +## Explicit key paths -(extending-pytrees)= +In a pytree each leaf has a _key path_. A key path for a leaf is a `list` of _keys_, where the length of the list is equal to the depth of the leaf in the pytree . Each _key_ is a [hashable object](https://docs.python.org/3/glossary.html#term-hashable) that represents an index into the corresponding pytree node type. The type of the key depends on the pytree node type; for example, the type of keys for `dict`s is different from the type of keys for `tuple`s. -### Extending pytrees +For built-in pytree node types, the set of keys for any pytree node instance is unique. For a pytree comprising nodes with this property, the key path for each leaf is unique. + +JAX has the following `jax.tree_util.*` methods for working with key paths: -By default, any part of a structured value that is not recognized as an -internal pytree node (i.e. container-like) is treated as a leaf: +- {func}`jax.tree_util.tree_flatten_with_path`: Works similarly to {func}`jax.tree.flatten`, but returns key paths. +- {func}`jax.tree_util.tree_map_with_path`: Works similarly to {func}`jax.tree.map`, but the function also takes key paths as arguments. +- {func}`jax.tree_util.keystr`: Given a general key path, returns a reader-friendly string expression. + +For example, one use case is to print debugging information related to a certain leaf value: ```{code-cell} -class Special(object): - def __init__(self, x, y): - self.x = x - self.y = y +import collections - def __repr__(self): - return "Special(x={}, y={})".format(self.x, self.y) +ATuple = collections.namedtuple("ATuple", ('name')) +tree = [1, {'k1': 2, 'k2': (3, 4)}, ATuple('foo')] +flattened, _ = jax.tree_util.tree_flatten_with_path(tree) -show_example(Special(1., 2.)) +for key_path, value in flattened: + print(f'Value of tree{jax.tree_util.keystr(key_path)}: {value}') ``` -The set of Python types that are considered internal pytree nodes is extensible, -through a global registry of types, and values of registered types are traversed -recursively. To register a new type, you can use -{func}`~jax.tree_util.register_pytree_node`: +To express key paths, JAX provides a few default key types for the built-in pytree node types, namely: -```{code-cell} -from jax.tree_util import register_pytree_node - -class RegisteredSpecial(Special): - def __repr__(self): - return "RegisteredSpecial(x={}, y={})".format(self.x, self.y) - -def special_flatten(v): - """Specifies a flattening recipe. - - Params: - v: the value of registered type to flatten. - Returns: - a pair of an iterable with the children to be flattened recursively, - and some opaque auxiliary data to pass back to the unflattening recipe. - The auxiliary data is stored in the treedef for use during unflattening. - The auxiliary data could be used, e.g., for dictionary keys. - """ - children = (v.x, v.y) - aux_data = None - return (children, aux_data) +* `SequenceKey(idx: int)`: For lists and tuples. +* `DictKey(key: Hashable)`: For dictionaries. +* `GetAttrKey(name: str)`: For `namedtuple`s and preferably custom pytree nodes (more in the next section) -def special_unflatten(aux_data, children): - """Specifies an unflattening recipe. +You are free to define your own key types for your custom nodes. They will work with {func}`jax.tree_util.keystr` as long as their `__str__()` method is also overridden with a reader-friendly expression. - Params: - aux_data: the opaque data that was specified during flattening of the - current treedef. - children: the unflattened children +```{code-cell} +for key_path, _ in flattened: + print(f'Key path of tree{jax.tree_util.keystr(key_path)}: {repr(key_path)}') +``` - Returns: - a re-constructed object of the registered type, using the specified - children and auxiliary data. - """ - return RegisteredSpecial(*children) +(pytrees-common-pytree-gotchas)= +## Common pytree gotchas -# Global registration -register_pytree_node( - RegisteredSpecial, - special_flatten, # tell JAX what are the children nodes - special_unflatten # tell JAX how to pack back into a RegisteredSpecial -) +This section covers some of the most common problems ("gotchas") encountered when using JAX pytrees. -show_example(RegisteredSpecial(1., 2.)) -``` +### Mistaking pytree nodes for leaves -Alternatively, you can define appropriate `tree_flatten` and `tree_unflatten` methods -on your class and decorate it with {func}`~jax.tree_util.register_pytree_node_class`: +A common gotcha to look out for is accidentally introducing _tree nodes_ instead of _leaves_: ```{code-cell} -from jax.tree_util import register_pytree_node_class +a_tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))] -@register_pytree_node_class -class RegisteredSpecial2(Special): - def __repr__(self): - return "RegisteredSpecial2(x={}, y={})".format(self.x, self.y) +# Try to make another pytree with ones instead of zeros. +shapes = jax.tree.map(lambda x: x.shape, a_tree) +jax.tree.map(jnp.ones, shapes) +``` - def tree_flatten(self): - children = (self.x, self.y) - aux_data = None - return (children, aux_data) +What happened here is that the `shape` of an array is a tuple, which is a pytree node, with its elements as leaves. Thus, in the map, instead of calling `jnp.ones` on e.g. `(2, 3)`, it's called on `2` and `3`. - @classmethod - def tree_unflatten(cls, aux_data, children): - return cls(*children) +The solution will depend on the specifics, but there are two broadly applicable options: -show_example(RegisteredSpecial2(1., 2.)) -``` +- Rewrite the code to avoid the intermediate {func}`jax.tree.map`. +- Convert the tuple into a NumPy array (`np.array`) or a JAX NumPy array (`jnp.array`), which makes the entire sequence a leaf. -When defining unflattening functions, in general `children` should contain all the -dynamic elements of the data structure (arrays, dynamic scalars, and pytrees), while -`aux_data` should contain all the static elements that will be rolled into the `treedef` -structure. JAX sometimes needs to compare `treedef` for equality, or compute its hash -for use in the JIT cache, and so care must be taken to ensure that the auxiliary data -specified in the flattening recipe supports meaningful hashing and equality comparisons. +### Handling of `None` by `jax.tree_util` -The whole set of functions for operating on pytrees are in {mod}`jax.tree_util`. +`jax.tree_util` functions treat `None` as the absence of a pytree node, not as a leaf: -### Custom PyTrees and Initialization +```{code-cell} +jax.tree.leaves([None, None, None]) +``` -One common gotcha with user-defined PyTree objects is that JAX transformations occasionally -initialize them with unexpected values, so that any input validation done at initialization -may fail. For example: +To treat `None` as a leaf, you can use the `is_leaf` argument: ```{code-cell} -:tags: [skip-execution] -class MyTree: - def __init__(self, a): - self.a = jnp.asarray(a) +jax.tree.leaves([None, None, None], is_leaf=lambda x: x is None) +``` -register_pytree_node(MyTree, lambda tree: ((tree.a,), None), - lambda _, args: MyTree(*args)) +(pytrees-common-pytree-patterns)= +## Common pytree patterns -tree = MyTree(jnp.arange(5.0)) +This section covers some of the most common patterns with JAX pytrees. -jax.vmap(lambda x: x)(tree) # Error because object() is passed to MyTree. -jax.jacobian(lambda x: x)(tree) # Error because MyTree(...) is passed to MyTree -``` -In the first case, JAX's internals use arrays of `object()` values to infer the structure -of the tree; in the second case, the jacobian of a function mapping a tree to a tree -is defined as a tree of trees. +### Transposing pytrees with `jax.tree.map` and `jax.tree.transpose` + +To transpose a pytree (turn a list of trees into a tree of lists), JAX has two functions: {func}`jax.tree.map` (more basic) and {func}`jax.tree.transpose` (more flexible, complex and verbose). + +**Option 1:** Use {func}`jax.tree.map`. Here's an example: -For this reason, the `__init__` and `__new__` methods of custom PyTree classes should -generally avoid doing any array conversion or other input validation, or else -anticipate and handle these special cases. For example: ```{code-cell} -class MyTree: - def __init__(self, a): - if not (type(a) is object or a is None or isinstance(a, MyTree)): - a = jnp.asarray(a) - self.a = a +def tree_transpose(list_of_trees): + """ + Converts a list of trees of identical structure into a single tree of lists. + """ + return jax.tree.map(lambda *xs: list(xs), *list_of_trees) + +# Convert a dataset from row-major to column-major. +episode_steps = [dict(t=1, obs=3), dict(t=2, obs=4)] +tree_transpose(episode_steps) ``` -Another possibility is to structure your `tree_unflatten` function so that it avoids -calling `__init__`; for example: + +**Option 2:** For more complex transposes, use {func}`jax.tree.transpose`, which is more verbose, but allows you specify the structure of the inner and outer pytree for more flexibility. For example: + ```{code-cell} -def tree_unflatten(aux_data, children): - del aux_data # unused in this class - obj = object.__new__(MyTree) - obj.a = a - return obj +jax.tree.transpose( + outer_treedef = jax.tree.structure([0 for e in episode_steps]), + inner_treedef = jax.tree.structure(episode_steps[0]), + pytree_to_transpose = episode_steps +) ``` -If you go this route, make sure that your `tree_unflatten` function stays in-sync with -`__init__` if and when the code is updated. \ No newline at end of file + +(extending-pytrees)= +### Extending pytrees + +Material on extending pytrees has been moved to {ref}`pytrees-custom-pytree-nodes`. diff --git a/docs/quickstart.md b/docs/quickstart.md deleted file mode 100644 index 77cbb9d46ab8..000000000000 --- a/docs/quickstart.md +++ /dev/null @@ -1,223 +0,0 @@ ---- -jupytext: - formats: md:myst - text_representation: - extension: .md - format_name: myst - format_version: 0.13 - jupytext_version: 1.16.4 -kernelspec: - display_name: Python 3 - language: python - name: python3 ---- - -# Quickstart - - - -**JAX is a library for array-oriented numerical computation (*à la* [NumPy](https://numpy.org/)), with automatic differentiation and JIT compilation to enable high-performance machine learning research**. - -This document provides a quick overview of essential JAX features, so you can get started with JAX quickly: - -* JAX provides a unified NumPy-like interface to computations that run on CPU, GPU, or TPU, in local or distributed settings. -* JAX features built-in Just-In-Time (JIT) compilation via [Open XLA](https://github.com/openxla), an open-source machine learning compiler ecosystem. -* JAX functions support efficient evaluation of gradients via its automatic differentiation transformations. -* JAX functions can be automatically vectorized to efficiently map them over arrays representing batches of inputs. - -## Installation - -JAX can be installed for CPU on Linux, Windows, and macOS directly from the [Python Package Index](https://pypi.org/project/jax/): -``` -pip install jax -``` -or, for NVIDIA GPU: -``` -pip install -U "jax[cuda12]" -``` -For more detailed platform-specific installation information, check out {ref}`installation`. - -## JAX as NumPy - -Most JAX usage is through the familiar {mod}`jax.numpy` API, which is typically imported under the `jnp` alias: - -```{code-cell} -import jax.numpy as jnp -``` - -With this import, you can immediately use JAX in a similar manner to typical NumPy programs, -including using NumPy-style array creation functions, Python functions and operators, and -array attributes and methods: - -```{code-cell} -def selu(x, alpha=1.67, lmbda=1.05): - return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) - -x = jnp.arange(5.0) -print(selu(x)) -``` - -You'll find a few differences between JAX arrays and NumPy arrays once you begin digging-in; -these are explored in [🔪 JAX - The Sharp Bits 🔪](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html). - -## Just-in-time compilation with {func}`jax.jit` -JAX runs transparently on the GPU or TPU (falling back to CPU if you don't have one). However, in the above example, JAX is dispatching kernels to the chip one operation at a time. If we have a sequence of operations, we can use the {func}`jax.jit` function to compile this sequence of operations together using XLA. - -We can use IPython's `%timeit` to quickly benchmark our `selu` function, using `block_until_ready()` to -account for JAX's dynamic dispatch (See {ref}`async-dispatch`): - -```{code-cell} -from jax import random - -key = random.key(1701) -x = random.normal(key, (1_000_000,)) -%timeit selu(x).block_until_ready() -``` - -(notice we've used {mod}`jax.random` to generate some random numbers; for details on -how to generate random numbers in JAX, check out {ref}`pseudorandom-numbers`). - -We can speed the execution of this function with the {func}`jax.jit` transformation, -which will jit-compile the first time `selu` is called and will be cached thereafter. - -```{code-cell} -from jax import jit - -selu_jit = jit(selu) -_ = selu_jit(x) # compiles on first call -%timeit selu_jit(x).block_until_ready() -``` - -The above timing represents execution on CPU, but the same code can be run on GPU or -TPU, typically for an even greater speedup. - -For more on JIT compilation in JAX, check out {ref}`jit-compilation`. - -## Taking derivatives with {func}`jax.grad` - -In addition to transforming functions via JIT compilation, JAX also provides other -transformations. One such transformation is {func}`jax.grad`, which performs -[automatic differentiation (autodiff)](https://en.wikipedia.org/wiki/Automatic_differentiation): - -```{code-cell} -from jax import grad - -def sum_logistic(x): - return jnp.sum(1.0 / (1.0 + jnp.exp(-x))) - -x_small = jnp.arange(3.) -derivative_fn = grad(sum_logistic) -print(derivative_fn(x_small)) -``` - -Let's verify with finite differences that our result is correct. - -```{code-cell} -def first_finite_differences(f, x, eps=1E-3): - return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps) - for v in jnp.eye(len(x))]) - -print(first_finite_differences(sum_logistic, x_small)) -``` - -The {func}`~jax.grad` and {func}`~jax.jit` transformations compose and can be mixed arbitrarily. -In the above example we jitted `sum_logistic` and then took its derivative. We can go further: - -```{code-cell} -print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0)) -``` - -Beyond scalar-valued functions, the {func}`jax.jacobian` transformation can be -used to compute the full Jacobian matrix for vector-valued functions: - -```{code-cell} -from jax import jacobian -print(jacobian(jnp.exp)(x_small)) -``` - -For more advanced autodiff operations, you can use {func}`jax.vjp` for reverse-mode vector-Jacobian products, -and {func}`jax.jvp` and {func}`jax.linearize` for forward-mode Jacobian-vector products. -The two can be composed arbitrarily with one another, and with other JAX transformations. -For example, {func}`jax.jvp` and {func}`jax.vjp` are used to define the forward-mode {func}`jax.jacfwd` and reverse-mode {func}`jax.jacrev` for computing Jacobians in forward- and reverse-mode, respectively. -Here's one way to compose them to make a function that efficiently computes full Hessian matrices: - -```{code-cell} -from jax import jacfwd, jacrev -def hessian(fun): - return jit(jacfwd(jacrev(fun))) -print(hessian(sum_logistic)(x_small)) -``` - -This kind of composition produces efficient code in practice; this is more-or-less how JAX's built-in {func}`jax.hessian` function is implemented. - -For more on automatic differentiation in JAX, check out {ref}`automatic-differentiation`. - -## Auto-vectorization with {func}`jax.vmap` - -Another useful transformation is {func}`~jax.vmap`, the vectorizing map. -It has the familiar semantics of mapping a function along array axes, but instead of explicitly looping -over function calls, it transforms the function into a natively vectorized version for better performance. -When composed with {func}`~jax.jit`, it can be just as performant as manually rewriting your function -to operate over an extra batch dimension. - -We're going to work with a simple example, and promote matrix-vector products into matrix-matrix products using {func}`~jax.vmap`. -Although this is easy to do by hand in this specific case, the same technique can apply to more complicated functions. - -```{code-cell} -key1, key2 = random.split(key) -mat = random.normal(key1, (150, 100)) -batched_x = random.normal(key2, (10, 100)) - -def apply_matrix(x): - return jnp.dot(mat, x) -``` - -The `apply_matrix` function maps a vector to a vector, but we may want to apply it row-wise across a matrix. -We could do this by looping over the batch dimension in Python, but this usually results in poor performance. - -```{code-cell} -def naively_batched_apply_matrix(v_batched): - return jnp.stack([apply_matrix(v) for v in v_batched]) - -print('Naively batched') -%timeit naively_batched_apply_matrix(batched_x).block_until_ready() -``` - -A programmer familiar with the `jnp.dot` function might recognize that `apply_matrix` can -be rewritten to avoid explicit looping, using the built-in batching semantics of `jnp.dot`: - -```{code-cell} -import numpy as np - -@jit -def batched_apply_matrix(batched_x): - return jnp.dot(batched_x, mat.T) - -np.testing.assert_allclose(naively_batched_apply_matrix(batched_x), - batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4) -print('Manually batched') -%timeit batched_apply_matrix(batched_x).block_until_ready() -``` - -However, as functions become more complicated, this kind of manual batching becomes more difficult and error-prone. -The {func}`~jax.vmap` transformation is designed to automatically transform a function into a batch-aware version: - -```{code-cell} -from jax import vmap - -@jit -def vmap_batched_apply_matrix(batched_x): - return vmap(apply_matrix)(batched_x) - -np.testing.assert_allclose(naively_batched_apply_matrix(batched_x), - vmap_batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4) -print('Auto-vectorized with vmap') -%timeit vmap_batched_apply_matrix(batched_x).block_until_ready() -``` - -As you would expect, {func}`~jax.vmap` can be arbitrarily composed with {func}`~jax.jit`, -{func}`~jax.grad`, and any other JAX transformation. - -For more on automatic vectorization in JAX, check out {ref}`automatic-vectorization`. - -This is just a taste of what JAX can do. We're really excited to see what you do with it! diff --git a/docs/random-numbers.md b/docs/random-numbers.md index 00f77e3473bb..5562dc3f43d5 100644 --- a/docs/random-numbers.md +++ b/docs/random-numbers.md @@ -150,9 +150,9 @@ print(random.normal(key)) print(random.normal(key)) ``` -Re-using the same key, even with different {mod}`~jax.random` APIs, can result in correlated outputs, which is generally undesirable. +Reusing the same key, even with different {mod}`~jax.random` APIs, can result in correlated outputs, which is generally undesirable. -**The rule of thumb is: never reuse keys (unless you want identical outputs).** +**The rule of thumb is: never reuse keys (unless you want identical outputs). Reusing the same state will cause __sadness__ and __monotony__, depriving the end user of __lifegiving chaos__.** JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's splittable. That is, its design allows us to fork the PRNG state into new PRNGs for use with parallel stochastic generation. In order to generate different and independent samples, you must {func}`~jax.random.split` the key explicitly before passing it to a random function: diff --git a/docs/rank_promotion_warning.rst b/docs/rank_promotion_warning.rst index 5e4e7ec65cbc..6ec0000e2ffc 100644 --- a/docs/rank_promotion_warning.rst +++ b/docs/rank_promotion_warning.rst @@ -9,14 +9,14 @@ surprising bugs where a silent rank promotion masks an underlying shape error. Here's an example of rank promotion: ->>> import numpy as np ->>> x = np.arange(12).reshape(4, 3) ->>> y = np.array([0, 1, 0]) +>>> from jax import numpy as jnp +>>> x = jnp.arange(12).reshape(4, 3) +>>> y = jnp.array([0, 1, 0]) >>> x + y -array([[ 0, 2, 2], +Array([[ 0, 2, 2], [ 3, 5, 5], [ 6, 8, 8], - [ 9, 11, 11]]) + [ 9, 11, 11]], dtype=int32) To avoid potential surprises, :code:`jax.numpy` is configurable so that expressions requiring rank promotion can lead to a warning, error, or can be diff --git a/docs/requirements.txt b/docs/requirements.txt index 5d49222bbb42..f88eb4a45db5 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,6 +1,7 @@ absl-py ipython>=8.8.0 # 8.7.0 has ipython3 lexer error pydata-sphinx-theme==0.14.4 # v0.15 breaks sidebar toggling +snowballstemmer<3.0.0 # v3.0.0 incompatible with older sphinx; missing stemmer sphinx>=7.3.2,<8.0 # 7.3.0 breaks sphinx-book-theme; 8.0 breaks myst-nb 1.1 sphinx-book-theme==1.1.1 # v1.1.2 requires pydata-sphinx-theme v0.15 sphinx-copybutton>=0.5.0 @@ -8,6 +9,7 @@ sphinx-remove-toctrees sphinx-design sphinxext-rediraffe myst-nb>=1.0.0 +sphinxcontrib-mermaid # Packages used for CI tests. flatbuffers @@ -21,4 +23,5 @@ pooch numpy rich[jupyter] cmake +cloudpickle .[ci] # Install jax from the current directory; jaxlib from pypi. diff --git a/docs/sharded-computation.ipynb b/docs/sharded-computation.ipynb index d3ddac4edbdb..b05ba18da7ef 100644 --- a/docs/sharded-computation.ipynb +++ b/docs/sharded-computation.ipynb @@ -7,27 +7,50 @@ "(sharded-computation)=\n", "# Introduction to parallel programming\n", "\n", - "\n", + "\n", "\n", "This tutorial serves as an introduction to device parallelism for Single-Program Multi-Data (SPMD) code in JAX. SPMD is a parallelism technique where the same computation, such as the forward pass of a neural network, can be run on different input data (for example, different inputs in a batch) in parallel on different devices, such as several GPUs or Google TPUs.\n", "\n", "The tutorial covers three modes of parallel computation:\n", "\n", - "- _Automatic parallelism via {func}`jax.jit`_: The compiler chooses the optimal computation strategy (a.k.a. \"the compiler takes the wheel\").\n", - "- _Semi-automated parallelism_ using {func}`jax.jit` and {func}`jax.lax.with_sharding_constraint`\n", - "- _Fully manual parallelism with manual control using {func}`jax.experimental.shard_map.shard_map`_: `shard_map` enables per-device code and explicit communication collectives\n", + "- _Automatic sharding via {func}`jax.jit`_: The compiler chooses the optimal computation strategy (a.k.a. \"the compiler takes the wheel\").\n", + "- *Explicit Sharding* (\\*new\\*) is similar to automatic sharding in that\n", + " you're writing a global-view program. The difference is that the sharding\n", + " of each array is part of the array's JAX-level type making it an explicit\n", + " part of the programming model. These shardings are propagated at the JAX\n", + " level and queryable at trace time. It's still the compiler's responsibility\n", + " to turn the whole-array program into per-device programs (turning `jnp.sum`\n", + " into `psum` for example) but the compiler is heavily constrained by the\n", + " user-supplied shardings.\n", + "- _Fully manual sharding with manual control using {func}`jax.shard_map`_: `shard_map` enables per-device code and explicit communication collectives\n", "\n", - "Using these schools of thought for SPMD, you can transform a function written for one device into a function that can run in parallel on multiple devices.\n", + "A summary table:\n", "\n", - "If you are running these examples in a Google Colab notebook, make sure that your hardware accelerator is the latest Google TPU by checking your notebook settings: **Runtime** > **Change runtime type** > **Hardware accelerator** > **TPU v2** (which provides eight devices to work with)." + "| Mode | View? | Explicit sharding? | Explicit Collectives? |\n", + "|---|---|---|---|\n", + "| Auto | Global | ❌ | ❌ |\n", + "| Explicit | Global | ✅ | ❌ |\n", + "| Manual | Per-device | ✅ | ✅ |\n", + "\n", + "Using these schools of thought for SPMD, you can transform a function written for one device into a function that can run in parallel on multiple devices." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7efa1e66", + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "\n", + "jax.config.update('jax_num_cpu_devices', 8)" ] }, { "cell_type": "code", "execution_count": 1, - "metadata": { - "outputId": "18905ae4-7b5e-4bb9-acb4-d8ab914cb456" - }, + "metadata": {}, "outputs": [ { "data": { @@ -48,7 +71,6 @@ } ], "source": [ - "import jax\n", "jax.devices()" ] }, @@ -84,7 +106,9 @@ } ], "source": [ + "import numpy as np\n", "import jax.numpy as jnp\n", + "\n", "arr = jnp.arange(32.0).reshape(4, 8)\n", "arr.devices()" ] @@ -264,51 +288,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "UEObolTqw4pp" - }, - "source": [ - "The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.\n", - "\n", - "The {class}`~jax.sharding.NamedSharding` includes a parameter called `memory_kind`. This parameter determines the type of memory to be used and defaults to `device`. You can set this parameter to `pinned_host` if you prefer to place it on the host.\n", - "\n", - "To create a new sharding that only differs from an existing sharding in terms of its memory kind, you can use the `with_memory_kind` method on the existing sharding." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "aKNeOHTJnqmS", - "outputId": "847c53ec-8b2e-4be0-f993-7fde7d77c0f2" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "pinned_host\n", - "device\n" - ] - } - ], - "source": [ - "s_host = jax.NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host')\n", - "s_dev = s_host.with_memory_kind('device')\n", - "arr_host = jax.device_put(arr, s_host)\n", - "arr_dev = jax.device_put(arr, s_dev)\n", - "print(arr_host.sharding.memory_kind)\n", - "print(arr_dev.sharding.memory_kind)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jDHYnVqHwaST" - }, + "metadata": {}, "source": [ "## 1. Automatic parallelism via `jit`\n", "\n", @@ -400,159 +380,170 @@ "id": "Q4N5mrr9i_ki" }, "source": [ - "The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on.\n", - "\n", - "### 1.1 Sharding transformation between memory types\n", - "\n", - "The output sharding of a {func}`jax.jit` function can differ from the input sharding if you specify the output sharding using the `out_shardings` parameter. Specifically, the `memory_kind` of the output can be different from that of the input array.\n", + "The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `4`, the second on `1` and `5`, and so on.\n", "\n", - "#### Example 1: Pinned host to device memory\n", + "## 2. Explicit sharding\n", "\n", - "In the example below, the {func}`jax.jit` function `f` takes an array sharded in `pinned_host` memory and generates an array in `device` memory." + "The main idea behind explicit shardings, (a.k.a. sharding-in-types), is that\n", + "the JAX-level _type_ of a value includes a description of how the value is sharded.\n", + "We can query the JAX-level type of any JAX value (or Numpy array, or Python\n", + "scalar) using `jax.typeof`:" ] }, { "cell_type": "code", - "execution_count": 11, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "PXu3MhafyRHo", - "outputId": "7bc6821f-a4a9-4cf8-8b21-e279d516d27b" - }, + "execution_count": 9, + "metadata": {}, "outputs": [ + { + "data": { + "text/html": [ + "
  TPU 0    TPU 1    TPU 2    TPU 3    TPU 6    TPU 7    TPU 4    TPU 5  \n",
+       "                                                                        \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107mTPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82mTPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mTPU 3\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148mTPU 6\u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207mTPU 7\u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148mTPU 4\u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49mTPU 5\u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "name": "stdout", "output_type": "stream", "text": [ - "[[ 0. 1. 2. 3. 4. 5. 6. 7.]\n", - " [ 8. 9. 10. 11. 12. 13. 14. 15.]\n", - " [16. 17. 18. 19. 20. 21. 22. 23.]\n", - " [24. 25. 26. 27. 28. 29. 30. 31.]]\n", - "device\n" + "[48. 52. 56. 60. 64. 68. 72. 76.]\n" ] } ], "source": [ - "f = jax.jit(lambda x: x, out_shardings=s_dev)\n", - "out_dev = f(arr_host)\n", - "print(out_dev)\n", - "print(out_dev.sharding.memory_kind)" + "some_array = np.arange(8)\n", + "print(f\"JAX-level type of some_array: {jax.typeof(some_array)}\")" ] }, { "cell_type": "markdown", - "metadata": { - "id": "LuYFqpcBySiX" - }, + "metadata": {}, "source": [ - "#### Example 2: Device to pinned_host memory\n", + "Importantly, we can query the type even while tracing under a `jit` (the JAX-level type\n", + "is almost _defined_ as \"the information about a value we have access to while\n", + "under a jit)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ffe62839", + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def foo(x):\n", + " print(f\"JAX-level type of x during tracing: {jax.typeof(x)}\")\n", + " return x + x\n", "\n", - "In the example below, the {func}`jax.jit` function `g` takes an array sharded in `device` memory and generates an array in `pinned_host` memory." + "foo(some_array)" + ] + }, + { + "cell_type": "markdown", + "id": "74995421", + "metadata": {}, + "source": [ + "To start seeing shardings in the type we need to set up an explicit-sharding mesh." ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "qLsgNlKfybRw", - "outputId": "a16448b9-7e39-408f-b200-505f65ad4464" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[ 0. 1. 2. 3. 4. 5. 6. 7.]\n", - " [ 8. 9. 10. 11. 12. 13. 14. 15.]\n", - " [16. 17. 18. 19. 20. 21. 22. 23.]\n", - " [24. 25. 26. 27. 28. 29. 30. 31.]]\n", - "pinned_host\n" - ] - } - ], + "id": "e785a694", + "metadata": {}, + "outputs": [], "source": [ - "g = jax.jit(lambda x: x, out_shardings=s_host)\n", - "out_host = g(arr_dev)\n", - "print(out_host)\n", - "print(out_host.sharding.memory_kind)" + "from jax.sharding import AxisType\n", + "\n", + "mesh = jax.make_mesh((2, 4), (\"X\", \"Y\"),\n", + " axis_types=(AxisType.Explicit, AxisType.Explicit))" ] }, { "cell_type": "markdown", - "metadata": { - "id": "7BGD31-owaSU" - }, + "id": "8d81409c", + "metadata": {}, + "source": [ + "Now we can create some sharded arrays:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4969cabd", + "metadata": {}, + "outputs": [], "source": [ - "## 2. Semi-automated sharding with constraints\n", + "replicated_array = np.arange(8).reshape(4, 2)\n", + "sharded_array = jax.device_put(replicated_array, jax.NamedSharding(mesh, P(\"X\", None)))\n", "\n", - "If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.\n", + "print(f\"replicated_array type: {jax.typeof(replicated_array)}\")\n", + "print(f\"sharded_array type: {jax.typeof(sharded_array)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "c09acf7d", + "metadata": {}, + "source": [ + "We should read the type `int32[4@X, 2]` as \"a 4-by-2 array of 32-bit ints whose first dimension\n", + "is sharded along mesh axis 'X'. The array is replicated along all other mesh\n", + "axes\"\n", "\n", - "For example, suppose that within `f_contract` above, you'd prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices:" + "These shardings associated with JAX-level types propagate through operations. For example:" ] }, { "cell_type": "code", - "execution_count": 9, - "metadata": { - "outputId": "8468f5c6-76ca-4367-c9f2-93c723687cfd" - }, - "outputs": [ - { - "data": { - "text/html": [ - "
  TPU 0    TPU 1    TPU 2    TPU 3    TPU 6    TPU 7    TPU 4    TPU 5  \n",
-       "                                                                        \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107mTPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82mTPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mTPU 3\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148mTPU 6\u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207mTPU 7\u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148mTPU 4\u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49mTPU 5\u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[48. 52. 56. 60. 64. 68. 72. 76.]\n" - ] - } - ], + "execution_count": null, + "id": "ab2f9500", + "metadata": {}, + "outputs": [], "source": [ + "arg0 = jax.device_put(np.arange(4).reshape(4, 1),\n", + " jax.NamedSharding(mesh, P(\"X\", None)))\n", + "arg1 = jax.device_put(np.arange(8).reshape(1, 8),\n", + " jax.NamedSharding(mesh, P(None, \"Y\")))\n", + "\n", "@jax.jit\n", - "def f_contract_2(x):\n", - " out = x.sum(axis=0)\n", - " sharding = jax.sharding.NamedSharding(mesh, P('x'))\n", - " return jax.lax.with_sharding_constraint(out, sharding)\n", + "def add_arrays(x, y):\n", + " ans = x + y\n", + " print(f\"x sharding: {jax.typeof(x)}\")\n", + " print(f\"y sharding: {jax.typeof(y)}\")\n", + " print(f\"ans sharding: {jax.typeof(ans)}\")\n", + " return ans\n", "\n", - "result = f_contract_2(arr_sharded)\n", - "jax.debug.visualize_array_sharding(result)\n", - "print(result)" + "with jax.set_mesh(mesh):\n", + " add_arrays(arg0, arg1)" ] }, { "cell_type": "markdown", + "id": "dda3d0c5", "metadata": {}, "source": [ - "This gives you a function with the particular output sharding you'd like.\n", + "That's the gist of it. Shardings propagate deterministically at trace time and\n", + "we can query them at trace time.\n", "\n", "## 3. Manual parallelism with `shard_map`\n", "\n", - "In the automatic parallelism methods explored above, you can write a function as if you're operating on the full dataset, and `jit` will split that computation across multiple devices. By contrast, with {func}`jax.experimental.shard_map.shard_map` you write the function that will handle a single shard of data, and `shard_map` will construct the full function.\n", + "In the automatic parallelism methods explored above, you can write a function as if you're operating on the full dataset, and `jit` will split that computation across multiple devices. By contrast, with {func}`jax.shard_map` you write the function that will handle a single shard of data, and `shard_map` will construct the full function.\n", "\n", "`shard_map` works by mapping a function across a particular *mesh* of devices (`shard_map` maps over shards). In the example below:\n", "\n", "- As before, {class}`jax.sharding.Mesh` allows for precise device placement, with the axis names parameter for logical and physical axis names.\n", "- The `in_specs` argument determines the shard sizes. The `out_specs` argument identifies how the blocks are assembled back together.\n", "\n", - "**Note:** {func}`jax.experimental.shard_map.shard_map` code can work inside {func}`jax.jit` if you need it." + "**Note:** {func}`jax.shard_map` code can work inside {func}`jax.jit` if you need it." ] }, { @@ -580,10 +571,9 @@ } ], "source": [ - "from jax.experimental.shard_map import shard_map\n", "mesh = jax.make_mesh((8,), ('x',))\n", "\n", - "f_elementwise_sharded = shard_map(\n", + "f_elementwise_sharded = jax.shard_map(\n", " f_elementwise,\n", " mesh=mesh,\n", " in_specs=P('x'),\n", @@ -624,7 +614,7 @@ " print(f\"device local shape: {x.shape=}\")\n", " return x * 2\n", "\n", - "y = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)" + "y = jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)" ] }, { @@ -658,7 +648,7 @@ "def f(x):\n", " return jnp.sum(x, keepdims=True)\n", "\n", - "shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)" + "jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)" ] }, { @@ -693,7 +683,7 @@ " sum_in_shard = x.sum()\n", " return jax.lax.psum(sum_in_shard, 'x')\n", "\n", - "shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x)" + "jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x)" ] }, { @@ -757,7 +747,8 @@ "source": [ "You can automatically run this in a distributed manner using {func}`jax.jit` and passing appropriately sharded data.\n", "\n", - "If you shard the leading axis of both `x` and `weights` in the same way, then the matrix multiplication will automatically happen in parallel:" + "If you shard the leading axis of both `x` and make `weights` fully replicated,\n", + "then the matrix multiplication will automatically happen in parallel:" ] }, { @@ -780,10 +771,8 @@ ], "source": [ "mesh = jax.make_mesh((8,), ('x',))\n", - "sharding = jax.sharding.NamedSharding(mesh, P('x'))\n", - "\n", - "x_sharded = jax.device_put(x, sharding)\n", - "weights_sharded = jax.device_put(weights, sharding)\n", + "x_sharded = jax.device_put(x, jax.NamedSharding(mesh, P('x')))\n", + "weights_sharded = jax.device_put(weights, jax.NamedSharding(mesh, P()))\n", "\n", "layer(x_sharded, weights_sharded, bias)" ] @@ -792,15 +781,13 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Alternatively, you can use {func}`jax.lax.with_sharding_constraint` in the function to automatically distribute unsharded inputs:" + "Alternatively, you can use explicit sharding mode too:" ] }, { "cell_type": "code", "execution_count": 17, - "metadata": { - "outputId": "bb63e8da-ff4f-4e95-f083-10584882daf4" - }, + "metadata": {}, "outputs": [ { "data": { @@ -814,13 +801,22 @@ } ], "source": [ + "explicit_mesh = jax.make_mesh((8,), ('X',), axis_types=(AxisType.Explicit,))\n", + "\n", + "x_sharded = jax.device_put(x, jax.NamedSharding(explicit_mesh, P('X')))\n", + "weights_sharded = jax.device_put(weights, jax.NamedSharding(explicit_mesh, P()))\n", + "\n", "@jax.jit\n", "def layer_auto(x, weights, bias):\n", - " x = jax.lax.with_sharding_constraint(x, sharding)\n", - " weights = jax.lax.with_sharding_constraint(weights, sharding)\n", - " return layer(x, weights, bias)\n", + " print(f\"x sharding: {jax.typeof(x)}\")\n", + " print(f\"weights sharding: {jax.typeof(weights)}\")\n", + " print(f\"bias sharding: {jax.typeof(bias)}\")\n", + " out = layer(x, weights, bias)\n", + " print(f\"out sharding: {jax.typeof(out)}\")\n", + " return out\n", "\n", - "layer_auto(x, weights, bias) # pass in unsharded inputs" + "with jax.set_mesh(explicit_mesh):\n", + " layer_auto(x_sharded, weights_sharded, bias)" ] }, { @@ -852,7 +848,7 @@ "from functools import partial\n", "\n", "@jax.jit\n", - "@partial(shard_map, mesh=mesh,\n", + "@partial(jax.shard_map, mesh=mesh,\n", " in_specs=(P('x'), P('x', None), P(None)),\n", " out_specs=P(None))\n", "def layer_sharded(x, weights, bias):\n", @@ -865,13 +861,82 @@ "cell_type": "markdown", "metadata": {}, "source": [ + "(sharded-data-placement)=\n", + "## Controlling data and computation placement on devices\n", + "\n", + "Let's look at the principles of data and computation placement in JAX.\n", + "\n", + "In JAX, the computation follows data placement. JAX arrays have two placement\n", + "properties: 1) the device where the data resides; and 2) whether it is\n", + "**committed** to the device or not (the data is sometimes referred to as being\n", + "*sticky* to the device).\n", + "\n", + "By default, JAX arrays are placed uncommitted on the default device\n", + "(`jax.devices()[0]`), which is the first GPU or TPU by default. If no GPU or\n", + "TPU is present, `jax.devices()[0]` is the CPU. The default device can be\n", + "temporarily overridden with the {func}`jax.default_device` context manager, or\n", + "set for the whole process by setting the environment variable `JAX_PLATFORMS`\n", + "or the absl flag `--jax_platforms` to \"cpu\", \"gpu\", or \"tpu\" (`JAX_PLATFORMS`\n", + "can also be a list of platforms, which determines which platforms are available\n", + "in priority order).\n", + "\n", + "```python\n", + ">>> from jax import numpy as jnp\n", + ">>> print(jnp.ones(3).devices()) # doctest: +SKIP\n", + "{CudaDevice(id=0)}\n", + "```\n", + "\n", + "Computations involving uncommitted data are performed on the default device and\n", + "the results are uncommitted on the default device.\n", + "\n", + "Data can also be placed explicitly on a device using {func}`jax.device_put` with\n", + "a `device` parameter, in which case the data becomes **committed** to the\n", + "device:\n", + "\n", + "```python\n", + ">>> import jax\n", + ">>> from jax import device_put\n", + ">>> arr = device_put(1, jax.devices()[2]) # doctest: +SKIP\n", + ">>> print(arr.devices()) # doctest: +SKIP\n", + "{CudaDevice(id=2)}\n", + "```\n", + "\n", + "Computations involving some committed inputs will happen on the committed device\n", + "and the result will be committed on the same device. Invoking an operation on\n", + "arguments that are committed to more than one device will raise an error.\n", + "\n", + "You can also use {func}`jax.device_put` without a `device` parameter. If the\n", + "data is already on a device (committed or not), it's left as-is. If the data\n", + "isn't on any device—that is, it's a regular Python or NumPy value—it's placed\n", + "uncommitted on the default device.\n", + "\n", + "Jitted functions behave like any other primitive operations—they will follow the\n", + "data and will show errors if invoked on data committed on more than one device.\n", + "\n", + "(Before [PR #6002](https://github.com/jax-ml/jax/pull/6002) in March 2021\n", + "there was some laziness in creation of array constants, so that\n", + "`jax.device_put(jnp.zeros(...), jax.devices()[1])` or similar would actually\n", + "create the array of zeros on `jax.devices()[1]`, instead of creating the\n", + "array on the default device then moving it. But this optimization was removed\n", + "so as to simplify the implementation.)\n", + "\n", + "(As of April 2020, {func}`jax.jit` has a `device` parameter that affects the device\n", + "placement. That parameter is experimental, is likely to be removed or changed,\n", + "and its use is not recommended.)\n", + "\n", + "For a worked-out example, we recommend reading through\n", + "`test_computation_follows_data` in\n", + "[multi_device_test.py](https://github.com/jax-ml/jax/blob/main/tests/multi_device_test.py).\n", + "\n", "## Next steps\n", "\n", "This tutorial serves as a brief introduction of sharded and parallel computation in JAX.\n", "\n", "To learn about each SPMD method in-depth, check out these docs:\n", "- {doc}`../notebooks/Distributed_arrays_and_automatic_parallelization`\n", - "- {doc}`../notebooks/shard_map`" + "- {doc}`../notebooks/explicit-sharding`\n", + "- {doc}`../notebooks/shard_map`\n", + "- {doc}`../the-training-cookbook`" ] } ], diff --git a/docs/sharded-computation.md b/docs/sharded-computation.md index b05eb8d5f66e..3838176ec069 100644 --- a/docs/sharded-computation.md +++ b/docs/sharded-computation.md @@ -14,24 +14,40 @@ kernelspec: (sharded-computation)= # Introduction to parallel programming - + This tutorial serves as an introduction to device parallelism for Single-Program Multi-Data (SPMD) code in JAX. SPMD is a parallelism technique where the same computation, such as the forward pass of a neural network, can be run on different input data (for example, different inputs in a batch) in parallel on different devices, such as several GPUs or Google TPUs. The tutorial covers three modes of parallel computation: -- _Automatic parallelism via {func}`jax.jit`_: The compiler chooses the optimal computation strategy (a.k.a. "the compiler takes the wheel"). -- _Semi-automated parallelism_ using {func}`jax.jit` and {func}`jax.lax.with_sharding_constraint` -- _Fully manual parallelism with manual control using {func}`jax.experimental.shard_map.shard_map`_: `shard_map` enables per-device code and explicit communication collectives +- _Automatic sharding via {func}`jax.jit`_: The compiler chooses the optimal computation strategy (a.k.a. "the compiler takes the wheel"). +- *Explicit Sharding* (\*new\*) is similar to automatic sharding in that + you're writing a global-view program. The difference is that the sharding + of each array is part of the array's JAX-level type making it an explicit + part of the programming model. These shardings are propagated at the JAX + level and queryable at trace time. It's still the compiler's responsibility + to turn the whole-array program into per-device programs (turning `jnp.sum` + into `psum` for example) but the compiler is heavily constrained by the + user-supplied shardings. +- _Fully manual sharding with manual control using {func}`jax.shard_map`_: `shard_map` enables per-device code and explicit communication collectives + +A summary table: + +| Mode | View? | Explicit sharding? | Explicit Collectives? | +|---|---|---|---| +| Auto | Global | ❌ | ❌ | +| Explicit | Global | ✅ | ❌ | +| Manual | Per-device | ✅ | ✅ | Using these schools of thought for SPMD, you can transform a function written for one device into a function that can run in parallel on multiple devices. -If you are running these examples in a Google Colab notebook, make sure that your hardware accelerator is the latest Google TPU by checking your notebook settings: **Runtime** > **Change runtime type** > **Hardware accelerator** > **TPU v2** (which provides eight devices to work with). - ```{code-cell} -:outputId: 18905ae4-7b5e-4bb9-acb4-d8ab914cb456 - import jax + +jax.config.update('jax_num_cpu_devices', 8) +``` + +```{code-cell} jax.devices() ``` @@ -46,7 +62,9 @@ In the simplest cases, arrays are sharded on a single device, as demonstrated be ```{code-cell} :outputId: 39fdbb79-d5c0-4ea6-8b20-88b2c502a27a +import numpy as np import jax.numpy as jnp + arr = jnp.arange(32.0).reshape(4, 8) arr.devices() ``` @@ -90,31 +108,6 @@ print(arr_sharded) jax.debug.visualize_array_sharding(arr_sharded) ``` -+++ {"id": "UEObolTqw4pp"} - -The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device. - -The {class}`~jax.sharding.NamedSharding` includes a parameter called `memory_kind`. This parameter determines the type of memory to be used and defaults to `device`. You can set this parameter to `pinned_host` if you prefer to place it on the host. - -To create a new sharding that only differs from an existing sharding in terms of its memory kind, you can use the `with_memory_kind` method on the existing sharding. - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: aKNeOHTJnqmS -outputId: 847c53ec-8b2e-4be0-f993-7fde7d77c0f2 ---- -s_host = jax.NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host') -s_dev = s_host.with_memory_kind('device') -arr_host = jax.device_put(arr, s_host) -arr_dev = jax.device_put(arr, s_dev) -print(arr_host.sharding.memory_kind) -print(arr_dev.sharding.memory_kind) -``` - -+++ {"id": "jDHYnVqHwaST"} - ## 1. Automatic parallelism via `jit` Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a {func}`jax.jit`-compiled function! In JAX, you need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications. @@ -154,90 +147,96 @@ print(result) +++ {"id": "Q4N5mrr9i_ki"} -The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on. - -### 1.1 Sharding transformation between memory types +The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `4`, the second on `1` and `5`, and so on. -The output sharding of a {func}`jax.jit` function can differ from the input sharding if you specify the output sharding using the `out_shardings` parameter. Specifically, the `memory_kind` of the output can be different from that of the input array. +## 2. Explicit sharding -#### Example 1: Pinned host to device memory - -In the example below, the {func}`jax.jit` function `f` takes an array sharded in `pinned_host` memory and generates an array in `device` memory. +The main idea behind explicit shardings, (a.k.a. sharding-in-types), is that +the JAX-level _type_ of a value includes a description of how the value is sharded. +We can query the JAX-level type of any JAX value (or Numpy array, or Python +scalar) using `jax.typeof`: ```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: PXu3MhafyRHo -outputId: 7bc6821f-a4a9-4cf8-8b21-e279d516d27b ---- -f = jax.jit(lambda x: x, out_shardings=s_dev) -out_dev = f(arr_host) -print(out_dev) -print(out_dev.sharding.memory_kind) +some_array = np.arange(8) +print(f"JAX-level type of some_array: {jax.typeof(some_array)}") ``` -+++ {"id": "LuYFqpcBySiX"} +Importantly, we can query the type even while tracing under a `jit` (the JAX-level type +is almost _defined_ as "the information about a value we have access to while +under a jit). + +```{code-cell} +@jax.jit +def foo(x): + print(f"JAX-level type of x during tracing: {jax.typeof(x)}") + return x + x -#### Example 2: Device to pinned_host memory +foo(some_array) +``` -In the example below, the {func}`jax.jit` function `g` takes an array sharded in `device` memory and generates an array in `pinned_host` memory. +To start seeing shardings in the type we need to set up an explicit-sharding mesh. ```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: qLsgNlKfybRw -outputId: a16448b9-7e39-408f-b200-505f65ad4464 ---- -g = jax.jit(lambda x: x, out_shardings=s_host) -out_host = g(arr_dev) -print(out_host) -print(out_host.sharding.memory_kind) +from jax.sharding import AxisType + +mesh = jax.make_mesh((2, 4), ("X", "Y"), + axis_types=(AxisType.Explicit, AxisType.Explicit)) ``` -+++ {"id": "7BGD31-owaSU"} +Now we can create some sharded arrays: -## 2. Semi-automated sharding with constraints +```{code-cell} +replicated_array = np.arange(8).reshape(4, 2) +sharded_array = jax.device_put(replicated_array, jax.NamedSharding(mesh, P("X", None))) + +print(f"replicated_array type: {jax.typeof(replicated_array)}") +print(f"sharded_array type: {jax.typeof(sharded_array)}") +``` -If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed. +We should read the type `int32[4@X, 2]` as "a 4-by-2 array of 32-bit ints whose first dimension +is sharded along mesh axis 'X'. The array is replicated along all other mesh +axes" -For example, suppose that within `f_contract` above, you'd prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices: +These shardings associated with JAX-level types propagate through operations. For example: ```{code-cell} -:outputId: 8468f5c6-76ca-4367-c9f2-93c723687cfd +arg0 = jax.device_put(np.arange(4).reshape(4, 1), + jax.NamedSharding(mesh, P("X", None))) +arg1 = jax.device_put(np.arange(8).reshape(1, 8), + jax.NamedSharding(mesh, P(None, "Y"))) @jax.jit -def f_contract_2(x): - out = x.sum(axis=0) - sharding = jax.sharding.NamedSharding(mesh, P('x')) - return jax.lax.with_sharding_constraint(out, sharding) - -result = f_contract_2(arr_sharded) -jax.debug.visualize_array_sharding(result) -print(result) +def add_arrays(x, y): + ans = x + y + print(f"x sharding: {jax.typeof(x)}") + print(f"y sharding: {jax.typeof(y)}") + print(f"ans sharding: {jax.typeof(ans)}") + return ans + +with jax.set_mesh(mesh): + add_arrays(arg0, arg1) ``` -This gives you a function with the particular output sharding you'd like. +That's the gist of it. Shardings propagate deterministically at trace time and +we can query them at trace time. ## 3. Manual parallelism with `shard_map` -In the automatic parallelism methods explored above, you can write a function as if you're operating on the full dataset, and `jit` will split that computation across multiple devices. By contrast, with {func}`jax.experimental.shard_map.shard_map` you write the function that will handle a single shard of data, and `shard_map` will construct the full function. +In the automatic parallelism methods explored above, you can write a function as if you're operating on the full dataset, and `jit` will split that computation across multiple devices. By contrast, with {func}`jax.shard_map` you write the function that will handle a single shard of data, and `shard_map` will construct the full function. `shard_map` works by mapping a function across a particular *mesh* of devices (`shard_map` maps over shards). In the example below: - As before, {class}`jax.sharding.Mesh` allows for precise device placement, with the axis names parameter for logical and physical axis names. - The `in_specs` argument determines the shard sizes. The `out_specs` argument identifies how the blocks are assembled back together. -**Note:** {func}`jax.experimental.shard_map.shard_map` code can work inside {func}`jax.jit` if you need it. +**Note:** {func}`jax.shard_map` code can work inside {func}`jax.jit` if you need it. ```{code-cell} :outputId: 435c32f3-557a-4676-c11b-17e6bab8c1e2 -from jax.experimental.shard_map import shard_map mesh = jax.make_mesh((8,), ('x',)) -f_elementwise_sharded = shard_map( +f_elementwise_sharded = jax.shard_map( f_elementwise, mesh=mesh, in_specs=P('x'), @@ -259,7 +258,7 @@ def f(x): print(f"device local shape: {x.shape=}") return x * 2 -y = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x) +y = jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x) ``` Because each of your functions only "sees" the device-local part of the data, it means that aggregation-like functions require some extra thought. @@ -272,7 +271,7 @@ For example, here's what a `shard_map` of a {func}`jax.numpy.sum` looks like: def f(x): return jnp.sum(x, keepdims=True) -shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x) +jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x) ``` Your function `f` operates separately on each shard, and the resulting summation reflects this. @@ -286,7 +285,7 @@ def f(x): sum_in_shard = x.sum() return jax.lax.psum(sum_in_shard, 'x') -shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x) +jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x) ``` Because the output no longer has a sharded dimension, set `out_specs=P()` (recall that the `out_specs` argument identifies how the blocks are assembled back together in `shard_map`). @@ -320,32 +319,38 @@ layer(x, weights, bias) You can automatically run this in a distributed manner using {func}`jax.jit` and passing appropriately sharded data. -If you shard the leading axis of both `x` and `weights` in the same way, then the matrix multiplication will automatically happen in parallel: +If you shard the leading axis of both `x` and make `weights` fully replicated, +then the matrix multiplication will automatically happen in parallel: ```{code-cell} :outputId: 80be899e-8dbc-4bfc-acd2-0f3d554a0aa5 mesh = jax.make_mesh((8,), ('x',)) -sharding = jax.sharding.NamedSharding(mesh, P('x')) - -x_sharded = jax.device_put(x, sharding) -weights_sharded = jax.device_put(weights, sharding) +x_sharded = jax.device_put(x, jax.NamedSharding(mesh, P('x'))) +weights_sharded = jax.device_put(weights, jax.NamedSharding(mesh, P())) layer(x_sharded, weights_sharded, bias) ``` -Alternatively, you can use {func}`jax.lax.with_sharding_constraint` in the function to automatically distribute unsharded inputs: +Alternatively, you can use explicit sharding mode too: ```{code-cell} -:outputId: bb63e8da-ff4f-4e95-f083-10584882daf4 +explicit_mesh = jax.make_mesh((8,), ('X',), axis_types=(AxisType.Explicit,)) + +x_sharded = jax.device_put(x, jax.NamedSharding(explicit_mesh, P('X'))) +weights_sharded = jax.device_put(weights, jax.NamedSharding(explicit_mesh, P())) @jax.jit def layer_auto(x, weights, bias): - x = jax.lax.with_sharding_constraint(x, sharding) - weights = jax.lax.with_sharding_constraint(weights, sharding) - return layer(x, weights, bias) - -layer_auto(x, weights, bias) # pass in unsharded inputs + print(f"x sharding: {jax.typeof(x)}") + print(f"weights sharding: {jax.typeof(weights)}") + print(f"bias sharding: {jax.typeof(bias)}") + out = layer(x, weights, bias) + print(f"out sharding: {jax.typeof(out)}") + return out + +with jax.set_mesh(explicit_mesh): + layer_auto(x_sharded, weights_sharded, bias) ``` Finally, you can do the same thing with `shard_map`, using {func}`jax.lax.psum` to indicate the cross-shard collective required for the matrix product: @@ -356,7 +361,7 @@ Finally, you can do the same thing with `shard_map`, using {func}`jax.lax.psum` from functools import partial @jax.jit -@partial(shard_map, mesh=mesh, +@partial(jax.shard_map, mesh=mesh, in_specs=(P('x'), P('x', None), P(None)), out_specs=P(None)) def layer_sharded(x, weights, bias): @@ -365,10 +370,79 @@ def layer_sharded(x, weights, bias): layer_sharded(x, weights, bias) ``` +(sharded-data-placement)= +## Controlling data and computation placement on devices + +Let's look at the principles of data and computation placement in JAX. + +In JAX, the computation follows data placement. JAX arrays have two placement +properties: 1) the device where the data resides; and 2) whether it is +**committed** to the device or not (the data is sometimes referred to as being +*sticky* to the device). + +By default, JAX arrays are placed uncommitted on the default device +(`jax.devices()[0]`), which is the first GPU or TPU by default. If no GPU or +TPU is present, `jax.devices()[0]` is the CPU. The default device can be +temporarily overridden with the {func}`jax.default_device` context manager, or +set for the whole process by setting the environment variable `JAX_PLATFORMS` +or the absl flag `--jax_platforms` to "cpu", "gpu", or "tpu" (`JAX_PLATFORMS` +can also be a list of platforms, which determines which platforms are available +in priority order). + +```python +>>> from jax import numpy as jnp +>>> print(jnp.ones(3).devices()) # doctest: +SKIP +{CudaDevice(id=0)} +``` + +Computations involving uncommitted data are performed on the default device and +the results are uncommitted on the default device. + +Data can also be placed explicitly on a device using {func}`jax.device_put` with +a `device` parameter, in which case the data becomes **committed** to the +device: + +```python +>>> import jax +>>> from jax import device_put +>>> arr = device_put(1, jax.devices()[2]) # doctest: +SKIP +>>> print(arr.devices()) # doctest: +SKIP +{CudaDevice(id=2)} +``` + +Computations involving some committed inputs will happen on the committed device +and the result will be committed on the same device. Invoking an operation on +arguments that are committed to more than one device will raise an error. + +You can also use {func}`jax.device_put` without a `device` parameter. If the +data is already on a device (committed or not), it's left as-is. If the data +isn't on any device—that is, it's a regular Python or NumPy value—it's placed +uncommitted on the default device. + +Jitted functions behave like any other primitive operations—they will follow the +data and will show errors if invoked on data committed on more than one device. + +(Before [PR #6002](https://github.com/jax-ml/jax/pull/6002) in March 2021 +there was some laziness in creation of array constants, so that +`jax.device_put(jnp.zeros(...), jax.devices()[1])` or similar would actually +create the array of zeros on `jax.devices()[1]`, instead of creating the +array on the default device then moving it. But this optimization was removed +so as to simplify the implementation.) + +(As of April 2020, {func}`jax.jit` has a `device` parameter that affects the device +placement. That parameter is experimental, is likely to be removed or changed, +and its use is not recommended.) + +For a worked-out example, we recommend reading through +`test_computation_follows_data` in +[multi_device_test.py](https://github.com/jax-ml/jax/blob/main/tests/multi_device_test.py). + ## Next steps This tutorial serves as a brief introduction of sharded and parallel computation in JAX. To learn about each SPMD method in-depth, check out these docs: - {doc}`../notebooks/Distributed_arrays_and_automatic_parallelization` +- {doc}`../notebooks/explicit-sharding` - {doc}`../notebooks/shard_map` +- {doc}`../the-training-cookbook` diff --git a/docs/shardy_jax_migration.md b/docs/shardy_jax_migration.md new file mode 100644 index 000000000000..308ba11c88df --- /dev/null +++ b/docs/shardy_jax_migration.md @@ -0,0 +1,155 @@ +--- +orphan: true +--- +(shardy-jax-migration)= +# Shardy JAX Migration + + + +## TL;DR + +### What’s going on? + +[Shardy](https://openxla.org/shardy) is a new partitioning system co-developed +by GDM Model Scaling (author of [PartIR](https://arxiv.org/abs/2401.11202)) and +XLA/CoreML teams (author of [GSPMD](https://arxiv.org/abs//2105.04663)). Shardy +aims to provide better usability and control to users, and will gradually +replace GSPMD and PartIR. + +After the migration is complete in March 2026, Shardy will be the only +partitioner in JAX. + +Until then, as a temporary workaround for any problems, Shardy +[can be disabled](#how-can-i-disable-shardy-for-now). Please file a +[JAX issue](https://github.com/jax-ml/jax/issues) if you encounter any problem. + +### How do I know if Shardy broke my code? + +The easiest way to tell if Shardy is responsible for any problems is to disable +Shardy and see if the issues go away. See +[What issues can arise when Shardy is switched on?](#what-issues-can-arise-when-shardy-is-switched-on) +section below. + +You can tell that Shardy is enabled by looking for +`Using Shardy for XLA SPMD propagation in the logs`. + +### How can I disable Shardy for now? + +Until March, 2026 it will be possible to temporarily disable Shardy by: + + * setting the shell environment variable `JAX_USE_SHARDY_PARTITIONER` to + something false-like (e.g., 0); + + * setting the boolean flag `jax_use_shardy_partitioner` to something + false-like if your code parses flags with absl; + + * using this statement in your main file or anywhere before you call + `jax.jit`: + + ``` python + import jax + jax.config.update('jax_use_shardy_partitioner', False) + ``` + +To debug partitioning with Shardy enabled, you can enable MLIR dumps as follows: + +``` +--xla_dump_hlo_pass_re=shardy --xla_dump_to= +``` + +NOTE: Please disable only the specific use cases that are not working as +expected if possible, and file a [bug](https://github.com/jax-ml/jax/issues) +with a reproducer, so we can resolve it asap and re-enable Shardy. + +### JAX export backwards compatibility + +Enabling Shardy in JAX by default is maintaining the 6 months backwards +compatibility guarantee. This means that you will be able to load a model +exported with Shardy disabled for at least 6 months after Shardy becomes enabled +for your model. That old checkpointed model will run with GSPMD, and only when +re-exporting the model will it start running with Shardy. + +However, if you still encounter an issue with loading an old checkpoint, please +contact us or file a [bug](https://github.com/jax-ml/jax/issues). + +NOTE: exporting a model with Shardy enabled, then loading it with Shardy +disabled isn’t supported and will fail. + +### How do I prepare for Shardy being enabled in March 2026 permanently? + +Due to us falling back to GSPMD for any JAX export checkpoint for 6 months, to +help find any potential issues, please re-export any models you have with Shardy +enabled. Then you can see if it runs fine, or there is any bug we need to fix. + +## What issues can arise when Shardy is switched on? + +### Performance regression or OOM + +While Shardy improves on the existing sharding propagation systems (GSPMD and +PartIR), it can sometimes output slightly different results due to different +propagation order or conflict resolution heuristics. + +This doesn’t necessarily mean that Shardy is doing the wrong thing, but possibly +that there aren't enough sharding constraints in the program, so a small change +in propagation order can affect the final result. It can also hint that existing +sharding constraints were overfitted to GSPMD and require slight adjustments +with Shardy. + +Therefore, it is possible that enabling Shardy will cause some models to have a +performance regression or OOM (especially if the model was already close to the +memory capacity). However, we have already migrated many use cases across +Alphabet, and have observed equivalent or better performance than GSPMD. + +To resolve such issues, users can either: + +1. Disable Shardy temporarily and open a [bug](https://github.com/jax-ml/jax/issues) + with a reproducer. +2. Add additional sharding constraints to make sure Shardy does the desired + thing. + +### Compilation failure + +We have done extensive testing across many JAX models. However, it’s possible +that there are certain edge cases or situations we don’t support/handle (because +we didn't know we needed to). + +This means that although rare, it’s possible that you will get a compilation +failure in the form of a segfault, hard check, python value error, etc. + +In such a case, please disable Shardy temporarily and open a +[bug](https://github.com/jax-ml/jax/issues) with a reproducer. + +### Inconsistent value of the use Shardy flag + +If Shardy is disabled somewhere in your code, but there are still paths that use +the default value of the JAX flag, this can cause issues. For example, exporting +a model with Shardy enabled, then loading it with Shardy disabled isn’t +supported and will fail (the other way is supported for +[backwards compatibility](#jax-export-backwards-compatibility)). + +The symptom for an issue like this can be an error in JAX or in XLA/Shardy, or +just undefined behavior. You can try disabling Shardy globally in +[JAX config](https://github.com/jax-ml/jax/blob/main/jax/_src/config.py) to see +if the issue goes away. + +NOTE: Please ensure that Shardy is disabled consistently if needed, or remove +any explicit modification of the flag, to have the default value apply +throughout. + +### New way to use the JAX `jax.experimental.custom_partitioning` API + +If you use this API, you may see the error + +``` +Shardy is used, but sharding propagation callbacks instead of sharding_rule are +provided. Need to provide sharding_rule to migrate to Shardy. +``` + +Instead of defining `infer_sharding_from_operands` and `propagate_user_sharding` +callbacks, define a `jax.experimental.SdyShardingRule` that specifies an einsum-like relationship between dimensions during propagation. Refer to the [`custom_partitioning` doc](https://docs.jax.dev/en/latest/jax.experimental.custom_partitioning.html#module-jax.experimental.custom_partitioning) +for more info on how to define a sharding rule. + +### `jax.export` requires all inputs and outputs to have the same mesh + +As part of the Shardy migration, `jax.export` now requires all input/output +shardings to live on the same mesh - same axis names and sizes. diff --git a/docs/sphinxext/jax_list_config_options.py b/docs/sphinxext/jax_list_config_options.py new file mode 100644 index 000000000000..e4dee1852333 --- /dev/null +++ b/docs/sphinxext/jax_list_config_options.py @@ -0,0 +1,172 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from operator import itemgetter +from typing import Any, List + +from docutils import nodes +from sphinx.util import logging +from sphinx.util.docutils import SphinxDirective + +logger = logging.getLogger(__name__) + +_deprecations = ( + 'jax_default_dtype_bits', # an experiment that we never documented, but we can't remove it because Keras depends on its existing broken behavior + 'jax_serialization_version' +) + +def create_field_item(label, content): + """Create a field list item with a label and content side by side. + + Args: + label: The label text for the field name + content: The content to add (a node or text) + + Returns: + A field list item with the label and content side by side. + """ + # Create a field list item + field = nodes.field() + + # Create the field name (label) + field_name = nodes.field_name() + field_name += nodes.Text(label) + field += field_name + + # Create the field body (content) + field_body = nodes.field_body() + + if isinstance(content, str): + para = nodes.paragraph() + para += nodes.Text(content) + field_body += para + elif isinstance(content, nodes.Node): + field_body += content + + field += field_body + return field + +class ConfigOptionDirective(SphinxDirective): + required_arguments = 0 + optional_arguments = 0 + has_content = False + + def run(self) -> List[nodes.Node]: + from jax._src.config import config as jax_config + + config_options = sorted(jax_config.meta.items(), key=itemgetter(0)) + result = [] + + for name, (opt_type, meta_args, meta_kwargs) in config_options: + if name in _deprecations: + continue + + holder = jax_config._value_holders[name] + + # Create target for linking + target = nodes.target() + target['ids'].append(name) + result.append(target) + + # Create a section for this option + option_section = nodes.section() + option_section['ids'].append(name) + option_section['classes'].append('config-option-section') + + # Create a title with the option name (important for TOC) + title = nodes.title() + title['classes'] = ['h4'] + title += nodes.Text(name.replace("jax_", "").replace("_", " ").title()) + option_section += title + + # Create a field list for side-by-side display + field_list = nodes.field_list() + field_list['classes'].append('config-field-list') + + # Add type information as a field item + if opt_type == "enum": + type_para = nodes.paragraph() + emphasis_node = nodes.emphasis() + emphasis_node += nodes.Text("Enum values: ") + type_para += emphasis_node + + for i, value in enumerate(enum_values := meta_kwargs.get('enum_values', [])): + type_para += nodes.literal(text=repr(value)) + if i < len(enum_values) - 1: + type_para += nodes.Text(", ") + elif opt_type == "enum_class": + type_para = nodes.paragraph() + emphasis_node = nodes.emphasis() + emphasis_node += nodes.Text("Enum values: ") + type_para += emphasis_node + + enum_class = meta_kwargs.get('enum_class') + members = enum_class.__members__ + for i, value in enumerate(members.keys()): + type_para += nodes.literal(text=value) + if i < len(members) - 1: + type_para += nodes.Text(", ") + else: + type_para = nodes.paragraph() + type_para += nodes.literal(text=opt_type.__name__) + + field_list += create_field_item("Type", type_para) + + # Add default value information + default_para = nodes.paragraph() + default_para += nodes.literal(text=repr(holder.value)) + field_list += create_field_item("Default Value", default_para) + + # Add configuration string information + string_para = nodes.paragraph() + string_para += nodes.literal(text=repr(name)) + field_list += create_field_item("Configuration String", string_para) + + string_para = nodes.paragraph() + string_para += nodes.literal(text=name.upper()) + field_list += create_field_item("Environment Variable", string_para) + + # Add the field list to the section + option_section += field_list + + # Add help text in a description box + if (help_text := meta_kwargs.get('help')): + help_para = nodes.paragraph() + # logger.error(name) + # logger.warning(help_text) + + # If we get here, help text seems valid - proceed with normal parsing + # parsed = nodes.Text(help_text) + help_para += self.parse_text_to_nodes(help_text) + + option_section += help_para + + result.append(option_section) + # Add an extra paragraph to ensure proper separation + result.append(nodes.paragraph()) + result.append(nodes.paragraph()) # ensure new line + + return result + + def get_location(self) -> Any: + return (self.env.docname, self.lineno) + +def setup(app): + app.add_directive("list_config_options", ConfigOptionDirective) + + return { + "version": "0.1", + "parallel_read_safe": True, + "parallel_write_safe": True, + } diff --git a/docs/sphinxext/source_include.py b/docs/sphinxext/source_include.py new file mode 100644 index 000000000000..12f2185fddfb --- /dev/null +++ b/docs/sphinxext/source_include.py @@ -0,0 +1,106 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import ast +from pathlib import Path +from docutils import nodes +from sphinx.util.docutils import SphinxDirective +from sphinx.util.logging import getLogger + +logger = getLogger(__name__) + + +# (The parse_lines_spec and get_tagged_block functions are unchanged) +def parse_lines_spec(spec: str) -> list[int]: + items = [] + if not spec: + return items + for part in spec.split(","): + part = part.strip() + if "-" in part: + start, end = part.split("-", 1) + items.extend(range(int(start), int(end) + 1)) + else: + items.append(int(part)) + return items + + +def get_tagged_block(filepath, tag, lines_spec=None): + try: + full_path = Path(filepath) + if not full_path.exists(): + raise FileNotFoundError(f"Source file not found at {full_path}") + content_full = full_path.read_text() + regex_pattern = rf"# tag: {tag}\n(.*?)\s*# tag: {tag}" + pattern = re.compile(regex_pattern, re.DOTALL) + match = pattern.search(content_full) + if not match: + raise ValueError(f"Tag '{tag}' not found in '{filepath}'") + content = match.group(1).strip("\n") + if lines_spec is None: + return content + line_list = content.split("\n") + if lines_spec.startswith("[") and lines_spec.endswith("]"): + indexer = ast.literal_eval(lines_spec) + final_lines = [line_list[i] for i in indexer] + elif ":" in lines_spec: + parts_str = (lines_spec.split(":") + ["", "", ""])[:3] + indexer = slice(*(int(p.strip()) if p.strip() else None for p in parts_str)) + final_lines = line_list[indexer] + else: + indexer = int(lines_spec) + final_lines = [line_list[indexer]] + if not final_lines: + return "" + try: + indent_level = len(final_lines[0]) - len(final_lines[0].lstrip()) + return "\n".join(line[indent_level:] for line in final_lines) + except IndexError: + return "" + except Exception as e: + logger.warning(f"Error processing tagged_block: {e}") + return f"Error processing tagged_block for tag '{tag}' in '{filepath}'." + + +class TaggedBlockDirective(SphinxDirective): + has_content = False + required_arguments = 2 + optional_arguments = 1 + option_spec = { + "hl_lines": str, + } + + def run(self): + source_dir = Path(self.env.srcdir) + filepath = source_dir / self.arguments[0] + tag = self.arguments[1] + lines_spec = self.arguments[2] if len(self.arguments) > 2 else None + code = get_tagged_block(filepath, tag, lines_spec) + literal = nodes.literal_block(code, code) + literal["language"] = "python" + if "hl_lines" in self.options: + highlight_lines = parse_lines_spec(self.options["hl_lines"]) + literal["highlight_args"] = {"hl_lines": highlight_lines} + return [literal] + + +def setup(app): + app.add_directive("tagged-block", TaggedBlockDirective) + # This dictionary fixes the "parallel reading" warning + return { + "version": "0.1", + "parallel_read_safe": True, + "parallel_write_safe": True, + } diff --git a/docs/stateful-computations.md b/docs/stateful-computations.md index 30c626bec4e3..5c1100dee1d8 100644 --- a/docs/stateful-computations.md +++ b/docs/stateful-computations.md @@ -20,7 +20,7 @@ kernelspec: JAX transformations like {func}`~jax.jit`, {func}`~jax.vmap`, {func}`~jax.grad`, require the functions they wrap to be pure: that is, functions whose outputs depend *solely* on the inputs, and which have no side effects such as updating of global state. -You can find a discussion of this in [JAX sharp bits: Pure functions](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions). +You can find a discussion of this in [JAX sharp bits: Pure functions](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions). This constraint can pose some challenges in the context of machine learning, where state may exist in many forms. For example: @@ -235,4 +235,4 @@ Handling parameters manually seems fine if you're dealing with two parameters, b 2) Are we supposed to pipe all these things around manually? -The details can be tricky to handle, but there are examples of libraries that take care of this for you. See [JAX Neural Network Libraries](https://github.com/jax-ml/jax#neural-network-libraries) for some examples. +The details can be tricky to handle, but there are examples of libraries that take care of this for you. See [JAX Ecosystem Libraries](https://docs.jax.dev/en/latest/#ecosystem) for some examples. diff --git a/docs/the-training-cookbook.py b/docs/the-training-cookbook.py new file mode 100644 index 000000000000..edfa581920e2 --- /dev/null +++ b/docs/the-training-cookbook.py @@ -0,0 +1,250 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools as ft +import itertools as it +import time +from dataclasses import dataclass +from typing import Iterator + +import jax +import jax.numpy as jnp +import numpy as np +from jax.sharding import AxisType + +ode = """ +We are the music makers + And we are the dreamers of dreams +Wandering by lone sea-breakers + And sitting by desolate streams; +World-losers and world-forsakers + On whom the pale moon gleams +Yet we are the movers and shaker + Of the world for ever, it seems +""" + +# tag: config +@jax.tree_util.register_static +@dataclass(kw_only=True, frozen=True) +class Config: + mesh_axis_names: tuple[str, ...] = ("fsdp",) + mesh_shape: tuple[int, ...] = (8,) + seq_length: int = 128 + + num_train_steps: int = 10**6 + host_batch_size: int = 16 + learning_rate: float = 1e-4 + beta_1: float = 0.9 + beta_2: float = 0.999 + eps: float = 1e-8 + eps_root: float = 0.0 + + param_seed: int = 12738 + num_layers: int = 4 + embed_dim: int = 512 + mlp_dim: int = 512 * 4 + vocab_size: int = 2**8 # uint8 ascii encoding + num_heads: int = 8 + head_dim: int = 128 + dtype: str = "bfloat16" + + embed: jax.P = jax.P(None, None) + pos_embed: jax.P = jax.P(None, None) + att_qkv: jax.P = jax.P(None, "fsdp", None, None) + att_out: jax.P = jax.P("fsdp", None, None) + mlp_in: jax.P = jax.P("fsdp", None) + mlp_out: jax.P = jax.P(None, "fsdp") + in_kernel: jax.P = jax.P(None, None) + in_bias: jax.P = jax.P(None) + out_kernel: jax.P = jax.P("fsdp", None) + out_bias: jax.P = jax.P(None) + + act_ids: jax.P = jax.P("fsdp") + act_seq: jax.P = jax.P("fsdp", None, None) + act_att: jax.P = jax.P("fsdp", None, None, None) + act_hidden: jax.P = jax.P("fsdp", None, None) + + def __post_init__(self): + mesh = jax.make_mesh(self.mesh_shape, self.mesh_axis_names, len(self.mesh_shape) * (AxisType.Explicit,)) + jax.sharding.set_mesh(mesh) + # tag: config + + +@jax.tree_util.register_pytree_with_keys_class +class dot_dict(dict): + __setattr__ = dict.__setitem__ + __getattr__ = dict.__getitem__ + + def tree_flatten_with_keys(self): + keys = tuple(sorted(self)) + return tuple((jax.tree_util.DictKey(k), self[k]) for k in keys), keys + + @classmethod + def tree_unflatten(cls, keys, values): + return cls(zip(keys, values)) + + +# tag: get-param-state +def init_param_state(config: Config) -> dot_dict: + root_key = jax.random.key(config.param_seed) + key = map(ft.partial(jax.random.fold_in, root_key), it.count()) + zero_init = jax.nn.initializers.constant(0.0) + he_init = jax.nn.initializers.he_normal(1, 1) + dtype = config.dtype + + params = dot_dict( + pos_embed=zero_init(next(key), (config.seq_length, config.embed_dim), dtype, config.pos_embed), + layers=dot_dict(), + ) + params.embedding = he_init(next(key), (config.vocab_size, config.embed_dim), dtype, config.embed) + params.linear_in = dot_dict( + kernel=he_init(next(key), (1, config.embed_dim), dtype, config.in_kernel), + bias=zero_init(next(key), (config.embed_dim,), dtype, config.in_bias), + ) + params.linear_out = dot_dict( + kernel=he_init(next(key), (config.embed_dim, config.vocab_size), dtype, config.out_kernel), + ) + for layer in range(config.num_layers): + qkv_shape = (3, config.embed_dim, config.num_heads, config.head_dim) + out_shape = (config.num_heads, config.head_dim, config.embed_dim) + params.layers[layer] = dot_dict( + attention=dot_dict( + qkv=he_init(next(key), qkv_shape, dtype, config.att_qkv), + out=he_init(next(key), out_shape, dtype, config.att_out), + ), + mlp=dot_dict( + in_kernel=he_init(next(key), (config.embed_dim, config.mlp_dim), dtype, config.mlp_in), + out_kernel=he_init(next(key), (config.mlp_dim, config.embed_dim), dtype, config.mlp_out), + ), + ) + return params # tag: get-param-state + + +# tag: model-apply +def model_apply(config: Config, params: dot_dict, tokens: jax.Array) -> jax.Array: + out = params.embedding.at[tokens].get(out_sharding=config.act_seq) + out += params.pos_embed + del tokens + + for layer in range(config.num_layers): + block = params.layers[layer] + att_skip = out # 1 billion dollars in venture capital funding please + qkv = jnp.einsum("bsd,3dkh->bs3kh", out, block.attention.qkv, out_sharding=config.act_att) + out = jax.nn.dot_product_attention(qkv[:, :, 0, :], qkv[:, :, 1, :], qkv[:, :, 2, :], is_causal=True) + out = jnp.einsum("bskh,khd->bsd", out, block.attention.out, out_sharding=config.act_seq) + out += att_skip + out *= jax.lax.rsqrt(jnp.linalg.norm(out, axis=-1, keepdims=True) + 1e-6) + + mlp_skip = out # machine learning circa 1986 + out = jnp.einsum("bsd,dh->bsh", out, block.mlp.in_kernel, out_sharding=config.act_hidden) + out = jax.nn.gelu(out) + out = jnp.einsum("bsh,hd->bsd", out, block.mlp.out_kernel, out_sharding=config.act_seq) + out += mlp_skip + out *= jax.lax.rsqrt(jnp.linalg.norm(out, axis=-1, keepdims=True) + 1e-6) + + logits = jnp.einsum("bsd,dl->bsl", out, params.linear_out.kernel, out_sharding=config.act_seq) + return logits # tag: model-apply + + +# tag: get-adam-state +def init_adam_state(param: jax.Array) -> dot_dict: + adam_state = dot_dict(mu=jnp.zeros_like(param), nu=jnp.zeros_like(param), count=jnp.array(0)) + return adam_state # tag: get-adam-state + + +# tag: adam-apply +def adam_update(config: Config, param: jax.Ref, grad: jax.Array, adam_state: dot_dict): + adam_state.mu[...] = (1 - config.beta_1) * adam_state.mu[...] + config.beta_1 * grad + adam_state.nu[...] = (1 - config.beta_2) * adam_state.nu[...] + config.beta_2 * grad**2 + adam_state.count[...] += 1 + + mu_hat = adam_state.mu[...] / (1 - config.beta_1 ** adam_state.count[...]) + nu_hat = adam_state.nu[...] / (1 - config.beta_2 ** adam_state.count[...]) + param[...] -= config.learning_rate * mu_hat / (jnp.sqrt(nu_hat + config.eps_root) + config.eps) + # tag: adam-apply + + +# tag: get-train-state +@jax.jit +def init_train_state(config: Config) -> dot_dict: + train_state = dot_dict() + train_state.params = init_param_state(config) + train_state.opt = jax.tree.map(init_adam_state, train_state.params) + return train_state # tag: get-train-state + + +# tag: train-step +@jax.jit +def train_step(config: Config, train_state: dot_dict, batch: dict) -> dict: + def loss_fn(params): + logits = model_apply(config, params, batch["observed_ids"]) + labels = jax.nn.one_hot(batch["target_ids"], config.vocab_size) + return -(labels * jax.nn.log_softmax(logits)).mean() + + params = jax.tree.map(jax.ref.get, train_state.params) + loss, grad = jax.value_and_grad(loss_fn)(params) + jax.tree.map(ft.partial(adam_update, config), train_state.params, grad, train_state.opt) + metrics = {"train_loss": loss} + return metrics # tag: train-step + + +# tag: record-writer +class RecordWriter: + prev_metrics = None + + def __call__(self, cur_metrics: dict): + self.prev_metrics, log_metrics = cur_metrics, self.prev_metrics + if log_metrics is None: + return + print(*it.starmap("{}: {}".format, log_metrics.items()), sep="\t") + # tag: record-writer + + +# tag: get-dataset +def get_dataset(config: Config, single_batch=ode) -> Iterator[dict[str, np.ndarray]]: + while True: + observed_array = np.frombuffer(single_batch.encode("ascii"), dtype=np.uint8) + target_array = np.roll(observed_array, -1) + time.sleep(0.5) + yield { # repeat the sequence across the batch size to simulate multiple data points + "observed_ids": np.tile(observed_array[: config.seq_length], (config.host_batch_size, 1)), + "target_ids": np.tile(target_array[: config.seq_length], (config.host_batch_size, 1)), + } + # tag: get-dataset + + +# tag: get-dataset-on-device +def get_dataset_on_device(config: Config) -> Iterator[dict[str, jax.Array]]: + datset = get_dataset(config) + sharding = jax.P(config.mesh_axis_names) + return map(ft.partial(jax.make_array_from_process_local_data, sharding), datset) + # tag: get-dataset-on-device + + +# tag: train-loop +def train_loop(config: Config): + record_writer = RecordWriter() + train_state = init_train_state(config) + train_state = jax.tree.map(jax.ref.new_ref, train_state) + batch = iter(get_dataset_on_device(config)) + for step in range(config.num_train_steps): + metrics = train_step(config, train_state, next(batch)) + record_writer({"step": step} | metrics) + # tag: train-loop + + +if __name__ == "__main__": + jax.config.update("jax_platform_name", "cpu") + jax.config.update("jax_num_cpu_devices", 8) + train_loop(config=Config()) diff --git a/docs/the-training-cookbook.rst b/docs/the-training-cookbook.rst new file mode 100644 index 000000000000..4059ec5f8ed2 --- /dev/null +++ b/docs/the-training-cookbook.rst @@ -0,0 +1,355 @@ +===================== +The Training Cookbook +===================== + +Traditionally, machine learning codebases rely on libraries to perform much of the bookkeeping and parameter wrangling necessary for training large, complex models. While convenient, these libraries can abstract the key functionality and core APIs offered in JAX. The purpose of this cookbook, therefore, is to demonstrate best practices (or "recipes") for writing simple yet high-performance machine learning training code directly in JAX. Following the patterns documented below will prepare your machine learning workloads to maximally leverage our compiler (XLA) for performance and tractability. Most training scripts adhere roughly to the following structure: + +.. tagged-block:: the-training-cookbook.py train-loop + +For each line of code above, we will explain the best practices and showcase the core technologies we have assembled to empower you to write simple, yet unbelievably performant code in JAX. The code above is a segment of a self-contained, completely functional `companion script `_ in which we initialize a `Vaswani et al. (2017) `_ Transformer decoder, define the training loss for next-token prediction, and `Adam optimizer `_, in pure JAX. The code therein is suited to TPUs, CPUs, and GPUs, as well as single- and multi-host systems. For that reason, we use the terms *device* or *accelerator* to refer interchangeably to the hardware JAX is primarily performing arithmetic on—whether it be a TPU, GPU, or CPU—and *host system* to refer to operations performed exclusively using the host CPU. In this guide, there are many aspects of the JAX APIs we will gloss over for the sake of expediency. These are available for you to peruse at your leisure in our API documentation. However, there is a central JAX concept that one must confront in detail for much of what follows to cohere. + +Device Mesh and Shardings +------------------------- + +JAX employs the `Single Program, Multiple Data (SPMD) `_ model of parallelism. This means we write a single program that runs on multiple devices, using annotations to specify which part of the data each device is responsible for. The two primary concepts for this are the :class:`jax.sharding.Mesh` and :class:`jax.P`. + +Device Mesh +~~~~~~~~~~~ +A :class:`jax.sharding.Mesh` is an arrangement of all our accelerators into a NumPy ``ndarray``, together with string labels for the axes of the device array. The reason for using an array is that this allows for a very convenient annotation for how arrays should be partitioned across devices. For this introduction, we will use the notation of an ordered dictionary [#ordered]_, so that ``{"x": 2, "y": 4}`` refers to a device mesh of shape ``(2, 4)`` with labeled axes ``"x"`` and ``"y"``. To shard an array ``param``, we decorate it with a :class:`jax.P`, which is a tuple of ``str | None`` elements of the same length as the dimensions of the array. The ``jax.P`` specifies which axes of our array are to be sharded over which axes of devices. A more thorough account of the notation of shardings and sharded computations is available in :ref:`sharded-computation`. Some common sharding strategies such as data parallel, fully sharded data parallel, and basic tensor parallelism will be covered in :ref:`achieving-high-performance`. + +.. admonition:: Example + + Suppose we have a device mesh of ``{"x": 2, "y": 4}`` and an array ``param`` of shape ``(32, 64, 64, 128)``. If we shard this array with `jax.P(None, "x", "y", None) `, we end up with shards of size ``(32, 32, 16, 128)`` distributed across the devices. The ``None`` indicates that an axis should not be sharded. JAX implicitly broadcasts trailing axes, so an identical sharding can be achieved more concisely with `jax.P(None, "x", "y")`. As a result, the shorthand for a fully replicated array (of any dimension) is `jax.P()`. + +.. admonition:: Example + + More advanced mesh geometries are convenient when aligned with the communication hierarchy of our devices. Host-to-host communication is typically slower than accelerator-to-accelerator communication. Suppose we have two host machines, each with eight attached GPUs. One might arrange the devices into a mesh of ``{"host": 2, "gpu": 8}``. Then we can shard a parameter as follows: + + .. code-block:: python + + param = jnp.zeros((256, 192), out_sharding=jax.P("gpu", None)) + + The whole of ``param`` will be replicated twice, but within each host, it will be spread across the eight locally attached GPUs, with each GPU storing a shard of shape ``(32, 192)`` in HBM. This is particularly useful for :ref:`fsdp-sharding`. + + +Train State Initialization +-------------------------- + +.. tagged-block:: the-training-cookbook.py get-train-state + :hl_lines: 4 + +Before we can get started, the first thing we need to do is set up the train state. The train state encapsulates (unsurprisingly) all the *stateful* aspects of the training process. This typically includes, at a minimum, the model parameters and the optimizer state. The way we have structured this function (though you may choose to do otherwise) is to: + +1. Create a series of nested dictionaries to house the model parameters, and then + +2. :func:`jax.tree.map` over those parameters to produce a similar set of nested dictionaries to house the accompanying optimizer states. (More on this `below <#optimizer-initialization>`_.) + +Parameter Initialization +~~~~~~~~~~~~~~~~~~~~~~~~ +.. tagged-block:: the-training-cookbook.py get-train-state + :hl_lines: 4 + +To initialize our parameters, we build a series of nested dictionaries that correspond to the semantic sections of the neural network. If we were using a layer-based library such as PyTorch or Flax, these might correspond to neural network layers. For this example, we could, in fact, get by with a completely flattened dictionary, but the nested approach is convenient both for working with some of the APIs in JAX and for structuring our code. + +.. tagged-block:: the-training-cookbook.py get-param-state + +Our ``get_param_state`` function makes use of the ``constant`` and ``he_normal`` factories provided in :mod:`jax.nn.initializers`. These factories return an *initializer*, which is a function conforming to the following protocol: + +.. code-block:: python + + class Initializer(Protocol): + def __call__(self, key, shape, dtype, out_sharding) -> jax.Array: + ... + +The functional flavor of JAX requires explicit handling of all stochasticity (viz. :ref:`pseudorandom-numbers`), so we set up a little iterator that yields PRNG keys. Then, to build our parameters, we initialize them at their respective positions in the ``params`` nested dictionary, supplying the parameter shape, dtype, and sharding from the ``Config`` class. + +.. note:: + + By specifying the shardings here, we initialize each shard of each parameter directly on the correct device in the device mesh where it needs to be, preventing the need for needless host-to-device transfers or, in the case of a model that does not fit in system memory, avoiding out-of-memory errors. + +Optimizer Initialization +~~~~~~~~~~~~~~~~~~~~~~~~ +.. tagged-block:: the-training-cookbook.py get-train-state + :hl_lines: 5 + +When it comes to setting up the optimizer state, things are a little less straightforward than when we built the model parameters. The `Adam optimizer `_ requires that, for each parameter, we keep track of three optimization states: ``mu``, ``nu``, and ``count``. The simplest of these is ``count``, which stores the number of training steps we have performed. This is just a scalar used to de-bias the Adam updates. The ``mu`` and ``nu`` states will be arrays of the same shape, dtype, and sharding as the accompanying parameter ``param`` [#zeros_like]_ + +.. tagged-block:: the-training-cookbook.py get-adam-state + +When we use :func:`jax.tree.map`, it iterates over the items in ``train_state.params``. For each parameter, it creates a corresponding Adam state, resulting in a new nested dictionary that mirrors the structure of ``train_state.params``. Each leaf in this new structure contains the optimizer state for the corresponding parameter. + +The Train Step (Functional Transformations) +------------------------------------------- + +.. tagged-block:: the-training-cookbook.py train-step + +The train step is where we calculate the gradient of the model with respect to the current parameters and use the gradient, together with the optimizer, to update the parameters. To do this in JAX, we define the forward pass of the model, then we leverage JAX's functional transformations to automatically generate the backward pass, which we use to calculate the gradients and perform the update. + +Model Forward Pass +~~~~~~~~~~~~~~~~~~ + +.. tagged-block:: the-training-cookbook.py model-apply + +The model's forward pass is mostly unremarkable, aside from the ``out_sharding`` annotations we have supplied. These annotations declare what the result-sharding should be after the operation executes. The compiler uses these activation shardings, together with the parameter shardings we supplied when we `initialized the model <#parameter-initialization>`_, to dynamically insert `communication collectives `_ that ferry parameters and activations alike between devices. By choosing a good sharding strategy, we can achieve highly performant training (and inference) code. We will cover some standard strategies that serve most use cases in the section titled :ref:`achieving-high-performance`. For a detailed discussion of the principles underpinning the design of sharding strategies, see `The Scaling Cookbook `_. + +Gradient and Optimizer Update +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. tagged-block:: the-training-cookbook.py train-step + :hl_lines: 3-6 + +In order to calculate the gradient, we define the training loss. This is a function of the parameters that returns a scalar which summarizes how well our model, with the current ``train_state`` parameters, is explaining the data. + +.. tagged-block:: the-training-cookbook.py train-step 8 + +By supplying this function to :func:`jax.value_and_grad`, we transform it into a function that returns both the scalar value and the gradient of ``loss_fn`` evaluated at ``params`` (the *value* and *grad*). Since we have defined our parameters in terms of a series of nested dictionaries, the gradient will also be a series of nested dictionaries, mirroring the parameters. Recall that, unlike the parameters, the optimizer states contain some extra, deeper nested dictionaries corresponding to the optimizer state per parameter. Take a moment, before reading the explanation, to ponder what the semantics of the following function call might be: + +.. tagged-block:: the-training-cookbook.py train-step 9 + +Examining the call signature of the function ``adam_apply`` gives us a hint: + +.. tagged-block:: the-training-cookbook.py adam-apply + +Because ``train_state.params`` is the first argument, :func:`jax.tree.map` uses its tree structure to guide the mapping process. [#prefix_tree]_ This means that ``train_state.opt`` is traversed only as deep as the leaves of ``train_state.params``. The optimizer state for each parameter is therefore passed in as a complete subtree, which allows us to easily access all relevant states (like ``mu`` and ``nu``) for a given ``param`` inside ``adam_apply``. + +.. tip:: + + If we wished to use different optimization algorithms and states on different parameters in our model (or freeze some parameters), we could achieve this by modifying the body of ``adam_apply`` and replacing :func:`jax.tree.map` with :func:`jax.tree_util.tree_map_with_path`, which allows the operand function to customize its behavior depending on the parameter. + +The Training Loop +----------------- +.. tagged-block:: the-training-cookbook.py train-loop + :hl_lines: 11-13 + +During training, we have to orchestrate the flow of data between two key players: the host system and the accelerator. Ensuring smooth interplay between these systems is key to writing highly performant training code. The Python `GIL `_ would ordinarily pose a significant obstacle here, but to work around this, the paradigm of :ref:`Asynchronous Dispatch ` adopted by JAX makes this orchestration easy to accomplish. But, in order to leverage this paradigm, we need to be mindful of how our code will be executed when structuring our training step. + +Efficiency via Asynchronous Dispatch +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +One of the most important tasks performed by the host system is to fetch data and place it on the accelerators so that the accelerators are never waiting for data. The time when accelerators are waiting idle between train steps is referred to as the *step bubble*. We can leverage asynchronous dispatch to minimize the step bubble. Let's see how this works with our training loop, discarding, for the moment, the line concerning the ``record_writer``. + +.. tagged-block:: the-training-cookbook.py train-loop 5:7 + +When this code executes, Python will first query the range iterator, get ``step`` (with value ``0``), then call ``next(batch)``, which will take some time to retrieve the batch. Then, ``train_step`` gets called. So far, nothing out of the ordinary. + +What happens next is interesting. Because :func:`jax.jit`-decorated calls are non-blocking, the call to ``train_step`` returns to the Python interpreter immediately. While the computation is enqueued on the accelerator, no work is actually performed yet. The Python loop continues, advancing the step counter and calling ``next(batch)`` for the *next* iteration. Once the second call to ``train_step`` is made, its inputs are now the mutated reference to ``train_state`` from the previous JIT call and a fresh batch of data. The runtime is clever and sees that in order to execute the second call to ``train_step``, we first need to realize the ``train_state`` result of step ``0`` to perform the mutation. And so it fires off the computation for the first step, and, crucially, while this happens, ``train_step``, once again, returns immediately, and the loop skips over again. Python now runs ahead until it encounters the ``next(batch)`` function at step 3, which proceeds to execute in Python, loading data, *while* the first train step is executing (for real this time). And just like that, we can simultaneously load data and perform math on the accelerator, without any traditional multiprocessing. [#sleep]_ + +.. mermaid:: + + --- + displayMode: compact + --- + gantt + title Synchronous Dispatch: No Overlap + axisFormat % + + section Host + next(batch) :gb0, 0, 1000s + next(batch) :gb1, after ajc0, 1000s + next(batch) :gb2, after ajc1, 1000s + + section Accelerator + + train_step 0 :ajc0, after gb0, 2000s + train_step 1 :ajc1, after gb1, 2000s + + +.. mermaid:: + + --- + displayMode: compact + --- + gantt + title JAX Asynchronous Dispatch: Host-Device Overlap + axisFormat % + + section Host + %% Task: id, name, start, duration_or_end + next(batch) :gb0, 0, 1000s + next(batch) :gb1, after gb0, 1000s + next(batch) :gb2, after gb1, 1000s + next(batch) :gb3, after jc0, 1000s + next(batch) :gb4, after jc1, 1000s + + section Accelerator + %% Task: id, name, start, duration_or_end + train_step 0 :jc0, after gb1, 2000s + train_step 1 :jc1, after jc0, 2000s + train_step 2 :jc2, after jc1, 2000s + +Common Mistakes +~~~~~~~~~~~~~~~ +When writing asynchronous dispatch code in Python, there are two primary mistakes one should be wary of so as not to interrupt our careful orchestration of compute. + +Requesting device-to-host transfers +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Up until now, we have ignored what happens to the variable ``metrics``. Indeed, if this is left dangling, nothing will happen, and we will achieve good overlap just as advertised. However, more often than not, we would like to observe telemetry from our train step, such as the current loss, gradient statistics, and so on. Suppose we were to insert code such as: + +.. code-block:: python + + metrics = train_step(config, train_state, next(batch)) + print({"step": step} | metrics) + +Instead of the loop ticking over, ``print`` will incur a device-to-host transfer of whatever on-device arrays are in ``metrics``. This interrupts the Python interpreter, and the code is forced to execute synchronously, producing a step bubble. The solution is slightly counterintuitive: at each step, we gather the telemetry for the *previous* step. + +.. tagged-block:: the-training-cookbook.py record-writer + +and + +.. tagged-block:: the-training-cookbook.py train-loop 6:7 + +A small helper function like this is essential to achieve good overlap and make the most of the resources of our host system and our accelerator. Of course, the simple ``print`` statement here can be swapped out for any Python operation that requests data from the accelerator. + +Interrupting the accelerator +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +The other common way in which we can waste spectacular amounts of cloud compute money is by unintentionally enqueuing math operations on the accelerator outside of the train step. Suppose we are using a cosine learning rate schedule. + +.. code-block:: python + + def learning_rate(count, init_value: float = 1e-4, decay_steps: int = 10_000, alpha: float = 1e-6): + cosine_decay = 0.5 * (1 + jnp.cos(jnp.pi * jnp.minimum(count, decay_steps) / decay_steps)) + return init_value * (1 - alpha) * cosine_decay + +A common pattern is to want to visualize the schedule alongside the other metrics we're gathering. However, even if we use the clever ``record_writer`` class we defined earlier, the following code will create a bubble on the accelerator. + +.. code-block:: python + + metrics = train_step(config, train_state, next(batch)) + record_writer({"step": step, "learning_rate": learning_rate(step)} | metrics) + + +This is because we have used :mod:`jax.numpy` in our calculations. When :func:`jax.numpy.minimum` is called, the Python integer ``step`` is promoted to a :class:`jax.Array` and transferred to the accelerator (a host-to-device transfer). The calculation is now enqueued on the accelerator, outside our main ``train_step``. To ``print`` the result, the value must be transferred back to the host (a device-to-host transfer). This round-trip forces the accelerator to synchronize with the host, and we have thrown away money by creating a performance bubble. The two ways to avoid this are to use NumPy for these calculations or to use the :func:`jax.default_device` context manager. + +.. code-block:: python + + metrics = train_step(config, train_state, next(batch)) + with jax.default_device('cpu'): + record_writer({"step": step, "learning_rate": learning_rate(step)} | metrics) + + +Data Loading +~~~~~~~~~~~~ +In addition to overlapping the actual loading of the data (that is, retrieving it from network storage to the host), JAX also allows us to overlap the host-to-device transfer of the data itself with the computation of the train step. The special function :func:`jax.device_put` is carefully designed to be non-blocking, executing asynchronously, which makes it perfectly fine to use in the context of our train step. However, there is a more convenient function specifically designed for the task of loading data. In the following code, ``dataset`` is an ordinary Python iterator that yields a ``dict`` of batched data. By mapping over this iterator with :func:`jax.make_array_from_process_local_data`, we generate a new iterator. Yielding from this new iterator will generate data placed on the device, ready for consumption by our train step. Internally, it will :func:`jax.tree.map` to create :class:`jax.Array` objects and queue them to be transferred to the device. Provided the data can be batched fast enough, on both TPUs and GPUs, these transfers will be overlapped with the train step computation. + +.. tagged-block:: the-training-cookbook.py get-dataset-on-device + + +.. _achieving-high-performance: + +Achieving High Performance +-------------------------- + +In this section, we will describe the three primary forms of model parallelism that are useful for training. During training, *throughput* is of paramount importance; that is, we wish to maximize the average number of operations per second. This contrasts with inference, where the goal is to minimize *latency* by ensuring all the operations happen in as little time as possible. Keeping throughput in mind as our ultimate goal for training, this section introduces the three primary strategies for sharding during training. For each strategy, we outline the JAX shardings that implement it and describe the collectives involved so that when studying program traces, you'll have landmarks to look for to confirm that the program is behaving as expected. The sharding variables we define in the code blocks below correspond to their uses in the `initialization <#train-state-initialization>`_ and `model forward pass <#model-forward-pass>`_. But in the companion script these and other aspects of the training code are set conveniently using the global `Config` class. + +.. tagged-block:: the-training-cookbook.py config + + +Data Parallel +~~~~~~~~~~~~~ +Data parallel is the most common and easy-to-understand form of parallelism. In this scheme, each accelerator stores a complete copy of the model parameters, and we shard activations along the batch axis to split the computation of the gradients. To compute the gradients, each accelerator performs an individual forward and backward pass. Then, before the parameters are updated, XLA inserts an ``AllReduce`` to share the updates and keep the models in sync. + +*Mesh:* + +.. code-block:: python + + mesh = jax.sharding.Mesh(jax.devices(), ('devices',)) + +*Parameter Shardings:* + +.. code-block:: python + + pos_embed = jax.P(None, None) + att_qkv = jax.P(None, None, None, None) + att_out = jax.P(None, None, None) + mlp_in = jax.P(None, None) + mlp_out = jax.P(None, None) + in_kernel = jax.P(None, None) + in_bias = jax.P(None) + out_kernel = jax.P(None, None) + out_bias = jax.P(None) + +*Activation Shardings:* + +.. code-block:: python + + act_ids = jax.P("devices") + act_seq = jax.P("devices", None, None) + act_att = jax.P("devices", None, None, None) + act_hidden = jax.P("devices", None, None) + + +.. _fsdp-sharding: + +Fully-Sharded Data Parallel (FSDP) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +The drawback of data-parallel sharding is that we have to keep multiple, full, redundant copies of the model parameters in HBM. This is a very performant strategy for small models, but since HBM is in short supply, we need to shard the model parameters as well. In the *Fully-Sharded Data Parallel (FSDP)* strategy, we shard both the model and the parameters. Now, as the forward pass happens, the parameters are, one-by-one, unsharded (via ``AllGather``) into whole arrays before they are applied to the activations. This unsharding is brief and temporary, however, leading to a large saving in HBM. In the backward pass, each ``AllGather`` becomes a ``ReduceScatter``. Then there is a final ``ReduceScatter`` at the optimizer update to synchronize gradients. Compared with Data parallelism, the total communication traffic is 50% highter, but we our HBM pressure is reduced by the size of the model divided by the number of devices. + +*Mesh:* + +.. code-block:: python + + mesh = jax.make_mesh((128*4,), ("fsdp",)) + +*Parameter Shardings:* + +.. code-block:: python + + pos_embed = jax.P(None, None) + att_qkv = jax.P(None, "fsdp", None, None) + att_out = jax.P("fsdp", None, None) + mlp_in = jax.P("fsdp", None) + mlp_out = jax.P(None, "fsdp") + in_kernel = jax.P(None, None) + in_bias = jax.P(None) + out_kernel = jax.P("fsdp", None) + out_bias = jax.P(None) + +*Activation Shardings:* + +.. code-block:: python + + act_ids = jax.P("fsdp") + act_seq = jax.P("fsdp", None, None) + act_att = jax.P("fsdp", None, None, None) + act_hidden = jax.P("fsdp", None, None) + + +.. note:: + + While FSDP entails a great deal more communication than data parallel, in practice we are able to overlap the communication with the compute, thereby hiding it and achieving the same throughput at a drastically improved HBM budget. + +Tensor Parallel +~~~~~~~~~~~~~~~ +If our model is large enough and structured appropriately, it becomes beneficial to partition the computation within a single example across our accelerators. Using a matrix multiplication as an example, we can spread the large matrix multiplications over two or four accelerators. This entails significantly more communication, and so this strategy only works for computations with a very high arithmetic intensity, such as extremely large matrix multiplications. With multi-head self-attention, we opt to shard along the heads with a replicated sequence axis, since this offers the most natural amount of parallelism. If the MLP is large enough we can also efficiently shard the matrix multiplications. + +*Mesh:* + +.. code-block:: python + + mesh = jax.make_mesh((128,4), ("fsdp", "tensor")) + +*Parameter Shardings:* + +.. code-block:: python + + pos_embed = jax.P(None, "tensor") + att_qkv = jax.P(None, "fsdp", "tensor", None) + att_out = jax.P("fsdp", None, None) + mlp_in = jax.P("fsdp", "tensor") + mlp_out = jax.P("tensor", "fsdp") + in_kernel = jax.P(None, None) + in_bias = jax.P(None) + out_kernel = jax.P("fsdp", None) + out_bias = jax.P(None) + +*Activation Shardings:* + +.. code-block:: python + + act_ids = jax.P("fsdp") + act_seq = jax.P("fsdp", None, None) + act_att = jax.P("fsdp", None, "tensor", None) + act_hidden = jax.P("fsdp", None, "tensor") + +.. [#ordered] Of course, all dictionaries are order-preserving in modern Python, so this is somewhat redundant. +.. [#zeros_like] This is accomplished by using the ``zeros_like`` constructor, but we could have specified the sharding manually using the ``devices`` argument of many of the :mod:`jax.numpy` functions. +.. [#prefix_tree] We could have achieved the same behavior equivalently by ordering ``grad`` first. +.. [#sleep] For the purposes of this explanation, you can think of ``next(batch)`` as just a sleep. diff --git a/docs/tracing.md b/docs/tracing.md new file mode 100644 index 000000000000..5167c8822a93 --- /dev/null +++ b/docs/tracing.md @@ -0,0 +1,236 @@ +--- +jupytext: + formats: md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.4 +kernelspec: + display_name: Python 3 + language: python + name: python3 +--- + +(tracing-tutorial)= +# Tracing + +`jax.jit` and other JAX transforms work by *tracing* a function to determine its effect on inputs of a specific shape and type. For a window into tracing, let's put a few `print()` statements within a JIT-compiled function and then call the function: + +```{code-cell} +from jax import jit +import jax.numpy as jnp +import numpy as np + +@jit +def f(x, y): + print("Running f():") + print(f" x = {x}") + print(f" y = {y}") + result = jnp.dot(x + 1, y + 1) + print(f" result = {result}") + return result + +x = np.random.randn(3, 4) +y = np.random.randn(4) +f(x, y) +``` + +Notice that the print statements execute, but rather than printing the data we +passed to the function, though, it prints *tracer* objects that stand-in for +them (something like `Traced`). + +These tracer objects are what `jax.jit` uses to extract the sequence of +operations specified by the function. Basic tracers are stand-ins that encode +the **shape** and **dtype** of the arrays, but are agnostic to the values. This +recorded sequence of computations can then be efficiently applied within XLA to +new inputs with the same shape and dtype, without having to re-execute the +Python code. + +When we call the compiled function again on matching inputs, no re-compilation +is required and nothing is printed because the result is computed in compiled +XLA rather than in Python: + +```{code-cell} +x2 = np.random.randn(3, 4) +y2 = np.random.randn(4) +f(x2, y2) +``` + +The extracted sequence of operations is encoded in a JAX expression, or +[*jaxpr*](key-concepts-jaxprs) for short. You can view the jaxpr using the +`jax.make_jaxpr` transformation: + +```{code-cell} +from jax import make_jaxpr + +def f(x, y): + return jnp.dot(x + 1, y + 1) + +make_jaxpr(f)(x, y) +``` + +Note one consequence of this: because JIT compilation is done *without* +information on the content of the array, control flow statements in the function +cannot depend on traced values (see {ref}`control-flow`). For example, this fails: + +```{code-cell} +:tags: [raises-exception] + +@jit +def f(x, neg): + return -x if neg else x + +f(1, True) +``` + +If there are variables that you would not like to be traced, they can be marked +as *static* for the purposes of JIT compilation: + +```{code-cell} +from functools import partial + +@partial(jit, static_argnums=(1,)) +def f(x, neg): + return -x if neg else x + +f(1, True) +``` + +Note that calling a JIT-compiled function with a different static argument +results in re-compilation, so the function still works as expected: + +```{code-cell} +f(1, False) +``` + +### Static vs traced operations + +Just as values can be either static or traced, operations can be static or +traced. Static operations are evaluated at compile-time in Python; traced +operations are compiled & evaluated at run-time in XLA. + +This distinction between static and traced values makes it important to think +about how to keep a static value static. Consider this function: + +```{code-cell} +:tags: [raises-exception] + +import jax.numpy as jnp +from jax import jit + +@jit +def f(x): + return x.reshape(jnp.array(x.shape).prod()) + +x = jnp.ones((2, 3)) +f(x) +``` + +This fails with an error specifying that a tracer was found instead of a 1D +sequence of concrete values of integer type. Let's add some print statements to +the function to understand why this is happening: + +```{code-cell} +@jit +def f(x): + print(f"x = {x}") + print(f"x.shape = {x.shape}") + print(f"jnp.array(x.shape).prod() = {jnp.array(x.shape).prod()}") + # comment this out to avoid the error: + # return x.reshape(jnp.array(x.shape).prod()) + +f(x) +``` + +Notice that although `x` is traced, `x.shape` is a static value. However, when +we use `jnp.array` and `jnp.prod` on this static value, it becomes a traced +value, at which point it cannot be used in a function like `reshape()` that +requires a static input (recall: array shapes must be static). + +A useful pattern is to use `numpy` for operations that should be static (i.e. +done at compile-time), and use `jax.numpy` for operations that should be traced +(i.e. compiled and executed at run-time). For this function, it might look like +this: + +```{code-cell} +from jax import jit +import jax.numpy as jnp +import numpy as np + +@jit +def f(x): + return x.reshape((np.prod(x.shape),)) + +f(x) +``` + +For this reason, a standard convention in JAX programs is to +`import numpy as np` and `import jax.numpy as jnp` so that both interfaces are +available for finer control over whether operations are performed in a static +manner (with `numpy`, once at compile-time) or a traced manner (with +`jax.numpy`, optimized at run-time). + +Understanding which values and operations will be static and which will be +traced is a key part of using `jax.jit` effectively. + +(faq-different-kinds-of-jax-values)= +## Different kinds of JAX values + +A tracer value carries an **abstract** value, e.g., `ShapedArray` with +information about the shape and dtype of an array. We will refer here to such +tracers as **abstract tracers**. Some tracers, e.g., those that are introduced +for arguments of autodiff transformations, carry `ConcreteArray` abstract values +that actually include the regular array data, and are used, e.g., for resolving +conditionals. We will refer here to such tracers as **concrete tracers**. Tracer +values computed from these concrete tracers, perhaps in combination with regular +values, result in concrete tracers. A **concrete value** is either a regular +value or a concrete tracer. + +Typically, computations that involve at least a tracer value will produce a +tracer value. There are very few exceptions, when a computation can be +entirely done using the abstract value carried by a tracer, in which case the +result can be a **regular** Python value. For example, getting the shape of a +tracer with `ShapedArray` abstract value. Another example is when explicitly +casting a concrete tracer value to a regular type, e.g., `int(x)` or +`x.astype(float)`. Another such situation is for `bool(x)`, which produces a +Python bool when concreteness makes it possible. That case is especially salient +because of how often it arises in control flow. + +Here is how the transformations introduce abstract or concrete tracers: + +* {func}`jax.jit`: introduces **abstract tracers** for all positional arguments + except those denoted by `static_argnums`, which remain regular + values. +* {func}`jax.pmap`: introduces **abstract tracers** for all positional arguments + except those denoted by `static_broadcasted_argnums`. +* {func}`jax.vmap`, {func}`jax.make_jaxpr`, {func}`xla_computation`: + introduce **abstract tracers** for all positional arguments. +* {func}`jax.jvp` and {func}`jax.grad` introduce **concrete tracers** + for all positional arguments. An exception is when these transformations + are within an outer transformation and the actual arguments are + themselves abstract tracers; in that case, the tracers introduced + by the autodiff transformations are also abstract tracers. +* All higher-order control-flow primitives ({func}`lax.cond`, + {func}`lax.while_loop`, {func}`lax.fori_loop`, {func}`lax.scan`) when they + process the functionals introduce **abstract tracers**, whether or not there + is a JAX transformation in progress. + +All of this is relevant when you have code that can operate +only on regular Python values, such as code that has conditional +control-flow based on data: + +```{code-cell} +def divide(x, y): + return x / y if y >= 1. else 0. +``` + +If we want to apply {func}`jax.jit`, we must ensure to specify `static_argnums=1` +to ensure `y` stays a regular value. This is due to the boolean expression +`y >= 1.`, which requires concrete values (regular or tracers). The +same would happen if we write explicitly `bool(y >= 1.)`, or `int(y)`, +or `float(y)`. + +Interestingly, `jax.grad(divide)(3., 2.)`, works because {func}`jax.grad` +uses concrete tracers, and resolves the conditional using the concrete +value of `y`. diff --git a/docs/tutorials.rst b/docs/tutorials.rst deleted file mode 100644 index c9c2fdb1dcc7..000000000000 --- a/docs/tutorials.rst +++ /dev/null @@ -1,29 +0,0 @@ -.. _jax-tutorials: - -Tutorials -========= - -.. toctree:: - :maxdepth: 1 - - quickstart - key-concepts - jit-compilation - automatic-vectorization - automatic-differentiation - debugging - random-numbers - working-with-pytrees - sharded-computation - stateful-computations - control-flow - -.. toctree:: - :maxdepth: 1 - :caption: Advanced tutorials - - advanced-autodiff - external-callbacks - gradient-checkpointing - jax-primitives - jaxpr diff --git a/docs/type_promotion.rst b/docs/type_promotion.rst index d3724745fe08..60695f6eeafb 100644 --- a/docs/type_promotion.rst +++ b/docs/type_promotion.rst @@ -4,7 +4,7 @@ Type promotion semantics ======================== This document describes JAX's type promotion rules–i.e., the result of :func:`jax.numpy.promote_types` for each pair of types. -For some background on the considerations that went into the design of what is described below, see `Design of Type Promotion Semantics for JAX `_. +For some background on the considerations that went into the design of what is described below, see `Design of Type Promotion Semantics for JAX `_. JAX's type promotion behavior is determined via the following type promotion lattice: @@ -119,7 +119,7 @@ on this lattice, which generates the following binary promotion table: for t1 in types: out += "{}".format(name(t1)) for t2 in types: - t, weak_type = dtypes._lattice_result_type(t1, t2) + t, weak_type = dtypes.lattice_result_type(t1, t2) if weak_type: t = type(t.type(0).item()) different = jnp.bfloat16 in (t1, t2) or jnp.promote_types(t1, t2) is not np.promote_types(t1, t2) diff --git a/docs/user_guides.rst b/docs/user_guides.rst deleted file mode 100644 index 6481da7a31dd..000000000000 --- a/docs/user_guides.rst +++ /dev/null @@ -1,45 +0,0 @@ -.. _user-guides: - -User guides -=========== - -User guides are deeper dives into particular topics within JAX -that become relevant as your JAX project matures into larger -or deployed codebases. - -.. toctree:: - :maxdepth: 1 - :caption: Debugging and performance - - notebooks/thinking_in_jax - profiling - device_memory_profiling - debugging/index - gpu_performance_tips - persistent_compilation_cache - -.. toctree:: - :maxdepth: 1 - :caption: Interfaces - - pytrees - errors - aot - export/index - type_promotion - transfer_guard - -.. toctree:: - :maxdepth: 1 - :caption: Custom operations - - pallas/index - ffi - -.. toctree:: - :caption: Example applications - :maxdepth: 1 - - notebooks/neural_network_with_tfds_data - notebooks/Neural_Network_and_Data_Loading - notebooks/vmapped_log_probs diff --git a/docs/working-with-pytrees.md b/docs/working-with-pytrees.md deleted file mode 100644 index ffa47eba07c0..000000000000 --- a/docs/working-with-pytrees.md +++ /dev/null @@ -1,517 +0,0 @@ ---- -jupytext: - formats: md:myst - text_representation: - extension: .md - format_name: myst - format_version: 0.13 - jupytext_version: 1.16.4 -kernelspec: - display_name: Python 3 - language: python - name: python3 ---- - -```{code-cell} -:tags: [remove-cell] - -# This ensures that code cell tracebacks appearing below will be concise. -%xmode minimal -``` - -(working-with-pytrees)= -# Working with pytrees - - - -JAX has built-in support for objects that look like dictionaries (dicts) of arrays, or lists of lists of dicts, or other nested structures — in JAX these are called pytrees. -This section will explain how to use them, provide useful code examples, and point out common "gotchas" and patterns. - - -(pytrees-what-is-a-pytree)= -## What is a pytree? - -A pytree is a container-like structure built out of container-like Python objects — “leaf” pytrees and/or more pytrees. A pytree can include lists, tuples, and dicts. A leaf is anything that’s not a pytree, such as an array, but a single leaf is also a pytree. - -In the context of machine learning (ML), a pytree can contain: - -- Model parameters -- Dataset entries -- Reinforcement learning agent observations - -When working with datasets, you can often come across pytrees (such as lists of lists of dicts). - -Below is an example of a simple pytree. In JAX, you can use {func}`jax.tree.leaves`, to extract the flattened leaves from the trees, as demonstrated here: - -```{code-cell} -import jax -import jax.numpy as jnp - -example_trees = [ - [1, 'a', object()], - (1, (2, 3), ()), - [1, {'k1': 2, 'k2': (3, 4)}, 5], - {'a': 2, 'b': (2, 3)}, - jnp.array([1, 2, 3]), -] - -# Print how many leaves the pytrees have. -for pytree in example_trees: - # This `jax.tree.leaves()` method extracts the flattened leaves from the pytrees. - leaves = jax.tree.leaves(pytree) - print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}") -``` - -Any tree-like structure built out of container-like Python objects can be treated as a pytree in JAX. -Classes are considered container-like if they are in the pytree registry, which by default includes lists, tuples, and dicts. Any object whose type is *not* in the pytree container registry will be treated as a leaf node in the tree. - -The pytree registry can be extended to include user-defined container classes by registering the class -with functions that specify how to flatten the tree; see {ref}`pytrees-custom-pytree-nodes` below. - -(pytrees-common-pytree-functions)= -## Common pytree functions - -JAX provides a number of utilities to operate over pytrees. These can be found in the {mod}`jax.tree_util` subpackage; -for convenience many of these have aliases in the {mod}`jax.tree` module. - -### Common function: `jax.tree.map` - -The most commonly used pytree function is {func}`jax.tree.map`. It works analogously to Python's native `map`, but transparently operates over entire pytrees. - -Here's an example: - -```{code-cell} -list_of_lists = [ - [1, 2, 3], - [1, 2], - [1, 2, 3, 4] -] - -jax.tree.map(lambda x: x*2, list_of_lists) -``` - -{func}`jax.tree.map` also allows mapping a [N-ary](https://en.wikipedia.org/wiki/N-ary) function over multiple arguments. For example: - -```{code-cell} -another_list_of_lists = list_of_lists -jax.tree.map(lambda x, y: x+y, list_of_lists, another_list_of_lists) -``` - -When using multiple arguments with {func}`jax.tree.map`, the structure of the inputs must exactly match. That is, lists must have the same number of elements, dicts must have the same keys, etc. - -(pytrees-example-jax-tree-map-ml)= -### Example of `jax.tree.map` with ML model parameters - -This example demonstrates how pytree operations can be useful when training a simple [multi-layer perceptron (MLP)](https://en.wikipedia.org/wiki/Multilayer_perceptron). - -Begin with defining the initial model parameters: - -```{code-cell} -import numpy as np - -def init_mlp_params(layer_widths): - params = [] - for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]): - params.append( - dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in), - biases=np.ones(shape=(n_out,)) - ) - ) - return params - -params = init_mlp_params([1, 128, 128, 1]) -``` - -Use {func}`jax.tree.map` to check the shapes of the initial parameters: - -```{code-cell} -jax.tree.map(lambda x: x.shape, params) -``` - -Next, define the functions for training the MLP model: - -```{code-cell} -# Define the forward pass. -def forward(params, x): - *hidden, last = params - for layer in hidden: - x = jax.nn.relu(x @ layer['weights'] + layer['biases']) - return x @ last['weights'] + last['biases'] - -# Define the loss function. -def loss_fn(params, x, y): - return jnp.mean((forward(params, x) - y) ** 2) - -# Set the learning rate. -LEARNING_RATE = 0.0001 - -# Using the stochastic gradient descent, define the parameter update function. -# Apply `@jax.jit` for JIT compilation (speed). -@jax.jit -def update(params, x, y): - # Calculate the gradients with `jax.grad`. - grads = jax.grad(loss_fn)(params, x, y) - # Note that `grads` is a pytree with the same structure as `params`. - # `jax.grad` is one of many JAX functions that has - # built-in support for pytrees. - # This is useful - you can apply the SGD update using JAX pytree utilities. - return jax.tree.map( - lambda p, g: p - LEARNING_RATE * g, params, grads - ) -``` - -(pytrees-custom-pytree-nodes)= -## Custom pytree nodes - -This section explains how in JAX you can extend the set of Python types that will be considered _internal nodes_ in pytrees (pytree nodes) by using {func}`jax.tree_util.register_pytree_node` with {func}`jax.tree.map`. - -Why would you need this? In the previous examples, pytrees were shown as lists, tuples, and dicts, with everything else as pytree leaves. This is because if you define your own container class, it will be considered to be a pytree leaf unless you _register_ it with JAX. This is also the case even if your container class has trees inside it. For example: - -```{code-cell} -class Special(object): - def __init__(self, x, y): - self.x = x - self.y = y - -jax.tree.leaves([ - Special(0, 1), - Special(2, 4), -]) -``` - -Accordingly, if you try to use a {func}`jax.tree.map` expecting the leaves to be elements inside the container, you will get an error: - -```{code-cell} -:tags: [raises-exception] - -jax.tree.map(lambda x: x + 1, - [ - Special(0, 1), - Special(2, 4) - ]) -``` - -As a solution, JAX allows to extend the set of types to be considered internal pytree nodes through a global registry of types. Additionally, the values of registered types are traversed recursively. - -First, register a new type using {func}`jax.tree_util.register_pytree_node`: - -```{code-cell} -from jax.tree_util import register_pytree_node - -class RegisteredSpecial(Special): - def __repr__(self): - return "RegisteredSpecial(x={}, y={})".format(self.x, self.y) - -def special_flatten(v): - """Specifies a flattening recipe. - - Params: - v: The value of the registered type to flatten. - Returns: - A pair of an iterable with the children to be flattened recursively, - and some opaque auxiliary data to pass back to the unflattening recipe. - The auxiliary data is stored in the treedef for use during unflattening. - The auxiliary data could be used, for example, for dictionary keys. - """ - children = (v.x, v.y) - aux_data = None - return (children, aux_data) - -def special_unflatten(aux_data, children): - """Specifies an unflattening recipe. - - Params: - aux_data: The opaque data that was specified during flattening of the - current tree definition. - children: The unflattened children - - Returns: - A reconstructed object of the registered type, using the specified - children and auxiliary data. - """ - return RegisteredSpecial(*children) - -# Global registration -register_pytree_node( - RegisteredSpecial, - special_flatten, # Instruct JAX what are the children nodes. - special_unflatten # Instruct JAX how to pack back into a `RegisteredSpecial`. -) -``` - -Now you can traverse the special container structure: - -```{code-cell} -jax.tree.map(lambda x: x + 1, - [ - RegisteredSpecial(0, 1), - RegisteredSpecial(2, 4), - ]) -``` - -Modern Python comes equipped with helpful tools to make defining containers easier. Some will work with JAX out-of-the-box, but others require more care. - -For instance, a Python `NamedTuple` subclass doesn't need to be registered to be considered a pytree node type: - -```{code-cell} -from typing import NamedTuple, Any - -class MyOtherContainer(NamedTuple): - name: str - a: Any - b: Any - c: Any - -# NamedTuple subclasses are handled as pytree nodes, so -# this will work out-of-the-box. -jax.tree.leaves([ - MyOtherContainer('Alice', 1, 2, 3), - MyOtherContainer('Bob', 4, 5, 6) -]) -``` - -Notice that the `name` field now appears as a leaf, because all tuple elements are children. This is what happens when you don't have to register the class the hard way. - -Unlike `NamedTuple` subclasses, classes decorated with `@dataclass` are not automatically pytrees. However, they can be registered as pytrees using the {func}`jax.tree_util.register_dataclass` decorator: - -```{code-cell} -from dataclasses import dataclass -import functools - -@functools.partial(jax.tree_util.register_dataclass, - data_fields=['a', 'b', 'c'], - meta_fields=['name']) -@dataclass -class MyDataclassContainer(object): - name: str - a: Any - b: Any - c: Any - -# MyDataclassContainer is now a pytree node. -jax.tree.leaves([ - MyDataclassContainer('apple', 5.3, 1.2, jnp.zeros([4])), - MyDataclassContainer('banana', np.array([3, 4]), -1., 0.) -]) -``` - -Notice that the `name` field does not appear as a leaf. This is because we included it in the `meta_fields` argument to {func}`jax.tree_util.register_dataclass`, indicating that it should be treated as metadata/auxiliary data, just like `aux_data` in `RegisteredSpecial` above. Now instances of `MyDataclassContainer` can be passed into JIT-ed functions, and `name` will be treated as static (see {ref}`jit-marking-arguments-as-static` for more information on static args): - -```{code-cell} -@jax.jit -def f(x: MyDataclassContainer | MyOtherContainer): - return x.a + x.b - -# Works fine! `mdc.name` is static. -mdc = MyDataclassContainer('mdc', 1, 2, 3) -y = f(mdc) -``` - -Contrast this with `MyOtherContainer`, the `NamedTuple` subclass. Since the `name` field is a pytree leaf, JIT expects it to be convertible to {class}`jax.Array`, and the following raises an error: - -```{code-cell} -:tags: [raises-exception] - -moc = MyOtherContainer('moc', 1, 2, 3) -y = f(moc) -``` - -(pytree-and-jax-transformations)= -## Pytrees and JAX transformations - -Many JAX functions, like {func}`jax.lax.scan`, operate over pytrees of arrays. In addition, all JAX function transformations can be applied to functions that accept as input and produce as output pytrees of arrays. - -Some JAX function transformations take optional parameters that specify how certain input or output values should be treated (such as the `in_axes` and `out_axes` arguments to {func}`jax.vmap`). These parameters can also be pytrees, and their structure must correspond to the pytree structure of the corresponding arguments. In particular, to be able to “match up” leaves in these parameter pytrees with values in the argument pytrees, the parameter pytrees are often constrained to be tree prefixes of the argument pytrees. - -For example, if you pass the following input to {func}`jax.vmap` (note that the input arguments to a function are considered a tuple): - -```python -vmap(f, in_axes=(a1, {"k1": a2, "k2": a3})) -``` - -then you can use the following `in_axes` pytree to specify that only the `k2` argument is mapped (`axis=0`), and the rest aren’t mapped over (`axis=None`): - -```python -vmap(f, in_axes=(None, {"k1": None, "k2": 0})) -``` - -The optional parameter pytree structure must match that of the main input pytree. However, the optional parameters can optionally be specified as a “prefix” pytree, meaning that a single leaf value can be applied to an entire sub-pytree. - -For example, if you have the same {func}`jax.vmap` input as above, but wish to only map over the dictionary argument, you can use: - -```python -vmap(f, in_axes=(None, 0)) # equivalent to (None, {"k1": 0, "k2": 0}) -``` - -Alternatively, if you want every argument to be mapped, you can write a single leaf value that is applied over the entire argument tuple pytree: - -```python -vmap(f, in_axes=0) # equivalent to (0, {"k1": 0, "k2": 0}) -``` - -This happens to be the default `in_axes` value for {func}`jax.vmap`. - -The same logic applies to other optional parameters that refer to specific input or output values of a transformed function, such as `out_axes` in {func}`jax.vmap`. - -(pytrees-explicity-key-paths)= -## Explicit key paths - -In a pytree each leaf has a _key path_. A key path for a leaf is a `list` of _keys_, where the length of the list is equal to the depth of the leaf in the pytree . Each _key_ is a [hashable object](https://docs.python.org/3/glossary.html#term-hashable) that represents an index into the corresponding pytree node type. The type of the key depends on the pytree node type; for example, the type of keys for `dict`s is different from the type of keys for `tuple`s. - -For built-in pytree node types, the set of keys for any pytree node instance is unique. For a pytree comprising nodes with this property, the key path for each leaf is unique. - -JAX has the following `jax.tree_util.*` methods for working with key paths: - -- {func}`jax.tree_util.tree_flatten_with_path`: Works similarly to {func}`jax.tree.flatten`, but returns key paths. -- {func}`jax.tree_util.tree_map_with_path`: Works similarly to {func}`jax.tree.map`, but the function also takes key paths as arguments. -- {func}`jax.tree_util.keystr`: Given a general key path, returns a reader-friendly string expression. - -For example, one use case is to print debugging information related to a certain leaf value: - -```{code-cell} -import collections - -ATuple = collections.namedtuple("ATuple", ('name')) - -tree = [1, {'k1': 2, 'k2': (3, 4)}, ATuple('foo')] -flattened, _ = jax.tree_util.tree_flatten_with_path(tree) - -for key_path, value in flattened: - print(f'Value of tree{jax.tree_util.keystr(key_path)}: {value}') -``` - -To express key paths, JAX provides a few default key types for the built-in pytree node types, namely: - -* `SequenceKey(idx: int)`: For lists and tuples. -* `DictKey(key: Hashable)`: For dictionaries. -* `GetAttrKey(name: str)`: For `namedtuple`s and preferably custom pytree nodes (more in the next section) - -You are free to define your own key types for your custom nodes. They will work with {func}`jax.tree_util.keystr` as long as their `__str__()` method is also overridden with a reader-friendly expression. - -```{code-cell} -for key_path, _ in flattened: - print(f'Key path of tree{jax.tree_util.keystr(key_path)}: {repr(key_path)}') -``` - -(pytrees-common-pytree-gotchas)= -## Common pytree gotchas - -This section covers some of the most common problems ("gotchas") encountered when using JAX pytrees. - -### Mistaking pytree nodes for leaves - -A common gotcha to look out for is accidentally introducing _tree nodes_ instead of _leaves_: - -```{code-cell} -a_tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))] - -# Try to make another pytree with ones instead of zeros. -shapes = jax.tree.map(lambda x: x.shape, a_tree) -jax.tree.map(jnp.ones, shapes) -``` - -What happened here is that the `shape` of an array is a tuple, which is a pytree node, with its elements as leaves. Thus, in the map, instead of calling `jnp.ones` on e.g. `(2, 3)`, it's called on `2` and `3`. - -The solution will depend on the specifics, but there are two broadly applicable options: - -- Rewrite the code to avoid the intermediate {func}`jax.tree.map`. -- Convert the tuple into a NumPy array (`np.array`) or a JAX NumPy array (`jnp.array`), which makes the entire sequence a leaf. - -### Handling of `None` by `jax.tree_util` - -`jax.tree_util` functions treat `None` as the absence of a pytree node, not as a leaf: - -```{code-cell} -jax.tree.leaves([None, None, None]) -``` - -To treat `None` as a leaf, you can use the `is_leaf` argument: - -```{code-cell} -jax.tree.leaves([None, None, None], is_leaf=lambda x: x is None) -``` - -### Custom pytrees and initialization with unexpected values - -Another common gotcha with user-defined pytree objects is that JAX transformations occasionally initialize them with unexpected values, so that any input validation done at initialization may fail. For example: - -```{code-cell} -:tags: [raises-exception] - -class MyTree: - def __init__(self, a): - self.a = jnp.asarray(a) - -register_pytree_node(MyTree, lambda tree: ((tree.a,), None), - lambda _, args: MyTree(*args)) - -tree = MyTree(jnp.arange(5.0)) - -jax.vmap(lambda x: x)(tree) # Error because object() is passed to `MyTree`. -``` - -```{code-cell} -:tags: [raises-exception] - -jax.jacobian(lambda x: x)(tree) # Error because MyTree(...) is passed to `MyTree`. -``` - -- In the first case with `jax.vmap(...)(tree)`, JAX’s internals use arrays of `object()` values to infer the structure of the tree -- In the second case with `jax.jacobian(...)(tree)`, the Jacobian of a function mapping a tree to a tree is defined as a tree of trees. - -**Potential solution 1:** - -- The `__init__` and `__new__` methods of custom pytree classes should generally avoid doing any array conversion or other input validation, or else anticipate and handle these special cases. For example: - -```{code-cell} -class MyTree: - def __init__(self, a): - if not (type(a) is object or a is None or isinstance(a, MyTree)): - a = jnp.asarray(a) - self.a = a -``` - -**Potential solution 2:** - -- Structure your custom `tree_unflatten` function so that it avoids calling `__init__`. If you choose this route, make sure that your `tree_unflatten` function stays in sync with `__init__` if and when the code is updated. Example: - -```{code-cell} -def tree_unflatten(aux_data, children): - del aux_data # Unused in this class. - obj = object.__new__(MyTree) - obj.a = a - return obj -``` - -(pytrees-common-pytree-patterns)= -## Common pytree patterns - -This section covers some of the most common patterns with JAX pytrees. - -### Transposing pytrees with `jax.tree.map` and `jax.tree.transpose` - -To transpose a pytree (turn a list of trees into a tree of lists), JAX has two functions: {func}`jax.tree.map` (more basic) and {func}`jax.tree.transpose` (more flexible, complex and verbose). - -**Option 1:** Use {func}`jax.tree.map`. Here's an example: - -```{code-cell} -def tree_transpose(list_of_trees): - """ - Converts a list of trees of identical structure into a single tree of lists. - """ - return jax.tree.map(lambda *xs: list(xs), *list_of_trees) - -# Convert a dataset from row-major to column-major. -episode_steps = [dict(t=1, obs=3), dict(t=2, obs=4)] -tree_transpose(episode_steps) -``` - -**Option 2:** For more complex transposes, use {func}`jax.tree.transpose`, which is more verbose, but allows you specify the structure of the inner and outer pytree for more flexibility. For example: - -```{code-cell} -jax.tree.transpose( - outer_treedef = jax.tree.structure([0 for e in episode_steps]), - inner_treedef = jax.tree.structure(episode_steps[0]), - pytree_to_transpose = episode_steps -) -``` diff --git a/docs/xla_flags.md b/docs/xla_flags.md index 1e374abea005..dec2d81e6cf6 100644 --- a/docs/xla_flags.md +++ b/docs/xla_flags.md @@ -1,10 +1,10 @@ -# List of XLA compiler flags +# XLA compiler flags ## Introduction This guide gives a brief overview of XLA and how XLA relates to Jax. -For in-depth details please refer to [XLA documentation](https://openxla.org/xla). Then it lists commonly-used XLA compiler flags designed to optimize performance of Jax programs. +For in-depth details please refer to [XLA documentation](https://openxla.org/xla). ## XLA: The Powerhouse Behind Jax XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear algebra that plays a pivotal role in Jax's performance and flexibility. It enables Jax to generate optimized code for various hardware backends (CPUs, GPUs, TPUs) by transforming and compiling your Python/NumPy-like code into efficient machine instructions. @@ -44,45 +44,8 @@ XLA_FLAGS='--flag1=value1 --flag2=value2' python3 source.py * Complete and up to date documentation about XLA can be found in the official [XLA documentation](https://openxla.org/xla). * For backends supported by open-source version of XLA (CPU, GPU), XLA flags are defined with their default values in [xla/debug_options_flags.cc](https://github.com/openxla/xla/blob/main/xla/debug_options_flags.cc), and a complete list of flags could be found [here](https://github.com/openxla/xla/blob/main/xla/xla.proto). -* TPU compiler flags are not part of [OpenXLA](https://github.com/openxla/xla), but commonly-used options are listed below. - -* Please note that this list of flags is not exhaustive and is subject to change. These flags are implementation details, and there is no guarantee that they will remain available or maintain their current behavior. -### Common XLA flags -| Flag | Type | Notes | -| ---- | ---- | ----- | -| `xla_dump_to` | String (filepath) | The folder where pre-optimization HLO files and other artifacts will be placed (see [XLA Tools](https://openxla.org/xla/tools)). | -| `xla_enable_async_collective_permute` | TristateFlag (true/false/auto) | Rewrites all collective-permute operations to their asynchronous variants. When set to `auto`, XLA can turn on async collective based on other configurations or conditions automatically. | -| `xla_enable_async_all_gather` | TristateFlag (true/false/auto) | If set to true, enables async all gather. If `auto`, enables only for platforms that implement async all-gather. The implementation (such as BC-offload or continuation fusion) is chosen based on other flag values. | -| `xla_disable_hlo_passes` | String (comma-separated list of pass names) | Comma-separated list of HLO passes to be disabled. These names must exactly match the pass name (no whitespace around commas). | - -### TPU XLA flags -| Flag | Type | Notes | -| ---- | ---- | ----- | -| `xla_tpu_enable_data_parallel_all_reduce_opt` | Boolean (true/false) | Optimization to increase overlap opportunities for DCN (data center networking) all-reduces used for data parallel sharding. | -| `xla_tpu_data_parallel_opt_different_sized_ops` | Boolean (true/false) | Enables pipelining of data parallel ops across multiple iterations even if their output sizes don't match what can be saved in place in the stacked variables. Can increase memory pressure. | -| `xla_tpu_enable_async_collective_fusion` | Boolean (true/false) | Enables the pass which fuses async collective communications with compute ops (output/loop-fusion or convolution) that are scheduled between their -start and -done instructions. | -| `xla_tpu_enable_async_collective_fusion_fuse_all_gather` | TristateFlag (true/false/auto) | Enables fusing all-gathers within the AsyncCollectiveFusion pass.
If set to `auto`, it will be enabled based on the target. | -| `xla_tpu_enable_async_collective_fusion_multiple_steps` | Boolean (true/false) | Enables continuing the same async collective in multiple steps (fusions) in the AsyncCollectiveFusion pass. | -| `xla_tpu_overlap_compute_collective_tc` | Boolean (true/false) | Enables the overlap of compute and communication on a single TensorCore, i.e., one core equivalent of MegaCore fusion. | -| `xla_tpu_spmd_rng_bit_generator_unsafe` | Boolean (true/false) | Whether to run RngBitGenerator HLO in a partitioned way, which is unsafe if deterministic results are expected with different shardings on different parts of the computation. | -| `xla_tpu_megacore_fusion_allow_ags` | Boolean (true/false) | Allows fusing all-gathers with convolutions/all-reduces. | -| `xla_tpu_enable_ag_backward_pipelining` | Boolean (true/false) | Pipelines all-gathers (currently megascale all-gathers) backwards through scan loops. | - -### GPU XLA flags -| Flag | Type | Notes | -| ---- | ---- | ----- | -| `xla_gpu_enable_latency_hiding_scheduler` | Boolean (true/false) |This flag enables latency hiding schedulers to overlap asynchronous communication with computation efficiently. The default value is False. | -| `xla_gpu_enable_triton_gemm` | Boolean (true/false) | Use Triton-based matrix multiplication. | -| `xla_gpu_graph_level` | Flag (0-3) | The legacy flag for setting GPU graph level. Use xla_gpu_enable_command_buffer in new use cases. 0 = off; 1 = capture fusions and memcpys; 2 = capture gemms; 3 = capture convolutions. | -| `xla_gpu_all_reduce_combine_threshold_bytes` | Integer (bytes) | These flags tune when to combine multiple small AllGather / ReduceScatter / AllReduce into one big AllGather / ReduceScatter / AllReduce to reduce time spent on cross-device communication. For example, for the AllGather / ReduceScatter thresholds on a Transformer-based workload, consider tuning them high enough so as to combine at least a Transformer Layer’s weight AllGather / ReduceScatter. By default, the combine_threshold_bytes is set to 256. | -| `xla_gpu_all_gather_combine_threshold_bytes` | Integer (bytes) | See xla_gpu_all_reduce_combine_threshold_bytes above. | -| `xla_gpu_reduce_scatter_combine_threshold_bytes` | Integer (bytes) | See xla_gpu_all_reduce_combine_threshold_bytes above. | -| `xla_gpu_enable_pipelined_all_gather` | Boolean (true/false) | Enable pipelinling of all-gather instructions. | -| `xla_gpu_enable_pipelined_reduce_scatter` | Boolean (true/false) | Enable pipelinling of reduce-scatter instructions. | -| `xla_gpu_enable_pipelined_all_reduce` | Boolean (true/false) | Enable pipelinling of all-reduce instructions. | -| `xla_gpu_enable_while_loop_double_buffering` | Boolean (true/false) | Enable double-buffering for while loop. | -| `xla_gpu_enable_all_gather_combine_by_dim` | Boolean (true/false) | Combine all-gather ops with the same gather dimension or irrespective of their dimension. | -| `xla_gpu_enable_reduce_scatter_combine_by_dim` | Boolean (true/false) | Combine reduce-scatter ops with the same dimension or irrespective of their dimension. | + +* A guide on how to use key XLA flags can be found [here](https://openxla.org/xla/flags_guidance). **Additional reading:** -* [GPU performance tips](https://jax.readthedocs.io/en/latest/gpu_performance_tips.html#xla-performance-flags) +* [GPU performance tips](https://docs.jax.dev/en/latest/gpu_performance_tips.html#xla-performance-flags) diff --git a/examples/ffi/CMakeLists.txt b/examples/ffi/CMakeLists.txt index ea7670b81ccc..4a93cc490d33 100644 --- a/examples/ffi/CMakeLists.txt +++ b/examples/ffi/CMakeLists.txt @@ -3,10 +3,10 @@ project(${SKBUILD_PROJECT_NAME} LANGUAGES CXX) option(JAX_FFI_EXAMPLE_ENABLE_CUDA "Enable CUDA support" OFF) -find_package(Python 3.10 REQUIRED COMPONENTS Interpreter Development.Module) +find_package(Python 3.11 REQUIRED COMPONENTS Interpreter Development.Module) execute_process( COMMAND "${Python_EXECUTABLE}" - "-c" "from jax.extend import ffi; print(ffi.include_dir())" + "-c" "from jax import ffi; print(ffi.include_dir())" OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR) message(STATUS "XLA include directory: ${XLA_DIR}") diff --git a/examples/ffi/README.md b/examples/ffi/README.md index bd45408e50d8..c490f014859b 100644 --- a/examples/ffi/README.md +++ b/examples/ffi/README.md @@ -2,7 +2,7 @@ This directory includes an example project demonstrating the use of JAX's foreign function interface (FFI). The JAX docs provide more information about -this interface in [the FFI tutorial](https://jax.readthedocs.io/en/latest/ffi.html), +this interface in [the FFI tutorial](https://docs.jax.dev/en/latest/ffi.html), but the example in this directory complements that document by demonstrating (and testing!) the full packaging workflow, and some more advanced use cases. Within the example project, there are several example calls: diff --git a/examples/ffi/pyproject.toml b/examples/ffi/pyproject.toml index 130dd91bbc70..84e2c4700500 100644 --- a/examples/ffi/pyproject.toml +++ b/examples/ffi/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "scikit_build_core.build" [project] name = "jax_ffi_example" version = "0.0.1" -requires-python = ">=3.10" +requires-python = ">=3.11" dependencies = ["jax"] [project.optional-dependencies] diff --git a/examples/ffi/src/jax_ffi_example/gpu_examples.cc b/examples/ffi/src/jax_ffi_example/gpu_examples.cc index 921039debe5d..c2b7c95ce9e4 100644 --- a/examples/ffi/src/jax_ffi_example/gpu_examples.cc +++ b/examples/ffi/src/jax_ffi_example/gpu_examples.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" #include "cuda_runtime_api.h" +#include "nanobind/nanobind.h" #include "xla/ffi/api/ffi.h" namespace nb = nanobind; @@ -53,6 +53,19 @@ XLA_FFI_DEFINE_HANDLER(kStateExecute, StateExecute, NB_MODULE(_gpu_examples, m) { m.def("type_id", []() { return nb::capsule(reinterpret_cast(&State::id)); }); + m.def("state_type", []() { + // In earlier versions of XLA:FFI, the `MakeTypeInfo` helper was not + // available. In latest XLF:FFI `TypeInfo` is an alias for C API struct. +#if XLA_FFI_API_MINOR >= 2 + static auto kStateTypeInfo = xla::ffi::MakeTypeInfo(); +#else + static auto kStateTypeInfo = xla::ffi::TypeInfo(); +#endif + nb::dict d; + d["type_id"] = nb::capsule(reinterpret_cast(&State::id)); + d["type_info"] = nb::capsule(reinterpret_cast(&kStateTypeInfo)); + return d; + }); m.def("handler", []() { nb::dict d; d["instantiate"] = nb::capsule(reinterpret_cast(kStateInstantiate)); diff --git a/examples/ffi/src/jax_ffi_example/gpu_examples.py b/examples/ffi/src/jax_ffi_example/gpu_examples.py index 8f775c265fd4..e733ea7e44bf 100644 --- a/examples/ffi/src/jax_ffi_example/gpu_examples.py +++ b/examples/ffi/src/jax_ffi_example/gpu_examples.py @@ -16,8 +16,10 @@ from jax_ffi_example import _gpu_examples import jax.numpy as jnp + +jax.ffi.register_ffi_type( + "state", _gpu_examples.state_type(), platform="CUDA") jax.ffi.register_ffi_target("state", _gpu_examples.handler(), platform="CUDA") -jax.ffi.register_ffi_type_id("state", _gpu_examples.type_id(), platform="CUDA") def read_state(): diff --git a/examples/ffi/src/jax_ffi_example/rms_norm.cc b/examples/ffi/src/jax_ffi_example/rms_norm.cc index 819f3b9f868d..bcfc1eb67aa4 100644 --- a/examples/ffi/src/jax_ffi_example/rms_norm.cc +++ b/examples/ffi/src/jax_ffi_example/rms_norm.cc @@ -16,8 +16,6 @@ limitations under the License. #include #include #include -#include -#include #include #include diff --git a/examples/ffi/src/jax_ffi_example/rms_norm.py b/examples/ffi/src/jax_ffi_example/rms_norm.py index 6dbfe5043ddf..996eb9e5d935 100644 --- a/examples/ffi/src/jax_ffi_example/rms_norm.py +++ b/examples/ffi/src/jax_ffi_example/rms_norm.py @@ -14,9 +14,9 @@ """An example demontrating the basic end-to-end use of the JAX FFI. This example is exactly the same as the one in the `FFI tutorial -`, so more details can be found +`, so more details can be found on that page. But, the high level summary is that we implement our custom -extension in ``rms_norm.cc``, then call it usin ``jax.ffi.ffi_call`` in +extension in ``rms_norm.cc``, then call it using ``jax.ffi.ffi_call`` in this module. The behavior under autodiff is implemented using ``jax.custom_vjp``. """ diff --git a/examples/jax_cpp/BUILD b/examples/jax_cpp/BUILD index b3cb995aae21..97f8a5804a79 100644 --- a/examples/jax_cpp/BUILD +++ b/examples/jax_cpp/BUILD @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_applicable_licenses = ["//jax:license"]) +load("@rules_cc//cc:cc_binary.bzl", "cc_binary") + +package(default_applicable_licenses = []) licenses(["notice"]) @@ -21,6 +23,7 @@ cc_binary( srcs = ["main.cc"], tags = ["manual"], deps = [ + "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:platform_port", @@ -33,6 +36,7 @@ cc_binary( "@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options", "@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", "@xla//xla/service:hlo_module_config", + "@xla//xla/service:hlo_proto_cc", "@xla//xla/tools:hlo_module_loader", ], ) diff --git a/examples/jax_cpp/main.cc b/examples/jax_cpp/main.cc index 0a1d3a63acfd..b911711ad53f 100644 --- a/examples/jax_cpp/main.cc +++ b/examples/jax_cpp/main.cc @@ -18,7 +18,7 @@ limitations under the License. // // To build a HloModule, // -// $ python3 jax/tools/jax_to_hlo.py \ +// $ python3 jax/tools/jax_to_ir.py \ // --fn examples.jax_cpp.prog.fn \ // --input_shapes '[("x", "f32[2,2]"), ("y", "f32[2,2]")]' \ // --constants '{"z": 2.0}' \ @@ -41,7 +41,8 @@ limitations under the License. #include #include -#include "third_party/absl/status/statusor.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/literal.h" @@ -50,6 +51,7 @@ limitations under the License. #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/plugin/xla_cpu/cpu_client_options.h" #include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h" +#include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" #include "xla/tools/hlo_module_loader.h" #include "tsl/platform/init_main.h" @@ -104,7 +106,7 @@ int main(int argc, char** argv) { // Get result. std::shared_ptr result_literal = - results[0][0]->ToLiteralSync().value(); + results[0][0]->ToLiteral().Await().value(); LOG(INFO) << "result = " << *result_literal; return 0; } diff --git a/examples/k8s/example.yaml b/examples/k8s/example.yaml new file mode 100644 index 000000000000..9039626e9c82 --- /dev/null +++ b/examples/k8s/example.yaml @@ -0,0 +1,39 @@ +apiVersion: jobset.x-k8s.io/v1alpha2 +kind: JobSet +metadata: + name: jaxjob +spec: + replicatedJobs: + - name: workers + template: + spec: + parallelism: 2 + completions: 2 + backoffLimit: 0 + template: + spec: + serviceAccountName: jax-job-sa # kubectl apply -f svc-acct.yaml + restartPolicy: Never + imagePullSecrets: + # https://k8s.io/docs/tasks/configure-pod-container/pull-image-private-registry/ + - name: null + containers: + - name: main + image: null # e.g. ghcr.io/nvidia/jax:jax + imagePullPolicy: Always + resources: + limits: + cpu: 900m + # https://k8s.io/docs/tasks/manage-gpus/scheduling-gpus/ + nvidia.com/gpu: null + command: + - python + args: + - -c + - | + import jax + jax.distributed.initialize() + print(jax.devices()) + print(jax.local_devices()) + assert jax.process_count() > 1 + assert len(jax.devices()) > len(jax.local_devices()) diff --git a/examples/k8s/svc-acct.yaml b/examples/k8s/svc-acct.yaml new file mode 100644 index 000000000000..c1523964c515 --- /dev/null +++ b/examples/k8s/svc-acct.yaml @@ -0,0 +1,31 @@ +apiVersion: v1 +kind: ServiceAccount +metadata: + name: jax-job-sa + namespace: default +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: Role +metadata: + name: pod-reader +rules: + - apiGroups: [""] + resources: ["pods", "services"] + verbs: ["get", "list", "watch"] + - apiGroups: ["batch"] + resources: ["jobs"] + verbs: ["get", "list", "watch"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: RoleBinding +metadata: + name: pod-reader-binding + namespace: default +subjects: + - kind: ServiceAccount + name: jax-job-sa + namespace: default +roleRef: + kind: Role + name: pod-reader + apiGroup: rbac.authorization.k8s.io diff --git a/examples/spmd_mnist_classifier_fromscratch.py b/examples/spmd_mnist_classifier_fromscratch.py index 3698314708c7..7f2f18de8b1b 100644 --- a/examples/spmd_mnist_classifier_fromscratch.py +++ b/examples/spmd_mnist_classifier_fromscratch.py @@ -12,33 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""An MNIST example with single-program multiple-data (SPMD) data parallelism. - -The aim here is to illustrate how to use JAX's `pmap` to express and execute -SPMD programs for data parallelism along a batch dimension, while also -minimizing dependencies by avoiding the use of higher-level layers and -optimizers libraries. -""" - - from functools import partial import time +from jax import NamedSharding import numpy as np import numpy.random as npr - import jax -from jax import jit, grad, pmap +from jax import jit, grad +from jax.sharding import PartitionSpec as P, AxisType, reshard from jax.scipy.special import logsumexp -from jax.tree_util import tree_map -from jax import lax import jax.numpy as jnp -from examples import datasets +import datasets def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)): - return [(scale * rng.randn(m, n), scale * rng.randn(n)) - for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])] + return [ + (scale * rng.randn(m, n), scale * rng.randn(n)) + for m, n in zip(layer_sizes[:-1], layer_sizes[1:]) + ] + def predict(params, inputs): activations = inputs @@ -50,11 +43,21 @@ def predict(params, inputs): logits = jnp.dot(activations, final_w) + final_b return logits - logsumexp(logits, axis=1, keepdims=True) + def loss(params, batch): inputs, targets = batch preds = predict(params, inputs) return -jnp.mean(jnp.sum(preds * targets, axis=1)) + +@partial(jax.jit, donate_argnums=0) +def train_step(params, batch): + grads = grad(loss)(params, batch) + return [ + (w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads) + ] + + @jit def accuracy(params, batch): inputs, targets = batch @@ -72,57 +75,72 @@ def accuracy(params, batch): train_images, train_labels, test_images, test_labels = datasets.mnist() num_train = train_images.shape[0] + + num_devices = jax.device_count() + print(f"Using {num_devices} devices") + + if batch_size % num_devices != 0: + batch_size = (batch_size // num_devices) * num_devices + print(f"Adjusting batch size to {batch_size} for divisibility") + num_complete_batches, leftover = divmod(num_train, batch_size) num_batches = num_complete_batches + bool(leftover) - # For this manual SPMD example, we get the number of devices (e.g. GPUs or - # TPU cores) that we're using, and use it to reshape data minibatches. - num_devices = jax.device_count() + devices = np.array(jax.devices()) + mesh = jax.make_mesh( + (jax.device_count(),), ("batch",), axis_types=(AxisType.Explicit,) + ) + + replicated_sharding = NamedSharding(mesh, P()) + data_sharding = NamedSharding(mesh, P("batch")) + def data_stream(): rng = npr.RandomState(0) while True: perm = rng.permutation(num_train) for i in range(num_batches): - batch_idx = perm[i * batch_size:(i + 1) * batch_size] - images, labels = train_images[batch_idx], train_labels[batch_idx] - # For this SPMD example, we reshape the data batch dimension into two - # batch dimensions, one of which is mapped over parallel devices. - batch_size_per_device, ragged = divmod(images.shape[0], num_devices) - if ragged: - msg = "batch size must be divisible by device count, got {} and {}." - raise ValueError(msg.format(batch_size, num_devices)) - shape_prefix = (num_devices, batch_size_per_device) - images = images.reshape(shape_prefix + images.shape[1:]) - labels = labels.reshape(shape_prefix + labels.shape[1:]) + batch_idx = perm[i * batch_size : (i + 1) * batch_size] + images_np, labels_np = train_images[batch_idx], train_labels[batch_idx] + + current_batch_size = images_np.shape[0] + if current_batch_size < batch_size: + pad_len = batch_size - current_batch_size + images_np = np.concatenate([images_np, images_np[:pad_len]], axis=0) + labels_np = np.concatenate([labels_np, labels_np[:pad_len]], axis=0) + + images = jax.device_put(images_np, data_sharding) + labels = jax.device_put(labels_np, data_sharding) yield images, labels + batches = data_stream() - @partial(pmap, axis_name='batch') - def spmd_update(params, batch): - grads = grad(loss)(params, batch) - # We compute the total gradients, summing across the device-mapped axis, - # using the `lax.psum` SPMD primitive, which does a fast all-reduce-sum. - grads = [(lax.psum(dw, 'batch'), lax.psum(db, 'batch')) for dw, db in grads] - return [(w - step_size * dw, b - step_size * db) - for (w, b), (dw, db) in zip(params, grads)] - - # We replicate the parameters so that the constituent arrays have a leading - # dimension of size equal to the number of devices we're pmapping over. - init_params = init_random_params(param_scale, layer_sizes) - replicate_array = lambda x: np.broadcast_to(x, (num_devices,) + x.shape) - replicated_params = tree_map(replicate_array, init_params) + params = init_random_params(param_scale, layer_sizes) + replicated_params = jax.device_put(params, replicated_sharding) for epoch in range(num_epochs): start_time = time.time() - for _ in range(num_batches): - replicated_params = spmd_update(replicated_params, next(batches)) + for i in range(num_batches - 1): + print(f"Batch no {i+1} of {num_batches}") + batch = next(batches) + with jax.set_mesh(mesh): + replicated_params = train_step(replicated_params, batch) epoch_time = time.time() - start_time - # We evaluate using the jitted `accuracy` function (not using pmap) by - # grabbing just one of the replicated parameter values. - params = tree_map(lambda x: x[0], replicated_params) - train_acc = accuracy(params, (train_images, train_labels)) - test_acc = accuracy(params, (test_images, test_labels)) + # Reshard train_images, train_labels, test_images, test_labels + sharded_train_images = reshard(train_images, data_sharding) + sharded_train_labels = reshard(train_labels, data_sharding) + sharded_test_images = reshard(test_images, data_sharding) + sharded_test_labels = reshard(test_labels, data_sharding) + + train_acc = accuracy( + replicated_params, (sharded_train_images, sharded_train_labels) + ) + test_acc = accuracy(replicated_params, (sharded_test_images, sharded_test_labels)) print(f"Epoch {epoch} in {epoch_time:0.2f} sec") print(f"Training set accuracy {train_acc}") print(f"Test set accuracy {test_acc}") + + if epoch < num_epochs - 1: + batches = data_stream() + print(f"Batch no {0} of {num_batches}") + replicated_params = train_step(replicated_params, next(batches)) diff --git a/jax/BUILD b/jax/BUILD index 12eae4afdcf7..6431c08656d1 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -15,28 +15,16 @@ # JAX is Autograd and XLA load("@bazel_skylib//rules:common_settings.bzl", "string_flag") -load("@rules_python//python:defs.bzl", "py_library") load( "//jaxlib:jax.bzl", - "if_building_jaxlib", - "jax_export_file_visibility", "jax_extend_internal_users", "jax_extra_deps", - "jax_internal_export_back_compat_test_util_visibility", "jax_internal_packages", - "jax_internal_test_harnesses_visibility", - "jax_test_util_visibility", "jax_visibility", - "mosaic_gpu_internal_users", - "mosaic_internal_users", - "pallas_fuser_users", - "pallas_gpu_internal_users", - "pallas_tpu_internal_users", "py_deps", "py_library_providing_imports_info", "pytype_library", "pytype_strict_library", - "serialize_executable_internal_users", ) package( @@ -63,23 +51,70 @@ string_flag( ) config_setting( - name = "enable_jaxlib_build", + name = "config_build_jaxlib_true", flag_values = { ":build_jaxlib": "true", }, ) +config_setting( + name = "config_build_jaxlib_false", + flag_values = { + ":build_jaxlib": "false", + }, +) + +config_setting( + name = "config_build_jaxlib_wheel", + flag_values = { + ":build_jaxlib": "wheel", + }, +) + +# The flag controls whether jax should be built by Bazel. +# If ":build_jax=true", then jax will be built. +# If ":build_jax=false", then jax is not built. It is assumed that the pre-built jax wheel +# is available in the "dist" folder. +# If ":build_jax=wheel", then jax wheel will be built as a py_import rule attribute. +# The py_import rule unpacks the wheel and provides its content as a py_library. +string_flag( + name = "build_jax", + build_setting_default = "true", + values = [ + "true", + "false", + "wheel", + ], +) + +config_setting( + name = "config_build_jax_true", + flag_values = { + ":build_jax": "true", + }, +) + +config_setting( + name = "config_build_jax_false", + flag_values = { + ":build_jax": "false", + }, +) + +config_setting( + name = "config_build_jax_wheel", + flag_values = { + ":build_jax": "wheel", + }, +) + exports_files([ "LICENSE", "version.py", "py.typed", + "oss/pyproject.toml", ]) -exports_files( - ["_src/export/serialization.fbs"], - visibility = jax_export_file_visibility, -) - # Packages that have access to JAX-internal implementation details. package_group( name = "internal", @@ -93,175 +128,16 @@ package_group( includes = [":internal"], packages = [ # Intentionally avoid jax dependencies on jax.extend. - # See https://jax.readthedocs.io/en/latest/jep/15856-jex.html + # See https://docs.jax.dev/en/latest/jep/15856-jex.html "//tests/...", ] + jax_extend_internal_users, ) -package_group( - name = "mosaic_users", - includes = [":internal"], - packages = mosaic_internal_users, -) - -package_group( - name = "pallas_gpu_users", - includes = [":internal"], - packages = pallas_gpu_internal_users, -) - -package_group( - name = "pallas_tpu_users", - includes = [":internal"], - packages = pallas_tpu_internal_users, -) - -package_group( - name = "pallas_fuser_users", - includes = [":internal"], - packages = pallas_fuser_users, -) - -package_group( - name = "mosaic_gpu_users", - includes = [":internal"], - packages = mosaic_gpu_internal_users, -) - -package_group( - name = "serialize_executable_users", - includes = [":internal"], - packages = serialize_executable_internal_users, -) - -# JAX-private test utilities. -py_library( - # This build target is required in order to use private test utilities in jax._src.test_util, - # and its visibility is intentionally restricted to discourage its use outside JAX itself. - # JAX does provide some public test utilities (see jax/test_util.py); - # these are available in jax.test_util via the standard :jax target. - name = "test_util", - srcs = [ - "_src/test_util.py", - "_src/test_warning_util.py", - ], - visibility = [ - ":internal", - ] + jax_test_util_visibility, - deps = [ - ":compilation_cache_internal", - ":jax", - ] + py_deps("absl/testing") + py_deps("numpy"), -) - -# TODO(necula): break the internal_test_util into smaller build targets. -py_library( - name = "internal_test_util", - srcs = [ - "_src/internal_test_util/__init__.py", - "_src/internal_test_util/deprecation_module.py", - "_src/internal_test_util/lax_test_util.py", - ] + glob( - [ - "_src/internal_test_util/lazy_loader_module/*.py", - ], - ), - visibility = [":internal"], - deps = [ - ":jax", - ] + py_deps("numpy"), -) - -py_library( - name = "internal_test_harnesses", - srcs = ["_src/internal_test_util/test_harnesses.py"], - visibility = [":internal"] + jax_internal_test_harnesses_visibility, - deps = [ - ":ad_util", - ":config", - ":jax", - ":test_util", - "//jax/_src/lib", - ] + py_deps("numpy"), -) - -py_library( - name = "internal_export_back_compat_test_util", - srcs = ["_src/internal_test_util/export_back_compat_test_util.py"], - visibility = [ - ":internal", - ] + jax_internal_export_back_compat_test_util_visibility, - deps = [ - ":jax", - ":test_util", - ] + py_deps("numpy"), -) - -py_library( - name = "internal_export_back_compat_test_data", - testonly = 1, - srcs = glob([ - "_src/internal_test_util/export_back_compat_test_data/*.py", - "_src/internal_test_util/export_back_compat_test_data/pallas/*.py", - ]), - visibility = [ - ":internal", - ], - deps = py_deps("numpy"), -) - py_library_providing_imports_info( name = "jax", - srcs = [ - "_src/__init__.py", - "_src/ad_checkpoint.py", - "_src/api.py", - "_src/array.py", - "_src/blocked_sampler.py", - "_src/callback.py", - "_src/checkify.py", - "_src/custom_batching.py", - "_src/custom_dce.py", - "_src/custom_derivatives.py", - "_src/custom_partitioning.py", - "_src/custom_partitioning_sharding_rule.py", - "_src/custom_transpose.py", - "_src/debugging.py", - "_src/dispatch.py", - "_src/dlpack.py", - "_src/earray.py", - "_src/error_check.py", - "_src/ffi.py", - "_src/flatten_util.py", - "_src/interpreters/__init__.py", - "_src/interpreters/ad.py", - "_src/interpreters/batching.py", - "_src/interpreters/pxla.py", - "_src/pjit.py", - "_src/prng.py", - "_src/public_test_util.py", - "_src/random.py", - "_src/shard_alike.py", - "_src/sourcemap.py", - "_src/stages.py", - "_src/tree.py", - ] + glob( + srcs = glob( [ "*.py", - "_src/cudnn/**/*.py", - "_src/debugger/**/*.py", - "_src/extend/**/*.py", - "_src/image/**/*.py", - "_src/export/**/*.py", - "_src/lax/**/*.py", - "_src/nn/**/*.py", - "_src/numpy/**/*.py", - "_src/ops/**/*.py", - "_src/scipy/**/*.py", - "_src/state/**/*.py", - "_src/third_party/**/*.py", - "experimental/key_reuse/**/*.py", - "experimental/roofline/**/*.py", "image/**/*.py", "interpreters/**/*.py", "lax/**/*.py", @@ -275,814 +151,104 @@ py_library_providing_imports_info( exclude = [ "*_test.py", "**/*_test.py", - "_src/internal_test_util/**", ], - ) + [ - "experimental/attrs.py", - "experimental/pjit.py", - "experimental/multihost_utils.py", - "experimental/shard_map.py", - # until checkify is moved out of experimental - "experimental/checkify.py", - "experimental/compilation_cache/compilation_cache.py", - ], + # TODO(dsuo): Consider moving these files out of experimental if they're in the public API. + ) + ["//jax/experimental:jax_public"], + lazy_imports = True, lib_rule = pytype_library, pytype_srcs = glob( [ + "nn/*.pyi", "numpy/*.pyi", - "_src/**/*.pyi", - ], - exclude = [ - "_src/basearray.pyi", - ], - ), - visibility = ["//visibility:public"], - deps = [ - ":abstract_arrays", - ":ad_util", - ":api_util", - ":basearray", - ":cloud_tpu_init", - ":compilation_cache_internal", - ":compiler", - ":compute_on", - ":config", - ":core", - ":custom_api_util", - ":deprecations", - ":dtypes", - ":effects", - ":environment_info", - ":internal_mesh_utils", - ":jaxpr_util", - ":layout", - ":lazy_loader", - ":mesh", - ":mlir", - ":monitoring", - ":named_sharding", - ":op_shardings", - ":partial_eval", - ":partition_spec", - ":path", - ":pickle_util", - ":pretty_printer", - ":profiler", - ":sharding", - ":sharding_impls", - ":sharding_specs", - ":source_info_util", - ":traceback_util", - ":tree_util", - ":typing", - ":util", - ":version", - ":xla", - ":xla_bridge", - ":xla_metadata", - "//jax/_src/lib", - ] + py_deps("numpy") + py_deps("scipy") + py_deps("opt_einsum") + py_deps("flatbuffers") + jax_extra_deps, -) - -pytype_strict_library( - name = "abstract_arrays", - srcs = ["_src/abstract_arrays.py"], - deps = [ - ":ad_util", - ":core", - ":dtypes", - ":traceback_util", - ] + py_deps("numpy"), -) - -pytype_strict_library( - name = "ad_util", - srcs = ["_src/ad_util.py"], - deps = [ - ":core", - ":traceback_util", - ":tree_util", - ":typing", - ":util", - ], -) - -pytype_strict_library( - name = "api_util", - srcs = ["_src/api_util.py"], - deps = [ - ":abstract_arrays", - ":config", - ":core", - ":dtypes", - ":state_types", - ":traceback_util", - ":tree_util", - ":util", - ] + py_deps("numpy"), -) - -pytype_strict_library( - name = "basearray", - srcs = ["_src/basearray.py"], - pytype_srcs = ["_src/basearray.pyi"], - deps = [ - ":partition_spec", - ":sharding", - "//jax/_src/lib", - ] + py_deps("numpy"), -) - -pytype_strict_library( - name = "cloud_tpu_init", - srcs = ["_src/cloud_tpu_init.py"], - deps = [ - ":config", - ":hardware_utils", - ":version", - ], -) - -pytype_strict_library( - name = "compilation_cache_internal", - srcs = ["_src/compilation_cache.py"], - visibility = [":internal"] + jax_visibility("compilation_cache"), - deps = [ - ":cache_key", - ":compilation_cache_interface", - ":config", - ":lru_cache", - ":monitoring", - ":path", - "//jax/_src/lib", - ] + py_deps("numpy") + py_deps("zstandard"), -) - -pytype_strict_library( - name = "cache_key", - srcs = ["_src/cache_key.py"], - visibility = [":internal"] + jax_visibility("compilation_cache"), - deps = [ - ":config", - "//jax/_src/lib", - ] + py_deps("numpy"), -) - -pytype_strict_library( - name = "compilation_cache_interface", - srcs = ["_src/compilation_cache_interface.py"], - deps = [ - ":path", - ":util", - ], -) - -pytype_strict_library( - name = "lru_cache", - srcs = ["_src/lru_cache.py"], - deps = [ - ":compilation_cache_interface", - ":path", - ] + py_deps("filelock"), -) - -pytype_strict_library( - name = "config", - srcs = ["_src/config.py"], - deps = [ - ":logging_config", - "//jax/_src/lib", - ], -) - -pytype_strict_library( - name = "logging_config", - srcs = ["_src/logging_config.py"], -) - -pytype_strict_library( - name = "compiler", - srcs = ["_src/compiler.py"], - deps = [ - ":cache_key", - ":compilation_cache_internal", - ":config", - ":mlir", - ":monitoring", - ":path", - ":profiler", - ":traceback_util", - ":xla_bridge", - "//jax/_src/lib", - ] + py_deps("numpy"), -) - -pytype_strict_library( - name = "core", - srcs = [ - "_src/core.py", - "_src/errors.py", - "_src/linear_util.py", - ], - deps = [ - ":compute_on", - ":config", - ":deprecations", - ":dtypes", - ":effects", - ":mesh", - ":named_sharding", - ":partition_spec", - ":pretty_printer", - ":source_info_util", - ":traceback_util", - ":tree_util", - ":typing", - ":util", - ":xla_metadata", - "//jax/_src/lib", - ] + py_deps("numpy"), -) - -pytype_strict_library( - name = "custom_api_util", - srcs = ["_src/custom_api_util.py"], -) - -pytype_strict_library( - name = "deprecations", - srcs = ["_src/deprecations.py"], -) - -pytype_strict_library( - name = "dtypes", - srcs = [ - "_src/dtypes.py", - ], - deps = [ - ":config", - ":traceback_util", - ":typing", - ":util", - "//jax/_src/lib", - ] + py_deps("ml_dtypes") + py_deps("numpy"), -) - -pytype_strict_library( - name = "effects", - srcs = ["_src/effects.py"], -) - -pytype_strict_library( - name = "environment_info", - srcs = ["_src/environment_info.py"], - deps = [ - ":version", - ":xla_bridge", - "//jax/_src/lib", - ] + py_deps("numpy"), -) - -pytype_strict_library( - name = "hardware_utils", - srcs = ["_src/hardware_utils.py"], -) - -pytype_library( - name = "lax_reference", - srcs = ["_src/lax_reference.py"], - visibility = [":internal"] + jax_visibility("lax_reference"), - deps = [ - ":core", - ":util", - ] + py_deps("numpy") + py_deps("scipy") + py_deps("opt_einsum"), -) - -pytype_strict_library( - name = "lazy_loader", - srcs = ["_src/lazy_loader.py"], -) - -pytype_strict_library( - name = "jaxpr_util", - srcs = ["_src/jaxpr_util.py"], - deps = [ - ":core", - ":source_info_util", - ":util", - "//jax/_src/lib", - ], -) - -pytype_strict_library( - name = "mesh", - srcs = ["_src/mesh.py"], - deps = [ - ":config", - ":util", - ":xla_bridge", - "//jax/_src/lib", - ] + py_deps("numpy"), -) - -pytype_strict_library( - name = "mlir", - srcs = ["_src/interpreters/mlir.py"], - deps = [ - ":ad_util", - ":api_util", - ":config", - ":core", - ":dtypes", - ":effects", - ":layout", - ":op_shardings", - ":partial_eval", - ":partition_spec", - ":path", - ":pickle_util", - ":sharding", - ":sharding_impls", - ":source_info_util", - ":state_types", - ":util", - ":xla", - ":xla_bridge", - "//jax/_src/lib", - ] + py_deps("numpy"), -) - -pytype_strict_library( - name = "monitoring", - srcs = ["_src/monitoring.py"], -) - -pytype_strict_library( - name = "op_shardings", - srcs = ["_src/op_shardings.py"], - deps = [ - "//jax/_src/lib", - ] + py_deps("numpy"), -) - -pytype_strict_library( - name = "serialize_executable", - srcs = ["experimental/serialize_executable.py"], - visibility = [":serialize_executable_users"], - deps = [ - ":jax", - "//jax/_src/lib", - ], -) - -pytype_strict_library( - name = "source_mapper", - srcs = glob(include = ["experimental/source_mapper/**/*.py"]), - visibility = [ - "//visibility:public", - ], - deps = [ - ":config", - ":core", - ":jax", - ":source_info_util", - ] + py_deps("absl/flags"), -) - -pytype_strict_library( - name = "pallas", - srcs = glob( - [ - "experimental/pallas/**/*.py", - ], - exclude = [ - "experimental/pallas/gpu.py", - "experimental/pallas/mosaic_gpu.py", - "experimental/pallas/ops/gpu/**/*.py", - "experimental/pallas/ops/tpu/**/*.py", - "experimental/pallas/tpu.py", - "experimental/pallas/fuser.py", - "experimental/pallas/triton.py", ], ), - visibility = [ - "//visibility:public", - ], - deps = [ - ":deprecations", - ":jax", - ":source_info_util", - "//jax/_src/pallas", - ] + py_deps("numpy"), -) - -pytype_strict_library( - name = "pallas_tpu", - srcs = ["experimental/pallas/tpu.py"], - visibility = [ - ":pallas_tpu_users", - ], - deps = [ - ":pallas", # build_cleaner: keep - ":tpu_custom_call", - "//jax/_src/pallas", - "//jax/_src/pallas/mosaic:core", - "//jax/_src/pallas/mosaic:helpers", - "//jax/_src/pallas/mosaic:interpret", - "//jax/_src/pallas/mosaic:lowering", - "//jax/_src/pallas/mosaic:pallas_call_registration", # build_cleaner: keep - "//jax/_src/pallas/mosaic:pipeline", - "//jax/_src/pallas/mosaic:primitives", - "//jax/_src/pallas/mosaic:random", - "//jax/_src/pallas/mosaic:verification", - ], -) - -pytype_strict_library( - name = "pallas_fuser", - srcs = ["experimental/pallas/fuser.py"], - visibility = [ - ":pallas_fuser_users", - ], - deps = [ - ":pallas", # build_cleaner: keep - "//jax/_src/pallas/fuser:block_spec", - "//jax/_src/pallas/fuser:custom_evaluate", - "//jax/_src/pallas/fuser:fusable", - "//jax/_src/pallas/fuser:fusion", - "//jax/_src/pallas/fuser:jaxpr_fusion", - ], -) - -pytype_strict_library( - name = "pallas_gpu_ops", - srcs = ["//jax/experimental/pallas/ops/gpu:triton_ops"], - visibility = [ - ":pallas_gpu_users", - ], - deps = [ - ":jax", - ":pallas", - ":pallas_gpu", - ] + py_deps("numpy"), -) - -pytype_strict_library( - name = "pallas_experimental_gpu_ops", - srcs = ["//jax/experimental/pallas/ops/gpu:mgpu_ops"], - visibility = [ - ":mosaic_gpu_users", - ], - deps = [ - ":jax", - ":mosaic_gpu", - ":pallas", - ":pallas_mosaic_gpu", - ":test_util", # This is only to make them runnable as jax_multiplatform_test... - ] + py_deps("numpy"), -) - -pytype_strict_library( - name = "pallas_tpu_ops", - srcs = glob(["experimental/pallas/ops/tpu/**/*.py"]), - visibility = [ - ":pallas_tpu_users", - ], - deps = [ - ":jax", - ":pallas", - ":pallas_tpu", - ] + py_deps("numpy"), -) - -pytype_strict_library( - name = "pallas_gpu", - visibility = [ - ":pallas_gpu_users", - ], - deps = [ - ":pallas_triton", - # TODO(slebedev): Add :pallas_mosaic_gpu once it is ready. - ], -) - -pytype_strict_library( - name = "pallas_triton", - srcs = [ - "experimental/pallas/gpu.py", - "experimental/pallas/triton.py", - ], - visibility = [ - ":pallas_gpu_users", - ], - deps = [ - ":deprecations", - "//jax/_src/pallas/triton:core", - "//jax/_src/pallas/triton:pallas_call_registration", # build_cleaner: keep - "//jax/_src/pallas/triton:primitives", - ], -) - -pytype_strict_library( - name = "pallas_mosaic_gpu", - srcs = ["experimental/pallas/mosaic_gpu.py"], - visibility = [ - ":mosaic_gpu_users", - ], - deps = [ - ":mosaic_gpu", - "//jax/_src/pallas/mosaic_gpu:core", - "//jax/_src/pallas/mosaic_gpu:pallas_call_registration", # build_cleaner: keep - "//jax/_src/pallas/mosaic_gpu:pipeline", - "//jax/_src/pallas/mosaic_gpu:primitives", - ], -) - -# This target only supports sm_90 GPUs. -py_library_providing_imports_info( - name = "mosaic_gpu", - srcs = glob(["experimental/mosaic/gpu/*.py"]), - visibility = [ - ":mosaic_gpu_users", - ], - deps = [ - ":config", - ":core", - ":jax", - ":mlir", - "//jax/_src/lib", - "//jaxlib/mlir:arithmetic_dialect", - "//jaxlib/mlir:builtin_dialect", - "//jaxlib/mlir:func_dialect", - "//jaxlib/mlir:gpu_dialect", - "//jaxlib/mlir:ir", - "//jaxlib/mlir:llvm_dialect", - "//jaxlib/mlir:math_dialect", - "//jaxlib/mlir:memref_dialect", - "//jaxlib/mlir:nvgpu_dialect", - "//jaxlib/mlir:nvvm_dialect", - "//jaxlib/mlir:pass_manager", - "//jaxlib/mlir:scf_dialect", - "//jaxlib/mlir:vector_dialect", - "//jaxlib/mosaic/python:gpu_dialect", - ] + py_deps("absl/flags") + py_deps("numpy"), -) - -pytype_strict_library( - name = "partial_eval", - srcs = ["_src/interpreters/partial_eval.py"], - deps = [ - ":ad_util", - ":api_util", - ":compute_on", - ":config", - ":core", - ":dtypes", - ":effects", - ":profiler", - ":source_info_util", - ":state_types", - ":tree_util", - ":util", - ":xla_metadata", - ] + py_deps("numpy"), -) - -pytype_strict_library( - name = "partition_spec", - srcs = ["_src/partition_spec.py"], -) - -pytype_strict_library( - name = "path", - srcs = ["_src/path.py"], - deps = py_deps("epath"), -) - -pytype_library( - name = "experimental_profiler", - srcs = ["experimental/profiler.py"], visibility = ["//visibility:public"], deps = [ - "//jax/_src/lib", - ], -) - -pytype_library( - name = "experimental_transfer", - srcs = ["experimental/transfer.py"], - deps = [ - ":jax", - "//jax/_src/lib", - ], -) - -pytype_strict_library( - name = "pickle_util", - srcs = ["_src/pickle_util.py"], - deps = [":profiler"] + py_deps("cloudpickle"), -) - -pytype_strict_library( - name = "pretty_printer", - srcs = ["_src/pretty_printer.py"], - deps = [ - ":config", - ":util", - ] + py_deps("colorama"), -) - -pytype_strict_library( - name = "profiler", - srcs = ["_src/profiler.py"], - deps = [ - ":traceback_util", - ":xla_bridge", - "//jax/_src/lib", - ], -) - -pytype_strict_library( - name = "sharding", - srcs = ["_src/sharding.py"], - deps = [ - ":op_shardings", - ":util", - ":xla_bridge", - "//jax/_src/lib", - ], -) - -pytype_strict_library( - name = "compute_on", - srcs = ["_src/compute_on.py"], - deps = [ - ":config", - "//jax/_src/lib", - ], -) - -pytype_strict_library( - name = "xla_metadata", - srcs = ["_src/xla_metadata.py"], - deps = [ - ":config", - "//jax/_src/lib", - ], -) - -pytype_strict_library( - name = "layout", - srcs = ["_src/layout.py"], - deps = [ - ":dtypes", - ":sharding", - ":sharding_impls", - "//jax/_src/lib", - ] + py_deps("numpy"), -) - -pytype_strict_library( - name = "sharding_impls", - srcs = ["_src/sharding_impls.py"], - deps = [ - ":config", - ":core", - ":internal_mesh_utils", - ":mesh", - ":named_sharding", - ":op_shardings", - ":partition_spec", - ":sharding", - ":sharding_specs", - ":source_info_util", - ":tree_util", - ":util", - ":xla_bridge", - "//jax/_src/lib", - ] + py_deps("numpy"), -) - -pytype_strict_library( - name = "named_sharding", - srcs = ["_src/named_sharding.py"], - deps = [ - ":config", - ":mesh", - ":partition_spec", - ":sharding", - ":util", - ":xla_bridge", - "//jax/_src/lib", - ] + py_deps("numpy"), -) - -pytype_strict_library( - name = "sharding_specs", - srcs = ["_src/sharding_specs.py"], - deps = [ - ":config", - ":op_shardings", - ":util", - "//jax/_src/lib", - ] + py_deps("numpy"), -) - -pytype_library( - name = "internal_mesh_utils", - srcs = ["_src/mesh_utils.py"], - deps = [ - ":xla_bridge", - ], -) - -pytype_strict_library( - name = "source_info_util", - srcs = ["_src/source_info_util.py"], - visibility = [":internal"] + jax_visibility("source_info_util"), - deps = [ - ":traceback_util", ":version", - "//jax/_src/lib", - ], -) - -pytype_strict_library( - name = "state_types", - srcs = [ - "_src/state/__init__.py", - "_src/state/indexing.py", - "_src/state/types.py", - ], - deps = [ - ":core", - ":dtypes", - ":effects", - ":pretty_printer", - ":traceback_util", - ":tree_util", - ":typing", - ":util", - ] + py_deps("numpy"), -) - -pytype_strict_library( - name = "tree_util", - srcs = ["_src/tree_util.py"], - visibility = [":internal"] + jax_visibility("tree_util"), - deps = [ - ":traceback_util", - ":util", - "//jax/_src/lib", - ], -) - -pytype_strict_library( - name = "traceback_util", - srcs = ["_src/traceback_util.py"], - visibility = [":internal"] + jax_visibility("traceback_util"), - deps = [ - ":config", - ":util", - "//jax/_src/lib", - ], -) - -pytype_strict_library( - name = "typing", - srcs = [ - "_src/typing.py", - ], - deps = [":basearray"] + py_deps("numpy"), -) - -pytype_strict_library( - name = "tpu_custom_call", - srcs = ["_src/tpu_custom_call.py"], - visibility = [":internal"], - deps = [ - ":config", - ":core", - ":jax", - ":mlir", - ":sharding_impls", - "//jax/_src/lib", - "//jax/_src/pallas", - ] + if_building_jaxlib([ - "//jaxlib/mlir:ir", - "//jaxlib/mlir:mhlo_dialect", - "//jaxlib/mlir:pass_manager", - "//jaxlib/mlir:stablehlo_dialect", - ]) + py_deps("numpy") + py_deps("absl/flags"), -) - -pytype_strict_library( - name = "util", - srcs = ["_src/util.py"], - deps = [ - ":config", - "//jax/_src/lib", - ] + py_deps("numpy"), + "//jax/_src:abstract_arrays", + "//jax/_src:ad", + "//jax/_src:ad_util", + "//jax/_src:api", + "//jax/_src:api_util", + "//jax/_src:basearray", + "//jax/_src:batching", + "//jax/_src:blocked_sampler", + "//jax/_src:buffer_callback", + "//jax/_src:callback", + "//jax/_src:checkify", + "//jax/_src:cloud_tpu_init", + "//jax/_src:compilation_cache_internal", + "//jax/_src:compiler", + "//jax/_src:compute_on", + "//jax/_src:config", + "//jax/_src:core", + "//jax/_src:cudnn", + "//jax/_src:custom_api_util", + "//jax/_src:custom_batching", + "//jax/_src:custom_dce", + "//jax/_src:custom_derivatives", + "//jax/_src:custom_partitioning", + "//jax/_src:custom_partitioning_sharding_rule", + "//jax/_src:custom_transpose", + "//jax/_src:debugger", + "//jax/_src:debugging", + "//jax/_src:deprecations", + "//jax/_src:dlpack", + "//jax/_src:dtypes", + "//jax/_src:earray", + "//jax/_src:effects", + "//jax/_src:environment_info", + "//jax/_src:error_check", + "//jax/_src:export", + "//jax/_src:ffi", + "//jax/_src:flatten_util", + "//jax/_src:hashable_array", + "//jax/_src:hijax", + "//jax/_src:image", + "//jax/_src:init", + "//jax/_src:internal_mesh_utils", + "//jax/_src:jaxpr_util", + "//jax/_src:lax", + "//jax/_src:layout", + "//jax/_src:lazy_loader", + "//jax/_src:memory", + "//jax/_src:mesh", + "//jax/_src:mlir", + "//jax/_src:monitoring", + "//jax/_src:named_sharding", + "//jax/_src:nn", + "//jax/_src:numpy", + "//jax/_src:op_shardings", + "//jax/_src:partial_eval", + "//jax/_src:partition_spec", + "//jax/_src:path", + "//jax/_src:pickle_util", + "//jax/_src:pmap", + "//jax/_src:pretty_printer", + "//jax/_src:profiler", + "//jax/_src:public_test_util", + "//jax/_src:random", + "//jax/_src:ref", + "//jax/_src:scipy", + "//jax/_src:shard_alike", + "//jax/_src:shard_map", + "//jax/_src:sharding", + "//jax/_src:sharding_impls", + "//jax/_src:sharding_specs", + "//jax/_src:source_info_util", + "//jax/_src:sourcemap", + "//jax/_src:stages", + "//jax/_src:tpu", + "//jax/_src:traceback_util", + "//jax/_src:tree", + "//jax/_src:tree_util", + "//jax/_src:typing", + "//jax/_src:util", + "//jax/_src:xla_bridge", + "//jax/_src:xla_metadata", + "//jax/_src:xla_metadata_lib", + "//jax/_src/lib", + ] + py_deps("numpy") + py_deps("opt_einsum") + py_deps("flatbuffers") + jax_extra_deps, ) pytype_strict_library( @@ -1090,218 +256,37 @@ pytype_strict_library( srcs = ["version.py"], ) -pytype_strict_library( - name = "xla", - srcs = ["_src/interpreters/xla.py"], - deps = [ - ":abstract_arrays", - ":config", - ":core", - ":dtypes", - ":sharding_impls", - ":source_info_util", - ":typing", - ":util", - ":xla_bridge", - "//jax/_src/lib", - ] + py_deps("numpy"), -) - -# TODO(phawkins): break up this SCC. -pytype_strict_library( - name = "xla_bridge", - srcs = [ - "_src/clusters/__init__.py", - "_src/clusters/cloud_tpu_cluster.py", - "_src/clusters/cluster.py", - "_src/clusters/k8s_cluster.py", - "_src/clusters/mpi4py_cluster.py", - "_src/clusters/ompi_cluster.py", - "_src/clusters/slurm_cluster.py", - "_src/distributed.py", - "_src/xla_bridge.py", - ], - visibility = [":internal"] + jax_visibility("xla_bridge"), - deps = [ - ":cloud_tpu_init", - ":config", - ":hardware_utils", - ":traceback_util", - ":util", - "//jax/_src/lib", - ], -) - # Public JAX libraries below this point. -py_library_providing_imports_info( +# Aliases of experimental targets. +# TODO(dsuo): remove these aliases/targets. +pytype_strict_library( name = "experimental", - srcs = glob( - [ - "experimental/*.py", - "example_libraries/*.py", - ], - ), - visibility = ["//visibility:public"], + visibility = jax_visibility("experimental_deprecated_alias"), deps = [ ":jax", - ] + py_deps("absl/logging") + py_deps("numpy"), -) - -pytype_library( - name = "stax", - srcs = [ - "example_libraries/stax.py", + "//jax/example_libraries:optimizers", + "//jax/example_libraries:stax", + "//jax/experimental", + "//jax/experimental:checkify", + "//jax/experimental:compute_on", + "//jax/experimental:custom_dce", + "//jax/experimental:custom_partitioning", + "//jax/experimental:fused", + "//jax/experimental:hijax", + "//jax/experimental:jet", + "//jax/experimental:layout", + "//jax/experimental:mesh_utils", + "//jax/experimental:multihost_utils", + "//jax/experimental:ode", + "//jax/experimental:pjit", + "//jax/experimental:profiler", + "//jax/experimental:rnn", + "//jax/experimental:scheduling_groups", + "//jax/experimental:shard_alike", + "//jax/experimental:shard_map", + "//jax/experimental:topologies", + "//jax/experimental:transfer", + "//jax/experimental:xla_metadata", ], - visibility = ["//visibility:public"], - deps = [":jax"], -) - -pytype_library( - name = "experimental_sparse", - srcs = glob( - [ - "experimental/sparse/*.py", - ], - exclude = ["experimental/sparse/test_util.py"], - ), - visibility = ["//visibility:public"], - deps = [":jax"], -) - -pytype_library( - name = "sparse_test_util", - srcs = [ - "experimental/sparse/test_util.py", - ], - visibility = [":internal"], - deps = [ - ":experimental_sparse", - ":jax", - ":test_util", - ] + py_deps("numpy"), -) - -pytype_library( - name = "optimizers", - srcs = [ - "example_libraries/optimizers.py", - ], - visibility = ["//visibility:public"], - deps = [":jax"] + py_deps("numpy"), -) - -pytype_library( - name = "ode", - srcs = ["experimental/ode.py"], - visibility = ["//visibility:public"], - deps = [":jax"], -) - -# TODO(apaszke): Remove this target -pytype_library( - name = "pjit", - srcs = ["experimental/pjit.py"], - visibility = ["//visibility:public"], - deps = [ - ":experimental", - ":jax", - ], -) - -pytype_library( - name = "jet", - srcs = ["experimental/jet.py"], - visibility = ["//visibility:public"], - deps = [":jax"], -) - -pytype_library( - name = "experimental_host_callback", - srcs = [ - "experimental/__init__.py", # To support JAX_HOST_CALLBACK_LEGACY=False - "experimental/host_callback.py", - "experimental/x64_context.py", # To support JAX_HOST_CALLBACK_LEGACY=False - ], - visibility = ["//visibility:public"], - deps = [ - ":jax", - ], -) - -pytype_library( - name = "compilation_cache", - srcs = [ - "experimental/compilation_cache/__init__.py", - "experimental/compilation_cache/compilation_cache.py", - ], - visibility = ["//visibility:public"], - deps = [":jax"], -) - -pytype_library( - name = "mesh_utils", - srcs = ["experimental/mesh_utils.py"], - visibility = ["//visibility:public"], - deps = [ - ":internal_mesh_utils", - ], -) - -# TODO(phawkins): remove this target in favor of the finer-grained targets in jax/extend/... -pytype_strict_library( - name = "extend", - visibility = [":jax_extend_users"], - deps = [ - "//jax/extend", - "//jax/extend:backend", - "//jax/extend:core", - "//jax/extend:linear_util", - "//jax/extend:random", - "//jax/extend:source_info_util", - ], -) - -pytype_library( - name = "mosaic", - srcs = [ - "experimental/mosaic/__init__.py", - "experimental/mosaic/dialects.py", - ], - visibility = [":mosaic_users"], - deps = [ - ":tpu_custom_call", - "//jax/_src/lib", - ], -) - -pytype_library( - name = "rnn", - srcs = ["experimental/rnn.py"], - visibility = ["//visibility:public"], - deps = [":jax"], -) - -pytype_library( - name = "experimental_colocated_python", - srcs = [ - "experimental/colocated_python/__init__.py", - "experimental/colocated_python/api.py", - "experimental/colocated_python/func.py", - "experimental/colocated_python/func_backend.py", - "experimental/colocated_python/obj.py", - "experimental/colocated_python/obj_backend.py", - "experimental/colocated_python/serialization.py", - ], - visibility = ["//visibility:public"], - deps = [ - ":api_util", - ":jax", - ":traceback_util", - ":tree_util", - ":util", - ":xla_bridge", - "//jax/_src/lib", - "//jax/extend:ifrt_programs", - ] + py_deps("numpy") + py_deps("cloudpickle"), ) diff --git a/jax/__init__.py b/jax/__init__.py index ae3bac4ad3fa..9baed340ca50 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -44,9 +44,12 @@ from jax import tree as tree from jax import typing as typing +from jax._src.lib import jaxlib_extension_version + from jax._src.config import ( config as config, enable_checks as enable_checks, + enable_x64 as enable_x64, debug_key_reuse as debug_key_reuse, check_tracer_leaks as check_tracer_leaks, checking_leaks as checking_leaks, @@ -57,6 +60,7 @@ debug_infs as debug_infs, log_compiles as log_compiles, no_tracing as no_tracing, + no_execution as no_execution, explain_cache_misses as explain_cache_misses, default_device as default_device, default_matmul_precision as default_matmul_precision, @@ -70,8 +74,13 @@ transfer_guard_host_to_device as transfer_guard_host_to_device, transfer_guard_device_to_device as transfer_guard_device_to_device, transfer_guard_device_to_host as transfer_guard_device_to_host, - spmd_mode as spmd_mode, + make_user_context as make_user_context, + remove_size_one_mesh_axis_from_type as remove_size_one_mesh_axis_from_type, ) +if jaxlib_extension_version >= 395: + from jax._src.config import thread_guard as thread_guard +del jaxlib_extension_version + from jax._src.core import ensure_compile_time_eval as ensure_compile_time_eval from jax._src.environment_info import print_environment_info as print_environment_info @@ -82,8 +91,9 @@ from jax._src.core import typeof as typeof from jax._src.api import effects_barrier as effects_barrier from jax._src.api import block_until_ready as block_until_ready -from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint # noqa: F401 +from jax._src.ad_checkpoint import checkpoint as checkpoint from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies +from jax._src.ad_checkpoint import remat as remat from jax._src.api import clear_caches as clear_caches from jax._src.api import copy_to_host_async as copy_to_host_async from jax._src.custom_derivatives import closure_convert as closure_convert @@ -94,12 +104,13 @@ from jax._src.xla_bridge import device_count as device_count from jax._src.api import device_get as device_get from jax._src.api import device_put as device_put -from jax._src.api import device_put_sharded as device_put_sharded -from jax._src.api import device_put_replicated as device_put_replicated +from jax._src.api import device_put_sharded as _deprecated_device_put_sharded +from jax._src.api import device_put_replicated as _deprecated_device_put_replicated from jax._src.xla_bridge import devices as devices from jax._src.api import disable_jit as disable_jit from jax._src.api import eval_shape as eval_shape from jax._src.dtypes import float0 as float0 +from jax._src.api import fwd_and_bwd as fwd_and_bwd from jax._src.api import grad as grad from jax._src.api import hessian as hessian from jax._src.xla_bridge import host_count as host_count @@ -123,13 +134,22 @@ from jax._src.xla_bridge import process_index as process_index from jax._src.xla_bridge import process_indices as process_indices from jax._src.callback import pure_callback as pure_callback -from jax._src.ad_checkpoint import checkpoint_wrapper as remat # noqa: F401 -from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct +from jax._src.core import ShapeDtypeStruct as ShapeDtypeStruct from jax._src.api import value_and_grad as value_and_grad from jax._src.api import vjp as vjp from jax._src.api import vmap as vmap from jax._src.sharding_impls import NamedSharding as NamedSharding from jax._src.sharding_impls import make_mesh as make_mesh +from jax._src.sharding_impls import set_mesh as set_mesh +from jax._src.partition_spec import P as P +from jax._src.pjit import reshard as reshard + +from jax._src.shard_map import shard_map as shard_map +from jax._src.shard_map import smap as smap + +from jax.ref import new_ref as new_ref +from jax.ref import freeze as freeze +from jax.ref import Ref as Ref # Force import, allowing jax.interpreters.* to be used after import jax. from jax.interpreters import ad, batching, mlir, partial_eval, pxla, xla @@ -141,16 +161,6 @@ make_array_from_process_local_data as make_array_from_process_local_data, ) -from jax._src.tree_util import ( - tree_map as _deprecated_tree_map, - treedef_is_leaf as _deprecated_treedef_is_leaf, - tree_flatten as _deprecated_tree_flatten, - tree_leaves as _deprecated_tree_leaves, - tree_structure as _deprecated_tree_structure, - tree_transpose as _deprecated_tree_transpose, - tree_unflatten as _deprecated_tree_unflatten, -) - # These submodules are separate because they are in an import cycle with # jax and rely on the names imported above. from jax import custom_derivatives as custom_derivatives @@ -162,6 +172,7 @@ from jax import dlpack as dlpack from jax import dtypes as dtypes from jax import errors as errors +from jax import export as export from jax import ffi as ffi from jax import image as image from jax import lax as lax @@ -173,9 +184,9 @@ from jax import random as random from jax import scipy as scipy from jax import sharding as sharding +from jax import memory as memory from jax import stages as stages from jax import tree_util as tree_util -from jax import util as util # Also circular dependency. from jax._src.array import Shard as Shard @@ -184,59 +195,31 @@ del _ccache _deprecations = { - # Added July 2022 - "treedef_is_leaf": ( - "jax.treedef_is_leaf is deprecated: use jax.tree_util.treedef_is_leaf.", - _deprecated_treedef_is_leaf - ), - "tree_flatten": ( - "jax.tree_flatten is deprecated: use jax.tree.flatten (jax v0.4.25 or newer) " - "or jax.tree_util.tree_flatten (any JAX version).", - _deprecated_tree_flatten + # Remove in v0.10.0 + "array_ref": ( + "jax.array_ref was removed in JAX v0.9.0; use jax.new_ref instead.", + None, ), - "tree_leaves": ( - "jax.tree_leaves is deprecated: use jax.tree.leaves (jax v0.4.25 or newer) " - "or jax.tree_util.tree_leaves (any JAX version).", - _deprecated_tree_leaves - ), - "tree_structure": ( - "jax.tree_structure is deprecated: use jax.tree.structure (jax v0.4.25 or newer) " - "or jax.tree_util.tree_structure (any JAX version).", - _deprecated_tree_structure - ), - "tree_transpose": ( - "jax.tree_transpose is deprecated: use jax.tree.transpose (jax v0.4.25 or newer) " - "or jax.tree_util.tree_transpose (any JAX version).", - _deprecated_tree_transpose - ), - "tree_unflatten": ( - "jax.tree_unflatten is deprecated: use jax.tree.unflatten (jax v0.4.25 or newer) " - "or jax.tree_util.tree_unflatten (any JAX version).", - _deprecated_tree_unflatten + "ArrayRef": ( + "jax.ArrayRef was removed in JAX v0.9.0; use jax.Ref instead.", + None ), - # Added Feb 28, 2024 - "tree_map": ( - "jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) " - "or jax.tree_util.tree_map (any JAX version).", - _deprecated_tree_map + # Added for v0.8.1 + "device_put_replicated": ( + "jax.device_put_replicated is deprecated; use jax.device_put instead.", + _deprecated_device_put_replicated ), - # Finalized Nov 12 2024; remove after Feb 12 2025 - "clear_backends": ( - "jax.clear_backends was removed in JAX v0.4.36", - None + # Added for v0.8.1 + "device_put_sharded": ( + "jax.device_put_sharded is deprecated; use jax.device_put instead.", + _deprecated_device_put_sharded ), } import typing as _typing if _typing.TYPE_CHECKING: - from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf - from jax._src.tree_util import tree_flatten as tree_flatten - from jax._src.tree_util import tree_leaves as tree_leaves - from jax._src.tree_util import tree_map as tree_map - from jax._src.tree_util import tree_structure as tree_structure - from jax._src.tree_util import tree_transpose as tree_transpose - from jax._src.tree_util import tree_unflatten as tree_unflatten - + device_put_replicated = _deprecated_device_put_replicated + device_put_sharded = _deprecated_device_put_sharded else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) @@ -246,3 +229,5 @@ import jax.lib # TODO(phawkins): remove this export. # noqa: F401 # trailer +del _deprecated_device_put_sharded +del _deprecated_device_put_replicated diff --git a/jax/_src/BUILD b/jax/_src/BUILD new file mode 100644 index 000000000000..cef7cadf66a9 --- /dev/null +++ b/jax/_src/BUILD @@ -0,0 +1,1613 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load( + "//jaxlib:jax.bzl", + "if_building_jaxlib", + "jax_export_file_visibility", + "jax_internal_export_back_compat_test_util_visibility", + "jax_internal_test_harnesses_visibility", + "jax_test_util_visibility", + "jax_visibility", + "py_deps", + "py_library_providing_imports_info", + "pytype_strict_library", +) + +package( + default_applicable_licenses = [], + default_visibility = ["//jax:internal"], +) + +exports_files( + ["export/serialization.fbs"], + visibility = jax_export_file_visibility, +) + +pytype_strict_library( + name = "init", + srcs = [ + "__init__.py", + "interpreters/__init__.py", + ], + deps = [":traceback_util"], +) + +# JAX-private test utilities. +pytype_strict_library( + # This build target is required in order to use private test utilities in jax._src.test_util, + # and its visibility is intentionally restricted to discourage its use outside JAX itself. + # JAX does provide some public test utilities (see jax/test_util.py); + # these are available in jax.test_util via the standard :jax target. + name = "test_util", + srcs = [ + "test_loader.py", + "test_util.py", + "test_warning_util.py", + ], + visibility = [ + "//jax:internal", + ] + jax_test_util_visibility, + deps = [ + ":api", + ":cloud_tpu_init", + ":compilation_cache_internal", + ":config", + ":core", + ":deprecations", + ":dtypes", + ":lax", + ":mesh", + ":mlir", + ":monitoring", + ":numpy", + ":public_test_util", + ":sharding_impls", + ":tree_util", + ":typing", + ":util", + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("absl/testing") + py_deps("numpy"), +) + +# TODO(necula): break the internal_test_util into smaller build targets. +pytype_strict_library( + name = "internal_test_util", + srcs = [ + "internal_test_util/__init__.py", + "internal_test_util/deprecation_module.py", + "internal_test_util/lax_test_util.py", + ] + glob( + [ + "internal_test_util/lazy_loader_module/*.py", + ], + ), + visibility = ["//jax:internal"], + deps = if_building_jaxlib( + if_building = [ + ":api", + ":config", + ":core", + ":dtypes", + ":deprecations", + ":lax", + ":lazy_loader", + ":random", + ":test_util", + ":tree_util", + ":typing", + ":util", + ":xla_bridge", + ], + if_not_building = [], + ) + py_deps("numpy"), +) + +pytype_strict_library( + name = "internal_test_harnesses", + srcs = ["internal_test_util/test_harnesses.py"], + visibility = ["//jax:internal"] + jax_internal_test_harnesses_visibility, + deps = if_building_jaxlib( + if_building = [ + ":ad_util", + ":api", + ":config", + ":dtypes", + ":lax", + ":numpy", + ":random", + ":test_util", + ":typing", + ":xla_bridge", + "//jax/_src/lib", + ], + if_not_building = [], + ) + py_deps("numpy") + py_deps("absl/testing"), +) + +pytype_strict_library( + name = "test_multiprocess", + srcs = ["test_multiprocess.py"], + visibility = ["//jax:internal"], + deps = if_building_jaxlib( + if_building = [ + ":config", + ":test_util", + ":xla_bridge", + "//jax/_src/lib", + ], + if_not_building = [], + ) + py_deps("absl-all") + py_deps("portpicker"), +) + +pytype_strict_library( + name = "internal_export_back_compat_test_util", + srcs = ["internal_test_util/export_back_compat_test_util.py"], + visibility = [ + "//jax:internal", + ] + jax_internal_export_back_compat_test_util_visibility, + deps = if_building_jaxlib( + if_building = [ + ":api", + ":core", + ":stages", + ":export", + ":test_util", + ":tree_util", + ":typing", + ":xla_bridge", + ], + if_not_building = [], + ) + py_deps("numpy") + py_deps("absl/logging"), +) + +pytype_strict_library( + name = "internal_export_back_compat_test_data", + srcs = glob([ + "internal_test_util/export_back_compat_test_data/*.py", + "internal_test_util/export_back_compat_test_data/pallas/*.py", + ]), + visibility = [ + "//jax:internal", + ], + deps = py_deps("numpy"), +) + +pytype_strict_library( + name = "abstract_arrays", + srcs = ["abstract_arrays.py"], + deps = [ + ":ad_util", + ":config", + ":core", + ":dtypes", + ":literals", + ":traceback_util", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "ad_util", + srcs = ["ad_util.py"], + deps = [ + ":core", + ":traceback_util", + ":tree_util", + ":typing", + ":util", + ], +) + +pytype_strict_library( + name = "api", + srcs = [ + "api.py", + "array.py", + "dispatch.py", + "interpreters/pxla.py", + "pjit.py", + ], + visibility = ["//jax:internal"] + jax_visibility("api"), + deps = [ + ":abstract_arrays", + ":ad", + ":api_util", + ":basearray", + ":batching", + ":compiler", + ":config", + ":core", + ":deprecations", + ":dtypes", + ":effects", + ":jaxpr_util", + ":layout", + ":literals", + ":mesh", + ":mlir", + ":monitoring", + ":op_shardings", + ":partial_eval", + ":partition_spec", + ":profiler", + ":sharding", + ":sharding_impls", + ":sharding_specs", + ":source_info_util", + ":stages", + ":state_types", + ":traceback_util", + ":tree_util", + ":typing", + ":util", + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "api_util", + srcs = ["api_util.py"], + deps = [ + ":abstract_arrays", + ":config", + ":core", + ":dtypes", + ":state_types", + ":traceback_util", + ":tree_util", + ":util", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "basearray", + srcs = ["basearray.py"], + pytype_srcs = ["basearray.pyi"], + deps = [ + ":literals", + ":named_sharding", + ":partition_spec", + ":sharding", + ":util", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "blocked_sampler", + srcs = ["blocked_sampler.py"], + deps = [ + ":numpy", + ":random", + ":typing", + ], +) + +pytype_strict_library( + name = "buffer_callback", + srcs = ["buffer_callback.py"], + deps = [ + ":ad", + ":api", + ":batching", + ":core", + ":effects", + ":ffi", + ":mlir", + ":tree_util", + ":util", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "memory", + srcs = ["memory.py"], +) + +pytype_strict_library( + name = "callback", + srcs = ["callback.py"], + deps = [ + ":ad", + ":api", + ":batching", + ":config", + ":core", + ":dtypes", + ":effects", + ":ffi", + ":mlir", + ":pickle_util", + ":sharding", + ":sharding_impls", + ":tree_util", + ":typing", + ":util", + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "checkify", + srcs = ["checkify.py"], + visibility = ["//jax:internal"] + jax_visibility("checkify"), + deps = [ + ":ad", + ":ad_util", + ":api", + ":api_util", + ":batching", + ":callback", + ":config", + ":core", + ":custom_derivatives", + ":dtypes", + ":effects", + ":lax", + ":mesh", + ":mlir", + ":numpy", + ":partial_eval", + ":partition_spec", + ":shard_map", + ":sharding_impls", + ":source_info_util", + ":traceback_util", + ":tree_util", + ":typing", + ":util", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "cloud_tpu_init", + srcs = ["cloud_tpu_init.py"], + deps = [ + ":config", + ":hardware_utils", + ], +) + +pytype_strict_library( + name = "compilation_cache_internal", + srcs = ["compilation_cache.py"], + visibility = ["//jax:internal"] + jax_visibility("compilation_cache"), + deps = [ + ":cache_key", + ":compilation_cache_interface", + ":config", + ":lru_cache", + ":monitoring", + ":path", + "//jax/_src/lib", + ] + py_deps("numpy") + py_deps("zstandard"), +) + +pytype_strict_library( + name = "cache_key", + srcs = ["cache_key.py"], + visibility = ["//jax:internal"] + jax_visibility("compilation_cache"), + deps = [ + ":config", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "compilation_cache_interface", + srcs = ["compilation_cache_interface.py"], + deps = [":util"], +) + +py_library_providing_imports_info( + name = "lax", + srcs = glob( + [ + "lax/**/*.py", + "state/**/*.py", + ], + ) + [ + "ad_checkpoint.py", + ], + visibility = ["//jax:internal"] + jax_visibility("lax"), + deps = [ + ":abstract_arrays", + ":ad", + ":ad_util", + ":api", + ":api_util", + ":batching", + ":callback", + ":config", + ":core", + ":custom_derivatives", + ":custom_partitioning_sharding_rule", + ":dtypes", + ":effects", + ":ffi", + ":literals", + ":mesh", + ":mlir", + ":named_sharding", + ":partial_eval", + ":partition_spec", + ":pretty_printer", + ":sharding", + ":sharding_impls", + ":source_info_util", + ":state_types", + ":traceback_util", + ":tree_util", + ":typing", + ":util", + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "lru_cache", + srcs = ["lru_cache.py"], + deps = [ + ":compilation_cache_interface", + ":path", + ] + py_deps("filelock"), +) + +pytype_strict_library( + name = "config", + srcs = ["config.py"], + deps = [ + ":deprecations", + ":logging_config", + "//jax/_src/lib", + ], +) + +pytype_strict_library( + name = "logging_config", + srcs = ["logging_config.py"], + deps = ["//jax/_src/lib"], +) + +pytype_strict_library( + name = "compiler", + srcs = ["compiler.py"], + visibility = ["//jax:internal"] + jax_visibility("compiler"), + deps = [ + ":cache_key", + ":compilation_cache_internal", + ":config", + ":mlir", + ":monitoring", + ":path", + ":profiler", + ":traceback_util", + ":util", + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "core", + srcs = [ + "core.py", + "errors.py", + "linear_util.py", + ], + deps = [ + ":config", + ":dtypes", + ":effects", + ":layout", + ":memory", + ":mesh", + ":named_sharding", + ":partition_spec", + ":pretty_printer", + ":sharding", + ":source_info_util", + ":traceback_util", + ":tree_util", + ":typing", + ":util", + ":xla_metadata_lib", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "literals", + srcs = ["literals.py"], + deps = [ + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "custom_api_util", + srcs = ["custom_api_util.py"], +) + +pytype_strict_library( + name = "custom_batching", + srcs = ["custom_batching.py"], + deps = [ + ":ad", + ":api", + ":api_util", + ":batching", + ":core", + ":custom_api_util", + ":mlir", + ":partial_eval", + ":source_info_util", + ":traceback_util", + ":tree_util", + ":util", + ], +) + +pytype_strict_library( + name = "custom_dce", + srcs = ["custom_dce.py"], + deps = [ + ":ad", + ":api_util", + ":batching", + ":core", + ":custom_api_util", + ":mlir", + ":partial_eval", + ":source_info_util", + ":traceback_util", + ":tree_util", + ":util", + ], +) + +pytype_strict_library( + name = "custom_derivatives", + srcs = ["custom_derivatives.py"], + deps = [ + ":ad", + ":ad_util", + ":api", + ":api_util", + ":batching", + ":config", + ":core", + ":custom_api_util", + ":custom_transpose", + ":dtypes", + ":effects", + ":mlir", + ":partial_eval", + ":state_types", + ":traceback_util", + ":tree_util", + ":util", + ], +) + +pytype_strict_library( + name = "custom_partitioning", + srcs = ["custom_partitioning.py"], + deps = [ + ":api", + ":api_util", + ":config", + ":core", + ":custom_api_util", + ":custom_partitioning_sharding_rule", + ":mesh", + ":mlir", + ":partial_eval", + ":sharding", + ":sharding_impls", + ":tree_util", + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "custom_partitioning_sharding_rule", + srcs = ["custom_partitioning_sharding_rule.py"], + deps = [ + "//jax/_src/lib", + ], +) + +pytype_strict_library( + name = "custom_transpose", + srcs = ["custom_transpose.py"], + deps = [ + ":ad", + ":ad_util", + ":api", + ":api_util", + ":core", + ":custom_api_util", + ":mlir", + ":partial_eval", + ":source_info_util", + ":traceback_util", + ":tree_util", + ":util", + ], +) + +pytype_strict_library( + name = "debugger", + srcs = glob(["debugger/**/*.py"]), + deps = [ + ":callback", + ":core", + ":debugging", + ":lax", + ":traceback_util", + ":tree_util", + ":util", + ], +) + +pytype_strict_library( + name = "debugging", + srcs = [ + "debugging.py", + ], + deps = [ + ":ad", + ":api", + ":batching", + ":callback", + ":config", + ":core", + ":effects", + ":lax", + ":mesh", + ":mlir", + ":numpy", + ":partial_eval", + ":shard_map", + ":sharding", + ":sharding_impls", + ":source_info_util", + ":tree_util", + ":util", + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "deprecations", + srcs = ["deprecations.py"], +) + +pytype_strict_library( + name = "dlpack", + srcs = ["dlpack.py"], + deps = [ + ":api", + ":deprecations", + ":dtypes", + ":lax", + ":numpy", + ":sharding", + ":typing", + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "dtypes", + srcs = [ + "dtypes.py", + ], + deps = [ + ":config", + ":literals", + ":traceback_util", + ":typing", + ":util", + "//jax/_src/lib", + ] + py_deps("ml_dtypes") + py_deps("numpy"), +) + +pytype_strict_library( + name = "earray", + srcs = ["earray.py"], + deps = [ + ":api", + ":basearray", + ":core", + ":dtypes", + ":sharding_impls", + ":tree_util", + ":util", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "effects", + srcs = ["effects.py"], +) + +pytype_strict_library( + name = "environment_info", + srcs = ["environment_info.py"], + deps = [ + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "error_check", + srcs = ["error_check.py"], + deps = [ + ":core", + ":export", + ":lax", + ":mesh", + ":shard_map", + ":sharding_impls", + ":source_info_util", + ":traceback_util", + ":tree_util", + ":typing", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "export", + srcs = glob([ + "export/**/*.py", + ]), + visibility = ["//jax:internal"] + jax_visibility("export"), + deps = [ + ":ad_util", + ":api", + ":config", + ":core", + ":custom_derivatives", + ":dtypes", + ":effects", + ":mesh", + ":mlir", + ":named_sharding", + ":partition_spec", + ":sharding", + ":sharding_impls", + ":source_info_util", + ":stages", + ":traceback_util", + ":tree_util", + ":typing", + ":util", + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("flatbuffers") + py_deps("numpy") + py_deps("opt_einsum"), +) + +pytype_strict_library( + name = "ffi", + srcs = ["ffi.py"], + deps = [ + ":ad", + ":api", + ":batching", + ":core", + ":effects", + ":frozen_dict", + ":hashable_array", + ":layout", + ":mlir", + ":typing", + ":util", + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "flatten_util", + srcs = [ + "flatten_util.py", + ], + deps = [ + ":dtypes", + ":lax", + ":tree_util", + ":typing", + ":util", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "frozen_dict", + srcs = ["frozen_dict.py"], +) + +pytype_strict_library( + name = "hardware_utils", + srcs = ["hardware_utils.py"], +) + +pytype_strict_library( + name = "hashable_array", + srcs = ["hashable_array.py"], + deps = py_deps("numpy"), +) + +pytype_strict_library( + name = "image", + srcs = glob([ + "image/**/*.py", + ]), + visibility = ["//jax:internal"] + jax_visibility("image"), # buildifier: disable=visibility-as-string-list + deps = [ + ":api", + ":core", + ":dtypes", + ":lax", + ":numpy", + ":util", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "lax_reference", + srcs = ["lax_reference.py"], + visibility = ["//jax:internal"] + jax_visibility("lax_reference"), + deps = [ + ":core", + ":dtypes", + ":util", + ] + py_deps("numpy") + py_deps("scipy") + py_deps("opt_einsum"), +) + +pytype_strict_library( + name = "lazy_loader", + srcs = ["lazy_loader.py"], +) + +pytype_strict_library( + name = "jaxpr_util", + srcs = ["jaxpr_util.py"], + deps = [ + ":config", + ":core", + ":path", + ":source_info_util", + ":util", + "//jax/_src/lib", + ], +) + +pytype_strict_library( + name = "mesh", + srcs = ["mesh.py"], + deps = [ + ":config", + ":util", + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "ad", + srcs = ["interpreters/ad.py"], + deps = [ + ":ad_util", + ":api_util", + ":config", + ":core", + ":dtypes", + ":mesh", + ":partial_eval", + ":source_info_util", + ":state_types", + ":tree_util", + ":util", + ], +) + +pytype_strict_library( + name = "batching", + srcs = ["interpreters/batching.py"], + deps = [ + ":ad_util", + ":config", + ":core", + ":mesh", + ":partial_eval", + ":partition_spec", + ":sharding_impls", + ":source_info_util", + ":tree_util", + ":typing", + ":util", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "mlir", + srcs = ["interpreters/mlir.py"], + deps = [ + ":ad_util", + ":api_util", + ":config", + ":core", + ":dtypes", + ":effects", + ":frozen_dict", + ":hashable_array", + ":jaxpr_util", + ":layout", + ":literals", + ":mesh", + ":op_shardings", + ":partial_eval", + ":partition_spec", + ":path", + ":pickle_util", + ":sharding", + ":sharding_impls", + ":source_info_util", + ":state_types", + ":typing", + ":util", + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "monitoring", + srcs = ["monitoring.py"], +) + +pytype_strict_library( + name = "op_shardings", + srcs = ["op_shardings.py"], + deps = [ + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "scipy", + srcs = glob([ + "scipy/**/*.py", + "third_party/**/*.py", + ]), + deps = [ + ":api", + ":api_util", + ":config", + ":core", + ":custom_derivatives", + ":deprecations", + ":dtypes", + ":lax", + ":nn", + ":numpy", + ":random", + ":tpu", + ":tree_util", + ":typing", + ":util", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "sourcemap", + srcs = ["sourcemap.py"], +) + +pytype_strict_library( + name = "partial_eval", + srcs = ["interpreters/partial_eval.py"], + visibility = ["//jax:internal"] + jax_visibility("partial_eval"), + deps = [ + ":ad_util", + ":api_util", + ":config", + ":core", + ":dtypes", + ":effects", + ":profiler", + ":source_info_util", + ":state_types", + ":tree_util", + ":util", + ":xla_metadata_lib", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "partition_spec", + srcs = ["partition_spec.py"], + deps = [ + ":util", + "//jax/_src/lib", + ], +) + +pytype_strict_library( + name = "path", + srcs = ["path.py"], + deps = py_deps("epath"), +) + +pytype_strict_library( + name = "pickle_util", + srcs = ["pickle_util.py"], + deps = [":profiler"] + py_deps("cloudpickle"), +) + +pytype_strict_library( + name = "pretty_printer", + srcs = ["pretty_printer.py"], + visibility = ["//jax:internal"] + jax_visibility("pretty_printer"), + deps = [ + ":config", + "//jax/_src/lib", + ], +) + +pytype_strict_library( + name = "profiler", + srcs = ["profiler.py"], + deps = [ + ":traceback_util", + ":xla_bridge", + "//jax/_src/lib", + ], +) + +pytype_strict_library( + name = "pmap", + srcs = ["pmap.py"], + deps = [ + ":api", + ":core", + ":lax", + ":mesh", + ":shard_map", + ":stages", + ":traceback_util", + ":tree_util", + ":util", + ":xla_bridge", + ], +) + +pytype_strict_library( + name = "public_test_util", + srcs = [ + "public_test_util.py", + ], + deps = [ + ":api", + ":config", + ":dtypes", + ":tree_util", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "sharding", + srcs = ["sharding.py"], + deps = [ + ":op_shardings", + ":util", + ":xla_bridge", + "//jax/_src/lib", + ], +) + +pytype_strict_library( + name = "shard_alike", + srcs = [ + "shard_alike.py", + ], + deps = [ + ":ad", + ":api", + ":batching", + ":config", + ":core", + ":mlir", + ":tree_util", + ":util", + "//jax/_src/lib", + ], +) + +pytype_strict_library( + name = "shard_map", + srcs = ["shard_map.py"], + deps = [ + ":ad", + ":ad_util", + ":api", + ":api_util", + ":batching", + ":config", + ":core", + ":dtypes", + ":effects", + ":lax", + ":layout", + ":mesh", + ":mlir", + ":partial_eval", + ":sharding", + ":sharding_impls", + ":source_info_util", + ":stages", + ":traceback_util", + ":tree_util", + ":util", + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "hijax", + srcs = ["hijax.py"], + deps = [ + ":ad", + ":ad_util", + ":api", + ":api_util", + ":batching", + ":config", + ":core", + ":custom_derivatives", + ":dtypes", + ":effects", + ":partial_eval", + ":state_types", + ":tree_util", + ":util", + ], +) + +pytype_strict_library( + name = "stages", + srcs = ["stages.py"], + visibility = ["//jax:internal"] + jax_visibility("stages"), + deps = [ + ":config", + ":core", + ":layout", + ":mlir", + ":sharding", + ":sharding_impls", + ":source_info_util", + ":traceback_util", + ":tree_util", + ":typing", + ":util", + "//jax/_src/lib", + ], +) + +pytype_strict_library( + name = "compute_on", + srcs = ["compute_on.py"], + deps = [ + ":ad", + ":api", + ":api_util", + ":batching", + ":config", + ":core", + ":mlir", + ":partial_eval", + ":tree_util", + ":util", + "//jax/_src/lib", + ], +) + +pytype_strict_library( + name = "xla_metadata", + srcs = ["xla_metadata.py"], + deps = [ + ":ad", + ":api", + ":batching", + ":config", + ":core", + ":mlir", + ":tree_util", + ":xla_metadata_lib", + "//jax/_src/lib", + ], +) + +pytype_strict_library( + name = "xla_metadata_lib", + srcs = ["xla_metadata_lib.py"], + deps = [ + ":config", + "//jax/_src/lib", + ], +) + +pytype_strict_library( + name = "layout", + srcs = ["layout.py"], + deps = [ + ":dtypes", + ":named_sharding", + ":sharding", + ":util", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "sharding_impls", + srcs = ["sharding_impls.py"], + visibility = ["//jax:internal"], + deps = [ + ":config", + ":core", + ":deprecations", + ":internal_mesh_utils", + ":mesh", + ":named_sharding", + ":op_shardings", + ":partition_spec", + ":sharding", + ":sharding_specs", + ":source_info_util", + ":tree_util", + ":util", + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "named_sharding", + srcs = ["named_sharding.py"], + deps = [ + ":config", + ":mesh", + ":partition_spec", + ":sharding", + ":util", + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "nn", + srcs = glob([ + "nn/**/*.py", + ]), + deps = [ + ":api", + ":config", + ":core", + ":cudnn", + ":custom_derivatives", + ":deprecations", + ":dtypes", + ":lax", + ":named_sharding", + ":numpy", + ":partition_spec", + ":random", + ":sharding_impls", + ":typing", + ":util", + ] + py_deps("numpy"), +) + +py_library_providing_imports_info( + name = "numpy", + srcs = glob([ + "numpy/**/*.py", + "ops/**/*.py", + ]), + deps = [ + ":api", + ":api_util", + ":config", + ":core", + ":custom_derivatives", + ":deprecations", + ":dtypes", + ":error_check", + ":export", + ":lax", + ":literals", + ":mesh", + ":sharding", + ":sharding_impls", + ":tree_util", + ":typing", + ":util", + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("numpy") + py_deps("opt_einsum"), +) + +pytype_strict_library( + name = "sharding_specs", + srcs = ["sharding_specs.py"], + deps = [ + ":op_shardings", + ":util", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "internal_mesh_utils", + srcs = ["mesh_utils.py"], + deps = [ + ":xla_bridge", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "source_info_util", + srcs = ["source_info_util.py"], + visibility = ["//jax:internal"] + jax_visibility("source_info_util"), + deps = [ + ":traceback_util", + "//jax/_src/lib", + ], +) + +pytype_strict_library( + name = "state_types", + srcs = [ + "state/__init__.py", + "state/indexing.py", + "state/types.py", + ], + visibility = ["//jax:internal"] + jax_visibility("state_types"), + deps = [ + ":core", + ":dtypes", + ":effects", + ":pretty_printer", + ":traceback_util", + ":tree_util", + ":typing", + ":util", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "tpu", + srcs = glob([ + "tpu/**/*.py", + ]), + deps = [ + ":api", + ":config", + ":core", + ":dtypes", + ":lax", + ":mlir", + ":numpy", + ":traceback_util", + ":tree_util", + ":typing", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "tree", + srcs = ["tree.py"], + deps = [":tree_util"], +) + +pytype_strict_library( + name = "tree_util", + srcs = ["tree_util.py"], + visibility = ["//jax:internal"] + jax_visibility("tree_util"), + deps = [ + ":traceback_util", + ":util", + "//jax/_src/lib", + ], +) + +pytype_strict_library( + name = "traceback_util", + srcs = ["traceback_util.py"], + visibility = ["//jax:internal"] + jax_visibility("traceback_util"), + deps = [ + ":config", + ":util", + "//jax/_src/lib", + ], +) + +pytype_strict_library( + name = "typing", + srcs = [ + "typing.py", + ], + deps = [":basearray"] + py_deps("numpy"), +) + +pytype_strict_library( + name = "tpu_custom_call", + srcs = ["tpu_custom_call.py"], + visibility = ["//jax:internal"], + deps = [ + ":api", + ":batching", + ":cloud_tpu_init", + ":config", + ":core", + ":frozen_dict", + ":mlir", + ":sharding_impls", + "//jax/_src/lib", + "//jax/_src/pallas", + ] + if_building_jaxlib([ + "//jaxlib/mlir:ir", + "//jaxlib/mlir:mhlo_dialect", + "//jaxlib/mlir:pass_manager", + "//jaxlib/mlir:stablehlo_dialect", + ]) + py_deps("numpy") + py_deps("absl/flags"), +) + +pytype_strict_library( + name = "util", + srcs = ["util.py"], + deps = [ + ":config", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +# TODO(phawkins): break up this SCC. +pytype_strict_library( + name = "xla_bridge", + srcs = [ + "clusters/__init__.py", + "clusters/cloud_tpu_cluster.py", + "clusters/cluster.py", + "clusters/k8s_cluster.py", + "clusters/mpi4py_cluster.py", + "clusters/ompi_cluster.py", + "clusters/slurm_cluster.py", + "distributed.py", + "xla_bridge.py", + ], + visibility = ["//jax:internal"] + jax_visibility("xla_bridge"), + deps = [ + ":cloud_tpu_init", + ":config", + ":hardware_utils", + ":traceback_util", + ":util", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "cudnn", + srcs = glob(["cudnn/**/*.py"]), + deps = [ + ":api", + ":batching", + ":core", + ":custom_derivatives", + ":custom_partitioning", + ":custom_partitioning_sharding_rule", + ":dtypes", + ":lax", + ":mlir", + ":numpy", + ":sharding_impls", + ":tree_util", + ":typing", + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "extend_src", + srcs = glob(include = ["extend/**/*.py"]), + deps = [ + ":random", + ":typing", + ], +) + +pytype_strict_library( + name = "random", + srcs = [ + "prng.py", + "random.py", + ], + visibility = ["//jax:internal"] + jax_visibility("random"), + deps = [ + ":ad", + ":api", + ":batching", + ":config", + ":core", + ":dtypes", + ":ffi", + ":lax", + ":literals", + ":mesh", + ":mlir", + ":numpy", + ":pretty_printer", + ":sharding_impls", + ":source_info_util", + ":tree_util", + ":typing", + ":util", + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "ref", + srcs = ["ref.py"], + deps = [":core"], +) + +pytype_strict_library( + name = "stateful_rng", + srcs = [ + "stateful_rng.py", + ], + deps = [ + ":api_util", + ":core", + ":dtypes", + ":lax", + ":numpy", + ":random", + ":ref", + ":tree_util", + ":typing", + ] + py_deps("numpy"), +) diff --git a/jax/_src/abstract_arrays.py b/jax/_src/abstract_arrays.py index 83c431499d63..8b3c0b37c5cd 100644 --- a/jax/_src/abstract_arrays.py +++ b/jax/_src/abstract_arrays.py @@ -14,11 +14,11 @@ from __future__ import annotations -from functools import partial - import numpy as np +from jax._src import config from jax._src import core +from jax._src import literals from jax._src import dtypes from jax._src import traceback_util @@ -41,7 +41,7 @@ numpy_scalar_types.add(dtypes.int2) numpy_scalar_types.add(dtypes.uint2) -array_types: set[type] = {np.ndarray} | numpy_scalar_types # pylint: disable=g-bare-generic +array_types: set[type] = {literals.TypedNdArray, np.ndarray} | numpy_scalar_types # pylint: disable=g-bare-generic def masked_array_error(*args, **kwargs): @@ -60,6 +60,17 @@ def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray: core.pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array +def _make_shaped_array_for_typed_ndarray( + x: literals.TypedNdArray, +) -> ShapedArray: + dtype = x.dtype + dtypes.check_valid_dtype(dtype) + return ShapedArray(x.shape, dtype, sharding=None, weak_type=x.weak_type) + + +core.pytype_aval_mappings[literals.TypedNdArray] = _make_shaped_array_for_typed_ndarray + + def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray: dtype = np.dtype(x) dtypes.check_valid_dtype(dtype) @@ -72,13 +83,83 @@ def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray: core.literalable_types.update(array_types) -def _make_abstract_python_scalar(typ, val): - # Note: all python scalar types are weak except bool, because bool only - # comes in a single width. - return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val), - weak_type=typ is not bool, sharding=None) +core.literalable_types.add(literals.TypedNdArray) + +_int32_min = np.iinfo(np.int32).min +_int32_max = np.iinfo(np.int32).max +_int64_min = np.iinfo(np.int64).min +_int64_max = np.iinfo(np.int64).max + +# Note: all python scalar types are weak except bool, because bool only +# comes in a single width. +_bool_aval = ShapedArray((), dtype=np.dtype(bool)) +_int32_aval = ShapedArray((), dtype=np.dtype(np.int32), weak_type=True) +_int64_aval = ShapedArray((), dtype=np.dtype(np.int64), weak_type=True) +_float32_aval = ShapedArray((), dtype=np.dtype(np.float32), weak_type=True) +_float64_aval = ShapedArray((), dtype=np.dtype(np.float64), weak_type=True) +_complex64_aval = ShapedArray((), dtype=np.dtype(np.complex64), weak_type=True) +_complex128_aval = ShapedArray((), dtype=np.dtype(np.complex128), weak_type=True) + +core.pytype_aval_mappings[bool] = lambda v: _bool_aval + +def _int_aval(value): + if config.enable_x64.value: + if value < _int64_min or value > _int64_max: + raise OverflowError(f"Python int {value} too large to convert to int64") + return _int64_aval + else: + if value < _int32_min or value > _int32_max: + raise OverflowError(f"Python int {value} too large to convert to int32") + return _int32_aval +core.pytype_aval_mappings[int] = _int_aval + +_float_aval = lambda v: _float64_aval if config.enable_x64.value else _float32_aval +core.pytype_aval_mappings[float] = _float_aval + +_complex_aval = lambda v: _complex128_aval if config.enable_x64.value else _complex64_aval +core.pytype_aval_mappings[complex] = _complex_aval + +core.literalable_types.update(dtypes.python_scalar_types) + + +def _aval_for_typed_scalar(x): + return ShapedArray((), x.dtype, weak_type=True, sharding=None) + +for t in literals.typed_scalar_types: + core.pytype_aval_mappings[t] = _aval_for_typed_scalar +core.literalable_types.update(literals.typed_scalar_types) + + +def _canonicalize_ndarray_dtype(x): + dtype = dtypes.canonicalize_dtype(x.dtype) + return literals.TypedNdArray(np.asarray(x, dtype), weak_type=False) + +def _canonicalize_masked_array_dtype(x): + raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. " + "Use arr.filled() to convert the value to a standard numpy array.") + +dtypes.canonicalize_value_handlers.update( + (t, _canonicalize_ndarray_dtype) for t in numpy_scalar_types) + + +dtypes.canonicalize_value_handlers[literals.TypedNdArray] = lambda x: x + +dtypes.canonicalize_value_handlers[np.ndarray] = _canonicalize_ndarray_dtype +dtypes.canonicalize_value_handlers[np.ma.MaskedArray] = _canonicalize_masked_array_dtype + +def _canonicalize_python_scalar(literal_type, typ): + def canonicalize_scalar(x): + return literal_type(x, dtypes.scalar_type_to_dtype(typ, x)) # pytype: disable=wrong-arg-types + return canonicalize_scalar -for t in dtypes.python_scalar_dtypes: - core.pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t) +dtypes.canonicalize_value_handlers[bool] = lambda x: x +dtypes.canonicalize_value_handlers[int] = _canonicalize_python_scalar( + literals.TypedInt, int) +dtypes.canonicalize_value_handlers[float] = _canonicalize_python_scalar( + literals.TypedFloat, float) +dtypes.canonicalize_value_handlers[complex] = _canonicalize_python_scalar( + literals.TypedComplex, complex) -core.literalable_types.update(dtypes.python_scalar_dtypes.keys()) +dtypes.canonicalize_value_handlers[literals.TypedInt] = lambda x: x +dtypes.canonicalize_value_handlers[literals.TypedFloat] = lambda x: x +dtypes.canonicalize_value_handlers[literals.TypedComplex] = lambda x: x diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index c2868cf7c078..62002cd8eb51 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -15,7 +15,6 @@ from __future__ import annotations from collections.abc import Callable, Sequence -import functools from functools import partial import logging from typing import Any @@ -27,6 +26,7 @@ from jax._src import api from jax._src import config from jax._src import core +from jax._src import deprecations from jax._src import dtypes from jax._src import linear_util as lu from jax._src import effects @@ -40,8 +40,13 @@ from jax._src.lax import lax as lax_internal from jax._src.lax import convolution as lax_convolution from jax._src.lib.mlir.dialects import hlo +from jax._src.state import discharge +from jax._src.state.types import AbstractRef from jax._src.traceback_util import api_boundary -from jax._src.tree_util import PyTreeDef, tree_flatten, tree_unflatten, tree_structure +from jax._src.tree_util import ( + PyTreeDef, tree_flatten, tree_unflatten, tree_structure, broadcast_prefix, + tree_map) +from jax._src.typing import DeprecatedArg from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map, safe_zip, merge_lists, weakref_lru_cache) @@ -56,30 +61,48 @@ ### Policies def everything_saveable(*_, **__) -> bool: - # This is the effective policy without any use of jax.remat. + """The default strategy, as if ``jax.checkpoint`` were not being used at all. + + This is the effective policy without any use of jax.remat.""" return True def nothing_saveable(*_, **__) -> bool: - # This is the effective policy when using jax.remat without explicit policy. + """Rematerialize everything, as if a custom policy were not being used at all. + + This is the effective policy when using jax.remat without explicit policy.""" return False def dots_saveable(prim, *_, **__) -> bool: # Matrix multiplies are expensive, so let's save them (and nothing else). - return prim in {lax_internal.dot_general_p, - lax_convolution.conv_general_dilated_p} + # We check for scaled matmul by name because we want to avoid importing + # cudnn here. + return (prim in {lax_internal.dot_general_p, + lax_convolution.conv_general_dilated_p} or + prim.name == "scaled_matmul_wrapper") checkpoint_dots = dots_saveable -def dots_with_no_batch_dims_saveable(prim, *_, **params) -> bool: - # This is a useful heuristic for transformers. +def dots_with_no_batch_dims_saveable(prim, *args, **params) -> bool: + """This is a useful heuristic for transformers.""" if prim is lax_internal.dot_general_p: (_, _), (lhs_b, rhs_b) = params['dimension_numbers'] if not lhs_b and not rhs_b: return True + + # We check for scaled matmul by name because we want to avoid importing + # cudnn here. + if prim.name == "scaled_matmul_wrapper": + lhs = args[0] + # Only save the dot if its batch dim is of size 1. + return lhs.shape[0] == 1 + return False def offload_dot_with_no_batch_dims(offload_src, offload_dst): + """Same as ``dots_with_no_batch_dims_saveable``, but offload to CPU memory + instead of recomputing. + + This is a useful heuristic for transformers.""" def policy(prim, *_, **params): - # This is a useful heuristic for transformers. if prim is lax_internal.dot_general_p: (_, _), (lhs_b, rhs_b) = params['dimension_numbers'] if not lhs_b and not rhs_b: @@ -100,7 +123,8 @@ def policy(prim, *_, **params): return policy def save_any_names_but_these(*names_not_to_save): - """Save only named values, excluding the names given.""" + """Save only named values, i.e. any outputs of `checkpoint_name`, excluding + the names given.""" names_not_to_save = frozenset(names_not_to_save) def policy(prim, *_, **params): if prim is name_p: @@ -120,6 +144,8 @@ def policy(prim, *_, **params): def save_and_offload_only_these_names( *, names_which_can_be_saved, names_which_can_be_offloaded, offload_src, offload_dst): + """Same as ``save_only_these_names``, but offload to CPU memory instead of + recomputing.""" names_which_can_be_saved = set(names_which_can_be_saved) names_which_can_be_offloaded = set(names_which_can_be_offloaded) intersection = names_which_can_be_saved.intersection(names_which_can_be_offloaded) @@ -140,7 +166,9 @@ def policy(prim, *_, **params): def save_from_both_policies(policy_1, policy_2): + """Logical OR of the given policies. + A residual is saveable iff it is saveable according to either policy.""" def policy(prim, *args, **params): out1 = policy_1(prim, *args, **params) out2 = policy_2(prim, *args, **params) @@ -172,11 +200,11 @@ def policy(prim, *args, **params): ### Main API -@api_boundary +@partial(api_boundary, repro_api_name="jax.checkpoint") def checkpoint(fun: Callable, *, prevent_cse: bool = True, policy: Callable[..., bool] | None = None, static_argnums: int | tuple[int, ...] = (), - ) -> Callable: + concrete: bool | DeprecatedArg = DeprecatedArg()) -> Callable: """Make ``fun`` recompute internal linearization points when differentiated. The :func:`jax.checkpoint` decorator, aliased to :func:`jax.remat`, provides a @@ -230,6 +258,8 @@ def checkpoint(fun: Callable, *, prevent_cse: bool = True, returns a boolean indicating whether the corresponding output value(s) can be saved as residuals (or instead must be recomputed in the (co)tangent computation if needed). + concrete: Optional boolean; deprecated. Will raise a DeprecationWarning if + used, and passing True will result in a NotImplementedError. Returns: A function (callable) with the same input/output behavior as ``fun`` but @@ -317,8 +347,23 @@ def foo(x, y): ``jax.ensure_compile_time_eval``), it may be easier to compute some values outside the :func:`jax.checkpoint`-decorated function and then close over them. """ + if not isinstance(concrete, DeprecatedArg): + concrete_msg = ( + "The `concrete` option to `jax.checkpoint` has been deprecated." + " In its place please use `static_argnums`; for details refer to" + " https://docs.jax.dev/en/latest/jep/11830-new-remat-checkpoint.html." + ) + deprecations.warn("jax-checkpoint-concrete", concrete_msg, stacklevel=2) + if concrete: + raise NotImplementedError(concrete_msg) + if isinstance(static_argnums, int): static_argnums = static_argnums, + if isinstance(prevent_cse, list): + prevent_cse = tuple(prevent_cse) + if not isinstance(prevent_cse, (tuple, bool)): + raise TypeError("prevent_cse must be a bool or tuple of bools, got " + f"{type(prevent_cse)=}") @wraps(fun) @api_boundary @@ -330,13 +375,25 @@ def fun_remat(*args, **kwargs): args_flat, in_tree = tree_flatten((args, kwargs)) in_avals = [core.shaped_abstractify(x) for x in args_flat] jaxpr, consts, out_tree = _trace_to_jaxpr(fun_, in_tree, tuple(in_avals), debug) + if isinstance(prevent_cse, tuple): + cse_args = (tuple(args), kwargs) if kwargs else tuple(args) + cse = (False,) * len(consts) + tuple(broadcast_prefix(prevent_cse, cse_args)) + else: + cse = prevent_cse out_flat = remat_p.bind( - *consts, *args_flat, jaxpr=jaxpr, prevent_cse=prevent_cse, - differentiated=False, policy=policy) + *consts, *args_flat, jaxpr=jaxpr, prevent_cse=cse, differentiated=False, + policy=policy) return tree_unflatten(out_tree, out_flat) return fun_remat -remat = checkpoint # alias + +def remat(fun: Callable, *, prevent_cse: bool = True, + policy: Callable[..., bool] | None = None, + static_argnums: int | tuple[int, ...] = (), + concrete: bool | DeprecatedArg = DeprecatedArg()) -> Callable: + """Alias of :func:`jax.checkpoint`.""" + return checkpoint(fun, prevent_cse=prevent_cse, policy=policy, + static_argnums=static_argnums, concrete=concrete) # This function is similar to api_util.argnums_partial, except the error # messages are specific to jax.remat (and thus more actionable), the @@ -422,7 +479,7 @@ def _trace_to_jaxpr(fun: Callable, ) -> tuple[core.Jaxpr, Sequence[Any], PyTreeDef]: flat_fun, out_tree = api_util.flatten_fun(lu.wrap_init(fun, debug_info=debug), in_tree) try: - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) except core.ConcretizationTypeError as e: msg, = e.args if 'for checkpoint' in msg: @@ -430,7 +487,7 @@ def _trace_to_jaxpr(fun: Callable, "Consider using the `static_argnums` parameter for `jax.remat` or " "`jax.checkpoint`. See the `jax.checkpoint` docstring and its example " "involving `static_argnums`:\n" - "https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html" + "https://docs.jax.dev/en/latest/_autosummary/jax.checkpoint.html" "\n") e.args = msg, raise @@ -448,14 +505,18 @@ def f_(*args): return f(*args, **kwargs) debug_info = api_util.debug_info("saved_residuals", f, args, kwargs) - out = api.make_jaxpr(lambda *args: api.linearize(f_, *args)[1], + out = api.make_jaxpr(lambda *args: api.linearize(f_, *args), return_shape=True)(*in_leaves) assert isinstance(out, tuple) - jaxpr_, out_shape = out + jaxpr_, out_shape_ = out jaxpr = jaxpr_.jaxpr - out_tree = lambda: tree_structure(out_shape) + out_shape = out_shape_[1] + num_res = tree_structure(out_shape).num_leaves + jaxpr = jaxpr.replace( + outvars=jaxpr.outvars[len(jaxpr.outvars) - num_res:], + debug_info=debug_info._replace(result_paths=None)) assert len(jaxpr.invars) == len(in_leaves) - return _saved_residuals(jaxpr, debug_info.arg_names) + return _saved_residuals(jaxpr, debug_info.arg_names or ("unknown",) * len(jaxpr.invars)) def _saved_residuals(jaxpr: core.Jaxpr, arg_names: Sequence[str]) -> list[tuple[core.AbstractValue, str]]: @@ -488,7 +549,7 @@ def _saved_residuals(jaxpr: core.Jaxpr, if v in res_vars: if eqn.primitive is name_p or v in named_vars and (eqn := named_vars[v]): results.append((v.aval, f"named '{eqn.params['name']}' from {src}")) - elif str(eqn.primitive) == 'pjit': + elif eqn.primitive.name == 'jit': results.append((v.aval, f"output of jitted function '{eqn.params['name']}' " f"from {src}")) @@ -508,6 +569,12 @@ def print_saved_residuals(f, *args, **kwargs): remat_p = core.Primitive('remat2') remat_p.multiple_results = True +def _remat_bind(*args, jaxpr, prevent_cse, differentiated, policy): + assert isinstance(prevent_cse, bool) or len(prevent_cse) == len(args) + return core.Primitive.bind(remat_p, *args, jaxpr=jaxpr, prevent_cse=prevent_cse, + differentiated=differentiated, policy=policy) +remat_p.bind = _remat_bind # type: ignore + @remat_p.def_impl def remat_impl(*args, jaxpr, prevent_cse, differentiated, policy): del prevent_cse, differentiated, policy # Unused. @@ -516,7 +583,7 @@ def remat_impl(*args, jaxpr, prevent_cse, differentiated, policy): @remat_p.def_effectful_abstract_eval def remat_abstract_eval(*args, jaxpr, prevent_cse, differentiated, policy): del args, prevent_cse, differentiated, policy # Unused. - return [v.aval for v in jaxpr.outvars], jaxpr.effects + return [v.aval for v in jaxpr.outvars], core.eqn_effects(jaxpr) def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy): assert not jaxpr.constvars @@ -524,6 +591,8 @@ def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy): jaxpr_jvp_, out_nz = ad.jvp_jaxpr(pe.close_jaxpr(jaxpr), in_nonzeros, False) nonzero_tangents = [t for t in tangents if type(t) is not ad_util.Zero] jaxpr_jvp = pe.convert_constvars_jaxpr(jaxpr_jvp_.jaxpr) + if isinstance(prevent_cse, tuple): + prevent_cse += (True,) * len(nonzero_tangents) outs = remat_p.bind( *jaxpr_jvp_.consts, *primals, *nonzero_tangents, jaxpr=jaxpr_jvp, prevent_cse=prevent_cse, differentiated=differentiated, policy=policy) @@ -534,10 +603,8 @@ def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy): return out_primals, out_tangents ad.primitive_jvps[remat_p] = remat_jvp -effects.remat_allowed_effects.add_type(lax_internal.InOutFeedEffect) - def remat_partial_eval(trace: pe.JaxprTrace, *tracers: core.Tracer, - jaxpr: core.Jaxpr, **params): + jaxpr: core.Jaxpr, prevent_cse, **params): assert not jaxpr.constvars disallowed_effects = effects.remat_allowed_effects.filter_not_in(jaxpr.effects) if disallowed_effects: @@ -565,7 +632,7 @@ def remat_partial_eval(trace: pe.JaxprTrace, *tracers: core.Tracer, # on producers of any residuals. See https://github.com/jax-ml/jax/pull/22244. jaxpr_known_ = _insert_reduce_precision(jaxpr_known, num_res) - # compute known outputs and residuals (hoisted out of remat primitive) + # Compute known outputs and residuals (hoisted out of remat primitive) _, in_consts_ = unzip2(t.pval for t in tracers if t.pval.is_known()) _, in_consts = partition_list(in_used_known, in_consts_) out_consts = core.eval_jaxpr(jaxpr_known_, (), *in_consts) @@ -577,9 +644,13 @@ def remat_partial_eval(trace: pe.JaxprTrace, *tracers: core.Tracer, in_jaxpr_tracers = res_tracers + map(trace.instantiate_const, tracers_staged) # type: ignore out_jaxpr_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None) for x in jaxpr_unknown.outvars] - new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True) - recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p, - new_params, jaxpr_unknown.effects, + if isinstance(prevent_cse, tuple): + _, prevent_cse_ = partition_list(in_used_staged, prevent_cse) + prevent_cse = (True,) * len(res_tracers) + tuple(prevent_cse_) + new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True, + prevent_cse=prevent_cse) + recipe = pe.new_eqn_recipe(trace, in_jaxpr_tracers, out_jaxpr_tracers, remat_p, + new_params, core.eqn_effects(jaxpr_unknown), source_info_util.current()) # log info about saved residuals @@ -615,13 +686,13 @@ def _insert_reduce_precision(jaxpr: core.Jaxpr, num_res: int) -> core.Jaxpr: used_vars = {x for e in jaxpr.eqns for x in e.invars if isinstance(x, core.Var)} invars, constvars, eqns = jaxpr.invars[:], jaxpr.constvars[:], jaxpr.eqns[:] for v in res_vars: - if (not isinstance(v.aval, core.UnshapedArray) or + if (not isinstance(v.aval, core.ShapedArray) or not dtypes.issubdtype(v.aval.dtype, np.inexact)): continue if v not in used_vars: continue assert isinstance(v, core.Var) - newvar = core.Var(v.suffix, v.aval) + newvar = core.Var(v.aval) finfo = dtypes.finfo(v.aval.dtype) params = dict(exponent_bits=finfo.nexp, mantissa_bits=finfo.nmant) if v in constvars or v in invars: @@ -647,29 +718,43 @@ def _insert_reduce_precision(jaxpr: core.Jaxpr, num_res: int) -> core.Jaxpr: return new_jaxpr def remat_partial_eval_custom_params_updater(*args): - *_, params_known, params_staged = args + unks_in, inst_in, *_, params_known, params_staged = args + prevent_cse = params_known['prevent_cse'] + assert prevent_cse == params_staged['prevent_cse'] + if isinstance(prevent_cse, tuple): + prevent_cse_known, _ = partition_list(unks_in, prevent_cse) + _, prevent_cse_staged = partition_list(inst_in, prevent_cse) + params_known = dict(params_known, prevent_cse=tuple(prevent_cse_known)) + params_staged = dict(params_staged, prevent_cse=tuple(prevent_cse_staged)) return params_known, dict(params_staged, differentiated=True) pe.partial_eval_jaxpr_custom_rules[remat_p] = \ partial(pe.call_partial_eval_custom_rule, 'jaxpr', remat_partial_eval_custom_params_updater) -def remat_transpose(out_cts, *in_primals, jaxpr, **params): +def remat_transpose(out_cts, *args, jaxpr, prevent_cse, **params): + # TODO(mattjj): avoid round-tripping into UndefinedPrimals + args_ = [ad.UndefinedPrimal(x.aval) if isinstance(x, ad.ValAccum) else x + for x in args] + if any(isinstance(x, ad.GradAccum) for x in args_): raise NotImplementedError + assert not jaxpr.constvars - in_linear = [ad.is_undefined_primal(x) for x in in_primals] + in_linear = [ad.is_undefined_primal(x) for x in args_] out_zeros = [type(ct) is ad_util.Zero for ct in out_cts] transposed_jaxpr_, in_zeros = transpose_jaxpr( pe.close_jaxpr(jaxpr), in_linear, out_zeros) transposed_jaxpr, consts = transposed_jaxpr_.jaxpr, transposed_jaxpr_.consts transposed_jaxpr = pe.convert_constvars_jaxpr(transposed_jaxpr) - args, _ = tree_flatten((in_primals, out_cts)) - in_cts_nz = remat_p.bind(*consts, *args, jaxpr=transposed_jaxpr, **params) + flat_args, _ = tree_flatten((args_, out_cts)) + if isinstance(prevent_cse, tuple): + prevent_cse_, _ = partition_list(in_linear, prevent_cse) + prevent_cse = tuple(prevent_cse_) + (True,) * (len(out_zeros) - sum(out_zeros)) + in_cts_nz = remat_p.bind(*consts, *flat_args, jaxpr=transposed_jaxpr, + prevent_cse=prevent_cse, **params) in_cts_nz_, in_zeros_ = iter(in_cts_nz), iter(in_zeros) - in_cts = [None if not ad.is_undefined_primal(x) else - ad_util.Zero(x.aval) if next(in_zeros_) else next(in_cts_nz_) - for x in in_primals] - assert next(in_cts_nz_, None) is next(in_zeros_, None) is None - return in_cts -ad.primitive_transposes[remat_p] = remat_transpose + for x in args: + if isinstance(x, ad.ValAccum) and not next(in_zeros_): + x.accum(next(in_cts_nz_)) +ad.fancy_transposes[remat_p] = remat_transpose # TODO(mattjj): move this to ad.py def transpose_jaxpr(jaxpr: core.ClosedJaxpr, in_linear: bool | Sequence[bool], @@ -698,27 +783,32 @@ def transposed(*args_flat): pe.PartialVal.known(next(ins_iter)) for aval, lin in zip(jaxpr.in_avals, in_lin)] assert next(ins_iter, None) is None + + # TODO(mattjj): revise not to require disabling checks + with config.mutable_array_checks(False): + jaxpr_rematted, lin_jaxpr, out_uk, res_avals = \ + pe.partial_eval_jaxpr_nounits(jaxpr, in_lin, False) with source_info_util.extend_name_stack('rematted_computation'): - lin_jaxpr, _, consts = pe.trace_to_jaxpr_nounits( - lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=jaxpr.jaxpr.debug_info), - in_pvals, False) + consts = core.jaxpr_as_fun(jaxpr_rematted)(*ins_flat) # Transpose the linear jaxpr (which only has linear inputs). out_cts_iter = iter(out_cts_flat) out_cts = [ad_util.Zero(aval) if zero else next(out_cts_iter) for aval, zero in zip(jaxpr.out_avals, out_zeros)] assert next(out_cts_iter, None) is None - dummy_args = [ad.UndefinedPrimal(v.aval) for v in lin_jaxpr.invars] - in_cts = ad.backward_pass(lin_jaxpr, False, consts, dummy_args, out_cts) + dummy_args = [ad.UndefinedPrimal(aval) for aval in lin_jaxpr.in_avals[len(consts):]] + in_cts = ad.backward_pass(lin_jaxpr.jaxpr, False, lin_jaxpr.consts, + [*consts, *dummy_args], out_cts) + in_cts = in_cts[len(consts):] # Identify symbolic zeros in the resulting cotangents, and return nonzeros. in_zeros = cell.in_cts_zero = [type(ct) is ad_util.Zero for ct in in_cts] in_cts_nz, _ = partition_list(in_zeros, in_cts) return in_cts_nz - transposed_wrapped = lu.wrap_init(transposed, - debug_info=jaxpr.jaxpr.debug_info) - transposed_jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic( + dbg = jaxpr.jaxpr.debug_info.with_unknown_names() + transposed_wrapped = lu.wrap_init(transposed, debug_info=dbg) + transposed_jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic( transposed_wrapped, in_avals) transposed_jaxpr = core.ClosedJaxpr(transposed_jaxpr_, consts) return transposed_jaxpr, cell.in_cts_zero # pytype: disable=attribute-error @@ -741,7 +831,10 @@ def remat_dce(used_outputs: list[bool], eqn: core.JaxprEqn if not any(used_outputs) and not pe.has_effects(eqn): return [False] * len(eqn.invars), None new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs) - new_params = dict(eqn.params, jaxpr=new_jaxpr) + prevent_cse = eqn.params['prevent_cse'] + if isinstance(prevent_cse, tuple): + prevent_cse = tuple(p for p, u in zip(prevent_cse, used_inputs) if u) + new_params = dict(eqn.params, jaxpr=new_jaxpr, prevent_cse=prevent_cse) if (not any(used_inputs) and not any(used_outputs) and _has_effects(new_jaxpr.effects)): return used_inputs, None @@ -749,117 +842,105 @@ def remat_dce(used_outputs: list[bool], eqn: core.JaxprEqn new_eqn = pe.new_jaxpr_eqn( [v for v, used in zip(eqn.invars, used_inputs) if used], [v for v, used in zip(eqn.outvars, used_outputs) if used], - eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info, eqn.ctx) + eqn.primitive, new_params, core.eqn_effects(new_jaxpr), + eqn.source_info, eqn.ctx) return used_inputs, new_eqn pe.dce_rules[remat_p] = remat_dce def _has_effects(effects) -> bool: - return bool({e for e in effects if not isinstance(e, core.NamedAxisEffect)}) + not_really_effects = (core.NamedAxisEffect, core.InternalMutableArrayEffect) + return any(not isinstance(e, not_really_effects) for e in effects) -def remat_expansion(*args, jaxpr: core.Jaxpr, prevent_cse: bool, - differentiated: bool, is_gpu_platform: bool = False, - **_): +def remat_expansion( + *args, jaxpr: core.Jaxpr, prevent_cse: bool, differentiated: bool, **_ +): assert not jaxpr.constvars if differentiated and prevent_cse: - if config.remat_opt_barrier.value: - translation_rule = _remat_translation_using_opt_barrier - elif is_gpu_platform: - translation_rule = _remat_translation_using_while - else: - translation_rule = _remat_translation_using_cond + translation_rule = _remat_translation_using_opt_barrier else: translation_rule = lambda *args, jaxpr: core.eval_jaxpr(jaxpr, (), *args) return api.named_call(translation_rule, name="checkpoint")(*args, jaxpr=jaxpr) + def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr): args = lax_internal.optimization_barrier(args) return core.eval_jaxpr(jaxpr, (), *args) -# TODO(mattjj): add core utility for 'create dummy value for this type'? -def _dummy_like(aval: core.AbstractValue) -> Any: - if aval is core.abstract_token: - return lax_internal.create_token() - elif isinstance(aval, (core.ShapedArray, core.DShapedArray)): - return lax_internal.broadcast(lax_internal.empty(aval.dtype), aval.shape) # type: ignore - else: - raise ValueError(aval) - -def _remat_translation_using_while(*args, jaxpr: core.Jaxpr): - # Implements: - # for(counter=0, result=0; counter < rng(1, 2); counter ++) { - # result = eval_jaxpr(*args) - # } - # The loop carry is a tuple: (counter, result, args) - from jax._src.lax import control_flow as lax_control_flow - - avals_out = tuple(v.aval for v in jaxpr.outvars) - carry_init = (np.int32(0), tuple(map(_dummy_like, avals_out)), args) - def cond(carry): - counter, _, _ = carry - unif = lax_internal.rng_uniform(np.int32(1), np.int32(2), shape=()) - return counter < unif - - def body(carry): - counter, _, args = carry - results = core.eval_jaxpr(jaxpr, (), *args) - return (counter + 1, tuple(results), args) - - carry_res = lax_control_flow.while_loop(cond, body, carry_init) - return carry_res[1] - -def _remat_translation_using_cond(*args, jaxpr: core.Jaxpr): - # Implements: - # if(rng(0, 1) < 2) - # return eval_jaxpr(*args) - # else: - # return 0 - from jax._src.lax import control_flow as lax_control_flow - - avals_out = tuple(v.aval for v in jaxpr.outvars) - - def remat_comp(*args): - return tuple(core.eval_jaxpr(jaxpr, (), *args)) - def dummy_comp(*args): - return tuple(map(_dummy_like, avals_out)) - - unif = lax_internal.rng_uniform(np.float32(0), np.float32(1), shape=()) - return lax_control_flow.cond(unif < np.float32(2), remat_comp, dummy_comp, *args) - -def _remat_lowering(ctx, *args, jaxpr: core.Jaxpr, prevent_cse: bool, - differentiated: bool, policy, is_gpu_platform=False): - jaxpr_args: Sequence[mlir.IrValues] - if differentiated and prevent_cse: - # If we're using the loop or cond lowerings, use the slower lower_fun - # based path. - if not config.remat_opt_barrier.value: - return mlir.lower_fun(remat_expansion, multiple_results=True)( - ctx, *args, jaxpr=jaxpr, prevent_cse=prevent_cse, - differentiated=differentiated, policy=policy, - is_gpu_platform=is_gpu_platform) - - arg_types = map(mlir.aval_to_ir_type, ctx.avals_in) - flat_args = mlir.flatten_ir_values(args) - barrier_op = hlo.OptimizationBarrierOp(flat_args) - jaxpr_args = mlir.unflatten_ir_values_like_types( - barrier_op.results, arg_types) - else: - jaxpr_args = args + +def _remat_lowering( + ctx: mlir.LoweringRuleContext, + *args, + jaxpr: core.Jaxpr, + prevent_cse: bool, + differentiated: bool, + policy, +): + if isinstance(prevent_cse, bool): + prevent_cse = (prevent_cse,) * len(ctx.avals_in) # type: ignore + assert isinstance(prevent_cse, tuple) + if differentiated and any(prevent_cse): + _, barrier_avals = partition_list(prevent_cse, ctx.avals_in) + other_args, barrier_args = partition_list(prevent_cse, args) + barrier_op = hlo.OptimizationBarrierOp( + mlir.flatten_ir_values(barrier_args)) + barrier_results = mlir.unflatten_ir_values_like_types( + barrier_op.results, map(mlir.aval_to_ir_type, barrier_avals)) + args = merge_lists(prevent_cse, other_args, barrier_results) # type: ignore outs, tokens_out = mlir.jaxpr_subcomp( ctx.module_context, jaxpr, ctx.name_stack.extend('checkpoint'), - ctx.tokens_in, (), *jaxpr_args, dim_var_values=ctx.dim_var_values) + ctx.tokens_in, (), *args, dim_var_values=ctx.dim_var_values, + const_lowering=ctx.const_lowering) ctx.set_tokens_out(tokens_out) return outs mlir.register_lowering(remat_p, _remat_lowering) -mlir.register_lowering(remat_p, partial(_remat_lowering, is_gpu_platform=True), - platform="gpu") def checkpoint_name(x, name): - return name_p.bind(x, name=name) + """Identifies a value with a name within :func:`jax.checkpoint`. + + This function acts as an identity function at runtime (returning ``x`` + unchanged) but attaches a string name to the value in the JAX trace. + These names can be targeted by specific checkpointing policies (see + :ref:`checkpoint-policies`) to control which intermediate values + are saved during the forward pass and which are recomputed during the + backward pass. + + Args: + x: array or PyTree of arrays to be named. + name: A string name to associate with the value ``x``. + + Returns: + The input ``x``, unchanged. + + See Also: + - :func:`jax.checkpoint` (alias: :func:`jax.remat`): decorator to + enable checkpointing. + - :mod:`jax.checkpoint_policies`: a namespace containing policies + that use names marked via ``checkpoint_name`` to determine behavior. + + Example: + >>> import jax + >>> import jax.numpy as jnp + >>> from jax.ad_checkpoint import checkpoint_name + + >>> # Define a function where we explicitly name an intermediate value + >>> def f(x): + ... y = jnp.sin(x) + ... z = checkpoint_name(y, "my_intermediate") + ... return jnp.cos(z) + + >>> # Use a policy that saves only the named value + >>> policy = jax.checkpoint_policies.save_only_these_names("my_intermediate") + >>> f_checkpointed = jax.checkpoint(f, policy=policy) + + For further examples, see the `remat example notebook + `_. + """ + return tree_map(partial(name_p.bind, name=name), x) name_p.def_impl(lambda x, *, name: x) name_p.def_abstract_eval(lambda x, *, name: x) @@ -877,64 +958,14 @@ def name_batcher(args, dims, *, name): batching.primitive_batchers[name_p] = name_batcher -@functools.wraps(checkpoint) -def checkpoint_wrapper( - fun: Callable, - *, - concrete: bool = False, - prevent_cse: bool = True, - static_argnums: int | tuple[int, ...] = (), - policy: Callable[..., bool] | None = None, -) -> Callable: - if concrete: - msg = ("The 'concrete' option to jax.checkpoint / jax.remat is deprecated; " - "in its place, you can use its `static_argnums` option, and if " - "necessary the `jax.ensure_compile_time_eval()` context manager.\n" - "\n" - "For example, if using `concrete=True` for an `is_training` flag:\n" - "\n" - " from functools import partial\n" - "\n" - " @partial(jax.checkpoint, concrete=True)\n" - " def foo(x, is_training):\n" - " if is_training:\n" - " return f(x)\n" - " else:\n" - " return g(x)\n" - "\n" - "replace it with a use of `static_argnums`:\n" - "\n" - " @partial(jax.checkpoint, static_argnums=(1,))\n" - " def foo(x, is_training):\n" - " ...\n" - "\n" - "If jax.numpy operations need to be performed on static arguments, " - "we can use the `jax.ensure_compile_time_eval()` context manager. " - "For example, we can replace this use of `concrete=True`\n:" - "\n" - " @partial(jax.checkpoint, concrete=True)\n" - " def foo(x, y):\n" - " if y > 0:\n" - " return f(x)\n" - " else:\n" - " return g(x)\n" - "\n" - "with this combination of `static_argnums` and " - "`jax.ensure_compile_time_eval()`:\n" - "\n" - " @partial(jax.checkpoint, static_argnums=(1,))\n" - " def foo(x, y):\n" - " with jax.ensure_compile_time_eval():\n" - " y_pos = y > 0\n" - " if y_pos:\n" - " return f(x)\n" - " else:\n" - " return g(x)\n" - "\n" - "See https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html\n") - raise NotImplementedError(msg) - return checkpoint(fun, prevent_cse=prevent_cse, policy=policy, - static_argnums=static_argnums) - -# TODO(phawkins): update users to refer to the public name. -_optimization_barrier = lax_internal.optimization_barrier +@discharge.register_discharge_rule(remat_p) +def _remat_state_discharge_rule( + in_avals, out_avals, *args, jaxpr, **params): + discharged_jaxpr, () = discharge.discharge_state(jaxpr, []) + out_vals_ref_vals = remat_p.bind(*args, jaxpr=discharged_jaxpr, **params) + out_vals, ref_vals = split_list(out_vals_ref_vals, [len(jaxpr.outvars)]) + ref_vals_ = iter(ref_vals) + new_invals = [next(ref_vals_) if isinstance(a, AbstractRef) else None + for a in in_avals] + assert next(ref_vals_, None) is None + return new_invals, out_vals diff --git a/jax/_src/ad_util.py b/jax/_src/ad_util.py index c729a57cfb11..4b693fb1a890 100644 --- a/jax/_src/ad_util.py +++ b/jax/_src/ad_util.py @@ -31,6 +31,10 @@ map = safe_map def add_jaxvals(x: ArrayLike, y: ArrayLike) -> Array: + ty = core.typeof(x) + if hasattr(ty, 'vspace_add'): # TODO(mattjj,dougalm): revise away hasattr + return ty.vspace_add(x, y) + x, y = core.standard_insert_pvary(x, y) return add_jaxvals_p.bind(x, y) add_jaxvals_p = Primitive('add_any') @@ -43,10 +47,12 @@ def add_impl(x, y): @add_jaxvals_p.def_abstract_eval def add_abstract(x, y): - assert core.typematch(x, y) + assert core.typematch(x, y), (x, y) return x def zeros_like_aval(aval: core.AbstractValue) -> Array: + if hasattr(aval, 'vspace_zero'): # TODO(mattjj,dougalm): revise away hasattr + return aval.vspace_zero() return aval_zeros_likers[type(aval)](aval) aval_zeros_likers: dict[type, Callable[[Any], Array]] = {} @@ -67,7 +73,10 @@ def __repr__(self) -> str: return f'Zero({self.aval})' @staticmethod def from_primal_value(val: Any) -> Zero: + # TODO(mattjj,yashkatariya): sometimes we want to_cotangent_aval... return Zero(get_aval(val).to_tangent_aval()) + def instantiate(self): + return zeros_like_aval(self.aval) register_pytree_node(Zero, lambda z: ((), z.aval), lambda aval, _: Zero(aval)) diff --git a/jax/_src/api.py b/jax/_src/api.py index cdcc3e534e74..fdf877424561 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -25,23 +25,25 @@ import atexit import collections from collections.abc import Callable, Hashable, Iterable, Sequence -from functools import partial, lru_cache +import dataclasses +from functools import partial import inspect -import math import typing -from typing import (Any, Literal, NamedTuple, TypeVar, overload, - cast) +from typing import (Any, Literal, NamedTuple, Optional, TypeVar, overload, + cast, TYPE_CHECKING) import weakref import numpy as np from contextlib import contextmanager +from jax._src import api_util from jax._src import linear_util as lu from jax._src import stages from jax._src.tree_util import ( tree_map, tree_flatten, tree_unflatten, tree_structure, tree_transpose, - tree_leaves, Partial, PyTreeDef, all_leaves, keystr, broadcast_prefix, - prefix_errors, generate_key_paths, tree_flatten_with_path) + tree_leaves, Partial, PyTreeDef, keystr, broadcast_prefix, + prefix_errors, generate_key_paths, tree_flatten_with_path, + equality_errors_pytreedef, register_pytree_node, register_dataclass) from jax._src import config from jax._src import core from jax._src import dispatch @@ -49,43 +51,43 @@ from jax._src import basearray from jax._src import distributed from jax._src import dtypes +from jax._src.dtypes import canonicalize_value from jax._src import sharding_impls from jax._src import sharding_specs from jax._src import source_info_util from jax._src import traceback_util from jax._src import pjit from jax._src import xla_bridge as xb -from jax._src.core import eval_jaxpr, shaped_abstractify, ShapedArray +from jax._src.core import eval_jaxpr, shaped_abstractify, ShapedArray, typeof from jax._src.api_util import ( flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial, - flatten_axes, donation_vector, - rebase_donate_argnums, _ensure_index, _ensure_index_tuple, - apply_flat_fun_nokwargs, check_callable, debug_info, - flat_out_axes) -from jax._src.lax import lax as lax_internal + flatten_axes, donation_vector, rebase_donate_argnums, + _ensure_index, _ensure_index_tuple, apply_flat_fun_nokwargs, check_callable, + debug_info, flat_out_axes) from jax._src.lib import jax_jit from jax._src.lib import xla_client as xc from jax._src.lib import pmap_lib from jax._src.sharding import Sharding -from jax._src.mesh import get_concrete_mesh -from jax._src.sharding_impls import ( - PmapSharding, TransferToMemoryKind, PartitionSpec as P, NamedSharding) -from jax._src.layout import Layout, AutoLayout +from jax._src.mesh import get_concrete_mesh, get_abstract_mesh, Mesh +from jax._src.sharding_impls import (PmapSharding, PartitionSpec as P, + NamedSharding) +from jax._src.layout import Format from jax._src.traceback_util import api_boundary from jax._src import tree_util -from jax._src.util import unzip2, safe_map, safe_zip, wraps, split_list +from jax._src.util import unzip2, safe_map, safe_zip, wraps from jax._src import util from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import pxla -from jax._src.interpreters import xla + +config_ext = xc._xla.config traceback_util.register_exclusion(__file__) -_dtype = partial(dtypes.dtype, canonicalize=True) +_dtype = dtypes.dtype AxisName = Hashable @@ -100,6 +102,8 @@ map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip +ShapeDtypeStruct = core.ShapeDtypeStruct + @api_boundary def _nan_check_posthook(fun, args, kwargs, output): @@ -110,32 +114,37 @@ def _nan_check_posthook(fun, args, kwargs, output): buffers.extend([shard.data for shard in leaf.addressable_shards]) try: - dispatch.check_special(pjit.pjit_p.name, buffers) - except dispatch.InternalFloatingPointError as e: + dispatch.check_special(pjit.jit_p.name, buffers) + except api_util.InternalFloatingPointError as e: assert config.debug_nans.value or config.debug_infs.value if hasattr(fun, '_fun'): f = fun._fun if getattr(f, '_apply_primitive', False): raise FloatingPointError(f"invalid value ({e.ty}) encountered in {f.__qualname__}") from None # compiled_fun can only raise in this case - dispatch.maybe_recursive_nan_check(e, f, args, kwargs) + api_util.maybe_recursive_nan_check(e, f, args, kwargs) raise AssertionError("Unreachable") from e else: # TODO(emilyaf): Shouldn't need this fallback. raise +_post_hook_state = config_ext.Config[Optional[Callable]]( + "post_hook", None, include_in_jit_key=False +) +jax_jit.set_post_hook_state(_post_hook_state) + def _update_debug_special_global(_): if config._read("jax_debug_nans") or config._read("jax_debug_infs"): - jax_jit.global_state().post_hook = _nan_check_posthook + _post_hook_state.set_global(_nan_check_posthook) else: - jax_jit.global_state().post_hook = None + _post_hook_state.set_global(None) def _update_debug_special_thread_local(_): if (config.debug_nans.get_local() == True or config.debug_infs.get_local() == True): - jax_jit.thread_local_state().post_hook = _nan_check_posthook + _post_hook_state.set_local(_nan_check_posthook) else: - jax_jit.thread_local_state().post_hook = None + _post_hook_state.set_local(None) config.debug_nans._add_hooks(_update_debug_special_global, _update_debug_special_thread_local) @@ -145,9 +154,45 @@ def _update_debug_special_thread_local(_): float0 = dtypes.float0 +class NotSpecified: + """Sentinel for use in jax.jit""" + def __repr__(self): + return "" + +@overload +def jit( + fun: Callable, /, *, + in_shardings: Any = ..., + out_shardings: Any = ..., + static_argnums: int | Sequence[int] | None = ..., + static_argnames: str | Iterable[str] | None = ..., + donate_argnums: int | Sequence[int] | None = ..., + donate_argnames: str | Iterable[str] | None = ..., + keep_unused: bool = ..., + device: xc.Device | None = ..., + backend: str | None = ..., + inline: bool = ..., + compiler_options: dict[str, Any] | None = ..., +) -> pjit.JitWrapped: ... + +@overload +def jit( + *, + in_shardings: Any = ..., + out_shardings: Any = ..., + static_argnums: int | Sequence[int] | None = ..., + static_argnames: str | Iterable[str] | None = ..., + donate_argnums: int | Sequence[int] | None = ..., + donate_argnames: str | Iterable[str] | None = ..., + keep_unused: bool = ..., + device: xc.Device | None = ..., + backend: str | None = ..., + inline: bool = ..., + compiler_options: dict[str, Any] | None = ..., +) -> Callable[[Callable], pjit.JitWrapped]: ... def jit( - fun: Callable, + fun: Callable | NotSpecified = NotSpecified(), /, *, in_shardings: Any = sharding_impls.UNSPECIFIED, out_shardings: Any = sharding_impls.UNSPECIFIED, static_argnums: int | Sequence[int] | None = None, @@ -158,21 +203,21 @@ def jit( device: xc.Device | None = None, backend: str | None = None, inline: bool = False, - abstracted_axes: Any | None = None, compiler_options: dict[str, Any] | None = None, -) -> pjit.JitWrapped: +) -> pjit.JitWrapped | Callable[[Callable], pjit.JitWrapped]: """Sets up ``fun`` for just-in-time compilation with XLA. Args: fun: Function to be jitted. ``fun`` should be a pure function. - The arguments and return value of ``fun`` should be arrays, scalar, or (nested) standard Python containers (tuple/list/dict) thereof. Positional arguments indicated by ``static_argnums`` can be any hashable type. Static arguments are included as part of a compilation cache key, which is why hash and equality operators must be defined. JAX keeps a weak reference to ``fun`` for use as a compilation cache key, so the object ``fun`` must be - weakly-referenceable. + weakly-referenceable. Starting in JAX v0.8.1, when ``fun`` is omitted, + the return value will be a partially-evaluated function to allow the + decorator factory pattern (see Examples below). in_shardings: optional, a :py:class:`Sharding` or pytree with :py:class:`Sharding` leaves and structure that is a tree prefix of the positional arguments tuple to ``fun``. If provided, the positional @@ -184,14 +229,13 @@ def jit( out_shardings: optional, a :py:class:`Sharding` or pytree with :py:class:`Sharding` leaves and structure that is a tree prefix of the output of ``fun``. If provided, it has the same effect as applying - corresponding :py:func:`jax.lax.with_sharding_constraint`s to the output - of ``fun``. + :py:func:`jax.lax.with_sharding_constraint` to the output of ``fun``. static_argnums: optional, an int or collection of ints that specify which positional arguments to treat as static (trace- and compile-time constant). Static arguments should be hashable, meaning both ``__hash__`` and - ``__eq__`` are implemented, and immutable. Otherwise they can be arbitrary + ``__eq__`` are implemented, and immutable. Otherwise, they can be arbitrary Python objects. Calling the jitted function with different values for these constants will trigger recompilation. Arguments that are not array-like or containers thereof must be marked as static. @@ -231,7 +275,7 @@ def jit( be donated. For more details on buffer donation see the - `FAQ `_. + `FAQ `_. donate_argnames: optional, a string or collection of strings specifying which named arguments are donated to the computation. See the comment on ``donate_argnums`` for details. If not @@ -272,8 +316,20 @@ def jit( [-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748 -0.85743 -0.78232 0.76827 0.59566 ] - To pass arguments such as ``static_argnames`` when decorating a function, a - common pattern is to use :func:`functools.partial`: + Starting in JAX v0.8.1, :func:`jit` supports the decorator factory pattern + for specifying optional keywords: + + >>> @jax.jit(static_argnames=['n']) + ... def g(x, n): + ... for i in range(n): + ... x = x ** 2 + ... return x + >>> + >>> g(jnp.arange(4), 3) + Array([ 0, 1, 256, 6561], dtype=int32) + + For compatiblity with older JAX versions, a common pattern is to use + :func:`functools.partial`: >>> from functools import partial >>> @@ -286,17 +342,28 @@ def jit( >>> g(jnp.arange(4), 3) Array([ 0, 1, 256, 6561], dtype=int32) """ - return pjit.make_jit( - fun, in_shardings, out_shardings, donate_argnums, donate_argnames, - static_argnums, static_argnames, device, backend, abstracted_axes, - keep_unused, inline, compiler_options, use_resource_env=False) + kwds = dict( + in_shardings=in_shardings, out_shardings=out_shardings, + static_argnums=static_argnums, static_argnames=static_argnames, + donate_argnums=donate_argnums, donate_argnames=donate_argnames, + keep_unused=keep_unused, device=device, backend=backend, inline=inline, + compiler_options=compiler_options, use_resource_env=False) + if isinstance(fun, NotSpecified): + return lambda fun: pjit.make_jit(fun, **kwds) + else: + return pjit.make_jit(fun, **kwds) + +if not TYPE_CHECKING: + # TODO(slebedev): This ought to be a decorator, but it seems it makes + # pytype ignore the overloads + jit = api_boundary(jit, repro_api_name="jax.jit") @contextmanager def disable_jit(disable: bool = True): """Context manager that disables :py:func:`jit` behavior under its dynamic context. - For debugging it is useful to have a mechanism that disables :py:func:`jit` + For debugging, it is useful to have a mechanism that disables :py:func:`jit` everywhere in a dynamic context. Note that this not only disables explicit uses of :func:`jit` by the user, but will also remove any implicit JIT compilation used by the JAX library: this includes implicit JIT computation of `body` and @@ -321,8 +388,8 @@ def disable_jit(disable: bool = True): ... print("Value of y is", y) ... return y + 3 ... - >>> print(f(jax.numpy.array([1, 2, 3]))) # doctest:+ELLIPSIS - Value of y is Tracedwith + >>> print(f(jax.numpy.array([1, 2, 3]))) + Value of y is JitTracer(int32[3]) [5 7 9] Here ``y`` has been abstracted by :py:func:`jit` to a :py:class:`ShapedArray`, @@ -343,6 +410,7 @@ def disable_jit(disable: bool = True): yield +@partial(api_boundary, repro_api_name="jax.grad") def grad(fun: Callable, argnums: int | Sequence[int] = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, @@ -409,6 +477,7 @@ def grad_f_aux(*args, **kwargs): return grad_f_aux if has_aux else grad_f +@partial(api_boundary, repro_api_name="jax.value_and_grad") def value_and_grad(fun: Callable, argnums: int | Sequence[int] = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: Sequence[AxisName] = () @@ -440,6 +509,8 @@ def value_and_grad(fun: Callable, argnums: int | Sequence[int] = 0, shapes and types as the corresponding arguments. If ``has_aux`` is True then a tuple of ((value, auxiliary_data), gradient) is returned. """ + from jax._src.lax import lax as lax_internal # pytype: disable=import-error + if reduce_axes: raise NotImplementedError("reduce_axes argument to grad is deprecated") del reduce_axes @@ -471,11 +542,10 @@ def value_and_grad_f(*args, **kwargs): if not has_aux: ans, vjp_py = _vjp(f_partial, *dyn_args) else: - ans, vjp_py, aux = _vjp( - f_partial, *dyn_args, has_aux=True) + ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True) _check_scalar(ans) tree_map(partial(_check_output_dtype_grad, holomorphic), ans) - g = vjp_py(lax_internal._one(ans)) + g = vjp_py(lax_internal._one_vjp(ans)) g = g[0] if isinstance(argnums, int) else g if not has_aux: return ans, g @@ -504,17 +574,18 @@ def _check_input_dtype_revderiv(name, holomorphic, allow_int, x): if not dtypes.issubdtype(aval.dtype, np.complexfloating): raise TypeError(f"{name} with holomorphic=True requires inputs with complex dtype, " f"but got {aval.dtype.name}.") - if (dtypes.issubdtype(aval.dtype, dtypes.extended) or - dtypes.issubdtype(aval.dtype, np.integer) or - dtypes.issubdtype(aval.dtype, np.bool_)): - if not allow_int: - raise TypeError(f"{name} requires real- or complex-valued inputs (input dtype " - f"that is a sub-dtype of np.inexact), but got {aval.dtype.name}. " - "If you want to use Boolean- or integer-valued inputs, use vjp " - "or set allow_int to True.") - elif not dtypes.issubdtype(aval.dtype, np.inexact): - raise TypeError(f"{name} requires numerical-valued inputs (input dtype that is a " - f"sub-dtype of np.bool_ or np.number), but got {aval.dtype.name}.") + if isinstance(aval, ShapedArray): + if (dtypes.issubdtype(aval.dtype, dtypes.extended) or + dtypes.issubdtype(aval.dtype, np.integer) or + dtypes.issubdtype(aval.dtype, np.bool_)): + if not allow_int: + raise TypeError(f"{name} requires real- or complex-valued inputs (input dtype " + f"that is a sub-dtype of np.inexact), but got {aval.dtype.name}. " + "If you want to use Boolean- or integer-valued inputs, use vjp " + "or set allow_int to True.") + elif not dtypes.issubdtype(aval.dtype, np.inexact): + raise TypeError(f"{name} requires numerical-valued inputs (input dtype that is a " + f"sub-dtype of np.bool_ or np.number), but got {aval.dtype.name}.") _check_input_dtype_grad = partial(_check_input_dtype_revderiv, "grad") def _check_output_dtype_revderiv(name, holomorphic, x): @@ -539,7 +610,82 @@ def _check_output_dtype_revderiv(name, holomorphic, x): "jax.vjp directly.") _check_output_dtype_grad = partial(_check_output_dtype_revderiv, "grad") +@partial(api_boundary, repro_api_name="jax.fwd_and_bwd") +def fwd_and_bwd( + fun: Callable, argnums: int | Sequence[int], has_aux: bool = False, + jitted: bool = True, +) -> tuple[Callable, Callable]: + """Creates functions ``fwd`` and ``bwd`` corresponding to the forward and + backward pass of a given function ``fun``. The forward function ``fwd(*args)`` + functionally behaves much like ``y, fun_vjp = jax.vjp(fun, *args)``, but allows + reuse of the backward function ``bwd`` across multiple iterations, which is + useful to avoid recompilation when the forward and backward do not end up in a + single jitted function: + + >>> import jax + >>> + >>> x = W = cot_out = jax.numpy.ones((4,4)) + >>> + >>> def f(x, W): + ... return x @ W + ... + >>> f_jitted = jax.jit(f) + >>> for i in range(3): + ... y, f_vjp = jax.vjp(f_jitted, x, W) + ... cot_x, cot_W = f_vjp(cot_out) # not jitted + ... cot_x, cot_W = jax.jit(f_vjp)(cot_out) # recompiles on every iteration + ... + >>> fwd, bwd = jax.fwd_and_bwd(f, argnums=(0,1)) + >>> for i in range(3): + ... y, residuals = fwd(x, W) + ... cot_x, cot_W = bwd(residuals, cot_out) # jitted, compiles once + ... + + Args: + fun: Function to produce a forward and backward of. + argnums: Integer or sequence of integers. Specifies which positional argument(s) + to differentiate with respect to. + has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the + first element is considered the output of the mathematical function to be + differentiated and the second element is auxiliary data. Default False. + jitted: Optional, bool. Indicates whether to return the ``jax.jit`` of + forward and backward. Note that jit-ing only the backward but not the + forward will result in the backward recompiling on every invocation, so we + default to jit-ing both. + Returns: + The two functions, ``fwd`` and ``bwd``. + + If ``has_aux`` is ``False``, ``fwd(*primals)`` returns a tuple + ``(primals_out, residuals)``, where ``primals_out`` is ``fun(*primals)``. + If ``has_aux`` is ``True``, returns a ``(primals_out, residuals, aux)`` tuple + where ``aux`` is the auxiliary data returned by ``fun``. + + ``bwd`` is a function from ``residuals`` and a cotangent vector with the same + shape as ``primals_out`` to a tuple of cotangent vectors with the same number + and shapes as the ``primals`` designated by ``argnums``, representing the + vector-Jacobian product of ``fun`` evaluated at ``primals``. + """ + check_callable(fun) + argnums = _ensure_index(argnums) + + def fwd(*args, **kwargs): + dbg = debug_info('fwd_and_bwd', fun, args, kwargs) + f = lu.wrap_init(fun, params=kwargs, debug_info=dbg) + f_partial, dyn_args = argnums_partial( + f, argnums, args, require_static_args_hashable=False) + return _vjp(f_partial, *dyn_args, has_aux=has_aux) # type: ignore + def bwd(f_vjp, outgrad): + g = f_vjp(outgrad) + g = g[0] if isinstance(argnums, int) else g + return g + if jitted: + fwd = jit(fwd) + bwd = jit(bwd) + return fwd, bwd + + +@partial(api_boundary, repro_api_name="jax.jacfwd") def jacfwd(fun: Callable, argnums: int | Sequence[int] = 0, has_aux: bool = False, holomorphic: bool = False) -> Callable: """Jacobian of ``fun`` evaluated column-by-column using forward-mode AD. @@ -630,8 +776,10 @@ def _check_output_dtype_jacfwd(holomorphic, x): raise TypeError("jacfwd with holomorphic=True requires outputs with complex dtype, " f"but got {aval.dtype.name}.") +@partial(api_boundary, repro_api_name="jax.jacrev") def jacrev(fun: Callable, argnums: int | Sequence[int] = 0, - has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False) -> Callable: + has_aux: bool = False, holomorphic: bool = False, + allow_int: bool = False) -> Callable: """Jacobian of ``fun`` evaluated row-by-row using reverse-mode AD. Args: @@ -710,6 +858,7 @@ def jacobian(fun: Callable, argnums: int | Sequence[int] = 0, _check_output_dtype_jacrev = partial(_check_output_dtype_revderiv, "jacrev") +@partial(api_boundary, repro_api_name="jax.hessian") def hessian(fun: Callable, argnums: int | Sequence[int] = 0, has_aux: bool = False, holomorphic: bool = False) -> Callable: """Hessian of ``fun`` as a dense array. @@ -777,31 +926,53 @@ def hessian(fun: Callable, argnums: int | Sequence[int] = 0, return jacfwd(jacrev(fun, argnums, has_aux=has_aux, holomorphic=holomorphic), argnums, has_aux=has_aux, holomorphic=holomorphic) +def _insert_pvary(basis, leaf): + if not config._check_vma.value: + return basis + return core.pvary(basis, tuple(core.typeof(leaf).vma)) + def _std_basis(pytree): - import jax.numpy as jnp + import jax.numpy as jnp # pytype: disable=import-error leaves, _ = tree_flatten(pytree) ndim = sum(map(np.size, leaves)) dtype = dtypes.result_type(*leaves) flat_basis = jnp.eye(ndim, dtype=dtype) - return _unravel_array_into_pytree(pytree, 1, None, flat_basis) + axis = 1 + arr_s = [None] * flat_basis.ndim + specs = tree_map(lambda l: P(arr_s[:axis], *core.typeof(l).sharding.spec, + arr_s[axis+1:]), pytree) + out_pytree = _unravel_array_into_pytree(pytree, axis, None, flat_basis, specs) + out_pytree = tree_map(_insert_pvary, out_pytree, pytree) + return out_pytree def _jacfwd_unravel(input_pytree, output_pytree_leaf, arr): + axis = -1 % arr.ndim + arr_s = core.typeof(arr).sharding.spec + specs = tree_map( + lambda l: P(*arr_s[:axis], *[None] * len(np.shape(l)), *arr_s[axis+1:]), + input_pytree) return _unravel_array_into_pytree( - input_pytree, -1, output_pytree_leaf, arr) + input_pytree, axis, output_pytree_leaf, arr, specs) def _jacrev_unravel(output_pytree, input_pytree_leaf, arr): + specs = tree_map( + lambda l: P(*[None] * len(np.shape(l)), *core.typeof(arr).sharding.spec[1:]), + output_pytree) return _unravel_array_into_pytree( - output_pytree, 0, input_pytree_leaf, arr) + output_pytree, 0, input_pytree_leaf, arr, specs) -def _possible_downcast(x, example): +def _possible_downcast(x, example, spec): + from jax._src.lax import lax as lax_internal # pytype: disable=import-error if (dtypes.issubdtype(x.dtype, np.complexfloating) and not dtypes.issubdtype(_dtype(example), np.complexfloating)): x = x.real - dtype = None if example is None else _dtype(example) - weak_type = None if example is None else dtypes.is_weakly_typed(example) - return lax_internal._convert_element_type(x, dtype, weak_type) + dtype = _dtype(example) + weak_type = dtypes.is_weakly_typed(example) + sharding = NamedSharding(core.typeof(example).sharding.mesh, spec) + return lax_internal._convert_element_type( + x, dtype, weak_type, sharding=sharding) -def _unravel_array_into_pytree(pytree, axis, example, arr): +def _unravel_array_into_pytree(pytree, axis, example, arr, specs): """Unravel an array into a PyTree with a given structure. Args: pytree: The pytree that provides the structure. @@ -812,12 +983,14 @@ def _unravel_array_into_pytree(pytree, axis, example, arr): arr: The array to be unraveled. """ leaves, treedef = tree_flatten(pytree) - axis = axis % arr.ndim + specs, _ = tree_flatten(specs) shapes = [arr.shape[:axis] + np.shape(l) + arr.shape[axis+1:] for l in leaves] parts = _split(arr, np.cumsum(map(np.size, leaves[:-1])), axis) reshaped_parts = [ - _possible_downcast(np.reshape(x, shape), leaf if example is None else example) - for x, shape, leaf in zip(parts, shapes, leaves)] + _possible_downcast(np.reshape(x, shape), + leaf if example is None else example, + spec=spec) + for x, shape, leaf, spec in zip(parts, shapes, leaves, specs)] return tree_unflatten(treedef, reshaped_parts) def _split(x, indices, axis): @@ -827,12 +1000,15 @@ def _split(x, indices, axis): return x._split(indices, axis) +@partial(api_boundary, repro_api_name="jax.vmap") def vmap(fun: F, in_axes: int | None | Sequence[Any] = 0, out_axes: Any = 0, axis_name: AxisName | None = None, axis_size: int | None = None, - spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None) -> F: + spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None, + sum_match: bool = False + ) -> F: """Vectorizing map. Creates a function which maps ``fun`` over argument axes. Args: @@ -855,7 +1031,7 @@ def vmap(fun: F, be a container with a matching pytree structure specifying the mapping of its container elements. In other words, ``in_axes`` must be a container tree prefix of the positional argument tuple passed to ``fun``. See this link for more detail: - https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees + https://docs.jax.dev/en/latest/pytrees.html#applying-optional-parameters-to-pytrees Either ``axis_size`` must be provided explicitly, or at least one positional argument must have ``in_axes`` not None. The sizes of the @@ -971,7 +1147,7 @@ def vmap(fun: F, docstr += fun.__doc__ axis_name = core.no_axis_name if axis_name is None else axis_name - if spmd_axis_name is not None and type(spmd_axis_name) is not tuple: + if spmd_axis_name is not None and not isinstance(spmd_axis_name, tuple): spmd_axis_name = (spmd_axis_name,) if isinstance(in_axes, list): @@ -995,28 +1171,46 @@ def vmap(fun: F, @wraps(fun, docstr=docstr) @api_boundary def vmap_f(*args, **kwargs): + nonlocal spmd_axis_name if isinstance(in_axes, tuple) and len(in_axes) != len(args): raise ValueError("vmap in_axes must be an int, None, or a tuple of entries corresponding " "to the positional arguments passed to the function, " f"but got {len(in_axes)=}, {len(args)=}") + args_flat, in_tree = tree_flatten((args, kwargs), is_leaf=batching.is_vmappable) - f = lu.wrap_init(fun, debug_info=debug_info("vmap", fun, args, kwargs)) + dbg = debug_info("vmap", fun, args, kwargs) + + f = lu.wrap_init(fun, debug_info=dbg) flat_fun, out_tree = batching.flatten_fun_for_vmap(f, in_tree) in_axes_flat = flatten_axes("vmap in_axes", in_tree, (in_axes, 0), kws=True) + + if config.mutable_array_checks.value: + avals = [None if d is None or batching.is_vmappable(x) else core.typeof(x) + for x, d in zip(args_flat, in_axes_flat)] + api_util.check_no_aliased_ref_args(lambda: dbg, avals, args_flat) + axis_size_ = (axis_size if axis_size is not None else _mapped_axis_size(fun, in_tree, args_flat, in_axes_flat, "vmap")) explicit_mesh_axis = _mapped_axis_spec(args_flat, in_axes_flat) if spmd_axis_name is not None and explicit_mesh_axis is not None: - raise ValueError( - "Only one of spmd_axis_name or arrays sharded on `Explicit` mesh" - f" axis type is allowed. Got {spmd_axis_name=} and" - f" arrays sharded on {explicit_mesh_axis=}") + if config.remove_size_one_mesh_axis_from_type.value: + mesh = get_abstract_mesh() + spmd_axis_name = tuple(i for i in spmd_axis_name if mesh.shape[i] != 1) + if spmd_axis_name == explicit_mesh_axis: + spmd_axis_name = None + else: + raise ValueError( + "Only one of spmd_axis_name or arrays sharded on `Explicit` mesh" + f" axis type is allowed. Got {spmd_axis_name=} and" + f" arrays sharded on {explicit_mesh_axis=}") + assert spmd_axis_name is None try: axis_data = batching.AxisData(axis_name, axis_size_, spmd_axis_name, explicit_mesh_axis) - out_flat = batching.batch( + out_flat, inferred_out_axes = batching.batch( flat_fun, axis_data, in_axes_flat, - lambda: flatten_axes("vmap out_axes", out_tree(), out_axes) + lambda: flatten_axes("vmap out_axes", out_tree(), out_axes), + sum_match=sum_match ).call_wrapped(*args_flat) except batching.SpecMatchError as e: out_axes_flat = flatten_axes("vmap out_axes", out_tree(), out_axes) @@ -1026,7 +1220,11 @@ def vmap_f(*args, **kwargs): path, _ = pairs[e.leaf_idx] raise ValueError(f'at vmap out_axes{keystr(path)}, got axis spec {e.dst} ' f'but output was batched on axis {e.src}') from None - return tree_unflatten(out_tree(), out_flat) + if any(d is batching.infer for d in tree_leaves(out_axes)): + return (tree_unflatten(out_tree(), out_flat), + tree_unflatten(out_tree(), inferred_out_axes)) + else: + return tree_unflatten(out_tree(), out_flat) return cast(F, vmap_f) @@ -1038,16 +1236,20 @@ def _get_spec(arg, i): except (IndexError, TypeError): return None - temp_spec = None + out_spec = None + non_none_count = 0 for arg, i in zip(args_flat, in_axes_flat): if i is not None: spec = _get_spec(arg, i) - if temp_spec is not None and temp_spec != spec: + if non_none_count != 0 and out_spec != spec: raise ValueError( "Mapped away dimension of inputs passed to vmap should be sharded" - f" the same. Got inconsistent axis specs: {temp_spec} vs {spec}") - temp_spec = spec - return temp_spec + f" the same. Got inconsistent axis specs: {out_spec} vs {spec}") + out_spec = spec + non_none_count += 1 + if out_spec is not None and not isinstance(out_spec, tuple): + out_spec = (out_spec,) + return out_spec def _mapped_axis_size(fn, tree, vals, dims, name): if not vals: @@ -1135,7 +1337,7 @@ def _all_sizes_index(sz): msg.append(f" * some axes ({ct} of them) had size {sz}, e.g. axis {ax} of {ex};\n") raise ValueError(''.join(msg)[:-2]) # remove last semicolon and newline - +@partial(api_boundary, repro_api_name="jax.pmap") def pmap( fun: Callable, axis_name: AxisName | None = None, @@ -1149,7 +1351,19 @@ def pmap( donate_argnums: int | Iterable[int] = (), global_arg_shapes: tuple[tuple[int, ...], ...] | None = None, ) -> Any: - """Parallel map with support for collective operations. + """Old way of doing parallel map. Use :py:func:`jax.shard_map` instead. + + .. note:: + While :py:func:`jax.pmap` works, you should probably use + :py:func:`jax.shard_map` or ``jax.smap`` instead. shard_map supports more + efficient autodiff, and is more composable in the multi-controller setting. + See https://docs.jax.dev/en/latest/notebooks/shard_map.html for examples. + + .. note:: + :py:func:`pmap` is now implemented in terms of :py:func:`jit` and + :py:func:`shard_map`. Please see the `migration + guide `_ for + more information. The purpose of :py:func:`pmap` is to express single-program multiple-data (SPMD) programs. Applying :py:func:`pmap` to a function will compile the @@ -1241,7 +1455,7 @@ def pmap( arguments will not be donated. For more details on buffer donation see the - `FAQ `_. + `FAQ `_. Returns: A parallelized version of ``fun`` with arguments that correspond to those of @@ -1308,26 +1522,6 @@ def pmap( are important particularly in the case of nested :py:func:`pmap` functions, where collective operations can operate over distinct axes: - >>> from functools import partial - >>> import jax - >>> - >>> @partial(pmap, axis_name='rows') - ... @partial(pmap, axis_name='cols') - ... def normalize(x): - ... row_normed = x / jax.lax.psum(x, 'rows') - ... col_normed = x / jax.lax.psum(x, 'cols') - ... doubly_normed = x / jax.lax.psum(x, ('rows', 'cols')) - ... return row_normed, col_normed, doubly_normed - >>> - >>> x = jnp.arange(8.).reshape((4, 2)) - >>> row_normed, col_normed, doubly_normed = normalize(x) # doctest: +SKIP - >>> print(row_normed.sum(0)) # doctest: +SKIP - [ 1. 1.] - >>> print(col_normed.sum(1)) # doctest: +SKIP - [ 1. 1. 1. 1.] - >>> print(doubly_normed.sum((0, 1))) # doctest: +SKIP - 1.0 - On multi-process platforms, collective operations operate over all devices, including those on other processes. For example, assuming the following code runs on two processes with 4 XLA devices each: @@ -1371,10 +1565,8 @@ def pmap( " removed from JAX. Please migrate to pjit and remove global_arg_shapes" " from pmap.") - # TODO(yashkatariya): Move this out after shard_map is out of experimental and - # in _src if config.pmap_shmap_merge.value: - from jax.experimental.shard_map import pmap + from jax._src.pmap import pmap # pytype: disable=import-error return pmap(fun, axis_name, in_axes=in_axes, out_axes=out_axes, static_broadcasted_argnums=static_broadcasted_argnums, devices=devices, backend=backend, @@ -1427,7 +1619,7 @@ def _get_global_axis_size(local_axis_size: int, in_devices, backend_name: str, if global_axis_size is None: if xb.process_count(backend) == 1: global_axis_size = local_axis_size - elif in_devices: + elif in_devices is not None: global_axis_size = len(in_devices) else: global_axis_size = local_axis_size * xb.process_count(backend) @@ -1488,7 +1680,7 @@ def _prepare_pmap(fun: Callable, in_axes, out_axes, static_broadcasted_tuple, "Instead, each argument passed by keyword is mapped over its " "leading axis. See the description of `in_axes` in the `pmap` " "docstring: " - "https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html#jax.pmap") + "https://docs.jax.dev/en/latest/_autosummary/jax.pmap.html#jax.pmap") msg += ("\n\nCheck that the value of the `in_axes` argument to `pmap` " "is a tree prefix of the tuple of arguments passed positionally to " "the pmapped function.") @@ -1568,11 +1760,13 @@ def _cpp_pmap( out_axes) del static_broadcasted_argnums, donate_argnums + prepare_pmap_fn = partial(_prepare_pmap, + fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple, + devices, backend, axis_size) + @api_boundary def cache_miss(*args, **kwargs): - p = _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, - donate_tuple, devices, backend, - axis_size, args, kwargs) + p = prepare_pmap_fn(args, kwargs) for arg in p.flat_args: dispatch.check_arg(arg) @@ -1593,11 +1787,11 @@ def cache_miss(*args, **kwargs): with core.take_current_trace() as trace: try: if isinstance(trace, core.EvalTrace): - execute = pxla.xla_pmap_impl_lazy(p.flat_fun, *p.flat_args, **params) - out = execute(*p.flat_args) + execute, const_args = pxla.xla_pmap_impl_lazy(p.flat_fun, *p.flat_args, **params) + out = execute(*const_args, *p.flat_args) else: out = pxla.xla_pmap_p.bind_with_trace(trace, (p.flat_fun, *p.flat_args), params) - except dispatch.InternalFloatingPointError as e: + except api_util.InternalFloatingPointError as e: raise FloatingPointError(f'Invalid value ({e.ty}) encountered in parallel computation.') out_tree, out_flat = p.out_tree, out @@ -1612,6 +1806,7 @@ def cache_miss(*args, **kwargs): # TODO(sharadmv): Enable effects in replicated computation not execute_replicated.has_unordered_effects and not execute_replicated.has_host_callbacks and + len(const_args) == 0 and # No tracers in the outputs. all(isinstance(x, xc.ArrayImpl) for x in out_flat)) @@ -1644,58 +1839,51 @@ def cache_miss(*args, **kwargs): cpp_mapped_f = pmap_lib.pmap( fun, cache_miss, static_broadcasted_tuple, - lambda x, s: pxla.shard_args([s], [None], [None], [x])[0], + lambda x, s: pxla.shard_args([s], [None], + [xc.ArrayCopySemantics.REUSE_INPUT], [x])[0], pytree_registry=tree_util.default_registry) _pmap_cache_clears.add(cpp_mapped_f) pmap_f = wraps(fun)(cpp_mapped_f) + # Store some data for the `lower` and `trace` methods pmap_f._fun = fun - - @api_boundary - def lower(*args, **kwargs): - return trace(*args, **kwargs).lower() - - @api_boundary - def trace(*args, **kwargs): - p = _prepare_pmap( - fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple, - devices, backend, axis_size, args, kwargs) - abstract_args = list(map(shaped_abstractify, p.flat_args)) - closed_jaxpr, xc_backend, replicas, shards, pci = pxla.get_pmap_jaxpr( - p.flat_fun, backend, axis_name, - axis_size=p.local_axis_size, global_axis_size=p.global_axis_size, - devices=p.devices, - name=p.flat_fun.__name__, - in_axes=p.in_axes_flat, - out_axes_thunk=p.out_axes_thunk, - avals=abstract_args) - lower_callable = partial( - pxla.lower_parallel_callable, p.flat_fun, axis_name, - axis_size=p.local_axis_size, global_axis_size=p.global_axis_size, - devices=p.devices, - name=p.flat_fun.__name__, - in_axes=p.in_axes_flat, - donated_invars=p.donated_invars, - is_explicit_global_axis_size=p.is_explicit_global_axis_size, - avals=abstract_args, - closed_jaxpr=closed_jaxpr, - backend=xc_backend, - replicas=replicas, - shards=shards, - pci=pci) - args_info = stages.make_args_info(p.in_tree, abstract_args, donate_tuple) - return stages.Traced(closed_jaxpr, args_info, p.flat_fun.__name__, - p.out_tree(), lower_callable) - - pmap_f.lower = lower - pmap_f.trace = trace - + pmap_f._prepare_pmap = prepare_pmap_fn + pmap_f._backend = backend + pmap_f._axis_name = axis_name + pmap_f._donate_tuple = donate_tuple + + # TODO(necula): move these to top-level; we don't need to do this for + # every pmap + cpp_mapped_f_class = type(pmap_f) + cpp_mapped_f_class.lower = _cpp_mapped_lower + # We return directly the function produced by pmap_lib.pmap, because we do not + # want to have Python in the dispatch path. return pmap_f +@api_boundary +def _cpp_mapped_lower(pmap_f, *args, **kwargs): + p = pmap_f._prepare_pmap(args, kwargs) + abstract_args = list(map(shaped_abstractify, p.flat_args)) + closed_jaxpr, xc_backend, replicas, shards, pci = pxla.get_pmap_jaxpr( + p.flat_fun, pmap_f._backend, pmap_f._axis_name, + axis_size=p.local_axis_size, global_axis_size=p.global_axis_size, + devices=p.devices, name=p.flat_fun.__name__, in_axes=p.in_axes_flat, + out_axes_thunk=p.out_axes_thunk, avals=abstract_args) + lowering = pxla.lower_parallel_callable( + p.flat_fun, pmap_f._axis_name, axis_size=p.local_axis_size, + global_axis_size=p.global_axis_size, devices=p.devices, + name=p.flat_fun.__name__, in_axes=p.in_axes_flat, + donated_invars=p.donated_invars, + is_explicit_global_axis_size=p.is_explicit_global_axis_size, + avals=abstract_args, closed_jaxpr=closed_jaxpr, backend=xc_backend, + replicas=replicas, shards=shards, pci=pci, lowering_platforms=None, + lowering_parameters=pxla.mlir.LoweringParameters()) + args_info = stages.make_args_info(p.in_tree, abstract_args, pmap_f._donate_tuple) + return stages.Lowered(lowering, args_info, p.out_tree()) _pmap_cache_clears = weakref.WeakSet() # type: ignore -@api_boundary +@partial(api_boundary, repro_api_name="jax.jvp") def jvp( fun: Callable, primals, tangents, has_aux: bool = False ) -> tuple[Any, ...]: @@ -1753,6 +1941,7 @@ def _jvp(fun: lu.WrappedFun, primals, tangents, has_aux=False): f"structure; primals have tree structure {tree_def} whereas tangents have " f"tree structure {tree_def_2}.") for p, t in zip(ps_flat, ts_flat): + if not isinstance(core.typeof(p), ShapedArray): continue if core.primal_dtype_to_tangent_dtype(_dtype(p)) != _dtype(t): raise TypeError("primal and tangent arguments to jax.jvp do not match; " "dtypes must be equal, or in case of int/bool primal dtype " @@ -1789,6 +1978,7 @@ def linearize(fun: Callable, *primals, has_aux: Literal[True] ) -> tuple[Any, Callable, Any]: ... +@partial(api_boundary, repro_api_name="jax.linearize") def linearize(fun: Callable, *primals, has_aux: bool = False ) -> tuple[Any, Callable] | tuple[Any, Callable, Any]: """Produces a linear approximation to ``fun`` using :py:func:`jvp` and partial eval. @@ -1887,10 +2077,27 @@ def fun(*tangents): for primal_aval, tangent_aval in zip(primal_avals, tangent_avals): expected_tangent_aval = primal_aval.to_tangent_aval() if not core.typecompat(expected_tangent_aval, tangent_aval): - raise ValueError("linearized function called on tangent values inconsistent with " - "the original primal values: " - f"got tangent aval {tangent_aval} for primal aval {primal_aval} " - f"but expected {expected_tangent_aval}") + extra_msg = '' + if (isinstance(primal_aval, core.ShapedArray) and + isinstance(tangent_aval, core.ShapedArray) and + primal_aval.vma != tangent_aval.vma): + pvary_applications = [] + if left := tangent_aval.vma - primal_aval.vma: + pvary_applications.append( + f"applying `jax.lax.pcast(..., {tuple(left)}, to='varying')` to" + " the primal value passed to `jax.linearize`") + if left := primal_aval.vma - tangent_aval.vma: + pvary_applications.append( + f"applying `jax.lax.pcast(..., {tuple(left)}, to='varying')` to" + " the tangent value passed to the callable `f_jvp` returned by" + " `jax.linearize`") + extra_msg = " \nThis might be fixed by:\n" + "\n".join( + f" * {d};" for d in pvary_applications) + raise ValueError( + "linearized function called on tangent values inconsistent with " + "the original primal values:\n" + f"Got tangent aval {tangent_aval} for primal aval {primal_aval} " + f"but expected {expected_tangent_aval}.{extra_msg}") tangents_out = eval_jaxpr(jaxpr, consts, *tangents) tangents_out_ = iter(tangents_out) full_out = [pval.get_known() if pval.is_known() else next(tangents_out_) @@ -1931,7 +2138,7 @@ def _vjp_pullback_wrapper(name, out_primal_avals, io_tree, fun, *py_args_): f"got {in_tree}, but expected to match {in_tree_expected}") for arg, aval in zip(args, out_primal_avals): ct_aval = shaped_abstractify(arg) - ct_aval_expected = aval.to_tangent_aval() + ct_aval_expected = aval.to_cotangent_aval() if (not core.typecompat(ct_aval, ct_aval_expected) and not _temporary_dtype_exception(ct_aval, ct_aval_expected)): raise ValueError( @@ -1960,7 +2167,8 @@ def vjp(fun: Callable[..., tuple[T, U]], *primals: Any, has_aux: Literal[True], reduce_axes: Sequence[AxisName] = ()) -> tuple[T, Callable, U]: ... -@api_boundary + +@partial(api_boundary, repro_api_name="jax.vjp") def vjp( fun: Callable, *primals, has_aux: bool = False, reduce_axes=() ) -> tuple[Any, Callable] | tuple[Any, Callable, Any]: @@ -2011,28 +2219,181 @@ def vjp( fun, debug_info=debug_info("vjp", fun, primals, {})) return _vjp(wrapped_fun, *primals, has_aux=has_aux) -def _vjp(fun: lu.WrappedFun, *primals, has_aux=False): - """Variant of vjp() that takes an lu.WrappedFun.""" +def _vjp(fun, *primals, has_aux=False): + canon = lambda x: x if isinstance(x, core.Tracer) else canonicalize_value(x) + primals = tree_map(canon, primals) primals_flat, in_tree = tree_flatten(primals) - for arg in primals_flat: dispatch.check_arg(arg) + for arg in primals_flat: + dispatch.check_arg(arg) if not has_aux: flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree) - out_primals, vjp = ad.vjp(flat_fun, primals_flat) + out_primals_flat, out_pvals, jaxpr, residuals = ad.linearize( + flat_fun, *primals_flat) out_tree = out_tree() else: flat_fun, out_aux_trees = flatten_fun_nokwargs2(fun, in_tree) - out_primals, vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True) + out_primals_flat, out_pvals, jaxpr, residuals, aux = ad.linearize( + flat_fun, *primals_flat, has_aux=True) out_tree, aux_tree = out_aux_trees() - out_primal_avals = map(shaped_abstractify, out_primals) - out_primal_py = tree_unflatten(out_tree, out_primals) - vjp_py = Partial(partial(_vjp_pullback_wrapper, fun.__name__, - out_primal_avals, (out_tree, in_tree)), vjp) + del out_aux_trees + out_known = [pval.is_known() for pval in out_pvals] + id_map = {id(x): i for i, x in enumerate(primals_flat)} + used, opaque_residuals = set(), [] + spec = [used.add(id(r)) or RSpec(id_map[id(r)], True) if id(r) in id_map else # type: ignore + RSpec(opaque_residuals.append(r) or (len(opaque_residuals) - 1), False) # type: ignore + for r in residuals] + args_res = tuptree_map(lambda x: x if id(x) in used else NotNeeded(), + in_tree, primals_flat) + out_primal_avals = [typeof(x) for x in out_primals_flat] + f_vjp = VJP(partial(_vjp3_callable, spec, out_known, jaxpr, out_primal_avals), + in_tree, out_tree, list(args_res), opaque_residuals) + out_primals = tree_unflatten(out_tree, out_primals_flat) if not has_aux: - return out_primal_py, vjp_py + return out_primals, f_vjp + else: + return out_primals, f_vjp, tree_unflatten(aux_tree, aux) + +def _vjp3_callable(spec, out_known, jaxpr, out_primal_avals, in_tree, out_tree, + args_res, opaque_res, *maybe_ct_refs): + if not maybe_ct_refs: + maybe_ct_refs_flat = [GradValue()] * in_tree.num_leaves else: - return out_primal_py, vjp_py, tree_unflatten(aux_tree, aux) + maybe_ct_refs_flat, in_tree_ = tree_flatten(maybe_ct_refs) + if in_tree != in_tree_: + raise Exception # TODO accept isomorph tuple tree + args_res_ = tree_leaves(args_res, is_leaf=lambda x: isinstance(x, NotNeeded)) + residuals = [args_res_[i.idx] if i.primal else opaque_res[i.idx] for i in spec] + maybe_refs = [ad.RefAccum(v.aval, x) if _is_ref(x) else ad.ValAccum(v.aval) + for v, x in zip(jaxpr.invars, maybe_ct_refs_flat)] + return Partial(partial(_vjp3_bwd, in_tree, out_tree, out_known, jaxpr, + out_primal_avals), residuals, maybe_refs) + +def _vjp3_bwd(in_tree, out_tree, out_known, jaxpr, out_primal_avals, residuals, + maybe_refs, out_ct): + cts_flat, out_tree_ = tree_flatten(out_ct) + if out_tree != out_tree_: + _vjp_ct_tree_error(jaxpr, out_tree, out_tree_) + _vjp_check_ct_avals(cts_flat, out_primal_avals) + cts_flat = [ct for ct, k in zip(cts_flat, out_known) if not k] + ad.backward_pass3(jaxpr, True, residuals, maybe_refs, cts_flat) + arg_cts = [x.freeze() if isinstance(x, ad.ValAccum) else GradRef() + for x in maybe_refs] + arg_cts = map(ad.instantiate_zeros, arg_cts) + return tree_unflatten(in_tree, arg_cts) + + +@dataclasses.dataclass(frozen=True) +class RSpec: + idx: int + primal: bool + +def tuptree_map(f, treedef, x): + return treedef.walk(lambda xs, _: tuple(xs), f, x) + +def _is_ref(x): + from jax._src.state.types import AbstractRef + try: + return isinstance(typeof(x), AbstractRef) + except: + return False + + +_vjp_too_many_args = """ +The function returned by `jax.vjp` applied to {} was called with {} arguments, +but functions returned by `jax.vjp` must be called with a single argument +corresponding to the single value returned by the function being differentiated +(even if that returned value is a tuple or other container). + +For example, if we have: + + def f(x): + return (x, x) + _, f_vjp = jax.vjp(f, 1.0) + +the function `f` returns a single tuple as output, and so we call `f_vjp` with a +single tuple as its argument: + x_bar, = f_vjp((2.0, 2.0)) +If we instead call `f_vjp(2.0, 2.0)`, with the values 'splatted out' as +arguments rather than in a tuple, this error can arise. +""".format + + +def _vjp_ct_tree_error(jaxpr, out_tree, ct_tree): + msg = f"""unexpected tree structure. + +The argument to a VJP function returned by `jax.vjp` must match the pytree +structure of the differentiated function {jaxpr.debug_info.func_src_info}. + +But the tree structures differ: +""" + msg += '\n'.join(f" * out{keystr(path)} was a {thing1} in the original " + f" output, but a {thing2} here, so {explanation}." + for path, thing1, thing2, explanation + in equality_errors_pytreedef(out_tree, ct_tree)) + raise ValueError(msg) + + +def _vjp_check_ct_avals(cts, primal_avals): + # TODO(mattjj): improve this error by flattening with keys in the first place + for ct, aval in zip(cts, primal_avals): + ct_aval = typeof(ct) + ct_aval_expected = aval.to_cotangent_aval() + if (not core.typecompat(ct_aval, ct_aval_expected) and + not _temporary_dtype_exception(ct_aval, ct_aval_expected)): + raise ValueError( + "unexpected JAX type (e.g. shape/dtype) for argument to VJP function: " + f"got {ct_aval.str_short()}, but expected {ct_aval_expected.str_short()} " + "because the corresponding output of the differentiated function had JAX type " + f"{aval.str_short()}") + + +@register_dataclass +@dataclasses.dataclass(frozen=True) +class NotNeeded: + pass + +@dataclasses.dataclass(frozen=True) +class GradValue: + pass + +@dataclasses.dataclass(frozen=True) +class GradRef: + pass + +@dataclasses.dataclass +class VJP: + fun: Callable + in_tree: PyTreeDef + out_tree: PyTreeDef + args_res: list[Any] + opaque_residuals: list[Any] + jaxpr = property(lambda self: self.fun.args[2]) # type: ignore + + def __call__(self, out_ct, *extra_args): + if extra_args: + name, *_ = self.jaxpr.debug_info.func_src_info.split(' ') + raise TypeError(_vjp_too_many_args(name, len(extra_args))) + return self.fun(self.in_tree, self.out_tree, self.args_res, + self.opaque_residuals)(out_ct) + + def with_refs(self, *maybe_ct_refs): + return self.fun(self.in_tree, self.out_tree, self.args_res, + self.opaque_residuals, *maybe_ct_refs) + + # Only safe to put these in cache keys if residuals aren't mutated. Beware! + __hash__ = object.__hash__ + __eq__ = object.__eq__ + +register_pytree_node( + VJP, + lambda vjp: ((vjp.args_res, vjp.opaque_residuals), + (vjp.fun, vjp.in_tree, vjp.out_tree)), + lambda meta, args_res: VJP(*meta, *args_res)) + + +@partial(api_boundary, repro_api_name="jax.linear_transpose") def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable: """Transpose a function that is promised to be linear. @@ -2078,14 +2439,14 @@ def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable: debug_info=debug_info("linear_transpose", fun, primals, {})), in_tree) in_avals = map(shaped_abstractify, primals_flat) - in_dtypes = map(dtypes.dtype, in_avals) + in_dtypes = map(lambda a: a.dtype, in_avals) in_pvals = map(pe.PartialVal.unknown, in_avals) jaxpr, out_pvals, const = pe.trace_to_jaxpr_nounits(flat_fun, in_pvals, instantiate=True) jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars), True) out_avals, _ = unzip2(out_pvals) - out_dtypes = map(dtypes.dtype, out_avals) + out_dtypes = map(lambda a: a.dtype, out_avals) if not (all(dtypes.issubdtype(d, np.inexact) for d in in_dtypes + out_dtypes) or all(dtypes.issubdtype(d, np.integer) for d in in_dtypes + out_dtypes)): @@ -2111,22 +2472,12 @@ def transposed_fun(const, out_cotangent): return Partial(transposed_fun, const) -def _flat_axes_specs(abstracted_axes, *args, **kwargs - ) -> list[pe.AbstractedAxesSpec]: - if kwargs: raise NotImplementedError - def ax_leaf(l): - return (isinstance(l, dict) and all_leaves(l.values()) or - isinstance(l, tuple) and all_leaves(l, lambda x: x is None)) - return broadcast_prefix(abstracted_axes, args, ax_leaf) - - @overload def make_jaxpr( fun: Callable, static_argnums: int | Iterable[int] = (), axis_env: Sequence[tuple[AxisName, int]] | None = None, return_shape: Literal[False] = ..., - abstracted_axes: Any | None = None, ) -> Callable[..., core.ClosedJaxpr]: ... @@ -2136,18 +2487,17 @@ def make_jaxpr( static_argnums: int | Iterable[int] = (), axis_env: Sequence[tuple[AxisName, int]] | None = None, return_shape: Literal[True] = ..., - abstracted_axes: Any | None = None, ) -> Callable[..., tuple[core.ClosedJaxpr, Any]]: ... +@partial(api_boundary, repro_api_name="jax.make_japr") def make_jaxpr( fun: Callable, static_argnums: int | Iterable[int] = (), axis_env: Sequence[tuple[AxisName, int]] | None = None, return_shape: bool = False, - abstracted_axes: Any | None = None, ) -> Callable[..., core.ClosedJaxpr | tuple[core.ClosedJaxpr, Any]]: - """Creates a function that produces its jaxpr given example args. + """Create a function that returns the jaxpr of ``fun`` given example args. Args: fun: The function whose ``jaxpr`` is to be computed. Its positional @@ -2198,7 +2548,7 @@ def make_jaxpr( c:f32[] = sin a _:f32[] = sin b d:f32[] = cos b - e:f32[] = mul 1.0 d + e:f32[] = mul 1.0:f32[] d f:f32[] = neg e g:f32[] = mul f c in (g,) } @@ -2213,15 +2563,13 @@ def make_jaxpr( @api_boundary def make_jaxpr_f(*args, **kwargs): with core.extend_axis_env_nd(axis_env or []): - traced = jit(fun, static_argnums=static_argnums, - abstracted_axes=abstracted_axes).trace(*args, **kwargs) - # `jit` converts tracers in consts to args but that breaks the semantics of - # `make_jaxpr`. Hence convert the tracers in args back to consts in jaxpr. - if traced._num_consts: - consts, _ = split_list(traced._args_flat, [traced._num_consts]) - jaxpr_ = pe.convert_invars_to_constvars(traced.jaxpr.jaxpr, - traced._num_consts) - jaxpr = core.ClosedJaxpr(jaxpr_, consts) + traced = jit(fun, static_argnums=static_argnums).trace(*args, **kwargs) + # `jit` converts tracers in consts to args but `make_jaxpr` callers expect + # consts not to be converted. + num_consts = traced._num_consts + if num_consts: + jaxpr_ = pe.convert_invars_to_constvars(traced.jaxpr.jaxpr, num_consts) + jaxpr = core.ClosedJaxpr(jaxpr_, traced._consts) else: jaxpr = traced.jaxpr if return_shape: @@ -2248,7 +2596,7 @@ def _infer_src_sharding(src, x) -> Sharding | None: return None -@lru_cache(maxsize=2048) +@util.cache(max_size=2048, trace_context_in_key=False) def _check_string_compatible_sharding(s): """Checks if target devices are compatible with string arrays.""" if isinstance(s, xc.Device) and s.device_kind == "cpu": @@ -2260,17 +2608,15 @@ def _check_string_compatible_sharding(s): "String arrays can only be sharded to CPU devices. Received" f" unsupported device or sharding: {s}") -# TODO(yashkatariya): Generalize check_compatible_aval (maybe renamed) and use -# that to check if shardings are compatible with the input. -@lru_cache(maxsize=2048) + +@util.cache(max_size=2048, trace_context_in_key=False) def _check_sharding(aval, s): if (s is not None and - not isinstance(s, (xc.Device, Sharding, Layout, TransferToMemoryKind))): + not isinstance(s, (xc.Device, Sharding, Format, core.MemorySpace))): raise ValueError( "`jax.device_put` only accepts `None`, `jax.sharding.Sharding`," - " `jax.Device`, `Layout` or a pytree of these values. Received" - f" invalid value: {s}") - + " `jax.Device`, `Format`, `jax.memory.Space` or a pytree of these" + f" values. Received invalid value: {s}") if isinstance(aval, core.ShapedArray) and dtypes.is_string_dtype(aval.dtype): _check_string_compatible_sharding(s) @@ -2282,20 +2628,20 @@ def _check_sharding(aval, s): (s,), (aval,), ("",), "device_put args", allow_uneven_sharding=False) s.shard_shape(aval.shape) # should raise an Error if incompatible -def pspec_to_sharding(val): +def pspec_to_sharding(name, val): if isinstance(val, P): mesh = get_concrete_mesh() - if mesh is None: + if mesh.empty: raise ValueError( - "Please set a mesh via `jax.sharding.use_mesh` if a PartitionSpec is" - " passed to device_put") + "Please set a mesh via `jax.set_mesh` if a PartitionSpec is" + f" passed to {name}") return NamedSharding(mesh, val) return val def device_put( x, - device: None | xc.Device | Sharding | P | Layout | Any | TransferToMemoryKind = None, - *, src: None | xc.Device | Sharding | P | Layout | Any | TransferToMemoryKind = None, + device: None | xc.Device | Sharding | P | Format | Any = None, + *, src: None | xc.Device | Sharding | P | Format | Any = None, donate: bool | Any = False, may_alias: bool | None | Any = None): """Transfers ``x`` to ``device``. @@ -2332,20 +2678,20 @@ def device_put( with config.explicit_device_put_scope(): x_flat, treedef = tree_flatten(x) if (device is None or - isinstance(device, (xc.Device, Sharding, TransferToMemoryKind))): + isinstance(device, (xc.Device, Sharding, core.MemorySpace))): device_flat = [device] * len(x_flat) else: device_flat = flatten_axes("device_put device", treedef, device) if (src is None or - isinstance(src, (xc.Device, Sharding, TransferToMemoryKind))): + isinstance(src, (xc.Device, Sharding, core.MemorySpace))): src_flat = [_infer_src_sharding(src, xf) for xf in x_flat] else: src_flat = flatten_axes("device_put source", treedef, src) src_flat = list(map(_infer_src_sharding, src_flat, x_flat)) - device_flat = map(pspec_to_sharding, device_flat) - src_flat = map(pspec_to_sharding, src_flat) + device_flat = map(partial(pspec_to_sharding, 'device_put'), device_flat) + src_flat = map(partial(pspec_to_sharding, 'device_put'), src_flat) if isinstance(donate, bool): donate_flat = [donate] * len(x_flat) @@ -2364,18 +2710,27 @@ def device_put( if m is None: m = not d if m and not d: - copy_semantics.append(dispatch.CopySemantics.ALIAS) + copy_semantics.append(dispatch.ArrayCopySemantics.REUSE_INPUT) elif not m and d: - copy_semantics.append(dispatch.CopySemantics.DONATE) + copy_semantics.append(dispatch.ArrayCopySemantics.DONATE_INPUT) else: assert not m and not d - copy_semantics.append(dispatch.CopySemantics.COPY) + copy_semantics.append(dispatch.ArrayCopySemantics.ALWAYS_COPY) + dst_avals = [] for xf, d in zip(x_flat, device_flat): - _check_sharding(shaped_abstractify(xf), d) - out_flat = dispatch.device_put_p.bind( - *x_flat, devices=device_flat, srcs=src_flat, - copy_semantics=copy_semantics) + aval = shaped_abstractify(xf) + aval = dispatch.update_dp_aval(aval, d) + dst_avals.append(aval) + _check_sharding(aval, d) + if core.trace_state_clean(): + out_flat = dispatch._batched_device_put_impl( + *x_flat, devices=device_flat, srcs=src_flat, # type: ignore + copy_semantics=copy_semantics, dst_avals=dst_avals) + else: + out_flat = dispatch.device_put_p.bind( + *x_flat, devices=tuple(device_flat), srcs=tuple(src_flat), + copy_semantics=tuple(copy_semantics)) return tree_unflatten(treedef, out_flat) @@ -2403,8 +2758,8 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): # >>> import jax >>> devices = jax.local_devices() >>> x = [jax.numpy.ones(5) for device in devices] - >>> y = jax.device_put_sharded(x, devices) - >>> np.allclose(y, jax.numpy.stack(x)) + >>> y = jax.device_put_sharded(x, devices) # doctest: +SKIP + >>> np.allclose(y, jax.numpy.stack(x)) # doctest: +SKIP True Passing a list of nested container objects with arrays at the leaves for @@ -2412,14 +2767,14 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): # all entries in the list to have the same tree structure: >>> x = [(i, jax.numpy.arange(i, i + 4)) for i in range(len(devices))] - >>> y = jax.device_put_sharded(x, devices) - >>> type(y) + >>> y = jax.device_put_sharded(x, devices) # doctest: +SKIP + >>> type(y) # doctest: +SKIP - >>> y0 = jax.device_put_sharded([a for a, b in x], devices) - >>> y1 = jax.device_put_sharded([b for a, b in x], devices) - >>> np.allclose(y[0], y0) + >>> y0 = jax.device_put_sharded([a for a, b in x], devices) # doctest: +SKIP + >>> y1 = jax.device_put_sharded([b for a, b in x], devices) # doctest: +SKIP + >>> np.allclose(y[0], y0) # doctest: +SKIP True - >>> np.allclose(y[1], y1) + >>> np.allclose(y[1], y1) # doctest: +SKIP True See Also: @@ -2443,18 +2798,19 @@ def _device_put_sharded(*xs): raise ValueError("the shards passed to device_put_sharded must have " f"consistent shape and dtype, but got {a1} and {a2}.") stacked_aval = avals[0].update(shape=(len(devices),) + avals[0].shape) - sharding_spec = sharding_specs.create_pmap_sharding_spec(stacked_aval.shape) - sharding = PmapSharding(np.array(devices), sharding_spec) + if config.pmap_shmap_merge.value: + mesh = Mesh(np.array(devices), ('_device_put_sharded',)) + sharding = NamedSharding(mesh, P('_device_put_sharded')) + else: + sharding_spec = sharding_specs.create_pmap_sharding_spec(stacked_aval.shape) + sharding = PmapSharding(np.array(devices), sharding_spec) if dtypes.issubdtype(stacked_aval.dtype, dtypes.extended): return stacked_aval.dtype._rules.device_put_sharded(xs, stacked_aval, sharding, devices) - if config.pmap_no_rank_reduction.value: - ys = [] - for x in xs: - if not isinstance(x, (np.ndarray, basearray.Array)): - x = np.asarray(x) - ys.append(x[None]) - else: - ys = xs + ys = [] + for x in xs: + if not isinstance(x, (np.ndarray, basearray.Array)): + x = np.asarray(x) + ys.append(x[None]) return pxla.batched_device_put(stacked_aval, sharding, ys, list(devices)) @@ -2485,8 +2841,8 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811 >>> import jax >>> devices = jax.local_devices() >>> x = jax.numpy.array([1., 2., 3.]) - >>> y = jax.device_put_replicated(x, devices) - >>> np.allclose(y, jax.numpy.stack([x for _ in devices])) + >>> y = jax.device_put_replicated(x, devices) # doctest: +SKIP + >>> np.allclose(y, jax.numpy.stack([x for _ in devices])) # doctest: +SKIP True See Also: @@ -2499,18 +2855,18 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811 def _device_put_replicated(x): aval = core.unmapped_aval(len(devices), 0, core.get_aval(x)) assert isinstance(aval, ShapedArray) - sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape) - if config.pmap_no_rank_reduction.value: - if isinstance(x, (np.ndarray, basearray.Array)): - buf = device_put(x[None], devices[0]) - else: - buf = device_put(x, devices[0])[None] + if isinstance(x, (np.ndarray, basearray.Array)): + buf = device_put(x[None], devices[0]) + else: + buf = device_put(x, devices[0])[None] + if config.pmap_shmap_merge.value: + mesh = Mesh(np.array(devices), ('_device_put_replicated',)) + sharding = NamedSharding(mesh, P('_device_put_replicated')) else: - buf = device_put(x, devices[0]) - sharding = PmapSharding(np.array(devices), sharding_spec) + sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape) + sharding = PmapSharding(np.array(devices), sharding_spec) if dtypes.issubdtype(aval.dtype, dtypes.extended): return aval.dtype._rules.device_put_replicated(buf, aval, sharding, devices) - assert len(xla.aval_to_xla_shapes(aval)) == 1 return pxla.batched_device_put(aval, sharding, [buf] * len(devices), devices) with config.explicit_device_put_scope(): @@ -2575,80 +2931,7 @@ def device_get(x: Any): return tree_map(_device_get, x) -class ShapeDtypeStruct: - """A container for the shape, dtype, and other static attributes of an array. - - ``ShapeDtypeStruct`` is often used in conjunction with :func:`jax.eval_shape`. - - Args: - shape: a sequence of integers representing an array shape - dtype: a dtype-like object - sharding: (optional) a :class:`jax.Sharding` object - """ - __slots__ = ["shape", "dtype", "sharding", "_dll", "weak_type"] - - def __init__(self, shape, dtype, *, sharding=None, weak_type=False): - self.shape = tuple(shape) - if dtype is None: - raise ValueError("ShapeDtypeStruct: dtype must be specified.") - self.dtype = dtype if dtypes.issubdtype(dtype, dtypes.extended) else np.dtype(dtype) - if sharding is not None and not isinstance(sharding, (Sharding, Layout)): - raise ValueError( - "sharding should be an instance of `jax.sharding.Sharding` or" - f" `jax.experimental.layout.Layout`. Got {sharding} of type" - f" {type(sharding)}.") - if (isinstance(sharding, Layout) and - isinstance(sharding.device_local_layout, AutoLayout)): - raise TypeError( - "`DeviceLocalLayout.AUTO` cannot be used in place of a device-local" - f" layout in a `ShapeDtypeStruct`. Got {sharding}") - self.sharding = sharding.sharding if isinstance(sharding, Layout) else sharding - self._dll = sharding.device_local_layout if isinstance(sharding, Layout) else None - self.weak_type = weak_type - - size = property(lambda self: math.prod(self.shape)) - ndim = property(lambda self: len(self.shape)) - - @property - def layout(self): - return Layout(self._dll, self.sharding) - - def __len__(self): - try: - return self.shape[0] - except IndexError as e: - raise TypeError("len() of unsized object") from e # same as numpy error - - def __repr__(self): - sh = f", sharding={self.sharding}" if self.sharding is not None else "" - l = f", layout={self.layout}" if self._dll is not None else "" - wt = f", weak_type={self.weak_type}" if self.weak_type else "" - return (f"{type(self).__name__}(shape={self.shape}, " - f"dtype={self.dtype.name}{sh}{l}{wt})") - - __str__ = __repr__ - - def __eq__(self, other): - if not isinstance(other, ShapeDtypeStruct): - return False - else: - return ((self.shape, self.dtype, self.sharding, self.layout, self.weak_type) == - (other.shape, other.dtype, other.sharding, other.layout, other.weak_type)) - - def __hash__(self): - # TODO(frostig): avoid the conversion from dict by addressing - # https://github.com/jax-ml/jax/issues/8182 - return hash((self.shape, self.dtype, self.sharding, self.layout, self.weak_type)) - -def _sds_aval_mapping(x): - aval = ShapedArray( - x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True), - weak_type=x.weak_type) - return core.update_aval_with_sharding(aval, x.sharding) -core.pytype_aval_mappings[ShapeDtypeStruct] = _sds_aval_mapping - - -@api_boundary +@partial(api_boundary, repro_api_name="jax.eval_shape") def eval_shape(fun: Callable, *args, **kwargs): """Compute the shape/dtype of ``fun`` without any FLOPs. @@ -2714,11 +2997,14 @@ def eval_shape(fun, *args, **kwargs): >>> print(out.dtype) float32 """ + if type(fun) is xc._xla.PjitFunction: + return fun.trace(*args, **kwargs).out_info # type: ignore try: hash(fun) except TypeError: fun = partial(fun) - return jit(fun).eval_shape(*args, **kwargs) + return jit(fun).trace(*args, **kwargs).out_info +@partial(api_boundary, repro_api_name="jax.named_call") def named_call( fun: F, *, @@ -2745,7 +3031,7 @@ def named_call( within the name scope. Use the fun.__name__ if not specified. Returns: - A version of `fun` that is wrapped in a name_scope. + A version of ``fun`` that is wrapped in a ``named_scope``. """ if name is None: name = fun.__name__ @@ -2870,21 +3156,16 @@ def copy_to_host_async(x): return x + def clear_backends(): """ Clear all backend clients so that new backend clients can be created later. """ xb._clear_backends() - xb.local_devices.cache_clear() - xb.process_count.cache_clear() - dispatch.xla_primitive_callable.cache_clear() util.clear_all_caches() - pjit._infer_params_cached.cache_clear() - pjit._create_pjit_jaxpr.cache_clear() # pytype: disable=attribute-error pjit._cpp_pjit_cache_fun_only.clear() pjit._cpp_pjit_cache_explicit_attributes.clear() xc._xla.PjitFunctionCache.clear_all() - xc._xla.jax_jit.thread_local_state().extra_jit_context = None @atexit.register def clean_up(): @@ -2911,17 +3192,11 @@ def clear_caches(): # Clear all lu.cache, util.cache and util.weakref_lru_cache instances # (used for staging and Python-dispatch compiled executable caches). util.clear_all_caches() - util.clear_all_weakref_lru_caches() - # Clear all C++ compiled executable caches for pjit pjit._cpp_pjit_cache_fun_only.clear() pjit._cpp_pjit_cache_explicit_attributes.clear() - pjit._infer_params_cached.cache_clear() xc._xla.PjitFunctionCache.clear_all() # Clear all C++ compiled executable caches for pmap for fun in _pmap_cache_clears: fun._cache_clear() - - # Clear particular util.cache instances. - dispatch.xla_primitive_callable.cache_clear() diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index a42141b96fbd..1f62ed7b1a78 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -26,17 +26,18 @@ from jax._src import dtypes from jax._src.state.types import AbstractRef from jax._src.tree_util import ( - PyTreeDef, tree_flatten, tree_unflatten, tree_map, - treedef_children, generate_key_paths, broadcast_prefix, - prefix_errors) -from jax._src.tree_util import _replace_nones + PyTreeDef, tree_flatten, tree_unflatten, treedef_children, + generate_key_paths, broadcast_prefix, prefix_errors, + none_leaf_registry, broadcast_flattened_prefix_with_treedef) from jax._src import linear_util as lu from jax._src.util import (safe_map, WrapKwArgs, Hashable, HashableFunction, Unhashable, safe_zip) from jax._src import traceback_util + traceback_util.register_exclusion(__file__) -map = safe_map +map, unsafe_map = safe_map, map +zip, unsafe_zip = safe_zip, zip def _ensure_index(x: Any) -> int | tuple[int, ...]: """Ensure x is either an index or a tuple of indices.""" @@ -201,9 +202,11 @@ def _validate_argnames( f"in {argnames_name}. Function does not take these args.") -def argnums_partial(f, dyn_argnums, args, require_static_args_hashable=True): +def argnums_partial(f: lu.WrappedFun, dyn_argnums: int | Sequence[int], + args: Sequence, require_static_args_hashable=True): dyn_argnums = _ensure_index_tuple(dyn_argnums) dyn_argnums = _ensure_inbounds(False, len(args), dyn_argnums) + fixed_args: list if require_static_args_hashable: fixed_args = [] for i, arg in enumerate(args): @@ -246,18 +249,22 @@ def _ensure_inbounds(allow_invalid: bool, num_args: int, argnums: Sequence[int] result.append(i % num_args) # Resolve negative return tuple(result) +def _split_args(static_argnums, args, allow_invalid): + static_argnums = _ensure_inbounds(allow_invalid, len(args), static_argnums) + dyn_argnums = tuple(i for i in range(len(args)) if i not in static_argnums) + dyn_args = tuple(args[i] for i in dyn_argnums) + return static_argnums, dyn_argnums, dyn_args def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...], args: tuple[Any, ...], *, allow_invalid: bool): "Version of ``argnums_partial`` that checks hashability of static_argnums." if not static_argnums: return f, args - static_argnums = _ensure_inbounds(allow_invalid, len(args), static_argnums) - dyn_argnums = tuple(i for i in range(len(args)) if i not in static_argnums) - dyn_args = tuple(args[i] for i in dyn_argnums) + static_argnums, dyn_argnums, dyn_args = _split_args( + static_argnums, args, allow_invalid) fixed_args = [] - for i in static_argnums: + for i in sorted(static_argnums): # TODO(shoyer): set allow_invalid=True permanently after static_argnames. if allow_invalid and i >= len(args): continue @@ -273,7 +280,9 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...], return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args @lu.transformation2 -def _argnums_partial(_fun, _dyn_argnums, _fixed_args, *dyn_args, **kwargs): +def _argnums_partial(_fun: Callable, + _dyn_argnums: Sequence[int], + _fixed_args: Sequence, *dyn_args, **kwargs): sentinel = object() args = [sentinel] * (len(_fixed_args) + len(dyn_args)) for i, arg in zip(_dyn_argnums, dyn_args): @@ -334,7 +343,7 @@ def donation_vector(donate_argnums, donate_argnames, in_tree, donate = bool(i in donate_argnums) res.extend((donate,) * arg.num_leaves) if kwargs_tree is not None: - for key, val in safe_zip(kwargs_tree.node_data()[1], kwargs_tree.children()): # type: ignore + for key, val in zip(kwargs_tree.node_data()[1], kwargs_tree.children()): # type: ignore donate = key in donate_argnames res.extend((donate,) * val.num_leaves) return tuple(res) @@ -390,13 +399,10 @@ def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False): # leaves, i.e. the Nones are to be considered leaves) that is a tree prefix of # the given treedef, build a complete axis spec tree with the same structure # and return the flattened result - # TODO(mattjj,phawkins): improve this implementation - proxy = object() - dummy = tree_unflatten(treedef, [SENTINEL] * treedef.num_leaves) - axes = [] - add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0])) + axis_tree_leaves, axis_treedef = none_leaf_registry.flatten(axis_tree) try: - tree_map(add_leaves, _replace_nones(proxy, axis_tree), dummy) + axes = broadcast_flattened_prefix_with_treedef( + axis_tree_leaves, axis_treedef, treedef) except ValueError: if kws: # if keyword arguments are included in the tree, we make adapt the error @@ -419,7 +425,6 @@ def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False): raise ValueError(f"{name} specification must be a tree prefix of the " f"corresponding value, got specification {axis_tree} " f"for value tree {treedef}.{hint}") from None - axes = [None if a is proxy else a for a in axes] assert len(axes) == treedef.num_leaves return axes @@ -504,9 +509,6 @@ def resolve_argnums( * fills in any missing pieces (e.g., names given numbers, or vice versa), * validates the argument names/numbers against the function signature, * validates that donated and static arguments don't intersect. - * rebases the donated arguments so they index into the dynamic arguments, - (after static arguments have been removed), in the order that parameters - are passed into the compiled function. """ if signature is None: # Some built-in functions don't support signature. @@ -540,7 +542,6 @@ def resolve_argnums( # Compensate for static argnums absorbing args _assert_no_intersection(static_argnames, donate_argnames) - donate_argnums = rebase_donate_argnums(donate_argnums, static_argnums) return donate_argnums, donate_argnames, static_argnums, static_argnames @@ -594,20 +595,23 @@ def debug_info( *, static_argnums: Sequence[int] = (), static_argnames: Sequence[str] = (), - result_paths_thunk: Callable[[], tuple[str, ...]] | None = None, + result_paths_thunk: Callable[[], tuple[str, ...]] | core.InitialResultPaths = core.initial_result_paths, # TODO(necula): check if we really need this, e.g., to speed up tracing? sourceinfo: str | None = None, signature: inspect.Signature | None = None, ) -> core.DebugInfo: - """Constructd core.DebugInfo for a function given example args and kwargs. + """Construct core.DebugInfo for a function given example args and kwargs. - `args` and `kwargs` are example positional and keyword arguments, users with - `inspect.Signature` to get the names of argments. The arguments that are + `args` and `kwargs` are example positional and keyword arguments, used with + `inspect.Signature` to get the names of arguments. The arguments that are considered static for tracing purposes should be included, and designated using `static_argnums` and `static_argnames`. See docstring for linear_util.DebugInfo. """ + res = getattr(fun, "__fun_debug_info__", None) + if res is not None: + return res if sourceinfo is None: sourceinfo = fun_sourceinfo(fun) if signature is None: @@ -623,25 +627,15 @@ def fun_signature(fun: Callable) -> inspect.Signature | None: except (ValueError, TypeError): return None -def save_wrapped_fun_sourceinfo(wrapper: Callable, - wrapped: Callable | core.DebugInfo) -> None: - # Prefer this to functools.wraps because it does not create a reference to - # the wrapped function. - if isinstance(wrapped, core.DebugInfo): - func_src_info = wrapped.func_src_info - elif callable(wrapped): - func_src_info = fun_sourceinfo(wrapped) - else: - assert False, wrapped # Unreachable - setattr(wrapper, "__fun_sourceinfo__", func_src_info) +def save_wrapped_fun_debug_info(wrapper: Callable, + dbg: core.DebugInfo) -> None: + setattr(wrapper, "__fun_debug_info__", dbg) _fun_name_re = re.compile(r"(?:)") # TODO(mattjj): make this function internal to this module def fun_sourceinfo(fun: Callable) -> str: # See DebugInfo.fun_src_info - res = getattr(fun, "__fun_sourceinfo__", None) - if res is not None: return res while isinstance(fun, partial): fun = fun.func fun = inspect.unwrap(fun) @@ -671,40 +665,47 @@ def _non_static_arg_names(fn_signature: inspect.Signature | None, If the `fn_signature` is given then we get from it the names of the top-level arguments. In other cases, including when the `args` and `kwargs` - do not match the signature, we use names like `args[0[]`, `args[1]`, etc. + do not match the signature, we use names like `args[0]`, `args[1]`, etc. """ + # Use the same argument parsing as jit: positional followed by kwargs + # sorted by keys. static = object() static_argnums_ = _ensure_inbounds(True, len(args), static_argnums) static_argnames_ = set(static_argnames) args_ = [static if i in static_argnums_ else x for i, x in enumerate(args)] - kwargs_ = {k:static if k in static_argnames_ else x for k, x in kwargs.items()} + kwargs_ = {k: static if k in static_argnames_ else x for k, x in kwargs.items()} + ordered_args: Sequence[tuple[str, Any]] | None = None if fn_signature is not None: try: ba = fn_signature.bind(*args_, **kwargs_) except (ValueError, TypeError): pass else: - return tuple(f'{name}{lu._clean_keystr_arg_names(path)}' - for name, x in ba.arguments.items() - for path, l in generate_key_paths(x) if l is not static) - args_arg_names = tuple(f'args{lu._clean_keystr_arg_names(path)}' - for path, l in generate_key_paths(args_) - if l is not static) - kwargs_arg_names = tuple(f'kwargs{lu._clean_keystr_arg_names(path)}' - for path, l in generate_key_paths(kwargs_) - if l is not static) - arg_names = args_arg_names + kwargs_arg_names - return arg_names - -def hoist_obj_attrs(f, flat_args): - idxs, objs, flat_args_ = [], [], [] - for i, x in enumerate(flat_args): - if type(x) in _class_with_attrs: - objs.append(_HashableByObjectId(x)) - else: - idxs.append(i) - flat_args_.append(x) - return _argnums_partial(f, tuple(idxs), tuple(objs)), flat_args_ + # Do we have a **kwargs + kwargs_name = next((name for name, p in fn_signature.parameters.items() + if p.kind == inspect.Parameter.VAR_KEYWORD), None) + # Positional argument are those not passed by keyword and not passed + # by **kwargs. + positional = [(name, x) for name, x in ba.arguments.items() + if name not in kwargs and name != kwargs_name] + # Keyword arguments are passed sorted by actual kwarg keyword + sorted_kwargs = sorted(((name, x) for name, x in kwargs_.items()), + key=lambda name_x: name_x[0]) + sorted_kwargs = [(name if name in ba.arguments else f"{kwargs_name}['{name}']", + x) + for name, x in sorted_kwargs] + ordered_args = positional + sorted_kwargs + + if ordered_args is None: + positional = [("args", args_)] + keyword = sorted([(f"kwargs['{name}']", x) for name, x in kwargs_.items() if x is not static], + key=lambda name_x: name_x[0]) + ordered_args = positional + keyword + + return tuple(f'{name}{lu._clean_keystr_arg_names(path)}' + for name, x in ordered_args + for path, l in generate_key_paths(x) if l is not static) + class _HashableByObjectId: __slots__ = ['val'] @@ -715,22 +716,21 @@ def __hash__(self): def __eq__(self, other): return self.val is other.val -def register_class_with_attrs(t: type) -> None: - _class_with_attrs.add(t) -_class_with_attrs: set[type] = set() - # TODO(mattjj): make this function faster -def _check_no_aliased_ref_args(dbg: core.DebugInfo, avals, args): +def check_no_aliased_ref_args(dbg_fn: Callable[[], core.DebugInfo], + maybe_avals, args) -> None: assert config.mutable_array_checks.value refs: dict[int, int] = {} - for i, (a, x) in enumerate(zip(avals, args)): + for i, (a, x) in enumerate(zip(maybe_avals, args)): if (isinstance(a, AbstractRef) and (dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i): + dbg = dbg_fn() raise ValueError( "only one reference to a mutable array may be passed as an argument " f"to a function, but when tracing {dbg.func_src_info} for {dbg.traced_for} " f"the mutable array reference of type {a.str_short()} appeared at both " - f"{dbg.arg_names[dup_idx]} and {dbg.arg_names[i]}." + f"{dbg.arg_names[dup_idx] if dbg.arg_names is not None else 'unknown'} " + f"and {dbg.arg_names[i] if dbg.arg_names is not None else 'unknown'}." if dbg else f"at both flat index {dup_idx} and flat index {i}") from None @@ -746,3 +746,41 @@ def _check_no_aliased_closed_over_refs(dbg: core.DebugInfo, consts, args) -> Non f"array reference of type {a.str_short()} was both closed over and " f"passed as the argument " f"{dbg.safe_arg_names(len(args))[i]}" if dbg else "at flat index {i}") + +class InternalFloatingPointError(Exception): + name: str + ty: str + + def __init__(self, name: str, ty: str): + self.name = name + self.ty = ty + +def maybe_recursive_nan_check(e: Exception, fun: Callable, args, kwargs, +) -> None: # always raises an exception + print("Invalid nan value encountered in the output of a jax.jit " + "function. Calling the de-optimized version.") + try: + _ = fun(*args, **kwargs) + except (FloatingPointError, ZeroDivisionError) as e2: + raise e2 from None + else: + _raise_no_nan_in_deoptimized(e) + + +def _raise_no_nan_in_deoptimized(e) -> None: + msg = (f"{str(e)}. Because " + "jax_config.debug_nans.value and/or config.jax_debug_infs is set, the " + "de-optimized function (i.e., the function as if the `jit` " + "decorator were removed) was called in an attempt to get a more " + "precise error message. However, the de-optimized function did not " + "produce invalid values during its execution. This behavior can " + "result from `jit` optimizations causing the invalid value to be " + "produced. It may also arise from having nan/inf literals as " + "inputs or outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. " + "\n\n" + "It may be possible to avoid the invalid value by removing the " + "`jit` decorator, at the cost of losing optimizations. " + "\n\n" + "If you see this error, consider opening a bug report at " + "https://github.com/jax-ml/jax.") + raise FloatingPointError(msg) from None diff --git a/jax/_src/array.py b/jax/_src/array.py index b0793d2c3330..1ce52849a9de 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -18,6 +18,7 @@ from collections.abc import Callable, Sequence import enum import functools +from functools import partial import math import operator as op from typing import Any, TYPE_CHECKING, cast @@ -26,28 +27,31 @@ from jax._src import basearray from jax._src import config from jax._src import core -from jax._src import deprecations from jax._src import dispatch from jax._src import dtypes from jax._src import errors +from jax._src import literals from jax._src import profiler from jax._src import util from jax._src import xla_bridge +from jax._src.op_shardings import are_hlo_shardings_equal from jax._src.interpreters import mlir from jax._src.interpreters import pxla -from jax._src.interpreters import xla -from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout +from jax._src.layout import AutoLayout, Format, Layout +from jax._src.lib import _jax from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension as xe +from jax._src.mesh import empty_concrete_mesh from jax._src.sharding import Sharding +from jax._src.tree_util import broadcast_prefix, tree_flatten, tree_unflatten from jax._src.sharding_impls import ( - PmapSharding, SingleDeviceSharding, + PmapSharding, SingleDeviceSharding, NamedSharding, device_replica_id_map, hashed_index, num_addressable_indices, - local_to_global_shape, use_concrete_mesh) # pyformat: disable -from jax._src.typing import ArrayLike, DLDeviceType, DTypeLike + local_to_global_shape, _internal_use_concrete_mesh) # pyformat: disable +from jax._src.typing import ArrayLike, DLDeviceType, DTypeLike, ExtendedDType from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method, cache import numpy as np +zip, unsafe_zip = safe_zip, zip Shape = tuple[int, ...] Device = xc.Device @@ -119,16 +123,6 @@ def _reconstruct_array(fun, args, arr_state, aval_state): np_value = fun(*args) np_value.__setstate__(arr_state) jnp_value = api.device_put(np_value) - # TODO(slebedev): Remove this branch after December 10th 2024. - if "named_shape" in aval_state: - deprecations.warn( - "jax-aval-named-shape", - "Pickled array contains an aval with a named_shape attribute. This is" - " deprecated and the code path supporting such avals will be removed." - " Please re-pickle the array.", - stacklevel=2, - ) - del aval_state["named_shape"] jnp_value.aval = jnp_value.aval.update(**aval_state) return jnp_value @@ -160,7 +154,7 @@ def _process_has_full_value_in_mcjax(s, shape): def _validate_shape_and_dtype_for_per_device_arrays( - arrays: Sequence[ArrayImpl | np.ndarray], + arrays: Sequence[ArrayImpl | np.ndarray | literals.TypedNdArray], sharding: Sharding, aval: core.ShapedArray, expected_shape: Shape, @@ -182,8 +176,6 @@ def _validate_shape_and_dtype_for_per_device_arrays( class ArrayImpl(basearray.Array): - # TODO(yashkatariya): Add __slots__ here. - aval: core.ShapedArray _sharding: Sharding _arrays: list[ArrayImpl] @@ -292,9 +284,6 @@ def weak_type(self): def committed(self) -> bool: return self._committed - def __str__(self): - return str(self._value) - def __len__(self): try: return self.shape[0] @@ -336,29 +325,26 @@ def tolist(self): return self._value.tolist() def __format__(self, format_spec): - # Simulates behavior of https://github.com/numpy/numpy/pull/9883 - if self.ndim == 0: - return format(self._value[()], format_spec) + if isinstance(self.sharding, NamedSharding) and self.sharding.spec.unreduced: + return repr(self) + elif (self.is_fully_addressable or self.is_fully_replicated and + self.sharding.has_addressable_devices): + # Simulates behavior of https://github.com/numpy/numpy/pull/9883 + return format(self._value if self.ndim else self._value[()], format_spec) else: - return format(self._value, format_spec) + return repr(self) def __getitem__(self, idx): - from jax._src.lax import lax - from jax._src.numpy import indexing + from jax._src.lax import lax # pytype: disable=import-error + from jax._src.numpy import indexing # pytype: disable=import-error self._check_if_deleted() if isinstance(self.sharding, PmapSharding): - if config.pmap_no_rank_reduction.value: - cidx = idx if isinstance(idx, tuple) else (idx,) + cidx = idx if isinstance(idx, tuple) else (idx,) - padded_cidx = tuple( - slice(i, i + 1, None) if isinstance(i, int) else i for i in cidx - ) + (slice(None),) * (len(self.shape) - len(cidx)) - else: - if not isinstance(idx, tuple): - padded_cidx = (idx,) + (slice(None),) * (len(self.shape) - 1) - else: - padded_cidx = idx + (slice(None),) * (len(self.shape) - len(idx)) + padded_cidx = tuple( + slice(i, i + 1, None) if isinstance(i, int) else i for i in cidx + ) + (slice(None),) * (len(self.shape) - len(cidx)) indices = tuple(self.sharding.devices_indices_map(self.shape).values()) try: @@ -369,12 +355,11 @@ def __getitem__(self, idx): out = self._arrays[arr_idx] sharding = SingleDeviceSharding(_get_device(out)) - if config.pmap_no_rank_reduction.value: - # If cidx was the index of a single shard, then it corresponds to one - # shard of the chunked dimension. - dims = tuple(i for i, x in enumerate(cidx) if isinstance(x, int)) - # Squeeze on committed arrays to avoid data movement to shard 0. - out = lax.squeeze(out, dimensions=dims) + # If cidx was the index of a single shard, then it corresponds to one + # shard of the chunked dimension. + dims = tuple(i for i, x in enumerate(cidx) if isinstance(x, int)) + # Squeeze on committed arrays to avoid data movement to shard 0. + out = lax.squeeze(out, dimensions=dims) return ArrayImpl( out.aval, sharding, [out], committed=False, _skip_checks=True) @@ -402,24 +387,37 @@ def is_fully_replicated(self) -> bool: def __repr__(self): prefix = 'Array(' if self.aval is not None and self.aval.weak_type: - dtype_str = f'dtype={self.dtype.name}, weak_type=True)' + dtype_str = f'dtype={self.dtype.name}, weak_type=True' else: - dtype_str = f'dtype={self.dtype.name})' + dtype_str = f'dtype={self.dtype.name}' - if self.is_fully_addressable or self.is_fully_replicated: + if isinstance(self.sharding, NamedSharding) and self.sharding.spec.unreduced: + return f"Array(shape={self.shape}, {dtype_str}, sharding={self.sharding})" + elif self.is_fully_addressable or self.is_fully_replicated: line_width = np.get_printoptions()["linewidth"] if self.size == 0: s = f"[], shape={self.shape}" + elif not self.sharding.has_addressable_devices: + s = f"shape={self.shape}" else: s = np.array2string(self._value, prefix=prefix, suffix=',', separator=', ', max_line_width=line_width) last_line_len = len(s) - s.rfind('\n') + 1 sep = ' ' - if last_line_len + len(dtype_str) + 1 > line_width: + if last_line_len + len(dtype_str) + 2 > line_width: sep = ' ' * len(prefix) - return f"{prefix}{s},{sep}{dtype_str}" + return f"{prefix}{s},{sep}{dtype_str})" else: - return f"{prefix}{self.shape}, {dtype_str}" + return f"{prefix}shape={self.shape}, {dtype_str})" + + def __str__(self): + if isinstance(self.sharding, NamedSharding) and self.sharding.spec.unreduced: + return repr(self) + elif (self.is_fully_addressable or self.is_fully_replicated and + self.sharding.has_addressable_devices): + return str(self._value) # doesn't print Array(...) + else: + return repr(self) @property def is_fully_addressable(self) -> bool: @@ -444,7 +442,7 @@ def __dlpack__(self, *, stream: int | Any | None = None, max_version: tuple[int, int] | None = None, dl_device: tuple[DLDeviceType, int] | None = None, copy: bool | None = None): - from jax._src.dlpack import to_dlpack # pylint: disable=g-import-not-at-top + from jax._src.dlpack import to_dlpack # pytype: disable=import-error # pylint: disable=g-import-not-at-top device_set = self.sharding.device_set if len(device_set) > 1: @@ -464,7 +462,7 @@ def __dlpack_device__(self) -> tuple[enum.Enum, int]: if len(self._arrays) != 1: raise BufferError("__dlpack__ only supported for unsharded arrays.") - from jax._src.dlpack import DLDeviceType # pylint: disable=g-import-not-at-top + from jax._src.dlpack import DLDeviceType # pytype: disable=import-error # pylint: disable=g-import-not-at-top if self.platform() == "cpu": return DLDeviceType.kDLCPU, 0 @@ -547,17 +545,17 @@ def addressable_shards(self) -> Sequence[Shard]: return out @property - def layout(self): + def format(self): # TODO(yashkatariya): Remove the deleted check from here. if self.is_deleted(): - return Layout(None, self.sharding) + return Format(None, self.sharding) try: - return Layout(DeviceLocalLayout.from_pjrt_layout(self._pjrt_layout), + return Format(Layout.from_pjrt_layout(self._pjrt_layout), self.sharding) - except xe.XlaRuntimeError as e: + except _jax.JaxRuntimeError as e: msg, *_ = e.args if type(msg) is str and msg.startswith("UNIMPLEMENTED"): - return Layout(None, self.sharding) + return Format(None, self.sharding) else: raise @@ -624,7 +622,7 @@ def _copy_single_device_array_to_host_async(self): def copy_to_host_async(self): self._check_if_deleted() if self._npy_value is None: - if self.is_fully_replicated: + if self.is_fully_replicated and self.sharding.has_addressable_devices: self._copy_single_device_array_to_host_async() return for i, _ in _cached_index_calc(self.sharding, self.shape): @@ -636,7 +634,8 @@ def _value(self) -> np.ndarray: self._check_if_deleted() if self._npy_value is None: - if self.is_fully_replicated: + # addressable_device_list can be empty. If it's empty, we will error below + if self.is_fully_replicated and self.sharding.has_addressable_devices: npy_value, did_copy = self._single_device_array_to_np_array_did_copy() npy_value.flags.writeable = False if did_copy: @@ -645,6 +644,7 @@ def _value(self) -> np.ndarray: # TODO(yashkatariya): Merge `_process_has_full_value_in_mcjax` with # is_fully_addressable. + # is_fully_addressable return False if addressable_device_list is empty. if (not self.is_fully_addressable and not _process_has_full_value_in_mcjax(self.sharding, self.shape)): raise RuntimeError( @@ -680,29 +680,29 @@ def _get_shape_from_index(slc: Index, shape: Shape) -> Shape: if isinstance(s, slice) # If element is int, this dimension is reduced ) -def _get_and_check_dtype(arrays: Sequence[basearray.Array | np.ndarray], - dtype: DTypeLike | None, fname: str): - if arrays: - if dtype is None: + +def _get_and_check_dtype( + arrays: Sequence[basearray.Array | np.ndarray | literals.TypedNdArray], + dtype: DTypeLike | ExtendedDType | None, + fname: str, +): + if dtype is None: + if arrays: dtype = arrays[0].dtype else: - if arrays[0].dtype != dtype: - raise ValueError( - f"If `dtype` is provided to `jax.{fname}`, it must match the dtype " - f"of the addressable shards. Got dtype={dtype} and shard " - f"dtype={arrays[0].dtype}`.") - else: - if not config.enable_empty_arrays.value: - raise ValueError( - f"Building an Array with no addressable shards with `jax.{fname}` is " - "supported only if `jax.config.enable_empty_arrays` is set to True." - ) - if dtype is None: raise ValueError( "If the Array has no addressable shards, `dtype` must be provided " f"via the `dtype` argument to `jax.{fname}`.") + else: + dtype = dtypes.check_and_canonicalize_user_dtype(dtype, fname) + if arrays and arrays[0].dtype != dtype: + raise ValueError( + f"If `dtype` is provided to `jax.{fname}`, it must match the dtype " + f"of the addressable shards. Got dtype={dtype} and shard " + f"dtype={arrays[0].dtype}`.") return dtype + # explicitly set to be unhashable. setattr(ArrayImpl, "__hash__", None) setattr(ArrayImpl, "__array_priority__", 100) @@ -710,7 +710,7 @@ def _get_and_check_dtype(arrays: Sequence[basearray.Array | np.ndarray], # TODO(yashkatariya): Remove None from callback input type. def make_array_from_callback( - shape: Shape, sharding: Sharding | Layout, + shape: Shape, sharding: Sharding | Format, data_callback: Callable[[Index | None], ArrayLike], dtype: DTypeLike | None = None) -> ArrayImpl: # pyformat: disable @@ -755,18 +755,20 @@ def make_array_from_callback( (4, 2) """ # pyformat: enable - dll = sharding.device_local_layout if isinstance(sharding, Layout) else None + dll = sharding.layout if isinstance(sharding, Format) else None if isinstance(dll, AutoLayout): raise TypeError( - "`DeviceLocalLayout.AUTO` cannot be used in place of a device-local" + "`Layout.AUTO` cannot be used in place of a device-local" f" layout when calling `jax.make_array_from_callback`. Got {sharding}") - sharding = sharding.sharding if isinstance(sharding, Layout) else sharding + sharding = sharding.sharding if isinstance(sharding, Format) else sharding if not isinstance(sharding, Sharding): raise TypeError( f"sharding should be an instance of `jax.sharding`. Got {sharding} of" f" type {type(sharding)}") - def get_data(index: Index | None) -> ArrayImpl | np.ndarray: + def get_data( + index: Index | None, + ) -> ArrayImpl | literals.TypedNdArray | np.ndarray: # Perhaps cache on index here, then we can unify fully_replicated # and non-fully_replicated cases below and become faster for # partially replicated cases. @@ -777,8 +779,14 @@ def get_data(index: Index | None) -> ArrayImpl | np.ndarray: "jax.make_array_from_callback cannot be called within a traced" " context." ) - # Value can be python scalar, resolve it into something with dtype. - return xla.canonicalize_dtype(r) + # Value can be python scalars, resolve it into something with dtype. + r = dtypes.canonicalize_value(r) + if isinstance(r, (literals.TypedInt, literals.TypedFloat, + literals.TypedComplex)): + r = literals.TypedNdArray(np.asarray(r, dtype=r.dtype), weak_type=False) + elif isinstance(r, bool): + r = literals.TypedNdArray(np.asarray(r, dtype=np.bool_), weak_type=False) + return r if sharding.is_fully_replicated: devices = list(sharding._internal_device_list.addressable_device_list) # type: ignore @@ -811,7 +819,7 @@ def get_data(index: Index | None) -> ArrayImpl | np.ndarray: and sharding.is_fully_replicated and first_value.is_fully_replicated and first_value.sharding._device_assignment == tuple(devices) - and first_value.layout.device_local_layout == dll): + and first_value.format.layout == dll): return first_value if dtypes.issubdtype(aval.dtype, dtypes.extended): @@ -822,7 +830,7 @@ def get_data(index: Index | None) -> ArrayImpl | np.ndarray: ) if dll is not None: - devices = [Layout(dll, SingleDeviceSharding(d)) for d in devices] + devices = [Format(dll, SingleDeviceSharding(d)) for d in devices] # type: ignore # pxla.batched_device_put doesn't support Layout... Take the slow route arrays = api.device_put(per_device_values, devices) return ArrayImpl(aval, sharding, arrays, committed=True) @@ -836,10 +844,9 @@ def get_data(index: Index | None) -> ArrayImpl | np.ndarray: def make_array_from_process_local_data( - sharding: Sharding, - local_data: np.ndarray, - global_shape: Shape | None = None, -) -> ArrayImpl: + sharding, # PyTree[jax.sharding.Sharding] + local_data, # PyTree[np.ndarray] + global_shape=None): # PyTree[Shape] # pyformat: disable """Creates distributed tensor using the data available in process. @@ -900,14 +907,14 @@ def make_array_from_process_local_data( >>> assert output_global_array.addressable_data(0).shape == per_device_shape >>> assert output_global_array.shape == global_shape - NB: While most shardings are uniform, It is possible to design am exotic + NB: While most shardings are uniform, It is possible to design an exotic sharding mesh where each process's devices will be arranged in a non-grid like pattern in some dimensions, or for indices to overlap non-trivially. Such sharding is called "non-uniform" in those dimensions. In that case, the global shape along those directions must match local shape as there is no meaningful way to represent all needed per-process data in non-overlapping fashion. For example for global_shape 4x4 - if sharding looks like this: + if sharding looks like this:: 0123 2103 @@ -915,7 +922,7 @@ def make_array_from_process_local_data( 4567 with 4 processes, containing devices (0,1), (2, 3), (4, 5), (6, 7) respectively. - Then the data for each host look like + Then the data for each host look like:: xx.. ..xx .... .... .xx. x..x .... .... @@ -929,7 +936,7 @@ def make_array_from_process_local_data( In this case user must provide global_shape explicitly and for local_shape=(2, 4), potentially valid global shapes are (2, 4) and (4, 4). - On the other hand for sharding: + On the other hand for sharding:: 0213 x.x. .x.x. .... .... 0213 x.x. .x.x. .... .... @@ -952,9 +959,33 @@ def make_array_from_process_local_data( Tensor that will have sharding=sharding and of shape global_shape. """ # pyformat: enable + local_data_flat, treedef = tree_flatten(local_data) + sharding_flat = broadcast_prefix(sharding, local_data) + sharding_flat = map( + partial(api.pspec_to_sharding, 'make_array_from_process_local_data'), + sharding_flat) + global_shape_flat = broadcast_prefix( + global_shape, local_data, + is_leaf=lambda x: x is None or isinstance(x, tuple)) if xla_bridge.process_count() == 1: + # Safety check if the provided data doesn't match expected global_shape + for s, d in zip(global_shape_flat, local_data_flat): + if s is not None and s != d.shape: + raise ValueError( + "When calling `make_array_from_process_local_data` on a single" + " process, global_shape should be None or equal to" + f" local_data.shape.Got global_shape={s} and" + f" local_data.shape={d.shape}." + ) return api.device_put(local_data, sharding) + out = [_array_from_process_local_data(data, s, shape) + for data, s, shape in zip(local_data_flat, sharding_flat, global_shape_flat)] + return tree_unflatten(treedef, out) + +def _array_from_process_local_data( + local_data: np.ndarray, sharding: Sharding, + global_shape: Shape | None = None) -> ArrayImpl: # TODO(sandler): consider supporting partially specified global_shape or # making local_to_global_shape available in the api. local_shape = local_data.shape @@ -978,7 +1009,7 @@ def make_array_from_process_local_data( if process_slice != data_dim: raise ValueError( "Invalid host data, each dimension should match either global or " - f"process shape. In dimension {i=}, the process data has {data_dim}" + f"process shape. In dimension {i}, the process data has {data_dim} " f"elements. Process addresses {process_slice} elements and " f"{global_shape=}." ) @@ -1024,7 +1055,7 @@ def make_array_from_single_device_arrays( shape : Shape of the output ``jax.Array``. This conveys information already included with ``sharding`` and ``arrays`` and serves as a double check. sharding: Sharding: A global Sharding instance which describes how the output jax.Array is laid out across devices. - arrays: Sequence of ``jax.Array``\s that are each single device addressable. ``len(arrays)`` + arrays: `list` or `tuple` of ``jax.Array``\s that are each single device addressable. ``len(arrays)`` must equal ``len(sharding.addressable_devices)`` and the shape of each array must be the same. For multiprocess code, each process will call with a different ``arrays`` argument that corresponds to that processes' data. These arrays are commonly created via ``jax.device_put``. @@ -1071,14 +1102,15 @@ def make_array_from_single_device_arrays( if dtypes.issubdtype(aval.dtype, dtypes.extended): return aval.dtype._rules.make_sharded_array(aval, sharding, arrays, committed=True) + arrays = list(arrays) if isinstance(arrays, tuple) else arrays # TODO(phawkins): ideally the cast() could be checked. try: return ArrayImpl(aval, sharding, cast(Sequence[ArrayImpl], arrays), committed=True) except TypeError: - if not isinstance(arrays, Sequence): + if not isinstance(arrays, list): raise TypeError("jax.make_array_from_single_device_arrays `arrays` " - "argument must be a Sequence (list or tuple), but got " + "argument must be a list or tuple, but got " f"{type(arrays)}.") if any(isinstance(arr, core.Tracer) for arr in arrays): raise ValueError( @@ -1086,17 +1118,14 @@ def make_array_from_single_device_arrays( f" arrays as input, but got types {set(map(type, arrays))}") raise -xla.canonicalize_dtype_handlers[ArrayImpl] = pxla.identity +dtypes.canonicalize_value_handlers[ArrayImpl] = lambda x: x def _get_aval_array(self): return core.update_aval_with_sharding(self.aval, self.sharding) core.pytype_aval_mappings[ArrayImpl] = _get_aval_array -# TODO(jakevdp) replace this with true inheritance at the C++ level. -basearray.Array.register(ArrayImpl) - -def _array_mlir_constant_handler(val): +def _array_mlir_constant_handler(val, aval): try: return mlir.ir_constant(val._value) except RuntimeError as e: @@ -1111,6 +1140,8 @@ def _array_mlir_constant_handler(val): mlir.register_constant_handler(ArrayImpl, _array_mlir_constant_handler) +if config.use_simplified_jaxpr_constants.value: + core.literalable_types.add(ArrayImpl) # NOTE(skye): we could refactor to generate _multi_slice parameters directly # from the input ShardingSpec, rather than the indices. However, this would @@ -1149,7 +1180,7 @@ def shard_device_array(x, devices, indices, sharding): else: # TODO(yashkatariya): Maybe this should be set when we call the handler in # InputsHandler.__call__? - with use_concrete_mesh(None): + with _internal_use_concrete_mesh(empty_concrete_mesh): shards = x._multi_slice(start_indices, limit_indices, removed_dims) aval = core.shaped_abstractify(x) return pxla.batched_device_put(aval, sharding, shards, devices) @@ -1167,7 +1198,8 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding): # Look up all buffers that contain the correct slice of the logical array. candidates_list = candidates[hashed_index(idx)] if not candidates_list: - return pxla.shard_args([sharding], [None], [None], [x._value], + return pxla.shard_args([sharding], [None], + [xc.ArrayCopySemantics.REUSE_INPUT], [x._value], canonicalize=False)[0] # Try to find a candidate buffer already on the correct device, # otherwise copy one of them. @@ -1181,10 +1213,18 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding): @cache(max_size=4096, trace_context_in_key=False) -def _sharding_indices_and_eq(src_sharding, shape, dst_sharding): +def _fallback_check_via_indices(src_sharding, dst_sharding, shape): src_indices = src_sharding.addressable_devices_indices_map(shape).values() dst_indices = dst_sharding.addressable_devices_indices_map(shape).values() - return dst_indices, tuple(src_indices) == tuple(dst_indices) + return tuple(src_indices) == tuple(dst_indices) + +@cache(max_size=4096, trace_context_in_key=False) +def _sharding_indices_and_eq(src_sharding, dst_sharding, ndim): + hlos_eq = are_hlo_shardings_equal(src_sharding._to_xla_hlo_sharding(ndim), + dst_sharding._to_xla_hlo_sharding(ndim)) + len_eq = (len(src_sharding._internal_device_list.addressable_device_list) == + len(dst_sharding._internal_device_list.addressable_device_list)) + return hlos_eq and len_eq def _array_shard_arg(xs, shardings, layouts, copy_semantics): @@ -1196,35 +1236,39 @@ def _array_shard_arg(xs, shardings, layouts, copy_semantics): for i, (x, sharding, layout, cs) in enumerate( safe_zip(xs, shardings, layouts, copy_semantics)): x._check_if_deleted() - indices, same_indices = _sharding_indices_and_eq(x.sharding, x.shape, sharding) - same_layout = (True if layout is None else - x.layout.device_local_layout == layout) + try: + same_sharding = _sharding_indices_and_eq(x.sharding, sharding, len(x.shape)) + except NotImplementedError: + same_sharding = _fallback_check_via_indices(x.sharding, sharding, x.shape) + same_layout = True if layout is None else x.format.layout == layout if not x.is_fully_addressable: - if same_indices and same_layout: + if same_sharding and same_layout: results.append(x) else: raise NotImplementedError( "Cannot reshard an input that is not fully addressable") else: - devices = sharding._addressable_device_assignment - if same_indices and same_layout: + devices = sharding._internal_device_list.addressable_device_list + if same_sharding and same_layout: # Add a placeholder result that will be filled in later. results.append(None) # Accumulate arguments to `batched_copy_array_to_devices_with_sharding`. batch_xs.append(x) - batch_devs.append(list(devices)) + batch_devs.append(devices) batch_shardings.append(sharding) batch_indices.append(i) batch_cs.append(cs) # Resharding starts here: elif not same_layout: - results.append(api.device_put(x, Layout(layout, sharding))) - elif dispatch.is_single_device_sharding(x.sharding): - results.append(shard_device_array(x, devices, indices, sharding)) + results.append(api.device_put(x, Format(layout, sharding))) else: - results.append( - shard_sharded_device_array_slow_path(x, devices, indices, sharding)) + indices = sharding.addressable_devices_indices_map(x.shape).values() + if dispatch.is_single_device_sharding(x.sharding): + results.append(shard_device_array(x, devices, indices, sharding)) + else: + results.append( + shard_sharded_device_array_slow_path(x, devices, indices, sharding)) util.test_event("batched_copy_array") copy_outs = xc.batched_copy_array_to_devices_with_sharding( @@ -1237,9 +1281,12 @@ def _array_shard_arg(xs, shardings, layouts, copy_semantics): def _array_global_result_handler(global_aval, out_sharding, committed): - global_aval = core.update_aval_with_sharding(global_aval, out_sharding) if global_aval.dtype == dtypes.float0: - return lambda _: np.zeros(global_aval.shape, dtypes.float0) + def handler(xs): + return np.zeros(global_aval.shape, dtypes.float0) + phys_aval = core.physical_aval(global_aval) + return xc.array_result_handler(phys_aval, out_sharding, committed=committed, + _skip_checks=True).wrap(handler) if dtypes.issubdtype(global_aval.dtype, dtypes.extended): return global_aval.dtype._rules.global_sharded_result_handler( global_aval, out_sharding, committed) @@ -1251,7 +1298,11 @@ def _array_global_result_handler(global_aval, out_sharding, committed): # Only used for Arrays that come out of pmap. def _array_local_result_handler(aval, sharding, indices): if aval.dtype == dtypes.float0: - return lambda _: np.zeros(aval.shape, dtypes.float0) + def handler(xs): + return np.zeros(aval.shape, dtypes.float0) + phys_aval = core.physical_aval(aval) + return xc.array_result_handler(phys_aval, sharding, committed=True, + _skip_checks=True).wrap(handler) if dtypes.issubdtype(aval.dtype, dtypes.extended): return aval.dtype._rules.local_sharded_result_handler( aval, sharding, indices) @@ -1280,9 +1331,7 @@ def _token_shard_arg(xs, shardings, layouts, copy_semantics): def _token_global_result_handler(global_aval, out_sharding, committed): array_handler = _array_global_result_handler( core.get_token_aval(), out_sharding, committed) - - def wrapper(*args, **kwargs): - out_buf = array_handler(*args, **kwargs) - return core.Token(out_buf) - return wrapper + def wrapper(array): + return core.Token(array) + return array_handler.wrap(wrapper) # type: ignore pxla.global_result_handlers[core.AbstractToken] = _token_global_result_handler diff --git a/jax/_src/basearray.py b/jax/_src/basearray.py index a89d4a2949be..8053a99b61d6 100644 --- a/jax/_src/basearray.py +++ b/jax/_src/basearray.py @@ -16,10 +16,16 @@ from __future__ import annotations -import abc -import numpy as np -from typing import Any, Union from collections.abc import Sequence +import sys +from typing import Any, Union +import warnings + +from jax._src import literals +from jax._src.lib import xla_client as xc +from jax._src.util import use_cpp_class +import numpy as np + # TODO(jakevdp): fix import cycles and define these. Device = Any @@ -29,7 +35,9 @@ # Array is a type annotation for standard JAX arrays and tracers produced by # core functions in jax.lax and jax.numpy; it is not meant to include # future non-standard array types like KeyArray and BInt. -class Array(abc.ABC): + + +class Array: """Array base class for JAX ``jax.Array`` is the public interface for instance checks and type annotation @@ -47,51 +55,58 @@ def f(x: Array) -> Array: # type annotations are valid for traced and non-trace :func:`jax.numpy.array`, :func:`jax.numpy.zeros`, :func:`jax.numpy.ones`, :func:`jax.numpy.full`, :func:`jax.numpy.arange`, etc. """ - # Note: abstract methods for this class are defined dynamically in - # lax_numpy.py # For the sake of static type analysis, these definitions are mirrored in the # associated basearray.pyi file. __slots__ = ['__weakref__'] __hash__ = None + # TODO(jakevdp): set __numpy_dtype__ = None after deprecation period. + @property + def __numpy_dtype__(self) -> np.dtype: + # __numpy_dtype__ protocol added in NumPy v2.4.0. + warnings.warn( + "Implicit conversion of an array to a dtype is deprecated;" + " rather than dtype=arr use dtype=arr.dtype. In the future" + " this will result in an error.", DeprecationWarning, stacklevel=2) + return self.dtype + @property - @abc.abstractmethod def dtype(self) -> np.dtype: """The data type (:class:`numpy.dtype`) of the array.""" + raise NotImplementedError @property - @abc.abstractmethod def ndim(self) -> int: """The number of dimensions in the array.""" + raise NotImplementedError @property - @abc.abstractmethod def size(self) -> int: """The total number of elements in the array.""" + raise NotImplementedError @property - @abc.abstractmethod def shape(self) -> tuple[int, ...]: """The shape of the array.""" + raise NotImplementedError # Documentation for sharding-related methods and properties defined on ArrayImpl: - @abc.abstractmethod def addressable_data(self, index: int) -> Array: """Return an array of the addressable data at a particular index.""" + raise NotImplementedError @property - @abc.abstractmethod def addressable_shards(self) -> Sequence[Shard]: """List of addressable shards.""" + raise NotImplementedError @property - @abc.abstractmethod def global_shards(self) -> Sequence[Shard]: """List of global shards.""" + raise NotImplementedError @property - @abc.abstractmethod def is_fully_addressable(self) -> bool: """Is this Array fully addressable? @@ -103,19 +118,19 @@ def is_fully_addressable(self) -> bool: a jax.Array which is fully replicated can span across multiple hosts and is not fully addressable. """ + raise NotImplementedError @property - @abc.abstractmethod def is_fully_replicated(self) -> bool: """Is this Array fully replicated?""" + raise NotImplementedError @property - @abc.abstractmethod def sharding(self) -> Sharding: """The sharding for the array.""" + raise NotImplementedError @property - @abc.abstractmethod def committed(self) -> bool: """Whether the array is committed or not. @@ -137,20 +152,20 @@ def committed(self) -> bool: a + b # Raises an error ``` - See https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices + See https://docs.jax.dev/en/latest/faq.html#controlling-data-and-computation-placement-on-devices for more information. """ + raise NotImplementedError @property - @abc.abstractmethod def device(self) -> Device | Sharding: """Array API-compatible device attribute. For single-device arrays, this returns a Device. For sharded arrays, this returns a Sharding. """ + raise NotImplementedError - @abc.abstractmethod def copy_to_host_async(self): """Copies an ``Array`` to the host asynchronously. @@ -165,17 +180,24 @@ def copy_to_host_async(self): array, but does not wait for the copy to complete. This may speed up a future on-host access to the array's contents. """ + raise NotImplementedError +Array = use_cpp_class(xc.Array)(Array) Array.__module__ = "jax" + # StaticScalar is the Union of all scalar types that can be converted to # JAX arrays, and are possible to mark as static arguments. StaticScalar = Union[ np.bool_, np.number, # NumPy scalar types bool, int, float, complex, # Python scalar types ] -StaticScalar.__doc__ = "Type annotation for JAX-compatible static scalars." + +if sys.version_info[:2] < (3, 14): + # Python 3.14 raises + # AttributeError: 'typing.Union' object attribute '__doc__' is read-only + StaticScalar.__doc__ = "Type annotation for JAX-compatible static scalars." # ArrayLike is a Union of all objects that can be implicitly converted to a @@ -186,5 +208,10 @@ def copy_to_host_async(self): Array, # JAX array type np.ndarray, # NumPy array type StaticScalar, # valid scalars + literals.TypedNdArray, # Typed array type ] -ArrayLike.__doc__ = "Type annotation for JAX array-like objects." + +if sys.version_info[:2] < (3, 14): + # Python 3.14 raises + # AttributeError: 'typing.Union' object attribute '__doc__' is read-only + ArrayLike.__doc__ = "Type annotation for JAX array-like objects." diff --git a/jax/_src/basearray.pyi b/jax/_src/basearray.pyi index a368b593332d..5f5681d60f7a 100644 --- a/jax/_src/basearray.pyi +++ b/jax/_src/basearray.pyi @@ -11,14 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import abc from collections.abc import Callable, Sequence from types import ModuleType -from typing import Any, Protocol, Union, runtime_checkable +from typing import Any, Protocol, runtime_checkable, Union import numpy as np +from jax._src import literals +from jax._src.partition_spec import PartitionSpec as P +from jax._src.named_sharding import NamedSharding from jax._src.sharding import Sharding -from jax._src.partition_spec import PartitionSpec + # TODO(jakevdp) de-duplicate this with the DTypeLike definition in typing.py. # We redefine these here to prevent circular imports. @@ -39,12 +41,16 @@ Traceback = Any PrecisionLike = Any -class Array(abc.ABC): +class Array: aval: Any @property def dtype(self) -> np.dtype: ... + # TODO(jakevdp) set to None after deprecation period. + @property + def __numpy_dtype__(self) -> np.dtype: ... + @property def ndim(self) -> int: ... @@ -74,12 +80,12 @@ class Array(abc.ABC): # Comparisons # these return bool for object, so ignore override errors. - def __lt__(self, other) -> Array: ... - def __le__(self, other) -> Array: ... - def __eq__(self, other) -> Array: ... # type: ignore[override] - def __ne__(self, other) -> Array: ... # type: ignore[override] - def __gt__(self, other) -> Array: ... - def __ge__(self, other) -> Array: ... + def __lt__(self, other: ArrayLike) -> Array: ... + def __le__(self, other: ArrayLike) -> Array: ... + def __eq__(self, other: ArrayLike) -> Array: ... # type: ignore[override] + def __ne__(self, other: ArrayLike) -> Array: ... # type: ignore[override] + def __gt__(self, other: ArrayLike) -> Array: ... + def __ge__(self, other: ArrayLike) -> Array: ... # Unary arithmetic @@ -90,35 +96,35 @@ class Array(abc.ABC): # Binary arithmetic - def __add__(self, other) -> Array: ... - def __sub__(self, other) -> Array: ... - def __mul__(self, other) -> Array: ... - def __matmul__(self, other) -> Array: ... - def __truediv__(self, other) -> Array: ... - def __floordiv__(self, other) -> Array: ... - def __mod__(self, other) -> Array: ... - def __divmod__(self, other) -> tuple[Array, Array]: ... - def __pow__(self, other) -> Array: ... - def __lshift__(self, other) -> Array: ... - def __rshift__(self, other) -> Array: ... - def __and__(self, other) -> Array: ... - def __xor__(self, other) -> Array: ... - def __or__(self, other) -> Array: ... - - def __radd__(self, other) -> Array: ... - def __rsub__(self, other) -> Array: ... - def __rmul__(self, other) -> Array: ... - def __rmatmul__(self, other) -> Array: ... - def __rtruediv__(self, other) -> Array: ... - def __rfloordiv__(self, other) -> Array: ... - def __rmod__(self, other) -> Array: ... - def __rdivmod__(self, other) -> Array: ... - def __rpow__(self, other) -> Array: ... - def __rlshift__(self, other) -> Array: ... - def __rrshift__(self, other) -> Array: ... - def __rand__(self, other) -> Array: ... - def __rxor__(self, other) -> Array: ... - def __ror__(self, other) -> Array: ... + def __add__(self, other: ArrayLike) -> Array: ... + def __sub__(self, other: ArrayLike) -> Array: ... + def __mul__(self, other: ArrayLike) -> Array: ... + def __matmul__(self, other: ArrayLike) -> Array: ... + def __truediv__(self, other: ArrayLike) -> Array: ... + def __floordiv__(self, other: ArrayLike) -> Array: ... + def __mod__(self, other: ArrayLike) -> Array: ... + def __divmod__(self, other: ArrayLike) -> tuple[Array, Array]: ... + def __pow__(self, other: ArrayLike) -> Array: ... + def __lshift__(self, other: ArrayLike) -> Array: ... + def __rshift__(self, other: ArrayLike) -> Array: ... + def __and__(self, other: ArrayLike) -> Array: ... + def __xor__(self, other: ArrayLike) -> Array: ... + def __or__(self, other: ArrayLike) -> Array: ... + + def __radd__(self, other: ArrayLike) -> Array: ... # type: ignore[misc] + def __rsub__(self, other: ArrayLike) -> Array: ... # type: ignore[misc] + def __rmul__(self, other: ArrayLike) -> Array: ... # type: ignore[misc] + def __rmatmul__(self, other: ArrayLike) -> Array: ... + def __rtruediv__(self, other: ArrayLike) -> Array: ... # type: ignore[misc] + def __rfloordiv__(self, other: ArrayLike) -> Array: ... + def __rmod__(self, other: ArrayLike) -> Array: ... + def __rdivmod__(self, other: ArrayLike) -> Array: ... + def __rpow__(self, other: ArrayLike) -> Array: ... # type: ignore[misc] + def __rlshift__(self, other: ArrayLike) -> Array: ... + def __rrshift__(self, other: ArrayLike) -> Array: ... + def __rand__(self, other: ArrayLike) -> Array: ... + def __rxor__(self, other: ArrayLike) -> Array: ... + def __ror__(self, other: ArrayLike) -> Array: ... def __bool__(self) -> bool: ... def __complex__(self) -> complex: ... @@ -181,12 +187,15 @@ class Array(abc.ABC): promote_integers: bool = True) -> Array: ... def ptp(self, axis: Axis = None, out: None = None, keepdims: bool = False) -> Array: ... - def ravel(self, order: str = 'C') -> Array: ... + def ravel(self, order: str = 'C', *, + out_sharding: NamedSharding | P | None = ...) -> Array: ... @property def real(self) -> Array: ... def repeat(self, repeats: ArrayLike, axis: int | None = None, *, - total_repeat_length: int | None = None) -> Array: ... - def reshape(self, *args: Any, order: str = "C") -> Array: ... + total_repeat_length: int | None = None, + out_sharding: NamedSharding | P | None = None) -> Array: ... + def reshape(self, *args: Any, order: str = "C", + out_sharding: NamedSharding | P | None = ...) -> Array: ... def round(self, decimals: int = 0, out: None = None) -> Array: ... def searchsorted(self, v: ArrayLike, side: str = 'left', sorter: ArrayLike | None = None, *, method: str = 'scan') -> Array: ... @@ -268,6 +277,8 @@ ArrayLike = Union[ Array, # JAX array type np.ndarray, # NumPy array type StaticScalar, # valid scalars + # Typed array and scalar types + literals.TypedNdArray, ] @@ -280,25 +291,38 @@ class _IndexUpdateHelper: class _IndexUpdateRef: def get(self, indices_are_sorted: bool = False, unique_indices: bool = False, mode: str | None = None, fill_value: StaticScalar | None = None, - out_spec: Sharding | PartitionSpec | None = None) -> Array: ... + out_sharding: NamedSharding | P | None = None, + wrap_negative_indices: bool = True) -> Array: ... def set(self, values: Any, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: str | None = None, fill_value: StaticScalar | None = None) -> Array: ... + mode: str | None = None, fill_value: StaticScalar | None = None, + out_sharding: NamedSharding | P | None = None, + wrap_negative_indices: bool = True) -> Array: ... def add(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: str | None = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None, + out_sharding: NamedSharding | P | None = None, + wrap_negative_indices: bool = True) -> Array: ... def subtract(self, values: Any, *, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: str | None = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None, + wrap_negative_indices: bool = True) -> Array: ... def mul(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: str | None = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None, + wrap_negative_indices: bool = True) -> Array: ... def multiply(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: str | None = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None, + wrap_negative_indices: bool = True) -> Array: ... def divide(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: str | None = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None, + wrap_negative_indices: bool = True) -> Array: ... def power(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: str | None = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None, + wrap_negative_indices: bool = True) -> Array: ... def min(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: str | None = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None, + wrap_negative_indices: bool = True) -> Array: ... def max(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: str | None = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None, + wrap_negative_indices: bool = True) -> Array: ... def apply(self, func: Callable[[ArrayLike], ArrayLike], indices_are_sorted: bool = False, - unique_indices: bool = False, mode: str | None = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None, + wrap_negative_indices: bool = True) -> Array: ... diff --git a/jax/_src/blocked_sampler.py b/jax/_src/blocked_sampler.py index 3021b6a1604f..2afe3f82826a 100644 --- a/jax/_src/blocked_sampler.py +++ b/jax/_src/blocked_sampler.py @@ -14,10 +14,10 @@ from collections.abc import Sequence from typing import Any, Protocol -import jax + +from jax._src import numpy as jnp from jax._src import random from jax._src.typing import Array, ArrayLike -from jax import numpy as jnp NdKeyList = Any Shape = random.Shape @@ -111,7 +111,7 @@ def blocked_fold_in( def _keygen_loop(axis, prefix): if axis == len(block_size_in_tiles): - subtile_key = jax.random.fold_in( + subtile_key = random.fold_in( global_key, _compute_tile_index( block_index, block_size_in_tiles, total_size_in_tiles, prefix)) return subtile_key @@ -120,7 +120,7 @@ def _keygen_loop(axis, prefix): for i in range(block_size_in_tiles[axis]): keys.append(_keygen_loop(axis+1, prefix+(i,))) return keys - return _keygen_loop(0, tuple()) + return _keygen_loop(0, ()) def sample_block( @@ -130,7 +130,7 @@ def sample_block( tile_size: Shape, *args, **kwargs - ) -> jax.Array: + ) -> Array: """Draws random samples for a single block. This function is intended to be used in conjunction with `blocked_fold_in`: @@ -155,12 +155,12 @@ def sample_block( """ size_in_tiles = tuple( _shape // _element for _shape, _element in zip(block_size, tile_size)) - def _nested_index(arr: jax.Array, idx: Sequence[int]) -> jax.Array: + def _nested_index(arr: Array, idx: Sequence[int]) -> Array: if len(idx) == 1: return arr[idx[0]] return _nested_index(arr[idx[0]], idx[1:]) - def _sample_loop(axis: int, prefix: tuple[int, ...]) -> jax.Array: + def _sample_loop(axis: int, prefix: tuple[int, ...]) -> Array: if axis == len(size_in_tiles): return sampler_fn(_nested_index(keys, prefix), *args, shape=tile_size, **kwargs) @@ -169,4 +169,4 @@ def _sample_loop(axis: int, prefix: tuple[int, ...]) -> jax.Array: for i in range(size_in_tiles[axis]): samples.append(_sample_loop(axis+1, prefix+(i,))) return jnp.concatenate(samples, axis=axis) - return _sample_loop(0, tuple()) + return _sample_loop(0, ()) diff --git a/jax/_src/buffer_callback.py b/jax/_src/buffer_callback.py new file mode 100644 index 000000000000..1049b2def14c --- /dev/null +++ b/jax/_src/buffer_callback.py @@ -0,0 +1,266 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable, Sequence +import functools +from typing import Any + +import numpy as np + +from jax._src import core +from jax._src import dispatch +from jax._src import effects +from jax._src import ffi +from jax._src import tree_util +from jax._src import util +from jax._src.interpreters import ad +from jax._src.interpreters import batching +from jax._src.interpreters import mlir +from jax._src.lib import ffi as ffi_lib + +export = util.set_module("jax.experimental.buffer_callback") +Buffer = export(ffi_lib.Buffer) +ExecutionStage = export(ffi_lib.ExecutionStage) +ExecutionContext = export(ffi_lib.ExecutionContext) + + +def buffer_callback( + callback: Callable[..., None], + result_shape_dtypes: object, + *, + has_side_effect: bool = False, + vmap_method: str | None = None, + input_output_aliases: dict[int, int] | None = None, + command_buffer_compatible: bool = False, +): + """An experimental callback that operates in place on device buffers. + + Only supported on CPU and GPU backends. + + Note that the plan is for this to eventually be replaced by a consolidated + callback API built using JAX mutable arrays, but for now this provides a + mechanism for prototyping computational kernels using other Python libraries + including Numpy, PyTorch, Cupy, and others. + + Let's start with a simple example: + + >>> def py_add_one_inplace(ctx, out, x): + ... np.asarray(out)[...] = np.asarray(x) + 1 + ... + >>> x = jnp.array(41, dtype=jnp.int32) + >>> out_type = jax.ShapeDtypeStruct(x.shape, x.dtype) + >>> add_one = buffer_callback(py_add_one_inplace, out_type) + >>> add_one(x) # doctest: +SKIP + Array(42, dtype=int32) + + In this example, we're executing a numpy computation via JAX, and this could + have been implemented using :func:`jax.pure_callback`, but in this case, the + output is being populated in-place. This means that JAX doesn't need to copy + the output arrays upon returning from the callback. Note that even though the + callback function operates on mutable buffers, JAX still sees this as an + operation that consumes and produces regular immutable JAX arrays. + + Unlike the other JAX callback APIs, ``buffer_callback`` requires that the + user-defined Python function have the following signature: + + .. code-block:: python + + def callback(ctx: ExecutionContext, out, *args) -> None: + ... + + where ``ctx`` is an instance of + :class:`~jax.experimental.buffer_callback.ExecutionContext`, which mainly + provides access to XLA's computation stream when running on GPU, ``out`` is a + pytree of mutable :class:`~jax.experimental.buffer_callback.Buffer` objects, + and the ``args`` arguments have the same pytree structure as the inputs, but + each leaf is :class:`~jax.experimental.buffer_callback.Buffer`. This callback + should not return any values, and it should overwrite the ``out`` buffers in + place to output values back to JAX. + + It's important to note that this Python function can't really be called + except via ```buffer_callback`` itself, because it's not (yet!) possible to + construct mutable JAX buffers directly in Python. + + The bespoke :class:`~jax.experimental.buffer_callback.Buffer` type is an + array-like object that supports the ``__array__`` protocol on CPU, the + ``__cuda_array_interface__`` protocol on GPU, and the ``__dlpack__`` protocol + on both CPU and GPU. + + Args: + callback: A Python function with the signature and behavior described above. + result_shape_dtypes: A pytree whose leaves have ``shape`` and ``dtype`` + attributes, with a structure that matches the expected output of the + callback function at runtime. :class:`jax.ShapeDtypeStruct` is often used + to define leaf values. + has_side_effect: Whether the callback has side effects. + vmap_method: A string specifying how the callback transforms under + :func:`~jax.vmap` as described in the docs for :func:`~jax.pure_callback`. + input_output_aliases: a dictionary mapping the index of some inputs to + the index of the output that aliases them. These indices are in the + flattened inputs and outputs. + command_buffer_compatible: if ``True``, the callback will be traced into + the command buffer. This means that the Python code should only be + executed once, and then the operations will be replayed for every + subsequent call. + + Returns: + A new callable that accepts :class:`jax.Array` inputs (and pytrees thereof), + and pytree of :class:`jax.Array` objects whose structure matches that + of ``result_shape_dtypes``. + + See Also: + - :func:`jax.pure_callback`: callback designed for pure host functions. + - :func:`jax.experimental.io_callback`: callback designed for impure host + functions. + - :func:`jax.debug.callback`: callback designed for general-purpose + debugging. + - :func:`jax.debug.print`: callback designed for printing. + """ + flat_shape_dtypes, out_tree = tree_util.tree_flatten(result_shape_dtypes) + flat_result_avals = tuple( + core.ShapedArray(x.shape, x.dtype) for x in flat_shape_dtypes + ) + + def wrapped_callback(*args, **kwargs): + flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) + + in_avals = [core.get_aval(x) for x in flat_args] + static_input_output_aliases: tuple[tuple[int, int], ...] = () + if input_output_aliases is not None: + for i_idx, o_idx in sorted(input_output_aliases.items()): + i_idx, o_idx = int(i_idx), int(o_idx) + if i_idx >= len(args): + raise ValueError( + f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' " + f"with input index {i_idx} outside the range [0, " + f"{len(args)}).") + if o_idx >= len(flat_result_avals): + raise ValueError( + f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' " + f"with output index {o_idx} outside the range [0, " + f"{len(flat_result_avals)}).") + in_aval = in_avals[i_idx] + out_aval = flat_result_avals[o_idx] + if not ffi._check_compatible_avals(in_aval, out_aval): + raise ValueError( + f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' " + f"referring to an input with abstract value {in_aval} and an " + f"output with a different abstract value {out_aval}.") + static_input_output_aliases += ((i_idx, o_idx),) + + out_flat = buffer_callback_p.bind( + *flat_args, + callback=callback, + result_avals=flat_result_avals, + in_tree=in_tree, + out_tree=out_tree, + vmap_method=vmap_method, + has_side_effect=has_side_effect, + input_output_aliases=static_input_output_aliases, + command_buffer_compatible=command_buffer_compatible, + ) + return tree_util.tree_unflatten(out_tree, out_flat) + + return wrapped_callback + + +buffer_callback_p = core.Primitive("buffer_callback") +buffer_callback_p.multiple_results = True +dispatch.simple_impl(buffer_callback_p) + + +class BufferCallbackEffect(effects.Effect): + def __str__(self): + return "BufferCallback" + +_BufferCallbackEffect = BufferCallbackEffect() +effects.lowerable_effects.add_type(BufferCallbackEffect) +effects.control_flow_allowed_effects.add_type(BufferCallbackEffect) + + +@buffer_callback_p.def_effectful_abstract_eval +def _buffer_callback_abstract_eval( + *args, + result_avals: tuple[core.ShapedArray, ...], + has_side_effect: bool, + **_, +): + del args + effects = {_BufferCallbackEffect} if has_side_effect else core.no_effects + return result_avals, effects + + +def _buffer_callback_jvp_rule(*args, **kwargs): + del args, kwargs + raise ValueError( + "Buffer callbacks do not support JVP. " + "Please use `jax.custom_jvp` to use callbacks while taking gradients.") +ad.primitive_jvps[buffer_callback_p] = _buffer_callback_jvp_rule + + +def _buffer_callback_transpose_rule(*args, **kwargs): + del args, kwargs + raise ValueError( + "Buffer callbacks do not support transpose. " + "Please use `jax.custom_vjp` to use callbacks while taking gradients.") +ad.primitive_transposes[buffer_callback_p] = _buffer_callback_transpose_rule + +batching.primitive_batchers[buffer_callback_p] = functools.partial( + ffi.ffi_batching_rule, buffer_callback_p +) + + +def _buffer_callback_lowering( + ctx: mlir.LoweringRuleContext, + *args: Any, + callback, + in_tree: Any, + out_tree: Any, + has_side_effect: bool, + input_output_aliases: Sequence[tuple[int, int]], + command_buffer_compatible: bool, + **_, +): + + if len(ctx.module_context.platforms) > 1: + raise NotImplementedError("multi-platform lowering for buffer_callback") + platform = ctx.module_context.platforms[0] + target_name = { + "cpu": "xla_buffer_python_cpu_callback", + "cuda": "xla_buffer_python_gpu_callback", + "rocm": "xla_buffer_python_gpu_callback", + }.get(platform) + if target_name is None: + raise ValueError(f"`buffer_callback` not supported on {platform} backend.") + + if command_buffer_compatible and platform in ("cuda", "rocm"): + target_name += "_cmd_buffer" + + def wrapped_callback(exec_ctx, *args: Any): + args_in, args_out = util.split_list(args, [in_tree.num_leaves]) + py_args_in, py_kwargs_in = tree_util.tree_unflatten(in_tree, args_in) + py_args_out = tree_util.tree_unflatten(out_tree, args_out) + if callback(exec_ctx, py_args_out, *py_args_in, **py_kwargs_in) is not None: + raise ValueError("buffer_callback callback must not return any values.") + return () + + ctx.module_context.add_host_callback(wrapped_callback) + index = np.uint64(len(ctx.module_context.host_callbacks) - 1) + rule = ffi.ffi_lowering( + target_name, + has_side_effect=has_side_effect, + operand_output_aliases=dict(input_output_aliases), + ) + return rule(ctx, *args, index=index) +mlir.register_lowering(buffer_callback_p, _buffer_callback_lowering) diff --git a/jax/_src/cache_key.py b/jax/_src/cache_key.py index e4b6e7a2669c..296f65b5ed3f 100644 --- a/jax/_src/cache_key.py +++ b/jax/_src/cache_key.py @@ -23,6 +23,7 @@ from jax._src import config from jax._src.lib import version_str as jaxlib_version_str +from jax._src.lib import _jax from jax._src.lib import xla_client from jax._src.lib.mlir import ir from jax._src.lib.mlir import passmanager as pm @@ -56,7 +57,7 @@ def get_flag_prefixes() -> list[str]: def custom_hook() -> str: """Custom hook for any addition to the cache key. - The custom hook will be called everytime get() is called and can be + The custom hook will be called every time get() is called and can be defined to return a string that will be hashed into the cache key. """ return "" @@ -110,6 +111,10 @@ def get( bytes(jaxlib_version_str.encode("utf-8")) ), ), + ( + "backend version", + lambda hash_obj: _hash_platform(hash_obj, backend) + ), ( "XLA flags", lambda hash_obj: _hash_xla_flags(hash_obj, get_flag_prefixes()), @@ -126,7 +131,7 @@ def get( ), ( "accelerator_config", - lambda hash_obj: _hash_accelerator_config(hash_obj, devices, backend), + lambda hash_obj: _hash_accelerator_config(hash_obj, devices), ), ( "compression", @@ -220,7 +225,7 @@ def _hash_devices(hash_obj, devices: np.ndarray) -> None: _hash_string(hash_obj, device.device_kind) -def _hash_accelerator_config(hash_obj, accelerators: np.ndarray, backend): +def _hash_accelerator_config(hash_obj, accelerators: np.ndarray): accelerator_devices = [] for accelerator in accelerators.flat: accelerator_devices.append(accelerator) @@ -228,14 +233,13 @@ def _hash_accelerator_config(hash_obj, accelerators: np.ndarray, backend): hash_obj.update( xla_client.get_topology_for_devices(accelerator_devices).serialize() ) - except xla_client._xla.XlaRuntimeError as ex: + except _jax.JaxRuntimeError as ex: # Fall back for those backends that do not support serialized # PjRtTopologyDescription as yet. logger.info("get (_hash_accelerator_config): unable to hash " "accelerator config, falling back to hashing " - "devices + platform: %s (type %s)", ex, type(ex)) + "devices %s (type %s)", ex, type(ex)) _hash_devices(hash_obj, accelerators) - _hash_platform(hash_obj, backend) # LINT.IfChange(xla_flags) xla_flags_to_exclude_from_cache_key = [ @@ -330,7 +334,6 @@ def _hash_serialized_compile_options(hash_obj, compile_options_obj, def _hash_platform(hash_obj, backend): _hash_string(hash_obj, backend.platform) _hash_string(hash_obj, backend.platform_version) - _hash_string(hash_obj, backend.runtime_type) def _hash_xla_flags(hash_obj, extra_flag_prefixes: list[str]): diff --git a/jax/_src/callback.py b/jax/_src/callback.py index bdceb98d92b7..1cafe97d663a 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -18,12 +18,11 @@ import dataclasses import functools import logging -from typing import Any +from typing import Any, cast -import jax +from jax._src import api from jax._src import config from jax._src import core -from jax._src import deprecations from jax._src import dispatch from jax._src import dtypes from jax._src import effects @@ -36,21 +35,15 @@ from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir -from jax._src.interpreters import xla -from jax._src.lax.control_flow.loops import map as lax_map from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo -from jax._src.sharding_impls import SdyArraySharding, SdyArrayShardingList, SingleDeviceSharding -from jax._src.typing import DeprecatedArg +from jax._src.sharding_impls import SdyArray, SdyArrayList, SdyDim, SingleDeviceSharding +from jax._src.typing import Array, DeprecatedArg import numpy as np logger = logging.getLogger(__name__) -# TODO(dfm): Remove after 6 months. -# Added Oct 1, 2024 -deprecations.register("jax-callback-vectorized") - # `pure_callback_p` is the main primitive for staging out Python pure callbacks. pure_callback_p = core.Primitive("pure_callback") pure_callback_p.multiple_results = True @@ -72,7 +65,7 @@ class _FlatCallback: callback_func: Callable[..., Any] in_tree: tree_util.PyTreeDef # (args, kwargs) pytree for `callback_func`. - def __call__(self, *flat_args: jax.Array) -> Sequence[jax.Array]: + def __call__(self, *flat_args: Array) -> Sequence[Array]: args, kwargs = tree_util.tree_unflatten(self.in_tree, flat_args) return tree_util.tree_leaves(self.callback_func(*args, **kwargs)) @@ -82,20 +75,19 @@ def pure_callback_impl( result_avals, callback: _FlatCallback, sharding: SingleDeviceSharding | None, - vectorized: bool | DeprecatedArg, vmap_method: str | None, ): - del sharding, vectorized, vmap_method, result_avals + del sharding, vmap_method, result_avals try: - cpu_device, *_ = jax.local_devices(backend="cpu") + cpu_device, *_ = xb.local_devices(backend="cpu") except RuntimeError as e: raise RuntimeError( "jax.pure_callback failed to find a local CPU device to place the" " inputs on. Make sure \"cpu\" is listed in --jax_platforms or the" " JAX_PLATFORMS environment variable." ) from e - args = jax.device_put(args, cpu_device) - with jax.default_device(cpu_device): + args = api.device_put(args, cpu_device) + with config.default_device(cpu_device): try: return tree_util.tree_map(np.asarray, callback(*args)) except BaseException: @@ -113,10 +105,9 @@ def pure_callback_abstract_eval( callback: _FlatCallback, result_avals, sharding: SingleDeviceSharding | None, - vectorized: bool | DeprecatedArg, vmap_method: str | None, ): - del avals, callback, sharding, vectorized, vmap_method + del avals, callback, sharding, vmap_method return result_avals @@ -143,6 +134,17 @@ def pure_callback_transpose_rule(*args, **kwargs): ffi.ffi_batching_rule, pure_callback_p ) +def _get_sdy_array_list_for_callbacks(avals: Sequence[core.ShapedArray]) -> SdyArrayList: + """Returns an SdyArrayList with `max(1, len(avals))` replicated shardings.""" + ndims = [0] + if avals: + ndims = [x.ndim for x in avals if isinstance(x, core.ShapedArray)] + return SdyArrayList([ + SdyArray( + mesh_shape=(), + dim_shardings=[SdyDim(axes=[], is_open=False)] * ndim, + logical_device_ids=()) for ndim in ndims]) + def _callback_op_sharding( axis_context, sharding: SingleDeviceSharding | None, avals_out @@ -161,14 +163,7 @@ def _callback_op_sharding( " computations" ) if config.use_shardy_partitioner.value: - assert len(avals_out) == 1 - op_sharding = sharding_impls.SdyArrayShardingList([ - sharding_impls.SdyArraySharding( - mesh_shape=(), - dimension_shardings=[ - sharding_impls.SdyDimSharding(axes=[], is_closed=True) - ] * avals_out[0].ndim, - logical_device_ids=())]) + op_sharding = _get_sdy_array_list_for_callbacks(avals_out) else: op_sharding = xc.OpSharding() # type: ignore[assignment] op_sharding.type = xc.OpSharding.Type.MANUAL @@ -200,10 +195,14 @@ def _callback_op_sharding( # program has bulk array semantics, so we run the callback with a MAXIMAL # sharding and hence execute it only once on the full logical value). if config.use_shardy_partitioner.value: - op_sharding = sharding_impls.SdyArrayShardingList([ - sharding_impls.SdyArraySharding( + # For shardy, we need to have the same number of shardy annotations as the + # number of result ops. If there are no result ops, we need 1 shardy + # annotation. + num_sdy_shardings = max(1, len(avals_out)) + op_sharding = SdyArrayList(num_sdy_shardings * [ + SdyArray( mesh_shape=(), - dimension_shardings=[], + dim_shardings=[], logical_device_ids=(device_index,))]) else: op_sharding = xc.OpSharding() # type: ignore[assignment] @@ -239,12 +238,15 @@ def _callback(*flat_args): ctx.avals_in, ctx.avals_out, has_side_effect=False, + returns_token=False, sharding=op_sharding, ) return result -mlir.register_lowering(pure_callback_p, pure_callback_lowering) +# TODO(phawkins): On TPU, these have an embedded channel ID that should be +# unique for each callback. Caching defeats this. +mlir.register_lowering(pure_callback_p, pure_callback_lowering, cacheable=False) def _check_shape_dtype(shape_dtype): dt = np.dtype(shape_dtype.dtype) @@ -287,7 +289,7 @@ def pure_callback( When `vmap`-ed the behavior will depend on the value of the ``vmap_method``. * Calling :func:`~jax.vmap` on a callback without an explicit ``vmap_method`` - is deprecated and it will eventually raise ``NotImplementedError``. + raises a ``NotImplementedError``. * ``vmap_method="sequential"`` uses :func:`~jax.lax.map` to loop over the batched arguments, calling ``callback`` once for each batch element. * ``vmap_method="sequential_unrolled"`` is like ``sequential``, but the loop @@ -297,9 +299,8 @@ def pure_callback( * ``vmap_method="broadcast_all"`` behaves like ``expand_dims``, but the inputs are tiled to the expected batched shape. - If necessary, the legacy behavior provided by the deprecated - ``vectorized=True`` argument can be recovered using - ``vmap_method="legacy_vectorized"``. + If necessary, the legacy behavior provided by the removed ``vectorized=True`` + argument can be recovered using ``vmap_method="legacy_vectorized"``. The current default behavior is to use ``vmap_method="sequential"`` when not specified, but this behavior is deprecated, and in the future, the @@ -366,20 +367,13 @@ def pure_callback( (4,) (4,) Array([1., 2., 3., 4.], dtype=float32) - .. _External Callbacks: https://jax.readthedocs.io/en/latest/external-callbacks.html + .. _External Callbacks: https://docs.jax.dev/en/latest/external-callbacks.html """ - if not isinstance(vectorized, DeprecatedArg) and not vectorized is None: - deprecations.warn( - "jax-callback-vectorized", - "The vectorized argument of jax.pure_callback is deprecated and setting " - "it will soon raise an error. To avoid an error in the future, and to " - "suppress this warning, please use the vmap_method argument instead.", - stacklevel=2) - if vmap_method is not None: - raise ValueError( - "the vectorized and vmap_method arguments of jax.pure_callback cannot " - "be used together. Please use the vmap_method argument.") - vmap_method = "legacy_vectorized" if vectorized else "sequential" + # TODO(danfm): Remove this check 3 months after v0.6.0 is released. + if not isinstance(vectorized, DeprecatedArg): + raise ValueError( + "The 'vectorized' argument of jax.pure_callback was removed in JAX " + "v0.6.0. Use 'vmap_method' instead.") allowed_vmap_methods = ["sequential", "sequential_unrolled", "expand_dims", "broadcast_all", "legacy_vectorized", None] if vmap_method not in allowed_vmap_methods: @@ -397,7 +391,6 @@ def pure_callback( callback=_FlatCallback(callback, in_tree), result_avals=tuple(flat_result_avals), sharding=sharding, - vectorized=vectorized, vmap_method=vmap_method, ) return tree_util.tree_unflatten(out_tree, out_flat) @@ -434,15 +427,15 @@ def io_callback_impl( ): del result_avals, sharding, ordered try: - cpu_device, *_ = jax.local_devices(backend="cpu") + cpu_device, *_ = xb.local_devices(backend="cpu") except RuntimeError as e: raise RuntimeError( "jax.io_callback failed to find a local CPU device to place the" " inputs on. Make sure \"cpu\" is listed in --jax_platforms or the" " JAX_PLATFORMS environment variable." ) from e - args = jax.device_put(args, cpu_device) - with jax.default_device(cpu_device): + args = api.device_put(args, cpu_device) + with config.default_device(cpu_device): try: return tree_util.tree_map(np.asarray, callback(*args)) except BaseException: @@ -482,6 +475,7 @@ def io_callback_transpose_rule(*args, **kwargs): def io_callback_batching_rule( args, dims, callback, result_avals, sharding, ordered ): + from jax._src.lax.control_flow.loops import map as lax_map # pytype: disable=import-error if ordered: raise ValueError("Cannot `vmap` ordered IO callback.") is_batched = [d is not batching.not_mapped for d in dims] @@ -521,6 +515,7 @@ def _callback(*flat_args): ctx.avals_in, ctx.avals_out, has_side_effect=True, + returns_token=True, sharding=op_sharding, ) ctx.set_tokens_out(mlir.TokenSet({_OrderedIOEffect: token})) @@ -533,13 +528,14 @@ def _callback(*flat_args): ctx.avals_in, ctx.avals_out, has_side_effect=True, + returns_token=False, sharding=op_sharding, ) return result - -mlir.register_lowering(io_callback_p, io_callback_lowering) - +# TODO(phawkins): On TPU, these have an embedded channel ID that should be +# unique for each callback. Caching defeats this. +mlir.register_lowering(io_callback_p, io_callback_lowering, cacheable=False) def io_callback( callback: Callable[..., Any], @@ -575,7 +571,7 @@ def io_callback( - :func:`jax.debug.callback`: callback designed for general-purpose debugging. - :func:`jax.debug.print`: callback designed for printing. - .. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html + .. _External Callbacks: https://docs.jax.dev/en/latest/notebooks/external_callbacks.html """ flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) tree_util.tree_map(_check_shape_dtype, result_shape_dtypes) @@ -592,26 +588,39 @@ def io_callback( return tree_util.tree_unflatten(out_tree, out_flat) - def is_empty_shape(s: core.Shape) -> bool: return any(d == 0 for d in s) +_XLA_HOST_TRANSFER_PJRT_RENDEZVOUS_HANDLER_NAME = "pjrt_rendezvous" + + def send_to_host( channel: int, token: hlo.TokenType, operand: Any, - name: str, + name: str | None = None, *, - sharding: SdyArrayShardingList | xc.OpSharding | None = None, + sharding: SdyArrayList | xc.OpSharding | None = None, ) -> ir.Value: channel_handle = hlo.ChannelHandle.get(channel, mlir.SEND_TO_HOST_TYPE) send_op = hlo.SendOp([operand], token, channel_handle, is_host_transfer=ir.BoolAttr.get(True)) - send_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get( - dict( - _xla_host_transfer_handler_name=ir.StringAttr.get(str(name)), - _xla_host_transfer_rendezvous=ir.StringAttr.get(str(name)))) + if mlir.USE_NEW_TPU_CALLBACK_LOWERING: + send_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get( + dict( + _xla_host_transfer_handler_name=ir.StringAttr.get( + _XLA_HOST_TRANSFER_PJRT_RENDEZVOUS_HANDLER_NAME + ), + _xla_host_transfer_rendezvous=ir.StringAttr.get(str(channel)), + ) + ) + else: + assert name is not None + send_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get( + dict( + _xla_host_transfer_handler_name=ir.StringAttr.get(str(name)), + _xla_host_transfer_rendezvous=ir.StringAttr.get(str(name)))) if sharding is not None: if config.use_shardy_partitioner.value: # `SendOp`'s return type is a StableHLO `TokenType`. However JAX passed @@ -619,11 +628,11 @@ def send_to_host( # we need to create an equivalent sharding with no dimensions. If there # are multiple shardings, just grab the first one since all these # shardings should be the same. - assert isinstance(sharding, SdyArrayShardingList) + assert isinstance(sharding, SdyArrayList) assert len(sharding.shardings) >= 1 - sharding = SdyArrayShardingList([ - SdyArraySharding( - mesh_shape=(), dimension_shardings=[], + sharding = SdyArrayList([ + SdyArray( + mesh_shape=(), dim_shardings=[], logical_device_ids=sharding.shardings[0].logical_device_ids)]) mlir.set_sharding(send_op, sharding) return send_op.result @@ -633,32 +642,43 @@ def receive_from_host( channel: int, token: hlo.TokenType, out_aval: core.ShapedArray, - name: str, + name: str | None = None, *, - sharding: SdyArrayShardingList | xc.OpSharding | None = None, + sharding: SdyArrayList | xc.OpSharding | None = None, ) -> tuple[ir.Value, ir.Value]: channel_handle = hlo.ChannelHandle.get(channel, mlir.RECV_FROM_HOST_TYPE) recv_op = hlo.RecvOp([mlir.aval_to_ir_type(out_aval), hlo.TokenType.get()], token, channel_handle, is_host_transfer=ir.BoolAttr.get(True)) - recv_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get( - dict( - _xla_host_transfer_handler_name=ir.StringAttr.get(str(name)), - _xla_host_transfer_rendezvous=ir.StringAttr.get(str(name)))) + if mlir.USE_NEW_TPU_CALLBACK_LOWERING: + recv_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get( + dict( + _xla_host_transfer_handler_name=ir.StringAttr.get( + _XLA_HOST_TRANSFER_PJRT_RENDEZVOUS_HANDLER_NAME + ), + _xla_host_transfer_rendezvous=ir.StringAttr.get(str(channel)), + ) + ) + else: + assert name is not None + recv_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get( + dict( + _xla_host_transfer_handler_name=ir.StringAttr.get(str(name)), + _xla_host_transfer_rendezvous=ir.StringAttr.get(str(name)))) if sharding is not None: if config.use_shardy_partitioner.value: - assert isinstance(sharding, SdyArrayShardingList) + assert isinstance(sharding, SdyArrayList) assert len(sharding.shardings) >= 1 - # `RecvOp`'s last argument is a `TokenType`. Since Shardy requires the + # `RecvOp`'s last argument is a `TokenType`. Since Shardy requires the # number of shardings to match the number of results, but JAX only sees # the array result, we need to add an equivalent sharding for the token. # Note that even if a function returns N results, we will end up with N # `RecvOp`s, so we only need to get the first sharding. All shardings are # the same anyways, operating on the same single device ID. - sharding = SdyArrayShardingList([ + sharding = SdyArrayList([ sharding.shardings[0], - SdyArraySharding( - mesh_shape=(), dimension_shardings=[], + SdyArray( + mesh_shape=(), dim_shardings=[], logical_device_ids=sharding.shardings[0].logical_device_ids)]) mlir.set_sharding(recv_op, sharding) # Token should be at the end of the results @@ -666,8 +686,26 @@ def receive_from_host( return token, result +def _aval_to_xla_shape(aval: core.AbstractValue) -> xc.Shape: + try: + return _xla_shape_handlers[type(aval)](aval) + except KeyError as err: + raise TypeError(f"No xla_shape_handler for type: {type(aval)}") from err + +_xla_shape_handlers: dict[type[core.AbstractValue], + Callable[[Any], xc.Shape]] = {} + +def _make_array_shape(aval: core.ShapedArray) -> xc.Shape: + aval = core.physical_aval(aval) + dtype = np.dtype('bool') if aval.dtype == dtypes.float0 else aval.dtype + return xc.Shape.array_shape(dtype, aval.shape) +_xla_shape_handlers[core.ShapedArray] = _make_array_shape + +_xla_shape_handlers[core.AbstractToken] = lambda _: xc.Shape.token_shape() + + def _emit_tpu_python_callback( - backend: xb.XlaBackend, + backend: xc.Client, ctx: mlir.LoweringRuleContext, callback, token: Any | None, @@ -677,7 +715,8 @@ def _emit_tpu_python_callback( result_avals: Sequence[core.ShapedArray], result_shapes: Sequence[xc.Shape], *, - sharding: SdyArrayShardingList | xc.OpSharding | None = None, + returns_token: bool, + sharding: SdyArrayList | xc.OpSharding | None = None, ) -> tuple[Sequence[ir.Value], Any]: token = token or hlo.create_token() _wrapped_callback = callback @@ -695,27 +734,62 @@ def _wrapped_callback(*args): # pylint: disable=function-redefined send_channel = ctx.module_context.new_channel() dummy_send_aval = core.ShapedArray((1,), np.float32) dummy_send_val = mlir.ir_constant(np.zeros(1, np.float32)) - operand_shapes = [*operand_shapes, - xla.aval_to_xla_shapes(dummy_send_aval)[0]] - token = send_to_host(send_channel, token, dummy_send_val, callback.__name__, - sharding=sharding) + operand_shapes = [*operand_shapes, _aval_to_xla_shape(dummy_send_aval)] + if mlir.USE_NEW_TPU_CALLBACK_LOWERING: + token = send_to_host(send_channel, token, dummy_send_val, + sharding=sharding) + else: + token = send_to_host(send_channel, token, dummy_send_val, + _wrapped_callback.__name__, sharding=sharding) send_channels.append(send_channel) else: for operand in operands: channel = ctx.module_context.new_channel() - token = send_to_host(channel, token, operand, callback.__name__, - sharding=sharding) + if mlir.USE_NEW_TPU_CALLBACK_LOWERING: + token = send_to_host(channel, token, operand, sharding=sharding) + else: + token = send_to_host(channel, token, operand, + _wrapped_callback.__name__, sharding=sharding) send_channels.append(channel) recv_channels = [] outputs = [] - for result_aval in result_avals: + if returns_token and not result_avals: + # If the caller expects a token, we need at least one result so that the + # token from the recv is used as an indication that the callback is + # complete. Without this, we would only wait for the send to finish. + callback_without_results = _wrapped_callback + def _wrapped_callback(*args): # pylint: disable=function-redefined + callback_without_results(*args) + return 0.0, + dummy_recv_aval = core.ShapedArray((), np.float32) + result_shapes = [_aval_to_xla_shape(dummy_recv_aval)] channel = ctx.module_context.new_channel() - assert isinstance(result_aval, core.ShapedArray) - token, out = receive_from_host(channel, token, result_aval, - callback.__name__, sharding=sharding) - outputs.append(out) + if mlir.USE_NEW_TPU_CALLBACK_LOWERING: + token, _ = receive_from_host( + channel, token, dummy_recv_aval, sharding=sharding + ) + else: + token, _ = receive_from_host( + channel, token, dummy_recv_aval, _wrapped_callback.__name__, + sharding=sharding + ) recv_channels.append(channel) + else: + for result_aval in result_avals: + channel = ctx.module_context.new_channel() + assert isinstance(result_aval, core.ShapedArray) + if mlir.USE_NEW_TPU_CALLBACK_LOWERING: + token, out = receive_from_host( + channel, token, result_aval, sharding=sharding + ) + else: + token, out = receive_from_host( + channel, token, result_aval, _wrapped_callback.__name__, + sharding=sharding + ) + outputs.append(out) + recv_channels.append(channel) ifrt_callback = backend.make_python_callback_from_host_send_and_recv( _wrapped_callback, operand_shapes, result_shapes, send_channels, recv_channels, pickle_util.dumps) @@ -723,21 +797,6 @@ def _wrapped_callback(*args): # pylint: disable=function-redefined return outputs, token -def _layout_to_mlir_layout(minor_to_major: Sequence[int] | None): - if minor_to_major is None: - # Needed for token layouts - layout: np.ndarray = np.zeros((0,), dtype="int64") - else: - layout = np.array(minor_to_major, dtype="int64") - return ir.DenseIntElementsAttr.get(layout, type=ir.IndexType.get()) - - -def _aval_to_default_layouts(aval): - avals = [core.physical_aval(aval)] - # Row major order is default for `NumPy`. - return [list(range(aval.ndim - 1, -1, -1)) for aval in avals] - - def emit_python_callback( ctx: mlir.LoweringRuleContext, callback, @@ -747,30 +806,44 @@ def emit_python_callback( result_avals: Sequence[core.ShapedArray], *, has_side_effect: bool, - sharding: SdyArrayShardingList | xc.OpSharding | None = None, - operand_layouts: Sequence[Sequence[int] | None] | None = None, - result_layouts: Sequence[Sequence[int] | None] | None = None, + returns_token: bool = True, + partitioned: bool = False, + sharding: SdyArrayList | xc.OpSharding | None = None, ) -> tuple[Sequence[mlir.IrValues], Any, Any]: - """Emits MLIR that calls back to a provided Python function.""" + """Emits MLIR that calls back to a provided Python function. + + Args: + ctx: The lowering context. + callback: The Python callback function. + token: The token to use for the callback. + operands: The operands to the callback. + operand_avals: The abstract values of the operands. + result_avals: The abstract values of the results. + has_side_effect: Whether the callback has side effects. + returns_token: Whether the callback should return a token. + partitioned: If True, then `callback` is called on local shards only. If + False, then `callback` is called on all shards. + sharding: The sharding of the callback. + + Returns: + A tuple of MLIR result values, a new token (if any), and the host callback + object. + """ if len(ctx.module_context.platforms) > 1: raise NotImplementedError("multi-platform lowering for python_callback") platform = ctx.module_context.platforms[0] if platform not in {"cpu", "cuda", "rocm", "tpu"}: raise ValueError( f"`EmitPythonCallback` not supported on {platform} backend.") - backend = ctx.module_context.get_backend() - result_shapes = util.flatten( - [xla.aval_to_xla_shapes(result_aval) for result_aval in result_avals]) - operand_shapes = util.flatten( - [xla.aval_to_xla_shapes(op_aval) for op_aval in operand_avals]) - # Handling layouts - if operand_layouts is None: - operand_layouts = util.concatenate( - map(_aval_to_default_layouts, operand_avals)) - operand_mlir_layouts = map(_layout_to_mlir_layout, operand_layouts) - if result_layouts is None: - result_layouts = util.concatenate(map(_aval_to_default_layouts, result_avals)) - result_mlir_layouts = map(_layout_to_mlir_layout, result_layouts) + if partitioned: + if platform not in {"cpu", "cuda", "rocm"}: + raise NotImplementedError( + f"Partitioned callback not implemented on {platform} backend.") + if result_avals: + raise ValueError("Partitioned callback not supported with return values.") + backend: xc.Client = cast(xc.Client, ctx.module_context.get_backend()) + result_shapes = [_aval_to_xla_shape(aval) for aval in result_avals] + operand_shapes = [_aval_to_xla_shape(aval) for aval in operand_avals] # First we apply checks to ensure output shapes and dtypes match the expected # ones. @@ -781,7 +854,7 @@ def _wrapped_callback(*args): "Mismatched number of outputs from callback. " "Expected: {}, Actual: {}".format(len(result_avals), len(out_vals))) # Handle Python literals, and custom arrays, e.g., tf.Tensor. - out_vals = tuple(xla.canonicalize_dtype(np.asarray(a)) for a in out_vals) + out_vals = tuple(dtypes.canonicalize_value(np.asarray(a)) for a in out_vals) for i, (out_val, out_aval) in enumerate(zip(out_vals, result_avals)): if out_val.shape != out_aval.shape: raise RuntimeError( @@ -814,7 +887,7 @@ def _wrapped_callback(*args): backend, ctx, _wrapped_callback, token, operands, operand_avals, operand_shapes, non_empty_result_avals, non_empty_result_shapes, - sharding=sharding) + returns_token=returns_token, sharding=sharding) non_empty_outputs_iter = iter(non_empty_outputs) outputs = [ mlir.ir_constant(np.zeros(result_aval.shape, dtype=result_aval.dtype)) @@ -822,55 +895,51 @@ def _wrapped_callback(*args): for result_aval in result_avals] return outputs, token, None - result_types = mlir.flatten_ir_types([mlir.aval_to_ir_type(aval) for aval in result_avals]) + device = "gpu" if platform in {"cuda", "rocm"} else "cpu" + partition = "_partitioned" if partitioned else "" + call_target_name = f"xla_ffi{partition}_python_{device}_callback" if token: - callback_without_token = _wrapped_callback def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function-redefined return (token, *callback_without_token(*args)) - - operand_shapes = [ - xla.aval_to_xla_shapes(core.abstract_token)[0], *operand_shapes - ] - result_shapes = [ - xla.aval_to_xla_shapes(core.abstract_token)[0], *result_shapes - ] operands = [token, *operands] - result_types = [mlir.token_type(), *result_types] - operand_mlir_layouts = [_layout_to_mlir_layout(None), *operand_mlir_layouts] - result_mlir_layouts = [_layout_to_mlir_layout(None), *result_mlir_layouts] - callback_descriptor, ifrt_callback = ( - backend.get_emit_python_callback_descriptor(_wrapped_callback, - operand_shapes, - result_shapes)) + if ( + config.use_shardy_partitioner.value + and sharding is not None + and len(ctx.avals_out) > 0 + and isinstance(sharding, SdyArrayList) + ): + # Add a sharding annotation for the token if we have at least one + # output. Otherwise, the single shardy annotation required of all ops + # (even those without any results) can annotate the token. + sharding = SdyArrayList([ + SdyArray( + mesh_shape=(), + dim_shardings=[], + logical_device_ids=()), + *sharding.shardings]) + ctx = dataclasses.replace( + ctx, + avals_in=[core.abstract_token, *ctx.avals_in], + avals_out=[core.abstract_token, *ctx.avals_out], + ) + + # TODO(dsuo): Remove this line once we deprecate the XLA custom call + # handler. + ifrt_callback = _wrapped_callback ctx.module_context.add_host_callback(ifrt_callback) - descriptor_operand = mlir.ir_constant(callback_descriptor) - callback_operands = [descriptor_operand, *operands] - if operand_mlir_layouts is not None: - operand_mlir_layouts = [_layout_to_mlir_layout([]), *operand_mlir_layouts] - result_type = ir.TupleType.get_tuple(result_types) - call_target_name = ("xla_python_gpu_callback" - if platform in {"cuda", "rocm"} else "xla_python_cpu_callback") - result = hlo.CustomCallOp( - [result_type], - callback_operands, - call_target_name=ir.StringAttr.get(call_target_name), - has_side_effect=ir.BoolAttr.get(has_side_effect), - api_version=mlir.i32_attr(2), - called_computations=ir.ArrayAttr.get([]), - backend_config=ir.StringAttr.get(str(callback_descriptor)), - operand_layouts=( - None if operand_mlir_layouts is None - else ir.ArrayAttr.get(operand_mlir_layouts)), - result_layouts=( - None if result_mlir_layouts is None - else ir.ArrayAttr.get(result_mlir_layouts))) + index = np.uint64(len(ctx.module_context.host_callbacks) - 1) + result = ffi.build_ffi_lowering_function( # type: ignore + call_target_name, + has_side_effect=has_side_effect, + )(ctx, *operands, index=np.uint64(index)) + if sharding is not None: mlir.set_sharding(result, sharding) - results = [ - hlo.get_tuple_element(result, mlir.i32_attr(i)) - for i in range(len(result_types)) - ] + + results = result.results # type: ignore + if token: - token, *results = results - return results, token, ifrt_callback + token, *results = results # type: ignore + + return results, token, ifrt_callback # type: ignore diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 1ec8ad50b456..697b3fd0ff0f 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -21,22 +21,21 @@ import numpy as np -import jax.numpy as jnp -from jax import dtypes -from jax import lax - -from jax.experimental import shard_map +from jax._src import ad_checkpoint from jax._src import api from jax._src import api_util -from jax._src import ad_checkpoint -from jax._src import linear_util as lu from jax._src import callback from jax._src import config from jax._src import core from jax._src import custom_derivatives +from jax._src import dtypes from jax._src import effects -from jax._src import pjit +from jax._src import lax +from jax._src import linear_util as lu from jax._src import mesh as mesh_lib +from jax._src import numpy as jnp +from jax._src import pjit +from jax._src import shard_map as jshmap from jax._src import sharding_impls from jax._src import source_info_util from jax._src import traceback_util @@ -46,13 +45,22 @@ from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe +from jax._src.partition_spec import PartitionSpec as P from jax._src.tree_util import tree_flatten -from jax._src.tree_util import tree_map +from jax._src.tree_util import tree_map, FlatTree from jax._src.tree_util import tree_unflatten from jax._src.typing import Array from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip, unzip3, weakref_lru_cache, HashableWrapper, foreach) +# Backward compatibility: some downstream users implicitly rely on this import, +# and reference jax.experimental.shard_map without an explicit import. +# TODO(yashkatariya): remove this once users are migrated to jax.shard_map. +try: + import jax.experimental.shard_map as _ # pytype: disable=import-error # noqa: F401 +except ImportError: + pass + source_info_util.register_exclusion(__file__) traceback_util.register_exclusion(__file__) @@ -115,8 +123,10 @@ def __lt__(self, other: ErrorEffect): unpack = lambda x: (str(x.error_type), shape_dtypes(x)) return (unpack(self) < unpack(other)) -effects.control_flow_allowed_effects.add_type(ErrorEffect) effects.lowerable_effects.add_type(ErrorEffect) +effects.control_flow_allowed_effects.add_type(ErrorEffect) +effects.custom_derivatives_allowed_effects.add_type(ErrorEffect) +effects.remat_allowed_effects.add_type(ErrorEffect) class DivisionByZeroError(JaxException): @@ -167,7 +177,7 @@ def __str__(self): f'{self._payload[1]} with size {self._payload[2]}. ') def get_effect_type(self): - return ErrorEffect(OOBError, (api.ShapeDtypeStruct((3,), jnp.int32),)) + return ErrorEffect(OOBError, (api.ShapeDtypeStruct((3,), np.int32),)) class FailedCheckError(JaxException): @@ -261,7 +271,7 @@ def _get_batched_exception(self) -> BatchedError | None: cur_effect = None for error_effect, code in self._code.items(): if self._pred[error_effect][idx]: # type: ignore - if min_code is None or code[idx] < min_code: + if min_code is None or code[idx] < min_code: # type: ignore[index] min_code = code[idx] # type: ignore cur_effect = error_effect @@ -464,6 +474,7 @@ def _reduce_any_error(error: Error): ## check_p primitive check_p = core.Primitive('check') +check_p.is_effectful = lambda _: True # type: ignore check_p.multiple_results = True # zero results @@ -525,7 +536,8 @@ def check_lowering_rule(ctx, *args, err_tree, debug): operands=args, operand_avals=list(ctx.avals_in), result_avals=list(ctx.avals_out), - has_side_effect=True) + has_side_effect=True, + returns_token=False) return out_op def check_lowering_rule_unsupported(*a, debug, **k): @@ -580,7 +592,7 @@ def check_nans(prim, error, enabled_errors, out): return error def isnan(x): - if jnp.issubdtype(x.dtype, dtypes.prng_key): + if dtypes.issubdtype(x.dtype, dtypes.prng_key): return False return jnp.any(jnp.isnan(x)) @@ -600,7 +612,7 @@ def isnan(x): lax.igamma_p, lax.igammac_p, lax.integer_pow_p, lax.lgamma_p, lax.linear_solve_p, lax.log1p_p, lax.log_p, lax.logistic_p, lax.mul_p, lax.pad_p, lax.pow_p, lax.psum_p, - lax.random_gamma_grad_p, lax.reduce_p, lax.reduce_prod_p, + lax.reduce_p, lax.reduce_prod_p, lax.reduce_sum_p, lax.reduce_window_p, lax.reduce_window_sum_p, lax.regularized_incomplete_beta_p, lax.rem_p, lax.rng_uniform_p, lax.rsqrt_p, lax.sin_p, @@ -616,9 +628,9 @@ def dynamic_slice_error_check(error, enabled_errors, operand, *start_indices, sl if OOBError not in enabled_errors: return error, out - operand_dims = np.array(operand.shape) - slice_sizes = np.array(slice_sizes) start_indices = jnp.array(start_indices) + operand_dims = np.array(operand.shape, dtype=start_indices.dtype) + slice_sizes = np.array(slice_sizes, dtype=start_indices.dtype) oob_mask = (start_indices < 0) | (start_indices + slice_sizes > operand_dims) payload = oob_payload(oob_mask, start_indices, range(operand.ndim), operand.shape) @@ -683,7 +695,7 @@ def oob_payload(oob_mask, indices, dims_map, operand_shape): oob_axis = jnp.array(dims_map)[multi_idx[-1]] oob_axis_size = jnp.array(operand_shape)[oob_axis] oob_index = jnp.ravel(indices)[flat_idx] - payload = jnp.array([oob_index, oob_axis, oob_axis_size], dtype=jnp.int32) + payload = jnp.array([oob_index, oob_axis, oob_axis_size], dtype=np.int32) return payload def scatter_oob(operand, indices, updates, dnums): @@ -741,22 +753,31 @@ def scatter_error_check(prim, error, enabled_errors, operand, indices, updates, # HOP error check rules +@jtu.register_static +class ErrorEffects: + def __init__(self, val): + self.val = val + @weakref_lru_cache def jaxpr_to_checkify_jaxpr( jaxpr: core.ClosedJaxpr, enabled_errors, err_tree: PyTreeDef, *flat_err_and_in_vals) -> tuple[core.ClosedJaxpr, PyTreeDef, set[ErrorEffect]]: - checkify_jaxpr_partial = functools.partial(checkify_jaxpr_flat, jaxpr.jaxpr, - jaxpr.consts, enabled_errors, - err_tree) - fun = lu.wrap_init(checkify_jaxpr_partial, debug_info=jaxpr.jaxpr.debug_info) - fun, metadata = _flatten_and_get_error_metadata_thunk(fun) - - new_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, flat_err_and_in_vals) - checked_jaxpr = core.ClosedJaxpr(new_jaxpr, consts) - out_tree, error_effects = metadata() - return checked_jaxpr, out_tree, error_effects - -def cond_error_check(error: Error, enabled_errors, index, *ops, branches): + + def fun_wrapped(*invals): + error, out = checkify_jaxpr_flat( + jaxpr.jaxpr, jaxpr.consts, enabled_errors, err_tree, *invals) + error_effects = ErrorEffects(set(error._pred.keys())) + return (error, out), error_effects + + debug_info = jaxpr.jaxpr.debug_info.with_unknown_names() + args_avals = FlatTree.flatten((flat_err_and_in_vals, {})) + checked_jaxpr, full_out_avals = pe.trace_to_jaxpr(fun_wrapped, args_avals, debug_info) + out_avals, error_effects = full_out_avals.unpack() + error_effects = error_effects.unflatten().val + return checked_jaxpr, out_avals.tree, error_effects + +def cond_error_check(error: Error, enabled_errors, index, *ops, + branches, **params): # Get the error-effects out of all branches so the cond can be called with # a merged error with all these effects. err_vals, err_tree = jtu.tree_flatten(error) @@ -777,7 +798,7 @@ def get_error_effects_from_jaxpr(jxpr): err_and_outs = lax.cond_p.bind( index, *err_vals, *ops, - branches=tuple(new_branches)) + branches=tuple(new_branches), **params) # we need to merge metadata across out_trees (a tuple) err0, out = tree_unflatten(out_trees[0], err_and_outs) @@ -832,18 +853,19 @@ def new_body_f(*c_consts_and_vals): c_consts, vals = split_list(c_consts_and_vals, [c_consts_num]) out = body_f(*vals) # This checks if the next cond application will error - _ = cond_f(*c_consts, *out) + lax.dce_sink(cond_f(*c_consts, *out)) return out - new_body_f_ = lu.wrap_init(new_body_f, debug_info=body_jaxpr.jaxpr.debug_info) c_consts_avals = cond_jaxpr.in_avals[:c_consts_num] - jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(new_body_f_, [*c_consts_avals, - *body_jaxpr.in_avals]) - closed_jaxpr = pe.close_jaxpr(jaxpr) + + jaxpr, _ = pe.trace_to_jaxpr( + new_body_f, + FlatTree.flatten(((*c_consts_avals, *body_jaxpr.in_avals), {})), + debug_info=body_jaxpr.jaxpr.debug_info.with_unknown_names()) err_vals, err_tree = jtu.tree_flatten(error) err_vals = map(core.get_aval, err_vals) flat_err_and_in_vals = [*err_vals, *c_consts_avals, *body_jaxpr.in_avals] jaxpr, out_tree, error_effects = jaxpr_to_checkify_jaxpr( - closed_jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals) + jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals) return jaxpr, out_tree, error_effects @@ -913,15 +935,15 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, # Update pjit params to account for extra error values. num_error_vals = len(err_vals) num_out_error_vals = out_tree.num_leaves - len(out_shardings) - sharding = sharding_impls.UNSPECIFIED new_in_shardings = (*[sharding] * num_error_vals, *in_shardings) - new_out_shardings = (*[sharding] * num_out_error_vals, *out_shardings) new_in_layouts = (*[None] * num_error_vals, *in_layouts) - new_out_layouts = (*[None] * num_out_error_vals, *out_layouts) new_donated_invars = (*[False] * num_error_vals, *donated_invars) - err_and_out = pjit.pjit_p.bind( + new_out_shardings = (*[sharding] * num_out_error_vals, *out_shardings) + new_out_layouts = (*[None] * num_out_error_vals, *out_layouts) + + err_and_out = pjit.jit_p.bind( *new_vals_in, jaxpr=checked_jaxpr, in_shardings=new_in_shardings, @@ -936,7 +958,7 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, compiler_options_kvs=compiler_options_kvs, ) return tree_unflatten(out_tree, err_and_out) -error_checks[pjit.pjit_p] = pjit_error_check +error_checks[pjit.jit_p] = pjit_error_check def remat_error_check(error, enabled_errors, *vals_in, jaxpr, **params): @@ -954,7 +976,7 @@ def remat_error_check(error, enabled_errors, *vals_in, jaxpr, **params): def shard_map_error_check( error: Error, enabled_errors, *vals_in, - jaxpr: core.Jaxpr, in_names, out_names, **kwargs + jaxpr: core.Jaxpr, in_specs, out_specs, **kwargs ): if (mesh := kwargs.get('mesh')) is None: raise ValueError('Mesh must be provided for shard_map with checkify.') @@ -962,22 +984,24 @@ def shard_map_error_check( err_vals, err_tree = jtu.tree_flatten(error) num_error_vals = len(err_vals) # Replicated sharding for in errors. - new_in_names = (*([{}] * num_error_vals), *in_names) + new_in_specs = (*([P()] * num_error_vals), *in_specs) new_vals_in = [*err_vals, *vals_in] in_avals = list(map(core.get_aval, new_vals_in)) - auto = kwargs.get('auto') + manual_axes = kwargs.get('manual_axes') + check_vma = kwargs.get('check_vma') for i, v in enumerate(in_avals): if not (sharder := core.shard_aval_handlers.get(type(v))): raise ValueError(f'Unsupported aval type: {type(v)}') - in_avals[i] = sharder(mesh, auto, new_in_names[i], v) + in_avals[i] = sharder(mesh, manual_axes, check_vma, new_in_specs[i], v) - with (shard_map._extend_axis_env(mesh, auto), - mesh_lib.use_abstract_mesh(shard_map._as_manual_mesh(mesh, auto))): + with (jshmap._extend_axis_env(mesh, manual_axes), + mesh_lib.use_abstract_mesh(jshmap._as_manual_mesh(mesh, manual_axes)), # type: ignore[arg-type] + config._check_vma(check_vma)): # jaxpr to checked_jaxpr checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr( pe.close_jaxpr(jaxpr), enabled_errors, err_tree, *in_avals ) - num_out_error_vals = out_tree.num_leaves - len(out_names) + num_out_error_vals = out_tree.num_leaves - len(out_specs) def expand_errors_leading_dim(*xs): outs = core.eval_jaxpr(checked_jaxpr.jaxpr, checked_jaxpr.consts, *xs) @@ -985,32 +1009,30 @@ def expand_errors_leading_dim(*xs): errs = [lax.expand_dims(e, [0]) for e in errs] return *errs, *outs - with core.extend_axis_env_nd(mesh.shape.items()): - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(expand_errors_leading_dim, - debug_info=checked_jaxpr.jaxpr.debug_info), - checked_jaxpr.in_avals - ) - checked_jaxpr = core.ClosedJaxpr(jaxpr, consts) + with core.extend_axis_env_nd(mesh.shape.items()), config._check_vma(check_vma): + checked_jaxpr, _ = pe.trace_to_jaxpr( + expand_errors_leading_dim, + FlatTree.flatten((tuple(checked_jaxpr.in_avals), {})), + debug_info=checked_jaxpr.jaxpr.debug_info) # Update shard_map params to account for extra error values. # Use fully sharded partitioning for out errors. - new_out_names = (*([{0: mesh.axis_names}] * num_out_error_vals), *out_names) + new_out_specs = (*([P(mesh.axis_names)] * num_out_error_vals), *out_specs) subfun = lu.hashable_partial( lu.wrap_init(core.eval_jaxpr, debug_info=checked_jaxpr.jaxpr.debug_info), checked_jaxpr.jaxpr, checked_jaxpr.consts ) new_params = dict( jaxpr=checked_jaxpr.jaxpr, - in_names=new_in_names, - out_names=new_out_names, + in_specs=new_in_specs, + out_specs=new_out_specs, **kwargs, ) - _, new_params = shard_map.shard_map_p.get_bind_params(new_params) + _, new_params = jshmap.shard_map_p.get_bind_params(new_params) - err_and_out = shard_map.shard_map_p.bind(subfun, *new_vals_in, **new_params) + err_and_out = jshmap.shard_map_p.bind(subfun, *new_vals_in, **new_params) return tree_unflatten(out_tree, err_and_out) -error_checks[shard_map.shard_map_p] = shard_map_error_check +error_checks[jshmap.shard_map_p] = shard_map_error_check def custom_jvp_call_rule(in_err: Error, enabled_errors: set, *in_vals, num_consts, @@ -1073,17 +1095,17 @@ def jvp(*xs): return [*primal_errs, *out_primals, *tangent_errs, *out_tangents] return lu.wrap_init(jvp, debug_info=jvp_jaxpr_fun.debug_info) -def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, - fun_jaxpr: core.ClosedJaxpr, - fwd_jaxpr_thunk, num_consts, - bwd: lu.WrappedFun, out_trees, - symbolic_zeros: bool): +def custom_vjp_call_rule(in_err, enabled_errors, *in_vals, + call_jaxpr: core.ClosedJaxpr, + fwd_jaxpr_thunk, num_consts, + bwd: lu.WrappedFun, out_trees, + symbolic_zeros: bool): err_vals, err_tree = jtu.tree_flatten(in_err) num_errs = err_tree.num_leaves checkified_fun = lu.wrap_init( - functools.partial(checkify_jaxpr_flat, fun_jaxpr.jaxpr, - fun_jaxpr.consts, enabled_errors, err_tree), - debug_info=fun_jaxpr.jaxpr.debug_info) + functools.partial(checkify_jaxpr_flat, call_jaxpr.jaxpr, + call_jaxpr.consts, enabled_errors, err_tree), + debug_info=call_jaxpr.jaxpr.debug_info) checkified_fun, fun_metadata = _flatten_and_get_error_metadata_thunk( checkified_fun) @@ -1091,13 +1113,13 @@ def checkified_fwd(*args): # TODO(lenamartens, sharadmv): why not checkify here? xs, zeros = args[::2], args[1::2] xs, zeros = xs[num_errs:], zeros[num_errs:] - fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk(*zeros) + fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk.call_wrapped(*zeros) xs_without_consts = xs[num_consts:] return core.eval_jaxpr(fwd_jaxpr, fwd_consts, *xs_without_consts) # TODO(necula): the fwd result_paths are not quite the same as fun_jaxpr checkified_fwd_wrapped = lu.wrap_init(checkified_fwd, - debug_info=fun_jaxpr.jaxpr.debug_info) + debug_info=fwd_jaxpr_thunk.debug_info) bwd_ = lu.wrap_init(lambda *args: (*(None,)*num_errs, *bwd.call_wrapped(*args)), debug_info=bwd.debug_info) checkified_fwd_wrapped, fwd_out_tree = flatten_fun_output(checkified_fwd_wrapped) @@ -1112,7 +1134,7 @@ def checkified_fwd(*args): else: out_err, out_vals = in_err, all_outs return out_err, out_vals -error_checks[custom_derivatives.custom_vjp_call_jaxpr_p] = custom_vjp_call_jaxpr_rule +error_checks[custom_derivatives.custom_vjp_call_p] = custom_vjp_call_rule def check_discharge_rule(error, enabled_errors, *args, err_tree, debug): @@ -1217,18 +1239,15 @@ def checkify(f: Callable[..., Out], @traceback_util.api_boundary def checked_fun(*args, **kwargs): # close over all arguments so they're not turned into abstract values. - in_tree = jtu.tree_structure(((), {})) + in_avals = FlatTree.flatten(((), {})) closed_f = lambda: f(*args, **kwargs) # stage: - debug = api_util.debug_info("checkify", f, args, kwargs) - fun_, out_tree = api_util.flatten_fun(lu.wrap_init(closed_f, - debug_info=debug), - in_tree) - jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, ()) - jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_)) + debug_info = api_util.debug_info("checkify", f, args, kwargs).with_unknown_names() + jaxpr_, out_avals = pe.trace_to_jaxpr(closed_f, in_avals, debug_info) + jaxpr, consts = pe.separate_consts(jaxpr_) # checkify: error, out_flat = checkify_jaxpr(jaxpr, errors, init_error, *consts) - return error, jtu.tree_unflatten(out_tree(), out_flat) + return error, out_avals.update(out_flat).unflatten() return checked_fun def check(pred: Bool, msg: str, @@ -1297,7 +1316,7 @@ def _check_error(error, *, debug=False): def is_scalar_pred(pred) -> bool: return (isinstance(pred, bool) or isinstance(pred, Array) and pred.shape == () and - pred.dtype == jnp.dtype('bool')) + pred.dtype == np.dtype('bool')) def debug_check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None: diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index 0539e4253063..2563b391ceda 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -16,7 +16,7 @@ import os import re import warnings -from jax import version + from jax._src import config from jax._src import hardware_utils @@ -38,9 +38,16 @@ def maybe_import_libtpu(): def get_tpu_library_path() -> str | None: - path_from_env = os.getenv("TPU_LIBRARY_PATH") - if path_from_env is not None and os.path.isfile(path_from_env): - return path_from_env + path_from_env = os.getenv('TPU_LIBRARY_PATH') + if path_from_env is not None: + if os.path.isfile(path_from_env): + return path_from_env + warning_message = ( + f'TPU_LIBRARY_PATH is set to a non-existent path: {path_from_env}.' + ' Falling back to default libtpu path. Please unset TPU_LIBRARY_PATH' + ' or set it to a valid path.' + ) + warnings.warn(warning_message) libtpu_module = maybe_import_libtpu() if libtpu_module is not None: @@ -71,9 +78,13 @@ def cloud_tpu_init() -> None: """ global running_in_cloud_tpu_vm + from jax import version # pytype: disable=import-error + # Exit early if we're not running on a Cloud TPU VM or libtpu isn't installed. libtpu_path = get_tpu_library_path() num_tpu_chips, tpu_id = hardware_utils.num_available_tpu_chips_and_device_id() + if num_tpu_chips == 0: + os.environ['TPU_SKIP_MDS_QUERY'] = '1' if ( tpu_id is not None and tpu_id >= hardware_utils.TpuVersion.v5e @@ -95,8 +106,13 @@ def cloud_tpu_init() -> None: os.environ.setdefault('TPU_ML_PLATFORM', 'JAX') os.environ.setdefault('TPU_ML_PLATFORM_VERSION', version.__version__) os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1') - if '--xla_tpu_use_enhanced_launch_barrier' not in os.environ.get('LIBTPU_INIT_ARGS', ''): - os.environ['LIBTPU_INIT_ARGS'] = os.environ.get('LIBTPU_INIT_ARGS','') + ' --xla_tpu_use_enhanced_launch_barrier=true' + if '--xla_tpu_use_enhanced_launch_barrier' not in os.environ.get( + 'LIBTPU_INIT_ARGS', '' + ): + os.environ['LIBTPU_INIT_ARGS'] = ( + os.environ.get('LIBTPU_INIT_ARGS', '') + + ' --xla_tpu_use_enhanced_launch_barrier=true' + ) # this makes tensorstore serialization work better on TPU os.environ.setdefault('TENSORSTORE_CURL_LOW_SPEED_TIME_SECONDS', '60') @@ -110,25 +126,23 @@ def cloud_tpu_init() -> None: if config.jax_pjrt_client_create_options.value is None: config.update( - 'jax_pjrt_client_create_options', - f'ml_framework_name:JAX;ml_framework_version:{version.__version__}' - ) + 'jax_pjrt_client_create_options', + f'ml_framework_name:JAX;ml_framework_version:{version.__version__}', + ) -def is_cloud_tpu_older_than(year: int, month: int, day: int): - # We import locally because the functions above must run before the runtime - # modules are imported. - from jax._src import xla_bridge # pytype: disable=import-error - date = datetime.date(year, month, day) - if not running_in_cloud_tpu_vm: +def is_cloud_tpu_older_than(year: int, month: int, day: int, backend): + if 'TFRT TPU' not in backend.platform_version: return False # The format of Cloud TPU platform_version is like: # PJRT C API # TFRT TPU v2 # Built on Oct 30 2023 03:04:42 (1698660263) cl/577737722 - platform_version = xla_bridge.get_backend().platform_version.split('\n')[-1] + platform_version = backend.platform_version.split('\n')[-1] results = re.findall(r'\(.*?\)', platform_version) if len(results) != 1: return True + date = datetime.date(year, month, day) build_date = date.fromtimestamp(int(results[0][1:-1])) - return build_date < date + # Filter out ridiculously old dates that some test builds get. + return build_date < date and build_date.year > 2010 diff --git a/jax/_src/clusters/cloud_tpu_cluster.py b/jax/_src/clusters/cloud_tpu_cluster.py index c8aa765c181c..c4e0f5d1e104 100644 --- a/jax/_src/clusters/cloud_tpu_cluster.py +++ b/jax/_src/clusters/cloud_tpu_cluster.py @@ -26,7 +26,7 @@ # We use an arbitrarily chosen port for the coordinator since we cannot # rely on communication to choose one in real time. -coordinator_port = '8476' +coordinator_port = '8482' metadata_response_code_success = 200 @@ -54,24 +54,26 @@ def get_metadata(key): raise RuntimeError(f"Getting metadata['{key}'] failed for 6 tries") return api_resp.text, api_resp.status_code -def get_tpu_env_value(key): - def get_tpu_env_value_from_metadata(key): - tpu_env_data = get_metadata('tpu-env')[0] - key_value_pairs = tpu_env_data.split('\n') - for key_value_pair in key_value_pairs: - # Typical line is MEGASCALE_NUM_SLICES: '2' - if ':' in key_value_pair: - row_key, value = re.split(':', key_value_pair, 1) - row_key = row_key.strip() - if row_key == key: - return value.strip().strip("'") - return None - +def get_tpu_env_value_from_metadata(key) -> str | None: + metadata_value = None + tpu_env_data = get_metadata('tpu-env')[0] + key_value_pairs = tpu_env_data.split('\n') + for key_value_pair in key_value_pairs: + # Typical line is MEGASCALE_NUM_SLICES: '2' + if ':' in key_value_pair: + row_key, value = re.split(':', key_value_pair, 1) + row_key = row_key.strip() + if row_key == key: + metadata_value = value.strip().strip("'") + return metadata_value + +def get_tpu_env_value(key) -> str | None: + # First try to get the value from the environment. value = os.environ.get(key, None) - return value if value is not None else get_tpu_env_value_from_metadata(key) - -def has_megascale_address(): - return get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS') is not None + if value is None: + # If not found, try to get it from the metadata. + value = get_tpu_env_value_from_metadata(key) + return value class BaseTpuCluster(clusters.ClusterEnv): @@ -93,13 +95,12 @@ def is_env_present(cls) -> bool: return False @classmethod - def get_coordinator_address(cls, timeout_secs: int | None) -> str: - if has_megascale_address(): - # For both GCE via QueuedResources and GKE via JobSet, the - # Megascale coordinator address is set as the host with process id = 0, - # so can be used as the jax distributed system coordinator. - coordinator_address = get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS') - else: + def get_coordinator_address(cls, timeout_secs: int | None, override_coordinator_port: str | None) -> str: + # For both GCE via QueuedResources and GKE via JobSet, the + # Megascale coordinator address is set as the host with process id = 0, + # so can be used as the jax distributed system coordinator. + coordinator_address = get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS') + if not coordinator_address: # For both GCE (QueuedResources and TPUVM create) and GKE via Job API, # the workers lists are sorted by process ID so the first one can # be used as the jax distributed system coordinator. @@ -107,7 +108,8 @@ def get_coordinator_address(cls, timeout_secs: int | None) -> str: coordinator_address = coordinator_address.split(':')[0] logger.debug("TPU Cluster using coordinator address: %s", coordinator_address) cls.wait_for_coordinator(coordinator_address, timeout_secs) - return f'{coordinator_address}:{coordinator_port}' + port = override_coordinator_port or coordinator_port + return f'{coordinator_address}:{port}' @classmethod def wait_for_coordinator(cls, coordinator_address, timeout_secs): @@ -149,17 +151,18 @@ def get_process_id(cls) -> int: @staticmethod def _get_num_slices() -> int: - if has_megascale_address(): - return int(get_tpu_env_value('MEGASCALE_NUM_SLICES')) - else: + num_slices = get_tpu_env_value('MEGASCALE_NUM_SLICES') + if not num_slices: return 1 + return int(num_slices) # type: ignore + @staticmethod def _get_slice_id() -> int: - if has_megascale_address(): - return int(get_tpu_env_value('MEGASCALE_SLICE_ID')) - else: + slice_id = get_tpu_env_value('MEGASCALE_SLICE_ID') + if not slice_id: return 0 + return int(slice_id) # type: ignore @staticmethod def _get_process_id_in_slice() -> int: @@ -208,20 +211,38 @@ class GkeTpuCluster(BaseTpuCluster): @classmethod def is_env_present(cls) -> bool: - if running_in_cloud_tpu_vm and os.environ.get("TPU_WORKER_HOSTNAMES") is not None: + if running_in_cloud_tpu_vm and cls._get_worker_host_names_env_var() is not None: logger.debug("Gke Tpu Cluster detected for Jax Distributed System") return True else: if not running_in_cloud_tpu_vm: logger.debug("Did not detect cloud TPU VM") else: - logger.debug("Did not detect TPU GKE cluster since TPU_WORKER_HOSTNAMES is not set") + logger.debug("Did not detect TPU GKE cluster since neither " + "TPU_PROCESS_ADDRESSES nor TPU_WORKER_HOSTNAMES is set.") return False @staticmethod def _get_process_id_in_slice() -> int: return int(str(os.environ.get('TPU_WORKER_ID'))) + @staticmethod + def _get_worker_host_names_env_var() -> str | None: + """ + Retrieves the list of worker hostnames from environment variables. + + Checks 'TPU_PROCESS_ADDRESSES' first, then 'TPU_WORKER_HOSTNAMES'. + Returns None if neither environment variable is set. + """ + worker_hostnames = os.environ.get('TPU_PROCESS_ADDRESSES', None) + if worker_hostnames is not None: + return worker_hostnames + return os.environ.get('TPU_WORKER_HOSTNAMES', None) + @staticmethod def _get_worker_list_in_slice() -> list[str]: - return str(os.environ.get('TPU_WORKER_HOSTNAMES', None)).split(',') + """ + Returns a list of worker endpoints/hostnames within slice. + """ + worker_hostnames_str = str(GkeTpuCluster._get_worker_host_names_env_var()) + return worker_hostnames_str.split(',') diff --git a/jax/_src/clusters/cluster.py b/jax/_src/clusters/cluster.py index 69ef77a6421d..8fe3f3605e5b 100644 --- a/jax/_src/clusters/cluster.py +++ b/jax/_src/clusters/cluster.py @@ -15,6 +15,7 @@ from __future__ import annotations from collections.abc import Sequence +import os import logging from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm @@ -23,7 +24,7 @@ class ClusterEnv: """Interface for defining a cluster environment. - To enable auto bootrapping (aka :func:`jax.distributed.initialize()`), + To enable auto bootstrapping (aka :func:`jax.distributed.initialize()`), cluster environments need to derive from :class:`ClusterEnv` and implement :func:`is_env_present`, :func:`get_coordinator_address`, :func:`get_process_count`, and :func:`get_process_id`. @@ -69,7 +70,8 @@ def auto_detect_unset_distributed_params(cls, if env: logger.debug('Initializing distributed JAX environment via %s', env.__name__) if coordinator_address is None: - coordinator_address = env.get_coordinator_address(timeout_secs=initialization_timeout) + coordinator_port = os.environ.get("JAX_COORDINATOR_PORT") + coordinator_address = env.get_coordinator_address(timeout_secs=initialization_timeout, override_coordinator_port=coordinator_port) if num_processes is None: num_processes = env.get_process_count() if process_id is None: @@ -95,7 +97,7 @@ def is_env_present(cls) -> bool: raise NotImplementedError("ClusterEnv subclasses must implement is_env_present") @classmethod - def get_coordinator_address(cls, timeout_secs: int | None) -> str: + def get_coordinator_address(cls, timeout_secs: int | None, override_coordinator_port: str | None) -> str: """Returns address and port used by JAX to bootstrap. Process id 0 will open a tcp socket at "hostname:port" where diff --git a/jax/_src/clusters/k8s_cluster.py b/jax/_src/clusters/k8s_cluster.py index 1274724b8ebd..a3c6c9af56b9 100644 --- a/jax/_src/clusters/k8s_cluster.py +++ b/jax/_src/clusters/k8s_cluster.py @@ -16,13 +16,51 @@ from contextlib import contextmanager from functools import cache +from itertools import chain +import logging +import numpy as np import os import socket +import time import textwrap import warnings from jax._src import clusters +logger = logging.getLogger(__name__) + + +def retry( + func=None, + initial_delay=0, + wait=np.logspace(-1, 1, 5) * np.random.rand(5), + exceptions=Exception, +): + def retry_decorator(func): + def retry_driver(*args, **kwargs): + # Retry the function call with exponential backoff + for i, t in enumerate(chain([initial_delay], wait)): + logger.debug( + f"Trying {func.__name__} in {t:.2f} seconds, attempt {i}/{len(wait)}" + ) + time.sleep(t) + try: + return func(*args, **kwargs) + except exceptions as e: + if i == len(wait): + raise RuntimeError('Retry failed with all attempts exhausted') from e + finally: + logger.debug( + f"Finished {func.__name__} after {i+1} attempts" + ) + return retry_driver + + if func is None: + return retry_decorator + else: + return retry_decorator(func) + + class K8sCluster(clusters.ClusterEnv): # Use an arbitrarily chosen port for the coordinator since we cannot @@ -34,16 +72,18 @@ def is_env_present(cls) -> bool: if 'KUBERNETES_SERVICE_HOST' in os.environ: try: import kubernetes as k8s # pytype: disable=import-error - except ImportError as e: - warnings.warn(textwrap.fill( - "Kubernetes environment detected, but the `kubernetes` package is " - "not installed to enable automatic bootstrapping in this " - "environment. To enable automatic boostrapping, please install " - "jax with the [k8s] extra. For example:" - " pip install jax[k8s]" - " OR" - " pip install jax[k8s,]" - )) + except (ImportError, ModuleNotFoundError): + warnings.warn( + '\n'.join([ + textwrap.fill( + "Kubernetes environment detected, but the `kubernetes` package " + "is not installed to enable automatic bootstrapping in this " + "environment. To enable automatic bootstrapping, please install " + "jax with the [k8s] extra. For example:"), + " pip install jax[k8s]", + " pip install jax[k8s,]", + ]) + ) return False k8s.config.load_incluster_config() @@ -67,7 +107,9 @@ def _handle_api_exception(cls): "this job does not have the permission for pod introspection. Please " "either grant the default SA permission to read pod info, or create a " "dedicated service account with the permission and associated with " - "the job. For more details, see .", + "the job. For an example on setting up the service account, see the " + "example/k8s directory in the JAX repo. For more details, please refer to " + "https://docs.jax.dev/en/latest/multi_process.html#kubernetes-example", width=80 )) raise RuntimeError('\n'.join(err_msg)) from e @@ -81,16 +123,16 @@ def _namespace(cls): @classmethod @cache + # in case of latency for core DNS to update pod IP to etcd/API server + @retry(exceptions=ValueError) def _pod(cls): + ip = socket.gethostbyname(os.getenv('HOSTNAME')) with cls._handle_api_exception(): - ip = socket.gethostbyname(os.getenv('HOSTNAME')) - pods = cls._core_api.list_namespaced_pod( + [pod] = cls._core_api.list_namespaced_pod( namespace=cls._namespace(), field_selector=f'status.podIP={ip}' ).items - assert len(pods) == 1, \ - f"Exactly 1 Kubernetes pod should have IP {ip}, got {len(pods)}." - return pods[0] + return pod @classmethod @cache @@ -101,13 +143,128 @@ def _job(cls): ) @classmethod - def get_coordinator_address(cls, timeout_secs: int | None) -> str: - return '{job_name}-0.{jobset_name}:{port}'.format( - job_name=cls._pod().metadata.labels['job-name'], - jobset_name=cls._job().metadata.labels['jobset.sigs.k8s.io/jobset-name'], - port=cls._coordinator_port + @cache + def _headless_svc(cls): + with cls._handle_api_exception(): + services = cls._core_api.list_namespaced_service(cls._namespace()).items + + pod_labels = cls._pod().metadata.labels or {} + for svc in services: + if svc.spec.cluster_ip == "None": # if headless service + svc_selector = svc.spec.selector or {} + if all(pod_labels.get(k) == v for k, v in svc_selector.items()): + return svc + + # returns None if no headless service targets the current pod + return None + + @classmethod + @cache + def _controller(cls): + # https://github.com/kubernetes/apimachinery/blob/7b4292b/pkg/apis/meta/v1/types.go#L235 + # states that there cannot be more than one managing controller. + for owner in cls._pod().metadata.owner_references: + if owner.controller is True: + return owner + + raise RuntimeError( + 'Cannot automatically initialize distributed workload: ' + f'pod {cls._pod().metadata.name} does not have a controller.' ) + @classmethod + def get_coordinator_address(cls, timeout_secs: int | None, override_coordinator_port: str | None) -> str: + controller = cls._controller() + job = cls._job() + pod = cls._pod() + if controller.kind == 'Job': + # if job belongs to a jobset + if 'jobset.sigs.k8s.io/jobset-name' in job.metadata.labels: + coordinator_hostname = '{job_name}-0.{subdomain}'.format( + job_name=job.metadata.name, + subdomain=job.metadata.labels['jobset.sigs.k8s.io/jobset-name'] + ) + # if job is standalone + else: + # check if the job is associated with a headless service, which is + # necessary for pods to communicate with each other + if pod.spec.subdomain is None: + # check if a headless service exists but not specified as subdomain + svc = cls._headless_svc() + err_msg = ( + "Pods within a job need a headless service in order to " + "communicate with each other. " + ) + if svc: + err_msg += ( + f"A headless service '{svc.metadata.name}' is found that " + "targets this job, but it is not specified as the job subdomain. " + "Please add the following to the job specification: " + ) + fix_msg = [ + "```", + "kind: Job", + "spec:", + " ...", + " template:", + " spec:", + f" subdomain: {svc.metadata.name}", + "```", + ] + else: + err_msg += "To fix, add the following to the job specification:" + fix_msg = [ + "```", + "apiVersion: v1", + "kind: Service", + "metadata:", + " name: jaxpods", + "spec:", + " publishNotReadyAddresses: true", + " clusterIP: None", + " selector:", + f" job-name: {job.metadata.name}", + "---", + "kind: Job", + "spec:", + " ...", + " template:", + " spec:", + " subdomain: jaxpods", + "```", + ] + + raise RuntimeError('\n'.join([textwrap.fill(err_msg)] + fix_msg)) + + coordinator_hostname = '{job_name}-0.{subdomain}'.format( + job_name=job.metadata.name, + subdomain=pod.spec.subdomain + ) + + if timeout_secs: + # Ensure host pod is up before trying to communicate + # Retry in case of cached NXDOMAIN DNS failure (30 secs default) + @retry( + initial_delay=0.5, + wait=np.logspace(-1, 1.5, 8) * np.random.rand(8), + exceptions=socket.gaierror + ) + def wait_for_host(hostname): + socket.gethostbyname(hostname) + + wait_for_host(coordinator_hostname) + + port = override_coordinator_port or cls._coordinator_port + return '{hostname}:{port}'.format( + hostname=coordinator_hostname, + port=port + ) + + else: + raise RuntimeError( + 'In K8s, cluster automatic bootstrap only supports Job/JobSet.' + ) + @classmethod def get_process_count(cls) -> int: # https://kubernetes.io/docs/concepts/workloads/controllers/job/#controlling-parallelism @@ -120,5 +277,6 @@ def get_process_id(cls) -> int: return int(os.environ['JOB_COMPLETION_INDEX']) except KeyError: raise RuntimeError( - 'K8s job must be run with `completionMode: "Indexed"`.' + 'To enable automatic bootstrap in a K8s cluster, ' + 'jobs must be indexed by setting `completionMode: "Indexed"`.' ) diff --git a/jax/_src/clusters/mpi4py_cluster.py b/jax/_src/clusters/mpi4py_cluster.py index fc37842e7683..2b63adc8391e 100644 --- a/jax/_src/clusters/mpi4py_cluster.py +++ b/jax/_src/clusters/mpi4py_cluster.py @@ -33,7 +33,7 @@ def is_env_present(cls) -> bool: return find_spec("mpi4py") is not None @classmethod - def get_coordinator_address(cls, timeout_secs: int | None) -> str: + def get_coordinator_address(cls, timeout_secs: int | None, override_coordinator_port: str | None) -> str: # Using mpi4py, figure out rank 0 and it's hostname. # Then broadcast the hostname and port. @@ -49,8 +49,11 @@ def get_coordinator_address(cls, timeout_secs: int | None) -> str: # Order all the hostnames, and find unique ones hostname = socket.gethostname() - # Apparently, we want to pick a port in an ephemeral range... - port_id = hash(hostname) % 2**12 + (65535 - 2**12 + 1) + if override_coordinator_port: + port_id = override_coordinator_port + else: + # Apparently, we want to pick a port in an ephemeral range... + port_id = str(hash(hostname) % 2**12 + (65535 - 2**12 + 1)) hostname = f'{hostname}:{port_id}' diff --git a/jax/_src/clusters/ompi_cluster.py b/jax/_src/clusters/ompi_cluster.py index 151968c1c2bc..ad0a97e1c32e 100644 --- a/jax/_src/clusters/ompi_cluster.py +++ b/jax/_src/clusters/ompi_cluster.py @@ -33,17 +33,20 @@ def is_env_present(cls) -> bool: return _ORTE_URI in os.environ @classmethod - def get_coordinator_address(cls, timeout_secs: int | None) -> str: + def get_coordinator_address(cls, timeout_secs: int | None, override_coordinator_port: str | None) -> str: # Examples of orte_uri: # 1531576320.0;tcp://10.96.0.1,10.148.0.1,10.108.0.1:34911 # 1314521088.0;tcp6://[fe80::b9b:ac5d:9cf0:b858,2620:10d:c083:150e::3000:2]:43370 orte_uri = os.environ[_ORTE_URI] - job_id_str = orte_uri.split('.', maxsplit=1)[0] - # The jobid is always a multiple of 2^12, let's divide it by 2^12 - # to reduce likelihood of port conflict between jobs - job_id = int(job_id_str) // 2**12 - # Pick port in ephemeral range [(65535 - 2^12 + 1), 65535] - port = job_id % 2**12 + (65535 - 2**12 + 1) + if override_coordinator_port: + port = override_coordinator_port + else: + job_id_str = orte_uri.split('.', maxsplit=1)[0] + # The jobid is always a multiple of 2^12, let's divide it by 2^12 + # to reduce likelihood of port conflict between jobs + job_id = int(job_id_str) // 2**12 + # Pick port in ephemeral range [(65535 - 2^12 + 1), 65535] + port = str(job_id % 2**12 + (65535 - 2**12 + 1)) launcher_ip_match = re.search(r"tcp://(.+?)[,:]|tcp6://\[(.+?)[,\]]", orte_uri) if launcher_ip_match is None: raise RuntimeError('Could not parse coordinator IP address from Open MPI environment.') diff --git a/jax/_src/clusters/slurm_cluster.py b/jax/_src/clusters/slurm_cluster.py index 8cec07601094..e05974023fc2 100644 --- a/jax/_src/clusters/slurm_cluster.py +++ b/jax/_src/clusters/slurm_cluster.py @@ -30,12 +30,16 @@ class SlurmCluster(clusters.ClusterEnv): @classmethod def is_env_present(cls) -> bool: - return _JOBID_PARAM in os.environ + return all(var in os.environ for var in + (_JOBID_PARAM, _NODE_LIST, _PROCESS_COUNT, _PROCESS_ID, _LOCAL_PROCESS_ID)) @classmethod - def get_coordinator_address(cls, timeout_secs: int | None) -> str: - # Pick port in ephemeral range [(65535 - 2^12 + 1), 65535] - port = int(os.environ[_JOBID_PARAM]) % 2**12 + (65535 - 2**12 + 1) + def get_coordinator_address(cls, timeout_secs: int | None, override_coordinator_port: str | None) -> str: + if override_coordinator_port: + port = override_coordinator_port + else: + # Pick port in ephemeral range [(65535 - 2^12 + 1), 65535] + port = str(int(os.environ[_JOBID_PARAM]) % 2**12 + (65535 - 2**12 + 1)) # Parse the first hostname of the job # If we are looking for 'node001', diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index f1b56adf3359..bb416fd1c633 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -23,8 +23,18 @@ # If zstandard is installed, we use zstd compression, otherwise we use zlib. try: - import zstandard + # compression.zstd should be present in Python 3.14+ + from compression import zstd # pytype: disable=import-error except ImportError: + zstd = None + +if zstd is None: + # TODO(phawkins): remove this case when we drop support for Python 3.13. + try: + import zstandard # pytype: disable=import-error + except ImportError: + zstandard = None +else: zstandard = None from jax._src import cache_key @@ -111,8 +121,6 @@ def initialize_cache(path) -> None: Set the path. To take effect, should be called prior to any calls to get_executable_and_time() and put_executable_and_time(). """ - warnings.warn("initialize_cache is deprecated; use set_cache_dir instead", - DeprecationWarning, stacklevel=2) config.config.update("jax_compilation_cache_dir", path) @@ -181,14 +189,18 @@ def _get_cache(backend) -> CacheInterface | None: def compress_executable(executable: bytes) -> bytes: - if zstandard: + if zstd: + return zstd.compress(executable) + elif zstandard: compressor = zstandard.ZstdCompressor() return compressor.compress(executable) else: return zlib.compress(executable) def decompress_executable(executable: bytes) -> bytes: - if zstandard: + if zstd: + return zstd.decompress(executable) + elif zstandard: decompressor = zstandard.ZstdDecompressor() return decompressor.decompress(executable) else: @@ -207,7 +219,7 @@ def is_executable_in_cache(backend, cache_key: str) -> bool: def get_executable_and_time( - cache_key: str, compile_options, backend + cache_key: str, compile_options, backend, executable_devices ) -> tuple[xla_client.LoadedExecutable | None, int | None]: """Returns the cached executable and its compilation time if present, or None otherwise. @@ -224,7 +236,7 @@ def get_executable_and_time( serialized_executable, compile_time = extract_executable_and_time( executable_and_time) xla_executable_deserialized = backend.deserialize_executable( - serialized_executable, compile_options) + serialized_executable, executable_devices, compile_options) return xla_executable_deserialized, compile_time @@ -249,7 +261,10 @@ def put_executable_and_time( " since cache is disabled/not initialized", cache_key) return - serialized_executable = backend.serialize_executable(executable) + if hasattr(executable, "serialize") or xla_client._version >= 389: + serialized_executable = executable.serialize() + else: + serialized_executable = backend.serialize_executable(executable) executable_and_time = combine_executable_and_time( serialized_executable, compile_time) executable_and_time = compress_executable(executable_and_time) @@ -275,7 +290,7 @@ def put_executable_and_time( f"PERSISTENT CACHE WRITE with key {cache_key}, this is unexpected because " "JAX_COMPILATION_CACHE_EXPECT_PGLE is set. The execution that populated the " "cache may lack coverage, " - "https://jax.readthedocs.io/en/latest/persistent_compilation_cache.html may " + "https://docs.jax.dev/en/latest/persistent_compilation_cache.html may " "help debug why this has happened") cache.put(cache_key, executable_and_time) @@ -306,8 +321,6 @@ def is_initialized() -> bool: initialized status is not checked. The name is retained for backwards compatibility. """ - warnings.warn("is_initialized is deprecated; do not use", - DeprecationWarning, stacklevel=2) return _is_cache_enabled() @@ -341,7 +354,7 @@ def combine_executable_and_time( def extract_executable_and_time( - exectuable_and_time: bytes + executable_and_time: bytes ) -> tuple[bytes, int]: """Given the cache entry in the format shown below, extract the serialized executable and the compilation time. @@ -351,5 +364,5 @@ def extract_executable_and_time( Content: compilation time serialized executable (big-endian int) """ - return exectuable_and_time[4:], int.from_bytes( - exectuable_and_time[:4], byteorder='big') + return executable_and_time[4:], int.from_bytes( + executable_and_time[:4], byteorder='big') diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index dea532d13031..1b604d2abb52 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -17,9 +17,12 @@ from __future__ import annotations from collections.abc import Sequence +import copy +from functools import partial import logging import time -from typing import Any, Callable +from typing import Any +from collections.abc import Callable import warnings from jax._src import cache_key as cache_key_type @@ -31,8 +34,10 @@ from jax._src import path as pathlib from jax._src import profiler from jax._src import traceback_util +from jax._src import util from jax._src.interpreters import mlir from jax._src.lib import xla_client as xc +from jax._src.lib import _jax from jax._src.lib.mlir import ir import numpy as np @@ -113,7 +118,6 @@ def get_compile_options( num_partitions: int, device_assignment=None, use_spmd_partitioning: bool = True, - use_shardy_partitioner: bool = False, use_auto_spmd_partitioning: bool = False, auto_spmd_partitioning_mesh_shape: list[int] | None = None, auto_spmd_partitioning_mesh_ids: list[int] | None = None, @@ -133,10 +137,6 @@ def get_compile_options( `num_partitions`. use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD partitioning in XLA. - use_shardy_partitioner: boolean indicating whether to use the Shardy - partitioner in XLA. Shardy is a new open sourced propagation framework for - MLIR. Currently Shardy is experimental in JAX. See - www.github.com/openxla/shardy. use_auto_spmd_partitioning: boolean indicating whether to automatically generate XLA shardings for SPMD partitioner. auto_spmd_partitioning_mesh_shape: device mesh shape used to create @@ -156,7 +156,7 @@ def get_compile_options( build_options = compile_options.executable_build_options build_options.use_spmd_partitioning = use_spmd_partitioning build_options.use_auto_spmd_partitioning = use_auto_spmd_partitioning - build_options.use_shardy_partitioner = use_shardy_partitioner + build_options.use_shardy_partitioner = config.use_shardy_partitioner.value if fdo_profile is not None: build_options.fdo_profile = fdo_profile if use_auto_spmd_partitioning: @@ -197,15 +197,6 @@ def get_compile_options( config.memory_fitting_level.value ).value - # This is a temporary workaround to simplify the AutoPGLE usage. - # TODO(b/376647494): Remove once the bug is fixed. - if ((config.enable_pgle.value and config.pgle_profiling_runs.value > 0) - or config.compilation_cache_expect_pgle.value): - logger.debug("Explicitly disabling command buffer scheduling for AutoPGLE.") - if env_options_overrides is None: - env_options_overrides = {} - env_options_overrides['xla_gpu_enable_command_buffer'] = '' - if env_options_overrides is not None: # Some overrides are passed directly on build_options. overrides_on_build_options = [ @@ -248,7 +239,7 @@ def get_compile_options( else: compile_options.profile_version = _NO_PROFILE_DONT_RETRIEVE if backend is None: - logging.info("get_compile_options: no backend supplied; " + logger.info("get_compile_options: no backend supplied; " "disabling XLA-AutoFDO profile") else: fdo_profile_version = get_latest_profile_version(backend) @@ -295,31 +286,85 @@ def get_compile_options( def backend_compile( backend: xc.Client, module: ir.Module, + executable_devices: xc.DeviceList, + options: xc.CompileOptions, +) -> xc.Executable: + sym_name = module.operation.attributes['sym_name'] + module_name = ir.StringAttr(sym_name).value + if (options.executable_build_options.fdo_profile is not None + and len(options.executable_build_options.fdo_profile)): + logger.debug( + "Compiling module %s with FDO profile of length %d", + module_name, + len(options.executable_build_options.fdo_profile), + ) + + try: + return backend.compile(module, executable_devices, options) + except _jax.JaxRuntimeError as e: + for error_handler in _XLA_RUNTIME_ERROR_HANDLERS: + handler_result = error_handler(e) + if handler_result is not None: + raise handler_result from e + raise e + + +@profiler.annotate_function +def backend_compile_and_load( + backend: xc.Client, + module: ir.Module, + executable_devices: xc.DeviceList, options: xc.CompileOptions, host_callbacks: Sequence[Any], ) -> xc.LoadedExecutable: - # Convert ir.Module to a string representation, unless the backend - # explicitly flags the ability to handle a module directly (avoiding the - # overhead of back and forth conversions). - # TODO(slebedev): Change the backend.compile() to accept ir.Module. - built_c: Any - if getattr(backend, "needs_str_ir", True): - built_c = mlir.module_to_bytecode(module) - else: - built_c = module + sym_name = module.operation.attributes['sym_name'] + module_name = ir.StringAttr(sym_name).value + + if (options.executable_build_options.fdo_profile is not None + and len(options.executable_build_options.fdo_profile)): + logger.debug( + "Compiling module %s with FDO profile of length %d", + module_name, + len(options.executable_build_options.fdo_profile), + ) try: # we use a separate function call to ensure that XLA compilation appears # separately in Python profiling results - if host_callbacks: - return backend.compile( - built_c, compile_options=options, host_callbacks=host_callbacks + # TODO(dsuo): Simplify this logic once we delete _jax.CompileOnlyPyClient. + if isinstance(backend, _jax.CompileOnlyPyClient): + if host_callbacks: + return backend.compile( + module, + executable_devices=executable_devices, # type: ignore + compile_options=options, + host_callbacks=host_callbacks, # type: ignore + ) + # Some backends don't have `host_callbacks` option yet + # TODO(sharadmv): remove this fallback when all backends allow `compile` + # to take in `host_callbacks` + return backend.compile( # type: ignore + module, + executable_devices=executable_devices, + compile_options=options, + ) + else: + if host_callbacks: + return backend.compile_and_load( + module, + executable_devices=executable_devices, + compile_options=options, + host_callbacks=host_callbacks, + ) + # Some backends don't have `host_callbacks` option yet + # TODO(sharadmv): remove this fallback when all backends allow `compile` + # to take in `host_callbacks` + return backend.compile_and_load( + module, + executable_devices=executable_devices, + compile_options=options, ) - # Some backends don't have `host_callbacks` option yet - # TODO(sharadmv): remove this fallback when all backends allow `compile` - # to take in `host_callbacks` - return backend.compile(built_c, compile_options=options) - except xc.XlaRuntimeError as e: + except _jax.JaxRuntimeError as e: for error_handler in _XLA_RUNTIME_ERROR_HANDLERS: handler_result = error_handler(e) if handler_result is not None: @@ -331,7 +376,7 @@ def backend_compile( def register_xla_runtime_error_handler( - handler_fn: Callable[[xc.XlaRuntimeError], Exception | None], + handler_fn: Callable[[_jax.JaxRuntimeError], Exception | None], ): """Registers a custom exception handler for XLA runtime errors. @@ -354,15 +399,14 @@ def compile_or_get_cached( devices: np.ndarray, compile_options: xc.CompileOptions, host_callbacks: Sequence[Any], + executable_devices: xc.DeviceList, pgle_profiler: profiler.PGLEProfiler | None = None, ) -> xc.LoadedExecutable: sym_name = computation.operation.attributes['sym_name'] module_name = ir.StringAttr(sym_name).value if dumped_to := mlir.dump_module_to_file(computation, "compile"): - logging.info("Dumped the module to %s.", dumped_to) - - use_compilation_cache = compilation_cache.is_cache_used(backend) + logger.info("Dumped the module to %s.", dumped_to) is_multi_process = ( len({device.process_index for device in devices.flatten()}) > 1 @@ -370,67 +414,29 @@ def compile_or_get_cached( min_device_process_id = min( devices.flatten(), key=lambda device: device.id ).process_index - is_auto_pgle_used = ( - config.enable_pgle.value and config.pgle_profiling_runs.value > 0 - ) - if not use_compilation_cache: - if ( - is_multi_process - and is_auto_pgle_used - and distributed.global_state.client is not None - ): - compile_options.executable_build_options.fdo_profile = ( - _share_fdo_profiles( - computation, - devices, - compile_options, - backend, - distributed.global_state.client, - min_device_process_id, - ) - ) + # cache_key: may be None if compilation caching is disabled + cache_key, compile_options = _resolve_compilation_strategy( + computation, + devices, + compile_options, + backend, + pgle_profiler, + is_multi_process, + module_name, + min_device_process_id, + ) - return backend_compile(backend, computation, compile_options, - host_callbacks) + if cache_key is None: + return backend_compile_and_load( + backend, computation, executable_devices, compile_options, + host_callbacks) monitoring.record_event('/jax/compilation_cache/compile_requests_use_cache') - try: - if config.remove_custom_partitioning_ptr_from_cache_key.value: - ignore_callbacks = cache_key_type.IgnoreCallbacks.CUSTOM_PARTITIONING - else: - ignore_callbacks = cache_key_type.IgnoreCallbacks.NO - - cache_key = compilation_cache.get_cache_key( - computation, - devices, - compile_options, - backend, - ignore_callbacks=ignore_callbacks, - ) - except xc._xla.XlaRuntimeError as ex: - logger.error("compile_or_get_cached: unable to generate cache key, " - "skipping the cache: %s", ex) - return backend_compile(backend, computation, compile_options, - host_callbacks) - - if is_auto_pgle_used or config.compilation_cache_expect_pgle.value: - cache_key = _resolve_pgle_module_cache_key( - computation, - devices, - compile_options, - backend, - pgle_profiler, - is_multi_process, - cache_key, - module_name, - min_device_process_id, - ) - cache_retrieval_start = time.monotonic() retrieved_executable, retrieved_compile_time = _cache_read( - module_name, cache_key, compile_options, backend) + module_name, cache_key, compile_options, backend, executable_devices) cache_retrieval_time = time.monotonic() - cache_retrieval_start if retrieved_executable is not None: @@ -446,11 +452,12 @@ def compile_or_get_cached( "/jax/compilation_cache/cache_retrieval_time_sec", cache_retrieval_time) return retrieved_executable - elif ( + util.test_event("compile_after_persistent_compilation_miss") + if ( config.share_binary_between_hosts.value and is_multi_process and distributed.global_state.client is not None - # Host callbacks are currently baked into the HLO module so we cant share + # Host callbacks are currently baked into the HLO module so we can't share # them. and len(host_callbacks) == 0 ): @@ -458,6 +465,7 @@ def compile_or_get_cached( return _compile_and_share_module( backend, computation, + executable_devices, compile_options, host_callbacks, distributed.global_state.client, @@ -470,6 +478,7 @@ def compile_or_get_cached( return _compile_and_write_cache( backend, computation, + executable_devices, compile_options, host_callbacks, module_name, @@ -481,85 +490,130 @@ def compile_or_get_cached( # 1. PGLE optimized module (the one which was recompiled with FDO profile) is # in the persistent cache. In this case the module should be returned from # cache and PGLE should be disabled for this module. Is module is stored in -# the persistent cache under the "pgle_profiled_module_key" which calculated -# with replacing FDO profile with flag which identify that module were PGLE -# profiled. +# the persistent cache under the "pgle_optimized_cache_key", which is +# calculated by replacing the FDO profile with a sentinel value that identifies +# that the module was optimized with PGLE. # 2. PGLE profiled module is not in the persistent cache and the module is -# getting built with an FDO profile. In this case we need to share FDO profile -# with other processes and store the result under the -# "pgle_profiled_module_key" so later in case 1 we will be able to find the +# getting built with an FDO profile. In this case we need to share the FDO +# profile with any other processes and store the result under the +# "pgle_optimized_cache_key" so later in case 1 we will be able to find the # module. # 3. PGLE profiled module is not in the persistent cache and the module is # getting compiled to be PGLEd (FDO profile is empty). In this case we need to -# simply return the non-PGLE profiled module from the persistent cache. +# simply return the non-PGLE profiled module from the persistent cache if it +# exists, and otherwise compile it. # # If the compilation_cache_expect_pgle option is set then in case 1 the PGLE # optimized module will be loaded even if PGLE is not enabled in the current # process. This is useful if we want to combine the use of PGLE with other # profiling tools (e.g. Nsight Systems) that cannot co-exist with PGLE due to # contention for CUPTI resources. -def _resolve_pgle_module_cache_key( +def _resolve_compilation_strategy( computation: ir.Module, devices: np.ndarray, compile_options: xc.CompileOptions, backend: xc.Client, pgle_profiler: profiler.PGLEProfiler | None, is_multi_process: bool, - cache_key: str, module_name: str, min_device_process_id: int, -) -> str: - fdo_profile = compile_options.executable_build_options.fdo_profile - compile_options.executable_build_options.fdo_profile = b"pgle profiled" - - pgle_profiled_module_key = compilation_cache.get_cache_key( - computation, - devices, - compile_options, - backend, - cache_key_type.IgnoreCallbacks.ALL, +) -> tuple[str | None, xc.CompileOptions]: + is_auto_pgle_used = ( + config.enable_pgle.value and config.pgle_profiling_runs.value > 0 ) - compile_options.executable_build_options.fdo_profile = fdo_profile - - result_key = cache_key - if _is_executable_in_cache(backend, pgle_profiled_module_key): - # Load PGLE profiled module from the persistent cache. - result_key = pgle_profiled_module_key - if config.compilation_cache_expect_pgle.value: - logging.info(f"PGLE-optimized {module_name} loaded from compilation cache") - if pgle_profiler is not None: - pgle_profiler.disable() + + get_cache_key = partial(_get_cache_key, backend=backend, + computation=computation, devices=devices) + + if is_auto_pgle_used or config.compilation_cache_expect_pgle.value: + # This can be None if cache key generation fails. + pgle_optimized_cache_key = get_cache_key(compile_options, + override_fdo_profile=b"pgle profiled") + # TODO(b/376647494): remove the workaround when the bug is fixed; the JAX + # profiler cannot collect sufficiently detailed profile data for PGLE if + # command buffers / CUDA graphs are enabled. Therefore disable command + # buffers when compiling for PGLE data collection, but not if AutoPGLE is + # not enabled, and not when re-compiling using PGLE data. This condition + # includes `compilation_cache_expect_pgle` so that slow-to-compile modules + # that are not executed often enough to trigger re-compilation will still + # be cached between an "enable_pgle" run and an "expect_pgle" run. + first_pass_compile_options = copy.deepcopy(compile_options) + first_pass_compile_options.env_option_overrides += [ + ("xla_gpu_enable_command_buffer", ""), + ] else: - # No PGLE-optimised module found in the persistent cache. - if (config.compilation_cache_expect_pgle.value - and _is_executable_in_cache(backend, cache_key)): - # The user asserted this miss was unexpected; emit a warning + pgle_optimized_cache_key = None + first_pass_compile_options = compile_options + + # This can be None if cache key generation fails or caching is disabled + cache_key = get_cache_key(first_pass_compile_options) + + if cache_key is not None and pgle_optimized_cache_key is not None: + # The compilation cache is enabled and AutoPGLE is enabled/expected + if _is_executable_in_cache(backend, pgle_optimized_cache_key): + if config.compilation_cache_expect_pgle.value: + logger.info(f"PGLE-optimized {module_name} loaded from compilation cache") + # No need to record N profiles in this case + if pgle_profiler is not None: + pgle_profiler.disable() + return pgle_optimized_cache_key, compile_options + elif (config.compilation_cache_expect_pgle.value + and _is_executable_in_cache(backend, cache_key)): + # No PGLE-optimized module found in the persistent cache, and the user + # asserted (expect_pgle) that this miss was unexpected warnings.warn(f"PERSISTENT CACHE MISS for PGLE-optimized {module_name} " "despite non-PGLE hit; it may not have been executed " "enough times when the cache was populated") - if fdo_profile is not None and len(fdo_profile) > 0: - # Store module under PGLE profiled module cache key. - result_key = pgle_profiled_module_key - if is_multi_process and distributed.global_state.client is not None: - compile_options.executable_build_options.fdo_profile = ( - _share_fdo_profiles( - computation, - devices, - compile_options, - backend, - distributed.global_state.client, - min_device_process_id, - ) - ) - else: - compile_options.executable_build_options.fdo_profile = fdo_profile - logger.debug( - "Compiling module %s with FDO profile of length %d", - module_name, - len(compile_options.executable_build_options.fdo_profile), + + if (is_auto_pgle_used + and compile_options.executable_build_options.fdo_profile is not None + and len(compile_options.executable_build_options.fdo_profile)): + # Profile data are available to trigger a PGLE-optimized recompilation; + # store under `pgle_optimized_cache_key` if the cache is enabled + if is_multi_process and distributed.global_state.client is not None: + compile_options.executable_build_options.fdo_profile = ( + _share_fdo_profiles( + computation, + devices, + compile_options, + backend, + distributed.global_state.client, + min_device_process_id, ) - return result_key + ) + return pgle_optimized_cache_key, compile_options + else: + # Compile for PGLE collection, store under `cache_key` if the cache is + # enabled. This is also the AutoPGLE-disabled path. + return cache_key, first_pass_compile_options +def _get_cache_key( + options: xc.CompileOptions, + backend: xc.Client, + computation: ir.Module, + devices: np.ndarray, + override_fdo_profile: bytes | None = None) -> str | None: + if not compilation_cache.is_cache_used(backend): + return None + if config.remove_custom_partitioning_ptr_from_cache_key.value: + ignore_callbacks = cache_key_type.IgnoreCallbacks.CUSTOM_PARTITIONING + else: + ignore_callbacks = cache_key_type.IgnoreCallbacks.NO + if override_fdo_profile is not None: + options = copy.deepcopy(options) + options.executable_build_options.fdo_profile = override_fdo_profile + try: + return compilation_cache.get_cache_key( + computation, + devices, + options, + backend, + ignore_callbacks, + ) + except _jax.JaxRuntimeError as ex: + logger.error("compile_or_get_cached: unable to generate cache key, " + "skipping the cache: %s", ex) + return None # The process that has the lowest device ID should share FDO profile before # compilation with other processes. @@ -568,13 +622,13 @@ def _share_fdo_profiles( devices: np.ndarray, compile_options: xc.CompileOptions, backend: xc.Client, - global_client: lib.xla_extension.DistributedRuntimeClient, + global_client: lib._jax.DistributedRuntimeClient, min_process_id -) -> bytes | None: +) -> bytes: sym_name = computation.operation.attributes['sym_name'] module_name = ir.StringAttr(sym_name).value fdo_profile = compile_options.executable_build_options.fdo_profile - if fdo_profile is None or len(fdo_profile) == 0: + if len(fdo_profile) == 0: return fdo_profile compile_options.executable_build_options.fdo_profile = b"" @@ -589,7 +643,7 @@ def _share_fdo_profiles( ) + "_fdo_sync" ) - except xc._xla.XlaRuntimeError as ex: + except _jax.JaxRuntimeError as ex: logger.error( "compile_or_get_cached: unable to generate cache key, " "skipping the fdo profile sharing: %s", @@ -624,14 +678,16 @@ def _share_fdo_profiles( _share_fdo_profiles.modules_profiles = {} + # The process with the first_process_id should compile the module and write it # to the K-V storage. def _compile_and_share_module( backend: xc.Client, computation: ir.Module, + executable_devices: xc.DeviceList, compile_options: xc.CompileOptions, host_callbacks: Sequence[Any], - global_client: lib.xla_extension.DistributedRuntimeClient, + global_client: lib._jax.DistributedRuntimeClient, module_name: str, cache_key: str, first_process_id: int @@ -647,6 +703,7 @@ def _compile_and_share_module( executable = _compile_and_write_cache( backend, computation, + executable_devices, compile_options, host_callbacks, module_name, @@ -667,25 +724,27 @@ def _compile_and_share_module( serialized_executable ) executable = backend.deserialize_executable( - serialized_executable, compile_options - ) + serialized_executable, executable_devices, compile_options) # type: ignore _compile_and_share_module.modules_cache[cache_key] = executable return executable + _compile_and_share_module.modules_cache = {} + def _compile_and_write_cache( backend: xc.Client, computation: ir.Module, + executable_devices: xc.DeviceList, compile_options: xc.CompileOptions, host_callbacks: Sequence[Any], module_name: str, cache_key: str, ) -> xc.LoadedExecutable: start_time = time.monotonic() - executable = backend_compile( - backend, computation, compile_options, host_callbacks + executable = backend_compile_and_load( + backend, computation, executable_devices, compile_options, host_callbacks ) compile_time = time.monotonic() - start_time _cache_write( @@ -693,6 +752,7 @@ def _compile_and_write_cache( ) return executable + def _is_executable_in_cache(backend, cache_key) -> bool: """Checks if executable is presented in cache on a given key """ @@ -709,14 +769,14 @@ def _is_executable_in_cache(backend, cache_key) -> bool: def _cache_read( module_name: str, cache_key: str, compile_options: xc.CompileOptions, - backend: xc.Client + backend: xc.Client, executable_devices: xc.DeviceList, ) -> tuple[xc.LoadedExecutable | None, int | None]: """Looks up the `computation` and it's compilation time in the persistent compilation cache repository. """ try: return compilation_cache.get_executable_and_time( - cache_key, compile_options, backend) + cache_key, compile_options, backend, executable_devices) except Exception as ex: if config.raise_persistent_cache_errors.value: raise diff --git a/jax/_src/compute_on.py b/jax/_src/compute_on.py index e7130c970cd4..447a6bd503d9 100644 --- a/jax/_src/compute_on.py +++ b/jax/_src/compute_on.py @@ -14,10 +14,25 @@ from __future__ import annotations from contextlib import contextmanager +from functools import partial +from typing import Sequence + from jax._src import config from jax._src.lib import xla_client +from jax._src import dispatch +from jax._src import core +from jax._src import linear_util as lu +from jax._src.interpreters import ad, batching, mlir, partial_eval as pe +from jax._src.tree_util import tree_flatten, tree_unflatten +from jax._src.util import (safe_map, safe_zip, weakref_lru_cache, unzip2, + split_list) +from jax._src.api_util import debug_info, flatten_fun_nokwargs, flatten_axes +from jax._src.lib.mlir.dialects import func as func_dialect +from jax._src.lib.mlir import ir config_ext = xla_client._xla.config +map, unsafe_map = safe_map, map +zip, unsafe_zip = safe_zip, zip @contextmanager @@ -36,8 +51,6 @@ def extend_compute_type(c_type: str | None): finally: config.compute_on_context_manager.set_local(prev) -def current_compute_type() -> str | None: - return config.compute_on_context_manager.value def _check_valid(c_type: str): if (c_type not in {'device_host', 'device', 'tpu_sparsecore'} @@ -54,3 +67,194 @@ def compute_on(compute_type: str): with extend_compute_type(compute_type): yield + +def compute_on2(f=None, *, compute_type, out_memory_spaces): + kwargs = dict(compute_type=compute_type, out_memory_spaces=out_memory_spaces) + if f is None: + return lambda g: _compute_on2(g, **kwargs) + return _compute_on2(f, **kwargs) + +def _compute_on2(f, *, compute_type, out_memory_spaces): + def wrapped(*args): + dbg = debug_info('compute_on', f, args, {}) + args_flat, in_tree = tree_flatten(args) + in_avals = tuple(core.shaped_abstractify(x) for x in args_flat) + jaxpr, out_tree = _trace_to_jaxpr(f, in_avals, in_tree, dbg) + out_memory_spaces_flat = flatten_axes( + "compute_on out_memory_spaces", out_tree, out_memory_spaces) + outs_flat = compute_on_p.bind( + *args_flat, jaxpr=jaxpr, compute_type=compute_type, + out_memory_spaces=tuple(out_memory_spaces_flat)) + return tree_unflatten(out_tree, outs_flat) + return wrapped + +@weakref_lru_cache +def _trace_to_jaxpr(fun, in_avals, in_tree, dbg): + f = lu.wrap_init(fun, debug_info=dbg) + f, out_tree = flatten_fun_nokwargs(f, in_tree) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(f, in_avals) + return core.ClosedJaxpr(jaxpr, consts), out_tree() + +compute_on_p = core.Primitive('compute_on') +compute_on_p.multiple_results = True +dispatch.simple_impl(compute_on_p) + + +def _compute_on_abstract_eval(*in_avals, jaxpr, compute_type, out_memory_spaces): + return [a.update(memory_space=s) + for a, s in zip(jaxpr.out_avals, out_memory_spaces)] +compute_on_p.def_abstract_eval(_compute_on_abstract_eval) + + +def _compute_on_lowering(ctx, *args, jaxpr, compute_type, out_memory_spaces): + const_args_and_avals = core.jaxpr_const_args(jaxpr.jaxpr) + const_args, const_avals = unzip2(const_args_and_avals) + const_arg_values = [ + mlir.ir_constant(c, const_lowering=ctx.const_lowering, aval=aval) + for c, aval in const_args_and_avals] + in_avals = (*const_avals, *ctx.avals_in) + func_op, output_types, effects = mlir.lower_called_computation( + "compute_on", jaxpr, ctx.module_context, len(const_args), in_avals, + ctx.avals_out, ctx.tokens_in) + + symbol_name = func_op.name.value + flat_output_types = mlir.flatten_ir_types(output_types) + tokens = [ctx.tokens_in.get(eff) for eff in effects] + args = (*ctx.dim_var_values, *tokens, *const_arg_values, *args) + call = func_dialect.CallOp( + flat_output_types, ir.FlatSymbolRefAttr.get(symbol_name), + mlir.flatten_ir_values(args)) + + if compute_type.startswith("gpu_stream:"): + dict_attr = { + "_xla_stream_annotation": ir.StringAttr.get(compute_type.split(":")[1]), + "inlineable": ir.StringAttr.get("false"), + } + else: + dict_attr = { + "_xla_compute_type": ir.StringAttr.get(mlir.map_compute_type(compute_type)) + } + call.operation.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr) + + out_nodes = mlir.unflatten_ir_values_like_types(call.results, output_types) + tokens, out_nodes = split_list(out_nodes, [len(effects)]) + tokens_out = ctx.tokens_in.update_tokens(mlir.TokenSet(zip(effects, tokens))) + ctx.set_tokens_out(tokens_out) + return [mlir.wrap_with_memory_kind(on, core.mem_space_to_kind(oms), out_aval) + for on, out_aval, oms in zip(out_nodes, ctx.avals_out, out_memory_spaces)] + +mlir.register_lowering(compute_on_p, _compute_on_lowering) + + +def _compute_on_batcher(axis_data, vals_in, dims_in, *, jaxpr, compute_type, + out_memory_spaces): + batched_jaxpr, dims_out = batching.batch_jaxpr2(jaxpr, axis_data, dims_in) + outs = compute_on_p.bind(*vals_in, jaxpr=batched_jaxpr, + compute_type=compute_type, + out_memory_spaces=out_memory_spaces) + return outs, dims_out +batching.fancy_primitive_batchers[compute_on_p] = _compute_on_batcher + + +def _compute_on_jvp(primals, tangents, *, jaxpr, compute_type, + out_memory_spaces): + nzs = [not isinstance(t, ad.Zero) for t in tangents] + jaxpr_jvp, out_nzs = ad.jvp_jaxpr(jaxpr, nzs, False) + nz_tangents = [t for t in tangents if not isinstance(t, ad.Zero)] + spaces_jvp = (*out_memory_spaces, + *[s for s, nz in zip(out_memory_spaces, out_nzs) if nz]) + outs = compute_on_p.bind(*primals, *nz_tangents, jaxpr=jaxpr_jvp, + compute_type=compute_type, + out_memory_spaces=spaces_jvp) + primals_out, nz_tangents_out = outs[:len(out_nzs)], outs[len(out_nzs):] + nz_outs = iter(nz_tangents_out) + tangents_out = [next(nz_outs) if nz else ad.Zero(aval.to_tangent_aval()) + for aval, nz in zip(jaxpr.out_avals, out_nzs)] + assert next(nz_outs, None) is None + return primals_out, tangents_out +ad.primitive_jvps[compute_on_p] = _compute_on_jvp + + +def _compute_on_lin(nzs, *primals, jaxpr, compute_type, out_memory_spaces): + jaxpr_jvp, out_nzs = ad.jvp_jaxpr(jaxpr, nzs, False) + lin_outs = [False] * len(out_nzs) + [True] * sum(out_nzs) + jaxpr_lin_, used_inputs = pe.dce_jaxpr(jaxpr_jvp.jaxpr, lin_outs, False) + jaxpr_lin = pe.close_jaxpr(jaxpr_lin_) + spaces_lin = tuple(s for s, nz in zip(out_memory_spaces, out_nzs) if nz) + primals_out = compute_on_p.bind(*primals, jaxpr=jaxpr, + compute_type=compute_type, + out_memory_spaces=out_memory_spaces) + tangent_avals_out = [a.to_tangent_aval() for a in jaxpr.out_avals] + + def compute_on_lin(primals, *tangents): + nz_tangents = [t for t in tangents if not isinstance(t, ad.Zero)] + inputs = [x for x, u in zip([*primals, *nz_tangents], used_inputs) if u] + nz_outs = compute_on_p.bind(*inputs, jaxpr=jaxpr_lin, + compute_type=compute_type, + out_memory_spaces=spaces_lin) + nz_outs_ = iter(nz_outs) + outs = [next(nz_outs_) if nz else ad.Zero(a) + for nz, a in zip(out_nzs, tangent_avals_out)] + assert next(nz_outs_, None) is None + return outs + return primals_out, out_nzs, primals, compute_on_lin +ad.primitive_linearizations[compute_on_p] = _compute_on_lin + +def _compute_on_partial_eval_custom_params_updater( + unks_in: Sequence[bool], inst_in: Sequence[bool], + kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool], + num_res_out: int, num_res_in: int, params_known, params_staged): + # prune inputs to jaxpr_known according to unks_in + _, out_memory_spaces_known = pe.partition_list( + kept_outs_known, params_known['out_memory_spaces']) + new_params_known = dict( + params_known, + out_memory_spaces=(*out_memory_spaces_known, + *[core.MemorySpace.Device] * num_res_out), + ) + assert (len(new_params_known['out_memory_spaces']) == + len(params_known['jaxpr'].out_avals)) + + # added num_res new inputs to jaxpr_staged, and pruning according to inst_in + _, out_memory_spaces_staged = pe.partition_list( + kept_outs_staged, params_staged['out_memory_spaces']) + new_params_staged = dict( + params_staged, + out_memory_spaces=tuple(out_memory_spaces_staged), + ) + assert (len(new_params_staged['out_memory_spaces']) == + len(params_staged['jaxpr'].out_avals)) + return new_params_known, new_params_staged + +pe.partial_eval_jaxpr_custom_rules[compute_on_p] = \ + partial(pe.closed_call_partial_eval_custom_rule, 'jaxpr', + _compute_on_partial_eval_custom_params_updater) + +@weakref_lru_cache +def _transpose_jaxpr(jaxpr, in_avals, in_tree): + cell = lambda: None + def transposed(*in_flat): + primals_in, cts_in = tree_unflatten(in_tree, in_flat) + out = ad.backward_pass(jaxpr.jaxpr, False, jaxpr.consts, primals_in, cts_in) + out = [ct if not isinstance(ct, ad.Zero) else None for ct in out] + cts_out, cell.out_tree = tree_flatten(out) # type: ignore + return cts_out + dbg = jaxpr.jaxpr.debug_info.with_unknown_names() + trans_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(transposed, debug_info=dbg), in_avals) + return core.ClosedJaxpr(trans_jaxpr, consts), cell.out_tree # type: ignore + +def _compute_on_transpose(cts_in, *primals_in, jaxpr, compute_type, + out_memory_spaces): + in_flat, in_tree = tree_flatten((primals_in, cts_in)) + in_avals = tuple(core.typeof(x) for x in in_flat) + trans_jaxpr, out_tree = _transpose_jaxpr(jaxpr, in_avals, in_tree) + in_spaces = [x.aval.memory_space if isinstance(x, ad.UndefinedPrimal) + else core.typeof(x).memory_space for x in primals_in] + cts_out_ = tree_unflatten(out_tree, trans_jaxpr.out_avals) + trans_spaces = tuple(s for x, s in zip(cts_out_, in_spaces) if x) + cts_out = compute_on_p.bind(*in_flat, jaxpr=trans_jaxpr, + compute_type=compute_type, + out_memory_spaces=trans_spaces) + return tree_unflatten(out_tree, cts_out) +ad.primitive_transposes[compute_on_p] = _compute_on_transpose diff --git a/jax/_src/config.py b/jax/_src/config.py index cf6a07834a10..95433bec05eb 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -22,19 +22,23 @@ import logging import os import sys -from typing import Any, Generic, NoReturn, Optional, Protocol, TypeVar, cast +from typing import Any, Generic, NoReturn, Optional, Protocol, Type, TypeVar, cast +import warnings +from jax._src import deprecations +from jax._src import logging_config +from jax._src.lib import _jax from jax._src.lib import guard_lib from jax._src.lib import jax_jit +from jax._src.lib import jaxlib_extension_version from jax._src.lib import xla_client -from jax._src import logging_config config_ext = xla_client._xla.config logger = logging.getLogger(__name__) _T = TypeVar('_T') - +_ET = TypeVar('_ET', bound=enum.Enum) class EffortLevel(enum.Enum): """Effort level enum, mirroring the XLA effort options.""" @@ -161,11 +165,14 @@ def config_with_absl(self): self.use_absl = True self.absl_flags = absl_flags - absl_defs = { bool: absl_flags.DEFINE_bool, - int: absl_flags.DEFINE_integer, - float: absl_flags.DEFINE_float, - str: absl_flags.DEFINE_string, - 'enum': absl_flags.DEFINE_enum } + absl_defs = { + bool: absl_flags.DEFINE_bool, + int: absl_flags.DEFINE_integer, + float: absl_flags.DEFINE_float, + str: absl_flags.DEFINE_string, + 'enum': absl_flags.DEFINE_enum, + 'enum_class': absl_flags.DEFINE_enum_class, + } for name, (flag_type, meta_args, meta_kwargs) in self.meta.items(): holder = self._value_holders[name] @@ -214,38 +221,9 @@ def parse_flags_with_absl(self): self.complete_absl_config(absl.flags) already_configured_with_absl = True +register_trace_context_callback = [] # type: ignore -def trace_context(): - """Returns a tuple of configuration values that affect tracing. - - These values are included in the cache key for linear_util.cache. - - Values included in this set should also most likely be included in - the C++ JIT state, which is handled separately. - """ - return (axis_env_state.value, mesh_context_manager.value, - xla_metadata_context_manager.value, - abstract_mesh_context_manager.value, - compute_on_context_manager.value, enable_x64.value, - numpy_rank_promotion.value, default_matmul_precision.value, - dynamic_shapes.value, - eager_constant_folding.value, - numpy_dtype_promotion.value, - default_device.value, random_seed_offset.value, - threefry_partitionable.value, - threefry_gpu_kernel_lowering.value, - use_direct_linearize.value, - varying_axes_in_types.value, - softmax_custom_jvp.value, - disable_jit.value, - debug_key_reuse.value, - jax_xla_profile_version.value, - # Technically this affects jaxpr->stablehlo lowering, not tracing. - hlo_source_file_canonicalization_regex.value, - pgle_profiling_runs.value, - enable_pgle.value, - use_shardy_partitioner.value, - use_high_dynamic_range_gumbel.value) +trace_context = config_ext.trace_context config = Config() @@ -263,7 +241,7 @@ class State(config_ext.Config[_T]): __slots__ = ( '_name', '_update_thread_local_hook', '_update_global_hook', - '_validator', '_default_context_manager_value', '__doc__', '__name__', + '_parser', '_default_context_manager_value', '__doc__', '__name__', ) def __init__( @@ -273,22 +251,24 @@ def __init__( help, update_global_hook: Callable[[_T], None] | None = None, update_thread_local_hook: Callable[[_T | None], None] | None = None, - validator: Callable[[Any], None] | None = None, + parser: Callable[[Any], Any] | None = None, extra_description: str = '', default_context_manager_value: Any = no_default, include_in_jit_key: bool = False, + include_in_trace_context: bool = False, ): - super().__init__(default, include_in_jit_key) + if parser is not None: + default = parser(default) + super().__init__(name, default, include_in_jit_key=include_in_jit_key, + include_in_trace_context=include_in_trace_context) self._name = name self.__name__ = name[4:] if name.startswith('jax_') else name self.__doc__ = (f"Context manager for `{name}` config option" f"{extra_description}.\n\n{help}") self._update_global_hook = update_global_hook self._update_thread_local_hook = update_thread_local_hook - self._validator = validator + self._parser = parser self._default_context_manager_value = default_context_manager_value - if self._validator: - self._validator(default) if self._update_global_hook: self._update_global_hook(default) config_states[name] = self @@ -300,8 +280,8 @@ def __bool__(self) -> NoReturn: type(self).__name__)) def _set(self, value: _T) -> None: - if self._validator: - self._validator(value) + if self._parser: + value = self._parser(value) self.set_global(value) if self._update_global_hook: self._update_global_hook(value) @@ -323,7 +303,6 @@ class StateContextManager(contextlib.ContextDecorator): def __init__(self, state, new_val): self.state = state - self.new_val = new_val if new_val is no_default: if state._default_context_manager_value is not no_default: @@ -334,8 +313,10 @@ def __init__(self, state, new_val): raise TypeError(f"Context manager for {state.__name__} config option " "requires an argument representing the new value for " "the config option.") - if state._validator: - state._validator(new_val) + if state._parser: + self.new_val = state._parser(new_val) + else: + self.new_val = new_val def __enter__(self): @@ -356,7 +337,7 @@ def __exit__(self, exc_type, exc_value, traceback): " This will be enabled by default in future versions of JAX, at which " "point all uses of the flag will be considered deprecated (following " "the `API compatibility policy " - "`_).") + "`_).") UPGRADE_BOOL_EXTRA_DESC = " (transient)" @@ -371,6 +352,8 @@ def bool_state( upgrade: bool = False, extra_description: str = '', include_in_jit_key: bool = False, + include_in_trace_context: bool = False, + validator: Callable[[str], None] | None = None, ) -> State[bool]: """Set up thread-local state and return a contextmanager for managing it. @@ -398,6 +381,11 @@ def bool_state( for the outgoing functionality to be deprecated. extra_description: string, optional: extra information to add to the summary description. + include_in_jit_key: bool, optional: whether to include the state in the + JIT cache key. + include_in_trace_context: bool, optional: whether to include the state in + the trace context. + validator: optional function to validate the value of the config option. Returns: A contextmanager to control the thread-local state value. @@ -432,11 +420,17 @@ def bool_state( extra_description += UPGRADE_BOOL_EXTRA_DESC config._contextmanager_flags.add(name) + def parser(val): + if validator: + validator(val) + return bool(val) + s = State[bool]( name, default, help, update_global_hook=update_global_hook, update_thread_local_hook=update_thread_local_hook, extra_description=extra_description, default_context_manager_value=True, - include_in_jit_key=include_in_jit_key) + parser=parser, include_in_jit_key=include_in_jit_key, + include_in_trace_context=include_in_trace_context) config.add_option(name, s, bool, meta_args=[], meta_kwargs={"help": help}) setattr(Config, name, property(lambda _: s.value)) return s @@ -451,6 +445,8 @@ def enum_state( update_global_hook: Callable[[str], None] | None = None, update_thread_local_hook: Callable[[str | None], None] | None = None, include_in_jit_key: bool = False, + include_in_trace_context: bool = False, + extra_validator: Callable[[str], None] | None = None, ) -> State[str]: """Set up thread-local state and return a contextmanager for managing it. @@ -465,6 +461,10 @@ def enum_state( default: string, default value. help: string, used to populate the flag help information as well as the docstring of the returned context manager. + include_in_jit_key: bool, optional: whether to include the state in the + JIT cache key. + extra_validator: optional function to validate the value of the config + option. Returns: A contextmanager to control the thread-local state value. @@ -478,10 +478,13 @@ def enum_state( raise ValueError(f"Invalid value \"{default}\" for JAX flag {name}") config._contextmanager_flags.add(name) - def validator(new_val): + def parser(new_val): if type(new_val) is not str or new_val not in enum_values: raise ValueError(f"new enum value must be in {enum_values}, " f"got {new_val} of type {type(new_val)}.") + if extra_validator is not None: + extra_validator(new_val) + return new_val s = State[str]( name, @@ -489,8 +492,9 @@ def validator(new_val): help, update_global_hook=update_global_hook, update_thread_local_hook=update_thread_local_hook, - validator=validator, + parser=parser, include_in_jit_key=include_in_jit_key, + include_in_trace_context=include_in_trace_context, ) config.add_option( name, s, 'enum', @@ -510,6 +514,7 @@ def optional_enum_state( update_global_hook: Callable[[str | None], None] | None = None, update_thread_local_hook: Callable[[str | None], None] | None = None, include_in_jit_key: bool = False, + include_in_trace_context: bool = False, ) -> State[str | None]: """Set up thread-local state and return a contextmanager for managing it. @@ -537,15 +542,17 @@ def optional_enum_state( raise ValueError(f"Invalid value \"{default}\" for JAX flag {name}") config._contextmanager_flags.add(name) - def validate(new_val): + def parser(new_val): if (new_val is not None and (type(new_val) is not str or new_val not in enum_values)): raise ValueError(f"new enum value must be None or in {enum_values}, " f"got {new_val} of type {type(new_val)}.") + return new_val s = State['str | None']( name, default, help, update_global_hook, update_thread_local_hook, - validate, include_in_jit_key=include_in_jit_key, + parser, include_in_jit_key=include_in_jit_key, + include_in_trace_context=include_in_trace_context, ) config.add_option( name, s, 'enum', @@ -556,6 +563,85 @@ def validate(new_val): return s +def enum_class_state( + name: str, + enum_class: Type[_ET], + default: _ET, + help: str, + *, + update_global_hook: Callable[[_ET], None] | None = None, + update_thread_local_hook: Callable[[_ET | None], None] | None = None, + include_in_jit_key: bool = False, + include_in_trace_context: bool = False, + extra_validator: Callable[[_ET], None] | None = None, +) -> State[_ET]: + """Set up thread-local state and return a contextmanager for managing it. + + See docstring for ``bool_state``. + + Args: + name: string, converted to lowercase to define the name of the config + option (and absl flag). It is converted to uppercase to define the + corresponding shell environment variable. + enum_class: a subtype of enum.Enum. + default: an instance of enum_class that is the default value. + help: string, used to populate the flag help information as well as the + docstring of the returned context manager. + include_in_jit_key: bool, optional: whether to include the state in the + JIT cache key. + include_in_trace_context: bool, optional: whether to include the state in + the trace context. + extra_validator: optional function to validate the value of the config + option. + + Returns: + A contextmanager to control the thread-local state value. + """ + if not isinstance(default, enum_class): + raise TypeError( + f'Default value must be of type {enum_class}, got {default} ' + f"of type {getattr(type(default), '__name__', type(default))}" + ) + name = name.lower() + default_str = os.getenv(name.upper(), None) + if default_str is not None: + try: + default = enum_class(default_str) + except ValueError as e: + raise ValueError(f"Invalid value \"{default_str}\" for JAX flag {name}") from e + config._contextmanager_flags.add(name) + + def parser(new_val): + if isinstance(new_val, str): + return enum_class(new_val) + if not isinstance(new_val, enum_class): + raise TypeError( + f'new enum value must be an instance of {enum_class}, got' + f' {new_val} of type {type(new_val)}.' + ) + if extra_validator is not None: + extra_validator(new_val) + return new_val + + s = State[_ET]( + name, + default, + help, + update_global_hook=update_global_hook, + update_thread_local_hook=update_thread_local_hook, + parser=parser, + include_in_jit_key=include_in_jit_key, + include_in_trace_context=include_in_trace_context, + ) + config.add_option( + name, s, 'enum_class', + meta_args=[], + meta_kwargs={"enum_class": enum_class, "help": help} + ) + setattr(Config, name, property(lambda _: s.value)) + return s + + def int_state( name: str, default: int, @@ -564,6 +650,7 @@ def int_state( update_global_hook: Callable[[int], None] | None = None, update_thread_local_hook: Callable[[int | None], None] | None = None, include_in_jit_key: bool = False, + include_in_trace_context: bool = False, validator: Callable[[Any], None] | None = None, ) -> State[int]: """Set up thread-local state and return a contextmanager for managing it. @@ -593,16 +680,18 @@ def int_state( raise ValueError(f"Invalid value \"{default_env}\" for JAX flag {name}") config._contextmanager_flags.add(name) - def validate(new_val): + def parser(new_val): if new_val is not None and not isinstance(new_val, int): raise ValueError(f'new int config value must be None or of type int, ' f'got {new_val} of type {type(new_val)}') if new_val is not None and validator is not None: validator(new_val) + return new_val s = State[int](name, default, help, update_global_hook, - update_thread_local_hook, validate, - include_in_jit_key=include_in_jit_key) + update_thread_local_hook, parser, + include_in_jit_key=include_in_jit_key, + include_in_trace_context=include_in_trace_context) config.add_option(name, s, int, meta_args=[], meta_kwargs={"help": help}) setattr(Config, name, property(lambda _: s.value)) return s @@ -643,14 +732,15 @@ def float_state( raise ValueError(f"Invalid value \"{default_env}\" for JAX flag {name}") config._contextmanager_flags.add(name) - def validate(new_val): + def parser(new_val): if new_val is not None and not isinstance(new_val, (float, int)): raise ValueError( f'new float config value must be None or of type float, ' f'got {new_val} of type {type(new_val)}') + return new_val s = State[float](name, default, help, update_global_hook, - update_thread_local_hook, validate) + update_thread_local_hook, parser) config.add_option(name, s, float, meta_args=[], meta_kwargs={"help": help}) setattr(Config, name, property(lambda _: s.value)) return s @@ -707,6 +797,7 @@ def optional_string_state( *, update_global_hook: Callable[[str], None] | None = None, update_thread_local_hook: Callable[[str | None], None] | None = None, + include_in_trace_context: bool = False, ) -> State[str | None]: """Set up thread-local state and return a contextmanager for managing it. @@ -741,7 +832,8 @@ def validator(new_val): name, default, help, update_global_hook=update_global_hook, update_thread_local_hook=update_thread_local_hook, - validator=validator) + validator=validator, + include_in_trace_context=include_in_trace_context) def string_or_object_state( name: str, @@ -751,6 +843,8 @@ def string_or_object_state( update_global_hook: Callable[[Any], None] | None = None, update_thread_local_hook: Callable[[Any], None] | None = None, validator: Callable[[Any], None] | None = None, + include_in_jit_key: bool = False, + include_in_trace_context: bool = False, ) -> State[Any]: """Set up thread-local state and return a contextmanager for managing it. @@ -781,9 +875,15 @@ def string_or_object_state( default = os.getenv(name.upper(), default) config._contextmanager_flags.add(name) + def parser(new_val): + if validator is not None: + validator(new_val) + return new_val + s = State[Any]( name, default, help, update_global_hook, update_thread_local_hook, - validator) + parser, include_in_jit_key=include_in_jit_key, + include_in_trace_context=include_in_trace_context) setattr(Config, name, property(lambda _: s.value)) config.add_option(name, s, str, meta_args=[], meta_kwargs={"help": help}) return s @@ -853,13 +953,98 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: already_configured_with_absl = False -trace_state = config_ext.Config(None, include_in_jit_key=True) -axis_env_state = config_ext.Config((), include_in_jit_key=True) -mesh_context_manager = config_ext.Config((), include_in_jit_key=True) -abstract_mesh_context_manager = config_ext.Config(None, include_in_jit_key=True) -device_context = config_ext.Config(None, include_in_jit_key=True) -compute_on_context_manager = config_ext.Config(None, include_in_jit_key=True) -xla_metadata_context_manager = config_ext.Config(None, include_in_jit_key=True) +trace_state = config_ext.Config('trace_state', None, include_in_jit_key=True) +axis_env_state = config_ext.Config( + 'axis_env_state', + (), + include_in_jit_key=True, + include_in_trace_context=True, +) +mesh_context_manager = config_ext.Config( + 'mesh_context_manager', + (), + include_in_jit_key=True, + include_in_trace_context=True, +) +abstract_mesh_context_manager = config_ext.Config( + 'abstract_mesh_context_manager', + None, + include_in_jit_key=True, + include_in_trace_context=True, +) +device_context = config_ext.Config( + 'device_context', None, include_in_jit_key=True +) +compute_on_context_manager = config_ext.Config( + 'compute_on_context_manager', + None, + include_in_jit_key=True, + include_in_trace_context=True, +) +xla_metadata_context_manager = config_ext.Config( + 'xla_metadata_context_manager', + None, + include_in_jit_key=True, + include_in_trace_context=True, +) +pallas_tpu_interpret_mode_context_manager = config_ext.Config( + 'pallas_tpu_interpret_mode_context_manager', + None, + include_in_jit_key=True, + include_in_trace_context=True, +) + +class UserConfig: + def __init__(self, default_value): + self._obj = config_ext.Config("user_context", default_value, include_in_jit_key=True, + include_in_trace_context=True) + + @property + def value(self): + return self._obj.value + + def __call__(self, new_value): + return UserContext(self._obj, new_value) + +class UserContext: + __slots__ = ["_config", "_new_value", "_prev_value"] + + def __init__(self, config, new_value): + self._config = config + self._new_value = new_value + + def __enter__(self): + self._prev_value = self._config.swap_local(self._new_value) + + def __exit__(self, exc_type, exc_val, exc_tb): + self._config.set_local(self._prev_value) + +def make_user_context(default_value=None): + """Creates a `jax.jit` cache sensitive context. + + If the value of the context changes, JAX's tracing, lowering and compilation + cache won't get a hit and the jitted function will be re-traced, re-lowered + and re-compiled. + + This function is not thread-safe. Do not call it concurrently with other JAX + APIs. + + Example: + + ``` + @jax.jit + def f(x): + return x * 2 + + my_context = jax.make_user_context(default_value=None) + with my_context(1): + f(1.) + with my_context(2): + f(1.) # tracing cache miss + ``` + """ + obj = UserConfig(default_value) + return obj # TODO(b/214340779): remove flag when XLA:CPU is improved. @@ -902,13 +1087,13 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: # Note: bump the default calling convention version at least one month after # we update XlaCallModule to support the new version, so that serialized # modules are forward compatible with deployed versions of XlaCallModule. - # Version 9 of XlaCallModule is supported since October 27th, 2023. - default=int_env('JAX_EXPORT_CALLING_CONVENTION_VERSION', 9), + # Version 10 of XlaCallModule is supported since May 20th, 2025. + default=int_env('JAX_EXPORT_CALLING_CONVENTION_VERSION', 10), help=( 'The calling convention version number to use for exporting. This must be ' 'within the range of versions supported by the tf.XlaCallModule ' 'used in your deployment environment. ' - 'See https://jax.readthedocs.io/en/latest/export/shape_poly.html#calling-convention-versions.' + 'See https://docs.jax.dev/en/latest/export/shape_poly.html#calling-convention-versions.' ) ) @@ -917,7 +1102,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: default=bool_env('JAX_EXPORT_IGNORE_FORWARD_COMPATIBILIY', False), help=( 'Whether to ignore the forward compatibility lowering rules. ' - 'See https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees-for-custom-calls.' + 'See https://docs.jax.dev/en/latest/export/export.html#compatibility-guarantees-for-custom-calls.' ) ) @@ -937,11 +1122,17 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: 'otherwise.' )) -jax_pjrt_client_create_options = optional_string_state( +def _validate_jax_pjrt_client_create_options(new_val): + if new_val is not None and not isinstance(new_val, (str, dict)): + raise ValueError('new string config value must be None or of type dict' + f' | str, got {new_val} of type {type(new_val)}.') + +jax_pjrt_client_create_options = string_or_object_state( name='jax_pjrt_client_create_options', default=None, help=('A set of key-value pairs in the format of "k1:v1;k2:v2" strings ' - 'provided to a device platform pjrt client as extra arguments.')) + 'provided to a device platform pjrt client as extra arguments.'), + validator=_validate_jax_pjrt_client_create_options) enable_checks = bool_state( name='jax_enable_checks', @@ -955,7 +1146,8 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: ' typed PRNG keys (i.e. keys created with jax.random.key()) will have their' ' usage tracked, and incorrect reuse of a previously-used key will lead to' ' an error. Currently enabling this leads to a small Python overhead on' - ' every call to a JIT-compiled function with keys as inputs or outputs.')) + ' every call to a JIT-compiled function with keys as inputs or outputs.'), + include_in_trace_context=True) check_tracer_leaks = bool_state( name='jax_check_tracer_leaks', @@ -967,6 +1159,37 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: 'to disable any debuggers while leak checking is enabled.')) checking_leaks = functools.partial(check_tracer_leaks, True) +check_static_indices = bool_state( + name='jax_check_static_indices', + default=False, + help=('Turn on bounds checks for static indices during array indexing operations.' + ' These will only be checked when indexing mode is PROMISE_IN_BOUNDS, which' + ' is the default for gather-type operations.'), + include_in_jit_key=True, + include_in_trace_context=True, +) + +captured_constants_warn_bytes = int_state( + name='jax_captured_constants_warn_bytes', + default=2 * 10 ** 9, + help=('The number of bytes of parameters that may be captured as constants ' + 'before a warning is issued. Defaults to approximately 2GB. ' + 'Set to -1 to disable issuing a warning.' + ) +) + +captured_constants_report_frames = int_state( + name='jax_captured_constants_report_frames', + default=0, + help=('The number of stack frames reported for each captured constant ' + 'indicating the file and operation where the constant was captured. ' + 'Set to -1 to print the complete set of frames, or 0 to disable. ' + 'N.b. the report is only generated if the total amount of captured ' + 'constants exceeds `jax_captured_constants_warn_bytes`, as it is expensive' + 'to generate the report.' + ) +) + debug_nans = bool_state( name='jax_debug_nans', default=False, @@ -995,7 +1218,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: name='jax_explain_cache_misses', default=False, help=('Each time there is a miss on one of the main caches (e.g. the ' - 'tracing cache), log an explanation.. Logging is performed with ' + 'tracing cache), log an explanation. Logging is performed with ' '`logging`. When this option is set, the log level is WARNING; ' 'otherwise the level is DEBUG.')) @@ -1006,24 +1229,32 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: 'partially evaluated (e.g. for autodiff), printing what residuals ' 'are saved.')) +# Since we want a deprecation warning regardless of value, we need an +# exemption for when config.py is first loaded. +_pmap_shmap_merge_initialized = False + + +def _default_pmap_shmap_merge(new_val): + del new_val + global _pmap_shmap_merge_initialized + if _pmap_shmap_merge_initialized: + deprecations.warn( + 'jax-pmap-shmap-merge', + ( + 'Setting `jax_pmap_shmap_merge` is deprecated in JAX v0.9.0 and ' + 'will be removed in JAX v0.10.0.' + ), + stacklevel=3, + ) + _pmap_shmap_merge_initialized = True + pmap_shmap_merge = bool_state( name='jax_pmap_shmap_merge', - default=False, + default=True, upgrade=True, - help='If True, pmap and shard_map API will be merged.') - - -spmd_mode = enum_state( - name='jax_spmd_mode', - enum_values=['allow_all', 'allow_jit'], - default='allow_jit', - help=("Decides whether Math on `jax.Array`'s that are not fully addressable " - "(i.e. spans across multiple processes) is allowed. The options are: " - "* allow_jit: Default, `pjit` and `jax.jit` computations are allowed " - " to execute on non-fully addressable `jax.Array`s\n" - "* allow_all: `jnp`, normal math (like `a + b`, etc), `pjit`, " - " `jax.jit` and all other operations are allowed to " - " execute on non-fully addressable `jax.Array`s.")) + help='If True, pmap and shard_map API will be merged.', + validator=_default_pmap_shmap_merge, +) distributed_debug = bool_state( @@ -1038,12 +1269,18 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: default=0, help=('Offset to all random seeds (e.g. argument to jax.random.key()).'), include_in_jit_key=True, + include_in_trace_context=True, ) -legacy_prng_key = enum_state( +class LegacyPrngKeyState(enum.StrEnum): + ALLOW = 'allow' + WARN = 'warn' + ERROR = 'error' + +legacy_prng_key = enum_class_state( name='jax_legacy_prng_key', - enum_values=['allow', 'warn', 'error'], - default='allow', + enum_class=LegacyPrngKeyState, + default=LegacyPrngKeyState.ALLOW, help=('Specify the behavior when raw PRNG keys are passed to ' 'jax.random APIs.') ) @@ -1072,7 +1309,8 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: 'may result in extraneous communication and/or redundant distributed ' 'computation. With this flag, the communication overheads disappear ' 'in some cases.'), - include_in_jit_key=True) + include_in_jit_key=True, + include_in_trace_context=True) threefry_gpu_kernel_lowering = bool_state( name='jax_threefry_gpu_kernel_lowering', @@ -1080,27 +1318,45 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: help=('On GPU, lower threefry PRNG operations to a kernel implementation. ' 'This makes compile times faster at a potential runtime memory ' 'cost.'), - include_in_jit_key=True) + include_in_jit_key=True, + include_in_trace_context=True) use_direct_linearize = bool_state( name='jax_use_direct_linearize', - default=False, + default=True, help=('Use direct linearization instead JVP followed by partial eval'), - include_in_jit_key=True) + include_in_jit_key=True, + include_in_trace_context=True) -varying_axes_in_types = bool_state( - name='jax_varying_axes_in_types', +use_simplified_jaxpr_constants = bool_state( + name='jax_use_simplified_jaxpr_constants', default=False, - help=('Adds varying manual axes to ShapedArray to track which mesh axes the' - ' array is varying over. This will help to remove the efficient' - ' transpose rewrite machinery in shard_map'), - include_in_jit_key=True) + help=('Enable a simplification of the handling of closed-over constants ' + 'in Jaxpr. The value `True` enables the new behavior. ' + 'This flag will exist only briefly, while we transition ' + 'users. See https://github.com/jax-ml/jax/pull/29679.' + 'DO NOT RELY ON THIS FLAG.'), + include_in_jit_key=True, + include_in_trace_context=True) + +# This config is temporary and should go away since this is a user problem. +# If they don't want 1 sized mesh axis names to show up in sharding and vma +# bits on ShapedArray, then their mesh (which they pass to set_mesh) should not +# contain those axes at all. +remove_size_one_mesh_axis_from_type = bool_state( + name='jax_remove_size_one_mesh_axis_from_type', + default=False, + help="Removes mesh axes of size 1 from ShapedArray.sharding and vma", + include_in_jit_key=True, + include_in_trace_context=True) -data_dependent_tracing_fallback = bool_state( - name='jax_data_dependent_tracing_fallback', +# TODO make it so people don't use this, this is internal... +_check_vma = bool_state( + name='check_vma', default=False, - help=('When True, falls back to trace dispatch based on data dependence ' - 'instead of throwing an escaped tracer error.')) + help='internal implementation detail of shard_map, DO NOT USE', + include_in_jit_key=True, + include_in_trace_context=True) softmax_custom_jvp = bool_state( name='jax_softmax_custom_jvp', @@ -1109,7 +1365,8 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: help=('Use a new custom_jvp rule for jax.nn.softmax. The new rule should ' 'improve memory usage and stability. Set True to use new ' 'behavior. See https://github.com/jax-ml/jax/pull/15677'), - include_in_jit_key=True) + include_in_jit_key=True, + include_in_trace_context=True) enable_custom_vjp_by_custom_transpose = bool_state( @@ -1176,7 +1433,8 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: 'source_file with the given regex, and all matches are removed. ' 'This can be used to avoid spurious cache misses when using the ' 'persistent compilation cache, which includes HLO metadata in the ' - 'cache key.')) + 'cache key.'), + include_in_trace_context=True) include_full_tracebacks_in_locations = bool_state( name='jax_include_full_tracebacks_in_locations', @@ -1221,6 +1479,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: 'estimator.' ), include_in_jit_key=True, + include_in_trace_context=True, ) pgle_profiling_runs = int_state( @@ -1231,6 +1490,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: 'PGLE is used.' ), include_in_jit_key=True, + include_in_trace_context=True, ) pgle_aggregation_percentile = int_state( @@ -1290,23 +1550,69 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: 'what they are trying to achieve should set it.'), ) +def _default_dtype_bits_deprecation(val): + if val != '_default': + warnings.warn( + ( + 'The jax_default_dtype_bits configuration is deprecated in JAX v0.7.1' + ' and has no effect as of JAX v0.9.0. It will be removed in JAX v0.10.0.' + ), + category=DeprecationWarning, + stacklevel=4) + + default_dtype_bits = enum_state( name='jax_default_dtype_bits', - enum_values=['32', '64'], - default='64', - help=('Specify bit width of default dtypes, either 32-bit or 64-bit. ' - 'This is a temporary flag that will be used during the process ' - 'of deprecating the ``jax_enable_x64`` flag.')) + enum_values=['_default', '32', '64'], + default='_default', + help=('[deprecated]. This has no effect starting with JAX v0.9.0, and' + ' will be removed in JAX v0.10.0.'), + extra_validator=_default_dtype_bits_deprecation) + + +class ExplicitX64Mode(enum.IntEnum): + WARN = enum.auto() + ERROR = enum.auto() + ALLOW = enum.auto() + + @classmethod + def _missing_(cls, value: object) -> ExplicitX64Mode | None: + if value == "warn": + return cls.WARN + if value == "error": + return cls.ERROR + if value == "allow": + return cls.ALLOW + return None + + +explicit_x64_dtypes = enum_class_state( + name='jax_explicit_x64_dtypes', + enum_class=ExplicitX64Mode, + default=ExplicitX64Mode.WARN, + help=( + 'If set to ALLOW, explicit specification of 64-bit types will be ' + 'respected even if enable_x64 is false. If set to WARN, a warning will ' + 'be issued, and if set to ERROR, an error will be raised.' + ), + include_in_jit_key=True, + include_in_trace_context=True, +) -numpy_dtype_promotion = enum_state( +class NumpyDtypePromotion(enum.StrEnum): + STANDARD = 'standard' + STRICT = 'strict' + +numpy_dtype_promotion = enum_class_state( name='jax_numpy_dtype_promotion', - enum_values=['standard', 'strict'], - default='standard', + enum_class=NumpyDtypePromotion, + default=NumpyDtypePromotion.STANDARD, help=('Specify the rules used for implicit type promotion in operations ' 'between arrays. Options are "standard" or "strict"; in strict-mode, ' 'binary operations between arrays of differing strongly-specified ' 'dtypes will result in an error.'), - include_in_jit_key=True) + include_in_jit_key=True, + include_in_trace_context=True) disallow_mesh_context_manager = bool_state( name='jax_disallow_mesh_context_manager', @@ -1317,32 +1623,58 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: ), ) -def _update_x64_global(val): - jax_jit.global_state().enable_x64 = val +# TODO(ayx): Move these 3 flags out of config once we have a user-level +# extension mechanism for adding contexts to which the jit cache is sensitive. +error_checking_behavior_nan = enum_state( + name='jax_error_checking_behavior_nan', + enum_values=['ignore', 'raise'], + default='ignore', + help=( + 'Specify the behavior when a NaN is encountered. Options are "ignore"' + ' or "raise".' + ), + include_in_jit_key=True, + include_in_trace_context=True, +) + +error_checking_behavior_divide = enum_state( + name='jax_error_checking_behavior_divide', + enum_values=['ignore', 'raise'], + default='ignore', + help=( + 'Specify the behavior when a divide by zero is encountered. Options are' + ' "ignore" or "raise".' + ), + include_in_jit_key=True, + include_in_trace_context=True, +) -def _update_x64_thread_local(val): - jax_jit.thread_local_state().enable_x64 = val +error_checking_behavior_oob = enum_state( + name='jax_error_checking_behavior_oob', + enum_values=['ignore', 'raise'], + default='ignore', + help=( + 'Specify the behavior when an out of bounds access is encountered.' + ' Options are "ignore" or "raise".' + ), + include_in_jit_key=True, + include_in_trace_context=True, +) enable_x64 = bool_state( name='jax_enable_x64', default=False, help='Enable 64-bit types to be used', - update_global_hook=_update_x64_global, - update_thread_local_hook=_update_x64_thread_local) + include_in_jit_key=True, + include_in_trace_context=True) + +jax_jit.set_enable_x64_state(enable_x64) # TODO(phawkins): remove after fixing users of FLAGS.x64_enabled. config._contextmanager_flags.remove('jax_enable_x64') setattr(Config, "x64_enabled", property(lambda _: enable_x64.value)) -def _update_default_device_global(val): - jax_jit.global_state().default_device = val - - -def _update_default_device_thread_local(val): - jax_jit.thread_local_state().default_device = val - - def _validate_default_device(val): if (val is not None and not isinstance(val, xla_client.Device) and @@ -1369,23 +1701,17 @@ def _validate_default_device(val): 'no effect on multi-device computations, e.g. pmapped function calls). ' 'Set to None to use the system default device. See ' ':ref:`faq-data-placement` for more information on device placement.'), - update_global_hook=_update_default_device_global, - update_thread_local_hook=_update_default_device_thread_local, - validator=_validate_default_device) - -def _update_disable_jit_global(val): - jax_jit.global_state().disable_jit = val - -def _update_disable_jit_thread_local(val): - jax_jit.thread_local_state().disable_jit = val + validator=_validate_default_device, + include_in_jit_key=True, + include_in_trace_context=True) disable_jit = bool_state( name='jax_disable_jit', default=False, help=('Disable JIT compilation and just call original Python.'), - update_global_hook=_update_disable_jit_global, - update_thread_local_hook=_update_disable_jit_thread_local) + include_in_trace_context=True) +jax_jit.set_disable_jit_state(disable_jit) numpy_rank_promotion = enum_state( name='jax_numpy_rank_promotion', @@ -1393,7 +1719,8 @@ def _update_disable_jit_thread_local(val): default='allow', help=('Control NumPy-style automatic rank promotion broadcasting ' '("allow", "warn", or "raise").'), - include_in_jit_key=True) + include_in_jit_key=True, + include_in_trace_context=True) default_matmul_precision = optional_enum_state( name='jax_default_matmul_precision', @@ -1429,7 +1756,8 @@ def _update_disable_jit_thread_local(val): '"algorithm" for functions that perform matrix multiplications, like ' ':func:`jax.lax.dot`. To specify an algorithm, set this option to ' 'the name of a :class:`~jax.lax.DotAlgorithmPreset`.\n\n'), - include_in_jit_key=True) + include_in_jit_key=True, + include_in_trace_context=True) traceback_filtering = enum_state( @@ -1437,18 +1765,17 @@ def _update_disable_jit_thread_local(val): enum_values=["off", "tracebackhide", "remove_frames", "quiet_remove_frames", "auto"], default="auto", - help="Controls how JAX filters internal frames out of tracebacks.\n\n" - "Valid values are:\n" - " * \"off\": disables traceback filtering.\n" - " * \"auto\": use \"tracebackhide\" if running under a sufficiently" - " new IPython, or \"remove_frames\" otherwise.\n" - " * \"tracebackhide\": adds \"__tracebackhide__\" annotations to" - " hidden stack frames, which some traceback printers support.\n" - " * \"remove_frames\": removes hidden frames from tracebacks, and adds" - " the unfiltered traceback as a __cause__ of the exception.\n" - " * \"quiet_remove_frames\": removes hidden frames from tracebacks, and adds" - " a brief message (to the __cause__ of the exception) describing that this has" - " happened.\n") + help="Controls how JAX filters internal frames out of tracebacks. Valid values are:\n" + "- ``off``: disables traceback filtering.\n" + "- ``auto``: use ``tracebackhide`` if running under a sufficiently " + "new IPython, or ``remove_frames`` otherwise.\n" + "- ``tracebackhide``: adds ``__tracebackhide__`` annotations to " + "hidden stack frames, which some traceback printers support.\n" + "- ``remove_frames``: removes hidden frames from tracebacks, and adds " + "the unfiltered traceback as a ``__cause__`` of the exception.\n" + "- ``quiet_remove_frames``: removes hidden frames from tracebacks, and adds " + "a brief message (to the ``__cause__`` of the exception) describing that this has " + "happened.\n\n") # This flag is for internal use. # TODO(tianjianlu): Removes once we always enable cusparse lowering. @@ -1458,28 +1785,13 @@ def _update_disable_jit_thread_local(val): default=False, help=('Enables lowering BCOO ops to cuSparse.')) -# TODO(mattjj): remove this flag when we ensure we only succeed at trace-staging -# if the intended backend can handle lowering the result -dynamic_shapes = bool_state( - name='jax_dynamic_shapes', - default=bool(os.getenv('JAX_DYNAMIC_SHAPES', '')), - help=('Enables experimental features for staging out computations with ' - 'dynamic shapes.'), - include_in_jit_key=True) - # This is for stackless backward compat with e.g. equinox eager_constant_folding = bool_state( name='eager_constant_folding', default=False, help=('Attempt constant folding during staging.'), - include_in_jit_key=True) - -# This flag is temporary during rollout of the remat barrier. -# TODO(parkers): Remove if there are no complaints. -remat_opt_barrier = bool_state( - name='jax_remat_opt_barrier', - default=True, - help=('Enables using optimization-barrier op for lowering remat.')) + include_in_jit_key=True, + include_in_trace_context=True) enable_remat_opt_pass = bool_state( name='jax_compiler_enable_remat_pass', @@ -1489,36 +1801,43 @@ def _update_disable_jit_thread_local(val): 'compute when encountering OOM errors. However, you are ' 'likely to get better results manually with jax.checkpoint')) -# TODO(sharadmv,mattjj): set default to True, then remove -eager_pmap = bool_state( - name='jax_eager_pmap', - default=True, - upgrade=True, - help='Enable eager-mode pmap when jax_disable_jit is activated.') - no_tracing = bool_state( name='jax_no_tracing', default=False, help='Disallow tracing for JIT compilation.') +no_execution = bool_state( + name='jax_no_execution', + default=False, + help='Disallow JAX executions.', + include_in_jit_key=True, + include_in_trace_context=True) + disable_vmap_shmap_error = bool_state( name='jax_disable_vmap_shmap_error', default=False, upgrade=False, help='Temporary workaround to disable an error check in vmap-of-shmap.') -# TODO(mattjj): remove once we land mutable array plumbing, or face great shame -custom_vjp_disable_shape_check = bool_state( - name='jax_custom_vjp_disable_shape_check', +mutable_array_checks = bool_state( + name='jax_mutable_array_checks', + default=True, + upgrade=True, + help='Enable error checks for mutable arrays that rule out aliasing.', + include_in_trace_context=True) + +refs_to_pins = bool_state( + name='jax_refs_to_pins', default=False, upgrade=True, - help='Disable the check from #19009 to enable some custom_vjp hacks.') + help='Lower refs to pinned buffers in HLO.') -mutable_array_checks = bool_state( - name='jax_mutable_array_checks', +# TODO(mattjj, yashkatariya): remove once we land box plumbing +disable_bwd_checks = bool_state( + name='jax_disable_bwd_checks', default=False, upgrade=True, - help='Enable error checks for mutable arrays that rule out aliasing.') + help='Disables all bwd pass checks') xla_runtime_errors = bool_state( name='jax_experimental_unsafe_xla_runtime_errors', @@ -1538,6 +1857,7 @@ def _update_disable_jit_thread_local(val): 'only when XLA is configured to support the remote compilation ' 'profile feature.'), include_in_jit_key=True, + include_in_trace_context=True, ) @contextlib.contextmanager @@ -1650,7 +1970,7 @@ def transfer_guard(new_val: str) -> Iterator[None]: """A contextmanager to control the transfer guard level for all transfers. For more information, see - https://jax.readthedocs.io/en/latest/transfer_guard.html + https://docs.jax.dev/en/latest/transfer_guard.html Args: new_val: The new thread-local transfer guard level for all transfers. @@ -1685,16 +2005,17 @@ def _update_garbage_collection_guard(state, key, val): # The default is applied by guard_lib. default=None, help=( - 'Select garbage collection guard level for "jax.Array" objects.\nThis' - ' option can be used to control what happens when a "jax.Array"' - ' object is garbage collected. It is desirable for "jax.Array"' - ' objects to be freed by Python reference couting rather than garbage' + 'Select garbage collection guard level for ``jax.Array`` objects.\n\n' + 'This option can be used to control what happens when a ``jax.Array``' + ' object is garbage collected. It is desirable for ``jax.Array``' + ' objects to be freed by Python reference counting rather than garbage' ' collection in order to avoid device memory being held by the arrays' - ' until garbage collection occurs.\n\nValid values are:\n * "allow":' - ' do not log garbage collection of "jax.Array" objects.\n * "log":' - ' log an error when a "jax.Array" is garbage collected.\n * "fatal":' - ' fatal error if a "jax.Array" is garbage collected.\nDefault is' - ' "allow". Note that not all cycles may be detected.' + ' until garbage collection occurs.\n\n' + 'Valid values are:\n\n' + '* ``allow``: do not log garbage collection of ``jax.Array`` objects.\n' + '* ``log``: log an error when a ``jax.Array`` is garbage collected.\n' + '* ``fatal``: fatal error if a ``jax.Array`` is garbage collected.\n\n' + 'Default is ``allow``. Note that not all cycles may be detected.' ), update_global_hook=lambda val: _update_garbage_collection_guard( guard_lib.global_state(), 'garbage_collect_array', val @@ -1704,6 +2025,59 @@ def _update_garbage_collection_guard(state, key, val): ), ) +if jaxlib_extension_version >= 395: + thread_guard = bool_state( + name='jax_thread_guard', + default=False, + help=( + 'If True, an error will be raised at runtime if a multi-process JAX ' + 'operation is called from a thread other than the one in which the ' + 'thread guard was set. This is useful for detecting cases where ' + 'threads may schedule operations in different orders in different ' + 'processes, leading to non-deterministic crashes.' + ), + update_thread_local_hook=( + # If the state is None, set it to False. + lambda val: guard_lib.update_thread_guard_global_state(val or False)), + ) + +# TODO(nbasile): Remove hasattr checks after jaxlib 0.8.1 release +if hasattr(_jax, 'RuntimeTracebackMode'): + class RuntimeTracebackMode(enum.StrEnum): + OFF = 'off' + ON = 'on' + FULL = 'full' + + @classmethod + def _missing_(cls, value): + if isinstance(value, str): + try: + return cls[value.upper()] + except KeyError: + pass + return None + + def as_cpp_enum(self): + return getattr(_jax.RuntimeTracebackMode, self.name) + + send_traceback_to_runtime = enum_class_state( + name='jax_send_traceback_to_runtime', + enum_class=RuntimeTracebackMode, + default=RuntimeTracebackMode.OFF, + help=( + 'Controls the level of Python traceback information sent to the' + ' runtime at dispatch time:\n- "OFF": (default) No Python traceback' + ' information is sent.\n- "ON": Only the most recent user frame call' + ' location is sent.\n- "FULL": The full Python traceback of the call' + ' location is sent. This has a high fixed cost on the dispatch path' + ' and should be used only for debugging.' + ), + update_global_hook=lambda val: _jax.set_send_traceback_to_runtime_global( + val.as_cpp_enum() if val is not None else _jax.RuntimeTracebackMode.OFF), + update_thread_local_hook=lambda val: _jax.set_send_traceback_to_runtime_thread_local( + val.as_cpp_enum() if val is not None else None), + ) + # Don't define a context manager since this isn't threadsafe. string_state( name='jax_debug_log_modules', @@ -1726,22 +2100,18 @@ def _update_garbage_collection_guard(state, key, val): logging_config.update_logging_level_global(logging_level=logging_level) ) -pmap_no_rank_reduction = bool_state( - name='jax_pmap_no_rank_reduction', - default=True, - help='If True, pmap shards have a the same rank as their enclosing array.', -) + use_shardy_partitioner = bool_state( name='jax_use_shardy_partitioner', - default=False, + default=True, upgrade=True, help=( - 'Whether to lower to Shardy. Shardy is a new open sourced propagation ' - 'framework for MLIR. Currently Shardy is experimental in JAX. See ' - 'www.github.com/openxla/shardy' + 'Whether to lower to Shardy. See the migration guide for more ' + 'information: https://docs.jax.dev/en/latest/shardy_jax_migration.html.' ), include_in_jit_key=True, + include_in_trace_context=True, ) gpu_use_magma = enum_state( @@ -1790,35 +2160,70 @@ def _update_garbage_collection_guard(state, key, val): 'O2', 'O3', ], - default='UNKNOWN', + default='O2', help=( 'The degree to which the compiler should attempt to make the program' ' fit in memory' ), - include_in_jit_key=True + include_in_jit_key=True, ) +DEFAULT_CPU_COLLECTIVES_IMPL = "gloo" + cpu_collectives_implementation = optional_enum_state( name='jax_cpu_collectives_implementation', enum_values=["gloo", "mpi", "megascale"], - default=None, + default=DEFAULT_CPU_COLLECTIVES_IMPL, help=( "Cross-process collective implementation used on CPU. Must be one of " '("gloo", "mpi")'), ) -enable_empty_arrays = bool_state( - name='jax_enable_empty_arrays', - default=False, - help=( - "Enable the creation of an Array from an empty list of single-device " - "arrays. This is to support MPMD/pipeline parallelism in McJAX (WIP)." - ) -) - use_high_dynamic_range_gumbel = bool_state( name='jax_high_dynamic_range_gumbel', default=False, help='If True, gumble noise draws two samples to cover low probability ' 'events with more precision.', + include_in_trace_context=True, +) + +jax_dump_ir_to = string_flag( + name='jax_dump_ir_to', + default=os.getenv('JAX_DUMP_IR_TO', ''), + help="Path to which IR(s) emitted by JAX should be dumped as text files." + "If omitted, JAX will not dump any IR. " + "Supports the special value 'sponge' to pick the path from the " + "environment variable TEST_UNDECLARED_OUTPUTS_DIR. See " + "jax_dump_ir_modes for options governing what is dumped.") + +jax_include_debug_info_in_dumps = bool_flag( + name='jax_include_debug_info_in_dumps', + default=bool_env('JAX_INCLUDE_DEBUG_INFO_IN_DUMPS', True), + help='Determine whether or not to keep debug symbols and location ' + 'information when dumping IR code. By default, debug information will ' + 'be preserved in the IR dump. To avoid exposing source code and ' + 'potentially sensitive information, set to false ') + +# TODO(dsuo): Turn this into a list-valued flag. +jax_dump_ir_modes = string_flag( + name="jax_dump_ir_modes", + default=os.getenv("JAX_DUMP_IR_MODES", "stablehlo"), + help="Comma-delimited modes in which to dump IR. Can be 'stablehlo' (the " + "default), 'jaxpr', or 'eqn_count_pprof' for " + "jaxpr equation count pprof profile.") + +jax_ragged_dot_use_ragged_dot_instruction = bool_state( + name='jax_ragged_dot_use_ragged_dot_instruction', + default=True, + help=( + '(TPU only) If True, use chlo.ragged_dot instruction for ragged_dot()' + ' lowering. Otherwise, rely on the rollout logic in lowering rule for' + ' ragged_dot_general_p.' + ), +) + +jax_pallas_verbose_errors = bool_flag( + "jax_pallas_verbose_errors", + default=bool_env("JAX_PALLAS_VERBOSE_ERRORS", False), + help="If True, print verbose error messages for Pallas kernels.", ) diff --git a/jax/_src/core.py b/jax/_src/core.py index 36ce2f004ed4..d23fcb768f01 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -28,7 +28,7 @@ import threading import types from typing import (Any, ClassVar, Generic, NamedTuple, TypeVar, - overload, Union) + overload, Union, TYPE_CHECKING) import warnings import weakref @@ -37,7 +37,6 @@ from jax._src import dtypes from jax._src import config from jax._src import effects -from jax._src import compute_on from jax._src import mesh as mesh_lib from jax._src.mesh import AxisType from jax._src.partition_spec import PartitionSpec as P @@ -45,20 +44,24 @@ ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError, TracerIntegerConversionError, UnexpectedTracerError) from jax._src import linear_util as lu - +from jax._src.tree_util import tree_map from jax._src import source_info_util from jax._src.util import (safe_zip, safe_map, curry, tuple_insert, tuple_delete, cache, HashableFunction, HashableWrapper, weakref_lru_cache, - partition_list, StrictABCMeta, foreach) + partition_list, StrictABCMeta, foreach, + weakref_cache_key_types, set_module) import jax._src.pretty_printer as pp from jax._src.named_sharding import NamedSharding +from jax._src.sharding import Sharding +from jax._src.layout import Format, AutoLayout +from jax._src.memory import Space as MemorySpace +from jax._src.lib import _jax from jax._src.lib import jax_jit from jax._src.lib import xla_client from jax._src import traceback_util -from jax._src.typing import Array, DimSize, Shape -from jax._src import typing -from jax._src import xla_metadata as xla_metadata_lib +from jax._src.typing import Array, ArrayLike, DimSize, Shape +from jax._src import xla_metadata_lib traceback_util.register_exclusion(__file__) @@ -67,6 +70,8 @@ config_ext = xla_client._xla.config +PyTree = Any + _TRACER_ERROR_NUM_TRACEBACK_FRAMES = config.int_flag( 'jax_tracer_error_num_traceback_frames', @@ -74,6 +79,7 @@ help='Set the number of stack frames in JAX tracer error messages.' ) +def identity(x): return x # -------------------- jaxprs -------------------- @@ -84,10 +90,12 @@ DebugInfo = lu.DebugInfo +InitialResultPaths = lu.InitialResultPaths +initial_result_paths = lu.initial_result_paths class Jaxpr: __slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns', - '_effects', '_debug_info'] + '_effects', '_debug_info', '_is_high'] _constvars: list[Var] _invars: list[Var] @@ -95,6 +103,7 @@ class Jaxpr: _eqns: list[JaxprEqn] _effects: Effects _debug_info: DebugInfo + _is_high: bool @property def constvars(self) -> list[Var]: @@ -120,6 +129,28 @@ def effects(self) -> Effects: def debug_info(self) -> DebugInfo: return self._debug_info + @property + def is_high(self) -> bool: + return self._is_high + + @property + def in_avals(self): + return [v.aval for v in self.invars] + + @property + def in_aval_qdds(self) -> list[AbstractValue | AvalQDD]: + return [v.aval if v.initial_qdd is None else AvalQDD(v.aval, v.initial_qdd) + for v in self.invars] + + @property + def final_aval_qdds(self) -> list[AbstractValue | AvalQDD]: + return [v.aval if v.final_qdd is None else AvalQDD(v.aval, v.final_qdd) + for v in self.invars] + + @property + def out_avals(self): + return [v.aval for v in self.outvars] + def __init__(self, constvars: Sequence[Var], invars: Sequence[Var], outvars: Sequence[Atom], eqns: Sequence[JaxprEqn], effects: Effects = no_effects, @@ -127,6 +158,7 @@ def __init__(self, constvars: Sequence[Var], invars: Sequence[Var], # compatibility we have to allow calls when the debug_info # is missing. debug_info: DebugInfo = None, # type: ignore[annotation-type-mismatch,assignment] + is_high: bool = False, ): """ Args: @@ -148,9 +180,9 @@ def __init__(self, constvars: Sequence[Var], invars: Sequence[Var], # TODO(https://github.com/jax-ml/jax/issues/26480) debug_info = debug_info or lu._missing_debug_info("core.Jaxpr") self._debug_info = debug_info.resolve_result_paths() - # TODO(necula): re-enable these safety checks - # assert (len(debug_info.arg_names) == len(invars)), (debug_info, invars) - # assert (len(debug_info.result_paths) == len(outvars)), (debug_info, outvars) + config.enable_checks.value and self._debug_info.assert_arg_names(len(invars)) + config.enable_checks.value and self._debug_info.assert_result_paths(len(outvars)) + self._is_high = is_high def __str__(self): return str(self.pretty_print()) @@ -170,18 +202,24 @@ def _repr_pretty_(self, p, cycle): return p.text(self.pretty_print(use_color=True)) def replace(self, **kwargs): + debug_default = self.debug_info + if (kwargs.get('invars', self.invars) != self.invars or + kwargs.get('outvars', self.outvars) != self.outvars): + debug_default = debug_default.with_unknown_names() jaxpr = Jaxpr( constvars=kwargs.pop("constvars", self.constvars), invars=kwargs.pop("invars", self.invars), outvars=kwargs.pop("outvars", self.outvars), eqns=kwargs.pop("eqns", self.eqns), effects=kwargs.pop("effects", self.effects), - debug_info=kwargs.pop("debug_info", self.debug_info), + debug_info=kwargs.pop("debug_info", debug_default), + is_high=kwargs.pop("is_high", self.is_high), ) if kwargs: raise ValueError(f"Unknown keyword arguments: {kwargs}") return jaxpr +weakref_cache_key_types.add(Jaxpr) def join_effects(*effects: Effects) -> Effects: return set().union(*effects) if effects else no_effects @@ -212,6 +250,15 @@ class ClosedJaxpr: jaxpr = property(lambda self: self._jaxpr) consts = property(lambda self: self._consts) + literals = consts + + constvars = property(lambda self: self._jaxpr.constvars) + invars = property(lambda self: self._jaxpr.invars) + outvars = property(lambda self: self._jaxpr.outvars) + eqns = property(lambda self: self._jaxpr.eqns) + effects = property(lambda self: self._jaxpr.effects) + debug_info = property(lambda self: self._jaxpr.debug_info) + is_high = property(lambda self: self._jaxpr.is_high) def __init__(self, jaxpr: Jaxpr, consts: Sequence): assert len(consts) == len(jaxpr.constvars) @@ -221,23 +268,21 @@ def __init__(self, jaxpr: Jaxpr, consts: Sequence): @property def in_avals(self): - return [v.aval for v in self.jaxpr.invars] - - @property - def out_avals(self): - return [v.aval for v in self.jaxpr.outvars] + return [v.aval for v in self.invars] @property - def literals(self): - return self.consts # backwards compatible alias + def in_aval_qdds(self) -> list[AbstractValue | AvalQDD]: + return [v.aval if v.initial_qdd is None else AvalQDD(v.aval, v.initial_qdd) + for v in self.invars] @property - def eqns(self): - return self.jaxpr.eqns + def final_aval_qdds(self) -> list[AbstractValue | AvalQDD]: + return [v.aval if v.final_qdd is None else AvalQDD(v.aval, v.final_qdd) + for v in self.invars] @property - def effects(self) -> Effects: - return self.jaxpr.effects + def out_avals(self): + return [v.aval for v in self.outvars] def map_jaxpr(self, f): return ClosedJaxpr(f(self.jaxpr), self.consts) @@ -264,6 +309,9 @@ def pretty_print(self, *, source_info=False, print_shapes=True, def _repr_pretty_(self, p, cycle): return p.text(self.pretty_print(use_color=True)) +weakref_cache_key_types.add(ClosedJaxpr) + + @curry def jaxpr_as_fun(closed_jaxpr: ClosedJaxpr, *args): # TODO(dougalm): remove this hack when we add contexts to jaxpr. @@ -320,7 +368,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): config.compute_on_context_manager.set_local(self.prev_compute_type) config.threefry_partitionable.set_local(self.prev_threefry_partitionable) - if self.context.xla_metadata is not None: + if self.context.xla_metadata: config.xla_metadata_context_manager.set_local(self.prev_xla_metadata) config.abstract_mesh_context_manager.set_local(self.prev_abstract_mesh) @@ -330,8 +378,13 @@ class JaxprEqnContext: __slots__ = ['compute_type', 'threefry_partitionable', 'xla_metadata', 'cur_abstract_mesh'] + compute_type: str | None + threefry_partitionable: bool + xla_metadata: dict[str, Any] | None + cur_abstract_mesh: mesh_lib.AbstractMesh + def __init__(self, compute_type: str | None, threefry_partitionable: bool, - xla_metadata=None): + xla_metadata: dict[str, Any] | None = None): self.compute_type = compute_type self.threefry_partitionable = threefry_partitionable self.cur_abstract_mesh = mesh_lib.get_abstract_mesh() @@ -349,6 +402,21 @@ def __repr__(self): f"xla_metadata={self.xla_metadata})" ) + def __hash__(self): + return hash(( + self.compute_type, + self.threefry_partitionable, + self.cur_abstract_mesh, + None if self.xla_metadata is None + else tuple(sorted(self.xla_metadata.items())), + )) + + def __eq__(self, other): + return (self.compute_type == other.compute_type and + self.threefry_partitionable == other.threefry_partitionable and + self.cur_abstract_mesh == other.cur_abstract_mesh and + self.xla_metadata == other.xla_metadata) + class JaxprEqn: invars: list[Atom] @@ -356,6 +424,12 @@ class JaxprEqn: primitive: Primitive params: dict[str, Any] effects: Effects + + # The source_info.name_stack is always relative to the enclosing jaxpr (only) + # and does not include any name context from the caller of the jaxpr. A jaxpr + # might have multiple callers, after all. + # TODO(phawkins): update source_info.tracebacks to also be relative to the + # enclosing jaxpr. source_info: source_info_util.SourceInfo ctx: JaxprEqnContext @@ -363,13 +437,13 @@ class JaxprEqn: __slots__ = ['invars', 'outvars', 'primitive', 'params', 'effects', 'source_info', 'ctx'] - def __init__(self, invars, outvars, primitive, params, effects, source_info, + def __init__(self, invars, outvars, primitive, params, effs, source_info, ctx): self.invars = invars self.outvars = outvars self.primitive = primitive self.params = params - self.effects = effects + self.effects = effs self.source_info = source_info self.ctx = ctx @@ -402,7 +476,7 @@ def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None, ctx=None) -> JaxprEqn: source_info = source_info or source_info_util.new_source_info() ctx = ctx or JaxprEqnContext( - compute_on.current_compute_type(), + config.compute_on_context_manager.value, config.threefry_partitionable.value, xla_metadata_lib.current_xla_metadata()) if config.enable_checks.value: @@ -412,31 +486,32 @@ def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None, _var_counter = it.count() -@total_ordering class Var: - __slots__ = ["count", "suffix", "aval"] + __slots__ = ["count", "aval", "initial_qdd", "final_qdd"] count: int - suffix: str aval: AbstractValue + # these are only useful for jaxpr binders but rather than create a separate + # type for those, breaking existing interpreters, we add fields here. + initial_qdd : QuasiDynamicData | None + final_qdd : QuasiDynamicData | None - def __init__(self, suffix: str, aval: AbstractValue): + def __init__(self, aval: AbstractValue, initial_qdd=None, final_qdd=None): + assert isinstance(aval, AbstractValue), aval self.count = next(_var_counter) - self.suffix = suffix self.aval = aval - - # TODO(phawkins, mattjj): remove ordering of variables. JAX itself does not - # care about variable ordering, but the downstream package kfac_jax does. - def __lt__(self, other): - return self.count < other.count + self.initial_qdd = initial_qdd + self.final_qdd = final_qdd def __repr__(self): - return f'Var(id={id(self)}){self.suffix}:{self.aval.str_short()}' + return f'Var(id={id(self)}):{self.aval.str_short()}' + def pretty_print(self, context: JaxprPpContext, *, print_dtype: bool = True): + del print_dtype # unused + return f"{context.var_names[self]}" -def gensym(suffix: str = '') -> Callable[[AbstractValue], Var]: - """Produce distinct variables, printed with the optional suffix.""" - return partial(Var, suffix) + +gensym = lambda: Var # In a jaxpr, `dropvar` can appear in place of a bound variable to indicate that # the assignment is dropped, i.e. that an expression's output value will never @@ -444,38 +519,91 @@ def gensym(suffix: str = '') -> Callable[[AbstractValue], Var]: # treat it as a special case of one. Its `aval` is similarly inexact. class DropVar(Var): def __init__(self, aval: AbstractValue): - super().__init__('', aval) + super().__init__(aval) def __repr__(self): return '_' + def pretty_print(self, context: JaxprPpContext, *, print_dtype: bool = True): + del context, print_dtype # unused + return '_' class Literal: - __slots__ = ["val", "aval", "hash"] + # See https://docs.jax.dev/en/latest/internals/constants.html + __slots__ = ["val", "aval"] val: Any aval: AbstractValue - hash: int | None def __init__(self, val, aval): self.val = val self.aval = aval + + @property + def hash(self): try: - self.hash = hash(val) + return hash(self.val) except TypeError: - if type(val) in literalable_types: + if type(self.val) in literalable_types: try: - self.hash = hash((val.item(), val.dtype)) + return hash((self.val.item(), self.val.dtype)) except (TypeError, AttributeError, ValueError): - self.hash = None + return None __hash__ = None # type: ignore - def __repr__(self): - if hasattr(self, 'hash'): - return f'{self.val}' + def pretty_print(self, context: JaxprPpContext, *, print_dtype: bool = True): + del context # unused + dtype = getattr(self.aval, 'dtype', None) + if not np.shape(self.val): + val_str = str(np.asarray(self.val).item()) + else: + val_str = "[...]" + if print_dtype and dtype: + return f'{val_str}:{self.aval.str_short(short_dtypes=True)}' else: - return f'Literal(val={self.val})' + return val_str + + def __repr__(self): + return f'Literal({self.val})' +# The types of constants that can be used with core.Literal. Other constants +# end up as `constvars`. literalable_types: set[type] = set() +def is_literalable(x: Any) -> bool: + # See https://docs.jax.dev/en/latest/internals/constants.html + for t in type(x).__mro__: + if t in literalable_types: + return (not np.shape(x) or config.use_simplified_jaxpr_constants.value) + return False + +@partial(weakref_lru_cache, trace_context_in_key=False) +def jaxpr_const_args(jaxpr: Jaxpr) -> list[tuple[ArrayLike, AbstractValue]]: + # The non-scalar constants in core.Literal, in the entire Jaxpr, + # uniquified by id. These will be hoisted as const arguments to the functions + # in which they appear. + # See https://docs.jax.dev/en/latest/internals/constants.html + if not config.use_simplified_jaxpr_constants.value: + return [] + consts_by_id: dict[int, tuple[ArrayLike, AbstractValue]] = {} + for v in jaxpr.outvars: + if type(v) is Literal and np.shape(v.val): # type: ignore + consts_by_id[id(v)] = (v.val, v.aval) # type: ignore + + for eqn in jaxpr.eqns: + for v in eqn.invars: + if type(v) is Literal and np.shape(v.val): # type: ignore + consts_by_id[id(v)] = (v.val, v.aval) # type: ignore + consts_by_id.update({id(v_aval[0]): v_aval + for v_aval in eqn_params_const_args(eqn.params)}) + return list(consts_by_id.values()) + +def eqn_params_const_args(params) -> list[tuple[ArrayLike, AbstractValue]]: + consts_by_id: dict[int, tuple[ArrayLike, AbstractValue]] = {} + for j in jaxprs_in_params(params): + consts_by_id.update( + {id(v_aval[0]): v_aval for v_aval in jaxpr_const_args(j)} + ) + return list(consts_by_id.values()) + Atom = Union[Var, Literal] class Primitive: @@ -491,6 +619,8 @@ class Primitive: # set for primitives that can skip canonicalization of values skip_canonicalization: bool = False + is_effectful = None + def __init__(self, name: str): self.name = name @@ -503,11 +633,9 @@ def bind(self, *args, **params): def _true_bind(self, *args, **params): for arg in args: - if (isinstance(arg, Tracer) - and not arg._trace.is_valid() - and not config.data_dependent_tracing_fallback.value): + if isinstance(arg, Tracer) and not arg._trace.is_valid(): raise escaped_tracer_error(arg) - # TODO: figure out how to handle function arguments + # TODO: figure out how to handle function arguments for this assert # assert (not config.enable_checks.value or # all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args @@ -522,7 +650,17 @@ def _true_bind(self, *args, **params): trace_ctx.set_trace(prev_trace) def bind_with_trace(self, trace, args, params): - return trace.process_primitive(self, args, params) + # TODO(mattjj,dougalm): remove this block? + try: in_type = map(typeof, args) + except: pass # try lojax error message + else: + if self.is_high(*in_type, **params) and trace.requires_low: + with set_current_trace(trace): + return self.to_lojax(*args, **params) # type: ignore + return trace.process_primitive(self, args, params) + trace.process_primitive(self, args, params) # may raise lojax error + raise Exception(f"couldn't apply typeof to args: {args}") + def def_impl(self, impl): self.impl = impl @@ -536,6 +674,10 @@ def def_effectful_abstract_eval(self, effectful_abstract_eval): self.abstract_eval = effectful_abstract_eval return effectful_abstract_eval + def def_effectful_abstract_eval2(self, abstract_eval): + self.abstract_eval = _generic_effectful_abstract_eval(abstract_eval, self) + return abstract_eval + def def_bind_with_trace(self, bind_with_trace): self.bind_with_trace = bind_with_trace return bind_with_trace @@ -551,12 +693,27 @@ def abstract_eval(self, *args, **params): def get_bind_params(self, params): return [], params + def is_high(self, *avals, **params) -> bool: + return False + def _effect_free_abstract_eval(abstract_eval): def abstract_eval_(*args, **kwargs): return abstract_eval(*args, **kwargs), no_effects return abstract_eval_ +@dataclass(frozen=True) +class GenericEffect(Effect): + prim: Primitive +effects.lowerable_effects.add_type(GenericEffect) +effects.control_flow_allowed_effects.add_type(GenericEffect) +effects.custom_derivatives_allowed_effects.add_type(GenericEffect) + +def _generic_effectful_abstract_eval(abstract_eval, prim): + def abstract_eval_(*args, **kwargs): + return abstract_eval(*args, **kwargs), {GenericEffect(prim)} + return abstract_eval_ + # -------------------- lifting -------------------- # TODO(mattjj): replace this approach with a primitive-keyed table of rules @@ -573,8 +730,8 @@ def read(v: Atom) -> Any: return v.val if isinstance(v, Literal) else env[v] def write(v: Var, val: Any) -> None: - if config.enable_checks.value and not config.dynamic_shapes.value: - assert typecheck(v.aval, val), (v.aval, val) + if config.enable_checks.value: + assert typecheck(v.aval, val), (v.aval, get_aval(val), val) env[v] = val env: dict[Var, Any] = {} @@ -603,7 +760,7 @@ def check_avals_context_mesh(avals, prim_name): continue # avals can have meshes with different axis_names so allow that in # full auto mode. - if a.sharding.mesh._are_all_axes_auto and cur_mesh._are_all_axes_auto: + if a.sharding.mesh.are_all_axes_auto and cur_mesh.are_all_axes_auto: continue if a.sharding.mesh != cur_mesh: raise ValueError( @@ -617,12 +774,13 @@ def check_avals_context_mesh(avals, prim_name): TracerType = TypeVar('TracerType', bound='Tracer') class Trace(Generic[TracerType]): - __slots__ = ("__weakref__", "_invalidated", "_weakref") + __slots__ = ("__weakref__", "_invalidated", "_weakref", "requires_low") def __init__(self): self._invalidated = False # We frequently need a weakref to a trace, so let's precompute one. self._weakref = weakref.ref(self) + self.requires_low = True def process_primitive(self, primitive, tracers, params): raise NotImplementedError("must override") @@ -634,7 +792,7 @@ def is_valid(self): return not self._invalidated def __repr__(self): - return '{}'.format(self.__class__.__name__) + return f'{self.__class__.__name__}' def process_call(self, call_primitive, f, tracers, params): msg = (f"{type(self)} must override process_call to handle call-like " @@ -736,9 +894,21 @@ def _aval_property(name): return property(lambda self: getattr(self.aval, name)) -class Tracer(typing.Array, metaclass=StrictABCMeta): +if TYPE_CHECKING: + # We want Python type checkers to accept `some_tracer: jax.Array`, even though + # tracers can represent non-arrays. That is, ideally we would only accept that + # annotation when the Tracer instance has a ShapedArray aval, but we can't + # decide that at Python type checking time. So instead we're overly permissive + # and allow all Tracer instances to typecheck against a jax.Array annotation. + TracerBase = Array + TracerMeta = StrictABCMeta +else: + TracerBase = object + TracerMeta = type + +class Tracer(TracerBase, metaclass=TracerMeta): __array_priority__ = 1000 - __slots__ = ['_trace', '_line_info'] + __slots__ = ['__weakref__', '_trace', '_line_info'] __hash__ = None # type: ignore _trace: Trace @@ -760,6 +930,10 @@ def _error_repr(self): def __array__(self, *args, **kw): raise TracerArrayConversionError(self) + # helper for isinstance(tracer, jax.Array), here to avoid circular imports + def _is_traced_array(self): + return isinstance(self.aval, ShapedArray) + def __dlpack__(self, *args, **kw): raise ConcretizationTypeError(self, f"The __dlpack__() method was called on {self._error_repr()}." @@ -799,7 +973,6 @@ def sharding(self): # Raising a ConcretizationTypeError would make sense, but for backward compatibility # we raise an AttributeError so that hasattr() and getattr() work as expected. raise AttributeError( - self, f"The 'sharding' attribute is not available on {self._error_repr()}." f"{self._origin_msg()}") @@ -815,7 +988,7 @@ def device(self): # This attribute is part of the jax.Array API, but only defined on concrete arrays. # Raising a ConcretizationTypeError would make sense, but for backward compatibility # we raise an AttributeError so that hasattr() and getattr() work as expected. - raise AttributeError(self, + raise AttributeError( f"The 'device' attribute is not available on {self._error_repr()}." f"{self._origin_msg()}") @@ -887,9 +1060,8 @@ def __getattr__(self, name): if name == 'sharding': raise AttributeError( - self, - f"The 'sharding' attribute is not available on {self._error_repr()}." - f"{self._origin_msg()}") + f"The 'sharding' attribute is not available on {self._error_repr()}. " + "To query sharding information on tracers, use `jax.typeof(x)`.") try: attr = getattr(self.aval, name) @@ -906,7 +1078,13 @@ def __getattr__(self, name): else: return attr - def _pretty_print(self): + def _short_repr(self) -> str: + return f'{self.__class__.__name__}<{self.aval}>' + + def _pretty_print(self, verbose: bool = False) -> pp.Doc: + if not verbose: + return pp.text(self._short_repr()) + base = pp.text(f'Traced<{self.aval}>with<{self._trace}>') contents = [(name, attr._pretty_print() if isinstance(attr, Tracer) else pp.text(repr(attr))) for name, attr in self._contents()] @@ -919,7 +1097,7 @@ def _pretty_print(self): return base def __repr__(self): - return self._pretty_print().format() + return self._pretty_print(verbose=False).format() def _contents(self): try: @@ -939,14 +1117,14 @@ def addressable_data(self, index): @property def block_until_ready(self): # Raise AttributeError for backward compatibility with hasattr() and getattr() checks. - raise AttributeError(self, + raise AttributeError( f"The 'block_until_ready' method is not available on {self._error_repr()}." f"{self._origin_msg()}") @property def copy_to_host_async(self): # Raise AttributeError for backward compatibility with hasattr() and getattr() checks. - raise AttributeError(self, + raise AttributeError( f"The 'copy_to_host_async' method is not available on {self._error_repr()}." f"{self._origin_msg()}") @@ -999,6 +1177,8 @@ def unsafe_buffer_pointer(self): f"The unsafe_buffer_pointer() method was called on {self._error_repr()}." f"{self._origin_msg()}") +_jax.set_tracer_class(Tracer) + # these can be used to set up forwarding of properties and instance methods from # Tracer instances to the underlying avals aval_property = namedtuple("aval_property", ["fget"]) @@ -1021,10 +1201,6 @@ def process_primitive(self, primitive, args, params): else: # TODO(dougalm): delete. this shouldn't be necessary args = map(full_lower, args) - if config.data_dependent_tracing_fallback.value: - for arg in args: - if isinstance(arg, Tracer): - return primitive.bind_with_trace(arg._trace, args, params) check_eval_args(args) return primitive.impl(*args, **params) @@ -1049,6 +1225,8 @@ def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, **_): # py del primitive, fwd, bwd, _ # Unused. return fun.call_wrapped(*tracers) + def cur_qdd(self, x): + return x.cur_qdd() class TraceTag: # TODO: this works for surprisingly subtle reasons. Function transformations @@ -1074,6 +1252,7 @@ def __eq__(self, other): class AxisEnv: axis_sizes : dict[AxisName, int] spmd_axis_names : set[AxisName] + explicit_mesh_axis_names: frozenset[AxisName] def axis_size(self, axis_name): if axis_name not in self.axis_sizes: @@ -1090,24 +1269,31 @@ def axis_names(self): def pop_pure(self, axis_name): new_sizes = self.axis_sizes.copy() new_sizes.pop(axis_name) - return AxisEnv(new_sizes, self.spmd_axis_names) + return AxisEnv(new_sizes, self.spmd_axis_names, + self.explicit_mesh_axis_names) def extend_pure(self, name_size_pairs): new_sizes = self.axis_sizes.copy() new_sizes.update((name, size) for name, size in name_size_pairs if name is not no_axis_name) - return AxisEnv(new_sizes, self.spmd_axis_names) + return AxisEnv(new_sizes, self.spmd_axis_names, + self.explicit_mesh_axis_names) def add_spmd_axis_names(self, axis_names): new_spmd_axis_names = self.spmd_axis_names | set(axis_names) - return AxisEnv(self.axis_sizes, new_spmd_axis_names) + return AxisEnv(self.axis_sizes, new_spmd_axis_names, + self.explicit_mesh_axis_names) + + def add_explicit_mesh_axis_names(self, axis_names): + new_ema = self.explicit_mesh_axis_names | frozenset(axis_names) + return AxisEnv(self.axis_sizes, self.spmd_axis_names, new_ema) def as_hashable_key(self): return tuple((name, size) for (name, size) in self.axis_sizes.items() if name is not no_axis_name) eval_trace = EvalTrace() -top_axis_env = AxisEnv({}, set()) +top_axis_env = AxisEnv({}, set(), frozenset()) class TracingContext(threading.local): trace: Trace | None @@ -1213,6 +1399,24 @@ def __exit__(self, exc_type, exc_value, traceback): add_spmd_axis_names = AddSpmdAxisNamesContextManager +class AddExplicitMeshAxisNamesContextManager: + __slots__ = ['prev', 'axis_names'] + + def __init__(self, axis_names: AxisName | None): + self.axis_names = axis_names + + def __enter__(self): + self.prev = trace_ctx.axis_env + if self.axis_names is not None: + trace_ctx.set_axis_env(self.prev.add_explicit_mesh_axis_names( + self.axis_names)) + + def __exit__(self, exc_type, exc_value, traceback): + trace_ctx.set_axis_env(self.prev) + +add_explicit_mesh_axis_names = AddExplicitMeshAxisNamesContextManager + + def get_axis_env(): return trace_ctx.axis_env @@ -1439,10 +1643,15 @@ def definitely_equal(x, y): class AbstractValue: __slots__: list[str] = [] + is_high = False + has_qdd = False def to_tangent_aval(self): raise NotImplementedError("must override") + def to_cotangent_aval(self): + raise NotImplementedError("must override") + # TODO(dougalm): deprecate this alias def at_least_vspace(self): return self.to_tangent_aval() @@ -1457,6 +1666,9 @@ def __repr__(self): def update_weak_type(self, weak_type): return self + def update_vma(self, vma): + return self + def strip_weak_type(self) -> AbstractValue: return self.update_weak_type(False) @@ -1466,46 +1678,17 @@ def normalize(self) -> AbstractValue: def update(self, **kwargs): raise NotImplementedError("must override") - def str_short(self, short_dtypes=False): - return str(self) + def lo_ty(self): + return [self] -# For type signatures involving dynamic shapes, we use lists of abstract values -# which may contain (reverse) de Bruijn indices in their shapes. -class DBIdx(NamedTuple): - val: int + def lo_ty_qdd(self, qdd): + raise NotImplementedError("avals with qdd must override") -@dataclass(frozen=True) -class InDBIdx: - val: int + def str_short(self, short_dtypes=False, mesh_axis_types=False): + return str(self) -@dataclass(frozen=True) -class OutDBIdx: - val: int - -# For annotating input types of callables (i.e. linear_util.WrappedFuns), we use -# a sequence of pairs where the first element of each pair is an AbstractValue -# (possibly containing DBIdx instances in its shape) and the second is a boolean -# indicating whether that argument is explicit (i.e. passed to the callable). -InputType = tuple[tuple[AbstractValue, bool], ...] # DBIdx in shapes - -# For annotating jaxpr output types, we use a sequence of pairs where the first -# element of each pair is an AbstractValue (possibly containing InDBIdx and/or -# OutDBIdx instances in its shape) and the second is a boolean indicating -# whether that argument is explicit (i.e. returned by the callable). -OutputType = tuple[tuple[AbstractValue, bool], ...] # InDBIdx / OutDBIdx shapes - - -def _jaxpr_type_to_callable_annotation(jaxpr: Jaxpr) -> InputType: - idxs = {v: DBIdx(i) for i, v in enumerate((*jaxpr.constvars, *jaxpr.invars))} - out = [(v.aval.update(shape=tuple(idxs.get(d, d) for d in v.aval.shape)) # type: ignore - if type(v.aval) is DShapedArray else v.aval, True) - for v in jaxpr.invars] - return tuple(out) - -# TODO(dougalm): Deprecate. This is here for backwards compat. -def lattice_join(x, y): - assert typematch(x, y) - return x +InputType = tuple[AbstractValue] +OutputType = tuple[AbstractValue] # For use in typing annotations to denote either a Tracer or a `valid_jaxtype`. Value = Any @@ -1526,12 +1709,30 @@ def check_valid_jaxtype(x): raise TypeError( f"Value {x!r} of type {type(x)} is not a valid JAX type") -def update_aval_with_sharding(aval, sharding): + +def mem_kind_to_space(mem_kind: str) -> MemorySpace: + if mem_kind == 'pinned_host': + return MemorySpace.Host + return MemorySpace.Device + +def mem_space_to_kind(mem_space: MemorySpace) -> str: + if mem_space == MemorySpace.Device: + return 'device' + elif mem_space == MemorySpace.Host: + return 'pinned_host' + else: + assert False, "unreachable" + + +@cache(max_size=4096, + trace_context_in_key=lambda: config.remove_size_one_mesh_axis_from_type.value) +def update_aval_with_sharding(aval, sharding, vma=None): if isinstance(sharding, NamedSharding): - aval = aval.update(sharding=NamedSharding( - sharding.mesh.abstract_mesh, - sharding.spec._normalized_spec_for_aval(aval.ndim))) - return aval + s = NamedSharding(sharding.mesh.abstract_mesh, + sharding.spec._normalized_spec_for_aval(aval.ndim)) + return aval.update(sharding=s, vma=aval.vma if vma is None else vma, + memory_space=mem_kind_to_space(sharding.memory_kind)) + return aval if vma is None else aval.update(vma=vma) # We have three flavors of abstractification APIs here which each used to have @@ -1554,10 +1755,17 @@ def shaped_abstractify(x): if isinstance(x, AbstractValue): return x if hasattr(x, '__jax_array__'): - return shaped_abstractify(x.__jax_array__()) + raise ValueError( + 'Triggering __jax_array__() during abstractification is no longer' + ' supported. To avoid this error, either explicitly convert your object' + ' using jax.numpy.array(), or register your object as a pytree.' + ) if hasattr(x, 'dtype'): - aval = ShapedArray(np.shape(x), x.dtype, - weak_type=getattr(x, 'weak_type', False)) + aval = ShapedArray( + np.shape(x), + dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True), + weak_type=getattr(x, "weak_type", False), + ) return update_aval_with_sharding(aval, getattr(x, 'sharding', None)) raise TypeError( f"Cannot interpret value of type {typ} as an abstract array; it " @@ -1570,7 +1778,8 @@ def abstractify(x): return get_aval(x) -def get_aval(x): +# TODO(phawkins): the return type should be AbstractValue. +def get_aval(x: Any) -> Any: typ = type(x) if (aval_fn := pytype_aval_mappings.get(typ)): # fast path return aval_fn(x) @@ -1578,10 +1787,21 @@ def get_aval(x): if (aval_fn := pytype_aval_mappings.get(t)): return aval_fn(x) if hasattr(x, '__jax_array__'): - return get_aval(x.__jax_array__()) + raise ValueError( + 'Triggering __jax_array__() during abstractification is no longer' + ' supported. To avoid this error, either explicitly convert your object' + ' using jax.numpy.array(), or register your object as a pytree.' + ) raise TypeError(f"Argument '{x}' of type '{typ}' is not a valid JAX type") -typeof = get_aval + +# TODO(phawkins): the return type should be AbstractValue. +def typeof(x: Any, /) -> Any: + """Return the JAX type (i.e. :class:`AbstractValue`) of the input. + + Raises a ``TypeError`` if ``x`` is not a valid JAX type. + """ + return get_aval(x) def is_concrete(x): return to_concrete_value(x) is not None @@ -1630,6 +1850,63 @@ def concrete_dim_or_error(val: Any, context=""): else: return concrete_or_error(operator.index, val, context=context) +### Quasi-dynamic data + +# Quasi-dynamic data includes things like liveness bits and the content type of +# a type-changeable box. These change throughout the program but at a given +# point in the program they have a single statically known value. + +class MutableQuasiDynamicData: + def __init__(self, val : QuasiDynamicData | None): + self.init_val = val + self.cur_val = val # immutable payload + + def update(self, val): + self.cur_val = val + + def __repr__(self): + return f'MutableQuasiDynamicData(init_val={self.init_val}, cur_val={self.cur_val})' + +class QuasiDynamicData: + pass + +@dataclass(frozen=True) +class AvalQDD: + is_high = True + aval: AbstractValue + qdd: QuasiDynamicData | None # immutable + + has_qdd = True + def lo_ty(self): + return self.aval.lo_ty_qdd(self.qdd) # type: ignore + + def read_loval(self, val): + return self.aval.read_loval(self.qdd, val) # type: ignore + + def new_from_loval(self, *lovals): + return self.aval.new_from_loval(self.qdd, *lovals) # type: ignore + + def to_tangent_aval(self): + return AvalQDD(self.aval.to_tangent_aval(), self.qdd.to_tangent_qdd()) + +@dataclass(frozen=True) +class AvalMutableQDD: + aval: AbstractValue + mutable_qdd: MutableQuasiDynamicData + +def cur_qdd(x): + prev_trace = trace_ctx.trace + trace_ctx.set_trace(eval_trace) + try: + return prev_trace.cur_qdd(x) + finally: + trace_ctx.set_trace(prev_trace) + +def cur_aval_qdd(x): + aval = typeof(x) + qdd = cur_qdd(x) if aval.has_qdd else None + return AvalQDD(aval, qdd) + ### Extended dtypes # # Extended dtypes are JAX-specific dtypes that allow us to represent logical @@ -1648,20 +1925,17 @@ def concrete_dim_or_error(val: Any, context=""): @overload def physical_aval(aval: ShapedArray) -> ShapedArray: ... -@overload -def physical_aval(aval: DShapedArray) -> DShapedArray: ... @overload # TODO(frostig): remove this case def physical_aval(aval: AbstractValue) -> AbstractValue: ... def physical_aval(aval): - if (isinstance(aval, (ShapedArray, DShapedArray)) and + if (isinstance(aval, ShapedArray) and isinstance(aval.dtype, dtypes.ExtendedDType)): elt_aval = physical_element_aval(aval.dtype) - if isinstance(aval, ShapedArray): - from jax._src.sharding_impls import physical_sharding # type: ignore - return ShapedArray((*aval.shape, *elt_aval.shape), elt_aval.dtype, - sharding=physical_sharding(aval, aval.sharding)) - return DShapedArray((*aval.shape, *elt_aval.shape), elt_aval.dtype) + from jax._src.sharding_impls import physical_sharding # type: ignore + return ShapedArray((*aval.shape, *elt_aval.shape), elt_aval.dtype, + sharding=physical_sharding(aval, aval.sharding), + vma=aval.vma) return aval def physical_shape(logical_shape, dtype): @@ -1676,47 +1950,6 @@ def physical_element_aval(edtype: dtypes.ExtendedDType) -> ShapedArray: def _dtype_object(dtype): return dtype if isinstance(dtype, dtypes.ExtendedDType) else np.dtype(dtype) -class UnshapedArray(AbstractValue): - __slots__ = ['dtype', 'weak_type'] - array_abstraction_level = 4 - - def __init__(self, dtype, weak_type=False): - # Is it silly to initialize this object and then complain that we should - # never create one? Yes. But otherwise pytype complains. - self.dtype = _dtype_object(dtype) - self.weak_type = weak_type - raise Exception("We should never create an UnshapedArray object") - - def __eq__(self, other): - return (type(self) is type(other) and self.dtype == other.dtype and - self.weak_type == other.weak_type) - - def __ne__(self, other): - return not self == other - - def __hash__(self): - # can use hash(self.dtype) and rely on the fact that numpy reuses base dtype - # objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use - # the unique character code via hash(self.dtype.char) - return hash((self.dtype, self.weak_type)) - - def __repr__(self): - return '{}({}{})'.format(self.__class__.__name__, self.str_short(), - ", weak_type=True" if self.weak_type else "") - - _bool = concretization_function_error(bool) - _int = concretization_function_error(int, True) - _float = concretization_function_error(float, True) - _complex = concretization_function_error(complex, True) - _hex = concretization_function_error(hex) - _oct = concretization_function_error(oct) - _index = concretization_function_error(operator.index) - - def str_short(self, short_dtypes=False) -> str: - return dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name - - def update_weak_type(self, weak_type): - return self.update(weak_type=weak_type) def _canonicalize_dimension(dim: DimSize) -> DimSize: # Dimensions are most commonly integral (by far), so we check that first. @@ -1724,15 +1957,7 @@ def _canonicalize_dimension(dim: DimSize) -> DimSize: return operator.index(dim) except TypeError as e: type_error = e - if isinstance(dim, Tracer) and config.dynamic_shapes.value: - if not (dim.ndim == 0 and (dtypes.issubdtype(dim.dtype, np.integer) - or isinstance(dim.dtype, bint))): - raise TypeError(f"Dimensions must be integer scalars; got {dim.ndim=} {dim.dtype=}") - return dim - elif (config.dynamic_shapes.value and isinstance(dim, DArray) and - type(dim._aval.dtype) is bint and not dim._aval.shape): - return dim - elif is_dim(dim): + if is_dim(dim): return dim else: raise type_error @@ -1766,16 +1991,11 @@ def canonicalize_dim(d: DimSize, context: str="") -> DimSize: return canonicalize_shape((d,), context)[0] def _invalid_shape_error(shape: Shape, context: str=""): - if config.dynamic_shapes.value: - msg = ("Shapes must be 1D sequences of integer scalars, " - f"got {shape}") - else: - msg = ("Shapes must be 1D sequences of concrete values of integer type, " - f"got {shape}.") + msg = ("Shapes must be 1D sequences of concrete values of integer type, " + f"got {shape}.") if context: msg += f" {context}." - if not config.dynamic_shapes.value and any( - isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray) + if any(isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray) and not is_concrete(x) for x in shape): msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to " "smaller subfunctions.") @@ -1786,6 +2006,10 @@ def _invalid_shape_error(shape: Shape, context: str=""): return TypeError(msg) +class ShardingTypeError(Exception): + pass + + # TODO(dougalm): Cast scalar, numpy arrays, etc to jax arrays so that values # passed to primitives are always have avals, etc i.e. they are canonical. def canonicalize_value(val): @@ -1801,14 +2025,20 @@ def canonicalize_value(val): cur_mesh = mesh_lib.get_abstract_mesh() if cur_mesh == aval.sharding.mesh: return val + # TODO(yashkatariya): Casting to Explicit is not yet allowed. Maybe we need + # cast_and_slice_p for it since shape might change? # Atleast 1 mesh axis should be Manual and all other axes should be # Manual or Auto to allow casting. - # TODO(yashkatariy): Casting to Explicit is not yet allowed. Maybe we need - # cast_and_slice_p for it since shape might change? - if (cur_mesh._any_axis_manual and cur_mesh._are_all_axes_auto_or_manual and - aval.sharding.mesh._are_all_axes_auto): - from jax._src.pjit import mesh_cast # pytype: disable=import-error - return mesh_cast(val, NamedSharding(cur_mesh, P(*[None] * aval.ndim))) + if cur_mesh._any_axis_manual and cur_mesh._are_all_axes_auto_or_manual: + if aval.sharding.mesh.are_all_axes_auto: + from jax._src.pjit import reshard # pytype: disable=import-error + return reshard(val, NamedSharding(cur_mesh, P(*[None] * aval.ndim))) + elif aval.sharding.mesh._any_axis_explicit: + raise NotImplementedError( + "Closing over inputs to shard_map where the input is sharded on" + " `Explicit` axes is not implemented. As a workaround, please pass" + " those inputs as an argument to shard_map. Got input with shape" + f" {aval.str_short(True, True)}") return val @@ -1817,43 +2047,55 @@ def get_cur_mesh_sharding(spec=None): return NamedSharding(mesh_lib.get_abstract_mesh(), spec) def _make_lengths_same(sharding, ndim): - if ndim > len(sharding.spec): - return sharding.with_spec(sharding.spec._normalized_spec_for_aval(ndim)) - if ndim < len(sharding.spec): - assert all(s is None for s in sharding.spec[ndim:]) - return sharding.with_spec(sharding.spec[:ndim]) + pspec = sharding.spec + if ndim > len(pspec): + return sharding.update(spec=pspec._normalized_spec_for_aval(ndim)) + if ndim < len(pspec): + assert all(s is None for s in pspec[ndim:]), (ndim, pspec) + return sharding.update(spec=P(*pspec[:ndim], unreduced=pspec.unreduced, + reduced=pspec.reduced)) assert False, "unreachable" -# TODO(yashkatariya): Only works with User/Auto. Generalize it to work with -# Collective too. def modify_spec_for_auto_manual(spec, mesh) -> P: - new_spec = [] + new_spec = [] # type: ignore + # PartitionSpec can only mention mesh axes that are Explicit. for s in spec: - if not s: - new_spec.append(s) + if s is None: + new_spec.append(s) # type: ignore + elif isinstance(s, tuple): + new_spec.append(tuple( + p for p in s if mesh._name_to_type[p] == AxisType.Explicit)) else: - temp_s = s[0] if isinstance(s, tuple) else s - new_spec.append( - None - if mesh._name_to_type[temp_s] in (AxisType.Auto, AxisType.Manual) - else s) - return P(*new_spec) + new_spec.append(s if mesh._name_to_type[s] == AxisType.Explicit else None) # type: ignore + # Unreduced and reduced can mention mesh axes that are Explicit and Manual. + new_unreduced = {u for u in spec.unreduced + if mesh._name_to_type[u] != AxisType.Auto} + new_reduced = {u for u in spec.reduced + if mesh._name_to_type[u] != AxisType.Auto} + return P(*new_spec, unreduced=new_unreduced, reduced=new_reduced) + +def remove_size_one_mesh_axis(spec, mesh) -> P: + new_spec = [] # type: ignore + for s in spec: + if s is None: + new_spec.append(s) # type: ignore + elif isinstance(s, tuple): + new_spec.append(tuple(i for i in s if mesh.shape[i] != 1)) + else: + new_spec.append(None if mesh.shape[s] == 1 else s) # type: ignore + return P(*new_spec, unreduced=spec.unreduced, reduced=spec.reduced) def _maybe_modify_sharding(sharding, ndim): if len(sharding.spec) == 0 or all(s is None for s in sharding.spec): - if len(sharding.spec) != ndim: - return _make_lengths_same(sharding, ndim) - return sharding - - if sharding.mesh._are_all_axes_explicit: - if ndim > len(sharding.spec): - return sharding.with_spec(sharding.spec._normalized_spec_for_aval(ndim)) - return sharding - - out = sharding.with_spec(modify_spec_for_auto_manual( - sharding.spec, sharding.mesh)) - if (len(out.spec) != ndim and - (out.mesh.empty or out.mesh._are_all_axes_auto_or_manual)): + out = sharding + elif sharding.mesh.are_all_axes_explicit: + out = sharding + else: + out = sharding.update(spec=modify_spec_for_auto_manual( + sharding.spec, sharding.mesh)) + if config.remove_size_one_mesh_axis_from_type.value: + out = out.update(spec=remove_size_one_mesh_axis(out.spec, out.mesh)) + if len(out.spec) != ndim: out = _make_lengths_same(out, ndim) return out @@ -1871,7 +2113,8 @@ def _check_divisibility(sharding, shape): f" {size} times, but does not evenly divide the dimension size {sh}." f" Got shape: {shape} and sharding {sharding}") -@cache(max_size=4096, trace_context_in_key=False) +@cache(max_size=4096, + trace_context_in_key=lambda: config.remove_size_one_mesh_axis_from_type.value) def get_sharding(sharding, shape): """Modifies and checks the sharding. @@ -1894,28 +2137,69 @@ def get_sharding(sharding, shape): raise ValueError("Mesh of an aval must be an AbstractMesh. " f"Got {out_s.mesh} of type {type(out_s.mesh)}") _check_divisibility(out_s, shape) + assert out_s.memory_kind is None return out_s -def str_short_aval(shape, dtype, mesh, spec, short_dtypes=False, - mesh_axis_types=False) -> str: - dt_str = dtypes.short_dtype_name(dtype) if short_dtypes else dtype.name - dt_str = dt_str.replace('void', 'float0') - shapestr = _get_shape_sharding_str(shape, spec) - mesh_axes = f'({mesh._axis_types_dict})' if mesh_axis_types else '' - return f'{dt_str}[{shapestr}]{mesh_axes}' -class ShapedArray(UnshapedArray): - __slots__ = ['shape', 'sharding', 'varying_manual_axes'] # inherits slots from parent +@cache(max_size=4096, + trace_context_in_key=lambda: config.remove_size_one_mesh_axis_from_type.value) +def get_vma(vma, sharding): + mesh = sharding.mesh + spec = sharding.spec + if mesh.empty: + assert not vma, vma + return vma + + axis_env = get_axis_env() + for i in vma: + if axis_env.axis_exists(i) and i not in mesh._name_to_type: + continue + if mesh._name_to_type[i] != AxisType.Manual: + raise ValueError( + "Axes mentioned in `vma` field of ShapedArray should" + f" be of type `Manual`. Got axis: {i} of type {mesh._name_to_type[i]}") + if config.remove_size_one_mesh_axis_from_type.value: + vma = frozenset(i for i in vma if mesh.shape[i] != 1) + + if vma & spec.unreduced: + raise ValueError( + f"vma and unreduced cannot have common mesh axes. Got {vma=} and" + f" unreduced={spec.unreduced}") + if vma & spec.reduced: + raise ValueError( + f"vma and reduced cannot have common mesh axes. Got {vma=} and" + f" reduced={spec.reduced}") + assert isinstance(vma, frozenset) + return vma + + +def get_memory_space(memory_space): + assert isinstance(memory_space, MemorySpace) + return memory_space + + +class ShapedArray(AbstractValue): + # inherits slots from parent + __slots__ = ['shape', 'dtype', 'weak_type', 'sharding', 'vma', 'memory_space'] array_abstraction_level = 2 def __init__(self, shape, dtype, weak_type=False, *, sharding=None, - varying_manual_axes: frozenset[AxisName] = frozenset()): + vma: frozenset[AxisName] = frozenset(), + memory_space: MemorySpace = MemorySpace.Device): self.shape = canonicalize_shape(shape) self.dtype = _dtype_object(dtype) self.weak_type = weak_type + # The ShapedArray.sharding.memory_kind is always None; use memory_space. self.sharding = get_sharding(sharding, self.shape) - if config.varying_axes_in_types.value: - self.varying_manual_axes = varying_manual_axes + # short for varying_manual_axes. See docs at + # https://docs.jax.dev/en/latest/notebooks/shard_map.html#tracking-how-values-vary-over-manual-mesh-axes-and-check-vma-true + self.vma = get_vma(vma, self.sharding) + # See description of https://github.com/jax-ml/jax/pull/30556 + self.memory_space = get_memory_space(memory_space) + + def lower_val(self, val): return [val] + def raise_val(self, val): return val + def lo_ty(self): return [self] def update(self, shape=None, dtype=None, weak_type=None, **kwargs): if shape is None: @@ -1926,9 +2210,10 @@ def update(self, shape=None, dtype=None, weak_type=None, **kwargs): weak_type = self.weak_type if 'sharding' not in kwargs: kwargs['sharding'] = self.sharding - if 'varying_manual_axes' not in kwargs: - kwargs['varying_manual_axes'] = getattr(self, 'varying_manual_axes', - frozenset()) + if 'vma' not in kwargs: + kwargs['vma'] = self.vma + if 'memory_space' not in kwargs: + kwargs['memory_space'] = self.memory_space return ShapedArray(shape, dtype, weak_type, **kwargs) ndim = property(lambda self: len(self.shape)) @@ -1946,26 +2231,44 @@ def __eq__(self, other): and self.dtype == other.dtype and self.shape == other.shape and self.weak_type == other.weak_type and self.sharding == other.sharding - and (getattr(self, 'varying_manual_axes', frozenset()) == - getattr(other, 'varying_manual_axes', frozenset()))) + and self.vma == other.vma + and self.memory_space == other.memory_space) def __hash__(self): # can use hash(self.dtype) and rely on the fact that numpy reuses base dtype # objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use # the unique character code via hash(self.dtype.char) return hash((self.shape, self.dtype, self.weak_type, self.sharding, - getattr(self, 'varying_manual_axes', frozenset()))) + self.vma, self.memory_space)) + + def __ne__(self, other): + return not self == other + + def __repr__(self): + wt_str = ", weak_type=True" if self.weak_type else "" + return f'ShapedArray({self.str_short()}{wt_str})' + + def __str__(self): + wt_str = "~" if self.weak_type else "" + return f'{wt_str}{self.str_short()}' def to_tangent_aval(self): return ShapedArray( self.shape, primal_dtype_to_tangent_dtype(self.dtype), - self.weak_type, sharding=self.sharding, - varying_manual_axes=getattr(self, 'varying_manual_axes', frozenset())) + self.weak_type, sharding=self.sharding, vma=self.vma, + memory_space=self.memory_space) + + def to_cotangent_aval(self): + dtype = primal_dtype_to_tangent_dtype(self.dtype) + sharding = primal_sharding_to_cotangent_sharding(self.sharding) + return ShapedArray( + self.shape, dtype, self.weak_type, sharding=sharding, vma=self.vma, + memory_space=self.memory_space) def str_short(self, short_dtypes=False, mesh_axis_types=False): return str_short_aval( self.shape, self.dtype, self.sharding.mesh, self.sharding.spec, - short_dtypes, mesh_axis_types) + self.vma, self.memory_space, short_dtypes, mesh_axis_types) def _len(self, ignored_tracer): try: @@ -1973,6 +2276,20 @@ def _len(self, ignored_tracer): except IndexError as err: raise TypeError("len() of unsized object") from err # same as numpy error + def update_vma(self, vma): + return self.update(vma=vma) + + def update_weak_type(self, weak_type): + return self.update(weak_type=weak_type) + + _bool = concretization_function_error(bool) + _int = concretization_function_error(int, True) + _float = concretization_function_error(float, True) + _complex = concretization_function_error(complex, True) + _hex = concretization_function_error(hex) + _oct = concretization_function_error(oct) + _index = concretization_function_error(operator.index) + def _get_shape_sharding_str(shape, spec): out = [] @@ -1986,6 +2303,42 @@ def _get_shape_sharding_str(shape, spec): out.append(f"{s1}@{s2}") return ','.join(out) +@cache(max_size=1024, trace_context_in_key=False) +def _axis_types_dict(mesh): + if not mesh.axis_names: + return {} + d = defaultdict(list) + for n, t in safe_zip(mesh.axis_names, mesh.axis_types): + d[t].append(n) + return {t: tuple(n) for t, n in d.items()} + +def str_short_aval(shape, dtype, mesh, spec, vma, memory_space, + short_dtypes=False, mesh_axis_types=False) -> str: + dt_str = dtypes.short_dtype_name(dtype) if short_dtypes else dtype.name + dt_str = dt_str.replace('void', 'float0') + shapestr = _get_shape_sharding_str(shape, spec) + mesh_axes = f'({_axis_types_dict(mesh)})' if mesh_axis_types else '' + vma_ur = _vma_ur_str(vma, spec.unreduced, spec.reduced, mesh) + ms_str = ("" if memory_space == MemorySpace.Device else + f"<{memory_space.name.lower()}>") + return f'{dt_str}{ms_str}[{shapestr}]{vma_ur}{mesh_axes}' + +def _create_str(x, prefix): + x_str = f"{','.join(i for i in x)}" + x_str = x_str if len(x) == 1 else f"({x_str})" + return f"{prefix}:{x_str}, " + +def order_wrt_mesh(mesh, x): + return tuple(a for a in mesh.axis_names if a in x) + +def _vma_ur_str(vma, unreduced, reduced, mesh): + if not vma and not unreduced and not reduced: + return '' + vma_str = _create_str(order_wrt_mesh(mesh, vma), 'V') if vma else '' + ur_str = _create_str(unreduced, 'U') if unreduced else '' + red_str = _create_str(reduced, 'R') if reduced else '' + m_str = f"{vma_str}{ur_str}{red_str}".rstrip(', ') + return f"{{{m_str}}}" def primal_dtype_to_tangent_dtype(primal_dtype): if isinstance(primal_dtype, dtypes.ExtendedDType): @@ -1995,125 +2348,97 @@ def primal_dtype_to_tangent_dtype(primal_dtype): else: return primal_dtype +def primal_spec_to_cotangent_spec(spec): + return P(*spec, unreduced=spec.reduced, reduced=spec.unreduced) -# Dynamic shape stuff below here! We keep the abstract values distinct just so -# as not to interfere with any static shape machinery. - -# We have a convention of reusing AbsractValues as types, even though we could -# make a distinction and use abstract values during tracing only. This reuse -# becomes a bit more extreme with DShapedArrays. A DShapedArray's shape -# attribute is a tuple which can contain several different types: int, DArray -# (scalar and with dtype of bint type), Tracer (while tracing), Var (when used -# as jaxpr type annotations), or DBIdx/InDBIdx/OutDBIdx (when used in InputType -# or OutputType). We could reduce this polymorphism if it seems cleaner, though -# it's kind of convenient! -class DShapedArray(UnshapedArray): - __slots__ = ['shape'] - shape: tuple[AxisSize, ...] # noqa: F821 - array_abstraction_level: int = 3 - - def __init__(self, shape, dtype, weak_type=False): - self.shape = shape - self.dtype = dtype - self.weak_type = weak_type - - ndim = property(lambda self: len(self.shape)) - size = property(lambda self: - 0 if any(type(d) is int and d == 0 for d in self.shape) - else math.prod(self.shape)) - - def str_short(self, short_dtypes=False) -> str: - del short_dtypes # ignored - shape = f'{",".join(str(d) for d in self.shape)}' if self.shape else '' - dtype = dtypes.short_dtype_name(self.dtype) - return f'{dtype}[{shape}]' - __str__ = __repr__ = str_short - - def update(self, shape=None, dtype=None, weak_type=None): - if shape is None: - shape = self.shape - if dtype is None: - dtype = self.dtype - if weak_type is None: - weak_type = self.weak_type - return DShapedArray(shape, dtype, weak_type) - - @property - def sharding(self): - return NamedSharding(mesh_lib.empty_abstract_mesh, P()) - - def _len(self, tracer): - return self.shape[0] - - def __eq__(self, other): - return (type(self) is type(other) - and self.dtype == other.dtype and self.shape == other.shape - and self.weak_type == other.weak_type) - - def __hash__(self): - return hash((self.shape, self.dtype, self.weak_type)) - - def to_tangent_aval(self): - return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), - self.weak_type) - +def primal_sharding_to_cotangent_sharding(sharding): + return sharding.update(spec=primal_spec_to_cotangent_spec(sharding.spec)) -class DArray: - _aval: DShapedArray - _data: Any # standard array type - def __init__(self, aval, data): - pad_shape = tuple(d.dtype.bound if type(d) is DArray and - type(d.dtype) is bint else d for d in aval.shape) - assert data.shape == pad_shape - self._aval = aval - self._data = data +############################## pvary ################################# - shape = property(lambda self: self._aval.shape) - dtype = property(lambda self: self._aval.dtype) - aval = property(lambda self: self._aval) - def __repr__(self) -> str: - if not self.shape and type(self.dtype) is bint: - # special-case scalar bints - return f'{int(self._data)}{{≤{self.dtype.bound}}}' - - dtypestr = dtypes.short_dtype_name(self._aval.dtype) - shapestr = ','.join(map(str, self.shape)) - data = self.data - return f'{dtypestr}[{shapestr}] with value: {data}' +# Invariant -> Variant no-op cast +def pvary(x, axis_name): + axes = (axis_name,) if not isinstance(axis_name, tuple) else axis_name + if not axis_name: + return x + cur_mesh = mesh_lib.get_abstract_mesh() + new_axes = axes if cur_mesh.empty else order_wrt_mesh(cur_mesh, axes) + assert set(new_axes) == set(axes) + del axes + # TODO(yashkatariya): Remove this handling and remove_size_one_mesh_axis_from_type + # generally from JAX. + if config.remove_size_one_mesh_axis_from_type.value and not cur_mesh.empty: + new_axes = tuple(i for i in new_axes if cur_mesh.shape[i] != 1) + if not new_axes: + return x + return tree_map(lambda leaf: pvary_p.bind(leaf, axes=new_axes), x) - def __hash__(self) -> int: - if not self.shape: - return hash((self._aval, int(self._data))) - raise TypeError("unhashable type: DArray") +pvary_p = Primitive('pvary') - def __eq__(self, other): - if isinstance(other, DArray) and self._aval == other._aval: - return self._data == other._data - return False +####################### reduced_vary_cast ############################# - def __len__(self): - return self.shape[0] +# Reduced -> Varying no-op cast +def reduced_vary_cast(x, axis_name): + axes = (axis_name,) if not isinstance(axis_name, tuple) else axis_name + if not axis_name: + return x + return tree_map(lambda leaf: reduced_vary_cast_p.bind(leaf, axes=axes), x) - @property - def data(self): - if not self.shape and type(self.dtype) is bint: - # special-case scalar bints - return self._data - - slices = tuple( - slice(int(d._data)) - if type(d) is DArray and type(d.dtype) is bint - else slice(None) - for d in self.shape - ) - data = self._data[slices] - return data +reduced_vary_cast_p = Primitive('reduced_vary_cast_p') -def _darray_aval(x): - return DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type) +####################################################################### -pytype_aval_mappings[DArray] = _darray_aval +def check_unreduced_args(args, name): + for a in args: + if a.sharding.spec.unreduced: + raise ValueError( + f"{name} cannot accept args which are unreduced. Got" + f" {a.str_short(True)}") + if a.sharding.spec.reduced: + raise ValueError( + f"{name} cannot accept args which are reduced. Got" + f" {a.str_short(True)}") + +def standard_insert_pvary(*args): + if not config._check_vma.value: + return args + if not args: + return args + in_vma = [aval.vma if isinstance(aval := get_aval(a), ShapedArray) + else frozenset() for a in args] + in_reduced = [aval.sharding.spec.reduced + if isinstance(aval := get_aval(a), ShapedArray) else frozenset() + for a in args] + out_vma = frozenset.union(*in_vma) + out = [] + for arg, src_vma, src_reduced in zip(args, in_vma, in_reduced): + if (isinstance(get_aval(arg), ShapedArray) and + (rest_vma := out_vma - src_vma)): + # TODO(yashkatariya): Handle partial reduced_vary_cast and partial pvary. + # Will need more changes to pvary to allow such partialness. + if src_reduced == rest_vma: + out.append( + reduced_vary_cast(arg, tuple(n for n in out_vma if n in rest_vma))) + else: + out.append(pvary(arg, tuple(n for n in out_vma if n in rest_vma))) + else: + out.append(arg) + return out +def standard_vma_rule(prim_name, *avals, **kwargs) -> frozenset[AxisName]: + if not config._check_vma.value: + return frozenset() + avals = tuple(a for a in avals if a is not abstract_token) + if not avals: + return frozenset() + vma, *vmas = (a.vma for a in avals) + if not all(vma == vma_ for vma_ in vmas): + raise ValueError( + f'Primitive {prim_name} requires varying manual axes ' + f'to match, but got {[vma, *vmas]}. Please open an issue at ' + 'https://github.com/jax-ml/jax/issues and as a temporary ' + 'workaround pass the check_vma=False argument to `jax.shard_map`') + return vma @dataclass(frozen=True) class bint(dtypes.ExtendedDType): @@ -2130,48 +2455,149 @@ def name(self) -> str: def __str__(self) -> str: return self.name -AxisSize = Union[int, DArray, Tracer, Var, DBIdx, InDBIdx, OutDBIdx] +AxisSize = Union[int, Tracer, Var] -class MutableArray: - _aval: ShapedArray - _buf: Array - def __init__(self, aval, buf): +class RefMeta(type): + def __instancecheck__(self, inst): + from jax._src.state.types import AbstractRef # pytype: disable=import-error + return (super().__instancecheck__(inst) or + isinstance(inst, Tracer) and isinstance(inst.aval, AbstractRef)) + +class Ref(metaclass=RefMeta): + """Mutable array reference. + + In most cases this should not be constructed directly, but rather + via :func:`jax.ref.new_ref`. For examples of how this can be + used, refer to the `Ref guide`_. + + .. _Ref guide: https://docs.jax.dev/en/latest/array_refs.html + """ + _aval: AbstractValue + _refs: PyTree # list of ArrayRefImpl + + def __init__(self, aval, refs): + from jax._src.state.types import AbstractRef # pytype: disable=import-error + assert isinstance(aval, AbstractRef) self._aval = aval - self._buf = buf + self._refs = refs + + # TODO(mattjj): update repr to handle non-lojax refs + def __repr__(self) -> str: return 'Ref' + repr(self._refs._buf)[5:] + + # forward type-level info to aval aval = property(lambda self: self._aval) shape = property(lambda self: self._aval.shape) + ndim = property(lambda self: len(self._aval.shape)) dtype = property(lambda self: self._aval.dtype) - sharding = property(lambda self: self._buf.sharding) + + # get operations from aval, munging the name def __getitem__(self, idx): return self._aval._getitem(self, idx) def __setitem__(self, idx, x): return self._aval._setitem(self, idx, x) - def __repr__(self) -> str: return 'Mutable' + repr(self[...]) -pytype_aval_mappings[MutableArray] = lambda x: x._aval + def __len__(self) -> int: return self._aval._len(self) + def addupdate(self, x, idx=()): return self._aval._addupdate(self, idx, x) + + # some attributes/methods only work for lojax refs + sharding = property(lambda self: self._refs._buf.sharding) + format = property(lambda self: self._refs._buf.format) + committed = _committed = property(lambda self: True) + def unsafe_buffer_pointer(self): return self._refs._buf.unsafe_buffer_pointer() + + @property + def at(self): raise NotImplementedError() # TODO(mattjj) + +class ArrayRefImpl: + _aval: ShapedArray + _buf: Array # mutable field + + def __init__(self, aval, buf): + from jax._src.state.types import AbstractRef # pytype: disable=import-error + assert isinstance(aval, AbstractRef) and isinstance(aval.inner_aval, ShapedArray) + self._aval = aval + self._buf = buf + +pytype_aval_mappings[Ref] = lambda x: x._aval +dtypes.canonicalize_value_handlers[Ref] = lambda x: x + +def new_ref(init_val: Any, *, memory_space: Any = None, kind: Any = None): + """Create a mutable array reference with initial value ``init_val``. + + For more discussion, see the `Ref guide`_. + + Args: + init_val: A :class:`jax.Array` representing the initial state + of the buffer. + memory_space: An optional memory space attribute for the Ref. + + Returns: + A :class:`jax.ref.Ref` containing a reference to a mutable buffer. + + .. _Ref guide: https://docs.jax.dev/en/latest/array_refs.html + """ + return ref_p.bind(init_val, memory_space=memory_space, kind=kind) +ref_p = Primitive('new_ref') +ref_p.is_effectful = lambda params: True # type: ignore +ref_p.ref_primitive = True + +ref_p.is_high = lambda aval, *, memory_space, kind: aval.is_high # type: ignore +def _ref_to_lojax(init_val, *, memory_space, kind): + from jax._src.state.types import AbstractRef # pytype: disable=import-error + val_ty = typeof(init_val) + hival_of_refs = val_ty.raise_val(*map(new_ref, val_ty.lower_val(init_val))) # type: ignore + aval = AbstractRef(typeof(init_val)) + return Ref(AbstractRef(val_ty), hival_of_refs) +ref_p.to_lojax = _ref_to_lojax # type: ignore -def mutable_array(init_val): - return mutable_array_p.bind(init_val) -mutable_array_p = Primitive('mutable_array') -mutable_array_p.ref_primitive = True class InternalMutableArrayEffect(effects.Effect): pass -internal_mutable_array_effect = InternalMutableArrayEffect() +array_ref_effect = internal_mutable_array_effect = InternalMutableArrayEffect() effects.control_flow_allowed_effects.add_type(InternalMutableArrayEffect) +effects.remat_allowed_effects.add_type(InternalMutableArrayEffect) -@mutable_array_p.def_effectful_abstract_eval -def mutable_array_abstract_eval(init_aval): +@ref_p.def_effectful_abstract_eval +def _ref_abstract_eval(init_aval, *, memory_space: Any, kind: Any): from jax._src.state.types import AbstractRef # pytype: disable=import-error - return AbstractRef(init_aval), {internal_mutable_array_effect} - -@mutable_array_p.def_impl -def _mutable_array_impl(init_val): + return (AbstractRef(init_aval, memory_space=memory_space, kind=kind), + {internal_mutable_array_effect}) + +@ref_p.def_impl +def _ref_impl(init_val, *, memory_space: Any, kind: Any): + if memory_space is not None: + raise NotImplementedError( + "array ref with memory space only works inside of a `jit`.") from jax._src.state.types import AbstractRef # pytype: disable=import-error from jax._src.lax.lax import _array_copy # pytype: disable=import-error - return MutableArray(AbstractRef(get_aval(init_val)), _array_copy(init_val)) + aval = AbstractRef(typeof(init_val), kind=kind) + return Ref(aval, ArrayRefImpl(aval, _array_copy(init_val))) + +def freeze(ref: Ref) -> Array: + """Invalidate a given reference and return its final value. + + For more information about mutable array references, refer to the + `Ref guide`_. + + Args: + ref: A :class:`jax.ref.Ref` object. + + Returns: + A :class:`jax.Array` containing the contents of ``ref``. + + Examples: + >>> import jax + >>> ref = jax.new_ref(jax.numpy.arange(5)) + >>> ref[3] = 100 + >>> ref + Ref([ 0, 1, 2, 100, 4], dtype=int32) + + >>> jax.ref.freeze(ref) + Array([ 0, 1, 2, 100, 4], dtype=int32) -def freeze(ref): + .. _Ref guide: https://docs.jax.dev/en/latest/array_refs.html + """ return freeze_p.bind(ref) freeze_p = Primitive('freeze') +freeze_p.is_effectful = lambda params: True # type: ignore freeze_p.ref_primitive = True @freeze_p.def_effectful_abstract_eval @@ -2182,9 +2608,20 @@ def freeze_abstract_eval(ref_aval): def _freeze_impl(ref): return ref[()] +def accum_grad_in_ref(x): + return accum_grad_in_ref_p.bind(x) + +accum_grad_in_ref_p = Primitive('accum_grad_in_ref') +accum_grad_in_ref_p.is_high = lambda *_: True # type: ignore +accum_grad_in_ref_p.to_lojax = lambda x: x # type: ignore +accum_grad_in_ref_p.def_abstract_eval(lambda x: x) # type: ignore +accum_grad_in_ref_p.def_impl(lambda x: x) # type: ignore + + class AbstractToken(AbstractValue): - def str_short(self, short_dtypes=False): return 'Tok' + def str_short(self, short_dtypes=False, mesh_axis_types=False): return 'Tok' def to_tangent_aval(self): return self + def to_cotangent_aval(self): return self abstract_token: AbstractToken = AbstractToken() # Singleton shaped array used by all abstract tokens when shape/dtype is needed. @@ -2201,18 +2638,13 @@ def __init__(self, buf): def block_until_ready(self): self._buf.block_until_ready() pytype_aval_mappings[Token] = lambda _: abstract_token +dtypes.canonicalize_value_handlers[Token] = lambda x: x -# TODO(dougalm): Deprecate these. They're just here for backwards compat. -def raise_to_shaped(aval): - return aval -raise_to_shaped_mappings: dict[type, Callable] = {} - ### Operations on shapes and dimension sizes. class InconclusiveDimensionOperation(Exception): """Raised when we cannot conclusively compute with symbolic dimensions.""" - pass def is_symbolic_dim(v: Any) -> bool: """Checks if a value is a symbolic dimension used for shape polymorphism. @@ -2224,6 +2656,9 @@ def is_symbolic_dim(v: Any) -> bool: def is_constant_dim(d: DimSize) -> bool: # Whether the dimension is a static integer constant. + # Try using a fast path for non-concrete Tracers. + if isinstance(d, Tracer) and not is_concrete(d): + return False try: operator.index(d) return True @@ -2443,7 +2878,7 @@ def eval_one_dim(d: DimSize): def dim_value_dtype(): """The dtype to be used for dimension values.""" - return dtypes.canonicalize_dtype(np.int64) + return dtypes.default_int_dtype() def dim_constant(ct: int): dtype = dim_value_dtype() @@ -2473,10 +2908,8 @@ def bind_with_trace(self, trace, fun_and_args, params): def get_bind_params(self, params): new_params = dict(params) jaxpr = new_params.pop('call_jaxpr') - subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr, debug_info=jaxpr.debug_info), - jaxpr, ()) - if config.dynamic_shapes.value: - subfun = lu.annotate(subfun, _jaxpr_type_to_callable_annotation(jaxpr)) + subfun = lu.hashable_partial( + lu.wrap_init(eval_jaxpr, debug_info=jaxpr.debug_info), jaxpr, ()) return [subfun], new_params def call_impl(f: lu.WrappedFun, *args, **params): @@ -2499,7 +2932,7 @@ def get_bind_params(self, params): closed_call_p: ClosedCallPrimitive = ClosedCallPrimitive('closed_call') closed_call_p.def_impl(call_impl) closed_call_p.def_effectful_abstract_eval( - lambda *_, call_jaxpr: (call_jaxpr.out_avals, call_jaxpr.effects)) + lambda *_, call_jaxpr: (call_jaxpr.out_avals, eqn_effects(call_jaxpr))) # ------------------- Map ------------------- @@ -2548,42 +2981,32 @@ def unmapped_aval(size: AxisSize, axis: int | None, def _map_shaped_array( size: int, axis: int | None, aval: ShapedArray) -> ShapedArray: assert axis is None or aval.shape[axis] == size - # TODO: Extend the named shape - if axis is None: return aval - sharding = aval.sharding.with_spec(tuple_delete(aval.sharding.spec, axis)) + if axis is None: + return aval + aval_s = aval.sharding + sharding = aval_s.update( + spec=aval_s.spec.update(partitions=tuple_delete(aval_s.spec, axis))) return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype, - weak_type=aval.weak_type, sharding=sharding) + weak_type=aval.weak_type, sharding=sharding, vma=aval.vma, + memory_space=aval.memory_space) def _unmap_shaped_array( size: int, axis: int | None, explicit_mesh_axis, aval: ShapedArray ) -> ShapedArray: - if axis is None: return aval + if axis is None: + return aval elif type(axis) is int: - sharding = aval.sharding.with_spec(tuple_insert( - aval.sharding.spec, axis, explicit_mesh_axis)) + aval_s = aval.sharding + sharding = aval_s.update(spec=aval_s.spec.update(partitions=tuple_insert( + aval_s.spec, axis, explicit_mesh_axis))) return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype, - weak_type=aval.weak_type, sharding=sharding) - else: raise TypeError(axis) - -def _map_dshaped_array( - size: AxisSize, axis: int | None, aval: DShapedArray) -> DShapedArray: - if axis is None: return aval - return DShapedArray(tuple_delete(aval.shape, axis), aval.dtype, - aval.weak_type) - -def _unmap_dshaped_array( - size: AxisSize, axis: int | None, explicit_mesh_axis, aval: DShapedArray - ) -> DShapedArray: - if axis is None: return aval - elif type(axis) is int: - return DShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype, - weak_type=aval.weak_type) + weak_type=aval.weak_type, sharding=sharding, + vma=aval.vma, memory_space=aval.memory_space) else: raise TypeError(axis) AvalMapHandlerPair = tuple[Callable, Callable] aval_mapping_handlers: dict[type, AvalMapHandlerPair] = { - DShapedArray: (_map_dshaped_array, _unmap_dshaped_array), ShapedArray: (_map_shaped_array, _unmap_shaped_array), AbstractToken: (lambda _, __, a: a, lambda _, __, ____, a: a) } @@ -2616,10 +3039,8 @@ def __lt__(self, other): @dataclass(frozen=True) class NamedAxisEffect(effects.Effect): """A side-effect introducing a new named axis into the current scope.""" - name: AxisName - effects.control_flow_allowed_effects.add_type(NamedAxisEffect) effects.custom_derivatives_allowed_effects.add_type(NamedAxisEffect) effects.lowerable_effects.add_type(NamedAxisEffect) @@ -2662,22 +3083,59 @@ def typecompat(aval_ref: AbstractValue, aval: AbstractValue) -> bool: except TypeError: return False -def typematch(t1: AbstractValue, t2: AbstractValue) -> bool: +def typematch(t1: AbstractValue, t2: AbstractValue, + only_shape_shd_check: bool = False) -> bool: """Determine whether `t1` and `t2` are equivalent. Ignores weak_type.""" t1 = t1.normalize() t2 = t2.normalize() + from jax._src.state.types import AbstractRef # pytype: disable=import-error if t1 == t2: return True - elif (isinstance(t1, (ShapedArray, DShapedArray)) and - isinstance(t2, (ShapedArray, DShapedArray))): - # This case handles DShapedArray and shape polynomials. Alternatively we - # could try normalizing first and then doing simple equality. - # TODO(yashkatariya): Also check `sharding` here. - # See https://github.com/jax-ml/jax/issues/26474 - return t1.dtype == t2.dtype and definitely_equal_shape(t1.shape, t2.shape) + elif isinstance(t1, ShapedArray) and isinstance(t2, ShapedArray): + if only_shape_shd_check: + return cmp_shape_sharding_vma(t1, t2) + return (t1.dtype == t2.dtype and cmp_shape_sharding_vma(t1, t2) and + t1.memory_space == t2.memory_space) + elif isinstance(t1, AbstractRef) and isinstance(t2, AbstractRef): + # We want to use the regular typecheck for ShapedArray here. + return (typematch(t1.inner_aval, t2.inner_aval, only_shape_shd_check) and # type: ignore + (t1.memory_space is None or t2.memory_space is None or # type: ignore + t1.memory_space == t2.memory_space)) # type: ignore else: return False +def cmp_shape_sharding_vma(t1, t2): + # TODO(yashkatariya): Expand this to Manual and Auto mode. + # See https://github.com/jax-ml/jax/issues/26474 + if (not t1.sharding.mesh.empty and not t2.sharding.mesh.empty and + (t1.sharding.mesh._any_axis_explicit or + t2.sharding.mesh._any_axis_explicit)): + shd_eq = t1.sharding == t2.sharding + else: + shd_eq = True + return (shd_eq and definitely_equal_shape(t1.shape, t2.shape) and + t1.vma == t2.vma) + +def aval_mismatch_extra(a1: AbstractValue, a2: AbstractValue) -> str: + assert not typematch(a1, a2) + if isinstance(a1, ShapedArray) and isinstance(a2, ShapedArray): + mismatches = [] + if a1.dtype != a2.dtype: + mismatches.append('the dtypes do not match') + if a1.shape != a2.shape: + mismatches.append('the shapes do not match') + if a1.vma != a2.vma: + mismatches.append('the varying manual axes do not match') + # TODO(yashkatariya,mattjj): add check for sharding-in-types mismatch + + if len(mismatches) == 0: + return '' + elif len(mismatches) == 1: + return ', so ' + mismatches[0] + else: + return ', so ' + ', '.join(mismatches[:-1]) + ', and ' + mismatches[-1] + return '' + class JaxprTypeError(TypeError): pass custom_typechecks: dict[Primitive, Callable] = {} @@ -2686,7 +3144,7 @@ def _check_closed_call(_, *in_atoms, call_jaxpr): in_avals = [x.aval for x in in_atoms] if not all(map(typecompat, call_jaxpr.in_avals, in_avals)): raise JaxprTypeError("Closed call in_avals mismatch") - return call_jaxpr.out_avals, call_jaxpr.effects + return call_jaxpr.out_avals, eqn_effects(call_jaxpr) custom_typechecks[closed_call_p] = _check_closed_call def check_jaxpr(jaxpr: Jaxpr): @@ -2728,15 +3186,23 @@ def ctx_factory(): from jax.experimental.key_reuse._core import check_key_reuse_jaxpr # pytype: disable=import-error check_key_reuse_jaxpr(jaxpr) +# A place to track the quasi-dynamic data associated with a variable during typechecking +@dataclass(frozen=True) +class MutableTypecheckVal: + aval : AbstractValue + mutable_qdd : MutableQuasiDynamicData + + +_ref_allocating_primitives = {ref_p} + def _check_jaxpr( ctx_factory: Callable[[], tuple[JaxprPpContext, JaxprPpSettings]], jaxpr: Jaxpr ) -> None: - # Use set of variables to types to check that variables are in scope. - env: set[Var] = set() + env: dict[Var, Atom | MutableTypecheckVal] = {} - def read(x: Atom) -> Atom: + def read(x: Atom) -> Atom | MutableTypecheckVal: # Check the type annotation is itself well-typed. check_type(ctx_factory, env, x.aval) if isinstance(x, Var): @@ -2744,7 +3210,7 @@ def read(x: Atom) -> Atom: if x not in env: ctx, _ = ctx_factory() raise JaxprTypeError(f"Variable '{pp_var(x, ctx)}' not defined") - return x + return env[x] elif isinstance(x, Literal): # Check that the literal matches its type annotation. if not typecheck(x.aval, x.val): @@ -2756,7 +3222,8 @@ def read(x: Atom) -> Atom: else: assert False, "syntactically invalid jaxpr" - def write(v: Var, a: AbstractValue) -> None: + def write(v: Var, a: AvalQDD) -> None: + aval, qdd = a.aval, a.qdd assert isinstance(v, Var), "syntactically invalid jaxpr" # Check the type annotation of the binder is itself well-typed. check_type(ctx_factory, env, v.aval) @@ -2765,19 +3232,30 @@ def write(v: Var, a: AbstractValue) -> None: ctx, _ = ctx_factory() raise JaxprTypeError(f"Variable '{pp_var(v, ctx)}' already bound") # Check that the computed type is consistent with the binder annotation. - if not typematch(v.aval, a): + if not typematch(v.aval, aval): ctx, _ = ctx_factory() raise JaxprTypeError( f"Value for variable '{pp_var(v, ctx)}' inconsistently typed " - f"as {pp_aval(a, ctx)} for let-binder of type {pp_aval(v.aval, ctx)}") + f"as {pp_aval(aval, ctx)} for let-binder of type {pp_aval(v.aval, ctx)}") + # If the variable is not a DropVar, add it to the environment. if not isinstance(v, DropVar): - env.add(v) + if qdd is None: + env[v] = v + else: + env[v] = MutableTypecheckVal(aval, MutableQuasiDynamicData(qdd)) + + # # Don't return refs + if config.mutable_array_checks.value: + from jax._src.state.types import AbstractRef # pytype: disable=import-error + for v in jaxpr.outvars: + if isinstance(v.aval, AbstractRef): + raise JaxprTypeError("returned a ref!") # Check type annotations on lambda binders. for v in it.chain(jaxpr.constvars, jaxpr.invars): check_type(ctx_factory, env, v.aval) - write(v, v.aval) + write(v, AvalQDD(v.aval, v.initial_qdd)) # Check each eqn. sentinel = object() @@ -2787,7 +3265,8 @@ def write(v: Var, a: AbstractValue) -> None: prim = eqn.primitive try: in_atoms = map(read, eqn.invars) - in_avals = [x.aval for x in in_atoms] # use in_atoms for dyn shapes + in_avals = [AvalMutableQDD(x.aval, x.mutable_qdd) if isinstance(x, MutableTypecheckVal) + else x.aval for x in in_atoms] # use in_atoms for dyn shapes # Compute the type of the primitive application. with eqn.ctx.manager: @@ -2806,7 +3285,7 @@ def write(v: Var, a: AbstractValue) -> None: # Check the computed effect type matches the eqn's annotation, and is # included in the jaxpr's annotation. if prim.ref_primitive: - if prim is mutable_array_p: + if prim in _ref_allocating_primitives: outvar, = eqn.outvars in_idx[outvar] = None # type: ignore mut_arrays.add(outvar) @@ -2817,7 +3296,7 @@ def write(v: Var, a: AbstractValue) -> None: for eff in eqn.effects: if isinstance(eff, effects.JaxprInputEffect): eqn_invar = eqn.invars[eff.input_index] - if eqn_invar in mut_arrays: + if type(eqn_invar) is Literal or eqn_invar in mut_arrays: continue if (jaxpr_index := in_idx.get(eqn_invar, sentinel)) is sentinel: raise JaxprTypeError( @@ -2836,7 +3315,7 @@ def write(v: Var, a: AbstractValue) -> None: f"Jaxpr effects: {jaxpr.effects}") # Check out_type matches the let-binders' annotation (after substitution). - out_type = substitute_vars_in_output_ty(out_type, eqn.invars, eqn.outvars) + out_type = [t if isinstance(t, AvalQDD) else AvalQDD(t, None) for t in out_type] foreach(write, eqn.outvars, out_type) except JaxprTypeError as e: @@ -2847,59 +3326,22 @@ def write(v: Var, a: AbstractValue) -> None: f"from source: {src}"]) raise JaxprTypeError(msg, eqn_idx) from None + # Check there are no output refs + # TODO(mattjj): improve this error message + if config.mutable_array_checks.value: + from jax._src.state.types import AbstractRef # pytype: disable=import-error + for v in jaxpr.outvars: + if isinstance(v.aval, AbstractRef): raise TypeError("returned ref") + # TODO(mattjj): include output type annotation on jaxpr and check it here foreach(read, jaxpr.outvars) def check_type( ctx_factory: Callable[[], tuple[JaxprPpContext, JaxprPpSettings]], - env: set[Var], + env: dict[Var, Atom | MutableTypecheckVal], ty: AbstractValue, ) -> None: - if isinstance(ty, DShapedArray): - # Check all elements in the shape tuple are well-typed. - for d in ty.shape: - if (isinstance(d, int) or - isinstance(d, DArray) and not d.shape and type(d.dtype) == bint): - continue - elif isinstance(d, Var): - if d not in env: - ctx, _ = ctx_factory() - raise JaxprTypeError(f"unbound axis size: '{pp_var(d, ctx)}'") - if not isinstance(d.aval, (ShapedArray, DShapedArray)): - raise JaxprTypeError(f"axis size with unexpected type annotation: " - f"{d.aval} of type {type(d.aval)}") - if isinstance(d.aval, ShapedArray): - shape, dtype = d.aval.shape, d.aval.dtype - if shape: raise JaxprTypeError(f"axis size nonscalar: {d.aval}") - if not dtypes.issubdtype(dtype, np.integer): - raise JaxprTypeError(f"axis size with non-integer dtype: {d.aval}") - else: - assert isinstance(d.aval, DShapedArray) - shape, dtype = d.aval.shape, d.aval.dtype - if shape: raise JaxprTypeError(f"axis size nonscalar: {d.aval}") - if type(dtype) is not bint: - raise JaxprTypeError( - f"DArray axis size with non-bint dtype: {d.aval}") - else: - raise JaxprTypeError(f"unexpected type in shape: {type(d)}") - else: - return # Except in above case(s), all syntactic forms are valid - -def substitute_vars_in_output_ty( - out_type: Sequence[AbstractValue], # shapes may contain InDBIdx / OutDBIdx - in_atoms: Sequence[Atom], - out_binders: Sequence[Var], - ) -> list[AbstractValue]: # shapes may contain Vars - in_atoms = [x.val if type(x) is Literal else x for x in in_atoms] - result = [] - for aval in out_type: - if type(aval) is DShapedArray: - shape = [in_atoms[d.val] if type(d) is InDBIdx else - out_binders[d.val] if type(d) is OutDBIdx else - d for d in aval.shape] - aval = aval.update(shape=tuple(shape)) - result.append(aval) - return result + return # Except in above case(s), all syntactic forms are valid def check_eqn(prim, in_avals, params): for jaxpr in jaxprs_in_params(params): @@ -2914,7 +3356,10 @@ def _check_call(ctx_factory, prim, in_atoms, params): if "call_jaxpr" not in params: raise JaxprTypeError( f"Call primitive {prim} missing 'call_jaxpr' parameter") - call_jaxpr = params["call_jaxpr"] + if isinstance(prim, ClosedCallPrimitive): + call_jaxpr = params["call_jaxpr"].jaxpr + else: + call_jaxpr = params["call_jaxpr"] if len(in_atoms) != len(call_jaxpr.invars): raise JaxprTypeError(f"Call primitive {prim} with {len(in_atoms)} " @@ -2922,30 +3367,27 @@ def _check_call(ctx_factory, prim, in_atoms, params): f"{len(call_jaxpr.invars)} inputs") # Check `call_jaxpr` can be applied to in_atoms. - env: dict[Var, Atom] = {} - def substitute(aval: AbstractValue): - if isinstance(aval, DShapedArray): - aval = aval.update(shape=tuple(env.get(d, d) for d in aval.shape)) # type: ignore - return aval + env: dict[Var, Atom | MutableTypecheckVal] = {} for v, x in zip(call_jaxpr.invars, in_atoms): - if not typecompat(substitute(v.aval), x.aval): + if not typecompat(v.aval, x.aval): # TODO(mattjj): vars in error message are confusing b/c of Var.__repr__ raise JaxprTypeError(f"Call primitive {prim} passes operand {x} of type " f"{x.aval} to jaxpr expecting type " - f"{substitute(v.aval)}") - env[v] = x if type(x) is Var else x.val + f"{v.aval}") + env[v] = x.val if type(x) is Literal else x - _check_jaxpr(ctx_factory, call_jaxpr) + check_jaxpr(call_jaxpr) invars, outvars = call_jaxpr.invars, call_jaxpr.outvars - in_map : dict[Var, InDBIdx] = {v: InDBIdx(i) for i, v in enumerate( invars)} - out_map: dict[Var, OutDBIdx] = {x: OutDBIdx(i) for i, x in enumerate(outvars) - if type(x) is Var} out_avals = [x.aval for x in call_jaxpr.outvars] - out_type = [a.update(shape=tuple(in_map.get(d, out_map.get(d)) - if type(d) is Var else d for d in a.shape)) - if type(a) is DShapedArray else a for a in out_avals] - return out_type, call_jaxpr.effects + out_type = out_avals + # jaxpr input effects are indexed to include jaxpr.constvars, but the eqn + # should have effects indexed only on its explicit arguments + effs = {e.replace(input_index=e.input_index - len(call_jaxpr.constvars)) + if isinstance(e, effects.JaxprInputEffect) + else e for e in call_jaxpr.effects} + + return out_type, effs def _check_map(ctx_factory, prim, in_avals, params): if "call_jaxpr" not in params: @@ -2985,6 +3427,165 @@ def _check_map(ctx_factory, prim, in_avals, params): for aval, out_axis in zip(mapped_out_avals, out_axes)] return out_avals, filter_named_axis_effects(call_jaxpr.effects, {axis_name}) +def eqn_effects(jaxpr): + # jaxpr input effects are indexed to include jaxpr.constvars, but the eqn + # should have effects indexed only on its explicit arguments + effs = jaxpr.effects + return {e.replace(input_index=e.input_index - len(jaxpr.constvars)) + if isinstance(e, effects.JaxprInputEffect) else e for e in effs} + + +# ------------------- ShapeDtypeStruct ------------------- + +def _check_sharding(sharding, shape): + if sharding is None: + return + if isinstance(sharding, P): + sharding._check_compatible_wrt_shape(shape) + else: + sharding.check_compatible_aval(shape) + +@set_module("jax") +class ShapeDtypeStruct: + """A container for the shape, dtype, and other static attributes of an array. + + ``ShapeDtypeStruct`` is often used in conjunction with :func:`jax.eval_shape`. + + Args: + shape: a sequence of integers representing an array shape + dtype: a dtype-like object + sharding: (optional) a :class:`jax.Sharding` object + """ + __slots__ = ["shape", "dtype", "_sharding", "_dll", "weak_type", "vma", + "is_ref"] + + def __init__(self, shape, dtype, *, sharding=None, weak_type=False, + vma=None, is_ref=False): + self.shape = tuple(shape) + if dtype is None: + raise ValueError("ShapeDtypeStruct: dtype must be specified.") + self.dtype = dtype if dtypes.issubdtype(dtype, dtypes.extended) else np.dtype(dtype) + if sharding is not None and not isinstance(sharding, (Sharding, Format, P)): + raise ValueError( + "sharding should be an instance of `jax.sharding.Sharding`, " + "`jax.sharding.PartitionSpec` or" + f" `jax.experimental.layout.Format`. Got {sharding} of type" + f" {type(sharding)}.") + if (isinstance(sharding, Format) and + isinstance(sharding.layout, AutoLayout)): + raise TypeError( + "`Layout.AUTO` cannot be used in place of a device-local" + f" layout in a `ShapeDtypeStruct`. Got {sharding}") + self._sharding = (sharding.sharding if isinstance(sharding, Format) + else sharding) + _check_sharding(self._sharding, self.shape) + self._dll = sharding.layout if isinstance(sharding, Format) else None + self.weak_type = weak_type + if vma is not None and not isinstance(vma, (set, frozenset)): + raise TypeError( + "`vma` argument passed to ShapeDtypeStruct should be of type `set`" + f" or `frozenset`. Got type {type(vma)}") + self.vma = None if vma is None else frozenset(vma) + self.is_ref = is_ref + + size = property(lambda self: math.prod(self.shape)) + ndim = property(lambda self: len(self.shape)) + + @property + def sharding(self): + if isinstance(self._sharding, P): + # TODO(yashkatariya): Maybe use `get_abstract_mesh()` here but switch + # on `core.trace_state_clean()`? + cur_mesh = mesh_lib.get_concrete_mesh() + if cur_mesh.empty: + raise TypeError( + "When specifying PartitionSpec to `ShapeDtypeStruct`, the context" + " mesh cannot be empty. Please use `jax.set_mesh` to set" + " the mesh context.") + return NamedSharding(cur_mesh, self._sharding) + else: + return self._sharding + + @property + def format(self): + return Format(self._dll, self.sharding) + + def __len__(self): + try: + return self.shape[0] + except IndexError as e: + raise TypeError("len() of unsized object") from e # same as numpy error + + def __repr__(self): + sh = f", sharding={self.sharding}" if self.sharding is not None else "" + l = f", format={self._dll}" if self._dll is not None else "" + wt = f", weak_type={self.weak_type}" if self.weak_type else "" + vma = f", vma={self.vma}" if self.vma else "" + is_ref = f", is_ref={self.is_ref}" if self.is_ref else "" + return (f"{type(self).__name__}(shape={self.shape}, " + f"dtype={self.dtype.name}{sh}{l}{wt}{vma}{is_ref})") + + __str__ = __repr__ + + def __eq__(self, other): + if not isinstance(other, ShapeDtypeStruct): + return False + else: + return ((self.shape, self.dtype, self.sharding, self._dll, + self.weak_type, self.vma, self.is_ref) == + (other.shape, other.dtype, other.sharding, other._dll, + other.weak_type, other.vma, other.is_ref)) + + def __hash__(self): + # TODO(frostig): avoid the conversion from dict by addressing + # https://github.com/jax-ml/jax/issues/8182 + return hash((self.shape, self.dtype, self.sharding, self._dll, + self.weak_type, self.vma, self.is_ref)) + + def __setattr__(self, name, value): + if hasattr(self, name): + if getattr(self, name) == value: + # This can happen if two threads race, for example if two threads + # are trying to hash the same SDS instance. + return + raise RuntimeError( + f"Cannot reassign attributes ({name}) of immutable ShapeDtypeStruct" + " objects") + super().__setattr__(name, value) + + def update(self, **kwargs): + if 'sharding' in kwargs: + s = kwargs['sharding'] + if self._dll is not None and isinstance(s, Sharding): + raise ValueError( + f"You are updating ShapeDtypeStruct with a {type(s)} when the" + f" original ShapeDtypeStruct had a concrete layout {self.format}." + " This might lead to bugs. If you want to do this, create a new" + " ShapeDtypeStruct via the constructor.") + sharding = s + else: + sharding = self.format + return ShapeDtypeStruct( + shape=kwargs.pop('shape', self.shape), + dtype=kwargs.pop('dtype', self.dtype), + sharding=sharding, + weak_type=kwargs.pop('weak_type', self.weak_type), + vma=kwargs.pop('vma', self.vma), + is_ref=kwargs.pop('is_ref', self.is_ref)) + + +def _sds_aval_mapping(x): + aval = ShapedArray( + x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True), + weak_type=x.weak_type) + aval = update_aval_with_sharding( + aval, x.sharding, vma=(frozenset() if x.vma is None else x.vma)) + if x.is_ref: + from jax._src.state.types import AbstractRef # type: ignore + return AbstractRef(aval) + return aval +pytype_aval_mappings[ShapeDtypeStruct] = _sds_aval_mapping + # ------------------- Jaxpr printed representation ------------------- @@ -3097,17 +3698,12 @@ def suggest_same_var_names(self, self.var_names[for_v] = pp_var(like_v, self) -def pp_var(v: Var | Literal, context: JaxprPpContext) -> str: - if isinstance(v, (Literal, DropVar)): return str(v) - return f"{context.var_names[v]}{v.suffix}" +def pp_var(v: Var | Literal, context: JaxprPpContext, *, + print_literal_dtype: bool = True) -> str: + return v.pretty_print(context, print_dtype=print_literal_dtype) def pp_aval(a: AbstractValue, context: JaxprPpContext) -> str: - if isinstance(a, DShapedArray): - shape = [pp_var(d, context) if type(d) is Var else str(d) for d in a.shape] - dtype = dtypes.short_dtype_name(a.dtype) - return f'{dtype}[{",".join(shape)}]' - else: - return a.str_short(short_dtypes=True) + return a.str_short(short_dtypes=True) def pp_vars(vs: Sequence[Atom], context: JaxprPpContext, *, separator="", print_shapes: bool = False) -> pp.Doc: @@ -3139,20 +3735,20 @@ def pp_kv_pair(k:str, v: Any, context: JaxprPpContext, settings: JaxprPpSettings def pp_kv_pairs(kv_pairs, context: JaxprPpContext, settings: JaxprPpSettings) -> pp.Doc: if not kv_pairs: return pp.nil() - return pp.group( + return pp.group(pp.concat([ pp.nest(2, pp.concat([ pp.text("["), pp.brk(""), pp.join(pp.brk(), [pp_kv_pair(k, v, context, settings) for k, v in kv_pairs]) - ])) - + pp.brk("") + pp.text("]") - ) + ])), + pp.brk(""), pp.text("]") + ])) def pp_eqn(eqn: JaxprEqn, context: JaxprPpContext, settings: JaxprPpSettings ) -> pp.Doc: rule = (_pp_eqn if not settings.custom_pp_eqn_rules else pp_eqn_rules.get(eqn.primitive, _pp_eqn)) doc = rule(eqn, context, settings) - user_frame = source_info_util.user_frame(eqn.source_info) + user_frame = source_info_util.user_frame(eqn.source_info.traceback) return doc if user_frame is None else pp.source_map(doc, user_frame) def _pp_eqn(eqn: JaxprEqn, context: JaxprPpContext, settings: JaxprPpSettings, @@ -3166,7 +3762,7 @@ def _pp_eqn(eqn: JaxprEqn, context: JaxprPpContext, settings: JaxprPpSettings, rhs = [pp.text(eqn.primitive.name, annotation=name_stack_annotation), pp_kv_pairs([(p, eqn.params[p]) for p in params], context, settings), pp.text(" ") + pp_vars(eqn.invars, context)] - if lhs.format(): + if eqn.outvars: return pp.concat([lhs, pp.text(" = ", annotation=annotation), *rhs]) else: return pp.concat(rhs) @@ -3261,10 +3857,10 @@ def pp_jaxpr( def pp_jaxprs(jaxprs: Sequence[ClosedJaxpr | Jaxpr], context: JaxprPpContext, settings: JaxprPpSettings) -> pp.Doc: jaxprs = [j.jaxpr if isinstance(j, ClosedJaxpr) else j for j in jaxprs] - return pp.group(pp.nest(2, pp.concat([ + return pp.group(pp.concat([pp.nest(2, pp.concat([ pp.text('('), pp.brk(""), pp.join(pp.brk(), map(lambda x: pp_jaxpr(x, context, settings), jaxprs))] - )) + pp.brk("") + pp.text(')') + )), pp.brk(""), pp.text(')')]) ) @@ -3315,6 +3911,21 @@ def clean_up_dead_vars(eqn: JaxprEqn, env: dict[Var, Any], shard_aval_handlers = {} # type: ignore unshard_aval_handlers = {} # type: ignore +def shard_aval(mesh, manual_axes, check_vma, spec, aval: AbstractValue + ) -> AbstractValue: + if type(aval) in shard_aval_handlers: + return shard_aval_handlers[type(aval)](mesh, manual_axes, check_vma, + spec, aval) + raise NotImplementedError(f"Unsupported aval type: {type(aval)}") + +def unshard_aval(mesh, check_vma, spec, aval: AbstractValue + ) -> AbstractValue: + if type(aval) in unshard_aval_handlers: + return unshard_aval_handlers[type(aval)](mesh, check_vma, spec, aval) + else: + raise NotImplementedError(f"Unsupported aval type: {type(aval)}") + + # ----------------- external APIs for querying tracing context ----------------- # TODO(dougalm, jakevdp): expose these via jax.extend @@ -3330,7 +3941,7 @@ def __eq__(self, other): else: return False -def get_opaque_trace_state(convention): +def get_opaque_trace_state(convention=None): del convention return OpaqueTraceState(trace_ctx.trace._weakref) diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index c7e7c83f30f8..43599c264cbc 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -18,22 +18,25 @@ import math from typing import TypedDict -import jax -from jax import dtypes from jax._src import core +from jax._src import custom_derivatives from jax._src import dispatch +from jax._src import dtypes +from jax._src import numpy as jnp +from jax._src import xla_bridge from jax._src.custom_partitioning import custom_partitioning +from jax._src.custom_partitioning_sharding_rule import BATCHING, ArrayMapping, CompoundFactor, SdyShardingRule from jax._src.interpreters import batching +from jax._src.interpreters import mlir +from jax._src.lax import parallel as lax_parallel from jax._src.lib import cuda_versions -from jax._src import xla_bridge -from jax.interpreters import mlir -from jax.interpreters import xla -from jax.interpreters.mlir import hlo -from jax.interpreters.mlir import ir -import jax.numpy as jnp -from jax.sharding import NamedSharding, PartitionSpec +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import hlo +from jax._src.sharding_impls import NamedSharding, PartitionSpec +from jax._src.typing import Array + +import numpy as np -Array = jnp.ndarray class FP8Params(TypedDict): amax_dQ: float # Amax of gradient of query @@ -86,18 +89,13 @@ def has_padding(mask_type: MaskType) -> bool: return mask_type == MaskType.PADDING or mask_type == MaskType.PADDING_CAUSAL def should_export_dbias(bias_shape, query_shape, layout) -> bool: - b_B, b_N, _, _ = bias_shape - if layout == AttentionLayout.BNTH.value: - _, q_N, _, _ = query_shape - else: - _, _, q_N, _ = query_shape - return b_B == 1 and b_N == q_N + return True def get_large_negative_number(dtype): # temp WAR as cuDNN has a bug for subtraction between two large negative value - if dtype == jnp.bfloat16: + if dtype == np.dtype('bfloat16'): return jnp.asarray(-2 << 40, dtype=dtype) - elif dtype == jnp.float16: + elif dtype == np.dtype('float16'): return jnp.asarray(-2 << 14, dtype=dtype) else: raise ValueError("Unsupported dtype for inputs.") @@ -122,6 +120,9 @@ def default_layouts(*shapes): def get_max_seg_per_batch(q_offsets): return q_offsets.shape[1] - 1 if len(q_offsets.shape) == 2 else 1 +def check_is_paged_attention(page_table_k): + return len(page_table_k.shape) == 4 + def create_dot_product_attention_backend_config_base( batch, num_heads, seq_q, seq_kv, dtype, fmha_scale, mask_type, layout, is_bwd ): @@ -229,6 +230,7 @@ def create_dot_product_attention_backend_config( layout, sliding_window_length, max_seg_per_batch, + is_paged_attention, is_bwd ): backend_config = create_dot_product_attention_backend_config_base( @@ -241,6 +243,7 @@ def create_dot_product_attention_backend_config( backend_config['cudnn_fmha_backend_config']["seed"] = seed backend_config['cudnn_fmha_backend_config']["sliding_window_length"] = sliding_window_length backend_config['cudnn_fmha_backend_config']["max_seg_per_batch"] = max_seg_per_batch + backend_config['cudnn_fmha_backend_config']["is_paged_attention"] = is_paged_attention return json.dumps(backend_config) def create_dot_product_attention_fp8_backend_config( @@ -273,7 +276,7 @@ def get_custom_call_name(has_bias, has_dropout, is_bwd, is_fp8=False): ) def check_layout(query, key, value, bias, q_seqlen, kv_seqlen, - q_offsets, kv_offsets, layout): + q_offsets, kv_offsets, page_table_k, page_table_v, layout): def check_eq(a, b, c, msg): if not (a == b == c): raise ValueError(f"{msg} must be same, got {a}, {b}, {b}") @@ -284,7 +287,7 @@ def check_eq(a, b, c, msg): check_eq(q_rank, k_rank, v_rank, "QKV rank") q_dtype, k_dtype, v_dtype = query.dtype, key.dtype, value.dtype - if q_dtype not in [jnp.bfloat16, jnp.float16, jnp.float8_e4m3fn, jnp.float8_e5m2]: + if q_dtype not in [np.float16, dtypes.bfloat16, dtypes.float8_e4m3fn, dtypes.float8_e5m2]: raise NotImplementedError(f"Q must be fp16/bf16/fp8_e4m3fn/fp8_e5m2, got {q_dtype}") check_eq(q_dtype, k_dtype, v_dtype, "QKV dtype") @@ -298,8 +301,25 @@ def check_eq(a, b, c, msg): kB, kS, kN, kH = key.shape vB, vS, vN, vH = value.shape + if page_table_k is not None and page_table_v is not None: + k_blocks, k_block_size = kB, kS + v_blocks, v_block_size = vB, vS + kB, _, k_blocks_per_batch, _ = page_table_k.shape + vB, _, v_blocks_per_batch, _ = page_table_v.shape + kS = k_blocks_per_batch * k_block_size + vS = v_blocks_per_batch * v_block_size + if kB * k_blocks_per_batch != k_blocks: + raise ValueError( + f"Key and page_table_k must have same number of blocks, " + f"got {k_blocks} vs {kB * k_blocks_per_batch}") + if vB * v_blocks_per_batch != v_blocks: + raise ValueError( + f"Value and page_table_v must have same number of blocks, " + f"got {v_blocks} vs {vB * v_blocks_per_batch}") + check_eq(qB, kB, vB, "QKV batch") - check_eq(qH, kH, vH, "QKV dim_per_head") + if qH != kH: + raise ValueError(f"QK must have same head dim, got {qH} vs {kH}") if kN != vN: raise ValueError(f"KV must have same number of heads, got {kN} vs {vN}") if kS != vS: @@ -318,7 +338,7 @@ def check_seqlen_offsets(tensor, name): if tensor is not None: dtype = tensor.dtype rank = len(tensor.shape) - if dtype != jnp.int32: + if dtype != np.dtype('int32'): raise ValueError(f"{name} must have int32 datatype, got {dtype}") if rank != expected_rank: raise ValueError(f"{name} must have a rank of {expected_rank}, got {rank}") @@ -333,33 +353,40 @@ def check_seqlen_offsets(tensor, name): def check_is_flash_attention( - query, key, layout: int, cudnn_version, has_bias, is_training, is_packed=False, - is_fp8=False): + query, key, value, layout: int, cudnn_version, has_bias, is_training, + is_packed=False, is_paged_attention=False, is_fp8=False): # Extract sequence length (T) and head dim (H) based on layout if layout == AttentionLayout.BNTH.value: - _, _, T, H = query.shape - _, _, S, _ = key.shape + _, _, T, qH = query.shape + _, _, S, vH = value.shape else: - _, T, _, H = query.shape - _, S, _, _ = key.shape + _, T, _, qH = query.shape + _, S, _, vH = value.shape + + if is_cuda_compute_capability_equal("10.3") and cudnn_version < 91100: + # cudnn support compute_cap 10.3 on cudnn 9.11+ + raise NotImplementedError( + "Compute capability 10.3 requires cuDNN version >= 9.11.") # Flash attention conditions if is_fp8: # FP8 specific conditions - if not ((is_training and H == 128 and T % 128 == 0 and S % 128 == 0) or - (not is_training and H <= 256 and H % 16 == 0)): + if not ((is_training and qH == 128 and T % 128 == 0 and S % 128 == 0) or + (not is_training and qH <= 256 and qH % 16 == 0)): raise NotImplementedError( - f"Unsupported sequence length Q {T}, KV {S} and head dim {H} for FP8." + f"Unsupported sequence length Q {T}, KV {S} and head dim {qH} for FP8." ) else: # bf16/fp16 attention conditions # Check the head dim. - is_on_hopper = is_cuda_compute_capability_equal("9.0") - H_max = 256 if cudnn_version >= 90500 and is_on_hopper else 128 - if not (H <= H_max and H % 8 == 0): + is_hopper_or_later = check_compute_capability("9.0") + H_max = 256 if is_hopper_or_later else 128 + # check if multi-head latent attention is needed + is_mla = qH != vH + if not (qH <= H_max and qH % 8 == 0): raise NotImplementedError( - f"The head dim must be <= {H_max} and a mutiple of 8, " - f"but got {H}." + f"The head dim must be <= {H_max} and a multiple of 8, " + f"but got {qH}." ) # Check patterns with bias, seqlen should be divisible by 2 @@ -368,8 +395,12 @@ def check_is_flash_attention( f"Unsupported sequence length Q {T}, KV {S}." ) - if is_packed and cudnn_version < 90600: - raise NotImplementedError("Packed layout requires cudnn version >= 9.6.") + if is_packed and not check_compute_capability("9.0"): + raise NotImplementedError( + "Packed layout requires a GPU with at least Hopper architecture.") + if is_mla and (cudnn_version < 91000 or not check_compute_capability("9.0")): + raise NotImplementedError( + "mla requires cudnn version >= 9.10 and at least hopper arch.") def check_cudnn_version(): # check if cuDNN is installed @@ -380,7 +411,7 @@ def check_cudnn_version(): def check_compute_capability(capability): if not 'cuda' in xla_bridge.get_backend().platform_version: return False - d, *_ = jax.local_devices(backend="gpu") + d, *_ = xla_bridge.local_devices(backend="gpu") target = tuple(int(x) for x in capability.split(".")) current = tuple(int(x) for x in d.compute_capability.split(".")) return current >= target @@ -388,57 +419,66 @@ def check_compute_capability(capability): def is_cuda_compute_capability_equal(capability): if not 'cuda' in xla_bridge.get_backend().platform_version: return False - d, *_ = jax.local_devices(backend="gpu") + d, *_ = xla_bridge.local_devices(backend="gpu") target = tuple(int(x) for x in capability.split(".")) current = tuple(int(x) for x in d.compute_capability.split(".")) return current == target def _dot_product_attention_fwd( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, + page_table_k, page_table_v, scale, seed, dropout_rate, variadic_args, mask_type, layout, - sliding_window_length, cudnn_version): + sliding_window_length, cudnn_version, return_residual): # check if flash attention is supported for this attention pattern check_is_flash_attention( - query, key, layout, cudnn_version, bias is not None, False, - get_max_seg_per_batch(q_offsets) > 1) + query, key, value, layout, cudnn_version, bias is not None, False, + get_max_seg_per_batch(q_offsets) > 1, check_is_paged_attention(page_table_k)) outputs = _dot_product_attention_fwd_p_wrapper.bind( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - scale=scale, seed=seed, dropout_rate=dropout_rate, + page_table_k, page_table_v, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, - sliding_window_length=sliding_window_length, is_training=False) - output = outputs[0] - return output + sliding_window_length=sliding_window_length, is_training=False or return_residual) + if return_residual: + return tuple(outputs) + else: + return outputs[0] def _dot_product_attention_fwd_rule( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - scale, seed, dropout_rate, variadic_args, mask_type, layout, - sliding_window_length, cudnn_version): + page_table_k, page_table_v, scale, seed, dropout_rate, variadic_args, + mask_type, layout, sliding_window_length, cudnn_version, + return_residual): # check if flash attention is supported for this attention pattern check_is_flash_attention( - query, key, layout, cudnn_version, bias is not None, True, + query, key, value, layout, cudnn_version, bias is not None, True, get_max_seg_per_batch(q_offsets) > 1) outputs = _dot_product_attention_fwd_p_wrapper.bind( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - scale=scale, seed=seed, dropout_rate=dropout_rate, + page_table_k, page_table_v, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length, is_training=True) res = (query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, - kv_offsets, outputs[1], outputs[0]) - return outputs[0], res + kv_offsets, page_table_k, page_table_v, outputs[1], outputs[0]) + if return_residual: + return tuple(outputs), res + else: + return outputs[0], res def _dot_product_attention_bwd_rule( scale, seed, dropout_rate, variadic_args, mask_type, layout, - sliding_window_length, is_training, res, grad_output): + sliding_window_length, is_training, return_residual, res, grad_output): (query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - activation, fwd_output) = res + page_table_k, page_table_v, activation, fwd_output) = res + if return_residual: + grad_output = grad_output[0] grads = _dot_product_attention_bwd_p_wrapper.bind( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - activation, fwd_output, grad_output, scale=scale, seed=seed, - dropout_rate=dropout_rate, variadic_args=variadic_args, + page_table_k, page_table_v, activation, fwd_output, grad_output, + scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length ) - grads = (*grads,) + (None,) * (8 - len(grads)) + grads = (*grads,) + (None,) * (10 - len(grads)) return grads def _fix_seqlen_offsets(q_seqlen, kv_seqlen, q_offsets, kv_offsets, query, key): @@ -471,7 +511,7 @@ def _cu_offset(offsets, max_seq): batch = offsets.shape[0] offsets = jnp.where( offsets >= 0, - offsets + (jnp.arange(batch) * max_seq)[..., jnp.newaxis], + offsets + (jnp.arange(batch, dtype=offsets.dtype) * max_seq)[..., np.newaxis], offsets, ) return offsets @@ -480,17 +520,13 @@ def _cu_offset(offsets, max_seq): B, T, N, H = query.shape _, S, _, _ = key.shape - q_seqlen = _shift_to_left(q_seqlen, -1) - kv_seqlen = _shift_to_left(kv_seqlen, -1) + q_seqlen = _shift_to_left(q_seqlen, 0) + kv_seqlen = _shift_to_left(kv_seqlen, 0) q_offsets = _cu_offset(q_offsets, T) kv_offsets = _cu_offset(kv_offsets, S) - q_offsets = _shift_to_left(q_offsets, -1) - kv_offsets = _shift_to_left(kv_offsets, -1) - - # mark any invalid entries as maximum offset - q_offsets = jnp.where(q_offsets < 0, B * T, q_offsets) - kv_offsets = jnp.where(kv_offsets < 0, B * S, kv_offsets) + q_offsets = _shift_to_left(q_offsets, B * T) + kv_offsets = _shift_to_left(kv_offsets, B * S) # multiply by stride_per_token to get correct offsets # do it here because real stride changes after sharding @@ -501,27 +537,28 @@ def _cu_offset(offsets, max_seq): def _dot_product_attention_fwd_impl( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - scale, seed, dropout_rate, variadic_args, mask_type, layout, - sliding_window_length, is_training): + page_table_k, page_table_v, scale, seed, dropout_rate, variadic_args, + mask_type, layout, sliding_window_length, is_training): # args: {Q, K, V, mask*, bias*} q_seqlen, kv_seqlen, q_offsets, kv_offsets = \ _fix_seqlen_offsets(q_seqlen, kv_seqlen, q_offsets, kv_offsets, query, key) outputs = _dot_product_attention_fwd_p.bind( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - scale=scale, seed=seed, dropout_rate=dropout_rate, + page_table_k, page_table_v, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length, is_training=is_training) return outputs def _dot_product_attention_bwd_impl( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - activation, fwd_output, grad_output, scale, seed, dropout_rate, - variadic_args, mask_type, layout, sliding_window_length): + page_table_k, page_table_v, activation, fwd_output, grad_output, scale, + seed, dropout_rate, variadic_args, mask_type, layout, sliding_window_length): q_seqlen, kv_seqlen, q_offsets, kv_offsets = \ _fix_seqlen_offsets(q_seqlen, kv_seqlen, q_offsets, kv_offsets, query, key) grads = _dot_product_attention_bwd_p.bind( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - activation, fwd_output, grad_output, scale=scale, seed=seed, + page_table_k, page_table_v, activation, fwd_output, grad_output, + scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length) @@ -529,91 +566,88 @@ def _dot_product_attention_bwd_impl( def _dot_product_attention_fwd_abstract( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - *, scale, seed, dropout_rate, variadic_args, mask_type, layout, - sliding_window_length, is_training): - query_dtype = dtypes.canonicalize_dtype(query.dtype) + page_table_k, page_table_v, *, scale, seed, dropout_rate, variadic_args, + mask_type, layout, sliding_window_length, is_training): if layout == AttentionLayout.BNTH.value: B, N, T, _ = query.shape - _, _, S, _ = key.shape + _, _, S, H = value.shape + output_shape = (B, N, T, H) else: B, T, N, _ = query.shape - _, S, _, _ = key.shape - output_shape = query.shape + _, S, _, H = value.shape + output_shape = (B, T, N, H) max_seg_per_batch = get_max_seg_per_batch(q_offsets) softmax_stat_shape = (B * max_seg_per_batch, N, T) if is_training: return ( - core.ShapedArray(output_shape, query_dtype), # output - core.ShapedArray(softmax_stat_shape, jnp.float32), # softmax_stat + core.ShapedArray(output_shape, query.dtype), # output + core.ShapedArray(softmax_stat_shape, np.float32), # softmax_stat ) else: return ( - core.ShapedArray(output_shape, query_dtype), # output + core.ShapedArray(output_shape, query.dtype), # output ) def _dot_product_attention_bwd_abstract( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - activation, fwd_output, grad_output, *, scale, seed, dropout_rate, - variadic_args, mask_type, layout, sliding_window_length): - query_dtype = dtypes.canonicalize_dtype(query.dtype) - key_dtype = dtypes.canonicalize_dtype(key.dtype) - value_dtype = dtypes.canonicalize_dtype(value.dtype) - + page_table_k, page_table_v, activation, fwd_output, grad_output, *, + scale, seed, dropout_rate, variadic_args, mask_type, layout, sliding_window_length): _, has_dbias = variadic_args if has_dbias: # cuDNN supports bias for this case - bias_dtype = dtypes.canonicalize_dtype(bias.dtype) return ( core.ShapedArray( - query.shape, query_dtype + query.shape, query.dtype ), # grad query core.ShapedArray( - key.shape, key_dtype + key.shape, key.dtype ), # grad key core.ShapedArray( - value.shape, value_dtype + value.shape, value.dtype ), # grad value core.ShapedArray( - bias.shape, bias_dtype + bias.shape, bias.dtype ), # grad bias ) else: return ( core.ShapedArray( - query.shape, query_dtype + query.shape, query.dtype ), # grad query core.ShapedArray( - key.shape, key_dtype + key.shape, key.dtype ), # grad key core.ShapedArray( - value.shape, value_dtype + value.shape, value.dtype ), # grad value ) def _dot_product_attention_fwd_cuda_lowering( ctx, query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, - kv_offsets, scale, seed, dropout_rate, variadic_args, mask_type, - layout, sliding_window_length, is_training): + kv_offsets, page_table_k, page_table_v, scale, seed, dropout_rate, + variadic_args, mask_type, layout, sliding_window_length, is_training): query_type = ir.RankedTensorType(query.type) query_shape = query_type.shape - key_type = ir.RankedTensorType(key.type) - key_shape = key_type.shape + value_type = ir.RankedTensorType(value.type) + value_shape = value_type.shape if layout == AttentionLayout.BNTH.value: - B, N, T, H = query_shape - _, _, S, _ = key_shape + B, N, T, qk_H = query_shape + _, _, S, v_H = value_shape output_layout = (3, 2, 1, 0) output_transpose_perm = mlir.dense_int_array((0, 1, 2, 3)) else: - B, T, N, H = query_shape - _, S, _, _ = key_shape + B, T, N, qk_H = query_shape + _, S, _, v_H = value_shape output_layout = (3, 1, 2, 0) output_transpose_perm = mlir.dense_int_array((0, 2, 1, 3)) max_seg_per_batch = get_max_seg_per_batch(ir.RankedTensorType(q_offsets.type)) - output_shape = (B, N, T, H) + is_paged_attention = check_is_paged_attention(ir.RankedTensorType(page_table_k.type)) + + output_shape = (B, N, T, v_H) softmax_stat_shape = (B * max_seg_per_batch, N, T) workspace_shape = (0,) workspace_type = ir.IntegerType.get_unsigned(8) @@ -622,19 +656,22 @@ def _dot_product_attention_fwd_cuda_lowering( backend_config = create_dot_product_attention_backend_config( B, N, T, S, query_type.element_type, scale, seed, dropout_rate, mask_type, layout, sliding_window_length, max_seg_per_batch, - is_bwd=False) + is_paged_attention, is_bwd=False) # {Q, K, V, bias*, q_seqlen*, kv_seqlen*, q_offsets*, kv_offsets*}} # {output, activation*, workspace} has_dropout = dropout_rate > 0 operands = [query, key, value] if has_bias: operands.append(bias) - if has_padding(mask_type) or max_seg_per_batch > 1: + if has_padding(mask_type) or max_seg_per_batch > 1 or is_paged_attention: operands.append(q_seqlen) operands.append(kv_seqlen) if max_seg_per_batch > 1: operands.append(q_offsets) operands.append(kv_offsets) + if is_paged_attention: + operands.append(page_table_k) + operands.append(page_table_v) custom_call_name = get_custom_call_name(has_bias, has_dropout, False) @@ -670,38 +707,38 @@ def _dot_product_attention_fwd_cuda_lowering( def _dot_product_attention_bwd_cuda_lowering( ctx, query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - activation, fwd_output, grad_output, scale, seed, dropout_rate, - variadic_args, mask_type, layout, sliding_window_length): + page_table_k, page_table_v, activation, fwd_output, grad_output, + scale, seed, dropout_rate, variadic_args, mask_type, layout, sliding_window_length): query_type = ir.RankedTensorType(query.type) query_shape = query_type.shape key_type = ir.RankedTensorType(key.type) - key_shape = key_type.shape value_type = ir.RankedTensorType(value.type) + value_shape = value_type.shape if layout == AttentionLayout.BNTH.value: - B, q_N, T, H = query_shape - _, k_N, S, _ = key_shape + B, q_N, T, qk_H = query_shape + _, v_N, S, v_H = value_shape grad_layout = (3, 2, 1, 0) grad_transpose_perm = mlir.dense_int_array((0, 1, 2, 3)) else: - B, T, q_N, H = query_shape - _, S, k_N, _ = key_shape + B, T, q_N, qk_H = query_shape + _, S, v_N, v_H = value_shape grad_layout = (3, 1, 2, 0) grad_transpose_perm = mlir.dense_int_array((0, 2, 1, 3)) workspace_shape = (0,) workspace_type = ir.IntegerType.get_unsigned(8) - grad_query_shape = (B, q_N, T, H) - grad_key_shape = (B, k_N, S, H) - grad_value_shape = (B, k_N, S, H) + grad_query_shape = (B, q_N, T, qk_H) + grad_key_shape = (B, v_N, S, qk_H) + grad_value_shape = (B, v_N, S, v_H) has_bias, has_dbias = variadic_args max_seg_per_batch = get_max_seg_per_batch(ir.RankedTensorType(q_offsets.type)) backend_config = create_dot_product_attention_backend_config( B, q_N, T, S, query_type.element_type, scale, seed, dropout_rate, mask_type, layout, sliding_window_length, max_seg_per_batch, - is_bwd=True) + False, is_bwd=True) # {Q, K, V, activation, dO, bias*, O, q_seqlen*, kv_seqlen*, # q_offsets*, kv_offsets*} # {dQ, dK, dV, dbias*, workspace} @@ -769,7 +806,7 @@ def _dot_product_attention_fwd_batcher( mask_type, layout, sliding_window_length, is_training): _check_valid_batch_dims(batch_dims) query, key, value, bias, q_seqlen, kv_seqlen, \ - q_offsets, kv_offsets = batched_args + q_offsets, kv_offsets, page_table_k, page_table_v = batched_args query_bdim = batch_dims[0] if is_training: out_bdims = query_bdim, query_bdim @@ -797,7 +834,7 @@ def _dot_product_attention_fwd_batcher( outputs = _dot_product_attention_fwd_p_wrapper.bind( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - scale=scale, seed=seed, dropout_rate=dropout_rate, + page_table_k, page_table_v, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length, is_training=is_training) @@ -816,7 +853,7 @@ def _dot_product_attention_bwd_batcher( mask_type, layout, sliding_window_length): _check_valid_batch_dims(batch_dims) query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, \ - activation, fwd_output, grad_output = batched_args + page_table_k, page_table_v, activation, fwd_output, grad_output = batched_args query_bdim = batch_dims[0] out_bdims = query_bdim, query_bdim, query_bdim @@ -828,11 +865,6 @@ def _dot_product_attention_bwd_batcher( *_, S, _, _ = key.shape B = math.prod(Bs) has_bias, has_dbias = variadic_args - # Reset the has_dbias if the combined batch size is not 1, because cuDNN only - # supports dbias with a single batch. In this case, an all-zero dbias will be - # appended instead. - if B > 1: - variadic_args = (has_bias, False) original_query_shape = query.shape original_key_shape = key.shape original_value_shape = value.shape @@ -853,8 +885,8 @@ def _dot_product_attention_bwd_batcher( grads = _dot_product_attention_bwd_p_wrapper.bind( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - activation, fwd_output, grad_output, scale=scale, seed=seed, - dropout_rate=dropout_rate, variadic_args=variadic_args, + page_table_k, page_table_v, activation, fwd_output, grad_output, + scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length, ) @@ -906,7 +938,7 @@ def _check_qkv_bias_mask_spec( # fwd custom partition -def _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args,is_training, layout): +def _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args, is_training, layout): # only sharding on batch and num_head dim is allowed # (*batch, q_seq, num_head, head) query_spec = _get_padded_spec(arg_shapes[0]) @@ -922,20 +954,54 @@ def _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args,is_training, layo out_sharding = NamedSharding(mesh, PartitionSpec(*query_spec)) if is_training: # activation sharding - *batch_spec, q_seq_spec, num_head_spec, _ = query_spec + if layout == AttentionLayout.BNTH.value: + *batch_spec, num_head_spec, q_seq_spec, _ = query_spec + else: + *batch_spec, q_seq_spec, num_head_spec, _ = query_spec activation_sharding = NamedSharding( mesh, PartitionSpec(*batch_spec, num_head_spec, q_seq_spec, None)) return [out_sharding, activation_sharding] return [out_sharding] +def _fwd_shardy_rule(value_types, result_types, layout, is_training, is_fp8): + num_args = len(value_types) + # We only need the query and value sharding, so use placeholders for the remaining args. + input_sharding = [ArrayMapping(f'{BATCHING}{n}') for n in range(num_args)] + if layout == AttentionLayout.BNTH.value: + input_sharding[0] = ArrayMapping('batch', 'nhead', 'qseq', 'head') + else: + input_sharding[0] = ArrayMapping('batch', 'qseq', 'nhead', 'head') + input_sharding[2] += ('v',) + + # The major dimensions are sharded like the query, the minor like the value. + output_sharding = (ArrayMapping(*input_sharding[0][:-1], 'v'),) + if is_fp8: + # `amax` is a scalar. + amax = ArrayMapping(f'{BATCHING}{num_args}') + output_sharding += (amax, amax) + factor_sizes = {} + if is_training: + # Activation sharding. + if result_types[-1].shape[0] == value_types[0].shape[0]: + output_sharding += (ArrayMapping('batch', 'nhead', 'qseq'),) + else: + factor_sizes['n'] = result_types[-1].shape[0] // value_types[0].shape[0] + output_sharding += (ArrayMapping(CompoundFactor('batch', 'n'), 'nhead', 'qseq'),) + return SdyShardingRule(tuple(input_sharding), output_sharding, **factor_sizes) + _dot_product_attention_fwd_lower = custom_partitioning( - _dot_product_attention_fwd_impl, static_argnums=(8, 9, 10, 11, 12, 13, 14, 15)) + _dot_product_attention_fwd_impl, static_argnums=(10, 11, 12, 13, 14, 15, 16, 17)) def _dot_product_attention_fwd_infer_sharding_from_operands( scale, seed, dropout_rate, variadic_args, mask_type, layout, sliding_window_length, is_training, mesh, arg_shapes, result_shape): return _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args, is_training, layout) +def _dot_product_attention_fwd_shardy_rule( + scale, seed, dropout_rate, variadic_args, mask_type, layout, sliding_window_length, + is_training, mesh, value_types, result_types): + return _fwd_shardy_rule(value_types, result_types, layout, is_training, is_fp8=False) + def _dot_product_attention_fwd_partition( scale, seed, dropout_rate, variadic_args, mask_type, layout, sliding_window_length, is_training, mesh, arg_shapes, result_shape): @@ -977,8 +1043,20 @@ def _infer_bwd_output_sharding(mesh, arg_shapes, layout, variadic_args): out_shardings = out_shardings + [grad_bias_sharding] return out_shardings +def _bwd_shardy_rule(num_args, has_dbias, is_fp8): + input_sharding = tuple(ArrayMapping(f'{BATCHING}{n}') for n in range(num_args)) + + if has_dbias: + output_sharding = input_sharding[0:4] + else: + output_sharding = input_sharding[0:3] + if is_fp8: + amax = ArrayMapping(f'{BATCHING}{num_args}') + output_sharding += (amax, amax, amax, amax) + return SdyShardingRule(input_sharding, output_sharding) + _dot_product_attention_bwd_lower = custom_partitioning( - _dot_product_attention_bwd_impl, static_argnums=(11, 12, 13, 14, 15, 16, 17) + _dot_product_attention_bwd_impl, static_argnums=(13, 14, 15, 16, 17, 18, 19) ) def _dot_product_attention_bwd_infer_sharding_from_operands( @@ -986,6 +1064,12 @@ def _dot_product_attention_bwd_infer_sharding_from_operands( sliding_window_length, mesh, arg_shapes, result_shape): return _infer_bwd_output_sharding(mesh, arg_shapes, layout, variadic_args) +def _dot_product_attention_bwd_shardy_rule( + scale, seed, dropout_rate, variadic_args, + mask_type, layout, sliding_window_length, mesh, value_types, result_types): + _, has_dbias = variadic_args + return _bwd_shardy_rule(len(value_types), has_dbias, is_fp8=False) + def _dot_product_attention_bwd_partition( scale, seed, dropout_rate, variadic_args, mask_type, layout, sliding_window_length, mesh, arg_shapes, result_shape): @@ -1007,10 +1091,21 @@ def sharded_impl(*args): _, has_dbias = variadic_args if has_dbias: query_spec = arg_shardings[0].spec - batch_spec = query_spec[0] - local_dbias = grads[3] - global_dbias = jax.lax.psum(local_dbias, batch_spec) - grads = grads[:3] + [global_dbias] + bias_spec = arg_shardings[3].spec + if layout == AttentionLayout.BNTH.value: + q_batch_spec, q_num_head_spec, _, _ = query_spec + else: + q_batch_spec, _, q_num_head_spec, _ = query_spec + b_batch_spec, b_num_head_spec, _, _ = bias_spec + + dbias = grads[3] + if q_batch_spec is not None and b_batch_spec is None: + # bias is replicated alone batch dim + dbias = lax_parallel.psum(dbias, q_batch_spec) + if q_num_head_spec is not None and b_num_head_spec is None: + # bias is replicated alone num_head dim + dbias = lax_parallel.psum(dbias, q_num_head_spec) + grads = grads[:3] + [dbias] return grads return mesh, sharded_impl, out_shardings, arg_shardings @@ -1018,7 +1113,7 @@ def sharded_impl(*args): _dot_product_attention_fwd_p = core.Primitive("dot_product_attention_fwd") _dot_product_attention_fwd_p.multiple_results = True _dot_product_attention_fwd_p.def_impl( - functools.partial(xla.apply_primitive, _dot_product_attention_fwd_p) + functools.partial(dispatch.apply_primitive, _dot_product_attention_fwd_p) ) _dot_product_attention_fwd_p.def_abstract_eval( _dot_product_attention_fwd_abstract @@ -1043,7 +1138,7 @@ def sharded_impl(*args): _dot_product_attention_bwd_p = core.Primitive("dot_product_attention_bwd") _dot_product_attention_bwd_p.multiple_results = True _dot_product_attention_bwd_p.def_impl( - functools.partial(xla.apply_primitive, _dot_product_attention_bwd_p) + functools.partial(dispatch.apply_primitive, _dot_product_attention_bwd_p) ) _dot_product_attention_bwd_p.def_abstract_eval( _dot_product_attention_bwd_abstract @@ -1073,14 +1168,16 @@ def sharded_impl(*args): _dot_product_attention_fwd_lower.def_partition( infer_sharding_from_operands=_dot_product_attention_fwd_infer_sharding_from_operands, - partition=_dot_product_attention_fwd_partition) + partition=_dot_product_attention_fwd_partition, + sharding_rule=_dot_product_attention_fwd_shardy_rule) mlir.register_lowering(_dot_product_attention_fwd_p_wrapper, mlir.lower_fun(_dot_product_attention_fwd_lower, multiple_results=True)) _dot_product_attention_bwd_lower.def_partition( infer_sharding_from_operands=_dot_product_attention_bwd_infer_sharding_from_operands, - partition=_dot_product_attention_bwd_partition) + partition=_dot_product_attention_bwd_partition, + sharding_rule=_dot_product_attention_bwd_shardy_rule) mlir.register_lowering(_dot_product_attention_bwd_p_wrapper, mlir.lower_fun(_dot_product_attention_bwd_lower, multiple_results=True)) @@ -1098,7 +1195,7 @@ def sharded_impl(*args): _dot_product_attention_bwd_p_wrapper ) -@functools.partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15)) +@functools.partial(custom_derivatives.custom_vjp, nondiff_argnums=(10, 11, 12, 13, 14, 15, 16, 17, 18)) def _dot_product_attention(query: Array, key: Array, value: Array, @@ -1107,6 +1204,8 @@ def _dot_product_attention(query: Array, kv_seqlen: Array, q_offsets: Array, kv_offsets: Array, + page_table_k: Array, + page_table_v: Array, scale: float, seed: int, dropout_rate: float, @@ -1114,13 +1213,14 @@ def _dot_product_attention(query: Array, mask_type: bool, layout: int, sliding_window_length: int | None, - cudnn_version: int): + cudnn_version: int, + return_residual: bool): output = _dot_product_attention_fwd( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - scale=scale, seed=seed, dropout_rate=dropout_rate, + page_table_k, page_table_v, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length, - cudnn_version=cudnn_version) + cudnn_version=cudnn_version, return_residual=return_residual) return output _dot_product_attention.defvjp( @@ -1161,7 +1261,7 @@ def _dot_product_attention_fp8_fwd( fp8_params_fwd, scale, use_causal_mask, layout, cudnn_version): check_is_flash_attention_fp8( - query, key, layout, cudnn_version, is_training=False) + query, key, value, layout, cudnn_version, is_training=False) descale_q, descale_k, descale_v, descale_s, scale_s, scale_o = fp8_params_fwd outputs = _dot_product_attention_fp8_fwd_p_wrapper.bind( query, key, value, @@ -1175,7 +1275,7 @@ def _dot_product_attention_fp8_fwd_rule( fp8_params, scale, use_causal_mask, layout, cudnn_version): check_is_flash_attention_fp8( - query, key, layout, cudnn_version, is_training=True) + query, key, value, layout, cudnn_version, is_training=True) outputs = _dot_product_attention_fp8_fwd_p_wrapper.bind( query, key, value, *params_from_keys(fp8_params, fp8_params_keys_fwd), @@ -1246,7 +1346,6 @@ def _dot_product_attention_fp8_fwd_abstract( query, key, value, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, scale, use_causal_mask, layout, is_training): - query_dtype = dtypes.canonicalize_dtype(query.dtype) if layout == AttentionLayout.BNTH.value: B, N, T, _ = query.shape _, _, S, _ = key.shape @@ -1259,16 +1358,16 @@ def _dot_product_attention_fp8_fwd_abstract( # output, amax_s, amax_o[, softmax_stat] if is_training: return ( - core.ShapedArray(output_shape, query_dtype), - core.ShapedArray((1,1,1,1), jnp.float32), - core.ShapedArray((1,1,1,1), jnp.float32), - core.ShapedArray(softmax_stat_shape, jnp.float32), + core.ShapedArray(output_shape, query.dtype), + core.ShapedArray((1,1,1,1), np.float32), + core.ShapedArray((1,1,1,1), np.float32), + core.ShapedArray(softmax_stat_shape, np.float32), ) else: return ( - core.ShapedArray(output_shape, query_dtype), - core.ShapedArray((1,1,1,1), jnp.float32), - core.ShapedArray((1,1,1,1), jnp.float32), + core.ShapedArray(output_shape, query.dtype), + core.ShapedArray((1,1,1,1), np.float32), + core.ShapedArray((1,1,1,1), np.float32), ) def _dot_product_attention_fp8_bwd_abstract( @@ -1276,20 +1375,15 @@ def _dot_product_attention_fp8_bwd_abstract( descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, scale, use_causal_mask, layout): - query_dtype = dtypes.canonicalize_dtype(query.dtype) - key_dtype = dtypes.canonicalize_dtype(key.dtype) - value_dtype = dtypes.canonicalize_dtype(value.dtype) - amax_shape = (1,1,1,1) - return ( - core.ShapedArray(query.shape, query_dtype), - core.ShapedArray(key.shape, key_dtype), - core.ShapedArray(value.shape, value_dtype), - core.ShapedArray(amax_shape, jnp.float32), - core.ShapedArray(amax_shape, jnp.float32), - core.ShapedArray(amax_shape, jnp.float32), - core.ShapedArray(amax_shape, jnp.float32), + core.ShapedArray(query.shape, query.dtype), + core.ShapedArray(key.shape, key.dtype), + core.ShapedArray(value.shape, value.dtype), + core.ShapedArray(amax_shape, np.float32), + core.ShapedArray(amax_shape, np.float32), + core.ShapedArray(amax_shape, np.float32), + core.ShapedArray(amax_shape, np.float32), ) def _dot_product_attention_fp8_fwd_cuda_lowering( @@ -1562,6 +1656,11 @@ def _dot_product_attention_fp8_fwd_partition( layout=layout, is_training=is_training) return mesh, impl, out_shardings, arg_shardings +def _dot_product_attention_fp8_fwd_shardy_rule( + scale, use_causal_mask, layout, is_training, + mesh, value_types, result_types): + return _fwd_shardy_rule(value_types, result_types, layout, is_training, is_fp8=True) + def _infer_fp8_bwd_output_sharding(mesh, arg_shapes, layout): # Prepare variadic_args for the original function has_bias = False # Adjust as needed @@ -1588,6 +1687,10 @@ def _dot_product_attention_fp8_bwd_infer_sharding_from_operands( arg_shapes, result_shape): return _infer_fp8_bwd_output_sharding(mesh, arg_shapes, layout) +def _dot_product_attention_fp8_bwd_shardy_rule( + scale, use_causal_mask, layout, mesh, value_types, result_types): + return _bwd_shardy_rule(len(value_types), has_dbias=False, is_fp8=True) + def _dot_product_attention_fp8_bwd_partition( scale, use_causal_mask, layout, mesh, arg_shapes, result_shape): @@ -1604,7 +1707,7 @@ def _dot_product_attention_fp8_bwd_partition( _dot_product_attention_fp8_fwd_p = core.Primitive("dot_product_attention_fp8_fwd") _dot_product_attention_fp8_fwd_p.multiple_results = True _dot_product_attention_fp8_fwd_p.def_impl( - functools.partial(xla.apply_primitive, _dot_product_attention_fp8_fwd_p) + functools.partial(dispatch.apply_primitive, _dot_product_attention_fp8_fwd_p) ) _dot_product_attention_fp8_fwd_p.def_abstract_eval( _dot_product_attention_fp8_fwd_abstract @@ -1629,7 +1732,7 @@ def _dot_product_attention_fp8_bwd_partition( _dot_product_attention_fp8_bwd_p = core.Primitive("dot_product_attention_fp8_bwd") _dot_product_attention_fp8_bwd_p.multiple_results = True _dot_product_attention_fp8_bwd_p.def_impl( - functools.partial(xla.apply_primitive, _dot_product_attention_fp8_bwd_p) + functools.partial(dispatch.apply_primitive, _dot_product_attention_fp8_bwd_p) ) _dot_product_attention_fp8_bwd_p.def_abstract_eval( _dot_product_attention_fp8_bwd_abstract @@ -1659,14 +1762,16 @@ def _dot_product_attention_fp8_bwd_partition( _dot_product_attention_fp8_fwd_lower.def_partition( infer_sharding_from_operands=_dot_product_attention_fp8_fwd_infer_sharding_from_operands, - partition=_dot_product_attention_fp8_fwd_partition) + partition=_dot_product_attention_fp8_fwd_partition, + sharding_rule=_dot_product_attention_fp8_fwd_shardy_rule) mlir.register_lowering(_dot_product_attention_fp8_fwd_p_wrapper, mlir.lower_fun(_dot_product_attention_fp8_fwd_lower, multiple_results=True)) _dot_product_attention_fp8_bwd_lower.def_partition( infer_sharding_from_operands=_dot_product_attention_fp8_bwd_infer_sharding_from_operands, - partition=_dot_product_attention_fp8_bwd_partition) + partition=_dot_product_attention_fp8_bwd_partition, + sharding_rule=_dot_product_attention_fp8_bwd_shardy_rule) mlir.register_lowering(_dot_product_attention_fp8_bwd_p_wrapper, mlir.lower_fun(_dot_product_attention_fp8_bwd_lower, multiple_results=True)) @@ -1684,7 +1789,7 @@ def _dot_product_attention_fp8_bwd_partition( _dot_product_attention_fp8_bwd_p_wrapper ) -@functools.partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7)) +@functools.partial(custom_derivatives.custom_vjp, nondiff_argnums=(4, 5, 6, 7)) def _dot_product_attention_fp8(query: Array, key: Array, value: Array, @@ -1701,7 +1806,119 @@ def _dot_product_attention_fp8(query: Array, _dot_product_attention_fp8.defvjp(_dot_product_attention_fp8_fwd_rule, _dot_product_attention_fp8_bwd_rule) +def combine_bias_and_mask(bias, mask, dtype): + if bias is not None: + # reshape bias to have 4D shape + bias = bias.reshape((1,) * (4 - len(bias.shape)) + bias.shape) + + if mask is not None: + if mask.dtype == np.dtype('bool'): + large_negative_number = get_large_negative_number(dtype) + mask = jnp.where(mask, jnp.asarray(0, dtype), large_negative_number) + # reshape mask to have 4D shape + mask = mask.reshape((1,) * (4 - len(mask.shape)) + mask.shape) # type: ignore[union-attr] + + # combine bias and mask + if bias is None: + bias = mask + else: + if mask is not None: + # should be broadcast to same shape + bias = bias + mask + return bias + # User interface +def paged_attention( + query: Array, + key: Array, + value: Array, + q_seqlen: Array, + kv_seqlen: Array, + page_table_k: Array, + page_table_v: Array, + bias: Array | None = None, + mask: Array | None = None, + fp8_params: FP8Params | None = None, + *, + scale: float = 1.0, + mask_type: MaskType = MaskType.NO_MASK, + seed: int = 42, + dropout_rate: float = 0., + qkv_layout: str = "BTNH", + sliding_window_length: int | None = None, + use_fp8: bool = False, + return_residual: bool = False +): + """Computes paged attention described in https://arxiv.org/pdf/2309.06180. + + B = batch size + S = length of the key/value (source) + T = length of the query (target) + N = number of attention heads + H = dimensions of each attention head. + + Args: + query: Queries for attention calculation with a shape of BTNH or BNTH. + key: Keys for attention calculation with a shape of + [num_blocks, block_size, N, H] or [num_blocks, N, block_size, H] where + num_blocks = B * Ceil(S / block_size). + value: Values to be used in attention with a shape of + [num_blocks, block_size, N, H] or [num_blocks, N, block_size, H] where + num_blocks = B * Ceil(S / block_size). + q_seqlen: Non padded sequence length of query with a shape of B. + kv_seqlen: Non padded sequence length of key and value with a shape of B. + page_table_k: page table for key of shape [B, 1, num_blocks_per_batch, 1] + where num_blocks_per_batch = Ceil(S / block_size). + page_table_v: page table for value of shape [B, 1, num_blocks_per_batch, 1] + where num_blocks_per_batch = Ceil(S / block_size). + bias: Bias to be added to logits with a shape of BNTS. + mask: Mask used to filter out logits with a shape of BNTS. + scale: Scale for the query. + qkv_layout: Layout string, with supported formats being BTNH, BNTH, BSNH, + BNSH. + sliding_window_length: Window size to make attention only attend to each + token's left local window (pos - sliding_window_length, pos] where `pos` + is the index of each token. E.g., if sliding_window_length == 3 and the + sequence is [0, 1, 2, 3, c, 4, 5], token `c` can attend to [4, 5, c]. + use_fp8: Whether to use FP8 attention mechanism. + return_residual: Whether to return the logsumexp tensor of shape BTN + or BNT to users. See section 3.1.1 in the FlashAttention-2 paper: + https://arxiv.org/pdf/2307.08691 to find the definition of logsumexp. + Returns: + output: the same shape as the query. + residual: the logsumexp tensor if return_residual=True. (non fp8) + """ + cudnn_version = check_cudnn_version() + layout = _normalize_layout(qkv_layout) + if use_fp8: + raise ValueError("Paged attention doesn't support fp8 for now.") + if has_padding(mask_type) and (q_seqlen is None or kv_seqlen is None): + raise ValueError("Require q_seqlen and kv_seqlen to generate padding mask.") + if sliding_window_length is not None and sliding_window_length <= 0: + raise ValueError( + f"Require sliding_window_length > 0, got {sliding_window_length}.") + + bias = combine_bias_and_mask(bias, mask, query.dtype) + # check if input shape and data type is compatiable + check_layout(query, key, value, bias, q_seqlen, kv_seqlen, None, None, + page_table_k, page_table_v, layout) + has_bias = bias is not None + has_dbias = has_bias and \ + should_export_dbias(bias.shape, query.shape, layout) # type: ignore[union-attr] + variadic_args = (has_bias, has_dbias) + + _not_used = jnp.zeros(0, dtype=query.dtype) + if bias is None: + bias = _not_used + + output = _dot_product_attention( + query, key, value, bias, q_seqlen, kv_seqlen, _not_used, _not_used, + page_table_k, page_table_v, scale, seed, dropout_rate, variadic_args, + mask_type, layout.value, sliding_window_length, cudnn_version, + return_residual) + return output + + def dot_product_attention( query: Array, key: Array, @@ -1720,7 +1937,8 @@ def dot_product_attention( dropout_rate: float = 0., qkv_layout: str = "BTNH", sliding_window_length: int | None = None, - use_fp8: bool = False + use_fp8: bool = False, + return_residual: bool = False ): """Computes dot-product attention given query (Q), key (K), and value (V). @@ -1776,8 +1994,12 @@ def dot_product_attention( is the index of each token. E.g., if sliding_window_length == 3 and the sequence is [0, 1, 2, 3, c, 4, 5], token `c` can attend to [4, 5, c]. use_fp8: Whether to use FP8 attention mechanism. + return_residual: Whether to return the logsumexp tensor of shape BTN + or BNT to users. See section 3.1.1 in the FlashAttention-2 paper: + https://arxiv.org/pdf/2307.08691 to find the definition of logsumexp. Returns: - Output of the same shape as the query. + output: the same shape as the query. + residual: the logsumexp tensor if return_residual=True. (non fp8) amax_s: amax of state. (fp8 only) amax_o: amax of output. (fp8 only) """ @@ -1797,7 +2019,8 @@ def dot_product_attention( f"but got: bias={bias}, mask={mask}, q_seqlen={q_seqlen}, kv_seqlen={kv_seqlen}" ) check_fp8_params(fp8_params) - check_layout(query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, layout) + check_layout(query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, + None, None, layout) output, amax_s, amax_o = _dot_product_attention_fp8( query, key, value, fp8_params, scale, mask_type == MaskType.CAUSAL, layout.value, cudnn_version @@ -1812,44 +2035,30 @@ def dot_product_attention( if q_offsets is not None and (q_seqlen is None or kv_seqlen is None): raise ValueError("Require q_seqlen and kv_seqlen to use packed layout") - if bias is not None: - # reshape bias to have 4D shape - bias = bias.reshape((1,) * (4 - len(bias.shape)) + bias.shape) - - if mask is not None: - if mask.dtype == jnp.bool: - large_negative_number = get_large_negative_number(query.dtype) - mask = jnp.where(mask, jnp.asarray(0, query.dtype), large_negative_number) - # reshape mask to have 4D shape - mask = mask.reshape((1,) * (4 - len(mask.shape)) + mask.shape) # type: ignore[union-attr] - - # combine bias and mask - if bias is None: - bias = mask - else: - if mask is not None: - # should be broadcast to same shape - bias = bias + mask - + bias = combine_bias_and_mask(bias, mask, query.dtype) # check if input shape and data type is compatiable - check_layout(query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, layout) + check_layout(query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, + None, None, layout) has_bias = bias is not None has_dbias = has_bias and \ should_export_dbias(bias.shape, query.shape, layout) # type: ignore[union-attr] variadic_args = (has_bias, has_dbias) + _not_used = jnp.zeros(0, dtype=query.dtype) if bias is None: - bias = jnp.zeros(0, dtype=query.dtype) + bias = _not_used if q_seqlen is None: - q_seqlen = jnp.zeros(0, dtype=query.dtype) + q_seqlen = _not_used if kv_seqlen is None: - kv_seqlen = jnp.zeros(0, dtype=query.dtype) + kv_seqlen = _not_used if q_offsets is None: - q_offsets = jnp.zeros(0, dtype=query.dtype) + q_offsets = _not_used if kv_offsets is None: - kv_offsets = jnp.zeros(0, dtype=query.dtype) + kv_offsets = _not_used + output = _dot_product_attention( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - scale, seed, dropout_rate, variadic_args, mask_type, layout.value, - sliding_window_length, cudnn_version) + _not_used, _not_used, scale, seed, dropout_rate, variadic_args, + mask_type, layout.value, sliding_window_length, cudnn_version, + return_residual) return output diff --git a/jax/_src/cudnn/fusion.py b/jax/_src/cudnn/fusion.py index f320672463cb..3029953c3550 100644 --- a/jax/_src/cudnn/fusion.py +++ b/jax/_src/cudnn/fusion.py @@ -13,12 +13,13 @@ # limitations under the License. import functools -import jax -from jax._src import core as jax_core -from jax.interpreters import mlir -from jax.interpreters.mlir import hlo -from jax.interpreters.mlir import ir +from jax._src import api +from jax._src import core as jax_core +from jax._src import tree_util +from jax._src.interpreters import mlir +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import hlo def _cudnn_fusion_impl(*args, jaxpr, **unused_kwargs): @@ -41,21 +42,16 @@ def _custom_abstract_eval(*args, jaxpr, **unused_kwargs): def call_cudnn_fusion(f, *args, **kwargs): """Creates a new cudnn_fusion corresponding to calling the given function f with args and kwargs.""" - jaxpr, out_shapes = jax.make_jaxpr( + jaxpr, out_shapes = api.make_jaxpr( functools.partial(f, **kwargs), return_shape=True )(*args) - flat_args = jax.tree.leaves(args) - out_tree = jax.tree.structure(out_shapes) + flat_args = tree_util.tree_leaves(args) + out_tree = tree_util.tree_structure(out_shapes) out_flat = cudnn_fusion_p.bind(*flat_args, name=f.__name__, jaxpr=jaxpr) - return jax.tree.unflatten(out_tree, out_flat) + return tree_util.tree_unflatten(out_tree, out_flat) -def _cudnn_fusion_stablehlo_lowering( - ctx, - *args, - name, - jaxpr, -): +def _cudnn_fusion_stablehlo_lowering(ctx, *args, name, jaxpr): """Make cudnn_fusion which calls the implementation function. Currently this leaks a CallOp since we're using the `core_call_lowering` function, but this should get cleaned up by DCE easily. diff --git a/jax/_src/cudnn/scaled_matmul_stablehlo.py b/jax/_src/cudnn/scaled_matmul_stablehlo.py index 1a8dee293082..b6b80748fb24 100644 --- a/jax/_src/cudnn/scaled_matmul_stablehlo.py +++ b/jax/_src/cudnn/scaled_matmul_stablehlo.py @@ -16,25 +16,27 @@ import json import operator from functools import partial, reduce -from typing import List -# Third-party imports -import jax -import jax.numpy as jnp import numpy as np -from jax import custom_vjp, lax -from jax._src import core, dispatch, dtypes + +from jax._src import api +from jax._src import core +from jax._src import dispatch +from jax._src import dtypes +from jax._src import numpy as jnp +from jax._src import tree_util +from jax._src.custom_derivatives import custom_vjp from jax._src.custom_partitioning import custom_partitioning from jax._src.interpreters import batching +from jax._src.interpreters import mlir +from jax._src.interpreters.mlir import ir +from jax._src.lax import lax +from jax._src.lax import parallel as lax_parallel from jax._src.lax.lax import ranges_like, remaining -from jax._src.typing import DTypeLike -from jax.interpreters import mlir, xla -from jax.interpreters.mlir import ir -from jax.sharding import NamedSharding -from jax.sharding import PartitionSpec as P +from jax._src.typing import Array, DTypeLike +from jax._src.sharding_impls import NamedSharding, PartitionSpec as P -Array = jnp.ndarray block_scaled_dot_name = "__op$block_scaled_dot" @dataclass @@ -64,7 +66,7 @@ def _scaled_matmul_impl(a, b, a_scale, b_scale, preferred_element_type): ) -def _scaled_matmul_cuda_lowering( +def _scaled_matmul_gpu_lowering( ctx, a, b, a_scales, b_scales, preferred_element_type ): lhs_type = ir.RankedTensorType(a.type) @@ -103,7 +105,6 @@ def _scaled_matmul_cuda_lowering( def _scaled_matmul_abstract(a, b, a_scale, b_scale, *, preferred_element_type): - a_dtype = dtypes.canonicalize_dtype(a.dtype) batch, non_contracting_lhs, contracting_lhs = a.shape _, non_contracting_rhs, _ = b.shape output_shape = (batch, non_contracting_lhs, non_contracting_rhs) @@ -112,15 +113,20 @@ def _scaled_matmul_abstract(a, b, a_scale, b_scale, *, preferred_element_type): _scaled_matmul_p = core.Primitive("scaled_matmul") _scaled_matmul_p.multiple_results = True -_scaled_matmul_p.def_impl(partial(xla.apply_primitive, _scaled_matmul_p)) +dispatch.simple_impl(_scaled_matmul_p) _scaled_matmul_p.def_abstract_eval(_scaled_matmul_abstract) mlir.register_lowering( _scaled_matmul_p, - _scaled_matmul_cuda_lowering, + _scaled_matmul_gpu_lowering, platform="cuda", ) +mlir.register_lowering( + _scaled_matmul_p, + _scaled_matmul_gpu_lowering, + platform="rocm", +) _scaled_matmul_p_wrapper = core.Primitive("scaled_matmul_wrapper") _scaled_matmul_p_wrapper.multiple_results = True @@ -159,7 +165,7 @@ def _check_shardings(shardings): def _enable_reduce_scatter(lhs, rhs): - batch_spec, m_spec, lhs_k_spec = lhs.spec + _, m_spec, lhs_k_spec = lhs.spec _, n_spec, rhs_k_spec = rhs.spec return ( lhs_k_spec != None @@ -170,12 +176,18 @@ def _enable_reduce_scatter(lhs, rhs): def _enable_all_reduce(lhs, rhs): - batch_spec, m_spec, lhs_k_spec = lhs.spec + _, _, lhs_k_spec = lhs.spec _, n_spec, rhs_k_spec = rhs.spec return lhs_k_spec != None and lhs_k_spec == rhs_k_spec and n_spec == None +def _are_specs_overlapping(lhs, rhs): + if lhs is None or rhs is None: + return False + lhs = (lhs,) if isinstance(lhs, str) else lhs + rhs = (rhs,) if isinstance(rhs, str) else rhs + return not set(lhs).isdisjoint(rhs) -def _get_output_sharding(mesh, shardings): +def _get_output_sharding(shardings): lhs, rhs = shardings[0], shardings[1] batch_spec, m_spec, _ = lhs.spec _, n_spec, _ = rhs.spec @@ -184,71 +196,107 @@ def _get_output_sharding(mesh, shardings): return [NamedSharding(lhs.mesh, P(*lhs.spec))] output_specs = (batch_spec, m_spec) - output_specs += (n_spec,) if m_spec != n_spec else (None,) + # If the m and n specs are overlapping, we cannot keep both - + # we (arbitrarily) pick m and replicate for n. + output_specs += (None,) if _are_specs_overlapping(m_spec, n_spec) else (n_spec,) return [NamedSharding(lhs.mesh, P(*output_specs))] def _scaled_matmul_infer_sharding_from_operands( preferred_element_type, mesh, shapes, output_shape ): - shardings = jax.tree.map(lambda x: x.sharding, shapes) + shardings = tree_util.tree_map(lambda x: x.sharding, shapes) _check_shardings(shardings) - return _get_output_sharding(mesh, shardings) + return _get_output_sharding(shardings) -def supported_in_sharding(mesh, shardings): - lhs_sharding, rhs_sharding = shardings[0], shardings[1] - use_reduce_scatter = _enable_reduce_scatter(lhs_sharding, rhs_sharding) - use_all_reduce = _enable_all_reduce(lhs_sharding, rhs_sharding) - assert not (use_all_reduce and use_reduce_scatter) - - lhs_specs, rhs_specs = list(lhs_sharding.spec), list(rhs_sharding.spec) +# If one of the non contracting dimensions of the output (M or N) is sharded +# alone the same axes as the contracting dimension (K), returns that output +# dimension along which to perform reduce-scatter. Otherwise, returns None. +def _get_reduce_scatter_dim(lhs, rhs, output): + _, _, lhs_k_spec = lhs.spec + _, _, rhs_k_spec = rhs.spec + _, out_m_spec, out_n_spec = output.spec - def named_sharding(lhs, rhs, lhs_specs, rhs_specs): - lhs_sharding = NamedSharding(lhs.mesh, P(*lhs_specs)) - rhs_sharding = NamedSharding(rhs.mesh, P(*rhs_specs)) - return (lhs_sharding, rhs_sharding, lhs_sharding, rhs_sharding) + if lhs_k_spec == None or lhs_k_spec != rhs_k_spec: + return None - if use_all_reduce: - return named_sharding(lhs_sharding, rhs_sharding, lhs_specs, rhs_specs) + if out_m_spec == lhs_k_spec: + return 1 + if out_n_spec == lhs_k_spec: + return 2 + return None - if use_reduce_scatter: - rhs_specs[1] = None - return named_sharding(lhs_sharding, rhs_sharding, lhs_specs, rhs_specs) - lhs_specs[2] = None - rhs_specs[2] = None - m_spec, n_spec = lhs_specs[1], rhs_specs[1] - if m_spec == n_spec: - rhs_specs[1] = None - - return named_sharding(lhs_sharding, rhs_sharding, lhs_specs, rhs_specs) +def _supported_in_out_sharding(lhs_sharding, rhs_sharding, out_sharding, reduce_scatter_dim): + use_all_reduce = _enable_all_reduce(lhs_sharding, rhs_sharding) + batch_spec, m_spec, k_spec = lhs_sharding.spec + batch_spec_rhs, n_spec, _ = rhs_sharding.spec + + # This is checked by the caller, assert here for documentation. + assert batch_spec == batch_spec_rhs + + def named_sharding(lhs_specs, rhs_specs, out_specs): + lhs = NamedSharding(lhs_sharding.mesh, P(*lhs_specs)) + rhs = NamedSharding(rhs_sharding.mesh, P(*rhs_specs)) + out = NamedSharding(lhs_sharding.mesh, P(*out_specs)) + return ((lhs, rhs, lhs, rhs), [out]) + + if reduce_scatter_dim == 1: + lhs_specs = (batch_spec, None, k_spec) + rhs_specs = (batch_spec, n_spec, k_spec) + out_specs = (batch_spec, k_spec, n_spec) + return named_sharding(lhs_specs, rhs_specs, out_specs) + + if reduce_scatter_dim == 2: + lhs_specs = (batch_spec, m_spec, k_spec) + rhs_specs = (batch_spec, None, k_spec) + out_specs = (batch_spec, m_spec, k_spec) + return named_sharding(lhs_specs, rhs_specs, out_specs) + + if not use_all_reduce: + k_spec = None + + if _are_specs_overlapping(m_spec, n_spec): + # We have m and n specs that share an axis, so we can't keep both. + # Let us keep the one that was inferred in the output. + if n_spec == out_sharding.spec[2]: + # Output has n spec, so we get rid of m. + m_spec = None + else: + # Otherwise, we get rid of n. + n_spec = None + + lhs_specs = (batch_spec, m_spec, k_spec) + rhs_specs = (batch_spec, n_spec, k_spec) + out_specs = (batch_spec, m_spec, n_spec) + return named_sharding(lhs_specs, rhs_specs, out_specs) def _scaled_matmul_partition( preferred_element_type, mesh, shapes, output_shape ): - shardings = jax.tree.map(lambda x: x.sharding, shapes) + shardings = tree_util.tree_map(lambda x: x.sharding, shapes) _check_shardings(shardings) lhs, rhs = shardings[0], shardings[1] + out = output_shape[0].sharding use_all_reduce = _enable_all_reduce(lhs, rhs) - use_reduce_scatter = _enable_reduce_scatter(lhs, rhs) + reduce_scatter_dim = _get_reduce_scatter_dim(lhs, rhs, out) lhs_k_spec = lhs.spec[2] def _scaled_matmul_impl_partition(a, b, a_scale, b_scale): z = _scaled_matmul_impl(a, b, a_scale, b_scale, preferred_element_type) - if use_reduce_scatter: - z = jax.lax.psum_scatter( - z, lhs_k_spec, scatter_dimension=2, tiled=True - ) - if use_all_reduce: - z = jax.lax.psum(z, lhs_k_spec) + if reduce_scatter_dim is not None: + z = lax_parallel.psum_scatter( + z, lhs_k_spec, scatter_dimension=reduce_scatter_dim, tiled=True + ) + elif use_all_reduce: + z = lax_parallel.psum(z, lhs_k_spec) return z - out_shardings = _get_output_sharding(mesh, shardings) - arg_shardings = supported_in_sharding(mesh, shardings) + arg_shardings, out_shardings = _supported_in_out_sharding(lhs, rhs, out, reduce_scatter_dim) return mesh, _scaled_matmul_impl_partition, out_shardings, arg_shardings @@ -259,6 +307,7 @@ def _scaled_matmul_impl_partition(a, b, a_scale, b_scale): _scaled_matmul_lower.def_partition( infer_sharding_from_operands=_scaled_matmul_infer_sharding_from_operands, partition=_scaled_matmul_partition, + sharding_rule='b m k, b n k, b m k, b n k -> b m n', ) @@ -311,13 +360,13 @@ def _scaled_matmul_batcher(batched_args, batch_dims, *, preferred_element_type): batching.primitive_batchers[_scaled_matmul_p] = _scaled_matmul_batcher -@partial(jax.jit, static_argnames=("preferred_element_type",)) +@api.jit(static_argnames=("preferred_element_type",)) def _scaled_matmul( lhs: Array, rhs: Array, lhs_scales: Array, rhs_scales: Array, - preferred_element_type: DTypeLike = jnp.float32, + preferred_element_type: DTypeLike = np.dtype('float32'), ) -> Array: output = _scaled_matmul_p_wrapper.bind( lhs, rhs, lhs_scales, rhs_scales, @@ -330,7 +379,7 @@ def scaled_matmul_wrapper( rhs: Array, lhs_scales: Array, rhs_scales: Array, - preferred_element_type: DTypeLike = jnp.float32, + preferred_element_type: DTypeLike = np.dtype('float32'), ) -> Array: """ Performs scaled matrix multiplication between two 3D arrays, with scaling @@ -364,9 +413,8 @@ def scaled_matmul_wrapper( assert lhs_K == rhs_K _, _, K_block = lhs_scales.shape - preferred_element_type = dtypes.canonicalize_dtype( - np.dtype(preferred_element_type) - ) + preferred_element_type = dtypes.check_and_canonicalize_user_dtype( + preferred_element_type, "scaled_matmul_wrapper") out = _scaled_matmul( lhs, @@ -451,7 +499,7 @@ def compute_dot_output_shape( def cast_to_e8m0_with_rounding_up(x): - temp = x.astype(jnp.float32).view(jnp.uint32) + temp = x.astype(np.float32).view(np.uint32) exp = temp >> 23 mant = temp & 0x7FFFFF is_ru = jnp.logical_and( @@ -459,17 +507,17 @@ def cast_to_e8m0_with_rounding_up(x): ~jnp.logical_and((exp == 0), (mant <= 0x400000)) ) exp = jnp.where(is_ru, exp + 1, exp) - new_x = exp.astype(jnp.uint8) + new_x = exp.astype(np.uint8) return new_x def e8m0_to_dtype(x, dtype): - temp = x.astype(jnp.uint32) + temp = x.astype(np.uint32) exp = temp << 23 - new_x = exp.view(jnp.float32) - near_zero_value = 2**-15 if dtype == jnp.float16 else 2**-127 + new_x = exp.view(np.float32) + near_zero_value = 2**-15 if dtype == np.float16 else 2**-127 new_x = jnp.where( - new_x == 0, jnp.array(near_zero_value, jnp.float32), new_x + new_x == 0, jnp.array(near_zero_value, np.float32), new_x ) return new_x.astype(dtype) @@ -480,25 +528,31 @@ def quantize(x, config): assert contract_dim >= block_size and contract_dim % block_size == 0 x_new_shape = x_shape[:-1] + (x_shape[-1] // block_size, block_size) x = x.reshape(x_new_shape) # shape = (B, M, K / block_size, block_size) + MAX = dtypes.finfo(config.data_type).max.astype(x.dtype) - amax = jnp.max(jnp.abs(x), axis=-1, keepdims=True) - MAX = jnp.finfo(config.data_type).max.astype(x.dtype) - scales = amax / MAX # shape = (B, M, K / block_size, 1) + def get_scales_per_block(values): + # shape = (B, M, K / block_size, 1) + return jnp.max(jnp.abs(values), axis=-1, keepdims=True) / MAX if config.mode == "mxfp8": - assert config.scale_type == jnp.float8_e8m0fnu - scales_q = cast_to_e8m0_with_rounding_up(scales) - scaled_x = x / e8m0_to_dtype(scales_q, scales.dtype) + assert config.global_scale is None + assert config.scale_type == dtypes.float8_e8m0fnu + + scales_q = cast_to_e8m0_with_rounding_up(get_scales_per_block(x)) + scaled_x = x / e8m0_to_dtype(scales_q, x.dtype) elif config.mode == "nvfp4": - assert config.scale_type == jnp.float8_e4m3fn - assert config.global_scale.dtype == jnp.float32 + assert config.scale_type == dtypes.float8_e4m3fn + assert config.global_scale.dtype == np.float32 - scales = scales / config.global_scale - scales_q = jax.lax.optimization_barrier(scales.astype(jnp.float8_e4m3fn)) - scaled_x = x / (scales_q.astype(jnp.float32) * - config.global_scale).astype(x.dtype) + SCALE_MAX = dtypes.finfo(config.scale_type).max.astype(x.dtype) + + x /= config.global_scale + scales_q = jnp.clip(get_scales_per_block(x), 0, SCALE_MAX) + scales_q = lax.optimization_barrier(scales_q.astype(config.scale_type)) + scaled_x = x / scales_q.astype(np.float32) else: raise ValueError(f"Unrecognized mode: {config.mode}.") + clipped_x = jnp.clip(scaled_x, -MAX, MAX) x_q = clipped_x.astype(config.data_type) @@ -515,9 +569,8 @@ def scaled_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type, lhs, rhs, return_weak_type_flag=False ) else: - preferred_element_type = dtypes.canonicalize_dtype( - np.dtype(preferred_element_type) - ) + preferred_element_type = dtypes.check_and_canonicalize_user_dtype( + preferred_element_type, "scaled_dot_impl") (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers lhs_dn = (lhs_contract, lhs_batch) @@ -531,7 +584,7 @@ def scaled_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type, out_dtype = preferred_element_type if configs[0].mode == 'nvfp4': - out_dtype = jnp.float32 + out_dtype = np.float32 out = scaled_matmul_wrapper( lhs_q, rhs_q, lhs_scales, rhs_scales, preferred_element_type=out_dtype @@ -590,7 +643,7 @@ def scaled_dot_general_transpose_lhs( def scaled_dot_general_transpose_rhs( g, x, y, *, dimension_numbers, preferred_element_type: DTypeLike, - configs: List[BlockScaleConfig] + configs: list[BlockScaleConfig] ): (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch)) @@ -639,6 +692,17 @@ def scaled_dot_bwd(dimension_numbers, preferred_element_type, configs, res, g): } grad_lhs = scaled_dot_general_transpose_lhs(*args, **lhs_kw_args) grad_rhs = scaled_dot_general_transpose_rhs(*args, **rhs_kw_args) + + # We apply a Straight-Through Estimator (STE) with zero-out behavior: if + # inputs are clipped during quantization in fprop, their corresponding gradients + # are zeroed out; otherwise, they pass through unchanged. + if configs[2].mode == "nvfp4": + assert rhs.dtype == lhs.dtype + MAX = dtypes.finfo(configs[0].data_type).max.astype(lhs.dtype) + SCALE_MAX = dtypes.finfo(configs[0].scale_type).max.astype(lhs.dtype) + grad_lhs = jnp.where(jnp.abs(lhs) <= configs[0].global_scale * MAX * SCALE_MAX, grad_lhs, 0) + grad_rhs = jnp.where(jnp.abs(rhs) <= configs[1].global_scale * MAX * SCALE_MAX, grad_rhs, 0) + return (grad_lhs, grad_rhs) @@ -673,10 +737,10 @@ def _ensure_batch_dim(lhs, rhs, dimension_numbers): def scaled_dot_general_wrapper( lhs, rhs, dimension_numbers, - preferred_element_type=jnp.float32, - configs: List[BlockScaleConfig] | None=None, + preferred_element_type=np.float32, + configs: list[BlockScaleConfig] | None=None, ): - if preferred_element_type not in (jnp.float32, jnp.bfloat16, jnp.float16): + if preferred_element_type not in (np.dtype('float32'), np.dtype('bfloat16'), np.dtype('float16')): msg = ('Only support preferred_element_type in (f32, bf16, f16), but got ' '{preferred_element_type}') raise TypeError(msg) @@ -684,8 +748,8 @@ def scaled_dot_general_wrapper( mxfp8_config = BlockScaleConfig( mode='mxfp8', block_size=32, - data_type=jnp.float8_e4m3fn, - scale_type=jnp.float8_e8m0fnu, + data_type=dtypes.float8_e4m3fn, + scale_type=dtypes.float8_e8m0fnu, global_scale=None, infer_only=False ) diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index 338074837ea5..e394e8273260 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -19,7 +19,6 @@ import functools import operator -from jax import lax from jax._src import api from jax._src import core from jax._src import custom_api_util @@ -34,7 +33,7 @@ from jax._src.interpreters.batching import not_mapped from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe -from jax._src.interpreters import xla +from jax._src.interpreters import pxla from jax._src.tree_util import (tree_flatten, tree_map, tree_structure, tree_unflatten, treedef_tuple) @@ -103,7 +102,7 @@ class custom_vmap: >>> jax.grad(f)(jnp.zeros(()), jnp.ones(())) Array(1., dtype=float32) - Note that the :py:class:`jax.custom_vjp` must be on the ouside, wrapping the + Note that the :py:class:`jax.custom_vjp` must be on the outside, wrapping the ``custom_vmap``-decorated function. """ @@ -161,7 +160,7 @@ def __call__(self, *args, **kwargs): lu.wrap_init(self.fun, debug_info=debug_fun), in_tree) in_avals = [core.get_aval(x) for x in args_flat] - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) in_tree = treedef_tuple((tree_structure(consts), in_tree)) assert self.vmap_rule is not None @@ -261,7 +260,8 @@ def custom_vmap_batching(args_flat, dims, *, call, rule, in_tree, out_tree): def custom_vmap_abstract_eval(*in_avals, call, **_): - return call.out_avals + del in_avals + return call.out_avals, call.effects def custom_vmap_jvp(primals, tangents, *, @@ -296,7 +296,7 @@ def to_jvp(*primals): out_mutually_batched.store(out_batched) return out - api_util.save_wrapped_fun_sourceinfo(to_jvp, call.jaxpr.debug_info) + api_util.save_wrapped_fun_debug_info(to_jvp, call.jaxpr.debug_info) def to_vmap_over_extra_batched_dims(primals, tangents): return api.jvp(to_jvp, primals, tangents) @@ -348,10 +348,10 @@ def to_vmap_over_extra_batched_dims(primals, tangents): custom_vmap_p = core.Primitive('custom_vmap_call') custom_vmap_p.multiple_results = True custom_vmap_p.def_impl(custom_vmap_impl) -custom_vmap_p.def_abstract_eval(custom_vmap_abstract_eval) +custom_vmap_p.def_effectful_abstract_eval(custom_vmap_abstract_eval) batching.primitive_batchers[custom_vmap_p] = custom_vmap_batching ad.primitive_jvps[custom_vmap_p] = custom_vmap_jvp -xla.register_initial_style_primitive(custom_vmap_p) +pxla.register_initial_style_primitive(custom_vmap_p) mlir.register_lowering(custom_vmap_p, mlir.lower_fun( custom_vmap_impl, multiple_results=True)) @@ -394,6 +394,8 @@ def sequential_vmap(f): See the documentation for :py:class:`~jax.custom_batching.custom_vmap` for more details. """ + from jax._src.lax import control_flow # pytype: disable=import-error + f = custom_vmap(f) @f.def_vmap @@ -405,7 +407,7 @@ def to_map(mapped_args): return f(*args) mapped_args, bcast_args = tree_split(in_batched, list(args)) - out = lax.map(to_map, mapped_args) + out = control_flow.map(to_map, mapped_args) out_batched = tree_map(lambda _: True, out) return out, out_batched diff --git a/jax/_src/custom_dce.py b/jax/_src/custom_dce.py index d336c969a3c4..2321241b15b4 100644 --- a/jax/_src/custom_dce.py +++ b/jax/_src/custom_dce.py @@ -183,7 +183,7 @@ def dce_jaxpr_thunk( out_avals, ) assert self.dce_rule is not None - dce_jaxpr, _, dce_consts, () = pe.trace_to_jaxpr_dynamic( + dce_jaxpr, _, dce_consts = pe.trace_to_jaxpr_dynamic( flat_rule, in_avals ) @@ -199,7 +199,7 @@ def dce_jaxpr_thunk( return core.ClosedJaxpr(dce_jaxpr, dce_consts), used_ins - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) closed_call = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr)) out_avals = closed_call.out_avals out_flat = custom_dce_p.bind( @@ -251,9 +251,9 @@ def flatten_dce_rule( # For error checking purposes, we need to reformat the pytree structure # of the output of the DCE rule to match the original output. The catch is # that the DCE rule can return a None to indicated an unused subtree, so we - # need to rebuild those subtrees with a sentinal value at the leaves. This + # need to rebuild those subtrees with a sentinel value at the leaves. This # logic is very similar to what is used in custom_dervatives._flatten_bwd. - sentinal = object() + sentinel = object() dummy = tree_util.tree_unflatten(out_tree, [object()] * out_tree.num_leaves) keypaths, _ = util.unzip2(tree_util.tree_flatten_with_path(dummy)[0]) out_flat = [] @@ -261,7 +261,7 @@ def flatten_dce_rule( def append(x, d): num_leaves = len(tree_util.tree_flatten(d)[0]) if x is None and d is not None: - out_flat.extend([sentinal] * num_leaves) + out_flat.extend([sentinel] * num_leaves) elif x is not None: out_flat.extend([x] * num_leaves) return x @@ -281,7 +281,7 @@ def append(x, d): for kp, used, aval, val in zip(keypaths, used_outs, out_avals, out_flat): if not used: continue - if val is sentinal: + if val is sentinel: raise ValueError( f"Custom DCE rule {rule_name} for function {fun_name} must produce " "values for all of the required outputs (as specified by the " diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 32856106ad8f..94528250d67d 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -31,20 +31,20 @@ stop_gradient_p, SymbolicZero, Zero, zeros_like_aval) from jax._src.api_util import ( argnums_partial, flatten_fun_nokwargs, resolve_kwargs, - prepend_static_args, debug_info) + prepend_static_args, debug_info, fun_signature, + infer_argnums_and_argnames) from jax._src.errors import UnexpectedTracerError from jax._src.state.types import AbstractRef from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe -from jax._src.interpreters import xla +from jax._src.interpreters import pxla from jax._src.interpreters.batching import not_mapped -from jax._src.lax import lax from jax._src.tree_util import ( tree_flatten, tree_unflatten, tree_map, treedef_is_leaf, treedef_tuple, register_pytree_node_class, tree_leaves, tree_flatten_with_path, - tree_leaves_with_path, keystr, treedef_children, PyTreeDef) + tree_leaves_with_path, keystr, treedef_children, tree_structure, PyTreeDef) from jax._src.util import (cache, safe_zip, safe_map, split_list, unzip2, weakref_lru_cache) @@ -60,7 +60,7 @@ def _initial_style_jaxpr(fun: lu.WrappedFun, in_avals: Sequence[core.AbstractValue] ) -> tuple[core.Jaxpr, Sequence[Any]]: - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, in_avals) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals) return jaxpr, consts def _close_jaxpr(jaxpr: core.Jaxpr) -> core.ClosedJaxpr: @@ -87,7 +87,7 @@ def _flatten_fun_nokwargs(f: Callable, ans = f(*py_args) ans_flat, ans_tree = tree_flatten(ans) ans_avals = [core.get_aval(x) for x in ans_flat] - store.store((ans_tree, ans_avals)) + store.store((ans_tree, ans_avals, ())) return ans_flat @@ -130,20 +130,35 @@ def f_jvp(primals, tangents): For a more detailed introduction, see the tutorial_. - .. _tutorial: https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html + .. _tutorial: https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html """ fun: Callable[..., ReturnValue] nondiff_argnums: Sequence[int] + nondiff_argnames: Sequence[str] jvp: Callable[..., tuple[ReturnValue, ReturnValue]] | None = None symbolic_zeros: bool = False def __init__(self, fun: Callable[..., ReturnValue], nondiff_argnums: Sequence[int] = (), + nondiff_argnames: Sequence[str] = (), ): update_wrapper(self, fun) self.fun = fun - self.nondiff_argnums = nondiff_argnums + + nondiff_argnums_: set[int] = set() + if nondiff_argnames: + sig = fun_signature(self.fun) + assert sig is not None + inferred_nondiff_argnums, _ = infer_argnums_and_argnames( + sig, None, nondiff_argnames + ) + nondiff_argnums_.update(inferred_nondiff_argnums) + + if nondiff_argnums: + nondiff_argnums_.update(nondiff_argnums) + + self.nondiff_argnums = tuple(sorted(nondiff_argnums_)) __getattr__ = custom_api_util.forward_attr @@ -241,7 +256,8 @@ def jvp(primals, tangents): self.defjvp(jvp) - @traceback_util.api_boundary + @partial(traceback_util.api_boundary, + repro_api_name="jax.custom_jvp.__call__") def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation debug = debug_info("custom_jvp fun", self.fun, args, kwargs, static_argnums=self.nondiff_argnums) @@ -260,10 +276,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable ) from e if self.nondiff_argnums: - nondiff_argnums = set(self.nondiff_argnums) - args = tuple(_stop_gradient(x) if i in nondiff_argnums else x + args = tuple(_stop_gradient(x) if i in self.nondiff_argnums else x for i, x in enumerate(args)) - diff_argnums = [i for i in range(len(args)) if i not in nondiff_argnums] + diff_argnums = [i for i in range(len(args)) if i not in self.nondiff_argnums] f_, dyn_args = argnums_partial(lu.wrap_init(self.fun, debug_info=debug), diff_argnums, args, require_static_args_hashable=False) @@ -287,7 +302,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable in_tree, out_type1) out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat, symbolic_zeros=self.symbolic_zeros) - _, (out_tree, _) = lu.merge_linear_aux(out_type1, out_type2) + _, (out_tree, _, _) = lu.merge_linear_aux(out_type1, out_type2) return tree_unflatten(out_tree, out_flat) @partial(lu.transformation_with_aux2, use_eq_store=True) @@ -314,7 +329,7 @@ def _flatten_jvp(f, store, primal_name, jvp_name, in_tree, maybe_out_type, *args try: out_type_ = maybe_out_type() except lu.StoreException: out_type_ = None if out_type_ is not None: - out_tree_, primal_avals_ = out_type_ + out_tree_, primal_avals_, () = out_type_ ty_tree = tree_unflatten(out_tree , [a.str_short() for a in primal_avals]) ty_tree_ = tree_unflatten(out_tree_, [a.str_short() for a in primal_avals_]) if out_tree_ != out_tree: @@ -351,7 +366,7 @@ def _flatten_jvp(f, store, primal_name, jvp_name, in_tree, maybe_out_type, *args tangent_avals_out = [core.get_aval(t).strip_weak_type() if type(t) is not SymbolicZero else t.aval.strip_weak_type() for t in tangents_out] - if expected_tangent_avals_out != tangent_avals_out: + if not all(map(core.typematch, expected_tangent_avals_out, tangent_avals_out)): if len(expected_tangent_avals_out) == 1: (av_p,), (av_et,), (av_t,) = primal_avals_out, expected_tangent_avals_out, tangent_avals_out msg = ("Custom JVP rule must produce primal and tangent outputs with " @@ -364,9 +379,8 @@ def _flatten_jvp(f, store, primal_name, jvp_name, in_tree, maybe_out_type, *args f" primal {av_p.str_short()} with tangent {av_t.str_short()}, expecting tangent {av_et}" for av_p, av_et, av_t in zip(primal_avals_out, expected_tangent_avals_out, tangent_avals_out) if av_et != av_t) - raise TypeError(msg.format('\n'.join(disagreements))) - store.store((out_tree, primal_avals)) + store.store((out_tree, primal_avals, ())) return primals_out + tangents_out class CustomJVPCallPrimitive(core.Primitive): @@ -410,8 +424,6 @@ def jvp(*xs): return [*out_primals, *out_tangents] return lu.wrap_init(jvp, debug_info=jvp_jaxpr_fun.debug_info) -effects.custom_derivatives_allowed_effects.add_type(lax.InOutFeedEffect) - custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call') def _custom_jvp_call_typecheck(_, *in_avals, call_jaxpr, jvp_jaxpr_fun, @@ -425,31 +437,31 @@ def _custom_jvp_call_typecheck(_, *in_avals, call_jaxpr, jvp_jaxpr_fun, return call_jaxpr.out_avals, call_jaxpr.effects core.custom_typechecks[custom_jvp_call_p] = _custom_jvp_call_typecheck -def _custom_jvp_call_mlir_translation(ctx, *args, call_jaxpr, jvp_jaxpr_fun, - num_consts, symbolic_zeros): - del jvp_jaxpr_fun, num_consts, symbolic_zeros - consts = mlir._ir_consts(call_jaxpr.consts) +def _custom_jvp_vjp_call_lowering(ctx: mlir.LoweringRuleContext, *args, + call_jaxpr: core.ClosedJaxpr, **_): + consts = mlir.ir_consts( + call_jaxpr.consts, [v.aval for v in call_jaxpr.jaxpr.constvars]) out, tokens = mlir.jaxpr_subcomp(ctx.module_context, call_jaxpr.jaxpr, ctx.name_stack, ctx.tokens_in, consts, - *args, dim_var_values=ctx.dim_var_values) + *args, dim_var_values=ctx.dim_var_values, + const_lowering=ctx.const_lowering) ctx.set_tokens_out(tokens) return out -mlir.register_lowering(custom_jvp_call_p, _custom_jvp_call_mlir_translation) +mlir.register_lowering(custom_jvp_call_p, _custom_jvp_vjp_call_lowering) -# If a (multi)linear function is defined with a custom jvp, then -# custom_jvp_call_ can appear in jaxprs to be transposed. Since it's already -# been linearized, we can drop the jvp rule. -def _custom_jvp_call_transpose(params, jaxpr, args, ct, _): +def _custom_jvp_call_transpose_fancy(params, jaxpr, args, ct, _): del params - return ad.backward_pass(jaxpr.jaxpr, None, jaxpr.consts, args, ct) -ad.primitive_transposes[custom_jvp_call_p] = _custom_jvp_call_transpose + return ad.backward_pass3(jaxpr.jaxpr, False, jaxpr.consts, args, ct) +ad.fancy_transposes[custom_jvp_call_p] = _custom_jvp_call_transpose_fancy @weakref_lru_cache def _cached_closed_call_dce_instantiate(jaxpr_: core.ClosedJaxpr, used_outputs: tuple[bool, ...] ) -> tuple[core.ClosedJaxpr, list[bool]]: jaxpr, consts = jaxpr_.jaxpr, jaxpr_.consts - new_jaxpr, used_inputs = pe.dce_jaxpr(jaxpr, used_outputs, True) + new_jaxpr, used_inputs = pe.dce_jaxpr( + jaxpr.replace(debug_info=jaxpr.debug_info.with_unknown_names()), + used_outputs, True) return core.ClosedJaxpr(new_jaxpr, consts), used_inputs def _custom_jvp_call_dce( @@ -469,7 +481,9 @@ def _custom_jvp_call_dce( @pe._memoize def dce_jvp_jaxpr_thunk(*in_zeros): jvp_jaxpr, consts, out_zeros = jvp_jaxpr_fun.call_wrapped(*in_zeros) - dce_jvp_jaxpr, _ = pe.dce_jaxpr(jvp_jaxpr, [*used_outs, *used_outs], True) + sz = eqn.params["symbolic_zeros"] + nz_used_outs = [u for u, z in zip(used_outs, out_zeros) if not z] if sz else used_outs + dce_jvp_jaxpr, _ = pe.dce_jaxpr(jvp_jaxpr, [*used_outs, *nz_used_outs], True) dce_out_zeros = [v for used, v in zip(used_outs, out_zeros) if used] return dce_jvp_jaxpr, consts, dce_out_zeros @@ -487,6 +501,21 @@ def dce_jvp_jaxpr_thunk(*in_zeros): pe.dce_rules[custom_jvp_call_p] = _custom_jvp_call_dce +def _custom_jvp_call_pp_rule(eqn: core.JaxprEqn, + context: core.JaxprPpContext, + settings: core.JaxprPpSettings) -> core.pp.Doc: + params = dict(eqn.params) + if not params["num_consts"]: + params.pop("num_consts") + params["jvp"] = params.pop("jvp_jaxpr_fun").debug_info.func_name + names = sorted(params) + params["name"] = params["call_jaxpr"].jaxpr.debug_info.func_name + return core._pp_eqn(eqn.replace(params=params), context, settings, + params=["name"] + names) + + +core.pp_eqn_rules[custom_jvp_call_p] = _custom_jvp_call_pp_rule + ### VJPs @custom_api_util.register_custom_decorator_type @@ -521,15 +550,29 @@ def f_bwd(res, g): For a more detailed introduction, see the tutorial_. - .. _tutorial: https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html + .. _tutorial: https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html """ def __init__(self, fun: Callable[..., ReturnValue], - nondiff_argnums: Sequence[int] = ()): + nondiff_argnums: Sequence[int] = (), + nondiff_argnames: Sequence[str] = ()): update_wrapper(self, fun) self.fun = fun - self.nondiff_argnums = nondiff_argnums + + nondiff_argnums_: set[int] = set() + if nondiff_argnames: + sig = fun_signature(self.fun) + assert sig is not None + inferred_nondiff_argnums, _ = infer_argnums_and_argnames( + sig, None, nondiff_argnames + ) + nondiff_argnums_.update(inferred_nondiff_argnums) + + if nondiff_argnums: + nondiff_argnums_.update(nondiff_argnums) + + self.nondiff_argnums = tuple(sorted(nondiff_argnums_)) self.fwd: Callable[..., tuple[ReturnValue, Any]] | None = None self.bwd: Callable[..., tuple[Any, ...]] | None = None self.symbolic_zeros = False @@ -635,7 +678,8 @@ def defvjp(self, raise NotImplementedError( "remat optimization for custom_vjp does not support symbolic zeros") - @traceback_util.api_boundary + @partial(traceback_util.api_boundary, + repro_api_name="jax.custom_vjp.__call__") def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation debug_fun = debug_info("custom_vjp fun", self.fun, args, kwargs, static_argnums=self.nondiff_argnums) @@ -671,8 +715,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable else: if self.nondiff_argnums: for i in self.nondiff_argnums: _check_for_tracers(args[i]) - nondiff_argnums = set(self.nondiff_argnums) - dyn_argnums = [i for i in range(len(args)) if i not in nondiff_argnums] + dyn_argnums = [i for i in range(len(args)) if i not in self.nondiff_argnums] f_, dyn_args = argnums_partial( lu.wrap_init(self.fun, debug_info=debug_fun), dyn_argnums, args, require_static_args_hashable=False) @@ -698,24 +741,27 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd, *args_flat, out_trees=out_trees, symbolic_zeros=self.symbolic_zeros) - _, (out_tree, _) = lu.merge_linear_aux(out_type, out_trees) + _, (out_tree, _, _) = lu.merge_linear_aux(out_type, out_trees) return tree_unflatten(out_tree, out_flat) @lu.transformation2 -def _check_primal_refs(f: Callable, nondiff_argnums: Sequence[int], - debug_info: core.DebugInfo, *args): - _check_for_aliased_refs(f, nondiff_argnums, debug_info, args) +def _check_primal_refs( + f: Callable, nondiff_argnums: Sequence[int], debug: core.DebugInfo, *args): + _check_for_aliased_refs(f, nondiff_argnums, debug, args) out = f(*args) - _check_for_returned_refs(f, out, 'primal') + _check_for_returned_refs(f, out, 'primal', [], 0) return out -def _check_for_aliased_refs(f: Callable, - nondiff_argnums: Sequence[int], - debug: core.DebugInfo, - args): +def _check_for_aliased_refs( + f: Callable, nondiff_argnums: Sequence[int], debug: core.DebugInfo, args): + nondiff_argnums_ = set(nondiff_argnums) + argnums = [x for i, arg in enumerate(args) + for x in [i] * tree_structure(arg).num_leaves] leaves = tree_leaves(args) refs: dict[int, int] = {} - for i, x in enumerate(leaves): + for i, (argnum, x) in enumerate(zip(argnums, leaves)): + if argnum in nondiff_argnums: continue + x = x.value if isinstance(x, CustomVJPPrimal) else x if (isinstance((a := core.get_aval(x)), AbstractRef) and (dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i): arg_names = debug.safe_arg_names(len(leaves)) @@ -725,14 +771,21 @@ def _check_for_aliased_refs(f: Callable, f"array reference of type {a.str_short()} at {arg_names[dup_idx]} and" f" {arg_names[i]}.") -def _check_for_returned_refs(f, out, kind): +def _check_for_returned_refs(f, out, kind, args, after_idx): + args = [x.value if isinstance(x, CustomVJPPrimal) else x for x in args] + ids = {id(x) for x in args if isinstance(core.get_aval(x), AbstractRef)} leaves = tree_leaves_with_path(out) - for path, leaf in leaves: + for i, (path, leaf) in enumerate(leaves): if isinstance((a := core.get_aval(leaf)), AbstractRef): loc = f' at output tree path {keystr(path)}' if path else '' - raise ValueError(f"custom_vjp {kind} function {f} returned a mutable " - f"a array reference of type {a.str_short()}{loc}, " - "but mutable array references cannot be returned.") + if i < after_idx: + raise ValueError(f"custom_vjp {kind} function {f} returned a mutable " + f"array reference of type {a.str_short()}{loc}, " + "but mutable array references cannot be returned there.") + if id(leaf) not in ids: + raise ValueError(f"custom_vjp {kind} function {f} returned a mutable " + f"array reference of type {a.str_short()}{loc} " + "that was not an argument.") @dataclasses.dataclass class CustomVJPPrimal: @@ -787,8 +840,6 @@ def _flatten_fwd(f: Callable, store: lu.EqualStore, if config.mutable_array_checks.value: _check_for_aliased_refs(f, nondiff_argnums, debug_primal, py_args) pair_out = f(*py_args) - if config.mutable_array_checks.value: - _check_for_returned_refs(f, pair_out, "fwd") if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2: msg = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} " "must produce a pair (list or tuple of length two) where the first " @@ -801,12 +852,14 @@ def _flatten_fwd(f: Callable, store: lu.EqualStore, py_primals_out, res = pair_out primals_out, out_tree = tree_flatten(py_primals_out) res, res_tree = tree_flatten(res) + if config.mutable_array_checks.value: + _check_for_returned_refs(f, pair_out, "fwd", args, out_tree.num_leaves) primal_avals = [core.get_aval(x) for x in primals_out] # If the primal function already ran, check out_tree agreement. try: out_type_ = maybe_out_type() except lu.StoreException: out_type_ = None if out_type_ is not None: - out_tree_, primal_avals_ = out_type_ + out_tree_, primal_avals_, () = out_type_ ty_tree = tree_unflatten(out_tree , [a.str_short() for a in primal_avals]) ty_tree_ = tree_unflatten(out_tree_, [a.str_short() for a in primal_avals_]) if out_tree_ != out_tree: @@ -838,15 +891,21 @@ def _flatten_fwd(f: Callable, store: lu.EqualStore, "shapes/dtypes of:\n" f""" {str(ty_tree_).replace("'", "")}""") raise TypeError(m) - store.store((out_tree, res_tree)) - return (*res, *primals_out) + pruned_res, input_forwards = _filter_forwarded_inputs(res, args) # prune + store.store((out_tree, res_tree, input_forwards)) + return (*pruned_res, *primals_out) + +def _filter_forwarded_inputs(outs, ins): + idxs: dict[int, int] = {id(x): i for i, x in enumerate(ins)} + return [o for o in outs if id(o) not in idxs], [idxs.get(id(o)) for o in outs] @lu.transformation2 def _flatten_bwd(f: Callable, in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue], - out_trees: Callable[[], Sequence[PyTreeDef]], *args): - out_tree, res_tree = out_trees() + out_trees: Callable[[], tuple[PyTreeDef, PyTreeDef, list[int | None]]], + *args): + out_tree, res_tree, _ = out_trees() assert len(args) == res_tree.num_leaves + out_tree.num_leaves res, cts_out = split_list(args, [res_tree.num_leaves]) py_res = tree_unflatten(res_tree, res) @@ -886,7 +945,7 @@ def append(x, d): if ct is zero or getattr(a.to_tangent_aval(), 'dtype') == dtypes.float0: results.append(Zero(a.to_tangent_aval())) elif type(ct) is SymbolicZero: - if not core.typecompat(a.to_tangent_aval(), a_ := ct.aval): + if not core.typecompat(a.to_cotangent_aval(), a_ := ct.aval): msg = ("Custom VJP bwd rule produced a SymbolicZero with a shape/dtype " "that does not match the corresponding input tangent shape/dtype: " f"at output{keystr(kp)} the SymbolicZero had shape/dtype " @@ -897,32 +956,36 @@ def append(x, d): raise ValueError(msg) results.append(Zero(ct.aval)) else: - if (not core.typecompat(a.to_tangent_aval(), a_ := core.get_aval(ct)) - and not (_temporary_dtype_exception(a, a_) or - _temporary_shape_exception(a, a_))): + if (not config.disable_bwd_checks.value and + not core.typecompat(a.to_cotangent_aval(), a_ := core.get_aval(ct)) + and not _ref_typecompat(a.to_cotangent_aval(), a_) + and not _temporary_dtype_exception(a.to_cotangent_aval(), a_)): msg = ("Custom VJP bwd rule must produce an output with the same " - "shape/dtypes as the args tuple of the primal function, but at " + "type as the args tuple of the primal function, but at " f"output{keystr(kp)} the bwd rule produced an output of " - f"shape/dtype {a_.str_short()} corresponding " - f"to an input of shape/dtype {a.str_short()}.") + f"type {a_.str_short()} corresponding " + f"to an input of type {a.str_short()}" + f"{core.aval_mismatch_extra(a, a_)}") raise ValueError(msg) results.append(ct) return results +def _ref_typecompat(a, a_): + return (isinstance(a, AbstractRef) and + core.typecompat(a.to_cotangent_aval().inner_aval, a_)) + # TODO(mattjj): remove both these exceptions to cotangent compatibility check def _temporary_dtype_exception(a, a_) -> bool: if isinstance(a, core.ShapedArray) and isinstance(a_, core.ShapedArray): return (a.shape == a_.shape and + core.typematch(a, a_, only_shape_shd_check=True) and (dtypes.issubdtype(a_.dtype, dtypes.extended) or dtypes.issubdtype(a.dtype, dtypes.np.inexact))) return False -# TODO(mattjj): remove both these exceptions to cotangent compatibility check -def _temporary_shape_exception(a, a_) -> bool: - return config.custom_vjp_disable_shape_check.value -class CustomVJPCallPrimitive(core.CallPrimitive): - initial_style: core.Primitive +class CustomVJPCallPrimitive(core.Primitive): + multiple_results = True def bind(self, *args, **params): return self._true_bind(*args, **params) @@ -931,119 +994,83 @@ def bind_with_trace(self, trace, args, params): fun, fwd, bwd, tracers = args[0], args[1], args[2], args[3:] return trace.process_custom_vjp_call(self, fun, fwd, bwd, tracers, **params) -custom_vjp_call_p = CustomVJPCallPrimitive('custom_vjp_call') + def impl(self, fun, fwd, bwd, *args): + raise NotImplementedError -def _custom_vjp_call_jaxpr_impl(*args, fun_jaxpr, **_): - return core.jaxpr_as_fun(fun_jaxpr)(*args) + def get_bind_params(self, params): + new_params = dict(params) + call_jaxpr: core.ClosedJaxpr = new_params.pop('call_jaxpr') + num_consts: int = new_params.pop('num_consts') + fwd_jaxpr_thunk = new_params.pop('fwd_jaxpr_thunk') + fun = lu.wrap_init(core.jaxpr_as_fun(call_jaxpr), + debug_info=call_jaxpr.jaxpr.debug_info) + fwd = lift_fwd(num_consts, fwd_jaxpr_thunk) + const_avals, _ = split_list(call_jaxpr.in_avals, [num_consts]) + bwd = _handle_consts_in_bwd(new_params.pop('bwd'), const_avals) + return [fun, fwd, bwd], new_params + +def lift_fwd(num_consts: int, fwd_jaxpr_thunk: lu.WrappedFun) -> lu.WrappedFun: + def fwd(*args): + vals, nonzeros = args[::2], args[1::2] + assert len(vals) == len(nonzeros) + _, primals = split_list(vals, [num_consts]) + const_nonzeros, in_nonzeros = split_list(nonzeros, [num_consts]) + if any(const_nonzeros): raise ad.CustomVJPException() + fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk.call_wrapped(*in_nonzeros) + return core.eval_jaxpr(fwd_jaxpr, fwd_consts, *primals) + return lu.wrap_init(fwd, debug_info=fwd_jaxpr_thunk.debug_info) + +@lu.transformation2 +def _handle_consts_in_bwd(f, const_avals, *args): + return [Zero(a) for a in const_avals] + list(f(*args)) -def _custom_vjp_call_jaxpr_abstract_eval(*_, fun_jaxpr, **__): - disallowed_effects = effects.custom_derivatives_allowed_effects.filter_not_in(fun_jaxpr.effects) +custom_vjp_call_p = CustomVJPCallPrimitive('custom_vjp_call') +# TODO(phawkins,mattjj): make this primitive cacheable. +mlir.register_lowering(custom_vjp_call_p, _custom_jvp_vjp_call_lowering, + cacheable=False) + +def _custom_vjp_call_typecheck(_, *in_avals, call_jaxpr, **kwargs): + del in_avals, kwargs + disallowed_effects = effects.custom_derivatives_allowed_effects.filter_not_in( + call_jaxpr.effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `custom_vjp`: {disallowed_effects}') - return fun_jaxpr.out_avals, fun_jaxpr.effects - -custom_vjp_call_jaxpr_p = core.Primitive('custom_vjp_call_jaxpr') -custom_vjp_call_jaxpr_p.multiple_results = True -custom_vjp_call_jaxpr_p.def_impl(_custom_vjp_call_jaxpr_impl) -custom_vjp_call_jaxpr_p.def_effectful_abstract_eval(_custom_vjp_call_jaxpr_abstract_eval) -CustomVJPCallPrimitive.initial_style = custom_vjp_call_jaxpr_p - -mlir.register_lowering(custom_vjp_call_jaxpr_p, mlir.lower_fun( - _custom_vjp_call_jaxpr_impl, multiple_results=True)) - -def _custom_vjp_call_jaxpr_jvp( - primals, tangents, *, fun_jaxpr: core.ClosedJaxpr, - fwd_jaxpr_thunk: Callable[..., tuple[core.Jaxpr, Sequence[Any]]], - num_consts: int, bwd: lu.WrappedFun, - out_trees: Callable[[], Sequence[PyTreeDef]], - symbolic_zeros: bool): - _, args = split_list(primals, [num_consts]) - consts_dot, args_dot = split_list(tangents, [num_consts]) - if any(type(t) is not Zero for t in consts_dot): - raise ad.CustomVJPException() - zeros = [type(t) is not Zero for t in args_dot] - fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk(*zeros) # consts can be tracers! - _, res_tree = out_trees() - res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args) - res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) - avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out] - args_dot = map(ad.instantiate_zeros, args_dot) - tangents_out = ad.custom_lin_p.bind( - *res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd, - out_avals=avals_out, symbolic_zeros=symbolic_zeros) - tangents_out = map(lax.tie_p.bind, primals_out, tangents_out) - return primals_out, tangents_out -ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp - -def _custom_vjp_call_jaxpr_vmap( - axis_data, args, in_dims, *, - fun_jaxpr: core.ClosedJaxpr, - fwd_jaxpr_thunk: Callable[..., tuple[core.Jaxpr, Sequence[Any]]], - num_consts: int, bwd: lu.WrappedFun, - out_trees: Callable, symbolic_zeros: bool): - args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0 - else x for x, d in zip(args, in_dims)] - in_batched = [d is not not_mapped for d in in_dims] - _, args_batched = split_list(in_batched, [num_consts]) - batched_fun_jaxpr, out_batched = batching.batch_jaxpr( - fun_jaxpr, axis_data, in_batched, False) - out_dims1 = [0 if b else not_mapped for b in out_batched] - out_dims2 = [] - - @pe._memoize - def batched_fwd_jaxpr_thunk(*zeros): - fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk(*zeros)) # consts can be tracers - batched_fwd_jaxpr, out_batched = batching.batch_jaxpr( - fwd_jaxpr, axis_data, args_batched, False) - out_dims2.append([0 if b else not_mapped for b in out_batched]) - return batched_fwd_jaxpr.jaxpr, batched_fwd_jaxpr.consts - - fwd_args_batched = [0 if b else not_mapped for b in args_batched] - fwd_out_dims = lambda: out_dims2[0] - tag = core.TraceTag() - batched_bwd = batching.batch_custom_vjp_bwd( - bwd, tag, axis_data, fwd_out_dims, fwd_args_batched) - - batched_outs = custom_vjp_call_jaxpr_p.bind( - *args, fun_jaxpr=batched_fun_jaxpr, - fwd_jaxpr_thunk=batched_fwd_jaxpr_thunk, bwd=batched_bwd, - num_consts=num_consts, out_trees=out_trees, symbolic_zeros=symbolic_zeros) - out_dims = out_dims2[0] if out_dims2 else out_dims1 - return batched_outs, out_dims -batching.fancy_primitive_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap + return call_jaxpr.out_avals, call_jaxpr.effects +core.custom_typechecks[custom_vjp_call_p] = _custom_vjp_call_typecheck -def _custom_vjp_call_jaxpr_dce( +def _custom_vjp_call_dce( used_outs: Sequence[bool], eqn: core.JaxprEqn ) -> tuple[list[bool], core.JaxprEqn | None]: if not any(used_outs) and not pe.has_effects(eqn): return [False] * len(eqn.invars), None - fun_jaxpr: core.ClosedJaxpr = eqn.params["fun_jaxpr"] + call_jaxpr: core.ClosedJaxpr = eqn.params["call_jaxpr"] fwd_jaxpr_thunk = eqn.params["fwd_jaxpr_thunk"] bwd: lu.WrappedFun = eqn.params["bwd"] - out_trees: Callable[[], Sequence[PyTreeDef]] = eqn.params["out_trees"] + out_trees: Callable[[], tuple[PyTreeDef, PyTreeDef, list[int | None]]] = eqn.params["out_trees"] symbolic_zeros: bool = eqn.params["symbolic_zeros"] - dce_fun_jaxpr: core.ClosedJaxpr + dce_call_jaxpr: core.ClosedJaxpr used_ins: Sequence[bool] - dce_fun_jaxpr, used_ins = _cached_closed_call_dce_instantiate( - fun_jaxpr, tuple(used_outs)) + dce_call_jaxpr, used_ins = _cached_closed_call_dce_instantiate( + call_jaxpr, tuple(used_outs)) assert all(used_ins) + @partial(lu.wrap_init, debug_info=fwd_jaxpr_thunk.debug_info) @pe._memoize def dce_fwd_jaxpr_thunk(*zeros): - fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk(*zeros)) - _, res_tree = out_trees() - num_res = res_tree.num_leaves + fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk.call_wrapped(*zeros)) + _, res_tree, fwds = out_trees() + num_res_out = res_tree.num_leaves - sum(f is not None for f in fwds) dce_fwd_jaxpr, _ = _cached_closed_call_dce_instantiate( - fwd_jaxpr, (True,) * num_res + tuple(used_outs)) + fwd_jaxpr, (True,) * num_res_out + tuple(used_outs)) return dce_fwd_jaxpr.jaxpr, dce_fwd_jaxpr.consts def dce_bwd(*args): - _, res_tree = out_trees() + _, res_tree, _ = out_trees() res, cts = split_list(args, [res_tree.num_leaves]) cts_ = iter(cts) all_cts = [] - for used, aval in zip(used_outs, fun_jaxpr.out_avals): + for used, aval in zip(used_outs, call_jaxpr.out_avals): if used: all_cts.append(next(cts_)) else: @@ -1060,20 +1087,37 @@ def dce_bwd(*args): outvars = [v for used, v in zip(used_outs, eqn.outvars) if used] new_params = dict( eqn.params, - fun_jaxpr=dce_fun_jaxpr, + call_jaxpr=dce_call_jaxpr, fwd_jaxpr_thunk=dce_fwd_jaxpr_thunk, bwd=dce_bwd_wrapped, ) new_eqn = pe.new_jaxpr_eqn( - eqn.invars, outvars, eqn.primitive, new_params, dce_fun_jaxpr.effects, + eqn.invars, outvars, eqn.primitive, new_params, dce_call_jaxpr.effects, eqn.source_info, eqn.ctx) return list(used_ins), new_eqn -pe.dce_rules[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_dce +pe.dce_rules[custom_vjp_call_p] = _custom_vjp_call_dce + -xla.register_initial_style_primitive(custom_vjp_call_jaxpr_p) +def _custom_vjp_call_pp_rule(eqn: core.JaxprEqn, + context: core.JaxprPpContext, + settings: core.JaxprPpSettings) -> core.pp.Doc: + params = dict(eqn.params) + if not params["num_consts"]: + params.pop("num_consts") + params.pop("out_trees") + params["fwd"] = params.pop("fwd_jaxpr_thunk").debug_info.func_name + params["bwd"] = params.pop("bwd").debug_info.func_name + names = sorted(params) + params["name"] = params["call_jaxpr"].jaxpr.debug_info.func_name + return core._pp_eqn(eqn.replace(params=params), context, settings, + params=["name"] + names) + +core.pp_eqn_rules[custom_vjp_call_p] = _custom_vjp_call_pp_rule batching.primitive_batchers[ad.custom_lin_p] = ad.raise_custom_vjp_error_on_jvp -mlir.register_lowering(ad.custom_lin_p, ad.raise_custom_vjp_error_on_jvp) +# TODO(phawkins,mattjj): make this primitive cacheable. +mlir.register_lowering(ad.custom_lin_p, ad.raise_custom_vjp_error_on_jvp, + cacheable=False) def custom_gradient(fun): @@ -1149,7 +1193,7 @@ def fwd(*args, **kwargs): rule, in_tree = flatten_fun_nokwargs(lu.wrap_init(rule, debug_info=debug_fwd), out_tree) ans_avals = [core.get_aval(x).to_tangent_aval() for x in ans_flat] - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(rule, ans_avals) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(rule, ans_avals) return ans, Residuals(jaxpr, in_tree(), out_tree, consts) def bwd(res, cts): @@ -1276,18 +1320,18 @@ def _maybe_perturbed(x: Any) -> bool: @cache() def _closure_convert_for_avals(fun, in_tree, in_avals, debug_info: core.DebugInfo): - wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun, debug_info=debug_info), - in_tree) - jaxpr, out_pvals, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals) + wrapped_fun, out_tree = flatten_fun_nokwargs( + lu.wrap_init(fun, debug_info=debug_info), in_tree) + jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals) out_tree = out_tree() - (closure_consts, hoisted_consts), merge = partition_list(_maybe_perturbed, consts) - num_consts = len(hoisted_consts) + (closure_consts, const_args), merge = partition_list(_maybe_perturbed, consts) + num_consts = len(const_args) def converted_fun(*args_hconsts): num_args = len(args_hconsts) - num_consts - args, hoisted_consts = split_list(args_hconsts, [num_args]) - consts = merge(closure_consts, hoisted_consts) + args, const_args = split_list(args_hconsts, [num_args]) + consts = merge(closure_consts, const_args) all_args, in_tree2 = tree_flatten(tuple(args)) if in_tree != in_tree2: msg = ("The inputs to the closure produced by closure_convert must have " @@ -1298,7 +1342,7 @@ def converted_fun(*args_hconsts): out_flat = core.eval_jaxpr(jaxpr, consts, *all_args) return tree_unflatten(out_tree, out_flat) - return converted_fun, hoisted_consts + return converted_fun, const_args def partition_list(choice, lst): out = [], [] @@ -1370,11 +1414,11 @@ def linear_call(fun: Callable, >>> custom_id(1.) 1.0 >>> transpose(custom_id, 1.)(1.) - 7.0 + TypedFloat(7.0, dtype=float32) >>> transpose(transpose(custom_id, 1.), 1.)(1.) 1.0 >>> transpose(transpose(transpose(custom_id, 1.), 1.), 1.)(1.) - 7.0 + TypedFloat(7.0, dtype=float32) Args: fun: a Python callable specifying a linear function. It should @@ -1424,53 +1468,66 @@ def linear_call(fun: Callable, (residual_args, linear_args), {})), t_in_tree) - t_jaxpr, t_consts = _initial_style_jaxpr(t, (*res_avals, *out_avals)) - t_jaxpr_closed = _close_jaxpr(t_jaxpr) - - if t_out_tree() != lin_tree: - raise TypeError( - 'transpose output pytree structure must match that of linear inputs, ' - f'got output structure {t_out_tree()} ' - f'and input structure {lin_tree}.') + @pe._memoize + def transpose_thunk(): + t_jaxpr, t_consts = _initial_style_jaxpr(t.with_unknown_names(), + (*res_avals, *out_avals)) + if t_out_tree() != lin_tree: + raise TypeError( + 'transpose output pytree structure must match that of linear inputs, ' + f'got output structure {t_out_tree()} ' + f'and input structure {lin_tree}.') + return _close_jaxpr(t_jaxpr), t_consts - out = linear_call_p.bind(*f_consts, *t_consts, *operands_res, *operands_lin, + out = linear_call_p.bind(*f_consts, *operands_res, *operands_lin, callee=f_jaxpr_closed, - transpose=t_jaxpr_closed, + transpose_thunk=transpose_thunk, num_callee_consts=len(f_consts), - num_transpose_consts=len(t_consts), num_res=len(operands_res)) return tree_unflatten(out_tree(), out) -def _linear_call_impl(*args, callee, transpose, num_callee_consts, - num_transpose_consts, num_res): - del transpose - consts, _, operands_res, operands_lin = split_list( - args, [num_callee_consts, num_transpose_consts, num_res]) - return core.eval_jaxpr(callee.jaxpr, (), *consts, *operands_res, *operands_lin) - -def _linear_call_transpose_rule(cts, *args, callee, transpose, - num_callee_consts, - num_transpose_consts, num_res): - f_consts, t_consts, operands_res, operands_lin = split_list( - args, [num_callee_consts, num_transpose_consts, num_res]) +def _linear_call_impl(*args, callee, transpose_thunk, num_callee_consts, + num_res): + del transpose_thunk, num_callee_consts, num_res + return core.eval_jaxpr(callee.jaxpr, (), *args) + +def _linear_call_jvp_rule(primals, tangents, callee, transpose_thunk, + num_callee_consts, num_res): + consts_and_res, primals = split_list(primals, [num_callee_consts + num_res]) + const_tangents, tangents = split_list(tangents, [num_callee_consts + num_res]) + assert all(type(t) is Zero for t in const_tangents) + primals_out = linear_call_p.bind( + *consts_and_res, *primals, callee=callee, transpose_thunk=transpose_thunk, + num_callee_consts=num_callee_consts, num_res=num_res) + tangents_out = linear_call_p.bind( + *consts_and_res, *tangents, callee=callee, transpose_thunk=transpose_thunk, + num_callee_consts=num_callee_consts, num_res=num_res) + return primals_out, tangents_out + +def _linear_call_transpose_rule(cts, *args, callee, transpose_thunk, + num_callee_consts, num_res): + transpose, t_consts = transpose_thunk() + f_consts, operands_res, operands_lin = split_list( + args, [num_callee_consts, num_res]) _, _, cts_avals = split_list( - transpose.in_avals, [num_transpose_consts, num_res]) + transpose.in_avals, [len(t_consts), num_res]) assert all(ad.is_undefined_primal(x) for x in operands_lin) assert all(not ad.is_undefined_primal(x) for x in operands_res) + def new_transpose_thunk(): + return callee, f_consts + cts = [zeros_like_aval(a) if type(ct) is Zero else ct for ct, a in zip(cts, cts_avals)] - - cts_out = linear_call_p.bind(*t_consts, *f_consts, *operands_res, *cts, + cts_out = linear_call_p.bind(*t_consts, *operands_res, *cts, callee=transpose, - transpose=callee, + transpose_thunk=new_transpose_thunk, num_callee_consts=len(t_consts), - num_transpose_consts=len(f_consts), num_res=len(operands_res)) - return [None] * (num_callee_consts + num_transpose_consts + num_res) + cts_out + return [None] * (num_callee_consts + num_res) + cts_out def _linear_call_abstract_eval(*args, **kwargs): return kwargs['callee'].out_avals @@ -1479,8 +1536,9 @@ def _linear_call_abstract_eval(*args, **kwargs): linear_call_p.multiple_results = True linear_call_p.def_impl(_linear_call_impl) linear_call_p.def_abstract_eval(_linear_call_abstract_eval) +ad.primitive_jvps[linear_call_p] = _linear_call_jvp_rule ad.primitive_transposes[linear_call_p] = _linear_call_transpose_rule -xla.register_initial_style_primitive(linear_call_p) +pxla.register_initial_style_primitive(linear_call_p) mlir.register_lowering(linear_call_p, mlir.lower_fun( _linear_call_impl, multiple_results=True)) @@ -1558,7 +1616,6 @@ def jvp(primals, tangents): # TODO(mattjj): remove these stubs, which exist to avoid breaking internal users custom_jvp_call_jaxpr_p = core.Primitive("custom_jvp_call_jaxpr") - # The following is a helper for optimizing the behavior of custom_vjp when used # under remat. This is really only useful when the `fwd` function to custom_vjp # executes a black box kernel. Otherwise, DCE will perform this optimization @@ -1586,7 +1643,6 @@ def optimize_remat_of_custom_vjp_fwd( def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]: # TODO(dfm): This initial logic is duplicated from custom_vjp.__call__ # above and it would be good to consolidate it. - fwd_name = debug_fwd.func_name if debug_fwd else str(fwd) # Note: we use `fun` instead of `fwd` here for consistency with # custom_vjp.__call__ above. args = resolve_kwargs(fun, args, kwargs) @@ -1610,28 +1666,30 @@ def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]: flat_fwd = _fix_fwd_args(flat_fwd) in_avals = [core.get_aval(x) for x in args_flat] - fwd_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fwd, in_avals) + fwd_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fwd.with_unknown_names(), + in_avals) fwd_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr)) - prim_tree, res_tree = out_trees() - num_res = res_tree.num_leaves + prim_tree, res_tree, fwds = out_trees() + num_res_out = res_tree.num_leaves - sum(f is not None for f in fwds) - if fwd_jaxpr.effects: + disallowed_effects = effects.custom_derivatives_allowed_effects.filter_not_in(fwd_jaxpr.effects) + if disallowed_effects: raise NotImplementedError( "remat optimization for custom_vjp does not support forward " - f"functions with side effects, but {fwd_name} has the following " - f"effects: {fwd_jaxpr.effects}") + f"functions with these side effects: {disallowed_effects}") @pe._memoize def fun_jaxpr_thunk(): - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) return jaxpr, consts - out_flat = remat_opt_p.bind(*consts, *args_flat, - num_consts=len(consts), - num_res=num_res, - fwd_jaxpr=fwd_jaxpr, + out_flat = remat_opt_p.bind(*consts, *args_flat, num_consts=len(consts), + num_res=num_res_out, fwd_jaxpr=fwd_jaxpr, fun_jaxpr_thunk=fun_jaxpr_thunk) - res, out_flat = split_list(out_flat, [num_res]) + res, out_flat = split_list(out_flat, [num_res_out]) + res_ = iter(res) + res = [next(res_) if f is None else args_flat[f] for f in fwds] + assert next(res_, None) is None out_tree = treedef_tuple((prim_tree, res_tree)) return tree_unflatten(out_tree, (*out_flat, *res)) @@ -1788,7 +1846,7 @@ def _remat_opt_dce(used_outs: list[bool], eqn: core.JaxprEqn): remat_opt_p.multiple_results = True remat_opt_p.def_impl(_remat_opt_impl) remat_opt_p.def_effectful_abstract_eval(_remat_opt_abstract_eval) -xla.register_initial_style_primitive(remat_opt_p) +pxla.register_initial_style_primitive(remat_opt_p) mlir.register_lowering(remat_opt_p, mlir.lower_fun( _remat_opt_impl, multiple_results=True)) diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py index 5374071517f1..a839eeb64772 100644 --- a/jax/_src/custom_partitioning.py +++ b/jax/_src/custom_partitioning.py @@ -21,20 +21,23 @@ from functools import partial import inspect -from typing import Any, Callable +from typing import Any +from collections.abc import Callable import weakref import numpy as np -import jax -from jax import tree_util + +from jax._src import api from jax._src import api_util from jax._src import config from jax._src import core from jax._src import custom_api_util from jax._src import dispatch +from jax._src import errors from jax._src import linear_util as lu from jax._src import mesh as mesh_lib from jax._src import sharding_impls +from jax._src import tree_util from jax._src import xla_bridge as xb from jax._src.custom_partitioning_sharding_rule import sdy_sharding_rule_to_mlir, SdyShardingRule, str_to_sdy_sharding_rule from jax._src.interpreters import mlir @@ -42,7 +45,7 @@ from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo -from jax.errors import UnexpectedTracerError +from jax._src.sharding import Sharding def _resolve_kwargs(fun, args, kwargs): @@ -93,7 +96,7 @@ def _to_jax_shape(s): def _to_jax_sharded_shape(s, sharding): - return jax.ShapeDtypeStruct( + return api.ShapeDtypeStruct( s.dimensions(), s.numpy_dtype(), sharding=sharding ) @@ -140,7 +143,7 @@ def _custom_partitioning_propagate_user_sharding(user_sharding, shape, def _to_hlo_sharding(sharding, num_dimensions): - if not isinstance(sharding, jax.sharding.Sharding): + if not isinstance(sharding, Sharding): raise ValueError("Custom Partitioning rules must return Sharding.") return sharding._to_xla_hlo_sharding(num_dimensions) @@ -178,7 +181,7 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape, _to_jax_shape(sharding.tile(s)) for sharding, s in zip(result_shardings, result_shapes) ] - closed_jaxpr = jax.make_jaxpr(lower_fn, axis_env=list(mesh.shape.items()))( + closed_jaxpr = api.make_jaxpr(lower_fn, axis_env=list(mesh.shape.items()))( *info.in_tree.unflatten(tiled_args) ) if ([(o.shape, o.dtype) for o in closed_jaxpr.out_avals] != @@ -251,7 +254,7 @@ def _custom_partitioning_impl(*args, call, in_tree, out_tree, def _check_for_tracers(x): if any(isinstance(leaf, core.Tracer) for leaf in tree_util.tree_leaves(x)): - raise UnexpectedTracerError( + raise errors.UnexpectedTracerError( "Found a JAX Tracer object passed as an argument to a" "custom_partitioning function in a position indicated as static by" "static_argnums. " @@ -307,10 +310,29 @@ def infer_sharding_from_operands(mesh, arg_shapes, shape): provide a contextual mesh. * ``sharding_rule``: an SdyShardingRule object, an Einsum-like notation string that describes the sharding rule, or a Callable that produces either of - these. We borrow the idea from the einops.rearrange string , to use a space - separator between factors and allow multiple letters factor names. See + these. We call the index labels in Einsum notation factors in our sharding + rule. We borrow the idea from the einops.rearrange string , to use a space + separator between factors and allow multiple letters factor names. By + default, a factor corresponds to a passthrough/elementwise dimension. + Factors corresponding to other dimensions can be specified via keyword + arguments described below. See `jax-shardy-guide `_ - for more details and examples on how to use this. + for more details and examples. + * ``reduction_factors``: A tuple of strings, specifying the reduction factors + for a string `sharding_rule`. A reduction factor corresponds to a dimension + that appears in operands but not in the result, such as the contracting + dimensions in a matmul operation. If a reduction factor is sharded, the + result would need to be all-reduced along the same axes. + * ``need_replication_factors``: A tuple of strings, specifying the + need_replication factors for a string `sharding_rule`. A need_replication + factor corresponds to a dimension that shouldn't be sharded to support + the implementation. + * ``permutation_factors``: A tuple of strings, specifying the permutation + factors for a string `sharding_rule`. A permutation factor corresponds to a + dimension that would trigger collective permute if it is sharded. + * ``factor_sizes``: A dictionary of variable keyword arguments, specifying + the sizes of the factors that are only used in compound factors in a string + `sharding_rule`. When config.use_shardy_partitioner.value is True, `sharding_rule` is used; otherwise, `propagate_user_sharding` and `infer_sharding_from_operands` are @@ -376,7 +398,7 @@ def my_fft(x): my_fft.def_partition( infer_sharding_from_operands=infer_sharding_from_operands, partition=partition, - sharding_rule=SdyShardingRule(operand_mappings=((SDY_BATCHING, 'i'),), result_mappings=((SDY_BATCHING, 'i'),)))) + sharding_rule=SdyShardingRule(operand_mappings=((BATCHING, 'i'),), result_mappings=((BATCHING, 'i'),)))) Now create a 2D array sharded along the first axis, pass it through ``my_fft`` and notice how it is still sharded as expected, and identical to the output @@ -455,22 +477,38 @@ def __init__(self, fun, static_argnums=()): def def_partition(self, partition, infer_sharding_from_operands=None, propagate_user_sharding=None, decode_shardings=True, - sharding_rule=None): + sharding_rule=None, *, reduction_factors=(), + need_replication_factors=(), permutation_factors=(), + **factor_sizes): self.partition = partition self.propagate_user_sharding = propagate_user_sharding self.infer_sharding_from_operands = infer_sharding_from_operands self.decode_shardings = decode_shardings if (sharding_rule is None or isinstance(sharding_rule, Callable) or isinstance(sharding_rule, SdyShardingRule)): + sharding_rule_dict = factor_sizes + if len(reduction_factors) > 0: + sharding_rule_dict["reduction_factors"] = reduction_factors + if len(need_replication_factors) > 0: + sharding_rule_dict["need_replication_factors"] = need_replication_factors + if len(permutation_factors) > 0: + sharding_rule_dict["permutation_factors"] = permutation_factors + if sharding_rule_dict: + raise ValueError(f"Unknown keyword arguments: {sharding_rule_dict}") self.sharding_rule = sharding_rule else: - self.sharding_rule = str_to_sdy_sharding_rule(sharding_rule) + self.sharding_rule = str_to_sdy_sharding_rule( + sharding_rule, + reduction_factors=reduction_factors, + need_replication_factors=need_replication_factors, + permutation_factors=permutation_factors, + **factor_sizes) return partition def __call__(self, *args, **kwargs): args = _resolve_kwargs(self.fun, args, kwargs) debug = api_util.debug_info("custom_partitioning", self.fun, - args, kwargs, + args, {}, static_argnums=self.static_argnums) if self.static_argnums: static_argnums = set(self.static_argnums) @@ -482,17 +520,17 @@ def __call__(self, *args, **kwargs): args, require_static_args_hashable=False, ) - static_args = [args[i] for i in self.static_argnums] + static_args = tuple(args[i] for i in self.static_argnums) _check_for_tracers(static_args) else: - static_args = [] + static_args = () f_, dyn_args = lu.wrap_init(self.fun, debug_info=debug), args args_flat, in_tree = tree_util.tree_flatten(dyn_args) flat_fun, out_tree = api_util.flatten_fun_nokwargs(f_, in_tree) in_avals = [core.get_aval(x) for x in args_flat] mesh = mesh_lib.thread_resources.env.physical_mesh with core.extend_axis_env_nd(mesh.shape.items()): - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) assert not len(consts) closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) @@ -500,6 +538,14 @@ def __call__(self, *args, **kwargs): infer_sharding_from_operands = None sharding_rule = None if config.use_shardy_partitioner.value: + if (self.sharding_rule is None and + (self.propagate_user_sharding is not None or + self.infer_sharding_from_operands is not None)): + raise NotImplementedError( + "Shardy is used, but sharding propagation callbacks instead of " + "sharding_rule are provided. Need to provide sharding_rule to " + "migrate to Shardy." + ) sharding_rule = self.sharding_rule else: propagate_user_sharding = self.propagate_user_sharding @@ -557,11 +603,11 @@ def to_mesh_pspec_sharding(hlo_sharding: xc.HloSharding | None, ndim): return hlo_sharding if mesh.empty or not decode_shardings: assert devices is not None - return sharding_impls._op_sharding_to_pos_sharding(hlo_sharding, devices) + return sharding_impls.GSPMDSharding(devices, hlo_sharding) pspec = sharding_impls.parse_flatten_op_sharding( hlo_sharding, mesh)[0] - pspec = jax.sharding.PartitionSpec(*pspec, *((None,) * (ndim - len(pspec)))) - return jax.sharding.NamedSharding(mesh, pspec) + pspec = sharding_impls.PartitionSpec(*pspec, *((None,) * (ndim - len(pspec)))) + return sharding_impls.NamedSharding(mesh, pspec) sharding_callback_info = _ShardingCallbackInfo(propagate_user_sharding, partition, to_mesh_pspec_sharding, in_tree, out_tree, @@ -587,8 +633,12 @@ def to_mesh_pspec_sharding(hlo_sharding: xc.HloSharding | None, ndim): value_types = [mlir.aval_to_ir_type(s) for s in call.in_avals] if callable(sharding_rule): sharding_rule = sharding_rule(*static_args, mesh, value_types, result_types) + if isinstance(sharding_rule, (list, tuple)) and len(sharding_rule) == 2: + sharding_rule, sharding_rule_dict = sharding_rule + else: + sharding_rule_dict = {} if isinstance(sharding_rule, str): - sharding_rule = str_to_sdy_sharding_rule(sharding_rule) + sharding_rule = str_to_sdy_sharding_rule(sharding_rule, **sharding_rule_dict) elif not isinstance(sharding_rule, SdyShardingRule): raise ValueError("sharding_rule callable must produce either an " "SdyShardingRule object or an Einsum-like notation " diff --git a/jax/_src/custom_partitioning_sharding_rule.py b/jax/_src/custom_partitioning_sharding_rule.py index 5e2e5f4e0479..95377771a38d 100644 --- a/jax/_src/custom_partitioning_sharding_rule.py +++ b/jax/_src/custom_partitioning_sharding_rule.py @@ -15,6 +15,7 @@ """Implements SdyShardingRule.""" from collections import OrderedDict +from typing import Union from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import sdy @@ -28,7 +29,7 @@ _BATCHING_DIM_FACTOR_PREFIX = "?" # A Jax value in general corresponds to an ir.Type or a tuple of ir.Types. -IrTypes = ir.Type | tuple[ir.Type, ...] +IrTypes = Union[ir.Type, tuple[ir.Type, ...]] def _check_factor(factor:str): """Validates a factor. @@ -107,18 +108,30 @@ def __new__(cls, *dim_mappings): class SdyShardingRule: """Represents a Shardy sharding rule. - An SdyShardingRule contains the ArrayMappings for operands and results, and an - optional list of factor sizes. A factor is a name used in the ArrayMappings. - If a factor is only used in CompoundFactors, its size must be specified. + An SdyShardingRule contains the ArrayMappings for operands and results, + optional special factors and optional factor sizes. A factor is a name used in + the ArrayMappings. If a factor is only used in CompoundFactors, its size must + be specified. + + By default, a factor is a passthrough factor. Keyword arguments can be used to + specify other factor kinds including reduction_factors, need_replication_factors, + and permutation_factors. """ operand_mappings: tuple[ArrayMapping, ...] result_mappings: tuple[ArrayMapping, ...] factor_sizes: dict[str, int] + reduction_factors: tuple[str, ...] + need_replication_factors: tuple[str, ...] + permutation_factors: tuple[str, ...] def __init__(self, operand_mappings: tuple[ArrayMapping, ...], - result_mappings: tuple[ArrayMapping, ...], **factor_sizes): + result_mappings: tuple[ArrayMapping, ...], + *, reduction_factors: tuple[str, ...] = (), + need_replication_factors: tuple[str, ...] = (), + permutation_factors: tuple[str, ...] = (), + **factor_sizes: int): # Find all factors and mark whether their size can be inferred. - factors_inferrable = dict() + factors_inferrable = {} for value in operand_mappings + result_mappings: for dim in value: if isinstance(dim, str): @@ -137,22 +150,56 @@ def __init__(self, operand_mappings: tuple[ArrayMapping, ...], # Check that factors that are used for a whole dimension aren't in # factor_sizes and factors that are never used for a whole dimension are # in factor_sizes. - for factor, inferrable in factors_inferrable.items(): - if factor not in factor_sizes and not inferrable: + for factor, inferable in factors_inferrable.items(): + if factor not in factor_sizes and not inferable: raise ValueError( f"Factor {factor} is only used in compound factors; must specify" " its size") - if factor in factor_sizes and inferrable: + if factor in factor_sizes and inferable: raise ValueError( f"Factor {factor} represents a whole dimension; do not specify its" " size") + special_factors = set() + def check_special_factors(kind, factors): + if not isinstance(factors, tuple): + raise ValueError(f"{kind} must be a tuple of factors") + + if len(factors) != len(set(factors)): + raise ValueError(f"{kind} contains duplicated factors") + + for factor in factors: + if factor not in factors_inferrable: + raise ValueError( + f"Factor {factor} in {kind} is not used in the rule") + if factor in special_factors: + raise ValueError(f"Factor {factor} can only be in one of the " + f"reduction, need replication, or permutation factor sets.") + special_factors.add(factor) + + check_special_factors("reduction_factors", reduction_factors) + check_special_factors("need_replication_factors", need_replication_factors) + check_special_factors("permutation_factors", permutation_factors) + self.operand_mappings = operand_mappings self.result_mappings = result_mappings self.factor_sizes = factor_sizes + self.reduction_factors = reduction_factors + self.need_replication_factors = need_replication_factors + self.permutation_factors = permutation_factors + def __str__(self): - return f"SdyShardingRule({self.operand_mappings}, {self.result_mappings}, {self.factor_sizes})" + def to_str(kind, factors): + if len(factors) > 0: + return f" {kind}={factors}" + return "" + + special_factors = (to_str("reduction_factors", self.reduction_factors) + + to_str("need_replication_factors", self.need_replication_factors) + + to_str("permutation_factors", self.permutation_factors)) + return (f"SdyShardingRule({self.operand_mappings}, {self.result_mappings}, " + f"{self.factor_sizes}{special_factors})") def _get_batching_dim_factor_name(batch_group: str,batch_dim_order : int): @@ -267,15 +314,22 @@ def add_factor(x): return tuple(all_values) -def str_to_sdy_sharding_rule(rule: str, **factor_sizes) -> SdyShardingRule: +def str_to_sdy_sharding_rule(rule: str, *, + reduction_factors: tuple[str, ...] = (), + need_replication_factors: tuple[str, ...] = (), + permutation_factors: tuple[str, ...] = (), + **factor_sizes: int) -> SdyShardingRule: """Constructs a SdyShardingRule object from the Einsum notation like string. This is done by verifying that the input Einsum notation like string and - with optional factor sizes represents a valid sharding rule and converting - it to an internal representation. + with optional special factors and factor sizes represents a valid sharding + rule and converting it to an internal representation. Args: rule: The Einsum notation like string for an operation. + reduction_factors: A tuple of factors that are reduction factors. + need_replication_factors: A tuple of factors that are need_replication factors. + permutation_factors: A tuple of factors that are permutation factors. **factor_sizes: The optional factor sizes. Raises: @@ -302,8 +356,11 @@ def str_to_sdy_sharding_rule(rule: str, **factor_sizes) -> SdyShardingRule: operand_mappings = _parse_values(operands) result_mappings = _parse_values(results) - - return SdyShardingRule(operand_mappings, result_mappings, **factor_sizes) + return SdyShardingRule(operand_mappings, result_mappings, + reduction_factors=reduction_factors, + need_replication_factors=need_replication_factors, + permutation_factors=permutation_factors, + **factor_sizes) def sdy_sharding_rule_to_mlir( @@ -350,18 +407,20 @@ def add_factor(factor, size): `size` may be a dimensions size, a user specified factor size, or UNKNOWN if a factor is first used as in a compound factor and then used for a - whole dimension. + whole dimension. If a factor is not for a leading batching dimension and + it corresponds to multiple sizes, the smallest size is used. """ factor_index, factor_size = factors_to_indices_sizes.get(factor, [UNKNOWN, UNKNOWN]) if factor_index != UNKNOWN: # Not the first time seeing the factor. if size != UNKNOWN and factor_size != UNKNOWN and factor_size != size: - factor_or_batching_dim = ( - f"Factor {factor}" if _BATCHING_DIM_FACTOR_PREFIX not in factor - else f"Batching dimension {factor[1:]}") - raise ValueError( - f"{factor_or_batching_dim} corresponds to two sizes:" - f" {factor_size} and {size}") + if _BATCHING_DIM_FACTOR_PREFIX in factor: + raise ValueError(f"Batching dimension {factor[1:]} corresponds to " + f"two sizes: {factor_size} and {size}") + else: + if size < factor_size: + # Use the smaller size to update the factor size. + factor_size = UNKNOWN if size != UNKNOWN and factor_size == UNKNOWN: factors_to_indices_sizes[factor] = [factor_index, size] else: @@ -389,6 +448,9 @@ def build_dim_mapping_for_compound_factors(i, j, factors): return sdy.DimMappingAttr.get(factor_indices=all_indices) + def factors_to_indices(factors): + return [factors_to_indices_sizes[factor][0] for factor in factors] + # Add factors and their sizes in the order they appear in the rule, # including the batching dimensions represented by ellipsis. batching_group_to_rank: dict[str, int] = {} @@ -472,4 +534,9 @@ def build_dim_mapping_for_compound_factors(i, j, factors): return sdy.OpShardingRuleAttr.get( factor_sizes=[item[1] for item in factors_to_indices_sizes.values()], operand_mappings=tensor_mappings[0:len(operand_types)], - result_mappings=tensor_mappings[len(operand_types):]) + result_mappings=tensor_mappings[len(operand_types):], + is_custom=True, + reduction_factors=factors_to_indices(rule.reduction_factors), + need_replication_factors=factors_to_indices(rule.need_replication_factors), + permutation_factors=factors_to_indices(rule.permutation_factors), + ) diff --git a/jax/_src/custom_transpose.py b/jax/_src/custom_transpose.py index 5e87fdb203c9..1762522a8deb 100644 --- a/jax/_src/custom_transpose.py +++ b/jax/_src/custom_transpose.py @@ -29,7 +29,7 @@ from jax._src.interpreters import ad from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe -from jax._src.interpreters import xla +from jax._src.interpreters import pxla from jax._src.tree_util import (tree_flatten, tree_leaves, tree_map, tree_structure, treedef_tuple, tree_unflatten, PyTreeDef) @@ -105,7 +105,7 @@ def __call__(self, out_types, res_arg, lin_arg): (res_arg, out_types), {}) ) out_flat = custom_transpose_p.bind(flat_fun, *args_flat, transpose=transpose_wrapped, - out_types=out_types_flat, + out_types=tuple(out_types_flat), lin_tree=lin_tree, res_tree=res_tree, out_tree=out_tree) @@ -177,15 +177,19 @@ def bind_with_trace(self, trace, call_args, params): # TODO(frostig,mattjj): consider keeping `call` as a named parameter # instead of following this "call primitive" convention. def get_bind_params(self, params): - assert 'call_jaxpr' in params - assert 'transpose_jaxpr_thunk' in params - new_params: dict[str, Any] = dict(params) - new_params['transpose'] = make_transpose_from_thunk( - new_params.pop('transpose_jaxpr_thunk'), - new_params['lin_tree']) - call_jaxpr: core.ClosedJaxpr = new_params.pop('call_jaxpr') - call = lu.wrap_init(core.jaxpr_as_fun(call_jaxpr), - debug_info=call_jaxpr.jaxpr.debug_info) + if 'call_jaxpr' in params: + assert 'transpose_jaxpr_thunk' in params + new_params: dict[str, Any] = dict(params) + new_params['transpose'] = make_transpose_from_thunk( + new_params.pop('transpose_jaxpr_thunk'), + new_params['lin_tree']) + call_jaxpr: core.ClosedJaxpr = new_params.pop('call_jaxpr') + call = lu.wrap_init(core.jaxpr_as_fun(call_jaxpr), + debug_info=call_jaxpr.jaxpr.debug_info) + else: + assert 'transpose' in params + new_params: dict[str, Any] = dict(params) + call = new_params.pop("call") return [call], new_params @@ -213,7 +217,6 @@ def custom_transpose_transpose_rule( # Consider passing this information to the custom transpose rule? res_arg, lin_arg = tree_unflatten(call_in_tree, args) - del lin_arg assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg)) cts = [ad_util.zeros_like_aval(ct.aval) if type(ct) is ad_util.Zero else ct @@ -221,10 +224,17 @@ def custom_transpose_transpose_rule( ct_out = tree_unflatten(out_tree, cts) ct_lin = transpose.call_wrapped(res_arg, ct_out) check_transpose_rule_trees(transpose, lin_tree, tree_structure(ct_lin)) - ct_lin_flat, _ = tree_flatten( - tree_broadcast(lin_tree, ct_lin, is_leaf=lambda x: x is None), - is_leaf=lambda x: x is None) - return [None] * len(tree_leaves(res_arg)) + ct_lin_flat + ct_lin = tree_broadcast(lin_tree, ct_lin, is_leaf=lambda x: x is None) + + # When the transpose returns None, we treat that as a Zero, except when the + # input is also None. In that case, the cotangent corresponding to that input + # should be dropped. + zero = object() + ct_lin = tree_map(lambda l, ct: zero if ct is None and l is not None else ct, + lin_arg, ct_lin, is_leaf=ad.is_undefined_primal) + + ct_lin_flat, _ = tree_flatten(ct_lin) + return [None] * res_tree.num_leaves + [None if ct is zero else ct for ct in ct_lin_flat] def custom_transpose_lowering(*args, call_jaxpr, **params): @@ -237,4 +247,4 @@ def custom_transpose_lowering(*args, call_jaxpr, **params): mlir.register_lowering( custom_transpose_p, mlir.lower_fun(custom_transpose_lowering, multiple_results=True)) -xla.register_initial_style_primitive(custom_transpose_p) +pxla.register_initial_style_primitive(custom_transpose_p) diff --git a/jax/_src/debugger/cli_debugger.py b/jax/_src/debugger/cli_debugger.py index bf4b38765026..eb1eca3bec48 100644 --- a/jax/_src/debugger/cli_debugger.py +++ b/jax/_src/debugger/cli_debugger.py @@ -105,7 +105,7 @@ def do_pp(self, arg): def do_up(self, _): """u(p) - Move down a stack frame. + Move up a stack frame. """ if self.frame_index == len(self.frames) - 1: print('At topmost frame.', file=self.stdout) diff --git a/jax/_src/debugger/colab_debugger.py b/jax/_src/debugger/colab_debugger.py index 57a5be4825d6..2ebcbafe0643 100644 --- a/jax/_src/debugger/colab_debugger.py +++ b/jax/_src/debugger/colab_debugger.py @@ -66,7 +66,7 @@ def _highlight_code(self, code: str, highlights, linenostart: int): hl_color = "#4e56b7" if is_dark_mode else "#fff7c1" if IS_PYGMENTS_ENABLED: lexer = pygments.lexers.get_lexer_by_name("python") - formatter = pygments.formatters.HtmlFormatter( + formatter = pygments.formatters.HtmlFormatter( # pytype: disable=module-attr full=False, hl_lines=highlights, linenos=True, diff --git a/jax/_src/debugger/core.py b/jax/_src/debugger/core.py index 1efeed73cbc8..54a086977b13 100644 --- a/jax/_src/debugger/core.py +++ b/jax/_src/debugger/core.py @@ -19,12 +19,13 @@ import threading from typing import Any, Protocol -import jax -from jax import tree_util +from jax._src import callback from jax._src import core from jax._src import debugging from jax._src import traceback_util +from jax._src import tree_util from jax._src import util +from jax._src.lax import lax @tree_util.register_pytree_node_class @@ -120,7 +121,7 @@ def from_frameinfo(cls, frame_info) -> DebuggerFrame: except OSError: source = [] offset = None - return DebuggerFrame( + return DebuggerFrame( # pytype: disable=wrong-arg-types filename=frame_info.filename, locals=frame_info.frame.f_locals, globals={}, @@ -225,5 +226,5 @@ def _breakpoint_callback(*flat_args): def _breakpoint_callback_wrapper(x, *flat_args): _breakpoint_callback(*flat_args) return x - token, flat_args = jax.lax.stop_gradient((token, flat_args)) - return jax.pure_callback(_breakpoint_callback_wrapper, token, token, *flat_args) + token, flat_args = lax.stop_gradient((token, flat_args)) + return callback.pure_callback(_breakpoint_callback_wrapper, token, token, *flat_args) diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index b61b28e12f43..c2303b0b7f3a 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -16,6 +16,7 @@ from __future__ import annotations from collections.abc import Callable, Sequence +import copy from functools import partial import importlib.util import logging @@ -26,18 +27,20 @@ import numpy as np -import jax -import jax.numpy as jnp -from jax import lax +from jax._src import api from jax._src import callback as cb from jax._src import config from jax._src import core from jax._src import dispatch from jax._src import effects +from jax._src import lax from jax._src import mesh as mesh_lib +from jax._src import shard_map from jax._src import sharding_impls +from jax._src import source_info_util from jax._src import tree_util from jax._src import util +from jax._src import xla_bridge from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -45,6 +48,7 @@ from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo +from jax._src.numpy import lax_numpy as jnp from jax._src.sharding import Sharding from jax._src.sharding_impls import ( NamedSharding, PartitionSpec as P, parse_flatten_op_sharding) @@ -58,8 +62,8 @@ class DebugEffect(effects.Effect): class OrderedDebugEffect(effects.Effect): __str__ = lambda self: "OrderedDebug" -ordered_debug_effect = OrderedDebugEffect() +ordered_debug_effect = OrderedDebugEffect() effects.ordered_effects.add_type(OrderedDebugEffect) effects.lowerable_effects.add_type(DebugEffect) effects.lowerable_effects.add_type(OrderedDebugEffect) @@ -69,6 +73,8 @@ class OrderedDebugEffect(effects.Effect): effects.remat_allowed_effects.add_type(OrderedDebugEffect) effects.custom_derivatives_allowed_effects.add_type(DebugEffect) effects.custom_derivatives_allowed_effects.add_type(OrderedDebugEffect) +effects.partial_eval_kept_effects.add_type(DebugEffect) +effects.partial_eval_kept_effects.add_type(OrderedDebugEffect) # `debug_callback_p` is the main primitive for staging out Python callbacks. debug_callback_p = core.Primitive('debug_callback') @@ -78,18 +84,20 @@ class OrderedDebugEffect(effects.Effect): @debug_callback_p.def_impl def debug_callback_impl(*args, callback: Callable[..., Any], - effect: DebugEffect): - del effect + effect: DebugEffect, partitioned: bool): + del effect, partitioned try: - cpu_device, *_ = jax.local_devices(backend="cpu") + cpu_device, *_ = xla_bridge.local_devices(backend="cpu") except RuntimeError as e: raise RuntimeError( "jax.debug.callback failed to find a local CPU device to place the" " inputs on. Make sure \"cpu\" is listed in --jax_platforms or the" " JAX_PLATFORMS environment variable." ) from e - args = jax.device_put(args, cpu_device) - with jax.default_device(cpu_device): + args = api.device_put(args, cpu_device) + with (config.default_device(cpu_device), + sharding_impls._internal_use_concrete_mesh(mesh_lib.empty_concrete_mesh), + mesh_lib.use_abstract_mesh(mesh_lib.empty_abstract_mesh)): try: callback(*args) except BaseException: @@ -99,16 +107,16 @@ def debug_callback_impl(*args, callback: Callable[..., Any], @debug_callback_p.def_effectful_abstract_eval def debug_callback_abstract_eval(*flat_avals, callback: Callable[..., Any], - effect: DebugEffect): - del flat_avals, callback + effect: DebugEffect, partitioned: bool): + del flat_avals, callback, partitioned return [], {effect} -def debug_callback_batching_rule(args, dims, **params): + +def debug_batching_rule(args, dims, *, primitive, **params): """Unrolls the debug callback across the mapped axis.""" axis_size = next(x.shape[i] for x, i in zip(args, dims) if i is not None) - # TODO(sharadmv): implement in terms of rolled loop unstead of - # unrolled. + # TODO(sharadmv): implement in terms of rolled loop unstead of unrolled. def get_arg_at_dim(i, dim, arg): if dim is batching.not_mapped: # Broadcast unmapped argument @@ -117,34 +125,35 @@ def get_arg_at_dim(i, dim, arg): outs = [] for i in range(axis_size): args_idx = map(partial(get_arg_at_dim, i), dims, args) - outs.append(debug_callback_p.bind(*args_idx, **params)) + outs.append(primitive.bind(*args_idx, **params)) outs = [jnp.stack(xs) for xs in zip(*outs)] return outs, (0,) * len(outs) -batching.primitive_batchers[debug_callback_p] = debug_callback_batching_rule + + +batching.primitive_batchers[debug_callback_p] = partial( + debug_batching_rule, primitive=debug_callback_p +) def debug_callback_jvp_rule(primals, tangents, **params): return debug_callback_p.bind(*primals, **params), [] ad.primitive_jvps[debug_callback_p] = debug_callback_jvp_rule -def debug_callback_transpose_rule(*flat_args, callback: Callable[..., Any], - effect: DebugEffect): - del flat_args, callback, effect - raise ValueError("Transpose doesn't support debugging callbacks.") +def debug_callback_transpose_rule(_, *flat_args, callback: Callable[..., Any], + effect: DebugEffect, partitioned): + del callback, effect, partitioned + return [None for _ in flat_args] ad.primitive_transposes[debug_callback_p] = debug_callback_transpose_rule def _debug_callback_partial_auto(axis_context, *args, **params): - from jax.experimental.shard_map import shard_map partial_auto = list(set(axis_context.mesh.axis_names) - axis_context.manual_axes) def f(): - idx = jax.lax.with_sharding_constraint( - jax.lax.axis_index(*partial_auto), - NamedSharding(axis_context.mesh, P())) - return jax.lax.cond(idx == 0, - lambda: debug_callback_p.bind(*args, **params), - lambda: []) - return shard_map(f, axis_context.mesh, in_specs=(), out_specs=[])() - -def debug_callback_lowering(ctx, *args, effect, callback, **params): + idx = lax.axis_index(*partial_auto) + return lax.cond(idx == 0, + lambda: debug_callback_p.bind(*args, **params), + lambda: []) + return shard_map.shard_map(f, in_specs=(), out_specs=[])() + +def debug_callback_lowering(ctx, *args, effect, partitioned, callback, **params): axis_context = ctx.module_context.axis_context if isinstance(axis_context, sharding_impls.SPMDAxisContext): # We're a shard_map, which might be partial-manual or full-manual. @@ -152,21 +161,20 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params): if partial_auto: # If we have partial manual / partial auto sharding, we gather and # conditionally run the callback. - lower = partial(_debug_callback_partial_auto, axis_context, - effect=effect, callback=callback, **params) + lower = partial( + _debug_callback_partial_auto, + axis_context, + effect=effect, + partitioned=partitioned, + callback=callback, + **params, + ) return mlir.lower_fun(lower)(ctx, *args) elif set(axis_context.manual_axes) == set(axis_context.mesh.axis_names): # If we have fully manual sharding during lowering, that means the JAX # program has per-device semantics, so we run the callback on each device. if config.use_shardy_partitioner.value: - assert len(ctx.avals_out) == 1 - sharding = sharding_impls.SdyArrayShardingList([ - sharding_impls.SdyArraySharding( - mesh_shape=(), - dimension_shardings=[ - sharding_impls.SdyDimSharding(axes=[], is_closed=True) - ] * ctx.avals_out[0].ndim, - logical_device_ids=())]) + sharding = cb._get_sdy_array_list_for_callbacks(ctx.avals_out) else: sharding = xc.OpSharding() sharding.type = xc.OpSharding.Type.MANUAL @@ -177,9 +185,9 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params): # program has bulk array semantics, so we run the callback with a MAXIMAL # sharding and hence execute it only once on the full logical value). if config.use_shardy_partitioner.value: - sharding = sharding_impls.SdyArrayShardingList([ - sharding_impls.SdyArraySharding( - mesh_shape=(), dimension_shardings=[], logical_device_ids=(0,))]) + sharding = sharding_impls.SdyArrayList([ + sharding_impls.SdyArray( + mesh_shape=(), dim_shardings=[], logical_device_ids=(0,))]) else: sharding = xc.OpSharding() sharding.type = xc.OpSharding.Type.MAXIMAL @@ -191,27 +199,36 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params): def _callback(*flat_args): debug_callback_p.impl( - *flat_args, effect=effect, callback=callback, **params) + *flat_args, + effect=effect, + partitioned=partitioned, + callback=callback, + **params, + ) return () if effects.ordered_effects.contains(effect): token = ctx.tokens_in.get(effect) result, token, _ = cb.emit_python_callback( ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out, - has_side_effect=True) + has_side_effect=True, returns_token=True, partitioned=partitioned) ctx.set_tokens_out(mlir.TokenSet({effect: token})) else: result, _, _ = cb.emit_python_callback( ctx, _callback, None, list(args), ctx.avals_in, ctx.avals_out, - has_side_effect=True, sharding=sharding) + has_side_effect=True, returns_token=True, partitioned=partitioned, + sharding=sharding) return result mlir.register_lowering(debug_callback_p, debug_callback_lowering, platform="cpu") mlir.register_lowering( debug_callback_p, debug_callback_lowering, platform="gpu") +# Debug callbacks use channel IDs on TPU, which require non-caching. mlir.register_lowering( - debug_callback_p, debug_callback_lowering, platform="tpu") + debug_callback_p, debug_callback_lowering, platform="tpu", + cacheable=False) + -def _debug_callback_partial_eval_custom(saveable, unks_in, inst_in, eqn): +def _debug_partial_eval_custom(saveable, unks_in, inst_in, eqn, primitive): # The default behavior for effectful primitives is to not stage them if # possible. For debug callback, we actually want it to be staged to # provide more information to the user. This rule bypasses partial_eval's @@ -226,7 +243,7 @@ def _debug_callback_partial_eval_custom(saveable, unks_in, inst_in, eqn): # The usual case (if we have any unknowns, we need to stage it out) res = [v for v, inst in zip(eqn.invars, inst_in) if not inst] return None, eqn, [], [], res - if saveable(debug_callback_p, *[v.aval for v in eqn.invars], **eqn.params): + if saveable(primitive, *[v.aval for v in eqn.invars], **eqn.params): # The policy is telling us we can save the debug callback. if all(inst_in): # If all of the inputs are instantiated, we also stage out the @@ -239,19 +256,162 @@ def _debug_callback_partial_eval_custom(saveable, unks_in, inst_in, eqn): # If we can't save the debug callback (thanks to the policy) we listen to # the policy and stage out the debug callback. return eqn, eqn, [], [], [] -pe.partial_eval_jaxpr_custom_rules[debug_callback_p] = ( - _debug_callback_partial_eval_custom) + + +pe.partial_eval_jaxpr_custom_rules[debug_callback_p] = partial( + _debug_partial_eval_custom, primitive=debug_callback_p +) @state_discharge.register_discharge_rule(debug_callback_p) def _debug_callback_state_discharge_rule( - in_avals, out_avals, *args, effect, callback, **params + in_avals, out_avals, *args, effect, partitioned, callback, **params ): del in_avals, out_avals # Unused. - out = debug_callback_p.bind(*args, effect=effect, callback=callback, **params) + out = debug_callback_p.bind( + *args, effect=effect, partitioned=partitioned, callback=callback, **params + ) return args, out -def debug_callback(callback: Callable[..., None], *args: Any, - ordered: bool = False, **kwargs: Any) -> None: + +def _split_callback_args(args, kwargs): + flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) + static_args, dyn_args = {}, [] + for i, a in enumerate(flat_args): + try: + core.shaped_abstractify(a) + dyn_args.append(a) + except (AssertionError, TypeError): + static_args[i] = a + return in_tree, dyn_args, static_args + + +def merge_callback_args(in_tree, dyn_args, static_args): + static_args_dict = dict(static_args) + all_args = [None] * (len(static_args) + len(dyn_args)) + di = iter(dyn_args) + for i in range(len(all_args)): + if i in static_args_dict: + all_args[i] = static_args_dict[i] + else: + all_args[i] = next(di) + assert next(di, None) is None + return tree_util.tree_unflatten(in_tree, all_args) + + +def _make_flat_callback(in_tree, callback, static_args): + def _flat_callback(*dyn_args): + args, kwargs = merge_callback_args(in_tree, dyn_args, static_args) + callback(*args, **kwargs) + return () + return _flat_callback + + +debug_print_p = core.Primitive("debug_print") +debug_print_p.multiple_results = True + + +@debug_print_p.def_impl +def debug_print_impl( + *args: Any, + fmt: str, + ordered, + partitioned, + in_tree, + static_args, + np_printoptions, + has_placeholders, + logging_record, +): + callback = partial( + _format_print_callback, fmt, dict(np_printoptions), has_placeholders, + logging_record, + ) + callback = _make_flat_callback(in_tree, callback, static_args) + effect = ordered_debug_effect if ordered else debug_effect + debug_callback_impl( + *args, callback=callback, effect=effect, partitioned=partitioned + ) + return () + + +@debug_print_p.def_effectful_abstract_eval +def debug_print_abstract_eval(*avals: Any, fmt: str, ordered, **kwargs): + del avals, fmt, kwargs # Unused. + effect = ordered_debug_effect if ordered else debug_effect + return [], {effect} + + +batching.primitive_batchers[debug_print_p] = partial( + debug_batching_rule, primitive=debug_print_p +) + + +def debug_print_jvp_rule(primals, tangents, **params): + return debug_print_p.bind(*primals, **params), [] + + +ad.primitive_jvps[debug_print_p] = debug_print_jvp_rule + + +def debug_print_transpose_rule(_, *args, **kwargs): + del kwargs + return [None for _ in args] + + +ad.primitive_transposes[debug_print_p] = debug_print_transpose_rule + + +def debug_print_lowering_rule( + ctx, + *dyn_args, + fmt, + ordered, + partitioned, + in_tree, + static_args, + np_printoptions, + has_placeholders, + logging_record, +): + callback = partial( + _format_print_callback, + fmt, + dict(np_printoptions), + has_placeholders, + logging_record, + ) + callback = _make_flat_callback(in_tree, callback, static_args) + effect = ordered_debug_effect if ordered else debug_effect + return debug_callback_lowering( + ctx, *dyn_args, effect=effect, partitioned=partitioned, callback=callback + ) + + +mlir.register_lowering(debug_print_p, debug_print_lowering_rule, platform="cpu") +mlir.register_lowering(debug_print_p, debug_print_lowering_rule, platform="gpu") +mlir.register_lowering( + debug_print_p, debug_print_lowering_rule, platform="tpu", cacheable=False +) + +pe.partial_eval_jaxpr_custom_rules[debug_print_p] = partial( + _debug_partial_eval_custom, primitive=debug_print_p +) + + +@state_discharge.register_discharge_rule(debug_print_p) +def _debug_print_state_discharge_rule(in_avals, out_avals, *args, **kwargs): + del in_avals, out_avals # Unused. + out = debug_print_p.bind(*args, **kwargs) + return args, out + + +def debug_callback( + callback: Callable[..., None], + *args: Any, + ordered: bool = False, + partitioned: bool = False, + **kwargs: Any, +) -> None: """Calls a stageable Python callback. For more explanation, see `External Callbacks`_. @@ -274,6 +434,9 @@ def debug_callback(callback: Callable[..., None], *args: Any, ordered: A keyword only argument used to indicate whether or not the staged out computation will enforce ordering of this callback w.r.t. other ordered callbacks. + partitioned: If True, then print local shards only; this option avoids an + all-gather of the operands. If False, print with logical operands; this + option requires an all-gather of operands first. **kwargs: The keyword arguments to the callback. Returns: @@ -284,19 +447,12 @@ def debug_callback(callback: Callable[..., None], *args: Any, - :func:`jax.pure_callback`: callback designed for pure functions. - :func:`jax.debug.print`: callback designed for printing. - .. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html + .. _External Callbacks: https://docs.jax.dev/en/latest/notebooks/external_callbacks.html """ if not callable(callback): raise TypeError("first argument to jax.debug.callback must be callable, " f"but got an object of type {type(callback)}") - flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) - static_args, dyn_args = {}, [] - for i, a in enumerate(flat_args): - try: - core.shaped_abstractify(a) - dyn_args.append(a) - except (AssertionError, TypeError): - static_args[i] = a + in_tree, dyn_args, static_args = _split_callback_args(args, kwargs) def _flat_callback(*dyn_args): all_args = [None] * (len(static_args) + len(dyn_args)) @@ -312,7 +468,10 @@ def _flat_callback(*dyn_args): return () effect = ordered_debug_effect if ordered else debug_effect - debug_callback_p.bind(*dyn_args, callback=_flat_callback, effect=effect) + debug_callback_p.bind( + *dyn_args, callback=_flat_callback, effect=effect, partitioned=partitioned + ) + class _DebugPrintFormatChecker(string.Formatter): @@ -334,11 +493,48 @@ def check_unused_args(self, used_args, args, kwargs): formatter = _DebugPrintFormatChecker() -def _format_print_callback(fmt: str, np_printoptions, *args, **kwargs): - with np.printoptions(**np_printoptions): - sys.stdout.write(fmt.format(*args, **kwargs) + "\n") -def debug_print(fmt: str, *args, ordered: bool = False, **kwargs) -> None: +def _format_print_callback( + fmt: str, np_printoptions, has_placeholders, logging_record, *args, **kwargs +): + if has_placeholders: + with np.printoptions(**np_printoptions): + msg = fmt.format(*args, **kwargs) + else: + assert not kwargs, "Format without placeholders should not have kwargs." + msg = " ".join((fmt, *(str(a) for a in args))) + if logging_record: + logging_record = copy.copy(logging_record) + logging_record.msg = msg + logger.handle(logging_record) + else: + sys.stdout.write(msg + "\n") + + +def _make_logging_record(level): + si = source_info_util.current() + user_frame = source_info_util.user_frame(si.traceback) + + file_name = "(unknown file)" + line_no = 0 + if user_frame: + file_name = user_frame.file_name + line_no = user_frame.start_line + args = () + return logger.makeRecord( + logger.name, level, file_name, line_no, "", args, None + ) + + +def debug_print( + fmt: str, + *args, + ordered: bool = False, + partitioned: bool = False, + skip_format_check: bool = False, + _use_logging: bool = False, + **kwargs, +) -> None: """Prints values and works in staged out JAX functions. This function does *not* work with f-strings because formatting is delayed. @@ -359,23 +555,50 @@ def debug_print(fmt: str, *args, **kwargs): Args: fmt: A format string, e.g. ``"hello {x}"``, that will be used to format - input arguments, like ``str.format``. See the Python docs on - `string formatting `_ - and `format string syntax `_. + input arguments, like ``str.format``. See the Python docs on `string + formatting `_ + and `format string syntax + `_. *args: A list of positional arguments to be formatted, as if passed to ``fmt.format``. - ordered: A keyword only argument used to indicate whether or not the - staged out computation will enforce ordering of this ``jax.debug.print`` - w.r.t. other ordered ``jax.debug.print`` calls. + ordered: A keyword only argument used to indicate whether or not the staged + out computation will enforce ordering of this ``jax.debug.print`` w.r.t. + other ordered ``jax.debug.print`` calls. + partitioned: If True, then print local shards only; this option avoids an + all-gather of the operands. If False, print with logical operands; this + option requires an all-gather of operands first. + skip_format_check: If True, the format string is not checked. This is useful + when using the function from inside a Pallas TPU kernel, where scalars + args will be printed after the format string. **kwargs: Additional keyword arguments to be formatted, as if passed to ``fmt.format``. """ - # Check that we provide the correct arguments to be formatted. - formatter.format(fmt, *args, **kwargs) - - debug_callback(partial(_format_print_callback, fmt, np.get_printoptions()), - *args, **kwargs, ordered=ordered) - + if not skip_format_check: + # Check that we provide the correct arguments to be formatted. + formatter.format(fmt, *args, **kwargs) + has_placeholders = False + if fmt: + _, field_name, *_ = next(iter(string.Formatter().parse(fmt))) + has_placeholders = field_name is not None + in_tree, dyn_args, static_args = _split_callback_args(args, kwargs) + static_args = tuple(static_args.items()) + np_printoptions = tuple(np.get_printoptions().items()) + + debug_print_p.bind( + *dyn_args, + fmt=fmt, + ordered=ordered, + partitioned=partitioned, + in_tree=in_tree, + static_args=static_args, + np_printoptions=np_printoptions, + has_placeholders=has_placeholders, + logging_record=(_make_logging_record(logging.INFO) if _use_logging + else None), + ) + + +debug_log = partial(debug_print, _use_logging=True) # Sharding visualization @@ -429,6 +652,7 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *, mesh = mesh_lib.Mesh(np.array(devices).reshape(am.axis_sizes), am.axis_names) elif isinstance(axis_context, sharding_impls.SPMDAxisContext): + mesh = axis_context.mesh devices = axis_context.mesh._flat_devices_tuple else: raise NotImplementedError(type(axis_context)) @@ -439,8 +663,9 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *, def _hlo_sharding_callback(hlo_sharding: xc.HloSharding): if mesh.empty: return callback( - sharding_impls._op_sharding_to_pos_sharding(hlo_sharding, devices)) - pspec = parse_flatten_op_sharding(hlo_sharding, mesh)[0] + sharding_impls.GSPMDSharding(devices, hlo_sharding)) + pspec = (P() if hlo_sharding.is_manual() else + parse_flatten_op_sharding(hlo_sharding, mesh)[0]) return callback(NamedSharding(mesh, pspec)) if len(devices) == 1: @@ -634,7 +859,7 @@ def inspect_array_sharding(value, *, callback: Callable[[Sharding], None]): """Enables inspecting array sharding inside JIT-ted functions. This function, when provided with a Pytree of arrays, calls back with each of - their shardings and works in ``pjit``-ted computations, enabling inspecting + their shardings and works in ``jax.jit``-ted computations, enabling inspecting the chosen intermediate shardings. The policy for when ``callback`` is called is *as early as possible* when the @@ -643,9 +868,9 @@ def inspect_array_sharding(value, *, callback: Callable[[Sharding], None]): since we have the array and its sharding readily available. Inside of a ``jax.jit``, the callback will happen at lowering time, meaning you can trigger the callback using the AOT API (``jit(f).lower(...)``). When inside of - a ``pjit``, the callback happens *at compile time* since the sharding is + a ``jax.jit``, the callback happens *at compile time* since the sharding is determined by XLA. You can trigger the callback by using JAX's AOT API - (``pjit(f).lower(...).compile()``). In all cases, the callback will be + (``jax.jit(f).lower(...).compile()``). In all cases, the callback will be triggered by running the function, since running a function entails lowering and compiling it first. However, once the function is compiled and cached, the callback will no longer occur. @@ -657,11 +882,10 @@ def inspect_array_sharding(value, *, callback: Callable[[Sharding], None]): callback: A callable that takes in a ``Sharding`` and doesn't return a value. In the following example, we print out the sharding of an intermediate value - in a ``pjit``-ted computation: + in a ``jax.jit``-ted computation: >>> import jax >>> import jax.numpy as jnp - >>> from jax.experimental.pjit import pjit >>> from jax.sharding import Mesh, PartitionSpec >>> >>> x = jnp.arange(8, dtype=jnp.float32) @@ -669,9 +893,9 @@ def inspect_array_sharding(value, *, callback: Callable[[Sharding], None]): ... x = jnp.sin(x) ... jax.debug.inspect_array_sharding(x, callback=print) ... return jnp.square(x) - >>> f = pjit(f_, in_shardings=PartitionSpec('dev'), - ... out_shardings=PartitionSpec('dev')) - >>> with Mesh(jax.devices(), ('dev',)): + >>> f = jax.jit(f_, in_shardings=PartitionSpec('dev'), + ... out_shardings=PartitionSpec('dev')) + >>> with jax.set_mesh(Mesh(jax.devices(), ('dev',))): ... f.lower(x).compile() # doctest: +SKIP ... NamedSharding(mesh={'dev': 8}, partition_spec=PartitionSpec(('dev',),)) @@ -685,3 +909,48 @@ def visualize_array_sharding(arr, **kwargs): def _visualize(sharding): return visualize_sharding(arr.shape, sharding, **kwargs) inspect_array_sharding(arr, callback=_visualize) + + +# TODO(mattjj): working around an apparent XLA or PjRt bug, remove eventually +def _debug_callback_eager_rule( + mesh, + *args, + callback: Callable[..., Any], + effect: DebugEffect, + partitioned: bool, +): + del effect + with core.eval_context(): + all_blocks = zip(*map(list, args)) + for (idx, device), blocks in zip(np.ndenumerate(mesh.devices), all_blocks): + callback(*blocks) + return [] +shard_map.eager_rules[debug_callback_p] = _debug_callback_eager_rule + + +def _debug_print_eager_rule( + mesh, + *args, + fmt: str, + ordered, + partitioned, + in_tree, + static_args, + np_printoptions, + has_placeholders, + logging_record, +): + del ordered, partitioned + callback = partial( + _format_print_callback, fmt, dict(np_printoptions), has_placeholders, + logging_record, + ) + callback = _make_flat_callback(in_tree, callback, static_args) + with core.eval_context(): + all_blocks = zip(*map(list, args)) + for (idx, device), blocks in zip(np.ndenumerate(mesh.devices), all_blocks): + callback(*blocks) + return [] + + +shard_map.eager_rules[debug_print_p] = _debug_print_eager_rule diff --git a/jax/_src/deprecations.py b/jax/_src/deprecations.py index 37f2f0264782..633a1b1808b9 100644 --- a/jax/_src/deprecations.py +++ b/jax/_src/deprecations.py @@ -123,16 +123,12 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None: # Register a number of deprecations: we do this here to ensure they're # always registered by the time `accelerate` and `is_acelerated` are called. -register('jax-aval-named-shape') -register('jax-dlpack-import-legacy') +register('default-dtype-bits-config') +register('jax-checkpoint-concrete') register('jax-nn-one-hot-float-input') -register("jax-numpy-astype-complex-to-real") -register("jax-numpy-array-none") +register('jax-numpy-arange-complex') +register('jax-numpy-astype-complex-to-real') register('jax-numpy-clip-args') -register('jax-numpy-linalg-matrix_rank-tol') -register('jax-numpy-linalg-pinv-rcond') -register('jax-numpy-quantile-interpolation') -register('jax-numpy-reduction-non-boolean-where') -register('jax-numpy-trimzeros-not-1d-array') -register('pallas-gpu-triton') register('jax-scipy-special-sph-harm') +register('jax-pmap-shmap-merge') +register('pltpu-memory-space-any') diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 2330f7628966..e3876469c8e9 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -18,40 +18,42 @@ import atexit from collections.abc import Sequence import dataclasses -import enum from functools import partial -import itertools import logging import threading import time -from typing import Any, Callable, NamedTuple +from typing import Any -import jax from jax._src import api from jax._src import array from jax._src import basearray from jax._src import config from jax._src import core from jax._src import dtypes -from jax._src import lib -from jax._src import source_info_util + +from jax._src import literals +from jax._src import pjit from jax._src import traceback_util from jax._src import util + +from jax._src import xla_bridge from jax._src.abstract_arrays import array_types from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval from jax._src.interpreters import pxla -from jax._src.interpreters import xla -from jax._src.layout import DeviceLocalLayout, Layout +from jax._src.api_util import InternalFloatingPointError +from jax._src.layout import Layout, Format from jax._src.lib import xla_client as xc from jax._src.mesh import AbstractMesh, Mesh -from jax._src.monitoring import record_event_duration_secs, record_event_time_span +from jax._src.monitoring import record_scalar, record_event_duration_secs, record_event_time_span from jax._src.partition_spec import PartitionSpec from jax._src.sharding import Sharding -from jax._src.sharding_impls import ( NamedSharding, - SingleDeviceSharding, TransferToMemoryKind, +from jax._src.sharding_impls import ( + NamedSharding, SingleDeviceSharding, GSPMDSharding, is_single_device_sharding) +from jax._src.stages import SourceInfo import numpy as np @@ -65,6 +67,7 @@ Backend = xe.Client Device = xc.Device +ArrayCopySemantics = xc.ArrayCopySemantics CompileOptions = xc.CompileOptions @@ -83,13 +86,18 @@ def apply_primitive(prim, *args, **params): fun = xla_primitive_callable(prim, **params) # TODO(yashkatariya): Investigate adding is_primitive to jit and never # triggering the disable jit path instead of messing around with it here. - prev = lib.jax_jit.swap_thread_local_state_disable_jit(False) + prev = config.disable_jit.swap_local(False) try: outs = fun(*args) finally: - lib.jax_jit.swap_thread_local_state_disable_jit(prev) + config.disable_jit.set_local(prev) return outs +# TODO(necula): this cache will contain strong references to all +# Jaxprs in `params` (for higher-order primitives). +# This is not immediately fixable by using +# util.multi_weakref_lru_cache, because the `params` (including the Jaxpr) +# are closed over in the `prim_fun` lambda. Leaving this fix for a later PR. @util.cache() def xla_primitive_callable(prim: core.Primitive, **params): util.test_event("xla_primitive_callable_cache_miss") @@ -132,12 +140,17 @@ def get_token_input( # TODO(yueshengys): This might still be buggy in a multi-process SPMD # scenario. Revise the logic later. A distributed shutdown barrier inside # the XLA program may be needed. - return jax.device_put(tok, jax.sharding.PositionalSharding(devices)) + return api.device_put( + tok, NamedSharding(Mesh(devices, 'x'), PartitionSpec('x'))) # We only use replicated sharding for the first time when the token for the # order effect hasn't been created. - s = jax.sharding.GSPMDSharding.get_replicated(devices) - sharded_tok = core.Token(pxla.shard_args([s], [None], [None], [tok])[0]) + s = GSPMDSharding.get_replicated(devices) + sharded_tok = core.Token( + pxla.shard_args( + [s], [None], [xc.ArrayCopySemantics.REUSE_INPUT], [tok] + )[0] + ) self.current_tokens[eff] = sharded_tok return sharded_tok @@ -178,6 +191,10 @@ def __init__(self, fmt: str, fun_name: str, event: str | None = None): def __enter__(self): self.start_time = time.time() + if self.event is not None: + record_scalar( + self.event, self.start_time, fun_name=self.fun_name + ) def __exit__(self, exc_type, exc_value, traceback): if _on_exit: @@ -190,8 +207,12 @@ def __exit__(self, exc_type, exc_value, traceback): logger.log(log_priority, self.fmt.format( fun_name=self.fun_name, elapsed_time=elapsed_time)) if self.event is not None: - record_event_duration_secs(self.event, elapsed_time) - record_event_time_span(self.event, self.start_time, end_time) + record_event_duration_secs( + self.event, elapsed_time, fun_name=self.fun_name + ) + record_event_time_span( + self.event, self.start_time, end_time, fun_name=self.fun_name + ) log_elapsed_time = LogElapsedTimeContextManager @@ -231,16 +252,10 @@ def jaxpr_has_prim_requiring_devices(jaxpr: core.Jaxpr) -> bool: return False -class SourceInfo(NamedTuple): - source_info: source_info_util.SourceInfo - eqn_name: str - - @util.weakref_lru_cache def get_intermediate_shardings( jaxpr: core.Jaxpr) -> Sequence[tuple[Sharding, SourceInfo]]: - from jax._src import pjit - from jax.experimental import shard_map + from jax._src import shard_map # pytype: disable=import-error out = [] for eqn in jaxpr.eqns: @@ -250,19 +265,17 @@ def get_intermediate_shardings( continue source_info = SourceInfo(eqn.source_info, eqn.primitive.name) out.append((s, source_info)) - elif eqn.primitive is pjit.pjit_p: + elif eqn.primitive is pjit.jit_p: source_info = SourceInfo(eqn.source_info, eqn.primitive.name) out.extend((i, source_info) for i in eqn.params['in_shardings']) out.extend((o, source_info) for o in eqn.params['out_shardings']) elif eqn.primitive is shard_map.shard_map_p: - if isinstance(eqn.params['mesh'], AbstractMesh): + mesh = eqn.params['mesh'] + if isinstance(mesh, AbstractMesh): continue source_info = SourceInfo(eqn.source_info, eqn.primitive.name) - def _names_to_pspec(names): - ndmin = max(names) + 1 if names else 0 - return PartitionSpec(*(names.get(i) for i in range(ndmin))) - out.extend((NamedSharding(eqn.params['mesh'], _names_to_pspec(names)), source_info) - for names in [*eqn.params['in_names'], *eqn.params['out_names']]) + out.extend((NamedSharding(mesh, spec), source_info) + for spec in [*eqn.params['in_specs'], *eqn.params['out_specs']]) elif eqn.primitive is device_put_p: source_info = SourceInfo(eqn.source_info, eqn.primitive.name) out.extend((s, source_info) for s in eqn.params['devices'] @@ -272,53 +285,12 @@ def _names_to_pspec(names): return out -def jaxpr_has_bints(jaxpr: core.Jaxpr) -> bool: - return (any(type(v.aval.dtype) is core.bint for v in jaxpr.invars - if isinstance(v.aval, core.UnshapedArray)) or - any(_is_bint_axis_size(d) - for j in itertools.chain([jaxpr], core.subjaxprs(jaxpr)) - for e in j.eqns for v in e.outvars - if isinstance(v.aval, core.DShapedArray) for d in v.aval.shape)) - -def _is_bint_axis_size(d: core.AxisSize) -> bool: - if isinstance(d, core.DArray): - assert not d.shape - return type(d.dtype) is core.bint - elif isinstance(d, core.Var): - return (isinstance(d.aval, core.DShapedArray) and - type(d.aval.dtype) is core.bint) - return False - - def check_arg(arg: Any): if not (isinstance(arg, core.Tracer) or core.valid_jaxtype(arg)): raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid " "JAX type.") -def jaxpr_replicas(jaxpr: core.Jaxpr) -> int: - """The number of replicas needed for a jaxpr. - - For a eqn, multiply the `axis_size` with the `jaxpr_replicas` of the - subjaxprs. For a list of eqns, take the maximum number of replicas. - """ - return max(unsafe_map(_eqn_replicas, jaxpr.eqns), default=1) - -# TODO(mattjj): this function assumes that only pmap has a parameter named -# axis_size, and that it corresponds to cross-replica mapping -def _eqn_replicas(eqn: core.JaxprEqn) -> int: - call_jaxpr = eqn.params.get("call_jaxpr") - if call_jaxpr: - return eqn.params.get('axis_size', 1) * jaxpr_replicas(call_jaxpr) - elif eqn.primitive in xla.initial_style_primitives: - return _initial_style_primitive_replicas(eqn.params) - else: - return 1 - -def _initial_style_primitive_replicas(params: dict[str, Any]) -> int: - return max(core.traverse_jaxpr_params(jaxpr_replicas, params).values(), - default=1) - def needs_check_special() -> bool: return config.debug_infs.value or config.debug_nans.value @@ -327,6 +299,15 @@ def check_special(name: str, bufs: Sequence[basearray.Array]) -> None: for buf in bufs: _check_special(name, buf.dtype, buf) + +def check_special_array(name: str, arr: array.ArrayImpl) -> array.ArrayImpl: + if needs_check_special(): + if dtypes.issubdtype(arr.dtype, np.inexact): + for buf in arr._arrays: + _check_special(name, buf.dtype, buf) + return arr + + def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None: if dtypes.issubdtype(dtype, np.inexact): if config.debug_nans.value and np.any(np.isnan(np.asarray(buf))): @@ -334,71 +315,22 @@ def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None: if config.debug_infs.value and np.any(np.isinf(np.asarray(buf))): raise InternalFloatingPointError(name, "inf") -class CopySemantics(enum.Enum): - ALIAS = enum.auto() - COPY = enum.auto() - DONATE = enum.auto() - -class InternalFloatingPointError(Exception): - name: str - ty: str - - def __init__(self, name: str, ty: str): - self.name = name - self.ty = ty - -def maybe_recursive_nan_check(e: Exception, fun: Callable, args, kwargs, -) -> None: # always raises an exception - print("Invalid nan value encountered in the output of a jax.jit " - "function. Calling the de-optimized version.") - try: - _ = fun(*args, **kwargs) - except (FloatingPointError, ZeroDivisionError) as e2: - raise e2 from None - else: - _raise_no_nan_in_deoptimized(e) - -def _raise_no_nan_in_deoptimized(e) -> None: - msg = (f"{str(e)}. Because " - "jax_config.debug_nans.value and/or config.jax_debug_infs is set, the " - "de-optimized function (i.e., the function as if the `jit` " - "decorator were removed) was called in an attempt to get a more " - "precise error message. However, the de-optimized function did not " - "produce invalid values during its execution. This behavior can " - "result from `jit` optimizations causing the invalid value to be " - "produced. It may also arise from having nan/inf literals as " - "inputs or outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. " - "\n\n" - "It may be possible to avoid the invalid value by removing the " - "`jit` decorator, at the cost of losing optimizations. " - "\n\n" - "If you see this error, consider opening a bug report at " - "https://github.com/jax-ml/jax.") - raise FloatingPointError(msg) from None - def _identity_fn(x): return x -def _different_device_order_reshard(x, target_sharding, copy: CopySemantics): + +def _different_device_order_reshard( + x: array.ArrayImpl, target_sharding: NamedSharding, copy: ArrayCopySemantics +) -> array.ArrayImpl: x._check_if_deleted() inp_sharding = x.sharding assert isinstance(inp_sharding, NamedSharding) - donate_argnums = 0 if copy == CopySemantics.DONATE else None + donate_argnums = 0 if copy == ArrayCopySemantics.DONATE_INPUT else None if inp_sharding._device_assignment == target_sharding._device_assignment: return api.jit(_identity_fn, out_shardings=target_sharding, donate_argnums=donate_argnums)(x) - if inp_sharding.device_set != target_sharding.device_set: - inp_ids = [d.id for d in inp_sharding._device_assignment] - inp_plat = inp_sharding._device_assignment[0].platform.upper() - target_ids = [d.id for d in target_sharding._device_assignment] - target_plat = target_sharding._device_assignment[0].platform.upper() - raise ValueError("Input and target sharding should have the same set of " - f"devices. Got input's device set ids: {inp_ids} on " - f"platform {inp_plat} and target sharding's device set " - f"ids: {target_ids} on platform {target_plat}") - if inp_sharding.is_fully_replicated: permute_order = None else: @@ -411,16 +343,44 @@ def _different_device_order_reshard(x, target_sharding, copy: CopySemantics): new_mesh, inp_sharding.spec, memory_kind=target_sharding.memory_kind, _logical_device_ids=(None if permute_order is None else tuple(permute_order.tolist()))) - new_x = _reorder_shards(x, new_s, CopySemantics.ALIAS) + new_x = xc.reorder_shards(x, new_s, ArrayCopySemantics.REUSE_INPUT) # type: ignore return api.jit(_identity_fn, out_shardings=target_sharding, donate_argnums=donate_argnums)(new_x) -def _reorder_shards(x, new_s, copy_semantics: CopySemantics): - """Reorders array shards to match the order indicated by the new sharding.""" - xc_copy_semantics = pxla.to_xc_copy_semantics([copy_semantics])[0] - return xc.reorder_shards(x, new_s, xc_copy_semantics) # type: ignore - +@util.cache(max_size=2048, trace_context_in_key=False) +def _is_supported_cross_host_transfer(ndim, src_sharding, dst_sharding): + """Returns True if src->dst is a supported cross-host transfer.""" + if (src_sharding._internal_device_list.device_kind != + dst_sharding._internal_device_list.device_kind): + return False + if (src_sharding._to_xla_hlo_sharding(ndim) != + dst_sharding._to_xla_hlo_sharding(ndim)): + return False + # This check excludes the case where the source and destination shardings + # have the same process index sets but there are shards that require + # cross-host transfers. This case is supportable but expensive to check for. + different_process_inds = ( + src_sharding._internal_device_list.process_indices != + dst_sharding._internal_device_list.process_indices) + backend = xla_bridge.get_backend() + # If a cross-host device transfer is requested but the backend does not + # support it, then the user must set the flags to enable DCN-based transfers. + if (different_process_inds and + (xla_bridge.FORCE_DCN_CROSS_HOST_TRANSFERS.value + or not getattr(backend, "supports_cross_host_transfers", False)) and + not xla_bridge.CROSS_HOST_TRANSFER_SOCKET_ADDRESS.value): + if xla_bridge.FORCE_DCN_CROSS_HOST_TRANSFERS.value: + msg = ("DCN-based cross-host transfers were requested with the " + "jax_force_dcn_cross_host_transfers flag.") + else: + msg = ("The backend ({backend.platform}, {backend.platform_version}) " + "does not support cross-host device transfers.") + raise ValueError( + f"{msg} Please set jax_cross_host_transfer_socket_address and " + "(optionally) jax_cross_host_transport_addresses flags to enable " + "DCN-based cross host device transfers.") + return different_process_inds @dataclasses.dataclass(frozen=True) class _DeferredShardArg: @@ -435,40 +395,101 @@ class _DeferredShardArg: s: Sharding aval: core.AbstractValue committed: bool - copy_semantics: CopySemantics + copy_semantics: ArrayCopySemantics def result_handler(self, shard_arg_result): return pxla.global_aval_to_result_handler( self.aval, self.s, self.committed)(shard_arg_result) +@dataclasses.dataclass(frozen=True) +class _DeferredCrossHostTransferArg: + """Deferred call to `xc.batched_copy_array_to_devices_with_sharding` for + cross-host data transfers. + + Per-array impls return this object instead of a result array to indicate a + deferred `batched_copy_array_to_devices_with_sharding` call for a cross-host + data transfer. `_batched_device_put_impl` then batches all + `_DeferredCrossHostTransferArg` objects into a single + `_batched_device_put_impl` call. -def _device_put_sharding_impl(x, aval, device, copy): - from jax.experimental import multihost_utils + For any _DeferredCrossHostTransferArg, _is_supported_cross_host_transfer( + x.ndim, x.sharding, dst_sharding) == True. + """ + + x: array.ArrayImpl + dst_sharding: Sharding + copy_semantics: ArrayCopySemantics + + +def _device_put_sharding_impl( + x: Any, + aval: core.ShapedArray, + device: Device | Sharding | None, + copy: ArrayCopySemantics, +): + from jax.experimental import multihost_utils # pytype: disable=import-error + + if isinstance(x, array.ArrayImpl): + x_is_jax_array = True + x_is_fully_addressable, x_sharding = x.is_fully_addressable, x.sharding + else: + x_is_jax_array = False + x_is_fully_addressable, x_sharding = None, None if isinstance(device, Sharding): s = device + s_is_fully_addressable = s.is_fully_addressable if (getattr(x, 'sharding', None) == s and getattr(x, '_committed', False) - and copy == CopySemantics.ALIAS): + and copy == ArrayCopySemantics.REUSE_INPUT): return x - if (not s.is_fully_addressable and - isinstance(x, array.ArrayImpl) and not x.is_fully_addressable): - assert isinstance(s, Sharding) + if (not s_is_fully_addressable and + x_is_jax_array and not x_is_fully_addressable and + s.device_set == x_sharding.device_set): + assert isinstance(s, NamedSharding), s return _different_device_order_reshard(x, s, copy) - if (s.is_fully_addressable and isinstance(x, array.ArrayImpl) and - x.is_fully_addressable and s.num_devices > 1 and - s._internal_device_list != x.sharding._internal_device_list and # pytype: disable=attribute-error - s.device_set == x.sharding.device_set): - assert isinstance(s, Sharding) + if (s_is_fully_addressable and x_is_jax_array and + x_is_fully_addressable and s.num_devices > 1 and + s._internal_device_list != x_sharding._internal_device_list and # pytype: disable=attribute-error + s.device_set == x_sharding.device_set): + assert isinstance(s, NamedSharding), s return _different_device_order_reshard(x, s, copy) - if not s.is_fully_addressable: - if ((isinstance(x, array.ArrayImpl) and not x._committed) or - type(x) in array_types): - # TODO(emilyaf): Remove this condition when jit works when a sharding - # has no local devices. - if not config.enable_empty_arrays.value: + if (x_is_jax_array and x._committed and xla_bridge.process_count() > 1 + and _is_supported_cross_host_transfer(x.ndim, x_sharding, s)): + return _DeferredCrossHostTransferArg(x, s, copy) + + if not s_is_fully_addressable: + # If both the source and target shardings are not fully addressable and + # one of the above conditions has not been met, then assume that the user + # is attempting a different device order reshard. + if (x_is_jax_array and not x_is_fully_addressable + and s.device_set != x_sharding.device_set): + inp_ids = [d.id for d in x_sharding._device_assignment] + inp_plat = x_sharding._device_assignment[0].platform.upper() + target_ids = [d.id for d in s._device_assignment] + target_plat = s._device_assignment[0].platform.upper() + raise ValueError( + "For a cross-host reshard in multi-controller JAX, input and target" + " sharding should have the same set of devices. Got input's device" + f" set ids: {inp_ids} on platform {inp_plat} and target sharding's" + f" device set ids: {target_ids} on platform {target_plat}.\n\n" + "There is experimental support for cross-host transfers with " + "different device sets, when input/output shardings have the same " + "indices and layouts, in the TFRT TPU runtime only.") + + if ((x_is_jax_array and not x._committed) or + type(x) in array_types or type(x) in dtypes.python_scalar_types): + # If all hosts participate in the sharding, assert that the input is the + # same on all hosts. If some hosts have no addressable devices in the + # sharding, bypass the check, since we can't easily distinguish between + # these two cases: (1) the sharding contains the same subset of global + # devices on all hosts (and hosts with no addressable devices in the + # sharding do not transfer data) or (2) the sharding contains a + # different subset of devices on each host. For (1), the input should be + # the same on all hosts, but for (2) it need not be. + if xla_bridge.process_count() == len(s._internal_device_list.process_indices): # pytype: disable=attribute-error multihost_utils.assert_equal( x, fail_message=( f"{type(x)} passed to device_put is not the same on each" @@ -483,18 +504,21 @@ def _device_put_sharding_impl(x, aval, device, copy): return _DeferredShardArg(x, s, aval, True, copy) # Only `Device` exists below. `Sharding` instance is handled above. - if isinstance(x, array.ArrayImpl): - if not x.is_fully_addressable: + if x_is_jax_array: + if not x_is_fully_addressable: raise ValueError( "device_put's first argument must be a fully addressable array, but " f"got value with devices {x.devices()}") if device is None: - if copy == CopySemantics.ALIAS: + if copy == ArrayCopySemantics.REUSE_INPUT: return x else: - return _DeferredShardArg(x, x.sharding, aval, x.committed, copy) - elif is_single_device_sharding(x.sharding): - device = x.sharding._device_assignment[0] if device is None else device + return _DeferredShardArg(x, x_sharding, aval, x.committed, copy) + elif is_single_device_sharding(x_sharding): + device = x_sharding._device_assignment[0] if device is None else device + if copy == ArrayCopySemantics.ALWAYS_COPY: + return xc.batched_device_put(aval, SingleDeviceSharding(device), [x], + [device], True, True) return pxla.batched_device_put(aval, SingleDeviceSharding(device), [x], [device]) @@ -504,80 +528,138 @@ def _device_put_sharding_impl(x, aval, device, copy): def _device_put_impl( - x, *, device: Device | Sharding | Layout | None, - src: Device | Sharding | Layout | None, copy: CopySemantics): - if (isinstance(device, TransferToMemoryKind) or - isinstance(src, TransferToMemoryKind)): - raise ValueError( - "TransferToMemoryKind argument to jax.device_put can only be used" - " inside jax.jit. If you are using device_put outside jax.jit, then" - " please provide a concrete Sharding with memory_kind.") - - try: - aval = core.abstractify(x) - except TypeError as err: - raise TypeError( - f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err - - if isinstance(device, Layout): + x, *, device: Device | Sharding | Format | None, + src: Device | Sharding | Format | None, copy: ArrayCopySemantics, aval): + if aval is None: + try: + aval = core.abstractify(x) + aval = update_dp_aval(aval, device) + except TypeError as err: + raise TypeError( + f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err + + if isinstance(device, core.MemorySpace): + return apply_primitive(device_put_p, x, devices=(device,), srcs=(src,), + copy_semantics=(copy,))[0] + + if isinstance(device, Format): l = device - dll = l.device_local_layout - x_dll = x.layout.device_local_layout if hasattr(x, 'layout') else None + dll = l.layout + x_dll = x.format.layout if hasattr(x, 'format') else None if dll is None and l.sharding is None: return _device_put_sharding_impl(x, aval, l.sharding, copy) if (not isinstance(l.sharding, Sharding) or - not isinstance(dll, (DeviceLocalLayout, type(None)))): + not isinstance(dll, (Layout, type(None)))): raise ValueError( - "sharding and device_local_layout in `Layout` instance should be" + "sharding and layout in `Layout` instance should be" f" concrete. Got layout: {l} for input {aval.str_short()}") - if (getattr(x, 'layout', None) == l and getattr(x, '_committed', False) and - copy == CopySemantics.ALIAS): + if (getattr(x, 'format', None) == l and getattr(x, '_committed', False) and + copy == ArrayCopySemantics.REUSE_INPUT): return x if x_dll is None and dll is None: return _device_put_sharding_impl(x, aval, l.sharding, copy) return api.jit( - _identity_fn, out_shardings=l, - donate_argnums=(0 if copy == CopySemantics.DONATE else None))(x) + _identity_fn, + out_shardings=l, + donate_argnums=(0 if copy == ArrayCopySemantics.DONATE_INPUT else None), + )(x) return _device_put_sharding_impl(x, aval, device, copy) def _batched_device_put_impl( *xs, - devices: Sequence[Device | Sharding | Layout | None], - srcs: Sequence[Device | Sharding | Layout | None], - copy_semantics: Sequence[CopySemantics]): + devices: Sequence[Device | Sharding | Format | None], + srcs: Sequence[Device | Sharding | Format | None], + copy_semantics: Sequence[ArrayCopySemantics], + dst_avals: Sequence[core.ShapedArray | None]): ys = [] + + # Used to batch transfers when _device_put_impl returns a _DeferredShardArg. dsa_indices, dsa_xs, dsa_shardings, dsa_copy_semantics = [], [], [], [] - for i, (x, device, src, cp) in enumerate(zip(xs, devices, srcs, copy_semantics)): - y = _device_put_impl(x, device=device, src=src, copy=cp) + # Used to batch transfers when _device_put_impl returns a + # _DeferredCrossHostTransferArg. + dca_indices, dca_xs, dca_shardings, dca_device_lists, dca_copy_semantics = \ + [], [], [], [], [] + + for i, (x, device, src, cp, aval) in enumerate( + zip(xs, devices, srcs, copy_semantics, dst_avals)): + y = _device_put_impl(x, device=device, src=src, copy=cp, aval=aval) if isinstance(y, _DeferredShardArg): dsa_indices.append(i) dsa_xs.append(y.x) dsa_shardings.append(y.s) dsa_copy_semantics.append(y.copy_semantics) + elif isinstance(y, _DeferredCrossHostTransferArg): + dca_indices.append(i) + dca_xs.append(y.x) + dca_shardings.append(y.dst_sharding) + dca_device_lists.append(y.dst_sharding._internal_device_list) # pytype: disable=attribute-error + dca_copy_semantics.append(y.copy_semantics) ys.append(y) + # Batch shard_arg / batched_copy_array_to_devices_with_sharding calls. Helps + # improve efficiency for backends that support efficient batch transfer. if dsa_xs: - # Batch shard_arg calls. Helps improve efficiency for backends that support - # efficient batch transfer. - # device_put handles `Layout` via a different path, so just pass `None` as + # device_put handles `Format` via a different path, so just pass `None` as # the layout here. shard_arg_results = pxla.shard_args(dsa_shardings, [None] * len(dsa_xs), dsa_copy_semantics, dsa_xs) for i, shard_arg_result in zip(dsa_indices, shard_arg_results): assert isinstance(ys[i], _DeferredShardArg) ys[i] = ys[i].result_handler(shard_arg_result) + if dca_xs: + copy_array_results = xc.batched_copy_array_to_devices_with_sharding( + dca_xs, dca_device_lists, dca_shardings, dca_copy_semantics) + for i, copy_array_result in zip(dca_indices, copy_array_results): + assert isinstance(ys[i], _DeferredCrossHostTransferArg) + ys[i] = copy_array_result return ys +def batched_device_put_impl( + *xs, + devices: Sequence[Device | Sharding | Format | None], + srcs: Sequence[Device | Sharding | Format | None], + copy_semantics: Sequence[ArrayCopySemantics]): + return _batched_device_put_impl( + *xs, devices=devices, srcs=srcs, copy_semantics=copy_semantics, + dst_avals=[None] * len(devices)) + device_put_p = core.Primitive('device_put') device_put_p.multiple_results = True -device_put_p.def_impl(_batched_device_put_impl) +device_put_p.def_impl(batched_device_put_impl) + + +def _device_put_folding_rule(consts, params, out_avals): + # We elide device_puts that do nothing; these can be generated by jnp.array, + # for example. + if (all(x is None for x in params["devices"]) + and all(isinstance(x, literals.TypedNdArray) for x in consts) + and all(x == ArrayCopySemantics.REUSE_INPUT for x in params["copy_semantics"])): + return consts + return None + +partial_eval.const_fold_rules[device_put_p] = _device_put_folding_rule + + +def update_dp_aval(aval, d): + if not isinstance(aval, core.ShapedArray): + return aval + if isinstance(d, Sharding): + aval = (aval.update(sharding=aval.sharding.update(mesh=d.mesh.abstract_mesh, + spec=d.spec)) + if isinstance(d, NamedSharding) else aval.update(sharding=None)) + if d.memory_kind is not None: + aval = aval.update(memory_space=core.mem_kind_to_space(d.memory_kind)) + return aval + elif isinstance(d, core.MemorySpace): + return aval.update(memory_space=d) + return aval def _device_put_abstract_eval(*xs, devices, srcs, copy_semantics): - return xs + return [update_dp_aval(x, d) for x, d in zip(xs, devices)] device_put_p.def_abstract_eval(_device_put_abstract_eval) def _device_put_transpose(cts, *_, devices, srcs, copy_semantics): @@ -590,17 +672,17 @@ def _device_put_transpose(cts, *_, devices, srcs, copy_semantics): indices, args, devices, srcs, copy_semantics = list(zip(*dp_args)) new_copy_semantics = [] for cp in copy_semantics: - if cp == CopySemantics.DONATE: + if cp == ArrayCopySemantics.DONATE_INPUT: raise ValueError( "donate=True is not allowed during tranposition of device_put." " Please file an issue if you want this to be supported.") - elif cp == CopySemantics.ALIAS: - new_copy_semantics.append(CopySemantics.COPY) + elif cp == ArrayCopySemantics.REUSE_INPUT: + new_copy_semantics.append(ArrayCopySemantics.ALWAYS_COPY) else: - assert cp == CopySemantics.COPY - new_copy_semantics.append(CopySemantics.COPY) + assert cp == ArrayCopySemantics.ALWAYS_COPY + new_copy_semantics.append(ArrayCopySemantics.ALWAYS_COPY) ys = device_put_p.bind(*args, devices=srcs, srcs=devices, - copy_semantics=new_copy_semantics) + copy_semantics=tuple(new_copy_semantics)) for i, y in zip(indices, ys): results[i] = y return results @@ -621,8 +703,8 @@ def _tpu_gpu_device_put_lowering(ctx, *xs, devices, srcs, copy_semantics): if ctx.module_context.all_default_mem_kind: return xs def lower(x, device, aval, out_aval): - if (isinstance(device, (Sharding, TransferToMemoryKind)) and - device.memory_kind is not None): + if ((isinstance(device, Sharding) and device.memory_kind is not None) or + isinstance(device, core.MemorySpace)): if isinstance(device, Sharding): if config.use_shardy_partitioner.value: x = mlir.wrap_with_sharding_op( @@ -632,7 +714,9 @@ def lower(x, device, aval, out_aval): x = mlir.wrap_with_sharding_op( ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto()) - x = mlir.wrap_with_memory_kind(x, device.memory_kind, out_aval) + mem_kind = (core.mem_space_to_kind(device) + if isinstance(device, core.MemorySpace) else device.memory_kind) + x = mlir.wrap_with_memory_kind(x, mem_kind, out_aval) return x return x return list(map(lower, xs, devices, ctx.avals_in, ctx.avals_out)) @@ -646,13 +730,3 @@ def lower(x, device, aval, out_aval): def _common_device_put_lowering(ctx, *xs, devices, srcs, copy_semantics): return xs mlir.register_lowering(device_put_p, _common_device_put_lowering) - -def _propagate_mem_kind_dp(*xm, devices, srcs, copy_semantics): - memory_kinds = [] - for device in devices: - if isinstance(device, (Sharding, TransferToMemoryKind)): - memory_kinds.append(device.memory_kind) - else: - memory_kinds.append(None) - return memory_kinds -pxla.memory_kind_propagate_rule[device_put_p] = _propagate_mem_kind_dp diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index af50e2e9e31a..b6f1ab7443d6 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -18,11 +18,12 @@ import logging import os from typing import Any +import warnings from jax._src import clusters from jax._src import config from jax._src import xla_bridge -from jax._src.lib import xla_extension +from jax._src.lib import _jax logger = logging.getLogger(__name__) @@ -34,13 +35,32 @@ ) +_ENABLE_RECOVERABILITY = config.bool_state( + name="jax_enable_recoverability", + default=False, + help=( + "Allows a multi-controller JAX job to continue running, even after some" + " tasks have failed." + ), +) + +_ENABLE_PREEMPTION_SERVICE = config.bool_state( + name='jax_enable_preemption_service', + default=True, + help=( + "Enables the preemption service. See" + " multihost_utils.reached_preemption_sync_point for details." + ), +) + class State: process_id: int = 0 num_processes: int = 1 - service: Any | None = None - client: Any | None = None + service: _jax.DistributedRuntimeService | Any | None = None + client: _jax.DistributedRuntimeClient | Any | None = None preemption_sync_manager: Any | None = None coordinator_address: str | None = None + partition_index: int | None = None def initialize(self, coordinator_address: str | None = None, @@ -50,10 +70,9 @@ def initialize(self, cluster_detection_method: str | None = None, initialization_timeout: int = 300, coordinator_bind_address: str | None = None, - service_heartbeat_interval_seconds: int = 10, - service_max_missing_heartbeats: int = 10, - client_heartbeat_interval_seconds: int = 10, - client_max_missing_heartbeats: int = 10): + heartbeat_timeout_seconds: int = 100, + shutdown_timeout_seconds: int = 300, + partition_index: int | None = None): coordinator_address = (coordinator_address or os.environ.get('JAX_COORDINATOR_ADDRESS')) if isinstance(local_device_ids, int): @@ -130,41 +149,65 @@ def initialize(self, logger.info( 'Starting JAX distributed service on %s', coordinator_bind_address ) - self.service = xla_extension.get_distributed_runtime_service( + self.service = _jax.get_distributed_runtime_service( coordinator_bind_address, num_processes, - heartbeat_interval=service_heartbeat_interval_seconds, - max_missing_heartbeats=service_max_missing_heartbeats) + heartbeat_timeout=heartbeat_timeout_seconds, + shutdown_timeout=shutdown_timeout_seconds) self.num_processes = num_processes if self.client is not None: raise RuntimeError('distributed.initialize should only be called once.') - self.client = xla_extension.get_distributed_runtime_client( + self.client = _jax.get_distributed_runtime_client( coordinator_address, process_id, init_timeout=initialization_timeout, - heartbeat_interval=client_heartbeat_interval_seconds, - max_missing_heartbeats=client_max_missing_heartbeats, use_compression=True) + use_compression=True, heartbeat_timeout=heartbeat_timeout_seconds, + recoverable=_ENABLE_RECOVERABILITY.value) # type: ignore logger.info('Connecting to JAX distributed service on %s', coordinator_address) self.client.connect() self.initialize_preemption_sync_manager() + if partition_index is None: + jax_partition_index = os.environ.get('JAX_PARTITION_INDEX') + jax_slice_index = os.environ.get('JAX_SLICE_INDEX') + if jax_partition_index is not None: + partition_index = int(jax_partition_index) # type: ignore + elif jax_slice_index is not None: + # Deprecation added 2025-08-05. Should be removed after 3 months. + warnings.warn( + 'JAX_SLICE_INDEX has been deprecated. Please use' + ' JAX_PARTITION_INDEX instead.', + DeprecationWarning, + ) + partition_index = int(jax_slice_index) # type: ignore + self.partition_index = partition_index + def shutdown(self): + if self.preemption_sync_manager: + # It's important to shut down the preemption sync manager before the + # client because the preemption sync manager depends on the client. + self.preemption_sync_manager.shutdown() + self.preemption_sync_manager = None if self.client: self.client.shutdown() self.client = None if self.service: self.service.shutdown() self.service = None - if self.preemption_sync_manager: - self.preemption_sync_manager = None def initialize_preemption_sync_manager(self): + if not _ENABLE_PREEMPTION_SERVICE.value: + logger.info( + 'The JAX preemption service is disabled. You can enable it using the' + ' jax_enable_preemption_service configuration option.' + ) + return if self.preemption_sync_manager is not None: raise RuntimeError( 'Preemption sync manager should only be initialized once.') self.preemption_sync_manager = ( - xla_extension.create_preemption_sync_manager()) + _jax.create_preemption_sync_manager()) self.preemption_sync_manager.initialize(self.client) global_state = State() @@ -175,7 +218,11 @@ def initialize(coordinator_address: str | None = None, local_device_ids: int | Sequence[int] | None = None, cluster_detection_method: str | None = None, initialization_timeout: int = 300, - coordinator_bind_address: str | None = None): + heartbeat_timeout_seconds: int = 100, + shutdown_timeout_seconds: int = 300, + coordinator_bind_address: str | None = None, + slice_index: int | None = None, + partition_index: int | None = None): """Initializes the JAX distributed system. Calling :func:`~jax.distributed.initialize` prepares JAX for execution on @@ -231,11 +278,19 @@ def initialize(coordinator_address: str | None = None, initialization_timeout: Time period (in seconds) for which connection will be retried. If the initialization takes more than the timeout specified, the initialization will error. Defaults to 300 secs i.e. 5 mins. + heartbeat_timeout_seconds: The time (in seconds) after which a process is + considered dead if it hasn't successfully sent any heartbeats. Defaults + to 100 seconds. + shutdown_timeout_seconds: The time (in seconds) a terminating process will + wait for all other processes to also terminate. Defaults to 300 seconds. coordinator_bind_address: the address and port to which the coordinator service on process `0` should bind. If this is not specified, the default is to bind to all available addresses on the same port as ``coordinator_address``. On systems that have multiple network interfaces per node it may be insufficient to only have the coordinator service listen on one address/interface. + slice_index: DEPRECATED: Use ``partition_index`` instead. + partition_index: The partition index assigned to this process' local devices. If any process sets ``partition_index``, + then all processes must do so. If ``None`` the partition indices will be chosen automatically. Raises: RuntimeError: If :func:`~jax.distributed.initialize` is called more than once @@ -259,9 +314,20 @@ def initialize(coordinator_address: str | None = None, raise RuntimeError("jax.distributed.initialize() must be called before " "any JAX calls that might initialise the XLA backend. " "This includes any computation, but also calls to jax.devices, jax.device_put, and others.") + if partition_index is None: + if slice_index is not None: + # Deprecation added 2025-08-05. Should be removed after 3 months. + warnings.warn( + '`slice_index` has been deprecated. Please use `partition_index` instead.', + DeprecationWarning, + ) + partition_index = slice_index global_state.initialize(coordinator_address, num_processes, process_id, local_device_ids, cluster_detection_method, - initialization_timeout, coordinator_bind_address) + initialization_timeout, coordinator_bind_address, + heartbeat_timeout_seconds=heartbeat_timeout_seconds, + shutdown_timeout_seconds=shutdown_timeout_seconds, + partition_index=partition_index) def is_initialized() -> bool: diff --git a/jax/_src/dlpack.py b/jax/_src/dlpack.py index a0b1db608ad0..1ba5422012ef 100644 --- a/jax/_src/dlpack.py +++ b/jax/_src/dlpack.py @@ -16,16 +16,19 @@ from typing import Any -from jax import numpy as jnp from jax._src import array -from jax._src import deprecations +from jax._src import dtypes from jax._src import xla_bridge from jax._src.api import device_put from jax._src.lax.lax import _array_copy +from jax._src.lib import _jax from jax._src.lib import xla_client +from jax._src.numpy import lax_numpy as jnp +from jax._src.numpy import scalar_types as jnp_types from jax._src.sharding import Sharding -from jax._src.typing import Array -from jax._src.typing import DLDeviceType +from jax._src.typing import Array, DLDeviceType, DTypeLike + +import numpy as np DLPACK_VERSION = (0, 8) @@ -37,16 +40,28 @@ # For example, # hash(jnp.float32) != hash(jnp.dtype(jnp.float32)) # hash(jnp.float32) == hash(jnp.dtype(jnp.float32).type) -# TODO(phawkins): Migrate to using dtypes instead of the scalar type objects. -SUPPORTED_DTYPES = frozenset({ - jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16, - jnp.uint32, jnp.uint64, jnp.float16, jnp.bfloat16, jnp.float32, - jnp.float64, jnp.complex64, jnp.complex128, jnp.bool_}) + +# TODO(vanderplas): remove this set +SUPPORTED_DTYPES: frozenset[DTypeLike] = frozenset({ + jnp_types.int8, jnp_types.int16, jnp_types.int32, jnp_types.int64, + jnp_types.uint8, jnp_types.uint16, jnp_types.uint32, jnp_types.uint64, + jnp_types.float16, jnp_types.bfloat16, jnp_types.float32, jnp_types.float64, + jnp_types.complex64, jnp_types.complex128, jnp_types.bool_}) + +SUPPORTED_DTYPES_SET: frozenset[np.dtype] = frozenset({np.dtype(dt) for dt in SUPPORTED_DTYPES}) + + +def is_supported_dtype(dtype: DTypeLike) -> bool: + """Check if dtype is supported by jax.dlpack.""" + if dtype is None: + # NumPy will silently cast this to float64, which may be surprising. + raise TypeError(f"Expected a string or dtype-like object; got {dtype=}") + return np.dtype(dtype) in SUPPORTED_DTYPES_SET def _to_dlpack(x: Array, stream: int | Any | None, - src_device: xla_client.Device | None = None, - device: xla_client.Device | None = None, + src_device: _jax.Device | None = None, + device: _jax.Device | None = None, copy: bool | None = None): if src_device is None: @@ -62,7 +77,7 @@ def _to_dlpack(x: Array, stream: int | Any | None, arr = device_put(x, device) else: arr = _array_copy(x) if copy else x - return xla_client._xla.buffer_to_dlpack_managed_tensor( + return _jax.buffer_to_dlpack_managed_tensor( arr.addressable_data(0), stream=stream ) @@ -75,7 +90,7 @@ def _to_dlpack(x: Array, stream: int | Any | None, def to_dlpack(x: Array, stream: int | Any | None = None, - src_device: xla_client.Device | None = None, + src_device: _jax.Device | None = None, dl_device: tuple[DLDeviceType, int] | None = None, max_version: tuple[int, int] | None = None, copy : bool | None = None): @@ -130,7 +145,7 @@ def to_dlpack(x: Array, stream: int | Any | None = None, ) from None # As new versions are adopted over time, we can maintain some legacy paths - # for compatability mediated through the max_version parameter. + # for compatibility mediated through the max_version parameter. # TODO(micky774): Deprecate default usage of DLPackManagedTensor when XLA # supports DLManagedTensorVersioned (DLPack version 1.0) and repurpose the # current _to_dlpack as a legacy path for (0,5) <= max_version < (1,0). @@ -156,7 +171,7 @@ def to_dlpack(x: Array, stream: int | Any | None = None, f"version ({max_version}) was requested." ) -def _place_array(_arr, device, dlpack_device, copy): +def _check_device(device, dlpack_device, copy): if device and dlpack_device != device: if copy is not None and not copy: raise ValueError( @@ -164,75 +179,23 @@ def _place_array(_arr, device, dlpack_device, copy): f"is {repr(dlpack_device)}, however copy=False. Set copy=True or " "copy=None to perform the requested operation." ) - else: - return device_put(_arr, device) + +def _place_array(_arr, device, dlpack_device, copy): + if device and dlpack_device != device: + return device_put(_arr, device) if copy: return jnp.array(_arr, copy=True) return _arr -def _legacy_from_dlpack(dlpack, device: xla_client.Device | None = None, - copy: bool | None = None): - preferred_platform = getattr(device, "platform", None) - if device and preferred_platform == "gpu": - preferred_platform = "cuda" if "cuda" in device.client.platform_version else "rocm" - - cpu_backend = xla_bridge.get_backend("cpu") - gpu_backend = None - - if preferred_platform in {"cuda", "rocm"}: - try: - gpu_backend = xla_bridge.get_backend(preferred_platform) - except RuntimeError: - raise TypeError( - f"A {str.upper(preferred_platform)} device was specified, however no " - f"{str.upper(preferred_platform)} backend was found." - ) - - if preferred_platform is None: - try: - gpu_backend = xla_bridge.get_backend("cuda") - except RuntimeError: - pass - # Try ROCm if CUDA backend not found - if gpu_backend is None: - try: - gpu_backend = xla_bridge.get_backend("rocm") - except RuntimeError: - pass - - _arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer( - dlpack, cpu_backend, gpu_backend)) - dlpack_device, = _arr.devices() - return _place_array(_arr, device, dlpack_device, copy) - -def _from_dlpack(external_array, device: xla_client.Device | None = None, - copy: bool | None = None): - dl_device_type, device_id = external_array.__dlpack_device__() - try: - dl_device_platform = _DL_DEVICE_TO_PLATFORM[dl_device_type] - except KeyError: - raise TypeError( - "Array passed to from_dlpack is on unsupported device type " - f"(DLDeviceType: {dl_device_type}, array: {external_array}" - ) from None - - backend = xla_bridge.get_backend(dl_device_platform) - dlpack_device = backend.device_from_local_hardware_id(device_id) - try: - stream = dlpack_device.get_stream_for_external_ready_events() - except xla_client.XlaRuntimeError as err: - if "UNIMPLEMENTED" in str(err): - stream = None - else: - raise - dlpack = external_array.__dlpack__(stream=stream) - - _arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer( - dlpack, dlpack_device, stream)) - return _place_array(_arr, device, dlpack_device, copy) +def _is_tensorflow_tensor(external_array): + t = type(external_array) + return ( + t.__qualname__ == "EagerTensor" + and t.__module__.endswith("tensorflow.python.framework.ops") + ) def from_dlpack(external_array, - device: xla_client.Device | Sharding | None = None, + device: _jax.Device | Sharding | None = None, copy: bool | None = None): """Returns a :class:`~jax.Array` representation of a DLPack tensor. @@ -240,7 +203,7 @@ def from_dlpack(external_array, device transfer or copy was requested. Args: - external_array: An array object that has ``__dlpack__` and + external_array: An array object that has ``__dlpack__`` and ``__dlpack_device__`` methods. device: The (optional) :py:class:`Device`, representing the device on which the returned array should be placed. If given, then the result is @@ -272,18 +235,54 @@ def from_dlpack(external_array, f"a Sharding with {len(device_set)} devices was provided." ) device, = device_set - if hasattr(external_array, "__dlpack__"): - return _from_dlpack(external_array, device, copy) - - # Deprecated legacy path. - # TODO(slebedev): Remove on or after December 3rd 2023. - deprecations.warn( - "jax-dlpack-import-legacy", - ( - "Calling from_dlpack with a DLPack tensor is deprecated. The argument" - " to from_dlpack should be an array from another framework that" - " implements the __dlpack__ protocol." - ), - stacklevel=2, - ) - return _legacy_from_dlpack(external_array, device, copy) + if not hasattr(external_array, "__dlpack__") or not hasattr(external_array, "__dlpack_device__"): + raise TypeError( + "The array passed to from_dlpack must have __dlpack__ and __dlpack_device__ methods." + ) + + dl_device_type, device_id = external_array.__dlpack_device__() + try: + dl_device_platform = _DL_DEVICE_TO_PLATFORM[dl_device_type] + except KeyError: + raise TypeError( + "Array passed to from_dlpack is on unsupported device type " + f"(DLDeviceType: {dl_device_type}, array: {external_array}" + ) from None + + backend = xla_bridge.get_backend(dl_device_platform) + dlpack_device = backend.device_from_local_hardware_id(device_id) + _check_device(device, dlpack_device, copy) + if _is_tensorflow_tensor(external_array): + # TensorFlow does not support stream=. + stream = None + else: + try: + stream = dlpack_device.get_stream_for_external_ready_events() + except _jax.JaxRuntimeError as err: + if "UNIMPLEMENTED" in str(err): + stream = None + else: + raise + dlpack = external_array.__dlpack__(stream=stream) + + try: + arr = _jax.dlpack_managed_tensor_to_buffer( + dlpack, dlpack_device, stream, copy) + except xla_client.XlaRuntimeError as e: + se = str(e) + if "is not aligned to" in se: + i = se.index("is not aligned to") + raise ValueError( + "Specified input which requires a copy since the source data " + f"buffer {se[i:]} However copy=False. Set copy=True or " + "copy=None to perform the requested operation." + ) + else: + raise + # TODO(phawkins): when we are ready to support x64 arrays in + # non-x64 mode, change the semantics to not canonicalize here. + arr = jnp.asarray(arr, dtype=dtypes.canonicalize_dtype(arr.dtype)) + if copy: + # copy was already handled by dlpack_managed_tensor_to_buffer. + copy = None + return _place_array(arr, device, dlpack_device, copy) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 01500c008405..96c821ab2953 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -22,17 +22,17 @@ from __future__ import annotations import abc -import builtins import dataclasses import functools import types -from typing import cast, overload, Any, Literal, Union +from typing import cast, overload, Any, Callable, Literal, Union import warnings import ml_dtypes import numpy as np from jax._src import config +from jax._src import literals from jax._src.typing import Array, DType, DTypeLike from jax._src.util import set_module, StrictABC @@ -44,8 +44,8 @@ except: pass else: - if _ml_dtypes_version < (0, 2, 0): - raise ValueError("JAX requires ml_dtypes version 0.2.0 or newer; " + if _ml_dtypes_version < (0, 5): + raise ValueError("JAX requires ml_dtypes version 0.5 or newer; " f"installed version is {ml_dtypes.__version__}.") export = set_module('jax.dtypes') @@ -90,19 +90,18 @@ def type(self) -> type: ... # fp8 support -# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0 -float8_e3m4: type[np.generic] | None = None -float8_e4m3: type[np.generic] | None = None -float8_e8m0fnu: type[np.generic] | None = None +float8_e3m4: type[np.generic] = ml_dtypes.float8_e3m4 +float8_e4m3: type[np.generic] = ml_dtypes.float8_e4m3 +float8_e8m0fnu: type[np.generic] = ml_dtypes.float8_e8m0fnu float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn float8_e4m3fnuz: type[np.generic] = ml_dtypes.float8_e4m3fnuz float8_e5m2: type[np.generic] = ml_dtypes.float8_e5m2 float8_e5m2fnuz: type[np.generic] = ml_dtypes.float8_e5m2fnuz -_float8_e3m4_dtype: np.dtype | None = None -_float8_e4m3_dtype: np.dtype | None = None -_float8_e8m0fnu_dtype: np.dtype | None = None +_float8_e3m4_dtype: np.dtype = np.dtype(float8_e3m4) +_float8_e4m3_dtype: np.dtype = np.dtype(float8_e4m3) +_float8_e8m0fnu_dtype: np.dtype = np.dtype(float8_e8m0fnu) _float8_e4m3b11fnuz_dtype: np.dtype = np.dtype(float8_e4m3b11fnuz) _float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn) _float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz) @@ -110,10 +109,9 @@ def type(self) -> type: ... _float8_e5m2fnuz_dtype: np.dtype = np.dtype(float8_e5m2fnuz) # fp4 support -# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0 -float4_e2m1fn: type[np.generic] | None = None +float4_e2m1fn: type[np.generic] = ml_dtypes.float4_e2m1fn -_float4_e2m1fn_dtype: np.dtype | None = None +_float4_e2m1fn_dtype: np.dtype = np.dtype(float4_e2m1fn) def supports_inf(dtype: DTypeLike) -> bool: """Return true if the dtype supports infinity, else return False.""" @@ -127,6 +125,10 @@ def supports_inf(dtype: DTypeLike) -> bool: _bfloat16_dtype: np.dtype = np.dtype(bfloat16) _custom_float_scalar_types = [ + float4_e2m1fn, + float8_e3m4, + float8_e4m3, + float8_e8m0fnu, float8_e4m3b11fnuz, float8_e4m3fn, float8_e4m3fnuz, @@ -135,6 +137,10 @@ def supports_inf(dtype: DTypeLike) -> bool: bfloat16, ] _custom_float_dtypes = [ + _float4_e2m1fn_dtype, + _float8_e3m4_dtype, + _float8_e4m3_dtype, + _float8_e8m0fnu_dtype, _float8_e4m3b11fnuz_dtype, _float8_e4m3fn_dtype, _float8_e4m3fnuz_dtype, @@ -143,6 +149,9 @@ def supports_inf(dtype: DTypeLike) -> bool: _bfloat16_dtype, ] _float8_dtypes = [ + _float8_e3m4_dtype, + _float8_e4m3_dtype, + _float8_e8m0fnu_dtype, _float8_e4m3b11fnuz_dtype, _float8_e4m3fn_dtype, _float8_e4m3fnuz_dtype, @@ -150,83 +159,71 @@ def supports_inf(dtype: DTypeLike) -> bool: _float8_e5m2fnuz_dtype, ] -_float4_dtypes: list[np.dtype] = [] - -# TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0 -if hasattr(ml_dtypes, "float8_e4m3"): - float8_e4m3 = ml_dtypes.float8_e4m3 - _float8_e4m3_dtype = np.dtype(float8_e4m3) - _custom_float_scalar_types.insert(0, float8_e4m3) # type: ignore[arg-type] - _custom_float_dtypes.insert(0, _float8_e4m3_dtype) - _float8_dtypes.insert(0, _float8_e4m3_dtype) -if hasattr(ml_dtypes, "float8_e3m4"): - float8_e3m4 = ml_dtypes.float8_e3m4 - _float8_e3m4_dtype = np.dtype(float8_e3m4) - _custom_float_scalar_types.insert(0, float8_e3m4) # type: ignore[arg-type] - _custom_float_dtypes.insert(0, _float8_e3m4_dtype) - _float8_dtypes.insert(0, _float8_e3m4_dtype) -if hasattr(ml_dtypes, "float8_e8m0fnu"): - float8_e8m0fnu = ml_dtypes.float8_e8m0fnu - _float8_e8m0fnu_dtype = np.dtype(float8_e8m0fnu) - _custom_float_scalar_types.insert(0, float8_e8m0fnu) # type: ignore[arg-type] - _custom_float_dtypes.insert(0, _float8_e8m0fnu_dtype) - _float8_dtypes.insert(0, _float8_e8m0fnu_dtype) -if hasattr(ml_dtypes, "float4_e2m1fn"): - float4_e2m1fn = ml_dtypes.float4_e2m1fn - _float4_e2m1fn_dtype = np.dtype(float4_e2m1fn) - _custom_float_scalar_types.insert(0, float4_e2m1fn) # type: ignore[arg-type] - _custom_float_dtypes.insert(0, _float4_e2m1fn_dtype) - _float4_dtypes.insert(0, _float4_e2m1fn_dtype) - -# 2-bit integer support -int2: type[np.generic] | None = None -uint2: type[np.generic] | None = None - -_int2_dtype: np.dtype | None = None -_uint2_dtype: np.dtype | None = None - -_intn_dtypes = [] - -# Remove the condition once the minimum ml_dtypes version required by JAX -# contains https://github.com/jax-ml/ml_dtypes/pull/154. -if hasattr(ml_dtypes, 'int2'): - int2 = ml_dtypes.int2 - uint2 = ml_dtypes.uint2 - _int2_dtype = np.dtype(int2) - _uint2_dtype = np.dtype(uint2) - _intn_dtypes.extend([_int2_dtype, _uint2_dtype]) +_float4_dtypes: list[np.dtype] = [ + _float4_e2m1fn_dtype, +] + +int2: type[np.generic] = ml_dtypes.int2 +uint2: type[np.generic] = ml_dtypes.uint2 + +_int2_dtype: np.dtype = np.dtype(int2) +_uint2_dtype: np.dtype = np.dtype(uint2) # 4-bit integer support int4: type[np.generic] = ml_dtypes.int4 uint4: type[np.generic] = ml_dtypes.uint4 _int4_dtype = np.dtype(int4) _uint4_dtype = np.dtype(uint4) -_intn_dtypes.extend([_int4_dtype, _uint4_dtype]) + +_intn_dtypes = [ + _int2_dtype, + _uint2_dtype, + _int4_dtype, + _uint4_dtype, +] # Default types. bool_ = np.bool_ -int_: type[Any] -uint: type[Any] -float_: type[Any] -complex_: type[Any] -if config.default_dtype_bits.value == '32': - int_ = np.int32 - uint = np.uint32 - float_ = np.float32 - complex_ = np.complex64 -else: - int_ = np.int64 - uint = np.uint64 - float_ = np.float64 - complex_ = np.complex128 -_default_types: dict[str, type[Any]] = { - 'b': bool_, - 'i': int_, - 'u': uint, - 'f': float_, - 'c': complex_, -} +int_: type[Any] = np.int64 +uint: type[Any] = np.uint64 +float_: type[Any] = np.float64 +complex_: type[Any] = np.complex128 + +# Default dtypes. These are intended to have the same semantics as, say, +# canonicalize_dtype(np.float64), but are preparing for the reduction in the +# number of places we perform dtype canonicalization. + + +def default_int_dtype() -> DType: + return np.dtype(np.int64) if config.enable_x64.value else np.dtype(np.int32) + + +def default_uint_dtype() -> DType: + return np.dtype(np.uint64) if config.enable_x64.value else np.dtype(np.uint32) + + +def default_float_dtype() -> DType: + return ( + np.dtype(np.float64) if config.enable_x64.value else np.dtype(np.float32) + ) + + +def default_complex_dtype() -> DType: + return ( + np.dtype(np.complex128) + if config.enable_x64.value + else np.dtype(np.complex64) + ) + + +default_types: dict[str, Callable[[], DType]] = { + 'b': lambda: np.dtype(bool), + 'i': default_int_dtype, + 'u': default_uint_dtype, + 'f': default_float_dtype, + 'c': default_complex_dtype, +} def jax_dtype(obj: DTypeLike | None, *, align: bool = False, copy: bool = False) -> DType: @@ -235,24 +232,26 @@ def jax_dtype(obj: DTypeLike | None, *, align: bool = False, Arguments mirror those of :func:`numpy.dtype`. """ if obj is None: - obj = float_ + obj = default_float_dtype() elif issubdtype(obj, extended): return obj # type: ignore[return-value] - elif isinstance(obj, type): - obj = _DEFAULT_TYPEMAP.get(obj, obj) + elif isinstance(obj, type) and (f := _DEFAULT_TYPEMAP.get(obj)) is not None: + obj = f() return np.dtype(obj, align=align, copy=copy) -_DEFAULT_TYPEMAP: dict[type, DTypeLike] = { - bool: bool, - int: int_, - float: float_, - complex: complex_, +_DEFAULT_TYPEMAP: dict[type, Callable[[], np.dtype]] = { + bool: lambda: np.dtype(bool), + int: default_int_dtype, + float: default_float_dtype, + complex: default_complex_dtype, } -def bit_width(dtype: DTypeLike) -> int: +def itemsize_bits(dtype: DTypeLike) -> int: """Number of bits per element for the dtype.""" # Note: we cannot use dtype.itemsize here because this is # incorrect for sub-byte integer types. + if dtype is None: + raise ValueError("dtype cannot be None.") if dtype == np.dtype(bool): return 8 # physical bit layout for boolean dtype elif issubdtype(dtype, np.integer): @@ -280,6 +279,7 @@ def bit_width(dtype: DTypeLike) -> int: _dtype_to_inexact: dict[DType, DType] = { np.dtype(k): np.dtype(v) for k, v in [ ('bool', 'float32'), + ('uint4', 'float32'), ('int4', 'float32'), ('uint8', 'float32'), ('int8', 'float32'), ('uint16', 'float32'), ('int16', 'float32'), ('uint32', 'float32'), ('int32', 'float32'), @@ -299,6 +299,12 @@ def to_inexact_dtype(dtype: DTypeLike) -> DType: return _dtype_to_inexact.get(dtype_, dtype_) +def to_floating_dtype(dtype: DTypeLike) -> DType: + """Promotes a dtype to a non-complex floating dtype.""" + dtype_ = np.dtype(dtype) + return finfo(_dtype_to_inexact.get(dtype_, dtype_)).dtype + + def to_complex_dtype(dtype: DTypeLike) -> DType: ftype = to_inexact_dtype(dtype) if ftype in [np.dtype('float64'), np.dtype('complex128')]: @@ -334,8 +340,38 @@ def canonicalize_dtype(dtype: Any, allow_extended_dtype: bool = False) -> DType """Convert from a dtype to a canonical dtype based on config.x64_enabled.""" return _canonicalize_dtype(config.enable_x64.value, allow_extended_dtype, dtype) # pytype: disable=bad-return-type +class InvalidInputException(Exception): + pass + +canonicalize_value_handlers: dict[Any, Callable] = {} + + +# TODO(mattjj): try to remove this canonicalize_dtype stuff +def canonicalize_value(x): + typ = type(x) + handler = canonicalize_value_handlers.get(typ) + if handler: + return handler(x) + for typ in typ.__mro__: + handler = canonicalize_value_handlers.get(typ) + if handler: + return handler(x) + if hasattr(x, '__jax_array__'): + raise ValueError( + 'Triggering __jax_array__() during abstractification is no longer' + ' supported. To avoid this error, either explicitly convert your object' + ' using jax.numpy.array(), or register your object as a pytree.' + ) + raise InvalidInputException( + f"Argument '{x}' of type {type(x)} is not a valid JAX type." + ) + + +# The list of all known Python scalar types. +python_scalar_types: set[type] = {bool, int, float, complex} + # Default dtypes corresponding to Python scalars. -python_scalar_dtypes : dict[type, DType] = { +python_scalar_types_to_dtypes: dict[type, DType] = { bool: np.dtype('bool'), int: np.dtype('int64'), float: np.dtype('float64'), @@ -362,7 +398,7 @@ def scalar_type_of(x: Any) -> type: raise TypeError(f"Invalid scalar value {x}") -def _scalar_type_to_dtype(typ: type, value: Any = None) -> DType: +def scalar_type_to_dtype(typ: type, value: Any = None) -> DType: """Return the numpy dtype for the given scalar type. Raises @@ -371,23 +407,24 @@ def _scalar_type_to_dtype(typ: type, value: Any = None) -> DType: Examples -------- - >>> _scalar_type_to_dtype(int) + >>> scalar_type_to_dtype(int) dtype('int32') - >>> _scalar_type_to_dtype(float) + >>> scalar_type_to_dtype(float) dtype('float32') - >>> _scalar_type_to_dtype(complex) + >>> scalar_type_to_dtype(complex) dtype('complex64') - >>> _scalar_type_to_dtype(int) + >>> scalar_type_to_dtype(int) dtype('int32') - >>> _scalar_type_to_dtype(int, 0) + >>> scalar_type_to_dtype(int, 0) dtype('int32') - >>> _scalar_type_to_dtype(int, 1 << 63) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> scalar_type_to_dtype(int, 1 << 63) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): OverflowError: Python int 9223372036854775808 too large to convert to int32 """ - dtype = canonicalize_dtype(python_scalar_dtypes[typ]) + dtype = canonicalize_dtype(python_scalar_types_to_dtypes[typ]) if typ is int and value is not None: - if value < np.iinfo(dtype).min or value > np.iinfo(dtype).max: + iinfo = np.iinfo(dtype) + if value < iinfo.min or value > iinfo.max: raise OverflowError(f"Python int {value} too large to convert to {dtype}") return dtype @@ -398,8 +435,8 @@ def coerce_to_array(x: Any, dtype: DTypeLike | None = None) -> np.ndarray: Handles Python scalar type promotion according to JAX's rules, not NumPy's rules. """ - if dtype is None and type(x) in python_scalar_dtypes: - dtype = _scalar_type_to_dtype(type(x), x) + if dtype is None and type(x) in python_scalar_types: + dtype = scalar_type_to_dtype(type(x), x) return np.asarray(x, dtype) iinfo = ml_dtypes.iinfo @@ -472,9 +509,9 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType, # to the normal scalar type hierarchy. if a_sctype in _custom_float_scalar_types: return b_sctype in {a_sctype, np.floating, np.inexact, np.number, np.generic} - if (int2 is not None and a_sctype == int2) or a_sctype == int4: + if a_sctype in [int2, int4]: return b_sctype in {a_sctype, np.signedinteger, np.integer, np.number, np.generic} - if (uint2 is not None and a_sctype == uint2) or a_sctype == uint4: + if a_sctype in [uint2, uint4]: return b_sctype in {a_sctype, np.unsignedinteger, np.integer, np.number, np.generic} # Otherwise, fall back to numpy.issubdtype @@ -491,6 +528,7 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType, _unsigned_types: list[JAXType] _int_types: list[JAXType] _unsigned_types = [ + np.dtype(uint2), np.dtype(uint4), np.dtype('uint8'), np.dtype('uint16'), @@ -498,6 +536,7 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType, np.dtype('uint64'), ] _signed_types = [ + np.dtype(int2), np.dtype(int4), np.dtype('int8'), np.dtype('int16'), @@ -505,11 +544,6 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType, np.dtype('int64'), ] -if _int2_dtype is not None: - _signed_types.insert(0, _int2_dtype) -if _uint2_dtype is not None: - _unsigned_types.insert(0, _uint2_dtype) - _int_types = _unsigned_types + _signed_types _float_types: list[JAXType] = [ @@ -616,51 +650,63 @@ def _dtype_and_weaktype(value: Any) -> tuple[DType, bool]: """Return a (dtype, weak_type) tuple for the given input.""" return dtype(value), any(value is typ for typ in _weak_types) or is_weakly_typed(value) -def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> dict[JAXType, list[JAXType]]: +def _type_promotion_lattice(strict: bool, x64: bool) -> dict[JAXType, list[JAXType]]: """ Return the type promotion lattice in the form of a DAG. - This DAG maps each type to its immediately higher type on the lattice. + This DAG maps each type to its immediately higher types on the lattice. + + Args: + strict: use strict promotion lattice? + x64: allow promotions that form x64 types from non-x64 inputs? """ b1, = _bool_types - if _int2_dtype is not None: - assert _uint2_dtype is not None - _uint2, uint4, u1, u2, u4, u8, _int2, int4, i1, i2, i4, i8 = _int_types - else: - uint4, u1, u2, u4, u8, int4, i1, i2, i4, i8 = _int_types - *f1_types, bf, f2, f4, f8 = _float_types - c4, c8 = _complex_types + u2, u4, u8, u16, u32, u64, i2, i4, i8, i16, i32, i64 = _int_types + *small_float_types, bf16, f16, f32, f64 = _float_types + c64, c128 = _complex_types i_, f_, c_ = _weak_types - if jax_numpy_dtype_promotion == 'standard': - out: dict[JAXType, list[JAXType]] - out = { - b1: [i_], - i_: [u1, uint4, i1, int4], - uint4: [], u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_], - int4: [], i1: [i2], i2: [i4], i4: [i8], i8: [f_], - f_: [*f1_types, bf, f2, c_], - **{t: [] for t in f1_types}, bf: [f4], f2: [f4], f4: [f8, c4], f8: [c8], - c_: [c4], c4: [c8], c8: [], + if not strict: + out: dict[JAXType, list[JAXType]] = { + b1: [i_], + i_: [u8, u2, u4, i8, i2, i4], + u2: [], + u4: [], + u8: [i16, u16], + u16: [i32, u32], + u32: [i64, u64], + u64: [f_], + i2: [], + i4: [], + i8: [i16], + i16: [i32], + i32: [i64], + i64: [f_], + f_: [*small_float_types, bf16, f16, c_], + **{t: [] for t in small_float_types}, + bf16: [f32], + f16: [f32], + f32: [f64, c64], + f64: [c128], + c_: [c64], + c64: [c128], + c128: [], } - if _int2_dtype is not None: - out[i_].append(_int2_dtype) - out[_int2_dtype] = [] - if _uint2_dtype is not None: - out[i_].append(_uint2_dtype) - out[_uint2_dtype] = [] + # If x64 mode is not enabled, then we want to avoid any promotions that form + # 64-bit types from non-64-bit inputs. There's only one of these in the + # entire promotion lattice, namely u4xi4->i8, which we can avoid by + # replacing it with u4xi4->i4. + if not x64: + out[u32] = [i32, u64] return out - elif jax_numpy_dtype_promotion == 'strict': + else: return { i_: [f_] + _int_types, f_: [c_] + _float_types, c_: _complex_types, **{t: [] for t in _jax_types} } - else: - raise ValueError( - f"Unexpected value of jax_numpy_dtype_promotion={jax_numpy_dtype_promotion!r}") -def _make_lattice_upper_bounds(jax_numpy_dtype_promotion: str) -> dict[JAXType, set[JAXType]]: - lattice = _type_promotion_lattice(jax_numpy_dtype_promotion) +def _make_lattice_upper_bounds(strict: bool, x64: bool) -> dict[JAXType, set[JAXType]]: + lattice = _type_promotion_lattice(strict, x64) upper_bounds = {node: {node} for node in lattice} for n in lattice: while True: @@ -672,16 +718,17 @@ def _make_lattice_upper_bounds(jax_numpy_dtype_promotion: str) -> dict[JAXType, upper_bounds[n] |= new_upper_bounds return upper_bounds -_lattice_upper_bounds: dict[str, dict[JAXType, set[JAXType]]] = { - 'standard': _make_lattice_upper_bounds('standard'), - 'strict': _make_lattice_upper_bounds('strict'), -} +_standard_x64_lattice_ubs = _make_lattice_upper_bounds(strict=False, x64=True) +_standard_x32_lattice_ubs = _make_lattice_upper_bounds(strict=False, x64=False) +_strict_lattice_ubs = _make_lattice_upper_bounds(strict=True, x64=True) class TypePromotionError(ValueError): pass -@functools.lru_cache(512) # don't use util.memoize because there is no X64 dependence. -def _least_upper_bound(jax_numpy_dtype_promotion: str, *nodes: JAXType) -> JAXType: +# We don't use util.memoize because there is no implicit X64 dependence. +@functools.lru_cache(512) +def _least_upper_bound(jax_numpy_dtype_promotion: config.NumpyDtypePromotion, + x64: bool, *nodes: JAXType) -> JAXType: """Compute the least upper bound of a set of nodes. Args: @@ -708,7 +755,16 @@ def _least_upper_bound(jax_numpy_dtype_promotion: str, *nodes: JAXType) -> JAXTy # ∀ c ∈ N: CUB(N) ⊆ UB(c) # So if N ∩ CUB(N) is nonempty, if follows that LUB(N) = N ∩ CUB(N). N = set(nodes) - UB = _lattice_upper_bounds[jax_numpy_dtype_promotion] + if jax_numpy_dtype_promotion == config.NumpyDtypePromotion.STRICT: + UB = _strict_lattice_ubs + elif jax_numpy_dtype_promotion == config.NumpyDtypePromotion.STANDARD: + if x64: + UB = _standard_x64_lattice_ubs + else: + UB = _standard_x32_lattice_ubs + else: + raise ValueError( + f"Unexpected value of jax_numpy_dtype_promotion={jax_numpy_dtype_promotion!r}") try: bounds = [UB[n] for n in N] except KeyError: @@ -719,7 +775,7 @@ def _least_upper_bound(jax_numpy_dtype_promotion: str, *nodes: JAXType) -> JAXTy if len(LUB) == 1: return LUB.pop() elif len(LUB) == 0: - if config.numpy_dtype_promotion.value == 'strict': + if config.numpy_dtype_promotion.value == config.NumpyDtypePromotion.STRICT: msg = ( f"Input dtypes {tuple(str(n) for n in nodes)} have no available implicit dtype " "promotion path when jax_numpy_dtype_promotion=strict. Try explicitly casting " @@ -802,18 +858,26 @@ def promote_types(a: DTypeLike, b: DTypeLike) -> DType: # object identity, not object equality, due to the behavior of np.dtype.__eq__ a_tp = cast(JAXType, a if any(a is t for t in _weak_types) else np.dtype(a)) b_tp = cast(JAXType, b if any(b is t for t in _weak_types) else np.dtype(b)) - return np.dtype(_least_upper_bound(config.numpy_dtype_promotion.value, a_tp, b_tp)) + return np.dtype(_least_upper_bound( + config.numpy_dtype_promotion.value, config.enable_x64.value, a_tp, b_tp)) def register_weak_scalar_type(typ: type): """Register a scalar type as a weak type.""" - _registered_weak_types.append(typ) -_registered_weak_types: list[JAXType] = [] + _registered_weak_types.add(typ) + +_registered_weak_types: set[JAXType] = { + literals.TypedInt, + literals.TypedFloat, + literals.TypedComplex, +} def is_weakly_typed(x: Any) -> bool: if type(x) in _weak_types or type(x) in _registered_weak_types: return True + if isinstance(x, literals.TypedNdArray): + return x.weak_type try: return x.aval.weak_type except AttributeError: @@ -823,25 +887,91 @@ def is_python_scalar(x: Any) -> bool: try: return x.aval.weak_type and np.ndim(x) == 0 except AttributeError: - return type(x) in python_scalar_dtypes + return type(x) in python_scalar_types def check_valid_dtype(dtype: DType) -> None: if dtype not in _jax_dtype_set: raise TypeError(f"Dtype {dtype} is not a valid JAX array " "type. Only arrays of numeric types are supported by JAX.") -def dtype(x: Any, *, canonicalize: bool = False) -> DType: - """Return the dtype object for a value or type, optionally canonicalized based on X64 mode.""" +def _maybe_canonicalize_explicit_dtype(dtype: DType, fun_name: str) -> DType: + "Canonicalizes explicitly requested dtypes, per explicit_x64_dtypes." + allow = config.explicit_x64_dtypes.value + if allow == config.ExplicitX64Mode.ALLOW or config.enable_x64.value: + return dtype + canonical_dtype = canonicalize_dtype(dtype) + if canonical_dtype == dtype: + return dtype + fun_name = f" requested in {fun_name}" if fun_name else "" + if allow == config.ExplicitX64Mode.ERROR: + msg = ("Explicitly requested dtype {}{} is not available. To enable more " + "dtypes, set the jax_enable_x64 or allow_explicit_x64_dtypes " + "configuration options." + "See https://github.com/jax-ml/jax#current-gotchas for more.") + msg = msg.format(dtype, fun_name, canonical_dtype.name) + raise ValueError(msg) + else: # WARN + msg = ("Explicitly requested dtype {}{} is not available, " + "and will be truncated to dtype {}. To enable more dtypes, set the " + "jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell " + "environment variable. " + "See https://github.com/jax-ml/jax#current-gotchas for more.") + msg = msg.format(dtype, fun_name, canonical_dtype.name) + warnings.warn(msg, stacklevel=4) + return canonical_dtype + + +_types_whose_dtype_should_not_be_canonicalized = ( + Array, + literals.TypedNdArray, + literals.TypedInt, + literals.TypedFloat, + literals.TypedComplex, +) + +def dtype(x: Any) -> DType: + """Return the dtype object for a value or type. + + Python scalars, Python scalar types, NumPy scalar type, NumPy dtypes, and + non-JAX arrays will have their dtypes canonicalized. + + Note: this is not the same function as jax.numpy.dtype, which simply aliases + numpy.dtype.""" + # TODO(phawkins): in the future, we would like to: + # - return the default dtype for Python scalar types and values + # - canonicalize NumPy array and scalar types + # - return NumPy dtypes as-is, uncanonicalized. if x is None: raise ValueError(f"Invalid argument to dtype: {x}.") - is_type = isinstance(x, type) - if is_type and x in python_scalar_dtypes: - dt = python_scalar_dtypes[x] - elif type(x) in python_scalar_dtypes: - dt = python_scalar_dtypes[type(x)] - elif is_type and _issubclass(x, np.generic): - return np.dtype(x) - elif issubdtype(getattr(x, 'dtype', None), extended): + if isinstance(x, type): + # Python scalar types, e.g., int, float + if (dt := python_scalar_types_to_dtypes.get(x)) is not None: + return canonicalize_dtype(dt) + + # Numpy scalar types, e.g., np.int32, np.float32 + if _issubclass(x, np.generic): + dt = np.dtype(x) + return _maybe_canonicalize_explicit_dtype(dt, "dtype") + + # Python scalar values, e.g., int(3), float(3.14) + elif (dt := python_scalar_types_to_dtypes.get(type(x))) is not None: + return canonicalize_dtype(dt) + # Jax Arrays, literal arrays, and scalars. + # We intentionally do not canonicalize these types: once we've formed an x64 + # value, that is something we respect irrespective of the x64 mode. + elif isinstance(x, _types_whose_dtype_should_not_be_canonicalized): + return x.dtype + + if isinstance(x, str): + x = np.dtype(x) + + if isinstance(x, np.dtype): + if x not in _jax_dtype_set and not issubdtype(x, extended): + raise TypeError(f"Value '{x}' with dtype {dt} is not a valid JAX array " + "type. Only arrays of numeric types are supported by JAX.") + return _maybe_canonicalize_explicit_dtype(x, "dtype") + + if issubdtype(getattr(x, 'dtype', None), extended): dt = x.dtype else: try: @@ -852,9 +982,9 @@ def dtype(x: Any, *, canonicalize: bool = False) -> DType: raise TypeError(f"Value '{x}' with dtype {dt} is not a valid JAX array " "type. Only arrays of numeric types are supported by JAX.") # TODO(jakevdp): fix return type annotation and remove this ignore. - return canonicalize_dtype(dt, allow_extended_dtype=True) if canonicalize else dt # type: ignore[return-value] + return canonicalize_dtype(dt, allow_extended_dtype=True) # type: ignore[return-value] -def _lattice_result_type(*args: Any) -> tuple[DType, bool]: +def lattice_result_type(*args: Any) -> tuple[DType, bool]: dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args)) if len(dtypes) == 1: out_dtype = dtypes[0] @@ -863,18 +993,20 @@ def _lattice_result_type(*args: Any) -> tuple[DType, bool]: # Trivial promotion case. This allows extended dtypes through. out_dtype = dtypes[0] out_weak_type = False - elif all(weak_types) and config.numpy_dtype_promotion.value != 'strict': + elif all(weak_types) and config.numpy_dtype_promotion.value != config.NumpyDtypePromotion.STRICT: # If all inputs are weakly typed, we compute the bound of the strongly-typed # counterparts and apply the weak type at the end. This avoids returning the # incorrect result with non-canonical weak types (e.g. weak int16). # TODO(jakevdp): explore removing this special case. - result_type = _least_upper_bound(config.numpy_dtype_promotion.value, - *{_jax_type(dtype, False) for dtype in dtypes}) + result_type = _least_upper_bound( + config.numpy_dtype_promotion.value, config.enable_x64.value, + *{_jax_type(dtype, False) for dtype in dtypes}) out_dtype = dtype(result_type) out_weak_type = True else: - result_type = _least_upper_bound(config.numpy_dtype_promotion.value, - *{_jax_type(d, w) for d, w in zip(dtypes, weak_types)}) + result_type = _least_upper_bound( + config.numpy_dtype_promotion.value, config.enable_x64.value, + *{_jax_type(d, w) for d, w in zip(dtypes, weak_types)}) out_dtype = dtype(result_type) out_weak_type = any(result_type is t for t in _weak_types) return out_dtype, (out_dtype != bool_) and out_weak_type @@ -902,54 +1034,35 @@ def result_type(*args: Any, return_weak_type_flag: bool = False) -> DType | tupl if len(args) == 0: raise ValueError("at least one array or dtype is required") dtype: DType | ExtendedDType - dtype, weak_type = _lattice_result_type(*(float_ if arg is None else arg for arg in args)) + dtype, weak_type = lattice_result_type(*(default_float_dtype() if arg is None else arg for arg in args)) if weak_type: - dtype = canonicalize_dtype( - _default_types['f' if dtype in _custom_float_dtypes else dtype.kind]) - else: - dtype = canonicalize_dtype(dtype, allow_extended_dtype=True) + dtype = default_types['f' if dtype in _custom_float_dtypes else dtype.kind]() # TODO(jakevdp): fix return type annotation and remove this ignore. return (dtype, weak_type) if return_weak_type_flag else dtype # type: ignore[return-value] -def check_user_dtype_supported(dtype, fun_name=None): +def check_and_canonicalize_user_dtype(dtype, fun_name=None) -> DType: + """Checks validity of a user-provided dtype, and returns its canonical form. + + For Python scalar types this function returns the corresponding default dtype. + """ + if dtype is None: + raise ValueError("dtype must be specified.") if isinstance(dtype, Array): - # Deprecation warning added 2024 June 13. - warnings.warn("Passing an array as a dtype argument is deprecated; " - "instead of dtype=arr use dtype=arr.dtype.", - category=DeprecationWarning, stacklevel=3) - return # no further check needed, as array dtypes have already been validated. + raise ValueError("Passing an array as a dtype argument is no longer " + "supported; instead of dtype=arr use dtype=arr.dtype.") if issubdtype(dtype, extended): - return + return dtype # Avoid using `dtype in [...]` because of numpy dtype equality overloading. - if isinstance(dtype, type) and dtype in {bool, int, float, builtins.complex}: - return + if isinstance(dtype, type) and (f := _DEFAULT_TYPEMAP.get(dtype)) is not None: + return f() np_dtype = np.dtype(dtype) - is_custom_dtype = np_dtype.type in [ - *_custom_float_scalar_types, - int2, - int4, - uint2, - uint4 - ] - if ( - np_dtype.kind not in 'biufcT' - and not is_custom_dtype - and not dtype == float0 - ): + if np_dtype not in _jax_dtype_set: msg = ( f'JAX only supports number, bool, and string dtypes, got dtype {dtype}' ) msg += f" in {fun_name}" if fun_name else "" raise TypeError(msg) - if dtype is not None and np_dtype != canonicalize_dtype(np_dtype): - msg = ("Explicitly requested dtype {} {} is not available, " - "and will be truncated to dtype {}. To enable more dtypes, set the " - "jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell " - "environment variable. " - "See https://github.com/jax-ml/jax#current-gotchas for more.") - fun_name = f"requested in {fun_name}" if fun_name else "" - truncated_dtype = canonicalize_dtype(np_dtype).name - warnings.warn(msg.format(dtype, fun_name, truncated_dtype), stacklevel=3) + return _maybe_canonicalize_explicit_dtype(np_dtype, fun_name) def safe_to_cast(input_dtype_or_value: Any, output_dtype_or_value: Any) -> bool: @@ -971,17 +1084,17 @@ def safe_to_cast(input_dtype_or_value: Any, Examples: - >>> safe_to_cast('int32', 'float64') + >>> safe_to_cast('int16', 'float32') True - >>> safe_to_cast('float64', 'int32') + >>> safe_to_cast('float32', 'int16') False >>> safe_to_cast('float32', 'complex64') True - >>> safe_to_cast('complex64', 'float64') + >>> safe_to_cast('complex64', 'float32') False """ - input_dtype = dtype(input_dtype_or_value, canonicalize=True) - output_dtype = dtype(output_dtype_or_value, canonicalize=True) + input_dtype = dtype(input_dtype_or_value) + output_dtype = dtype(output_dtype_or_value) if input_dtype == output_dtype: return True # We deliberately use output_dtype rather than output_dtype_or_value here: @@ -1010,6 +1123,7 @@ class PrimalTangentDType(ExtendedDType): return PrimalTangentDType() +@functools.cache def short_dtype_name(dtype) -> str: if isinstance(dtype, ExtendedDType): return str(dtype) diff --git a/jax/_src/earray.py b/jax/_src/earray.py index a85138584afb..b45d371057e4 100644 --- a/jax/_src/earray.py +++ b/jax/_src/earray.py @@ -18,10 +18,10 @@ from jax._src import basearray from jax._src import core +from jax._src import dtypes from jax._src import tree_util from jax._src import sharding_impls from jax._src.interpreters import pxla -from jax._src.interpreters import xla from jax._src.util import safe_zip, safe_map map, unsafe_map = safe_map, map @@ -116,6 +116,6 @@ def _earray_shard_arg_handler(xs, shardings, layouts, copy_semantics): pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler core.pytype_aval_mappings[EArray] = lambda x: x.aval -xla.canonicalize_dtype_handlers[EArray] = lambda x: x +dtypes.canonicalize_value_handlers[EArray] = lambda x: x tree_util.dispatch_registry.register_node( EArray, lambda x: ((x._data,), x.aval), lambda a, xs: EArray(a, xs[0])) diff --git a/jax/_src/effects.py b/jax/_src/effects.py index 36528c5feae5..efbf10638cf0 100644 --- a/jax/_src/effects.py +++ b/jax/_src/effects.py @@ -47,7 +47,7 @@ for each thread the `RuntimeToken` returned by the last dispatched computation. For more details, see the design note: -https://jax.readthedocs.io/en/latest/jep/10657-sequencing-effects.html. +https://docs.jax.dev/en/latest/jep/10657-sequencing-effects.html. """ from __future__ import annotations @@ -62,9 +62,13 @@ class Effect: Effects = Set[Effect] class JaxprInputEffect(Effect): - """A side-effect associated with the input of a jaxpr. + """A side-effect associated with the input of a `JaxprEqn` or a `Jaxpr`. - Note that the `input_index` includes constvars. + This is used as a base class for effects associated with inputs, e.g., + reading/writing from mutable inputs. + + When used in a `JaxprEqn`, `input_index` refers to `eqn.invars`. + When used in a `Jaxpr`, `input_index` refers to `jaxpr.constvars + jaxpr.invars`. """ def __init__(self, input_index: Any): @@ -91,6 +95,9 @@ class EffectTypeSet: def __init__(self): self._effect_types: set[type[Effect]] = set() + def __repr__(self): + return f"EffectTypeSet({self._effect_types})" + def add_type(self, effect_type: type[Effect]): self._effect_types.add(effect_type) @@ -118,3 +125,5 @@ def filter_not_in(self, effects: Iterable[Effect]) -> list[Effect]: control_flow_allowed_effects: EffectTypeSet = EffectTypeSet() custom_derivatives_allowed_effects: EffectTypeSet = EffectTypeSet() remat_allowed_effects: EffectTypeSet = EffectTypeSet() + +partial_eval_kept_effects: EffectTypeSet = EffectTypeSet() diff --git a/jax/_src/environment_info.py b/jax/_src/environment_info.py index 4abfdeaa0f14..d9365cc75131 100644 --- a/jax/_src/environment_info.py +++ b/jax/_src/environment_info.py @@ -14,12 +14,12 @@ from __future__ import annotations +import os import platform import subprocess import sys import textwrap -from jax import version from jax._src import lib from jax._src import xla_bridge as xb import numpy as np @@ -39,6 +39,8 @@ def print_environment_info(return_string: bool = False) -> str | None: Args: return_string (bool) : if True, return the string rather than printing to stdout. """ + from jax import version # pytype: disable=import-error + # TODO(jakevdp): should we include other info, e.g. jax.config.values? python_version = sys.version.replace('\n', ' ') info = textwrap.dedent(f"""\ @@ -48,8 +50,10 @@ def print_environment_info(return_string: bool = False) -> str | None: python: {python_version} device info: {xb.devices()[0].device_kind}-{xb.device_count()}, {xb.local_device_count()} local devices" process_count: {xb.process_count()} - platform: {platform.uname()} -""") + platform: {platform.uname()}""") + for key, value in os.environ.items(): + if key.startswith(("JAX_", "XLA_")): + info += f"\n{key}={value}" nvidia_smi = try_nvidia_smi() if nvidia_smi: info += '\n\n$ nvidia-smi\n' + nvidia_smi diff --git a/jax/_src/error_check.py b/jax/_src/error_check.py index 60dc2f76a5b2..5375267805af 100644 --- a/jax/_src/error_check.py +++ b/jax/_src/error_check.py @@ -14,82 +14,93 @@ from __future__ import annotations +import dataclasses from functools import partial +import json import threading +import traceback as tb_lib +from types import TracebackType +import warnings + +import numpy as np -import jax from jax._src import core from jax._src import source_info_util from jax._src import traceback_util +from jax._src import tree_util import jax._src.mesh as mesh_lib -from jax.experimental.shard_map import shard_map -import jax.numpy as jnp -from jax.sharding import NamedSharding, PartitionSpec as P - - -Traceback = source_info_util.Traceback +from jax._src import shard_map +from jax._src.export import _export +from jax._src.lax import lax +from jax._src.sharding_impls import NamedSharding, PartitionSpec as P +from jax._src.typing import Array, ArrayLike traceback_util.register_exclusion(__file__) class JaxValueError(ValueError): - """Exception raised for failed runtime error checks in JAX.""" + """Exception raised for runtime errors detected within JAX computations.""" #: The default error code for no error. #: #: This value is chosen because we can use `jnp.min()` to obtain the #: first error when performing reductions. -_NO_ERROR = jnp.iinfo(jnp.uint32).max +_NO_ERROR = np.iinfo(np.uint32).max -_error_list_lock = threading.Lock() -_error_list: list[tuple[str, Traceback]] = [] # (error_message, traceback) pair +_error_list_lock = threading.RLock() +# (error_message, traceback) pairs. Traceback is `str` when imported from AOT. +_error_list: list[tuple[str, TracebackType | str]] = [] class _ErrorStorage(threading.local): def __init__(self): - self.ref: core.MutableArray | None = None + self.ref: core.Ref | None = None _error_storage = _ErrorStorage() def _initialize_error_code_ref() -> None: - """Initialize error_code_ref in the current thread. + """Initialize the error code ref in the current thread. - The size of the error code array is determined by the mesh in the context. In - single-device environment, the array is a scalar. In multi-device - environment, the array has the same shape as the mesh. + The shape and size of the error code array depend on the mesh in the context. + In single-device environments, the array is a scalar. In multi-device + environments, its shape and size match those of the mesh. """ - with core.eval_context(): - # Get mesh from the context. - mesh = mesh_lib.get_concrete_mesh() - - if mesh is None: # single-device case. - error_code = jnp.uint32(_NO_ERROR) - - else: # multi-device case. - sharding = NamedSharding(mesh, P(*mesh.axis_names)) - error_code = jnp.full( - mesh.axis_sizes, - jnp.uint32(_NO_ERROR), - device=sharding, - ) + # Get mesh from the context. + mesh = mesh_lib.get_concrete_mesh() + + if mesh.empty: # single-device case. + error_code: ArrayLike = np.uint32(_NO_ERROR) + + else: # multi-device case. + sharding = NamedSharding(mesh, P(*mesh.axis_names)) + error_code = lax.full( + mesh.axis_sizes, + np.uint32(_NO_ERROR), + sharding=sharding, + ) - _error_storage.ref = core.mutable_array(error_code) + _error_storage.ref = core.new_ref(error_code) class error_checking_context: - """Redefine the error checking state based on the mesh in the context. + """Redefine the internal error state based on the mesh in the context. + + When using JAX in multi-device environments in explicit mode, error tracking + needs to be properly aligned with the device mesh. This context manager + ensures that the internal error state is correctly initialized based on the + current mesh configuration. - This context manager should be used when starting a multi-device - computation, and whenever the mesh is changed. + This context manager should be used when starting a multi-device computation, + or when switching between different device meshes. - When exiting the context, the error checking state will be reset to the - original state. + On entering the context, it initializes a new error state based on the mesh in + the context. On exiting the context, it restores the previous error state. """ __slots__ = ("old_ref",) @@ -99,69 +110,114 @@ def __init__(self): def __enter__(self): self.old_ref = _error_storage.ref - _initialize_error_code_ref() + with core.eval_context(): + _initialize_error_code_ref() return self def __exit__(self, exc_type, exc_value, traceback): _error_storage.ref = self.old_ref -def set_error_if(pred: jax.Array, /, msg: str) -> None: - """Set error if any element of pred is true. +def set_error_if(pred: Array, /, msg: str) -> None: + """Set the internal error state if any element of `pred` is `True`. + + This function is used inside JAX computations to detect runtime errors without + immediately halting execution. When this function is traced (e.g., inside + :func:`jax.jit`), the corresponding error message and its traceback are + recorded. At execution time, if `pred` contains any `True` values, the error + state is set, but execution continues without interruption. The recorded error + can later be raised using :func:`raise_if_error`. - If the error is already set, the new error will be ignored. It will not - override the existing error. + If the error state has already been set, subsequent errors are ignored and + will not override the existing error. - In auto mode, this function does not work under jit. + For multi-device environments, in explicit mode, users must call + :func:`error_checking_context` to initialize a new error tracking state that + matches the device mesh. In auto mode, implicit cross-device communication may + occur inside this function, which could impact performance. A warning is + issued in such cases. + + When exporting a function with `jax.export`, error checking must be explicitly + wrapped using :func:`wrap_for_export` before export and + :func:`unwrap_from_import` after import. + + Args: + pred: A JAX boolean array. If any element of `pred` is `True`, the internal + error state will be set. + msg: The corresponding error message to be raised later. """ + # TODO(jakevdp): remove this import and express the following using lax APIs. + import jax.numpy as jnp # pytype: disable=import-error + if _error_storage.ref is None: - _initialize_error_code_ref() + with core.eval_context(): + _initialize_error_code_ref() assert _error_storage.ref is not None + # Get the traceback. traceback = source_info_util.current().traceback assert traceback is not None + traceback = traceback.as_python_traceback() + assert isinstance(traceback, TracebackType) + traceback = traceback_util.filter_traceback(traceback) + assert isinstance(traceback, TracebackType) + with _error_list_lock: - new_error_code = jnp.uint32(len(_error_list)) + new_error_code = np.uint32(len(_error_list)) _error_list.append((msg, traceback)) out_sharding = core.typeof(_error_storage.ref).sharding in_sharding: NamedSharding = core.typeof(pred).sharding - if out_sharding.mesh.shape_tuple == (): # single-device case. + # Reduce `pred`. + if all(dim is None for dim in out_sharding.spec): # single-device case. pred = pred.any() else: # multi-device case. has_auto_axes = mesh_lib.AxisType.Auto in in_sharding.mesh.axis_types - if has_auto_axes: - raise NotImplementedError( - "Error checking in auto mode is not supported yet. Please use" - " explicit mode." - ) - if out_sharding.mesh != in_sharding.mesh: - raise ValueError( - "The error code state and the predicate must be on the same mesh, " - f"but got {out_sharding.mesh} and {in_sharding.mesh} respectively. " - "Please use `with error_checking_context()` to redefine the error " - "code state based on the mesh." + if has_auto_axes: # auto mode. + warnings.warn( + "When at least one mesh axis of `pred` is in auto mode, calling" + " `set_error_if` will cause implicit communication between devices." + " To avoid this, consider converting the mesh axis in auto mode to" + " explicit mode.", + RuntimeWarning, ) - pred = shard_map( - partial(jnp.any, keepdims=True), - mesh=out_sharding.mesh, - in_specs=in_sharding.spec, - out_specs=out_sharding.spec, - )(pred) # perform per-device reduction + pred = pred.any() # reduce to a single scalar + else: # explicit mode. + if out_sharding.mesh != in_sharding.mesh: + raise ValueError( + "The error code state and the predicate must be on the same mesh, " + f"but got {out_sharding.mesh} and {in_sharding.mesh} respectively. " + "Please use `with error_checking_context()` to redefine the error " + "code state based on the mesh." + ) + pred = shard_map.shard_map( + partial(jnp.any, keepdims=True), + mesh=out_sharding.mesh, + in_specs=in_sharding.spec, + out_specs=out_sharding.spec, + )(pred) # perform per-device reduction error_code = _error_storage.ref[...] - should_update = jnp.logical_and(pred, error_code == jnp.uint32(_NO_ERROR)) + should_update = jnp.logical_and(error_code == jnp.uint32(_NO_ERROR), pred) error_code = jnp.where(should_update, new_error_code, error_code) # TODO(ayx): support vmap and shard_map. _error_storage.ref[...] = error_code def raise_if_error() -> None: - """Raise error if an error is set. + """Raise an exception if the internal error state is set. - This function should be called after the computation is finished. It should - not be called within a traced context, such as within a jitted function." + This function should be called after a computation completes to check for any + errors that were marked during execution via `set_error_if()`. If an error + exists, it raises a `JaxValueError` with the corresponding error message. + + This function should not be called inside a traced function (e.g., inside + :func:`jax.jit`). Doing so will raise a `ValueError`. + + Raises: + JaxValueError: If the internal error state is set. + ValueError: If called within a traced JAX function. """ if _error_storage.ref is None: # if not initialized, do nothing return @@ -172,16 +228,144 @@ def raise_if_error() -> None: "raise_if_error() should not be called within a traced context, such as" " within a jitted function." ) - if error_code == jnp.uint32(_NO_ERROR): + if error_code == np.uint32(_NO_ERROR): return - _error_storage.ref[...] = jnp.full( + _error_storage.ref[...] = lax.full( _error_storage.ref.shape, - jnp.uint32(_NO_ERROR), - device=_error_storage.ref.sharding, + np.uint32(_NO_ERROR), + sharding=_error_storage.ref.sharding, ) # clear the error code - msg, traceback = _error_list[error_code] - exc = JaxValueError(msg) - traceback = traceback.as_python_traceback() - filtered_traceback = traceback_util.filter_traceback(traceback) - raise exc.with_traceback(filtered_traceback) + with _error_list_lock: + msg, traceback = _error_list[error_code] + if isinstance(traceback, str): # from imported AOT functions + exc = JaxValueError( + f"{msg}\nThe original traceback is shown below:\n{traceback}" + ) + raise exc + else: + exc = JaxValueError(msg) + raise exc.with_traceback(traceback) + + +@dataclasses.dataclass(frozen=True) +class _ErrorClass: + """A class to store error information for AOT compilation. + + This class is used internally by the wrapper functions `wrap_for_export` and + `unwrap_from_import` to encapsulate error-related data within an exported + function. + + Attributes: + error_code (jax.Array): A JAX array representing the final error state of + the function to be exported. This value is local to the wrapper function. + error_list (list[tuple[str, str]]): A list of `(error_message, traceback)` + pairs containing error messages and corresponding stack traces. This error + list is local to the wrapper function, and does not contain pairs of error + information from other functions. + """ + + error_code: Array + error_list: list[tuple[str, str]] + + +tree_util.register_dataclass( + _ErrorClass, data_fields=("error_code",), meta_fields=("error_list",) +) +_export.register_pytree_node_serialization( + _ErrorClass, + serialized_name=f"{_ErrorClass.__module__}.{_ErrorClass.__name__}", + serialize_auxdata=lambda x: json.dumps(x, ensure_ascii=False).encode( + "utf-8" + ), + deserialize_auxdata=lambda x: json.loads(x.decode("utf-8")), +) + + +def _traceback_to_str(traceback: TracebackType) -> str: + """Convert a traceback to a string for export.""" + return "".join(tb_lib.format_list(tb_lib.extract_tb(traceback))).rstrip("\n") + + +def wrap_for_export(f): + """Wrap a function with error checking to make it compatible with AOT mode. + + Error checking relies on global state, which cannot be serialized across + processes. This wrapper ensures that the error state remains within the + function scope, making it possible to export the function and later import in + other processes. + + When the function is later imported, it must be wrapped with + :func:`unwrap_from_import` to integrate the error checking mechanism of the + imported function into the global error checking mechanism of the current + process. + + This function should only be applied once to a function; wrapping the same + function multiple times is unnecessary. + """ + + def inner(*args, **kwargs): + global _error_list + + # 1. Save the old state and initialize a new state. + with core.eval_context(): + old_ref = _error_storage.ref + _initialize_error_code_ref() + with _error_list_lock: + old_error_list, _error_list = _error_list, [] + + # 2. Trace the function. + out = f(*args, **kwargs) + error_code = _error_storage.ref[...].min() + + # 3. Restore the old state. + _error_list, new_error_list = old_error_list, _error_list + with core.eval_context(): + _error_storage.ref = old_ref + + new_error_list = [ + (msg, _traceback_to_str(traceback)) for msg, traceback in new_error_list + ] + return out, _ErrorClass(error_code, new_error_list) + + return inner + + +def unwrap_from_import(f): + """Unwrap a function after AOT import to restore error checking. + + When an AOT-exported function is imported in a new process, its error state is + separate from the global error state of the current process. This wrapper + ensures that errors detected during execution are correctly integrated into + the global error checking mechanism of the current process. + + This function should only be applied to functions that were previously wrapped + with :func:`wrap_for_export` before export. + """ + if _error_storage.ref is None: + with core.eval_context(): + _initialize_error_code_ref() + assert _error_storage.ref is not None + + def inner(*args, **kwargs): + out, error_class = f(*args, **kwargs) + new_error_code, error_list = error_class.error_code, error_class.error_list + + # Update the global error list. + with _error_list_lock: + offset = len(_error_list) + _error_list.extend(error_list) + + # Update the global error code array. + error_code = _error_storage.ref[...] + should_update = lax.bitwise_and( + error_code == np.uint32(_NO_ERROR), + new_error_code != np.uint32(_NO_ERROR), + ) + error_code = lax.select(should_update, new_error_code + offset, error_code) + # TODO(ayx): support vmap and shard_map. + _error_storage.ref[...] = error_code + + return out + + return inner diff --git a/jax/_src/errors.py b/jax/_src/errors.py index 6540fd1f5d41..a81d414b182a 100644 --- a/jax/_src/errors.py +++ b/jax/_src/errors.py @@ -21,7 +21,7 @@ class _JAXErrorMixin: """Mixin for JAX-specific errors""" - _error_page = 'https://jax.readthedocs.io/en/latest/errors.html' + _error_page = 'https://docs.jax.dev/en/latest/errors.html' _module_name = "jax.errors" def __init__(self, message: str): @@ -35,12 +35,12 @@ def __init__(self, message: str): @export class JAXTypeError(_JAXErrorMixin, TypeError): - pass + """JAX-specific :class:`TypeError`""" @export class JAXIndexError(_JAXErrorMixin, IndexError): - pass + """JAX-specific :class:`IndexError`""" @export @@ -306,7 +306,7 @@ class TracerArrayConversionError(JAXTypeError): and concrete vs. abstract values, you may want to read :ref:`faq-different-kinds-of-jax-values`. - .. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html + .. _External Callbacks: https://docs.jax.dev/en/latest/notebooks/external_callbacks.html """ def __init__(self, tracer: core.Tracer): super().__init__( @@ -503,7 +503,7 @@ class TracerBoolConversionError(ConcretizationTypeError): In this case, the error occurs because Python's built-in ``min`` function is not compatible with JAX transforms. This can be fixed by replacing it with - ``jnp.minumum``: + ``jnp.minimum``: >>> @jit ... def func(x): @@ -530,7 +530,7 @@ class UnexpectedTracerError(JAXTypeError): function ``f`` that stores, in some scope outside of ``f``, a reference to an intermediate value, that value is considered to have been leaked. Leaking values is a side effect. (Read more about avoiding side effects in - `Pure Functions `_) + `Pure Functions `_) JAX detects leaks when you then use the leaked value in another operation later on, at which point it raises an ``UnexpectedTracerError``. @@ -678,6 +678,5 @@ class KeyReuseError(JAXTypeError): This sort of key reuse is problematic because the JAX PRNG is stateless, and keys must be manually split; For more information on this see `the Pseudorandom Numbers - tutorial `_. + tutorial `_. """ - pass diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index afae3d9bcdc2..a5e48df82fd0 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -11,9 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""JAX APIs for exporting JAX functions for interoperation. - -""" +"""JAX APIs for exporting JAX functions for interoperation.""" from __future__ import annotations @@ -27,15 +25,14 @@ import re from typing import Any, Protocol, TypeVar, Union, cast -from absl import logging +import logging import numpy as np -import jax -from jax import sharding - from jax._src import ad_util +from jax._src import api from jax._src import config from jax._src import core +from jax._src import custom_derivatives from jax._src import dispatch from jax._src import dtypes from jax._src import effects @@ -43,20 +40,25 @@ from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.lib import xla_client -from jax._src.lib import xla_extension, xla_extension_version +from jax._src.lib import _jax from jax._src.lib.mlir import ir, passmanager from jax._src.lib.mlir.dialects import hlo -from jax._src.lib.mlir.dialects import func as func_dialect +from jax._src.lib.mlir.dialects import func as func_dialect, sdy from jax._src import pjit +from jax._src import sharding from jax._src import sharding_impls from jax._src import source_info_util from jax._src import stages +from jax._src import traceback_util from jax._src import tree_util +from jax._src import typing from jax._src import util from jax._src import xla_bridge as xb from jax._src.export import shape_poly +logger = logging.getLogger(__name__) + map = util.safe_map zip = util.safe_zip @@ -64,12 +66,13 @@ Shape = core.Shape # The values of input and output sharding from the lowering. LoweringSharding = Union[sharding.Sharding, pxla.UnspecifiedValue] +NamedSharding = sharding_impls.NamedSharding HloSharding = xla_client.HloSharding # The minimum and maximum supported calling convention version. -# See https://jax.readthedocs.io/en/latest/export/export.html#export-calling-convention-version +# See https://docs.jax.dev/en/latest/export/export.html#export-calling-convention-version minimum_supported_calling_convention_version = 9 -maximum_supported_calling_convention_version = 9 +maximum_supported_calling_convention_version = 10 class DisabledSafetyCheck: @@ -77,11 +80,11 @@ class DisabledSafetyCheck: Most of these checks are performed on serialization, but some are deferred to deserialization. The list of disabled checks is attached to the serialization, - e.g., as a sequence of string attributes to `jax.export.Exported` or of - `tf.XlaCallModuleOp`. + e.g., as a sequence of string attributes to :class:`jax.export.Exported` or of + ``tf.XlaCallModuleOp``. When using jax2tf, you can disable more deserialization safety checks - by passing `TF_XLA_FLAGS=--tf_xla_call_module_disabled_checks=platform`. + by passing ``TF_XLA_FLAGS=--tf_xla_call_module_disabled_checks=platform``. """ _impl: str @@ -130,7 +133,7 @@ class Exported: Attributes: fun_name: the name of the exported function, for error messages. in_tree: a PyTreeDef describing the tuple (args, kwargs) of the lowered JAX - function. The actual lowering does not depend on the `in_tree`, but this + function. The actual lowering does not depend on the ``in_tree``, but this can be used to invoke the exported function using the same argument structure. in_avals: the flat tuple of input abstract values. May contain dimension @@ -138,42 +141,49 @@ class Exported: out_tree: a PyTreeDef describing the result of the lowered JAX function. out_avals: the flat tuple of output abstract values. May contain dimension expressions in the shapes, with dimension variables among those in - `in_avals`. + ``in_avals``. Note that when the out_shardings are not specified for + an output, the `out_avals.sharding.spec` for `Auto` axes may be `None` + even if after compilation the compiler may pick a non-replicated + sharding. + See https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html#concrete-array-shardings-can-mention-auto-mesh-axis + for more details. in_shardings_hlo: the flattened input shardings, a sequence as long - as `in_avals`. `None` means unspecified sharding. + as ``in_avals``. ``None`` means unspecified sharding. Note that these do not include the mesh or the actual devices used in - the mesh. See `in_shardings_jax` for a way to turn these + the mesh, and in general you should avoid using this field directly. + See ``in_shardings_jax`` for a way to turn these into sharding specification that can be used with JAX APIs. out_shardings_hlo: the flattened output shardings, a sequence as long - as `out_avals`. `None` means unspecified sharding. + as ``out_avals``. ``None`` means unspecified sharding. Note that these do not include the mesh or the actual devices used in - the mesh. See `out_shardings_jax` for a way to turn these + the mesh, and in general you should avoid using this field directly. + See ``out_shardings_jax`` for a way to turn these into sharding specification that can be used with JAX APIs. nr_devices: the number of devices that the module has been lowered for. platforms: a tuple containing the platforms for which the function should be exported. The set of platforms in JAX is open-ended; users can add platforms. JAX built-in platforms are: 'tpu', 'cpu', 'cuda', 'rocm'. - See https://jax.readthedocs.io/en/latest/export/export.html#cross-platform-and-multi-platform-export. + See https://docs.jax.dev/en/latest/export/export.html#cross-platform-and-multi-platform-export. ordered_effects: the ordered effects present in the serialized module. - This is present from serialization version 9. See https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention + This is present from serialization version 9. See https://docs.jax.dev/en/latest/export/export.html#module-calling-convention for the calling convention in presence of ordered effects. unordered_effects: the unordered effects present in the serialized module. This is present from serialization version 9. mlir_module_serialized: the serialized lowered VHLO module. calling_convention_version: a version number for the calling convention of the exported module. - See more versioning details at https://jax.readthedocs.io/en/latest/export/export.html#calling-convention-versions. + See more versioning details at https://docs.jax.dev/en/latest/export/export.html#calling-convention-versions. module_kept_var_idx: the sorted indices of the arguments among `in_avals` that must be passed to the module. The other arguments have been dropped because they are not used. - uses_global_constants: whether the `mlir_module_serialized` uses shape + uses_global_constants: whether the ``mlir_module_serialized`` uses shape polymorphism or multi-platform export. - This may be because `in_avals` contains dimension + This may be because ``in_avals`` contains dimension variables, or due to inner calls of Exported modules that have dimension variables or platform index arguments. Such modules need shape refinement before XLA compilation. disabled_safety_checks: a list of descriptors of safety checks that have been - disabled at export time. See docstring for `DisabledSafetyCheck`. + disabled at export time. See docstring for ``DisabledSafetyCheck``. _get_vjp: an optional function that takes the current exported function and returns the exported VJP function. The VJP function takes a flat list of arguments, @@ -181,7 +191,10 @@ class Exported: for each primal output. It returns a tuple with the cotangents corresponding to the flattened primal inputs. - See a [description of the calling convention for the `mlir_module`](https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention). + DO NOT RELY directly on fields whose name starts with '_'. They will change. + + See a description of the calling convention for the :meth:`~jax.export.Exported.mlir_module` + method at https://docs.jax.dev/en/latest/export/export.html#module-calling-convention. """ fun_name: str in_tree: tree_util.PyTreeDef @@ -189,8 +202,17 @@ class Exported: out_tree: tree_util.PyTreeDef out_avals: tuple[core.ShapedArray, ...] + # _has_named_shardings is True if the export was done after 1/15/2026 and + # we have _in_named_shardings and _out_named_shardings. In that case we + # support multiple meshes for inputs and outputs, and we do not rely + # anymore on the Shardy-saved meshes, which do not have axis_types anyway, + # and do not support multiple meshes. + _has_named_shardings: bool + _in_named_shardings: tuple[NamedSharding | None, ...] # all None if not _has_named_shardings + _out_named_shardings: tuple[NamedSharding | None, ...] # all None if not _has_named_shardings in_shardings_hlo: tuple[HloSharding | None, ...] out_shardings_hlo: tuple[HloSharding | None, ...] + nr_devices: int platforms: tuple[str, ...] ordered_effects: tuple[effects.Effect, ...] @@ -205,7 +227,7 @@ class Exported: _get_vjp: Callable[[Exported], Exported] | None def mlir_module(self) -> str: - """A string representation of the `mlir_module_serialized`.""" + """A string representation of the ``mlir_module_serialized``.""" return xla_client._xla.mlir.deserialize_portable_artifact(self.mlir_module_serialized) def __str__(self): @@ -215,17 +237,18 @@ def __str__(self): def in_shardings_jax( self, - mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]: - """Creates Shardings corresponding to self.in_shardings_hlo. + mesh: mesh_lib.Mesh) -> Sequence[sharding.Sharding | None]: + """Creates Shardings corresponding to ``self.in_shardings_hlo`` and ``self._in_named_shardings``. - The Exported object stores `in_shardings_hlo` as HloShardings, which are - independent of a mesh or set of devices. This method constructs - Sharding that can be used in JAX APIs such as `jax.jit` or - `jax.device_put`. + The Exported object stores ``in_shardings_hlo`` as HloShardings, and + after 12/5/2025 also ``_in_named_shardings`` as NamedShardings with + abstract meshes. This method constructs + Sharding that can be used in JAX APIs such as :func:`jax.jit` or + :func:`jax.device_put`. The `mesh` argument may be a concrete mesh. Example usage: - >>> from jax import export + >>> from jax import export, sharding >>> # Prepare the exported object: >>> exp_mesh = sharding.Mesh(jax.devices(), ("a",)) >>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x), @@ -234,7 +257,7 @@ def in_shardings_jax( >>> exp.in_shardings_hlo ({devices=[8]<=[8]},) >>> # Create a mesh for running the exported object - >>> run_mesh = sharding.Mesh(jax.devices()[::-1], ("b",)) + >>> run_mesh = sharding.Mesh(jax.devices()[::-1], ("a",)) >>> # Put the args and kwargs on the appropriate devices >>> run_arg = jax.device_put(np.arange(jax.device_count()), ... exp.in_shardings_jax(run_mesh)[0]) @@ -250,18 +273,24 @@ def in_shardings_jax( Shard(device=CpuDevice(id=0), index=(slice(7, 8, None),), replica_id=0, data=[14])] """ - return tuple(_hlo_sharding_to_named_sharding(s, mesh) - for s in self.in_shardings_hlo) + return tuple( + _get_named_sharding(self._has_named_shardings, named_sharding, + hlo_sharding, aval, mesh) + for named_sharding, hlo_sharding, aval in zip( + self._in_named_shardings, self.in_shardings_hlo, self.in_avals)) def out_shardings_jax( self, - mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]: - """Creates Shardings corresponding to `self.out_shardings_hlo`. + mesh: mesh_lib.Mesh) -> Sequence[sharding.Sharding | None]: + """Creates Shardings for ``out_shardings_hlo`` and ``_out_named_shardings``. See documentation for in_shardings_jax. """ - return tuple(_hlo_sharding_to_named_sharding(s, mesh) - for s in self.out_shardings_hlo) + return tuple( + _get_named_sharding(self._has_named_shardings, named_sharding, + hlo_sharding, aval, mesh) + for named_sharding, hlo_sharding, aval in zip( + self._out_named_shardings, self.out_shardings_hlo, self.out_avals)) def has_vjp(self) -> bool: """Returns if this Exported supports VJP.""" @@ -283,9 +312,9 @@ def serialize(self, Args: vjp_order: The maximum vjp order to include. E.g., the value 2 means that we - serialize the primal functions and two orders of the `vjp` function. This + serialize the primal functions and two orders of the ``vjp`` function. This should allow 2nd order reverse mode differentiation of the deserialized - function. i.e., `jax.grad(jax.grad(f)).` + function. i.e., ``jax.grad(jax.grad(f))``. """ # Lazy load the serialization module, since flatbuffers is an optional # dependency. @@ -306,7 +335,7 @@ def call(self, *args, **kwargs): The invocation supports reverse-mode AD, and all the features supported by exporting: shape polymorphism, multi-platform, device polymorphism. - See the examples in the [JAX export documentation](https://jax.readthedocs.io/en/latest/export/export.html). + See the examples in the [JAX export documentation](https://docs.jax.dev/en/latest/export/export.html). """ return call_exported(self)(*args, **kwargs) @@ -315,7 +344,7 @@ def deserialize(blob: bytearray) -> Exported: """Deserializes an Exported. Args: - blob: a bytearray obtained from `Exported.serialize`. + blob: a bytearray obtained from :meth:`jax.export.Exported.serialize`. """ # Lazy load the serialization module, since flatbuffers is an optional # dependency. @@ -331,8 +360,8 @@ class _SerializeAuxData(Protocol): def __call__(self, aux_data: PyTreeAuxData) -> bytes: """Serializes the PyTree node AuxData. - The AuxData is returned by the `flatten_func` registered by - `tree_util.register_pytree_node`). + The AuxData is returned by the ``flatten_func`` registered by + :func:`jax.tree_util.register_pytree_node`). """ @@ -340,7 +369,7 @@ class _DeserializeAuxData(Protocol): def __call__(self, serialized_aux_data: bytes) -> PyTreeAuxData: """Deserializes the PyTree node AuxData. - The result will be passed to `_BuildFromChildren`. + The result will be passed to ``_BuildFromChildren``. """ @@ -348,7 +377,7 @@ class _BuildFromChildren(Protocol): def __call__(self, aux_data: PyTreeAuxData, children: Sequence[Any]) -> Any: """Materializes a T given a deserialized AuxData and children. - This is similar in scope with the `unflatten_func`. + This is similar in scope with the ``unflatten_func``. """ @@ -378,37 +407,37 @@ def register_pytree_node_serialization( You must use this function before you can serialize and deserialize PyTree nodes for the types not supported natively. We serialize PyTree nodes for - the `in_tree` and `out_tree` fields of `Exported`, which are part of the + the ``in_tree`` and ``out_tree`` fields of ``Exported``, which are part of the exported function's calling convention. This function must be called after calling - `jax.tree_util.register_pytree_node` (except for `collections.namedtuple`, - which do not require a call to `register_pytree_node`). + :func:`jax.tree_util.register_pytree_node` (except for ``collections.namedtuple``, + which do not require a call to ``register_pytree_node``). Args: nodetype: the type whose PyTree nodes we want to serialize. It is an - error to attempt to register multiple serializations for a `nodetype`. + error to attempt to register multiple serializations for a ``nodetype``. serialized_name: a string that will be present in the serialization and will be used to look up the registration during deserialization. It is an error to attempt to register multiple serializations for a - `serialized_name`. + ``serialized_name``. serialize_auxdata: serialize the PyTree auxdata (returned by the - `flatten_func` argument to `jax.tree_util.register_pytree_node`.). + ``flatten_func`` argument to :func:`jax.tree_util.register_pytree_node`.). deserialize_auxdata: deserialize the auxdata that was serialized by the - `serialize_auxdata`. + ``serialize_auxdata``. from_children: if present, this is a function that takes that result of - `deserialize_auxdata` along with some children and creates an instance - of `nodetype`. This is similar to the `unflatten_func` passed to - `jax.tree_util.register_pytree_node`. If not present, we look up - and use the `unflatten_func`. This is needed for `collections.namedtuple`, - which does not have a `register_pytree_node`, but it can be useful to - override that function. Note that the result of `from_children` is - only used with `jax.tree_util.tree_structure` to construct a proper + ``deserialize_auxdata`` along with some children and creates an instance + of ``nodetype``. This is similar to the ``unflatten_func`` passed to + :func:`jax.tree_util.register_pytree_node`. If not present, we look up + and use the ``unflatten_func``. This is needed for ``collections.namedtuple``, + which does not have a ``register_pytree_node``, but it can be useful to + override that function. Note that the result of ``from_children`` is + only used with :func:`jax.tree_util.tree_structure` to construct a proper PyTree node, it is not used to construct the outputs of the serialized function. Returns: - the same type passed as `nodetype`, so that this function can + the same type passed as ``nodetype``, so that this function can be used as a class decorator. """ if nodetype in serialization_registry: @@ -442,23 +471,23 @@ def register_namedtuple_serialization( serialized_name: str) -> type[T]: """Registers a namedtuple for serialization and deserialization. - JAX has native PyTree support for `collections.namedtuple`, and does not - require a call to `jax.tree_util.register_pytree_node`. However, if you + JAX has native PyTree support for ``collections.namedtuple``, and does not + require a call to :func:`jax.tree_util.register_pytree_node`. However, if you want to serialize functions that have inputs of outputs of a namedtuple type, you must register that type for serialization. Args: nodetype: the type whose PyTree nodes we want to serialize. It is an - error to attempt to register multiple serializations for a `nodetype`. + error to attempt to register multiple serializations for a ``nodetype``. On deserialization, this type must have the same set of keys that were present during serialization. serialized_name: a string that will be present in the serialization and will be used to look up the registration during deserialization. It is an error to attempt to register multiple serializations for - a `serialized_name`. + a ``serialized_name``. Returns: - the same type passed as `nodetype`, so that this function can + the same type passed as ``nodetype``, so that this function can be used as a class decorator. """ if not _is_namedtuple(nodetype): @@ -509,45 +538,55 @@ def _serialize_ordereddict_keys(keys): def default_export_platform() -> str: """Retrieves the default export platform. - One of: `tpu`, `cpu`, `cuda`, `rocm`. + One of: ``'tpu'``, ``'cpu'``, ``'cuda'``, ``'rocm'``. """ # Canonicalize to turn 'gpu' into 'cuda' or 'rocm' - return xb.canonicalize_platform(jax.default_backend()) + return xb.canonicalize_platform(xb.default_backend()) default_lowering_platform = default_export_platform def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]: """Returns the shape and dtype of a jax.Array or a j""" - if isinstance(a, jax.ShapeDtypeStruct): + if isinstance(a, api.ShapeDtypeStruct): return a.shape, a.dtype aval = core.get_aval(a) return aval.shape, aval.dtype +@functools.partial(traceback_util.api_boundary, + repro_api_name="jax.export.export") def export( fun_jit: stages.Wrapped, *, platforms: Sequence[str] | None = None, disabled_checks: Sequence[DisabledSafetyCheck] = (), + _override_lowering_rules: Sequence[tuple[Any, Any]] | None = None ) -> Callable[..., Exported]: """Exports a JAX function for persistent serialization. Args: - fun_jit: the function to export. Should be the result of `jax.jit`. + fun_jit: the function to export. Should be the result of :func:`jax.jit`. platforms: Optional sequence containing a subset of 'tpu', 'cpu', 'cuda', 'rocm'. If more than one platform is specified, then the exported code takes an argument specifying the platform. If None, then use the default JAX backend. The calling convention for multiple platforms is explained at - https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention. + https://docs.jax.dev/en/latest/export/export.html#module-calling-convention. + _override_lowering_rules: an optional sequence of custom lowering rules + for some JAX primitives. Each element of the sequence is a pair + of a JAX primitive and a lowering function. Defining lowering rules + is an advanced feature using JAX internal APIs, which are subject + to change. Furthermore, the responsibility for the stability of the + MLIR emitted through these custom lowering rules, rests with the user + of these rules. disabled_checks: the safety checks to disable. See documentation for - of `jax.export.DisabledSafetyCheck`. + of :class:`jax.export.DisabledSafetyCheck`. Returns: - a function that takes args and kwargs pytrees of {class}`jax.ShapeDtypeStruct`, - or values with `.shape` and `.dtype` attributes, and returns an - `Exported`. + a function that takes args and kwargs pytrees of :class:`jax.ShapeDtypeStruct`, + or values with ``.shape`` and ``.dtype`` attributes, and returns an + :class:`~jax.export.Exported`. Usage: @@ -568,7 +607,8 @@ def export( Array([0.09983342, 0.19866933, 0.29552022, 0.38941833], dtype=float32) """ return _export_internal(fun_jit, platforms=platforms, - disabled_checks=disabled_checks) + disabled_checks=disabled_checks, + override_lowering_rules=_override_lowering_rules) # TODO(necula): remove this once we improve the integration with jax2tf. @@ -577,15 +617,16 @@ def _export_internal( *, platforms: Sequence[str] | None = None, disabled_checks: Sequence[DisabledSafetyCheck] = (), - _device_assignment_for_internal_jax2tf_use_only = None, + _device_assignment_for_internal_jax2tf_use_only=None, + override_lowering_rules=None, ) -> Callable[..., Exported]: """Exports native serialization for a JAX function. Note: this function exists only for internal usage by jax2tf. Use - `jax.export` instead. - See https://jax.readthedocs.io/en/latest/export/export.html + :mod:`jax.export` instead. + See https://docs.jax.dev/en/latest/export/export.html - See docstring of `export` for more details. + See docstring of ``export`` for more details. """ if not isinstance(fun_jit, stages.Wrapped): raise ValueError( @@ -604,7 +645,9 @@ def do_export(*args_specs, **kwargs_specs) -> Exported: lowered = traced.lower( lowering_platforms=actual_lowering_platforms, _private_parameters=mlir.LoweringParameters( + override_lowering_rules=override_lowering_rules, for_export=True, + hoist_constants_as_args=False, export_ignore_forward_compatibility=config.export_ignore_forward_compatibility.value)) return _export_lowered( lowered, traced.jaxpr, traced.fun_name, @@ -630,6 +673,55 @@ def check_symbolic_scope_errors(fun_jax, args_specs, kwargs_specs): other_descr=shape_poly.args_kwargs_path_to_str(k_path)) +def to_named_sharding_with_abstract_mesh( + s: LoweringSharding, + aval: core.ShapedArray, + mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh | None) -> NamedSharding | None: + # We store and serialize all shardings as NamedShardings with abstract mesh, + # even if they are SingleDeviceShardings. + if isinstance(s, sharding_impls.UnspecifiedValue): + return None + if isinstance(s, sharding_impls.NamedSharding): + return sharding_impls.NamedSharding(s.mesh.abstract_mesh, + s.spec, memory_kind=s.memory_kind) + if isinstance(s, sharding_impls.SingleDeviceSharding): + return sharding_impls.NamedSharding( + mesh_lib.empty_abstract_mesh, + sharding_impls.PartitionSpec(*([None] *aval.ndim)), + memory_kind=s.memory_kind) + + if isinstance(s, sharding_impls.GSPMDSharding): + if mesh is None: + if s._hlo_sharding.tuple_elements(): + raise TypeError( + f"Cannot convert GSPMDSharding {s} into NamedSharding.") + elif s._hlo_sharding.is_replicated(): + return sharding_impls.NamedSharding( + mesh_lib.empty_abstract_mesh, + sharding_impls.PartitionSpec(*([None] *aval.ndim)), + memory_kind=s.memory_kind) + elif s._hlo_sharding.is_tiled(): + if not s._hlo_sharding.is_tile_assignment_iota(): + raise TypeError( + f"Cannot convert GSPMDSharding {s} into NamedSharding.") + axis_sizes = tuple(s._hlo_sharding.get_axis_sizes()) + axis_names = tuple(f'_axis_{i}' for i in range(len(axis_sizes))) + mesh = mesh_lib.AbstractMesh(axis_sizes, axis_names) + return sharding_impls._gspmd_to_named_sharding_via_mesh(s, mesh) + else: + raise TypeError( + f"Cannot convert GSPMDSharding {s} into NamedSharding.") + else: + return sharding_impls._gspmd_to_named_sharding_via_mesh(s, mesh) + + assert False, f"Unsupported sharding: {s}" + +def named_to_hlo_sharding(s: NamedSharding | None, + aval: core.ShapedArray) -> HloSharding | None: + if s is None: return None + return s._to_xla_hlo_sharding(aval.ndim) + + def _export_lowered( lowered: stages.Lowered, jaxpr: core.ClosedJaxpr, @@ -658,6 +750,12 @@ def _export_lowered( # For pmap module_kept_var_idx = tuple(range(len(args_avals_flat))) shape_poly_state = lowering.compile_args["shape_poly_state"] + + # Make a copy of mlir module as we should not mutate it + # because it may be cached + context = mlir.make_ir_context() + with context, ir.Location.unknown(context): + mlir_module = ir.Module.parse(mlir.module_to_bytecode(mlir_module)) if (not all(core.is_constant_shape(a.shape) for a in args_avals_flat) or lowering.compile_args.get("ordered_effects", [])): mlir_module = _wrap_main_func( @@ -674,12 +772,10 @@ def _export_lowered( # Shardy was used during lowering if we can find the Shardy mesh in the # module. Note that the mesh should have been lifted by the # `sdy-lift-inlined-meshes` pass in mlir.py. - shardy_enabled = False - if xla_extension_version >= 319: - shardy_enabled = xla_extension.sdy.lowered_with_shardy( - mlir.module_to_bytecode(mlir_module)) + shardy_enabled = has_sdy_mesh(ir.SymbolTable(mlir_module.operation), + mlir_module) - mlir_module_serialized = _module_to_bytecode(mlir_module, shardy_enabled) + mlir_module_serialized = _module_to_bytecode(mlir_module) # Figure out the result types and shapes if "global_out_avals" in lowering.compile_args: @@ -691,16 +787,15 @@ def _export_lowered( out_avals_flat = lowered.compile_args["out_avals"] # type: ignore # Log and then check the module. - if logging.vlog_is_on(3): - logmsg = (f"fun_name={fun_name} version={version} " - f"lowering_platforms={lowering._platforms} " # type: ignore[unused-ignore,attribute-error] - f"disabled_checks={disabled_checks}") - logging.info("Exported JAX function: %s\n", logmsg) - logging.info(mlir.dump_module_message(mlir_module, "export")) - logging.info( - "Size of mlir_module_serialized: %d byte", - len(mlir_module_serialized), - ) + logmsg = (f"fun_name={fun_name} version={version} " + f"lowering_platforms={lowering._platforms} " # type: ignore[unused-ignore,attribute-error] + f"disabled_checks={disabled_checks}") + logger.debug("Exported JAX function: %s\n", logmsg) + logger.debug(mlir.dump_module_message(mlir_module, "export")) + logger.debug( + "Size of mlir_module_serialized: %d byte", + len(mlir_module_serialized), + ) _check_module(mlir_module, disabled_checks=disabled_checks, @@ -710,39 +805,33 @@ def _export_lowered( unordered_effects = tuple(lowering.compile_args["unordered_effects"]) nr_devices = lowering.compile_args["num_devices"] - def export_sharding(s: LoweringSharding, - aval: core.ShapedArray) -> HloSharding | None: - if isinstance(s, sharding_impls.UnspecifiedValue): - return None - return s._to_xla_hlo_sharding(aval.ndim) all_in_shardings = expand_in_shardings(lowering.compile_args["in_shardings"], module_kept_var_idx, len(args_avals_flat)) - in_shardings = tuple( - export_sharding(s, aval) + + cur_mesh = None + if config.use_shardy_partitioner.value: + for sharding in itertools.chain.from_iterable([ + all_in_shardings, lowering.compile_args["out_shardings"]]): + if isinstance(sharding, sharding_impls.NamedSharding): + cur_mesh = sharding.mesh + break + if cur_mesh and isinstance(cur_mesh, mesh_lib.Mesh): + cur_mesh = cur_mesh.abstract_mesh + + in_named_shardings = tuple( + to_named_sharding_with_abstract_mesh(s, aval, cur_mesh) for s, aval in zip(all_in_shardings, args_avals_flat)) - out_shardings = tuple( - export_sharding(s, aval) + + out_named_shardings = tuple( + to_named_sharding_with_abstract_mesh(s, aval, cur_mesh) for s, aval in zip(lowering.compile_args["out_shardings"], out_avals_flat)) - device_assignment = lowering.compile_args["device_assignment"] + device_assignment = lowering._device_list # type: ignore if _device_assignment_for_internal_jax2tf_use_only is not None: _device_assignment_for_internal_jax2tf_use_only[0] = device_assignment - mesh = None - if config.use_shardy_partitioner.value: - for sharding in itertools.chain.from_iterable( - [all_in_shardings, lowering.compile_args["out_shardings"]]): - if isinstance(sharding, sharding_impls.NamedSharding): - if mesh is not None and mesh.shape_tuple != sharding.mesh.shape_tuple: - raise ValueError( - f'Mesh for all inputs should be equal. Got one mesh: {mesh} and' - f' another mesh: {sharding.mesh}') - mesh = sharding.mesh - if mesh and isinstance(mesh, mesh_lib.Mesh): - mesh = mesh.abstract_mesh - def _get_exported_vjp(exp_primal: Exported) -> Exported: # Turn the primal jaxpr into a function, in preparation for exporting # the VJP. Note that jaxpr_as_fun produces a function with flat arguments @@ -753,13 +842,16 @@ def _get_exported_vjp(exp_primal: Exported) -> Exported: fun_jax, in_tree=exp_primal.in_tree, in_avals=exp_primal.in_avals, + has_named_shardings=exp_primal._has_named_shardings, in_shardings_hlo=exp_primal.in_shardings_hlo, - out_avals=exp_primal.out_avals, out_shardings_hlo=exp_primal.out_shardings_hlo, + in_named_shardings=exp_primal._in_named_shardings, + out_named_shardings=exp_primal._out_named_shardings, + out_avals=exp_primal.out_avals, device_assignment=device_assignment, apply_jit=True, flat_primal_fun=True, - mesh=mesh) # type: ignore[arg-type] + mesh=cur_mesh) # type: ignore[arg-type] return export(fun_vjp_jax, # type: ignore[arg-type] platforms=exp_primal.platforms, disabled_checks=exp_primal.disabled_safety_checks)(*vjp_in_avals) @@ -770,8 +862,14 @@ def _get_exported_vjp(exp_primal: Exported) -> Exported: out_tree=lowered.out_tree, in_avals=tuple(args_avals_flat), out_avals=tuple(out_avals_flat), - in_shardings_hlo=in_shardings, - out_shardings_hlo=out_shardings, + _has_named_shardings=True, + _in_named_shardings=in_named_shardings, + _out_named_shardings=out_named_shardings, + in_shardings_hlo=tuple(named_to_hlo_sharding(s, aval) + for s, aval in zip(in_named_shardings, args_avals_flat)), + out_shardings_hlo=tuple(named_to_hlo_sharding(s, aval) + for s, aval in zip(out_named_shardings, out_avals_flat)), + nr_devices=nr_devices, platforms=lowering._platforms, # type: ignore ordered_effects=ordered_effects, @@ -783,12 +881,8 @@ def _get_exported_vjp(exp_primal: Exported) -> Exported: calling_convention_version=version, _get_vjp=_get_exported_vjp) -def _module_to_bytecode(module: ir.Module, shardy_enabled: bool) -> bytes: - if xla_extension_version >= 319 and shardy_enabled: - mlir_str = xla_extension.sdy.sdy_round_trip_export_pipeline( - mlir.module_to_bytecode(module)) - else: - mlir_str = mlir.module_to_bytecode(module) +def _module_to_bytecode(module: ir.Module) -> bytes: + mlir_str = mlir.module_to_bytecode(module) # `target_version` is used to manage situations when a StableHLO producer # and a StableHLO consumer were built using different versions of StableHLO. # @@ -807,13 +901,10 @@ def _module_to_bytecode(module: ir.Module, shardy_enabled: bool) -> bytes: # Note that this does not verify any JAX custom calls, which are only # guaranteed 3w of forward compatibility, and only prevents use of new # StableHLO features from failing on older hardware. - if hlo.get_api_version() < 9: - target_version = hlo.get_minimum_version() - else: - target_version = hlo.get_version_from_compatibility_requirement( - hlo.StablehloCompatibilityRequirement.WEEK_4) + target_version = hlo.get_version_from_compatibility_requirement( + hlo.StablehloCompatibilityRequirement.WEEK_4) module_serialized = xla_client._xla.mlir.serialize_portable_artifact( # type: ignore - mlir_str, target_version) + mlir_str, target_version, xb.get_backend().serialize_with_sdy) return module_serialized @@ -828,27 +919,26 @@ def _wrap_main_func( ) -> ir.Module: """Wraps the lowered module with a new "main" handling dimension arguments. - See calling convention documentation https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention. + See calling convention documentation https://docs.jax.dev/en/latest/export/export.html#module-calling-convention. Args: - module: the HLO module as obtained from lowering. + module: a copy of HLO module as obtained from lowering. args_avals_flat: the avals for all the arguments of the lowered function, - which correspond to the array arguments of the `module`. - args_kwargs_tree: the PyTreeDef corresponding to `(args, kwargs)`, for error + which correspond to the array arguments of the ``module``. + args_kwargs_tree: the PyTreeDef corresponding to ``(args, kwargs)``, for error messages. - has_platform_index_argument: whether the `module` has a first platform + has_platform_index_argument: whether the ``module`` has a first platform index argument module_kept_var_idx: a sorted tuple of integers with the indices of arguments - in `args_avals_flat` that are kept as `module` arguments. + in ``args_avals_flat`` that are kept as ``module`` arguments. serialization_version: the target serialization version Returns the wrapped module, without dimension and token arguments. """ dim_vars = shape_poly.all_dim_vars(args_avals_flat) - context = mlir.make_ir_context() + context = module.context + wrapped_module = module with context, ir.Location.unknown(context): - # Make a copy, do not mutate because it may be cached - wrapped_module = ir.Module.parse(mlir.module_to_bytecode(module)) symbol_table = ir.SymbolTable(wrapped_module.operation) orig_main = symbol_table["main"] orig_main.attributes["sym_visibility"] = ir.StringAttr.get("private") @@ -933,14 +1023,16 @@ def is_token(typ, attrs): host_callbacks=[], module=wrapped_module, context=context, lowering_parameters=mlir.LoweringParameters( global_constant_computation=True, - for_export=True, + for_export=True, hoist_constants_as_args=False, export_ignore_forward_compatibility=config.export_ignore_forward_compatibility.value, )) ctx = mlir.LoweringRuleContext( module_context=module_context, - name_stack=source_info_util.new_name_stack(), primitive=None, + name_stack=source_info_util.new_name_stack(), traceback=None, + primitive=None, avals_in=args_avals_flat, avals_out=None, - tokens_in=mlir.TokenSet(), tokens_out=None) + tokens_in=mlir.TokenSet(), tokens_out=None, + const_lowering={}) # We compute dim_values from the array arguments. new_main_op_array_args = new_main_op.arguments[-nr_array_args:] if shape_poly.all_dim_vars(args_avals_flat): @@ -982,6 +1074,9 @@ def is_token(typ, attrs): orig_main_args) func_dialect.ReturnOp([call.results[idx] for idx in new_main_result_indices]) symbol_table.set_symbol_name(new_main_op, "main") + pipeline = passmanager.PassManager.parse( + 'builtin.module(symbol-dce)') + pipeline.run(wrapped_module.operation) return wrapped_module @@ -1059,6 +1154,8 @@ def _check_lowering(lowering) -> None: # qr on GPU "cusolver_geqrf_ffi", "cusolver_orgqr_ffi", "hipsolver_geqrf_ffi", "hipsolver_orgqr_ffi", + # cholesky on GPU + "cusolver_potrf_ffi", "hipsolver_potrf_ffi", # eigh on GPU "cusolver_syevd_ffi", "hipsolver_syevd_ffi", # svd on GPU @@ -1066,6 +1163,8 @@ def _check_lowering(lowering) -> None: "hipsolver_gesvd_ffi", "hipsolver_gesvdj_ffi", # tridiagonal on GPU "cusolver_sytrd_ffi", + # tridiagonal_solve on GPU + "cusparse_gtsv2_ffi", ] # These are the JAX custom call target names that are guaranteed to be stable. # Their backwards compatibility is tested by back_compat_test.py. @@ -1073,33 +1172,22 @@ def _check_lowering(lowering) -> None: *_CPU_FFI_KERNELS, *_GPU_FFI_KERNELS, "Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape", + "annotate_device_placement", "cu_threefry2x32_ffi", # Triton IR does not guarantee stability. # "__gpu$xla.gpu.triton", - # cholesky on CPU - "lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf", # eigh on TPU "Eigh", - # eig on CPU - "lapack_sgeev", "lapack_dgeev", "lapack_cgeev", "lapack_zgeev", - # svd on CPU - "lapack_sgesdd", "lapack_dgesdd", "lapack_cgesdd", "lapack_zgesdd", # qr and svd on TPU "Qr", "ProductOfElementaryHouseholderReflectors", - # triangular_solve on CPU - "blas_strsm", "blas_dtrsm", "blas_ctrsm", "blas_ztrsm", - # schur on CPU - "lapack_sgees", "lapack_dgees", "lapack_cgees", "lapack_zgees", - # tridiagonal on CPU - "lapack_ssytrd", "lapack_dsytrd", "lapack_chetrd", "lapack_zhetrd", - # hessenberg on CPU - "lapack_sgehrd", "lapack_dgehrd", "lapack_cgehrd", "lapack_zgehrd", # lu on TPU "LuDecomposition", # ApproxTopK on TPU "ApproxTopK", "stablehlo.dynamic_approx_top_k", "tf.call_tf_function", # From jax2tf.call_tf(func, call_tf_graph=True) "tpu_custom_call", # Pallas/TPU kernels + "AllocateBuffer", # lax.empty implementation + "mosaic_gpu_v2", # Pallas Mosaic GPU kernels # TODO(burmako): maintain backwards compatibility for these, until they # are upstreamed to StableHLO. # See https://github.com/openxla/stablehlo/issues/8. @@ -1109,7 +1197,7 @@ def _check_lowering(lowering) -> None: "shape_assertion", # Used by shape_poly to evaluate assertions } -check_sharding_pattern = re.compile(r"^({replicated}|{unknown shard_as.*}|\[({}, )*{}\]"")$") +check_sharding_pattern = re.compile(r"^({replicated}|{unknown shard_as.*}|.*\[({}, )*{}\]"")$") def _check_module(mod: ir.Module, *, disabled_checks: Sequence[DisabledSafetyCheck], @@ -1135,7 +1223,7 @@ def _check_module(mod: ir.Module, *, module_uses_non_replicated_sharding = False def check_sharding(op: ir.Operation, loc: ir.Location): try: - sharding = (op.attributes["sdy.sharding"] if shardy_enabled else + sharding = (op.attributes["sharding"] if shardy_enabled else op.attributes["mhlo.sharding"]) except KeyError: pass @@ -1164,6 +1252,8 @@ def check_op(op: ir.Operation): disallowed_custom_call_ops.append(f"{op} at {op.location}") if call_target_name_attr == sharding_attr: check_sharding(op, op.location) + elif op_name == "sdy.sharding_constraint": + check_sharding(op, op.location) def walk_operations(op): check_op(op) @@ -1177,7 +1267,7 @@ def walk_operations(op): disallowed_custom_call_ops_str = "\n".join(disallowed_custom_call_ops) msg = ("Cannot serialize code with custom calls whose targets have no " "compatibility guarantees. " - "See https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees-for-custom-calls. " + "See https://docs.jax.dev/en/latest/export/export.html#compatibility-guarantees-for-custom-calls. " "Examples are:\n" f"{disallowed_custom_call_ops_str}.\n") raise ValueError(msg) @@ -1198,31 +1288,45 @@ def expand_in_shardings(in_shardings: Sequence[LoweringSharding], return tuple(all_in_shardings) -def _hlo_sharding_to_xla_compatible_sharding( +def _hlo_sharding_to_gspmd_sharding( hlo_sharding: HloSharding | None, - mesh: sharding.Mesh) -> sharding.Sharding | None: + device_assignment: Sequence[_jax.Device] + ) -> sharding_impls.GSPMDSharding | None: if hlo_sharding is None: return None - return sharding_impls._gspmd_to_named_sharding_via_mesh( - _hlo_sharding_to_gspmd_sharding(hlo_sharding, tuple(mesh.devices.flat)), # type: ignore[arg-type] - mesh) + return sharding_impls.GSPMDSharding(device_assignment, hlo_sharding) -def _hlo_sharding_to_gspmd_sharding( +def _get_named_sharding( + has_named_shardings: bool, + named_sharding: NamedSharding | None, hlo_sharding: HloSharding | None, - device_assignment: Sequence[jax.Device]) -> sharding.GSPMDSharding | None: - if hlo_sharding is None: - return None - return sharding.GSPMDSharding(device_assignment, hlo_sharding) + aval: core.ShapedArray, + new_mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh | None + ) -> sharding_impls.NamedSharding | None: + if has_named_shardings: + if named_sharding is None: # Unspecified + return None + if new_mesh is None: + return named_sharding + # TODO(necula): for now we require that the mesh uses the same axis_names + if (new_mesh.abstract_mesh.axis_sizes != named_sharding.mesh.axis_sizes or + new_mesh.abstract_mesh.axis_names != named_sharding.mesh.axis_names or + new_mesh.abstract_mesh.axis_types != named_sharding.mesh.axis_types): + raise ValueError(f"NamedSharding new mesh {new_mesh} does not match the mesh used for export {named_sharding.mesh}.") + return named_sharding.update(mesh=new_mesh) -def _hlo_sharding_to_named_sharding( - hlo_sharding: HloSharding | None, - mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh): if hlo_sharding is None: return None - return sharding_impls.create_mesh_pspec_sharding( - mesh, sharding_impls.parse_flatten_op_sharding(hlo_sharding, mesh)[0]) + + mem_kind: str | None = None + if aval.memory_space is not None: + mem_kind = core.mem_space_to_kind(aval.memory_space) + + return sharding_impls.cached_named_sharding( + new_mesh, sharding_impls.parse_flatten_op_sharding(hlo_sharding, new_mesh)[0], # type: ignore + memory_kind=mem_kind) def _get_vjp_fun( @@ -1231,8 +1335,11 @@ def _get_vjp_fun( in_tree: tree_util.PyTreeDef, in_avals: Sequence[core.AbstractValue], out_avals: Sequence[core.AbstractValue], + has_named_shardings: bool, in_shardings_hlo: tuple[HloSharding | None, ...], out_shardings_hlo: tuple[HloSharding | None, ...], + in_named_shardings: tuple[NamedSharding | None, ...], + out_named_shardings: tuple[NamedSharding | None, ...], device_assignment: Sequence[sharding_impls.Device] | None, apply_jit: bool, flat_primal_fun: bool = False, @@ -1254,7 +1361,7 @@ def flattened_primal_fun_jax(*args_flat): args_flat_jax, out_cts_flat_jax = util.split_list(args_and_out_cts_flat_jax, [len(in_avals)]) - _, pullback_jax = jax.vjp(primal_fun if flat_primal_fun else flattened_primal_fun_jax, + _, pullback_jax = api.vjp(primal_fun if flat_primal_fun else flattened_primal_fun_jax, *args_flat_jax) return pullback_jax(out_cts_flat_jax) @@ -1263,19 +1370,26 @@ def flattened_primal_fun_jax(*args_flat): map(lambda a: a.to_tangent_aval(), out_avals))) if apply_jit: - if mesh: + if has_named_shardings or mesh: vjp_in_shardings = tuple( - _hlo_sharding_to_named_sharding(s, mesh) - for s in itertools.chain(in_shardings_hlo, out_shardings_hlo)) - vjp_out_shardings = tuple(_hlo_sharding_to_named_sharding(s, mesh) - for s in in_shardings_hlo) + _get_named_sharding(has_named_shardings, named_sharding, # type: ignore + hlo_sharding, aval, mesh) + for named_sharding, hlo_sharding, aval in zip( + itertools.chain(in_named_shardings, out_named_shardings), + itertools.chain(in_shardings_hlo, out_shardings_hlo), + vjp_in_avals)) + vjp_out_shardings = tuple( + _get_named_sharding(has_named_shardings, named_sharding, + hlo_sharding, aval, mesh) # type: ignore + for named_sharding, hlo_sharding, aval in zip( + in_named_shardings, in_shardings_hlo, in_avals)) else: assert device_assignment is not None vjp_in_shardings = tuple( - _hlo_sharding_to_gspmd_sharding(s, device_assignment) + _hlo_sharding_to_gspmd_sharding(s, device_assignment) # type: ignore for s in itertools.chain(in_shardings_hlo, out_shardings_hlo)) vjp_out_shardings = tuple( - _hlo_sharding_to_gspmd_sharding(s, device_assignment) + _hlo_sharding_to_gspmd_sharding(s, device_assignment) # type: ignore for s in in_shardings_hlo) return pjit.pjit(fun_vjp_jax, in_shardings=vjp_in_shardings, @@ -1286,12 +1400,12 @@ def flattened_primal_fun_jax(*args_flat): ### Calling the exported function -def call(exported: Exported) -> Callable[..., jax.Array]: +def call(exported: Exported) -> Callable[..., typing.Array]: if not isinstance(exported, Exported): raise ValueError( "The exported argument must be an export.Exported. " f"Found {exported}.") - @jax.custom_vjp + @custom_derivatives.custom_vjp def f_flat(*args_flat): return call_exported_p.bind(*args_flat, exported=exported) @@ -1400,13 +1514,25 @@ def pp_arg_dim(dim_idx: int | None) -> str: # it would be ambiguous whether we should continue tracing with a result # of type `f32[c]` or `f32[d]`. shape_constraints.check_statically(synthetic_eval) - exported_dim_values = [synthetic_eval.evaluate(solution[var]) + exported_dim_values = [synthetic_eval.evaluate(solution[var]) # type: ignore[arg-type] for var in exported_dim_vars] - out_avals = tuple( - core.ShapedArray(core.evaluate_shape(out_aval.shape, exported_dim_vars, - *exported_dim_values), - dtype=out_aval.dtype, weak_type=out_aval.weak_type) - for out_aval in exported.out_avals) + + def make_aval(out_aval_idx: int): + out_aval = exported.out_avals[out_aval_idx] + if exported._has_named_shardings: + sharding = exported._out_named_shardings[out_aval_idx] + else: + sharding = None + aval = core.ShapedArray( + core.evaluate_shape(out_aval.shape, exported_dim_vars, + *exported_dim_values), + dtype=out_aval.dtype, weak_type=out_aval.weak_type, + # memory_space from out_aval because sharding may be None + memory_space=out_aval.memory_space) + return core.update_aval_with_sharding(aval, sharding) + + out_avals = tuple(make_aval(out_aval_idx) + for out_aval_idx in range(len(exported.out_avals))) return out_avals, set(exported.ordered_effects + exported.unordered_effects) @@ -1417,31 +1543,62 @@ def _call_exported_impl(*args, exported: Exported): call_exported_p.def_impl(_call_exported_impl) + +def get_mesh_from_symbol(symtab: ir.SymbolTable) -> mesh_lib.AbstractMesh: + if "mesh" not in symtab: + return mesh_lib.empty_abstract_mesh + mesh_attr = sdy.MeshAttr(symtab["mesh"].mesh) + axes = [sdy.MeshAxisAttr(a) for a in mesh_attr.axes] + if not axes: + return mesh_lib.empty_abstract_mesh + axes_sizes = tuple(a.size for a in axes) + axes_names = tuple(a.name for a in axes) + # TODO(necula): Shardy meshes do not have axis_types :-( + return mesh_lib.AbstractMesh(axes_sizes, axes_names) + +def has_sdy_meshes_in_frontend_attributes(submodule: ir.Module) -> bool: + if "mhlo.frontend_attributes" not in submodule.operation.attributes: + return False + frontend_attributes = submodule.operation.attributes[ + "mhlo.frontend_attributes" + ] + return "xla.sdy.meshes" in frontend_attributes + +def has_sdy_mesh(symtab: ir.SymbolTable, submodule: ir.Module) -> bool: + for mesh_name in ("mesh", "empty_mesh", "maximal_mesh_0"): + if mesh_name in symtab: + return isinstance(symtab[mesh_name], sdy.MeshOp) + return has_sdy_meshes_in_frontend_attributes(submodule) + def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, exported: Exported): if exported.uses_global_constants: ctx.module_context.shape_poly_state.uses_dim_vars = True submodule = ir.Module.parse(exported.mlir_module()) - shardy_enabled = False - if xla_extension_version >= 319: - shardy_enabled = xla_extension.sdy.lowered_with_shardy( - mlir.module_to_bytecode(submodule)) + symtab = ir.SymbolTable(submodule.operation) + shardy_enabled = has_sdy_mesh(symtab, submodule) if shardy_enabled: - submodule = ir.Module.parse(xla_extension.sdy.sdy_round_trip_import_shardings( - mlir.module_to_bytecode(submodule))) + if not config.use_shardy_partitioner.value: + raise ValueError( + "The function was exported with shardy enabled but you are calling " + "it with Shardy disabled. Please enable Shardy using " + "`--jax_use_shardy_partitioner=True`.") + # TODO(b/422690222): remove this pass once we don't need to support 6m + # old exported modules. + if has_sdy_meshes_in_frontend_attributes(submodule): + with submodule.context: + pipeline = passmanager.PassManager.parse( + 'builtin.module(xla-sdy-round-trip-import-shardy-attrs)') + pipeline.run(submodule.operation) with submodule.context: pipeline = passmanager.PassManager.parse( 'builtin.module(sdy-lift-inlined-meshes)') pipeline.run(submodule.operation) - - # TODO(bartchr): delete this once I have JAX export support multiple meshes. mesh = None if shardy_enabled: - sdy_mesh_axes = xla_extension.sdy.get_mesh(mlir.module_to_bytecode(submodule)) - mesh = mesh_lib.AbstractMesh( - *list(zip(*sdy_mesh_axes))[::-1]) if sdy_mesh_axes else None + mesh = get_mesh_from_symbol(symtab) axis_context = ctx.module_context.axis_context if isinstance(axis_context, sharding_impls.ShardingContext): @@ -1452,40 +1609,46 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, num_devices = axis_context.axis_env.nreps else: raise NotImplementedError(type(axis_context)) - if num_devices != exported.nr_devices: - # In some special cases we allow running with a different number of devices - # than the function was exported for. - err_msg = "" - if exported.nr_devices != 1: - err_msg = "the function was exported for more than 1 device." - elif (_check_module(submodule, disabled_checks=(), shardy_enabled=shardy_enabled) - or any(s is not None and not s.is_replicated() - for s in exported.in_shardings_hlo + exported.out_shardings_hlo)): - err_msg = "the function contains non-replicated sharding annotations." - if err_msg: - raise ValueError( + if num_devices != exported.nr_devices and exported.nr_devices != 1: + raise ValueError( f"Function {exported.fun_name} was exported for " f"{exported.nr_devices} devices and is called in a context with " - f"{num_devices} devices. This is disallowed because: {err_msg}" - ) + f"{num_devices} devices, which is not allowed." + ) # Apply in_shardings - if shardy_enabled: + if exported._has_named_shardings: args = tuple( wrap_with_sharding( ctx, x, x_aval, - _hlo_sharding_to_named_sharding(x_sharding, mesh)) # type: ignore[arg-type] - for x, x_aval, x_sharding in zip(args, ctx.avals_in, exported.in_shardings_hlo)) + _get_named_sharding(exported._has_named_shardings, + named_sharding, None, x_aval, None), + use_shardy=True) # type: ignore[arg-type] + for x, named_sharding, x_aval in zip( + args, exported._in_named_shardings, exported.in_avals)) + elif mesh: + # A mesh only exists if Shardy is enabled, or we saved named shardings. + args = tuple( + wrap_with_sharding( + ctx, x, x_aval, + _get_named_sharding(False, None, hlo_sharding, x_aval, mesh), + use_shardy=True) # type: ignore[arg-type] + for x, hlo_sharding, x_aval in zip( + args, exported.in_shardings_hlo, exported.in_avals)) else: + # Since there is no mesh - either due to shardy being disabled or the loaded + # function being lowered for GSPMD (so no shardy mesh) - need to create a + # GSPMD sharding from the HLO sharding (can't use shardy lowering). args = tuple( - wrap_with_sharding(ctx, x, x_aval, x_sharding) - for x, x_aval, x_sharding in zip(args, ctx.avals_in, exported.in_shardings_hlo)) + wrap_with_sharding(ctx, x, x_aval, x_sharding, use_shardy=False) + for x, x_aval, x_sharding in zip( + args, ctx.avals_in, exported.in_shardings_hlo)) - symtab = ir.SymbolTable(submodule.operation) # The called function may have been exported with polymorphic shapes and called # now with more refined shapes. We insert hlo.ConvertOp to ensure the module # is valid. - def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.AbstractValue) -> ir.Value: + def convert_shape(x: ir.Value, x_aval: core.AbstractValue, + new_aval: core.AbstractValue) -> ir.Value: new_ir_type = mlir.aval_to_ir_type(new_aval) if x.type != new_ir_type: return hlo.convert(mlir.aval_to_ir_type(new_aval), x) @@ -1567,15 +1730,30 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.Abstra for out, out_aval, refined_out_aval in zip(call.results[len(ordered_effects):], exported.out_avals, ctx.avals_out)) # Apply out_shardings - if shardy_enabled: + if exported._has_named_shardings: + results = tuple( + wrap_with_sharding( + ctx, x, x_aval, + _get_named_sharding(True, x_sharding, None, x_aval, None), + use_shardy=True) # type: ignore[arg-type] + for x, x_aval, x_sharding in \ + zip(results, ctx.avals_out, exported._out_named_shardings)) + elif mesh: results = tuple( wrap_with_sharding( - ctx, x, x_aval, _hlo_sharding_to_named_sharding(x_sharding, mesh)) # type: ignore[arg-type] - for x, x_aval, x_sharding in zip(results, ctx.avals_out, exported.out_shardings_hlo)) + ctx, x, x_aval, + _get_named_sharding(False, None, x_sharding, x_aval, mesh), + use_shardy=True) # type: ignore[arg-type] + for x, x_aval, x_sharding in \ + zip(results, ctx.avals_out, exported.out_shardings_hlo)) else: + # Since there is no mesh - either due to shardy being disabled or the loaded + # function being lowered for GSPMD (so no shardy mesh) - need to create a + # GSPMD sharding from the HLO sharding (can't use shardy lowering). results = tuple( - wrap_with_sharding(ctx, x, x_aval, x_sharding) - for x, x_aval, x_sharding in zip(results, ctx.avals_out, exported.out_shardings_hlo)) + wrap_with_sharding(ctx, x, x_aval, x_sharding, use_shardy=False) + for x, x_aval, x_sharding in \ + zip(results, ctx.avals_out, exported.out_shardings_hlo)) return results mlir.register_lowering(call_exported_p, _call_exported_lowering) @@ -1585,12 +1763,14 @@ def wrap_with_sharding( ctx: mlir.LoweringRuleContext, x: ir.Value, x_aval: core.AbstractValue, - x_sharding: sharding_impls.NamedSharding | HloSharding | None, + x_sharding: sharding_impls.NamedSharding | sharding_impls.GSPMDSharding | HloSharding | None, + use_shardy: bool, ) -> ir.Value: if x_sharding is None: return x - if config.use_shardy_partitioner.value: + if use_shardy: x_sharding = x_sharding._to_sdy_sharding(x_aval.ndim) # type: ignore else: x_sharding = x_sharding.to_proto() # type: ignore - return mlir.wrap_with_sharding_op(ctx, x, x_aval, x_sharding) # type: ignore[arg-type] + return mlir.wrap_with_sharding_op(ctx, x, x_aval, x_sharding, # type: ignore[arg-type] + allow_shardy_lowering=use_shardy) diff --git a/jax/_src/export/serialization.fbs b/jax/_src/export/serialization.fbs index 7d3e342f1879..844d879c5e54 100644 --- a/jax/_src/export/serialization.fbs +++ b/jax/_src/export/serialization.fbs @@ -45,7 +45,7 @@ enum AbstractValueKind: byte { } enum DType: byte { - // Last used id: 22 + // Last used id: 29 bool = 0, i8 = 1, i16 = 2, @@ -76,22 +76,65 @@ enum DType: byte { f8_e5m2fnuz = 21, f8_e8m0fnu = 25, f4_e2m1fn = 26, + + key_fry = 27, + key_rbg = 28, + key_unsafe_rbg = 29, +} + +enum MemorySpace: byte { + Missing = 0, // default if missing (pre 11/25/2025) + Device = 1, + Host = 2, + Any = 3, +} + +enum AxisType: byte { + Missing = 0, // default if missing (pre 1/15/2026) + Auto = 1, + Explicit = 2, + Manual = 3, +} + +table AbstractMesh { + axis_sizes: [uint32]; + axis_names: [string]; + axis_types: [AxisType]; +} + +table PartitionSpecOneAxis { + axes: [string]; // [] = None, ['x'] = 'x', ['x', 'y'] = ('x', 'y') +} + +table PartitionSpec { + partitions: [PartitionSpecOneAxis]; + reduced: [string]; + unreduced: [string]; +} + +table NamedSharding { + mesh: AbstractMesh; + spec: PartitionSpec; + memory_kind: string; } table AbstractValue { kind: AbstractValueKind; shape: [string]; // Support shape polymorphism dtype: DType; + memory_space: MemorySpace; } enum ShardingKind: byte { - unspecified, - hlo_sharding, + unspecified = 0, + hlo_sharding = 1, + named_sharding = 2, // Added 1/15/2026 } table Sharding { kind: ShardingKind; - hlo_sharding_proto: [byte]; + hlo_sharding_proto: [byte]; // if kind == hlo_sharding + named_sharding: NamedSharding; // if kind == named_sharding; added 1/15/2026 } table Effect { @@ -115,6 +158,7 @@ table Exported { /// Note that this field has different semantics and purpose from /// `mlir_module_serialization_version`, which encodes /// the calling convention of the `mlir_module_serialized`. + /// See comments in serialization.py for more details. serialization_version: uint16; function_name: string; @@ -122,7 +166,7 @@ table Exported { in_avals: [AbstractValue]; out_tree: PyTreeDef; out_avals: [AbstractValue]; - nr_devices: short; + nr_devices_short: short; // Deprecated as of 11/25/2025 in_shardings: [Sharding]; out_shardings: [Sharding]; @@ -138,6 +182,7 @@ table Exported { uses_global_constants: bool; vjp: Exported; + nr_devices: uint32 = 0; // Added 11/25/2025 } root_type Exported; diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index ac97c11d1177..7f490532661d 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -18,8 +18,9 @@ import types from collections.abc import Callable, Sequence +import itertools from functools import partial -from typing import TypeVar +from typing import Any, TypeVar try: import flatbuffers @@ -31,11 +32,14 @@ from jax._src import core from jax._src import dtypes from jax._src import effects -from jax._src import tree_util from jax._src.export import serialization_generated as ser_flatbuf from jax._src.export import _export from jax._src.export import shape_poly from jax._src.lib import xla_client +from jax._src import mesh +from jax._src import named_sharding +from jax._src import partition_spec +from jax._src import tree_util import numpy as np @@ -48,7 +52,15 @@ # Version 2, Dec 16th, 2023, adds the f0 dtype. # Version 3, October 16th, 2024, adds serialization for namedtuple and custom types # This version is backwards compatible with Version 2. -_SERIALIZATION_VERSION = 2 +# Version 4, April 7th, 2025, adds serialization for PRNGs key types. +# This version is backwards compatible with Version 2 and 3. +# Version 5, November 23rd, 2025, adds serialization for aval memory_space, +# upgrade num_devices to a 32 bit value. +# This version is backwards compatible with Version 2 to 4. +# Version 6, January 15th, 2026, adds serialization for sharding as +# NamedSharding, including the abstract mesh, and the partition spec. +# This version is backwards compatible with Version 2 to 5. +_SERIALIZATION_VERSION = 6 def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray: """Serializes an Exported. @@ -81,11 +93,16 @@ def _serialize_exported( in_avals = _serialize_array(builder, _serialize_aval, exp.in_avals) out_tree = _serialize_pytreedef(builder, exp.out_tree) out_avals = _serialize_array(builder, _serialize_aval, exp.out_avals) + # TODO(necula): For 30 days after 1/15/2026 we must serialize the HLO + # shardings. After that we can error out when serialized Exported with not + # _has_named_shardings. in_shardings = _serialize_array( - builder, _serialize_sharding, exp.in_shardings_hlo + builder, partial(_serialize_sharding, has_named_sharding=exp._has_named_shardings), + zip(exp._in_named_shardings, exp.in_shardings_hlo) # type: ignore ) out_shardings = _serialize_array( - builder, _serialize_sharding, exp.out_shardings_hlo + builder, partial(_serialize_sharding, has_named_sharding=exp._has_named_shardings), + zip(exp._out_named_shardings, exp.out_shardings_hlo) # type: ignore ) ordered_effects = _serialize_array( builder, _serialize_effect, exp.ordered_effects @@ -115,13 +132,19 @@ def _serialize_exported( vjp = _serialize_exported(builder, exp.vjp(), vjp_order - 1) ser_flatbuf.ExportedStart(builder) - ser_flatbuf.ExportedAddSerializationVersion(builder, _SERIALIZATION_VERSION) + # TODO(necula): we cannot really store the actual serialization_version + # in the flatbuffer because prior to 11/25/2025 deserializers checked + # if the version is 2 or 3. I have now removed that check, but for the + # sake of old deserializers we can only store version 3. Starting + # on January 2026 we can store the actual version. + ser_flatbuf.ExportedAddSerializationVersion(builder, 3) ser_flatbuf.ExportedAddFunctionName(builder, fun_name) ser_flatbuf.ExportedAddInTree(builder, in_tree) ser_flatbuf.ExportedAddInAvals(builder, in_avals) ser_flatbuf.ExportedAddOutTree(builder, out_tree) ser_flatbuf.ExportedAddOutAvals(builder, out_avals) ser_flatbuf.ExportedAddNrDevices(builder, exp.nr_devices) + ser_flatbuf.ExportedAddNrDevicesShort(builder, exp.nr_devices) # For forward compatibility, can remove after January 2026 ser_flatbuf.ExportedAddInShardings(builder, in_shardings) ser_flatbuf.ExportedAddOutShardings(builder, out_shardings) ser_flatbuf.ExportedAddPlatforms(builder, platforms) @@ -155,33 +178,49 @@ def _serialize_array( def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported: serialization_version = exp.SerializationVersion() - if serialization_version not in [2, 3]: - raise NotImplementedError( - f"deserialize unsupported version {serialization_version}" - ) fun_name = exp.FunctionName().decode("utf-8") in_tree = tree_util.tree_structure( _deserialize_pytreedef_to_pytree(exp.InTree()) ) - scope = shape_poly.SymbolicScope(()) # TODO: serialize the constraints - deser_aval = partial(_deserialize_aval, scope=scope) - in_avals = _deserialize_tuple( - exp.InAvalsLength, exp.InAvals, deser_aval - ) + scope = shape_poly.SymbolicScope(()) # TODO(necula): serialize the constraints out_tree = tree_util.tree_structure( _deserialize_pytreedef_to_pytree(exp.OutTree()) ) - out_avals = _deserialize_tuple( - exp.OutAvalsLength, exp.OutAvals, deser_aval - ) - nr_devices = exp.NrDevices() + + # TODO(necula): remove the fallback to NrDevicesShort and mark + # the field "deprecated" once we abandon the old + # serialization format (6 months after 11/24/2025). + nr_devices = exp.NrDevices() or exp.NrDevicesShort() in_shardings = _deserialize_tuple( exp.InShardingsLength, exp.InShardings, _deserialize_sharding ) out_shardings = _deserialize_tuple( exp.OutShardingsLength, exp.OutShardings, _deserialize_sharding ) + # has_named_sharding will be True for all exports after 1/15/2026 + has_named_shardings = not any(isinstance(s, _export.HloSharding) + for s in itertools.chain(in_shardings, out_shardings)) + if has_named_shardings: + in_avals = tuple( + _deserialize_aval(exp.InAvals(i), scope=scope, sharding=in_shardings[i]) # type: ignore + for i in range(exp.InAvalsLength()) + ) + out_avals = tuple( + _deserialize_aval(exp.OutAvals(i), scope=scope, sharding=out_shardings[i]) # type: ignore + for i in range(exp.OutAvalsLength()) + ) + in_shardings_hlo = tuple(_export.named_to_hlo_sharding(s, aval) # type: ignore + for s, aval in zip(in_shardings, in_avals)) + out_shardings_hlo = tuple(_export.named_to_hlo_sharding(s, aval) # type: ignore + for s, aval in zip(out_shardings, out_avals)) + else: + in_avals = _deserialize_tuple(exp.InAvalsLength, exp.InAvals, + partial(_deserialize_aval, scope=scope, sharding=None)) + out_avals = _deserialize_tuple(exp.OutAvalsLength, exp.OutAvals, + partial(_deserialize_aval, scope=scope, sharding=None)) + in_shardings_hlo, in_shardings = in_shardings, (None,) * len(in_shardings) # type: ignore + out_shardings_hlo, out_shardings = out_shardings, (None,) * len(out_shardings) # type: ignore platforms = _deserialize_tuple( exp.PlatformsLength, exp.Platforms, @@ -215,8 +254,11 @@ def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported: out_tree=out_tree, out_avals=out_avals, nr_devices=nr_devices, - in_shardings_hlo=in_shardings, - out_shardings_hlo=out_shardings, + in_shardings_hlo=in_shardings_hlo, + out_shardings_hlo=out_shardings_hlo, + _has_named_shardings=has_named_shardings, + _in_named_shardings=in_shardings, # type: ignore + _out_named_shardings=out_shardings, # type: ignore platforms=platforms, ordered_effects=ordered_effects, unordered_effects=unordered_effects, @@ -264,8 +306,14 @@ def _serialize_pytreedef( elif node_type is dict: kind = ser_flatbuf.PyTreeDefKind.dict assert len(node_data[1]) == len(children) + def serialize_key(builder, k): + if not isinstance(k, str): + raise TypeError( + "Serialization is supported only for dictionaries with string keys." + f" Found key {k} of type {type(k)}.") + return builder.CreateString(k) children_names_vector_offset = _serialize_array( - builder, lambda b, s: b.CreateString(s), node_data[1] + builder, serialize_key, node_data[1] ) elif node_type in _export.serialization_registry: kind = ser_flatbuf.PyTreeDefKind.custom @@ -357,21 +405,152 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef): dtypes._float8_e4m3fnuz_dtype: ser_flatbuf.DType.f8_e4m3fnuz, dtypes._float8_e5m2_dtype: ser_flatbuf.DType.f8_e5m2, dtypes._float8_e5m2fnuz_dtype: ser_flatbuf.DType.f8_e5m2fnuz, + dtypes._float8_e3m4_dtype: ser_flatbuf.DType.f8_e3m4, + dtypes._float8_e4m3_dtype: ser_flatbuf.DType.f8_e4m3, + dtypes._float8_e8m0fnu_dtype: ser_flatbuf.DType.f8_e8m0fnu, + dtypes._float4_e2m1fn_dtype: ser_flatbuf.DType.f4_e2m1fn, } -if dtypes._float8_e3m4_dtype is not None: - _dtype_to_dtype_kind[dtypes._float8_e3m4_dtype] = ser_flatbuf.DType.f8_e3m4 -if dtypes._float8_e4m3_dtype is not None: - _dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3 -if dtypes._float8_e8m0fnu_dtype is not None: - _dtype_to_dtype_kind[dtypes._float8_e8m0fnu_dtype] = ser_flatbuf.DType.f8_e8m0fnu -if dtypes._float4_e2m1fn_dtype is not None: - _dtype_to_dtype_kind[dtypes._float4_e2m1fn_dtype] = ser_flatbuf.DType.f4_e2m1fn _dtype_kind_to_dtype = { kind: dtype for dtype, kind in _dtype_to_dtype_kind.items() } +def register_dtype_kind(dtype: Any, kind: int): + _dtype_to_dtype_kind[dtype] = kind + _dtype_kind_to_dtype[kind] = dtype + + +_memory_space_to_enum = { + core.MemorySpace.Device: ser_flatbuf.MemorySpace.Device, + core.MemorySpace.Host: ser_flatbuf.MemorySpace.Host, + core.MemorySpace.Any: ser_flatbuf.MemorySpace.Any, +} +_memory_space_from_enum = {v: k for k, v in _memory_space_to_enum.items()} + + +_axis_type_to_enum = { + core.AxisType.Auto: ser_flatbuf.AxisType.Auto, + core.AxisType.Explicit: ser_flatbuf.AxisType.Explicit, + core.AxisType.Manual: ser_flatbuf.AxisType.Manual, +} +_axis_type_from_enum = {v: k for k, v in _axis_type_to_enum.items()} + + +def _serialize_abstract_mesh(builder: flatbuffers.Builder, + mesh: mesh.AbstractMesh) -> int: + ser_flatbuf.AbstractMeshStartAxisSizesVector(builder, len(mesh.axis_sizes)) + for axis_size in reversed(mesh.axis_sizes): + builder.PrependUint32(axis_size) + axis_sizes = builder.EndVector() + + axis_names = _serialize_array(builder, + lambda builder, an: builder.CreateString(an), + mesh.axis_names) + + assert mesh.axis_types is not None, mesh + ser_flatbuf.AbstractMeshStartAxisTypesVector(builder, len(mesh.axis_types)) + for axis_type in reversed(mesh.axis_types): + builder.PrependByte(_axis_type_to_enum[axis_type]) + axis_types = builder.EndVector() + + ser_flatbuf.AbstractMeshStart(builder) + ser_flatbuf.AbstractMeshAddAxisSizes(builder, axis_sizes) + ser_flatbuf.AbstractMeshAddAxisNames(builder, axis_names) + ser_flatbuf.AbstractMeshAddAxisTypes(builder, axis_types) + return ser_flatbuf.AbstractMeshEnd(builder) + + +def _deserialize_abstract_mesh(ser_mesh: ser_flatbuf.AbstractMesh +) -> mesh.AbstractMesh: + axis_sizes = tuple(ser_mesh.AxisSizes(i) + for i in range(ser_mesh.AxisSizesLength())) + axis_names = tuple(ser_mesh.AxisNames(i).decode("utf-8") + for i in range(ser_mesh.AxisNamesLength())) + axis_types = tuple(_axis_type_from_enum[ser_mesh.AxisTypes(i)] + for i in range(ser_mesh.AxisTypesLength())) + return mesh.AbstractMesh(axis_sizes, axis_names, axis_types) + + +def _serialize_partition_spec_one_axis(builder: flatbuffers.Builder, + spec: str | tuple[str, ...] | None) -> int: + if spec is None: + axes = () + else: + axes = (spec,) if isinstance(spec, str) else spec # type: ignore + + axes_offset = _serialize_array(builder, + lambda builder, ps: builder.CreateString(ps), + axes) + ser_flatbuf.PartitionSpecOneAxisStart(builder) + ser_flatbuf.PartitionSpecOneAxisAddAxes(builder, axes_offset) + return ser_flatbuf.PartitionSpecOneAxisEnd(builder) + + +def _deserialize_partition_spec_one_axis( + spec: ser_flatbuf.PartitionSpecOneAxis) -> str | tuple[str, ...] | None: + axes = tuple(spec.Axes(i).decode("utf-8") for i in range(spec.AxesLength())) + if not axes: + return None + else: + return axes[0] if len(axes) == 1 else axes + + +def _serialize_partition_spec(builder: flatbuffers.Builder, + spec: partition_spec.PartitionSpec) -> int: + partitions = _serialize_array(builder, _serialize_partition_spec_one_axis, + spec._partitions) + reduced = _serialize_array(builder, # type: ignore + lambda builder, ps: builder.CreateString(ps), + spec.reduced) + unreduced = _serialize_array(builder, # type: ignore + lambda builder, ps: builder.CreateString(ps), + spec.unreduced) + + ser_flatbuf.PartitionSpecStart(builder) + ser_flatbuf.PartitionSpecAddPartitions(builder, partitions) + ser_flatbuf.PartitionSpecAddReduced(builder, reduced) + ser_flatbuf.PartitionSpecAddUnreduced(builder, unreduced) + return ser_flatbuf.PartitionSpecEnd(builder) + + +def _deserialize_partition_spec(spec: ser_flatbuf.PartitionSpec + ) -> partition_spec.PartitionSpec: + partitions = tuple(_deserialize_partition_spec_one_axis(spec.Partitions(i)) + for i in range(spec.PartitionsLength())) + reduced = frozenset(spec.Reduced(i).decode("utf-8") + for i in range(spec.ReducedLength())) + unreduced = frozenset(spec.Unreduced(i).decode("utf-8") + for i in range(spec.UnreducedLength())) + return partition_spec.PartitionSpec(*partitions, + reduced=reduced, + unreduced=unreduced) + + +def _serialize_named_sharding( + builder: flatbuffers.Builder, sharding: named_sharding.NamedSharding +) -> int: + mesh_offset = _serialize_abstract_mesh(builder, sharding.mesh.abstract_mesh) + spec_offset = _serialize_partition_spec(builder, sharding.spec) + memory_kind = builder.CreateString(sharding.memory_kind) if sharding.memory_kind is not None else 0 + + ser_flatbuf.NamedShardingStart(builder) + ser_flatbuf.NamedShardingAddMesh(builder, mesh_offset) + ser_flatbuf.NamedShardingAddSpec(builder, spec_offset) + if memory_kind != 0: + ser_flatbuf.NamedShardingAddMemoryKind(builder, memory_kind) + return ser_flatbuf.NamedShardingEnd(builder) + + +def _deserialize_named_sharding( + s: ser_flatbuf.NamedSharding +) -> named_sharding.NamedSharding: + amesh = _deserialize_abstract_mesh(s.Mesh()) + spec = _deserialize_partition_spec(s.Spec()) + memory_kind = s.MemoryKind().decode("utf-8") if s.MemoryKind() is not None else None + return named_sharding.NamedSharding(amesh, spec, memory_kind=memory_kind) + + def _serialize_aval( builder: flatbuffers.Builder, aval: core.ShapedArray ) -> int: @@ -386,56 +565,75 @@ def _serialize_aval( ser_flatbuf.AbstractValueAddKind(builder, aval_kind) ser_flatbuf.AbstractValueAddShape(builder, shape_vector_offset) ser_flatbuf.AbstractValueAddDtype(builder, _dtype_to_dtype_kind[aval.dtype]) + ser_flatbuf.AbstractValueAddMemorySpace(builder, _memory_space_to_enum[aval.memory_space]) return ser_flatbuf.AbstractValueEnd(builder) -def _deserialize_aval(aval: ser_flatbuf.AbstractValue, - scope) -> core.ShapedArray: - aval_kind = aval.Kind() - if aval_kind == ser_flatbuf.AbstractValueKind.shapedArray: - dtype = _dtype_kind_to_dtype[aval.Dtype()] - shape = shape_poly.symbolic_shape( - ",".join( +def _deserialize_aval(aval: ser_flatbuf.AbstractValue, *, + scope: shape_poly.SymbolicScope, + sharding: named_sharding.NamedSharding | None, + ) -> core.ShapedArray: + dtype = _dtype_kind_to_dtype[aval.Dtype()] + shape = shape_poly.symbolic_shape( + ",".join( aval.Shape(i).decode("utf-8") for i in range(aval.ShapeLength()) ), scope=scope ) - return core.ShapedArray(shape, dtype) + if (ser_mem_space := aval.MemorySpace()): + mem_space = _memory_space_from_enum[ser_mem_space] else: - assert False, aval_kind + mem_space = core.MemorySpace.Device + + aval = core.ShapedArray(shape, dtype, memory_space=mem_space) + return core.update_aval_with_sharding(aval, sharding) def _serialize_sharding( - builder: flatbuffers.Builder, s: _export.HloSharding | None + builder: flatbuffers.Builder, s: tuple[_export.NamedSharding | None, _export.HloSharding | None], + has_named_sharding: bool, ) -> int: - proto = None - if s is None: + named_s, hlo_s = s + is_unspecified = (named_s is None) if has_named_sharding else (hlo_s is None) + named_sharding = None + hlo_sharding = None + if is_unspecified: kind = ser_flatbuf.ShardingKind.unspecified else: + # TODO(necula): We must use the hlo_sharding kind for at least 30 days after + # 1/15/2026 because old deserializers check the kind and abort if they + # do not recognize it. After that date, we can stop using the kind. kind = ser_flatbuf.ShardingKind.hlo_sharding - proto_bytes = s.to_proto().SerializeToString() - proto = builder.CreateByteVector(proto_bytes) + + if has_named_sharding and named_s is not None: + named_sharding = _serialize_named_sharding(builder, named_s) + + # TODO(necula): We must serialize the hlo_sharding for at least 30 days after + # 1/15/2026 because old deserializers can only deserialize this sharding. + if hlo_s is not None: + hlo_sharding = builder.CreateByteVector(hlo_s.to_proto().SerializeToString()) ser_flatbuf.ShardingStart(builder) ser_flatbuf.ShardingAddKind(builder, kind) - if proto is not None: - ser_flatbuf.ShardingAddHloShardingProto(builder, proto) + if hlo_sharding is not None: + ser_flatbuf.ShardingAddHloShardingProto(builder, hlo_sharding) + if named_sharding is not None: + ser_flatbuf.ShardingAddNamedSharding(builder, named_sharding) return ser_flatbuf.ShardingEnd(builder) -def _deserialize_sharding(s: ser_flatbuf.Sharding) -> _export.HloSharding | None: - kind = s.Kind() - if kind == ser_flatbuf.ShardingKind.unspecified: - return None +def _deserialize_sharding(s: ser_flatbuf.Sharding) -> _export.HloSharding | named_sharding.NamedSharding | None: + if (named_sharding_off := s.NamedSharding()) is not None: + # After 1/15/26 all exports will have named shardings (or None) + return _deserialize_named_sharding(named_sharding_off) - if kind == ser_flatbuf.ShardingKind.hlo_sharding: - proto_str = s.HloShardingProtoAsNumpy().tobytes() + # TODO(necula): We must keep reading the HloSharding for 6 months after 1/15/2026. + if not s.HloShardingProtoIsNone(): proto = xla_client.OpSharding() - proto.ParseFromString(proto_str) - + proto.ParseFromString(s.HloShardingProtoAsNumpy().tobytes()) return xla_client.HloSharding.from_proto(proto) - assert False, kind + return None # Unspecified sharding def _serialize_effect(builder: flatbuffers.Builder, eff: core.Effect) -> int: diff --git a/jax/_src/export/serialization_generated.py b/jax/_src/export/serialization_generated.py index b1fc13333777..c021fa6f639f 100644 --- a/jax/_src/export/serialization_generated.py +++ b/jax/_src/export/serialization_generated.py @@ -53,21 +53,39 @@ class DType(object): bf16 = 14 i4 = 15 ui4 = 16 - f8_e3m4 = 24 - f8_e4m3 = 23 f8_e4m3b11fnuz = 17 f8_e4m3fn = 18 f8_e4m3fnuz = 19 f8_e5m2 = 20 f8_e5m2fnuz = 21 f0 = 22 + f8_e4m3 = 23 + f8_e3m4 = 24 f8_e8m0fnu = 25 f4_e2m1fn = 26 + key_fry = 27 + key_rbg = 28 + key_unsafe_rbg = 29 + + +class MemorySpace(object): + Missing = 0 + Device = 1 + Host = 2 + Any = 3 + + +class AxisType(object): + Missing = 0 + Auto = 1 + Explicit = 2 + Manual = 3 class ShardingKind(object): unspecified = 0 hlo_sharding = 1 + named_sharding = 2 class DisabledSafetyCheckKind(object): @@ -211,6 +229,346 @@ def PyTreeDefEnd(builder): +class AbstractMesh(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = AbstractMesh() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsAbstractMesh(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + # AbstractMesh + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # AbstractMesh + def AxisSizes(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Uint32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # AbstractMesh + def AxisSizesAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o) + return 0 + + # AbstractMesh + def AxisSizesLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # AbstractMesh + def AxisSizesIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + # AbstractMesh + def AxisNames(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return "" + + # AbstractMesh + def AxisNamesLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # AbstractMesh + def AxisNamesIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + + # AbstractMesh + def AxisTypes(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Int8Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 1)) + return 0 + + # AbstractMesh + def AxisTypesAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int8Flags, o) + return 0 + + # AbstractMesh + def AxisTypesLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # AbstractMesh + def AxisTypesIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + return o == 0 + +def AbstractMeshStart(builder): + builder.StartObject(3) + +def AbstractMeshAddAxisSizes(builder, axisSizes): + builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(axisSizes), 0) + +def AbstractMeshStartAxisSizesVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + +def AbstractMeshAddAxisNames(builder, axisNames): + builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(axisNames), 0) + +def AbstractMeshStartAxisNamesVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + +def AbstractMeshAddAxisTypes(builder, axisTypes): + builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(axisTypes), 0) + +def AbstractMeshStartAxisTypesVector(builder, numElems): + return builder.StartVector(1, numElems, 1) + +def AbstractMeshEnd(builder): + return builder.EndObject() + + + +class PartitionSpecOneAxis(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = PartitionSpecOneAxis() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsPartitionSpecOneAxis(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + # PartitionSpecOneAxis + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # PartitionSpecOneAxis + def Axes(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return "" + + # PartitionSpecOneAxis + def AxesLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # PartitionSpecOneAxis + def AxesIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + +def PartitionSpecOneAxisStart(builder): + builder.StartObject(1) + +def PartitionSpecOneAxisAddAxes(builder, axes): + builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(axes), 0) + +def PartitionSpecOneAxisStartAxesVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + +def PartitionSpecOneAxisEnd(builder): + return builder.EndObject() + + + +class PartitionSpec(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = PartitionSpec() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsPartitionSpec(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + # PartitionSpec + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # PartitionSpec + def Partitions(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + x = self._tab.Vector(o) + x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 + x = self._tab.Indirect(x) + obj = PartitionSpecOneAxis() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # PartitionSpec + def PartitionsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # PartitionSpec + def PartitionsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + # PartitionSpec + def Reduced(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return "" + + # PartitionSpec + def ReducedLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # PartitionSpec + def ReducedIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + + # PartitionSpec + def Unreduced(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return "" + + # PartitionSpec + def UnreducedLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # PartitionSpec + def UnreducedIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + return o == 0 + +def PartitionSpecStart(builder): + builder.StartObject(3) + +def PartitionSpecAddPartitions(builder, partitions): + builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(partitions), 0) + +def PartitionSpecStartPartitionsVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + +def PartitionSpecAddReduced(builder, reduced): + builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(reduced), 0) + +def PartitionSpecStartReducedVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + +def PartitionSpecAddUnreduced(builder, unreduced): + builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(unreduced), 0) + +def PartitionSpecStartUnreducedVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + +def PartitionSpecEnd(builder): + return builder.EndObject() + + + +class NamedSharding(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = NamedSharding() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsNamedSharding(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + # NamedSharding + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # NamedSharding + def Mesh(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + x = self._tab.Indirect(o + self._tab.Pos) + obj = AbstractMesh() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # NamedSharding + def Spec(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + x = self._tab.Indirect(o + self._tab.Pos) + obj = PartitionSpec() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # NamedSharding + def MemoryKind(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + +def NamedShardingStart(builder): + builder.StartObject(3) + +def NamedShardingAddMesh(builder, mesh): + builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(mesh), 0) + +def NamedShardingAddSpec(builder, spec): + builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(spec), 0) + +def NamedShardingAddMemoryKind(builder, memoryKind): + builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(memoryKind), 0) + +def NamedShardingEnd(builder): + return builder.EndObject() + + + class AbstractValue(object): __slots__ = ['_tab'] @@ -263,8 +621,15 @@ def Dtype(self): return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) return 0 + # AbstractValue + def MemorySpace(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + def AbstractValueStart(builder): - builder.StartObject(3) + builder.StartObject(4) def AbstractValueAddKind(builder, kind): builder.PrependInt8Slot(0, kind, 0) @@ -278,6 +643,9 @@ def AbstractValueStartShapeVector(builder, numElems): def AbstractValueAddDtype(builder, dtype): builder.PrependInt8Slot(2, dtype, 0) +def AbstractValueAddMemorySpace(builder, memorySpace): + builder.PrependInt8Slot(3, memorySpace, 0) + def AbstractValueEnd(builder): return builder.EndObject() @@ -335,8 +703,18 @@ def HloShardingProtoIsNone(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) return o == 0 + # Sharding + def NamedSharding(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + x = self._tab.Indirect(o + self._tab.Pos) + obj = NamedSharding() + obj.Init(self._tab.Bytes, x) + return obj + return None + def ShardingStart(builder): - builder.StartObject(2) + builder.StartObject(3) def ShardingAddKind(builder, kind): builder.PrependInt8Slot(0, kind, 0) @@ -347,6 +725,9 @@ def ShardingAddHloShardingProto(builder, hloShardingProto): def ShardingStartHloShardingProtoVector(builder, numElems): return builder.StartVector(1, numElems, 1) +def ShardingAddNamedSharding(builder, namedSharding): + builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(namedSharding), 0) + def ShardingEnd(builder): return builder.EndObject() @@ -457,6 +838,7 @@ def Init(self, buf, pos): # Note that this field has different semantics and purpose from # `mlir_module_serialization_version`, which encodes # the calling convention of the `mlir_module_serialized`. + # See comments in serialization.py for more details. # Exported def SerializationVersion(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) @@ -540,7 +922,7 @@ def OutAvalsIsNone(self): return o == 0 # Exported - def NrDevices(self): + def NrDevicesShort(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16)) if o != 0: return self._tab.Get(flatbuffers.number_types.Int16Flags, o + self._tab.Pos) @@ -764,8 +1146,15 @@ def Vjp(self): return obj return None + # Exported + def NrDevices(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(40)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos) + return 0 + def ExportedStart(builder): - builder.StartObject(18) + builder.StartObject(19) def ExportedAddSerializationVersion(builder, serializationVersion): builder.PrependUint16Slot(0, serializationVersion, 0) @@ -791,8 +1180,8 @@ def ExportedAddOutAvals(builder, outAvals): def ExportedStartOutAvalsVector(builder, numElems): return builder.StartVector(4, numElems, 4) -def ExportedAddNrDevices(builder, nrDevices): - builder.PrependInt16Slot(6, nrDevices, 0) +def ExportedAddNrDevicesShort(builder, nrDevicesShort): + builder.PrependInt16Slot(6, nrDevicesShort, 0) def ExportedAddInShardings(builder, inShardings): builder.PrependUOffsetTRelativeSlot(7, flatbuffers.number_types.UOffsetTFlags.py_type(inShardings), 0) @@ -851,5 +1240,8 @@ def ExportedAddUsesGlobalConstants(builder, usesGlobalConstants): def ExportedAddVjp(builder, vjp): builder.PrependUOffsetTRelativeSlot(17, flatbuffers.number_types.UOffsetTFlags.py_type(vjp), 0) +def ExportedAddNrDevices(builder, nrDevices): + builder.PrependUint32Slot(18, nrDevices, 0) + def ExportedEnd(builder): return builder.EndObject() diff --git a/jax/_src/export/shape_poly.py b/jax/_src/export/shape_poly.py index 6a6ce93712ff..1e58834387f5 100644 --- a/jax/_src/export/shape_poly.py +++ b/jax/_src/export/shape_poly.py @@ -13,7 +13,7 @@ # limitations under the License. """Shape polymorphism support. -See documentation at https://jax.readthedocs.io/en/latest/export/shape_poly.html. +See documentation at https://docs.jax.dev/en/latest/export/shape_poly.html. """ from __future__ import annotations @@ -34,23 +34,21 @@ import numpy as np import opt_einsum -import jax - +from jax._src import api from jax._src import config from jax._src import core from jax._src import dtypes from jax._src import effects -from jax._src.lax import lax from jax._src.interpreters import mlir -from jax._src.numpy import einsum as jnp_einsum from jax._src import source_info_util from jax._src import tree_util +from jax._src import typing from jax._src import util DimSize = Union["_DimExpr", int] TfVal = Any -DimVarEnv = dict[str, jax.Array] +DimVarEnv = dict[str, typing.Array] DType = Any # Tuples of terms and their coefficients, sorted with the largest term first. @@ -70,7 +68,7 @@ class InconclusiveDimensionOperation(core.InconclusiveDimensionOperation): are non-constant, and the result of the operation cannot be represented as a boolean value for all values of the symbolic dimensions involved. -Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#comparison-of-symbolic-dimensions-is-partially-supported +Please see https://docs.jax.dev/en/latest/export/shape_poly.html#comparison-of-symbolic-dimensions-is-partially-supported for more details. """ @@ -214,6 +212,8 @@ def __ge__(self, other: _DimFactor): return self._syntactic_cmp(other) >= 0 def evaluate(self, env: DimVarEnv, scope: SymbolicScope): + from jax._src.lax import lax # pytype: disable=import-error + if self.var is not None: try: return env[self.var] @@ -227,7 +227,8 @@ def evaluate(self, env: DimVarEnv, scope: SymbolicScope): return normalized_var._evaluate(env) # type: ignore err_msg = ( f"Encountered dimension variable '{self.var}' that is not appearing in the shapes of the function arguments.\n" - "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details.") + f"The following dimension variables are appearing in the shapes of the function arguments: {list(env.keys())}.\n" + "Please see https://docs.jax.dev/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details.") raise UnexpectedDimVar(err_msg) else: operand_values = [opnd._evaluate(env) for opnd in self.operands] @@ -654,7 +655,7 @@ def _eq(self, other: _DimExpr) -> bool: # Here we really ought to raise InconclusiveDimensionOperation, but __eq__ # cannot raise exceptions, because it is used indirectly when hashing. # So, we say that the expressions are disequal, which is really unsound. - # See https://jax.readthedocs.io/en/latest/export/shape_poly.html#comparison-of-symbolic-dimensions-is-partially-supported + # See https://docs.jax.dev/en/latest/export/shape_poly.html#comparison-of-symbolic-dimensions-is-partially-supported return False return diff == 0 @@ -841,7 +842,7 @@ def __eq__(self, other: Any) -> bool: # Here we really ought to raise InconclusiveDimensionOperation, but __eq__ # cannot raise exceptions, because it is used indirectly when hashing. # So, we say that the expressions are disequal, which is really unsound. - # See https://jax.readthedocs.io/en/latest/export/shape_poly.html#comparison-of-symbolic-dimensions-is-partially-supported + # See https://docs.jax.dev/en/latest/export/shape_poly.html#comparison-of-symbolic-dimensions-is-partially-supported return False return diff == 0 @@ -978,7 +979,7 @@ def cmp_sequence(s1, s2, elem_cmp) -> int: class SymbolicScope: - """Indentifies a scope for symbolic expressions. + """Identifies a scope for symbolic expressions. All symbolic expressions that interact (e.g., appear in the argument shapes for one JAX function invocation, or are involved in arithmetic operations) @@ -986,7 +987,7 @@ class SymbolicScope: Holds the constraints on symbolic expressions. - See [the README](https://jax.readthedocs.io/en/latest/export/shape_poly.html#user-specified-symbolic-constraints) + See [the README](https://docs.jax.dev/en/latest/export/shape_poly.html#user-specified-symbolic-constraints) for more details. Args: @@ -1001,7 +1002,8 @@ def __init__(self, "The symbolic constraints should be a sequence of strings. " f"Got {repr(constraints_str)}") self._initialized = False - self._location_frame = source_info_util.user_frame(source_info_util.current()) + self._location_frame = source_info_util.user_frame( + source_info_util.current().traceback) # Keep the explicit constraints in the order in which they were added self._explicit_constraints: list[_SymbolicConstraint] = [] @@ -1112,7 +1114,7 @@ def _check_same_scope(self, other: _DimExpr, f"Invalid mixing of symbolic scopes {when}.\n" f"Expected {self_descr}scope {self}\n" f"and found for '{other}' ({other_descr}) scope {other.scope}\n" - f"See https://jax.readthedocs.io/en/latest/export/shape_poly.html#user-specified-symbolic-constraints.") + f"See https://docs.jax.dev/en/latest/export/shape_poly.html#user-specified-symbolic-constraints.") def _clear_caches(self): self._bounds_cache.clear() @@ -1224,7 +1226,8 @@ def is_symbolic_dim(p: DimSize) -> bool: """ return isinstance(p, _DimExpr) -dtypes.python_scalar_dtypes[_DimExpr] = dtypes.python_scalar_dtypes[int] +dtypes.python_scalar_types.add(_DimExpr) +dtypes.python_scalar_types_to_dtypes[_DimExpr] = dtypes.python_scalar_types_to_dtypes[int] def _einsum_contract_path(*operands, **kwargs): """Like opt_einsum.contract_path, with support for DimExpr shapes. @@ -1255,7 +1258,7 @@ def fake_dim(d): # here some errors due to non-equal dimensions, but we catch them # later. return 8 - fake_ops.append(jax.ShapeDtypeStruct(tuple(map(fake_dim, shape)), + fake_ops.append(api.ShapeDtypeStruct(tuple(map(fake_dim, shape)), operand.dtype)) contract_fake_ops, contractions = opt_einsum.contract_path(*fake_ops, @@ -1267,8 +1270,6 @@ def fake_dim(d): contract_operands.append(operands[idx[0]]) return contract_operands, contractions -jnp_einsum._poly_einsum_handlers[_DimExpr] = _einsum_contract_path - # To implement shape-constraint checking we use a shape assertion primitive. # shape_assertion_p.bind(assert_what: bool, *error_message_inputs, # error_message="...{0}...{1}") @@ -1303,8 +1304,8 @@ class ShapeAssertionEffect(effects.Effect): effects.remat_allowed_effects.add_type(ShapeAssertionEffect) effects.custom_derivatives_allowed_effects.add_type(ShapeAssertionEffect) -def shape_assertion(assert_what: jax.Array, - *error_message_inputs: jax.Array, +def shape_assertion(assert_what: typing.Array, + *error_message_inputs: typing.Array, error_message: str) -> None: """Adds a shape assertion in the code. @@ -1384,7 +1385,7 @@ def symbolic_shape(shape_spec: str | None, ) -> Sequence[DimSize]: """Constructs a symbolic shape from a string representation. - See https://jax.readthedocs.io/en/latest/export/shape_poly.html for examples. + See https://docs.jax.dev/en/latest/export/shape_poly.html for examples. Args: shape_spec: a symbolic shape specification. None stands for "...". @@ -1396,13 +1397,13 @@ def symbolic_shape(shape_spec: str | None, mod(e1, e2), max(e1, e2), or min(e1, e2). constraints: a sequence of constraints on symbolic dimension expressions, of the form `e1 >= e2` or `e1 <= e2`, or `e1 == e2`. - See [the documentation](https://jax.readthedocs.io/en/latest/export/shape_poly.html#user-specified-symbolic-constraints) + See [the documentation](https://docs.jax.dev/en/latest/export/shape_poly.html#user-specified-symbolic-constraints) for usage. scope: optionally, you can specify that the parsed symbolic expressions be created in the given scope. If this is missing, then a new `SymbolicScope` is created with the given `constraints`. You cannot specify both a `scope` and `constraints`. - See [the documentation](https://jax.readthedocs.io/en/latest/export/shape_poly.html#user-specified-symbolic-constraints) + See [the documentation](https://docs.jax.dev/en/latest/export/shape_poly.html#user-specified-symbolic-constraints) for usage. like: when `shape_spec` contains placeholders ("_", "..."), use this shape to fill in the placeholders. @@ -1434,13 +1435,13 @@ def symbolic_args_specs( constraints: Sequence[str] = (), scope: SymbolicScope | None = None, ): - """Constructs a pytree of jax.ShapeDtypeSpec arguments specs for `export`. + """Constructs a pytree of jax.ShapeDtypeStruct arguments specs for `export`. See the documentation of :func:`jax.export.symbolic_shape` and - the [shape polymorphism documentation](https://jax.readthedocs.io/en/latest/export/shape_poly.html) for details. + the [shape polymorphism documentation](https://docs.jax.dev/en/latest/export/shape_poly.html) for details. Args: - args: a pytree of arguments. These can be jax.Array, or jax.ShapeDTypeSpec. + args: a pytree of arguments. These can be jax.Array, or jax.ShapeDtypeStruct. They are used to learn the pytree structure of the arguments, their dtypes, and to fill-in the actual shapes where the `shapes_specs` contains placeholders. Note that only the shape dimensions for which @@ -1450,11 +1451,11 @@ def symbolic_args_specs( applies to all arguments), or a pytree matching a prefix of the `args`. See [how optional parameters are matched to - arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). + arguments](https://docs.jax.dev/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). constraints: as for :func:`jax.export.symbolic_shape`. scope: as for :func:`jax.export.symbolic_shape`. - Returns: a pytree of jax.ShapeDTypeStruct matching the `args` with the shapes + Returns: a pytree of jax.ShapeDtypeStruct matching the `args` with the shapes replaced with symbolic dimensions as specified by `shapes_specs`. """ polymorphic_shapes = shapes_specs @@ -1485,14 +1486,14 @@ def symbolic_args_specs( elif constraints: raise ValueError("Cannot use both `scope` and `constraints`") args_specs_flat = ( - jax.ShapeDtypeStruct(symbolic_shape(spec, like=s, scope=scope), t) + api.ShapeDtypeStruct(symbolic_shape(spec, like=s, scope=scope), t) for s, t, spec in zip(shapes, dtypes, polymorphic_shapes_flat)) return args_tree.unflatten(args_specs_flat) def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]: """Returns the shape and dtype of a jax.Array or a j""" - if isinstance(a, jax.ShapeDtypeStruct): + if isinstance(a, api.ShapeDtypeStruct): return a.shape, a.dtype aval = core.get_aval(a) return aval.shape, aval.dtype @@ -1785,7 +1786,7 @@ def check_statically(self, eval: ShapeEvaluator) -> None: if not ok: raise self.make_error(eval) - def compute(self, eval: ShapeEvaluator) -> jax.Array | None: + def compute(self, eval: ShapeEvaluator) -> typing.Array | None: """Computes if the constraint is satisfied. If the constraint can be resolved statically returns None @@ -1793,6 +1794,8 @@ def compute(self, eval: ShapeEvaluator) -> jax.Array | None: resolved statically, returns a value representing if the constraint is satisfied. """ + from jax._src.lax import lax # pytype: disable=import-error + left, right = eval.evaluate(self.left), eval.evaluate(self.right) # Try to evaluate the constraint statically. if core.is_constant_shape((left, right)): @@ -1997,8 +2000,8 @@ def solve_dim_vars( def compute_dim_vars_from_arg_shapes( args_avals: Sequence[core.ShapedArray], - *actual_args: jax.Array, - args_kwargs_tree: tree_util.PyTreeDef) -> Sequence[jax.Array]: + *actual_args: typing.Array, + args_kwargs_tree: tree_util.PyTreeDef) -> Sequence[typing.Array]: """Computes values of dimension variables to unify args_avals with actual arguments. Like `solve_dim_vars` except that here we express the solution as @@ -2021,7 +2024,7 @@ def compute_dim_vars_from_arg_shapes( } synthetic_eval = ShapeEvaluator(synthetic_env) shape_constraints.shape_assertions(synthetic_eval) - return tuple(synthetic_eval.evaluate(solution[var]) for var in dim_vars) + return tuple(synthetic_eval.evaluate(solution[var]) for var in dim_vars) # type: ignore[arg-type] def _solve_dim_equations( eqns: list[_DimEquation], @@ -2038,7 +2041,7 @@ def _solve_dim_equations( " Using the following polymorphic shapes specifications: " + ",".join(f"{arg_name}.shape = {arg_spec}" for arg_name, arg_spec in polymorphic_shape_specs)) + "." - solution_err_msg_trailer_errors = ". Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." + solution_err_msg_trailer_errors = ". Please see https://docs.jax.dev/en/latest/export/shape_poly.html#shape-assertion-errors for more details." shape_constraints = ShapeConstraints() # accumulate shape constraints scope: SymbolicScope | None = None @@ -2171,6 +2174,6 @@ def add_explicit_symbolic_constraints(shape_env: DimVarEnv): " Unprocessed specifications: " + ", ".join(f"'{eqn.aval_dim_expr}' for dimension size {eqn.dim_name}" for eqn in eqns) + - ". Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details." + ". Please see https://docs.jax.dev/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details." ) raise ValueError(err_msg) diff --git a/jax/_src/extend/random.py b/jax/_src/extend/random.py index df927486dd2f..dd85042cedca 100644 --- a/jax/_src/extend/random.py +++ b/jax/_src/extend/random.py @@ -14,10 +14,9 @@ from collections.abc import Callable, Hashable -from jax import Array - from jax._src import prng from jax._src import random +from jax._src.typing import Array Shape = tuple[int, ...] diff --git a/jax/_src/ffi.py b/jax/_src/ffi.py index 05697f00e945..d3fc47ff6b3a 100644 --- a/jax/_src/ffi.py +++ b/jax/_src/ffi.py @@ -16,36 +16,32 @@ from collections.abc import Callable, Mapping, Sequence import ctypes +import dataclasses import functools import os -from typing import Any, overload +from typing import Any, TypedDict, NotRequired, overload import numpy as np -import jax from jax._src import core -from jax._src import deprecations from jax._src import dispatch from jax._src import effects from jax._src import util from jax._src import xla_bridge +from jax._src.hashable_array import HashableArray +from jax._src.frozen_dict import FrozenDict from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir -from jax._src.layout import DeviceLocalLayout +from jax._src.layout import Layout from jax._src.lib import jaxlib from jax._src.lib import xla_client from jax._src.lib.mlir import ir from jax._src.typing import (Array, ArrayLike, DeprecatedArg, DuckTypedArray, Shape) -# TODO(dfm): Remove after 6 months or less because there aren't any offical -# compatibility guarantees for jax.extend (see JEP 15856) -# Added Oct 13, 2024 -deprecations.register("jax-ffi-call-args") - map, unsafe_map = util.safe_map, map -FfiLayoutOptions = Sequence[int] | DeviceLocalLayout | None +FfiLayoutOptions = Sequence[int] | Layout | None def register_ffi_target( @@ -61,7 +57,7 @@ def register_ffi_target( name: the name of the target. fn: a ``PyCapsule`` object containing the function pointer, or a ``dict`` where the keys are FFI stage names (e.g. `"execute"`) and the values are - ``PyCapsule`` objects continaing a pointer to the handler for that stage. + ``PyCapsule`` objects containing a pointer to the handler for that stage. platform: the target platform. api_version: the XLA custom call API version to use. Supported versions are: 1 (default) for the typed FFI or 0 for the earlier "custom call" API. @@ -73,6 +69,20 @@ def register_ffi_target( **kwargs) +class TypeRegistration(TypedDict): + """A dictionary type for registering FFI types. + + Attributes: + type_id: A ``PyCapsule`` object containing a pointer to the + ``XLA_FFI_TypeId``. + type_info: An optional ``PyCapsule`` object containing a pointer to the type + ``XLA_FFI_TypeInfo``. + """ + + type_id: Any + type_info: NotRequired[Any] + + def register_ffi_type_id( name: str, obj: Any, @@ -85,7 +95,24 @@ def register_ffi_type_id( obj: a ``PyCapsule`` object encapsulating a pointer to the type ID. platform: the target platform. """ - return xla_client.register_custom_type_id(name, obj, platform=platform) + raise ValueError( + "register_ffi_type_id is not supported after jaxlib version 381.") + +def register_ffi_type( + name: str, + type_registration: TypeRegistration, + platform: str = "cpu", +) -> None: + """Registers a custom type for a FFI target. + + Args: + name: the name of the type. This name must be unique within the process. + type_registration: a ``TypeRegistration`` defining the external type. + platform: the target platform. + """ + return xla_client.register_custom_type( + name, type_registration, platform=platform + ) def register_ffi_target_as_batch_partitionable(name: str) -> None: @@ -141,7 +168,7 @@ def include_dir() -> str: def _aval_shape(aval: core.AbstractValue) -> Shape: - return () if aval is core.abstract_token else aval.shape # pytype: disable=attribute-error + return () if aval is core.abstract_token else core.physical_aval(aval).shape # pytype: disable=attribute-error def _convert_layout_for_lowering( @@ -149,8 +176,8 @@ def _convert_layout_for_lowering( """Convert a layout to the minor-to-major order used by the custom call API.""" if layout is None: return tuple(reversed(range(len(_aval_shape(aval))))) - elif isinstance(layout, DeviceLocalLayout): - if layout._tiling is not None: + elif isinstance(layout, Layout): + if layout.tiling is not None: raise ValueError("The FFI does not support layouts with tiling") return layout.major_to_minor[::-1] else: @@ -163,6 +190,7 @@ def build_ffi_lowering_function( operand_layouts: Sequence[FfiLayoutOptions] | None = None, result_layouts: Sequence[FfiLayoutOptions] | None = None, backend_config: Mapping[str, ir.Attribute] | str | None = None, + skip_ffi_layout_processing: bool = False, **lowering_args: Any, ) -> Callable[..., ir.Operation]: """Build a lowering op for an foreign function interface (FFI) target. @@ -173,7 +201,7 @@ def build_ffi_lowering_function( Note that layouts passed to this function as tuples should be in minor-to-major order (as expected by XLA) rather than major-to-minor as used - by :func:`~jax.ffi.ffi_call` and ``DeviceLocalLayout``. + by :func:`~jax.ffi.ffi_call` and ``Layout``. If keyword arguments are passed to the lowering rule, these are treated as attributes, and added to `backend_config`. @@ -188,6 +216,8 @@ def build_ffi_lowering_function( arguments passed to the lowering rule will added to this dictionary. lowering_args: Any other arguments to :func:`mlir.custom_call` will also be passed through if provided as extra arguments to this function. + skip_ffi_layout_processing: If true, skip processing of operand and result + layout arguments passed to the lowering rule. """ def _lowering_op( @@ -209,18 +239,25 @@ def _lowering_op( kwargs["backend_config"] = backend_config if "result_types" not in kwargs: kwargs["result_types"] = [mlir.aval_to_ir_type(aval) for aval in ctx.avals_out] - if operand_layouts is None: - kwargs["operand_layouts"] = map(_convert_layout_for_lowering, ctx.avals_in) - else: - kwargs["operand_layouts"] = [ - _convert_layout_for_lowering(*args) - for args in zip(ctx.avals_in, operand_layouts)] - if result_layouts is None: - kwargs["result_layouts"] = map(_convert_layout_for_lowering, ctx.avals_out) - else: - kwargs["result_layouts"] = [ - _convert_layout_for_lowering(*args) - for args in zip(ctx.avals_out, result_layouts)] + if not skip_ffi_layout_processing: + if operand_layouts is None: + kwargs["operand_layouts"] = map( + _convert_layout_for_lowering, ctx.avals_in + ) + else: + kwargs["operand_layouts"] = [ + _convert_layout_for_lowering(*args) + for args in zip(ctx.avals_in, operand_layouts) + ] + if result_layouts is None: + kwargs["result_layouts"] = map( + _convert_layout_for_lowering, ctx.avals_out + ) + else: + kwargs["result_layouts"] = [ + _convert_layout_for_lowering(*args) + for args in zip(ctx.avals_out, result_layouts) + ] if "result_shapes" not in kwargs and not all( core.is_constant_shape(_aval_shape(aval)) for aval in ctx.avals_out): kwargs["result_shapes"] = [ @@ -238,6 +275,7 @@ def ffi_lowering( operand_layouts: Sequence[FfiLayoutOptions] | None = None, result_layouts: Sequence[FfiLayoutOptions] | None = None, backend_config: Mapping[str, ir.Attribute] | str | None = None, + skip_ffi_layout_processing: bool = False, **lowering_args: Any ) -> mlir.LoweringRule: """Build a lowering rule for an foreign function interface (FFI) target. @@ -248,7 +286,7 @@ def ffi_lowering( Note that layouts passed to this function as tuples should be in minor-to-major order (as expected by XLA) rather than major-to-minor as used - by :func:`~jax.ffi.ffi_call` and ``DeviceLocalLayout``. + by :func:`~jax.ffi.ffi_call` and ``Layout``. If keyword arguments are passed to the lowering rule, these are treated as attributes, and added to `backend_config`. @@ -263,6 +301,8 @@ def ffi_lowering( arguments passed to the lowering rule will added to this dictionary. lowering_args: Any other arguments to :func:`mlir.custom_call` will also be passed through if provided as extra arguments to this function. + skip_ffi_layout_processing: If true, skip processing of operand and result + layout arguments passed to the lowering rule. """ def _lowering( @@ -273,6 +313,7 @@ def _lowering( operand_layouts=operand_layouts, result_layouts=result_layouts, backend_config=backend_config, + skip_ffi_layout_processing=skip_ffi_layout_processing, **lowering_args, )(ctx, *operands, **params) @@ -287,17 +328,21 @@ def _lowering( def _result_avals(results: Sequence[ResultMetadata]) -> tuple[core.AbstractValue, ...]: avals: list[core.AbstractValue] = [] for idx, result in enumerate(results): - if isinstance(result, core.AbstractToken): - avals.append(result) + if result is core.abstract_token: + avals.append(result) # type: ignore else: if not hasattr(result, "shape") or not hasattr(result, "dtype"): raise ValueError( "All elements of result_shape_dtypes must have 'shape' and 'dtype' " f"attributes. Got {result} at position {idx}.") - avals.append(core.ShapedArray(result.shape, result.dtype)) + # Update the dtype because shaped_abstractify can canonicalize the dtype. + # We need to call shaped_abstractify here to handle sharding, vma and + # memory_kind bits. + # TODO(yashkatariya): Maybe add an option to shaped_abstractify/typeof + # to not canonicalize dtype. + avals.append(core.shaped_abstractify(result).update(dtype=result.dtype)) return tuple(avals) - def _check_compatible_avals(a: core.AbstractValue, b: core.AbstractValue) -> bool: if isinstance(a, core.AbstractToken) and isinstance(b, core.AbstractToken): return True @@ -314,7 +359,7 @@ def _convert_layouts_for_ffi_call( return tuple( _convert_layout_for_lowering( aval, - layout if layout is None or isinstance(layout, DeviceLocalLayout) + layout if layout is None or isinstance(layout, Layout) else layout[::-1] ) for aval, layout in zip(avals, layouts)) @@ -325,7 +370,7 @@ def _convert_layouts_for_ffi_call( def ffi_call( target_name: str, result_shape_dtypes: ResultMetadata, - *deprecated_args: ArrayLike, + *, has_side_effect: bool = ..., vmap_method: str | None = ..., input_layouts: Sequence[FfiLayoutOptions] | None = ..., @@ -333,9 +378,8 @@ def ffi_call( input_output_aliases: dict[int, int] | None = ..., custom_call_api_version: int = ..., legacy_backend_config: str | None = ..., - vectorized: bool | DeprecatedArg = ..., - **deprecated_kwargs: Any, -) -> Callable[..., Array] | Array: + vectorized: bool | None | DeprecatedArg = DeprecatedArg(), +) -> Callable[..., Array]: ... @@ -343,7 +387,7 @@ def ffi_call( def ffi_call( target_name: str, result_shape_dtypes: Sequence[ResultMetadata], - *deprecated_args: ArrayLike, + *, has_side_effect: bool = ..., vmap_method: str | None = ..., input_layouts: Sequence[FfiLayoutOptions] | None = ..., @@ -351,16 +395,15 @@ def ffi_call( input_output_aliases: dict[int, int] | None = ..., custom_call_api_version: int = ..., legacy_backend_config: str | None = ..., - vectorized: bool | DeprecatedArg = ..., - **deprecated_kwargs: Any, -) -> Callable[..., Sequence[Array]] | Sequence[Array]: + vectorized: bool | None | DeprecatedArg = DeprecatedArg(), +) -> Callable[..., Sequence[Array]]: ... def ffi_call( target_name: str, result_shape_dtypes: ResultMetadata | Sequence[ResultMetadata], - *deprecated_args: ArrayLike, + *, has_side_effect: bool = False, vmap_method: str | None = None, input_layouts: Sequence[FfiLayoutOptions] | None = None, @@ -368,16 +411,15 @@ def ffi_call( input_output_aliases: dict[int, int] | None = None, custom_call_api_version: int = 4, legacy_backend_config: str | None = None, - vectorized: bool | DeprecatedArg = DeprecatedArg(), - **deprecated_kwargs: Any, -) -> Callable[..., Array | Sequence[Array]] | Array | Sequence[Array]: + vectorized: bool | None | DeprecatedArg = DeprecatedArg(), +) -> Callable[..., Array | Sequence[Array]]: """Call a foreign function interface (FFI) target. See the :ref:`ffi-tutorial` tutorial for more information. Like :func:`~jax.pure_callback`, the behavior of ``ffi_call`` under :func:`~jax.vmap` depends on the value of ``vmap_method``. See the - :func:`~jax.pure_callback` documenation for more details about the allowed + :func:`~jax.pure_callback` documentation for more details about the allowed values and examples of their behavior. The current default behavior is to use ``vmap_method="sequential"`` when @@ -400,7 +442,7 @@ def ffi_call( :func:`~jax.vmap` as described above. input_layouts: a sequence of layouts for each input argument. In each case, the layout can be (a) ``None`` indicating that this input is in default - row-major order, (b) a ``DeviceLocalLayout`` specifying the axis order, + row-major order, (b) a ``Layout`` specifying the axis order, or (c) a sequence of integers specifying the major-to-minor axis ordering. Users who are familiar with XLA layouts should note that this function expects layouts in major-to-minor order instead of the @@ -430,18 +472,11 @@ def ffi_call( to execute the FFI handler. Any keyword arguments are passed as named attributes to the FFI handler using XLA's FFI interface. """ - if not isinstance(vectorized, DeprecatedArg) and not vectorized is None: - deprecations.warn( - "jax-callback-vectorized", - "The vectorized argument of ffi_call is deprecated and setting " - "it will soon raise an error. To avoid an error in the future, and to " - "suppress this warning, please use the vmap_method argument instead.", - stacklevel=2) - if vmap_method is not None: - raise ValueError( - "the vectorized and vmap_method arguments of ffi_call cannot " - "be used together. Please use the vmap_method argument.") - vmap_method = "legacy_vectorized" if vectorized else "sequential" + # TODO(danfm): Remove this check 3 months after v0.6.0 is released. + if not isinstance(vectorized, DeprecatedArg): + raise ValueError( + "The 'vectorized' argument of jax.ffi.ffi_call was removed in JAX " + "v0.6.0. Use 'vmap_method' instead.") allowed_vmap_methods = ["sequential", "sequential_unrolled", "expand_dims", "broadcast_all", "legacy_vectorized", None] if vmap_method not in allowed_vmap_methods: @@ -456,7 +491,7 @@ def ffi_call( result_avals = _result_avals(result_shape_dtypes) else: multiple_results = False - result_avals = _result_avals((result_shape_dtypes,)) + result_avals = _result_avals([result_shape_dtypes]) output_layouts_ = (output_layouts,) # type: ignore if custom_call_api_version >= 4 and legacy_backend_config is not None: @@ -515,11 +550,10 @@ def wrapped(*args: ArrayLike, **kwargs: Any): "and an output with a different layout " f"{static_output_layouts[o_idx]}.") static_input_output_aliases += ((i_idx, o_idx),) - + args = core.standard_insert_pvary(*args) results = ffi_call_p.bind( *args, result_avals=result_avals, - vectorized=vectorized, vmap_method=vmap_method, target_name=target_name, has_side_effect=has_side_effect, @@ -537,19 +571,7 @@ def wrapped(*args: ArrayLike, **kwargs: Any): else: return results[0] - if deprecated_args or deprecated_kwargs: - deprecations.warn( - "jax-ffi-call-args", - "Calling ffi_call directly with input arguments is deprecated. " - "Instead, ffi_call should be used to construct a callable, which can " - "then be called with the appropriate inputs. For example,\n" - " ffi_call('target_name', output_type, x, argument=5)\n" - "should be replaced with\n" - " ffi_call('target_name', output_type)(x, argument=5)", - stacklevel=2) - return wrapped(*deprecated_args, **deprecated_kwargs) - else: - return wrapped + return wrapped # ffi_call must support some small non-hashable input arguments, like np.arrays @@ -563,7 +585,7 @@ def _wrap_kwargs_hashable(kwargs: dict[str, Any]) -> Sequence[tuple[str, Any]]: if isinstance(v, np.ndarray): hashable_kwargs.append((k, HashableArray(v))) elif isinstance(v, dict): - hashable_kwargs.append((k, HashableDict(v))) + hashable_kwargs.append((k, FrozenDict(v))) else: try: hash(v) @@ -580,48 +602,14 @@ def _unwrap_kwargs_hashable(kwargs: Sequence[tuple[str, Any]]) -> dict[str, Any] for k, v in kwargs: if isinstance(v, HashableArray): unwrapped_kwargs[k] = v.val - elif isinstance(v, HashableDict): - unwrapped_kwargs[k] = dict(v.val) + elif isinstance(v, FrozenDict): + unwrapped_kwargs[k] = v._d else: unwrapped_kwargs[k] = v return unwrapped_kwargs -class HashableArray: - __slots__ = ["val"] - - def __init__(self, val): - assert isinstance(val, np.ndarray) - self.val = np.copy(val) - self.val.setflags(write=False) - - def __repr__(self): - return f"HashableArray({self.val})" - - def __hash__(self): - return hash((self.val.shape, self.val.dtype, self.val.tobytes())) - - def __eq__(self, other): - return isinstance(other, HashableArray) and np.array_equal(self.val, other.val) - - -class HashableDict: - __slots__ = ["val"] - - def __init__(self, val): - assert isinstance(val, dict) - self.val = tuple(sorted(val.items())) - - def __repr__(self): - return f"HashableDict({dict(self.val)})" - - def __hash__(self): - return hash(self.val) - - def __eq__(self, other): - return isinstance(other, HashableDict) and self.val == other.val - - +@dataclasses.dataclass(frozen=True) class FfiEffect(effects.Effect): def __str__(self): return "FFI" @@ -638,9 +626,12 @@ def ffi_call_abstract_eval( has_side_effect: bool, **_, ): - del avals_in # unused + core.standard_vma_rule('ffi_call', *avals_in) effects = {_FfiEffect} if has_side_effect else core.no_effects - return result_avals, effects + return tuple(r if r is core.abstract_token else + r.update(sharding=(core.get_cur_mesh_sharding() + if r.sharding.mesh.empty else r.sharding)) # type: ignore + for r in result_avals), effects def ffi_call_jvp(*args, target_name, **_): @@ -684,20 +675,12 @@ def ffi_batching_rule( args, dims, *, - vectorized: bool | None | DeprecatedArg, vmap_method: str | None, result_avals: Sequence[core.ShapedArray], **kwargs: Any, ): - if isinstance(vectorized, DeprecatedArg) and vmap_method is None: - deprecations.warn( - "jax-callback-vectorized", - f"The default behavior of {prim.name} under vmap will soon " - "change. Currently, the default behavior is to generate a sequential " - "vmap (i.e. a loop), but in the future the default will be to raise " - "an error. To keep the current default, set vmap_method='sequential'.", - stacklevel=6) - vmap_method = "sequential" + from jax._src.lax import control_flow # pytype: disable=import-error + from jax._src.lax import lax # pytype: disable=import-error axis_size, = {a.shape[d] for a, d in zip(args, dims) if d is not batching.not_mapped} @@ -726,7 +709,6 @@ def ffi_batching_rule( for layout, d in zip(kwargs["input_layouts"], dims)) outvals = prim.bind( *new_args, - vectorized=vectorized, vmap_method=vmap_method, result_avals=batched_result_avals, **kwargs, @@ -734,7 +716,7 @@ def ffi_batching_rule( elif vmap_method == "expand_dims" or vmap_method == "broadcast_all": size = axis_size if vmap_method == "broadcast_all" else 1 bcast_args = [ - jax.lax.broadcast(x, (size,)) if d is batching.not_mapped else x + lax.broadcast(x, (size,)) if d is batching.not_mapped else x for x, d in zip(new_args, dims)] if kwargs.get("input_layouts") is not None: kwargs["input_layouts"] = tuple( @@ -742,7 +724,6 @@ def ffi_batching_rule( for layout in kwargs["input_layouts"]) outvals = prim.bind( *bcast_args, - vectorized=vectorized, vmap_method=vmap_method, result_avals=batched_result_avals, **kwargs, @@ -755,13 +736,12 @@ def _batch_fun(batched_args): return prim.bind( *merged_args, result_avals=result_avals, - vectorized=vectorized, vmap_method=vmap_method, **kwargs, ) unroll = vmap_method == "sequential_unrolled" g = lambda _, x: ((), _batch_fun(x)) - _, outvals = jax.lax.scan(g, (), batched_args, unroll=unroll) + _, outvals = control_flow.scan(g, (), batched_args, unroll=unroll) else: raise NotImplementedError( f"vmap is only supported for the {prim.name} primitive when vmap_method " diff --git a/jax/_src/flatten_util.py b/jax/_src/flatten_util.py index ff35b8db8e25..ec5fb2984cc5 100644 --- a/jax/_src/flatten_util.py +++ b/jax/_src/flatten_util.py @@ -12,20 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Iterable import numpy as np +from typing import Any, Callable, TypeAlias -from jax import lax -import jax.numpy as jnp - -from jax._src.lax import lax as lax_internal +from jax._src.lax import lax from jax._src import dtypes -from jax._src.tree_util import tree_flatten, tree_unflatten -from jax._src.util import safe_zip, unzip2, HashablePartial +from jax._src.tree_util import tree_flatten, tree_unflatten, PyTreeDef, Leaf +from jax._src.util import safe_zip as zip, unzip2, HashablePartial +from jax._src.typing import Array -zip = safe_zip +Sizes: TypeAlias = tuple[int, ...] +Shapes: TypeAlias = tuple[tuple[int, ...], ...] -def ravel_pytree(pytree): +def ravel_pytree(pytree: Any) -> tuple[Array, Callable[[Array], Any]]: """Ravel (flatten) a pytree of arrays down to a 1D array. Args: @@ -41,49 +42,67 @@ def ravel_pytree(pytree): component of the output. For details on dtype promotion, see - https://jax.readthedocs.io/en/latest/type_promotion.html. + https://docs.jax.dev/en/latest/type_promotion.html. """ leaves, treedef = tree_flatten(pytree) flat, unravel_list = _ravel_list(leaves) return flat, HashablePartial(unravel_pytree, treedef, unravel_list) -def unravel_pytree(treedef, unravel_list, flat): + +def unravel_pytree( + treedef: PyTreeDef, + unravel_list: Callable[[Array], Iterable[Leaf]], + flat: Array, +) -> Any: return tree_unflatten(treedef, unravel_list(flat)) -def _ravel_list(lst): - if not lst: return jnp.array([], jnp.float32), lambda _: [] + +def _ravel_list(lst: list[Any], /) -> tuple[Array, Callable[[Array], list[Any]]]: + if not lst: + return lax.full([0], 0, "float32"), lambda _: [] from_dtypes = tuple(dtypes.dtype(l) for l in lst) to_dtype = dtypes.result_type(*from_dtypes) - sizes, shapes = unzip2((jnp.size(x), jnp.shape(x)) for x in lst) - indices = tuple(np.cumsum(sizes)) + sizes, shapes = unzip2((np.size(x), np.shape(x)) for x in lst) if all(dt == to_dtype for dt in from_dtypes): # Skip any dtype conversion, resulting in a dtype-polymorphic `unravel`. # See https://github.com/jax-ml/jax/issues/7809. del from_dtypes, to_dtype - raveled = jnp.concatenate([jnp.ravel(e) for e in lst]) - return raveled, HashablePartial(_unravel_list_single_dtype, indices, shapes) + ravel = lambda e: lax.reshape(e, (np.size(e),)) + raveled = lax.concatenate([ravel(e) for e in lst], dimension=0) + return raveled, HashablePartial(_unravel_list_single_dtype, sizes, shapes) # When there is more than one distinct input dtype, we perform type # conversions and produce a dtype-specific unravel function. - ravel = lambda e: jnp.ravel(lax.convert_element_type(e, to_dtype)) - raveled = jnp.concatenate([ravel(e) for e in lst]) - unrav = HashablePartial(_unravel_list, indices, shapes, from_dtypes, to_dtype) + ravel = lambda e: lax.convert_element_type(e, to_dtype).ravel() + raveled = lax.concatenate([ravel(e) for e in lst], dimension=0) + unrav = HashablePartial(_unravel_list, sizes, shapes, from_dtypes, to_dtype) return raveled, unrav -def _unravel_list_single_dtype(indices, shapes, arr): - chunks = jnp.split(arr, indices[:-1]) + +def _unravel_list_single_dtype(sizes: Sizes, shapes: Shapes, arr: Array) -> list[Array]: + chunks = lax.split(arr, sizes) return [chunk.reshape(shape) for chunk, shape in zip(chunks, shapes)] -def _unravel_list(indices, shapes, from_dtypes, to_dtype, arr): + +def _unravel_list( + sizes: Sizes, + shapes: Shapes, + from_dtypes: tuple[np.dtype, ...], + to_dtype: np.dtype, + arr: Array, +) -> list[Array]: arr_dtype = dtypes.dtype(arr) if arr_dtype != to_dtype: - raise TypeError(f"unravel function given array of dtype {arr_dtype}, " - f"but expected dtype {to_dtype}") - chunks = jnp.split(arr, indices[:-1]) + raise TypeError( + f"unravel function given array of dtype {arr_dtype}, " + f"but expected dtype {to_dtype}" + ) + chunks = lax.split(arr, sizes) return [ - lax_internal._convert_element_type(chunk.reshape(shape), dtype, - warn_on_complex_to_real_cast=False) + lax._convert_element_type( + chunk.reshape(shape), dtype, warn_on_complex_to_real_cast=False + ) for chunk, shape, dtype in zip(chunks, shapes, from_dtypes) ] diff --git a/jax/_src/frozen_dict.py b/jax/_src/frozen_dict.py new file mode 100644 index 000000000000..f443cce0dfd1 --- /dev/null +++ b/jax/_src/frozen_dict.py @@ -0,0 +1,52 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, TypeVar +from collections.abc import Iterator, Mapping + +K = TypeVar("K") +V = TypeVar("V") + + +class FrozenDict(Mapping[K, V]): + + def __init__(self, d: Mapping[K, V]): + self._d = dict(d.items()) + + def __repr__(self) -> str: + return f"FrozenDict({self._d!r})" + + def __str__(self) -> str: + return f"FrozenDict({self._d})" + + def __getitem__(self, key: K) -> V: + return self._d[key] + + def __hash__(self) -> int: + # This assumes that the values are hashable. + return hash(frozenset(self._d.items())) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, FrozenDict): + return False + return self._d == other._d + + def __iter__(self) -> Iterator[K]: + return iter(self._d) + + def __len__(self) -> int: + return len(self._d) + + def get(self, key: K) -> V | None: # type: ignore + return self._d.get(key, None) diff --git a/jax/_src/hardware_utils.py b/jax/_src/hardware_utils.py index 84ad9edf919f..fb439edcf507 100644 --- a/jax/_src/hardware_utils.py +++ b/jax/_src/hardware_utils.py @@ -40,6 +40,8 @@ class TpuVersion(enum.IntEnum): v5e = 5 # TPU v6e v6e = 6 + # TPU7x + tpu7x = 7 _TPU_PCI_DEVICE_IDS = { @@ -49,6 +51,7 @@ class TpuVersion(enum.IntEnum): '0x0062': TpuVersion.v5p, '0x0063': TpuVersion.v5e, '0x006f': TpuVersion.v6e, + '0x0076': TpuVersion.tpu7x, } def num_available_tpu_chips_and_device_id(): diff --git a/jax/_src/hashable_array.py b/jax/_src/hashable_array.py new file mode 100644 index 000000000000..4757a9c5eb24 --- /dev/null +++ b/jax/_src/hashable_array.py @@ -0,0 +1,37 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the + +import numpy as np + + +class HashableArray: + __slots__ = ["val"] + val: np.ndarray + + def __init__(self, val): + self.val = np.array(val, copy=True) + self.val.setflags(write=False) + + def __repr__(self): + return f"HashableArray({self.val!r})" + + def __str__(self): + return f"HashableArray({self.val})" + + def __hash__(self): + return hash((self.val.shape, self.val.dtype, self.val.tobytes())) + + def __eq__(self, other): + return isinstance(other, HashableArray) and np.array_equal( + self.val, other.val + ) diff --git a/jax/_src/hijax.py b/jax/_src/hijax.py new file mode 100644 index 000000000000..5dae4366291f --- /dev/null +++ b/jax/_src/hijax.py @@ -0,0 +1,760 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from functools import partial +import inspect +import itertools as it +from typing import Any, Hashable, Callable + +from jax._src import api +from jax._src import config +from jax._src import core +from jax._src import dtypes +from jax._src import effects +from jax._src.api_util import resolve_kwargs, infer_argnums_and_argnames +from jax._src.core import typeof +from jax._src.interpreters import ad +from jax._src.interpreters import batching +from jax._src.interpreters import partial_eval as pe +from jax._src.custom_derivatives import CustomVJPPrimal +from jax._src.errors import UnexpectedTracerError +from jax._src.state.types import AbstractRef +from jax._src import ad_util +from jax._src.util import safe_zip, safe_map, split_list, unzip2 +from jax._src.tree_util import ( + tree_map, tree_flatten, tree_unflatten, tree_leaves, tree_leaves_checked, + broadcast_prefix, register_static, tree_structure, tree_map_with_path, + keystr) +map, unsafe_map = safe_map, map +zip, unsafe_zip = safe_zip, zip + +PyTreeOfAvals = Any +PyTreeDef = Any +LoVal = Any +HiVal = Any + + +# Hijax extension API + +Ty = core.AbstractValue +LoType = core.AbstractValue +QDD = core.QuasiDynamicData +ShapedArray = core.ShapedArray + +class HiPrimitive(core.Primitive): + def __init__(self, name): + self.name = name + ad.primitive_jvps[self] = self.jvp + ad.primitive_transposes[self] = self.transpose + + def is_high(self, *avals, **params) -> bool: + return True + + def is_effectful(self, params) -> bool: # type: ignore + return False # default immutable + + # type checking and forward type propagation + def abstract_eval(self, *arg_avals, **params): + assert False, "must override" + + # lowering implements the primitive in terms of lojax inputs/outputs/ops + def to_lojax(self, *lotypes_wrapped_in_hitypes, **params): + assert False, f"must override for {self}" + + # autodiff interface + def jvp(self, primals, tangents, **params): + assert False, "must override" + # transposition is only required if the primitive is linear in some inputs + def transpose(self, *args, **params): + assert False, "must override" + + +class HiType(core.AbstractValue): + is_high = True + has_qdd = False # immutable + + # type equality + def __hash__(self): assert False, "must override" + def __eq__(self, other): assert False, "must override" + + # lowering from hijax type to lojax types + def lo_ty(self) -> list[core.AbstractValue]: + assert False, "must override" + + # define lowering from hijax value to lojax values and back (like pytrees) + def lower_val(self, hi_val: HiVal) -> list[LoVal]: # TODO(mattjj); not lovals + assert False, "must override" + def raise_val(self, *lo_vals: LoVal) -> HiVal: + assert False, "must override" + + # autodiff interface + def to_tangent_aval(self) -> HiType: + assert False, "must override" + + # Subclasses should override if the cotangent type is a function of primal + # type. For example, CT unreduced = reduced and vice-versa. + def to_cotangent_aval(self) -> HiType: + return self.to_tangent_aval() + + # the next two are required if this type is itself a tangent type + def vspace_zero(self) -> HiVal: + assert False, "must override" + + def vspace_add(self, x: HiVal, y: HiVal) -> HiVal: + assert False, "must override" + +class MutableHiType(core.AbstractValue): + is_high = True + has_qdd = True # mutable and potentially type-changing + type_state = core.aval_method(core.cur_qdd) + + # type equality + def __hash__(self): assert False, "must override" + def __eq__(self, other): assert False, "must override" + + # define lowering from (mutable) hijax type to (immutable) lojax types + def lo_ty_qdd(self, state: QDD) -> list[core.AbstractValue]: + assert False, "must override" + def lo_ty(self): + assert False, "mutable hitypes should use lo_ty_qdd instead" + + # define lowering from hijax value to lojax values and back, depending on qdd + def new_from_loval(self, state: QDD, *vals: LoVal) -> HiVal: + assert False, "must override" + def read_loval(self, state: QDD, val: HiVal) -> list[LoVal]: + assert False, "must override" + + # define how to mutate/set the mutable hijax value given immutable lojax vals + def update_from_loval(self, state: QDD, val: HiVal, *lo_vals: LoVal) -> None: + assert False, "must override" + + # autodiff interface + def to_tangent_aval(self) -> HiType: + assert False, "must override" + + # Subclasses should override if the cotangent type is a function of primal + # type. For example, CT unreduced = reduced and vice-versa. + def to_cotangent_aval(self) -> HiType: + return self.to_tangent_aval() + +def register_hitype(val_cls, typeof_fn) -> None: + core.pytype_aval_mappings[val_cls] = typeof_fn + dtypes.canonicalize_value_handlers[val_cls] = lambda x: x + +def hijax_method(f): + return core.aval_method(f) + + +# Boxes + +## Box API + +def new_box(): + (), treedef = tree_flatten(None) + return new_box_p.bind(treedef=treedef) + +def box_get(box): + tys = core.cur_qdd(box) + leaf_vals = box_get_p.bind(box, avals=tuple(tys.leaf_avals)) + return tree_unflatten(tys.treedef, leaf_vals) + +def box_set(box, val): + leaves, treedef = tree_flatten(val) + box_set_p.bind(box, *leaves, treedef=treedef) + +## Box implementation + +@dataclass(frozen=True) +class BoxTypeState(QDD): + leaf_avals: tuple[core.AbstractValue, ...] + treedef: PyTreeDef + + def to_tangent_qdd(self): + leaf_avals = tuple(a.to_tangent_aval() for a in self.leaf_avals) + return BoxTypeState(leaf_avals, self.treedef) + + def normalize(self): + leaf_types = tuple(a.normalize() for a in self.leaf_avals) + return BoxTypeState(leaf_types, self.treedef) + +class BoxTy(MutableHiType): + has_qdd = True + + # forwarded to value + get = core.aval_method(box_get) + set = core.aval_method(box_set) + + # aval interface: hashability and str_short + def __hash__(self): return hash(BoxTy) + def __eq__(self, other): return isinstance(other, BoxTy) + + def str_short(self, short_dtypes=False, **_) -> str: # type: ignore + return 'BoxTy' + + # mutable interface + def lo_ty_qdd(self, box_state): + return [lo_ty for t in box_state.leaf_avals for lo_ty in t.lo_ty()] + + def new_from_loval(self, box_state: BoxTypeState, *lo_vals) -> Box: # type: ignore + lo_vals_ = iter(lo_vals) + hi_vals = [hi_ty.raise_val(*it.islice(lo_vals_, len(hi_ty.lo_ty()))) # type: ignore + for hi_ty in box_state.leaf_avals] + assert next(lo_vals_, None) is None + return Box._new(tree_unflatten(box_state.treedef, hi_vals)) # will be mutated + + def read_loval(self, box_state: BoxTypeState, box) -> list: # type: ignore + leaf_vals, treedef = tree_flatten(box_get(box)) + assert treedef == box_state.treedef + return [lo_val for hi_ty, hi_val in zip(box_state.leaf_avals, leaf_vals) + for lo_val in hi_ty.lower_val(hi_val)] # type: ignore + + def update_from_loval(self, box_state: BoxTypeState, box, *lo_vals) -> None: # type: ignore + lo_vals_ = iter(lo_vals) + hi_vals = [hi_ty.raise_val(*it.islice(lo_vals_, len(hi_ty.lo_ty()))) # type: ignore + for hi_ty in box_state.leaf_avals] + assert next(lo_vals_, None) is None + box_set(box, tree_unflatten(box_state.treedef, hi_vals)) + + def to_tangent_aval(self): + return BoxTy() + +# Override isinstance checks under tracing +class _BoxMeta(type): + def __instancecheck__(self, instance): + return (super().__instancecheck__(instance) or + isinstance(instance, core.Tracer) and + isinstance(core.typeof(instance), BoxTy)) + +class Box(metaclass=_BoxMeta): # noqa: F811 + _val = None # always clobbered by __new__, but pytype likes this + + # We want `Box(x)` to bind a primitive, so we override __new__ and provide a + # raw `_new` method below. + def __new__(cls, init_val=None): + (), treedef = tree_flatten(None) + box = new_box_p.bind(treedef=treedef) + box.set(init_val) + return box + + @classmethod + def _new(cls, init_val): + new = super().__new__(cls) + new._val = init_val + return new + + def get(self): + return box_get(self) + + def set(self, val): + box_set(self, val) + + def cur_qdd(self): + return self.type_state() + + @property + def ty(self): + return BoxTy() + + def type_state(self): + leaves, treedef = tree_flatten(self._val) + leaf_avals = tuple(map(core.typeof, leaves)) + return BoxTypeState(leaf_avals, treedef) + +register_hitype(Box, lambda b: b.ty) + +class BoxEffect(effects.Effect): ... +box_effect = BoxEffect() +effects.control_flow_allowed_effects.add_type(BoxEffect) +effects.custom_derivatives_allowed_effects.add_type(BoxEffect) + +class NewBox(HiPrimitive): + def is_high(self, *, treedef) -> bool: return True # type: ignore + + def abstract_eval(self, *, treedef): + leaves, treedef = tree_flatten(None) + qdd = BoxTypeState(tuple(leaves), treedef) + return core.AvalQDD(BoxTy(), qdd), {box_effect} + + def to_lojax(_, *, treedef): + return Box._new(None) + + def jvp(_, primals, tangents, *, treedef): + assert False # TODO + + def transpose(_, *args, treedef): + assert False # TODO +new_box_p = NewBox('new_box') + +class BoxSet(HiPrimitive): + multiple_results = True + + def is_high(self, *leaf_avals, treedef) -> bool: return True # type: ignore + + def abstract_eval(self, box_ty, *leaf_avals, treedef): + box_ty.mutable_qdd.update(BoxTypeState(leaf_avals, treedef)) + return [], {box_effect} # TODO better typechecking... + + def to_lojax(_, box, *leaves, treedef): + box._val = tree_unflatten(treedef, leaves) + return [] + + def jvp(_, primals, tangents, *, treedef): + box, *vals = primals + box_dot, *val_dots = tangents + if type(box_dot) is ad_util.Zero: + raise Exception("can't differentiate Box.set operation, " + "did you forget jax.lax.stop_gradient?") + box_set_p.bind(box, *vals, treedef=treedef) + box_set_p.bind(box_dot, *val_dots, treedef=treedef) + return [], [] + + def transpose(_, *args, treedef): + assert False # TODO +box_set_p = BoxSet('box_set') + + +class BoxGet(HiPrimitive): + multiple_results = True + + def abstract_eval(self, box_ty, *, avals): + return avals, {box_effect} + + def to_lojax(_, box, *, avals): + return tree_leaves(box._val) + + def jvp(_, primals, tangents, *, avals): + (box,), (box_dot,) = primals, tangents + return ( + box_get_p.bind(box, avals=avals), + box_get_p.bind(box_dot, avals=tuple(a.to_tangent_aval() for a in avals)) + ) + + def transpose(_, *args): + assert False # TODO +box_get_p = BoxGet('box_get') + + +# === new-style hijax primitive implementation === + +class VJPHiPrimitive: + in_avals: tuple[PyTreeOfAvals, ...] + out_aval: PyTreeOfAvals + params: dict[str, Hashable] + + def __init__(self): + if not hasattr(self, 'in_avals'): + raise AttributeError("subclass __init__ should set `self.in_avals`") + if not hasattr(self, 'out_aval'): + raise AttributeError("subclass __init__ should set `self.out_aval`") + if not hasattr(self, 'params'): + raise AttributeError("subclass __init__ should set `self.params`") + if (type(self).vjp_bwd is not VJPHiPrimitive.vjp_bwd and + type(self).vjp_bwd_retval is not VJPHiPrimitive.vjp_bwd_retval): + raise AttributeError(f"subclass {type(self)} should not override both " + "`vjp_bwd` and `vjp_bwd_retval`") + self.in_avals_flat, self.in_tree = tree_flatten(self.in_avals) + self.out_avals_flat, self.out_tree = tree_flatten(self.out_aval) + self.__dict__.update(self.params) + + # Operation implementation in terms of lojax primitives + def expand(self, *args): + raise NotImplementedError(f"subclass {type(self)} must implement `expand`") + + def vjp_fwd(self, nzs_in, *args): + raise NotImplementedError(f"for grad support, subclass {type(self)} must " + "implement `vjp_fwd`") + + def vjp_bwd(self, res, outgrad, *arg_accums): + args_grad = self.vjp_bwd_retval(res, outgrad) + tree_map(lambda acc, leaf_grad: acc.accum(leaf_grad), arg_accums, args_grad) + + def vjp_bwd_retval(self, res, outgrad): + # Classic API: returns values instead of using accumulators + raise NotImplementedError(f"for grad support, subclass {type(self)} must " + "implement `vjp_bwd` or `vjp_bwd_retval`") + + def batch(self, axis_data, args, dims): + out_dim = self.batch_dim_rule(axis_data, dims) + return VmapOf(self, axis_data, dims, out_dim)(*args), out_dim + + def batch_dim_rule(self, axis_data, dims): + raise NotImplementedError(f"for vmap support, subclass {type(self)} must " + "implement `batch` or `batch_dim_rule`") + + def jvp(self, primals, tangents): + raise NotImplementedError(f"for jvp support, subclass {type(self)} must " + "implement `jvp`") + + def __call__(self, *args): + args_flat = tree_leaves_checked(self.in_tree, args) + ans_flat = call_hi_primitive_p.bind(*args_flat, prim=self) + return tree_unflatten(self.out_tree, ans_flat) + + def check(self, *arg_tys): + return # subclass can optionally override this to add checking logic + + def staging(self, trace, source_info, *args): + args_flat = tree_leaves_checked(self.in_tree, args) + ans_flat = trace.default_process_primitive( + call_hi_primitive_p, args_flat, dict(prim=self), source_info) + return tree_unflatten(self.out_tree, ans_flat) + + def __repr__(self): + return f"{self.__class__.__name__}[{self.params}]" + + def __hash__(self): + return hash((self.__class__.__name__, tuple(self.params.items()))) + + def __eq__(self, other): + return type(self) is type(other) and self.params == other.params + +class VmapOf(VJPHiPrimitive): + def __init__(self, prim, axis_data, in_dims, out_dim): + unmap = lambda a, d: core.unmapped_aval(axis_data.size, d, a, + axis_data.explicit_mesh_axis) + self.in_avals = tree_map(unmap, prim.in_avals, in_dims) + self.out_aval = tree_map(unmap, prim.out_aval, out_dim) + self.params = dict(prim=prim, axis_data=axis_data, in_dims=in_dims, + out_dim=out_dim) + super().__init__() + + @property + def _vmap_params(self): + return dict(axis_size=self.axis_data.size, axis_name=self.axis_data.name, # type: ignore + spmd_axis_name=self.axis_data.spmd_name or self.axis_data.explicit_mesh_axis) # type: ignore + + def expand(self, *args): + return api.vmap(self.prim.expand, in_axes=self.in_dims, out_axes=self.out_dim, # type: ignore + **self._vmap_params)(*args) + + def jvp(self, primals, tangents): + # TODO probably gonna get non-pytree-prefix errors because of sym zeros... + return api.vmap(self.prim.jvp, in_axes=(self.in_dims, self.in_dims), # type: ignore + out_axes=(self.out_dim, self.out_dim), # type: ignore + **self._vmap_params)(primals, tangents) # type: ignore + + def vjp_fwd(self, in_nzs, *args): + store = lambda: None + def fwd(*args): + primal_out, res, *maybe_out_nzs = self.prim.vjp_fwd(in_nzs, *args) # type: ignore + store.out_nzs = maybe_out_nzs + return primal_out, res + (primal_out, res), (_, res_axes) = api.vmap( + fwd, in_axes=self.in_dims, out_axes=(self.out_dim, batching.infer), # type: ignore + **self._vmap_params)(*args) + return primal_out, (res, Static(res_axes)), *store.out_nzs # type: ignore + + def vjp_bwd_retval(self, res_, g): + # TODO probably gonna get non-pytree-prefix errors because of sym zeros... + res, res_axes = res_[0], res_[1].val + in_dims = tree_map(lambda x: batching.sum_axis if x is None else x, self.in_dims, # type: ignore + is_leaf=lambda x: x is None) + g = tree_map(partial(map_zero, self.axis_data), self.out_dim, g, is_leaf=lambda x: x is None) # type: ignore + out = api.vmap(self.prim.vjp_bwd_retval, in_axes=(res_axes, self.out_dim), # type: ignore + out_axes=in_dims, **self._vmap_params, sum_match=True)(res, g) + return tree_map(partial(unmap_zero, self.axis_data), self.in_dims, out, is_leaf=lambda x: x is None) # type: ignore + + def batch_dim_rule(self, axis_data, in_dims): + in_dims_ = tree_map(lambda d, d_: d - (d_ < d), in_dims, self.in_dims) # type: ignore + out_dim = self.prim.batch_dim_rule(axis_data, in_dims_) # type: ignore + return tree_map(lambda d, d_: d + (d_ < d), out_dim, self.out_dim) # type: ignore + +def map_zero(axis_data, d, ct): + if isinstance(ct, ad_util.Zero): + return ad_util.Zero(core.mapped_aval(axis_data.size, d, ct.aval)) + return ct + +def unmap_zero(axis_data, d, ct): + if isinstance(ct, ad_util.Zero): + return ad_util.Zero(core.unmapped_aval(axis_data.size, d, ct.aval, + axis_data.explicit_mesh_axis)) + return ct + + +call_hi_primitive_p = core.Primitive("call_hi_primitive") +call_hi_primitive_p.multiple_results = True +call_hi_primitive_p.is_high = lambda *args, prim: True # type: ignore +@call_hi_primitive_p.def_abstract_eval +def _call_hi_primitive_abstract_eval(*_args, prim): + return prim.out_avals_flat + +def _call_hi_primitive_staging(trace, source_info, *args_flat, prim): + trace.frame.is_high = True + args = tree_unflatten(prim.in_tree, args_flat) + ans = prim.staging(trace, source_info, *args) + return tree_leaves_checked(prim.out_tree, ans) +pe.custom_staging_rules[call_hi_primitive_p] = _call_hi_primitive_staging + +def _call_hi_primitive_to_lojax(*args_flat, prim): + args = tree_unflatten(prim.in_tree, args_flat) + ans = prim.expand(*args) + return tree_leaves_checked(prim.out_tree, ans) +call_hi_primitive_p.to_lojax = _call_hi_primitive_to_lojax + +def _call_hi_primitive_batcher(axis_data, args_flat, dims_flat, prim): + args = tree_unflatten(prim.in_tree, args_flat) + dims = tree_unflatten(prim.in_tree, dims_flat) + ans, dims = prim.batch(axis_data, args, dims) + ans_flat = tree_leaves_checked(prim.out_tree, ans) + dims_flat = prim.out_tree.flatten_up_to(dims) + return ans_flat, dims_flat +batching.fancy_primitive_batchers[call_hi_primitive_p] = _call_hi_primitive_batcher + +def _call_hi_primitive_linearize(nz_in_flat, *args_flat, prim): + args = tree_unflatten(prim.in_tree, args_flat) + nzs_in = tree_unflatten(prim.in_tree, nz_in_flat) + ans, residuals, *maybe_nzs_out = prim.vjp_fwd(nzs_in, *args) + ans_flat = tree_leaves_checked(prim.out_tree, ans) + nzs_out = True if maybe_nzs_out == [] else maybe_nzs_out[0] + nzs_out_flat = broadcast_prefix(nzs_out, ans) + linearized = partial(fake_linear_op, prim, nz_in_flat) + return (ans_flat, nzs_out_flat, residuals, linearized) + +def fake_linear_op(prim, nz_in_flat, rs, *tangents): + residuals_flat, residuals_tree = tree_flatten(rs) + assert nz_in_flat == [not isinstance(t, ad_util.Zero) for t in tangents] + nz_tangents = tree_leaves(tangents) + return call_hi_primitive_linearized_p.bind( + *residuals_flat, *nz_tangents, residuals_tree=residuals_tree, prim=prim, + nz_in_flat=tuple(nz_in_flat)) + +ad.primitive_linearizations[call_hi_primitive_p] = _call_hi_primitive_linearize + +call_hi_primitive_linearized_p = core.Primitive("call_hi_primitive_linearized") +call_hi_primitive_linearized_p.multiple_results = True +call_hi_primitive_linearized_p.is_high = lambda *args, prim, **_: True # type: ignore +@call_hi_primitive_linearized_p.def_abstract_eval +def _call_hi_primitive_linearized_abstract_eval(*_args, prim, residuals_tree, nz_in_flat): + return [t.to_tangent_aval() for t in prim.out_avals_flat] # TODO(dougalm): handle nonzeros + +def _call_hi_primitive_linearized_transpose(cts_flat, *args, prim, residuals_tree, nz_in_flat): + residuals_flat, accums_flat = split_list(args, [residuals_tree.num_leaves]) + residuals = tree_unflatten(residuals_tree, residuals_flat) + accums_flat_ = iter(accums_flat) + accums_flat = [next(accums_flat_) if nz else ad.NullAccum() for nz in nz_in_flat] + assert next(accums_flat_, None) is None + accums = tree_unflatten(prim.in_tree, accums_flat) + cts = tree_unflatten(prim.out_tree, cts_flat) + none = prim.vjp_bwd(residuals, cts, *accums) + assert none is None +ad.fancy_transposes[call_hi_primitive_linearized_p] = _call_hi_primitive_linearized_transpose + +def _call_hi_primitive_jvp(primals, tangents, *, prim): + primals = tree_unflatten(prim.in_tree, primals) + tangents = tree_unflatten(prim.in_tree, tangents) + out_primals, out_tangents = prim.jvp(primals, tangents) + out_primals_flat = tree_leaves_checked(prim.out_tree, out_primals) + out_tangents_flat = prim.out_tree.flatten_up_to(out_tangents) + return out_primals_flat, out_tangents_flat +ad.primitive_jvps[call_hi_primitive_p] = _call_hi_primitive_jvp + +def _call_hi_primitive_dce(used_outs_flat, eqn): + if hasattr(prim := eqn.params['prim'], 'dce'): + return prim.dce(used_outs_flat, eqn) + else: + return pe._default_dce_rule(used_outs_flat, eqn) +pe.dce_rules[call_hi_primitive_p] = _call_hi_primitive_dce + +call_hi_primitive_linearized_p.to_lojax = ad.raise_custom_vjp_error_on_jvp +batching.fancy_primitive_batchers[call_hi_primitive_linearized_p] = ad.raise_custom_vjp_error_on_jvp + + +class CustomVJPTraced(VJPHiPrimitive): + def __init__(self, traced, fwd, bwd, in_avals, sym_zeros, static_argnums, opt_remat): + self.in_avals = in_avals + self.out_aval = traced.out_avals + self.params = dict(traced=traced, fwd=fwd, bwd=bwd, symbolic_zeros=sym_zeros, + static_argnums=static_argnums, opt_remat=opt_remat) + super().__init__() + + def expand(self, *args): + args = [x for x in args if not isinstance(x, Static)] + return self.traced(*args) # type: ignore + + def vjp_fwd(self, in_nzs, *args): + in_nzs = tuple(x.val if isinstance(x, Static) else x for x in in_nzs) + args = tuple(x.val if isinstance(x, Static) else x for x in args) + if self.symbolic_zeros: # type: ignore + args = tree_map(CustomVJPPrimal, args, in_nzs) + out, res = self.fwd(*args) # type: ignore + if ((tree := tree_structure(out)) != self.out_tree): + raise TypeError(_vjp_primal_fwd_tree_mismatch_err(self, tree)) + tree_map_with_path(_vjp_fwd_aval_mismatch_err, self.out_aval, out) + if self.symbolic_zeros: # type: ignore + out_pairs_flat = tree_leaves_checked(self.out_tree, out) + out_flat, out_nzs_flat = unzip2( + (x.value, x.perturbed) if isinstance(x, CustomVJPPrimal) else + (x, True) for x in out_pairs_flat) + out_nzs = tree_unflatten(self.out_tree, out_nzs_flat) + out = tree_unflatten(self.out_tree, out_flat) + return out, res, out_nzs + else: + return out, res + + def vjp_bwd_retval(self, res, out_ct): + static_args = tuple(x.val for x in self.in_avals if isinstance(x, Static)) + in_avals_ = tuple(x for x in self.in_avals if not isinstance(x, Static)) + leaf = lambda x: isinstance(x, ad_util.Zero) + if self.symbolic_zeros: # type: ignore + out_ct = tree_map(ad_util.replace_internal_symbolic_zeros, out_ct, is_leaf=leaf) + else: + out_ct = tree_map(ad_util.instantiate, out_ct, is_leaf=leaf) + in_cts = self.bwd(*static_args, res, out_ct) # type: ignore + if isinstance(in_cts, list): + in_cts = tuple(in_cts) + if not isinstance(in_cts, tuple): + raise TypeError(f"Custom VJP bwd rule {self.bwd} must produce a tuple " # type: ignore + f"but got {type(in_cts)}.") # type: ignore + if len(in_cts) != len(self.in_tree.children()) - len(self.static_argnums): # type: ignore + raise ValueError(f"Custom VJP bwd rule {self.bwd} must produce a tuple " # type: ignore + "of length equal to the primal args tuple, but got " + f"length {len(in_cts)}") # type: ignore + in_cts = broadcast_prefix(in_cts, in_avals_, is_leaf=lambda x: x is None) + in_cts = tree_unflatten(self.in_tree, map(_replace_none, self.in_avals_flat, in_cts)) + tree_map_with_path(_vjp_bwd_aval_mismatch_err, self.in_avals, in_cts) + if self.symbolic_zeros: # type: ignore + in_cts = tree_map(ad_util.replace_rule_output_symbolic_zeros, in_cts) + return in_cts + + def jvp(self, primals, tangents): + if self.symbolic_zeros: raise NotImplementedError # type: ignore + zero = lambda x: isinstance(x, ad_util.Zero) + tangents = tree_map(ad_util.instantiate, tangents, is_leaf=zero) + if self.opt_remat: # type: ignore + fwd_traced = api.jit(partial(self.vjp_fwd, (True,) * len(primals))).trace(*primals) + primals_out, residuals = OptRemat(self.traced, fwd_traced)(*primals) # type: ignore + else: + primals_out, residuals, *_ = self.vjp_fwd((True,) * len(primals), *primals) + tangents_out_flat = fake_linear_op(self, [True] * len(tangents), residuals, *tangents) + tangents_out = tree_unflatten(self.out_tree, tangents_out_flat) + return primals_out, tangents_out + + def batch_dim_rule(self, axis_data, in_dims): + in_dims_flat = self.in_tree.flatten_up_to(in_dims) + _, out_dims = batching.batch_jaxpr2(self.traced.jaxpr, axis_data, tuple(in_dims_flat)) # type: ignore + return tree_unflatten(self.out_tree, out_dims) + +def _vjp_primal_fwd_tree_mismatch_err(self, tree): + return (f"Custom VJP fwd rule {self.fwd.__name__} for function {self.traced.fun_name} " # type: ignore + "must produce a pair (list or tuple of length two) where the first " + "element represents the primal output " + "(equal to the output of the custom_vjp-decorated function " + f"{self.traced.fun_name}) and the " # type: ignore + "second element represents residuals (i.e. values stored from the " + "forward pass for use on the backward pass), but " + f"instead the fwd rule output's first element had container/pytree " + "structure:\n" + f""" {str(tree ).replace("'", "")}\n""" # type: ignore + f"while the custom_vjp-decorated function {self.traced.fun_name} had output " # type: ignore + "container/pytree structure:\n" + f""" {str(self.out_tree).replace("'", "")}.""") # type: ignore + +def _vjp_fwd_aval_mismatch_err(path, primal_aval, fwd_val): + if not core.typematch(ty := typeof(fwd_val), primal_aval): + raise TypeError(f"at {keystr(path)}, got fwd output type {ty.str_short()} " + f"which doesn't match primal output type {primal_aval.str_short()}") + +def _vjp_bwd_aval_mismatch_err(path, primal_aval, ct_val): + if config.disable_bwd_checks.value: return + if isinstance(ct_val, ad_util.Zero): return + if isinstance(primal_aval, AbstractRef): primal_aval = primal_aval.inner_aval + expected = primal_aval.to_cotangent_aval() + ty = ct_val.aval if isinstance(ct_val, ad_util.SymbolicZero) else typeof(ct_val) + if not core.typematch(ty, expected) and getattr(expected, 'dtype', None) is not dtypes.float0: + result = f"at output{keystr(path)} " if path else "" + raise ValueError(f"{result}the bwd rule produced an output of type {ty.str_short()} " + f"which doesn't match expected type {expected.str_short()}") + +def _replace_none(primal_in_aval, maybe_ct): + if maybe_ct is None: + return ad_util.Zero(primal_in_aval.to_cotangent_aval()) + else: + return maybe_ct + +class custom_vjp3: + fwd: Callable | None = None + bwd: Callable | None = None + + def __init__(self, f, *, nondiff_argnums=(), nondiff_argnames=()): + self.f = f + self.static_argnums = _set_up_nondiff(f, nondiff_argnums, nondiff_argnames) + + def defvjp(self, fwd, bwd, *, symbolic_zeros=False, optimize_remat=False): + self.fwd = fwd + self.bwd = bwd + self.symz = symbolic_zeros + self.opt_remat = optimize_remat + return self + + def __call__(self, *args, **kwargs): + if not self.fwd or not self.bwd: + msg = f"No VJP defined for custom_vjp function {self.f.__name__} using defvjp." + raise AttributeError(msg) + + args = resolve_kwargs(self.f, args, kwargs) + if any(isinstance(args[i], core.Tracer) for i in self.static_argnums): + raise UnexpectedTracerError("custom_vjp inputs marked with nondiff_argnums " + "must be static, not Tracers") + traced = api.jit(self.f, static_argnums=(*self.static_argnums,)).trace(*args) + if any(isinstance(x, core.Tracer) for x in traced._consts): + raise Exception # TODO(mattjj):error tracer type, value type, primal name + args = tuple(Static(x) if i in self.static_argnums else x for i, x in enumerate(args)) + in_avals = tree_map(typeof, args) + prim = CustomVJPTraced(traced, self.fwd, self.bwd, in_avals, self.symz, # type: ignore + self.static_argnums, self.opt_remat) # type: ignore + return prim(*args) + +class OptRemat(VJPHiPrimitive): + traced_fwd: Any + traced_primal: Any + + def __init__(self, traced_primal, traced_fwd): + self.in_avals, _ = traced_primal.in_avals + self.out_aval = traced_fwd.out_avals + self.params = dict(traced_primal=traced_primal, traced_fwd=traced_fwd) + super().__init__() + + def expand(self, *primals): + return self.traced_fwd(*primals) + + def dce(self, used_outs, eqn): + num_primals_in = len(self.traced_primal.jaxpr.in_avals) + num_primals_out = len(self.traced_primal.jaxpr.out_avals) + _, used_res = split_list(used_outs, [num_primals_out]) + if any(used_res): + return [True] * num_primals_in, eqn + else: + outvars = [v for used, v in zip(used_outs, eqn.outvars) if used] + primal_eqn = pe.new_jaxpr_eqn( + eqn.invars, outvars, core.closed_call_p, dict(call_jaxpr=self.traced_primal.jaxpr), + self.traced_primal.jaxpr.effects, eqn.source_info, eqn.ctx) + return [True] * num_primals_in, primal_eqn + + # TODO(mattjj): jvp and transpose? does anyone rely on them? + + +def _set_up_nondiff(f, argnums_, argnames) -> frozenset[int]: + argnums = set(argnums_) + if argnames: + sig = inspect.signature(f) # needed for static_argnames + argnums |= set(infer_argnums_and_argnames(sig, None, argnames)[0]) + return frozenset(argnums) + +@register_static +@dataclass(frozen=True) +class Static: + val: Any diff --git a/jax/_src/image/scale.py b/jax/_src/image/scale.py index faaee6a540e4..ea95bfe9ed14 100644 --- a/jax/_src/image/scale.py +++ b/jax/_src/image/scale.py @@ -15,16 +15,17 @@ from __future__ import annotations from collections.abc import Callable, Sequence -from functools import partial import enum from typing import Any import numpy as np -from jax import jit -from jax import lax -from jax import numpy as jnp +from jax._src import api from jax._src import core +from jax._src import dtypes +from jax._src import numpy as jnp +from jax._src.lax import lax +from jax._src.numpy import einsum as jnp_einsum from jax._src.util import canonicalize_axis from jax._src.numpy.util import promote_dtypes_inexact @@ -56,7 +57,7 @@ def compute_weight_mat(input_size: core.DimSize, translation, kernel: Callable, antialias: bool): - dtype = jnp.result_type(scale, translation) + dtype = dtypes.result_type(scale, translation) inv_scale = 1. / scale # When downsampling the kernel should be scaled since we want to low pass # filter and interpolate, but when upsampling it should not be since we only @@ -65,8 +66,8 @@ def compute_weight_mat(input_size: core.DimSize, sample_f = ((jnp.arange(output_size, dtype=dtype) + 0.5) * inv_scale - translation * inv_scale - 0.5) x = ( - jnp.abs(sample_f[jnp.newaxis, :] - - jnp.arange(input_size, dtype=dtype)[:, jnp.newaxis]) / + jnp.abs(sample_f[np.newaxis, :] - + jnp.arange(input_size, dtype=dtype)[:, np.newaxis]) / kernel_scale) weights = kernel(x) @@ -81,7 +82,7 @@ def compute_weight_mat(input_size: core.DimSize, input_size_minus_0_5 = core.dimension_as_value(input_size) - 0.5 return jnp.where( jnp.logical_and(sample_f >= -0.5, - sample_f <= input_size_minus_0_5)[jnp.newaxis, :], weights, 0) + sample_f <= input_size_minus_0_5)[np.newaxis, :], weights, 0) def _scale_and_translate(x, output_shape: core.Shape, @@ -106,7 +107,7 @@ def _scale_and_translate(x, output_shape: core.Shape, contractions.append([d, len(output_shape) + i]) out_indices[d] = len(output_shape) + i contractions.append(out_indices) - return jnp.einsum(x, in_indices, *contractions, precision=precision) + return jnp_einsum.einsum(x, in_indices, *contractions, precision=precision) class ResizeMethod(enum.Enum): @@ -270,7 +271,7 @@ def _resize_nearest(x, output_shape: core.Shape): return x -@partial(jit, static_argnums=(1, 2, 3, 4)) +@api.jit(static_argnums=(1, 2, 3, 4)) def _resize(image, shape: core.Shape, method: str | ResizeMethod, antialias: bool, precision): if len(shape) != image.ndim: diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/__init__.py b/jax/_src/internal_test_util/export_back_compat_test_data/__init__.py new file mode 100644 index 000000000000..3da0dd1fa3ca --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/annotate_data_placement.py b/jax/_src/internal_test_util/export_back_compat_test_data/annotate_data_placement.py new file mode 100644 index 000000000000..c05b7d039cf1 --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/annotate_data_placement.py @@ -0,0 +1,254 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + +import datetime +import numpy as np + +array = np.array +float32 = np.float32 + +data_2025_04_07_tpu = {} + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_04_07_tpu['gspmd'] = dict( + testdata_version=1, + platform='tpu', + custom_call_targets=['annotate_device_placement'], + serialized_date=datetime.date(2025, 4, 7), + inputs=(array([0.], dtype=float32), array([0.], dtype=float32)), + expected_outputs=(array([0.], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("x") +#loc2 = loc("y") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1xf32> {mhlo.memory_kind = "device", mhlo.sharding = "{maximal device=0}"} loc("x"), %arg1: tensor<1xf32> {mhlo.memory_kind = "pinned_host", mhlo.sharding = "{maximal device=0}"} loc("y")) -> (tensor<1xf32> {jax.result_info = "result", mhlo.memory_kind = "pinned_host", mhlo.sharding = "{maximal device=0}"}) { + %0 = stablehlo.add %arg0, %arg1 : tensor<1xf32> loc(#loc4) + %1 = stablehlo.custom_call @annotate_device_placement(%0) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<1xf32>) -> tensor<1xf32> loc(#loc) + return %1 : tensor<1xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":878:13) +#loc4 = loc("jit(func)/jit(main)/add"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.9.3\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03oQ\x0b\x01%\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0b\x13\x0b\x03-\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0f#\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x01\x05\x0b\x0f\x03\x07\x13\x1b\x07\x02\n\x02\x1f\x11\x03\x05\x03\x07\x07\t\x0b\x03\r\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x13\x01\x05\x17\x1d\x17\x01\x05\x19\x1d\x1b\x1d\x05\x1b\x17\x1f\xba\r\x1b\x05\x1d\x03\x03#E\x05\x1f\x03\x01\x1d!\x1d#\x1d%\x1d'\x03\x0515\r\x05'3)+\x1d)\r\x05'-)+#\x07\x03\x03;\r\x07=?'-)+\x1d+\x1d-\x1d/\x1d1\r\x03G-\x1d3\x0b\x03\x1d5\x1d7\x05\x03\x01\t\x01\x02\x02)\x03\x05\t\x11\x05\x05\x05\x03\x05\t\x04e\x05\x01Q\x01\x05\x01\x07\x04S\x03\x01\x05\x03P\x01\x03\x07\x04?\x03\t\x0f\x05\x0b\x11\x0b\x15\x00\x05\x06\x19\x03\x05\x05\x01\x03\x07G\x01!\x05\x03\x05\x03\x05\t\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00\x9a\x0695\x03-\x0f\x0b\x0f!\x0f\x19'\x1d#3i1\x05\x05\x13%)9\x15\x1f\x0f\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00add_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00x\x00y\x00jit(func)/jit(main)/add\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.frontend_attributes\x00mhlo.memory_kind\x00mhlo.sharding\x00{maximal device=0}\x00pinned_host\x00device\x00jax.result_info\x00result\x00main\x00public\x00_xla_buffer_placement\x00\x00annotate_device_placement\x00\x08'\x07\x05\x1f\x01\x0b/79AC\x11IKM%O%%%", + xla_call_module_version=9, + nr_devices=1, +) # End paste + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_04_07_tpu['shardy'] = dict( + testdata_version=1, + platform='tpu', + custom_call_targets=['annotate_device_placement', 'xla.sdy.FuncResultSharding'], + serialized_date=datetime.date(2025, 5, 28), + inputs=(array([0.], dtype=float32), array([0.], dtype=float32)), + expected_outputs=(array([0.], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("x") +#loc2 = loc("y") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.frontend_attributes = {xla.sdy.meshes = "{mesh = #sdy.mesh<[\22a\22=1]>}"}, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\22a\22}]>"}, mhlo.memory_kind = "device", mhlo.sharding = "{devices=[1]<=[1]}"} loc("x"), %arg1: tensor<1xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\22a\22}]>"}, mhlo.memory_kind = "pinned_host", mhlo.sharding = "{devices=[1]<=[1]}"} loc("y")) -> (tensor<1xf32> {jax.result_info = "result", mhlo.memory_kind = "pinned_host", mhlo.sharding = "{devices=[1]<=[1]}"}) { + %0 = stablehlo.add %arg0, %arg1 : tensor<1xf32> loc(#loc4) + %1 = stablehlo.custom_call @annotate_device_placement(%0) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<1xf32>) -> tensor<1xf32> loc(#loc) + %2 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%1) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\22a\22}]>]>"}} : (tensor<1xf32>) -> tensor<1xf32> loc(#loc) + return %2 : tensor<1xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":801:13) +#loc4 = loc("jit(func)/jit(main)/add"(#loc3)) +""", + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.10.3\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03\x85g\x0b\x01-\x07\x0b\x0f+\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0b\x13\x13\x03;\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x13#\x0b\x0b#\x0b\x0f#\x0b\x0b\x0b\x0b\x13\x0b\x0b\x13\x0b\x0b\x01\x05\x0b\x0f\x03\x07\x13\x1b\x07\x02\x9a\x02\x1f\x05\x0f\x11\x03\x05\x03\t\t\x0b\x03\r\x13\x05\x15\x05\x05\x11\x11\x01\x00\x03\x03\x0f\x11\x05\x13\x05\x15\x05\x17\x05\x19\x05\x1b\x1d\x1b\x01\x05\x1d\x1d\x1f\x01\x05\x1f\x1d#%\x05!\x17\'\x86\x0c\x1b\x05#\x03\x03\x03[\x03\x03\x03a\x03\x01\x1d%\x1d\'\x1d)\x1d+\x1d\x0f\r\x03;G\x1d-\x0b\x03\x1d/\x05\x03\x03\x05EK\r\x0779/I13\x1d1\x1d3\r\x0779/513#\x07\x03\x03Q\r\x07SU/513\x1d5\x1d7\x1d9\x1d;\r\x03]5\x1d=\x1d?\r\x03;c\x1dA\x1dC\x01\t\x01\x02\x02)\x03\x05\t\x11\x05\x05\x05\x03\x05\t\x04w\x05\x01Q\x01\x07\x01\x07\x04e\x03\x01\x05\x05P\x01\x03\x07\x04Q\x03\x0b\x13\x05\x0b\x19\x0b\x1d\x00\x07\x06!\x03\x05\x05\x01\x03\x03G\x01)\x05\x03\x05\x03\x05\x03G\x01+\x07\x03\x05\x03\x07\t\x04\x01\x03\t\x06\x03\x01\x05\x01\x006\tE7Y5-\x0f\x0b\x0f!\x0f=\x03#\x19\'\x1d#i1\x05\x05\x13%)9\x1f93\x15\x0f\x11\x1f\x0f\x0b\x11builtin\x00vhlo\x00module\x00custom_call_v1\x00func_v1\x00add_v1\x00return_v1\x00mhlo.frontend_attributes\x00jax.uses_shape_polymorphism\x00xla.sdy.meshes\x00{mesh = #sdy.mesh<["a"=1]>}\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00x\x00y\x00jit(func)/jit(main)/add\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.memory_kind\x00mhlo.sharding\x00{devices=[1]<=[1]}\x00pinned_host\x00xla.sdy.sharding\x00\x00#sdy.sharding<@mesh, [{"a"}]>\x00device\x00jax.result_info\x00result\x00main\x00public\x00_xla_buffer_placement\x00annotate_device_placement\x00#sdy.sharding_per_value<[<@mesh, [{"a"}]>]>\x00xla.sdy.FuncResultSharding\x00\x089\t\x05/\x01\x0bCMOWY\x11=?_-A---\x11=?e-A---', + xla_call_module_version=9, + nr_devices=1, +) # End paste + +data_2025_06_30_tpu = {} + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_06_30_tpu['gspmd'] = dict( + testdata_version=1, + platform='tpu', + custom_call_targets=['annotate_device_placement'], + serialized_date=datetime.date(2025, 7, 1), + inputs=(array([0.], dtype=float32), array([0.], dtype=float32)), + expected_outputs=(array([0.], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("x") +#loc2 = loc("y") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1xf32> {mhlo.memory_kind = "device", mhlo.sharding = "{replicated}"} loc("x"), %arg1: tensor<1xf32> {mhlo.memory_kind = "pinned_host", mhlo.sharding = "{replicated}"} loc("y")) -> (tensor<1xf32> {jax.result_info = "result", mhlo.memory_kind = "pinned_host", mhlo.sharding = "{replicated}"}) { + %0 = stablehlo.add %arg0, %arg1 : tensor<1xf32> loc(#loc5) + %1 = stablehlo.custom_call @annotate_device_placement(%0) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<1xf32>) -> tensor<1xf32> loc(#loc) + return %1 : tensor<1xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":819:13) +#loc4 = loc("jit(func)/add"(#loc3)) +#loc5 = loc("add:"(#loc4)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.10.9\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03sU\x0b\x01)\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0b\x13\x0b\x03-\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0f#\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x01\x05\x0b\x0f\x03\x07\x13\x1b\x07\x02\x1e\x02\x1f\x11\x03\x05\x03\x07\x07\t\x0b\x03\r\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x13\x01\x05\x17\x1d\x17\x01\x05\x19\x1d\x1b\x1d\x05\x1b\x1d\x1f!\x05\x1d\x17#\xce\x0c\x1b\x05\x1f\x03\x03'I\x05!\x03\x01\x1d#\x1d%\x1d'\x1d)\x03\x0559\r\x05+7-/\x1d+\r\x05+1-/#\x07\x03\x03?\r\x07AC+1-/\x1d-\x1d/\x1d1\x1d3\r\x03K1\x1d5\x0b\x03\x1d7\x1d9\x05\x03\x01\t\x01\x02\x02)\x03\x05\t\x11\x05\x05\x05\x03\x05\t\x04e\x05\x01Q\x01\x05\x01\x07\x04S\x03\x01\x05\x03P\x01\x03\x07\x04?\x03\t\x0f\x05\x0b\x11\x0b\x15\x00\x05\x06\x19\x03\x05\x05\x01\x03\x07G\x01%\x05\x03\x05\x03\x05\t\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00r\x06;5\x03-\x0f\x0b\x0f!\x0f\x19\x1b\x1d#3i\x1d\x0b\x05\x05\x13%)9\x15\x1f\x0f\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00add_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00x\x00y\x00add:\x00jit(func)/add\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.frontend_attributes\x00mhlo.memory_kind\x00mhlo.sharding\x00{replicated}\x00pinned_host\x00device\x00jax.result_info\x00result\x00main\x00public\x00_xla_buffer_placement\x00\x00annotate_device_placement\x00\x08'\x07\x05\x1f\x01\x0b3;=EG\x11MOQ)S)))", + xla_call_module_version=9, + nr_devices=1, +) # End paste + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_06_30_tpu['shardy'] = dict( + testdata_version=1, + platform='tpu', + custom_call_targets=['annotate_device_placement'], + serialized_date=datetime.date(2025, 6, 30), + inputs=(array([0.], dtype=float32), array([0.], dtype=float32)), + expected_outputs=(array([0.], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("x") +#loc2 = loc("y") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + sdy.mesh @mesh = <["a"=1]> loc(#loc) + func.func public @main(%arg0: tensor<1xf32> {mhlo.memory_kind = "device", sdy.sharding = #sdy.sharding<@mesh, [{"a"}]>} loc("x"), %arg1: tensor<1xf32> {mhlo.memory_kind = "pinned_host", sdy.sharding = #sdy.sharding<@mesh, [{"a"}]>} loc("y")) -> (tensor<1xf32> {jax.result_info = "result", mhlo.memory_kind = "pinned_host", sdy.sharding = #sdy.sharding<@mesh, [{"a"}]>}) { + %0 = stablehlo.add %arg0, %arg1 : tensor<1xf32> loc(#loc5) + %1 = stablehlo.custom_call @annotate_device_placement(%0) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<1xf32>) -> tensor<1xf32> loc(#loc) + return %1 : tensor<1xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":810:13) +#loc4 = loc("jit(func)/add"(#loc3)) +#loc5 = loc("add:"(#loc4)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.10.9\x00\x01#\x07\x01\x05\t\r\x01\x03\x0f\x03\x03\x13\x05\t\x17\x1b\x1f#\x03\x83a\x0b\x01-\x07\x0f\x0b#\x0b\x0f\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0b\x13\x0b\x03\x0b\x17\x13\x0f\x17\x0f\x05+\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0f#\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x01\x05\x0b\x0f\x05\x07\x13\x1b\x07\x02v\x02\x1f\x11\x03\x05\x05\t\x03\x07\t\x0b\r\x03\x0f\x03\x05\x13\x11\x01\x00\x05\x15\x05\x17\x05\x19\t\x05\x1d\x17\x01\x05\x1b\x1d\x1b\x01\x05\x1d\x1d\x1f!\x05\x1f\x1d#%\x05!\x17'\xaa\x0c\x1b\x05#\x03\x03+U\x05%\r\x13\x033\x01\x05\x031\x01\x03'\x05\x0b\x035\x01\x01\t'\x01\x03\x01\x1d)\x1d+\x1d-\x03\x05AE\r\x059C;-\x1d/\r\x059=;-#\x07\x03\x03K\r\x07MO9=;-\x1d1\x1d3\x1d5\x1d7\r\x03W=\x1d9\x0b\x03\x1d;\x1d=\x05\x03\x01\t\x01\x02\x02)\x03\x05\t\x11\x05\x05\x05\x03\x05\t\x04m\x05\x01Q\x01\x07\x01\x07\x04[\x03\x01\t\x03@\x01\x03\x05P\x01\x05\x07\x04?\x03\t\x0f\x05\x0b\x15\x0b\x19\x00\x07\x06\x1d\x03\x05\x05\x01\x03\tG\x01)\x07\x03\x05\x03\x05\x0b\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00n\x06?5\x03-\x0f\x0b\x0f!\x0f\x19\x1b#\x053i\x1d\x0b\x05\x05\x13%)9\x15\x1f\x0f\x11\x0b\x0f\x0b\t\x11builtin\x00sdy\x00vhlo\x00module\x00mesh\x00func_v1\x00add_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00x\x00y\x00add:\x00jit(func)/add\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.frontend_attributes\x00a\x00mhlo.memory_kind\x00sdy.sharding\x00pinned_host\x00device\x00jax.result_info\x00result\x00main\x00public\x00_xla_buffer_placement\x00\x00annotate_device_placement\x00\x08-\t\x05#\x01\x05/\x05\x0b?GIQS\x11Y[]7_777", + xla_call_module_version=9, + nr_devices=1, +) # End paste + +data_2025_04_07_cuda = {} + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_04_07_cuda['gspmd'] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['annotate_device_placement'], + serialized_date=datetime.date(2025, 4, 7), + inputs=(array([0.], dtype=float32), array([0.], dtype=float32)), + expected_outputs=(array([0.], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("x") +#loc2 = loc("y") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1xf32> {mhlo.memory_kind = "device", mhlo.sharding = "{maximal device=0}"} loc("x"), %arg1: tensor<1xf32> {mhlo.memory_kind = "pinned_host", mhlo.sharding = "{maximal device=0}"} loc("y")) -> (tensor<1xf32> {jax.result_info = "result", mhlo.memory_kind = "pinned_host", mhlo.sharding = "{maximal device=0}"}) { + %0 = stablehlo.add %arg0, %arg1 : tensor<1xf32> loc(#loc4) + %1 = stablehlo.custom_call @annotate_device_placement(%0) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<1xf32>) -> tensor<1xf32> loc(#loc) + return %1 : tensor<1xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":878:13) +#loc4 = loc("jit(func)/jit(main)/add"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.9.3\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03oQ\x0b\x01%\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0b\x13\x0b\x03-\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0f#\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x01\x05\x0b\x0f\x03\x07\x13\x1b\x07\x02\n\x02\x1f\x11\x03\x05\x03\x07\x07\t\x0b\x03\r\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x13\x01\x05\x17\x1d\x17\x01\x05\x19\x1d\x1b\x1d\x05\x1b\x17\x1f\xba\r\x1b\x05\x1d\x03\x03#E\x05\x1f\x03\x01\x1d!\x1d#\x1d%\x1d'\x03\x0515\r\x05'3)+\x1d)\r\x05'-)+#\x07\x03\x03;\r\x07=?'-)+\x1d+\x1d-\x1d/\x1d1\r\x03G-\x1d3\x0b\x03\x1d5\x1d7\x05\x03\x01\t\x01\x02\x02)\x03\x05\t\x11\x05\x05\x05\x03\x05\t\x04e\x05\x01Q\x01\x05\x01\x07\x04S\x03\x01\x05\x03P\x01\x03\x07\x04?\x03\t\x0f\x05\x0b\x11\x0b\x15\x00\x05\x06\x19\x03\x05\x05\x01\x03\x07G\x01!\x05\x03\x05\x03\x05\t\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00\x9a\x0695\x03-\x0f\x0b\x0f!\x0f\x19'\x1d#3i1\x05\x05\x13%)9\x15\x1f\x0f\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00add_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00x\x00y\x00jit(func)/jit(main)/add\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.frontend_attributes\x00mhlo.memory_kind\x00mhlo.sharding\x00{maximal device=0}\x00pinned_host\x00device\x00jax.result_info\x00result\x00main\x00public\x00_xla_buffer_placement\x00\x00annotate_device_placement\x00\x08'\x07\x05\x1f\x01\x0b/79AC\x11IKM%O%%%", + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_04_07_cuda['shardy'] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['annotate_device_placement', 'xla.sdy.FuncResultSharding'], + serialized_date=datetime.date(2025, 5, 28), + inputs=(array([0.], dtype=float32), array([0.], dtype=float32)), + expected_outputs=(array([0.], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("x") +#loc2 = loc("y") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.frontend_attributes = {xla.sdy.meshes = "{mesh = #sdy.mesh<[\22a\22=1]>}"}, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\22a\22}]>"}, mhlo.memory_kind = "device", mhlo.sharding = "{devices=[1]<=[1]}"} loc("x"), %arg1: tensor<1xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\22a\22}]>"}, mhlo.memory_kind = "pinned_host", mhlo.sharding = "{devices=[1]<=[1]}"} loc("y")) -> (tensor<1xf32> {jax.result_info = "result", mhlo.memory_kind = "pinned_host", mhlo.sharding = "{devices=[1]<=[1]}"}) { + %0 = stablehlo.add %arg0, %arg1 : tensor<1xf32> loc(#loc4) + %1 = stablehlo.custom_call @annotate_device_placement(%0) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<1xf32>) -> tensor<1xf32> loc(#loc) + %2 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%1) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\22a\22}]>]>"}} : (tensor<1xf32>) -> tensor<1xf32> loc(#loc) + return %2 : tensor<1xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":806:13) +#loc4 = loc("jit(func)/jit(main)/add"(#loc3)) +""", + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.10.3\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03\x85g\x0b\x01-\x07\x0b\x0f+\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0b\x13\x13\x03;\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x13#\x0b\x0b#\x0b\x0f#\x0b\x0b\x0b\x0b\x13\x0b\x0b\x13\x0b\x0b\x01\x05\x0b\x0f\x03\x07\x13\x1b\x07\x02\x9a\x02\x1f\x05\x0f\x11\x03\x05\x03\t\t\x0b\x03\r\x13\x05\x15\x05\x05\x11\x11\x01\x00\x03\x03\x0f\x11\x05\x13\x05\x15\x05\x17\x05\x19\x05\x1b\x1d\x1b\x01\x05\x1d\x1d\x1f\x01\x05\x1f\x1d#%\x05!\x17\'\x9a\x0c\x1b\x05#\x03\x03\x03[\x03\x03\x03a\x03\x01\x1d%\x1d\'\x1d)\x1d+\x1d\x0f\r\x03;G\x1d-\x0b\x03\x1d/\x05\x03\x03\x05EK\r\x0779/I13\x1d1\x1d3\r\x0779/513#\x07\x03\x03Q\r\x07SU/513\x1d5\x1d7\x1d9\x1d;\r\x03]5\x1d=\x1d?\r\x03;c\x1dA\x1dC\x01\t\x01\x02\x02)\x03\x05\t\x11\x05\x05\x05\x03\x05\t\x04w\x05\x01Q\x01\x07\x01\x07\x04e\x03\x01\x05\x05P\x01\x03\x07\x04Q\x03\x0b\x13\x05\x0b\x19\x0b\x1d\x00\x07\x06!\x03\x05\x05\x01\x03\x03G\x01)\x05\x03\x05\x03\x05\x03G\x01+\x07\x03\x05\x03\x07\t\x04\x01\x03\t\x06\x03\x01\x05\x01\x006\tE7Y5-\x0f\x0b\x0f!\x0f=\x03#\x19\'\x1d#i1\x05\x05\x13%)9\x1f93\x15\x0f\x11\x1f\x0f\x0b\x11builtin\x00vhlo\x00module\x00custom_call_v1\x00func_v1\x00add_v1\x00return_v1\x00mhlo.frontend_attributes\x00jax.uses_shape_polymorphism\x00xla.sdy.meshes\x00{mesh = #sdy.mesh<["a"=1]>}\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00x\x00y\x00jit(func)/jit(main)/add\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.memory_kind\x00mhlo.sharding\x00{devices=[1]<=[1]}\x00pinned_host\x00xla.sdy.sharding\x00\x00#sdy.sharding<@mesh, [{"a"}]>\x00device\x00jax.result_info\x00result\x00main\x00public\x00_xla_buffer_placement\x00annotate_device_placement\x00#sdy.sharding_per_value<[<@mesh, [{"a"}]>]>\x00xla.sdy.FuncResultSharding\x00\x089\t\x05/\x01\x0bCMOWY\x11=?_-A---\x11=?e-A---', + xla_call_module_version=9, + nr_devices=1, +) # End paste + +data_2025_06_30_cuda = {} + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_06_30_cuda['gspmd'] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['annotate_device_placement'], + serialized_date=datetime.date(2025, 4, 7), + inputs=(array([0.], dtype=float32), array([0.], dtype=float32)), + expected_outputs=(array([0.], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("x") +#loc2 = loc("y") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1xf32> {mhlo.memory_kind = "device", mhlo.sharding = "{maximal device=0}"} loc("x"), %arg1: tensor<1xf32> {mhlo.memory_kind = "pinned_host", mhlo.sharding = "{maximal device=0}"} loc("y")) -> (tensor<1xf32> {jax.result_info = "result", mhlo.memory_kind = "pinned_host", mhlo.sharding = "{maximal device=0}"}) { + %0 = stablehlo.add %arg0, %arg1 : tensor<1xf32> loc(#loc4) + %1 = stablehlo.custom_call @annotate_device_placement(%0) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<1xf32>) -> tensor<1xf32> loc(#loc) + return %1 : tensor<1xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":878:13) +#loc4 = loc("jit(func)/jit(main)/add"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.9.3\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03oQ\x0b\x01%\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0b\x13\x0b\x03-\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0f#\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x01\x05\x0b\x0f\x03\x07\x13\x1b\x07\x02\n\x02\x1f\x11\x03\x05\x03\x07\x07\t\x0b\x03\r\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x13\x01\x05\x17\x1d\x17\x01\x05\x19\x1d\x1b\x1d\x05\x1b\x17\x1f\xba\r\x1b\x05\x1d\x03\x03#E\x05\x1f\x03\x01\x1d!\x1d#\x1d%\x1d'\x03\x0515\r\x05'3)+\x1d)\r\x05'-)+#\x07\x03\x03;\r\x07=?'-)+\x1d+\x1d-\x1d/\x1d1\r\x03G-\x1d3\x0b\x03\x1d5\x1d7\x05\x03\x01\t\x01\x02\x02)\x03\x05\t\x11\x05\x05\x05\x03\x05\t\x04e\x05\x01Q\x01\x05\x01\x07\x04S\x03\x01\x05\x03P\x01\x03\x07\x04?\x03\t\x0f\x05\x0b\x11\x0b\x15\x00\x05\x06\x19\x03\x05\x05\x01\x03\x07G\x01!\x05\x03\x05\x03\x05\t\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00\x9a\x0695\x03-\x0f\x0b\x0f!\x0f\x19'\x1d#3i1\x05\x05\x13%)9\x15\x1f\x0f\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00add_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00x\x00y\x00jit(func)/jit(main)/add\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.frontend_attributes\x00mhlo.memory_kind\x00mhlo.sharding\x00{maximal device=0}\x00pinned_host\x00device\x00jax.result_info\x00result\x00main\x00public\x00_xla_buffer_placement\x00\x00annotate_device_placement\x00\x08'\x07\x05\x1f\x01\x0b/79AC\x11IKM%O%%%", + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_06_30_cuda['shardy'] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['annotate_device_placement'], + serialized_date=datetime.date(2025, 6, 30), + inputs=(array([0.], dtype=float32), array([0.], dtype=float32)), + expected_outputs=(array([0.], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("x") +#loc2 = loc("y") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + sdy.mesh @mesh = <["a"=1]> loc(#loc) + func.func public @main(%arg0: tensor<1xf32> {mhlo.memory_kind = "device", sdy.sharding = #sdy.sharding<@mesh, [{"a"}]>} loc("x"), %arg1: tensor<1xf32> {mhlo.memory_kind = "pinned_host", sdy.sharding = #sdy.sharding<@mesh, [{"a"}]>} loc("y")) -> (tensor<1xf32> {jax.result_info = "result", mhlo.memory_kind = "pinned_host", sdy.sharding = #sdy.sharding<@mesh, [{"a"}]>}) { + %0 = stablehlo.add %arg0, %arg1 : tensor<1xf32> loc(#loc5) + %1 = stablehlo.custom_call @annotate_device_placement(%0) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<1xf32>) -> tensor<1xf32> loc(#loc) + return %1 : tensor<1xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":810:13) +#loc4 = loc("jit(func)/add"(#loc3)) +#loc5 = loc("add:"(#loc4)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.10.9\x00\x01#\x07\x01\x05\t\r\x01\x03\x0f\x03\x03\x13\x05\t\x17\x1b\x1f#\x03\x83a\x0b\x01-\x07\x0f\x0b#\x0b\x0f\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0b\x13\x0b\x03\x0b\x17\x13\x0f\x17\x0f\x05+\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0f#\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x01\x05\x0b\x0f\x05\x07\x13\x1b\x07\x02v\x02\x1f\x11\x03\x05\x05\t\x03\x07\t\x0b\r\x03\x0f\x03\x05\x13\x11\x01\x00\x05\x15\x05\x17\x05\x19\t\x05\x1d\x17\x01\x05\x1b\x1d\x1b\x01\x05\x1d\x1d\x1f!\x05\x1f\x1d#%\x05!\x17'\xaa\x0c\x1b\x05#\x03\x03+U\x05%\r\x13\x033\x01\x05\x031\x01\x03'\x05\x0b\x035\x01\x01\t'\x01\x03\x01\x1d)\x1d+\x1d-\x03\x05AE\r\x059C;-\x1d/\r\x059=;-#\x07\x03\x03K\r\x07MO9=;-\x1d1\x1d3\x1d5\x1d7\r\x03W=\x1d9\x0b\x03\x1d;\x1d=\x05\x03\x01\t\x01\x02\x02)\x03\x05\t\x11\x05\x05\x05\x03\x05\t\x04m\x05\x01Q\x01\x07\x01\x07\x04[\x03\x01\t\x03@\x01\x03\x05P\x01\x05\x07\x04?\x03\t\x0f\x05\x0b\x15\x0b\x19\x00\x07\x06\x1d\x03\x05\x05\x01\x03\tG\x01)\x07\x03\x05\x03\x05\x0b\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00n\x06?5\x03-\x0f\x0b\x0f!\x0f\x19\x1b#\x053i\x1d\x0b\x05\x05\x13%)9\x15\x1f\x0f\x11\x0b\x0f\x0b\t\x11builtin\x00sdy\x00vhlo\x00module\x00mesh\x00func_v1\x00add_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00x\x00y\x00add:\x00jit(func)/add\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.frontend_attributes\x00a\x00mhlo.memory_kind\x00sdy.sharding\x00pinned_host\x00device\x00jax.result_info\x00result\x00main\x00public\x00_xla_buffer_placement\x00\x00annotate_device_placement\x00\x08-\t\x05#\x01\x05/\x05\x0b?GIQS\x11Y[]7_777", + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_cholesky_lapack_potrf.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_cholesky_lapack_potrf.py index eb4143615da6..8188a7ffa73c 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_cholesky_lapack_potrf.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_cholesky_lapack_potrf.py @@ -15,347 +15,14 @@ # ruff: noqa import datetime -from numpy import array, float32, complex64 +import numpy as np -data_2023_06_19 = {} - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["f32"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_spotrf'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(array([[ 24.343887, 13.603932, 20.50489 , 12.063956], - [ 13.603932, 58.879757, -31.84056 , 16.328012], - [ 20.50489 , -31.84056 , 66.890755, -9.92216 ], - [ 12.063956, 16.328012, -9.92216 , 23.640734]], dtype=float32),), - expected_outputs=(array([[ 4.9339523, 0. , 0. , 0. ], - [ 2.7572079, 7.1608353, 0. , 0. ], - [ 4.155875 , -6.0466647, 3.6134892, 0. ], - [ 2.4450896, 1.3387254, -3.3177967, 2.2050648]], dtype=float32),), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_cholesky attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xf32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x4xf32> {jax.result_info = ""}) { - %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x4xf32>) -> tensor<4x4xf32> loc(#loc2) - %1 = stablehlo.add %arg0, %0 : tensor<4x4xf32> loc(#loc3) - %2 = stablehlo.constant dense<2.000000e+00> : tensor loc(#loc) - %3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor) -> tensor<4x4xf32> loc(#loc4) - %4 = stablehlo.divide %1, %3 : tensor<4x4xf32> loc(#loc4) - %5 = stablehlo.constant dense<1> : tensor loc(#loc5) - %6 = stablehlo.constant dense<1> : tensor loc(#loc5) - %7 = stablehlo.constant dense<4> : tensor loc(#loc5) - %8:2 = stablehlo.custom_call @lapack_spotrf(%5, %6, %7, %4) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<4x4xf32>) -> (tensor<4x4xf32>, tensor) loc(#loc5) - %9 = stablehlo.constant dense<0> : tensor loc(#loc5) - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor loc(#loc5) - %11 = stablehlo.compare EQ, %8#1, %10, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %13 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc5) - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<4x4xf32> loc(#loc5) - %15 = stablehlo.broadcast_in_dim %12, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %16 = stablehlo.select %15, %8#0, %14 : tensor<4x4xi1>, tensor<4x4xf32> loc(#loc5) - %17 = call @tril(%16) : (tensor<4x4xf32>) -> tensor<4x4xf32> loc(#loc6) - return %17 : tensor<4x4xf32> loc(#loc) - } loc(#loc) - func.func private @tril(%arg0: tensor<4x4xf32> loc(unknown)) -> tensor<4x4xf32> { - %0 = stablehlo.iota dim = 0 : tensor<4x4xi32> loc(#loc7) - %1 = stablehlo.constant dense<0> : tensor loc(#loc6) - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<4x4xi32> loc(#loc8) - %3 = stablehlo.add %0, %2 : tensor<4x4xi32> loc(#loc8) - %4 = stablehlo.iota dim = 1 : tensor<4x4xi32> loc(#loc9) - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<4x4xi32>, tensor<4x4xi32>) -> tensor<4x4xi1> loc(#loc10) - %6 = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc6) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<4x4xf32> loc(#loc11) - %8 = stablehlo.select %5, %arg0, %7 : tensor<4x4xi1>, tensor<4x4xf32> loc(#loc12) - return %8 : tensor<4x4xf32> loc(#loc6) - } loc(#loc6) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":292:0) -#loc2 = loc("jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]"(#loc1)) -#loc3 = loc("jit(cholesky)/jit(main)/add"(#loc1)) -#loc4 = loc("jit(cholesky)/jit(main)/div"(#loc1)) -#loc5 = loc("jit(cholesky)/jit(main)/cholesky"(#loc1)) -#loc6 = loc("jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc1)) -#loc7 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]"(#loc1)) -#loc8 = loc("jit(cholesky)/jit(main)/jit(tril)/add"(#loc1)) -#loc9 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]"(#loc1)) -#loc10 = loc("jit(cholesky)/jit(main)/jit(tril)/ge"(#loc1)) -#loc11 = loc("jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]"(#loc1)) -#loc12 = loc("jit(cholesky)/jit(main)/jit(tril)/select_n"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03"\x02\xd9%\x01\x87\x0f\x17\x07\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x0f\x13#\x0b\x0b\x0b33\x0b\x0b\x13\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x1b\x13\x13\x13\x0b\x03S\x0f\x0b\x0b\x0f\x0b\x0bO\x0f\x1b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1fO\x1f\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x13\x0b\x1fO\x01\x03\x0f\x03#\x17\x0f\x0f\x17\x07\x07\x07\x07\x17\x13\x07\x17\x13\x13\x13\x0f\x17\x02J\x07\x1dg\x03\x177\x92\x04\x01\x1f\x05\x1f\x03\x03\x1d\xb3\x1d5\x03\x05!\x11\x01\x05\x05#\x05%\x05\'\x05)\x05+\x03\x03\x07\xb1\x05-\x1d?\x03\x05/\x051\x1de\x03\x03\x03\x07\xbf\x03\x07+\x0f-\x0f\r/\x053\x055\x057\x03\x0b\x11\x95\x13\x89\x15\xa1\r\xa7\x17\xa9\x03\x0b\x11\x8d\x13\x89\x15\x8d\r\x8f\x17\xad\x059\x05;\x03\x03\x19\xaf\x1d=\x03\x05=\x05?\x03\x03\x19\xb5\x1dE\x03\x05A\x03\x05!\x91#\xb7\x1dK\x03\x05C\x03\x03\x07\xb9\x1dQ\x03\x05E\x1dU\x03\x05G\x03\x03Y\xbb\x05I\x1d]\x03\x05K\x1da\x03\x05M\x03\x03\x07\xbd\x05O\x05Q\x03\x03\x07\xc1\x03\x11m\xc3o\x8bq\xc5s\xc7u\xc9w\xcby\xcd{\xd1\x05S\x05U\x05W\x05Y\x05[\x05]\x05_\x05a\x03\x05!\x91#\xd3\x03\x03\x07\xd5\x03\x03\x1d\xd7\x03\x03\x85\x8f\x05c\x1f\x1d\x01#\x19\x1de\x03\x03\xab\x1dg\t\x07\x1f\x1f!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03\x97\r\x05\x99\x9b\x9d\x9f\x1di\x1dk\x1dm\x1do\x03\x03\xa3\r\x03\xa5\x8b\x1dq\x1ds\x1du\r\x01\x1dw\x13\x0b\x01\x1f\x05\t\x00\x00\x00\x00\x1f\x1b\x01\x13\x0b\x05\x07\x05\x1f\x07\t\x00\x00\x00\x00\x1f\x15!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\t\x00\x00\x00@\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x0b\x05\x1dy\x03\x01\x05\x01\x03\t\x87\x87\x87\x93\x03\x03\xcf\x15\x03\x01\r\x01\x03\x05\x93\x87\x07\x01\x1f\x07\t\x00\x00\xc0\x7f\x1f\x15!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x0f)\x01\x11)\x01\x0f)\x05\x11\x11\x11\x1d\x01\t\x1b)\x05\x11\x11\r)\x03\t\x0b\x13\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03\x01\x17)\x03\t\x17)\x01\r)\x05\x05\x05\r\x04\xd6\x03\x05\x01\x11\x05)\x07\x03\x01\t\x07\x11\x051\x05\x03)O\x03\x03\x05\x13\x07[W\x03\x03\x03\x01\x0b\x06_\x03\x03\x05\x01\x03\x03\x03\x05c\x03\x07\x05\x07%\t\x03\x03\x03\x07\x15\x06%\x03\x03\x05\x05\t\x03\x03\x01\'\x03\x05\x03\x03\x01\'\x03\x05\x03\x03\x01i\x03\x05\x17\x07\x01k\x05\x03\x05\t\r\x0f\x11\x0b\x03\x03\x01\x1b\x03\x05\x05\x07\x01\t\x03\x05\x03\x17\r\x07\x01}\x03!\x05\x15\x19\x05\x07\x01\t\x03#\x03\x1b\x03\x03\x01\x7f\x03\x07\x05\x07\x01\t\x03\x03\x03\x1f\x05\x07\x01\x81\x03\x13\x03\x1d\x0f\x06\x01\x03\x03\x07#\x13!\x19\x07\x0b\x83\x03\x03\x03%\x11\x04\x05\x03\'\x07\x11\x0b3\x05\x03\x15+\x03\x03\x05\t\x03;9\x03\t\x03\x03\x0b\x1b\x03\x05\x05\x07\x1f\t\x03\t\x03\x05\x0b\x06\x1f\x03\t\x05\x03\x07\t\x03CA\x03\t\r\x07IG\x03\x13\x05\t\x0b\x03\x03\x0bM\x03\x07\x05\x07O\t\x03\x03\x03\x0f\x0f\x06S\x03\x03\x07\r\x01\x11\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\n\x16{\x1d\x11\x0f\x0b!\x1b\x1d\x05\x1b\x0b\x03\x0f\x1f/!!)#\x1f\x19C99m\x19W\xb3K\x9bM\x9b\x97\xd2\x02\x1b%)+\x1b+\x1f\x1f\x15\x1d\x15\x13\r\x11\x1f\x15\x1b\x15\x15\x17\x0f\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00func_v1\x00iota_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00transpose_v1\x00divide_v1\x00custom_call_v1\x00call_v1\x00value\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_cholesky\x00jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]\x00jit(cholesky)/jit(main)/jit(tril)/add\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]\x00jit(cholesky)/jit(main)/jit(tril)/ge\x00jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]\x00jit(cholesky)/jit(main)/jit(tril)/select_n\x00permutation\x00jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]\x00jit(cholesky)/jit(main)/add\x00jit(cholesky)/jit(main)/div\x00jit(cholesky)/jit(main)/cholesky\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00callee\x00\x00tril\x00jax.arg_info\x00x\x00mhlo.sharding\x00{replicated}\x00jax.result_info\x00main\x00public\x00private\x00lapack_spotrf\x00', - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["f64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dpotrf'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(array([[ 23.022171138130666 , -16.79765603341739 , 0.9133449305189146, - -25.36636199966769 ], - [-16.79765603341739 , 31.655770252600092 , -1.5189878284433445, - 20.0344758332268 ], - [ 0.9133449305189146, -1.5189878284433445, 10.940134497877208 , - 8.169020034607513 ], - [-25.36636199966769 , 20.0344758332268 , 8.169020034607513 , - 37.054603917509596 ]]),), - expected_outputs=(array([[ 4.7981424674691215 , 0. , 0. , - 0. ], - [-3.500866459740129 , 4.404509539513645 , 0. , - 0. ], - [ 0.19035385812557523, -0.1935707899825621 , 3.2964268922333835 , - 0. ], - [-5.286704630312426 , 0.3465604732420997 , 2.8037778311164425 , - 1.060228174247855 ]]),), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_cholesky attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xf64> {jax.arg_info = "x", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x4xf64> {jax.result_info = ""}) { - %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x4xf64>) -> tensor<4x4xf64> loc(#loc2) - %1 = stablehlo.add %arg0, %0 : tensor<4x4xf64> loc(#loc3) - %2 = stablehlo.constant dense<2.000000e+00> : tensor loc(#loc) - %3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor) -> tensor<4x4xf64> loc(#loc4) - %4 = stablehlo.divide %1, %3 : tensor<4x4xf64> loc(#loc4) - %5 = stablehlo.constant dense<1> : tensor loc(#loc5) - %6 = stablehlo.constant dense<1> : tensor loc(#loc5) - %7 = stablehlo.constant dense<4> : tensor loc(#loc5) - %8:2 = stablehlo.custom_call @lapack_dpotrf(%5, %6, %7, %4) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<4x4xf64>) -> (tensor<4x4xf64>, tensor) loc(#loc5) - %9 = stablehlo.constant dense<0> : tensor loc(#loc5) - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor loc(#loc5) - %11 = stablehlo.compare EQ, %8#1, %10, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %13 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc5) - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<4x4xf64> loc(#loc5) - %15 = stablehlo.broadcast_in_dim %12, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %16 = stablehlo.select %15, %8#0, %14 : tensor<4x4xi1>, tensor<4x4xf64> loc(#loc5) - %17 = call @tril(%16) : (tensor<4x4xf64>) -> tensor<4x4xf64> loc(#loc6) - return %17 : tensor<4x4xf64> loc(#loc) - } loc(#loc) - func.func private @tril(%arg0: tensor<4x4xf64> loc(unknown)) -> tensor<4x4xf64> { - %0 = stablehlo.iota dim = 0 : tensor<4x4xi32> loc(#loc7) - %1 = stablehlo.constant dense<0> : tensor loc(#loc6) - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<4x4xi32> loc(#loc8) - %3 = stablehlo.add %0, %2 : tensor<4x4xi32> loc(#loc8) - %4 = stablehlo.iota dim = 1 : tensor<4x4xi32> loc(#loc9) - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<4x4xi32>, tensor<4x4xi32>) -> tensor<4x4xi1> loc(#loc10) - %6 = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc6) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<4x4xf64> loc(#loc11) - %8 = stablehlo.select %5, %arg0, %7 : tensor<4x4xi1>, tensor<4x4xf64> loc(#loc12) - return %8 : tensor<4x4xf64> loc(#loc6) - } loc(#loc6) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":292:0) -#loc2 = loc("jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]"(#loc1)) -#loc3 = loc("jit(cholesky)/jit(main)/add"(#loc1)) -#loc4 = loc("jit(cholesky)/jit(main)/div"(#loc1)) -#loc5 = loc("jit(cholesky)/jit(main)/cholesky"(#loc1)) -#loc6 = loc("jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc1)) -#loc7 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]"(#loc1)) -#loc8 = loc("jit(cholesky)/jit(main)/jit(tril)/add"(#loc1)) -#loc9 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]"(#loc1)) -#loc10 = loc("jit(cholesky)/jit(main)/jit(tril)/ge"(#loc1)) -#loc11 = loc("jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]"(#loc1)) -#loc12 = loc("jit(cholesky)/jit(main)/jit(tril)/select_n"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03"\x02\xd9%\x01\x87\x0f\x17\x07\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x0f\x13#\x0b\x0b\x0b33\x0b\x0b\x13\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x1b\x13\x13\x13\x0b\x03S\x0f\x0b\x0b\x0f\x0b\x0bO\x0f\x1b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b/O/\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x13\x0b/O\x01\x03\x0f\x03#\x17\x0f\x0f\x17\x07\x07\x07\x07\x17\x13\x07\x17\x13\x13\x13\x0f\x17\x02z\x07\x1dg\x03\x177\x92\x04\x01\x1f\x05\x1f\x03\x03\x1d\xb3\x1d5\x03\x05!\x11\x01\x05\x05#\x05%\x05\'\x05)\x05+\x03\x03\x07\xb1\x05-\x1d?\x03\x05/\x051\x1de\x03\x03\x03\x07\xbf\x03\x07+\x0f-\x0f\r/\x053\x055\x057\x03\x0b\x11\x95\x13\x89\x15\xa1\r\xa7\x17\xa9\x03\x0b\x11\x8d\x13\x89\x15\x8d\r\x8f\x17\xad\x059\x05;\x03\x03\x19\xaf\x1d=\x03\x05=\x05?\x03\x03\x19\xb5\x1dE\x03\x05A\x03\x05!\x91#\xb7\x1dK\x03\x05C\x03\x03\x07\xb9\x1dQ\x03\x05E\x1dU\x03\x05G\x03\x03Y\xbb\x05I\x1d]\x03\x05K\x1da\x03\x05M\x03\x03\x07\xbd\x05O\x05Q\x03\x03\x07\xc1\x03\x11m\xc3o\x8bq\xc5s\xc7u\xc9w\xcby\xcd{\xd1\x05S\x05U\x05W\x05Y\x05[\x05]\x05_\x05a\x03\x05!\x91#\xd3\x03\x03\x07\xd5\x03\x03\x1d\xd7\x03\x03\x85\x8f\x05c\x1f\x1d\x01#\x19\x1de\x03\x03\xab\x1dg\t\x07\x1f\x1f!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03\x97\r\x05\x99\x9b\x9d\x9f\x1di\x1dk\x1dm\x1do\x03\x03\xa3\r\x03\xa5\x8b\x1dq\x1ds\x1du\r\x01\x1dw\x13\x0b\x01\x1f\x05\t\x00\x00\x00\x00\x1f\x1b\x01\x13\x0b\x05\x07\x05\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x15!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00@\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x0b\x05\x1dy\x03\x01\x05\x01\x03\t\x87\x87\x87\x93\x03\x03\xcf\x15\x03\x01\r\x01\x03\x05\x93\x87\x07\x01\x1f\x07\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x15!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x0f)\x01\x11)\x01\x0f)\x05\x11\x11\x11\x1d\x01\x0b\x1b)\x05\x11\x11\r)\x03\t\x0b\x13\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03\x01\x17)\x03\t\x17)\x01\r)\x05\x05\x05\r\x04\xd6\x03\x05\x01\x11\x05)\x07\x03\x01\t\x07\x11\x051\x05\x03)O\x03\x03\x05\x13\x07[W\x03\x03\x03\x01\x0b\x06_\x03\x03\x05\x01\x03\x03\x03\x05c\x03\x07\x05\x07%\t\x03\x03\x03\x07\x15\x06%\x03\x03\x05\x05\t\x03\x03\x01\'\x03\x05\x03\x03\x01\'\x03\x05\x03\x03\x01i\x03\x05\x17\x07\x01k\x05\x03\x05\t\r\x0f\x11\x0b\x03\x03\x01\x1b\x03\x05\x05\x07\x01\t\x03\x05\x03\x17\r\x07\x01}\x03!\x05\x15\x19\x05\x07\x01\t\x03#\x03\x1b\x03\x03\x01\x7f\x03\x07\x05\x07\x01\t\x03\x03\x03\x1f\x05\x07\x01\x81\x03\x13\x03\x1d\x0f\x06\x01\x03\x03\x07#\x13!\x19\x07\x0b\x83\x03\x03\x03%\x11\x04\x05\x03\'\x07\x11\x0b3\x05\x03\x15+\x03\x03\x05\t\x03;9\x03\t\x03\x03\x0b\x1b\x03\x05\x05\x07\x1f\t\x03\t\x03\x05\x0b\x06\x1f\x03\t\x05\x03\x07\t\x03CA\x03\t\r\x07IG\x03\x13\x05\t\x0b\x03\x03\x0bM\x03\x07\x05\x07O\t\x03\x03\x03\x0f\x0f\x06S\x03\x03\x07\r\x01\x11\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\n\x16{\x1d\x11\x0f\x0b!\x1b\x1d\x05\x1b\x0b\x03\x0f\x1f/!!)#\x1f\x19C99m\x19W\xb3K\x9bM\x9b\x97\xd2\x02\x1b%)+\x1b+\x1f\x1f\x15\x1d\x15\x13\r\x11\x1f\x15\x1b\x15\x15\x17\x0f\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00func_v1\x00iota_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00transpose_v1\x00divide_v1\x00custom_call_v1\x00call_v1\x00value\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_cholesky\x00jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]\x00jit(cholesky)/jit(main)/jit(tril)/add\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]\x00jit(cholesky)/jit(main)/jit(tril)/ge\x00jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]\x00jit(cholesky)/jit(main)/jit(tril)/select_n\x00permutation\x00jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]\x00jit(cholesky)/jit(main)/add\x00jit(cholesky)/jit(main)/div\x00jit(cholesky)/jit(main)/cholesky\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00callee\x00\x00tril\x00jax.arg_info\x00x\x00mhlo.sharding\x00{replicated}\x00jax.result_info\x00main\x00public\x00private\x00lapack_dpotrf\x00', - xla_call_module_version=6, -) # End paste - - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["c64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_cpotrf'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(array([[ 38.089394 +6.36582342e-09j, 3.3509154+3.13455486e+01j, - -0.5972489-3.80308151e+01j, -19.04205 +1.22770605e+01j], - [ 3.3509154-3.13455486e+01j, 73.875755 +4.06565448e-09j, - -12.427276 -1.23379612e+01j, 41.542507 -9.63993359e+00j], - [ -0.5972489+3.80308151e+01j, -12.427276 +1.23379612e+01j, - 73.04141 -4.18667753e-07j, 8.193126 -2.60565052e+01j], - [-19.04205 -1.22770605e+01j, 41.542507 +9.63993359e+00j, - 8.193126 +2.60565052e+01j, 52.977036 -1.09952367e-07j]], - dtype=complex64),), - expected_outputs=(array([[ 6.1716604 +0.j , 0. +0.j , - 0. +0.j , 0. +0.j ], - [ 0.542952 -5.078949j , 6.912687 +0.j , - 0. +0.j , 0. +0.j ], - [-0.09677281+6.162169j , 2.7373738 +1.3719271j, - 5.0679703 +0.j , 0. +0.j ], - [-3.0854013 -1.9892638j, 4.7903748 +3.8177056j, - 0.3555784 +0.5865844j, 1.2276335 +0.j ]], dtype=complex64),), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_cholesky attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xcomplex> {jax.arg_info = "x", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x4xcomplex> {jax.result_info = ""}) { - %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x4xcomplex>) -> tensor<4x4xcomplex> loc(#loc2) - %1 = stablehlo.real %0 : (tensor<4x4xcomplex>) -> tensor<4x4xf32> loc(#loc3) - %2 = stablehlo.imag %0 : (tensor<4x4xcomplex>) -> tensor<4x4xf32> loc(#loc4) - %3 = stablehlo.negate %2 : tensor<4x4xf32> loc(#loc5) - %4 = stablehlo.complex %1, %3 : tensor<4x4xcomplex> loc(#loc6) - %5 = stablehlo.add %arg0, %4 : tensor<4x4xcomplex> loc(#loc7) - %6 = stablehlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor> loc(#loc) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc8) - %8 = stablehlo.divide %5, %7 : tensor<4x4xcomplex> loc(#loc8) - %9 = stablehlo.constant dense<1> : tensor loc(#loc9) - %10 = stablehlo.constant dense<1> : tensor loc(#loc9) - %11 = stablehlo.constant dense<4> : tensor loc(#loc9) - %12:2 = stablehlo.custom_call @lapack_cpotrf(%9, %10, %11, %8) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor) loc(#loc9) - %13 = stablehlo.constant dense<0> : tensor loc(#loc9) - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor loc(#loc9) - %15 = stablehlo.compare EQ, %12#1, %14, SIGNED : (tensor, tensor) -> tensor loc(#loc9) - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) - %17 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc9) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc9) - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc9) - %20 = stablehlo.select %19, %12#0, %18 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc9) - %21 = call @tril(%20) : (tensor<4x4xcomplex>) -> tensor<4x4xcomplex> loc(#loc10) - return %21 : tensor<4x4xcomplex> loc(#loc) - } loc(#loc) - func.func private @tril(%arg0: tensor<4x4xcomplex> loc(unknown)) -> tensor<4x4xcomplex> { - %0 = stablehlo.iota dim = 0 : tensor<4x4xi32> loc(#loc11) - %1 = stablehlo.constant dense<0> : tensor loc(#loc10) - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<4x4xi32> loc(#loc12) - %3 = stablehlo.add %0, %2 : tensor<4x4xi32> loc(#loc12) - %4 = stablehlo.iota dim = 1 : tensor<4x4xi32> loc(#loc13) - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<4x4xi32>, tensor<4x4xi32>) -> tensor<4x4xi1> loc(#loc14) - %6 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc10) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc15) - %8 = stablehlo.select %5, %arg0, %7 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc16) - return %8 : tensor<4x4xcomplex> loc(#loc10) - } loc(#loc10) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":292:0) -#loc2 = loc("jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]"(#loc1)) -#loc3 = loc("jit(cholesky)/jit(main)/real"(#loc1)) -#loc4 = loc("jit(cholesky)/jit(main)/imag"(#loc1)) -#loc5 = loc("jit(cholesky)/jit(main)/neg"(#loc1)) -#loc6 = loc("jit(cholesky)/jit(main)/complex"(#loc1)) -#loc7 = loc("jit(cholesky)/jit(main)/add"(#loc1)) -#loc8 = loc("jit(cholesky)/jit(main)/div"(#loc1)) -#loc9 = loc("jit(cholesky)/jit(main)/cholesky"(#loc1)) -#loc10 = loc("jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc1)) -#loc11 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]"(#loc1)) -#loc12 = loc("jit(cholesky)/jit(main)/jit(tril)/add"(#loc1)) -#loc13 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]"(#loc1)) -#loc14 = loc("jit(cholesky)/jit(main)/jit(tril)/ge"(#loc1)) -#loc15 = loc("jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]"(#loc1)) -#loc16 = loc("jit(cholesky)/jit(main)/jit(tril)/select_n"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x011\x05\x01\x03\x01\x03\x05\x03!\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%\x03J\x02\xe9)\x01\x97\x17\x0f\x07\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x0f\x13#\x0b\x0b\x0b33\x0b\x0b\x13\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x1b\x13\x13\x13\x0b\x03S\x0f\x0b\x0b\x0f\x0b\x0bO\x0f\x1b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b/O/\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x13\x0b/O\x01\x03\x0f\x03'\x17\x0f\x0f\x17\x07\x07\x17\x0b\x07\x07\x17\x13\x07\x17\x13\x13\x13\x0f\x17\x02\xe6\x07\x177\x92\x04\x01\x1dw\x01\x1f\x05'\x03\x03\x1d\xc3\x1d5\x01\x05)\x11\x01\x05\x05+\x05-\x05/\x051\x053\x03\x03\x07\xc1\x055\x1d?\x01\x057\x059\x1du\x01\x03\x03\x07\xcf\x03\x07+\x0f-\x0f\r/\x05;\x05=\x05?\x03\x0b\x11\xa5\x13\x99\x15\xb1\r\xb7\x17\xb9\x03\x0b\x11\x9d\x13\x99\x15\x9d\r\x9f\x17\xbd\x05A\x05C\x03\x03\x19\xbf\x1d=\x01\x05E\x05G\x03\x03\x19\xc5\x1dE\x01\x05I\x03\x05!\xa1#\xc7\x1dK\x01\x05K\x03\x03\x07\xc9\x1dQ\x01\x05M\x1dU\x01\x05O\x03\x03Y\xcb\x05Q\x1d]\x01\x05S\x1da\x01\x05U\x1de\x01\x05W\x1di\x01\x05Y\x1dm\x01\x05[\x1dq\x01\x05]\x03\x03\x07\xcd\x05_\x05a\x03\x03\x07\xd1\x03\x11}\xd3\x7f\x9b\x81\xd5\x83\xd7\x85\xd9\x87\xdb\x89\xdd\x8b\xe1\x05c\x05e\x05g\x05i\x05k\x05m\x05o\x05q\x03\x05!\xa1#\xe3\x03\x03\x07\xe5\x03\x03\x1d\xe7\x03\x03\x95\x9f\x05s\x1f!\x01#\x1d\x1du\x03\x03\xbb\x1dw\t\x07\x1f#!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03\xa7\r\x05\xa9\xab\xad\xaf\x1dy\x1d{\x1d}\x1d\x7f\x03\x03\xb3\r\x03\xb5\x9b\x1d\x81\x1d\x83\x1d\x85\r\x01\x1d\x87\x13\x0b\x01\x1f\x05\t\x00\x00\x00\x00\x1f\x1f\x01\x13\x0b\x05\x07\x05\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x19!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x11\x00\x00\x00@\x00\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x0b\x05\x1d\x89\x03\x01\x05\x01\x03\t\x97\x97\x97\xa3\x03\x03\xdf\x15\x03\x01\r\x01\x03\x05\xa3\x97\x07\x01\x1f\x07\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x19!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x11)\x01\x15)\x01\x11)\x05\x11\x11\x15\x1d\x01)\x05\x11\x11\x13\x03\x13\t\x1b)\x05\x11\x11\r)\x03\t\x0b\x13\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03\x01\x1b)\x03\t\x1b)\x01\r)\x05\x05\x05\r\x04J\x04\x05\x01\x11\x05)\x07\x03\x01\t\x07\x11\x051\x05\x031_\x03\x03\x05\x13\x07[W\x03\x03\x03\x01\x15\x06_\x03\x0f\x03\x03\x17\x06c\x03\x0f\x03\x03\x19\x06g\x03\x0f\x03\x07\x1b\x06k\x03\x03\x05\x05\t\x0b\x06o\x03\x03\x05\x01\x0b\x03\x03\x05s\x03\x07\x05\x07%\t\x03\x03\x03\x0f\x1d\x06%\x03\x03\x05\r\x11\x03\x03\x03'\x03\x05\x03\x03\x03'\x03\x05\x03\x03\x03y\x03\x05\x1f\x07\x03{\x05\x03\x05\t\x15\x17\x19\x13\x03\x03\x03\x1b\x03\x05\x05\x07\x03\t\x03\x05\x03\x1f\r\x07\x03\x8d\x03%\x05\x1d!\x05\x07\x03\t\x03'\x03#\x03\x03\x03\x8f\x03\x07\x05\x07\x03\t\x03\x03\x03'\x05\x07\x03\x91\x03\x17\x03%\x0f\x06\x03\x03\x03\x07+\x1b)!\x07\x0b\x93\x03\x03\x03-\x11\x04\x05\x03/\x07\x11\x0b3\x05\x03\x15+\x03\x03\x05\t\x03;9\x03\t\x03\x03\x0b\x1b\x03\x05\x05\x07\x1f\t\x03\t\x03\x05\x0b\x06\x1f\x03\t\x05\x03\x07\t\x03CA\x03\t\r\x07IG\x03\x17\x05\t\x0b\x03\x03\x0bM\x03\x07\x05\x07O\t\x03\x03\x03\x0f\x0f\x06S\x03\x03\x07\r\x01\x11\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\x96\x18\x8b\x1d\x11\x0f\x0b!\x1b\x1d\x05\x1b\x0b\x03\x0f\x1f/!!)#\x1f\x19C99A9;;m\x19W\xb3K\x9bM\x9b\x97\xd2\x02\x1b%)+\x1b+\x1f\x1f\x15\x1d\x15\x13\r\x11\x1f\x15\x17\x15\x11\x11\x1b\x15\x15\x17\x0f\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00func_v1\x00iota_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00divide_v1\x00custom_call_v1\x00call_v1\x00value\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_cholesky\x00jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]\x00jit(cholesky)/jit(main)/jit(tril)/add\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]\x00jit(cholesky)/jit(main)/jit(tril)/ge\x00jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]\x00jit(cholesky)/jit(main)/jit(tril)/select_n\x00permutation\x00jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]\x00jit(cholesky)/jit(main)/real\x00jit(cholesky)/jit(main)/imag\x00jit(cholesky)/jit(main)/neg\x00jit(cholesky)/jit(main)/complex\x00jit(cholesky)/jit(main)/add\x00jit(cholesky)/jit(main)/div\x00jit(cholesky)/jit(main)/cholesky\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00callee\x00\x00tril\x00jax.arg_info\x00x\x00mhlo.sharding\x00{replicated}\x00jax.result_info\x00main\x00public\x00private\x00lapack_cpotrf\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["c128"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zpotrf'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(array([[ 77.35445791180521 -6.4555004827448569e-16j, - 16.89356598261691 -5.4959586590823566e+00j, - -21.124380423202325+6.4431220601700787e+01j, - 55.385054340628855+2.5198457006849742e+00j], - [ 16.89356598261691 +5.4959586590823566e+00j, - 67.125263428637 -3.2921739472953976e-16j, - 25.14078382035968 +1.2783276691803774e+01j, - 51.116221409460884-2.2635508887939348e+00j], - [-21.124380423202325-6.4431220601700787e+01j, - 25.14078382035968 -1.2783276691803774e+01j, - 107.43449297637208 -2.8959717546347756e-15j, - 12.493792156221616-5.7556567757218694e+01j], - [ 55.385054340628855-2.5198457006849715e+00j, - 51.116221409460884+2.2635508887939326e+00j, - 12.493792156221616+5.7556567757218708e+01j, - 78.9856503203742 +2.0971925518284437e-16j]]),), - expected_outputs=(array([[ 8.795138311124232 +0.j , - 0. +0.j , - 0. +0.j , - 0. +0.j ], - [ 1.9207845726825759+0.624885984127274j , - 7.940111306576433 +0.j , - 0. +0.j , - 0. +0.j ], - [-2.401824698593298 -7.325776846534311j , - 4.3238621722485755-0.026813746599595675j, - 5.413152651345813 +0.j , - 0. +0.j ], - [ 6.297235174866659 -0.28650438589440164j , - 4.936910868956218 +0.849977768846063j , - 0.7751580530200595+1.279980716041562j , - 3.451611642915363 +0.j ]]),), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_cholesky attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xcomplex> {jax.arg_info = "x", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x4xcomplex> {jax.result_info = ""}) { - %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x4xcomplex>) -> tensor<4x4xcomplex> loc(#loc2) - %1 = stablehlo.real %0 : (tensor<4x4xcomplex>) -> tensor<4x4xf64> loc(#loc3) - %2 = stablehlo.imag %0 : (tensor<4x4xcomplex>) -> tensor<4x4xf64> loc(#loc4) - %3 = stablehlo.negate %2 : tensor<4x4xf64> loc(#loc5) - %4 = stablehlo.complex %1, %3 : tensor<4x4xcomplex> loc(#loc6) - %5 = stablehlo.add %arg0, %4 : tensor<4x4xcomplex> loc(#loc7) - %6 = stablehlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor> loc(#loc) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc8) - %8 = stablehlo.divide %5, %7 : tensor<4x4xcomplex> loc(#loc8) - %9 = stablehlo.constant dense<1> : tensor loc(#loc9) - %10 = stablehlo.constant dense<1> : tensor loc(#loc9) - %11 = stablehlo.constant dense<4> : tensor loc(#loc9) - %12:2 = stablehlo.custom_call @lapack_zpotrf(%9, %10, %11, %8) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor) loc(#loc9) - %13 = stablehlo.constant dense<0> : tensor loc(#loc9) - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor loc(#loc9) - %15 = stablehlo.compare EQ, %12#1, %14, SIGNED : (tensor, tensor) -> tensor loc(#loc9) - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) - %17 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc9) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc9) - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc9) - %20 = stablehlo.select %19, %12#0, %18 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc9) - %21 = call @tril(%20) : (tensor<4x4xcomplex>) -> tensor<4x4xcomplex> loc(#loc10) - return %21 : tensor<4x4xcomplex> loc(#loc) - } loc(#loc) - func.func private @tril(%arg0: tensor<4x4xcomplex> loc(unknown)) -> tensor<4x4xcomplex> { - %0 = stablehlo.iota dim = 0 : tensor<4x4xi32> loc(#loc11) - %1 = stablehlo.constant dense<0> : tensor loc(#loc10) - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<4x4xi32> loc(#loc12) - %3 = stablehlo.add %0, %2 : tensor<4x4xi32> loc(#loc12) - %4 = stablehlo.iota dim = 1 : tensor<4x4xi32> loc(#loc13) - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<4x4xi32>, tensor<4x4xi32>) -> tensor<4x4xi1> loc(#loc14) - %6 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc10) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc15) - %8 = stablehlo.select %5, %arg0, %7 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc16) - return %8 : tensor<4x4xcomplex> loc(#loc10) - } loc(#loc10) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":292:0) -#loc2 = loc("jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]"(#loc1)) -#loc3 = loc("jit(cholesky)/jit(main)/real"(#loc1)) -#loc4 = loc("jit(cholesky)/jit(main)/imag"(#loc1)) -#loc5 = loc("jit(cholesky)/jit(main)/neg"(#loc1)) -#loc6 = loc("jit(cholesky)/jit(main)/complex"(#loc1)) -#loc7 = loc("jit(cholesky)/jit(main)/add"(#loc1)) -#loc8 = loc("jit(cholesky)/jit(main)/div"(#loc1)) -#loc9 = loc("jit(cholesky)/jit(main)/cholesky"(#loc1)) -#loc10 = loc("jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc1)) -#loc11 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]"(#loc1)) -#loc12 = loc("jit(cholesky)/jit(main)/jit(tril)/add"(#loc1)) -#loc13 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]"(#loc1)) -#loc14 = loc("jit(cholesky)/jit(main)/jit(tril)/ge"(#loc1)) -#loc15 = loc("jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]"(#loc1)) -#loc16 = loc("jit(cholesky)/jit(main)/jit(tril)/select_n"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x011\x05\x01\x03\x01\x03\x05\x03!\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%\x03J\x02\xe9)\x01\x97\x17\x0f\x07\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x0f\x13#\x0b\x0b\x0b33\x0b\x0b\x13\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x1b\x13\x13\x13\x0b\x03S\x0f\x0b\x0b\x0f\x0b\x0bO\x0f\x1b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0bOOO\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x13\x0bOO\x01\x03\x0f\x03'\x17\x0f\x0f\x17\x07\x07\x17\x0b\x07\x07\x17\x13\x07\x17\x13\x13\x13\x0f\x17\x02F\x08\x177\x92\x04\x01\x1dw\x01\x1f\x05'\x03\x03\x1d\xc3\x1d5\x01\x05)\x11\x01\x05\x05+\x05-\x05/\x051\x053\x03\x03\x07\xc1\x055\x1d?\x01\x057\x059\x1du\x01\x03\x03\x07\xcf\x03\x07+\x0f-\x0f\r/\x05;\x05=\x05?\x03\x0b\x11\xa5\x13\x99\x15\xb1\r\xb7\x17\xb9\x03\x0b\x11\x9d\x13\x99\x15\x9d\r\x9f\x17\xbd\x05A\x05C\x03\x03\x19\xbf\x1d=\x01\x05E\x05G\x03\x03\x19\xc5\x1dE\x01\x05I\x03\x05!\xa1#\xc7\x1dK\x01\x05K\x03\x03\x07\xc9\x1dQ\x01\x05M\x1dU\x01\x05O\x03\x03Y\xcb\x05Q\x1d]\x01\x05S\x1da\x01\x05U\x1de\x01\x05W\x1di\x01\x05Y\x1dm\x01\x05[\x1dq\x01\x05]\x03\x03\x07\xcd\x05_\x05a\x03\x03\x07\xd1\x03\x11}\xd3\x7f\x9b\x81\xd5\x83\xd7\x85\xd9\x87\xdb\x89\xdd\x8b\xe1\x05c\x05e\x05g\x05i\x05k\x05m\x05o\x05q\x03\x05!\xa1#\xe3\x03\x03\x07\xe5\x03\x03\x1d\xe7\x03\x03\x95\x9f\x05s\x1f!\x01#\x1d\x1du\x03\x03\xbb\x1dw\t\x07\x1f#!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03\xa7\r\x05\xa9\xab\xad\xaf\x1dy\x1d{\x1d}\x1d\x7f\x03\x03\xb3\r\x03\xb5\x9b\x1d\x81\x1d\x83\x1d\x85\r\x01\x1d\x87\x13\x0b\x01\x1f\x05\t\x00\x00\x00\x00\x1f\x1f\x01\x13\x0b\x05\x07\x05\x1f\x07!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x19!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07!\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x0b\x05\x1d\x89\x03\x01\x05\x01\x03\t\x97\x97\x97\xa3\x03\x03\xdf\x15\x03\x01\r\x01\x03\x05\xa3\x97\x07\x01\x1f\x07!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x19!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x11)\x01\x15)\x01\x11)\x05\x11\x11\x15\x1d\x01)\x05\x11\x11\x13\x03\x13\x0b\x1b)\x05\x11\x11\r)\x03\t\x0b\x13\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03\x01\x1b)\x03\t\x1b)\x01\r)\x05\x05\x05\r\x04J\x04\x05\x01\x11\x05)\x07\x03\x01\t\x07\x11\x051\x05\x031_\x03\x03\x05\x13\x07[W\x03\x03\x03\x01\x15\x06_\x03\x0f\x03\x03\x17\x06c\x03\x0f\x03\x03\x19\x06g\x03\x0f\x03\x07\x1b\x06k\x03\x03\x05\x05\t\x0b\x06o\x03\x03\x05\x01\x0b\x03\x03\x05s\x03\x07\x05\x07%\t\x03\x03\x03\x0f\x1d\x06%\x03\x03\x05\r\x11\x03\x03\x03'\x03\x05\x03\x03\x03'\x03\x05\x03\x03\x03y\x03\x05\x1f\x07\x03{\x05\x03\x05\t\x15\x17\x19\x13\x03\x03\x03\x1b\x03\x05\x05\x07\x03\t\x03\x05\x03\x1f\r\x07\x03\x8d\x03%\x05\x1d!\x05\x07\x03\t\x03'\x03#\x03\x03\x03\x8f\x03\x07\x05\x07\x03\t\x03\x03\x03'\x05\x07\x03\x91\x03\x17\x03%\x0f\x06\x03\x03\x03\x07+\x1b)!\x07\x0b\x93\x03\x03\x03-\x11\x04\x05\x03/\x07\x11\x0b3\x05\x03\x15+\x03\x03\x05\t\x03;9\x03\t\x03\x03\x0b\x1b\x03\x05\x05\x07\x1f\t\x03\t\x03\x05\x0b\x06\x1f\x03\t\x05\x03\x07\t\x03CA\x03\t\r\x07IG\x03\x17\x05\t\x0b\x03\x03\x0bM\x03\x07\x05\x07O\t\x03\x03\x03\x0f\x0f\x06S\x03\x03\x07\r\x01\x11\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\x96\x18\x8b\x1d\x11\x0f\x0b!\x1b\x1d\x05\x1b\x0b\x03\x0f\x1f/!!)#\x1f\x19C99A9;;m\x19W\xb3K\x9bM\x9b\x97\xd2\x02\x1b%)+\x1b+\x1f\x1f\x15\x1d\x15\x13\r\x11\x1f\x15\x17\x15\x11\x11\x1b\x15\x15\x17\x0f\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00func_v1\x00iota_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00divide_v1\x00custom_call_v1\x00call_v1\x00value\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_cholesky\x00jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]\x00jit(cholesky)/jit(main)/jit(tril)/add\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]\x00jit(cholesky)/jit(main)/jit(tril)/ge\x00jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]\x00jit(cholesky)/jit(main)/jit(tril)/select_n\x00permutation\x00jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]\x00jit(cholesky)/jit(main)/real\x00jit(cholesky)/jit(main)/imag\x00jit(cholesky)/jit(main)/neg\x00jit(cholesky)/jit(main)/complex\x00jit(cholesky)/jit(main)/add\x00jit(cholesky)/jit(main)/div\x00jit(cholesky)/jit(main)/cholesky\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00callee\x00\x00tril\x00jax.arg_info\x00x\x00mhlo.sharding\x00{replicated}\x00jax.result_info\x00main\x00public\x00private\x00lapack_zpotrf\x00", - xla_call_module_version=6, -) # End paste +array = np.array +float32 = np.float32 +complex64 = np.complex64 data_2024_05_31 = {} - # Pasted from the test output (see export_back_compat_test_util.py module docstring) data_2024_05_31["c128"] = dict( testdata_version=1, diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eig_lapack_geev.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eig_lapack_geev.py index bc28857fa325..52ef915968af 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eig_lapack_geev.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eig_lapack_geev.py @@ -15,279 +15,13 @@ # ruff: noqa import datetime -from numpy import array, float32, complex64 - -data_2023_06_19 = {} - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["f32"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_sgeev'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(), - expected_outputs=(array([ 3.2464241e+01+0.j, -2.4642489e+00+0.j, 1.4189274e-07+0.j, - -4.0686123e-07+0.j], dtype=complex64), array([[-0.40377745 +0.j, -0.82883257 +0.j, -0.06733338 +0.j, - -0.5208027 +0.j], - [-0.46480742 +0.j, -0.4371466 +0.j, 0.49492982 +0.j, - 0.82081676 +0.j], - [-0.52583724 +0.j, -0.045459956+0.j, -0.78785884 +0.j, - -0.07922471 +0.j], - [-0.5868671 +0.j, 0.3462263 +0.j, 0.36026272 +0.j, - -0.2207891 +0.j]], dtype=complex64), array([[-0.11417642+0.j, -0.73277813+0.j, 0.16960056+0.j, - -0.5435681 +0.j], - [-0.33000448+0.j, -0.28974825+0.j, 0.16204938+0.j, - 0.67456985+0.j], - [-0.54583275+0.j, 0.15328142+0.j, -0.8329006 +0.j, - 0.28156415+0.j], - [-0.761661 +0.j, 0.5963111 +0.j, 0.5012507 +0.j, - -0.41256607+0.j]], dtype=complex64)), - mlir_module_text=r""" -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<4xcomplex> {jax.result_info = "[0]"}, tensor<4x4xcomplex> {jax.result_info = "[1]"}, tensor<4x4xcomplex> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<16xf32> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<16xf32>) -> tensor<4x4xf32> loc(#loc4) - %2 = stablehlo.constant dense<1> : tensor loc(#loc5) - %3 = stablehlo.constant dense<4> : tensor loc(#loc5) - %4 = stablehlo.constant dense<86> : tensor loc(#loc5) - %5 = stablehlo.constant dense<86> : tensor loc(#loc5) - %6:8 = stablehlo.custom_call @lapack_sgeev(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor, tensor<4x4xf32>) -> (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor) loc(#loc5) - %7 = stablehlo.complex %6#3, %6#4 : tensor<4xcomplex> loc(#loc5) - %8 = stablehlo.constant dense<0> : tensor loc(#loc5) - %9 = stablehlo.broadcast_in_dim %8, dims = [] : (tensor) -> tensor loc(#loc5) - %10 = stablehlo.compare EQ, %6#7, %9, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %11 = stablehlo.broadcast_in_dim %10, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) - %12 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) - %13 = stablehlo.broadcast_in_dim %12, dims = [] : (tensor>) -> tensor<4xcomplex> loc(#loc5) - %14 = stablehlo.broadcast_in_dim %11, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc5) - %15 = stablehlo.select %14, %7, %13 : tensor<4xi1>, tensor<4xcomplex> loc(#loc5) - %16 = stablehlo.broadcast_in_dim %10, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %17 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %20 = stablehlo.select %19, %6#5, %18 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - %21 = stablehlo.broadcast_in_dim %10, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %22 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %24 = stablehlo.broadcast_in_dim %21, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %25 = stablehlo.select %24, %6#6, %23 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - return %15, %20, %25 : tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":496:0) -#loc2 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":497:0) -#loc3 = loc("jit(func)/jit(main)/iota[dtype=float32 shape=(16,) dimension=0]"(#loc1)) -#loc4 = loc("jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]"(#loc1)) -#loc5 = loc("jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01%\x05\x01\x03\x01\x03\x05\x03\x15\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x03\xe7\x9b9\x01[\x0f\x13\x0b\x07\x0b\x13\x0f\x0b\x17\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x17\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03AO\x0f\x0b\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x0f\x1f\x1f\x13\x0b\x0b\x0b\x0b\x1f+\x1f\x0f\x0b\x0b//O\x01\x03\x0f\x037\x17\x0f\x07\x13\x07\x07\x17\x0f\x0b\x0f\x07\x13\x17\x17\x1b\x13\x07\x07\x13\x13\x13\x13\x0f\x13\x13\x13\x13\x02v\x06\x1d9;\x03\x03\t\x8f\x05\x1b\x1f\x05\x1d\x03\x03\x05\x95\x11\x01\x05\x05\x1f\x17\x13\xc2\x07\x01\x05!\x03\x03\x05\x7f\x03\x03\t\x99\x03\x07\x1b\r\x1d\r\x0f\x1f\x05#\x05%\x05'\x03\x0b#_%e'g\x0fu)w\x05)\x05+\x05-\x05/\x03\x03-y\x051\x1d1\x11\x053\x1d5\x11\x055\x03\x03\x05{\x057\x17\x13\xc6\x07\x01\x03\x03\x05}\x03\x11A\x81C\x83E\x85G_I\x87K\x89M_O\x8b\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x05\x8d\x03\x05U\x91W\x93\x05I\x05K\x03\x03\t\x97\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f'\x01\x03\x01\x1dM\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x07imq\r\x03ak\x1dO\r\x03ao\x1dQ\r\x03as\x1dS\x1dU\x1dW\x13\r\x01\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x15\x03V\x0b\x05\x1dY\x1d[\x05\x01\x03\x0b]]]][\x03\x11[[[cc[[]\x1f\x05\t\x00\x00\x00\x00\x1f-\x01\t\x07\x07\x01\x1f\x11\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f7!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x13)\x01#\x01)\x03\x11\x13\t\x1d)\x05\x11\x11\x0b)\x01\x13\x03\x0b)\x01%\x13)\x03\x11\x0b)\x05\x05\x05\x07)\x05\x11\x11\x07\x11\x01\x07\t\x03\x03)\x03A\x0b\x1b!)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x03\x01\r)\x01\x07)\x03\x05\x07)\x03\x11\x07)\x03\x05\r)\x03\t\r\x04\x92\x03\x05\x01\x11\x07\x19\x07\x03\x01\x05\t\x11\x07!\x05\x03Cm\x0b\x03/+\x03!\r\x063\x03\x0f\x03\x01\x05\x03\x017\x03\x05\x05\x03\x01=\x03\x05\x05\x03\x01\x15\x03\x15\x05\x03\x01\x15\x03\x15\x0f\x07\x01?\x11\x0f\x0f\x0f\x19\x19\x03\x03\x05\x0b\x05\x07\t\x0b\x03\x11\x06\x01\x03\t\x05\x13\x15\x05\x03\x01Q\x03\x05\x03\x07\x01\x03\x03\x05\x03\x1f\x13\x07\x01S\x03/\x05\x1b!\x03\x07\x01\x03\x031\x03#\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\t\x03'\x03\x07\x01Y\x033\x03%\x07\x06\x01\x03\t\x07+\x1d)\x03\x07\x01\x03\x03\x1b\x03#\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x03\x031\x03\x07\x01\x17\x03\x1d\x03/\x07\x06\x01\x03\x03\x075\x173\x03\x07\x01\x03\x03\x1b\x03#\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x03\x03;\x03\x07\x01\x17\x03\x1d\x039\x07\x06\x01\x03\x03\x07?\x19=\x15\x04\x07\x07-7A\x06\x03\x01\x05\x01\x00&\r]\x1b\x03\x0f\x0b\t\t\t!+\x1b\x1f/!!)#\x1f\x19\xb1}\x81\x1f\x1f\x15\x1d\x15\x13%)\x97\x13+\r\x15\x17\x17\x1f\x17\x11\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00complex_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=float32 shape=(16,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]\x00jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_sgeev\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["f64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dgeev'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(), - expected_outputs=(array([ 3.2464249196572972e+01+0.j, -2.4642491965729802e+00+0.j, - -1.5210037805054253e-15+0.j, 1.2568096307462507e-16+0.j]), array([[-0.4037774907686232 +0.j, 0.8288327563197505 +0.j, - 0.5454962288885842 +0.j, -0.2420483778598153 +0.j], - [-0.46480737115848986 +0.j, 0.43714638836388725 +0.j, - -0.7640998541831632 +0.j, -0.04349021275982002 +0.j], - [-0.5258372515483576 +0.j, 0.045460020408024715+0.j, - -0.10828897829942748 +0.j, 0.8131255590990858 +0.j], - [-0.5868671319382249 +0.j, -0.3462263475478384 +0.j, - 0.32689260359400607 +0.j, -0.5275869684794504 +0.j]]), array([[-0.11417645138733863+0.j, 0.7327780959803554 +0.j, - 0.49133754464261303+0.j, -0.04933420991901029+0.j], - [-0.33000459866554765+0.j, 0.28974835239692637+0.j, - -0.8355289351028521 +0.j, -0.3408099365295394 +0.j], - [-0.545832745943757 +0.j, -0.1532813911865017 +0.j, - 0.1970452362778633 +0.j, 0.8296225028161098 +0.j], - [-0.7616608932219663 +0.j, -0.5963111347699308 +0.j, - 0.14714615418237506+0.j, -0.43947835636755994+0.j]])), - mlir_module_text=r""" -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<4xcomplex> {jax.result_info = "[0]"}, tensor<4x4xcomplex> {jax.result_info = "[1]"}, tensor<4x4xcomplex> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<16xf64> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<16xf64>) -> tensor<4x4xf64> loc(#loc4) - %2 = stablehlo.constant dense<1> : tensor loc(#loc5) - %3 = stablehlo.constant dense<4> : tensor loc(#loc5) - %4 = stablehlo.constant dense<86> : tensor loc(#loc5) - %5 = stablehlo.constant dense<86> : tensor loc(#loc5) - %6:8 = stablehlo.custom_call @lapack_dgeev(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor, tensor<4x4xf64>) -> (tensor<4x4xf64>, tensor<4x4xf64>, tensor<4x4xf64>, tensor<4xf64>, tensor<4xf64>, tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor) loc(#loc5) - %7 = stablehlo.complex %6#3, %6#4 : tensor<4xcomplex> loc(#loc5) - %8 = stablehlo.constant dense<0> : tensor loc(#loc5) - %9 = stablehlo.broadcast_in_dim %8, dims = [] : (tensor) -> tensor loc(#loc5) - %10 = stablehlo.compare EQ, %6#7, %9, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %11 = stablehlo.broadcast_in_dim %10, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) - %12 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) - %13 = stablehlo.broadcast_in_dim %12, dims = [] : (tensor>) -> tensor<4xcomplex> loc(#loc5) - %14 = stablehlo.broadcast_in_dim %11, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc5) - %15 = stablehlo.select %14, %7, %13 : tensor<4xi1>, tensor<4xcomplex> loc(#loc5) - %16 = stablehlo.broadcast_in_dim %10, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %17 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %20 = stablehlo.select %19, %6#5, %18 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - %21 = stablehlo.broadcast_in_dim %10, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %22 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %24 = stablehlo.broadcast_in_dim %21, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %25 = stablehlo.select %24, %6#6, %23 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - return %15, %20, %25 : tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":496:0) -#loc2 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":497:0) -#loc3 = loc("jit(func)/jit(main)/iota[dtype=float64 shape=(16,) dimension=0]"(#loc1)) -#loc4 = loc("jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]"(#loc1)) -#loc5 = loc("jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01%\x05\x01\x03\x01\x03\x05\x03\x15\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x03\xe7\x9b9\x01[\x0f\x13\x0b\x07\x0b\x13\x0f\x0b\x17\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x17\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03AO\x0f\x0b\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x0f\x1f\x1f\x13\x0b\x0b\x0b\x0b\x1f+\x1f\x0f\x0b\x0bO/O\x01\x03\x0f\x037\x17\x0f\x07\x13\x07\x07\x17\x0f\x0b\x0f\x07\x13\x17\x17\x1b\x13\x07\x07\x13\x13\x13\x13\x0f\x13\x13\x13\x13\x02\x96\x06\x1d9;\x03\x03\t\x8f\x05\x1b\x1f\x05\x1d\x03\x03\x05\x95\x11\x01\x05\x05\x1f\x17\x13\xc2\x07\x01\x05!\x03\x03\x05\x7f\x03\x03\t\x99\x03\x07\x1b\r\x1d\r\x0f\x1f\x05#\x05%\x05'\x03\x0b#_%e'g\x0fu)w\x05)\x05+\x05-\x05/\x03\x03-y\x051\x1d1\x11\x053\x1d5\x11\x055\x03\x03\x05{\x057\x17\x13\xc6\x07\x01\x03\x03\x05}\x03\x11A\x81C\x83E\x85G_I\x87K\x89M_O\x8b\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x05\x8d\x03\x05U\x91W\x93\x05I\x05K\x03\x03\t\x97\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f'\x01\x03\x01\x1dM\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x07imq\r\x03ak\x1dO\r\x03ao\x1dQ\r\x03as\x1dS\x1dU\x1dW\x13\r\x01\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x15\x03V\x0b\x05\x1dY\x1d[\x05\x01\x03\x0b]]]][\x03\x11[[[cc[[]\x1f\x05\t\x00\x00\x00\x00\x1f-\x01\t\x07\x07\x01\x1f\x11!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f7!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x13)\x01#\x01)\x03\x11\x13\x0b\x1d)\x05\x11\x11\x0b)\x01\x13\x03\x0b)\x01%\x13)\x03\x11\x0b)\x05\x05\x05\x07)\x05\x11\x11\x07\x11\x01\x07\t\x03\x03)\x03A\x0b\x1b!)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x03\x01\r)\x01\x07)\x03\x05\x07)\x03\x11\x07)\x03\x05\r)\x03\t\r\x04\x92\x03\x05\x01\x11\x07\x19\x07\x03\x01\x05\t\x11\x07!\x05\x03Cm\x0b\x03/+\x03!\r\x063\x03\x0f\x03\x01\x05\x03\x017\x03\x05\x05\x03\x01=\x03\x05\x05\x03\x01\x15\x03\x15\x05\x03\x01\x15\x03\x15\x0f\x07\x01?\x11\x0f\x0f\x0f\x19\x19\x03\x03\x05\x0b\x05\x07\t\x0b\x03\x11\x06\x01\x03\t\x05\x13\x15\x05\x03\x01Q\x03\x05\x03\x07\x01\x03\x03\x05\x03\x1f\x13\x07\x01S\x03/\x05\x1b!\x03\x07\x01\x03\x031\x03#\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\t\x03'\x03\x07\x01Y\x033\x03%\x07\x06\x01\x03\t\x07+\x1d)\x03\x07\x01\x03\x03\x1b\x03#\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x03\x031\x03\x07\x01\x17\x03\x1d\x03/\x07\x06\x01\x03\x03\x075\x173\x03\x07\x01\x03\x03\x1b\x03#\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x03\x03;\x03\x07\x01\x17\x03\x1d\x039\x07\x06\x01\x03\x03\x07?\x19=\x15\x04\x07\x07-7A\x06\x03\x01\x05\x01\x00&\r]\x1b\x03\x0f\x0b\t\t\t!+\x1b\x1f/!!)#\x1f\x19\xb1}\x81\x1f\x1f\x15\x1d\x15\x13%)\x97\x13+\r\x15\x17\x17\x1f\x17\x11\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00complex_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=float64 shape=(16,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]\x00jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_dgeev\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["c64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_cgeev'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(), - expected_outputs=(array([ 3.2464237e+01+0.j, -2.4642489e+00+0.j, -5.7737714e-07+0.j, - 1.4719126e-07+0.j], dtype=complex64), array([[ 0.4037776 +0.j, 0.8288327 +0.j, -0.53126234 -0.j, - 0.052026853-0.j], - [ 0.46480742 +0.j, 0.43714646 -0.j, 0.80768156 +0.j, - -0.47577178 -0.j], - [ 0.52583724 +0.j, 0.045459922-0.j, -0.021575088-0.j, - 0.79546237 +0.j], - [ 0.5868671 +0.j, -0.3462263 -0.j, -0.25484383 -0.j, - -0.3717177 -0.j]], dtype=complex64), array([[ 0.114176475+0.j, 0.7327782 +0.j, -0.5452461 -0.j, - -0.13326685 -0.j], - [ 0.3300045 +0.j, 0.28974816 -0.j, 0.68821603 +0.j, - -0.2182906 -0.j], - [ 0.5458328 +0.j, -0.1532814 -0.j, 0.25930583 -0.j, - 0.8363818 +0.j], - [ 0.76166093 +0.j, -0.5963111 -0.j, -0.40227592 -0.j, - -0.4848244 -0.j]], dtype=complex64)), - mlir_module_text=r""" -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<4xcomplex> {jax.result_info = "[0]"}, tensor<4x4xcomplex> {jax.result_info = "[1]"}, tensor<4x4xcomplex> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<16xcomplex> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<16xcomplex>) -> tensor<4x4xcomplex> loc(#loc4) - %2 = stablehlo.constant dense<1> : tensor loc(#loc5) - %3 = stablehlo.constant dense<4> : tensor loc(#loc5) - %4 = stablehlo.constant dense<86> : tensor loc(#loc5) - %5 = stablehlo.constant dense<86> : tensor loc(#loc5) - %6:6 = stablehlo.custom_call @lapack_cgeev(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor, tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor<8xf32>, tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor) loc(#loc5) - %7 = stablehlo.constant dense<0> : tensor loc(#loc5) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor loc(#loc5) - %9 = stablehlo.compare EQ, %6#5, %8, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) - %11 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor>) -> tensor<4xcomplex> loc(#loc5) - %13 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc5) - %14 = stablehlo.select %13, %6#2, %12 : tensor<4xi1>, tensor<4xcomplex> loc(#loc5) - %15 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %16 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %18 = stablehlo.broadcast_in_dim %15, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %19 = stablehlo.select %18, %6#3, %17 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - %20 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %21 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %24 = stablehlo.select %23, %6#4, %22 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - return %14, %19, %24 : tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":496:0) -#loc2 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":497:0) -#loc3 = loc("jit(func)/jit(main)/iota[dtype=complex64 shape=(16,) dimension=0]"(#loc1)) -#loc4 = loc("jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]"(#loc1)) -#loc5 = loc("jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01#\x05\x01\x03\x01\x03\x05\x03\x13\x07\t\x0b\r\x0f\x11\x13\x15\x17\x03\xe5\x9b7\x01[\x0f\x13\x0b\x07\x0b\x13\x0f\x0b\x17\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x17\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03A\x0fO\x0b\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x0f\x1f\x1f\x13\x0b\x0b\x0b\x0b\x1f#\x1f\x0f\x0b\x0b//O\x01\x03\x0f\x035\x17\x0f\x07\x13\x0b\x07\x0f\x0f\x07\x07\x17\x17\x1b\x13\x07\x07\x13\x13\x13\x13\x13\x0f\x13\x13\x13\x13\x02Z\x06\x1d9;\x03\x03\t\x8f\x05\x19\x1f\x05\x1b\x03\x03\x05\x95\x11\x01\x05\x05\x1d\x17\x13\xc2\x07\x01\x05\x1f\x03\x03\x05\x7f\x03\x03\t\x99\x03\x07\x1b\r\x1d\r\x0f\x1f\x05!\x05#\x05%\x03\x0b#_%e'g\x0fu)w\x05'\x05)\x05+\x05-\x03\x03-y\x05/\x1d1\x11\x051\x1d5\x11\x053\x03\x03\x05{\x055\x17\x13\xc6\x07\x01\x03\x03\x05}\x03\x11A\x81C\x83E\x85G_I\x87K\x89M_O\x8b\x057\x059\x05;\x05=\x05?\x05A\x05C\x05E\x03\x03\x05\x8d\x03\x05U\x91W\x93\x05G\x05I\x03\x03\t\x97\x1f%\x01\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1dK\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1b\x03\x07imq\r\x03ak\x1dM\r\x03ao\x1dO\r\x03as\x1dQ\x1dS\x1dU\x13\r\x01\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x11\x03V\x0b\x05\x1dW\x1dY\x05\x01\x03\x0b[[[[]\x03\r]cc]][\x1f\x05\t\x00\x00\x00\x00\x1f+\x01\t\x07\x07\x01\x1f\x0f\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f5!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x0b)\x01\x1f\x01)\x03\x11\x0b\x03\x15\x1d)\x01\x0b)\x01!\x13\t)\x05\x05\x05\x07)\x05\x11\x11\x07\x11\x01\x07\t\x03\x03)\x03A\x0b\x1b!)\x03!\x15)\x03\x01\x13)\x03\t\x13)\x03\x05\x13)\x03\x01\r)\x01\x07)\x03\x05\x07)\x03\x11\x07)\x03\x05\r)\x03\t\r\x04j\x03\x05\x01\x11\x07\x19\x07\x03\x01\x05\t\x11\x07!\x05\x03=i\x0b\x03/+\x03\x1d\r\x063\x03\x03\x03\x01\x05\x03\x017\x03\x05\x05\x03\x01=\x03\x05\x05\x03\x01\x15\x03\x11\x05\x03\x01\x15\x03\x11\x0f\x07\x01?\r\x03#\t\x03\x03\x05\x0b\x05\x07\t\x0b\x03\x05\x03\x01Q\x03\x05\x03\x07\x01\x03\x03\x05\x03\x19\x11\x07\x01S\x03-\x05\x17\x1b\x03\x07\x01\x03\x03/\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\t\x03!\x03\x07\x01Y\x031\x03\x1f\x07\x06\x01\x03\t\x07%\x11#\x03\x07\x01\x03\x03\x17\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\x03\x03+\x03\x07\x01\x17\x03\x19\x03)\x07\x06\x01\x03\x03\x07/\x13-\x03\x07\x01\x03\x03\x17\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\x03\x035\x03\x07\x01\x17\x03\x19\x033\x07\x06\x01\x03\x03\x079\x157\x13\x04\x07\x07'1;\x06\x03\x01\x05\x01\x00\xfe\x0c[\x1b\x03\x0f\x0b\t\t\t!+\x1b\x1f/!!)#\x1f\x19\xb1}\x85\x1f\x1f\x15\x1d\x15\x13%)\x97\x13+\r\x15\x17\x1f\x17\x11\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=complex64 shape=(16,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]\x00jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_cgeev\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["c128"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zgeev'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(), - expected_outputs=(array([ 3.2464249196572965e+01+0.j, -2.4642491965729807e+00+0.j, - -1.6035677295293283e-15+0.j, 1.2218554396786611e-16+0.j]), array([[ 0.40377749076862335 +0.j, 0.8288327563197505 +0.j, - -0.5457111210844892 +0.j, -0.2322136424094458 -0.j], - [ 0.46480737115848997 +0.j, 0.4371463883638875 -0.j, - 0.7625701354883243 +0.j, -0.06012408092789514 -0.j], - [ 0.5258372515483578 +0.j, 0.045460020408024694-0.j, - 0.1119930922768192 +0.j, 0.8168890890841272 +0.j], - [ 0.5868671319382247 +0.j, -0.34622634754783854 -0.j, - -0.32885210668065423 +0.j, -0.5245513657467864 -0.j]]), array([[ 0.11417645138733871+0.j, 0.7327780959803554 +0.j, - -0.49606131100796214+0.j, -0.04689746607984153-0.j], - [ 0.3300045986655476 +0.j, 0.2897483523969264 -0.j, - 0.8344969112540657 +0.j, -0.34421909950105706-0.j], - [ 0.5458327459437571 +0.j, -0.15328139118650172-0.j, - -0.18080988948424467+0.j, 0.8291305972416383 +0.j], - [ 0.7616608932219663 +0.j, -0.5963111347699308 -0.j, - -0.1576257107618584 +0.j, -0.4380140316607401 -0.j]])), - mlir_module_text=r""" -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<4xcomplex> {jax.result_info = "[0]"}, tensor<4x4xcomplex> {jax.result_info = "[1]"}, tensor<4x4xcomplex> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<16xcomplex> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<16xcomplex>) -> tensor<4x4xcomplex> loc(#loc4) - %2 = stablehlo.constant dense<1> : tensor loc(#loc5) - %3 = stablehlo.constant dense<4> : tensor loc(#loc5) - %4 = stablehlo.constant dense<86> : tensor loc(#loc5) - %5 = stablehlo.constant dense<86> : tensor loc(#loc5) - %6:6 = stablehlo.custom_call @lapack_zgeev(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor, tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor<8xf64>, tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor) loc(#loc5) - %7 = stablehlo.constant dense<0> : tensor loc(#loc5) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor loc(#loc5) - %9 = stablehlo.compare EQ, %6#5, %8, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) - %11 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor>) -> tensor<4xcomplex> loc(#loc5) - %13 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc5) - %14 = stablehlo.select %13, %6#2, %12 : tensor<4xi1>, tensor<4xcomplex> loc(#loc5) - %15 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %16 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %18 = stablehlo.broadcast_in_dim %15, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %19 = stablehlo.select %18, %6#3, %17 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - %20 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %21 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %24 = stablehlo.select %23, %6#4, %22 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - return %14, %19, %24 : tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":496:0) -#loc2 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":497:0) -#loc3 = loc("jit(func)/jit(main)/iota[dtype=complex128 shape=(16,) dimension=0]"(#loc1)) -#loc4 = loc("jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]"(#loc1)) -#loc5 = loc("jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01#\x05\x01\x03\x01\x03\x05\x03\x13\x07\t\x0b\r\x0f\x11\x13\x15\x17\x03\xe5\x9b7\x01[\x0f\x13\x0b\x07\x0b\x13\x0f\x0b\x17\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x17\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03A\x0fO\x0b\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x0f\x1f\x1f\x13\x0b\x0b\x0b\x0b\x1f#\x1f\x0f\x0b\x0bO/O\x01\x03\x0f\x035\x17\x0f\x07\x13\x0b\x07\x0f\x0f\x07\x07\x17\x17\x1b\x13\x07\x07\x13\x13\x13\x13\x13\x0f\x13\x13\x13\x13\x02z\x06\x1d9;\x03\x03\t\x8f\x05\x19\x1f\x05\x1b\x03\x03\x05\x95\x11\x01\x05\x05\x1d\x17\x13\xc2\x07\x01\x05\x1f\x03\x03\x05\x7f\x03\x03\t\x99\x03\x07\x1b\r\x1d\r\x0f\x1f\x05!\x05#\x05%\x03\x0b#_%e'g\x0fu)w\x05'\x05)\x05+\x05-\x03\x03-y\x05/\x1d1\x11\x051\x1d5\x11\x053\x03\x03\x05{\x055\x17\x13\xc6\x07\x01\x03\x03\x05}\x03\x11A\x81C\x83E\x85G_I\x87K\x89M_O\x8b\x057\x059\x05;\x05=\x05?\x05A\x05C\x05E\x03\x03\x05\x8d\x03\x05U\x91W\x93\x05G\x05I\x03\x03\t\x97\x1f%\x01\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1dK\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1b\x03\x07imq\r\x03ak\x1dM\r\x03ao\x1dO\r\x03as\x1dQ\x1dS\x1dU\x13\r\x01\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x11\x03V\x0b\x05\x1dW\x1dY\x05\x01\x03\x0b[[[[]\x03\r]cc]][\x1f\x05\t\x00\x00\x00\x00\x1f+\x01\t\x07\x07\x01\x1f\x0f!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f5!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x0b)\x01\x1f\x01)\x03\x11\x0b\x03\x15\x1d)\x01\x0b)\x01!\x13\x0b)\x05\x05\x05\x07)\x05\x11\x11\x07\x11\x01\x07\t\x03\x03)\x03A\x0b\x1b!)\x03!\x15)\x03\x01\x13)\x03\t\x13)\x03\x05\x13)\x03\x01\r)\x01\x07)\x03\x05\x07)\x03\x11\x07)\x03\x05\r)\x03\t\r\x04j\x03\x05\x01\x11\x07\x19\x07\x03\x01\x05\t\x11\x07!\x05\x03=i\x0b\x03/+\x03\x1d\r\x063\x03\x03\x03\x01\x05\x03\x017\x03\x05\x05\x03\x01=\x03\x05\x05\x03\x01\x15\x03\x11\x05\x03\x01\x15\x03\x11\x0f\x07\x01?\r\x03#\t\x03\x03\x05\x0b\x05\x07\t\x0b\x03\x05\x03\x01Q\x03\x05\x03\x07\x01\x03\x03\x05\x03\x19\x11\x07\x01S\x03-\x05\x17\x1b\x03\x07\x01\x03\x03/\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\t\x03!\x03\x07\x01Y\x031\x03\x1f\x07\x06\x01\x03\t\x07%\x11#\x03\x07\x01\x03\x03\x17\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\x03\x03+\x03\x07\x01\x17\x03\x19\x03)\x07\x06\x01\x03\x03\x07/\x13-\x03\x07\x01\x03\x03\x17\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\x03\x035\x03\x07\x01\x17\x03\x19\x033\x07\x06\x01\x03\x03\x079\x157\x13\x04\x07\x07'1;\x06\x03\x01\x05\x01\x00\x02\r[\x1b\x03\x0f\x0b\t\t\t!+\x1b\x1f/!!)#\x1f\x19\xb1}\x87\x1f\x1f\x15\x1d\x15\x13%)\x97\x13+\r\x15\x17\x1f\x17\x11\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=complex128 shape=(16,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]\x00jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_zgeev\x00", - xla_call_module_version=6, -) # End paste +import numpy as np +array = np.array +complex64 = np.complex64 data_2024_08_19 = {} - # Pasted from the test output (see export_back_compat_test_util.py module docstring) data_2024_08_19["c128"] = dict( testdata_version=1, diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eigh_lapack_syev.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eigh_lapack_syev.py index f0696db1aeda..68656d799b35 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eigh_lapack_syev.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eigh_lapack_syev.py @@ -15,378 +15,14 @@ # ruff: noqa import datetime -from numpy import array, float32, complex64 +import numpy as np -data_2023_03_17 = dict( - # Pasted from the test output (see back_compat_test.py module docstring) - f32=dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_ssyevd'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[-0.6185769 , -0.20142993 , -0.09725195 , 0.62983674 , - -0.07926044 , 0.3605001 , -0.019093221 , -0.18446997 ], - [-0.47070873 , 0.29325768 , -0.19454119 , -0.6394365 , - 0.0622955 , 0.33249345 , 0.28112718 , -0.22856665 ], - [-0.32284075 , -0.12361939 , 0.20547704 , -0.18307868 , - 0.47294614 , -0.3170349 , -0.6373532 , -0.27266347 ], - [-0.17497246 , -0.079641335 , 0.15042791 , -0.15416273 , - -0.815209 , -0.38054234 , -0.083263926 , -0.31676024 ], - [-0.027104253 , -0.26490977 , 0.32271704 , 0.08653544 , - 0.30305928 , -0.33998996 , 0.6926741 , -0.360857 ], - [ 0.12076397 , 0.43288827 , -0.64385164 , 0.2652551 , - 0.09482376 , -0.37435007 , 0.00091664493, -0.40495378 ], - [ 0.26863196 , 0.51607686 , 0.53846526 , 0.16969058 , - -0.021670295 , 0.35755336 , -0.113144726 , -0.4490505 ], - [ 0.4165004 , -0.57262254 , -0.2814425 , -0.17463988 , - -0.01698498 , 0.3613705 , -0.12186296 , -0.49314725 ]], - dtype=float32), array([-2.4598808e+01, -3.3105560e-05, -3.1002426e-05, -1.0103593e-05, - -1.0022322e-05, 4.0141886e-06, 9.5510331e-06, 2.7659882e+02], - dtype=float32)), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<8x8xf32> {jax.result_info = "[0]"}, tensor<8xf32> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<64xf32> - %1 = stablehlo.reshape %0 : (tensor<64xf32>) -> tensor<8x8xf32> - %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xf32>) -> tensor<8x8xf32> - %3 = stablehlo.add %1, %2 : tensor<8x8xf32> - %4 = stablehlo.constant dense<2.000000e+00> : tensor - %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<8x8xf32> - %6 = stablehlo.divide %3, %5 : tensor<8x8xf32> - %7 = call @tril(%6) : (tensor<8x8xf32>) -> tensor<8x8xf32> - %8 = stablehlo.constant dense<1> : tensor - %9 = stablehlo.constant dense<1> : tensor - %10 = stablehlo.constant dense<8> : tensor - %11 = stablehlo.custom_call @lapack_ssyevd(%8, %9, %10, %7) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor<8x8xf32>) -> tuple, tensor<8xf32>, tensor, tensor<177xf32>, tensor<43xi32>> - %12 = stablehlo.get_tuple_element %11[0] : (tuple, tensor<8xf32>, tensor, tensor<177xf32>, tensor<43xi32>>) -> tensor<8x8xf32> - %13 = stablehlo.get_tuple_element %11[1] : (tuple, tensor<8xf32>, tensor, tensor<177xf32>, tensor<43xi32>>) -> tensor<8xf32> - %14 = stablehlo.get_tuple_element %11[2] : (tuple, tensor<8xf32>, tensor, tensor<177xf32>, tensor<43xi32>>) -> tensor - %15 = stablehlo.get_tuple_element %11[3] : (tuple, tensor<8xf32>, tensor, tensor<177xf32>, tensor<43xi32>>) -> tensor<177xf32> - %16 = stablehlo.get_tuple_element %11[4] : (tuple, tensor<8xf32>, tensor, tensor<177xf32>, tensor<43xi32>>) -> tensor<43xi32> - %17 = stablehlo.constant dense<0> : tensor - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor - %19 = stablehlo.compare EQ, %14, %18, SIGNED : (tensor, tensor) -> tensor - %20 = stablehlo.broadcast_in_dim %19, dims = [] : (tensor) -> tensor<1x1xi1> - %21 = stablehlo.constant dense<0x7FC00000> : tensor - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor) -> tensor<8x8xf32> - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> - %24 = stablehlo.select %23, %12, %22 : tensor<8x8xi1>, tensor<8x8xf32> - %25 = stablehlo.broadcast_in_dim %19, dims = [] : (tensor) -> tensor<1xi1> - %26 = stablehlo.constant dense<0x7FC00000> : tensor - %27 = stablehlo.broadcast_in_dim %26, dims = [] : (tensor) -> tensor<8xf32> - %28 = stablehlo.broadcast_in_dim %25, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> - %29 = stablehlo.select %28, %13, %27 : tensor<8xi1>, tensor<8xf32> - return %24, %29 : tensor<8x8xf32>, tensor<8xf32> - } - func.func private @tril(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { - %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> - %1 = stablehlo.constant dense<0> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<8x8xi32> - %3 = stablehlo.add %0, %2 : tensor<8x8xi32> - %4 = stablehlo.iota dim = 1 : tensor<8x8xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<8x8xf32> - %8 = stablehlo.select %5, %arg0, %7 : tensor<8x8xi1>, tensor<8x8xf32> - return %8 : tensor<8x8xf32> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01-\x05\x01\x05\x01\x03\x05\x03\x1d\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!\x03z\x02\xf77\x01\x9b\x0f\x17\x13\x0b\x07\x0f\x0b\x0b\x0b\x0b\x17\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x17\x0f\x13\x13\x13\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x13\x1b\x13\x13\x03]\x0f/\x0b\x0b\x0f\x0b\x0bO\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1fO\x1f\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x1f\x0f\x0f\x0f\x0f\x0f\x0b\x1fO/\x037\x17\x0f\x07\x0f\x07\x13\x07\x07\x17\x07\x17\x13\x17\x13\x17\x17\x13\x17\x1f\x13\x13\x13\x0f\x17\x13\x13\x13\x02\n\t\x1du\x03\x17\x11\xf6\x04\x01\x03\x03\x13\xc5\x05#\x1f\x1d;\x03\x05%\x05'\x05)\x05+\x17\x11\xf2\x04\x01\x05-\x05/\x051\x053\x03\x03!\xc1\x055\x03\x03\x07\xc3\x1dA\x03\x057\x059\x17\x11\xea\x04\x01\x1do\x15\x03\x03\x07\xd1\x03\x03\x07\xf1\x03\x03\x0f5\x05;\x03\x0b\x17\x9f\x19\xab\x1b\xad\x0f\xb7\x1d\xb9\x03\x0b\x17\xa3\x19\xbd\x1b\xa3\x0f\xa5\x1d\xbf\x05=\x1d?\x03\x05?\x05A\x03\x03!\xc7\x1dG\x03\x05C\x03\x05'\xa7)\xc9\x1dM\x03\x05E\x03\x03\x07\xcb\x1dS\x03\x05G\x1dW\x03\x05I\x1d[+\x05K\x1d_+\x05M\x03\x03c\xcd\x05O\x1dg\x15\x05Q\x1dk\x15\x05S\x03\x03\x07\xcf\x05U\x03\x03s\xa5\x05W\x05Y\x03\x03\x07\xd3\x03\x11{\xd5}\xd7\x7f\xd9\x81\x9f\x83\xdb\x85\xdd\x87\xdf\x89\xe3\x05[\x05]\x05_\x05a\x05c\x05e\x05g\x05i\x03\x03\r\xe5\x03\x03\r\xe7\x03\x03\r\xe9\x03\x03\r\xeb\x03\x03\r\xed\x03\x05'\xa7)\xef\x03\x03\x13\xf3\x03\x03\x13\xf5\x1f'\x01\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1dk\x03\x03\xbb\x1dm\t\x07\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#\x1d\x03\x05\xaf\xb3\r\x03\xa1\xb1\x1do\r\x03\xa1\xb5\x1dq\x1ds\x1du\r\x01#\x1f\x1dw\x13\r\x01\x1f\x03\t\x00\x00\x00\x00\x1f!\x01\x13\r\x05\x07\x05\x1f\x07\t\x00\x00\x00\x00\x1f\x17!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\t\x00\x00\x00@\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x08\x00\x00\x00\x0b\x05\x1dy\x1d{\x05\x01\x03\t\x9b\x9b\x9b\xa9\x03\x03\xe1\x15\x03\x01\r\x01\x03\x0b\xa9\x9d\x9b\x9d\x9d\x13\x05\x01\x13\x05\x05\x13\x05\t\x13\x05\r\x13\x05\x11\x07\x01\x1f\x07\t\x00\x00\xc0\x7f\x1f\x17!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00)\x05!!\t)\x01\x05\x1b)\x01\t\t)\x03!\t\x1d\x01)\x05!!\x05\x13)\x05!!\x0f)\x03\t\r)\x03\x8a\x05\t)\x03\xad\x05\x11\x01\x05\x01\x0b\x11\x03\x01\x03\x01)\x03\x01\r)\x03\x02\x02\t/\x0b\x01\x0b\x03\x19\x1b)\x03\x01\x13)\x03\t\x13)\x03\x05\x13)\x01\x0f)\x05\x05\x05\x0f)\x03\x05\x0f)\x03!\x0f)\x03\x05\r\x04:\x05\x05\x01\x11\t3\x07\x03\x01\t\r\x11\t7\x05\x03=}\t\x03Y\x1f\x03#\x15\x06]\x03\x01\x03\x01\x17\x07ea\x03\x01\x03\x03\x0f\x06i\x03\x01\x05\x03\x05\x05\x03\tm\x03\x07\x03\x07-\x05\x03\x01\x03\t\x19\x06-\x03\x01\x05\x07\x0b\x1b\x07\x0bq\x03\x01\x03\r\x05\x03\x01/\x03\x03\x05\x03\x01/\x03\x03\x05\x03\x01w\x03\x03\x1d\x07\x01y\x03%\t\x11\x13\x15\x0f\x07\x07\x01\x8b\x03\x01\x03\x17\x07\x07\x01\x8d\x03\x0b\x03\x17\x07\x07\x01\x8f\x03\x03\x03\x17\x07\x07\x01\x91\x03\x19\x03\x17\x07\x07\x01\x93\x03\x1b\x03\x17\x05\x03\x01#\x03\x03\x03\x07\x01\x05\x03\x03\x03#\x11\x07\x01\x95\x03-\x05\x1d%\x03\x07\x01\x05\x03/\x03'\x05\x03\x011\x03\x07\x03\x07\x01\x05\x03\x01\x03+\x03\x07\x01\x97\x03\x15\x03)\x0b\x06\x01\x03\x01\x07/\x19-\x03\x07\x01\x05\x031\x03'\x05\x03\x011\x03\x07\x03\x07\x01\x05\x03\x0b\x035\x03\x07\x01\x99\x033\x033\x0b\x06\x01\x03\x0b\x079\x1b7\x13\x04\t\x051;\r\x11\x0b9\x05\x03\x15+\x03\x01\t\t\x03=\x1f\x03\x11\x05\x03\x0b#\x03\x03\x03\x07%\x05\x03\x11\x03\x05\x0f\x06%\x03\x11\x05\x03\x07\t\x03EC\x03\x11\x11\x07KI\x03\x15\x05\t\x0b\x05\x03\x0bO\x03\x07\x03\x07Q\x05\x03\x01\x03\x0f\x0b\x06U\x03\x01\x07\r\x01\x11\x13\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\xb2\x19}\x1d\x03\x11\x0f\x0b\t\t\x0b!\x1f/!!)#\x1f\x19\x7f\x0f99m\x19\x85\x89W\xb3K\x9bM\x9b\x96\x04\x1b+\x1b\x1f\x1f\x15\x1d\x15+\x83\x13\r\r\x1f\x11\x15\x1b\x17\x15\x17\x0f\x11\x15\x11+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00value\x00index\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00lapack_ssyevd\x00", - xla_call_module_version=4, - ), # End paste - - # Pasted from the test output (see back_compat_test.py module docstring) - f64=dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dsyevd'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[-6.1857700048412056e-01, 2.4081403770912022e-01, - 3.5662489253627483e-01, -6.3034019033669797e-01, - 1.0043483479985752e-16, -2.8842036081919542e-02, - 7.7164692943283169e-25, -1.8446994643771725e-01], - [-4.7070881487314614e-01, 4.7473787464450845e-01, - -4.8036836210243367e-01, 4.3802686872516400e-01, - 1.7961797619639258e-01, 8.3080980076741355e-03, - 2.1415294457221756e-01, -2.2856669794666584e-01], - [-3.2284062926217072e-01, -5.4336490915553370e-01, - 2.2181041859724990e-01, 2.9947877954402297e-01, - -3.6491813600134632e-01, 3.2867679819727436e-01, - 3.8223299448843473e-01, -2.7266344945561438e-01], - [-1.7497244365119530e-01, -8.9251550609769414e-02, - -6.3518515114898394e-02, 1.9162997359209971e-01, - -2.2087281326110139e-01, 5.9957027043505064e-02, - -8.7632498908241274e-01, -3.1676020096456303e-01], - [-2.7104258040220038e-02, -3.3772873786627672e-01, - 2.5901386593721748e-01, 1.7032650752287815e-01, - 6.7521217612940332e-01, -4.5036136532965476e-01, - -1.2279030059078447e-02, -3.6085695247351163e-01], - [ 1.2076392757075530e-01, -3.3834734096469254e-01, - -6.5506827461665540e-01, -5.0472498521116749e-01, - 6.9987430903492118e-02, 1.0595648906599275e-01, - 8.3443844143082022e-02, -4.0495370398246017e-01], - [ 2.6863211318173097e-01, 2.2958613191407318e-01, - 6.3952843755683941e-02, 1.8776775771084137e-02, - -5.3523731432241317e-01, -5.9199531677602002e-01, - 1.7916671834524248e-01, -4.4905045549140887e-01], - [ 4.1650029879270661e-01, 3.6355449432857079e-01, - 2.9755313100756142e-01, 1.6826270392615944e-02, - 1.9621068035557282e-01, 5.6830030587314817e-01, - 2.9607517592514246e-02, -4.9314720700035747e-01]]), array([-2.4598804776133626e+01, -4.6567755957874661e-14, - -1.9932120610662194e-14, -5.7323356091157378e-15, - -4.5459724251334835e-16, 4.0479851042511616e-14, - 9.2325194924982089e-14, 2.7659880477613365e+02])), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<8x8xf64> {jax.result_info = "[0]"}, tensor<8xf64> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<64xf64> - %1 = stablehlo.reshape %0 : (tensor<64xf64>) -> tensor<8x8xf64> - %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xf64>) -> tensor<8x8xf64> - %3 = stablehlo.add %1, %2 : tensor<8x8xf64> - %4 = stablehlo.constant dense<2.000000e+00> : tensor - %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<8x8xf64> - %6 = stablehlo.divide %3, %5 : tensor<8x8xf64> - %7 = call @tril(%6) : (tensor<8x8xf64>) -> tensor<8x8xf64> - %8 = stablehlo.constant dense<1> : tensor - %9 = stablehlo.constant dense<1> : tensor - %10 = stablehlo.constant dense<8> : tensor - %11 = stablehlo.custom_call @lapack_dsyevd(%8, %9, %10, %7) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor<8x8xf64>) -> tuple, tensor<8xf64>, tensor, tensor<177xf64>, tensor<43xi32>> - %12 = stablehlo.get_tuple_element %11[0] : (tuple, tensor<8xf64>, tensor, tensor<177xf64>, tensor<43xi32>>) -> tensor<8x8xf64> - %13 = stablehlo.get_tuple_element %11[1] : (tuple, tensor<8xf64>, tensor, tensor<177xf64>, tensor<43xi32>>) -> tensor<8xf64> - %14 = stablehlo.get_tuple_element %11[2] : (tuple, tensor<8xf64>, tensor, tensor<177xf64>, tensor<43xi32>>) -> tensor - %15 = stablehlo.get_tuple_element %11[3] : (tuple, tensor<8xf64>, tensor, tensor<177xf64>, tensor<43xi32>>) -> tensor<177xf64> - %16 = stablehlo.get_tuple_element %11[4] : (tuple, tensor<8xf64>, tensor, tensor<177xf64>, tensor<43xi32>>) -> tensor<43xi32> - %17 = stablehlo.constant dense<0> : tensor - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor - %19 = stablehlo.compare EQ, %14, %18, SIGNED : (tensor, tensor) -> tensor - %20 = stablehlo.broadcast_in_dim %19, dims = [] : (tensor) -> tensor<1x1xi1> - %21 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor) -> tensor<8x8xf64> - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> - %24 = stablehlo.select %23, %12, %22 : tensor<8x8xi1>, tensor<8x8xf64> - %25 = stablehlo.broadcast_in_dim %19, dims = [] : (tensor) -> tensor<1xi1> - %26 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %27 = stablehlo.broadcast_in_dim %26, dims = [] : (tensor) -> tensor<8xf64> - %28 = stablehlo.broadcast_in_dim %25, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> - %29 = stablehlo.select %28, %13, %27 : tensor<8xi1>, tensor<8xf64> - return %24, %29 : tensor<8x8xf64>, tensor<8xf64> - } - func.func private @tril(%arg0: tensor<8x8xf64>) -> tensor<8x8xf64> { - %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> - %1 = stablehlo.constant dense<0> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<8x8xi32> - %3 = stablehlo.add %0, %2 : tensor<8x8xi32> - %4 = stablehlo.iota dim = 1 : tensor<8x8xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<8x8xf64> - %8 = stablehlo.select %5, %arg0, %7 : tensor<8x8xi1>, tensor<8x8xf64> - return %8 : tensor<8x8xf64> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01-\x05\x01\x05\x01\x03\x05\x03\x1d\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!\x03z\x02\xf77\x01\x9b\x0f\x17\x13\x0b\x07\x0f\x0b\x0b\x0b\x0b\x17\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x17\x0f\x13\x13\x13\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x13\x1b\x13\x13\x03]\x0f/\x0b\x0b\x0f\x0b\x0bO\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b/O/\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x1f\x0f\x0f\x0f\x0f\x0f\x0b/O/\x037\x17\x0f\x07\x0f\x07\x13\x07\x07\x17\x07\x17\x13\x17\x13\x17\x17\x13\x17\x1f\x13\x13\x13\x0f\x17\x13\x13\x13\x02:\t\x1du\x03\x17\x11\xf6\x04\x01\x03\x03\x13\xc5\x05#\x1f\x1d;\x03\x05%\x05'\x05)\x05+\x17\x11\xf2\x04\x01\x05-\x05/\x051\x053\x03\x03!\xc1\x055\x03\x03\x07\xc3\x1dA\x03\x057\x059\x17\x11\xea\x04\x01\x1do\x15\x03\x03\x07\xd1\x03\x03\x07\xf1\x03\x03\x0f5\x05;\x03\x0b\x17\x9f\x19\xab\x1b\xad\x0f\xb7\x1d\xb9\x03\x0b\x17\xa3\x19\xbd\x1b\xa3\x0f\xa5\x1d\xbf\x05=\x1d?\x03\x05?\x05A\x03\x03!\xc7\x1dG\x03\x05C\x03\x05'\xa7)\xc9\x1dM\x03\x05E\x03\x03\x07\xcb\x1dS\x03\x05G\x1dW\x03\x05I\x1d[+\x05K\x1d_+\x05M\x03\x03c\xcd\x05O\x1dg\x15\x05Q\x1dk\x15\x05S\x03\x03\x07\xcf\x05U\x03\x03s\xa5\x05W\x05Y\x03\x03\x07\xd3\x03\x11{\xd5}\xd7\x7f\xd9\x81\x9f\x83\xdb\x85\xdd\x87\xdf\x89\xe3\x05[\x05]\x05_\x05a\x05c\x05e\x05g\x05i\x03\x03\r\xe5\x03\x03\r\xe7\x03\x03\r\xe9\x03\x03\r\xeb\x03\x03\r\xed\x03\x05'\xa7)\xef\x03\x03\x13\xf3\x03\x03\x13\xf5\x1f'\x01\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1dk\x03\x03\xbb\x1dm\t\x07\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#\x1d\x03\x05\xaf\xb3\r\x03\xa1\xb1\x1do\r\x03\xa1\xb5\x1dq\x1ds\x1du\r\x01#\x1f\x1dw\x13\r\x01\x1f\x03\t\x00\x00\x00\x00\x1f!\x01\x13\r\x05\x07\x05\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x17!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00@\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x08\x00\x00\x00\x0b\x05\x1dy\x1d{\x05\x01\x03\t\x9b\x9b\x9b\xa9\x03\x03\xe1\x15\x03\x01\r\x01\x03\x0b\xa9\x9d\x9b\x9d\x9d\x13\x05\x01\x13\x05\x05\x13\x05\t\x13\x05\r\x13\x05\x11\x07\x01\x1f\x07\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x17!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00)\x05!!\t)\x01\x05\x1b)\x01\t\x0b)\x03!\t\x1d\x01)\x05!!\x05\x13)\x05!!\x0f)\x03\t\r)\x03\x8a\x05\t)\x03\xad\x05\x11\x01\x05\x01\x0b\x11\x03\x01\x03\x01)\x03\x01\r)\x03\x02\x02\t/\x0b\x01\x0b\x03\x19\x1b)\x03\x01\x13)\x03\t\x13)\x03\x05\x13)\x01\x0f)\x05\x05\x05\x0f)\x03\x05\x0f)\x03!\x0f)\x03\x05\r\x04:\x05\x05\x01\x11\t3\x07\x03\x01\t\r\x11\t7\x05\x03=}\t\x03Y\x1f\x03#\x15\x06]\x03\x01\x03\x01\x17\x07ea\x03\x01\x03\x03\x0f\x06i\x03\x01\x05\x03\x05\x05\x03\tm\x03\x07\x03\x07-\x05\x03\x01\x03\t\x19\x06-\x03\x01\x05\x07\x0b\x1b\x07\x0bq\x03\x01\x03\r\x05\x03\x01/\x03\x03\x05\x03\x01/\x03\x03\x05\x03\x01w\x03\x03\x1d\x07\x01y\x03%\t\x11\x13\x15\x0f\x07\x07\x01\x8b\x03\x01\x03\x17\x07\x07\x01\x8d\x03\x0b\x03\x17\x07\x07\x01\x8f\x03\x03\x03\x17\x07\x07\x01\x91\x03\x19\x03\x17\x07\x07\x01\x93\x03\x1b\x03\x17\x05\x03\x01#\x03\x03\x03\x07\x01\x05\x03\x03\x03#\x11\x07\x01\x95\x03-\x05\x1d%\x03\x07\x01\x05\x03/\x03'\x05\x03\x011\x03\x07\x03\x07\x01\x05\x03\x01\x03+\x03\x07\x01\x97\x03\x15\x03)\x0b\x06\x01\x03\x01\x07/\x19-\x03\x07\x01\x05\x031\x03'\x05\x03\x011\x03\x07\x03\x07\x01\x05\x03\x0b\x035\x03\x07\x01\x99\x033\x033\x0b\x06\x01\x03\x0b\x079\x1b7\x13\x04\t\x051;\r\x11\x0b9\x05\x03\x15+\x03\x01\t\t\x03=\x1f\x03\x11\x05\x03\x0b#\x03\x03\x03\x07%\x05\x03\x11\x03\x05\x0f\x06%\x03\x11\x05\x03\x07\t\x03EC\x03\x11\x11\x07KI\x03\x15\x05\t\x0b\x05\x03\x0bO\x03\x07\x03\x07Q\x05\x03\x01\x03\x0f\x0b\x06U\x03\x01\x07\r\x01\x11\x13\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\xb2\x19}\x1d\x03\x11\x0f\x0b\t\t\x0b!\x1f/!!)#\x1f\x19\x7f\x0f99m\x19\x85\x89W\xb3K\x9bM\x9b\x96\x04\x1b+\x1b\x1f\x1f\x15\x1d\x15+\x83\x13\r\r\x1f\x11\x15\x1b\x17\x15\x17\x0f\x11\x15\x11+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00value\x00index\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=float64 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00lapack_dsyevd\x00", - xla_call_module_version=4, - ), # End paste - - # Pasted from the test output (see back_compat_test.py module docstring) - c64=dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_cheevd'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[-0.6185769 +0.j, -0.20142993 +0.j, -0.09725195 +0.j, - 0.62983674 +0.j, -0.07926044 +0.j, 0.3605001 -0.j, - -0.019093221 +0.j, -0.18446997 +0.j], - [-0.47070873 +0.j, 0.29325768 +0.j, -0.19454116 +0.j, - -0.6394365 +0.j, 0.06229549 +0.j, 0.33249345 +0.j, - 0.28112718 +0.j, -0.22856665 +0.j], - [-0.32284075 +0.j, -0.12361939 +0.j, 0.20547704 +0.j, - -0.18307868 +0.j, 0.47294614 +0.j, -0.3170349 +0.j, - -0.6373532 +0.j, -0.27266347 +0.j], - [-0.17497246 +0.j, -0.079641335 +0.j, 0.15042792 +0.j, - -0.15416273 +0.j, -0.815209 +0.j, -0.38054234 +0.j, - -0.083263926 +0.j, -0.31676024 +0.j], - [-0.027104257 +0.j, -0.26490977 +0.j, 0.32271704 +0.j, - 0.08653544 +0.j, 0.30305928 +0.j, -0.33998996 +0.j, - 0.6926741 +0.j, -0.360857 +0.j], - [ 0.120763965 +0.j, 0.43288827 +0.j, -0.64385164 +0.j, - 0.2652551 +0.j, 0.094823755 +0.j, -0.37435007 +0.j, - 0.00091664493+0.j, -0.40495378 +0.j], - [ 0.26863196 +0.j, 0.51607686 +0.j, 0.53846526 +0.j, - 0.16969058 +0.j, -0.0216703 +0.j, 0.35755336 +0.j, - -0.113144726 +0.j, -0.4490505 +0.j], - [ 0.4165004 +0.j, -0.57262254 +0.j, -0.28144246 +0.j, - -0.17463988 +0.j, -0.016984984 +0.j, 0.3613705 +0.j, - -0.12186296 +0.j, -0.49314725 +0.j]], dtype=complex64), array([-2.4598808e+01, -3.3105560e-05, -3.1002426e-05, -1.0103593e-05, - -1.0022322e-05, 4.0141886e-06, 9.5510331e-06, 2.7659882e+02], - dtype=float32)), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<8x8xcomplex> {jax.result_info = "[0]"}, tensor<8xf32> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<64xcomplex> - %1 = stablehlo.reshape %0 : (tensor<64xcomplex>) -> tensor<8x8xcomplex> - %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xcomplex>) -> tensor<8x8xcomplex> - %3 = stablehlo.real %2 : (tensor<8x8xcomplex>) -> tensor<8x8xf32> - %4 = stablehlo.imag %2 : (tensor<8x8xcomplex>) -> tensor<8x8xf32> - %5 = stablehlo.negate %4 : tensor<8x8xf32> - %6 = stablehlo.complex %3, %5 : tensor<8x8xcomplex> - %7 = stablehlo.add %1, %6 : tensor<8x8xcomplex> - %8 = stablehlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor> - %9 = stablehlo.broadcast_in_dim %8, dims = [] : (tensor>) -> tensor<8x8xcomplex> - %10 = stablehlo.divide %7, %9 : tensor<8x8xcomplex> - %11 = call @tril(%10) : (tensor<8x8xcomplex>) -> tensor<8x8xcomplex> - %12 = stablehlo.constant dense<1> : tensor - %13 = stablehlo.constant dense<1> : tensor - %14 = stablehlo.constant dense<8> : tensor - %15 = stablehlo.custom_call @lapack_cheevd(%12, %13, %14, %11) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor<8x8xcomplex>) -> tuple>, tensor<8xf32>, tensor, tensor<81xcomplex>, tensor<169xf32>, tensor<43xi32>> - %16 = stablehlo.get_tuple_element %15[0] : (tuple>, tensor<8xf32>, tensor, tensor<81xcomplex>, tensor<169xf32>, tensor<43xi32>>) -> tensor<8x8xcomplex> - %17 = stablehlo.get_tuple_element %15[1] : (tuple>, tensor<8xf32>, tensor, tensor<81xcomplex>, tensor<169xf32>, tensor<43xi32>>) -> tensor<8xf32> - %18 = stablehlo.get_tuple_element %15[2] : (tuple>, tensor<8xf32>, tensor, tensor<81xcomplex>, tensor<169xf32>, tensor<43xi32>>) -> tensor - %19 = stablehlo.get_tuple_element %15[3] : (tuple>, tensor<8xf32>, tensor, tensor<81xcomplex>, tensor<169xf32>, tensor<43xi32>>) -> tensor<81xcomplex> - %20 = stablehlo.get_tuple_element %15[4] : (tuple>, tensor<8xf32>, tensor, tensor<81xcomplex>, tensor<169xf32>, tensor<43xi32>>) -> tensor<169xf32> - %21 = stablehlo.get_tuple_element %15[5] : (tuple>, tensor<8xf32>, tensor, tensor<81xcomplex>, tensor<169xf32>, tensor<43xi32>>) -> tensor<43xi32> - %22 = stablehlo.constant dense<0> : tensor - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor - %24 = stablehlo.compare EQ, %18, %23, SIGNED : (tensor, tensor) -> tensor - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1x1xi1> - %26 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> - %27 = stablehlo.broadcast_in_dim %26, dims = [] : (tensor>) -> tensor<8x8xcomplex> - %28 = stablehlo.broadcast_in_dim %25, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> - %29 = stablehlo.select %28, %16, %27 : tensor<8x8xi1>, tensor<8x8xcomplex> - %30 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi1> - %31 = stablehlo.constant dense<0x7FC00000> : tensor - %32 = stablehlo.broadcast_in_dim %31, dims = [] : (tensor) -> tensor<8xf32> - %33 = stablehlo.broadcast_in_dim %30, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> - %34 = stablehlo.select %33, %17, %32 : tensor<8xi1>, tensor<8xf32> - return %29, %34 : tensor<8x8xcomplex>, tensor<8xf32> - } - func.func private @tril(%arg0: tensor<8x8xcomplex>) -> tensor<8x8xcomplex> { - %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> - %1 = stablehlo.constant dense<0> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<8x8xi32> - %3 = stablehlo.add %0, %2 : tensor<8x8xi32> - %4 = stablehlo.iota dim = 1 : tensor<8x8xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> - %6 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<8x8xcomplex> - %8 = stablehlo.select %5, %arg0, %7 : tensor<8x8xi1>, tensor<8x8xcomplex> - return %8 : tensor<8x8xcomplex> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x015\x05\x01\x05\x01\x03\x05\x03%\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%')\x03\xc6\x02\x1e\x02?\x01\xa9\x0f\x17\x13\x0b\x17\x0b\x07\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x17\x0f\x13\x13\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x13\x13\x1b\x17\x03a\x0f/\x0b\x0b\x0f\x0b\x0bO\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b/O/\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17#\x0f\x0f\x0f\x0f\x0f\x0f\x0b/O\x1f/\x01\x07\x17\x17\x17\x03?\x17\x0f\x07\x0f\x07\x13\x07\x07\x0b\x17\x17\x07\x17\x13\x17\x17\x13\x0f\x17\x17\x13\x17#\x13\x13\x13\x0f\x17\x13\x13\x13\x02&\n\x1d\x83\x03\x17\x13\xf6\x04\x01\x03\x03\x15\xd3\x05+\x17\x13\xf2\x04\x01\x05-\x1f\x1d9\x03\x05/\x051\x053\x055\x057\x059\x05;\x03\x03!\xcf\x05=\x03\x03\x07\xd1\x1d?\x03\x05?\x05A\x17\x13\xea\x04\x01\x1d}\t\x03\x03\x07\xdf\x03\x03\x113\x05C\x03\x0b\x17\xad\x19\xb9\x1b\xbb\x11\xc5\x1d\xc7\x03\x0b\x17\xb1\x19\xcb\x1b\xb1\x11\xb3\x1d\xcd\x05E\x1d=\x03\x05G\x05I\x03\x03!\xd5\x1dE\x03\x05K\x03\x05'\xb5)\xd7\x1dK\x03\x05M\x03\x03\x07\xd9\x1dQ\x03\x05O\x1dU\x03\x05Q\x1dY+\x05S\x1d]+\x05U\x03\x03a\xdb\x05W\x1de\t\x05Y\x1di\t\x05[\x1dm\t\x05]\x1dq\t\x05_\x1du\t\x05a\x1dy\t\x05c\x03\x03\x07\xdd\x05e\x03\x03\x81\xb3\x05g\x05i\x03\x03\x07\xe1\x03\x11\x89\xe3\x8b\xe5\x8d\xe7\x8f\xad\x91\xe9\x93\xeb\x95\xed\x97\xf1\x05k\x05m\x05o\x05q\x05s\x05u\x05w\x05y\x03\x03\x0b\xf3\x03\x03\x0b\xf5\x03\x03\x0b\xf7\x03\x03\x0b\xf9\x03\x03\x0b\xfb\x03\x03\x0b\xfd\x03\x05'\xb5)\xff\x03\x03\x07\x02\x02\x1f/\x01\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d{\x03\x03\xc9\x1d}\t\x07\x1f1!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#%\x03\x05\xbd\xc1\r\x03\xaf\xbf\x1d\x7f\r\x03\xaf\xc3\x1d\x81\x1d\x83\x1d\x85\r\x01#'\x1d\x87\x13\r\x01\x1f\x03\t\x00\x00\x00\x00\x1f)\x01\x13\r\x05\x07\x05\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x11\x00\x00\x00@\x00\x00\x00\x00\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x08\x00\x00\x00\x0b\x05\x1d\x89\x1d\x8b\x05\x01\x03\t\xa9\xa9\xa9\xb7\x03\x03\xef\x15\x03\x01\r\x01\x03\r\xb7\xab\xa9\xab\xab\xab\x13\x05\x01\x13\x05\x05\x13\x05\t\x13\x05\r\x13\x05\x11\x13\x05\x15\x07\x01\x1f\x07\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f#\t\x00\x00\xc0\x7f\x1f=\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03\x15\x06\x02\x03\x03\x07\n\x02\x03\x03\x15\x0e\x02)\x05!!\x11)\x01\x05\x1b)\x01\x11\t)\x03!\t\x1d\x01\x03\t)\x05!!\x05)\x05!!\t\x13)\x05!!\x0f)\x03\t\r)\x03\x8a\x02\x11)\x03J\x05\t)\x03\xad\x05)\x01\t\x11\x01\x05\x01\x0b\x11\x03\x01\x03\x01)\x03\x01\r)\x03\x02\x02\x11/\r\x01\x0b\x03\x1d\x1f!)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x01\x0f)\x05\x05\x05\x0f)\x03\x05\x0f)\x03!\x0f)\x03\x05\r\x04\xda\x05\x05\x01\x11\r1\x07\x03\x01\t\r\x11\r5\x05\x03G\x91\t\x03W\x1f\x03+\x15\x06[\x03\x01\x03\x01\x17\x07c_\x03\x01\x03\x03\x19\x06g\x03\x15\x03\x05\x1b\x06k\x03\x15\x03\x05\x1d\x06o\x03\x15\x03\t\x1f\x06s\x03\x01\x05\x07\x0b\x0f\x06w\x03\x01\x05\x03\r\x05\x03\r{\x03\x07\x03\x07-\x05\x03\x01\x03\x11!\x06-\x03\x01\x05\x0f\x13#\x07\x0f\x7f\x03\x01\x03\x15\x05\x03\x01/\x03\x03\x05\x03\x01/\x03\x03\x05\x03\x01\x85\x03\x03%\x07\x01\x87\x03-\t\x19\x1b\x1d\x17\x07\x07\x01\x99\x03\x01\x03\x1f\x07\x07\x01\x9b\x03\x0b\x03\x1f\x07\x07\x01\x9d\x03\x03\x03\x1f\x07\x07\x01\x9f\x03\x1d\x03\x1f\x07\x07\x01\xa1\x03\x1f\x03\x1f\x07\x07\x01\xa3\x03!\x03\x1f\x05\x03\x01#\x03\x03\x03\x07\x01\x05\x03\x03\x03-\x11\x07\x01\xa5\x035\x05%/\x03\x07\x01\x05\x037\x031\x05\x03\x01\xa7\x03\x07\x03\x07\x01\x05\x03\x01\x035\x03\x07\x01\x12\x02\x03\x19\x033\x0b\x06\x01\x03\x01\x079!7\x03\x07\x01\x05\x039\x031\x05\x03\x01\x16\x02\x03#\x03\x07\x01\x05\x03\x0b\x03?\x03\x07\x01\x1a\x02\x03;\x03=\x0b\x06\x01\x03\x0b\x07C#A\x13\x04\r\x05;E\r\x11\x0f7\x05\x03\x15+\x03\x01\r\t\x03;\x1f\x03\x13\x05\x03\x0f#\x03\x03\x03\x07%\x05\x03\x13\x03\x05\x0f\x06%\x03\x13\x05\x03\x07\t\x03CA\x03\x13\x11\x07IG\x03\x19\x05\t\x0b\x05\x03\x0fM\x03\x07\x03\x07O\x05\x03\x01\x03\x0f\x0b\x06S\x03\x01\x07\r\x01\x11\x13\x04\x0f\x03\x13\x06\x03\x01\x05\x01\x00F\x1c\x8d\x1d\x03\x11\x0f\x0b\t\t\x0b!\x1f/!!)#\x1f\x19\x7f\x0f99A9;;m\x19\x85\x8dW\xb3K\x9bM\x9b\x96\x04\x1b+\x1b\x1f\x1f\x15\x1d\x15+\x83\x13\r\r\x1f\x11\x15\x17\x15\x11\x11\x1b\x17\x15\x17\x0f\x11\x15\x11+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00value\x00index\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=complex64 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/real\x00jit()/jit(main)/imag\x00jit()/jit(main)/neg\x00jit()/jit(main)/complex\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00lapack_cheevd\x00", - xla_call_module_version=4, - ), # End paste - - # Pasted from the test output (see back_compat_test.py module docstring) - c128=dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zheevd'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[-6.1857700048412056e-01+0.j, 2.4081403770912022e-01+0.j, - 3.5662489253627483e-01+0.j, -6.3034019033669797e-01+0.j, - 1.0043483479985752e-16+0.j, -2.8842036081919542e-02+0.j, - 7.7164692943283169e-25+0.j, -1.8446994643771725e-01+0.j], - [-4.7070881487314609e-01+0.j, 4.7473787464450828e-01+0.j, - -4.8036836210243361e-01+0.j, 4.3802686872516400e-01+0.j, - 1.7961797619639255e-01+0.j, 8.3080980076741355e-03+0.j, - 2.1415294457221759e-01+0.j, -2.2856669794666584e-01+0.j], - [-3.2284062926217072e-01+0.j, -5.4336490915553370e-01+0.j, - 2.2181041859724987e-01+0.j, 2.9947877954402286e-01+0.j, - -3.6491813600134637e-01+0.j, 3.2867679819727436e-01+0.j, - 3.8223299448843473e-01+0.j, -2.7266344945561438e-01+0.j], - [-1.7497244365119527e-01+0.j, -8.9251550609769331e-02+0.j, - -6.3518515114898352e-02+0.j, 1.9162997359209963e-01+0.j, - -2.2087281326110142e-01+0.j, 5.9957027043505008e-02+0.j, - -8.7632498908241274e-01+0.j, -3.1676020096456303e-01+0.j], - [-2.7104258040220017e-02+0.j, -3.3772873786627688e-01+0.j, - 2.5901386593721754e-01+0.j, 1.7032650752287815e-01+0.j, - 6.7521217612940321e-01+0.j, -4.5036136532965476e-01+0.j, - -1.2279030059078447e-02+0.j, -3.6085695247351163e-01+0.j], - [ 1.2076392757075533e-01+0.j, -3.3834734096469249e-01+0.j, - -6.5506827461665529e-01+0.j, -5.0472498521116760e-01+0.j, - 6.9987430903492132e-02+0.j, 1.0595648906599270e-01+0.j, - 8.3443844143082035e-02+0.j, -4.0495370398246017e-01+0.j], - [ 2.6863211318173102e-01+0.j, 2.2958613191407312e-01+0.j, - 6.3952843755683969e-02+0.j, 1.8776775771084192e-02+0.j, - -5.3523731432241317e-01+0.j, -5.9199531677602002e-01+0.j, - 1.7916671834524250e-01+0.j, -4.4905045549140887e-01+0.j], - [ 4.1650029879270667e-01+0.j, 3.6355449432857068e-01+0.j, - 2.9755313100756148e-01+0.j, 1.6826270392616000e-02+0.j, - 1.9621068035557282e-01+0.j, 5.6830030587314817e-01+0.j, - 2.9607517592514260e-02+0.j, -4.9314720700035747e-01+0.j]]), array([-2.4598804776133626e+01, -4.6567755957874661e-14, - -1.9932120610662194e-14, -5.7323356091157378e-15, - -4.5459724251334835e-16, 4.0479851042511616e-14, - 9.2325194924982089e-14, 2.7659880477613365e+02])), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<8x8xcomplex> {jax.result_info = "[0]"}, tensor<8xf64> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<64xcomplex> - %1 = stablehlo.reshape %0 : (tensor<64xcomplex>) -> tensor<8x8xcomplex> - %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xcomplex>) -> tensor<8x8xcomplex> - %3 = stablehlo.real %2 : (tensor<8x8xcomplex>) -> tensor<8x8xf64> - %4 = stablehlo.imag %2 : (tensor<8x8xcomplex>) -> tensor<8x8xf64> - %5 = stablehlo.negate %4 : tensor<8x8xf64> - %6 = stablehlo.complex %3, %5 : tensor<8x8xcomplex> - %7 = stablehlo.add %1, %6 : tensor<8x8xcomplex> - %8 = stablehlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor> - %9 = stablehlo.broadcast_in_dim %8, dims = [] : (tensor>) -> tensor<8x8xcomplex> - %10 = stablehlo.divide %7, %9 : tensor<8x8xcomplex> - %11 = call @tril(%10) : (tensor<8x8xcomplex>) -> tensor<8x8xcomplex> - %12 = stablehlo.constant dense<1> : tensor - %13 = stablehlo.constant dense<1> : tensor - %14 = stablehlo.constant dense<8> : tensor - %15 = stablehlo.custom_call @lapack_zheevd(%12, %13, %14, %11) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor<8x8xcomplex>) -> tuple>, tensor<8xf64>, tensor, tensor<81xcomplex>, tensor<169xf64>, tensor<43xi32>> - %16 = stablehlo.get_tuple_element %15[0] : (tuple>, tensor<8xf64>, tensor, tensor<81xcomplex>, tensor<169xf64>, tensor<43xi32>>) -> tensor<8x8xcomplex> - %17 = stablehlo.get_tuple_element %15[1] : (tuple>, tensor<8xf64>, tensor, tensor<81xcomplex>, tensor<169xf64>, tensor<43xi32>>) -> tensor<8xf64> - %18 = stablehlo.get_tuple_element %15[2] : (tuple>, tensor<8xf64>, tensor, tensor<81xcomplex>, tensor<169xf64>, tensor<43xi32>>) -> tensor - %19 = stablehlo.get_tuple_element %15[3] : (tuple>, tensor<8xf64>, tensor, tensor<81xcomplex>, tensor<169xf64>, tensor<43xi32>>) -> tensor<81xcomplex> - %20 = stablehlo.get_tuple_element %15[4] : (tuple>, tensor<8xf64>, tensor, tensor<81xcomplex>, tensor<169xf64>, tensor<43xi32>>) -> tensor<169xf64> - %21 = stablehlo.get_tuple_element %15[5] : (tuple>, tensor<8xf64>, tensor, tensor<81xcomplex>, tensor<169xf64>, tensor<43xi32>>) -> tensor<43xi32> - %22 = stablehlo.constant dense<0> : tensor - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor - %24 = stablehlo.compare EQ, %18, %23, SIGNED : (tensor, tensor) -> tensor - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1x1xi1> - %26 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> - %27 = stablehlo.broadcast_in_dim %26, dims = [] : (tensor>) -> tensor<8x8xcomplex> - %28 = stablehlo.broadcast_in_dim %25, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> - %29 = stablehlo.select %28, %16, %27 : tensor<8x8xi1>, tensor<8x8xcomplex> - %30 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi1> - %31 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %32 = stablehlo.broadcast_in_dim %31, dims = [] : (tensor) -> tensor<8xf64> - %33 = stablehlo.broadcast_in_dim %30, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> - %34 = stablehlo.select %33, %17, %32 : tensor<8xi1>, tensor<8xf64> - return %29, %34 : tensor<8x8xcomplex>, tensor<8xf64> - } - func.func private @tril(%arg0: tensor<8x8xcomplex>) -> tensor<8x8xcomplex> { - %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> - %1 = stablehlo.constant dense<0> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<8x8xi32> - %3 = stablehlo.add %0, %2 : tensor<8x8xi32> - %4 = stablehlo.iota dim = 1 : tensor<8x8xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> - %6 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<8x8xcomplex> - %8 = stablehlo.select %5, %arg0, %7 : tensor<8x8xi1>, tensor<8x8xcomplex> - return %8 : tensor<8x8xcomplex> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x015\x05\x01\x05\x01\x03\x05\x03%\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%')\x03\xc6\x02\x1e\x02?\x01\xa9\x0f\x17\x13\x0b\x17\x0b\x07\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x17\x0f\x13\x13\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x13\x13\x1b\x17\x03a\x0f/\x0b\x0b\x0f\x0b\x0bO\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0bOOO\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17#\x0f\x0f\x0f\x0f\x0f\x0f\x0bOO//\x01\x07\x17\x17\x17\x03?\x17\x0f\x07\x0f\x07\x13\x07\x07\x0b\x17\x17\x07\x17\x13\x17\x17\x13\x0f\x17\x17\x13\x17#\x13\x13\x13\x0f\x17\x13\x13\x13\x02\x96\n\x1d\x83\x03\x17\x13\xf6\x04\x01\x03\x03\x15\xd3\x05+\x17\x13\xf2\x04\x01\x05-\x1f\x1d9\x03\x05/\x051\x053\x055\x057\x059\x05;\x03\x03!\xcf\x05=\x03\x03\x07\xd1\x1d?\x03\x05?\x05A\x17\x13\xea\x04\x01\x1d}\t\x03\x03\x07\xdf\x03\x03\x113\x05C\x03\x0b\x17\xad\x19\xb9\x1b\xbb\x11\xc5\x1d\xc7\x03\x0b\x17\xb1\x19\xcb\x1b\xb1\x11\xb3\x1d\xcd\x05E\x1d=\x03\x05G\x05I\x03\x03!\xd5\x1dE\x03\x05K\x03\x05'\xb5)\xd7\x1dK\x03\x05M\x03\x03\x07\xd9\x1dQ\x03\x05O\x1dU\x03\x05Q\x1dY+\x05S\x1d]+\x05U\x03\x03a\xdb\x05W\x1de\t\x05Y\x1di\t\x05[\x1dm\t\x05]\x1dq\t\x05_\x1du\t\x05a\x1dy\t\x05c\x03\x03\x07\xdd\x05e\x03\x03\x81\xb3\x05g\x05i\x03\x03\x07\xe1\x03\x11\x89\xe3\x8b\xe5\x8d\xe7\x8f\xad\x91\xe9\x93\xeb\x95\xed\x97\xf1\x05k\x05m\x05o\x05q\x05s\x05u\x05w\x05y\x03\x03\x0b\xf3\x03\x03\x0b\xf5\x03\x03\x0b\xf7\x03\x03\x0b\xf9\x03\x03\x0b\xfb\x03\x03\x0b\xfd\x03\x05'\xb5)\xff\x03\x03\x07\x02\x02\x1f/\x01\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d{\x03\x03\xc9\x1d}\t\x07\x1f1!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#%\x03\x05\xbd\xc1\r\x03\xaf\xbf\x1d\x7f\r\x03\xaf\xc3\x1d\x81\x1d\x83\x1d\x85\r\x01#'\x1d\x87\x13\r\x01\x1f\x03\t\x00\x00\x00\x00\x1f)\x01\x13\r\x05\x07\x05\x1f\x07!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07!\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x08\x00\x00\x00\x0b\x05\x1d\x89\x1d\x8b\x05\x01\x03\t\xa9\xa9\xa9\xb7\x03\x03\xef\x15\x03\x01\r\x01\x03\r\xb7\xab\xa9\xab\xab\xab\x13\x05\x01\x13\x05\x05\x13\x05\t\x13\x05\r\x13\x05\x11\x13\x05\x15\x07\x01\x1f\x07!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f#\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f=\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03\x15\x06\x02\x03\x03\x07\n\x02\x03\x03\x15\x0e\x02)\x05!!\x11)\x01\x05\x1b)\x01\x11\x0b)\x03!\t\x1d\x01\x03\t)\x05!!\x05)\x05!!\t\x13)\x05!!\x0f)\x03\t\r)\x03\x8a\x02\x11)\x03J\x05\t)\x03\xad\x05)\x01\t\x11\x01\x05\x01\x0b\x11\x03\x01\x03\x01)\x03\x01\r)\x03\x02\x02\x11/\r\x01\x0b\x03\x1d\x1f!)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x01\x0f)\x05\x05\x05\x0f)\x03\x05\x0f)\x03!\x0f)\x03\x05\r\x04\xda\x05\x05\x01\x11\r1\x07\x03\x01\t\r\x11\r5\x05\x03G\x91\t\x03W\x1f\x03+\x15\x06[\x03\x01\x03\x01\x17\x07c_\x03\x01\x03\x03\x19\x06g\x03\x15\x03\x05\x1b\x06k\x03\x15\x03\x05\x1d\x06o\x03\x15\x03\t\x1f\x06s\x03\x01\x05\x07\x0b\x0f\x06w\x03\x01\x05\x03\r\x05\x03\r{\x03\x07\x03\x07-\x05\x03\x01\x03\x11!\x06-\x03\x01\x05\x0f\x13#\x07\x0f\x7f\x03\x01\x03\x15\x05\x03\x01/\x03\x03\x05\x03\x01/\x03\x03\x05\x03\x01\x85\x03\x03%\x07\x01\x87\x03-\t\x19\x1b\x1d\x17\x07\x07\x01\x99\x03\x01\x03\x1f\x07\x07\x01\x9b\x03\x0b\x03\x1f\x07\x07\x01\x9d\x03\x03\x03\x1f\x07\x07\x01\x9f\x03\x1d\x03\x1f\x07\x07\x01\xa1\x03\x1f\x03\x1f\x07\x07\x01\xa3\x03!\x03\x1f\x05\x03\x01#\x03\x03\x03\x07\x01\x05\x03\x03\x03-\x11\x07\x01\xa5\x035\x05%/\x03\x07\x01\x05\x037\x031\x05\x03\x01\xa7\x03\x07\x03\x07\x01\x05\x03\x01\x035\x03\x07\x01\x12\x02\x03\x19\x033\x0b\x06\x01\x03\x01\x079!7\x03\x07\x01\x05\x039\x031\x05\x03\x01\x16\x02\x03#\x03\x07\x01\x05\x03\x0b\x03?\x03\x07\x01\x1a\x02\x03;\x03=\x0b\x06\x01\x03\x0b\x07C#A\x13\x04\r\x05;E\r\x11\x0f7\x05\x03\x15+\x03\x01\r\t\x03;\x1f\x03\x13\x05\x03\x0f#\x03\x03\x03\x07%\x05\x03\x13\x03\x05\x0f\x06%\x03\x13\x05\x03\x07\t\x03CA\x03\x13\x11\x07IG\x03\x19\x05\t\x0b\x05\x03\x0fM\x03\x07\x03\x07O\x05\x03\x01\x03\x0f\x0b\x06S\x03\x01\x07\r\x01\x11\x13\x04\x0f\x03\x13\x06\x03\x01\x05\x01\x00J\x1c\x8d\x1d\x03\x11\x0f\x0b\t\t\x0b!\x1f/!!)#\x1f\x19\x7f\x0f99A9;;m\x19\x85\x8fW\xb3K\x9bM\x9b\x96\x04\x1b+\x1b\x1f\x1f\x15\x1d\x15+\x83\x13\r\r\x1f\x11\x15\x17\x15\x11\x11\x1b\x17\x15\x17\x0f\x11\x15\x11+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00value\x00index\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=complex128 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/real\x00jit()/jit(main)/imag\x00jit()/jit(main)/neg\x00jit()/jit(main)/complex\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00lapack_zheevd\x00", - xla_call_module_version=4, - ), # End paste -) +array = np.array +float32 = np.float32 +complex64 = np.complex64 data_2024_08_19 = {} - # Pasted from the test output (see export_back_compat_test_util.py module docstring) data_2024_08_19["c128"] = dict( testdata_version=1, diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_hessenberg_lapack_gehrd.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_hessenberg_lapack_gehrd.py index 204af8f55396..3fe090e8b270 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_hessenberg_lapack_gehrd.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_hessenberg_lapack_gehrd.py @@ -15,277 +15,14 @@ # ruff: noqa import datetime -from numpy import array, float32, complex64 +import numpy as np -data_2024_08_30 = {} - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_30["c128"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zgehrd'], - serialized_date=datetime.date(2024, 8, 30), - inputs=(), - expected_outputs=(array([[[ 0.7137638961069523 +2.4533812415320035e+00j, - -0.3272236912989258 -3.2003874808591863e+00j, - -3.065817294924296 +1.6978219378771007e+00j, - -3.3971558164664 +2.6931967836060400e-01j], - [ 6.346214936866542 +0.0000000000000000e+00j, - 2.083218259144673 -1.2191838498692813e+00j, - 1.9552582313969427 -3.3216313521481879e+00j, - 2.7451664155727293 +2.5460553490974451e+00j], - [-0.16133388943502391 +3.6906265775683444e-01j, - -4.698636849217318 +0.0000000000000000e+00j, - 2.5396292124414077 -3.3038474840573420e+00j, - 2.5410992366186456 +4.1958389320867528e-01j], - [ 0.47396123039280513 +3.9524384493417053e-03j, - 0.058880409351504966-7.8934332132630333e-02j, - 0.9469634796174572 +0.0000000000000000e+00j, - -3.130422531669044 -8.8070401977461810e-01j]], - - [[-6.7065483048969465 -4.1981401054281309e-01j, - -0.21813268822330256 -3.8602920478381799e+00j, - -0.8248337528620167 -2.9073223456990824e+00j, - -3.597231249446879 +2.7626541679004930e+00j], - [-6.812126638479044 +0.0000000000000000e+00j, - -0.20651586628458585 -1.0948249928988512e+00j, - -1.6675586608354327 +4.2553627621795744e+00j, - -2.410110723267707 +3.6065122124698634e-01j], - [ 0.038235817369200516-3.7823713529009173e-01j, - -8.508141062606947 +0.0000000000000000e+00j, - 4.260708077719245 -6.8052584397204630e-02j, - 5.345997177836541 -1.1955161503390279e+00j], - [-0.18541509608158574 -1.2016051097247168e-01j, - -0.02698777746917469 -4.4847463691672246e-01j, - 6.149305574585603 +0.0000000000000000e+00j, - -2.483131585236393 +2.8524912589603817e+00j]]]), array([[1.2286220194325557+0.5121060656500841j , - 1.9529937219183482-0.23299856112387676j, - 1.5940499664125072-0.8044281430962614j ], - [1.6682114302246909-0.11372755955977935j, - 1.4075913155446236-0.6008708461880701j , - 1.5086928152468893-0.8609480935086589j ]])), - mlir_module_text=r""" -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { - %cst = stablehlo.constant dense<[[[(0.71376389610695234,2.4533812415320035), (-1.0686093138739379,-1.885041510645256), (3.2629529488994033,-0.87160041258342402), (2.4332168907311504,3.4960248990882183)], [(-1.450884474619478,-3.249935163088522), (0.53920035905924757,-5.0056840575116066), (0.13157186736298554,2.5015499854549939), (-1.2451270607408882,0.24345856951924827)], [(2.457366083193417,-2.3532935513245605), (-0.37595429769485644,1.5729223427874068), (3.5877693970448052,-0.30904304334212157), (-1.685615117470264,2.6148811836470265)], [(-3.6826776618664727,-1.5711608241015744), (-0.12407609317204518,-4.7137561145212281), (1.3298255603911306,-1.6739172003954141), (-2.6345448161870149,-0.089008252847513236)]], [[(-6.7065483048969465,-0.41981401054281309), (-2.1586544949255457,0.34815132010709054), (-5.1462488701272413,3.440817752555807), (1.0301804086076078,-0.6994760434270566)], [(4.551940883969797,-0.77472653800638502), (4.4485186470774796,-0.0024458890677252756), (0.66610302132250898,2.5976571401862039), (-5.0693248202533674,-5.7405538897950699)], [(0.14148406399087146,-4.3279346473525058), (-2.353557113110897,2.0880432773400326), (-3.2524452107293618,-0.42398740171508631), (3.7200566224095519,-0.56951559566037058)], [(-2.2001612082232613,-1.2218661647417151), (0.72437359623190833,8.6381970213061301), (0.72314820631775734,0.058458198280771749), (0.37498718985014962,2.1160469724471378)]]]> : tensor<2x4x4xcomplex> loc(#loc) - %c = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_2 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_3 = stablehlo.constant dense<2> : tensor loc(#loc2) - %c_4 = stablehlo.constant dense<4288> : tensor loc(#loc2) - %0:4 = stablehlo.custom_call @lapack_zgehrd(%c, %c_0, %c_1, %c_2, %c_3, %c_4, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x3xcomplex>, tensor<2xi32>, tensor<4288xcomplex>) loc(#loc2) - %c_5 = stablehlo.constant dense<0> : tensor loc(#loc2) - %1 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %2 = stablehlo.compare EQ, %0#2, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %cst_6 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc2) - %4 = stablehlo.broadcast_in_dim %cst_6, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc2) - %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %6 = stablehlo.select %5, %0#0, %4 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc2) - %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %cst_7 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc2) - %8 = stablehlo.broadcast_in_dim %cst_7, dims = [] : (tensor>) -> tensor<2x3xcomplex> loc(#loc2) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc2) - %10 = stablehlo.select %9, %0#1, %8 : tensor<2x3xi1>, tensor<2x3xcomplex> loc(#loc2) - return %6, %10 : tensor<2x4x4xcomplex>, tensor<2x3xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":697:13) -#loc2 = loc("jit(func)/jit(main)/hessenberg"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xef\xa19\x01W\x0f\x0b\x07\x0b\x13\x13\x0f\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03K\x0f\x0b\x0b\x0b\x0bo/\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b&\x10\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\'\x0f\x17\x1bO\x1f\x0f\x0b\x0b/OoO\x01\x05\x0b\x0f\x035\x0f\x1b\x07\x0b\x17\x07\x07\x0f\x07\x13\x17\x07\x17\x13\x13\x13\x13\x13\x13\x1b\x13\x1b\x13\x17\x17\x13\x02\xce\x0f\x1d-/\x05\x15\x1f\x05\x17\x03\x03\x03w\x03\x03\x07\x93\x11\x03\x05\x05\x19\x03\x03\x07\x99\x03\x03\x03\x9b\x03\t\x17\x19\x1b\r\x1d\r\x0f\x1f\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b#Y%e\'g\x0fq)s\x05#\x05%\x05\'\x05)\x03\x03\x03u\x05+\x171\xe6\n\x1b\x05-\x03\x03\x03y\x03\x03\x03{\x03\x03\x03}\x03\x11;\x7f=\x81?\x83AYC\x85E\x87G\x89I\x8d\x05/\x051\x053\x055\x057\x059\x05;\x05=\x03\x03\x03\x91\x03\x05O\x95Q\x97\x05?\x05A\x03\x03\x07\x9d\x03\x03\x07\x9f\x1f\x1f\x01\x03\x01\x1dC\x1dE\x1dG\x1f!1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f%\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x19\x03\x05im\r\x05[k]_\x1dI\r\x05[o]_\x1dK\x1dM\x1dO\x1f\x07\x02\x08p\t\xdba\'\xd7\xe6?\xa8\xff\'X\x86\xa0\x03@\x0c\xa2t\x14\x06\x19\xf1\xbfT.}I!)\xfe\xbf\x0fG_\x13\x87\x1a\n@\xae:g\x8c&\xe4\xeb\xbf\xeb\x1e\xcej:w\x03@N\xaf\xfc\xe6\xdb\xf7\x0b@\x9f<\x8c\xa3\xd26\xf7\xbf^\xaf\xbc\x01\xde\xff\t\xc0b\xd4\x84\x1c!A\xe1?\xd6{\xa4\n\xd2\x05\x14\xc0\xf0\xe6\xb2\xd1X\xd7\xc0?2\xb5\x86\xa3,\x03\x04@\x91\xf2SZ\n\xec\xf3\xbf\x04\x10\x02\x81\xa6)\xcf?8\xec\x8c\x8c\xaf\xa8\x03@\r\x9d\xc6\x91\x8b\xd3\x02\xc0\xb0\xf6X\x9d\xa2\x0f\xd8\xbf\xbd\xb6V\x9e\xb0*\xf9?7-\x0fq\xc0\xb3\x0c@{|\ry\\\xc7\xd3\xbf\x04\xd9\xb2\x8eG\xf8\xfa\xbf\x9b\x84u\xd3F\xeb\x04@\xf4h\xbb\xb4\x1fv\r\xc0\xdc\\D\x88y#\xf9\xbf\x9a\xaecjs\xc3\xbf\xbf<\xc1\x04\xe2\xe2\xda\x12\xc0\x89<\xb4*\xf7F\xf5?\x1b\x90\xfef]\xc8\xfa\xbf\xdc\xf4\x8a;\x8c\x13\x05\xc0\xf8\xdd\r\xaf>\xc9\xb6\xbfvN\x1af\x81\xd3\x1a\xc0Z\xc6k\x95;\xde\xda\xbf\x87\x8c\xd8\xa5\xecD\x01\xc0\xdd\xd3zy\x1cH\xd6?\x04\x18\x89C\xc2\x95\x14\xc0\x8c\xc95u\xcb\x86\x0b@\x881\xbfs\x9e{\xf0?\x92Y[\x95\x1bb\xe6\xbf\x06\xe7\xb7\xfd/5\x12@L\x95\x02O\x8f\xca\xe8\xbf2`\xe3xH\xcb\x11@>\xda\xc6\xb1f\td\xbfZ\x1a\x8bH\xb7P\xe5?\xa8\x90zw\x00\xc8\x04@<(\xef\x15\xfdF\x14\xc0\xb4aF\xc2S\xf6\x16\xc0\xc1{\xdfY&\x1c\xc2?\xcfj\xa6\x19\xceO\x11\xc0\xc4\xa2p\xc0\x15\xd4\x02\xc0\xfcv\xa6\x08P\xb4\x00@^\xea\xa0\xfe\x01\x05\n\xc0^\x11\x12\x0e\x9c"\xdb\xbfR#\xe4\x0b\xad\xc2\r@F\x8b=\xc5x9\xe2\xbfZ\xf9\x99\x1e\xee\x99\x01\xc0My\x1a\x89\xc3\x8c\xf3\xbf\xd1\xdc<\x89\x11.\xe7?2\xd4\x8d\xc2\xc1F!@mw\t\xb5\x07$\xe7?G\x16\x99\xa3;\xee\xad?M\xd24E\xca\xff\xd7?\xa2\xae\xfb\x08\xaa\xed\x00@\x1f\x05\t\x04\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\xc0\x10\x00\x00\x0b\x05\x1dQ\x1dS\x05\x01\x03\x0fWWWWWWa\x03\x03\x8b\x15\x03\x01\x19\x01\x03\ta\x8fcc\x1f#!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x1f\'\x01\t\x07\x07\x01\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x13!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f11\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f7!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x01\x15)\x07\t\x11\x11\x0b\x01\x03\x1b)\x05\t\r\x0b\x13\x1d)\x01\x0b\x1b)\x03\t\x15\x11\x01\x05\x07\r\x0b)\x03\x02\x86\x0b)\x03\x01\x0f)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x11)\x03\t\t)\x07\t\x05\x05\t)\x03\x05\x11)\x07\t\x11\x11\t)\x03\r\x11)\x05\t\x05\t)\x05\t\r\t)\x03\t\x11\x04\xde\x02\x05\x01\x11\x05\x15\x07\x03\x01\x05\t\x11\x05!\x07\x031Y\x03\x03\x05+\x03\x07\x03\x03\x01\t\x03\x05\x03\x03\x013\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x015\x03\x05\x03\x03\x017\x03\x05\x0b\x07\x019\t\x07\r\x17\x1d\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01K\x03\x05\x05\x07\x01\x0b\x03\x17\x03\x17\r\x07\x01M\x03)\x05\x13\x19\x05\x07\x01\x11\x03+\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\x07\x03\x1f\x05\x07\x01S\x03/\x03\x1d\x07\x06\x01\x03\x07\x07#\x0f!\x05\x07\x01\x11\x033\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\r\x03)\x05\x07\x01U\x035\x03\'\x07\x06\x01\x03\r\x07-\x11+\x0f\x04\x05\x05%/\x06\x03\x01\x05\x01\x00\xf2\tU\x1d\x03\x0f\x0b\t\t\x11#!+\x1b\x1f/!!)#\x1f\x19i?\x1f\x15\x1d\x15\x13%)9\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/hessenberg\x00third_party/py/jax/tests/export_back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_zgehrd\x00', - xla_call_module_version=9, - nr_devices=1, -) # End paste - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_30["c64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_cgehrd'], - serialized_date=datetime.date(2024, 8, 30), - inputs=(), - expected_outputs=(array([[[ 5.2023945 -0.878671j , -2.8841915 -0.47488597j , - 1.3024182 +0.6651789j , 4.9291854 -1.9147056j ], - [ 6.3457894 +0.j , 1.6869383 -4.6557646j , - 0.88955224-1.7617276j , 2.9149916 +4.342665j ], - [-0.2465725 -0.5776757j , -5.3007755 +0.j , - -0.9786545 -0.0633831j , -1.3690261 -1.5921416j ], - [ 0.35462287+0.35993803j , -0.38403815-0.46558398j , - 2.8020499 +0.j , 0.5636822 -6.218306j ]], - - [[ 1.0687767 -3.88293j , -4.0144 -2.5885587j , - 5.3900986 -0.8850739j , 2.079677 +3.5515747j ], - [ 7.5675693 +0.j , 0.5971966 -3.6699948j , - 2.246994 -1.0858283j , -0.8870981 -0.022960603j], - [-0.2183232 +0.10552277j , 5.860886 +0.j , - -5.091036 +6.2841997j , 5.008773 +1.8765848j ], - [ 0.1378771 +0.427895j , 0.63263524-0.3470098j , - 6.4528017 +0.j , -4.233642 -0.84165764j ]]], - dtype=complex64), array([[1.0933675-0.3605358j , 1.1987956+0.5659744j , - 1.9999101-0.013409062j], - [1.4504763-0.44363326j , 1.3110259-0.07426627j , - 1.227255 +0.97383535j ]], dtype=complex64)), - mlir_module_text=r""" -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { - %cst = stablehlo.constant dense<[[[(5.20239449,-0.87867099), (-0.211780012,-0.923053801), (-5.25181627,1.90887547), (-1.61342144,-1.98000157)], [(-5.924900e-01,2.28788424), (-1.74142945,-3.25563216), (3.08765078,-3.25260139), (-3.35189271,-0.571629047)], [(3.032444,3.44394636), (1.22205484,0.808871626), (2.58686161,-7.47011566), (1.9139297,-2.57945323)], [(-3.28396916,-1.68601465), (2.62759161,-0.953538239), (-2.78763294,-0.0429570749), (0.426534384,-0.211706176)]], [[(1.06877673,-3.882930e+00), (-0.0192247611,5.96663713), (1.15329504,-5.0599103), (-1.76508892,-1.98541296)], [(-3.40901089,3.35722542), (-6.13531398,2.55851483), (-4.8095789,0.164206699), (-0.247624069,-3.13545418)], [(2.04217815,-1.89123917), (-1.18974173,-1.69466627), (-2.28673625,-0.487834573), (3.01541853,-1.85637176)], [(-2.9499588,-4.23393869), (8.44624137,5.57274485), (-1.09048736,2.4864223), (-0.305431545,-0.298133373)]]]> : tensor<2x4x4xcomplex> loc(#loc) - %c = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_2 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_3 = stablehlo.constant dense<2> : tensor loc(#loc2) - %c_4 = stablehlo.constant dense<4288> : tensor loc(#loc2) - %0:4 = stablehlo.custom_call @lapack_cgehrd(%c, %c_0, %c_1, %c_2, %c_3, %c_4, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x3xcomplex>, tensor<2xi32>, tensor<4288xcomplex>) loc(#loc2) - %c_5 = stablehlo.constant dense<0> : tensor loc(#loc2) - %1 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %2 = stablehlo.compare EQ, %0#2, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %cst_6 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc2) - %4 = stablehlo.broadcast_in_dim %cst_6, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc2) - %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %6 = stablehlo.select %5, %0#0, %4 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc2) - %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %cst_7 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc2) - %8 = stablehlo.broadcast_in_dim %cst_7, dims = [] : (tensor>) -> tensor<2x3xcomplex> loc(#loc2) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc2) - %10 = stablehlo.select %9, %0#1, %8 : tensor<2x3xi1>, tensor<2x3xcomplex> loc(#loc2) - return %6, %10 : tensor<2x4x4xcomplex>, tensor<2x3xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":697:13) -#loc2 = loc("jit(func)/jit(main)/hessenberg"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xef\xa19\x01W\x0f\x0b\x07\x0b\x13\x13\x0f\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03K\x0f\x0b\x0b\x0b\x0bo/\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b&\x08\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\'\x0f\x17\x1bO\x1f\x0f\x0b\x0b//oO\x01\x05\x0b\x0f\x035\x0f\x1b\x07\x0b\x17\x07\x07\x0f\x07\x13\x17\x07\x17\x13\x13\x13\x13\x13\x13\x1b\x13\x1b\x13\x17\x17\x13\x02\xae\x0b\x1d-/\x05\x15\x1f\x05\x17\x03\x03\x03w\x03\x03\x07\x93\x11\x03\x05\x05\x19\x03\x03\x07\x99\x03\x03\x03\x9b\x03\t\x17\x19\x1b\r\x1d\r\x0f\x1f\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b#Y%e\'g\x0fq)s\x05#\x05%\x05\'\x05)\x03\x03\x03u\x05+\x171\xe6\n\x1b\x05-\x03\x03\x03y\x03\x03\x03{\x03\x03\x03}\x03\x11;\x7f=\x81?\x83AYC\x85E\x87G\x89I\x8d\x05/\x051\x053\x055\x057\x059\x05;\x05=\x03\x03\x03\x91\x03\x05O\x95Q\x97\x05?\x05A\x03\x03\x07\x9d\x03\x03\x07\x9f\x1f\x1f\x01\x03\x01\x1dC\x1dE\x1dG\x1f!1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f%\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x19\x03\x05im\r\x05[k]_\x1dI\r\x05[o]_\x1dK\x1dM\x1dO\x1f\x07\x02\x04\x04z\xa6@\x95\xf0`\xbf\xdc\xdcX\xbeAMl\xbf\xe1\x0e\xa8\xc0\x08V\xf4?\x98\x84\xce\xbf\xb1p\xfd\xbfm\xad\x17\xbf\xb2l\x12@)\xe7\xde\xbfG\\P\xc0\x12\x9cE@\x9f*P\xc0i\x85V\xc0HV\x12\xbf\x90\x13B@\x9ei\\@Kl\x9c?6\x12O?$\x8f%@0\x0b\xef\xc0\xa6\xfb\xf4?\xc3\x15%\xc0\x8d,R\xc0T\xcf\xd7\xbfv*(@\x15\x1bt\xbf\x94h2\xc0\xc2\xf3/\xbd\xb7b\xda>\x81\xc9X\xbe\xad\xcd\x88?\xed\x81x\xc0?}\x9d\xbc\xb1\xee\xbe@,\x9f\x93?\xc9\xea\xa1\xc0o\xee\xe1\xbf\x03"\xfe\xbf<-Z\xc0\xc8\xdcV@~T\xc4\xc0\xb5\xbe#@\x12\xe8\x99\xc0\xcd%(>*\x91}\xbeH\xabH\xc0\x0c\xb3\x02@ \x14\xf2\xbfuI\x98\xbf\xd3\xea\xd8\xbf\xe3Y\x12\xc0t\xc5\xf9\xbe\x9e\xfc@@\x97\x9d\xed\xbf \xcc<\xc0m|\x87\xc0\xce#\x07A\xedS\xb2@\x17\x95\x8b\xbf\x8b!\x1f@\x86a\x9c\xbe\xf0\xa4\x98\xbe\x1f\x05\t\x04\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\xc0\x10\x00\x00\x0b\x05\x1dQ\x1dS\x05\x01\x03\x0fWWWWWWa\x03\x03\x8b\x15\x03\x01\x19\x01\x03\ta\x8fcc\x1f#!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x1f\'\x01\t\x07\x07\x01\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x13\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f11\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f7!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x01\x15)\x07\t\x11\x11\x0b\x01\x03\x1b)\x05\t\r\x0b\x13\x1d)\x01\x0b\x1b)\x03\t\x15\x11\x01\x05\x07\r\t)\x03\x02\x86\x0b)\x03\x01\x0f)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x11)\x03\t\t)\x07\t\x05\x05\t)\x03\x05\x11)\x07\t\x11\x11\t)\x03\r\x11)\x05\t\x05\t)\x05\t\r\t)\x03\t\x11\x04\xde\x02\x05\x01\x11\x05\x15\x07\x03\x01\x05\t\x11\x05!\x07\x031Y\x03\x03\x05+\x03\x07\x03\x03\x01\t\x03\x05\x03\x03\x013\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x015\x03\x05\x03\x03\x017\x03\x05\x0b\x07\x019\t\x07\r\x17\x1d\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01K\x03\x05\x05\x07\x01\x0b\x03\x17\x03\x17\r\x07\x01M\x03)\x05\x13\x19\x05\x07\x01\x11\x03+\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\x07\x03\x1f\x05\x07\x01S\x03/\x03\x1d\x07\x06\x01\x03\x07\x07#\x0f!\x05\x07\x01\x11\x033\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\r\x03)\x05\x07\x01U\x035\x03\'\x07\x06\x01\x03\r\x07-\x11+\x0f\x04\x05\x05%/\x06\x03\x01\x05\x01\x00\xf2\tU\x1d\x03\x0f\x0b\t\t\x11#!+\x1b\x1f/!!)#\x1f\x19i?\x1f\x15\x1d\x15\x13%)9\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/hessenberg\x00third_party/py/jax/tests/export_back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_cgehrd\x00', - xla_call_module_version=9, - nr_devices=1, -) # End paste - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_30["f32"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_sgehrd'], - serialized_date=datetime.date(2024, 8, 30), - inputs=(), - expected_outputs=(array([[[-3.5237675 , -6.1161256 , -0.549011 , -4.7706876 ], - [ 5.8401766 , 3.424213 , 0.3059119 , 2.3492367 ], - [ 0.63135445 , 2.7238827 , -0.106214404, -0.82470125 ], - [-0.27146497 , 0.09917235 , 0.2545611 , -0.5113605 ]], - - [[ 4.297168 , -1.8758869 , 0.33528137 , 5.867136 ], - [-7.129698 , -3.3118155 , -1.3492918 , -2.8959117 ], - [-0.7266852 , -3.506432 , 4.77164 , -4.0780373 ], - [ 0.14084078 , 0.3389384 , 2.3910007 , -0.79807365 ]]], - dtype=float32), array([[1.3584172, 1.9805213, 0. ], - [1.2920669, 1.7939165, 0. ]], dtype=float32)), - mlir_module_text=r""" -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x4x4xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { - %cst = stablehlo.constant dense<[[[-3.52376747, -0.758410036, 4.85795927, -6.0243597], [-2.09321976, -1.27957773, -0.956288218, -1.11928439], [-5.00878525, 0.51314038, 3.53047514, -2.91282868], [2.15363932, 0.635739565, -0.21264787, 0.555740714]], [[4.29716778, -3.86209464, -2.39021468, 4.17441607], [2.08234859, -1.03958249, 4.09025383, 5.22586823], [-6.69425774, 3.43749118, -0.691099107, 1.59547663], [1.29743183, -2.00156212, 3.08750296, 2.39243269]]]> : tensor<2x4x4xf32> loc(#loc) - %c = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_2 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_3 = stablehlo.constant dense<2> : tensor loc(#loc2) - %c_4 = stablehlo.constant dense<4288> : tensor loc(#loc2) - %0:4 = stablehlo.custom_call @lapack_sgehrd(%c, %c_0, %c_1, %c_2, %c_3, %c_4, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xf32>) -> (tensor<2x4x4xf32>, tensor<2x3xf32>, tensor<2xi32>, tensor<4288xf32>) loc(#loc2) - %c_5 = stablehlo.constant dense<0> : tensor loc(#loc2) - %1 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %2 = stablehlo.compare EQ, %0#2, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %cst_6 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) - %4 = stablehlo.broadcast_in_dim %cst_6, dims = [] : (tensor) -> tensor<2x4x4xf32> loc(#loc2) - %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %6 = stablehlo.select %5, %0#0, %4 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc2) - %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %cst_7 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %cst_7, dims = [] : (tensor) -> tensor<2x3xf32> loc(#loc2) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc2) - %10 = stablehlo.select %9, %0#1, %8 : tensor<2x3xi1>, tensor<2x3xf32> loc(#loc2) - return %6, %10 : tensor<2x4x4xf32>, tensor<2x3xf32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":697:13) -#loc2 = loc("jit(func)/jit(main)/hessenberg"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xed\xa17\x01W\x0f\x0b\x07\x0b\x13\x13\x0f\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03K\x0f\x0b\x0b\x0b\x0bo/\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b&\x04\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\'\x0f\x17\x1bO\x1f\x0f\x0b\x0b/\x1foO\x01\x05\x0b\x0f\x033\x0f\x1b\x07\x07\x17\x07\x07\x0f\x07\x13\x17\x17\x13\x13\x13\x13\x13\x13\x1b\x13\x1b\x13\x17\x17\x13\x02\x96\t\x1d-/\x05\x15\x1f\x05\x17\x03\x03\x03w\x03\x03\x07\x93\x11\x03\x05\x05\x19\x03\x03\x07\x99\x03\x03\x03\x9b\x03\t\x17\x19\x1b\r\x1d\r\x0f\x1f\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b#Y%e\'g\x0fq)s\x05#\x05%\x05\'\x05)\x03\x03\x03u\x05+\x171\xe6\n\x1b\x05-\x03\x03\x03y\x03\x03\x03{\x03\x03\x03}\x03\x11;\x7f=\x81?\x83AYC\x85E\x87G\x89I\x8d\x05/\x051\x053\x055\x057\x059\x05;\x05=\x03\x03\x03\x91\x03\x05O\x95Q\x97\x05?\x05A\x03\x03\x07\x9d\x03\x03\x07\x9f\x1f\x1d\x01\x03\x01\x1dC\x1dE\x1dG\x1f\x1f1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x19\x03\x05im\r\x05[k]_\x1dI\r\x05[o]_\x1dK\x1dM\x1dO\x1f\x07\x02\x02h\x85a\xc0)\'B\xbfgt\x9b@\x8e\xc7\xc0\xc0P\xf7\x05\xc04\xc9\xa3\xbfN\xcft\xbf\xb6D\x8f\xbf\xf8G\xa0\xc0+]\x03?N\xf3a@\xc9k:\xc0:\xd5\t@\xd4\xbf"?]\xc0Y\xbe\x06E\x0e?f\x82\x89@\x8f,w\xc0G\xf9\x18\xc0\xd1\x94\x85@3E\x05@\n\x11\x85\xbf\\\xe3\x82@P:\xa7@\\7\xd6\xc0\xdb\xff[@\xdf\xeb0\xbf\x948\xcc??\x12\xa6?\x98\x19\x00\xc0\xa6\x99E@\x9e\x1d\x19@\x1f\x05\t\x04\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\xc0\x10\x00\x00\x0b\x05\x1dQ\x1dS\x05\x01\x03\x0fWWWWWWa\x03\x03\x8b\x15\x03\x01\x19\x01\x03\ta\x8fcc\x1f!!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x1f%\x01\t\x07\x07\x01\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x13\t\x00\x00\xc0\x7f\x1f/1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f5!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x01\x15)\x07\t\x11\x11\x0b\x01\t)\x05\t\r\x0b\x13\x1d)\x01\x0b\x1b)\x03\t\x15\x11\x01\x05\x07\r)\x03\x02\x86\x0b)\x03\x01\x0f)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x11)\x03\t\t)\x07\t\x05\x05\t)\x03\x05\x11)\x07\t\x11\x11\t)\x03\r\x11)\x05\t\x05\t)\x05\t\r\t)\x03\t\x11\x04\xde\x02\x05\x01\x11\x05\x15\x07\x03\x01\x05\t\x11\x05!\x07\x031Y\x03\x03\x05+\x03\x07\x03\x03\x01\t\x03\x05\x03\x03\x013\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x015\x03\x05\x03\x03\x017\x03\x05\x0b\x07\x019\t\x07\r\x17\x1b\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01K\x03\x05\x05\x07\x01\x0b\x03\x17\x03\x17\r\x07\x01M\x03\'\x05\x13\x19\x05\x07\x01\x11\x03)\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\x07\x03\x1f\x05\x07\x01S\x03-\x03\x1d\x07\x06\x01\x03\x07\x07#\x0f!\x05\x07\x01\x11\x031\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\r\x03)\x05\x07\x01U\x033\x03\'\x07\x06\x01\x03\r\x07-\x11+\x0f\x04\x05\x05%/\x06\x03\x01\x05\x01\x00\xf2\tU\x1d\x03\x0f\x0b\t\t\x11#!+\x1b\x1f/!!)#\x1f\x19i?\x1f\x15\x1d\x15\x13%)9\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/hessenberg\x00third_party/py/jax/tests/export_back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_sgehrd\x00', - xla_call_module_version=9, - nr_devices=1, -) # End paste - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_30["f64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dgehrd'], - serialized_date=datetime.date(2024, 8, 30), - inputs=(), - expected_outputs=(array([[[ 0.9307390587491866 , -0.35692982324474015 , - -0.1271353200176119 , -0.43952156917870067 ], - [ 2.2633695323673964 , 0.9965090965971986 , - -1.3244131008423046 , 1.7324542351344163 ], - [ 0.24558316247256504 , 2.922776762811796 , - 3.630059093036474 , 1.4330664619737252 ], - [-0.2856727718012896 , -0.4601276537179077 , - -2.8602148466873802 , 1.9928744545245372 ]], - - [[-0.5351339571818844 , 5.753313169426148 , - 0.1385440281649789 , 2.8445493054193807 ], - [ 4.676815781213274 , 2.920688567170204 , - -2.610159425457712 , 4.0359806870679655 ], - [-0.16963242599901043 , -2.342935131066633 , - 4.179999589709703 , -0.6810604472011716 ], - [ 0.030645999613174775, -0.2271804227402005 , - -2.2755242550977153 , 0.7136684502626782 ]]]), array([[1.751436143556826 , 1.6505497938190505, 0. ], - [1.9422862513069978, 1.9018440331997255, 0. ]])), - mlir_module_text=r""" -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x4x4xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { - %cst = stablehlo.constant dense<[[[0.93073905874918661, 0.18483901505653183, -0.11804347408930886, -0.53725392025434981], [-1.700777672846173, 1.3531570270421245, -2.4375034855727518, 2.2945174202226699], [-0.97352780716312858, -0.8319788592736328, 2.4986640885328582, -2.8118637941861766], [1.1324489199416958, -1.9301638714393787, 1.5523821278819048, 2.7676215285832253]], [[-0.53513395718188439, -5.2137633671981938, 2.9644475919777618, 2.2891023676266191], [-4.4068992105328642, 1.2751848926168665, -2.8947257279736456, -2.6817410994805888], [1.5408926111334784, -0.85423691880254915, 6.4217874587762065, -0.43997818045540715], [-0.27837952612324207, 1.1509460853774549, -0.21686805683301608, 0.11738425574951133]]]> : tensor<2x4x4xf64> loc(#loc) - %c = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_2 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_3 = stablehlo.constant dense<2> : tensor loc(#loc2) - %c_4 = stablehlo.constant dense<4288> : tensor loc(#loc2) - %0:4 = stablehlo.custom_call @lapack_dgehrd(%c, %c_0, %c_1, %c_2, %c_3, %c_4, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xf64>) -> (tensor<2x4x4xf64>, tensor<2x3xf64>, tensor<2xi32>, tensor<4288xf64>) loc(#loc2) - %c_5 = stablehlo.constant dense<0> : tensor loc(#loc2) - %1 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %2 = stablehlo.compare EQ, %0#2, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %cst_6 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) - %4 = stablehlo.broadcast_in_dim %cst_6, dims = [] : (tensor) -> tensor<2x4x4xf64> loc(#loc2) - %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %6 = stablehlo.select %5, %0#0, %4 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc2) - %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %cst_7 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %cst_7, dims = [] : (tensor) -> tensor<2x3xf64> loc(#loc2) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc2) - %10 = stablehlo.select %9, %0#1, %8 : tensor<2x3xi1>, tensor<2x3xf64> loc(#loc2) - return %6, %10 : tensor<2x4x4xf64>, tensor<2x3xf64> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":697:13) -#loc2 = loc("jit(func)/jit(main)/hessenberg"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xed\xa17\x01W\x0f\x0b\x07\x0b\x13\x13\x0f\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03K\x0f\x0b\x0b\x0b\x0bo/\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b&\x08\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\'\x0f\x17\x1bO\x1f\x0f\x0b\x0b//oO\x01\x05\x0b\x0f\x033\x0f\x1b\x07\x07\x17\x07\x07\x0f\x07\x13\x17\x17\x13\x13\x13\x13\x13\x13\x1b\x13\x1b\x13\x17\x17\x13\x02\xa6\x0b\x1d-/\x05\x15\x1f\x05\x17\x03\x03\x03w\x03\x03\x07\x93\x11\x03\x05\x05\x19\x03\x03\x07\x99\x03\x03\x03\x9b\x03\t\x17\x19\x1b\r\x1d\r\x0f\x1f\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b#Y%e\'g\x0fq)s\x05#\x05%\x05\'\x05)\x03\x03\x03u\x05+\x171\xe6\n\x1b\x05-\x03\x03\x03y\x03\x03\x03{\x03\x03\x03}\x03\x11;\x7f=\x81?\x83AYC\x85E\x87G\x89I\x8d\x05/\x051\x053\x055\x057\x059\x05;\x05=\x03\x03\x03\x91\x03\x05O\x95Q\x97\x05?\x05A\x03\x03\x07\x9d\x03\x03\x07\x9f\x1f\x1d\x01\x03\x01\x1dC\x1dE\x1dG\x1f\x1f1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x19\x03\x05im\r\x05[k]_\x1dI\r\x05[o]_\x1dK\x1dM\x1dO\x1f\x07\x02\x04\xa6\x00NG\x9d\xc8\xed?\xf2\xa8X\n\xce\xa8\xc7?#E\xb8\xdc\x188\xbe\xbf\xb8|$"/1\xe1\xbf\xc4B*\xa6b6\xfb\xbf\xe8\xf9\x97\xfb\x87\xa6\xf5?)^\xd3\xd3\x01\x80\x03\xc0T\xab\xff\xf2+[\x02@4d\xb0\xc9#\'\xef\xbf~e\xf1 \x92\x9f\xea\xbf\x96\x81\xff\x98C\xfd\x03@W\xb0\xe6q\xb2~\x06\xc0F\xa48\xc2\x82\x1e\xf2?\xcc\x0b\xfc\x82\xf3\xe1\xfe\xbf\xdc\\b\xa4\x8e\xd6\xf8?\x8c\xc3\x87\xc1\x16$\x06@\x83h\xa2?\xd1\x1f\xe1\xbf\xdc\xcb\xbc\xc8\xe4\xda\x14\xc0\xe6\x00\x92L0\xb7\x07@Q8\xf1\xe6\x14P\x02@\t\x07\xc8/\xaa\xa0\x11\xc0\x8eH"F(g\xf4?\xf5Jd\xf6e(\x07\xc0\x9e\xddt\xad4t\x05\xc0\x1cv\xb7\x02\x7f\xa7\xf8?B^\xa9\xa9\xe8U\xeb\xbf\x1e:5\r\xe9\xaf\x19@\xa2\x9c\x00>\x9a(\xdc\xbf\xc1\xd1$\\\xf8\xd0\xd1\xbf}|BqFj\xf2?6\x8b\xd2\x1dU\xc2\xcb\xbfdk\x82\x03\xe5\x0c\xbe?\x1f\x05\t\x04\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\xc0\x10\x00\x00\x0b\x05\x1dQ\x1dS\x05\x01\x03\x0fWWWWWWa\x03\x03\x8b\x15\x03\x01\x19\x01\x03\ta\x8fcc\x1f!!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x1f%\x01\t\x07\x07\x01\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x13\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f/1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f5!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x01\x15)\x07\t\x11\x11\x0b\x01\x0b)\x05\t\r\x0b\x13\x1d)\x01\x0b\x1b)\x03\t\x15\x11\x01\x05\x07\r)\x03\x02\x86\x0b)\x03\x01\x0f)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x11)\x03\t\t)\x07\t\x05\x05\t)\x03\x05\x11)\x07\t\x11\x11\t)\x03\r\x11)\x05\t\x05\t)\x05\t\r\t)\x03\t\x11\x04\xde\x02\x05\x01\x11\x05\x15\x07\x03\x01\x05\t\x11\x05!\x07\x031Y\x03\x03\x05+\x03\x07\x03\x03\x01\t\x03\x05\x03\x03\x013\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x015\x03\x05\x03\x03\x017\x03\x05\x0b\x07\x019\t\x07\r\x17\x1b\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01K\x03\x05\x05\x07\x01\x0b\x03\x17\x03\x17\r\x07\x01M\x03\'\x05\x13\x19\x05\x07\x01\x11\x03)\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\x07\x03\x1f\x05\x07\x01S\x03-\x03\x1d\x07\x06\x01\x03\x07\x07#\x0f!\x05\x07\x01\x11\x031\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\r\x03)\x05\x07\x01U\x033\x03\'\x07\x06\x01\x03\r\x07-\x11+\x0f\x04\x05\x05%/\x06\x03\x01\x05\x01\x00\xf2\tU\x1d\x03\x0f\x0b\t\t\x11#!+\x1b\x1f/!!)#\x1f\x19i?\x1f\x15\x1d\x15\x13%)9\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/hessenberg\x00third_party/py/jax/tests/export_back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_dgehrd\x00', - xla_call_module_version=9, - nr_devices=1, -) # End paste +array = np.array +float32 = np.float32 +complex64 = np.complex64 data_2024_08_31 = {} - # Pasted from the test output (see export_back_compat_test_util.py module docstring) data_2024_08_31["c128"] = dict( testdata_version=1, diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_lu_lapack_getrf.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_lu_lapack_getrf.py index 72d97df53a4f..294b309a5904 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_lu_lapack_getrf.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_lu_lapack_getrf.py @@ -15,529 +15,15 @@ # ruff: noqa import datetime -from numpy import array, int32, float32, complex64 +import numpy as np -data_2023_06_14 = {} - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_14['f32'] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_sgetrf'], - serialized_date=datetime.date(2023, 6, 14), - inputs=(), - expected_outputs=(array([[6. , 7. , 8. ], - [0. , 1. , 2. ], - [0.5, 0.5, 0. ]], dtype=float32), array([2, 2, 2], dtype=int32), array([2, 0, 1], dtype=int32)), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "[0]"}, tensor<3xi32> {jax.result_info = "[1]"}, tensor<3xi32> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xf32> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<9xf32>) -> tensor<3x3xf32> loc(#loc4) - %2 = stablehlo.constant dense<3> : tensor loc(#loc5) - %3 = stablehlo.constant dense<3> : tensor loc(#loc5) - %4 = stablehlo.convert %2 : (tensor) -> tensor loc(#loc5) - %5 = stablehlo.reshape %4 : (tensor) -> tensor<1xi32> loc(#loc5) - %6 = stablehlo.convert %3 : (tensor) -> tensor loc(#loc5) - %7 = stablehlo.reshape %6 : (tensor) -> tensor<1xi32> loc(#loc5) - %8 = stablehlo.concatenate %5, %7, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5) - %9 = stablehlo.constant dense<3> : tensor loc(#loc5) - %10 = stablehlo.convert %9 : (tensor) -> tensor loc(#loc5) - %11 = stablehlo.reshape %10 : (tensor) -> tensor<1xi32> loc(#loc5) - %12 = stablehlo.constant dense<> : tensor<0xi32> loc(#loc5) - %13 = stablehlo.constant dense<1> : tensor loc(#loc5) - %14 = stablehlo.constant dense<3> : tensor loc(#loc5) - %15 = stablehlo.constant dense<3> : tensor loc(#loc5) - %16:3 = stablehlo.custom_call @lapack_sgetrf(%13, %14, %15, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<3x3xf32>) -> (tensor<3x3xf32>, tensor<3xi32>, tensor) loc(#loc5) - %17 = stablehlo.constant dense<1> : tensor loc(#loc5) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<3xi32> loc(#loc5) - %19 = stablehlo.subtract %16#1, %18 : tensor<3xi32> loc(#loc5) - %20 = stablehlo.constant dense<0> : tensor loc(#loc5) - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor loc(#loc5) - %22 = stablehlo.compare GE, %16#2, %21, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %24 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc5) - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc5) - %26 = stablehlo.broadcast_in_dim %23, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc5) - %27 = stablehlo.select %26, %16#0, %25 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc5) - %28 = stablehlo.iota dim = 0 : tensor<3xi32> loc(#loc6) - %29 = stablehlo.constant dense<0> : tensor loc(#loc7) - %30 = stablehlo.constant dense<0> : tensor loc(#loc8) - %31:4 = stablehlo.while(%iterArg = %30, %iterArg_0 = %29, %iterArg_1 = %28, %iterArg_2 = %19) : tensor, tensor, tensor<3xi32>, tensor<3xi32> - cond { - %32 = stablehlo.constant dense<3> : tensor loc(#loc9) - %33 = stablehlo.compare LT, %iterArg, %32, SIGNED : (tensor, tensor) -> tensor loc(#loc10) - stablehlo.return %33 : tensor loc(#loc9) - } do { - %32 = stablehlo.constant dense<1> : tensor loc(#loc9) - %33 = stablehlo.add %iterArg_0, %32 : tensor loc(#loc11) - %34 = stablehlo.constant dense<0> : tensor loc(#loc9) - %35 = stablehlo.compare LT, %iterArg_0, %34, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %36 = stablehlo.constant dense<3> : tensor loc(#loc9) - %37 = stablehlo.add %iterArg_0, %36 : tensor loc(#loc11) - %38 = stablehlo.select %35, %37, %iterArg_0 : tensor, tensor loc(#loc13) - %39 = stablehlo.convert %38 : (tensor) -> tensor loc(#loc14) - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %41 = "stablehlo.gather"(%iterArg_2, %40) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %42 = stablehlo.constant dense<0> : tensor loc(#loc9) - %43 = stablehlo.compare LT, %iterArg_0, %42, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %44 = stablehlo.constant dense<3> : tensor loc(#loc9) - %45 = stablehlo.add %iterArg_0, %44 : tensor loc(#loc11) - %46 = stablehlo.select %43, %45, %iterArg_0 : tensor, tensor loc(#loc13) - %47 = stablehlo.convert %46 : (tensor) -> tensor loc(#loc14) - %48 = stablehlo.broadcast_in_dim %47, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %49 = "stablehlo.gather"(%iterArg_1, %48) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %50 = stablehlo.constant dense<0> : tensor loc(#loc9) - %51 = stablehlo.compare LT, %41, %50, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %52 = stablehlo.constant dense<3> : tensor loc(#loc9) - %53 = stablehlo.add %41, %52 : tensor loc(#loc11) - %54 = stablehlo.select %51, %53, %41 : tensor, tensor loc(#loc13) - %55 = stablehlo.dynamic_slice %iterArg_1, %54, sizes = [1] : (tensor<3xi32>, tensor) -> tensor<1xi32> loc(#loc17) - %56 = stablehlo.reshape %55 : (tensor<1xi32>) -> tensor loc(#loc18) - %57 = stablehlo.constant dense<0> : tensor loc(#loc9) - %58 = stablehlo.compare LT, %iterArg_0, %57, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %59 = stablehlo.constant dense<3> : tensor loc(#loc9) - %60 = stablehlo.add %iterArg_0, %59 : tensor loc(#loc11) - %61 = stablehlo.select %58, %60, %iterArg_0 : tensor, tensor loc(#loc13) - %62 = stablehlo.convert %61 : (tensor) -> tensor loc(#loc14) - %63 = stablehlo.broadcast_in_dim %62, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %64 = "stablehlo.scatter"(%iterArg_1, %63, %56) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %65 = stablehlo.constant dense<0> : tensor loc(#loc9) - %66 = stablehlo.compare LT, %41, %65, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %67 = stablehlo.constant dense<3> : tensor loc(#loc9) - %68 = stablehlo.add %41, %67 : tensor loc(#loc11) - %69 = stablehlo.select %66, %68, %41 : tensor, tensor loc(#loc13) - %70 = stablehlo.broadcast_in_dim %69, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %71 = "stablehlo.scatter"(%64, %70, %49) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %72 = stablehlo.constant dense<1> : tensor loc(#loc9) - %73 = stablehlo.add %iterArg, %72 : tensor loc(#loc11) - stablehlo.return %73, %33, %71, %iterArg_2 : tensor, tensor, tensor<3xi32>, tensor<3xi32> loc(#loc9) - } loc(#loc9) - return %27, %19, %31#2 : tensor<3x3xf32>, tensor<3xi32>, tensor<3xi32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":550:0) -#loc2 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":551:0) -#loc3 = loc("jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]"(#loc1)) -#loc4 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc1)) -#loc5 = loc("jit()/jit(main)/lu"(#loc2)) -#loc6 = loc("jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]"(#loc2)) -#loc7 = loc("jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]"(#loc2)) -#loc8 = loc("jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]"(#loc2)) -#loc9 = loc("jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]"(#loc2)) -#loc10 = loc("jit()/jit(main)/while/cond/lt"(#loc2)) -#loc11 = loc("jit()/jit(main)/while/body/add"(#loc2)) -#loc12 = loc("jit()/jit(main)/while/body/lt"(#loc2)) -#loc13 = loc("jit()/jit(main)/while/body/select_n"(#loc2)) -#loc14 = loc("jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]"(#loc2)) -#loc15 = loc("jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]"(#loc2)) -#loc16 = loc("jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]"(#loc2)) -#loc17 = loc("jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]"(#loc2)) -#loc18 = loc("jit()/jit(main)/while/body/squeeze[dimensions=(0,)]"(#loc2)) -#loc19 = loc("jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x013\x05\x01\x03\x01\x03\x05\x03#\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%'\x03\xa6\x02\x0e\x023\x01\xb1\x0f\x0f\x07\x17\x0b\x13\x13\x0f\x1b\x13\x0f\x0f\x13\x0f\x13\x13\x13\x0f\x0f\x0b\x13\x17\x0b\x0b\x0b\x0b;\x0b\x0b\x0b\x0f;#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x13\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x13\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x03Q\x0f\x0f/\x0b\x0f\x0b\x0bO\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b/\x0f/\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x17/\x1f\x1f\x0b\x1fO/\x0b\x01\x07\x0b\x13\x0b\x01\x03\x0f\x031\x0f\x0f\x13\x13\x0f\x17\x07\x07\x07\x07\x07\x13\x0f\x13\x1b\x13\x13\x13\x13\x13\x13\x17\x17\x13\x02J\t\x1d]\x07\x1d\x8b\x07\x1f\x17-\x9e\x08\x01\x05)\x03\x03/\xb9\x03\x03\t\xd9\x1d\x8d\x07\x03\x051\xc13\xff\x03\x03\t\xfd\x1d\x8f\x07\x1d\x91\x07\x03\x03\t\xdf\x1d\x95\x07\x1d\x02\x02\x07\x03\x03\t\xdd\x03\x03\t\xf5\x1d\x93\x07\x11\x01\x05\x05+\x03\x03S\xb1\x17-\x9a\x08\x01\x05-\x05/\x051\x053\x03\r\x97\xb57\xb19\xbb\x99\xb9;\xc3\x9b\xb5\x055\x057\x059\x1d\x9d\x07\x03\r7\xb19\xbb\xa9\xb5\xab\xb5\xad\xbb\xaf\xb9\x03\x07C%E%'G\x05;\x05=\x05?\x03\x0bK\xbdM\xc5O\xc7'\xd5Q\xd7\x05A\x05C\x05E\x05G\x05I\x1dW+\x05K\x1d[+\x05M\x05O\x03\x03a\xb1\x05Q\x03\x03\t\xdb\x03\x11g\xe1i\xe3k\xe5m\xbdo\xe7q\xe9s\xebu\xef\x05S\x05U\x05W\x05Y\x05[\x05]\x05_\x05a\x03\x03\t\xf3\x03\x051\xc13\xf7\x03\x03\t\xf9\x03\x03/\xfb\x1d\x81\x07\x05c\x1d\x85\x07\x05e\x1d\x89\x07\x05g\x05i\x05k\x05m\x05o\x05q\x05s\x05u\x05w\x05y\x05{\x03\x03;\xc3\x1d\xa3\x07\x05}\x1d\xa7\x07\x05\x7f\x05\x81\x05\x83\x05\x85\x05\x87\x13\x11\x01\x1f%\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d\x89\x1f+\x01\x05\x03\x03\x01\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\t\x07\x1f\x1d\x11\x01\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x07\xc9\xcd\xd1\r\x03\xb7\xcb\x1d\x8b\r\x03\xb7\xcf\x1d\x8d\r\x03\xb7\xd3\x1d\x8f\x1d\x91\x1d\x93\x1f\x03\x11\x03\x00\x00\x00\x00\x00\x00\x00\x1f\x19\x01\x1f\x03\x11\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x03\x00\x00\x00\x0b\x05\x1d\x95\x1d\x97\x05\x01\x03\t\xb3\xb3\xb3\xbf\x03\x03\xed\x15\x03\x01\r\x01\x03\x07\xbf\xf1\xb3\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x07\x05\x1f\x1b\t\x00\x00\xc0\x7f\x1f1!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x03\x11\x00\x00\x00\x00\x00\x00\x00\x00\x07\x0b\x05\x99\x1d\n\x02\x07\x05\x9b\x01\x02\x02)\x01\x11)\x01\x0f)\x03\r\x0f)\x03\x05\x0f)\x01\x17)\x05\r\r\x13\x1b\x1d\t\x13\x01)\x03\x01\x0f)\x01\x13)\x03\x05\x11\x11\x01\x07\r\x07\x07)\x03%\x13)\x03\t\x0f)\x03\x01\x15)\x03\t\x15)\x03\x05\x15)\x03\x01\x11)\x05\x05\x05\x17)\x05\r\r\x17)\x03\t\x11\x04f\n\x05\x01\x11\x05A\x07\x03\x01\x05\x19\x11\x05I\x05\x03K\x85\x13\x03U)\x03!\x0f\x06Y\x03\r\x03\x01\x03\x03\x01\r\x03\x03\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x05\x0f\x06\x01\x03\t\x03\t\x0b\x06\x01\x03\x05\x03\x07\x0f\x06\x01\x03\t\x03\r\x1b\x07\x01_\x03#\x05\x0b\x0f\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x13\x0f\x06\x01\x03\t\x03\x15\x03\x03\x01c\x03\x19\x03\x03\x01\x1f\x03\x03\x03\x03\x01\x19\x03\x05\x03\x03\x01\x19\x03\x05\x1d\x07\x01e\x07\r\x07\x05\t\x1b\x1d\x1f\x03\x03\x03\x01w\x03\x05\x05\x07\x01\x0b\x03\x07\x03'\x1f\x06\x01\x03\x07\x05#)\x03\x03\x01!\x03\x05\x05\x07\x01\x0b\x03\x05\x03-\x07\x07\x01y\x03\x0b\x05%/\x05\x07\x01\x0b\x03-\x031\x03\x03\x01{\x03\x1b\x05\x07\x01\x0b\x03\r\x035\x05\x07\x01}\x03/\x033\r\x06\x01\x03\r\x079!7\x13\x03\x7f)\x03\x07\x03\x03\x83\x13\x03\x03\x03\x03\x87\x13\x03\x03!\x16\x03\t\x03\x03\x07\x07\tA?=+\t\x03\r\x0f\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\r\x03\x03\x07\x07\x06\x02\x11\x03\x0b\x05KS\x11\x04\x03\x03U\x03]\xaf\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05MS\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05MW\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M[\r\x06\x17\x03\x03\x07Y]M\x0b\x06#\x03\x05\x03_\x05\x07\x1b\x0b\x03\t\x03a\x15\x07=5\x03\x05\x05Qc\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05Mg\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05Mk\r\x06\x17\x03\x03\x07imM\x0b\x06#\x03\x05\x03o\x05\x07\x1b\x0b\x03\t\x03q\x15\x07=5\x03\x05\x05Os\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05ew\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e{\r\x06\x17\x03\x05\x07y}e#\x07\xa1\x9f\x03\t\x05O\x7f\x0f\x06\xa5\x03\x05\x03\x81\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05M\x85\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M\x89\r\x06\x17\x03\x03\x07\x87\x8bM\x0b\x06#\x03\x05\x03\x8d\x05\x07\x1b\x0b\x03\t\x03\x8f\x17\x17\x1d?\x03\x07\x07O\x91\x83\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05e\x95\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e\x99\r\x06\x17\x03\x05\x07\x97\x9be\x05\x07\x1b\x0b\x03\t\x03\x9d\x17\x17\x1d?\x03\x07\x07\x93\x9fu\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05K\xa3\x11\x04\x03\t\xa5U\xa1Q\x11\x04\x05\x07;+G\x06\x03\x01\x05\x01\x00v%\x9dM2\x04\x1d\x03\x0f\x0b\t\t\t!'\x1f;+y\x87.\x04!\x19+\xb1\xb3YMO{\xe9\x8b\x83\x1f/!!)#\x1f\x19\x157\x85\x87\x1f\x1f\x15\x1d\x15\x1b%)\x19'#+\x1b+\x83\x13\r#\x13\x19\x1f\x1f\x11\x17\x15\x11\x15\x17\x15\x17\x0f\x17)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00compare_v1\x00add_v1\x00convert_v1\x00select_v1\x00reshape_v1\x00return_v1\x00iota_v1\x00gather_v1\x00scatter_v1\x00func_v1\x00concatenate_v1\x00custom_call_v1\x00subtract_v1\x00while_v1\x00dynamic_slice_v1\x00value\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00index_vector_dim\x00indices_are_sorted\x00slice_sizes\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/lu\x00dimension\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]\x00jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]\x00jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]\x00jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]\x00jit()/jit(main)/while/body/add\x00jit()/jit(main)/while/body/lt\x00jit()/jit(main)/while/body/select_n\x00jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]\x00jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]\x00collapsed_slice_dims\x00offset_dims\x00start_index_map\x00jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]\x00jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]\x00jit()/jit(main)/while/body/squeeze[dimensions=(0,)]\x00inserted_window_dims\x00scatter_dims_to_operand_dims\x00unique_indices\x00update_window_dims\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_sgetrf\x00jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]\x00jit()/jit(main)/while/cond/lt\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_14['f64'] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dgetrf'], - serialized_date=datetime.date(2023, 6, 14), - inputs=(), - expected_outputs=(array([[6. , 7. , 8. ], - [0. , 1. , 2. ], - [0.5, 0.5, 0. ]]), array([2, 2, 2], dtype=int32), array([2, 0, 1], dtype=int32)), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xf64> {jax.result_info = "[0]"}, tensor<3xi32> {jax.result_info = "[1]"}, tensor<3xi32> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xf64> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<9xf64>) -> tensor<3x3xf64> loc(#loc4) - %2 = stablehlo.constant dense<3> : tensor loc(#loc5) - %3 = stablehlo.constant dense<3> : tensor loc(#loc5) - %4 = stablehlo.convert %2 : (tensor) -> tensor loc(#loc5) - %5 = stablehlo.reshape %4 : (tensor) -> tensor<1xi32> loc(#loc5) - %6 = stablehlo.convert %3 : (tensor) -> tensor loc(#loc5) - %7 = stablehlo.reshape %6 : (tensor) -> tensor<1xi32> loc(#loc5) - %8 = stablehlo.concatenate %5, %7, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5) - %9 = stablehlo.constant dense<3> : tensor loc(#loc5) - %10 = stablehlo.convert %9 : (tensor) -> tensor loc(#loc5) - %11 = stablehlo.reshape %10 : (tensor) -> tensor<1xi32> loc(#loc5) - %12 = stablehlo.constant dense<> : tensor<0xi32> loc(#loc5) - %13 = stablehlo.constant dense<1> : tensor loc(#loc5) - %14 = stablehlo.constant dense<3> : tensor loc(#loc5) - %15 = stablehlo.constant dense<3> : tensor loc(#loc5) - %16:3 = stablehlo.custom_call @lapack_dgetrf(%13, %14, %15, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<3x3xf64>) -> (tensor<3x3xf64>, tensor<3xi32>, tensor) loc(#loc5) - %17 = stablehlo.constant dense<1> : tensor loc(#loc5) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<3xi32> loc(#loc5) - %19 = stablehlo.subtract %16#1, %18 : tensor<3xi32> loc(#loc5) - %20 = stablehlo.constant dense<0> : tensor loc(#loc5) - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor loc(#loc5) - %22 = stablehlo.compare GE, %16#2, %21, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %24 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc5) - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<3x3xf64> loc(#loc5) - %26 = stablehlo.broadcast_in_dim %23, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc5) - %27 = stablehlo.select %26, %16#0, %25 : tensor<3x3xi1>, tensor<3x3xf64> loc(#loc5) - %28 = stablehlo.iota dim = 0 : tensor<3xi32> loc(#loc6) - %29 = stablehlo.constant dense<0> : tensor loc(#loc7) - %30 = stablehlo.constant dense<0> : tensor loc(#loc8) - %31:4 = stablehlo.while(%iterArg = %30, %iterArg_0 = %29, %iterArg_1 = %28, %iterArg_2 = %19) : tensor, tensor, tensor<3xi32>, tensor<3xi32> - cond { - %32 = stablehlo.constant dense<3> : tensor loc(#loc9) - %33 = stablehlo.compare LT, %iterArg, %32, SIGNED : (tensor, tensor) -> tensor loc(#loc10) - stablehlo.return %33 : tensor loc(#loc9) - } do { - %32 = stablehlo.constant dense<1> : tensor loc(#loc9) - %33 = stablehlo.add %iterArg_0, %32 : tensor loc(#loc11) - %34 = stablehlo.constant dense<0> : tensor loc(#loc9) - %35 = stablehlo.compare LT, %iterArg_0, %34, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %36 = stablehlo.constant dense<3> : tensor loc(#loc9) - %37 = stablehlo.add %iterArg_0, %36 : tensor loc(#loc11) - %38 = stablehlo.select %35, %37, %iterArg_0 : tensor, tensor loc(#loc13) - %39 = stablehlo.convert %38 : (tensor) -> tensor loc(#loc14) - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %41 = "stablehlo.gather"(%iterArg_2, %40) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %42 = stablehlo.constant dense<0> : tensor loc(#loc9) - %43 = stablehlo.compare LT, %iterArg_0, %42, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %44 = stablehlo.constant dense<3> : tensor loc(#loc9) - %45 = stablehlo.add %iterArg_0, %44 : tensor loc(#loc11) - %46 = stablehlo.select %43, %45, %iterArg_0 : tensor, tensor loc(#loc13) - %47 = stablehlo.convert %46 : (tensor) -> tensor loc(#loc14) - %48 = stablehlo.broadcast_in_dim %47, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %49 = "stablehlo.gather"(%iterArg_1, %48) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %50 = stablehlo.constant dense<0> : tensor loc(#loc9) - %51 = stablehlo.compare LT, %41, %50, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %52 = stablehlo.constant dense<3> : tensor loc(#loc9) - %53 = stablehlo.add %41, %52 : tensor loc(#loc11) - %54 = stablehlo.select %51, %53, %41 : tensor, tensor loc(#loc13) - %55 = stablehlo.dynamic_slice %iterArg_1, %54, sizes = [1] : (tensor<3xi32>, tensor) -> tensor<1xi32> loc(#loc17) - %56 = stablehlo.reshape %55 : (tensor<1xi32>) -> tensor loc(#loc18) - %57 = stablehlo.constant dense<0> : tensor loc(#loc9) - %58 = stablehlo.compare LT, %iterArg_0, %57, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %59 = stablehlo.constant dense<3> : tensor loc(#loc9) - %60 = stablehlo.add %iterArg_0, %59 : tensor loc(#loc11) - %61 = stablehlo.select %58, %60, %iterArg_0 : tensor, tensor loc(#loc13) - %62 = stablehlo.convert %61 : (tensor) -> tensor loc(#loc14) - %63 = stablehlo.broadcast_in_dim %62, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %64 = "stablehlo.scatter"(%iterArg_1, %63, %56) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %65 = stablehlo.constant dense<0> : tensor loc(#loc9) - %66 = stablehlo.compare LT, %41, %65, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %67 = stablehlo.constant dense<3> : tensor loc(#loc9) - %68 = stablehlo.add %41, %67 : tensor loc(#loc11) - %69 = stablehlo.select %66, %68, %41 : tensor, tensor loc(#loc13) - %70 = stablehlo.broadcast_in_dim %69, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %71 = "stablehlo.scatter"(%64, %70, %49) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %72 = stablehlo.constant dense<1> : tensor loc(#loc9) - %73 = stablehlo.add %iterArg, %72 : tensor loc(#loc11) - stablehlo.return %73, %33, %71, %iterArg_2 : tensor, tensor, tensor<3xi32>, tensor<3xi32> loc(#loc9) - } loc(#loc9) - return %27, %19, %31#2 : tensor<3x3xf64>, tensor<3xi32>, tensor<3xi32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":553:0) -#loc2 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":554:0) -#loc3 = loc("jit()/jit(main)/iota[dtype=float64 shape=(9,) dimension=0]"(#loc1)) -#loc4 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc1)) -#loc5 = loc("jit()/jit(main)/lu"(#loc2)) -#loc6 = loc("jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]"(#loc2)) -#loc7 = loc("jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]"(#loc2)) -#loc8 = loc("jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]"(#loc2)) -#loc9 = loc("jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]"(#loc2)) -#loc10 = loc("jit()/jit(main)/while/cond/lt"(#loc2)) -#loc11 = loc("jit()/jit(main)/while/body/add"(#loc2)) -#loc12 = loc("jit()/jit(main)/while/body/lt"(#loc2)) -#loc13 = loc("jit()/jit(main)/while/body/select_n"(#loc2)) -#loc14 = loc("jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]"(#loc2)) -#loc15 = loc("jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]"(#loc2)) -#loc16 = loc("jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]"(#loc2)) -#loc17 = loc("jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]"(#loc2)) -#loc18 = loc("jit()/jit(main)/while/body/squeeze[dimensions=(0,)]"(#loc2)) -#loc19 = loc("jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x013\x05\x01\x03\x01\x03\x05\x03#\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%'\x03\xa6\x02\x0e\x023\x01\xb1\x0f\x0f\x07\x17\x0b\x13\x13\x0f\x1b\x13\x0f\x0f\x13\x0f\x13\x13\x13\x0f\x0f\x0b\x13\x17\x0b\x0b\x0b\x0b;\x0b\x0b\x0b\x0f;#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x13\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x13\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x03Q\x0f\x0f/\x0b\x0f\x0b\x0bO\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b/\x0f/\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x17/\x1f\x1f\x0b/O/\x0b\x01\x07\x0b\x13\x0b\x01\x03\x0f\x031\x0f\x0f\x13\x13\x0f\x17\x07\x07\x07\x07\x07\x13\x0f\x13\x1b\x13\x13\x13\x13\x13\x13\x17\x17\x13\x02Z\t\x1d]\x07\x1d\x8b\x07\x1f\x17-\xaa\x08\x01\x05)\x03\x03/\xb9\x03\x03\t\xd9\x1d\x8d\x07\x03\x051\xc13\xff\x03\x03\t\xfd\x1d\x8f\x07\x1d\x91\x07\x03\x03\t\xdf\x1d\x95\x07\x1d\x02\x02\x07\x03\x03\t\xdd\x03\x03\t\xf5\x1d\x93\x07\x11\x01\x05\x05+\x03\x03S\xb1\x17-\xa6\x08\x01\x05-\x05/\x051\x053\x03\r\x97\xb57\xb19\xbb\x99\xb9;\xc3\x9b\xb5\x055\x057\x059\x1d\x9d\x07\x03\r7\xb19\xbb\xa9\xb5\xab\xb5\xad\xbb\xaf\xb9\x03\x07C%E%'G\x05;\x05=\x05?\x03\x0bK\xbdM\xc5O\xc7'\xd5Q\xd7\x05A\x05C\x05E\x05G\x05I\x1dW+\x05K\x1d[+\x05M\x05O\x03\x03a\xb1\x05Q\x03\x03\t\xdb\x03\x11g\xe1i\xe3k\xe5m\xbdo\xe7q\xe9s\xebu\xef\x05S\x05U\x05W\x05Y\x05[\x05]\x05_\x05a\x03\x03\t\xf3\x03\x051\xc13\xf7\x03\x03\t\xf9\x03\x03/\xfb\x1d\x81\x07\x05c\x1d\x85\x07\x05e\x1d\x89\x07\x05g\x05i\x05k\x05m\x05o\x05q\x05s\x05u\x05w\x05y\x05{\x03\x03;\xc3\x1d\xa3\x07\x05}\x1d\xa7\x07\x05\x7f\x05\x81\x05\x83\x05\x85\x05\x87\x13\x11\x01\x1f%\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d\x89\x1f+\x01\x05\x03\x03\x01\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\t\x07\x1f\x1d\x11\x01\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x07\xc9\xcd\xd1\r\x03\xb7\xcb\x1d\x8b\r\x03\xb7\xcf\x1d\x8d\r\x03\xb7\xd3\x1d\x8f\x1d\x91\x1d\x93\x1f\x03\x11\x03\x00\x00\x00\x00\x00\x00\x00\x1f\x19\x01\x1f\x03\x11\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x03\x00\x00\x00\x0b\x05\x1d\x95\x1d\x97\x05\x01\x03\t\xb3\xb3\xb3\xbf\x03\x03\xed\x15\x03\x01\r\x01\x03\x07\xbf\xf1\xb3\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x07\x05\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f1!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x03\x11\x00\x00\x00\x00\x00\x00\x00\x00\x07\x0b\x05\x99\x1d\n\x02\x07\x05\x9b\x01\x02\x02)\x01\x11)\x01\x0f)\x03\r\x0f)\x03\x05\x0f)\x01\x17)\x05\r\r\x13\x1b\x1d\x0b\x13\x01)\x03\x01\x0f)\x01\x13)\x03\x05\x11\x11\x01\x07\r\x07\x07)\x03%\x13)\x03\t\x0f)\x03\x01\x15)\x03\t\x15)\x03\x05\x15)\x03\x01\x11)\x05\x05\x05\x17)\x05\r\r\x17)\x03\t\x11\x04f\n\x05\x01\x11\x05A\x07\x03\x01\x05\x19\x11\x05I\x05\x03K\x85\x13\x03U)\x03!\x0f\x06Y\x03\r\x03\x01\x03\x03\x01\r\x03\x03\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x05\x0f\x06\x01\x03\t\x03\t\x0b\x06\x01\x03\x05\x03\x07\x0f\x06\x01\x03\t\x03\r\x1b\x07\x01_\x03#\x05\x0b\x0f\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x13\x0f\x06\x01\x03\t\x03\x15\x03\x03\x01c\x03\x19\x03\x03\x01\x1f\x03\x03\x03\x03\x01\x19\x03\x05\x03\x03\x01\x19\x03\x05\x1d\x07\x01e\x07\r\x07\x05\t\x1b\x1d\x1f\x03\x03\x03\x01w\x03\x05\x05\x07\x01\x0b\x03\x07\x03'\x1f\x06\x01\x03\x07\x05#)\x03\x03\x01!\x03\x05\x05\x07\x01\x0b\x03\x05\x03-\x07\x07\x01y\x03\x0b\x05%/\x05\x07\x01\x0b\x03-\x031\x03\x03\x01{\x03\x1b\x05\x07\x01\x0b\x03\r\x035\x05\x07\x01}\x03/\x033\r\x06\x01\x03\r\x079!7\x13\x03\x7f)\x03\x07\x03\x03\x83\x13\x03\x03\x03\x03\x87\x13\x03\x03!\x16\x03\t\x03\x03\x07\x07\tA?=+\t\x03\r\x0f\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\r\x03\x03\x07\x07\x06\x02\x11\x03\x0b\x05KS\x11\x04\x03\x03U\x03]\xaf\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05MS\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05MW\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M[\r\x06\x17\x03\x03\x07Y]M\x0b\x06#\x03\x05\x03_\x05\x07\x1b\x0b\x03\t\x03a\x15\x07=5\x03\x05\x05Qc\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05Mg\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05Mk\r\x06\x17\x03\x03\x07imM\x0b\x06#\x03\x05\x03o\x05\x07\x1b\x0b\x03\t\x03q\x15\x07=5\x03\x05\x05Os\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05ew\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e{\r\x06\x17\x03\x05\x07y}e#\x07\xa1\x9f\x03\t\x05O\x7f\x0f\x06\xa5\x03\x05\x03\x81\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05M\x85\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M\x89\r\x06\x17\x03\x03\x07\x87\x8bM\x0b\x06#\x03\x05\x03\x8d\x05\x07\x1b\x0b\x03\t\x03\x8f\x17\x17\x1d?\x03\x07\x07O\x91\x83\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05e\x95\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e\x99\r\x06\x17\x03\x05\x07\x97\x9be\x05\x07\x1b\x0b\x03\t\x03\x9d\x17\x17\x1d?\x03\x07\x07\x93\x9fu\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05K\xa3\x11\x04\x03\t\xa5U\xa1Q\x11\x04\x05\x07;+G\x06\x03\x01\x05\x01\x00v%\x9dM2\x04\x1d\x03\x0f\x0b\t\t\t!'\x1f;+y\x87.\x04!\x19+\xb1\xb3YMO{\xe9\x8b\x83\x1f/!!)#\x1f\x19\x157\x85\x87\x1f\x1f\x15\x1d\x15\x1b%)\x19'#+\x1b+\x83\x13\r#\x13\x19\x1f\x1f\x11\x17\x15\x11\x15\x17\x15\x17\x0f\x17)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00compare_v1\x00add_v1\x00convert_v1\x00select_v1\x00reshape_v1\x00return_v1\x00iota_v1\x00gather_v1\x00scatter_v1\x00func_v1\x00concatenate_v1\x00custom_call_v1\x00subtract_v1\x00while_v1\x00dynamic_slice_v1\x00value\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00index_vector_dim\x00indices_are_sorted\x00slice_sizes\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit()/jit(main)/iota[dtype=float64 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/lu\x00dimension\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]\x00jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]\x00jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]\x00jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]\x00jit()/jit(main)/while/body/add\x00jit()/jit(main)/while/body/lt\x00jit()/jit(main)/while/body/select_n\x00jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]\x00jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]\x00collapsed_slice_dims\x00offset_dims\x00start_index_map\x00jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]\x00jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]\x00jit()/jit(main)/while/body/squeeze[dimensions=(0,)]\x00inserted_window_dims\x00scatter_dims_to_operand_dims\x00unique_indices\x00update_window_dims\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_dgetrf\x00jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]\x00jit()/jit(main)/while/cond/lt\x00", - xla_call_module_version=6, -) # End paste - - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_14['c64'] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_cgetrf'], - serialized_date=datetime.date(2023, 6, 14), - inputs=(), - expected_outputs=(array([[6. +0.j, 7. +0.j, 8. +0.j], - [0. +0.j, 1. +0.j, 2. +0.j], - [0.5+0.j, 0.5+0.j, 0. +0.j]], dtype=complex64), array([2, 2, 2], dtype=int32), array([2, 0, 1], dtype=int32)), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "[0]"}, tensor<3xi32> {jax.result_info = "[1]"}, tensor<3xi32> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xcomplex> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<9xcomplex>) -> tensor<3x3xcomplex> loc(#loc4) - %2 = stablehlo.constant dense<3> : tensor loc(#loc5) - %3 = stablehlo.constant dense<3> : tensor loc(#loc5) - %4 = stablehlo.convert %2 : (tensor) -> tensor loc(#loc5) - %5 = stablehlo.reshape %4 : (tensor) -> tensor<1xi32> loc(#loc5) - %6 = stablehlo.convert %3 : (tensor) -> tensor loc(#loc5) - %7 = stablehlo.reshape %6 : (tensor) -> tensor<1xi32> loc(#loc5) - %8 = stablehlo.concatenate %5, %7, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5) - %9 = stablehlo.constant dense<3> : tensor loc(#loc5) - %10 = stablehlo.convert %9 : (tensor) -> tensor loc(#loc5) - %11 = stablehlo.reshape %10 : (tensor) -> tensor<1xi32> loc(#loc5) - %12 = stablehlo.constant dense<> : tensor<0xi32> loc(#loc5) - %13 = stablehlo.constant dense<1> : tensor loc(#loc5) - %14 = stablehlo.constant dense<3> : tensor loc(#loc5) - %15 = stablehlo.constant dense<3> : tensor loc(#loc5) - %16:3 = stablehlo.custom_call @lapack_cgetrf(%13, %14, %15, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<3x3xcomplex>) -> (tensor<3x3xcomplex>, tensor<3xi32>, tensor) loc(#loc5) - %17 = stablehlo.constant dense<1> : tensor loc(#loc5) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<3xi32> loc(#loc5) - %19 = stablehlo.subtract %16#1, %18 : tensor<3xi32> loc(#loc5) - %20 = stablehlo.constant dense<0> : tensor loc(#loc5) - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor loc(#loc5) - %22 = stablehlo.compare GE, %16#2, %21, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %24 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc5) - %26 = stablehlo.broadcast_in_dim %23, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc5) - %27 = stablehlo.select %26, %16#0, %25 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc5) - %28 = stablehlo.iota dim = 0 : tensor<3xi32> loc(#loc6) - %29 = stablehlo.constant dense<0> : tensor loc(#loc7) - %30 = stablehlo.constant dense<0> : tensor loc(#loc8) - %31:4 = stablehlo.while(%iterArg = %30, %iterArg_0 = %29, %iterArg_1 = %28, %iterArg_2 = %19) : tensor, tensor, tensor<3xi32>, tensor<3xi32> - cond { - %32 = stablehlo.constant dense<3> : tensor loc(#loc9) - %33 = stablehlo.compare LT, %iterArg, %32, SIGNED : (tensor, tensor) -> tensor loc(#loc10) - stablehlo.return %33 : tensor loc(#loc9) - } do { - %32 = stablehlo.constant dense<1> : tensor loc(#loc9) - %33 = stablehlo.add %iterArg_0, %32 : tensor loc(#loc11) - %34 = stablehlo.constant dense<0> : tensor loc(#loc9) - %35 = stablehlo.compare LT, %iterArg_0, %34, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %36 = stablehlo.constant dense<3> : tensor loc(#loc9) - %37 = stablehlo.add %iterArg_0, %36 : tensor loc(#loc11) - %38 = stablehlo.select %35, %37, %iterArg_0 : tensor, tensor loc(#loc13) - %39 = stablehlo.convert %38 : (tensor) -> tensor loc(#loc14) - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %41 = "stablehlo.gather"(%iterArg_2, %40) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %42 = stablehlo.constant dense<0> : tensor loc(#loc9) - %43 = stablehlo.compare LT, %iterArg_0, %42, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %44 = stablehlo.constant dense<3> : tensor loc(#loc9) - %45 = stablehlo.add %iterArg_0, %44 : tensor loc(#loc11) - %46 = stablehlo.select %43, %45, %iterArg_0 : tensor, tensor loc(#loc13) - %47 = stablehlo.convert %46 : (tensor) -> tensor loc(#loc14) - %48 = stablehlo.broadcast_in_dim %47, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %49 = "stablehlo.gather"(%iterArg_1, %48) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %50 = stablehlo.constant dense<0> : tensor loc(#loc9) - %51 = stablehlo.compare LT, %41, %50, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %52 = stablehlo.constant dense<3> : tensor loc(#loc9) - %53 = stablehlo.add %41, %52 : tensor loc(#loc11) - %54 = stablehlo.select %51, %53, %41 : tensor, tensor loc(#loc13) - %55 = stablehlo.dynamic_slice %iterArg_1, %54, sizes = [1] : (tensor<3xi32>, tensor) -> tensor<1xi32> loc(#loc17) - %56 = stablehlo.reshape %55 : (tensor<1xi32>) -> tensor loc(#loc18) - %57 = stablehlo.constant dense<0> : tensor loc(#loc9) - %58 = stablehlo.compare LT, %iterArg_0, %57, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %59 = stablehlo.constant dense<3> : tensor loc(#loc9) - %60 = stablehlo.add %iterArg_0, %59 : tensor loc(#loc11) - %61 = stablehlo.select %58, %60, %iterArg_0 : tensor, tensor loc(#loc13) - %62 = stablehlo.convert %61 : (tensor) -> tensor loc(#loc14) - %63 = stablehlo.broadcast_in_dim %62, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %64 = "stablehlo.scatter"(%iterArg_1, %63, %56) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %65 = stablehlo.constant dense<0> : tensor loc(#loc9) - %66 = stablehlo.compare LT, %41, %65, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %67 = stablehlo.constant dense<3> : tensor loc(#loc9) - %68 = stablehlo.add %41, %67 : tensor loc(#loc11) - %69 = stablehlo.select %66, %68, %41 : tensor, tensor loc(#loc13) - %70 = stablehlo.broadcast_in_dim %69, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %71 = "stablehlo.scatter"(%64, %70, %49) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %72 = stablehlo.constant dense<1> : tensor loc(#loc9) - %73 = stablehlo.add %iterArg, %72 : tensor loc(#loc11) - stablehlo.return %73, %33, %71, %iterArg_2 : tensor, tensor, tensor<3xi32>, tensor<3xi32> loc(#loc9) - } loc(#loc9) - return %27, %19, %31#2 : tensor<3x3xcomplex>, tensor<3xi32>, tensor<3xi32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":553:0) -#loc2 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":554:0) -#loc3 = loc("jit()/jit(main)/iota[dtype=complex64 shape=(9,) dimension=0]"(#loc1)) -#loc4 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc1)) -#loc5 = loc("jit()/jit(main)/lu"(#loc2)) -#loc6 = loc("jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]"(#loc2)) -#loc7 = loc("jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]"(#loc2)) -#loc8 = loc("jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]"(#loc2)) -#loc9 = loc("jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]"(#loc2)) -#loc10 = loc("jit()/jit(main)/while/cond/lt"(#loc2)) -#loc11 = loc("jit()/jit(main)/while/body/add"(#loc2)) -#loc12 = loc("jit()/jit(main)/while/body/lt"(#loc2)) -#loc13 = loc("jit()/jit(main)/while/body/select_n"(#loc2)) -#loc14 = loc("jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]"(#loc2)) -#loc15 = loc("jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]"(#loc2)) -#loc16 = loc("jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]"(#loc2)) -#loc17 = loc("jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]"(#loc2)) -#loc18 = loc("jit()/jit(main)/while/body/squeeze[dimensions=(0,)]"(#loc2)) -#loc19 = loc("jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x013\x05\x01\x03\x01\x03\x05\x03#\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%'\x03\xaa\x02\x0e\x025\x01\xb1\x0f\x0f\x07\x17\x0b\x13\x13\x0f\x1b\x13\x0f\x0f\x13\x0f\x13\x13\x13\x0f\x0f\x0b\x13\x17\x0b\x0b\x0b\x0b;\x0b\x0b\x0b\x0f;#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x13\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x13\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x03Q\x0f\x0f/\x0b\x0f\x0b\x0bO\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b/\x0f/\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x17/\x1f\x1f\x0b/O/\x0b\x01\x07\x0b\x13\x0b\x01\x03\x0f\x033\x0f\x0f\x13\x13\x0f\x17\x07\x07\x0b\x07\x07\x13\x0f\x13\x1b\x07\x13\x13\x13\x13\x13\x13\x17\x17\x13\x02b\t\x1d]\x07\x1d\x8b\x07\x1f\x17-\xaa\x08\x01\x05)\x03\x03/\xb9\x03\x03\t\xd9\x1d\x8d\x07\x03\x051\xc13\xff\x03\x03\t\xfd\x1d\x8f\x07\x1d\x91\x07\x03\x03\t\xdf\x1d\x95\x07\x1d\x02\x02\x07\x03\x03\t\xdd\x03\x03\t\xf5\x1d\x93\x07\x11\x01\x05\x05+\x03\x03S\xb1\x17-\xa6\x08\x01\x05-\x05/\x051\x053\x03\r\x97\xb57\xb19\xbb\x99\xb9;\xc3\x9b\xb5\x055\x057\x059\x1d\x9d\x07\x03\r7\xb19\xbb\xa9\xb5\xab\xb5\xad\xbb\xaf\xb9\x03\x07C%E%'G\x05;\x05=\x05?\x03\x0bK\xbdM\xc5O\xc7'\xd5Q\xd7\x05A\x05C\x05E\x05G\x05I\x1dW+\x05K\x1d[+\x05M\x05O\x03\x03a\xb1\x05Q\x03\x03\t\xdb\x03\x11g\xe1i\xe3k\xe5m\xbdo\xe7q\xe9s\xebu\xef\x05S\x05U\x05W\x05Y\x05[\x05]\x05_\x05a\x03\x03\t\xf3\x03\x051\xc13\xf7\x03\x03\t\xf9\x03\x03/\xfb\x1d\x81\x07\x05c\x1d\x85\x07\x05e\x1d\x89\x07\x05g\x05i\x05k\x05m\x05o\x05q\x05s\x05u\x05w\x05y\x05{\x03\x03;\xc3\x1d\xa3\x07\x05}\x1d\xa7\x07\x05\x7f\x05\x81\x05\x83\x05\x85\x05\x87\x13\x11\x01\x1f'\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d\x89\x1f-\x01\x05\x03\x03\x01\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\t\x07\x1f\x1d\x11\x01\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x07\xc9\xcd\xd1\r\x03\xb7\xcb\x1d\x8b\r\x03\xb7\xcf\x1d\x8d\r\x03\xb7\xd3\x1d\x8f\x1d\x91\x1d\x93\x1f\x03\x11\x03\x00\x00\x00\x00\x00\x00\x00\x1f\x19\x01\x1f\x03\x11\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x03\x00\x00\x00\x0b\x05\x1d\x95\x1d\x97\x05\x01\x03\t\xb3\xb3\xb3\xbf\x03\x03\xed\x15\x03\x01\r\x01\x03\x07\xbf\xf1\xb3\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x07\x05\x1f\x1b\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x03\x11\x00\x00\x00\x00\x00\x00\x00\x00\x07\x0b\x05\x99\x1d\n\x02\x07\x05\x9b\x01\x02\x02)\x01\x11)\x01\x0f)\x03\r\x0f)\x03\x05\x0f)\x01\x17)\x05\r\r\x13\x1b\x1d\x03!\x13\x01)\x03\x01\x0f)\x01\x13)\x03\x05\x11\x11\x01\x07\r\x07\x07\t)\x03%\x13)\x03\t\x0f)\x03\x01\x15)\x03\t\x15)\x03\x05\x15)\x03\x01\x11)\x05\x05\x05\x17)\x05\r\r\x17)\x03\t\x11\x04f\n\x05\x01\x11\x05A\x07\x03\x01\x05\x19\x11\x05I\x05\x03K\x85\x13\x03U)\x03#\x0f\x06Y\x03\r\x03\x01\x03\x03\x01\r\x03\x03\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x05\x0f\x06\x01\x03\t\x03\t\x0b\x06\x01\x03\x05\x03\x07\x0f\x06\x01\x03\t\x03\r\x1b\x07\x01_\x03%\x05\x0b\x0f\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x13\x0f\x06\x01\x03\t\x03\x15\x03\x03\x01c\x03\x19\x03\x03\x01\x1f\x03\x03\x03\x03\x01\x19\x03\x05\x03\x03\x01\x19\x03\x05\x1d\x07\x01e\x07\r\x07\x05\t\x1b\x1d\x1f\x03\x03\x03\x01w\x03\x05\x05\x07\x01\x0b\x03\x07\x03'\x1f\x06\x01\x03\x07\x05#)\x03\x03\x01!\x03\x05\x05\x07\x01\x0b\x03\x05\x03-\x07\x07\x01y\x03\x0b\x05%/\x05\x07\x01\x0b\x03/\x031\x03\x03\x01{\x03\x1b\x05\x07\x01\x0b\x03\r\x035\x05\x07\x01}\x031\x033\r\x06\x01\x03\r\x079!7\x13\x03\x7f)\x03\x07\x03\x03\x83\x13\x03\x03\x03\x03\x87\x13\x03\x03!\x16\x03\t\x03\x03\x07\x07\tA?=+\t\x03\r\x0f\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\r\x03\x03\x07\x07\x06\x02\x11\x03\x0b\x05KS\x11\x04\x03\x03U\x03]\xaf\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05MS\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05MW\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M[\r\x06\x17\x03\x03\x07Y]M\x0b\x06#\x03\x05\x03_\x05\x07\x1b\x0b\x03\t\x03a\x15\x07=5\x03\x05\x05Qc\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05Mg\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05Mk\r\x06\x17\x03\x03\x07imM\x0b\x06#\x03\x05\x03o\x05\x07\x1b\x0b\x03\t\x03q\x15\x07=5\x03\x05\x05Os\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05ew\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e{\r\x06\x17\x03\x05\x07y}e#\x07\xa1\x9f\x03\t\x05O\x7f\x0f\x06\xa5\x03\x05\x03\x81\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05M\x85\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M\x89\r\x06\x17\x03\x03\x07\x87\x8bM\x0b\x06#\x03\x05\x03\x8d\x05\x07\x1b\x0b\x03\t\x03\x8f\x17\x17\x1d?\x03\x07\x07O\x91\x83\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05e\x95\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e\x99\r\x06\x17\x03\x05\x07\x97\x9be\x05\x07\x1b\x0b\x03\t\x03\x9d\x17\x17\x1d?\x03\x07\x07\x93\x9fu\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05K\xa3\x11\x04\x03\t\xa5U\xa1Q\x11\x04\x05\x07;+G\x06\x03\x01\x05\x01\x00~%\x9dM2\x04\x1d\x03\x0f\x0b\t\t\t!'\x1f;+y\x87.\x04!\x19+\xb1\xb3YMO{\xe9\x8b\x83\x1f/!!)#\x1f\x19\x157\x85\x8b\x1f\x1f\x15\x1d\x15\x1b%)\x19'#+\x1b+\x83\x13\r#\x13\x19\x1f\x1f\x11\x17\x15\x11\x15\x17\x15\x17\x0f\x17)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00compare_v1\x00add_v1\x00convert_v1\x00select_v1\x00reshape_v1\x00return_v1\x00iota_v1\x00gather_v1\x00scatter_v1\x00func_v1\x00concatenate_v1\x00custom_call_v1\x00subtract_v1\x00while_v1\x00dynamic_slice_v1\x00value\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00index_vector_dim\x00indices_are_sorted\x00slice_sizes\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit()/jit(main)/iota[dtype=complex64 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/lu\x00dimension\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]\x00jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]\x00jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]\x00jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]\x00jit()/jit(main)/while/body/add\x00jit()/jit(main)/while/body/lt\x00jit()/jit(main)/while/body/select_n\x00jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]\x00jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]\x00collapsed_slice_dims\x00offset_dims\x00start_index_map\x00jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]\x00jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]\x00jit()/jit(main)/while/body/squeeze[dimensions=(0,)]\x00inserted_window_dims\x00scatter_dims_to_operand_dims\x00unique_indices\x00update_window_dims\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_cgetrf\x00jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]\x00jit()/jit(main)/while/cond/lt\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_14['c128'] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zgetrf'], - serialized_date=datetime.date(2023, 6, 14), - inputs=(), - expected_outputs=(array([[6. +0.j, 7. +0.j, 8. +0.j], - [0. +0.j, 1. +0.j, 2. +0.j], - [0.5+0.j, 0.5+0.j, 0. +0.j]]), array([2, 2, 2], dtype=int32), array([2, 0, 1], dtype=int32)), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "[0]"}, tensor<3xi32> {jax.result_info = "[1]"}, tensor<3xi32> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xcomplex> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<9xcomplex>) -> tensor<3x3xcomplex> loc(#loc4) - %2 = stablehlo.constant dense<3> : tensor loc(#loc5) - %3 = stablehlo.constant dense<3> : tensor loc(#loc5) - %4 = stablehlo.convert %2 : (tensor) -> tensor loc(#loc5) - %5 = stablehlo.reshape %4 : (tensor) -> tensor<1xi32> loc(#loc5) - %6 = stablehlo.convert %3 : (tensor) -> tensor loc(#loc5) - %7 = stablehlo.reshape %6 : (tensor) -> tensor<1xi32> loc(#loc5) - %8 = stablehlo.concatenate %5, %7, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5) - %9 = stablehlo.constant dense<3> : tensor loc(#loc5) - %10 = stablehlo.convert %9 : (tensor) -> tensor loc(#loc5) - %11 = stablehlo.reshape %10 : (tensor) -> tensor<1xi32> loc(#loc5) - %12 = stablehlo.constant dense<> : tensor<0xi32> loc(#loc5) - %13 = stablehlo.constant dense<1> : tensor loc(#loc5) - %14 = stablehlo.constant dense<3> : tensor loc(#loc5) - %15 = stablehlo.constant dense<3> : tensor loc(#loc5) - %16:3 = stablehlo.custom_call @lapack_zgetrf(%13, %14, %15, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<3x3xcomplex>) -> (tensor<3x3xcomplex>, tensor<3xi32>, tensor) loc(#loc5) - %17 = stablehlo.constant dense<1> : tensor loc(#loc5) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<3xi32> loc(#loc5) - %19 = stablehlo.subtract %16#1, %18 : tensor<3xi32> loc(#loc5) - %20 = stablehlo.constant dense<0> : tensor loc(#loc5) - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor loc(#loc5) - %22 = stablehlo.compare GE, %16#2, %21, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %24 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc5) - %26 = stablehlo.broadcast_in_dim %23, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc5) - %27 = stablehlo.select %26, %16#0, %25 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc5) - %28 = stablehlo.iota dim = 0 : tensor<3xi32> loc(#loc6) - %29 = stablehlo.constant dense<0> : tensor loc(#loc7) - %30 = stablehlo.constant dense<0> : tensor loc(#loc8) - %31:4 = stablehlo.while(%iterArg = %30, %iterArg_0 = %29, %iterArg_1 = %28, %iterArg_2 = %19) : tensor, tensor, tensor<3xi32>, tensor<3xi32> - cond { - %32 = stablehlo.constant dense<3> : tensor loc(#loc9) - %33 = stablehlo.compare LT, %iterArg, %32, SIGNED : (tensor, tensor) -> tensor loc(#loc10) - stablehlo.return %33 : tensor loc(#loc9) - } do { - %32 = stablehlo.constant dense<1> : tensor loc(#loc9) - %33 = stablehlo.add %iterArg_0, %32 : tensor loc(#loc11) - %34 = stablehlo.constant dense<0> : tensor loc(#loc9) - %35 = stablehlo.compare LT, %iterArg_0, %34, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %36 = stablehlo.constant dense<3> : tensor loc(#loc9) - %37 = stablehlo.add %iterArg_0, %36 : tensor loc(#loc11) - %38 = stablehlo.select %35, %37, %iterArg_0 : tensor, tensor loc(#loc13) - %39 = stablehlo.convert %38 : (tensor) -> tensor loc(#loc14) - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %41 = "stablehlo.gather"(%iterArg_2, %40) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %42 = stablehlo.constant dense<0> : tensor loc(#loc9) - %43 = stablehlo.compare LT, %iterArg_0, %42, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %44 = stablehlo.constant dense<3> : tensor loc(#loc9) - %45 = stablehlo.add %iterArg_0, %44 : tensor loc(#loc11) - %46 = stablehlo.select %43, %45, %iterArg_0 : tensor, tensor loc(#loc13) - %47 = stablehlo.convert %46 : (tensor) -> tensor loc(#loc14) - %48 = stablehlo.broadcast_in_dim %47, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %49 = "stablehlo.gather"(%iterArg_1, %48) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %50 = stablehlo.constant dense<0> : tensor loc(#loc9) - %51 = stablehlo.compare LT, %41, %50, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %52 = stablehlo.constant dense<3> : tensor loc(#loc9) - %53 = stablehlo.add %41, %52 : tensor loc(#loc11) - %54 = stablehlo.select %51, %53, %41 : tensor, tensor loc(#loc13) - %55 = stablehlo.dynamic_slice %iterArg_1, %54, sizes = [1] : (tensor<3xi32>, tensor) -> tensor<1xi32> loc(#loc17) - %56 = stablehlo.reshape %55 : (tensor<1xi32>) -> tensor loc(#loc18) - %57 = stablehlo.constant dense<0> : tensor loc(#loc9) - %58 = stablehlo.compare LT, %iterArg_0, %57, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %59 = stablehlo.constant dense<3> : tensor loc(#loc9) - %60 = stablehlo.add %iterArg_0, %59 : tensor loc(#loc11) - %61 = stablehlo.select %58, %60, %iterArg_0 : tensor, tensor loc(#loc13) - %62 = stablehlo.convert %61 : (tensor) -> tensor loc(#loc14) - %63 = stablehlo.broadcast_in_dim %62, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %64 = "stablehlo.scatter"(%iterArg_1, %63, %56) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %65 = stablehlo.constant dense<0> : tensor loc(#loc9) - %66 = stablehlo.compare LT, %41, %65, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %67 = stablehlo.constant dense<3> : tensor loc(#loc9) - %68 = stablehlo.add %41, %67 : tensor loc(#loc11) - %69 = stablehlo.select %66, %68, %41 : tensor, tensor loc(#loc13) - %70 = stablehlo.broadcast_in_dim %69, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %71 = "stablehlo.scatter"(%64, %70, %49) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %72 = stablehlo.constant dense<1> : tensor loc(#loc9) - %73 = stablehlo.add %iterArg, %72 : tensor loc(#loc11) - stablehlo.return %73, %33, %71, %iterArg_2 : tensor, tensor, tensor<3xi32>, tensor<3xi32> loc(#loc9) - } loc(#loc9) - return %27, %19, %31#2 : tensor<3x3xcomplex>, tensor<3xi32>, tensor<3xi32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":553:0) -#loc2 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":554:0) -#loc3 = loc("jit()/jit(main)/iota[dtype=complex128 shape=(9,) dimension=0]"(#loc1)) -#loc4 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc1)) -#loc5 = loc("jit()/jit(main)/lu"(#loc2)) -#loc6 = loc("jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]"(#loc2)) -#loc7 = loc("jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]"(#loc2)) -#loc8 = loc("jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]"(#loc2)) -#loc9 = loc("jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]"(#loc2)) -#loc10 = loc("jit()/jit(main)/while/cond/lt"(#loc2)) -#loc11 = loc("jit()/jit(main)/while/body/add"(#loc2)) -#loc12 = loc("jit()/jit(main)/while/body/lt"(#loc2)) -#loc13 = loc("jit()/jit(main)/while/body/select_n"(#loc2)) -#loc14 = loc("jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]"(#loc2)) -#loc15 = loc("jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]"(#loc2)) -#loc16 = loc("jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]"(#loc2)) -#loc17 = loc("jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]"(#loc2)) -#loc18 = loc("jit()/jit(main)/while/body/squeeze[dimensions=(0,)]"(#loc2)) -#loc19 = loc("jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x013\x05\x01\x03\x01\x03\x05\x03#\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%'\x03\xaa\x02\x0e\x025\x01\xb1\x0f\x0f\x07\x17\x0b\x13\x13\x0f\x1b\x13\x0f\x0f\x13\x0f\x13\x13\x13\x0f\x0f\x0b\x13\x17\x0b\x0b\x0b\x0b;\x0b\x0b\x0b\x0f;#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x13\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x13\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x03Q\x0f\x0f/\x0b\x0f\x0b\x0bO\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b/\x0f/\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x17/\x1f\x1f\x0bOO/\x0b\x01\x07\x0b\x13\x0b\x01\x03\x0f\x033\x0f\x0f\x13\x13\x0f\x17\x07\x07\x0b\x07\x07\x13\x0f\x13\x1b\x07\x13\x13\x13\x13\x13\x13\x17\x17\x13\x02\x82\t\x1d]\x07\x1d\x8b\x07\x1f\x17-\xaa\x08\x01\x05)\x03\x03/\xb9\x03\x03\t\xd9\x1d\x8d\x07\x03\x051\xc13\xff\x03\x03\t\xfd\x1d\x8f\x07\x1d\x91\x07\x03\x03\t\xdf\x1d\x95\x07\x1d\x02\x02\x07\x03\x03\t\xdd\x03\x03\t\xf5\x1d\x93\x07\x11\x01\x05\x05+\x03\x03S\xb1\x17-\xa6\x08\x01\x05-\x05/\x051\x053\x03\r\x97\xb57\xb19\xbb\x99\xb9;\xc3\x9b\xb5\x055\x057\x059\x1d\x9d\x07\x03\r7\xb19\xbb\xa9\xb5\xab\xb5\xad\xbb\xaf\xb9\x03\x07C%E%'G\x05;\x05=\x05?\x03\x0bK\xbdM\xc5O\xc7'\xd5Q\xd7\x05A\x05C\x05E\x05G\x05I\x1dW+\x05K\x1d[+\x05M\x05O\x03\x03a\xb1\x05Q\x03\x03\t\xdb\x03\x11g\xe1i\xe3k\xe5m\xbdo\xe7q\xe9s\xebu\xef\x05S\x05U\x05W\x05Y\x05[\x05]\x05_\x05a\x03\x03\t\xf3\x03\x051\xc13\xf7\x03\x03\t\xf9\x03\x03/\xfb\x1d\x81\x07\x05c\x1d\x85\x07\x05e\x1d\x89\x07\x05g\x05i\x05k\x05m\x05o\x05q\x05s\x05u\x05w\x05y\x05{\x03\x03;\xc3\x1d\xa3\x07\x05}\x1d\xa7\x07\x05\x7f\x05\x81\x05\x83\x05\x85\x05\x87\x13\x11\x01\x1f'\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d\x89\x1f-\x01\x05\x03\x03\x01\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\t\x07\x1f\x1d\x11\x01\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x07\xc9\xcd\xd1\r\x03\xb7\xcb\x1d\x8b\r\x03\xb7\xcf\x1d\x8d\r\x03\xb7\xd3\x1d\x8f\x1d\x91\x1d\x93\x1f\x03\x11\x03\x00\x00\x00\x00\x00\x00\x00\x1f\x19\x01\x1f\x03\x11\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x03\x00\x00\x00\x0b\x05\x1d\x95\x1d\x97\x05\x01\x03\t\xb3\xb3\xb3\xbf\x03\x03\xed\x15\x03\x01\r\x01\x03\x07\xbf\xf1\xb3\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x07\x05\x1f\x1b!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x03\x11\x00\x00\x00\x00\x00\x00\x00\x00\x07\x0b\x05\x99\x1d\n\x02\x07\x05\x9b\x01\x02\x02)\x01\x11)\x01\x0f)\x03\r\x0f)\x03\x05\x0f)\x01\x17)\x05\r\r\x13\x1b\x1d\x03!\x13\x01)\x03\x01\x0f)\x01\x13)\x03\x05\x11\x11\x01\x07\r\x07\x07\x0b)\x03%\x13)\x03\t\x0f)\x03\x01\x15)\x03\t\x15)\x03\x05\x15)\x03\x01\x11)\x05\x05\x05\x17)\x05\r\r\x17)\x03\t\x11\x04f\n\x05\x01\x11\x05A\x07\x03\x01\x05\x19\x11\x05I\x05\x03K\x85\x13\x03U)\x03#\x0f\x06Y\x03\r\x03\x01\x03\x03\x01\r\x03\x03\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x05\x0f\x06\x01\x03\t\x03\t\x0b\x06\x01\x03\x05\x03\x07\x0f\x06\x01\x03\t\x03\r\x1b\x07\x01_\x03%\x05\x0b\x0f\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x13\x0f\x06\x01\x03\t\x03\x15\x03\x03\x01c\x03\x19\x03\x03\x01\x1f\x03\x03\x03\x03\x01\x19\x03\x05\x03\x03\x01\x19\x03\x05\x1d\x07\x01e\x07\r\x07\x05\t\x1b\x1d\x1f\x03\x03\x03\x01w\x03\x05\x05\x07\x01\x0b\x03\x07\x03'\x1f\x06\x01\x03\x07\x05#)\x03\x03\x01!\x03\x05\x05\x07\x01\x0b\x03\x05\x03-\x07\x07\x01y\x03\x0b\x05%/\x05\x07\x01\x0b\x03/\x031\x03\x03\x01{\x03\x1b\x05\x07\x01\x0b\x03\r\x035\x05\x07\x01}\x031\x033\r\x06\x01\x03\r\x079!7\x13\x03\x7f)\x03\x07\x03\x03\x83\x13\x03\x03\x03\x03\x87\x13\x03\x03!\x16\x03\t\x03\x03\x07\x07\tA?=+\t\x03\r\x0f\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\r\x03\x03\x07\x07\x06\x02\x11\x03\x0b\x05KS\x11\x04\x03\x03U\x03]\xaf\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05MS\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05MW\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M[\r\x06\x17\x03\x03\x07Y]M\x0b\x06#\x03\x05\x03_\x05\x07\x1b\x0b\x03\t\x03a\x15\x07=5\x03\x05\x05Qc\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05Mg\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05Mk\r\x06\x17\x03\x03\x07imM\x0b\x06#\x03\x05\x03o\x05\x07\x1b\x0b\x03\t\x03q\x15\x07=5\x03\x05\x05Os\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05ew\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e{\r\x06\x17\x03\x05\x07y}e#\x07\xa1\x9f\x03\t\x05O\x7f\x0f\x06\xa5\x03\x05\x03\x81\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05M\x85\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M\x89\r\x06\x17\x03\x03\x07\x87\x8bM\x0b\x06#\x03\x05\x03\x8d\x05\x07\x1b\x0b\x03\t\x03\x8f\x17\x17\x1d?\x03\x07\x07O\x91\x83\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05e\x95\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e\x99\r\x06\x17\x03\x05\x07\x97\x9be\x05\x07\x1b\x0b\x03\t\x03\x9d\x17\x17\x1d?\x03\x07\x07\x93\x9fu\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05K\xa3\x11\x04\x03\t\xa5U\xa1Q\x11\x04\x05\x07;+G\x06\x03\x01\x05\x01\x00\x82%\x9dM2\x04\x1d\x03\x0f\x0b\t\t\t!'\x1f;+y\x87.\x04!\x19+\xb1\xb3YMO{\xe9\x8b\x83\x1f/!!)#\x1f\x19\x157\x85\x8d\x1f\x1f\x15\x1d\x15\x1b%)\x19'#+\x1b+\x83\x13\r#\x13\x19\x1f\x1f\x11\x17\x15\x11\x15\x17\x15\x17\x0f\x17)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00compare_v1\x00add_v1\x00convert_v1\x00select_v1\x00reshape_v1\x00return_v1\x00iota_v1\x00gather_v1\x00scatter_v1\x00func_v1\x00concatenate_v1\x00custom_call_v1\x00subtract_v1\x00while_v1\x00dynamic_slice_v1\x00value\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00index_vector_dim\x00indices_are_sorted\x00slice_sizes\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit()/jit(main)/iota[dtype=complex128 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/lu\x00dimension\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]\x00jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]\x00jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]\x00jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]\x00jit()/jit(main)/while/body/add\x00jit()/jit(main)/while/body/lt\x00jit()/jit(main)/while/body/select_n\x00jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]\x00jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]\x00collapsed_slice_dims\x00offset_dims\x00start_index_map\x00jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]\x00jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]\x00jit()/jit(main)/while/body/squeeze[dimensions=(0,)]\x00inserted_window_dims\x00scatter_dims_to_operand_dims\x00unique_indices\x00update_window_dims\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_zgetrf\x00jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]\x00jit()/jit(main)/while/cond/lt\x00", - xla_call_module_version=6, -) # End paste +array = np.array +int32 = np.int32 +float32 = np.float32 +complex64 = np.complex64 data_2024_05_31 = {} - # Pasted from the test output (see export_back_compat_test_util.py module docstring) data_2024_05_31["c128"] = dict( testdata_version=1, diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_qr_lapack_geqrf.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_qr_lapack_geqrf.py index 94314a7ae518..2ebe6bea0360 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_qr_lapack_geqrf.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_qr_lapack_geqrf.py @@ -15,261 +15,19 @@ # ruff: noqa import datetime -from numpy import array, float32, complex64 +import numpy as np -data_2023_03_17 = {} +array = np.array +float32 = np.float32 +complex64 = np.complex64 -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_03_17["f32"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_sgeqrf', 'lapack_sorgqr'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[ 0. , 0.91287076, 0.4082487 ], - [-0.44721356, 0.36514866, -0.8164965 ], - [-0.8944271 , -0.18257445, 0.40824816]], dtype=float32), array([[-6.7082043e+00, -8.0498438e+00, -9.3914852e+00], - [ 0.0000000e+00, 1.0954441e+00, 2.1908894e+00], - [ 0.0000000e+00, 0.0000000e+00, 7.1525574e-07]], dtype=float32)), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "[0]"}, tensor<3x3xf32> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xf32> - %1 = stablehlo.reshape %0 : (tensor<9xf32>) -> tensor<3x3xf32> - %2 = stablehlo.constant dense<1> : tensor - %3 = stablehlo.constant dense<3> : tensor - %4 = stablehlo.constant dense<3> : tensor - %5 = stablehlo.constant dense<96> : tensor - %6 = stablehlo.custom_call @lapack_sgeqrf(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor<3x3xf32>) -> tuple, tensor<3xf32>, tensor, tensor<96xf32>> - %7 = stablehlo.get_tuple_element %6[0] : (tuple, tensor<3xf32>, tensor, tensor<96xf32>>) -> tensor<3x3xf32> - %8 = stablehlo.get_tuple_element %6[1] : (tuple, tensor<3xf32>, tensor, tensor<96xf32>>) -> tensor<3xf32> - %9 = stablehlo.get_tuple_element %6[2] : (tuple, tensor<3xf32>, tensor, tensor<96xf32>>) -> tensor - %10 = stablehlo.get_tuple_element %6[3] : (tuple, tensor<3xf32>, tensor, tensor<96xf32>>) -> tensor<96xf32> - %11 = stablehlo.constant dense<0> : tensor - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor - %13 = stablehlo.compare EQ, %9, %12, SIGNED : (tensor, tensor) -> tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1x1xi1> - %15 = stablehlo.constant dense<0x7FC00000> : tensor - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<3x3xf32> - %17 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %18 = stablehlo.select %17, %7, %16 : tensor<3x3xi1>, tensor<3x3xf32> - %19 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi1> - %20 = stablehlo.constant dense<0x7FC00000> : tensor - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<3xf32> - %22 = stablehlo.broadcast_in_dim %19, dims = [0] : (tensor<1xi1>) -> tensor<3xi1> - %23 = stablehlo.select %22, %8, %21 : tensor<3xi1>, tensor<3xf32> - %24 = stablehlo.constant dense<0.000000e+00> : tensor - %25 = stablehlo.pad %18, %24, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> - %26 = stablehlo.constant dense<1> : tensor - %27 = stablehlo.constant dense<3> : tensor - %28 = stablehlo.constant dense<3> : tensor - %29 = stablehlo.constant dense<3> : tensor - %30 = stablehlo.constant dense<96> : tensor - %31 = stablehlo.custom_call @lapack_sorgqr(%26, %27, %28, %29, %30, %25, %23) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xf32>, tensor<3xf32>) -> tuple, tensor, tensor<96xf32>> - %32 = stablehlo.get_tuple_element %31[0] : (tuple, tensor, tensor<96xf32>>) -> tensor<3x3xf32> - %33 = stablehlo.get_tuple_element %31[1] : (tuple, tensor, tensor<96xf32>>) -> tensor - %34 = stablehlo.get_tuple_element %31[2] : (tuple, tensor, tensor<96xf32>>) -> tensor<96xf32> - %35 = stablehlo.constant dense<0> : tensor - %36 = stablehlo.broadcast_in_dim %35, dims = [] : (tensor) -> tensor - %37 = stablehlo.compare EQ, %33, %36, SIGNED : (tensor, tensor) -> tensor - %38 = stablehlo.broadcast_in_dim %37, dims = [] : (tensor) -> tensor<1x1xi1> - %39 = stablehlo.constant dense<0x7FC00000> : tensor - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor) -> tensor<3x3xf32> - %41 = stablehlo.broadcast_in_dim %38, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %42 = stablehlo.select %41, %32, %40 : tensor<3x3xi1>, tensor<3x3xf32> - %43 = call @triu(%18) : (tensor<3x3xf32>) -> tensor<3x3xf32> - return %42, %43 : tensor<3x3xf32>, tensor<3x3xf32> - } - func.func private @triu(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> - %1 = stablehlo.constant dense<-1> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<3x3xi32> - %3 = stablehlo.add %0, %2 : tensor<3x3xi32> - %4 = stablehlo.iota dim = 1 : tensor<3x3xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<3x3xf32> - %8 = stablehlo.select %5, %7, %arg0 : tensor<3x3xi1>, tensor<3x3xf32> - return %8 : tensor<3x3xf32> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01+\x05\x01\x05\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03\xa2\x02\n\x027\x01\x9b\x0f\x0f\x17\x13\x0b\x0f\x13\x07\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x13\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0bK\x13\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0bK\x03g\x0fO/\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1f\x1f\x1f\x1f\x0b\x1f\x0f\x17\x1b\x0f\x0f\x0f\x0f\x1f\x0b\x1fO/\x0b'\x0f\x17\x17\x01\x05\x17\x0b\x037\x0f\x17\x0f\x07\x07\x07\x07\x17\x13\x17\x17\x07\x0f\x17\x13\x17\x17\x13\x13\x1b\x13\x13\x13\x13\x13\x13\x17\x02\xae\t\x1d\x7f\x05\x1d\x97\x05\x17!\xee\x05\x01\x03\x03\x15\xcd\x05!\x1dY\x05\x03\x03\t\xd7\x1f\x05#\x05%\x05'\x03\x03\t\xf1\x05)\x05+\x05-\x05/\x051\x03\x03%\xc9\x053\x1da\x05\x055\x057\x03\x03\t\xd3\x17!\xea\x05\x01\x03\x03\t\xd5\x03\x03\t\xd9\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x11\xe5\x03\x03\x11\xe7\x03\x03\x11\xe9\x03\x03\t\xed\x03\x05)\xab+\xef\x03\x03\x15\xf3\x03\x03\x13S\x05I\x03\x0b\x19\xa1\x1b\xb3\x1d\xb5\x13\xbf\x1f\xc1\x03\x0b\x19\xa7\x1b\xc5\x1d\xa7\x13\xa9\x1f\xc7\x05K\x1d]\x05\x05M\x03\x03\t\xcb\x05O\x03\x03%\xcf\x1dg\x05\x05Q\x03\x05)\xab+\xd1\x1dm\x05\x05S\x1dq\x05\x05U\x1du\x05\x05W\x1dy/\x05Y\x1d}/\x05[\x05]\x03\x115\xad7\xaf9\xdb;\xa1=\xb1?\xddA\xdfC\xe3\x03\x03\x11\xeb\x03\x03\x15\xf5\x1d\x89\x05\x05_\x03\x07\x8d\xa3\x8f\xa3\x91\xa3\x05a\x05c\x05e\x1d\x95\x05\x05g\x05i\x03\x115\xad7\xaf9\xf7;\xa1=\xb1?\xf9A\xfbC\xff\x1f)\x01\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dk\x03\x03\xc3\x1dm\t\x07\x0b\x05\x1do\x05\x01#\x1f\x03\x05\xb7\xbb\r\x03\xa5\xb9\x1dq\r\x03\xa5\xbd\x1ds\x1du\x1dw\r\x01#!\x1dy\x13\x0b\x01\x1f\x01\t\xff\xff\xff\xff\x1f#\x01\x13\x0b\x05\x07\x05\x1f\x05\t\x00\x00\x00\x00\x1f\x01\t\x01\x00\x00\x00\x1f\x01\t\x03\x00\x00\x00\x1f\x01\t`\x00\x00\x00\x1d{\x03\x0b\x9b\x9b\x9b\x9b\x9d\x03\x03\xe1\x15\x03\x01\x11\x01\x03\t\x9d\x9f\x9b\x9f\x13\x07\x01\x13\x07\x05\x13\x07\t\x13\x07\r\x1f\x01\t\x00\x00\x00\x00\x07\x01\x1f\x05\t\x00\x00\xc0\x7f\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d}\x03\x0f\x9b\x9b\x9b\x9b\x9b\x9d\x9f\x03\x03\xfd\x15\x03\x01\x15\x01\x03\x07\x9d\x9b\x9f\x03\x03\x06\x02\xa9\x05\x7f)\x01\x07)\x05\r\r\t)\x01\t\x1b\t\x1d\x01)\x05\r\r\x07)\x03\r\t)\x03\x02\x03\t)\x05\r\r\r\x13)\x01\r)\x05\x05\x05\r)\x03\t\x0b\x11\x01\x05\x03\x03\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03%\t/\t\x03\x11\x01\x13)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x03\x05\r)\x03\r\r)\x03\x05\x0b/\x07\x03\x01\x13\x04\xe6\x06\x05\x01\x11\x0fQ\x07\x03\x01\t\x0f\x11\x0fU\x05\x03Y\xb5\x0b\x03w#\x03%\x17\x06{\x03\x03\x03\x01\x03\x03\x011\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x013\x03\x01\x13\x07\x01\x81\x03'\x0b\x05\x07\t\x0b\x03\x07\x07\x01E\x03\x03\x03\r\x07\x07\x01G\x03\x11\x03\r\x07\x07\x01I\x03\x01\x03\r\x07\x07\x01\x83\x03\x13\x03\r\x03\x03\x01K\x03\x01\x05\x07\x01\x07\x03\x01\x03\x17\r\x07\x01M\x03\x19\x05\x13\x19\x05\x07\x01\x07\x03\x1b\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x03\x03\x1f\x05\x07\x01O\x03\x15\x03\x1d\t\x06\x01\x03\x03\x07#\x0f!\x05\x07\x01\x07\x03/\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x11\x03)\x05\x07\x01\x85\x031\x03'\t\x06\x01\x03\x11\x07-\x11+\x03\x03\x87-\x03\x05\x19\x07\x93\x8b\x03\x03\x05%1\x03\x03\x031\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x033\x03\x01\x13\x07\x03\x99\x035\x0f579;=3/\x07\x07\x03E\x03\x03\x03?\x07\x07\x03G\x03\x01\x03?\x07\x07\x03I\x03\x13\x03?\x03\x03\x03K\x03\x01\x05\x07\x03\x07\x03\x01\x03G\r\x07\x03M\x03\x19\x05CI\x05\x07\x03\x07\x03\x1b\x03K\x03\x03\x03\x17\x03\x05\x05\x07\x03\x07\x03\x03\x03O\x05\x07\x03O\x03\x15\x03M\t\x06\x03\x03\x03\x07SAQ\x1b\x07\x0b\x02\x02\x03\x03\x03%\x11\x04\x0f\x05UW\x0f\x11\x0bW\x05\x03\x15+\x03\x03\x0f\x0b\x03[#\x03\x0f\x03\x03\x0b_\x03\x01\x05\x07'\x07\x03\x0f\x03\x05\x15\x06'\x03\x0f\x05\x03\x07\x0b\x03ec\x03\x0f\r\x07ki\x03\x15\x05\t\x0b\x03\x03\x0b-\x03\x05\x05\x07o\x07\x03\x03\x03\x0f\t\x06s\x03\x03\x07\r\x11\x01\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\xc6\x18\x81\x0f\x1d\x1d\x11\x0f\x0b\t\t\x03\x0b!Y\x87##%_=\x85\x87W\xb3K\x9bM\x9b\xd2\x02\x1b\x1f/!!)#\x1f\x19+\x1b\x1f\x83\x1f\x15\x1d\x15+\x13\r\r\x11\x0f\x17\x0f\x1f\x15\x11\x17\x11\x15+)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00get_tuple_element_v1\x00select_v1\x00iota_v1\x00compare_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00index\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_sgeqrf\x00lapack_sorgqr\x00callee\x00", - xla_call_module_version=4, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_03_17["f64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dgeqrf', 'lapack_dorgqr'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[ 0. , 0.9128709291752773 , 0.40824829046386235], - [-0.447213595499958 , 0.3651483716701102 , -0.8164965809277263 ], - [-0.894427190999916 , -0.1825741858350548 , 0.40824829046386324]]), array([[-6.7082039324993694e+00, -8.0498447189992444e+00, - -9.3914855054991175e+00], - [ 0.0000000000000000e+00, 1.0954451150103341e+00, - 2.1908902300206665e+00], - [ 0.0000000000000000e+00, 0.0000000000000000e+00, - -8.8817841970012523e-16]])), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<3x3xf64> {jax.result_info = "[0]"}, tensor<3x3xf64> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xf64> - %1 = stablehlo.reshape %0 : (tensor<9xf64>) -> tensor<3x3xf64> - %2 = stablehlo.constant dense<1> : tensor - %3 = stablehlo.constant dense<3> : tensor - %4 = stablehlo.constant dense<3> : tensor - %5 = stablehlo.constant dense<96> : tensor - %6 = stablehlo.custom_call @lapack_dgeqrf(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor<3x3xf64>) -> tuple, tensor<3xf64>, tensor, tensor<96xf64>> - %7 = stablehlo.get_tuple_element %6[0] : (tuple, tensor<3xf64>, tensor, tensor<96xf64>>) -> tensor<3x3xf64> - %8 = stablehlo.get_tuple_element %6[1] : (tuple, tensor<3xf64>, tensor, tensor<96xf64>>) -> tensor<3xf64> - %9 = stablehlo.get_tuple_element %6[2] : (tuple, tensor<3xf64>, tensor, tensor<96xf64>>) -> tensor - %10 = stablehlo.get_tuple_element %6[3] : (tuple, tensor<3xf64>, tensor, tensor<96xf64>>) -> tensor<96xf64> - %11 = stablehlo.constant dense<0> : tensor - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor - %13 = stablehlo.compare EQ, %9, %12, SIGNED : (tensor, tensor) -> tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1x1xi1> - %15 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<3x3xf64> - %17 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %18 = stablehlo.select %17, %7, %16 : tensor<3x3xi1>, tensor<3x3xf64> - %19 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi1> - %20 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<3xf64> - %22 = stablehlo.broadcast_in_dim %19, dims = [0] : (tensor<1xi1>) -> tensor<3xi1> - %23 = stablehlo.select %22, %8, %21 : tensor<3xi1>, tensor<3xf64> - %24 = stablehlo.constant dense<0.000000e+00> : tensor - %25 = stablehlo.pad %18, %24, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf64>, tensor) -> tensor<3x3xf64> - %26 = stablehlo.constant dense<1> : tensor - %27 = stablehlo.constant dense<3> : tensor - %28 = stablehlo.constant dense<3> : tensor - %29 = stablehlo.constant dense<3> : tensor - %30 = stablehlo.constant dense<96> : tensor - %31 = stablehlo.custom_call @lapack_dorgqr(%26, %27, %28, %29, %30, %25, %23) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xf64>, tensor<3xf64>) -> tuple, tensor, tensor<96xf64>> - %32 = stablehlo.get_tuple_element %31[0] : (tuple, tensor, tensor<96xf64>>) -> tensor<3x3xf64> - %33 = stablehlo.get_tuple_element %31[1] : (tuple, tensor, tensor<96xf64>>) -> tensor - %34 = stablehlo.get_tuple_element %31[2] : (tuple, tensor, tensor<96xf64>>) -> tensor<96xf64> - %35 = stablehlo.constant dense<0> : tensor - %36 = stablehlo.broadcast_in_dim %35, dims = [] : (tensor) -> tensor - %37 = stablehlo.compare EQ, %33, %36, SIGNED : (tensor, tensor) -> tensor - %38 = stablehlo.broadcast_in_dim %37, dims = [] : (tensor) -> tensor<1x1xi1> - %39 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor) -> tensor<3x3xf64> - %41 = stablehlo.broadcast_in_dim %38, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %42 = stablehlo.select %41, %32, %40 : tensor<3x3xi1>, tensor<3x3xf64> - %43 = call @triu(%18) : (tensor<3x3xf64>) -> tensor<3x3xf64> - return %42, %43 : tensor<3x3xf64>, tensor<3x3xf64> - } - func.func private @triu(%arg0: tensor<3x3xf64>) -> tensor<3x3xf64> { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> - %1 = stablehlo.constant dense<-1> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<3x3xi32> - %3 = stablehlo.add %0, %2 : tensor<3x3xi32> - %4 = stablehlo.iota dim = 1 : tensor<3x3xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<3x3xf64> - %8 = stablehlo.select %5, %7, %arg0 : tensor<3x3xi1>, tensor<3x3xf64> - return %8 : tensor<3x3xf64> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01+\x05\x01\x05\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03\xa2\x02\n\x027\x01\x9b\x0f\x0f\x17\x13\x0b\x0f\x13\x07\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x13\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0bK\x13\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0bK\x03g\x0fO/\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b/\x1f\x1f\x1f\x0b\x1f\x0f\x17\x1b\x0f\x0f\x0f\x0f\x1f\x0b/O/\x0b'\x0f\x17\x17\x01\x05\x17\x0b\x037\x0f\x17\x0f\x07\x07\x07\x07\x17\x13\x17\x17\x07\x0f\x17\x13\x17\x17\x13\x13\x1b\x13\x13\x13\x13\x13\x13\x17\x02\xce\t\x1d\x7f\x05\x1d\x97\x05\x17!\xee\x05\x01\x03\x03\x15\xcd\x05!\x1dY\x05\x03\x03\t\xd7\x1f\x05#\x05%\x05'\x03\x03\t\xf1\x05)\x05+\x05-\x05/\x051\x03\x03%\xc9\x053\x1da\x05\x055\x057\x03\x03\t\xd3\x17!\xea\x05\x01\x03\x03\t\xd5\x03\x03\t\xd9\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x11\xe5\x03\x03\x11\xe7\x03\x03\x11\xe9\x03\x03\t\xed\x03\x05)\xab+\xef\x03\x03\x15\xf3\x03\x03\x13S\x05I\x03\x0b\x19\xa1\x1b\xb3\x1d\xb5\x13\xbf\x1f\xc1\x03\x0b\x19\xa7\x1b\xc5\x1d\xa7\x13\xa9\x1f\xc7\x05K\x1d]\x05\x05M\x03\x03\t\xcb\x05O\x03\x03%\xcf\x1dg\x05\x05Q\x03\x05)\xab+\xd1\x1dm\x05\x05S\x1dq\x05\x05U\x1du\x05\x05W\x1dy/\x05Y\x1d}/\x05[\x05]\x03\x115\xad7\xaf9\xdb;\xa1=\xb1?\xddA\xdfC\xe3\x03\x03\x11\xeb\x03\x03\x15\xf5\x1d\x89\x05\x05_\x03\x07\x8d\xa3\x8f\xa3\x91\xa3\x05a\x05c\x05e\x1d\x95\x05\x05g\x05i\x03\x115\xad7\xaf9\xf7;\xa1=\xb1?\xf9A\xfbC\xff\x1f)\x01\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dk\x03\x03\xc3\x1dm\t\x07\x0b\x05\x1do\x05\x01#\x1f\x03\x05\xb7\xbb\r\x03\xa5\xb9\x1dq\r\x03\xa5\xbd\x1ds\x1du\x1dw\r\x01#!\x1dy\x13\x0b\x01\x1f\x01\t\xff\xff\xff\xff\x1f#\x01\x13\x0b\x05\x07\x05\x1f\x05\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x01\t\x01\x00\x00\x00\x1f\x01\t\x03\x00\x00\x00\x1f\x01\t`\x00\x00\x00\x1d{\x03\x0b\x9b\x9b\x9b\x9b\x9d\x03\x03\xe1\x15\x03\x01\x11\x01\x03\t\x9d\x9f\x9b\x9f\x13\x07\x01\x13\x07\x05\x13\x07\t\x13\x07\r\x1f\x01\t\x00\x00\x00\x00\x07\x01\x1f\x05\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d}\x03\x0f\x9b\x9b\x9b\x9b\x9b\x9d\x9f\x03\x03\xfd\x15\x03\x01\x15\x01\x03\x07\x9d\x9b\x9f\x03\x03\x06\x02\xa9\x05\x7f)\x01\x07)\x05\r\r\t)\x01\t\x1b\x0b\x1d\x01)\x05\r\r\x07)\x03\r\t)\x03\x02\x03\t)\x05\r\r\r\x13)\x01\r)\x05\x05\x05\r)\x03\t\x0b\x11\x01\x05\x03\x03\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03%\t/\t\x03\x11\x01\x13)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x03\x05\r)\x03\r\r)\x03\x05\x0b/\x07\x03\x01\x13\x04\xe6\x06\x05\x01\x11\x0fQ\x07\x03\x01\t\x0f\x11\x0fU\x05\x03Y\xb5\x0b\x03w#\x03%\x17\x06{\x03\x03\x03\x01\x03\x03\x011\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x013\x03\x01\x13\x07\x01\x81\x03'\x0b\x05\x07\t\x0b\x03\x07\x07\x01E\x03\x03\x03\r\x07\x07\x01G\x03\x11\x03\r\x07\x07\x01I\x03\x01\x03\r\x07\x07\x01\x83\x03\x13\x03\r\x03\x03\x01K\x03\x01\x05\x07\x01\x07\x03\x01\x03\x17\r\x07\x01M\x03\x19\x05\x13\x19\x05\x07\x01\x07\x03\x1b\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x03\x03\x1f\x05\x07\x01O\x03\x15\x03\x1d\t\x06\x01\x03\x03\x07#\x0f!\x05\x07\x01\x07\x03/\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x11\x03)\x05\x07\x01\x85\x031\x03'\t\x06\x01\x03\x11\x07-\x11+\x03\x03\x87-\x03\x05\x19\x07\x93\x8b\x03\x03\x05%1\x03\x03\x031\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x033\x03\x01\x13\x07\x03\x99\x035\x0f579;=3/\x07\x07\x03E\x03\x03\x03?\x07\x07\x03G\x03\x01\x03?\x07\x07\x03I\x03\x13\x03?\x03\x03\x03K\x03\x01\x05\x07\x03\x07\x03\x01\x03G\r\x07\x03M\x03\x19\x05CI\x05\x07\x03\x07\x03\x1b\x03K\x03\x03\x03\x17\x03\x05\x05\x07\x03\x07\x03\x03\x03O\x05\x07\x03O\x03\x15\x03M\t\x06\x03\x03\x03\x07SAQ\x1b\x07\x0b\x02\x02\x03\x03\x03%\x11\x04\x0f\x05UW\x0f\x11\x0bW\x05\x03\x15+\x03\x03\x0f\x0b\x03[#\x03\x0f\x03\x03\x0b_\x03\x01\x05\x07'\x07\x03\x0f\x03\x05\x15\x06'\x03\x0f\x05\x03\x07\x0b\x03ec\x03\x0f\r\x07ki\x03\x15\x05\t\x0b\x03\x03\x0b-\x03\x05\x05\x07o\x07\x03\x03\x03\x0f\t\x06s\x03\x03\x07\r\x11\x01\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\xc6\x18\x81\x0f\x1d\x1d\x11\x0f\x0b\t\t\x03\x0b!Y\x87##%_=\x85\x87W\xb3K\x9bM\x9b\xd2\x02\x1b\x1f/!!)#\x1f\x19+\x1b\x1f\x83\x1f\x15\x1d\x15+\x13\r\r\x11\x0f\x17\x0f\x1f\x15\x11\x17\x11\x15+)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00get_tuple_element_v1\x00select_v1\x00iota_v1\x00compare_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00index\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float64 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_dgeqrf\x00lapack_dorgqr\x00callee\x00", - xla_call_module_version=4, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_03_17["c64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_cgeqrf', 'lapack_cungqr'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[ 0. +0.j, 0.91287076+0.j, 0.4082487 +0.j], - [-0.44721356-0.j, 0.36514866+0.j, -0.8164965 +0.j], - [-0.8944271 -0.j, -0.18257445+0.j, 0.40824816+0.j]], - dtype=complex64), array([[-6.7082043e+00+0.j, -8.0498438e+00+0.j, -9.3914852e+00+0.j], - [ 0.0000000e+00+0.j, 1.0954441e+00+0.j, 2.1908894e+00+0.j], - [ 0.0000000e+00+0.j, 0.0000000e+00+0.j, 7.1525574e-07+0.j]], - dtype=complex64)), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "[0]"}, tensor<3x3xcomplex> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xcomplex> - %1 = stablehlo.reshape %0 : (tensor<9xcomplex>) -> tensor<3x3xcomplex> - %2 = stablehlo.constant dense<1> : tensor - %3 = stablehlo.constant dense<3> : tensor - %4 = stablehlo.constant dense<3> : tensor - %5 = stablehlo.constant dense<96> : tensor - %6 = stablehlo.custom_call @lapack_cgeqrf(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor<3x3xcomplex>) -> tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>> - %7 = stablehlo.get_tuple_element %6[0] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor<3x3xcomplex> - %8 = stablehlo.get_tuple_element %6[1] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor<3xcomplex> - %9 = stablehlo.get_tuple_element %6[2] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor - %10 = stablehlo.get_tuple_element %6[3] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor<96xcomplex> - %11 = stablehlo.constant dense<0> : tensor - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor - %13 = stablehlo.compare EQ, %9, %12, SIGNED : (tensor, tensor) -> tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1x1xi1> - %15 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor>) -> tensor<3x3xcomplex> - %17 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %18 = stablehlo.select %17, %7, %16 : tensor<3x3xi1>, tensor<3x3xcomplex> - %19 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi1> - %20 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor>) -> tensor<3xcomplex> - %22 = stablehlo.broadcast_in_dim %19, dims = [0] : (tensor<1xi1>) -> tensor<3xi1> - %23 = stablehlo.select %22, %8, %21 : tensor<3xi1>, tensor<3xcomplex> - %24 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> - %25 = stablehlo.pad %18, %24, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xcomplex>, tensor>) -> tensor<3x3xcomplex> - %26 = stablehlo.constant dense<1> : tensor - %27 = stablehlo.constant dense<3> : tensor - %28 = stablehlo.constant dense<3> : tensor - %29 = stablehlo.constant dense<3> : tensor - %30 = stablehlo.constant dense<96> : tensor - %31 = stablehlo.custom_call @lapack_cungqr(%26, %27, %28, %29, %30, %25, %23) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xcomplex>, tensor<3xcomplex>) -> tuple>, tensor, tensor<96xcomplex>> - %32 = stablehlo.get_tuple_element %31[0] : (tuple>, tensor, tensor<96xcomplex>>) -> tensor<3x3xcomplex> - %33 = stablehlo.get_tuple_element %31[1] : (tuple>, tensor, tensor<96xcomplex>>) -> tensor - %34 = stablehlo.get_tuple_element %31[2] : (tuple>, tensor, tensor<96xcomplex>>) -> tensor<96xcomplex> - %35 = stablehlo.constant dense<0> : tensor - %36 = stablehlo.broadcast_in_dim %35, dims = [] : (tensor) -> tensor - %37 = stablehlo.compare EQ, %33, %36, SIGNED : (tensor, tensor) -> tensor - %38 = stablehlo.broadcast_in_dim %37, dims = [] : (tensor) -> tensor<1x1xi1> - %39 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor>) -> tensor<3x3xcomplex> - %41 = stablehlo.broadcast_in_dim %38, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %42 = stablehlo.select %41, %32, %40 : tensor<3x3xi1>, tensor<3x3xcomplex> - %43 = call @triu(%18) : (tensor<3x3xcomplex>) -> tensor<3x3xcomplex> - return %42, %43 : tensor<3x3xcomplex>, tensor<3x3xcomplex> - } - func.func private @triu(%arg0: tensor<3x3xcomplex>) -> tensor<3x3xcomplex> { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> - %1 = stablehlo.constant dense<-1> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<3x3xi32> - %3 = stablehlo.add %0, %2 : tensor<3x3xi32> - %4 = stablehlo.iota dim = 1 : tensor<3x3xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> - %6 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<3x3xcomplex> - %8 = stablehlo.select %5, %7, %arg0 : tensor<3x3xi1>, tensor<3x3xcomplex> - return %8 : tensor<3x3xcomplex> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01+\x05\x01\x05\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03\xa6\x02\n\x029\x01\x9b\x0f\x0f\x17\x13\x0b\x0f\x13\x07\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x13\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0bK\x13\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0bK\x03g\x0fO/\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b/\x1f\x1f\x1f\x0b\x1f\x0f\x17\x1b\x0f\x0f\x0f\x0f\x1f\x0b/O/\x0b'\x0f\x17\x17\x01\x05\x17\x0b\x039\x0f\x17\x0f\x07\x0b\x07\x07\x17\x13\x17\x17\x07\x0f\x17\x13\x17\x07\x17\x13\x13\x1b\x13\x13\x13\x13\x13\x13\x17\x02\xd6\t\x1d\x7f\x05\x1d\x97\x05\x17!\xee\x05\x01\x03\x03\x15\xcd\x05!\x1dY\x05\x03\x03\t\xd7\x1f\x05#\x05%\x05'\x03\x03\t\xf1\x05)\x05+\x05-\x05/\x051\x03\x03%\xc9\x053\x1da\x05\x055\x057\x03\x03\t\xd3\x17!\xea\x05\x01\x03\x03\t\xd5\x03\x03\t\xd9\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x11\xe5\x03\x03\x11\xe7\x03\x03\x11\xe9\x03\x03\t\xed\x03\x05)\xab+\xef\x03\x03\x15\xf3\x03\x03\x13S\x05I\x03\x0b\x19\xa1\x1b\xb3\x1d\xb5\x13\xbf\x1f\xc1\x03\x0b\x19\xa7\x1b\xc5\x1d\xa7\x13\xa9\x1f\xc7\x05K\x1d]\x05\x05M\x03\x03\t\xcb\x05O\x03\x03%\xcf\x1dg\x05\x05Q\x03\x05)\xab+\xd1\x1dm\x05\x05S\x1dq\x05\x05U\x1du\x05\x05W\x1dy/\x05Y\x1d}/\x05[\x05]\x03\x115\xad7\xaf9\xdb;\xa1=\xb1?\xddA\xdfC\xe3\x03\x03\x11\xeb\x03\x03\x15\xf5\x1d\x89\x05\x05_\x03\x07\x8d\xa3\x8f\xa3\x91\xa3\x05a\x05c\x05e\x1d\x95\x05\x05g\x05i\x03\x115\xad7\xaf9\xf7;\xa1=\xb1?\xf9A\xfbC\xff\x1f+\x01\x1f-!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f/\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dk\x03\x03\xc3\x1dm\t\x07\x0b\x05\x1do\x05\x01#\x1f\x03\x05\xb7\xbb\r\x03\xa5\xb9\x1dq\r\x03\xa5\xbd\x1ds\x1du\x1dw\r\x01##\x1dy\x13\x0b\x01\x1f\x01\t\xff\xff\xff\xff\x1f%\x01\x13\x0b\x05\x07\x05\x1f\x05\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x01\t\x01\x00\x00\x00\x1f\x01\t\x03\x00\x00\x00\x1f\x01\t`\x00\x00\x00\x1d{\x03\x0b\x9b\x9b\x9b\x9b\x9d\x03\x03\xe1\x15\x03\x01\x11\x01\x03\t\x9d\x9f\x9b\x9f\x13\x07\x01\x13\x07\x05\x13\x07\t\x13\x07\r\x1f\x01\t\x00\x00\x00\x00\x07\x01\x1f\x05\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d}\x03\x0f\x9b\x9b\x9b\x9b\x9b\x9d\x9f\x03\x03\xfd\x15\x03\x01\x15\x01\x03\x07\x9d\x9b\x9f\x03\x03\x06\x02\xa9\x05\x7f)\x01\x07)\x05\r\r\t)\x01\t\x1b\x03!\x1d\x01)\x05\r\r\x07)\x03\r\t)\x03\x02\x03\t)\x05\r\r\r\x13)\x01\r)\x05\x05\x05\r)\x03\t\x0b\x11\x01\x05\x03\x03\t\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03%\t/\t\x03\x11\x01\x13)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x03\x05\r)\x03\r\r)\x03\x05\x0b/\x07\x03\x01\x13\x04\xe6\x06\x05\x01\x11\x0fQ\x07\x03\x01\t\x0f\x11\x0fU\x05\x03Y\xb5\x0b\x03w#\x03'\x17\x06{\x03\x03\x03\x01\x03\x03\x011\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x013\x03\x01\x13\x07\x01\x81\x03)\x0b\x05\x07\t\x0b\x03\x07\x07\x01E\x03\x03\x03\r\x07\x07\x01G\x03\x11\x03\r\x07\x07\x01I\x03\x01\x03\r\x07\x07\x01\x83\x03\x13\x03\r\x03\x03\x01K\x03\x01\x05\x07\x01\x07\x03\x01\x03\x17\r\x07\x01M\x03\x19\x05\x13\x19\x05\x07\x01\x07\x03\x1b\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x03\x03\x1f\x05\x07\x01O\x03\x15\x03\x1d\t\x06\x01\x03\x03\x07#\x0f!\x05\x07\x01\x07\x031\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x11\x03)\x05\x07\x01\x85\x033\x03'\t\x06\x01\x03\x11\x07-\x11+\x03\x03\x87-\x03\x05\x19\x07\x93\x8b\x03\x03\x05%1\x03\x03\x031\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x033\x03\x01\x13\x07\x03\x99\x037\x0f579;=3/\x07\x07\x03E\x03\x03\x03?\x07\x07\x03G\x03\x01\x03?\x07\x07\x03I\x03\x13\x03?\x03\x03\x03K\x03\x01\x05\x07\x03\x07\x03\x01\x03G\r\x07\x03M\x03\x19\x05CI\x05\x07\x03\x07\x03\x1b\x03K\x03\x03\x03\x17\x03\x05\x05\x07\x03\x07\x03\x03\x03O\x05\x07\x03O\x03\x15\x03M\t\x06\x03\x03\x03\x07SAQ\x1b\x07\x0b\x02\x02\x03\x03\x03%\x11\x04\x0f\x05UW\x0f\x11\x0bW\x05\x03\x15+\x03\x03\x0f\x0b\x03[#\x03\x0f\x03\x03\x0b_\x03\x01\x05\x07'\x07\x03\x0f\x03\x05\x15\x06'\x03\x0f\x05\x03\x07\x0b\x03ec\x03\x0f\r\x07ki\x03\x15\x05\t\x0b\x03\x03\x0b-\x03\x05\x05\x07o\x07\x03\x03\x03\x0f\t\x06s\x03\x03\x07\r\x11\x01\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\xce\x18\x81\x0f\x1d\x1d\x11\x0f\x0b\t\t\x03\x0b!Y\x87##%_=\x85\x8bW\xb3K\x9bM\x9b\xd2\x02\x1b\x1f/!!)#\x1f\x19+\x1b\x1f\x83\x1f\x15\x1d\x15+\x13\r\r\x11\x0f\x17\x0f\x1f\x15\x11\x17\x11\x15+)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00get_tuple_element_v1\x00select_v1\x00iota_v1\x00compare_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00index\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=complex64 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_cgeqrf\x00lapack_cungqr\x00callee\x00", - xla_call_module_version=4, -) # End paste +data_2025_04_02 = {} - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_03_17["c128"] = dict( +data_2025_04_02['c128'] = dict( testdata_version=1, platform='cpu', - custom_call_targets=['lapack_zgeqrf', 'lapack_zungqr'], - serialized_date=datetime.date(2023, 3, 17), + custom_call_targets=['lapack_zgeqrf_ffi', 'lapack_zungqr_ffi'], + serialized_date=datetime.date(2025, 4, 2), inputs=(), expected_outputs=(array([[ 0. +0.j, 0.9128709291752773 +0.j, 0.40824829046386235+0.j], @@ -283,531 +41,199 @@ [ 0.0000000000000000e+00+0.j, 0.0000000000000000e+00+0.j, -8.8817841970012523e-16+0.j]])), mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "[0]"}, tensor<3x3xcomplex> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xcomplex> - %1 = stablehlo.reshape %0 : (tensor<9xcomplex>) -> tensor<3x3xcomplex> - %2 = stablehlo.constant dense<1> : tensor - %3 = stablehlo.constant dense<3> : tensor - %4 = stablehlo.constant dense<3> : tensor - %5 = stablehlo.constant dense<96> : tensor - %6 = stablehlo.custom_call @lapack_zgeqrf(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor<3x3xcomplex>) -> tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>> - %7 = stablehlo.get_tuple_element %6[0] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor<3x3xcomplex> - %8 = stablehlo.get_tuple_element %6[1] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor<3xcomplex> - %9 = stablehlo.get_tuple_element %6[2] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor - %10 = stablehlo.get_tuple_element %6[3] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor<96xcomplex> - %11 = stablehlo.constant dense<0> : tensor - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor - %13 = stablehlo.compare EQ, %9, %12, SIGNED : (tensor, tensor) -> tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1x1xi1> - %15 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor>) -> tensor<3x3xcomplex> - %17 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %18 = stablehlo.select %17, %7, %16 : tensor<3x3xi1>, tensor<3x3xcomplex> - %19 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi1> - %20 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor>) -> tensor<3xcomplex> - %22 = stablehlo.broadcast_in_dim %19, dims = [0] : (tensor<1xi1>) -> tensor<3xi1> - %23 = stablehlo.select %22, %8, %21 : tensor<3xi1>, tensor<3xcomplex> - %24 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> - %25 = stablehlo.pad %18, %24, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xcomplex>, tensor>) -> tensor<3x3xcomplex> - %26 = stablehlo.constant dense<1> : tensor - %27 = stablehlo.constant dense<3> : tensor - %28 = stablehlo.constant dense<3> : tensor - %29 = stablehlo.constant dense<3> : tensor - %30 = stablehlo.constant dense<96> : tensor - %31 = stablehlo.custom_call @lapack_zungqr(%26, %27, %28, %29, %30, %25, %23) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xcomplex>, tensor<3xcomplex>) -> tuple>, tensor, tensor<96xcomplex>> - %32 = stablehlo.get_tuple_element %31[0] : (tuple>, tensor, tensor<96xcomplex>>) -> tensor<3x3xcomplex> - %33 = stablehlo.get_tuple_element %31[1] : (tuple>, tensor, tensor<96xcomplex>>) -> tensor - %34 = stablehlo.get_tuple_element %31[2] : (tuple>, tensor, tensor<96xcomplex>>) -> tensor<96xcomplex> - %35 = stablehlo.constant dense<0> : tensor - %36 = stablehlo.broadcast_in_dim %35, dims = [] : (tensor) -> tensor - %37 = stablehlo.compare EQ, %33, %36, SIGNED : (tensor, tensor) -> tensor - %38 = stablehlo.broadcast_in_dim %37, dims = [] : (tensor) -> tensor<1x1xi1> - %39 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor>) -> tensor<3x3xcomplex> - %41 = stablehlo.broadcast_in_dim %38, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %42 = stablehlo.select %41, %32, %40 : tensor<3x3xi1>, tensor<3x3xcomplex> - %43 = call @triu(%18) : (tensor<3x3xcomplex>) -> tensor<3x3xcomplex> - return %42, %43 : tensor<3x3xcomplex>, tensor<3x3xcomplex> - } - func.func private @triu(%arg0: tensor<3x3xcomplex>) -> tensor<3x3xcomplex> { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> - %1 = stablehlo.constant dense<-1> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<3x3xi32> - %3 = stablehlo.add %0, %2 : tensor<3x3xi32> - %4 = stablehlo.iota dim = 1 : tensor<3x3xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> - %6 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<3x3xcomplex> - %8 = stablehlo.select %5, %7, %arg0 : tensor<3x3xi1>, tensor<3x3xcomplex> - return %8 : tensor<3x3xcomplex> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01+\x05\x01\x05\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03\xa6\x02\n\x029\x01\x9b\x0f\x0f\x17\x13\x0b\x0f\x13\x07\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x13\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0bK\x13\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0bK\x03g\x0fO/\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0bO\x1f\x1f\x1f\x0b\x1f\x0f\x17\x1b\x0f\x0f\x0f\x0f\x1f\x0bOO/\x0b'\x0f\x17\x17\x01\x05\x17\x0b\x039\x0f\x17\x0f\x07\x0b\x07\x07\x17\x13\x17\x17\x07\x0f\x17\x13\x17\x07\x17\x13\x13\x1b\x13\x13\x13\x13\x13\x13\x17\x02\x16\n\x1d\x7f\x05\x1d\x97\x05\x17!\xee\x05\x01\x03\x03\x15\xcd\x05!\x1dY\x05\x03\x03\t\xd7\x1f\x05#\x05%\x05'\x03\x03\t\xf1\x05)\x05+\x05-\x05/\x051\x03\x03%\xc9\x053\x1da\x05\x055\x057\x03\x03\t\xd3\x17!\xea\x05\x01\x03\x03\t\xd5\x03\x03\t\xd9\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x11\xe5\x03\x03\x11\xe7\x03\x03\x11\xe9\x03\x03\t\xed\x03\x05)\xab+\xef\x03\x03\x15\xf3\x03\x03\x13S\x05I\x03\x0b\x19\xa1\x1b\xb3\x1d\xb5\x13\xbf\x1f\xc1\x03\x0b\x19\xa7\x1b\xc5\x1d\xa7\x13\xa9\x1f\xc7\x05K\x1d]\x05\x05M\x03\x03\t\xcb\x05O\x03\x03%\xcf\x1dg\x05\x05Q\x03\x05)\xab+\xd1\x1dm\x05\x05S\x1dq\x05\x05U\x1du\x05\x05W\x1dy/\x05Y\x1d}/\x05[\x05]\x03\x115\xad7\xaf9\xdb;\xa1=\xb1?\xddA\xdfC\xe3\x03\x03\x11\xeb\x03\x03\x15\xf5\x1d\x89\x05\x05_\x03\x07\x8d\xa3\x8f\xa3\x91\xa3\x05a\x05c\x05e\x1d\x95\x05\x05g\x05i\x03\x115\xad7\xaf9\xf7;\xa1=\xb1?\xf9A\xfbC\xff\x1f+\x01\x1f-!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f/\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dk\x03\x03\xc3\x1dm\t\x07\x0b\x05\x1do\x05\x01#\x1f\x03\x05\xb7\xbb\r\x03\xa5\xb9\x1dq\r\x03\xa5\xbd\x1ds\x1du\x1dw\r\x01##\x1dy\x13\x0b\x01\x1f\x01\t\xff\xff\xff\xff\x1f%\x01\x13\x0b\x05\x07\x05\x1f\x05!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x01\t\x01\x00\x00\x00\x1f\x01\t\x03\x00\x00\x00\x1f\x01\t`\x00\x00\x00\x1d{\x03\x0b\x9b\x9b\x9b\x9b\x9d\x03\x03\xe1\x15\x03\x01\x11\x01\x03\t\x9d\x9f\x9b\x9f\x13\x07\x01\x13\x07\x05\x13\x07\t\x13\x07\r\x1f\x01\t\x00\x00\x00\x00\x07\x01\x1f\x05!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d}\x03\x0f\x9b\x9b\x9b\x9b\x9b\x9d\x9f\x03\x03\xfd\x15\x03\x01\x15\x01\x03\x07\x9d\x9b\x9f\x03\x03\x06\x02\xa9\x05\x7f)\x01\x07)\x05\r\r\t)\x01\t\x1b\x03!\x1d\x01)\x05\r\r\x07)\x03\r\t)\x03\x02\x03\t)\x05\r\r\r\x13)\x01\r)\x05\x05\x05\r)\x03\t\x0b\x11\x01\x05\x03\x03\x0b\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03%\t/\t\x03\x11\x01\x13)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x03\x05\r)\x03\r\r)\x03\x05\x0b/\x07\x03\x01\x13\x04\xe6\x06\x05\x01\x11\x0fQ\x07\x03\x01\t\x0f\x11\x0fU\x05\x03Y\xb5\x0b\x03w#\x03'\x17\x06{\x03\x03\x03\x01\x03\x03\x011\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x013\x03\x01\x13\x07\x01\x81\x03)\x0b\x05\x07\t\x0b\x03\x07\x07\x01E\x03\x03\x03\r\x07\x07\x01G\x03\x11\x03\r\x07\x07\x01I\x03\x01\x03\r\x07\x07\x01\x83\x03\x13\x03\r\x03\x03\x01K\x03\x01\x05\x07\x01\x07\x03\x01\x03\x17\r\x07\x01M\x03\x19\x05\x13\x19\x05\x07\x01\x07\x03\x1b\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x03\x03\x1f\x05\x07\x01O\x03\x15\x03\x1d\t\x06\x01\x03\x03\x07#\x0f!\x05\x07\x01\x07\x031\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x11\x03)\x05\x07\x01\x85\x033\x03'\t\x06\x01\x03\x11\x07-\x11+\x03\x03\x87-\x03\x05\x19\x07\x93\x8b\x03\x03\x05%1\x03\x03\x031\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x033\x03\x01\x13\x07\x03\x99\x037\x0f579;=3/\x07\x07\x03E\x03\x03\x03?\x07\x07\x03G\x03\x01\x03?\x07\x07\x03I\x03\x13\x03?\x03\x03\x03K\x03\x01\x05\x07\x03\x07\x03\x01\x03G\r\x07\x03M\x03\x19\x05CI\x05\x07\x03\x07\x03\x1b\x03K\x03\x03\x03\x17\x03\x05\x05\x07\x03\x07\x03\x03\x03O\x05\x07\x03O\x03\x15\x03M\t\x06\x03\x03\x03\x07SAQ\x1b\x07\x0b\x02\x02\x03\x03\x03%\x11\x04\x0f\x05UW\x0f\x11\x0bW\x05\x03\x15+\x03\x03\x0f\x0b\x03[#\x03\x0f\x03\x03\x0b_\x03\x01\x05\x07'\x07\x03\x0f\x03\x05\x15\x06'\x03\x0f\x05\x03\x07\x0b\x03ec\x03\x0f\r\x07ki\x03\x15\x05\t\x0b\x03\x03\x0b-\x03\x05\x05\x07o\x07\x03\x03\x03\x0f\t\x06s\x03\x03\x07\r\x11\x01\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\xd2\x18\x81\x0f\x1d\x1d\x11\x0f\x0b\t\t\x03\x0b!Y\x87##%_=\x85\x8dW\xb3K\x9bM\x9b\xd2\x02\x1b\x1f/!!)#\x1f\x19+\x1b\x1f\x83\x1f\x15\x1d\x15+\x13\r\r\x11\x0f\x17\x0f\x1f\x15\x11\x17\x11\x15+)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00get_tuple_element_v1\x00select_v1\x00iota_v1\x00compare_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00index\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=complex128 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_zgeqrf\x00lapack_zungqr\x00callee\x00", - xla_call_module_version=4, -) # End paste - - -data_2024_08_22 = {} - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_22['c128'] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zgeqrf_ffi', 'lapack_zungqr_ffi'], - serialized_date=datetime.date(2024, 8, 22), - inputs=(), - expected_outputs=( - array([ - [0.0 + 0.0j, 0.9128709291752773 + 0.0j, 0.40824829046386235 + 0.0j], - [ - -0.447213595499958 - 0.0j, - 0.3651483716701102 + 0.0j, - -0.8164965809277263 + 0.0j, - ], - [ - -0.894427190999916 - 0.0j, - -0.1825741858350548 + 0.0j, - 0.40824829046386324 + 0.0j, - ], - ]), - array([ - [ - -6.7082039324993694e00 + 0.0j, - -8.0498447189992444e00 + 0.0j, - -9.3914855054991175e00 + 0.0j, - ], - [ - 0.0000000000000000e00 + 0.0j, - 1.0954451150103341e00 + 0.0j, - 2.1908902300206665e00 + 0.0j, - ], - [ - 0.0000000000000000e00 + 0.0j, - 0.0000000000000000e00 + 0.0j, - -8.8817841970012523e-16 + 0.0j, - ], - ]), - ), - mlir_module_text=r""" -#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":364:11) -#loc10 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3)) module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3x3xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "result[0]"}, tensor<3x3xcomplex> {jax.result_info = "result[1]"}) { + %c = stablehlo.constant dense<-1> : tensor loc(#loc) + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc) %0 = stablehlo.iota dim = 0 : tensor<9xcomplex> loc(#loc4) %1 = stablehlo.reshape %0 : (tensor<9xcomplex>) -> tensor<3x3xcomplex> loc(#loc5) - %c = stablehlo.constant dense<3> : tensor loc(#loc6) - %c_0 = stablehlo.constant dense<3> : tensor loc(#loc6) - %2:2 = stablehlo.custom_call @lapack_zgeqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xcomplex>) -> (tensor<3x3xcomplex>, tensor<3xcomplex>) loc(#loc6) - %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc7) - %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xcomplex>, tensor>) -> tensor<3x3xcomplex> loc(#loc8) - %c_1 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_2 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_3 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_4 = stablehlo.constant dense<1> : tensor loc(#loc9) - %c_5 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_6 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_7 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_8 = stablehlo.constant dense<96> : tensor loc(#loc9) - %4:3 = stablehlo.custom_call @lapack_zungqr(%c_4, %c_5, %c_6, %c_7, %c_8, %3, %2#1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xcomplex>, tensor<3xcomplex>) -> (tensor<3x3xcomplex>, tensor, tensor<96xcomplex>) loc(#loc9) - %c_9 = stablehlo.constant dense<0> : tensor loc(#loc9) - %5 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor) -> tensor loc(#loc9) - %6 = stablehlo.compare EQ, %4#1, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc9) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) - %cst_10 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc9) - %8 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc9) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc9) - %10 = stablehlo.select %9, %4#0, %8 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc9) - %11 = call @triu(%2#0) : (tensor<3x3xcomplex>) -> tensor<3x3xcomplex> loc(#loc10) - return %10, %11 : tensor<3x3xcomplex>, tensor<3x3xcomplex> loc(#loc) + %2:2 = stablehlo.custom_call @lapack_zgeqrf_ffi(%1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xcomplex>) -> (tensor<3x3xcomplex>, tensor<3xcomplex>) loc(#loc6) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xcomplex>, tensor>) -> tensor<3x3xcomplex> loc(#loc7) + %4 = stablehlo.custom_call @lapack_zungqr_ffi(%3, %2#1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<3x3xcomplex>, tensor<3xcomplex>) -> tensor<3x3xcomplex> loc(#loc8) + %5 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc9) + %6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc10) + %7 = stablehlo.add %5, %6 : tensor<3x3xi32> loc(#loc10) + %8 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc9) + %9 = stablehlo.compare GE, %7, %8, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc11) + %10 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc12) + %11 = stablehlo.select %9, %10, %2#0 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc13) + return %4, %11 : tensor<3x3xcomplex>, tensor<3x3xcomplex> loc(#loc) } loc(#loc) - func.func private @triu(%arg0: tensor<3x3xcomplex> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3))) -> (tensor<3x3xcomplex> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc11) - %c = stablehlo.constant dense<-1> : tensor loc(#loc10) - %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc12) - %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc12) - %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc13) - %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc14) - %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc10) - %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc15) - %6 = stablehlo.select %4, %5, %arg0 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc16) - return %6 : tensor<3x3xcomplex> loc(#loc10) - } loc(#loc10) } loc(#loc) #loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:26) -#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:14) -#loc4 = loc("jit()/jit(main)/iota[dtype=complex128 shape=(9,) dimension=0]"(#loc1)) -#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc2)) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":412:11) +#loc4 = loc("jit()/jit(main)/iota"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape"(#loc2)) #loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) -#loc7 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc3)) -#loc8 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]"(#loc3)) -#loc9 = loc("jit()/jit(main)/householder_product"(#loc3)) -#loc11 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc3)) -#loc12 = loc("jit()/jit(main)/jit(triu)/add"(#loc3)) -#loc13 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc3)) -#loc14 = loc("jit()/jit(main)/jit(triu)/ge"(#loc3)) -#loc15 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]"(#loc3)) -#loc16 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc3)) +#loc7 = loc("jit()/jit(main)/pad"(#loc3)) +#loc8 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc9 = loc("jit()/jit(main)/iota"(#loc3)) +#loc10 = loc("jit()/jit(main)/add"(#loc3)) +#loc11 = loc("jit()/jit(main)/ge"(#loc3)) +#loc12 = loc("jit()/jit(main)/broadcast_in_dim"(#loc3)) +#loc13 = loc("jit()/jit(main)/select_n"(#loc3)) """, - mlir_module_serialized=( - b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03\xae\x02\x12\x025\x01\x9d\x0f\x17\x0b\x0f\x13\x13\x0b\x07\x0b\x0f\x13\x0f\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0bS\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0b\x13\x13K\x13\x1b\x13\x03g\x0fO\x0b\x0b\x0b//\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0bO/\x0b\x0b\x0b\x0f\x0f\x17\x13\x1f\x1f\x1f\x0b\x0b'\x0f\x17\x17\x1f\x0bOO\x01\x07\x17\x17\x0b\x01\x05\x0b\x0f\x031\x17\x0f\x0f\x0b\x07\x0f\x17\x07\x07\x07\x17\x13\x17\x07\x17\x13\x13\x13\x13\x13\x17\x13\x0f\x17\x02\xf2\t\x1d\x8f\x03\x17\x11\xb2\x05\x17\x05\x1f\x1dO\x03\x03\x03%\xd1\x03\x03\x05\xd9\x05!\x1f\x05#\x1dy\x03\x03\x03\x05\xeb\x11\x03\x05\x05%\x05'\x05)\x05+\x03\x03#\xcd\x05-\x05/\x1dW\x03\x051\x053\x03\x03\x05\xd7\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\tACE\x17G\x17\rI\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x19\xa1\x1b\xb7\x1d\xb9\r\xc3\x1f\xc5\x03\x0b\x19\xad\x1b\xc9\x1d\xad\r\xaf\x1f\xcb\x05M\x1dS\x03\x05O\x03\x03\x05\xcf\x05Q\x03\x03#\xd3\x1d]\x03\x05S\x03\x05)\xb1+\xd5\x1dc\x03\x05U\x1dg\x03\x05W\x1dk\x03\x05Y\x1doq\x05[\x17\x11\xae\x055\x1duw\x05]\x17\x11\xae\x05\x1d\x05_\x03\x13/\xdb1\xb33\xdd5\xa17\xb5}\xdf9\xe1;\xe3=\xe7\x05a\x1d\x81\x03\x05c\x03\x07\x85\xa9\x87\xa9\x89\xa9\x05e\x05g\x05i\x1d\x8d\x03\x05k\x05m\x03\x03\x05\xe9\x03\x03\x05\xed\x03\x11/\xef1\xb33\xf15\xa17\xb59\xf3;\xf5=\xf9\x03\x03\x05\xfb\x03\x05)\xb1+\xfd\x03\x03\x05\xff\x1f/\x01\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1do\x1dq\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1ds\x03\x03\xc7\x1du\t\x07\x1dw\x05\x01#\x1d\x03\x05\xbb\xbf\r\x05\xab\xbd\xa3\xa5\x1dy\r\x05\xab\xc1\xa3\xa5\x1d{\x1d}\x1d\x7f\r\x03\xa3\xa5#!\x1d\x81\x13\r\x01\x1f\x07\t\xff\xff\xff\xff\x1f#\x01\x13\r\x05\x07\x05\x1f\x0f!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\x11\x03\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x83\r\x01\x03\x03\x9f\x03\x03\xe5\x15\x03\x01\x01\x01\x03\x05\x9f\xa7\x1f\x07\t\x01\x00\x00\x00\x1f\x07\t\x03\x00\x00\x00\x1f\x07\t`\x00\x00\x00\x0b\x05\x1d\x85\x03\x0f\x9d\x9d\x9d\x9d\x9d\x9f\xa7\x03\x03\xf7\x15\x03\x01\x15\x01\x03\x07\x9f\x9d\xa7\x1f\x07\t\x00\x00\x00\x00\x07\x01\x1f\x0f!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03%\x02\x02\x03\x03\x0e\x02\xaf\x05\x87\x01\t\x01\x02\x02)\x05\r\r\x0b)\x01\x17)\x01\r\x03\x1f\x1d)\x01\x0b)\x05\r\r\x17\x01\x13\x1b)\x05\r\r\x13)\x03\t\r\x11\x01\x05\x05\x05\x0b\x11\x03\x05\x03\x05)\x03\x01\r)\x03%\x0b)\x03\r\x0b)\x03\t\x15)\x03\x05\x15)\x03\x02\x03\x0b)\x03\x01\x15)\x01\x13)\x05\x05\x05\x13\x04\x8a\x04\x05\x01\x11\x0f?\x07\x03\x01\t\t\x11\x0fK\x07\x039i\x07\x03m!\x03%\x15\x06s\x03\x05\x03\x01\x03\x03\x13\x0b\x03\t\x03\x03\x13\x0b\x03\t\x11\x07\x13{\x05\x05'\x03\x03\x03\x03\x7f-\x03\x0f\x17\x07\x8b\x83\x03\x05\x05\t\r\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x91\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x93\x03\x07\x11\x07\x01\x95\x07\x05\x07-\x0f\x17\x19\x1b\x1d\x1f\x0f\x0b\x03\x03\x01\x97\x03\x07\x05\x07\x01\t\x03\x07\x03'\x0b\x07\x01\x99\x031\x05#)\x05\x07\x01\t\x033\x03+\x03\x03\x01\x9b\x03\x0f\x05\x07\x01\t\x03\x05\x03/\x05\x07\x01\x06\x02\x03\x19\x03-\r\x06\x01\x03\x05\x073!1\x19\x07\x07\n\x02\x03\x05\x03\t\x0f\x04\x0f\x0557\t\x11\x07M\x07\x03\x15+\x03\x05\x07\x07\x03Q!\x03\x11\x03\x03\x07U\x03\x07\x05\x07'\t\x03\x11\x03\x05\x13\x06'\x03\x11\x05\x03\x07\x07\x03[Y\x03\x11\x0b\x07a_\x03\x19\x05\t\x0b\x03\x03\x07-\x03\x0f\x05\x07e\t\x03\x05\x03\x0f\r\x06i\x03\x05\x07\r\x11\x01\x0f\x04\x07\x03\x13\x06\x03\x01\x05\x01\x00\xaa\x1a\x89\x0f\x1d%\x11\x0f\x0b\t\t\x03\x0b!\x11#Y\x87##%_)=\x85\x8dW\xb3K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b+\x1f\x1f\x15\x1d\x15i\x13\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,)" - b' out_shardings=(UnspecifiedValue,) in_layouts=(None,)' - b' out_layouts=(None,) resource_env=None donated_invars=(False,)' - b' name=triu keep_unused=False' - b' inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3,' - b' 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=complex128' - b' shape=(9,)' - b' dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3)' - b' dimensions=None]\x00jit()/jit(main)/geqrf\x00mhlo.backend_config\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0,' - b' 0, 0), (0, 0,' - b' 0))]\x00jit()/jit(main)/householder_product\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_zgeqrf_ffi\x00lapack_zungqr\x00callee\x00' - ), + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.9.3\x00\x01)\x05\x01\x05\x19\x01\x03\x0b\x03\x17\x0f\x13\x17\x1b\x1f#\'+/37\x03\xc7\x8b)\x01E\x17\x07\x0b\x0f\x0b\x1b\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x17\x0f\x0b\x17\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x03G\x0b/\x0b\x0f\x0b\x0b\x0b\x0fO\x13\x0f\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1fO\x0b\x13\x0b\x0b\x0b\x0f\x17/\x0b\x0f\x13\x0f\x0b\x0b\x01\x05\x0b\x0f\x03%\x17\x0b\x07\x17\x0f\x07\x0f\x07\x17\x07\x13\x13\x13\x13\x13\x13\x17\x07\x02\xd2\x04\x17\x05r\x06\x17\x1f\x05\x1d\x11\x03\x05\x05\x1f\x03\x05\'o)q\x1d\t\x01\x1d7\x01\x03\x07\x13\x15\x17\x07\x19\x07\x05!\x11\x01\x00\x05#\x05%\x05\'\x1d\t\x1f\x17\x05n\x065\x1d#%\x05)\x17\x05n\x06\x1d\x05+\x05-\x1d-\x01\x05/\x1d1\x01\x051\x1d5\x01\x053\x055\x1d;\x01\x057\x1d?\x01\x059\x1dC\x01\x05;\x03\x01\x1f!\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d=\x13\t\x01\x0b\x03\x1d?\x05\x01\x03\x03U\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x05U}\x1f#\x01#\x15\x03\x05_c\r\x03Ia\x1dA\r\x03Ie\x1dC\x1dE\x1dG\x1f\r\t\xff\xff\xff\xff\x1f\x11!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\r\x01\r\x03su\x1dI\x1dK\x1dM\x03\x03{\x15\x03\x01\x01\x01\x1f\x1f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dO\x03\x03\x83\x15\x01\x01\x01\x13\t\x05\t\x07\x07\x05\x01\t\x01\x02\x02)\x05\r\r\x07\x03\x17\x1d)\x05\r\r\x0f)\x01\x0f\x1b)\x01\x07\x13\x11\x01\x05\x05\x05\x0b)\x03%\x07)\x03\r\x07)\x03\t\x13)\x03\x05\x13)\x03\t\t)\x03\x01\t)\x05\r\r\'\x01\x04"\x02\x05\x01Q\x03\x11\x01\x07\x04\xff\x03\x01\x05\x0bP\x03\x03\x07\x04\xeb\x03\x1f=\x05B\x03\x05\x03\r\x05B\x03\x07\x03\x11\x03B\x1d\t\x03\x19\r\x06!\x03\x05\x03\x05\x07G+\x0b\x0b\x05\x05\x1b\x03\x07\x0fF/\r\x03\x05\x05\t\x03\x07G3\x0b\x0f\x03\x05\x05\r\x0b\x03B\r\t\x03\x0b\tF\x0f\x11\x03\x0b\x03\x01\x11\x06\x0f\x03\x0b\x05\x11\x13\x03B\r\x13\x03\x0b\x13F9\x15\x03%\x05\x15\x17\tF=\x11\x03\x05\x03\x03\x15\x06A\x03\x05\x07\x19\x1b\t\x17\x04\x03\x05\x0f\x1d\x06\x03\x01\x05\x01\x00\xba\x0bQ%%\x05\x1f\x0f\x0b\x15\x15\x03!CS79Y9=3)A\x1b%)9;i\x15\x15\x17\x0f\x0f\x17\x11)\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00iota_v1\x00constant_v1\x00custom_call_v1\x00broadcast_in_dim_v1\x00func_v1\x00reshape_v1\x00pad_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit()/jit(main)/iota\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/reshape\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/pad\x00jit()/jit(main)/householder_product\x00jit()/jit(main)/add\x00jit()/jit(main)/ge\x00jit()/jit(main)/broadcast_in_dim\x00jit()/jit(main)/select_n\x00jax.result_info\x00\x00result[0]\x00result[1]\x00main\x00public\x00num_batch_dims\x000\x00lapack_zgeqrf_ffi\x00lapack_zungqr_ffi\x00\x08[\x17\x057\x01\x0bE[]gi\x03k\x03m\x03K\x11MOwEQSyW\x07GGG\x11MO\x7fEQW\x81S\x03Y\x03\x85\x05\x87\x89', xla_call_module_version=9, nr_devices=1, ) # End paste - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_22['c64'] = dict( +data_2025_04_02['c64'] = dict( testdata_version=1, platform='cpu', custom_call_targets=['lapack_cgeqrf_ffi', 'lapack_cungqr_ffi'], - serialized_date=datetime.date(2024, 8, 22), + serialized_date=datetime.date(2025, 4, 2), inputs=(), - expected_outputs=( - array( - [ - [0.0 + 0.0j, 0.91287076 + 0.0j, 0.4082487 + 0.0j], - [-0.44721356 - 0.0j, 0.36514866 + 0.0j, -0.8164965 + 0.0j], - [-0.8944271 - 0.0j, -0.18257445 + 0.0j, 0.40824816 + 0.0j], - ], - dtype=complex64, - ), - array( - [ - [ - -6.7082043e00 + 0.0j, - -8.0498438e00 + 0.0j, - -9.3914852e00 + 0.0j, - ], - [0.0000000e00 + 0.0j, 1.0954441e00 + 0.0j, 2.1908894e00 + 0.0j], - [ - 0.0000000e00 + 0.0j, - 0.0000000e00 + 0.0j, - 7.1525574e-07 + 0.0j, - ], - ], - dtype=complex64, - ), - ), + expected_outputs=(array([[ 0. +0.j, 0.91287076+0.j, 0.4082487 +0.j], + [-0.44721356-0.j, 0.36514866+0.j, -0.8164965 +0.j], + [-0.8944271 -0.j, -0.18257445+0.j, 0.40824816+0.j]], + dtype=complex64), array([[-6.7082043e+00+0.j, -8.0498438e+00+0.j, -9.3914852e+00+0.j], + [ 0.0000000e+00+0.j, 1.0954441e+00+0.j, 2.1908894e+00+0.j], + [ 0.0000000e+00+0.j, 0.0000000e+00+0.j, 7.1525574e-07+0.j]], + dtype=complex64)), mlir_module_text=r""" -#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":364:11) -#loc10 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3)) module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3x3xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "result[0]"}, tensor<3x3xcomplex> {jax.result_info = "result[1]"}) { + %c = stablehlo.constant dense<-1> : tensor loc(#loc) + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc) %0 = stablehlo.iota dim = 0 : tensor<9xcomplex> loc(#loc4) %1 = stablehlo.reshape %0 : (tensor<9xcomplex>) -> tensor<3x3xcomplex> loc(#loc5) - %c = stablehlo.constant dense<3> : tensor loc(#loc6) - %c_0 = stablehlo.constant dense<3> : tensor loc(#loc6) - %2:2 = stablehlo.custom_call @lapack_cgeqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xcomplex>) -> (tensor<3x3xcomplex>, tensor<3xcomplex>) loc(#loc6) - %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc7) - %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xcomplex>, tensor>) -> tensor<3x3xcomplex> loc(#loc8) - %c_1 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_2 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_3 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_4 = stablehlo.constant dense<1> : tensor loc(#loc9) - %c_5 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_6 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_7 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_8 = stablehlo.constant dense<96> : tensor loc(#loc9) - %4:3 = stablehlo.custom_call @lapack_cungqr(%c_4, %c_5, %c_6, %c_7, %c_8, %3, %2#1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xcomplex>, tensor<3xcomplex>) -> (tensor<3x3xcomplex>, tensor, tensor<96xcomplex>) loc(#loc9) - %c_9 = stablehlo.constant dense<0> : tensor loc(#loc9) - %5 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor) -> tensor loc(#loc9) - %6 = stablehlo.compare EQ, %4#1, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc9) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) - %cst_10 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc9) - %8 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc9) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc9) - %10 = stablehlo.select %9, %4#0, %8 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc9) - %11 = call @triu(%2#0) : (tensor<3x3xcomplex>) -> tensor<3x3xcomplex> loc(#loc10) - return %10, %11 : tensor<3x3xcomplex>, tensor<3x3xcomplex> loc(#loc) + %2:2 = stablehlo.custom_call @lapack_cgeqrf_ffi(%1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xcomplex>) -> (tensor<3x3xcomplex>, tensor<3xcomplex>) loc(#loc6) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xcomplex>, tensor>) -> tensor<3x3xcomplex> loc(#loc7) + %4 = stablehlo.custom_call @lapack_cungqr_ffi(%3, %2#1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<3x3xcomplex>, tensor<3xcomplex>) -> tensor<3x3xcomplex> loc(#loc8) + %5 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc9) + %6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc10) + %7 = stablehlo.add %5, %6 : tensor<3x3xi32> loc(#loc10) + %8 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc9) + %9 = stablehlo.compare GE, %7, %8, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc11) + %10 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc12) + %11 = stablehlo.select %9, %10, %2#0 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc13) + return %4, %11 : tensor<3x3xcomplex>, tensor<3x3xcomplex> loc(#loc) } loc(#loc) - func.func private @triu(%arg0: tensor<3x3xcomplex> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3))) -> (tensor<3x3xcomplex> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc11) - %c = stablehlo.constant dense<-1> : tensor loc(#loc10) - %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc12) - %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc12) - %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc13) - %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc14) - %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc10) - %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc15) - %6 = stablehlo.select %4, %5, %arg0 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc16) - return %6 : tensor<3x3xcomplex> loc(#loc10) - } loc(#loc10) } loc(#loc) #loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:26) -#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:14) -#loc4 = loc("jit()/jit(main)/iota[dtype=complex64 shape=(9,) dimension=0]"(#loc1)) -#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc2)) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":412:11) +#loc4 = loc("jit()/jit(main)/iota"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape"(#loc2)) #loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) -#loc7 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc3)) -#loc8 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]"(#loc3)) -#loc9 = loc("jit()/jit(main)/householder_product"(#loc3)) -#loc11 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc3)) -#loc12 = loc("jit()/jit(main)/jit(triu)/add"(#loc3)) -#loc13 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc3)) -#loc14 = loc("jit()/jit(main)/jit(triu)/ge"(#loc3)) -#loc15 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]"(#loc3)) -#loc16 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc3)) +#loc7 = loc("jit()/jit(main)/pad"(#loc3)) +#loc8 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc9 = loc("jit()/jit(main)/iota"(#loc3)) +#loc10 = loc("jit()/jit(main)/add"(#loc3)) +#loc11 = loc("jit()/jit(main)/ge"(#loc3)) +#loc12 = loc("jit()/jit(main)/broadcast_in_dim"(#loc3)) +#loc13 = loc("jit()/jit(main)/select_n"(#loc3)) """, - mlir_module_serialized=( - b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03\xae\x02\x12\x025\x01\x9d\x0f\x17\x0b\x0f\x13\x13\x0b\x07\x0b\x0f\x13\x0f\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0bS\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0b\x13\x13K\x13\x1b\x13\x03g\x0fO\x0b\x0b\x0b//\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b//\x0b\x0b\x0b\x0f\x0f\x17\x13\x1f\x1f\x1f\x0b\x0b'\x0f\x17\x17\x1f\x0b/O\x01\x07\x17\x17\x0b\x01\x05\x0b\x0f\x031\x17\x0f\x0f\x0b\x07\x0f\x17\x07\x07\x07\x17\x13\x17\x07\x17\x13\x13\x13\x13\x13\x17\x13\x0f\x17\x02\xb2\t\x1d\x8f\x03\x17\x11\xb2\x05\x17\x05\x1f\x1dO\x03\x03\x03%\xd1\x03\x03\x05\xd9\x05!\x1f\x05#\x1dy\x03\x03\x03\x05\xeb\x11\x03\x05\x05%\x05'\x05)\x05+\x03\x03#\xcd\x05-\x05/\x1dW\x03\x051\x053\x03\x03\x05\xd7\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\tACE\x17G\x17\rI\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x19\xa1\x1b\xb7\x1d\xb9\r\xc3\x1f\xc5\x03\x0b\x19\xad\x1b\xc9\x1d\xad\r\xaf\x1f\xcb\x05M\x1dS\x03\x05O\x03\x03\x05\xcf\x05Q\x03\x03#\xd3\x1d]\x03\x05S\x03\x05)\xb1+\xd5\x1dc\x03\x05U\x1dg\x03\x05W\x1dk\x03\x05Y\x1doq\x05[\x17\x11\xae\x055\x1duw\x05]\x17\x11\xae\x05\x1d\x05_\x03\x13/\xdb1\xb33\xdd5\xa17\xb5}\xdf9\xe1;\xe3=\xe7\x05a\x1d\x81\x03\x05c\x03\x07\x85\xa9\x87\xa9\x89\xa9\x05e\x05g\x05i\x1d\x8d\x03\x05k\x05m\x03\x03\x05\xe9\x03\x03\x05\xed\x03\x11/\xef1\xb33\xf15\xa17\xb59\xf3;\xf5=\xf9\x03\x03\x05\xfb\x03\x05)\xb1+\xfd\x03\x03\x05\xff\x1f/\x01\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1do\x1dq\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1ds\x03\x03\xc7\x1du\t\x07\x1dw\x05\x01#\x1d\x03\x05\xbb\xbf\r\x05\xab\xbd\xa3\xa5\x1dy\r\x05\xab\xc1\xa3\xa5\x1d{\x1d}\x1d\x7f\r\x03\xa3\xa5#!\x1d\x81\x13\r\x01\x1f\x07\t\xff\xff\xff\xff\x1f#\x01\x13\r\x05\x07\x05\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\x11\x03\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x83\r\x01\x03\x03\x9f\x03\x03\xe5\x15\x03\x01\x01\x01\x03\x05\x9f\xa7\x1f\x07\t\x01\x00\x00\x00\x1f\x07\t\x03\x00\x00\x00\x1f\x07\t`\x00\x00\x00\x0b\x05\x1d\x85\x03\x0f\x9d\x9d\x9d\x9d\x9d\x9f\xa7\x03\x03\xf7\x15\x03\x01\x15\x01\x03\x07\x9f\x9d\xa7\x1f\x07\t\x00\x00\x00\x00\x07\x01\x1f\x0f\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03%\x02\x02\x03\x03\x0e\x02\xaf\x05\x87\x01\t\x01\x02\x02)\x05\r\r\x0b)\x01\x17)\x01\r\x03\x1f\x1d)\x01\x0b)\x05\r\r\x17\x01\x13\x1b)\x05\r\r\x13)\x03\t\r\x11\x01\x05\x05\x05\t\x11\x03\x05\x03\x05)\x03\x01\r)\x03%\x0b)\x03\r\x0b)\x03\t\x15)\x03\x05\x15)\x03\x02\x03\x0b)\x03\x01\x15)\x01\x13)\x05\x05\x05\x13\x04\x8a\x04\x05\x01\x11\x0f?\x07\x03\x01\t\t\x11\x0fK\x07\x039i\x07\x03m!\x03%\x15\x06s\x03\x05\x03\x01\x03\x03\x13\x0b\x03\t\x03\x03\x13\x0b\x03\t\x11\x07\x13{\x05\x05'\x03\x03\x03\x03\x7f-\x03\x0f\x17\x07\x8b\x83\x03\x05\x05\t\r\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x91\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x93\x03\x07\x11\x07\x01\x95\x07\x05\x07-\x0f\x17\x19\x1b\x1d\x1f\x0f\x0b\x03\x03\x01\x97\x03\x07\x05\x07\x01\t\x03\x07\x03'\x0b\x07\x01\x99\x031\x05#)\x05\x07\x01\t\x033\x03+\x03\x03\x01\x9b\x03\x0f\x05\x07\x01\t\x03\x05\x03/\x05\x07\x01\x06\x02\x03\x19\x03-\r\x06\x01\x03\x05\x073!1\x19\x07\x07\n\x02\x03\x05\x03\t\x0f\x04\x0f\x0557\t\x11\x07M\x07\x03\x15+\x03\x05\x07\x07\x03Q!\x03\x11\x03\x03\x07U\x03\x07\x05\x07'\t\x03\x11\x03\x05\x13\x06'\x03\x11\x05\x03\x07\x07\x03[Y\x03\x11\x0b\x07a_\x03\x19\x05\t\x0b\x03\x03\x07-\x03\x0f\x05\x07e\t\x03\x05\x03\x0f\r\x06i\x03\x05\x07\r\x11\x01\x0f\x04\x07\x03\x13\x06\x03\x01\x05\x01\x00\xa6\x1a\x89\x0f\x1d%\x11\x0f\x0b\t\t\x03\x0b!\x11#Y\x87##%_)=\x85\x8bW\xb3K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b+\x1f\x1f\x15\x1d\x15i\x13\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,)" - b' out_shardings=(UnspecifiedValue,) in_layouts=(None,)' - b' out_layouts=(None,) resource_env=None donated_invars=(False,)' - b' name=triu keep_unused=False' - b' inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3,' - b' 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=complex64' - b' shape=(9,)' - b' dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3)' - b' dimensions=None]\x00jit()/jit(main)/geqrf\x00mhlo.backend_config\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0,' - b' 0, 0), (0, 0,' - b' 0))]\x00jit()/jit(main)/householder_product\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_cgeqrf_ffi\x00lapack_cungqr\x00callee\x00' - ), + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.9.3\x00\x01)\x05\x01\x05\x19\x01\x03\x0b\x03\x17\x0f\x13\x17\x1b\x1f#\'+/37\x03\xc7\x8b)\x01E\x17\x07\x0b\x0f\x0b\x1b\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x17\x0f\x0b\x17\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x03G\x0b/\x0b\x0f\x0b\x0b\x0b\x0fO\x13\x0f\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f/\x0b\x13\x0b\x0b\x0b\x0f\x17/\x0b\x0f\x13\x0f\x0b\x0b\x01\x05\x0b\x0f\x03%\x17\x0b\x07\x17\x0f\x07\x0f\x07\x17\x07\x13\x13\x13\x13\x13\x13\x17\x07\x02\xb2\x04\x17\x05r\x06\x17\x1f\x05\x1d\x11\x03\x05\x05\x1f\x03\x05\'o)q\x1d\t\x01\x1d7\x01\x03\x07\x13\x15\x17\x07\x19\x07\x05!\x11\x01\x00\x05#\x05%\x05\'\x1d\t\x1f\x17\x05n\x065\x1d#%\x05)\x17\x05n\x06\x1d\x05+\x05-\x1d-\x01\x05/\x1d1\x01\x051\x1d5\x01\x053\x055\x1d;\x01\x057\x1d?\x01\x059\x1dC\x01\x05;\x03\x01\x1f!\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d=\x13\t\x01\x0b\x03\x1d?\x05\x01\x03\x03U\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x05U}\x1f#\x01#\x15\x03\x05_c\r\x03Ia\x1dA\r\x03Ie\x1dC\x1dE\x1dG\x1f\r\t\xff\xff\xff\xff\x1f\x11\x11\x00\x00\x00\x00\x00\x00\x00\x00\r\x01\r\x03su\x1dI\x1dK\x1dM\x03\x03{\x15\x03\x01\x01\x01\x1f\x1f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dO\x03\x03\x83\x15\x01\x01\x01\x13\t\x05\t\x07\x07\x05\x01\t\x01\x02\x02)\x05\r\r\x07\x03\x17\x1d)\x05\r\r\x0f)\x01\x0f\x1b)\x01\x07\x13\x11\x01\x05\x05\x05\t)\x03%\x07)\x03\r\x07)\x03\t\x13)\x03\x05\x13)\x03\t\t)\x03\x01\t)\x05\r\r\'\x01\x04"\x02\x05\x01Q\x03\x11\x01\x07\x04\xff\x03\x01\x05\x0bP\x03\x03\x07\x04\xeb\x03\x1f=\x05B\x03\x05\x03\r\x05B\x03\x07\x03\x11\x03B\x1d\t\x03\x19\r\x06!\x03\x05\x03\x05\x07G+\x0b\x0b\x05\x05\x1b\x03\x07\x0fF/\r\x03\x05\x05\t\x03\x07G3\x0b\x0f\x03\x05\x05\r\x0b\x03B\r\t\x03\x0b\tF\x0f\x11\x03\x0b\x03\x01\x11\x06\x0f\x03\x0b\x05\x11\x13\x03B\r\x13\x03\x0b\x13F9\x15\x03%\x05\x15\x17\tF=\x11\x03\x05\x03\x03\x15\x06A\x03\x05\x07\x19\x1b\t\x17\x04\x03\x05\x0f\x1d\x06\x03\x01\x05\x01\x00\xba\x0bQ%%\x05\x1f\x0f\x0b\x15\x15\x03!CS79Y9=3)A\x1b%)9;i\x15\x15\x17\x0f\x0f\x17\x11)\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00iota_v1\x00constant_v1\x00custom_call_v1\x00broadcast_in_dim_v1\x00func_v1\x00reshape_v1\x00pad_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit()/jit(main)/iota\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/reshape\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/pad\x00jit()/jit(main)/householder_product\x00jit()/jit(main)/add\x00jit()/jit(main)/ge\x00jit()/jit(main)/broadcast_in_dim\x00jit()/jit(main)/select_n\x00jax.result_info\x00\x00result[0]\x00result[1]\x00main\x00public\x00num_batch_dims\x000\x00lapack_cgeqrf_ffi\x00lapack_cungqr_ffi\x00\x08[\x17\x057\x01\x0bE[]gi\x03k\x03m\x03K\x11MOwEQSyW\x07GGG\x11MO\x7fEQW\x81S\x03Y\x03\x85\x05\x87\x89', xla_call_module_version=9, nr_devices=1, ) # End paste - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_22['f32'] = dict( +data_2025_04_02['f32'] = dict( testdata_version=1, platform='cpu', custom_call_targets=['lapack_sgeqrf_ffi', 'lapack_sorgqr_ffi'], - serialized_date=datetime.date(2024, 8, 22), + serialized_date=datetime.date(2025, 4, 2), inputs=(), - expected_outputs=( - array( - [ - [0.0, 0.91287076, 0.4082487], - [-0.44721356, 0.36514866, -0.8164965], - [-0.8944271, -0.18257445, 0.40824816], - ], - dtype=float32, - ), - array( - [ - [-6.7082043e00, -8.0498438e00, -9.3914852e00], - [0.0000000e00, 1.0954441e00, 2.1908894e00], - [0.0000000e00, 0.0000000e00, 7.1525574e-07], - ], - dtype=float32, - ), - ), + expected_outputs=(array([[ 0. , 0.91287076, 0.4082487 ], + [-0.44721356, 0.36514866, -0.8164965 ], + [-0.8944271 , -0.18257445, 0.40824816]], dtype=float32), array([[-6.7082043e+00, -8.0498438e+00, -9.3914852e+00], + [ 0.0000000e+00, 1.0954441e+00, 2.1908894e+00], + [ 0.0000000e+00, 0.0000000e+00, 7.1525574e-07]], dtype=float32)), mlir_module_text=r""" -#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":364:11) -#loc10 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3)) module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3x3xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "result[0]"}, tensor<3x3xf32> {jax.result_info = "result[1]"}) { + %c = stablehlo.constant dense<-1> : tensor loc(#loc) + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc) %0 = stablehlo.iota dim = 0 : tensor<9xf32> loc(#loc4) %1 = stablehlo.reshape %0 : (tensor<9xf32>) -> tensor<3x3xf32> loc(#loc5) - %c = stablehlo.constant dense<3> : tensor loc(#loc6) - %c_0 = stablehlo.constant dense<3> : tensor loc(#loc6) - %2:2 = stablehlo.custom_call @lapack_sgeqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf32>) -> (tensor<3x3xf32>, tensor<3xf32>) loc(#loc6) - %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc7) - %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> loc(#loc8) - %c_1 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_2 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_3 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_4 = stablehlo.constant dense<1> : tensor loc(#loc9) - %c_5 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_6 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_7 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_8 = stablehlo.constant dense<96> : tensor loc(#loc9) - %4:3 = stablehlo.custom_call @lapack_sorgqr(%c_4, %c_5, %c_6, %c_7, %c_8, %3, %2#1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xf32>, tensor<3xf32>) -> (tensor<3x3xf32>, tensor, tensor<96xf32>) loc(#loc9) - %c_9 = stablehlo.constant dense<0> : tensor loc(#loc9) - %5 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor) -> tensor loc(#loc9) - %6 = stablehlo.compare EQ, %4#1, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc9) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) - %cst_10 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc9) - %8 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc9) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc9) - %10 = stablehlo.select %9, %4#0, %8 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc9) - %11 = call @triu(%2#0) : (tensor<3x3xf32>) -> tensor<3x3xf32> loc(#loc10) - return %10, %11 : tensor<3x3xf32>, tensor<3x3xf32> loc(#loc) + %2:2 = stablehlo.custom_call @lapack_sgeqrf_ffi(%1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf32>) -> (tensor<3x3xf32>, tensor<3xf32>) loc(#loc6) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> loc(#loc7) + %4 = stablehlo.custom_call @lapack_sorgqr_ffi(%3, %2#1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<3x3xf32>, tensor<3xf32>) -> tensor<3x3xf32> loc(#loc8) + %5 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc9) + %6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc10) + %7 = stablehlo.add %5, %6 : tensor<3x3xi32> loc(#loc10) + %8 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc9) + %9 = stablehlo.compare GE, %7, %8, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc11) + %10 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc12) + %11 = stablehlo.select %9, %10, %2#0 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc13) + return %4, %11 : tensor<3x3xf32>, tensor<3x3xf32> loc(#loc) } loc(#loc) - func.func private @triu(%arg0: tensor<3x3xf32> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3))) -> (tensor<3x3xf32> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc11) - %c = stablehlo.constant dense<-1> : tensor loc(#loc10) - %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc12) - %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc12) - %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc13) - %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc14) - %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc10) - %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc15) - %6 = stablehlo.select %4, %5, %arg0 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc16) - return %6 : tensor<3x3xf32> loc(#loc10) - } loc(#loc10) } loc(#loc) #loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:26) -#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:14) -#loc4 = loc("jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]"(#loc1)) -#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc2)) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":412:11) +#loc4 = loc("jit()/jit(main)/iota"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape"(#loc2)) #loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) -#loc7 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc3)) -#loc8 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]"(#loc3)) -#loc9 = loc("jit()/jit(main)/householder_product"(#loc3)) -#loc11 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc3)) -#loc12 = loc("jit()/jit(main)/jit(triu)/add"(#loc3)) -#loc13 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc3)) -#loc14 = loc("jit()/jit(main)/jit(triu)/ge"(#loc3)) -#loc15 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]"(#loc3)) -#loc16 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc3)) +#loc7 = loc("jit()/jit(main)/pad"(#loc3)) +#loc8 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc9 = loc("jit()/jit(main)/iota"(#loc3)) +#loc10 = loc("jit()/jit(main)/add"(#loc3)) +#loc11 = loc("jit()/jit(main)/ge"(#loc3)) +#loc12 = loc("jit()/jit(main)/broadcast_in_dim"(#loc3)) +#loc13 = loc("jit()/jit(main)/select_n"(#loc3)) """, - mlir_module_serialized=( - b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03\xaa\x02\x12\x023\x01\x9d\x0f\x17\x0b\x0f\x13\x13\x0b\x07\x0b\x0f\x13\x0f\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0bS\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0b\x13\x13K\x13\x1b\x13\x03g\x0fO\x0b\x0b\x0b//\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1f/\x0b\x0b\x0b\x0f\x0f\x17\x13\x1f\x1f\x1f\x0b\x0b'\x0f\x17\x17\x1f\x0b\x1fO\x01\x07\x17\x17\x0b\x01\x05\x0b\x0f\x03/\x17\x0f\x0f\x07\x07\x0f\x17\x07\x07\x07\x17\x13\x17\x17\x13\x13\x13\x13\x13\x17\x13\x0f\x17\x02\x8a\t\x1d\x8f\x03\x17\x11\xb2\x05\x17\x05\x1f\x1dO\x03\x03\x03%\xd1\x03\x03\x05\xd9\x05!\x1f\x05#\x1dy\x03\x03\x03\x05\xeb\x11\x03\x05\x05%\x05'\x05)\x05+\x03\x03#\xcd\x05-\x05/\x1dW\x03\x051\x053\x03\x03\x05\xd7\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\tACE\x17G\x17\rI\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x19\xa1\x1b\xb7\x1d\xb9\r\xc3\x1f\xc5\x03\x0b\x19\xad\x1b\xc9\x1d\xad\r\xaf\x1f\xcb\x05M\x1dS\x03\x05O\x03\x03\x05\xcf\x05Q\x03\x03#\xd3\x1d]\x03\x05S\x03\x05)\xb1+\xd5\x1dc\x03\x05U\x1dg\x03\x05W\x1dk\x03\x05Y\x1doq\x05[\x17\x11\xae\x055\x1duw\x05]\x17\x11\xae\x05\x1d\x05_\x03\x13/\xdb1\xb33\xdd5\xa17\xb5}\xdf9\xe1;\xe3=\xe7\x05a\x1d\x81\x03\x05c\x03\x07\x85\xa9\x87\xa9\x89\xa9\x05e\x05g\x05i\x1d\x8d\x03\x05k\x05m\x03\x03\x05\xe9\x03\x03\x05\xed\x03\x11/\xef1\xb33\xf15\xa17\xb59\xf3;\xf5=\xf9\x03\x03\x05\xfb\x03\x05)\xb1+\xfd\x03\x03\x05\xff\x1f-\x01\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1do\x1dq\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1ds\x03\x03\xc7\x1du\t\x07\x1dw\x05\x01#\x1d\x03\x05\xbb\xbf\r\x05\xab\xbd\xa3\xa5\x1dy\r\x05\xab\xc1\xa3\xa5\x1d{\x1d}\x1d\x7f\r\x03\xa3\xa5#\x1f\x1d\x81\x13\r\x01\x1f\x07\t\xff\xff\xff\xff\x1f!\x01\x13\r\x05\x07\x05\x1f\x0f\t\x00\x00\x00\x00\x1f\t\x11\x03\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x83\r\x01\x03\x03\x9f\x03\x03\xe5\x15\x03\x01\x01\x01\x03\x05\x9f\xa7\x1f\x07\t\x01\x00\x00\x00\x1f\x07\t\x03\x00\x00\x00\x1f\x07\t`\x00\x00\x00\x0b\x05\x1d\x85\x03\x0f\x9d\x9d\x9d\x9d\x9d\x9f\xa7\x03\x03\xf7\x15\x03\x01\x15\x01\x03\x07\x9f\x9d\xa7\x1f\x07\t\x00\x00\x00\x00\x07\x01\x1f\x0f\t\x00\x00\xc0\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03%\x02\x02\x03\x03\x0e\x02\xaf\x05\x87\x01\t\x01\x02\x02)\x05\r\r\x0b)\x01\x17)\x01\r\t\x1d)\x01\x0b)\x05\r\r\x17\x01\x13\x1b)\x05\r\r\x13)\x03\t\r\x11\x01\x05\x05\x05\x11\x03\x05\x03\x05)\x03\x01\r)\x03%\x0b)\x03\r\x0b)\x03\t\x15)\x03\x05\x15)\x03\x02\x03\x0b)\x03\x01\x15)\x01\x13)\x05\x05\x05\x13\x04\x8a\x04\x05\x01\x11\x0f?\x07\x03\x01\t\t\x11\x0fK\x07\x039i\x07\x03m!\x03#\x15\x06s\x03\x05\x03\x01\x03\x03\x13\x0b\x03\t\x03\x03\x13\x0b\x03\t\x11\x07\x13{\x05\x05%\x03\x03\x03\x03\x7f-\x03\x0f\x17\x07\x8b\x83\x03\x05\x05\t\r\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x91\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x93\x03\x07\x11\x07\x01\x95\x07\x05\x07+\x0f\x17\x19\x1b\x1d\x1f\x0f\x0b\x03\x03\x01\x97\x03\x07\x05\x07\x01\t\x03\x07\x03'\x0b\x07\x01\x99\x03/\x05#)\x05\x07\x01\t\x031\x03+\x03\x03\x01\x9b\x03\x0f\x05\x07\x01\t\x03\x05\x03/\x05\x07\x01\x06\x02\x03\x19\x03-\r\x06\x01\x03\x05\x073!1\x19\x07\x07\n\x02\x03\x05\x03\t\x0f\x04\x0f\x0557\t\x11\x07M\x07\x03\x15+\x03\x05\x07\x07\x03Q!\x03\x11\x03\x03\x07U\x03\x07\x05\x07'\t\x03\x11\x03\x05\x13\x06'\x03\x11\x05\x03\x07\x07\x03[Y\x03\x11\x0b\x07a_\x03\x19\x05\t\x0b\x03\x03\x07-\x03\x0f\x05\x07e\t\x03\x05\x03\x0f\r\x06i\x03\x05\x07\r\x11\x01\x0f\x04\x07\x03\x13\x06\x03\x01\x05\x01\x00\x9e\x1a\x89\x0f\x1d%\x11\x0f\x0b\t\t\x03\x0b!\x11#Y\x87##%_)=\x85\x87W\xb3K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b+\x1f\x1f\x15\x1d\x15i\x13\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,)" - b' out_shardings=(UnspecifiedValue,) in_layouts=(None,)' - b' out_layouts=(None,) resource_env=None donated_invars=(False,)' - b' name=triu keep_unused=False' - b' inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3,' - b' 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32' - b' shape=(9,)' - b' dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3)' - b' dimensions=None]\x00jit()/jit(main)/geqrf\x00mhlo.backend_config\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0,' - b' 0, 0), (0, 0,' - b' 0))]\x00jit()/jit(main)/householder_product\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_sgeqrf_ffi\x00lapack_sorgqr\x00callee\x00' - ), + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.9.3\x00\x01)\x05\x01\x05\x19\x01\x03\x0b\x03\x17\x0f\x13\x17\x1b\x1f#\'+/37\x03\xc5\x8b\'\x01E\x17\x07\x0b\x0f\x0b\x1b\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x17\x0f\x0b\x17\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x03G\x0b/\x0b\x0f\x0b\x0b\x0b\x0fO\x13\x0f\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x0b\x13\x0b\x0b\x0b\x0f\x17/\x0b\x0f\x13\x0f\x0b\x0b\x01\x05\x0b\x0f\x03#\x17\x07\x07\x17\x0f\x07\x0f\x07\x17\x13\x13\x13\x13\x13\x13\x17\x07\x02\x9a\x04\x17\x05r\x06\x17\x1f\x05\x1d\x11\x03\x05\x05\x1f\x03\x05\'o)q\x1d\t\x01\x1d7\x01\x03\x07\x13\x15\x17\x07\x19\x07\x05!\x11\x01\x00\x05#\x05%\x05\'\x1d\t\x1f\x17\x05n\x065\x1d#%\x05)\x17\x05n\x06\x1d\x05+\x05-\x1d-\x01\x05/\x1d1\x01\x051\x1d5\x01\x053\x055\x1d;\x01\x057\x1d?\x01\x059\x1dC\x01\x05;\x03\x01\x1f\x1f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d=\x13\t\x01\x0b\x03\x1d?\x05\x01\x03\x03U\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x05U}\x1f!\x01#\x15\x03\x05_c\r\x03Ia\x1dA\r\x03Ie\x1dC\x1dE\x1dG\x1f\r\t\xff\xff\xff\xff\x1f\x11\t\x00\x00\x00\x00\r\x01\r\x03su\x1dI\x1dK\x1dM\x03\x03{\x15\x03\x01\x01\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dO\x03\x03\x83\x15\x01\x01\x01\x13\t\x05\t\x07\x07\x05\x01\t\x01\x02\x02)\x05\r\r\x07\t\x1d)\x05\r\r\x0f)\x01\x0f\x1b)\x01\x07\x13\x11\x01\x05\x05\x05)\x03%\x07)\x03\r\x07)\x03\t\x13)\x03\x05\x13)\x03\t\t)\x03\x01\t)\x05\r\r%\x01\x04"\x02\x05\x01Q\x03\x11\x01\x07\x04\xff\x03\x01\x05\x0bP\x03\x03\x07\x04\xeb\x03\x1f=\x05B\x03\x05\x03\r\x05B\x03\x07\x03\x11\x03B\x1d\t\x03\x17\r\x06!\x03\x05\x03\x05\x07G+\x0b\x0b\x05\x05\x19\x03\x07\x0fF/\r\x03\x05\x05\t\x03\x07G3\x0b\x0f\x03\x05\x05\r\x0b\x03B\r\t\x03\x0b\tF\x0f\x11\x03\x0b\x03\x01\x11\x06\x0f\x03\x0b\x05\x11\x13\x03B\r\x13\x03\x0b\x13F9\x15\x03#\x05\x15\x17\tF=\x11\x03\x05\x03\x03\x15\x06A\x03\x05\x07\x19\x1b\t\x17\x04\x03\x05\x0f\x1d\x06\x03\x01\x05\x01\x00\xba\x0bQ%%\x05\x1f\x0f\x0b\x15\x15\x03!CS79Y9=3)A\x1b%)9;i\x15\x15\x17\x0f\x0f\x17\x11)\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00iota_v1\x00constant_v1\x00custom_call_v1\x00broadcast_in_dim_v1\x00func_v1\x00reshape_v1\x00pad_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit()/jit(main)/iota\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/reshape\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/pad\x00jit()/jit(main)/householder_product\x00jit()/jit(main)/add\x00jit()/jit(main)/ge\x00jit()/jit(main)/broadcast_in_dim\x00jit()/jit(main)/select_n\x00jax.result_info\x00\x00result[0]\x00result[1]\x00main\x00public\x00num_batch_dims\x000\x00lapack_sgeqrf_ffi\x00lapack_sorgqr_ffi\x00\x08[\x17\x057\x01\x0bE[]gi\x03k\x03m\x03K\x11MOwEQSyW\x07GGG\x11MO\x7fEQW\x81S\x03Y\x03\x85\x05\x87\x89', xla_call_module_version=9, nr_devices=1, ) # End paste - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_22['f64'] = dict( +data_2025_04_02['f64'] = dict( testdata_version=1, platform='cpu', custom_call_targets=['lapack_dgeqrf_ffi', 'lapack_dorgqr_ffi'], - serialized_date=datetime.date(2024, 8, 22), + serialized_date=datetime.date(2025, 4, 2), inputs=(), - expected_outputs=( - array([ - [0.0, 0.9128709291752773, 0.40824829046386235], - [-0.447213595499958, 0.3651483716701102, -0.8164965809277263], - [-0.894427190999916, -0.1825741858350548, 0.40824829046386324], - ]), - array([ - [ - -6.7082039324993694e00, - -8.0498447189992444e00, - -9.3914855054991175e00, - ], - [ - 0.0000000000000000e00, - 1.0954451150103341e00, - 2.1908902300206665e00, - ], - [ - 0.0000000000000000e00, - 0.0000000000000000e00, - -8.8817841970012523e-16, - ], - ]), - ), + expected_outputs=(array([[ 0. , 0.9128709291752773 , 0.40824829046386235], + [-0.447213595499958 , 0.3651483716701102 , -0.8164965809277263 ], + [-0.894427190999916 , -0.1825741858350548 , 0.40824829046386324]]), array([[-6.7082039324993694e+00, -8.0498447189992444e+00, + -9.3914855054991175e+00], + [ 0.0000000000000000e+00, 1.0954451150103341e+00, + 2.1908902300206665e+00], + [ 0.0000000000000000e+00, 0.0000000000000000e+00, + -8.8817841970012523e-16]])), mlir_module_text=r""" -#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":364:11) -#loc10 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3)) module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3x3xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + func.func public @main() -> (tensor<3x3xf64> {jax.result_info = "result[0]"}, tensor<3x3xf64> {jax.result_info = "result[1]"}) { + %c = stablehlo.constant dense<-1> : tensor loc(#loc) + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc) %0 = stablehlo.iota dim = 0 : tensor<9xf64> loc(#loc4) %1 = stablehlo.reshape %0 : (tensor<9xf64>) -> tensor<3x3xf64> loc(#loc5) - %c = stablehlo.constant dense<3> : tensor loc(#loc6) - %c_0 = stablehlo.constant dense<3> : tensor loc(#loc6) - %2:2 = stablehlo.custom_call @lapack_dgeqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf64>) -> (tensor<3x3xf64>, tensor<3xf64>) loc(#loc6) - %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc7) - %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf64>, tensor) -> tensor<3x3xf64> loc(#loc8) - %c_1 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_2 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_3 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_4 = stablehlo.constant dense<1> : tensor loc(#loc9) - %c_5 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_6 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_7 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_8 = stablehlo.constant dense<96> : tensor loc(#loc9) - %4:3 = stablehlo.custom_call @lapack_dorgqr(%c_4, %c_5, %c_6, %c_7, %c_8, %3, %2#1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xf64>, tensor<3xf64>) -> (tensor<3x3xf64>, tensor, tensor<96xf64>) loc(#loc9) - %c_9 = stablehlo.constant dense<0> : tensor loc(#loc9) - %5 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor) -> tensor loc(#loc9) - %6 = stablehlo.compare EQ, %4#1, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc9) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) - %cst_10 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc9) - %8 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor) -> tensor<3x3xf64> loc(#loc9) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc9) - %10 = stablehlo.select %9, %4#0, %8 : tensor<3x3xi1>, tensor<3x3xf64> loc(#loc9) - %11 = call @triu(%2#0) : (tensor<3x3xf64>) -> tensor<3x3xf64> loc(#loc10) - return %10, %11 : tensor<3x3xf64>, tensor<3x3xf64> loc(#loc) + %2:2 = stablehlo.custom_call @lapack_dgeqrf_ffi(%1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf64>) -> (tensor<3x3xf64>, tensor<3xf64>) loc(#loc6) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf64>, tensor) -> tensor<3x3xf64> loc(#loc7) + %4 = stablehlo.custom_call @lapack_dorgqr_ffi(%3, %2#1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<3x3xf64>, tensor<3xf64>) -> tensor<3x3xf64> loc(#loc8) + %5 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc9) + %6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc10) + %7 = stablehlo.add %5, %6 : tensor<3x3xi32> loc(#loc10) + %8 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc9) + %9 = stablehlo.compare GE, %7, %8, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc11) + %10 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x3xf64> loc(#loc12) + %11 = stablehlo.select %9, %10, %2#0 : tensor<3x3xi1>, tensor<3x3xf64> loc(#loc13) + return %4, %11 : tensor<3x3xf64>, tensor<3x3xf64> loc(#loc) } loc(#loc) - func.func private @triu(%arg0: tensor<3x3xf64> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3))) -> (tensor<3x3xf64> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc11) - %c = stablehlo.constant dense<-1> : tensor loc(#loc10) - %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc12) - %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc12) - %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc13) - %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc14) - %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc10) - %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x3xf64> loc(#loc15) - %6 = stablehlo.select %4, %5, %arg0 : tensor<3x3xi1>, tensor<3x3xf64> loc(#loc16) - return %6 : tensor<3x3xf64> loc(#loc10) - } loc(#loc10) } loc(#loc) #loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:26) -#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:14) -#loc4 = loc("jit()/jit(main)/iota[dtype=float64 shape=(9,) dimension=0]"(#loc1)) -#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc2)) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":412:11) +#loc4 = loc("jit()/jit(main)/iota"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape"(#loc2)) #loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) -#loc7 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc3)) -#loc8 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]"(#loc3)) -#loc9 = loc("jit()/jit(main)/householder_product"(#loc3)) -#loc11 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc3)) -#loc12 = loc("jit()/jit(main)/jit(triu)/add"(#loc3)) -#loc13 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc3)) -#loc14 = loc("jit()/jit(main)/jit(triu)/ge"(#loc3)) -#loc15 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]"(#loc3)) -#loc16 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc3)) +#loc7 = loc("jit()/jit(main)/pad"(#loc3)) +#loc8 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc9 = loc("jit()/jit(main)/iota"(#loc3)) +#loc10 = loc("jit()/jit(main)/add"(#loc3)) +#loc11 = loc("jit()/jit(main)/ge"(#loc3)) +#loc12 = loc("jit()/jit(main)/broadcast_in_dim"(#loc3)) +#loc13 = loc("jit()/jit(main)/select_n"(#loc3)) """, - mlir_module_serialized=( - b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03\xaa\x02\x12\x023\x01\x9d\x0f\x17\x0b\x0f\x13\x13\x0b\x07\x0b\x0f\x13\x0f\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0bS\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0b\x13\x13K\x13\x1b\x13\x03g\x0fO\x0b\x0b\x0b//\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b//\x0b\x0b\x0b\x0f\x0f\x17\x13\x1f\x1f\x1f\x0b\x0b'\x0f\x17\x17\x1f\x0b/O\x01\x07\x17\x17\x0b\x01\x05\x0b\x0f\x03/\x17\x0f\x0f\x07\x07\x0f\x17\x07\x07\x07\x17\x13\x17\x17\x13\x13\x13\x13\x13\x17\x13\x0f\x17\x02\xaa\t\x1d\x8f\x03\x17\x11\xb2\x05\x17\x05\x1f\x1dO\x03\x03\x03%\xd1\x03\x03\x05\xd9\x05!\x1f\x05#\x1dy\x03\x03\x03\x05\xeb\x11\x03\x05\x05%\x05'\x05)\x05+\x03\x03#\xcd\x05-\x05/\x1dW\x03\x051\x053\x03\x03\x05\xd7\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\tACE\x17G\x17\rI\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x19\xa1\x1b\xb7\x1d\xb9\r\xc3\x1f\xc5\x03\x0b\x19\xad\x1b\xc9\x1d\xad\r\xaf\x1f\xcb\x05M\x1dS\x03\x05O\x03\x03\x05\xcf\x05Q\x03\x03#\xd3\x1d]\x03\x05S\x03\x05)\xb1+\xd5\x1dc\x03\x05U\x1dg\x03\x05W\x1dk\x03\x05Y\x1doq\x05[\x17\x11\xae\x055\x1duw\x05]\x17\x11\xae\x05\x1d\x05_\x03\x13/\xdb1\xb33\xdd5\xa17\xb5}\xdf9\xe1;\xe3=\xe7\x05a\x1d\x81\x03\x05c\x03\x07\x85\xa9\x87\xa9\x89\xa9\x05e\x05g\x05i\x1d\x8d\x03\x05k\x05m\x03\x03\x05\xe9\x03\x03\x05\xed\x03\x11/\xef1\xb33\xf15\xa17\xb59\xf3;\xf5=\xf9\x03\x03\x05\xfb\x03\x05)\xb1+\xfd\x03\x03\x05\xff\x1f-\x01\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1do\x1dq\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1ds\x03\x03\xc7\x1du\t\x07\x1dw\x05\x01#\x1d\x03\x05\xbb\xbf\r\x05\xab\xbd\xa3\xa5\x1dy\r\x05\xab\xc1\xa3\xa5\x1d{\x1d}\x1d\x7f\r\x03\xa3\xa5#\x1f\x1d\x81\x13\r\x01\x1f\x07\t\xff\xff\xff\xff\x1f!\x01\x13\r\x05\x07\x05\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\x11\x03\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x83\r\x01\x03\x03\x9f\x03\x03\xe5\x15\x03\x01\x01\x01\x03\x05\x9f\xa7\x1f\x07\t\x01\x00\x00\x00\x1f\x07\t\x03\x00\x00\x00\x1f\x07\t`\x00\x00\x00\x0b\x05\x1d\x85\x03\x0f\x9d\x9d\x9d\x9d\x9d\x9f\xa7\x03\x03\xf7\x15\x03\x01\x15\x01\x03\x07\x9f\x9d\xa7\x1f\x07\t\x00\x00\x00\x00\x07\x01\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03%\x02\x02\x03\x03\x0e\x02\xaf\x05\x87\x01\t\x01\x02\x02)\x05\r\r\x0b)\x01\x17)\x01\r\x0b\x1d)\x01\x0b)\x05\r\r\x17\x01\x13\x1b)\x05\r\r\x13)\x03\t\r\x11\x01\x05\x05\x05\x11\x03\x05\x03\x05)\x03\x01\r)\x03%\x0b)\x03\r\x0b)\x03\t\x15)\x03\x05\x15)\x03\x02\x03\x0b)\x03\x01\x15)\x01\x13)\x05\x05\x05\x13\x04\x8a\x04\x05\x01\x11\x0f?\x07\x03\x01\t\t\x11\x0fK\x07\x039i\x07\x03m!\x03#\x15\x06s\x03\x05\x03\x01\x03\x03\x13\x0b\x03\t\x03\x03\x13\x0b\x03\t\x11\x07\x13{\x05\x05%\x03\x03\x03\x03\x7f-\x03\x0f\x17\x07\x8b\x83\x03\x05\x05\t\r\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x91\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x93\x03\x07\x11\x07\x01\x95\x07\x05\x07+\x0f\x17\x19\x1b\x1d\x1f\x0f\x0b\x03\x03\x01\x97\x03\x07\x05\x07\x01\t\x03\x07\x03'\x0b\x07\x01\x99\x03/\x05#)\x05\x07\x01\t\x031\x03+\x03\x03\x01\x9b\x03\x0f\x05\x07\x01\t\x03\x05\x03/\x05\x07\x01\x06\x02\x03\x19\x03-\r\x06\x01\x03\x05\x073!1\x19\x07\x07\n\x02\x03\x05\x03\t\x0f\x04\x0f\x0557\t\x11\x07M\x07\x03\x15+\x03\x05\x07\x07\x03Q!\x03\x11\x03\x03\x07U\x03\x07\x05\x07'\t\x03\x11\x03\x05\x13\x06'\x03\x11\x05\x03\x07\x07\x03[Y\x03\x11\x0b\x07a_\x03\x19\x05\t\x0b\x03\x03\x07-\x03\x0f\x05\x07e\t\x03\x05\x03\x0f\r\x06i\x03\x05\x07\r\x11\x01\x0f\x04\x07\x03\x13\x06\x03\x01\x05\x01\x00\x9e\x1a\x89\x0f\x1d%\x11\x0f\x0b\t\t\x03\x0b!\x11#Y\x87##%_)=\x85\x87W\xb3K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b+\x1f\x1f\x15\x1d\x15i\x13\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,)" - b' out_shardings=(UnspecifiedValue,) in_layouts=(None,)' - b' out_layouts=(None,) resource_env=None donated_invars=(False,)' - b' name=triu keep_unused=False' - b' inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3,' - b' 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float64' - b' shape=(9,)' - b' dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3)' - b' dimensions=None]\x00jit()/jit(main)/geqrf\x00mhlo.backend_config\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0,' - b' 0, 0), (0, 0,' - b' 0))]\x00jit()/jit(main)/householder_product\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_dgeqrf_ffi\x00lapack_dorgqr\x00callee\x00' - ), + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.9.3\x00\x01)\x05\x01\x05\x19\x01\x03\x0b\x03\x17\x0f\x13\x17\x1b\x1f#\'+/37\x03\xc5\x8b\'\x01E\x17\x07\x0b\x0f\x0b\x1b\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x17\x0f\x0b\x17\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x03G\x0b/\x0b\x0f\x0b\x0b\x0b\x0fO\x13\x0f\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f/\x0b\x13\x0b\x0b\x0b\x0f\x17/\x0b\x0f\x13\x0f\x0b\x0b\x01\x05\x0b\x0f\x03#\x17\x07\x07\x17\x0f\x07\x0f\x07\x17\x13\x13\x13\x13\x13\x13\x17\x07\x02\xaa\x04\x17\x05r\x06\x17\x1f\x05\x1d\x11\x03\x05\x05\x1f\x03\x05\'o)q\x1d\t\x01\x1d7\x01\x03\x07\x13\x15\x17\x07\x19\x07\x05!\x11\x01\x00\x05#\x05%\x05\'\x1d\t\x1f\x17\x05n\x065\x1d#%\x05)\x17\x05n\x06\x1d\x05+\x05-\x1d-\x01\x05/\x1d1\x01\x051\x1d5\x01\x053\x055\x1d;\x01\x057\x1d?\x01\x059\x1dC\x01\x05;\x03\x01\x1f\x1f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d=\x13\t\x01\x0b\x03\x1d?\x05\x01\x03\x03U\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x05U}\x1f!\x01#\x15\x03\x05_c\r\x03Ia\x1dA\r\x03Ie\x1dC\x1dE\x1dG\x1f\r\t\xff\xff\xff\xff\x1f\x11\x11\x00\x00\x00\x00\x00\x00\x00\x00\r\x01\r\x03su\x1dI\x1dK\x1dM\x03\x03{\x15\x03\x01\x01\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dO\x03\x03\x83\x15\x01\x01\x01\x13\t\x05\t\x07\x07\x05\x01\t\x01\x02\x02)\x05\r\r\x07\x0b\x1d)\x05\r\r\x0f)\x01\x0f\x1b)\x01\x07\x13\x11\x01\x05\x05\x05)\x03%\x07)\x03\r\x07)\x03\t\x13)\x03\x05\x13)\x03\t\t)\x03\x01\t)\x05\r\r%\x01\x04"\x02\x05\x01Q\x03\x11\x01\x07\x04\xff\x03\x01\x05\x0bP\x03\x03\x07\x04\xeb\x03\x1f=\x05B\x03\x05\x03\r\x05B\x03\x07\x03\x11\x03B\x1d\t\x03\x17\r\x06!\x03\x05\x03\x05\x07G+\x0b\x0b\x05\x05\x19\x03\x07\x0fF/\r\x03\x05\x05\t\x03\x07G3\x0b\x0f\x03\x05\x05\r\x0b\x03B\r\t\x03\x0b\tF\x0f\x11\x03\x0b\x03\x01\x11\x06\x0f\x03\x0b\x05\x11\x13\x03B\r\x13\x03\x0b\x13F9\x15\x03#\x05\x15\x17\tF=\x11\x03\x05\x03\x03\x15\x06A\x03\x05\x07\x19\x1b\t\x17\x04\x03\x05\x0f\x1d\x06\x03\x01\x05\x01\x00\xba\x0bQ%%\x05\x1f\x0f\x0b\x15\x15\x03!CS79Y9=3)A\x1b%)9;i\x15\x15\x17\x0f\x0f\x17\x11)\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00iota_v1\x00constant_v1\x00custom_call_v1\x00broadcast_in_dim_v1\x00func_v1\x00reshape_v1\x00pad_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit()/jit(main)/iota\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/reshape\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/pad\x00jit()/jit(main)/householder_product\x00jit()/jit(main)/add\x00jit()/jit(main)/ge\x00jit()/jit(main)/broadcast_in_dim\x00jit()/jit(main)/select_n\x00jax.result_info\x00\x00result[0]\x00result[1]\x00main\x00public\x00num_batch_dims\x000\x00lapack_dgeqrf_ffi\x00lapack_dorgqr_ffi\x00\x08[\x17\x057\x01\x0bE[]gi\x03k\x03m\x03K\x11MOwEQSyW\x07GGG\x11MO\x7fEQW\x81S\x03Y\x03\x85\x05\x87\x89', xla_call_module_version=9, nr_devices=1, ) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_schur_lapack_gees.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_schur_lapack_gees.py index 309aa73f20ba..091c41c26cf6 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_schur_lapack_gees.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_schur_lapack_gees.py @@ -15,232 +15,11 @@ # ruff: noqa import datetime -from numpy import array, int32, float32, complex64 +import numpy as np -data_2023_07_16 = {} - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_07_16["f32"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_sgees'], - serialized_date=datetime.date(2023, 7, 16), - inputs=(array([[ 0., 1., 2., 3.], - [ 4., 5., 6., 7.], - [ 8., 9., 10., 11.], - [12., 13., 14., 15.]], dtype=float32),), - expected_outputs=(array([[ 3.2464233e+01, -1.3416403e+01, -1.5532076e-05, -4.3390692e-06], - [ 0.0000000e+00, -2.4642491e+00, -1.4625000e-06, -6.4478525e-07], - [ 0.0000000e+00, 0.0000000e+00, -8.1893580e-07, -2.5704816e-07], - [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.5155359e-07]], - dtype=float32), array([[-0.11417631 , 0.828833 , -0.546308 , -0.039330132], - [-0.33000442 , 0.4371459 , 0.69909686 , 0.45963493 ], - [-0.54583275 , 0.045459975, 0.24073309 , -0.80127877 ], - [-0.7616609 , -0.34622616 , -0.39352104 , 0.3809742 ]], - dtype=float32)), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xf32> {jax.arg_info = "input", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x4xf32> {jax.result_info = "[0]"}, tensor<4x4xf32> {jax.result_info = "[1]"}) { - %0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %2 = stablehlo.constant dense<86> : tensor loc(#loc2) - %3 = stablehlo.constant dense<78> : tensor loc(#loc2) - %4:6 = stablehlo.custom_call @lapack_sgees(%0, %1, %2, %3, %arg0) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor, tensor<4x4xf32>) -> (tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor, tensor) loc(#loc2) - %5 = stablehlo.constant dense<0> : tensor loc(#loc2) - %6 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor) -> tensor loc(#loc2) - %7 = stablehlo.compare EQ, %4#5, %6, SIGNED : (tensor, tensor) -> tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc2) - %9 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<4x4xf32> loc(#loc2) - %11 = stablehlo.broadcast_in_dim %8, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc2) - %12 = stablehlo.select %11, %4#0, %10 : tensor<4x4xi1>, tensor<4x4xf32> loc(#loc2) - %13 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc2) - %14 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) - %15 = stablehlo.broadcast_in_dim %14, dims = [] : (tensor) -> tensor<4x4xf32> loc(#loc2) - %16 = stablehlo.broadcast_in_dim %13, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc2) - %17 = stablehlo.select %16, %4#3, %15 : tensor<4x4xi1>, tensor<4x4xf32> loc(#loc2) - return %12, %17 : tensor<4x4xf32>, tensor<4x4xf32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":483:0) -#loc2 = loc("jit(func)/jit(main)/schur[compute_schur_vectors=True sort_eig_vals=False select_callable=None]"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xd5\x97+\x01M\x0f\x0b\x13\x07\x0f\x0b\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x03K\x0fO\x0b/\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x13\x13\x0b\x0b\x0b\x0b\x0b\x1f\x0f\x17#\x1f\x0f\x0b\x0b\x1fO\x01\x03\x0f\x03)\x17\x0f\x0f\x07\x07\x07\x0f\x13\x07\x17\x17\x1b\x07\x07\x13\x13\x13\x13\x0f\x13\x02\xbe\x05\x1d')\x05\x15\x03\x03\r\x8d\x1f\x11\x01\x05\x05\x17\x05\x19\x03\x03\x03\x93\x03\x03\r\x95\x03\x07\x15\t\x17\t\x0b\x19\x05\x1b\x05\x1d\x05\x1f\x03\x0b\x1dU\x1fa!c\x0bm#o\x05!\x05#\x05%\x05'\x03\x03\x03q\x05)\x17+\x8e\x07\x01\x05+\x03\x03\x03s\x03\x03\x03u\x03\x03\x03w\x03\x115y7{9};\x7f=\x81?\x83A\x85C\x89\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x8b\x03\x05I\x8fK\x91\x05=\x05?\x1f\x1f\x01\x1f!!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1dA\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03W\r\x05Y[]_\x1dC\x1dE\x1dG\x1dI#\x19\x03\x05ei\r\x03Qg\x1dK\r\x03Qk\x1dM\x1dO\x1dQ\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x07\x03V\x1f\x07\x03N\x0b\x05\x1dS\x1dU\x03\x01\x05\x01\x03\x0bMMMMO\x03\x03\x87\x15\x03\x01\x11\x01\x03\rOSSOMM\x1f\x05\t\x00\x00\x00\x00\x1f%\x01\t\x07\x07\x01\x1f\x0f\t\x00\x00\xc0\x7f\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\t)\x01\x1b)\x01\x1d\t\x13\x01)\x01\t)\x03\x11\t\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x03\x05\x03\x03\x1b!)\x03\x01\x0b)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x13)\x01\r)\x03\t\x13\x04\xa2\x02\x05\x01\x11\x07\x13\x07\x03\x01\x05\t\x11\x07\x1b\x05\x031O\x03\x03\x07\x03\x03\x01%\x03\x05\x03\x03\x01-\x03\x05\x03\x03\x01/\x03\x07\x03\x03\x011\x03\x07\x0b\x07\x013\r\x03\x11\x11\x03\x05\x05\x0b\x03\x05\x07\t\x01\x03\x03\x01E\x03\x05\x05\x07\x01\x05\x03\x05\x03\x17\r\x07\x01G\x03'\x05\x15\x19\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03\x1f\x05\x07\x01\x11\x03\x17\x03\x1d\x07\x06\x01\x03\x03\x07#\x0b!\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03)\x05\x07\x01\x11\x03\x17\x03'\x07\x06\x01\x03\x03\x07-\x11+\x0f\x04\x07\x05%/\x06\x03\x01\x05\x01\x002\x0bW\x1b\x03\x0f\x0b\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97\xbf\x1f\x15\x1d\x15\x13%)+\x13\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00sym_name\x00broadcast_dimensions\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/schur[compute_schur_vectors=True sort_eig_vals=False select_callable=None]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_sgees\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_07_16["f64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dgees'], - serialized_date=datetime.date(2023, 7, 16), - inputs=(array([[ 0., 1., 2., 3.], - [ 4., 5., 6., 7.], - [ 8., 9., 10., 11.], - [12., 13., 14., 15.]]),), - expected_outputs=(array([[ 3.2464249196572958e+01, -1.3416407864998734e+01, - 1.4217165257496823e-15, 1.7257338996070338e-16], - [ 0.0000000000000000e+00, -2.4642491965729794e+00, - 4.0099214829607365e-16, 2.9384059908060751e-16], - [ 0.0000000000000000e+00, 0.0000000000000000e+00, - -1.5668631265126207e-15, 6.3403580326623540e-16], - [ 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 1.2369554016158485e-16]]), array([[-0.11417645138733855 , 0.8288327563197505 , - 0.4940336612834742 , -0.23649681080057947 ], - [-0.3300045986655475 , 0.4371463883638869 , - -0.8349858635153001 , -0.052901868866879136], - [-0.545832745943757 , 0.045460020408024784, - 0.18787074318017621 , 0.8152941701354965 ], - [-0.7616608932219662 , -0.3462263475478383 , - 0.1530814590516493 , -0.525895490468038 ]])), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xf64> {jax.arg_info = "input", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x4xf64> {jax.result_info = "[0]"}, tensor<4x4xf64> {jax.result_info = "[1]"}) { - %0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %2 = stablehlo.constant dense<86> : tensor loc(#loc2) - %3 = stablehlo.constant dense<78> : tensor loc(#loc2) - %4:6 = stablehlo.custom_call @lapack_dgees(%0, %1, %2, %3, %arg0) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor, tensor<4x4xf64>) -> (tensor<4x4xf64>, tensor<4xf64>, tensor<4xf64>, tensor<4x4xf64>, tensor, tensor) loc(#loc2) - %5 = stablehlo.constant dense<0> : tensor loc(#loc2) - %6 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor) -> tensor loc(#loc2) - %7 = stablehlo.compare EQ, %4#5, %6, SIGNED : (tensor, tensor) -> tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc2) - %9 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<4x4xf64> loc(#loc2) - %11 = stablehlo.broadcast_in_dim %8, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc2) - %12 = stablehlo.select %11, %4#0, %10 : tensor<4x4xi1>, tensor<4x4xf64> loc(#loc2) - %13 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc2) - %14 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) - %15 = stablehlo.broadcast_in_dim %14, dims = [] : (tensor) -> tensor<4x4xf64> loc(#loc2) - %16 = stablehlo.broadcast_in_dim %13, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc2) - %17 = stablehlo.select %16, %4#3, %15 : tensor<4x4xi1>, tensor<4x4xf64> loc(#loc2) - return %12, %17 : tensor<4x4xf64>, tensor<4x4xf64> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":483:0) -#loc2 = loc("jit(func)/jit(main)/schur[compute_schur_vectors=True sort_eig_vals=False select_callable=None]"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xd5\x97+\x01M\x0f\x0b\x13\x07\x0f\x0b\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x03K\x0fO\x0b/\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x13\x13\x0b\x0b\x0b\x0b\x0b\x1f\x0f\x17#\x1f\x0f\x0b\x0b/O\x01\x03\x0f\x03)\x17\x0f\x0f\x07\x07\x07\x0f\x13\x07\x17\x17\x1b\x07\x07\x13\x13\x13\x13\x0f\x13\x02\xce\x05\x1d')\x05\x15\x03\x03\r\x8d\x1f\x11\x01\x05\x05\x17\x05\x19\x03\x03\x03\x93\x03\x03\r\x95\x03\x07\x15\t\x17\t\x0b\x19\x05\x1b\x05\x1d\x05\x1f\x03\x0b\x1dU\x1fa!c\x0bm#o\x05!\x05#\x05%\x05'\x03\x03\x03q\x05)\x17+\x8e\x07\x01\x05+\x03\x03\x03s\x03\x03\x03u\x03\x03\x03w\x03\x115y7{9};\x7f=\x81?\x83A\x85C\x89\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x8b\x03\x05I\x8fK\x91\x05=\x05?\x1f\x1f\x01\x1f!!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1dA\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03W\r\x05Y[]_\x1dC\x1dE\x1dG\x1dI#\x19\x03\x05ei\r\x03Qg\x1dK\r\x03Qk\x1dM\x1dO\x1dQ\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x07\x03V\x1f\x07\x03N\x0b\x05\x1dS\x1dU\x03\x01\x05\x01\x03\x0bMMMMO\x03\x03\x87\x15\x03\x01\x11\x01\x03\rOSSOMM\x1f\x05\t\x00\x00\x00\x00\x1f%\x01\t\x07\x07\x01\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\t)\x01\x1b)\x01\x1d\x0b\x13\x01)\x01\t)\x03\x11\t\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x03\x05\x03\x03\x1b!)\x03\x01\x0b)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x13)\x01\r)\x03\t\x13\x04\xa2\x02\x05\x01\x11\x07\x13\x07\x03\x01\x05\t\x11\x07\x1b\x05\x031O\x03\x03\x07\x03\x03\x01%\x03\x05\x03\x03\x01-\x03\x05\x03\x03\x01/\x03\x07\x03\x03\x011\x03\x07\x0b\x07\x013\r\x03\x11\x11\x03\x05\x05\x0b\x03\x05\x07\t\x01\x03\x03\x01E\x03\x05\x05\x07\x01\x05\x03\x05\x03\x17\r\x07\x01G\x03'\x05\x15\x19\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03\x1f\x05\x07\x01\x11\x03\x17\x03\x1d\x07\x06\x01\x03\x03\x07#\x0b!\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03)\x05\x07\x01\x11\x03\x17\x03'\x07\x06\x01\x03\x03\x07-\x11+\x0f\x04\x07\x05%/\x06\x03\x01\x05\x01\x002\x0bW\x1b\x03\x0f\x0b\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97\xbf\x1f\x15\x1d\x15\x13%)+\x13\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00sym_name\x00broadcast_dimensions\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/schur[compute_schur_vectors=True sort_eig_vals=False select_callable=None]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_dgees\x00", - xla_call_module_version=6, -) # End paste - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_07_16["c64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_cgees'], - serialized_date=datetime.date(2023, 7, 16), - inputs=(array([[ 0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j], - [ 4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j], - [ 8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j], - [12.+0.j, 13.+0.j, 14.+0.j, 15.+0.j]], dtype=complex64),), - expected_outputs=(array([[ 3.2464264e+01+0.j, -1.3416414e+01+0.j, -3.3649465e-06+0.j, - 3.5482326e-06+0.j], - [ 0.0000000e+00+0.j, -2.4642489e+00+0.j, -7.4810049e-07+0.j, - 6.1193055e-07+0.j], - [ 0.0000000e+00+0.j, 0.0000000e+00+0.j, -5.7737759e-07+0.j, - 2.5704813e-07+0.j], - [ 0.0000000e+00+0.j, 0.0000000e+00+0.j, 0.0000000e+00+0.j, - 1.4719124e-07+0.j]], dtype=complex64), array([[ 0.11417647 +0.j, -0.8288329 +0.j, 0.5452458 +0.j, - -0.05202686 +0.j], - [ 0.3300045 +0.j, -0.43714625 +0.j, -0.68821627 +0.j, - 0.47577178 +0.j], - [ 0.54583293 +0.j, -0.045460097-0.j, -0.25930598 +0.j, - -0.79546237 +0.j], - [ 0.76166105 +0.j, 0.3462263 +0.j, 0.40227604 +0.j, - 0.37171766 +0.j]], dtype=complex64)), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xcomplex> {jax.arg_info = "input", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x4xcomplex> {jax.result_info = "[0]"}, tensor<4x4xcomplex> {jax.result_info = "[1]"}) { - %0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %2 = stablehlo.constant dense<86> : tensor loc(#loc2) - %3 = stablehlo.constant dense<78> : tensor loc(#loc2) - %4:6 = stablehlo.custom_call @lapack_cgees(%0, %1, %2, %3, %arg0) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor, tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor<4xf32>, tensor<4xcomplex>, tensor<4x4xcomplex>, tensor, tensor) loc(#loc2) - %5 = stablehlo.constant dense<0> : tensor loc(#loc2) - %6 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor) -> tensor loc(#loc2) - %7 = stablehlo.compare EQ, %4#5, %6, SIGNED : (tensor, tensor) -> tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc2) - %9 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc2) - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc2) - %11 = stablehlo.broadcast_in_dim %8, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc2) - %12 = stablehlo.select %11, %4#0, %10 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc2) - %13 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc2) - %14 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc2) - %15 = stablehlo.broadcast_in_dim %14, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc2) - %16 = stablehlo.broadcast_in_dim %13, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc2) - %17 = stablehlo.select %16, %4#3, %15 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc2) - return %12, %17 : tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":483:0) -#loc2 = loc("jit(func)/jit(main)/schur[compute_schur_vectors=True sort_eig_vals=False select_callable=None]"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xd9\x97/\x01M\x0f\x0b\x13\x07\x0f\x0b\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x03K\x0fO\x0b/\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x13\x13\x0b\x0b\x0b\x0b\x0b\x1f\x0f\x17#\x1f\x0f\x0b\x0b/O\x01\x03\x0f\x03-\x17\x0f\x0f\x0b\x07\x07\x0f\x07\x07\x17\x17\x1b\x07\x07\x13\x13\x13\x13\x13\x13\x0f\x13\x02\xe6\x05\x1d')\x05\x15\x03\x03\r\x8d\x1f\x11\x01\x05\x05\x17\x05\x19\x03\x03\x03\x93\x03\x03\r\x95\x03\x07\x15\t\x17\t\x0b\x19\x05\x1b\x05\x1d\x05\x1f\x03\x0b\x1dU\x1fa!c\x0bm#o\x05!\x05#\x05%\x05'\x03\x03\x03q\x05)\x17+\x8e\x07\x01\x05+\x03\x03\x03s\x03\x03\x03u\x03\x03\x03w\x03\x115y7{9};\x7f=\x81?\x83A\x85C\x89\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x8b\x03\x05I\x8fK\x91\x05=\x05?\x1f#\x01\x1f%!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1dA\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03W\r\x05Y[]_\x1dC\x1dE\x1dG\x1dI#\x19\x03\x05ei\r\x03Qg\x1dK\r\x03Qk\x1dM\x1dO\x1dQ\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x07\x03V\x1f\x07\x03N\x0b\x05\x1dS\x1dU\x03\x01\x05\x01\x03\x0bMMMMO\x03\x03\x87\x15\x03\x01\x11\x01\x03\rOSSOMM\x1f\x05\t\x00\x00\x00\x00\x1f)\x01\t\x07\x07\x01\x1f\x0f\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f-!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\t)\x01\x1b)\x01\x1d\x03\x11\x13\x01)\x01\t\t\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x03\x05\x03\x03\x1b!)\x03\x11\x11)\x03\x11\t)\x03\x01\x0b)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x13)\x01\r)\x03\t\x13\x04\xa2\x02\x05\x01\x11\x07\x13\x07\x03\x01\x05\t\x11\x07\x1b\x05\x031O\x03\x03\x07\x03\x03\x01%\x03\x05\x03\x03\x01-\x03\x05\x03\x03\x01/\x03\x07\x03\x03\x011\x03\x07\x0b\x07\x013\r\x03\x1f!\x03\x05\x05\x0b\x03\x05\x07\t\x01\x03\x03\x01E\x03\x05\x05\x07\x01\x05\x03\x05\x03\x17\r\x07\x01G\x03+\x05\x15\x19\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03\x1f\x05\x07\x01\x11\x03\x17\x03\x1d\x07\x06\x01\x03\x03\x07#\x0b!\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03)\x05\x07\x01\x11\x03\x17\x03'\x07\x06\x01\x03\x03\x07-\x11+\x0f\x04\x07\x05%/\x06\x03\x01\x05\x01\x002\x0bW\x1b\x03\x0f\x0b\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97\xbf\x1f\x15\x1d\x15\x13%)+\x13\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00sym_name\x00broadcast_dimensions\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/schur[compute_schur_vectors=True sort_eig_vals=False select_callable=None]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_cgees\x00", - xla_call_module_version=6, -) # End paste - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_07_16["c128"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zgees'], - serialized_date=datetime.date(2023, 7, 16), - inputs=(array([[ 0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j], - [ 4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j], - [ 8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j], - [12.+0.j, 13.+0.j, 14.+0.j, 15.+0.j]]),), - expected_outputs=(array([[ 3.2464249196572965e+01+0.j, -1.3416407864998730e+01+0.j, - 4.3084836728703156e-15+0.j, 2.8665351303736084e-15+0.j], - [ 0.0000000000000000e+00+0.j, -2.4642491965729802e+00+0.j, - -2.3716026934523430e-16+0.j, 3.7279396143672773e-16+0.j], - [ 0.0000000000000000e+00+0.j, 0.0000000000000000e+00+0.j, - -1.6035677295293287e-15+0.j, -6.3403580326623540e-16+0.j], - [ 0.0000000000000000e+00+0.j, 0.0000000000000000e+00+0.j, - 0.0000000000000000e+00+0.j, 1.2218554396786608e-16+0.j]]), array([[ 0.11417645138733863+0.j, -0.8288327563197504 +0.j, - 0.4960613110079619 +0.j, 0.2322136424094458 +0.j], - [ 0.33000459866554754+0.j, -0.43714638836388703+0.j, - -0.8344969112540657 +0.j, 0.06012408092789509+0.j], - [ 0.5458327459437572 +0.j, -0.04546002040802478-0.j, - 0.18080988948424495+0.j, -0.8168890890841272 +0.j], - [ 0.7616608932219662 +0.j, 0.34622634754783854+0.j, - 0.15762571076185886+0.j, 0.5245513657467864 +0.j]])), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xcomplex> {jax.arg_info = "input", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x4xcomplex> {jax.result_info = "[0]"}, tensor<4x4xcomplex> {jax.result_info = "[1]"}) { - %0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %2 = stablehlo.constant dense<86> : tensor loc(#loc2) - %3 = stablehlo.constant dense<78> : tensor loc(#loc2) - %4:6 = stablehlo.custom_call @lapack_zgees(%0, %1, %2, %3, %arg0) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor, tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor<4xf64>, tensor<4xcomplex>, tensor<4x4xcomplex>, tensor, tensor) loc(#loc2) - %5 = stablehlo.constant dense<0> : tensor loc(#loc2) - %6 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor) -> tensor loc(#loc2) - %7 = stablehlo.compare EQ, %4#5, %6, SIGNED : (tensor, tensor) -> tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc2) - %9 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc2) - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc2) - %11 = stablehlo.broadcast_in_dim %8, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc2) - %12 = stablehlo.select %11, %4#0, %10 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc2) - %13 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc2) - %14 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc2) - %15 = stablehlo.broadcast_in_dim %14, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc2) - %16 = stablehlo.broadcast_in_dim %13, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc2) - %17 = stablehlo.select %16, %4#3, %15 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc2) - return %12, %17 : tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":483:0) -#loc2 = loc("jit(func)/jit(main)/schur[compute_schur_vectors=True sort_eig_vals=False select_callable=None]"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xd9\x97/\x01M\x0f\x0b\x13\x07\x0f\x0b\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x03K\x0fO\x0b/\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x13\x13\x0b\x0b\x0b\x0b\x0b\x1f\x0f\x17#\x1f\x0f\x0b\x0bOO\x01\x03\x0f\x03-\x17\x0f\x0f\x0b\x07\x07\x0f\x07\x07\x17\x17\x1b\x07\x07\x13\x13\x13\x13\x13\x13\x0f\x13\x02\x06\x06\x1d')\x05\x15\x03\x03\r\x8d\x1f\x11\x01\x05\x05\x17\x05\x19\x03\x03\x03\x93\x03\x03\r\x95\x03\x07\x15\t\x17\t\x0b\x19\x05\x1b\x05\x1d\x05\x1f\x03\x0b\x1dU\x1fa!c\x0bm#o\x05!\x05#\x05%\x05'\x03\x03\x03q\x05)\x17+\x8e\x07\x01\x05+\x03\x03\x03s\x03\x03\x03u\x03\x03\x03w\x03\x115y7{9};\x7f=\x81?\x83A\x85C\x89\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x8b\x03\x05I\x8fK\x91\x05=\x05?\x1f#\x01\x1f%!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1dA\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03W\r\x05Y[]_\x1dC\x1dE\x1dG\x1dI#\x19\x03\x05ei\r\x03Qg\x1dK\r\x03Qk\x1dM\x1dO\x1dQ\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x07\x03V\x1f\x07\x03N\x0b\x05\x1dS\x1dU\x03\x01\x05\x01\x03\x0bMMMMO\x03\x03\x87\x15\x03\x01\x11\x01\x03\rOSSOMM\x1f\x05\t\x00\x00\x00\x00\x1f)\x01\t\x07\x07\x01\x1f\x0f!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f-!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\t)\x01\x1b)\x01\x1d\x03\x11\x13\x01)\x01\t\x0b\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x03\x05\x03\x03\x1b!)\x03\x11\x11)\x03\x11\t)\x03\x01\x0b)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x13)\x01\r)\x03\t\x13\x04\xa2\x02\x05\x01\x11\x07\x13\x07\x03\x01\x05\t\x11\x07\x1b\x05\x031O\x03\x03\x07\x03\x03\x01%\x03\x05\x03\x03\x01-\x03\x05\x03\x03\x01/\x03\x07\x03\x03\x011\x03\x07\x0b\x07\x013\r\x03\x1f!\x03\x05\x05\x0b\x03\x05\x07\t\x01\x03\x03\x01E\x03\x05\x05\x07\x01\x05\x03\x05\x03\x17\r\x07\x01G\x03+\x05\x15\x19\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03\x1f\x05\x07\x01\x11\x03\x17\x03\x1d\x07\x06\x01\x03\x03\x07#\x0b!\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03)\x05\x07\x01\x11\x03\x17\x03'\x07\x06\x01\x03\x03\x07-\x11+\x0f\x04\x07\x05%/\x06\x03\x01\x05\x01\x002\x0bW\x1b\x03\x0f\x0b\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97\xbf\x1f\x15\x1d\x15\x13%)+\x13\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00sym_name\x00broadcast_dimensions\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/schur[compute_schur_vectors=True sort_eig_vals=False select_callable=None]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_zgees\x00", - xla_call_module_version=6, -) # End paste +array = np.array +float32 = np.float32 +complex64 = np.complex64 data_2024_11_29 = {} diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_svd_lapack_gesdd.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_svd_lapack_gesdd.py index 2d71308caeda..2351bed05905 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_svd_lapack_gesdd.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_svd_lapack_gesdd.py @@ -15,437 +15,14 @@ # ruff: noqa import datetime -from numpy import array, float32, complex64 +import numpy as np -data_2023_06_19 = {} - - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["f32"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_sgesdd'], - serialized_date=datetime.date(2023, 6, 22), - inputs=(array([[[ 1.5410905 , -2.775912 , -2.374003 , 4.028736 ], - [-0.56933475, 1.6115232 , 0.9041465 , -0.8321383 ], - [-5.382895 , 4.734856 , 2.1972926 , 1.5553856 ], - [ 0.5109847 , -1.1969309 , 3.3766198 , -1.3678027 ]], - - [[ 2.2637439 , 3.406768 , 4.809871 , 2.8010902 ], - [-1.9981416 , -0.6599986 , 0.5138156 , 4.5982494 ], - [-2.335944 , -9.151717 , -1.0481138 , 2.272443 ], - [-8.257684 , 1.8223318 , 0.38403794, 5.0769973 ]]], - dtype=float32),), - expected_outputs=(array([[[-0.48540133 , 0.6682397 , -0.48819906 , -0.28196266 ], - [ 0.2180054 , -0.13631375 , 0.14819765 , -0.95495003 ], - [ 0.8457052 , 0.44643915 , -0.27943406 , 0.08597418 ], - [ 0.040523227, -0.57928085 , -0.8133977 , -0.03429017 ]], - - [[-0.21146733 , 0.46376425 , 0.786309 , 0.34917438 ], - [ 0.3461469 , 0.21883713 , 0.3399653 , -0.84659094 ], - [ 0.6526192 , -0.5834038 , 0.3972404 , 0.2755518 ], - [ 0.6399631 , 0.6298203 , -0.32915345 , 0.2922879 ]]], - dtype=float32), array([[ 8.551608 , 5.3574076, 2.8073738, 0.5226082], - [11.457576 , 10.041606 , 5.6716514, 1.4754109]], dtype=float32), array([[[-0.6319046 , 0.6612254 , 0.39110154 , -0.102553196], - [-0.2971051 , 0.13673358 , -0.50112 , 0.80119365 ], - [ 0.08969147 , 0.4433047 , -0.73647296 , -0.5030348 ], - [-0.7101976 , -0.5895471 , -0.23135659 , -0.30745354 ]], - - [[-0.6964344 , -0.5023085 , -0.11150039 , 0.50023323 ], - [-0.32121164 , 0.7889568 , 0.3183193 , 0.41598475 ], - [ 0.5096958 , -0.31399378 , 0.60193455 , 0.5284816 ], - [-0.3898877 , -0.16322286 , 0.7238198 , -0.5453721 ]]], - dtype=float32)), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<2x4x4xf32> {jax.arg_info = "input", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<2x4x4xf32> {jax.result_info = "[0]"}, tensor<2x4xf32> {jax.result_info = "[1]"}, tensor<2x4x4xf32> {jax.result_info = "[2]"}) { - %0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %1 = stablehlo.constant dense<1> : tensor loc(#loc2) - %2 = stablehlo.constant dense<2> : tensor loc(#loc2) - %3 = stablehlo.constant dense<4> : tensor loc(#loc2) - %4 = stablehlo.constant dense<4> : tensor loc(#loc2) - %5 = stablehlo.constant dense<268> : tensor loc(#loc2) - %6:7 = stablehlo.custom_call @lapack_sgesdd(%0, %1, %2, %3, %4, %5, %arg0) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xf32>) -> (tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x4x4xf32>, tensor<2x4x4xf32>, tensor<2xi32>, tensor<32xi32>, tensor<268xf32>) loc(#loc2) - %7 = stablehlo.constant dense<0> : tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %9 = stablehlo.compare EQ, %6#4, %8, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %10 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %11 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor<2x4xf32> loc(#loc2) - %13 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc2) - %14 = stablehlo.select %13, %6#1, %12 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc2) - %15 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %16 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor) -> tensor<2x4x4xf32> loc(#loc2) - %18 = stablehlo.broadcast_in_dim %15, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %19 = stablehlo.select %18, %6#2, %17 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc2) - %20 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %21 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor) -> tensor<2x4x4xf32> loc(#loc2) - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %24 = stablehlo.select %23, %6#3, %22 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc2) - return %19, %14, %24 : tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x4x4xf32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":355:0) -#loc2 = loc("jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xef\xa57\x01Q\x0f\x0b\x07\x13\x0b\x13\x13\x0f\x0b\x13\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03U\x0fo\x0b/\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\x0b'\x0f\x17'O\x1f\x0f\x0b\x0b/\x1fOo\x01\x03\x0f\x035\x0f\x1b\x07\x07\x17\x07\x07\x0f\x07\x13\x1b\x1b\x1f\x13\x17\x13\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02\xb6\x07\x1d+-\x05\x15\x1f\x03\x03\t\x97\x05\x17\x03\x03\t\x9d\x03\x03\x03\x9f\x11\x01\x05\x05\x19\x03\x03\x03y\x03\x03\x03}\x03\x03\t\xa3\x03\x07\x1b\x0f\x1d\x0f\x11\x1f\x05\x1b\x05\x1d\x05\x1f\x03\x0b#Y%e'g\x11u)w\x05!\x05#\x05%\x05'\x05)\x17/\x8e\x05\x01\x05+\x03\x03\x03{\x03\x03\x03\x7f\x03\x117\x819\x83;\x85=\x87?\x89A\x8bC\x8dE\x91\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x95\x03\x05K\x99M\x9b\x05=\x05?\x03\x03\t\xa1\x1f!\x01\x1f#1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1dA\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03[\r\x05]_ac\x1dC\x1dE\x1dG\x1dI#\x1b\x03\x07imq\r\x03Uk\x1dK\r\x03Uo\x1dM\r\x03Us\x1dO\x1dQ\x1dS\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x02\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x0c\x01\x00\x00\x0b\x05\x1dU\x1dW\x03\x01\x05\x01\x03\x0fQQQQQQS\x03\x03\x8f\x15\x03\x01\x19\x01\x03\x0fS\x93SSWWW\x1f%!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f)\x01\t\x07\x07\x01\x1f/\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x11\t\x00\x00\xc0\x7f\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f51\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x01\x13)\x07\t\x11\x11\t\x01\t)\x05\t\x11\t\x13\x1d)\x01\t\x1b)\x03\t\x13)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\x0b\x05)\x03\x81\x13)\x03b\x08\t)\x03\x01\r)\x03\r\r)\x03\t\r)\x03\x05\r)\x03\x01\x0f)\x03\t\x07)\x05\t\x05\x07)\x03\x05\x0f)\x05\t\x11\x07)\x03\t\x0f)\x03\r\x0f\x04~\x03\x05\x01\x11\x05\x19\x07\x03\x01\x05\t\x11\x05!\x05\x03Ak\x03\x05\x05\x03\x03\x01\x13\x03\x03\x03\x03\x01\x13\x03\x03\x03\x03\x011\x03\x03\x03\x03\x01\x15\x03\x03\x03\x03\x01\x15\x03\x03\x03\x03\x013\x03\x03\x0b\x07\x015\x0f\x05\x0b\x05\x05\x15\x1d\x1f\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01G\x03\x03\x05\x07\x01\x07\x03\x15\x03\x1d\r\x07\x01I\x03+\x05\x17\x1f\x05\x07\x01\x0b\x03-\x03!\x03\x03\x01\r\x03\x11\x05\x07\x01\x07\x03\x0b\x03%\x05\x07\x01O\x031\x03#\x07\x06\x01\x03\x0b\x07)\x11'\x05\x07\x01\x0b\x03\x17\x03!\x03\x03\x01\r\x03\x11\x05\x07\x01\x07\x03\x05\x03/\x05\x07\x01\x17\x03\x19\x03-\x07\x06\x01\x03\x05\x073\x131\x05\x07\x01\x0b\x03\x17\x03!\x03\x03\x01\r\x03\x11\x05\x07\x01\x07\x03\x05\x039\x05\x07\x01\x17\x03\x19\x037\x07\x06\x01\x03\x05\x07=\x15;\x0f\x04\x05\x075+?\x06\x03\x01\x05\x01\x00\xbe\nY\x1d\x03\x0f\x0b\t\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97y\x1f\x15\x1d\x15\x13%)\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_sgesdd\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["f64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dgesdd'], - serialized_date=datetime.date(2023, 6, 22), - inputs=(array([[[ 0.3445689867809981 , 3.5114993759427104 , - 4.702602090972179 , -0.2702264758497052 ], - [ 2.209901632583705 , -2.6286702510632773 , - 4.591276599385847 , 3.4465035398844828 ], - [-1.5083742421154478 , 3.3225165204269635 , - 1.2596205557926703 , 3.524804355848018 ], - [ 1.5118969169108838 , 1.838885943509677 , - 2.818520751293422 , 3.06002540493494 ]], - - [[-2.4045510943950843 , -1.5657555633438576 , - -0.6061472334580296 , -0.23926156407779164], - [ 4.087879920053448 , -3.2507640936811715 , - -2.2556577657517476 , 6.090369998330348 ], - [ 1.1165401344486945 , 2.2134726894037247 , - 5.225178515435584 , 1.9794693474107725 ], - [-4.127878192684534 , -0.37313660200336163, - 0.7893465897510026 , -2.0315217791342848 ]]]),), - expected_outputs=(array([[[-0.5109626909166218 , -0.41744996156105785, - -0.731253241567692 , 0.1729779025790829 ], - [-0.5623501368035175 , 0.7608931604238581 , - 0.03470920608540986, 0.32186828528169453], - [-0.39585755254587435, -0.4954770291405409 , - 0.6561880513437818 , 0.4089212062978684 ], - [-0.5157288533916834 , -0.03577207859388855, - 0.18297871183094833, -0.8362194085221047 ]], - - [[-0.12124821978030875, -0.30260506534356213, - -0.5817463045715607 , -0.7451847292758064 ], - [ 0.8877417367326685 , -0.15794001239879188, - -0.3761180739267688 , 0.2133184375808915 ], - [ 0.03055221675864994, 0.9244545314395409 , - -0.3686107533067095 , -0.09260936183071355], - [-0.44303503260363514, -0.16990864078317836, - -0.619864940232637 , 0.624994775612963 ]]]), array([[8.951386926411189 , 5.762891699811626 , 3.839104008889441 , - 1.2696468971033248 ], - [9.21500688857692 , 6.477297670883227 , 3.24626945855818 , - 0.05112101994354587]]), array([[[-0.17890276924244797 , -0.2881812520705063 , - -0.7749616998111006 , -0.5332726590950898 ], - [ 0.38712159387038353 , -0.8985113987184378 , - 0.1397618670046424 , 0.15258033445914954 ], - [-0.23140697924040152 , -0.03708202130554661 , - -0.5045854966104308 , 0.8309447696839614 ], - [-0.8744034999217865 , -0.32901938548360005 , - 0.35396957633060866 , -0.043246992182741084]], - - [[ 0.6276106632546885 , -0.26728735347872895 , - -0.22995258718774078 , 0.6941067163520401 ], - [ 0.2802931697592562 , 0.4781137804659157 , - 0.808362569504731 , 0.19847646746808023 ], - [ 0.6187014005224262 , 0.47714095343944474 , - -0.3740686697560633 , -0.49961757159793246 ], - [-0.3804591585793503 , 0.6872417290515944 , - -0.3921025301835001 , 0.47875384105714014 ]]])), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<2x4x4xf64> {jax.arg_info = "input", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<2x4x4xf64> {jax.result_info = "[0]"}, tensor<2x4xf64> {jax.result_info = "[1]"}, tensor<2x4x4xf64> {jax.result_info = "[2]"}) { - %0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %1 = stablehlo.constant dense<1> : tensor loc(#loc2) - %2 = stablehlo.constant dense<2> : tensor loc(#loc2) - %3 = stablehlo.constant dense<4> : tensor loc(#loc2) - %4 = stablehlo.constant dense<4> : tensor loc(#loc2) - %5 = stablehlo.constant dense<268> : tensor loc(#loc2) - %6:7 = stablehlo.custom_call @lapack_dgesdd(%0, %1, %2, %3, %4, %5, %arg0) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xf64>) -> (tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x4x4xf64>, tensor<2x4x4xf64>, tensor<2xi32>, tensor<32xi32>, tensor<268xf64>) loc(#loc2) - %7 = stablehlo.constant dense<0> : tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %9 = stablehlo.compare EQ, %6#4, %8, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %10 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %11 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor<2x4xf64> loc(#loc2) - %13 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc2) - %14 = stablehlo.select %13, %6#1, %12 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc2) - %15 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %16 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor) -> tensor<2x4x4xf64> loc(#loc2) - %18 = stablehlo.broadcast_in_dim %15, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %19 = stablehlo.select %18, %6#2, %17 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc2) - %20 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %21 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor) -> tensor<2x4x4xf64> loc(#loc2) - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %24 = stablehlo.select %23, %6#3, %22 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc2) - return %19, %14, %24 : tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x4x4xf64> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":355:0) -#loc2 = loc("jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xef\xa57\x01Q\x0f\x0b\x07\x13\x0b\x13\x13\x0f\x0b\x13\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03U\x0fo\x0b/\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\x0b'\x0f\x17'O\x1f\x0f\x0b\x0b//Oo\x01\x03\x0f\x035\x0f\x1b\x07\x07\x17\x07\x07\x0f\x07\x13\x1b\x1b\x1f\x13\x17\x13\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02\xc6\x07\x1d+-\x05\x15\x1f\x03\x03\t\x97\x05\x17\x03\x03\t\x9d\x03\x03\x03\x9f\x11\x01\x05\x05\x19\x03\x03\x03y\x03\x03\x03}\x03\x03\t\xa3\x03\x07\x1b\x0f\x1d\x0f\x11\x1f\x05\x1b\x05\x1d\x05\x1f\x03\x0b#Y%e'g\x11u)w\x05!\x05#\x05%\x05'\x05)\x17/\x8e\x05\x01\x05+\x03\x03\x03{\x03\x03\x03\x7f\x03\x117\x819\x83;\x85=\x87?\x89A\x8bC\x8dE\x91\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x95\x03\x05K\x99M\x9b\x05=\x05?\x03\x03\t\xa1\x1f!\x01\x1f#1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1dA\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03[\r\x05]_ac\x1dC\x1dE\x1dG\x1dI#\x1b\x03\x07imq\r\x03Uk\x1dK\r\x03Uo\x1dM\r\x03Us\x1dO\x1dQ\x1dS\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x02\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x0c\x01\x00\x00\x0b\x05\x1dU\x1dW\x03\x01\x05\x01\x03\x0fQQQQQQS\x03\x03\x8f\x15\x03\x01\x19\x01\x03\x0fS\x93SSWWW\x1f%!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f)\x01\t\x07\x07\x01\x1f/\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x11\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f51\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x01\x13)\x07\t\x11\x11\t\x01\x0b)\x05\t\x11\t\x13\x1d)\x01\t\x1b)\x03\t\x13)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\x0b\x05)\x03\x81\x13)\x03b\x08\t)\x03\x01\r)\x03\r\r)\x03\t\r)\x03\x05\r)\x03\x01\x0f)\x03\t\x07)\x05\t\x05\x07)\x03\x05\x0f)\x05\t\x11\x07)\x03\t\x0f)\x03\r\x0f\x04~\x03\x05\x01\x11\x05\x19\x07\x03\x01\x05\t\x11\x05!\x05\x03Ak\x03\x05\x05\x03\x03\x01\x13\x03\x03\x03\x03\x01\x13\x03\x03\x03\x03\x011\x03\x03\x03\x03\x01\x15\x03\x03\x03\x03\x01\x15\x03\x03\x03\x03\x013\x03\x03\x0b\x07\x015\x0f\x05\x0b\x05\x05\x15\x1d\x1f\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01G\x03\x03\x05\x07\x01\x07\x03\x15\x03\x1d\r\x07\x01I\x03+\x05\x17\x1f\x05\x07\x01\x0b\x03-\x03!\x03\x03\x01\r\x03\x11\x05\x07\x01\x07\x03\x0b\x03%\x05\x07\x01O\x031\x03#\x07\x06\x01\x03\x0b\x07)\x11'\x05\x07\x01\x0b\x03\x17\x03!\x03\x03\x01\r\x03\x11\x05\x07\x01\x07\x03\x05\x03/\x05\x07\x01\x17\x03\x19\x03-\x07\x06\x01\x03\x05\x073\x131\x05\x07\x01\x0b\x03\x17\x03!\x03\x03\x01\r\x03\x11\x05\x07\x01\x07\x03\x05\x039\x05\x07\x01\x17\x03\x19\x037\x07\x06\x01\x03\x05\x07=\x15;\x0f\x04\x05\x075+?\x06\x03\x01\x05\x01\x00\xbe\nY\x1d\x03\x0f\x0b\t\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97y\x1f\x15\x1d\x15\x13%)\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_dgesdd\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["c64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_cgesdd'], - serialized_date=datetime.date(2023, 6, 22), - inputs=(array([[[ 1.6052934 +0.45878917j, 4.587192 -4.5177283j , - 0.4177733 -1.9419309j , -2.2248359 -4.5042715j ], - [-7.083374 -8.127356j , 2.7596245 -4.991001j , - -0.52622825+5.033981j , -0.35441273-1.8215327j ], - [-0.7996552 -2.4052901j , -0.8506142 -3.164714j , - -0.3090829 +2.2020447j , 1.2367196 +2.8830793j ], - [ 1.4633094 -0.5451007j , -3.7833478 +6.6770763j , - -3.1279542 -2.2322626j , -2.1099617 -2.9661314j ]], - - [[ 1.2560439 -5.4743752j , -2.0085676 +2.0063214j , - -0.8132642 -3.4407883j , -0.17360081+0.6419895j ], - [ 2.3756726 +6.3315964j , -0.31447247-1.9387872j , - 4.6732006 -4.286903j , 1.7702469 -1.4957623j ], - [ 1.6918924 -0.52161306j, 0.49963537+4.7751374j , - -1.9243752 -4.5870543j , 2.8829405 +1.7382988j ], - [ 1.4884951 -0.44194785j, -1.3645276 -2.8733373j , - -0.39430943+2.4366508j , -0.76268387+5.2014065j ]]], - dtype=complex64),), - expected_outputs=(array([[[ 0.016725361+0.19210356j , 0.5452691 +0.5572638j , - 0.41363996 +0.18964858j , -0.26152334 -0.28195143j ], - [ 0.53678626 +0.64057267j , -0.21783225 -0.21288812j , - 0.28426644 +0.30535883j , 0.15201284 +0.10768581j ], - [ 0.21286921 +0.154735j , 0.066471666-0.25652882j , - -0.4074613 -0.10356682j , -0.11794163 -0.81844836j ], - [-0.39079374 -0.20583564j , -0.18335931 -0.4421772j , - 0.63489586 +0.19758748j , 0.038680226-0.36351213j ]], - - [[-0.3178596 +0.39032036j , -0.1273337 -0.30841744j , - 0.26394194 +0.26815224j , -0.21332254 -0.66947937j ], - [-0.39241245 -0.60790956j , -0.14006221 +0.41040683j , - -0.0830612 -0.10184447j , -0.45091942 -0.2603987j ], - [-0.36103728 +0.2876153j , -0.4965461 +0.10084368j , - -0.13752826 -0.6203828j , 0.35439825 -0.028546419j], - [ 0.062335093-0.078214265j, 0.35014474 -0.5668197j , - -0.42214075 -0.5090833j , -0.2889288 -0.15894148j ]]], - dtype=complex64), array([[15.135655 , 9.373035 , 7.444931 , 0.41523397], - [12.316969 , 8.661011 , 5.005059 , 2.115905 ]], - dtype=float32), array([[[-0.6537865 +0.j , -0.20306697 -0.6166746j , - 0.29948467 +0.24257992j , -0.007604365+0.04945353j ], - [ 0.52712685 +0.j , -0.11291563 -0.7116954j , - -0.089219 -0.36348897j , -0.23654723 -0.08269388j ], - [-0.31538543 +0.j , -0.014410622+0.15958191j , - -0.17958623 -0.13690898j , -0.6930434 -0.58613425j ], - [-0.44185135 +0.j , 0.17604677 -0.050492246j, - -0.4213856 -0.69485146j , 0.22373371 +0.2465445j ]], - - [[-0.64551586 +0.j , 0.32932255 -0.11672116j , - -0.093527466+0.6710145j , -0.038554154+0.02716677j ], - [ 0.4241116 +0.j , 0.031135002-0.539813j , - -0.26271763 +0.22760014j , -0.63609654 -0.04817467j ], - [-0.4577485 +0.j , -0.15202768 +0.2734652j , - 0.18931003 -0.3297506j , -0.7331101 -0.10269702j ], - [ 0.44034657 +0.j , 0.29474002 +0.63307834j , - 0.31271848 +0.4216674j , -0.20595454 -0.020532424j]]], - dtype=complex64)), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<2x4x4xcomplex> {jax.arg_info = "input", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]"}, tensor<2x4xf32> {jax.result_info = "[1]"}, tensor<2x4x4xcomplex> {jax.result_info = "[2]"}) { - %0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %1 = stablehlo.constant dense<1> : tensor loc(#loc2) - %2 = stablehlo.constant dense<2> : tensor loc(#loc2) - %3 = stablehlo.constant dense<4> : tensor loc(#loc2) - %4 = stablehlo.constant dense<4> : tensor loc(#loc2) - %5 = stablehlo.constant dense<264> : tensor loc(#loc2) - %6:8 = stablehlo.custom_call @lapack_cgesdd(%0, %1, %2, %3, %4, %5, %arg0) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x4xf32>, tensor<2x4x4xcomplex>, tensor<2x4x4xcomplex>, tensor<2xi32>, tensor<32xi32>, tensor<100xf32>, tensor<264xcomplex>) loc(#loc2) - %7 = stablehlo.constant dense<0> : tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %9 = stablehlo.compare EQ, %6#4, %8, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %10 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %11 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor<2x4xf32> loc(#loc2) - %13 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc2) - %14 = stablehlo.select %13, %6#1, %12 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc2) - %15 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %16 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc2) - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc2) - %18 = stablehlo.broadcast_in_dim %15, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %19 = stablehlo.select %18, %6#2, %17 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc2) - %20 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %21 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc2) - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc2) - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %24 = stablehlo.select %23, %6#3, %22 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc2) - return %19, %14, %24 : tensor<2x4x4xcomplex>, tensor<2x4xf32>, tensor<2x4x4xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":355:0) -#loc2 = loc("jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xf9\xa9=\x01S\x0f\x0b\x07\x13\x0b\x13\x0f\x0b\x13\x13\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03W\x0fo/\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\x0b\'\x0f\x17+O\x1f\x0f\x0b\x0b/\x1fO/o\x01\x03\x0f\x03;\x0f\x1b\x07\x07\x17\x07\x07\x0b\x07\x0f\x13\x0f\x1b\x1b\x1f\x13\x17\x17\x13\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02\x1e\x08\x1d+-\x05\x15\x1f\x03\x03\t\x99\x05\x17\x03\x03\t\x9f\x11\x01\x05\x05\x19\x03\x03\x03{\x03\x03\x03\x7f\x03\x03\x03\xa5\x03\x03\t\xa7\x03\x07\x1b\r\x1d\r\x0f\x1f\x05\x1b\x05\x1d\x05\x1f\x03\x0b#[%g\'i\x0fw)y\x05!\x05#\x05%\x05\'\x05)\x17/\x8e\x05\x01\x05+\x03\x03\x03}\x03\x03\x03\x81\x03\x117\x839\x85;\x87=\x89?\x8bA\x8dC\x8fE\x93\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x97\x03\x05K\x9bM\x9d\x05=\x05?\x03\x03\x03\xa1\x03\x03\t\xa3\x1f\'\x01\x1f)1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dA\x03\x03]\r\x05_ace\x1dC\x1dE\x1dG\x1dI#\x1f\x03\x07kos\r\x03Ym\x1dK\r\x03Yq\x1dM\r\x03Yu\x1dO\x1dQ\x1dS\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x02\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x08\x01\x00\x00\x0b\x05\x1dU\x1dW\x03\x01\x05\x01\x03\x0fSSSSSSU\x03\x03\x91\x15\x03\x01\x19\x01\x03\x11U\x95UUWWWW\x1f+!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f/\x01\t\x07\x07\x01\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x19\t\x00\x00\xc0\x7f\x1f9!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x15\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f;1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x01\x13)\x07\t\x11\x11\x11\x01\t)\x05\t\x11\t\x13\x1d\x03\t\x1b)\x01\x11)\x03\t\x13)\x01\t)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\x0b\x05)\x03\x81\x13)\x03"\x03\t)\x03B\x08\x11)\x03\x01\r)\x03\r\r)\x03\t\r)\x03\x05\r)\x03\x01\x0f)\x03\t\x07)\x05\t\x05\x07)\x03\x05\x0f)\x05\t\x11\x07)\x03\t\x0f)\x03\r\x0f\x04\x82\x03\x05\x01\x11\x05\x19\x07\x03\x01\x05\t\x11\x05!\x05\x03Ck\x03\x05\x05\x03\x03\x01\x11\x03\x03\x03\x03\x01\x11\x03\x03\x03\x03\x011\x03\x03\x03\x03\x01\x13\x03\x03\x03\x03\x01\x13\x03\x03\x03\x03\x013\x03\x03\x0b\x07\x015\x11\x05\x0b\x05\x05\x17!#%\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01G\x03\x03\x05\x07\x01\x07\x03\x17\x03\x1f\r\x07\x01I\x031\x05\x17!\x05\x07\x01\x0b\x033\x03#\x03\x03\x01O\x03\x19\x05\x07\x01\x07\x03\x0b\x03\'\x05\x07\x01Q\x037\x03%\x07\x06\x01\x03\x0b\x07+\x11)\x05\x07\x01\x0b\x03\x1b\x03#\x03\x03\x01\x15\x03\x15\x05\x07\x01\x07\x03\x05\x031\x05\x07\x01\x17\x03\x1d\x03/\x07\x06\x01\x03\x05\x075\x133\x05\x07\x01\x0b\x03\x1b\x03#\x03\x03\x01\x15\x03\x15\x05\x07\x01\x07\x03\x05\x03;\x05\x07\x01\x17\x03\x1d\x039\x07\x06\x01\x03\x05\x07?\x15=\x0f\x04\x05\x077-A\x06\x03\x01\x05\x01\x00\xbe\nY\x1d\x03\x0f\x0b\t\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97y\x1f\x15\x1d\x15\x13%)\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_cgesdd\x00', - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["c128"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zgesdd'], - serialized_date=datetime.date(2023, 6, 22), - inputs=(array([[[-0.9247611722912019-1.3615157109291343j , - -1.0663457975211892+4.73170030936092j , - -1.4918732811689488-2.880861991859318j , - -1.111356346434667 -2.869701609083459j ], - [-4.71291623424314 -1.5444012898828912j , - -5.232967549101415 -0.41287816948482003j, - 0.8905737109262459+9.50245186328329j , - 4.397722119094926 -6.842005210371916j ], - [ 1.9369405063276903+2.3496014107398917j , - -1.5609345742256133+4.2102103739897805j , - 0.6596030248996742+5.195353435247212j , - 0.6315014498240328-1.2778849649354402j ], - [ 5.115159214503849 -0.8856276268773485j , - 1.3719934567460779-2.236070491368575j , - 0.4974504006612811-3.0462081956756637j , - -0.2620346712025989+4.424682727912594j ]], - - [[-1.8242711798401063-0.8543252170262536j , - -2.724527211360488 +2.256038331706666j , - -1.2777487543905157+0.976556823566376j , - 3.7438974536713223-0.4994301527847589j ], - [-0.6359051102028691+2.730662301129662j , - -1.2877728943263032+3.9124921723649053j , - -3.4618573226579894+1.7835551986994034j , - -1.4710491660152465+2.144967500163963j ], - [-3.6013691182532828+2.8182351980619034j , - 2.0045935428878803+1.1146211993017152j , - -2.332213857689336 -0.874915651404938j , - -1.5393862406530452+0.6852883119580928j ], - [-2.674897392856801 +2.0724239502976984j , - -3.349108041292141 -1.0215359152295307j , - 0.2603515088197114-1.9093411474619364j , - 5.41252457188561 +8.634368042893094j ]]]),), - expected_outputs=(array([[[-0.04173678258633362+0.10796693731538423j , - 0.6813428383170976 +0.34327979589293334j , - -0.41770229002865755+0.20028957850808823j , - -0.43443513665085287+0.034743251442636465j], - [-0.8408468609573512 -0.1326064604464803j , - -0.21674151028481228+0.015170556885426551j, - 0.17147327711152338+0.1531041615298256j , - -0.3568765623609291 +0.21904384306708768j ], - [-0.2673618144044136 +0.1379833616281103j , - -0.17534278352558025-0.378992615769627j , - -0.8179957069096054 -0.037506032257391624j, - 0.25392637883428526-0.009771014463849802j], - [ 0.40569239968065934-0.08297706578106905j , - -0.4321527034953765 +0.09791545663574397j , - -0.23439193826962654-0.08427130532228161j , - -0.42348296145608866+0.6251448114949291j ]], - - [[ 0.0272684373986653 +0.36312055550335454j , - 0.270297713559288 +0.1304616587162563j , - 0.04286867013923673-0.4765859417602139j , - 0.7242702256119968 +0.15420620503522459j ], - [-0.08593436615104483+0.1189990183325552j , - 0.37050286109355285-0.6240865462984536j , - 0.46902056878806025-0.34747949920770266j , - -0.31667671459632074-0.10340064369932994j ], - [-0.07914843440873574-0.033487314943774035j, - 0.4110353453489128 -0.455090805566563j , - -0.431131803930273 +0.40910871949632j , - 0.13782730102420274+0.49428280062680086j ], - [-0.7478497242333215 +0.5283836938016964j , - -0.08345894989956631+0.011807690067190268j, - -0.27178304569905287+0.056526279406748176j, - -0.09911954913441999-0.2598859654000683j ]]]), array([[16.80132997488892 , 7.744755614558116 , 5.831221808032041 , - 1.1195288361137765], - [12.39537594694893 , 8.218551160453814 , 4.683634850274079 , - 1.8820915363839188]]), array([[[ 0.35796251040556704 +0.j , - 0.40179383774178046 -0.1269359716702074j , - -0.0751486661300563 -0.6109813931761136j , - -0.23049271148274278 +0.51209309438597j ], - [-0.4682861415308549 +0.j , - -0.013958972669495105+0.4210606476774211j , - -0.6006888466394119 -0.3766516564723718j , - -0.24264518623237025 -0.20408557153193485j ], - [-0.6392945524816095 +0.j , - 0.2432388607602898 -0.6679928485374246j , - 0.18168178910997038 -0.08126854868489754j , - -0.2030612067046724 -0.07124733621915219j ], - [-0.49383540371426055 +0.j , - -0.010402968929686592+0.3734624991410737j , - 0.27994282704104956 +0.01949406216762731j , - 0.32588905219319236 +0.6569569657140543j ]], - - [[ 0.2666920370516844 +0.j , - 0.24929033811571413 +0.27271089049933883j , - -0.012922512768026735+0.16383354123801513j , - 0.07388201893235022 -0.8717175469187741j ], - [-0.6156140469162428 +0.j , - -0.33787077397020143 +0.37797154650923376j , - -0.3916043058726119 -0.2839601305776179j , - -0.2714888604157674 -0.23729034093304682j ], - [ 0.5618758038857617 +0.j , - -0.5788776267734554 -0.13833058883452312j , - -0.48995086206819644 +0.19259594116096765j , - -0.22967101640965012 -0.012926826751577613j], - [-0.48393210641613593 +0.j , - -0.1049229605428438 -0.4911419972025977j , - -0.07782239226461217 +0.6751317817750165j , - 0.11941657609231515 -0.19354808489959852j ]]])), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<2x4x4xcomplex> {jax.arg_info = "input", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]"}, tensor<2x4xf64> {jax.result_info = "[1]"}, tensor<2x4x4xcomplex> {jax.result_info = "[2]"}) { - %0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %1 = stablehlo.constant dense<1> : tensor loc(#loc2) - %2 = stablehlo.constant dense<2> : tensor loc(#loc2) - %3 = stablehlo.constant dense<4> : tensor loc(#loc2) - %4 = stablehlo.constant dense<4> : tensor loc(#loc2) - %5 = stablehlo.constant dense<264> : tensor loc(#loc2) - %6:8 = stablehlo.custom_call @lapack_zgesdd(%0, %1, %2, %3, %4, %5, %arg0) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x4xf64>, tensor<2x4x4xcomplex>, tensor<2x4x4xcomplex>, tensor<2xi32>, tensor<32xi32>, tensor<100xf64>, tensor<264xcomplex>) loc(#loc2) - %7 = stablehlo.constant dense<0> : tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %9 = stablehlo.compare EQ, %6#4, %8, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %10 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %11 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor<2x4xf64> loc(#loc2) - %13 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc2) - %14 = stablehlo.select %13, %6#1, %12 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc2) - %15 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %16 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc2) - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc2) - %18 = stablehlo.broadcast_in_dim %15, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %19 = stablehlo.select %18, %6#2, %17 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc2) - %20 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %21 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc2) - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc2) - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %24 = stablehlo.select %23, %6#3, %22 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc2) - return %19, %14, %24 : tensor<2x4x4xcomplex>, tensor<2x4xf64>, tensor<2x4x4xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":355:0) -#loc2 = loc("jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xf9\xa9=\x01S\x0f\x0b\x07\x13\x0b\x13\x0f\x0b\x13\x13\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03W\x0fo/\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\x0b\'\x0f\x17+O\x1f\x0f\x0b\x0b//OOo\x01\x03\x0f\x03;\x0f\x1b\x07\x07\x17\x07\x07\x0b\x07\x0f\x13\x0f\x1b\x1b\x1f\x13\x17\x17\x13\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02N\x08\x1d+-\x05\x15\x1f\x03\x03\t\x99\x05\x17\x03\x03\t\x9f\x11\x01\x05\x05\x19\x03\x03\x03{\x03\x03\x03\x7f\x03\x03\x03\xa5\x03\x03\t\xa7\x03\x07\x1b\r\x1d\r\x0f\x1f\x05\x1b\x05\x1d\x05\x1f\x03\x0b#[%g\'i\x0fw)y\x05!\x05#\x05%\x05\'\x05)\x17/\x8e\x05\x01\x05+\x03\x03\x03}\x03\x03\x03\x81\x03\x117\x839\x85;\x87=\x89?\x8bA\x8dC\x8fE\x93\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x97\x03\x05K\x9bM\x9d\x05=\x05?\x03\x03\x03\xa1\x03\x03\t\xa3\x1f\'\x01\x1f)1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dA\x03\x03]\r\x05_ace\x1dC\x1dE\x1dG\x1dI#\x1f\x03\x07kos\r\x03Ym\x1dK\r\x03Yq\x1dM\r\x03Yu\x1dO\x1dQ\x1dS\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x02\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x08\x01\x00\x00\x0b\x05\x1dU\x1dW\x03\x01\x05\x01\x03\x0fSSSSSSU\x03\x03\x91\x15\x03\x01\x19\x01\x03\x11U\x95UUWWWW\x1f+!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f/\x01\t\x07\x07\x01\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x19\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f9!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x15!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f;1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x01\x13)\x07\t\x11\x11\x11\x01\x0b)\x05\t\x11\t\x13\x1d\x03\t\x1b)\x01\x11)\x03\t\x13)\x01\t)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\x0b\x05)\x03\x81\x13)\x03"\x03\t)\x03B\x08\x11)\x03\x01\r)\x03\r\r)\x03\t\r)\x03\x05\r)\x03\x01\x0f)\x03\t\x07)\x05\t\x05\x07)\x03\x05\x0f)\x05\t\x11\x07)\x03\t\x0f)\x03\r\x0f\x04\x82\x03\x05\x01\x11\x05\x19\x07\x03\x01\x05\t\x11\x05!\x05\x03Ck\x03\x05\x05\x03\x03\x01\x11\x03\x03\x03\x03\x01\x11\x03\x03\x03\x03\x011\x03\x03\x03\x03\x01\x13\x03\x03\x03\x03\x01\x13\x03\x03\x03\x03\x013\x03\x03\x0b\x07\x015\x11\x05\x0b\x05\x05\x17!#%\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01G\x03\x03\x05\x07\x01\x07\x03\x17\x03\x1f\r\x07\x01I\x031\x05\x17!\x05\x07\x01\x0b\x033\x03#\x03\x03\x01O\x03\x19\x05\x07\x01\x07\x03\x0b\x03\'\x05\x07\x01Q\x037\x03%\x07\x06\x01\x03\x0b\x07+\x11)\x05\x07\x01\x0b\x03\x1b\x03#\x03\x03\x01\x15\x03\x15\x05\x07\x01\x07\x03\x05\x031\x05\x07\x01\x17\x03\x1d\x03/\x07\x06\x01\x03\x05\x075\x133\x05\x07\x01\x0b\x03\x1b\x03#\x03\x03\x01\x15\x03\x15\x05\x07\x01\x07\x03\x05\x03;\x05\x07\x01\x17\x03\x1d\x039\x07\x06\x01\x03\x05\x07?\x15=\x0f\x04\x05\x077-A\x06\x03\x01\x05\x01\x00\xbe\nY\x1d\x03\x0f\x0b\t\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97y\x1f\x15\x1d\x15\x13%)\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_zgesdd\x00', - xla_call_module_version=6, -) # End paste +array = np.array +float32 = np.float32 +complex64 = np.complex64 data_2024_08_13 = {} - # Pasted from the test output (see export_back_compat_test_util.py module docstring) data_2024_08_13["c128"] = dict( testdata_version=1, diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_triangular_solve_blas_trsm.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_triangular_solve_blas_trsm.py index c401ca041bfb..5b909d3a8d8f 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_triangular_solve_blas_trsm.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_triangular_solve_blas_trsm.py @@ -15,15 +15,21 @@ # ruff: noqa import datetime -from numpy import array, float32, complex64 +import numpy as np -data_2023_07_16 = {} -# Pasted from the test output (see back_compat_test_util.py module docstring) -data_2023_07_16["f32"] = dict( +array = np.array +float32 = np.float32 +complex64 = np.complex64 + +data_2025_10_20 = {} + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_10_20['f32'] = dict( testdata_version=1, platform='cpu', - custom_call_targets=['blas_strsm'], - serialized_date=datetime.date(2023, 7, 16), + custom_call_targets=['lapack_strsm_ffi'], + serialized_date=datetime.date(2025, 10, 20), inputs=(array([[ 5., 0., 0., 0.], [ 4., 10., 0., 0.], [ 8., 9., 15., 0.], @@ -40,34 +46,31 @@ [ 0.16833334 , 0.12173338 , 0.0751333 , 0.02853328 , -0.018066704]], dtype=float32),), mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xf32> {jax.arg_info = "a", mhlo.sharding = "{replicated}"} loc(unknown), %arg1: tensor<4x5xf32> {jax.arg_info = "b", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x5xf32> {jax.result_info = ""}) { - %0 = stablehlo.constant dense<1.000000e+00> : tensor loc(#loc2) - %1 = stablehlo.constant dense<1> : tensor loc(#loc2) - %2 = stablehlo.constant dense<1> : tensor loc(#loc2) - %3 = stablehlo.constant dense<0> : tensor loc(#loc2) - %4 = stablehlo.constant dense<0> : tensor loc(#loc2) - %5 = stablehlo.constant dense<4> : tensor loc(#loc2) - %6 = stablehlo.constant dense<5> : tensor loc(#loc2) - %7 = stablehlo.constant dense<1> : tensor loc(#loc2) - %8 = stablehlo.custom_call @blas_strsm(%1, %2, %3, %4, %5, %6, %7, %0, %arg0, %arg1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor<4x4xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> loc(#loc2) - return %8 : tensor<4x5xf32> loc(#loc) +#loc1 = loc("a") +#loc2 = loc("b") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<4x4xf32> loc("a"), %arg1: tensor<4x5xf32> loc("b")) -> (tensor<4x5xf32> {jax.result_info = "result"}) { + %0 = stablehlo.custom_call @lapack_strsm_ffi(%arg0, %arg1) {mhlo.backend_config = {diag = 78 : ui8, side = 76 : ui8, trans_x = 78 : ui8, uplo = 76 : ui8}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>], sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [k, l])->([m, n]) {i=4, j=4, k=4, l=5, m=4, n=5}, custom>} : (tensor<4x4xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> loc(#loc5) + return %0 : tensor<4x5xf32> loc(#loc) } loc(#loc) } loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":508:0) -#loc2 = loc("jit(func)/jit(main)/triangular_solve[left_side=True lower=True transpose_a=False conjugate_a=False unit_diagonal=False]"(#loc1)) +#loc = loc(unknown) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":654:13) +#loc4 = loc("jit(func)"(#loc3)) +#loc5 = loc("triangular_solve"(#loc4)) """, - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x19\x05\x01\x03\x01\x03\x05\x03\t\x07\t\x0b\r\x03\xa5{\x17\x01?\x0f\x07\x0b\x13\x0f\x0b\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03=\x0fO\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x1f\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b3\x0f\x13\x0f\x01\x03\x0f\x03\x15\x0f\x17\x07\x17\x0f\x07\x1b\x07\x13\x13\x02J\x04\x1d#%\x1f\x05\x0f\x03\x03\x05c\x11\x01\x05\x05\x11\x03\x03\x05e\x03\x07\x11\t\x13\t\x0b\x15\x05\x13\x05\x15\x05\x17\x03\x0b\x19K\x1bU\x1dW\x0b]\x1f_\x05\x19\x05\x1b\x05\x1d\x05\x1f\x03\x03\x05a\x05!\x17'\xf2\x07\x01\x05#\x03\x03\x05g\x03\x03\x05i\x03\x11/k1I3m5o7q9s;u=y\x05%\x05'\x05)\x05+\x05-\x05/\x051\x053\x1f\x13\x01\x1f\x15!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d5\x1d7\x1d9\x1d;\x03\x05MQ\r\x05COEG\x1d=\r\x05CSEG\x1d?#\x0f\x03\x03Y\r\x03[I\x1dA\x1dC\x1dE\x1f\x0b\t\x00\x00\x80?\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x05\x00\x00\x00\x0b\x05\x1dG\x03\x01\x05\x01\x03\x15????????AA\x03\x03w\x15\x01%\x01\x03\x03A\x01\x02\x02)\x01\x11)\x05\x11\x15\x07\t)\x05\x11\x11\x07)\x01\x07\x13\x11\x05\t\x05\x03\x05\x1b)\x03\x01\r)\x03\t\r\x04\xb9\x05\x01\x11\x03\x0f\x07\x03\x01\x05\x05\x11\x03\x17\x05\x03\x17+\x05\t\x03\x05\x03\x03\x03\x01!\x03\x0b\x03\x03\x01\x07\x03\x03\x03\x03\x01\x07\x03\x03\x03\x03\x01\r\x03\x03\x03\x03\x01\r\x03\x03\x03\x03\x01)\x03\x03\x03\x03\x01+\x03\x03\x03\x03\x01\x07\x03\x03\x07\x07\x01-\x03\x05\x15\x07\t\x0b\r\x0f\x11\x13\x05\x01\x03\t\x04\x03\x03\x15\x06\x03\x01\x05\x01\x00\xca\tI\x17\x0f\x0b!\x05\x05\x03\x1b\x1d\x1b\x1f/!!)#\x1f\x19\x97\xf1\x1f\x15\x1d\x15\x13%)\x13\r\x15\x1f\x11\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00func_v1\x00custom_call_v1\x00return_v1\x00value\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/triangular_solve[left_side=True lower=True transpose_a=False conjugate_a=False unit_diagonal=False]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.arg_info\x00mhlo.sharding\x00{replicated}\x00\x00a\x00b\x00jax.result_info\x00main\x00public\x00blas_strsm\x00", - xla_call_module_version=6, + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.12.1\x00\x01\x1b\x07\x01\x05\t\t\x01\x03\x0f\x03\x07\x13\x17\x1b\x03\xa5{\x13\x01-\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x17\x0b\x03;O\x0b\x0f\x0f\x13\x0b\x0f\x13\x0b\x0b\x0b\x0b+\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0f\x13\x0f\x05\x15K\x13\x0f\x0f\x13\x0f\x0f\x13\x0f\x0f\x01\x05\x0b\x0f\x03\x0f\x17\x17\x07\x07\x1b\x13\x07\x02\xba\x03\x1f\x11\x03\x05\x03\x07\x07\t\x0b\x03\r\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x13\x01\x05\x17\x1d\x17\x01\x05\x19\x03\x07\x1bE\x1dO\x1fg\x05\x1b\x05\x1d\x05\x1f\x1d#%\x05!\x1d')\x05#\x17+:\n\x1b\x05%\x1f\x0f!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\r\x01\x13\x0bN\x13\x0bL\x03\x05//#\r\x03\x03;\r\x03=?\x1d'\x1d)\x1d+\x1d-\r\tG1I3K1M3\x1d/\x1d1\x1d3\x1d5\r\x03QS\x1d7\x1d9\x0b\x03\x1d;\x1d=\x03\x01\x05\x01\x03\x05--\x03\x03c\x15\x01\x05\x01\x03\x03-\x15\r\x11\x11\x11\x15\x11\x15\x05io\x03u\x01\x01\x01\x01\x01\x13\x05km\x11\x03\x01\x11\x03\x05\x13\x05qs\x11\x03\t\x11\x03\r\x13\x05wy\x11\x03\x11\x11\x03\x15\x01\t\x01\x02\x02)\x05\x11\x15\t)\x05\x11\x11\t\t!\x11\x05\x07\x05\x03\x05)\x03\t\x11\x13\x04W\x05\x01Q\x01\x05\x01\x07\x04E\x03\x01\x05\x03P\x01\x03\x07\x041\x03\x07\x0b\x05\x0f\x11\x0b\x15\x00\x05G!\x19\x05\x03\x05\x05\x01\x03\x07\x04\x01\x03\x05\x06\x03\x01\x05\x01\x00N\x06?#\x03\x05\x1f\x0b\x11\x0b\x0b\x0f\x0b\x0f!i\x15#%3)\x05\x05\x13%)9\x15\x1f\x11\x0f\t\x0b\x11builtin\x00vhlo\x00sdy\x00module\x00func_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00a\x00b\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00sdy.sharding_rule\x00triangular_solve\x00jit(func)\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.result_info\x00result\x00main\x00public\x00diag\x00side\x00trans_x\x00uplo\x00num_batch_dims\x000\x00\x00lapack_strsm_ffi\x00\x08'\x07\x05\x1f\x01\x0b579AC\x11UWY[]_ae", + xla_call_module_version=10, + nr_devices=1, ) # End paste -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_07_16["f64"] = dict( + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_10_20['f64'] = dict( testdata_version=1, platform='cpu', - custom_call_targets=['blas_dtrsm'], - serialized_date=datetime.date(2023, 7, 16), + custom_call_targets=['lapack_dtrsm_ffi'], + serialized_date=datetime.date(2025, 10, 20), inputs=(array([[ 5., 0., 0., 0.], [ 4., 10., 0., 0.], [ 8., 9., 15., 0.], @@ -88,35 +91,31 @@ 0.07513333333333323 , 0.0285333333333333 , -0.018066666666666675]]),), mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xf64> {jax.arg_info = "a", mhlo.sharding = "{replicated}"} loc(unknown), %arg1: tensor<4x5xf64> {jax.arg_info = "b", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x5xf64> {jax.result_info = ""}) { - %0 = stablehlo.constant dense<1.000000e+00> : tensor loc(#loc2) - %1 = stablehlo.constant dense<1> : tensor loc(#loc2) - %2 = stablehlo.constant dense<1> : tensor loc(#loc2) - %3 = stablehlo.constant dense<0> : tensor loc(#loc2) - %4 = stablehlo.constant dense<0> : tensor loc(#loc2) - %5 = stablehlo.constant dense<4> : tensor loc(#loc2) - %6 = stablehlo.constant dense<5> : tensor loc(#loc2) - %7 = stablehlo.constant dense<1> : tensor loc(#loc2) - %8 = stablehlo.custom_call @blas_dtrsm(%1, %2, %3, %4, %5, %6, %7, %0, %arg0, %arg1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor<4x4xf64>, tensor<4x5xf64>) -> tensor<4x5xf64> loc(#loc2) - return %8 : tensor<4x5xf64> loc(#loc) +#loc1 = loc("a") +#loc2 = loc("b") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<4x4xf64> loc("a"), %arg1: tensor<4x5xf64> loc("b")) -> (tensor<4x5xf64> {jax.result_info = "result"}) { + %0 = stablehlo.custom_call @lapack_dtrsm_ffi(%arg0, %arg1) {mhlo.backend_config = {diag = 78 : ui8, side = 76 : ui8, trans_x = 78 : ui8, uplo = 76 : ui8}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>], sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [k, l])->([m, n]) {i=4, j=4, k=4, l=5, m=4, n=5}, custom>} : (tensor<4x4xf64>, tensor<4x5xf64>) -> tensor<4x5xf64> loc(#loc5) + return %0 : tensor<4x5xf64> loc(#loc) } loc(#loc) } loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":511:0) -#loc2 = loc("jit(func)/jit(main)/triangular_solve[left_side=True lower=True transpose_a=False conjugate_a=False unit_diagonal=False]"(#loc1)) +#loc = loc(unknown) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":654:13) +#loc4 = loc("jit(func)"(#loc3)) +#loc5 = loc("triangular_solve"(#loc4)) """, - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x19\x05\x01\x03\x01\x03\x05\x03\t\x07\t\x0b\r\x03\xa5{\x17\x01?\x0f\x07\x0b\x13\x0f\x0b\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03=\x0fO\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0f\x13\x0b\x0b\x0b/\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b3\x0f\x13\x0f\x01\x03\x0f\x03\x15\x0f\x17\x07\x17\x0f\x07\x1b\x07\x13\x13\x02Z\x04\x1d#%\x1f\x05\x0f\x03\x03\x05c\x11\x01\x05\x05\x11\x03\x03\x05e\x03\x07\x11\t\x13\t\x0b\x15\x05\x13\x05\x15\x05\x17\x03\x0b\x19K\x1bU\x1dW\x0b]\x1f_\x05\x19\x05\x1b\x05\x1d\x05\x1f\x03\x03\x05a\x05!\x17'\xfe\x07\x01\x05#\x03\x03\x05g\x03\x03\x05i\x03\x11/k1I3m5o7q9s;u=y\x05%\x05'\x05)\x05+\x05-\x05/\x051\x053\x1f\x13\x01\x1f\x15!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d5\x1d7\x1d9\x1d;\x03\x05MQ\r\x05COEG\x1d=\r\x05CSEG\x1d?#\x0f\x03\x03Y\r\x03[I\x1dA\x1dC\x1dE\x1f\x0b\x11\x00\x00\x00\x00\x00\x00\xf0?\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x05\x00\x00\x00\x0b\x05\x1dG\x03\x01\x05\x01\x03\x15????????AA\x03\x03w\x15\x01%\x01\x03\x03A\x01\x02\x02)\x01\x11)\x05\x11\x15\x07\x0b)\x05\x11\x11\x07)\x01\x07\x13\x11\x05\t\x05\x03\x05\x1b)\x03\x01\r)\x03\t\r\x04\xb9\x05\x01\x11\x03\x0f\x07\x03\x01\x05\x05\x11\x03\x17\x05\x03\x17+\x05\t\x03\x05\x03\x03\x03\x01!\x03\x0b\x03\x03\x01\x07\x03\x03\x03\x03\x01\x07\x03\x03\x03\x03\x01\r\x03\x03\x03\x03\x01\r\x03\x03\x03\x03\x01)\x03\x03\x03\x03\x01+\x03\x03\x03\x03\x01\x07\x03\x03\x07\x07\x01-\x03\x05\x15\x07\t\x0b\r\x0f\x11\x13\x05\x01\x03\t\x04\x03\x03\x15\x06\x03\x01\x05\x01\x00\xca\tI\x17\x0f\x0b!\x05\x05\x03\x1b\x1d\x1b\x1f/!!)#\x1f\x19\x97\xf1\x1f\x15\x1d\x15\x13%)\x13\r\x15\x1f\x11\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00func_v1\x00custom_call_v1\x00return_v1\x00value\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/triangular_solve[left_side=True lower=True transpose_a=False conjugate_a=False unit_diagonal=False]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.arg_info\x00mhlo.sharding\x00{replicated}\x00\x00a\x00b\x00jax.result_info\x00main\x00public\x00blas_dtrsm\x00", - xla_call_module_version=6, + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.12.1\x00\x01\x1b\x07\x01\x05\t\t\x01\x03\x0f\x03\x07\x13\x17\x1b\x03\xa5{\x13\x01-\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x17\x0b\x03;O\x0b\x0f\x0f\x13\x0b\x0f\x13\x0b\x0b\x0b\x0b+\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0f\x13\x0f\x05\x15K\x13\x0f\x0f\x13\x0f\x0f\x13\x0f\x0f\x01\x05\x0b\x0f\x03\x0f\x17\x17\x07\x07\x1b\x13\x07\x02\xba\x03\x1f\x11\x03\x05\x03\x07\x07\t\x0b\x03\r\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x13\x01\x05\x17\x1d\x17\x01\x05\x19\x03\x07\x1bE\x1dO\x1fg\x05\x1b\x05\x1d\x05\x1f\x1d#%\x05!\x1d')\x05#\x17+:\n\x1b\x05%\x1f\x0f!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\r\x01\x13\x0bN\x13\x0bL\x03\x05//#\r\x03\x03;\r\x03=?\x1d'\x1d)\x1d+\x1d-\r\tG1I3K1M3\x1d/\x1d1\x1d3\x1d5\r\x03QS\x1d7\x1d9\x0b\x03\x1d;\x1d=\x03\x01\x05\x01\x03\x05--\x03\x03c\x15\x01\x05\x01\x03\x03-\x15\r\x11\x11\x11\x15\x11\x15\x05io\x03u\x01\x01\x01\x01\x01\x13\x05km\x11\x03\x01\x11\x03\x05\x13\x05qs\x11\x03\t\x11\x03\r\x13\x05wy\x11\x03\x11\x11\x03\x15\x01\t\x01\x02\x02)\x05\x11\x15\t)\x05\x11\x11\t\x0b!\x11\x05\x07\x05\x03\x05)\x03\t\x11\x13\x04W\x05\x01Q\x01\x05\x01\x07\x04E\x03\x01\x05\x03P\x01\x03\x07\x041\x03\x07\x0b\x05\x0f\x11\x0b\x15\x00\x05G!\x19\x05\x03\x05\x05\x01\x03\x07\x04\x01\x03\x05\x06\x03\x01\x05\x01\x00N\x06?#\x03\x05\x1f\x0b\x11\x0b\x0b\x0f\x0b\x0f!i\x15#%3)\x05\x05\x13%)9\x15\x1f\x11\x0f\t\x0b\x11builtin\x00vhlo\x00sdy\x00module\x00func_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00a\x00b\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00sdy.sharding_rule\x00triangular_solve\x00jit(func)\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.result_info\x00result\x00main\x00public\x00diag\x00side\x00trans_x\x00uplo\x00num_batch_dims\x000\x00\x00lapack_dtrsm_ffi\x00\x08'\x07\x05\x1f\x01\x0b579AC\x11UWY[]_ae", + xla_call_module_version=10, + nr_devices=1, ) # End paste -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_07_16["c64"] = dict( +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_10_20['c64'] = dict( testdata_version=1, platform='cpu', - custom_call_targets=['blas_ctrsm'], - serialized_date=datetime.date(2023, 7, 16), + custom_call_targets=['lapack_ctrsm_ffi'], + serialized_date=datetime.date(2025, 10, 20), inputs=(array([[ 5.+0.j, 0.+0.j, 0.+0.j, 0.+0.j], [ 4.+0.j, 10.+0.j, 0.+0.j, 0.+0.j], [ 8.+0.j, 9.+0.j, 15.+0.j, 0.+0.j], @@ -133,35 +132,31 @@ [ 0.16833334 +0.j, 0.12173338 +0.j, 0.0751333 +0.j, 0.02853328 +0.j, -0.018066704+0.j]], dtype=complex64),), mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xcomplex> {jax.arg_info = "a", mhlo.sharding = "{replicated}"} loc(unknown), %arg1: tensor<4x5xcomplex> {jax.arg_info = "b", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x5xcomplex> {jax.result_info = ""}) { - %0 = stablehlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor> loc(#loc2) - %1 = stablehlo.constant dense<1> : tensor loc(#loc2) - %2 = stablehlo.constant dense<1> : tensor loc(#loc2) - %3 = stablehlo.constant dense<0> : tensor loc(#loc2) - %4 = stablehlo.constant dense<0> : tensor loc(#loc2) - %5 = stablehlo.constant dense<4> : tensor loc(#loc2) - %6 = stablehlo.constant dense<5> : tensor loc(#loc2) - %7 = stablehlo.constant dense<1> : tensor loc(#loc2) - %8 = stablehlo.custom_call @blas_ctrsm(%1, %2, %3, %4, %5, %6, %7, %0, %arg0, %arg1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor>, tensor<4x4xcomplex>, tensor<4x5xcomplex>) -> tensor<4x5xcomplex> loc(#loc2) - return %8 : tensor<4x5xcomplex> loc(#loc) +#loc1 = loc("a") +#loc2 = loc("b") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<4x4xcomplex> loc("a"), %arg1: tensor<4x5xcomplex> loc("b")) -> (tensor<4x5xcomplex> {jax.result_info = "result"}) { + %0 = stablehlo.custom_call @lapack_ctrsm_ffi(%arg0, %arg1) {mhlo.backend_config = {diag = 78 : ui8, side = 76 : ui8, trans_x = 78 : ui8, uplo = 76 : ui8}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>], sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [k, l])->([m, n]) {i=4, j=4, k=4, l=5, m=4, n=5}, custom>} : (tensor<4x4xcomplex>, tensor<4x5xcomplex>) -> tensor<4x5xcomplex> loc(#loc5) + return %0 : tensor<4x5xcomplex> loc(#loc) } loc(#loc) } loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":510:0) -#loc2 = loc("jit(func)/jit(main)/triangular_solve[left_side=True lower=True transpose_a=False conjugate_a=False unit_diagonal=False]"(#loc1)) +#loc = loc(unknown) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":654:13) +#loc4 = loc("jit(func)"(#loc3)) +#loc5 = loc("triangular_solve"(#loc4)) """, - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x19\x05\x01\x03\x01\x03\x05\x03\t\x07\t\x0b\r\x03\xa7{\x19\x01?\x0f\x07\x0b\x13\x0f\x0b\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03=\x0fO\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0f\x13\x0b\x0b\x0b/\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b3\x0f\x13\x0f\x01\x03\x0f\x03\x17\x0f\x17\x0b\x17\x0f\x07\x1b\x07\x07\x13\x13\x02b\x04\x1d#%\x1f\x05\x0f\x03\x03\x05c\x11\x01\x05\x05\x11\x03\x03\x05e\x03\x07\x11\t\x13\t\x0b\x15\x05\x13\x05\x15\x05\x17\x03\x0b\x19K\x1bU\x1dW\x0b]\x1f_\x05\x19\x05\x1b\x05\x1d\x05\x1f\x03\x03\x05a\x05!\x17'\xfa\x07\x01\x05#\x03\x03\x05g\x03\x03\x05i\x03\x11/k1I3m5o7q9s;u=y\x05%\x05'\x05)\x05+\x05-\x05/\x051\x053\x1f\x15\x01\x1f\x17!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d5\x1d7\x1d9\x1d;\x03\x05MQ\r\x05COEG\x1d=\r\x05CSEG\x1d?#\x0f\x03\x03Y\r\x03[I\x1dA\x1dC\x1dE\x1f\x0b\x11\x00\x00\x80?\x00\x00\x00\x00\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x05\x00\x00\x00\x0b\x05\x1dG\x03\x01\x05\x01\x03\x15????????AA\x03\x03w\x15\x01%\x01\x03\x03A\x01\x02\x02)\x01\x13)\x05\x11\x15\x07\x03\x11)\x05\x11\x11\x07)\x01\x07\x13\x11\x05\t\x05\x03\x05\t\x1b)\x03\x01\r)\x03\t\r\x04\xb9\x05\x01\x11\x03\x0f\x07\x03\x01\x05\x05\x11\x03\x17\x05\x03\x17+\x05\t\x03\x05\x03\x03\x03\x01!\x03\x0b\x03\x03\x01\x07\x03\x03\x03\x03\x01\x07\x03\x03\x03\x03\x01\r\x03\x03\x03\x03\x01\r\x03\x03\x03\x03\x01)\x03\x03\x03\x03\x01+\x03\x03\x03\x03\x01\x07\x03\x03\x07\x07\x01-\x03\x05\x15\x07\t\x0b\r\x0f\x11\x13\x05\x01\x03\t\x04\x03\x03\x15\x06\x03\x01\x05\x01\x00\xca\tI\x17\x0f\x0b!\x05\x05\x03\x1b\x1d\x1b\x1f/!!)#\x1f\x19\x97\xf1\x1f\x15\x1d\x15\x13%)\x13\r\x15\x1f\x11\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00func_v1\x00custom_call_v1\x00return_v1\x00value\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/triangular_solve[left_side=True lower=True transpose_a=False conjugate_a=False unit_diagonal=False]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.arg_info\x00mhlo.sharding\x00{replicated}\x00\x00a\x00b\x00jax.result_info\x00main\x00public\x00blas_ctrsm\x00", - xla_call_module_version=6, + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.12.1\x00\x01\x1b\x07\x01\x05\t\t\x01\x03\x0f\x03\x07\x13\x17\x1b\x03\xa7{\x15\x01-\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x17\x0b\x03;O\x0b\x0f\x0f\x13\x0b\x0f\x13\x0b\x0b\x0b\x0b+\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0f\x13\x0f\x05\x15K\x13\x0f\x0f\x13\x0f\x0f\x13\x0f\x0f\x01\x05\x0b\x0f\x03\x11\x17\x17\x0b\x07\x1b\x07\x13\x07\x02\xc2\x03\x1f\x11\x03\x05\x03\x07\x07\t\x0b\x03\r\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x13\x01\x05\x17\x1d\x17\x01\x05\x19\x03\x07\x1bE\x1dO\x1fg\x05\x1b\x05\x1d\x05\x1f\x1d#%\x05!\x1d')\x05#\x17+:\n\x1b\x05%\x1f\x11!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\r\x01\x13\x0bN\x13\x0bL\x03\x05//#\r\x03\x03;\r\x03=?\x1d'\x1d)\x1d+\x1d-\r\tG1I3K1M3\x1d/\x1d1\x1d3\x1d5\r\x03QS\x1d7\x1d9\x0b\x03\x1d;\x1d=\x03\x01\x05\x01\x03\x05--\x03\x03c\x15\x01\x05\x01\x03\x03-\x15\r\x11\x11\x11\x15\x11\x15\x05io\x03u\x01\x01\x01\x01\x01\x13\x05km\x11\x03\x01\x11\x03\x05\x13\x05qs\x11\x03\t\x11\x03\r\x13\x05wy\x11\x03\x11\x11\x03\x15\x01\t\x01\x02\x02)\x05\x11\x15\t)\x05\x11\x11\t\x03\x0f!\x11\x05\x07\x05\x03\x05\t)\x03\t\x13\x13\x04W\x05\x01Q\x01\x05\x01\x07\x04E\x03\x01\x05\x03P\x01\x03\x07\x041\x03\x07\x0b\x05\x0f\x11\x0b\x15\x00\x05G!\x19\x05\x03\x05\x05\x01\x03\x07\x04\x01\x03\x05\x06\x03\x01\x05\x01\x00N\x06?#\x03\x05\x1f\x0b\x11\x0b\x0b\x0f\x0b\x0f!i\x15#%3)\x05\x05\x13%)9\x15\x1f\x11\x0f\t\x0b\x11builtin\x00vhlo\x00sdy\x00module\x00func_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00a\x00b\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00sdy.sharding_rule\x00triangular_solve\x00jit(func)\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.result_info\x00result\x00main\x00public\x00diag\x00side\x00trans_x\x00uplo\x00num_batch_dims\x000\x00\x00lapack_ctrsm_ffi\x00\x08'\x07\x05\x1f\x01\x0b579AC\x11UWY[]_ae", + xla_call_module_version=10, + nr_devices=1, ) # End paste -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_07_16["c128"] = dict( +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_10_20['c128'] = dict( testdata_version=1, platform='cpu', - custom_call_targets=['blas_ztrsm'], - serialized_date=datetime.date(2023, 7, 16), + custom_call_targets=['lapack_ztrsm_ffi'], + serialized_date=datetime.date(2025, 10, 20), inputs=(array([[ 5.+0.j, 0.+0.j, 0.+0.j, 0.+0.j], [ 4.+0.j, 10.+0.j, 0.+0.j, 0.+0.j], [ 8.+0.j, 9.+0.j, 15.+0.j, 0.+0.j], @@ -182,302 +177,20 @@ 0.07513333333333323 +0.j, 0.0285333333333333 +0.j, -0.018066666666666675+0.j]]),), mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xcomplex> {jax.arg_info = "a", mhlo.sharding = "{replicated}"} loc(unknown), %arg1: tensor<4x5xcomplex> {jax.arg_info = "b", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x5xcomplex> {jax.result_info = ""}) { - %0 = stablehlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor> loc(#loc2) - %1 = stablehlo.constant dense<1> : tensor loc(#loc2) - %2 = stablehlo.constant dense<1> : tensor loc(#loc2) - %3 = stablehlo.constant dense<0> : tensor loc(#loc2) - %4 = stablehlo.constant dense<0> : tensor loc(#loc2) - %5 = stablehlo.constant dense<4> : tensor loc(#loc2) - %6 = stablehlo.constant dense<5> : tensor loc(#loc2) - %7 = stablehlo.constant dense<1> : tensor loc(#loc2) - %8 = stablehlo.custom_call @blas_ztrsm(%1, %2, %3, %4, %5, %6, %7, %0, %arg0, %arg1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor>, tensor<4x4xcomplex>, tensor<4x5xcomplex>) -> tensor<4x5xcomplex> loc(#loc2) - return %8 : tensor<4x5xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":510:0) -#loc2 = loc("jit(func)/jit(main)/triangular_solve[left_side=True lower=True transpose_a=False conjugate_a=False unit_diagonal=False]"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x19\x05\x01\x03\x01\x03\x05\x03\t\x07\t\x0b\r\x03\xa7{\x19\x01?\x0f\x07\x0b\x13\x0f\x0b\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03=\x0fO\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0f\x13\x0b\x0b\x0bO\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b3\x0f\x13\x0f\x01\x03\x0f\x03\x17\x0f\x17\x0b\x17\x0f\x07\x1b\x07\x07\x13\x13\x02\x82\x04\x1d#%\x1f\x05\x0f\x03\x03\x05c\x11\x01\x05\x05\x11\x03\x03\x05e\x03\x07\x11\t\x13\t\x0b\x15\x05\x13\x05\x15\x05\x17\x03\x0b\x19K\x1bU\x1dW\x0b]\x1f_\x05\x19\x05\x1b\x05\x1d\x05\x1f\x03\x03\x05a\x05!\x17'\xfa\x07\x01\x05#\x03\x03\x05g\x03\x03\x05i\x03\x11/k1I3m5o7q9s;u=y\x05%\x05'\x05)\x05+\x05-\x05/\x051\x053\x1f\x15\x01\x1f\x17!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d5\x1d7\x1d9\x1d;\x03\x05MQ\r\x05COEG\x1d=\r\x05CSEG\x1d?#\x0f\x03\x03Y\r\x03[I\x1dA\x1dC\x1dE\x1f\x0b!\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x05\x00\x00\x00\x0b\x05\x1dG\x03\x01\x05\x01\x03\x15????????AA\x03\x03w\x15\x01%\x01\x03\x03A\x01\x02\x02)\x01\x13)\x05\x11\x15\x07\x03\x11)\x05\x11\x11\x07)\x01\x07\x13\x11\x05\t\x05\x03\x05\x0b\x1b)\x03\x01\r)\x03\t\r\x04\xb9\x05\x01\x11\x03\x0f\x07\x03\x01\x05\x05\x11\x03\x17\x05\x03\x17+\x05\t\x03\x05\x03\x03\x03\x01!\x03\x0b\x03\x03\x01\x07\x03\x03\x03\x03\x01\x07\x03\x03\x03\x03\x01\r\x03\x03\x03\x03\x01\r\x03\x03\x03\x03\x01)\x03\x03\x03\x03\x01+\x03\x03\x03\x03\x01\x07\x03\x03\x07\x07\x01-\x03\x05\x15\x07\t\x0b\r\x0f\x11\x13\x05\x01\x03\t\x04\x03\x03\x15\x06\x03\x01\x05\x01\x00\xca\tI\x17\x0f\x0b!\x05\x05\x03\x1b\x1d\x1b\x1f/!!)#\x1f\x19\x97\xf1\x1f\x15\x1d\x15\x13%)\x13\r\x15\x1f\x11\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00func_v1\x00custom_call_v1\x00return_v1\x00value\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/triangular_solve[left_side=True lower=True transpose_a=False conjugate_a=False unit_diagonal=False]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.arg_info\x00mhlo.sharding\x00{replicated}\x00\x00a\x00b\x00jax.result_info\x00main\x00public\x00blas_ztrsm\x00", - xla_call_module_version=6, -) # End paste - -data_2024_12_02 = {} - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_12_02['c128'] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_ztrsm_ffi'], - serialized_date=datetime.date(2024, 12, 2), - inputs=( - array([ - [5.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j], - [4.0 + 0.0j, 10.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j], - [8.0 + 0.0j, 9.0 + 0.0j, 15.0 + 0.0j, 0.0 + 0.0j], - [12.0 + 0.0j, 13.0 + 0.0j, 14.0 + 0.0j, 20.0 + 0.0j], - ]), - array([ - [0.0 + 0.0j, 1.0 + 0.0j, 2.0 + 0.0j, 3.0 + 0.0j, 4.0 + 0.0j], - [5.0 + 0.0j, 6.0 + 0.0j, 7.0 + 0.0j, 8.0 + 0.0j, 9.0 + 0.0j], - [10.0 + 0.0j, 11.0 + 0.0j, 12.0 + 0.0j, 13.0 + 0.0j, 14.0 + 0.0j], - [15.0 + 0.0j, 16.0 + 0.0j, 17.0 + 0.0j, 18.0 + 0.0j, 19.0 + 0.0j], - ]), - ), - expected_outputs=( - array([ - [ - 0.0 + 0.0j, - 0.2 + 0.0j, - 0.4 + 0.0j, - 0.6000000000000001 + 0.0j, - 0.8 + 0.0j, - ], - [ - 0.5 + 0.0j, - 0.52 + 0.0j, - 0.54 + 0.0j, - 0.5599999999999999 + 0.0j, - 0.58 + 0.0j, - ], - [ - 0.36666666666666664 + 0.0j, - 0.3146666666666667 + 0.0j, - 0.2626666666666667 + 0.0j, - 0.21066666666666667 + 0.0j, - 0.15866666666666665 + 0.0j, - ], - [ - 0.16833333333333336 + 0.0j, - 0.1217333333333333 + 0.0j, - 0.07513333333333323 + 0.0j, - 0.0285333333333333 + 0.0j, - -0.018066666666666675 + 0.0j, - ], - ]), - ), - mlir_module_text=r""" #loc1 = loc("a") #loc2 = loc("b") module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xcomplex> loc("a"), %arg1: tensor<4x5xcomplex> loc("b")) -> (tensor<4x5xcomplex> {jax.result_info = ""}) { - %cst = stablehlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor> loc(#loc4) - %0 = stablehlo.custom_call @lapack_ztrsm_ffi(%arg0, %arg1, %cst) {mhlo.backend_config = {diag = 78 : ui8, side = 76 : ui8, trans_x = 78 : ui8, uplo = 76 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<4x4xcomplex>, tensor<4x5xcomplex>, tensor>) -> tensor<4x5xcomplex> loc(#loc4) + func.func public @main(%arg0: tensor<4x4xcomplex> loc("a"), %arg1: tensor<4x5xcomplex> loc("b")) -> (tensor<4x5xcomplex> {jax.result_info = "result"}) { + %0 = stablehlo.custom_call @lapack_ztrsm_ffi(%arg0, %arg1) {mhlo.backend_config = {diag = 78 : ui8, side = 76 : ui8, trans_x = 78 : ui8, uplo = 76 : ui8}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>], sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [k, l])->([m, n]) {i=4, j=4, k=4, l=5, m=4, n=5}, custom>} : (tensor<4x4xcomplex>, tensor<4x5xcomplex>) -> tensor<4x5xcomplex> loc(#loc5) return %0 : tensor<4x5xcomplex> loc(#loc) } loc(#loc) } loc(#loc) #loc = loc(unknown) -#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":715:13) -#loc4 = loc("jit(func)/jit(main)/triangular_solve"(#loc3)) -""", - mlir_module_serialized=b"ML\xefR\rStableHLO_v1.8.1\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03\x87[\x19\x01%\x07\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x17\x0b\x13\x0b\x037O\x0b\x0b\x0f\x0f\x13\x0b\x0f\x13\x0b\x0b\x0bO+\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x17\x0f\x0f\x13\x0f\x01\x05\x0b\x0f\x03\x15\x17\x0b\x17\x0f\x07\x07\x1b\x07\x13\x13\x02\x1e\x03\x1f\x11\x03\x05\x1d\x1b\x1d\x03\x07\t\x0b\r\x03\x0f\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x15\x01\x05\x17\x1d\x19\x01\x05\x19\x05\x1b\x17\x1f.\x0b\x1b\x05\x1d\x03\x03#?\x05\x1f\x1f\x15!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\r\x01\x1d!\x13\rN\x13\rL\x03\x05''#\x11\x03\x035\r\x037)\x1d#\x1d%\x1d'\x1f\x0b!\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\x00\x00\x00\x00\x00\r\tA+C-E+G-\x1d)\x1d+\x1d-\x1d/\x0b\x03\x1d1\x03\x01\x05\x01\x03\x07%%S\x1f\x17\x01\x03\x03W\x15\x01\x05\x01\x03\x03%\x01\t\x01\x02\x02)\x05\x11\x15\x07\x03\x13)\x05\x11\x11\x07)\x01\x07!\x13\x11\x05\t\x05\x03\x05\x0b)\x03\t\x0f)\x03\x01\x0f\x04e\x05\x01Q\x01\x07\x01\x07\x04S\x03\x01\x05\x03P\x01\x03\x07\x04?\x03\t\x0f\x05\x13\x13\x0b\x17\x00\x05B\x05\x05\x03\x0b\x07G\x05!\x07\x03\x05\x07\x01\x03\x05\t\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00r\x053#\x0b\x11\x0b\x0b\x0f\x0b!\x03)iK\x05\x05\x13%)9\x15\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00constant_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00a\x00b\x00jit(func)/jit(main)/triangular_solve\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.backend_config\x00\x00jax.result_info\x00main\x00public\x00diag\x00side\x00trans_x\x00uplo\x00lapack_ztrsm_ffi\x00\x08+\t\x05#\x01\x0b/139;\x03=\x11I)KMOQUY", - xla_call_module_version=9, - nr_devices=1, -) # End paste - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_12_02['c64'] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_ctrsm_ffi'], - serialized_date=datetime.date(2024, 12, 2), - inputs=( - array( - [ - [5.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j], - [4.0 + 0.0j, 10.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j], - [8.0 + 0.0j, 9.0 + 0.0j, 15.0 + 0.0j, 0.0 + 0.0j], - [12.0 + 0.0j, 13.0 + 0.0j, 14.0 + 0.0j, 20.0 + 0.0j], - ], - dtype=complex64, - ), - array( - [ - [0.0 + 0.0j, 1.0 + 0.0j, 2.0 + 0.0j, 3.0 + 0.0j, 4.0 + 0.0j], - [5.0 + 0.0j, 6.0 + 0.0j, 7.0 + 0.0j, 8.0 + 0.0j, 9.0 + 0.0j], - [ - 10.0 + 0.0j, - 11.0 + 0.0j, - 12.0 + 0.0j, - 13.0 + 0.0j, - 14.0 + 0.0j, - ], - [ - 15.0 + 0.0j, - 16.0 + 0.0j, - 17.0 + 0.0j, - 18.0 + 0.0j, - 19.0 + 0.0j, - ], - ], - dtype=complex64, - ), - ), - expected_outputs=( - array( - [ - [0.0 + 0.0j, 0.2 + 0.0j, 0.4 + 0.0j, 0.6 + 0.0j, 0.8 + 0.0j], - [ - 0.5 + 0.0j, - 0.52 + 0.0j, - 0.54 + 0.0j, - 0.56 + 0.0j, - 0.58000004 + 0.0j, - ], - [ - 0.36666667 + 0.0j, - 0.31466666 + 0.0j, - 0.26266667 + 0.0j, - 0.21066667 + 0.0j, - 0.15866666 + 0.0j, - ], - [ - 0.16833334 + 0.0j, - 0.12173338 + 0.0j, - 0.0751333 + 0.0j, - 0.02853328 + 0.0j, - -0.018066704 + 0.0j, - ], - ], - dtype=complex64, - ), - ), - mlir_module_text=r""" -#loc1 = loc("a") -#loc2 = loc("b") -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xcomplex> loc("a"), %arg1: tensor<4x5xcomplex> loc("b")) -> (tensor<4x5xcomplex> {jax.result_info = ""}) { - %cst = stablehlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor> loc(#loc4) - %0 = stablehlo.custom_call @lapack_ctrsm_ffi(%arg0, %arg1, %cst) {mhlo.backend_config = {diag = 78 : ui8, side = 76 : ui8, trans_x = 78 : ui8, uplo = 76 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<4x4xcomplex>, tensor<4x5xcomplex>, tensor>) -> tensor<4x5xcomplex> loc(#loc4) - return %0 : tensor<4x5xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":715:13) -#loc4 = loc("jit(func)/jit(main)/triangular_solve"(#loc3)) -""", - mlir_module_serialized=b"ML\xefR\rStableHLO_v1.8.1\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03\x87[\x19\x01%\x07\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x17\x0b\x13\x0b\x037O\x0b\x0b\x0f\x0f\x13\x0b\x0f\x13\x0b\x0b\x0b/+\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x17\x0f\x0f\x13\x0f\x01\x05\x0b\x0f\x03\x15\x17\x0b\x17\x0f\x07\x07\x1b\x07\x13\x13\x02\xfe\x02\x1f\x11\x03\x05\x1d\x1b\x1d\x03\x07\t\x0b\r\x03\x0f\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x15\x01\x05\x17\x1d\x19\x01\x05\x19\x05\x1b\x17\x1f.\x0b\x1b\x05\x1d\x03\x03#?\x05\x1f\x1f\x15!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\r\x01\x1d!\x13\rN\x13\rL\x03\x05''#\x11\x03\x035\r\x037)\x1d#\x1d%\x1d'\x1f\x0b\x11\x00\x00\x80?\x00\x00\x00\x00\r\tA+C-E+G-\x1d)\x1d+\x1d-\x1d/\x0b\x03\x1d1\x03\x01\x05\x01\x03\x07%%S\x1f\x17\x01\x03\x03W\x15\x01\x05\x01\x03\x03%\x01\t\x01\x02\x02)\x05\x11\x15\x07\x03\x13)\x05\x11\x11\x07)\x01\x07!\x13\x11\x05\t\x05\x03\x05\t)\x03\t\x0f)\x03\x01\x0f\x04e\x05\x01Q\x01\x07\x01\x07\x04S\x03\x01\x05\x03P\x01\x03\x07\x04?\x03\t\x0f\x05\x13\x13\x0b\x17\x00\x05B\x05\x05\x03\x0b\x07G\x05!\x07\x03\x05\x07\x01\x03\x05\t\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00r\x053#\x0b\x11\x0b\x0b\x0f\x0b!\x03)iK\x05\x05\x13%)9\x15\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00constant_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00a\x00b\x00jit(func)/jit(main)/triangular_solve\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.backend_config\x00\x00jax.result_info\x00main\x00public\x00diag\x00side\x00trans_x\x00uplo\x00lapack_ctrsm_ffi\x00\x08+\t\x05#\x01\x0b/139;\x03=\x11I)KMOQUY", - xla_call_module_version=9, - nr_devices=1, -) # End paste - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_12_02['f32'] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_strsm_ffi'], - serialized_date=datetime.date(2024, 12, 2), - inputs=( - array( - [ - [5.0, 0.0, 0.0, 0.0], - [4.0, 10.0, 0.0, 0.0], - [8.0, 9.0, 15.0, 0.0], - [12.0, 13.0, 14.0, 20.0], - ], - dtype=float32, - ), - array( - [ - [0.0, 1.0, 2.0, 3.0, 4.0], - [5.0, 6.0, 7.0, 8.0, 9.0], - [10.0, 11.0, 12.0, 13.0, 14.0], - [15.0, 16.0, 17.0, 18.0, 19.0], - ], - dtype=float32, - ), - ), - expected_outputs=( - array( - [ - [0.0, 0.2, 0.4, 0.6, 0.8], - [0.5, 0.52, 0.54, 0.56, 0.58000004], - [0.36666667, 0.31466666, 0.26266667, 0.21066667, 0.15866666], - [0.16833334, 0.12173338, 0.0751333, 0.02853328, -0.018066704], - ], - dtype=float32, - ), - ), - mlir_module_text=r""" -#loc1 = loc("a") -#loc2 = loc("b") -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xf32> loc("a"), %arg1: tensor<4x5xf32> loc("b")) -> (tensor<4x5xf32> {jax.result_info = ""}) { - %cst = stablehlo.constant dense<1.000000e+00> : tensor loc(#loc4) - %0 = stablehlo.custom_call @lapack_strsm_ffi(%arg0, %arg1, %cst) {mhlo.backend_config = {diag = 78 : ui8, side = 76 : ui8, trans_x = 78 : ui8, uplo = 76 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<4x4xf32>, tensor<4x5xf32>, tensor) -> tensor<4x5xf32> loc(#loc4) - return %0 : tensor<4x5xf32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":715:13) -#loc4 = loc("jit(func)/jit(main)/triangular_solve"(#loc3)) -""", - mlir_module_serialized=b"ML\xefR\rStableHLO_v1.8.1\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03\x85[\x17\x01%\x07\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x17\x0b\x13\x0b\x037O\x0b\x0b\x0f\x0f\x13\x0b\x0f\x13\x0b\x0b\x0b\x1f+\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x17\x0f\x0f\x13\x0f\x01\x05\x0b\x0f\x03\x13\x17\x07\x17\x0f\x07\x07\x1b\x13\x13\x02\xe6\x02\x1f\x11\x03\x05\x1d\x1b\x1d\x03\x07\t\x0b\r\x03\x0f\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x15\x01\x05\x17\x1d\x19\x01\x05\x19\x05\x1b\x17\x1f.\x0b\x1b\x05\x1d\x03\x03#?\x05\x1f\x1f\x13!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\r\x01\x1d!\x13\rN\x13\rL\x03\x05''#\x11\x03\x035\r\x037)\x1d#\x1d%\x1d'\x1f\x0b\t\x00\x00\x80?\r\tA+C-E+G-\x1d)\x1d+\x1d-\x1d/\x0b\x03\x1d1\x03\x01\x05\x01\x03\x07%%S\x1f\x15\x01\x03\x03W\x15\x01\x05\x01\x03\x03%\x01\t\x01\x02\x02)\x05\x11\x15\x07\t)\x05\x11\x11\x07)\x01\x07!\x13\x11\x05\t\x05\x03\x05)\x03\t\x0f)\x03\x01\x0f\x04e\x05\x01Q\x01\x07\x01\x07\x04S\x03\x01\x05\x03P\x01\x03\x07\x04?\x03\t\x0f\x05\x13\x13\x0b\x17\x00\x05B\x05\x05\x03\x0b\x07G\x05!\x07\x03\x05\x07\x01\x03\x05\t\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00r\x053#\x0b\x11\x0b\x0b\x0f\x0b!\x03)iK\x05\x05\x13%)9\x15\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00constant_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00a\x00b\x00jit(func)/jit(main)/triangular_solve\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.backend_config\x00\x00jax.result_info\x00main\x00public\x00diag\x00side\x00trans_x\x00uplo\x00lapack_strsm_ffi\x00\x08+\t\x05#\x01\x0b/139;\x03=\x11I)KMOQUY", - xla_call_module_version=9, - nr_devices=1, -) # End paste - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_12_02['f64'] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dtrsm_ffi'], - serialized_date=datetime.date(2024, 12, 2), - inputs=( - array([ - [5.0, 0.0, 0.0, 0.0], - [4.0, 10.0, 0.0, 0.0], - [8.0, 9.0, 15.0, 0.0], - [12.0, 13.0, 14.0, 20.0], - ]), - array([ - [0.0, 1.0, 2.0, 3.0, 4.0], - [5.0, 6.0, 7.0, 8.0, 9.0], - [10.0, 11.0, 12.0, 13.0, 14.0], - [15.0, 16.0, 17.0, 18.0, 19.0], - ]), - ), - expected_outputs=( - array([ - [0.0, 0.2, 0.4, 0.6000000000000001, 0.8], - [0.5, 0.52, 0.54, 0.5599999999999999, 0.58], - [ - 0.36666666666666664, - 0.3146666666666667, - 0.2626666666666667, - 0.21066666666666667, - 0.15866666666666665, - ], - [ - 0.16833333333333336, - 0.1217333333333333, - 0.07513333333333323, - 0.0285333333333333, - -0.018066666666666675, - ], - ]), - ), - mlir_module_text=r""" -#loc1 = loc("a") -#loc2 = loc("b") -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xf64> loc("a"), %arg1: tensor<4x5xf64> loc("b")) -> (tensor<4x5xf64> {jax.result_info = ""}) { - %cst = stablehlo.constant dense<1.000000e+00> : tensor loc(#loc4) - %0 = stablehlo.custom_call @lapack_dtrsm_ffi(%arg0, %arg1, %cst) {mhlo.backend_config = {diag = 78 : ui8, side = 76 : ui8, trans_x = 78 : ui8, uplo = 76 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<4x4xf64>, tensor<4x5xf64>, tensor) -> tensor<4x5xf64> loc(#loc4) - return %0 : tensor<4x5xf64> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":715:13) -#loc4 = loc("jit(func)/jit(main)/triangular_solve"(#loc3)) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":654:13) +#loc4 = loc("jit(func)"(#loc3)) +#loc5 = loc("triangular_solve"(#loc4)) """, - mlir_module_serialized=b"ML\xefR\rStableHLO_v1.8.1\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03\x85[\x17\x01%\x07\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x17\x0b\x13\x0b\x037O\x0b\x0b\x0f\x0f\x13\x0b\x0f\x13\x0b\x0b\x0b/+\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x17\x0f\x0f\x13\x0f\x01\x05\x0b\x0f\x03\x13\x17\x07\x17\x0f\x07\x07\x1b\x13\x13\x02\xf6\x02\x1f\x11\x03\x05\x1d\x1b\x1d\x03\x07\t\x0b\r\x03\x0f\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x15\x01\x05\x17\x1d\x19\x01\x05\x19\x05\x1b\x17\x1f.\x0b\x1b\x05\x1d\x03\x03#?\x05\x1f\x1f\x13!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\r\x01\x1d!\x13\rN\x13\rL\x03\x05''#\x11\x03\x035\r\x037)\x1d#\x1d%\x1d'\x1f\x0b\x11\x00\x00\x00\x00\x00\x00\xf0?\r\tA+C-E+G-\x1d)\x1d+\x1d-\x1d/\x0b\x03\x1d1\x03\x01\x05\x01\x03\x07%%S\x1f\x15\x01\x03\x03W\x15\x01\x05\x01\x03\x03%\x01\t\x01\x02\x02)\x05\x11\x15\x07\x0b)\x05\x11\x11\x07)\x01\x07!\x13\x11\x05\t\x05\x03\x05)\x03\t\x0f)\x03\x01\x0f\x04e\x05\x01Q\x01\x07\x01\x07\x04S\x03\x01\x05\x03P\x01\x03\x07\x04?\x03\t\x0f\x05\x13\x13\x0b\x17\x00\x05B\x05\x05\x03\x0b\x07G\x05!\x07\x03\x05\x07\x01\x03\x05\t\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00r\x053#\x0b\x11\x0b\x0b\x0f\x0b!\x03)iK\x05\x05\x13%)9\x15\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00constant_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00a\x00b\x00jit(func)/jit(main)/triangular_solve\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.backend_config\x00\x00jax.result_info\x00main\x00public\x00diag\x00side\x00trans_x\x00uplo\x00lapack_dtrsm_ffi\x00\x08+\t\x05#\x01\x0b/139;\x03=\x11I)KMOQUY", - xla_call_module_version=9, + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.12.1\x00\x01\x1b\x07\x01\x05\t\t\x01\x03\x0f\x03\x07\x13\x17\x1b\x03\xa7{\x15\x01-\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x17\x0b\x03;O\x0b\x0f\x0f\x13\x0b\x0f\x13\x0b\x0b\x0b\x0b+\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0f\x13\x0f\x05\x15K\x13\x0f\x0f\x13\x0f\x0f\x13\x0f\x0f\x01\x05\x0b\x0f\x03\x11\x17\x17\x0b\x07\x1b\x07\x13\x07\x02\xc2\x03\x1f\x11\x03\x05\x03\x07\x07\t\x0b\x03\r\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x13\x01\x05\x17\x1d\x17\x01\x05\x19\x03\x07\x1bE\x1dO\x1fg\x05\x1b\x05\x1d\x05\x1f\x1d#%\x05!\x1d')\x05#\x17+:\n\x1b\x05%\x1f\x11!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\r\x01\x13\x0bN\x13\x0bL\x03\x05//#\r\x03\x03;\r\x03=?\x1d'\x1d)\x1d+\x1d-\r\tG1I3K1M3\x1d/\x1d1\x1d3\x1d5\r\x03QS\x1d7\x1d9\x0b\x03\x1d;\x1d=\x03\x01\x05\x01\x03\x05--\x03\x03c\x15\x01\x05\x01\x03\x03-\x15\r\x11\x11\x11\x15\x11\x15\x05io\x03u\x01\x01\x01\x01\x01\x13\x05km\x11\x03\x01\x11\x03\x05\x13\x05qs\x11\x03\t\x11\x03\r\x13\x05wy\x11\x03\x11\x11\x03\x15\x01\t\x01\x02\x02)\x05\x11\x15\t)\x05\x11\x11\t\x03\x0f!\x11\x05\x07\x05\x03\x05\x0b)\x03\t\x13\x13\x04W\x05\x01Q\x01\x05\x01\x07\x04E\x03\x01\x05\x03P\x01\x03\x07\x041\x03\x07\x0b\x05\x0f\x11\x0b\x15\x00\x05G!\x19\x05\x03\x05\x05\x01\x03\x07\x04\x01\x03\x05\x06\x03\x01\x05\x01\x00N\x06?#\x03\x05\x1f\x0b\x11\x0b\x0b\x0f\x0b\x0f!i\x15#%3)\x05\x05\x13%)9\x15\x1f\x11\x0f\t\x0b\x11builtin\x00vhlo\x00sdy\x00module\x00func_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00a\x00b\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00sdy.sharding_rule\x00triangular_solve\x00jit(func)\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.result_info\x00result\x00main\x00public\x00diag\x00side\x00trans_x\x00uplo\x00num_batch_dims\x000\x00\x00lapack_ztrsm_ffi\x00\x08'\x07\x05\x1f\x01\x0b579AC\x11UWY[]_ae", + xla_call_module_version=10, nr_devices=1, ) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_tridiagonal_lapack_sytrd_hetrd.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_tridiagonal_lapack_sytrd_hetrd.py index 9e245052e03a..4bd672a59aa1 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_tridiagonal_lapack_sytrd_hetrd.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_tridiagonal_lapack_sytrd_hetrd.py @@ -15,434 +15,14 @@ # ruff: noqa import datetime -from numpy import array, float32, complex64 +import numpy as np -data_2024_09_03 = {} - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_09_03["c128"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zhetrd'], - serialized_date=datetime.date(2024, 9, 3), - inputs=(), - expected_outputs=(array([[[-1.6782909868280393 +0.j , - -0.44670237330570184+4.847000766107959j , - 2.05945450900321 -2.2848432268240106j , - -1.852046418980849 +1.672382006137275j ], - [ 8.516713699516982 +0.j , - -2.7881860505313174 +0.j , - 0.9238284715039695 -2.3790501284019947j , - 0.5005102262291599 -1.30066052934836j ], - [-0.12132810525381293-0.2963030371159077j , - -3.6374350042782893 +0.j , - 0.5605752523031344 +0.j , - -2.9865099107523174 +0.5492956557924651j ], - [-0.40379248092949666-0.7813328344426929j , - -0.07101654492399719-0.27208840961051617j, - -7.4654253782049285 +0.j , - -8.172380353916964 +0.j ]], - - [[-3.996403598623405 +0.j , - 0.59408630943699 +2.531609474375295j , - -1.789098034543644 -2.538389274566601j , - -1.291106590337488 +3.1576544511573843j ], - [10.8950662522622 +0.j , - -2.8151642043836693 +0.j , - 6.18998567202382 +1.1866537964613415j , - 3.1900218245393352 +2.7291222716752372j ], - [-0.3142889671188478 -0.37781876498252764j, - 3.049208563595754 +0.j , - -2.4383044880335487 +0.j , - 4.075435464493341 -0.6653616942280807j ], - [ 0.32757687545025194+0.565870910342534j , - 0.8177026465997795 -0.15906305615104555j, - 3.3415143060767125 +0.j , - 4.094619408678314 +0.j ]]]), array([[-1.6782909868280393, -2.7881860505313174, 0.5605752523031344, - -8.172380353916964 ], - [-3.996403598623405 , -2.8151642043836693, -2.4383044880335487, - 4.094619408678314 ]]), array([[ 8.516713699516982 , -3.6374350042782893, -7.4654253782049285], - [10.8950662522622 , 3.049208563595754 , 3.3415143060767125]]), array([[1.0626274644222748+0.06050271598884928j, - 1.834630852474663 +0.18575551495730305j, - 1.981584368497257 +0.19102912741736966j], - [1.0365789616521406-0.40942548304121656j, - 1.0872592163018966-0.3187050677167622j , - 1.0458498304770472-0.9989483435319496j ]])), - mlir_module_text=r""" -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":695:13) -#loc5 = loc("jit(func)/jit(main)/pjit"(#loc1)) -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x3xf64> {jax.result_info = "[2]", mhlo.layout_mode = "default"}, tensor<2x3xcomplex> {jax.result_info = "[3]", mhlo.layout_mode = "default"}) { - %cst = stablehlo.constant dense<[[[(-1.6782909868280393,-0.44303325034407437), (-0.44670237330570184,4.8470007661079588), (2.0594545090032099,-2.2848432268240106), (-1.852046418980849,1.6723820061372749)], [(-0.53338018421119981,-0.5152843101202178), (-8.6208093221459947,-1.4723511111926109), (0.92382847150396952,-2.3790501284019947), (0.50051022622915986,-1.30066052934836)], [(0.94535043721506584,2.744088772946665), (-5.9178492824175759,-4.3744650461123786), (1.8341291553102983,-4.8378584827626838), (-2.9865099107523174,0.54929565579246509)], [(3.2517513113853891,7.2792034361133062), (-0.09841002311276037,0.88008791818205689), (-0.035759860211603468,2.4677764344580244), (-3.6133109853094476,-2.2833696560058976)]], [[(-3.996403598623405,2.42308766118121), (0.59408630943699003,2.531609474375295), (-1.789098034543644,-2.538389274566601), (-1.2911065903374881,3.1576544511573843)], [(-0.39853021063902833,4.4607177630985086), (1.0742061295773189,-2.6002112528615386), (6.1899856720238198,1.1866537964613415), (3.1900218245393352,2.7291222716752372)], [(5.2347956435718022,2.8649782894514577), (2.3527586611916762,2.4688953673448575), (-2.317572140163894,4.3609023810820053), (4.0754354644933413,-0.66536169422808067)], [(-6.2237114632988675,-4.9294897244018943), (4.2994486027667103,-1.3300494261380422), (-0.51942958410141249,0.60038999428238982), (0.084516726847668963,-7.2944134049318752)]]]> : tensor<2x4x4xcomplex> loc(#loc) - %c = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_2 = stablehlo.constant dense<2> : tensor loc(#loc2) - %c_3 = stablehlo.constant dense<128> : tensor loc(#loc2) - %0:6 = stablehlo.custom_call @lapack_zhetrd(%c, %c_0, %c_1, %c_2, %c_3, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xcomplex>, tensor<2xi32>, tensor<128xcomplex>) loc(#loc2) - %c_4 = stablehlo.constant dense<0> : tensor loc(#loc) - %1 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc4) - %cst_5 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc) - %4 = call @_where(%3, %0#0, %cst_5) : (tensor<2x1x1xi1>, tensor<2x4x4xcomplex>, tensor>) -> tensor<2x4x4xcomplex> loc(#loc5) - %c_6 = stablehlo.constant dense<0> : tensor loc(#loc) - %5 = stablehlo.broadcast_in_dim %c_6, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %6 = stablehlo.compare EQ, %0#4, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_7 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) - %8 = call @_where_0(%7, %0#1, %cst_7) : (tensor<2x1xi1>, tensor<2x4xf64>, tensor) -> tensor<2x4xf64> loc(#loc5) - %c_8 = stablehlo.constant dense<0> : tensor loc(#loc) - %9 = stablehlo.broadcast_in_dim %c_8, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %10 = stablehlo.compare EQ, %0#4, %9, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %11 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_9 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) - %12 = call @_where_1(%11, %0#2, %cst_9) : (tensor<2x1xi1>, tensor<2x3xf64>, tensor) -> tensor<2x3xf64> loc(#loc5) - %c_10 = stablehlo.constant dense<0> : tensor loc(#loc) - %13 = stablehlo.broadcast_in_dim %c_10, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %14 = stablehlo.compare EQ, %0#4, %13, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %15 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_11 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc) - %16 = call @_where_2(%15, %0#3, %cst_11) : (tensor<2x1xi1>, tensor<2x3xcomplex>, tensor>) -> tensor<2x3xcomplex> loc(#loc5) - return %4, %8, %12, %16 : tensor<2x4x4xcomplex>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xcomplex> loc(#loc) - } loc(#loc) - func.func private @_where(%arg0: tensor<2x1x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc7) - return %2 : tensor<2x4x4xcomplex> loc(#loc5) - } loc(#loc5) - func.func private @_where_0(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4xf64> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4xf64> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc7) - return %2 : tensor<2x4xf64> loc(#loc5) - } loc(#loc5) - func.func private @_where_1(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xf64> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x3xf64> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xf64> loc(#loc7) - return %2 : tensor<2x3xf64> loc(#loc5) - } loc(#loc5) - func.func private @_where_2(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xcomplex> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xcomplex> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor>) -> tensor<2x3xcomplex> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xcomplex> loc(#loc7) - return %2 : tensor<2x3xcomplex> loc(#loc5) - } loc(#loc5) -} loc(#loc) -#loc = loc(unknown) -#loc2 = loc("jit(func)/jit(main)/tridiagonal"(#loc1)) -#loc3 = loc("jit(func)/jit(main)/eq"(#loc1)) -#loc4 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc1)) -#loc6 = loc("jit(func)/jit(main)/jit(_where)/broadcast_in_dim"(#loc1)) -#loc7 = loc("jit(func)/jit(main)/jit(_where)/select_n"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\rStableHLO_v1.3.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#\'+\x03\xf7\x99I\x01-\x0f\x07\x0f\x0f\x17\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03m\x0f\x0b\x0b\x0f\x0b\x17\x13\x0f\x0b\x1f\x0b\x0b/OO\x0b\x0b\x0b\x0b\x0b\x1fo/O/\x0b\x1b\x1b\x0b\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0b\x0b\x0b\x0bo&\x10\x1f\x1f\x1f\x0b\x0b\x0b\x0b#\x0f\x17#\x01\x05\x0b\x0f\x03E\x0f\x1b\x17\x17\x17\x17\x0f\x0f\x07\x13\x0b\x07\x07\x07\x13\x1b\x17\x07\x1f\x1f\x1f\x1f\x1f\x13\x13\x17\x1b\x13\x17\x13\x13\x13\x13\x13\x02\x12\x10\x1d\x1f\t\x1f\x1d#\t\x1d)\t\x17!\xde\n\x1b\x1d\'\t\x1d%\t\x1d+\t\x11\x03\x05\x03\x07\x15\x17\x19\x11\x1b\x11\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x05%\x05\'\x05)\x05+\x1f5\x01\x1d-\x1d/\x1f?\x01\x1d1\x03\x07999\r\x03/1\x03\x039\x1d3\x1f\x05\t\x00\x00\x00\x00\t\x07\x07\x01\x1fG\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1fC!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d5\x1d7\x1d9\x1d;\x1f\x05\t\x04\x00\x00\x00\x1fA1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1fE\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x11!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x13\x11\x00\x00\x00\x00\x00\x00\xf8\x7f#)\x03\tcgko\r\x055e/1\x1d=\r\x055i/1\x1d?\r\x055m/1\x1dA\r\x055q/1\x1dC\x1dE\x1dG#+#-#/#1\x1f;1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x02\x08d\x91Y\xa6G\xda\xfa\xbf$-Q"\xa8Z\xdc\xbfL0\x19\x8d\xc5\x96\xdc\xbf\x86{8+Tc\x13@\xf0%\x1eI\xc3y\x00@\xe4\x91\xbd\xe2[G\x02\xc0\x85%\x03m\xfb\xa1\xfd\xbf\x9atl\xa2\x13\xc2\xfa?\x9c\xb0\xf0Qs\x11\xe1\xbf\xd8v\x83\x855}\xe0\xbf\x84V/\xb8\xda=!\xc0\n\xd3\xec\t\xc0\x8e\xf7\xbf\x98$\x07\xba\x00\x90\xed?\xd5?\x08oK\x08\x03\xc0>\xf8\x9e\x05.\x04\xe0?\xf2\xfcKj\x81\xcf\xf4\xbf\xe4"c\x8fO@\xee?y\x03\x89\xd0\xe4\xf3\x05@\xee\x8f\xaa\xae\xe0\xab\x17\xc0\xf20\xda\xc3s\x7f\x11\xc0V*+\xd0\x97X\xfd?P\x91\xf8\x92\xf7Y\x13\xc0\x7f\xe3\xdeN_\xe4\x07\xc0\x14\xd5\xae{\xd4\x93\xe1?\xbc\x00\t1\x96\x03\n@`&l\x81\xe7\x1d\x1d@X/\xde6f1\xb9\xbf\x06KF#\xae)\xec?\xcd\x9a<\xcc\x1dO\xa2\xbf\x91\xb1>\x92\x01\xbe\x03@\xf2s\x01\x97\x0f\xe8\x0c\xc0\xf5\xcaiOWD\x02\xc0F\xa2-s\xa2\xf8\x0f\xc0X\xea\xa0\xc8{b\x03@\x0b\x10\xc1J\xc1\x02\xe3?2|\xd5w\xbc@\x04@\xca>\xbbB%\xa0\xfc\xbf\xe8>6\t\x9fN\x04\xc0\xafdRb_\xa8\xf4\xbf\x80Q>V\xe0B\t@UhJ\xdb\x84\x81\xd9\xbf\t\xc7\xb4e\xc6\xd7\x11@<(;\xc4\xf2/\xf1?\x1a\xda\xad\x8e;\xcd\x04\xc0\x1c4\xa0\x9a\x8b\xc2\x18@z\x9c\xf7\xb0\x88\xfc\xf2?\xaea\x8f)*\x85\t@\x00\x0b\xbd\x0e>\xd5\x05@b\x89\xe9Dn\xf0\x14@a\x8d\xc7\xbcy\xeb\x06@\x8a\x97\t"s\xd2\x02@\xc2\xef\xdf6L\xc0\x03@J\xff Cc\x8a\x02\xc0\xd7.\xcfd\x90q\x11@s\xd4S\xf4>M\x10@t\x10\x97\x9b\xa4J\xe5\xbf\x8eo*\x9e\x14\xe5\x18\xc0\xc5\x18\x81\'\xcc\xb7\x13\xc0\x19\xdd\x8e\xa7\xa22\x11@-95\xe8\xe1G\xf5\xbfZK\x89\xca*\x9f\xe0\xbfR;\xc9\x13e6\xe3?\x7f\x94\xc6a\xe3\xa2\xb5?\xe2\xbe&\xb5z-\x1d\xc0\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\x80\x00\x00\x00\x0b\x05\x1dI\x1dK\x05\x01\x03\r33333W\x03\x03\x95\x15\x03\x01\x15\x01\x03\rWIIIYY\x01\t\x01\x02\x02)\x01\')\x07\t\x11\x11\x19)\x05\t\x05\x15)\x05\t\x11\x1b)\x05\t\r\x1b)\x05\t\r\x19)\x01\x19)\x01\x1b\x01)\x03\t\'\x03\x1b\x0b\x1d\x13)\x03\t\x15)\x07\t\x05\x05\x15)\x05\t\r\x15\x1b\x11\x01\t\x07\x0b\r\x0f\x11\x07#\x07\x11\x03\x07\x11\x07\t\x0b\x13\x03\x0b\x11\x07\t\r\x13\x03\r\x11\x07\t\x0f\x11\x03\x0f)\x03\t\x1d)\x03\x01\x1d)\x05\t\x11\x15)\x07\t\x11\x11\x15)\x03\r\x1d)\x03\x02\x04\x19)\x03\x01\x1f)\x03\r\x1f)\x03\t\x1f)\x03\x05\x1f)\x03\x05\x1d\x04J\x07\x05\x01Q\x03\x13\x01\x07\x04"\x07\x03\x01\x15\x07P\x03\x03\x07\x04\xf6\x03\x03I\x81\x05B\x03\x05\x03\x07\x05B\x0b\x07\x03\x05\x05B\x0b\t\x03\x05\x05B\x0b\x07\x03\x05\x05B\x0b\x0b\x03\x05\x05B\x0b\r\x03\x05\x11F\x0b\x0f\r\x07\x0b\r\x0f\x17=\r\x03\x05\x07\t\x0b\x01\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x03\x19\rF\x07\x15\x03!\x05\x15\x1b\x03F\x0f\x17\x03#\x03\x1d\x05B\x03\x19\x03\x11\x0fF\x01\x1b\x03\x07\x07\x1f\r!\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x03%\rF\x07\x15\x03!\x05\x15\'\x03F\x0f\x17\x03\t\x03)\x05B\x03\x1d\x03\x13\x0fF\x01\x1f\x03\x0b\x07+\x0f-\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x031\rF\x07\x15\x03!\x05\x153\x03F\x0f\x17\x03\t\x035\x05B\x03\x1d\x03\x13\x0fF\x01!\x03\r\x077\x119\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x03=\rF\x07\x15\x03!\x05\x15?\x03F\x0f\x17\x03\t\x03A\x05B\x03\x19\x03\x11\x0fF\x01#\x03\x0f\x07C\x13E\t\x04\x03\t#/;G\x07P\x01%\x07\x04S\x03\r\x13\x07G\x01\x0f\x01#\x01\x00\x03F\x05\'\x039\x03\x01\x03F\x05\x13\x03\x07\x03\x05\x0b\x06\r\x03\x07\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01)\x07\x04S\x03\r\x13\x07\x13\x01\x17\x01\'\x01\x00\x03F\x05+\x037\x03\x01\x03F\x05\x13\x03\x0b\x03\x05\x0b\x06\r\x03\x0b\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01-\x07\x04S\x03\r\x13\x07\x13\x01\x1b\x01\'\x01\x00\x03F\x05+\x03%\x03\x01\x03F\x05\x13\x03\r\x03\x05\x0b\x06\r\x03\r\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01/\x07\x04S\x03\r\x13\x07\x13\x01\x1f\x01#\x01\x00\x03F\x05+\x03%\x03\x01\x03F\x05\x13\x03\x0f\x03\x05\x0b\x06\r\x03\x0f\x07\x07\x03\t\t\x04\x01\x03\x0b\x06\x03\x01\x05\x01\x00\x96\tM\x1d\x03\x0f\x0b\t\t\t\t\x13\x13\x13\x0f\x11!\x11#K/ASci3\x13%)9\x1f\x11\x17\x15\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00func_v1\x00return_v1\x00select_v1\x00compare_v1\x00call_v1\x00custom_call_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/jit(_where)/broadcast_in_dim\x00jit(func)/jit(main)/jit(_where)/select_n\x00jit(func)/jit(main)/tridiagonal\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00mhlo.layout_mode\x00default\x00jax.result_info\x00private\x00_where\x00_where_0\x00_where_1\x00_where_2\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00\x00lapack_zhetrd\x00\x08\x9d1\x05;\x01\x0bK_asu\x03\x81\x03U\x03\x83\x03\x85\x03\x87\x11\x89\x8b\x8dK\x8f\x91\x93\x97\x03?\x03-\x05AC\x03E\x03[\x03M\x03]\x03O\x03Q\x03S\x0b7w;M=\x03\x7f\x0b7y;O=\x03G\x0b7{;Q=\x0b7};S=', - xla_call_module_version=9, - nr_devices=1, -) # End paste - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_09_03["c64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_chetrd'], - serialized_date=datetime.date(2024, 9, 3), - inputs=(), - expected_outputs=(array([[[ 3.3228416 +0.j , -1.9756439 +4.593356j , - 7.367708 +0.88518727j , -8.659938 +1.6132793j ], - [-6.9206004 +0.j , -3.6362798 +0.j , - 3.3011198 -4.644362j , -4.8589935 -0.61439794j ], - [ 0.64957 +0.060723424j, 6.620491 +0.j , - 0.2882607 +0.j , -1.0288142 +1.8544064j ], - [-0.05458622 +0.10473086j , -0.15611424 +0.06925995j , - -4.431866 +0.j , 2.364208 +0.j ]], - - [[-4.1803885 +0.j , 0.5670845 +0.6913016j , - 2.675204 -0.23881845j , -0.41825035 -1.4060576j ], - [ 8.33625 +0.j , 2.6144838 +0.j , - -2.4941807 -1.9316154j , 0.6687787 -2.209776j ], - [ 0.019031923+0.17462212j , 2.7034955 +0.j , - -0.70924187 +0.j , 2.7962255 +1.5316825j ], - [-0.057821754+0.023692288j, -0.62805307 -0.0882424j , - 6.6364865 +0.j , -1.698973 +0.j ]]], - dtype=complex64), array([[ 3.3228416 , -3.6362798 , 0.2882607 , 2.364208 ], - [-4.1803885 , 2.6144838 , -0.70924187, -1.698973 ]], - dtype=float32), array([[-6.9206004, 6.620491 , -4.431866 ], - [ 8.33625 , 2.7034955, 6.6364865]], dtype=float32), array([[1.360567 +0.1977107j , 1.7586378-0.56989706j, - 1.5772758-0.8165493j ], - [1.9152443-0.1834492j , 1.1593437+0.55631363j, - 1.6889225-0.724835j ]], dtype=complex64)), - mlir_module_text=r""" -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":695:13) -#loc5 = loc("jit(func)/jit(main)/pjit"(#loc1)) -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x3xf32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}, tensor<2x3xcomplex> {jax.result_info = "[3]", mhlo.layout_mode = "default"}) { - %cst = stablehlo.constant dense<[[[(3.32284164,1.14621949), (-1.97564387,4.59335613), (7.36770821,0.885187268), (-8.65993785,1.61327934)], [(2.495340e+00,1.36827672), (-3.96969199,-0.636681795), (3.3011198,-4.64436197), (-4.85899353,-0.614397943)], [(6.03322554,1.46055949), (-3.89591122,-4.1833396), (-1.46423841,-0.106284566), (-1.0288142,1.85440636)], [(-0.657281339,0.911450386), (3.18693113,-2.02812219), (-2.64483237,0.351429433), (4.45011663,-1.79112875)]], [[(-4.18038845,-3.65238023), (0.567084491,0.691301584), (2.67520404,-0.238818452), (-0.418250352,-1.4060576)], [(-7.62970591,1.5292784), (0.269325763,2.48722434), (-2.49418068,-1.93161535), (0.668778717,-2.20977592)], [(-0.570908666,-2.75890398), (-0.235837936,3.45861554), (-0.946199476,0.23120968), (2.79622555,1.53168249)], [(0.886947453,-0.466695577), (-3.194850e+00,-0.0176551137), (-4.37602425,-3.7703948), (0.883143305,-4.70016575)]]]> : tensor<2x4x4xcomplex> loc(#loc) - %c = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_2 = stablehlo.constant dense<2> : tensor loc(#loc2) - %c_3 = stablehlo.constant dense<128> : tensor loc(#loc2) - %0:6 = stablehlo.custom_call @lapack_chetrd(%c, %c_0, %c_1, %c_2, %c_3, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xcomplex>, tensor<2xi32>, tensor<128xcomplex>) loc(#loc2) - %c_4 = stablehlo.constant dense<0> : tensor loc(#loc) - %1 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc4) - %cst_5 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc) - %4 = call @_where(%3, %0#0, %cst_5) : (tensor<2x1x1xi1>, tensor<2x4x4xcomplex>, tensor>) -> tensor<2x4x4xcomplex> loc(#loc5) - %c_6 = stablehlo.constant dense<0> : tensor loc(#loc) - %5 = stablehlo.broadcast_in_dim %c_6, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %6 = stablehlo.compare EQ, %0#4, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_7 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) - %8 = call @_where_0(%7, %0#1, %cst_7) : (tensor<2x1xi1>, tensor<2x4xf32>, tensor) -> tensor<2x4xf32> loc(#loc5) - %c_8 = stablehlo.constant dense<0> : tensor loc(#loc) - %9 = stablehlo.broadcast_in_dim %c_8, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %10 = stablehlo.compare EQ, %0#4, %9, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %11 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_9 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) - %12 = call @_where_1(%11, %0#2, %cst_9) : (tensor<2x1xi1>, tensor<2x3xf32>, tensor) -> tensor<2x3xf32> loc(#loc5) - %c_10 = stablehlo.constant dense<0> : tensor loc(#loc) - %13 = stablehlo.broadcast_in_dim %c_10, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %14 = stablehlo.compare EQ, %0#4, %13, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %15 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_11 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc) - %16 = call @_where_2(%15, %0#3, %cst_11) : (tensor<2x1xi1>, tensor<2x3xcomplex>, tensor>) -> tensor<2x3xcomplex> loc(#loc5) - return %4, %8, %12, %16 : tensor<2x4x4xcomplex>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xcomplex> loc(#loc) - } loc(#loc) - func.func private @_where(%arg0: tensor<2x1x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc7) - return %2 : tensor<2x4x4xcomplex> loc(#loc5) - } loc(#loc5) - func.func private @_where_0(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4xf32> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4xf32> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc7) - return %2 : tensor<2x4xf32> loc(#loc5) - } loc(#loc5) - func.func private @_where_1(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xf32> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x3xf32> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xf32> loc(#loc7) - return %2 : tensor<2x3xf32> loc(#loc5) - } loc(#loc5) - func.func private @_where_2(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xcomplex> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xcomplex> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor>) -> tensor<2x3xcomplex> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xcomplex> loc(#loc7) - return %2 : tensor<2x3xcomplex> loc(#loc5) - } loc(#loc5) -} loc(#loc) -#loc = loc(unknown) -#loc2 = loc("jit(func)/jit(main)/tridiagonal"(#loc1)) -#loc3 = loc("jit(func)/jit(main)/eq"(#loc1)) -#loc4 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc1)) -#loc6 = loc("jit(func)/jit(main)/jit(_where)/broadcast_in_dim"(#loc1)) -#loc7 = loc("jit(func)/jit(main)/jit(_where)/select_n"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\rStableHLO_v1.3.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#\'+\x03\xf7\x99I\x01-\x0f\x07\x0f\x0f\x17\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03m\x0f\x0b\x0b\x0f\x0b\x17\x13\x0f\x0b\x1f\x0b\x0b/OO\x0b\x0b\x0b\x0b\x0b\x1fo//\x1f\x0b\x1b\x1b\x0b\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0b\x0b\x0b\x0bo&\x08\x1f\x1f\x1f\x0b\x0b\x0b\x0b#\x0f\x17#\x01\x05\x0b\x0f\x03E\x0f\x1b\x17\x17\x17\x17\x0f\x0f\x07\x13\x0b\x07\x07\x07\x13\x1b\x17\x07\x1f\x1f\x1f\x1f\x1f\x13\x13\x17\x1b\x13\x17\x13\x13\x13\x13\x13\x02\xe2\x0b\x1d\x1f\t\x1f\x1d#\t\x1d)\t\x17!\xde\n\x1b\x1d\'\t\x1d%\t\x1d+\t\x11\x03\x05\x03\x07\x15\x17\x19\x11\x1b\x11\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x05%\x05\'\x05)\x05+\x1f5\x01\x1d-\x1d/\x1f?\x01\x1d1\x03\x07999\r\x03/1\x03\x039\x1d3\x1f\x05\t\x00\x00\x00\x00\t\x07\x07\x01\x1fG\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1fC!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d5\x1d7\x1d9\x1d;\x1f\x05\t\x04\x00\x00\x00\x1fA1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1fE\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x11\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x13\t\x00\x00\xc0\x7f#)\x03\tcgko\r\x055e/1\x1d=\r\x055i/1\x1d?\r\x055m/1\x1dA\r\x055q/1\x1dC\x1dE\x1dG#+#-#/#1\x1f;1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x02\x04p\xa9T@R\xb7\x92?\xe6\xe1\xfc\xbf\xc6\xfc\x92@D\xc4\xeb@\xa2\x9bb?\x1b\x8f\n\xc1\xf0\x7f\xce?\xa7\xb3\x1f@\xb1#\xaf?o\x0f~\xc0\x94\xfd"\xbf\x8cES@\x9d\x9e\x94\xc0\xe0|\x9b\xc0/I\x1d\xbf/\x10\xc1@\x9d\xf3\xba?\x9cVy\xc0\xeb\xdd\x85\xc0*l\xbb\xbf\xb9\xab\xd9\xbd/\xb0\x83\xbf0]\xed?\x97C(\xbf\xd0Ti?\xae\xf6K@\xc1\xcc\x01\xc0\xefD)\xc0\x8f\xee\xb3>[g\x8e@\xb5C\xe5\xbf\xbe\xc5\x85\xc0\x99\xc0i\xc0s,\x11?$\xf90?\x8b6+@\xd3\x8ct\xbe\xe9$\xd6\xbe\xb2\xf9\xb3\xbf\x8d&\xf4\xc0e\xbf\xc3?\x11\xe5\x89>\xaf.\x1f@\xa8\xa0\x1f\xc0,?\xf7\xbf\x155+?\xf8l\r\xc0\x12\'\x12\xbf\xe2\x910\xc0\x80\x7fq\xbe\xf5Y]@!:r\xbf;\xc2l>\\\xf52@,\x0e\xc4?\xfd\x0ec?\xb9\xf2\xee\xbelxL\xc0u\xa1\x90\xbcd\x08\x8c\xc0&Nq\xc0\xae\x15b?\xc2g\x96\xc0\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\x80\x00\x00\x00\x0b\x05\x1dI\x1dK\x05\x01\x03\r33333W\x03\x03\x95\x15\x03\x01\x15\x01\x03\rWIIIYY\x01\t\x01\x02\x02)\x01\')\x07\t\x11\x11\x19)\x05\t\x05\x15)\x05\t\x11\x1b)\x05\t\r\x1b)\x05\t\r\x19)\x01\x19)\x01\x1b\x01)\x03\t\'\x03\x1b\t\x1d\x13)\x03\t\x15)\x07\t\x05\x05\x15)\x05\t\r\x15\x1b\x11\x01\t\x07\x0b\r\x0f\x11\x07#\x07\x11\x03\x07\x11\x07\t\x0b\x13\x03\x0b\x11\x07\t\r\x13\x03\r\x11\x07\t\x0f\x11\x03\x0f)\x03\t\x1d)\x03\x01\x1d)\x05\t\x11\x15)\x07\t\x11\x11\x15)\x03\r\x1d)\x03\x02\x04\x19)\x03\x01\x1f)\x03\r\x1f)\x03\t\x1f)\x03\x05\x1f)\x03\x05\x1d\x04J\x07\x05\x01Q\x03\x13\x01\x07\x04"\x07\x03\x01\x15\x07P\x03\x03\x07\x04\xf6\x03\x03I\x81\x05B\x03\x05\x03\x07\x05B\x0b\x07\x03\x05\x05B\x0b\t\x03\x05\x05B\x0b\x07\x03\x05\x05B\x0b\x0b\x03\x05\x05B\x0b\r\x03\x05\x11F\x0b\x0f\r\x07\x0b\r\x0f\x17=\r\x03\x05\x07\t\x0b\x01\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x03\x19\rF\x07\x15\x03!\x05\x15\x1b\x03F\x0f\x17\x03#\x03\x1d\x05B\x03\x19\x03\x11\x0fF\x01\x1b\x03\x07\x07\x1f\r!\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x03%\rF\x07\x15\x03!\x05\x15\'\x03F\x0f\x17\x03\t\x03)\x05B\x03\x1d\x03\x13\x0fF\x01\x1f\x03\x0b\x07+\x0f-\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x031\rF\x07\x15\x03!\x05\x153\x03F\x0f\x17\x03\t\x035\x05B\x03\x1d\x03\x13\x0fF\x01!\x03\r\x077\x119\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x03=\rF\x07\x15\x03!\x05\x15?\x03F\x0f\x17\x03\t\x03A\x05B\x03\x19\x03\x11\x0fF\x01#\x03\x0f\x07C\x13E\t\x04\x03\t#/;G\x07P\x01%\x07\x04S\x03\r\x13\x07G\x01\x0f\x01#\x01\x00\x03F\x05\'\x039\x03\x01\x03F\x05\x13\x03\x07\x03\x05\x0b\x06\r\x03\x07\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01)\x07\x04S\x03\r\x13\x07\x13\x01\x17\x01\'\x01\x00\x03F\x05+\x037\x03\x01\x03F\x05\x13\x03\x0b\x03\x05\x0b\x06\r\x03\x0b\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01-\x07\x04S\x03\r\x13\x07\x13\x01\x1b\x01\'\x01\x00\x03F\x05+\x03%\x03\x01\x03F\x05\x13\x03\r\x03\x05\x0b\x06\r\x03\r\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01/\x07\x04S\x03\r\x13\x07\x13\x01\x1f\x01#\x01\x00\x03F\x05+\x03%\x03\x01\x03F\x05\x13\x03\x0f\x03\x05\x0b\x06\r\x03\x0f\x07\x07\x03\t\t\x04\x01\x03\x0b\x06\x03\x01\x05\x01\x00\x96\tM\x1d\x03\x0f\x0b\t\t\t\t\x13\x13\x13\x0f\x11!\x11#K/ASci3\x13%)9\x1f\x11\x17\x15\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00func_v1\x00return_v1\x00select_v1\x00compare_v1\x00call_v1\x00custom_call_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/jit(_where)/broadcast_in_dim\x00jit(func)/jit(main)/jit(_where)/select_n\x00jit(func)/jit(main)/tridiagonal\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00mhlo.layout_mode\x00default\x00jax.result_info\x00private\x00_where\x00_where_0\x00_where_1\x00_where_2\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00\x00lapack_chetrd\x00\x08\x9d1\x05;\x01\x0bK_asu\x03\x81\x03U\x03\x83\x03\x85\x03\x87\x11\x89\x8b\x8dK\x8f\x91\x93\x97\x03?\x03-\x05AC\x03E\x03[\x03M\x03]\x03O\x03Q\x03S\x0b7w;M=\x03\x7f\x0b7y;O=\x03G\x0b7{;Q=\x0b7};S=', - xla_call_module_version=9, - nr_devices=1, -) # End paste - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_09_03["f32"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_ssytrd'], - serialized_date=datetime.date(2024, 9, 3), - inputs=(), - expected_outputs=(array([[[-0.8395241 , 0.156272 , -1.6810869 , 0.23832119], - [-2.985257 , -5.571 , -0.22652794, -0.83806676], - [ 0.27237308, -1.6295947 , 2.0042834 , -1.148861 ], - [-0.17183593, 0.57464546, 0.5536146 , -4.206357 ]], - - [[ 1.7666914 , 2.569005 , -0.86576384, -0.1617768 ], - [-5.143918 , 5.0426254 , -3.7237067 , 4.383015 ], - [ 0.33311516, -1.5299042 , -8.854181 , -2.896776 ], - [ 0.3419102 , 0.2669245 , -2.8250606 , 5.752488 ]]], - dtype=float32), array([[-0.8395241, -5.571 , 2.0042834, -4.206357 ], - [ 1.7666914, 5.0426254, -8.854181 , 5.752488 ]], dtype=float32), array([[-2.985257 , -1.6295947, 0.5536146], - [-5.143918 , -1.5299042, -2.8250606]], dtype=float32), array([[1.8120625, 1.5035137, 0. ], - [1.6288393, 1.8669801, 0. ]], dtype=float32)), - mlir_module_text=r""" -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":695:13) -#loc5 = loc("jit(func)/jit(main)/pjit"(#loc1)) -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x4x4xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x3xf32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}, tensor<2x3xf32> {jax.result_info = "[3]", mhlo.layout_mode = "default"}) { - %cst = stablehlo.constant dense<[[[-0.83952409, 1.562720e-01, -1.6810869, 0.238321185], [2.42421508, -5.17118931, -0.226527944, -0.838066756], [1.47339451, -1.32866347, -3.3505435, -1.14886105], [-0.929541587, -0.955984473, 2.71886253, 0.748659431]], [[1.76669145, 2.56900501, -0.865763843, -0.161776796], [3.23469758, -0.362713158, -3.72370672, 4.38301516], [2.79104376, 7.36582708, -3.04437494, -2.89677596], [2.86473417, 0.981746375, -2.13533139, 5.34802151]]]> : tensor<2x4x4xf32> loc(#loc) - %c = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_2 = stablehlo.constant dense<2> : tensor loc(#loc2) - %c_3 = stablehlo.constant dense<128> : tensor loc(#loc2) - %0:6 = stablehlo.custom_call @lapack_ssytrd(%c, %c_0, %c_1, %c_2, %c_3, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xf32>) -> (tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xf32>, tensor<2xi32>, tensor<128xf32>) loc(#loc2) - %c_4 = stablehlo.constant dense<0> : tensor loc(#loc) - %1 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc4) - %cst_5 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) - %4 = call @_where(%3, %0#0, %cst_5) : (tensor<2x1x1xi1>, tensor<2x4x4xf32>, tensor) -> tensor<2x4x4xf32> loc(#loc5) - %c_6 = stablehlo.constant dense<0> : tensor loc(#loc) - %5 = stablehlo.broadcast_in_dim %c_6, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %6 = stablehlo.compare EQ, %0#4, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_7 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) - %8 = call @_where_0(%7, %0#1, %cst_7) : (tensor<2x1xi1>, tensor<2x4xf32>, tensor) -> tensor<2x4xf32> loc(#loc5) - %c_8 = stablehlo.constant dense<0> : tensor loc(#loc) - %9 = stablehlo.broadcast_in_dim %c_8, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %10 = stablehlo.compare EQ, %0#4, %9, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %11 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_9 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) - %12 = call @_where_1(%11, %0#2, %cst_9) : (tensor<2x1xi1>, tensor<2x3xf32>, tensor) -> tensor<2x3xf32> loc(#loc5) - %c_10 = stablehlo.constant dense<0> : tensor loc(#loc) - %13 = stablehlo.broadcast_in_dim %c_10, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %14 = stablehlo.compare EQ, %0#4, %13, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %15 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_11 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) - %16 = call @_where_1(%15, %0#3, %cst_11) : (tensor<2x1xi1>, tensor<2x3xf32>, tensor) -> tensor<2x3xf32> loc(#loc5) - return %4, %8, %12, %16 : tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xf32> loc(#loc) - } loc(#loc) - func.func private @_where(%arg0: tensor<2x1x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4x4xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4x4xf32> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4x4xf32> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc7) - return %2 : tensor<2x4x4xf32> loc(#loc5) - } loc(#loc5) - func.func private @_where_0(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4xf32> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4xf32> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc7) - return %2 : tensor<2x4xf32> loc(#loc5) - } loc(#loc5) - func.func private @_where_1(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xf32> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x3xf32> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xf32> loc(#loc7) - return %2 : tensor<2x3xf32> loc(#loc5) - } loc(#loc5) -} loc(#loc) -#loc = loc(unknown) -#loc2 = loc("jit(func)/jit(main)/tridiagonal"(#loc1)) -#loc3 = loc("jit(func)/jit(main)/eq"(#loc1)) -#loc4 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc1)) -#loc6 = loc("jit(func)/jit(main)/jit(_where)/broadcast_in_dim"(#loc1)) -#loc7 = loc("jit(func)/jit(main)/jit(_where)/select_n"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\rStableHLO_v1.3.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#\'+\x03\xe9\x93A\x01-\x0f\x07\x0f\x17\x0f\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03g\x0f\x0b\x0b\x0f\x0b\x13\x1f\x0b\x0b/\x1f\x17\x0f\x0b\x0bO\x0b\x0b\x0bO\x1fo/\x0b\x1b\x1b\x0b\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0b\x0b\x0bo&\x04\x1f\x1f\x1f\x0b\x0b\x0b\x0b#\x0f\x17#\x01\x05\x0b\x0f\x03=\x0f\x17\x0f\x1b\x17\x17\x07\x07\x13\x07\x07\x13\x1b\x07\x1f\x1f\x1f\x1f\x17\x13\x13\x17\x1b\x13\x17\x13\x13\x13\x13\x13\x02b\t\x1d\x1f\x07\x1f\x1d)\x07\x17!\xde\n\x1b\x1d#\x07\x1d\'\x07\x1d+\x07\x1d%\x07\x11\x03\x05\x03\x07\x15\x17\x19\x11\x1b\x11\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x05%\x05\'\x05)\x05+\x1f-\x01\x1d-\x1d/\x1f7\x01\x1d1\r\x03/1\x1f\x05\t\x00\x00\x00\x00\t\x07\x07\x01\x1f?\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\t\x00\x00\xc0\x7f\x03\x07777\x03\x037\x1d3\x1d5\x1f;!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d7\x1d9\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f91\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f=\x11\x00\x00\x00\x00\x00\x00\x00\x00#!\x03\t_cgk\r\x055a/1\x1d;\r\x055e/1\x1d=\r\x055i/1\x1d?\r\x055m/1\x1dA\x1dC\x1dE###%#\'\x1f31\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x0b\x02\x02\r\xebV\xbf\xc4\x05 >\xdb-\xd7\xbfx\nt>W&\x1b@bz\xa5\xc0\xf1\xf6g\xbe\x8b\x8bV\xbf1\x98\xbc?\xa5\x11\xaa\xbfNoV\xc0\xe1\r\x93\xbfp\xf6m\xbff\xbbt\xbf\xd8\x01.@%\xa8??\xf2"\xe2?\x94j$@\xb3\xa2]\xbf\xd1\xa8%\xbeI\x05O@\x8a\xb5\xb9\xbe6Qn\xc0\xa9A\x8c@v\xa02@\xdb\xb4\xeb@\n\xd7B\xc0\xc7d9\xc0\xceW7@\xbbS{?E\xa9\x08\xc0\xfe"\xab@\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\x80\x00\x00\x00\x0b\x05\x1dG\x1dI\x05\x01\x03\r33333W\x03\x03\x8f\x15\x03\x01\x15\x01\x03\rWKKKYY\x01\t\x01\x02\x02)\x01\x1f)\x05\t\r\x13)\x01\x13)\x07\t\x11\x11\x13)\x05\t\x11\x13)\x05\t\x05\x11\x01\t)\x03\t\x1f\x1d\x13)\x03\t\x11)\x07\t\x05\x05\x11\x1b\x11\x01\t\x0b\r\x07\x07\x11\x07\x1d\x0b\t\x03\x0b\x11\x07\x0f\r\t\x03\r\x11\x07\x0f\x07\t\x03\x07)\x05\t\r\x11)\x03\t\x17)\x03\x01\x17)\x05\t\x11\x11)\x07\t\x11\x11\x11)\x03\r\x17)\x03\x02\x04\x13)\x03\x01\x19)\x03\r\x19)\x03\t\x19)\x03\x05\x19)\x03\x05\x17\x04\x8a\x06\x05\x01Q\x03\x13\x01\x07\x04b\x06\x03\x01\x11\x07P\x03\x03\x07\x04\xf6\x03\x03I\x81\x05B\x03\x05\x03\x0b\x05B\x0b\x07\x03\x05\x05B\x0b\t\x03\x05\x05B\x0b\x07\x03\x05\x05B\x0b\x0b\x03\x05\x05B\x0b\r\x03\x05\x11F\x0b\x0f\r\x0b\r\x07\x07\x155\r\x03\x05\x07\t\x0b\x01\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x03\x19\x0bF\x05\x15\x03\x1b\x05\x15\x1b\x03F\r\x17\x03\x1d\x03\x1d\x05B\x03\x19\x03\t\rF\x01\x1b\x03\x0b\x07\x1f\r!\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x03%\x0bF\x05\x15\x03\x1b\x05\x15\'\x03F\r\x17\x03\x0f\x03)\x05B\x03\x19\x03\t\rF\x01\x1d\x03\r\x07+\x0f-\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x031\x0bF\x05\x15\x03\x1b\x05\x153\x03F\r\x17\x03\x0f\x035\x05B\x03\x19\x03\t\rF\x01\x1f\x03\x07\x077\x119\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x03=\x0bF\x05\x15\x03\x1b\x05\x15?\x03F\r\x17\x03\x0f\x03A\x05B\x03\x19\x03\t\rF\x01\x1f\x03\x07\x07C\x13E\t\x04\x03\t#/;G\x07P\x01!\x07\x04S\x03\r\x13\x07;\x01\x17\x01\x13\x01\x00\x03F\t#\x031\x03\x01\x03F\t\x13\x03\x0b\x03\x05\x0f\x06\x0f\x03\x0b\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01%\x07\x04S\x03\r\x13\x07\x1f\x01\x1b\x01\x13\x01\x00\x03F\t\'\x03/\x03\x01\x03F\t\x13\x03\r\x03\x05\x0f\x06\x0f\x03\r\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01)\x07\x04S\x03\r\x13\x07\x1f\x01\x0f\x01\x13\x01\x00\x03F\t\'\x03)\x03\x01\x03F\t\x13\x03\x07\x03\x05\x0f\x06\x0f\x03\x07\x07\x07\x03\t\t\x04\x01\x03\x0b\x06\x03\x01\x05\x01\x00n\tK\x1d\x03\x0f\x0b\t\t\t\t\x13\x0f\x13\x11!\x11#K/ASci3\x13%)9\x1f\x15\x11\x17\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00func_v1\x00return_v1\x00compare_v1\x00call_v1\x00select_v1\x00custom_call_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/jit(_where)/broadcast_in_dim\x00jit(func)/jit(main)/jit(_where)/select_n\x00jit(func)/jit(main)/tridiagonal\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00mhlo.layout_mode\x00default\x00jax.result_info\x00private\x00_where_1\x00_where\x00_where_0\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00\x00lapack_ssytrd\x00\x08\x89+\x05;\x01\x0bM[]oq\x03{\x03U\x03}\x03\x7f\x03\x81\x11\x83\x85\x87M\x89\x8b\x8d\x91\x039\x03-\x05;=\x03?\x03A\x03O\x03Q\x03I\x0bCsEOG\x03y\x0bCuEQG\x03S\x0bCwEIG', - xla_call_module_version=9, - nr_devices=1, -) # End paste - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_09_03["f64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dsytrd'], - serialized_date=datetime.date(2024, 9, 3), - inputs=(), - expected_outputs=(array([[[ 0.8251247184208595 , -2.6963562039892532 , - 0.8082445002373937 , -1.551980329390836 ], - [-2.629505060186711 , 4.427374205796291 , - -2.2111093161901074 , 7.552489598405787 ], - [ 0.2269453213819231 , 0.3650586474106988 , - -3.5933639667756205 , 4.828829679372501 ], - [-0.6415372293575187 , -0.2519326897319508 , - -1.7607827845801751 , -3.381311711243865 ]], - - [[-4.000421911405985 , 3.6303350337601055 , - 2.8066821235532355 , 1.099224389184342 ], - [-4.141622408467332 , -5.276404169116551 , - -0.8496056221591237 , -2.275319346221659 ], - [ 0.5828958067901202 , 0.9351254869793256 , - 2.7765603683442177 , -4.339686212557215 ], - [-0.6391146585297987 , 0.3129920702652711 , - -0.25441692469349864, -1.4155240723557498 ]]]), array([[ 0.8251247184208595, 4.427374205796291 , -3.5933639667756205, - -3.381311711243865 ], - [-4.000421911405985 , -5.276404169116551 , 2.7765603683442177, - -1.4155240723557498]]), array([[-2.629505060186711 , 0.3650586474106988 , -1.7607827845801751 ], - [-4.141622408467332 , 0.9351254869793256 , -0.25441692469349864]]), array([[1.3669846724688552, 1.8806358893589366, 0. ], - [1.1440109149169537, 1.8215532880266878, 0. ]])), - mlir_module_text=r""" -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":695:13) -#loc5 = loc("jit(func)/jit(main)/pjit"(#loc1)) -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x4x4xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x3xf64> {jax.result_info = "[2]", mhlo.layout_mode = "default"}, tensor<2x3xf64> {jax.result_info = "[3]", mhlo.layout_mode = "default"}) { - %cst = stablehlo.constant dense<[[[0.82512471842085955, -2.6963562039892532, 0.80824450023739369, -1.5519803293908361], [0.96498805326781766, -4.1313349231964409, -2.2111093161901074, 7.5524895984057867], [0.81575339483804743, 1.0647235400727899, -1.0064296232364345, 4.8288296793725012], [-2.3060011529502993, -2.9182106402942192, -1.7781896154088577, 2.5904630742096817]], [[-4.0004219114059847, 3.6303350337601055, 2.8066821235532355, 1.0992243891843421], [0.59643883228393779, -1.5243235004961249, -0.84960562215912372, -2.275319346221659], [2.7617960295487092, -0.57538970930521982, 0.12559406141906576, -4.3396862125572149], [-3.0281643919760217, 0.38177997229319849, 3.860398204232184, -2.5166384340510231]]]> : tensor<2x4x4xf64> loc(#loc) - %c = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_2 = stablehlo.constant dense<2> : tensor loc(#loc2) - %c_3 = stablehlo.constant dense<128> : tensor loc(#loc2) - %0:6 = stablehlo.custom_call @lapack_dsytrd(%c, %c_0, %c_1, %c_2, %c_3, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xf64>) -> (tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xf64>, tensor<2xi32>, tensor<128xf64>) loc(#loc2) - %c_4 = stablehlo.constant dense<0> : tensor loc(#loc) - %1 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc4) - %cst_5 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) - %4 = call @_where(%3, %0#0, %cst_5) : (tensor<2x1x1xi1>, tensor<2x4x4xf64>, tensor) -> tensor<2x4x4xf64> loc(#loc5) - %c_6 = stablehlo.constant dense<0> : tensor loc(#loc) - %5 = stablehlo.broadcast_in_dim %c_6, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %6 = stablehlo.compare EQ, %0#4, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_7 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) - %8 = call @_where_0(%7, %0#1, %cst_7) : (tensor<2x1xi1>, tensor<2x4xf64>, tensor) -> tensor<2x4xf64> loc(#loc5) - %c_8 = stablehlo.constant dense<0> : tensor loc(#loc) - %9 = stablehlo.broadcast_in_dim %c_8, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %10 = stablehlo.compare EQ, %0#4, %9, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %11 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_9 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) - %12 = call @_where_1(%11, %0#2, %cst_9) : (tensor<2x1xi1>, tensor<2x3xf64>, tensor) -> tensor<2x3xf64> loc(#loc5) - %c_10 = stablehlo.constant dense<0> : tensor loc(#loc) - %13 = stablehlo.broadcast_in_dim %c_10, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %14 = stablehlo.compare EQ, %0#4, %13, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %15 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_11 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) - %16 = call @_where_1(%15, %0#3, %cst_11) : (tensor<2x1xi1>, tensor<2x3xf64>, tensor) -> tensor<2x3xf64> loc(#loc5) - return %4, %8, %12, %16 : tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xf64> loc(#loc) - } loc(#loc) - func.func private @_where(%arg0: tensor<2x1x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4x4xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4x4xf64> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4x4xf64> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc7) - return %2 : tensor<2x4x4xf64> loc(#loc5) - } loc(#loc5) - func.func private @_where_0(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4xf64> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4xf64> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc7) - return %2 : tensor<2x4xf64> loc(#loc5) - } loc(#loc5) - func.func private @_where_1(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xf64> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x3xf64> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xf64> loc(#loc7) - return %2 : tensor<2x3xf64> loc(#loc5) - } loc(#loc5) -} loc(#loc) -#loc = loc(unknown) -#loc2 = loc("jit(func)/jit(main)/tridiagonal"(#loc1)) -#loc3 = loc("jit(func)/jit(main)/eq"(#loc1)) -#loc4 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc1)) -#loc6 = loc("jit(func)/jit(main)/jit(_where)/broadcast_in_dim"(#loc1)) -#loc7 = loc("jit(func)/jit(main)/jit(_where)/select_n"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\rStableHLO_v1.3.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#\'+\x03\xe9\x93A\x01-\x0f\x07\x0f\x17\x0f\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03g\x0f\x0b\x0b\x0f\x0b\x13\x1f\x0b\x0b//\x17\x0f\x0b\x0bO\x0b\x0b\x0bO\x1fo/\x0b\x1b\x1b\x0b\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0b\x0b\x0bo&\x08\x1f\x1f\x1f\x0b\x0b\x0b\x0b#\x0f\x17#\x01\x05\x0b\x0f\x03=\x0f\x17\x0f\x1b\x17\x17\x07\x07\x13\x07\x07\x13\x1b\x07\x1f\x1f\x1f\x1f\x17\x13\x13\x17\x1b\x13\x17\x13\x13\x13\x13\x13\x02r\x0b\x1d\x1f\x07\x1f\x1d)\x07\x17!\xde\n\x1b\x1d#\x07\x1d\'\x07\x1d+\x07\x1d%\x07\x11\x03\x05\x03\x07\x15\x17\x19\x11\x1b\x11\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x05%\x05\'\x05)\x05+\x1f-\x01\x1d-\x1d/\x1f7\x01\x1d1\r\x03/1\x1f\x05\t\x00\x00\x00\x00\t\x07\x07\x01\x1f?\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x03\x07777\x03\x037\x1d3\x1d5\x1f;!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d7\x1d9\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f91\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f=\x11\x00\x00\x00\x00\x00\x00\x00\x00#!\x03\t_cgk\r\x055a/1\x1d;\r\x055e/1\x1d=\r\x055i/1\x1d?\r\x055m/1\x1dA\x1dC\x1dE###%#\'\x1f31\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x0b\x02\x04A\xa4\x17\xf4kg\xea?\x1f\x01\x943#\x92\x05\xc0\x86 \xf6\x91#\xdd\xe9?\x9dMlS\xe9\xd4\xf8\xbf\x88\x1c:\xa0.\xe1\xee?8\xce\x7f\xa9|\x86\x10\xc0\xe8V\xc7\x14Z\xb0\x01\xc0\xd2!R\xd5\xbf5\x1e@\xbf\xc5\r\xdd\xa6\x1a\xea?\xbcM\xfe\x8c\x1b\t\xf1?\xdbj\xd8\xf2U\x1a\xf0\xbf\xado;\xba\xb8P\x13@\xbb\xad\x83\xbb\xb0r\x02\xc0\x1f9\xf7\xd1~X\x07\xc0)ID\xf4vs\xfc\xbfD\xcfI\xb4D\xb9\x04@\x16\xc3\xfe\x99n\x00\x10\xc0\x82.\x1c\x18\xed\n\r@\x8cn\xd7\xc1\x15t\x06@|2(Pl\x96\xf1?\x88*\xd7\xe3\x06\x16\xe3?F{\xf2\t\xa1c\xf8\xbf8z5!\xf8/\xeb\xbf4\xd3\x1f\xa1\xda3\x02\xc0)\x13I\x84(\x18\x06@\xbcw\xfd\xad\x97i\xe2\xbf\x1e\xf0.Yw\x13\xc0?dW\xd7\xb3\xd6[\x11\xc0\x04\x97\xb3@\xae9\x08\xc0\xbc\x17\xd1C\x15o\xd8?\x02\xb7%t\x18\xe2\x0e@\xac\xd8\xd0T\x13"\x04\xc0\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\x80\x00\x00\x00\x0b\x05\x1dG\x1dI\x05\x01\x03\r33333W\x03\x03\x8f\x15\x03\x01\x15\x01\x03\rWKKKYY\x01\t\x01\x02\x02)\x01\x1f)\x05\t\r\x13)\x01\x13)\x07\t\x11\x11\x13)\x05\t\x11\x13)\x05\t\x05\x11\x01\x0b)\x03\t\x1f\x1d\x13)\x03\t\x11)\x07\t\x05\x05\x11\x1b\x11\x01\t\x0b\r\x07\x07\x11\x07\x1d\x0b\t\x03\x0b\x11\x07\x0f\r\t\x03\r\x11\x07\x0f\x07\t\x03\x07)\x05\t\r\x11)\x03\t\x17)\x03\x01\x17)\x05\t\x11\x11)\x07\t\x11\x11\x11)\x03\r\x17)\x03\x02\x04\x13)\x03\x01\x19)\x03\r\x19)\x03\t\x19)\x03\x05\x19)\x03\x05\x17\x04\x8a\x06\x05\x01Q\x03\x13\x01\x07\x04b\x06\x03\x01\x11\x07P\x03\x03\x07\x04\xf6\x03\x03I\x81\x05B\x03\x05\x03\x0b\x05B\x0b\x07\x03\x05\x05B\x0b\t\x03\x05\x05B\x0b\x07\x03\x05\x05B\x0b\x0b\x03\x05\x05B\x0b\r\x03\x05\x11F\x0b\x0f\r\x0b\r\x07\x07\x155\r\x03\x05\x07\t\x0b\x01\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x03\x19\x0bF\x05\x15\x03\x1b\x05\x15\x1b\x03F\r\x17\x03\x1d\x03\x1d\x05B\x03\x19\x03\t\rF\x01\x1b\x03\x0b\x07\x1f\r!\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x03%\x0bF\x05\x15\x03\x1b\x05\x15\'\x03F\r\x17\x03\x0f\x03)\x05B\x03\x19\x03\t\rF\x01\x1d\x03\r\x07+\x0f-\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x031\x0bF\x05\x15\x03\x1b\x05\x153\x03F\r\x17\x03\x0f\x035\x05B\x03\x19\x03\t\rF\x01\x1f\x03\x07\x077\x119\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x03=\x0bF\x05\x15\x03\x1b\x05\x15?\x03F\r\x17\x03\x0f\x03A\x05B\x03\x19\x03\t\rF\x01\x1f\x03\x07\x07C\x13E\t\x04\x03\t#/;G\x07P\x01!\x07\x04S\x03\r\x13\x07;\x01\x17\x01\x13\x01\x00\x03F\t#\x031\x03\x01\x03F\t\x13\x03\x0b\x03\x05\x0f\x06\x0f\x03\x0b\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01%\x07\x04S\x03\r\x13\x07\x1f\x01\x1b\x01\x13\x01\x00\x03F\t\'\x03/\x03\x01\x03F\t\x13\x03\r\x03\x05\x0f\x06\x0f\x03\r\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01)\x07\x04S\x03\r\x13\x07\x1f\x01\x0f\x01\x13\x01\x00\x03F\t\'\x03)\x03\x01\x03F\t\x13\x03\x07\x03\x05\x0f\x06\x0f\x03\x07\x07\x07\x03\t\t\x04\x01\x03\x0b\x06\x03\x01\x05\x01\x00n\tK\x1d\x03\x0f\x0b\t\t\t\t\x13\x0f\x13\x11!\x11#K/ASci3\x13%)9\x1f\x15\x11\x17\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00func_v1\x00return_v1\x00compare_v1\x00call_v1\x00select_v1\x00custom_call_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/jit(_where)/broadcast_in_dim\x00jit(func)/jit(main)/jit(_where)/select_n\x00jit(func)/jit(main)/tridiagonal\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00mhlo.layout_mode\x00default\x00jax.result_info\x00private\x00_where_1\x00_where\x00_where_0\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00\x00lapack_dsytrd\x00\x08\x89+\x05;\x01\x0bM[]oq\x03{\x03U\x03}\x03\x7f\x03\x81\x11\x83\x85\x87M\x89\x8b\x8d\x91\x039\x03-\x05;=\x03?\x03A\x03O\x03Q\x03I\x0bCsEOG\x03y\x0bCuEQG\x03S\x0bCwEIG', - xla_call_module_version=9, - nr_devices=1, -) # End paste +array = np.array +float32 = np.float32 +complex64 = np.complex64 data_2024_12_01 = {} - # Pasted from the test output (see export_back_compat_test_util.py module docstring) data_2024_12_01["c128"] = dict( testdata_version=1, diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_tridiagonal_solve_lapack_gtsv.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_tridiagonal_solve_lapack_gtsv.py index 3286d8d3b3d9..fd821951b093 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_tridiagonal_solve_lapack_gtsv.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_tridiagonal_solve_lapack_gtsv.py @@ -15,7 +15,11 @@ # ruff: noqa import datetime -from numpy import array, float32, complex64 +import numpy as np + +array = np.array +float32 = np.float32 +complex64 = np.complex64 data_2025_01_09 = {} diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_cholesky_solver_potrf.py b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_cholesky_solver_potrf.py new file mode 100644 index 000000000000..96996ac67a84 --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_cholesky_solver_potrf.py @@ -0,0 +1,321 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + +import datetime +import numpy as np + +array = np.array +float32 = np.float32 +int32 = np.int32 +complex64 = np.complex64 + +data_2025_10_15 = {} + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_10_15["f32"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cusolver_potrf_ffi'], + serialized_date=datetime.date(2025, 10, 15), + inputs=(array([[ 10.895978 , 4.2912526, -18.31493 , 3.697414 ], + [ 4.2912526, 61.31485 , 4.850662 , -2.1822023], + [-18.31493 , 4.850662 , 36.91168 , -4.3174276], + [ 3.697414 , -2.1822023, -4.3174276, 16.82287 ]], + dtype=float32),), + expected_outputs=(array([[ 3.3009057, 0. , 0. , 0. ], + [ 1.3000228, 7.721709 , 0. , 0. ], + [-5.548456 , 1.5623201, 1.9197574, 0. ], + [ 1.120121 , -0.4711891, 1.3718729, 3.6693518]], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("x") +module @jit_cholesky attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<4x4xf32> loc("x")) -> (tensor<4x4xf32> {jax.result_info = "result"}) { + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc) + %cst_0 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc4) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %cst_1 = stablehlo.constant dense<2.000000e+00> : tensor loc(#loc) + %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x4xf32>) -> tensor<4x4xf32> loc(#loc5) + %1 = stablehlo.add %arg0, %0 : tensor<4x4xf32> loc(#loc6) + %2 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor) -> tensor<4x4xf32> loc(#loc7) + %3 = stablehlo.divide %1, %2 : tensor<4x4xf32> loc(#loc7) + %4:2 = stablehlo.custom_call @cusolver_potrf_ffi(%3) {mhlo.backend_config = {lower = true}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>], sdy.sharding_rule = #sdy.op_sharding_rule<([i, j])->([k, l], []) {i=4, j=4, k=4, l=4}, custom>} : (tensor<4x4xf32>) -> (tensor<4x4xf32>, tensor) loc(#loc4) + %5 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor loc(#loc4) + %6 = stablehlo.compare EQ, %4#1, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc4) + %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc4) + %8 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor) -> tensor<4x4xf32> loc(#loc4) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc4) + %10 = stablehlo.select %9, %4#0, %8 : tensor<4x4xi1>, tensor<4x4xf32> loc(#loc4) + %11 = stablehlo.iota dim = 0 : tensor<4x4xi32> loc(#loc8) + %12 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<4x4xi32> loc(#loc6) + %13 = stablehlo.add %11, %12 : tensor<4x4xi32> loc(#loc6) + %14 = stablehlo.iota dim = 1 : tensor<4x4xi32> loc(#loc8) + %15 = stablehlo.compare GE, %13, %14, SIGNED : (tensor<4x4xi32>, tensor<4x4xi32>) -> tensor<4x4xi1> loc(#loc9) + %16 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x4xf32> loc(#loc10) + %17 = stablehlo.select %15, %10, %16 : tensor<4x4xi1>, tensor<4x4xf32> loc(#loc11) + return %17 : tensor<4x4xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":225:6) +#loc3 = loc("jit(cholesky)"(#loc2)) +#loc4 = loc("cholesky"(#loc3)) +#loc5 = loc("transpose"(#loc3)) +#loc6 = loc("add"(#loc3)) +#loc7 = loc("div"(#loc3)) +#loc8 = loc("iota"(#loc3)) +#loc9 = loc("ge"(#loc3)) +#loc10 = loc("broadcast_in_dim"(#loc3)) +#loc11 = loc("select_n"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.12.1\x00\x01+\x07\x01\x05\t\x19\x01\x03\x0f\x03\x17\x13\x17\x1b\x1f#'+/37;\x03\xdf\xa1'\x01E\x0f\x0f\x07\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0b\x0b\x17\x0b\x0f\x0b\x0b\x0b#\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x03M\x0fO\x0b\x0f\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x1f\x1f\x1f\x1fO\x13\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x13\x0f\x0bO\x0f\x0f\x0b\x05\x11C\x13\x0f\x0f\x13\x0f\x0f\x0b\x01\x05\x0b\x0f\x03#\x17\x0f\x0f\x07\x17\x07\x07\x07\x13\x07\x17\x17\x13\x13\x13\x0f\x17\x02\x9a\x05\x1d\x1f\x03\x1d!#\x1f\x1d+\x03\x11\x03\x05\x1d-\x03\x1d7\x03\x03\x07\x11\x13\x15\t\x17\t\x05\x1f\x11\x01\x00\x05!\x05#\x05%\x1d\x1d\x05\x05'\x05)\x05+\x17%\x86\x03\r\x05-\x1d)\x03\x05/\x051\x053\x03\x071g3m5\x91\x055\x057\x059\x05;\x1d;\x03\x05=\x1d?\x03\x05?\x1dC\x03\x05A\x1f\x1d\x01\x1f\x1f!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\t\x07\x03\x03M\r\x01#\x1b\x03\x03S\r\x03UW\x1dC\x1dE\x1dG\x1dI\x1f\x07\t\x00\x00\x00\x00\x1f\x07\t\x00\x00\xc0\x7f\x1f\t\t\x00\x00\x00\x00\x1f\x07\t\x00\x00\x00@\x1f\x15!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\r\x03ik\x1dK\x05\x03\r\x03oq\x1dM\x1dO\x0b\x03\x1dQ\x1dS\x03\x01\x05\x01\x03\x03G\x03\x03\x81\x15\x03\x01\x01\x01\x03\x05G\x85\x1f!\x01\x07\x01\x1f\x15!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x13\x0b\x01\x13\x0b\x05\x07\x05\x15\t\x11\x11\x11\x11\x03\x93\x05\x99\x9f\x01\x01\x01\x01\x01\x13\x05\x95\x97\x11\x03\x01\x11\x03\x05\x13\x05\x9b\x9d\x11\x03\t\x11\x03\r\x13\x01\x01\t\x01\x02\x02)\x05\x11\x11\x11)\x01\x11)\x01\x13\x1d)\x05\x11\x11\x13\x01\t\x1b)\x03\t\x0b\x13)\x05\x11\x11\x0f\x11\x03\x05\x03\x05)\x03\x01\x0b)\x03\t\x17)\x03\x01\x17)\x01\x0f)\x05\x05\x05\x0f\x04.\x03\x05\x01Q\x05\x0f\x01\x07\x04\x06\x03\x03\x01\x05\x0fP\x05\x03\x07\x04\xda\x02\x031_\x03\x0b\x1b\x00\x05B\x05\x05\x03\x07\x05B\x01\x07\x03\x07\x05B\x05\t\x03\t\x05B\x05\x0b\x03\x07\x11F'\r\x03\x05\x03\x01\x07\x06\x07\x03\x05\x05\x01\x0b\x03F\x0b\x0f\x03\x05\x03\t\x13\x06\x0b\x03\x05\x05\r\x0f\x15G\x01/\x11\x05\x05\t\x03\x11\x03F\x01\x0f\x03\t\x03\x07\tF\x01\x13\x03#\x05\x15\x17\x03F\x01\x0f\x03%\x03\x19\x03F\x01\x0f\x03\x05\x03\x05\x03F\x01\x15\x03\x19\x03\x1b\x0b\x06\x01\x03\x05\x07\x1f\x13\x1d\rB\r\x17\x03\r\x03F\x07\x0f\x03\r\x03\x07\x07\x06\x07\x03\r\x05#%\rB\r\x19\x03\r\tF9\x1b\x03\x19\x05')\x03F=\x0f\x03\x05\x03\x03\x0b\x06A\x03\x05\x07+!-\x17\x04\x05\x03/\x06\x03\x01\x05\x01\x00r\x08U'\x03\x05\x1f\r\x0f\x0b\x0f!\x13#\x07\x0b%3)\t\t\x15i\x1d\x13\x05\x1b%)9\x15\x1f\x15\x1b\x11\x11\x15\x17\x0f\x19)\x0f\t\x0b\x11builtin\x00vhlo\x00sdy\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00add_v1\x00compare_v1\x00select_v1\x00iota_v1\x00func_v1\x00transpose_v1\x00divide_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_cholesky\x00x\x00cholesky\x00jit(cholesky)\x00third_party/py/jax/tests/export_back_compat_test.py\x00transpose\x00add\x00div\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00sdy.sharding_rule\x00iota\x00ge\x00broadcast_in_dim\x00select_n\x00jax.result_info\x00result\x00main\x00public\x00lower\x00num_batch_dims\x000\x00\x00cusolver_potrf_ffi\x00\x08W\x1d\x053\x01\x0bKOQY[\x03]\x03_\x03a\x03c\x03e\x03E\x11suwy{}\x7f\x83\x05I\x87\x03\x89\x03\x8b\x03\x8d\x05I\x8f", + xla_call_module_version=10, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_10_15["f64"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cusolver_potrf_ffi'], + serialized_date=datetime.date(2025, 10, 15), + inputs=(array([[ 19.64203917602577 , -3.257129261958479 , + -6.113417498905007 , -38.31045157187507 ], + [ -3.257129261958479 , 30.34754344041282 , + -0.21110017803103753, 12.603145919345353 ], + [ -6.113417498905007 , -0.21110017803103753, + 8.567277222442657 , 15.150956096041329 ], + [-38.31045157187507 , 12.603145919345353 , + 15.150956096041329 , 77.94548898633911 ]]),), + expected_outputs=(array([[ 4.431934022074987 , 0. , 0. , + 0. ], + [-0.7349227776711179 , 5.459618297213917 , 0. , + 0. ], + [-1.3794017393884324 , -0.22434790660947998, 2.571807940071491 , + 0. ], + [-8.644183641059373 , 1.1448306689037868 , 1.3546868938685381 , + 0.27886255592811615]]),), + mlir_module_text=r""" +#loc1 = loc("x") +module @jit_cholesky attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<4x4xf64> loc("x")) -> (tensor<4x4xf64> {jax.result_info = "result"}) { + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc) + %cst_0 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc4) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %cst_1 = stablehlo.constant dense<2.000000e+00> : tensor loc(#loc) + %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x4xf64>) -> tensor<4x4xf64> loc(#loc5) + %1 = stablehlo.add %arg0, %0 : tensor<4x4xf64> loc(#loc6) + %2 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor) -> tensor<4x4xf64> loc(#loc7) + %3 = stablehlo.divide %1, %2 : tensor<4x4xf64> loc(#loc7) + %4:2 = stablehlo.custom_call @cusolver_potrf_ffi(%3) {mhlo.backend_config = {lower = true}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>], sdy.sharding_rule = #sdy.op_sharding_rule<([i, j])->([k, l], []) {i=4, j=4, k=4, l=4}, custom>} : (tensor<4x4xf64>) -> (tensor<4x4xf64>, tensor) loc(#loc4) + %5 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor loc(#loc4) + %6 = stablehlo.compare EQ, %4#1, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc4) + %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc4) + %8 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor) -> tensor<4x4xf64> loc(#loc4) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc4) + %10 = stablehlo.select %9, %4#0, %8 : tensor<4x4xi1>, tensor<4x4xf64> loc(#loc4) + %11 = stablehlo.iota dim = 0 : tensor<4x4xi32> loc(#loc8) + %12 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<4x4xi32> loc(#loc6) + %13 = stablehlo.add %11, %12 : tensor<4x4xi32> loc(#loc6) + %14 = stablehlo.iota dim = 1 : tensor<4x4xi32> loc(#loc8) + %15 = stablehlo.compare GE, %13, %14, SIGNED : (tensor<4x4xi32>, tensor<4x4xi32>) -> tensor<4x4xi1> loc(#loc9) + %16 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x4xf64> loc(#loc10) + %17 = stablehlo.select %15, %10, %16 : tensor<4x4xi1>, tensor<4x4xf64> loc(#loc11) + return %17 : tensor<4x4xf64> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":225:6) +#loc3 = loc("jit(cholesky)"(#loc2)) +#loc4 = loc("cholesky"(#loc3)) +#loc5 = loc("transpose"(#loc3)) +#loc6 = loc("add"(#loc3)) +#loc7 = loc("div"(#loc3)) +#loc8 = loc("iota"(#loc3)) +#loc9 = loc("ge"(#loc3)) +#loc10 = loc("broadcast_in_dim"(#loc3)) +#loc11 = loc("select_n"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.12.1\x00\x01+\x07\x01\x05\t\x19\x01\x03\x0f\x03\x17\x13\x17\x1b\x1f#'+/37;\x03\xdf\xa1'\x01E\x0f\x0f\x07\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0b\x0b\x17\x0b\x0f\x0b\x0b\x0b#\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x03M\x0fO\x0b\x0f\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b//\x1f/O\x13\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x13\x0f\x0bO\x0f\x0f\x0b\x05\x11C\x13\x0f\x0f\x13\x0f\x0f\x0b\x01\x05\x0b\x0f\x03#\x17\x0f\x0f\x07\x17\x07\x07\x07\x13\x07\x17\x17\x13\x13\x13\x0f\x17\x02\xca\x05\x1d\x1f\x03\x1d!#\x1f\x1d+\x03\x11\x03\x05\x1d-\x03\x1d7\x03\x03\x07\x11\x13\x15\t\x17\t\x05\x1f\x11\x01\x00\x05!\x05#\x05%\x1d\x1d\x05\x05'\x05)\x05+\x17%\x86\x03\r\x05-\x1d)\x03\x05/\x051\x053\x03\x071g3m5\x91\x055\x057\x059\x05;\x1d;\x03\x05=\x1d?\x03\x05?\x1dC\x03\x05A\x1f\x1d\x01\x1f\x1f!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\t\x07\x03\x03M\r\x01#\x1b\x03\x03S\r\x03UW\x1dC\x1dE\x1dG\x1dI\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\t\t\x00\x00\x00\x00\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00@\x1f\x15!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\r\x03ik\x1dK\x05\x03\r\x03oq\x1dM\x1dO\x0b\x03\x1dQ\x1dS\x03\x01\x05\x01\x03\x03G\x03\x03\x81\x15\x03\x01\x01\x01\x03\x05G\x85\x1f!\x01\x07\x01\x1f\x15!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x13\x0b\x01\x13\x0b\x05\x07\x05\x15\t\x11\x11\x11\x11\x03\x93\x05\x99\x9f\x01\x01\x01\x01\x01\x13\x05\x95\x97\x11\x03\x01\x11\x03\x05\x13\x05\x9b\x9d\x11\x03\t\x11\x03\r\x13\x01\x01\t\x01\x02\x02)\x05\x11\x11\x11)\x01\x11)\x01\x13\x1d)\x05\x11\x11\x13\x01\x0b\x1b)\x03\t\x0b\x13)\x05\x11\x11\x0f\x11\x03\x05\x03\x05)\x03\x01\x0b)\x03\t\x17)\x03\x01\x17)\x01\x0f)\x05\x05\x05\x0f\x04.\x03\x05\x01Q\x05\x0f\x01\x07\x04\x06\x03\x03\x01\x05\x0fP\x05\x03\x07\x04\xda\x02\x031_\x03\x0b\x1b\x00\x05B\x05\x05\x03\x07\x05B\x01\x07\x03\x07\x05B\x05\t\x03\t\x05B\x05\x0b\x03\x07\x11F'\r\x03\x05\x03\x01\x07\x06\x07\x03\x05\x05\x01\x0b\x03F\x0b\x0f\x03\x05\x03\t\x13\x06\x0b\x03\x05\x05\r\x0f\x15G\x01/\x11\x05\x05\t\x03\x11\x03F\x01\x0f\x03\t\x03\x07\tF\x01\x13\x03#\x05\x15\x17\x03F\x01\x0f\x03%\x03\x19\x03F\x01\x0f\x03\x05\x03\x05\x03F\x01\x15\x03\x19\x03\x1b\x0b\x06\x01\x03\x05\x07\x1f\x13\x1d\rB\r\x17\x03\r\x03F\x07\x0f\x03\r\x03\x07\x07\x06\x07\x03\r\x05#%\rB\r\x19\x03\r\tF9\x1b\x03\x19\x05')\x03F=\x0f\x03\x05\x03\x03\x0b\x06A\x03\x05\x07+!-\x17\x04\x05\x03/\x06\x03\x01\x05\x01\x00r\x08U'\x03\x05\x1f\r\x0f\x0b\x0f!\x13#\x07\x0b%3)\t\t\x15i\x1d\x13\x05\x1b%)9\x15\x1f\x15\x1b\x11\x11\x15\x17\x0f\x19)\x0f\t\x0b\x11builtin\x00vhlo\x00sdy\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00add_v1\x00compare_v1\x00select_v1\x00iota_v1\x00func_v1\x00transpose_v1\x00divide_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_cholesky\x00x\x00cholesky\x00jit(cholesky)\x00third_party/py/jax/tests/export_back_compat_test.py\x00transpose\x00add\x00div\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00sdy.sharding_rule\x00iota\x00ge\x00broadcast_in_dim\x00select_n\x00jax.result_info\x00result\x00main\x00public\x00lower\x00num_batch_dims\x000\x00\x00cusolver_potrf_ffi\x00\x08W\x1d\x053\x01\x0bKOQY[\x03]\x03_\x03a\x03c\x03e\x03E\x11suwy{}\x7f\x83\x05I\x87\x03\x89\x03\x8b\x03\x8d\x05I\x8f", + xla_call_module_version=10, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_10_15["c64"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cusolver_potrf_ffi'], + serialized_date=datetime.date(2025, 10, 15), + inputs=(array([[123.867 +0.j , 87.94228 +8.670872j, + -28.046396 +75.86514j , -44.46436 -25.990704j], + [ 87.94228 -8.670872j, 122.20061 +0.j , + -4.3756332+54.971073j, -21.559742 +1.933352j], + [-28.046396 -75.86514j , -4.3756332-54.971073j, + 101.799835 +0.j , -3.7747602+28.983824j], + [-44.46436 +25.990704j, -21.559742 -1.933352j, + -3.7747602-28.983824j, 83.28284 +0.j ]], dtype=complex64),), + expected_outputs=(array([[11.129555 +0.j , 0. +0.j , + 0. +0.j , 0. +0.j ], + [ 7.901689 -0.7790852j , 7.6913548 +0.j , + 0. +0.j , 0. +0.j ], + [-2.5199926 -6.8165474j , 1.3295308 +0.11109254j, + 6.8705287 +0.j , 0. +0.j ], + [-3.9951606 +2.3352869j , 1.5378516 -2.2458322j , + 0.04088954+1.0612036j , 7.3028345 +0.j ]], + dtype=complex64),), + mlir_module_text=r""" +#loc1 = loc("x") +module @jit_cholesky attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<4x4xcomplex> loc("x")) -> (tensor<4x4xcomplex> {jax.result_info = "result"}) { + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc) + %cst_0 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc4) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %cst_1 = stablehlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor> loc(#loc) + %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x4xcomplex>) -> tensor<4x4xcomplex> loc(#loc5) + %1 = stablehlo.real %0 : (tensor<4x4xcomplex>) -> tensor<4x4xf32> loc(#loc6) + %2 = stablehlo.imag %0 : (tensor<4x4xcomplex>) -> tensor<4x4xf32> loc(#loc6) + %3 = stablehlo.negate %2 : tensor<4x4xf32> loc(#loc6) + %4 = stablehlo.complex %1, %3 : tensor<4x4xcomplex> loc(#loc6) + %5 = stablehlo.add %arg0, %4 : tensor<4x4xcomplex> loc(#loc7) + %6 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc8) + %7 = stablehlo.divide %5, %6 : tensor<4x4xcomplex> loc(#loc8) + %8:2 = stablehlo.custom_call @cusolver_potrf_ffi(%7) {mhlo.backend_config = {lower = true}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>], sdy.sharding_rule = #sdy.op_sharding_rule<([i, j])->([k, l], []) {i=4, j=4, k=4, l=4}, custom>} : (tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor) loc(#loc4) + %9 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor loc(#loc4) + %10 = stablehlo.compare EQ, %8#1, %9, SIGNED : (tensor, tensor) -> tensor loc(#loc4) + %11 = stablehlo.broadcast_in_dim %10, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc4) + %12 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc4) + %13 = stablehlo.broadcast_in_dim %11, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc4) + %14 = stablehlo.select %13, %8#0, %12 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc4) + %15 = stablehlo.iota dim = 0 : tensor<4x4xi32> loc(#loc9) + %16 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<4x4xi32> loc(#loc7) + %17 = stablehlo.add %15, %16 : tensor<4x4xi32> loc(#loc7) + %18 = stablehlo.iota dim = 1 : tensor<4x4xi32> loc(#loc9) + %19 = stablehlo.compare GE, %17, %18, SIGNED : (tensor<4x4xi32>, tensor<4x4xi32>) -> tensor<4x4xi1> loc(#loc10) + %20 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc11) + %21 = stablehlo.select %19, %14, %20 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc12) + return %21 : tensor<4x4xcomplex> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":225:6) +#loc3 = loc("jit(cholesky)"(#loc2)) +#loc4 = loc("cholesky"(#loc3)) +#loc5 = loc("transpose"(#loc3)) +#loc6 = loc(""(#loc3)) +#loc7 = loc("add"(#loc3)) +#loc8 = loc("div"(#loc3)) +#loc9 = loc("iota"(#loc3)) +#loc10 = loc("ge"(#loc3)) +#loc11 = loc("broadcast_in_dim"(#loc3)) +#loc12 = loc("select_n"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.12.1\x00\x013\x07\x01\x05\t!\x01\x03\x0f\x03\x1f\x13\x17\x1b\x1f#'+/37;?CGK\x03\xe7\xa5+\x01I\x0f\x0f\x07\x0f\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0b\x0b\x17\x0b\x0f\x0b\x0b\x0b\x0b#\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x03M\x0fO\x0b\x0f\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b//\x1f/O\x13\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x13\x0f\x0bO\x0f\x0f\x0b\x05\x11C\x13\x0f\x0f\x13\x0f\x0f\x0b\x01\x05\x0b\x0f\x03'\x17\x0f\x0f\x07\x17\x17\x07\x0b\x07\x07\x13\x07\x17\x17\x13\x13\x13\x0f\x17\x02\xfa\x05\x1d#%\x1d!\x01\x1f\x1d-\x01\x1d/\x01\x11\x03\x05\x1d1\x01\x1d;\x01\x03\x07\x13\x15\x17\x0b\x19\x0b\x05'\x11\x01\x00\x05)\x05+\x05-\x1d\x1f\x05\x05/\x051\x053\x17'\x86\x03\r\x055\x1d+\x01\x057\x059\x05;\x05=\x03\x075k7q9\x95\x05?\x05A\x05C\x05E\x1d?\x01\x05G\x1dC\x01\x05I\x1dG\x01\x05K\x1f!\x01\x1f#!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\t\x07\x03\x03Q\r\x01#\x1f\x03\x03W\r\x03Y[\x1dM\x1dO\x1dQ\x1dS\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\t\t\x00\x00\x00\x00\x1f\x07\x11\x00\x00\x00@\x00\x00\x00\x00\x1f\x19!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\r\x03mo\x1dU\x05\x03\r\x03su\x1dW\x1dY\x0b\x03\x1d9\x1d[\x03\x01\x05\x01\x03\x03K\x03\x03\x85\x15\x03\x01\x01\x01\x03\x05K\x89\x1f%\x01\x07\x01\x1f\x19!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x13\x0b\x01\x13\x0b\x05\x07\x05\x15\t\x11\x11\x11\x11\x03\x97\x05\x9d\xa3\x01\x01\x01\x01\x01\x13\x05\x99\x9b\x11\x03\x01\x11\x03\x05\x13\x05\x9f\xa1\x11\x03\t\x11\x03\r\x13\x01\x01\t\x01\x02\x02)\x05\x11\x11\x13)\x01\x13)\x01\x17\x1d)\x05\x11\x11\x17)\x05\x11\x11\x15\x01\x03\x15\t\x1b)\x03\t\x0b\x13)\x05\x11\x11\x11\x11\x03\x05\x03\x05)\x03\x01\x0b)\x03\t\x1b)\x03\x01\x1b)\x01\x11)\x05\x05\x05\x11\x04\xa2\x03\x05\x01Q\x05\x11\x01\x07\x04z\x03\x03\x01\x05\x0fP\x05\x03\x07\x04N\x03\x039o\x03\x0b\x1d\x00\x05B\x05\x05\x03\x07\x05B\x03\x07\x03\x07\x05B\x05\t\x03\t\x05B\x05\x0b\x03\x07\x11F)\r\x03\x05\x03\x01\x13\x06\x07\x03\x0f\x03\x0b\x15\x06\x07\x03\x0f\x03\x0b\x17\x06\x07\x03\x0f\x03\x0f\x19\x06\x07\x03\x05\x05\r\x11\x07\x06\t\x03\x05\x05\x01\x13\x03F\r\x0f\x03\x05\x03\t\x1b\x06\r\x03\x05\x05\x15\x17\x1dG\x033\x11\x05\x05\t\x03\x19\x03F\x03\x0f\x03\t\x03\x07\tF\x03\x13\x03'\x05\x1d\x1f\x03F\x03\x0f\x03)\x03!\x03F\x03\x0f\x03\x05\x03\x05\x03F\x03\x15\x03\x1d\x03#\x0b\x06\x03\x03\x05\x07'\x1b%\rB\x0f\x17\x03\r\x03F\t\x0f\x03\r\x03\x07\x07\x06\t\x03\r\x05+-\rB\x0f\x19\x03\r\tF=\x1b\x03\x1d\x05/1\x03FA\x0f\x03\x05\x03\x03\x0b\x06E\x03\x05\x073)5\x1f\x04\x05\x037\x06\x03\x01\x05\x01\x00\x16\t]'\x05\x1f\r\x0f\x0b\x0f!\x13#\x07\x0b%3)\t\t\x03\x15i\x1d\x13\x05\x1b%)9\x15\x1f\x15\x17\x15\x11\x11\x1b\x11\x11\x15\x17\x0f\x19)\x0f\t\x0b\x11builtin\x00vhlo\x00sdy\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00add_v1\x00compare_v1\x00select_v1\x00iota_v1\x00func_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00divide_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_cholesky\x00x\x00cholesky\x00jit(cholesky)\x00third_party/py/jax/tests/export_back_compat_test.py\x00transpose\x00\x00add\x00div\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00sdy.sharding_rule\x00iota\x00ge\x00broadcast_in_dim\x00select_n\x00jax.result_info\x00result\x00main\x00public\x00lower\x00num_batch_dims\x000\x00cusolver_potrf_ffi\x00\x08W\x1d\x057\x01\x0bOSU]_\x03a\x03c\x03e\x03g\x03i\x03I\x11wy{}\x7f\x81\x83\x87\x05M\x8b\x03\x8d\x03\x8f\x03\x91\x05M\x93", + xla_call_module_version=10, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_10_15["c128"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cusolver_potrf_ffi'], + serialized_date=datetime.date(2025, 10, 15), + inputs=(array([[145.98892137813692 +0.j , + 40.91401793296874 +4.175781485327598j, + 2.782341135635754 +9.885118329189208j, + 26.733955883991726 +52.65661439791964j ], + [ 40.91401793296874 -4.175781485327598j, + 33.78265769398051 +0.j , + 3.6132624138937786 +5.213542211682853j, + 4.589810669550227 -12.339958092149333j], + [ 2.782341135635754 -9.885118329189208j, + 3.6132624138937786 -5.213542211682853j, + 93.29157057865525 +0.j , + -0.20930676609536647-34.05459375322113j ], + [ 26.733955883991726 -52.65661439791964j , + 4.589810669550227 +12.339958092149333j, + -0.20930676609536647+34.05459375322113j , + 78.03147614159427 +0.j ]]),), + expected_outputs=(array([[12.082587528263014 +0.j , + 0. +0.j , + 0. +0.j , + 0. +0.j ], + [ 3.386196693155718 -0.3456032472812474j, + 4.711357346318622 +0.j , + 0. +0.j , + 0. +0.j ], + [ 0.2302769277795368 -0.8181292546870784j, + 0.5414047646829693 -0.5354677863425775j, + 9.59110852656672 +0.j , + 0. +0.j ], + [ 2.2126018803055993 -4.358057764923928j , + -0.935750165427917 +5.589157126898457j , + -0.08182988296307511+3.2032822133798207j, + 3.4294580838507387 +0.j ]]),), + mlir_module_text=r""" +#loc1 = loc("x") +module @jit_cholesky attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<4x4xcomplex> loc("x")) -> (tensor<4x4xcomplex> {jax.result_info = "result"}) { + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc) + %cst_0 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc4) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %cst_1 = stablehlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor> loc(#loc) + %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x4xcomplex>) -> tensor<4x4xcomplex> loc(#loc5) + %1 = stablehlo.real %0 : (tensor<4x4xcomplex>) -> tensor<4x4xf64> loc(#loc6) + %2 = stablehlo.imag %0 : (tensor<4x4xcomplex>) -> tensor<4x4xf64> loc(#loc6) + %3 = stablehlo.negate %2 : tensor<4x4xf64> loc(#loc6) + %4 = stablehlo.complex %1, %3 : tensor<4x4xcomplex> loc(#loc6) + %5 = stablehlo.add %arg0, %4 : tensor<4x4xcomplex> loc(#loc7) + %6 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc8) + %7 = stablehlo.divide %5, %6 : tensor<4x4xcomplex> loc(#loc8) + %8:2 = stablehlo.custom_call @cusolver_potrf_ffi(%7) {mhlo.backend_config = {lower = true}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>], sdy.sharding_rule = #sdy.op_sharding_rule<([i, j])->([k, l], []) {i=4, j=4, k=4, l=4}, custom>} : (tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor) loc(#loc4) + %9 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor loc(#loc4) + %10 = stablehlo.compare EQ, %8#1, %9, SIGNED : (tensor, tensor) -> tensor loc(#loc4) + %11 = stablehlo.broadcast_in_dim %10, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc4) + %12 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc4) + %13 = stablehlo.broadcast_in_dim %11, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc4) + %14 = stablehlo.select %13, %8#0, %12 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc4) + %15 = stablehlo.iota dim = 0 : tensor<4x4xi32> loc(#loc9) + %16 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<4x4xi32> loc(#loc7) + %17 = stablehlo.add %15, %16 : tensor<4x4xi32> loc(#loc7) + %18 = stablehlo.iota dim = 1 : tensor<4x4xi32> loc(#loc9) + %19 = stablehlo.compare GE, %17, %18, SIGNED : (tensor<4x4xi32>, tensor<4x4xi32>) -> tensor<4x4xi1> loc(#loc10) + %20 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc11) + %21 = stablehlo.select %19, %14, %20 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc12) + return %21 : tensor<4x4xcomplex> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":225:6) +#loc3 = loc("jit(cholesky)"(#loc2)) +#loc4 = loc("cholesky"(#loc3)) +#loc5 = loc("transpose"(#loc3)) +#loc6 = loc(""(#loc3)) +#loc7 = loc("add"(#loc3)) +#loc8 = loc("div"(#loc3)) +#loc9 = loc("iota"(#loc3)) +#loc10 = loc("ge"(#loc3)) +#loc11 = loc("broadcast_in_dim"(#loc3)) +#loc12 = loc("select_n"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.12.1\x00\x013\x07\x01\x05\t!\x01\x03\x0f\x03\x1f\x13\x17\x1b\x1f#'+/37;?CGK\x03\xe7\xa5+\x01I\x0f\x0f\x07\x0f\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0b\x0b\x17\x0b\x0f\x0b\x0b\x0b\x0b#\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x03M\x0fO\x0b\x0f\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0bOO\x1fOO\x13\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x13\x0f\x0bO\x0f\x0f\x0b\x05\x11C\x13\x0f\x0f\x13\x0f\x0f\x0b\x01\x05\x0b\x0f\x03'\x17\x0f\x0f\x07\x17\x17\x07\x0b\x07\x07\x13\x07\x17\x17\x13\x13\x13\x0f\x17\x02Z\x06\x1d#%\x1d!\x01\x1f\x1d-\x01\x1d/\x01\x11\x03\x05\x1d1\x01\x1d;\x01\x03\x07\x13\x15\x17\x0b\x19\x0b\x05'\x11\x01\x00\x05)\x05+\x05-\x1d\x1f\x05\x05/\x051\x053\x17'\x86\x03\r\x055\x1d+\x01\x057\x059\x05;\x05=\x03\x075k7q9\x95\x05?\x05A\x05C\x05E\x1d?\x01\x05G\x1dC\x01\x05I\x1dG\x01\x05K\x1f!\x01\x1f#!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\t\x07\x03\x03Q\r\x01#\x1f\x03\x03W\r\x03Y[\x1dM\x1dO\x1dQ\x1dS\x1f\x07!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\t\t\x00\x00\x00\x00\x1f\x07!\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x19!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\r\x03mo\x1dU\x05\x03\r\x03su\x1dW\x1dY\x0b\x03\x1d9\x1d[\x03\x01\x05\x01\x03\x03K\x03\x03\x85\x15\x03\x01\x01\x01\x03\x05K\x89\x1f%\x01\x07\x01\x1f\x19!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x13\x0b\x01\x13\x0b\x05\x07\x05\x15\t\x11\x11\x11\x11\x03\x97\x05\x9d\xa3\x01\x01\x01\x01\x01\x13\x05\x99\x9b\x11\x03\x01\x11\x03\x05\x13\x05\x9f\xa1\x11\x03\t\x11\x03\r\x13\x01\x01\t\x01\x02\x02)\x05\x11\x11\x13)\x01\x13)\x01\x17\x1d)\x05\x11\x11\x17)\x05\x11\x11\x15\x01\x03\x15\x0b\x1b)\x03\t\x0b\x13)\x05\x11\x11\x11\x11\x03\x05\x03\x05)\x03\x01\x0b)\x03\t\x1b)\x03\x01\x1b)\x01\x11)\x05\x05\x05\x11\x04\xa2\x03\x05\x01Q\x05\x11\x01\x07\x04z\x03\x03\x01\x05\x0fP\x05\x03\x07\x04N\x03\x039o\x03\x0b\x1d\x00\x05B\x05\x05\x03\x07\x05B\x03\x07\x03\x07\x05B\x05\t\x03\t\x05B\x05\x0b\x03\x07\x11F)\r\x03\x05\x03\x01\x13\x06\x07\x03\x0f\x03\x0b\x15\x06\x07\x03\x0f\x03\x0b\x17\x06\x07\x03\x0f\x03\x0f\x19\x06\x07\x03\x05\x05\r\x11\x07\x06\t\x03\x05\x05\x01\x13\x03F\r\x0f\x03\x05\x03\t\x1b\x06\r\x03\x05\x05\x15\x17\x1dG\x033\x11\x05\x05\t\x03\x19\x03F\x03\x0f\x03\t\x03\x07\tF\x03\x13\x03'\x05\x1d\x1f\x03F\x03\x0f\x03)\x03!\x03F\x03\x0f\x03\x05\x03\x05\x03F\x03\x15\x03\x1d\x03#\x0b\x06\x03\x03\x05\x07'\x1b%\rB\x0f\x17\x03\r\x03F\t\x0f\x03\r\x03\x07\x07\x06\t\x03\r\x05+-\rB\x0f\x19\x03\r\tF=\x1b\x03\x1d\x05/1\x03FA\x0f\x03\x05\x03\x03\x0b\x06E\x03\x05\x073)5\x1f\x04\x05\x037\x06\x03\x01\x05\x01\x00\x16\t]'\x05\x1f\r\x0f\x0b\x0f!\x13#\x07\x0b%3)\t\t\x03\x15i\x1d\x13\x05\x1b%)9\x15\x1f\x15\x17\x15\x11\x11\x1b\x11\x11\x15\x17\x0f\x19)\x0f\t\x0b\x11builtin\x00vhlo\x00sdy\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00add_v1\x00compare_v1\x00select_v1\x00iota_v1\x00func_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00divide_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_cholesky\x00x\x00cholesky\x00jit(cholesky)\x00third_party/py/jax/tests/export_back_compat_test.py\x00transpose\x00\x00add\x00div\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00sdy.sharding_rule\x00iota\x00ge\x00broadcast_in_dim\x00select_n\x00jax.result_info\x00result\x00main\x00public\x00lower\x00num_batch_dims\x000\x00cusolver_potrf_ffi\x00\x08W\x1d\x057\x01\x0bOSU]_\x03a\x03c\x03e\x03g\x03i\x03I\x11wy{}\x7f\x81\x83\x87\x05M\x8b\x03\x8d\x03\x8f\x03\x91\x05M\x93", + xla_call_module_version=10, + nr_devices=1, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_eigh_cusolver_syev.py b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_eigh_cusolver_syev.py index 56479e82f9d9..aba2e4083d2c 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_eigh_cusolver_syev.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_eigh_cusolver_syev.py @@ -15,1400 +15,11 @@ # ruff: noqa import datetime -from numpy import array, float32, complex64 +import numpy as np -data_2023_03_17=dict( - # Pasted from the test output (see back_compat_test.py module docstring) - f32_syevj=dict( - testdata_version=1, - platform='cuda', - custom_call_targets=['cusolver_syevj'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[ 6.18577063e-01, -8.00570633e-05, -1.96905047e-01, - -8.95753130e-02, 7.24549413e-01, -1.07546024e-01, - -4.77200520e-04, 1.84469908e-01], - [ 4.70708847e-01, 3.31519186e-05, 2.80930042e-01, - -5.84393919e-01, -4.93098050e-01, -2.50211239e-01, - -1.14346610e-03, 2.28566617e-01], - [ 3.22840720e-01, -5.11042356e-01, -3.03526163e-01, - 2.48800799e-01, -3.14544559e-01, 5.54342926e-01, - 1.10838346e-06, 2.72663534e-01], - [ 1.74972475e-01, 4.18093473e-01, -2.66933769e-01, - 5.78716159e-01, -2.97307134e-01, -4.46864694e-01, - 1.09066934e-06, 3.16760242e-01], - [ 2.71042082e-02, 4.29418474e-01, 4.71952170e-01, - 1.10573582e-01, 9.57800150e-02, 4.65731144e-01, - -4.72866714e-01, 3.60856950e-01], - [-1.20763958e-01, -3.84347916e-01, 5.79687178e-01, - 2.87678182e-01, 1.63329691e-01, -2.02215970e-01, - 4.32829827e-01, 4.04953718e-01], - [-2.68632114e-01, 3.63640338e-01, -2.97110289e-01, - -3.32554609e-01, 3.46945561e-02, 2.77071655e-01, - 5.63131213e-01, 4.49050426e-01], - [-4.16500419e-01, -3.15715015e-01, -2.68094122e-01, - -2.19244853e-01, 8.65960941e-02, -2.90307850e-01, - -5.21475971e-01, 4.93147314e-01]], dtype=float32), array([-2.4598812e+01, -2.4345848e-06, -1.2664314e-06, -8.6959182e-07, - -8.2917722e-07, 1.6633214e-06, 2.0499781e-06, 2.7659885e+02], - dtype=float32)), - mlir_module_text=""" -module @jit__lambda_ { - func.func public @main() -> (tensor<8x8xf32> {jax.result_info = "[0]"}, tensor<8xf32> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<64xf32> - %1 = stablehlo.reshape %0 : (tensor<64xf32>) -> tensor<8x8xf32> - %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xf32>) -> tensor<8x8xf32> - %3 = stablehlo.add %1, %2 : tensor<8x8xf32> - %4 = stablehlo.constant dense<2.000000e+00> : tensor - %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<8x8xf32> - %6 = stablehlo.divide %3, %5 : tensor<8x8xf32> - %7 = call @tril(%6) : (tensor<8x8xf32>) -> tensor<8x8xf32> - %8 = stablehlo.custom_call @cusolver_syevj(%7) {api_version = 2 : i32, backend_config = "\00\00\00\00\00\00\00\00\01\00\00\00\08\00\00\00M\08\00\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<8x8xf32>) -> tuple, tensor<8xf32>, tensor, tensor<2125xf32>> - %9 = stablehlo.get_tuple_element %8[0] : (tuple, tensor<8xf32>, tensor, tensor<2125xf32>>) -> tensor<8x8xf32> - %10 = stablehlo.get_tuple_element %8[1] : (tuple, tensor<8xf32>, tensor, tensor<2125xf32>>) -> tensor<8xf32> - %11 = stablehlo.get_tuple_element %8[2] : (tuple, tensor<8xf32>, tensor, tensor<2125xf32>>) -> tensor - %12 = stablehlo.get_tuple_element %8[3] : (tuple, tensor<8xf32>, tensor, tensor<2125xf32>>) -> tensor<2125xf32> - %13 = stablehlo.constant dense<0> : tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor - %15 = stablehlo.compare EQ, %11, %14, SIGNED : (tensor, tensor) -> tensor - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1x1xi1> - %17 = stablehlo.constant dense<0x7FC00000> : tensor - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<8x8xf32> - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> - %20 = stablehlo.select %19, %9, %18 : tensor<8x8xi1>, tensor<8x8xf32> - %21 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1xi1> - %22 = stablehlo.constant dense<0x7FC00000> : tensor - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor<8xf32> - %24 = stablehlo.broadcast_in_dim %21, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> - %25 = stablehlo.select %24, %10, %23 : tensor<8xi1>, tensor<8xf32> - return %20, %25 : tensor<8x8xf32>, tensor<8xf32> - } - func.func private @tril(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { - %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> - %1 = stablehlo.constant dense<0> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<8x8xi32> - %3 = stablehlo.add %0, %2 : tensor<8x8xi32> - %4 = stablehlo.iota dim = 1 : tensor<8x8xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<8x8xf32> - %8 = stablehlo.select %5, %arg0, %7 : tensor<8x8xi1>, tensor<8x8xf32> - return %8 : tensor<8x8xf32> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01-\x05\x01\x05\x01\x03\x05\x03\x1d\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!\x03^\x02\xeb5\x01\x95\x0f\x17\x13\x07\x0f\x0b\x0b\x0b\x0b\x0b\x17\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x17\x0f\x13\x13\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x13\x0bK\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x03W\x0b\x0b\x0f\x0b\x0bO/\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1fO\x1f\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1b\x0f\x0f\x0f\x0f\x0f\x0b\x1fO/\x035\x17\x0f\x07\x0f\x07\x13\x07\x07\x17\x07\x17\x13\x17\x17\x17\x13\x17\x1b\x13\x13\x13\x0f\x17\x13\x13\x13\x02r\x08\x1d\x85\x03\x17\x116\x04\x01\x03\x03\x13\xbd\x1f\x1d9\x03\x05#\x05%\x05'\x05)\x05+\x17\x112\x04\x01\x05-\x05/\x051\x053\x03\x03!\xb9\x055\x03\x03\x0b\xbb\x1d?\x03\x057\x059\x17\x11*\x04\x01\x1dm\x15\x03\x03\x0b\xe5\x03\x03\x0f3\x05;\x03\x0b\x17\x95\x19\xa3\x1b\xa5\x0f\xaf\x1d\xb1\x03\x0b\x17\x99\x19\xb5\x1b\x99\x0f\x9b\x1d\xb7\x05=\x1d=\x03\x05?\x05A\x03\x03!\xbf\x1dE\x03\x05C\x03\x05'\x9d)\xc1\x1dK\x03\x05E\x03\x03\x0b\xc3\x1dQ\x03\x05G\x1dU\x03\x05I\x1dY+\x05K\x1d]+\x05M\x03\x03a\xc5\x05O\x1de\x15\x05Q\x1di\x15\x05S\x03\x03\x0b\xc7\x05U\x03\x03q\x9b\x05W\x03\x11u\xc9w\xcby\xcd{\x95}\xcf\x7f\xd1\x81\xd3\x83\xd7\x05Y\x05[\x05]\x05_\x05a\x05c\x05e\x05g\x05i\x03\x03\r\xdb\x03\x03\r\xdd\x03\x03\r\xdf\x03\x03\r\xe1\x03\x05'\x9d)\xe3\x03\x03\x13\xe7\x03\x03\x13\xe9\x03\x01\x1dk\x03\x03\xb3\x1dm\t\x07\x1f%!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1b\x03\x05\xa7\xab\r\x03\x97\xa9\x1do\r\x03\x97\xad\x1dq\x1ds\x1du\r\x01#\x1d\x1dw\x13\r\x01\x1f\x07\t\x00\x00\x00\x00\x1f\x1f\x01\x13\r\x05\x07\x05\x1f\x03\t\x00\x00\x00\x00\x1f\x17!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x00\x00\x00@\x0b\x05\x1dy\x1d{\x05\x01\x03\x03\x9f\x03\x03\xd5\x15\x03\x01\x01\x01\x03\t\x9f\xa1\xd9\xa1\x1f)\x01\x13\x05\x01\x13\x05\x05\x13\x05\t\x13\x05\r\x07\x01\x1f\x03\t\x00\x00\xc0\x7f\x1f\x17!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00)\x05!!\t)\x01\t\x1b)\x01\x05\t)\x03!\t\x1d\x01)\x05!!\x05\x13)\x05!!\x0f)\x03\t\r)\x03jB\t\x11\x01\x05\x01\x0b\x11\x03\x01\x03\x01)\x03\x01\r)\x03\x02\x02\t/\t\x01\x0b\x07\x19)\x03\t\x13)\x03\x05\x13)\x03\x01\x13)\x01\x0f)\x05\x05\x05\x0f)\x03\x05\x0f)\x03!\x0f)\x03\x05\r\x04\xc6\x04\x05\x01\x11\x071\x07\x03\x01\t\r\x11\x075\x05\x035m\t\x03W\x1f\x03!\x15\x06[\x03\x01\x03\x01\x17\x07c_\x03\x01\x03\x03\x0f\x06g\x03\x01\x05\x03\x05\x05\x03\x07k\x03\x03\x03\x07-\x05\x03\x01\x03\t\x19\x06-\x03\x01\x05\x07\x0b\x1b\x07\to\x03\x01\x03\r\x1d\x07\x01s\x03#\x03\x0f\x07\x07\x01\x87\x03\x01\x03\x11\x07\x07\x01\x89\x03\x0b\x03\x11\x07\x07\x01\x8b\x03\x07\x03\x11\x07\x07\x01\x8d\x03\x19\x03\x11\x05\x03\x01#\x03\x07\x03\x07\x01\x05\x03\x07\x03\x1b\x11\x07\x01\x8f\x03+\x05\x17\x1d\x03\x07\x01\x05\x03-\x03\x1f\x05\x03\x01/\x03\x03\x03\x07\x01\x05\x03\x01\x03#\x03\x07\x01\x91\x03\x15\x03!\x0b\x06\x01\x03\x01\x07'\x13%\x03\x07\x01\x05\x03/\x03\x1f\x05\x03\x01/\x03\x03\x03\x07\x01\x05\x03\x0b\x03-\x03\x07\x01\x93\x031\x03+\x0b\x06\x01\x03\x0b\x071\x15/\x13\x04\x07\x05)3\r\x11\t7\x05\x03\x15+\x03\x01\x07\t\x03;\x1f\x03\x11\x05\x03\t#\x03\x07\x03\x07%\x05\x03\x11\x03\x05\x0f\x06%\x03\x11\x05\x03\x07\t\x03CA\x03\x11\x11\x07IG\x03\x15\x05\t\x0b\x05\x03\tM\x03\x03\x03\x07O\x05\x03\x01\x03\x0f\x0b\x06S\x03\x01\x07\r\x01\x11\x13\x04\t\x03\x13\x06\x03\x01\x05\x01\x00\x06\x1a}\x1f+\x11\x0f\x0b\t\t\x0b!\x7f\x1f/!!)#\x1f\x19\x0f99m\x19\x85\x89W\xb3K\x9bM\x9b\x96\x04\x1b+\x1b\x1f\x1f\x15\x1d\x15+\x83\x13\r\r\x1f\x11\x15\x1b\x17\x15\x17\x0f\x11\x15\x11+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00value\x00index\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True]\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x08\x00\x00\x00M\x08\x00\x00\x00cusolver_syevj\x00", - xla_call_module_version=4, - ), # End paste - - # Pasted from the test output (see back_compat_test.py module docstring) - f32_syevd=dict( - testdata_version=1, - platform='cuda', - custom_call_targets=['cusolver_syevd'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[ 3.14863890e-01, 0.00000000e+00, 0.00000000e+00, - 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, - 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, - 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, - 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, - 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, - 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, - -4.91220355e-01, 0.00000000e+00, 0.00000000e+00, - 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, - 0.00000000e+00, 0.00000000e+00, 8.05416584e-01, - 0.00000000e+00, -1.77893345e-03, -2.64500137e-02, - 1.46598322e-04, -5.19353598e-02, -8.64148438e-02], - [ 2.99391806e-01, 2.77544819e-02, 6.73292065e-03, - -6.83086272e-03, -3.54272849e-03, -1.21014733e-02, - -1.32716037e-02, -1.15843862e-03, -8.83520208e-03, - -6.63395738e-03, 1.60171092e-03, -1.01765711e-03, - 1.19860061e-02, -1.33239310e-02, 1.76237477e-03, - 1.27085261e-02, 3.38556734e-03, -8.78101215e-03, - 1.58616400e-03, -7.37631368e-03, 3.81911686e-03, - -5.18379211e-02, -7.22059654e-03, 1.85085051e-02, - 2.94725411e-03, 4.74284729e-03, -1.33781182e-02, - -3.61499190e-03, -5.49228955e-03, -1.05845921e-01, - 1.01772454e-02, 4.47412670e-01, 1.95654288e-01, - 3.94686669e-01, 7.00925171e-01, -9.06614065e-02], - [ 2.83920437e-01, 1.69272088e-02, 6.64264262e-02, - -1.18565477e-01, 3.54601629e-02, -1.52457461e-01, - 6.84847543e-03, 1.90414500e-03, -2.76310533e-01, - 3.76881436e-02, 1.22269124e-01, -1.01556584e-01, - -1.90264836e-01, -1.16590485e-01, 6.09031200e-01, - -9.43092555e-02, -3.74726858e-03, -2.33182713e-01, - 1.95203945e-01, -1.20613754e-01, 3.94887812e-02, - -5.88066364e-03, 1.19152360e-01, -1.46030456e-01, - -4.74781469e-02, 2.67041594e-01, -1.22617789e-01, - 5.77996820e-02, 2.58437768e-02, -1.34434626e-01, - -3.28330845e-02, -9.32494774e-02, 1.14714004e-01, - 1.21207587e-01, -2.04871535e-01, -9.49072391e-02], - [ 2.68448830e-01, 2.17946004e-02, -1.94895901e-02, - 3.40374447e-02, 6.18659109e-02, 1.72068894e-01, - -8.02555401e-03, 9.68076065e-02, 4.98391055e-02, - 5.55528253e-02, -3.23998183e-02, -2.63249427e-01, - -4.35045222e-03, 5.20016700e-02, -5.92328422e-02, - 4.31317724e-02, -2.00986061e-02, -2.69871447e-02, - 1.54309347e-01, 1.74670279e-01, -4.97168908e-03, - -4.15510803e-01, -4.33471389e-02, -3.71299796e-02, - 5.26434295e-02, -1.18867345e-01, -2.42547281e-02, - -3.90263759e-02, -2.58720964e-01, -3.92957211e-01, - -1.28192365e-01, 2.77028710e-01, -4.02157485e-01, - -1.77024350e-01, -1.76668167e-01, -9.91534367e-02], - [ 2.52977222e-01, 3.48518007e-02, 7.02044442e-02, - 1.42712081e-02, 4.50692251e-02, 7.16193160e-03, - 1.19931757e-01, 2.32399218e-02, -6.05047755e-02, - 1.06077030e-01, 1.03731848e-01, -1.13200452e-02, - 5.94755262e-03, -2.32813850e-01, 8.72232541e-02, - 8.17264095e-02, 3.30835059e-02, 4.88227099e-01, - 6.14454560e-02, 1.43805355e-01, -7.40422234e-02, - 2.25823849e-01, -3.86487693e-01, 1.30468249e-01, - 3.16427708e-01, -1.19733319e-01, -4.18486483e-02, - -2.74667948e-01, -2.16731444e-01, 2.60375626e-02, - 5.77645637e-02, -7.56322592e-02, 2.28632554e-01, - 2.37157010e-02, -1.40153974e-01, -1.03399649e-01], - [ 2.37505659e-01, 7.01064467e-02, -3.83728333e-02, - 5.06979637e-02, 1.83892641e-02, 4.02548499e-02, - -3.88330072e-02, 3.13181393e-02, -5.75652197e-02, - 7.04995319e-02, -6.92743529e-03, -9.82947052e-02, - -4.91717793e-02, 4.06844541e-02, -1.53035461e-03, - 4.68783826e-02, 5.36918640e-03, -1.67432979e-01, - 1.03467651e-01, 3.48554403e-02, 3.20128165e-02, - 4.70223904e-01, 9.19904634e-02, 6.90946281e-02, - -6.94891065e-02, 3.92344594e-02, -6.30731881e-02, - 2.22810470e-02, -3.87494615e-03, 1.96694940e-01, - -1.92701817e-02, 2.01028123e-01, 1.89283062e-02, - -6.97807550e-01, 2.03354478e-01, -1.07645869e-01], - [ 2.22034067e-01, -1.60748392e-01, 2.42968962e-01, - -3.35482806e-01, -3.41870189e-02, 1.28819138e-01, - 1.24212839e-01, -3.87125909e-02, -5.60933471e-01, - 7.95257688e-02, -3.60307507e-02, 3.67332071e-01, - -5.87672107e-02, 7.33083040e-02, -3.94398779e-01, - -7.60597512e-02, 1.71925854e-02, 1.17799109e-02, - -2.65986789e-02, 1.98394638e-02, -1.35528380e-02, - -3.39059532e-02, 9.92002785e-02, -7.92167559e-02, - 9.19176906e-04, -4.89958897e-02, 5.72972372e-02, - 1.21006947e-02, 4.03640568e-02, -1.18844979e-01, - -2.80744191e-02, -1.74218431e-01, -4.31395955e-02, - -6.09265082e-02, 3.76862884e-02, -1.11892074e-01], - [ 2.06562474e-01, 1.73960440e-02, -2.63249487e-01, - 1.38902217e-01, -4.79032584e-02, -2.24852517e-01, - 4.69521992e-02, -3.35566737e-02, 1.37603536e-01, - -5.11448458e-02, 8.18398222e-02, 1.07205749e-01, - -1.46739393e-01, -1.30916521e-01, -2.28276670e-01, - -7.91462511e-02, 6.24803789e-02, 4.59876209e-02, - 8.15130547e-02, 1.46908918e-02, -2.61019613e-03, - 1.13239333e-01, 2.98404664e-01, -1.80148214e-01, - 1.44556239e-01, -3.98542970e-01, -4.15323582e-03, - 4.42554235e-01, 4.46505845e-02, -3.50878686e-02, - -1.36736231e-02, 1.28197059e-01, 1.92225441e-01, - 9.25138816e-02, -2.71676213e-01, -1.16138257e-01], - [ 1.91090912e-01, -3.68523598e-02, -6.60930753e-01, - 3.02158773e-01, 1.77503861e-02, 1.00428194e-01, - -1.10393446e-02, 9.11340117e-03, -7.01573640e-02, - -3.42316413e-03, -7.93189174e-05, 2.59178817e-01, - 1.22925844e-02, 6.14976510e-02, -1.56667307e-01, - -5.03374226e-02, -4.95696850e-02, -1.59401018e-02, - -4.26767953e-02, -5.12050986e-02, -6.04047906e-03, - 5.44762500e-02, -1.07276395e-01, -1.12534806e-01, - -1.20743208e-01, 3.80993217e-01, -2.20808387e-02, - -2.89817184e-01, 3.23761255e-02, -6.17432930e-02, - -3.90686616e-02, -5.96804358e-02, -4.96021062e-02, - 8.57739672e-02, -8.64073634e-02, -1.20384485e-01], - [ 1.75619304e-01, 1.71932317e-02, 4.29833472e-01, - 8.81271958e-02, -3.94745134e-02, -5.61874844e-02, - 7.05854744e-02, 7.86138419e-03, 4.67237175e-01, - -1.88353360e-02, 6.92435876e-02, -3.38627174e-02, - -8.19625556e-02, -4.84902970e-02, -2.62022078e-01, - -1.48765266e-01, 7.19114691e-02, -1.21600203e-01, - 1.18209779e-01, 2.58331411e-02, 4.69931588e-02, - 9.96347591e-02, 2.32059956e-01, -1.78489253e-01, - 1.77511200e-03, 1.59484446e-01, 3.28991674e-02, - -4.70239580e-01, 1.65105104e-01, -2.61324756e-02, - -1.49319443e-04, -8.15570727e-02, 7.44131976e-05, - 8.14437792e-02, -7.25714415e-02, -1.24630690e-01], - [ 1.60147712e-01, -1.10780589e-01, 2.73144871e-01, - 1.10703602e-01, 2.37337053e-02, 4.52041216e-02, - 1.52682560e-02, -3.83009948e-02, 2.30164632e-01, - 2.54375394e-02, -3.03758867e-02, 8.13979190e-03, - 2.33282149e-02, 3.12441736e-02, -1.84844747e-01, - 2.14728359e-02, -5.53616770e-02, -2.22909674e-02, - -9.31906551e-02, -1.01961263e-01, -3.32283713e-02, - 8.18983093e-02, -3.90430242e-01, 1.43959653e-02, - -1.31596243e-02, 4.55893874e-01, -4.22518775e-02, - 5.82709551e-01, -1.36653170e-01, -3.07889320e-02, - -4.67781313e-02, -6.33331314e-02, -5.06754033e-03, - 3.76623571e-02, -6.18892610e-02, -1.28876895e-01], - [ 1.44676119e-01, -2.91557442e-02, 2.55934417e-01, - 5.66692650e-01, 3.84408869e-02, 1.04354315e-01, - -1.37322113e-01, 7.15484237e-03, -1.95520781e-02, - -2.59401686e-02, -9.82144028e-02, 2.44248882e-01, - 1.52861271e-02, 1.99174404e-01, 2.76121795e-01, - 8.94557908e-02, -1.24152258e-01, 6.37411512e-03, - -1.13803938e-01, -3.23315486e-02, -3.17632034e-02, - 2.70075332e-02, 2.75091957e-02, -4.90174480e-02, - -2.08239228e-01, -3.95830333e-01, -5.95310889e-02, - -4.46558185e-03, -7.16161057e-02, -4.99811508e-02, - -1.02262713e-01, -2.79212356e-01, -5.11405505e-02, - 2.62467805e-02, 1.03744328e-01, -1.33123115e-01], - [ 1.29204527e-01, -1.63312718e-01, -1.99243486e-01, - -2.34051406e-01, 3.55675933e-03, 1.56449080e-02, - 9.30304453e-02, -7.26388171e-02, 1.25461653e-01, - 1.20737530e-01, 4.42517921e-02, -4.18601990e-01, - -1.94645032e-01, 1.02710314e-01, -7.12260604e-02, - -6.79927021e-02, -3.08946688e-02, -8.88019353e-02, - -4.35314551e-02, -2.15784147e-01, -1.86102502e-02, - 5.49090989e-02, -3.75167191e-01, -8.20007622e-02, - -2.06737250e-01, -3.52603942e-01, -3.86392660e-02, - -2.84039471e-02, 2.83454835e-01, -2.61564963e-02, - 1.20758023e-02, -2.92337686e-01, -5.17344326e-02, - 3.77417319e-02, 1.23368390e-01, -1.37369320e-01], - [ 1.13732927e-01, -6.13378249e-02, -1.77854180e-01, - -4.99198377e-01, 2.01901477e-02, 1.41450047e-01, - -3.23677920e-02, 9.39797983e-03, 5.04098058e-01, - 1.23931216e-02, -8.47154856e-02, 3.81212860e-01, - 1.21610202e-01, 4.87964153e-02, 2.52459884e-01, - 1.51112108e-02, -4.74717468e-02, -1.84605867e-02, - -7.36073852e-02, 3.58235948e-02, -7.69592915e-03, - -7.00120777e-02, 1.28127992e-01, 4.49521616e-02, - -7.93955289e-04, -3.76549661e-02, 1.04670962e-02, - 7.88062997e-03, -2.23614484e-01, -9.32817012e-02, - -4.67354655e-02, -1.74636483e-01, 1.47633761e-01, - -1.42957285e-01, 7.11189136e-02, -1.41615525e-01], - [ 9.82613564e-02, -1.55768439e-01, -1.11842593e-02, - 6.37831986e-02, 5.79317398e-02, 3.34746271e-01, - 3.84975046e-01, -2.11655404e-02, -4.85437140e-02, - -4.50517267e-01, -3.28294598e-02, -2.49714255e-01, - 3.28522325e-01, -1.25372112e-01, 2.82705110e-02, - 1.42169207e-01, -8.04641694e-02, 6.62415996e-02, - -9.59652960e-02, -5.61193414e-02, -4.80792150e-02, - -4.04721648e-02, 2.45707080e-01, 2.35501617e-01, - -4.14447524e-02, 4.34486791e-02, -4.62412462e-02, - 4.26126681e-02, 2.55748153e-01, -7.83308148e-02, - 2.59090564e-03, -3.38329338e-02, 1.78729519e-01, - -3.09782606e-02, -8.34960043e-02, -1.45861730e-01], - [ 8.27897936e-02, -5.39819747e-02, 5.41151650e-02, - -2.87518036e-02, 1.98750496e-02, -1.58728033e-01, - -4.75713938e-01, 1.16178179e-02, -2.98879808e-03, - 2.26475924e-01, 2.46154964e-02, 1.24507852e-01, - 4.07826692e-01, -2.43859500e-01, 1.46053182e-02, - 8.78053382e-02, -7.19747171e-02, -4.02797535e-02, - -8.92022029e-02, -4.73439731e-02, 2.02829354e-02, - -9.01956186e-02, -1.16379023e-01, 1.02566876e-01, - 1.27621949e-01, -3.85584086e-02, -1.85301397e-02, - -1.46384817e-02, 5.42852879e-01, -1.11336805e-01, - -4.69652563e-02, 1.10105053e-01, -3.25540863e-02, - -9.18325037e-02, -1.09285243e-01, -1.50107935e-01], - [ 6.73181787e-02, -1.69579491e-01, -5.90509735e-02, - -8.87142718e-02, -4.61161807e-02, -1.32888526e-01, - -4.28256035e-01, -4.96512838e-02, -1.00748278e-01, - -1.56540096e-01, -1.33985683e-01, -3.31550747e-01, - 3.25447232e-01, 2.73245610e-02, -6.19893037e-02, - -1.48184791e-01, 1.88705355e-01, 1.62340149e-01, - 1.02853999e-01, 3.19841057e-01, -6.06105961e-02, - 1.69779122e-01, 1.54020518e-01, -8.75391066e-02, - -2.06520095e-01, 6.03866279e-02, 1.08508043e-01, - 4.56446186e-02, -2.30992153e-01, 6.16142601e-02, - 5.93037927e-04, -2.22505212e-01, -4.13618460e-02, - 1.47342280e-01, 4.37493399e-02, -1.54354155e-01], - [ 5.18466011e-02, 1.40082181e-01, -2.43853368e-02, - -9.01594944e-03, -2.02037729e-02, -2.15594158e-01, - -1.49669036e-01, -2.02583615e-02, 4.76960652e-03, - -4.28980350e-01, -2.16286242e-01, 2.93388069e-02, - -2.61512101e-01, 4.32281435e-01, 5.15976362e-02, - -2.38068718e-02, 1.35174215e-01, 1.65118262e-01, - 1.18229888e-01, -4.75422740e-02, -1.69874616e-02, - -9.87956077e-02, -6.16191179e-02, 1.92472130e-01, - 4.03664082e-01, 9.86855701e-02, 2.18016505e-02, - 9.58452746e-03, 2.42479756e-01, -9.45590809e-02, - 6.06411323e-02, -1.15035795e-01, -5.60823381e-02, - -1.10115618e-01, 7.84227401e-02, -1.58600360e-01], - [ 3.63750085e-02, 2.90070504e-01, 2.58655623e-02, - -4.51171659e-02, -9.76288766e-02, -7.32196262e-03, - 2.62665208e-02, -1.30719528e-01, -3.34864855e-02, - 1.83281839e-01, -2.03847468e-01, -7.86208585e-02, - 2.39961028e-01, 9.32282284e-02, -1.40201841e-02, - -1.65743440e-01, 2.50046160e-02, 1.87149823e-01, - -1.68221984e-02, -6.99453712e-01, 2.46135090e-02, - 9.76792276e-02, 1.59403309e-01, 1.05807781e-01, - -1.64897703e-02, -3.37719321e-02, 9.97098759e-02, - -5.71760125e-02, -2.09543109e-01, 1.61970984e-02, - 4.49959114e-02, 1.13044158e-01, -1.33089647e-01, - 6.79383874e-02, -1.17107280e-01, -1.62846550e-01], - [ 2.09034402e-02, 9.87452939e-02, 3.10002435e-02, - -3.82550769e-02, 6.49476936e-03, -1.86508909e-01, - -1.58566430e-01, 1.52609888e-02, 2.44785240e-03, - -1.72963649e-01, 2.82357018e-02, 6.35804012e-02, - -4.01134878e-01, -3.48292142e-01, -9.30772051e-02, - 2.69406252e-02, -1.48355186e-01, 6.67649359e-02, - -1.52495161e-01, -4.16254858e-03, -7.79623985e-02, - -8.69922712e-02, 1.67651065e-02, 4.43452805e-01, - -4.69122916e-01, 1.32700158e-02, 1.84264123e-01, - -4.69396599e-02, -8.76988843e-02, -8.42647329e-02, - 1.80242240e-01, 4.39915545e-02, -3.01284958e-02, - -4.19178084e-02, -6.55100867e-02, -1.67092770e-01], - [ 5.43184578e-03, -8.44964292e-03, 5.85759105e-03, - -7.32589066e-02, -6.53161779e-02, 1.58945680e-01, - -1.98484868e-01, -2.29594544e-01, -3.62942442e-02, - -4.60159145e-02, 4.65791941e-01, -1.32931456e-01, - -1.30874768e-01, 1.82594404e-01, 4.72868867e-02, - 7.68151507e-02, -1.17584936e-01, -7.83182383e-02, - -5.70569098e-01, 5.07849343e-02, -6.92476258e-02, - 1.45652056e-01, 1.57256410e-01, -2.92076059e-02, - 2.85284370e-01, 2.52744146e-02, 2.82830708e-02, - -5.04164398e-02, -1.00659683e-01, 5.86346574e-02, - 1.91001222e-02, 8.99196714e-02, -1.54763028e-01, - 1.01448707e-01, -7.42661506e-02, -1.71338975e-01], - [-1.00397598e-02, 6.89980984e-02, 5.02617331e-03, - -5.32203764e-02, 1.92967560e-02, -5.64105034e-01, - 3.46719325e-01, -7.40835667e-02, -5.14018210e-03, - 9.32325572e-02, 1.93343818e-01, 3.23573984e-02, - 2.21131876e-01, 3.06417048e-01, -8.70961323e-03, - 4.47171003e-01, 8.35162401e-02, 8.83740187e-02, - -8.72178078e-02, 1.18704282e-01, 1.05058528e-01, - -4.56921048e-02, 1.59751941e-02, -3.00876088e-02, - -2.47394085e-01, 4.93424907e-02, -6.64604902e-02, - -3.64027135e-02, -1.82686392e-02, -4.59523462e-02, - -1.26862470e-02, 2.52796169e-02, -4.81151454e-02, - -2.86283679e-02, -2.56162435e-02, -1.75585181e-01], - [-2.55113579e-02, 1.63476765e-02, -6.48622513e-02, - 8.53358284e-02, -1.47179626e-02, -2.74279952e-01, - 3.23813617e-01, 1.18787922e-01, -3.12188938e-02, - 1.27388835e-01, -1.47029653e-01, -6.44396339e-03, - 1.59717619e-01, -8.00469816e-02, 4.15628105e-02, - -3.71895492e-01, -2.58336008e-01, -3.58502686e-01, - -9.30814072e-02, 2.37474293e-01, -1.02323368e-01, - 7.77886510e-02, -2.62345857e-04, 3.05618107e-01, - 2.69323707e-01, -4.94645983e-02, 7.17321262e-02, - 1.81141701e-02, -7.26979673e-02, 3.66130173e-02, - 3.41478437e-02, -1.42837018e-01, -2.29302347e-01, - 9.40499976e-02, 9.85415503e-02, -1.79831386e-01], - [-4.09829244e-02, 2.96095997e-01, 5.72670512e-02, - -1.39296770e-01, -1.60581374e-03, 2.67294142e-02, - 5.13432994e-02, 3.44210893e-01, -4.88008671e-02, - -1.20673403e-01, -4.54095185e-01, -3.60888802e-02, - -3.48375738e-02, -3.80728357e-02, 6.19033575e-02, - 2.85812598e-02, -5.49174994e-02, 8.16437509e-03, - -3.89526159e-01, 1.42197743e-01, -6.57034442e-02, - 9.32944417e-02, -1.29381031e-01, -4.54968363e-01, - -7.63084590e-02, -1.27602285e-02, -3.93663906e-02, - -2.22954508e-02, 9.34363678e-02, 4.61584628e-02, - 1.17300354e-01, 1.84356645e-01, 4.64061499e-02, - 2.61230320e-02, -1.38632745e-01, -1.84077591e-01], - [-5.64545169e-02, -3.65092814e-01, -4.26685773e-02, - 1.75265297e-02, -1.79290678e-03, 7.54252076e-02, - -2.16403184e-03, 1.22491851e-01, 4.61655157e-03, - 9.93698239e-02, -2.86250204e-01, 1.17600495e-02, - -1.76643163e-01, -1.61555171e-01, 4.21675071e-02, - 4.96386349e-01, 2.84064054e-01, -1.88499331e-01, - 5.03461063e-02, -9.29289460e-02, 2.72047639e-01, - 1.54824242e-01, 7.62812719e-02, 9.09931362e-02, - 1.82046860e-01, -1.51961623e-02, 1.57171339e-01, - -2.52939817e-02, -6.88583925e-02, 8.74516144e-02, - 1.06507227e-01, 3.63174151e-03, -2.16592148e-01, - 1.95526704e-01, -2.63463091e-02, -1.88323811e-01], - [-7.19260871e-02, 1.53307199e-01, 2.98810583e-02, - -1.76042188e-02, 4.68952209e-02, 2.30930567e-01, - -1.91631261e-02, -3.50371659e-01, -1.39247498e-03, - -3.16982158e-02, 3.19441818e-02, 1.38011038e-01, - 1.15297228e-01, 1.21593997e-01, 1.12343794e-02, - -6.25559241e-02, 2.27593221e-02, -1.95765942e-01, - 2.61839062e-01, 1.88924655e-01, 1.47905156e-01, - 3.61047573e-02, -1.53986499e-01, 4.26004231e-02, - -1.01659156e-01, -9.87078920e-02, -1.97795078e-01, - 2.87956242e-02, 2.66166143e-02, 2.03926936e-02, - 6.36121154e-01, 1.17329828e-01, -1.68884546e-02, - 1.05052806e-01, -1.36004210e-01, -1.92570001e-01], - [-8.73977244e-02, 2.91939259e-01, -6.38535023e-02, - 1.23778999e-01, 2.33115517e-02, 8.99281502e-02, - -2.38235518e-02, 2.54457176e-01, -2.92873345e-02, - 1.45903289e-01, 2.51857221e-01, -1.22888424e-01, - 4.71667722e-02, -1.51163086e-01, -6.75680041e-02, - 1.34960130e-01, -5.27166612e-02, 5.85827529e-02, - 6.49949759e-02, -6.27990216e-02, 7.91215152e-02, - -2.11644500e-01, 1.25666901e-01, -2.19153777e-01, - 1.45102561e-01, 9.46507752e-02, 2.63710856e-01, - 1.36273995e-01, -2.85680946e-02, -9.64817554e-02, - 3.51572961e-01, -3.73799771e-01, 7.54300505e-02, - -1.52278930e-01, 2.77134597e-01, -1.96816236e-01], - [-1.02869295e-01, 4.54483837e-01, -3.16920318e-02, - -9.15080402e-03, 4.94015254e-02, 2.09832817e-01, - 9.22076330e-02, -3.92193407e-01, -1.33265834e-03, - 1.03313603e-01, -7.82989189e-02, 8.86598602e-03, - -9.18587223e-02, -1.70766622e-01, 5.54255210e-02, - 2.28601284e-02, 1.81634039e-01, 4.14796174e-02, - 3.81892845e-02, 2.48120666e-01, 1.65915981e-01, - 2.87097245e-02, -2.50649545e-02, 4.36540544e-02, - -5.01171201e-02, 3.54694985e-02, 1.90053612e-01, - 9.52630565e-02, 1.70738876e-01, 3.70882489e-02, - -4.90600616e-01, -9.28841755e-02, -8.13470930e-02, - 8.31348598e-02, 5.93565181e-02, -2.01062426e-01], - [-1.18340865e-01, -6.85950592e-02, 4.95309308e-02, - -1.77844893e-02, -9.69045609e-02, 2.31995173e-02, - -1.06131600e-03, 2.21603140e-01, -6.05566725e-02, - -2.82245725e-01, 2.64784724e-01, 8.62200931e-02, - 1.37575060e-01, 1.50092602e-01, 4.38311473e-02, - -1.27834529e-01, -1.75913945e-02, -2.03415841e-01, - 1.48476526e-01, -7.80855790e-02, 2.29345813e-01, - 3.37421596e-02, -3.02611887e-01, -3.64654101e-02, - -4.98286486e-02, -1.24875009e-01, 5.32554924e-01, - -5.55246398e-02, -8.19649324e-02, 4.32646945e-02, - -1.92818239e-01, 1.91410363e-01, 1.91146538e-01, - -1.30635314e-02, -1.27977282e-01, -2.05308631e-01], - [-1.33812457e-01, 5.83807267e-02, 6.38746191e-03, - -6.32736981e-02, 2.60766506e-01, 1.92557305e-01, - -4.26477045e-02, 5.47973156e-01, 1.53431622e-02, - 2.03396276e-01, 2.18420655e-01, 1.71779748e-02, - -7.09848702e-02, 2.39939511e-01, -2.50959713e-02, - -1.48106590e-01, 1.51656091e-01, 1.71890616e-01, - 7.37760216e-02, 5.53064533e-02, 1.98505912e-02, - 9.67100039e-02, 1.37430176e-01, 2.82746285e-01, - -1.24559112e-01, 1.80215873e-02, -2.68079907e-01, - 9.55012143e-02, 1.30839288e-01, 8.27972442e-02, - -9.96278524e-02, 4.17835526e-02, -4.81917933e-02, - 1.98767141e-01, -6.95911944e-02, -2.09554836e-01], - [-1.49284035e-01, -7.56456144e-03, -8.76261014e-03, - 2.92932428e-02, -8.39372516e-01, 5.67366369e-02, - -2.41059046e-02, 8.43372419e-02, -2.29054149e-02, - 3.72556150e-02, 3.59098194e-03, -3.51436548e-02, - -4.86128107e-02, -4.90781479e-02, -2.96334457e-02, - 2.16081198e-02, -6.04292788e-02, 1.73466746e-02, - 5.54120354e-02, 4.32790630e-02, 1.27067477e-01, - -9.41377804e-02, -1.37587115e-02, 7.06801787e-02, - -1.22610051e-02, 2.18931045e-02, -3.70597780e-01, - -1.30672632e-02, -4.53533195e-02, -1.70034133e-02, - -1.13316208e-01, -3.45941707e-02, 1.05737671e-01, - -2.95185428e-02, 2.46357918e-02, -2.13801056e-01], - [-1.64755657e-01, -1.91551998e-01, 1.24477036e-02, - 1.76897332e-01, -1.70191415e-02, 2.34046783e-02, - 6.76611960e-02, -1.21719569e-01, -1.60261299e-02, - 2.84169883e-01, -7.72131458e-02, -4.39732298e-02, - -6.60723150e-02, 8.68341923e-02, 7.35200867e-02, - -1.56345084e-01, 4.99212921e-01, -9.53519195e-02, - -1.69593558e-01, 3.12364921e-02, -4.14223462e-01, - -2.19161183e-01, -7.49167113e-04, 4.25142385e-02, - -2.26298310e-02, 3.90600637e-02, 1.34113848e-01, - -4.32782359e-02, -2.25105719e-03, -8.36708769e-02, - 7.53742829e-02, 1.09890841e-01, 3.47145647e-01, - -1.67040601e-01, -4.17540558e-02, -2.18047246e-01], - [-1.80227250e-01, -3.65751952e-01, 1.95310116e-02, - 3.56181487e-02, -2.47674435e-02, -2.56252866e-02, - 1.70394495e-01, -1.01341322e-01, 6.43750429e-02, - -1.18520278e-02, 7.76712969e-02, 1.21111691e-01, - -7.56260678e-02, -1.32285401e-01, 2.50612080e-01, - -2.70852149e-01, -9.66061503e-02, 4.63890702e-01, - 5.18286489e-02, 1.14975851e-02, 7.05922395e-02, - 7.95801077e-03, 3.40116471e-02, -2.50298321e-01, - -4.72176410e-02, 7.11330771e-02, 7.71585703e-02, - 7.12307394e-02, 1.51480496e-01, 4.94032800e-02, - 9.26278085e-02, 1.93590626e-01, -3.63108933e-01, - -1.36400744e-01, 1.46016315e-01, -2.22293481e-01], - [-1.95698813e-01, 8.16941485e-02, 6.35532150e-03, - -5.50320372e-02, 1.45350844e-01, -7.66825154e-02, - -1.48402769e-02, 8.44644289e-03, -3.05129532e-02, - -3.45072865e-01, 1.88118920e-01, 1.39703169e-01, - 9.01852995e-02, -3.05740625e-01, -7.54492134e-02, - 6.51175901e-02, 2.45817453e-01, -1.89270392e-01, - 1.16880536e-01, -2.26171866e-01, -3.72853994e-01, - 5.43844700e-03, -1.24716990e-01, -1.48458153e-01, - 5.83554097e-02, -8.44632387e-02, -3.41172040e-01, - -5.05601391e-02, -1.60052970e-01, 5.74440435e-02, - -1.45993277e-01, -4.03214097e-02, -2.16732427e-01, - -2.84256153e-02, 1.41579702e-01, -2.26539686e-01], - [-2.11170420e-01, -6.31088763e-02, 8.17671046e-03, - -5.57366088e-02, 6.94130734e-02, 3.52174342e-02, - -6.57851174e-02, -9.82191563e-02, -1.27271414e-02, - 1.43996403e-01, -1.19659491e-01, -5.62400967e-02, - -1.02117673e-01, 1.46197915e-01, -6.46053180e-02, - 2.75428176e-01, -5.38663089e-01, 1.51460487e-02, - 3.81278455e-01, 1.08411210e-02, -4.44346756e-01, - 4.02242467e-02, 9.23668295e-02, -7.21167400e-02, - 3.91138941e-02, 4.99221608e-02, 9.94546860e-02, - -3.87978405e-02, 1.93843860e-02, 8.32882449e-02, - -1.15623131e-01, 8.08125958e-02, 1.40358344e-01, - 1.01261795e-01, -5.90205789e-02, -2.30785877e-01], - [-2.26641983e-01, -1.44536331e-01, 8.91233422e-03, - 5.05167954e-02, 3.87359351e-01, -1.25706807e-01, - -9.50697213e-02, -1.42298609e-01, -7.01352954e-02, - -3.15868692e-03, -1.33074358e-01, -1.18453935e-01, - -7.71054849e-02, -4.75535467e-02, -1.50268868e-01, - -1.44392461e-01, -1.82032049e-01, -1.19762598e-02, - -1.21959276e-01, -6.38470054e-02, 4.80738163e-01, - -1.59658909e-01, 2.71296166e-02, -4.31644246e-02, - 1.02411315e-01, 2.07743910e-03, -2.89108336e-01, - -1.03720047e-01, -2.01758668e-01, -2.16420572e-02, - -1.27163813e-01, -7.36601278e-03, 3.14732850e-01, - -1.12868495e-01, 3.11465543e-02, -2.35032097e-01]], dtype=float32), array([-1.89882166e+03, -1.79985218e-04, -1.70435800e-04, -1.27975552e-04, - -1.24901737e-04, -1.24676313e-04, -1.16428266e-04, -1.06598200e-04, - -1.00050034e-04, -9.61478145e-05, -8.36294785e-05, -6.41566730e-05, - -4.51904889e-05, -2.39018827e-05, -1.49146554e-05, -9.43070791e-06, - -8.04440424e-06, 1.51055592e-05, 2.01099483e-05, 2.64523860e-05, - 3.25085311e-05, 5.15936626e-05, 5.31896258e-05, 7.24942220e-05, - 9.04739063e-05, 1.04830775e-04, 1.08393360e-04, 1.37811687e-04, - 1.49946762e-04, 1.86386926e-04, 1.89535742e-04, 2.40968098e-03, - 2.56012683e-03, 2.69382820e-03, 3.27441283e-03, 2.52088105e+04], - dtype=float32)), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<36x36xf32> {jax.result_info = "[0]"}, tensor<36xf32> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<1296xf32> - %1 = stablehlo.reshape %0 : (tensor<1296xf32>) -> tensor<36x36xf32> - %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<36x36xf32>) -> tensor<36x36xf32> - %3 = stablehlo.add %1, %2 : tensor<36x36xf32> - %4 = stablehlo.constant dense<2.000000e+00> : tensor - %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<36x36xf32> - %6 = stablehlo.divide %3, %5 : tensor<36x36xf32> - %7 = call @tril(%6) : (tensor<36x36xf32>) -> tensor<36x36xf32> - %8 = stablehlo.custom_call @cusolver_syevd(%7) {api_version = 2 : i32, backend_config = "\00\00\00\00\00\00\00\00\01\00\00\00$\00\00\00Y\98\00\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<36x36xf32>) -> tuple, tensor<36xf32>, tensor, tensor<39001xf32>> - %9 = stablehlo.get_tuple_element %8[0] : (tuple, tensor<36xf32>, tensor, tensor<39001xf32>>) -> tensor<36x36xf32> - %10 = stablehlo.get_tuple_element %8[1] : (tuple, tensor<36xf32>, tensor, tensor<39001xf32>>) -> tensor<36xf32> - %11 = stablehlo.get_tuple_element %8[2] : (tuple, tensor<36xf32>, tensor, tensor<39001xf32>>) -> tensor - %12 = stablehlo.get_tuple_element %8[3] : (tuple, tensor<36xf32>, tensor, tensor<39001xf32>>) -> tensor<39001xf32> - %13 = stablehlo.constant dense<0> : tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor - %15 = stablehlo.compare EQ, %11, %14, SIGNED : (tensor, tensor) -> tensor - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1x1xi1> - %17 = stablehlo.constant dense<0x7FC00000> : tensor - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<36x36xf32> - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<36x36xi1> - %20 = stablehlo.select %19, %9, %18 : tensor<36x36xi1>, tensor<36x36xf32> - %21 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1xi1> - %22 = stablehlo.constant dense<0x7FC00000> : tensor - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor<36xf32> - %24 = stablehlo.broadcast_in_dim %21, dims = [0] : (tensor<1xi1>) -> tensor<36xi1> - %25 = stablehlo.select %24, %10, %23 : tensor<36xi1>, tensor<36xf32> - return %20, %25 : tensor<36x36xf32>, tensor<36xf32> - } - func.func private @tril(%arg0: tensor<36x36xf32>) -> tensor<36x36xf32> { - %0 = stablehlo.iota dim = 0 : tensor<36x36xi32> - %1 = stablehlo.constant dense<0> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<36x36xi32> - %3 = stablehlo.add %0, %2 : tensor<36x36xi32> - %4 = stablehlo.iota dim = 1 : tensor<36x36xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<36x36xi32>, tensor<36x36xi32>) -> tensor<36x36xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<36x36xf32> - %8 = stablehlo.select %5, %arg0, %7 : tensor<36x36xi1>, tensor<36x36xf32> - return %8 : tensor<36x36xf32> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01-\x05\x01\x05\x01\x03\x05\x03\x1d\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!\x03^\x02\xeb5\x01\x95\x0f\x17\x13\x07\x0f\x0b\x0b\x0b\x0b\x0b\x17\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x17\x0f\x13\x13\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x13\x0bK\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x03W\x0b\x0b\x0f\x0b\x0bO/\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1fO\x1f\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1b\x0f\x0f\x0f\x0f\x0f\x0b\x1fO/\x035\x17\x0f\x07\x0f\x07\x13\x07\x07\x17\x07\x17\x13\x1b\x17\x17\x13\x17\x1b\x13\x13\x13\x0f\x17\x13\x13\x13\x02v\x08\x1d\x85\x03\x17\x11R\x04\x01\x03\x03\x13\xbd\x1f\x1d9\x03\x05#\x05%\x05'\x05)\x05+\x17\x11N\x04\x01\x05-\x05/\x051\x053\x03\x03!\xb9\x055\x03\x03\x0b\xbb\x1d?\x03\x057\x059\x17\x11F\x04\x01\x1dm\x15\x03\x03\x0b\xe5\x03\x03\x0f3\x05;\x03\x0b\x17\x95\x19\xa3\x1b\xa5\x0f\xaf\x1d\xb1\x03\x0b\x17\x99\x19\xb5\x1b\x99\x0f\x9b\x1d\xb7\x05=\x1d=\x03\x05?\x05A\x03\x03!\xbf\x1dE\x03\x05C\x03\x05'\x9d)\xc1\x1dK\x03\x05E\x03\x03\x0b\xc3\x1dQ\x03\x05G\x1dU\x03\x05I\x1dY+\x05K\x1d]+\x05M\x03\x03a\xc5\x05O\x1de\x15\x05Q\x1di\x15\x05S\x03\x03\x0b\xc7\x05U\x03\x03q\x9b\x05W\x03\x11u\xc9w\xcby\xcd{\x95}\xcf\x7f\xd1\x81\xd3\x83\xd7\x05Y\x05[\x05]\x05_\x05a\x05c\x05e\x05g\x05i\x03\x03\r\xdb\x03\x03\r\xdd\x03\x03\r\xdf\x03\x03\r\xe1\x03\x05'\x9d)\xe3\x03\x03\x13\xe7\x03\x03\x13\xe9\x03\x01\x1dk\x03\x03\xb3\x1dm\t\x07\x1f%!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1b\x03\x05\xa7\xab\r\x03\x97\xa9\x1do\r\x03\x97\xad\x1dq\x1ds\x1du\r\x01#\x1d\x1dw\x13\r\x01\x1f\x07\t\x00\x00\x00\x00\x1f\x1f\x01\x13\r\x05\x07\x05\x1f\x03\t\x00\x00\x00\x00\x1f\x17!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x00\x00\x00@\x0b\x05\x1dy\x1d{\x05\x01\x03\x03\x9f\x03\x03\xd5\x15\x03\x01\x01\x01\x03\t\x9f\xa1\xd9\xa1\x1f)\x01\x13\x05\x01\x13\x05\x05\x13\x05\t\x13\x05\r\x07\x01\x1f\x03\t\x00\x00\xc0\x7f\x1f\x17!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00)\x05\x91\x91\t)\x01\t\x1b)\x01\x05\t)\x03\x91\t\x1d\x01)\x05\x91\x91\x05\x13)\x05\x91\x91\x0f)\x03\t\r)\x03\x94\x85\t\t\x11\x01\x05\x01\x0b\x11\x03\x01\x03\x01)\x03\x01\r)\x03\x82(\t/\t\x01\x0b\x07\x19)\x03\t\x13)\x03\x05\x13)\x03\x01\x13)\x01\x0f)\x05\x05\x05\x0f)\x03\x05\x0f)\x03\x91\x0f)\x03\x05\r\x04\xc6\x04\x05\x01\x11\x071\x07\x03\x01\t\r\x11\x075\x05\x035m\t\x03W\x1f\x03!\x15\x06[\x03\x01\x03\x01\x17\x07c_\x03\x01\x03\x03\x0f\x06g\x03\x01\x05\x03\x05\x05\x03\x07k\x03\x03\x03\x07-\x05\x03\x01\x03\t\x19\x06-\x03\x01\x05\x07\x0b\x1b\x07\to\x03\x01\x03\r\x1d\x07\x01s\x03#\x03\x0f\x07\x07\x01\x87\x03\x01\x03\x11\x07\x07\x01\x89\x03\x0b\x03\x11\x07\x07\x01\x8b\x03\x07\x03\x11\x07\x07\x01\x8d\x03\x19\x03\x11\x05\x03\x01#\x03\x07\x03\x07\x01\x05\x03\x07\x03\x1b\x11\x07\x01\x8f\x03+\x05\x17\x1d\x03\x07\x01\x05\x03-\x03\x1f\x05\x03\x01/\x03\x03\x03\x07\x01\x05\x03\x01\x03#\x03\x07\x01\x91\x03\x15\x03!\x0b\x06\x01\x03\x01\x07'\x13%\x03\x07\x01\x05\x03/\x03\x1f\x05\x03\x01/\x03\x03\x03\x07\x01\x05\x03\x0b\x03-\x03\x07\x01\x93\x031\x03+\x0b\x06\x01\x03\x0b\x071\x15/\x13\x04\x07\x05)3\r\x11\t7\x05\x03\x15+\x03\x01\x07\t\x03;\x1f\x03\x11\x05\x03\t#\x03\x07\x03\x07%\x05\x03\x11\x03\x05\x0f\x06%\x03\x11\x05\x03\x07\t\x03CA\x03\x11\x11\x07IG\x03\x15\x05\t\x0b\x05\x03\tM\x03\x03\x03\x07O\x05\x03\x01\x03\x0f\x0b\x06S\x03\x01\x07\r\x01\x11\x13\x04\t\x03\x13\x06\x03\x01\x05\x01\x00.\x1a}\x1f+\x11\x0f\x0b\t\t\x0b!\x7f\x1f/!!)#\x1f\x19\x0f99m\x19\x89\x8dW\xb7K\x9fM\x9f\x96\x04\x1b+\x1b\x1f\x1f\x15\x1d\x15+\x83\x13\r\r\x1f\x11\x15\x1b\x17\x15\x17\x0f\x11\x15\x11+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00value\x00index\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(36, 36) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(36, 36) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(36, 36) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(1296,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(36, 36) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True]\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00$\x00\x00\x00Y\x98\x00\x00\x00cusolver_syevd\x00", - xla_call_module_version=4, - ), # End paste - - # Pasted from the test output (see back_compat_test.py module docstring) - f64_syevj=dict( - testdata_version=1, - platform='cuda', - custom_call_targets=['cusolver_syevj'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[ 6.1857700048412179e-01, -7.9870412160195655e-05, - -7.1795133407817180e-02, 7.2651725579187088e-01, - -5.8816812454044016e-04, -1.0752133550364418e-01, - -1.9695247974936425e-01, 1.8446994643771727e-01], - [ 4.7070881487314487e-01, 3.3071017759156432e-05, - -5.9630159401629157e-01, -4.7856902268752244e-01, - -1.4151478943184035e-03, -2.5017522435505674e-01, - 2.8106392345809550e-01, 2.2856669794666581e-01], - [ 3.2284062926217122e-01, -5.1104181032785456e-01, - 2.4098685972870454e-01, -3.2057977627137213e-01, - 6.0128498619340851e-04, 5.5435726441071020e-01, - -3.0349043125069775e-01, 2.7266344945561433e-01], - [ 1.7497244365119549e-01, 4.1809211960021736e-01, - 5.7112844532216078e-01, -3.1146378582869927e-01, - -4.8989605706119613e-04, -4.4689091764000977e-01, - -2.6709076241922963e-01, 3.1676020096456298e-01], - [ 2.7104258040218803e-02, 4.2941995817157164e-01, - 1.1304358388496584e-01, 9.3073375918824142e-02, - -4.7236149166811120e-01, 4.6617552271070906e-01, - 4.7197416944525139e-01, 3.6085695247351168e-01], - [-1.2076392757075657e-01, -3.8434927079561992e-01, - 2.9171425263113138e-01, 1.5624558970245273e-01, - 4.3260383504376299e-01, -2.0278835428567779e-01, - 5.7959048064074936e-01, 4.0495370398246017e-01], - [-2.6863211318173014e-01, 3.6363990709349564e-01, - -3.3163183889685732e-01, 4.2836063092320187e-02, - 5.6343802845177837e-01, 2.7652818360156795e-01, - -2.9700444618985122e-01, 4.4905045549140854e-01], - [-4.1650029879270561e-01, -3.1571410434740910e-01, - -2.1714457524599659e-01, 9.1940300282126255e-02, - -5.2178844473770358e-01, -2.8968513893859849e-01, - -2.6809045393495168e-01, 4.9314720700035708e-01]]), array([-2.4598804776133605e+01, -2.8026300235964570e-15, - -1.8958980326674837e-15, 1.5553235693581772e-15, - 1.6670762548207520e-15, 2.2405283578797194e-15, - 5.4086800892994285e-15, 2.7659880477613365e+02])), - mlir_module_text=""" -module @jit__lambda_ { - func.func public @main() -> (tensor<8x8xf64> {jax.result_info = "[0]"}, tensor<8xf64> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<64xf64> - %1 = stablehlo.reshape %0 : (tensor<64xf64>) -> tensor<8x8xf64> - %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xf64>) -> tensor<8x8xf64> - %3 = stablehlo.add %1, %2 : tensor<8x8xf64> - %4 = stablehlo.constant dense<2.000000e+00> : tensor - %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<8x8xf64> - %6 = stablehlo.divide %3, %5 : tensor<8x8xf64> - %7 = call @tril(%6) : (tensor<8x8xf64>) -> tensor<8x8xf64> - %8 = stablehlo.custom_call @cusolver_syevj(%7) {api_version = 2 : i32, backend_config = "\01\00\00\00\00\00\00\00\01\00\00\00\08\00\00\00M\08\00\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<8x8xf64>) -> tuple, tensor<8xf64>, tensor, tensor<2125xf64>> - %9 = stablehlo.get_tuple_element %8[0] : (tuple, tensor<8xf64>, tensor, tensor<2125xf64>>) -> tensor<8x8xf64> - %10 = stablehlo.get_tuple_element %8[1] : (tuple, tensor<8xf64>, tensor, tensor<2125xf64>>) -> tensor<8xf64> - %11 = stablehlo.get_tuple_element %8[2] : (tuple, tensor<8xf64>, tensor, tensor<2125xf64>>) -> tensor - %12 = stablehlo.get_tuple_element %8[3] : (tuple, tensor<8xf64>, tensor, tensor<2125xf64>>) -> tensor<2125xf64> - %13 = stablehlo.constant dense<0> : tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor - %15 = stablehlo.compare EQ, %11, %14, SIGNED : (tensor, tensor) -> tensor - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1x1xi1> - %17 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<8x8xf64> - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> - %20 = stablehlo.select %19, %9, %18 : tensor<8x8xi1>, tensor<8x8xf64> - %21 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1xi1> - %22 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor<8xf64> - %24 = stablehlo.broadcast_in_dim %21, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> - %25 = stablehlo.select %24, %10, %23 : tensor<8xi1>, tensor<8xf64> - return %20, %25 : tensor<8x8xf64>, tensor<8xf64> - } - func.func private @tril(%arg0: tensor<8x8xf64>) -> tensor<8x8xf64> { - %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> - %1 = stablehlo.constant dense<0> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<8x8xi32> - %3 = stablehlo.add %0, %2 : tensor<8x8xi32> - %4 = stablehlo.iota dim = 1 : tensor<8x8xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<8x8xf64> - %8 = stablehlo.select %5, %arg0, %7 : tensor<8x8xi1>, tensor<8x8xf64> - return %8 : tensor<8x8xf64> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01-\x05\x01\x05\x01\x03\x05\x03\x1d\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!\x03^\x02\xeb5\x01\x95\x0f\x17\x13\x07\x0f\x0b\x0b\x0b\x0b\x0b\x17\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x17\x0f\x13\x13\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x13\x0bK\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x03W\x0b\x0b\x0f\x0b\x0bO/\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b/O/\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1b\x0f\x0f\x0f\x0f\x0f\x0b/O/\x035\x17\x0f\x07\x0f\x07\x13\x07\x07\x17\x07\x17\x13\x17\x17\x17\x13\x17\x1b\x13\x13\x13\x0f\x17\x13\x13\x13\x02\xa2\x08\x1d\x85\x03\x17\x116\x04\x01\x03\x03\x13\xbd\x1f\x1d9\x03\x05#\x05%\x05'\x05)\x05+\x17\x112\x04\x01\x05-\x05/\x051\x053\x03\x03!\xb9\x055\x03\x03\x0b\xbb\x1d?\x03\x057\x059\x17\x11*\x04\x01\x1dm\x15\x03\x03\x0b\xe5\x03\x03\x0f3\x05;\x03\x0b\x17\x95\x19\xa3\x1b\xa5\x0f\xaf\x1d\xb1\x03\x0b\x17\x99\x19\xb5\x1b\x99\x0f\x9b\x1d\xb7\x05=\x1d=\x03\x05?\x05A\x03\x03!\xbf\x1dE\x03\x05C\x03\x05'\x9d)\xc1\x1dK\x03\x05E\x03\x03\x0b\xc3\x1dQ\x03\x05G\x1dU\x03\x05I\x1dY+\x05K\x1d]+\x05M\x03\x03a\xc5\x05O\x1de\x15\x05Q\x1di\x15\x05S\x03\x03\x0b\xc7\x05U\x03\x03q\x9b\x05W\x03\x11u\xc9w\xcby\xcd{\x95}\xcf\x7f\xd1\x81\xd3\x83\xd7\x05Y\x05[\x05]\x05_\x05a\x05c\x05e\x05g\x05i\x03\x03\r\xdb\x03\x03\r\xdd\x03\x03\r\xdf\x03\x03\r\xe1\x03\x05'\x9d)\xe3\x03\x03\x13\xe7\x03\x03\x13\xe9\x03\x01\x1dk\x03\x03\xb3\x1dm\t\x07\x1f%!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1b\x03\x05\xa7\xab\r\x03\x97\xa9\x1do\r\x03\x97\xad\x1dq\x1ds\x1du\r\x01#\x1d\x1dw\x13\r\x01\x1f\x07\t\x00\x00\x00\x00\x1f\x1f\x01\x13\r\x05\x07\x05\x1f\x03\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x17!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\x11\x00\x00\x00\x00\x00\x00\x00@\x0b\x05\x1dy\x1d{\x05\x01\x03\x03\x9f\x03\x03\xd5\x15\x03\x01\x01\x01\x03\t\x9f\xa1\xd9\xa1\x1f)\x01\x13\x05\x01\x13\x05\x05\x13\x05\t\x13\x05\r\x07\x01\x1f\x03\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x17!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00)\x05!!\t)\x01\t\x1b)\x01\x05\x0b)\x03!\t\x1d\x01)\x05!!\x05\x13)\x05!!\x0f)\x03\t\r)\x03jB\t\x11\x01\x05\x01\x0b\x11\x03\x01\x03\x01)\x03\x01\r)\x03\x02\x02\t/\t\x01\x0b\x07\x19)\x03\t\x13)\x03\x05\x13)\x03\x01\x13)\x01\x0f)\x05\x05\x05\x0f)\x03\x05\x0f)\x03!\x0f)\x03\x05\r\x04\xc6\x04\x05\x01\x11\x071\x07\x03\x01\t\r\x11\x075\x05\x035m\t\x03W\x1f\x03!\x15\x06[\x03\x01\x03\x01\x17\x07c_\x03\x01\x03\x03\x0f\x06g\x03\x01\x05\x03\x05\x05\x03\x07k\x03\x03\x03\x07-\x05\x03\x01\x03\t\x19\x06-\x03\x01\x05\x07\x0b\x1b\x07\to\x03\x01\x03\r\x1d\x07\x01s\x03#\x03\x0f\x07\x07\x01\x87\x03\x01\x03\x11\x07\x07\x01\x89\x03\x0b\x03\x11\x07\x07\x01\x8b\x03\x07\x03\x11\x07\x07\x01\x8d\x03\x19\x03\x11\x05\x03\x01#\x03\x07\x03\x07\x01\x05\x03\x07\x03\x1b\x11\x07\x01\x8f\x03+\x05\x17\x1d\x03\x07\x01\x05\x03-\x03\x1f\x05\x03\x01/\x03\x03\x03\x07\x01\x05\x03\x01\x03#\x03\x07\x01\x91\x03\x15\x03!\x0b\x06\x01\x03\x01\x07'\x13%\x03\x07\x01\x05\x03/\x03\x1f\x05\x03\x01/\x03\x03\x03\x07\x01\x05\x03\x0b\x03-\x03\x07\x01\x93\x031\x03+\x0b\x06\x01\x03\x0b\x071\x15/\x13\x04\x07\x05)3\r\x11\t7\x05\x03\x15+\x03\x01\x07\t\x03;\x1f\x03\x11\x05\x03\t#\x03\x07\x03\x07%\x05\x03\x11\x03\x05\x0f\x06%\x03\x11\x05\x03\x07\t\x03CA\x03\x11\x11\x07IG\x03\x15\x05\t\x0b\x05\x03\tM\x03\x03\x03\x07O\x05\x03\x01\x03\x0f\x0b\x06S\x03\x01\x07\r\x01\x11\x13\x04\t\x03\x13\x06\x03\x01\x05\x01\x00\x06\x1a}\x1f+\x11\x0f\x0b\t\t\x0b!\x7f\x1f/!!)#\x1f\x19\x0f99m\x19\x85\x89W\xb3K\x9bM\x9b\x96\x04\x1b+\x1b\x1f\x1f\x15\x1d\x15+\x83\x13\r\r\x1f\x11\x15\x1b\x17\x15\x17\x0f\x11\x15\x11+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00value\x00index\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=float64 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True]\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x08\x00\x00\x00M\x08\x00\x00\x00cusolver_syevj\x00", - xla_call_module_version=4, - ), # End paste - - # Pasted from the test output (see back_compat_test.py module docstring) - f64_syevd=dict( - testdata_version=1, - platform='cuda', - custom_call_targets=['cusolver_syevd'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[-3.1486359056225782e-01, 3.7431364158123925e-02, - 6.1831284766658730e-02, -1.2946991231313536e-02, - 1.9330566993707950e-02, 3.1760201896488226e-03, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 9.4213470166864710e-01, -8.6414847942068732e-02], - [-2.9939200325938797e-01, 8.3501568928299474e-01, - 4.0680107296867257e-01, -4.6573192775473518e-02, - 6.5422207600829785e-02, 2.2099527094683900e-02, - -1.0242349878775975e-02, 4.0829390183091318e-03, - -1.5827725558444371e-02, -8.6793932713605769e-03, - 1.3047005177451432e-03, -5.3573283556152184e-03, - -1.1723085990292578e-02, -3.4282481604778923e-03, - 1.5300655388654032e-03, 1.3010433879291027e-02, - -7.6245808434662662e-03, 5.9569775610370131e-04, - -5.9294293157650772e-03, -1.9734040942842074e-03, - -1.8628968192927392e-02, -1.3034235399858809e-02, - -5.0097004610369401e-03, 2.4749245795903537e-02, - -5.0644358547264675e-03, 3.0532167800601515e-03, - 2.0824661626164857e-02, -1.5147462161617094e-03, - 1.6322395782111299e-02, -1.1236053191734820e-02, - -1.1821960842042806e-02, 3.8822577430670320e-03, - 7.0724820528586508e-04, 1.9906723944256747e-02, - -1.7030338737863057e-01, -9.0661051391036640e-02], - [-2.8392041595652112e-01, -1.0171687781151459e-01, - -1.1816431661072314e-01, 2.9212172394267638e-01, - 3.3294458108354380e-01, 4.2087881292542445e-01, - -2.2194306321456944e-01, 1.2056157631930936e-01, - -1.0764065526585581e-01, 4.4945129933377570e-02, - -1.1518299700192679e-01, -3.1085391640205563e-02, - 3.1385765542768805e-02, -2.2533661915179113e-02, - 9.3053311217867085e-02, -1.6099650538834706e-01, - -3.8639305088265900e-02, 9.2990366329018387e-03, - 4.6666113341746911e-02, -2.1871647987757620e-01, - 1.7703518610745730e-01, 1.5467613762024190e-01, - -7.2294521250116733e-02, 2.3499877830015681e-01, - -5.6829378083033165e-03, -1.0178485446351725e-01, - 1.7877785721217213e-01, 2.1684187554288339e-01, - 7.7233872499541889e-02, 2.2835265304748494e-02, - 3.1080805156356406e-01, 3.1722234078538948e-02, - -7.8092425763001377e-02, 9.4554636051152510e-02, - -9.6031463624110386e-02, -9.4907254840003452e-02], - [-2.6844882865365438e-01, -2.0201860535424061e-02, - -2.0343029420688158e-01, 1.2815855886454322e-01, - 4.8774092445450092e-02, 1.3232562034943543e-01, - -1.8521836621459195e-01, 9.8747816539597660e-02, - 2.7324903486606195e-01, -7.8737437097193080e-02, - 4.9421661772677816e-02, 7.1493931251323112e-02, - 3.5542595611320515e-01, 1.3920746216059152e-01, - -2.8249741974519734e-02, 6.7932896387190703e-02, - -2.3008512044551552e-01, 5.5015746716542496e-02, - -6.0329018554125865e-03, 8.4249901371007491e-02, - -1.0850059549176212e-01, -2.7052679792044718e-02, - 1.7199248671821082e-01, -2.0779039909219962e-01, - 1.1023999772580403e-01, 4.0228126834019268e-01, - -7.1331569093078903e-02, -2.2546040356632324e-01, - -5.6848723613690040e-02, 2.0039103669806510e-01, - -2.2375524112669190e-01, -6.6955463229343037e-02, - -1.4356710092268696e-01, 2.2907198003730800e-01, - -8.4342913246148038e-02, -9.9153458288970819e-02], - [-2.5297724135078736e-01, -9.7633470097019753e-02, - -2.0613664461051402e-02, -4.6575018452204114e-01, - -4.5475545929408095e-01, -1.6835202228307944e-01, - -2.7411043542686481e-01, 1.4382896244553764e-01, - 1.5533482960243880e-01, -7.7897907011887785e-02, - -5.9104799414908579e-02, -5.1049057176047449e-02, - 5.0937034273965797e-03, -2.9920502980456239e-02, - 7.9164430071644656e-02, 6.5334090456028976e-02, - -2.4594170101813598e-01, 4.0287932953704184e-02, - 1.3071075582032446e-01, -5.6912271071735306e-02, - -1.2680756132856946e-02, 3.5044366466197449e-02, - -5.1780762628180410e-03, 1.2325979893038844e-01, - -1.3286387357961091e-01, -1.9718715617446650e-01, - -7.0204376770955132e-02, -9.3710658292701816e-03, - 7.6870928390159760e-03, 1.2623341382152653e-01, - 3.4895566103640097e-01, 7.7553659039143241e-02, - -3.4023999296528072e-02, 8.3074702907895745e-02, - -8.5300072672481381e-02, -1.0339966173793817e-01], - [-2.3750565404792034e-01, -8.2181485614283623e-02, - -2.4796576412755008e-02, 2.6469606244089910e-01, - 2.5136155191565374e-01, -8.5932117879471037e-01, - -6.7801327364868255e-02, 2.3630380146045637e-02, - -6.0339530364635997e-02, 2.4318784991642788e-02, - -2.0157980609574723e-02, 1.3969684905577337e-02, - 5.2064373452097072e-02, -1.3504287072787914e-03, - 1.1948855400414819e-02, -7.7684684576308824e-02, - -1.8126869586737940e-02, -3.2895203661275497e-02, - -4.7194795185232655e-03, -6.2526420481870917e-02, - 7.8353014950393762e-02, 4.3021669650274826e-02, - 4.1123834759705602e-02, 2.1527669096626890e-02, - 3.2298969317449348e-02, 2.3438124417394162e-02, - 3.1518151219115144e-02, 8.9704214482948422e-02, - 7.6821260017619769e-03, -8.5409778343425186e-03, - 1.5521001031338759e-02, -1.3290428648657086e-02, - 1.8906628930454021e-02, -1.2782589525387992e-02, - -8.2979044248598546e-02, -1.0764586518690553e-01], - [-2.2203406674505338e-01, -9.0264475102341105e-02, - 9.0740700176499111e-03, 6.9171384437416147e-02, - -1.3111811612891669e-01, -1.8966507957248607e-02, - 4.0414304307463594e-01, -3.2564666059313241e-02, - 5.6086124244845181e-01, -4.0083205571491060e-02, - -2.4505702319715772e-02, 2.8981348567837486e-02, - -1.8028953963325864e-01, 1.2810669493073431e-02, - -3.0205734928244080e-02, 1.3016546116209483e-03, - 4.1180187675978214e-01, 1.8487430939971340e-03, - 2.1878399115523185e-02, -1.2942737544986772e-02, - 3.1612876215063763e-02, 1.9040590265843902e-02, - -2.9853451951736565e-01, -2.1069261774264141e-02, - 1.2756924052704141e-02, 1.0396556130345047e-02, - 2.0982593071380967e-01, 7.2513245350085284e-02, - 2.6961322653924678e-02, 4.4259057451694346e-02, - 1.3245555422671054e-02, -1.1355432725780245e-02, - -1.6423769454471046e-01, 2.1283797622603673e-01, - -7.7771821344734746e-02, -1.1189206863587289e-01], - [-2.0656247944218639e-01, -7.5555152047925872e-02, - 2.1436004480934572e-02, 1.8519822533150174e-01, - -4.7687267679858099e-02, 1.0893715640778658e-01, - 5.4446388557811642e-01, 6.7864355635107079e-02, - 1.8925675037139755e-01, 3.6392773516755073e-02, - -2.4764455183159433e-02, -3.8468294614801751e-02, - -2.8696444635530814e-02, -1.8823021866307067e-02, - 4.8264052464878845e-02, -3.6882747079153497e-02, - -3.0155420938729255e-01, 1.0404831951207047e-02, - 4.4505477004053171e-03, -4.6873846610364103e-02, - 2.4798470273412251e-02, 2.5891733287640804e-02, - 3.5011544817152707e-01, 8.5903050378751358e-02, - -1.6860450574909990e-02, -3.9052038500091160e-02, - -2.9924661599529656e-01, -1.5823886416275893e-03, - 2.8254484941419005e-03, -4.8861168063938747e-03, - 9.7917302635802658e-02, 2.7710576047465570e-02, - 2.3536560145276611e-01, -3.9600571986552502e-01, - -7.4934893198527877e-02, -1.1613827208484023e-01], - [-1.9109089213931946e-01, -8.4666472598656825e-02, - 5.7740802097843921e-02, 1.9626130737187028e-01, - -2.4601756649487860e-01, 8.1511271167717628e-02, - -4.6530930078469529e-01, 6.8795587726048116e-02, - 5.2415554010200038e-02, -1.7332120317563506e-03, - 3.1251731285109323e-02, 1.5521676381926154e-02, - -1.2359815126908288e-01, 2.7460289856811461e-02, - 1.9114633014954776e-02, 2.8966001347205911e-03, - 4.3487864890462036e-01, -2.2957986155413699e-02, - -1.5357935266312277e-02, 1.0016152245695723e-02, - -4.5019081491420573e-02, -2.4405778384030734e-02, - -7.4832588748429490e-02, -4.4078616914614753e-02, - 3.0809052034342380e-02, 1.1926634983737788e-01, - -8.1517751909305367e-02, -7.7527914203627396e-02, - -3.7123430398910418e-02, 1.3750979135916276e-02, - -9.7457414231716055e-02, -1.7178991628521816e-02, - 2.1304973749867503e-01, -5.4941011823140218e-01, - -6.7860578570392335e-02, -1.2038447553380759e-01], - [-1.7561930483645249e-01, -8.8342789136092309e-02, - -1.1242590243640400e-02, -1.8652768797207359e-01, - -9.8464009205703876e-02, 1.7256713195193910e-02, - 2.9649268724224581e-01, 5.8780632678962143e-02, - -3.4585362321307522e-01, 7.6907763800451081e-03, - 2.5103268120083535e-02, 2.5393826053803564e-02, - 4.3240349879996420e-01, 3.3310696488693933e-02, - 2.1609140330890370e-02, 1.3951456173138647e-03, - -1.2840968480253712e-01, -3.3248191939129826e-02, - -8.9379099725266672e-04, -1.8994911138723630e-03, - -2.3834826680311980e-02, 4.7502947323282011e-03, - -4.4024121870114297e-01, -6.7327999197165686e-02, - 2.9359383382924452e-02, 9.1479482958182867e-02, - 3.8593300484440007e-01, -4.7958512765110956e-02, - -5.1251961259242168e-02, 1.8636628882937378e-02, - -6.5572564769060912e-02, -2.2887842635462220e-02, - -1.6042006104302377e-02, -3.3250776465128573e-01, - -6.6477273291217359e-02, -1.2463067898277495e-01], - [-1.6014771753358550e-01, -8.3434053708109190e-02, - 1.3638599925185501e-02, -2.4158649874087133e-02, - -1.1124755841847851e-01, 4.2695267715302458e-02, - 1.4866152720116035e-01, 4.9700778378845270e-04, - -3.5326388070491549e-01, -1.5745483283003094e-02, - -8.9738221678782072e-03, 1.0993364411347295e-02, - 1.9527915544397639e-01, 1.3259513825918660e-02, - -3.9339417079053149e-03, -3.7389315402467350e-02, - 3.0825337281314197e-01, 2.9465425388143118e-02, - -1.0086552608467406e-04, -2.1130010935818223e-02, - 2.4746795171351338e-02, 1.2876294127766924e-02, - -1.3542161100061775e-01, 2.3491306500478031e-02, - 2.8381089185132442e-02, 5.0060402655999779e-02, - -4.7990645387633185e-01, 1.7841388064942280e-02, - 3.6163722246352295e-02, 2.2692968040711251e-02, - -1.4881297657765719e-03, -1.1068249840362020e-02, - 4.3250260717661632e-01, 4.5393847466427317e-01, - -6.1116215809998306e-02, -1.2887688243174231e-01], - [-1.4467613023071851e-01, -8.5360329958689612e-02, - 3.6773895176301370e-02, 2.8417567832807769e-04, - -1.4251569175101705e-01, 1.8419541161364662e-02, - 1.4739729008583152e-01, -6.2901931512317516e-02, - -4.3820330673251112e-01, -1.1585923923104585e-01, - -4.6526417840431711e-02, 1.2161556905396271e-02, - -8.3388018002128958e-02, 2.3616237126461999e-02, - -9.1086898933490409e-02, 9.6073985629915787e-02, - 3.0200810799555788e-01, 9.9080289536070815e-02, - 4.9921034650103280e-02, 7.6871969202905246e-02, - -8.3377720121475072e-03, -1.7031625806123534e-02, - 4.5636496936456672e-01, -4.0005637071420394e-02, - -1.9891703100641429e-02, 1.2472945837760744e-02, - 5.9697784009368959e-03, -9.5789228620796370e-03, - 6.8806967828826657e-02, 1.5038487697273856e-01, - 6.8452882565985446e-02, 1.3123694381544091e-02, - -5.6226049096551989e-01, -4.1018946243773058e-02, - -5.6717572380307106e-02, -1.3312308588070965e-01], - [-1.2920454292785150e-01, -6.6253352907543861e-02, - -1.0164436321011842e-01, -1.4433060335444364e-01, - 1.6028176487458967e-01, 3.4584483531135940e-02, - 1.9900533500768001e-02, -5.2164178106233798e-02, - -1.2875710620386896e-01, -1.3038955529948765e-01, - -3.1311992664378889e-02, 2.5299917094429910e-02, - -4.1764341929454979e-01, 5.7547077142788963e-02, - -1.1598534347679475e-01, 1.8086109486937549e-01, - -6.3115663671148348e-02, 8.6408791666891471e-02, - 4.0289642159952954e-02, 1.2892059198986330e-01, - -7.5052803928986972e-02, -3.4807004039357006e-02, - 2.0072216849958635e-01, -1.1909118683716058e-01, - -2.6393566026650855e-02, 6.6849035713186178e-02, - 4.7200759534307635e-01, -7.6853961442131774e-02, - 2.6993333821331650e-02, 1.7484304402685918e-01, - 5.3240433359001025e-03, 2.9788042206222785e-03, - 5.1760936987899087e-01, 1.1384037033693235e-01, - -5.1865856323749862e-02, -1.3736928932967699e-01], - [-1.1373295562498452e-01, -5.7235135967154585e-02, - -4.7652965020097103e-02, -1.7627396739100985e-02, - 7.7938405922626644e-02, 2.2087656281477019e-02, - 6.1009605667557178e-03, -5.4981966965685393e-02, - -1.8486086378646865e-01, 3.8911039431433647e-02, - 3.5079519080830110e-02, 1.9272432328556483e-02, - -5.9096451891695889e-01, -7.7247905448605157e-03, - 3.7441325666613741e-02, -4.9165769090891341e-02, - -3.3776276260195798e-01, 1.6606308621317768e-02, - 3.8859102913090936e-02, -1.9047412918711374e-02, - -3.8482634352387676e-02, -4.8755071639337150e-02, - -4.3270527443011519e-01, -9.1999354995766322e-02, - 1.0430914529054176e-01, 1.4978760949122619e-01, - -3.4135100214765429e-01, -2.5289826614278744e-02, - 3.4608873349492607e-02, 8.8085003662463843e-02, - -1.5196825642675141e-01, -9.3051296574294673e-03, - -2.4468277187262805e-01, -2.4348157193486621e-02, - -4.7513567722300747e-02, -1.4161549277864433e-01], - [-9.8261368322117570e-02, -1.6390394385331745e-02, - -5.4742294041798749e-02, -5.8987021949670405e-02, - -1.6882319276059432e-01, 4.3601612172208745e-02, - -2.9911314975774938e-02, 2.3284677199386728e-03, - -3.1808540586289284e-02, 6.9627318822466044e-01, - 1.6271702602637766e-01, 1.5743246880124597e-02, - -4.3195703838658110e-02, -2.2494758789598773e-01, - 7.1399213422553218e-02, -1.3240943946997921e-01, - -8.4980139589052577e-03, -3.2038201094679952e-01, - 6.2407097431780204e-02, -7.6882180114861851e-02, - 2.9470860002467913e-02, -4.2571478756212582e-02, - 2.0163350380724604e-01, -3.2389702717405428e-01, - 6.9711204990479309e-02, -8.1573794801329258e-02, - 1.3304500243627673e-01, 4.0406118875997113e-02, - 8.2477981782237836e-02, -1.1543529624088469e-01, - -1.1014206710642817e-01, 4.2320022953069426e-06, - 3.8041226304310447e-03, 1.3395530894194055e-01, - -3.9467794046677329e-02, -1.4586169622761166e-01], - [-8.2789781019250525e-02, -1.9278711714630567e-01, - 2.2165755909431184e-01, -2.1201546316262262e-01, - 1.4307796989725635e-01, 6.0334342472999250e-02, - -5.5139304406736672e-02, -1.9408969113742302e-02, - 5.4970843704949646e-02, -4.5047658482968128e-01, - -3.3338315762977556e-02, -6.5308425743183532e-02, - 1.4218465309675436e-02, 4.9087218418760230e-02, - 1.8670840217742501e-01, -1.5287462038432642e-01, - -1.3217180940167689e-02, -6.6463048958420534e-02, - 3.8845065361654303e-04, -2.2429929685530131e-01, - -2.6776933696982124e-02, 8.5772405898653856e-02, - 1.1857225379472448e-01, -3.3789334871471582e-01, - 8.3834684881833613e-02, -1.7391265231974168e-01, - -5.9431721332300208e-03, 2.7485104738181495e-02, - 1.6105963634532708e-01, -4.7246605597344127e-01, - -2.3898285645951292e-01, -2.0628986543330220e-02, - -2.1798010578591574e-02, 1.6076906598537423e-02, - -5.4377032852269684e-02, -1.5010789967657906e-01], - [-6.7318193716383562e-02, -1.3247564302860890e-01, - 1.7006921492087917e-01, 1.2398760160260749e-01, - -1.4177630269484331e-01, 1.5422349385403381e-02, - -5.9592326716797428e-02, -3.5882053764316857e-02, - -1.7232432793461348e-02, 2.3701488719579314e-01, - -4.6593215018650616e-02, -6.3082282004145299e-02, - -2.0902723950643357e-02, 5.2050993065408405e-02, - -8.0468326155430828e-02, -5.0880717820819980e-02, - -1.1820152914284968e-01, 5.6506976812092713e-01, - -2.1968735055254530e-02, 1.6529598718631755e-01, - 1.0797738052990204e-01, -3.0113303079001008e-02, - 5.5521405735639642e-03, 2.7802427161516047e-01, - -1.3829193596041753e-01, -1.1466435184415830e-01, - 1.1740546330296046e-01, -1.7311150238082029e-01, - -1.6365530586101310e-01, -3.6819727396673907e-01, - -3.1239015782869367e-01, 6.3966770007709506e-02, - -2.6591619532336051e-02, 1.2885889151522636e-01, - -3.7992961598361283e-02, -1.5435410312554640e-01], - [-5.1846606413516585e-02, -6.0477319044140554e-02, - -7.5750638182608219e-03, -1.0624372654415394e-01, - 8.1266486795481985e-02, 4.0180836057036554e-02, - -3.7783670829837974e-02, 4.6289675320758547e-02, - 3.3808855820936547e-02, -1.9195948450068509e-01, - -5.8196442046703094e-02, 1.7282080569685822e-03, - 1.4755965059760449e-02, -6.0959969133142022e-01, - -2.8239274796445768e-01, 1.2486767782495350e-01, - -1.6812624118941352e-02, -3.1637047991210354e-01, - -3.4329518102613220e-02, 2.9658523886210797e-01, - 2.1095830387260842e-01, -7.1581690223787436e-02, - 1.4902746008909057e-02, 2.5118050689616306e-01, - 1.5960904763919231e-01, 1.6146826320314336e-01, - -3.0778528162015331e-02, -6.0781897242040703e-03, - -1.5766062756371724e-01, -2.2924930849571712e-01, - -2.3919944196342770e-02, 4.0432828090792343e-02, - -3.3603315710298294e-02, 6.6005717038430623e-03, - -3.2237412023528290e-02, -1.5860030657451374e-01], - [-3.6375019110649567e-02, -4.6095054123273631e-02, - 4.1487329226456366e-03, -4.9882330119267008e-02, - 2.6789583798631911e-01, 2.8310263556813459e-02, - -5.0744234427433435e-02, -2.1955670997388516e-01, - 8.8814242427478526e-02, 7.2616405945027329e-02, - 3.7105581486243189e-01, 1.3801726499993164e-01, - 1.2228306569610396e-01, -1.8641957679946289e-01, - -1.7746951776518829e-01, 1.1838468893129621e-01, - 4.1434840944853890e-02, 3.4352445701196649e-01, - -1.3539286248067484e-01, 1.2179016223131671e-01, - -1.4481862254120659e-01, -6.0813770391397334e-02, - -9.5024877677197070e-02, -2.6026144416788322e-01, - 6.7007386100264313e-02, -2.7403316717453452e-01, - -1.2940472617950355e-01, -7.0811325772559455e-02, - 1.0283464270665656e-02, -5.0042226650144100e-02, - 3.9567119578457077e-01, -2.3131183910318670e-01, - -2.4438157021422158e-02, -9.5495078814865603e-02, - -3.1811761848109070e-02, -1.6284651002348108e-01], - [-2.0903431807782615e-02, 7.2327502897265056e-02, - -2.1426834420397733e-01, -2.4971807305411563e-02, - -6.8251303361485452e-02, -3.5176957926268708e-03, - -1.7281098595222758e-02, -2.7919893499292525e-01, - -7.5490419998562163e-03, 8.8933532299955390e-02, - -8.3918077552881970e-02, 4.2946166228858822e-02, - -3.5084337029511685e-02, 5.2484778345047800e-01, - -1.3476341073870199e-01, 8.9651093734304757e-02, - -2.6221874920893444e-02, -3.2081171793188057e-01, - -7.0201683149374666e-02, 9.7920337768921742e-02, - -7.6208072805887969e-02, 2.9964575931518713e-02, - 2.1839138515231137e-03, 2.1907625163481245e-01, - 7.8802565386018458e-02, 1.0637722019900711e-01, - -1.5047419808766808e-02, -1.2522929609505140e-01, - 1.0489044814827699e-01, -4.4452472469644072e-01, - 2.5261973738582033e-01, -1.9360753077714768e-01, - -3.0637038971187570e-02, -3.9473390838082588e-04, - -1.0054456334322568e-02, -1.6709271347244839e-01], - [-5.4318445049156196e-03, -1.1991560506989501e-01, - 1.6016393502783463e-01, -9.0534713898102900e-02, - 1.7803986653673967e-01, 4.2517830558630100e-02, - -6.5595472901773699e-02, -6.9456352075150884e-02, - 7.9849581869208763e-02, 1.4596149872374808e-01, - -3.7448911148165226e-01, 3.0784697110174092e-02, - 1.0212691273921030e-01, 1.2477201433959939e-01, - -2.1170895978207616e-01, 1.9057503902571590e-01, - -1.9885301263116554e-02, -2.1847437899940467e-01, - -1.3659628076825936e-01, 6.2262165446311392e-02, - -1.9622860693073528e-02, 4.1620399347292121e-02, - -3.1648999142503326e-02, 8.2027519954154221e-02, - -7.9260224219164649e-02, -4.4257777757196498e-01, - -1.0450524222584731e-01, 7.1670676847096298e-02, - 4.6620848245388563e-02, 3.5490360494088574e-01, - -3.4694381436297000e-01, -2.2966638374036538e-01, - -2.1349097951285249e-02, -5.0149218417714851e-02, - -2.8318514185483656e-02, -1.7133891692141581e-01], - [ 1.0039742797951326e-02, 1.4486958501002600e-01, - -3.0487486722127227e-01, 1.2108072885929126e-01, - -1.1723298949673400e-01, 9.6017523703054095e-03, - 4.9883113678426960e-03, 3.2018649396693973e-02, - -4.0095882258820964e-02, -2.4528012104090294e-01, - 6.0349817604330003e-01, -6.0025406492642708e-02, - -1.6146280657180472e-02, 1.5798023347451132e-01, - -1.5035528625979958e-02, -2.2434556029665070e-02, - -2.4354754626807390e-02, -1.5308774844201870e-01, - -1.1065734099847921e-02, 5.1339996940509787e-02, - 1.6396255893983677e-01, 2.4722965810338692e-02, - 9.6017297101513074e-03, 1.6662850312888863e-01, - 9.1395453034799151e-02, -4.2004786665153609e-01, - 3.0226599593042958e-02, 3.3204444593892296e-02, - -9.0545811500522586e-02, 1.1327046229049616e-01, - -2.5108979165208944e-01, 1.2687846708619716e-01, - -2.1404901679780933e-03, 2.9977168343317158e-02, - 5.5400172108409033e-03, -1.7558512037038310e-01], - [ 2.5511330100818325e-02, -5.8698168696025753e-02, - 8.0629703301508024e-02, -7.0612253616157819e-02, - 3.2715731475630602e-02, 2.1732269341780134e-02, - -5.6700795470449199e-02, -6.8235752853351661e-01, - 6.4905178300795938e-02, -3.5862976828251472e-02, - 8.8618413873728166e-02, 3.1550620324006268e-01, - 9.2319437517647415e-02, -1.0599662867975553e-01, - 2.6587503059973538e-01, -1.0545080566473539e-01, - -2.2738440485640277e-02, -6.6368929276419075e-02, - -5.1003071286368440e-02, -1.1626185301232636e-01, - 5.4119363471023328e-02, -2.4882466696968256e-02, - 4.6420092314024886e-02, 1.7831888983094824e-01, - -2.7253935859206135e-01, 1.7198911112035339e-01, - 1.3432430343834192e-02, 7.1000954309573148e-03, - -3.8416339301886476e-03, 1.6384316059667964e-01, - -6.0953258543061287e-02, 2.6960776094017469e-01, - 2.0718992188831518e-02, -2.7614704623654989e-02, - -1.2643038301898243e-02, -1.7983132381935049e-01], - [ 4.0982917403685273e-02, -4.7160343894475723e-02, - 7.8787266856851345e-03, -1.6730572778497552e-01, - 2.7113248408711793e-01, 9.8438763801876154e-03, - 2.2608843153598773e-02, 4.0738411310515976e-01, - 3.2355058682223534e-02, 1.1698920368317291e-01, - 1.4072643414054364e-01, 6.7061453574130916e-02, - -1.8930127519950827e-02, 1.9146087806398635e-01, - -2.4250669817151019e-02, 1.1868698006794093e-01, - 1.0317141879348907e-01, -8.5252634874863287e-02, - -2.8010523433118828e-01, 1.3060583612270180e-01, - -9.9969111180962050e-02, -3.4760563118607063e-02, - -1.7994529116745678e-02, -6.0554676763009442e-02, - -4.6559703882739706e-01, 1.1940676107160293e-01, - -1.0161278374127546e-01, 1.3173327834920193e-01, - -2.2709272071986680e-02, -1.1755702148341549e-01, - 3.7441059930431703e-02, 4.4164660080364565e-01, - -6.6992110689447992e-02, -2.5301348191003502e-02, - -9.7262032302421250e-03, -1.8407752726831786e-01], - [ 5.6454504706552257e-02, 7.8158541336176779e-02, - -1.4338657014458589e-01, 1.0703741291078765e-01, - -1.3942580377761906e-03, 2.2695174951015635e-03, - -3.8562621975632518e-02, -3.0965063003047144e-01, - 3.7355997032764349e-02, 1.4990453152525209e-02, - -1.1227058245216649e-01, -7.0287795373175999e-01, - 1.1718292741895955e-01, -5.1035967037226390e-02, - -9.4000621055494157e-02, 1.7518267045374700e-01, - -1.4730348981690847e-02, 5.1783743616797537e-02, - 2.1169018058168132e-01, 5.8597372997689870e-02, - -1.6243455966644404e-01, 5.9497378897041750e-02, - -7.3121464646455983e-02, -1.8084067697810838e-01, - -6.6501694611624321e-02, 4.1097079298917809e-02, - -4.3356588698331838e-02, 2.4444891440205574e-01, - 6.5642952335239826e-03, -9.6906979426258765e-03, - 1.8913630981055121e-03, 2.7008769602574367e-01, - 8.8545125037905337e-03, -3.9988001886776758e-02, - 9.3906452280477001e-03, -1.8832373071728517e-01], - [ 7.1926092009419212e-02, 8.0994217906793439e-02, - -2.0767188447365928e-01, -1.5196436606475891e-01, - 1.3077919554196207e-01, -2.1254474743086713e-02, - 4.5019671597743463e-02, 9.6558458919928689e-02, - 1.2420216348711157e-02, -6.2064238471275191e-03, - 9.8956490118614168e-02, -3.2363738790615754e-01, - -3.2870638207842147e-02, -1.5482218310094722e-01, - 2.9647782998980127e-01, -6.1576762109174010e-02, - 1.2666434428081200e-01, 2.1955834692424955e-02, - -1.8997255642944891e-03, -1.0295835477975461e-01, - 1.8208445909004639e-02, -1.1030261882048981e-01, - 4.3794875217006007e-02, 1.8518198489376456e-01, - -4.0747443172392700e-01, 1.3827664021164707e-01, - -4.1431123873109715e-03, -1.4061023435938111e-01, - 1.3942741953117222e-02, 1.9365617058920072e-02, - -8.4489815015323350e-02, -5.7838799828344145e-01, - -2.8902818751484066e-02, -2.4186610549109096e-02, - 1.2086263962861131e-02, -1.9256993416625251e-01], - [ 8.7397679312286217e-02, 1.5064561887342939e-01, - -2.1080556782941462e-01, 1.5916760566958116e-01, - -1.9624826757584166e-01, 1.5198104896205650e-02, - -1.4330248064956560e-02, 3.3068118190946301e-02, - -3.5714352226646290e-02, -1.4260141979380403e-01, - -2.4115477092387741e-01, 3.4101861982281523e-01, - -1.9029646752241479e-03, -2.7699284020832545e-02, - 1.0920088465260440e-01, -1.5239532632222408e-01, - -8.5144012779746134e-02, 5.5970342531910411e-02, - 6.9106614215268647e-02, 2.4036876137100174e-01, - -1.2301443222654272e-01, -1.1953863304856910e-01, - 3.5171852820881193e-03, -2.1104179481631563e-01, - -1.6652675336533382e-01, -6.9825511877400867e-02, - 7.3611503187800218e-03, 5.1349708686040763e-01, - -3.0172148431909446e-01, -1.0589893886410634e-01, - 3.6783462028334960e-03, -2.0553003985674112e-01, - 1.8790472746182015e-02, 1.9823557204917654e-02, - 2.5168461511062466e-02, -1.9681613761521988e-01], - [ 1.0286926661515323e-01, -5.1095768728277327e-02, - 1.3471859461003702e-01, 3.0500821091821676e-02, - -1.6790235354550213e-02, -7.0308669455806175e-03, - -3.0939649438101019e-03, 2.5665199177927620e-02, - 2.1279168221811904e-03, -2.5037640808915945e-02, - -1.2405085129935786e-01, -2.6231150724568519e-01, - -8.5787446133464614e-03, 3.9627338244596369e-02, - 2.3267441336286346e-01, -4.0743293242468487e-01, - 2.4149661576382757e-04, -6.3680910375172040e-02, - -4.3805185403053759e-01, 2.0300111728111647e-01, - -2.1099142295899803e-01, -3.4325637130492054e-01, - 2.4798870388207689e-02, 5.8652422232119368e-02, - 3.1273508409742873e-01, -6.5663309732651248e-02, - 8.4976320234436575e-02, -1.2972698624062320e-01, - -1.0136590956706468e-01, 1.7606369902531008e-01, - 1.7776135567204221e-01, 1.0742707779456324e-01, - -7.9052346006256245e-03, 7.3493627583932908e-02, - 9.9131943085618447e-03, -2.0106234106418724e-01], - [ 1.1834085391802023e-01, 8.1061946874736585e-02, - -1.6265342280467382e-01, -2.5856375159094996e-01, - 1.4258244531423583e-01, -2.5799424990869069e-02, - 1.9638649342146815e-02, 7.3355921016709083e-02, - 5.9394009978013036e-02, 1.5655633426552953e-01, - -9.8792934500238835e-02, 9.9575902680803088e-02, - 1.8527367488061958e-02, 6.3288806058580380e-02, - 3.7739330071097632e-01, 3.9157302813010320e-01, - 1.3485974151563190e-01, 2.4396726581112591e-01, - 3.6171829433890815e-02, -1.5329124928290030e-01, - 1.0994295285071572e-01, -6.2470988682208468e-02, - 7.2649015124010521e-02, 1.4656583051512045e-01, - 5.0160574613932607e-01, 4.7267639935224738e-02, - -4.2965682291764895e-02, 1.8881658695211850e-01, - -1.0776584277343945e-01, -2.6754374009298049e-02, - -7.7009726198669998e-02, 8.6417047403639091e-02, - -5.3833621971674586e-03, -8.0918819205681225e-02, - 2.1780800232539175e-02, -2.0530854451315456e-01], - [ 1.3381244122088717e-01, 1.5997082437941978e-01, - -2.1906335649966574e-01, 2.3332171765159351e-01, - -6.4994730069703827e-02, -2.7137179321886296e-02, - 4.4299490835366419e-02, 5.4082161016101568e-02, - -6.1822856454263338e-02, -6.6517101749567792e-02, - -2.9376460130324589e-01, 1.1103413626514062e-01, - -3.3806550575053815e-02, -1.8397686746205080e-01, - 3.9400318507963744e-02, 1.8758272608343995e-01, - 2.1898570040268548e-02, -5.7258401311702969e-02, - -1.2054652895121606e-01, -3.3785342949153890e-01, - -3.9112933378476634e-02, 1.2987622324621689e-01, - -7.4850924489854642e-02, -1.8237325410753219e-01, - -1.1058781873480500e-01, -2.0595217802395629e-01, - 7.5757742040963461e-03, -5.3655875317100610e-01, - -2.0896258914648322e-01, -5.5945308120122161e-02, - 8.2455318541596961e-02, 1.7624602710846482e-01, - -2.2489297400574856e-02, 5.2915934277181324e-02, - 3.8152138968863464e-02, -2.0955474796212192e-01], - [ 1.4928402852375416e-01, -4.7103999084602964e-02, - 1.5843017378423407e-01, -1.0471529213101267e-01, - 4.1822224430947852e-02, 4.9674575956627585e-03, - -1.3311898606966285e-03, 4.8322275176183468e-02, - 2.6782623911085109e-02, 1.3647784270166637e-02, - 1.0980857986376788e-01, -5.0748588072257886e-04, - -1.0361251293227987e-02, 1.1049141088458188e-01, - -4.7174567274205670e-01, -2.0220954115377396e-01, - 1.3182956708179594e-02, -1.1843903142333311e-02, - 2.0088578029524848e-01, -5.3080319758187777e-01, - -1.6308626968204651e-01, -1.6901681485606096e-01, - 7.0269705034495436e-02, 9.8708103667137601e-02, - 5.8906260202682963e-02, 1.3406466835766842e-01, - 1.3927440769859889e-02, 9.2483635015958410e-02, - -4.1489874017913597e-01, 5.3520424215223954e-02, - 3.3087563030626183e-02, -4.3491644319790287e-02, - -1.4259433018598195e-05, 8.4993306168228474e-03, - 1.9440725644047020e-02, -2.1380095141108932e-01], - [ 1.6475561582662113e-01, 1.5611161975472390e-01, - -2.2981567498448408e-01, -2.5170242091030143e-01, - 1.2572164509633985e-01, -3.5101394036068920e-02, - 1.4388788465769620e-02, 6.4367254285863956e-02, - 7.9127393952476463e-02, 4.7770236979792664e-02, - -1.4967962998375717e-01, 9.6657597995555136e-02, - 2.8600846275685401e-02, 1.0247100377903102e-03, - -1.7416826445456965e-01, -5.2903452729155642e-01, - 1.2378709088794008e-01, 1.6002483124980629e-01, - 2.3117191956384286e-01, 2.0936710152049257e-01, - 8.2739337123958492e-02, 3.7851995698789648e-01, - 1.9060641335918893e-02, -6.4314540668445608e-03, - 8.4778867125413743e-02, 1.6232730574310308e-02, - -5.8776952506303194e-02, -1.6317833767006093e-01, - 2.0541131812472332e-01, 9.4709191370766388e-02, - -3.0776520624173034e-02, 1.1938827858311649e-01, - 1.2517716200189802e-02, -1.3352132837280375e-01, - 3.8021934168930759e-02, -2.1804715486005660e-01], - [ 1.8022720312948814e-01, 8.6827318631575279e-02, - -7.4501227114099414e-02, 1.2876209736226957e-01, - -2.2037890384696301e-01, -2.5814842105572621e-02, - -3.1406758090893994e-02, 9.6294241223690305e-02, - 1.8240072112824506e-02, -7.8775576899911090e-02, - -3.0389264268442007e-02, 8.6684499299869738e-02, - 5.4365532030452843e-02, -1.1850090448995039e-01, - -2.4574663651253167e-01, 5.0606647353540021e-02, - -1.1179254494673002e-01, 1.5746625930135386e-01, - -2.3653025671773734e-01, -2.4326576699636770e-01, - 5.1089622549619594e-02, -2.8901934374460203e-01, - -2.6451534372339578e-02, 5.4045829899974578e-02, - 7.5844174532653701e-03, 9.5261278786040723e-02, - 6.5117432591824925e-02, 1.5374072905554484e-01, - 6.6944827374030014e-01, 2.5045538719576737e-03, - -5.5672913354967879e-02, 1.2051210553600417e-02, - 3.3658431259863966e-02, -3.1395677687489406e-03, - 4.7661017511192831e-02, -2.2229335830902394e-01], - [ 1.9569879043235514e-01, -9.1769753653107577e-02, - 2.7141769527027171e-01, 2.2785564717029946e-01, - 6.4057719170856758e-02, -3.7788206214948872e-03, - 9.7259287514508460e-03, 1.6918261328737952e-01, - 6.8155784376799586e-04, -1.4846652373116371e-02, - 9.7665427524227605e-02, 1.4020779957899679e-01, - 5.4803013440760974e-02, -3.7770889485239614e-02, - 2.0161818269196646e-01, 1.3431772896192445e-01, - -2.2780324141178667e-02, -1.3299949529057514e-01, - 5.6952253822862586e-01, 1.7551693338628394e-01, - -3.8851158821630960e-01, -8.2597118671349307e-02, - -5.5521724833590726e-02, 1.8126259477529724e-01, - 1.7814975368438311e-02, -6.5528218308503153e-02, - 3.7971760553771383e-02, -1.5071623691597721e-01, - 2.1592446351812103e-01, -5.6402536331480002e-04, - 4.5088070248228272e-02, 2.6712876881033590e-02, - -5.4087768899409383e-03, 6.8686308808012492e-02, - 3.2287080492312645e-02, -2.2653956175799139e-01], - [ 2.1117037773522207e-01, -5.0164247531242157e-02, - 2.6588099000556803e-01, 9.2461134185888125e-02, - -1.8638912752062822e-01, -1.3326201088302150e-02, - -1.5139012219398481e-02, 5.6526342555140038e-02, - -2.1347405801495557e-02, 4.2134620640229903e-03, - 1.6189227618448768e-01, -4.0274584225345120e-02, - -5.6430110607539385e-02, -5.8413256975427548e-02, - 5.2327365554425583e-02, 1.0547316593589447e-01, - -1.0141590903757328e-01, 2.2750086641208328e-03, - -2.9965053997941909e-01, 1.5580924251156411e-03, - -9.8801397992561726e-02, 7.0133690173366392e-01, - 2.9288631311505543e-02, 3.2187639373342534e-02, - 8.5847997795661615e-02, 2.0571325754758280e-01, - 7.4079833507648560e-02, 1.5568547966076893e-01, - -4.9689302197244593e-02, 7.8435365554783448e-02, - 4.8351735020509205e-02, -1.7685071128733182e-01, - 6.5889048949493989e-03, 8.0297089881752479e-02, - 3.9088810533135447e-02, -2.3078576520695873e-01], - [ 2.2664196503808903e-01, -3.6435359235223168e-02, - 2.7461198824493543e-01, 8.5347376974543573e-02, - -2.1059797477235808e-02, 1.1448326379020789e-02, - -2.6592754399652377e-02, 2.5891172442431810e-02, - 2.8366243844641929e-02, -2.0536075588459556e-02, - 6.6444382000443650e-05, -6.6068428617317751e-02, - 2.3676624954254568e-02, 2.2112015932022797e-01, - 3.6011261258148117e-02, 6.3110902119789564e-02, - -6.5129709470743133e-02, -4.8955274099800709e-02, - 1.5625642089103450e-01, 1.1336968441478927e-01, - 7.1887047535547766e-01, -1.4060033754799098e-01, - -4.3732646616641863e-02, -2.9113406474813336e-01, - -5.4252028224128682e-02, 8.5563234976626823e-02, - -9.8842092892354998e-03, -8.6014269752744857e-02, - -5.3867992496449059e-02, 1.0226004671603665e-01, - 2.0616418999784455e-01, -6.6321426514466278e-02, - 1.7485733797709232e-02, 1.0373147806260606e-02, - 3.9178042791043720e-02, -2.3503196865592610e-01]]), array([-1.8988227080038084e+03, -8.1652460579197793e-12, - -6.8293671717855184e-12, -5.0961343548435651e-12, - -4.6422244875241180e-12, -4.0432649621797409e-12, - -4.6750947941168519e-13, -4.2866623066103143e-13, - -3.9638626555876315e-13, -3.4647469398250028e-13, - -3.2765729675497798e-13, -3.0727463002427591e-13, - -2.9879803908775378e-13, -2.4080245315867009e-13, - -2.1775959053373055e-13, -1.8534745675222213e-13, - -1.5959779217062472e-13, -1.0879546752449559e-13, - -9.0067575069985811e-14, -5.3973885458936187e-14, - -4.6064162488080463e-14, 6.1429074771130427e-15, - 1.3659631287864453e-14, 3.4753391317142145e-14, - 8.7547004653142170e-14, 1.2585089324337818e-13, - 1.5745245909745148e-13, 2.0606204849135956e-13, - 2.1792577470203850e-13, 2.6674476798831050e-13, - 3.0421425292401405e-13, 3.1193691330212636e-13, - 3.1270969371399125e-13, 4.3446674157388007e-13, - 1.6764394233642590e-12, 2.5208822708003838e+04])), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<36x36xf64> {jax.result_info = "[0]"}, tensor<36xf64> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<1296xf64> - %1 = stablehlo.reshape %0 : (tensor<1296xf64>) -> tensor<36x36xf64> - %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<36x36xf64>) -> tensor<36x36xf64> - %3 = stablehlo.add %1, %2 : tensor<36x36xf64> - %4 = stablehlo.constant dense<2.000000e+00> : tensor - %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<36x36xf64> - %6 = stablehlo.divide %3, %5 : tensor<36x36xf64> - %7 = call @tril(%6) : (tensor<36x36xf64>) -> tensor<36x36xf64> - %8 = stablehlo.custom_call @cusolver_syevd(%7) {api_version = 2 : i32, backend_config = "\01\00\00\00\00\00\00\00\01\00\00\00$\00\00\00Y\98\00\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<36x36xf64>) -> tuple, tensor<36xf64>, tensor, tensor<39001xf64>> - %9 = stablehlo.get_tuple_element %8[0] : (tuple, tensor<36xf64>, tensor, tensor<39001xf64>>) -> tensor<36x36xf64> - %10 = stablehlo.get_tuple_element %8[1] : (tuple, tensor<36xf64>, tensor, tensor<39001xf64>>) -> tensor<36xf64> - %11 = stablehlo.get_tuple_element %8[2] : (tuple, tensor<36xf64>, tensor, tensor<39001xf64>>) -> tensor - %12 = stablehlo.get_tuple_element %8[3] : (tuple, tensor<36xf64>, tensor, tensor<39001xf64>>) -> tensor<39001xf64> - %13 = stablehlo.constant dense<0> : tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor - %15 = stablehlo.compare EQ, %11, %14, SIGNED : (tensor, tensor) -> tensor - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1x1xi1> - %17 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<36x36xf64> - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<36x36xi1> - %20 = stablehlo.select %19, %9, %18 : tensor<36x36xi1>, tensor<36x36xf64> - %21 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1xi1> - %22 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor<36xf64> - %24 = stablehlo.broadcast_in_dim %21, dims = [0] : (tensor<1xi1>) -> tensor<36xi1> - %25 = stablehlo.select %24, %10, %23 : tensor<36xi1>, tensor<36xf64> - return %20, %25 : tensor<36x36xf64>, tensor<36xf64> - } - func.func private @tril(%arg0: tensor<36x36xf64>) -> tensor<36x36xf64> { - %0 = stablehlo.iota dim = 0 : tensor<36x36xi32> - %1 = stablehlo.constant dense<0> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<36x36xi32> - %3 = stablehlo.add %0, %2 : tensor<36x36xi32> - %4 = stablehlo.iota dim = 1 : tensor<36x36xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<36x36xi32>, tensor<36x36xi32>) -> tensor<36x36xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<36x36xf64> - %8 = stablehlo.select %5, %arg0, %7 : tensor<36x36xi1>, tensor<36x36xf64> - return %8 : tensor<36x36xf64> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01-\x05\x01\x05\x01\x03\x05\x03\x1d\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!\x03^\x02\xeb5\x01\x95\x0f\x17\x13\x07\x0f\x0b\x0b\x0b\x0b\x0b\x17\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x17\x0f\x13\x13\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x13\x0bK\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x03W\x0b\x0b\x0f\x0b\x0bO/\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b/O/\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1b\x0f\x0f\x0f\x0f\x0f\x0b/O/\x035\x17\x0f\x07\x0f\x07\x13\x07\x07\x17\x07\x17\x13\x1b\x17\x17\x13\x17\x1b\x13\x13\x13\x0f\x17\x13\x13\x13\x02\xa6\x08\x1d\x85\x03\x17\x11R\x04\x01\x03\x03\x13\xbd\x1f\x1d9\x03\x05#\x05%\x05'\x05)\x05+\x17\x11N\x04\x01\x05-\x05/\x051\x053\x03\x03!\xb9\x055\x03\x03\x0b\xbb\x1d?\x03\x057\x059\x17\x11F\x04\x01\x1dm\x15\x03\x03\x0b\xe5\x03\x03\x0f3\x05;\x03\x0b\x17\x95\x19\xa3\x1b\xa5\x0f\xaf\x1d\xb1\x03\x0b\x17\x99\x19\xb5\x1b\x99\x0f\x9b\x1d\xb7\x05=\x1d=\x03\x05?\x05A\x03\x03!\xbf\x1dE\x03\x05C\x03\x05'\x9d)\xc1\x1dK\x03\x05E\x03\x03\x0b\xc3\x1dQ\x03\x05G\x1dU\x03\x05I\x1dY+\x05K\x1d]+\x05M\x03\x03a\xc5\x05O\x1de\x15\x05Q\x1di\x15\x05S\x03\x03\x0b\xc7\x05U\x03\x03q\x9b\x05W\x03\x11u\xc9w\xcby\xcd{\x95}\xcf\x7f\xd1\x81\xd3\x83\xd7\x05Y\x05[\x05]\x05_\x05a\x05c\x05e\x05g\x05i\x03\x03\r\xdb\x03\x03\r\xdd\x03\x03\r\xdf\x03\x03\r\xe1\x03\x05'\x9d)\xe3\x03\x03\x13\xe7\x03\x03\x13\xe9\x03\x01\x1dk\x03\x03\xb3\x1dm\t\x07\x1f%!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1b\x03\x05\xa7\xab\r\x03\x97\xa9\x1do\r\x03\x97\xad\x1dq\x1ds\x1du\r\x01#\x1d\x1dw\x13\r\x01\x1f\x07\t\x00\x00\x00\x00\x1f\x1f\x01\x13\r\x05\x07\x05\x1f\x03\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x17!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\x11\x00\x00\x00\x00\x00\x00\x00@\x0b\x05\x1dy\x1d{\x05\x01\x03\x03\x9f\x03\x03\xd5\x15\x03\x01\x01\x01\x03\t\x9f\xa1\xd9\xa1\x1f)\x01\x13\x05\x01\x13\x05\x05\x13\x05\t\x13\x05\r\x07\x01\x1f\x03\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x17!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00)\x05\x91\x91\t)\x01\t\x1b)\x01\x05\x0b)\x03\x91\t\x1d\x01)\x05\x91\x91\x05\x13)\x05\x91\x91\x0f)\x03\t\r)\x03\x94\x85\t\t\x11\x01\x05\x01\x0b\x11\x03\x01\x03\x01)\x03\x01\r)\x03\x82(\t/\t\x01\x0b\x07\x19)\x03\t\x13)\x03\x05\x13)\x03\x01\x13)\x01\x0f)\x05\x05\x05\x0f)\x03\x05\x0f)\x03\x91\x0f)\x03\x05\r\x04\xc6\x04\x05\x01\x11\x071\x07\x03\x01\t\r\x11\x075\x05\x035m\t\x03W\x1f\x03!\x15\x06[\x03\x01\x03\x01\x17\x07c_\x03\x01\x03\x03\x0f\x06g\x03\x01\x05\x03\x05\x05\x03\x07k\x03\x03\x03\x07-\x05\x03\x01\x03\t\x19\x06-\x03\x01\x05\x07\x0b\x1b\x07\to\x03\x01\x03\r\x1d\x07\x01s\x03#\x03\x0f\x07\x07\x01\x87\x03\x01\x03\x11\x07\x07\x01\x89\x03\x0b\x03\x11\x07\x07\x01\x8b\x03\x07\x03\x11\x07\x07\x01\x8d\x03\x19\x03\x11\x05\x03\x01#\x03\x07\x03\x07\x01\x05\x03\x07\x03\x1b\x11\x07\x01\x8f\x03+\x05\x17\x1d\x03\x07\x01\x05\x03-\x03\x1f\x05\x03\x01/\x03\x03\x03\x07\x01\x05\x03\x01\x03#\x03\x07\x01\x91\x03\x15\x03!\x0b\x06\x01\x03\x01\x07'\x13%\x03\x07\x01\x05\x03/\x03\x1f\x05\x03\x01/\x03\x03\x03\x07\x01\x05\x03\x0b\x03-\x03\x07\x01\x93\x031\x03+\x0b\x06\x01\x03\x0b\x071\x15/\x13\x04\x07\x05)3\r\x11\t7\x05\x03\x15+\x03\x01\x07\t\x03;\x1f\x03\x11\x05\x03\t#\x03\x07\x03\x07%\x05\x03\x11\x03\x05\x0f\x06%\x03\x11\x05\x03\x07\t\x03CA\x03\x11\x11\x07IG\x03\x15\x05\t\x0b\x05\x03\tM\x03\x03\x03\x07O\x05\x03\x01\x03\x0f\x0b\x06S\x03\x01\x07\r\x01\x11\x13\x04\t\x03\x13\x06\x03\x01\x05\x01\x00.\x1a}\x1f+\x11\x0f\x0b\t\t\x0b!\x7f\x1f/!!)#\x1f\x19\x0f99m\x19\x89\x8dW\xb7K\x9fM\x9f\x96\x04\x1b+\x1b\x1f\x1f\x15\x1d\x15+\x83\x13\r\r\x1f\x11\x15\x1b\x17\x15\x17\x0f\x11\x15\x11+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00value\x00index\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(36, 36) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(36, 36) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(36, 36) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=float64 shape=(1296,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(36, 36) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True]\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00$\x00\x00\x00Y\x98\x00\x00\x00cusolver_syevd\x00", - xla_call_module_version=4, - ) # End paste -) +array = np.array +float32 = np.float32 +complex64 = np.complex64 data_2024_09_30 = {} diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_cusolver_getrf.py b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_cusolver_getrf.py index 47da841aec0a..1842939705e3 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_cusolver_getrf.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_cusolver_getrf.py @@ -15,7 +15,12 @@ # ruff: noqa import datetime -from numpy import array, int32, float32, complex64 +import numpy as np + +array = np.array +float32 = np.float32 +int32 = np.int32 +complex64 = np.complex64 data_2024_08_19 = {} diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_pivots_to_permutation.py b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_pivots_to_permutation.py index 12285a45b77a..e9ac59ef731b 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_pivots_to_permutation.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_pivots_to_permutation.py @@ -13,14 +13,17 @@ # limitations under the License. import datetime -from numpy import array, int32 +import numpy as np + +array = np.array +int32 = np.int32 # Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_08 = dict( +data_2025_04_01 = dict( testdata_version=1, platform='cuda', custom_call_targets=['cu_lu_pivots_to_permutation'], - serialized_date=datetime.date(2024, 8, 8), + serialized_date=datetime.date(2025, 4, 1), inputs=(), expected_outputs=(array([[[0, 1, 2, 3, 4, 5, 6, 7], [4, 5, 6, 7, 0, 1, 2, 3], @@ -31,25 +34,22 @@ [0, 1, 2, 3, 4, 5, 6, 7]]], dtype=int32),), mlir_module_text=r""" module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x3x8xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) { + func.func public @main() -> (tensor<2x3x8xi32> {jax.result_info = "result"}) { %0 = stablehlo.iota dim = 0 : tensor<24xi32> loc(#loc4) %1 = stablehlo.reshape %0 : (tensor<24xi32>) -> tensor<2x3x4xi32> loc(#loc5) - %c = stablehlo.constant dense<2> : tensor loc(#loc6) - %c_0 = stablehlo.constant dense<3> : tensor loc(#loc6) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc6) - %2 = stablehlo.custom_call @cu_lu_pivots_to_permutation(%1) {mhlo.backend_config = {permutation_size = 8 : i32}, operand_layouts = [dense<[2, 1, 0]> : tensor<3xindex>], result_layouts = [dense<[2, 1, 0]> : tensor<3xindex>]} : (tensor<2x3x4xi32>) -> tensor<2x3x8xi32> loc(#loc6) + %2 = stablehlo.custom_call @cu_lu_pivots_to_permutation(%1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "2"}, operand_layouts = [dense<[2, 1, 0]> : tensor<3xindex>], result_layouts = [dense<[2, 1, 0]> : tensor<3xindex>]} : (tensor<2x3x4xi32>) -> tensor<2x3x8xi32> loc(#loc6) return %2 : tensor<2x3x8xi32> loc(#loc) } loc(#loc) } loc(#loc) #loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":347:26) -#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":347:14) -#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":348:11) -#loc4 = loc("jit()/jit(main)/iota[dtype=int32 shape=(24,) dimension=0]"(#loc1)) -#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(2, 3, 4) dimensions=None]"(#loc2)) -#loc6 = loc("jit()/jit(main)/lu_pivots_to_permutation[permutation_size=8]"(#loc3)) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":408:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":408:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":409:11) +#loc4 = loc("jit()/jit(main)/iota"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape"(#loc2)) +#loc6 = loc("jit()/jit(main)/lu_pivots_to_permutation"(#loc3)) """, - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1d\x05\x01\x03\x01\x03\x05\x03\r\x07\t\x0b\r\x0f\x11\x03\xa7}\x17\x01Q\x0f\x07\x0b\x0b\x0f\x0b+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x17\x0f\x0b\x17\x13\x0b\x17\x13\x13S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03-\x0b\x0b\x0f\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x0f///\x0b\x0b\x0b\x13\x0b\x0fo\x01\x05\x0b\x0f\x03\x13\x0f\x07\x1b\x07\x13\x13\x1b\x13\x07\x02Z\x04\x1d57\x1f\x05\x13\x05\x15\x11\x03\x05\x05\x17\x03\t\x0f\x11\x13\t\x15\t\x0b\x17\x05\x19\x11\x01\x00\x05\x1b\x05\x1d\x05\x1f\x03\x0b\x1bQ\x1dW\x1fY\x0bc!e\x05!\x05#\x05%\x05'\x03\x03%g\x05)\x1d)+\x05+\x17\x05n\x055\x1d/1\x05-\x17\x05n\x05\x1d\x03\x03\x07i\x05/\x17\x05r\x05\x17\x03\x03\x07k\x03\x03\x07m\x03\x13?oASCqEQGsIuKUMQOU\x051\x053\x055\x057\x059\x05;\x05=\x05?\x05A\x03\x01\x1dC\x03\x03{#\r\x03\x03[\r\x05]S_a\x1dE\x1dG\x1dI\x1dK\x1dM\x13\x0b\x01\x1f\x05\x11\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x05\x11\x03\x00\x00\x00\x00\x00\x00\x00\x1f\x05\x11\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1dO\x05\x01\r\x03wy\x1dQ\x13\x07!\x1f\x131\x02\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x01\x0b\x1b)\x07\t\r!\x07\x1d\x11\x01\x03\t)\x03a\x07)\x07\t\r\x11\x07)\x03\r\x15\x13\x04{\x05\x01\x11\x03\r\x07\x03\x01\x05\x05\x11\x03\x19\x07\x03\r\x1d\x07\x03'#\x03\x0f\t\x06-\x03\x11\x03\x01\x03\x03\x013\x03\x05\x03\x03\x019\x03\x05\x03\x03\x01;\x03\x05\x0b\x07\x01=\x03\t\x03\x03\r\x04\x03\x03\x0b\x06\x03\x01\x05\x01\x00f\x0cS#9\x0f\x0b\x11#!\x03\x1f/!)!)#\x1f\x19\x8b\x8b\x85\x1f\x1f\x15\x1d\x15\x1b%)9\x13\ri\x15\x1f\x17\x11\x11\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00value\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit()/jit(main)/iota[dtype=int32 shape=(24,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(2, 3, 4) dimensions=None]\x00jit()/jit(main)/lu_pivots_to_permutation[permutation_size=8]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00\x00jax.result_info\x00mhlo.layout_mode\x00default\x00main\x00public\x00cu_lu_pivots_to_permutation\x00permutation_size\x00", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.9.3\x00\x01\x1d\x05\x01\x05\r\x01\x03\x0b\x03\x0b\x0f\x13\x17\x1b\x1f\x03yQ\x15\x01+\x07\x0b\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x17\x0f\x0b\x17\x1b\x0b\x0b\x0f\x0b\x17\x03'\x0b\x0f\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0bo\x01\x05\x0b\x0f\x03\x11\x07\x1b\x13\x13\x07\x1b\x13\x07\x02\x9e\x02\x1f\x05\x11\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x13\x11\x01\x00\x05\x15\x05\x17\x05\x19\x1d\x15\x17\x05\x1b\x17\x03b\x065\x1d\x1b\x1d\x05\x1d\x17\x03b\x06\x1d\x03\x05!?#A\x05\x1f\x05!\x1d')\x05#\x17\x03f\x06\x17\x03\x01\x03\x03O#\t\x03\x033\r\x0357\x1d%\x1d'\x1d)\x1d+\x13\r\x01\r\x01\r\x03CE\x1d-\x1d/\x0b\x03\x1d1\x1d3\x05\x01\x1f\x111\x02\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02\x1b)\x07\t\r!\x05\x11\x01\x03\x07)\x03a\x05\x1d)\x07\t\r\x11\x05)\x03\r\x13\x13\x04c\x05\x01Q\x01\x07\x01\x07\x04Q\x03\x01\x05\x03P\x01\x03\x07\x04=\x03\x07\x11\x05B\x13\x05\x03\x0b\x07\x06\x19\x03\x0f\x03\x01\tG%\x1f\x07\x03\x07\x03\x03\x0b\x04\x01\x03\x05\x06\x03\x01\x05\x01\x00J\x0759\x03\x05\x1f\x0f\x0b\x0f!c3)A;\x1b%)9i\x15\x1f\x17\x11\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/iota\x00jit()/jit(main)/reshape\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00jit()/jit(main)/lu_pivots_to_permutation\x00jax.result_info\x00result\x00main\x00public\x00num_batch_dims\x002\x00\x00cu_lu_pivots_to_permutation\x00\x08+\t\x05#\x01\x0b+/19;\x03=\x11GIK+M-+-", xla_call_module_version=9, nr_devices=1, ) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_qr_cusolver_geqrf.py b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_qr_cusolver_geqrf.py index be5c6e01f8d8..9b91a0052d11 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_qr_cusolver_geqrf.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_qr_cusolver_geqrf.py @@ -15,149 +15,14 @@ # ruff: noqa import datetime -from numpy import array, float32, float64, complex64, complex128 +import numpy as np -data_2023_03_18 = {} +array = np.array +float32 = np.float32 +complex64 = np.complex64 -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_03_18["unbatched"] = dict( - testdata_version=1, - platform='cuda', - custom_call_targets=['cusolver_geqrf', 'cusolver_orgqr'], - serialized_date=datetime.date(2023, 3, 18), - inputs=(), - expected_outputs=(array([[ 0. , 0.9128705 , 0.40824863], - [-0.44721356, 0.36514878, -0.8164964 ], - [-0.8944271 , -0.18257457, 0.40824813]], dtype=float32), array([[-6.7082043e+00, -8.0498438e+00, -9.3914843e+00], - [ 0.0000000e+00, 1.0954436e+00, 2.1908882e+00], - [ 0.0000000e+00, 0.0000000e+00, 5.6703755e-08]], dtype=float32)), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "[0]"}, tensor<3x3xf32> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xf32> - %1 = stablehlo.reshape %0 : (tensor<9xf32>) -> tensor<3x3xf32> - %2 = stablehlo.custom_call @cusolver_geqrf(%1) {api_version = 2 : i32, backend_config = "\00\00\00\00\01\00\00\00\03\00\00\00\03\00\00\00\00\00\03\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf32>) -> tuple, tensor<3xf32>, tensor, tensor<196608xf32>> - %3 = stablehlo.get_tuple_element %2[0] : (tuple, tensor<3xf32>, tensor, tensor<196608xf32>>) -> tensor<3x3xf32> - %4 = stablehlo.get_tuple_element %2[1] : (tuple, tensor<3xf32>, tensor, tensor<196608xf32>>) -> tensor<3xf32> - %5 = stablehlo.get_tuple_element %2[2] : (tuple, tensor<3xf32>, tensor, tensor<196608xf32>>) -> tensor - %6 = stablehlo.get_tuple_element %2[3] : (tuple, tensor<3xf32>, tensor, tensor<196608xf32>>) -> tensor<196608xf32> - %7 = stablehlo.constant dense<0> : tensor - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor - %9 = stablehlo.compare EQ, %5, %8, SIGNED : (tensor, tensor) -> tensor - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1x1xi1> - %11 = stablehlo.constant dense<0x7FC00000> : tensor - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor<3x3xf32> - %13 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %14 = stablehlo.select %13, %3, %12 : tensor<3x3xi1>, tensor<3x3xf32> - %15 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1xi1> - %16 = stablehlo.constant dense<0x7FC00000> : tensor - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor) -> tensor<3xf32> - %18 = stablehlo.broadcast_in_dim %15, dims = [0] : (tensor<1xi1>) -> tensor<3xi1> - %19 = stablehlo.select %18, %4, %17 : tensor<3xi1>, tensor<3xf32> - %20 = stablehlo.constant dense<0.000000e+00> : tensor - %21 = stablehlo.pad %14, %20, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> - %22 = stablehlo.custom_call @cusolver_orgqr(%21, %19) {api_version = 2 : i32, backend_config = "\00\00\00\00\01\00\00\00\03\00\00\00\03\00\00\00\03\00\00\00 \81\00\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf32>, tensor<3xf32>) -> tuple, tensor, tensor<33056xf32>> - %23 = stablehlo.get_tuple_element %22[0] : (tuple, tensor, tensor<33056xf32>>) -> tensor<3x3xf32> - %24 = stablehlo.get_tuple_element %22[1] : (tuple, tensor, tensor<33056xf32>>) -> tensor - %25 = stablehlo.get_tuple_element %22[2] : (tuple, tensor, tensor<33056xf32>>) -> tensor<33056xf32> - %26 = stablehlo.constant dense<0> : tensor - %27 = stablehlo.broadcast_in_dim %26, dims = [] : (tensor) -> tensor - %28 = stablehlo.compare EQ, %24, %27, SIGNED : (tensor, tensor) -> tensor - %29 = stablehlo.broadcast_in_dim %28, dims = [] : (tensor) -> tensor<1x1xi1> - %30 = stablehlo.constant dense<0x7FC00000> : tensor - %31 = stablehlo.broadcast_in_dim %30, dims = [] : (tensor) -> tensor<3x3xf32> - %32 = stablehlo.broadcast_in_dim %29, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %33 = stablehlo.select %32, %23, %31 : tensor<3x3xi1>, tensor<3x3xf32> - %34 = call @triu(%14) : (tensor<3x3xf32>) -> tensor<3x3xf32> - return %33, %34 : tensor<3x3xf32>, tensor<3x3xf32> - } - func.func private @triu(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> - %1 = stablehlo.constant dense<-1> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<3x3xi32> - %3 = stablehlo.add %0, %2 : tensor<3x3xi32> - %4 = stablehlo.iota dim = 1 : tensor<3x3xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<3x3xf32> - %8 = stablehlo.select %5, %7, %arg0 : tensor<3x3xi1>, tensor<3x3xf32> - return %8 : tensor<3x3xf32> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01+\x05\x01\x05\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03~\x02\xf79\x01\x99\x0f\x0f\x17\x13\x0f\x07\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0bK\x0b\x13\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0bK\x0b\x13\x0b\x03_O/\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0f\x0f\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1f\x0b\x0b\x0f\x17\x1b\x0f\x0f\x0f\x0f\x1f\x0b\x1fO/\x0b\x0b\x13\x17\x039\x17\x0f\x0f\x07\x07\x07\x07\x17\x13\x17\x07\x1b\x0f\x17\x13\x1b\x17\x17\x13\x13\x1b\x13\x13\x13\x13\x13\x13\x17\x02\x06\t\x1d{\x05\x1d\x93\x05\x17\x1f\n\x06\x01\x03\x03\x13\xcb\x1dS\x05\x1f\x05!\x05#\x05%\x05'\x03\x03\r\xe9\x05)\x05+\x05-\x05/\x051\x03\x03#\xc7\x053\x1d[\x05\x055\x057\x03\x03\r\xd1\x17\x1f\x06\x06\x01\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x0f\xdd\x03\x03\x0f\xdf\x03\x03\x0f\xe1\x03\x03\r\xe5\x03\x05'\xa7)\xe7\x03\x03\x13\xeb\x03\x03\x11M\x05I\x03\x0b\x17\x9d\x19\xb1\x1b\xb3\x11\xbd\x1d\xbf\x03\x0b\x17\xa3\x19\xc3\x1b\xa3\x11\xa5\x1d\xc5\x05K\x1dW\x05\x05M\x03\x03\r\xc9\x05O\x03\x03#\xcd\x1da\x05\x05Q\x03\x05'\xa7)\xcf\x1dg\x05\x05S\x1dk\x05\x05U\x1do\x05\x05W\x1ds-\x05Y\x1dw-\x05[\x03\x11/\xa91\xd33\xd55\x9d7\xab9\xd7;\xad=\xdb\x05]\x03\x03\x0f\xe3\x03\x03\x13\xed\x1d\x83\x05\x05_\x03\x07\x87\x9f\x89\x9f\x8b\x9f\x05a\x05c\x05e\x1d\x8f\x05\x05g\x03\x11/\xa91\xef3\xf15\x9d7\xab9\xf3;\xad=\xf5\x05i\x03\x03\x97\xa5\x05k\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dm\x03\x03\xc1\x1do\t\x07\x0b\x05\x05\x01\x03\x03\xd9\x1f/\x01#!\x03\x05\xb5\xb9\r\x03\xa1\xb7\x1dq\r\x03\xa1\xbb\x1ds\x1du\x1dw\r\x01##\x1dy\x13\x0b\x01\x1f\x03\t\xff\xff\xff\xff\x1f%\x01\x13\x0b\x05\x07\x05\x1f\x05\t\x00\x00\x00\x00\x1d{\x1d}\x03\x03\x99\x15\x03\x01\x01\x01\x03\t\x99\x9b\xaf\x9b\x13\t\x01\x13\t\x05\x13\t\t\x13\t\r\x1f\x03\t\x00\x00\x00\x00\x07\x01\x1f\x05\t\x00\x00\xc0\x7f\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d\x7f\x1d\x81\x03\x05\x99\x9b\x03\x07\x99\xaf\x9b)\x05\r\r\x07)\x01\t)\x01\x07\t\x1b\x1d\x01)\x05\r\r\t)\x03\r\x07)\x05\r\r\r\x13)\x03\x04\x000\x07)\x01\r)\x05\x05\x05\r)\x03\t\x0b)\x03\x04\x12\x08\x07\x11\x01\x05\x01\x01\x11\x03\x01\x03\x01)\x03\x01\x0b)\x03%\x07/\t\x01\x11\x03\x17)\x03\t\x15)\x03\x05\x15)\x03\x01\x15)\x03\x05\r)\x03\r\r)\x03\x05\x0b/\x07\x01\x03\x1f\x04\xe6\x05\x05\x01\x11\x0bK\x07\x03\x01\t\x0f\x11\x0bO\x05\x03G\x91\x0b\x03q!\x03'\x17\x06u\x03\x01\x03\x01\x13\x07\x01y\x03)\x03\x03\x07\x07\x01?\x03\x01\x03\x05\x07\x07\x01A\x03\x11\x03\x05\x07\x07\x01C\x03\x03\x03\x05\x07\x07\x01}\x03\x17\x03\x05\x05\x03\x01E\x03\x03\x03\x07\x01\x07\x03\x03\x03\x0f\r\x07\x01G\x03\x19\x05\x0b\x11\x03\x07\x01\x07\x03\x1b\x03\x13\x05\x03\x01\x15\x03\x05\x03\x07\x01\x07\x03\x01\x03\x17\x03\x07\x01I\x03\x13\x03\x15\t\x06\x01\x03\x01\x07\x1b\x07\x19\x03\x07\x01\x07\x031\x03\x13\x05\x03\x01\x15\x03\x05\x03\x07\x01\x07\x03\x11\x03!\x03\x07\x01\x7f\x033\x03\x1f\t\x06\x01\x03\x11\x07%\t#\x05\x03\x81+\x03\x05\x19\x07\x8d\x85\x03\x01\x05\x1d)\x13\x07\x03\x91\x037\x05+'\x07\x07\x03?\x03\x01\x03-\x07\x07\x03A\x03\x03\x03-\x07\x07\x03C\x03\x1f\x03-\x05\x03\x03E\x03\x03\x03\x07\x03\x07\x03\x03\x035\r\x07\x03G\x03\x19\x0517\x03\x07\x03\x07\x03\x1b\x039\x05\x03\x03\x15\x03\x05\x03\x07\x03\x07\x03\x01\x03=\x03\x07\x03I\x03\x13\x03;\t\x06\x03\x03\x01\x07A/?\x1b\x07\t\x95\x03\x01\x03\x1d\x11\x04\x0b\x05CE\x0f\x11\tQ\x05\x03\x15+\x03\x01\x0b\x0b\x03U!\x03\x0f\x05\x03\tY\x03\x03\x03\x07%\x07\x03\x0f\x03\x05\x15\x06%\x03\x0f\x05\x03\x07\x0b\x03_]\x03\x0f\r\x07ec\x03\x13\x05\t\x0b\x05\x03\t+\x03\x05\x03\x07i\x07\x03\x01\x03\x0f\t\x06m\x03\x01\x07\r\x11\x01\x11\x04\t\x03\x13\x06\x03\x01\x05\x01\x00\x86\x19\x83\x1f3\x1f+\x11\x0f\x0b\t\t\x0b!\x0fY\x87##%_=\x85\x87W\xb3K\x9bM\x9b\xd2\x02\x1b\x1f/!!)#\x1f\x19+\x1b\x1f\x83\x1f\x15\x1d\x15+\x13\r\r\x11\x0f\x17\x0f\x1f\x15\x11\x17\x11\x15+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00select_v1\x00iota_v1\x00compare_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00index\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00callee\x00jax.result_info\x00triu\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x00\x00\x03\x00\x00cusolver_geqrf\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00 \x81\x00\x00\x00cusolver_orgqr\x00", - xla_call_module_version=4, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_03_18["batched"] = dict( - testdata_version=1, - platform='cuda', - custom_call_targets=['cublas_geqrf_batched', 'cusolver_orgqr'], - serialized_date=datetime.date(2023, 3, 18), - inputs=(), - expected_outputs=(array([[[ 0. , 0.91287094, 0.40824836], - [-0.4472136 , 0.36514843, -0.81649655], - [-0.8944272 , -0.18257417, 0.4082483 ]], - - [[-0.42426407, 0.80828977, 0.40824953], - [-0.5656854 , 0.11547142, -0.8164964 ], - [-0.7071068 , -0.5773508 , 0.4082474 ]]], dtype=float32), array([[[-6.7082038e+00, -8.0498447e+00, -9.3914852e+00], - [ 0.0000000e+00, 1.0954450e+00, 2.1908898e+00], - [ 0.0000000e+00, 0.0000000e+00, 4.8374091e-08]], - - [[-2.1213203e+01, -2.2910259e+01, -2.4607319e+01], - [ 0.0000000e+00, 3.4641042e-01, 6.9282258e-01], - [ 0.0000000e+00, 0.0000000e+00, 1.4548683e-06]]], dtype=float32)), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<2x3x3xf32> {jax.result_info = "[0]"}, tensor<2x3x3xf32> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<18xf32> - %1 = stablehlo.reshape %0 : (tensor<18xf32>) -> tensor<2x3x3xf32> - %2 = stablehlo.custom_call @cublas_geqrf_batched(%1) {api_version = 2 : i32, backend_config = "\00\00\00\00\02\00\00\00\03\00\00\00\03\00\00\00", operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x3x3xf32>) -> tuple, tensor<2x3xf32>, tensor<16xi8>, tensor<16xi8>> - %3 = stablehlo.get_tuple_element %2[0] : (tuple, tensor<2x3xf32>, tensor<16xi8>, tensor<16xi8>>) -> tensor<2x3x3xf32> - %4 = stablehlo.get_tuple_element %2[1] : (tuple, tensor<2x3xf32>, tensor<16xi8>, tensor<16xi8>>) -> tensor<2x3xf32> - %5 = stablehlo.get_tuple_element %2[2] : (tuple, tensor<2x3xf32>, tensor<16xi8>, tensor<16xi8>>) -> tensor<16xi8> - %6 = stablehlo.get_tuple_element %2[3] : (tuple, tensor<2x3xf32>, tensor<16xi8>, tensor<16xi8>>) -> tensor<16xi8> - %7 = stablehlo.constant dense<0.000000e+00> : tensor - %8 = stablehlo.pad %3, %7, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<2x3x3xf32>, tensor) -> tensor<2x3x3xf32> - %9 = stablehlo.custom_call @cusolver_orgqr(%8, %4) {api_version = 2 : i32, backend_config = "\00\00\00\00\02\00\00\00\03\00\00\00\03\00\00\00\03\00\00\00 \81\00\00", operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x3x3xf32>, tensor<2x3xf32>) -> tuple, tensor<2xi32>, tensor<33056xf32>> - %10 = stablehlo.get_tuple_element %9[0] : (tuple, tensor<2xi32>, tensor<33056xf32>>) -> tensor<2x3x3xf32> - %11 = stablehlo.get_tuple_element %9[1] : (tuple, tensor<2xi32>, tensor<33056xf32>>) -> tensor<2xi32> - %12 = stablehlo.get_tuple_element %9[2] : (tuple, tensor<2xi32>, tensor<33056xf32>>) -> tensor<33056xf32> - %13 = stablehlo.constant dense<0> : tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<2xi32> - %15 = stablehlo.compare EQ, %11, %14, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - %16 = stablehlo.broadcast_in_dim %15, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> - %17 = stablehlo.constant dense<0x7FC00000> : tensor - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<2x3x3xf32> - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x3x3xi1> - %20 = stablehlo.select %19, %10, %18 : tensor<2x3x3xi1>, tensor<2x3x3xf32> - %21 = call @triu(%3) : (tensor<2x3x3xf32>) -> tensor<2x3x3xf32> - return %20, %21 : tensor<2x3x3xf32>, tensor<2x3x3xf32> - } - func.func private @triu(%arg0: tensor<2x3x3xf32>) -> tensor<2x3x3xf32> { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> - %1 = stablehlo.constant dense<-1> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<3x3xi32> - %3 = stablehlo.add %0, %2 : tensor<3x3xi32> - %4 = stablehlo.iota dim = 1 : tensor<3x3xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> - %6 = stablehlo.broadcast_in_dim %5, dims = [1, 2] : (tensor<3x3xi1>) -> tensor<2x3x3xi1> - %7 = stablehlo.constant dense<0.000000e+00> : tensor - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<2x3x3xf32> - %9 = stablehlo.select %6, %8, %arg0 : tensor<2x3x3xi1>, tensor<2x3x3xf32> - return %9 : tensor<2x3x3xf32> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01+\x05\x01\x05\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03\x96\x02\xff=\x01\x9f\x17\x0f\x0f\x0f\x07\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0bK\x0b\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0bK\x0b\x13\x1b\x13\x13\x13\x13\x0b\x03ao/\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0fO\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0bO\x1f\x0b\x0b\x0f\x17\x1b\x0f\x0f\x0f\x0f\x0b\x0b\x13\x17\x1f\x0b/\x1fo\x03=\x1b\x07\x07\x07\x0f\x17\x0f\x07\x13\x07\x13\x1b\x17\x13\x1b\x17\x17\x13\x17\x13\x13\x1b\x07\x13\x13\x13\x17\x13\x1b\x13\x02\x1a\n\x17\x1d\n\x06\x01\x1d\x8f\x01\x1dK\x01\x1dy\x01\x1f\x05!\x03\x03\x0f\xd1\x05#\x05%\x05'\x05)\x05+\x05-\x05/\x051\x03\x03!\xcd\x053\x1dS\x01\x055\x057\x03\x03\x0b\xd9\x17\x1d\x06\x06\x01\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x11\xe5\x03\x03\x11\xe7\x03\x03\x11\xe9\x03\x03\x13E\x05I\x03\x0b\x15\xa3\x17\xb7\x19\xb9\x13\xc3\x1b\xc5\x03\x0b\x15\xa9\x17\xc9\x19\xa9\x13\xab\x1b\xcb\x05K\x1dO\x01\x05M\x03\x03\x0b\xcf\x05O\x03\x03!\xd3\x1dY\x01\x05Q\x03\x05%\xad'\xd5\x1d_\x01\x05S\x03\x03\x0f\xd7\x1de\x01\x05U\x1di\x01\x05W\x1dm\x01\x05Y\x1dq+\x05[\x1du+\x05]\x03\x11-\xaf/\xdb1\xdd3\xa35\xb17\xdf9\xb3;\xe3\x05_\x03\x03\x11\xeb\x1d\x7f\x01\x05a\x03\x07\x83\xa5\x85\xa5\x87\xa5\x05c\x05e\x05g\x1d\x8b\x01\x05i\x03\x11-\xaf/\xed1\xef3\xa35\xb17\xf19\xb3;\xf3\x05k\x03\x03\x0b\xf5\x03\x05%\xad'\xf7\x03\x03\x0f\xf9\x03\x03\x0b\xfb\x03\x03\x0f\xfd\x03\x03\x9d\xab\x05m\x1f/1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1do\x03\x03\xc7\x1dq\t\x07\x0b\x05\x05\x01\x03\x03\xe1\x1f1!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x05\xbb\xbf\r\x03\xa7\xbd\x1ds\r\x03\xa7\xc1\x1du\x1dw\x1dy\r\x01#!\x1d{\x13\x05\x01\x1f\r\t\xff\xff\xff\xff\x1f#\x01\x13\x05\x05\x07\x05\x1f'!\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\t\t\x00\x00\x00\x00\x1d}\x1d\x7f\x03\x03\x9f\x15\x03\x01\x01\x01\x03\t\x9f\xb5\xa1\xa1\x13\x03\x01\x13\x03\x05\x13\x03\t\x13\x03\r\x1d\x81\x1d\x83\x03\x05\x9f\xb5\x03\x07\x9f\xa1\xa1\x1f\r\t\x00\x00\x00\x00\x07\x01\x1f;\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\t\x00\x00\xc0\x7f\x1f\x1b1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00)\x07\t\r\r\x07\x1b\x1d\t)\x01\x07)\x05\r\r\x03)\x01\x03\x01)\x03A-\x13)\x03\t\x03)\x07\t\r\r\x0f)\x05\t\r\x07)\x03\r\x05)\x03\x04\x12\x08\x07\x11\x01\x05\x01\x01\x11\x03\x01\x03\x01)\x03\x01\x05)\x05\r\r\x0f)\x03\t\x05)\x03I\x07/\t\x01\x19\x11\x11\x17)\x03\r\x13)\x03\t\x13)\x03\x05\x13/\x07\x01\x15\x1d)\x03\t\x0f)\x07\t\x05\x05\x0f)\x03\x05\x05\x04r\x04\x05\x01\x11\tC\x07\x03\x01\t\x0b\x11\tG\x05\x03-]\t\x03o\x1f\x03)\x17\x06s\x03\x01\x03\x01\x13\x07\x07w\x03+\x03\x03\x05\x07\x07=\x03\x01\x03\x05\x05\x07\x07?\x03\x19\x03\x05\x05\x07\x07A\x03\x11\x03\x05\x05\x07\x07{\x03\x11\x03\x05\x07\x03})\x03\t\x19\x07\x89\x81\x03\x01\x05\x07\x0f\x13\x07\x03\x8d\x035\x05\x11\t\x05\x07\x03=\x03\x01\x03\x13\x05\x07\x03?\x03\x15\x03\x13\x05\x07\x03A\x03\x1d\x03\x13\x07\x03\x03\x91\x03\r\x03\x07\x03\r\x03\x15\x03\x1b\r\x07\x03\x93\x037\x05\x17\x1d\x03\x07\x03\x95\x039\x03\x1f\x07\x03\x03\x97\x03\t\x03\x07\x03\r\x03\x01\x03#\x03\x07\x03\x99\x03\x17\x03!\x0f\x06\x03\x03\x01\x07'\x15%\x1b\x07\x05\x9b\x03\x01\x03\x07\x11\x04\t\x05)+\x0b\x11\x05I\x05\x03\x17/\x03\x01\t\t\x03M\x1f\x03\x0b\x07\x03\x05Q\x03\r\x03\x07#\r\x03\x0b\x03\x05\x15\x06#\x03\x0b\x05\x03\x07\t\x03WU\x03\x0b\r\x07][\x03%\x05\t\x0b\x03\x07ca\x03\x17\x03\r\x07\x03\x05)\x03\t\x03\x07g\r\x03\x01\x03\x11\x0f\x06k\x03\x01\x07\x0f\x13\x01\x11\x04\x05\x03\x15\x06\x03\x01\x05\x01\x00Z\x1b\x85\x1f3+#\x11\x0f\x0b\t\t\x0b!\x0fY\x9d##%_=\x8b\x89W\xb9\xc1K\x9bM\x9b\xd2\x02\x1b\x1f/!!)#\x1f\x19+\x1b\x1f\x83\x1f\x15\x1d\x15\x13\r+\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11\x19+)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00get_tuple_element_v1\x00constant_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00broadcast_dimensions\x00index\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=(1, 2)]\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(18,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(2, 3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00callee\x00jax.result_info\x00triu\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x00cublas_geqrf_batched\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00 \x81\x00\x00\x00cusolver_orgqr\x00", - xla_call_module_version=4, -) # End paste data_2024_09_26 = {} - data_2024_09_26["f32"] = dict( testdata_version=1, platform='cuda', diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_svd_cusolver_gesvd.py b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_svd_cusolver_gesvd.py index 11e39f801dc6..7a4113d2852f 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_svd_cusolver_gesvd.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_svd_cusolver_gesvd.py @@ -16,7 +16,11 @@ # ruff: noqa import datetime -from numpy import array, float32, complex64 +import numpy as np + +array = np.array +float32 = np.float32 +complex64 = np.complex64 data_2024_10_08 = {"jacobi": {}, "qr": {}} diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_threefry2x32.py b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_threefry2x32.py index 3aa8e3eeb4cc..e3dcc379f01c 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_threefry2x32.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_threefry2x32.py @@ -13,7 +13,11 @@ # limitations under the License. import datetime -from numpy import array, float32, uint32 +import numpy as np + +array = np.array +float32 = np.float32 +uint32 = np.uint32 # Pasted from the test output (see export_back_compat_test_util.py module docstring) diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_tridiagonal_cusolver_sytrd.py b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_tridiagonal_cusolver_sytrd.py index b45d2c281fc8..8f442a434e8d 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_tridiagonal_cusolver_sytrd.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_tridiagonal_cusolver_sytrd.py @@ -15,7 +15,11 @@ # ruff: noqa import datetime -from numpy import array, float32, complex64 +import numpy as np + +array = np.array +float32 = np.float32 +complex64 = np.complex64 data_2025_01_09 = {} diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_tridiagonal_solve.py b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_tridiagonal_solve.py new file mode 100644 index 000000000000..6665996165ed --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_tridiagonal_solve.py @@ -0,0 +1,87 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + +import datetime +import numpy as np + +array = np.array +float32 = np.float32 + +data_2025_06_16 = {} + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_06_16["f32"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cusparse_gtsv2_ffi'], + serialized_date=datetime.date(2025, 6, 16), + inputs=(array([0., 2., 3.], dtype=float32), array([1., 1., 1.], dtype=float32), array([1., 2., 0.], dtype=float32), array([[1.], + [1.], + [1.]], dtype=float32)), + expected_outputs=(array([[ 0.57142854], + [ 0.42857146], + [-0.2857143 ]], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("dl") +#loc2 = loc("d") +#loc3 = loc("du") +#loc4 = loc("b") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<3xf32> loc("dl"), %arg1: tensor<3xf32> loc("d"), %arg2: tensor<3xf32> loc("du"), %arg3: tensor<3x1xf32> loc("b")) -> (tensor<3x1xf32> {jax.result_info = "result"}) { + %0 = stablehlo.custom_call @cusparse_gtsv2_ffi(%arg0, %arg1, %arg2, %arg3) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3x1xf32>) -> tensor<3x1xf32> loc(#loc6) + return %0 : tensor<3x1xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc5 = loc("third_party/py/jax/tests/export_back_compat_test.py":760:13) +#loc6 = loc("jit(func)/jit(main)/tridiagonal_solve"(#loc5)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.10.4\x00\x01\x19\x05\x01\x05\t\x01\x03\x0b\x03\x07\x0f\x13\x17\x03\x83]\x13\x01/\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x1b\x0b\x0b\x0f\x0b\x17\x0b\x03/\x0b/O\x1b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x1b\x0f\x13\x0f\x01\x05\x0b\x0f\x03\x0f\x13\x17\x07\x07#\x13\x13\x02\xea\x02\x1f\x11\x03\x05\x03\x07\x07\t\x0b\x03\r\x03\x05\r\x11\x01\x00\x05\x0f\x05\x11\x05\x13\x1d\x13\x01\x05\x15\x1d\x17\x01\x05\x17\x1d\x1b\x01\x05\x19\x1d\x1f\x01\x05\x1b\x03\x05#/%E\x05\x1d\x05\x1f\x1d)+\x05!\x17-\xe2\x0b\x1b\x05#\r\x01\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x11!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\t////#\r\x03\x03;\r\x03=?\x1d%\x1d'\x1d)\x1d+\r\x03GI\x1d-\x1d/\x0b\x03\x1d1\x1d3\x03\x01\x05\x01\x03\t1113\x03\x03Y\x15\x01\r\x01\x03\x033\x01\t\x01\x02\x02)\x03\r\t)\x05\r\x05\t\t\x13\x11\t\x05\x05\x05\x07\x03\x07)\x03\x05\x0b)\x03\t\x0b\x04c\x05\x01Q\x01\x05\x01\x07\x04Q\x03\x01\x05\x03P\x01\x03\x07\x04=\x03\x0b\x0b\t\x0b\x11\x0b\x15\x0b\x19\x0f\x1d\x00\x05G'!\x05\x03\x07\t\x01\x03\x05\x07\x07\x04\x01\x03\t\x06\x03\x01\x05\x01\x00\xd2\x055'\x03\x05\x1f\x0f\x0b\x0f!iM3)\x05\x07\x05\x07\x13%)9\x15\x1f\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00dl\x00d\x00du\x00b\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00jit(func)/jit(main)/tridiagonal_solve\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.result_info\x00result\x00main\x00public\x00num_batch_dims\x000\x00\x00cusparse_gtsv2_ffi\x00\x08'\x07\x05\x1f\x01\x0b579AC\x11KMOQSUW[", + xla_call_module_version=9, + nr_devices=1, +) # End paste + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_06_16["f64"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cusparse_gtsv2_ffi'], + serialized_date=datetime.date(2025, 6, 16), + inputs=(array([0., 2., 3.]), array([1., 1., 1.]), array([1., 2., 0.]), array([[1.], + [1.], + [1.]])), + expected_outputs=(array([[ 0.5714285714285714 ], + [ 0.42857142857142855], + [-0.2857142857142857 ]]),), + mlir_module_text=r""" +#loc1 = loc("dl") +#loc2 = loc("d") +#loc3 = loc("du") +#loc4 = loc("b") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<3xf64> loc("dl"), %arg1: tensor<3xf64> loc("d"), %arg2: tensor<3xf64> loc("du"), %arg3: tensor<3x1xf64> loc("b")) -> (tensor<3x1xf64> {jax.result_info = "result"}) { + %0 = stablehlo.custom_call @cusparse_gtsv2_ffi(%arg0, %arg1, %arg2, %arg3) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<3xf64>, tensor<3xf64>, tensor<3xf64>, tensor<3x1xf64>) -> tensor<3x1xf64> loc(#loc6) + return %0 : tensor<3x1xf64> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc5 = loc("third_party/py/jax/tests/export_back_compat_test.py":760:13) +#loc6 = loc("jit(func)/jit(main)/tridiagonal_solve"(#loc5)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.10.4\x00\x01\x19\x05\x01\x05\t\x01\x03\x0b\x03\x07\x0f\x13\x17\x03\x83]\x13\x01/\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x1b\x0b\x0b\x0f\x0b\x17\x0b\x03/\x0b/O\x1b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x1b\x0f\x13\x0f\x01\x05\x0b\x0f\x03\x0f\x13\x17\x07\x07#\x13\x13\x02\xea\x02\x1f\x11\x03\x05\x03\x07\x07\t\x0b\x03\r\x03\x05\r\x11\x01\x00\x05\x0f\x05\x11\x05\x13\x1d\x13\x01\x05\x15\x1d\x17\x01\x05\x17\x1d\x1b\x01\x05\x19\x1d\x1f\x01\x05\x1b\x03\x05#/%E\x05\x1d\x05\x1f\x1d)+\x05!\x17-\xe2\x0b\x1b\x05#\r\x01\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x11!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\t////#\r\x03\x03;\r\x03=?\x1d%\x1d'\x1d)\x1d+\r\x03GI\x1d-\x1d/\x0b\x03\x1d1\x1d3\x03\x01\x05\x01\x03\t1113\x03\x03Y\x15\x01\r\x01\x03\x033\x01\t\x01\x02\x02)\x03\r\t)\x05\r\x05\t\x0b\x13\x11\t\x05\x05\x05\x07\x03\x07)\x03\x05\x0b)\x03\t\x0b\x04c\x05\x01Q\x01\x05\x01\x07\x04Q\x03\x01\x05\x03P\x01\x03\x07\x04=\x03\x0b\x0b\t\x0b\x11\x0b\x15\x0b\x19\x0f\x1d\x00\x05G'!\x05\x03\x07\t\x01\x03\x05\x07\x07\x04\x01\x03\t\x06\x03\x01\x05\x01\x00\xd2\x055'\x03\x05\x1f\x0f\x0b\x0f!iM3)\x05\x07\x05\x07\x13%)9\x15\x1f\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00dl\x00d\x00du\x00b\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00jit(func)/jit(main)/tridiagonal_solve\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.result_info\x00result\x00main\x00public\x00num_batch_dims\x000\x00\x00cusparse_gtsv2_ffi\x00\x08'\x07\x05\x1f\x01\x0b579AC\x11KMOQSUW[", + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/export_with_memory_space.py b/jax/_src/internal_test_util/export_back_compat_test_data/export_with_memory_space.py new file mode 100644 index 000000000000..9d4500941494 --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/export_with_memory_space.py @@ -0,0 +1,28 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + +# Pasted from the test output (see export_serialization_back_compat_test.py module docstring) +serializations = [ + dict( + serialization_version=5, + exported_serialized=bytearray(b"0\x00\x00\x00\x00\x00*\x00L\x00J\x00D\x00@\x00<\x008\x004\x00.\x00(\x00$\x00 \x00\x1c\x00\x18\x00\x14\x00\x10\x00\x0e\x00\x08\x00\x07\x00\x00\x000\x00*\x00\x00\x00\x00\x00\x00\x01D\x00\x00\x00\x00\x00\n\x00D\x00\x00\x00\x84\x02\x00\x00\x84\x02\x00\x00\x84\x02\x00\x00X\x02\x00\x00\x80\x02\x00\x00\x88\x02\x00\x00\x00\x00\x02\x00\x02\x00\x00\x00\xa0\x02\x00\x00\xcc\x02\x00\x00\xcc\x02\x00\x00\x04\x03\x00\x00X\x03\x00\x00\x00\x00\x03\x00\x01\x00\x00\x00\x00\x00\x00\x00 \x02\x00\x00ML\xefR\rStableHLO_v1.13.0\x00\x01\x1f\x07\x01\x05\t\t\x01\x03\x0f\x03\x03\x13\x05\x05\x17\x1b\x03kE\x0f\x01\x1b\x07\x0b#\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x03\r\x13\x0f\x1b\x17\x0f\x13\x05\x1f\x0b\x0b\x13\x13\x0b\x0b\x1b\x0b\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x01\x05\x0f\x0b\x05\x0b\x17\x0f\x1b\x07\x07\x02\xf9\x1f\x05\t\x03\x07\x07\t\x0b\r\x0f\x11\x05\x0f\x11\x03\x01\x05\x11\x11\x01\t\x05\x13\x11\x01\x05\x05\x15\t\x03\x1d\x19\x01\x05\x17\x05\x03\x1d\x01\x03\x17\t\r\x15\x05!%\x01\x0b\x03#\x01\x01\t\x17\x01\x0b\x01\x01\x01\x1d\x19\x1d\x1b\x03\x05-3\r\x03/1\x1d\x1d\x1d\x1f\r\x05\')5\x1f\x1d!#\t\x03\x03;\r\x05=?\')\x1d#\x1d%\x1d\'\x1d)\x01\x02\x02\x01\t)\x05\t\r\r)\x01\x0b\x11\x05\x07\x05\x03\x05\x1b\t\x04I\x05\x01Q\x01\x05\x01\x07\x047\x03\x01\t\x03@\x01\x03\x05P\x01\x05\x07\x04\x1b\x03\x05\x07\x05\r\x0b\x17\x00\x07\x04\x01\x03\x03\x06\x03\x01\x05\x01\x00\x1a\x04+\x0f\x0b\x0f!\x1b!)\x19#\x05\x19%)9\x15\x11\x0b\x0f\x0b\t\x11builtin\x00sdy\x00vhlo\x00module\x00mesh\x00func_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda\x00x\x00mhlo.memory_kind\x00pinned_host\x00jax.global_constant\x00_platform_index\x00sdy.sharding\x00jax.result_info\x00result\x00main\x00public\x00\x08\x1b\x07\x05\'\x01\x05\x1b\x03\x0b+79AC\x02\x00\x00\x00\x14\x00\x00\x00\x04\x00\x00\x00\x04\x00\x00\x00cuda\x00\x00\x00\x00\x03\x00\x00\x00tpu\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x18\xff\xff\xff\x01\x00\x00\x00\x04\x00\x00\x00@\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x01\x0c\x00\x00\x00\x08\x03\x1a\x02\x02\x01J\x01\x02R\x01\x00\x01\x00\x00\x00\x04\x00\x00\x00\xcc\xff\xff\xff\x00\x00\x02\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x003\x00\x00\x00\x01\x00\x00\x002\x00\x00\x00p\xff\xff\xff\x01\x00\x00\x00\x10\x00\x00\x00\x0c\x00\x0c\x00\x00\x00\x08\x00\x07\x00\x06\x00\x0c\x00\x00\x00\x00\x00\x02\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x003\x00\x00\x00\x01\x00\x00\x002\x00\x00\x00\xcc\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x02\x02\x00\x00\x00,\x00\x00\x00\x10\x00\x00\x00\x00\x00\n\x00\x0c\x00\x0b\x00\x00\x00\x04\x00\n\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x08\x00\x0c\x00\x0b\x00\x04\x00\x08\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x02\x01\x00\x00\x00\x08\x00\x00\x00\x04\x00\x04\x00\x04\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00"), + ), + + dict( + serialization_version=6, + exported_serialized=bytearray(b"0\x00\x00\x00\x00\x00*\x00L\x00J\x00D\x00@\x00<\x008\x004\x00.\x00(\x00$\x00 \x00\x1c\x00\x18\x00\x14\x00\x10\x00\x0e\x00\x08\x00\x07\x00\x00\x000\x00*\x00\x00\x00\x00\x00\x00\x01D\x00\x00\x00\x00\x00\n\x00D\x00\x00\x00\x84\x02\x00\x00\x84\x02\x00\x00\x84\x02\x00\x00X\x02\x00\x00\x80\x02\x00\x00\x88\x02\x00\x00\x00\x00\x02\x00\x02\x00\x00\x00X\x03\x00\x00\x84\x03\x00\x00\x84\x03\x00\x00\xbc\x03\x00\x00\x10\x04\x00\x00\x00\x00\x03\x00\x01\x00\x00\x00\x00\x00\x00\x00 \x02\x00\x00ML\xefR\rStableHLO_v1.13.0\x00\x01\x1f\x07\x01\x05\t\t\x01\x03\x0f\x03\x03\x13\x05\x05\x17\x1b\x03kE\x0f\x01\x1b\x07\x0b#\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x03\r\x13\x0f\x1b\x17\x0f\x13\x05\x1f\x0b\x0b\x13\x13\x0b\x0b\x1b\x0b\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x01\x05\x0f\x0b\x05\x0b\x17\x0f\x1b\x07\x07\x02\xf9\x1f\x05\t\x03\x07\x07\t\x0b\r\x0f\x11\x05\x0f\x11\x03\x01\x05\x11\x11\x01\t\x05\x13\x11\x01\x05\x05\x15\t\x03\x1d\x19\x01\x05\x17\x05\x03\x1d\x01\x03\x17\t\r\x15\x05!%\x01\x0b\x03#\x01\x01\t\x17\x01\x0b\x01\x01\x01\x1d\x19\x1d\x1b\x03\x05-3\r\x03/1\x1d\x1d\x1d\x1f\r\x05\')5\x1f\x1d!#\t\x03\x03;\r\x05=?\')\x1d#\x1d%\x1d\'\x1d)\x01\x02\x02\x01\t)\x05\t\r\r)\x01\x0b\x11\x05\x07\x05\x03\x05\x1b\t\x04I\x05\x01Q\x01\x05\x01\x07\x047\x03\x01\t\x03@\x01\x03\x05P\x01\x05\x07\x04\x1b\x03\x05\x07\x05\r\x0b\x17\x00\x07\x04\x01\x03\x03\x06\x03\x01\x05\x01\x00\x1a\x04+\x0f\x0b\x0f!\x1b!)\x19#\x05\x19%)9\x15\x11\x0b\x0f\x0b\t\x11builtin\x00sdy\x00vhlo\x00module\x00mesh\x00func_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda\x00x\x00mhlo.memory_kind\x00pinned_host\x00jax.global_constant\x00_platform_index\x00sdy.sharding\x00jax.result_info\x00result\x00main\x00public\x00\x08\x1b\x07\x05\'\x01\x05\x1b\x03\x0b+79AC\x02\x00\x00\x00\x14\x00\x00\x00\x04\x00\x00\x00\x04\x00\x00\x00cuda\x00\x00\x00\x00\x03\x00\x00\x00tpu\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00`\xfe\xff\xff\x01\x00\x00\x00\x10\x00\x00\x00\x00\x00\n\x00\x10\x00\x0f\x00\x08\x00\x04\x00\n\x00\x00\x00\x1c\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x01\x0c\x00\x00\x00\x08\x03\x1a\x02\x02\x01J\x01\x02R\x01\x00\x92\xff\xff\xff\x0c\x00\x00\x00\x18\x00\x00\x00l\x00\x00\x00\x0b\x00\x00\x00pinned_host\x00\xb2\xff\xff\xff\x0c\x00\x00\x00\x0c\x00\x00\x00\x0c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x1c\x00\x00\x00\x04\x00\x00\x00\xf2\xff\xff\xff\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x06\x00\x08\x00\x04\x00\x06\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00x\x00\n\x00\x10\x00\x0c\x00\x08\x00\x04\x00\n\x00\x00\x00\x0c\x00\x00\x00\x10\x00\x00\x00\x1c\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00x\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\xcc\xff\xff\xff\x00\x00\x02\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x003\x00\x00\x00\x01\x00\x00\x002\x00\x00\x00p\xff\xff\xff\x01\x00\x00\x00\x10\x00\x00\x00\x0c\x00\x0c\x00\x00\x00\x08\x00\x07\x00\x06\x00\x0c\x00\x00\x00\x00\x00\x02\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x003\x00\x00\x00\x01\x00\x00\x002\x00\x00\x00\xcc\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x02\x02\x00\x00\x00,\x00\x00\x00\x10\x00\x00\x00\x00\x00\n\x00\x0c\x00\x0b\x00\x00\x00\x04\x00\n\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x08\x00\x0c\x00\x0b\x00\x04\x00\x08\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x02\x01\x00\x00\x00\x08\x00\x00\x00\x04\x00\x04\x00\x04\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00"), + ), +] diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/export_with_specified_sharding.py b/jax/_src/internal_test_util/export_back_compat_test_data/export_with_specified_sharding.py new file mode 100644 index 000000000000..3e61911e2f08 --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/export_with_specified_sharding.py @@ -0,0 +1,33 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + +# Pasted from the test output (see export_serialization_back_compat_test.py module docstring) +serializations = [ + dict( + serialization_version=4, + exported_serialized=bytearray(b"(\x00\x00\x00$\x00D\x00B\x00<\x008\x004\x000\x00,\x00*\x00$\x00 \x00\x1c\x00\x18\x00\x14\x00\x10\x00\x0c\x00\n\x00\x04\x00$\x00\x00\x00@\x00\x00\x00\x00\x00\n\x00@\x00\x00\x00H\x06\x00\x00H\x06\x00\x00H\x06\x00\x00,\x06\x00\x00D\x06\x00\x00d\x06\x00\x00\x00\x00\x02\x00\x80\x06\x00\x00\xac\x06\x00\x00\xac\x06\x00\x00\xe4\x06\x00\x008\x07\x00\x00\x00\x00\x02\x00\x01\x00\x00\x00\x00\x00\x00\x00\xf6\x05\x00\x00ML\xefR\rStableHLO_v1.13.0\x00\x01%\x07\x01\x05\t\x0f\x01\x03\x0f\x03\x03\x13\x05\x0b\x17\x1b\x1f#\'\x03\xc7\x9f\x11\x01y\x07\x0b\x0b\x0b\x0b\x0f#\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1f\x0b\x0f\x0f\x0b\x1f\x0f\x0f\x0b\x1b\x0b\x0f\x0f\x0b\x1b\x0b\x0f\x0f\x0b\x1f\x0b\x0f\x0f\x0b\x1f\x0f\x0b\x1f\x03\x0f\x17\x13\x13\x0f\x1b\x0f\x1b\x05\x19\x0b\x0f\x13\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x1f\x0f\x01\x05\x0f\x0b\x05\r\x17\x07\x0f\x17\x13\x07\x02~\x04\x1f\x05\x15\x05\x17\x05\t\t\x07\x1d!#\x03\x07\x0f\x11\x13\x15\x17\x19\x05\x19\x11\x03\x00\x05\x1b\x11\x01\t\x05\x1d\x11\x01\x05\x05\x1f\x1d\x1f\x01\x05!\x05#\x15%+\x1d\')\x05%-\x03\x07\xb1\x1f+\x15-3\x1d/1\x05\'-\x03\x07w\x15]\x155;\x1d79\x05)-\x03\x07\xb7!_\x15=E\x1d?A\x05+-C\x07~\x05!K\x05-\x15GM\x1dIK\x05/-\x05\x07\xca\x02\x11-\x15OW\x1dQS\x051-U\x07\xf35g\x053\x15Ya\x1d[]\x055-_\x07\xf1\x1f\x99\x057\x15ck\x1deg\x059-i\x07\x02\x08\x1f\xab\x05;\x15ms\x1doq\x05=-\x05\x07\xda\x03!_\x1duw\x05?-\x05\x07b\x05KW\x0b\x03\x83\x01\x01\x0b\x01\x01\x01\x05\x03\x7f\x01\x03A\t\r\t\x05y{\x01\tA\x01\r\t\x05{y\x01\x1dC\x03\x03\x8b\r\x03\x87\x81#\x0b\x03\x03\x91\r\x05\x93\x95\x87\x85\x1dE\x1dG\x1dI\x1dK\x1f\t\t\x00\x00\x00@\x1f\r\x01\x01\x02\x02\x01\t)\x05A\x11\x07\t)\x01\x07\x11\x03\x05\x03\x05)\x03\x01\x0f\x1d\x04s\x05\x01Q\x01\r\x01\x07\x04a\x03\x01\t\x03@\x01\x03\x05P\x01\x05\x07\x04E\x03\t\x13\x03\x0b\x1d\x00\x07B\x01\x07\x03\t\tF\x0b\t\x03\x05\x03\x03\x0b\x06\x0b\x03\x05\x05\x01\x05\r\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00\xba\x0fM\x0f\x0b\x0f!\x1b\x05\'E\x9b)\x9f1\x9f\x17)\xa13QAg\x17\x05\r%)9\x9d\x91\x15\x19)\x19\x11\x0b\x0f\x0b\t\x11builtin\x00sdy\x00vhlo\x00module\x00mesh\x00func_v1\x00constant_v1\x00broadcast_in_dim_v1\x00multiply_v1\x00return_v1\x00/Users/necula/Source/jax/tests/export_serialization_back_compat_test.py\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/_pytest/runner.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_f\x00b\x00jit(f)/mul\x00CompatTest.test_with_specified_sharding..f\x00CompatTest.export_and_serialize\x00CompatTest.test_with_specified_sharding\x00TestCaseFunction.runtest\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/_pytest/unittest.py\x00pytest_runtest_call\x00_multicall\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_callers.py\x00PluginManager._hookexec\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_manager.py\x00HookCaller.__call__\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_hooks.py\x00call_and_report..\x00CallInfo.from_call\x00x\x00sdy.sharding\x00jax.result_info\x00result\x00main\x00public\x00\x08#\x0b\x057\x01\x05}\x07\x0b\x89\x8d\x8f\x97\x99\x03\x9b\x03\x9d\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x03\x00\x00\x00cpu\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x1c\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x01\x0c\x00\x00\x00\x08\x03\x1a\x02\x01\x02J\x01\x02R\x01\x00\x01\x00\x00\x00\x04\x00\x00\x00@\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x01\x0c\x00\x00\x00\x08\x03\x1a\x02\x02\x01J\x01\x02R\x01\x00\x01\x00\x00\x00\x04\x00\x00\x00\xca\xff\xff\xff\x00\x00\x00\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x004\x00\x00\x00\x02\x00\x00\x0016\x00\x00p\xff\xff\xff\x01\x00\x00\x00\x10\x00\x00\x00\x00\x00\n\x00\x0c\x00\x00\x00\x08\x00\x07\x00\n\x00\x00\x00\x00\x00\x00\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x004\x00\x00\x00\x02\x00\x00\x0016\x00\x00\xcc\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x02\x02\x00\x00\x00,\x00\x00\x00\x10\x00\x00\x00\x00\x00\n\x00\x0c\x00\x0b\x00\x00\x00\x04\x00\n\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x08\x00\x0c\x00\x0b\x00\x04\x00\x08\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x02\x01\x00\x00\x00\x08\x00\x00\x00\x04\x00\x04\x00\x04\x00\x00\x00\x01\x00\x00\x00f\x00\x00\x00"), + ), # End paste + + dict( + serialization_version=5, + exported_serialized=bytearray(b"0\x00\x00\x00\x00\x00*\x00H\x00F\x00@\x00<\x008\x004\x000\x00*\x00$\x00 \x00\x1c\x00\x18\x00\x14\x00\x10\x00\x0c\x00\n\x00\x04\x00\x00\x00\x00\x00,\x00*\x00\x00\x00D\x00\x00\x00\x00\x00\n\x00D\x00\x00\x00L\x06\x00\x00L\x06\x00\x00L\x06\x00\x000\x06\x00\x00H\x06\x00\x00h\x06\x00\x00\x00\x00\x02\x00\x02\x00\x00\x00\x80\x06\x00\x00\xac\x06\x00\x00\xac\x06\x00\x00\xe4\x06\x00\x008\x07\x00\x00\x00\x00\x03\x00\x01\x00\x00\x00\x00\x00\x00\x00\xf6\x05\x00\x00ML\xefR\rStableHLO_v1.13.0\x00\x01%\x07\x01\x05\t\x0f\x01\x03\x0f\x03\x03\x13\x05\x0b\x17\x1b\x1f#\'\x03\xc7\x9f\x11\x01y\x07\x0b\x0b\x0b\x0b\x0f#\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1f\x0b\x0f\x0f\x0b\x1f\x0f\x0f\x0b\x1b\x0b\x0f\x0f\x0b\x1b\x0b\x0f\x0f\x0b\x1f\x0b\x0f\x0f\x0b\x1f\x0f\x0b\x1f\x03\x0f\x17\x13\x13\x0f\x1b\x0f\x1b\x05\x19\x0b\x0f\x13\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x1f\x0f\x01\x05\x0f\x0b\x05\r\x17\x07\x0f\x17\x13\x07\x02~\x04\x1f\x05\x15\x05\x17\x05\t\t\x07\x1d!#\x03\x07\x0f\x11\x13\x15\x17\x19\x05\x19\x11\x03\x00\x05\x1b\x11\x01\t\x05\x1d\x11\x01\x05\x05\x1f\x1d\x1f\x01\x05!\x05#\x15%+\x1d\')\x05%-\x03\x07\xa5\x1f+\x15-3\x1d/1\x05\'-\x03\x07k\x15]\x155;\x1d79\x05)-\x03\x07\xa9\x1d[\x15=E\x1d?A\x05+-C\x07~\x05!K\x05-\x15GM\x1dIK\x05/-\x05\x07\xca\x02\x11-\x15OW\x1dQS\x051-U\x07\xf35g\x053\x15Ya\x1d[]\x055-_\x07\xf1\x1f\x99\x057\x15ck\x1deg\x059-i\x07\x02\x08\x1f\xab\x05;\x15ms\x1doq\x05=-\x05\x07\xda\x03!_\x1duw\x05?-\x05\x07b\x05KW\x0b\x03\x83\x01\x01\x0b\x01\x01\x01\x05\x03\x7f\x01\x03A\t\r\t\x05y{\x01\tA\x01\r\t\x05{y\x01\x1dC\x03\x03\x8b\r\x03\x87\x81#\x0b\x03\x03\x91\r\x05\x93\x95\x87\x85\x1dE\x1dG\x1dI\x1dK\x1f\t\t\x00\x00\x00@\x1f\r\x01\x01\x02\x02\x01\t)\x05A\x11\x07\t)\x01\x07\x11\x03\x05\x03\x05)\x03\x01\x0f\x1d\x04s\x05\x01Q\x01\r\x01\x07\x04a\x03\x01\t\x03@\x01\x03\x05P\x01\x05\x07\x04E\x03\t\x13\x03\x0b\x1d\x00\x07B\x01\x07\x03\t\tF\x0b\t\x03\x05\x03\x03\x0b\x06\x0b\x03\x05\x05\x01\x05\r\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00\xba\x0fM\x0f\x0b\x0f!\x1b\x05\'E\x9b)\x9f1\x9f\x17)\xa13QAg\x17\x05\r%)9\x9d\x91\x15\x19)\x19\x11\x0b\x0f\x0b\t\x11builtin\x00sdy\x00vhlo\x00module\x00mesh\x00func_v1\x00constant_v1\x00broadcast_in_dim_v1\x00multiply_v1\x00return_v1\x00/Users/necula/Source/jax/tests/export_serialization_back_compat_test.py\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/_pytest/runner.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_f\x00b\x00jit(f)/mul\x00CompatTest.test_with_specified_sharding..f\x00CompatTest.export_and_serialize\x00CompatTest.test_with_specified_sharding\x00TestCaseFunction.runtest\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/_pytest/unittest.py\x00pytest_runtest_call\x00_multicall\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_callers.py\x00PluginManager._hookexec\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_manager.py\x00HookCaller.__call__\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_hooks.py\x00call_and_report..\x00CallInfo.from_call\x00x\x00sdy.sharding\x00jax.result_info\x00result\x00main\x00public\x00\x08#\x0b\x057\x01\x05}\x07\x0b\x89\x8d\x8f\x97\x99\x03\x9b\x03\x9d\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x03\x00\x00\x00cpu\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x1c\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x01\x0c\x00\x00\x00\x08\x03\x1a\x02\x01\x02J\x01\x02R\x01\x00\x01\x00\x00\x00\x04\x00\x00\x00@\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x01\x0c\x00\x00\x00\x08\x03\x1a\x02\x02\x01J\x01\x02R\x01\x00\x01\x00\x00\x00\x04\x00\x00\x00\xcc\xff\xff\xff\x00\x00\x01\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x004\x00\x00\x00\x02\x00\x00\x0016\x00\x00p\xff\xff\xff\x01\x00\x00\x00\x10\x00\x00\x00\x0c\x00\x0c\x00\x00\x00\x08\x00\x07\x00\x06\x00\x0c\x00\x00\x00\x00\x00\x01\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x004\x00\x00\x00\x02\x00\x00\x0016\x00\x00\xcc\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x02\x02\x00\x00\x00,\x00\x00\x00\x10\x00\x00\x00\x00\x00\n\x00\x0c\x00\x0b\x00\x00\x00\x04\x00\n\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x08\x00\x0c\x00\x0b\x00\x04\x00\x08\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x02\x01\x00\x00\x00\x08\x00\x00\x00\x04\x00\x04\x00\x04\x00\x00\x00\x01\x00\x00\x00f\x00\x00\x00"), + ), # End paste + + dict( + serialization_version=6, + exported_serialized=bytearray(b"0\x00\x00\x00\x00\x00*\x00H\x00F\x00@\x00<\x008\x004\x000\x00*\x00$\x00 \x00\x1c\x00\x18\x00\x14\x00\x10\x00\x0c\x00\n\x00\x04\x00\x00\x00\x00\x00,\x00*\x00\x00\x00D\x00\x00\x00\x00\x00\n\x00D\x00\x00\x00\xf0\x06\x00\x00\xf0\x06\x00\x00\xf0\x06\x00\x00\xd4\x06\x00\x00\xec\x06\x00\x00\xa4\x07\x00\x00\x00\x00\x02\x00\x02\x00\x00\x00p\x08\x00\x00\x9c\x08\x00\x00\x9c\x08\x00\x00\xd4\x08\x00\x00(\t\x00\x00\x00\x00\x03\x00\x01\x00\x00\x00\x00\x00\x00\x00\x9b\x06\x00\x00ML\xefR\rStableHLO_v1.13.0\x00\x01%\x07\x01\x05\t\x0f\x01\x03\x0f\x03\x03\x13\x05\x0b\x17\x1b\x1f#\'\x03\xc9\xa1\x11\x01{\x07\x0b\x0b\x0b\x0f\x0b#\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1f\x0f\x0f\x0b\x1f\x0b\x0f\x0f\x0b\x1f\x0b\x0f\x0f\x0b\x1f\x0f\x0f\x0b\x1b\x0b\x0f\x0f\x0b\x1b\x0b\x0f\x0f\x0b\x1f\x0b\x0f\x0b\x1f\x03\x0f\x17\x13\x13\x0f\x1b\x0f\x1b\x05\x19\x0b\x0f\x13\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x1f\x0f\x01\x05\x0f\x0b\x05\r\x17\x07\x0f\x17\x13\x07\x02\x8a\x04\x1f\x05\x15\x05\t\t\x05\x1d!#\x05\x17\x03\x07\x0f\x11\x13\x15\x17\x19\x05\x19\x11\x03\x00\x05\x1b\x11\x01\t\x05\x1d\x11\x01\x05\x05\x1f\x1d\x1f\x01\x05!\x05#\x15%+\x1d\')\x05%-\x03\x07\xfd\x1f+\x15-3\x1d/1\x05\'-\x03\x07\xb1\x15\x87\x155;\x1d79\x05)-\x03\x07\n\x02+i\x15=E\x1d?A\x05+-C\x07\x06\x05#k\x05-\x15GO\x1dIK\x05/-M\x07~\x05!K\x051\x15QW\x1dSU\x053-\x0b\x07\xca\x02\x11-\x15Ya\x1d[]\x055-_\x07\xf35g\x057\x15ck\x1deg\x059-i\x07\xf1\x1f\x99\x05;\x15mu\x1doq\x05=-s\x07\x02\x08\x1f\xab\x05?\x1dwy\x05A-\x0b\x07\xda\x03!_\x0b\x03\x85\x01\x01\x0b\x01\x01\x01\x05\x03\x81\x01\x03C\t\r\x07\x05{}\x01\tC\x01\r\x07\x05}{\x01\x1dE\x03\x03\x8d\r\x03\x89\x83#\x0b\x03\x03\x93\r\x05\x95\x97\x89\x87\x1dG\x1dI\x1dK\x1dM\x1f\t\t\x00\x00\x00@\x1f\r\x01\x01\x02\x02\x01\t)\x05A\x11\x07\t)\x01\x07\x11\x03\x05\x03\x05)\x03\x01\x0f\x1d\x04s\x05\x01Q\x01\r\x01\x07\x04a\x03\x01\t\x03@\x01\x03\x05P\x01\x05\x07\x04E\x03\t\x13\x03\x0b\x1d\x00\x07B\x01\x07\x03\t\tF\t\t\x03\x05\x03\x03\x0b\x06\t\x03\x05\x05\x01\x05\r\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00>\x12O\x0f\x0b\x0f!\x1b\x05E\x9b)\x9f1\x9f\x17)\xa13\xb5\xb3QAg\x17\x05\r%)9\x9d\x91\x15\x19)\x19\x11\x0b\x0f\x0b\t\x11builtin\x00sdy\x00vhlo\x00module\x00mesh\x00func_v1\x00constant_v1\x00broadcast_in_dim_v1\x00multiply_v1\x00return_v1\x00/Users/necula/Source/jax/tests/export_serialization_back_compat_test.py\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/_pytest/runner.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_f\x00b\x00jit(f)/mul\x00CompatTest.test_with_specified_sharding..f\x00CompatTest.export_and_serialize\x00CompatTest.test_with_specified_sharding\x00_ParameterizedTestIter.__iter__..make_bound_param_test..bound_param_test\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/absl/testing/parameterized.py\x00TestCaseFunction.runtest\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/_pytest/unittest.py\x00pytest_runtest_call\x00_multicall\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_callers.py\x00PluginManager._hookexec\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_manager.py\x00HookCaller.__call__\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_hooks.py\x00call_and_report..\x00x\x00sdy.sharding\x00jax.result_info\x00result\x00main\x00public\x00\x08#\x0b\x057\x01\x05\x7f\x05\x0b\x8b\x8f\x91\x99\x9b\x03\x9d\x03\x9f\x00\x01\x00\x00\x00\x04\x00\x00\x00\x03\x00\x00\x00cpu\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00B\xff\xff\xff\x1c\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x01\x0c\x00\x00\x00\x08\x03\x1a\x02\x01\x02J\x01\x02R\x01\x00\xce\xfe\xff\xff\x0c\x00\x00\x00\x14\x00\x00\x00X\x00\x00\x00\x06\x00\x00\x00device\x00\x00\xea\xfe\xff\xff\x0c\x00\x00\x00\x0c\x00\x00\x00\x0c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00 \x00\x00\x00\x04\x00\x00\x00*\xff\xff\xff\x04\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00x\x00\x00\x00B\xff\xff\xff\x04\x00\x00\x00\x00\x00\x00\x002\xff\xff\xff\x0c\x00\x00\x00\x10\x00\x00\x00\x1c\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00x\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x01\x00\x00\x00\x10\x00\x00\x00\x00\x00\n\x00\x10\x00\x0f\x00\x08\x00\x04\x00\n\x00\x00\x00\x1c\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x01\x0c\x00\x00\x00\x08\x03\x1a\x02\x02\x01J\x01\x02R\x01\x00\x96\xff\xff\xff\x0c\x00\x00\x00\x14\x00\x00\x00h\x00\x00\x00\x06\x00\x00\x00device\x00\x00\xb2\xff\xff\xff\x0c\x00\x00\x00\x0c\x00\x00\x00\x0c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x1c\x00\x00\x00\x04\x00\x00\x00\xf2\xff\xff\xff\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x06\x00\x08\x00\x04\x00\x06\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00x\x00\n\x00\x10\x00\x0c\x00\x08\x00\x04\x00\n\x00\x00\x00\x0c\x00\x00\x00\x10\x00\x00\x00\x1c\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00x\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\xcc\xff\xff\xff\x00\x00\x01\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x004\x00\x00\x00\x02\x00\x00\x0016\x00\x00p\xff\xff\xff\x01\x00\x00\x00\x10\x00\x00\x00\x0c\x00\x0c\x00\x00\x00\x08\x00\x07\x00\x06\x00\x0c\x00\x00\x00\x00\x00\x01\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x004\x00\x00\x00\x02\x00\x00\x0016\x00\x00\xcc\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x02\x02\x00\x00\x00,\x00\x00\x00\x10\x00\x00\x00\x00\x00\n\x00\x0c\x00\x0b\x00\x00\x00\x04\x00\n\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x08\x00\x0c\x00\x0b\x00\x04\x00\x08\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x02\x01\x00\x00\x00\x08\x00\x00\x00\x04\x00\x04\x00\x04\x00\x00\x00\x01\x00\x00\x00f\x00\x00\x00"), + ), +] diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/export_with_unspecified_sharding.py b/jax/_src/internal_test_util/export_back_compat_test_data/export_with_unspecified_sharding.py new file mode 100644 index 000000000000..2df76e21016a --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/export_with_unspecified_sharding.py @@ -0,0 +1,34 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + + +# Pasted from the test output (see export_serialization_back_compat_test.py module docstring) +serializations = [ + dict( + serialization_version=4, + exported_serialized=bytearray(b"(\x00\x00\x00$\x00D\x00B\x00<\x008\x004\x000\x00,\x00*\x00$\x00 \x00\x1c\x00\x18\x00\x14\x00\x10\x00\x0c\x00\n\x00\x04\x00$\x00\x00\x00@\x00\x00\x00\x00\x00\n\x00@\x00\x00\x00D\x06\x00\x00D\x06\x00\x00D\x06\x00\x00(\x06\x00\x00@\x06\x00\x00H\x06\x00\x00\x00\x00\x02\x00d\x06\x00\x00\x90\x06\x00\x00\x90\x06\x00\x00\xc8\x06\x00\x00\x1c\x07\x00\x00\x00\x00\x02\x00\x01\x00\x00\x00\x00\x00\x00\x00\xf1\x05\x00\x00ML\xefR\rStableHLO_v1.13.0\x00\x01%\x07\x01\x05\t\x0f\x01\x03\x0f\x03\x03\x13\x05\x0b\x17\x1b\x1f#\'\x03\xc5\x9d\x11\x01y\x07\x0b\x0b\x0b\x0f#\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x0b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1f\x0b\x0f\x0f\x0b\x1f\x0f\x0f\x0b\x1b\x0b\x0f\x0f\x0b\x1b\x0b\x0f\x0f\x0b\x1f\x0b\x0f\x0f\x0b\x1f\x0f\x0b\x1f\x03\r\x13\x0f\x1b\x17\x0f\x13\x05\x19\x0f\x13\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x1f\x0f\x01\x05\x0f\x0b\x05\r\x17\x07\x0f\x17\x13\x07\x02^\x04\x1f\x05\x15\x05\x17\x05\t\x1d!#\x03\x07\r\x0f\x11\x13\x15\x17\x05\x19\x11\x03\x00\x05\x1b\x11\x01\t\x05\x1d\x11\x01\x05\x05\x1f\t\x07\x1d\x1f\x01\x05!\x05#\x15%+\x1d\')\x05%-\x03\x07\xed\x1f+\x15-3\x1d/1\x05\'-\x03\x07w\x15]\x155;\x1d79\x05)-\x03\x07\xf5!_\x15=E\x1d?A\x05+-C\x07~\x05!K\x05-\x15GM\x1dIK\x05/-\x05\x07\xca\x02\x11-\x15OW\x1dQS\x051-U\x07\xf35g\x053\x15Ya\x1d[]\x055-_\x07\xf1\x1f\x99\x057\x15ck\x1deg\x059-i\x07\x02\x08\x1f\xab\x05;\x15ms\x1doq\x05=-\x05\x07\xda\x03!_\x1duw\x05?-\x05\x07b\x05KW\x05\x03{\x01\x03A\t\r\x1b\x05\x7f\x83\x01\x0b\x03\x81\x01\x01\tA\x01\x0b\x01\x01\x01\x03\x03\x87\r\x03\x89}\x1dC#\x0b\x03\x03\x8f\r\x03\x91\x93\x1dE\x1dG\x1dI\x1dK\x1f\t\t\x00\x00\x00@\x1f\r\x01\x01\x02\x02\x01\t)\x05A\x11\x07\t)\x01\x07\x11\x03\x05\x03\x05)\x03\x01\x0f\x1d\x04s\x05\x01Q\x01\x0b\x01\x07\x04a\x03\x01\t\x03@\x01\x03\x05P\x01\x05\x07\x04E\x03\t\x13\x03\x0b\x1d\x00\x07B\x01\x07\x03\t\tF\t\t\x03\x05\x03\x03\x0b\x06\t\x03\x05\x05\x01\x05\r\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00\xca\x0fM\x0f\x0b\x0f!\x1b\x05\'E\x9b)\x9f1\x9f\x17)\xa13UAk\x17\x05\r%)9\x9d\x91\x15\x19)\x19\x11\x0b\x0f\x0b\t\x11builtin\x00sdy\x00vhlo\x00module\x00mesh\x00func_v1\x00constant_v1\x00broadcast_in_dim_v1\x00multiply_v1\x00return_v1\x00/Users/necula/Source/jax/tests/export_serialization_back_compat_test.py\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/_pytest/runner.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_f\x00b\x00jit(f)/mul\x00CompatTest.test_with_unspecified_sharding..f\x00CompatTest.export_and_serialize\x00CompatTest.test_with_unspecified_sharding\x00TestCaseFunction.runtest\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/_pytest/unittest.py\x00pytest_runtest_call\x00_multicall\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_callers.py\x00PluginManager._hookexec\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_manager.py\x00HookCaller.__call__\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_hooks.py\x00call_and_report..\x00CallInfo.from_call\x00x\x00sdy.sharding\x00jax.result_info\x00result\x00main\x00public\x00\x08#\x0b\x053\x01\x05y\x07\x0b\x85\x8b\x8d\x95\x97\x03\x99\x03\x9b\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x03\x00\x00\x00cpu\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x18\xff\xff\xff\x01\x00\x00\x00\x04\x00\x00\x00@\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x01\x0c\x00\x00\x00\x08\x03\x1a\x02\x02\x01J\x01\x02R\x01\x00\x01\x00\x00\x00\x04\x00\x00\x00\xca\xff\xff\xff\x00\x00\x00\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x004\x00\x00\x00\x02\x00\x00\x0016\x00\x00p\xff\xff\xff\x01\x00\x00\x00\x10\x00\x00\x00\x00\x00\n\x00\x0c\x00\x00\x00\x08\x00\x07\x00\n\x00\x00\x00\x00\x00\x00\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x004\x00\x00\x00\x02\x00\x00\x0016\x00\x00\xcc\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x02\x02\x00\x00\x00,\x00\x00\x00\x10\x00\x00\x00\x00\x00\n\x00\x0c\x00\x0b\x00\x00\x00\x04\x00\n\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x08\x00\x0c\x00\x0b\x00\x04\x00\x08\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x02\x01\x00\x00\x00\x08\x00\x00\x00\x04\x00\x04\x00\x04\x00\x00\x00\x01\x00\x00\x00f\x00\x00\x00"), + ), + + dict( + serialization_version=5, + exported_serialized=bytearray(b"0\x00\x00\x00\x00\x00*\x00H\x00F\x00@\x00<\x008\x004\x000\x00*\x00$\x00 \x00\x1c\x00\x18\x00\x14\x00\x10\x00\x0c\x00\n\x00\x04\x00\x00\x00\x00\x00,\x00*\x00\x00\x00D\x00\x00\x00\x00\x00\n\x00D\x00\x00\x00H\x06\x00\x00H\x06\x00\x00H\x06\x00\x00,\x06\x00\x00D\x06\x00\x00L\x06\x00\x00\x00\x00\x02\x00\x02\x00\x00\x00d\x06\x00\x00\x90\x06\x00\x00\x90\x06\x00\x00\xc8\x06\x00\x00\x1c\x07\x00\x00\x00\x00\x03\x00\x01\x00\x00\x00\x00\x00\x00\x00\xf1\x05\x00\x00ML\xefR\rStableHLO_v1.13.0\x00\x01%\x07\x01\x05\t\x0f\x01\x03\x0f\x03\x03\x13\x05\x0b\x17\x1b\x1f#\'\x03\xc5\x9d\x11\x01y\x07\x0b\x0b\x0b\x0f#\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x0b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1f\x0b\x0f\x0f\x0b\x1f\x0f\x0f\x0b\x1b\x0b\x0f\x0f\x0b\x1b\x0b\x0f\x0f\x0b\x1f\x0b\x0f\x0f\x0b\x1f\x0f\x0b\x1f\x03\r\x13\x0f\x1b\x17\x0f\x13\x05\x19\x0f\x13\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x1f\x0f\x01\x05\x0f\x0b\x05\r\x17\x07\x0f\x17\x13\x07\x02^\x04\x1f\x05\x15\x05\x17\x05\t\x1d!#\x03\x07\r\x0f\x11\x13\x15\x17\x05\x19\x11\x03\x00\x05\x1b\x11\x01\t\x05\x1d\x11\x01\x05\x05\x1f\t\x07\x1d\x1f\x01\x05!\x05#\x15%+\x1d\')\x05%-\x03\x07\xdf\x1f+\x15-3\x1d/1\x05\'-\x03\x07w\x15]\x155;\x1d79\x05)-\x03\x07\xe7!_\x15=E\x1d?A\x05+-C\x07~\x05!K\x05-\x15GM\x1dIK\x05/-\x05\x07\xca\x02\x11-\x15OW\x1dQS\x051-U\x07\xf35g\x053\x15Ya\x1d[]\x055-_\x07\xf1\x1f\x99\x057\x15ck\x1deg\x059-i\x07\x02\x08\x1f\xab\x05;\x15ms\x1doq\x05=-\x05\x07\xda\x03!_\x1duw\x05?-\x05\x07b\x05KW\x05\x03{\x01\x03A\t\r\x1b\x05\x7f\x83\x01\x0b\x03\x81\x01\x01\tA\x01\x0b\x01\x01\x01\x03\x03\x87\r\x03\x89}\x1dC#\x0b\x03\x03\x8f\r\x03\x91\x93\x1dE\x1dG\x1dI\x1dK\x1f\t\t\x00\x00\x00@\x1f\r\x01\x01\x02\x02\x01\t)\x05A\x11\x07\t)\x01\x07\x11\x03\x05\x03\x05)\x03\x01\x0f\x1d\x04s\x05\x01Q\x01\x0b\x01\x07\x04a\x03\x01\t\x03@\x01\x03\x05P\x01\x05\x07\x04E\x03\t\x13\x03\x0b\x1d\x00\x07B\x01\x07\x03\t\tF\t\t\x03\x05\x03\x03\x0b\x06\t\x03\x05\x05\x01\x05\r\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00\xca\x0fM\x0f\x0b\x0f!\x1b\x05\'E\x9b)\x9f1\x9f\x17)\xa13UAk\x17\x05\r%)9\x9d\x91\x15\x19)\x19\x11\x0b\x0f\x0b\t\x11builtin\x00sdy\x00vhlo\x00module\x00mesh\x00func_v1\x00constant_v1\x00broadcast_in_dim_v1\x00multiply_v1\x00return_v1\x00/Users/necula/Source/jax/tests/export_serialization_back_compat_test.py\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/_pytest/runner.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_f\x00b\x00jit(f)/mul\x00CompatTest.test_with_unspecified_sharding..f\x00CompatTest.export_and_serialize\x00CompatTest.test_with_unspecified_sharding\x00TestCaseFunction.runtest\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/_pytest/unittest.py\x00pytest_runtest_call\x00_multicall\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_callers.py\x00PluginManager._hookexec\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_manager.py\x00HookCaller.__call__\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_hooks.py\x00call_and_report..\x00CallInfo.from_call\x00x\x00sdy.sharding\x00jax.result_info\x00result\x00main\x00public\x00\x08#\x0b\x053\x01\x05y\x07\x0b\x85\x8b\x8d\x95\x97\x03\x99\x03\x9b\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x03\x00\x00\x00cpu\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x18\xff\xff\xff\x01\x00\x00\x00\x04\x00\x00\x00@\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x01\x0c\x00\x00\x00\x08\x03\x1a\x02\x02\x01J\x01\x02R\x01\x00\x01\x00\x00\x00\x04\x00\x00\x00\xcc\xff\xff\xff\x00\x00\x01\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x004\x00\x00\x00\x02\x00\x00\x0016\x00\x00p\xff\xff\xff\x01\x00\x00\x00\x10\x00\x00\x00\x0c\x00\x0c\x00\x00\x00\x08\x00\x07\x00\x06\x00\x0c\x00\x00\x00\x00\x00\x01\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x004\x00\x00\x00\x02\x00\x00\x0016\x00\x00\xcc\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x02\x02\x00\x00\x00,\x00\x00\x00\x10\x00\x00\x00\x00\x00\n\x00\x0c\x00\x0b\x00\x00\x00\x04\x00\n\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x08\x00\x0c\x00\x0b\x00\x04\x00\x08\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x02\x01\x00\x00\x00\x08\x00\x00\x00\x04\x00\x04\x00\x04\x00\x00\x00\x01\x00\x00\x00f\x00\x00\x00"), + ), + + dict( + serialization_version=6, + exported_serialized=bytearray(b"0\x00\x00\x00\x00\x00*\x00H\x00F\x00@\x00<\x008\x004\x000\x00*\x00$\x00 \x00\x1c\x00\x18\x00\x14\x00\x10\x00\x0c\x00\n\x00\x04\x00\x00\x00\x00\x00,\x00*\x00\x00\x00D\x00\x00\x00\x00\x00\n\x00D\x00\x00\x00\xec\x06\x00\x00\xec\x06\x00\x00\xec\x06\x00\x00\xd0\x06\x00\x00\xe8\x06\x00\x00\xf0\x06\x00\x00\x00\x00\x02\x00\x02\x00\x00\x00\xbc\x07\x00\x00\xe8\x07\x00\x00\xe8\x07\x00\x00 \x08\x00\x00t\x08\x00\x00\x00\x00\x03\x00\x01\x00\x00\x00\x00\x00\x00\x00\x97\x06\x00\x00ML\xefR\rStableHLO_v1.13.0\x00\x01%\x07\x01\x05\t\x0f\x01\x03\x0f\x03\x03\x13\x05\x0b\x17\x1b\x1f#\'\x03\xc7\x9f\x11\x01{\x07\x0b\x0b\x0f\x0b#\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x0b\x0f\x0f\x0b\x1f\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1f\x0f\x0f\x0b\x1f\x0b\x0f\x0f\x0b\x1f\x0b\x0f\x0f\x0b\x1f\x0f\x0f\x0b\x1b\x0b\x0f\x0f\x0b\x1b\x0b\x0f\x0f\x0b\x1f\x0b\x0f\x0b\x1f\x03\r\x13\x0f\x1b\x17\x0f\x13\x05\x19\x0f\x13\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x1f\x0f\x01\x05\x0f\x0b\x05\r\x17\x07\x0f\x17\x13\x07\x02n\x04\x1f\x05\x15\x05\t\x1d!#\x05\x17\x03\x07\r\x0f\x11\x13\x15\x17\x05\x19\x11\x03\x00\x05\x1b\x11\x01\t\x05\x1d\x11\x01\x05\x05\x1f\t\x05\x1d\x1f\x01\x05!\x05#\x15%+\x1d\')\x05%-\x03\x07r\x02\x1f+\x15-3\x1d/1\x05\'-\x03\x07\xb1\x15\x87\x155;\x1d79\x05)-\x03\x07\x82\x02+i\x15=E\x1d?A\x05+-C\x07\x06\x05#k\x05-\x15GO\x1dIK\x05/-M\x07~\x05!K\x051\x15QW\x1dSU\x053-\t\x07\xca\x02\x11-\x15Ya\x1d[]\x055-_\x07\xf35g\x057\x15ck\x1deg\x059-i\x07\xf1\x1f\x99\x05;\x15mu\x1doq\x05=-s\x07\x02\x08\x1f\xab\x05?\x1dwy\x05A-\t\x07\xda\x03!_\x05\x03}\x01\x03C\t\r\x1b\x05\x81\x85\x01\x0b\x03\x83\x01\x01\tC\x01\x0b\x01\x01\x01\x03\x03\x89\r\x03\x8b\x7f\x1dE#\x0b\x03\x03\x91\r\x03\x93\x95\x1dG\x1dI\x1dK\x1dM\x1f\t\t\x00\x00\x00@\x1f\r\x01\x01\x02\x02\x01\t)\x05A\x11\x07\t)\x01\x07\x11\x03\x05\x03\x05)\x03\x01\x0f\x1d\x04s\x05\x01Q\x01\x0b\x01\x07\x04a\x03\x01\t\x03@\x01\x03\x05P\x01\x05\x07\x04E\x03\t\x13\x03\x0b\x1d\x00\x07B\x01\x07\x03\t\tF\x07\t\x03\x05\x03\x03\x0b\x06\x07\x03\x05\x05\x01\x05\r\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00N\x12O\x0f\x0b\x0f!\x1b\x05E\x9b)\x9f1\x9f\x17)\xa13\xb5\xb3UAk\x17\x05\r%)9\x9d\x91\x15\x19)\x19\x11\x0b\x0f\x0b\t\x11builtin\x00sdy\x00vhlo\x00module\x00mesh\x00func_v1\x00constant_v1\x00broadcast_in_dim_v1\x00multiply_v1\x00return_v1\x00/Users/necula/Source/jax/tests/export_serialization_back_compat_test.py\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/_pytest/runner.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_f\x00b\x00jit(f)/mul\x00CompatTest.test_with_unspecified_sharding..f\x00CompatTest.export_and_serialize\x00CompatTest.test_with_unspecified_sharding\x00_ParameterizedTestIter.__iter__..make_bound_param_test..bound_param_test\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/absl/testing/parameterized.py\x00TestCaseFunction.runtest\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/_pytest/unittest.py\x00pytest_runtest_call\x00_multicall\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_callers.py\x00PluginManager._hookexec\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_manager.py\x00HookCaller.__call__\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_hooks.py\x00call_and_report..\x00x\x00sdy.sharding\x00jax.result_info\x00result\x00main\x00public\x00\x08#\x0b\x053\x01\x05{\x05\x0b\x87\x8d\x8f\x97\x99\x03\x9b\x03\x9d\x00\x01\x00\x00\x00\x04\x00\x00\x00\x03\x00\x00\x00cpu\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00d\xfe\xff\xff\x01\x00\x00\x00\x10\x00\x00\x00\x00\x00\n\x00\x10\x00\x0f\x00\x08\x00\x04\x00\n\x00\x00\x00\x1c\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x01\x0c\x00\x00\x00\x08\x03\x1a\x02\x02\x01J\x01\x02R\x01\x00\x96\xff\xff\xff\x0c\x00\x00\x00\x14\x00\x00\x00h\x00\x00\x00\x06\x00\x00\x00device\x00\x00\xb2\xff\xff\xff\x0c\x00\x00\x00\x0c\x00\x00\x00\x0c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x1c\x00\x00\x00\x04\x00\x00\x00\xf2\xff\xff\xff\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x06\x00\x08\x00\x04\x00\x06\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00x\x00\n\x00\x10\x00\x0c\x00\x08\x00\x04\x00\n\x00\x00\x00\x0c\x00\x00\x00\x10\x00\x00\x00\x1c\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00x\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\xcc\xff\xff\xff\x00\x00\x01\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x004\x00\x00\x00\x02\x00\x00\x0016\x00\x00p\xff\xff\xff\x01\x00\x00\x00\x10\x00\x00\x00\x0c\x00\x0c\x00\x00\x00\x08\x00\x07\x00\x06\x00\x0c\x00\x00\x00\x00\x00\x01\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x004\x00\x00\x00\x02\x00\x00\x0016\x00\x00\xcc\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x02\x02\x00\x00\x00,\x00\x00\x00\x10\x00\x00\x00\x00\x00\n\x00\x0c\x00\x0b\x00\x00\x00\x04\x00\n\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x08\x00\x0c\x00\x0b\x00\x04\x00\x08\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x02\x01\x00\x00\x00\x08\x00\x00\x00\x04\x00\x04\x00\x04\x00\x00\x00\x01\x00\x00\x00f\x00\x00\x00"), + ), +] diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/pallas/__init__.py b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/__init__.py new file mode 100644 index 000000000000..3da0dd1fa3ca --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_gpu_add_one.py b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_gpu_add_one.py new file mode 100644 index 000000000000..ade64d199ea9 --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_gpu_add_one.py @@ -0,0 +1,141 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import numpy as np + +array = np.array +float32 = np.float32 + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_04_22 = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['mosaic_gpu'], + serialized_date=datetime.date(2025, 4, 22), + inputs=(array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., + 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., + 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., + 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., + 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., + 55., 56., 57., 58., 59., 60., 61., 62., 63., 64., 65., + 66., 67., 68., 69., 70., 71., 72., 73., 74., 75., 76., + 77., 78., 79., 80., 81., 82., 83., 84., 85., 86., 87., + 88., 89., 90., 91., 92., 93., 94., 95., 96., 97., 98., + 99., 100., 101., 102., 103., 104., 105., 106., 107., 108., 109., + 110., 111., 112., 113., 114., 115., 116., 117., 118., 119., 120., + 121., 122., 123., 124., 125., 126., 127., 128., 129., 130., 131., + 132., 133., 134., 135., 136., 137., 138., 139., 140., 141., 142., + 143., 144., 145., 146., 147., 148., 149., 150., 151., 152., 153., + 154., 155., 156., 157., 158., 159., 160., 161., 162., 163., 164., + 165., 166., 167., 168., 169., 170., 171., 172., 173., 174., 175., + 176., 177., 178., 179., 180., 181., 182., 183., 184., 185., 186., + 187., 188., 189., 190., 191., 192., 193., 194., 195., 196., 197., + 198., 199., 200., 201., 202., 203., 204., 205., 206., 207., 208., + 209., 210., 211., 212., 213., 214., 215., 216., 217., 218., 219., + 220., 221., 222., 223., 224., 225., 226., 227., 228., 229., 230., + 231., 232., 233., 234., 235., 236., 237., 238., 239., 240., 241., + 242., 243., 244., 245., 246., 247., 248., 249., 250., 251., 252., + 253., 254., 255.], dtype=float32),), + expected_outputs=(array([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., + 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., + 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., + 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., + 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., + 56., 57., 58., 59., 60., 61., 62., 63., 64., 65., 66., + 67., 68., 69., 70., 71., 72., 73., 74., 75., 76., 77., + 78., 79., 80., 81., 82., 83., 84., 85., 86., 87., 88., + 89., 90., 91., 92., 93., 94., 95., 96., 97., 98., 99., + 100., 101., 102., 103., 104., 105., 106., 107., 108., 109., 110., + 111., 112., 113., 114., 115., 116., 117., 118., 119., 120., 121., + 122., 123., 124., 125., 126., 127., 128., 129., 130., 131., 132., + 133., 134., 135., 136., 137., 138., 139., 140., 141., 142., 143., + 144., 145., 146., 147., 148., 149., 150., 151., 152., 153., 154., + 155., 156., 157., 158., 159., 160., 161., 162., 163., 164., 165., + 166., 167., 168., 169., 170., 171., 172., 173., 174., 175., 176., + 177., 178., 179., 180., 181., 182., 183., 184., 185., 186., 187., + 188., 189., 190., 191., 192., 193., 194., 195., 196., 197., 198., + 199., 200., 201., 202., 203., 204., 205., 206., 207., 208., 209., + 210., 211., 212., 213., 214., 215., 216., 217., 218., 219., 220., + 221., 222., 223., 224., 225., 226., 227., 228., 229., 230., 231., + 232., 233., 234., 235., 236., 237., 238., 239., 240., 241., 242., + 243., 244., 245., 246., 247., 248., 249., 250., 251., 252., 253., + 254., 255., 256.], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("args[0]") +module @jit_wrapped attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<256xf32> loc("args[0]")) -> (tensor<256xf32> {jax.result_info = "result"}) { + %0 = stablehlo.custom_call @mosaic_gpu(%arg0) {api_version = 2 : i32, backend_config = "\A9C\FB\81\9A1\C2?\0E\F4\E1\E4\E77\03\B6\97\E5G(]WR\98\EB{\BA\8A\84\01\12'#loc = loc(\22third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py\22:83:4)\0A#loc1 = loc(\22-\22:94:40)\0A#loc2 = loc(\22-\22:94:47)\0A#loc3 = loc(\22-\22:94:54)\0A#loc4 = loc(\22-\22:94:116)\0A#loc5 = loc(\22-\22:94:123)\0A#loc6 = loc(\22-\22:94:130)\0A#loc7 = loc(\22-\22:94:65)\0A#loc8 = loc(\22-\22:94:78)\0A#loc9 = loc(\22-\22:94:91)\0A#loc10 = loc(\22-\22:94:141)\0A#loc11 = loc(\22-\22:94:157)\0A#loc12 = loc(\22-\22:94:174)\0A#loc17 = loc(\22jit(wrapped)/jit(main)/pallas_call\22(#loc))\0A\22builtin.module\22() <{sym_name = \22add_one\22}> ({\0A \22stable_mosaic_gpu.func.func\22() ({\0A }) {function_type = (!llvm.ptr, !llvm.ptr, i64, i64, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> (), sym_name = \22mosaic_gpu_init_tma_desc\22, sym_visibility = \22private\22} : () -> () loc(#loc17)\0A \22stable_mosaic_gpu.llvm.mlir.global\22() ({\0A }) {addr_space = 4 : i32, global_type = !llvm.array<0 x i8>, linkage = #llvm.linkage, sym_name = \22global_scratch\22, unnamed_addr = 0 : i64, visibility_ = 0 : i64} : () -> () loc(#loc17)\0A \22stable_mosaic_gpu.func.func\22() ({\0A ^bb0(%arg0: !llvm.ptr loc(\22jit(wrapped)/jit(main)/pallas_call\22(#loc)), %arg1: !llvm.ptr loc(\22jit(wrapped)/jit(main)/pallas_call\22(#loc))):\0A %0 = \22stable_mosaic_gpu.builtin.unrealized_conversion_cast\22(%arg0) : (!llvm.ptr) -> !gpu.async.token loc(#loc17)\0A %1 = \22stable_mosaic_gpu.llvm.getelementptr\22(%arg1) {elem_type = !llvm.ptr, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A %2 = \22stable_mosaic_gpu.llvm.load\22(%1) {ordering = 0 : i64} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A %3 = \22stable_mosaic_gpu.llvm.mlir.undef\22() : () -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %4 = \22stable_mosaic_gpu.llvm.insertvalue\22(%3, %2) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %5 = \22stable_mosaic_gpu.llvm.insertvalue\22(%4, %2) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %6 = \22stable_mosaic_gpu.llvm.mlir.constant\22() {value = 0 : i64} : () -> i64 loc(#loc17)\0A %7 = \22stable_mosaic_gpu.llvm.insertvalue\22(%5, %6) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %8 = \22stable_mosaic_gpu.llvm.mlir.constant\22() {value = 256 : i64} : () -> i64 loc(#loc17)\0A %9 = \22stable_mosaic_gpu.llvm.insertvalue\22(%7, %8) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %10 = \22stable_mosaic_gpu.llvm.mlir.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %11 = \22stable_mosaic_gpu.llvm.insertvalue\22(%9, %10) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %12 = \22stable_mosaic_gpu.builtin.unrealized_conversion_cast\22(%11) : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>) -> memref<256xf32> loc(#loc17)\0A %13 = \22stable_mosaic_gpu.llvm.getelementptr\22(%arg1) {elem_type = !llvm.ptr, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A %14 = \22stable_mosaic_gpu.llvm.load\22(%13) {ordering = 0 : i64} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A %15 = \22stable_mosaic_gpu.llvm.mlir.undef\22() : () -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %16 = \22stable_mosaic_gpu.llvm.insertvalue\22(%15, %14) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %17 = \22stable_mosaic_gpu.llvm.insertvalue\22(%16, %14) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %18 = \22stable_mosaic_gpu.llvm.mlir.constant\22() {value = 0 : i64} : () -> i64 loc(#loc17)\0A %19 = \22stable_mosaic_gpu.llvm.insertvalue\22(%17, %18) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %20 = \22stable_mosaic_gpu.llvm.mlir.constant\22() {value = 256 : i64} : () -> i64 loc(#loc17)\0A %21 = \22stable_mosaic_gpu.llvm.insertvalue\22(%19, %20) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %22 = \22stable_mosaic_gpu.llvm.mlir.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %23 = \22stable_mosaic_gpu.llvm.insertvalue\22(%21, %22) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %24 = \22stable_mosaic_gpu.builtin.unrealized_conversion_cast\22(%23) : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>) -> memref<256xf32> loc(#loc17)\0A %25 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %26 = \22stable_mosaic_gpu.llvm.alloca\22(%25) {alignment = 64 : i64, elem_type = !llvm.array<256 x i8>} : (i64) -> !llvm.ptr loc(#loc17)\0A %27 = \22stable_mosaic_gpu.llvm.getelementptr\22(%26) {elem_type = i8, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A %28:4 = \22stable_mosaic_gpu.memref.extract_strided_metadata\22(%12) : (memref<256xf32>) -> (memref, index, index, index) loc(#loc17)\0A %29 = \22stable_mosaic_gpu.memref.extract_aligned_pointer_as_index\22(%12) : (memref<256xf32>) -> index loc(#loc17)\0A %30 = \22stable_mosaic_gpu.arith.index_cast\22(%29) : (index) -> i64 loc(#loc17)\0A %31 = \22stable_mosaic_gpu.llvm.inttoptr\22(%30) : (i64) -> !llvm.ptr loc(#loc17)\0A %32 = \22stable_mosaic_gpu.arith.index_cast\22(%28#1) : (index) -> i64 loc(#loc17)\0A %33 = \22stable_mosaic_gpu.llvm.getelementptr\22(%31, %32) {elem_type = f32, rawConstantIndices = array} : (!llvm.ptr, i64) -> !llvm.ptr loc(#loc17)\0A %34 = \22stable_mosaic_gpu.arith.constant\22() {value = 6 : i64} : () -> i64 loc(#loc17)\0A %35 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %36 = \22stable_mosaic_gpu.arith.index_cast\22(%28#2) : (index) -> i64 loc(#loc17)\0A %37 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %38 = \22stable_mosaic_gpu.llvm.alloca\22(%37) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\0A %39 = \22stable_mosaic_gpu.llvm.getelementptr\22(%38) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A \22stable_mosaic_gpu.llvm.store\22(%36, %39) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\0A %40 = \22stable_mosaic_gpu.arith.index_cast\22(%28#3) : (index) -> i64 loc(#loc17)\0A %41 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %42 = \22stable_mosaic_gpu.llvm.alloca\22(%41) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\0A %43 = \22stable_mosaic_gpu.llvm.getelementptr\22(%42) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A \22stable_mosaic_gpu.llvm.store\22(%40, %43) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\0A %44 = \22stable_mosaic_gpu.arith.constant\22() {value = 16 : i64} : () -> i64 loc(#loc17)\0A %45 = \22stable_mosaic_gpu.arith.constant\22() {value = 256 : i64} : () -> i64 loc(#loc17)\0A %46 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %47 = \22stable_mosaic_gpu.llvm.alloca\22(%46) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\0A %48 = \22stable_mosaic_gpu.llvm.getelementptr\22(%47) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A \22stable_mosaic_gpu.llvm.store\22(%45, %48) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\0A \22stable_mosaic_gpu.func.call\22(%27, %33, %34, %35, %38, %42, %44, %47) {callee = @mosaic_gpu_init_tma_desc} : (!llvm.ptr, !llvm.ptr, i64, i64, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> () loc(#loc17)\0A %49 = \22stable_mosaic_gpu.llvm.getelementptr\22(%26) {elem_type = i8, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A %50:4 = \22stable_mosaic_gpu.memref.extract_strided_metadata\22(%24) : (memref<256xf32>) -> (memref, index, index, index) loc(#loc17)\0A %51 = \22stable_mosaic_gpu.memref.extract_aligned_pointer_as_index\22(%24) : (memref<256xf32>) -> index loc(#loc17)\0A %52 = \22stable_mosaic_gpu.arith.index_cast\22(%51) : (index) -> i64 loc(#loc17)\0A %53 = \22stable_mosaic_gpu.llvm.inttoptr\22(%52) : (i64) -> !llvm.ptr loc(#loc17)\0A %54 = \22stable_mosaic_gpu.arith.index_cast\22(%50#1) : (index) -> i64 loc(#loc17)\0A %55 = \22stable_mosaic_gpu.llvm.getelementptr\22(%53, %54) {elem_type = f32, rawConstantIndices = array} : (!llvm.ptr, i64) -> !llvm.ptr loc(#loc17)\0A %56 = \22stable_mosaic_gpu.arith.constant\22() {value = 6 : i64} : () -> i64 loc(#loc17)\0A %57 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %58 = \22stable_mosaic_gpu.arith.index_cast\22(%50#2) : (index) -> i64 loc(#loc17)\0A %59 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %60 = \22stable_mosaic_gpu.llvm.alloca\22(%59) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\0A %61 = \22stable_mosaic_gpu.llvm.getelementptr\22(%60) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A \22stable_mosaic_gpu.llvm.store\22(%58, %61) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\0A %62 = \22stable_mosaic_gpu.arith.index_cast\22(%50#3) : (index) -> i64 loc(#loc17)\0A %63 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %64 = \22stable_mosaic_gpu.llvm.alloca\22(%63) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\0A %65 = \22stable_mosaic_gpu.llvm.getelementptr\22(%64) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A \22stable_mosaic_gpu.llvm.store\22(%62, %65) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\0A %66 = \22stable_mosaic_gpu.arith.constant\22() {value = 16 : i64} : () -> i64 loc(#loc17)\0A %67 = \22stable_mosaic_gpu.arith.constant\22() {value = 256 : i64} : () -> i64 loc(#loc17)\0A %68 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %69 = \22stable_mosaic_gpu.llvm.alloca\22(%68) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\0A %70 = \22stable_mosaic_gpu.llvm.getelementptr\22(%69) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A \22stable_mosaic_gpu.llvm.store\22(%67, %70) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\0A \22stable_mosaic_gpu.func.call\22(%49, %55, %56, %57, %60, %64, %66, %69) {callee = @mosaic_gpu_init_tma_desc} : (!llvm.ptr, !llvm.ptr, i64, i64, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> () loc(#loc17)\0A %71 = \22stable_mosaic_gpu.llvm.load\22(%26) {ordering = 0 : i64} : (!llvm.ptr) -> !llvm.array<256 x i8> loc(#loc17)\0A %72 = \22stable_mosaic_gpu.arith.constant\22() {value = 2 : index} : () -> index loc(#loc17)\0A %73 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : index} : () -> index loc(#loc17)\0A %74 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : index} : () -> index loc(#loc17)\0A %75 = \22stable_mosaic_gpu.arith.constant\22() {value = 128 : index} : () -> index loc(#loc17)\0A %76 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : index} : () -> index loc(#loc17)\0A %77 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : index} : () -> index loc(#loc17)\0A %78 = \22stable_mosaic_gpu.arith.constant\22() {value = 2056 : i32} : () -> i32 loc(#loc17)\0A %79 = \22stable_mosaic_gpu.gpu.launch\22(%0, %72, %73, %74, %75, %76, %77, %78) ({\0A ^bb0(%arg2: index loc(\22-\22:94:40), %arg3: index loc(\22-\22:94:47), %arg4: index loc(\22-\22:94:54), %arg5: index loc(\22-\22:94:116), %arg6: index loc(\22-\22:94:123), %arg7: index loc(\22-\22:94:130), %arg8: index loc(\22-\22:94:65), %arg9: index loc(\22-\22:94:78), %arg10: index loc(\22-\22:94:91), %arg11: index loc(\22-\22:94:141), %arg12: index loc(\22-\22:94:157), %arg13: index loc(\22-\22:94:174)):\0A %80 = \22stable_mosaic_gpu.gpu.dynamic_shared_memory\22() : () -> memref> loc(#loc17)\0A %81 = \22stable_mosaic_gpu.builtin.unrealized_conversion_cast\22(%71) : (!llvm.array<256 x i8>) -> !llvm.ptr loc(#loc17)\0A %82 = \22stable_mosaic_gpu.llvm.getelementptr\22(%81) {elem_type = i8, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc18)\0A %83 = \22stable_mosaic_gpu.llvm.getelementptr\22(%81) {elem_type = i8, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc19)\0A %84 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc17)\0A %85 = \22stable_mosaic_gpu.memref.view\22(%80, %84) : (memref>, index) -> memref<2048xi8, #gpu.address_space> loc(#loc17)\0A %86 = \22stable_mosaic_gpu.builtin.unrealized_conversion_cast\22(%80) : (memref>) -> !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %87 = \22stable_mosaic_gpu.llvm.extractvalue\22(%86) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> !llvm.ptr<3> loc(#loc17)\0A %88 = \22stable_mosaic_gpu.llvm.extractvalue\22(%86) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> i64 loc(#loc17)\0A %89 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %90 = \22stable_mosaic_gpu.llvm.mul\22(%88, %89) : (i64, i64) -> i64 loc(#loc17)\0A %91 = \22stable_mosaic_gpu.llvm.ptrtoint\22(%87) : (!llvm.ptr<3>) -> i64 loc(#loc17)\0A %92 = \22stable_mosaic_gpu.llvm.add\22(%91, %90) : (i64, i64) -> i64 loc(#loc17)\0A %93 = \22stable_mosaic_gpu.llvm.inttoptr\22(%92) : (i64) -> !llvm.ptr<3> loc(#loc17)\0A %94 = \22stable_mosaic_gpu.llvm.getelementptr\22(%93) {elem_type = i8, rawConstantIndices = array} : (!llvm.ptr<3>) -> !llvm.ptr<3> loc(#loc17)\0A %95 = \22stable_mosaic_gpu.memref.alloca\22() {operandSegmentSizes = array} : () -> memref loc(#loc17)\0A %96 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc17)\0A \22stable_mosaic_gpu.memref.store\22(%96, %95) : (i32, memref) -> () loc(#loc17)\0A %97 = \22stable_mosaic_gpu.nvvm.elect.sync\22() : () -> i1 loc(#loc17)\0A %98 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %99 = \22stable_mosaic_gpu.arith.index_cast\22(%98) : (index) -> i32 loc(#loc17)\0A %100 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %101 = \22stable_mosaic_gpu.arith.index_cast\22(%100) : (index) -> i32 loc(#loc17)\0A %102 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %103 = \22stable_mosaic_gpu.arith.index_cast\22(%102) : (index) -> i32 loc(#loc17)\0A %104 = \22stable_mosaic_gpu.arith.muli\22(%103, %101) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %105 = \22stable_mosaic_gpu.arith.addi\22(%99, %104) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %106 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %107 = \22stable_mosaic_gpu.arith.index_cast\22(%106) : (index) -> i32 loc(#loc17)\0A %108 = \22stable_mosaic_gpu.arith.muli\22(%101, %107) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %109 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %110 = \22stable_mosaic_gpu.arith.index_cast\22(%109) : (index) -> i32 loc(#loc17)\0A %111 = \22stable_mosaic_gpu.arith.muli\22(%110, %108) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %112 = \22stable_mosaic_gpu.arith.addi\22(%105, %111) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %113 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %114 = \22stable_mosaic_gpu.arith.index_cast\22(%113) : (index) -> i32 loc(#loc17)\0A %115 = \22stable_mosaic_gpu.arith.muli\22(%108, %114) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %116 = \22stable_mosaic_gpu.arith.constant\22() {value = 5 : i32} : () -> i32 loc(#loc17)\0A %117 = \22stable_mosaic_gpu.arith.shrui\22(%112, %116) : (i32, i32) -> i32 loc(#loc17)\0A %118 = \22stable_mosaic_gpu.arith.constant\22() {value = -1 : i32} : () -> i32 loc(#loc17)\0A %119 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc17)\0A %120 = \22stable_mosaic_gpu.arith.constant\22() {value = 31 : i32} : () -> i32 loc(#loc17)\0A %121 = \22stable_mosaic_gpu.nvvm.shfl.sync\22(%118, %117, %119, %120) {kind = #nvvm} : (i32, i32, i32, i32) -> i32 loc(#loc17)\0A %122 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc17)\0A %123 = \22stable_mosaic_gpu.arith.cmpi\22(%121, %122) {predicate = 0 : i64} : (i32, i32) -> i1 loc(#loc17)\0A %124 = \22stable_mosaic_gpu.arith.andi\22(%123, %97) : (i1, i1) -> i1 loc(#loc17)\0A \22stable_mosaic_gpu.scf.if\22(%124) ({\0A %332 = \22stable_mosaic_gpu.llvm.getelementptr\22(%94) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>) -> !llvm.ptr<3> loc(#loc17)\0A %333 = \22stable_mosaic_gpu.arith.constant\22() {value = 128 : i32} : () -> i32 loc(#loc17)\0A \22stable_mosaic_gpu.nvvm.mbarrier.init.shared\22(%332, %333) : (!llvm.ptr<3>, i32) -> () loc(#loc17)\0A \22stable_mosaic_gpu.scf.yield\22() : () -> () loc(#loc13)\0A }, {\0A }) : (i1) -> () loc(#loc17)\0A %125 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc17)\0A \22stable_mosaic_gpu.nvvm.fence.mbarrier.init\22() : () -> () loc(#loc17)\0A \22stable_mosaic_gpu.gpu.barrier\22() : () -> () loc(#loc17)\0A %126 = \22stable_mosaic_gpu.nvvm.elect.sync\22() : () -> i1 loc(#loc17)\0A %127 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %128 = \22stable_mosaic_gpu.arith.index_cast\22(%127) : (index) -> i32 loc(#loc17)\0A %129 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %130 = \22stable_mosaic_gpu.arith.index_cast\22(%129) : (index) -> i32 loc(#loc17)\0A %131 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %132 = \22stable_mosaic_gpu.arith.index_cast\22(%131) : (index) -> i32 loc(#loc17)\0A %133 = \22stable_mosaic_gpu.arith.muli\22(%132, %130) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %134 = \22stable_mosaic_gpu.arith.addi\22(%128, %133) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %135 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %136 = \22stable_mosaic_gpu.arith.index_cast\22(%135) : (index) -> i32 loc(#loc17)\0A %137 = \22stable_mosaic_gpu.arith.muli\22(%130, %136) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %138 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %139 = \22stable_mosaic_gpu.arith.index_cast\22(%138) : (index) -> i32 loc(#loc17)\0A %140 = \22stable_mosaic_gpu.arith.muli\22(%139, %137) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %141 = \22stable_mosaic_gpu.arith.addi\22(%134, %140) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %142 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %143 = \22stable_mosaic_gpu.arith.index_cast\22(%142) : (index) -> i32 loc(#loc17)\0A %144 = \22stable_mosaic_gpu.arith.muli\22(%137, %143) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %145 = \22stable_mosaic_gpu.arith.constant\22() {value = 5 : i32} : () -> i32 loc(#loc17)\0A %146 = \22stable_mosaic_gpu.arith.shrui\22(%141, %145) : (i32, i32) -> i32 loc(#loc17)\0A %147 = \22stable_mosaic_gpu.arith.constant\22() {value = -1 : i32} : () -> i32 loc(#loc17)\0A %148 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc17)\0A %149 = \22stable_mosaic_gpu.arith.constant\22() {value = 31 : i32} : () -> i32 loc(#loc17)\0A %150 = \22stable_mosaic_gpu.nvvm.shfl.sync\22(%147, %146, %148, %149) {kind = #nvvm} : (i32, i32, i32, i32) -> i32 loc(#loc17)\0A %151 = \22stable_mosaic_gpu.arith.constant\22() {value = 4 : i32} : () -> i32 loc(#loc17)\0A %152 = \22stable_mosaic_gpu.arith.remui\22(%150, %151) : (i32, i32) -> i32 loc(#loc17)\0A %153 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc17)\0A %154 = \22stable_mosaic_gpu.arith.cmpi\22(%152, %153) {predicate = 0 : i64} : (i32, i32) -> i1 loc(#loc17)\0A %155 = \22stable_mosaic_gpu.arith.andi\22(%154, %126) : (i1, i1) -> i1 loc(#loc17)\0A %156 = \22stable_mosaic_gpu.nvvm.elect.sync\22() : () -> i1 loc(#loc17)\0A %157 = \22stable_mosaic_gpu.gpu.block_id\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %158 = \22stable_mosaic_gpu.arith.index_cast\22(%157) : (index) -> i32 loc(#loc17)\0A %159 = \22stable_mosaic_gpu.gpu.dynamic_shared_memory\22() : () -> memref> loc(#loc20)\0A %160 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc20)\0A %161 = \22stable_mosaic_gpu.memref.view\22(%159, %160) : (memref>, index) -> memref<1x256xf32, #gpu.address_space> loc(#loc20)\0A %162 = \22stable_mosaic_gpu.gpu.dynamic_shared_memory\22() : () -> memref> loc(#loc20)\0A %163 = \22stable_mosaic_gpu.arith.constant\22() {value = 1024 : index} : () -> index loc(#loc20)\0A %164 = \22stable_mosaic_gpu.memref.view\22(%162, %163) : (memref>, index) -> memref<1x256xf32, #gpu.address_space> loc(#loc20)\0A %165 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc19)\0A %166 = \22stable_mosaic_gpu.memref.subview\22(%161, %165) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<1x256xf32, #gpu.address_space>, index) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc19)\0A %167 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc19)\0A %168 = \22stable_mosaic_gpu.arith.index_castui\22(%167) : (index) -> i32 loc(#loc19)\0A %169 = \22stable_mosaic_gpu.arith.addi\22(%125, %168) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc19)\0A %170 = \22stable_mosaic_gpu.arith.constant\22() {value = 8 : i32} : () -> i32 loc(#loc19)\0A %171 = \22stable_mosaic_gpu.llvm.getelementptr\22(%94, %169) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3> loc(#loc19)\0A \22stable_mosaic_gpu.nvvm.mbarrier.arrive.expect_tx.shared\22(%171, %170) : (!llvm.ptr<3>, i32) -> () loc(#loc19)\0A %172 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc19)\0A %173 = \22stable_mosaic_gpu.arith.index_cast\22(%172) : (index) -> i32 loc(#loc19)\0A %174 = \22stable_mosaic_gpu.builtin.unrealized_conversion_cast\22(%166) : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> loc(#loc19)\0A %175 = \22stable_mosaic_gpu.llvm.extractvalue\22(%174) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> !llvm.ptr<3> loc(#loc19)\0A %176 = \22stable_mosaic_gpu.llvm.extractvalue\22(%174) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> i64 loc(#loc19)\0A %177 = \22stable_mosaic_gpu.arith.constant\22() {value = 4 : i64} : () -> i64 loc(#loc19)\0A %178 = \22stable_mosaic_gpu.llvm.mul\22(%176, %177) : (i64, i64) -> i64 loc(#loc19)\0A %179 = \22stable_mosaic_gpu.llvm.ptrtoint\22(%175) : (!llvm.ptr<3>) -> i64 loc(#loc19)\0A %180 = \22stable_mosaic_gpu.llvm.add\22(%179, %178) : (i64, i64) -> i64 loc(#loc19)\0A %181 = \22stable_mosaic_gpu.llvm.inttoptr\22(%180) : (i64) -> !llvm.ptr<3> loc(#loc19)\0A %182 = \22stable_mosaic_gpu.arith.constant\22() {value = 1024 : i32} : () -> i32 loc(#loc19)\0A %183 = \22stable_mosaic_gpu.llvm.getelementptr\22(%94, %169) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3> loc(#loc19)\0A \22stable_mosaic_gpu.nvvm.cp.async.bulk.tensor.shared.cluster.global\22(%181, %83, %173, %183, %155) {operandSegmentSizes = array} : (!llvm.ptr<3>, !llvm.ptr, i32, !llvm.ptr<3>, i1) -> () loc(#loc19)\0A %184 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc21)\0A %185 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc21)\0A %186 = \22stable_mosaic_gpu.arith.addi\22(%185, %184) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc21)\0A %187 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc22)\0A %188 = \22stable_mosaic_gpu.arith.remsi\22(%186, %187) : (i32, i32) -> i32 loc(#loc22)\0A %189 = \22stable_mosaic_gpu.arith.index_cast\22(%188) : (i32) -> index loc(#loc23)\0A %190 = \22stable_mosaic_gpu.arith.index_castui\22(%189) : (index) -> i32 loc(#loc23)\0A %191 = \22stable_mosaic_gpu.arith.addi\22(%125, %190) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc23)\0A %192 = \22stable_mosaic_gpu.memref.load\22(%95) : (memref) -> i32 loc(#loc23)\0A %193 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc23)\0A %194 = \22stable_mosaic_gpu.arith.shli\22(%193, %191) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc23)\0A %195 = \22stable_mosaic_gpu.arith.andi\22(%192, %194) : (i32, i32) -> i32 loc(#loc23)\0A %196 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc23)\0A %197 = \22stable_mosaic_gpu.arith.cmpi\22(%195, %196) {predicate = 1 : i64} : (i32, i32) -> i1 loc(#loc23)\0A %198 = \22stable_mosaic_gpu.arith.xori\22(%192, %194) : (i32, i32) -> i32 loc(#loc23)\0A \22stable_mosaic_gpu.memref.store\22(%198, %95) : (i32, memref) -> () loc(#loc23)\0A %199 = \22stable_mosaic_gpu.arith.constant\22() {value = 10000000 : i32} : () -> i32 loc(#loc23)\0A %200 = \22stable_mosaic_gpu.arith.extui\22(%197) : (i1) -> i32 loc(#loc23)\0A %201 = \22stable_mosaic_gpu.llvm.getelementptr\22(%94, %191) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3> loc(#loc23)\0A \22stable_mosaic_gpu.nvvm.mbarrier.try_wait.parity.shared\22(%201, %200, %199) : (!llvm.ptr<3>, i32, i32) -> () loc(#loc23)\0A %202 = \22stable_mosaic_gpu.arith.index_cast\22(%188) : (i32) -> index loc(#loc24)\0A %203 = \22stable_mosaic_gpu.memref.subview\22(%161, %202) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<1x256xf32, #gpu.address_space>, index) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc24)\0A %204 = \22stable_mosaic_gpu.arith.index_cast\22(%188) : (i32) -> index loc(#loc24)\0A %205 = \22stable_mosaic_gpu.memref.subview\22(%164, %204) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<1x256xf32, #gpu.address_space>, index) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc24)\0A %206 = \22stable_mosaic_gpu.gpu.block_id\22() {dimension = #gpu} : () -> index loc(#loc24)\0A %207 = \22stable_mosaic_gpu.arith.index_cast\22(%206) : (index) -> i32 loc(#loc24)\0A %208 = \22stable_mosaic_gpu.memref.subview\22(%203) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc25)\0A %209 = \22stable_mosaic_gpu.memref.collapse_shape\22(%208) {reassociation = [[0]]} : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc25)\0A %210 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc25)\0A %211 = \22stable_mosaic_gpu.arith.constant\22() {value = 128 : index} : () -> index loc(#loc25)\0A %212 = \22stable_mosaic_gpu.arith.remui\22(%210, %211) : (index, index) -> index loc(#loc25)\0A %213 = \22stable_mosaic_gpu.arith.constant\22() {value = 2 : index} : () -> index loc(#loc25)\0A %214 = \22stable_mosaic_gpu.arith.muli\22(%212, %213) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc25)\0A %215 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc25)\0A %216 = \22stable_mosaic_gpu.arith.addi\22(%214, %215) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc25)\0A %217 = \22stable_mosaic_gpu.vector.load\22(%209, %216) : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>, index) -> vector<2xf32> loc(#loc25)\0A %218 = \22stable_mosaic_gpu.arith.constant\22() {value = 1.000000e+00 : f32} : () -> f32 loc(#loc26)\0A %219 = \22stable_mosaic_gpu.vector.splat\22(%218) : (f32) -> vector<2xf32> loc(#loc26)\0A %220 = \22stable_mosaic_gpu.arith.addf\22(%217, %219) {fastmath = #arith.fastmath} : (vector<2xf32>, vector<2xf32>) -> vector<2xf32> loc(#loc26)\0A %221 = \22stable_mosaic_gpu.memref.subview\22(%205) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc27)\0A %222 = \22stable_mosaic_gpu.memref.collapse_shape\22(%221) {reassociation = [[0]]} : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc27)\0A %223 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc27)\0A %224 = \22stable_mosaic_gpu.arith.constant\22() {value = 128 : index} : () -> index loc(#loc27)\0A %225 = \22stable_mosaic_gpu.arith.remui\22(%223, %224) : (index, index) -> index loc(#loc27)\0A %226 = \22stable_mosaic_gpu.arith.constant\22() {value = 2 : index} : () -> index loc(#loc27)\0A %227 = \22stable_mosaic_gpu.arith.muli\22(%225, %226) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc27)\0A %228 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc27)\0A %229 = \22stable_mosaic_gpu.arith.addi\22(%227, %228) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc27)\0A %230 = \22stable_mosaic_gpu.vector.load\22(%222, %229) : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>, index) -> vector<2xf32> loc(#loc27)\0A %231 = \22stable_mosaic_gpu.memref.collapse_shape\22(%221) {reassociation = [[0]]} : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc27)\0A %232 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc27)\0A %233 = \22stable_mosaic_gpu.arith.constant\22() {value = 128 : index} : () -> index loc(#loc27)\0A %234 = \22stable_mosaic_gpu.arith.remui\22(%232, %233) : (index, index) -> index loc(#loc27)\0A %235 = \22stable_mosaic_gpu.arith.constant\22() {value = 2 : index} : () -> index loc(#loc27)\0A %236 = \22stable_mosaic_gpu.arith.muli\22(%234, %235) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc27)\0A %237 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc27)\0A %238 = \22stable_mosaic_gpu.arith.addi\22(%236, %237) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc27)\0A \22stable_mosaic_gpu.vector.store\22(%220, %231, %238) : (vector<2xf32>, memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>, index) -> () loc(#loc27)\0A \22stable_mosaic_gpu.nvvm.cp.async.bulk.commit.group\22() : () -> () loc(#loc28)\0A %239 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc29)\0A %240 = \22stable_mosaic_gpu.arith.addi\22(%186, %239) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc29)\0A %241 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc22)\0A %242 = \22stable_mosaic_gpu.arith.remsi\22(%240, %241) : (i32, i32) -> i32 loc(#loc22)\0A %243 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc30)\0A %244 = \22stable_mosaic_gpu.arith.cmpi\22(%186, %243) {predicate = 9 : i64} : (i32, i32) -> i1 loc(#loc30)\0A %245 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc31)\0A %246 = \22stable_mosaic_gpu.arith.cmpi\22(%240, %245) {predicate = 6 : i64} : (i32, i32) -> i1 loc(#loc31)\0A %247 = \22stable_mosaic_gpu.arith.andi\22(%244, %246) : (i1, i1) -> i1 loc(#loc32)\0A %248 = \22stable_mosaic_gpu.arith.extui\22(%247) : (i1) -> i32 loc(#loc33)\0A %249 = \22stable_mosaic_gpu.arith.index_cast\22(%248) : (i32) -> index loc(#loc34)\0A \22stable_mosaic_gpu.scf.index_switch\22(%249) ({\0A %313 = \22stable_mosaic_gpu.arith.index_cast\22(%242) : (i32) -> index loc(#loc19)\0A %314 = \22stable_mosaic_gpu.memref.subview\22(%161, %313) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<1x256xf32, #gpu.address_space>, index) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc19)\0A %315 = \22stable_mosaic_gpu.arith.index_cast\22(%242) : (i32) -> index loc(#loc19)\0A %316 = \22stable_mosaic_gpu.arith.index_castui\22(%315) : (index) -> i32 loc(#loc19)\0A %317 = \22stable_mosaic_gpu.arith.addi\22(%125, %316) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc19)\0A %318 = \22stable_mosaic_gpu.arith.constant\22() {value = 8 : i32} : () -> i32 loc(#loc19)\0A %319 = \22stable_mosaic_gpu.llvm.getelementptr\22(%94, %317) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3> loc(#loc19)\0A \22stable_mosaic_gpu.nvvm.mbarrier.arrive.expect_tx.shared\22(%319, %318) : (!llvm.ptr<3>, i32) -> () loc(#loc19)\0A %320 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc19)\0A %321 = \22stable_mosaic_gpu.arith.index_cast\22(%320) : (index) -> i32 loc(#loc19)\0A %322 = \22stable_mosaic_gpu.builtin.unrealized_conversion_cast\22(%314) : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> loc(#loc19)\0A %323 = \22stable_mosaic_gpu.llvm.extractvalue\22(%322) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> !llvm.ptr<3> loc(#loc19)\0A %324 = \22stable_mosaic_gpu.llvm.extractvalue\22(%322) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> i64 loc(#loc19)\0A %325 = \22stable_mosaic_gpu.arith.constant\22() {value = 4 : i64} : () -> i64 loc(#loc19)\0A %326 = \22stable_mosaic_gpu.llvm.mul\22(%324, %325) : (i64, i64) -> i64 loc(#loc19)\0A %327 = \22stable_mosaic_gpu.llvm.ptrtoint\22(%323) : (!llvm.ptr<3>) -> i64 loc(#loc19)\0A %328 = \22stable_mosaic_gpu.llvm.add\22(%327, %326) : (i64, i64) -> i64 loc(#loc19)\0A %329 = \22stable_mosaic_gpu.llvm.inttoptr\22(%328) : (i64) -> !llvm.ptr<3> loc(#loc19)\0A %330 = \22stable_mosaic_gpu.arith.constant\22() {value = 1024 : i32} : () -> i32 loc(#loc19)\0A %331 = \22stable_mosaic_gpu.llvm.getelementptr\22(%94, %317) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3> loc(#loc19)\0A \22stable_mosaic_gpu.nvvm.cp.async.bulk.tensor.shared.cluster.global\22(%329, %83, %321, %331, %155) {operandSegmentSizes = array} : (!llvm.ptr<3>, !llvm.ptr, i32, !llvm.ptr<3>, i1) -> () loc(#loc19)\0A \22stable_mosaic_gpu.scf.yield\22() : () -> () loc(#loc16)\0A }, {\0A \22stable_mosaic_gpu.scf.yield\22() : () -> () loc(#loc34)\0A }) {cases = array} : (index) -> () loc(#loc34)\0A %250 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc21)\0A %251 = \22stable_mosaic_gpu.arith.addi\22(%184, %250) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc21)\0A \22stable_mosaic_gpu.nvvm.fence.proxy\22() {kind = #nvvm.proxy_kind, space = #nvvm.shared_space} : () -> () loc(#loc35)\0A %252 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc35)\0A %253 = \22stable_mosaic_gpu.arith.index_cast\22(%252) : (index) -> i32 loc(#loc35)\0A %254 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc35)\0A %255 = \22stable_mosaic_gpu.arith.index_cast\22(%254) : (index) -> i32 loc(#loc35)\0A %256 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc35)\0A %257 = \22stable_mosaic_gpu.arith.index_cast\22(%256) : (index) -> i32 loc(#loc35)\0A %258 = \22stable_mosaic_gpu.arith.muli\22(%257, %255) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\0A %259 = \22stable_mosaic_gpu.arith.addi\22(%253, %258) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\0A %260 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc35)\0A %261 = \22stable_mosaic_gpu.arith.index_cast\22(%260) : (index) -> i32 loc(#loc35)\0A %262 = \22stable_mosaic_gpu.arith.muli\22(%255, %261) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\0A %263 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc35)\0A %264 = \22stable_mosaic_gpu.arith.index_cast\22(%263) : (index) -> i32 loc(#loc35)\0A %265 = \22stable_mosaic_gpu.arith.muli\22(%264, %262) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\0A %266 = \22stable_mosaic_gpu.arith.addi\22(%259, %265) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\0A %267 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc35)\0A %268 = \22stable_mosaic_gpu.arith.index_cast\22(%267) : (index) -> i32 loc(#loc35)\0A %269 = \22stable_mosaic_gpu.arith.muli\22(%262, %268) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\0A %270 = \22stable_mosaic_gpu.arith.constant\22() {value = 7 : i32} : () -> i32 loc(#loc35)\0A %271 = \22stable_mosaic_gpu.arith.shrui\22(%266, %270) : (i32, i32) -> i32 loc(#loc35)\0A %272 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc35)\0A %273 = \22stable_mosaic_gpu.arith.addi\22(%271, %272) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\0A %274 = \22stable_mosaic_gpu.llvm.inline_asm\22(%273) {asm_string = \22bar.sync $0, 128;\22, constraints = \22r\22, has_side_effects} : (i32) -> !llvm.void loc(#loc35)\0A %275 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc22)\0A %276 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc22)\0A %277 = \22stable_mosaic_gpu.arith.remsi\22(%275, %276) : (i32, i32) -> i32 loc(#loc22)\0A %278 = \22stable_mosaic_gpu.arith.index_cast\22(%277) : (i32) -> index loc(#loc18)\0A %279 = \22stable_mosaic_gpu.memref.subview\22(%164, %278) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<1x256xf32, #gpu.address_space>, index) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc18)\0A %280 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc18)\0A %281 = \22stable_mosaic_gpu.arith.index_cast\22(%280) : (index) -> i32 loc(#loc18)\0A %282 = \22stable_mosaic_gpu.builtin.unrealized_conversion_cast\22(%279) : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> loc(#loc18)\0A %283 = \22stable_mosaic_gpu.llvm.extractvalue\22(%282) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> !llvm.ptr<3> loc(#loc18)\0A %284 = \22stable_mosaic_gpu.llvm.extractvalue\22(%282) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> i64 loc(#loc18)\0A %285 = \22stable_mosaic_gpu.arith.constant\22() {value = 4 : i64} : () -> i64 loc(#loc18)\0A %286 = \22stable_mosaic_gpu.llvm.mul\22(%284, %285) : (i64, i64) -> i64 loc(#loc18)\0A %287 = \22stable_mosaic_gpu.llvm.ptrtoint\22(%283) : (!llvm.ptr<3>) -> i64 loc(#loc18)\0A %288 = \22stable_mosaic_gpu.llvm.add\22(%287, %286) : (i64, i64) -> i64 loc(#loc18)\0A %289 = \22stable_mosaic_gpu.llvm.inttoptr\22(%288) : (i64) -> !llvm.ptr<3> loc(#loc18)\0A \22stable_mosaic_gpu.nvvm.cp.async.bulk.tensor.global.shared.cta\22(%82, %289, %281, %155) {operandSegmentSizes = array} : (!llvm.ptr, !llvm.ptr<3>, i32, i1) -> () loc(#loc18)\0A \22stable_mosaic_gpu.nvvm.cp.async.bulk.commit.group\22() : () -> () loc(#loc28)\0A \22stable_mosaic_gpu.nvvm.cp.async.bulk.wait_group\22() {group = 0 : i32} : () -> () loc(#loc36)\0A %290 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc36)\0A %291 = \22stable_mosaic_gpu.arith.index_cast\22(%290) : (index) -> i32 loc(#loc36)\0A %292 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc36)\0A %293 = \22stable_mosaic_gpu.arith.index_cast\22(%292) : (index) -> i32 loc(#loc36)\0A %294 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc36)\0A %295 = \22stable_mosaic_gpu.arith.index_cast\22(%294) : (index) -> i32 loc(#loc36)\0A %296 = \22stable_mosaic_gpu.arith.muli\22(%295, %293) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\0A %297 = \22stable_mosaic_gpu.arith.addi\22(%291, %296) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\0A %298 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc36)\0A %299 = \22stable_mosaic_gpu.arith.index_cast\22(%298) : (index) -> i32 loc(#loc36)\0A %300 = \22stable_mosaic_gpu.arith.muli\22(%293, %299) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\0A %301 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc36)\0A %302 = \22stable_mosaic_gpu.arith.index_cast\22(%301) : (index) -> i32 loc(#loc36)\0A %303 = \22stable_mosaic_gpu.arith.muli\22(%302, %300) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\0A %304 = \22stable_mosaic_gpu.arith.addi\22(%297, %303) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\0A %305 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc36)\0A %306 = \22stable_mosaic_gpu.arith.index_cast\22(%305) : (index) -> i32 loc(#loc36)\0A %307 = \22stable_mosaic_gpu.arith.muli\22(%300, %306) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\0A %308 = \22stable_mosaic_gpu.arith.constant\22() {value = 7 : i32} : () -> i32 loc(#loc36)\0A %309 = \22stable_mosaic_gpu.arith.shrui\22(%304, %308) : (i32, i32) -> i32 loc(#loc36)\0A %310 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc36)\0A %311 = \22stable_mosaic_gpu.arith.addi\22(%309, %310) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\0A %312 = \22stable_mosaic_gpu.llvm.inline_asm\22(%311) {asm_string = \22bar.sync $0, 128;\22, constraints = \22r\22, has_side_effects} : (i32) -> !llvm.void loc(#loc36)\0A \22stable_mosaic_gpu.gpu.terminator\22() : () -> () loc(#loc17)\0A }) {operandSegmentSizes = array, workgroup_attributions = 0 : i64} : (!gpu.async.token, index, index, index, index, index, index, i32) -> !gpu.async.token loc(#loc17)\0A \22stable_mosaic_gpu.func.return\22() : () -> () loc(#loc17)\0A }) {function_type = (!llvm.ptr, !llvm.ptr) -> (), llvm.emit_c_interface, sym_name = \22mosaic_gpu_body\22} : () -> () loc(#loc17)\0A}) {stable_mosaic_gpu.version = 1 : i64} : () -> () loc(#loc17)\0A#loc13 = loc(\22-\22:141:7)\0A#loc14 = loc(\22third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py\22:78:19)\0A#loc15 = loc(\22third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py\22:78:6)\0A#loc16 = loc(\22-\22:279:7)\0A#loc18 = loc(\22/copy_smem_to_gmem\22(#loc))\0A#loc19 = loc(\22/copy_gmem_to_smem\22(#loc))\0A#loc20 = loc(\22/run_scoped\22(#loc))\0A#loc21 = loc(\22/scan\22(#loc))\0A#loc22 = loc(\22/rem\22(#loc))\0A#loc23 = loc(\22/barrier_wait\22(#loc))\0A#loc24 = loc(\22/jaxpr_call\22(#loc))\0A#loc25 = loc(\22/get\22(#loc14))\0A#loc26 = loc(\22/add\22(#loc14))\0A#loc27 = loc(\22/swap\22(#loc15))\0A#loc28 = loc(\22/commit_group\22(#loc))\0A#loc29 = loc(\22/add\22(#loc))\0A#loc30 = loc(\22/ge\22(#loc))\0A#loc31 = loc(\22/lt\22(#loc))\0A#loc32 = loc(\22/and\22(#loc))\0A#loc33 = loc(\22/convert_element_type\22(#loc))\0A#loc34 = loc(\22/cond\22(#loc))\0A#loc35 = loc(\22/commit_smem\22(#loc))\0A#loc36 = loc(\22/wait_smem_to_gmem\22(#loc))\0A", operand_layouts = [dense<0> : tensor<1xindex>], result_layouts = [dense<0> : tensor<1xindex>]} : (tensor<256xf32>) -> tensor<256xf32> loc(#loc3) + return %0 : tensor<256xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":83:4) +#loc3 = loc("jit(wrapped)/jit(main)/pallas_call"(#loc2)) +""", + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.9.7\x00\x01\x19\x05\x01\x05\t\x01\x03\x0b\x03\x07\x0f\x13\x17\x03_=\x0f\x01\x1d\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x13\x0b\x03!\x0b\x0f\x0f\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b/\x01\x05\x0b\x0f\x03\x0b\x17\x17\x07\x13\x07\x02\xd5\x1f\x11\x03\x05\x03\x07\x07\t\x0b\x03\r\x03\x05\r\x11\x01\x00\x05\x0f\x05\x11\x05\x13\x1d\x13\x01\x05\x15\x1d\x17\x19\x05\x17\x17\x1b\xa7\t\x05\x19\x03\x01\x03\x03;\x03\x03#\r\x01#\x07\x03\x03)\r\x03+-\x1d\x1b\x1d\x1d\x1d\x1f\x1d!\x0b\x05\x1d#\x1d%\x05\x01\x1f\x0b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x03\x02\x08\t\x11\x03\x05\x03\x05\t)\x03\x05\r\x13\x04O\x05\x01Q\x01\x05\x01\x07\x04=\x03\x01\x05\x03P\x01\x03\x07\x04)\x03\x05\x0b\x03\x0b\x11\x00\x05F\x15\x05\x03\x05\x03\x01\x07\x04\x01\x03\x03\x06\x03\x01\x05\x01\x00D\xae\x05\'\x17\xa4\xa4\x05\x0f\x0b\x0f!\x85G\x11\x19%)9\x15\x1f\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_wrapped\x00args[0]\x00jit(wrapped)/jit(main)/pallas_call\x00third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py\x00jax.result_info\x00result\x00main\x00public\x00\xa9C\xfb\x81\x9a1\xc2?\x0e\xf4\xe1\xe4\xe77\x03\xb6\x97\xe5G(]WR\x98\xeb{\xba\x8a\x84\x01\x12\'#loc = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":83:4)\n#loc1 = loc("-":94:40)\n#loc2 = loc("-":94:47)\n#loc3 = loc("-":94:54)\n#loc4 = loc("-":94:116)\n#loc5 = loc("-":94:123)\n#loc6 = loc("-":94:130)\n#loc7 = loc("-":94:65)\n#loc8 = loc("-":94:78)\n#loc9 = loc("-":94:91)\n#loc10 = loc("-":94:141)\n#loc11 = loc("-":94:157)\n#loc12 = loc("-":94:174)\n#loc17 = loc("jit(wrapped)/jit(main)/pallas_call"(#loc))\n"builtin.module"() <{sym_name = "add_one"}> ({\n "stable_mosaic_gpu.func.func"() ({\n }) {function_type = (!llvm.ptr, !llvm.ptr, i64, i64, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> (), sym_name = "mosaic_gpu_init_tma_desc", sym_visibility = "private"} : () -> () loc(#loc17)\n "stable_mosaic_gpu.llvm.mlir.global"() ({\n }) {addr_space = 4 : i32, global_type = !llvm.array<0 x i8>, linkage = #llvm.linkage, sym_name = "global_scratch", unnamed_addr = 0 : i64, visibility_ = 0 : i64} : () -> () loc(#loc17)\n "stable_mosaic_gpu.func.func"() ({\n ^bb0(%arg0: !llvm.ptr loc("jit(wrapped)/jit(main)/pallas_call"(#loc)), %arg1: !llvm.ptr loc("jit(wrapped)/jit(main)/pallas_call"(#loc))):\n %0 = "stable_mosaic_gpu.builtin.unrealized_conversion_cast"(%arg0) : (!llvm.ptr) -> !gpu.async.token loc(#loc17)\n %1 = "stable_mosaic_gpu.llvm.getelementptr"(%arg1) {elem_type = !llvm.ptr, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n %2 = "stable_mosaic_gpu.llvm.load"(%1) {ordering = 0 : i64} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n %3 = "stable_mosaic_gpu.llvm.mlir.undef"() : () -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %4 = "stable_mosaic_gpu.llvm.insertvalue"(%3, %2) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %5 = "stable_mosaic_gpu.llvm.insertvalue"(%4, %2) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %6 = "stable_mosaic_gpu.llvm.mlir.constant"() {value = 0 : i64} : () -> i64 loc(#loc17)\n %7 = "stable_mosaic_gpu.llvm.insertvalue"(%5, %6) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %8 = "stable_mosaic_gpu.llvm.mlir.constant"() {value = 256 : i64} : () -> i64 loc(#loc17)\n %9 = "stable_mosaic_gpu.llvm.insertvalue"(%7, %8) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %10 = "stable_mosaic_gpu.llvm.mlir.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %11 = "stable_mosaic_gpu.llvm.insertvalue"(%9, %10) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %12 = "stable_mosaic_gpu.builtin.unrealized_conversion_cast"(%11) : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>) -> memref<256xf32> loc(#loc17)\n %13 = "stable_mosaic_gpu.llvm.getelementptr"(%arg1) {elem_type = !llvm.ptr, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n %14 = "stable_mosaic_gpu.llvm.load"(%13) {ordering = 0 : i64} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n %15 = "stable_mosaic_gpu.llvm.mlir.undef"() : () -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %16 = "stable_mosaic_gpu.llvm.insertvalue"(%15, %14) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %17 = "stable_mosaic_gpu.llvm.insertvalue"(%16, %14) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %18 = "stable_mosaic_gpu.llvm.mlir.constant"() {value = 0 : i64} : () -> i64 loc(#loc17)\n %19 = "stable_mosaic_gpu.llvm.insertvalue"(%17, %18) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %20 = "stable_mosaic_gpu.llvm.mlir.constant"() {value = 256 : i64} : () -> i64 loc(#loc17)\n %21 = "stable_mosaic_gpu.llvm.insertvalue"(%19, %20) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %22 = "stable_mosaic_gpu.llvm.mlir.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %23 = "stable_mosaic_gpu.llvm.insertvalue"(%21, %22) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %24 = "stable_mosaic_gpu.builtin.unrealized_conversion_cast"(%23) : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>) -> memref<256xf32> loc(#loc17)\n %25 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %26 = "stable_mosaic_gpu.llvm.alloca"(%25) {alignment = 64 : i64, elem_type = !llvm.array<256 x i8>} : (i64) -> !llvm.ptr loc(#loc17)\n %27 = "stable_mosaic_gpu.llvm.getelementptr"(%26) {elem_type = i8, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n %28:4 = "stable_mosaic_gpu.memref.extract_strided_metadata"(%12) : (memref<256xf32>) -> (memref, index, index, index) loc(#loc17)\n %29 = "stable_mosaic_gpu.memref.extract_aligned_pointer_as_index"(%12) : (memref<256xf32>) -> index loc(#loc17)\n %30 = "stable_mosaic_gpu.arith.index_cast"(%29) : (index) -> i64 loc(#loc17)\n %31 = "stable_mosaic_gpu.llvm.inttoptr"(%30) : (i64) -> !llvm.ptr loc(#loc17)\n %32 = "stable_mosaic_gpu.arith.index_cast"(%28#1) : (index) -> i64 loc(#loc17)\n %33 = "stable_mosaic_gpu.llvm.getelementptr"(%31, %32) {elem_type = f32, rawConstantIndices = array} : (!llvm.ptr, i64) -> !llvm.ptr loc(#loc17)\n %34 = "stable_mosaic_gpu.arith.constant"() {value = 6 : i64} : () -> i64 loc(#loc17)\n %35 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %36 = "stable_mosaic_gpu.arith.index_cast"(%28#2) : (index) -> i64 loc(#loc17)\n %37 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %38 = "stable_mosaic_gpu.llvm.alloca"(%37) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\n %39 = "stable_mosaic_gpu.llvm.getelementptr"(%38) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n "stable_mosaic_gpu.llvm.store"(%36, %39) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\n %40 = "stable_mosaic_gpu.arith.index_cast"(%28#3) : (index) -> i64 loc(#loc17)\n %41 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %42 = "stable_mosaic_gpu.llvm.alloca"(%41) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\n %43 = "stable_mosaic_gpu.llvm.getelementptr"(%42) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n "stable_mosaic_gpu.llvm.store"(%40, %43) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\n %44 = "stable_mosaic_gpu.arith.constant"() {value = 16 : i64} : () -> i64 loc(#loc17)\n %45 = "stable_mosaic_gpu.arith.constant"() {value = 256 : i64} : () -> i64 loc(#loc17)\n %46 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %47 = "stable_mosaic_gpu.llvm.alloca"(%46) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\n %48 = "stable_mosaic_gpu.llvm.getelementptr"(%47) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n "stable_mosaic_gpu.llvm.store"(%45, %48) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\n "stable_mosaic_gpu.func.call"(%27, %33, %34, %35, %38, %42, %44, %47) {callee = @mosaic_gpu_init_tma_desc} : (!llvm.ptr, !llvm.ptr, i64, i64, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> () loc(#loc17)\n %49 = "stable_mosaic_gpu.llvm.getelementptr"(%26) {elem_type = i8, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n %50:4 = "stable_mosaic_gpu.memref.extract_strided_metadata"(%24) : (memref<256xf32>) -> (memref, index, index, index) loc(#loc17)\n %51 = "stable_mosaic_gpu.memref.extract_aligned_pointer_as_index"(%24) : (memref<256xf32>) -> index loc(#loc17)\n %52 = "stable_mosaic_gpu.arith.index_cast"(%51) : (index) -> i64 loc(#loc17)\n %53 = "stable_mosaic_gpu.llvm.inttoptr"(%52) : (i64) -> !llvm.ptr loc(#loc17)\n %54 = "stable_mosaic_gpu.arith.index_cast"(%50#1) : (index) -> i64 loc(#loc17)\n %55 = "stable_mosaic_gpu.llvm.getelementptr"(%53, %54) {elem_type = f32, rawConstantIndices = array} : (!llvm.ptr, i64) -> !llvm.ptr loc(#loc17)\n %56 = "stable_mosaic_gpu.arith.constant"() {value = 6 : i64} : () -> i64 loc(#loc17)\n %57 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %58 = "stable_mosaic_gpu.arith.index_cast"(%50#2) : (index) -> i64 loc(#loc17)\n %59 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %60 = "stable_mosaic_gpu.llvm.alloca"(%59) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\n %61 = "stable_mosaic_gpu.llvm.getelementptr"(%60) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n "stable_mosaic_gpu.llvm.store"(%58, %61) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\n %62 = "stable_mosaic_gpu.arith.index_cast"(%50#3) : (index) -> i64 loc(#loc17)\n %63 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %64 = "stable_mosaic_gpu.llvm.alloca"(%63) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\n %65 = "stable_mosaic_gpu.llvm.getelementptr"(%64) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n "stable_mosaic_gpu.llvm.store"(%62, %65) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\n %66 = "stable_mosaic_gpu.arith.constant"() {value = 16 : i64} : () -> i64 loc(#loc17)\n %67 = "stable_mosaic_gpu.arith.constant"() {value = 256 : i64} : () -> i64 loc(#loc17)\n %68 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %69 = "stable_mosaic_gpu.llvm.alloca"(%68) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\n %70 = "stable_mosaic_gpu.llvm.getelementptr"(%69) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n "stable_mosaic_gpu.llvm.store"(%67, %70) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\n "stable_mosaic_gpu.func.call"(%49, %55, %56, %57, %60, %64, %66, %69) {callee = @mosaic_gpu_init_tma_desc} : (!llvm.ptr, !llvm.ptr, i64, i64, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> () loc(#loc17)\n %71 = "stable_mosaic_gpu.llvm.load"(%26) {ordering = 0 : i64} : (!llvm.ptr) -> !llvm.array<256 x i8> loc(#loc17)\n %72 = "stable_mosaic_gpu.arith.constant"() {value = 2 : index} : () -> index loc(#loc17)\n %73 = "stable_mosaic_gpu.arith.constant"() {value = 1 : index} : () -> index loc(#loc17)\n %74 = "stable_mosaic_gpu.arith.constant"() {value = 1 : index} : () -> index loc(#loc17)\n %75 = "stable_mosaic_gpu.arith.constant"() {value = 128 : index} : () -> index loc(#loc17)\n %76 = "stable_mosaic_gpu.arith.constant"() {value = 1 : index} : () -> index loc(#loc17)\n %77 = "stable_mosaic_gpu.arith.constant"() {value = 1 : index} : () -> index loc(#loc17)\n %78 = "stable_mosaic_gpu.arith.constant"() {value = 2056 : i32} : () -> i32 loc(#loc17)\n %79 = "stable_mosaic_gpu.gpu.launch"(%0, %72, %73, %74, %75, %76, %77, %78) ({\n ^bb0(%arg2: index loc("-":94:40), %arg3: index loc("-":94:47), %arg4: index loc("-":94:54), %arg5: index loc("-":94:116), %arg6: index loc("-":94:123), %arg7: index loc("-":94:130), %arg8: index loc("-":94:65), %arg9: index loc("-":94:78), %arg10: index loc("-":94:91), %arg11: index loc("-":94:141), %arg12: index loc("-":94:157), %arg13: index loc("-":94:174)):\n %80 = "stable_mosaic_gpu.gpu.dynamic_shared_memory"() : () -> memref> loc(#loc17)\n %81 = "stable_mosaic_gpu.builtin.unrealized_conversion_cast"(%71) : (!llvm.array<256 x i8>) -> !llvm.ptr loc(#loc17)\n %82 = "stable_mosaic_gpu.llvm.getelementptr"(%81) {elem_type = i8, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc18)\n %83 = "stable_mosaic_gpu.llvm.getelementptr"(%81) {elem_type = i8, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc19)\n %84 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc17)\n %85 = "stable_mosaic_gpu.memref.view"(%80, %84) : (memref>, index) -> memref<2048xi8, #gpu.address_space> loc(#loc17)\n %86 = "stable_mosaic_gpu.builtin.unrealized_conversion_cast"(%80) : (memref>) -> !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %87 = "stable_mosaic_gpu.llvm.extractvalue"(%86) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> !llvm.ptr<3> loc(#loc17)\n %88 = "stable_mosaic_gpu.llvm.extractvalue"(%86) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> i64 loc(#loc17)\n %89 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %90 = "stable_mosaic_gpu.llvm.mul"(%88, %89) : (i64, i64) -> i64 loc(#loc17)\n %91 = "stable_mosaic_gpu.llvm.ptrtoint"(%87) : (!llvm.ptr<3>) -> i64 loc(#loc17)\n %92 = "stable_mosaic_gpu.llvm.add"(%91, %90) : (i64, i64) -> i64 loc(#loc17)\n %93 = "stable_mosaic_gpu.llvm.inttoptr"(%92) : (i64) -> !llvm.ptr<3> loc(#loc17)\n %94 = "stable_mosaic_gpu.llvm.getelementptr"(%93) {elem_type = i8, rawConstantIndices = array} : (!llvm.ptr<3>) -> !llvm.ptr<3> loc(#loc17)\n %95 = "stable_mosaic_gpu.memref.alloca"() {operandSegmentSizes = array} : () -> memref loc(#loc17)\n %96 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc17)\n "stable_mosaic_gpu.memref.store"(%96, %95) : (i32, memref) -> () loc(#loc17)\n %97 = "stable_mosaic_gpu.nvvm.elect.sync"() : () -> i1 loc(#loc17)\n %98 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc17)\n %99 = "stable_mosaic_gpu.arith.index_cast"(%98) : (index) -> i32 loc(#loc17)\n %100 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc17)\n %101 = "stable_mosaic_gpu.arith.index_cast"(%100) : (index) -> i32 loc(#loc17)\n %102 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc17)\n %103 = "stable_mosaic_gpu.arith.index_cast"(%102) : (index) -> i32 loc(#loc17)\n %104 = "stable_mosaic_gpu.arith.muli"(%103, %101) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %105 = "stable_mosaic_gpu.arith.addi"(%99, %104) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %106 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc17)\n %107 = "stable_mosaic_gpu.arith.index_cast"(%106) : (index) -> i32 loc(#loc17)\n %108 = "stable_mosaic_gpu.arith.muli"(%101, %107) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %109 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc17)\n %110 = "stable_mosaic_gpu.arith.index_cast"(%109) : (index) -> i32 loc(#loc17)\n %111 = "stable_mosaic_gpu.arith.muli"(%110, %108) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %112 = "stable_mosaic_gpu.arith.addi"(%105, %111) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %113 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc17)\n %114 = "stable_mosaic_gpu.arith.index_cast"(%113) : (index) -> i32 loc(#loc17)\n %115 = "stable_mosaic_gpu.arith.muli"(%108, %114) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %116 = "stable_mosaic_gpu.arith.constant"() {value = 5 : i32} : () -> i32 loc(#loc17)\n %117 = "stable_mosaic_gpu.arith.shrui"(%112, %116) : (i32, i32) -> i32 loc(#loc17)\n %118 = "stable_mosaic_gpu.arith.constant"() {value = -1 : i32} : () -> i32 loc(#loc17)\n %119 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc17)\n %120 = "stable_mosaic_gpu.arith.constant"() {value = 31 : i32} : () -> i32 loc(#loc17)\n %121 = "stable_mosaic_gpu.nvvm.shfl.sync"(%118, %117, %119, %120) {kind = #nvvm} : (i32, i32, i32, i32) -> i32 loc(#loc17)\n %122 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc17)\n %123 = "stable_mosaic_gpu.arith.cmpi"(%121, %122) {predicate = 0 : i64} : (i32, i32) -> i1 loc(#loc17)\n %124 = "stable_mosaic_gpu.arith.andi"(%123, %97) : (i1, i1) -> i1 loc(#loc17)\n "stable_mosaic_gpu.scf.if"(%124) ({\n %332 = "stable_mosaic_gpu.llvm.getelementptr"(%94) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>) -> !llvm.ptr<3> loc(#loc17)\n %333 = "stable_mosaic_gpu.arith.constant"() {value = 128 : i32} : () -> i32 loc(#loc17)\n "stable_mosaic_gpu.nvvm.mbarrier.init.shared"(%332, %333) : (!llvm.ptr<3>, i32) -> () loc(#loc17)\n "stable_mosaic_gpu.scf.yield"() : () -> () loc(#loc13)\n }, {\n }) : (i1) -> () loc(#loc17)\n %125 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc17)\n "stable_mosaic_gpu.nvvm.fence.mbarrier.init"() : () -> () loc(#loc17)\n "stable_mosaic_gpu.gpu.barrier"() : () -> () loc(#loc17)\n %126 = "stable_mosaic_gpu.nvvm.elect.sync"() : () -> i1 loc(#loc17)\n %127 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc17)\n %128 = "stable_mosaic_gpu.arith.index_cast"(%127) : (index) -> i32 loc(#loc17)\n %129 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc17)\n %130 = "stable_mosaic_gpu.arith.index_cast"(%129) : (index) -> i32 loc(#loc17)\n %131 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc17)\n %132 = "stable_mosaic_gpu.arith.index_cast"(%131) : (index) -> i32 loc(#loc17)\n %133 = "stable_mosaic_gpu.arith.muli"(%132, %130) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %134 = "stable_mosaic_gpu.arith.addi"(%128, %133) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %135 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc17)\n %136 = "stable_mosaic_gpu.arith.index_cast"(%135) : (index) -> i32 loc(#loc17)\n %137 = "stable_mosaic_gpu.arith.muli"(%130, %136) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %138 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc17)\n %139 = "stable_mosaic_gpu.arith.index_cast"(%138) : (index) -> i32 loc(#loc17)\n %140 = "stable_mosaic_gpu.arith.muli"(%139, %137) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %141 = "stable_mosaic_gpu.arith.addi"(%134, %140) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %142 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc17)\n %143 = "stable_mosaic_gpu.arith.index_cast"(%142) : (index) -> i32 loc(#loc17)\n %144 = "stable_mosaic_gpu.arith.muli"(%137, %143) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %145 = "stable_mosaic_gpu.arith.constant"() {value = 5 : i32} : () -> i32 loc(#loc17)\n %146 = "stable_mosaic_gpu.arith.shrui"(%141, %145) : (i32, i32) -> i32 loc(#loc17)\n %147 = "stable_mosaic_gpu.arith.constant"() {value = -1 : i32} : () -> i32 loc(#loc17)\n %148 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc17)\n %149 = "stable_mosaic_gpu.arith.constant"() {value = 31 : i32} : () -> i32 loc(#loc17)\n %150 = "stable_mosaic_gpu.nvvm.shfl.sync"(%147, %146, %148, %149) {kind = #nvvm} : (i32, i32, i32, i32) -> i32 loc(#loc17)\n %151 = "stable_mosaic_gpu.arith.constant"() {value = 4 : i32} : () -> i32 loc(#loc17)\n %152 = "stable_mosaic_gpu.arith.remui"(%150, %151) : (i32, i32) -> i32 loc(#loc17)\n %153 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc17)\n %154 = "stable_mosaic_gpu.arith.cmpi"(%152, %153) {predicate = 0 : i64} : (i32, i32) -> i1 loc(#loc17)\n %155 = "stable_mosaic_gpu.arith.andi"(%154, %126) : (i1, i1) -> i1 loc(#loc17)\n %156 = "stable_mosaic_gpu.nvvm.elect.sync"() : () -> i1 loc(#loc17)\n %157 = "stable_mosaic_gpu.gpu.block_id"() {dimension = #gpu} : () -> index loc(#loc17)\n %158 = "stable_mosaic_gpu.arith.index_cast"(%157) : (index) -> i32 loc(#loc17)\n %159 = "stable_mosaic_gpu.gpu.dynamic_shared_memory"() : () -> memref> loc(#loc20)\n %160 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc20)\n %161 = "stable_mosaic_gpu.memref.view"(%159, %160) : (memref>, index) -> memref<1x256xf32, #gpu.address_space> loc(#loc20)\n %162 = "stable_mosaic_gpu.gpu.dynamic_shared_memory"() : () -> memref> loc(#loc20)\n %163 = "stable_mosaic_gpu.arith.constant"() {value = 1024 : index} : () -> index loc(#loc20)\n %164 = "stable_mosaic_gpu.memref.view"(%162, %163) : (memref>, index) -> memref<1x256xf32, #gpu.address_space> loc(#loc20)\n %165 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc19)\n %166 = "stable_mosaic_gpu.memref.subview"(%161, %165) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<1x256xf32, #gpu.address_space>, index) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc19)\n %167 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc19)\n %168 = "stable_mosaic_gpu.arith.index_castui"(%167) : (index) -> i32 loc(#loc19)\n %169 = "stable_mosaic_gpu.arith.addi"(%125, %168) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc19)\n %170 = "stable_mosaic_gpu.arith.constant"() {value = 8 : i32} : () -> i32 loc(#loc19)\n %171 = "stable_mosaic_gpu.llvm.getelementptr"(%94, %169) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3> loc(#loc19)\n "stable_mosaic_gpu.nvvm.mbarrier.arrive.expect_tx.shared"(%171, %170) : (!llvm.ptr<3>, i32) -> () loc(#loc19)\n %172 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc19)\n %173 = "stable_mosaic_gpu.arith.index_cast"(%172) : (index) -> i32 loc(#loc19)\n %174 = "stable_mosaic_gpu.builtin.unrealized_conversion_cast"(%166) : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> loc(#loc19)\n %175 = "stable_mosaic_gpu.llvm.extractvalue"(%174) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> !llvm.ptr<3> loc(#loc19)\n %176 = "stable_mosaic_gpu.llvm.extractvalue"(%174) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> i64 loc(#loc19)\n %177 = "stable_mosaic_gpu.arith.constant"() {value = 4 : i64} : () -> i64 loc(#loc19)\n %178 = "stable_mosaic_gpu.llvm.mul"(%176, %177) : (i64, i64) -> i64 loc(#loc19)\n %179 = "stable_mosaic_gpu.llvm.ptrtoint"(%175) : (!llvm.ptr<3>) -> i64 loc(#loc19)\n %180 = "stable_mosaic_gpu.llvm.add"(%179, %178) : (i64, i64) -> i64 loc(#loc19)\n %181 = "stable_mosaic_gpu.llvm.inttoptr"(%180) : (i64) -> !llvm.ptr<3> loc(#loc19)\n %182 = "stable_mosaic_gpu.arith.constant"() {value = 1024 : i32} : () -> i32 loc(#loc19)\n %183 = "stable_mosaic_gpu.llvm.getelementptr"(%94, %169) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3> loc(#loc19)\n "stable_mosaic_gpu.nvvm.cp.async.bulk.tensor.shared.cluster.global"(%181, %83, %173, %183, %155) {operandSegmentSizes = array} : (!llvm.ptr<3>, !llvm.ptr, i32, !llvm.ptr<3>, i1) -> () loc(#loc19)\n %184 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc21)\n %185 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc21)\n %186 = "stable_mosaic_gpu.arith.addi"(%185, %184) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc21)\n %187 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc22)\n %188 = "stable_mosaic_gpu.arith.remsi"(%186, %187) : (i32, i32) -> i32 loc(#loc22)\n %189 = "stable_mosaic_gpu.arith.index_cast"(%188) : (i32) -> index loc(#loc23)\n %190 = "stable_mosaic_gpu.arith.index_castui"(%189) : (index) -> i32 loc(#loc23)\n %191 = "stable_mosaic_gpu.arith.addi"(%125, %190) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc23)\n %192 = "stable_mosaic_gpu.memref.load"(%95) : (memref) -> i32 loc(#loc23)\n %193 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc23)\n %194 = "stable_mosaic_gpu.arith.shli"(%193, %191) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc23)\n %195 = "stable_mosaic_gpu.arith.andi"(%192, %194) : (i32, i32) -> i32 loc(#loc23)\n %196 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc23)\n %197 = "stable_mosaic_gpu.arith.cmpi"(%195, %196) {predicate = 1 : i64} : (i32, i32) -> i1 loc(#loc23)\n %198 = "stable_mosaic_gpu.arith.xori"(%192, %194) : (i32, i32) -> i32 loc(#loc23)\n "stable_mosaic_gpu.memref.store"(%198, %95) : (i32, memref) -> () loc(#loc23)\n %199 = "stable_mosaic_gpu.arith.constant"() {value = 10000000 : i32} : () -> i32 loc(#loc23)\n %200 = "stable_mosaic_gpu.arith.extui"(%197) : (i1) -> i32 loc(#loc23)\n %201 = "stable_mosaic_gpu.llvm.getelementptr"(%94, %191) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3> loc(#loc23)\n "stable_mosaic_gpu.nvvm.mbarrier.try_wait.parity.shared"(%201, %200, %199) : (!llvm.ptr<3>, i32, i32) -> () loc(#loc23)\n %202 = "stable_mosaic_gpu.arith.index_cast"(%188) : (i32) -> index loc(#loc24)\n %203 = "stable_mosaic_gpu.memref.subview"(%161, %202) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<1x256xf32, #gpu.address_space>, index) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc24)\n %204 = "stable_mosaic_gpu.arith.index_cast"(%188) : (i32) -> index loc(#loc24)\n %205 = "stable_mosaic_gpu.memref.subview"(%164, %204) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<1x256xf32, #gpu.address_space>, index) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc24)\n %206 = "stable_mosaic_gpu.gpu.block_id"() {dimension = #gpu} : () -> index loc(#loc24)\n %207 = "stable_mosaic_gpu.arith.index_cast"(%206) : (index) -> i32 loc(#loc24)\n %208 = "stable_mosaic_gpu.memref.subview"(%203) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc25)\n %209 = "stable_mosaic_gpu.memref.collapse_shape"(%208) {reassociation = [[0]]} : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc25)\n %210 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc25)\n %211 = "stable_mosaic_gpu.arith.constant"() {value = 128 : index} : () -> index loc(#loc25)\n %212 = "stable_mosaic_gpu.arith.remui"(%210, %211) : (index, index) -> index loc(#loc25)\n %213 = "stable_mosaic_gpu.arith.constant"() {value = 2 : index} : () -> index loc(#loc25)\n %214 = "stable_mosaic_gpu.arith.muli"(%212, %213) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc25)\n %215 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc25)\n %216 = "stable_mosaic_gpu.arith.addi"(%214, %215) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc25)\n %217 = "stable_mosaic_gpu.vector.load"(%209, %216) : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>, index) -> vector<2xf32> loc(#loc25)\n %218 = "stable_mosaic_gpu.arith.constant"() {value = 1.000000e+00 : f32} : () -> f32 loc(#loc26)\n %219 = "stable_mosaic_gpu.vector.splat"(%218) : (f32) -> vector<2xf32> loc(#loc26)\n %220 = "stable_mosaic_gpu.arith.addf"(%217, %219) {fastmath = #arith.fastmath} : (vector<2xf32>, vector<2xf32>) -> vector<2xf32> loc(#loc26)\n %221 = "stable_mosaic_gpu.memref.subview"(%205) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc27)\n %222 = "stable_mosaic_gpu.memref.collapse_shape"(%221) {reassociation = [[0]]} : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc27)\n %223 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc27)\n %224 = "stable_mosaic_gpu.arith.constant"() {value = 128 : index} : () -> index loc(#loc27)\n %225 = "stable_mosaic_gpu.arith.remui"(%223, %224) : (index, index) -> index loc(#loc27)\n %226 = "stable_mosaic_gpu.arith.constant"() {value = 2 : index} : () -> index loc(#loc27)\n %227 = "stable_mosaic_gpu.arith.muli"(%225, %226) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc27)\n %228 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc27)\n %229 = "stable_mosaic_gpu.arith.addi"(%227, %228) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc27)\n %230 = "stable_mosaic_gpu.vector.load"(%222, %229) : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>, index) -> vector<2xf32> loc(#loc27)\n %231 = "stable_mosaic_gpu.memref.collapse_shape"(%221) {reassociation = [[0]]} : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc27)\n %232 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc27)\n %233 = "stable_mosaic_gpu.arith.constant"() {value = 128 : index} : () -> index loc(#loc27)\n %234 = "stable_mosaic_gpu.arith.remui"(%232, %233) : (index, index) -> index loc(#loc27)\n %235 = "stable_mosaic_gpu.arith.constant"() {value = 2 : index} : () -> index loc(#loc27)\n %236 = "stable_mosaic_gpu.arith.muli"(%234, %235) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc27)\n %237 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc27)\n %238 = "stable_mosaic_gpu.arith.addi"(%236, %237) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc27)\n "stable_mosaic_gpu.vector.store"(%220, %231, %238) : (vector<2xf32>, memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>, index) -> () loc(#loc27)\n "stable_mosaic_gpu.nvvm.cp.async.bulk.commit.group"() : () -> () loc(#loc28)\n %239 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc29)\n %240 = "stable_mosaic_gpu.arith.addi"(%186, %239) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc29)\n %241 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc22)\n %242 = "stable_mosaic_gpu.arith.remsi"(%240, %241) : (i32, i32) -> i32 loc(#loc22)\n %243 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc30)\n %244 = "stable_mosaic_gpu.arith.cmpi"(%186, %243) {predicate = 9 : i64} : (i32, i32) -> i1 loc(#loc30)\n %245 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc31)\n %246 = "stable_mosaic_gpu.arith.cmpi"(%240, %245) {predicate = 6 : i64} : (i32, i32) -> i1 loc(#loc31)\n %247 = "stable_mosaic_gpu.arith.andi"(%244, %246) : (i1, i1) -> i1 loc(#loc32)\n %248 = "stable_mosaic_gpu.arith.extui"(%247) : (i1) -> i32 loc(#loc33)\n %249 = "stable_mosaic_gpu.arith.index_cast"(%248) : (i32) -> index loc(#loc34)\n "stable_mosaic_gpu.scf.index_switch"(%249) ({\n %313 = "stable_mosaic_gpu.arith.index_cast"(%242) : (i32) -> index loc(#loc19)\n %314 = "stable_mosaic_gpu.memref.subview"(%161, %313) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<1x256xf32, #gpu.address_space>, index) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc19)\n %315 = "stable_mosaic_gpu.arith.index_cast"(%242) : (i32) -> index loc(#loc19)\n %316 = "stable_mosaic_gpu.arith.index_castui"(%315) : (index) -> i32 loc(#loc19)\n %317 = "stable_mosaic_gpu.arith.addi"(%125, %316) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc19)\n %318 = "stable_mosaic_gpu.arith.constant"() {value = 8 : i32} : () -> i32 loc(#loc19)\n %319 = "stable_mosaic_gpu.llvm.getelementptr"(%94, %317) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3> loc(#loc19)\n "stable_mosaic_gpu.nvvm.mbarrier.arrive.expect_tx.shared"(%319, %318) : (!llvm.ptr<3>, i32) -> () loc(#loc19)\n %320 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc19)\n %321 = "stable_mosaic_gpu.arith.index_cast"(%320) : (index) -> i32 loc(#loc19)\n %322 = "stable_mosaic_gpu.builtin.unrealized_conversion_cast"(%314) : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> loc(#loc19)\n %323 = "stable_mosaic_gpu.llvm.extractvalue"(%322) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> !llvm.ptr<3> loc(#loc19)\n %324 = "stable_mosaic_gpu.llvm.extractvalue"(%322) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> i64 loc(#loc19)\n %325 = "stable_mosaic_gpu.arith.constant"() {value = 4 : i64} : () -> i64 loc(#loc19)\n %326 = "stable_mosaic_gpu.llvm.mul"(%324, %325) : (i64, i64) -> i64 loc(#loc19)\n %327 = "stable_mosaic_gpu.llvm.ptrtoint"(%323) : (!llvm.ptr<3>) -> i64 loc(#loc19)\n %328 = "stable_mosaic_gpu.llvm.add"(%327, %326) : (i64, i64) -> i64 loc(#loc19)\n %329 = "stable_mosaic_gpu.llvm.inttoptr"(%328) : (i64) -> !llvm.ptr<3> loc(#loc19)\n %330 = "stable_mosaic_gpu.arith.constant"() {value = 1024 : i32} : () -> i32 loc(#loc19)\n %331 = "stable_mosaic_gpu.llvm.getelementptr"(%94, %317) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3> loc(#loc19)\n "stable_mosaic_gpu.nvvm.cp.async.bulk.tensor.shared.cluster.global"(%329, %83, %321, %331, %155) {operandSegmentSizes = array} : (!llvm.ptr<3>, !llvm.ptr, i32, !llvm.ptr<3>, i1) -> () loc(#loc19)\n "stable_mosaic_gpu.scf.yield"() : () -> () loc(#loc16)\n }, {\n "stable_mosaic_gpu.scf.yield"() : () -> () loc(#loc34)\n }) {cases = array} : (index) -> () loc(#loc34)\n %250 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc21)\n %251 = "stable_mosaic_gpu.arith.addi"(%184, %250) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc21)\n "stable_mosaic_gpu.nvvm.fence.proxy"() {kind = #nvvm.proxy_kind, space = #nvvm.shared_space} : () -> () loc(#loc35)\n %252 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc35)\n %253 = "stable_mosaic_gpu.arith.index_cast"(%252) : (index) -> i32 loc(#loc35)\n %254 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc35)\n %255 = "stable_mosaic_gpu.arith.index_cast"(%254) : (index) -> i32 loc(#loc35)\n %256 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc35)\n %257 = "stable_mosaic_gpu.arith.index_cast"(%256) : (index) -> i32 loc(#loc35)\n %258 = "stable_mosaic_gpu.arith.muli"(%257, %255) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\n %259 = "stable_mosaic_gpu.arith.addi"(%253, %258) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\n %260 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc35)\n %261 = "stable_mosaic_gpu.arith.index_cast"(%260) : (index) -> i32 loc(#loc35)\n %262 = "stable_mosaic_gpu.arith.muli"(%255, %261) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\n %263 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc35)\n %264 = "stable_mosaic_gpu.arith.index_cast"(%263) : (index) -> i32 loc(#loc35)\n %265 = "stable_mosaic_gpu.arith.muli"(%264, %262) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\n %266 = "stable_mosaic_gpu.arith.addi"(%259, %265) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\n %267 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc35)\n %268 = "stable_mosaic_gpu.arith.index_cast"(%267) : (index) -> i32 loc(#loc35)\n %269 = "stable_mosaic_gpu.arith.muli"(%262, %268) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\n %270 = "stable_mosaic_gpu.arith.constant"() {value = 7 : i32} : () -> i32 loc(#loc35)\n %271 = "stable_mosaic_gpu.arith.shrui"(%266, %270) : (i32, i32) -> i32 loc(#loc35)\n %272 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc35)\n %273 = "stable_mosaic_gpu.arith.addi"(%271, %272) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\n %274 = "stable_mosaic_gpu.llvm.inline_asm"(%273) {asm_string = "bar.sync $0, 128;", constraints = "r", has_side_effects} : (i32) -> !llvm.void loc(#loc35)\n %275 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc22)\n %276 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc22)\n %277 = "stable_mosaic_gpu.arith.remsi"(%275, %276) : (i32, i32) -> i32 loc(#loc22)\n %278 = "stable_mosaic_gpu.arith.index_cast"(%277) : (i32) -> index loc(#loc18)\n %279 = "stable_mosaic_gpu.memref.subview"(%164, %278) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<1x256xf32, #gpu.address_space>, index) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc18)\n %280 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc18)\n %281 = "stable_mosaic_gpu.arith.index_cast"(%280) : (index) -> i32 loc(#loc18)\n %282 = "stable_mosaic_gpu.builtin.unrealized_conversion_cast"(%279) : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> loc(#loc18)\n %283 = "stable_mosaic_gpu.llvm.extractvalue"(%282) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> !llvm.ptr<3> loc(#loc18)\n %284 = "stable_mosaic_gpu.llvm.extractvalue"(%282) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> i64 loc(#loc18)\n %285 = "stable_mosaic_gpu.arith.constant"() {value = 4 : i64} : () -> i64 loc(#loc18)\n %286 = "stable_mosaic_gpu.llvm.mul"(%284, %285) : (i64, i64) -> i64 loc(#loc18)\n %287 = "stable_mosaic_gpu.llvm.ptrtoint"(%283) : (!llvm.ptr<3>) -> i64 loc(#loc18)\n %288 = "stable_mosaic_gpu.llvm.add"(%287, %286) : (i64, i64) -> i64 loc(#loc18)\n %289 = "stable_mosaic_gpu.llvm.inttoptr"(%288) : (i64) -> !llvm.ptr<3> loc(#loc18)\n "stable_mosaic_gpu.nvvm.cp.async.bulk.tensor.global.shared.cta"(%82, %289, %281, %155) {operandSegmentSizes = array} : (!llvm.ptr, !llvm.ptr<3>, i32, i1) -> () loc(#loc18)\n "stable_mosaic_gpu.nvvm.cp.async.bulk.commit.group"() : () -> () loc(#loc28)\n "stable_mosaic_gpu.nvvm.cp.async.bulk.wait_group"() {group = 0 : i32} : () -> () loc(#loc36)\n %290 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc36)\n %291 = "stable_mosaic_gpu.arith.index_cast"(%290) : (index) -> i32 loc(#loc36)\n %292 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc36)\n %293 = "stable_mosaic_gpu.arith.index_cast"(%292) : (index) -> i32 loc(#loc36)\n %294 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc36)\n %295 = "stable_mosaic_gpu.arith.index_cast"(%294) : (index) -> i32 loc(#loc36)\n %296 = "stable_mosaic_gpu.arith.muli"(%295, %293) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\n %297 = "stable_mosaic_gpu.arith.addi"(%291, %296) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\n %298 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc36)\n %299 = "stable_mosaic_gpu.arith.index_cast"(%298) : (index) -> i32 loc(#loc36)\n %300 = "stable_mosaic_gpu.arith.muli"(%293, %299) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\n %301 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc36)\n %302 = "stable_mosaic_gpu.arith.index_cast"(%301) : (index) -> i32 loc(#loc36)\n %303 = "stable_mosaic_gpu.arith.muli"(%302, %300) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\n %304 = "stable_mosaic_gpu.arith.addi"(%297, %303) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\n %305 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc36)\n %306 = "stable_mosaic_gpu.arith.index_cast"(%305) : (index) -> i32 loc(#loc36)\n %307 = "stable_mosaic_gpu.arith.muli"(%300, %306) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\n %308 = "stable_mosaic_gpu.arith.constant"() {value = 7 : i32} : () -> i32 loc(#loc36)\n %309 = "stable_mosaic_gpu.arith.shrui"(%304, %308) : (i32, i32) -> i32 loc(#loc36)\n %310 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc36)\n %311 = "stable_mosaic_gpu.arith.addi"(%309, %310) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\n %312 = "stable_mosaic_gpu.llvm.inline_asm"(%311) {asm_string = "bar.sync $0, 128;", constraints = "r", has_side_effects} : (i32) -> !llvm.void loc(#loc36)\n "stable_mosaic_gpu.gpu.terminator"() : () -> () loc(#loc17)\n }) {operandSegmentSizes = array, workgroup_attributions = 0 : i64} : (!gpu.async.token, index, index, index, index, index, index, i32) -> !gpu.async.token loc(#loc17)\n "stable_mosaic_gpu.func.return"() : () -> () loc(#loc17)\n }) {function_type = (!llvm.ptr, !llvm.ptr) -> (), llvm.emit_c_interface, sym_name = "mosaic_gpu_body"} : () -> () loc(#loc17)\n}) {stable_mosaic_gpu.version = 1 : i64} : () -> () loc(#loc17)\n#loc13 = loc("-":141:7)\n#loc14 = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":78:19)\n#loc15 = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":78:6)\n#loc16 = loc("-":279:7)\n#loc18 = loc("/copy_smem_to_gmem"(#loc))\n#loc19 = loc("/copy_gmem_to_smem"(#loc))\n#loc20 = loc("/run_scoped"(#loc))\n#loc21 = loc("/scan"(#loc))\n#loc22 = loc("/rem"(#loc))\n#loc23 = loc("/barrier_wait"(#loc))\n#loc24 = loc("/jaxpr_call"(#loc))\n#loc25 = loc("/get"(#loc14))\n#loc26 = loc("/add"(#loc14))\n#loc27 = loc("/swap"(#loc15))\n#loc28 = loc("/commit_group"(#loc))\n#loc29 = loc("/add"(#loc))\n#loc30 = loc("/ge"(#loc))\n#loc31 = loc("/lt"(#loc))\n#loc32 = loc("/and"(#loc))\n#loc33 = loc("/convert_element_type"(#loc))\n#loc34 = loc("/cond"(#loc))\n#loc35 = loc("/commit_smem"(#loc))\n#loc36 = loc("/wait_smem_to_gmem"(#loc))\n\x00mosaic_gpu\x00\x08\'\x07\x05\x1f\x01\x0b!%\'/1\x11357\x1d9\x1f\x1d\x1f', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +kernel_data_2025_09_07 = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['AllocateBuffer', 'mosaic_gpu_v2'], + serialized_date=datetime.date(2025, 9, 7), + inputs=(array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., + 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., + 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., + 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., + 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., + 55., 56., 57., 58., 59., 60., 61., 62., 63., 64., 65., + 66., 67., 68., 69., 70., 71., 72., 73., 74., 75., 76., + 77., 78., 79., 80., 81., 82., 83., 84., 85., 86., 87., + 88., 89., 90., 91., 92., 93., 94., 95., 96., 97., 98., + 99., 100., 101., 102., 103., 104., 105., 106., 107., 108., 109., + 110., 111., 112., 113., 114., 115., 116., 117., 118., 119., 120., + 121., 122., 123., 124., 125., 126., 127.], dtype=float32),), + expected_outputs=(array([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., + 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., + 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., + 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., + 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., + 56., 57., 58., 59., 60., 61., 62., 63., 64., 65., 66., + 67., 68., 69., 70., 71., 72., 73., 74., 75., 76., 77., + 78., 79., 80., 81., 82., 83., 84., 85., 86., 87., 88., + 89., 90., 91., 92., 93., 94., 95., 96., 97., 98., 99., + 100., 101., 102., 103., 104., 105., 106., 107., 108., 109., 110., + 111., 112., 113., 114., 115., 116., 117., 118., 119., 120., 121., + 122., 123., 124., 125., 126., 127., 128.], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("operands[0]") +module @jit_wrapper attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<128xf32> loc("operands[0]")) -> (tensor<128xf32> {jax.result_info = "result"}) { + %0 = stablehlo.custom_call @AllocateBuffer() : () -> tensor<128xf32> loc(#loc4) + %1 = stablehlo.custom_call @mosaic_gpu_v2(%arg0, %0) {mhlo.backend_config = {kernel_hash = "\A4,O\92\88\FB;\FC\DFyB@\19\BB)\BC\EF\D2\8Asl\B2N\10\00\B4\91J\A9\FA\1B\ED", module = "#loc = loc(\22pallas_call\22)\0A#loc1 = loc(\22-\22:50:40)\0A#loc2 = loc(\22-\22:50:47)\0A#loc3 = loc(\22-\22:50:54)\0A#loc4 = loc(\22-\22:50:115)\0A#loc5 = loc(\22-\22:50:122)\0A#loc6 = loc(\22-\22:50:129)\0A#loc7 = loc(\22-\22:50:65)\0A#loc8 = loc(\22-\22:50:78)\0A#loc9 = loc(\22-\22:50:91)\0A#loc10 = loc(\22-\22:50:140)\0A#loc11 = loc(\22-\22:50:156)\0A#loc12 = loc(\22-\22:50:172)\0A\22builtin.module\22() <{sym_name = \22add_one\22}> ({\0A \22stable_mosaic_gpu.func.func\22() ({\0A }) {function_type = (!llvm.ptr, !llvm.ptr, i64, i64, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> (), sym_name = \22mosaic_gpu_init_tma_desc\22, sym_visibility = \22private\22} : () -> () loc(#loc)\0A \22stable_mosaic_gpu.llvm.mlir.global\22() ({\0A }) {addr_space = 4 : i32, global_type = !llvm.array<0 x i8>, linkage = #llvm.linkage, sym_name = \22global_scratch\22, unnamed_addr = 0 : i64, visibility_ = 0 : i64} : () -> () loc(#loc)\0A \22stable_mosaic_gpu.func.func\22() ({\0A ^bb0(%arg0: !llvm.ptr loc(\22pallas_call\22), %arg1: !llvm.ptr loc(\22pallas_call\22)):\0A %0 = \22stable_mosaic_gpu.builtin.unrealized_conversion_cast\22(%arg0) : (!llvm.ptr) -> !gpu.async.token loc(#loc)\0A %1 = \22stable_mosaic_gpu.llvm.getelementptr\22(%arg1) {elem_type = !llvm.ptr, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc)\0A %2 = \22stable_mosaic_gpu.llvm.load\22(%1) {ordering = 0 : i64} : (!llvm.ptr) -> !llvm.ptr loc(#loc)\0A %3 = \22stable_mosaic_gpu.llvm.mlir.undef\22() : () -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\0A %4 = \22stable_mosaic_gpu.llvm.insertvalue\22(%3, %2) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\0A %5 = \22stable_mosaic_gpu.llvm.insertvalue\22(%4, %2) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\0A %6 = \22stable_mosaic_gpu.llvm.mlir.constant\22() {value = 0 : i64} : () -> i64 loc(#loc)\0A %7 = \22stable_mosaic_gpu.llvm.insertvalue\22(%5, %6) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\0A %8 = \22stable_mosaic_gpu.llvm.mlir.constant\22() {value = 128 : i64} : () -> i64 loc(#loc)\0A %9 = \22stable_mosaic_gpu.llvm.insertvalue\22(%7, %8) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\0A %10 = \22stable_mosaic_gpu.llvm.mlir.constant\22() {value = 1 : i64} : () -> i64 loc(#loc)\0A %11 = \22stable_mosaic_gpu.llvm.insertvalue\22(%9, %10) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\0A %12 = \22stable_mosaic_gpu.builtin.unrealized_conversion_cast\22(%11) : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>) -> memref<128xf32> loc(#loc)\0A %13 = \22stable_mosaic_gpu.llvm.getelementptr\22(%arg1) {elem_type = !llvm.ptr, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc)\0A %14 = \22stable_mosaic_gpu.llvm.load\22(%13) {ordering = 0 : i64} : (!llvm.ptr) -> !llvm.ptr loc(#loc)\0A %15 = \22stable_mosaic_gpu.llvm.mlir.undef\22() : () -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\0A %16 = \22stable_mosaic_gpu.llvm.insertvalue\22(%15, %14) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\0A %17 = \22stable_mosaic_gpu.llvm.insertvalue\22(%16, %14) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\0A %18 = \22stable_mosaic_gpu.llvm.mlir.constant\22() {value = 0 : i64} : () -> i64 loc(#loc)\0A %19 = \22stable_mosaic_gpu.llvm.insertvalue\22(%17, %18) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\0A %20 = \22stable_mosaic_gpu.llvm.mlir.constant\22() {value = 128 : i64} : () -> i64 loc(#loc)\0A %21 = \22stable_mosaic_gpu.llvm.insertvalue\22(%19, %20) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\0A %22 = \22stable_mosaic_gpu.llvm.mlir.constant\22() {value = 1 : i64} : () -> i64 loc(#loc)\0A %23 = \22stable_mosaic_gpu.llvm.insertvalue\22(%21, %22) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\0A %24 = \22stable_mosaic_gpu.builtin.unrealized_conversion_cast\22(%23) : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>) -> memref<128xf32> loc(#loc)\0A %25 = \22stable_mosaic_gpu.llvm.getelementptr\22(%arg1) {elem_type = !llvm.ptr, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc)\0A %26 = \22stable_mosaic_gpu.llvm.load\22(%25) {ordering = 0 : i64} : (!llvm.ptr) -> !llvm.ptr loc(#loc)\0A %27 = \22stable_mosaic_gpu.llvm.mlir.undef\22() : () -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\0A %28 = \22stable_mosaic_gpu.llvm.insertvalue\22(%27, %26) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\0A %29 = \22stable_mosaic_gpu.llvm.insertvalue\22(%28, %26) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\0A %30 = \22stable_mosaic_gpu.llvm.mlir.constant\22() {value = 0 : i64} : () -> i64 loc(#loc)\0A %31 = \22stable_mosaic_gpu.llvm.insertvalue\22(%29, %30) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\0A %32 = \22stable_mosaic_gpu.llvm.mlir.constant\22() {value = 128 : i64} : () -> i64 loc(#loc)\0A %33 = \22stable_mosaic_gpu.llvm.insertvalue\22(%31, %32) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\0A %34 = \22stable_mosaic_gpu.llvm.mlir.constant\22() {value = 1 : i64} : () -> i64 loc(#loc)\0A %35 = \22stable_mosaic_gpu.llvm.insertvalue\22(%33, %34) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\0A %36 = \22stable_mosaic_gpu.builtin.unrealized_conversion_cast\22(%35) : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>) -> memref<128xf32> loc(#loc)\0A %37 = \22stable_mosaic_gpu.arith.constant\22() {value = 2 : index} : () -> index loc(#loc)\0A %38 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : index} : () -> index loc(#loc)\0A %39 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : index} : () -> index loc(#loc)\0A %40 = \22stable_mosaic_gpu.arith.constant\22() {value = 128 : index} : () -> index loc(#loc)\0A %41 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : index} : () -> index loc(#loc)\0A %42 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : index} : () -> index loc(#loc)\0A %43 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc)\0A %44 = \22stable_mosaic_gpu.gpu.launch\22(%0, %37, %38, %39, %40, %41, %42, %43) ({\0A ^bb0(%arg2: index loc(\22-\22:50:40), %arg3: index loc(\22-\22:50:47), %arg4: index loc(\22-\22:50:54), %arg5: index loc(\22-\22:50:115), %arg6: index loc(\22-\22:50:122), %arg7: index loc(\22-\22:50:129), %arg8: index loc(\22-\22:50:65), %arg9: index loc(\22-\22:50:78), %arg10: index loc(\22-\22:50:91), %arg11: index loc(\22-\22:50:140), %arg12: index loc(\22-\22:50:156), %arg13: index loc(\22-\22:50:172)):\0A %45 = \22stable_mosaic_gpu.gpu.dynamic_shared_memory\22() : () -> memref> loc(#loc)\0A %46 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc)\0A %47 = \22stable_mosaic_gpu.memref.view\22(%45, %46) : (memref>, index) -> memref<0xi8, #gpu.address_space> loc(#loc)\0A \22stable_mosaic_gpu.nvvm.fence.mbarrier.init\22() : () -> () loc(#loc)\0A \22stable_mosaic_gpu.gpu.barrier\22() : () -> () loc(#loc)\0A %48 = \22stable_mosaic_gpu.nvvm.elect.sync\22() : () -> i1 loc(#loc)\0A %49 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc)\0A %50 = \22stable_mosaic_gpu.arith.index_cast\22(%49) : (index) -> i32 loc(#loc)\0A %51 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc)\0A %52 = \22stable_mosaic_gpu.arith.index_cast\22(%51) : (index) -> i32 loc(#loc)\0A %53 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc)\0A %54 = \22stable_mosaic_gpu.arith.index_cast\22(%53) : (index) -> i32 loc(#loc)\0A %55 = \22stable_mosaic_gpu.arith.muli\22(%54, %52) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc)\0A %56 = \22stable_mosaic_gpu.arith.addi\22(%50, %55) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc)\0A %57 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc)\0A %58 = \22stable_mosaic_gpu.arith.index_cast\22(%57) : (index) -> i32 loc(#loc)\0A %59 = \22stable_mosaic_gpu.arith.muli\22(%52, %58) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc)\0A %60 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc)\0A %61 = \22stable_mosaic_gpu.arith.index_cast\22(%60) : (index) -> i32 loc(#loc)\0A %62 = \22stable_mosaic_gpu.arith.muli\22(%61, %59) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc)\0A %63 = \22stable_mosaic_gpu.arith.addi\22(%56, %62) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc)\0A %64 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc)\0A %65 = \22stable_mosaic_gpu.arith.index_cast\22(%64) : (index) -> i32 loc(#loc)\0A %66 = \22stable_mosaic_gpu.arith.muli\22(%59, %65) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc)\0A %67 = \22stable_mosaic_gpu.arith.constant\22() {value = 5 : i32} : () -> i32 loc(#loc)\0A %68 = \22stable_mosaic_gpu.arith.shrui\22(%63, %67) : (i32, i32) -> i32 loc(#loc)\0A %69 = \22stable_mosaic_gpu.arith.constant\22() {value = -1 : i32} : () -> i32 loc(#loc)\0A %70 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc)\0A %71 = \22stable_mosaic_gpu.arith.constant\22() {value = 31 : i32} : () -> i32 loc(#loc)\0A %72 = \22stable_mosaic_gpu.nvvm.shfl.sync\22(%69, %68, %70, %71) {kind = #nvvm} : (i32, i32, i32, i32) -> i32 loc(#loc)\0A %73 = \22stable_mosaic_gpu.arith.constant\22() {value = 4 : i32} : () -> i32 loc(#loc)\0A %74 = \22stable_mosaic_gpu.arith.remui\22(%72, %73) : (i32, i32) -> i32 loc(#loc)\0A %75 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc)\0A %76 = \22stable_mosaic_gpu.arith.cmpi\22(%74, %75) {predicate = 0 : i64} : (i32, i32) -> i1 loc(#loc)\0A %77 = \22stable_mosaic_gpu.arith.andi\22(%76, %48) : (i1, i1) -> i1 loc(#loc)\0A %78 = \22stable_mosaic_gpu.nvvm.elect.sync\22() : () -> i1 loc(#loc)\0A %79 = \22stable_mosaic_gpu.gpu.block_id\22() {dimension = #gpu} : () -> index loc(#loc)\0A %80 = \22stable_mosaic_gpu.arith.index_cast\22(%79) : (index) -> i32 loc(#loc)\0A %81 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc29)\0A %82 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc29)\0A %83 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc29)\0A %84 = \22stable_mosaic_gpu.arith.muli\22(%82, %83) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc29)\0A %85 = \22stable_mosaic_gpu.arith.addi\22(%84, %81) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc29)\0A %86 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc29)\0A %87 = \22stable_mosaic_gpu.arith.addi\22(%85, %86) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc29)\0A %88 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc30)\0A %89 = \22stable_mosaic_gpu.arith.remsi\22(%87, %88) : (i32, i32) -> i32 loc(#loc30)\0A %90 = \22stable_mosaic_gpu.gpu.block_id\22() {dimension = #gpu} : () -> index loc(#loc31)\0A %91 = \22stable_mosaic_gpu.arith.index_cast\22(%90) : (index) -> i32 loc(#loc31)\0A %92 = \22stable_mosaic_gpu.memref.subview\22(%12) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<128xf32>) -> memref<128xf32, strided<[1]>> loc(#loc32)\0A %93 = \22stable_mosaic_gpu.memref.collapse_shape\22(%92) {reassociation = [[0]]} : (memref<128xf32, strided<[1]>>) -> memref<128xf32> loc(#loc32)\0A %94 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc32)\0A %95 = \22stable_mosaic_gpu.arith.constant\22() {value = 128 : index} : () -> index loc(#loc32)\0A %96 = \22stable_mosaic_gpu.arith.remui\22(%94, %95) : (index, index) -> index loc(#loc32)\0A %97 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : index} : () -> index loc(#loc32)\0A %98 = \22stable_mosaic_gpu.arith.muli\22(%96, %97) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc32)\0A %99 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc32)\0A %100 = \22stable_mosaic_gpu.arith.addi\22(%98, %99) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc32)\0A %101 = \22stable_mosaic_gpu.vector.load\22(%93, %100) : (memref<128xf32>, index) -> vector<1xf32> loc(#loc32)\0A %102 = \22stable_mosaic_gpu.arith.constant\22() {value = 1.000000e+00 : f32} : () -> f32 loc(#loc33)\0A %103 = \22stable_mosaic_gpu.vector.splat\22(%102) : (f32) -> vector<1xf32> loc(#loc33)\0A %104 = \22stable_mosaic_gpu.arith.addf\22(%101, %103) {fastmath = #arith.fastmath} : (vector<1xf32>, vector<1xf32>) -> vector<1xf32> loc(#loc33)\0A %105 = \22stable_mosaic_gpu.memref.subview\22(%24) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<128xf32>) -> memref<128xf32, strided<[1]>> loc(#loc34)\0A %106 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc34)\0A %107 = \22stable_mosaic_gpu.arith.index_cast\22(%106) : (index) -> i32 loc(#loc34)\0A %108 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc34)\0A %109 = \22stable_mosaic_gpu.arith.index_cast\22(%108) : (index) -> i32 loc(#loc34)\0A %110 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc34)\0A %111 = \22stable_mosaic_gpu.arith.index_cast\22(%110) : (index) -> i32 loc(#loc34)\0A %112 = \22stable_mosaic_gpu.arith.muli\22(%111, %109) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc34)\0A %113 = \22stable_mosaic_gpu.arith.addi\22(%107, %112) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc34)\0A %114 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc34)\0A %115 = \22stable_mosaic_gpu.arith.index_cast\22(%114) : (index) -> i32 loc(#loc34)\0A %116 = \22stable_mosaic_gpu.arith.muli\22(%109, %115) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc34)\0A %117 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc34)\0A %118 = \22stable_mosaic_gpu.arith.index_cast\22(%117) : (index) -> i32 loc(#loc34)\0A %119 = \22stable_mosaic_gpu.arith.muli\22(%118, %116) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc34)\0A %120 = \22stable_mosaic_gpu.arith.addi\22(%113, %119) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc34)\0A %121 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc34)\0A %122 = \22stable_mosaic_gpu.arith.index_cast\22(%121) : (index) -> i32 loc(#loc34)\0A %123 = \22stable_mosaic_gpu.arith.muli\22(%116, %122) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc34)\0A %124 = \22stable_mosaic_gpu.arith.constant\22() {value = 7 : i32} : () -> i32 loc(#loc34)\0A %125 = \22stable_mosaic_gpu.arith.shrui\22(%120, %124) : (i32, i32) -> i32 loc(#loc34)\0A %126 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc34)\0A %127 = \22stable_mosaic_gpu.arith.addi\22(%125, %126) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc34)\0A %128 = \22stable_mosaic_gpu.llvm.inline_asm\22(%127) {asm_string = \22bar.sync $0, 128;\22, constraints = \22r\22, has_side_effects, tail_call_kind = #llvm.tailcallkind} : (i32) -> !llvm.void loc(#loc34)\0A %129 = \22stable_mosaic_gpu.memref.collapse_shape\22(%105) {reassociation = [[0]]} : (memref<128xf32, strided<[1]>>) -> memref<128xf32> loc(#loc34)\0A %130 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc34)\0A %131 = \22stable_mosaic_gpu.arith.constant\22() {value = 128 : index} : () -> index loc(#loc34)\0A %132 = \22stable_mosaic_gpu.arith.remui\22(%130, %131) : (index, index) -> index loc(#loc34)\0A %133 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : index} : () -> index loc(#loc34)\0A %134 = \22stable_mosaic_gpu.arith.muli\22(%132, %133) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc34)\0A %135 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc34)\0A %136 = \22stable_mosaic_gpu.arith.addi\22(%134, %135) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc34)\0A %137 = \22stable_mosaic_gpu.vector.load\22(%129, %136) : (memref<128xf32>, index) -> vector<1xf32> loc(#loc34)\0A %138 = \22stable_mosaic_gpu.memref.collapse_shape\22(%105) {reassociation = [[0]]} : (memref<128xf32, strided<[1]>>) -> memref<128xf32> loc(#loc34)\0A %139 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc34)\0A %140 = \22stable_mosaic_gpu.arith.constant\22() {value = 128 : index} : () -> index loc(#loc34)\0A %141 = \22stable_mosaic_gpu.arith.remui\22(%139, %140) : (index, index) -> index loc(#loc34)\0A %142 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : index} : () -> index loc(#loc34)\0A %143 = \22stable_mosaic_gpu.arith.muli\22(%141, %142) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc34)\0A %144 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc34)\0A %145 = \22stable_mosaic_gpu.arith.addi\22(%143, %144) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc34)\0A \22stable_mosaic_gpu.vector.store\22(%104, %138, %145) : (vector<1xf32>, memref<128xf32>, index) -> () loc(#loc34)\0A %146 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc34)\0A %147 = \22stable_mosaic_gpu.arith.index_cast\22(%146) : (index) -> i32 loc(#loc34)\0A %148 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc34)\0A %149 = \22stable_mosaic_gpu.arith.index_cast\22(%148) : (index) -> i32 loc(#loc34)\0A %150 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc34)\0A %151 = \22stable_mosaic_gpu.arith.index_cast\22(%150) : (index) -> i32 loc(#loc34)\0A %152 = \22stable_mosaic_gpu.arith.muli\22(%151, %149) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc34)\0A %153 = \22stable_mosaic_gpu.arith.addi\22(%147, %152) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc34)\0A %154 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc34)\0A %155 = \22stable_mosaic_gpu.arith.index_cast\22(%154) : (index) -> i32 loc(#loc34)\0A %156 = \22stable_mosaic_gpu.arith.muli\22(%149, %155) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc34)\0A %157 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc34)\0A %158 = \22stable_mosaic_gpu.arith.index_cast\22(%157) : (index) -> i32 loc(#loc34)\0A %159 = \22stable_mosaic_gpu.arith.muli\22(%158, %156) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc34)\0A %160 = \22stable_mosaic_gpu.arith.addi\22(%153, %159) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc34)\0A %161 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc34)\0A %162 = \22stable_mosaic_gpu.arith.index_cast\22(%161) : (index) -> i32 loc(#loc34)\0A %163 = \22stable_mosaic_gpu.arith.muli\22(%156, %162) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc34)\0A %164 = \22stable_mosaic_gpu.arith.constant\22() {value = 7 : i32} : () -> i32 loc(#loc34)\0A %165 = \22stable_mosaic_gpu.arith.shrui\22(%160, %164) : (i32, i32) -> i32 loc(#loc34)\0A %166 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc34)\0A %167 = \22stable_mosaic_gpu.arith.addi\22(%165, %166) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc34)\0A %168 = \22stable_mosaic_gpu.llvm.inline_asm\22(%167) {asm_string = \22bar.sync $0, 128;\22, constraints = \22r\22, has_side_effects, tail_call_kind = #llvm.tailcallkind} : (i32) -> !llvm.void loc(#loc34)\0A %169 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc35)\0A %170 = \22stable_mosaic_gpu.arith.addi\22(%87, %169) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\0A %171 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc30)\0A %172 = \22stable_mosaic_gpu.arith.remsi\22(%170, %171) : (i32, i32) -> i32 loc(#loc30)\0A %173 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc36)\0A %174 = \22stable_mosaic_gpu.arith.cmpi\22(%87, %173) {predicate = 5 : i64} : (i32, i32) -> i1 loc(#loc36)\0A %175 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc37)\0A %176 = \22stable_mosaic_gpu.arith.cmpi\22(%170, %175) {predicate = 2 : i64} : (i32, i32) -> i1 loc(#loc37)\0A %177 = \22stable_mosaic_gpu.arith.andi\22(%174, %176) : (i1, i1) -> i1 loc(#loc38)\0A %178 = \22stable_mosaic_gpu.arith.extui\22(%177) : (i1) -> i32 loc(#loc39)\0A %179 = \22stable_mosaic_gpu.arith.index_cast\22(%178) : (i32) -> index loc(#loc40)\0A \22stable_mosaic_gpu.scf.index_switch\22(%179) ({\0A \22stable_mosaic_gpu.scf.yield\22() : () -> () loc(#loc16)\0A }, {\0A \22stable_mosaic_gpu.scf.yield\22() : () -> () loc(#loc40)\0A }) {cases = array} : (index) -> () loc(#loc40)\0A %180 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc29)\0A %181 = \22stable_mosaic_gpu.arith.addi\22(%81, %180) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc29)\0A \22stable_mosaic_gpu.gpu.terminator\22() : () -> () loc(#loc)\0A }) {operandSegmentSizes = array, workgroup_attributions = 0 : i64} : (!gpu.async.token, index, index, index, index, index, index, i32) -> !gpu.async.token loc(#loc)\0A \22stable_mosaic_gpu.func.return\22() : () -> () loc(#loc)\0A }) {function_type = (!llvm.ptr, !llvm.ptr) -> (), llvm.emit_c_interface, sym_name = \22add_one_mosaic_gpu\22} : () -> () loc(#loc)\0A}) {stable_mosaic_gpu.version = 1 : i64} : () -> () loc(#loc)\0A#loc13 = loc(\22third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py\22:102:4)\0A#loc14 = loc(\22third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py\22:98:19)\0A#loc15 = loc(\22third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py\22:98:6)\0A#loc16 = loc(\22-\22:189:7)\0A#loc17 = loc(\22scan\22(#loc13))\0A#loc18 = loc(\22rem\22(#loc13))\0A#loc19 = loc(\22jaxpr_call\22(#loc13))\0A#loc20 = loc(\22get\22(#loc14))\0A#loc21 = loc(\22add\22(#loc14))\0A#loc22 = loc(\22swap\22(#loc15))\0A#loc23 = loc(\22add\22(#loc13))\0A#loc24 = loc(\22ge\22(#loc13))\0A#loc25 = loc(\22lt\22(#loc13))\0A#loc26 = loc(\22and\22(#loc13))\0A#loc27 = loc(\22convert_element_type\22(#loc13))\0A#loc28 = loc(\22cond\22(#loc13))\0A#loc29 = loc(\22scan:\22(#loc17))\0A#loc30 = loc(\22rem:\22(#loc18))\0A#loc31 = loc(\22jaxpr_call:\22(#loc19))\0A#loc32 = loc(\22get:\22(#loc20))\0A#loc33 = loc(\22add:\22(#loc21))\0A#loc34 = loc(\22swap:\22(#loc22))\0A#loc35 = loc(\22add:\22(#loc23))\0A#loc36 = loc(\22ge:\22(#loc24))\0A#loc37 = loc(\22lt:\22(#loc25))\0A#loc38 = loc(\22and:\22(#loc26))\0A#loc39 = loc(\22convert_element_type:\22(#loc27))\0A#loc40 = loc(\22cond:\22(#loc28))\0A", use_custom_barrier = false}, operand_layouts = [dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<0> : tensor<1xindex>]} : (tensor<128xf32>, tensor<128xf32>) -> tensor<128xf32> loc(#loc4) + return %1 : tensor<128xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":102:4) +#loc3 = loc("jit(wrapper)"(#loc2)) +#loc4 = loc(""(#loc3)) +""", + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.12.1\x00\x01\x19\x05\x01\x05\t\x01\x03\x0b\x03\x07\x0f\x13\x17\x03{Y\x0f\x01%\x07\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0b\x0f\x0b\x13\x0b\x13\x0b\x035\x0b\x0b/\x0b\x0b\x0f\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b#\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0f\x13\x0f\x01\x05\x0b\x0f\x03\x0b\x17\x17\x07\x13\x07\x02R\x02\x1f\x11\x03\x05\x1d\x17\x19\x03\x07\t\x0b\r\x03\x0f\x03\x05\r\x11\x01\x00\x05\x0f\x05\x11\x05\x13\x1d\x15\x01\x05\x15\x05\x17\x1d\x1b\x1d\x05\x19\x17\x1f\xcd\t\x05\x1b\x03\x03#C\x05\x1d\x03\x01\x05\x01\x1f\x0b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x17\x03\x031\r\x01#\x07\x03\x037\r\x039;\x1d\x1f\x1d!\x1d#\x1d%\x1d\'\r\x07EGIKM\'\x1d)\x1d+\x1d\x05\x1d-\x1d/\x1d1\x03\x05))\x03\x03U\x15\x01\x05\x01\x03\x03)\x01\t\x01\x02\x02)\x03\x02\x04\t\x11\x03\x05\x03\x05\t)\x03\x05\r\x13\x04_\x05\x01Q\x01\x07\x01\x07\x04M\x03\x01\x05\x05P\x01\x03\x07\x049\x03\x07\x0f\x03\x0b\x13\x00\x03B\x05\x05\x03\x05\x03G\x05!\x07\x03\x05\x05\x01\x03\x07\x04\x01\x03\x05\x06\x03\x01\x05\x01\x00\x0c\x14\x033\x1d\'\x94\x07\x03C\x19\x1f\x0f\x0b\x0f!)\x85\x1b\x03\x19\x19%)9\x15\x11\x1f\x0f\x0b\x11builtin\x00vhlo\x00module\x00custom_call_v1\x00func_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_wrapper\x00operands[0]\x00\x00jit(wrapper)\x00third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py\x00mhlo.backend_config\x00jax.result_info\x00result\x00main\x00public\x00AllocateBuffer\x00kernel_hash\x00\xa4,O\x92\x88\xfb;\xfc\xdfyB@\x19\xbb)\xbc\xef\xd2\x8asl\xb2N\x10\x00\xb4\x91J\xa9\xfa\x1b\xed\x00#loc = loc("pallas_call")\n#loc1 = loc("-":50:40)\n#loc2 = loc("-":50:47)\n#loc3 = loc("-":50:54)\n#loc4 = loc("-":50:115)\n#loc5 = loc("-":50:122)\n#loc6 = loc("-":50:129)\n#loc7 = loc("-":50:65)\n#loc8 = loc("-":50:78)\n#loc9 = loc("-":50:91)\n#loc10 = loc("-":50:140)\n#loc11 = loc("-":50:156)\n#loc12 = loc("-":50:172)\n"builtin.module"() <{sym_name = "add_one"}> ({\n "stable_mosaic_gpu.func.func"() ({\n }) {function_type = (!llvm.ptr, !llvm.ptr, i64, i64, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> (), sym_name = "mosaic_gpu_init_tma_desc", sym_visibility = "private"} : () -> () loc(#loc)\n "stable_mosaic_gpu.llvm.mlir.global"() ({\n }) {addr_space = 4 : i32, global_type = !llvm.array<0 x i8>, linkage = #llvm.linkage, sym_name = "global_scratch", unnamed_addr = 0 : i64, visibility_ = 0 : i64} : () -> () loc(#loc)\n "stable_mosaic_gpu.func.func"() ({\n ^bb0(%arg0: !llvm.ptr loc("pallas_call"), %arg1: !llvm.ptr loc("pallas_call")):\n %0 = "stable_mosaic_gpu.builtin.unrealized_conversion_cast"(%arg0) : (!llvm.ptr) -> !gpu.async.token loc(#loc)\n %1 = "stable_mosaic_gpu.llvm.getelementptr"(%arg1) {elem_type = !llvm.ptr, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc)\n %2 = "stable_mosaic_gpu.llvm.load"(%1) {ordering = 0 : i64} : (!llvm.ptr) -> !llvm.ptr loc(#loc)\n %3 = "stable_mosaic_gpu.llvm.mlir.undef"() : () -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\n %4 = "stable_mosaic_gpu.llvm.insertvalue"(%3, %2) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\n %5 = "stable_mosaic_gpu.llvm.insertvalue"(%4, %2) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\n %6 = "stable_mosaic_gpu.llvm.mlir.constant"() {value = 0 : i64} : () -> i64 loc(#loc)\n %7 = "stable_mosaic_gpu.llvm.insertvalue"(%5, %6) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\n %8 = "stable_mosaic_gpu.llvm.mlir.constant"() {value = 128 : i64} : () -> i64 loc(#loc)\n %9 = "stable_mosaic_gpu.llvm.insertvalue"(%7, %8) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\n %10 = "stable_mosaic_gpu.llvm.mlir.constant"() {value = 1 : i64} : () -> i64 loc(#loc)\n %11 = "stable_mosaic_gpu.llvm.insertvalue"(%9, %10) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\n %12 = "stable_mosaic_gpu.builtin.unrealized_conversion_cast"(%11) : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>) -> memref<128xf32> loc(#loc)\n %13 = "stable_mosaic_gpu.llvm.getelementptr"(%arg1) {elem_type = !llvm.ptr, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc)\n %14 = "stable_mosaic_gpu.llvm.load"(%13) {ordering = 0 : i64} : (!llvm.ptr) -> !llvm.ptr loc(#loc)\n %15 = "stable_mosaic_gpu.llvm.mlir.undef"() : () -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\n %16 = "stable_mosaic_gpu.llvm.insertvalue"(%15, %14) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\n %17 = "stable_mosaic_gpu.llvm.insertvalue"(%16, %14) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\n %18 = "stable_mosaic_gpu.llvm.mlir.constant"() {value = 0 : i64} : () -> i64 loc(#loc)\n %19 = "stable_mosaic_gpu.llvm.insertvalue"(%17, %18) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\n %20 = "stable_mosaic_gpu.llvm.mlir.constant"() {value = 128 : i64} : () -> i64 loc(#loc)\n %21 = "stable_mosaic_gpu.llvm.insertvalue"(%19, %20) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\n %22 = "stable_mosaic_gpu.llvm.mlir.constant"() {value = 1 : i64} : () -> i64 loc(#loc)\n %23 = "stable_mosaic_gpu.llvm.insertvalue"(%21, %22) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\n %24 = "stable_mosaic_gpu.builtin.unrealized_conversion_cast"(%23) : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>) -> memref<128xf32> loc(#loc)\n %25 = "stable_mosaic_gpu.llvm.getelementptr"(%arg1) {elem_type = !llvm.ptr, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc)\n %26 = "stable_mosaic_gpu.llvm.load"(%25) {ordering = 0 : i64} : (!llvm.ptr) -> !llvm.ptr loc(#loc)\n %27 = "stable_mosaic_gpu.llvm.mlir.undef"() : () -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\n %28 = "stable_mosaic_gpu.llvm.insertvalue"(%27, %26) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\n %29 = "stable_mosaic_gpu.llvm.insertvalue"(%28, %26) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\n %30 = "stable_mosaic_gpu.llvm.mlir.constant"() {value = 0 : i64} : () -> i64 loc(#loc)\n %31 = "stable_mosaic_gpu.llvm.insertvalue"(%29, %30) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\n %32 = "stable_mosaic_gpu.llvm.mlir.constant"() {value = 128 : i64} : () -> i64 loc(#loc)\n %33 = "stable_mosaic_gpu.llvm.insertvalue"(%31, %32) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\n %34 = "stable_mosaic_gpu.llvm.mlir.constant"() {value = 1 : i64} : () -> i64 loc(#loc)\n %35 = "stable_mosaic_gpu.llvm.insertvalue"(%33, %34) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc)\n %36 = "stable_mosaic_gpu.builtin.unrealized_conversion_cast"(%35) : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>) -> memref<128xf32> loc(#loc)\n %37 = "stable_mosaic_gpu.arith.constant"() {value = 2 : index} : () -> index loc(#loc)\n %38 = "stable_mosaic_gpu.arith.constant"() {value = 1 : index} : () -> index loc(#loc)\n %39 = "stable_mosaic_gpu.arith.constant"() {value = 1 : index} : () -> index loc(#loc)\n %40 = "stable_mosaic_gpu.arith.constant"() {value = 128 : index} : () -> index loc(#loc)\n %41 = "stable_mosaic_gpu.arith.constant"() {value = 1 : index} : () -> index loc(#loc)\n %42 = "stable_mosaic_gpu.arith.constant"() {value = 1 : index} : () -> index loc(#loc)\n %43 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc)\n %44 = "stable_mosaic_gpu.gpu.launch"(%0, %37, %38, %39, %40, %41, %42, %43) ({\n ^bb0(%arg2: index loc("-":50:40), %arg3: index loc("-":50:47), %arg4: index loc("-":50:54), %arg5: index loc("-":50:115), %arg6: index loc("-":50:122), %arg7: index loc("-":50:129), %arg8: index loc("-":50:65), %arg9: index loc("-":50:78), %arg10: index loc("-":50:91), %arg11: index loc("-":50:140), %arg12: index loc("-":50:156), %arg13: index loc("-":50:172)):\n %45 = "stable_mosaic_gpu.gpu.dynamic_shared_memory"() : () -> memref> loc(#loc)\n %46 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc)\n %47 = "stable_mosaic_gpu.memref.view"(%45, %46) : (memref>, index) -> memref<0xi8, #gpu.address_space> loc(#loc)\n "stable_mosaic_gpu.nvvm.fence.mbarrier.init"() : () -> () loc(#loc)\n "stable_mosaic_gpu.gpu.barrier"() : () -> () loc(#loc)\n %48 = "stable_mosaic_gpu.nvvm.elect.sync"() : () -> i1 loc(#loc)\n %49 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc)\n %50 = "stable_mosaic_gpu.arith.index_cast"(%49) : (index) -> i32 loc(#loc)\n %51 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc)\n %52 = "stable_mosaic_gpu.arith.index_cast"(%51) : (index) -> i32 loc(#loc)\n %53 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc)\n %54 = "stable_mosaic_gpu.arith.index_cast"(%53) : (index) -> i32 loc(#loc)\n %55 = "stable_mosaic_gpu.arith.muli"(%54, %52) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc)\n %56 = "stable_mosaic_gpu.arith.addi"(%50, %55) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc)\n %57 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc)\n %58 = "stable_mosaic_gpu.arith.index_cast"(%57) : (index) -> i32 loc(#loc)\n %59 = "stable_mosaic_gpu.arith.muli"(%52, %58) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc)\n %60 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc)\n %61 = "stable_mosaic_gpu.arith.index_cast"(%60) : (index) -> i32 loc(#loc)\n %62 = "stable_mosaic_gpu.arith.muli"(%61, %59) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc)\n %63 = "stable_mosaic_gpu.arith.addi"(%56, %62) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc)\n %64 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc)\n %65 = "stable_mosaic_gpu.arith.index_cast"(%64) : (index) -> i32 loc(#loc)\n %66 = "stable_mosaic_gpu.arith.muli"(%59, %65) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc)\n %67 = "stable_mosaic_gpu.arith.constant"() {value = 5 : i32} : () -> i32 loc(#loc)\n %68 = "stable_mosaic_gpu.arith.shrui"(%63, %67) : (i32, i32) -> i32 loc(#loc)\n %69 = "stable_mosaic_gpu.arith.constant"() {value = -1 : i32} : () -> i32 loc(#loc)\n %70 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc)\n %71 = "stable_mosaic_gpu.arith.constant"() {value = 31 : i32} : () -> i32 loc(#loc)\n %72 = "stable_mosaic_gpu.nvvm.shfl.sync"(%69, %68, %70, %71) {kind = #nvvm} : (i32, i32, i32, i32) -> i32 loc(#loc)\n %73 = "stable_mosaic_gpu.arith.constant"() {value = 4 : i32} : () -> i32 loc(#loc)\n %74 = "stable_mosaic_gpu.arith.remui"(%72, %73) : (i32, i32) -> i32 loc(#loc)\n %75 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc)\n %76 = "stable_mosaic_gpu.arith.cmpi"(%74, %75) {predicate = 0 : i64} : (i32, i32) -> i1 loc(#loc)\n %77 = "stable_mosaic_gpu.arith.andi"(%76, %48) : (i1, i1) -> i1 loc(#loc)\n %78 = "stable_mosaic_gpu.nvvm.elect.sync"() : () -> i1 loc(#loc)\n %79 = "stable_mosaic_gpu.gpu.block_id"() {dimension = #gpu} : () -> index loc(#loc)\n %80 = "stable_mosaic_gpu.arith.index_cast"(%79) : (index) -> i32 loc(#loc)\n %81 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc29)\n %82 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc29)\n %83 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc29)\n %84 = "stable_mosaic_gpu.arith.muli"(%82, %83) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc29)\n %85 = "stable_mosaic_gpu.arith.addi"(%84, %81) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc29)\n %86 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc29)\n %87 = "stable_mosaic_gpu.arith.addi"(%85, %86) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc29)\n %88 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc30)\n %89 = "stable_mosaic_gpu.arith.remsi"(%87, %88) : (i32, i32) -> i32 loc(#loc30)\n %90 = "stable_mosaic_gpu.gpu.block_id"() {dimension = #gpu} : () -> index loc(#loc31)\n %91 = "stable_mosaic_gpu.arith.index_cast"(%90) : (index) -> i32 loc(#loc31)\n %92 = "stable_mosaic_gpu.memref.subview"(%12) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<128xf32>) -> memref<128xf32, strided<[1]>> loc(#loc32)\n %93 = "stable_mosaic_gpu.memref.collapse_shape"(%92) {reassociation = [[0]]} : (memref<128xf32, strided<[1]>>) -> memref<128xf32> loc(#loc32)\n %94 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc32)\n %95 = "stable_mosaic_gpu.arith.constant"() {value = 128 : index} : () -> index loc(#loc32)\n %96 = "stable_mosaic_gpu.arith.remui"(%94, %95) : (index, index) -> index loc(#loc32)\n %97 = "stable_mosaic_gpu.arith.constant"() {value = 1 : index} : () -> index loc(#loc32)\n %98 = "stable_mosaic_gpu.arith.muli"(%96, %97) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc32)\n %99 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc32)\n %100 = "stable_mosaic_gpu.arith.addi"(%98, %99) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc32)\n %101 = "stable_mosaic_gpu.vector.load"(%93, %100) : (memref<128xf32>, index) -> vector<1xf32> loc(#loc32)\n %102 = "stable_mosaic_gpu.arith.constant"() {value = 1.000000e+00 : f32} : () -> f32 loc(#loc33)\n %103 = "stable_mosaic_gpu.vector.splat"(%102) : (f32) -> vector<1xf32> loc(#loc33)\n %104 = "stable_mosaic_gpu.arith.addf"(%101, %103) {fastmath = #arith.fastmath} : (vector<1xf32>, vector<1xf32>) -> vector<1xf32> loc(#loc33)\n %105 = "stable_mosaic_gpu.memref.subview"(%24) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<128xf32>) -> memref<128xf32, strided<[1]>> loc(#loc34)\n %106 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc34)\n %107 = "stable_mosaic_gpu.arith.index_cast"(%106) : (index) -> i32 loc(#loc34)\n %108 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc34)\n %109 = "stable_mosaic_gpu.arith.index_cast"(%108) : (index) -> i32 loc(#loc34)\n %110 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc34)\n %111 = "stable_mosaic_gpu.arith.index_cast"(%110) : (index) -> i32 loc(#loc34)\n %112 = "stable_mosaic_gpu.arith.muli"(%111, %109) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc34)\n %113 = "stable_mosaic_gpu.arith.addi"(%107, %112) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc34)\n %114 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc34)\n %115 = "stable_mosaic_gpu.arith.index_cast"(%114) : (index) -> i32 loc(#loc34)\n %116 = "stable_mosaic_gpu.arith.muli"(%109, %115) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc34)\n %117 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc34)\n %118 = "stable_mosaic_gpu.arith.index_cast"(%117) : (index) -> i32 loc(#loc34)\n %119 = "stable_mosaic_gpu.arith.muli"(%118, %116) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc34)\n %120 = "stable_mosaic_gpu.arith.addi"(%113, %119) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc34)\n %121 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc34)\n %122 = "stable_mosaic_gpu.arith.index_cast"(%121) : (index) -> i32 loc(#loc34)\n %123 = "stable_mosaic_gpu.arith.muli"(%116, %122) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc34)\n %124 = "stable_mosaic_gpu.arith.constant"() {value = 7 : i32} : () -> i32 loc(#loc34)\n %125 = "stable_mosaic_gpu.arith.shrui"(%120, %124) : (i32, i32) -> i32 loc(#loc34)\n %126 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc34)\n %127 = "stable_mosaic_gpu.arith.addi"(%125, %126) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc34)\n %128 = "stable_mosaic_gpu.llvm.inline_asm"(%127) {asm_string = "bar.sync $0, 128;", constraints = "r", has_side_effects, tail_call_kind = #llvm.tailcallkind} : (i32) -> !llvm.void loc(#loc34)\n %129 = "stable_mosaic_gpu.memref.collapse_shape"(%105) {reassociation = [[0]]} : (memref<128xf32, strided<[1]>>) -> memref<128xf32> loc(#loc34)\n %130 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc34)\n %131 = "stable_mosaic_gpu.arith.constant"() {value = 128 : index} : () -> index loc(#loc34)\n %132 = "stable_mosaic_gpu.arith.remui"(%130, %131) : (index, index) -> index loc(#loc34)\n %133 = "stable_mosaic_gpu.arith.constant"() {value = 1 : index} : () -> index loc(#loc34)\n %134 = "stable_mosaic_gpu.arith.muli"(%132, %133) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc34)\n %135 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc34)\n %136 = "stable_mosaic_gpu.arith.addi"(%134, %135) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc34)\n %137 = "stable_mosaic_gpu.vector.load"(%129, %136) : (memref<128xf32>, index) -> vector<1xf32> loc(#loc34)\n %138 = "stable_mosaic_gpu.memref.collapse_shape"(%105) {reassociation = [[0]]} : (memref<128xf32, strided<[1]>>) -> memref<128xf32> loc(#loc34)\n %139 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc34)\n %140 = "stable_mosaic_gpu.arith.constant"() {value = 128 : index} : () -> index loc(#loc34)\n %141 = "stable_mosaic_gpu.arith.remui"(%139, %140) : (index, index) -> index loc(#loc34)\n %142 = "stable_mosaic_gpu.arith.constant"() {value = 1 : index} : () -> index loc(#loc34)\n %143 = "stable_mosaic_gpu.arith.muli"(%141, %142) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc34)\n %144 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc34)\n %145 = "stable_mosaic_gpu.arith.addi"(%143, %144) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc34)\n "stable_mosaic_gpu.vector.store"(%104, %138, %145) : (vector<1xf32>, memref<128xf32>, index) -> () loc(#loc34)\n %146 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc34)\n %147 = "stable_mosaic_gpu.arith.index_cast"(%146) : (index) -> i32 loc(#loc34)\n %148 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc34)\n %149 = "stable_mosaic_gpu.arith.index_cast"(%148) : (index) -> i32 loc(#loc34)\n %150 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc34)\n %151 = "stable_mosaic_gpu.arith.index_cast"(%150) : (index) -> i32 loc(#loc34)\n %152 = "stable_mosaic_gpu.arith.muli"(%151, %149) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc34)\n %153 = "stable_mosaic_gpu.arith.addi"(%147, %152) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc34)\n %154 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc34)\n %155 = "stable_mosaic_gpu.arith.index_cast"(%154) : (index) -> i32 loc(#loc34)\n %156 = "stable_mosaic_gpu.arith.muli"(%149, %155) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc34)\n %157 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc34)\n %158 = "stable_mosaic_gpu.arith.index_cast"(%157) : (index) -> i32 loc(#loc34)\n %159 = "stable_mosaic_gpu.arith.muli"(%158, %156) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc34)\n %160 = "stable_mosaic_gpu.arith.addi"(%153, %159) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc34)\n %161 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc34)\n %162 = "stable_mosaic_gpu.arith.index_cast"(%161) : (index) -> i32 loc(#loc34)\n %163 = "stable_mosaic_gpu.arith.muli"(%156, %162) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc34)\n %164 = "stable_mosaic_gpu.arith.constant"() {value = 7 : i32} : () -> i32 loc(#loc34)\n %165 = "stable_mosaic_gpu.arith.shrui"(%160, %164) : (i32, i32) -> i32 loc(#loc34)\n %166 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc34)\n %167 = "stable_mosaic_gpu.arith.addi"(%165, %166) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc34)\n %168 = "stable_mosaic_gpu.llvm.inline_asm"(%167) {asm_string = "bar.sync $0, 128;", constraints = "r", has_side_effects, tail_call_kind = #llvm.tailcallkind} : (i32) -> !llvm.void loc(#loc34)\n %169 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc35)\n %170 = "stable_mosaic_gpu.arith.addi"(%87, %169) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\n %171 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc30)\n %172 = "stable_mosaic_gpu.arith.remsi"(%170, %171) : (i32, i32) -> i32 loc(#loc30)\n %173 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc36)\n %174 = "stable_mosaic_gpu.arith.cmpi"(%87, %173) {predicate = 5 : i64} : (i32, i32) -> i1 loc(#loc36)\n %175 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc37)\n %176 = "stable_mosaic_gpu.arith.cmpi"(%170, %175) {predicate = 2 : i64} : (i32, i32) -> i1 loc(#loc37)\n %177 = "stable_mosaic_gpu.arith.andi"(%174, %176) : (i1, i1) -> i1 loc(#loc38)\n %178 = "stable_mosaic_gpu.arith.extui"(%177) : (i1) -> i32 loc(#loc39)\n %179 = "stable_mosaic_gpu.arith.index_cast"(%178) : (i32) -> index loc(#loc40)\n "stable_mosaic_gpu.scf.index_switch"(%179) ({\n "stable_mosaic_gpu.scf.yield"() : () -> () loc(#loc16)\n }, {\n "stable_mosaic_gpu.scf.yield"() : () -> () loc(#loc40)\n }) {cases = array} : (index) -> () loc(#loc40)\n %180 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc29)\n %181 = "stable_mosaic_gpu.arith.addi"(%81, %180) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc29)\n "stable_mosaic_gpu.gpu.terminator"() : () -> () loc(#loc)\n }) {operandSegmentSizes = array, workgroup_attributions = 0 : i64} : (!gpu.async.token, index, index, index, index, index, index, i32) -> !gpu.async.token loc(#loc)\n "stable_mosaic_gpu.func.return"() : () -> () loc(#loc)\n }) {function_type = (!llvm.ptr, !llvm.ptr) -> (), llvm.emit_c_interface, sym_name = "add_one_mosaic_gpu"} : () -> () loc(#loc)\n}) {stable_mosaic_gpu.version = 1 : i64} : () -> () loc(#loc)\n#loc13 = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":102:4)\n#loc14 = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":98:19)\n#loc15 = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":98:6)\n#loc16 = loc("-":189:7)\n#loc17 = loc("scan"(#loc13))\n#loc18 = loc("rem"(#loc13))\n#loc19 = loc("jaxpr_call"(#loc13))\n#loc20 = loc("get"(#loc14))\n#loc21 = loc("add"(#loc14))\n#loc22 = loc("swap"(#loc15))\n#loc23 = loc("add"(#loc13))\n#loc24 = loc("ge"(#loc13))\n#loc25 = loc("lt"(#loc13))\n#loc26 = loc("and"(#loc13))\n#loc27 = loc("convert_element_type"(#loc13))\n#loc28 = loc("cond"(#loc13))\n#loc29 = loc("scan:"(#loc17))\n#loc30 = loc("rem:"(#loc18))\n#loc31 = loc("jaxpr_call:"(#loc19))\n#loc32 = loc("get:"(#loc20))\n#loc33 = loc("add:"(#loc21))\n#loc34 = loc("swap:"(#loc22))\n#loc35 = loc("add:"(#loc23))\n#loc36 = loc("ge:"(#loc24))\n#loc37 = loc("lt:"(#loc25))\n#loc38 = loc("and:"(#loc26))\n#loc39 = loc("convert_element_type:"(#loc27))\n#loc40 = loc("cond:"(#loc28))\n\x00use_custom_barrier\x00mosaic_gpu_v2\x00\x089\t\x05#\x01\x0b/35=?\x11+-A%\'%%%\x11+-O%\'QSW', + xla_call_module_version=10, + nr_devices=1, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_matmul.py b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_matmul.py index 2c94cb777b46..9762ce8aa7c1 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_matmul.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_matmul.py @@ -13,7 +13,10 @@ # limitations under the License. import datetime -from numpy import array, float32 +import numpy as np + +array = np.array +float32 = np.float32 # Pasted from the test output (see export_back_compat_test_util.py module docstring) diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_semaphore_dma.py b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_semaphore_dma.py index a44e92846b98..deac893c8f37 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_semaphore_dma.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_semaphore_dma.py @@ -13,7 +13,11 @@ # limitations under the License. import datetime -from numpy import array, float32 +import numpy as np + +array = np.array +float32 = np.float32 + # Pasted from the test output (see export_back_compat_test_util.py module docstring) semaphore_and_dma_2024_04_22 = dict( diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/pallas/triton_add_one.py b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/triton_add_one.py index 4a1bbe63b9be..ca20b459d303 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/pallas/triton_add_one.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/triton_add_one.py @@ -13,7 +13,10 @@ # limitations under the License. import datetime -from numpy import array, float32 +import numpy as np + +array = np.array +float32 = np.float32 # Pasted from the test output (see export_back_compat_test_util.py module docstring) diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/rocm_eigh_hipsolver_syev.py b/jax/_src/internal_test_util/export_back_compat_test_data/rocm_eigh_hipsolver_syev.py index 299fd8ca33f6..7228e3ecbd83 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/rocm_eigh_hipsolver_syev.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/rocm_eigh_hipsolver_syev.py @@ -15,7 +15,10 @@ # ruff: noqa import datetime -from numpy import array, float32 +import numpy as np + +array = np.array +float32 = np.float32 data_2024_08_05 = {} diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/rocm_qr_hipsolver_geqrf.py b/jax/_src/internal_test_util/export_back_compat_test_data/rocm_qr_hipsolver_geqrf.py deleted file mode 100644 index bd5fa628741e..000000000000 --- a/jax/_src/internal_test_util/export_back_compat_test_data/rocm_qr_hipsolver_geqrf.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import datetime -from numpy import array, float32 - -data_2024_08_05 = {} - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_05["unbatched"] = dict( - testdata_version=1, - platform='rocm', - custom_call_targets=['hipsolver_geqrf', 'hipsolver_orgqr'], - serialized_date=datetime.date(2024, 8, 5), - inputs=(), - expected_outputs=(array([[ 0. , 0.9128709 , 0.40824834], - [-0.4472136 , 0.3651484 , -0.81649655], - [-0.8944272 , -0.18257423, 0.40824828]], dtype=float32), array([[-6.7082038e+00, -8.0498447e+00, -9.3914852e+00], - [ 0.0000000e+00, 1.0954450e+00, 2.1908898e+00], - [ 0.0000000e+00, 0.0000000e+00, 1.6371473e-09]], dtype=float32)), - mlir_module_text=r""" -#loc2 = loc("/release/jax/tests/export_back_compat_test.py":346:0) -#loc9 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc2)) -module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3x3xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xf32> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<9xf32>) -> tensor<3x3xf32> loc(#loc4) - %2:4 = stablehlo.custom_call @hipsolver_geqrf(%1) {api_version = 2 : i32, backend_config = "\00\00\00\00\01\00\00\00\03\00\00\00\03\00\00\00\00\01\00\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf32>) -> (tensor<3x3xf32>, tensor<3xf32>, tensor, tensor<256xf32>) loc(#loc5) - %c = stablehlo.constant dense<0> : tensor loc(#loc5) - %3 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor loc(#loc5) - %4 = stablehlo.compare EQ, %2#2, %3, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %cst = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc5) - %6 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc5) - %7 = stablehlo.broadcast_in_dim %5, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc5) - %8 = stablehlo.select %7, %2#0, %6 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc5) - %9 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) - %cst_0 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc5) - %10 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor) -> tensor<3xf32> loc(#loc5) - %11 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<1xi1>) -> tensor<3xi1> loc(#loc5) - %12 = stablehlo.select %11, %2#1, %10 : tensor<3xi1>, tensor<3xf32> loc(#loc5) - %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc6) - %13 = stablehlo.pad %8, %cst_1, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> loc(#loc7) - %14:3 = stablehlo.custom_call @hipsolver_orgqr(%13, %12) {api_version = 2 : i32, backend_config = "\00\00\00\00\01\00\00\00\03\00\00\00\03\00\00\00\03\00\00\00\80\00\00\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf32>, tensor<3xf32>) -> (tensor<3x3xf32>, tensor, tensor<128xf32>) loc(#loc8) - %c_2 = stablehlo.constant dense<0> : tensor loc(#loc8) - %15 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor loc(#loc8) - %16 = stablehlo.compare EQ, %14#1, %15, SIGNED : (tensor, tensor) -> tensor loc(#loc8) - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc8) - %cst_3 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc8) - %18 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc8) - %19 = stablehlo.broadcast_in_dim %17, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc8) - %20 = stablehlo.select %19, %14#0, %18 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc8) - %21 = call @triu(%8) : (tensor<3x3xf32>) -> tensor<3x3xf32> loc(#loc9) - return %20, %21 : tensor<3x3xf32>, tensor<3x3xf32> loc(#loc) - } loc(#loc) - func.func private @triu(%arg0: tensor<3x3xf32> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc2))) -> (tensor<3x3xf32> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc10) - %c = stablehlo.constant dense<-1> : tensor loc(#loc9) - %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc11) - %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc11) - %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc12) - %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc13) - %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc9) - %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc14) - %6 = stablehlo.select %4, %5, %arg0 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc15) - return %6 : tensor<3x3xf32> loc(#loc9) - } loc(#loc9) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("/release/jax/tests/export_back_compat_test.py":345:0) -#loc3 = loc("jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]"(#loc1)) -#loc4 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc1)) -#loc5 = loc("jit()/jit(main)/geqrf"(#loc2)) -#loc6 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc2)) -#loc7 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]"(#loc2)) -#loc8 = loc("jit()/jit(main)/householder_product"(#loc2)) -#loc10 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc2)) -#loc11 = loc("jit()/jit(main)/jit(triu)/add"(#loc2)) -#loc12 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc2)) -#loc13 = loc("jit()/jit(main)/jit(triu)/ge"(#loc2)) -#loc14 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]"(#loc2)) -#loc15 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03~\x02\xf39\x01\x99\x0f\x17\x13\x0f\x0f\x0b\x0b\x07\x0b\x13\x0f\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x13+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0bK\x0b\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0bK\x0b\x13\x0b\x03[O/\x0b\x0b\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0f\x0f\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1f\x0b\x0b\x0f\x17\x1b\x1f\x0b\x1fO/\x0b\x0b\x13\x17\x01\x05\x0b\x0f\x035\x17\x0f\x0f\x07\x07\x07\x17\x17\x13\x07\x07\x0f\x17\x13\x17\x17\x13\x13\x17\x13\x13\x13\x13\x13\x13\x17\x02\xde\x08\x1d}\x03\x17\x1fj\x05\x01\x03\x03\x11\xcf\x1d\x93\x03\x1dU\x03\x05\x1f\x05!\x1f\x05#\x03\x03\x0b\xe5\x11\x03\x05\x05%\x05'\x05)\x05+\x05-\x03\x03#\xcb\x05/\x1d]\x03\x051\x053\x03\x03\x0b\xd5\x17\x1ff\x05\x01\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\x03\x0b\xe1\x03\x05'\xab)\xe3\x03\x03\x11\xe7\x03\tGIK\x15M\x15\rO\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x17\x9d\x19\xb5\x1b\xb7\r\xc1\x1d\xc3\x03\x0b\x17\xa7\x19\xc7\x1b\xa7\r\xa9\x1d\xc9\x05M\x1dY\x03\x05O\x03\x03\x0b\xcd\x05Q\x03\x03#\xd1\x1dc\x03\x05S\x03\x05'\xab)\xd3\x1di\x03\x05U\x1dm\x03\x05W\x1dq\x03\x05Y\x1du-\x05[\x1dy-\x05]\x03\x11/\xad1\xd73\xd95\x9d7\xaf9\xdb;\xb1=\xdf\x05_\x03\x03\x11\xe9\x1d\x83\x03\x05a\x03\x07\x87\xa3\x89\xa3\x8b\xa3\x05c\x05e\x05g\x1d\x8f\x03\x05i\x03\x11/\xad1\xeb3\xed5\x9d7\xaf9\xef;\xb1=\xf1\x05k\x03\x03\x97\xa9\x05m\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1do\x1dq\x1f\x1f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1ds\x03\x03\xc5\x1du\t\x07\x0b\x05\x05\x01\x03\x03\xdd\x1f/\x01#!\x03\x05\xb9\xbd\r\x05\xa5\xbb\x9f\xa1\x1dw\r\x05\xa5\xbf\x9f\xa1\x1dy\x1d{\x1d}\r\x03\x9f\xa1##\x1d\x7f\x13\r\x01\x1f\x07\t\xff\xff\xff\xff\x1f%\x01\x13\r\x05\x07\x05\x1f\t\t\x00\x00\x00\x00\x1d\x81\x1d\x83\x03\x03\x99\x15\x03\x01\x01\x01\x03\t\x99\x9b\xb3\x9b\x1f\x07\t\x00\x00\x00\x00\x07\x01\x1f\t\t\x00\x00\xc0\x7f\x1f\x1f!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d\x85\x1d\x87\x03\x05\x99\x9b\x03\x07\x99\xb3\x9b\x01\t\x01\x02\x02)\x05\r\r\x0b)\x01\x19)\x01\x0b\t\x1d\x01)\x05\r\r\x19)\x05\r\r\x0f)\x03\r\x0b\x13\x1b)\x01\x0f)\x05\x05\x05\x0f)\x03\t\r\x11\x01\x05\x05\x05\x11\x03\x05\x03\x05)\x03\x01\r)\x03%\x0b)\x03\x02\x08\x0b)\x03\t\x17)\x03\x05\x17)\x03\x01\x17)\x03\x05\x0f)\x03\r\x0f)\x03\x05\r)\x03\x02\x04\x0b\x04\x1a\x05\x05\x01\x11\x0fE\x07\x03\x01\t\r\x11\x0fQ\x07\x03Cu\t\x03s!\x03'\x15\x06w\x03\x05\x03\x01\x11\x07\x01{\t\x05\x15\x07)\x03\x03\x05\x03\x01?\x03\x07\x03\x07\x01\x05\x03\x07\x03\r\x0b\x07\x01A\x03\x1b\x05\t\x0f\x03\x07\x01\x05\x03\x1d\x03\x11\x05\x03\x01\x13\x03\t\x03\x07\x01\x05\x03\x05\x03\x15\x03\x07\x01C\x03\x13\x03\x13\x07\x06\x01\x03\x05\x07\x19\x05\x17\x03\x07\x01\x05\x031\x03\x11\x05\x03\x01\x13\x03\t\x03\x07\x01\x05\x03\x15\x03\x1f\x03\x07\x01\x7f\x033\x03\x1d\x07\x06\x01\x03\x15\x07#\x07!\x05\x03\x81+\x03\t\x17\x07\x8d\x85\x03\x05\x05\x1b'\x11\x07\x07\x91\x07\x05\x077\x05)%\x05\x03\x07?\x03\x07\x03\x07\x07\x05\x03\x07\x031\x0b\x07\x07A\x03\x1b\x05-3\x03\x07\x07\x05\x03\x1d\x035\x05\x03\x07\x13\x03\t\x03\x07\x07\x05\x03\x05\x039\x03\x07\x07C\x03\x13\x037\x07\x06\x07\x03\x05\x07=+;\x19\x07\t\x95\x03\x05\x03\x1b\x0f\x04\x0f\x05?A\r\x11\tS\x07\x03\x15+\x03\x05\t\t\x03W!\x03\x11\x05\x03\t[\x03\x07\x03\x07%\x05\x03\x11\x03\x05\x13\x06%\x03\x11\x05\x03\x07\t\x03a_\x03\x11\x0b\x07ge\x03\x13\x05\t\x0b\x05\x03\t+\x03\t\x03\x07k\x05\x03\x05\x03\x0f\x07\x06o\x03\x05\x07\r\x11\x01\x0f\x04\t\x03\x13\x06\x03\x01\x05\x01\x00\xea\x1a\x89!3!+\x11\x0f\x0b\t\t\x0b!\x11#\x0fY\x87##%_=\x85\x87W\xb3K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b\x1f]\x1f\x15\x1d\x15+\x13\r\x11\x0f\x17\x0f\x1f\x15\x11\x17\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00iota_v1\x00compare_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00/release/jax/tests/export_back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00callee\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x00\x01\x00\x00\x00hipsolver_geqrf\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x80\x00\x00\x00\x00hipsolver_orgqr\x00", - xla_call_module_version=9, - nr_devices=1, -) # End paste - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_05["batched"] = dict( - testdata_version=1, - platform='rocm', - custom_call_targets=['hipblas_geqrf_batched', 'hipsolver_orgqr'], - serialized_date=datetime.date(2024, 8, 5), - inputs=(), - expected_outputs=(array([[[ 0. , 0.9128709 , 0.40824834], - [-0.4472136 , 0.3651484 , -0.81649655], - [-0.8944272 , -0.18257423, 0.40824828]], - - [[-0.42426407, 0.8082888 , 0.4082513 ], - [-0.5656854 , 0.11547317, -0.81649613], - [-0.7071068 , -0.5773518 , 0.40824607]]], dtype=float32), array([[[-6.7082038e+00, -8.0498447e+00, -9.3914852e+00], - [ 0.0000000e+00, 1.0954450e+00, 2.1908898e+00], - [ 0.0000000e+00, 0.0000000e+00, 1.6371473e-09]], - - [[-2.1213203e+01, -2.2910259e+01, -2.4607313e+01], - [ 0.0000000e+00, 3.4641036e-01, 6.9281983e-01], - [ 0.0000000e+00, 0.0000000e+00, 8.3555670e-07]]], dtype=float32)), - mlir_module_text=r""" -#loc2 = loc("/release/jax/tests/export_back_compat_test.py":346:0) -#loc9 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc2)) -module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x3x3xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3x3xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<18xf32> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<18xf32>) -> tensor<2x3x3xf32> loc(#loc4) - %2:4 = stablehlo.custom_call @hipblas_geqrf_batched(%1) {api_version = 2 : i32, backend_config = "\00\00\00\00\02\00\00\00\03\00\00\00\03\00\00\00", operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x3x3xf32>) -> (tensor<2x3x3xf32>, tensor<2x3xf32>, tensor<16xi8>, tensor<16xi8>) loc(#loc5) - %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc6) - %3 = stablehlo.pad %2#0, %cst, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<2x3x3xf32>, tensor) -> tensor<2x3x3xf32> loc(#loc7) - %4:3 = stablehlo.custom_call @hipsolver_orgqr(%3, %2#1) {api_version = 2 : i32, backend_config = "\00\00\00\00\02\00\00\00\03\00\00\00\03\00\00\00\03\00\00\00\80\00\00\00", operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x3x3xf32>, tensor<2x3xf32>) -> (tensor<2x3x3xf32>, tensor<2xi32>, tensor<128xf32>) loc(#loc8) - %c = stablehlo.constant dense<0> : tensor loc(#loc8) - %5 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<2xi32> loc(#loc8) - %6 = stablehlo.compare EQ, %4#1, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc8) - %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc8) - %cst_0 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc8) - %8 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor) -> tensor<2x3x3xf32> loc(#loc8) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x3x3xi1> loc(#loc8) - %10 = stablehlo.select %9, %4#0, %8 : tensor<2x3x3xi1>, tensor<2x3x3xf32> loc(#loc8) - %11 = call @triu(%2#0) : (tensor<2x3x3xf32>) -> tensor<2x3x3xf32> loc(#loc9) - return %10, %11 : tensor<2x3x3xf32>, tensor<2x3x3xf32> loc(#loc) - } loc(#loc) - func.func private @triu(%arg0: tensor<2x3x3xf32> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc2))) -> (tensor<2x3x3xf32> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc10) - %c = stablehlo.constant dense<-1> : tensor loc(#loc9) - %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc11) - %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc11) - %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc12) - %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc13) - %5 = stablehlo.broadcast_in_dim %4, dims = [1, 2] : (tensor<3x3xi1>) -> tensor<2x3x3xi1> loc(#loc14) - %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc9) - %6 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x3x3xf32> loc(#loc15) - %7 = stablehlo.select %5, %6, %arg0 : tensor<2x3x3xi1>, tensor<2x3x3xf32> loc(#loc16) - return %7 : tensor<2x3x3xf32> loc(#loc9) - } loc(#loc9) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("/release/jax/tests/export_back_compat_test.py":345:0) -#loc3 = loc("jit()/jit(main)/iota[dtype=float32 shape=(18,) dimension=0]"(#loc1)) -#loc4 = loc("jit()/jit(main)/reshape[new_sizes=(2, 3, 3) dimensions=None]"(#loc1)) -#loc5 = loc("jit()/jit(main)/geqrf"(#loc2)) -#loc6 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc2)) -#loc7 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0), (0, 0, 0))]"(#loc2)) -#loc8 = loc("jit()/jit(main)/householder_product"(#loc2)) -#loc10 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc2)) -#loc11 = loc("jit()/jit(main)/jit(triu)/add"(#loc2)) -#loc12 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc2)) -#loc13 = loc("jit()/jit(main)/jit(triu)/ge"(#loc2)) -#loc14 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=(1, 2)]"(#loc2)) -#loc15 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=()]"(#loc2)) -#loc16 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03\x96\x02\xfb=\x01\x9f\x17\x0f\x0f\x0b\x13\x0b\x0b\x07\x0f\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0bK\x0f\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0bK\x0b\x13\x1b\x13\x13\x13\x13\x0b\x03]o/\x0b\x0b\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0fO\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0bO\x1f\x0b\x0b\x0f\x17\x1b\x0b\x0b\x13\x17\x1f\x0b/\x1fo\x01\x05\x0b\x0f\x039\x1b\x07\x07\x0f\x17\x0f\x07\x07\x07\x1b\x13\x13\x13\x17\x17\x13\x17\x13\x13\x17\x07\x13\x13\x13\x17\x13\x1b\x13\x02\xf6\t\x17\x1bj\x05\x01\x1d\x8f\x01\x1dK\x01\x05\x1f\x03\x03\x0b\xd5\x05!\x05#\x1f\x11\x03\x05\x05%\x05'\x05)\x05+\x05-\x03\x03\x1f\xd1\x05/\x1dS\x01\x051\x053\x03\x03\x07\xdd\x17\x1bf\x05\x01\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\t=?A\x11C\x11\rE\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x13\xa3\x15\xbb\x17\xbd\r\xc7\x19\xc9\x03\x0b\x13\xad\x15\xcd\x17\xad\r\xaf\x19\xcf\x05M\x1dO\x01\x05O\x03\x03\x07\xd3\x05Q\x03\x03\x1f\xd7\x1dY\x01\x05S\x03\x05#\xb1%\xd9\x1d_\x01\x05U\x03\x03\x0b\xdb\x1de\x01\x05W\x1di\x01\x05Y\x1dm\x01\x05[\x1dq)\x05]\x1du)\x05_\x03\x11+\xb3-\xdf/\xe11\xa33\xb55\xe37\xb79\xe7\x1d{\x01\x05a\x1d\x7f\x01\x05c\x03\x07\x83\xa9\x85\xa9\x87\xa9\x05e\x05g\x05i\x1d\x8b\x01\x05k\x03\x11+\xb3-\xe9/\xeb1\xa33\xb55\xed7\xb79\xef\x05m\x03\x03\x07\xf1\x03\x05#\xb1%\xf3\x03\x03\x0b\xf5\x03\x03\x07\xf7\x03\x03\x0b\xf9\x03\x03\x9d\xaf\x05o\x1f/1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1dq\x1ds\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1du\x03\x03\xcb\x1dw\t\x07\x0b\x05\x05\x01\x03\x03\xe5\x1f1!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x05\xbf\xc3\r\x05\xab\xc1\xa5\xa7\x1dy\r\x05\xab\xc5\xa5\xa7\x1d{\x1d}\x1d\x7f\r\x03\xa5\xa7#!\x1d\x81\x13\x07\x01\x1f\x0f\t\xff\xff\xff\xff\x1f#\x01\x13\x07\x05\x07\x05\x1f'!\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x0b\t\x00\x00\x00\x00\x1d\x83\x1d\x85\x03\x03\x9f\x15\x03\x01\x01\x01\x03\t\x9f\xb9\xa1\xa1\x1d\x87\x1d\x89\x03\x05\x9f\xb9\x03\x07\x9f\xa1\xa1\x1f\x0f\t\x00\x00\x00\x00\x07\x01\x1f;\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x0b\t\x00\x00\xc0\x7f\x1f\x1b1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\r\r\t\x1d\t)\x01\t)\x05\r\r\x13)\x01\x13\x01\x1b\x13)\x07\t\r\r\x11)\x03A-)\x03\r\x07)\x03\t\x13\x11\x01\x05\x05\x05\x11\x03\x05\x03\x05)\x03\x01\x07)\x05\r\r\x11)\x03\t\x07)\x03I\t)\x05\t\r\t\x17)\x03\r\x15)\x03\t\x15)\x03\x05\x15)\x03\x02\x04\t)\x03\t\x11)\x07\t\x05\x05\x11)\x03\x05\x07\x04\xa6\x03\x05\x01\x11\x0f;\x07\x03\x01\t\t\x11\x0fG\x07\x03)A\x07\x03o\x1d\x03)\x15\x06s\x03\x05\x03\x01\x11\x07yw\t\x05+\x19\x19\x03\x03\x05\x03}'\x03\x0b\x17\x07\x89\x81\x03\x05\x05\x05\r\x11\x07\x03\x8d\x07\x05\x1d5\x05\x0f\x07\x05\x03\x03\x91\x03\x0f\x03\x07\x03\t\x03\x1d\x03\x17\x0b\x07\x03\x93\x037\x05\x13\x19\x03\x07\x03\x95\x039\x03\x1b\x05\x03\x03\x97\x03\x0b\x03\x07\x03\t\x03\x05\x03\x1f\x03\x07\x03\x99\x03\x17\x03\x1d\r\x06\x03\x03\x05\x07#\x11!\x19\x07\x05\x9b\x03\x05\x03\x05\x0f\x04\x0f\x05%'\t\x11\x05I\x07\x03\x17/\x03\x05\x05\x07\x03M\x1d\x03\r\x05\x03\x05Q\x03\x0f\x03\x07!\t\x03\r\x03\x05\x13\x06!\x03\r\x05\x03\x07\x07\x03WU\x03\r\x0b\x07][\x03%\x05\t\x0b\x03\x07ca\x03\x17\x03\r\x05\x03\x05'\x03\x0b\x03\x07g\t\x03\x05\x03\x11\r\x06k\x03\x05\x07\x0f\x13\x01\x0f\x04\x05\x03\x15\x06\x03\x01\x05\x01\x00\xbe\x1c\x8b!3-#\x11\x0f\x0b\t\t\x0b!\x11#\x0fY\x9d##%_=\x8b\x89W\xb9\xc1K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b\x1f]\x1f\x15\x1d\x15\x13+\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00broadcast_dimensions\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00/release/jax/tests/export_back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=(1, 2)]\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(18,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(2, 3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00callee\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x00hipblas_geqrf_batched\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x80\x00\x00\x00\x00hipsolver_orgqr\x00", - xla_call_module_version=9, - nr_devices=1, -) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/shardy_sharding_ops_with_different_meshes.py b/jax/_src/internal_test_util/export_back_compat_test_data/shardy_sharding_ops_with_different_meshes.py index b54234d11cca..c0a732fc47bf 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/shardy_sharding_ops_with_different_meshes.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/shardy_sharding_ops_with_different_meshes.py @@ -15,43 +15,87 @@ # ruff: noqa import datetime -from numpy import array, float32 +import numpy as np +array = np.array +float32 = np.float32 +int32 = np.int32 # Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2025_02_12 = dict( +data_2025_06_30 = dict( testdata_version=1, platform='tpu', - custom_call_targets=['Sharding', 'xla.sdy.GlobalToLocalShape', 'xla.sdy.LocalToGlobalShape'], - serialized_date=datetime.date(2025, 2, 12), + custom_call_targets=[], + serialized_date=datetime.date(2025, 6, 30), inputs=(array([[0., 1., 2., 3.], [4., 5., 6., 7.]], dtype=float32),), expected_outputs=(array([[4., 5., 6., 7.], [0., 1., 2., 3.]], dtype=float32),), mlir_module_text=r""" #loc1 = loc("x") -#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":1052:13) -#loc6 = loc("jit(func)/jit(main)/shard_map"(#loc3)) -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.frontend_attributes = {xla.sdy.meshes = "{mesh = #sdy.mesh<[\\\22a\\\22=2]>}"}, mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<2x4xf32> loc("x")) -> (tensor<2x4xf32> {jax.result_info = ""}) { - %0 = stablehlo.custom_call @Sharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}, {}]>]>"}, mhlo.sharding = "{devices=[2,1]<=[2]}"} : (tensor<2x4xf32>) -> tensor<2x4xf32> loc(#loc5) - %1 = stablehlo.custom_call @xla.sdy.GlobalToLocalShape(%0) : (tensor<2x4xf32>) -> tensor<1x4xf32> loc(#loc6) - %2 = call @xla.sdy.manual_computation_body(%1) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}, {}]>]>"}} : (tensor<1x4xf32>) -> tensor<1x4xf32> loc(#loc6) - %3 = stablehlo.custom_call @xla.sdy.LocalToGlobalShape(%2) : (tensor<1x4xf32>) -> tensor<2x4xf32> loc(#loc6) - return %3 : tensor<2x4xf32> loc(#loc) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":1001:8 to :54) +#loc4 = loc("third_party/py/absl/testing/absltest.py":2900:19 to :56) +#loc5 = loc("third_party/py/absl/testing/absltest.py":2936:35 to 2938:3) +#loc6 = loc("third_party/py/absl/testing/absltest.py":2446:6 to :34) +#loc7 = loc("third_party/py/absl/app.py":404:13 to :23) +#loc8 = loc("third_party/py/absl/app.py":484:6 to :27) +#loc9 = loc("third_party/py/absl/testing/absltest.py":2448:4 to :31) +#loc10 = loc("third_party/py/absl/testing/absltest.py":2330:2 to :38) +#loc11 = loc("third_party/py/jax/tests/export_back_compat_test.py":1005:2 to :47) +#loc12 = loc("third_party/py/jax/tests/export_back_compat_test.py":992:13 to :30) +#loc15 = loc("ShardyCompatTest.test_shardy_sharding_ops_with_different_meshes"(#loc3)) +#loc16 = loc("_run_and_get_tests_result"(#loc4)) +#loc17 = loc("run_tests"(#loc5)) +#loc18 = loc("_run_in_app..main_function"(#loc6)) +#loc19 = loc("_run_main"(#loc7)) +#loc20 = loc("run"(#loc8)) +#loc21 = loc("_run_in_app"(#loc9)) +#loc22 = loc("main"(#loc10)) +#loc23 = loc(""(#loc11)) +#loc24 = loc("ShardyCompatTest.test_shardy_sharding_ops_with_different_meshes..func"(#loc12)) +#loc26 = loc(callsite(#loc22 at #loc23)) +#loc28 = loc(callsite(#loc21 at #loc26)) +#loc30 = loc(callsite(#loc20 at #loc28)) +#loc32 = loc(callsite(#loc19 at #loc30)) +#loc34 = loc(callsite(#loc18 at #loc32)) +#loc36 = loc(callsite(#loc17 at #loc34)) +#loc38 = loc(callsite(#loc16 at #loc36)) +#loc40 = loc(callsite(#loc15 at #loc38)) +#loc43 = loc(callsite(#loc24 at #loc40)) +#loc46 = loc("jit(func)/shard_map"(#loc43)) +#loc49 = loc("shard_map:"(#loc46)) +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} { + sdy.mesh @mesh = <["a"=2]> loc(#loc) + func.func public @main(%arg0: tensor<2x4xf32> loc("x")) -> (tensor<2x4xf32> {jax.result_info = "result"}) { + %0 = sdy.sharding_constraint %arg0 <@mesh, [{"a"}, {}]> : tensor<2x4xf32> loc(#loc48) + %1 = sdy.manual_computation(%0) in_shardings=[<@mesh, [{"a"}, {}]>] out_shardings=[<@mesh, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<1x4xf32> loc("shard_map:"(#loc46))) { + %2 = "stablehlo.collective_permute"(%arg1) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 1], [1, 0]]> : tensor<2x2xi64>}> : (tensor<1x4xf32>) -> tensor<1x4xf32> loc(#loc50) + sdy.return %2 : tensor<1x4xf32> loc(#loc49) + } : (tensor<2x4xf32>) -> tensor<2x4xf32> loc(#loc49) + return %1 : tensor<2x4xf32> loc(#loc) } loc(#loc) - func.func @xla.sdy.manual_computation_body(%arg0: tensor<1x4xf32> loc("jit(func)/jit(main)/shard_map"(#loc3))) -> tensor<1x4xf32> { - %0 = "stablehlo.collective_permute"(%arg0) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 1], [1, 0]]> : tensor<2x2xi64>}> : (tensor<1x4xf32>) -> tensor<1x4xf32> loc(#loc7) - return %0 : tensor<1x4xf32> loc(#loc6) - } loc(#loc6) } loc(#loc) #loc = loc(unknown) -#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":1051:10) -#loc4 = loc("third_party/py/jax/tests/export_back_compat_test.py":1050:15) -#loc5 = loc("jit(func)/jit(main)/sharding_constraint"(#loc2)) -#loc7 = loc("jit(func)/jit(main)/ppermute"(#loc4)) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":991:10 to :73) +#loc13 = loc("third_party/py/jax/tests/export_back_compat_test.py":990:15 to :46) +#loc14 = loc("ShardyCompatTest.test_shardy_sharding_ops_with_different_meshes..func"(#loc2)) +#loc25 = loc("ShardyCompatTest.test_shardy_sharding_ops_with_different_meshes..func..shard_map_func"(#loc13)) +#loc27 = loc(callsite(#loc21 at #loc22)) +#loc29 = loc(callsite(#loc20 at #loc27)) +#loc31 = loc(callsite(#loc19 at #loc29)) +#loc33 = loc(callsite(#loc18 at #loc31)) +#loc35 = loc(callsite(#loc17 at #loc33)) +#loc37 = loc(callsite(#loc16 at #loc35)) +#loc39 = loc(callsite(#loc15 at #loc37)) +#loc41 = loc(callsite(#loc24 at #loc39)) +#loc42 = loc(callsite(#loc14 at #loc40)) +#loc44 = loc(callsite(#loc25 at #loc41)) +#loc45 = loc("jit(func)/sharding_constraint"(#loc42)) +#loc47 = loc("jit(func)/ppermute"(#loc44)) +#loc48 = loc("sharding_constraint:"(#loc45)) +#loc50 = loc("ppermute:"(#loc47)) """, - mlir_module_serialized=b'ML\xefR\rStableHLO_v1.8.8\x00\x01\x1d\x05\x01\x05\r\x01\x03\x0b\x03\x0b\x0f\x13\x17\x1b\x1f\x03\x97q\x13\x019\x0f\x07\x0b\x0b+\x0b\x0f\x13\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x17\x0f\x0b\x17\x0f\x0b\x1b\x0b\x0f\x0b\x17\x13\x039\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0f\x8f\x13\x0b\x0b\x0b\x0b#\x0b\x0b\x0b\x0b\x0b\x01\x05\x0f\x0b\x03\x0f\x17\x17\x07\x07\x17\x17\x17\x02v\x03\x1d\x1f!\x1f\x05\x11\x05\x13\x03\t\x0b\r\x05\x0f\x15\x17\x19\x1b\x05\x15\x11\x03\x00\x03\x03\x11\x13\x05\x17\x05\x19\x05\x1b\x11\x01\t\x05\x1d\x11\x01\x05\x05\x1f\x05!\x17\x07r\x10\x1b\x1d%\'\x05#\x17\x07j\x10\x1f\x1d+\x03\x05%\x03\x05\x05[/_\x05\'\x1d35\x05)\x17\x07n\x10\x15\x03\x03\x05e\x03\x01\x1d+\x1d-\x0b\x03\x05\x01\x1d/\x03\x03G\r\x01#\r\x03\x03M\r\x03O;\x1d1\x1d3\x1d5#\x0f\x13\x0b\x05\x1f\x11A\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\r\x03]=\x1d7\x1d9\x1d;\x1d=\r\x07g=ikm=\x1d?\x1dA\x1dC\x1dE\x1dG\x01\x02\x02\x01\t)\x05\x05\x11\t)\x05\t\x11\t\t\x1d\x11\x03\x07\x03\x07\x11\x03\x05\x03\x05)\x05\t\t\x0b\x04\xb9\x05\x01Q\x03\t\x01\x07\x04\xa7\x03\x01\t\x05P\x03\x03\x07\x04]\x03\x0b\x17\x03\x0f)\x00\x03G1-\x05\x03\x07\x03\x01\x03F\x01\x07\x03\x05\x03\x03\x0bG\x017\t\x03\x05\x03\x05\x03F\x01\x0b\x03\x07\x03\x07\x07\x04\x03\x03\t\x05P\x01\r\x07\x04)\x03\x05\x0b\x03\x0b\x01\x00\tF#\x0f\x03\x05\x03\x01\x07\x04\x01\x03\x03\x06\x03\x01\x05\x01\x00r\x0bI7-3)+7\x13+#\x0f\x0b!Ae\x03Q\x1d\x05;=\x13%)=\x1f9i3\x11-\x15\x11\x1f\x0f\x0b\x11builtin\x00vhlo\x00module\x00custom_call_v1\x00func_v1\x00return_v1\x00collective_permute_v1\x00call_v1\x00mhlo.frontend_attributes\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00xla.sdy.meshes\x00{mesh = #sdy.mesh<[\\"a\\"=2]>}\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/shard_map\x00jit(func)/jit(main)/ppermute\x00x\x00mhlo.sharding\x00jit(func)/jit(main)/sharding_constraint\x00\x00#sdy.sharding_per_value<[<@mesh, [{\\"a\\"}, {}]>]>\x00xla.sdy.manual_computation_body\x00jax.result_info\x00main\x00public\x00xla.sdy.sharding\x00{devices=[2,1]<=[2]}\x00Sharding\x00xla.sdy.GlobalToLocalShape\x00xla.sdy.in_shardings\x00xla.sdy.manual_axes\x00#sdy\x00xla.sdy.out_shardings\x00xla.sdy.LocalToGlobalShape\x00\x08a\x11\x05;\x01\x0bEIKQS\x11?;a9A999\x11?;c9A999\x03C\x11?;o9A999\x0b9U9C;\x05WY', + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.10.9\x00\x01)\x07\x01\x05\t\x13\x01\x05\x0f\x13\x03\t\x17\x1b\x1f#\x05\x07'+/\x03\xfb\xcd\x17\x01\xa7\x07\x0b\x0b\x0f\x0f\x0b\x0f\x0b\x0f\x0f\x0f\x0f\x0f\x0f\x0b\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0f\x1f\x0b\x1f\x0f\x0b\x1f\x0f\x0b'\x0f\x0b\x1f\x0f\x0b\x1f\x0f\x0b\x1f\x0f\x0b\x1f\x0f\x0b\x1f\x0f\x0b\x1f\x0b\x0b\x0f\x0b\x0f\x1f\x0b\x0f\x0b\x0f\x0f\x0b\x1f\x0f\x0f\x0f\x0f\x0f\x0f\x0f\x0f\x03\x11\x1b\x0f\x13\x0f\x17\x0f\x13\x0f\x05\x17\x0f\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0f\x8f\x01\x0b\x17\x0f\x07\x17\x0b\x05\r\x17\x07\x17\x07\x17\x17\x02N\x06\x1f\x05\x19\x05\x1b\x1d\x7f\x81\x1d\x89\x8b\x05\x0b\x1d7\x01\x05\x1d\x15\x13M\x1dIK\x1dOQ\x1dUW\x1d[]\x1dac\x05\x1f\x1dgi\x1dmo\x1dsu\x1d\x0f\x87\x03\x07)+-/13\x05!\x11\t\x00\x05#\x11\x03\t\x05%\x11\x03\x05\x05'\x05)\t\x0b\x1d=?\x05+\x1dAC\x05-\x15E\x11\x1d\x0fG-\x03\x07~\x0f\x15\x93\x05/-\x03\x07\xa6\x0f\x11m\x15\x15S\x051-\x05\x07R-'q\x15\x17Y\x053-\x05\t\xe2-G\xea-\x07\x15\x19_\x055-\x05\x07:&\rE\x15\x1be\x057-\x1d\x07R\x06\x1b/\x15\x1fk\x059-\x1d\x07\x92\x07\r7\x15!q\x05;-\x05\x07B&\t?\x15#w\x05=-\x05\x07j$\x05M\x1dy{\x05?-\x03\x07\xb6\x0f\x05_\x05A\x05C\x1d\x83\x85\x05E\x15%\x11-\x03\x07\x82\x0f\x1b=\x05G\x1d\x8d\x8f\x05I\x15\x91\x97\x1d\x93\x95\x05K-\x03\x07z\x0f\x1f]\x15%\x99\x15\x13\x9b\x15\x15\x9d\x15\x17\x9f\x15\x19\xa1\x15\x1b\xa3\x15\x1f\xa5\x15!#\r9\x05\xaf\xb3\x01\x0f\x03\xa7\x05\x03\xad\x01\x03A\t\x0b\x03\xb1\x01\x01\tA\x01\x0b\x01\x01\x01\x01\x03}\x03\x03\xb9\r\x01#\x13\x03\x03\xbf\r\x03\xc1\xc3\x1dM\x1dO\x1d=\x1dQ\x13\x11\x05\x1f\x15A\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1b\x05\t\x11\x05\x01\x02\x02\x0b\x1b\x05\x05\x11\x05\x01\t)\x05\t\x11\r\t)\x05\x05\x11\r\x1d\x11\x03\x0b\x03\x0b)\x05\t\t\x11\x04\xcd\x05\x03Q\x01'\x01\x07\x04\xbb\x03\x01\t\x05@\x01\x03\rP\x01\x05\x07\x04\x9f\x03\x0b\x17\x03\x17\r\x00\x01\x06\r\x03\x01\x03\x01\x07F;\x07\x03\x01\x03\x03\tV\x07\t\x03\x01\x03\x05\x07\x04E\x03\t\x13\x03\x0f\x07\x00\x01\x06\t\x03\x0f\x03\x01\x11F\t\x0b\x03\x0f\x03\x03\x01\x06\t\x03\x07\x03\x05\x0b\x04\x07\x03\x07\x01\x06\x01\x03\x0b\x03\x07\x0f\x04\x01\x03\t\x06\x03\x01\x05\x01\x00\xba\rS\x0f\x0f!\xcd'\x15)\x17\x05\x13\x0b\x19\t\x15G\x155\x81=+\x05\x13%)97\x9dQi-\x15\x11\x0f')\x0b\x0f7\x0b\t\x11builtin\x00sdy\x00vhlo\x00unrealized_conversion_cast\x00module\x00mesh\x00sharding_constraint\x00manual_computation\x00return\x00func_v1\x00return_v1\x00collective_permute_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00third_party/py/absl/testing/absltest.py\x00ShardyCompatTest.test_shardy_sharding_ops_with_different_meshes..func\x00third_party/py/absl/app.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00x\x00sharding_constraint:\x00jit(func)/sharding_constraint\x00ShardyCompatTest.test_shardy_sharding_ops_with_different_meshes\x00_run_and_get_tests_result\x00run_tests\x00_run_in_app..main_function\x00_run_main\x00run\x00_run_in_app\x00main\x00\x00a\x00shard_map:\x00jit(func)/shard_map\x00ppermute:\x00jit(func)/ppermute\x00ShardyCompatTest.test_shardy_sharding_ops_with_different_meshes..func..shard_map_func\x00jax.result_info\x00result\x00public\x00\x08-\r\x05k\x01\x05\xab\x0b\x0b\xb7\xbb\xbd\xc5\xc7\x03\xa7\x07\xa9\xb5\xa9\x05\xc9\xcb", xla_call_module_version=9, nr_devices=2, ) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/stablehlo_dynamic_approx_top_k.py b/jax/_src/internal_test_util/export_back_compat_test_data/stablehlo_dynamic_approx_top_k.py index b676cc8011d3..cf074bdffd42 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/stablehlo_dynamic_approx_top_k.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/stablehlo_dynamic_approx_top_k.py @@ -15,7 +15,11 @@ # ruff: noqa import datetime -from numpy import array, float32, int32 +import numpy as np + +array = np.array +float32 = np.float32 +int32 = np.int32 # Pasted from the test output (see export_back_compat_test_util.py module docstring) data_2024_05_30 = dict( diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/stablehlo_dynamic_rng_bit_generator.py b/jax/_src/internal_test_util/export_back_compat_test_data/stablehlo_dynamic_rng_bit_generator.py index 0a477d4f5265..913346ed93f3 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/stablehlo_dynamic_rng_bit_generator.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/stablehlo_dynamic_rng_bit_generator.py @@ -15,7 +15,11 @@ # ruff: noqa import datetime -from numpy import array, float32, uint32 +import numpy as np + +array = np.array +float32 = np.float32 +uint32 = np.uint32 # Pasted from the test output (see back_compat_test.py module docstring) diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/stablehlo_dynamic_top_k.py b/jax/_src/internal_test_util/export_back_compat_test_data/stablehlo_dynamic_top_k.py index e245a59086a9..08fc91b92f74 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/stablehlo_dynamic_top_k.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/stablehlo_dynamic_top_k.py @@ -15,7 +15,11 @@ # ruff: noqa import datetime -from numpy import array, float32, int32 +import numpy as np + +array = np.array +float32 = np.float32 +int32 = np.int32 # Pasted from the test output (see back_compat_test_util.py module docstring) diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/tpu_ApproxTopK.py b/jax/_src/internal_test_util/export_back_compat_test_data/tpu_ApproxTopK.py index 0a7a28fc4a72..f28225d62e94 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/tpu_ApproxTopK.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/tpu_ApproxTopK.py @@ -15,7 +15,11 @@ # ruff: noqa import datetime -from numpy import array, float32, int32 +import numpy as np + +array = np.array +float32 = np.float32 +int32 = np.int32 # Pasted from the test output (see back_compat_test.py module docstring) data_2023_04_17 = dict( diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/tpu_Eigh.py b/jax/_src/internal_test_util/export_back_compat_test_data/tpu_Eigh.py index 7e5d0d132000..f4d3d4e57cf1 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/tpu_Eigh.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/tpu_Eigh.py @@ -15,7 +15,10 @@ # ruff: noqa import datetime -from numpy import array, float32 +import numpy as np + +array = np.array +float32 = np.float32 # Pasted from the test output (see module docstring) data = dict( diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/tpu_Lu.py b/jax/_src/internal_test_util/export_back_compat_test_data/tpu_Lu.py index 34656973a072..3e3dcdf45175 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/tpu_Lu.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/tpu_Lu.py @@ -15,7 +15,11 @@ # ruff: noqa import datetime -from numpy import array, float32, int32 +import numpy as np + +array = np.array +float32 = np.float32 +int32 = np.int32 # Pasted from the test output (see back_compat_test.py module docstring) data_2023_03_21 = dict( diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/tpu_Qr.py b/jax/_src/internal_test_util/export_back_compat_test_data/tpu_Qr.py index a001a9754d29..065804dd09cc 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/tpu_Qr.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/tpu_Qr.py @@ -15,7 +15,10 @@ # ruff: noqa import datetime -from numpy import array, float32 +import numpy as np + +array = np.array +float32 = np.float32 # Pasted from the test output (see back_compat_test.py module docstring) data_2023_03_17 = dict( diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/tpu_Sharding.py b/jax/_src/internal_test_util/export_back_compat_test_data/tpu_Sharding.py index f2d8be3b958a..7a11e893db4a 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/tpu_Sharding.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/tpu_Sharding.py @@ -13,10 +13,17 @@ # limitations under the License. import datetime -from numpy import array, float32 +import numpy as np + +array = np.array +float32 = np.float32 + + +data_2025_06_30 = {} + # Pasted from the test output (see module docstring) -data_2023_03_16 = dict( +data_2025_06_30['gspmd'] = dict( testdata_version=1, platform='tpu', custom_call_targets=['SPMDFullToShardShape', 'SPMDShardToFullShape', 'Sharding'], @@ -47,3 +54,39 @@ xla_call_module_version=4, nr_devices=2, ) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_06_30['shardy'] = dict( + testdata_version=1, + platform='tpu', + custom_call_targets=[], + serialized_date=datetime.date(2025, 6, 30), + inputs=(array([[0., 1., 2., 3.], + [4., 5., 6., 7.]], dtype=float32),), + expected_outputs=(array([[4., 5., 6., 7.], + [0., 1., 2., 3.]], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("x") +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":792:6) +#loc4 = loc("jit(func)/shard_map"(#loc2)) +#loc6 = loc("shard_map:"(#loc4)) +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} { + sdy.mesh @mesh = <["a"=2]> loc(#loc) + func.func public @main(%arg0: tensor<2x4xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>} loc("x")) -> (tensor<2x4xf32> {jax.result_info = "result", sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}) { + %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"a"}, {}]>] out_shardings=[<@mesh, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<1x4xf32> loc("shard_map:"(#loc4))) { + %1 = "stablehlo.collective_permute"(%arg1) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 1], [1, 0]]> : tensor<2x2xi64>}> : (tensor<1x4xf32>) -> tensor<1x4xf32> loc(#loc7) + sdy.return %1 : tensor<1x4xf32> loc(#loc6) + } : (tensor<2x4xf32>) -> tensor<2x4xf32> loc(#loc6) + return %0 : tensor<2x4xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":783:13) +#loc5 = loc("jit(func)/ppermute"(#loc3)) +#loc7 = loc("ppermute:"(#loc5)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.10.9\x00\x01'\x07\x01\x05\t\x11\x01\x05\x0f\x13\x03\x07\x17\x1b\x1f\x05\x07#'+\x03\x89[\x17\x013\x07\x0f\x0f\x0b\x0f\x0b#\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x17\x0b\x0f\x0b\x17\x03\x11\x1b\x0f\x13\x0f\x17\x0f\x13\x0f\x05\x19\x0b\x0f\x13\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x0f\x8f\x01\x0b\x0f\x17\x07\x17\x0b\x05\r\x17\x07\x17\x07\x17\x17\x022\x03\x1f\x1d#%\x1d+-\x05\x0b\x1d\x1f\x01\x05\x17\x03\x07\x0f\x11\x13\x15\x17\x19\x05\x19\x11\t\x00\x05\x1b\x11\x01\t\x05\x1d\x11\x01\x05\x05\x1f\t\x07\x05!\x05#\x05%\x1d')\x05'\x17\x0bb\x0c\r\x05)\x1d/1\x05+\x17\x0b>\x0c\x1b\r\x1d\x05;?\x01\x0f\x033\x05\x039\x01\x03#\t\x0b\x03=\x01\x01\t#\x01\x0b\x01\x01\x01\x01\x03!\x1d-\x03\x03G\r\x03C3#\x13\x03\x03M\r\x05OQC3\x1d/\x1d1\x1d3\x1d5\x13\x11\x05\x1f\x15A\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02\x1b\x05\t\x11\x05\x0b\x1b\x05\x05\x11\x05\x01\t)\x05\t\x11\r\t)\x05\x05\x11\r\x1d\x11\x03\x0b\x03\x0b)\x05\t\t\x11\x04\xbd\x05\x03Q\x01\r\x01\x07\x04\xab\x03\x01\t\x05@\x01\x03\x0bP\x01\x05\x07\x04\x8f\x03\t\x13\x03\x17\t\x00\x01\x06\t\x03\x03\x03\x01\x07V\x03\x07\x03\x03\x03\x03\x07\x04E\x03\t\x13\x03\x0f\x03\x00\x01\x06\x05\x03\x0f\x03\x01\x0fF\x05\t\x03\x0f\x03\x03\x01\x06\x05\x03\x07\x03\x05\t\x04\x03\x03\x07\x01\x06\x01\x03\x0b\x03\x05\r\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00\x16\x067\x0f\x0b\x0f!\x1b'\x15)\x17\x05\x05\x13%)9i-\x15\x11\x0f'\x0b\x0f7\x0b\t\x11builtin\x00sdy\x00vhlo\x00unrealized_conversion_cast\x00module\x00mesh\x00manual_computation\x00return\x00func_v1\x00return_v1\x00collective_permute_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00x\x00a\x00shard_map:\x00jit(func)/shard_map\x00ppermute:\x00jit(func)/ppermute\x00sdy.sharding\x00jax.result_info\x00result\x00main\x00public\x00\x08)\x0b\x057\x01\x057\x07\x0bEIKSU\x075A5\x05WY", + xla_call_module_version=9, + nr_devices=2, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/tpu_stablehlo_dynamic_reduce_window.py b/jax/_src/internal_test_util/export_back_compat_test_data/tpu_stablehlo_dynamic_reduce_window.py index 55a4c615334d..4b0268a887b2 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/tpu_stablehlo_dynamic_reduce_window.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/tpu_stablehlo_dynamic_reduce_window.py @@ -15,7 +15,11 @@ # ruff: noqa import datetime -from numpy import array, float32, int32 +import numpy as np + +array = np.array +float32 = np.float32 +int32 = np.int32 # Pasted from the test output (see back_compat_test.py module docstring) diff --git a/jax/_src/internal_test_util/export_back_compat_test_util.py b/jax/_src/internal_test_util/export_back_compat_test_util.py index 5d5e95b5cb9a..ef45a2407778 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_util.py +++ b/jax/_src/internal_test_util/export_back_compat_test_util.py @@ -43,18 +43,27 @@ class CompatTest(bctu.CompatTestBase) def test_foo_call(self): def func(...): ... - inputs = (...,) # Tuple of nd.array, keep it small, perhaps generate the - # inputs in `func`. + inputs = (...,) # Tuple of nd.array, keep it small, perhaps have it + # empty and generate the inputs in `func`. data = self.starter_data(inputs) # This is temporary, just for starting. self.run_one_test(func, data) The test will fail, but will save to a file the test data you will need. The -file name will be printed in the logs. Create a new -file jax/_src/internal_test_util/export_back_compat_test_data/foo_call.py +file name will be printed in the logs. +For Google internal tests, the file will be saved in the Test Artifacts. +Check the file to see if it contains the custom call that you expect. + +Often when we change a lowering, we keep the old one for 30 days for +forward compatibility. If you see the old custom call, you can add +`with config.export_ignore_forward_compatibility(True):` around the +`self.run_one_test` method call to make it generate the new lowering. + +Now create a new file +jax/_src/internal_test_util/export_back_compat_test_data/foo_call.py and paste the test data that you will see printed in the logs. Name the literal `data_YYYYY_MM_DD` to include the date of serialization -(for readability only). Then add to this file: +(used for readability only). Then add to this file: from jax._src.internal_test_util.export_back_compat_test_data import foo_call @@ -80,19 +89,24 @@ def func(...): ... from absl import logging import numpy as np -# Import some NumPy symbols so that we can parse repr(ndarray). -from numpy import array, float32 - -import jax -from jax import tree_util -from jax import export - -from jax.experimental import pjit +from jax._src import api from jax._src import core +from jax._src import pjit +from jax._src import stages from jax._src import test_util as jtu +from jax._src import tree_util from jax._src import xla_bridge as xb +from jax._src.export import _export +from jax._src.export import shape_poly +from jax._src.export import shape_poly_decision +from jax._src.typing import Array + +del shape_poly_decision # Imported for its side-effect only. +# Alias some NumPy symbols so that we can parse repr(ndarray). +array = np.array +float32 = np.float32 CURRENT_TESTDATA_VERSION = 1 @@ -136,7 +150,7 @@ class CompatTestBase(jtu.JaxTestCase): """Base class with helper functions for backward compatibility tests.""" def default_jax_backend(self) -> str: # Canonicalize to turn into "cuda" or "rocm" - return xb.canonicalize_platform(jax.default_backend()) + return xb.canonicalize_platform(xb.default_backend()) def starter_data(self, inputs: Sequence[np.ndarray]) -> CompatTestData: # Helper for starting a test, see module docstring. @@ -165,7 +179,8 @@ def load_testdata_nested(self, testdata_nest) -> Iterable[CompatTestData]: else: assert False, testdata_nest - def run_one_test(self, func: Callable[..., jax.Array], + def run_one_test(self, + func: Callable[..., Array] | stages.Wrapped, data: CompatTestData, polymorphic_shapes: Sequence[str] | None = None, rtol: float | None = None, @@ -176,7 +191,8 @@ def run_one_test(self, func: Callable[..., jax.Array], """Run one compatibility test. Args: - func: the JAX function to serialize and run + func: the JAX function to serialize and run, either as a Python Callable + or as a `jax.jit(callable)`. data: the test data polymorphic_shapes: when using shape polymorphism, the specification for each argument of `func`. @@ -269,19 +285,22 @@ def run_one_test(self, func: Callable[..., jax.Array], expect_current_custom_calls = data.custom_call_targets self.assertItemsEqual(expect_current_custom_calls, current_custom_call_targets) - def run_current(self, func: Callable, data: CompatTestData): + def run_current(self, + func: Callable | stages.Wrapped, + data: CompatTestData): """Lowers and runs the test function at the current JAX version.""" - return jax.jit(func)(*data.inputs) + jit_func = func if isinstance(func, stages.Wrapped) else api.jit(func) + return jit_func(*data.inputs) def serialize(self, - func: Callable, data: CompatTestData, *, + func: Callable | stages.Wrapped, data: CompatTestData, *, polymorphic_shapes: Sequence[str] | None = None, allow_unstable_custom_call_targets: Sequence[str] = () ) -> tuple[bytes, str, int, int]: """Serializes the test function. Args: - func: the function to serialize + func: the function to serialize. polymorphic_shapes: the polymorphic_shapes to use for serialization allow_unstable_custom_call_targets: whether to allow additional custom call targets besides those known as stable. @@ -291,12 +310,13 @@ def serialize(self, (d) the number of devices for which the module was serialized. """ # Use the native exporter, to make sure we get the proper serialization. - args_specs = export.symbolic_args_specs(data.inputs, polymorphic_shapes) - exported = export.export( - jax.jit(func), + args_specs = shape_poly.symbolic_args_specs(data.inputs, polymorphic_shapes) + jit_func = func if isinstance(func, stages.Wrapped) else api.jit(func) + exported = _export.export( + jit_func, platforms=(self.default_jax_backend(),), disabled_checks=tuple( - export.DisabledSafetyCheck.custom_call(target) + _export.DisabledSafetyCheck.custom_call(target) for target in allow_unstable_custom_call_targets) )(*args_specs) @@ -308,13 +328,13 @@ def serialize(self, def run_serialized(self, data: CompatTestData, polymorphic_shapes: Sequence[str] | None = None): - args_specs = export.symbolic_args_specs(data.inputs, polymorphic_shapes) + args_specs = shape_poly.symbolic_args_specs(data.inputs, polymorphic_shapes) def ndarray_to_aval(a: np.ndarray) -> core.ShapedArray: return core.ShapedArray(a.shape, a.dtype) in_avals_tree = tree_util.tree_map(ndarray_to_aval, args_specs) # TODO: we ought to ensure that out_avals are polymorphic if need be. We # could either save the in/out_avals (but we need to first implement that - # support in export), or we can just re-use them from the current + # support in export), or we can just reuse them from the current # exported. out_avals_tree = tree_util.tree_map(ndarray_to_aval, data.expected_outputs) # in_tree must be for (args, kwargs) @@ -323,12 +343,15 @@ def ndarray_to_aval(a: np.ndarray) -> core.ShapedArray: def _get_vjp(_): assert False # We do not have and do not need VJP - exported = export.Exported( + exported = _export.Exported( fun_name="run_serialized", in_tree=in_tree, in_avals=tuple(in_avals), out_tree=out_tree, out_avals=tuple(out_avals), + _has_named_shardings=True, + _in_named_shardings=(None,) * len(in_avals), + _out_named_shardings=(None,) * len(out_avals), in_shardings_hlo=(None,) * len(in_avals), out_shardings_hlo=(None,) * len(out_avals), platforms=(data.platform,), diff --git a/jax/_src/internal_test_util/lax_test_util.py b/jax/_src/internal_test_util/lax_test_util.py index 4e28791e9cee..3337971da0ed 100644 --- a/jax/_src/internal_test_util/lax_test_util.py +++ b/jax/_src/internal_test_util/lax_test_util.py @@ -23,15 +23,15 @@ import itertools from typing import Union, cast -import jax -from jax import lax +from jax._src import config from jax._src import dtypes +from jax._src import lax from jax._src import test_util from jax._src.util import safe_map, safe_zip import numpy as np -jax.config.parse_flags_with_absl() +config.parse_flags_with_absl() map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip @@ -304,7 +304,7 @@ def lax_ops(): float_dtypes, test_util.rand_uniform, { - np.float32: 1e-5, + np.float32: 2e-5, np.float64: 1e-12, }, ), diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index 48c645c4d033..6559dda6440b 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -44,21 +44,24 @@ from functools import partial from typing import Any, NamedTuple, Union -from absl import testing +from absl.testing import parameterized as absl_parameterized import numpy as np -import jax -from jax import dtypes -from jax import lax -from jax import numpy as jnp - from jax._src import ad_util +from jax._src import api from jax._src import config from jax._src import dispatch +from jax._src import dtypes +from jax._src import lax +from jax._src import numpy as jnp from jax._src import prng +from jax._src import random from jax._src import test_util as jtu +from jax._src import typing +from jax._src import xla_bridge as xb from jax._src.lax import control_flow as lax_control_flow from jax._src.lax import windowed_reductions as lax_windowed_reductions +from jax._src.numpy import linalg as jnp_linalg from jax._src import random as jax_random # mypy generates a lot of false positive due to re-assigned variables. @@ -400,7 +403,7 @@ def parameterized(harnesses: Iterable[Harness], if not cases: # We filtered out all the harnesses. return jtu.skip_on_devices(jtu.device_under_test()) - return testing.parameterized.named_parameters(*cases) + return absl_parameterized.named_parameters(*cases) ############################################################################### @@ -408,11 +411,11 @@ def parameterized(harnesses: Iterable[Harness], ############################################################################### -def _make_unary_elementwise_harness(*, prim, shape=(20, 20), dtype): +def _make_unary_elementwise_harness(*, prim, shape=(20, 20), dtype, **kwargs): define( str(prim), f"shape={jtu.format_shape_dtype_string(shape, dtype)}", - prim.bind, [RandArg(shape, dtype)], + lambda x: prim.bind(x, **kwargs), [RandArg(shape, dtype)], prim=prim, dtype=dtype, shape=shape) @@ -429,19 +432,19 @@ def _make_unary_elementwise_harness(*, prim, shape=(20, 20), dtype): _make_unary_elementwise_harness(prim=lax.acos_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.atan_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.asin_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.cos_p, dtype=dtype) + _make_unary_elementwise_harness(prim=lax.cos_p, dtype=dtype, accuracy=None) _make_unary_elementwise_harness(prim=lax.cosh_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.exp_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.expm1_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.log_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.log1p_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.rsqrt_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.sin_p, dtype=dtype) + _make_unary_elementwise_harness(prim=lax.exp_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.expm1_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.log_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.log1p_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.rsqrt_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.sin_p, dtype=dtype, accuracy=None) _make_unary_elementwise_harness(prim=lax.sinh_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.sqrt_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.tan_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.tanh_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.logistic_p, dtype=dtype) + _make_unary_elementwise_harness(prim=lax.sqrt_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.tan_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.tanh_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.logistic_p, dtype=dtype, accuracy=None) for dtype in jtu.dtypes.all_floating: _make_unary_elementwise_harness(prim=lax.bessel_i0e_p, dtype=dtype) @@ -649,13 +652,13 @@ def _make_device_put_harness(name, shape=(3, 4), dtype=np.float32, device=None): - _device_fn = lambda: jax.devices(device)[0] if device is not None else None + _device_fn = lambda: xb.devices(device)[0] if device is not None else None define( "device_put", f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_{device=}", lambda x: dispatch.device_put_p.bind( - x, devices=[_device_fn()], srcs=[None], - copy_semantics=[dispatch.CopySemantics.ALIAS])[0], + x, devices=(_device_fn(),), srcs=(None,), + copy_semantics=(dispatch.ArrayCopySemantics.REUSE_INPUT,))[0], [RandArg(shape, dtype)], shape=shape, dtype=dtype, @@ -802,7 +805,7 @@ def _make_argminmax_harness(prim, name, *, shape=(15,), - dtype=jnp.float32, + dtype=np.float32, axes=(0,), index_dtype=np.int32, arr=None, @@ -2046,7 +2049,7 @@ def linear_solve(a, b, solve, transpose_solve=None, symmetric=False): return lax.custom_linear_solve(matvec, b, solve, transpose_solve, symmetric) def explicit_jacobian_solve(matvec, b): - return lax.stop_gradient(jnp.linalg.solve(jax.jacobian(matvec)(b), b)) + return lax.stop_gradient(jnp_linalg.solve(api.jacobian(matvec)(b), b)) def _make_harness(name, *, @@ -2345,7 +2348,7 @@ def _make_select_and_scatter_add_harness(name, padding=((0, 0), (0, 0), (0, 0)), nb_inactive_dims=0): ones = (1,) * len(shape) - cotangent_shape = jax.eval_shape( + cotangent_shape = api.eval_shape( lambda x: lax_windowed_reductions._select_and_gather_add( x, x, lax.ge_p, window_dimensions, window_strides, padding, ones, ones), @@ -2719,20 +2722,20 @@ def _make_reducer_harness(prim, define( "random_gamma", f"shape={jtu.format_shape_dtype_string(shape, dtype)}", - jax.jit(lambda x: jax_random.gamma(jax.random.key(42), x)), + api.jit(lambda x: jax_random.gamma(random.key(42), x)), [RandArg(shape, dtype)], dtype=dtype) def wrap_and_split(): - key = jax.random.key(42) - result = jax.random.split(key, 2) - return jax.random.key_data(result) + key = random.key(42) + result = random.split(key, 2) + return random.key_data(result) define( "random_split", "", - jax.jit(wrap_and_split), + api.jit(wrap_and_split), [], dtype=np.uint32) @@ -2743,8 +2746,9 @@ def wrap_and_split(): define( "random_categorical", f"shape={jtu.format_shape_dtype_string(shape, dtype)}_{axis=}", - lambda x, axis: jax.random.categorical( - jax.random.key(42), x, axis), + lambda x, axis: random.categorical( + # TODO(b/416027995): Change this key back to 42. + random.key(1337), x, axis), [RandArg(shape, dtype), StaticArg(axis)], dtype=dtype, @@ -2755,8 +2759,8 @@ def wrap_and_split(): define( "random_uniform", f"shape={jtu.format_shape_dtype_string(shape, dtype)}", - lambda shape, dtype: jax.random.uniform( - jax.random.key(42), shape, dtype), + lambda shape, dtype: random.uniform( + random.key(42), shape, dtype), [StaticArg(shape), StaticArg(dtype)], dtype=dtype) @@ -2768,8 +2772,8 @@ def wrap_and_split(): define( "random_randint", f"shape={jtu.format_shape_dtype_string(shape, dtype)}", - lambda shape, minval, maxval, dtype: jax.random.randint( - jax.random.key(42), shape, minval, maxval, dtype), + lambda shape, minval, maxval, dtype: random.randint( + random.key(42), shape, minval, maxval, dtype), [StaticArg(shape), StaticArg(-5), # minval StaticArg(maxval), @@ -2842,6 +2846,12 @@ def _make_dot_general_harness(name, if preferred_element_type is not None: suffix += f"_preferred={jtu.dtype_str(preferred_element_type)}" + if ( + preferred_element_type in (np.float64, np.int64, np.complex128) + and not config.enable_x64.value + ): + return + define( lax.dot_general_p, f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, lhs_dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)}_dimensionnumbers={dimension_numbers}{suffix}_enable_xla={enable_xla}" @@ -3030,6 +3040,12 @@ def _make_conv_harness(name, works_without_xla=False): enable_xla_cases = [True, False] if works_without_xla else [True] + if ( + preferred_element_type in (np.float64, np.int64, np.complex128) + and not config.enable_x64.value + ): + return + for enable_xla in enable_xla_cases: define( lax.conv_general_dilated_p, @@ -3130,7 +3146,7 @@ def _make_conv_harness(name, # feature_group_count is supported for enable_xla=False only if we are doing a # depthwise convolution, i.e.: in_channels == feature_group_count. # See explanation of depthwise convolution at -# https://www.tensorflow.org/xla/operation_semantics#conv_convolution. +# https://www.openxla.org/xla/operation_semantics#conv_convolution. _make_conv_harness( "depthwise2d", lhs_shape=(2, 3, 9, 9), # "NCHW": in_channels == 3 @@ -3361,7 +3377,7 @@ def _make_conv_harness(name, lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation) -key_types: list[tuple[tuple[int, ...], jax.typing.DTypeLike]] +key_types: list[tuple[tuple[int, ...], typing.DTypeLike]] key_types = [((4,), np.uint32)] if config.enable_x64.value: key_types.append(((2,), np.uint64)) @@ -3369,14 +3385,15 @@ def _make_conv_harness(name, for algorithm in [lax.RandomAlgorithm.RNG_THREE_FRY, lax.RandomAlgorithm.RNG_PHILOX, lax.RandomAlgorithm.RNG_DEFAULT]: - for dtype in [np.uint32, np.uint64]: + for dtype in jtu.dtypes.unsigned: for shape in [(), (5, 7), (100, 100)]: for key_shape, key_dtype in key_types: define( lax.rng_bit_generator_p, f"{key_dtype=}_shape={jtu.format_shape_dtype_string(shape, dtype)}_{algorithm=}", - lambda key, shape, dtype, algorithm: lax.rng_bit_generator(key, shape, dtype=dtype, - algorithm=algorithm), + lambda key, shape, dtype, algorithm, out_sharding=None: lax.rng_bit_generator( + key, shape, dtype=dtype, algorithm=algorithm, + out_sharding=out_sharding), [RandArg(key_shape, key_dtype), StaticArg(shape), StaticArg(dtype), StaticArg(algorithm)], shape=shape, @@ -3390,7 +3407,7 @@ def _make_iota_2x32_shape_harness(shape): f"shape=({shapestr})", lambda shape: prng.iota_2x32_shape_p.bind(shape=shape), [StaticArg(shape)], - dtype=jnp.uint32, + dtype=np.uint32, shape=shape) for shape in [(3,), (5, 7, 4), (100, 100)]: diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index ddf96af6a010..5d3af4d9ced8 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -21,44 +21,42 @@ from functools import partial from typing import Any +from jax._src import api_util from jax._src import config -from jax._src import dispatch from jax._src import linear_util as lu from jax._src.interpreters import partial_eval as pe -from jax.tree_util import (tree_flatten, tree_unflatten, - register_pytree_node, Partial, PyTreeDef) +from jax._src.tree_util import (tree_flatten, tree_unflatten, + register_pytree_node, PyTreeDef) from jax._src import mesh as mesh_lib from jax._src import core from jax._src import source_info_util from jax._src.ad_util import ( add_jaxvals, replace_internal_symbolic_zeros, replace_rule_output_symbolic_zeros, Zero, zeros_like_aval, SymbolicZero) -from jax._src.ad_util import zeros_like_p, add_jaxvals_p # noqa: F401 +from jax._src.ad_util import add_jaxvals_p from jax._src.api_util import flatten_fun, flatten_fun_nokwargs, debug_info -from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal) +from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal, + typeof) from jax._src.dtypes import dtype, float0 +from jax._src.state.types import AbstractRef from jax._src.util import (unzip2, safe_map, safe_zip, split_list, wrap_name, as_hashable_function, weakref_lru_cache, partition_list, subs_list2, foreach) +Array = Any +Ref = Any zip = safe_zip map = safe_map def identity(x): return x def _update_annotation( f: lu.WrappedFun, - orig_type: tuple[tuple[core.AbstractValue, bool], ...] | None, - explicit_nonzeros: list[bool] + orig_type: tuple[core.AbstractValue, ...] | None, + nonzeros: list[bool] ) -> lu.WrappedFun: if orig_type is None: return f - # By convention, `explicit_nonzeros` only accounts for explicit arguments. - assert len(explicit_nonzeros) == sum(explicit for _, explicit in orig_type) - # Implicit arguments never have tangents, so generate the tangent part of the - # type annotation from explicit arguments only. - explicit_avals = [aval for aval, explicit in orig_type if explicit] - tan_types = [(aval.to_tangent_aval(), True) - for nz, aval in zip(explicit_nonzeros, explicit_avals) if nz] + tan_types = [aval.to_tangent_aval() for nz, aval in zip(nonzeros, orig_type) if nz] return lu.annotate(f, (*orig_type, *tan_types)) def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True, @@ -73,6 +71,7 @@ def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True, def jvpfun(f: Callable, instantiate, transform_stack, primals, tangents): tag = core.TraceTag() tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero) + and isinstance(typeof(t), core.ShapedArray) and dtype(t) == float0 else t for t in tangents] ctx = (source_info_util.transform_name_stack('jvp') if transform_stack else contextlib.nullcontext()) @@ -89,12 +88,14 @@ def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag, nzs_in: Sequence[bool], debug_info: core.DebugInfo, *primals, **params): + source_info = source_info_util.current() with core.take_current_trace() as parent_trace: - tangent_trace = pe.DynamicJaxprTrace(debug_info) + tangent_trace = pe.DynamicJaxprTrace(debug_info, auto_dce=True) tangent_trace.tag = _tag linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=_tag) tracers = [LinearizeTracer(linearize_trace, p, - tangent_trace.new_arg(get_aval(p).to_tangent_aval())) + tangent_trace.new_arg(get_aval(p).to_tangent_aval(), + source_info)) if nz else p for p, nz in zip(primals, nzs_in)] with core.set_current_trace(linearize_trace, check_leaks=True): @@ -103,10 +104,8 @@ def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag, del linearize_trace, ans, tracers nzs_out = tuple(type(t) is not Zero for t in out_tangents) out_tangents = tuple(t for t, nz in zip(out_tangents, nzs_out) if nz) - out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents) # type: ignore[assignment] - jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, debug_info) - if attrs_tracked: - raise NotImplementedError("TODO: attrs") + out_tangents = map(partial(tangent_trace.to_jaxpr_tracer, source_info=source_info), out_tangents) # type: ignore[assignment] + jaxpr, consts = tangent_trace.to_jaxpr(out_tangents, debug_info.with_unknown_names(), source_info) which_env = [(isinstance(c, pe.DynamicJaxprTracer) and getattr(c._trace, 'tag', None) is _tag) for c in consts] jaxpr = pe.move_envvars(jaxpr, tuple(which_env)) @@ -146,75 +145,105 @@ def jvp_subtrace_aux(f, store, tag, primals, tangents): store.store(aux_primals) return out_primals, out_tangents -def convert_constvars_jaxpr_constvars_at_end(jaxpr: core.Jaxpr) -> core.Jaxpr: - dbg = jaxpr.debug_info._replace( - arg_names=jaxpr.debug_info.arg_names + ("",) * len(jaxpr.constvars)) - return core.Jaxpr(constvars=(), - invars=jaxpr.invars + jaxpr.constvars, - outvars=jaxpr.outvars, eqns=jaxpr.eqns, - effects=jaxpr.effects, debug_info=dbg) - def linearize_jaxpr( jaxpr: core.ClosedJaxpr, - nonzeros: Sequence[bool] - ) -> tuple[core.ClosedJaxpr, int, Sequence[bool], core.ClosedJaxpr]: - return _linearize_jaxpr(jaxpr, tuple(nonzeros)) + nonzeros: Sequence[bool], + instantiate: bool | Sequence[bool] = False, + allow_fwds: bool | Sequence[bool] = True, +) -> tuple[core.ClosedJaxpr, int, Sequence[bool], Sequence[int | None], core.ClosedJaxpr]: + if type(allow_fwds) is bool: + allow_fwds = (allow_fwds,) * (len(jaxpr.consts) + len(jaxpr.jaxpr.invars)) + assert len(allow_fwds) == (len(jaxpr.consts) + len(jaxpr.jaxpr.invars)) + if type(instantiate) is bool: + instantiate = (instantiate,) * len(jaxpr.jaxpr.outvars) + assert len(instantiate) == len(jaxpr.jaxpr.outvars) + return _linearize_jaxpr(jaxpr, tuple(nonzeros), tuple(instantiate), + tuple(allow_fwds)) @weakref_lru_cache @source_info_util.reset_name_stack() def _linearize_jaxpr( jaxpr: core.ClosedJaxpr, - nonzeros: tuple[bool, ...] - ) -> tuple[core.ClosedJaxpr, int, Sequence[bool], core.ClosedJaxpr]: + nonzeros: tuple[bool, ...], + instantiate: tuple[bool, ...], + allow_fwds: tuple[bool, ...], +) -> tuple[core.ClosedJaxpr, int, Sequence[bool], Sequence[int | None], core.ClosedJaxpr]: dbg = jaxpr.jaxpr.debug_info + config.enable_checks.value and dbg.assert_arg_names(len(nonzeros)) primal_trace = pe.DynamicJaxprTrace(dbg) - tangent_trace = pe.DynamicJaxprTrace(dbg) + tangent_trace = pe.DynamicJaxprTrace(dbg, auto_dce=True) lin_trace = LinearizeTrace(primal_trace, tangent_trace) tangent_trace.tag = lin_trace.tag - def new_arg(trace, primal_aval, nz): - primal = primal_trace.new_arg(primal_aval) + def new_arg(trace, primal_aval, nz, source_info): + primal = primal_trace.new_arg(primal_aval, source_info) tangent_aval = primal_aval.to_tangent_aval() - tangent = tangent_trace.new_arg(tangent_aval) if nz else Zero(tangent_aval) + tangent = tangent_trace.new_arg(tangent_aval, source_info) if nz else Zero(tangent_aval) return LinearizeTracer(trace, primal, tangent) - tracers = [new_arg(lin_trace, v.aval, nz) - for (v, nz) in zip(jaxpr.jaxpr.invars, nonzeros)] + source_info = source_info_util.current() + tracers = [new_arg(lin_trace, a, nz, source_info) + for (a, nz) in zip(jaxpr.in_aval_qdds, nonzeros)] + in_primals = [t.primal for t in tracers] with core.set_current_trace(lin_trace, check_leaks=True): ans = core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *tracers) out_primals, out_tangents = unzip2(map(lin_trace.to_primal_tangent_pair, ans)) - del lin_trace, ans, tracers, new_arg + out_tangents = [instantiate_zeros(t) if inst else t + for t, inst in zip(out_tangents, instantiate)] + del lin_trace, ans, new_arg, tracers - debug_info = jaxpr.jaxpr.debug_info + # pe._check_no_returned_refs(debug_info, out_tangents) nzs_out = [type(t) is not Zero for t in out_tangents] - out_tangents = tuple(tangent_trace.to_jaxpr_tracer(t) - for (nz, t) in zip(nzs_out, out_tangents) if nz) - tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, debug_info) + out_tangents = [tangent_trace.to_jaxpr_tracer(t, source_info) + for (nz, t) in zip(nzs_out, out_tangents) if nz] + tangent_jaxpr, tangent_consts = tangent_trace.to_jaxpr( + out_tangents, dbg.with_unknown_names(), source_info) tangent_trace.invalidate() - if attrs_tracked: - raise NotImplementedError("TODO: attrs") - residuals_and_primals = (*tangent_consts, *out_primals) - residuals_and_primals = map(primal_trace.to_jaxpr_tracer, residuals_and_primals) - primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals, debug_info) + tangent_jaxpr, tangent_consts = _dce_consts(tangent_jaxpr, tangent_consts) + tangent_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(tangent_jaxpr)) + + fwd_inputs = (*jaxpr.consts, *in_primals) + id_map = {id(x):i for i, (x,a) in enumerate(zip(fwd_inputs, allow_fwds)) if a} + fwds = [id_map.get(id(c)) for c in tangent_consts] + tangent_consts = [c for c, f in zip(tangent_consts, fwds) if f is None] + del in_primals + + # pe._check_no_returned_refs(debug_info, out_primals) + primals_and_residuals = *out_primals, *tangent_consts + primals_and_residuals = map(partial(primal_trace.to_jaxpr_tracer, source_info=source_info), + primals_and_residuals) + primal_jaxpr, primal_consts = primal_trace.to_jaxpr( + primals_and_residuals, dbg.with_unknown_names(), + source_info) primal_trace.invalidate() - num_residuals = len(tangent_consts) - tangent_jaxpr = pe.close_jaxpr(convert_constvars_jaxpr_constvars_at_end(tangent_jaxpr)) - if attrs_tracked: - raise NotImplementedError("TODO: attrs") - return core.ClosedJaxpr(primal_jaxpr, primal_consts), num_residuals, nzs_out, tangent_jaxpr - -def direct_linearize(traceable: lu.WrappedFun, - primals, kwargs, *, has_aux=False, tag=None): + primal_jaxpr, primal_consts = _dce_consts(primal_jaxpr, primal_consts) + primal_jaxpr = core.ClosedJaxpr(primal_jaxpr, primal_consts) + + num_residuals_out = len(tangent_consts) + return primal_jaxpr, num_residuals_out, nzs_out, fwds, tangent_jaxpr + +def _dce_consts(jaxpr, consts): + jaxpr, used_consts, _ = pe.dce_jaxpr_consts( + jaxpr, [True] * len(jaxpr.outvars), + [False] * len(jaxpr.constvars) + [True] * len(jaxpr.invars)) + return jaxpr, [c for c, used in zip(consts, used_consts) if used] + +def direct_linearize(traceable: lu.WrappedFun, primals, kwargs, *, + has_aux=False, tag=None): + dbg = traceable.debug_info.with_unknown_names() with core.take_current_trace() as parent_trace: - tangent_trace = pe.DynamicJaxprTrace(traceable.debug_info) - tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) for p in primals] - tangents = [Zero.from_primal_value(t) if dtype(t) == float0 else t for t in tangents] + source_info = source_info_util.current() + tangent_trace = pe.DynamicJaxprTrace(dbg, auto_dce=True) + tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval(), source_info) for p in primals] + tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero) + and isinstance(typeof(t), core.ShapedArray) + and dtype(t) == float0 else t for t in tangents] linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=tag) tangent_trace.tag = linearize_trace.tag tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)] tracers = [t.full_lower() for t in tracers] - with (core.set_current_trace(linearize_trace, check_leaks=True), + with (core.set_current_trace(linearize_trace), source_info_util.transform_name_stack('jvp')): if has_aux: ans, aux = traceable.call_wrapped(*tracers) @@ -229,9 +258,11 @@ def direct_linearize(traceable: lu.WrappedFun, del linearize_trace, ans, tracers out_nzs = [type(t) is not Zero for t in out_tangents] out_nz_tangents = [t for t, nz in zip(out_tangents, out_nzs) if nz] - out_nz_tangents = map(tangent_trace.to_jaxpr_tracer, out_nz_tangents) - jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_nz_tangents, traceable.debug_info) + out_nz_tangents = map(partial(tangent_trace.to_jaxpr_tracer, + source_info=source_info), out_nz_tangents) + jaxpr, consts = tangent_trace.to_jaxpr(out_nz_tangents, dbg, source_info) tangent_trace.invalidate() + config.enable_checks.value and core.check_jaxpr(jaxpr) jaxpr, used_consts, _ = pe.dce_jaxpr_consts( jaxpr, [True] * len(jaxpr.outvars), [False] * len(jaxpr.constvars) + [True] * len(jaxpr.invars)) @@ -239,15 +270,12 @@ def direct_linearize(traceable: lu.WrappedFun, out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) if nz else pe.PartialVal.known(zeros_like_aval(t.aval)) for t, nz in zip(out_tangents, out_nzs)] - if attrs_tracked: - raise NotImplementedError("TODO: attrs") if has_aux: return out_primals, out_tangents_pvals, jaxpr, consts, aux_primals else: return out_primals, out_tangents_pvals, jaxpr, consts -def linearize(traceable: lu.WrappedFun, - *primals, **kwargs): +def linearize(traceable: lu.WrappedFun, *primals, **kwargs): has_aux = kwargs.pop('has_aux', False) if config.use_direct_linearize.value: return direct_linearize(traceable, primals, kwargs, has_aux=has_aux) @@ -267,164 +295,13 @@ def linearize(traceable: lu.WrappedFun, raise ValueError( "Linearization failed to produce known values for all output primals. " "This is typically caused by attempting to differentiate a function " - "uses an operation that does not support reverse-mode autodiff.") + "using an operation that does not support reverse-mode autodiff.") out_primals_consts = [pval.get_known() for pval in out_primals_pvals] if not has_aux: return out_primals_consts, out_tangents_pvals, jaxpr, consts else: return out_primals_consts, out_tangents_pvals, jaxpr, consts, aux() -def vjp(traceable: lu.WrappedFun, primals, has_aux=False): - if not has_aux: - out_primals, pvals, jaxpr, consts = linearize(traceable, *primals) - else: - out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True) - - def unbound_vjp(pvals, jaxpr, consts, *cts): - cts = tuple(ct for ct, pval in zip(cts, pvals) if not pval.is_known()) - dummy_args = [UndefinedPrimal(v.aval) for v in jaxpr.invars] - arg_cts = backward_pass(jaxpr, True, consts, dummy_args, cts) - return map(instantiate_zeros, arg_cts) - - # Ensure that vjp_ is a PyTree so that we can pass it from the forward to the backward - # pass in a custom VJP. - vjp_ = Partial(partial(unbound_vjp, pvals, jaxpr), consts) - if not has_aux: - return out_primals, vjp_ - else: - return out_primals, vjp_, aux - -def unpair_pval(pval): - aval, const = pval - const_1, const_2 = const - if aval is None: - return (None, const_1), (None, const_2) - else: - aval_1, aval_2 = aval - return (aval_1, const_1), (aval_2, const_2) - -# NOTE: The FIXMEs below are caused by primal/tangent mixups (type -# errors if you will) -def backward_pass(jaxpr: core.Jaxpr, transform_stack, - consts, primals_in, cotangents_in): - if all(type(ct) is Zero for ct in cotangents_in) and not jaxpr.effects: - return map(lambda v: Zero(v.aval), jaxpr.invars) - - def write_cotangent(prim, v, ct): - # assert v not in primal_env - assert ct is not Zero, (prim, v.aval) # check for an old harmless type error - if ct is None or type(v) is Literal: - return - if type(ct) is Zero: - # FIXME: This triggers a lot of failures! - # assert v.aval == ct.aval, (prim, v.aval, ct.aval) - return - ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct - # TODO(mattjj): add back these checks for dynamic shapes - # if config.enable_checks.value: - # ct_aval = core.get_aval(ct_env[v]) - # joined_aval = core.lattice_join(v.aval, ct_aval).strip_weak_type() - # assert v.aval.strip_weak_type() == joined_aval, (prim, v.aval, ct_aval) - - def read_cotangent(v): - return ct_env.pop(v, Zero(v.aval.to_tangent_aval())) - - def read_primal(v): - if type(v) is Literal: - return v.val - else: - a = v.aval - if type(a) is core.DShapedArray: - shape = [primal_env[d] if type(d) is core.Var else d for d in a.shape] - a = a.update(shape=tuple(shape)) - return primal_env.get(v, UndefinedPrimal(a)) - - def write_primal(v, val): - if not is_undefined_primal(val): - primal_env[v] = val - - primal_env: dict[Any, Any] = {} - foreach(write_primal, jaxpr.constvars, consts) - foreach(write_primal, jaxpr.invars, primals_in) - - # Start with a forward pass to evaluate any side-effect-free JaxprEqns that - # only operate on primals. This is required to support primitives with - # linearization rules that include computations on the residuals. - lin_eqns = [] - for eqn in jaxpr.eqns: - # TODO (dfm): The effects check is probably stricter than necessary. - # Consider adding an allowlist of effects here. - if jaxpr.effects or any( - type(x) is not Literal and x not in primal_env for x in eqn.invars): - lin_eqns.append(eqn) - continue - subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) - name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack - traceback = eqn.source_info.traceback - with source_info_util.user_context( - traceback, name_stack=name_stack), eqn.ctx.manager: - ans = eqn.primitive.bind(*subfuns, *map(read_primal, eqn.invars), **bind_params) - if eqn.primitive.multiple_results: - foreach(write_primal, eqn.outvars, ans) - else: - write_primal(eqn.outvars[0], ans) - - ct_env: dict[Any, Any] = {} - ctx = (source_info_util.transform_name_stack('transpose') if transform_stack - else contextlib.nullcontext()) - with ctx: - foreach(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in) - for eqn in lin_eqns[::-1]: - if eqn.primitive.ref_primitive: - if eqn.primitive is core.mutable_array_p: - val_var, = eqn.invars - ref_var, = eqn.outvars - ref = read_primal(ref_var) - ct_out = core.freeze(ref) - write_cotangent(eqn.primitive, val_var, ct_out) - elif eqn.primitive is core.freeze_p: - val_var, = eqn.outvars - ref_var, = eqn.invars # type: ignore - ct_in = instantiate_zeros(read_cotangent(val_var)) - write_primal(ref_var, core.mutable_array(ct_in)) - continue - - invals = map(read_primal, eqn.invars) - if eqn.primitive.multiple_results: - cts_in = map(read_cotangent, eqn.outvars) - else: - cts_in, = map(read_cotangent, eqn.outvars) - name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack - with source_info_util.user_context( - eqn.source_info.traceback, name_stack=name_stack), eqn.ctx.manager: - if eqn.primitive.call_primitive or eqn.primitive.map_primitive: - cts_in_avals = [v.aval for v in eqn.outvars] - params = dict(eqn.params) - call_jaxpr = params.pop('call_jaxpr') - cts_out = get_primitive_transpose(eqn.primitive)( - params, call_jaxpr, invals, cts_in, cts_in_avals) - else: - try: - cts_out = get_primitive_transpose(eqn.primitive)( - cts_in, *invals, **eqn.params) - except (FloatingPointError, ZeroDivisionError) as e: - msg = "When differentiating the code at the top of the callstack:" - if msg not in e.args[0]: - e.args = e.args[0] + f'\n{msg}', - e.args = e.args[0] + f'\n{source_info_util.summarize(eqn.source_info)}', - raise e from None - cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out - # FIXME: Some invars correspond to primals! - foreach(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out) - - cotangents_out = map(read_cotangent, jaxpr.invars) - return cotangents_out - -def closed_backward_pass(jaxpr: core.ClosedJaxpr, transform_stack, - primals_in, cotangents_in): - return backward_pass(jaxpr.jaxpr, transform_stack, jaxpr.consts, - primals_in, cotangents_in) - class UndefinedPrimal: __slots__ = ['aval'] @@ -448,6 +325,205 @@ def get_primitive_transpose(p): "Transpose rule (for reverse-mode differentiation) for '{}' " "not implemented".format(p)) from err + +def backward_pass3( + jaxpr: core.Jaxpr, transform_stack: bool, + consts: Sequence[Array], primals_in: Sequence[Array | Ref | GradAccum], + cotangents_in: Sequence[Array]) -> None: + if all(type(ct) is Zero for ct in cotangents_in) and not jaxpr.effects: + return + + env: dict = dict(zip((*jaxpr.constvars, *jaxpr.invars), + (*consts, *primals_in))) + + def read(x: core.Atom) -> Array | GradAccum: + return x.val if isinstance(x, Literal) else env[x] + + lin_eqns = [] + for eqn in jaxpr.eqns: + if eqn.primitive.ref_primitive: + v, = eqn.outvars + lin_eqns.append(eqn) + if eqn.primitive is core.ref_p: + env[v] = RefAccum(v.aval.inner_aval) # type: ignore + elif eqn.primitive is core.freeze_p: + env[v] = ValAccum(v.aval) + elif eqn.primitive is core.accum_grad_in_ref_p: + env[v] = RefAccum(v.aval) + else: + assert False + elif any(isinstance(read(x), GradAccum) for x in eqn.invars): + for v in eqn.outvars: + env[v] = ValAccum(v.aval) + lin_eqns.append(eqn) + else: + subfuns, params = eqn.primitive.get_bind_params(eqn.params) + with eqn.ctx.manager, _name_stack_ctx(eqn.source_info): + ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **params) + ans = ans if eqn.primitive.multiple_results else [ans] + foreach(env.setdefault, eqn.outvars, ans) + + ctx = (source_info_util.transform_name_stack('transpose') if transform_stack # type: ignore + else contextlib.nullcontext()) + for acc, ct in zip(map(read, jaxpr.outvars), cotangents_in): + if isinstance(acc, GradAccum): + acc.accum(ct) # jaxpr.outvars can have Literals, env can have inst zeros + with ctx: + for eqn in lin_eqns[::-1]: + with eqn.ctx.manager, _name_stack_ctx(eqn.source_info): + if eqn.primitive.ref_primitive: + ct = env.pop(eqn.outvars[0]).freeze() + acc = read(eqn.invars[0]) + if isinstance(acc, GradAccum): + acc.accum(ct) + else: + cts_in = [env.pop(v).freeze() for v in eqn.outvars] + if not eqn.primitive.multiple_results: + cts_in, = cts_in + if eqn.primitive in fancy_transposes: + rule = fancy_transposes[eqn.primitive] + rule(cts_in, *map(read, eqn.invars), **eqn.params) + else: + rule = get_primitive_transpose(eqn.primitive) + primals = map(read, eqn.invars) + up = lambda x: UndefinedPrimal(x.aval) if isinstance(x, GradAccum) else x + if eqn.primitive.call_primitive or eqn.primitive.map_primitive: + # TODO(mattjj,dougalm): remove this path by revising call/map trans + cts_in_avals = [v.aval for v in eqn.outvars] + params = dict(eqn.params) + call_jaxpr = params.pop('call_jaxpr') + cts_out = rule(params, call_jaxpr, map(up, primals), cts_in, cts_in_avals) + else: + cts_out = rule(cts_in, *map(up, primals), **eqn.params) + for x, ct in zip(primals, cts_out): + if isinstance(x, GradAccum): + x.accum(ct) + +def _name_stack_ctx(src_info): + stack = source_info_util.current_name_stack() + src_info.name_stack + return source_info_util.user_context(src_info.traceback, name_stack=stack) + +class GradAccum: + aval: core.AbstractValue + + def accum(self, x) -> None: + assert False + def freeze(self) -> Array | Zero: + assert False + +class RefAccum(GradAccum): + aval: core.AbstractValue + ref: AbstractRef | None + + def __init__(self, aval, ref=None): + self.aval = aval + self.ref = ref + + def accum(self, x): + assert x is not Zero + if isinstance(x, Zero) or x is None: + return + if self.ref is None: + self.ref = core.new_ref(x) + else: + ct_check(self, x) + self.ref.addupdate(x) + + def freeze(self): + if self.ref is None: + return Zero(self.aval) + else: + return core.freeze(self.ref) + + def inst(self): + if self.ref is None: + self.ref = core.new_ref(zeros_like_aval(self.aval)) + return self + +class ValAccum(GradAccum): + aval: core.AbstractValue + val: Array | Zero + + def __init__(self, aval, val=None): + self.aval = aval + self.val = Zero(aval) if val is None else val + + def accum(self, x): + if x is not None: + ct_check(self, x) + self.val = add_tangents(self.val, x) + + def freeze(self): + return self.val + +def ct_check(primal, ct): + if config.disable_bwd_checks.value: + return + ct_aval = ct.aval if type(ct) is Zero else typeof(ct) + ct_aval_expected = primal.aval.to_cotangent_aval() # type: ignore + if not core.typematch(ct_aval, ct_aval_expected, only_shape_shd_check=True): + # TODO(yashkatariya, mattjj): Add primitive name here for + # better error message? + raise ValueError( + f"Input primal JAX type to VJP function is" + f" {primal.aval.str_short()}. Hence the expected" + f" cotangent type is {ct_aval_expected.str_short()} but" + f" got {ct_aval.str_short()}") + +class NullAccum(GradAccum): + def __init__(self): pass + def accum(self, x): return + def freeze(self): assert False + +fancy_transposes: dict[core.Primitive, Callable] = {} + +def project_accums(args): + result, specs = [], [] + for x in args: + if isinstance(x, ValAccum): + specs.append((ValAccum, x.aval)) + elif isinstance(x, RefAccum): + specs.append((RefAccum, x.aval)) + result.append(x.inst().ref) + else: + specs.append((None, typeof(x))) + result.append(x) + return result, tuple(specs) + +def unproject_accums(specs, result): + args, result_ = [], iter(result) + for k, aval in specs: + if k is ValAccum: + args.append(ValAccum(aval)) + elif k is RefAccum: + args.append(RefAccum(aval, next(result_))) + elif k is None: + args.append(next(result_)) + else: + assert False + assert next(result_, None) is None + return args + +def accum_typeof(x): + if isinstance(x, GradAccum): + return x.aval + else: + return typeof(x) + +# TODO(mattjj): this is for for backward (get it?) compatibility. Remove, maybe. +def backward_pass(jaxpr, transform_stack: bool, consts, primals_in, cts_in): + primals_in = [ValAccum(x.aval) if isinstance(x, UndefinedPrimal) else x + for x in primals_in] + backward_pass3(jaxpr, transform_stack, consts, primals_in, cts_in) + return [x.freeze() if isinstance(x, ValAccum) else None + for x in primals_in] + +def closed_backward_pass(jaxpr: core.ClosedJaxpr, transform_stack, + primals_in, cotangents_in): + return backward_pass(jaxpr.jaxpr, transform_stack, jaxpr.consts, + primals_in, cotangents_in) + + @lu.transformation_with_aux2 def nonzero_tangent_outputs(f, store, *args, **kwargs): results = (_, tangents_out) = f(*args, **kwargs) @@ -460,6 +536,7 @@ def __init__(self, parent_trace, tag): super().__init__() self.tag = tag self.parent_trace = parent_trace + self.requires_low = False def to_primal_tangent_pair(self, val): if isinstance(val, JVPTracer) and val._trace.tag is self.tag: @@ -470,7 +547,9 @@ def to_primal_tangent_pair(self, val): def process_primitive(self, primitive, tracers, params): primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) - if all(type(t) is Zero for t in tangents_in): + if (all(type(t) is Zero for t in tangents_in) and + primitive is not core.ref_p and + not any(isinstance(typeof(x), AbstractRef) for x in primals_in)): return primitive.bind_with_trace(self.parent_trace, primals_in, params) jvp = primitive_jvps.get(primitive) if not jvp: @@ -484,6 +563,11 @@ def process_primitive(self, primitive, tracers, params): else: return maybe_jvp_tracer(self, primal_out, tangent_out) + def cur_qdd(self, x): + p, _ = self.to_primal_tangent_pair(x) + with core.set_current_trace(self.parent_trace): + return core.cur_qdd(p) + def process_call(self, call_primitive, f, tracers, params): assert call_primitive.multiple_results primals, tangents = unzip2(map(self.to_primal_tangent_pair, tracers)) @@ -507,7 +591,7 @@ def new_out_axes_thunk(): f_jvp, out_tree = traceable(f_jvp, in_tree) update_params = call_param_updaters.get(call_primitive) new_params = update_params(params, which_nz) if update_params else params - fun_and_args = (_update_annotation(f_jvp, f.in_type, which_nz),) + tuple(args) + fun_and_args = (_update_annotation(f_jvp.with_unknown_names(), f.in_type, which_nz),) + tuple(args) result = call_primitive.bind_with_trace(self.parent_trace, fun_and_args, new_params) primal_out, tangent_out = tree_unflatten(out_tree(), result) tangent_out = [Zero.from_primal_value(p) if t is None else t @@ -547,15 +631,20 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, with core.set_current_trace(self.parent_trace): res_and_primals_out = fwd.call_wrapped(*fwd_in) - _, res_tree = out_trees() - res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) + _, res_tree, input_fwds = out_trees() + num_res_out = res_tree.num_leaves - sum(f is not None for f in input_fwds) + res_out, primals_out = split_list(res_and_primals_out, [num_res_out]) + res_out_ = iter(res_out) + res = [next(res_out_) if f is None else primals_in[f] for f in input_fwds] + assert next(res_out_, None) is None + avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out] - # TODO(frostig,mattjj): avoid instantiating zeros when we don't have to! + in_zeros = [type(t) is Zero for t in tangents_in] + nz_tangents_in = [t for z, t in zip(in_zeros, tangents_in) if not z] with core.set_current_trace(self.parent_trace): - tangents_in = map(instantiate_zeros, tangents_in) tangents_out = custom_lin_p.bind( - *res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd, - out_avals=avals_out, symbolic_zeros=symbolic_zeros) + *res, *nz_tangents_in, num_res=res_tree.num_leaves, bwd=bwd, + out_avals=avals_out, symbolic_zeros=symbolic_zeros, in_zeros=in_zeros) return map(partial(maybe_jvp_tracer, self), primals_out, tangents_out) def process_custom_transpose(self, prim, call, tracers, **params): @@ -591,7 +680,9 @@ def process_custom_transpose(self, prim, call, tracers, **params): return map(partial(maybe_jvp_tracer, self), ps_out, ts_out) def maybe_jvp_tracer(trace, primal, tangent): - if type(tangent) is Zero or dtype(tangent) == float0: + if (type(tangent) is Zero or + isinstance(typeof(tangent), core.ShapedArray) + and dtype(tangent) == float0): return primal else: return JVPTracer(trace, primal, tangent) @@ -606,10 +697,18 @@ def __init__(self, trace, primal, tangent): self.primal = primal self.tangent = tangent + def _short_repr(self): + pp = lambda x: x._short_repr() if isinstance(x, Tracer) else str(x) + primal, tangent = pp(self.primal), pp(self.tangent) + return f'JVPTracer({primal=!s}, {tangent=!s})' + @property def aval(self): return get_aval(self.primal) + def cur_qdd(self): + return core.cur_qdd(self.primal) + def full_lower(self): if type(self.tangent) is Zero: return core.full_lower(self.primal) @@ -622,13 +721,26 @@ def to_concrete_value(self): def get_referent(self): return core.get_referent(self.primal) + def type_state(self): + return self.primal.type_state() + def _primal_tangent_shapes_match(primal, tangent): if type(tangent) is not Zero: primal_aval = get_aval(primal).strip_weak_type() tangent_aval = get_aval(tangent).strip_weak_type() - assert core.definitely_equal_shape(primal_aval.shape, tangent_aval.shape), (primal_aval.shape, tangent_aval.shape) + if not isinstance(primal_aval, core.ShapedArray): + return # TODO(mattjj,dougalm) + assert core.definitely_equal_shape(primal_aval.shape, tangent_aval.shape), ( + primal_aval.shape, tangent_aval.shape) expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(primal_aval.dtype) - assert expected_tangent_dtype == tangent_aval.dtype, (expected_tangent_dtype, tangent_aval.dtype) + assert expected_tangent_dtype == tangent_aval.dtype, ( + expected_tangent_dtype, tangent_aval.dtype) + if (not primal_aval.sharding.mesh.empty and + not tangent_aval.sharding.mesh.empty and + (primal_aval.sharding.mesh._any_axis_explicit or + tangent_aval.sharding.mesh._any_axis_explicit)): + assert primal_aval.sharding == tangent_aval.sharding, ( + primal_aval.sharding, tangent_aval.sharding) call_param_updaters: dict[core.Primitive, Callable] = {} call_linearize_param_updaters: dict[core.Primitive, Callable] = {} @@ -644,6 +756,7 @@ def __init__(self, parent_trace, tangent_trace, tag=None): self.parent_trace = parent_trace self.tangent_trace = tangent_trace self._name_stack_prefix_len = len(source_info_util.current_name_stack()) + self.requires_low = False def _name_stack_suffix(self): return source_info_util.current_name_stack()[self._name_stack_prefix_len:] @@ -658,7 +771,9 @@ def to_primal_tangent_pair(self, val): def process_primitive(self, primitive, args, params): primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, args)) tangent_nzs = [type(t) is not Zero for t in tangents_in] - if all(type(t) is Zero for t in tangents_in): + if (all(type(t) is Zero for t in tangents_in) and + primitive is not core.ref_p and + not any(isinstance(typeof(x), AbstractRef) for x in primals_in)): return primitive.bind_with_trace(self.parent_trace, primals_in, params) fallback = partial(fallback_linearize_rule, primitive) lin = primitive_linearizations.get(primitive, fallback) @@ -674,6 +789,11 @@ def process_primitive(self, primitive, args, params): else: return maybe_linearize_tracer(self, primal_out, tangent_nzs_out, tangent_out) + def cur_qdd(self, x): + p, _ = self.to_primal_tangent_pair(x) + with core.set_current_trace(self.parent_trace): + return core.cur_qdd(p) + def process_custom_jvp_call(self, prim, fun: lu.WrappedFun, f_jvp: lu.WrappedFun, tracers, *, symbolic_zeros: bool): @@ -703,7 +823,7 @@ def _f_jvp(primals, tangents): def process_custom_vjp_call(self, prim, fun, fwd, bwd: lu.WrappedFun, tracers, - out_trees: Callable[[], Sequence[PyTreeDef]], + out_trees: Callable[[], tuple[PyTreeDef, PyTreeDef, list[int | None]]], symbolic_zeros: bool): primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) if all(type(t) is Zero for t in tangents_in): @@ -715,20 +835,24 @@ def process_custom_vjp_call(self, prim, fun, fwd, with core.set_current_trace(self.parent_trace): res_and_primals_out = fwd.call_wrapped(*fwd_in_flat) - _, res_tree = out_trees() - res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) + _, res_tree, input_fwds = out_trees() + num_res_out = res_tree.num_leaves - sum(f is not None for f in input_fwds) + res_out, primals_out = split_list(res_and_primals_out, [num_res_out]) + res_out_ = iter(res_out) + res = [next(res_out_) if f is None else primals_in[f] for f in input_fwds] + assert next(res_out_, None) is None avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out] - tangents_in_zeros = map(instantiate_zeros, tangents_in) + in_zeros = [type(t) is Zero for t in tangents_in] + nz_tangents_in = [t for z, t in zip(in_zeros, tangents_in) if not z] with core.set_current_trace(self.tangent_trace): tangents_out = custom_lin_p.bind( - *res, *tangents_in_zeros, num_res=res_tree.num_leaves, bwd=bwd, - out_avals=avals_out, symbolic_zeros=symbolic_zeros) + *res, *nz_tangents_in, num_res=res_tree.num_leaves, bwd=bwd, + out_avals=avals_out, symbolic_zeros=symbolic_zeros, in_zeros=in_zeros) tangent_nzs_out = [type(t) is not Zero for t in tangents_out] return map(partial(maybe_linearize_tracer, self), primals_out, tangent_nzs_out, tangents_out) - def process_call(self, call_primitive, f: lu.WrappedFun, - tracers, params): + def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params): assert call_primitive.multiple_results primals, tangents = unzip2(map(self.to_primal_tangent_pair, tracers)) nzs_in = tuple(type(t) is not Zero for t in tangents) @@ -746,7 +870,8 @@ def new_out_axes_thunk(): else: primal_params = params - all_primal_results = call_primitive.bind_with_trace(self.parent_trace, (f_primal, *primals), primal_params) + all_primal_results = call_primitive.bind_with_trace( + self.parent_trace, (f_primal, *primals), primal_params) residual_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = linearize_outs_thunk() num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) non_fwd_res = all_primal_results[:num_res_out] @@ -772,18 +897,22 @@ def new_out_axes_thunk(): update_params = call_linearize_param_updaters.get(call_primitive) num_new_args = len(residuals) + len(env) - new_params = update_params(params, num_new_args, nzs_in) if update_params else params + new_params = (update_params(params, num_new_args, nzs_in) + if update_params else params) num_residuals = len(residual_avals) - @as_hashable_function(closure=(num_residuals, lin_jaxpr)) - def f_tangent(*args): - consts = args[:num_residuals] - nz_tangents = args[num_residuals:] - return core.eval_jaxpr(lin_jaxpr, consts, *nz_tangents) # TODO(mattjj,dougalm): this tag is read by DynamicJaxprTrace.process_map to # avoid round-tripping the jaxpr and thus getting grad-of-pmap cache misses. - # Remove when we replace the pmap implementation. - f_tangent._pmap_tag = isinstance(call_primitive, core.MapPrimitive) + # Remove the `if` branch when we replace the pmap implementation. + if isinstance(call_primitive, core.MapPrimitive): + @as_hashable_function(closure=(num_residuals, lin_jaxpr)) + def f_tangent(*args): + consts = args[:num_residuals] + nz_tangents = args[num_residuals:] + return core.eval_jaxpr(lin_jaxpr, consts, *nz_tangents) + f_tangent._pmap_tag = isinstance(call_primitive, core.MapPrimitive) + else: + f_tangent = _get_f_tangent(lin_jaxpr, num_residuals) nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz] nz_tangents_out = call_primitive.bind_with_trace( @@ -800,6 +929,16 @@ def f_tangent(*args): # that's handled in process_call. process_map = process_call + +@weakref_lru_cache +def _get_f_tangent(lin_jaxpr, num_residuals): + def _f(*args): + consts = args[:num_residuals] + nz_tangents = args[num_residuals:] + return core.eval_jaxpr(lin_jaxpr, consts, *nz_tangents) + return _f + + def maybe_linearize_tracer(trace, primal, is_nonzero, tangent): if is_nonzero: assert not type(tangent) is Zero @@ -829,6 +968,10 @@ def linearize_from_jvp(jvp: lu.WrappedFun, trace = pe.JaxprTrace(parent_trace, current_name_stack, core.TraceTag()) tangent_avals = [get_aval(p).to_tangent_aval() for p in primals] + # map tangents with float0 dtype to symbolic zeros + nonzeros = [nz and not (isinstance(a, core.ShapedArray) and a.dtype == float0) + for a, nz in zip(tangent_avals, nonzeros)] + def make_zero(aval): if instantiate_input_zeros: return zeros_like_aval(aval) @@ -842,10 +985,11 @@ def make_zero(aval): else: zero_type = Zero # type: ignore[assignment] - tangent_args = tuple(trace.new_arg(pe.PartialVal.unknown(aval)) if nz else make_zero(aval) - for aval, nz in zip(tangent_avals, nonzeros)) with core.set_current_trace(trace): - out_primals, out_tangents = jvp.call_wrapped(primals, tangent_args, **params) + tangent_args = [trace.new_arg(pe.PartialVal.unknown(a)) if nz else make_zero(a) + for a, nz in zip(tangent_avals, nonzeros)] + out_primals, out_tangents = jvp.call_wrapped( + tuple(primals), tuple(tangent_args), **params) if not multiple_results: out_primals = [out_primals] @@ -864,7 +1008,13 @@ def make_zero(aval): out_nz_tracers = [trace.to_jaxpr_tracer(r) for (r, nz) in zip(out_tangents, out_nzs) if nz] in_tracers = [t for t, nz in zip(tangent_args, nonzeros) if nz] - jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, jvp.debug_info) + jaxpr, out_consts, _ = pe.tracers_to_jaxpr( + in_tracers, out_nz_tracers, trace.effect_handles, + jvp.debug_info.with_unknown_names()) + jaxpr, used_consts, _ = pe.dce_jaxpr_consts( + jaxpr, [True] * len(jaxpr.outvars), + [False] * len(jaxpr.constvars) + [True] * len(jaxpr.invars)) + out_consts = [c for used, c in zip(used_consts, out_consts) if used] def linearized(residuals, *tangents): nz_tangents_in = [t for (t, nz) in zip(tangents, nonzeros) if nz] @@ -895,6 +1045,11 @@ def __init__(self, trace, primal, tangent): self.primal = primal self.tangent = tangent + def _short_repr(self): + pp = lambda x: x._short_repr() if isinstance(x, Tracer) else str(x) + primal, tangent = pp(self.primal), typeof(self.tangent).str_short(True) + return f"GradTracer({primal=!s}, typeof(tangent)={tangent!s})" + @property def aval(self): return get_aval(self.primal) @@ -908,6 +1063,12 @@ def full_lower(self): def to_concrete_value(self): return core.to_concrete_value(self.primal) + def get_referent(self): + return core.get_referent(self.primal) + + def cur_qdd(self): + return core.cur_qdd(self.primal) + # -------------------- Primitives -------------------- @@ -930,7 +1091,11 @@ def linear_jvp(primitive, primals, tangents, **params): return val_out, primitive.bind(*tangents, **params) def linear_transpose(transpose_rule, cotangent, *args, **kwargs): - return Zero if type(cotangent) is Zero else transpose_rule(cotangent, **kwargs) + if type(cotangent) is Zero: + return [Zero(x.aval.to_tangent_aval()) if isinstance(x, UndefinedPrimal) + else None for x in args] + else: + return transpose_rule(cotangent, **kwargs) def deflinear2(primitive, transpose_rule): @@ -938,7 +1103,11 @@ def deflinear2(primitive, transpose_rule): primitive_transposes[primitive] = partial(linear_transpose2, transpose_rule) def linear_transpose2(transpose_rule, cotangent, *args, **kwargs): - return Zero if type(cotangent) is Zero else transpose_rule(cotangent, *args, **kwargs) + if type(cotangent) is Zero: + return [Zero(x.aval.to_tangent_aval()) if isinstance(x, UndefinedPrimal) + else None for x in args] + else: + return transpose_rule(cotangent, *args, **kwargs) def defjvp(primitive, *jvprules): @@ -982,15 +1151,18 @@ def defbilinear(prim, lhs_rule, rhs_rule): def bilinear_transpose(lhs_rule, rhs_rule, cotangent, x, y, **kwargs): assert is_undefined_primal(x) ^ is_undefined_primal(y) - if type(cotangent) is Zero: - return Zero if is_undefined_primal(x): - out = lhs_rule(cotangent, x, y, **kwargs) - return Zero if out is Zero else (out, None) + if type(cotangent) is Zero: + return Zero(x.aval), None + else: + out = lhs_rule(cotangent, x, y, **kwargs) + return out, None else: - out = rhs_rule(cotangent, x, y, **kwargs) - return Zero if out is Zero else (None, out) - + if type(cotangent) is Zero: + return None, Zero(y.aval) + else: + out = rhs_rule(cotangent, x, y, **kwargs) + return None, out def defjvp_zero(primitive): assert isinstance(primitive, Primitive) @@ -1024,50 +1196,53 @@ def traceable(f, store, in_tree, *primals_and_tangents): store.store(out_tree) return out_flat +def call_transpose_fancy(primitive, cts, *args, call_jaxpr, **params): + if call_jaxpr.constvars: raise NotImplementedError + primals_ctrefs, specs = project_accums(args) + flat_args, treedef = tree_flatten((primals_ctrefs, cts)) + cell = lambda: None + + @partial(lu.wrap_init, debug_info=call_jaxpr.debug_info.with_unknown_names()) + def transposed(*flat_args): + primals_ctrefs, cts = tree_unflatten(treedef, flat_args) + args = unproject_accums(specs, primals_ctrefs) + backward_pass3(call_jaxpr, False, (), args, cts) + cts_out = [x.freeze() if isinstance(x, ValAccum) else None for x in args] + cts_out, cell.out_tree = tree_flatten(cts_out) # type: ignore + return cts_out -def call_transpose(primitive, params, call_jaxpr: core.Jaxpr, args, ct, _): - if isinstance(call_jaxpr, core.ClosedJaxpr): - call_jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts - else: - consts = () - all_args, in_tree_def = tree_flatten((consts, args, ct)) - fun = lu.hashable_partial(lu.wrap_init( - backward_pass, debug_info=call_jaxpr.debug_info), call_jaxpr, False) - fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def) update_params = call_transpose_param_updaters.get(primitive) if update_params: - params = update_params(params, map(is_undefined_primal, args), - [type(x) is not Zero for x in ct]) - if config.dynamic_shapes.value: - # TODO(mattjj,dougalm): handle consts, for now assume just args - which_lin = [is_undefined_primal(x) for x in args] - res_invars, _ = partition_list(which_lin, call_jaxpr.invars) - new_invars = [*res_invars, *call_jaxpr.outvars] - dbidx_map = {v: core.DBIdx(i) for i, v in enumerate(new_invars)} - in_type = [(v.aval.update(shape=tuple(dbidx_map.get(d, d) for d in v.aval.shape)) # type: ignore[arg-type] - if type(v.aval) is core.DShapedArray else v.aval, True) for v in new_invars] - fun = lu.annotate(fun, tuple(in_type)) - out_flat = primitive.bind(fun, *all_args, **params) - return tree_unflatten(out_tree(), out_flat) -primitive_transposes[core.call_p] = partial(call_transpose, call_p) + params = update_params(params, [isinstance(x, GradAccum) for x in args], + [type(x) is not Zero for x in cts]) + out_flat = primitive.bind(transposed, *flat_args, **params) + for x, ct in zip(args, tree_unflatten(cell.out_tree, out_flat)): # type: ignore + if isinstance(x, ValAccum): x.accum(ct) +fancy_transposes[core.call_p] = partial(call_transpose_fancy, call_p) -def _closed_call_transpose(params, jaxpr, args, ct, cts_in_avals): - jaxpr_, consts = jaxpr.jaxpr, jaxpr.consts +def _closed_call_transpose(ct, *args, call_jaxpr, **params): + jaxpr_, consts = call_jaxpr.jaxpr, call_jaxpr.consts jaxpr_ = pe.convert_constvars_jaxpr(jaxpr_) - return call_transpose(core.closed_call_p, params, jaxpr_, (*consts, *args), - ct, cts_in_avals) -primitive_transposes[core.closed_call_p] = _closed_call_transpose + call_transpose_fancy(core.closed_call_p, ct, *consts, *args, + call_jaxpr=jaxpr_, **params) +fancy_transposes[core.closed_call_p] = _closed_call_transpose @lu.transformation_with_aux2 def nonzero_outputs(f, store, *args, **kwargs): results = f(*args, **kwargs) - store.store([type(r) is not Zero for r in results]) + store.store([not isinstance(r, (Zero, type(None))) for r in results]) return results +# TODO(mattjj): delete this when the original pmap implementation is removed def map_transpose(primitive: core.Primitive, params, call_jaxpr: core.Jaxpr, args, ct, _): + # TODO(mattjj): we should unmap any Zeros in ct according to out_axes, but + # this code path is not long for this world... + args = [x if type(x) is not UndefinedPrimal else + UndefinedPrimal(core.mapped_aval(params['axis_size'], ax, x.aval)) + for x, ax in zip(args, params['in_axes'])] all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts # TODO(necula): use the right debug_info for the backwards pass fun = lu.hashable_partial(lu.wrap_init( @@ -1090,7 +1265,7 @@ def map_transpose(primitive: core.Primitive, params, @as_hashable_function(closure=(in_axes, tuple(type(c) is Zero for c in ct))) def out_axes_thunk(): return tuple(axis or 0 for axis, nz in zip(in_axes, nz_arg_cts()) if nz) - new_params = dict(params, name=wrap_name(params['name'], 'transpose'), + new_params = dict(params, name=wrap_name('transpose', params['name']), in_axes=new_in_axes, out_axes_thunk=out_axes_thunk) del new_params['out_axes'] update_params = call_transpose_param_updaters.get(primitive) @@ -1100,17 +1275,17 @@ def out_axes_thunk(): try: out_flat = primitive.bind(fun, *all_args, **new_params) - except dispatch.InternalFloatingPointError as e: + except api_util.InternalFloatingPointError as e: print("Invalid nan value encountered in the backward pass of a jax.jit " "function. Calling the de-optimized backward pass.") try: - _ = backward_pass(call_jaxpr, None, {}, args, ct) + _ = backward_pass(call_jaxpr, False, (), args, ct) except (FloatingPointError, ZeroDivisionError) as e2: raise e2 from None else: # If control reaches this line, we got a NaN on the output of `compiled` # but not `fun.call_wrapped` on the same arguments. Let's tell the user. - dispatch._raise_no_nan_in_deoptimized(e) + api_util._raise_no_nan_in_deoptimized(e) arg_cts = tree_unflatten(out_tree(), out_flat) # The freevars are being fanned out (not mapped). During transpose the @@ -1120,7 +1295,7 @@ def unmap_zero(zero, in_axis): return (zero if in_axis is None else Zero(core.unmapped_aval(params['axis_size'], in_axis, zero.aval))) arg_cts = (unmap_zero(arg_ct, in_axis) if type(arg_ct) is Zero else - arg_ct if in_axis is not None else + arg_ct if in_axis is not None or arg_ct is None else arg_ct.sum(0) for arg_ct, in_axis in zip(arg_cts, in_axes)) return tuple(arg_cts) @@ -1138,12 +1313,13 @@ def _jvp_jaxpr(jaxpr: core.ClosedJaxpr, nonzeros: Sequence[bool], instantiate: Sequence[bool]): assert len(jaxpr.in_avals) == len(nonzeros) f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), - debug_info=jaxpr.jaxpr.debug_info) + debug_info=jaxpr.jaxpr.debug_info.with_unknown_names()) f_jvp, out_nonzeros = f_jvp_traceable( jvp(f, instantiate=instantiate, transform_stack=False), nonzeros) - tangent_avals = [aval.to_tangent_aval() for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz] - avals_in = list(it.chain(jaxpr.in_avals, tangent_avals)) - jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic( + tangent_avals = [aval.to_tangent_aval() + for aval, nz in zip(jaxpr.in_aval_qdds, nonzeros) if nz] + avals_in = list(it.chain(jaxpr.in_aval_qdds, tangent_avals)) + jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic( f_jvp, avals_in) return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros() @@ -1163,19 +1339,24 @@ def f_jvp_traceable(f, store, nonzeros, *primals_and_nztangents): def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_out, tangents_out): new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars) new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars) - new_debug_info = jaxpr.jaxpr.debug_info - new_arg_names = tuple(_perm(primals_in, tangents_in, - jaxpr.jaxpr.debug_info.safe_arg_names(len(jaxpr.jaxpr.invars)))) - new_result_paths = tuple(_perm(primals_out, tangents_out, - jaxpr.jaxpr.debug_info.safe_result_paths(len(jaxpr.jaxpr.outvars)))) - new_debug_info = new_debug_info._replace( - arg_names=new_arg_names, - result_paths=new_result_paths, - ) - new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars, - new_invars, new_outvars, jaxpr.jaxpr.eqns, - jaxpr.jaxpr.effects, - new_debug_info) + if jaxpr.jaxpr.debug_info.arg_names is None: + new_arg_names = None + else: + new_arg_names = tuple(_perm(primals_in, tangents_in, + jaxpr.jaxpr.debug_info.arg_names)) + if jaxpr.jaxpr.debug_info.result_paths is None: + new_result_paths = None + else: + new_result_paths = tuple(_perm(primals_out, tangents_out, + jaxpr.jaxpr.debug_info.result_paths)) + new_debug_info = jaxpr.jaxpr.debug_info._replace( + arg_names=new_arg_names, result_paths=new_result_paths) + constvars = jaxpr.jaxpr.constvars + new_effects = pe._renumber_effects( + (*constvars, *new_invars), (*constvars, *jaxpr.jaxpr.invars), + jaxpr.jaxpr.effects) + new_jaxpr = core.Jaxpr(constvars, new_invars, new_outvars, jaxpr.jaxpr.eqns, + new_effects, new_debug_info) return core.ClosedJaxpr(new_jaxpr, jaxpr.consts) def _perm(primal_counts: Sequence[int], tangent_counts: Sequence[int], @@ -1202,7 +1383,7 @@ def raise_custom_vjp_error_on_jvp(*_, **__): def _custom_lin_transpose(cts_out, *invals, num_res, bwd: lu.WrappedFun, out_avals, - symbolic_zeros): + symbolic_zeros, in_zeros): res, _ = split_list(invals, [num_res]) if symbolic_zeros: cts_out = map(replace_internal_symbolic_zeros, cts_out) @@ -1210,9 +1391,18 @@ def _custom_lin_transpose(cts_out, *invals, num_res, cts_out = map(instantiate_zeros, cts_out) cts_in = bwd.call_wrapped(*res, *cts_out) cts_in = map(replace_rule_output_symbolic_zeros, cts_in) - return [None] * num_res + list(cts_in) + nz_cts_in, _ = partition_list(in_zeros, cts_in) + return [None] * num_res + nz_cts_in primitive_transposes[custom_lin_p] = _custom_lin_transpose +def _custom_lin_pp_rule(eqn: core.JaxprEqn, context: core.JaxprPpContext, + settings: core.JaxprPpSettings) -> core.pp.Doc: + params = dict(eqn.params) + params.pop("out_avals") + params["bwd"] = params.pop("bwd").debug_info.func_name + return core._pp_eqn(eqn.replace(params=params), context, settings) +core.pp_eqn_rules[custom_lin_p] = _custom_lin_pp_rule + class CustomJVPException(Exception): def __init__(self): @@ -1238,3 +1428,21 @@ def __init__(self): # TODO(mattjj): remove this vestigial dict reducing_transposes: dict[core.Primitive, Callable] = {} + +# TODO(mattjj): remove this old code, used by something downstream +def call_transpose(primitive, params, call_jaxpr: core.Jaxpr, args, ct, _): + if isinstance(call_jaxpr, core.ClosedJaxpr): + call_jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts + else: + consts = () + all_args, in_treedef = tree_flatten((consts, args, ct)) + fun = lu.hashable_partial( + lu.wrap_init(backward_pass, debug_info=call_jaxpr.debug_info), + call_jaxpr, False) + fun, out_tree = flatten_fun_nokwargs(fun, in_treedef) + update_params = call_transpose_param_updaters.get(primitive) + if update_params: + params = update_params(params, map(is_undefined_primal, args), + [type(x) is not Zero for x in ct]) + out_flat = primitive.bind(fun, *all_args, **params) + return tree_unflatten(out_tree(), out_flat) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 03c9a95105d7..53b91a75cf4c 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import annotations -import collections from collections.abc import Callable, Sequence import dataclasses from functools import partial @@ -21,223 +20,26 @@ import numpy as np -import jax from jax._src import config from jax._src import core +from jax._src.core import typeof from jax._src import source_info_util from jax._src import linear_util as lu from jax._src.partition_spec import PartitionSpec as P from jax._src.sharding_impls import NamedSharding from jax._src import mesh as mesh_lib -from jax._src.ad_util import (Zero, instantiate, SymbolicZero, - replace_rule_output_symbolic_zeros, - add_jaxvals, add_jaxvals_p) +from jax._src.ad_util import Zero, SymbolicZero, add_jaxvals, add_jaxvals_p from jax._src.core import Trace, Tracer, TraceTag, AxisName from jax._src.interpreters import partial_eval as pe -from jax._src.tree_util import (tree_unflatten, tree_flatten, - register_pytree_node, PyTreeDef) +from jax._src.tree_util import (tree_unflatten, tree_flatten, PyTreeDef) from jax._src.typing import Array from jax._src.util import (unzip2, safe_map, safe_zip, split_list, canonicalize_axis, moveaxis, as_hashable_function, curry, memoize, weakref_lru_cache, tuple_insert) - map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip - -# Jumbles - -# i:(Fin 3) => f32[[3, 1, 4].i] -@dataclasses.dataclass(frozen=True) -class JumbleTy: - binder: core.Var - length: int | Tracer | core.Var - elt_ty: core.DShapedArray - def __repr__(self) -> str: - return f'Var{id(self.binder)}:{self.length} => {self.elt_ty}' - replace = dataclasses.replace - -# [3, 1, 4].i -@dataclasses.dataclass(frozen=True) -class IndexedAxisSize: - idx: core.Var - lengths: Array | core.Var | Tracer - def __repr__(self) -> str: - return f'{self.lengths}.Var{id(self.idx)}' - replace = dataclasses.replace - -# Jumble(aval=a:3 => f32[[3 1 4].a], -# data=Array([0., 1., 2., 0., 0., 1., 2., 3.], dtype=float32)) -@dataclasses.dataclass(frozen=True) -class Jumble: - aval: JumbleTy - data: Array - -# To vmap over a jumble, one must specify the axis as JumbleAxis. -class JumbleAxis: pass -jumble_axis = JumbleAxis() - -# As a temporary measure before we have more general JITable / ADable interfaces -# (analogues to vmappable), to enable Jumbles to be used with other -# transformations and higher-order primitives (primarily jit, though also grad -# with allow_int=True) we register them as pytrees. -# TODO(mattjj): add JITable / ADable interfaces, remove this pytree registration -def _jumble_flatten(jumble): - lengths = [] - new_shape = [lengths.append(d.lengths) or d.replace(lengths=len(lengths)) - if type(d) is IndexedAxisSize else d - for d in jumble.aval.elt_ty.shape] - elt_ty = jumble.aval.elt_ty.update(shape=tuple(new_shape)) - aval = jumble.aval.replace(elt_ty=elt_ty) - return (lengths, jumble.data), aval - - -def _ragged_axis_parts(dim: RaggedAxis) -> tuple[int, int, int]: - stacked_axis = dim.stacked_axis - ragged_axes = dim.ragged_axes - if len(ragged_axes) != 1: - raise ValueError('Multiple ragged axes not yet implemented.') - ragged_axis_dim = ragged_axes[0][0] - ragged_axis_length = ragged_axes[0][1] - return stacked_axis, ragged_axis_dim, ragged_axis_length - - -def _jumble_unflatten(aval, x): - lengths, data = x - new_shape = [d.replace(lengths=lengths[d.lengths - 1]) - if type(d) is IndexedAxisSize else d - for d in aval.elt_ty.shape] - elt_ty = aval.elt_ty.update(shape=tuple(new_shape)) - aval = aval.replace(elt_ty=elt_ty) - return Jumble(aval, data) -register_pytree_node(Jumble, _jumble_flatten, _jumble_unflatten) - -def _jumble_result(axis_size, stacked_axis, ragged_axes, x): - binder = core.Var('', core.ShapedArray((), np.dtype('int32'))) - if stacked_axis != 0: - raise NotImplementedError # TODO Transpose x so the stacked axis is axis 0 - shape = list(x.shape) - del shape[0] - for ragged_axis, segment_lens in ragged_axes: - shape[ragged_axis-1] = IndexedAxisSize(binder, segment_lens) - elt_ty = core.DShapedArray(tuple(shape), x.dtype, x.weak_type) - return Jumble(JumbleTy(binder, axis_size, elt_ty), x) - - -@dataclasses.dataclass(frozen=True) -class RaggedAxis: - stacked_axis: int - # For each axis, we store its index and the corresponding segment lengths. - # For example, the jumble i:(Fin 3) => f32[lens1.i, 7, lens2.i] - # would be represented with ragged_axes = [(1, lens1), (3, lens2)] - ragged_axes: tuple[tuple[int, Any], ...] - - @property - def size(self): - # TODO(mattjj, axch): All the segment lengths arrays better be the - # same length! - return len(self.ragged_axes[0][1]) - - def move_stacked_axis(self: RaggedAxis, dst: int) -> RaggedAxis: - # Assumes that all stored and incoming axes are already canonicalized - def move_axis(ax): - if self.stacked_axis > ax and ax >= dst: - return ax + 1 - if self.stacked_axis < ax and ax <= dst: - return ax - 1 - return ax - new_axes = tuple((move_axis(ax), sizes) for ax, sizes in self.ragged_axes) - return RaggedAxis(dst, new_axes) - - -def transpose_ragged_axes(dim: RaggedAxis, perm: tuple[int, ...]) -> RaggedAxis: - new_ragged_axes = [] - for idx, old_idx in enumerate(perm): - for ax, size in dim.ragged_axes: - if old_idx == ax: - new_ragged_axes.append((idx, size)) - break - return _sorted_ragged_axis(dim.stacked_axis, new_ragged_axes) - -def _sorted_ragged_axis(stacked_axis, ragged_axes): - return RaggedAxis(stacked_axis, tuple(sorted(ragged_axes, key=lambda p: p[0]))) - -def make_batch_axis( - ndim: int, - stacked_axis: int, - ragged_axes: list[tuple[int, Array | core.Var]], -) -> int | RaggedAxis: - if ragged_axes: - canonical = [(canonicalize_axis(ax, ndim), sz) for ax, sz in ragged_axes] - return _sorted_ragged_axis(canonicalize_axis(stacked_axis, ndim), canonical) - else: - return canonicalize_axis(stacked_axis, ndim) - -def bdim_as_shape( - bdim: int | RaggedAxis, data_shape: core.Shape) -> core.Shape: - if isinstance(bdim, RaggedAxis): - result = list(data_shape) - binder = core.Var('', core.ShapedArray((), np.dtype('int32'))) - for ragged_axis, segment_lens in bdim.ragged_axes: - result[ragged_axis] = IndexedAxisSize(binder, segment_lens) - return tuple(result) - else: - return data_shape - -def shape_as_bdim( - stacked_axis: int, data_shape: core.Shape) -> int | RaggedAxis: - # This assumes that there is only one binder in the data_shape. - ragged_axes = [(i, size.lengths) for i, size in enumerate(data_shape) - if isinstance(size, IndexedAxisSize)] - return make_batch_axis(len(data_shape), stacked_axis, ragged_axes) - - -def _update_annotation( - f: lu.WrappedFun, orig_type: core.InputType | None, - axis_size: core.AxisSize, axis_name: AxisName, - explicit_in_dims: Sequence[int | RaggedAxis | None], - segment_lens: Sequence[Array], - ) -> lu.WrappedFun: - if orig_type is None: return f - # By convention, `explicit_in_dims` only accounts for explicit arguments. - assert len(explicit_in_dims) == sum(explicit for _, explicit in orig_type) - # We need to: - # * if `axis_size` is dynamic, add a new implicit binder (type) for it; - # * for each element of `segment_lengths`, add a new explicit binder for it; - # * drop other implicit binders, replacing DBIdx which refer to them with - # Name objects; - # * for each (aval, in_dim) pair: if int-valued in_dim, add batch axis (int - # size if `axis_size` is int, otherwise Name); if RaggedAxis-valued in_dim, - # add batch axis (int if corresponding segment_lengths is concrete, Name if - # not); - # * generate full in_type with implicit args too. - - class Name: - def __init__(self, a): self.a = a - names = [Name(a) for a, _ in orig_type] - avals = [a.update(shape=tuple(names[d.val] if type(d) is pe.DBIdx else d - for d in a.shape)) - if type(a) is core.DShapedArray else a for a, e in orig_type if e] - - new_avals = [core.get_aval(s) for s in segment_lens] - sz = Name(axis_size.aval) if isinstance(axis_size, Tracer) else axis_size - for a, d in zip(avals, explicit_in_dims): - if isinstance(d, RaggedAxis): - raise NotImplementedError - else: - new_avals.append(core.unmapped_aval(sz, d, a)) # type: ignore - - mentioned = {d for a in new_avals if type(a) is core.DShapedArray - for d in a.shape if type(d) is Name} - expl_names = set(map(Name, new_avals)) - impl_names = mentioned - expl_names # type: ignore - impl_part = [(n.a, False) for n in impl_names] # type: ignore - name_map = {n: pe.DBIdx(i) for i, n in enumerate((*impl_names, *expl_names))} - expl_part = [(a.update(shape=tuple(name_map.get(d, d) for d in a.shape)) - if type(a) is core.DShapedArray else a, True) for a in new_avals] - return lu.annotate(f, (*impl_part, *expl_part)) - ### vmappable typeclass Vmappable = Any @@ -254,26 +56,11 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt: handler = to_elt_handlers.get(type(x)) if handler: return handler(partial(to_elt, trace, get_idx), get_idx, x, spec) - elif type(x) is Jumble: - if spec is not jumble_axis: - raise TypeError("jumble input without using jumble_axis in_axes spec") - ias: IndexedAxisSize # Not present in the AxisSize union in core.py - (d, ias), = ((i, sz) # type: ignore - for i, sz in enumerate(x.aval.elt_ty.shape) - if type(sz) is IndexedAxisSize) - batch_axis = make_batch_axis(x.data.ndim, 0, [(d+1, ias.lengths)]) - return BatchTracer(trace, x.data, batch_axis) elif isinstance(spec, int) or spec is None: spec = spec and canonicalize_axis(spec, len(np.shape(x))) return (BatchTracer(trace, x, spec, source_info_util.current()) if spec is not None else x) else: - if isinstance(trace, BatchTrace) and isinstance(spec, JumbleAxis): - # TODO(mvoz): A vaguely questionable assumption that it is always - # sound to have a 0 axis here. This is true for the current use cases - # and comes from how we handle intermediary products of jumbles in - # vmap. - return BatchTracer(trace, x, 0, source_info_util.current()) # TODO(mvoz): This is a terrible place to fall into if you pass # a non jumble type in, make it clearer what went wrong. assert False, f'Unexpected type in ELT? {type(x)}' @@ -282,32 +69,30 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt: to_elt_handlers: dict[type, ToEltHandler] = {} def from_elt(trace: BatchTrace, axis_size: AxisSize, mesh_axis: MeshAxis, - i: int, x: Elt, spec: MapSpec) -> Vmappable: + sum_match: bool, i: int, x: Elt, spec: MapSpec) -> tuple[Vmappable, MapSpec]: handler = from_elt_handlers.get(type(x)) if handler: def _cont(axis_size, elt, axis): - return from_elt(trace, axis_size, mesh_axis, i, elt, axis) - return handler(_cont, axis_size, x, spec) + return from_elt(trace, axis_size, mesh_axis, sum_match, i, elt, axis)[0] + return handler(_cont, axis_size, x, spec), spec val, bdim = trace.to_batch_info(x) - if type(bdim) is RaggedAxis: - if spec is not jumble_axis: - # TODO(mattjj): improve this error message - raise TypeError("ragged output without using jumble_axis out_axes spec") - return _jumble_result(axis_size, bdim.stacked_axis, bdim.ragged_axes, val) - else: - try: - return matchaxis(trace.axis_data.name, axis_size, mesh_axis, - bdim, spec, val) - except SpecMatchError: - raise SpecMatchError(i, x.batch_dim, spec) from None + bdim_inferred = bdim if spec is infer else spec + try: + return matchaxis(trace.axis_data.name, axis_size, mesh_axis, + bdim, spec, val, sum_match=sum_match), bdim_inferred + except SpecMatchError: + raise SpecMatchError(i, x.batch_dim, spec) from None from_elt_handlers: dict[type, FromEltHandler] = {} def make_iota(axis_size: AxisSize) -> Array: + # Callers of this utility, via batch() or vtile(), must be in a context + # where lax is importable. + from jax import lax # pytype: disable=import-error handler = make_iota_handlers.get(type(axis_size)) if handler: return handler(axis_size) else: - return jax.lax.iota('int32', int(axis_size)) + return lax.iota('int32', int(axis_size)) make_iota_handlers: dict[type, MakeIotaHandler] = {} def register_vmappable(data_type: type, spec_type: type, axis_size_type: type, @@ -319,7 +104,7 @@ def register_vmappable(data_type: type, spec_type: type, axis_size_type: type, from_elt_handlers[data_type] = from_elt if make_iota: make_iota_handlers[axis_size_type] = make_iota vmappables: dict[type, tuple[type, type]] = {} -spec_types: set[type] = {JumbleAxis} +spec_types: set[type] = set() def unregister_vmappable(data_type: type) -> None: _, axis_size_type = vmappables.pop(data_type) @@ -329,11 +114,11 @@ def unregister_vmappable(data_type: type) -> None: del make_iota_handlers[axis_size_type] global spec_types spec_types = ( - {JumbleAxis} | {spec_type for spec_type, _ in vmappables.values()} + set() | {spec_type for spec_type, _ in vmappables.values()} ) def is_vmappable(x: Any) -> bool: - return type(x) is Jumble or type(x) in vmappables + return type(x) in vmappables @lu.transformation_with_aux2 def flatten_fun_for_vmap(f: Callable, @@ -344,44 +129,6 @@ def flatten_fun_for_vmap(f: Callable, store.store(out_tree) return ans -# Propagate ragged masking rules from invars to outvars -# rule([params], [raggedness_per_invar], outvars) -> -# [raggedness_per_invar, raggedness_per_outvar] -RaggedMaskingRule = Callable[ - [list[Any], list[Any], list[Any]], tuple[list[Any], list[Any]] -] - -ragged_prop_rules: dict[core.Primitive, RaggedMaskingRule] = {} - - -def ragged_mask_elementwise_rule(eqn_params, invar_raggedness, outvars): - # TODO(mvoz): A util for getting the ragged representations - first_invar_raggedness = invar_raggedness[0] - for other_invar_raggedness in invar_raggedness[1:]: - if other_invar_raggedness != first_invar_raggedness: - raise ValueError(f'{other_invar_raggedness} != {first_invar_raggedness}') - - outvar_raggedness = [first_invar_raggedness] * len(outvars) - return invar_raggedness, outvar_raggedness - - -def ragged_mask_assert_no_op_rule(eqn_params, invar_raggedness, outvars): - if any(invar_raggedness): - raise ValueError(f'unexpected invar_raggedness: {invar_raggedness}') - return invar_raggedness, [None] * len(outvars) - - -def ragged_mask_no_op_rule(eqn_params, invar_raggedness, outvars): - return invar_raggedness, [None] * len(outvars) - - -def ragged_mask_transfer_identity( - eqn_params, invar_raggedness, outvar_raggedness -): - assert len(invar_raggedness) == 1, invar_raggedness - outvar_raggedness = invar_raggedness - return invar_raggedness, outvar_raggedness - ### tracer @@ -393,10 +140,10 @@ def ragged_mask_transfer_identity( class BatchTracer(Tracer): __slots__ = ['val', 'batch_dim', 'source_info'] - def __init__(self, trace, val, batch_dim: NotMapped | int | RaggedAxis, + def __init__(self, trace, val, batch_dim: NotMapped | int, source_info: source_info_util.SourceInfo | None = None): if config.enable_checks.value: - assert type(batch_dim) in (NotMapped, int, RaggedAxis) + assert type(batch_dim) in (NotMapped, int) if type(batch_dim) is int: aval = core.get_aval(val) assert 0 <= batch_dim < len(aval.shape) @@ -405,24 +152,22 @@ def __init__(self, trace, val, batch_dim: NotMapped | int | RaggedAxis, self.batch_dim = batch_dim self.source_info = source_info + def _short_repr(self): + return f"VmapTracer(aval={self.aval}, batched={typeof(self.val)})" + @property def aval(self): aval = core.get_aval(self.val) + if self._trace.axis_data.spmd_name is not None: + if config._check_vma.value: + aval = aval.update( + vma=aval.vma - frozenset(self._trace.axis_data.spmd_name)) if self.batch_dim is not_mapped: return aval elif type(self.batch_dim) is int: return core.mapped_aval(aval.shape[self.batch_dim], self.batch_dim, aval) - elif type(self.batch_dim) is RaggedAxis: - new_aval = core.mapped_aval( - aval.shape[self.batch_dim.stacked_axis], self.batch_dim.stacked_axis, aval) - shape = list(new_aval.shape) # pytype: disable=attribute-error - for ragged_axis, segment_lengths in self.batch_dim.ragged_axes: - size_tracer = BatchTracer(self._trace, segment_lengths, 0) - if self.batch_dim.stacked_axis < ragged_axis: - ragged_axis -= 1 - shape[ragged_axis] = size_tracer - return core.DShapedArray(shape=tuple(shape), dtype=aval.dtype, - weak_type=aval.weak_type) + else: + raise Exception("batch dim should be int or `not_mapped`") def full_lower(self): if self.batch_dim is not_mapped: @@ -442,7 +187,7 @@ def _contents(self): def get_referent(self): if self.batch_dim is None or type(self.batch_dim) is int: return core.get_referent(self.val) - else: # TODO(mattjj): could handle the RaggedAxis case? + else: return self @dataclasses.dataclass(frozen=True) @@ -451,11 +196,36 @@ class AxisData: size : Any # Only one of spmd_axis_name and explicit_mesh_axis is set. spmd_name : Any - explicit_mesh_axis: Any + # short for private `_explicit_mesh_axis`. The public property is called + # `.explicit_mesh_axis` + _ema: tuple[Any, ...] | None + + @property + def explicit_mesh_axis(self): + assert self._ema is None or isinstance(self._ema, tuple) + if self._ema is None: + return None + cur_mesh = mesh_lib.get_abstract_mesh() + if cur_mesh.empty: + return self._ema + ema0_type = cur_mesh._name_to_type[self._ema[0]] + assert all(cur_mesh._name_to_type[e] == ema0_type for e in self._ema) + if ema0_type != mesh_lib.AxisType.Explicit: + return None + return self._ema + + def __repr__(self): + return (f'AxisData(name={self.name}, size={self.size},' + f' spmd_name={self.spmd_name},' + f' explicit_mesh_axis={self.explicit_mesh_axis})') + + __str__ = __repr__ def get_sharding_for_vmap(axis_data, orig_sharding, axis): val = axis_data.explicit_mesh_axis + # TODO(yashkatariya): Preserve unreduced here using + # `orig_sharding.spec.update` new_spec = P(*tuple_insert(orig_sharding.spec, axis, val)) return NamedSharding(orig_sharding.mesh, new_spec) @@ -468,6 +238,7 @@ def __init__(self, parent_trace, tag, axis_data): assert isinstance(axis_data, AxisData) self.axis_data = axis_data self.tag = tag + self.requires_low = False def to_batch_info(self, val): if isinstance(val, BatchTracer) and val._trace.tag is self.tag: @@ -476,8 +247,6 @@ def to_batch_info(self, val): return val, not_mapped def process_primitive(self, p, tracers, params): - if config.dynamic_shapes.value: - p.abstract_eval(*(map(core.get_aval, tracers)), **params) vals_in, dims_in = unzip2(map(self.to_batch_info, tracers)) args_not_mapped = all(bdim is not_mapped for bdim in dims_in) if p in fancy_primitive_batchers: @@ -485,20 +254,15 @@ def process_primitive(self, p, tracers, params): and p in skippable_batchers and not any(self.axis_data.name == axis_name for axis_name in skippable_batchers[p](params))): - # no-op shortcut return p.bind_with_trace(self.parent_trace, vals_in, params) else: with core.set_current_trace(self.parent_trace): val_out, dim_out = fancy_primitive_batchers[p]( self.axis_data, vals_in, dims_in, **params) elif args_not_mapped: - # no-op shortcut return p.bind_with_trace(self.parent_trace, vals_in, params) - elif p in primitive_batchers: - with core.set_current_trace(self.parent_trace): - val_out, dim_out = primitive_batchers[p](vals_in, dims_in, **params) else: - raise NotImplementedError("Batching rule for '{}' not implemented".format(p)) + raise NotImplementedError(f"Batching rule for '{p}' not implemented") src = source_info_util.current() if p.multiple_results: with core.set_current_trace(self.parent_trace): # val_out may be lazy map @@ -512,16 +276,12 @@ def process_call(self, call_primitive, f, tracers, params): assert call_primitive.multiple_results params = dict(params, name=params.get('name', f.__name__)) vals, dims = unzip2(map(self.to_batch_info, tracers)) - segment_lens, dims = indirectify_ragged_axes(dims) f_, dims_out = batch_subtrace(f, self.tag, self.axis_data, tuple(dims)) - f_ = _update_annotation( - f_, f.in_type, self.axis_data.size, self.axis_data.name, dims, segment_lens) with core.set_current_trace(self.parent_trace): - vals_out = call_primitive.bind(f_, *segment_lens, *vals, **params) - vals_out, dims_out = resolve_ragged_axes(vals_out, dims_out()) + vals_out = call_primitive.bind(f_, *vals, **params) src = source_info_util.current() - return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out)] + return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out())] def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): vals, dims = unzip2(map(self.to_batch_info, tracers)) @@ -565,12 +325,9 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): in_vals, in_dims = unzip2(map(self.to_batch_info, tracers)) fun, out_dims1 = batch_subtrace(fun, self.tag, self.axis_data, in_dims) jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.tag, self.axis_data, in_dims) - out_vals = prim.bind_with_trace(self.parent_trace, (fun, jvp) + tuple(in_vals), + out_vals = prim.bind_with_trace(self.parent_trace, (fun, jvp, *in_vals), dict(symbolic_zeros=symbolic_zeros)) fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2) - if not fst: - assert out_dims == out_dims[:len(out_dims) // 2] * 2 - out_dims = out_dims[:len(out_dims) // 2] src = source_info_util.current() return [BatchTracer(self, v, d, src) for v, d in zip(out_vals, out_dims)] @@ -582,52 +339,61 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees, fun, out_dims1 = batch_subtrace(fun, self.tag, self.axis_data, in_dims) fwd, out_dims2 = batch_subtrace(fwd, self.tag, self.axis_data, fwd_in_dims) - bwd = batch_custom_vjp_bwd(bwd, self.tag, self.axis_data, out_dims2, in_dims) + def bwd_in_dims(): + _, _, input_fwds = out_trees() + pruned_dims = iter(out_dims2()) + full_dims = [next(pruned_dims) if f is None else in_dims[f] for f in input_fwds] + return [*full_dims, *pruned_dims] + + bwd = batch_custom_vjp_bwd(bwd, self.tag, self.axis_data, bwd_in_dims, in_dims) out_vals = prim.bind_with_trace(self.parent_trace, (fun, fwd, bwd) + tuple(in_vals), dict(out_trees=out_trees, symbolic_zeros=symbolic_zeros)) fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2) if not fst: - _, res_tree = out_trees() - _, out_dims = split_list(out_dims, [res_tree.num_leaves]) + _, res_tree, input_fwds = out_trees() + num_res = res_tree.num_leaves - sum(f is not None for f in input_fwds) + _, out_dims = split_list(out_dims, [num_res]) src = source_info_util.current() return [BatchTracer(self, v, d, src) for v, d in zip(out_vals, out_dims)] ### API for batching callables with vmappable inputs and outputs def batch(fun: lu.WrappedFun, axis_data, - in_dims, out_dim_dests) -> lu.WrappedFun: + in_dims, out_dim_dests, sum_match=False) -> lu.WrappedFun: # we split up _batch_inner and _batch_outer for the leak checker - f = _batch_inner(fun, axis_data, out_dim_dests) + f = _batch_inner(fun, axis_data, out_dim_dests, sum_match) return _batch_outer(f, axis_data, in_dims) @lu.transformation2 def _batch_outer(f, axis_data, in_dims, *in_vals): tag = TraceTag() with source_info_util.transform_name_stack('vmap'): - outs, trace = f(tag, in_dims, *in_vals) + outs, out_dim_srcs, trace = f(tag, in_dims, *in_vals) with core.ensure_no_leaks(trace): del trace - return outs + return outs, out_dim_srcs @lu.transformation2 -def _batch_inner(f: Callable, axis_data, out_dim_dests, tag, in_dims, *in_vals): +def _batch_inner(f: Callable, axis_data, out_dim_dests, sum_match, tag, in_dims, *in_vals): in_dims = in_dims() if callable(in_dims) else in_dims with core.take_current_trace() as parent_trace: trace = BatchTrace(parent_trace, tag, axis_data) idx = memoize(lambda: BatchTracer(trace, make_iota(axis_data.size), 0, source_info_util.current())) - in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims) + with core.set_current_trace(parent_trace): + in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims) + # TODO(yashkatariya): Instead of `add_explicit_mesh_axis_names`, we should + # create a new mesh by removing the axis_data.explicit_mesh_axis from it. with (core.set_current_trace(trace), core.extend_axis_env_nd([(axis_data.name, axis_data.size)]), - core.add_spmd_axis_names(axis_data.spmd_name)): + core.add_spmd_axis_names(axis_data.spmd_name), + core.add_explicit_mesh_axis_names(axis_data.explicit_mesh_axis)): outs = f(*in_tracers) - - out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests - out_vals = map(partial(from_elt, trace, axis_data.size, - axis_data.explicit_mesh_axis), - range(len(outs)), outs, out_dim_dests) - - return out_vals, trace + out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests + out_vals, out_dim_srcs = unzip2( + map(partial(from_elt, trace, axis_data.size, axis_data.explicit_mesh_axis, sum_match), + range(len(outs)), outs, out_dim_dests)) + return out_vals, out_dim_srcs, trace # NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it. def vtile(f_flat: lu.WrappedFun, @@ -669,123 +435,46 @@ def batch_subtrace(f, store, tag, axis_data, in_dims, *in_vals): trace = BatchTrace(parent_trace, tag, axis_data) with core.set_current_trace(trace): in_dims = in_dims() if callable(in_dims) else in_dims - in_vals, in_dims = resolve_ragged_axes(in_vals, in_dims) in_tracers = [BatchTracer(trace, x, dim, source_info_util.current()) if dim is not None else x for x, dim in zip(in_vals, in_dims)] outs = f(*in_tracers) out_vals, out_dims = unzip2(map(trace.to_batch_info, outs)) - segment_lens, out_dims = indirectify_ragged_axes(out_dims) store.store(out_dims) - return (*segment_lens, *out_vals) - -def indirectify_ragged_axes(dims): - if not any(type(d) is RaggedAxis for d in dims): - return [], dims - axis_map : dict[int, tuple[Array, pe.DBIdx]] = collections.OrderedDict() - def canonicalize_segment_lengths(d: RaggedAxis) -> RaggedAxis: - new_ragged_axes = [] - for ragged_axis, segment_lengths in d.ragged_axes: - _, dbidx = axis_map.setdefault( - id(core.get_referent(segment_lengths)), - (segment_lengths, pe.DBIdx(len(axis_map)))) - new_ragged_axes.append((ragged_axis, dbidx)) - return RaggedAxis(d.stacked_axis, tuple(new_ragged_axes)) - new_dims = [canonicalize_segment_lengths(d) - if isinstance(d, RaggedAxis) else d for d in dims] - segment_lens = [s for s, _ in axis_map.values()] - return segment_lens, new_dims - -def indirectify_ragged_axes_against_inputs_outputs(dims, in_vals, out_vals): - def canonicalize_segment_lengths(d: RaggedAxis) -> RaggedAxis: - new_ragged_axes = [] - for ragged_axis, segment_lengths in d.ragged_axes: - key = id(core.get_referent(segment_lengths)) - value = _locate_value(key, in_vals, out_vals) - new_ragged_axes.append((ragged_axis, value)) - return RaggedAxis(d.stacked_axis, tuple(new_ragged_axes)) - new_dims = [canonicalize_segment_lengths(d) - if isinstance(d, RaggedAxis) else d for d in dims] - return new_dims - -def _locate_value(key, in_vals, out_vals): - for ix, candidate in enumerate(in_vals): - if key == id(candidate): - return pe.InDBIdx(ix) - for ix, candidate in enumerate(out_vals): - if key == id(candidate): - return pe.OutDBIdx(ix) - assert False, "Could not find segment lengths" - -def resolve_ragged_axes(vals, dims): - idxs = {lengths_idx.val for d in dims if isinstance(d, RaggedAxis) - for (_, lengths_idx) in d.ragged_axes} - dims = [RaggedAxis(d.stacked_axis, - tuple((ragged_axis, vals[lengths_idx.val]) - for ragged_axis, lengths_idx in d.ragged_axes)) - if isinstance(d, RaggedAxis) else d for d in dims] - vals = [x for i, x in enumerate(vals) if i not in idxs] - return vals, dims - -def resolve_ragged_axes_against_inputs_outputs(in_vals, out_vals, dims): - def fetch(idx): - if isinstance(idx, pe.InDBIdx): - return in_vals[idx.val] - else: - assert isinstance(idx, pe.OutDBIdx) - return out_vals[idx.val] - - dims = [RaggedAxis(d.stacked_axis, - tuple((ragged_axis, fetch(lengths_idx)) - for ragged_axis, lengths_idx in d.ragged_axes)) - if isinstance(d, RaggedAxis) else d for d in dims] - return dims + return out_vals ### API for batching jaxprs -# TODO(axch): parameterize RaggedAxis annotations by a type parameter so as to -# indicate whether we're dealing with instances that contain Arrays or DBIdx. -# Can reuse same pattern for all dynamic shape stuff. def batch_jaxpr2( closed_jaxpr: core.ClosedJaxpr, axis_data, - in_axes: tuple[int | NotMapped | RaggedAxis, ...], - ) -> tuple[core.ClosedJaxpr, tuple[int | NotMapped | RaggedAxis, ...]]: - # This is only ever used in pjit. The difference vs batch_jaxpr is that - # batch_jaxpr2 lets the callee decide which outputs are batched and what - # their batch axes are; whereas batch_jaxpr has to obey caller-imposed - # consistency constraints, such as type-agreement across arms of a - # `lax.cond`, or input-output agreement for the body of a `lax.scan`. + in_axes: tuple[int | NotMapped, ...], + ) -> tuple[core.ClosedJaxpr, tuple[int | NotMapped ]]: return _batch_jaxpr2(closed_jaxpr, axis_data, tuple(in_axes)) @weakref_lru_cache def _batch_jaxpr2( closed_jaxpr: core.ClosedJaxpr, axis_data, - in_axes: tuple[int | NotMapped | RaggedAxis, ...], + in_axes: tuple[int | NotMapped ], ) -> tuple[core.ClosedJaxpr, tuple[int | NotMapped, ...]]: f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr), debug_info=closed_jaxpr.jaxpr.debug_info) f, out_axes = _batch_jaxpr_inner(f, axis_data) f = _batch_jaxpr_outer(f, axis_data, in_axes) - in_axes2, avals_in = unzip2([ - handle_ragged(closed_jaxpr.in_avals, dim, aval) - if isinstance(dim, RaggedAxis) else (dim, aval) - for dim, aval in zip(in_axes, closed_jaxpr.in_avals)]) - avals_in2 = [core.unmapped_aval(axis_data.size, b, aval, - axis_data.explicit_mesh_axis) - if b is not not_mapped else aval - for aval, b in unsafe_zip(avals_in, in_axes2)] - jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in2) + avals_in2 = [] + for aval, b in unsafe_zip(closed_jaxpr.in_avals, in_axes): + if b is not_mapped: + avals_in2.append(aval) + else: + aval = core.unmapped_aval( + axis_data.size, b, aval, axis_data.explicit_mesh_axis) + if axis_data.spmd_name is not None: + if config._check_vma.value: + aval = aval.update(vma=aval.vma | frozenset(axis_data.spmd_name)) # type: ignore + avals_in2.append(aval) + jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in2) return core.ClosedJaxpr(jaxpr_out, consts), out_axes() -def handle_ragged(in_avals: list[core.AbstractValue], dim: RaggedAxis, - aval: core.ShapedArray) -> tuple[int, core.ShapedArray]: - new_shape = list(aval.shape) - for i, dbi in dim.ragged_axes: - new_shape[i - (dim.stacked_axis < i)] = in_avals[dbi.val].dtype.bound - new_aval = aval.update(shape=tuple(new_shape)) - return dim.stacked_axis, new_aval - def batch_jaxpr(closed_jaxpr, axis_data, in_batched, instantiate): inst = tuple(instantiate) if isinstance(instantiate, list) else instantiate return _batch_jaxpr(closed_jaxpr, axis_data, tuple(in_batched), inst) @@ -816,24 +505,24 @@ def _batch_jaxpr_axes(closed_jaxpr: core.ClosedJaxpr, axis_data.explicit_mesh_axis) if b is not not_mapped else aval for aval, b in unsafe_zip(closed_jaxpr.in_avals, in_axes)] - jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in) + jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in) return core.ClosedJaxpr(jaxpr_out, consts), out_batched() @lu.transformation_with_aux2 def _batch_jaxpr_inner(f, store, axis_data, tag, in_axes, *in_vals): with core.take_current_trace() as parent_trace: trace = BatchTrace(parent_trace, tag, axis_data) - _, in_axes = resolve_ragged_axes(in_vals, in_axes) in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val for val, dim in zip(in_vals, in_axes)] + # TODO(yashkatariya): Instead of `add_explicit_mesh_axis_names`, we should + # create a new mesh by removing the axis_data.explicit_mesh_axis from it. with (core.set_current_trace(trace), core.extend_axis_env_nd([(axis_data.name, axis_data.size)]), - core.add_spmd_axis_names(axis_data.spmd_name)): + core.add_spmd_axis_names(axis_data.spmd_name), + core.add_explicit_mesh_axis_names(axis_data.explicit_mesh_axis)): outs = f(*in_tracers) out_vals, out_axes = unzip2(map(trace.to_batch_info, outs)) - new_out_axes = indirectify_ragged_axes_against_inputs_outputs( - out_axes, in_vals, out_vals) - store.store(new_out_axes) + store.store(out_axes) return out_vals @lu.transformation_with_aux2 @@ -888,20 +577,16 @@ def batch_custom_jvp_subtrace(f, store, tag, axis_data, in_dims, *in_vals): if type(val) is SymbolicZero else BatchTracer(trace, val, dim) for val, dim in zip(in_vals, in_dims * 2)] with core.set_current_trace(trace): - outs = f(*in_tracers) - # TODO(mattjj,frostig): instantiating any SymbolicZero output is easy, but can - # be wasteful in the rare case it actually triggers; handle symbolically! - outs = [instantiate(replace_rule_output_symbolic_zeros(x)) for x in outs] - - out_vals, out_dims = unzip2(map(trace.to_batch_info, outs)) + out_tracers: list[BatchTracer | SymbolicZero] = f(*in_tracers) + out_vals, out_dims = unzip2(map(trace.to_batch_info, out_tracers)) out_primals, out_tangents = split_list(out_vals, [len(out_vals) // 2]) out_primal_bds, out_tangent_bds = split_list(out_dims, [len(out_vals) // 2]) out_dims = map(_merge_bdims, out_primal_bds, out_tangent_bds) out_primals = map(partial(matchaxis, trace.axis_data.name, size, mesh_axis), out_primal_bds, out_dims, out_primals) - out_tangents = map(partial(matchaxis, trace.axis_data.name, size, mesh_axis), + out_tangents = map(partial(_matchaxis_symzeros, trace.axis_data.name, size, mesh_axis), out_tangent_bds, out_dims, out_tangents) - store.store(out_dims * 2) + store.store(out_dims) return out_primals + out_tangents def batch_custom_vjp_bwd(bwd: lu.WrappedFun, tag: core.TraceTag, @@ -929,12 +614,11 @@ def _match_axes_and_sum(f, axis_size, axis_name, mesh_axis, out_dims_thunk, out_dim_dests, *in_vals): # this is like _match_axes, but we do reduce-sums as needed out_vals = f(*in_vals) - return map(partial(_matchaxis_symbolic_zeros, axis_name, axis_size, mesh_axis, - axis_name, sum_match=True), + return map(partial(_matchaxis_symzeros, axis_name, axis_size, mesh_axis, + sum_match=True), out_dims_thunk(), out_dim_dests, out_vals) -def _matchaxis_symbolic_zeros(axis_name, sz, mesh_axis, name, src, dst, x, - sum_match=False): +def _matchaxis_symzeros(axis_name, sz, mesh_axis, src, dst, x, sum_match=False): # Just like `matchaxis`, but handles symbolic zeros using ad_util.py # TODO(mattjj): dedup with matchaxis if isinstance(x, (Zero, SymbolicZero)): @@ -942,11 +626,11 @@ def _matchaxis_symbolic_zeros(axis_name, sz, mesh_axis, name, src, dst, x, return x elif type(src) == type(dst) == int: aval = core.mapped_aval(sz, src, x.aval) - return Zero(core.unmapped_aval(sz, dst, aval, mesh_axis)) + return type(x)(core.unmapped_aval(sz, dst, aval, mesh_axis)) elif src is not_mapped and dst is not not_mapped: - return Zero(core.unmapped_aval(sz, dst, x.aval, mesh_axis)) + return type(x)(core.unmapped_aval(sz, dst, x.aval, mesh_axis)) elif dst is not_mapped and sum_match: - return Zero(core.mapped_aval(sz, src, x.aval)) + return type(x)(core.mapped_aval(sz, src, x.aval)) else: raise ValueError((axis_name, x, src, dst)) else: @@ -959,8 +643,6 @@ def _matchaxis_symbolic_zeros(axis_name, sz, mesh_axis, name, src, dst, x, ..., tuple[Any, Union[int, None, tuple[Union[int, None], ...]]] ] -primitive_batchers : dict[core.Primitive, BatchingRule] = {} -# "fancy" primitive batchers just take a extra leading `AxisData` and "trace type" args fancy_primitive_batchers: dict[core.Primitive, Callable] = {} # backwards compat shim. TODO: delete @@ -969,36 +651,47 @@ def __setitem__(self, prim, batcher): def wrapped(axis_data, vals, dims, **params): return batcher(axis_data.size, axis_data.name, None, vals, dims, **params) fancy_primitive_batchers[prim] = wrapped - axis_primitive_batchers = AxisPrimitiveBatchersProxy() +# backwards compat shim. TODO: delete +class PrimitiveBatchersProxy: + def __setitem__(self, prim, batcher): + def wrapped(axis_data, vals, dims, **params): + del axis_data + if all(d is None for d in dims): + o = prim.bind(*vals, **params) + return (o, [None] * len(o)) if prim.multiple_results else (o, None) + return batcher(vals, dims, **params) + fancy_primitive_batchers[prim] = wrapped + + def __delitem__(self, prim): + del fancy_primitive_batchers[prim] +primitive_batchers = PrimitiveBatchersProxy() + # Presence in this table allows fancy batchers to be skipped by batch traces for -# irrelevant axes. The Callable takes the params and returns a list of relevant -# axes. +# irrelevant axes. The Callable takes params and returns a list of relevant axes +# TODO(yashkatariya): remove this skippable_batchers : dict[core.Primitive, Callable] = {} def defvectorized(prim): - primitive_batchers[prim] = partial(vectorized_batcher, prim) + fancy_primitive_batchers[prim] = partial(vectorized_batcher, prim) -def vectorized_batcher(prim, batched_args, batch_dims, **params): +def vectorized_batcher(prim, axis_data, batched_args, batch_dims, **params): + assert not prim.multiple_results + if all(d is None for d in batch_dims): + return prim.bind(*batched_args, **params), None assert all(batch_dims[0] == bd for bd in batch_dims[1:]), batch_dims return prim.bind(*batched_args, **params), batch_dims[0] def defbroadcasting(prim): - primitive_batchers[prim] = partial(broadcast_batcher, prim) - -def broadcast_batcher(prim, args, dims, **params): - """Process a primitive with built-in broadcasting. + fancy_primitive_batchers[prim] = partial(broadcast_batcher, prim) - Args: - args: the possibly-batched arguments - dims: list or tuple of the same length as `args`, where each - entry indicates the batching state of the corresponding entry to `args`: - either an int indicating the batch dimension, or else `not_mapped` - indicating no batching. - """ +def broadcast_batcher(prim, axis_data, args, dims, **params): assert len(args) > 1 + if all(d is None for d in dims): + o = prim.bind(*args, **params) + return (o, [None] * len(o)) if prim.multiple_results else (o, None) shape, dim = next((x.shape, d) for x, d in zip(args, dims) if d is not not_mapped) if all(core.definitely_equal_shape(shape, x.shape) and d == dim @@ -1018,15 +711,21 @@ def broadcast_batcher(prim, args, dims, **params): return (out, (0,) * len(out)) if prim.multiple_results else (out, 0) def _handle_scalar_broadcasting(nd, x, d): + # Callers of this utility, via broadcast_batcher() or defbroadcasting(), + # must be in a context where lax is importable. + from jax import lax # pytype: disable=import-error if d is not_mapped or nd == np.ndim(x): return x else: - return jax.lax.expand_dims(x, tuple(range(np.ndim(x), nd))) + return lax.expand_dims(x, tuple(range(np.ndim(x), nd))) def defreducer(prim, ident): - primitive_batchers[prim] = partial(reducer_batcher, prim, ident) + fancy_primitive_batchers[prim] = partial(reducer_batcher, prim, ident) -def reducer_batcher(prim, ident, batched_args, batch_dims, axes, **params): +def reducer_batcher(prim, ident, axis_data, batched_args, batch_dims, axes, + **params): + if all(d is None for d in batch_dims): + return prim.bind(*batched_args, axes=axes, **params), None def out_axis(axes, axis): return int(list(np.delete(np.arange(operand.ndim), axes)).index(axis)) operand, = batched_args @@ -1037,23 +736,6 @@ def out_axis(axes, axis): if 'input_shape' in params: params = dict(params, input_shape=operand.shape) return prim.bind(operand, axes=axes, **params), bdim_out - elif isinstance(bdim, RaggedAxis): - assert ident is not None, "TODO Ragged batching a reduction requires an identity" - axes = tuple(np.where(np.less(axes, bdim.stacked_axis), axes, np.add(axes, 1))) - bdim_out = out_axis(axes, bdim.stacked_axis) - # For each ragged_axis, we either mask the operand there or append - # it to the set of axes that will be ragged in the result. - axes_to_mask = [] - ragged_axes_out = [] - for ragged_axis, segment_lengths in bdim.ragged_axes: - if ragged_axis in axes: - axes_to_mask.append((ragged_axis, segment_lengths)) - else: - ragged_axes_out.append((out_axis(axes, ragged_axis), segment_lengths)) - operand = mask_ragged_axes( - operand, ident, RaggedAxis(bdim.stacked_axis, tuple(axes_to_mask))) - result = prim.bind(operand, axes=axes, **params) - return result, make_batch_axis(operand.ndim, bdim_out, ragged_axes_out) else: assert False @@ -1066,73 +748,51 @@ def expand_dims_batcher(prim, args, dims, **params): out = prim.bind(*args, **params) return (out, (0,) * len(out)) if prim.multiple_results else (out, 0) -def mask_ragged_axes(operand: Array, ident, axis_spec: RaggedAxis) -> Array: - # TODO(mattjj, axch) Can we mask multiple axes more efficiently at - # once, rather than one at a time? - for ragged_axis, segment_lengths in axis_spec.ragged_axes: - this_axis_spec = RaggedAxis( - axis_spec.stacked_axis, ((ragged_axis, segment_lengths),)) - operand = _mask_one_ragged_axis(operand, ident, this_axis_spec) - return operand - -def _mask_one_ragged_axis( - operand: Array, ident, axis_spec: RaggedAxis) -> Array: - assert len(axis_spec.ragged_axes) == 1, "Mask just one ragged axis at a time" - ragged_axis, segment_lengths = axis_spec.ragged_axes[0] - value = ident(operand.dtype) - positions = jax.lax.broadcasted_iota('int32', operand.shape, ragged_axis) - # TODO(mattjj, axch) can't get ._data, need to convert it - # lengths = jax.lax.convert_element_type(segment_lengths._data, 'int32') - lengths = jax.lax.convert_element_type(segment_lengths, 'int32') - limits = jax.lax.broadcast_in_dim( - lengths, operand.shape, [axis_spec.stacked_axis]) - mask = positions < limits - return jax.lax.select(mask, operand, jax.lax.broadcast(value, operand.shape)) - -def move_stacked_axis(operand, bdim, dst): - dst = canonicalize_axis(dst, operand.ndim) - if isinstance(bdim, int): - return moveaxis(operand, bdim, dst), dst - elif isinstance(bdim, RaggedAxis): - result = moveaxis(operand, bdim.stacked_axis, dst) - return result, bdim.move_stacked_axis(dst) - else: - raise TypeError(f"Unrecognized batch dimension type {bdim}") - ### general utilities for manipulating axes on jaxpr types (not vmappables) -def broadcast(x, sz, axis, mesh_axis=None): +def broadcast(x, sz, axis, mesh_axis): + # Callers of this utility must be in a context where lax is importable. + from jax import lax # pytype: disable=import-error shape = list(np.shape(x)) shape.insert(axis, sz) broadcast_dims = tuple(np.delete(np.arange(len(shape)), axis)) x_aval = core.get_aval(x) + if x_aval.sharding.mesh.empty: + mesh_axis = None new_spec = P(*tuple_insert(x_aval.sharding.spec, axis, mesh_axis)) - sharding = x_aval.sharding.with_spec(new_spec) + sharding = x_aval.sharding.update(spec=new_spec) # TODO(dougalm, yashkatariya): Delete this context manager once we figure # out how to ensure jaxpr arguments always have the context mesh. with mesh_lib.use_abstract_mesh(sharding.mesh): - return jax.lax.broadcast_in_dim(x, shape, broadcast_dims, - out_sharding=sharding) + x = lax.broadcast_in_dim(x, shape, broadcast_dims, out_sharding=sharding) + if config._check_vma.value: + # TODO(yashkatariya,parkers): don't do this, fix during fixit week 2026 + spmd_names = core.get_axis_env().spmd_axis_names + if len(spmd_names) > 1: + raise NotImplementedError + if spmd_names: + x = core.pvary(x, tuple(spmd_names)) + return x + +def matchaxis2(axis_data, src, dst, x, sum_match=False): + return matchaxis(axis_data.name, axis_data.size, axis_data.explicit_mesh_axis, + src, dst, x, sum_match) def matchaxis(axis_name, sz, mesh_axis, src, dst, x, sum_match=False): - if dst == jumble_axis: - x = bdim_at_front(x, src, sz) - elt_ty = x.aval.update(shape=x.shape[1:]) - aval = JumbleTy(core.Var('', core.ShapedArray((), np.dtype('int32'))), - x.shape[0], elt_ty) - return Jumble(aval, x) try: _ = core.get_aval(x) except TypeError as e: raise TypeError(f"Output from batched function {x!r} with type " f"{type(x)} is not a valid JAX type") from e - if src == dst: + if src == dst or dst is infer: return x elif type(src) == type(dst) == int: return moveaxis(x, src, dst) - elif src is not_mapped and dst is not not_mapped: + elif src is not_mapped and type(dst) is int: return broadcast(x, sz, canonicalize_axis(dst, np.ndim(x) + 1), mesh_axis) - elif dst is not_mapped and sum_match: + elif src is not_mapped and dst is sum_axis: + return x + elif dst is not_mapped and sum_match or dst is sum_axis: return x.sum(src) else: if (not isinstance(axis_name, core._TempAxisName) and @@ -1147,25 +807,42 @@ def __init__(self, leaf_idx, src, dst): self.src = src self.dst = dst -def bdim_at_front(x, bdim, size): +def bdim_at_front(x, bdim, size, mesh_axis=None): if bdim is not_mapped: - return broadcast(x, size, 0) + return broadcast(x, size, 0, mesh_axis=mesh_axis) else: return moveaxis(x, bdim, 0) -def add_batched(batched_args, batch_dims): +def add_batched(axis_data, batched_args, batch_dims): bdx, bdy = batch_dims x, y = batched_args + mesh_axis = axis_data.explicit_mesh_axis if bdx == bdy: return add_jaxvals(x, y), bdx elif bdx is not_mapped: - x = broadcast(x, y.shape[bdy], bdy) + x = broadcast(x, y.shape[bdy], bdy, mesh_axis=mesh_axis) return add_jaxvals(x, y), bdy elif bdy is not_mapped: - y = broadcast(y, x.shape[bdx], bdx) + y = broadcast(y, x.shape[bdx], bdx, mesh_axis=mesh_axis) return add_jaxvals(x, y), bdx else: x = moveaxis(x, bdx, bdy) return add_jaxvals(x, y), bdy -primitive_batchers[add_jaxvals_p] = add_batched + +fancy_primitive_batchers[add_jaxvals_p] = add_batched +skippable_batchers[add_jaxvals_p] = lambda _: () + +### mutable arrays + +defvectorized(core.ref_p) + +### hijax + +class Sum: pass +sum_axis = Sum() +spec_types.add(Sum) + +class Infer: pass +infer = Infer() +spec_types.add(Infer) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 1369f72ac74c..3080b13281be 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -16,29 +16,30 @@ from __future__ import annotations import collections -import contextlib from collections.abc import Callable, Iterable, Iterator, Sequence import dataclasses import functools from functools import partial +import heapq import io import itertools import operator -import os import re import types import typing from typing import Any, NamedTuple, Protocol, Union, cast as type_cast import warnings -import numpy as np - from jax._src import ad_util from jax._src import api_util from jax._src import config from jax._src import core from jax._src import dtypes from jax._src import effects as effects_lib +from jax._src import frozen_dict +from jax._src import hashable_array +from jax._src import literals +from jax._src import jaxpr_util from jax._src import linear_util as lu from jax._src import path from jax._src import sharding_impls @@ -46,20 +47,25 @@ from jax._src import util from jax._src import xla_bridge as xb from jax._src.interpreters import partial_eval as pe -from jax._src.interpreters import xla -from jax._src.layout import AutoLayout, DeviceLocalLayout -from jax._src.partition_spec import PartitionSpec -from jax._src.sharding import Sharding as JSharding -from jax._src.sharding_impls import (AUTO, NamedSharding, - modify_sdy_sharding_wrt_axis_types, - SdyArraySharding, SdyArrayShardingList) -from jax._src.util import foreach +from jax._src.layout import AutoLayout, Layout +from jax._src.lib import _jax +from jax._src.lib import jax_mlir_ext +from jax._src.lib import version as jaxlib_version from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension, xla_extension_version from jax._src.lib.mlir import dialects, ir, passmanager from jax._src.lib.mlir.dialects import func as func_dialect, hlo -from jax._src.lib.mlir import register_jax_dialects +from jax._src.mesh import AxisType +from jax._src.partition_spec import PartitionSpec +from jax._src.sharding import Sharding as JSharding +from jax._src.sharding_impls import ( AUTO, NamedSharding, + SdyArray, SdyArrayList, + modify_sdy_sharding_wrt_axis_types) from jax._src.state.types import AbstractRef +from jax._src.typing import ArrayLike +from jax._src.util import foreach +import numpy as np + +USE_NEW_TPU_CALLBACK_LOWERING = jaxlib_version >= (0, 9, 1) # mypy: ignore-errors @@ -73,33 +79,14 @@ # mypy implicitly sets this variable to true when type checking. MYPY = False -_JAX_DUMP_IR_TO = config.string_flag( - 'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''), - help="Path to which the IR that is emitted by JAX should be dumped as " - "text files. If omitted, JAX will not dump IR. " - "Supports the special value 'sponge' to pick the path from the " - "environment variable TEST_UNDECLARED_OUTPUTS_DIR.") - -_JAX_INCLUDE_DEBUG_INFO_IN_DUMPS = config.string_flag( - 'jax_include_debug_info_in_dumps', - os.getenv('JAX_INCLUDE_DEBUG_INFO_IN_DUMPS', "True"), - help="Determine whether or not to keep debug symbols and location information " - "when dumping IR code. By default, debug information will be preserved in " - "the IR dump. To avoid exposing source code and potentially sensitive " - "information, set to false") -lowerable_effects: effects_lib.EffectTypeSet = effects_lib.lowerable_effects - - # IR Helpers IrValues = Union[ir.Value, tuple[ir.Value, ...]] def _is_not_block_argument(x: IrValues) -> bool: - """Returns true if `x` is not a block argument.""" return not isinstance(x, ir.BlockArgument) - def dense_int_elements(xs) -> ir.DenseIntElementsAttr: return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64)) @@ -107,16 +94,9 @@ def dense_int_elements(xs) -> ir.DenseIntElementsAttr: def dense_bool_elements(xs: Sequence[bool]) -> ir.DenseElementsAttr: a = np.packbits(np.array(xs, np.bool_), bitorder='little') - # TODO(b/209005197): Work around for MLIR crash for non-splat single element - # buffers. - if len(xs) == 1: - a = np.array(0 if a.item() == 0 else 0xff, np.uint8) return ir.DenseElementsAttr.get( a, type=ir.IntegerType.get_signless(1), shape=[len(xs)]) -def dense_bool_array(xs: Sequence[bool]) -> ir.DenseBoolArrayAttr: - return ir.DenseBoolArrayAttr.get(xs) - def i32_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), i) def i64_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), i) @@ -185,24 +165,14 @@ def _is_ir_values(x: IrValues) -> bool: np.dtype(np.float64): ir.F64Type.get, np.dtype(np.complex64): lambda: ir.ComplexType.get(ir.F32Type.get()), np.dtype(np.complex128): lambda: ir.ComplexType.get(ir.F64Type.get()), + np.dtype(dtypes.int2): partial(ir.IntegerType.get_signless, 2), + np.dtype(dtypes.uint2): partial(ir.IntegerType.get_unsigned, 2), + np.dtype(dtypes.float8_e3m4): ir.Float8E3M4Type.get, + np.dtype(dtypes.float8_e4m3): ir.Float8E4M3Type.get, + np.dtype(dtypes.float8_e8m0fnu): ir.Float8E8M0FNUType.get, + np.dtype(dtypes.float4_e2m1fn): ir.Float4E2M1FNType.get, } - -if dtypes.int2 is not None: - assert dtypes.uint2 is not None - _dtype_to_ir_type[np.dtype(dtypes.int2)] = partial(ir.IntegerType.get_signless, 2) - _dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial(ir.IntegerType.get_unsigned, 2) - -if dtypes.float8_e3m4 is not None: - _dtype_to_ir_type[np.dtype(dtypes.float8_e3m4)] = ir.Float8E3M4Type.get -if dtypes.float8_e4m3 is not None: - _dtype_to_ir_type[np.dtype(dtypes.float8_e4m3)] = ir.Float8E4M3Type.get -if dtypes.float8_e8m0fnu is not None: - _dtype_to_ir_type[np.dtype(dtypes.float8_e8m0fnu)] = ir.Float8E8M0FNUType.get - -if dtypes.float4_e2m1fn is not None: - _dtype_to_ir_type[np.dtype(dtypes.float4_e2m1fn)] = ir.Float4E2M1FNType.get - def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type: if isinstance(dtype, core.bint): # TODO Support different-size underlying dtypes to take advantage of the @@ -217,7 +187,7 @@ def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type: f"No dtype_to_ir_type handler for dtype: {dtype}") from err return ir_type_factory() -def _array_ir_types(aval: core.ShapedArray | core.DShapedArray) -> ir.Type: +def _array_ir_types(aval: core.ShapedArray) -> ir.Type: aval = core.physical_aval(aval) # type: ignore if not core.is_constant_shape(aval.shape): return _dynamic_array_ir_types(aval) # type: ignore @@ -242,7 +212,6 @@ def aval_to_ir_type(aval: core.AbstractValue) -> IrTypes: ir_type_handlers[core.ShapedArray] = _array_ir_types ir_type_handlers[core.AbstractToken] = lambda _: hlo.TokenType.get() -ir_type_handlers[core.DShapedArray] = _dynamic_array_ir_types # This is a backwards compatibility shim for external users of jax.mlir apis. def aval_to_ir_types(aval: core.AbstractValue) -> tuple[ir.Type, ...]: @@ -252,7 +221,7 @@ def aval_to_ir_types(aval: core.AbstractValue) -> tuple[ir.Type, ...]: # Constants class ConstantHandler(Protocol): - def __call__(self, val: Any) -> IrValues: + def __call__(self, val: Any, aval: core.AbstractValue | None) -> IrValues: """Builds an IR representation for a constant `val`. A JAX value is represented by zero or more IR values.""" @@ -265,19 +234,32 @@ def register_constant_handler(type_: type, handler_fun: ConstantHandler): def get_constant_handler(type_: type) -> ConstantHandler: return _constant_handlers[type_] -def ir_constant(val: Any) -> IrValues: - """Translate a Python `val` to an IR constant, canonicalizing its dtype. +def ir_constant( + val: Any, *, + const_lowering: dict[tuple[int, core.AbstractValue], IrValues] | None = None, + aval: core.AbstractValue | None = None +) -> IrValues: + """Translate a Python `val` to an IR constant. + See https://docs.jax.dev/en/latest/internals/constants.html. Args: val: a Python value to be translated to a constant. + const_lowering: an optional dictionary with known lowering for some + constants, indexed by `id`. This is used, e.g., when we pass constants + as MLIR function arguments. + aval: the abstract value of `val`, if known. Required where ambiguous, e.g. + for Python scalars. Returns: A representation of the constant as an IR value or sequence of IR values. """ + if const_lowering is not None: + if np.shape(val) and (c_val := const_lowering.get((id(val), aval))) is not None: + return c_val for t in type(val).__mro__: handler = _constant_handlers.get(t) if handler: - out = handler(val) + out = handler(val, aval) assert _is_ir_values(out), (type(val), out) return out if hasattr(val, '__jax_array__'): @@ -300,7 +282,15 @@ def _masked_array_constant_handler(*args, **kwargs): register_constant_handler(np.ma.MaskedArray, _masked_array_constant_handler) -def _ndarray_constant_handler(val: np.ndarray | np.generic) -> IrValues: +def _shape_dtype_struct_constant_handler(*args, **kwargs): + raise TypeError("A ShapeDtypeStruct does not have a value and cannot be " + "used as a constant in a JAX function.") + +register_constant_handler(core.ShapeDtypeStruct, + _shape_dtype_struct_constant_handler) + +def _ndarray_constant_handler(val: np.ndarray | np.generic, + aval: core.AbstractValue | None) -> IrValues: """Constant handler for ndarray literals, handling zero-size strides. In most cases this function calls _numpy_array_constant(val) except it has @@ -334,6 +324,7 @@ def _ndarray_constant_handler(val: np.ndarray | np.generic) -> IrValues: return _numpy_array_constant(val) register_constant_handler(np.ndarray, _ndarray_constant_handler) +register_constant_handler(literals.TypedNdArray, _ndarray_constant_handler) for _scalar_type in [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, @@ -342,13 +333,15 @@ def _ndarray_constant_handler(val: np.ndarray | np.generic) -> IrValues: np.bool_, np.longlong, dtypes.bfloat16]: register_constant_handler(_scalar_type, _ndarray_constant_handler) # type: ignore -def _python_scalar_handler(dtype, val): - return _numpy_array_constant(np.array(val, dtype)) +def _python_scalar_handler(val, aval: core.AbstractValue | None): + assert isinstance(aval, core.ShapedArray), aval + assert aval.shape == (), aval + return _numpy_array_constant(np.array(val, aval.dtype)) -for ptype, dtype in dtypes.python_scalar_dtypes.items(): - register_constant_handler(ptype, partial(_python_scalar_handler, dtype)) +for ptype in dtypes.python_scalar_types: + register_constant_handler(ptype, _python_scalar_handler) -def _token_constant_handler(val): +def _token_constant_handler(val: core.Token, aval: core.AbstractValue | None): return hlo.create_token() register_constant_handler(core.Token, _token_constant_handler) @@ -373,11 +366,11 @@ def _numpy_scalar_attribute(val: Any) -> ir.Attribute: raise TypeError(f"Unsupported scalar attribute type: {type(val)}") def _numpy_array_attribute(x: np.ndarray | np.generic) -> ir.Attribute: + element_type = dtype_to_ir_type(x.dtype) shape = x.shape if x.dtype == np.bool_: x = np.packbits(x, bitorder='little') # type: ignore x = np.ascontiguousarray(x) - element_type = dtype_to_ir_type(x.dtype) return ir.DenseElementsAttr.get(x, type=element_type, shape=shape) # type: ignore def _numpy_array_attribute_handler(val: np.ndarray | np.generic) -> ir.Attribute: @@ -392,6 +385,8 @@ def _numpy_array_attribute_handler(val: np.ndarray | np.generic) -> ir.Attribute return _numpy_array_attribute(val) register_attribute_handler(np.ndarray, _numpy_array_attribute_handler) +register_attribute_handler(hashable_array.HashableArray, + lambda x: _numpy_array_attribute_handler(x.val)) for _scalar_type in [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, @@ -409,7 +404,7 @@ def _dtype_attribute_handler(dtype: np.dtype | np.generic) -> ir.Attribute: def _python_scalar_attribute_handler(dtype, val): return _numpy_scalar_attribute(np.array(val, dtype)) -for ptype, dtype in dtypes.python_scalar_dtypes.items(): +for ptype, dtype in dtypes.python_scalar_types_to_dtypes.items(): register_attribute_handler( ptype, partial(_python_scalar_attribute_handler, dtype)) @@ -454,81 +449,40 @@ def get_canonical_source_file(file_name: str, caches: TracebackCaches) -> str: caches.canonical_name_cache[file_name] = file_name return file_name -def _is_user_file(ctx: ModuleContext, file_name: str) -> bool: - is_user = ctx.traceback_caches.is_user_file_cache.get(file_name, None) - if is_user is not None: - return is_user - out = source_info_util.is_user_filename(file_name) - ctx.traceback_caches.is_user_file_cache[file_name] = out - return out - def _traceback_to_location(ctx: ModuleContext, tb: xc.Traceback) -> ir.Location: """Converts a full traceback to a callsite() MLIR location.""" - loc = ctx.traceback_caches.traceback_cache.get(tb, None) - if loc is not None: - return loc + return ctx.traceback_caches.traceback_to_location_cache.get(tb) - frame_locs = [] - frames_limit = config.traceback_in_locations_limit.value - frames_limit = frames_limit if frames_limit >= 0 else 1000 - - codes, lastis = tb.raw_frames() - for i, code in enumerate(codes): - if not _is_user_file(ctx, code.co_filename): - continue - - lasti = lastis[i] - code_lasti = code, lasti - loc = ctx.traceback_caches.location_cache.get(code_lasti, None) - if loc is None: - frame = source_info_util.raw_frame_to_frame(code, lasti) - file_loc = ir.Location.file( - get_canonical_source_file(frame.file_name, ctx.traceback_caches), - frame.start_line, - frame.start_column, - frame.end_line, - frame.end_column, - ) - loc = ir.Location.name(frame.function_name, childLoc=file_loc) - ctx.traceback_caches.location_cache[code_lasti] = loc - frame_locs.append(loc) - if len(frame_locs) >= frames_limit: - break - - n = len(frame_locs) - if n == 0: - loc = ir.Location.unknown() - elif n == 1: - loc = frame_locs[0] - else: - loc = ir.Location.callsite(frame_locs[0], frame_locs[1:]) - ctx.traceback_caches.traceback_cache[tb] = loc - return loc - -def _source_info_to_location( - ctx: ModuleContext, primitive: core.Primitive, - source_info: source_info_util.SourceInfo) -> ir.Location: - eqn_str = f'{source_info.name_stack}/{primitive.name}' +def source_info_to_location( + ctx: ModuleContext, primitive: core.Primitive | None, + name_stack: source_info_util.NameStack, + traceback: xc.Traceback | None) -> ir.Location: if config.include_full_tracebacks_in_locations.value: - if source_info.traceback is None: + if traceback is None: loc = ir.Location.unknown() else: - loc = _traceback_to_location(ctx, source_info.traceback) + loc = _traceback_to_location(ctx, traceback) else: - frame = source_info_util.user_frame(source_info) + frame = source_info_util.user_frame(traceback) if frame is None: loc = ir.Location.unknown() else: loc = ir.Location.file(get_canonical_source_file(frame.file_name, ctx.traceback_caches), frame.start_line, frame.start_column) - loc = ir.Location.name(eqn_str, childLoc=loc) - # TODO(phawkins): also include primitive.name as the operator type. + if primitive is None: + if name_stack.stack: + loc = ir.Location.name(str(name_stack), childLoc=loc) + else: + eqn_str = ( + f"{name_stack}/{primitive.name}" if name_stack.stack else primitive.name + ) + loc = ir.Location.name(eqn_str, childLoc=loc) + loc = ir.Location.name(f"{primitive.name}:", childLoc=loc) return loc upstream_dialects = ir.DialectRegistry() -if register_jax_dialects: - register_jax_dialects.register_dialects(upstream_dialects) +jax_mlir_ext.register_dialects(upstream_dialects) # Dumping MLIR modules _ir_dump_counter = itertools.count() @@ -547,23 +501,17 @@ def dump_module_to_file(module: ir.Module, stage_name: str) -> str | None: The name of the file containing the dump if JAX_DUMP_IR_TO is defined and the module was dumped, `None` otherwise. """ - out_dir_name = _JAX_DUMP_IR_TO.value - if not out_dir_name: + if not (out_dir := path.make_jax_dump_dir(config.jax_dump_ir_to.value)): + return None + modes = config.jax_dump_ir_modes.value.split(',') + if 'stablehlo' not in modes: return None - if out_dir_name == "sponge": - out_dir_name = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", "") - if not out_dir_name: - raise ValueError("JAX_DUMP_IR_TO='sponge' but " - "TEST_UNDECLARED_OUTPUTS_DIR is not defined") - id = next(_ir_dump_counter) sym_name = module.operation.attributes['sym_name'] module_name = ir.StringAttr(sym_name).value name = f"jax_ir{id:04d}_{_make_string_safe_for_filename(module_name)}_{stage_name}.mlir" - out_dir = path.Path(out_dir_name) - out_dir.mkdir(parents=True, exist_ok=True) full_path = out_dir / name full_path.write_text(module_to_string(module)) return name @@ -581,7 +529,7 @@ def _make_string_safe_for_filename(s: str) -> str: def module_to_string(module: ir.Module, enable_debug_info=None) -> str: output = io.StringIO() if enable_debug_info is None: - enable_debug_flag = str.lower(_JAX_INCLUDE_DEBUG_INFO_IN_DUMPS.value) + enable_debug_flag = str(config.jax_include_debug_info_in_dumps.value).lower() enable_debug_info = enable_debug_flag not in ('false', '0') module.operation.print(file=output, enable_debug_info=enable_debug_info) return output.getvalue() @@ -593,6 +541,11 @@ def module_to_bytecode(module: ir.Module) -> bytes: # Translation rules +# Create one global thread pool that can be shared between multiple ir.Contexts +# and enabling multi-threading +global_thread_pool = ir.ThreadPool() + + class JaxIrContext(ir.Context): def __init__(self, *args, **kwargs): # Note: we're very intentionally *not* calling the __init__() of our @@ -607,18 +560,18 @@ def make_ir_context() -> ir.Context: context.append_dialect_registry(upstream_dialects) context.load_all_available_dialects() - # If threading is enabled, each MLIR context will keep alive a thread pool. - # Since we cache MLIR modules (and hence contexts), this means we might keep - # several threads alive for each cache entry. This is a terrible idea. However - # we don't do any heavy computation on MLIR modules from Python anyway, so we - # just disable threading. - context.enable_multithreading(False) - # TODO(bartchr): Once JAX is released with SDY, remove the if. - if dialects.sdy: - dialects.sdy.register_dialect(context) + context.set_thread_pool(global_thread_pool) + dialects.sdy.register_dialect(context) + dialects.mpmd.register_dialect(context) dialects.mhlo.register_mhlo_dialect(context) dialects.chlo.register_dialect(context) dialects.hlo.register_dialect(context) + # If built in debug mode, and MLIR is in a multithreaded context, enabling + # multi threaded execution aborts the process if we try to register a new + # dialect after this point. The dialect registry in a context is not thread + # safe, and a fatal error is much better than a data race. + # if jaxlib_version >= (0, 8): + # jax_mlir_ext.enter_multi_threaded_execution(context) return context @@ -662,7 +615,7 @@ def __init__(self, @dataclasses.dataclass(frozen=True) class LoweringParameters: # A mapping between primitives and user-defined LoweringRules. - # When lowering a primitive, give priorioty to the rule in this map over + # When lowering a primitive, give priority to the rule in this map over # existing Jax rules. override_lowering_rules: tuple[tuple[core.Primitive, LoweringRule]] | None = None @@ -674,25 +627,57 @@ class LoweringParameters: global_constant_computation: bool = False # Signals that we are lowering for exporting. - for_export: bool = False - # See usage in https://jax.readthedocs.io/en/latest/export/export.html#ensuring-forward-and-backward-compatibility + # See usage in https://docs.jax.dev/en/latest/export/export.html#ensuring-forward-and-backward-compatibility # We have this here to ensure it is reflected in the cache keys export_ignore_forward_compatibility: bool = False + # During lowering hoist the core.Literal constants as args for the main MLIR + # function and all the intermediate functions that need them. + # See https://docs.jax.dev/en/latest/internals/constants.html + # TODO(necula): perhaps we can use `for_export` instead of this additional + # field. + hoist_constants_as_args: bool = config.use_simplified_jaxpr_constants.value + + +def _code_to_filename(code: types.CodeType) -> str | None: + """Returns the canonicalized filename of a code object. + Returns None if the filename should be omitted in tracebacks. + """ + if not source_info_util.is_user_filename(code.co_filename): + return None + pattern = config.hlo_source_file_canonicalization_regex.value + return re.sub(pattern, '', code.co_filename) if pattern else code.co_filename @dataclasses.dataclass class TracebackCaches: - traceback_cache: dict[xc.Traceback, ir.Location] - location_cache: dict[tuple[types.CodeType, int], ir.Location] + traceback_to_location_cache: Any # jax_mlir_ext.TracebackToLocationCache canonical_name_cache: dict[str, str] - is_user_file_cache: dict[str, bool] def __init__(self): - self.traceback_cache = {} - self.location_cache = {} + frame_limit = config.traceback_in_locations_limit.value + frame_limit = frame_limit if frame_limit >= 0 else 1000 + self.traceback_to_location_cache = jax_mlir_ext.TracebackToLocationCache( + code_to_filename=_code_to_filename, frame_limit=frame_limit) self.canonical_name_cache = {} - self.is_user_file_cache = {} + + +@dataclasses.dataclass(frozen=True) +class LoweringCacheKey: + primitive: core.Primitive + eqn_ctx: core.JaxprEqnContext + avals_in: tuple[core.AbstractValue, ...] + effects: effects_lib.Effects + params: frozen_dict.FrozenDict[str, Any] + platforms: tuple[str, ...] + +@dataclasses.dataclass(frozen=True) +class LoweringCacheValue: + func: func_dialect.FuncOp + output_types: Sequence[IrTypes] + const_args: Sequence[ArrayLike] # The hoisted constants expected by `func` + const_arg_avals: Sequence[core.AbstractValue] + inline: bool # Inline calls to this lowered function? @dataclasses.dataclass class ModuleContext: @@ -705,7 +690,7 @@ class ModuleContext: # exporting. platforms: Sequence[str] # See ModuleContext.get_backend() for backend and platforms usage. - backend: xb.XlaBackend | None + backend: xc.Client | None axis_context: AxisContext keepalives: list[Any] channel_iterator: Iterator[int] @@ -715,9 +700,10 @@ class ModuleContext: all_default_mem_kind: bool # Cached primitive lowerings. + lowering_cache: dict[LoweringCacheKey, LoweringCacheValue] cached_primitive_lowerings: dict[Any, func_dialect.FuncOp] - # Cached traceback infromation. + # Cached traceback information. traceback_caches: TracebackCaches lowering_parameters: LoweringParameters @@ -730,7 +716,7 @@ def __init__( self, *, platforms: Sequence[str], - backend: xb.XlaBackend | None, + backend: xc.Client | None, axis_context: AxisContext, keepalives: list[Any], channel_iterator: Iterator[int], @@ -740,7 +726,8 @@ def __init__( module: ir.Module | None = None, ip: ir.InsertionPoint | None = None, symbol_table: ir.SymbolTable | None = None, - cached_primitive_lowerings: None | (dict[Any, func_dialect.FuncOp]) = None, + lowering_cache: None | dict[LoweringCacheKey, Any] = None, + cached_primitive_lowerings: None | dict[Any, func_dialect.FuncOp] = None, traceback_caches: None | TracebackCaches = None, shape_poly_state = None, all_default_mem_kind: bool = True): @@ -752,10 +739,12 @@ def __init__( self.backend = backend self.platforms = platforms self.axis_context = axis_context + self.lowering_cache = ({} if lowering_cache is None else lowering_cache) self.cached_primitive_lowerings = ({} if cached_primitive_lowerings is None else cached_primitive_lowerings) - self.traceback_caches = (TracebackCaches() if traceback_caches is None - else traceback_caches) + with self.context: + self.traceback_caches = (TracebackCaches() if traceback_caches is None + else traceback_caches) self.channel_iterator = channel_iterator self.keepalives = keepalives self.host_callbacks = host_callbacks @@ -764,14 +753,18 @@ def __init__( self.all_default_mem_kind = all_default_mem_kind self.lowering_parameters = lowering_parameters - def get_backend(self) -> xb.XlaBackend: + def get_backend(self, optional: bool = False) -> xc.Client | None: if len(self.platforms) > 1: + if optional: + return None raise NotImplementedError( "accessing .backend in multi-lowering setting. This can occur when " "lowering a primitive that has not been adapted to multi-platform " "lowering") if self.backend is not None: if xb.canonicalize_platform(self.backend.platform) != self.platforms[0]: + if optional: + return None raise ValueError( "the platform for the specified backend " f"{xb.canonicalize_platform(self.backend.platform)} is different " @@ -809,12 +802,23 @@ def replace(self, **kw): return dataclasses.replace(self, **kw) class LoweringRuleContext: """Per-rule context information for MLIR lowering.""" module_context: ModuleContext + # Even though we assigned name_stack entries to each jaxpr equation during + # tracing, we need to propagate name stacks during lowering as well because + # lowering may effectively inline multiple jaxprs into a single HLO function. + # For example, the body of a while loop needs the name stack of the enclosing + # while instruction to be prepended when forming its HLO name. name_stack: source_info_util.NameStack + traceback: xc.Traceback | None primitive: core.Primitive | None avals_in: Sequence[core.AbstractValue] avals_out: Any # Usually Sequence[core.AbstractValue], but sometimes None. tokens_in: TokenSet tokens_out: TokenSet | None # Mutable store for output containers + # The values tobe used for the Literal constants, by id of the const. + # This is used to implement passing along the constants that have been + # hoisted as main function arguments down to where they are used. + # See https://docs.jax.dev/en/latest/internals/constants.html + const_lowering: dict[tuple[int, core.AbstractValue], IrValues] axis_size_env: dict[core.Var, ir.Value] | None = None # Dynamic axis sizes # The values for the dimension variables in same order as # module_context.shape_poly_state.dim_vars @@ -834,11 +838,18 @@ def is_forward_compat(self) -> bool: """Returns true if the lowering parameters are in forward compatibility mode. """ lowering_parameters = self.module_context.lowering_parameters - return ( - lowering_parameters.for_export - and not lowering_parameters.export_ignore_forward_compatibility + + check_platforms: Sequence[str] = ( + self.platforms or self.module_context.platforms + ) + force_forward_compat = any( + p in xb.FORCE_FORWARD_COMPAT_LOWERING_PLATFORMS for p in check_platforms ) + return ( + lowering_parameters.for_export or force_forward_compat + ) and not lowering_parameters.export_ignore_forward_compatibility + if not MYPY: class LoweringRule(Protocol): @@ -849,14 +860,38 @@ def __call__(self, ctx: LoweringRuleContext, else: LoweringRule = Any -_lowerings: dict[core.Primitive, LoweringRule] = {} -_platform_specific_lowerings: dict[str, dict[core.Primitive, LoweringRule]] +@dataclasses.dataclass(frozen=True) +class LoweringRuleEntry: + rule: LoweringRule + inline: bool + +_lowerings: dict[core.Primitive, LoweringRuleEntry] = {} +_platform_specific_lowerings: dict[str, dict[core.Primitive, LoweringRuleEntry]] _platform_specific_lowerings = collections.defaultdict(dict) def register_lowering(prim: core.Primitive, rule: LoweringRule, - platform: str | None = None): + platform: str | None = None, inline: bool = True, + cacheable: bool = True) -> None: + """Registers a lowering rule for a primitive. + + Args: + prim: The primitive to register the rule for. + rule: The lowering rule to register. + platform: The platform to register the rule for. If None, this is a common + rule applicable to all platforms. Platform-specific rules take precedence + over common rules. + inline: Whether to emit the lowering inline. If False, the lowering will be + emitted in a separate function, called by similar instances of the + lowering. + uncacheable: Whether this primitive's lowering can be cached. This is a + temporary flag that will be removed after primitives that have problems + with caching are fixed. + """ + assert not isinstance(rule, LoweringRuleEntry) + if not cacheable: + _uncacheable_primitives.add(prim) if platform is None: - _lowerings[prim] = rule + _lowerings[prim] = LoweringRuleEntry(rule, inline) else: if not xb.is_known_platform(platform): known_platforms = sorted(xb.known_platforms()) @@ -869,8 +904,7 @@ def register_lowering(prim: core.Primitive, rule: LoweringRule, # TODO(phawkins): fix up users to specify either "cuda" or "rocm" and remove # this expansion. for p in xb.expand_platform_alias(platform): - _platform_specific_lowerings[p][prim] = rule - return rule + _platform_specific_lowerings[p][prim] = LoweringRuleEntry(rule, inline) def flatten_ir_values(xs: Iterable[IrValues]) -> list[ir.Value]: @@ -941,27 +975,23 @@ def sharded_aval(aval: core.AbstractValue, return aval if isinstance(aval, core.AbstractToken): return aval - if not isinstance(aval, (core.ShapedArray, core.DShapedArray)): + if not isinstance(aval, core.ShapedArray): raise NotImplementedError return aval.update(sharding.shard_shape(aval.shape), sharding=None) # type: ignore def eval_dynamic_shape(ctx: LoweringRuleContext, shape: core.Shape) -> tuple[int | Value, ...]: - if config.dynamic_shapes.value: - assert ctx.axis_size_env is not None - return tuple(ctx.axis_size_env.get(d, d) for d in shape) # type: ignore - else: - ctx = ctx.replace( - primitive="eval_dynamic_shape", - avals_in=[core.dim_value_aval()] * len(ctx.module_context.shape_poly_state.dim_vars), - tokens_out=None) + ctx = ctx.replace( + primitive="eval_dynamic_shape", + avals_in=[core.dim_value_aval()] * len(ctx.module_context.shape_poly_state.dim_vars), + tokens_out=None) - res = lower_fun( - partial(core.evaluate_shape, shape, ctx.module_context.shape_poly_state.dim_vars), - multiple_results=True)(ctx, *ctx.dim_var_values) - return tuple(operator.index(d) if core.is_constant_dim(d) else d_ir - for d, d_ir in zip(shape, flatten_ir_values(res))) + res = lower_fun( + partial(core.evaluate_shape, shape, ctx.module_context.shape_poly_state.dim_vars), + multiple_results=True)(ctx, *ctx.dim_var_values) + return tuple(operator.index(d) if core.is_constant_dim(d) else d_ir + for d, d_ir in zip(shape, flatten_ir_values(res))) # TODO: replace usage of eval_dynamic_shape_as_vals with eval_dynamic_shape_as_ivals def eval_dynamic_shape_as_vals(ctx: LoweringRuleContext, @@ -1010,24 +1040,35 @@ class LoweringResult(NamedTuple): def add_manual_axes(axis_ctx: sharding_impls.SPMDAxisContext, sharding, ndim): - mesh = axis_ctx.mesh + mesh = axis_ctx.mesh.abstract_mesh + sharding_mesh = sharding.mesh.abstract_mesh if (isinstance(sharding, sharding_impls.NamedSharding) and - sharding.mesh.shape == mesh.shape): - return sharding_impls.NamedSharding( - sharding.mesh, sharding.spec, memory_kind=sharding.memory_kind, - _manual_axes=axis_ctx.manual_axes) + sharding_mesh.shape == mesh.shape): + out_mesh, spec = sharding_mesh, sharding.spec else: - spec = sharding_impls.parse_flatten_op_sharding( + out_mesh, spec = mesh, sharding_impls.parse_flatten_op_sharding( sharding._to_xla_hlo_sharding(ndim), mesh)[0] - return sharding_impls.NamedSharding( - mesh, spec, memory_kind=sharding.memory_kind, - _manual_axes=axis_ctx.manual_axes) + + out_mesh = out_mesh.update_axis_types( + {a: AxisType.Manual for a in axis_ctx.manual_axes}) + out = sharding_impls.NamedSharding(out_mesh, spec, + memory_kind=sharding.memory_kind) + manual_axes = out.mesh.manual_axes + if any(p in manual_axes for s in out.spec + if s is not None and s is not PartitionSpec.UNCONSTRAINED + for p in (s if isinstance(s, tuple) else (s,))): + raise ValueError( + f'pspec {out.spec} contains a manual axes {manual_axes} of mesh' + f' which is not allowed. If you are using a' + ' with_sharding_constraint under a shard_map, only use the' + ' mesh axis in PartitionSpec which are not manual.') + return out def _to_physical_op_sharding( ctx: ModuleContext, aval: core.AbstractValue, sharding: JSharding | AUTO | None, -) -> xc.OpSharding | SdyArraySharding | None: +) -> xc.OpSharding | SdyArray | None: if sharding is None: return None if all_unconstrained(sharding, aval): @@ -1039,7 +1080,7 @@ def _to_physical_op_sharding( assert isinstance(sharding, JSharding) if isinstance(aval, AbstractRef): return _to_physical_op_sharding(ctx, aval.inner_aval, sharding) - assert isinstance(aval, (core.ShapedArray, core.DShapedArray)) + assert isinstance(aval, core.ShapedArray) if dtypes.issubdtype(aval.dtype, dtypes.extended): sharding = sharding_impls.physical_sharding(aval, sharding) aval = core.physical_aval(aval) @@ -1052,7 +1093,7 @@ def _to_physical_op_sharding( return sharding._to_xla_hlo_sharding(aval.ndim).to_proto() # type: ignore -def _to_xla_layout(layout: DeviceLocalLayout | None | AutoLayout, +def _to_xla_layout(layout: Layout | None | AutoLayout, aval: core.AbstractValue) -> str | None: if layout is None: return None @@ -1079,6 +1120,8 @@ def contains_unconstrained(s): def all_unconstrained(s, aval): if isinstance(s, NamedSharding): + if aval.ndim == 0: + return False if aval.ndim != len(s.spec): return False return all(p is PartitionSpec.UNCONSTRAINED for p in s.spec) @@ -1087,33 +1130,83 @@ def all_unconstrained(s, aval): class UnconstrainedVariants(NamedTuple): contains_unconstrained: bool all_unconstrained: bool - unconstrained_dims: set[int] | None def _get_unconstrained_variants(s, aval) -> UnconstrainedVariants: us = contains_unconstrained(s) - unconstrained_dims = ({i for i, p in enumerate(s.spec) - if p is PartitionSpec.UNCONSTRAINED} if us else None) return UnconstrainedVariants( - contains_unconstrained=us, all_unconstrained=all_unconstrained(s, aval), - unconstrained_dims=unconstrained_dims) + contains_unconstrained=us, all_unconstrained=all_unconstrained(s, aval)) + + +def check_jaxpr_constants(closed_jaxpr: core.ClosedJaxpr): + """Check if a JAXPR contains an excessive amount of constants, if so, report where they were captured""" + if config.use_simplified_jaxpr_constants.value: + return + if (threshold := config.captured_constants_warn_bytes.value) == -1: + return + + # need the unaesthetic getter here as some of the consts in the test suite are arbitrary objects + total_iter, nbytes_iter = itertools.tee( + map(lambda c: getattr(c, "nbytes", 0), closed_jaxpr.consts) + ) + + if (total_bytes := sum(total_iter)) < threshold: + return + + message = ( + "A large amount of constants were captured during lowering" + f" ({util.pprint_bytes(total_bytes)} total). If this is intentional," + " disable this warning by setting JAX_CAPTURED_CONSTANTS_WARN_BYTES=-1. " + ) + + if not (num_frames := config.captured_constants_report_frames.value): + message += ( + "To obtain a report of where these constants were encountered, " + "set JAX_CAPTURED_CONSTANTS_REPORT_FRAMES=-1." + ) + warnings.warn(message) + return + + message += ( + "The subsequent report may be disabled by setting JAX_CAPTURED_CONSTANTS_REPORT_FRAMES=0.\n\n" + f"Largest {min(num_frames, len(closed_jaxpr.consts))} allocation(s):\n" + ) + try: + nbytes_var_const = zip(nbytes_iter, closed_jaxpr.jaxpr.constvars, closed_jaxpr.consts) + for nbytes, var, const in heapq.nlargest(5, nbytes_var_const, key=operator.itemgetter(0)): + message += f" Constant {type(const)}, {var.aval.str_short()}, {util.pprint_bytes(nbytes)} captured at:\n" + for eqn in jaxpr_util.eqns_using_var(closed_jaxpr.jaxpr, var): + call_frame_source_info = source_info_util.summarize(eqn.source_info, num_frames) + message += " " * 2 + call_frame_source_info.replace("\n", "\n" + " " * 2) + "\n\n" + + warnings.warn(message) + except Exception as exc: + warnings.warn(message + f" Exception raised while generating report: {exc}") + +# TODO(phawkins): it is my firm belief that: +# a) channel IDs have only a vestigal function when applied to collectives, and +# b) their identity does not matter. The presence or absence of a channel +# changes whether XLA considers collectives to be inter-replica or +# inter-partition, but beyond that we believe they have little effect. +COLLECTIVE_CHANNEL_ID = 1 def lower_jaxpr_to_module( module_name: str, jaxpr: core.ClosedJaxpr, *, + num_const_args: int, + in_avals: Sequence[core.AbstractValue], ordered_effects: list[core.Effect], # See ModuleContext.get_backend() for backend and platforms usage. platforms: Sequence[str], - backend: xb.XlaBackend | None, + backend: xc.Client | None, axis_context: AxisContext, - name_stack: source_info_util.NameStack, donated_args: Sequence[bool], replicated_args: Sequence[bool] | None = None, arg_shardings: Sequence[JSharding | AUTO | None] | None = None, result_shardings: Sequence[JSharding | AUTO | None] | None = None, - in_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None, - out_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None, + in_layouts: Sequence[Layout | None | AutoLayout] | None = None, + out_layouts: Sequence[Layout | None | AutoLayout] | None = None, arg_names: Sequence[str] | None = None, result_names: Sequence[str] | None = None, num_replicas: int = 1, @@ -1127,14 +1220,16 @@ def lower_jaxpr_to_module( Handles the quirks of the argument/return value passing conventions of the runtime. + The inputs already account for the constant arguments. + See https://docs.jax.dev/en/latest/internals/constants.html """ util.test_event("lower_jaxpr_to_module") platforms = tuple(map(xb.canonicalize_platform, platforms)) - in_avals = (jaxpr.in_avals if arg_shardings is None else - map(sharded_aval, jaxpr.in_avals, arg_shardings)) - out_avals = (jaxpr.out_avals if result_shardings is None else - map(sharded_aval, jaxpr.out_avals, result_shardings)) + sharded_in_avals = (in_avals if arg_shardings is None else + map(sharded_aval, in_avals, arg_shardings)) + sharded_out_avals = (jaxpr.out_avals if result_shardings is None else + map(sharded_aval, jaxpr.out_avals, result_shardings)) if all_default_mem_kind: arg_memory_kinds = None result_memory_kinds = None @@ -1156,7 +1251,7 @@ def lower_jaxpr_to_module( f"should support donation. Lowering for {platforms} of which " f"only {platforms_with_donation} support donation") input_output_aliases, donated_args, xla_donated_args = _set_up_aliases( - input_output_aliases, in_avals, out_avals, donated_args, + input_output_aliases, sharded_in_avals, sharded_out_avals, donated_args, arg_memory_kinds, result_memory_kinds, in_layouts, out_layouts, result_shardings if num_partitions > 1 else None) if (num_partitions > 1 and @@ -1170,8 +1265,8 @@ def lower_jaxpr_to_module( xla_donated_args[input_id] = True donated_args[input_id] = False if any(donated_args): - unused_donations = [str(a) for a, d in zip(in_avals, donated_args) if d] - msg = "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation." + unused_donations = [str(a) for a, d in zip(sharded_in_avals, donated_args) if d] + msg = "See an explanation at https://docs.jax.dev/en/latest/faq.html#buffer-donation." if not platforms_with_donation: msg = f"Donation is not implemented for {platforms}.\n{msg}" if unused_donations: @@ -1180,25 +1275,22 @@ def lower_jaxpr_to_module( # Delete donated_args by default here, since it's not needed beyond this point del donated_args - unlowerable_effects = lowerable_effects.filter_not_in(jaxpr.effects) + unlowerable_effects = effects_lib.lowerable_effects.filter_not_in( + jaxpr.effects) if unlowerable_effects: raise ValueError(f'Cannot lower jaxpr with effects: {jaxpr.effects}') - # HLO channels need to start at 1 - channel_iter = itertools.count(1) + # HLO channels need to start at 1. We reserve 1 for collectives. + channel_iter = itertools.count(COLLECTIVE_CHANNEL_ID + 1) # Create a keepalives list that will be mutated during the lowering. keepalives: list[Any] = [] host_callbacks: list[Any] = [] + # Find the dimension variables + all_dim_poly = [d for aval in sharded_in_avals if hasattr(aval, "shape") + for d in aval.shape if not core.is_constant_dim(d)] + dim_vars = tuple(sorted(functools.reduce(lambda acc, new: acc.union(new._get_vars()), + all_dim_poly, set()))) - dim_vars: Sequence[str] - if not config.dynamic_shapes.value: - # Find the dimension variables - all_dim_poly = [d for aval in jaxpr.in_avals if hasattr(aval, "shape") - for d in aval.shape if not core.is_constant_dim(d)] - dim_vars = tuple(sorted(functools.reduce(lambda acc, new: acc.union(new._get_vars()), - all_dim_poly, set()))) - else: - dim_vars = () ctx = ModuleContext(backend=backend, platforms=platforms, axis_context=axis_context, @@ -1212,15 +1304,16 @@ def lower_jaxpr_to_module( # Remove module name characters that XLA would alter. This ensures that # XLA computation preserves the module name. attrs = ctx.module.operation.attributes - module_name = sanitize_name(module_name) - attrs["sym_name"] = ir.StringAttr.get(module_name) + attrs["sym_name"] = ir.StringAttr.get( + sanitize_name(module_name).rstrip("_")) attrs["mhlo.num_replicas"] = i32_attr(num_replicas) attrs["mhlo.num_partitions"] = i32_attr(num_partitions) lower_jaxpr_to_fun( - ctx, "main", jaxpr, ordered_effects, - name_stack=name_stack, - public=True, + ctx, module_name, jaxpr, ordered_effects, + num_const_args=num_const_args, + main_function=True, replicated_args=replicated_args, + in_avals=in_avals, arg_shardings=arg_shardings, result_shardings=result_shardings, input_output_aliases=input_output_aliases, @@ -1250,8 +1343,13 @@ def emit_diagnostic_info(d): raise ValueError("\n".join(msg_lines) + "\n" + dump_module_message(ctx.module, "verification")) from e - if config.use_shardy_partitioner.value: - with ctx.context: + with ctx.context: + # Cached lowering rule evaluation leaves dead functions. Remove them. + pipeline = passmanager.PassManager.parse( + 'builtin.module(symbol-dce)') + pipeline.run(ctx.module.operation) + + if config.use_shardy_partitioner.value: pipeline = passmanager.PassManager.parse( 'builtin.module(sdy-lift-inlined-meshes)') pipeline.run(ctx.module.operation) @@ -1260,7 +1358,6 @@ def emit_diagnostic_info(d): return LoweringResult(ctx.module, ctx.keepalives, ctx.host_callbacks, ctx.shape_poly_state) - def _set_up_aliases(input_output_aliases, avals_in, avals_out, donated_args, arg_memory_kinds, result_memory_kinds, in_layouts, out_layouts, result_shardings): @@ -1309,7 +1406,7 @@ def _set_up_aliases(input_output_aliases, avals_in, avals_out, raise ValueError( f"Input layout being donated was {in_layouts[input_id]} while" f" output layout was {out_layouts[i]}. Did you mean to set the" - " **output layout** to **DeviceLocalLayout.AUTO**?\nThis will" + " **output layout** to **Layout.AUTO**?\nThis will" " allow for the input and output layout to be chosen by XLA and" " not the layout of the output which might not be optimal.") if (in_out_layout_not_none and @@ -1318,7 +1415,7 @@ def _set_up_aliases(input_output_aliases, avals_in, avals_out, raise ValueError( f"Input layout being donated was {in_layouts[input_id]} while" f" output layout was {out_layouts[i]}. Did you mean to set the" - " **input layout** to **DeviceLocalLayout.AUTO**?\nThis will allow" + " **input layout** to **Layout.AUTO**?\nThis will allow" " for the input and output layout to be chosen by XLA and not the" " layout of the input which might not be optimal.") if (in_layouts is None or out_layouts is None or @@ -1333,6 +1430,33 @@ def _set_up_aliases(input_output_aliases, avals_in, avals_out, xla_donated_args = [False] * len(avals_in) xla_donated_args[input_id] = True + aliased_output_ids = {i for i in input_output_aliases if i is not None} + + results_not_matched = collections.defaultdict(collections.deque) + for i, (aval, rm) in enumerate(zip(avals_out, result_memory_kinds)): + if i not in aliased_output_ids and aval is not core.abstract_token: + results_not_matched[(aval.size, rm)].append(i) + + # For each donated argument that hasn't been aliased or donated to XLA, try to + # find an output array with matching size ignoring shapes. If a matching + # output array is found, then the argument is donated to XLA. + # Similar to the aliasing logic above, an argument is donated to XLA even if + # its layout and the output's layout don't match. This is being done to + # provide more opportunities for XLA to reuse the donated arguments. + for input_idx in range(len(out_donated_args)): + # If the argument is not a token and hasn't been aliased or donated to XLA, + # then try to find an output array with matching size. + if (out_donated_args[input_idx] + and avals_in[input_idx] is not core.abstract_token): + key = (avals_in[input_idx].size, arg_memory_kinds[input_idx]) + if results_not_matched.get(key, ()): + # XLA donate the argument because there's a matching output array. + results_not_matched[key].popleft() + out_donated_args[input_idx] = False + if xla_donated_args is None: + xla_donated_args = [False] * len(avals_in) + xla_donated_args[input_idx] = True + return input_output_aliases, out_donated_args, xla_donated_args Token = ir.Value @@ -1388,28 +1512,31 @@ def lower_jaxpr_to_fun( name: str, jaxpr: core.ClosedJaxpr, effects: Sequence[core.Effect], - name_stack: source_info_util.NameStack, *, - public: bool = False, + num_const_args: int, + main_function: bool = False, replicated_args: Sequence[bool] | None = None, + in_avals: Sequence[core.AbstractValue], arg_shardings: Sequence[JSharding | AUTO | None] | None = None, result_shardings: Sequence[JSharding | AUTO | None] | None = None, use_sharding_annotations: bool = True, input_output_aliases: Sequence[int | None] | None = None, xla_donated_args: Sequence[bool] | None = None, - api_name: str = "jit", arg_names: Sequence[str | None] | None = None, result_names: Sequence[str] | None = None, arg_memory_kinds: Sequence[str | None] | None = None, result_memory_kinds: Sequence[str | None] | None = None, - arg_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None, - result_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None, + arg_layouts: Sequence[Layout | None | AutoLayout] | None = None, + result_layouts: Sequence[Layout | None | AutoLayout] | None = None, propagated_out_mem_kinds: tuple[None | str, ...] | None = None, ) -> func_dialect.FuncOp: """Lowers jaxpr and its callees to an IR function. Assumes that an MLIR context, location, and insertion point are set. + Note: this function does *not* take a name stack. Name stacks do not cross + the boundaries of HLO functions. + Args: ctx: the lowering context. name: the function name. The name will be uniquified by the symbol table, @@ -1417,7 +1544,13 @@ def lower_jaxpr_to_fun( jaxpr: the jaxpr to lower. effects: a sequence of `core.Effect`s corresponding to an ordering of tokens that will be created in or used by the lowered function. - public: if true, the function's visibility is set to "public". + num_const_args: how many constant arguments is this function going to have. + See https://docs.jax.dev/en/latest/internals/constants.html + main_function: if true, this is the main function in the module. This has + several effects: + * the function's visibility is set to "public". + * the function's symbol name will be "main" + * the function's name will be used as the root name stack entry. replicated_args: if present, annotates arguments as replicated. arg_shardings: sharding annotations for each argument (optional). result_shardings: sharding annotations for each result (optional). @@ -1429,68 +1562,84 @@ def lower_jaxpr_to_fun( input_output_aliases: optional sequence that maps argument numbers to the corresponding output that should alias them. xla_donated_args: optional sequence of args to set donation annotations. - api_name: The name of the higher level primitive which should show up in the - name stack. Returns: MLIR func op """ util.test_event("lower_jaxpr_to_fun", name) + if not config.use_simplified_jaxpr_constants.value: + check_jaxpr_constants(jaxpr) + # The first dimension variable may be the platform index num_dim_vars = len(ctx.shape_poly_state.dim_vars) - dim_var_avals = [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars + dim_var_avals = [core.ShapedArray((), dtypes.default_int_dtype())] * num_dim_vars dim_var_types = map(aval_to_ir_type, dim_var_avals) - # Function inputs: *dim_var_values, *tokens, *actual_inputs - input_types = map(aval_to_ir_type, jaxpr.in_avals) + nr_args = num_const_args + len(jaxpr.in_avals) + + assert nr_args == len(in_avals), (nr_args, in_avals) + assert replicated_args is None or nr_args == len(replicated_args), \ + (nr_args, replicated_args) + assert arg_shardings is None or nr_args == len(arg_shardings), \ + (nr_args, arg_shardings) + assert arg_layouts is None or nr_args == len(arg_layouts), \ + (nr_args, arg_layouts) + assert arg_memory_kinds is None or nr_args == len(arg_memory_kinds), \ + (nr_args, arg_memory_kinds) + assert arg_names is None or nr_args == len(arg_names), (nr_args, arg_names) + + # Function inputs: *dim_var_values, *tokens, *const_args, *actual_inputs + input_types = map(aval_to_ir_type, in_avals) output_types = map(aval_to_ir_type, jaxpr.out_avals) num_tokens = len(effects) token_types = [token_type() for _ in effects] token_avals = [core.abstract_token] * num_tokens - # Order of arguments: dim vars, tokens, array inputs - input_avals = dim_var_avals + token_avals + jaxpr.in_avals + # Order of arguments: dim vars, tokens, const_args, array inputs + input_avals = dim_var_avals + token_avals + list(in_avals) # type: ignore input_types = [*dim_var_types, *token_types, *input_types] output_avals = [core.abstract_token] * num_tokens + jaxpr.out_avals output_types = [*token_types, *output_types] if input_output_aliases is not None: - token_input_output_aliases = [None] * (num_dim_vars + num_tokens) - input_output_aliases = [*token_input_output_aliases, *input_output_aliases] + prefix_input_output_aliases = [None] * (num_dim_vars + num_tokens) + input_output_aliases = [*prefix_input_output_aliases, *input_output_aliases] # Update the existing aliases to account for the new output values input_output_aliases = [None if a is None else a + num_tokens for a in input_output_aliases] if arg_shardings is not None: - token_shardings = [None] * (num_dim_vars + num_tokens) - arg_shardings = [*token_shardings, *arg_shardings] + prefix_shardings = [None] * (num_dim_vars + num_tokens) + arg_shardings = [*prefix_shardings, *arg_shardings] if result_shardings is not None: token_shardings = [None] * num_tokens result_shardings = [*token_shardings, *result_shardings] if replicated_args is not None: - token_replicated_args = [False] * (num_dim_vars + num_tokens) - replicated_args = [*token_replicated_args, *replicated_args] + prefix_replicated_args = [False] * (num_dim_vars + num_tokens) + replicated_args = [*prefix_replicated_args, *replicated_args] if arg_memory_kinds is not None: - token_memory_kinds = [None] * (num_dim_vars + num_tokens) - arg_memory_kinds = [*token_memory_kinds, *arg_memory_kinds] + prefix_memory_kinds = [None] * (num_dim_vars + num_tokens) + arg_memory_kinds = [*prefix_memory_kinds, *arg_memory_kinds] if result_memory_kinds is not None: token_memory_kinds = [None] * num_tokens result_memory_kinds = [*token_memory_kinds, *result_memory_kinds] if arg_layouts is not None: - token_layouts = [None] * (num_dim_vars + num_tokens) - arg_layouts = [*token_layouts, *arg_layouts] + prefix_layouts = [None] * (num_dim_vars + num_tokens) + arg_layouts = [*prefix_layouts, *arg_layouts] if result_layouts is not None: token_layouts = [None] * num_tokens result_layouts = [*token_layouts, *result_layouts] if xla_donated_args is not None: - xla_donated_args = [*([False] * (num_dim_vars + num_tokens)), *xla_donated_args] + xla_donated_args = [*([False] * (num_dim_vars + num_tokens)), + *xla_donated_args] flat_input_types = flatten_ir_types(input_types) flat_output_types = flatten_ir_types(output_types) ftype = ir.FunctionType.get(flat_input_types, flat_output_types) - func_op = func_dialect.FuncOp(name, ftype, ip=ctx.ip) + func_name = "main" if main_function else name + func_op = func_dialect.FuncOp(func_name, ftype, ip=ctx.ip) func_op.attributes["sym_visibility"] = ir.StringAttr.get( - "public" if public else "private") + "public" if main_function else "private") ctx.symbol_table.insert(func_op) ir_arg_shardings = None @@ -1552,6 +1701,7 @@ def lower_jaxpr_to_fun( [[_to_xla_layout(l, a)] * len_ir_types(types) for l, a, types in zip(result_layouts, output_avals, output_types)]) + # Populate arg_attrs if ( replicated_args is not None or ir_arg_shardings is not None @@ -1562,6 +1712,7 @@ def lower_jaxpr_to_fun( or arg_names is not None or num_tokens > 0 or num_dim_vars > 0 + or num_const_args > 0 ): arg_attrs: list[dict[str, ir.Attribute]] = [ {} for _ in range(len(flat_input_types))] @@ -1622,8 +1773,15 @@ def lower_jaxpr_to_fun( for attrs in token_arg_attrs: attrs["jax.token"] = ir.BoolAttr.get(True) + if num_const_args > 0: + const_arg_attrs = arg_attrs[num_dim_vars + num_tokens : + num_dim_vars + num_tokens + num_const_args] + for attrs in const_arg_attrs: + attrs["jax.const"] = ir.BoolAttr.get(True) + func_op.arg_attrs = ir.ArrayAttr.get( [ir.DictAttr.get(attrs) for attrs in arg_attrs]) + # End populate arg_attrs result_attrs: list[dict[str, ir.Attribute]] = [ {} for _ in range(len(flat_output_types))] @@ -1667,26 +1825,46 @@ def lower_jaxpr_to_fun( arg_locs.append(ir.Location.name(n) if n else ir.Location.unknown()) entry_block = func_op.add_entry_block(arg_locs) else: - entry_block = func_op.add_entry_block() + with ir.Location.unknown(): + entry_block = func_op.add_entry_block() + + # When lowering a function out of line, we do not include name context from + # the caller. A function might have multiple callers, and it would be + # incorrect to include any one caller's context. Exception: The main function + # has no caller, so we include its name in the name stack. + name_stack = ( + source_info_util.new_name_stack(name) + if main_function + else source_info_util.new_name_stack() + ) with ir.InsertionPoint(entry_block): flat_args = entry_block.arguments - # We separate out the dimension variable inputs, the token inputs and - # the regular inputs. The dimension variables and token inputs - # will be passed to `jaxpr_subcomp` separately from the `args`. - dim_var_values, _, _ = util.split_list(flat_args, [num_dim_vars, num_tokens]) + dim_var_values, _, const_arg_values, _ = util.split_list( + flat_args, [num_dim_vars, num_tokens, num_const_args]) + const_args_and_avals = core.jaxpr_const_args(jaxpr.jaxpr) + if num_const_args == 0: + # If we did not hoist the constants out of this function, lower them now + const_arg_values = [ir_constant(c, aval=aval) + for c, aval in const_args_and_avals] + const_lowering = { + (id(c), aval): c_arg + for (c, aval), c_arg in zip(const_args_and_avals, const_arg_values) + } + # A lowering context just for function body entry/exit code. entry_lowering_ctx = LoweringRuleContext( - module_context=ctx, name_stack=name_stack, primitive=None, + module_context=ctx, name_stack=name_stack, traceback=None, primitive=None, avals_in=[], avals_out=None, tokens_in=TokenSet.create([]), tokens_out=None, - axis_size_env=None, dim_var_values=dim_var_values) + axis_size_env=None, dim_var_values=dim_var_values, + const_lowering=const_lowering) if not use_sharding_annotations and ir_arg_shardings is not None: flat_args = [ a if s is None else wrap_with_sharding_op(entry_lowering_ctx, a, a_aval, s) for a, s, a_aval in zip(flat_args, ir_arg_shardings, input_avals)] - if ir_arg_shardings is not None and name == "main": + if ir_arg_shardings is not None and main_function: flat_args = [ replicate_trailing_dims(entry_lowering_ctx, o, a) if (a is not core.abstract_token and @@ -1696,19 +1874,21 @@ def lower_jaxpr_to_fun( arg_shardings) # type: ignore ] - _, token_args, unflattened_args = util.split_list( + _, token_args, _, unflattened_args = util.split_list( unflatten_ir_values_like_types(flat_args, input_types), - [num_dim_vars, num_tokens]) + [num_dim_vars, num_tokens, num_const_args]) tokens_in = TokenSet(zip(effects, token_args)) args: list[IrValues] = unflattened_args - if name is not None: - callee_name_stack = name_stack.extend(util.wrap_name(name, api_name)) - else: - callee_name_stack = name_stack - consts = [ir_constant(xla.canonicalize_dtype(x)) for x in jaxpr.consts] + unique_consts = { + id(c): ir_constant(c, aval=var.aval) + for c, var in zip(jaxpr.consts, jaxpr.jaxpr.constvars) + } + consts_for_constvars = [unique_consts[id(c)] for c in jaxpr.consts] + out_vals, tokens_out = jaxpr_subcomp( - ctx, jaxpr.jaxpr, callee_name_stack, tokens_in, - consts, *args, dim_var_values=dim_var_values) + ctx, jaxpr.jaxpr, name_stack, tokens_in, + consts_for_constvars, *args, dim_var_values=dim_var_values, + const_lowering=const_lowering) outs: list[IrValues] = [] for eff in effects: outs.append(tokens_out.get(eff)) @@ -1729,22 +1909,27 @@ def lower_jaxpr_to_fun( not uv.all_unconstrained): if config.use_shardy_partitioner.value: s = modify_sdy_sharding_wrt_axis_types(s, o_aval.sharding.mesh) + unconstrained_dims = None # delete this after shardy is default + else: + unconstrained_dims = ( + set(range(o_aval.ndim)) if o_aval.sharding.mesh._any_axis_auto + else None) temp_flat_outputs.append(wrap_with_sharding_op( entry_lowering_ctx, o, o_aval, s, - unspecified_dims=uv.unconstrained_dims)) + unspecified_dims=unconstrained_dims)) else: temp_flat_outputs.append(o) flat_outputs = temp_flat_outputs # Insert a custom call if output is on host because XLA needs that to do the # transfer. - if custom_call_ir_result_memory_kinds is not None and name == "main": + if custom_call_ir_result_memory_kinds is not None and main_function: flat_outputs = [ o if mk is None else wrap_with_memory_kind(o, mk, o_aval) for o, mk, o_aval in zip( flat_outputs, custom_call_ir_result_memory_kinds, output_avals)] - if ir_result_shardings is not None and name == "main": + if ir_result_shardings is not None and main_function: flat_outputs = [ replicate_trailing_dims(entry_lowering_ctx, o, a) if (a is not core.abstract_token and @@ -1782,13 +1967,13 @@ def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value: # For example: if the key.shape is (8, 2) and key_data(key).shape is (8, 2, 2), # then the sharding will be P(P.UNCONSTRAINED, P.UNCONSTRAINED, None). # The below custom call achieves the sharding like above example. - assert isinstance(aval, (core.ShapedArray, core.DShapedArray)) + assert isinstance(aval, core.ShapedArray) if config.use_shardy_partitioner.value: physical_ndim = core.physical_aval(aval).ndim - s = SdyArraySharding( + s = SdyArray( mesh_shape=None, - dimension_shardings=[ - sharding_impls.SdyDimSharding(axes=[], is_closed=i >= aval.ndim) + dim_shardings=[ + sharding_impls.SdyDim(axes=[], is_open=i < aval.ndim) for i in range(physical_ndim) ]) return wrap_with_sharding_op(ctx, val, aval, s) @@ -1798,117 +1983,35 @@ def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value: unspecified_dims=set(range(aval.ndim))) -def _emit_lowering_rule_as_fun(lowering_rule, - ctx: LoweringRuleContext) -> func_dialect.FuncOp: - """Emits the contents of a lowering rule as a private function.""" - num_dim_vars = len(ctx.module_context.shape_poly_state.dim_vars) - # TODO(necula) maybe only pass the dim_vars if they are needed? - dim_var_types = [ - aval_to_ir_type(core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))) - ] * num_dim_vars - - input_types = map(aval_to_ir_type, ctx.avals_in) - output_types = map(aval_to_ir_type, ctx.avals_out) - effs = list(ctx.tokens_in.effects()) - token_types = [token_type() for _ in effs] - input_types = [*dim_var_types, *token_types, *input_types] - output_types = [*token_types, *output_types] - - flat_input_types = flatten_ir_types(input_types) - flat_output_types = flatten_ir_types(output_types) - ftype = ir.FunctionType.get(flat_input_types, flat_output_types) - assert ctx.primitive is not None - func_op = func_dialect.FuncOp(ctx.primitive.name, ftype, - ip=ctx.module_context.ip) - func_op.attributes["sym_visibility"] = ir.StringAttr.get("private") - ctx.module_context.symbol_table.insert(func_op) - entry_block = func_op.add_entry_block() - with ir.InsertionPoint(entry_block): - unflattened_args = unflatten_ir_values_like_types( - entry_block.arguments, input_types) - dim_var_values, token_args, unflattened_args = util.split_list(unflattened_args, [num_dim_vars, len(ctx.tokens_in)]) - sub_ctx = ctx.replace(tokens_in=TokenSet(zip(effs, token_args)), - dim_var_values=dim_var_values) - outs = lowering_rule(sub_ctx, *unflattened_args) - if sub_ctx.tokens_out: - outs = [*[sub_ctx.tokens_out.get(eff) for eff in effs], outs] - func_dialect.return_(flatten_ir_values(outs)) - return func_op - - -class HashableLiteral: - """Hashable wrapper of core.Literal, used for deduplicating IR constants.""" - - __slots__ = ["value", "data"] - - value: core.Literal - - # Copy of the value suitable for an equality comparison. We are careful to - # avoid floating point comparisons here, because in particular we don't want - # 0.0 and -0.0 to be considered equal, but we are fine with NaNs being equal. - data: bytes | int | bool | None - - def __init__(self, value): - self.value = value - if isinstance(value.val, (np.generic, np.ndarray)): - self.data = value.val.tobytes() - elif isinstance(value.val, (bool, int)): - self.data = value.val - elif isinstance(value.val, float): - self.data = np.float64(value.val).tobytes() - elif isinstance(value.val, complex): - self.data = np.complex128(value.val).tobytes() - else: - self.data = None # Unhandled case. - - def __hash__(self): - return hash(self.data) - - def __eq__(self, other): - if type(self.value.val) != type(other.value.val): - return False - if self.value.aval != other.value.aval: - return False - if self.data is None: - return self is other - return self.data == other.data - +_uncacheable_primitives: set[core.Primitive] = set() def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, name_stack: source_info_util.NameStack, tokens: TokenSet, - consts: Sequence[IrValues], + consts_for_constvars: Sequence[IrValues], *args: IrValues, - dim_var_values: Sequence[ir.Value] + dim_var_values: Sequence[ir.Value], + const_lowering: dict[tuple[int, core.AbstractValue], IrValues], ) -> tuple[Sequence[IrValues], TokenSet]: """Lowers a jaxpr into MLIR, inlined into an existing function. Assumes that an MLIR context, location, and insertion point are set. + consts_for_constvars: the constants corresponding to jaxpr.constvars. dim_var_values: the list of dimension variables values in the current IR function, in the order of ctx.shape_poly_state.dim_vars. + const_lowering: the lowering for constants, by constant id. + See https://docs.jax.dev/en/latest/internals/constants.html """ assert "gpu" not in ctx.platforms - cached_ir_consts: dict[HashableLiteral, IrValues] = {} def read(v: core.Atom) -> IrValues: if type(v) is core.Literal: - h = HashableLiteral(v) - c = cached_ir_consts.get(h) - if c is None: - c = ir_constant(xla.canonicalize_dtype(v.val)) - cached_ir_consts[h] = c - return c + return ir_constant(v.val, const_lowering=const_lowering, aval=v.aval) else: assert isinstance(v, core.Var) return env[v] - def aval(v: core.Atom) -> core.AbstractValue: - if type(v) is core.Literal: - return core.abstractify(v.val) - else: - return v.aval - def write(v: core.Var, node: IrValues): assert node is not None w: IrValues @@ -1926,94 +2029,253 @@ def write(v: core.Var, node: IrValues): w = tuple(node) env[v] = w - def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None: - if ctx.lowering_parameters.override_lowering_rules is None: - return None - for p, rule in ctx.lowering_parameters.override_lowering_rules: - if primitive is p: - return rule - return None - env: dict[core.Var, IrValues] = {} assert all(_is_ir_values(v) for v in args), args - assert all(_is_ir_values(v) for v in consts), consts + assert all(_is_ir_values(v) for v in consts_for_constvars), \ + consts_for_constvars assert isinstance(name_stack, source_info_util.NameStack), type(name_stack) assert len(args) == len(jaxpr.invars), (jaxpr, args) - assert len(consts) == len(jaxpr.constvars), (jaxpr, consts) - assert len(ctx.shape_poly_state.dim_vars) == len(dim_var_values), (ctx.shape_poly_state.dim_vars, dim_var_values) - foreach(write, jaxpr.constvars, consts) + assert len(consts_for_constvars) == len(jaxpr.constvars), \ + (jaxpr, consts_for_constvars) + assert len(ctx.shape_poly_state.dim_vars) == len(dim_var_values), \ + (ctx.shape_poly_state.dim_vars, dim_var_values) + foreach(write, jaxpr.constvars, consts_for_constvars) foreach(write, jaxpr.invars, args) last_used = core.last_used(jaxpr) for eqn in jaxpr.eqns: - in_nodes = map(read, eqn.invars) - source_info = eqn.source_info.replace( - name_stack=name_stack + eqn.source_info.name_stack) - loc = _source_info_to_location(ctx, eqn.primitive, source_info) + in_nodes = tuple(map(read, eqn.invars)) + assert all(_is_ir_values(v) for v in in_nodes), (eqn, in_nodes) + + avals_in = tuple(v.aval for v in eqn.invars) + ordered_effects = list(effects_lib.ordered_effects.filter_in(eqn.effects)) + tokens_in = tokens.subset(ordered_effects) + + eqn_name_stack = name_stack + eqn.source_info.name_stack + loc = source_info_to_location(ctx, eqn.primitive, eqn_name_stack, + eqn.source_info.traceback) with (source_info_util.user_context(eqn.source_info.traceback), loc, eqn.ctx.manager): - override_rule = get_override_lowering_rule(eqn.primitive) - platform_rules: dict[str, LoweringRule] = {} - default_rule: LoweringRule | None = None - # See mlir.lower_per_platform for meaning of `platform_rules` and `default_rule` - if override_rule is not None: - default_rule = override_rule + # TODO(mattjj, phawkins): support caching for dynamic shapes. + can_cache_lowering = ( + eqn.primitive not in _uncacheable_primitives) + if can_cache_lowering: + loc = source_info_to_location(ctx, None, eqn_name_stack, + eqn.source_info.traceback) + with loc: + out_nodes, tokens_out = _cached_lowering( + ctx, eqn, tokens_in, tuple(dim_var_values), const_lowering, + *in_nodes, **eqn.params) else: - # First the platform-specific rules - for p in _platforms_for_eqn_ctx(eqn.ctx) or ctx.platforms: - if eqn.primitive in _platform_specific_lowerings[p]: - platform_rules[p] = _platform_specific_lowerings[p][eqn.primitive] - # Now the default rule - if eqn.primitive in _lowerings: - default_rule = _lowerings[eqn.primitive] - - effects = list(effects_lib.ordered_effects.filter_in(eqn.effects)) - tokens_in = tokens.subset(effects) - avals_in = map(aval, eqn.invars) - rule_ctx = LoweringRuleContext( - module_context=ctx, primitive=eqn.primitive, - name_stack=source_info.name_stack, - avals_in=avals_in, - avals_out=map(aval, eqn.outvars), tokens_in=tokens_in, - tokens_out=None, jaxpr_eqn_ctx=eqn.ctx, dim_var_values=dim_var_values) - if config.dynamic_shapes.value: - axis_size_env = {d: read(d) - for a in avals_in if type(a) is core.DShapedArray - for d in a.shape if type(d) is core.Var} - rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env) - - assert all(_is_ir_values(v) for v in in_nodes), (eqn, in_nodes) - ans = lower_per_platform(rule_ctx, str(eqn.primitive), - platform_rules, default_rule, - eqn.effects, - *in_nodes, **eqn.params) - - if effects: - # If there were ordered effects in the primitive, there should be output - # tokens we need for subsequent ordered effects. + # If we cannot cache the lowering, lower inline. + axis_size_env = None + rule_ctx = LoweringRuleContext( + module_context=ctx, primitive=eqn.primitive, + name_stack=eqn_name_stack, + traceback=eqn.source_info.traceback, + avals_in=avals_in, + avals_out=tuple(v.aval for v in eqn.outvars), tokens_in=tokens_in, + tokens_out=None, jaxpr_eqn_ctx=eqn.ctx, + dim_var_values=dim_var_values, + axis_size_env=axis_size_env, + const_lowering=const_lowering) + out_nodes, _inline = _uncached_lowering( + eqn.primitive, eqn.ctx, eqn.effects, rule_ctx, *in_nodes, + **eqn.params) tokens_out = rule_ctx.tokens_out - if tokens_out is None: - raise ValueError( - f'Lowering rule for `{eqn.primitive}` needs to set `tokens_out` ' - f'because it has effects: {eqn.effects}.') - if tokens_out.effects() != tokens_in.effects(): - raise ValueError( - f'Lowering rule for `{eqn.primitive}` ' - 'returns incorrect set of output tokens. ' - f'Expected: {tuple(tokens_in.effects())} vs. Actual: {tuple(tokens_out.effects())}') - tokens = tokens.update_tokens(tokens_out) - try: - out_nodes = tuple(ans) - except TypeError as e: - raise ValueError("Output of translation rule must be iterable: " - f"{eqn}, got output {ans}") from e + assert len(out_nodes) == len(eqn.outvars), (out_nodes, eqn) + if ordered_effects: + tokens = tokens.update_tokens(tokens_out) - assert len(ans) == len(eqn.outvars), (ans, eqn) foreach(write, eqn.outvars, out_nodes) core.clean_up_dead_vars(eqn, env, last_used) return tuple(read(v) for v in jaxpr.outvars), tokens +def _cached_lowering( + ctx: ModuleContext, eqn: core.JaxprEqn, + tokens_in: TokenSet, + dim_var_values: tuple[ir.Value, ...], + const_lowering: dict[tuple[int, core.AbstractValue], IrValues], + *args, **params) -> tuple[Sequence[IrValues], TokenSet]: + """Lowers a jaxpr equation, using a cache. + + The jaxpr equation's lowering is emitted as an out-of-line MLIR function, and + that function's construction is cached in the event that we see a similar + equation. For each such equation we either inline the function body or emit + an out-of-line call to it, depending on whether any of the lowering rules + opted out of inlining.""" + avals_in = tuple(v.aval for v in eqn.invars) + ordered_effects = list(effects_lib.ordered_effects.filter_in(eqn.effects)) + cache_key = LoweringCacheKey( + primitive=eqn.primitive, + eqn_ctx=eqn.ctx, + avals_in=avals_in, + effects=frozenset(eqn.effects), + params=frozen_dict.FrozenDict(eqn.params), + platforms=tuple(ctx.platforms), + ) + try: + cache_entry = ctx.lowering_cache.get(cache_key, None) + except TypeError: + print("Unable to hash key: ", eqn) + raise + if cache_entry is None: + avals_out = map(lambda v: v.aval, eqn.outvars) + cache_entry = _emit_lowering_rule_as_fun( + partial(_uncached_lowering, eqn.primitive, eqn.ctx, eqn.effects), ctx, + eqn.ctx, eqn.primitive, ordered_effects, avals_in, + avals_out, **params) + ctx.lowering_cache[cache_key] = cache_entry + + tokens_in_args = tuple(tokens_in.get(eff) for eff in ordered_effects) + const_arg_values = tuple( + ir_constant(c, const_lowering=const_lowering, aval=aval) + for c, aval in zip(cache_entry.const_args, cache_entry.const_arg_avals) + ) + args = flatten_ir_values( + dim_var_values + tokens_in_args + const_arg_values + args) + if cache_entry.inline: + outs = jax_mlir_ext.inlined_func_call( + cache_entry.func, args, ir.InsertionPoint.current.block) + else: + outs = func_dialect.CallOp( + flatten_ir_types(cache_entry.output_types), + ir.FlatSymbolRefAttr.get(cache_entry.func.sym_name.value), + args + ).results + out_nodes = unflatten_ir_values_like_types(outs, cache_entry.output_types) + token_outs, out_nodes = util.split_list(out_nodes, [len(ordered_effects)]) + return out_nodes, TokenSet(zip(ordered_effects, token_outs)) + + +def _emit_lowering_rule_as_fun( + lowering_rule: LoweringRule, + ctx: ModuleContext, + eqn_ctx: core.JaxprEqnContext, + primitive: core.Primitive, + ordered_effects: Sequence[core.Effect], + avals_in: Sequence[core.AbstractValue], + avals_out: Sequence[core.AbstractValue], + **params +) -> LoweringCacheValue: + """Emits the contents of a lowering rule as a private function.""" + num_dim_vars = len(ctx.shape_poly_state.dim_vars) + # TODO(necula) maybe only pass the dim_vars if they are needed? + dim_var_types = [ + aval_to_ir_type(core.ShapedArray((), dtypes.default_int_dtype())) + ] * num_dim_vars + + const_args, const_arg_avals = util.unzip2(core.eqn_params_const_args(params)) + + input_types = map(aval_to_ir_type, const_arg_avals + avals_in) # type: ignore + output_types = map(aval_to_ir_type, avals_out) + token_types = [token_type() for _ in ordered_effects] + input_types = [*dim_var_types, *token_types, *input_types] + output_types = [*token_types, *output_types] + + flat_input_types = flatten_ir_types(input_types) + flat_output_types = flatten_ir_types(output_types) + ftype = ir.FunctionType.get(flat_input_types, flat_output_types) + func_op = func_dialect.FuncOp(primitive.name, ftype, + ip=ctx.ip) + func_op.attributes["sym_visibility"] = ir.StringAttr.get("private") + ctx.symbol_table.insert(func_op).value + entry_block = func_op.add_entry_block() + with ir.InsertionPoint(entry_block): + unflattened_args = unflatten_ir_values_like_types( + entry_block.arguments, input_types) + dim_var_values, token_args, const_arg_values, unflattened_args = \ + util.split_list(unflattened_args, + [num_dim_vars, len(ordered_effects), len(const_args)]) + const_lowering = { + (id(c), aval): c_arg + for c, aval, c_arg in zip(const_args, const_arg_avals, const_arg_values) + } + sub_ctx = LoweringRuleContext( + module_context=ctx, primitive=primitive, + name_stack=source_info_util.new_name_stack(), + traceback=None, + avals_in=avals_in, avals_out=avals_out, + tokens_in=TokenSet(zip(ordered_effects, token_args)), + tokens_out=None, jaxpr_eqn_ctx=eqn_ctx, dim_var_values=dim_var_values, + const_lowering=const_lowering) + with ir.Location.name(str(primitive.name)): + outs, inline = lowering_rule(sub_ctx, *unflattened_args, **params) + if sub_ctx.tokens_out: + outs = [*[sub_ctx.tokens_out.get(eff) for eff in ordered_effects], *outs] + outs = flatten_ir_values(outs) + func_dialect.return_(outs) + return LoweringCacheValue(func_op, output_types, const_args, const_arg_avals, + inline) + + +def _get_override_lowering_rule( + ctx: ModuleContext, primitive: core.Primitive +) -> LoweringRule | None: + if ctx.lowering_parameters.override_lowering_rules is None: + return None + for p, rule in ctx.lowering_parameters.override_lowering_rules: + if primitive is p: + return rule + return None + +def _uncached_lowering( + primitive: core.Primitive, + eqn_ctx: core.JaxprEqnContext, + effects: effects_lib.Effects, + ctx: LoweringRuleContext, + *args, + **params, +): + inline = True # Should calls to this lowering rule be inlined? + override_rule = _get_override_lowering_rule(ctx.module_context, primitive) + platform_rules: dict[str, LoweringRule] = {} + default_rule: LoweringRule | None = None + # See mlir.lower_per_platform for meaning of `platform_rules` and `default_rule` + if override_rule is not None: + default_rule = override_rule + assert not isinstance(default_rule, LoweringRuleEntry) + else: + # First the platform-specific rules + for p in _platforms_for_eqn_ctx(eqn_ctx) or ctx.module_context.platforms: + if primitive in _platform_specific_lowerings[p]: + r = _platform_specific_lowerings[p][primitive] + platform_rules[p] = r.rule + inline = inline and r.inline + # Now the default rule + if primitive in _lowerings: + r = _lowerings[primitive] + default_rule = r.rule + assert not isinstance(default_rule, LoweringRuleEntry) + inline = inline and r.inline + + assert not isinstance(default_rule, LoweringRuleEntry) + assert not any(isinstance(r, LoweringRuleEntry) for r in platform_rules.values()) + ans = lower_per_platform(ctx, str(primitive), platform_rules, default_rule, + effects, *args, **params) + try: + rets = tuple(ans) + except TypeError as e: + raise ValueError("Output of translation rule must be iterable: " + f"{primitive}, got output {ans}") from e + + if ctx.tokens_in.effects(): + # If there were ordered effects in the primitive, there should be output + # tokens we need for subsequent ordered effects. + tokens_out = ctx.tokens_out + if tokens_out is None: + raise ValueError( + f'Lowering rule for `{primitive}` needs to set `tokens_out` ' + f'because it has effects: {effects}.') + if tokens_out.effects() != ctx.tokens_in.effects(): + raise ValueError( + f"Lowering rule for `{primitive}` returns incorrect set of output" + f" tokens. Expected: {tuple(ctx.tokens_in.effects())} vs. Actual:" + f" {tuple(tokens_out.effects())}" + ) + return rets, inline + def _platforms_for_eqn_ctx(eqn_ctx: core.JaxprEqnContext | None ) -> tuple[str, ...]: @@ -2026,6 +2288,14 @@ def _platforms_for_eqn_ctx(eqn_ctx: core.JaxprEqnContext | None return ('tpu',) return () +def _platforms_for_eqn(ctx: LoweringRuleContext) -> tuple[str, ...]: + """The lowering platforms for the current eqn""" + return tuple(_platforms_for_eqn_ctx(ctx.jaxpr_eqn_ctx) or + ctx.platforms or ctx.module_context.platforms) + +def _get_owner(v: ir.Value): + owner = v.owner + return owner.operation if isinstance(owner, ir.OpView) else owner def lower_per_platform(ctx: LoweringRuleContext, description: str, @@ -2068,8 +2338,7 @@ def lower_per_platform(ctx: LoweringRuleContext, rule_args: the args of the lowering rules. rule_kwargs: the kwargs of the lowering rules. """ - platforms: Sequence[str] = (_platforms_for_eqn_ctx(ctx.jaxpr_eqn_ctx) or - ctx.platforms or ctx.module_context.platforms) + platforms: Sequence[str] = _platforms_for_eqn(ctx) # Special case the common case (single-platform lowering) if len(platforms) == 1: rule = platform_rules.get(platforms[0], default_rule) @@ -2103,11 +2372,11 @@ def lower_per_platform(ctx: LoweringRuleContext, if len(kept_rules) == 1: output = kept_rules[0](ctx, *rule_args, **rule_kwargs) foreach( - lambda o: wrap_compute_type_in_place(ctx, o.owner), + lambda o: wrap_compute_type_in_place(ctx, _get_owner(o)), filter(_is_not_block_argument, flatten_ir_values(output)), ) foreach( - lambda o: wrap_xla_metadata_in_place(ctx, o.owner), + lambda o: wrap_xla_metadata_in_place(ctx, _get_owner(o)), flatten_ir_values(output), ) return output @@ -2148,11 +2417,11 @@ def lower_per_platform(ctx: LoweringRuleContext, raise ValueError("Output of translation rule must be iterable: " f"{description}, got output {output}") from e foreach( - lambda o: wrap_compute_type_in_place(ctx, o.owner), + lambda o: wrap_compute_type_in_place(ctx, _get_owner(o)), filter(_is_not_block_argument, out_nodes), ) foreach( - lambda o: wrap_xla_metadata_in_place(ctx, o.owner), + lambda o: wrap_xla_metadata_in_place(ctx, _get_owner(o)), out_nodes, ) if inner_ctx.tokens_out is not None: @@ -2171,14 +2440,11 @@ def lower_per_platform(ctx: LoweringRuleContext, ctx.set_tokens_out(tokens_out) return results -def _ir_consts(consts) -> list[IrValues]: - unique_consts = {id(const): const for const in consts} - ir_consts = { - id_: ir_constant(xla.canonicalize_dtype(const)) - for id_, const in unique_consts.items() +def ir_consts(consts, avals: Sequence[core.AbstractValue]) -> list[IrValues]: + uniq_consts = { + id(c): ir_constant(c, aval=aval) for c, aval in zip(consts, avals) } - return [ir_consts[id(const)] for const in consts] - + return [uniq_consts[id(c)] for c in consts] def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable: """Converts a traceable JAX function `fun` into a lowering rule. @@ -2188,61 +2454,64 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable: def f_lowered(ctx: LoweringRuleContext, *args, **params): f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),) wrapped_fun = lu.wrap_init(f, params, - debug_info=api_util.debug_info("lower_fun", fun, args, params)) - manager = (contextlib.nullcontext() if ctx.jaxpr_eqn_ctx is None else - ctx.jaxpr_eqn_ctx.manager) - - with manager: - if config.dynamic_shapes.value: - # We might be applying this function to arguments with dynamic shapes, - # i.e. there might be Vars in the shape tuples of ctx.avals_in. In that - # case, we need to form a jaxpr with leading binders for those axis size - # arguments (by computing an InputType and using trace_to_jaxpr_dynamic2), - # and we need to call jaxpr_subcomp with these arguments made explicit. - assert ctx.axis_size_env is not None - args = (*ctx.axis_size_env.values(), *args) - idx = {d: core.DBIdx(i) for i, d in enumerate(ctx.axis_size_env)} - i32_aval = core.ShapedArray((), np.dtype('int32')) - implicit_args = [(i32_aval, False)] * len(ctx.axis_size_env) - explicit_args = [(a.update(shape=tuple(idx.get(d, d) for d in a.shape)) # type: ignore - if type(a) is core.DShapedArray else a, True) - for a in ctx.avals_in] - wrapped_fun = lu.annotate(wrapped_fun, (*implicit_args, *explicit_args)) - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic2(wrapped_fun) - else: - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) - # TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out? + debug_info=api_util.debug_info("lower_fun", fun, args, {})) - if ctx.platforms is not None: - sub_context = ctx.module_context.replace(platforms=ctx.platforms) - else: - sub_context = ctx.module_context - out, tokens = jaxpr_subcomp( - sub_context, jaxpr, ctx.name_stack, ctx.tokens_in, - _ir_consts(consts), *args, - dim_var_values=ctx.dim_var_values) - ctx.set_tokens_out(tokens) - return out + jaxpr, _, consts_for_constvars = pe.trace_to_jaxpr_dynamic( + wrapped_fun, ctx.avals_in) + + if any(isinstance(e, core.InternalMutableArrayEffect) for e in jaxpr.effects): + from jax._src.interpreters import pxla # type: ignore + closed_jaxpr = core.ClosedJaxpr(jaxpr, consts_for_constvars) + closed_jaxpr = pxla._discharge_internal_refs(closed_jaxpr) + jaxpr, consts_for_constvars = closed_jaxpr.jaxpr, closed_jaxpr.consts + + # TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out? + + if ctx.platforms is not None: + sub_context = ctx.module_context.replace(platforms=ctx.platforms) + else: + sub_context = ctx.module_context + out, tokens = jaxpr_subcomp( + sub_context, jaxpr, ctx.name_stack, ctx.tokens_in, + ir_consts(consts_for_constvars, [v.aval for v in jaxpr.constvars]), + *args, + dim_var_values=ctx.dim_var_values, + const_lowering=ctx.const_lowering) + ctx.set_tokens_out(tokens) + return out return f_lowered -def _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects, name_stack, - arg_names=None, result_names=None): +def _lower_jaxpr_to_fun_cached( + ctx: ModuleContext, fn_name, call_jaxpr: core.ClosedJaxpr, + num_const_args: int, effects, in_avals, arg_names=None, result_names=None): + assert num_const_args + len(call_jaxpr.in_avals) == len(in_avals) if not call_jaxpr.consts and arg_names is result_names is None: # Cacheable. key = (fn_name, call_jaxpr.jaxpr, tuple(effects)) try: - func_op = ctx.cached_primitive_lowerings[key] + func_op, _, _ = ctx.cached_primitive_lowerings[key] except KeyError: + num_callbacks = len(ctx.host_callbacks) func_op = lower_jaxpr_to_fun( - ctx, fn_name, call_jaxpr, effects, name_stack, arg_names=arg_names, - result_names=result_names) - ctx.cached_primitive_lowerings[key] = func_op + ctx, fn_name, call_jaxpr, effects, num_const_args=num_const_args, + in_avals=in_avals, arg_names=arg_names, result_names=result_names) + # If this Jaxpr includes callbacks, we can't cache the lowering because + # on TPU under libtpu <= 0.0.34 every callback must have a globally + # unique channel, but the channel gets assigned during lowering. + has_callbacks = len(ctx.host_callbacks) > num_callbacks + if USE_NEW_TPU_CALLBACK_LOWERING or not has_callbacks or "tpu" not in ctx.platforms: + ctx.cached_primitive_lowerings[key] = ( + func_op, + func_op.name.value, + func_op.type.results, + ) else: func_op = lower_jaxpr_to_fun( - ctx, fn_name, call_jaxpr, effects, name_stack, arg_names=arg_names, - result_names=result_names) + ctx, fn_name, call_jaxpr, effects, + num_const_args=num_const_args, in_avals=in_avals, + arg_names=arg_names, result_names=result_names) return func_op @@ -2264,42 +2533,40 @@ def check_backend_matches(inner_backend: str | None, def lower_called_computation( - fn_name, - name_stack, - call_jaxpr, - ctx: ModuleContext, - avals_out, - tokens_in, - backend=None, - arg_names=None, - result_names=None, -): - if isinstance(call_jaxpr, core.Jaxpr): - call_jaxpr = pe.close_jaxpr(call_jaxpr) + fn_name, call_jaxpr: core.ClosedJaxpr, ctx: ModuleContext, + num_const_args: int, in_avals, out_avals, tokens_in, backend=None, + arg_names=None, result_names=None): + assert isinstance(call_jaxpr, core.ClosedJaxpr), type(call_jaxpr) check_backend_matches(backend, ctx.platforms) effects = list(tokens_in.effects()) - output_types = map(aval_to_ir_type, avals_out) + output_types = map(aval_to_ir_type, out_avals) output_types = [token_type()] * len(effects) + output_types func_op = _lower_jaxpr_to_fun_cached( - ctx, - fn_name, - call_jaxpr, - effects, - name_stack, - arg_names=arg_names, - result_names=result_names, - ) + ctx, fn_name, call_jaxpr, num_const_args, effects, in_avals=in_avals, + arg_names=arg_names, result_names=result_names) return func_op, output_types, effects -def call_lowering(fn_name, name_stack, call_jaxpr, backend, - ctx: ModuleContext, avals_in, - avals_out, tokens_in, *args, +def call_lowering(fn_name, call_jaxpr: core.ClosedJaxpr, backend, + ctx: ModuleContext, in_avals, + out_avals, tokens_in, *args, dim_var_values: Sequence[ir.Value], - arg_names=None, result_names=None): - del avals_in + const_lowering: dict[tuple[int, core.AbstractValue], IrValues], + arg_names=None, result_names=None, + attributes: None | dict[str, Any] = None): + assert isinstance(call_jaxpr, core.ClosedJaxpr), type(call_jaxpr) + const_args_and_avals = core.jaxpr_const_args(call_jaxpr.jaxpr) + const_args, const_avals = util.unzip2(const_args_and_avals) + const_arg_values = [ir_constant(c, const_lowering=const_lowering, aval=aval) + for c, aval in const_args_and_avals] + args = tuple(const_arg_values) + args + if arg_names is not None: + arg_names = [""] * len(const_args) + arg_names + in_avals = (*const_avals, *in_avals) + func_op, output_types, effects = lower_called_computation( - fn_name, name_stack, call_jaxpr, ctx, avals_out, tokens_in, + fn_name, call_jaxpr, ctx, len(const_args), in_avals, out_avals, + tokens_in, backend=backend, arg_names=arg_names, result_names=result_names) symbol_name = func_op.name.value flat_output_types = flatten_ir_types(output_types) @@ -2308,60 +2575,68 @@ def call_lowering(fn_name, name_stack, call_jaxpr, backend, call = func_dialect.CallOp(flat_output_types, ir.FlatSymbolRefAttr.get(symbol_name), flatten_ir_values(args)) + if attributes: + call.operation.attributes['mhlo.frontend_attributes'] = ir.DictAttr.get(attributes) out_nodes = unflatten_ir_values_like_types(call.results, output_types) tokens, out_nodes = util.split_list(out_nodes, [len(effects)]) tokens_out = tokens_in.update_tokens(TokenSet(zip(effects, tokens))) return out_nodes, tokens_out def core_call_lowering(ctx: LoweringRuleContext, - *args, name, backend=None, call_jaxpr): + *args, name, backend=None, + call_jaxpr: core.ClosedJaxpr | core.Jaxpr): + if isinstance(call_jaxpr, core.Jaxpr): + call_jaxpr = pe.close_jaxpr(call_jaxpr) out_nodes, tokens = call_lowering( - name, ctx.name_stack, call_jaxpr, backend, ctx.module_context, + name, call_jaxpr, backend, ctx.module_context, ctx.avals_in, ctx.avals_out, ctx.tokens_in, *args, - dim_var_values=ctx.dim_var_values) + dim_var_values=ctx.dim_var_values, + const_lowering=ctx.const_lowering) ctx.set_tokens_out(tokens) return out_nodes register_lowering(core.call_p, partial(core_call_lowering, name="core_call")) +# TODO(phawkins): Not cacheable because of debug_print on TPU. register_lowering(core.closed_call_p, - partial(core_call_lowering, name=None)) - -def map_compute_type(c_type): - if c_type == 'device_host': - return 'host' - elif c_type == 'device': - return 'dense' - elif c_type == 'tpu_sparsecore': - return 'sparse' - raise ValueError(f'Invalid compute type {c_type}. Current supported values ' - 'are `device_host`, `device` and `tpu_sparsecore') - -def wrap_compute_type_in_place(ctx, op): + partial(core_call_lowering, name="closed_call"), + cacheable=False) + +def map_compute_type(c_type: str) -> str: + if c_type == "device_host": + return "host" + elif c_type == "device": + return "dense" + elif c_type == "tpu_sparsecore": + return "sparseoffload" + raise ValueError(f"Invalid compute type {c_type}. Current supported values " + "are `device_host`, `device` and `tpu_sparsecore`") + +def wrap_compute_type_in_place(ctx: LoweringRuleContext, op: ir.Operation) -> None: if ctx.jaxpr_eqn_ctx is not None and ctx.jaxpr_eqn_ctx.compute_type is not None: if ctx.jaxpr_eqn_ctx.compute_type.startswith("gpu_stream:"): stream = ctx.jaxpr_eqn_ctx.compute_type.split(":")[1] - dict_attr = {"_xla_stream_annotation": ir.StringAttr.get(stream)} - op.operation.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr) + dict_attr = { + "_xla_stream_annotation": ir.StringAttr.get(stream), + "inlineable": ir.StringAttr.get("false"), + } + op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr) else: dict_attr = {"_xla_compute_type": ir.StringAttr.get( map_compute_type(ctx.jaxpr_eqn_ctx.compute_type))} - op.operation.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr) + op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr) -def wrap_xla_metadata_in_place(ctx, op): - ctx_attributes = {} - existing_attributes = {} +def wrap_xla_metadata_in_place(ctx: LoweringRuleContext, op: ir.Operation) -> None: if ctx.jaxpr_eqn_ctx is not None and ctx.jaxpr_eqn_ctx.xla_metadata: + ctx_attributes, existing_attributes = {}, {} for k, v in ctx.jaxpr_eqn_ctx.xla_metadata.items(): ctx_attributes[k] = ir.StringAttr.get(str(v).lower()) if isinstance(op, ir.Operation): # combine with existing mhlo.frontend_attributes - op_attributes_dict = {attr.name: attr.attr for attr in op.attributes} - for k, attributes in op_attributes_dict.items(): - if k == "mhlo.frontend_attributes": - v_dict = {attr.name: attr.attr for attr in attributes} - for fa_key, fa_val in v_dict.items(): - existing_attributes[fa_key] = fa_val + for attr in op.attributes: + if attr == "mhlo.frontend_attributes": + for a in op.attributes[attr]: + existing_attributes[a.name] = a.attr op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get( ctx_attributes | existing_attributes ) @@ -2393,25 +2668,37 @@ def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue, out = hlo.broadcast_in_dim( aval_to_ir_type(aval_out), op, dense_int_array(broadcast_dimensions)) - wrap_compute_type_in_place(ctx, out.owner) + wrap_compute_type_in_place(ctx, _get_owner(out)) return out def multi_broadcast_in_dim(ctx: LoweringRuleContext, ops: Sequence[ir.Value], ops_avals: Sequence[core.AbstractValue], - out_shape: core.Shape) -> Sequence[ir.Value]: + out_shape: core.Shape, + out_sharding) -> Sequence[ir.Value]: """Broadcasts multiple ops to the out_shape.""" out = [] for op, op_aval in zip(ops, ops_avals): op_aval_shape = op_aval.shape # type: ignore + op_aval_sharding = op_aval.sharding # type: ignore + out_aval = core.ShapedArray( + out_shape, op_aval.dtype, sharding=out_sharding) # type: ignore if core.definitely_equal_shape(op_aval_shape, out_shape): - out.append(op) + if op_aval_sharding.spec.unreduced or op_aval_sharding.spec.reduced: + out.append(op) + elif op_aval_sharding == out_sharding: + out.append(op) + else: + out.append(lower_with_sharding_in_types(ctx, op, out_aval)) else: + if op_aval_sharding.spec.unreduced or op_aval_sharding.spec.reduced: + raise NotImplementedError() assert len(op_aval_shape) <= len(out_shape), (op_aval_shape, out_shape) broadcast_dimensions = list(range(len(out_shape) - len(op_aval_shape), len(out_shape))) - out.append(broadcast_in_dim(ctx, op, - core.ShapedArray(out_shape, op_aval.dtype), # type: ignore - broadcast_dimensions=broadcast_dimensions)) + b_out = broadcast_in_dim( + ctx, op, out_aval, broadcast_dimensions=broadcast_dimensions) + b_out = lower_with_sharding_in_types(ctx, b_out, out_aval) + out.append(b_out) return out def reshape(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue) -> ir.Value: @@ -2456,8 +2743,7 @@ def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *, if dtypes.issubdtype(aval_out.dtype, dtypes.extended): elt_shape = core.physical_element_aval(aval_out.dtype).shape index_avals = ctx.avals_in[1:] - dtype = dtypes.canonicalize_dtype( - index_avals[0].dtype if index_avals else 'int64') # type: ignore + dtype = index_avals[0].dtype if index_avals else np.int32 # type: ignore trailing_zeros = [ir_constant(np.array(0, dtype))] * len(elt_shape) start_indices = (*start_indices, *trailing_zeros) aval_out = core.physical_aval(aval_out) @@ -2489,8 +2775,7 @@ def dynamic_update_slice(ctx: LoweringRuleContext, aval_out, x, update, *, if dtypes.issubdtype(aval_out.dtype, dtypes.extended): elt_shape = core.physical_element_aval(aval_out.dtype).shape index_avals = ctx.avals_in[2:] - dtype = dtypes.canonicalize_dtype( - index_avals[0].dtype if index_avals else 'int64') # type: ignore + dtype = index_avals[0].dtype if index_avals else np.int32 # type: ignore zeros = [ir_constant(np.array(0, dtype=dtype))] * len(elt_shape) start_indices = (*start_indices, *zeros) physical_aval_out = core.physical_aval(aval_out) @@ -2530,7 +2815,7 @@ def iota(ctx: LoweringRuleContext, aval_out, *, dimension: int): def full_like_aval(ctx: LoweringRuleContext, value, aval: core.ShapedArray) -> ir.Value: """Returns an IR constant shaped full of `value` shaped like `aval`.""" - zero = ir_constant(np.array(value, dtypes.canonicalize_dtype(aval.dtype))) + zero = ir_constant(np.array(value, aval.dtype)) return broadcast_in_dim(ctx, zero, aval, broadcast_dimensions=()) def add_jaxvals_lowering(ctx, x, y): @@ -2561,7 +2846,7 @@ def compare_hlo(x, y, direction: str, comparison_type: str | None = None): def _minmax_hlo(op, cmp, x, y): """Min/max that compares complex values lexicographically as pairs.""" tensor_type = ir.RankedTensorType(x.type) - if ir.ComplexType.isinstance(tensor_type.element_type): + if isinstance(tensor_type.element_type, ir.ComplexType): rx = hlo.real(x) ry = hlo.real(y) real_eq = compare_hlo(rx, ry, "EQ", "FLOAT") @@ -2597,7 +2882,7 @@ def _wrap_with_spmd_op(name: str, ctx: LoweringRuleContext, x: ir.Value, aval_out: core.AbstractValue, - sharding: xc.OpSharding | SdyArraySharding, + sharding: xc.OpSharding | SdyArray, unspecified_dims: set[int] | None = None, has_side_effect: bool = False, allow_shardy_lowering: bool = False): @@ -2638,9 +2923,9 @@ def lower_with_sharding_in_types(ctx, op, aval, sharding_proto=None): if aval.sharding.mesh.empty: return op # Don't emit a wsc under full manual mode to avoid increasing HLO size. - if aval.sharding.mesh._are_all_axes_manual: + if aval.sharding.mesh.are_all_axes_manual: return op - if aval.sharding.mesh._are_all_axes_auto: + if aval.sharding.mesh.are_all_axes_auto: return op # TODO(yashkatariya): If all the axes in pspec are AUTO or collective, # `return op` early and avoid bloating HLO size. @@ -2656,23 +2941,21 @@ def lower_with_sharding_in_types(ctx, op, aval, sharding_proto=None): if sharding_proto is None else sharding_proto) unspecified_dims = None if aval.sharding.mesh._any_axis_auto: - # TODO(yashkatariya): Maybe if any mesh axis is auto, mark all axes - # as unspecified? - unspecified_dims = {i for i, s in enumerate(aval.sharding.spec) if s is None} + unspecified_dims = set(range(aval.ndim)) return wrap_with_sharding_op(ctx, op, aval, proto, unspecified_dims) -def set_sharding(op, sharding: xc.OpSharding | SdyArraySharding | SdyArrayShardingList): - if config.use_shardy_partitioner.value: +def set_sharding(op, sharding: xc.OpSharding | SdyArray | SdyArrayList): + if isinstance(sharding, (SdyArray, SdyArrayList)): op.attributes["sdy.sharding"] = get_sharding_attr(sharding) else: op.attributes["mhlo.sharding"] = get_sharding_attr(sharding) def get_sharding_attr( - sharding: xc.OpSharding | SdyArraySharding | SdyArrayShardingList + sharding: xc.OpSharding | SdyArray | SdyArrayList ) -> ir.Attribute: - if config.use_shardy_partitioner.value: + if isinstance(sharding, (SdyArray, SdyArrayList)): return sharding.build() # type: ignore else: # If there are very large numbers of devices, use the proto representation. @@ -2687,7 +2970,7 @@ def get_sharding_attr( def wrap_with_layout_op(ctx: LoweringRuleContext, x: ir.Value, aval_out: core.AbstractValue, - layout: DeviceLocalLayout, + layout: Layout, aval_in: core.AbstractValue): result_type = aval_to_ir_type(aval_out) assert isinstance(result_type, ir.Type), result_type @@ -2710,47 +2993,6 @@ def wrap_with_layout_op(ctx: LoweringRuleContext, # MLIR lowerings for lax primitives -def cache_lowering(f): - """Decorator that causes the contents of a lowering rule to be reused. - - The lowering will be emitted out-of-line in a separate function, together with - a call to that function. If the same primitive is called with the same shapes - and parameters, a new call to the original function will be added, without - emitting a new function. We allow for different lowering for the same - primitive for different platforms in the same module. - """ - @functools.wraps(f) - def cached_lowering(ctx, *args, **params): - assert ctx.primitive is not None - key = (f, ctx.primitive, - tuple(ctx.avals_in), tuple(ctx.avals_out), - tuple(params.items())) - try: - func = ctx.module_context.cached_primitive_lowerings.get(key) - except TypeError: - # If the parameters aren't hashable, give up on caching. - # TODO(phawkins): switch to requiring hashability, when XLA fallback - # computations have been ported to MLIR. - return f(ctx, *args, **params) - if func is None: - func = _emit_lowering_rule_as_fun(partial(f, **params), ctx) - ctx.module_context.cached_primitive_lowerings[key] = func - - output_types = map(aval_to_ir_type, ctx.avals_out) - args = tuple(ctx.dim_var_values) + args - flat_output_types = flatten_ir_types(output_types) - call = func_dialect.CallOp(flat_output_types, - ir.FlatSymbolRefAttr.get(func.name.value), - flatten_ir_values(args)) - return unflatten_ir_values_like_types(call.results, output_types) - return cached_lowering - - -def xla_computation_to_mlir_module(xla_computation: xc.XlaComputation - ) -> ir.Module: - module_str = xc._xla.mlir.xla_computation_to_mlir_module(xla_computation) - return ir.Module.parse(module_str) - def merge_mlir_modules(dst_module: ir.Module, sym_name: str, src_module: ir.Module, @@ -2830,18 +3072,20 @@ def merge_mlir_modules(dst_module: ir.Module, def build_mlir_module_helper( closed_jaxpr: core.ClosedJaxpr, *, name: str, platforms: Sequence[str], - backend: xb.XlaBackend | None, + backend: xc.Client | None, axis_context: AxisContext) -> ir.Module: """Helper to generate pmap-style XLA computations for custom partitioners.""" - unlowerable_effects = lowerable_effects.filter_not_in(closed_jaxpr.effects) + unlowerable_effects = effects_lib.lowerable_effects.filter_not_in( + closed_jaxpr.effects) if unlowerable_effects: raise ValueError(f'Cannot lower jaxpr with effects: {closed_jaxpr.effects}') lowering_result = lower_jaxpr_to_module(name, closed_jaxpr, + num_const_args=0, + in_avals=closed_jaxpr.in_avals, backend=backend, ordered_effects=[], - name_stack=source_info_util.NameStack(), donated_args=[False] * len(closed_jaxpr.jaxpr.invars), axis_context=axis_context, platforms=platforms, - lowering_parameters=LoweringParameters()) + lowering_parameters=LoweringParameters(hoist_constants_as_args=False)) return lowering_result.module def custom_call( @@ -3027,15 +3271,12 @@ def refine_polymorphic_shapes(module: ir.Module) -> ir.Module: Then verifies that there are no more dynamic shapes in the module. """ try: - refine_polymorphic_shapes = partial(xla_extension.mlir.refine_polymorphic_shapes, + refine_polymorphic_shapes = partial(_jax.mlir.refine_polymorphic_shapes, mlir_module=module_to_bytecode(module), enable_shape_assertions=True, validate_static_shapes=True) - if xla_extension_version >= 319: - refined_module_str = refine_polymorphic_shapes( - enable_shardy=config.use_shardy_partitioner.value) - else: - refined_module_str = refine_polymorphic_shapes() + refined_module_str = refine_polymorphic_shapes( + enable_shardy=config.use_shardy_partitioner.value) except Exception as e: raise ValueError( "Error refining shapes. " + diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 07c516fd95c7..8a5ba4aa1558 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -11,16 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# pytype: skip-file from __future__ import annotations from collections import namedtuple -from collections.abc import Callable, Sequence, Hashable -from contextlib import contextmanager +from collections.abc import Callable, Sequence +import contextlib +from dataclasses import dataclass from functools import partial +import logging import itertools as it import operator as op from typing import Any, NamedTuple, Union -from weakref import ref +from weakref import finalize, ref, ReferenceType, WeakValueDictionary import numpy as np @@ -33,21 +37,21 @@ from jax._src import linear_util as lu from jax._src import profiler from jax._src import source_info_util -from jax._src import compute_on -from jax._src import xla_metadata as xla_metadata_lib -from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval, - AbstractValue, ClosedJaxpr, new_jaxpr_eqn, - Var, DropVar, Atom, - JaxprEqn, Primitive, ShapedArray, DShapedArray, - mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx, - InputType, OutputType, get_referent, JaxprEqnContext) +from jax._src import xla_metadata_lib +from jax._src import tree_util +from jax._src.core import ( + Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval, AbstractValue, + ClosedJaxpr, new_jaxpr_eqn, Var, DropVar, Atom, JaxprEqn, Primitive, + mapped_aval, unmapped_aval, get_referent, JaxprEqnContext, typeof) +from jax._src.source_info_util import SourceInfo from jax._src.state.types import AbstractRef, ReadEffect -from jax._src.tree_util import (PyTreeDef, treedef_tuple, - tree_flatten, tree_structure) +from jax._src.tree_util import PyTreeDef, treedef_tuple, FlatTree from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list, merge_lists, partition_list, OrderedSet, - as_hashable_function, weakref_lru_cache, subs_list, - HashableFunction, foreach) + as_hashable_function, weakref_lru_cache, + multi_weakref_lru_cache, subs_list, + HashableFunction, foreach, test_event) +from jax._src.lib import jaxlib_extension_version map, unsafe_map = safe_map, map @@ -58,41 +62,9 @@ def identity(x): return x AvalId = int ConstId = int -def _update_annotation_known( - f: lu.WrappedFun, - orig_type: InputType | None, - in_knowns: list[bool] - ) -> lu.WrappedFun: - if orig_type is None: return f - # orig_type might contain DBIdx, but we're tossing out some args so we have to - # re-index. moreover some of the implicit args may not be needed anymore. - # so we basically just re-infer the lambda input type - if (all(e for _, e in orig_type) and - not any(type(d) is DBIdx for a, _ in orig_type for d in a.shape - if type(a) is DShapedArray)): - new_type = [ty for ty, known in zip(orig_type, in_knowns) if known] - return lu.annotate(f, tuple(new_type)) - - # Replace DBIdx with names, prune down to explicit only. - class Name: - def __init__(self, a): self.a = a - names = [Name(a) for a, _ in orig_type] - avals = [a.update(shape=tuple(names[d.val] if type(d) is DBIdx else d - for d in a.shape)) - if type(a) is DShapedArray else a for a, e in orig_type if e] - avals = [a for a, known in zip(avals, in_knowns) if known] - # Figure out the implicit part: names which aren't explicit and known. - expl_names = [o for o, (_, e) in zip(names, orig_type) if e] - expl_names = [o for o, k in zip(expl_names, in_knowns) if k] - expl_names_ = set(expl_names) - impl_names = {d for a in avals if type(a) is DShapedArray for d in a.shape - if type(d) is Name and d not in expl_names_} - impl_part = [(n.a, False) for n in impl_names] # type: ignore - # Figure out the explicit part: known explicit avals, replacing names w/ dbidx - name_map = {n: DBIdx(i) for i, n in enumerate((*impl_names, *expl_names))} - expl_part = [(a.update(shape=tuple(name_map.get(d, d) for d in a.shape)) - if type(a) is DShapedArray else a, True) for a in avals] - return lu.annotate(f, (*impl_part, *expl_part)) +AttrKind = Any +PyTree = Any +logger = logging.getLogger(__name__) class PartialVal(tuple): """Partial value: either a known value or an unknown (abstract) value. @@ -137,6 +109,10 @@ def get_aval(self) -> AbstractValue: else: return self[0] +@dataclass(frozen=True) +class EffectHandle: + parents : list[Tracer] + recipe : JaxprEqnRecipe class JaxprTrace(Trace['JaxprTracer']): @@ -145,6 +121,9 @@ def __init__(self, parent_trace:Trace, name_stack: source_info_util.NameStack, t self.name_stack = name_stack self.tag = tag self.parent_trace = parent_trace + self.requires_low = False + self.effect_handles : list[EffectHandle] = [] + self.counter = it.count() def to_jaxpr_tracer(self, x): if isinstance(x, JaxprTracer) and x._trace.tag is self.tag: @@ -174,14 +153,6 @@ def new_arg(self, pval: PartialVal) -> JaxprTracer: # known inputs (if it needs them, then they get passed through residuals). if const is None: aval = pval.get_aval() - if type(aval) is DShapedArray: - # TODO(dougalm): Fix the type error and remove the pytype pragmas. - # pytype: disable=attribute-error - shape = [self.new_instantiated_const(d) - if isinstance(d, Tracer) and d._trace.level < self.level else d - for d in aval.shape] - # pytype: enable=attribute-error - aval = aval.update(shape=tuple(shape)) return JaxprTracer(self, PartialVal.unknown(aval), LambdaBinding()) else: return self.new_const(const) @@ -191,18 +162,18 @@ def instantiate_const(self, tracer: JaxprTracer) -> JaxprTracer: if const is None: return tracer else: - if type(const) in core.literalable_types and np.shape(const) == (): + if core.is_literalable(const): return self.new_instantiated_literal(const) else: return self.new_instantiated_const(const) - def instantiate_const_abstracted(self, tracer) -> JaxprTracer: - const = tracer.pval.get_known() + def cur_qdd(self, x): + const = self.to_jaxpr_tracer(x).pval.get_known() if const is None: - return tracer + assert False # TODO: track tangent QDDs else: - aval = get_aval(const).update_weak_type(np.isscalar(const)) - return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(const)) + with core.set_current_trace(self.parent_trace): + return core.cur_qdd(const) def process_primitive(self, primitive, tracers, params): with core.set_current_trace(self.parent_trace): @@ -222,20 +193,25 @@ def default_process_primitive(self, primitive, tracers, params): return primitive.bind_with_trace(self.parent_trace, consts, params) tracers = map(self.instantiate_const, tracers) avals = [t.aval for t in tracers] - out_aval, effects = primitive.abstract_eval(*avals, **params) + out_aval, effs = primitive.abstract_eval(*avals, **params) name_stack = self._current_truncated_name_stack() source = source_info_util.current().replace(name_stack=name_stack) if primitive.multiple_results: out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None) for aval in out_aval] - eqn = new_eqn_recipe(tracers, out_tracers, primitive, params, effects, + eqn = new_eqn_recipe(self, tracers, out_tracers, primitive, params, effs, source) + if effects.partial_eval_kept_effects.filter_in(effs): + self.effect_handles.append(EffectHandle(tracers, eqn)) for t in out_tracers: t.recipe = eqn return out_tracers else: out_tracer = JaxprTracer(self, PartialVal.unknown(out_aval), None) - out_tracer.recipe = new_eqn_recipe(tracers, [out_tracer], primitive, - params, effects, source) + eqn = new_eqn_recipe(self, tracers, [out_tracer], primitive, + params, effs, source) + if effects.partial_eval_kept_effects.filter_in(effs): + self.effect_handles.append(EffectHandle(tracers, eqn)) + out_tracer.recipe = eqn return out_tracer def process_call(self, primitive, f: lu.WrappedFun, tracers, params): @@ -256,6 +232,7 @@ def process_call(self, primitive, f: lu.WrappedFun, tracers, params): # which were unknown to the first call (corresponding to in_avals). # Wrap f to perform the partial evaluation and plumb out aux data. + f = f.with_unknown_names() f_ = trace_to_subjaxpr_nounits_fwd(f, self.tag, f.debug_info, False) f_, aux = partial_eval_wrapper_nounits(f_, tuple(in_knowns), tuple(in_avals)) @@ -263,27 +240,12 @@ def process_call(self, primitive, f: lu.WrappedFun, tracers, params): const_params = update_params(params, in_knowns, 0) # Run the call, getting known out vals and aux data used for staged-out call - fun_and_args = (_update_annotation_known(f_, f.in_type, in_knowns),) + tuple(in_consts) + fun_and_args = (f_,) + tuple(in_consts) out = primitive.bind_with_trace(self.parent_trace, fun_and_args, const_params) fwds, out_knowns, out_type, jaxpr, env = aux() # Split apart known outputs from the original call and non-fwded residuals. out_consts, non_fwd_res = split_list(out, [sum(out_knowns)]) - - # Form the complete list of residuals by forwarding some inputs. - if config.dynamic_shapes.value: - # With dynamic shapes, we may need to forward implicit arguments. - assert f.in_type is not None, "f must be annotated with lu.annotate()" - in_consts_, in_knowns_ = iter(in_consts), iter(in_knowns) - in_consts_full = [None] * len(f.in_type) - for idx, (aval, explicit) in enumerate(f.in_type): - if explicit and next(in_knowns_): - c = in_consts_full[idx] = next(in_consts_) - if aval.shape: - for d1, d2 in zip(aval.shape, c.shape): - if type(d1) is DBIdx: - in_consts_full[d1.val] = d2 - else: - in_consts_full = in_consts + in_consts_full = in_consts res = subs_list(fwds, in_consts_full, non_fwd_res) # Create the input tracers for the staged-out (unknown-value) call. @@ -292,25 +254,17 @@ def process_call(self, primitive, f: lu.WrappedFun, tracers, params): unknown_arg_tracers = [t for t in tracers if not t.is_known()] # Adjust parameters (e.g. donated_invars) for the staged-out call's args. num_new_args = len(res_tracers) + len(env_tracers) - staged_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr)) + new_jaxpr = convert_constvars_jaxpr(jaxpr) + if isinstance(primitive, core.ClosedCallPrimitive): + new_jaxpr = close_jaxpr(new_jaxpr) # type: ignore + staged_params = dict(params, call_jaxpr=new_jaxpr) staged_params = update_params(staged_params, map(op.not_, in_knowns), num_new_args) - # The outputs of the staged-out call are Tracers with the new eqn as recipe. - if config.dynamic_shapes.value: - # With dynamic shapes, we may need to substitute Tracers into avals. - out_tracers = [] - for aval, _ in out_type: - if type(aval) is DShapedArray: - shape = [[*res_tracers, *env_tracers, *unknown_arg_tracers][d.val] - if type(d) is InDBIdx else d for d in aval.shape] - aval = aval.update(shape=tuple(shape)) - out_tracers.append(JaxprTracer(self, PartialVal.unknown(aval), None)) - else: - out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None) - for a in out_type] + out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None) + for a in out_type] name_stack = self._current_truncated_name_stack() source = source_info_util.current().replace(name_stack=name_stack) - eqn = new_eqn_recipe((*res_tracers, *env_tracers, *unknown_arg_tracers), + eqn = new_eqn_recipe(self, (*res_tracers, *env_tracers, *unknown_arg_tracers), out_tracers, primitive, staged_params, jaxpr.effects, source) for t in out_tracers: t.recipe = eqn @@ -379,7 +333,7 @@ def const_out_axes_thunk(): for a in out_avals] effs = core.filter_named_axis_effects(jaxpr.effects, {params['axis_name']}) src_info = source_info_util.current() - eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers), + eqn = new_eqn_recipe(self, (*const_tracers, *env_tracers, *unknown_arg_tracers), out_tracers, primitive, staged_params, effs, src_info) for t in out_tracers: t.recipe = eqn @@ -395,8 +349,7 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, symbolic_zeros): vals = [t.pval[1] for t in tracers] return prim.bind(fun, jvp, *vals, symbolic_zeros=symbolic_zeros) # We assume non-trivial partial evaluation is only performed to build linear - # functions, and hence we don't need to keep the custom JVP rule around - # anymore. + # functions, and hence we don't need to keep the custom JVP rule around. del jvp, symbolic_zeros with core.set_current_trace(self): return fun.call_wrapped(*tracers) @@ -415,7 +368,7 @@ def process_custom_transpose(self, prim, call, tracers, **params): for aval in params['out_types']] in_tracers = map(self.instantiate_const, tracers) new_params = dict(params, call=call) - eqn = new_eqn_recipe(in_tracers, out_tracers, prim, new_params, + eqn = new_eqn_recipe(self, in_tracers, out_tracers, prim, new_params, core.no_effects, source_info_util.current()) for t in out_tracers: t.recipe = eqn return out_tracers @@ -425,49 +378,45 @@ def process_custom_vjp_call(self, prim, f, fwd, bwd, tracers, out_trees, symboli if all(t.is_known() for t in tracers): vals = [t.pval[1] for t in tracers] with core.set_current_trace(self.parent_trace): - return prim.bind(f, fwd, bwd, *vals, out_trees=out_trees, symbolic_zeros=symbolic_zeros) - else: - # TODO(mattjj): remove non-ad users of partial eval, then drop this case. - # We stage out the whole thing, i.e. no nontrivial partial evaluation. - tracers = map(self.instantiate_const_abstracted, tracers) - # Because we instantiate all tracers, in_knowns is all False. - in_knowns, in_avals, () = partition_pvals([t.pval for t in tracers]) - f = trace_to_subjaxpr_nounits(f, self, True, f.debug_info) - f, aux = partial_eval_wrapper_nounits(f, (*in_knowns,), (*in_avals,)) - with core.set_current_trace(self.parent_trace): - out_flat = prim.bind(f, fwd, bwd, out_trees=out_trees, - symbolic_zeros=symbolic_zeros) - out_knowns, out_avals, jaxpr, env = aux() - out_consts, res = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)]) - res_tracers = map(self.new_instantiated_const, res) - env_tracers = map(self.to_jaxpr_tracer, env) - out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None) - for a in out_avals] - closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ()) - - @_memoize - def fwd_jaxpr_thunk(*zeros): - fwd_ = _interleave_fun(fwd, zeros) - fwd_ = trace_to_subjaxpr_nounits(fwd_, self, True, fwd_.debug_info) - fwd_, aux = partial_eval_wrapper_nounits(fwd_, (*in_knowns,), (*in_avals,)) - out_flat = fwd_.call_wrapped() - out_knowns, out_avals, jaxpr, env = aux() - _, res = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)]) - converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env)) - return converted_jaxpr, (*res, *env) + return prim.bind(f, fwd, bwd, *vals, out_trees=out_trees, + symbolic_zeros=symbolic_zeros) + + tracers = map(self.instantiate_const, tracers) + in_knowns = (False,) * len(tracers) + in_avals = tuple(t.aval for t in tracers) + f_ = trace_to_subjaxpr_nounits2(f, self.tag, f.debug_info, True) + f_, aux = partial_eval_wrapper_nounits(f_, in_knowns, in_avals) + params = dict(out_trees=out_trees, symbolic_zeros=symbolic_zeros) + res = prim.bind_with_trace(self.parent_trace, (f_, fwd, bwd), params) + out_knowns, out_avals, jaxpr, env = aux() + assert not any(out_knowns) + res_tracers = map(self.instantiate_const, map(self.new_const, res)) + env_tracers = map(self.to_jaxpr_tracer, env) + out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None) + for a in out_avals] + closed_jaxpr = close_jaxpr(convert_constvars_jaxpr(jaxpr)) + + @partial(lu.wrap_init, debug_info=fwd.debug_info) + @_memoize + def fwd_jaxpr_thunk(*zeros): + fwd_ = _interleave_fun(fwd.with_unknown_names(), zeros) + fwd_jaxpr, _, consts = trace_to_jaxpr_dynamic(fwd_, in_avals) + return fwd_jaxpr, consts name_stack = self._current_truncated_name_stack() source = source_info_util.current().replace(name_stack=name_stack) - eqn = new_eqn_recipe((*res_tracers, *env_tracers, *tracers), - out_tracers, prim.initial_style, - dict(fun_jaxpr=closed_jaxpr, - fwd_jaxpr_thunk=fwd_jaxpr_thunk, - num_consts=len(res) + len(env), - bwd=bwd, out_trees=out_trees, - symbolic_zeros=symbolic_zeros), - jaxpr.effects, source) + params = dict( + call_jaxpr=closed_jaxpr, + fwd_jaxpr_thunk=fwd_jaxpr_thunk, + num_consts=len(res) + len(env), + bwd=bwd, + out_trees=out_trees, + symbolic_zeros=symbolic_zeros + ) + eqn = new_eqn_recipe(self, (*res_tracers, *env_tracers, *tracers), + out_tracers, prim, params, jaxpr.effects, source) for t in out_tracers: t.recipe = eqn - return merge_lists(out_knowns, out_tracers, out_consts) + return out_tracers def partition_pvals( pvals: list[PartialVal] @@ -494,20 +443,31 @@ def partial_eval_wrapper_nounits( store.store((*maybe_fwds, out_knowns, out_avals, jaxpr, env)) return (*out_consts, *res) +@lu.transformation_with_aux2 +def partial_eval_wrapper_nounits2( + f: Callable, + store: lu.Store, + in_knowns: Sequence[bool], + in_avals: Sequence[AbstractValue], + *in_consts: Any): + in_avals_, in_consts_ = iter(in_avals), iter(in_consts) + in_pvals = [PartialVal.known(next(in_consts_)) if known else + PartialVal.unknown(next(in_avals_)) for known in in_knowns] + sentinel = object() + assert next(in_avals_, sentinel) is next(in_consts_, sentinel) is sentinel + jaxpr, (*maybe_fwds, out_pvals, res, env) = f(in_pvals) + out_knowns, _, out_consts = partition_pvals(out_pvals) + res_avals = [typeof(r) for r in res] + store.store((*maybe_fwds, out_knowns, res_avals, jaxpr, env)) + return (*out_consts, *res) + custom_partial_eval_rules: dict[Primitive, Callable] = {} call_partial_eval_rules: dict[Primitive, Callable] = {} call_param_updaters: dict[Primitive, Callable] = {} -def _closed_call_param_updater(params, _, __): - jaxpr = params.get('call_jaxpr') - if jaxpr is None: return params - assert type(jaxpr) is core.Jaxpr - return dict(params, call_jaxpr=core.ClosedJaxpr(jaxpr, ())) -call_param_updaters[core.closed_call_p] = _closed_call_param_updater - def abstract_eval_fun(fun: Callable, *avals, debug_info: core.DebugInfo, **params): - _, avals_out, _, () = trace_to_jaxpr_dynamic( + _, avals_out, _ = trace_to_jaxpr_dynamic( lu.wrap_init(fun, params, debug_info=debug_info), avals) assert all(isinstance(aval, AbstractValue) for aval in avals_out) return avals_out @@ -540,8 +500,6 @@ def parents(self) -> Sequence[JaxprTracer]: if isinstance(self.recipe, JaxprEqnRecipe): # TODO broadcast_in_dim can create a new tracer... return self.recipe.in_tracers - elif isinstance(self.aval, DShapedArray): - return [d for d in self.aval.shape if isinstance(d, JaxprTracer)] else: return [] @@ -633,7 +591,9 @@ def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace, out_tracers = [trace.instantiate_const(t) if inst else t for inst, t in zip(instantiate, out_tracers)] out_tracers_ = [t for t in out_tracers if not t.is_known()] - jaxpr, out_consts, env = tracers_to_jaxpr(in_tracers, out_tracers_, debug_info) + jaxpr, out_consts, env = tracers_to_jaxpr( + in_tracers, out_tracers_, trace.effect_handles, + debug_info.with_unknown_names()) return out_tracers, jaxpr, out_consts, env # The below variant implements an optimization where residuals which are also @@ -715,7 +675,8 @@ class JaxprEqnRecipe(NamedTuple): source_info: source_info_util.SourceInfo ctx: JaxprEqnContext -def new_eqn_recipe(in_tracers: Sequence[JaxprTracer], +def new_eqn_recipe(trace: JaxprTrace, + in_tracers: Sequence[JaxprTracer], out_tracers: Sequence[JaxprTracer], primitive: Primitive, params: dict[str, Any], @@ -734,11 +695,11 @@ def new_eqn_recipe(in_tracers: Sequence[JaxprTracer], len(params["donated_invars"]) == len(params["call_jaxpr"].invars)) out_avals = [t.aval for t in out_tracers] ctx = ctx or JaxprEqnContext( - compute_on.current_compute_type(), + config.compute_on_context_manager.value, config.threefry_partitionable.value, xla_metadata_lib.current_xla_metadata(), ) - return JaxprEqnRecipe(object(), tuple(in_tracers), map(ref, out_tracers), + return JaxprEqnRecipe(next(trace.counter), tuple(in_tracers), map(ref, out_tracers), out_avals, primitive, params, effects, source_info, ctx) @@ -756,6 +717,7 @@ def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom], def tracers_to_jaxpr( in_tracers: Sequence[JaxprTracer], out_tracers: Sequence[JaxprTracer], + effect_handles: Sequence[Any], debug_info: core.DebugInfo, ) -> tuple[Jaxpr, tuple[Any, ...], tuple[Any, ...]]: """Constructs Jaxpr given tracers for inputs and outputs. @@ -782,31 +744,34 @@ def get_atom(t: JaxprTracer) -> Atom: def newvar(t: JaxprTracer | None) -> Var: assert t is not None - var = gensym(type_substitute(t.aval)) + var = gensym(t.aval) var_ = t_to_var.setdefault(id(t), var) assert var is var_ return var - def type_substitute(aval: AbstractValue) -> AbstractValue: - if isinstance(aval, DShapedArray): - # Replace any Tracers in aval.shape with Vars or Literal values - shape = [get_atom(d) if type(d) is JaxprTracer else d for d in aval.shape] - shape = [d.val if type(d) is Literal else d for d in shape] - aval = aval.update(shape=tuple(shape)) - return aval - processed_eqn_ids = set() eqns: list[core.JaxprEqn] = [] - for t in toposort((*in_tracers, *out_tracers)): + is_high = False + + reachable = toposort + tracers = reachable((*in_tracers, *out_tracers, *effect_handles)) + def sort_key(t): + r = t.recipe + return r.eqn_id if isinstance(r, JaxprEqnRecipe) else -1 + tracers = sorted(tracers, key=sort_key) + + for t in tracers: r = t.recipe if isinstance(r, JaxprEqnRecipe): # TODO broadcast_in_dim can create a new tracer, not present in parents if r.eqn_id not in processed_eqn_ids: in_atoms = map(get_atom, r.in_tracers) - outvars = [DropVar(type_substitute(a)) if rf() is None else newvar(rf()) + outvars = [DropVar(a) if rf() is None else newvar(rf()) for a, rf in zip(r.out_avals, r.out_tracer_refs)] eqns.append(new_jaxpr_eqn(in_atoms, outvars, r.primitive, r.params, r.effects, r.source_info, r.ctx)) + in_avals = [x.aval for x in in_atoms] + is_high |= r.primitive.is_high(*in_avals, **r.params) processed_eqn_ids.add(r.eqn_id) elif isinstance(r, LambdaBinding): if not any(t is in_tracer for in_tracer in in_tracers): @@ -832,9 +797,9 @@ def type_substitute(aval: AbstractValue) -> AbstractValue: const_vars, const_vals = unzip2(consts.items()) outvars = map(get_atom, out_tracers) # type: ignore[arg-type] jaxpr_effects = make_jaxpr_effects(const_vars, invars, outvars, eqns) + is_high |= any(x.aval.is_high for x in it.chain(const_vars, invars, outvars)) jaxpr = Jaxpr(const_vars, invars, # type: ignore[arg-type] - outvars, eqns, jaxpr_effects, - debug_info) + outvars, eqns, jaxpr_effects, debug_info, is_high) config.enable_checks.value and core.check_jaxpr(jaxpr) # del getvar # needed to avoid cyclic-reference closure, apparently! return jaxpr, const_vals, env_vals @@ -844,16 +809,22 @@ def move_envvars(jaxpr: Jaxpr, which: tuple[bool, ...]) -> Jaxpr: constvars, envvars = partition_list(which, jaxpr.constvars) return jaxpr.replace(constvars=constvars, invars=[*envvars, *jaxpr.invars]) +@weakref_lru_cache +def separate_consts(jaxpr: ClosedJaxpr) -> tuple[ClosedJaxpr, list[Any]]: + """Moves the constvars to the start of invars and returns the consts explicitly.""" + return close_jaxpr(convert_constvars_jaxpr(jaxpr.jaxpr)), jaxpr.consts + @weakref_lru_cache def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr: """Moves the constvars to the start of invars.""" config.enable_checks.value and core.check_jaxpr(jaxpr) - dbg = jaxpr.debug_info._replace( - arg_names=("",) * len(jaxpr.constvars) + jaxpr.debug_info.arg_names) - lifted_jaxpr = Jaxpr(constvars=(), - invars=jaxpr.constvars + jaxpr.invars, - outvars=jaxpr.outvars, eqns=jaxpr.eqns, - effects=jaxpr.effects, debug_info=dbg) + if jaxpr.debug_info.arg_names is None: + arg_names = None + else: + arg_names = ("",) * len(jaxpr.constvars) + (*jaxpr.debug_info.arg_names,) + dbg = jaxpr.debug_info._replace(arg_names=arg_names) + lifted_jaxpr = jaxpr.replace( + constvars=(), invars=jaxpr.constvars + jaxpr.invars, debug_info=dbg) config.enable_checks.value and core.check_jaxpr(lifted_jaxpr) return lifted_jaxpr @@ -864,8 +835,11 @@ def convert_invars_to_constvars(jaxpr: Jaxpr, n: int) -> Jaxpr: return jaxpr.replace() # 'return jaxpr' would create cache reference cycle config.enable_checks.value and core.check_jaxpr(jaxpr) constvars, invars = split_list(jaxpr.invars, [n]) - dbg = jaxpr.debug_info._replace( - arg_names=jaxpr.debug_info.arg_names[n:]) + if jaxpr.debug_info.arg_names is None: + dbg = jaxpr.debug_info + else: + dbg = jaxpr.debug_info._replace( + arg_names=jaxpr.debug_info.arg_names[n:]) lifted_jaxpr = jaxpr.replace(constvars=tuple(constvars), invars=invars, debug_info=dbg) config.enable_checks.value and core.check_jaxpr(lifted_jaxpr) @@ -876,9 +850,8 @@ def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int) -> Jaxpr: raise NotImplementedError config.enable_checks.value and core.check_jaxpr(jaxpr) env_vars, invars = split_list(jaxpr.invars, [num_env_vars]) - converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars, - invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns, - effects=jaxpr.effects, debug_info=jaxpr.debug_info) + converted_jaxpr = jaxpr.replace(constvars=jaxpr.constvars + env_vars, + invars=invars) config.enable_checks.value and core.check_jaxpr(converted_jaxpr) return converted_jaxpr @@ -944,75 +917,78 @@ def partial_eval_jaxpr_nounits( passed to jaxpr_unknown (as leading inputs). """ instantiate = tuple(instantiate) if isinstance(instantiate, list) else instantiate - return _partial_eval_jaxpr_nounits(jaxpr, tuple(unknowns), instantiate) + return _partial_eval_jaxpr_nounits(jaxpr, tuple(unknowns), instantiate, False)[:-1] + +def partial_eval_jaxpr_nounits_fwd( + jaxpr: ClosedJaxpr, unknowns: Sequence[bool], + instantiate: bool | Sequence[bool], + fwd: bool | Sequence[bool] = True, +) -> tuple[ClosedJaxpr, ClosedJaxpr, list[bool], list[AbstractValue], list[int | None]]: + instantiate = tuple(instantiate) if isinstance(instantiate, list) else instantiate + fwd = tuple(fwd) if isinstance(fwd, list) else fwd + return _partial_eval_jaxpr_nounits(jaxpr, tuple(unknowns), instantiate, fwd) @weakref_lru_cache -def _partial_eval_jaxpr_nounits(jaxpr: ClosedJaxpr, - in_unknowns: Sequence[bool], - instantiate: bool | Sequence[bool]): - f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), - debug_info=jaxpr.jaxpr.debug_info) +def _partial_eval_jaxpr_nounits( + jaxpr: ClosedJaxpr, in_unknowns: Sequence[bool], + instantiate: bool | Sequence[bool], fwd: bool | Sequence[bool]): + f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=jaxpr.jaxpr.debug_info) cell = [] def fun(*known_vals_in): - known_vals_in = iter(known_vals_in) + known_vals_in_ = iter(known_vals_in) unknown_avals = (a for a, uk in zip(jaxpr.in_avals, in_unknowns) if uk) in_pvals = [PartialVal.unknown(next(unknown_avals)) if uk - else PartialVal.known(next(known_vals_in)) for uk in in_unknowns] - assert next(known_vals_in, None) is next(unknown_avals, None) is None - jaxpr_unknown_, out_pvals, residuals = trace_to_jaxpr_nounits( - f, in_pvals, instantiate=instantiate) + else PartialVal.known(next(known_vals_in_)) for uk in in_unknowns] + assert next(known_vals_in_, None) is next(unknown_avals, None) is None + jaxpr_unknown_, (fwds, out_pvals, residuals, ()) = trace_to_subjaxpr_nounits_fwd( + f, TraceTag(), jaxpr.jaxpr.debug_info, instantiate).call_wrapped(in_pvals) jaxpr_unknown = convert_constvars_jaxpr(jaxpr_unknown_) out_unknowns = [not pval.is_known() for pval in out_pvals] + if type(fwd) is bool and not fwd: + residuals_ = iter(residuals) + residuals = [next(residuals_) if f is None else known_vals_in[f] + for f in fwds] + assert next(residuals_, None) is None + fwds = [None] * len(fwds) + else: + if type(fwd) is tuple: + fwd_ = [f for f, uk in zip(fwd, in_unknowns) if not uk] + residuals_, residuals = iter(residuals), [] + fwds = [residuals.append(next(residuals_)) if f is None else + residuals.append(known_vals_in[f]) if not fwd_[f] else + f for f in fwds] + fwds, residuals = _include_consts_in_fwds(jaxpr.consts, fwds, residuals) res_avals = [core.get_aval(r) for r in residuals] - cell.append((out_unknowns, jaxpr_unknown, res_avals)) + cell.append((out_unknowns, jaxpr_unknown, res_avals, fwds)) known_vals_out = [pval.get_known() for pval in out_pvals if pval.is_known()] return [*known_vals_out, *residuals] - known_avals = [a for a, uk in zip(jaxpr.in_avals, in_unknowns) if not uk] - jaxpr_known, _, consts_known, () = trace_to_jaxpr_dynamic( - lu.wrap_init(fun, debug_info=f.debug_info), + known_avals = [a for a, uk in zip(jaxpr.in_aval_qdds, in_unknowns) if not uk] + jaxpr_known, _, consts_known = trace_to_jaxpr_dynamic( + lu.wrap_init(fun, debug_info=f.debug_info.with_unknown_names()), known_avals) - (out_unknowns, jaxpr_unknown, res_avals), = cell # pytype: disable=bad-unpacking + (out_unknowns, jaxpr_unknown, res_avals, fwds), = cell # pytype: disable=bad-unpacking - # check jaxpr_known and jaxpr_unknown in isolation - # TODO(mattjj): enable weak type checking here if config.enable_checks.value: core.check_jaxpr(jaxpr_known) core.check_jaxpr(jaxpr_unknown) - def check(first, second): - for f, s in zip(first, second): - if (not isinstance(f, core.ShapedArray) and - not isinstance(s, core.ShapedArray)): - assert f == s - elif f.sharding.mesh.empty or s.sharding.mesh.empty: - assert (f.shape, f.dtype) == (s.shape, s.dtype) - else: - assert f == s, (f, s) - - # check jaxpr_known has input type corresponding to known inputs of jaxpr - assert ([v.aval for v in jaxpr_known.invars] == - [a for a, uk in zip(jaxpr.in_avals, in_unknowns) if not uk]) - # check jaxpr_known has out type corresponding to known outs of jaxpr plus res - # Change this to `assert ... == ...` and remove the check function. - # See https://github.com/jax-ml/jax/issues/26474 - check([v.aval.strip_weak_type() for v in jaxpr_known.outvars], - [a.strip_weak_type() for a, uk in zip(jaxpr.out_avals, out_unknowns) - if not uk] + [a.strip_weak_type() for a in res_avals]) - # check jaxpr_unknown has input type corresponding to res plus unknown inputs - assert ([v.aval.strip_weak_type() for v in jaxpr_unknown.invars] == - [a.strip_weak_type() for a in res_avals] + - [a.strip_weak_type() for a, uk in zip(jaxpr.in_avals, in_unknowns) - if uk]) - # check jaxpr_unknown has output type corresponding to unknown outputs - check([v.aval.strip_weak_type() for v in jaxpr_unknown.outvars], - [a.strip_weak_type() for a, uk in zip(jaxpr.out_avals, out_unknowns) - if uk]) - closed_jaxpr_known = ClosedJaxpr(jaxpr_known, consts_known) closed_jaxpr_unknown = ClosedJaxpr(jaxpr_unknown, ()) - return closed_jaxpr_known, closed_jaxpr_unknown, out_unknowns, res_avals + return closed_jaxpr_known, closed_jaxpr_unknown, out_unknowns, res_avals, fwds + +def _include_consts_in_fwds(consts, fwds, residuals): + if all(f is None for f in fwds): + return fwds, residuals + dummys = [object() for _ in range(max(f for f in fwds if f is not None) + 1)] + residuals_ = iter(residuals) + residuals = [next(residuals_) if f is None else dummys[f] for f in fwds] + assert next(residuals_, None) is None + idxs = {id(x): i for i, x in enumerate((*consts, *dummys))} + fwds = [idxs.get(id(r)) for r in residuals] + residuals = [r for r in residuals if id(r) not in idxs] + return fwds, residuals def partial_eval_jaxpr_custom( @@ -1081,9 +1057,9 @@ def ensure_instantiated(inst: bool, x: Atom) -> Atom: return x def has_effects(effects) -> bool: - return bool({e for e in effects if not isinstance(e, core.NamedAxisEffect)}) + not_really_effects = (core.NamedAxisEffect, core.InternalMutableArrayEffect) + return any(not isinstance(e, not_really_effects) for e in effects) - newvar = core.gensym(suffix='_offload') known_eqns, staged_eqns = [], [] foreach(write, in_unknowns, in_inst, jaxpr.invars) foreach(partial(write, False, True), jaxpr.constvars) @@ -1112,25 +1088,31 @@ def has_effects(effects) -> bool: foreach(partial(write, False, False), eqn.outvars) elif isinstance(policy, Offloadable): # TODO(slebedev): This is a legit error which requires a BUILD fix. - from jax._src.dispatch import device_put_p, TransferToMemoryKind, CopySemantics # pytype: disable=import-error - resvars = [newvar(v.aval) for v in eqn.outvars] - outvars_copy = list[Atom](eqn.outvars) + from jax._src.dispatch import device_put_p, ArrayCopySemantics # type: ignore + resvars = [Var(v.aval.update(memory_space=core.mem_kind_to_space(policy.dst))) + for v in eqn.outvars] offload_eqn = core.JaxprEqn( - outvars_copy, resvars, device_put_p, - dict(devices=[TransferToMemoryKind(policy.dst) - ] * len(outvars_copy), srcs=[None], - copy_semantics=[CopySemantics.COPY]), + eqn.outvars, resvars, device_put_p, + dict( + devices=(core.mem_kind_to_space(policy.dst),) * len(eqn.outvars), + srcs=(None,), + copy_semantics=(ArrayCopySemantics.ALWAYS_COPY,), + ), set(), source_info_util.new_source_info(), JaxprEqnContext(None, False)) known_eqns.append(offload_eqn) # resvars are known and available in the backward jaxpr. foreach(partial(write, False, True), resvars) + assert all(o.aval.memory_space == core.mem_kind_to_space(policy.src) # type: ignore + for o in eqn.outvars) residuals.update(resvars) reload_eqn = core.JaxprEqn( resvars, eqn.outvars, device_put_p, - dict(devices=[TransferToMemoryKind(policy.src) - ] * len(resvars), srcs=[None], - copy_semantics=[CopySemantics.COPY]), + dict( + devices=(core.mem_kind_to_space(policy.src),) * len(resvars), + srcs=(None,), + copy_semantics=(ArrayCopySemantics.ALWAYS_COPY,) + ), set(), source_info_util.new_source_info(), JaxprEqnContext(None, False)) staged_eqns.append(reload_eqn) @@ -1158,8 +1140,12 @@ def has_effects(effects) -> bool: known_outvars = [*outs_known, *residuals] known_effects = make_jaxpr_effects(jaxpr.constvars, ins_known_and_ref_res, known_outvars, known_eqns) - jaxpr_known = Jaxpr(jaxpr.constvars, ins_known_and_ref_res, known_outvars, - known_eqns, known_effects, jaxpr.debug_info) + + # TODO(mattjj,necula): debug info should be updated here + jaxpr_known = jaxpr.replace( + invars=ins_known_and_ref_res, outvars=known_outvars, + eqns=known_eqns, effects=known_effects, + debug_info=jaxpr.debug_info.with_unknown_names()) config.enable_checks.value and core.check_jaxpr(jaxpr_known) _, ins_staged = partition_list(in_inst, jaxpr.invars) @@ -1167,9 +1153,11 @@ def has_effects(effects) -> bool: staged_invars = [*residuals, *non_input_res_refs, *ins_staged] staged_effects = make_jaxpr_effects(jaxpr.constvars, staged_invars, outs_staged, staged_eqns) - jaxpr_staged = Jaxpr(jaxpr.constvars, staged_invars, - outs_staged, staged_eqns, staged_effects, - jaxpr.debug_info) + # TODO(mattjj,necula): debug info should be updated here + jaxpr_staged = jaxpr.replace( + invars=staged_invars, outvars=outs_staged, eqns=staged_eqns, + effects=staged_effects, + debug_info=jaxpr.debug_info.with_unknown_names()) config.enable_checks.value and core.check_jaxpr(jaxpr_staged) return (jaxpr_known, jaxpr_staged, out_unknowns, out_inst, len(residuals), @@ -1194,6 +1182,15 @@ class Offloadable(NamedTuple): def ensure_enum(case: bool | RematCases) -> RematCases: if isinstance(case, bool): return Saveable if case else Recompute + if not isinstance(case, (RecomputeType, SaveableType, Offloadable)): + msg = ("Value returned by a remat policy should be a bool or" + " `ad_checkpoint.Recompute`, `ad_checkpoint.Saveable` or" + " `ad_checkpoint.Offloadable(...)`." + f" Got {case} of type {type(case)}.") + if isinstance(case, Offloadable): + msg += ("Did you return `Offloadable` instead of an instantiated" + " `Offloadable(...)`?") + raise TypeError(msg) return case # A primitive rule for policy-driven partial evaluation returns a 5-tuple @@ -1229,14 +1226,12 @@ def _default_res_aval_updater( params: dict[str, Any], aval: AbstractValue) -> AbstractValue: return aval -@contextmanager -def trivial_ctx(_): yield def call_partial_eval_custom_rule( jaxpr_param_name: str, params_updater: ParamsUpdater, saveable: Callable[..., RematCases_], unks_in: list[bool], inst_in: list[bool], eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater, - ctx = trivial_ctx, + ctx = contextlib.nullcontext, ) -> tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], list[Var]]: jaxpr = eqn.params[jaxpr_param_name] with ctx(eqn.params): @@ -1246,20 +1241,20 @@ def call_partial_eval_custom_rule( out_binders_known, _ = partition_list(unks_out, eqn.outvars) _, ins_staged = partition_list(inst_in, eqn.invars) _, out_binders_staged = partition_list(inst_out, eqn.outvars) - newvar = core.gensym() params_known = {**eqn.params, jaxpr_param_name: jaxpr_known} params_staged = {**eqn.params, jaxpr_param_name: jaxpr_staged} params_known, params_staged = params_updater( unks_in, inst_in, map(op.not_, unks_out), inst_out, num_res, params_known, params_staged) - residuals = [newvar(res_aval(params_known, var.aval)) + residuals = [Var(res_aval(params_known, var.aval)) for var in jaxpr_staged.invars[:num_res]] - eqn_known = new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals], - eqn.primitive, params_known, jaxpr_known.effects, - eqn.source_info, eqn.ctx) - eqn_staged = new_jaxpr_eqn([*residuals, *ins_staged], out_binders_staged, - eqn.primitive, params_staged, - jaxpr_staged.effects, eqn.source_info, eqn.ctx) + eqn_known = new_jaxpr_eqn( + ins_known, [*out_binders_known, *residuals], eqn.primitive, params_known, + core.eqn_effects(jaxpr_known), eqn.source_info, eqn.ctx) + eqn_staged = new_jaxpr_eqn( + [*residuals, *ins_staged], out_binders_staged, eqn.primitive, + params_staged, core.eqn_effects(jaxpr_staged), eqn.source_info, + eqn.ctx) assert len(eqn_staged.invars) == len(jaxpr_staged.invars) new_inst = [x for x, inst in zip(eqn.invars, inst_in) if type(x) is Var and not inst] @@ -1285,25 +1280,24 @@ def closed_call_partial_eval_custom_rule( ins_known, _ = partition_list(unks_in, eqn.invars) _, ins_staged = partition_list(inst_in, eqn.invars) _, out_binders_staged = partition_list(inst_out, eqn.outvars) - newvar = core.gensym() params_known = {**eqn.params, jaxpr_param_name: jaxpr_known} params_staged = {**eqn.params, jaxpr_param_name: jaxpr_staged} params_known, params_staged = params_updater( unks_in, inst_in, map(op.not_, unks_out), inst_out, sum(f is None for f in out_fwd), num_res, params_known, params_staged) res_val_binders, res_ref_binders = split_list( - [newvar(res_aval(params_known, v)) + [Var(res_aval(params_known, v)) for v in jaxpr_staged.in_avals[:num_res]], [num_res_val]) res_val_binders = [v for v, f in zip(res_val_binders, out_fwd) if f is None] res_val_vars = subs_list(out_fwd, out_binders_known, res_val_binders) - eqn_known = new_jaxpr_eqn([*ins_known, *res_ref_binders], - [*out_binders_known, *res_val_binders], - eqn.primitive, params_known, jaxpr_known.effects, - eqn.source_info, eqn.ctx) - eqn_staged = new_jaxpr_eqn([*res_val_vars, *res_ref_binders, *ins_staged], - out_binders_staged, - eqn.primitive, params_staged, jaxpr_staged.effects, - eqn.source_info, eqn.ctx) + eqn_known = new_jaxpr_eqn( + [*ins_known, *res_ref_binders], [*out_binders_known, *res_val_binders], + eqn.primitive, params_known, core.eqn_effects(jaxpr_known), + eqn.source_info, eqn.ctx) + eqn_staged = new_jaxpr_eqn( + [*res_val_vars, *res_ref_binders, *ins_staged], out_binders_staged, + eqn.primitive, params_staged, core.eqn_effects(jaxpr_staged), + eqn.source_info, eqn.ctx) assert len(eqn_staged.invars) == len(jaxpr_staged.in_avals) assert len(ins_known) + len(res_ref_binders) == len(jaxpr_known.jaxpr.invars) assert len(ins_staged) + len(res_ref_binders) + len(res_val_vars) == len(jaxpr_staged.jaxpr.invars) @@ -1350,15 +1344,15 @@ def _closed_jaxpr_partial_eval_custom_cached( def _jaxpr_forwarding(jaxpr: Jaxpr) -> list[int | None]: # Compute which inputs are just forwarded to outputs. - fwds: dict[Var, Var] = dict(zip(jaxpr.invars, jaxpr.invars)) + fwds: dict[Var, Atom] = dict(zip(jaxpr.invars, jaxpr.invars)) for eqn in jaxpr.eqns: if eqn.primitive in forwarding_rules: eqn = eqn.replace(invars=[a if type(a) is Literal else fwds.get(a, a) # type: ignore for a in eqn.invars]) - fwd_vars, _ = forwarding_rules[eqn.primitive](eqn) - for v_orig, v_new in zip(eqn.outvars, fwd_vars): - if v_new is not None: - fwds[v_orig] = v_new + fwd_idx, _ = forwarding_rules[eqn.primitive](eqn) + for v_orig, idx in zip(eqn.outvars, fwd_idx): + if idx is not None: + fwds[v_orig] = eqn.invars[idx] idxs: dict[Var, int] = {v: i for i, v in enumerate(jaxpr.invars)} return [None if type(v) is Literal else idxs.get(fwds.get(v)) # type: ignore for v in jaxpr.outvars] @@ -1462,7 +1456,8 @@ def write(x: Atom, b: bool) -> None: jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info, jaxpr.debug_info.filter_arg_names(used_inputs), jaxpr.debug_info.filter_result_paths(used_outputs)) - new_jaxpr = Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr_effects, dbg) + new_jaxpr = jaxpr.replace(invars=invars, outvars=outvars, eqns=eqns, + effects=jaxpr_effects, debug_info=dbg) config.enable_checks.value and core.check_jaxpr(new_jaxpr) return new_jaxpr, used_inputs @@ -1515,17 +1510,37 @@ def dce_jaxpr_closed_call_rule(used_outputs: list[bool], eqn: JaxprEqn return [False] * len(eqn.invars), None jaxpr_ = eqn.params['call_jaxpr'] closed_jaxpr, used_inputs = _cached_closed_call_dce(jaxpr_, tuple(used_outputs)) + effects = core.eqn_effects(closed_jaxpr) new_params = dict(eqn.params, call_jaxpr=closed_jaxpr) new_eqn = new_jaxpr_eqn( [v for v, used in zip(eqn.invars, used_inputs) if used], [v for v, used in zip(eqn.outvars, used_outputs) if used], - eqn.primitive, new_params, closed_jaxpr.effects, eqn.source_info, eqn.ctx) + eqn.primitive, new_params, effects, eqn.source_info, eqn.ctx) return used_inputs, new_eqn dce_rules[core.closed_call_p] = dce_jaxpr_closed_call_rule @weakref_lru_cache def close_jaxpr(jaxpr: Jaxpr) -> ClosedJaxpr: - return ClosedJaxpr(jaxpr, ()) + # The `jaxpr.replace()` is making a copy of the Jaxpr, without which + # the cache value would have a strong reference to the same Jaxpr as + # the key, and we would never gc the cache entry. This works because + # Jaxpr is hashed by id, and the cache entry is dead is the key is dead. + return ClosedJaxpr(jaxpr.replace(), ()) + +def move_invars_right(jaxpr: ClosedJaxpr, to_move: Sequence[bool]): + return _move_invars_right(jaxpr, tuple(to_move)) + +@weakref_lru_cache +def _move_invars_right(jaxpr: ClosedJaxpr, to_move: tuple[bool, ...]): + invars, rest = split_list(jaxpr.jaxpr.invars, [len(to_move)]) + left_invars, right_invars = partition_list(to_move, invars) + new_invars = [*left_invars, *right_invars, *rest] + new_effs = _renumber_effects( + (*jaxpr.jaxpr.constvars, *new_invars), + (*jaxpr.jaxpr.constvars, *jaxpr.jaxpr.invars), + jaxpr.jaxpr.effects) + new_jaxpr = jaxpr.jaxpr.replace(invars=new_invars, effects=new_effs) + return jaxpr.replace(jaxpr=new_jaxpr) def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool] ) -> ClosedJaxpr: @@ -1533,16 +1548,27 @@ def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool] return _move_binders_to_front(closed_jaxpr, tuple(to_move)) @weakref_lru_cache -def _move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: tuple[bool, ...] +def _move_binders_to_front(jaxpr: ClosedJaxpr, to_move: tuple[bool, ...] ) -> ClosedJaxpr: - assert len(closed_jaxpr.in_avals) == len(to_move) - new_invars = _move_to_front(closed_jaxpr.jaxpr.invars, to_move) - new_jaxpr = Jaxpr(closed_jaxpr.jaxpr.constvars, new_invars, - closed_jaxpr.jaxpr.outvars, closed_jaxpr.jaxpr.eqns, - closed_jaxpr.jaxpr.effects, - closed_jaxpr.jaxpr.debug_info) - new_closed_jaxpr = core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts) - return new_closed_jaxpr + assert len(jaxpr.in_avals) == len(to_move) + constvars, invars = jaxpr.jaxpr.constvars, jaxpr.jaxpr.invars + new_invars = _move_to_front(invars, to_move) + new_effs = _renumber_effects( + (*constvars, *new_invars), (*constvars, *invars), jaxpr.jaxpr.effects) + if jaxpr.jaxpr.debug_info.arg_names is None: + new_arg_names = None + else: + new_arg_names = tuple(_move_to_front(jaxpr.jaxpr.debug_info.arg_names, to_move)) + dbg = jaxpr.jaxpr.debug_info._replace(arg_names=new_arg_names) + new_jaxpr = jaxpr.jaxpr.replace( + constvars=constvars, invars=new_invars, effects=new_effs, debug_info=dbg) + return core.ClosedJaxpr(new_jaxpr, jaxpr.consts) + +def _renumber_effects(new_vars, old_vars, effs): + newvar_idxs = {id(v): i for i, v in enumerate(new_vars)} + old_to_new = {i: newvar_idxs[id(v)] for i, v in enumerate(old_vars)} + return {e.replace(input_index=old_to_new[e.input_index]) + if isinstance(e, effects.JaxprInputEffect) else e for e in effs} def _move_to_front(lst: Sequence, to_move: Sequence[bool]) -> Sequence: return ([elt for elt, move in zip(lst, to_move) if move] + @@ -1553,23 +1579,63 @@ def move_binders_to_back(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool] """Reorder `invars` by moving those indicated in `to_move` to the back.""" return move_binders_to_front(closed_jaxpr, map(op.not_, to_move)) +def move_outvars_to_back(jaxpr: ClosedJaxpr, to_move: Sequence[bool]) -> ClosedJaxpr: + return _move_outvars_to_back(jaxpr, tuple(to_move)) + +@weakref_lru_cache +def _move_outvars_to_back(jaxpr: core.ClosedJaxpr, to_move): + new_outvars = ([e for e, m in zip(jaxpr.jaxpr.outvars, to_move) if not m] + + [e for e, m in zip(jaxpr.jaxpr.outvars, to_move) if m]) + return jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(outvars=new_outvars)) + + class DynamicJaxprTracer(core.Tracer): - __slots__ = ['aval', '_debug_info'] + __slots__ = ['aval', 'val', 'mutable_qdd', 'parent', '_debug_info'] def __init__(self, trace: DynamicJaxprTrace, - aval: core.AbstractValue, - line_info: source_info_util.SourceInfo | None = None): + aval: core.AbstractValue | core.AvalQDD, + val : Atom, + line_info: source_info_util.SourceInfo | None = None, + parent : TracingEqn | None = None): + # TODO(dougalm): Remove aval. It's redundant now that we have val. + if isinstance(aval, core.AvalQDD): + assert aval.qdd is not None + aval, qdd = aval.aval, aval.qdd + else: + assert not aval.has_qdd + qdd = None self._trace = trace self._line_info = line_info self._debug_info = self._trace.frame.debug_info # for UnexpectedTracerError self.aval = aval # type: ignore[misc] + self.val = val + self.mutable_qdd = core.MutableQuasiDynamicData(qdd) + self.parent = parent + + def _short_repr(self): + return f"JitTracer({self.aval})" + + def cur_qdd(self): + return self.mutable_qdd.cur_val + + @property + def aval_mutable_qdd(self): + aval = self.aval + if aval.has_qdd: + return core.AvalMutableQDD(aval, self.mutable_qdd) + else: + return aval def full_lower(self): - var = self._trace.frame.tracer_to_var.get(id(self)) - if var is None: return self - val = self._trace.frame.constvar_to_val.get(var) - if val is None: return self - return core.full_lower(val) + atom = self.val + if isinstance(atom, Literal): + return self.val.val + else: + maybe_const = self._trace.frame.constvar_to_val.get(atom) + if maybe_const is None: + return self + else: + return core.full_lower(maybe_const.canonical) def _contents(self): return () @@ -1584,7 +1650,8 @@ def _origin_msg(self): f"{dbg.func_src_info} for {dbg.traced_for}. ") if invar_pos: try: - arg_names = [dbg.arg_names[i] for i in invar_pos] + arg_names = [(dbg.arg_names[i] if dbg.arg_names is not None else "unknown") + for i in invar_pos] except IndexError: return "" # TODO(mattjj): figure out when not (invar_pos < len(arg_info)) if len(arg_names) == 1: @@ -1608,10 +1675,14 @@ def _origin_msg(self): origin += "\n\n(Additional originating lines are not shown.)" return "\n" + origin + def get_const(self): + return self._trace.get_const(self) + def get_referent(self): frame = self._trace.frame - val = frame.constvar_to_val.get(frame.tracer_to_var.get(id(self))) - return self if val is None else get_referent(val) + atom = self.val + val = frame.constvar_to_val.get(atom) if isinstance(atom, Var) else None + return self if val is None else get_referent(val.canonical) core.pytype_aval_mappings[DynamicJaxprTracer] = lambda x: x.aval @@ -1621,377 +1692,432 @@ def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects: all_vars = {v: i for i, v in enumerate(it.chain(constvars, invars))} mut_arrays = set() for eqn in eqns: - if eqn.primitive is core.mutable_array_p: + if eqn.primitive in core._ref_allocating_primitives: outvar, = eqn.outvars all_vars[outvar] = None # type: ignore mut_arrays.add(outvar) for eff in eqn.effects: if isinstance(eff, effects.JaxprInputEffect): if eff.input_index >= len(eqn.invars): + # TODO(mattjj): ask for forgiveness + dbg = type('Fake', (), {'resolve_result_paths': lambda self_: self_, + 'assert_arg_names': lambda _, __: None, + 'assert_result_paths': lambda _, __: None, + })() raise ValueError( f"`JaxprInputEffect` {eff} is invalid." f"\n Equation: {eqn}\n" "\n Jaxpr: " - f"{core.Jaxpr(constvars, invars, outvars, eqns, set())}") - invar = eqn.invars[eff.input_index] - if invar in mut_arrays: + f"{core.Jaxpr(constvars, invars, outvars, eqns, set(), dbg)}") # type: ignore + eqn_invar = eqn.invars[eff.input_index] + if type(eqn_invar) is core.Literal or eqn_invar in mut_arrays: continue - if (input_index := all_vars.get(invar, sentinel)) is sentinel: + if (input_index := all_vars.get(eqn_invar, sentinel)) is sentinel: + # TODO(mattjj): ask for forgiveness + dbg = type('Fake', (), {'resolve_result_paths': lambda self_: self_, + 'assert_arg_names': lambda _, __: None, + 'assert_result_paths': lambda _, __: None, + })() raise ValueError( f"`JaxprInputEffect` {eff} does not have " - f"corresponding input: {invar}." + f"corresponding jaxpr input: {eqn_invar=}." f"\n Equation: {eqn}\n" + f"\n Effects: {eqn.effects}\n" "\n Jaxpr: " - f"{core.Jaxpr(constvars, invars, outvars, eqns, set())}") + f"{core.Jaxpr(constvars, invars, outvars, eqns, set(), dbg)}") # type: ignore eff = eff.replace(input_index=input_index) jaxpr_effects.add(eff) return jaxpr_effects +class Constants(NamedTuple): + # A pair of a canonicalized constant and its original form. + # It is important that we keep the original value alive because we use id(c) + # as a key in various dictionaries. If the original value were deleted we + # may confuse constants if the same object ID is reused. + canonical: Any + original: Any + class JaxprStackFrame: gensym: Callable[[AbstractValue], Var] - tracer_to_var: dict[TracerId, Var] - constid_to_tracer: dict[ConstId, Tracer] - constvar_to_val: dict[Var, Any] - tracers: list[DynamicJaxprTracer] # hold onto strong refs for all tracers - eqns: list[JaxprEqn] + constid_to_tracer: WeakValueDictionary[ConstId, DynamicJaxprTracer] + constvar_to_val: dict[Var, Constants] + tracing_eqns: list[Union[ReferenceType[TracingEqn], Callable[[], TracingEqn]]] invars: list[Var] effects: core.Effects - attrs_tracked: list[tuple[Any, str]] - attrs_inits: list - attrs_vars: list[Var] debug_info: core.DebugInfo + is_high: bool + mutable_qdds: list[tuple[Var, core.MutableQuasiDynamicData]] + auto_dce: bool - def __init__(self, debug_info: core.DebugInfo): + def __init__(self, debug_info: core.DebugInfo, auto_dce: bool): self.gensym = core.gensym() - self.tracer_to_var = {} - self.constid_to_tracer = {} + self.constid_to_tracer = WeakValueDictionary() self.constvar_to_val = {} - self.tracers = [] # circ refs, frame->tracer->trace->main->frame, - self.eqns = [] # cleared when we pop frame from main + self.tracing_eqns = [] # cleared when we pop frame from main self.invars = [] self.effects = set() - self.attrs_tracked = [] - self.attrs_inits = [] - self.attrs_vars = [] self.debug_info = debug_info - - def add_eqn(self, eqn: core.JaxprEqn): - self.eqns.append(eqn) - - def to_jaxpr(self, trace: DynamicJaxprTrace, - out_tracers: Sequence[Tracer], - debug_info: core.DebugInfo, - ) -> tuple[Jaxpr, list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: - # It's not necessary, but we keep the tracer-to-var mapping injective: - assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values())) - invars = self.attrs_vars + self.invars - state_ans, end_trees = unzip2( - tree_flatten(t) for t in get_states(self.attrs_tracked)) - state_outvars = [self.tracer_to_var[id(trace.to_jaxpr_tracer(x))] - for xs in state_ans for x in xs] - explicit_outvars = [self.tracer_to_var[id(t)] for t in out_tracers] - outvars = state_outvars + explicit_outvars - constvars, constvals = unzip2(self.constvar_to_val.items()) - jaxpr_effects = make_jaxpr_effects(constvars, self.invars, explicit_outvars, self.eqns) - jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects, - debug_info) - jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals) - jaxpr, constvals = _inline_literals(jaxpr, constvals) - init_trees = [tree_structure(init_val) for init_val in self.attrs_inits] - set_states(self.attrs_tracked, self.attrs_inits) - return jaxpr, list(constvals), zip(init_trees, end_trees, self.attrs_tracked) - - def to_jaxpr2(self, out_tracers: Sequence[core.Tracer], - debug_info: core.DebugInfo): - # It's not necessary, but we keep the tracer-to-var mapping injective: - assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values())) - constvars, constvals = unzip2(self.constvar_to_val.items()) - expl_outvars = [self.tracer_to_var[id(t)] for t in out_tracers] - jaxpr_effects = make_jaxpr_effects(constvars, self.invars, expl_outvars, - self.eqns) - jaxpr = Jaxpr(constvars, self.invars, expl_outvars, self.eqns, - jaxpr_effects, debug_info) - # We can't run check_jaxpr until after we normalize. - jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals) - jaxpr, constvals = _inline_literals(jaxpr, constvals) - jaxpr, out_type = _add_implicit_outputs(jaxpr) - config.enable_checks.value and core.check_jaxpr(jaxpr) - return jaxpr, out_type, constvals + self.is_high = False + self.mutable_qdds = [] + self.auto_dce = auto_dce + + def add_eqn(self, eqn: core.TracingEqn): + assert isinstance(eqn, TracingEqn) + r = (lambda: eqn) if (eqn.effects or not self.auto_dce) else ref(eqn) + self.tracing_eqns.append(r) + + def get_eqns(self): + eqns = [] + for tracing_eqn in self.tracing_eqns: + e = tracing_eqn() + if e is None: continue + eqns.append(JaxprEqn( + [t.val for t in e.in_tracers], + e.outvars, e.primitive, e.params, e.effects, e.source_info, e.ctx)) + return eqns + + def to_jaxpr( + self, trace: DynamicJaxprTrace, + out_tracers: Sequence[Tracer], + debug_info: core.DebugInfo, + source_info: SourceInfo, + ) -> tuple[Jaxpr, list[Any]]: + eqns = self.get_eqns() + outvars = [t.val for t in out_tracers] + constvars, constvals = unzip2(self.constvar_to_val.copy().items()) + constvals = [c.canonical for c in constvals] + constvars, constvals = _drop_unused_vars(constvars, constvals, eqns, outvars) + effs = make_jaxpr_effects(constvars, self.invars, outvars, eqns) + + # TODO(dougalm): handle qdd for consts + for v, qdd in self.mutable_qdds: + v.final_qdd = qdd.cur_val + + all_vars = it.chain(constvars, self.invars, outvars) + is_high = self.is_high or any(v.aval.is_high for v in all_vars) + + jaxpr = Jaxpr(constvars, self.invars, outvars, eqns, effs, debug_info, is_high) + return jaxpr, list(constvals) def newvar(self, aval): - if isinstance(aval, DShapedArray): - # this aval may have tracers in it, so we replace those with variables - new_shape = [self.tracer_to_var[id(d)] if isinstance(d, Tracer) else d - for d in aval.shape] - aval = aval.update(shape=tuple(new_shape)) - return self.gensym(aval) + if isinstance(aval, core.AvalQDD): + return self.gensym(aval.aval, initial_qdd=aval.qdd) + else: + return self.gensym(aval) def find_progenitors(self, tracer): - var = self.tracer_to_var.get(id(tracer)) - if not var: + eqns = self.get_eqns() + var = tracer.val + if not var or isinstance(var, Literal): return None, None active_vars = {var} - for eqn in self.eqns[::-1]: + for eqn in eqns[::-1]: produced = set(eqn.outvars) & active_vars if produced: active_vars.difference_update(produced) active_vars.update({v for v in eqn.invars if type(v) is Var}) invar_positions = [i for i, v in enumerate(self.invars) if v in active_vars] - constvars = active_vars & set(self.constvar_to_val) - const_eqns = [eqn for eqn in self.eqns - if {v for v in eqn.invars if type(v) is Var} & constvars] + constvars = active_vars & set(self.constvar_to_val.copy()) + const_eqns = [eqn for eqn in eqns if any( + v in constvars if type(v) is Var else type(v) is Literal + for v in eqn.invars)] return invar_positions, const_eqns -def _const_folding_and_forwarding( - jaxpr: Jaxpr, constvals: Sequence[Any]) -> tuple[Jaxpr, tuple[Any, ...]]: - consts: dict[Var, Any] = dict(zip(jaxpr.constvars, constvals)) - var_subs: dict[Var, Var] = {} # not Dict[Var, Atom] b/c literals not inlined - new_eqns = [] - def apply_var_sub(a: Atom) -> Atom: - return var_subs.get(a, a) if isinstance(a, Var) else a - for eqn in jaxpr.eqns: - # always apply invar substitutions - eqn = eqn.replace(invars=[apply_var_sub(v) for v in eqn.invars]) - # if any inputs are constants and we have a constant-folding rule, apply it - has_input_effect = any(isinstance(eff, effects.JaxprInputEffect) - for eff in eqn.effects) - if (eqn.primitive in const_fold_rules and - any(v in consts for v in eqn.invars if isinstance(v, Var)) and - not has_input_effect): - consts_in = [consts.get(v) if isinstance(v, Var) else None - for v in eqn.invars] - consts_out, new_eqn = const_fold_rules[eqn.primitive](consts_in, eqn) - assert (new_eqn is None) == all(c is not None for c in consts_out) - for v, c in zip(eqn.outvars, consts_out): - if c is not None: consts[v] = c - if new_eqn is None: continue - else: eqn = new_eqn - # if the application trivially maps some inputs to outputs, simplify - if eqn.primitive in forwarding_rules and not has_input_effect: - fwd_vars, new_eqn = forwarding_rules[eqn.primitive](eqn) - for v_orig, v_new in zip(eqn.outvars, fwd_vars): - if v_new is not None: var_subs[v_orig] = v_new - if new_eqn is None: continue - else: eqn = new_eqn - new_eqns.append(eqn) - new_constvars, new_constvals = unzip2(consts.items()) - new_outvars = [apply_var_sub(v) for v in jaxpr.outvars] - jaxpr_effects = make_jaxpr_effects(new_constvars, jaxpr.invars, new_outvars, - new_eqns) - new_jaxpr = Jaxpr(new_constvars, jaxpr.invars, new_outvars, new_eqns, - jaxpr_effects, jaxpr.debug_info) - return new_jaxpr, new_constvals ConstFoldRule = Callable[ - [list[Union[Any, None]], JaxprEqn], + [list[Union[Any, None]], Any, list[AbstractValue]], tuple[list[Union[Any, None]], Union[JaxprEqn, None]], ] const_fold_rules: dict[Primitive, ConstFoldRule] = {} ForwardingRule = Callable[ [JaxprEqn], - tuple[list[Union[Var, None]], Union[JaxprEqn, None]] + tuple[list[Union[int, None]], Union[JaxprEqn, None]] ] forwarding_rules: dict[Primitive, ForwardingRule] = {} -def _inline_literals( - jaxpr: Jaxpr, constvals: Sequence[Any] -) -> tuple[Jaxpr, list[Any]]: - # This function also prunes unused constants and inserts `dropvar` symbols. - input_effects = {eff for eff in jaxpr.effects - if isinstance(eff, effects.JaxprInputEffect)} - # Don't inline any literal with an input effect - has_input_effect = [any(eff.input_index == i for eff in input_effects) - for i in range(len(constvals))] - lits = {v: Literal(c, v.aval) for v, c, e in zip(jaxpr.constvars, constvals, - has_input_effect) - if type(c) in core.literalable_types and not np.shape(c) and not e} - def lit(a: Atom) -> Literal | None: - return (a if isinstance(a, Literal) else lits.get(a) if isinstance(a, Var) - else None) - newname: Callable[[AbstractValue], Var] = core.gensym() - newvars: dict[Var, Var] = {} - newvar = lambda aval: newname(_substitute_vars_in_type(lits, newvars, aval)) - var = lambda v: newvars.get(v) or newvars.setdefault(v, newvar(v.aval)) - lit_or_var = ( - lambda a: a if isinstance(a, Literal) else (lit(a) or var(a)) - ) - dropvar = lambda aval: DropVar(_substitute_vars_in_type(lits, newvars, aval)) - - def vars_in_shape(aval: AbstractValue) -> Sequence[Var]: - if isinstance(aval, DShapedArray): - return [d for d in aval.shape if isinstance(d, Var)] - return [] - - used = {v for eqn in jaxpr.eqns for atom in eqn.invars - for v in it.chain([atom], vars_in_shape(atom.aval)) - if isinstance(atom, Var)} - used |= {v for outvar in jaxpr.outvars - for v in it.chain([outvar], vars_in_shape(outvar.aval))} - new_constvars = [var(v) for v in jaxpr.constvars if v in used and not lit(v)] - new_constvals = [c for v, c in zip(jaxpr.constvars, constvals) - if v in used and not lit(v)] - new_invars = [var(v) for v in jaxpr.invars] - new_eqns = [] - for eqn in jaxpr.eqns: - invars = [lit_or_var(x) for x in eqn.invars] - outvars = [var(v) if v in used else dropvar(v.aval) for v in eqn.outvars] - new_eqns.append(eqn.replace(invars=invars, outvars=outvars)) - new_outvars = [lit_or_var(v) for v in jaxpr.outvars] - jaxpr_effects = make_jaxpr_effects(new_constvars, new_invars, new_outvars, - new_eqns) - new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns, - jaxpr_effects, jaxpr.debug_info) - return new_jaxpr, new_constvals +def _drop_unused_vars(constvars, constvals, eqns, outvars + ) -> tuple[list[Var], list[Any]]: + # modifies eqns in-place! + def vars(atom: Atom) -> list[Var]: + if isinstance(atom, Literal): + return [] + aval = atom.aval + return [atom] + used: set[Var] = {v for atom in outvars for v in vars(atom)} + for eqn in eqns[::-1]: + eqn.outvars = [v if v in used else DropVar(v.aval) for v in eqn.outvars] + used.update(v for atom in eqn.invars for v in vars(atom)) + constvars, constvals = unzip2( + (v, val) for v, val in zip(constvars, constvals) if v in used) + return constvars, constvals + + +@multi_weakref_lru_cache +def _cached_abstract_eval(primitive: core.Primitive, *aval_qdds, **params): + return primitive.abstract_eval(*aval_qdds, **params) + + +def _verify_params_are_hashable( + primitive: core.Primitive, params: dict[str, Any]) -> None: + for k, v in params.items(): + try: + hash(v) + except TypeError as e: + raise TypeError( + "As of JAX v0.7, parameters to jaxpr equations must have __hash__ and " + f"__eq__ methods. In a call to primitive {primitive}, the value of " + f"parameter {k} was not hashable: {v}") from e + +# We use TracingEqn instead JaxprEqn during tracing to allow automatic +# on-the-fly DCE based on Python refcounting. DynamicJaxprTracers point to +# TracingEqns which point to DynamicJaxprTracers and unreachable constants can +# be freed. + +@dataclass +class TracingEqn: + in_tracers: list[DynamicJaxprTracer] + outvars: list[Var] + primitive: Primitive + params: dict[str, Any] + effects: core.Effects + source_info: source_info_util.SourceInfo + ctx: JaxprEqnContext + def __init__(self, in_tracers, outvars, primitive, params, effects, source_info, ctx): + self.in_tracers = in_tracers + self.outvars = outvars + self.primitive = primitive + self.params = params + self.effects = effects + self.source_info = source_info + self.ctx = ctx + + # Allow TracingEqn to duck-type JaxpeEqn because some of the forwarding + # rules need to work with both. TODO(dougalm): remove this once we fix + # forwarding. + @property + def invars(self): + return self.in_tracers class DynamicJaxprTrace(core.Trace): - __slots__ = ("frame", "tag") + __slots__ = ("frame", "tag", "parent_trace") - def __init__(self, debug_info: core.DebugInfo): + def __init__(self, debug_info: core.DebugInfo, parent_trace=None, lower=False, + auto_dce=False): super().__init__() - self.frame = JaxprStackFrame(debug_info) + self.requires_low = lower + self.frame = JaxprStackFrame(debug_info, auto_dce) + self.parent_trace = parent_trace def invalidate(self): + # TODO(mattjj): exposed existing tracer leaks; fix them and re-enable! + # super().invalidate() + # avoid cyclic refs - self.frame.tracers = [] + self.frame.tracing_eqns = [] # thunk -> eqn -> in_tracers -> trace -> + # -> frame -> tracing_eqns -> thunk + + # TODO(dougalm): we might be able to remove these given refcounting dce self.frame.constid_to_tracer = {} + self.frame.constvar_to_val = {} - def to_jaxpr_tracer(self, x): - as_local_var = self.frame.tracer_to_var.get(id(x)) - if as_local_var is None: + def to_jaxpr_tracer(self, x, source_info: SourceInfo): + if isinstance(x, DynamicJaxprTracer) and x._trace is self: + return x + else: if hasattr(x, "dimension_as_value"): # Used for shape_poly._DimExpr with core.set_current_trace(self): x = x.dimension_as_value() - return self.to_jaxpr_tracer(x) + return self.to_jaxpr_tracer(x, source_info) else: - return self.new_const(x) - else: - return x + return self.new_const(x, source_info) - def new_arg(self, aval): - tracer = DynamicJaxprTracer(self, aval, source_info_util.current()) - self.frame.tracers.append(tracer) - self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval) + def var_to_tracer(self, var, source_info, parent=None): + aval = var.aval + if aval.has_qdd: + aval = core.AvalQDD(aval, var.initial_qdd) + return DynamicJaxprTracer(self, aval, var, source_info, parent) + + def new_arg(self, aval, source_info: SourceInfo): + var = self.frame.newvar(aval) + tracer = DynamicJaxprTracer(self, aval, var, source_info) self.frame.invars.append(var) + self.frame.mutable_qdds.append((var, tracer.mutable_qdd)) return tracer - def new_const(self, c): + def make_eqn(self, in_tracers, out_avals, primitive, params, + effects, source_info=None, ctx = None): + source_info = source_info or source_info_util.new_source_info() + ctx = ctx or JaxprEqnContext( + config.compute_on_context_manager.value, + config.threefry_partitionable.value, + xla_metadata_lib.current_xla_metadata()) + outvars = map(self.frame.newvar, out_avals) + if config.enable_checks.value: + assert all(isinstance(x, DynamicJaxprTracer) for x in in_tracers) + assert all(isinstance(v, Var) for v in outvars) + eqn = TracingEqn(in_tracers, outvars, primitive, params, effects, source_info, ctx) + out_tracers = [self.var_to_tracer(v, source_info, eqn) for v in outvars] + return eqn, out_tracers + + def emit_eqn(self, in_tracers, out_avals, primitive, params, effects, source_info=None, ctx=None): + eqn, out_tracers = self.make_eqn(in_tracers, out_avals, primitive, params, effects, source_info, ctx) + self.frame.add_eqn(eqn) + return out_tracers + + def new_const(self, c, source_info: SourceInfo, + aval: AbstractValue | None = None): # TODO(mattjj): for ints, or hashable consts, don't rely on id tracer = self.frame.constid_to_tracer.get(id(c)) if tracer is None: - aval = get_aval(c) - if hasattr(aval, "weak_type"): - aval = aval.update_weak_type(dtypes.is_weakly_typed(c)) - aval = self._lift_tracers_in_aval(aval) - tracer = self._new_const(aval, c) + if aval is None: + aval = get_aval(c) + if aval.has_qdd: + with core.set_current_trace(self.parent_trace or core.eval_trace): + aval = core.AvalQDD(aval, core.cur_qdd(c)) # type: ignore + tracer = self._new_const(aval, c, source_info) return tracer pure = lift = new_const - def _new_const(self, aval, c) -> DynamicJaxprTracer: - tracer = DynamicJaxprTracer(self, aval, source_info_util.current()) - self.frame.tracers.append(tracer) - self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval) - self.frame.constid_to_tracer[id(c)] = tracer - self.frame.constvar_to_val[var] = c - return tracer + def _new_const(self, aval, c, source_info: SourceInfo) -> DynamicJaxprTracer: + orig_c = c + id_c = id(c) + if isinstance(c, (int, float, bool, complex, np.generic, np.ndarray)): + c = dtypes.canonicalize_value(c) + if core.is_literalable(c): + val = Literal(c, aval) + return DynamicJaxprTracer(self, aval, val, source_info) + else: + var = self.frame.newvar(aval) + tracer = DynamicJaxprTracer(self, aval, var, source_info) + self.frame.constid_to_tracer[id_c] = tracer + if isinstance(aval, core.AvalQDD): + self.frame.mutable_qdds.append((var, tracer.mutable_qdd)) + self.frame.constvar_to_val[var] = Constants(canonical=c, original=orig_c) + finalize(tracer, self.finalize_const, var, id_c) + return tracer - def _lift_tracers_in_aval(self, aval): - if (not isinstance(aval, DShapedArray) or - not any(isinstance(d, Tracer) for d in aval.shape)): - return aval - shape = [self.to_jaxpr_tracer(d) if isinstance(d, Tracer) else d - for d in aval.shape] - return aval.update(shape=tuple(shape)) - - def getvar(self, tracer): - var = self.frame.tracer_to_var.get(id(tracer)) - if var is None: - raise core.escaped_tracer_error(tracer) - return var + def finalize_const(self, var, constid): + self.frame.constvar_to_val.pop(var, None) - def makevar(self, tracer): - var = self.frame.tracer_to_var.get(id(tracer)) - assert var is None, "a jaxpr variable must be created only once per tracer" - self.frame.tracers.append(tracer) - var = self.frame.tracer_to_var[id(tracer)] = self.frame.newvar(tracer.aval) - return var + def get_const(self, tracer) -> Any: + atom = tracer.val + if isinstance(atom, Literal): + return atom.val + else: + const = self.frame.constvar_to_val.get(atom) + if const is not None: + const = const.canonical + return const - def is_const(self, tracer): - return self.frame.tracer_to_var.get(id(tracer)) is None + def cur_qdd(self, x): + source_info = source_info_util.current() + return self.to_jaxpr_tracer(x, source_info=source_info).mutable_qdd.cur_val def process_primitive(self, primitive, tracers, params): - if (config.eager_constant_folding.value and all(map(self.is_const, tracers))): + self.frame.is_high |= primitive.is_high(*map(typeof, tracers), **params) + if config.eager_constant_folding.value and not any(isinstance(x, Tracer) for x in tracers): return primitive.bind_with_trace(core.eval_trace, tracers, params) - jaxpr_tracers = map(self.to_jaxpr_tracer, tracers) + source_info = source_info_util.current() + to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) + jaxpr_tracers = map(to_jaxpr_tracer, tracers) if primitive in custom_staging_rules: - return custom_staging_rules[primitive](self, *jaxpr_tracers, **params) - return self.default_process_primitive(primitive, jaxpr_tracers, params) + return custom_staging_rules[primitive](self, source_info, *jaxpr_tracers, + **params) + return self.default_process_primitive( + primitive, jaxpr_tracers, params, source_info) + + def default_process_primitive(self, primitive, tracers, params, + source_info=None): + from jax._src.hijax import call_hi_primitive_p + aval_qdds = [t.aval_mutable_qdd for t in tracers] + # TODO(mattjj): make custom_lin have hashable params. + # TODO(dougalm): add an attribute to primitives to mark primitives with + # effectful abstract_eval rules. + # TODO(mattjj,dougalm): clean up how we check for new-style hi primitives + if primitive is call_hi_primitive_p: + out_avals, effs = params['prim'].out_avals_flat, set() # TODO effs + elif (primitive.name in ("custom_lin", "call_hi_primitive_linearized") or + primitive.is_effectful and primitive.is_effectful(params)): + out_avals, effs = primitive.abstract_eval(*aval_qdds, **params) + else: + try: + out_avals, effs = _cached_abstract_eval(primitive, *aval_qdds, **params) + except Exception as e: + # TODO(phawkins): remove this 3 months after the release of JAX v0.7. + _verify_params_are_hashable(primitive, params) + raise - def default_process_primitive(self, primitive, tracers, params): - avals = [t.aval for t in tracers] - out_avals, effects = primitive.abstract_eval(*avals, **params) if isinstance(out_avals, (tuple, list)) != primitive.multiple_results: raise ValueError(f"{primitive}.abstract_eval() method should return " f"a tuple or a list iff {primitive}.multiple_results.") out_avals = [out_avals] if not primitive.multiple_results else out_avals - source_info = source_info_util.current() - out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] - invars = map(self.getvar, tracers) - outvars = map(self.makevar, out_tracers) - eqn = new_jaxpr_eqn(invars, outvars, primitive, params, effects, - source_info) - self.frame.add_eqn(eqn) + source_info = source_info or source_info_util.current() + + maybe_consts_out = try_constant_folding(primitive, tracers, params, out_avals) + if maybe_consts_out is not None: + eqn = None + out_tracers = [self.new_const(c, source_info=source_info, aval=aval) + for c, aval in zip(maybe_consts_out, out_avals)] + else: + eqn, out_tracers = self.make_eqn(tracers, out_avals, primitive, params, + effs, source_info=source_info) + # Input-to-output tracer forwarding + no_input_effects = not any(isinstance(e, effects.JaxprInputEffect) for e in effs) + if eqn is not None and no_input_effects and primitive in forwarding_rules: + in_fwd, eqn = forwarding_rules[primitive](eqn) + for out_idx, in_idx in enumerate(in_fwd): + if in_idx is not None: + out_tracers[out_idx] = tracers[in_idx] + + if eqn is not None: + self.frame.add_eqn(eqn) return out_tracers if primitive.multiple_results else out_tracers.pop() - def process_call(self, call_primitive, f: lu.WrappedFun, - explicit_tracers, params): - if f.in_type is None: - f = lu.annotate(f, tuple((get_aval(t), True) for t in explicit_tracers)) - assert f.in_type is not None - implicit_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers) - in_tracers = map(self.to_jaxpr_tracer, [*implicit_tracers, *explicit_tracers]) + def process_call(self, call_primitive, f: lu.WrappedFun, in_tracers, + params): + source_info = source_info_util.current() + to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) + in_type = (tuple(get_aval(t) for t in in_tracers) if f.in_type is None + else f.in_type) + f.in_type = None + assert in_type is not None + in_tracers = map(to_jaxpr_tracer, in_tracers) # TODO(mattjj): check in_tracers are consistent with f.in_type annotation - jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f) + jaxpr, out_avals, consts = _cached_trace_to_jaxpr(f, in_type) if params.get('inline', False): return core.eval_jaxpr(jaxpr, consts, *in_tracers, propagate_source_info=False) - source_info = source_info_util.current() - out_tracers: list[Tracer] = [] - for aval, _ in out_type: - if type(aval) is DShapedArray: - shape = [[*consts, *in_tracers][d.val] if type(d) is InDBIdx else - out_tracers[d.val] if type(d) is OutDBIdx else - d for d in aval.shape] - aval = aval.update(shape=tuple(get_referent(d) for d in shape)) - out_tracers.append(DynamicJaxprTracer(self, aval, source_info)) - invars = map(self.getvar, in_tracers) - constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts)) - outvars = map(self.makevar, out_tracers) - new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr)) + + new_jaxpr = convert_constvars_jaxpr(jaxpr) + if isinstance(call_primitive, core.ClosedCallPrimitive): + new_jaxpr = close_jaxpr(new_jaxpr) # type: ignore + new_params = dict(params, call_jaxpr=new_jaxpr) update_params = call_param_updaters.get(call_primitive) if update_params: - new_params = update_params(new_params, [True] * len(explicit_tracers), - len(consts) + len(implicit_tracers)) - eqn = new_jaxpr_eqn([*constvars, *invars], outvars, call_primitive, - new_params, new_params['call_jaxpr'].effects, source_info) - self.frame.add_eqn(eqn) - return [t for t, (_, keep) in zip(out_tracers, out_type) if keep] + new_params = update_params(new_params, [True] * len(in_tracers), + len(consts)) + const_tracers = map(to_jaxpr_tracer, consts) + return self.emit_eqn( + [*const_tracers, *in_tracers], out_avals, call_primitive, + new_params, new_params['call_jaxpr'].effects, source_info=source_info) def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): - tracers = map(self.to_jaxpr_tracer, tracers) + source_info = source_info_util.current() + to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) + tracers = map(to_jaxpr_tracer, tracers) in_avals = [t.aval for t in tracers] axis_name, axis_size = params['axis_name'], params['axis_size'] reduced_in_avals = [core.mapped_aval(axis_size, in_axis, a) if in_axis is not None else a for a, in_axis in zip(in_avals, params['in_axes'])] - with core.extend_axis_env_nd([(axis_name, params["global_axis_size"])]): - jaxpr, reduced_out_avals, consts, () = trace_to_jaxpr_dynamic( - f, reduced_in_avals) + jaxpr, reduced_out_avals, consts = trace_to_jaxpr_dynamic( + f.with_unknown_names(), reduced_in_avals) jaxpr, consts = _linearize_of_pmap_hack(f, jaxpr, consts) ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects) if ordered_effects: @@ -2001,11 +2127,7 @@ def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): out_avals = [core.unmapped_aval(axis_size, out_axis, a) if out_axis is not None else a for a, out_axis in zip(reduced_out_avals, out_axes)] - source_info = source_info_util.current() - out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] - invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts)) - outvars = map(self.makevar, out_tracers) + const_tracers = map(to_jaxpr_tracer, consts) new_in_axes = (None,) * len(consts) + params['in_axes'] new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes, call_jaxpr=convert_constvars_jaxpr(jaxpr)) @@ -2014,77 +2136,84 @@ def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): if update_params: new_params = update_params(new_params, [True] * len(tracers), len(consts)) effs = core.filter_named_axis_effects(jaxpr.effects, {axis_name}) - eqn = new_jaxpr_eqn([*constvars, *invars], outvars, map_primitive, - new_params, effs, source_info) - self.frame.add_eqn(eqn) + out_tracers = self.emit_eqn( + [*const_tracers, *tracers], out_avals, map_primitive, new_params, effs, source_info=source_info) return out_tracers def process_custom_jvp_call(self, prim, fun: lu.WrappedFun, jvp: lu.WrappedFun, tracers, symbolic_zeros: bool): - tracers = map(self.to_jaxpr_tracer, tracers) + if config.eager_constant_folding.value and not any(isinstance(x, Tracer) for x in tracers): + return prim.bind_with_trace(core.eval_trace, (fun, jvp, *tracers), + dict(symbolic_zeros=symbolic_zeros)) + source_info = source_info_util.current() + to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) + tracers = map(to_jaxpr_tracer, tracers) in_avals = [t.aval for t in tracers] in_tangent_avals = [t.to_tangent_aval() for t in in_avals] - fun_jaxpr, out_avals, consts, () = trace_to_jaxpr_dynamic(fun, in_avals) + fun_jaxpr, out_avals, consts = trace_to_jaxpr_dynamic(fun, in_avals) closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ()) + @partial(lu.wrap_init, debug_info=jvp.debug_info) @_memoize def jvp_jaxpr_thunk(*in_zeros): for store in jvp.stores: store and store.reset() nz_tangent_avals, zero_avals = partition_list(in_zeros, in_tangent_avals) jvp_, out_zeros = _jvp_jaxpr_zeros(jvp, in_zeros, tuple(zero_avals)) in_avals_ = (*in_avals, *nz_tangent_avals) - jaxpr, _, out_consts, () = trace_to_jaxpr_dynamic(jvp_, in_avals_) + jaxpr, _, out_consts = trace_to_jaxpr_dynamic(jvp_.with_unknown_names(), + in_avals_) return jaxpr, out_consts, out_zeros() - out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] - invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts)) - outvars = map(self.makevar, out_tracers) - eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim, - dict(call_jaxpr=closed_fun_jaxpr, - jvp_jaxpr_fun=lu.wrap_init(jvp_jaxpr_thunk, - debug_info=jvp.debug_info), - num_consts=len(consts), - symbolic_zeros=symbolic_zeros), - fun_jaxpr.effects, - source_info_util.current()) - self.frame.add_eqn(eqn) - return out_tracers + const_tracers = map(to_jaxpr_tracer, consts) + return self.emit_eqn( + [*const_tracers, *tracers], out_avals, prim, + dict(call_jaxpr=closed_fun_jaxpr, + jvp_jaxpr_fun=jvp_jaxpr_thunk, + num_consts=len(consts), + symbolic_zeros=symbolic_zeros), + fun_jaxpr.effects, + source_info=source_info) def process_custom_vjp_call(self, prim: core.Primitive, fun: lu.WrappedFun, fwd: lu.WrappedFun, bwd: lu.WrappedFun, tracers, - out_trees: Callable[[], Sequence[PyTreeDef]], + out_trees: Callable[[], tuple[PyTreeDef, PyTreeDef, list[int | None]]], symbolic_zeros: bool): - tracers = map(self.to_jaxpr_tracer, tracers) - in_avals = [t.aval for t in tracers] - fun_jaxpr, out_avals, consts, _ = trace_to_jaxpr_dynamic(fun, in_avals) + if config.eager_constant_folding.value and not any(isinstance(x, Tracer) for x in tracers): + return prim.bind_with_trace(core.eval_trace, (fun, fwd, bwd, *tracers), + dict(out_trees=out_trees, symbolic_zeros=symbolic_zeros)) + source_info = source_info_util.current() + to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) + tracers = map(to_jaxpr_tracer, tracers) + in_avals = [core.AvalQDD(t.aval, core.cur_qdd(t)) if t.aval.has_qdd else t.aval for t in tracers] + fun_jaxpr, out_avals, consts = trace_to_jaxpr_dynamic(fun.with_unknown_names(), in_avals) + num_consts = len(consts) closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ()) + @partial(lu.wrap_init, debug_info=fwd.debug_info) @_memoize def fwd_jaxpr_from_zeros(*zeros): for store in fwd.stores: store and store.reset() - fwd_ = _interleave_fun(fwd, zeros) - jaxpr, _, consts, attrs = trace_to_jaxpr_dynamic(fwd_, in_avals) - if attrs: raise NotImplementedError + fwd_ = _interleave_fun(fwd.with_unknown_names(), zeros) + jaxpr, _, consts = trace_to_jaxpr_dynamic(fwd_, in_avals) return jaxpr, consts - out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] - invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts)) - outvars = map(self.makevar, out_tracers) - eqn = new_jaxpr_eqn([*constvars, *invars], outvars, - prim.initial_style, # pytype: disable=attribute-error - dict(fun_jaxpr=closed_fun_jaxpr, - fwd_jaxpr_thunk=fwd_jaxpr_from_zeros, - num_consts=len(consts), - bwd=bwd, out_trees=out_trees, - symbolic_zeros=symbolic_zeros), - fun_jaxpr.effects, - source_info_util.current()) - self.frame.add_eqn(eqn) - return out_tracers + def out_trees_(): + out_tree, res_tree, input_fwds = out_trees() + input_fwds = [f if f is None else f + num_consts for f in input_fwds] + return out_tree, res_tree, input_fwds + + const_tracers = map(to_jaxpr_tracer, consts) + return self.emit_eqn( + [*const_tracers, *tracers], out_avals, prim, + dict(call_jaxpr=closed_fun_jaxpr, + fwd_jaxpr_thunk=fwd_jaxpr_from_zeros, + num_consts=num_consts, + bwd=bwd, out_trees=out_trees_, + symbolic_zeros=symbolic_zeros), + fun_jaxpr.effects, + source_info=source_info) def process_custom_transpose(self, prim: core.Primitive, # type: ignore[override] call: lu.WrappedFun, tracers, *, @@ -2092,13 +2221,15 @@ def process_custom_transpose(self, prim: core.Primitive, # type: ignore[overrid out_types, lin_tree: PyTreeDef, res_tree: PyTreeDef, out_tree: PyTreeDef): - tracers = map(self.to_jaxpr_tracer, tracers) + source_info = source_info_util.current() + to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) + tracers = map(to_jaxpr_tracer, tracers) tracers_res, tracers_lin = split_list(tracers, [res_tree.num_leaves]) in_avals_p = [t.aval for t in tracers] in_avals_t = [*[t.aval for t in tracers_res], *out_types] - call_jaxpr, out_avals, call_consts, _ = trace_to_jaxpr_dynamic(call, in_avals_p) + call_jaxpr, out_avals, call_consts = trace_to_jaxpr_dynamic(call, in_avals_p) closed_call_jaxpr = core.ClosedJaxpr( convert_constvars_jaxpr(call_jaxpr), ()) @@ -2109,26 +2240,28 @@ def process_custom_transpose(self, prim: core.Primitive, # type: ignore[overrid @_memoize def transpose_jaxpr_thunk(): for store in transpose_flat.stores: store.reset() - jaxpr, _, consts, () = trace_to_jaxpr_dynamic(transpose_flat, in_avals_t) + jaxpr, _, consts = trace_to_jaxpr_dynamic(transpose_flat, in_avals_t) return jaxpr, consts - out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] - invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(self.to_jaxpr_tracer, call_consts)) - outvars = map(self.makevar, out_tracers) - eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim, - dict(call_jaxpr=closed_call_jaxpr, - transpose_jaxpr_thunk=transpose_jaxpr_thunk, - out_types=out_types, res_tree=res_tree, - lin_tree=lin_tree, out_tree=out_tree), - closed_call_jaxpr.effects, - source_info_util.current()) - self.frame.add_eqn(eqn) - return out_tracers + const_tracers = map(to_jaxpr_tracer, call_consts) + return self.emit_eqn( + [*const_tracers, *tracers], out_avals, prim, + dict(call_jaxpr=closed_call_jaxpr, + transpose_jaxpr_thunk=transpose_jaxpr_thunk, + out_types=out_types, res_tree=res_tree, + lin_tree=lin_tree, out_tree=out_tree), + closed_call_jaxpr.effects, + source_info=source_info) def to_jaxpr(self, out_tracers: Sequence[Tracer], - debug_info: core.DebugInfo): - return self.frame.to_jaxpr(self, out_tracers, debug_info) + debug_info: core.DebugInfo, source_info: SourceInfo): + return self.frame.to_jaxpr(self, out_tracers, debug_info, source_info) + + +@lu.cache +def _cached_trace_to_jaxpr(f, in_type): + jaxpr, out_type, consts = trace_to_jaxpr_dynamic(lu.annotate(f, in_type), in_type) + return jaxpr, out_type, consts custom_staging_rules: dict[Primitive, Callable] = {} @@ -2164,30 +2297,180 @@ def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals): store.store(out_zeros) return [*out_primals, *out_nz_tangents] +callsites_with_tracing_cache_miss: set[str] = set() + +def explain(keys, fun, in_avals, debug_info, *context): + func_filename = debug_info.func_filename + if func_filename and not source_info_util.is_user_filename(func_filename): + return + + msg: list[str] = [] + p = msg.append + + callsite = source_info_util.summarize(source_info_util.current()) + p(f"TRACING CACHE MISS at {callsite}:") + + src_info = "" + if func_filename: + src_info += f" defined at {func_filename}" + if func_lineno := debug_info.func_lineno: + src_info += f":{func_lineno}" + func_name = debug_info.func_name + + # have we seen this function before at all? + keys = [key for fun_ref, *key in keys if fun_ref() is fun] + if not keys: + p(f" never seen function:\n {func_name} id={id(fun)}{src_info}") + if callsite in callsites_with_tracing_cache_miss: + p(" but seen another function defined on the same line; maybe the function is\n" + " being re-defined repeatedly, preventing caching?") + else: + callsites_with_tracing_cache_miss.add(callsite) + return logger.log(logging.WARNING, "\n".join(msg)) + + p(f" for {func_name}{src_info}") + + key = (config.trace_context(), (in_avals, debug_info, *context), {}) + min_diff = min(diff_tracing_cache_keys(key, k) for k in keys)[-1] + p(' all previously seen cache keys differ. For the closest previous key:') + p(' ' + min_diff) + return logger.log(logging.WARNING, "\n".join(msg)) + +def diff_tracing_cache_keys(new_key, old_key) -> tuple[int, int, str] | None: + new_ctx, (new_tree, new_dbg, new_qdd, *_), () = new_key + old_ctx, (old_tree, old_dbg, old_qdd, *_), () = old_key + return (diff_ctx(new_ctx, old_ctx) or + diff_trees(new_tree.tree, old_tree.tree) or + diff_debug(new_dbg, old_dbg) or + diff_types(new_dbg, new_tree.vals, old_tree.vals) or + (4, 0, 'cache miss explanation unavailable')) + +def diff_ctx(new_ctx, old_ctx): + msg = "Tracing context doesn't match, e.g. due to config or context manager." + num_diff = sum(map(op.ne, new_ctx, old_ctx)) + if num_diff: return 0, num_diff, msg + +def diff_trees(new_tree, old_tree): + errs = tree_util.equality_errors_pytreedef(new_tree, old_tree) + tree_diffs = [] + for path, thing1, thing2, explanation in errs: + tree_diffs.append( + f" * at input path {tree_util.keystr(tuple(path))}, now {thing1} and " + f"before {thing2}, so {explanation}") + msg = 'different input pytree:\n' + '\n'.join(tree_diffs) + if tree_diffs: return 1, len(tree_diffs), msg + +def diff_debug(new_dbg, old_dbg): + msg = "Debug info doesn't match." + num_diff = sum(map(op.ne, new_dbg, old_dbg)) + if num_diff: return 2, num_diff, msg + +def diff_types(dbg, new_leaves, old_leaves): + if new_leaves == old_leaves: return + diffs = [] + add_weak_type_hint = False + for name, new_ty, old_ty in zip(dbg.arg_names, new_leaves, old_leaves): + if new_ty != old_ty: + new_str, old_str = new_ty.str_short(True), old_ty.str_short(True) + if type(new_ty) is type(old_ty) is core.ShapedArray: + if new_ty.sharding != old_ty.sharding: + new_str, old_str = new_ty.str_short(True, True), old_ty.str_short(True, True) + if new_ty.weak_type != old_ty.weak_type: + add_weak_type_hint = True + new_str += f'{{weak_type={new_ty.weak_type}}}' + old_str += f'{{weak_type={old_ty.weak_type}}}' + diffs.append(f" * at {name}, now {new_str} and before {old_str}") + msg = 'different input types:\n' + '\n'.join(diffs) + if add_weak_type_hint: + msg += 'https://docs.jax.dev/en/latest/type_promotion.html#weak-types' + if diffs: return 3, len(diffs), msg + + +@partial(weakref_lru_cache, explain=explain, + maxsize=None if jaxlib_extension_version >= 396 else 8192) +def trace_to_jaxpr( + fun: Callable, + in_avals: FlatTree, # (args, kwargs) pair + debug_info: core.DebugInfo, + *context_for_cache_key, +) -> tuple[ClosedJaxpr, FlatTree]: + if config.no_tracing.value: + raise RuntimeError(f"re-tracing function {fun} for " + "`jit`, but 'no_tracing' is set") + del context_for_cache_key # read implicitly, e.g. qdd state + test_event("trace_to_jaxpr") + config.enable_checks.value and debug_info.assert_arg_names(len(in_avals)) + parent_trace = core.trace_ctx.trace + trace = DynamicJaxprTrace(debug_info, parent_trace=parent_trace) + # Name stacks are reset because the name stacks on jaxpr equations should be + # rooted at the enclosing jaxpr. + with core.ensure_no_leaks(trace), source_info_util.reset_name_stack(): + source_info = source_info_util.current() + in_tracers = in_avals.map(partial(trace.new_arg, source_info=source_info)) + with core.set_current_trace(trace): + args, kwargs = in_tracers.unflatten() + ans_pytree = fun(*args, **kwargs) + debug_info = debug_info.set_result_paths(ans_pytree) + ans = FlatTree.flatten(ans_pytree) + del ans_pytree, args, kwargs + + _check_returned_jaxtypes(debug_info, list(ans)) + out_tracers = ans.map(partial(trace.to_jaxpr_tracer, source_info=source_info)) + out_avals = out_tracers.map(lambda t: t.aval) + _check_no_returned_refs(debug_info, list(out_tracers)) + jaxpr, consts = trace.frame.to_jaxpr(trace, list(out_tracers), debug_info, + source_info) + del trace, fun, in_tracers, out_tracers, ans + + config.enable_checks.value and core.check_jaxpr(jaxpr) + return ClosedJaxpr(jaxpr, consts), out_avals +# TODO(dougalm): remove in favor of `trace_to_jaxpr` @profiler.annotate_function def trace_to_jaxpr_dynamic( fun: lu.WrappedFun, - in_avals: Sequence[AbstractValue], + in_avals: Sequence[AbstractValue | core.AvalQDD], *, keep_inputs: list[bool] | None = None, -) -> tuple[Jaxpr, list[AbstractValue], list[Any], - list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: + lower: bool = False, + auto_dce: bool = False, +) -> tuple[Jaxpr, list[AbstractValue], list[Any]]: + config.enable_checks.value and fun.debug_info.assert_arg_names(len(in_avals)) keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs - trace = DynamicJaxprTrace(fun.debug_info) + parent_trace = core.trace_ctx.trace + trace = DynamicJaxprTrace(fun.debug_info, parent_trace=parent_trace, + lower=lower, auto_dce=auto_dce) + # Name stacks are reset because the name stacks on jaxpr equations should be + # rooted at the enclosing jaxpr. with core.ensure_no_leaks(trace), source_info_util.reset_name_stack(): - in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) + source_info = source_info_util.current() + in_tracers = map(partial(trace.new_arg, source_info=source_info), in_avals) in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep] + with core.set_current_trace(trace): ans = fun.call_wrapped(*in_tracers) - - out_tracers = map(trace.to_jaxpr_tracer, ans) + _check_returned_jaxtypes(fun.debug_info, ans) + out_tracers = map(partial(trace.to_jaxpr_tracer, source_info=source_info), ans) _check_no_returned_refs(fun.debug_info, out_tracers) - jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers, fun.debug_info) + jaxpr, consts = trace.frame.to_jaxpr(trace, out_tracers, fun.debug_info, + source_info) del trace, fun, in_tracers, out_tracers, ans config.enable_checks.value and core.check_jaxpr(jaxpr) - return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked + return jaxpr, [v.aval for v in jaxpr.outvars], consts + +def _check_returned_jaxtypes(dbg, out_tracers): + for i, x in enumerate(out_tracers): + try: typeof(x) + except TypeError: + if (dbg and len(paths := dbg.resolve_result_paths()) > i and + (p := paths[i].removeprefix('result'))): + extra = f' at output component {p}' + else: + extra = '' + raise TypeError( + f"function {dbg.func_src_info} traced for {dbg.traced_for} returned a " + f"value of type {type(x)}{extra}, which is not a valid JAX type") from None def _check_no_returned_refs( dbg: core.DebugInfo, @@ -2200,10 +2483,12 @@ def _check_no_returned_refs( result_paths = dbg.resolve_result_paths().safe_result_paths(len(out_tracers)) loc = result_paths[i] and f' at output tree path {result_paths[i]}' frame = t._trace.frame - v = frame.tracer_to_var.get(id(t)) - eqn = next((e for e in frame.eqns if v in e.outvars), None) + v = t.val + eqns = frame.get_eqns() + # TODO(dougalm): something more efficient + eqn = next((e for e in eqns if v in e.outvars), None) if eqn: - assert eqn.primitive is core.mutable_array_p + assert eqn.primitive is core.ref_p origin_info = ('\n\nThe returned mutable array was created on line ' f'{source_info_util.summarize(eqn.source_info)}.') elif v in frame.invars: @@ -2217,179 +2502,6 @@ def _check_no_returned_refs( f"a mutable array reference of type {a.str_short()}{loc}, but " f"mutable array references cannot be returned.{origin_info}") -@profiler.annotate_function -def trace_to_jaxpr_dynamic2( - fun: lu.WrappedFun, - ) -> tuple[Jaxpr, OutputType, list[Any]]: - assert fun.in_type is not None, "fun must be annotated with lu.annotate()" - - trace = DynamicJaxprTrace(fun.debug_info) - with core.ensure_no_leaks(trace), source_info_util.reset_name_stack(): - in_avals, keep_inputs = unzip2(fun.in_type) - in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) - in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep] - with core.set_current_trace(trace): - ans = fun.call_wrapped(*in_tracers) - out_tracers = map(trace.to_jaxpr_tracer, ans) - jaxpr = trace.frame.to_jaxpr2(out_tracers, fun.debug_info) - del trace, in_tracers, out_tracers, ans - - return jaxpr - -AbstractedAxisName = Hashable -AbstractedAxesSpec = Union[ - dict[int, AbstractedAxisName], - tuple[AbstractedAxisName, ...], -] - -AttrsTracked = list[tuple[Any, str]] -AttrStates = list -def set_states(attrs_tracked: AttrsTracked, vals: AttrStates): - for ((obj, attr), val) in zip(attrs_tracked, vals): - setattr(obj, attr, val) - -def get_states(attrs_tracked: AttrsTracked): - return [getattr(obj, attr) for (obj, attr) in attrs_tracked] - - -def infer_lambda_input_type( - axes_specs: Sequence[AbstractedAxesSpec] | None, - args: Sequence[Any] - ) -> InputType: - ndims = [getattr(get_aval(x), 'ndim', 0) for x in args] - partial_specs = _canonicalize_specs(ndims, axes_specs) - specs = _complete_specs(args, partial_specs) - idxs, implicit_types = _collect_implicit(args, specs) - implicit_sig = [(ty, False) for ty in implicit_types] - explicit_sig = [(_arg_type(idxs, x, s), True) for x, s in zip(args, specs)] - input_type = (*implicit_sig, *explicit_sig) - lu._check_input_type(input_type) - return input_type - -def _spec_to_dict(spec: AbstractedAxesSpec) -> dict[int, AbstractedAxisName]: - if isinstance(spec, tuple): - return {i: d for i, d in enumerate(spec) if d is not None} - else: - return spec - -def _canonicalize_specs( - ndims: Sequence[int], specs: Sequence[AbstractedAxesSpec] | None - ) -> list[dict[int, AbstractedAxisName]]: - if specs is None: - return [{}] * len(ndims) - else: - return [_spec_to_dict(s) for n, s in zip(ndims, specs)] - -def _complete_specs( - args: Sequence[Any], partial_specs: list[dict[int, AbstractedAxisName]] - ) -> list[dict[int, AbstractedAxisName]]: - # The abstracted axes specification in `partial_specs` is partial in the sense - # that there could be additional axis abstraction represented in `args` due to - # Tracers existing in the shapes of elements of `args`. The purpose of this - # function is to produce a full specification, for each argument mapping any - # abstracted axis positions to a name, introducing new names as needed for - # Tracers in axis sizes which don't already correspond to abstracted axis - # names (with one new name per unique Tracer object id). - - # Identify each user-supplied name in partial_specs with a size. - sizes: dict[AbstractedAxisName, int | DynamicJaxprTracer] = {} - for x, spec in zip(args, partial_specs): - for i, name in spec.items(): - d = sizes.setdefault(name, x.shape[i]) - if d is not x.shape[i] and d != x.shape[i]: - raise TypeError(f"Provided size {d} for {name} does not match prior associated name for {name} : {x.shape[i]}") - - # Introduce new names as needed for Tracers in shapes. - named_tracers: dict[TracerId, AbstractedAxisName] = { - id(d): name for name, d in sizes.items() if isinstance(d, Tracer)} - specs: list[dict[int, AbstractedAxisName]] = [] - for x, spec in zip(args, partial_specs): - if isinstance(get_aval(x), DShapedArray): - spec = dict(spec) - for i, d in enumerate(x.shape): - if isinstance(d, Tracer): - spec[i] = named_tracers.get(id(d), TracerAsName(d)) - specs.append(spec) - - # Assert that `specs` is now complete in the sense that there are no Tracers - # which don't correspond to an AbstractedAxisName. - assert all(not spec or not any(isinstance(d, Tracer) and i not in spec - for i, d in enumerate(x.shape)) - for x, spec in zip(args, specs)) - return specs - - -def _collect_implicit( - args: Sequence[Any], specs: list[dict[int, AbstractedAxisName]] - ) -> tuple[dict[AbstractedAxisName, DBIdx], list[AbstractValue]]: - # Given an explicit argument list and a specification of abstracted axes, we - # want to produce an InputType by identifying AbstractedAxisNames with DBIdxs - # and figuring out which AbstractedAxisNames correspond to implicit arguments. - - idxs: dict[AbstractedAxisName, DBIdx] = {} - implicit_types: list[AbstractValue] = [] - explicit_tracers: dict[TracerId, int] = {} - counter = it.count() - - # Add implicit arguments to idxs. - for explicit_idx, (x, spec) in enumerate(zip(args, specs)): - for i, name in spec.items(): - if name not in idxs and id(x.shape[i]) not in explicit_tracers: - idxs[name] = DBIdx(next(counter)) - implicit_types.append(get_aval(x.shape[i])) - if isinstance(x, Tracer): - explicit_tracers.setdefault(id(x), explicit_idx) # use the first - - # Now that we know the implicit args, add explicit args to idxs. - offset = len(implicit_types) - for x, spec in zip(args, specs): - for i, name in spec.items(): - if id(x.shape[i]) in explicit_tracers: - idxs.setdefault(name, DBIdx(offset + explicit_tracers[id(x.shape[i])])) - - return idxs, implicit_types - -def _arg_type( - idxs: dict[AbstractedAxisName, DBIdx], x: Any, - spec: dict[int, AbstractedAxisName] - ) -> AbstractValue: - # Produce an AbstractValue by substituting DBIdxs for AbstractedAxisNames. - aval = get_aval(x) # aval.shape could contain Tracers - if not spec: return aval - shape: list[int | DBIdx] = [idxs[spec[i]] if i in spec else d - for i, d in enumerate(aval.shape)] - assert not any(isinstance(d, Tracer) for d in shape) - return DShapedArray(tuple(shape), aval.dtype, False) - -def _add_implicit_outputs(jaxpr: Jaxpr) -> tuple[Jaxpr, OutputType]: - invars = [*jaxpr.constvars, *jaxpr.invars] - expl_outvars = jaxpr.outvars - - # First do a pass to collect implicit outputs, meaning variables which occur - # in explicit_outvars types but not in invars or to the left in outvars. - seen: set[Var] = set(invars) - impl_outvars = [seen.add(d) or d for x in expl_outvars if type(x) is Var and # type: ignore - (seen.add(x) or type(x.aval) is DShapedArray) # type: ignore - for d in x.aval.shape if type(d) is Var and d not in seen] - outvars = [*impl_outvars, *expl_outvars] - - # Now assemble an OutputType by mapping vars in shapes to InDBIdx/OutDBIdx. - in_map : dict[Var, InDBIdx] = {v: InDBIdx(i) for i, v in enumerate( invars)} - out_map: dict[Var, OutDBIdx] = {x: OutDBIdx(i) for i, x in enumerate(outvars) - if type(x) is Var} - out_avals_ = (x.aval for x in outvars) - out_avals = [a.update(shape=tuple(in_map.get(d, out_map.get(d)) - if type(d) is Var else d for d in a.shape)) - if type(a) is DShapedArray else a for a in out_avals_] - kept_outs = [False] * len(impl_outvars) + [True] * len(expl_outvars) - out_type = tuple(zip(out_avals, kept_outs)) - - new_jaxpr = Jaxpr(jaxpr.constvars, jaxpr.invars, outvars, jaxpr.eqns, - jaxpr.effects, jaxpr.debug_info) - config.enable_checks.value and core.check_jaxpr(jaxpr) - return new_jaxpr, out_type - - class TracerAsName: ref: Any def __init__(self, tracer): @@ -2399,165 +2511,9 @@ def __eq__(self, other): def __hash__(self): return id(self.ref) -def _extract_implicit_args( - trace: DynamicJaxprTrace, in_type: Sequence[tuple[AbstractValue, bool]], - explicit_tracers: Sequence[DynamicJaxprTracer] - ) -> Sequence[DynamicJaxprTracer]: - # First, construct a list to represent the full argument list, leaving the - # implicit arguments as Nones for now. - explicit_tracers_ = iter(explicit_tracers) - tracers = [next(explicit_tracers_) if expl else None for _, expl in in_type] - assert next(explicit_tracers_, None) is None - del explicit_tracers_ - - # Next, populate the implicit arguments using DBIdxs in in_type. - for i, (aval, explicit) in enumerate(in_type): - if not explicit or not isinstance(aval, DShapedArray): - continue # can't populate an implicit argument - tracer = tracers[i] - assert tracer is not None - for d1, d2 in zip(aval.shape, tracer.aval.shape): - if isinstance(d1, DBIdx): - if tracers[d1.val] is None: - tracers[d1.val] = trace.to_jaxpr_tracer(d2) - assert tracers[d1.val] is trace.to_jaxpr_tracer(d2) - assert all(t is not None for t in tracers) - return [t for t, (_, e) in zip(tracers, in_type) if not e] # type: ignore - -def _input_type_to_tracers( - new_arg: Callable[[AbstractValue], Tracer], - in_avals: Sequence[AbstractValue] - ) -> Sequence[Tracer]: - # Create input Tracers given input AbstractValues, each of which can contain - # DeBruijn indices which refer to positions in the input argument list. That - # is, each element `a` of `in_avals` can have DBIdx instances in its shape, - # which must refer to positions left of `a`'s. - in_tracers: list[Tracer] = [] - - def _substitute_tracers_in_aval(a: AbstractValue) -> AbstractValue: - if isinstance(a, DShapedArray) and any(type(d) is DBIdx for d in a.shape): - shape = [in_tracers[d.val] if type(d) is DBIdx else d for d in a.shape] - return a.update(shape=tuple(shape)) - return a - - for a in in_avals: - in_tracers.append(new_arg(_substitute_tracers_in_aval(a))) - return in_tracers - -def _substitute_vars_in_type( - consts: dict[Var, Literal], env: dict[Var, Var], a: AbstractValue - ) -> AbstractValue: - if isinstance(a, DShapedArray) and any(isinstance(d, Var) for d in a.shape): - shape = [consts[d].val if d in consts else env[d] # type: ignore - if isinstance(d, Var) else d for d in a.shape] - return a.update(shape=tuple(shape)) - else: - return a - Const = Any Val = Any -def pad_jaxpr(jaxpr: Jaxpr, consts: Sequence[Const] - ) -> tuple[Jaxpr, list[Const]]: - bounds = {v: v.aval.dtype.bound for v in jaxpr.invars - if isinstance(v.aval, core.UnshapedArray) and - type(v.aval.dtype) is core.bint and not v.aval.shape} - idxs = {v: DBIdx(i) for i, v in enumerate(jaxpr.invars)} - - def substitute(aval: AbstractValue) -> AbstractValue: - if (isinstance(aval, core.UnshapedArray) and type(aval.dtype) is core.bint - and not aval.shape): - return ShapedArray((), dtypes._scalar_type_to_dtype(int)) - elif isinstance(aval, DShapedArray): - shape = [bounds.get(d, idxs.get(d, d)) for d in aval.shape] # type: ignore - typ = ShapedArray if all(type(d) is int for d in shape) else DShapedArray - return typ(tuple(shape), aval.dtype, aval.weak_type) - else: - return aval - - in_avals = [substitute(v.aval) for v in jaxpr.invars] - eval_padded = lu.wrap_init(partial(_eval_jaxpr_padded, jaxpr, consts), - debug_info=jaxpr.debug_info) - padded_jaxpr, _, padded_consts, () = trace_to_jaxpr_dynamic(eval_padded, in_avals) - return padded_jaxpr, padded_consts - -class BoundedAxisSize(NamedTuple): - val: int | DynamicJaxprTracer - bound: int - -def _eval_jaxpr_padded( - jaxpr: Jaxpr, consts: Sequence[Const], *args: DynamicJaxprTracer - ) -> list[Const | DynamicJaxprTracer]: - env: dict[Var, Val] = {} - - def read(x): - return x.val if type(x) is Literal else env[x] - - def write(v, val) -> None: - env[v] = val - - foreach(write, jaxpr.constvars, consts) - foreach(write, jaxpr.invars, args) - last_used = core.last_used(jaxpr) - for eqn in jaxpr.eqns: - in_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.invars] - out_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.outvars] - rule = padding_rules[eqn.primitive] - outs = rule(in_avals, out_avals, *map(read, eqn.invars), **eqn.params) - foreach(write, eqn.outvars, outs) - core.clean_up_dead_vars(eqn, env, last_used) - return map(read, jaxpr.outvars) - -def _substitute_axis_sizes(env: dict, aval: AbstractValue) -> AbstractValue: - if isinstance(aval, DShapedArray): - shp = [] - for d in aval.shape: - if isinstance(d, core.DArray): - assert not d.shape and type(d.dtype) is core.bint - shp.append(BoundedAxisSize(int(d._data), int(d.dtype.bound))) - elif (type(d) is core.Var and isinstance(d.aval, core.DShapedArray) and - type(d.aval.dtype) is core.bint): - assert not d.aval.shape - shp.append(BoundedAxisSize(env[d], d.aval.dtype.bound)) - else: - shp.append(env.get(d, d)) - return DShapedArray(tuple(shp), aval.dtype, aval.weak_type) - else: - return aval - -def _is_bint_axis_size(d: int | core.DArray | core.Var) -> bool: - if isinstance(d, core.DArray): - assert not d.shape # pytype: disable=attribute-error - return type(d.dtype) is core.bint # pytype: disable=attribute-error - elif isinstance(d, core.Var): - return (isinstance(d.aval, core.DShapedArray) and # pytype: disable=attribute-error - type(d.aval.dtype) is core.bint) # pytype: disable=attribute-error - return False - - -padding_rules: dict[Primitive, Callable] = {} - -def def_trivial_padding(prim: Primitive) -> None: - if prim.multiple_results: - padding_rules[prim] = partial(_trivial_padding_rule_multi, prim) - else: - padding_rules[prim] = partial(_trivial_padding_rule, prim) - -def _trivial_padding_rule(prim, _, __, *args, **params): - return [prim.bind(*args, **params)] - -def _trivial_padding_rule_multi(prim, _, __, *args, **params): - return prim.bind(*args, **params) - -def call_padding_rule(prim, in_avals, out_avals, *args, call_jaxpr, **params): - if call_jaxpr.constvars: raise NotImplementedError - padded_jaxpr, padded_consts = pad_jaxpr(call_jaxpr, ()) - if padded_consts: raise NotImplementedError - new_params = dict(params, call_jaxpr=padded_jaxpr) - subfuns, bind_params = prim.get_bind_params(new_params) - return prim.bind(*subfuns, *args, **bind_params) - - def instantiate_const_at(trace: JaxprTrace, instantiate: bool, tracer): if instantiate: return trace.instantiate_const(tracer) @@ -2565,33 +2521,44 @@ def instantiate_const_at(trace: JaxprTrace, instantiate: bool, tracer): return tracer def inline_jaxpr_into_trace( - trace: DynamicJaxprTrace, jaxpr: Jaxpr, consts: Sequence[Any], - *arg_tracers: DynamicJaxprTracer) -> list[Any]: + trace: DynamicJaxprTrace, src: SourceInfo, jaxpr: Jaxpr, + consts: Sequence[Any], *arg_tracers: DynamicJaxprTracer) -> list[Any]: # This function is conceptually the same thing as just calling eval_jaxpr, - const_tracers = map(trace.new_const, consts) - constvars = map(trace.getvar, const_tracers) - argvars = map(trace.getvar, arg_tracers) - env: dict[Var, Var] = dict(zip([*jaxpr.constvars, *jaxpr.invars], - [*constvars, *argvars])) + const_tracers = map(partial(trace.new_const, source_info=src), consts) + env: dict[Var, DynamicJaxprTracer] = dict( + zip([*jaxpr.constvars, *jaxpr.invars], + [*const_tracers, *arg_tracers])) + + def inline_atom(src_, x): + if isinstance(x, Literal): + return DynamicJaxprTracer(trace, x.aval, x, src_) + else: + return env[x] - src = source_info_util.current() for eqn in jaxpr.eqns: - invars = [x if isinstance(x, Literal) else env[x] for x in eqn.invars] - outvars = [Var('', v.aval) for v in eqn.outvars] src_ = (src if not eqn.source_info.name_stack else src.replace(name_stack=src.name_stack + eqn.source_info.name_stack)) - trace.frame.add_eqn(eqn.replace(invars, outvars, source_info=src_)) - foreach(env.setdefault, eqn.outvars, outvars) - - tracer_env: dict[Var, Any] = dict(zip([*jaxpr.constvars, *jaxpr.invars], - [*consts, *arg_tracers])) - def new_tracer(atom): - tracer = tracer_env[atom] = DynamicJaxprTracer(trace, atom.aval, src) - trace.frame.tracers.append(tracer) - trace.frame.tracer_to_var[id(tracer)] = env[atom] - return tracer - return [x.val if isinstance(x, Literal) else tracer_env[x] if x in tracer_env - else new_tracer(x) for x in jaxpr.outvars] + in_tracers = map(partial(inline_atom, src_), eqn.invars) + out_avals = [v.aval for v in eqn.outvars] + + maybe_consts = try_constant_folding(eqn.primitive, in_tracers, eqn.params, out_avals) + if maybe_consts is not None: + out_tracers = [trace.new_const(c, source_info=src_, aval=aval) + for c, aval in zip(maybe_consts, out_avals)] + else: + out_tracers = trace.emit_eqn(in_tracers, out_avals, eqn.primitive, + eqn.params, eqn.effects, src_, eqn.ctx) + foreach(env.setdefault, eqn.outvars, out_tracers) + + return map(partial(inline_atom, src), jaxpr.outvars) + + +def try_constant_folding(primitive, tracers, params, out_avals): + if primitive in const_fold_rules: + consts_in = [t.get_const() for t in tracers] + if any(c is not None for c in consts_in): + return const_fold_rules[primitive](consts_in, params, out_avals) + return None # TODO(mattjj,dougalm): this special handling is to avoid round-tripping the # jaxpr when we do grad-of-pmap. The tag is set by LinearizeTrace.process_call's @@ -2602,3 +2569,56 @@ def _linearize_of_pmap_hack(f: lu.WrappedFun, jaxpr, consts) -> tuple[Jaxpr, lis _, jaxpr = f.f.closure return convert_constvars_jaxpr(jaxpr), [] return jaxpr, consts + + +@weakref_lru_cache +def lower_jaxpr(hi_jaxpr: core.ClosedJaxpr): + lo_avals = [lo_ty for aval in hi_jaxpr.in_aval_qdds for lo_ty in aval.lo_ty()] + f = lu.wrap_init(partial(lower_traceable, hi_jaxpr), + debug_info=hi_jaxpr.jaxpr.debug_info.with_unknown_names()) + lo_jaxpr, _, lo_consts = trace_to_jaxpr_dynamic(f, lo_avals, lower=True) + return core.ClosedJaxpr(lo_jaxpr, lo_consts) + +def lower_traceable(jaxpr, *lo_args): + lo_args_ = iter(lo_args) + hi_args = [aval.raise_val(*it.islice(lo_args_, len(aval.lo_ty()))) + if not aval.has_qdd else + aval.new_from_loval(*it.islice(lo_args_, len(aval.lo_ty()))) + for aval in jaxpr.in_aval_qdds] + assert (problem := next(lo_args_, None)) is None + hi_outs = core.jaxpr_as_fun(jaxpr)(*hi_args) + mut_outs = [lo_val for aval, hi_arg in zip(jaxpr.final_aval_qdds, hi_args) if aval.has_qdd + for lo_val in aval.read_loval(hi_arg)] + lo_outs = [lo_val for v, hi_val in zip(jaxpr.jaxpr.outvars, hi_outs) + for lo_val in v.aval.lower_val(hi_val)] + return mut_outs + lo_outs + +@weakref_lru_cache +def convert_const_himutables(jaxpr): + move = [typeof(c).has_qdd for c in jaxpr.consts] + constvals, in_mutables = partition_list(move, jaxpr.consts) + constvars, boxvars = partition_list(move, jaxpr.jaxpr.constvars) + invars = *boxvars, *jaxpr.jaxpr.invars + effects = make_jaxpr_effects(constvars, invars, jaxpr.jaxpr.outvars, + jaxpr.jaxpr.eqns) + new_jaxpr = jaxpr.jaxpr.replace(constvars=constvars, invars=invars, + effects=effects) + return jaxpr.replace(jaxpr=new_jaxpr, consts=constvals), in_mutables + +def num_himuts_out(jaxpr): + return sum(len(a.lo_ty()) for a in jaxpr.final_aval_qdds if a.has_qdd) + +def apply_himut(jaxpr: Jaxpr | ClosedJaxpr, hi_args, out_mut): + out_mut_ = iter(out_mut) + for i, v in enumerate(jaxpr.invars): + if v.final_qdd is not None: + qdd = v.final_qdd + lo_vals = it.islice(out_mut_, len(v.aval.lo_ty_qdd(qdd))) + v.aval.update_from_loval(qdd, hi_args[i], *lo_vals) # type: ignore + assert next(out_mut_, None) is None + +def raise_lo_outs(avals, lo_outs): + lo_outs_ = iter(lo_outs) + hi_outs = [t.raise_val(*it.islice(lo_outs_, len(t.lo_ty()))) for t in avals] + assert next(lo_outs_, None) is None + return hi_outs diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index c06eda5214ed..eab069191edd 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -15,12 +15,11 @@ from __future__ import annotations -import enum import collections from collections import namedtuple from collections.abc import Callable, Sequence, Iterable import dataclasses -from functools import partial, lru_cache, cached_property +from functools import partial, cached_property import functools import itertools as it import logging @@ -30,34 +29,35 @@ import numpy as np -import jax - from jax._src import api +from jax._src import array from jax._src import compiler from jax._src import config from jax._src import core from jax._src import dispatch from jax._src import dtypes from jax._src import effects +from jax._src import literals +from jax._src import jaxpr_util from jax._src import linear_util as lu from jax._src import op_shardings from jax._src import sharding_specs +from jax._src import pjit from jax._src import profiler from jax._src import sharding_impls -from jax._src import source_info_util from jax._src import stages from jax._src import tree_util +from jax._src import typing from jax._src import util from jax._src import xla_bridge as xb from jax._src.abstract_arrays import array_types -from jax._src.core import DShapedArray from jax._src.core import ShapedArray from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import mlir -from jax._src.interpreters import xla -from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout +from jax._src.layout import Layout, AutoLayout, Format +from jax._src.lib import _jax from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo @@ -66,13 +66,14 @@ from jax._src.mesh import (AbstractMesh, Mesh, get_abstract_mesh, get_concrete_mesh) from jax._src.sharding_impls import ( - ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UnspecifiedValue, - get_array_mapping as _get_array_mapping, array_mapping_to_axis_resources, - SingleDeviceSharding, GSPMDSharding, NamedSharding, PositionalSharding) + ArrayMapping, AUTO, UnspecifiedValue, array_mapping_to_axis_resources, + SingleDeviceSharding, GSPMDSharding, NamedSharding, PartitionSpec as P) from jax._src.util import (safe_map, safe_zip, partition_list, wrap_name, - tuple_update, tuple_delete, distributed_debug_log, - unzip2, HashableFunction, weakref_lru_cache) + tuple_update, distributed_debug_log, + unzip2, HashableFunction, weakref_lru_cache, + tuple_insert) from jax._src.state.types import AbstractRef, RefEffect +from jax._src.typing import ArrayLike # Built in Python lists don't support weak refs but subclasses of lists do. @@ -80,9 +81,8 @@ class WeakRefList(list): pass -xe = xc._xla - unsafe_map, map = map, safe_map # type: ignore +zip, unsafe_zip = safe_zip, zip # type: ignore logger = logging.getLogger(__name__) @@ -103,44 +103,32 @@ class WeakRefList(list): ### util - -def to_xc_copy_semantics(copy_semantics): - out = [] - for cs in copy_semantics: - if cs is None or cs == dispatch.CopySemantics.ALIAS: - out.append(xc.ArrayCopySemantics.REUSE_INPUT) - elif cs == dispatch.CopySemantics.COPY: - out.append(xc.ArrayCopySemantics.ALWAYS_COPY) - elif cs == dispatch.CopySemantics.DONATE: - out.append(xc.ArrayCopySemantics.DONATE_INPUT) - else: - assert isinstance(cs, xc.ArrayCopySemantics) - out.append(cs) - return out - - def identity(x): return x + @profiler.annotate_function -def shard_args(shardings: Sequence[JSharding], layouts, copy_semantics, - args, canonicalize=True) -> Sequence[xc.ArrayImpl]: - xc_copy_semantics = to_xc_copy_semantics(copy_semantics) - del copy_semantics +def shard_args( + shardings: Sequence[JSharding], + layouts: Sequence[Any | None], + copy_semantics: Sequence[xc.ArrayCopySemantics], + args: Sequence[Any], + canonicalize: bool = True, +) -> Sequence[xc.ArrayImpl]: # Fast path for one argument. if len(args) == 1: arg = args[0] if canonicalize: - arg = xla.canonicalize_dtype(arg) + arg = dtypes.canonicalize_value(arg) return shard_arg_handlers[type(arg)]([arg], shardings, layouts, - xc_copy_semantics) + copy_semantics) # type(arg) -> (list[indices], list[args], list[shardings], list[layouts], # list[copy_semantics]) batches = collections.defaultdict(lambda: ([], [], [], [], [])) # type: ignore for i, (arg, sharding, layout, cs) in enumerate( - safe_zip(args, shardings, layouts, xc_copy_semantics)): + safe_zip(args, shardings, layouts, copy_semantics)): if canonicalize: - arg = xla.canonicalize_dtype(arg) + arg = dtypes.canonicalize_value(arg) batch = batches[type(arg)] batch[0].append(i) batch[1].append(arg) @@ -152,9 +140,9 @@ def shard_args(shardings: Sequence[JSharding], layouts, copy_semantics, # from each call in the same order as `args`. Since `batches` is grouped by # types, we cannot simply flatten the results and we have to use the original # indices to put each array back to its original position. - results: list[jax.Array | None] = [None] * len(args) - for t, (indices, a, s, l, cs) in batches.items(): - outs = shard_arg_handlers[t](a, s, l, cs) + results: list[typing.Array | None] = [None] * len(args) + for t, (indices, a, s, l, xcs) in batches.items(): + outs = shard_arg_handlers[t](a, s, l, xcs) for i, out in safe_zip(indices, outs): results[i] = out assert all(result is not None for result in results) @@ -162,12 +150,16 @@ def shard_args(shardings: Sequence[JSharding], layouts, copy_semantics, shard_arg_handlers: dict[ - Any, Callable[[Sequence[Any], Sequence[Any], Sequence[Any], Sequence[Any]], - Sequence[Any]] + Any, + Callable[ + [Sequence[Any], Sequence[Any], Sequence[Any], + Sequence[xc.ArrayCopySemantics]], + Sequence[Any], + ], ] = {} -@lru_cache(maxsize=2048) +@util.cache(max_size=2048, trace_context_in_key=False) def is_default_layout(curr_layout, sharding, aval): if curr_layout is None or sharding is None or isinstance(sharding, UnspecifiedValue): return True @@ -183,9 +175,9 @@ def is_default_layout(curr_layout, sharding, aval): # int4. return is_user_xla_layout_equal( curr_layout, - DeviceLocalLayout.from_pjrt_layout( + Layout.from_pjrt_layout( d.client.get_default_layout(aval.dtype, shard_shape, d))) - except xe.XlaRuntimeError as e: + except _jax.JaxRuntimeError as e: msg, *_ = e.args if isinstance(msg, str) and msg.startswith("UNIMPLEMENTED"): return True @@ -206,7 +198,7 @@ def _shard_np_array(xs, shardings, layouts, copy_semantics): x = np.zeros(x.shape, dtype=np.dtype(bool)) aval = core.shaped_abstractify(x) if layout is not None: - results.append(api.device_put(x, Layout(layout, sharding))) + results.append(api.device_put(x, Format(layout, sharding))) else: if sharding.is_fully_replicated: shards = [x] * len(devices) @@ -218,21 +210,34 @@ def _shard_np_array(xs, shardings, layouts, copy_semantics): for _t in array_types: shard_arg_handlers[_t] = _shard_np_array -def _shard_darray(xs, shardings, layouts, copy_semantics): - return shard_args(shardings, layouts, copy_semantics, [x._data for x in xs]) -shard_arg_handlers[core.DArray] = _shard_darray +shard_arg_handlers[literals.TypedNdArray] = _shard_np_array + +def _shard_python_scalar(xs, shardings, layouts, copy_semantics): + return shard_args(shardings, layouts, copy_semantics, + [np.array(x) for x in xs]) +for _t in dtypes.python_scalar_types: + shard_arg_handlers[_t] = _shard_python_scalar + +def _shard_typed_scalar(xs, shardings, layouts, copy_semantics): + return _shard_np_array( + [literals.TypedNdArray(np.array(x, dtype=x.dtype), weak_type=True) + for x in xs], + shardings, layouts, copy_semantics + ) +for _t in literals.typed_scalar_types: + shard_arg_handlers[_t] = _shard_typed_scalar def _shard_mutable_array(xs, shardings, layouts, copy_semantics): - return shard_args(shardings, layouts, copy_semantics, [x._buf for x in xs]) -shard_arg_handlers[core.MutableArray] = _shard_mutable_array + bufs = [x._refs._buf for x in xs] + return shard_args(shardings, layouts, copy_semantics, bufs) +shard_arg_handlers[core.Ref] = _shard_mutable_array def batched_device_put(aval: core.ShapedArray, sharding: JSharding, xs: Sequence[Any], - devices: Sequence[jax.Device], committed: bool = True): + devices: Sequence[xc.Device], committed: bool = True, + enable_x64: bool | None = None): util.test_event("batched_device_put_start") try: - from jax._src import array - bufs = [x for x, d in safe_zip(xs, devices) if (isinstance(x, array.ArrayImpl) and dispatch.is_single_device_sharding(x.sharding) and @@ -240,7 +245,8 @@ def batched_device_put(aval: core.ShapedArray, if len(bufs) == len(xs) > 0: return array.ArrayImpl( aval, sharding, bufs, committed=committed, _skip_checks=True) - return xc.batched_device_put(aval, sharding, xs, list(devices), committed) + return xc.batched_device_put(aval, sharding, xs, list(devices), committed, + enable_x64=enable_x64) finally: util.test_event("batched_device_put_end") @@ -257,11 +263,8 @@ def _shard_abstract_array(size, axis: int, x): raise ValueError(f"Axis size {size} does not match dimension {axis} of " f"shape {x.shape}") except IndexError: - raise ValueError("Cannot split a {x.dim}D value along axis {axis}") from None - if config.pmap_no_rank_reduction.value: - return x.update(shape=tuple_update(x.shape, axis, 1)) - else: - return x.update(shape=tuple_delete(x.shape, axis)) + raise ValueError(f"Cannot split a {x.dim}D value along axis {axis}") from None + return x.update(shape=tuple_update(x.shape, axis, 1)) _shard_aval_handlers[ShapedArray] = _shard_abstract_array @@ -290,7 +293,7 @@ def local_aval_to_result_handler( raise TypeError( f"No pxla_result_handler for type: {type(aval)}") from err -PxlaResultHandler = Callable[..., Callable[[Any], Any]] +PxlaResultHandler = Callable[..., xc._xla.ResultHandler] local_result_handlers: dict[type[core.AbstractValue], PxlaResultHandler] = {} @@ -337,9 +340,9 @@ def xla_pmap_impl_lazy( out_axes_thunk: Callable[[], Sequence[int | None]], donated_invars: Sequence[bool], is_explicit_global_axis_size: bool, -) -> Callable: - if (config.disable_jit.value and config.eager_pmap.value and - not is_explicit_global_axis_size and not any(d for d in donated_invars)): +) -> tuple[Callable, list[ArrayLike]]: + if (config.disable_jit.value and + not is_explicit_global_axis_size and not any(donated_invars)): def _emap_apply_fn(*args): return _emap_impl(fun, *args, backend=backend, axis_name=axis_name, axis_size=axis_size, global_axis_size=global_axis_size, @@ -347,9 +350,9 @@ def _emap_apply_fn(*args): out_axes_thunk=out_axes_thunk, donated_invars=donated_invars, is_explicit_global_axis_size=is_explicit_global_axis_size) - return _emap_apply_fn + return _emap_apply_fn, [] abstract_args = unsafe_map(core.abstractify, args) - compiled_fun, fingerprint = parallel_callable( + compiled_fun, fingerprint, const_args = parallel_callable( fun, backend, axis_name, axis_size, global_axis_size, devices, name, in_axes, out_axes_thunk, donated_invars, is_explicit_global_axis_size, *abstract_args) @@ -361,11 +364,11 @@ def _emap_apply_fn(*args): ("devices", devices), ("abstract args", map(core.abstractify, args)), ("fingerprint", fingerprint)) - return compiled_fun + return compiled_fun, const_args def xla_pmap_impl(fun: lu.WrappedFun, *args, **params): - compiled_fun = xla_pmap_impl_lazy(fun, *args, **params) - return compiled_fun(*args) + compiled_fun, const_args = xla_pmap_impl_lazy(fun, *args, **params) + return compiled_fun(*const_args, *args) class EmapInfo(NamedTuple): backend: str | None @@ -383,7 +386,6 @@ def _emap_impl(fun: lu.WrappedFun, *args, donated_invars: Sequence[bool], is_explicit_global_axis_size: bool, ): - from jax._src import array # TODO(sharadmv,mattjj): implement these cases if any(d for d in donated_invars): raise NotImplementedError("Buffer donation not supported in eager pmap.") @@ -407,13 +409,14 @@ def _emap_impl(fun: lu.WrappedFun, *args, platform = xb.get_backend(backend).platform donate_argnums = (1,) if platform in {"cuda", "rocm", "tpu"} else () new_outvals = [] + assert len(out_axes_src) == len(out_axes) for out_axis_src, out_axis, outval in zip(out_axes_src, out_axes, outvals): - with jax.disable_jit(False): + with api.disable_jit(False): donate_argnums_ = donate_argnums if isinstance(outval, array.ArrayImpl): # We don't want to donate if it's already sharded. donate_argnums_ = () - out = jax.pmap( + out = api.pmap( lambda _, x: x, in_axes=(0, out_axis_src.get(axis_name)), out_axes=out_axis, @@ -438,7 +441,7 @@ def _map_schedule(idx: tuple[int | None, ...]) -> tuple[int | None, ...]: # still ends up not working, because it has a separate cache per # _function object_. Adding this annotation here lets us reuse the same pmap # callable for all equivalent primitive pmaps. -@lru_cache +@util.cache(max_size=None, trace_context_in_key=False) def _multi_pmap(f: Callable, info: EmapInfo, names: list[core.AxisName], all_axes: list[tuple[int | None, ...]] ) -> tuple[Callable, dict[core.AxisName, int]]: @@ -446,7 +449,7 @@ def _multi_pmap(f: Callable, info: EmapInfo, names: list[core.AxisName], for i, name in reversed(list(enumerate(names))): in_axes = tuple(arg_axis[i] for arg_axis in all_axes) if any(in_axis is not None for in_axis in in_axes): - f = jax.pmap( + f = api.pmap( f, in_axes=in_axes, axis_name=name, @@ -474,12 +477,14 @@ def to_map_tracer(self, val): return MapTracer(self, val, {}) def process_primitive(self, primitive, tracers, params): - if primitive is jax._src.lax.parallel.axis_index_p: - return self.process_axis_index(**params) - if primitive is jax._src.lax.parallel.psum_p: + from jax._src.lax import parallel # pytype: disable=import-error + if primitive is parallel.axis_index_p: + return self.process_axis_index(**params) # pytype: disable=missing-parameter + if primitive is parallel.psum_p: f = HashableFunction( - lambda *xs: jax._src.lax.parallel.psum( - xs, axis_name=params['axes'], axis_index_groups=params['axis_index_groups']), + lambda x: parallel.psum( + x, axis_name=params['axes'], + axis_index_groups=params['axis_index_groups']), (primitive, tuple(params.items()))) else: f = HashableFunction(lambda *args: primitive.bind(*args, **params), @@ -490,7 +495,7 @@ def process_primitive(self, primitive, tracers, params): names = core.get_axis_env().axis_names() all_axes = tuple(_map_schedule(map(s.get, names)) for s in shard_axes) # pytype: disable=wrong-arg-types # always-use-return-annotations f_mapped, out_shard_axes = _multi_pmap(f, self.emap_info, names, all_axes) - with core.eval_context(), jax.disable_jit(False): + with core.eval_context(), api.disable_jit(False): outvals = f_mapped(*vals) if primitive.multiple_results: return [MapTracer(self, val, out_shard_axes) for val in outvals] @@ -544,11 +549,12 @@ def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, return fun.call_wrapped(*tracers) def process_axis_index(self, axis_name): + from jax._src.lax import lax, parallel # pytype: disable=import-error bind = HashableFunction( - lambda _: jax.lax.axis_index(axis_name), - (jax.lax.axis_index, axis_name)) + lambda _: parallel.axis_index(axis_name), + (parallel.axis_index, axis_name)) fake_primitive = FakePrimitive(multiple_results=False, bind=bind) - range = jax.lax.iota(np.int32, core.get_axis_env().axis_size(axis_name)) + range = lax.iota(np.int32, core.get_axis_env().axis_size(axis_name)) dummy_tracer = MapTracer(self, range, {axis_name: 0}) return self.process_primitive(fake_primitive, (dummy_tracer,), {}) @@ -573,7 +579,7 @@ def _match_annot(axis_name: core.AxisName, axis_size: int, val: Any, outval = batching.moveaxis(val, src, dst) shard_axis_out = _moveaxis(np.ndim(val), shard_axis_src, src, dst) elif src is None and dst is not None: - outval = batching.broadcast(val, axis_size, dst) + outval = batching.broadcast(val, axis_size, dst, None) shard_axis_out = {n: d + (dst <= d) for n, d in shard_axis_out.items()} else: raise NotImplementedError @@ -638,7 +644,8 @@ def parallel_callable(fun: lu.WrappedFun, closed_jaxpr=closed_jaxpr, backend=xc_backend, replicas=replicas, shards=shards, pci=pci) pmap_executable = pmap_computation.compile() - return WeakRefList([pmap_executable.unsafe_call, pmap_executable.fingerprint]) + return WeakRefList([pmap_executable.unsafe_call, pmap_executable.fingerprint, + pmap_computation.const_args]) @dataclasses.dataclass(frozen=True) @@ -682,25 +689,56 @@ class ReplicaInfo(NamedTuple): num_global_replicas: int +_initial_style_primitives: set[core.Primitive] = set() + + +def register_initial_style_primitive(prim: core.Primitive): + _initial_style_primitives.add(prim) + +def _jaxpr_replicas(jaxpr: core.Jaxpr) -> int: + """The number of replicas needed for a jaxpr. + + For a eqn, multiply the `axis_size` with the `jaxpr_replicas` of the + subjaxprs. For a list of eqns, take the maximum number of replicas. + """ + return max(unsafe_map(_eqn_replicas, jaxpr.eqns), default=1) + +# TODO(mattjj): this function assumes that only pmap has a parameter named +# axis_size, and that it corresponds to cross-replica mapping +def _eqn_replicas(eqn: core.JaxprEqn) -> int: + call_jaxpr = eqn.params.get("call_jaxpr") + if call_jaxpr: + return eqn.params.get('axis_size', 1) * _jaxpr_replicas(call_jaxpr) + elif eqn.primitive in _initial_style_primitives: + return _initial_style_primitive_replicas(eqn.params) + else: + return 1 + +def _initial_style_primitive_replicas(params: dict[str, Any]) -> int: + return max(core.traverse_jaxpr_params(_jaxpr_replicas, params).values(), + default=1) + + def find_replicas( jaxpr: core.Jaxpr, axis_size: int, global_axis_size: int ) -> ReplicaInfo: # TODO(skyewm): replace this with a chain of pmaps and/or sharded_jits - jaxpr_replicas = dispatch.jaxpr_replicas(jaxpr) + jaxpr_replicas = _jaxpr_replicas(jaxpr) num_local_replicas = axis_size * jaxpr_replicas num_global_replicas = global_axis_size * jaxpr_replicas return ReplicaInfo(jaxpr_replicas, num_local_replicas, num_global_replicas) @lu.transformation2 def _change_argument_ranks(f, in_axes, out_axes_thunk, *args): + from jax._src.lax import lax # pytype: disable=import-error args = tuple( - arg if in_axis is None else jax.lax.squeeze(arg, dimensions=(in_axis,)) + arg if in_axis is None else lax.squeeze(arg, dimensions=(in_axis,)) for in_axis, arg in zip(in_axes, args) ) results = f(*args) out_axes = out_axes_thunk() return tuple( - x if axis is None else jax.lax.expand_dims(x, dimensions=(axis,)) + x if axis is None else lax.expand_dims(x, dimensions=(axis,)) for x, axis in zip(results, out_axes) ) @@ -713,16 +751,13 @@ def stage_parallel_callable( for axis, aval in safe_zip(pci.in_axes, pci.avals)) orig_fun = fun - if config.pmap_no_rank_reduction.value: - fun = _change_argument_ranks(fun, pci.in_axes, pci.out_axes_thunk) - else: - fun = orig_fun + fun = _change_argument_ranks(fun, pci.in_axes, pci.out_axes_thunk) with core.extend_axis_env_nd([(pci.axis_name, pci.global_axis_size)]): with dispatch.log_elapsed_time( "Finished tracing + transforming {fun_name} for pmap in {elapsed_time} sec", fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT): - jaxpr, out_sharded_avals, consts, _ = pe.trace_to_jaxpr_dynamic( - fun, sharded_avals) + jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_dynamic( + fun.with_unknown_names(), sharded_avals) assert len(out_sharded_avals) == len(pci.out_axes), ( len(out_sharded_avals), len(pci.out_axes)) @@ -795,6 +830,27 @@ def lower_parallel_callable( f"axis_size {axis_size}.") jaxpr = closed_jaxpr.jaxpr + arg_names = jaxpr._debug_info.safe_arg_names(len(closed_jaxpr.in_avals)) + const_args: Sequence[ArrayLike] + if lowering_parameters.hoist_constants_as_args: + const_args_and_avals = core.jaxpr_const_args(jaxpr) + const_args, const_arg_avals = unzip2(const_args_and_avals) + num_const_args = len(const_arg_avals) + in_axes = (None,) * num_const_args + in_axes # type: ignore + donated_invars = (False,) * num_const_args + donated_invars # type: ignore + jaxpr_avals = list(const_arg_avals) + closed_jaxpr.in_avals # type: ignore + shards = ShardInfo( + tuple(const_arg_avals) + shards.sharded_avals, # type: ignore + shards.out_sharded_avals, + tuple(const_arg_avals) + shards.global_sharded_avals, # type: ignore + shards.num_local_shards, shards.num_global_shards) + pci = dataclasses.replace(pci, in_axes=in_axes, + avals=tuple(const_arg_avals) + tuple(pci.avals)) + arg_names = ("",) * num_const_args + arg_names + else: + jaxpr_avals = closed_jaxpr.in_avals + const_args = [] + num_const_args = 0 no_nested_sharding = False must_run_on_all_devices = False @@ -851,11 +907,10 @@ def lower_parallel_callable( axis_env = sharding_impls.AxisEnv( replicas.num_global_replicas, (axis_name,), (global_axis_size,)) - name_stack = source_info_util.new_name_stack(wrap_name(name, 'pmap')) replicated_args = [axis is None for axis in in_axes] tuple_args = dispatch.should_tuple_args(len(shards.global_sharded_avals), backend.platform) - module_name = f"pmap_{fun.__name__}" + module_name = wrap_name('pmap', name) platforms = lowering_platforms or (backend.platform,) with core.extend_axis_env_nd([(axis_name, global_axis_size)]): ordered_effects = list( @@ -866,24 +921,26 @@ def lower_parallel_callable( effects.ordered_effects.filter_not_in(closed_jaxpr.effects)) with dispatch.log_elapsed_time( "Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time:.9f} sec", - fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT): + fun_name=module_name, event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT): lowering_result = mlir.lower_jaxpr_to_module( module_name, closed_jaxpr, + num_const_args=num_const_args, + in_avals=jaxpr_avals, ordered_effects=ordered_effects, backend=backend, platforms=platforms, axis_context=sharding_impls.ReplicaAxisContext(axis_env), - name_stack=name_stack, donated_args=donated_invars, replicated_args=replicated_args, arg_shardings=None, result_shardings=None, - arg_names=jaxpr._debug_info.safe_arg_names(len(jaxpr.invars)), + arg_names=arg_names, result_names=jaxpr._debug_info.safe_result_paths(len(jaxpr.outvars)), num_replicas=replicas.num_global_replicas, lowering_parameters=lowering_parameters) return PmapComputation(lowering_result.module, + list(const_args), platforms=platforms, pci=pci, replicas=replicas, shards=shards, tuple_args=tuple_args, @@ -911,9 +968,6 @@ def _pmap_unmap_shaped_array(size: int, axis: int | None, aval: ShapedArray def _pmap_unmapped_aval(size: core.AxisSize, axis: int | None, aval: core.AbstractValue) -> core.AbstractValue: - if not config.pmap_no_rank_reduction.value: - return core.unmapped_aval(size, axis, aval) - _, handler = _pmap_aval_mapping_handlers.get(type(aval), (None, None)) if handler is not None: return handler(size, axis, aval) @@ -921,22 +975,25 @@ def _pmap_unmapped_aval(size: core.AxisSize, axis: int | None, raise TypeError(f"no unmapping handler for {aval} of type {type(aval)}") -class PmapComputation(stages.XlaLowering): +class PmapComputation(stages.Lowering): _hlo: ir.Module _executable: PmapExecutable | None - def __init__(self, hlo: ir.Module, **compile_args): + def __init__(self, hlo: ir.Module, const_args: list[ArrayLike], + **compile_args): self._executable = None self._hlo = hlo + self.const_args = const_args self.compile_args = compile_args - # -- stages.XlaLowering overrides + # -- stages.Lowering overrides def stablehlo(self) -> ir.Module: return self._hlo @profiler.annotate_function - def compile(self, compiler_options=None) -> PmapExecutable: + def compile(self, compiler_options=None, *, device_assignment=None + ) -> PmapExecutable: if self._executable is None or compiler_options is not None: executable = UnloadedPmapExecutable.from_hlo( self._hlo, **self.compile_args, @@ -953,7 +1010,7 @@ def _cast_to_shaped_array(aval: core.AbstractValue) -> ShapedArray: @dataclasses.dataclass class UnloadedPmapExecutable: compiled: Any - backend: xb.XlaBackend + backend: xc.Client local_input_avals: Sequence[core.AbstractValue] input_shardings: Sequence[JSharding] local_output_avals: Sequence[ShapedArray] @@ -1097,9 +1154,15 @@ def from_hlo(hlo: ir.Module, with dispatch.log_elapsed_time( "Finished XLA compilation of {fun_name} in {elapsed_time:.9f} sec", fun_name=pci.name, event=dispatch.BACKEND_COMPILE_EVENT): + # `executable_devices` contains devices for output shardings of a pmapped + # function. It contains only local devices for correspondence with + # `PmapSharding`s, which also contain only local devices. + executable_devices = _create_device_list( + tuple(local_device_assignment.flat)) + assert executable_devices is not None compiled = compiler.compile_or_get_cached( pci.backend, hlo, device_assignment, compile_options, - host_callbacks) + host_callbacks, executable_devices) return UnloadedPmapExecutable( compiled=compiled, @@ -1115,7 +1178,7 @@ def from_hlo(hlo: ir.Module, jaxpr_debug_info=jaxpr_debug_info).load() -class PmapExecutable(stages.XlaExecutable): +class PmapExecutable(stages.Executable): __slots__ = ["xla_executable", "_unsafe_call", "build_unsafe_call", "fingerprint", "in_avals", "_unloaded_executable"] @@ -1135,7 +1198,7 @@ def unsafe_call(self) -> Callable[..., Any]: self._unsafe_call = self.build_unsafe_call() return self._unsafe_call # type: ignore - # -- stages.XlaExecutable overrides + # -- stages.Executable overrides def xla_extension_executable(self): return self.xla_executable @@ -1159,8 +1222,9 @@ class InputsHandler: def __init__(self, in_shardings, in_layouts, local_devices=None, input_indices=None): - self.handler = partial(shard_args, in_shardings, in_layouts, - [None] * len(in_shardings)) + self.handler = partial( + shard_args, in_shardings, in_layouts, + [xc.ArrayCopySemantics.REUSE_INPUT] * len(in_shardings)) self.in_shardings = in_shardings self.in_layouts = in_layouts self.local_devices = local_devices @@ -1267,8 +1331,8 @@ def _handle_token_bufs(self, token_bufs, sharded_token): for token in token_buf: assert isinstance(token.sharding, sharding_impls.SingleDeviceSharding) token_devices.append(token.sharding._device_assignment[0]) - s = PositionalSharding(token_devices) - global_token_array = jax.make_array_from_single_device_arrays( + s = NamedSharding(Mesh(token_devices, 'x'), P('x')) + global_token_array = array.make_array_from_single_device_arrays( (0,), s, token_buf ) dispatch.runtime_tokens.set_token_result( @@ -1277,6 +1341,10 @@ def _handle_token_bufs(self, token_bufs, sharded_token): @profiler.annotate_function def __call__(self, *args): + if config.no_execution.value: + raise RuntimeError( + f"JAX tried to execute function {self.name}, but the no_execution config " + "option is set") args = [x for i, x in enumerate(args) if i in self.kept_var_idx] if self.mut: args = [*args, *self.mut.in_mut] @@ -1285,24 +1353,21 @@ def __call__(self, *args): if (self.ordered_effects or self.has_unordered_effects or self.has_host_callbacks): input_bufs = self._add_tokens_to_inputs(input_bufs) - results = self.xla_executable.execute_sharded( - input_bufs, with_tokens=True - ) + results = self.xla_executable.execute_sharded(input_bufs, with_tokens=True) - result_token_bufs = results.disassemble_prefix_into_single_device_arrays( - len(self.ordered_effects)) + result_token_bufs = results.consume_with_handlers( + [lambda xs: xs] * len(self.ordered_effects), strict=False) sharded_runtime_token = results.consume_token() self._handle_token_bufs(result_token_bufs, sharded_runtime_token) else: results = self.xla_executable.execute_sharded(input_bufs) + handlers = self.out_handler.handlers if dispatch.needs_check_special(): - out_arrays = results.disassemble_into_single_device_arrays() - for arrays in out_arrays: - dispatch.check_special(self.name, arrays) - out = self.out_handler(out_arrays) - else: - out = results.consume_with_handlers(self.out_handler.handlers) + special_check = functools.partial( + dispatch.check_special_array, self.name) + handlers = [h.pre_wrap(special_check) for h in handlers] + out = results.consume_with_handlers(handlers) if (self.pgle_profiler is not None and self.pgle_profiler.is_running() and len(out) > 0): @@ -1314,7 +1379,8 @@ def __call__(self, *args): out_ = [] for i, o in zip(self.mut.out_mut, out): if i is not None: - args[i]._buf = o + try: args[i]._refs._buf._replace_with(o) # type: ignore + except AttributeError: pass # TODO(mattjj): remove float0 else: out_.append(o) return out_ @@ -1510,9 +1576,9 @@ def _extend_axis_env(env: sharding_impls.AxisEnv, name, size: int): env.sizes + (size,)) -def _pmap_lowering(ctx, *in_nodes, axis_name, +def _pmap_lowering(ctx: mlir.LoweringRuleContext, *in_nodes, axis_name, axis_size, global_axis_size, devices, name, - call_jaxpr, backend=None, in_axes, out_axes, + call_jaxpr: core.Jaxpr, backend=None, in_axes, out_axes, donated_invars, is_explicit_global_axis_size): del donated_invars # Unused. mlir.check_backend_matches(backend, ctx.module_context.platforms) @@ -1534,9 +1600,9 @@ def _pmap_lowering(ctx, *in_nodes, axis_name, axis_context=sharding_impls.ReplicaAxisContext(new_env)) sharded_outs, _ = mlir.jaxpr_subcomp( sub_ctx, call_jaxpr, - ctx.name_stack.extend(util.wrap_name(name, 'pmap')), + ctx.name_stack.extend(util.wrap_name('pmap', name)), mlir.TokenSet(), (), *in_nodes_sharded, - dim_var_values=ctx.dim_var_values) + dim_var_values=ctx.dim_var_values, const_lowering=ctx.const_lowering) out_avals = [v.aval for v in call_jaxpr.outvars] outs = [_hlo_unshard(ctx, aval, new_env, out_axis, shard) for aval, out_axis, shard in zip(out_avals, out_axes, sharded_outs)] @@ -1651,67 +1717,10 @@ def check_if_any_auto( return True return False -class MismatchType(enum.Enum): - ARG_SHARDING = 0 - OUT_SHARDING = 1 - SHARDING_INSIDE_COMPUTATION = 2 - CONTEXT_DEVICES = 3 - IN_SHARDING = 4 - - def __str__(self): - if self.name == 'IN_SHARDING': - return 'explicit input sharding' - elif self.name == 'OUT_SHARDING': - return 'explicit output sharding' - elif self.name == 'CONTEXT_DEVICES': - return 'context mesh' - return f'{self.name}' - - -@dataclasses.dataclass -class DeviceAssignmentMismatch: - da: Sequence[xc.Device] - m_type: MismatchType - source_info: dispatch.SourceInfo | None - - @property - def device_ids(self) -> Sequence[int]: - return [d.id for d in self.da] - - @property - def platform(self) -> str: - return self.da[0].platform.upper() - - def _maybe_api_name(self, api_name) -> str: - return f" {api_name}'s" if self.m_type == MismatchType.CONTEXT_DEVICES else "" - - @property - def source_info_str(self): - return ( - "" if self.source_info is None - else f" at {source_info_util.summarize(self.source_info.source_info)}" - ) - - @property - def _dev_ids_plat_str(self): - return f"device ids {self.device_ids} on platform {self.platform}" - - def m_type_str(self, api_name): - return (f'{self.source_info and self.source_info.eqn_name} inside {api_name}' - if self.m_type == MismatchType.SHARDING_INSIDE_COMPUTATION else self.m_type) - - def _str(self, api_name): - return (f"{self._maybe_api_name(api_name)} {self.m_type_str(api_name)} with " - f"{self._dev_ids_plat_str}{self.source_info_str}") - - -class DeviceAssignmentMismatchError(Exception): - pass - ShardingInfo = tuple[ Union[JSharding, UnspecifiedValue, AUTO], - MismatchType, + stages.MismatchType, Union[Any, None], # Any is dispatch.SourceInfo to avoid circular imports ] @@ -1725,39 +1734,60 @@ def get_default_device() -> xc.Device: def _get_and_check_device_assignment( shardings: Iterable[ShardingInfo], - devices: Sequence[xc.Device] | None, -) -> tuple[xc.Client, tuple[xc.Device, ...]]: + context_devices: Sequence[xc.Device] | None, +) -> tuple[xc.Client, tuple[xc.Device, ...] | None, int]: first_sharding_info = None - devices = () if devices is None else tuple(devices) + context_devices = () if context_devices is None else tuple(context_devices) + abstract_mesh = None + any_concrete_sharding = True if context_devices else False for sh, s_type, source_info in shardings: if isinstance(sh, UnspecifiedValue): continue - if isinstance(sh, NamedSharding) and isinstance(sh.mesh, AbstractMesh): - continue - if first_sharding_info is None: - first_sharding_info = ( - (sh.mesh._flat_devices_tuple, s_type, source_info) if isinstance(sh, AUTO) - else (sh._device_assignment, s_type, source_info)) - arr_device_assignment = (sh.mesh._flat_devices_tuple if isinstance(sh, AUTO) - else sh._device_assignment) - if not devices: - if first_sharding_info[0] != arr_device_assignment: - raise DeviceAssignmentMismatchError([ - DeviceAssignmentMismatch(*first_sharding_info), - DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)]) + elif isinstance(sh, NamedSharding) and isinstance(sh.mesh, AbstractMesh): + if (abstract_mesh is not None and not sh.mesh.empty and + abstract_mesh.size != sh.mesh.size): + raise ValueError("AbstractMesh should be of the same size across all " + f"shardings. Got {abstract_mesh} and {sh.mesh}") + abstract_mesh = sh.mesh else: - if devices != arr_device_assignment: - raise DeviceAssignmentMismatchError([ - DeviceAssignmentMismatch(devices, MismatchType.CONTEXT_DEVICES, None), - DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)]) - if first_sharding_info is None and devices: - final_device_assignment = devices + any_concrete_sharding = True + arr_device_assignment = sh._device_assignment + if first_sharding_info is None: + first_sharding_info = (arr_device_assignment, s_type, source_info) + if not context_devices: + if first_sharding_info[0] != arr_device_assignment: + raise stages.DeviceAssignmentMismatchError([ + stages.DeviceAssignmentMismatch(*first_sharding_info), + stages.DeviceAssignmentMismatch( + arr_device_assignment, s_type, source_info)]) + else: + if context_devices != arr_device_assignment: + raise stages.DeviceAssignmentMismatchError([ + stages.DeviceAssignmentMismatch( + context_devices, stages.MismatchType.CONTEXT_DEVICES, None), + stages.DeviceAssignmentMismatch( + arr_device_assignment, s_type, source_info)]) + + if first_sharding_info is None and context_devices: + device_assignment = context_devices elif first_sharding_info is None: - final_device_assignment = (get_default_device(),) + device_assignment = (get_default_device(),) else: - final_device_assignment = first_sharding_info[0] # type: ignore - return xb.get_device_backend(final_device_assignment[0]), final_device_assignment + device_assignment = first_sharding_info[0] # type: ignore + + backend = xb.get_device_backend(device_assignment[0]) + + if (any_concrete_sharding and abstract_mesh is not None and + len(device_assignment) != abstract_mesh.size): + raise ValueError( + f"AbstractMesh size: {abstract_mesh.size} does not match the" + f" device assignment size: {len(device_assignment)}") + + if any_concrete_sharding or abstract_mesh is None: + return backend, device_assignment, len(device_assignment) # type: ignore + else: + return backend, None, abstract_mesh.size MaybeSharding = Union[JSharding, UnspecifiedValue] @@ -1773,10 +1803,7 @@ def prune_unused_inputs( @weakref_lru_cache -def _dce_jaxpr(closed_jaxpr, api_name, fun_name, - keep_unused, donated_invars, auto_spmd_lowering): - name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name)) - +def _dce_jaxpr(closed_jaxpr, keep_unused, donated_invars, auto_spmd_lowering): assert isinstance(closed_jaxpr, core.ClosedJaxpr) jaxpr = closed_jaxpr.jaxpr consts = closed_jaxpr.consts @@ -1793,19 +1820,21 @@ def _dce_jaxpr(closed_jaxpr, api_name, fun_name, del kept_const_idx closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) - return closed_jaxpr, donated_invars, kept_var_idx, name_stack + return closed_jaxpr, donated_invars, kept_var_idx class MutationData(NamedTuple): - in_mut: list[core.MutableArray] + in_mut: list[core.Ref] + # out_mut[o_idx] = i_idx, when the output[o_idx] corresponds to the + # mutable array args[i_idx]. None when it does not correspond to a mutable array. out_mut: list[int | None] @weakref_lru_cache def _discharge_refs( jaxpr: core.ClosedJaxpr ) -> tuple[core.ClosedJaxpr, Sequence[int | None], MutationData]: - from jax._src.state.discharge import discharge_state + from jax._src.state.discharge import discharge_state2 # pytype: disable=import-error jaxpr, in_mut = _move_mutable_consts(jaxpr) - new_jaxpr = core.ClosedJaxpr(*discharge_state(jaxpr.jaxpr, jaxpr.consts)) + new_jaxpr = discharge_state2(jaxpr) count = it.count(len(jaxpr.out_avals)) # new outputs are appended to the end inout_map = {i: next(count) for i, a in enumerate(jaxpr.in_avals) if isinstance(a, AbstractRef)} @@ -1817,20 +1846,21 @@ def _discharge_refs( @weakref_lru_cache def _move_mutable_consts( closed_jaxpr: core.ClosedJaxpr, -) -> tuple[core.ClosedJaxpr, list[core.MutableArray]]: +) -> tuple[core.ClosedJaxpr, list[core.Ref]]: jaxpr = closed_jaxpr.jaxpr - hoist = [isinstance(c, core.MutableArray) for c in closed_jaxpr.consts] + hoist = [isinstance(c, core.Ref) for c in closed_jaxpr.consts] consts, in_mut = partition_list(hoist, closed_jaxpr.consts) constvars, mutvars = partition_list(hoist, jaxpr.constvars) invars = (*jaxpr.invars, *mutvars) effects = pe.make_jaxpr_effects(constvars, invars, jaxpr.outvars, jaxpr.eqns) + # TODO(mattjj): debug_info must be updated... jaxpr = core.Jaxpr(constvars, invars, jaxpr.outvars, jaxpr.eqns, - effects, closed_jaxpr.jaxpr.debug_info) + effects, closed_jaxpr.jaxpr.debug_info.with_unknown_names()) return core.ClosedJaxpr(jaxpr, consts), in_mut @weakref_lru_cache def _discharge_internal_refs(jaxpr: core.ClosedJaxpr) -> core.ClosedJaxpr: - from jax._src.state.discharge import discharge_state + from jax._src.state.discharge import discharge_state # pytype: disable=import-error jaxpr_, consts = discharge_state(jaxpr.jaxpr, jaxpr.consts) jaxpr_._debug_info = jaxpr.jaxpr._debug_info return core.ClosedJaxpr(jaxpr_, consts) @@ -1839,7 +1869,7 @@ def _discharge_internal_refs(jaxpr: core.ClosedJaxpr) -> core.ClosedJaxpr: class SemanticallyEqualShardings: def __init__(self, shardings: tuple[GSPMDSharding | UnspecifiedValue, ...], - avals: tuple[core.AbstractValue]): + avals: Sequence[core.AbstractValue]): gspmd_shardings = [ s if (isinstance(s, (UnspecifiedValue, AUTO)) or (isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh))) @@ -1847,7 +1877,6 @@ def __init__(self, shardings: tuple[GSPMDSharding | UnspecifiedValue, ...], for s, a in zip(shardings, avals)] self._gspmd_shardings = gspmd_shardings self.shardings = shardings - self.avals = avals def __hash__(self): return hash(tuple( @@ -1858,7 +1887,7 @@ def __eq__(self, other): if not isinstance(other, SemanticallyEqualShardings): return False return all( - (op_shardings.are_op_shardings_equal(s._hlo_sharding, o._hlo_sharding) + (op_shardings.are_hlo_shardings_equal(s._hlo_sharding, o._hlo_sharding) and s.memory_kind == o.memory_kind) if (isinstance(s, GSPMDSharding) and isinstance(o, GSPMDSharding)) else s == o @@ -1870,20 +1899,20 @@ def _raise_warnings_or_errors_for_jit_of_pmap( nreps: int, backend: xc.Client, name: str, jaxpr: core.Jaxpr) -> None: if nreps > 1: warnings.warn( - f"The jitted function {name} includes a pmap. Using " + f"The function {name} includes a pmap. Using " "jit-of-pmap can lead to inefficient data movement, as the outer jit " "does not preserve sharded data representations and instead collects " "input and output arrays onto a single device. " "Consider removing the outer jit unless you know what you're doing. " "See https://github.com/jax-ml/jax/issues/2926. Or " - "use jax.experimental.shard_map instead of pmap under jit compilation.") + "use jax.shard_map instead of pmap under jit compilation.") if nreps > xb.device_count(backend): raise ValueError( f"compiling computation `{name}` that requires {nreps} replicas, but " f"only {xb.device_count(backend)} XLA devices are available.") - if xb.process_count() > 1 and ( + if xb.process_count(backend) > 1 and ( nreps > 1 or dispatch.jaxpr_has_primitive(jaxpr, "xla_pmap") ): raise NotImplementedError( @@ -1892,44 +1921,46 @@ def _raise_warnings_or_errors_for_jit_of_pmap( @weakref_lru_cache -def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, +def _cached_lowering_to_hlo(closed_jaxpr: core.ClosedJaxpr, module_name, backend, + num_const_args: int, + in_avals, semantic_in_shardings, semantic_out_shardings, in_layouts, out_layouts, num_devices, device_assignment, - donated_invars, name_stack, all_default_mem_kind, + donated_invars, all_default_mem_kind, inout_aliases: None | tuple[None | int, ...], propagated_out_mem_kinds: tuple[None | str, ...], platforms: tuple[str, ...], lowering_parameters: mlir.LoweringParameters, abstract_mesh: AbstractMesh | None): + # in_avals, in_shardings, in_layouts include the jaxpr_const_args(jaxpr) + out_avals = closed_jaxpr.out_avals jaxpr = closed_jaxpr.jaxpr in_shardings = semantic_in_shardings.shardings out_shardings = semantic_out_shardings.shardings - global_in_avals = closed_jaxpr.in_avals - global_out_avals = closed_jaxpr.out_avals log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG if logger.isEnabledFor(log_priority): logger.log(log_priority, "Compiling %s with global shapes and types %s. " "Argument mapping: %s.", - fun_name, global_in_avals, in_shardings) + module_name, in_avals, in_shardings) # Look at the number of replcas present in the jaxpr. In # lower_sharding_computation, nreps > 1 during `jit(pmap)` cases. This is # handled here so as to deprecate the lower_xla_callable codepath when # `jax.Array` is turned on by default. # TODO(yashkatariya): Remove this when `jit(pmap)` is removed. - nreps = dispatch.jaxpr_replicas(jaxpr) - _raise_warnings_or_errors_for_jit_of_pmap(nreps, backend, fun_name, jaxpr) + nreps = _jaxpr_replicas(jaxpr) + _raise_warnings_or_errors_for_jit_of_pmap(nreps, backend, module_name, jaxpr) in_mlir_shardings: list[JSharding | AUTO | None] | None out_mlir_shardings: list[JSharding | AUTO | None] | None axis_ctx: mlir.AxisContext if nreps == 1: - in_mlir_shardings = map(_to_logical_sharding, global_in_avals, in_shardings) - out_mlir_shardings = map(_to_logical_sharding, global_out_avals, out_shardings) - replicated_args = [False] * len(global_in_avals) + in_mlir_shardings = map(_to_logical_sharding, in_avals, in_shardings) + out_mlir_shardings = map(_to_logical_sharding, out_avals, out_shardings) + replicated_args = [False] * len(in_avals) axis_ctx = sharding_impls.ShardingContext(num_devices, device_assignment, abstract_mesh) num_partitions = num_devices @@ -1942,7 +1973,6 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, axis_ctx = sharding_impls.ReplicaAxisContext(axis_env) num_partitions = 1 - module_name = f"{api_name}_{fun_name}" if num_devices > 1: unsupported_effects = effects.ordered_effects.filter_in(closed_jaxpr.effects) @@ -1953,32 +1983,34 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, "The following ordered effects are not supported for " f"more than 1 device: {unsupported_effects}") ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects)) + arg_names = ("",) * num_const_args + jaxpr._debug_info.safe_arg_names(len(in_avals) - num_const_args) with dispatch.log_elapsed_time( "Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time:.9f} sec", - fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT): + fun_name=module_name, event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT): lowering_result = mlir.lower_jaxpr_to_module( module_name, closed_jaxpr, + num_const_args=num_const_args, ordered_effects=ordered_effects, backend=backend, platforms=platforms, axis_context=axis_ctx, - name_stack=name_stack, + in_avals=in_avals, donated_args=donated_invars, replicated_args=replicated_args, arg_shardings=in_mlir_shardings, result_shardings=out_mlir_shardings, in_layouts=in_layouts, out_layouts=out_layouts, - arg_names=jaxpr._debug_info.safe_arg_names(len(jaxpr.invars)), - result_names=jaxpr._debug_info.safe_result_paths(len(jaxpr.outvars)), + arg_names=arg_names, + result_names=jaxpr._debug_info.safe_result_paths(len(out_avals)), num_replicas=nreps, num_partitions=num_partitions, all_default_mem_kind=all_default_mem_kind, input_output_aliases=inout_aliases, propagated_out_mem_kinds=propagated_out_mem_kinds, lowering_parameters=lowering_parameters) - tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform) + tuple_args = dispatch.should_tuple_args(len(in_avals), backend.platform) unordered_effects = list( effects.ordered_effects.filter_not_in(closed_jaxpr.effects)) return (lowering_result.module, lowering_result.keepalive, @@ -1986,86 +2018,50 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, nreps, tuple_args, lowering_result.shape_poly_state) -@lru_cache(maxsize=2048) -def _create_da_object( # pytype: disable=invalid-annotation - device_assignment: tuple[xc.Device, ...]) -> xc.DeviceList: +@util.cache(max_size=2048, trace_context_in_key=False) +def _create_device_list_cached(device_assignment: tuple[xc.Device, ...] + ) -> xc.DeviceList: return xc.DeviceList(device_assignment) +def _create_device_list( + device_assignment: tuple[xc.Device, ...] | xc.DeviceList | None + ) -> xc.DeviceList | None: + if device_assignment is None or isinstance(device_assignment, xc.DeviceList): + return device_assignment # type: ignore + return _create_device_list_cached(device_assignment) + @weakref_lru_cache -def jaxpr_transfer_mem_kinds( - jaxpr: core.Jaxpr) -> Sequence[sharding_impls.TransferToMemoryKind]: +def jaxpr_transfer_mem_kinds(jaxpr: core.Jaxpr): out = [] # type: ignore for eqn in jaxpr.eqns: if eqn.primitive is dispatch.device_put_p: out.extend(d for d in eqn.params['devices'] - if isinstance(d, sharding_impls.TransferToMemoryKind)) + if isinstance(d, core.MemorySpace)) + elif eqn.primitive.name == 'call_exported': + out.extend(aval.memory_space for aval in eqn.params['exported'].out_avals) + for subjaxpr in core.subjaxprs(jaxpr): out.extend(jaxpr_transfer_mem_kinds(subjaxpr)) return out -def are_all_shardings_default_mem_kind( - da_object: xc.DeviceList | None, shardings -): - if da_object is None: - return True - try: - default_mem_kind = da_object.default_memory_kind - except: - return True +def are_all_shardings_default_mem_kind(shardings): for i in shardings: if isinstance(i, (UnspecifiedValue, AUTO)): continue - if i.memory_kind is None: # pytype: disable=attribute-error + mem_kind = (core.mem_space_to_kind(i) if isinstance(i, core.MemorySpace) + else i.memory_kind) + if mem_kind is None: continue - if i.memory_kind != default_mem_kind: + if mem_kind != 'device': return False return True -memory_kind_propagate_rule: dict[Any, Any] = {} - -@weakref_lru_cache -def get_out_memory_kinds_via_propagation(closed_jaxpr: core.ClosedJaxpr, - in_shardings=None) -> tuple[None | str]: - env = {} # type: ignore - jaxpr = closed_jaxpr.jaxpr - - def read(var): - if type(var) is core.Literal: - return None - return env[var] - - def write(var, val): - env[var] = val - - def _default_rule(prim, num_outvars, *_, **__): - return [None] * num_outvars if prim.multiple_results else None - - if in_shardings is None: - invar_mem_kind = [None] * len(jaxpr.invars) - else: - invar_mem_kind = [None if isinstance(s, (UnspecifiedValue, AUTO)) else s.memory_kind - for s in in_shardings] - safe_map(write, jaxpr.invars, invar_mem_kind) - safe_map(write, jaxpr.constvars, [None] * len(jaxpr.constvars)) - - for eqn in jaxpr.eqns: - in_mem_kinds = safe_map(read, eqn.invars) - rule = memory_kind_propagate_rule.get( - eqn.primitive, partial(_default_rule, eqn.primitive, len(eqn.outvars))) - out_mem_kinds = rule(*in_mem_kinds, **eqn.params) - if not eqn.primitive.multiple_results: - out_mem_kinds = [out_mem_kinds] - safe_map(write, eqn.outvars, out_mem_kinds) - return tuple(safe_map(read, jaxpr.outvars)) - @weakref_lru_cache def get_out_layouts_via_propagation(closed_jaxpr: core.ClosedJaxpr - ) -> tuple[None | DeviceLocalLayout]: - from jax._src import pjit - + ) -> tuple[None | Layout]: env = {} # type: ignore jaxpr = closed_jaxpr.jaxpr @@ -2091,37 +2087,7 @@ def write(var, val): return tuple(safe_map(read, jaxpr.outvars)) -def _get_num_devices( - shardings, device_assignment - ) -> tuple[int, tuple[xc.Device, ...] | None]: - """Number of lowering devices, and the device_assignment to use. - - If all the specified shardings have an abstract mesh, then we are compiling - with abstract devices, and the returned device_assignment is None. - """ - abstract_mesh, any_concrete_sharding = None, False - for s in shardings: - if isinstance(s, UnspecifiedValue): - continue - elif (isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh) and - not s.mesh.empty): - if abstract_mesh is not None and abstract_mesh != s.mesh: - raise ValueError("AbstractMesh should be the same across all " - f"shardings. Got {abstract_mesh} and {s.mesh}") - abstract_mesh = s.mesh - else: - any_concrete_sharding = True - if (any_concrete_sharding and abstract_mesh is not None and - len(device_assignment) != abstract_mesh.size): - raise ValueError( - f"AbstractMesh size: {abstract_mesh.size} does not match the" - f" device assignment size: {len(device_assignment)}") - if any_concrete_sharding or abstract_mesh is None: - return len(device_assignment), device_assignment - return abstract_mesh.size, None - - -MaybeLayout = Sequence[Union[DeviceLocalLayout, AutoLayout, None]] +MaybeLayout = Sequence[Union[Layout, AutoLayout, None]] class AllArgsInfo(NamedTuple): @@ -2130,21 +2096,23 @@ class AllArgsInfo(NamedTuple): debug_info: core.DebugInfo -@lru_cache(maxsize=2048) +@util.cache(max_size=2048, trace_context_in_key=False) def to_gspmd_sharding(s: JSharding, ndim: int) -> GSPMDSharding: if isinstance(s, GSPMDSharding): return s - return GSPMDSharding(s._device_assignment, s._to_xla_hlo_sharding(ndim), - memory_kind=s.memory_kind, - _device_list=getattr(s, '_internal_device_list', None)) + return GSPMDSharding(s._internal_device_list, s._to_xla_hlo_sharding(ndim), + memory_kind=s.memory_kind) def _discharge_refs_jaxpr(closed_jaxpr, in_shardings, in_layouts, donated_invars, out_shardings, out_layouts): - if any(isinstance(e, RefEffect) for e in closed_jaxpr.effects): + if (any(isinstance(e, RefEffect) for e in closed_jaxpr.effects) or + any(isinstance(a, AbstractRef) for a in closed_jaxpr.in_avals)): closed_jaxpr, inout_aliases, mut = _discharge_refs(closed_jaxpr) - in_shardings = (*in_shardings, *(c.sharding for c in mut.in_mut)) - in_layouts = (*in_layouts,) + (None,) * len(mut.in_mut) # TODO(mattjj) + in_shardings = (*in_shardings, *( + pjit.finalize_arg_sharding(c.sharding, c.committed) for c in mut.in_mut)) + in_layouts = (*in_layouts, *(c.format.layout if hasattr(c, 'format') + else None for c in mut.in_mut)) donated_invars = (*donated_invars,) + (False,) * len(mut.in_mut) out_layouts_ = iter(zip(out_shardings, out_layouts)) out_shardings, out_layouts = unzip2( @@ -2159,26 +2127,63 @@ def _discharge_refs_jaxpr(closed_jaxpr, in_shardings, in_layouts, return (closed_jaxpr, inout_aliases, mut, in_shardings, in_layouts, donated_invars, out_shardings, out_layouts) -@lru_cache(maxsize=1024) + +def hoist_constants_as_args( + closed_jaxpr: core.ClosedJaxpr, global_in_avals, in_shardings, in_layouts, + donated_invars, kept_var_idx: set[int], inout_aliases, mut, + all_args_info: AllArgsInfo): + const_args, const_arg_avals = unzip2( + core.jaxpr_const_args(closed_jaxpr.jaxpr) + ) + num_const_args = len(const_args) + if num_const_args: + global_in_avals = list(const_arg_avals) + global_in_avals # type: ignore + ca_shardings = pjit.const_args_shardings(const_args) + in_shardings = ca_shardings + in_shardings # type: ignore + ca_layouts = pjit.const_args_layouts(const_args, const_arg_avals, + ca_shardings) + in_layouts = ca_layouts + in_layouts # type: ignore + + donated_invars = (False,) * num_const_args + donated_invars + kept_var_idx = set(range(num_const_args)).union( + {kv + num_const_args for kv in kept_var_idx}) + if inout_aliases is not None: + inout_aliases = (None,) * num_const_args + inout_aliases + if mut is not None: + mut = MutationData( + in_mut=mut.in_mut, + out_mut=[None if i_idx is None else i_idx + num_const_args + for i_idx in mut.out_mut]) + if all_args_info.debug_info.arg_names is None: + arg_names = None + else: + arg_names = (("",) * num_const_args + all_args_info.debug_info.arg_names) + all_args_info = AllArgsInfo( + list(const_arg_avals) + all_args_info.in_avals, # type: ignore + all_args_info.debug_info._replace(arg_names=arg_names)) + + return (const_args, global_in_avals, in_shardings, in_layouts, donated_invars, + kept_var_idx, inout_aliases, mut, all_args_info) + + +@util.cache(max_size=1024, trace_context_in_key=False) def _abstract_to_concrete_mesh(abstract_mesh, device_assignment): np_dev = np.vectorize(lambda i: device_assignment[i], otypes=[object])(np.arange(len(device_assignment))) return Mesh(np_dev.reshape(abstract_mesh.axis_sizes), - abstract_mesh.axis_names, axis_types=abstract_mesh._axis_types) + abstract_mesh.axis_names, axis_types=abstract_mesh.axis_types) def _concretize_abstract_out_shardings(shardings, avals, device_assignment, out_mem_kinds): if device_assignment is None: return shardings - if len(device_assignment) == 1: - return shardings out = [] for s, a, mem_kind in zip(shardings, avals, out_mem_kinds): - if isinstance(s, UnspecifiedValue) and a.sharding is not None: + if isinstance(s, UnspecifiedValue) and isinstance(a, core.ShapedArray): if a.sharding.mesh.empty: out.append(s) - elif a.sharding.mesh._are_all_axes_auto: + elif a.sharding.mesh._are_all_axes_auto_or_manual: out.append(s) else: spec = (PartitionSpec(*[PartitionSpec.UNCONSTRAINED if sp is None else sp @@ -2192,19 +2197,18 @@ def _concretize_abstract_out_shardings(shardings, avals, device_assignment, return tuple(out) -def _get_context_mesh(context_mesh: Mesh | None) -> Mesh | None: - if context_mesh is None: - return context_mesh +def _get_context_mesh(context_mesh: Mesh) -> Mesh: # Don't update the mesh because the old `with mesh` ctx mgr is set. - if get_concrete_mesh() is None: + if get_concrete_mesh().empty: return context_mesh cur_mesh = get_abstract_mesh() if cur_mesh.empty or context_mesh.empty: return context_mesh if cur_mesh == context_mesh.abstract_mesh: return context_mesh - return Mesh(context_mesh.devices, context_mesh.axis_names, - axis_types=cur_mesh._axis_types) + assert context_mesh.size == cur_mesh.size + return Mesh(context_mesh.devices.reshape(cur_mesh.axis_sizes), + cur_mesh.axis_names, cur_mesh.axis_types) @profiler.annotate_function @@ -2219,7 +2223,7 @@ def lower_sharding_computation( donated_invars: Sequence[bool], *, keep_unused: bool, - context_mesh: Mesh | None, + context_mesh: Mesh, compiler_options_kvs: tuple[tuple[str, Any], ...], lowering_platforms: tuple[str, ...] | None, lowering_parameters: mlir.LoweringParameters, @@ -2230,16 +2234,14 @@ def lower_sharding_computation( The caller of this code can pass in a singleton UNSPECIFIED because the number of out_avals might not be known at that time and lower_sharding_computation calculates the number of out_avals so it can apply - the singleton UNSPECIFIED to all out_avals. - """ + the singleton UNSPECIFIED to all out_avals.""" auto_spmd_lowering = check_if_any_auto( it.chain.from_iterable([in_shardings, out_shardings])) all_args_info = AllArgsInfo(closed_jaxpr.in_avals, closed_jaxpr.jaxpr._debug_info) - closed_jaxpr, donated_invars, kept_var_idx, name_stack = _dce_jaxpr( - closed_jaxpr, api_name, fun_name, keep_unused, donated_invars, - auto_spmd_lowering) + closed_jaxpr, donated_invars, kept_var_idx = _dce_jaxpr( + closed_jaxpr, keep_unused, donated_invars, auto_spmd_lowering) in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx) in_layouts = tuple(l for i, l in enumerate(in_layouts) if i in kept_var_idx) @@ -2252,11 +2254,19 @@ def lower_sharding_computation( global_in_avals = closed_jaxpr.in_avals global_out_avals = closed_jaxpr.out_avals + if lowering_parameters.hoist_constants_as_args: + (const_args, global_in_avals, in_shardings, in_layouts, donated_invars, + kept_var_idx, inout_aliases, mut, all_args_info) = hoist_constants_as_args( + closed_jaxpr, global_in_avals, in_shardings, in_layouts, + donated_invars, kept_var_idx, inout_aliases, mut, all_args_info) + else: + const_args = [] + # If layout is propagated, then set the out_layout in the top module to AUTO # so that XLA can override the entry_computation_layout. The propagated # layout will be set via a custom call. out_layouts_via_prop = get_out_layouts_via_propagation(closed_jaxpr) - out_layouts = tuple(DeviceLocalLayout.AUTO if p is not None else o + out_layouts = tuple(Layout.AUTO if p is not None else o for o, p in safe_zip(out_layouts, out_layouts_via_prop)) assert len(out_shardings) == len(out_layouts) == len(global_out_avals), ( @@ -2264,40 +2274,30 @@ def lower_sharding_computation( context_mesh = _get_context_mesh(context_mesh) - devices_from_context = (None if context_mesh is None or context_mesh.empty - else context_mesh._flat_devices_tuple) + devices_from_context = (None if context_mesh.empty else + context_mesh._flat_devices_tuple) # Device assignment across all inputs, outputs and shardings inside jaxpr # should be the same. unique_intermediate_shardings = util.stable_unique( dispatch.get_intermediate_shardings(jaxpr)) - unique_in_shardings = util.stable_unique(in_shardings) + unique_const_shardings = util.stable_unique(in_shardings[:len(const_args)]) + unique_in_shardings = util.stable_unique(in_shardings[len(const_args):]) unique_out_shardings = util.stable_unique(out_shardings) - backend, device_assignment = _get_and_check_device_assignment( + # TODO(necula): Replace `None` with `source_info` for unique_const_shardings + backend, device_assignment, num_devices = _get_and_check_device_assignment( it.chain( - ((i, MismatchType.ARG_SHARDING, None) for i in unique_in_shardings), - ((o, MismatchType.OUT_SHARDING, None) for o in unique_out_shardings), - ((js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info) + ((i, stages.MismatchType.ARG_SHARDING, None) for i in unique_in_shardings), + ((c, stages.MismatchType.CONST_SHARDING, None) for c in unique_const_shardings), + ((o, stages.MismatchType.OUT_SHARDING, None) for o in unique_out_shardings), + ((js, stages.MismatchType.SHARDING_INSIDE_COMPUTATION, source_info) for js, source_info in unique_intermediate_shardings)), devices_from_context) unique_intermediate_shardings = [js for js, _ in unique_intermediate_shardings] - - # TODO(parkers): One _raw_platform has been unified with platform, - # change this back to just read platform. - platforms = lowering_platforms or ( - getattr(backend, "_raw_platform", backend.platform),) + unique_in_shardings = unique_in_shardings | unique_const_shardings # type: ignore + del unique_const_shardings prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr) - # TODO(yashkatariya): All device specific logic should go in compilation - # but this requires a big refactor. The current `_get_num_devices` logic - # is good enough to lower with AbstractMesh but cannot be compiled. Once - # I refactor, this will also work well with mesh being provided at - # compile time. - # Sets device_assignment to None if only abstractMesh and unspecified exists. - num_devices, device_assignment = _get_num_devices( - it.chain(unique_in_shardings, unique_out_shardings, - unique_intermediate_shardings), - device_assignment) if device_assignment is None: if lowering_platforms is None: raise ValueError( @@ -2309,36 +2309,68 @@ def lower_sharding_computation( "AbstractMesh cannot be used when jaxpr contains primitives that" " require devices to be present during lowering.") + # For device_assignment == 1, this doesn't matter. + if device_assignment is not None and len(device_assignment) > 1: + rep_gs = GSPMDSharding.get_replicated(device_assignment) + in_shardings = tuple( + rep_gs if (isinstance(s, UnspecifiedValue) and + aval is not core.abstract_token and aval.ndim == 0) + else s for s, aval in zip(in_shardings, global_in_avals)) + + for a in global_out_avals: + if (a is not core.abstract_token and not a.sharding.mesh.empty and + a.sharding.mesh.are_all_axes_explicit and + device_assignment is not None and + len(device_assignment) != a.sharding.mesh.size): + raise ValueError( + f"Length of device assignment {len(device_assignment)} is not equal" + f" to the size of the mesh {a.sharding.mesh.size} of aval" + f" {a.str_short(True, True)}. Please enter your `jit` into a mesh" + " context via `jax.set_mesh`.") + + # TODO(parkers): One _raw_platform has been unified with platform, + # change this back to just read platform. + platforms = lowering_platforms or ( + getattr(backend, "_raw_platform", backend.platform),) + + device_list = _create_device_list(device_assignment) + transfer_mem_kind_in_jaxpr = jaxpr_transfer_mem_kinds(jaxpr) + committed = bool( devices_from_context or num_devices > 1 or any(not isinstance(s, UnspecifiedValue) for s in it.chain( - unique_in_shardings, unique_out_shardings, unique_intermediate_shardings))) - - da_object = (_create_da_object(tuple(device_assignment)) - if device_assignment is not None else None) + unique_in_shardings, unique_out_shardings, + unique_intermediate_shardings)) + or transfer_mem_kind_in_jaxpr + ) - transfer_mem_kind_in_jaxpr = jaxpr_transfer_mem_kinds(jaxpr) all_default_mem_kind = are_all_shardings_default_mem_kind( - da_object, it.chain(unique_in_shardings, unique_out_shardings, unique_intermediate_shardings, transfer_mem_kind_in_jaxpr)) # pytype: disable=wrong-arg-types if all_default_mem_kind: propagated_out_mem_kinds = (None,) * len(global_out_avals) else: - propagated_out_mem_kinds = get_out_memory_kinds_via_propagation( - closed_jaxpr, in_shardings) + propagated_out_mem_kinds = tuple( + core.mem_space_to_kind(o.memory_space) for o in closed_jaxpr.out_avals) # type: ignore out_shardings = _concretize_abstract_out_shardings( out_shardings, global_out_avals, device_assignment, propagated_out_mem_kinds) - # 2. Build up the HLO + global_in_avals = [core.update_aval_with_sharding(a, sh) + if isinstance(a, core.ShapedArray) else a + for a, sh in zip(global_in_avals, in_shardings)] + global_out_avals = [core.update_aval_with_sharding(a, sh) + if isinstance(a, core.ShapedArray) else a + for a, sh in zip(global_out_avals, out_shardings)] + + ############################ Build up the stableHLO ###################### abstract_mesh = None if prim_requires_devices: - assert da_object is not None + assert device_list is not None for sharding in it.chain(unique_in_shardings, unique_out_shardings, unique_intermediate_shardings): if isinstance(sharding, NamedSharding): @@ -2355,12 +2387,17 @@ def lower_sharding_computation( semantic_out_shardings = SemanticallyEqualShardings( out_shardings, global_out_avals) + jaxpr_util.maybe_dump_jaxpr_to_file(fun_name, closed_jaxpr.jaxpr) + module_name = util.wrap_name(api_name, fun_name) + (module, keepalive, host_callbacks, unordered_effects, ordered_effects, nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo( - closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings, - semantic_out_shardings, in_layouts, out_layouts, num_devices, - tuple(da_object) if prim_requires_devices else None, # type: ignore[arg-type] - donated_invars, name_stack, all_default_mem_kind, inout_aliases, + closed_jaxpr, module_name, backend, + len(const_args), tuple(global_in_avals), + semantic_in_shardings, semantic_out_shardings, + in_layouts, out_layouts, num_devices, + tuple(device_list) if prim_requires_devices else None, # type: ignore[arg-type] + donated_invars, all_default_mem_kind, inout_aliases, propagated_out_mem_kinds, platforms, lowering_parameters=lowering_parameters, abstract_mesh=abstract_mesh) @@ -2371,11 +2408,13 @@ def lower_sharding_computation( # because we calculate the device_assignment and backend before in_shardings, # etc are pruned. return MeshComputation( - str(name_stack), + module_name, module, + const_args, donated_invars, platforms, compiler_options_kvs, + device_list, global_in_avals=global_in_avals, global_out_avals=global_out_avals, in_shardings=in_shardings, @@ -2390,7 +2429,6 @@ def lower_sharding_computation( kept_var_idx=kept_var_idx, mut=mut, backend=backend, - device_assignment=da_object, num_devices=num_devices, committed=committed, in_layouts=in_layouts, @@ -2410,7 +2448,7 @@ def _to_logical_sharding( return None if isinstance(sharding, AUTO): return sharding - elif isinstance(aval, (ShapedArray, DShapedArray, AbstractRef)): + elif isinstance(aval, (ShapedArray, AbstractRef)): assert isinstance(sharding, JSharding) return sharding elif isinstance(aval, core.AbstractToken): @@ -2419,35 +2457,60 @@ def _to_logical_sharding( raise TypeError(aval) -class MeshComputation(stages.XlaLowering): +class MeshComputation(stages.Lowering): _hlo: ir.Module _executable: MeshExecutable | None def __init__(self, name: str, hlo: ir.Module, + const_args: list[ArrayLike], donated_invars: Sequence[bool], platforms: Sequence[str], compiler_options_kvs: tuple[tuple[str, Any], ...], + device_assignment: xc.DeviceList | tuple[xc.Device, ...] | None, **compile_args): self._name = name self._hlo = hlo + self.const_args = const_args self._donated_invars = donated_invars self._platforms = platforms self._compiler_options_kvs = compiler_options_kvs + self._device_list = _create_device_list(device_assignment) self.compile_args = compile_args self._executable = None - # -- stages.XlaLowering overrides + # -- stages.Lowering overrides def stablehlo(self) -> ir.Module: return self._hlo - def compile(self, compiler_options=None) -> MeshExecutable: + def compile(self, compiler_options=None, *, device_assignment=None, + ) -> MeshExecutable: t_compiler_options = (() if compiler_options is None else tuple(compiler_options.items())) compiler_options_kvs = self._compiler_options_kvs + t_compiler_options - if self._executable is None or compiler_options_kvs: + + device_list = _create_device_list(device_assignment) + if device_list is None: + compilation_device_list = self._device_list + else: + if (self._device_list is not None and + self._device_list != device_list): + raise ValueError( + "device_assignment passed to `.compile` must match the" + " device_assignment calculated from array shardings and" + " out_shardings. Got device ids passed to compile" + f" {[d.id for d in device_list]} on platform" + f" {device_list[0].platform.upper()} and devices ids" + " calculated from array shardings and out_shardings" + f" {[d.id for d in self._device_list]} on platform" + f" {self._device_list[0].platform.upper()}") + compilation_device_list = device_list + assert isinstance(compilation_device_list, (type(None), xc.DeviceList)) + + if self._executable is None or compiler_options_kvs or device_assignment: executable = UnloadedMeshExecutable.from_hlo( self._name, self._hlo, **self.compile_args, - compiler_options_kvs=compiler_options_kvs) + compiler_options_kvs=compiler_options_kvs, + device_list=compilation_device_list) if not compiler_options_kvs: self._executable = executable return executable @@ -2460,16 +2523,45 @@ def cost_analysis(self) -> dict[str, float]: "Lowered.cost_analysis not implemented on platform " f"'{backend.platform}'. Use compile().cost_analysis() for " "post-compilation cost estimates.") - return xe.hlo_module_cost_analysis(backend, self.hlo().as_hlo_module()) + return _jax.hlo_module_cost_analysis(backend, self.hlo().as_hlo_module()) + + +def get_op_sharding_from_executable( + executable) -> tuple[Sequence[xc.OpSharding], Sequence[xc.OpSharding]]: + in_op_shardings: list[xc.OpSharding] = [] + parameter_shardings_from_xla = executable.get_parameter_shardings() + if parameter_shardings_from_xla is not None: + in_op_shardings = parameter_shardings_from_xla + + out_op_shardings: list[xc.OpSharding] = [] + output_shardings_from_xla = executable.get_output_shardings() + if output_shardings_from_xla is not None: + out_op_shardings = output_shardings_from_xla + + return in_op_shardings, out_op_shardings + + +def get_pspec_from_executable( + executable, mesh: Mesh +) -> tuple[tuple[PartitionSpec, ...], tuple[PartitionSpec, ...]]: + input_op_s, output_op_s = get_op_sharding_from_executable(executable) + in_pspec: list[PartitionSpec] = [] + for s in input_op_s: + in_pspec.extend(sharding_impls.parse_flatten_op_sharding(s, mesh)) + + out_pspec: list[PartitionSpec] = [] + for s in output_op_s: + out_pspec.extend(sharding_impls.parse_flatten_op_sharding(s, mesh)) + return tuple(in_pspec), tuple(out_pspec) def get_out_shardings_from_executable( xla_executable, - device_assignment: Sequence[xc.Device], + device_list: xc.DeviceList, num_out_avals: int, num_ordered_effects: int, ) -> Sequence[sharding_impls.GSPMDSharding] | None: - from jax._src import pjit + assert isinstance(device_list, xc.DeviceList) try: omk = xla_executable.get_output_memory_kinds()[0] @@ -2482,11 +2574,11 @@ def get_out_shardings_from_executable( # When the device assignment only has 1 device, SPMD partitioner will not run. # Hence the op shardings will not be set on the `hlo_module`. - if len(device_assignment) == 1: - return [sharding_impls.GSPMDSharding.get_replicated(device_assignment, memory_kind=mk) + if len(device_list) == 1: + return [sharding_impls.GSPMDSharding.get_replicated(device_list, memory_kind=mk) for mk in omk] - _, out_op_shardings = pjit.get_op_sharding_from_executable(xla_executable) + _, out_op_shardings = get_op_sharding_from_executable(xla_executable) if not out_op_shardings: return None @@ -2508,23 +2600,22 @@ def get_out_shardings_from_executable( assert len(out_op_shardings) == num_out_avals == len(omk), ( len(out_op_shardings), num_out_avals, len(omk)) - return [sharding_impls.GSPMDSharding(device_assignment, os, memory_kind=mk) + return [sharding_impls.GSPMDSharding(device_list, os, memory_kind=mk) for os, mk in safe_zip(out_op_shardings, omk)] def _get_in_shardings_from_xla( - xla_executable, device_assignment: Sequence[xc.Device], num_in_avals: int, + xla_executable, device_list: xc.DeviceList, num_in_avals: int, num_ordered_effects: int ) -> Sequence[GSPMDSharding] | None: """Returns input shardings from XLA.""" - from jax._src import pjit - # When the device assignment only has 1 device, SPMD partitioner will not run. # Hence the op shardings will not be set on the `hlo_module`. - if len(device_assignment) == 1: - return [GSPMDSharding.get_replicated(device_assignment)] * num_in_avals + assert isinstance(device_list, xc.DeviceList) + if len(device_list) == 1: + return [GSPMDSharding.get_replicated(device_list)] * num_in_avals - in_op_shardings, _ = pjit.get_op_sharding_from_executable(xla_executable) + in_op_shardings, _ = get_op_sharding_from_executable(xla_executable) if not in_op_shardings: return None @@ -2534,8 +2625,7 @@ def _get_in_shardings_from_xla( assert len(in_op_shardings) == num_in_avals, ( len(in_op_shardings), num_in_avals) - return [GSPMDSharding(device_assignment, os) - for os in in_op_shardings] + return [GSPMDSharding(device_list, os) for os in in_op_shardings] # TODO(yashkatariya): Remove this function after `AUTO` can return shardings @@ -2543,9 +2633,7 @@ def _get_in_shardings_from_xla( def _get_mesh_pspec_shardings_from_executable( xla_executable, mesh: Mesh ) -> tuple[Sequence[NamedSharding], Sequence[NamedSharding]]: - from jax._src import pjit - - in_pspec, out_pspec = pjit.get_pspec_from_executable(xla_executable, mesh) + in_pspec, out_pspec = get_pspec_from_executable(xla_executable, mesh) return ([NamedSharding(mesh, i) for i in in_pspec], [NamedSharding(mesh, o) for o in out_pspec]) @@ -2557,7 +2645,8 @@ def _gspmd_to_named_sharding( assert isinstance(out_s, GSPMDSharding) assert isinstance(orig_in_s, NamedSharding) assert isinstance(orig_in_s.mesh, Mesh) - if out_aval is not None and not out_aval.sharding.mesh.empty: + if (out_aval is not None and not out_aval.sharding.mesh.empty and + not out_aval.sharding.mesh._any_axis_manual): mesh = _abstract_to_concrete_mesh( out_aval.sharding.mesh, out_s._device_assignment) else: @@ -2565,15 +2654,6 @@ def _gspmd_to_named_sharding( return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, mesh) _orig_out_sharding_handlers[NamedSharding] = _gspmd_to_named_sharding -def _gspmd_to_positional_sharding( - out_s: GSPMDSharding, out_aval, orig_in_s: PositionalSharding - ) -> PositionalSharding: - assert isinstance(out_s, GSPMDSharding) - assert isinstance(orig_in_s, PositionalSharding) - return sharding_impls._op_sharding_to_pos_sharding( - out_s._hlo_sharding, orig_in_s._device_assignment, out_s.memory_kind) -_orig_out_sharding_handlers[PositionalSharding] = _gspmd_to_positional_sharding # type: ignore - def _gspmd_to_single_device_sharding( out_s: GSPMDSharding, out_aval, orig_in_s: SingleDeviceSharding ) -> SingleDeviceSharding: @@ -2639,9 +2719,9 @@ def maybe_recover_user_shardings( return new_shardings -def is_user_xla_layout_equal(ul: DeviceLocalLayout | AutoLayout, - xl: DeviceLocalLayout) -> bool: - if isinstance(ul, DeviceLocalLayout) and not ul._tiling: +def is_user_xla_layout_equal(ul: Layout | AutoLayout, + xl: Layout) -> bool: + if isinstance(ul, Layout) and not ul.tiling: return ul.major_to_minor == xl.major_to_minor else: return ul == xl @@ -2649,7 +2729,7 @@ def is_user_xla_layout_equal(ul: DeviceLocalLayout | AutoLayout, def _get_layouts_from_executable( xla_executable, in_layouts, out_layouts, num_ordered_effects -) -> tuple[Sequence[DeviceLocalLayout | None], Sequence[DeviceLocalLayout | None]]: +) -> tuple[Sequence[Layout | None], Sequence[Layout | None]]: try: in_layouts_xla = xla_executable.get_parameter_layouts() out_layouts_xla = xla_executable.get_output_layouts() @@ -2662,8 +2742,8 @@ def _get_layouts_from_executable( new_in_layouts = [] for x, l in safe_zip(in_layouts_xla, in_layouts): - x = DeviceLocalLayout.from_pjrt_layout(x) - if isinstance(l, DeviceLocalLayout) and not is_user_xla_layout_equal(l, x): + x = Layout.from_pjrt_layout(x) + if isinstance(l, Layout) and not is_user_xla_layout_equal(l, x): raise AssertionError( f"Unexpected XLA layout override: (XLA) {x} != {l} " f"(User input layout)") @@ -2673,8 +2753,8 @@ def _get_layouts_from_executable( new_out_layouts = [] for x, l in safe_zip(out_layouts_xla, out_layouts): - x = DeviceLocalLayout.from_pjrt_layout(x) - if isinstance(l, DeviceLocalLayout) and not is_user_xla_layout_equal(l, x): + x = Layout.from_pjrt_layout(x) + if isinstance(l, Layout) and not is_user_xla_layout_equal(l, x): raise AssertionError( f"Unexpected XLA layout override: (XLA) {x} != {l} " f"(User output layout)") @@ -2682,8 +2762,8 @@ def _get_layouts_from_executable( # (tiling, etc) even if the user layout does not specify tiling. new_out_layouts.append(x) - assert all(isinstance(i, DeviceLocalLayout) for i in new_in_layouts) - assert all(isinstance(o, DeviceLocalLayout) for o in new_out_layouts) + assert all(isinstance(i, Layout) for i in new_in_layouts) + assert all(isinstance(o, Layout) for o in new_out_layouts) return new_in_layouts, new_out_layouts @@ -2716,7 +2796,6 @@ def create_compile_options( num_partitions=num_partitions, device_assignment=xla_device_assignment, use_spmd_partitioning=spmd_lowering, - use_shardy_partitioner=config.use_shardy_partitioner.value, use_auto_spmd_partitioning=auto_spmd_lowering, env_options_overrides=compiler_options, fdo_profile=fdo_profile, @@ -2757,13 +2836,13 @@ def _cached_compilation(computation, name, mesh, spmd_lowering, fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT): xla_executable = compiler.compile_or_get_cached( backend, computation, dev, compile_options, host_callbacks, - pgle_profiler) + da, pgle_profiler) return xla_executable def _maybe_get_and_check_in_shardings( - xla_executable, in_shardings, device_assignment, - global_in_avals, num_ordered_effects): + xla_executable, in_shardings, device_list, global_in_avals, + num_ordered_effects): """Returns in_shardings extracted from XLA or checks and returns original shardings. @@ -2774,8 +2853,7 @@ def _maybe_get_and_check_in_shardings( If in_sharding is unspecified, then the sharding returned by XLA is returned. """ in_shardings_xla = _get_in_shardings_from_xla( - xla_executable, device_assignment, len(global_in_avals), - num_ordered_effects) + xla_executable, device_list, len(global_in_avals), num_ordered_effects) if in_shardings_xla is None: return in_shardings @@ -2793,7 +2871,7 @@ def _maybe_get_and_check_in_shardings( # MANUAL HloSharding comes from other partitioning frameworks. if (not dtypes.issubdtype(aval.dtype, dtypes.extended) and not xla_hlo_s.is_manual() and - (not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s))): + (not op_shardings.are_hlo_shardings_equal(xla_hlo_s, orig_hlo_s))): raise AssertionError( f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} " "(User sharding)") @@ -2806,11 +2884,11 @@ def _maybe_get_and_check_in_shardings( def _maybe_get_and_check_out_shardings( - xla_executable, out_shardings, device_assignment, global_out_avals, + xla_executable, out_shardings, device_list, global_out_avals, num_ordered_effects ): out_shardings_xla = get_out_shardings_from_executable( - xla_executable, device_assignment, len(global_out_avals), + xla_executable, device_list, len(global_out_avals), num_ordered_effects) if out_shardings_xla is None: return out_shardings @@ -2837,7 +2915,7 @@ def _maybe_get_and_check_out_shardings( # MANUAL HloSharding comes from other partitioning frameworks. if (not dtypes.issubdtype(aval.dtype, dtypes.extended) and not xla_hlo_s.is_manual() and - (not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s) or + (not op_shardings.are_hlo_shardings_equal(xla_hlo_s, orig_hlo_s) or xla_s.memory_kind != orig.memory_kind)): # pytype: disable=attribute-error raise AssertionError( f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} " @@ -2863,11 +2941,22 @@ def get_prop_to_input_output(in_shardings, out_shardings, return allow_prop_to_inputs, allow_prop_to_outputs +def maybe_concretize_mesh(sharding, da: xc.DeviceList): + if (isinstance(sharding, NamedSharding) and + isinstance(sharding.mesh, AbstractMesh)): + if sharding.mesh.size != len(da): + raise ValueError( + f"The size of abstract mesh {sharding.mesh.size} in {sharding} must" + f" match the length of device assignment: {len(da)}") + return sharding.update(mesh=_abstract_to_concrete_mesh(sharding.mesh, da)) + return sharding + + @dataclasses.dataclass class UnloadedMeshExecutable: xla_executable: Any - device_assignment: xc.DeviceList | Sequence[xc.Device] - backend: xb.XlaBackend + device_list: xc.DeviceList + backend: xc.Client input_avals: Sequence[ShapedArray] input_shardings: Sequence[JSharding] output_avals: Sequence[ShapedArray] @@ -2881,9 +2970,9 @@ class UnloadedMeshExecutable: kept_var_idx: set[int] mut: MutationData | None auto_spmd_lowering: bool - xla_in_layouts: Sequence[DeviceLocalLayout | None] - dispatch_in_layouts: Sequence[DeviceLocalLayout | None] - xla_out_layouts: Sequence[DeviceLocalLayout | None] + xla_in_layouts: Sequence[Layout | None] + dispatch_in_layouts: Sequence[Layout | None] + xla_out_layouts: Sequence[Layout | None] all_args_info: AllArgsInfo | None pgle_profiler: profiler.PGLEProfiler | None @@ -2905,7 +2994,8 @@ def load(self) -> MeshExecutable: self.input_shardings, self.output_shardings, self.auto_spmd_lowering, self.kept_var_idx, self.xla_in_layouts, self.dispatch_in_layouts, - self.xla_out_layouts, self.all_args_info, self) + self.xla_out_layouts, self.mut, self.all_args_info, + self) @staticmethod def from_hlo(name: str, @@ -2922,8 +3012,8 @@ def from_hlo(name: str, host_callbacks: list[Any], keepalive: Any, kept_var_idx: set[int], - backend: xb.XlaBackend, - device_assignment: xc.DeviceList | Sequence[xc.Device] | None, + backend: xc.Client, + device_list: xc.DeviceList | None, committed: bool, in_layouts: MaybeLayout, out_layouts: MaybeLayout, @@ -2938,20 +3028,19 @@ def from_hlo(name: str, context_mesh: Mesh | None = None, ) -> MeshExecutable: del num_devices # For compilation, we have an actual device_assignment - if (device_assignment is None or - any(isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh) - for s in it.chain(in_shardings, out_shardings))): + if device_list is None: raise RuntimeError( - "A jitted computation cannot contain AbstractMesh in in_shardings and" - " out_shardings during compilation. You can use `jax.export` to " - " lower with an AbstractMesh and later compile with concrete devices.") + "device_assignment cannot be `None` during compilation. Please pass a" + " tuple of devices to `.compile(device_assignment=)`") + + assert isinstance(device_list, xc.DeviceList) + in_shardings = tuple(maybe_concretize_mesh(i, device_list) + for i in in_shardings) + out_shardings = tuple(maybe_concretize_mesh(o, device_list) + for o in out_shardings) + if shape_poly_state is not None and shape_poly_state.uses_dim_vars: hlo = mlir.refine_polymorphic_shapes(hlo) - if isinstance(device_assignment, xc.DeviceList): - da = device_assignment - else: - da = _create_da_object(tuple(device_assignment)) - del device_assignment allow_prop_to_inputs, allow_prop_to_outputs = get_prop_to_input_output( in_shardings, out_shardings, len(ordered_effects)) @@ -2967,34 +3056,32 @@ def from_hlo(name: str, xla_executable = _cached_compilation( hlo, name, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, allow_prop_to_inputs, - allow_prop_to_outputs, tuple(host_callbacks), backend, da, pmap_nreps, - compiler_options_kvs, pgle_profiler) - - orig_out_shardings = out_shardings + allow_prop_to_outputs, tuple(host_callbacks), backend, device_list, + pmap_nreps, compiler_options_kvs, pgle_profiler) if auto_spmd_lowering: assert mesh is not None in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable( xla_executable, mesh) - in_shardings = [x if isinstance(i, AUTO) else i + in_shardings = [x if isinstance(i, AUTO) else i # type: ignore for x, i in safe_zip(in_shardings_xla, in_shardings)] - out_shardings = [x if isinstance(o, AUTO) else o + out_shardings = [x if isinstance(o, AUTO) else o # type: ignore for x, o in safe_zip(out_shardings_xla, out_shardings)] else: if pmap_nreps == 1: assert mesh is None in_shardings = _maybe_get_and_check_in_shardings( - xla_executable, in_shardings, tuple(da), global_in_avals, + xla_executable, in_shardings, device_list, global_in_avals, len(ordered_effects)) out_shardings = _maybe_get_and_check_out_shardings( - xla_executable, out_shardings, tuple(da), global_out_avals, + xla_executable, out_shardings, device_list, global_out_avals, len(ordered_effects)) else: - in_shardings, out_shardings, committed, da = _get_metadata_jit_pmap( + in_shardings, out_shardings, committed, device_list = _get_metadata_jit_pmap( xla_executable.local_devices(), len(in_shardings), len(out_shardings)) - # xla_in_layouts are all either None or DeviceLocalLayout. Even default - # layout are concrete layouts and they are used in `compiled.input_layouts` + # xla_in_layouts are all either None or Layout. Even default + # layout are concrete layouts and they are used in `compiled.input_formats` # to return concrete layouts to users. # `dispatch_in_layouts` replaces default layouts with `None` to simplify # dispatch logic downstream. @@ -3010,12 +3097,12 @@ def from_hlo(name: str, in_shardings, out_shardings, global_in_avals, global_out_avals, intermediate_shardings, context_mesh) - in_shardings = finalize_shardings(in_shardings, da) - out_shardings = finalize_shardings(out_shardings, da) + in_shardings = finalize_shardings(in_shardings, device_list) + out_shardings = finalize_shardings(out_shardings, device_list) return UnloadedMeshExecutable( xla_executable=xla_executable, - device_assignment=da, + device_list=device_list, backend=backend, input_avals=global_in_avals, input_shardings=in_shardings, @@ -3045,7 +3132,8 @@ class MeshExecutableFastpathData(NamedTuple): out_avals: Sequence[ShapedArray] out_committed: Sequence[bool] kept_var_bitvec: Iterable[bool] - in_device_local_layouts: Sequence[DeviceLocalLayout | None] + in_device_local_layouts: Sequence[Layout | None] + const_args: Sequence[ArrayLike] @dataclasses.dataclass(frozen=True, kw_only=True) @@ -3085,24 +3173,24 @@ def reflatten_outputs_for_dispatch(out_tree, out_flat): return tree_util.dispatch_registry.flatten(out_unflat, None) -class MeshExecutable(stages.XlaExecutable): +class MeshExecutable(stages.Executable): __slots__ = [ "xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals", "out_avals", "_in_shardings", "_out_shardings", "_auto_spmd_lowering", "_kept_var_idx", "_xla_in_layouts", "_dispatch_in_layouts", - "_xla_out_layouts", "_all_args_info", "_unloaded_executable", + "_xla_out_layouts", "_mut", "_all_args_info", "_unloaded_executable", ] def __init__(self, xla_executable, build_unsafe_call, in_avals, out_avals, in_shardings, out_shardings, auto_spmd_lowering, kept_var_idx, - xla_in_layouts, dispatch_in_layouts, xla_out_layouts, + xla_in_layouts, dispatch_in_layouts, xla_out_layouts, mut, all_args_info: AllArgsInfo | None = None, unloaded_executable=None): self.xla_executable = xla_executable self.build_unsafe_call = build_unsafe_call # in_avals is a list of global and local avals. Aval is global if input # is a GDA or jax.Array else local. - self.in_avals = in_avals + self.in_avals = in_avals # includes the const_args self.out_avals = out_avals self._unsafe_call = None self._in_shardings = in_shardings @@ -3112,6 +3200,7 @@ def __init__(self, xla_executable, build_unsafe_call, in_avals, out_avals, self._xla_in_layouts = xla_in_layouts self._dispatch_in_layouts = dispatch_in_layouts self._xla_out_layouts = xla_out_layouts + self._mut = mut self._all_args_info = all_args_info self._unloaded_executable = unloaded_executable @@ -3121,66 +3210,68 @@ def unsafe_call(self) -> Callable[..., Any]: self._unsafe_call = self.build_unsafe_call() return self._unsafe_call # type: ignore - # -- stages.XlaExecutable overrides + # -- stages.Executable overrides def xla_extension_executable(self): return self.xla_executable def call(self, *args): args_after_dce = [a for i, a in enumerate(args) if i in self._kept_var_idx] - if self._all_args_info is None: - kept_args = args_after_dce - ref_avals = self.in_avals - # TODO(necula): ensure we have actual debug info; need debug info - # before DCE. - # See https://github.com/jax-ml/jax/issues/26480. - debug_info = core.DebugInfo( - "MeshExecutable", "", - tuple(f"args[{i}]" for i in range(len(args))), ()) + if (self._all_args_info is not None and + self._all_args_info.debug_info.arg_names is not None): + arg_names_after_dce = [ + n for i, n in enumerate(self._all_args_info.debug_info.arg_names) + if i in self._kept_var_idx] else: - kept_args = args - ref_avals = self._all_args_info.in_avals - debug_info = self._all_args_info.debug_info - - all_arg_avals = map(core.abstractify, kept_args) - check_arg_avals_for_call(ref_avals, all_arg_avals, debug_info) - check_array_xla_sharding_layout_match( - args_after_dce, self._in_shardings, self._xla_in_layouts, debug_info, - self._kept_var_idx) - return self.unsafe_call(*args) # pylint: disable=not-callable - - def input_shardings(self) -> Sequence[JSharding]: - return self._in_shardings - - def output_shardings(self) -> Sequence[JSharding]: - return self._out_shardings + arg_names_after_dce = ("",) * len(args_after_dce) - def input_layouts(self): - return [Layout(l, s) - for l, s in safe_zip(self._xla_in_layouts, self._in_shardings)] - - def output_layouts(self): - return [Layout(l, s) - for l, s in safe_zip(self._xla_out_layouts, self._out_shardings)] + if self._all_args_info is not None: + # We check all args before DCE + check_arg_avals_for_call(self._all_args_info.in_avals, + map(core.shaped_abstractify, args), + self._all_args_info.debug_info) + else: + # We can only check the args after DCE + check_arg_avals_for_call(self.in_avals, + map(core.shaped_abstractify, args_after_dce), + core.DebugInfo("MeshExecutable", "", + arg_names_after_dce, None)) + if not self._mut: + check_array_xla_sharding_layout_match( + args_after_dce, self._in_shardings, self._xla_in_layouts, + arg_names_after_dce) + else: + args_after_dce = [*args_after_dce, *self._mut.in_mut] + arg_names_after_dce += (("",) * len(self._mut.in_mut)) + check_array_xla_sharding_layout_match( + args_after_dce, self._in_shardings, self._xla_in_layouts, + arg_names_after_dce) + return self.unsafe_call(*args) # pylint: disable=not-callable - def create_cpp_call(self, no_kwargs, in_tree, out_tree): + def create_cpp_call(self, params: stages.CompiledCallParams): if not (isinstance(self.unsafe_call, ExecuteReplicated) and not self.unsafe_call.has_unordered_effects and not self.unsafe_call.has_host_callbacks): return None def aot_cache_miss(*args, **kwargs): - params = stages.CompiledCallParams(self, no_kwargs, in_tree, out_tree) + # args do not include the const args. + # See https://docs.jax.dev/en/latest/internals/constants.html. outs, out_flat, args_flat = stages.Compiled.call(params, *args, **kwargs) - out_flat, out_tree_dispatch = reflatten_outputs_for_dispatch( - out_tree, out_flat) - use_fastpath = (all(isinstance(x, xc.ArrayImpl) for x in out_flat)) + + if not params.is_high: + out_flat, out_tree_dispatch = reflatten_outputs_for_dispatch( + params.out_tree, out_flat) + use_fastpath = (all(isinstance(x, xc.ArrayImpl) for x in out_flat) + and not self._mut) + else: + use_fastpath = False if use_fastpath: out_avals = [o.aval for o in out_flat] out_committed = [o._committed for o in out_flat] kept_var_bitvec = [i in self._kept_var_idx - for i in range(len(args_flat))] + for i in range(len(params.const_args) + len(args_flat))] in_shardings = [ sharding_impls.physical_sharding(a, s) if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended) @@ -3190,7 +3281,7 @@ def aot_cache_miss(*args, **kwargs): fastpath_data = MeshExecutableFastpathData( self.xla_executable, out_tree_dispatch, in_shardings, self._out_shardings, out_avals, out_committed, kept_var_bitvec, - self._dispatch_in_layouts) + self._dispatch_in_layouts, params.const_args) else: fastpath_data = None return outs, fastpath_data, False # Do not remove cache entry @@ -3200,7 +3291,8 @@ def aot_cache_miss(*args, **kwargs): JitGlobalCppCacheKeys(), tree_util.dispatch_registry, cc_shard_arg) def cc_shard_arg(x, sharding, layout): - return shard_args([sharding], [layout], [None], [x])[0] + return shard_args([sharding], [layout], [xc.ArrayCopySemantics.REUSE_INPUT], + [x])[0] def check_arg_avals_for_call(ref_avals, arg_avals, @@ -3230,7 +3322,9 @@ def check_arg_avals_for_call(ref_avals, arg_avals, num_mismatch_str = "The" raise TypeError( "Argument types differ from the types for which this computation was " - f"compiled. {num_mismatch_str} mismatches are:\n{str_errors}") + "compiled. Perhaps you are calling the compiled executable with a " + "different enable_x64 mode than when it was AOT compiled? " + f"{num_mismatch_str} mismatches are:\n{str_errors}") def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings): @@ -3243,7 +3337,8 @@ def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings): # It is unsupported for these shardings to be uncommitted, so force # the outputs to be committed. committed = True - return in_shardings, out_shardings, committed, tuple(local_devices) + return (in_shardings, out_shardings, committed, + _create_device_list(tuple(local_devices))) def check_device_backend_on_shardings(shardings) -> bool: @@ -3256,22 +3351,15 @@ def check_device_backend_on_shardings(shardings) -> bool: def check_array_xla_sharding_layout_match( - args_after_dce, - in_xla_shardings: Sequence[JSharding], - in_xla_layouts: Sequence[DeviceLocalLayout], - jaxpr_debug_info: core.DebugInfo, - kept_var_idx: set[int]) -> None: - from jax._src.array import ArrayImpl - # jaxpr_debug_info.arg_names are before DCE, so need to DCE them. - arg_names = ( - [a for i, a in enumerate(jaxpr_debug_info.arg_names) - if i in kept_var_idx] - ) + args, + in_shardings: Sequence[JSharding], + in_layouts: Sequence[Layout], + arg_names: Sequence[str] +) -> None: errors = [] num_errors = 5 - for arg, xs, xl, name in safe_zip( - args_after_dce, in_xla_shardings, in_xla_layouts, arg_names): - if not isinstance(arg, ArrayImpl): + for arg, xs, xl, name in zip(args, in_shardings, in_layouts, arg_names): + if not isinstance(arg, array.ArrayImpl): continue if isinstance(xs, (UnspecifiedValue, AUTO)): continue @@ -3280,41 +3368,41 @@ def check_array_xla_sharding_layout_match( if (not db_xs and arg._committed and not arg.sharding.is_equivalent_to(xs, arg.ndim)): - errors.append( - ("Got input sharding(s) that compiled object was called with: " - f"{arg.sharding} and sharding(s) the computation was compiled " - f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}", - 'sharding')) + errors.append(( + f"Argument {name} with shape {arg.aval.str_short()}:\n" + f" Passed sharding: {arg.sharding}\n" + f" Required sharding: {xs}", + "sharding")) if (not db_xs and arg._committed and - arg.layout.device_local_layout is not None and xl is not None and - arg.layout.device_local_layout != xl): - errors.append( - ("Got input layout(s) that compiled object was called with: " - f"{arg.layout.device_local_layout} and layout(s) the computation was " - f"compiled with: {xl} for arg {name} with " - f"shape: {arg.aval.str_short()}", - 'layout')) + arg.format.layout is not None and xl is not None and + arg.format.layout != xl): + errors.append(( + f"Argument {name} with shape {arg.aval.str_short()}:\n" + f" Passed layout: {arg.format.layout}\n" + f" Required layout: {xl}", + "layout")) if errors: first_errors, error_kinds = unzip2(errors[:num_errors]) str_errors = '\n'.join(first_errors) if all(k == 'sharding' for k in error_kinds): - kind_str = r'sharding(s)' + kind_str = r'shardings' elif all(k == 'layout' for k in error_kinds): - kind_str = 'layout(s)' + kind_str = 'layouts' else: - kind_str = 'sharding(s) and layout(s)' + kind_str = 'shardings and layouts' num_mismatch_str = ( - f'the {len(errors)} mismatches' if len(errors) < num_errors else + f"the {len(errors)} mismatches" if len(errors) < num_errors else f"{num_errors} mismatches out of {len(errors)}") raise ValueError( - f"Compiled object called with input {kind_str} does " - f"not match the {kind_str} the computation was " - "compiled with. " - f"Here are {num_mismatch_str}:\n{str_errors}") - - -def get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified: - pspec = sharding_impls.prepare_axis_resources(pspec, "pspec to array_mapping") - return _get_array_mapping(pspec) + f"Computation was compiled for input {kind_str} that disagree with the " + f"{kind_str} of arguments passed to it. " + f"Here are {num_mismatch_str}:\n{str_errors}") + +def batch_spec(spec, dim, val): + too_short = dim - len(spec) + if too_short > 0: + spec += (None,) * too_short + new_partitions = tuple_insert(spec, dim, val) # type: ignore + return PartitionSpec(*new_partitions) diff --git a/jax/_src/interpreters/xla.py b/jax/_src/interpreters/xla.py deleted file mode 100644 index 33a8992a8be4..000000000000 --- a/jax/_src/interpreters/xla.py +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright 2018 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Lowering of jaxprs into XLA (HLO) computations. - -from __future__ import annotations - -from collections.abc import Callable, Sequence -from functools import partial -from typing import Any, Union - -import numpy as np - -from jax._src import core -from jax._src import dtypes -from jax._src.abstract_arrays import numpy_scalar_types -from jax._src.core import ShapedArray -from jax._src.util import safe_zip, safe_map - -from jax._src.typing import Shape - -from jax._src.lib import xla_client as xc - -map, unsafe_map = safe_map, map -zip, unsafe_zip = safe_zip, zip - -# Types - -def identity(x): return x - -_scalar_types = dtypes.python_scalar_dtypes.keys() - -def _make_array_shape(aval: ShapedArray) -> Sequence[xc.Shape]: - aval = core.physical_aval(aval) - dtype = np.dtype('bool') if aval.dtype == dtypes.float0 else aval.dtype - return (xc.Shape.array_shape(dtype, aval.shape),) - -# Utilities - -# HLO instructions optionally can be annotated to say how the output should be -# spatially partitioned (represented in XLA as OpSharding protos, see -# sharding_to_proto). For array outputs, the annotation is either an int per -# dimension specifying the number of ways that dimension divided (i.e. the total -# number of shards is the product), or None to indicate the array should be -# replicated. Tuple outputs are represented as tuples thereof. XLA supports -# arbitrary tuple nesting, but JAX only uses one level of tupling (and our type -# checkers don't support recursive types), so we only represent one level of -# nesting in this type definition. -SpatialSharding = Union[Shape, None, tuple[Union[Shape, None], ...]] - - -def sharding_to_proto(sharding: SpatialSharding): - """Converts a SpatialSharding to an OpSharding. - - See - https://github.com/tensorflow/tensorflow/blob/main/tensorflow/compiler/xla/xla_data.proto#L601 - for details on the OpSharding proto. - """ - proto = xc.OpSharding() - if isinstance(sharding, tuple) and not isinstance(sharding[0], int): - assert all(s is None or isinstance(s, tuple) for s in sharding) - return tuple_sharding_proto(list(map(sharding_to_proto, sharding))) - - if sharding is None: - proto.type = xc.OpSharding.Type.REPLICATED - else: - proto.type = xc.OpSharding.Type.OTHER - proto.tile_assignment_dimensions = list(sharding) # type: ignore - proto.tile_assignment_devices = list(range(np.prod(sharding))) # type: ignore - return proto - -def tuple_sharding_proto(elems): - proto = xc.OpSharding() - assert all(isinstance(e, type(proto)) for e in elems) - proto.type = xc.OpSharding.Type.TUPLE - proto.tuple_shardings = elems - return proto - - -### handlers - -# JAX abstract values -> XLA shapes - -def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[xc.Shape]: - try: - return _xla_shape_handlers[type(aval)](aval) - except KeyError as err: - raise TypeError(f"No xla_shape_handler for type: {type(aval)}") from err - -_xla_shape_handlers: dict[type[core.AbstractValue], - Callable[[Any], Sequence[xc.Shape]]] = { - ShapedArray: _make_array_shape, -} -_xla_shape_handlers[core.AbstractToken] = lambda _: (xc.Shape.token_shape(),) - - -# IR constants - -class InvalidInputException(Exception): - pass - - -# TODO(mattjj): try to remove this canonicalize_dtype stuff -def canonicalize_dtype(x): - typ = type(x) - handler = canonicalize_dtype_handlers.get(typ) - if handler: return handler(x) - for typ in typ.__mro__: - handler = canonicalize_dtype_handlers.get(typ) - if handler: return handler(x) - if hasattr(x, '__jax_array__'): - return canonicalize_dtype(x.__jax_array__()) - raise InvalidInputException( - f"Argument '{x}' of type {type(x)} is not a valid JAX type.") - -def _canonicalize_masked_array_dtype(x): - raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. " - "Use arr.filled() to convert the value to a standard numpy array.") - -def _canonicalize_ndarray_dtype(x): - return np.asarray(x, dtypes.canonicalize_dtype(x.dtype)) - -def _canonicalize_python_scalar_dtype(typ, x): - return np.asarray( - x, dtypes.canonicalize_dtype(dtypes._scalar_type_to_dtype(typ, x))) - -canonicalize_dtype_handlers: dict[Any, Callable] = {} -canonicalize_dtype_handlers.update( - (t, _canonicalize_ndarray_dtype) for t in numpy_scalar_types) -canonicalize_dtype_handlers[np.ndarray] = _canonicalize_ndarray_dtype -canonicalize_dtype_handlers[np.ma.MaskedArray] = _canonicalize_masked_array_dtype -canonicalize_dtype_handlers.update( - (t, partial(_canonicalize_python_scalar_dtype, t)) for t in _scalar_types) -canonicalize_dtype_handlers[core.Token] = identity -canonicalize_dtype_handlers[core.DArray] = identity -canonicalize_dtype_handlers[core.MutableArray] = identity - -initial_style_primitives: set[core.Primitive] = set() - -def register_initial_style_primitive(prim: core.Primitive): - initial_style_primitives.add(prim) diff --git a/jax/_src/jaxpr_util.py b/jax/_src/jaxpr_util.py index ab72634d3bdf..30795b9b318c 100644 --- a/jax/_src/jaxpr_util.py +++ b/jax/_src/jaxpr_util.py @@ -21,10 +21,14 @@ import gzip import itertools import json +import logging import types -from typing import Any, Iterator, Union +from typing import Any, Union +from collections.abc import Iterator +from jax._src import config from jax._src import core +from jax._src import path from jax._src import util from jax._src import source_info_util from jax._src.lib import xla_client @@ -32,12 +36,26 @@ map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip +logger = logging.getLogger(__name__) -def all_eqns(jaxpr: core.Jaxpr) -> Iterator[tuple[core.Jaxpr, core.JaxprEqn]]: + +def _all_eqns( + jaxpr: core.Jaxpr, visited: set[core.Jaxpr] | None, +) -> Iterator[tuple[core.Jaxpr, core.JaxprEqn]]: for eqn in jaxpr.eqns: yield (jaxpr, eqn) for subjaxpr in core.subjaxprs(jaxpr): - yield from all_eqns(subjaxpr) + if visited is None: + yield from _all_eqns(subjaxpr, visited) + elif subjaxpr not in visited: + visited.add(subjaxpr) + yield from _all_eqns(subjaxpr, visited) + +def all_eqns( + jaxpr: core.Jaxpr, revisit_inner_jaxprs: bool = True +) -> Iterator[tuple[core.Jaxpr, core.JaxprEqn]]: + yield from _all_eqns(jaxpr, None if revisit_inner_jaxprs else set()) + def collect_eqns(jaxpr: core.Jaxpr, key: Callable): d = defaultdict(list) @@ -130,8 +148,16 @@ def print_histogram(histogram: dict[Any, int]): print(count_fmt.format(count), name) +DEFAULT_WORKSPACE_ROOT: str | None = None + +def _strip_workspace_root(filename: str, workspace_root: str) -> str: + i = filename.rfind(workspace_root) + return filename[i+len(workspace_root):] if i >= 0 else filename + + def _pprof_profile( - profile: dict[tuple[xla_client.Traceback | None, core.Primitive], int] + profile: dict[tuple[xla_client.Traceback | None, core.Primitive], int], + workspace_root: str | None = None, ) -> bytes: """Converts a profile into a compressed pprof protocol buffer. @@ -169,14 +195,19 @@ def _pprof_profile( "line": xla_client.Traceback.code_addr2line(code, lasti)}]} for (code, lasti), loc_id in loc.items() ] - functions = [ - {"id": func_id, - "name": s[code.co_name], - "system_name": s[code.co_name], - "filename": s[code.co_filename], - "start_line": code.co_firstlineno} - for code, func_id in func.items() - ] + functions = [] + for code, func_id in func.items(): + filename = code.co_filename + name = code.co_qualname + if workspace_root is not None: + filename = _strip_workspace_root(filename, workspace_root) + name = f"{filename.removesuffix('.py').replace('/', '.')}.{name}" + functions.append( + {"id": func_id, + "name": s[name], + "filename": s[filename], + "start_line": code.co_firstlineno} + ) sample_type = [{"type": s["equations"], "unit": s["count"]}] # This is the JSON encoding of a pprof profile protocol buffer. See: # https://github.com/google/pprof/blob/master/proto/profile.proto for a @@ -191,7 +222,8 @@ def _pprof_profile( return gzip.compress(xla_client._xla.json_to_pprof_profile(json_profile)) -def pprof_equation_profile(jaxpr: core.Jaxpr) -> bytes: +def pprof_equation_profile(jaxpr: core.Jaxpr, *, + workspace_root: str | None = None) -> bytes: """Generates a pprof profile that maps jaxpr equations to Python stack traces. By visualizing the profile using pprof, one can identify Python code that is @@ -199,6 +231,8 @@ def pprof_equation_profile(jaxpr: core.Jaxpr) -> bytes: Args: jaxpr: a Jaxpr. + workspace_root: the root of the workspace. If specified, function names + will be fully qualified, with respect to the workspace root. Returns: A gzip-compressed pprof Profile protocol buffer, suitable for passing to @@ -206,6 +240,76 @@ def pprof_equation_profile(jaxpr: core.Jaxpr) -> bytes: """ d = Counter( (eqn.source_info.traceback, eqn.primitive) - for _, eqn in all_eqns(jaxpr) + for _, eqn in all_eqns(jaxpr, revisit_inner_jaxprs=False) ) - return _pprof_profile(d) + return _pprof_profile(d, workspace_root or DEFAULT_WORKSPACE_ROOT) + +def eqns_using_var_with_invar_index(jaxpr: core.Jaxpr, invar: core.Var) -> Iterator[tuple[core.JaxprEqn, int]]: + """Find all the equations which use invar and the positional index of its binder""" + for eqn in jaxpr.eqns: + for invar_index, eqn_var in enumerate(eqn.invars): + if eqn_var == invar: + yield eqn, invar_index + break # we found the var, no need to keep looking in this eqn + +def jaxpr_and_binder_in_params(params, index: int) -> Iterator[tuple[core.Jaxpr, core.Var]]: + for val in params.values(): + vals = val if isinstance(val, tuple) else (val,) + for v in vals: + if isinstance(v, core.Jaxpr): + if index >= len(v.invars): + raise RuntimeError(f"Failed to find index {index} in jaxpr.invars while building report") + yield v, v.invars[index] + elif isinstance(v, core.ClosedJaxpr): + if index >= len(v.jaxpr.invars): + raise RuntimeError(f"Failed to find index {index} in jaxpr.invars while building report") + yield v.jaxpr, v.jaxpr.invars[index] + +def eqns_using_var(jaxpr: core.Jaxpr, invar: core.Var) -> Iterator[core.JaxprEqn]: + """Find the leaf equations using a variable""" + # The complexity of this call is because the invar might originate from a nested jaxpr + for eqn, invar_index in eqns_using_var_with_invar_index(jaxpr, invar): + if (child_jaxprs_and_vars := tuple(jaxpr_and_binder_in_params(eqn.params, invar_index))): + for (jaxpr, invar) in child_jaxprs_and_vars: + yield from eqns_using_var(jaxpr, invar) + else: + # if the previous condition fails, there is no deeper jaxpr to explore =( + yield eqn + + +_jaxpr_id_counter = itertools.count() + +def maybe_dump_jaxpr_to_file( + fun_name: str, jaxpr: core.Jaxpr +) -> str | None: + """Maybe dumps the `jaxpr` to a file. + + Dumps the jaxpr if JAX_DUMP_JAXPR_TO is defined. + + Args: + fn: The name of the function whose jaxpr is being dumped. + jaxpr: The jaxpr to dump. + + Returns: + The path to the file where the jaxpr was dumped, or None if no file was + dumped. + """ + if not (out_dir := path.make_jax_dump_dir(config.jax_dump_ir_to.value)): + return None + modes = config.jax_dump_ir_modes.value.split(",") + if "jaxpr" not in modes and "eqn_count_pprof" not in modes: + return None + id = next(_jaxpr_id_counter) + if "jaxpr" in modes: + logging.log( + logging.INFO, "Dumping jaxpr for %s to %s.", fun_name, out_dir + ) + jaxpr_path = out_dir / f"jax_{id:06d}_{fun_name}.jaxpr.txt" + jaxpr_path.write_text(jaxpr.pretty_print()) + if "eqn_count_pprof" in modes: + logging.log( + logging.INFO, "Dumping eqn count pprof for %s to %s.", fun_name, out_dir + ) + eqn_prof_path = out_dir / f"jax_{id:06d}_{fun_name}.eqn_count_pprof" + eqn_prof_path.write_bytes(pprof_equation_profile(jaxpr)) + return fun_name diff --git a/jax/_src/lax/__init__.py b/jax/_src/lax/__init__.py index a06eea6941cb..bcad73dd800f 100644 --- a/jax/_src/lax/__init__.py +++ b/jax/_src/lax/__init__.py @@ -16,3 +16,301 @@ from jax._src import traceback_util traceback_util.register_exclusion(os.path.dirname(__file__)) +del os, traceback_util + +# Import a subset of objects from `lax` for internal use. + +from jax._src.lax.lax import ( + DotAlgorithmPreset as DotAlgorithmPreset, + Precision as Precision, + RandomAlgorithm as RandomAlgorithm, + RoundingMethod as RoundingMethod, + abs as abs, + abs_p as abs_p, + acos_p as acos_p, + acosh_p as acosh_p, + add as add, + add_p as add_p, + after_all_p as after_all_p, + and_p as and_p, + argmax_p as argmax_p, + argmin_p as argmin_p, + asin_p as asin_p, + asinh_p as asinh_p, + atan_p as atan_p, + atan2_p as atan2_p, + atanh_p as atanh_p, + bitcast_convert_type_p as bitcast_convert_type_p, + bitwise_and as bitwise_and, + bitwise_not as bitwise_not, + bitwise_or as bitwise_or, + bitwise_xor as bitwise_xor, + broadcast_in_dim as broadcast_in_dim, + broadcast_in_dim_p as broadcast_in_dim_p, + broadcasted_iota as broadcasted_iota, + cbrt as cbrt, + cbrt_p as cbrt_p, + ceil as ceil, + ceil_p as ceil_p, + clamp as clamp, + clamp_p as clamp_p, + clz_p as clz_p, + complex as complex, + complex_p as complex_p, + concatenate as concatenate, + concatenate_p as concatenate_p, + conj as conj, + conj_p as conj_p, + convert_element_type as convert_element_type, + convert_element_type_p as convert_element_type_p, + copy_p as copy_p, + dce_sink_p as dce_sink_p, + cos as cos, + dce_sink as dce_sink, + cos_p as cos_p, + cosh as cosh, + cosh_p as cosh_p, + create_token_p as create_token_p, + div as div, + div_p as div_p, + dot as dot, + dot_general as dot_general, + dot_general_p as dot_general_p, + dtype as dtype, + empty as empty, + eq as eq, + eq_p as eq_p, + eq_to_p as eq_to_p, + expand_dims as expand_dims, + exp as exp, + exp_p as exp_p, + exp2 as exp2, + exp2_p as exp2_p, + expm1 as expm1, + expm1_p as expm1_p, + floor as floor, + floor_p as floor_p, + full as full, + full_like as full_like, + ge as ge, + ge_p as ge_p, + gt as gt, + gt_p as gt_p, + imag as imag, + imag_p as imag_p, + integer_pow as integer_pow, + integer_pow_p as integer_pow_p, + iota as iota, + iota_p as iota_p, + is_finite_p as is_finite_p, + le as le, + le_p as le_p, + le_to_p as le_to_p, + log1p as log1p, + log1p_p as log1p_p, + log as log, + log_p as log_p, + logistic as logistic, + logistic_p as logistic_p, + lt as lt, + lt_p as lt_p, + lt_to_p as lt_to_p, + max as max, + max_p as max_p, + min as min, + min_p as min_p, + mul as mul, + mul_p as mul_p, + ne as ne, + ne_p as ne_p, + neg as neg, + neg_p as neg_p, + nextafter_p as nextafter_p, + not_p as not_p, + optimization_barrier_p as optimization_barrier_p, + or_p as or_p, + pad as pad, + pad_p as pad_p, + padtype_to_pads as padtype_to_pads, + population_count as population_count, + population_count_p as population_count_p, + pow as pow, + pow_p as pow_p, + real as real, + real_p as real_p, + reduce_and as reduce_and, + reduce_and_p as reduce_and_p, + reduce_max as reduce_max, + reduce_max_p as reduce_max_p, + reduce_min as reduce_min, + reduce_min_p as reduce_min_p, + reduce_or as reduce_or, + reduce_or_p as reduce_or_p, + reduce as reduce, + reduce_p as reduce_p, + reduce_precision as reduce_precision, + reduce_precision_p as reduce_precision_p, + reduce_prod as reduce_prod, + reduce_prod_p as reduce_prod_p, + reduce_sum as reduce_sum, + reduce_sum_p as reduce_sum_p, + reduce_xor as reduce_xor, + reduce_xor_p as reduce_xor_p, + rem_p as rem_p, + reshape as reshape, + reshape_p as reshape_p, + rev as rev, + rev_p as rev_p, + rng_bit_generator as rng_bit_generator, + rng_bit_generator_p as rng_bit_generator_p, + rng_uniform_p as rng_uniform_p, + round as round, + round_p as round_p, + rsqrt as rsqrt, + rsqrt_p as rsqrt_p, + select as select, + select_n as select_n, + select_n_p as select_n_p, + shift_left_p as shift_left_p, + shift_right_arithmetic_p as shift_right_arithmetic_p, + shift_right_logical_p as shift_right_logical_p, + sign as sign, + sign_p as sign_p, + sin as sin, + sin_p as sin_p, + sinh as sinh, + sinh_p as sinh_p, + sort_key_val as sort_key_val, + sort_p as sort_p, + split as split, + split_p as split_p, + sqrt as sqrt, + sqrt_p as sqrt_p, + square as square, + square_p as square_p, + squeeze as squeeze, + squeeze_p as squeeze_p, + stop_gradient as stop_gradient, + sub as sub, + sub_p as sub_p, + tan as tan, + tan_p as tan_p, + tanh as tanh, + tanh_p as tanh_p, + top_k as top_k, + top_k_p as top_k_p, + transpose as transpose, + transpose_p as transpose_p, + xor_p as xor_p, +) +from jax._src.lax.other import ( + conv_general_dilated_patches as conv_general_dilated_patches, +) +from jax._src.lax.special import ( + bessel_i0e as bessel_i0e, + bessel_i0e_p as bessel_i0e_p, + bessel_i1e as bessel_i1e, + bessel_i1e_p as bessel_i1e_p, + betainc as betainc, + digamma as digamma, + digamma_p as digamma_p, + erfc as erfc, + erfc_p as erfc_p, + erf_inv as erf_inv, + erf_inv_p as erf_inv_p, + erf as erf, + erf_p as erf_p, + igammac as igammac, + igammac_p as igammac_p, + igamma_grad_a_p as igamma_grad_a_p, + igamma as igamma, + igamma_p as igamma_p, + lgamma as lgamma, + lgamma_p as lgamma_p, + polygamma as polygamma, + polygamma_p as polygamma_p, + regularized_incomplete_beta_p as regularized_incomplete_beta_p, + zeta as zeta, + zeta_p as zeta_p, +) +from jax._src.lax.slicing import ( + GatherDimensionNumbers as GatherDimensionNumbers, + GatherScatterMode as GatherScatterMode, + ScatterDimensionNumbers as ScatterDimensionNumbers, + dynamic_index_in_dim as dynamic_index_in_dim, + dynamic_slice as dynamic_slice, + dynamic_slice_in_dim as dynamic_slice_in_dim, + dynamic_slice_p as dynamic_slice_p, + dynamic_update_index_in_dim as dynamic_update_index_in_dim, + dynamic_update_slice as dynamic_update_slice, + dynamic_update_slice_in_dim as dynamic_update_slice_in_dim, + dynamic_update_slice_p as dynamic_update_slice_p, + index_in_dim as index_in_dim, + gather as gather, + gather_p as gather_p, + scatter_add as scatter_add, + scatter_add_p as scatter_add_p, + scatter_max as scatter_max, + scatter_max_p as scatter_max_p, + scatter_min as scatter_min, + scatter_min_p as scatter_min_p, + scatter_mul as scatter_mul, + scatter_mul_p as scatter_mul_p, + scatter_sub as scatter_sub, + scatter_sub_p as scatter_sub_p, + scatter as scatter, + scatter_p as scatter_p, + slice as slice, + slice_in_dim as slice_in_dim, + slice_p as slice_p, +) +from jax._src.lax.convolution import ( + conv_general_dilated as conv_general_dilated, + conv_general_dilated_p as conv_general_dilated_p, +) +from jax._src.lax.windowed_reductions import ( + reduce_window as reduce_window, + reduce_window_max_p as reduce_window_max_p, + reduce_window_min_p as reduce_window_min_p, + reduce_window_p as reduce_window_p, + reduce_window_sum_p as reduce_window_sum_p, + select_and_gather_add_p as select_and_gather_add_p, + select_and_scatter_p as select_and_scatter_p, + select_and_scatter_add_p as select_and_scatter_add_p, +) +from jax._src.lax.control_flow import ( + cond as cond, + cond_p as cond_p, + cumlogsumexp_p as cumlogsumexp_p, + cummax_p as cummax_p, + cummin_p as cummin_p, + cumprod_p as cumprod_p, + cumsum_p as cumsum_p, + custom_linear_solve as custom_linear_solve, + fori_loop as fori_loop, + linear_solve_p as linear_solve_p, + scan as scan, + scan_p as scan_p, + switch as switch, + while_loop as while_loop, + while_p as while_p, +) +from jax._src.lax.fft import ( + fft_p as fft_p, + FftType as FftType, +) +from jax._src.lax.parallel import ( + all_gather_p as all_gather_p, + all_to_all_p as all_to_all_p, + axis_index as axis_index, + axis_index_p as axis_index_p, + axis_size as axis_size, + pmax_p as pmax_p, + pmin_p as pmin_p, + ppermute_p as ppermute_p, + psum_p as psum_p, + ragged_all_to_all_p as ragged_all_to_all_p, +) +from jax._src.lax.ann import ( + approx_top_k_p as approx_top_k_p +) diff --git a/jax/_src/lax/ann.py b/jax/_src/lax/ann.py index 0e037ec774b5..009ab4ce9025 100644 --- a/jax/_src/lax/ann.py +++ b/jax/_src/lax/ann.py @@ -78,11 +78,11 @@ def pmap_mips(qy, db, db_offset, db_size, k, recall_target): from jax._src import core from jax._src import dispatch from jax._src import dtypes +from jax._src.numpy.indexing import take_along_axis from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir -from jax._src.lax import lax -from jax._src.lib import xla_client as xc +from jax._src.lib import _jax from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func from jax._src.lib.mlir.dialects import hlo @@ -231,7 +231,7 @@ def _approx_top_k_abstract_eval(operand, *, k, reduction_dimension, if aggregate_to_topk: dims[reduction_dimension] = k elif core.is_constant_shape((reduction_input_size, k)): - dims[reduction_dimension] = xc.ops.ApproxTopKReductionOutputSize( + dims[reduction_dimension] = _jax.approx_top_k_reduction_output_size( reduction_input_size, len(dims), k, recall_target, aggregate_to_topk, reduction_input_size_override)[0] else: @@ -239,9 +239,17 @@ def _approx_top_k_abstract_eval(operand, *, k, reduction_dimension, "approx_top_k with aggregate_to_topk=False not yet implemented when " f"either the `k` ({k}) or the " f" reduction dimension size ({reduction_input_size}) are symbolic") + operand_s = operand.sharding + if operand_s.spec[reduction_dimension] is not None: + raise core.ShardingTypeError( + f"reduction dimension {reduction_dimension} in operand" + f" {operand.str_short()} should be unsharded i.e. the spec of that dim" + " should be `None`.") return (operand.update(shape=dims, dtype=operand.dtype, - weak_type=operand.weak_type), - operand.update(shape=dims, dtype=np.dtype(np.int32))) + weak_type=operand.weak_type, vma=operand.vma, + sharding=operand_s), + operand.update(shape=dims, dtype=np.dtype(np.int32), vma=operand.vma, + sharding=operand_s)) def _get_init_val_literal(op_type, is_max_k): return np.array(-np.inf if is_max_k else np.inf, dtype=op_type) @@ -379,12 +387,7 @@ def _approx_top_k_jvp(primals, tangents, *, k, reduction_dimension, rank = len(arg_shape) if reduction_dimension < 0: reduction_dimension += rank - iotas = [ - lax.broadcasted_iota(arg_out.dtype, arg_shape, i) for i in range(rank) - ] - idx = tuple( - arg_out if i == reduction_dimension else iotas[i] for i in range(rank)) - tangent_out = tangent[idx] + tangent_out = take_along_axis(tangent, arg_out, axis=reduction_dimension) return (val_out, arg_out), (tangent_out, ad_util.Zero.from_primal_value(arg_out)) diff --git a/jax/_src/lax/control_flow/__init__.py b/jax/_src/lax/control_flow/__init__.py index f89e4d53a476..79488d89e3b7 100644 --- a/jax/_src/lax/control_flow/__init__.py +++ b/jax/_src/lax/control_flow/__init__.py @@ -34,6 +34,7 @@ while_p as while_p, ) from jax._src.lax.control_flow.conditionals import ( + BranchesPlatforms as BranchesPlatforms, cond as cond, cond_p as cond_p, switch as switch, @@ -49,11 +50,8 @@ # Private utilities used elsewhere in JAX # TODO(sharadmv): lift them into a more common place from jax._src.lax.control_flow.common import ( - _initial_style_open_jaxpr as _initial_style_open_jaxpr, - _initial_style_jaxpr as _initial_style_jaxpr, - _initial_style_jaxprs_with_common_consts as _initial_style_jaxprs_with_common_consts, _check_tree_and_avals as _check_tree_and_avals, - + _merge_common_consts as _merge_common_consts, ) # TODO(mattjj): fix dependent library which expects optimization_barrier_p here from jax._src.lax.lax import optimization_barrier_p as optimization_barrier_p diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index b75cbf6ac708..9518b4484bd9 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -15,27 +15,22 @@ from __future__ import annotations -from collections.abc import Callable, Sequence +from collections.abc import Sequence import os from functools import partial from typing import Any -from jax._src import api_util +from jax._src import ad_util from jax._src import core +from jax._src import config from jax._src import linear_util as lu -from jax._src.lax import lax -from jax._src import effects -from jax._src import ad_util -from jax._src import state -from jax._src.util import weakref_lru_cache, safe_map, partition_list +from jax._src.util import weakref_lru_cache, safe_map from jax._src.interpreters import partial_eval as pe -from jax.tree_util import tree_map, tree_unflatten, keystr, PyTreeDef -from jax._src.tree_util import equality_errors_pytreedef +from jax._src.tree_util import (equality_errors_pytreedef, tree_map, + tree_unflatten, keystr) map, unsafe_map = safe_map, map -effects.control_flow_allowed_effects.add_type(lax.InOutFeedEffect) - def _typecheck_param(prim, param, name, msg_required, pred): if not pred: @@ -48,155 +43,56 @@ def _typecheck_param(prim, param, name, msg_required, pred): msg = sep.join([msg, param_str]) raise core.JaxprTypeError(msg) -@weakref_lru_cache -def _initial_style_open_jaxpr(fun: Callable, - in_tree: PyTreeDef, - in_avals: Sequence[core.AbstractValue], - debug_info: core.DebugInfo): - wrapped_fun, out_tree = api_util.flatten_fun_nokwargs( - lu.wrap_init(fun, debug_info=debug_info), - in_tree) - jaxpr, _, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic( - wrapped_fun, in_avals) - return jaxpr, consts, out_tree(), attrs_tracked +# TODO(dougalm): this seems way too complicated. Why not allow different consts for each +# branch of a switch? +def _merge_common_consts( + jaxprs: Sequence[core.ClosedJaxpr], + all_consts: Sequence[Sequence[Any]] + ) -> tuple[Sequence[core.ClosedJaxpr], Sequence[Any]]: + # Jaxprs must share consts, so we concat consts and pad the jaxprs' constvars. + lens = map(len, all_consts) + consts = [c for cs in all_consts for c in cs] + avalqdds = tuple(map(core.cur_aval_qdd, consts)) + num_constss = [len(cs) for cs in all_consts] + jaxprs = [_pad_constvars(jaxpr, num_consts, avalqdds[:sum(lens[:i])], avalqdds[sum(lens[:i+1]):]) + for i, (jaxpr, num_consts) in enumerate(zip(jaxprs, num_constss))] + # De-duplicate shared constants. + const_ids = tuple(id(c) for c in consts) + seen = set() + dd_consts = [c for c in consts if id(c) not in seen and not seen.add(id(c))] # type: ignore + jaxprs = [_dedup_consts(jaxpr, len(consts), const_ids) for jaxpr in jaxprs] + return jaxprs, dd_consts @weakref_lru_cache -def _initial_style_jaxpr(fun: Callable, - in_tree: PyTreeDef, - in_avals: Sequence[core.AbstractValue], - debug_info: core.DebugInfo): - jaxpr, consts, out_tree, () = _initial_style_open_jaxpr( - fun, in_tree, in_avals, debug_info) - closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr)) - return closed_jaxpr, consts, out_tree - -def _initial_style_jaxpr_attrs(fun: Callable, - in_tree: PyTreeDef, - in_avals: Sequence[core.AbstractValue], - debug_info: core.DebugInfo): - jaxpr, consts, out_tree, attrs_tracked = _initial_style_open_jaxpr( - fun, in_tree, in_avals, debug_info) - closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr)) - return closed_jaxpr, consts, out_tree, attrs_tracked - -def _initial_style_jaxprs_with_common_consts( - funs: Sequence[Callable], - in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue], - debug_infos: Sequence[core.DebugInfo]): - # When staging the branches of a conditional into jaxprs, constants are - # extracted from each branch and converted to jaxpr arguments. To use the - # staged jaxprs as the branches to a conditional *primitive*, we need for - # their (input) signatures to match. This function "joins" the staged jaxprs: - # for each one, it makes another that accepts *all* constants, but only uses - # those that it needs (dropping the rest). - jaxpr_data = [_initial_style_open_jaxpr(fn, in_tree, in_avals, debug_info) - for fn, debug_info in zip(funs, debug_infos)] - if not jaxpr_data: - return [], [], [] - - jaxprs, all_consts, all_out_trees, all_attrs_tracked = zip(*jaxpr_data) - all_const_avals = [map(core.get_aval, consts) for consts in all_consts] - - # TODO(sharadmv,mattjj): we could dedup *all consts* instead of just the Refs. - - # We don't want two different Refs in a jaxpr's input to refer to the same - # Ref in the caller. We call this the "Ref aliasing problem" and it introduces - # difficulties when discharging Refs and when reasoning about programs with - # state effects. When unifying the arguments to each branch in a cond, - # however, we might naively pass the same Ref in multiple times. - # - # Here we dedup any `Ref`s that were closed over across the branches and - # pad out constants used across different branches. - # Let's consider an example case. For the following branch jaxprs, we will - # produce the following const lists, where `t_` indicates a tracer (a Ref). - # { lambda x:i32[] a:Ref{float64[]} c:Ref[float64[]}; . let - # a[] <- 1.0 - # c[] <- 3.14 - # in () } - - # { lambda d:Ref[float64[]} b:Ref{float64[]} y:i32[]; . let - # d[] <- 6.28 - # b[] <- 2.0 - # in () } - # consts = [[0, t_e, t_f], [t_g, t_e, 1]] - # - # Notice how `t_e` is duplicated. To deduplicate the `Ref`s we first - # 1) Detecting duplicate `Ref` tracers. We keep track of duplicates in - # `tracer_id_to_canonical_id.` We store the deduped `Ref` tracers in a - # list called `canonical_refs`. We remove the `Ref`s from the consts. - # We should have the following lists: - # canonical_refs = [t_e, t_f, t_g] - # consts = [[0], [1]] - # 2) We need to munge the branch jaxprs to take in *all* the canonical Refs - # and ignore the ones it doesn't actually use. We do this by keeping track - # for each jaxpr for each of its input Refs which canonical_ref it - # corresponds to, producing the following list: - # canonical_ref_indices = [[0, 1], [2, 0]] - # - # Afterwards, we proceed by rewriting the jaxprs to be the following: - # { lambda a:Ref{float64[]} c:Ref[float64[]} b_:Ref{float64[]} x:i32[]; . let - # a[] <- 1.0 - # c[] <- 3.14 - # in () } - # { lambda b:Ref{float64[]} _:Ref{float64[]} d:Ref{float64[]} y:i32[]; . let - # d[] <- 6.28 - # b[] <- 2.0 - # in () } - canonical_ref_indices = [] - canonical_non_ref_indices = [] - canonical_refs: list[Any] = [] - canonical_non_refs: list[Any] = [] - tracer_id_to_canonical_ref_id = {} - tracer_id_to_canonical_non_ref_id = {} - canonical_ref_avals = [] - canonical_non_ref_avals = [] - for consts, consts_avals in zip(all_consts, all_const_avals): - ref_indices = [] - non_ref_indices = [] - for c, aval in zip(consts, consts_avals): - tracer_id = id(c) - if isinstance(aval, state.AbstractRef): - if tracer_id not in tracer_id_to_canonical_ref_id: - canonical_id = len(canonical_refs) - canonical_refs.append(c) - tracer_id_to_canonical_ref_id[tracer_id] = canonical_id - canonical_ref_avals.append(aval) - canonical_id = tracer_id_to_canonical_ref_id[tracer_id] - ref_indices.append(canonical_id) - else: - if tracer_id not in tracer_id_to_canonical_non_ref_id: - canonical_id = len(canonical_non_refs) - canonical_non_refs.append(c) - tracer_id_to_canonical_non_ref_id[tracer_id] = canonical_id - canonical_non_ref_avals.append(aval) - canonical_id = tracer_id_to_canonical_non_ref_id[tracer_id] - non_ref_indices.append(canonical_id) - canonical_ref_indices.append(tuple(ref_indices)) - canonical_non_ref_indices.append(tuple(non_ref_indices)) - - consts = [*canonical_refs, *canonical_non_refs] - jaxprs = tuple(_pad_jaxpr_constvars(jaxpr, i, (*canonical_ref_avals,), (*canonical_ref_indices,), (*canonical_non_ref_avals,), (*canonical_non_ref_indices,)) - for i, jaxpr in enumerate(jaxprs)) - return jaxprs, consts, all_out_trees +def _pad_constvars(jaxpr: core.ClosedJaxpr, num_consts: int, + left: tuple[core.AvalQDD, ...], + right: tuple[core.AbstractValue, ...]) -> core.ClosedJaxpr: + def make_var(aq): + return core.Var(aq.aval, initial_qdd=aq.qdd, final_qdd=aq.qdd) + invars = [*map(make_var, left), *jaxpr.invars[:num_consts], + *map(make_var, right), *jaxpr.invars[num_consts:]] + effs = pe._renumber_effects(invars, jaxpr.invars, jaxpr.effects) + jaxpr = jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(invars=invars, effects=effs)) + config.enable_checks.value and core.check_jaxpr(jaxpr.jaxpr) + return jaxpr @weakref_lru_cache -def _pad_jaxpr_constvars(jaxpr, i, canonical_ref_avals, canonical_ref_indices, - canonical_non_ref_avals, canonical_non_ref_indices): - is_ref = [isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars] - nonref_constvars, ref_constvars = partition_list(is_ref, jaxpr.constvars) - newvar = core.gensym(suffix='_') - padded_ref_constvars = map(newvar, canonical_ref_avals) - padded_non_ref_constvars = map(newvar, canonical_non_ref_avals) - for canonical_id, ref_var in zip(canonical_ref_indices[i], ref_constvars): - padded_ref_constvars[canonical_id] = ref_var - for canonical_id, non_ref_var in zip(canonical_non_ref_indices[i], nonref_constvars): - padded_non_ref_constvars[canonical_id] = non_ref_var - constvars = [*padded_ref_constvars, *padded_non_ref_constvars] - jaxpr = jaxpr.replace(constvars=constvars) - effects = pe.make_jaxpr_effects(jaxpr.constvars, jaxpr.invars, - jaxpr.outvars, jaxpr.eqns) - jaxpr = jaxpr.replace(effects=effects) - return core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) +def _dedup_consts(jaxpr, num_consts, const_ids): + newvars = {} + canonicalize = {v: newvars.setdefault(constid, v) + for constid, v in zip(const_ids, jaxpr.invars[:num_consts])} + eqns = [e.replace(invars=[canonicalize.get(x, x) if isinstance(x, core.Var) + else x for x in e.invars]) for e in jaxpr.eqns] + outvars = [canonicalize.get(x, x) if isinstance(x, core.Var) else x + for x in jaxpr.outvars] + invars = [*list(newvars.values()), *jaxpr.invars[num_consts:]] + effs = pe._renumber_effects(invars, + [*map(canonicalize.get, jaxpr.invars[:num_consts]), *jaxpr.invars[num_consts:]], + jaxpr.effects) + jaxpr = jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(invars=invars, eqns=eqns, outvars=outvars, + effects=effs)) + config.enable_checks.value and core.check_jaxpr(jaxpr) + return jaxpr def _check_tree_and_avals(what1, tree1, avals1, what2, tree2, avals2): """Raises TypeError if (tree1, avals1) does not match (tree2, avals2). @@ -244,14 +140,9 @@ def _prune_zeros(ts): def _make_closed_jaxpr(traceable: lu.WrappedFun, in_avals: Sequence[core.AbstractValue]): - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(traceable, in_avals) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(traceable, in_avals) return core.ClosedJaxpr(jaxpr, consts) -def _make_closed_jaxpr_attrs(traceable: lu.WrappedFun, in_avals: Sequence[core.AbstractValue]): - jaxpr, _, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(traceable, in_avals) - return core.ClosedJaxpr(jaxpr, consts), attrs_tracked - - def _show_diff(array1, array2): if core.typematch(array1, array2): return f"{array1}" diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 63896cc2a0bf..641e9beed58f 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -18,12 +18,13 @@ from collections.abc import Callable, Sequence import functools from functools import partial -import inspect import itertools import operator from typing import Any, TypeVar -from jax.tree_util import tree_flatten, tree_unflatten +from jax._src.tree_util import ( + tree_flatten, tree_unflatten, tree_flatten_with_path, keystr, + equality_errors_pytreedef, FlatTree) from jax._src import ad_util from jax._src import api_util from jax._src import config @@ -36,29 +37,26 @@ from jax._src import util from jax._src.state.discharge import register_partial_discharge_rule, discharge_state from jax._src.state.types import AbstractRef, RefEffect -from jax._src.core import replace_jaxpr_effects +from jax._src.core import replace_jaxpr_effects, typeof, cur_qdd from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe -from jax._src.interpreters import xla +from jax._src.interpreters import pxla from jax._src.lax import lax from jax._src.traceback_util import api_boundary -from jax._src.util import (safe_map, split_list, partition_list) +from jax._src.typing import ArrayLike +from jax._src.util import safe_map, safe_zip, split_list, partition_list, unzip2 from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo import numpy as np from jax._src.lax.control_flow.common import ( - _avals_short, - _check_tree_and_avals, - _initial_style_jaxprs_with_common_consts, - _make_closed_jaxpr, - _prune_zeros, - _typecheck_param, - ) + _avals_short, _typecheck_param, _merge_common_consts, + _make_closed_jaxpr, _prune_zeros) map, unsafe_map = safe_map, map +zip, unsafe_zip = safe_zip, zip # For backward compatibility with a previous switch/cond calling convention, @@ -80,7 +78,7 @@ def switch(index, branches, *operands): return branches[index](*operands) Internally this wraps XLA's `Conditional - `_ + `_ operator. However, when transformed with :func:`~jax.vmap` to operate over a batch of predicates, ``cond`` is converted to :func:`~jax.lax.select`. @@ -130,42 +128,70 @@ def switch(index, branches, *operands): lo = np.array(0, np.int32) hi = np.array(len(branches) - 1, np.int32) index = lax.clamp(lo, index, hi) + return _switch_internal(index, branches, operands, + branches_platforms=None) + +def _switch_internal( + index: ArrayLike, + branches: Sequence[Callable], + operands: Sequence[ArrayLike], *, + branches_platforms: BranchesPlatforms | None): if (config.disable_jit.value and core.is_concrete(index)): - return branches[int(index)](*operands) + return branches[int(index)](*operands) # type: ignore dbgs = [api_util.debug_info("switch", branch, operands, {}) for branch in branches] - ops, ops_tree = tree_flatten(operands) - ops_avals = tuple(map(core.get_aval, ops)) + args = FlatTree.flatten((operands, {})) + avals = args.map(core.get_aval) if config.mutable_array_checks.value: - api_util._check_no_aliased_ref_args(dbgs[0], ops_avals, ops) + api_util.check_no_aliased_ref_args(lambda: dbgs[0], list(avals), list(args)) + + jaxprs_, out_avalss = zip(*[pe.trace_to_jaxpr(branch, avals, dbg) + for branch, dbg in zip(branches, dbgs)]) + jaxprs_, all_consts = zip(*[pe.separate_consts(j) for j in jaxprs_]) + jaxprs, consts = _merge_common_consts(jaxprs_, all_consts) - jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts( - branches, ops_tree, ops_avals, dbgs) if config.mutable_array_checks.value: - api_util._check_no_aliased_closed_over_refs(dbgs[0], (*jaxprs[0].consts, *consts), ops) - for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])): - _check_tree_and_avals("branch 0 output", - out_trees[0], jaxprs[0].out_avals, - f"branch {i + 1} output", - out_tree, jaxpr.out_avals) + api_util._check_no_aliased_closed_over_refs(dbgs[0], (*jaxprs[0].consts, *consts), list(args)) + for i, (out_avals, jaxpr) in enumerate(zip(out_avalss[1:], jaxprs[1:])): + _check_branch_outputs( + "switch", "branch 0", f"branch{i+1}", branches[0], branches[i+1], + out_avalss[0], out_avals) + # prune passthrough outputs + fwds = [pe._jaxpr_forwarding(jaxpr.jaxpr) for jaxpr in jaxprs] + in_fwd = [xs[0] if len(set(xs)) == 1 else None for xs in zip(*fwds)] + keep = [f is None for f in in_fwd] + jaxprs = [pe.prune_closed_jaxpr_outputs(jaxpr, keep) for jaxpr in jaxprs] + joined_effects = core.join_effects(*(jaxpr.effects for jaxpr in jaxprs)) disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `switch`: {disallowed_effects}') - out = cond_p.bind(index, *consts, *ops, branches=tuple(jaxprs)) - return tree_unflatten(out_trees[0], out) + jaxprs = [replace_jaxpr_effects(jaxpr, joined_effects) for jaxpr in jaxprs] + params = dict(branches=tuple(jaxprs)) + if branches_platforms is not None: + params["branches_platforms"] = branches_platforms + out = cond_p.bind(index, *consts, *args, **params) + out_ = iter(out) + all_inputs = [*consts, *args] + out = [ + next(out_) if fwd is None else lax.asarray(all_inputs[fwd]) + for fwd in in_fwd + ] + assert next(out_, None) is None + return out_avalss[0].update(out).unflatten() -def _cond(pred, true_fun: Callable, false_fun: Callable, *operands, +@partial(api_boundary, repro_api_name="jax_cond") +def cond(pred, true_fun: Callable, false_fun: Callable, *operands, operand=_no_operand_sentinel): """Conditionally apply ``true_fun`` or ``false_fun``. Wraps XLA's `Conditional - `_ + `_ operator. Provided arguments are correctly typed, ``cond()`` has equivalent @@ -200,7 +226,12 @@ def cond(pred, true_fun, false_fun, *operands): pytree (nested Python tuple/list/dict) thereof. """ if not (callable(true_fun) and callable(false_fun)): - raise TypeError("lax.cond: true_fun and false_fun arguments should be callable.") + # try falling back to the old, deprecated version of `cond` + if callable(false_fun) and len(operands) == 2 and callable(operands[1]): + x_true, f_true, x_false, f_false = true_fun, false_fun, *operands + return cond(pred, lambda x, _: f_true(x), lambda _, x: f_false(x), x_true, x_false) + else: + raise TypeError("lax.cond: true_fun and false_fun arguments should be callable.") if operand is not _no_operand_sentinel: if operands: raise TypeError("if 'operand' keyword is passed then no positional " @@ -236,30 +267,35 @@ def cond(pred, true_fun, false_fun, *operands): else: return false_fun(*operands) - ops, ops_tree = tree_flatten(operands) - ops_avals = tuple(map(core.get_aval, ops)) - - dbg_true_fun = api_util.debug_info("cond", true_fun, operands, {}) + args = FlatTree.flatten((operands, {})) + avals = args.map(core.get_aval) + avals = avals.map2( + lambda a, x: core.AvalQDD(a, cur_qdd(x)) if a.has_qdd else a, + args) + dbg_true = api_util.debug_info("cond", true_fun, operands, {}) if config.mutable_array_checks.value: - api_util._check_no_aliased_ref_args(dbg_true_fun, ops_avals, ops) - dbg_false_fun = api_util.debug_info("cond", false_fun, operands, {}) - jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts( - (true_fun, false_fun), ops_tree, ops_avals, - [dbg_true_fun, dbg_false_fun]) - true_jaxpr, false_jaxpr = jaxprs + api_util.check_no_aliased_ref_args(lambda: dbg_true, list(avals), list(args)) + dbg_false = api_util.debug_info("cond", false_fun, operands, {}) + + true_jaxpr_, out_avals = pe.trace_to_jaxpr(true_fun, avals, dbg_true) + true_jaxpr_, true_consts = pe.separate_consts(true_jaxpr_) + false_jaxpr_, false_out_avals = pe.trace_to_jaxpr(false_fun, avals, dbg_false) + false_jaxpr_, false_consts = pe.separate_consts(false_jaxpr_) + (true_jaxpr, false_jaxpr), consts = _merge_common_consts( + (true_jaxpr_, false_jaxpr_), (true_consts, false_consts)) if config.mutable_array_checks.value: - api_util._check_no_aliased_closed_over_refs(dbg_true_fun, (*true_jaxpr.consts, *consts), ops) + api_util._check_no_aliased_closed_over_refs( + dbg_true, (*true_jaxpr.consts, *consts), list(args)) - out_tree, false_out_tree = out_trees if any(isinstance(out_aval, AbstractRef) for out_aval in true_jaxpr.out_avals + false_jaxpr.out_avals): raise ValueError("Cannot return `Ref`s from `cond`.") - _check_tree_and_avals("true_fun output", - out_tree, true_jaxpr.out_avals, - "false_fun output", - false_out_tree, false_jaxpr.out_avals) - # prune passhtrough outputs + _check_branch_outputs( + 'cond', 'true_fun', 'false_fun', + true_fun, false_fun, out_avals, false_out_avals) + + # prune passthrough outputs true_fwds = pe._jaxpr_forwarding(true_jaxpr.jaxpr) false_fwds = pe._jaxpr_forwarding(false_jaxpr.jaxpr) in_fwd = [i if i == j else None for i, j in zip(true_fwds, false_fwds)] @@ -277,58 +313,99 @@ def cond(pred, true_fun, false_fun, *operands): false_jaxpr = replace_jaxpr_effects(false_jaxpr, joined_effects) true_jaxpr = replace_jaxpr_effects(true_jaxpr, joined_effects) - out = cond_p.bind(index, *consts, *ops, branches=(false_jaxpr, true_jaxpr)) - num_consts = len(consts) + out = cond_p.bind(index, *consts, *args, branches=(false_jaxpr, true_jaxpr)) out_ = iter(out) - all_inputs = [*consts, *ops] + all_inputs = [*consts, *args] out = [ next(out_) if fwd is None else lax.asarray(all_inputs[fwd]) for fwd in in_fwd ] assert next(out_, None) is None - return tree_unflatten(out_tree, out) + return out_avals.update(out).unflatten() -@api_boundary -@functools.wraps(_cond) -def cond(*args, **kwargs): - # detect an attempt to call the former, deprecated cond +def _check_branch_outputs( + api_name, name1, name2, f1, f2, out_avals1, out_avals2) -> None: + info1 = api_util.fun_sourceinfo(f1) + info2 = api_util.fun_sourceinfo(f2) try: - ba = inspect.signature(_cond_with_per_branch_args).bind(*args, **kwargs) - except TypeError: - pass + outs1 = out_avals1.unflatten() + except: + paths = [None] * len(out_avals1) + component = lambda _: '' else: - assert not ba.kwargs # no catch-all **kwargs in _cond_with_per_branch - _, true_operand, true_fun, false_operand, false_fun = ba.args - if callable(true_operand) and callable(true_fun): - # treat this as modern cond (with two operands) - return _cond(*args, **kwargs) - if callable(true_fun) and callable(false_fun): - return _cond_with_per_branch_args(*ba.args) - - return _cond(*args, **kwargs) - -def _cond_with_per_branch_args(pred, - true_operand, true_fun: Callable, - false_operand, false_fun: Callable): - """Conditionally apply ``true_fun`` or ``false_fun``. + leaves_and_paths, _ = tree_flatten_with_path(outs1) + paths, _ = unzip2(leaves_and_paths) # type: ignore + component = lambda p: f' at path {keystr(p)}' if p else '' + + if out_avals1.tree != out_avals2.tree: + diffs = [f'{name1} output{component(p)} is a {thing1} but ' + f'{name2} output{component(p)} is a {thing2}, so {expl}' + for p, thing1, thing2, expl + in equality_errors_pytreedef(out_avals1.tree, out_avals2.tree)] + + if len(diffs) == 0: + return # the trees may have different aux data, but structures are same + elif len(diffs) == 1: + differences = f'{diffs[0]}.\n' + else: + differences = ('\n'.join(f' * {d};\n' for d in diffs[:-1]) + + f' * {diffs[-1]}.\n') - Has equivalent semantics to this Python implementation:: + raise TypeError( + f'{api_name} branch outputs must have the same pytree structure, but ' + 'they differ:\n\n' + f'{name1} is {info1}\n' + f'{name2} is {info2}\n\n' + f'{differences}\n' + f'Revise {name1} and/or {name2} so that they have the same pytree ' + 'structure.') + + if not all(map(core.typematch, out_avals1, out_avals2)): + diffs = [f'the output of {name1}{component(p)} has type {a1.str_short()}' + f' but the corresponding output of {name2} has type ' + f'{a2.str_short()}{core.aval_mismatch_extra(a1, a2)}' + for p, a1, a2 in zip(paths, out_avals1, out_avals2) + if not core.typematch(a1, a2)] + if len(diffs) == 0: + return # seems unreachable but in any case we don't have a good error msg + elif len(diffs) == 1: + differences = f'{_capitalize(diffs[0])}.\n' + else: + differences = ('\n'.join(f' * {d};' for d in diffs[:-1]) + + f'\n * {diffs[-1]}.\n') + + pvary_applications = [ + f"applying `jax.lax.pcast(..., {tuple(a1.vma - a2.vma)}, to='varying')` " + f"to the output of {n}{component(p)}" + for p, aval1, aval2 in zip(paths, out_avals1, out_avals2) + for n, a1, a2 in [(name1, aval2, aval1), (name2, aval1, aval2)] + if not core.typematch(a1, a2) and + isinstance(a1, core.ShapedArray) and isinstance(a2, core.ShapedArray) + and a1.vma != a2.vma and a2.vma - a1.vma] + + if not pvary_applications: + pvary_msg = '' + elif len(pvary_applications) == 1: + pvary_msg = f'This might be fixed by {pvary_applications[0]}.\n' + else: + pvary_msg = ('This might be fixed by:\n' + + '\n'.join(f' * {d};' for d in pvary_applications[:-1]) + + f'\n * {pvary_applications[-1]}.\n') + if pvary_msg: + pvary_msg += ("See https://docs.jax.dev/en/latest/notebooks/shard_map.html#scan-vma " + "for more information.\n\n") - def cond(pred, true_operand, true_fun, false_operand, false_fun): - if pred: - return true_fun(true_operand) - else: - return false_fun(false_operand) + raise TypeError( + f'{api_name} branches must have equal output types but they differ.\n\n' + f'{name1} is {info1}\n' + f'{name2} is {info2}\n\n' + f'{differences}\n' + f'{pvary_msg}' + f'Revise {name1} and/or {name2} so that all output types match.') - Pred has to be a scalar type, collection types (list, tuple) are not supported - """ - if not (callable(true_fun) and callable(false_fun)): - raise TypeError("lax.cond: true_fun and false_fun arguments should be callable.") - return _cond(pred, - lambda op: true_fun(op[0]), - lambda op: false_fun(op[1]), - (true_operand, false_operand)) + +def _capitalize(s): + # s.capitalize() converts s[1:] to lowercase which we don't want. + return s[0].capitalize() + s[1:] def _join_cond_effects(branches: Sequence[core.ClosedJaxpr]) -> effects.Effects: joined_effects = set() @@ -347,6 +424,15 @@ def _cond_abstract_eval(*avals: core.AbstractValue, if disallowed_effects: raise NotImplementedError( f'Effects not supported in `cond`: {disallowed_effects}') + b0_vma = [o.vma for o in branches[0].out_avals] + for branch in branches[1:]: + b_vma = [o.vma for o in branch.out_avals] + if b0_vma != b_vma: + raise Exception("The branches of cond produced mismatched varying manual " + f"axes. Got {b0_vma} and {b_vma}. Please open an issue " + "at https://github.com/jax-ml/jax/issues, and as a " + "temporary workaround pass the check_vma=False argument " + "to `jax.shard_map`") return branches[0].out_avals, joined_effects def _bcast_select(pred, on_true, on_false): @@ -361,7 +447,7 @@ def _bcast_select_n(pred, *cases): pred = lax.broadcast_in_dim(pred, np.shape(cases[0]), idx) return lax.select_n(pred, *cases) -def _cond_batching_rule(axis_data, args, dims, branches): +def _cond_batching_rule(axis_data, args, dims, *, branches, **params): index, *ops = args index_dim, *op_dims = dims # TODO(sharadmv): clean this up by adding a specific blocklist @@ -375,6 +461,11 @@ def _cond_batching_rule(axis_data, args, dims, branches): raise NotImplementedError( "IO effect not supported in vmap-of-cond.") + if "branches_platforms" in params and (index_dim is not batching.not_mapped): + # If we end up with a mapped index for a platform_dependent cond, we can + # replace the index with a fresh call to platform_index. See #29329. + index = platform_index_p.bind(platforms=params["branches_platforms"]) + index_dim = batching.not_mapped if index_dim is not batching.not_mapped: # Convert to a lax.select. While we could get away with not broadcasting @@ -383,7 +474,10 @@ def _cond_batching_rule(axis_data, args, dims, branches): # optimizations to XLA. # TODO(mattjj,frostig): assumes branches are side-effect-free, revise! index, *ops = ( - batching.bdim_at_front(x, d, axis_data.size) for x, d in zip(args, dims)) + batching.bdim_at_front(x, d, axis_data.size, + mesh_axis=axis_data.explicit_mesh_axis) + for x, d in zip(args, dims) + ) in_batched = [True] * len(branches[0].in_avals) out_batched = [True] * len(branches[0].out_avals) @@ -415,10 +509,51 @@ def _cond_batching_rule(axis_data, args, dims, branches): for jaxpr in branches) out_dims = [0 if b else batching.not_mapped for b in out_bat] - out = cond_p.bind(index, *ops, branches=branches_batched) + out = cond_p.bind(index, *ops, branches=branches_batched, + **params) return out, out_dims -def _cond_jvp(primals, tangents, branches): +def _cond_linearize(nzs, *primals_in, branches, **params): + idx_nz, *nzs = nzs + assert not idx_nz + nzs_out = [ad.linearize_jaxpr(jaxpr, nzs, allow_fwds=False)[2] + for jaxpr in branches] + nzs_out = map(any, zip(*nzs_out)) + primal_jaxprs, tangent_jaxprs, branch_res_avals = [], [], [] + for jaxpr in branches: + primal_jaxpr, num_res_out, _, _, tangent_jaxpr = \ + ad.linearize_jaxpr(jaxpr, nzs, instantiate=nzs_out, allow_fwds=False) + res_avals = primal_jaxpr.out_avals[len(primal_jaxpr.out_avals)-num_res_out:] + primal_jaxprs.append(primal_jaxpr) + tangent_jaxprs.append(tangent_jaxpr) + branch_res_avals.append(res_avals) + + all_res_avals, res_avals_per_branch = _merge_branch_residuals(branch_res_avals) + num_res = len(all_res_avals) + primal_jaxprs = _join_cond_outputs( + primal_jaxprs, all_res_avals, res_avals_per_branch, len(nzs_out)) + tangent_jaxprs = _join_cond_pe_staged_jaxpr_inputs( + tangent_jaxprs, all_res_avals, res_avals_per_branch) + tangent_avals_out = [a.to_tangent_aval() for a in jaxpr.out_avals] + + primals_res_out = cond_p.bind(*primals_in, branches=primal_jaxprs, **params) + primals, res = split_list(primals_res_out, [len(nzs_out)]) + + def tangent_fun(res, *tangents_in): + nz_tangents_in = [t for t in tangents_in if not isinstance(t, ad.Zero)] + nz_tangents_out = cond_p.bind(*res, *nz_tangents_in, + branches=tangent_jaxprs, **params) + nz_tangents_out_ = iter(nz_tangents_out) + tangents_out = [next(nz_tangents_out_) if nz else ad.Zero(aval) + for (aval, nz) in zip(tangent_avals_out, nzs_out)] + assert next(nz_tangents_out_, None) is None + return tangents_out + + idx, *_ = primals_in + return primals, nzs_out, [idx, *res], tangent_fun + + +def _cond_jvp(primals, tangents, *, branches, **params): nonzeros = [type(t) is not ad_util.Zero for t in tangents] index_nz, *ops_nz = nonzeros @@ -435,15 +570,17 @@ def _cond_jvp(primals, tangents, branches): _, *ops_dot = tangents ops_dot = _prune_zeros(ops_dot) - out = cond_p.bind(index, *ops, *ops_dot, branches=branches_jvp) + out = cond_p.bind(index, *ops, *ops_dot, branches=branches_jvp, + **params) out_primals, out_tangents = split_list(out, [len(out_nz)]) out_tangents_iter = iter(out_tangents) - out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_primal_value(p) + out_tangents = [next(out_tangents_iter) if nz else + ad_util.Zero.from_primal_value(p) for p, nz in zip(out_primals, out_nz)] return out_primals, out_tangents -def _cond_partial_eval(trace, *tracers, branches): - in_unknowns = [t.pval[0] is not None for t in tracers] +def _cond_partial_eval(trace, *tracers, branches, **params): + in_unknowns = [not t.pval.is_known() for t in tracers] index_uk, *ops_uk = in_unknowns if any(isinstance(eff, RefEffect) for branch in branches for eff in branch.jaxpr.effects): @@ -453,7 +590,7 @@ def _cond_partial_eval(trace, *tracers, branches): if index_uk: # When the branch index is unknown, we stage out the whole cond. # TODO(mattjj): remove this path when old remat is removed - params = dict(branches=branches) + params = dict(branches=branches, **params) return trace.default_process_primitive(cond_p, tracers, params) branches_out_uks = [] @@ -483,7 +620,8 @@ def _cond_partial_eval(trace, *tracers, branches): for j in branches_known[1:]) in_consts = [t.pval.get_known() for t in tracers if t.pval.is_known()] - out_consts_res = cond_p.bind(*in_consts, branches=branches_known) + out_consts_res = cond_p.bind(*in_consts, branches=branches_known, + **params) out_consts, res = split_list(out_consts_res, [len(out_consts_res) - num_res]) index_tracer = trace.instantiate_const(tracers[0]) @@ -492,11 +630,11 @@ def _cond_partial_eval(trace, *tracers, branches): res_tracers = map(trace.new_instantiated_const, res) out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None) for aval in branches_unknown[0].out_avals] - params = dict(branches=branches_unknown) + params = dict(branches=branches_unknown, **params) name_stack = source_info_util.current_name_stack()[len(trace.name_stack):] source = source_info_util.current().replace(name_stack=name_stack) eqn = pe.new_eqn_recipe( - [index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params, + trace, [index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params, core.join_effects(*(j.effects for j in branches_unknown)), source) for t in out_tracers: t.recipe = eqn return util.merge_lists(out_uks, out_consts, out_tracers) @@ -505,6 +643,7 @@ def _cond_partial_eval(trace, *tracers, branches): def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn): index_uk, *ops_uk = unks_in branches = eqn.params['branches'] + eqn_rest_params = dict(k_v for k_v in eqn.params.items() if k_v[0] != 'branches') # Instantiate all inputs (b/c jaxpr_staged will take all inputs). new_inst = [x for x, inst in zip(eqn.invars, inst_in) @@ -555,13 +694,12 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn): for j in branches_known[1:]) # Create residual variables. - newvar = core.gensym() - res_binders = map(newvar, all_res_avals) + res_binders = map(core.Var, all_res_avals) # Build the known eqn. ins_known, _ = partition_list(unks_in, eqn.invars) # includes index invar out_binders_known, _ = partition_list(unks_out, eqn.outvars) - params_known = dict(branches=branches_known) + params_known = dict(branches=branches_known, **eqn_rest_params) effects_known = _join_cond_effects(branches_known) eqn_known = pe.new_jaxpr_eqn( ins_known, [*out_binders_known, *res_binders], cond_p, params_known, @@ -569,7 +707,7 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn): # Build the staged eqn. _, out_binders_staged = partition_list(inst_out, eqn.outvars) - params_staged = dict(branches=branches_staged) + params_staged = dict(branches=branches_staged, **eqn_rest_params) effects_staged = _join_cond_effects(branches_staged) eqn_staged = pe.new_jaxpr_eqn( [eqn.invars[0], *res_binders, *eqn.invars[1:]], out_binders_staged, @@ -620,7 +758,7 @@ def enumerate_equal(xs): # residual outputs that it does not populate. def _join_cond_outputs(jaxprs: Sequence[core.ClosedJaxpr], all_res_avals, res_aval_indices_per_jaxpr, - num_non_res_outputs): + num_non_res_outputs) -> tuple[core.ClosedJaxpr, ...]: def augment_jaxpr(jaxpr: core.ClosedJaxpr, res_indices): def f_aug(*args): @@ -638,11 +776,10 @@ def f_aug(*args): # This function augments branch inputs to agree with the merged residual format: # each branch is made to accept all residuals, even though it will ignore those # that it does not read. -def _join_cond_pe_staged_jaxpr_inputs(jaxprs: Sequence[core.ClosedJaxpr], - all_res_avals, - res_aval_indices_per_jaxpr): - newvar = core.gensym(suffix='_') - all_res_vars = map(newvar, all_res_avals) +def _join_cond_pe_staged_jaxpr_inputs( + jaxprs: Sequence[core.ClosedJaxpr], all_res_avals, + res_aval_indices_per_jaxpr) -> tuple[core.ClosedJaxpr, ...]: + all_res_vars = map(core.Var, all_res_avals) def augment_jaxpr(jaxpr: core.ClosedJaxpr, res_indices) -> core.ClosedJaxpr: num_res = len(res_indices) @@ -699,51 +836,45 @@ def _cond_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn, return [True, *used_inputs], new_eqn -def _transpose_cond_jaxpr(jaxpr: core.ClosedJaxpr, - num_res: int): - res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res]) - - def transposed(*args): - res, cts_out = split_list(args, [num_res]) - primals = res + [ad.UndefinedPrimal(aval) for aval in primal_avals] - cts_in = ad.backward_pass( - jaxpr.jaxpr, False, jaxpr.consts, primals, cts_out) - _, cts_in = split_list(cts_in, [num_res]) - return map(ad.instantiate_zeros, cts_in) - - return _make_closed_jaxpr(lu.wrap_init(transposed, - debug_info=jaxpr.jaxpr.debug_info), - res_avals + jaxpr.out_avals) - -def _cond_transpose(cts, *args, branches): - index, *ops = args - assert type(index) is not ad.UndefinedPrimal - linear = [type(x) is ad.UndefinedPrimal for x in ops] - in_avals = branches[0].in_avals - num_res = len(ops) - sum(linear) - if any(isinstance(eff, RefEffect) for branch in branches for eff in - branch.jaxpr.effects): - raise NotImplementedError("State effect not supported in cond transpose.") - - branches_trans = tuple( - _transpose_cond_jaxpr(jaxpr, num_res) for jaxpr in branches) - lin_in_avals = [a.strip_weak_type() for a, l in zip(in_avals, linear) if l] - assert all(core.typematch(out_aval, lin_in_aval) - for jaxpr in branches_trans - for out_aval, lin_in_aval in zip(jaxpr.out_avals, lin_in_avals)) - - res = ops[:num_res] - cts = map(ad.instantiate_zeros, cts) - - out = cond_p.bind(index, *res, *cts, branches=branches_trans) - assert all(map(core.typecheck, lin_in_avals, out)) - - out_iter = iter(out) - out = [next(out_iter) if l else None for l in linear] - assert next(out_iter, None) is None - return [None] + out - -def _cond_typecheck(bind_time, *in_atoms, branches): +def _cond_transpose_fancy(cts_in, index, *args, branches, **params): + assert not isinstance(index, ad.GradAccum) + primals_ctrefs, specs = ad.project_accums(args) + in_flat, in_tree = tree_flatten((primals_ctrefs, cts_in)) + in_avals = tuple(core.AvalQDD(a, cur_qdd(x)) if (a := typeof(x)).has_qdd # type: ignore + else a for x in in_flat) + trans_branches, out_trees = unzip2( + _transpose_jaxpr_fancy(j, in_tree, in_avals, specs, (False,) * len(args)) + for j in branches) + out_nzs = [[not isinstance(x, ad.Zero) for x in tree_unflatten(t, j.out_avals)] + for t, j in zip(out_trees, trans_branches)] + out_nz = tuple(map(partial(functools.reduce, operator.or_), zip(*out_nzs))) + trans_branches, out_trees = unzip2( + _transpose_jaxpr_fancy(j, in_tree, in_avals, specs, out_nz) for j in branches) + out_tree, = set(out_trees) + cts_out = cond_p.bind(index, *in_flat, branches=(*trans_branches,), **params) + for x, ct in zip(args, tree_unflatten(out_tree, cts_out)): + if isinstance(x, ad.ValAccum): x.accum(ct) + +@util.weakref_lru_cache +def _transpose_jaxpr_fancy(jaxpr, in_tree, in_avals, specs, inst_out): + cell = lambda: None + maybe_inst = lambda x, inst: ad.instantiate_zeros(x) if inst else x + def transposed(*in_flat): + primals_ctrefs, cts_in = tree_unflatten(in_tree, in_flat) + args = ad.unproject_accums(specs, primals_ctrefs) + ad.backward_pass3(jaxpr.jaxpr, False, jaxpr.consts, args, cts_in) + cts_out = [maybe_inst(x.freeze(), inst) if isinstance(x, ad.ValAccum) + else None for x, inst in zip(args, inst_out)] + cts_out, cell.out_tree = tree_flatten(cts_out) # type: ignore + return cts_out + dbg = jaxpr.jaxpr.debug_info.with_unknown_names() + trans_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(transposed, debug_info=dbg), in_avals) + return core.ClosedJaxpr(trans_jaxpr, consts), cell.out_tree # type: ignore + + +def _cond_typecheck(bind_time, *in_atoms, branches, **params): + del params if not bind_time: _, *in_atoms = in_atoms avals = [x.aval for x in in_atoms] @@ -797,22 +928,97 @@ def _cond_typecheck(bind_time, *in_atoms, branches): f'called with operands of type {_avals_short(op_avals)}') return jaxpr0.out_avals, joined_effects + +BranchesPlatforms = tuple[tuple[str, ...] | None, ...] +# cond_p takes an optional branches_platforms param of type `BranchesPlatforms` +# when it is a `platform_dependent` conditional. +# In that case, `branches_platforms` is a tuple as long +# as `branches` and for each branch it specifies the lowering platforms it +# corresponds to. The last element, corresponding to the last branch, +# can be `None` to represent a default match-all-lowering-platforms. +# The index argument of a `platform_dependent` cond is always a +# `platform_index` primitive. cond_p = core.Primitive('cond') cond_p.multiple_results = True cond_p.skip_canonicalization = True cond_p.def_impl(partial(dispatch.apply_primitive, cond_p)) cond_p.def_effectful_abstract_eval(_cond_abstract_eval) ad.primitive_jvps[cond_p] = _cond_jvp -ad.primitive_transposes[cond_p] = _cond_transpose +ad.primitive_linearizations[cond_p] = _cond_linearize +ad.fancy_transposes[cond_p] = _cond_transpose_fancy pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval batching.fancy_primitive_batchers[cond_p] = _cond_batching_rule -xla.register_initial_style_primitive(cond_p) +pxla.register_initial_style_primitive(cond_p) core.custom_typechecks[cond_p] = partial(_cond_typecheck, False) pe.partial_eval_jaxpr_custom_rules[cond_p] = _cond_partial_eval_custom pe.dce_rules[cond_p] = _cond_dce_rule -batching.ragged_prop_rules[cond_p] = batching.ragged_mask_assert_no_op_rule -def _cond_lowering(ctx, index, *args, branches): +def _cond_is_high(*_, branches, **__) -> bool: + return any(j.jaxpr.is_high for j in branches) +cond_p.is_high = _cond_is_high # type: ignore + +def _cond_to_lojax(pred, *hi_args, branches): + jaxpr = branches[0] + lo_branches = tuple(pe.lower_jaxpr(j) for j in branches) + lo_args = [lo_val for aval, x in zip(branches[0].in_aval_qdds, hi_args) + for lo_val in (aval.read_loval(x) if aval.has_qdd + else aval.lower_val(x))] + all_outs = cond_p.bind(pred, *lo_args, branches=lo_branches) + lo_muts_out = sum(len(aval.lo_ty()) for aval in branches[0].final_aval_qdds if aval.has_qdd) + out_mut, lo_outs = split_list(all_outs, [lo_muts_out]) + + # collect and apply mutations + out_mut_ = iter(out_mut) + in_idx = {v: i for i, v in enumerate(jaxpr.jaxpr.invars)} + + for v in jaxpr.jaxpr.invars: + if v.final_qdd is not None: + qdd = v.final_qdd + lo_vals = itertools.islice(out_mut_, len(v.aval.lo_ty_qdd(qdd))) + v.aval.update_from_loval(qdd, hi_args[in_idx[v]], *lo_vals) + + lo_outs_ = iter(lo_outs) + + hi_outs = [t.raise_val(*itertools.islice(lo_outs_, len(t.lo_ty()))) + for t in jaxpr.out_avals] + assert next(lo_outs_, None) is None + return hi_outs + +cond_p.to_lojax = _cond_to_lojax + +def _cond_lowering(ctx, index, *args, branches, + **params): + if (branches_platforms := params.get("branches_platforms", None)) is not None: + branches_kept: list[core.ClosedJaxpr] = [] + index_to_kept_index: dict[int, int] = {} + for p in mlir._platforms_for_eqn(ctx): + # Each `p` must appear in exactly one branches_platforms, or in the + # last default branch. Otherwise, platform_index lowering would have + # failed already. + for b_idx, b_platforms in enumerate(branches_platforms): + if b_platforms is None or p in b_platforms: + if b_idx not in index_to_kept_index: + index_to_kept_index[b_idx] = len(branches_kept) + branches_kept.append(branches[b_idx]) + break + else: + assert False, p + + # Compute the new index into branches_keep + i32_type = ir.RankedTensorType.get([], mlir.dtype_to_ir_type(dtypes.dtype(np.int32))) + kept_index_case_op = hlo.CaseOp([i32_type], + index=index, + num_branches=len(branches)) + for i in range(len(branches)): + branch = kept_index_case_op.regions[i].blocks.append() + with ir.InsertionPoint(branch): + kept_i = np.int32(index_to_kept_index.get(i, 0)) + hlo.return_([mlir.ir_constant(kept_i)]) + + index = kept_index_case_op + branches = branches_kept + assert branches, "platform_index lowering should have failed first" + joined_effects = core.join_effects(*(branch.effects for branch in branches)) ordered_effects = list(effects.ordered_effects.filter_in(joined_effects)) num_tokens = len(ordered_effects) @@ -831,11 +1037,14 @@ def _cond_lowering(ctx, index, *args, branches): for i, jaxpr in enumerate(branches): branch = case_op.regions[i].blocks.append() with ir.InsertionPoint(branch): - consts = [mlir.ir_constant(xla.canonicalize_dtype(x)) for x in jaxpr.consts] + consts = [ + mlir.ir_constant(x, aval=var.aval) + for x, var in zip(jaxpr.consts, jaxpr.jaxpr.constvars) + ] out_vals, tokens_out = mlir.jaxpr_subcomp( ctx.module_context, jaxpr.jaxpr, name_stack.extend(f'branch_{i}_fun'), tokens_in, consts, *args, - dim_var_values=ctx.dim_var_values) + dim_var_values=ctx.dim_var_values, const_lowering=ctx.const_lowering) out_tokens = [tokens_out.get(eff) for eff in ordered_effects] out_vals = [*out_tokens, *out_vals] hlo.return_(mlir.flatten_ir_values(out_vals)) @@ -849,36 +1058,33 @@ def _cond_lowering(ctx, index, *args, branches): mlir.register_lowering(cond_p, _cond_lowering) @register_partial_discharge_rule(cond_p) -def _cond_state_discharge_rule(should_discharge, in_avals, out_avals, index, *args, branches): +def _cond_state_discharge_rule(should_discharge, in_avals, out_avals, index, *args, + branches, **params): assert not should_discharge[0], "Can't discharge the index." - discharged_branches = tuple( - discharge_state(branch.jaxpr, (), should_discharge=should_discharge[1:])[0] - for branch in branches - ) + discharged_branches, discharged_consts = unzip2( + discharge_state(branch.jaxpr, branch.consts, should_discharge=should_discharge[1:]) + for branch in branches) # Don't thread the ref values through the cond if they never change. forwarded_outvars = None for branch in discharged_branches: invar_pos = {v: i for i, v in enumerate(branch.invars)} branch_forwarding = [ invar_pos.get(v, None) if isinstance(v, core.Var) else None - for v in branch.outvars[len(out_avals) :] - ] + for v in branch.outvars[len(out_avals) :]] if forwarded_outvars is None: forwarded_outvars = branch_forwarding else: forwarded_outvars = [ i if i == j else None - for i, j in zip(forwarded_outvars, branch_forwarding) - ] + for i, j in zip(forwarded_outvars, branch_forwarding)] assert forwarded_outvars is not None all_outvars_fwd = [None] * len(out_avals) + forwarded_outvars - new_branches = tuple( - core.ClosedJaxpr( + new_branches = tuple(core.ClosedJaxpr( branch.replace(outvars=[v for v, fwd in zip(branch.outvars, all_outvars_fwd) - if fwd is None]), ()) - for branch in discharged_branches - ) - out_vals_no_fwd = cond_p.bind(index, *args, branches=new_branches) + if fwd is None]), consts) + for branch, consts in zip(discharged_branches, discharged_consts)) + out_vals_no_fwd = cond_p.bind(index, *args, branches=new_branches, + **params) out_vals, out_ref_vals_no_fwd = util.split_list(out_vals_no_fwd, [len(out_avals)]) # Insert forwarded values into reference outputs ref_val_no_fwd_iter = iter(out_ref_vals_no_fwd) @@ -943,50 +1149,41 @@ def other_platforms_code(*args): ... The value ``per_platform[execution_platform](*args)``. """ # Join identical branches - platform_branches: list[tuple[list[str], Callable]] = [] + branches_platforms_list: list[tuple[list[str], Callable]] = [] for pname, pbranch in per_platform.items(): + if not callable(pbranch): + raise TypeError(f"lax.platform_dependent: the '{pname}' branch must " + "be a callable.") if pname == "gpu": raise ValueError("Use 'cuda' or 'rocm' for lax.platform_dependent.") - for ps, b in platform_branches: + for ps, b in branches_platforms_list: if b == pbranch: ps.append(pname) break else: - platform_branches.append(([pname], pbranch)) - - platforms_lists, branches = util.unzip2(platform_branches) - platform_index = platform_index_p.bind( - platforms=tuple(tuple(ps) for ps in platforms_lists), - has_default=(default is not None)) + branches_platforms_list.append(([pname], pbranch)) + platforms_lists, branches = util.unzip2(branches_platforms_list) + branches_platforms: BranchesPlatforms = tuple(tuple(ps) for ps in platforms_lists) if default is not None: + if not callable(default): + raise TypeError("lax.platform_dependent: the 'default' branch must " + "be a callable.") branches = branches + (default,) - # Use a switch, to get the proper transformation rules for free. Since - # platform index has no dependence on the input data, it won't be vectorized - # under vmap. - # If the switch and the platform_index_p above are in the same compilation - # unit then constant-folding will remove the unnecessary branches. However, - # if we run in eager mode the switch below cannot be constant-folded and - # the compilation may fail if some of the branches contain custom calls not - # recognized on the compilation platform. Detect eager mode and keep only the - # needed branch. - try: - # Note/TODO(mvoz): This actually rarely seems to concretize - we could look into - # core.ensure_compile_time_eval to get better single-branch selection. - platform_index_concrete = core.concrete_or_error(operator.index, platform_index) - except core.ConcretizationTypeError: - return switch(platform_index, branches, *args) - else: - assert 0 <= platform_index_concrete < len(branches) - return branches[platform_index_concrete](*args) + branches_platforms = branches_platforms + (None,) # type: ignore + platform_index = platform_index_p.bind(platforms=branches_platforms) + + if core.is_concrete(platform_index): + return branches[int(platform_index)](*args) + return _switch_internal(platform_index, branches, args, + branches_platforms=branches_platforms) + # A primitive to compute the index of a platform into a list of platforms. # Args: -# platforms: Sequence[Sequence[str]]: a sequence of sequences of platform -# names. If the current lowering platform is in one of the inner sequences -# returns the index of that inner sequence in the outer sequence. -# has_default: if True, and if the lowering platform is not found in -# `platforms` then return `len(platforms)`. Otherwise, raise an error. +# platforms: BranchesPlatforms. If the current lowering +# platform is in one of the inner tuples returns the index of that inner +# tuple in the outer tuple. platform_index_p = core.Primitive("platform_index") platform_index_p.multiple_results = False platform_index_p.def_impl(functools.partial(dispatch.apply_primitive, @@ -998,25 +1195,25 @@ def _platform_index_aval(*_, **__): def _platform_index_lowering(ctx: mlir.LoweringRuleContext, *, - platforms: Sequence[Sequence[str]], - has_default: bool): - def lower_constant( - ctx: mlir.LoweringRuleContext, *, i: int - ) -> Sequence[ir.Value]: + platforms: BranchesPlatforms): + def lower_constant(ctx: mlir.LoweringRuleContext, *, + i: int) -> Sequence[ir.Value]: v = mlir.ir_constant(np.int32(i)) - assert isinstance(v, ir.Value), v return [v] + platform_rules: dict[str, mlir.LoweringRule] = {} + default_rule = None for i, ps in enumerate(platforms): rule = partial(lower_constant, i=i) - for p in ps: - platform_rules[p] = rule + if ps is None: + default_rule = rule + else: + for p in ps: + platform_rules[p] = rule - default_rule = ( - partial(lower_constant, i=len(platforms)) if has_default else None) return mlir.lower_per_platform( ctx, - f"platform_index(platforms={platforms}, has_default={has_default})", + f"platform_index(platforms={platforms})", platform_rules, default_rule, effects.no_effects) mlir.register_lowering(platform_index_p, _platform_index_lowering) diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py deleted file mode 100644 index fc7ebde4cbea..000000000000 --- a/jax/_src/lax/control_flow/for_loop.py +++ /dev/null @@ -1,777 +0,0 @@ -# Copyright 2022 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Module for the `for_loop` primitive.""" - -from __future__ import annotations - -from collections.abc import Callable, Sequence -import functools -import operator -from typing import Any, Generic, TypeVar - -from jax import lax -from jax._src import api_util -from jax._src.interpreters import ad -from jax._src.interpreters import batching -from jax._src.interpreters import mlir -from jax._src.interpreters import partial_eval as pe -from jax.tree_util import (tree_flatten, tree_structure, tree_unflatten, - treedef_tuple, tree_map, tree_leaves, PyTreeDef) - -from jax._src import ad_util -from jax._src import core -from jax._src import dispatch -from jax._src import dtypes -from jax._src import linear_util as lu -from jax._src import source_info_util -from jax._src.state.types import (ReadEffect, AbstractRef, StateEffect) -from jax._src.state import discharge as state_discharge -from jax._src.state import primitives as state_primitives -from jax._src.state import utils as state_utils -from jax._src.state import types as state_types -from jax._src.typing import Array -from jax._src.util import (partition_list, merge_lists, safe_map, safe_zip, - split_list, split_dict, weakref_lru_cache) -from jax._src.lax.control_flow import loops -from jax._src.lax.control_flow.common import _initial_style_jaxpr -import numpy as np - -## JAX utilities - -map, unsafe_map = safe_map, map -zip, unsafe_zip = safe_zip, zip - -## Helpful type aliases -S = TypeVar('S') -T = TypeVar('T') -class Ref(Generic[T]): pass - -ref_set = state_primitives.ref_set -ref_get = state_primitives.ref_get -ref_addupdate = state_primitives.ref_addupdate -discharge_state = state_discharge.discharge_state - - -## `for_loop` implementation - -for_p = core.Primitive('for') -for_p.multiple_results = True -for_p.skip_canonicalization = True - -### Tracing utilities - -def _trace_to_jaxpr_with_refs(f: Callable, state_tree: PyTreeDef, - state_avals: Sequence[core.AbstractValue], - debug_info: core.DebugInfo, - ) -> tuple[core.Jaxpr, list[Any], PyTreeDef]: - f, out_tree_thunk = api_util.flatten_fun_nokwargs( - lu.wrap_init(f, debug_info=debug_info), - treedef_tuple((tree_structure(0), state_tree))) - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( - f, state_avals) - return jaxpr, consts, out_tree_thunk() - -def for_loop(nsteps: int | Sequence[int], - body: Callable[[Array, Ref[S]], None], init_state: S, - *, reverse: bool = False, unroll: int = 1) -> S: - """A for-loop combinator that allows read/write semantics in the loop body. - - `for_loop` is a higher-order function that enables writing loops that can be - staged out in JIT-ted JAX computations. Unlike `jax.lax.fori_loop`, it allows - mutation in its body using `Ref`s. - - `for_loop` will initialize `Ref`s with the values in `init_state`. Each - iteration, `body` will be called with the current `Ref`s, which can be read - from and written to using `ref_get` and `ref_set`. - - `for_loop` is semantically equivalent to the following Python code: - - ```python - def for_loop(nsteps, body, init_state): - refs = tree_map(make_ref, init_state) - for i in range(nsteps): - body(i, refs) - return tree_map(ref_get, refs) - ``` - - Args: - nsteps: Number of iterations - body: A callable that takes in the iteration number as its first argument - and `Ref`s corresponding to `init_state` as its second argument. - `body` is free to read from and write to its `Ref`s. `body` should - not return anything. - init_state: A Pytree of JAX-compatible values used to initialize the `Ref`s - that will be passed into the for loop body. - unroll: A positive int specifying, in the underlying operation of the - `for` primitive, how many iterations to unroll within a single iteration - of a loop. Higher values may speed up execution time at the cost of longer - compilation time. - Returns: - A Pytree of values representing the output of the for loop. - """ - if unroll < 1: - raise ValueError("`unroll` must be a positive integer.") - if isinstance(nsteps, int): - nsteps = [nsteps] - if len(nsteps) > 1: - outer_step, *rest_steps = nsteps - def wrapped_body(i, refs): - vals = tree_map(lambda ref: ref_get(ref, ()), refs) - vals = for_loop( - rest_steps, functools.partial(body, i), vals, unroll=unroll) - tree_map(lambda ref, val: ref_set(ref, (), val), refs, vals) - return for_loop(outer_step, wrapped_body, init_state, unroll=unroll) - dbg = api_util.debug_info("for_loop", body, (0, init_state), {}) - nsteps, = nsteps - flat_state, state_tree = tree_flatten(init_state) - state_avals = map(state_utils.val_to_ref_aval, flat_state) - idx_aval = core.ShapedArray((), dtypes.canonicalize_dtype(np.int64)) - jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs( - body, state_tree, [idx_aval, *state_avals], dbg) - if out_tree != tree_structure(None): - raise Exception("`body` should not return anything.") - jaxpr = state_utils.hoist_consts_to_refs(jaxpr, index=1) - which_linear = (False,) * (len(consts) + len(flat_state)) - out_flat = for_p.bind(*consts, *flat_state, jaxpr=jaxpr, nsteps=int(nsteps), - reverse=reverse, which_linear=which_linear, - unroll=unroll) - # Consts are `Ref`s so they are both inputs and outputs. We remove them from - # the outputs. - out_flat = out_flat[len(consts):] - return tree_unflatten(state_tree, out_flat) - -Carry = TypeVar('Carry') -X = TypeVar('X') -Y = TypeVar('Y') - -def scan(f: Callable[[Carry, X], tuple[Carry, Y]], - init: Carry, - xs: X | None = None, - length: int | None = None, - reverse: bool = False, - unroll: int = 1) -> tuple[Carry, Y]: - if not callable(f): - raise TypeError("scan: f argument should be a callable.") - if unroll < 1: - raise ValueError("`unroll` must be a positive integer.") - xs_flat, xs_tree = tree_flatten(xs) - - try: - lengths = [x.shape[0] for x in xs_flat] - except AttributeError as err: - msg = "scan got value with no leading axis to scan over: {}." - raise ValueError( - msg.format(', '.join(str(x) for x in xs_flat - if not hasattr(x, 'shape')))) from err - - if length is not None: - length = int(length) - if not all(length == l for l in lengths): - msg = ("scan got `length` argument of {} which disagrees with " - "leading axis sizes {}.") - raise ValueError(msg.format(length, [x.shape[0] for x in xs_flat])) - else: - unique_lengths = set(lengths) - if len(unique_lengths) > 1: - msg = "scan got values with different leading axis sizes: {}." - raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat))) - elif len(unique_lengths) == 0: - msg = "scan got no values to scan over and `length` not provided." - raise ValueError(msg) - else: - length, = unique_lengths - - x_shapes = [x.shape[1:] for x in xs_flat] - x_dtypes = [dtypes.canonicalize_dtype(x.dtype) for x in xs_flat] - x_avals = tuple(map(core.ShapedArray, x_shapes, x_dtypes)) - - def _create_jaxpr(init): - init_flat = tree_leaves(init) - _, in_tree = tree_flatten((init, xs)) - dbg = api_util.debug_info("scan", f, (init, xs), {}) - carry_avals = tuple(map(core.get_aval, init_flat)) - jaxpr, _, out_tree = _initial_style_jaxpr( - f, in_tree, carry_avals + x_avals, dbg) - return jaxpr, out_tree - jaxpr, out_tree = _create_jaxpr(init) - _, ys_avals = tree_unflatten(out_tree, jaxpr.out_avals) - ys = tree_map(lambda aval: lax.full([length, *aval.shape], 0, aval.dtype), - ys_avals) - def for_body(i, refs): - carry_refs, xs_refs, ys_refs = refs - carry = tree_map(lambda x: x[()], carry_refs) - x = tree_map(lambda x: x[i], xs_refs) - carry, y = f(carry, x) - tree_map(lambda c_ref, c: ref_set(c_ref, (), c), carry_refs, carry) - tree_map(lambda y_ref, y: ref_set(y_ref, (i,), y), ys_refs, y) - assert isinstance(length, int) - api_util.save_wrapped_fun_sourceinfo(for_body, f) - init, _, ys = for_loop(length, for_body, (init, xs, ys), reverse=reverse, - unroll=unroll) - return init, ys - -@for_p.def_effectful_abstract_eval -def _for_abstract_eval(*avals, jaxpr, **__): - # Find out for each of the `Ref`s in our jaxpr what effects they have. - jaxpr_aval_effects = state_types.get_ref_state_effects( - [v.aval for v in jaxpr.invars], jaxpr.effects)[1:] - aval_effects = [{eff.replace(input_index=eff.input_index - 1) - for eff in effs} for aval, effs - in zip(avals, jaxpr_aval_effects) - if isinstance(aval, AbstractRef)] - nonlocal_state_effects = core.join_effects(*aval_effects) - return list(avals), nonlocal_state_effects - -@state_discharge.register_discharge_rule(for_p) -def _for_discharge_rule(in_avals, _, *args: Any, jaxpr: core.Jaxpr, - reverse: bool, which_linear: Sequence[bool], - nsteps: int, unroll: int - ) -> tuple[Sequence[Any | None], Sequence[Any]]: - out_vals = for_p.bind(*args, jaxpr=jaxpr, reverse=reverse, - which_linear=which_linear, nsteps=nsteps, - unroll=unroll) - new_invals = [] - for aval, out_val in zip(in_avals, out_vals): - new_invals.append(out_val if isinstance(aval, AbstractRef) else None) - return new_invals, out_vals - -def _for_impl(*args, jaxpr, nsteps, reverse, which_linear, unroll): - del which_linear - discharged_jaxpr, consts = discharge_state(jaxpr, ()) - def body(i, state): - i_ = nsteps - i - 1 if reverse else i - return core.eval_jaxpr(discharged_jaxpr, consts, i_, *state) - return _for_impl_unrolled(body, nsteps, unroll, *args) - -def _for_impl_unrolled(body, nsteps, unroll, *args): - remainder = nsteps % unroll - i = lax.full((), 0, dtypes.canonicalize_dtype(np.int64)) - state = list(args) - - for _ in range(remainder): - state = body(i, state) - i = i + 1 - - def cond(carry): - i, _ = carry - return i < nsteps - def while_body(carry): - i, state = carry - for _ in range(unroll): - state = body(i, state) - i = i + 1 - return i, state - _, state = lax.while_loop(cond, while_body, (i, state)) - return state - -mlir.register_lowering(for_p, mlir.lower_fun(_for_impl, multiple_results=True)) -for_p.def_impl(functools.partial(dispatch.apply_primitive, for_p)) - -@weakref_lru_cache -def _cached_for_jaxpr(jaxpr): - discharged_jaxpr, body_consts = discharge_state(jaxpr, ()) - return core.ClosedJaxpr(discharged_jaxpr, body_consts) - -def _for_vmap(axis_data, args, dims, *, - jaxpr, nsteps, reverse, which_linear, unroll): - init_batched = [d is not batching.not_mapped for d in dims] - closed_jaxpr = _cached_for_jaxpr(jaxpr) - batched = init_batched - for _ in range(len(batched)): - _, out_batched = batching.batch_jaxpr( - closed_jaxpr, axis_data, [False] + batched, instantiate=batched) - if out_batched == batched: - break - batched = map(operator.or_, batched, out_batched) - else: - raise Exception("Invalid fixpoint") - args = [batching.broadcast(x, axis_data.size, 0, axis_data.explicit_mesh_axis) - if now_bat and not was_bat - else batching.moveaxis(x, d, 0) if now_bat else x - for x, d, was_bat, now_bat in zip(args, dims, init_batched, batched)] - batched_jaxpr_, _ = batching.batch_jaxpr( - pe.close_jaxpr(jaxpr), axis_data, [False] + batched, []) - batched_jaxpr, () = batched_jaxpr_.jaxpr, batched_jaxpr_.consts # TODO consts - out_flat = for_p.bind(*args, jaxpr=batched_jaxpr, nsteps=nsteps, - reverse=reverse, which_linear=which_linear, - unroll=unroll) - return out_flat, [0 if b else batching.not_mapped for b in batched] -batching.fancy_primitive_batchers[for_p] = _for_vmap - -def _for_jvp(primals, tangents, *, jaxpr, nsteps, reverse, which_linear, - unroll): - nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents] - # We need to find out which `Ref`s have nonzero tangents after running the - # for loop. Ordinarily we do this with a fixed point on the body jaxpr but - # a `for` body jaxpr is stateful and has no outputs. We therefore discharge - # the state effect from the jaxpr and we will now have a "symmetric" jaxpr - # where the inputs line up with the outputs. We use this discharged jaxpr - # for the fixed point. - closed_jaxpr = _cached_for_jaxpr(jaxpr) - for _ in range(len(nonzero_tangents)): - _, out_nonzero_tangents = ad.jvp_jaxpr( - closed_jaxpr, - [False] + nonzero_tangents, instantiate=nonzero_tangents) - if out_nonzero_tangents == nonzero_tangents: - break - nonzero_tangents = map(operator.or_, nonzero_tangents, out_nonzero_tangents) - else: - raise Exception("Invalid fixpoint") - tangents = [ad.instantiate_zeros(t) if inst else t - for t, inst in zip(tangents, nonzero_tangents)] - tangents = [t for t in tangents if type(t) is not ad_util.Zero] - closed_jaxpr = pe.close_jaxpr(jaxpr) - jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, [False] + nonzero_tangents, []) - jvp_jaxpr, () = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts # TODO consts - jvp_which_linear = which_linear + (True,) * len(tangents) - out_flat = for_p.bind(*primals, *tangents, jaxpr=jvp_jaxpr, - nsteps=nsteps, reverse=reverse, - which_linear=jvp_which_linear, unroll=unroll) - # `out_flat` includes constant inputs into the `for_loop` which are converted - # into outputs as well. We don't care about these in AD so we throw them out. - out_primals, out_tangents = split_list(out_flat, [len(primals)]) - out_tangents_iter = iter(out_tangents) - out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_primal_value(p) - for p, nz in zip(out_primals, nonzero_tangents)] - return out_primals, out_tangents -ad.primitive_jvps[for_p] = _for_jvp - - -def _partial_eval_jaxpr_custom(jaxpr, in_unknowns, policy): - # A simple wrapper around `pe.partial_eval_jaxpr_custom` that assumes all - # inputs are instantiated and doesn't ensure any outputs are unknown or - # instantiated. - return pe.partial_eval_jaxpr_custom( - jaxpr, in_unknowns, [True] * len(in_unknowns), False, False, policy) - -_save_everything = lambda *_, **__: True - -def _is_read_only(ref_effects: set[StateEffect]) -> bool: - assert len(ref_effects) > 0 - if len(ref_effects) > 1: - # Means we must have a write or accum effect so not read-only - return False - eff, = ref_effects - return isinstance(eff, ReadEffect) - -def _loop_invariant_outputs(jaxpr: core.Jaxpr) -> list[bool]: - # Get effects for each of the jaxpr inputs and remove the loop index. - ref_effects = state_types.get_ref_state_effects( - [v.aval for v in jaxpr.invars], jaxpr.effects)[1:] - # We first assume that *read-only `Ref`s* are loop-invariant. We can safely do - # this because the only way something can be loop-varying is if we write to it - # at some point. It's *possible* that read-write `Ref`s are loop-invariant but - # we conservatively assume they aren't. - loop_invar_refs = [_is_read_only(effs) if effs else True - for effs in ref_effects] - loop_var_refs = map(operator.not_, loop_invar_refs) - - # We'd like to detect if the outputs of the jaxpr are loop-invariant. An - # output is loop-invariant if it is downstream of only loop-invariant values - # (seeded by the read-only `Ref`s). If at any point, a loop-varying value - # interacts with a loop-invariant value, we produce a loop-varying value. We - # can use `partial_eval` to perform this analysis by treating loop-varying - # values as "unknown" and loop-invariant values as "known", since when a known - # and unknown value interact, they produce an unknown value. - loop_var_inputs = [True, *loop_var_refs] - _, _, loop_var_outputs, _, _, = _partial_eval_jaxpr_custom( - jaxpr, loop_var_inputs, _save_everything) - return map(operator.not_, loop_var_outputs) - - -def _for_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer, - jaxpr: core.Jaxpr, nsteps: int, reverse: bool, - which_linear: tuple[bool, ...], - unroll: int) -> list[pe.JaxprTracer]: - num_inputs = len(tracers) - assert num_inputs == len(jaxpr.invars) - 1 - in_unknowns = [not t.pval.is_known() for t in tracers] - # We first need to run a fixpoint to determine which of the `Ref`s are unknown - # after running the for loop. We want to use the jaxpr to determine which - # `Ref`s are unknown after executing the for loop body given which `Ref`s are - # unknown before. However, the jaxpr has no outputs. Instead, we discharge - # the body and run the fixpoint with the discharged jaxpr. We can do this - # because the outputs of the jaxpr are one-to-one with the inputs. - discharged_jaxpr, discharged_consts = discharge_state(jaxpr, ()) - discharged_jaxpr = discharged_jaxpr.replace( - invars=discharged_jaxpr.constvars + discharged_jaxpr.invars, - constvars=[]) - for _ in range(num_inputs): - jaxpr_in_unknowns = [False] * len(discharged_consts) + [False, *in_unknowns] - _, _, out_unknowns, _, _, = pe.partial_eval_jaxpr_custom( - discharged_jaxpr, jaxpr_in_unknowns, [True] * len(jaxpr_in_unknowns), - in_unknowns, False, _save_everything) - out_unknowns = list(out_unknowns) - if out_unknowns == in_unknowns: - break - in_unknowns = map(operator.or_, in_unknowns, out_unknowns) - else: - raise Exception("Invalid fixpoint") - del out_unknowns # redundant since it's the same as `in_unknowns` - tracers = tuple(trace.instantiate_const(t) if uk else t - for t, uk in zip(tracers, in_unknowns)) - - # We use `partial_eval_jaxpr_custom` here because it won't remove effectful - # primitives like `get`/`set`. - jaxpr_known_resout, jaxpr_unknown_resin_, uk_out, inst_out, num_res = \ - _partial_eval_jaxpr_custom(jaxpr, [False, *in_unknowns], - _save_everything) - # # `partial_eval_jaxpr_custom` will give us jaxprs that have hybrid `Ref` and - # regular valued input/outputs. However, we'd like to bind these jaxprs to a - # `for`, which expects only `Ref` inputs and no output. We need to convert - # both of these jaxprs into ones that are compatible with `for`. - # TODO(sharadmv,mattjj): implement "passthrough" optimization. - # TODO(sharadmv,mattjj): rematerialize loop-dependent values instead of - # passing the loop index as a residual - - # `jaxpr_known_resout` is a jaxpr that maps from all the input `Refs` - # to output residual values (none of them should be `Ref`s). We'll need to - # convert the output residual values into `Ref`s that are initially empty - # `Ref`s that are written to at the end of the jaxpr. - - # # Loop-invariant residual optimization - # Here we are interested in finding out which of the residuals are *not* - # dependent on the loop index. If a residual is not dependent on the loop - # index, we don't need add an extra loop dimension we're reading from when we - # convert it from an output into a write. - loop_invar_res = _loop_invariant_outputs(jaxpr_known_resout) - - jaxpr_known, res_avals = _convert_outputs_to_writes(nsteps, - jaxpr_known_resout, - loop_invar_res) - # We now run the known jaxpr to obtain our residual values. - known_tracers, _ = partition_list(in_unknowns, tracers) - known_vals = [t.pval.get_known() for t in known_tracers] - empty_res = map(ad_util.zeros_like_aval, res_avals) - jaxpr_known_args = [*known_vals, *empty_res] - # We assume the known inputs are nonlinear which is okay to do for AD but not - # necessarily okay for general partial eval. - jaxpr_known_which_linear = (False,) * len(jaxpr_known_args) - out_flat = for_p.bind(*jaxpr_known_args, jaxpr=jaxpr_known, nsteps=nsteps, - reverse=reverse, which_linear=jaxpr_known_which_linear, - unroll=unroll) - known_outputs, residuals = split_list(out_flat, [len(known_tracers)]) - residuals = map(trace.new_instantiated_const, residuals) - - # Now we handle the `jaxpr_unknown` that expects residual values as inputs. - # This jaxpr is the output of `partial_eval_jaxpr_custom` that marks which - # inputs are actually used. - # `partial_eval_jaxpr_custom` doesn't remove extra inputs/outputs for you - # so we use `dce_jaxpr` here to do that. - jaxpr_unknown_resin, used_inputs = pe.dce_jaxpr( - jaxpr_unknown_resin_, [], [True] * num_res + [True, *in_unknowns]) - used_res, (used_i,), used_refs = split_list(used_inputs, [num_res, 1]) - assert all(used_res), "All residuals should be used" - # To make it compatible with `for`, we need to convert those residual values - # into `Ref`s. - jaxpr_unknown = _convert_inputs_to_reads(nsteps, len(res_avals), - jaxpr_unknown_resin, - loop_invar_res) - # Since not all inputs are used in jaxpr_unknown, we filter the input tracers - # down using the output of `dce_jaxpr`. - used_and_known = map(operator.and_, used_refs, map(operator.not_, in_unknowns)) - tracers = [trace.instantiate_const(t) if u_and_k else t for t, u_and_k - in zip(tracers, used_and_known)] - _, known_used = partition_list(used_refs, used_and_known) - _, used_tracers = partition_list(used_refs, tracers) - _, used_which_linear = partition_list(used_refs, which_linear) - which_linear_unknown = (False,) * num_res + tuple(used_which_linear) - unknown_inputs = [*residuals, *used_tracers] - # Outputs match inputs so we construct output tracers that look like the input - # tracers. - res_ref_unknown_outputs = [ - pe.JaxprTracer(trace, pe.PartialVal.unknown(t.aval), None) - for t in unknown_inputs] - name_stack = source_info_util.current_name_stack()[len(trace.name_stack):] - source = source_info_util.current().replace(name_stack=name_stack) - - assert len(unknown_inputs) == len(res_ref_unknown_outputs) - assert len(unknown_inputs) == len(jaxpr_unknown.invars) - 1 - eqn = pe.new_eqn_recipe(unknown_inputs, res_ref_unknown_outputs, - for_p, dict(jaxpr=jaxpr_unknown, nsteps=nsteps, - reverse=reverse, - which_linear=which_linear_unknown, - unroll=unroll), - core.no_effects, source) - for t in res_ref_unknown_outputs: t.recipe = eqn - _, unknown_outputs = split_list(res_ref_unknown_outputs, [num_res]) - unknown_outputs, _ = partition_list(known_used, unknown_outputs) - return merge_lists(in_unknowns, known_outputs, unknown_outputs) -pe.custom_partial_eval_rules[for_p] = _for_partial_eval - -def _for_partial_eval_custom(saveable, in_unknowns, in_inst, eqn): - jaxpr, nsteps, reverse, which_linear, unroll = split_dict( - eqn.params, ["jaxpr", "nsteps", "reverse", "which_linear", "unroll"]) - num_inputs = len(eqn.invars) - # We first need to run a fixpoint to determine which of the `Ref`s are unknown - # after running the for loop. However, the jaxpr has no outputs. Instead, we - # discharge the body and run the fixpoint with the discharged jaxpr. We can do - # this because the outputs of the discharged jaxpr are one-to-one with the - # inputs. - discharged_jaxpr, discharged_consts = discharge_state(jaxpr, ()) - discharged_jaxpr = discharged_jaxpr.replace( - invars=discharged_jaxpr.constvars + discharged_jaxpr.invars, - constvars=[]) - in_unknowns, in_inst = list(in_unknowns), list(in_inst) - out_unknowns, out_inst = in_unknowns, in_inst - for _ in range(num_inputs): - jaxpr_in_unknowns = [False] * len(discharged_consts) + [False, *in_unknowns] - _, _, out_unknowns, out_inst, _, = pe.partial_eval_jaxpr_custom( - discharged_jaxpr, jaxpr_in_unknowns, True, - ensure_out_unknowns=in_unknowns, ensure_out_inst=True, - saveable=saveable) - out_unknowns = list(out_unknowns) - if out_unknowns == in_unknowns: - break - in_unknowns = map(operator.or_, in_unknowns, out_unknowns) - else: - if num_inputs > 0: raise Exception("Invalid fixpoint") - del out_unknowns # Redundant since it's the same as `in_unknowns` - new_inst = [x for x, inst in zip(eqn.invars, in_inst) - if type(x) is core.Var and not inst] - in_inst = [True] * len(eqn.invars) - - # We use `partial_eval_jaxpr_custom` here because it won't remove effectful - # primitives like `get`/`set`. - jaxpr_known_resout, jaxpr_staged_resin_, _, _, num_res = \ - pe.partial_eval_jaxpr_custom(jaxpr, [False, *in_unknowns], - [True, *in_inst], [], [], saveable) - - # `partial_eval_jaxpr_custom` will give us jaxprs that have hybrid `Ref` and - # non-Ref input/outputs. However, we'd like to bind these jaxprs to a - # `for`, which expects only `Ref` inputs and no output. We need to convert - # both of these jaxprs into ones that are compatible with `for`. - # TODO(sharadmv,mattjj): implement "passthrough" optimization. - # TODO(sharadmv,mattjj): rematerialize loop-dependent values instead of - # passing the loop index as a residual - - # `jaxpr_known_resout` is a jaxpr that maps from all the input `Refs` - # to output residual values (none of them should be `Ref`s). We'll need to - # convert the output residual values into `Ref`s that are initially empty - # `Ref`s that are written to at the end of the jaxpr. - - # # Loop-invariant residual optimization - # Here we are interested in finding out which of the residuals are *not* - # dependent on the loop index. If a residual is not dependent on the loop - # index, we don't need add an extra loop dimension we're reading from when we - # convert it from an output into a write. - loop_invar_res = _loop_invariant_outputs(jaxpr_known_resout) - - jaxpr_known, res_avals = _convert_outputs_to_writes(nsteps, - jaxpr_known_resout, - loop_invar_res) - - known_invars, _ = partition_list(in_unknowns, eqn.invars) - known_outvars, _ = partition_list(in_unknowns, eqn.outvars) - newvar = core.gensym() - resvars = map(newvar, res_avals) - - def known(*known_vals): - empty_res = map(ad_util.zeros_like_aval, res_avals) - jaxpr_known_args = [*known_vals, *empty_res] - jaxpr_known_which_linear = (False,) * len(jaxpr_known_args) - return for_p.bind(*jaxpr_known_args, jaxpr=jaxpr_known, nsteps=nsteps, - reverse=reverse, which_linear=jaxpr_known_which_linear, - unroll=unroll) - call_jaxpr_, _, call_jaxpr_consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(known, debug_info=jaxpr.debug_info), - [v.aval for v in known_invars]) - call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts) - eqn_known = pe.new_jaxpr_eqn(known_invars, [*known_outvars, *resvars], - core.closed_call_p, dict(call_jaxpr=call_jaxpr), - call_jaxpr.effects, eqn.source_info, eqn.ctx) - - jaxpr_staged = _convert_inputs_to_reads(nsteps, len(res_avals), - jaxpr_staged_resin_, - loop_invar_res) - which_linear_unknown = (False,) * num_res + tuple(which_linear) - params_staged = dict(eqn.params, jaxpr=jaxpr_staged, reverse=reverse, - nsteps=nsteps, - which_linear=which_linear_unknown, - unroll=unroll) - - def staged(*res_and_refs): - out_flat = for_p.bind(*res_and_refs, **params_staged) - _, ans = split_list(out_flat, [num_res]) - _, ans = partition_list(out_inst, ans) - return ans - call_jaxpr_, _, call_jaxpr_consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(staged, debug_info=jaxpr_staged.debug_info), - [v.aval for v in [*resvars, *eqn.invars]]) - assert len(jaxpr_staged.invars) - 1 == len(call_jaxpr_.invars) - call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts) - _, outvars = partition_list(out_inst, eqn.outvars) - eqn_staged = pe.new_jaxpr_eqn([*resvars, *eqn.invars], outvars, - core.closed_call_p, dict(call_jaxpr=call_jaxpr), - call_jaxpr.effects, eqn.source_info, eqn.ctx) - new_vars = [*new_inst, *resvars] - return eqn_known, eqn_staged, in_unknowns, out_inst, new_vars - -pe.partial_eval_jaxpr_custom_rules[for_p] = _for_partial_eval_custom - -def _convert_outputs_to_writes( - nsteps: int, jaxpr: core.Jaxpr, loop_invar_res: Sequence[bool] - ) -> tuple[core.Jaxpr, list[core.ShapedArray]]: - assert not jaxpr.constvars, "Jaxpr shouldn't have constvars." - - in_avals = [v.aval for v in jaxpr.invars] # [i, *orig_ref_avals] - - def eval_jaxpr(i, *refs): - # We split the refs into the original input refs and the dummy residual - # refs. - orig_refs, residual_refs = split_list(refs, [len(in_avals) - 1]) - residual_vals = core.eval_jaxpr(jaxpr, (), i, *orig_refs) - for res_ref, res_val, loop_invar in zip(residual_refs, residual_vals, - loop_invar_res): - if loop_invar: - res_ref[()] = res_val - else: - res_ref[i] = res_val - return [] - # TODO(mattjj, sharadmv): better handling of tokens, which don't have shape/dtype - res_ref_avals: list[core.AbstractValue] = [ - AbstractRef(v.aval) if loop_invar else # pytype: disable=attribute-error - AbstractRef(core.ShapedArray((nsteps, *v.aval.shape), # pytype: disable=attribute-error - v.aval.dtype)) # pytype: disable=attribute-error - for v, loop_invar in zip(jaxpr.outvars, loop_invar_res)] - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(eval_jaxpr, debug_info=jaxpr.debug_info), - [*in_avals, *res_ref_avals]) - assert not consts - return jaxpr, [core.ShapedArray(a.shape, a.dtype) for a in res_ref_avals] # pytype: disable=attribute-error - -def _convert_inputs_to_reads( - nsteps: int, num_res: int, jaxpr: core.Jaxpr, - loop_invar_res: Sequence[bool]) -> core.Jaxpr: - assert not jaxpr.constvars, "Jaxpr should not have constvars" - - def eval_jaxpr(i, *refs): - residual_refs, orig_refs = split_list(refs, [num_res]) - residual_vals = [r[()] if loop_invar else r[i] for r, loop_invar - in zip(residual_refs, loop_invar_res)] - () = core.eval_jaxpr(jaxpr, (), *residual_vals, i, *orig_refs) - return [] - - res_val_avals, (i_aval,), orig_ref_avals = \ - split_list([v.aval for v in jaxpr.invars], [num_res, 1]) - res_ref_avals: list[core.AbstractValue] = [ - AbstractRef(aval) if loop_invar else # pytype: disable=attribute-error - AbstractRef(core.ShapedArray((nsteps, *aval.shape), # pytype: disable=attribute-error - aval.dtype)) # pytype: disable=attribute-error - for aval, loop_invar in zip(res_val_avals, loop_invar_res)] - - jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(eval_jaxpr, debug_info=jaxpr.debug_info), - [i_aval, *res_ref_avals, *orig_ref_avals]) - return jaxpr - -def transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: list[bool]) -> core.Jaxpr: - def trans(i, *args): - # First we want to run the computation to read all the residual refs. We can - # do that by using partial evaluation with all linear inputs unknown. - res_jaxpr, tangent_jaxpr_, *_ = \ - _partial_eval_jaxpr_custom(jaxpr, [False, *which_linear], - _save_everything) - res_args = [x for x, lin in zip(args, which_linear) if not lin] - res = core.eval_jaxpr(res_jaxpr, (), i, *res_args) - - # Now that we have residual values, we run the tangent jaxpr. It takes as - # input the residuals, the loop index, and all the refs (at least, the ones - # that are used in the body). Luckily, `tangent_jaxpr_` has all known and - # unknown inputs! - tangent_jaxpr, used = pe.dce_jaxpr(tangent_jaxpr_, []) - used_res, (used_i,), used_ct = split_list(used, [len(res), 1]) - primals_args = [*(r for u, r in zip(used_res, res) if u)] - if used_i: - primals_args = [*primals_args, i] - ct_args = [x for x, u in zip(args, used_ct) if u] - ad.backward_pass(tangent_jaxpr, False, (), (*primals_args, *ct_args), ()) - return [] - jaxpr_trans, _, _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(trans, debug_info=jaxpr.debug_info), - [v.aval for v in jaxpr.invars]) - return jaxpr_trans - -def _for_transpose(in_cts, *args, jaxpr, nsteps, reverse, which_linear, unroll): - # if any in_ct is nonzero, we definitely want it in args_ (and the - # corresponding x in args could be an undefined primal, but doesn't have to be) - # for non-res stuff: - # getting and setting => (nonzero ct, UndefinedPrimal arg) - # just setting => (nonzero ct, not UndefinedPrimal, dummy value) - # just getting => (zero ct , UndefinedPrimal arg) - # for res stuff: - # (zero ct , not UndefinedPrimal) - args_ = [] - which_linear_transpose = [] - for x, ct in zip(args, in_cts): - if type(ct) is ad_util.Zero and not ad.is_undefined_primal(x): - # this is a residual, take x! - args_.append(x) - which_linear_transpose.append(False) - elif type(ct) is ad_util.Zero and ad.is_undefined_primal(x): - # the loop was 'just getting', plug in a zero - args_.append(ad_util.zeros_like_aval(x.aval)) - which_linear_transpose.append(False) - elif type(ct) is not ad_util.Zero and not ad.is_undefined_primal(x): - # the loop was 'just setting', grab that cotangent! x is dummy - args_.append(ct) - which_linear_transpose.append(False) - elif type(ct) is not ad_util.Zero and ad.is_undefined_primal(x): - # the loop was 'getting and setting', grab that cotangent! - args_.append(ct) - which_linear_transpose.append(True) - - jaxpr_transpose = transpose_jaxpr(jaxpr, which_linear) - assert len(args_) == len(jaxpr_transpose.invars) - 1 - all_outs = for_p.bind(*args_, jaxpr=jaxpr_transpose, nsteps=nsteps, - reverse=not reverse, - which_linear=tuple(which_linear_transpose), - unroll=unroll) - ct_outs = [ct if ad.is_undefined_primal(x) else None - for x, ct in zip(args, all_outs)] - return ct_outs -ad.primitive_transposes[for_p] = _for_transpose - -### Testing utility - -def discharged_for_loop(nsteps, body, init_state, *, reverse: bool = False): - """A `for_loop` implementation that discharges its body right away. - - Potentially useful for testing and benchmarking. - """ - flat_state, state_tree = tree_flatten(init_state) - state_avals = map(state_utils.val_to_ref_aval, flat_state) - idx_aval = core.ShapedArray((), dtypes.canonicalize_dtype(np.int64)) - debug = api_util.debug_info("discharged_for_loop", body, (0, init_state), {}) - jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs( - body, state_tree, [idx_aval, *state_avals], debug) - if out_tree != tree_structure(None): - raise Exception("`body` should not return anything.") - discharged_jaxpr, discharged_consts = discharge_state(jaxpr, consts) - - def fori_body(i, carry): - i = lax.convert_element_type(i, dtypes.canonicalize_dtype(np.int64)) - if reverse: - i = nsteps - i - 1 - out_flat = core.eval_jaxpr(discharged_jaxpr, discharged_consts, - i, *carry) - return out_flat - out_flat = loops.fori_loop(0, nsteps, fori_body, flat_state) - return tree_unflatten(state_tree, out_flat) - -def run_state(f, init_state): - @functools.wraps(f) - def wrapped_body(_, *args): - return f(*args) - return for_loop(1, wrapped_body, init_state) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 3084fa722977..61d9a77e5cb5 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -17,7 +17,7 @@ from collections.abc import Callable, Sequence from functools import partial import inspect -import itertools +import itertools as it import operator from typing import Any, TypeVar import weakref @@ -32,52 +32,43 @@ from jax._src import dtypes from jax._src import effects from jax._src import linear_util as lu +from jax._src import literals from jax._src import source_info_util from jax._src import state from jax._src import util from jax._src.api_util import ( - _check_no_aliased_ref_args, _check_no_aliased_closed_over_refs) -from jax._src.core import ShapedArray + check_no_aliased_ref_args, _check_no_aliased_closed_over_refs) +from jax._src.core import ( + ShapedArray, typeof, cur_qdd, ClosedJaxpr, AbstractValue) from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import pxla from jax._src import sharding_impls as sharding -from jax._src.interpreters import xla +from jax._src.mesh import use_abstract_mesh from jax._src.lax import lax from jax._src.lax import slicing from jax._src.lax import windowed_reductions from jax._src.lax.control_flow.common import ( - _avals_short, _initial_style_jaxpr, - _initial_style_jaxpr_attrs, _make_closed_jaxpr_attrs, _prune_zeros, - _typecheck_param) + _avals_short, _prune_zeros, _typecheck_param, + _make_closed_jaxpr) from jax._src.lax.other import logaddexp +from jax._src.pjit import auto_axes, PartitionSpec as P, reshard from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo -from jax._src.state import discharge as state_discharge +from jax._src.sharding_impls import canonicalize_sharding +from jax._src.state import discharge as state_discharge, AbstractRef from jax._src.traceback_util import api_boundary from jax._src.tree_util import equality_errors from jax._src.typing import Array from jax._src.util import ( - merge_lists, - partition_list, - safe_map, - safe_zip, - split_list, - split_list_checked, - unzip2, - weakref_lru_cache, -) + merge_lists, partition_list, safe_map, safe_zip, split_list, + split_list_checked, unzip2, weakref_lru_cache, subs_list) from jax._src import xla_bridge as xb -from jax.tree_util import ( - keystr, - tree_flatten, - tree_flatten_with_path, - tree_map, - tree_unflatten, - treedef_is_leaf, -) +from jax._src.tree_util import ( + keystr, tree_flatten, tree_map, tree_unflatten, + treedef_is_leaf, FlatTree) import numpy as np _map = safe_map @@ -91,29 +82,14 @@ def _stack(arrs: Sequence[Array], axis: int=0) -> Array: return lax.concatenate([lax.expand_dims(arr, (axis,)) for arr in arrs], dimension=axis) -def _promote_weak_typed_inputs(in_vals, in_avals, out_avals): - """Promote weakly-typed in_vals to be compatible with out_avals. - - Args: - in_vals : flattened list of input values. - in_avals : corresponding list of avals. - out_avals : list of target output avals. - Returns: - in_vals_new : flattened list of modified in_vals with no weak types. - changed : bool; true if in_vals required modification. - """ - if len(in_vals) != len(in_avals) or len(in_avals) != len(out_avals): - # Calling function is responsible for catching this. - return in_vals, False - weak_mismatches = [i for i, (a1, a2) in enumerate(zip(in_avals, out_avals)) - if getattr(a1, 'weak_type', False) and not core.typematch(a1, a2)] - if not weak_mismatches: - return in_vals, False - for i in weak_mismatches: - new_dtype = dtypes.result_type(in_vals[i], out_avals[i]) - in_vals[i] = lax.convert_element_type(in_vals[i], new_dtype) - return in_vals, True - +def _promote_weak_typed_input( + in_val:Any, in_aval:AbstractValue, out_aval:AbstractValue + ) -> tuple[Any, bool]: + if getattr(in_aval, 'weak_type', False) and not core.typematch(in_aval, out_aval): + new_dtype = dtypes.result_type(in_val, out_aval) + return lax.convert_element_type(in_val, new_dtype), True + else: + return in_val, False ### scan @@ -121,7 +97,7 @@ def _promote_weak_typed_inputs(in_vals, in_avals, out_avals): X = TypeVar('X') Y = TypeVar('Y') -@api_boundary +@partial(api_boundary, repro_api_name="jax.lax.scan") def scan(f: Callable[[Carry, X], tuple[Carry, Y]], init: Carry, xs: X | None = None, @@ -178,6 +154,11 @@ def scan(f, init, xs, length=None): :py:func:`scan` compiles ``f``, so while it can be combined with :py:func:`jit`, it's usually unnecessary. + .. note:: + :func:`scan` is designed for iterating with a static number of iterations. + For iteration with a dynamic number of iterations, use :func:`fori_loop` + or :func:`while_loop`. + Args: f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning that ``f`` accepts two arguments where the first is a value of the loop @@ -197,12 +178,13 @@ def scan(f, init, xs, length=None): reverse: optional boolean specifying whether to run the scan iteration forward (the default) or in reverse, equivalent to reversing the leading axes of the arrays in both ``xs`` and in ``ys``. - unroll: optional positive int or bool specifying, in the underlying + unroll: optional non-negative int or bool specifying, in the underlying operation of the scan primitive, how many scan iterations to unroll within a single iteration of a loop. If an integer is provided, it determines how many unrolled loop iterations to run within a single rolled iteration of - the loop. If a boolean is provided, it will determine if the loop is - competely unrolled (i.e. `unroll=True`) or left completely rolled (i.e. + the loop. `unroll=0` unrolls the entire loop. + If a boolean is provided, it will determine if the loop is + completely unrolled (i.e. `unroll=True`) or left completely rolled (i.e. `unroll=False`). _split_transpose: experimental optional bool specifying whether to further split the transpose into a scan (computing activation gradients), and a @@ -219,97 +201,70 @@ def scan(f, init, xs, length=None): """ if not callable(f): raise TypeError("lax.scan: f argument should be a callable.") - xs_flat, xs_tree = tree_flatten(xs) - - try: - lengths = [x.shape[0] for x in xs_flat] - except AttributeError as err: - msg = "scan got value with no leading axis to scan over: {}." - raise ValueError( - msg.format(', '.join(str(x) for x in xs_flat - if not hasattr(x, 'shape')))) from err - xs_avals = [core.get_aval(x) for x in xs_flat] + dbg_body = api_util.debug_info("scan", f, (init, xs), {}) + init = FlatTree.flatten(init) + xs = FlatTree.flatten(xs) + args = FlatTree.pack((init, xs)) - if not all(a.sharding.spec[0] is None for a in xs_avals): - raise ValueError('0th dimension of all xs should be replicated. Got ' - f'{", ".join(str(a.sharding.spec) for a in xs_avals)}') + args_avals = args.map(core.get_aval) + init_avals, xs_avals = args_avals.unpack() - if length is not None: - try: - length = int(length) - except core.ConcretizationTypeError as err: - msg = 'The `length` argument to `scan` expects a concrete `int` value.' - raise core.ConcretizationTypeError(length, msg) from None # type: ignore[arg-type] - if not all(length == l for l in lengths): - msg = ("scan got `length` argument of {} which disagrees with " - "leading axis sizes {}.") - raise ValueError(msg.format(length, [x.shape[0] for x in xs_flat])) - else: - unique_lengths = set(lengths) - if len(unique_lengths) > 1: - msg = "scan got values with different leading axis sizes: {}." - raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat))) - elif len(unique_lengths) == 0: - msg = "scan got no values to scan over and `length` not provided." - raise ValueError(msg) - else: - length, = unique_lengths + length = _infer_scan_length(list(xs), list(xs_avals), length) if config.disable_jit.value: if length == 0: raise ValueError("zero-length scan is not supported in disable_jit() " "mode because the output type is unknown.") - carry = init + carry = init.unflatten() ys = [] maybe_reversed = reversed if reverse else lambda x: x for i in maybe_reversed(range(length)): - xs_slice = [slicing.index_in_dim(x, i, keepdims=False) for x in xs_flat] - carry, y = f(carry, tree_unflatten(xs_tree, xs_slice)) + xs_slice = xs.map(lambda x: slicing.index_in_dim(x, i, keepdims=False)) + carry, y = f(carry, xs_slice.unflatten()) ys.append(y) stack = lambda *ys: _stack(ys) stacked_y = tree_map(stack, *maybe_reversed(ys)) return carry, stacked_y - x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals] - dbg_body = api_util.debug_info("scan", f, (init, xs), {}) - if config.mutable_array_checks.value: - in_flat, in_tree = tree_flatten((init, xs)) - in_avals = tuple(_map(core.get_aval, in_flat)) - _check_no_aliased_ref_args(dbg_body, in_avals, in_flat) - - def _create_jaxpr(init): - init_flat, init_tree = tree_flatten(init) - in_flat, in_tree = tree_flatten((init, xs)) - carry_avals = tuple(_map(core.get_aval, init_flat)) - jaxpr, consts, out_tree, attrs_tracked = _initial_style_jaxpr_attrs( - f, in_tree, (*carry_avals, *x_avals), debug_info=dbg_body) - if config.mutable_array_checks.value: - _check_no_aliased_closed_over_refs(dbg_body, (*jaxpr.consts, *consts), in_flat) - out_tree_children = out_tree.children() - if len(out_tree_children) != 2: + check_no_aliased_ref_args(lambda: dbg_body, list(args_avals), list(args)) + + x_avals = xs_avals.map(lambda aval: core.mapped_aval(length, 0, aval)) + def _create_jaxpr(carry_avals): + new_arg_avals = FlatTree.pack(((carry_avals, x_avals), {})) + jaxpr, out_avals = pe.trace_to_jaxpr(f, new_arg_avals, dbg_body) + jaxpr, consts = pe.separate_consts(jaxpr) + if len(out_avals.unpack()) != 2: msg = "scan body output must be a pair, got {}." - raise TypeError(msg.format(tree_unflatten(out_tree, jaxpr.out_avals))) - _, carry_avals_out, _ = split_list( - jaxpr.out_avals, [len(attrs_tracked), out_tree_children[0].num_leaves]) - return (init_flat, carry_avals, carry_avals_out, init_tree, in_flat, jaxpr, - consts, out_tree, out_tree_children, attrs_tracked) + raise TypeError(msg.format(out_avals.unflatten())) + return jaxpr, out_avals, consts # The carry input and output avals must match exactly. However, we want to account for # the case when init contains weakly-typed values (e.g. Python scalars), with avals that # may not match the output despite being compatible by virtue of their weak type. # To do this, we compute the jaxpr in two passes: first with the raw inputs, and if # necessary, a second time with modified init values. - init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init) - new_init_flat, changed = _promote_weak_typed_inputs(init_flat, carry_avals, carry_avals_out) - if changed: - init = tree_unflatten(init_tree, new_init_flat) - init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init) - in_flat, jaxpr, consts, out_tree, out_tree_children, attrs_tracked = rest - num_carry = len(init_flat) - - _check_carry_type('scan body', f, init, out_tree_children[0], carry_avals_out) + # TODO(dougalm): this two-pass stuff is expensive (exponential in scan nesting + # depth) and incomplete (because in the general case it takes more than two passes). + # Let's get rid of it, perhaps after getting rid of weak types altogether. + jaxpr, out_avals, consts = _create_jaxpr(init_avals) + if config.mutable_array_checks.value: + _check_no_aliased_closed_over_refs(dbg_body, consts, list(args)) + carry_out_avals, ys_avals = out_avals.unpack() + if len(carry_out_avals) != len(init_avals): + _check_carry_type('scan body', f, init_avals, carry_out_avals) + init, changed = init.map3( + _promote_weak_typed_input, + init_avals, carry_out_avals).unzip2() + num_carry, num_xs, num_ys = len(init), len(xs), len(ys_avals) + if any(changed): + init_avals = init.map(core.get_aval) + jaxpr, out_avals, consts = _create_jaxpr(init_avals) + carry_out_avals, ys_avals = out_avals.unpack() + + _check_carry_type('scan body', f, init_avals, carry_out_avals) + disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(jaxpr.effects) if disallowed_effects: raise NotImplementedError( @@ -321,46 +276,100 @@ def _create_jaxpr(init): "value.") if isinstance(unroll, bool): unroll = max(length, 1) if unroll else 1 - if unroll < 1: - raise ValueError("`unroll` must be a `bool` or a positive `int`.") - if attrs_tracked: - in_state = _get_states(attrs_tracked) - in_carry, in_ext = split_list(in_flat, [num_carry]) - in_flat = [*in_state, *in_carry, *in_ext] - num_carry += len(attrs_tracked) - out = scan_p.bind(*consts, *in_flat, + if unroll < 0: + raise ValueError("`unroll` must be a `bool` or a non-negative `int`.") + + args_flat = [*init.vals, *xs.vals] + + # If the body forwards an input carry to an output carry, that input is + # read-only and can be moved to be a const. Doing so can lead to efficiency + # wins, e.g. if the scan is inside a cond with a batched predicate. + num_ys = len(jaxpr.out_avals) - num_carry + carry_fwd, ext_fwd = split_list(pe._jaxpr_forwarding(jaxpr.jaxpr), [num_carry]) + move_to_const = [len(consts) + i == f for i, f in enumerate(carry_fwd)] + if any(move_to_const): + jaxpr = pe.prune_closed_jaxpr_outputs( + jaxpr, [not m for m in move_to_const] + [True] * num_ys) + jaxpr = pe.move_binders_to_front( + jaxpr, [False] * len(consts) + move_to_const + [False] * num_xs) + args_flat, new_consts = partition_list(move_to_const + [False] * num_xs, args_flat) + consts = [*new_consts, *consts] + num_carry -= len(new_consts) + + # When an extensive output is forwarded from an extensive input, we can + # avoid copying it by pruning it from the jaxpr and forwarding manually. We + # don't need to update the indexing based on the optimization above since it + # doesn't change the total number of consts and carries combined, and + # `ext_fwd` already only includes the extensive outputs. But, we do remove + # the number of consts from the index since we're going to use it to index + # into `in_flat`, which doesn't include consts. + ext_to_ext_fwd = [ + in_idx - len(consts) if in_idx is not None and + in_idx >= num_carry + len(consts) else None for in_idx in ext_fwd] + jaxpr = pe.prune_closed_jaxpr_outputs( + jaxpr, [True] * num_carry + [i is None for i in ext_to_ext_fwd]) + + out = scan_p.bind(*consts, *args_flat, reverse=reverse, length=length, jaxpr=jaxpr, num_consts=len(consts), num_carry=num_carry, - linear=(False,) * (len(consts) + len(in_flat)), - unroll=unroll, - _split_transpose=_split_transpose) - if attrs_tracked: - out_state, out = split_list(out, [len(attrs_tracked)]) - _set_states(attrs_tracked, out_state) - return tree_unflatten(out_tree, out) - -def _set_states(attrs_tracked, vals): - from jax.experimental.attrs import jax_setattr - valss = split_list_checked(vals, [td.num_leaves for _, td, _ in attrs_tracked]) - for ((_, treedef, (obj, attr)), leaves) in zip(attrs_tracked, valss): - val = tree_unflatten(treedef, leaves) - jax_setattr(obj, attr, val) - -def _get_states(attrs_tracked): - from jax.experimental.attrs import jax_getattr - vals = [] - for treedef, _, (obj, attr) in attrs_tracked: - tree = jax_getattr(obj, attr) - leaves, treedef_ = tree_flatten(tree) - assert treedef == treedef_ - vals.extend(leaves) - return vals + linear=(False,) * (len(consts) + len(args_flat)), + unroll=unroll, _split_transpose=_split_transpose) + + # Apply input to output forwarding that was computed above. + carry_out, out = split_list(out, [num_carry]) + out_ = iter(out) + out = [next(out_) if f is None else _maybe_put(args_flat[f]) for f in ext_to_ext_fwd] + assert next(out_, None) is None + out = [*carry_out, *out] + + if any(move_to_const): + out = pe.merge_lists(move_to_const + [False] * num_ys, out, new_consts) + + return out_avals.update(out).unflatten() + +def _infer_scan_length( + xs_flat: list[Any], xs_avals: list[AbstractValue], + length: int | None) -> int: + try: + lengths = [x.shape[0] for x in xs_flat] + except AttributeError as err: + msg = "scan got value with no leading axis to scan over: {}." + raise ValueError( + msg.format(', '.join(str(x) for x in xs_flat + if not hasattr(x, 'shape')))) from err + + if not all(a.sharding.spec[0] is None for a in xs_avals): + raise ValueError('0th dimension of all xs should be replicated. Got ' + f'{", ".join(str(a.sharding.spec) for a in xs_avals)}') + + if length is not None: + try: + return int(length) + except core.ConcretizationTypeError as err: + msg = ('The `length` argument to `scan` expects a concrete `int` value.' + ' For scan-like iteration with a dynamic length, use `while_loop`' + ' or `fori_loop`.') + raise core.ConcretizationTypeError(length, msg) from None # type: ignore[arg-type] + if not all(length == l for l in lengths): + msg = ("scan got `length` argument of {} which disagrees with " + "leading axis sizes {}.") + raise ValueError(msg.format(length, [x.shape[0] for x in xs_flat])) + else: + unique_lengths = set(lengths) + if len(unique_lengths) > 1: + msg = "scan got values with different leading axis sizes: {}." + raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat))) + elif len(unique_lengths) == 0: + msg = "scan got no values to scan over and `length` not provided." + raise ValueError(msg) + else: + return list(unique_lengths)[0] def _capitalize(s): # s.capitalize() converts s[1:] to lowercase which we don't want. return s[0].capitalize() + s[1:] -def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals): +def _check_carry_type(name, body_fun, in_carry, out_carry): try: sig = inspect.signature(body_fun) except (ValueError, TypeError): @@ -372,27 +381,23 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals): else: component = lambda p: (f'the input carry at path {keystr(p)}' if p else 'the input carry') - leaves_and_paths, in_carry_tree = tree_flatten_with_path(in_carry) - paths, in_carry_flat = unzip2(leaves_and_paths) - in_avals = _map(core.get_aval, in_carry_flat) - if in_carry_tree != out_carry_tree: + if in_carry.tree != out_carry.tree: try: - out_carry = tree_unflatten(out_carry_tree, out_avals) + out_carry_unflat = out_carry.unflatten() except: - out_carry = None + out_carry_unflat = None - if out_carry is None: - differences = [f'the input tree structure is:\n{in_carry_tree}\n', - f'the output tree structure is:\n{out_carry_tree}\n'] + if out_carry_unflat is None: + differences = (f'the input tree structure is:\n{in_carry.tree}\n' + + f'the output tree structure is:\n{out_carry.tree}\n') else: diffs = [f'{component(path)} is a {thing1} but the corresponding component ' f'of the carry output is a {thing2}, so {explanation}' for path, thing1, thing2, explanation - in equality_errors(in_carry, out_carry)] + in equality_errors(in_carry.unflatten(), out_carry.unflatten())] if len(diffs) == 0: - # The trees may have different aux data but structures are the same. - return - if len(diffs) == 1: + return # the trees may have different aux data, but structures are same + elif len(diffs) == 1: differences = f'{_capitalize(diffs[0])}.\n' else: differences = ('\n'.join(f' * {d};\n' for d in diffs[:-1]) @@ -403,38 +408,50 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals): f"{differences}\n" "Revise the function so that the carry output has the same pytree " "structure as the carry input.") - if not all(_map(core.typematch, in_avals, out_avals)): + if not all(_map(core.typematch, in_carry, out_carry)): diffs = [f'{component(path)} has type {in_aval.str_short()}' ' but the corresponding output carry component has type ' - f'{out_aval.str_short()}{_aval_mismatch_extra(in_aval, out_aval)}' - for path, in_aval, out_aval in zip(paths, in_avals, out_avals) + f'{out_aval.str_short()}' + f'{core.aval_mismatch_extra(in_aval, out_aval)}' + for path, in_aval, out_aval in zip(in_carry.paths, in_carry, out_carry) if not core.typematch(in_aval, out_aval)] + if len(diffs) == 0: - # The trees may have different aux data but structures are the same. - return + return # seems unreachable but in any case we don't have a good error msg if len(diffs) == 1: differences = f'{_capitalize(diffs[0])}.\n' else: differences = ('\n'.join(f' * {d};\n' for d in diffs[:-1]) + f' * {diffs[-1]}.\n') + + pvary_applications = [ + f'applying `jax.lax.pcast(..., {tuple(out_aval.vma - in_aval.vma)},' + " to='varying')` to the initial carry value corresponding to" + f' {component(path)}' + for path, in_aval, out_aval in zip(in_carry.paths, in_carry, out_carry) + if not core.typematch(in_aval, out_aval) and + isinstance(in_aval, ShapedArray) and isinstance(out_aval, ShapedArray) + and in_aval.vma != out_aval.vma and out_aval.vma - in_aval.vma] + + if not pvary_applications: + pvary_msg = '' + elif len(pvary_applications) == 1: + pvary_msg = f'This might be fixed by {pvary_applications[0]}.\n' + else: + pvary_msg = ('This might be fixed by:\n' + + '\n'.join(f' * {d};\n' for d in pvary_applications[:-1]) + + f' * {pvary_applications[-1]}.\n') + if pvary_msg: + pvary_msg += ("See https://docs.jax.dev/en/latest/notebooks/shard_map.html#scan-vma " + "for more information.\n\n") + raise TypeError( - f"{name} function carry input and carry output must have equal types " - "(e.g. shapes and dtypes of arrays), " + f"{name} function carry input and carry output must have equal types, " "but they differ:\n\n" f"{differences}\n" - "Revise the function so that all output types (e.g. shapes " - "and dtypes) match the corresponding input types.") - -def _aval_mismatch_extra(a1: core.AbstractValue, a2: core.AbstractValue) -> str: - assert not core.typematch(a1, a2) - if isinstance(a1, core.ShapedArray) and isinstance(a2, core.ShapedArray): - dtype_mismatch = a1.dtype != a2.dtype - shape_mismatch = a1.shape != a2.shape - return (', so ' * (dtype_mismatch or shape_mismatch) + - 'the dtypes do not match' * dtype_mismatch + - ' and also ' * (dtype_mismatch and shape_mismatch) + - 'the shapes do not match' * shape_mismatch) - return '' + f"{pvary_msg}" + "Revise the function so that all output types match the corresponding " + "input types.") # TODO(mattjj): re-land #19819 version? simpler, but caused ~1 perf regression. def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear, @@ -442,7 +459,10 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear, del _split_transpose consts, carry, xs_ = split_list(args, [num_consts, num_carry]) _, y_avals = split_list(jaxpr.out_avals, [num_carry]) - num_trips, remainder = divmod(length, unroll) + if unroll == 0: + num_trips, remainder = 0, length + else: + num_trips, remainder = divmod(length, unroll) if unroll != 1 and num_trips == 1 and remainder == 0: # In that case, we explicitly want to fully unroll the loop. Put everything @@ -479,15 +499,22 @@ def inner(n, carry, xs): def body_fun(while_carry): i_, carry, yss = while_carry - i = num_trips - i_ - 1 if reverse else i_ - xs = [slicing.dynamic_index_in_dim(xs, i, keepdims=False, - allow_negative_indices=False) - for xs in xss] + with use_abstract_mesh(core.typeof(i_).sharding.mesh): + i = num_trips - i_ - 1 if reverse else i_ + xs = [] + for x in xss: + with use_abstract_mesh(core.typeof(x).sharding.mesh): + o = slicing.dynamic_index_in_dim( + x, i, keepdims=False, allow_negative_indices=False) + xs.append(o) carry, ys = inner(unroll, carry, xs) - yss = [slicing.dynamic_update_index_in_dim(y, upd, i, 0, - allow_negative_indices=False) - for y, upd in zip(yss, ys)] - return i_ + 1, carry, yss + out_yss = [] + for y, upd in zip(yss, ys): + with use_abstract_mesh(core.typeof(y).sharding.mesh): + o = slicing.dynamic_update_index_in_dim( + y, upd, i, 0, allow_negative_indices=False) + out_yss.append(o) + return i_ + 1, carry, out_yss def cond_fun(while_carry): i, _, _ = while_carry @@ -503,8 +530,18 @@ def cond_fun(while_carry): if remainder: carry, ys_rem = inner(remainder, carry, xs_rem) ys = _map(_concat, ys, ys_rem) if not reverse else _map(_concat, ys_rem, ys) + # If any carry leaf is unreduced, we need to add a reshard to + # typeof(carry).sharding which inserts a sharding_constraint so that shardy + # knows not to AR at the boundary of while. This is a no-op at the trace level + # but during lowering time, it inserts an extra sharding constraint. + carry = tree_map(_constrain_unreduced, carry) + ys = tree_map(_constrain_unreduced, ys) return [*carry, *ys] +def _constrain_unreduced(val): + val_s = core.typeof(val).sharding + return reshard(val, val_s) if val_s.spec.unreduced else val + def _split_leading(sz, x): return (slicing.slice_in_dim(x, 0, sz), slicing.slice_in_dim(x, sz, x.shape[0])) @@ -512,15 +549,28 @@ def _split_leading(sz, x): def _concat(a, b): return lax.concatenate([a, b], 0) def _empty_array(prefix, length_spec, aval): - sharding = aval.sharding.with_spec((*length_spec, *aval.sharding.spec)) - return lax.broadcast(lax.empty(aval.dtype), (*prefix, *aval.shape), - out_sharding=sharding) + sharding = aval.sharding.update(spec=aval.sharding.spec.update( + partitions=(*length_spec, *aval.sharding.spec))) + # TODO(yashkatariya): Replace `lax.empty2` with `lax.empty` once + # AllocateBuffer issues are fixed. Also delete `empty2` after this usage is + # removed. Basically uncomment the following 2 lines. + # lax.empty will also need to take a memory_space argument. + # empty = lax.empty((*prefix, *aval.shape), aval.dtype, out_sharding=sharding, + # memory_space=aval.memory_space) + # return core.pvary(empty, tuple(aval.vma)) + empty = core.pvary(lax.empty2(aval.dtype, memory_space=aval.memory_space), + tuple(aval.vma)) + with use_abstract_mesh(sharding.mesh): + out = lax.broadcast(empty, (*prefix, *aval.shape), out_sharding=sharding) + return out eval_jaxpr_p = core.Primitive('eval_jaxpr') eval_jaxpr_p.multiple_results = True -def _stage_jaxpr(trace: pe.JaxprTrace, *tracers, jaxpr: core.ClosedJaxpr): +def _stage_jaxpr(trace: pe.DynamicJaxprTrace, source_info, *tracers, + jaxpr: ClosedJaxpr): params = dict(call_jaxpr=jaxpr) - return trace.default_process_primitive(core.closed_call_p, tracers, params) + return trace.default_process_primitive(core.closed_call_p, tracers, params, + source_info=source_info) pe.custom_staging_rules[eval_jaxpr_p] = _stage_jaxpr @eval_jaxpr_p.def_effectful_abstract_eval # abstract eval only used for jax2tf @@ -532,9 +582,20 @@ def _prepend_dim_to_aval(sz, aval): def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr, linear, unroll, _split_transpose): - carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry]) + if len(args) != len(jaxpr.in_avals): + raise ValueError("scan number of arguments doesn't match the number " + "of jaxpr arguments: {len(args)} vs {len(jaxpr.in_avals)}") + out_carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry]) + _, in_carry_avals, _ = split_list(args, [num_consts, num_carry]) + if [i.vma for i in in_carry_avals] != [o.vma for o in out_carry_avals]: + raise ValueError( + 'Scan carry input and output got mismatched varying manual axes ' + f'{in_carry_avals} and {out_carry_avals}. Please open an ' + 'issue at https://github.com/jax-ml/jax/issues, and as a ' + 'temporary workaround pass the check_vma=False argument to ' + '`jax.shard_map`') ys_avals = _map(partial(_prepend_dim_to_aval, length), y_avals) - return carry_avals + ys_avals, jaxpr.effects + return out_carry_avals + ys_avals, core.eqn_effects(jaxpr) def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry, linear, unroll, _split_transpose): @@ -593,9 +654,119 @@ def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry, for p, nz in zip(primals_out, nonzeros_out)] return primals_out, tangents_out +def _scan_linearize(nzs, *primals_in, reverse: bool, length: int, num_consts: + int, num_carry: int, jaxpr: ClosedJaxpr, linear: + Sequence[bool], unroll: int, _split_transpose: bool): + const_nz, init_nz, xs_nz = split_list(nzs, [num_consts, num_carry]) + num_ys = len(jaxpr.out_avals) - num_carry + carry_nz = init_nz + allow_fwds = [True] * len(jaxpr.consts) + [ + (i < num_consts or i >= num_consts + num_carry) + and not isinstance(x, (np.ndarray, literals.TypedNdArray)) + for i, x in enumerate(primals_in) + ] + for _ in range(1 + num_carry): + nzs = const_nz + carry_nz + xs_nz + primal_jaxpr, num_res_out, nzs_out, in_fwd_res, tangent_jaxpr = \ + ad.linearize_jaxpr(jaxpr, nzs, allow_fwds=allow_fwds, + instantiate=carry_nz + [False] * num_ys) + carry_nz_out = nzs_out[:num_carry] + if carry_nz_out == carry_nz: + break + else: + carry_nz = _map(operator.or_, carry_nz, carry_nz_out) + else: + assert False, "Fixpoint not reached" + num_res_in = len(in_fwd_res) + num_primals_out = len(primal_jaxpr.out_avals) - num_res_out + + # At this point all non-forwarded residuals produced by primal_jaxpr are at + # the end. We want to hoist out loop-invariant ones: + # Before: + # [*const_primals_in , *carry_ext_primals_in] -> [*primals_out, *non_fwd_res] + # After: + # [*const_primals_in_, *carry_ext_primals_in] -> [*primals_out, *ext_res] + # where, modulo hoisted res not being broadcasted by the scan, + # non_fwd_res = merge_lists(which_hoisted, ext_res, hoisted_res) + const_primals_in, carry_ext_primals_in = split_list(primals_in, [num_consts]) + primal_jaxpr, const_primals_in_, which_hoisted, hoisted_res = \ + _scan_known_hoisting(primal_jaxpr, const_primals_in, num_res_out) + del num_res_out + + # To make tangent_jaxpr match the scan calling convention, move to the back + # binders that don't correspond to hoisted or const-forwarded residuals. + # Before: [*res, *tangents_in] -> [*tangents_out] + # After: [*int_res, *tangents_in, *ext_res] -> [*tangents_out] + num_tangents_in = len(tangent_jaxpr.in_avals) - num_res_in + which_hoisted_ = iter(which_hoisted) + res_to_move = [not next(which_hoisted_) if f is None else + f >= len(jaxpr.consts) + num_consts + num_carry + for f in in_fwd_res] + assert next(which_hoisted_, None) is None + tangent_jaxpr = pe.move_binders_to_back( + tangent_jaxpr, res_to_move + [False] * num_tangents_in) + + # Run the primal scan (if it has any outputs or effects). + if not primal_jaxpr.out_avals and not primal_jaxpr.effects: + out = [] + else: + linear_ = (False,) * len(primal_jaxpr.in_avals) # TODO conservative + out = scan_p.bind(*const_primals_in_, *carry_ext_primals_in, + jaxpr=primal_jaxpr, reverse=reverse, length=length, + num_consts=len(const_primals_in_), num_carry=num_carry, + linear=linear_, unroll=unroll, + _split_transpose=_split_transpose) + primals_out, ext_res = split_list(out, [num_primals_out]) + + # Complete res using hoisted_res and input forwards. + res = subs_list(in_fwd_res, [*jaxpr.consts, *primals_in], + merge_lists(which_hoisted, ext_res, hoisted_res)) + + def tangent_fun(res, *tangents): + int_res, ext_res = partition_list(res_to_move, res) + nz_tangents = [ad.instantiate_zeros(x) for nz, x in zip(nzs, tangents) if nz] + tangent_linear = ((False,) * len(int_res) + (True,) * len(nz_tangents) + + (False,) * len(ext_res)) + tangent_num_consts = len(int_res) + sum(nzs[:num_consts]) + tangent_num_carry = sum(nzs[num_consts:num_consts + num_carry]) + nz_tangents_out = scan_p.bind( + *int_res, *nz_tangents, *ext_res, jaxpr=tangent_jaxpr, reverse=reverse, + length=length, num_consts=tangent_num_consts, + num_carry=tangent_num_carry, linear=tangent_linear, unroll=unroll, + _split_transpose=_split_transpose) + tangent_avals_out = [v.aval.to_tangent_aval() for v in jaxpr.jaxpr.outvars] + nz_tangents_out_ = iter(nz_tangents_out) + tangents_out = [next(nz_tangents_out_) if nz else ad.Zero(aval) + for aval, nz in zip(tangent_avals_out, nzs_out)] + assert next(nz_tangents_out_, None) is None + return tangents_out + + return primals_out, nzs_out, res, tangent_fun + +def _scan_known_hoisting(jaxpr_known, known_consts, num_res): + # To disable: + # return jaxpr_known, known_consts, [False] * num_res, [] + + consts = [pe.PartialVal.unknown(a) if isinstance(a := typeof(c), AbstractRef) + else pe.PartialVal.known(c) for c in known_consts] + others = _map(pe.PartialVal.unknown, jaxpr_known.in_avals[len(consts):]) + num_known_outs = len(jaxpr_known.out_avals) - num_res + with source_info_util.reset_name_stack(): + jaxpr_known_, pvals_out, new_known_consts = pe.trace_to_jaxpr_nounits( + lu.wrap_init(core.jaxpr_as_fun(jaxpr_known), + debug_info=jaxpr_known.jaxpr.debug_info), + consts + others, instantiate=[True] * num_known_outs + [False] * num_res) + jaxpr_known = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_known_)) + res_pvals = pvals_out[num_known_outs:] + which_hoisted = [pval.is_known() for pval in res_pvals] + hoisted_res = [pval.get_known() for pval in res_pvals if pval.is_known()] + mut_consts = [c for c in known_consts if isinstance(typeof(c), AbstractRef)] + return jaxpr_known, [*new_known_consts, *mut_consts], which_hoisted, hoisted_res + + def _scan_partial_eval(trace, *tracers, reverse: bool, length: int, num_consts: int, num_carry: int, - jaxpr: core.ClosedJaxpr, linear: Sequence[bool], + jaxpr: ClosedJaxpr, linear: Sequence[bool], unroll: int, _split_transpose: bool): num_ys = len(jaxpr.out_avals) - num_carry unknowns = [not t.pval.is_known() for t in tracers] @@ -603,13 +774,19 @@ def _scan_partial_eval(trace, *tracers, reverse: bool, # Fixpoint computation of which carry elements are unknown. Each iteration # promotes at least one carry to unknown. We need at most len(carry) - # iterations, but we need one last iteration to prepare the jaxpr based on the - # final carry_uk. + # iterations to decide carry_uk, plus one to prepare the jaxpr. carry_uk = init_uk + # Don't allow forwarding from the carry or numpy.ndarrays. + fwd = [ + (i < num_consts or i >= num_consts + num_carry) and + not isinstance(t.pval.get_known(), (np.ndarray, literals.TypedNdArray)) + for i, t in enumerate(tracers) + ] for _ in range(1 + len(carry_uk)): unknowns = const_uk + carry_uk + xs_uk - jaxpr_known, jaxpr_unknown, out_uk, res_avals = pe.partial_eval_jaxpr_nounits( - jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys) + jaxpr_known, jaxpr_unknown, out_uk, res_avals, in_fwd_res = \ + pe.partial_eval_jaxpr_nounits_fwd( + jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys, fwd=fwd) carry_uk_out, ys_uk = split_list(out_uk, [num_carry]) if carry_uk_out == carry_uk: break @@ -617,323 +794,225 @@ def _scan_partial_eval(trace, *tracers, reverse: bool, carry_uk = _map(operator.or_, carry_uk, carry_uk_out) else: assert False, "Fixpoint not reached" - num_res = len(res_avals) + num_res_out, num_res_in = len(res_avals), len(in_fwd_res) + num_knowns_out = len(jaxpr_known.out_avals) - num_res_out + num_consts_known = num_consts - sum(const_uk) + num_carry_known = num_carry - sum(carry_uk) del res_avals, carry_uk_out # Instantiate those inputs which must be treated as unknown from the fixpoint. - tracers = tuple(trace.instantiate_const(t) if uk else t - for t, uk in zip(tracers, unknowns)) - - # The residual inputs and outputs of the jaxprs produced haven't yet been - # adapted to the scan calling convention; in particular, jaxpr_known has its - # residual outputs all at the end, meaning they're extensive outputs (which is - # fully general but may be wasteful for residuals which are loop-invariant) - # while jaxpr_unknown has its corresponding residual inputs at the front (just - # as a convention with partial_eval_jaxpr_nounits), making them constant - # inputs. To make them consistent, we move the residual inputs on - # jaxpr_unknown to the end, even though we may move some back in the sequel. + tracers = [trace.instantiate_const(t) if uk else t + for t, uk in zip(tracers, unknowns)] + known_ins = [t.pval.get_known() for t in tracers if t.pval.is_known()] + unknown_ins = [t for t in tracers if not t.pval.is_known()] + + # At this point all non-forwarded residuals are treated as extensive outputs + # of jaxpr_known. Hoist out those that only depend on consts. + # Before: jaxpr_known: [*known_ins] -> [*known_outs, *non_fwd_res] + # After: jaxpr_known: [*known_consts_, *known_ins] -> [*known_outs, *ext_res] + # where, modulo hoisted res not being broadcast, we have + # non_fwd_res = merge_lists(which_hoisted, ext_res, hoisted_res) + known_consts, known_ins = split_list(known_ins, [num_consts_known]) + jaxpr_known, known_consts_, which_hoisted, hoisted_res = \ + _scan_known_hoisting(jaxpr_known, known_consts, num_res_out) + del num_res_out # changed + + # To make jaxpr_unknown match the scan calling convention, move to the back + # binders that don't correspond to hoisted or const-forwarded residuals. + # Before: jaxpr_unknown: [*res, *unknown_ins] -> [*unkown_outs] + # After: jaxpr_unkonwn: [*int_res, *unknown_ins, *ext_res] -> [*unknown_outs] + num_unk_in = len(jaxpr_unknown.in_avals) - num_res_in + which_hoisted_ = iter(which_hoisted) + res_to_move = [not next(which_hoisted_) if f is None else + f >= len(jaxpr.consts) + num_consts_known + num_carry_known + for f in in_fwd_res] + assert next(which_hoisted_, None) is None jaxpr_unknown = pe.move_binders_to_back( - jaxpr_unknown, [True] * num_res + [False] * sum(unknowns)) - - # At this point, all residuals are treated as extensive outputs of jaxpr_known - # (and extensive inputs to jaxpr_unknown). But residuals that are loop- - # invariant can be hoisted out of the scan, rather than letting them get - # broadcast (as in e.g. scanning multiplication by a constant matrix; we don't - # want to broadcast the matrix!). So, outside the loop we perform a partial - # evaluation with known 'const' inputs (but all other inputs unknown). - const_pvals = [pe.PartialVal.known(t.pval.get_known()) - for t in tracers[:num_consts] if t.pval.is_known()] - other_pvals = [pe.PartialVal.unknown(aval) - for aval in jaxpr_known.in_avals[len(const_pvals):]] - with source_info_util.reset_name_stack(): - jaxpr_known_, invar_pvals_out, jaxpr_known_consts = pe.trace_to_jaxpr_nounits( - lu.wrap_init(core.jaxpr_as_fun(jaxpr_known), - debug_info=jaxpr_known.jaxpr.debug_info), - const_pvals + other_pvals, - instantiate=[True] * (len(out_uk) - sum(out_uk)) + [False] * num_res) - jaxpr_known = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_known_), ()) - # The above trace_to_jaxpr_nounits call computed loop-invariant residuals - # (known values in invar_pvals_out) and also computed loop-invariant values - # needed by the new jaxpr_known (in jaxpr_known_consts, which replace the - # previous consts). We need to collect the computed inteisive residuals, and - # move corresponding intensive residual binders in jaxpr_unknown to the front. - res_pvals = invar_pvals_out[len(invar_pvals_out) - num_res:] - intensive_res = [pval.get_known() for pval in res_pvals if pval.is_known()] - jaxpr_unknown = pe.move_binders_to_front( - jaxpr_unknown, - [False] * sum(unknowns) + [pval.is_known() for pval in res_pvals]) - del const_pvals, other_pvals, invar_pvals_out, jaxpr_known_, res_pvals - # We use `jaxpr_known_consts` when we call scan_p.bind with jaxpr_known, and - # we use `intensive_res` when we build the jaxpr eqn with jaxpr_unknown. - - # As another optimization, for any extensive inputs that are just forwarded to - # extensive outputs, to avoid a copy (which would be looping over - # dynamic-update-slice) we'd rather forward the input tracer/value. That means - # pruning some outputs from jaxpr_known here, and updating `out_flat` below. - fwds_known = pe._jaxpr_forwarding(jaxpr_known.jaxpr) - # Prune fwds_known to include only extensive input to extensive output. - fwds_known = [in_idx if out_idx >= num_carry - sum(carry_uk) and - in_idx is not None and - in_idx >= len(jaxpr_known_consts) + num_carry - sum(carry_uk) - else None for out_idx, in_idx in enumerate(fwds_known)] - # Drop any extensive output we can instead get by forwarding an input. - # TODO(mattjj): use pe.dce_jaxpr here, though need a fixpoint - jaxpr_known_, () = jaxpr_known.jaxpr, jaxpr_known.consts - jaxpr_known_ = jaxpr_known_.replace( - outvars=[x for x, i in zip(jaxpr_known_.outvars, fwds_known) if i is None]) - jaxpr_known = core.ClosedJaxpr(jaxpr_known_, ()) - del jaxpr_known_ - # We use `fwds_known` below when forming the output of scanning jaxpr_known. + jaxpr_unknown, res_to_move + [False] * num_unk_in) # Run the known part of the scan (if it has any outputs or effects). - known_inputs = (list(jaxpr_known_consts) + - [t.pval.get_known() for t in tracers[num_consts:] - if t.pval.is_known()]) + linear_known, linear_unknown = partition_list(unknowns, linear) if not jaxpr_known.out_avals and not jaxpr_known.effects: - out_known = [] + known_outs_ext_res = [] else: - linear_known = [False] * len(known_inputs) # conservative! - out_known = scan_p.bind( - *known_inputs, reverse=reverse, length=length, jaxpr=jaxpr_known, - num_consts=len(jaxpr_known_consts), num_carry=num_carry - sum(carry_uk), - linear=tuple(linear_known), unroll=unroll, + linear_known = [False] * len(jaxpr_known.in_avals) # TODO conservative + assert len(known_consts_) + len(known_ins) == len(jaxpr_known.in_avals) + known_outs_ext_res = scan_p.bind( + *known_consts_, *known_ins, jaxpr=jaxpr_known, reverse=reverse, + length=length, num_consts=len(known_consts_), + num_carry=num_carry_known, linear=(*linear_known,), unroll=unroll, _split_transpose=_split_transpose) - del linear_known - # Complete the known output by filling in forwarded values using fwds_known. - out_known_iter = iter(out_known) - out_known = [next(out_known_iter) if f is None - else _maybe_put(known_inputs[f]) for f in fwds_known] - assert next(out_known_iter, None) is None - del known_inputs, out_known_iter - - # Split known outputs from residuals. - out_known, extensive_res = split_list(out_known, [len(out_uk) - sum(out_uk)]) - assert len(intensive_res) + len(extensive_res) == num_res + known_outs, ext_res = split_list(known_outs_ext_res, [num_knowns_out]) + + # Complete non_fwd_res and then res, then split to match binders. + non_fwd_res = merge_lists(which_hoisted, ext_res, hoisted_res) + non_fwd_res_ = iter(non_fwd_res) + res = [next(non_fwd_res_) if f is None + else [*jaxpr.consts, *known_consts, *known_ins][f] for f in in_fwd_res] + assert next(non_fwd_res_, None) is None + int_res, ext_res = partition_list(res_to_move, res) # Create input tracers for jaxpr_unknown bind. unknown_inputs = [t for t in tracers if not t.pval.is_known()] - intensive_res = _map(trace.new_instantiated_const, intensive_res) - extensive_res = _map(trace.new_instantiated_const, extensive_res) + int_res = _map(trace.new_instantiated_const, int_res) + ext_res = _map(trace.new_instantiated_const, ext_res) # Create output tracers for jaxpr_unknown bind, adapting extensive shapes. carry_avals, y_avals = split_list(jaxpr_unknown.out_avals, [sum(carry_uk)]) - ys_avals = [core.unmapped_aval(length, 0, y_aval) - for y_aval in y_avals] + ys_avals = [core.unmapped_aval(length, 0, y_aval) for y_aval in y_avals] out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) - for a in itertools.chain(carry_avals, ys_avals)] + for a in it.chain(carry_avals, ys_avals)] del carry_avals, y_avals # Create equation. - linear_unknown = tuple([False] * len(intensive_res) + - [l for l, uk in zip(linear, unknowns) if uk] + - [False] * len(extensive_res)) + linear_unknown = [False] * len(int_res) + linear_unknown + [False] * len(ext_res) + assert len(linear_unknown) == len(jaxpr_unknown.in_avals) name_stack = source_info_util.current_name_stack()[len(trace.name_stack):] source = source_info_util.current().replace(name_stack=name_stack) - assert len(out_tracers) == len(jaxpr_unknown.out_avals) - eqn = pe.new_eqn_recipe([*intensive_res, *unknown_inputs, *extensive_res], - out_tracers, scan_p, + unknown_tracers_in = [*int_res, *unknown_inputs, *ext_res] + eqn = pe.new_eqn_recipe(trace, unknown_tracers_in, out_tracers, scan_p, dict(reverse=reverse, length=length, unroll=unroll, - jaxpr=jaxpr_unknown, linear=linear_unknown, - num_consts=len(intensive_res) + sum(const_uk), + jaxpr=jaxpr_unknown, linear=(*linear_unknown,), + num_consts=len(int_res) + sum(const_uk), num_carry=sum(carry_uk), _split_transpose=_split_transpose), jaxpr_unknown.effects, source) for t in out_tracers: t.recipe = eqn + if effects.partial_eval_kept_effects.filter_in(jaxpr_unknown.effects): + trace.effect_handles.append(pe.EffectHandle(unknown_tracers_in, eqn)) # Merge known and unknown outputs into final result. - return util.merge_lists(out_uk, out_known, out_tracers) + return util.merge_lists(out_uk, known_outs, out_tracers) def _maybe_put(x): - if isinstance(x, np.ndarray): + if isinstance(x, (np.ndarray, literals.TypedNdArray)): aval = core.shaped_abstractify(x) s = sharding.SingleDeviceSharding(xb.local_devices(backend='cpu')[0]) result_handler = pxla.global_aval_to_result_handler(aval, s, False) - return result_handler(pxla.shard_args([s], [None], [None], [x])) + return result_handler( + pxla.shard_args( + [s], [None], [dispatch.ArrayCopySemantics.REUSE_INPUT], [x] + ) + ) else: return x -def _scan_transpose(cts, *args, reverse, length, num_consts, - num_carry, jaxpr, linear, unroll, _split_transpose): - # we've only implemented transposing scans with specific lin/nonlin patterns +@weakref_lru_cache +def _rearrange_mutable_binders( + jaxpr: ClosedJaxpr, num_prefix: int, num_binders: int +) -> ClosedJaxpr: + fst, invars, rst = split_list(jaxpr.jaxpr.invars, [num_prefix, num_binders]) + is_mutable = [isinstance(v.aval, AbstractRef) for v in invars] + immut_invars, mut_invars = partition_list(is_mutable, invars) + new_invars = [*fst, *mut_invars, *immut_invars, *rst] + if jaxpr.jaxpr.debug_info.arg_names is None: + new_arg_names = None + else: + fst, names, rst = split_list(jaxpr.jaxpr.debug_info.arg_names, + [num_prefix, num_binders]) + immut_names, mut_names = partition_list(is_mutable, names) + new_arg_names = [*fst, *mut_names, *immut_names, *rst] + dbg = jaxpr.jaxpr.debug_info._replace(arg_names=new_arg_names) + + # TODO(mattjj): don't we need to re-number effects? test coverage? + new_effs = pe._renumber_effects((*jaxpr.jaxpr.constvars, *new_invars), + (*jaxpr.jaxpr.constvars, *jaxpr.jaxpr.invars), + jaxpr.jaxpr.effects) + new_jaxpr = jaxpr.jaxpr.replace(invars=new_invars, effects=new_effs, + debug_info=dbg) + if config.enable_checks.value: core.check_jaxpr(new_jaxpr) + return ClosedJaxpr(new_jaxpr, jaxpr.consts) + +def _scan_transpose_fancy(cts, *args, reverse, length, num_consts, + num_carry, jaxpr, linear, unroll, _split_transpose): consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry]) num_ires = len(consts_lin) - sum(consts_lin) num_eres = len(xs_lin) - sum(xs_lin) - if consts_lin != [False] * num_ires + [True] * (len(consts_lin) - num_ires): - raise NotImplementedError - if xs_lin != [True] * (len(xs_lin) - num_eres) + [False] * num_eres: - raise NotImplementedError - if not all(init_lin): - pass # TODO(mattjj): error check https://github.com/jax-ml/jax/issues/1963 - - consts, _, xs = split_list(args, [num_consts, num_carry]) - ires, _ = split_list(consts, [num_ires]) - _, eres = split_list(xs, [sum(xs_lin)]) - assert not any(ad.is_undefined_primal(r) for r in ires) - assert not any(ad.is_undefined_primal(r) for r in eres) - - carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry]) + + # Rearrange jaxpr binders to separate out refs since we in/out swap pure vals: + # Before: [ires, T d, T c, T a, eres] -> [T c, T b] + # After: [ires, T d_mut, T d_pure, T c, T a_mut, T a_pure, eres] -> [T c, T b] + # where + # * `ires` means intensive (not scanned over / const) residuals, all Arrays; + # * `T d` means the intensive tangents, each a linear GradAccum or nonlinear + # plumbing ref or linear (zero) Array; + # * `T c` means the carry tangents; + # * `T a` means the extensive (scanned over) input tangents; + # * `eres` means the extensive residuals; + # * `T b` means the extensive tangent outputs. + ires, consts_dot, carry_dot, xs_dot, eres = split_list( + args, [num_ires, num_consts - num_ires, num_carry, sum(xs_lin)]) + is_mutable = [isinstance(x, ad.RefAccum) or not isinstance(x, ad.GradAccum) + and isinstance(typeof(x), AbstractRef) for x in consts_dot] + immut_consts_dot, mut_consts_bar = partition_list(is_mutable, consts_dot) + jaxpr = _rearrange_mutable_binders(jaxpr, num_ires, num_consts - num_ires) + is_mutable_ = [isinstance(x, ad.RefAccum) or not isinstance(x, ad.GradAccum) + and isinstance(typeof(x), AbstractRef) for x in xs_dot] + immut_xs_dot, mut_xs_bar = partition_list(is_mutable_, xs_dot) + jaxpr = _rearrange_mutable_binders(jaxpr, num_consts + num_carry, sum(xs_lin)) + del consts_dot, xs_dot, args + + # prepare cotangent values to be passed in to transpose ct_carry, ct_ys = split_list(cts, [num_carry]) - ct_carry = _map(ad.instantiate_zeros, ct_carry) - ct_ys_is_zeros = tuple(type(ct_y) is ad.Zero for ct_y in ct_ys) - ct_ys = [x for x in ct_ys if type(x) is not ad.Zero] - - ct_consts = _map(ad_util.zeros_like_aval, jaxpr.in_avals[num_ires:num_consts]) - - # jaxpr :: [ires, T d] -> [T c] -> [T a, eres] -> ([T c], [T b]) - # jaxpr_trans :: [ires] -> [CT d, CT c] -> [CT b, eres] -> ([CT d, CT c], [CT a]) - jaxpr_trans, attrs_tracked = _transpose_scan_jaxpr( - jaxpr, num_ires, num_consts - num_ires, num_eres, ct_ys_is_zeros) - linear_trans = ([False] * num_ires + [False] * len(attrs_tracked) + - [True] * (len(ct_consts) + len(ct_carry) + len(ct_ys)) + + ct_carry = _map(ad.instantiate_zeros, ct_carry) # TODO(mattjj): fixpoint + ct_ys_nz = [x for x in ct_ys if type(x) is not ad.Zero] + + # initialize values to be used to accumulate pure constant gradients + immut_const_avals = jaxpr.in_avals[num_ires+len(mut_consts_bar):num_consts] + ct_immut_consts = _map(ad_util.zeros_like_aval, immut_const_avals) + + # prepare transpose inputs, unboxing RefAccums while noting which are linear + trans_in, trans_tree = tree_flatten([ires, mut_consts_bar, ct_immut_consts, + ct_carry, mut_xs_bar, ct_ys, eres]) + lin_refs = tuple(isinstance(x, ad.RefAccum) for x in trans_in) + trans_in = [x.inst().ref if l else x for l, x in zip(lin_refs, trans_in)] + + # prepare transposed jaxpr + trans_avals, ext_avals = split_list(_map(ad.accum_typeof, trans_in), [num_consts+num_carry]) + trans_avals = trans_avals + [core.mapped_aval(length, 0, a) for a in ext_avals] + xs_avals = tuple(core.mapped_aval(length, 0, ad.accum_typeof(x)) for x in immut_xs_dot) + jaxpr_trans = _transpose_scan_jaxpr_fancy( + jaxpr, trans_tree, tuple(trans_avals), lin_refs, xs_avals) + + # run it + linear_trans = ([False] * num_ires + + [True] * (len(mut_consts_bar) + len(immut_consts_dot) + + len(carry_dot) + len(mut_xs_bar) + len(ct_ys_nz)) + [False] * num_eres) - in_state = _get_states(attrs_tracked) - - transpose_inputs = *ires, *in_state, *ct_consts, *ct_carry, *ct_ys, *eres - transpose_num_out_carry = num_consts-num_ires+num_carry+len(attrs_tracked) - - if not _split_transpose: - outs = scan_p.bind( - *transpose_inputs, - reverse=not reverse, length=length, jaxpr=jaxpr_trans, - num_consts=num_ires, - num_carry=transpose_num_out_carry, - linear=tuple(linear_trans), unroll=unroll, - _split_transpose=False) - else: - inst_mask = [False] * transpose_num_out_carry + [True] * ( - len(jaxpr_trans.out_avals) - transpose_num_out_carry) - - unknowns_mask = [False] * (len(transpose_inputs) - len(eres)) + [ - True - ] * len(eres) - - # The residuals may contain original parameters (e.g. forwarded extensive - # array arguments) and residuals from the primal. Hence we iterate and - # update all values of the mask that we've set to True (i.e. 'unknown') to - # see if we should actually push them to the known computation in order to - # perform the scan (known) - map (unknown) split. The test effectively is - # done by comparing the output masks. - # - # TODO(dvytin): improve performance by doing backwards abstract eval. - # - # For example, a mask arising from a relu() is an extensive residual, yet - # only really used in the backpropagation scan, not in the unknown map. But - # an intermediate activation of a matmul will be used only in the map part. - # If we were to erroneously push the relu mask to the unknown part, then, - # in the output, the partial evaluator will also pull the loop-carried state - # to the unknown, and that is something we can test by comparing the output - # mask of pe against our intended inst mask. - for index in range(len(jaxpr_trans.in_avals)): - if unknowns_mask[index]: - mask_for_dependence = [False]*len(jaxpr_trans.in_avals) - mask_for_dependence[index] = True # try moving this to unknown - _, _, outs_for_dependence, _ = pe.partial_eval_jaxpr_nounits( - jaxpr_trans, mask_for_dependence, inst_mask) - if inst_mask != outs_for_dependence: - unknowns_mask[index] = False - - jaxpr_known_body, jaxpr_unknown_body, outs_mask, res_avals = ( - pe.partial_eval_jaxpr_nounits(jaxpr_trans, unknowns_mask, inst_mask) - ) - - num_knowns = len(outs_mask) - sum(outs_mask) - - linear_list = list(linear_trans) - known_linear = [ - l for mask, l in zip(unknowns_mask, linear_list) if not mask - ] - unknown_linear = [l for mask, l in zip(unknowns_mask, linear_list) if mask] - unknown_linear = [False] * len(res_avals) + unknown_linear - - known_args = [ - arg for mask, arg in zip(unknowns_mask, transpose_inputs) if not mask - ] - unknown_args = [ - arg for mask, arg in zip(unknowns_mask, transpose_inputs) if mask - ] - # 1. Apply the known scan. - knowns_and_residual = scan_p.bind( - *known_args, - reverse=not reverse, - length=length, - num_consts=num_ires, - num_carry=transpose_num_out_carry, - jaxpr=jaxpr_known_body, - linear=tuple(known_linear), - unroll=unroll, - _split_transpose=False, # Just generate the loop now. - ) - known_results, residuals = split_list(knowns_and_residual, [num_knowns]) - - # 2. Apply the unknown map to residuals and unknown arguments. - unknown_results = scan_p.bind( - *residuals, *unknown_args, - reverse=reverse, # Keep reverse as is for better scheduling. - length=length, - num_consts=0, - num_carry=0, - jaxpr=jaxpr_unknown_body, - linear=tuple(unknown_linear), - unroll=unroll, - _split_transpose=False, # Just generate the loop now. - ) - known_results_iter = iter(known_results) - unknown_results_iter = iter(unknown_results) - outs = [ - next(known_results_iter) if not mask else next(unknown_results_iter) - for mask in outs_mask - ] - - out_state, outs = split_list(outs, [len(attrs_tracked)]) - _set_states(attrs_tracked, out_state) - ct_consts, ct_init, ct_xs = split_list(outs, [num_consts - num_ires, num_carry]) - return [None] * num_ires + ct_consts + ct_init + ct_xs + [None] * num_eres + outs = scan_p.bind( + *trans_in, reverse=not reverse, length=length, jaxpr=jaxpr_trans, + num_consts=num_ires + len(mut_consts_bar), + num_carry=len(immut_consts_dot) + len(carry_dot), + linear=tuple(linear_trans), unroll=unroll, _split_transpose=False) + for a, x in zip([*immut_consts_dot, *carry_dot, *immut_xs_dot], outs): + if isinstance(a, ad.GradAccum): a.accum(x) -# transpose_scan_jaxpr :: ([res1, c, a, res2] -> b) -# -> ([res1, CT c, CT b, res2] -> [CT c, CT a]) +# transpose_scan_jaxpr converts the jaxpr signature: +# Before: [(ires, T d_mut T d_pure), T c, (CT a_mut, T a, eres)] -> [T c, T b] +# ---------- consts ----------- --------- ext ------- +# +# After: [(ires, CT d_mut), (CT d_pure, CT c), (CT a_mut, CT b, eres)] -> [(CT d_pure, CT c), CT a] +# --- consts ---- ----- carry ------ --------- ext -------- @weakref_lru_cache -def _transpose_scan_jaxpr(jaxpr: core.ClosedJaxpr, - num_res1: int, num_c: int, num_res2: int, - ct_ys_is_zeros: Sequence[bool]): - num_a = len(jaxpr.in_avals) - num_res1 - num_c - num_res2 - # TODO: allow input cotangent avals to be batched relative to jaxpr.in_avals - # if an axis isn't reduced - res1_avals, c_avals, a_avals, res2_avals = split_list( - jaxpr.in_avals, [num_res1, num_c, num_a]) - - num_ys = len(ct_ys_is_zeros) - num_b = len(jaxpr.out_avals) - num_ys - # TODO: Also propagate ad.Zero through b_carry_avals until fixed point. - b_carry_avals, b_ys_avals = split_list(list(jaxpr.out_avals), [num_b]) - b_ys_avals_stripped = [ - aval for aval, is_zero in zip(b_ys_avals, ct_ys_is_zeros) if not is_zero - ] - - def transposed(*res1_cbar_bbar_res2): - res1, c_bar, b_bar, ys_bar_stripped, res2 = split_list( - res1_cbar_bbar_res2, - [num_res1, num_c, num_b, len(b_ys_avals_stripped)]) - ys_bar_stripped_iter = iter(ys_bar_stripped) - ys_bar = [ - ad.Zero(aval) if is_zero else next(ys_bar_stripped_iter) - for aval, is_zero in zip(b_ys_avals, ct_ys_is_zeros) - ] - # TODO(mattjj): c_avals should be _tangent_ types here... - primals = (res1 + [ad.UndefinedPrimal(aval) for aval in c_avals] + - [ad.UndefinedPrimal(aval) for aval in a_avals] + res2) - cbar_abar = ad.backward_pass( - jaxpr.jaxpr, False, jaxpr.consts, primals, b_bar + ys_bar) - _, new_c_bar, a_bar, _ = split_list(cbar_abar, [num_res1, num_c, num_a]) - a_bar = _map(ad.instantiate_zeros, a_bar) - c_bar = _map(ad.instantiate_zeros, _map(ad.add_tangents, c_bar, new_c_bar)) - return c_bar + a_bar - - # TODO(necula): fix arg names and results for transposed - transposed_wrapped = lu.wrap_init(transposed, - debug_info=jaxpr.jaxpr.debug_info) - return _make_closed_jaxpr_attrs( - transposed_wrapped, - tuple(res1_avals + c_avals + b_carry_avals + - b_ys_avals_stripped + res2_avals)) +def _transpose_scan_jaxpr_fancy( + jaxpr, trans_tree, trans_avals, lin_refs, immut_xs_avals +) -> core.ClosedJaxpr: + def transposed(*args): + args = [ad.RefAccum(typeof(x).inner_aval, x) if l else x + for l, x in zip(lin_refs, args)] + ires, mut_consts_bar, ct_immut_consts, ct_carry, mut_xs_bar, ct_ys, eres = \ + tree_unflatten(trans_tree, args) + immut_consts_dot = [ad.ValAccum(core.get_aval(x), x) for x in ct_immut_consts] + carry_dot = [ad.ValAccum(core.get_aval(x)) for x in ct_carry] + immut_xs_dot = [ad.ValAccum(a) for a in immut_xs_avals] + primals = (ires + mut_consts_bar + immut_consts_dot + carry_dot + mut_xs_bar + + immut_xs_dot + eres) + ad.backward_pass3(jaxpr.jaxpr, False, jaxpr.consts, primals, ct_carry + ct_ys) + return [ad.instantiate_zeros(x.freeze()) for x in primals + if isinstance(x, ad.ValAccum)] + + dbg = jaxpr.jaxpr.debug_info.with_unknown_names() + transposed_wrapped = lu.wrap_init(transposed, debug_info=dbg) + return _make_closed_jaxpr(transposed_wrapped, trans_avals) def _scan_batching_rule(axis_data, args, @@ -986,10 +1065,7 @@ def _scan_batching_rule(axis_data, args, @weakref_lru_cache def _cached_scan_pad_jaxpr(jaxpr): - return core.ClosedJaxpr(*pe.pad_jaxpr(jaxpr.jaxpr, jaxpr.consts)) - -def _scan_padding_rule(in_avals, out_avals, *args, jaxpr, **params): - return scan_p.bind(*args, jaxpr=_cached_scan_pad_jaxpr(jaxpr), **params) + return ClosedJaxpr(*pe.pad_jaxpr(jaxpr.jaxpr, jaxpr.consts)) def _scan_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn ) -> tuple[list[bool], core.JaxprEqn | None]: @@ -1017,12 +1093,12 @@ def _scan_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn new_linear = [l for l, u in zip(eqn.params['linear'], used_inputs) if u] new_params = dict(eqn.params, num_consts=sum(used_consts), num_carry=sum(used_carry_in), linear=tuple(new_linear), - jaxpr=core.ClosedJaxpr(jaxpr_dce, jaxpr.consts)) + jaxpr=ClosedJaxpr(jaxpr_dce, jaxpr.consts)) # TODO(mattjj,sharadmv): don't assume effects are never DCE'd? new_invars = [v for v, used in zip(eqn.invars, used_inputs) if used] new_outvars = [v for v, used in zip(eqn.outvars, used_outputs) if used] - _, new_effects = eqn.primitive.abstract_eval(*[v.aval for v in new_invars], - **new_params) + _, new_effects = eqn.primitive.abstract_eval( + *[v.aval for v in new_invars], **new_params) new_eqn = pe.new_jaxpr_eqn( new_invars, new_outvars, @@ -1054,8 +1130,8 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn): carry_uk = _map(operator.or_, carry_uk, carry_uk_out) else: assert False, "Fixpoint not reached" - jaxpr_known = core.ClosedJaxpr(jaxpr_known_ , jaxpr.consts) - jaxpr_staged = core.ClosedJaxpr(jaxpr_staged_, jaxpr.consts) + jaxpr_known = ClosedJaxpr(jaxpr_known_ , jaxpr.consts) + jaxpr_staged = ClosedJaxpr(jaxpr_staged_, jaxpr.consts) # Move all residual binders to the back of jaxpr_staged so they're extensive. # TODO(mattjj): make jaxpr_staged only take instantiated inputs @@ -1074,10 +1150,12 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn): num_const_known = len(const_uk) - sum(const_uk) num_carry_known = len(carry_uk) - sum(carry_uk) num_xs_known = len( xs_uk) - sum( xs_uk) + const_donthoist = [isinstance(a, state.AbstractRef) + for a in jaxpr_known.in_avals[:num_const_known]] jaxpr_known_hoist, jaxpr_known_loop, loop_dep, consts_known_lp_avals = \ pe.partial_eval_jaxpr_nounits( jaxpr_known, - [False] * num_const_known + [True] * (num_carry_known + num_xs_known), + const_donthoist + [True] * (num_carry_known + num_xs_known), [True] * (len(unks_out) - sum(unks_out)) + [False] * num_res) # jaxpr_known_hoist produces intensive residuals followed by the constants for # jaxpr_known_loop. We adjust jaxpr_staged to accept intensive res as consts. @@ -1103,26 +1181,29 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn): # (corresponding to consts_known_lp_avals) followed by known carry and xs. linear_known_ = [l for l, uk in zip(eqn.params['linear'], unks_in) if not uk] _, linear_known_ = split_list(linear_known_, [num_const_known]) - linear_known = [False] * len(consts_known_lp_avals) + linear_known_ params_known = dict(eqn.params, jaxpr=jaxpr_known_loop, - num_consts=len(consts_known_lp_avals), - num_carry=len(carry_uk)-sum(carry_uk), - linear=tuple(linear_known)) + num_carry=len(carry_uk)-sum(carry_uk)) def known(*ins_known): - consts_known_hoist, ins_known_lp = split_list(ins_known, [num_const_known]) + consts_known_maybehoist, ins_known_lp = split_list(ins_known, [num_const_known]) + consts_known_hoist, consts_known_donthoist = \ + partition_list(const_donthoist, consts_known_maybehoist) out_hoist = core.jaxpr_as_fun(jaxpr_known_hoist)(*consts_known_hoist) intensive_res, consts_known_lp = split_list(out_hoist, [num_intensive_res]) - out_loop = scan_p.bind(*consts_known_lp, *ins_known_lp, **params_known) + num_consts = len(consts_known_lp) + len(consts_known_donthoist) + linear_known = (False,) * num_consts + (False,) * len(ins_known_lp) + out_loop = scan_p.bind( + *consts_known_lp, *consts_known_donthoist, *ins_known_lp, + **dict(params_known, linear=linear_known, num_consts=num_consts)) return [*intensive_res, *out_loop] - call_jaxpr_, _, call_jaxpr_consts, () = pe.trace_to_jaxpr_dynamic( + call_jaxpr_, _, call_jaxpr_consts = pe.trace_to_jaxpr_dynamic( lu.wrap_init(known, debug_info=jaxpr_known_hoist.jaxpr.debug_info), [v.aval for v in ins_known]) - call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts) + call_jaxpr = ClosedJaxpr(call_jaxpr_, call_jaxpr_consts) eqn_known = pe.new_jaxpr_eqn( ins_known, [*intensive_res, *out_binders_known, *extensive_res], - core.closed_call_p, dict(call_jaxpr=call_jaxpr), call_jaxpr.effects, - eqn.source_info, eqn.ctx) + core.closed_call_p, dict(call_jaxpr=call_jaxpr), + core.eqn_effects(call_jaxpr), eqn.source_info, eqn.ctx) # Create the staged eqn. _, out_binders_staged = partition_list(inst_out, eqn.outvars) @@ -1131,12 +1212,17 @@ def known(*ins_known): params_staged = dict(eqn.params, jaxpr=jaxpr_staged, num_consts=len(intensive_res) + eqn.params['num_consts'], linear=tuple(linear_staged)) - eqn_staged = pe.new_jaxpr_eqn([*intensive_res, *eqn.invars, *extensive_res], - out_binders_staged, eqn.primitive, - params_staged, jaxpr_staged.effects, - eqn.source_info, eqn.ctx) + eqn_staged = pe.new_jaxpr_eqn( + [*intensive_res, *eqn.invars, *extensive_res], out_binders_staged, + eqn.primitive, params_staged, core.eqn_effects(jaxpr_staged), + eqn.source_info, eqn.ctx) new_vars = [*new_inst, *intensive_res, *extensive_res] + assert len(eqn_staged.invars) == len(eqn_staged.params['linear']) + for e in [eqn_known, eqn_staged]: + for eff in e.effects: + if isinstance(eff, effects.JaxprInputEffect): + assert isinstance(e.invars[eff.input_index].aval, AbstractRef) return eqn_known, eqn_staged, unks_out, inst_out, new_vars def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts, @@ -1151,10 +1237,10 @@ def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts, type(num_consts) is int and num_consts >= 0) tc(num_carry, 'num_carry', 'non-negative int', type(num_carry) is int and num_carry >= 0) - tc(jaxpr, 'jaxpr', 'ClosedJaxpr', type(jaxpr) is core.ClosedJaxpr) + tc(jaxpr, 'jaxpr', 'ClosedJaxpr', type(jaxpr) is ClosedJaxpr) tc(linear, 'linear', 'tuple of bool', type(linear) is tuple and all(type(x) is bool for x in linear)) - tc(unroll, 'unroll', 'positive int', type(unroll) is int and unroll > 0) + tc(unroll, 'unroll', 'non-negative int', type(unroll) is int and unroll >= 0) tc(length, 'length', 'non-negative int', length >= 0) @@ -1186,128 +1272,140 @@ def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts, raise core.JaxprTypeError( f'scan jaxpr takes input sequence types\n{_avals_short(x_avals_jaxpr)},\n' f'called with sequence whose items have type\n{_avals_short(x_avals_mapped)}') - return [*init_avals, *y_avals], jaxpr.effects - -def _scan_state_partial_discharge_rule(should_discharge, in_avals, out_avals, *args, jaxpr, num_consts, - num_carry, linear, unroll, reverse, length, - _split_transpose): - # We're shuffling parameters between three signatures for the scan body: - # jaxpr : (n_consts, n_carry, n_xs) -> (n_carry, n_ys) - # discharged : (n_consts, n_carry, n_xs) -> (n_carry, n_ys, n_ref_consts, n_ref_xs) - # wrapped : (n_val_consts, (n_ref_consts, n_carry), (n_val_xs, n_ref_xs)) - # -> ((n_ref_consts, n_carry), (n_ys, n_ref_xs)) - # where we partition consts and xs between ref and non-ref versions: - # n_carry = (n_val_consts, n_ref_consts) - # n_xs = (n_val_xs, n_ref_xs) - - # avals from jaxpr (i.e. rank-reduced) rather than from caller - jaxpr, in_avals, out_avals, consts = jaxpr.jaxpr, jaxpr.in_avals, jaxpr.out_avals, jaxpr.consts - if consts: raise NotImplementedError - n_consts = num_consts - n_carry = num_carry - n_xs = len(in_avals) - n_consts - n_carry - n_ys = len(out_avals) - n_carry - consts_avals, carry_avals, xs_avals = split_list_checked(in_avals, - [n_consts, n_carry, n_xs]) - consts_discharge, carry_discharge, xs_discharge = split_list_checked(should_discharge, - [n_consts, n_carry, n_xs]) - - is_ref_const = [s and isinstance(a, state.AbstractRef) for s, a in zip(consts_discharge, consts_avals)] - assert not any(isinstance(a, state.AbstractRef) for a in carry_avals) - assert not any(carry_discharge) - is_ref_xs = [s and isinstance(a, state.AbstractRef) for s, a in zip(xs_discharge, xs_avals)] - n_ref_consts = sum(is_ref_const) - n_val_consts = n_consts - n_ref_consts - n_ref_xs = sum(is_ref_xs) - n_val_xs = n_xs - n_ref_xs - discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, (), should_discharge=should_discharge) - if discharged_consts: - raise NotImplementedError("Discharged jaxpr has consts. If you see this, " - "please open an issue at " - "https://github.com/jax-ml/jax/issues") - def wrapped(*wrapped_args): - val_consts, carry_in, ref_consts_in, val_xs, ref_xs_in = split_list_checked(wrapped_args, - [n_val_consts, n_carry, n_ref_consts, n_val_xs, n_ref_xs]) - consts = merge_lists(is_ref_const, val_consts, ref_consts_in) - xs = merge_lists(is_ref_xs, val_xs, ref_xs_in) - outs = core.eval_jaxpr(discharged_jaxpr, (), *consts, *carry_in, *xs) - carry_out, ys, ref_consts_out, ref_xs_out = split_list_checked(outs, - [n_carry, n_ys, n_ref_consts, n_ref_xs]) - return [*carry_out, *ref_consts_out, *ys, *ref_xs_out] - - def arrange_jaxpr_args_for_wrapped(args): - consts, carry_in, xs = split_list_checked(args, [n_consts, n_carry, n_xs]) - val_consts, ref_consts_in = partition_list(is_ref_const, consts) - val_xs, ref_xs_in = partition_list(is_ref_xs, xs) - return *val_consts, *carry_in, *ref_consts_in, *val_xs, *ref_xs_in - - # Rearrange the arguments such that they are: - # val_consts, carry, ref_consts, val_xs, ref_xs - # - # It is important that carry is immediately after the val_consts - # because pallas pattern matches the leading argument type to figure - # out if a scan_p eqn is equivalent to a fori loop (see - # `pallas.utils.pattern_match_scan_to_fori_loop()`). - args_for_wrapped = arrange_jaxpr_args_for_wrapped(args) - linear_for_wrapped = arrange_jaxpr_args_for_wrapped(linear) - avals_for_wrapped = arrange_jaxpr_args_for_wrapped(in_avals) - # Get the const avals that we need to discharge and leave the rest as-is. - deref_const_avals = tuple(c.inner_aval for c in avals_for_wrapped[n_val_consts + n_carry:n_consts + n_carry]) - deref_xs_avals = tuple(x.inner_aval for x in avals_for_wrapped[n_consts + n_carry + n_val_xs:]) - avals_for_wrapped_no_refs = ( - avals_for_wrapped[: n_val_consts + n_carry] - + deref_const_avals - + avals_for_wrapped[n_consts + n_carry :n_consts + n_carry + n_val_xs] - + deref_xs_avals - ) - # TODO(cperivol): avoid tracing the jaxpr twice. When doing so don't - # forget to manage the effects. - new_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(wrapped, debug_info=discharged_jaxpr.debug_info), - avals_for_wrapped_no_refs) - all_out = scan_p.bind(*args_for_wrapped, - jaxpr=core.ClosedJaxpr(new_jaxpr, ()), - length=length, - num_consts=n_val_consts, - num_carry=n_ref_consts + n_carry, - unroll=unroll, - reverse=reverse, - linear=linear_for_wrapped, _split_transpose=_split_transpose) - carry_out, ref_consts_out, ys, ref_xs_out = split_list_checked(all_out, - [n_carry, n_ref_consts, n_ys, n_ref_xs]) - refs_out_matching_in_avals = [ - *merge_lists(is_ref_const, [None] * n_val_consts, ref_consts_out), - *[None] * n_carry, - *merge_lists(is_ref_xs, [None] * n_val_xs, ref_xs_out)] - assert len(refs_out_matching_in_avals) == len(in_avals) - return refs_out_matching_in_avals, [*carry_out, *ys] + return [*init_avals, *y_avals], core.eqn_effects(jaxpr) + +def _scan_state_partial_discharge_rule( + should_discharge, in_avals, out_avals, *args, jaxpr, num_consts, num_carry, + linear, unroll, reverse, length, _split_transpose): + # jaxpr: [*consts, *pure_carry, *xs] -> [*pure_carry, *pure_ys] + # jaxpr_: [*consts, *pure_carry, *xs] -> [*pure_carry, *pure_ys, *ref_outs] + discharged_jaxpr = state_discharge.discharge_state2(jaxpr, should_discharge) + + num_xs = len(args) - num_consts - num_carry + is_ref = [isinstance(a, AbstractRef) and s for a, s in zip(jaxpr.in_avals, should_discharge)] + is_ref_const, _, is_ref_xs = split_list_checked(is_ref, [num_consts, num_carry, num_xs]) + num_const_refs = sum(is_ref_const) + num_xs_refs = sum(is_ref_xs) + num_pure_consts = num_consts - num_const_refs + num_ys = len(jaxpr.out_avals) - num_carry + + ds = partial(slicing.dynamic_index_in_dim, keepdims=False, allow_negative_indices=False) + dus = partial(slicing.dynamic_update_index_in_dim, axis=0, allow_negative_indices=False) + + def body(*consts_carry_xs): + pure_consts, [i_], const_refvals, carry, xs_refvals_, pure_xs = split_list( + consts_carry_xs, [num_pure_consts, 1, num_const_refs, num_carry, num_xs_refs]) + i = length - i_ - 1 if reverse else i_ + xs_refvals = [ds(x, i) for x in xs_refvals_] + consts = merge_lists(is_ref_const, pure_consts, const_refvals) + xs = merge_lists(is_ref_xs, pure_xs, xs_refvals) + outs = eval_jaxpr_p.bind(*consts, *carry, *xs, jaxpr=discharged_jaxpr) + carry, ys, const_refvals, xs_updates = split_list_checked( + outs, [num_carry, num_ys, num_const_refs, num_xs_refs]) + xs_refvals = [dus(x, u, i) for x, u in zip(xs_refvals_, xs_updates)] + return [i_ + 1, *const_refvals, *carry, *xs_refvals, *ys] + + def rearrange(lst): + consts, carry, xs = split_list_checked(lst, [num_consts, num_carry, num_xs]) + pure_consts, ref_consts = partition_list(is_ref_const, consts) + pure_xs, ref_xs = partition_list(is_ref_xs, xs) + return *pure_consts, *ref_consts, *carry, *ref_xs, *pure_xs + + in_avals = rearrange([core.typeof(a) for a in args]) + pure_const_avals, carry_avals, pure_xs_avals = split_list( + in_avals, [num_pure_consts, num_const_refs + num_carry + num_xs_refs]) + pure_x_avals = [core.mapped_aval(length, 0, a) for a in pure_xs_avals] + in_avals = [*pure_const_avals, core.typeof(0), *carry_avals, *pure_x_avals] + + if jaxpr.jaxpr.debug_info.arg_names is None: + arg_names = None + else: + arg_names = rearrange(jaxpr.jaxpr.debug_info.arg_names) + pure_const_names, carry_names, pure_xs_names = split_list( + arg_names, [num_pure_consts, num_const_refs + num_carry + num_xs_refs]) + arg_names = (*pure_const_names, 'iter', *carry_names, *pure_xs_names) + + dbg = jaxpr.jaxpr.debug_info._replace(arg_names=arg_names, result_paths=None) + + new_jaxpr_, _, new_consts = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(body, debug_info=dbg), in_avals) + new_jaxpr = core.ClosedJaxpr(new_jaxpr_, new_consts) + + pure_consts, carry, pure_xs = split_list( + rearrange(args), [num_pure_consts, num_const_refs + num_carry + num_xs_refs]) + _, *outs = scan_p.bind( + *pure_consts, 0, *carry, *pure_xs, jaxpr=new_jaxpr, length=length, + unroll=unroll, reverse=reverse, num_consts=num_pure_consts, + num_carry=1 + num_const_refs + num_carry + num_xs_refs, + linear=(False, *rearrange(linear)), _split_transpose=_split_transpose) + + const_refvals, carry, xs_refvals, ys = split_list( + outs, [num_const_refs, num_carry, num_xs_refs]) + refvals_iter = it.chain(const_refvals, xs_refvals) + refvals_out = [next(refvals_iter) if r else None for r in is_ref] + assert next(refvals_iter, None) is None + return refvals_out, [*carry, *ys] scan_p = core.Primitive("scan") +scan_p.is_effectful = lambda params: bool(params['jaxpr'].effects) # type: ignore scan_p.multiple_results = True scan_p.skip_canonicalization = True scan_p.def_impl(partial(dispatch.apply_primitive, scan_p)) scan_p.def_effectful_abstract_eval(_scan_abstract_eval) ad.primitive_jvps[scan_p] = _scan_jvp -ad.primitive_transposes[scan_p] = _scan_transpose +ad.fancy_transposes[scan_p] = _scan_transpose_fancy +ad.primitive_linearizations[scan_p] = _scan_linearize pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval -xla.register_initial_style_primitive(scan_p) +pxla.register_initial_style_primitive(scan_p) mlir.register_lowering(scan_p, mlir.lower_fun(_scan_impl, multiple_results=True)) batching.fancy_primitive_batchers[scan_p] = _scan_batching_rule core.custom_typechecks[scan_p] = partial(_scan_typecheck, False) pe.partial_eval_jaxpr_custom_rules[scan_p] = _scan_partial_eval_custom -pe.padding_rules[scan_p] = _scan_padding_rule pe.dce_rules[scan_p] = _scan_dce_rule state_discharge.register_partial_discharge_rule(scan_p)(_scan_state_partial_discharge_rule) -def _propagate_mem_kind_scan(*xm, reverse, length, num_consts, num_carry, jaxpr, - linear, unroll, _split_transpose): - return pxla.get_out_memory_kinds_via_propagation(jaxpr) -pxla.memory_kind_propagate_rule[scan_p] = _propagate_mem_kind_scan +def _scan_is_high(*_, jaxpr, **__) -> bool: + return jaxpr.jaxpr.is_high +scan_p.is_high = _scan_is_high # type: ignore + +def _scan_to_lojax(*hi_args, jaxpr, num_carry, num_consts, linear, **params): + # move qdd binders and corresponding hi_args from consts slots to carry slots + to_move = [t.has_qdd for t in jaxpr.in_aval_qdds[:num_consts]] + jaxpr = pe.move_invars_right(jaxpr, to_move) + hi_args = _move_right(hi_args, to_move) + num_consts -= sum(to_move) + num_carry += sum(to_move) + + # expand num_consts, num_carry, linear according to lo types + const_in_avals, carry_in_avals, _ = split_list(jaxpr.in_aval_qdds, [num_consts, num_carry]) + num_consts = sum(len(aval.lo_ty()) for aval in const_in_avals) + num_carry = sum(len(aval.lo_ty()) for aval in carry_in_avals) + linear = [l for aval, l_ in zip(jaxpr.in_aval_qdds, linear) + for l in (l_,) * len(aval.lo_ty())] + + # collect lo input values + lo_args = [lo_val for aval, x in zip(jaxpr.in_aval_qdds, hi_args) + for lo_val in (aval.read_loval(x) if aval.has_qdd + else aval.lower_val(x))] + + # lower the jaxpr and bind it using lo input values + lo_jaxpr = pe.lower_jaxpr(jaxpr) + all_outs = scan_p.bind(*lo_args, jaxpr=lo_jaxpr, num_consts=num_consts, + num_carry=num_carry, linear=tuple(linear), **params) + out_mut, lo_outs = split_list(all_outs, [pe.num_himuts_out(jaxpr)]) + pe.apply_himut(jaxpr, hi_args, out_mut) + return pe.raise_lo_outs(jaxpr.out_avals, lo_outs) +scan_p.to_lojax = _scan_to_lojax + +def _move_right(lst, to_move): + lst, rest = split_list(lst, [len(to_move)]) + left, right = partition_list(to_move, lst) + return [*left, *right, *rest] ### while_loop -@api_boundary +@partial(api_boundary, repro_api_name="jax.lax.while_loop") def while_loop(cond_fun: Callable[[T], BooleanNumeric], body_fun: Callable[[T], T], init_val: T) -> T: @@ -1373,50 +1471,86 @@ def while_loop(cond_fun, body_fun, init_val): # transformation on it), so we fall back to the primitive version. pass - def _create_jaxpr(init_val): - init_vals, in_tree = tree_flatten((init_val,)) - init_avals = tuple(_map(core.get_aval, init_vals)) - cond_dbg = api_util.debug_info("while_cond", cond_fun, (init_val,), {}) - cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr( - cond_fun, in_tree, init_avals, cond_dbg) - body_dbg = api_util.debug_info("while_body", body_fun, (init_val,), {}) - body_jaxpr, body_consts, body_tree = _initial_style_jaxpr( - body_fun, in_tree, init_avals, body_dbg) - if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1: + def _create_jaxpr(init_avals): + args_avals = FlatTree.pack(((init_avals,), {})) + cond_jaxpr, cond_out_avals = pe.trace_to_jaxpr(cond_fun, args_avals, cond_dbg) + body_jaxpr, body_out_avals = pe.trace_to_jaxpr(body_fun, args_avals, body_dbg) + if not treedef_is_leaf(cond_out_avals.tree) or len(cond_jaxpr.out_avals) != 1: msg = "cond_fun must return a boolean scalar, but got pytree {}." - raise TypeError(msg.format(cond_tree)) + raise TypeError(msg.format(cond_out_avals.tree)) + pred_aval = cond_jaxpr.out_avals[0] if (not isinstance(pred_aval, ShapedArray) or ShapedArray(pred_aval.shape, pred_aval.dtype) != ShapedArray((), np.bool_)): msg = "cond_fun must return a boolean scalar, but got output type(s) {}." raise TypeError(msg.format(cond_jaxpr.out_avals)) - return init_vals, init_avals, body_jaxpr, in_tree, cond_jaxpr, cond_consts, body_consts, body_tree + + return cond_jaxpr, body_jaxpr, body_out_avals + + cond_dbg = api_util.debug_info("while_cond", cond_fun, (init_val,), {}) + body_dbg = api_util.debug_info("while_body", body_fun, (init_val,), {}) + init_val = FlatTree.flatten(init_val) # type: ignore + init_aval = init_val.map(core.get_aval) # The body input and output avals must match exactly. However, we want to account for # the case when init contains weakly-typed values (e.g. Python scalars), with avals that # may not match the output despite being compatible by virtue of their weak type. # To do this, we compute the jaxpr in two passes: first with the raw inputs, and if # necessary, a second time with modified init values. - init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(init_val) - new_init_vals, changed = _promote_weak_typed_inputs(init_vals, init_avals, body_jaxpr.out_avals) - new_init_val, = tree_unflatten(in_tree, new_init_vals) - if changed: - init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(new_init_val) - cond_jaxpr, cond_consts, body_consts, body_tree = rest - - in_tree_children = in_tree.children() - assert len(in_tree_children) == 1 - _check_carry_type('while_loop body', body_fun, new_init_val, body_tree, - body_jaxpr.out_avals) + cond_jaxpr, body_jaxpr, body_out_avals = _create_jaxpr(init_aval) + if len(body_out_avals) != len(init_aval): + _check_carry_type('while_loop body', body_fun, init_aval, body_out_avals) + assert False, "shouldn't get here" + + init_val, changed = init_val.map3( + _promote_weak_typed_input, + init_aval, body_out_avals).unzip2() + if any(changed): + init_aval = init_val.map(core.get_aval) + cond_jaxpr, body_jaxpr, body_out_avals = _create_jaxpr(init_aval) + + cond_jaxpr, cond_consts = pe.separate_consts(cond_jaxpr) + body_jaxpr, body_consts = pe.separate_consts(body_jaxpr) + _check_carry_type('while_loop body', body_fun, init_aval, body_out_avals) + + if not all(not v.aval.has_qdd or v.initial_qdd == v.final_qdd for v in + body_jaxpr.jaxpr.invars): + raise TypeError("type-changing mutations not allowed in while_loop body") joined_effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects) disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `while`: {disallowed_effects}') + + # If the body forwards an input carry to an output carry, *and* it's not used + # by the cond fun, it can be moved to be a body const. Doing so can lead to + # efficiency wins: if e.g. we vmap the loop with a batched predicate, we batch + # the carry too, but not the body consts. + body_fwd = pe._jaxpr_forwarding(body_jaxpr.jaxpr) + carry_nofwd = [len(body_consts) + i != f for i, f in enumerate(body_fwd)] + cond_jaxpr_, keep_cond = pe.dce_jaxpr( + cond_jaxpr.jaxpr, [True], [True] * len(cond_consts) + carry_nofwd) + _, keep_cond_carry = split_list(keep_cond, [len(cond_consts)]) + move_to_const = _map(operator.not_, keep_cond_carry) + + init_vals = list(init_val) # type: ignore + if any(move_to_const): + cond_jaxpr = pe.close_jaxpr(cond_jaxpr_) + body_jaxpr = pe.prune_closed_jaxpr_outputs( + body_jaxpr, [not m for m in move_to_const]) + body_jaxpr = pe.move_binders_to_front( + body_jaxpr, [False] * len(body_consts) + move_to_const) + init_vals, new_body_consts = partition_list(move_to_const, init_vals) + body_consts = [*new_body_consts, *body_consts] + outs = while_p.bind(*cond_consts, *body_consts, *init_vals, cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr, body_nconsts=len(body_consts), body_jaxpr=body_jaxpr) - return tree_unflatten(body_tree, outs) + + if any(move_to_const): + outs = pe.merge_lists(move_to_const, outs, new_body_consts) + + return body_out_avals.update(outs).unflatten() def _join_while_effects(body_jaxpr, cond_jaxpr, body_nconsts, cond_nconsts @@ -1438,7 +1572,29 @@ def _join_while_effects(body_jaxpr, cond_jaxpr, body_nconsts, cond_nconsts def _while_loop_abstract_eval(*avals, cond_jaxpr, body_jaxpr, body_nconsts, cond_nconsts): - del avals + cond_consts_avals, body_consts_avals, in_avals = \ + util.split_list(avals, [cond_nconsts, body_nconsts]) + + if len(cond_jaxpr.in_avals) != len(cond_consts_avals) + len(in_avals): + raise core.JaxprTypeError( + f"while_loop {len(cond_jaxpr.in_avals)=} but {len(cond_consts_avals) + len(in_avals)=}") + if len(body_jaxpr.in_avals) != len(body_consts_avals) + len(in_avals): + raise core.JaxprTypeError( + f"while_loop {len(body_jaxpr.in_avals)=} but {len(body_consts_avals) + len(in_avals)=}") + # TODO(mattjj): check body carry type + # TODO(mattjj): make these typecompat checks work with bints + # if not all(_map(core.typecompat, [*cond_consts_avals, *in_avals], cond_jaxpr.in_avals)): # type: ignore + # cond_avals = [*cond_consts_avals, *in_avals] + # a1, a2 = next((a1, a2) for a1, a2 in zip(cond_avals, cond_jaxpr.in_avals) + # if not core.typecompat(a1, a2)) + # raise core.JaxprTypeError(f"while_loop cond function input type error: {a1} != {a2}") + # if not all(_map(core.typecompat, [*body_consts_avals, *in_avals], body_jaxpr.in_avals)): # type: ignore + # body_avals = [*body_consts_avals, *in_avals] + # a1, a2 = next((a1, a2) for a1, a2 in zip(body_avals, body_jaxpr.in_avals) + # if not core.typecompat(a1, a2)) + # raise core.JaxprTypeError(f"while_loop body function input type error: {a1} != {a2}") + + joined_effects = _join_while_effects(body_jaxpr, cond_jaxpr, body_nconsts, cond_nconsts) disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects) @@ -1572,7 +1728,7 @@ def _while_loop_jvp(primals, tangents, cond_nconsts, cond_jaxpr, body_nconsts, cond_jaxpr.jaxpr.eqns, cond_jaxpr.jaxpr.effects, augmented_debug) - cond_jaxpr_augmented = core.ClosedJaxpr(cond_jaxpr_augmented, cond_jaxpr.consts) + cond_jaxpr_augmented = ClosedJaxpr(cond_jaxpr_augmented, cond_jaxpr.consts) out = while_p.bind( *(cconst + bconst + bconst_dot + init + init_dot), @@ -1678,8 +1834,8 @@ def _while_partial_eval_custom(saveable, unks_in, inst_in, eqn): else: assert False, "Fixpoint not reached" assert not num_res - body_jaxpr_known = core.ClosedJaxpr(jaxpr_known_, body_jaxpr.consts) - del jaxpr_known_, carry_uk_out, num_res + body_jaxpr_known = ClosedJaxpr(jaxpr_known_, body_jaxpr.consts) + del jaxpr_known_, carry_uk_out, num_res, unks_in # Instantiate all inputs (b/c jaxpr_staged will take all inputs). new_inst = [x for x, inst in zip(eqn.invars, inst_in) @@ -1697,10 +1853,11 @@ def _while_partial_eval_custom(saveable, unks_in, inst_in, eqn): # we handle it: if it is unknown, stage out the whole cond function. if cond_uk: return None, eqn, [True] * len(carry_uk), [True] * len(carry_uk), new_inst - cond_jaxpr_known = core.ClosedJaxpr(cond_jaxpr_known_, cond_jaxpr.consts) + cond_jaxpr_known = ClosedJaxpr(cond_jaxpr_known_, cond_jaxpr.consts) del cond_uk # Build the known eqn. + unks_in = [*cond_consts_uk, *body_consts_uk, *carry_uk] # fixpoint carry_uk ins_known, _ = partition_list(unks_in, eqn.invars) out_binders_known, _ = partition_list(carry_uk, eqn.outvars) params_known = dict(cond_jaxpr=cond_jaxpr_known, body_jaxpr=body_jaxpr_known, @@ -1711,6 +1868,11 @@ def _while_partial_eval_custom(saveable, unks_in, inst_in, eqn): eqn_known = pe.new_jaxpr_eqn(ins_known, out_binders_known, while_p, params_known, effects_known, eqn.source_info, eqn.ctx) + # Typecheck known eqn. + _while_loop_abstract_eval( + *[v.aval for v in eqn_known.invars], cond_jaxpr=cond_jaxpr_known, + body_jaxpr=body_jaxpr_known, body_nconsts=params_known['body_nconsts'], + cond_nconsts=params_known['cond_nconsts']) # Staged eqn is same as input eqn. eqn_staged = eqn @@ -1763,18 +1925,19 @@ def cond(args): pred = lax.reduce_or(pred, tuple(range(len(pred_aval.shape)))) return pred def body(args): - return tuple(core.eval_jaxpr(body_jaxpr.jaxpr, body_jaxpr.consts, *args)) + return core.eval_jaxpr(body_jaxpr.jaxpr, body_jaxpr.consts, *args) def new_cond(pred_args): - pred, _ = pred_args + pred, *_ = pred_args return pred def new_body(pred_args): - _, args = pred_args - args = body(args) - pred = cond(args) - return pred, args + _, cond_consts, body_consts, carry = pred_args + carry = body((*body_consts, *carry)) + pred = cond((*cond_consts, *carry)) + return pred, cond_consts, body_consts, carry def fun(*args): - pred = cond(args) - _, out = while_loop(new_cond, new_body, (pred, args)) + cond_consts, body_consts, carry = split_list(args, [cond_nconsts, body_nconsts]) + pred = cond((*cond_consts, *carry)) + *_, out = while_loop(new_cond, new_body, (pred, cond_consts, body_consts, carry)) return out return mlir.lower_fun(fun)(ctx, *args) @@ -1798,11 +1961,11 @@ def fun(*args): cond_block.arguments[i] for i in range(len(flat_loop_carry_types)) ] cond_args = mlir.unflatten_ir_values_like_types(flat_cond_args, loop_carry_types) - # Remove tokens from cond args - cond_args = cond_args[num_tokens:] + cond_args = cond_args[num_tokens:] # Remove tokens from cond args x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts]) cond_consts = [ - mlir.ir_constant(xla.canonicalize_dtype(x)) for x in cond_jaxpr.consts + mlir.ir_constant(x, aval=var.aval) + for x, var in zip(cond_jaxpr.consts, cond_jaxpr.jaxpr.constvars) ] cond_name_stack = name_stack.extend('cond') (pred,), _ = mlir.jaxpr_subcomp( @@ -1813,16 +1976,21 @@ def fun(*args): cond_consts, *(x + z), dim_var_values=ctx.dim_var_values, + const_lowering=ctx.const_lowering, ) if batched: pred_ctx = mlir.LoweringRuleContext( module_context=ctx.module_context, name_stack=cond_name_stack, + traceback=ctx.traceback, primitive=None, avals_in=[pred_aval], - avals_out=[pred_aval.update(shape=())], + avals_out=[pred_aval.update( + shape=(), sharding=pred_aval.sharding.update(spec=()))], tokens_in=mlir.TokenSet(), - tokens_out=None) + tokens_out=None, + dim_var_values=ctx.dim_var_values, + const_lowering=ctx.const_lowering) pred, = lax._unary_reduce_lower( hlo.OrOp, lambda dtype: np.array(False, dtype), @@ -1843,26 +2011,32 @@ def fun(*args): tokens_in = mlir.TokenSet(zip(body_effects, token_args)) x, y, z = util.split_list(body_args, [cond_nconsts, body_nconsts]) body_name_stack = name_stack.extend('body') - body_consts = [mlir.ir_constant(xla.canonicalize_dtype(x)) - for x in body_jaxpr.consts] + body_consts = [ + mlir.ir_constant(x, aval=var.aval) + for x, var in zip(body_jaxpr.consts, body_jaxpr.jaxpr.constvars) + ] new_z, tokens_out = mlir.jaxpr_subcomp( ctx.module_context, body_jaxpr.jaxpr, body_name_stack, - tokens_in, body_consts, *(y + z), dim_var_values=ctx.dim_var_values) + tokens_in, body_consts, *(y + z), + dim_var_values=ctx.dim_var_values, const_lowering=ctx.const_lowering) out_tokens = [tokens_out.get(eff) for eff in body_effects] if batched: body_pred_name_stack = name_stack.extend('body_pred') - cond_consts = [mlir.ir_constant(xla.canonicalize_dtype(x)) - for x in cond_jaxpr.consts] + cond_consts = [ + mlir.ir_constant(x, aval=var.aval) + for x, var in zip(cond_jaxpr.consts, cond_jaxpr.jaxpr.constvars) + ] (body_pred,), _ = mlir.jaxpr_subcomp( ctx.module_context, cond_jaxpr.jaxpr, body_pred_name_stack, mlir.TokenSet(), cond_consts, *(x + z), - dim_var_values=ctx.dim_var_values) + dim_var_values=ctx.dim_var_values, const_lowering=ctx.const_lowering) new_z = _map( partial(_pred_bcast_select_hlo, ctx, pred_aval, body_pred), new_z, z, body_jaxpr.out_avals) - hlo.return_([*mlir.flatten_ir_values(out_tokens), *mlir.flatten_ir_values(x), *mlir.flatten_ir_values(y), - *mlir.flatten_ir_values(new_z)]) + hlo.return_([*mlir.flatten_ir_values(out_tokens), + *mlir.flatten_ir_values(x), *mlir.flatten_ir_values(y), + *mlir.flatten_ir_values(new_z)]) outputs = mlir.unflatten_ir_values_like_types(while_op.results, loop_carry_types) tokens, _, _, z = util.split_list(outputs, [num_tokens, cond_nconsts, body_nconsts]) @@ -1881,38 +2055,70 @@ def _while_typecheck(_, *in_atoms, cond_jaxpr, body_jaxpr, cond_nconsts, f'Effects not supported in `while`: {disallowed_effects}') return body_jaxpr.out_avals, joined_effects -def _while_partial_discharge_rule(should_discharge, in_avals, out_avals, *args, cond_jaxpr, body_jaxpr, - cond_nconsts, body_nconsts): - # TODO(sharadmv): enable supporting state effects in the cond - if any(isinstance(eff, state.RefEffect) for eff in cond_jaxpr.effects): - raise NotImplementedError +def _while_partial_discharge_rule(should_discharge, in_avals, out_avals, *args, + cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts): + del out_avals cond_consts_discharge, body_consts_discharge, carry_discharge = split_list( should_discharge, [cond_nconsts, body_nconsts]) - - if any(cond_consts_discharge): - raise NotImplementedError cond_consts, body_consts, carry = split_list(args, [cond_nconsts, body_nconsts]) cond_consts_avals, body_consts_avals, carry_avals = split_list(in_avals, [cond_nconsts, body_nconsts]) - # There shouldn't be any `Ref`s in the `cond` (because of our check above). - assert not any(isinstance(aval, state.AbstractRef) for aval in cond_consts_avals) - is_ref = [ + + # Check if the same Ref is written to in both cond and body. + cond_write_ids = {id(cond_consts_avals[effect.input_index]) + for effect in cond_jaxpr.effects if isinstance(effect, state.WriteEffect)} + cond_has_writes = len(cond_write_ids) > 0 + body_write_ids = {id(body_consts_avals[effect.input_index]) + for effect in body_jaxpr.effects if isinstance(effect, state.WriteEffect)} + write_to_both_ids = cond_write_ids & body_write_ids + if write_to_both_ids: + raise NotImplementedError( + "Cannot write to the same ref in both cond and body of while loop.") + + cond_is_ref = [ + isinstance(aval, state.AbstractRef) and should + for aval, should in zip(cond_consts_avals, cond_consts_discharge) + ] + remaining_cond_consts, cond_refs = partition_list(cond_is_ref, cond_consts) + remaining_cond_const_avals, cond_ref_avals = partition_list(cond_is_ref, + cond_consts_avals) + num_cond_refs = sum(cond_is_ref) + num_remaining_cond_consts = cond_nconsts - num_cond_refs + body_is_ref = [ isinstance(aval, state.AbstractRef) and should for aval, should in zip(body_consts_avals, body_consts_discharge) ] - remaining_body_consts, refs = partition_list(is_ref, body_consts) - remaining_body_const_avals, ref_avals = partition_list(is_ref, + remaining_body_consts, body_refs = partition_list(body_is_ref, body_consts) + remaining_body_const_avals, body_ref_avals = partition_list(body_is_ref, body_consts_avals) - num_refs = sum(is_ref) - num_remaining_consts = body_nconsts - num_refs + num_body_refs = sum(body_is_ref) + num_remaining_body_consts = body_nconsts - num_body_refs + num_out_body_consts = num_remaining_body_consts + if cond_has_writes: + # If the cond has writes, we need to add the cond consts into the body + # consts since we need to evaluate the cond condition in the body. + remaining_body_consts = [*remaining_cond_consts, *remaining_body_consts] + remaining_body_const_avals = [*remaining_cond_const_avals, + *remaining_body_const_avals] + num_remaining_body_consts += num_remaining_cond_consts + num_carry = len(in_avals) - body_nconsts - cond_nconsts body_jaxpr, body_jaxpr_consts = body_jaxpr.jaxpr, body_jaxpr.consts - cond_jaxpr, cond_jaxpr_consts = cond_jaxpr.jaxpr, cond_jaxpr.consts if body_jaxpr_consts: raise NotImplementedError("Body jaxpr has consts. If you see this error, " "please open an issue at " "https://github.com/jax-ml/jax/issues") + cond_jaxpr, cond_jaxpr_consts = cond_jaxpr.jaxpr, cond_jaxpr.consts + if cond_jaxpr_consts: + raise NotImplementedError("Cond jaxpr has consts. If you see this error, " + "please open an issue at " + "https://github.com/jax-ml/jax/issues") + (discharged_cond_jaxpr, discharged_cond_consts + ) = state_discharge.discharge_state( + cond_jaxpr, (), + should_discharge=[*cond_consts_discharge, *carry_discharge]) + if discharged_cond_consts: raise NotImplementedError # body_jaxpr has the signature (*body_consts, *carry) -> carry. # Some of these body_consts are actually `Ref`s so when we discharge # them, they also turn into outputs, effectively turning those consts into @@ -1924,16 +2130,38 @@ def _while_partial_discharge_rule(should_discharge, in_avals, out_avals, *args, if discharged_consts: raise NotImplementedError def new_body(*consts_refs_carry): - consts, refs, carry = split_list( - consts_refs_carry, [num_remaining_consts, num_refs]) - consts_and_refs = merge_lists(is_ref, consts, refs) - carry_refs = core.eval_jaxpr(discharged_body_jaxpr, (), *consts_and_refs, + consts, body_refs, cond_refs, carry = split_list( + consts_refs_carry, + [num_remaining_body_consts, num_body_refs, num_cond_refs]) + if cond_has_writes: + # We run the cond jaxpr in the body so that Refs that are updated + # in the cond jaxpr are persisted via the carry. + cond_consts, body_consts = split_list(consts, [num_remaining_cond_consts]) + cond_consts_and_refs = merge_lists(cond_is_ref, cond_consts, cond_refs) + cond_carry_refs = core.eval_jaxpr(discharged_cond_jaxpr, (), + *cond_consts_and_refs, + *carry) + # Note: in order to handle the same Ref being updated in both the cond + # and body, we would need to interleave the updated cond_carry_refs into + # body_refs here. + # Currently we disallow this so we don't need to handle it. + _, cond_refs_out = split_list(cond_carry_refs, [1]) + assert len(cond_refs_out) == len(cond_refs) + else: + body_consts = consts + cond_refs_out = cond_refs + + body_consts_and_refs = merge_lists(body_is_ref, body_consts, body_refs) + body_carry_refs = core.eval_jaxpr(discharged_body_jaxpr, (), + *body_consts_and_refs, *carry) - carry, refs_out = split_list(carry_refs, [num_carry]) - return [*refs_out, *carry] - new_body_jaxpr, _, new_body_consts, () = pe.trace_to_jaxpr_dynamic( + carry, body_refs_out = split_list(body_carry_refs, [num_carry]) + return [*body_refs_out, *cond_refs_out, *carry] + new_body_jaxpr, _, new_body_consts = pe.trace_to_jaxpr_dynamic( lu.wrap_init(new_body, debug_info=discharged_body_jaxpr.debug_info), - [*remaining_body_const_avals, *[a.inner_aval for a in ref_avals], + [*remaining_body_const_avals, + *[a.inner_aval for a in body_ref_avals], + *[a.inner_aval for a in cond_ref_avals], *carry_avals]) if new_body_consts: raise NotImplementedError @@ -1941,25 +2169,41 @@ def new_body(*consts_refs_carry): # deal with them (i.e. ignore them) in the `cond`, so we need to rewrite the # cond_jaxpr as well. def new_cond(*consts_refs_carry): - consts, refs, carry = split_list( - consts_refs_carry, [cond_nconsts, num_refs]) - del refs # We don't use them here! - return core.eval_jaxpr(cond_jaxpr, cond_jaxpr_consts, *consts, *carry) - new_cond_jaxpr, _, new_cond_consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(new_cond, debug_info=cond_jaxpr.debug_info), - [*cond_consts_avals, *[a.inner_aval for a in ref_avals], *carry_avals]) + consts, body_refs, cond_refs, carry = split_list( + consts_refs_carry, [num_remaining_cond_consts, num_body_refs, num_cond_refs]) + # We don't use them here! + del body_refs + cond_consts_and_refs = merge_lists(cond_is_ref, consts, cond_refs) + results = core.eval_jaxpr( + discharged_cond_jaxpr, (), *cond_consts_and_refs, *carry) + predicate, refs_out = split_list(results, [1]) + assert len(refs_out) == len(cond_refs) + return predicate + + new_cond_jaxpr, _, new_cond_consts = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(new_cond, debug_info=cond_jaxpr.debug_info.with_unknown_names()), + [*remaining_cond_const_avals, + *[a.inner_aval for a in body_ref_avals], + *[a.inner_aval for a in cond_ref_avals], + *carry_avals]) if new_cond_consts: raise NotImplementedError - out = while_p.bind(*cond_consts, *remaining_body_consts, *refs, *carry, - body_jaxpr=core.ClosedJaxpr(new_body_jaxpr, ()), - cond_jaxpr=core.ClosedJaxpr(new_cond_jaxpr, ()), - body_nconsts=num_remaining_consts, - cond_nconsts=cond_nconsts) - refs_out, carry_out = split_list(out, [num_refs]) - updated_body_consts = merge_lists(is_ref, [None] * num_remaining_consts, - refs_out) + out = while_p.bind(*remaining_cond_consts, *remaining_body_consts, + *body_refs, *cond_refs, *carry, + body_jaxpr=ClosedJaxpr(new_body_jaxpr, ()), + cond_jaxpr=ClosedJaxpr(new_cond_jaxpr, ()), + body_nconsts=num_remaining_body_consts, + cond_nconsts=num_remaining_cond_consts) + body_refs_out, cond_refs_out, carry_out = split_list( + out, [num_body_refs, num_cond_refs]) + updated_cond_consts = merge_lists(cond_is_ref, + [None] * num_remaining_cond_consts, + cond_refs_out) + updated_body_consts = merge_lists(body_is_ref, + [None] * num_out_body_consts, + body_refs_out) invals_out = [ - *[None] * cond_nconsts, + *updated_cond_consts, *updated_body_consts, *[None] * num_carry] return invals_out, carry_out @@ -1971,14 +2215,62 @@ def new_cond(*consts_refs_carry): while_p.def_effectful_abstract_eval(_while_loop_abstract_eval) ad.primitive_jvps[while_p] = _while_loop_jvp pe.custom_partial_eval_rules[while_p] = _while_partial_eval -xla.register_initial_style_primitive(while_p) +pxla.register_initial_style_primitive(while_p) ad.primitive_transposes[while_p] = _while_transpose_error batching.fancy_primitive_batchers[while_p] = _while_loop_batching_rule pe.partial_eval_jaxpr_custom_rules[while_p] = _while_partial_eval_custom -mlir.register_lowering(while_p, _while_lowering) core.custom_typechecks[while_p] = _while_typecheck +mlir.register_lowering(while_p, _while_lowering) state_discharge.register_partial_discharge_rule(while_p)(_while_partial_discharge_rule) +def _while_is_high(*_, cond_jaxpr, body_jaxpr, **__): + return cond_jaxpr.is_high or body_jaxpr.is_high +while_p.is_high = _while_is_high # type: ignore + +def _while_to_lojax(*hi_args, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts): + if any(a.has_qdd for a in cond_jaxpr.in_avals[:cond_nconsts]): + raise NotImplementedError # TODO(mattjj,dougalm) + assert not any(a.has_qdd for a in cond_jaxpr.in_avals[cond_nconsts:]) + + hi_cconsts, hi_bconsts, hi_carry = split_list(hi_args, [cond_nconsts, body_nconsts]) + + # move qdd binders and corresponding hi_args from consts slots to carry slots + to_move = [t.has_qdd for t in body_jaxpr.in_aval_qdds[:body_nconsts]] + body_jaxpr = pe.move_invars_right(body_jaxpr, to_move) + hi_bconsts, hi_bconsts_qdd = partition_list(to_move, hi_bconsts) + hi_carry = [*hi_bconsts_qdd, *hi_carry] + body_nconsts -= sum(to_move) + cond_jaxpr = _insert_binders(cond_jaxpr, cond_nconsts, hi_bconsts_qdd) + del hi_bconsts_qdd + + # collect input values + loval = lambda a, x: a.read_loval(x) if a.has_qdd else a.lower_val(x) + lovals = lambda avals, xs: [lo for a, x in zip(avals, xs) for lo in loval(a, x)] + lo_cconsts = lovals(cond_jaxpr.in_aval_qdds[:cond_nconsts], hi_cconsts) + lo_bconsts = lovals(body_jaxpr.in_aval_qdds[:body_nconsts], hi_bconsts) + lo_carry = lovals(body_jaxpr.in_aval_qdds[body_nconsts:], hi_carry) + + # expand cond_nconsts and body_nconsts according to lo types + cond_nconsts = sum(len(typeof(x).lo_ty()) for x in hi_cconsts) + body_nconsts = sum(len(typeof(x).lo_ty()) for x in hi_bconsts) + + # lower jaxprs and bind + all_outs = while_p.bind(*lo_cconsts, *lo_bconsts, *lo_carry, + cond_jaxpr=pe.lower_jaxpr(cond_jaxpr), + body_jaxpr=pe.lower_jaxpr(body_jaxpr), + cond_nconsts=cond_nconsts, body_nconsts=body_nconsts) + out_mut, lo_outs = split_list(all_outs, [pe.num_himuts_out(body_jaxpr)]) + pe.apply_himut(body_jaxpr, [*hi_bconsts, *hi_carry], out_mut) + return pe.raise_lo_outs(body_jaxpr.out_avals, lo_outs) +while_p.to_lojax = _while_to_lojax # type: ignore + +def _insert_binders(jaxpr, n_after, vals): + avals = _map(typeof, vals) + invars = [core.Var(lo_ty) for a, x in zip(avals, vals) for lo_ty in + (a.lo_ty_qdd(cur_qdd(x)) if a.has_qdd else a.lo_ty())] + invars = jaxpr.jaxpr.invars[:n_after] + invars + jaxpr.jaxpr.invars[n_after:] + return jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(invars=invars)) + def _pred_bcast_select_hlo(ctx, pred_aval: core.ShapedArray, pred: ir.Value, x: mlir.IrValues, @@ -1994,7 +2286,7 @@ def _pred_bcast_select_hlo(ctx, pred_aval.shape, x_y_aval) x_y_aval = core.physical_aval(x_y_aval) bcast_pred = mlir.broadcast_in_dim( - ctx, pred, core.DShapedArray(x_y_aval.shape, np.dtype(np.bool_)), + ctx, pred, core.ShapedArray(x_y_aval.shape, np.dtype(np.bool_)), broadcast_dimensions=list(range(len(pred_aval.shape)))) return hlo.SelectOp(bcast_pred, x, y).results @@ -2005,23 +2297,34 @@ def _fori_cond_fun(loop_carry): return lax.lt(i, upper) @weakref_lru_cache -def _fori_body_fun(body_fun): +def _fori_body_fun(body_fun: Callable, body_fun_dbg: core.DebugInfo) -> Callable: body_fun_ref = weakref.ref(body_fun) def while_body_fun(loop_carry): i, upper, x = loop_carry return lax.add(i, lax._const(i, 1)), upper, body_fun_ref()(i, x) + if body_fun_dbg.arg_names is not None: + arg_names = (body_fun_dbg.arg_names[0], + "", # upper, + * body_fun_dbg.arg_names[1:]) + else: + arg_names = None + api_util.save_wrapped_fun_debug_info( + while_body_fun, + body_fun_dbg._replace(arg_names=arg_names)) return while_body_fun @weakref_lru_cache -def _fori_scan_body_fun(body_fun): +def _fori_scan_body_fun(body_fun: Callable, body_fun_dbg: core.DebugInfo) -> Callable: body_fun_ref = weakref.ref(body_fun) def scanned_fun(loop_carry, _): i, x = loop_carry return (i + 1, body_fun_ref()(i, x)), None + api_util.save_wrapped_fun_debug_info( + scanned_fun, body_fun_dbg._replace(result_paths=None)) return scanned_fun -@api_boundary +@partial(api_boundary, repro_api_name="jax.lax.fori_loop") def fori_loop(lower, upper, body_fun, init_val, *, unroll: int | bool | None = None): """Loop from ``lower`` to ``upper`` by reduction to :func:`jax.lax.while_loop`. @@ -2071,7 +2374,7 @@ def fori_loop(lower, upper, body_fun, init_val): unroll: An optional integer or boolean that determines how much to unroll the loop. If an integer is provided, it determines how many unrolled loop iterations to run within a single rolled iteration of the loop. If a - boolean is provided, it will determine if the loop is competely unrolled + boolean is provided, it will determine if the loop is completely unrolled (i.e. `unroll=True`) or left completely unrolled (i.e. `unroll=False`). This argument is only applicable if the loop bounds are statically known. @@ -2084,8 +2387,8 @@ def fori_loop(lower, upper, body_fun, init_val): raise TypeError("lax.fori_loop: body_fun argument should be callable.") # TODO(phawkins): perhaps do more type checking here, better error messages. - lower_dtype = dtypes.canonicalize_dtype(lax.dtype(lower)) - upper_dtype = dtypes.canonicalize_dtype(lax.dtype(upper)) + lower_dtype = lax.dtype(lower) + upper_dtype = lax.dtype(upper) if lower_dtype == upper_dtype: dtype = lower_dtype else: @@ -2119,6 +2422,9 @@ def fori_loop(lower, upper, body_fun, init_val): else: use_scan = False + body_fun_dbg = api_util.debug_info("fori_loop", body_fun, + (0, init_val), {}) + if use_scan: if unroll is None: unroll = False @@ -2126,9 +2432,7 @@ def fori_loop(lower, upper, body_fun, init_val): if config.disable_jit.value and length == 0: # non-jit implementation of scan does not support length=0 return init_val - - scan_body = _fori_scan_body_fun(body_fun) - api_util.save_wrapped_fun_sourceinfo(scan_body, body_fun) + scan_body = _fori_scan_body_fun(body_fun, body_fun_dbg) (_, result), _ = scan( scan_body, (lower_, init_val), @@ -2137,7 +2441,7 @@ def fori_loop(lower, upper, body_fun, init_val): unroll=unroll, ) return result - if unroll is not None: + if unroll is not None and unroll is not False and unroll != 1: raise ValueError("Can only use `unroll` in `fori_loop` if the loop bounds " "are statically known.") @@ -2145,37 +2449,63 @@ def fori_loop(lower, upper, body_fun, init_val): lower = lax.convert_element_type(lower, dtype) # type: ignore if upper_dtype != dtype: upper = lax.convert_element_type(upper, dtype) # type: ignore - while_body_fun = _fori_body_fun(body_fun) - api_util.save_wrapped_fun_sourceinfo(while_body_fun, body_fun) + while_body_fun = _fori_body_fun(body_fun, body_fun_dbg) _, _, result = while_loop(_fori_cond_fun, while_body_fun, (lower, upper, init_val)) return result ### map and miscellaneous rules -def _batch_and_remainder(x, batch_size: int): - leaves, treedef = tree_flatten(x) - - scan_leaves = [] - remainder_leaves = [] +def _scan_leaf(leaf, batch_elems, num_batches, batch_size): + def f(l): + return l[:batch_elems].reshape(num_batches, batch_size, *leaf.shape[1:]) - for leaf in leaves: - num_batches, _ = divmod(leaf.shape[0], batch_size) - total_batch_elems = num_batches * batch_size - scan_leaves.append(leaf[:total_batch_elems].reshape(num_batches, batch_size, *leaf.shape[1:])) - remainder_leaves.append(leaf[total_batch_elems:]) + aval = core.typeof(leaf) + if aval.sharding.spec[0] is not None: + raise ValueError( + '0th dimension of leaf passed to `jax.lax.map` should be replicated.' + f' Got {aval.str_short(True, True)}') + + out_s = aval.sharding.update(spec=P(None, None, *aval.sharding.spec[1:])) + out_s = canonicalize_sharding(out_s, 'lax.map') + if out_s is not None and out_s.mesh._any_axis_explicit: + return auto_axes(f, out_sharding=out_s, axes=out_s.mesh.explicit_axes)(leaf) + return f(leaf) + +def _remainder_leaf(leaf, batch_elems): + def f(l): + return l[batch_elems:] + sharding = canonicalize_sharding(core.typeof(leaf).sharding, 'lax.map') + if sharding is not None and sharding.mesh._any_axis_explicit: + return auto_axes( + f, out_sharding=sharding, axes=sharding.mesh.explicit_axes + )(leaf) + return f(leaf) - scan_tree = treedef.unflatten(scan_leaves) - remainder_tree = treedef.unflatten(remainder_leaves) - return scan_tree, remainder_tree +def _batch_and_remainder(x, batch_size: int): + leaves, treedef = tree_flatten(x) + if not leaves: + return x, None + if batch_size == 0: + num_batches, remainder = 0, leaves[0].shape[0] + else: + num_batches, remainder = divmod(leaves[0].shape[0], batch_size) + batch_elems = num_batches * batch_size + if num_batches == 0: + remainder_leaves = [_remainder_leaf(leaf, batch_elems) for leaf in leaves] + return None, treedef.unflatten(remainder_leaves) + elif remainder: + scan_leaves, remainder_leaves = unzip2( # type: ignore + [(_scan_leaf(leaf, batch_elems, num_batches, batch_size), + _remainder_leaf(leaf, batch_elems)) for leaf in leaves]) + return treedef.unflatten(scan_leaves), treedef.unflatten(remainder_leaves) + else: + scan_leaves = tuple(_scan_leaf(leaf, batch_elems, num_batches, batch_size) + for leaf in leaves) + return treedef.unflatten(scan_leaves), None @api_boundary -def map( - f, - xs, - *, - batch_size: int | None = None, -): +def map(f, xs, *, batch_size: int | None = None): """Map a function over leading array axes. Like Python's builtin map, except inputs and outputs are in the form of @@ -2200,6 +2530,8 @@ def map(f, xs): divisible by the batch size, the remainder is processed in a separate ``vmap`` and concatenated to the result. + ``batch_size=0`` is equivalent to applying a ``vmap``. That is, it uses a full batch. + >>> x = jnp.ones((10, 3, 4)) >>> def f(x): ... print('inner shape:', x.shape) @@ -2226,28 +2558,42 @@ def map(f, xs): if batch_size is not None: scan_xs, remainder_xs = _batch_and_remainder(xs, batch_size) g = lambda _, x: ((), api.vmap(f)(x)) - _, scan_ys = scan(g, (), scan_xs) - remainder_ys = api.vmap(f)(remainder_xs) + if scan_xs is not None: + _, scan_ys = scan(g, (), scan_xs) + else: + scan_ys = None + flatten = lambda x: x.reshape(-1, *x.shape[2:]) - ys = tree_map( - lambda x, y: lax.concatenate([flatten(x), y], dimension=0), scan_ys, remainder_ys, - ) + if scan_ys is None: + ys = api.vmap(f)(remainder_xs) + elif remainder_xs is not None: + remainder_ys = api.vmap(f)(remainder_xs) + ys = tree_map( + lambda x, y: lax.concatenate([flatten(x), y], dimension=0), scan_ys, + remainder_ys) + else: + ys = tree_map(flatten, scan_ys) else: g = lambda _, x: ((), f(x)) _, ys = scan(g, (), xs) return ys -def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, algorithm): +def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, + algorithm, out_sharding): keys, = batched_args bd, = batch_dims if bd is batching.not_mapped: - return lax.rng_bit_generator_p.bind(keys, shape=shape, dtype=dtype, - algorithm=algorithm), (None, None) + return lax.rng_bit_generator_p.bind( + keys, shape=shape, dtype=dtype, algorithm=algorithm, + out_sharding=out_sharding), (None, None) keys = batching.moveaxis(keys, bd, 0) batch_size = keys.shape[0] + out_s = (out_sharding.update(spec=(keys.aval.sharding.spec[0], *out_sharding.spec)) + if out_sharding is not None else None) key = keys[0] - new_key, bits = lax.rng_bit_generator_p.bind(key, shape=(batch_size, *shape), - dtype=dtype, algorithm=algorithm) + new_key, bits = lax.rng_bit_generator_p.bind( + key, shape=(batch_size, *shape), dtype=dtype, algorithm=algorithm, + out_sharding=out_s) new_keys = slicing.dynamic_update_index_in_dim(keys, new_key, 0, axis=0) return (new_keys, bits), (0, 0) @@ -2288,6 +2634,9 @@ def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0): of ``elems`` along ``axis``. For example, given ``elems = [a, b, c, ...]``, the result would be ``[a, fn(a, b), fn(fn(a, b), c), ...]``. + If ``elems = [..., x, y, z]`` and ``reverse`` is true, the result is + ``[..., f(f(z, y), x), f(z, y), z]``. + Example 1: partial sums of an array of numbers: >>> lax.associative_scan(jnp.add, jnp.arange(0, 4)) @@ -2482,21 +2831,23 @@ def _cumred_dtype_rule(name, operand, *args, **kw): if not dtypes.issubdtype(operand.dtype, np.number): raise TypeError("{} does not accept dtype {}. Accepted dtypes are subtypes " "of number.".format(name, np.dtype(operand.dtype).name)) - return dtypes.canonicalize_dtype(operand.dtype) + return operand.dtype def _cumulative_reduction_primitive(name, reduce_fn, reduce_window_fn): reducer_p = lax.standard_primitive( _cumred_shape_rule, partial(_cumred_dtype_rule, name), - name, sharding_rule=_cumred_sharding_rule) + name, sharding_rule=_cumred_sharding_rule, + vma_rule=partial(core.standard_vma_rule, name)) batching.primitive_batchers[reducer_p] = partial(_cumred_batch_rule, reducer_p) def register_lowering(fn, platform=None): mlir.register_lowering( reducer_p, - mlir.cache_lowering(mlir.lower_fun(fn, multiple_results=False)), - platform=platform) + mlir.lower_fun(fn, multiple_results=False), + platform=platform, + inline=False) # For jax-metal, until reduce_window legalization is better supported. register_lowering(partial(associative_scan, reduce_fn), 'METAL') diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index acfcfd7ff3d3..9de1ec5b08e8 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -15,10 +15,9 @@ import collections from functools import partial import operator -from typing import Any, Callable +from typing import Any +from collections.abc import Callable -from jax.tree_util import (tree_flatten, treedef_children, tree_leaves, - tree_unflatten, treedef_tuple) from jax._src import ad_util from jax._src import api from jax._src import api_util @@ -28,14 +27,15 @@ from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir -from jax._src.interpreters import xla +from jax._src.interpreters import partial_eval as pe +from jax._src.interpreters import pxla from jax._src.traceback_util import api_boundary +from jax._src.tree_util import tree_leaves, FlatTree from jax._src.util import split_list, safe_map import numpy as np from jax._src.lax.control_flow.common import ( _check_tree, - _initial_style_jaxpr, ) _map = safe_map @@ -91,42 +91,44 @@ def custom_root(f: Callable, The result of calling solve(f, initial_guess) with gradients defined via implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``. """ - guess_flat, in_args_tree = tree_flatten((initial_guess,)) - guess_avals = tuple(_map(core.get_aval, guess_flat)) + guess_flat = FlatTree.flatten(initial_guess) + guess_avals = guess_flat.map(core.get_aval) f_debug = api_util.debug_info("custom_root", f, (initial_guess,), {}) - f_jaxpr, f_consts, out_tree = _initial_style_jaxpr( - f, in_args_tree, guess_avals, f_debug) + args_avals = FlatTree.pack(((guess_avals,),{})) + f_jaxpr, out_avals = pe.trace_to_jaxpr(f, args_avals, f_debug) + f_jaxpr, f_consts = pe.separate_consts(f_jaxpr) - in_tree, = treedef_children(in_args_tree) - _check_tree("f", "initial_guess", out_tree, in_tree, False) + _check_tree("f", "initial_guess", out_avals.tree, guess_avals.tree, False) solve_debug = api_util.debug_info("custom_root solve", solve, (f, initial_guess), {}, static_argnums=(0,)) - solve_jaxpr, solve_consts, solution_tree = _initial_style_jaxpr( - partial(solve, f), in_args_tree, guess_avals, solve_debug) - _check_tree("solve", "initial_guess", solution_tree, in_tree, has_aux) + solve_jaxpr, solution_avals = pe.trace_to_jaxpr( + partial(solve, f), args_avals, solve_debug) + solve_jaxpr, solve_consts = pe.separate_consts(solve_jaxpr) + _check_tree("solve", "initial_guess", solution_avals.tree, guess_flat.tree, has_aux) def linearize_and_solve(x, b): unchecked_zeros, f_jvp = api.linearize(f, x) return tangent_solve(f_jvp, b) - tangent_solve_debug = api_util.debug_info("custom_root tangent_solve", - tangent_solve, - (f, initial_guess), {}, - static_argnums=(0,)) - l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr( - linearize_and_solve, treedef_tuple((in_tree,) * 2), guess_avals * 2, - tangent_solve_debug) - _check_tree("tangent_solve", "x", out_tree, in_tree, False) + linearize_and_solve_dbg = api_util.debug_info("custom_root tangent_solve", + tangent_solve, (initial_guess, initial_guess), {}) + + + linearize_and_solve_avals = FlatTree.pack(((guess_avals, guess_avals), {})) + l_and_s_jaxpr, out_avals = pe.trace_to_jaxpr( + linearize_and_solve, linearize_and_solve_avals, linearize_and_solve_dbg) + l_and_s_jaxpr, l_and_s_consts = pe.separate_consts(l_and_s_jaxpr) + _check_tree("tangent_solve", "x", out_avals.tree, guess_flat.tree, False) all_consts = [f_consts, solve_consts, l_and_s_consts] const_lengths = _RootTuple(*_map(len, all_consts)) jaxprs = _RootTuple(f_jaxpr, solve_jaxpr, l_and_s_jaxpr) solution_flat = _custom_root( - const_lengths, jaxprs, *(_flatten(all_consts) + guess_flat)) - return tree_unflatten(solution_tree, solution_flat) + const_lengths, jaxprs, *_flatten(all_consts), *guess_flat) + return solution_avals.update(solution_flat).unflatten() @partial(custom_derivatives.custom_jvp, nondiff_argnums=(0, 1)) @@ -196,15 +198,15 @@ def _flatten(args): def _check_shapes(func_name, expected_name, actual, expected): - actual_shapes = _map(np.shape, tree_leaves(actual)) - expected_shapes = _map(np.shape, tree_leaves(expected)) + actual_shapes = _map(np.shape, actual) + expected_shapes = _map(np.shape, expected) if actual_shapes != expected_shapes: raise ValueError( f"{func_name}() output shapes must match {expected_name}, " f"got {actual_shapes} and {expected_shapes}") -@api_boundary +@partial(api_boundary, repro_api_name="jax.custom_linear_solve") def custom_linear_solve( matvec: Callable, b: Any, @@ -248,20 +250,19 @@ def custom_linear_solve( if transpose_solve is None and symmetric: transpose_solve = solve - b_flat, in_args_tree = tree_flatten((b,)) - b_avals = tuple(_map(core.get_aval, b_flat)) - - tree, = treedef_children(in_args_tree) + b_flat = FlatTree.flatten(b) + b_avals = b_flat.map(core.get_aval) + tree = b_flat.tree def _shape_checked(fun, name, has_aux): def f(x): y = fun(x) - _check_shapes(name, "b", y, b_flat) + _check_shapes(name, "b", tree_leaves(y), b_flat) return y def f_aux(x): y, aux = fun(x) - _check_shapes(name, "b", y, b_flat) + _check_shapes(name, "b", tree_leaves(y), b_flat) return y, aux return f_aux if has_aux else f @@ -269,18 +270,21 @@ def f_aux(x): matvec_debug = api_util.debug_info("custom_linear_solve", matvec, (b,), {}) # no auxiliary data assumed for matvec - matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr( - _shape_checked(matvec, "matvec", False), in_args_tree, b_avals, + args_avals = FlatTree.pack(((b_avals,),{})) + matvec_jaxpr, out_avals = pe.trace_to_jaxpr( + _shape_checked(matvec, "matvec", False), args_avals, matvec_debug) - _check_tree("matvec", "b", out_tree, tree, False) + matvec_jaxpr, matvec_consts = pe.separate_consts(matvec_jaxpr) + _check_tree("matvec", "b", out_avals.tree, tree, False) solve_debug = api_util.debug_info("custom_linear_solve solve", solve, (matvec, b), {}, static_argnums=(0,)) - solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr( - _shape_checked(partial(solve, matvec), "solve", has_aux), in_args_tree, b_avals, + solve_jaxpr, out_avals = pe.trace_to_jaxpr( + _shape_checked(partial(solve, matvec), "solve", has_aux), args_avals, solve_debug) - _check_tree("solve", "b", out_tree, tree, has_aux) + solve_jaxpr, solve_consts = pe.separate_consts(solve_jaxpr) + _check_tree("solve", "b", out_avals.tree, tree, has_aux) if transpose_solve is None: vecmat_jaxpr = tr_solve_jaxpr = None @@ -295,38 +299,40 @@ def f_aux(x): vecmat_consts = matvec_consts else: vecmat = _transpose_one_output(matvec, b) - vecmat_jaxpr, vecmat_consts, out_tree = _initial_style_jaxpr( - vecmat, in_args_tree, b_avals, transpose_solve_debug) - assert out_tree == tree + vecmat_jaxpr, out_avals = pe.trace_to_jaxpr( + vecmat, args_avals, transpose_solve_debug) + vecmat_jaxpr, vecmat_consts = pe.separate_consts(vecmat_jaxpr) + assert out_avals.tree == tree - tr_solve_jaxpr, tr_solve_consts, out_tree = _initial_style_jaxpr( + tr_solve_jaxpr, out_avals = pe.trace_to_jaxpr( _shape_checked(partial(transpose_solve, vecmat), "transpose_solve", has_aux), - in_args_tree, b_avals, transpose_solve_debug) - _check_tree("transpose_solve", "b", out_tree, tree, has_aux) + args_avals, transpose_solve_debug) + tr_solve_jaxpr, tr_solve_consts = pe.separate_consts(tr_solve_jaxpr) + _check_tree("transpose_solve", "b", out_avals.tree, tree, has_aux) all_consts = [matvec_consts, vecmat_consts, solve_consts, tr_solve_consts] const_lengths = _LinearSolveTuple(*_map(len, all_consts)) jaxprs = _LinearSolveTuple( matvec_jaxpr, vecmat_jaxpr, solve_jaxpr, tr_solve_jaxpr) - out_flat = linear_solve_p.bind( - *(_flatten(all_consts) + b_flat), - const_lengths=const_lengths, jaxprs=jaxprs) + args = _flatten(all_consts) + list(b_flat) + args = core.standard_insert_pvary(*args) + out_flat = linear_solve_p.bind(*args, const_lengths=const_lengths, jaxprs=jaxprs) - return tree_unflatten(out_tree, out_flat) + return out_avals.update(out_flat).unflatten() def _linear_solve_abstract_eval(*args, const_lengths, jaxprs): args_to_raise = args[sum(const_lengths):] - # raise aux_args to shaped arrays as well if present # number of aux args is the difference in out_avals # of solve and matvec (since they map to the same vector space) - num_aux = len(jaxprs.solve.out_avals) - len(jaxprs.matvec.out_avals) if num_aux > 0: args_to_raise += tuple(jaxprs.solve.out_avals[-num_aux:]) - return args_to_raise, jaxprs.solve.effects + out_vma = core.standard_vma_rule('linear_solve', *args_to_raise) + return (tuple(a.update(vma=out_vma) for a in args_to_raise), + jaxprs.solve.effects) def _custom_linear_solve_impl(*args, const_lengths, jaxprs): @@ -394,13 +400,17 @@ def _linear_solve_transpose_rule(cotangent, *primals, const_lengths, jaxprs): 'differentiation of custom_linear_solve') params, b = _split_linear_solve_args(primals, const_lengths) - # split off symbolic zeros in the cotangent if present - x_cotangent, _ = split_list(cotangent, [len(b)]) - assert all(ad.is_undefined_primal(x) for x in b) + if any(ad.is_undefined_primal(x) for xs in params for x in xs): + raise NotImplementedError("open an issue at https://github.com/google/jax !!") + assert all(ad.is_undefined_primal(x) for x in b) # TODO(mattjj): why? + x_cotangent, other_cotangents = split_list(cotangent, [len(b)]) + if any(type(ct) is not ad_util.Zero for ct in other_cotangents): + raise NotImplementedError("open an issue at https://github.com/google/jax !!") + del other_cotangents + x_cotangent_ = _map(ad_util.instantiate, x_cotangent) cotangent_b_full = linear_solve_p.bind( - *(_flatten(params.transpose()) + x_cotangent), + *_flatten(params.transpose()), *x_cotangent_, const_lengths=const_lengths.transpose(), jaxprs=jaxprs.transpose()) - # drop aux values in cotangent computation cotangent_b, _ = split_list(cotangent_b_full, [len(b)]) return [None] * sum(const_lengths) + cotangent_b @@ -488,7 +498,7 @@ def _linear_solve_batching_rule(axis_data, args, dims, const_lengths, jaxprs): linear_solve_p.def_impl(_custom_linear_solve_impl) linear_solve_p.def_effectful_abstract_eval(_linear_solve_abstract_eval) ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp -xla.register_initial_style_primitive(linear_solve_p) +pxla.register_initial_style_primitive(linear_solve_p) mlir.register_lowering( linear_solve_p, mlir.lower_fun(_custom_linear_solve_impl, multiple_results=True)) diff --git a/jax/_src/lax/convolution.py b/jax/_src/lax/convolution.py index 290d027cc6bc..7324263d01c7 100644 --- a/jax/_src/lax/convolution.py +++ b/jax/_src/lax/convolution.py @@ -27,7 +27,10 @@ from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir +from jax._src.sharding_impls import ( + NamedSharding, PartitionSpec as P, canonicalize_sharding) from jax._src.lax import lax +from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.typing import Array, DTypeLike @@ -53,6 +56,8 @@ class ConvDimensionNumbers(NamedTuple): None, ] +# TODO(yashkatariya): conv_general_dilated should take `out_sharding` argument +# similar to `dot_general` def conv_general_dilated( lhs: Array, rhs: Array, window_strides: Sequence[int], padding: str | Sequence[tuple[int, int]], @@ -61,11 +66,12 @@ def conv_general_dilated( dimension_numbers: ConvGeneralDilatedDimensionNumbers = None, feature_group_count: int = 1, batch_group_count: int = 1, precision: lax.PrecisionLike = None, - preferred_element_type: DTypeLike | None = None) -> Array: + preferred_element_type: DTypeLike | None = None, + out_sharding: NamedSharding | P | None = None) -> Array: """General n-dimensional convolution operator, with optional dilation. Wraps XLA's `Conv - `_ + `_ operator. Args: @@ -101,6 +107,14 @@ def conv_general_dilated( preferred_element_type: Optional. Either ``None``, which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype. + out_sharding: Optional. Specifies how the output array should be sharded + across devices in multi-device computation. Can be a + :class:`~jax.sharding.NamedSharding`, a :class:`~jax.sharding.PartitionSpec` + (``P``), or ``None`` (default). When specified, the output will be sharded + according to the given sharding specification. Primarily used in explicit + sharding mode. + See the `explicit sharding tutorial `_ + for more details. Returns: An array containing the convolution result. @@ -130,6 +144,7 @@ def conv_general_dilated( 'NCHW')`` (for a 2D convolution). """ dnums = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers) + out_sharding = canonicalize_sharding(out_sharding, 'dot_general') if lhs_dilation is None: lhs_dilation = (1,) * (lhs.ndim - 2) elif isinstance(padding, str) and not len(lhs_dilation) == lhs_dilation.count(1): @@ -156,8 +171,12 @@ def conv_general_dilated( f"sequence of (low, high) pairs, got {padding}") from e preferred_element_type = ( - None if preferred_element_type is None else - dtypes.canonicalize_dtype(np.dtype(preferred_element_type))) + None if preferred_element_type is None + else dtypes.check_and_canonicalize_user_dtype( + preferred_element_type, "conv_general_dilated" + ) + ) + lhs, rhs = core.standard_insert_pvary(lhs, rhs) return conv_general_dilated_p.bind( lhs, rhs, window_strides=tuple(window_strides), padding=tuple(padding), lhs_dilation=tuple(lhs_dilation), rhs_dilation=tuple(rhs_dilation), @@ -165,7 +184,8 @@ def conv_general_dilated( feature_group_count=feature_group_count, batch_group_count=batch_group_count, precision=lax.canonicalize_precision(precision), - preferred_element_type=preferred_element_type) + preferred_element_type=preferred_element_type, + out_sharding=out_sharding) ### convenience wrappers around traceables @@ -243,7 +263,7 @@ def _conv_transpose_padding(k, s, padding): Args: k: int: kernel dimension. s: int: dimension stride value. - padding: 'same' or 'valid' padding mode for original forward conv. + padding: tuple of ints or 'same' or 'valid' padding mode for original forward conv. Returns: 2-tuple: ints: before and after padding for transposed convolution. @@ -257,12 +277,15 @@ def _conv_transpose_padding(k, s, padding): elif padding == 'VALID': pad_len = k + s - 2 + max(k - s, 0) pad_a = k - 1 + elif isinstance(padding, tuple): + pads = tuple(k - p - 1 for p in padding) + pad_a = pads[0] + pad_len = sum(pads) else: - raise ValueError('Padding mode must be `SAME` or `VALID`.') + raise ValueError(f"Invalid padding mode: {padding}") pad_b = pad_len - pad_a return pad_a, pad_b - def _flip_axes(x, axes): """Flip ndarray 'x' along each axis specified in axes tuple.""" for axis in axes: @@ -276,19 +299,32 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int], dimension_numbers: ConvGeneralDilatedDimensionNumbers = None, transpose_kernel: bool = False, precision: lax.PrecisionLike = None, - preferred_element_type: DTypeLike | None = None) -> Array: + preferred_element_type: DTypeLike | None = None, + use_consistent_padding: bool = False) -> Array: """Convenience wrapper for calculating the N-d convolution "transpose". This function directly calculates a fractionally strided conv rather than indirectly calculating the gradient (transpose) of a forward convolution. + Notes: + TensorFlow/Keras Compatibility: By default, JAX does NOT reverse the + kernel's spatial dimensions. This differs from TensorFlow's "Conv2DTranspose" + and similar frameworks, which flip spatial axes and swap input/output channels. + + To match TensorFlow/Keras behavior, set "transpose_kernel=True" . + Args: lhs: a rank `n+2` dimensional input array. rhs: a rank `n+2` dimensional array of kernel weights. strides: sequence of `n` integers, sets fractional stride. - padding: 'SAME', 'VALID' will set as transpose of corresponding forward - conv, or a sequence of `n` integer 2-tuples describing before-and-after - padding for each `n` spatial dimension. + padding: 'SAME', 'VALID', or a sequence of `n` integer 2-tuples describing before-and-after + padding for each spatial dimension. If `use_consistent_padding=True`, this is interpreted + as the padding of the corresponding forward conv, which effectively adds + `dilation * (kernel_size - 1) - padding` zero padding to each side + of the input so that `conv_transpose` becomes the gradient of `conv` when given the same padding + and stride arguments. This is the behavior in PyTorch. If `use_consistent_padding=False`, + the 'SAME' and 'VALID' strings are interpreted as the padding of the corresponding forward conv, + but integer tuples are interpreted as padding for the transposed convolution. rhs_dilation: `None`, or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of `rhs`. RHS dilation is also known as atrous convolution. @@ -306,7 +342,10 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int], preferred_element_type: Optional. Either ``None``, which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype. - + use_consistent_padding : In older versions of jax, the `padding` argument was interpreted differently + depending on whether it was a string or a sequence of integers. Strings were interpreted as padding + for the forward convolution, while integers were interpreted as padding for the transposed convolution. + If `use_consistent_padding` is False, this inconsistent behavior is preserved for backwards compatibility. Returns: Transposed N-d convolution, with output padding following the conventions of keras.layers.Conv2DTranspose. @@ -330,15 +369,16 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int], k_shape = np.take(rhs.shape, dn.rhs_spec) k_sdims = k_shape[2:] # Calculate correct output shape given padding and strides. + if rhs_dilation is None: + rhs_dilation = (1,) * (rhs.ndim - 2) pads: str | Sequence[tuple[int, int]] - if isinstance(padding, str) and padding in {'SAME', 'VALID'}: - if rhs_dilation is None: - rhs_dilation = (1,) * (rhs.ndim - 2) + if use_consistent_padding or (isinstance(padding, str) and padding in {'SAME', 'VALID'}): effective_k_size = map(lambda k, r: core.dilate_dim(k, r), k_sdims, rhs_dilation) - pads = [_conv_transpose_padding(k, s, padding) - for k,s in zip(effective_k_size, strides)] + replicated_padding = [padding] * len(strides) if isinstance(padding, str) else padding + pads = tuple(_conv_transpose_padding(k, s, p) + for k,s,p in zip(effective_k_size, strides, replicated_padding)) else: - pads = padding + pads = padding if transpose_kernel: # flip spatial dims and swap input / output channel axes rhs = _flip_axes(rhs, np.array(dn.rhs_spec)[2:]) @@ -414,10 +454,32 @@ def _conv_general_dilated_shape_rule( return tuple(np.take(out_trans, np.argsort(out_perm))) +def _conv_general_dilated_sharding_rule( + lhs: core.ShapedArray, rhs: core.ShapedArray, *, window_strides, padding, + lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, + batch_group_count, out_sharding, **unused_kwargs): + if out_sharding is not None: + assert isinstance(out_sharding, NamedSharding) + return out_sharding + # Only allow if rhs is fully replicated and lhs's feature dim is not sharded + if ((rhs.sharding.mesh.empty or rhs.sharding.is_fully_replicated) and + lhs.sharding.spec[dimension_numbers.lhs_spec[1]] is None): + out_shape = _conv_general_dilated_shape_rule( + lhs, rhs, window_strides=window_strides, padding=padding, + lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation, + dimension_numbers=dimension_numbers, + feature_group_count=feature_group_count, + batch_group_count=batch_group_count) + return lax.slicing._get_sharding_for_varying_out_shape( + out_shape, lhs, "conv_general_dilated") + raise core.ShardingTypeError( + "Please specify the output sharding via `out_sharding` parameter of" + " `conv_general_dilated`") + def _conv_general_dilated_dtype_rule( lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, preferred_element_type, **unused_kwargs): - result_dtype = lax.naryop_dtype_rule(lax._input_dtype, [lax._any, lax._any], + result_dtype = lax.naryop_dtype_rule(lax.input_dtype, [lax._any, lax._any], 'conv_general_dilated', lhs, rhs) if preferred_element_type is None: return result_dtype @@ -457,7 +519,7 @@ def _conv_general_dilated_dtype_rule( def _conv_general_dilated_transpose_lhs( g, lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, batch_group_count, - precision, preferred_element_type): + precision, preferred_element_type, out_sharding): assert type(dimension_numbers) is ConvDimensionNumbers assert batch_group_count == 1 or feature_group_count == 1 rhs_shape = rhs.shape @@ -486,7 +548,8 @@ def _conv_general_dilated_transpose_lhs( dimension_numbers=trans_dimension_numbers, feature_group_count=feature_group_count, batch_group_count=1, precision=precision, - preferred_element_type=preferred_element_type) + preferred_element_type=preferred_element_type, + out_sharding=lhs.aval.sharding) if batch_group_count > 1: out = _reshape_axis_out_of(lhs_spec[1], batch_group_count, out) out = _reshape_axis_into(lhs_spec[1], lhs_spec[0], out) @@ -495,7 +558,7 @@ def _conv_general_dilated_transpose_lhs( def _conv_general_dilated_transpose_rhs( g, lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers: ConvDimensionNumbers, feature_group_count: int, - batch_group_count: int, precision, preferred_element_type): + batch_group_count: int, precision, preferred_element_type, out_sharding): assert type(dimension_numbers) is ConvDimensionNumbers if np.size(g) == 0: # Avoids forming degenerate convolutions where the RHS has spatial size 0. @@ -522,13 +585,17 @@ def _conv_general_dilated_transpose_rhs( dimension_numbers=trans_dimension_numbers, feature_group_count=feature_group_count, batch_group_count=batch_group_count, precision=precision, - preferred_element_type=preferred_element_type) + preferred_element_type=preferred_element_type, + out_sharding=rhs.aval.sharding) def _conv_general_dilated_batch_rule( - batched_args, batch_dims, *, window_strides, padding, + axis_data, + batched_args, + batch_dims, + *, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, batch_group_count, precision, - preferred_element_type, **unused_kwargs): + preferred_element_type, out_sharding, **unused_kwargs): assert batch_group_count == 1 or feature_group_count == 1 lhs, rhs = batched_args lhs_bdim, rhs_bdim = batch_dims @@ -551,11 +618,25 @@ def _conv_general_dilated_batch_rule( rhs_dilation=rhs_dilation, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, batch_group_count=batch_group_count) + if out_sharding is not None: + out_sharding = batching.get_sharding_for_vmap(axis_data, out_sharding, 0) return lax.full( (0,) + shape, 0, dtype=lhs.dtype if preferred_element_type is None - else preferred_element_type), 0 - + else preferred_element_type, + sharding=out_sharding), 0 + + def get_out_sharding(axis): + if out_sharding is None: + return None + val = axis_data.explicit_mesh_axis + if not val: + return out_sharding + if out_sharding.spec[axis] is not None: + # Batch dim must not already be sharded. + raise NotImplementedError + return NamedSharding(out_sharding.mesh, + P(*util.tuple_update(out_sharding.spec, axis, val))) if lhs_bdim is not None and rhs_bdim is not None: assert lhs.shape[lhs_bdim] == rhs.shape[rhs_bdim] @@ -570,7 +651,8 @@ def _conv_general_dilated_batch_rule( new_lhs, new_rhs, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count=feature_group_count, batch_group_count=batch_group_count, precision=precision, - preferred_element_type=preferred_element_type) + preferred_element_type=preferred_element_type, + out_sharding=get_out_sharding(out_spec[1])) out = _reshape_axis_out_of(out_spec[1], lhs.shape[lhs_bdim], out) return out, out_spec[1] @@ -580,7 +662,8 @@ def _conv_general_dilated_batch_rule( out = conv_general_dilated(new_lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, precision=precision, - preferred_element_type=preferred_element_type) + preferred_element_type=preferred_element_type, + out_sharding=get_out_sharding(out_spec[0])) out = _reshape_axis_out_of(out_spec[0], lhs.shape[lhs_bdim], out) return out, out_spec[0] else: @@ -594,7 +677,8 @@ def _conv_general_dilated_batch_rule( lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, batch_group_count, precision=precision, - preferred_element_type=preferred_element_type) + preferred_element_type=preferred_element_type, + out_sharding=get_out_sharding(out_spec[0])) out = _reshape_axis_out_of(out_spec[0], lhs.shape[lhs_bdim], out) return out, out_spec[0] @@ -605,10 +689,13 @@ def _conv_general_dilated_batch_rule( lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, batch_group_count, precision=precision, - preferred_element_type=preferred_element_type) + preferred_element_type=preferred_element_type, + out_sharding=get_out_sharding(out_spec[1])) out = _reshape_axis_out_of(out_spec[1], rhs.shape[rhs_bdim], out) return out, out_spec[1] else: + if out_sharding is not None: + raise NotImplementedError # groups need to be outermost, so we need to factor them out of the # rhs output feature dim, then factor the batch dim into the remaining rhs # output feature dim, then put groups back in. We do something @@ -625,7 +712,8 @@ def _conv_general_dilated_batch_rule( lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, batch_group_count, precision=precision, - preferred_element_type=preferred_element_type) + preferred_element_type=preferred_element_type, + out_sharding=get_out_sharding(out_spec[1])) out = _reshape_axis_out_of(out_spec[1], group_count, out) out = _reshape_axis_out_of(out_spec[1] + 1, rhs.shape[rhs_bdim], out) out = _reshape_axis_into(out_spec[1], out_spec[1] + 1, out) @@ -633,13 +721,16 @@ def _conv_general_dilated_batch_rule( conv_general_dilated_p = lax.standard_primitive( _conv_general_dilated_shape_rule, _conv_general_dilated_dtype_rule, - 'conv_general_dilated') + 'conv_general_dilated', + sharding_rule=_conv_general_dilated_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'conv_general_dilated')) ad.defbilinear(conv_general_dilated_p, _conv_general_dilated_transpose_lhs, _conv_general_dilated_transpose_rhs) -batching.primitive_batchers[conv_general_dilated_p] = \ - _conv_general_dilated_batch_rule + +batching.fancy_primitive_batchers[conv_general_dilated_p] = _conv_general_dilated_batch_rule +batching.skippable_batchers[conv_general_dilated_p] = lambda _: () def _complex_mul(mul, x, y): # We use a trick for complex multiplication sometimes attributed to Gauss @@ -665,7 +756,7 @@ def _complex_mul(mul, x, y): def _conv_general_dilated_lower( ctx, lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, - batch_group_count, precision, preferred_element_type, + batch_group_count, precision, preferred_element_type, out_sharding, expand_complex_convolutions=False, **unused_kwargs): lhs_aval, rhs_aval = ctx.avals_in aval_out, = ctx.avals_out @@ -684,7 +775,8 @@ def _conv_general_dilated_lower( rhs_dilation=rhs_dilation, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, batch_group_count=batch_group_count, precision=precision, - preferred_element_type=preferred_element_type)), + preferred_element_type=preferred_element_type, + out_sharding=out_sharding)), multiple_results=False) return complex_conv(ctx, lhs, rhs) @@ -702,7 +794,7 @@ def _conv_general_dilated_lower( num_spatial_dims = len(rhs_spec) - 2 if len(padding) == 0: padding = np.zeros((0, 2), dtype=np.int64) - window_reversal = mlir.dense_bool_array([False] * num_spatial_dims) + window_reversal = ir.DenseBoolArrayAttr.get([False] * num_spatial_dims) if (not core.is_constant_shape(window_strides) or not core.is_constant_shape(lhs_dilation) or not core.is_constant_shape(rhs_dilation) or @@ -711,21 +803,18 @@ def _conv_general_dilated_lower( # TODO(https://github.com/openxla/stablehlo/issues/1268) raise NotImplementedError("Convolutions with non-static strides, dilation, feature_group_count, or batch_group_count") if all(core.is_constant_shape(p) for p in padding): - return [ - hlo.convolution( - mlir.aval_to_ir_type(aval_out), - lhs, - rhs, - dimension_numbers=dnums, - feature_group_count=mlir.i64_attr(feature_group_count), - batch_group_count=mlir.i64_attr(batch_group_count), - window_strides=mlir.dense_int_array(window_strides), - padding=mlir.dense_int_elements(padding), - lhs_dilation=mlir.dense_int_array(lhs_dilation), - rhs_dilation=mlir.dense_int_array(rhs_dilation), - window_reversal=window_reversal, - precision_config=lax.precision_attr(precision)) - ] + out = hlo.convolution( + mlir.aval_to_ir_type(aval_out), lhs, rhs, + dimension_numbers=dnums, + feature_group_count=mlir.i64_attr(feature_group_count), + batch_group_count=mlir.i64_attr(batch_group_count), + window_strides=mlir.dense_int_array(window_strides), + padding=mlir.dense_int_elements(padding), + lhs_dilation=mlir.dense_int_array(lhs_dilation), + rhs_dilation=mlir.dense_int_array(rhs_dilation), + window_reversal=window_reversal, + precision_config=lax.precision_attr(precision)) + return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)] else: # d_padding will be an array i32[N, 2] with pad_lo and pad_hi for each # spatial dimension. diff --git a/jax/_src/lax/fft.py b/jax/_src/lax/fft.py index 6ca1a4abd193..5d0b4b14fd0d 100644 --- a/jax/_src/lax/fft.py +++ b/jax/_src/lax/fft.py @@ -21,8 +21,6 @@ import numpy as np -from jax import lax - from jax._src import dispatch from jax._src import dtypes from jax._src.api import jit, linear_transpose, ShapeDtypeStruct @@ -30,6 +28,7 @@ from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir +from jax._src.lax import lax from jax._src.lib.mlir.dialects import hlo __all__ = [ @@ -65,7 +64,7 @@ def _str_to_fft_type(s: str) -> FftType: else: raise ValueError(f"Unknown FFT type '{s}'") -@partial(jit, static_argnums=(1, 2)) +@jit(static_argnums=(1, 2)) def fft(x, fft_type: FftType | str, fft_lengths: Sequence[int]): if isinstance(fft_type, str): typ = _str_to_fft_type(fft_type) @@ -124,7 +123,7 @@ def fft_abstract_eval(x, fft_type, fft_lengths): f"be equal to fft_lengths {fft_lengths}") shape = x.shape dtype = x.dtype - return x.update(shape=shape, dtype=dtype) + return x.update(shape=shape, dtype=dtype, vma=x.vma) def _fft_lowering(ctx, x, *, fft_type, fft_lengths): if not is_constant_shape(fft_lengths): @@ -141,7 +140,7 @@ def _naive_rfft(x, fft_lengths): n = fft_lengths[-1] return y[..., : n//2 + 1] -@partial(jit, static_argnums=1) +@jit(static_argnums=1) def _rfft_transpose(t, fft_lengths): # The transpose of RFFT can't be expressed only in terms of irfft. Instead of # manually building up larger twiddle matrices (which would increase the diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 86a75ada63ad..492f2db014e5 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -28,10 +28,6 @@ import numpy as np -from jax import tree_util -from jax.sharding import Sharding -from jax.tree_util import tree_map - from jax._src import ad_util from jax._src import api from jax._src import api_util @@ -41,38 +37,41 @@ from jax._src import dispatch from jax._src import dtypes from jax._src import effects +from jax._src import literals from jax._src import linear_util as lu from jax._src import pjit from jax._src import pretty_printer as pp from jax._src import source_info_util from jax._src import state +from jax._src import tree_util from jax._src import util from jax._src.abstract_arrays import array_types -from jax._src.core import (Primitive, UnshapedArray, ShapedArray, - abstract_token, canonicalize_shape) +from jax._src.core import (Primitive, ShapedArray, abstract_token, + canonicalize_shape) from jax._src.errors import UnexpectedTracerError +from jax._src.hashable_array import HashableArray from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import pxla -from jax._src.interpreters import xla -from jax._src.interpreters.batching import RaggedAxis from jax._src.lax import slicing -from jax._src import mesh as mesh_lib +from jax._src.lax import utils as lax_utils +from jax._src.mesh import get_abstract_mesh, get_concrete_mesh from jax._src.lax.utils import ( - _input_dtype, dtype_to_string, standard_abstract_eval, - standard_multi_result_abstract_eval, standard_primitive) + input_dtype, dtype_to_string, standard_multi_result_abstract_eval, + standard_primitive) +from jax._src.core import typeof from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo -from jax._src.lib import xla_extension_version -from jax._src.sharding_impls import (PmapSharding, NamedSharding, - PartitionSpec as P, canonicalize_sharding) -from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape -from jax._src.util import (NumpyComplexWarning, cache, canonicalize_axis, +from jax._src.sharding import Sharding +from jax._src.sharding_impls import ( + PmapSharding, NamedSharding, PartitionSpec as P, canonicalize_sharding, flatten_spec) +from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DType, DTypeLike, Shape +from jax._src.util import (cache, canonicalize_axis, safe_map, safe_zip, split_list, weakref_lru_cache, - foreach) + foreach, tuple_insert) _max = builtins.max _min = builtins.min @@ -103,11 +102,7 @@ def _check_static_shape(shape: Shape): raise TypeError(msg) assert shapes - if config.dynamic_shapes.value: - # pass dynamic shapes through unchecked - return - else: - foreach(_check_static_shape, shapes) + foreach(_check_static_shape, shapes) def _try_broadcast_shapes(*shapes: tuple[int, ...], name: str) -> tuple[int, ...]: """ @@ -141,7 +136,9 @@ def asarray(x: ArrayLike) -> Array: if isinstance(x, Array): return x elif isinstance(x, (bool, np.ndarray, np.generic)): - return _convert_element_type(x, weak_type=False) # pytype: disable=bad-return-type + return _convert_element_type(x, weak_type=False) + elif isinstance(x, literals.TypedNdArray): + return _convert_element_type(x, weak_type=x.weak_type) elif isinstance(x, (int, float, builtins.complex)): return _convert_element_type(dtypes.coerce_to_array(x), weak_type=True) else: @@ -238,43 +235,11 @@ def broadcast_shardings(*avals): new_spec = P(*(None,) * (ndim - a.ndim) + a.sharding.spec) new_shape = (1,) * (ndim - a.ndim) + a.shape aval_list.append(a.update(shape=new_shape, - sharding=a.sharding.with_spec(new_spec))) + sharding=a.sharding.update(spec=new_spec))) return broadcasting_sharding_rule('broadcast_shardings', *aval_list) def _identity(x, **_): return x -def _extract_tracers_dyn_shape( - shape: Sequence[int | core.Tracer] - ) -> tuple[list[core.Tracer], list[int | None]]: - # Given a sequence representing a shape, pull out Tracers, replacing with None - if config.dynamic_shapes.value: - # We must gate this behavior under a flag because otherwise the errors - # raised are different (and have worse source provenance information). - dyn_shape = [d for d in shape if isinstance(d, core.Tracer)] - static_shape = [None if isinstance(d, core.Tracer) else d for d in shape] - return dyn_shape, static_shape - else: - return [], list(shape) # type: ignore - -def _merge_dyn_shape( - static_shape: Sequence[int | None], - dyn_shape: Sequence[Any], - ) -> tuple[int | mlir.Value | core.Tracer, ...]: - # Replace Nones in static_shape with elements of dyn_shape, in order - dyn_shape_it = iter(dyn_shape) - shape = tuple(next(dyn_shape_it) if d is None else d for d in static_shape) - assert next(dyn_shape_it, None) is None - return shape - -def _dyn_shape_staging_rule(trace, prim, out_aval, *args, **params): - source_info = source_info_util.current() - out_tracer = pe.DynamicJaxprTracer(trace, out_aval, source_info) - eqn = pe.new_jaxpr_eqn([trace.getvar(x) for x in args], - [trace.makevar(out_tracer)], - prim, params, core.no_effects, source_info) - trace.frame.add_eqn(eqn) - return out_tracer - ### traceables @@ -369,6 +334,7 @@ def nextafter(x1: ArrayLike, x2: ArrayLike) -> Array: For the smallest usable (i.e. normal) float, use ``tiny`` of ``jnp.finfo``. """ + x1, x2 = core.standard_insert_pvary(x1, x2) return nextafter_p.bind(x1, x2) @export @@ -483,14 +449,41 @@ def is_finite(x: ArrayLike) -> Array: """ return is_finite_p.bind(x) +class Tolerance: + """Specify the tolerances used for computing unary functions. + + Maximum two tolerances can be specified: (atol and rtol) or (atol and ulps). + """ + + def __init__(self, atol: float = 0.0, rtol: float = 0.0, ulps: int = 0): + if atol < 0.0 or rtol < 0.0 or ulps < 0.0: + raise ValueError('Tolerances must be non-negative.') + if atol == 0.0 and rtol == 0.0 and ulps == 0: + raise ValueError('At least one of atol, rtol, or ulps must be set.') + + self.atol = atol + self.rtol = rtol + self.ulps = ulps + + +class AccuracyMode(enum.Enum): + HIGHEST = 1 + DEFAULT = 2 + @export -def exp(x: ArrayLike) -> Array: +def exp(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise exponential: :math:`e^x`. This function lowers directly to the `stablehlo.exponential`_ operation. Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -502,10 +495,10 @@ def exp(x: ArrayLike) -> Array: .. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential """ - return exp_p.bind(x) + return exp_p.bind(x, accuracy=accuracy) -@export -def exp2(x: ArrayLike) -> Array: + +def exp2(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise base-2 exponential: :math:`2^x`. This function is implemented in terms of the `stablehlo.exponential`_ @@ -513,6 +506,12 @@ def exp2(x: ArrayLike) -> Array: Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -525,10 +524,10 @@ def exp2(x: ArrayLike) -> Array: .. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential .. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply """ - return exp2_p.bind(x) + return exp2_p.bind(x, accuracy=accuracy) @export -def expm1(x: ArrayLike) -> Array: +def expm1(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise :math:`e^{x} - 1`. This function lowers directly to the `stablehlo.exponential_minus_one`_ @@ -537,6 +536,12 @@ def expm1(x: ArrayLike) -> Array: Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -548,16 +553,22 @@ def expm1(x: ArrayLike) -> Array: .. _stablehlo.exponential_minus_one: https://openxla.org/stablehlo/spec#exponential_minus_one """ - return expm1_p.bind(x) + return expm1_p.bind(x, accuracy=accuracy) @export -def log(x: ArrayLike) -> Array: +def log(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise natural logarithm: :math:`\mathrm{log}(x)`. This function lowers directly to the `stablehlo.log`_ operation. Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -568,10 +579,10 @@ def log(x: ArrayLike) -> Array: .. _stablehlo.log: https://openxla.org/stablehlo/spec#log """ - return log_p.bind(x) + return log_p.bind(x, accuracy=accuracy) @export -def log1p(x: ArrayLike) -> Array: +def log1p(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise :math:`\mathrm{log}(1 + x)`. This function lowers directly to the `stablehlo.log_plus_one`_ operation. @@ -580,6 +591,12 @@ def log1p(x: ArrayLike) -> Array: Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -591,16 +608,22 @@ def log1p(x: ArrayLike) -> Array: .. _stablehlo.log_plus_one: https://openxla.org/stablehlo/spec#log_plus_one """ - return log1p_p.bind(x) + return log1p_p.bind(x, accuracy=accuracy) @export -def tanh(x: ArrayLike) -> Array: +def tanh(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise hyperbolic tangent: :math:`\mathrm{tanh}(x)`. This function lowers directly to the `stablehlo.tanh`_ operation. Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -613,10 +636,11 @@ def tanh(x: ArrayLike) -> Array: .. _stablehlo.tanh: https://openxla.org/stablehlo/spec#tanh """ - return tanh_p.bind(x) + return tanh_p.bind(x, accuracy=accuracy) @export -def logistic(x: ArrayLike) -> Array: + +def logistic(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`. There is no HLO logistic/sigmoid primitive, so this lowers to a sequence @@ -632,10 +656,10 @@ def logistic(x: ArrayLike) -> Array: See also: - :func:`jax.nn.sigmoid`: an alternative API for this functionality. """ - return logistic_p.bind(x) + return logistic_p.bind(x, accuracy=accuracy) @export -def sin(x: ArrayLike) -> Array: +def sin(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise sine: :math:`\mathrm{sin}(x)`. For floating-point inputs, this function lowers directly to the @@ -644,6 +668,12 @@ def sin(x: ArrayLike) -> Array: Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -656,10 +686,10 @@ def sin(x: ArrayLike) -> Array: .. _stablehlo.sine: https://openxla.org/stablehlo/spec#sine """ - return sin_p.bind(x) + return sin_p.bind(x, accuracy=accuracy) @export -def cos(x: ArrayLike) -> Array: +def cos(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise cosine: :math:`\mathrm{cos}(x)`. For floating-point inputs, this function lowers directly to the @@ -668,6 +698,12 @@ def cos(x: ArrayLike) -> Array: Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -680,7 +716,7 @@ def cos(x: ArrayLike) -> Array: .. _stablehlo.cosine: https://openxla.org/stablehlo/spec#cosine """ - return cos_p.bind(x) + return cos_p.bind(x, accuracy=accuracy) @export def atan2(x: ArrayLike, y: ArrayLike) -> Array: @@ -704,6 +740,7 @@ def atan2(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.atan2: https://openxla.org/stablehlo/spec#atan2 """ + x, y = core.standard_insert_pvary(x, y) return atan2_p.bind(x, y) @export @@ -773,6 +810,7 @@ def complex(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.complex: https://openxla.org/stablehlo/spec#complex """ + x, y = core.standard_insert_pvary(x, y) return complex_p.bind(x, y) @export @@ -844,6 +882,7 @@ def pow(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.convert: https://openxla.org/stablehlo/spec#convert .. _stablehlo.pow: https://openxla.org/stablehlo/spec#pow """ + x, y = core.standard_insert_pvary(x, y) return pow_p.bind(x, y) @export @@ -861,20 +900,27 @@ def integer_pow(x: ArrayLike, y: int) -> Array: An array of the same shape and dtype as ``x`` containing the elementwise power. See also: - :func:`jax.lax.pow`: Elementwise pwoer where ``y`` is an array. + :func:`jax.lax.pow`: Elementwise power where ``y`` is an array. .. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply """ return integer_pow_p.bind(x, y=y) + @export -def sqrt(x: ArrayLike) -> Array: +def sqrt(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise square root: :math:`\sqrt{x}`. This function lowers directly to the `stablehlo.sqrt`_ operation. Args: x: Input array. Must have floating or complex dtype. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: An array of the same shape and dtype as ``x`` containing the square root. @@ -886,16 +932,22 @@ def sqrt(x: ArrayLike) -> Array: .. _stablehlo.sqrt: https://openxla.org/stablehlo/spec#sqrt """ - return sqrt_p.bind(x) + return sqrt_p.bind(x, accuracy=accuracy) @export -def rsqrt(x: ArrayLike) -> Array: +def rsqrt(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise reciprocal square root: :math:`1 \over \sqrt{x}`. This function lowers directly to the `stablehlo.rsqrt`_ operation. Args: x: Input array. Must have floating or complex dtype. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: An array of the same shape and dtype as ``x`` containing the @@ -908,16 +960,22 @@ def rsqrt(x: ArrayLike) -> Array: .. _stablehlo.rsqrt: https://openxla.org/stablehlo/spec#rsqrt """ - return rsqrt_p.bind(x) + return rsqrt_p.bind(x, accuracy=accuracy) @export -def cbrt(x: ArrayLike) -> Array: +def cbrt(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise cube root: :math:`\sqrt[3]{x}`. This function lowers directly to the `stablehlo.cbrt`_ operation. Args: x: Input array. Must have floating or complex dtype. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: An array of the same shape and dtype as ``x`` containing the cube root. @@ -929,7 +987,7 @@ def cbrt(x: ArrayLike) -> Array: .. _stablehlo.cbrt: https://openxla.org/stablehlo/spec#cbrt """ - return cbrt_p.bind(x) + return cbrt_p.bind(x, accuracy=accuracy) @export def bitwise_not(x: ArrayLike) -> Array: @@ -979,6 +1037,7 @@ def bitwise_and(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.and: https://openxla.org/stablehlo/spec#and """ + x, y = core.standard_insert_pvary(x, y) return and_p.bind(x, y) @export @@ -1005,6 +1064,7 @@ def bitwise_or(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.or: https://openxla.org/stablehlo/spec#or """ + x, y = core.standard_insert_pvary(x, y) return or_p.bind(x, y) @export @@ -1031,6 +1091,7 @@ def bitwise_xor(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.xor: https://openxla.org/stablehlo/spec#xor """ + x, y = core.standard_insert_pvary(x, y) return xor_p.bind(x, y) @export @@ -1095,6 +1156,7 @@ def add(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.add: https://openxla.org/stablehlo/spec#add """ + x, y = core.standard_insert_pvary(x, y) return add_p.bind(x, y) @export @@ -1118,6 +1180,7 @@ def sub(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.subtract: https://openxla.org/stablehlo/spec#subtract """ + x, y = core.standard_insert_pvary(x, y) return sub_p.bind(x, y) @export @@ -1141,6 +1204,7 @@ def mul(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply """ + x, y = core.standard_insert_pvary(x, y) return mul_p.bind(x, y) @export @@ -1170,6 +1234,7 @@ def div(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.divide: https://openxla.org/stablehlo/spec#divide """ + x, y = core.standard_insert_pvary(x, y) return div_p.bind(x, y) @export @@ -1197,6 +1262,7 @@ def rem(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.remainder: https://openxla.org/stablehlo/spec#remainder """ + x, y = core.standard_insert_pvary(x, y) return rem_p.bind(x, y) @export @@ -1222,6 +1288,7 @@ def max(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.maximum: https://openxla.org/stablehlo/spec#maximum """ + x, y = core.standard_insert_pvary(x, y) return max_p.bind(x, y) @export @@ -1247,6 +1314,7 @@ def min(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.minimum: https://openxla.org/stablehlo/spec#minimum """ + x, y = core.standard_insert_pvary(x, y) return min_p.bind(x, y) @export @@ -1272,6 +1340,7 @@ def shift_left(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.shift_left: https://openxla.org/stablehlo/spec#shift_left """ + x, y = core.standard_insert_pvary(x, y) return shift_left_p.bind(x, y) @export @@ -1298,6 +1367,7 @@ def shift_right_arithmetic(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.shift_right_arithmetic: https://openxla.org/stablehlo/spec#shift_right_arithmetic """ + x, y = core.standard_insert_pvary(x, y) return shift_right_arithmetic_p.bind(x, y) @export @@ -1324,6 +1394,7 @@ def shift_right_logical(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.shift_right_logical: https://openxla.org/stablehlo/spec#shift_right_logical """ + x, y = core.standard_insert_pvary(x, y) return shift_right_logical_p.bind(x, y) @export @@ -1354,6 +1425,7 @@ def eq(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ + x, y = core.standard_insert_pvary(x, y) return eq_p.bind(x, y) @export @@ -1384,6 +1456,7 @@ def ne(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ + x, y = core.standard_insert_pvary(x, y) return ne_p.bind(x, y) @export @@ -1414,6 +1487,7 @@ def ge(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ + x, y = core.standard_insert_pvary(x, y) return ge_p.bind(x, y) @export @@ -1444,6 +1518,7 @@ def gt(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ + x, y = core.standard_insert_pvary(x, y) return gt_p.bind(x, y) @export @@ -1474,6 +1549,7 @@ def le(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ + x, y = core.standard_insert_pvary(x, y) return le_p.bind(x, y) @export @@ -1504,6 +1580,7 @@ def lt(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ + x, y = core.standard_insert_pvary(x, y) return lt_p.bind(x, y) @export @@ -1539,20 +1616,20 @@ def convert_element_type(operand: ArrayLike, .. _stablehlo.convert: https://openxla.org/stablehlo/spec#convert .. _x64 mode: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision """ + new_dtype = dtypes.check_and_canonicalize_user_dtype( + new_dtype, 'convert_element_type') return _convert_element_type(operand, new_dtype, weak_type=False) # type: ignore[unused-ignore,bad-return-type] def _convert_element_type( - operand: ArrayLike, - new_dtype: DTypeLike | dtypes.ExtendedDType | None = None, + operand: ArrayLike | literals.TypedNdArray, + new_dtype: DType | None = None, weak_type: bool = False, sharding: Sharding | None = None, warn_on_complex_to_real_cast: bool = True): if hasattr(operand, '__jax_array__'): operand = operand.__jax_array__() - # Don't canonicalize old_dtype because x64 context might cause - # un-canonicalized operands to be passed in. - old_dtype = dtypes.dtype(operand, canonicalize=False) + old_dtype = dtypes.dtype(operand) if (isinstance(new_dtype, dtypes.ExtendedDType) or isinstance(old_dtype, dtypes.ExtendedDType)): @@ -1573,18 +1650,16 @@ def _convert_element_type( "Instead, convert to and from their representation dtypes, e.g.:\n" f"{dtype_to_string(old_dtype)} -> {dtype_to_string(old_rep_dtype)} " f"-> {dtype_to_string(new_rep_dtype)} -> {dtype_to_string(new_dtype)}") + if isinstance(new_dtype, dtypes.ExtendedDType): return to_edtype_p.bind(operand, edtype=new_dtype) return from_edtype_p.bind(operand, dtype=np.dtype(new_dtype)) - new_dtype = type_cast(DTypeLike | None, new_dtype) - old_weak_type = dtypes.is_weakly_typed(operand) if new_dtype is None: new_dtype = old_dtype else: - new_dtype = np.dtype(new_dtype) - new_dtype = dtypes.dtype(new_dtype, canonicalize=True) + assert isinstance(new_dtype, DType), new_dtype if sharding is not None and not isinstance(sharding, Sharding): raise ValueError(f'{sharding=} must be an instance of jax.sharding.Sharding') @@ -1593,16 +1668,16 @@ def _convert_element_type( dtypes.issubdtype(old_dtype, np.complexfloating) and not dtypes.issubdtype(new_dtype, np.complexfloating)): msg = "Casting complex values to real discards the imaginary part" - warnings.warn(msg, NumpyComplexWarning, stacklevel=2) + warnings.warn(msg, np.exceptions.ComplexWarning, stacklevel=2) # Python has big integers, but convert_element_type(2 ** 100, np.float32) need # not be an error since the target dtype fits the value. Handle this case by # converting to a NumPy array before calling bind. Without this step, we'd # first canonicalize the input to a value of dtype int32 or int64, leading to # an overflow error. - if type(operand) is int: - operand = np.asarray(operand).astype(new_dtype) - old_weak_type = False + if type(operand) is int and new_dtype != dtypes.float0: + operand = literals.TypedNdArray(np.asarray(operand).astype(new_dtype), + weak_type) if ((old_dtype, old_weak_type) == (new_dtype, weak_type) and isinstance(operand, Array) and @@ -1646,7 +1721,8 @@ def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array: .. _stablehlo.bitcast_convert: https://openxla.org/stablehlo/spec#bitcast_convert """ - new_dtype = dtypes.canonicalize_dtype(new_dtype) + new_dtype = dtypes.check_and_canonicalize_user_dtype( + new_dtype, 'bitcast_convert_type') return bitcast_convert_type_p.bind(operand, new_dtype=new_dtype) def clamp(min: ArrayLike, x: ArrayLike, max: ArrayLike) -> Array: @@ -1658,6 +1734,7 @@ def clamp(min: ArrayLike, x: ArrayLike, max: ArrayLike) -> Array: x & \text{otherwise} \end{cases}`. """ + min, x, max = core.standard_insert_pvary(min, x, max) return clamp_p.bind(min, x, max) @@ -1669,7 +1746,7 @@ def _trace_composite_to_jaxpr(fun: Callable, debug_info: core.DebugInfo): flat_fun, out_tree = api_util.flatten_fun_nokwargs( lu.wrap_init(fun, debug_info=debug_info), in_tree) - jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) if any(isinstance(c, core.Tracer) for c in consts): raise UnexpectedTracerError( "Found a JAX Tracer as a constant in the decomposition for the " @@ -1677,8 +1754,8 @@ def _trace_composite_to_jaxpr(fun: Callable, "closes over a value that is involved in a JAX transformation. " "Any values that aren't explicitly known at compile time must be " "explicitly passed as arguments to the composite.") - closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) - return closed_jaxpr, out_tree + closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr)) + return closed_jaxpr, consts, out_tree def composite( @@ -1749,7 +1826,7 @@ def composite( @functools.wraps(decomposition) def _decorator(*args, **kwargs): debug_info = api_util.debug_info("composite", decomposition, - args, kwargs) + args, {}) flat_args, in_tree = tree_util.tree_flatten(args) in_avals = tuple(core.get_aval(x) for x in flat_args) if any(isinstance(v, core.Tracer) for v in kwargs.values()): @@ -1761,13 +1838,21 @@ def _decorator(*args, **kwargs): "explicitly passed as arguments to the composite." "\n\nNote: If you are passing jax arrays as attributes, use numpy " "arrays instead.") - closed_jaxpr, out_tree = _trace_composite_to_jaxpr( + closed_jaxpr, consts, out_tree = _trace_composite_to_jaxpr( partial(decomposition, **kwargs), in_tree, in_avals, name, debug_info ) + attributes = [] + for k, v in kwargs.items(): + leaves, treedef = tree_util.tree_flatten(v) + leaves = tuple( + HashableArray(v) if isinstance(v, np.ndarray) else v for v in leaves + ) + attributes.append((k, leaves, treedef)) + flat_consts_and_args = core.standard_insert_pvary(*consts, *flat_args) out_flat = composite_p.bind( - *flat_args, + *flat_consts_and_args, name=name, - attributes=tuple((k, v) for k, v in kwargs.items()), + attributes=tuple(attributes), version=version, jaxpr=closed_jaxpr, ) @@ -1780,7 +1865,7 @@ def _composite_lowering( ctx: mlir.LoweringRuleContext, *args: Any, name: str, - attributes: Sequence[tuple[str, Any]], + attributes: Sequence[tuple[str, tuple[Any, ...], tree_util.PyTreeDef]], version: int, jaxpr: core.ClosedJaxpr, ): @@ -1799,23 +1884,32 @@ def _composite_lowering( Returns: The results of the composite. """ + const_args_and_avals = core.jaxpr_const_args(jaxpr.jaxpr) + const_args, const_avals = util.unzip2(const_args_and_avals) + const_arg_values = tuple( + mlir.ir_constant(c, const_lowering=ctx.const_lowering, aval=aval) + for c, aval in const_args_and_avals + ) + in_avals = (*const_avals, *ctx.avals_in) func_op, _, _ = mlir.lower_called_computation( name, - ctx.name_stack, jaxpr, ctx.module_context, + len(const_args), + in_avals, ctx.avals_out, ctx.tokens_in, ) - composite_attrs = { - k : mlir.ir_attribute(v) - for k, v in attributes - if v is not None - } + + composite_attrs = {} + for k, leaves, treedef in attributes: + v = treedef.unflatten(leaves) + if v is not None: + composite_attrs[k] = mlir.ir_attribute(v) symbol_name = func_op.name.value composite = hlo.CompositeOp( func_op.type.results, - mlir.flatten_ir_values(args), + mlir.flatten_ir_values(const_arg_values + args), name=ir.StringAttr.get(name), decomposition=ir.FlatSymbolRefAttr.get(symbol_name), composite_attributes=ir.DictAttr.get(composite_attrs), @@ -1838,7 +1932,7 @@ def composite_jvp(*args, **_): raise ValueError( "JVP rule for composite not implemented. You can use `jax.custom_jvp` to " "add support. See " - "https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_jvp.html" + "https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.html" ) @@ -1847,7 +1941,7 @@ def composite_transpose(*args, **_): raise ValueError( "Transpose rule for composite not implemented. You can use" "`jax.custom_jvp` or `jax.custom_vjp` to add support. See " - "https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_jvp.html" + "https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.html" ) @@ -1864,7 +1958,7 @@ def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array: """Concatenates a sequence of arrays along `dimension`. Wraps XLA's `Concatenate - `_ + `_ operator. Args: @@ -1881,6 +1975,7 @@ def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array: op, = operands if isinstance(op, Array): return op + operands = core.standard_insert_pvary(*operands) return concatenate_p.bind(*operands, dimension=dimension) @@ -1986,7 +2081,7 @@ class DotAlgorithm(NamedTuple): The `StableHLO spec `_ for the dot operation doesn't require that the precision types be the same as the - storage types for the inputs or outputs, but some plaforms may require that + storage types for the inputs or outputs, but some platforms may require that these types match. Furthermore, the return type of :func:`~jax.lax.dot_general` is always defined by the ``accumulation_type`` parameter of the input algorithm, if specified. @@ -2230,13 +2325,10 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, np.dtype(dtypes.float8_e4m3fnuz), np.dtype(dtypes.float8_e5m2), np.dtype(dtypes.float8_e5m2fnuz), + np.dtype(dtypes.float8_e3m4), + np.dtype(dtypes.float8_e4m3), + np.dtype(dtypes.float8_e8m0fnu), ] - if dtypes.float8_e3m4 is not None: - fp8_dtypes += [np.dtype(dtypes.float8_e3m4)] - if dtypes.float8_e4m3 is not None: - fp8_dtypes += [np.dtype(dtypes.float8_e4m3)] - if dtypes.float8_e8m0fnu is not None: - fp8_dtypes += [np.dtype(dtypes.float8_e8m0fnu)] if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes: raise ValueError( f"The dot algorithm '{self}' requires both inputs to have float8 " @@ -2267,11 +2359,6 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, case DotAlgorithmPreset.BF16_BF16_F32_X6: return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 6, False) case DotAlgorithmPreset.BF16_BF16_F32_X9: - if xla_extension_version < 320: - raise ValueError( - "The dot algorithm BF16_BF16_F32_X9 requires XLA extension " - "version >= 320." - ) return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 9, False) case DotAlgorithmPreset.TF32_TF32_F32: return hlo.DotAlgorithm.get(tf32, tf32, f32, 1, 1, 1, False) @@ -2302,77 +2389,52 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, ] -def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None, - preferred_element_type: DTypeLike | None = None) -> Array: - """Vector/vector, matrix/vector, and matrix/matrix multiplication. - - Wraps XLA's `Dot `_ - operator. - - For more general contraction, see the :func:`jax.lax.dot_general` operator. - - Args: - lhs: an array of dimension 1 or 2. - rhs: an array of dimension 1 or 2. - precision: Optional. This parameter controls the numerics of the - computation, and it can be one of the following: - - - ``None``, which means the default precision for the current backend, - - a :class:`~jax.lax.Precision` enum value or a tuple of two - :class:`~jax.lax.Precision` enums indicating precision of ``lhs``` and - ``rhs``, or - - a :class:`~jax.lax.DotAlgorithm` or a - :class:`~jax.lax.DotAlgorithmPreset` indicating the algorithm that - must be used to accumulate the dot product. - - preferred_element_type: Optional. This parameter controls the data type - output by the dot product. By default, the output element type of this - operation will match the ``lhs`` and ``rhs`` input element types under - the usual type promotion rules. Setting ``preferred_element_type`` to a - specific ``dtype`` will mean that the operation returns that element type. - When ``precision`` is not a :class:`~jax.lax.DotAlgorithm` or - :class:`~jax.lax.DotAlgorithmPreset`, ``preferred_element_type`` provides - a hint to the compiler to accumulate the dot product using this data type. - - Returns: - An array containing the product. - """ - if 1 <= lhs.ndim <= 2 and 1 <= rhs.ndim <= 2 and core.definitely_equal(lhs.shape[-1], rhs.shape[0]): - return dot_general(lhs, rhs, (((lhs.ndim - 1,), (0,)), ((), ())), - precision=precision, - preferred_element_type=preferred_element_type) - else: - raise TypeError("Incompatible shapes for dot: got {} and {}.".format( - lhs.shape, rhs.shape)) - - DotDimensionNumbers = tuple[tuple[Sequence[int], Sequence[int]], tuple[Sequence[int], Sequence[int]]] -def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers, + +# TODO(jakevdp): consider deprecating jax.lax.dot_general. +def dot_general(lhs: ArrayLike, rhs: ArrayLike, + dimension_numbers: DotDimensionNumbers, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, + *, out_sharding=None) -> Array: + """Alias of :func:`jax.lax.dot`. + + Prefer use of :func:`jax.lax.dot` directly, but note that it requires + all arguments after ``lhs`` and ``rhs`` to be specified by keyword + rather than position. + """ + return dot(lhs, rhs, dimension_numbers=dimension_numbers, precision=precision, + preferred_element_type=preferred_element_type, out_sharding=out_sharding) + + +# TODO(jakevdp): replace `*args`` with `*` in v0.10.0 +def dot(lhs: ArrayLike, rhs: ArrayLike, *args, + dimension_numbers: DotDimensionNumbers | None = None, + precision: PrecisionLike = None, + preferred_element_type: DTypeLike | None = None, + out_sharding=None) -> Array: """General dot product/contraction operator. - Wraps XLA's `DotGeneral - `_ - operator. + This operation lowers directly to the `stablehlo.dot_general`_ operation. The semantics of ``dot_general`` are complicated, but most users should not have to use it directly. Instead, you can use higher-level functions like :func:`jax.numpy.dot`, :func:`jax.numpy.matmul`, :func:`jax.numpy.tensordot`, :func:`jax.numpy.einsum`, and others which will construct appropriate calls to ``dot_general`` under the hood. If you really want to understand ``dot_general`` itself, we recommend reading XLA's - `DotGeneral `_ - operator documentation. + DotGeneral_ operator documentation. Args: lhs: an array rhs: an array - dimension_numbers: a tuple of tuples of sequences of ints of the form + dimension_numbers: an optional tuple of tuples of sequences of ints of the form ``((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, - rhs_batch_dims))`` + rhs_batch_dims))``. This may be left unspecified in the common case of + un-batched matrix-matrix, matrix-vector, or vector-vector dot products, as + determined by the shape of ``lhs`` and ``rhs``. precision: Optional. This parameter controls the numerics of the computation, and it can be one of the following: @@ -2392,16 +2454,40 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN When ``precision`` is not a :class:`~jax.lax.DotAlgorithm` or :class:`~jax.lax.DotAlgorithmPreset`, ``preferred_element_type`` provides a hint to the compiler to accumulate the dot product using this data type. + out_sharding: an optional sharding specification for the output. If not specified, + it will be determined automatically by the compiler. Returns: An array whose first dimensions are the (shared) batch dimensions, followed by the ``lhs`` non-contracting/non-batch dimensions, and finally the ``rhs`` non-contracting/non-batch dimensions. + + .. _stablehlo.dot_general: https://openxla.org/stablehlo/spec#dot_general + .. _DotGeneral: https://www.openxla.org/xla/operation_semantics#dotgeneral """ - if out_sharding is not None and not isinstance(out_sharding, NamedSharding): - raise NotImplementedError( - '`out_sharding` argument of `dot_general` only supports NamedSharding ' - 'instances. Please file a bug if this is not enough for your use case.') + if args: + raise TypeError( + f"dot() takes 2 positional arguments but {2 + len(args)} were given." + " Passing precision or preferred_element_type by position is not allowed" + " as of JAX v0.9.0; pass them by keyword instead." + ) + del args + + lhs_shape = np.shape(lhs) + lhs_ndim = len(lhs_shape) + + rhs_shape = np.shape(rhs) + rhs_ndim = len(rhs_shape) + + if dimension_numbers is None: + if 1 <= lhs_ndim <= 2 and 1 <= rhs_ndim <= 2 and core.definitely_equal(lhs_shape[-1], rhs_shape[0]): + dimension_numbers = (((lhs_ndim - 1,), (0,)), ((), ())) + else: + raise ValueError( + "jax.lax.dot: dimension_numbers must be specified when not performing simple" + " un-batched matrix-matrix, matrix-vector, or vector-vector products;" + f" got {lhs_shape=} {rhs_shape=}") + out_sharding = canonicalize_sharding(out_sharding, 'dot_general') (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers cdims = (api_util._ensure_index_tuple(lhs_contract), @@ -2410,7 +2496,8 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN api_util._ensure_index_tuple(rhs_batch)) preferred_element_type = ( None if preferred_element_type is None else - dtypes.canonicalize_dtype(np.dtype(preferred_element_type))) + dtypes.check_and_canonicalize_user_dtype(preferred_element_type, 'dot')) + lhs, rhs = core.standard_insert_pvary(lhs, rhs) return dot_general_p.bind(lhs, rhs, dimension_numbers=(cdims, bdims), precision=canonicalize_precision(precision), @@ -2515,16 +2602,18 @@ def ragged_dot_general( Let `g` be the number of groups in the lhs ragged dimension. Ragged dot has three modes, depending on the kind of the lhs ragged dimension: - 1. `[b...,m...,k...], [g,b...,k...,n...], [b...,x...,g] -> [b...,m...,n...]`. - Here the ragged dimension is a non-contracting dimension (`m`) of ``lhs``, - and `x...` are the lhs non-contracting dims outer to the ragged dim. - 2. `[b...,m...,k...], [b...,k...,n...], [b...,x...,g] -> [g,b...,m...,n...]`. - Here the ragged dimension is a contracting dimension (`k`) of ``lhs`` and + + 1. ``[b...,m...,k...], [g,b...,k...,n...], [b...,x...,g] -> [b...,m...,n...]``. + Here the ragged dimension is a non-contracting dimension (``m``) of ``lhs``, + and ``x...`` are the lhs non-contracting dims outer to the ragged dim. + 2. ``[b...,m...,k...], [b...,k...,n...], [b...,x...,g] -> [g,b...,m...,n...]``. + Here the ragged dimension is a contracting dimension (``k``) of ``lhs`` and ``rhs``, and `x...` are the lhs contracting dims outer to the ragged dim. - 3. `[b...,m...,k...], [b...,k...,n...], [x...,g] -> [b...,m...,n...]`. - Here the ragged dimension is a batch dimension (`b`) of ``lhs`` and - ``rhs``, and `x...` are the lhs batch dims outer to the ragged dim. - If ``group_sizes`` is passed-in with shape `[g]`, it is broadcasted according + 3. ``[b...,m...,k...], [b...,k...,n...], [x...,g] -> [b...,m...,n...]``. + Here the ragged dimension is a batch dimension (``b``) of ``lhs`` and + ``rhs``, and ``x...`` are the lhs batch dims outer to the ragged dim. + + If ``group_sizes`` is passed-in with shape ``[g]``, it is broadcasted according to the rules above. Args: @@ -2546,6 +2635,7 @@ def ragged_dot_general( extra leading dimension of size `g` in the case where the lhs ragged dimension is a contracting dimension. """ + lhs, rhs, group_sizes = core.standard_insert_pvary(lhs, rhs, group_sizes) return ragged_dot_general_p.bind( lhs, rhs, @@ -2557,7 +2647,7 @@ def ragged_dot_general( ) -def broadcast(operand: ArrayLike, sizes: Sequence[int], out_sharding=None +def broadcast(operand: ArrayLike, sizes: Sequence[int], *, out_sharding=None ) -> Array: """Broadcasts an array, adding new leading dimensions @@ -2579,17 +2669,17 @@ def broadcast(operand: ArrayLike, sizes: Sequence[int], out_sharding=None out_sharding=out_sharding) def broadcast_in_dim(operand: ArrayLike, shape: Shape, - broadcast_dimensions: Sequence[int], out_sharding=None + broadcast_dimensions: Sequence[int], *, out_sharding=None ) -> Array: """Wraps XLA's `BroadcastInDim - `_ + `_ operator. Args: operand: an array shape: the shape of the target array broadcast_dimensions: to which dimension in the target shape each dimension - of the operand shape corresponds to. That is, dimension i of the operand + of the operand shape corresponds to. That is, dimension i of the operand becomes dimension broadcast_dimensions[i] of the result. Returns: @@ -2602,16 +2692,14 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape, if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and isinstance(operand, Array) and out_sharding is None): return operand - if config.dynamic_shapes.value: - # We must gate this behavior under a flag because otherwise the errors - # raised are different (and have worse source provenance information). - dyn_shape, static_shape = _extract_tracers_dyn_shape(shape) - else: - dyn_shape, static_shape = [], shape # type: ignore + operand_aval = typeof(operand) + if (operand_aval.shape == shape and + list(broadcast_dimensions) == list(range(operand_aval.ndim)) and + out_sharding is not None and operand_aval.sharding != out_sharding): + return pjit.reshard(operand, out_sharding) return broadcast_in_dim_p.bind( - operand, *dyn_shape, shape=tuple(static_shape), - broadcast_dimensions=tuple(broadcast_dimensions), - sharding=out_sharding) + operand, shape=tuple(shape), + broadcast_dimensions=tuple(broadcast_dimensions), sharding=out_sharding) def broadcast_to_rank(x: ArrayLike, rank: int) -> Array: """Adds leading dimensions of ``1`` to give ``x`` rank ``rank``.""" @@ -2620,11 +2708,44 @@ def broadcast_to_rank(x: ArrayLike, rank: int) -> Array: return asarray(x) return broadcast(x, (1,) * (rank - ndim)) + +def tile(operand: ArrayLike, reps: Sequence[int]) -> Array: + """Tiles an array by repeating it along each dimension. + + Args: + operand: an array to tile. + reps: a sequence of integers representing the number of repeats for each + dimension. Must have the same length as ``operand.ndim``. + + Returns: + A tiled array with shape ``(operand.shape[0] * reps[0], ..., + operand.shape[-1] * reps[-1])``. + + Examples: + >>> x = jnp.array([[1, 2], [3, 4]]) + >>> lax.tile(x, (2, 3)) + Array([[1, 2, 1, 2, 1, 2], + [3, 4, 3, 4, 3, 4], + [1, 2, 1, 2, 1, 2], + [3, 4, 3, 4, 3, 4]], dtype=int32) + + >>> y = jnp.array([1, 2, 3]) + >>> lax.tile(y, (2,)) + Array([1, 2, 3, 1, 2, 3], dtype=int32) + + >>> z = jnp.array([[1], [2]]) + >>> lax.tile(z, (1, 3)) + Array([[1, 1, 1], + [2, 2, 2]], dtype=int32) + """ + return tile_p.bind(operand, reps=tuple(reps)) + + def reshape(operand: ArrayLike, new_sizes: Shape, dimensions: Sequence[int] | None = None, - out_sharding: NamedSharding | P | None = None) -> Array: + *, out_sharding: NamedSharding | P | None = None) -> Array: """Wraps XLA's `Reshape - `_ + `_ operator. For inserting/removing dimensions of size 1, prefer using ``lax.squeeze`` / @@ -2669,13 +2790,16 @@ def reshape(operand: ArrayLike, new_sizes: Shape, else: dims = api_util._ensure_index_tuple(dimensions) same_dims = tuple(dims) == tuple(range(np.ndim(operand))) - if np.shape(operand) and same_shape and same_dims and isinstance(operand, Array): + out_sharding = canonicalize_sharding(out_sharding, 'reshape') + same_sharding = (out_sharding is None or + typeof(operand).sharding == out_sharding) + + if (np.shape(operand) and same_shape and same_dims and same_sharding and + isinstance(operand, Array)): return operand else: - dyn_shape, static_new_sizes = _extract_tracers_dyn_shape(new_sizes) - out_sharding = canonicalize_sharding(out_sharding, 'reshape') return reshape_p.bind( - operand, *dyn_shape, new_sizes=tuple(static_new_sizes), + operand, new_sizes=tuple(new_sizes), dimensions=None if dims is None or same_dims else dims, sharding=out_sharding) @@ -2684,7 +2808,7 @@ def pad(operand: ArrayLike, padding_value: ArrayLike, """Applies low, high, and/or interior padding to an array. Wraps XLA's `Pad - `_ + `_ operator. Args: @@ -2693,7 +2817,8 @@ def pad(operand: ArrayLike, padding_value: ArrayLike, as ``operand``. padding_config: a sequence of ``(low, high, interior)`` tuples of integers, giving the amount of low, high, and interior (dilation) padding to insert - in each dimension. + in each dimension. Negative values for ``low`` and ``high`` are allowed + and remove elements from the edges of the array. Returns: The ``operand`` array with padding value ``padding_value`` inserted in each @@ -2728,12 +2853,19 @@ def pad(operand: ArrayLike, padding_value: ArrayLike, [-1, -1, 4, 5, 6, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], dtype=int32) + + Use negative padding to remove elements from the edges of an array: + + >>> x = jnp.array([1, 2, 3, 4, 5], dtype=jnp.int32) + >>> lax.pad(x, 0, [(-1, -2, 0)]) + Array([2, 3], dtype=int32) """ + operand, padding_value = core.standard_insert_pvary(operand, padding_value) return pad_p.bind(operand, padding_value, padding_config=tuple(padding_config)) def rev(operand: ArrayLike, dimensions: Sequence[int]) -> Array: """Wraps XLA's `Rev - `_ + `_ operator. """ return rev_p.bind(operand, dimensions=tuple(dimensions)) @@ -2742,7 +2874,7 @@ def select(pred: ArrayLike, on_true: ArrayLike, on_false: ArrayLike) -> Array: """Selects between two branches based on a boolean predicate. Wraps XLA's `Select - `_ + `_ operator. In general :func:`~jax.lax.select` leads to evaluation of both branches, although @@ -2761,13 +2893,15 @@ def select(pred: ArrayLike, on_true: ArrayLike, on_false: ArrayLike) -> Array: """ # Caution! The select_n_p primitive has the *opposite* order of arguments to # select(). This is because it implements `select_n`. + pred, on_false, on_true = core.standard_insert_pvary( + pred, on_false, on_true) return select_n_p.bind(pred, on_false, on_true) def select_n(which: ArrayLike, *cases: ArrayLike) -> Array: """Selects array values from multiple cases. Generalizes XLA's `Select - `_ + `_ operator. Unlike XLA's version, the operator is variadic and can select from many cases using an integer `pred`. @@ -2786,13 +2920,14 @@ def select_n(which: ArrayLike, *cases: ArrayLike) -> Array: """ if len(cases) == 0: raise ValueError("select_n() must have at least one case") + which, *cases = core.standard_insert_pvary(which, *cases) return select_n_p.bind(which, *cases) def transpose(operand: ArrayLike, permutation: Sequence[int] | np.ndarray) -> Array: """Wraps XLA's `Transpose - `_ + `_ operator. """ permutation = tuple(operator.index(d) for d in permutation) @@ -2804,21 +2939,22 @@ def transpose(operand: ArrayLike, def argmin(operand: ArrayLike, axis: int, index_dtype: DTypeLike) -> Array: """Computes the index of the minimum element along ``axis``.""" - return argmin_p.bind(operand, axes=(axis,), - index_dtype=dtypes.canonicalize_dtype(index_dtype)) + index_dtype = dtypes.check_and_canonicalize_user_dtype(index_dtype, 'argmin') + return argmin_p.bind(operand, axes=(axis,), index_dtype=index_dtype) def argmax(operand: ArrayLike, axis: int, index_dtype: DTypeLike) -> Array: """Computes the index of the maximum element along ``axis``.""" - return argmax_p.bind(operand, axes=(axis,), - index_dtype=dtypes.canonicalize_dtype(index_dtype)) + index_dtype = dtypes.check_and_canonicalize_user_dtype(index_dtype, 'argmax') + return argmax_p.bind(operand, axes=(axis,), index_dtype=index_dtype) def reduce(operands: Any, init_values: Any, computation: Callable[[Any, Any], Any], - dimensions: Sequence[int]) -> Any: + dimensions: Sequence[int], + out_sharding: NamedSharding | P | None = None) -> Any: """Wraps XLA's `Reduce - `_ + `_ operator. ``init_values`` and ``computation`` together must form a `monoid @@ -2841,13 +2977,20 @@ def reduce(operands: Any, monoid_reducer = _get_monoid_reducer(computation, flat_init_values) if monoid_reducer: # monoid reducers bypass the weak_type_rule, so we set it explicitly. - weak_type = dtypes.is_weakly_typed(*flat_operands) and dtypes.is_weakly_typed(*flat_init_values) - return _convert_element_type(monoid_reducer(*flat_operands, dimensions), - weak_type=weak_type) + weak_type = (dtypes.is_weakly_typed(*flat_operands) and + dtypes.is_weakly_typed(*flat_init_values)) + if out_sharding is not None and monoid_reducer is not reduce_sum: + raise NotImplementedError + out_sharding_dict = ({'out_sharding': out_sharding} + if out_sharding is not None else {}) + out = monoid_reducer(*flat_operands, dimensions, **out_sharding_dict) + return _convert_element_type(out, weak_type=weak_type) else: flat_init_avals = safe_map(core.get_aval, flat_init_values) closed_jaxpr, out_tree = _variadic_reduction_jaxpr( computation, comp_debug, tuple(flat_init_avals), init_value_tree) + flat_operands = core.standard_insert_pvary(*flat_operands) + flat_init_values = core.standard_insert_pvary(*flat_init_values) out = reduce_p.bind(*flat_operands, *flat_init_values, computation=computation, jaxpr=closed_jaxpr, dimensions=tuple(dimensions)) return tree_util.tree_unflatten(out_tree, out) @@ -2867,7 +3010,7 @@ def comp(x, y): comp, debug_info=api_util.debug_info("reduction_jaxpr", computation, (aval, aval), {})) - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(comp_wrapped, (aval, aval)) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(comp_wrapped, (aval, aval)) if any(isinstance(c, core.Tracer) for c in consts): raise NotImplementedError( "Reduction computations can't close over Tracers. Please open an issue " @@ -2883,7 +3026,7 @@ def _variadic_reduction_jaxpr(computation: Callable[[Any, Any], Any], flat_in_avals, in_tree = tree_util.tree_flatten((avals, avals)) comp = lu.wrap_init(computation, debug_info=debug_info) flat_comp, out_tree = api_util.flatten_fun_nokwargs(comp, in_tree) - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_comp, tuple(flat_in_avals)) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_comp, tuple(flat_in_avals)) if any(isinstance(c, core.Tracer) for c in consts): raise NotImplementedError( "Reduction computations can't close over Tracers. Please open an issue " @@ -2951,7 +3094,8 @@ def _get_min_identity(dtype: DTypeLike) -> np.ndarray: else: raise ValueError(f"Unsupported dtype for min: {dtype}") -def reduce_sum(operand: ArrayLike, axes: Sequence[int]) -> Array: +def reduce_sum(operand: ArrayLike, axes: Sequence[int], *, + out_sharding=None) -> Array: """Compute the sum of elements over one or more array axes. Args: @@ -2975,7 +3119,8 @@ def reduce_sum(operand: ArrayLike, axes: Sequence[int]) -> Array: :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`. """ - return reduce_sum_p.bind(operand, axes=tuple(axes)) + out_sharding = canonicalize_sharding(out_sharding, 'reduce_sum') + return reduce_sum_p.bind(operand, axes=tuple(axes), out_sharding=out_sharding) def reduce_prod(operand: ArrayLike, axes: Sequence[int]) -> Array: """Compute the product of elements over one or more array axes. @@ -3122,7 +3267,7 @@ def sort(operand: Sequence[Array], dimension: int = -1, def sort(operand: Array | Sequence[Array], dimension: int = -1, is_stable: bool = True, num_keys: int = 1) -> Array | tuple[Array, ...]: """Wraps XLA's `Sort - `_ operator. + `_ operator. For floating point inputs, -0.0 and 0.0 are treated as equivalent, and NaN values are sorted to the end of the array. For complex inputs, the sort order is @@ -3146,6 +3291,7 @@ def sort(operand: Array | Sequence[Array], dimension: int = -1, if not (1 <= num_keys <= len(operand)): raise ValueError(f"{num_keys=} must be between 1 and {len(operand)=}") dimension = canonicalize_axis(dimension, len(operand[0].shape)) + operand = core.standard_insert_pvary(*operand) return tuple(sort_p.bind(*operand, dimension=dimension, is_stable=is_stable, num_keys=num_keys)) @@ -3162,12 +3308,14 @@ def sort_key_val(keys: Array, values: ArrayLike, dimension: int = -1, k, v = sort_p.bind(keys, values, dimension=dimension, is_stable=is_stable, num_keys=1) return k, v -def top_k(operand: ArrayLike, k: int) -> tuple[Array, Array]: - """Returns top ``k`` values and their indices along the last axis of ``operand``. +def top_k(operand: ArrayLike, k: int, *, axis: int = -1) -> tuple[Array, Array]: + """Returns top ``k`` values and their indices along the specified axis of ``operand``. Args: operand: N-dimensional array of non-complex type. k: integer specifying the number of top entries. + axis: optional integer specifying the axis along which to compute the top + ``k`` entries. Default is -1, indicating the last axis. Returns: A tuple ``(values, indices)`` where @@ -3175,6 +3323,11 @@ def top_k(operand: ArrayLike, k: int) -> tuple[Array, Array]: - ``values`` is an array containing the top k values along the last axis. - ``indices`` is an array containing the indices corresponding to values. + ``values[..., i, ...]`` is the ``i``-th largest entry in ``operand`` along the + specified axis, and its index is ``indices[..., i, ...]``. + + If two elements are equal, the lower-index element appears first. + See also: - :func:`jax.lax.approx_max_k` - :func:`jax.lax.approx_min_k` @@ -3193,7 +3346,8 @@ def top_k(operand: ArrayLike, k: int) -> tuple[Array, Array]: k = int(k) if k < 0: raise ValueError(f"k argument to top_k must be nonnegative, got {k}") - return top_k_p.bind(operand, k=k) + axis = canonicalize_axis(axis, np.ndim(operand)) + return top_k_p.bind(operand, k=k, axis=axis) def tie_in(x: Any, y: T) -> T: """Deprecated. Ignores ``x`` and returns ``y``.""" @@ -3216,17 +3370,22 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *, if np.shape(fill_value): msg = "full must be called with scalar fill_value, got fill_value.shape {}." raise TypeError(msg.format(np.shape(fill_value))) - if dtypes.issubdtype(dtype, dtypes.extended): - return dtype._rules.full(shape, fill_value, dtype) # type: ignore[union-attr] - weak_type = dtype is None and dtypes.is_weakly_typed(fill_value) - dtype = dtypes.canonicalize_dtype(dtype or _dtype(fill_value)) - fill_value = _convert_element_type(fill_value, dtype, weak_type) + if dtype is None: + weak_type = dtypes.is_weakly_typed(fill_value) + fill_dtype = _dtype(fill_value) + else: + if dtypes.issubdtype(dtype, dtypes.extended): + return dtype._rules.full(shape, fill_value, dtype) # type: ignore[union-attr] + weak_type = False + fill_dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "full") + fill_value = _convert_element_type(fill_value, fill_dtype, weak_type) if (sharding is not None and not isinstance(sharding, PmapSharding) and isinstance(fill_value, array.ArrayImpl) and sharding._is_concrete): broadcast_shape = sharding.shard_shape(shape) shard = broadcast(fill_value, broadcast_shape) shard = shard.addressable_data(0) - return array.make_array_from_callback(shape, sharding, lambda _: shard) + return array.make_array_from_callback( + shape, sharding, lambda _: shard, dtype=fill_dtype) if sharding is not None and not sharding._is_concrete: return broadcast(fill_value, shape, out_sharding=sharding) @@ -3241,13 +3400,13 @@ def zeros_like_shaped_array(aval: ShapedArray) -> Array: scalar_zero = np.zeros((), dtype=aval.dtype) else: scalar_zero = _convert_element_type(0, aval.dtype, aval.weak_type) - return broadcast(scalar_zero, aval.shape, out_sharding=aval.sharding) - + out = broadcast(scalar_zero, aval.shape, out_sharding=aval.sharding) + return core.pvary(out, tuple(aval.vma)) ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array -def zeros_like_abstract_ref(aval: state.AbstractRef) -> core.MutableArray: +def zeros_like_abstract_ref(aval: state.AbstractRef) -> core.Ref: val = ad_util.zeros_like_aval(aval.inner_aval) - return core.mutable_array(val) + return core.new_ref(val) # TODO(dougalm): this is nonsense but it's here because in places like # custom_vjp we assume that all arguments have tangent spaces. We could have @@ -3256,29 +3415,27 @@ def zeros_like_abstract_ref(aval: state.AbstractRef) -> core.MutableArray: def iota(dtype: DTypeLike, size: int) -> Array: """Wraps XLA's `Iota - `_ + `_ operator. """ return broadcasted_iota(dtype, (size,), 0) def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int, - out_sharding=None) -> Array: + *, out_sharding=None) -> Array: """Convenience wrapper around ``iota``.""" - dtype = dtypes.canonicalize_dtype(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "broadcasted_iota") shape = canonicalize_shape(shape) - dynamic_shape = [d for d in shape if isinstance(d, core.Tracer)] - static_shape = [None if isinstance(d, core.Tracer) else d for d in shape] dimension = core.concrete_or_error( int, dimension, "dimension argument of lax.broadcasted_iota") out_sharding = canonicalize_sharding(out_sharding, 'broadcasted_iota') - return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape), + return iota_p.bind(dtype=dtype, shape=shape, dimension=dimension, sharding=out_sharding) def _eye(dtype: DTypeLike, shape: Shape, offset: DimSize = 0) -> Array: """Like numpy.eye, create a 2D array with ones on a diagonal.""" offset = _clip_int_to_valid_range(offset, np.int32, "argument `offset` of jax.numpy.eye") - dtype = dtypes.canonicalize_dtype(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "eye") bool_eye = eq(add(broadcasted_iota(np.int32, shape, 0), np.int32(offset)), broadcasted_iota(np.int32, shape, 1)) return convert_element_type_p.bind(bool_eye, new_dtype=dtype, weak_type=False, @@ -3287,7 +3444,7 @@ def _eye(dtype: DTypeLike, shape: Shape, offset: DimSize = 0) -> Array: def _delta(dtype: DTypeLike, shape: Shape, axes: Sequence[int]) -> Array: """This utility function exists for creating Kronecker delta arrays.""" axes = map(int, axes) - dtype = dtypes.canonicalize_dtype(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "delta") base_shape = tuple(np.take(shape, axes)) iotas = [broadcasted_iota(np.uint32, base_shape, i) for i in range(len(base_shape))] @@ -3299,12 +3456,19 @@ def _delta(dtype: DTypeLike, shape: Shape, axes: Sequence[int]) -> Array: def _tri(dtype: DTypeLike, shape: Shape, offset: DimSize) -> Array: """Like numpy.tri, create a 2D array with ones below a diagonal.""" - offset = _clip_int_to_valid_range(offset, np.int32, - "argument `offset` of jax.numpy.tri") - dtype = dtypes.canonicalize_dtype(dtype) - bool_tri = ge(add(broadcasted_iota(np.int32, shape, 0), - asarray(core.dimension_as_value(offset)).astype(np.int32)), - broadcasted_iota(np.int32, shape, 1)) + offset = asarray(core.dimension_as_value(offset)) + if not dtypes.issubdtype(offset.dtype, np.integer): + raise TypeError(f"offset must be an integer, got {offset!r}") + shape_dtype = lax_utils.int_dtype_for_shape(shape, signed=True) + if ( + np.iinfo(offset.dtype).min < np.iinfo(shape_dtype).min + or np.iinfo(offset.dtype).max > np.iinfo(shape_dtype).max + ): + shape_dtype = np.dtype(np.int64) + dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "tri") + bool_tri = ge(add(broadcasted_iota(shape_dtype, shape, 0), + offset.astype(shape_dtype)), + broadcasted_iota(shape_dtype, shape, 1)) return convert_element_type_p.bind(bool_tri, new_dtype=dtype, weak_type=False, sharding=None) @@ -3354,31 +3518,27 @@ def stop_gradient(x: T) -> T: the applicability of ``stop_gradient``. """ def stop(x): - # only bind primitive on inexact dtypes, to avoid some staging if dtypes.issubdtype(core.get_aval(x).dtype, dtypes.extended): return x - elif (dtypes.issubdtype(_dtype(x), np.floating) or - dtypes.issubdtype(_dtype(x), np.complexfloating)): - # break abstractions to support legacy leaked tracer use cases - if isinstance(x, ad.JVPTracer): - return stop(x.primal) - return ad_util.stop_gradient_p.bind(x) + elif isinstance(x, ad.JVPTracer): + return stop(x.primal) else: - return x - return tree_map(stop, x) + return ad_util.stop_gradient_p.bind(x) + return tree_util.tree_map(stop, x) def reduce_precision(operand: float | ArrayLike, exponent_bits: int, mantissa_bits: int) -> Array: """Wraps XLA's `ReducePrecision - `_ + `_ operator. """ exponent_bits = core.concrete_or_error( operator.index, exponent_bits, "exponent_bits argument of lax.reduce_precision") mantissa_bits = core.concrete_or_error( operator.index, mantissa_bits, "mantissa_bits argument of lax.reduce_precision") - return reduce_precision_p.bind(operand, exponent_bits=exponent_bits, mantissa_bits=mantissa_bits) + return reduce_precision_p.bind(operand, exponent_bits=exponent_bits, + mantissa_bits=mantissa_bits) def squeeze(array: ArrayLike, dimensions: Sequence[int]) -> Array: """Squeeze any number of size 1 dimensions from an array.""" @@ -3429,7 +3589,7 @@ def full_like(x: ArrayLike | DuckTypedArray, """ fill_shape = np.shape(x) if shape is None else canonicalize_shape(shape) # type: ignore[arg-type] weak_type = dtype is None and dtypes.is_weakly_typed(x) - dtype = dtype or _dtype(x) + dtype = _dtype(dtype) if dtype is not None else _dtype(x) if dtypes.issubdtype(dtype, dtypes.extended): return dtype._rules.full(fill_shape, fill_value, dtype) # type: ignore[union-attr] @@ -3445,15 +3605,21 @@ def full_like(x: ArrayLike | DuckTypedArray, # This bypasses the check. and not isinstance(x, core.Tracer) and hasattr(x, 'sharding') + and x.sharding is not None + and (x.sharding._is_concrete or not get_concrete_mesh().empty) and getattr(x, '_committed', True) and not weak_type and fill_shape == np.shape(x) # type: ignore[arg-type] ) if use_x_sharding: - # TODO(yashkatariya): Use shard_alike in tracing_mode once it is supported. sharding = x.sharding # type: ignore val = full(fill_shape, _convert_element_type(fill_value, dtype, weak_type), sharding=sharding) + if config._check_vma.value: + # TODO(yashkatariya): Maybe use `shaped_abstractify` here instead of + # `typeof` because `x` can be anything that implements the + # `DuckTypedArray` protocol. + val = core.pvary(val, tuple(typeof(x).vma)) return val @@ -3513,13 +3679,19 @@ def reciprocal(x: ArrayLike) -> Array: return integer_pow(x, -1) @export -def tan(x: ArrayLike) -> Array: +def tan(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise tangent: :math:`\mathrm{tan}(x)`. This function lowers directly to the `stablehlo.tangent`_ operation. Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -3533,7 +3705,7 @@ def tan(x: ArrayLike) -> Array: .. _stablehlo.tangent: https://openxla.org/stablehlo/spec#tangent """ - return tan_p.bind(x) + return tan_p.bind(x, accuracy=accuracy) @export def asin(x: ArrayLike) -> Array: @@ -3714,11 +3886,6 @@ def _iter(tracer): else: return (slicing.index_in_dim(tracer, i, keepdims=False) for i in range(n)) ShapedArray._iter = staticmethod(_iter) -core.DShapedArray._iter = staticmethod(_iter) - -def zeros_like_array(x: ArrayLike) -> Array: - return full_like(x, 0) - def _add_arrays(x, y): if (isinstance(a := core.get_aval(x), ShapedArray) and @@ -3727,7 +3894,8 @@ def _add_arrays(x, y): return add(x, y) for t in itertools.chain( - dtypes.python_scalar_dtypes.keys(), array_types, [array.ArrayImpl]): + dtypes.python_scalar_types, array_types, [array.ArrayImpl], + literals.typed_scalar_types): ad_util.raw_jaxval_adders[t] = _add_arrays @@ -3735,13 +3903,14 @@ def _add_arrays(x, y): _fixed_dtype = \ - lambda dtype: lambda *args, **kwargs: dtypes.canonicalize_dtype(dtype) + lambda dtype: lambda *args, **kwargs: np.dtype(dtype) _complex_basetype = lambda dtype, **kwargs: np.abs(np.zeros((), dtype)).dtype _strip_weak_type = lambda *args, **_: False -def unop_dtype_rule(result_dtype, accepted_dtypes, name, aval, **kwargs): +def unop_dtype_rule(result_dtype, accepted_dtypes, name, aval, + supports_narrow_ints=True, **kwargs): if aval.dtype == dtypes.float0: raise TypeError( f"Called {name} with a float0 array. " @@ -3756,15 +3925,24 @@ def unop_dtype_rule(result_dtype, accepted_dtypes, name, aval, **kwargs): typename = dtype_to_string(aval.dtype) accepted_typenames = (t.__name__ for t in accepted_dtypes) raise TypeError(msg.format(name, typename, ', '.join(accepted_typenames))) + if (not supports_narrow_ints) and aval.dtype in [dtypes.uint2, dtypes.int2, dtypes.uint4, dtypes.int4]: + raise TypeError(f'{name} does not accept dtype {dtype_to_string(aval.dtype)}.' + ' Support for narrow-width integers is platform-dependent' + ' and limited to a few specific operations, e.g. basic' + ' arithmetic and type casting.') return result_dtype(aval.dtype, **kwargs) +def unop_reduced_rule(out_s, aval, **kwargs): + return out_s.update(spec=out_s.spec.update(reduced=aval.sharding.spec.reduced)) -def unop(result_dtype, accepted_dtypes, name): - dtype_rule = partial(unop_dtype_rule, result_dtype, accepted_dtypes, name) +def unop(result_dtype, accepted_dtypes, name, supports_narrow_ints=True): + dtype_rule = partial(unop_dtype_rule, result_dtype, accepted_dtypes, name, + supports_narrow_ints=supports_narrow_ints) prim = standard_primitive(_attrgetter('shape'), dtype_rule, name, - sharding_rule=_attrgetter('sharding')) + sharding_rule=_attrgetter('sharding'), + vma_rule=_attrgetter('vma'), + reduced_rule=unop_reduced_rule) batching.defvectorized(prim) - pe.def_trivial_padding(prim) return prim standard_unop = partial(unop, _identity) @@ -3811,11 +3989,11 @@ def broadcasting_sharding_rule(name, *avals): for a in avals: if a.sharding is not None and not a.sharding.mesh.empty: if mesh is not None and mesh != a.sharding.mesh: - raise ValueError( + raise core.ShardingTypeError( f'Mesh for all inputs should be equal. Got one mesh: {mesh} and' f' another mesh: {a.sharding.mesh}') mesh = a.sharding.mesh - mesh = mesh_lib.get_abstract_mesh() if mesh is None else mesh + mesh = get_abstract_mesh() if mesh is None else mesh shapes = [aval.shape for aval in avals if aval.shape] if not shapes: @@ -3828,7 +4006,7 @@ def broadcasting_sharding_rule(name, *avals): result_specs = [None] * len(shapes[0]) for i, (ss, ds) in enumerate(zip(zip(*specs), zip(*shapes))): - if all(s == ss[0] for s in ss[1:]): + if all(ss[0] == s for s in ss[1:]): # if all dimension shardings are same, the resulting dimension sharding is # the same. result_specs[i] = ss[0] @@ -3845,54 +4023,88 @@ def broadcasting_sharding_rule(name, *avals): result_specs[i] = s elif (result_specs[i] is not None and s is not None and result_specs[i] != s): - raise TypeError( + raise core.ShardingTypeError( f'{name} got incompatible shardings for broadcasting: ' f'{", ".join(map(str, map(tuple, specs)))}.') return NamedSharding(mesh, P(*result_specs)) +def nary_reduced_rule(out_s, *avals, **params): + non_empty_avals = [a for a in avals if a.shape] + specs = [a.sharding.spec for a in non_empty_avals] + + reduced_spec = {s.reduced for s in specs if s.reduced} + if len(reduced_spec) > 1: + raise core.ShardingTypeError( + 'All inputs should be reduced across the same mesh axes. Got specs:' + f' {reduced_spec}') + reduced_s, = reduced_spec if reduced_spec else (frozenset(),) + if reduced_s: + for a in non_empty_avals: + s = a.sharding.spec + flat_spec = flatten_spec(s) + if a.sharding.replicated_axes & reduced_s: + raise core.ShardingTypeError( + 'Inputs cannot be replicated on the same axes that another input' + f' is reduced on. Got input spec: {s} and reduced spec: {reduced_s}') + if frozenset(flat_spec) & reduced_s: + raise core.ShardingTypeError( + 'Inputs cannot be sharded on the same axes that another input is' + ' reduced on. Reshard the input which is reduced to be sharded on' + ' the mesh axes it is reduced on via `jax.sharding.reshard(inp,' + f' jax.P(...))`. Got input spec: {s} and reduced spec: {reduced_s}') + return out_s.update(spec=out_s.spec.update(reduced=reduced_s)) + def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False, - require_same_dtypes=True): + require_same_dtypes=True, unreduced_rule=None, reduced_rule=None): dtype_rule = partial(naryop_dtype_rule, result_dtype, accepted_dtypes, name, allow_extended_dtype=allow_extended_dtype, require_same=require_same_dtypes) shape_rule = partial(broadcasting_shape_rule, name) sharding_rule = partial(broadcasting_sharding_rule, name) - prim = standard_primitive(shape_rule, dtype_rule, name, - sharding_rule=sharding_rule) + prim = standard_primitive( + shape_rule, dtype_rule, name, sharding_rule=sharding_rule, + vma_rule=partial(core.standard_vma_rule, name), + unreduced_rule=unreduced_rule, reduced_rule=nary_reduced_rule) batching.defbroadcasting(prim) - pe.def_trivial_padding(prim) return prim -standard_naryop = partial(naryop, _input_dtype) +standard_naryop = partial(naryop, input_dtype) # Like autograd.numpy.numpy_vjps.unbroadcast, this utility handles transposition # involving linear primitives with implicit broadcasting. def _unbroadcast(aval, x): - if not isinstance(aval, (core.DShapedArray, ShapedArray)): + if not isinstance(aval, ShapedArray): raise TypeError("transpose with implicit broadcasting of unshaped values") x_shape = np.shape(x) - if core.definitely_equal_shape(aval.shape, x_shape): + if (core.definitely_equal_shape(aval.shape, x_shape) and + aval.sharding == typeof(x).sharding): return x assert not aval.shape or len(x_shape) == len(aval.shape) if not aval.shape: return reduce_sum(x, list(range(len(x_shape)))) else: - dims = [i for i, (a, b) in enumerate(zip(x_shape, aval.shape)) if not core.definitely_equal(a, b)] - if config.enable_checks.value: assert all(aval.shape[i] == 1 for i in dims) - return reshape(reduce_sum(x, dims), aval.shape) - -def _maybe_broadcast(target_shape, x): + dims = [i for i, (a, b) in enumerate(zip(x_shape, aval.shape)) + if not core.definitely_equal(a, b)] + if config.enable_checks.value: + assert all(aval.shape[i] == 1 for i in dims) + x = reduce_sum(x, dims) if dims else x + return reshape(x, aval.shape, out_sharding=aval.to_cotangent_aval().sharding) + +def _maybe_broadcast(target_shape, x, target_sharding): x_shape = np.shape(x) - if core.definitely_equal_shape(x_shape, target_shape): + x_sharding = typeof(x).sharding + if (core.definitely_equal_shape(x_shape, target_shape) and + x_sharding == target_sharding): return x elif not x_shape: - return broadcast_in_dim(x, target_shape, ()) + return broadcast_in_dim(x, target_shape, (), out_sharding=target_sharding) else: dims = [i for i, (a, b) in enumerate(zip(x_shape, target_shape)) if core.definitely_equal(a, b)] squeeze_shape = [x_shape[i] for i in dims] - return broadcast_in_dim(reshape(x, squeeze_shape), target_shape, dims) + return broadcast_in_dim(reshape(x, squeeze_shape), target_shape, dims, + out_sharding=target_sharding) def broadcast_hlo( aval_out: core.ShapedArray, avals: Sequence[core.ShapedArray], @@ -3916,28 +4128,27 @@ def broadcast_hlo( out.append(arg) return out -def multi_sharding_in_dim(ctx, ops, in_avals, out_aval): - out = [] - for op, in_aval in zip(ops, in_avals): - if in_aval.sharding == out_aval.sharding or in_aval.sharding is None: - out.append(op) - else: - out.append(mlir.lower_with_sharding_in_types(ctx, op, out_aval)) - return out - -def _nary_lower_hlo(op: Callable, ctx, - *args: ir.Value, **params) -> Sequence[ir.Value]: +def _nary_lower_hlo( + op: Callable, ctx, *args: ir.Value, accuracy=None, **params +) -> Sequence[ir.Value]: """Lowers an elementwise operator to its MLIR equivalent. """ del params avals_in, (aval_out,) = ctx.avals_in, ctx.avals_out - args = mlir.multi_broadcast_in_dim(ctx, args, avals_in, aval_out.shape) - args = multi_sharding_in_dim(ctx, args, avals_in, aval_out) + args = mlir.multi_broadcast_in_dim(ctx, args, avals_in, aval_out.shape, + aval_out.sharding) out = op(*args) + if accuracy: + out = op(*args, result_accuracy=accuracy_attr(accuracy)) return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)] +def _unary_with_accuracy_pp_rule(eqn, context, settings): + params = dict(eqn.params) + if 'accuracy' in params and params['accuracy'] is None: + del params['accuracy'] + return core._pp_eqn(eqn.replace(params=params), context, settings) _float = {np.floating} _complex = {np.complexfloating} @@ -3997,48 +4208,84 @@ def _round_lower(ctx, x, *, rounding_method): mlir.register_lowering(is_finite_p, partial(_nary_lower_hlo, hlo.is_finite)) exp_p = standard_unop(_float | _complex, 'exp') -ad.defjvp2(exp_p, lambda g, ans, x: mul(g, ans)) +ad.defjvp2(exp_p, lambda g, ans, x, **kwargs: mul(g, ans)) mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.exponential)) -batching.ragged_prop_rules[exp_p] = batching.ragged_mask_elementwise_rule +core.pp_eqn_rules[exp_p] = _unary_with_accuracy_pp_rule exp2_p = standard_unop(_float | _complex, 'exp2') -ad.defjvp2(exp2_p, lambda g, ans, x: mul(log(_const(x, 2)), mul(g, ans))) -def _exp2_lower(ctx, x): +ad.defjvp2( + exp2_p, lambda g, ans, x, **kwargs: mul(log(_const(x, 2)), mul(g, ans)) +) + +def _exp2_lower(ctx, x, accuracy): x_aval, = ctx.avals_in log2 = mlir.ir_constant(np.array(np.log(2), x_aval.dtype)) log2 = mlir.broadcast_in_dim(ctx, log2, x_aval, broadcast_dimensions=()) - return [hlo.exponential(hlo.multiply(log2, x))] + return [ + hlo.exponential( + hlo.multiply(log2, x), result_accuracy=accuracy_attr(accuracy) + ) + ] + mlir.register_lowering(exp2_p, _exp2_lower) +core.pp_eqn_rules[exp2_p] = _unary_with_accuracy_pp_rule log_p = standard_unop(_float | _complex, 'log') -ad.defjvp(log_p, lambda g, x: div(g, x)) +ad.defjvp(log_p, lambda g, x, **kwargs: div(g, x)) mlir.register_lowering(log_p, partial(_nary_lower_hlo, hlo.log)) +core.pp_eqn_rules[log_p] = _unary_with_accuracy_pp_rule expm1_p = standard_unop(_float | _complex, 'expm1') -ad.defjvp2(expm1_p, lambda g, ans, x: mul(g, add(ans, _one(ans)))) +ad.defjvp2( + expm1_p, + lambda g, ans, x, accuracy: ( + mul(g, exp(x, accuracy=accuracy)) + if accuracy is AccuracyMode.HIGHEST + else mul(g, add(ans, _one(ans))) + ), +) mlir.register_lowering(expm1_p, partial(_nary_lower_hlo, hlo.exponential_minus_one)) +core.pp_eqn_rules[expm1_p] = _unary_with_accuracy_pp_rule log1p_p = standard_unop(_float | _complex, 'log1p') -ad.defjvp(log1p_p, lambda g, x: div(g, add(x, _one(x)))) +ad.defjvp(log1p_p, lambda g, x, **kwargs: div(g, add(x, _one(x)))) mlir.register_lowering(log1p_p, partial(_nary_lower_hlo, hlo.log_plus_one)) +core.pp_eqn_rules[log1p_p] = _unary_with_accuracy_pp_rule tanh_p = standard_unop(_float | _complex, 'tanh') -ad.defjvp2(tanh_p, lambda g, ans, x: mul(add(g, mul(g, ans)), - sub(_one(x), ans))) +ad.defjvp2( + tanh_p, + lambda g, ans, x, accuracy: mul(g, + mul(_const(x, 4), + mul(logistic(mul(_const(x, 2), x), accuracy=accuracy), + logistic(mul(_const(x, -2), x), accuracy=accuracy)), + ), + ) + if accuracy is AccuracyMode.HIGHEST + else mul(add(g, mul(g, ans)), sub(_one(x), ans)), +) mlir.register_lowering(tanh_p, partial(_nary_lower_hlo, hlo.tanh)) +core.pp_eqn_rules[tanh_p] = _unary_with_accuracy_pp_rule logistic_p = standard_unop(_float | _complex, 'logistic') -ad.defjvp2(logistic_p, lambda g, ans, x: mul(g, mul(ans, sub(_one(ans), ans)))) +ad.defjvp2( + logistic_p, + lambda g, ans, x, accuracy: mul(g, mul(ans, logistic(neg(x)))) + if accuracy is AccuracyMode.HIGHEST + else mul(g, mul(ans, sub(_one(ans), ans))), +) # TODO(phawkins): switch to LogisticOp lowering; debug numerical problems. # mlir.register_lowering(logistic_p, partial(_nary_lower_hlo, hlo.logistic)) -def logistic_impl(x): +def logistic_impl(x, accuracy): + del accuracy one = _const(x, 1) return div(one, add(one, exp(neg(x)))) mlir.register_lowering(logistic_p, mlir.lower_fun(logistic_impl, multiple_results=False)) +core.pp_eqn_rules[logistic_p] = _unary_with_accuracy_pp_rule def _sin_complex(x): # use expm1 instead of exp to avoid cancellation when abs(x) is small @@ -4056,22 +4303,23 @@ def _sin_complex(x): # avoid nan value when real(x) is zero and abs(x) is so large that abs(expm1(x)) is inf return select(a_is_zero, complex(_const(a, 0), im), complex(re, im)) -def _sin_lowering(ctx, x): +def _sin_lowering(ctx, x, accuracy): if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating): sine = mlir.lower_fun(_sin_complex, multiple_results=False) return sine(ctx, x) - return _nary_lower_hlo(hlo.sine, ctx, x) + return _nary_lower_hlo(hlo.sine, ctx, x, accuracy=accuracy) -def _sin_lin(nzs, x): + +def _sin_lin(nzs, x, accuracy): nz, = nzs - cos_x = cos(x) # TODO: allow this to happen in the linearized computation (need to fix backward_pass) - return (sin_p.bind(x), nz, cos_x, lambda cos_x_, t: mul(t, cos_x_)) + return (sin_p.bind(x, accuracy=accuracy), nz, cos(x), + lambda cos_x, t: mul(t, cos_x)) sin_p = standard_unop(_float | _complex, 'sin') -ad.defjvp(sin_p, lambda g, x: mul(g, cos(x))) +ad.defjvp(sin_p, lambda g, x, accuracy: mul(g, cos(x, accuracy=accuracy))) ad.primitive_linearizations[sin_p] = _sin_lin mlir.register_lowering(sin_p, _sin_lowering) -batching.ragged_prop_rules[sin_p] = batching.ragged_mask_elementwise_rule +core.pp_eqn_rules[sin_p] = _unary_with_accuracy_pp_rule def _cos_complex(x): # cos(x) = complex(cos(real(x)) * cosh(imag(x)), -sin(real(x)) * sinh(imag(x))) @@ -4085,19 +4333,23 @@ def _cos_complex(x): re, im = mul(cs, csh), mul(neg(sn), snh) return select(a_is_zero, complex(re, _const(a, 0)), complex(re, im)) -def _cos_lowering(ctx, x): +def _cos_lowering(ctx, x, accuracy): if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating): cosine = mlir.lower_fun(_cos_complex, multiple_results=False) return cosine(ctx, x) - return _nary_lower_hlo(hlo.cosine, ctx, x) + return _nary_lower_hlo(hlo.cosine, ctx, x, accuracy=accuracy) cos_p = standard_unop(_float | _complex, 'cos') -ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x)))) +ad.defjvp( + cos_p, lambda g, x, accuracy: neg(mul(g, sin(x, accuracy=accuracy))) +) mlir.register_lowering(cos_p, _cos_lowering) +core.pp_eqn_rules[cos_p] = _unary_with_accuracy_pp_rule tan_p = standard_unop(_float | _complex, 'tan') -ad.defjvp2(tan_p, lambda g, ans, x: mul(g, add(_const(x, 1), square(ans)))) +ad.defjvp2(tan_p, lambda g, ans, x, **kwargs: mul(g, add(_const(x, 1), square(ans)))) mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan)) +core.pp_eqn_rules[tan_p] = _unary_with_accuracy_pp_rule asin_p = standard_unop(_float | _complex, 'asin') ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(sub(_const(x, 1), square(x))))) @@ -4134,7 +4386,9 @@ def atan_impl(x): acosh_p = standard_unop(_float | _complex, 'acosh') ad.defjvp(acosh_p, - lambda g, x: mul(g, rsqrt(mul(sub(x, _one(x)), add(x, _one(x)))))) + # We use x^2-1 rather than (x+1)(x-1). The latter is more accurate + # for x near zero, but the function domain is x>=1. + lambda g, x: mul(g, rsqrt(sub(square(x), _one(x))))) mlir.register_lowering(acosh_p, partial(_nary_lower_hlo, chlo.acosh)) atanh_p = standard_unop(_float | _complex, 'atanh') @@ -4169,7 +4423,8 @@ def _complex_transpose_rule(t, x, y): else: return [None, _unbroadcast(y.aval, imag(neg(t)))] -_complex_dtype = lambda dtype, *args, **kwargs: (np.zeros((), dtype) + np.zeros((), np.complex64)).dtype +def _complex_dtype(dtype, *args, **kwargs): + return (np.zeros((), dtype) + np.zeros((), np.complex64)).dtype complex_p = naryop(_complex_dtype, [_complex_elem_types, _complex_elem_types], 'complex') ad.deflinear2(complex_p, _complex_transpose_rule) @@ -4199,7 +4454,8 @@ def _conj_transpose_rule(t, x, *, input_dtype): ad.primitive_jvps[conj_p] = partial(ad.linear_jvp, conj_p) ad.primitive_transposes[conj_p] = _conj_transpose_rule -abs_p = unop(_complex_basetype, _signedint | _float | _complex, 'abs') +abs_p = unop(_complex_basetype, _signedint | _float | _complex, 'abs', + supports_narrow_ints=False) mlir.register_lowering(abs_p, partial(_nary_lower_hlo, hlo.abs)) def _abs_jvp_rule(g, ans, x): @@ -4213,40 +4469,37 @@ def _abs_jvp_rule(g, ans, x): _maybe_real = lambda x: real(x) if _iscomplex(x) else x sqrt_p = standard_unop(_float | _complex, 'sqrt') -ad.defjvp2(sqrt_p, lambda g, ans, x: mul(g, div(_const(x, 0.5), ans))) +ad.defjvp2(sqrt_p, lambda g, ans, x, **kwargs: mul(g, div(_const(x, 0.5), ans))) mlir.register_lowering(sqrt_p, partial(_nary_lower_hlo, hlo.sqrt)) +core.pp_eqn_rules[sqrt_p] = _unary_with_accuracy_pp_rule rsqrt_p = standard_unop(_float | _complex, 'rsqrt') -ad.defjvp2(rsqrt_p, - lambda g, ans, x: - mul(g, mul(_const(x, -0.5), div(ans, x)))) +ad.defjvp2( + rsqrt_p, + lambda g, ans, x, **kwargs: mul(g, mul(_const(x, -0.5), div(ans, x))), +) mlir.register_lowering(rsqrt_p, partial(_nary_lower_hlo, hlo.rsqrt)) +core.pp_eqn_rules[rsqrt_p] = _unary_with_accuracy_pp_rule cbrt_p = standard_unop(_float, 'cbrt') -ad.defjvp2(cbrt_p, - lambda g, ans, x: mul(g, mul(_const(x, 1/3), integer_pow(ans, -2)))) +ad.defjvp2( + cbrt_p, + lambda g, ans, x, **kwargs: mul( + g, mul(_const(x, 1 / 3), integer_pow(ans, -2)) + ), +) mlir.register_lowering(cbrt_p, partial(_nary_lower_hlo, hlo.cbrt)) +core.pp_eqn_rules[cbrt_p] = _unary_with_accuracy_pp_rule square_p = standard_unop(_int | _float | _complex, 'square') -def _square_complex(x): - a, b = real(x), imag(x) - # zero square(x).real is handled explicitly for abs(a)==abs(b) cases - # where for finite a, 2 * a is non-finite: - zero_re = is_finite(a) & (eq(a, b) | eq(a, neg(b))) - # equivalent to a**2 - b**2 but avoids overflow errors for large a - # and large b cases: - re = mul(sub(a, b), add(a, b)) - im = mul(mul(a, b), _const(a, 2)) - return select(zero_re, complex(_const(a, 0), im), complex(re, im)) - def _square_lower_hlo(ctx, x): - if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating): - return mlir.lower_fun(_square_complex, multiple_results=False)(ctx, x) - return [hlo.multiply(x, x)] + if dtypes.issubdtype(ctx.avals_in[0].dtype, np.integer): + return [hlo.multiply(x, x)] + return [chlo.square(x)] ad.defjvp2(square_p, lambda g, ans, x: mul(g, mul(_const(x, 2), x))) -mlir.register_lowering(square_p, _square_lower_hlo) # TODO(pearu): use chlo.square +mlir.register_lowering(square_p, _square_lower_hlo) def _pow_dtype_rule(x, y): if (dtypes.issubdtype(x.dtype, np.inexact) and @@ -4271,8 +4524,9 @@ def _pow_jvp_lhs(g, ans, x, y): if dtypes.issubdtype(y_dtype, np.integer): if x.shape != y.shape: shape = broadcast_shapes(x.shape, y.shape) - x = _maybe_broadcast(shape, x) - y = _maybe_broadcast(shape, y) + sharding = broadcast_shardings(typeof(x), typeof(y)) + x = _maybe_broadcast(shape, x, sharding) + y = _maybe_broadcast(shape, y, sharding) jac = select(eq(y, _const(y, 0)), _zeros(y), mul(_replace_zero(y), pow(x, sub(y, _ones(y))))) else: @@ -4307,10 +4561,9 @@ def _integer_pow_jvp(g, x, *, y): integer_pow_p = standard_primitive( _attrgetter('shape'), _integer_pow_dtype_rule, 'integer_pow', - sharding_rule=_attrgetter('sharding')) + sharding_rule=_attrgetter('sharding'), vma_rule=_attrgetter('vma')) batching.defvectorized(integer_pow_p) ad.defjvp(integer_pow_p, _integer_pow_jvp) -pe.def_trivial_padding(integer_pow_p) def _integer_pow(x, *, y): # This should be kept in sync with the jax2tf translation rule. @@ -4343,8 +4596,6 @@ def _integer_pow_lowering(ctx, x, *, y): out = hlo.divide(mlir.full_like_aval(ctx, 1, ctx.avals_in[0]), x) else: lowering = mlir.lower_fun(_integer_pow, multiple_results=False) - if builtins.abs(y) >= 3: - lowering = mlir.cache_lowering(lowering) out, = lowering(ctx, x, y=y) aval_out, = ctx.avals_out return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)] @@ -4382,9 +4633,11 @@ def _add_jvp(primals, tangents): if type(xdot) is type(ydot) is ad_util.Zero: return primal_out, ad_util.Zero.from_primal_value(primal_out) if type(xdot) is ad_util.Zero: - return primal_out, _maybe_broadcast(primal_out.shape, ydot) + return (primal_out, _maybe_broadcast(primal_out.shape, ydot, + typeof(primal_out).sharding)) elif type(ydot) is ad_util.Zero: - return primal_out, _maybe_broadcast(primal_out.shape, xdot) + return (primal_out, _maybe_broadcast(primal_out.shape, xdot, + typeof(primal_out).sharding)) else: return primal_out, add(xdot, ydot) @@ -4400,12 +4653,33 @@ def _add_transpose(t, x, y): else: return [_unbroadcast(x_aval, t), _unbroadcast(y_aval, t)] -# TODO(slebedev): Why does mypy fail to infer the type here? -add_p: Primitive = standard_naryop([_num, _num], 'add') +def _add_unreduced_rule(out_sharding, x, y): + x_ur, y_ur = x.sharding.spec.unreduced, y.sharding.spec.unreduced + if x_ur and y_ur: + if x_ur != y_ur: + raise core.ShardingTypeError( + 'lhs and rhs to `add` must be unreduced along the same mesh axes. ' + f'Got lhs={x_ur}, rhs={y_ur}') + res_unreduced = x_ur + elif x_ur or y_ur: + if x_ur and not y_ur: + lhs_str, rhs_str = 'lhs', 'rhs' + else: + assert not x_ur and y_ur + lhs_str, rhs_str = 'rhs', 'lhs' + raise core.ShardingTypeError( + f'{lhs_str} is unreduced while {rhs_str} is not. `add` operation does' + ' not allow this because there will be implicit communication. Please' + f' reduce {lhs_str} via `reshard` before calling `add`.') + else: + res_unreduced = frozenset() + return out_sharding.update(spec=out_sharding.spec.update(unreduced=res_unreduced)) + +add_p: Primitive = naryop(input_dtype, [_num, _num], 'add', + unreduced_rule=_add_unreduced_rule) ad.primitive_jvps[add_p] = _add_jvp ad.primitive_transposes[add_p] = _add_transpose mlir.register_lowering(add_p, partial(_nary_lower_hlo, hlo.add)) -batching.ragged_prop_rules[add_p] = batching.ragged_mask_elementwise_rule def _sub_jvp(primals, tangents): x, y = primals @@ -4414,9 +4688,11 @@ def _sub_jvp(primals, tangents): if type(xdot) is type(ydot) is ad_util.Zero: return primal_out, ad_util.Zero.from_primal_value(primal_out) if type(xdot) is ad_util.Zero: - return primal_out, _maybe_broadcast(primal_out.shape, neg(ydot)) + return (primal_out, _maybe_broadcast(primal_out.shape, neg(ydot), + typeof(primal_out).sharding)) elif type(ydot) is ad_util.Zero: - return primal_out, _maybe_broadcast(primal_out.shape, xdot) + return (primal_out, _maybe_broadcast(primal_out.shape, xdot, + typeof(primal_out).sharding)) else: return primal_out, sub(xdot, ydot) @@ -4435,29 +4711,43 @@ def _sub_transpose(t, x, y): ad.primitive_jvps[sub_p] = _sub_jvp ad.primitive_transposes[sub_p] = _sub_transpose mlir.register_lowering(sub_p, partial(_nary_lower_hlo, hlo.subtract)) -batching.ragged_prop_rules[sub_p] = batching.ragged_mask_elementwise_rule - -def _mul_transpose(ct, x, y): - assert ad.is_undefined_primal(x) ^ ad.is_undefined_primal(y) - if ad.is_undefined_primal(x): - if type(ct) is ad_util.Zero: - return [ad_util.Zero(x.aval), None] - else: - return [_unbroadcast(x.aval, mul(ct, y)), None] +def _mul_unreduced_rule(out_sharding, x, y): + x_ur, y_ur = x.sharding.spec.unreduced, y.sharding.spec.unreduced + if x_ur and y_ur: + raise core.ShardingTypeError( + 'lhs and rhs to `mul` cannot be unreduced since mul is bilinear. ' + f'Got lhs={x_ur}, rhs={y_ur}') + elif x_ur and not y_ur: + if x_ur != y.sharding.spec.reduced: + raise core.ShardingTypeError( + 'RHS should be reduced along the same axes LHS is unreduced on. Got' + f' lhs={x} and rhs={y}') + out_unreduced = x_ur + elif not x_ur and y_ur: + if x.sharding.spec.reduced != y_ur: + raise core.ShardingTypeError( + 'LHS should be reduced along the same axes RHS is unreduced on. Got' + f' lhs={x} and rhs={y}') + out_unreduced = y_ur else: - if type(ct) is ad_util.Zero: - return [None, ad_util.Zero(y.aval)] - else: - return [None, _unbroadcast(y.aval, mul(x, ct))] + assert not x_ur and not y_ur + out_unreduced = frozenset() + if out_unreduced: + assert out_sharding.spec.reduced == out_unreduced + out_reduced = frozenset() # if both are equal, set difference is empty. + else: + out_reduced = out_sharding.spec.reduced + return out_sharding.update(spec=out_sharding.spec.update( + unreduced=out_unreduced, reduced=out_reduced)) -mul_p = standard_naryop([_num, _num], 'mul') +mul_p = standard_naryop([_num, _num], 'mul', unreduced_rule=_mul_unreduced_rule) ad.defjvp(mul_p, lambda xdot, x, y: mul(xdot, y), lambda ydot, x, y: mul(x, ydot)) -ad.primitive_transposes[mul_p] = _mul_transpose +ad.defbilinear(mul_p, lambda ct, x, y: _unbroadcast(x.aval, mul(ct, y)), + lambda ct, x, y: _unbroadcast(y.aval, mul(x, ct))) mlir.register_lowering(mul_p, partial(_nary_lower_hlo, hlo.multiply)) -batching.ragged_prop_rules[mul_p] = batching.ragged_mask_elementwise_rule def _div_transpose_rule(cotangent, x, y): assert ad.is_undefined_primal(x) and not ad.is_undefined_primal(y) @@ -4471,19 +4761,21 @@ def _div_transpose_rule(cotangent, x, y): lambda g, x, y: mul(mul(neg(g), x), integer_pow(y, -2))) ad.primitive_transposes[div_p] = _div_transpose_rule mlir.register_lowering(div_p, partial(_nary_lower_hlo, hlo.divide)) -batching.ragged_prop_rules[div_p] = batching.ragged_mask_elementwise_rule rem_p = standard_naryop([_int | _float, _int | _float], 'rem') ad.defjvp( rem_p, - lambda g, x, y: _maybe_broadcast(broadcast_shapes(np.shape(x), np.shape(y)), g), + lambda g, x, y: _maybe_broadcast( + broadcast_shapes(np.shape(x), np.shape(y)), g, + broadcast_shardings(typeof(x), typeof(y))), lambda g, x, y: mul(neg(g), mul(sign(div(x, y)), floor(abs(div(x, y)))))) mlir.register_lowering(rem_p, partial(_nary_lower_hlo, hlo.remainder)) def _minmax_complex_lowering(x, y, *, lax_cmp_pick_x): result_shape = broadcast_shapes(np.shape(x), np.shape(y)) - x = _maybe_broadcast(result_shape, x) - y = _maybe_broadcast(result_shape, y) + result_sharding = broadcast_shardings(typeof(x), typeof(y)) + x = _maybe_broadcast(result_shape, x, result_sharding) + y = _maybe_broadcast(result_shape, y, result_sharding) rx = real(x) ry = real(y) pick_x = select(eq(rx, ry), lax_cmp_pick_x(imag(x), imag(y)), @@ -4495,14 +4787,12 @@ def _minmax_complex_lowering(x, y, *, lax_cmp_pick_x): lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)), lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x))) mlir.register_lowering(max_p, partial(_nary_lower_hlo, mlir.max_hlo)) -batching.ragged_prop_rules[max_p] = batching.ragged_mask_elementwise_rule min_p: core.Primitive = standard_naryop([_any, _any], 'min') ad.defjvp2(min_p, lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)), lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x))) mlir.register_lowering(min_p, partial(_nary_lower_hlo, mlir.min_hlo)) -batching.ragged_prop_rules[min_p] = batching.ragged_mask_elementwise_rule shift_left_p = standard_naryop([_int, _int], 'shift_left') ad.defjvp_zero(shift_left_p) @@ -4553,7 +4843,8 @@ def _compare_lower_hlo_opaque(direction: str, ctx, avals_in, aval_out, x, y): def _compare_lower_hlo(direction: str, total_order: bool, ctx, x, y): avals_in, (aval_out,) = ctx.avals_in, ctx.avals_out x_dtype = avals_in[0].dtype - x, y = mlir.multi_broadcast_in_dim(ctx, (x, y), avals_in, aval_out.shape) + x, y = mlir.multi_broadcast_in_dim(ctx, (x, y), avals_in, aval_out.shape, + aval_out.sharding) if dtypes.issubdtype(x_dtype, dtypes.extended): assert not total_order return _compare_lower_hlo_opaque(direction, ctx, avals_in, aval_out, x, y) @@ -4568,7 +4859,6 @@ def _compare_lower_hlo(direction: str, total_order: bool, ctx, x, y): eq_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq', allow_extended_dtype=True) ad.defjvp_zero(eq_p) mlir.register_lowering(eq_p, partial(_compare_lower_hlo, "EQ", False)) -batching.ragged_prop_rules[eq_p] = batching.ragged_mask_elementwise_rule ne_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'ne', allow_extended_dtype=True) ad.defjvp_zero(ne_p) @@ -4589,7 +4879,6 @@ def _compare_lower_hlo(direction: str, total_order: bool, ctx, x, y): lt_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'lt') ad.defjvp_zero(lt_p) mlir.register_lowering(lt_p, partial(_compare_lower_hlo, "LT", False)) -batching.ragged_prop_rules[lt_p] = batching.ragged_mask_elementwise_rule eq_to_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq_to') ad.defjvp_zero(eq_to_p) @@ -4619,6 +4908,16 @@ def _convert_element_type_sharding_rule(operand, *, new_dtype, weak_type, return core.get_cur_mesh_sharding() return sharding +def _convert_element_type_unreduced_rule(out_s, operand, *, new_dtype, + weak_type, sharding): + return out_s.update(spec=out_s.spec.update( + unreduced=operand.sharding.spec.unreduced)) + +def _convert_element_type_reduced_rule(out_s, operand, *, new_dtype, + weak_type, sharding): + return out_s.update(spec=out_s.spec.update( + reduced=operand.sharding.spec.reduced)) + def _convert_element_type_dtype_rule(operand, *, new_dtype, weak_type, sharding): return new_dtype @@ -4637,8 +4936,10 @@ def _convert_element_type_transpose_rule(ct, operand, *, new_dtype, weak_type, elif core.primal_dtype_to_tangent_dtype(old_dtype) == dtypes.float0: return [ad_util.Zero(operand.aval.update(dtype=dtypes.float0, weak_type=False))] else: - return [convert_element_type_p.bind( - ct, new_dtype=old_dtype, weak_type=old_weak_type, sharding=sharding)] + out = convert_element_type_p.bind( + ct, new_dtype=old_dtype, weak_type=old_weak_type, + sharding=operand.aval.to_cotangent_aval().sharding) + return [out] def _convert_element_type_jvp_rule(tangent, primal_result, operand, *, new_dtype, weak_type, sharding): @@ -4649,7 +4950,14 @@ def _convert_element_type_jvp_rule(tangent, primal_result, operand, *, return convert_element_type_p.bind(tangent, new_dtype=new_tangent_dtype, weak_type=weak_type, sharding=sharding) -def _convert_elt_type_folding_rule(consts, eqn): +_foldable_types = { + literals.TypedNdArray, + np.ndarray, + *dtypes.python_scalar_types, + *literals.typed_scalar_types, +} + +def _convert_elt_type_folding_rule(consts, params, out_avals): # We constant-fold convert_element_types applied to constants if those # constants are Python builtin numeric types or numpy.ndarrays (so as not # to perform any device operations when constant-folding) and if the output @@ -4659,30 +4967,27 @@ def _convert_elt_type_folding_rule(consts, eqn): # we output a Python builtin numeric type. # TODO(mattjj): allow constant-folding CPU-backed JAX arrays c, = consts - o, = eqn.outvars - new_dtype = eqn.params['new_dtype'] - if (type(c) in {np.ndarray, *dtypes.python_scalar_dtypes} and - isinstance(o.aval, core.UnshapedArray) and not np.shape(c) and - not dtypes.issubdtype(new_dtype, dtypes.extended)): - out = np.array(c) + out_aval, = out_avals + new_dtype = params['new_dtype'] + if (type(c) in _foldable_types and isinstance(out_aval, ShapedArray) + and not np.shape(c) + and not dtypes.issubdtype(new_dtype, dtypes.extended)): + out = np.asarray(c) if (dtypes.issubdtype(out.dtype, np.complexfloating) and not dtypes.issubdtype(new_dtype, np.complexfloating)): out = out.real out = out.astype(new_dtype) - if not o.aval.weak_type: - return [out], None - out = out.item() - if core.get_aval(out).dtype is o.aval.dtype: - return [out], None - return [None], eqn + return [literals.TypedNdArray(out, weak_type=out_aval.weak_type)] + return None def _convert_elt_type_fwd_rule(eqn): - v, = eqn.invars - if (not dtypes.issubdtype(eqn.params['new_dtype'], dtypes.extended) and - not dtypes.issubdtype(v.aval.dtype, dtypes.extended) and - v.aval.dtype == eqn.params['new_dtype'] and - v.aval.weak_type == eqn.params['weak_type']): - return [v], None + t, = eqn.invars + aval = t.aval + if (aval.dtype == eqn.params['new_dtype'] and + aval.weak_type == eqn.params['weak_type'] and + not dtypes.issubdtype(aval.dtype, dtypes.extended) and + (eqn.params['sharding'] is None or eqn.params['sharding'] == aval.sharding)): + return [0], None else: return [None], eqn @@ -4692,7 +4997,13 @@ def _convert_elt_type_pp_rule(eqn, context, settings): del params['sharding'] # don't show trivial case return core._pp_eqn(eqn.replace(params=params), context, settings) -convert_element_type_p = Primitive('convert_element_type') +convert_element_type_p = standard_primitive( + _convert_element_type_shape_rule, _convert_element_type_dtype_rule, + 'convert_element_type', weak_type_rule=_convert_element_type_weak_type_rule, + sharding_rule=_convert_element_type_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'convert_element_type'), + unreduced_rule=_convert_element_type_unreduced_rule, + reduced_rule=_convert_element_type_reduced_rule) # TODO(dougalm): I'm overriding bind_with_trace here because that's the closest thing to # the old "custom bind" but it might not be the best way to do this. @@ -4706,21 +5017,21 @@ def _convert_element_type_bind_with_trace(trace, args, params): convert_element_type_p.def_bind_with_trace(_convert_element_type_bind_with_trace) convert_element_type_p.def_impl(partial(dispatch.apply_primitive, convert_element_type_p)) -convert_element_type_p.def_abstract_eval( - partial(standard_abstract_eval, convert_element_type_p, - _convert_element_type_shape_rule, _convert_element_type_dtype_rule, - _convert_element_type_weak_type_rule, - _convert_element_type_sharding_rule)) ad.defjvp2(convert_element_type_p, _convert_element_type_jvp_rule) ad.primitive_transposes[convert_element_type_p] = _convert_element_type_transpose_rule -batching.defvectorized(convert_element_type_p) + +def _convert_element_type_batching_rule( + axis_data, batched_args, batch_dims, *, new_dtype, weak_type, sharding): + if sharding is not None: + sharding = batching.get_sharding_for_vmap(axis_data, sharding, 0) + new_params = dict(new_dtype=new_dtype, weak_type=weak_type, sharding=sharding) + return convert_element_type_p.bind(*batched_args, **new_params), batch_dims[0] +batching.fancy_primitive_batchers[convert_element_type_p] = _convert_element_type_batching_rule +batching.skippable_batchers[convert_element_type_p] = lambda _: () + pe.const_fold_rules[convert_element_type_p] = _convert_elt_type_folding_rule pe.forwarding_rules[convert_element_type_p] = _convert_elt_type_fwd_rule -pe.def_trivial_padding(convert_element_type_p) core.pp_eqn_rules[convert_element_type_p] = _convert_elt_type_pp_rule -batching.ragged_prop_rules[convert_element_type_p] = ( - batching.ragged_mask_elementwise_rule -) def _real_dtype(dtype): return np.finfo(dtype).dtype @@ -4743,6 +5054,9 @@ def _to_edtype_abstract_eval(x, *, edtype): not isinstance(x.dtype, dtypes.ExtendedDType)) # For backward compatibility, if the edtype rules have a `convert_to` method, # use that rather than looking for an `allow_conversion: bool` attribute. + if not isinstance(x, ShapedArray): + raise TypeError("can only convert to an extended dtype on an array type," + f"but got {type(x)}") if convert_to := getattr(edtype._rules, 'convert_to', None): allow_conversion = convert_to(x.dtype, edtype) else: @@ -4752,6 +5066,7 @@ def _to_edtype_abstract_eval(x, *, edtype): f"Cannot convert_element_type from {dtype_to_string(x.dtype)} " f"to {dtype_to_string(edtype)}") rep_aval = core.physical_element_aval(edtype) + assert tuple(rep_aval.sharding.spec) == (None,) * rep_aval.ndim if x.dtype != rep_aval.dtype: raise ValueError( "can only convert to extended dtype from its representation dtype, " @@ -4774,7 +5089,18 @@ def _to_edtype_abstract_eval(x, *, edtype): f" has a representation shape {rep_aval.shape} while the given " f"representation array has shape {x.shape}, so the shape suffix " f"does not match: given {shape_suffix} but required {rep_aval.shape}.") - return x.update(shape=shape_prefix, dtype=edtype) + if isinstance(x, ShapedArray): + spec_prefix, spec_suffix = x.sharding.spec[:n], x.sharding.spec[n:] + if tuple(spec_suffix) != (None,) * len(spec_suffix): + raise ValueError( + "can only convert to extended dtype from an array with trailing " + "axes that are not explicitly sharded, but tried to convert from " + f"{x.str_short(short_dtypes=True)} to an extended dtype with element " + f"shape {rep_aval.shape}") + return x.update(shape=shape_prefix, dtype=edtype, + sharding=x.sharding.update(spec=spec_prefix)) + else: + assert False # unreachable, see isinstance check above to_edtype_p = Primitive('to_edtype') to_edtype_p.def_impl(partial(dispatch.apply_primitive, to_edtype_p)) @@ -4791,6 +5117,9 @@ def _to_edtype_abstract_eval(x, *, edtype): def _from_edtype_abstract_eval(x, *, dtype): assert (isinstance(x.dtype, dtypes.ExtendedDType) and not isinstance(dtype, dtypes.ExtendedDType)) + if not isinstance(x, ShapedArray): + raise TypeError("can only convert from an extended dtype on an array type," + f"but got {type(x)}") if convert_from := getattr(x.dtype._rules, 'convert_from', None): allow_conversion = convert_from(x.dtype, dtype) else: @@ -4800,16 +5129,17 @@ def _from_edtype_abstract_eval(x, *, dtype): f"Cannot convert_element_type from {dtype_to_string(x.dtype)} " f"to {dtype_to_string(dtype)}") rep_aval = core.physical_element_aval(x.dtype) + assert tuple(rep_aval.sharding.spec) == (None,) * rep_aval.ndim if rep_aval.dtype != dtype: raise ValueError( "can only convert from extended dtype to its representation dtype, " f"but tried to convert from {dtype_to_string(x.dtype)} to " f"{dtype_to_string(dtype)} which doesn't match the representation type " f"{dtype_to_string(rep_aval.dtype)}.") - if all(isinstance(d, int) for d in x.shape): - return core.ShapedArray(shape=(*x.shape, *rep_aval.shape), dtype=dtype) + if isinstance(x, ShapedArray): + return x.update(shape=(*x.shape, *rep_aval.shape), dtype=dtype) else: - raise NotImplementedError + assert False # unreachable, see isinstance check above from_edtype_p = Primitive('from_edtype') from_edtype_p.def_impl(partial(dispatch.apply_primitive, from_edtype_p)) @@ -4824,11 +5154,10 @@ def _from_edtype_abstract_eval(x, *, dtype): def _bitcast_convert_type_shape_rule(operand, *, new_dtype): - old_dtype = dtypes.canonicalize_dtype(operand.dtype) - new_dtype = dtypes.canonicalize_dtype(new_dtype) + old_dtype = operand.dtype - old_nbits = dtypes.bit_width(old_dtype) - new_nbits = dtypes.bit_width(new_dtype) + old_nbits = dtypes.itemsize_bits(old_dtype) + new_nbits = dtypes.itemsize_bits(new_dtype) if old_nbits == new_nbits: return operand.shape @@ -4845,22 +5174,20 @@ def _bitcast_convert_type_shape_rule(operand, *, new_dtype): return operand.shape[:-1] def _bitcast_convert_type_sharding_rule(operand, *, new_dtype): - old_dtype = dtypes.canonicalize_dtype(operand.dtype) - new_dtype = dtypes.canonicalize_dtype(new_dtype) + old_dtype = operand.dtype - old_nbits = dtypes.bit_width(old_dtype) - new_nbits = dtypes.bit_width(new_dtype) + old_nbits = dtypes.itemsize_bits(old_dtype) + new_nbits = dtypes.itemsize_bits(new_dtype) if old_nbits == new_nbits: return operand.sharding elif old_nbits > new_nbits: - return operand.sharding.with_spec((*operand.sharding.spec, None)) + return operand.sharding.update(spec=(*operand.sharding.spec, None)) else: - return operand.sharding.with_spec(operand.sharding.spec[:-1]) + return operand.sharding.update(spec=operand.sharding.spec[:-1]) def _bitcast_convert_type_dtype_rule(operand, *, new_dtype): - old_dtype = dtypes.canonicalize_dtype(operand.dtype) - new_dtype = dtypes.canonicalize_dtype(new_dtype) + old_dtype = operand.dtype if (dtypes.issubdtype(old_dtype, np.bool_) or dtypes.issubdtype(old_dtype, np.complexfloating) or dtypes.issubdtype(new_dtype, np.bool_) or @@ -4875,7 +5202,8 @@ def _bitcast_convert_type_dtype_rule(operand, *, new_dtype): bitcast_convert_type_p = standard_primitive( _bitcast_convert_type_shape_rule, _bitcast_convert_type_dtype_rule, 'bitcast_convert_type', weak_type_rule=_strip_weak_type, - sharding_rule=_bitcast_convert_type_sharding_rule) + sharding_rule=_bitcast_convert_type_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'bitcast_convert_type')) ad.defjvp_zero(bitcast_convert_type_p) batching.defvectorized(bitcast_convert_type_p) @@ -4985,7 +5313,7 @@ def _dot_general_shape_computation(lhs_shape, rhs_shape, dimension_numbers): lhs_tensored_shape = tuple_delete(lhs_shape, lhs_contract_or_batch) rhs_group = () if isinstance(dimension_numbers, RaggedDotDimensionNumbers): - rhs_group = tuple(dimension_numbers.rhs_group_dimensions) + rhs_group = tuple(dimension_numbers.rhs_group_dimensions) # pytype: disable=attribute-error rhs_contract_or_batch_or_group = tuple( sorted(tuple(rhs_contracting) + tuple(rhs_batch) + rhs_group) ) @@ -4996,13 +5324,14 @@ def _dot_general_shape_computation(lhs_shape, rhs_shape, dimension_numbers): def _check_specs_match(lhs_spec, rhs_spec, msg): for l, r in zip(lhs_spec, rhs_spec): if l is not None and r is not None and l != r: - raise TypeError(msg) + raise core.ShardingTypeError(msg) def _dot_general_sharding_rule(lhs, rhs, *, dimension_numbers, precision, preferred_element_type: DTypeLike | None, out_sharding): - if lhs.sharding.mesh != rhs.sharding.mesh: - raise ValueError( + if (not lhs.sharding.mesh.empty and not rhs.sharding.mesh.empty and + lhs.sharding.mesh != rhs.sharding.mesh): + raise core.ShardingTypeError( 'Mesh of both lhs and rhs should match. Got lhs:' f' {lhs.sharding.mesh} and rhs: {rhs.sharding.mesh}') @@ -5011,6 +5340,9 @@ def _dot_general_sharding_rule(lhs, rhs, *, dimension_numbers, precision, return out_sharding (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers + lhs_contracting_spec = tuple(lhs.sharding.spec[i] for i in lhs_contracting) + rhs_contracting_spec = tuple(rhs.sharding.spec[i] for i in rhs_contracting) + lhs_batch_spec = tuple(lhs.sharding.spec[i] for i in lhs_batch) rhs_batch_spec = tuple(rhs.sharding.spec[i] for i in rhs_batch) msg = ("dot_general requires lhs batch dimensions and rhs batch dimensions " @@ -5018,23 +5350,24 @@ def _dot_general_sharding_rule(lhs, rhs, *, dimension_numbers, precision, f"{rhs_batch_spec}.") _check_specs_match(lhs_batch_spec, rhs_batch_spec, msg) - lhs_contracting_spec = tuple(lhs.sharding.spec[i] for i in lhs_contracting) - rhs_contracting_spec = tuple(rhs.sharding.spec[i] for i in rhs_contracting) msg = ("dot_general requires contracting dimensions to have consistent " f"sharding, got {lhs_contracting_spec} and {rhs_contracting_spec}.") _check_specs_match(lhs_contracting_spec, rhs_contracting_spec, msg) for l, r in zip(lhs_contracting_spec, rhs_contracting_spec): if l is not None and r is not None: - raise ValueError( + raise core.ShardingTypeError( 'Contracting dimensions are sharded and it is ambiguous how the' ' output should be sharded. Please specify the output sharding via' - ' the `out_sharding` parameter of einsum. Or reshard your input via' - ' `jax.experimental.shard.reshard` so that the dot is conflict free.' + ' the `out_sharding` parameter.' f' Got {lhs_contracting_spec=} and {rhs_contracting_spec=}') + if lhs.sharding.mesh.empty and not rhs.sharding.mesh.empty: + mesh = rhs.sharding.mesh + else: + mesh = lhs.sharding.mesh return _dot_general_sharding_computation( - lhs.sharding.spec, rhs.sharding.spec, dimension_numbers, lhs.sharding.mesh) + lhs.sharding.spec, rhs.sharding.spec, dimension_numbers, mesh) def _dot_general_sharding_computation(lhs_spec, rhs_spec, dimension_numbers, mesh): @@ -5046,6 +5379,35 @@ def _dot_general_sharding_computation(lhs_spec, rhs_spec, rhs_tensored_spec = tuple_delete(rhs_spec, rhs_contract_or_batch) return NamedSharding(mesh, P(*(batch_spec + lhs_tensored_spec + rhs_tensored_spec))) + +def _dot_general_unreduced_rule(out_s, lhs, rhs, *, dimension_numbers, + **kwargs): + if lhs.sharding.spec.unreduced or rhs.sharding.spec.unreduced: + raise core.ShardingTypeError( + f'lhs or rhs passed to dot_general cannot be unreduced. Got {lhs=} and' + f' {rhs=}') + if out_s.spec.unreduced: + (lhs_contracting, rhs_contracting), _ = dimension_numbers + lhs_contracting_spec = tuple(lhs.sharding.spec[i] for i in lhs_contracting) + rhs_contracting_spec = tuple(rhs.sharding.spec[i] for i in rhs_contracting) + if lhs_contracting_spec != rhs_contracting_spec: + raise core.ShardingTypeError( + 'lhs and rhs contracting dims should be sharded identically when' + ' out_sharding provided to dot_general mentions unreduced_axes.' + f' Got {out_s=}, {lhs_contracting_spec=},' + f' {rhs_contracting_spec=}') + flat_spec = [s for s in flatten_spec(lhs_contracting_spec) if s is not None] + if out_s.spec.unreduced != frozenset(flat_spec): + raise core.ShardingTypeError( + "out_sharding's unreduced axes should be equal to the contracting" + f' specs. Got unreduced axes={out_s.spec.unreduced} and' + f' contracting spec={lhs_contracting_spec}') + return out_s + +def _dot_general_reduced_rule(out_s, lhs, rhs, *, dimension_numbers, **kwargs): + return out_s + + def tuple_delete(tup, idx): idx_ = set(idx) return tuple(tup[i] for i in range(len(tup)) if i not in idx_) @@ -5114,15 +5476,16 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract))) unsorted_axes = list(x_batch) + x_kept + x_contract_sorted_by_y out_axes = np.argsort(unsorted_axes) - xs = x.aval.sharding + xs = x.aval.to_cotangent_aval().sharding inverse_spec = tuple(xs.spec[o] for o in unsorted_axes) - ds = xs.with_spec(inverse_spec) + ds = xs.update(spec=xs.spec.update(partitions=inverse_spec)) dot_general_out = dot_general(g, y, dims, precision=precision, preferred_element_type=preferred_element_type, out_sharding=ds) x_bar = transpose(dot_general_out, tuple(out_axes)) if x_bar.dtype != x.aval.dtype: - x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type) + x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type, + warn_on_complex_to_real_cast=False) return x_bar def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision, @@ -5130,13 +5493,10 @@ def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision, out_sharding): (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch)) - y_bar = _dot_general_transpose_lhs( + return _dot_general_transpose_lhs( g, y, x, dimension_numbers=swapped_dimension_numbers, precision=precision, preferred_element_type=preferred_element_type, out_sharding=out_sharding, swap_ans=True) - if y_bar.dtype != y.aval.dtype: - y_bar = _convert_element_type(y_bar, y.aval.dtype, y.aval.weak_type) - return y_bar def _dot_batch_rule( @@ -5156,30 +5516,14 @@ def _dot_batch_rule( lhs, rhs = unpack_args(batched_args) lbd, rbd = unpack_dims(batch_dims) - left_stack_dim = lbd.stacked_axis if type(lbd) is RaggedAxis else lbd - right_stack_dim = rbd.stacked_axis if type(rbd) is RaggedAxis else rbd new_dimension_numbers, result_stack_dim = _dot_general_batch_dim_nums( - (np.ndim(lhs), np.ndim(rhs)), (left_stack_dim, right_stack_dim), + (np.ndim(lhs), np.ndim(rhs)), (lbd, rbd), dimension_numbers) - # TODO Should probably check that any ragged dimensions have corresponding - # sizes, because otherwise the dot product is technically undefined. - # - # This masking is not strictly necessary for non-contraction dimensions; - # we could micro-optimize here by avoiding computing that mask. - if type(lbd) is RaggedAxis: - lhs = batching.mask_ragged_axes(lhs, _get_sum_identity, lbd) - lhs_shape = batching.bdim_as_shape(lbd, lhs.shape) - else: - lhs_shape = np.shape(lhs) - if type(rbd) is RaggedAxis: - rhs = batching.mask_ragged_axes(rhs, _get_sum_identity, rbd) - rhs_shape = batching.bdim_as_shape(rbd, rhs.shape) - else: - rhs_shape = np.shape(rhs) - result_batch_dim = batching.shape_as_bdim( - result_stack_dim, - _dot_general_shape_computation(lhs_shape, rhs_shape, new_dimension_numbers)) + lhs_shape = np.shape(lhs) + rhs_shape = np.shape(rhs) + result_shape = _dot_general_shape_computation(lhs_shape, rhs_shape, new_dimension_numbers) + result_batch_dim = canonicalize_axis(result_stack_dim, len(result_shape)) if out_sharding is not None: out_sharding = batching.get_sharding_for_vmap( @@ -5267,15 +5611,6 @@ def bump_dims(dims, b): ) return new_dimension_numbers, result_batch_dim -def _dot_general_padding_rule(in_avals, out_avals, lhs, rhs, *, - dimension_numbers, **params): - lhs_aval, _ = in_avals - (lhs_contract, _), _ = dimension_numbers - padded_axes = [(i, lhs_aval.shape[i].val) for i in lhs_contract - if isinstance(lhs_aval.shape[i], pe.BoundedAxisSize)] - lhs_ = _replace_masked_values(lhs, 0, padded_axes) - return [dot_general(lhs_, rhs, dimension_numbers=dimension_numbers, **params)] - def _dot_general_pp_rule(eqn, context, settings) -> pp.Doc: # * suppress printing precision or preferred_element_type when None. # * print dimension_numbers as list-of-lists to be shorter. @@ -5286,64 +5621,14 @@ def _dot_general_pp_rule(eqn, context, settings) -> pp.Doc: return core._pp_eqn(eqn.replace(params=printed_params), context, settings) -def _dot_general_ragged_prop_rule(eqn_params, invar_raggedness, outvars): - assert len(invar_raggedness) == 2 - assert len(outvars) == 1 - invar_raggedness_lhs = invar_raggedness[0] - invar_raggedness_rhs = invar_raggedness[1] - - dimension_numbers = eqn_params['dimension_numbers'] - (lhs_contracting, rhs_contracting), (_, _) = dimension_numbers - - if not invar_raggedness_lhs and not invar_raggedness_rhs: - # Both are dense - it is valid to reach here, because dense operations - # are legal in code running under ragged prop. - return invar_raggedness, [None] - - if not invar_raggedness_lhs or not invar_raggedness_rhs: - # One ragged, one dense - if not invar_raggedness_lhs: - # left is dense, right is ragged - _, ragged_axis_dim_rhs, _, _ = invar_raggedness_rhs - if rhs_contracting != ragged_axis_dim_rhs: - # Contraction is on a dense dimension, this is valid! - return invar_raggedness, [None] - if not invar_raggedness_rhs: - # left is ragged, right is dense - _, ragged_axis_dim_lhs, _, _ = invar_raggedness_lhs - if lhs_contracting != ragged_axis_dim_lhs: - # Contraction is on a dense dimension, this is valid! - return invar_raggedness, [None] - - raise NotImplementedError('NYI - dense and ragged dim contraction') - - stacked_axis_lhs, ragged_axis_dim_lhs, _, _ = invar_raggedness_lhs - stacked_axis_rhs, ragged_axis_dim_rhs, _, _ = invar_raggedness_rhs - - if stacked_axis_rhs != 0 or stacked_axis_lhs != 0: - raise NotImplementedError( - 'Dot general ragged prop for non 0 stacked axis, NYI' - ) - - # We only support ragged k atm, that is, lhs is (m, ragged_k) and rhs is - # (ragged_k, n), meaning the output is dense. - if ragged_axis_dim_lhs != 2 or ragged_axis_dim_rhs != 1: - raise NotImplementedError( - 'Dot general ragged prop for non contraction raggedness, NYI' - ) - - assert len(outvars) == 1 - - # TODO(mvoz): A constant on batching.* ? - # Dense (m, n) - no jumble only atm - return invar_raggedness, [None] - - dot_general_p = standard_primitive( _dot_general_shape_rule, _dot_general_dtype_rule, 'dot_general', sharding_rule=_dot_general_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'dot_general'), + unreduced_rule=_dot_general_unreduced_rule, + reduced_rule=_dot_general_reduced_rule, ) @@ -5367,19 +5652,28 @@ def _dot_general_batch_unpack_dims(batch_dims): ) batching.fancy_primitive_batchers[dot_general_p] = _dot_general_batch_rule batching.skippable_batchers[dot_general_p] = lambda _: () -pe.padding_rules[dot_general_p] = _dot_general_padding_rule core.pp_eqn_rules[dot_general_p] = _dot_general_pp_rule -batching.ragged_prop_rules[dot_general_p] = _dot_general_ragged_prop_rule -def precision_attr(precision: Precision) -> ir.ArrayAttr: + +def _full_precision(precision: Precision) -> tuple[Precision, Precision]: if precision is None or isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)): - full_precision = (Precision.DEFAULT, Precision.DEFAULT) + return (Precision.DEFAULT, Precision.DEFAULT) elif not isinstance(precision, tuple): - full_precision = (precision, precision) + return (precision, precision) else: - full_precision = precision + return precision + + +def precision_attr(precision: Precision) -> ir.ArrayAttr: return ir.ArrayAttr.get( - [hlo.PrecisionAttr.get(str(p)) for p in full_precision]) + [hlo.PrecisionAttr.get(str(p)) for p in _full_precision(precision)] + ) + + +def chlo_precision_attr(precision: Precision) -> ir.ArrayAttr: + return ir.ArrayAttr.get( + [chlo.PrecisionAttr.get(str(p)) for p in _full_precision(precision)] + ) def dot_algorithm_attr(precision: CanonicalPrecision, lhs_dtype: DTypeLike, @@ -5417,32 +5711,30 @@ def maybe_convert_dtype(input_dtype, target_dtypes): return lhs_dtype, rhs_dtype, out_type -def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, - precision, preferred_element_type: np.dtype | None, - out_sharding, platform: str = "default"): +def accuracy_attr(accuracy) -> hlo.ResultAccuracyAttr: + if isinstance(accuracy, AccuracyMode): + return hlo.ResultAccuracyAttr.get(0.0, 0.0, int(0), str(accuracy.name)) + elif isinstance(accuracy, Tolerance): + return hlo.ResultAccuracyAttr.get( + atol=accuracy.atol, + rtol=accuracy.rtol, + ulps=accuracy.ulps, + mode='TOLERANCE', + ) + +def _handle_dot_precision(ctx, lhs, rhs, precision, platform): def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2, - dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz) - if dtypes.float8_e3m4 is not None: - fp8_dtypes += (dtypes.float8_e3m4,) - if dtypes.float8_e4m3 is not None: - fp8_dtypes += (dtypes.float8_e4m3,) - if dtypes.float8_e8m0fnu is not None: - fp8_dtypes += (dtypes.float8_e8m0fnu,) + dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz, + dtypes.float8_e3m4, dtypes.float8_e4m3, + dtypes.float8_e8m0fnu) return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes - del preferred_element_type # Implied by the output aval - lhs_aval, rhs_aval = ctx.avals_in + + # The *_ lets us reuse this for ragged_dot_general, which has group_sizes. + lhs_aval, rhs_aval, *_ = ctx.avals_in lhs_dtype, rhs_dtype = lhs_aval.dtype, rhs_aval.dtype aval_out, = ctx.avals_out accumulation_aval = aval_out - (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers - - dot_dnums = hlo.DotDimensionNumbers.get( - lhs_batching_dimensions=list(lhs_batch), - rhs_batching_dimensions=list(rhs_batch), - lhs_contracting_dimensions=list(lhs_contracting), - rhs_contracting_dimensions=list(rhs_contracting)) - algorithm_kwarg = {} if isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)): # The CPU backend silently ignores the algorithm spec, so we check here to @@ -5500,7 +5792,22 @@ def maybe_convert_dtype(operand, operand_aval, target_dtype): core.ShapedArray(lhs_aval.shape, aval_out.dtype)) rhs = mlir.convert_hlo(ctx, rhs, rhs_aval, core.ShapedArray(rhs_aval.shape, aval_out.dtype)) + return lhs, rhs, accumulation_aval, algorithm_kwarg + +def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, + precision, preferred_element_type: np.dtype | None, + out_sharding, platform: str = "default"): + del preferred_element_type # Implied by the output aval + lhs, rhs, accumulation_aval, algorithm_kwarg = _handle_dot_precision( + ctx, lhs, rhs, precision, platform + ) + (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers + dot_dnums = hlo.DotDimensionNumbers.get( + lhs_batching_dimensions=list(lhs_batch), + rhs_batching_dimensions=list(rhs_batch), + lhs_contracting_dimensions=list(lhs_contracting), + rhs_contracting_dimensions=list(rhs_contracting)) result = hlo.dot_general( mlir.aval_to_ir_type(accumulation_aval), lhs, @@ -5509,7 +5816,7 @@ def maybe_convert_dtype(operand, operand_aval, target_dtype): precision_config=precision_attr(precision), **algorithm_kwarg, ) - + aval_out, = ctx.avals_out result = mlir.lower_with_sharding_in_types(ctx, result, aval_out) if accumulation_aval.dtype != aval_out.dtype: result = mlir.convert_hlo(ctx, result, accumulation_aval, aval_out) @@ -5805,7 +6112,7 @@ def grad_x_dims(): unsorted_axes = list(x_batch) + x_kept + x_contract_sorted_by_y case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH: raise unimplemented('grad_x_dims', mode) - return dims, unsorted_axes + return dims, unsorted_axes # pytype: disable=name-error def grad_y_dims(): match mode: @@ -5824,7 +6131,7 @@ def grad_y_dims(): ) case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH: raise unimplemented('grad_y_dims', mode) - return dims, unsorted_axes + return dims, unsorted_axes # pytype: disable=name-error def _ragged_dot_grad(lhs, rhs, dims_fn, aval): dims, unsorted_axes = dims_fn() @@ -5918,6 +6225,7 @@ def _ragged_dot_general_batch_rule( _ragged_dot_general_shape_rule, _ragged_dot_general_dtype_rule, 'ragged_dot_general', + vma_rule=partial(core.standard_vma_rule, 'ragged_dot') ) ad.primitive_jvps[ragged_dot_general_p] = _ragged_dot_general_jvp_rule ad.primitive_transposes[ragged_dot_general_p] = _ragged_dot_general_transpose_rule @@ -6025,13 +6333,80 @@ def expand(x, dim, gs, *axes): lhs, rhs, dimension_numbers=ragged_dot_dimension_numbers.dot_dimension_numbers, - ) + ) # pytype: disable=bad-return-type + + +def _ragged_dot_general_lower( + ctx, + lhs, + rhs, + group_sizes, + *, + ragged_dot_dimension_numbers, + precision, + preferred_element_type: np.dtype | None, + group_offset: Array | None = None, + platform: str = 'default', +): + if group_offset is not None: + raise NotImplementedError('Unimplemented group_offset support.') + + if not config.jax_ragged_dot_use_ragged_dot_instruction.value: + result = mlir.lower_fun(_ragged_dot_general_impl, multiple_results=False)( + ctx, lhs, rhs, group_sizes, + ragged_dot_dimension_numbers=ragged_dot_dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type, + group_offset=group_offset + ) + (aval_out,) = ctx.avals_out + return mlir.lower_with_sharding_in_types(ctx, result, aval_out) + + del preferred_element_type # Implied by the output aval + lhs, rhs, accumulation_aval, _ = _handle_dot_precision( + ctx, lhs, rhs, precision, platform + ) + (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = ( + ragged_dot_dimension_numbers.dot_dimension_numbers + ) + ragged_dot_dnums = chlo.RaggedDotDimensionNumbers.get( + lhs_batching_dimensions=list(lhs_batch), + rhs_batching_dimensions=list(rhs_batch), + lhs_contracting_dimensions=list(lhs_contracting), + rhs_contracting_dimensions=list(rhs_contracting), + lhs_ragged_dimensions=list( + ragged_dot_dimension_numbers.lhs_ragged_dimensions + ), + rhs_group_dimensions=list( + ragged_dot_dimension_numbers.rhs_group_dimensions + ), + ) + result = chlo.ragged_dot( + mlir.aval_to_ir_type(accumulation_aval), + lhs, + rhs, + group_sizes, + ragged_dot_dnums, + precision_config=chlo_precision_attr(precision), + ) + (aval_out,) = ctx.avals_out + result = mlir.lower_with_sharding_in_types(ctx, result, aval_out) + if accumulation_aval.dtype != aval_out.dtype: + result = mlir.convert_hlo(ctx, result, accumulation_aval, aval_out) + return [result] mlir.register_lowering(ragged_dot_general_p, mlir.lower_fun(_ragged_dot_general_impl, multiple_results=False)) +for platform in ['tpu', 'gpu']: + mlir.register_lowering( + ragged_dot_general_p, + partial(_ragged_dot_general_lower, platform=platform), + platform=platform, + ) + def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions, sharding): @@ -6077,197 +6452,109 @@ def _broadcast_in_dim_sharding_rule(operand, *, shape, broadcast_dimensions, orig_spec = iter(operand.sharding.spec) new_spec = [next(orig_spec) if i in bds else None for i in range(len(shape))] assert next(orig_spec, None) is None - return operand.sharding.with_spec(new_spec) + mesh = (get_abstract_mesh() if operand.sharding.mesh.empty else + operand.sharding.mesh) + return operand.sharding.update( + mesh=mesh, spec=operand.sharding.spec.update(partitions=new_spec)) def _broadcast_in_dim_typecheck_rule( - _, operand, *dyn_shape, shape, broadcast_dimensions, sharding): - if not dyn_shape: - out_aval, effects = broadcast_in_dim_p.abstract_eval( - operand.aval, shape=shape, broadcast_dimensions=broadcast_dimensions, - sharding=sharding) - return [out_aval], effects - else: - # TODO(mattjj): perform more checks like _broadcast_in_dim_shape_rule - out_shape = _merge_dyn_shape(shape, dyn_shape) - out_shape = [x.val if type(x) is core.Literal else x for x in out_shape] # pytype: disable=attribute-error - out_aval = core.DShapedArray(tuple(out_shape), operand.aval.dtype, - operand.aval.weak_type) - return [out_aval], core.no_effects - -def _broadcast_in_dim_transpose_rule(ct, operand, *dyn_shape, + _, operand, shape, broadcast_dimensions, sharding): + out_aval, effects = broadcast_in_dim_p.abstract_eval( + operand.aval, shape=shape, broadcast_dimensions=broadcast_dimensions, + sharding=sharding) + return [out_aval], effects + +def _broadcast_in_dim_transpose_rule(ct, operand, shape, broadcast_dimensions, sharding): if type(ct) is ad_util.Zero: return [ad_util.Zero(operand.aval)] + if not isinstance(operand, ad.UndefinedPrimal): + return [None] # transpose wrt literal unit_dims = [i for i, s in enumerate(operand.aval.shape) if core.definitely_equal(s, 1)] bdims = tuple(np.delete(broadcast_dimensions, unit_dims)) axes = tuple(np.delete(range(len(shape)), bdims)) - return ([expand_dims(reduce_sum(ct, axes), unit_dims)] + - [None] * len(dyn_shape)) + return [expand_dims(reduce_sum(ct, axes), unit_dims)] def _broadcast_in_dim_batch_rule(axis_data, batched_args, batch_dims, shape, broadcast_dimensions, sharding): - # `dyn_shape` is the dynamic portion of the target shape. `shape` - # is the target shape, with `None` for dynamic sections. - # broadcast_dimensions gives indices where dimensions of the input - # have to go: dimension i of the input becomes dimension - # broadcast_dimensions[i] of the output. - operand, *dyn_shape = batched_args - operand_bdim, *dyn_shape_bdims = batch_dims - - stacked_size = None - if operand_bdim is not None: - if isinstance(operand_bdim, RaggedAxis): - stacked_axis = operand_bdim.stacked_axis - stacked_size = operand_bdim.size - else: - stacked_axis = operand_bdim - stacked_size = operand.shape[stacked_axis] - new_operand = batching.moveaxis(operand, stacked_axis, 0) - new_broadcast_dimensions = (0,) + tuple(np.add(1, broadcast_dimensions)) - else: - new_operand = operand - new_broadcast_dimensions = tuple(np.add(1, broadcast_dimensions)) - - # TODO(mattjj,axch) This section assumes that the shape of the operand is - # broadcast-compatible with the requested shape. We should tweak vmap to run - # the abstract_eval rule so this can be checked while the raggedness - # information is available. - dyn_limits = [] - out_ragged_sizes = [] - for sizes, bdim in zip(dyn_shape, dyn_shape_bdims): - if bdim is None: - # TODO(mattjj,axch) Is this what bdim == None means? - assert isinstance(sizes, int) - bound = sizes - else: - bound = sizes.dtype.bound - out_ragged_sizes.append(sizes) - if stacked_size is None: - stacked_size = len(sizes) - else: - msg = "All segments lengths arrays must be the same length" - assert len(sizes) == stacked_size, msg - dyn_limits.append(bound) - new_shape = (stacked_size,) + _merge_dyn_shape(shape, dyn_limits) + # `shape` is the target shape. broadcast_dimensions gives indices where + # dimensions of the input have to go: dimension i of the input becomes + # dimension broadcast_dimensions[i] of the output. + operand, = batched_args + operand_bdim, = batch_dims + assert operand_bdim is not None + new_operand = batching.moveaxis(operand, operand_bdim, 0) + new_broadcast_dimensions = (0,) + tuple(np.add(1, broadcast_dimensions)) + new_shape = (operand.shape[operand_bdim],) + shape if sharding is not None: sharding = batching.get_sharding_for_vmap(axis_data, sharding, 0) result = broadcast_in_dim(new_operand, new_shape, new_broadcast_dimensions, out_sharding=sharding) - out_ragged_axes = [idx+1 for idx, s in enumerate(shape) if s is None] - out_bdim = batching.make_batch_axis( - result.ndim, 0, zip(out_ragged_axes, out_ragged_sizes)) - return result, out_bdim + return result, 0 def _broadcast_in_dim_fwd_rule(eqn): - v, *dyn = eqn.invars - if not dyn and core.definitely_equal_shape(eqn.params['shape'], v.aval.shape): - return [v], None + v, = eqn.invars + if (core.definitely_equal_shape(eqn.params['shape'], v.aval.shape) + and (eqn.params['sharding'] is None or + eqn.params['sharding'] == v.aval.sharding)): + return [0], None else: return [None], eqn def _broadcast_in_dim_staging_rule( - trace, x, *dyn, shape, broadcast_dimensions, sharding): + trace, source_info, x, shape, broadcast_dimensions, sharding): params = dict(shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=sharding) - if not dyn: - return trace.default_process_primitive(broadcast_in_dim_p, (x,), params) - aval = core.DShapedArray(_merge_dyn_shape(shape, dyn), x.dtype, x.weak_type) - return _dyn_shape_staging_rule(trace, broadcast_in_dim_p, aval, x, *dyn, - **params) - -def _broadcast_in_dim_padding_rule(in_avals, out_avals, x, *dyn_shape, - shape, broadcast_dimensions): - del in_avals, dyn_shape - out_aval, = out_avals - new_shape = [] - new_dyn_shape = [] - for d in out_aval.shape: - if type(d) is pe.BoundedAxisSize: - new_shape.append(d.bound) - elif type(d) is int: - new_shape.append(d) - else: - assert isinstance(d, core.Tracer) - new_shape.append(None) - new_dyn_shape.append(d) - return [broadcast_in_dim_p.bind(x, *new_dyn_shape, shape=tuple(new_shape), - broadcast_dimensions=broadcast_dimensions)] + return trace.default_process_primitive(broadcast_in_dim_p, (x,), params, + source_info=source_info) def _broadcast_in_dim_jvp_rule(primals, tangents, *, shape, broadcast_dimensions, sharding): - operand, *dyn_shape = primals + operand, = primals operand_dot, *_ = tangents - y = broadcast_in_dim_p.bind(operand, *dyn_shape, shape=shape, + y = broadcast_in_dim_p.bind(operand, shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=sharding) if type(operand_dot) is ad_util.Zero: y_dot = ad_util.Zero.from_primal_value(y) else: - y_dot = broadcast_in_dim_p.bind(operand_dot, *dyn_shape, shape=shape, + y_dot = broadcast_in_dim_p.bind(operand_dot, shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=sharding) return y, y_dot def _broadcast_in_dim_partial_eval( - trace, operand, *dyn_shape, shape, broadcast_dimensions, sharding): - if not dyn_shape: - return trace.default_process_primitive( - broadcast_in_dim_p, (operand, *dyn_shape), - dict(shape=shape, broadcast_dimensions=broadcast_dimensions, - sharding=sharding)) - assert all(t.pval.is_known() for t in dyn_shape) - operand_tracer = trace.instantiate_const(operand) - dyn_shape_tracers = map(trace.instantiate_const, dyn_shape) - dyn_shape_tracers_ = iter(dyn_shape_tracers) - shape_ = [next(dyn_shape_tracers_) if d is None else d for d in shape] - out_aval = core.DShapedArray(tuple(shape_), operand.dtype, operand.weak_type) - out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None) - eqn = pe.new_eqn_recipe( - [operand_tracer, *dyn_shape_tracers], [out_tracer], broadcast_in_dim_p, + trace, operand, shape, broadcast_dimensions, sharding): + return trace.default_process_primitive( + broadcast_in_dim_p, (operand,), dict(shape=shape, broadcast_dimensions=broadcast_dimensions, - sharding=None), - core.no_effects, source_info_util.current()) - out_tracer.recipe = eqn - return out_tracer + sharding=sharding)) -def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions, +def _broadcast_in_dim_lower(ctx, x, shape, broadcast_dimensions, sharding) -> Sequence[ir.Value]: aval_out, = ctx.avals_out - if dyn_shape: - aval_out = aval_out.update(shape=_merge_dyn_shape(shape, dyn_shape)) out = mlir.broadcast_in_dim(ctx, x, aval_out, broadcast_dimensions=broadcast_dimensions) return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)] -def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions, +def _broadcast_in_dim_abstract_eval(x, shape, broadcast_dimensions, sharding): - if (not dyn_shape and - not any(isinstance(d, core.DArray) and - type(core.get_aval(d).dtype) is core.bint for d in shape)): - shape = _broadcast_in_dim_shape_rule( # error checking - x, shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=None) - new_sharding = _broadcast_in_dim_sharding_rule( - x, shape=shape, broadcast_dimensions=broadcast_dimensions, - sharding=sharding) - return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=new_sharding) - # If any BInts in shape, or Tracers in dyn_shape, produce a DShapedArray - # (even if x is a ShapedArray) - # TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code - return core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), x.dtype, x.weak_type) - - -def _broadcast_in_dim_ragged_prop_rule(eqn_params, invar_raggedness, outvars): - assert len(invar_raggedness) == 1 - assert not isinstance(invar_raggedness[0], core.Var) - return invar_raggedness, [None] * len(outvars) + shape = _broadcast_in_dim_shape_rule( # error checking + x, shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=None) + new_sharding = _broadcast_in_dim_sharding_rule( + x, shape=shape, broadcast_dimensions=broadcast_dimensions, + sharding=sharding) + new_vma = core.standard_vma_rule('broadcast_in_dim', x) + return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=new_sharding, + vma=new_vma, memory_space=x.memory_space) -broadcast_in_dim_p = standard_primitive( - _broadcast_in_dim_shape_rule, _input_dtype, 'broadcast_in_dim') +broadcast_in_dim_p = core.Primitive('broadcast_in_dim') broadcast_in_dim_p.def_abstract_eval(_broadcast_in_dim_abstract_eval) +broadcast_in_dim_p.def_impl(partial(dispatch.apply_primitive, broadcast_in_dim_p)) ad.primitive_jvps[broadcast_in_dim_p] = _broadcast_in_dim_jvp_rule ad.primitive_transposes[broadcast_in_dim_p] = _broadcast_in_dim_transpose_rule batching.fancy_primitive_batchers[broadcast_in_dim_p] = _broadcast_in_dim_batch_rule @@ -6275,12 +6562,61 @@ def _broadcast_in_dim_ragged_prop_rule(eqn_params, invar_raggedness, outvars): pe.forwarding_rules[broadcast_in_dim_p] = _broadcast_in_dim_fwd_rule pe.custom_partial_eval_rules[broadcast_in_dim_p] = _broadcast_in_dim_partial_eval pe.custom_staging_rules[broadcast_in_dim_p] = _broadcast_in_dim_staging_rule -pe.padding_rules[broadcast_in_dim_p] = _broadcast_in_dim_padding_rule core.custom_typechecks[broadcast_in_dim_p] = _broadcast_in_dim_typecheck_rule mlir.register_lowering(broadcast_in_dim_p, _broadcast_in_dim_lower) -batching.ragged_prop_rules[broadcast_in_dim_p] = ( - _broadcast_in_dim_ragged_prop_rule -) + + +def _tile_lower(ctx, x, reps) -> Sequence[ir.Value]: + aval_out, = ctx.avals_out + x_aval, = ctx.avals_in + expand_shape = tuple(j for i in x_aval.shape for j in [1, i]) + expand_sharding = NamedSharding( + x_aval.sharding.mesh.abstract_mesh, + P(*tuple(s for d in x_aval.sharding.spec for s in [None, d])), + ) + reshaped_aval = x_aval.update(shape=expand_shape, sharding=expand_sharding) + reshaped = mlir.reshape(ctx, x, reshaped_aval) + reshaped = mlir.lower_with_sharding_in_types(ctx, reshaped, reshaped_aval) + broadcast_shape = tuple(k for pair in zip(reps, x_aval.shape) for k in pair) + broadcasted_aval = x_aval.update( + shape=broadcast_shape, sharding=expand_sharding) + broadcasted = mlir.broadcast_in_dim(ctx, reshaped, + broadcasted_aval, broadcast_dimensions=tuple(range(2 * x_aval.ndim))) + broadcasted = mlir.lower_with_sharding_in_types( + ctx, broadcasted, broadcasted_aval) + out = mlir.reshape(ctx, broadcasted, aval_out) + return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)] + +def _tile_abstract_eval(x, reps): + if x.ndim != len(reps): + raise TypeError( + f"reps length must be equal to the ndim of x, got {len(reps)=} " + f"and {x.ndim=}.") + return x.update(shape=tuple(np.multiply(x.shape, reps))) + +def _tile_transpose_rule(ct, operand, *, reps): + if type(ct) is ad_util.Zero: + return [ad_util.Zero(operand.aval)] + if not isinstance(operand, ad.UndefinedPrimal): + return [None] # transpose wrt literal + ct_reshaped = reshape( + ct, tuple(k for pair in zip(reps, operand.aval.shape) for k in pair)) + axes = tuple(2*i for i in range(operand.aval.ndim)) + return [reduce_sum(ct_reshaped, axes)] + +def _tile_batch_rule(batched_args, batch_dims, *, reps): + operand, = batched_args + bdim, = batch_dims + new_reps = list(reps) + new_reps.insert(bdim, 1) + return tile(operand, reps=new_reps), bdim + +tile_p = core.Primitive('tile') +tile_p.def_abstract_eval(_tile_abstract_eval) +tile_p.def_impl(partial(dispatch.apply_primitive, tile_p)) +ad.deflinear2(tile_p, _tile_transpose_rule) +batching.primitive_batchers[tile_p] = _tile_batch_rule +mlir.register_lowering(tile_p, _tile_lower) def _clamp_shape_rule(min, operand, max): @@ -6295,7 +6631,7 @@ def _clamp_shape_rule(min, operand, max): def _clamp_sharding_rule(min, operand, max): return operand.sharding -_clamp_dtype_rule = partial(naryop_dtype_rule, _input_dtype, [_any, _any, _any], +_clamp_dtype_rule = partial(naryop_dtype_rule, input_dtype, [_any, _any, _any], 'clamp') def _clamp_batch_rule(batched_args, batch_dims, **params): @@ -6336,7 +6672,8 @@ def _clamp_batch_rule(batched_args, batch_dims, **params): return clamp_p.bind(min, x, max), 0 clamp_p = standard_primitive(_clamp_shape_rule, _clamp_dtype_rule, 'clamp', - sharding_rule=_clamp_sharding_rule) + sharding_rule=_clamp_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'clamp')) ad.defjvp(clamp_p, lambda g, min, operand, max: select(bitwise_and(gt(min, operand), lt(min, max)), @@ -6348,16 +6685,15 @@ def _clamp_batch_rule(batched_args, batch_dims, **params): select(lt(max, operand), g, _zeros(operand))) batching.primitive_batchers[clamp_p] = _clamp_batch_rule mlir.register_lowering(clamp_p, partial(_nary_lower_hlo, hlo.clamp)) -pe.def_trivial_padding(clamp_p) def _concatenate_shape_rule(*operands, **kwargs): dimension = kwargs.pop('dimension') if not operands: msg = "concatenate expects at least one operand, got 0." raise TypeError(msg) - if not all(isinstance(operand, UnshapedArray) for operand in operands): + if not all(isinstance(operand, ShapedArray) for operand in operands): msg = "All objects to concatenate must be arrays, got {}." - op = next(op for op in operands if not isinstance(op, UnshapedArray)) + op = next(op for op in operands if not isinstance(op, ShapedArray)) raise TypeError(msg.format(type(op))) if len({operand.ndim for operand in operands}) != 1: msg = "Cannot concatenate arrays with different numbers of dimensions: got {}." @@ -6384,10 +6720,20 @@ def _concatenate_sharding_rule(*operands, **kwargs): return core.get_cur_mesh_sharding() if not all(s == non_empty_s[0] for s in non_empty_s): ss = ", ".join(str(o.sharding) for o in operands) - raise TypeError( + raise core.ShardingTypeError( f"All operands should have the same sharding. Got shardings {ss}") return non_empty_s[0] +def _concatenate_reduced_rule(out_s, *operands, **kwargs): + reduced_specs = {o.sharding.spec.reduced + for o in operands if o.sharding.spec.reduced} + if len(reduced_specs) > 1: + raise core.ShardingTypeError( + 'All operands should be reduced along the same mesh axes. Got reduced' + f' specs {reduced_specs}') + reduced_s, = reduced_specs if reduced_specs else (frozenset(),) + return out_s.update(spec=out_s.spec.update(reduced=reduced_s)) + def _concatenate_dtype_rule(*operands, **kwargs): check_same_dtypes('concatenate', *operands) return operands[0].dtype @@ -6409,25 +6755,19 @@ def _concatenate_batch_rule(batched_args, batch_dims, *, dimension): for op, bdim in zip(batched_args, batch_dims) if bdim is not None) operands = [batching.moveaxis(op, bdim, 0) if bdim is not None else broadcast( - op, (size,), out_sharding=core.get_aval(op).sharding.with_spec( + op, (size,), out_sharding=core.get_aval(op).sharding.update(spec= (spec, *core.get_aval(op).sharding.spec))) for op, bdim in zip(batched_args, batch_dims)] return concatenate(operands, dimension + 1), 0 -def _concatenate_pad_rule(in_avals, out_avals, *operands, dimension): - if all(isinstance(a.shape[dimension], (int, np.integer)) - for a in in_avals): - return [concatenate(operands, dimension)] - else: - raise NotImplementedError # TODO(mattjj) - concatenate_p = standard_primitive( _concatenate_shape_rule, _concatenate_dtype_rule, 'concatenate', - sharding_rule=_concatenate_sharding_rule) + sharding_rule=_concatenate_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'concatenate'), + reduced_rule=_concatenate_reduced_rule) ad.deflinear2(concatenate_p, _concatenate_transpose_rule) ad.primitive_transposes[concatenate_p] = _concatenate_transpose_rule batching.primitive_batchers[concatenate_p] = _concatenate_batch_rule -pe.padding_rules[concatenate_p] = _concatenate_pad_rule def _concatenate_lower(ctx, *xs, dimension): aval_out, = ctx.avals_out @@ -6461,10 +6801,8 @@ def _split_transpose_rule(cotangents, operand, *, sizes, axis): assert ad.is_undefined_primal(operand) if all(type(t) is ad_util.Zero for t in cotangents): return ad_util.Zero(operand.aval), - cotangents = [ - _zeros(t.aval) if type(t) is ad_util.Zero else t - for t in cotangents - ] + cotangents = [t.instantiate() if type(t) is ad_util.Zero else t + for t in cotangents] return concatenate(cotangents, dimension=axis), def _split_batch_rule(batched_args, batch_dims, *, sizes, axis): @@ -6495,11 +6833,17 @@ def _split_sharding_rule(operand, *, sizes, axis): return [slicing._get_sharding_for_varying_out_shape(out_sh, operand, 'split') for out_sh in out_shapes] +def _split_vma_rule(operand, *, sizes, axis): + out_vma = core.standard_vma_rule('split', operand) + out_shapes = _split_shape_rule(operand, sizes=sizes, axis=axis) + return [out_vma] * len(out_shapes) + split_p = core.Primitive('split') split_p.multiple_results = True split_p.def_abstract_eval( partial(standard_multi_result_abstract_eval, split_p, _split_shape_rule, - _split_dtype_rule, _split_weak_type_rule, _split_sharding_rule)) + _split_dtype_rule, _split_weak_type_rule, _split_sharding_rule, + _split_vma_rule)) split_p.def_impl(partial(dispatch.apply_primitive, split_p)) ad.deflinear2(split_p, _split_transpose_rule) batching.primitive_batchers[split_p] = _split_batch_rule @@ -6510,7 +6854,7 @@ def _pad_dtype_rule(operand, padding_value, *, padding_config): msg = "pad operand and padding_value must be same dtype: got {} and {}." raise TypeError(msg.format(operand.dtype, padding_value.dtype)) - return _input_dtype(operand, padding_value) + return input_dtype(operand, padding_value) def _pad_shape_rule(operand, padding_value, *, padding_config): if np.ndim(padding_value) != 0: @@ -6581,7 +6925,8 @@ def _pad_batch_rule(batched_args, batch_dims, *, padding_config): return select(mask, x, broadcasted_padding), operand_bdim pad_p = standard_primitive(_pad_shape_rule, _pad_dtype_rule, 'pad', - sharding_rule=_pad_sharding_rule) + sharding_rule=_pad_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'pad')) ad.deflinear2(pad_p, _pad_transpose) batching.primitive_batchers[pad_p] = _pad_batch_rule @@ -6604,7 +6949,6 @@ def _pad_lower(ctx, x, padding_value, *, padding_config): # JAXpr is that we are reshaping from (1, 1) to (1,). # In contrast, squeeze[ dimensions=(0,) ] is unambiguous. - def _squeeze_dtype_rule(operand, *, dimensions): return operand.dtype @@ -6615,7 +6959,12 @@ def _squeeze_sharding_rule(operand, *, dimensions): dims_set = set(dimensions) new_spec = tuple(s for i, s in enumerate(operand.sharding.spec) if i not in dims_set) - return operand.sharding.with_spec(new_spec) + return operand.sharding.update( + spec=operand.sharding.spec.update(partitions=new_spec)) + +def _squeeze_reduced_rule(out_s, operand, *, dimensions): + return out_s.update(spec=out_s.spec.update( + reduced=operand.sharding.spec.reduced)) def _compute_squeeze_shape(shape, dimensions): dims_set = set(dimensions) @@ -6636,20 +6985,20 @@ def _squeeze_transpose_rule(t, operand, *, dimensions): def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions): operand, = batched_args bdim, = batch_dims - operand, bdim = batching.move_stacked_axis(operand, bdim, 0) + operand = batching.moveaxis(operand, bdim, 0) dimensions = tuple(np.add(1, dimensions)) - out_stack_dim = bdim.stacked_axis if isinstance(bdim, RaggedAxis) else bdim - bdim_out = batching.shape_as_bdim( - out_stack_dim, - _compute_squeeze_shape(batching.bdim_as_shape(bdim, operand.shape), dimensions)) + + result_shape = _compute_squeeze_shape(operand.shape, dimensions) + bdim_out = canonicalize_axis(0, len(result_shape)) return squeeze(operand, dimensions=dimensions), bdim_out -squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule, - 'squeeze', sharding_rule=_squeeze_sharding_rule) +squeeze_p = standard_primitive( + _squeeze_shape_rule, _squeeze_dtype_rule, 'squeeze', + sharding_rule=_squeeze_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'squeeze'), + reduced_rule=_squeeze_reduced_rule) ad.deflinear2(squeeze_p, _squeeze_transpose_rule) batching.primitive_batchers[squeeze_p] = _squeeze_batch_rule -pe.def_trivial_padding(squeeze_p) -batching.ragged_prop_rules[squeeze_p] = batching.ragged_mask_no_op_rule def _squeeze_lower(ctx, operand, *, dimensions): del dimensions # Implied by the output aval. @@ -6662,12 +7011,13 @@ def _squeeze_lower(ctx, operand, *, dimensions): def shape_as_value(shape: core.Shape): """Converts a shape that may contain Poly values into a JAX value.""" + dtype = lax_utils.int_dtype_for_shape(shape, signed=True) if len(shape) == 0: - return full((0,), np.array(0, np.int64)) + return full((0,), np.array(0, dtype=dtype)) if core.is_constant_shape(shape): - return np.asarray(shape, dtype=np.int64) + return np.asarray(shape, dtype=dtype) dims = [ - expand_dims(convert_element_type(core.dimension_as_value(d), np.int64), + expand_dims(convert_element_type(core.dimension_as_value(d), dtype), (0,)) for d in shape ] @@ -6680,12 +7030,6 @@ def _reshape_shape_rule(operand, *, new_sizes, dimensions, sharding): # TODO(necula): re-enable this check operand_size = math.prod(np.shape(operand)) new_size = math.prod(new_sizes) - if (not config.dynamic_shapes.value and - not operand_size == new_size): - msg = (f"reshape total size must be unchanged, got new_sizes {new_sizes} " - f"(of total size {new_size}) for shape {np.shape(operand)} " - f"(of total size {operand_size}).") - raise TypeError(msg) if dimensions is not None: if set(dimensions) != set(range(np.ndim(operand))): msg = ('reshape dimensions must be a permutation of operand dimensions, ' @@ -6696,6 +7040,14 @@ def _reshape_shape_rule(operand, *, new_sizes, dimensions, sharding): def _split_on_one_axis(op_shape, new_sizes, name): if len(new_sizes) <= len(op_shape): return False, [] + orig_op_shape, orig_new_sizes = op_shape, new_sizes + + num_1s = 0 + while op_shape[-1] == 1 and new_sizes[-1] == 1: + num_1s += 1 + op_shape = op_shape[:-1] + new_sizes = new_sizes[:-1] + i, j, count, out = 0, 0, 0, [] while j < len(new_sizes): if op_shape[i] == new_sizes[j]: @@ -6703,20 +7055,28 @@ def _split_on_one_axis(op_shape, new_sizes, name): else: count += 1 if count > 1: - raise ValueError( - f'{name} on more than 1 axis is not supported. Please specify' - ' the sharding of the output via the `sharding` argument of' - f' jax.lax.reshape. Got operand.shape={op_shape} and {new_sizes=}') + raise core.ShardingTypeError( + f'{name} on more than 1 axis is not supported. Please specify the' + ' sharding of the output via the `sharding` argument of' + f' jax.lax.reshape. Got operand.shape={orig_op_shape} and' + f' {orig_new_sizes=}') temp = [new_sizes[j]] - while math.prod(temp) != op_shape[i]: + next_j = j + 1 + while (math.prod(temp) != op_shape[i] or + (next_j < len(new_sizes) and new_sizes[next_j] == 1)): if math.prod(temp) > op_shape[i]: return False, [] j += 1 + if j >= len(new_sizes): + return False, [] temp.append(new_sizes[j]) + next_j += 1 out.append(temp) i += 1 j += 1 - assert len(op_shape) == len(out) + out.extend([1] * num_1s) + + assert len(orig_op_shape) == len(out) return True, out @@ -6729,6 +7089,8 @@ def _merge_on_one_axis(operand, new_sizes): def _reshape_sharding_rule(operand, *, new_sizes, dimensions, sharding): if sharding is not None: return sharding + if all(s is None for s in operand.sharding.spec): + return operand.sharding non_1s_op_shape = [s for s in operand.shape if s != 1] non_1s_new_shape = [s for s in new_sizes if s != 1] if non_1s_op_shape == non_1s_new_shape: @@ -6744,11 +7106,10 @@ def _reshape_sharding_rule(operand, *, new_sizes, dimensions, sharding): return _merge_an_axis_sharding_rule(operand, operand_merge, new_sizes, dimensions) - raise ValueError( + raise core.ShardingTypeError( 'This reshape is not supported. Please specify the sharding of' ' the output via the `out_sharding` argument of jax.lax.reshape. Got' - f' operand shape: {operand.shape}, new sizes: {new_sizes} and' - f' operand spec: {operand.sharding.spec}') + f' operand type: {operand}, new sizes: {new_sizes}') def _split_merge_singleton_dim_sharding_rule(operand, new_sizes): filtered_spec = [sp for sh, sp in zip(operand.shape, operand.sharding.spec) @@ -6761,7 +7122,7 @@ def _split_merge_singleton_dim_sharding_rule(operand, new_sizes): else: sp = next(fs) new_spec.append(sp) - return operand.sharding.with_spec(new_spec) + return operand.sharding.update(spec=new_spec) def _get_spec_size(sp, mesh): tup_sp = sp if isinstance(sp, tuple) else (sp,) @@ -6777,15 +7138,14 @@ def _split_an_axis_sharding_rule(operand, out_split, new_sizes, dimensions): elif dimensions is None and out[0] % _get_spec_size(sp, mesh) == 0: new_spec.extend([sp] + [None] * (len(out) - 1)) else: - raise ValueError( + raise core.ShardingTypeError( 'This reshape is not supported. Please specify the sharding of the' - ' output via the `sharding` argument of jax.lax.reshape. Got' - f' operand shape: {operand.shape}, new sizes: {new_sizes} and' - f' operand spec: {operand.sharding.spec}') + ' output via the `out_sharding` argument of jax.lax.reshape. Got' + f' operand type: {operand}, new sizes: {new_sizes}') else: new_spec.append(sp) assert len(new_spec) == len(new_sizes), (new_spec, new_sizes) - return operand.sharding.with_spec(new_spec) + return operand.sharding.update(spec=new_spec) def _merge_an_axis_sharding_rule(operand, operand_merge, new_sizes, dimensions): @@ -6802,32 +7162,23 @@ def _merge_an_axis_sharding_rule(operand, operand_merge, new_sizes, dimensions): assert new_size % _get_spec_size(sp[0], mesh) == 0 new_spec.append(sp[0]) else: - raise ValueError( + raise core.ShardingTypeError( 'This reshape is not supported. Please specify the sharding of the' - ' output via the `sharding` argument of jax.lax.reshape. Got' - f' operand shape: {operand.shape}, new sizes: {new_sizes} and' - f' operand spec: {operand.sharding.spec}') + ' output via the `out_sharding` argument of jax.lax.reshape. Got' + f' operand type: {operand}, new sizes: {new_sizes}') else: new_spec.append(next(op_spec)) assert next(op_spec, None) is None assert len(new_spec) == len(new_sizes), (new_spec, new_sizes) - return operand.sharding.with_spec(new_spec) + return operand.sharding.update(spec=new_spec) -def _reshape_typecheck_rule(_, operand, *dyn_shape, new_sizes, dimensions, +def _reshape_typecheck_rule(_, operand, new_sizes, dimensions, sharding): - if not dyn_shape: - out_aval, effects = reshape_p.abstract_eval( - operand.aval, new_sizes=new_sizes, dimensions=dimensions, - sharding=sharding) - return [out_aval], effects - else: - # TODO(mattjj, necula): perform more checks like _reshape_shape_rule - out_shape = _merge_dyn_shape(new_sizes, dyn_shape) - out_shape = [x.val if type(x) is core.Literal else x for x in out_shape] # pytype: disable=attribute-error - out_aval = core.DShapedArray(tuple(out_shape), operand.aval.dtype, - operand.aval.weak_type) - return [out_aval], core.no_effects + out_aval, effects = reshape_p.abstract_eval( + operand.aval, new_sizes=new_sizes, dimensions=dimensions, + sharding=sharding) + return [out_aval], effects def _reshape_dtype_rule(operand, *, new_sizes, dimensions, sharding): @@ -6838,7 +7189,7 @@ def _reshape_transpose_rule(t, operand, *, new_sizes, dimensions, sharding): if dimensions is None: return [reshape(t, operand.aval.shape, out_sharding=operand.aval.sharding)] else: - t_s = operand.aval.sharding.with_spec( + t_s = operand.aval.sharding.update(spec= tuple(map(lambda s: s if s is None else str(s), np.take(operand.aval.sharding.spec, dimensions)))) return [transpose(reshape(t, np.take(operand.aval.shape, dimensions), @@ -6861,25 +7212,22 @@ def _reshape_batch_rule(axis_data, batched_args, batch_dims, *, new_sizes, return out, 0 -def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions, sharding): +def _reshape_lower(ctx, x, new_sizes, dimensions, sharding): aval_out, = ctx.avals_out if dimensions is not None: x = hlo.transpose(x, mlir.dense_int_array(dimensions)) - if dyn_shape: - aval_out = aval_out.update(shape=_merge_dyn_shape(new_sizes, dyn_shape)) out = mlir.reshape(ctx, x, aval_out) return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)] def _reshape_staging_rule( - trace, x, *dyn, new_sizes, dimensions, sharding): + trace, source_info, x, new_sizes, dimensions, sharding): params = dict(new_sizes=new_sizes, dimensions=dimensions, sharding=sharding) - if not dyn: - return trace.default_process_primitive(reshape_p, (x,), params) - av = core.DShapedArray(_merge_dyn_shape(new_sizes, dyn), x.dtype, x.weak_type) - return _dyn_shape_staging_rule(trace, reshape_p, av, x, *dyn, **params) + return trace.default_process_primitive(reshape_p, (x,), params, + source_info=source_info) reshape_p = standard_primitive(_reshape_shape_rule, _reshape_dtype_rule, - 'reshape', sharding_rule=_reshape_sharding_rule) + 'reshape', sharding_rule=_reshape_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reshape')) ad.deflinear2(reshape_p, _reshape_transpose_rule) batching.fancy_primitive_batchers[reshape_p] = _reshape_batch_rule batching.skippable_batchers[reshape_p] = lambda _: () @@ -6910,8 +7258,9 @@ def _rev_batch_rule(batched_args, batch_dims, *, dimensions): new_dimensions = [i + 1 if i >= bdim else i for i in dimensions] return rev(operand, new_dimensions), bdim -rev_p = standard_primitive(_rev_shape_rule, _input_dtype, 'rev', - sharding_rule=_rev_sharding_rule) +rev_p = standard_primitive(_rev_shape_rule, input_dtype, 'rev', + sharding_rule=_rev_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'rev')) ad.deflinear2(rev_p, lambda t, _, dimensions: [rev(t, dimensions)]) batching.primitive_batchers[rev_p] = _rev_batch_rule @@ -6935,17 +7284,21 @@ def _transpose_shape_rule(operand, *, permutation): def _transpose_sharding_rule(operand, *, permutation): o_spec = operand.sharding.spec new_spec = [o_spec[old_idx] for old_idx in permutation] - return operand.sharding.with_spec(new_spec) + return operand.sharding.update(spec=o_spec.update(partitions=new_spec)) + +def _transpose_unreduced_rule(out_s, operand, *, permutation): + return out_s.update(spec=out_s.spec.update( + unreduced=operand.sharding.spec.unreduced)) + +def _transpose_reduced_rule(out_s, operand, *, permutation): + return out_s.update(spec=out_s.spec.update( + reduced=operand.sharding.spec.reduced)) def _transpose_batch_rule(batched_args, batch_dims, *, permutation): operand, = batched_args bdim, = batch_dims - stack_dim = bdim.stacked_axis if isinstance(bdim, RaggedAxis) else bdim - perm = (stack_dim,) + tuple(i if i < stack_dim else i+1 for i in permutation) - if isinstance(bdim, RaggedAxis): - res_bdim = batching.transpose_ragged_axes(bdim.move_stacked_axis(0), perm) - else: - res_bdim = 0 + perm = (bdim,) + tuple(i if i < bdim else i+1 for i in permutation) + res_bdim = 0 return transpose(operand, perm), res_bdim def _transpose_lower(ctx, x, *, permutation): @@ -6958,13 +7311,15 @@ def _transpose_lower(ctx, x, *, permutation): return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)] transpose_p = standard_primitive( - _transpose_shape_rule, _input_dtype, 'transpose', - sharding_rule=_transpose_sharding_rule) + _transpose_shape_rule, input_dtype, 'transpose', + sharding_rule=_transpose_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'transpose'), + unreduced_rule=_transpose_unreduced_rule, + reduced_rule=_transpose_reduced_rule) ad.deflinear2(transpose_p, lambda t, _, permutation: [transpose(t, np.argsort(permutation))]) batching.primitive_batchers[transpose_p] = _transpose_batch_rule mlir.register_lowering(transpose_p, _transpose_lower) -pe.def_trivial_padding(transpose_p) def _select_shape_rule(which, *cases): @@ -6985,10 +7340,11 @@ def _select_sharding_rule(which, *cases): return core.get_cur_mesh_sharding() if any(s != non_empty_s[0] for s in non_empty_s[1:]): msg = "select cases must have the same shardings, got [{}]." - raise TypeError(msg.format(", ".join([str(c.sharding) for c in cases]))) + raise core.ShardingTypeError( + msg.format(", ".join([str(c.sharding) for c in cases]))) if (which.shape and not which.sharding.mesh.empty and which.sharding != non_empty_s[0]): - raise TypeError( + raise core.ShardingTypeError( 'select `which` must be scalar or have the same sharding as cases, got' f' `which` sharding {which.sharding} but case sharding' f' {cases[0].sharding}.') @@ -7025,7 +7381,7 @@ def _select_transpose_rule(t, which, *cases): if ad.is_undefined_primal(case) else None for i, case in enumerate(cases) ] -def _select_batch_rule(batched_args, batch_dims, **unused_kwargs): +def _select_batch_rule(axis_data, batched_args, batch_dims, **unused_kwargs): which, *cases = batched_args which_bdim, *case_bdims = batch_dims size = next(x.shape[i] for x, i in zip(batched_args, batch_dims) @@ -7038,7 +7394,8 @@ def _select_batch_rule(batched_args, batch_dims, **unused_kwargs): else: # vmapped function had a scalar which with nonscalar args assert np.ndim(which) == 1 - which = broadcast_in_dim(which, cases[0].shape, [which_bdim]) + which = broadcast_in_dim(which, cases[0].shape, [which_bdim], + out_sharding=typeof(cases[0]).sharding) return select_n(which, *cases), which_bdim elif np.ndim(which) == 0 and all(bdim is not None for bdim in case_bdims): if all(case_bdims[0] == bdim for bdim in case_bdims[1:]): @@ -7049,16 +7406,18 @@ def _select_batch_rule(batched_args, batch_dims, **unused_kwargs): for c, c_bdim in zip(cases[1:], case_bdims[1:])] return select_n(which, cases[0], *other_cases), bdim - which = (batching.bdim_at_front(which, which_bdim, size) if np.shape(which) - else which) + which = (batching.bdim_at_front(which, which_bdim, size, + axis_data.explicit_mesh_axis) + if np.shape(which) else which) if not all(() == np.shape(c) for c in cases): - cases = [batching.bdim_at_front(c, bdim, size) + cases = [batching.bdim_at_front(c, bdim, size, axis_data.explicit_mesh_axis) for c, bdim in zip(cases, case_bdims)] assert all(np.shape(cases[0]) == np.shape(c) for c in cases[1:]) if 0 < np.ndim(which) < np.ndim(cases[0]): # vmapped function had a scalar which with nonscalar args assert np.ndim(which) == 1 - which = broadcast_in_dim(which, cases[0].shape, [0]) + which = broadcast_in_dim(which, cases[0].shape, [0], + out_sharding=typeof(cases[0]).sharding) if np.ndim(which) > np.ndim(cases[0]): assert np.ndim(cases[0]) == 0 cases = [broadcast(c, which.shape) for c in cases] @@ -7079,7 +7438,11 @@ def _select_jvp(primals, tangents): def _select_hlo_lowering_opaque(ctx, which, *cases): avals_in = ctx.avals_in aval_out, = ctx.avals_out - assert all(aval_case == aval_out for aval_case in avals_in[1:]) + assert all((aval_case.shape, aval_case.dtype) == (aval_out.shape, aval_out.dtype) + for aval_case in avals_in[1:]) + assert all( + aval_case == aval_out for aval_case in avals_in[1:] + if not aval_case.sharding.mesh.empty and not aval_out.sharding.mesh.empty) select_lower = _select_hlo_lowering physical_aval_out = core.physical_aval(aval_out) @@ -7134,12 +7497,13 @@ def _select(offset, cases): select_n_p = standard_primitive( _select_shape_rule, _select_dtype_rule, 'select_n', - weak_type_rule=_select_weak_type_rule, sharding_rule=_select_sharding_rule) + weak_type_rule=_select_weak_type_rule, sharding_rule=_select_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'select_n')) ad.primitive_jvps[select_n_p] = _select_jvp ad.primitive_transposes[select_n_p] = _select_transpose_rule -batching.primitive_batchers[select_n_p] = _select_batch_rule +batching.fancy_primitive_batchers[select_n_p] = _select_batch_rule +batching.skippable_batchers[select_n_p] = lambda _: () mlir.register_lowering(select_n_p, _select_hlo_lowering) -pe.def_trivial_padding(select_n_p) def _reduce_shape_rule(*avals, computation, jaxpr, dimensions): @@ -7151,13 +7515,18 @@ def _reduce_shape_rule(*avals, computation, jaxpr, dimensions): def _reduce_sharding_rule(*avals, computation, jaxpr, dimensions): operand_avals, _ = split_list(avals, [len(avals) // 2]) - return [op.sharding.with_spec(tuple_delete(op.sharding.spec, dimensions)) + return [op.sharding.update(spec=tuple_delete(op.sharding.spec, dimensions)) for op in operand_avals] +def _reduce_vma_rule(*avals, computation, jaxpr, dimensions): + operand_avals, _ = split_list(avals, [len(avals) // 2]) + out_vma = core.standard_vma_rule('reduce', *operand_avals) + return [out_vma] * len(operand_avals) + def _reduce_dtype_rule(*avals, computation, jaxpr, dimensions): operand_avals, init_val_avals = split_list(avals, [len(avals) // 2]) - operand_dtypes = [dtypes.canonicalize_dtype(op.dtype) for op in operand_avals] - init_val_dtypes = [dtypes.canonicalize_dtype(init.dtype) for init in init_val_avals] + operand_dtypes = [op.dtype for op in operand_avals] + init_val_dtypes = [init.dtype for init in init_val_avals] if operand_dtypes != init_val_dtypes: raise TypeError( "reduce operand dtypes should match corresponding initial value dtypes, " @@ -7240,11 +7609,13 @@ def _reduce_jvp_rule(primals, tangents, *, computation, jaxpr, dimensions): reduce_p.def_impl(partial(dispatch.apply_primitive, reduce_p)) reduce_p.def_abstract_eval( partial(standard_multi_result_abstract_eval, reduce_p, _reduce_shape_rule, - _reduce_dtype_rule, _reduce_weak_type_rule, _reduce_sharding_rule)) + _reduce_dtype_rule, _reduce_weak_type_rule, _reduce_sharding_rule, + _reduce_vma_rule)) batching.primitive_batchers[reduce_p] = _reduce_batch_rule ad.primitive_jvps[reduce_p] = _reduce_jvp_rule -def _reduce_lower(ctx, *values, computation, jaxpr, dimensions): +def _reduce_lower(ctx: mlir.LoweringRuleContext, *values, + computation, jaxpr: core.ClosedJaxpr, dimensions): assert all(isinstance(x, core.ShapedArray) for x in ctx.avals_in), ctx.avals_in operands, init_values = util.split_list(values, [len(values) // 2]) init_value_avals = ctx.avals_in[len(values) // 2:] @@ -7260,7 +7631,8 @@ def _reduce_lower(ctx, *values, computation, jaxpr, dimensions): name_stack, mlir.TokenSet(), jaxpr.consts, *reducer.arguments, - dim_var_values=ctx.dim_var_values) + dim_var_values=ctx.dim_var_values, + const_lowering=ctx.const_lowering) hlo.return_(mlir.flatten_ir_values(out_nodes)) return [mlir.lower_with_sharding_in_types(ctx, r, aval) for r, aval in safe_zip(op.results, ctx.avals_out)] @@ -7272,32 +7644,19 @@ def _reduce_number_dtype_rule(name, operand, *args, **kw): if not dtypes.issubdtype(operand.dtype, np.number): raise TypeError("{} does not accept dtype {}. Accepted dtypes are subtypes " "of number.".format(name, dtype_to_string(operand.dtype))) - return dtypes.canonicalize_dtype(operand.dtype) + return operand.dtype -def _reduce_sum_transpose_rule(cotangent, operand, *, axes): +def _reduce_sum_transpose_rule(cotangent, operand, *, axes, out_sharding): assert ad.is_undefined_primal(operand) input_shape = operand.aval.shape broadcast_dimensions = tuple(np.delete(np.arange(len(input_shape)), axes)) - result = broadcast_in_dim(cotangent, input_shape, broadcast_dimensions, - out_sharding=operand.aval.sharding) + result = broadcast_in_dim( + cotangent, input_shape, broadcast_dimensions, + out_sharding=operand.aval.to_cotangent_aval().sharding) assert result.shape == input_shape return [result] -def _reducer_padding(traceable, ident, in_avals, out_avals, operand, *, axes): - del out_avals - aval, = in_avals - padded_axes = [(i, d.val) for i, d in enumerate(aval.shape) - if isinstance(d, pe.BoundedAxisSize)] - operand_ = _replace_masked_values(operand, ident(aval.dtype), padded_axes) - return [traceable(operand_, axes)] - -def _replace_masked_values(x, val, padded_axes): - if not padded_axes: return x - dtype = dtypes._scalar_type_to_dtype(int) - masks = [broadcasted_iota(dtype, x.shape, i) < d for i, d in padded_axes] - return select(_reduce(operator.and_, masks), x, full_like(x, val)) - -def _reduce_op_shape_rule(operand, *, axes, input_shape=None): +def _reduce_op_shape_rule(operand, *, axes, input_shape=None, **kwargs): del input_shape # Unused. if len(axes) != len(set(axes)): raise ValueError(f"duplicate value in 'axes' of reduction: {axes}") @@ -7306,20 +7665,43 @@ def _reduce_op_shape_rule(operand, *, axes, input_shape=None): axes = frozenset(axes) return tuple(d for i, d in enumerate(operand.shape) if i not in axes) -def _reduce_op_sharding_rule(operand, *, axes): +def _reduce_sum_sharding_rule(operand, *, axes, out_sharding): + if out_sharding is not None: + assert isinstance(out_sharding, NamedSharding) + return out_sharding axes = frozenset(axes) new_spec = P(*tuple(s for i, s in enumerate(operand.sharding.spec) if i not in axes)) - return operand.sharding.with_spec(new_spec) + return operand.sharding.update(spec=new_spec) + +def _reduce_sum_unreduced_rule(out_s, operand, *, axes, **kwargs): + if operand.sharding.spec.unreduced: + raise core.ShardingTypeError( + f'operand passed to reduce_sum cannot be unreduced. Got {operand=}') + if unreduced_spec := out_s.spec.unreduced: + axes = frozenset(axes) + reduced_spec = frozenset( + s for i, spec in enumerate(operand.sharding.spec) if i in axes + for s in (spec if isinstance(spec, tuple) else (spec,))) + if not all(u in reduced_spec for u in unreduced_spec): + raise core.ShardingTypeError( + "out_sharding's unreduced axes should be in operand's specs that" + f' were summed over. Got {operand=}, {axes=},' + f' unreduced_spec={unreduced_spec}') + return out_s + +def _reduce_sum_reduced_rule(out_s, operand, *, axes, **kwargs): + return out_s.update(spec=out_s.spec.update( + reduced=operand.sharding.spec.reduced)) reduce_sum_p = standard_primitive( _reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'), - 'reduce_sum', sharding_rule=_reduce_op_sharding_rule) + 'reduce_sum', sharding_rule=_reduce_sum_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_sum'), + unreduced_rule=_reduce_sum_unreduced_rule, + reduced_rule=_reduce_sum_reduced_rule) ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule) batching.defreducer(reduce_sum_p, _get_sum_identity) -pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, reduce_sum, - _get_sum_identity) -batching.ragged_prop_rules[reduce_sum_p] = batching.ragged_mask_elementwise_rule def _reduce_prod_jvp_rule(primals, tangents, *, axes): reducer = lambda x, y: [mul(x, y)] @@ -7327,14 +7709,18 @@ def _reduce_prod_jvp_rule(primals, tangents, *, axes): primals, tangents, axes) return primals_out[0], tangents_out[0] +def _reduce_op_sharding_rule(operand, *, axes): + axes = frozenset(axes) + new_spec = P(*tuple(s for i, s in enumerate(operand.sharding.spec) + if i not in axes)) + return operand.sharding.update(spec=new_spec) + reduce_prod_p = standard_primitive( _reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_prod'), - 'reduce_prod', sharding_rule=_reduce_op_sharding_rule) + 'reduce_prod', sharding_rule=_reduce_op_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_prod')) ad.primitive_jvps[reduce_prod_p] = _reduce_prod_jvp_rule batching.defreducer(reduce_prod_p, _get_prod_identity) -pe.padding_rules[reduce_prod_p] = partial(_reducer_padding, reduce_prod, - _get_prod_identity) - def _reduce_chooser_jvp_rule(g, ans, operand, *, axes): # TODO(mattjj): an alternative is to use variadic reduce to compute the chosen @@ -7348,23 +7734,19 @@ def _reduce_chooser_jvp_rule(g, ans, operand, *, axes): reduce_max_p = standard_primitive( - _reduce_op_shape_rule, _input_dtype, 'reduce_max', - sharding_rule=_reduce_op_sharding_rule) + _reduce_op_shape_rule, input_dtype, 'reduce_max', + sharding_rule=_reduce_op_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_max')) ad.defjvp2(reduce_max_p, _reduce_chooser_jvp_rule) batching.defreducer(reduce_max_p, _get_max_identity) -pe.padding_rules[reduce_max_p] = partial(_reducer_padding, reduce_max, - _get_max_identity) -batching.ragged_prop_rules[reduce_max_p] = batching.ragged_mask_elementwise_rule reduce_min_p = standard_primitive( - _reduce_op_shape_rule, _input_dtype, 'reduce_min', - sharding_rule=_reduce_op_sharding_rule) + _reduce_op_shape_rule, input_dtype, 'reduce_min', + sharding_rule=_reduce_op_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_min')) ad.defjvp2(reduce_min_p, _reduce_chooser_jvp_rule) batching.defreducer(reduce_min_p, _get_min_identity) -pe.padding_rules[reduce_min_p] = partial(_reducer_padding, reduce_min, - _get_min_identity) - def _argminmax_shape_rule(operand, *, axes, index_dtype): axis, = axes @@ -7377,7 +7759,7 @@ def _argminmax_shape_rule(operand, *, axes, index_dtype): def _argminmax_sharding_rule(operand, *, axes, index_dtype): axis, = axes - return operand.sharding.with_spec( + return operand.sharding.update(spec= util.tuple_delete(operand.sharding.spec, axis)) def _argminmax_dtype_rule(operand, *, axes, index_dtype): @@ -7426,24 +7808,35 @@ def _compute_argminmax(value_comparator, get_identity, argmin_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule, 'argmin', weak_type_rule=_strip_weak_type, - sharding_rule=_argminmax_sharding_rule) + sharding_rule=_argminmax_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'argmin')) batching.defreducer(argmin_p, _get_min_identity) ad.defjvp_zero(argmin_p) argmax_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule, 'argmax', weak_type_rule=_strip_weak_type, - sharding_rule=_argminmax_sharding_rule) + sharding_rule=_argminmax_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'argmax')) batching.defreducer(argmax_p, _get_max_identity) ad.defjvp_zero(argmax_p) -mlir.register_lowering(argmin_p, mlir.cache_lowering( - mlir.lower_fun(partial(_compute_argminmax, lt, _get_min_identity), - multiple_results=False))) - -mlir.register_lowering(argmax_p, mlir.cache_lowering( - mlir.lower_fun(partial(_compute_argminmax, gt, _get_max_identity), - multiple_results=False))) +mlir.register_lowering( + argmin_p, + mlir.lower_fun( + partial(_compute_argminmax, lt, _get_min_identity), + multiple_results=False, + ), + inline=False, +) +mlir.register_lowering( + argmax_p, + mlir.lower_fun( + partial(_compute_argminmax, gt, _get_max_identity), + multiple_results=False, + ), + inline=False, +) def _reduce_logical_shape_rule(operand, *, axes): if operand.dtype != np.bool_ and not np.issubdtype(operand.dtype, np.integer): @@ -7451,28 +7844,37 @@ def _reduce_logical_shape_rule(operand, *, axes): return tuple(np.delete(operand.shape, axes)) def _reduce_logical_sharding_rule(operand, *, axes): - return operand.sharding.with_spec(tuple_delete(operand.sharding.spec, axes)) + return operand.sharding.update(spec=tuple_delete(operand.sharding.spec, axes)) + +def _reduce_or_lin(nzs, x, *, axes): + nz, = nzs + y = reduce_or_p.bind(x, axes=axes) + aval = typeof(y).to_tangent_aval() + return y, False, (), lambda _, t: ad_util.Zero(aval) reduce_or_p = standard_primitive( - _reduce_logical_shape_rule, _input_dtype, 'reduce_or', - weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule) + _reduce_logical_shape_rule, input_dtype, 'reduce_or', + weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_or')) batching.defreducer(reduce_or_p, _get_bitwise_or_identity) +ad.primitive_linearizations[reduce_or_p] = _reduce_or_lin reduce_and_p = standard_primitive( - _reduce_logical_shape_rule, _input_dtype, 'reduce_and', - weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule) + _reduce_logical_shape_rule, input_dtype, 'reduce_and', + weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_and')) batching.defreducer(reduce_and_p, _get_bitwise_and_identity) -batching.ragged_prop_rules[reduce_and_p] = batching.ragged_mask_elementwise_rule reduce_xor_p = standard_primitive( - _reduce_logical_shape_rule, _input_dtype, 'reduce_xor', - weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule) + _reduce_logical_shape_rule, input_dtype, 'reduce_xor', + weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_xor')) batching.defreducer(reduce_xor_p, _get_bitwise_or_identity) -def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes): +def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes, **kwargs): aval_out, = ctx.avals_out dtype = aval_out.dtype op = hlo.ReduceOp([mlir.aval_to_ir_type(aval_out)], [x], @@ -7515,7 +7917,8 @@ def _reduce_precision_sharding_rule(operand, *, exponent_bits, mantissa_bits): reduce_precision_p = standard_primitive( _reduce_precision_shape_rule, partial(unop_dtype_rule, _identity, _float, 'reduce_precision'), - name='reduce_precision', sharding_rule=_reduce_precision_sharding_rule) + name='reduce_precision', sharding_rule=_reduce_precision_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_precision')) ad.deflinear(reduce_precision_p, lambda t, **kwargs: [reduce_precision_p.bind(t, **kwargs)]) batching.defvectorized(reduce_precision_p) @@ -7541,16 +7944,23 @@ def _reduce_precision_lower(ctx, operand, *, exponent_bits, mantissa_bits): } -def _sort_abstract_eval(*args, **kwargs): - args = tuple(args) - if any(arg.shape != args[0].shape for arg in args[1:]): - shapes = " ".join(str(a.shape) for a in args) +def _sort_abstract_eval(*avals, **kwargs): + avals = tuple(avals) + if any(arg.shape != avals[0].shape for arg in avals[1:]): + shapes = " ".join(str(a.shape) for a in avals) raise TypeError(f"Arguments to sort must have equal shapes, got: {shapes}") - return args + non_empty_s = [ + a.sharding for a in avals + if not a.sharding.mesh.empty and a.sharding.mesh._any_axis_explicit] + if any(s != non_empty_s[0] for s in non_empty_s[1:]): + shardings = " ".join(str(s) for s in non_empty_s) + raise core.ShardingTypeError( + f'Arguments to sort must have equal shardings, got: {shardings}') + return avals def _canonicalize_float_for_sort(x): - # In the sort comparator, we are going to use a comparision operator where -0 + # In the sort comparator, we are going to use a comparison operator where -0 # would be before 0, and -NaN and NaN appear at the beginning and end of the # ordering. In this scheme, -0 would be before 0, and -NaN and NaN appear at # the beginning and end of the ordering. This causes issues for stable @@ -7577,6 +7987,7 @@ def _sort_lt_comparator(*operands, num_keys=1): x_keys, y_keys = _operands_to_keys(*operands, num_keys=num_keys) p = None for xk, yk in zip(x_keys[::-1], y_keys[::-1]): + xk, yk = core.standard_insert_pvary(xk, yk) p = (bitwise_or(lt_to_p.bind(xk, yk), bitwise_and(eq_to_p.bind(xk, yk), p)) if p is not None else lt_to_p.bind(xk, yk)) return p @@ -7587,6 +7998,7 @@ def _sort_le_comparator(*operands, num_keys=1): x_keys, y_keys = _operands_to_keys(*operands, num_keys=num_keys) p = None for xk, yk in zip(x_keys[::-1], y_keys[::-1]): + xk, yk = core.standard_insert_pvary(xk, yk) p = (bitwise_or(lt_to_p.bind(xk, yk), bitwise_and(eq_to_p.bind(xk, yk), p)) if p is not None else le_to_p.bind(xk, yk)) return p @@ -7611,8 +8023,10 @@ def _operands_to_keys(*operands, num_keys=1): def _sort_jvp(primals, tangents, *, dimension, is_stable, num_keys): shape = primals[0].shape + index_dtype = lax_utils.int_dtype_for_shape(shape, signed=False) sorted_primals_and_idx = sort_p.bind( - *primals, broadcasted_iota(np.uint64, shape, dimension), + *primals, + broadcasted_iota(index_dtype, shape, dimension), dimension=dimension, is_stable=is_stable, num_keys=num_keys) batch_dims = tuple(np.delete(np.arange(len(shape), dtype=np.int64), dimension)) @@ -7640,7 +8054,9 @@ def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, num_keys for arg, bdim in zip(batched_args, batch_dims): if bdim is None: dims = np.delete(np.arange(prototype_arg.ndim), new_bdim) - new_args.append(broadcast_in_dim(arg, prototype_arg.shape, dims)) + new_args.append(broadcast_in_dim( + arg, prototype_arg.shape, dims, + out_sharding=typeof(prototype_arg).sharding)) else: new_args.append(batching.moveaxis(arg, bdim, new_bdim)) new_dimension = dimension + (new_bdim <= dimension) @@ -7663,7 +8079,7 @@ def _sort_lower(ctx, *operands, dimension, is_stable, num_keys): mlir.flatten_ir_values(operands), dimension=mlir.i64_attr(dimension), is_stable=ir.BoolAttr.get(is_stable)) - scalar_s = lambda a: a.sharding.with_spec(P()) + scalar_s = lambda a: a.sharding.update(spec=P()) scalar_avals = [aval.update(shape=(), sharding=scalar_s(aval)) for aval in ctx.avals_in] scalar_types = safe_map(mlir.aval_to_ir_type, scalar_avals) @@ -7683,7 +8099,7 @@ def _sort_lower(ctx, *operands, dimension, is_stable, num_keys): mlir.register_lowering(sort_p, _sort_lower) -def _top_k_abstract_eval(operand, *, k): +def _top_k_abstract_eval(operand, *, k, axis): if dtypes.issubdtype(operand.dtype, np.complexfloating): raise ValueError("top_k is not compatible with complex inputs.") if k < 0: @@ -7691,19 +8107,30 @@ def _top_k_abstract_eval(operand, *, k): if len(operand.shape) == 0: raise TypeError("top_k operand must have >= 1 dimension, got {}" .format(operand.shape)) + if not (0 <= axis < len(operand.shape)): + raise ValueError(f"axis argument out of range: {axis=} for {operand.shape=}") shape = list(operand.shape) - if shape[-1] < k: - msg = "k argument to top_k must be no larger than minor dimension; {} vs {}" - raise ValueError(msg.format(k, shape)) - shape[-1] = k + if shape[axis] < k: + raise ValueError("k argument to top_k must be no larger than size along axis;" + f" got {k=} with {shape=} and {axis=}") + int32_max = dtypes.iinfo('int32').max + try: + too_large = (shape[axis] > int32_max + 1) + except core.InconclusiveDimensionOperation: + pass + else: + if too_large: + raise ValueError("top_k returns int32 indices, which will overflow for array dimensions " + f"larger than the maximum int32 ({int32_max}). Got {operand.shape=}") + shape[axis] = k return (operand.update(shape=shape, dtype=operand.dtype, weak_type=operand.weak_type), operand.update(shape=shape, dtype=np.dtype(np.int32))) -def _top_k_jvp(primals, tangents, *, k): +def _top_k_jvp(primals, tangents, *, k, axis): operand, = primals tangent, = tangents - primals_out = top_k(operand, k) + primals_out = top_k(operand, k, axis=axis) if type(tangent) is ad_util.Zero: tangent_out = ad_util.Zero.from_primal_value(primals_out[0]) else: @@ -7715,40 +8142,52 @@ def _top_k_jvp(primals, tangents, *, k): slice_sizes = (1,) * rank dnums = slicing.GatherDimensionNumbers( offset_dims=(), - collapsed_slice_dims=(rank - 1,), - operand_batching_dims=tuple(range(rank - 1)), - start_indices_batching_dims=tuple(range(rank - 1)), - start_index_map=(rank - 1,), + collapsed_slice_dims=(axis,), + operand_batching_dims=tuple(i for i in range(rank) if i != axis), + start_indices_batching_dims=tuple(i for i in range(rank) if i != axis), + start_index_map=(axis,), ) tangent_out = slicing.gather(tangent, gather_indices, dnums, slice_sizes) return primals_out, (tangent_out, ad_util.Zero.from_primal_value(primals_out[1])) -def _top_k_batch_rule(batched_args, batch_dims, *, k): +def _top_k_batch_rule(batched_args, batch_dims, *, k, axis): operand, = batched_args bdim, = batch_dims - if bdim == operand.ndim-1: - perm = np.arange(operand.ndim) - perm[bdim-1], perm[bdim] = perm[bdim], perm[bdim-1] - top_k_v, top_k_i = top_k(transpose(operand, perm), k=k) - return (transpose(top_k_v, perm), - transpose(top_k_i, perm)), (bdim, bdim) - else: - return top_k(operand, k=k), (bdim, bdim) + if bdim <= axis: + axis += 1 + return top_k(operand, k=k, axis=axis), (bdim, bdim) top_k_p = Primitive('top_k') top_k_p.multiple_results = True top_k_p.def_impl(partial(dispatch.apply_primitive, top_k_p)) top_k_p.def_abstract_eval(_top_k_abstract_eval) -def _top_k_lower(ctx, operand, k): +def _top_k_lower(ctx, operand, k, axis): + # Move axis to last dimension: + ndim = len(ctx.avals_in[0].shape) + if axis != ndim - 1: + perm = list(range(ndim)) + perm[axis], perm[-1] = perm[-1], perm[axis] + operand = hlo.transpose(operand, mlir.dense_int_array(perm)) + else: + perm = None + + # Compute the top-k along the last dimension if core.is_constant_dim(k): - return chlo.TopKOp(operand, mlir.i64_attr(k)).results - k_value, = mlir.eval_dynamic_shape_as_vals(ctx, (k,)) - out_values_aval, out_indices_aval, = ctx.avals_out - return mlir.custom_call( - "stablehlo.dynamic_top_k", - result_types=[mlir.aval_to_ir_type(out_values_aval), - mlir.aval_to_ir_type(out_indices_aval)], - operands=[operand, k_value]).results + results = chlo.TopKOp(operand, mlir.i64_attr(k)).results + else: + k_value, = mlir.eval_dynamic_shape_as_vals(ctx, (k,)) + out_values_aval, out_indices_aval, = ctx.avals_out + results = mlir.custom_call( + "stablehlo.dynamic_top_k", + result_types=[mlir.aval_to_ir_type(out_values_aval), + mlir.aval_to_ir_type(out_indices_aval)], + operands=[operand, k_value]).results + + # Move last dimension back into place + if perm is not None: + results = [hlo.transpose(result, mlir.dense_int_array(perm)) + for result in results] + return results mlir.register_lowering(top_k_p, _top_k_lower) ad.primitive_jvps[top_k_p] = _top_k_jvp @@ -7766,7 +8205,6 @@ def _stop_gradient_batch_rule(batched_args, batch_dims): ad.primitive_jvps[ad_util.stop_gradient_p] = _stop_gradient_jvp_rule batching.primitive_batchers[ad_util.stop_gradient_p] = _stop_gradient_batch_rule -pe.def_trivial_padding(ad_util.stop_gradient_p) def create_token(_=None): @@ -7791,7 +8229,8 @@ def _create_token_lowering(ctx, *operands): def after_all(*operands): """Merges one or more XLA token values. Experimental. - Wraps the XLA AfterAll operator.""" + Wraps the XLA after all operator.""" + operands = core.standard_insert_pvary(*operands) return after_all_p.bind(*operands) def _after_all_abstract_eval(*operands): @@ -7810,110 +8249,6 @@ def _after_all_lowering(ctx, *operands): mlir.register_lowering(after_all_p, _after_all_lowering) -class InOutFeedEffect(effects.Effect): - pass -infeed_effect = InOutFeedEffect() -outfeed_effect = InOutFeedEffect() - - -def infeed(token, shape=None, partitions=None): - """Consumes an infeed value of `shape` from the host. Experimental. - - `token` is used to sequence infeed and outfeed effects. - `partitions` may be specified inside a `sharded_jit` function. - """ - flat_shapes, treedef = tree_util.tree_flatten(shape) - for shape in flat_shapes: - if not isinstance(shape, ShapedArray): - raise TypeError("shape argument to infeed must be a pytree of " - "ShapedArray values, got {}".format(shape)) - if partitions is not None: - # Always replicate token. - # We specifically use type() to raise an error for PartitionSpecs. - if type(partitions) != tuple: # pylint: disable=unidiomatic-typecheck - raise ValueError(f"'partitions' argument to infeed should be a tuple, " - f"got {partitions}") - partitions = partitions + (None,) - xs_and_token = infeed_p.bind(token, shapes=tuple(flat_shapes), - partitions=partitions) - return (treedef.unflatten(xs_and_token[:-1]), xs_and_token[-1]) - -def _infeed_abstract_eval(token, *, shapes, partitions): - if token is not abstract_token: - raise TypeError("First argument to infeed must be a token") - return (*shapes, abstract_token), {infeed_effect} - - -infeed_p = Primitive("infeed") -infeed_p.multiple_results = True -infeed_p.def_impl(partial(dispatch.apply_primitive, infeed_p)) -infeed_p.def_effectful_abstract_eval(_infeed_abstract_eval) -mlir.lowerable_effects.add_type(InOutFeedEffect) - - -def _infeed_lowering(ctx, token, *, shapes, partitions): - output_types = safe_map(mlir.aval_to_ir_type, ctx.avals_out[:-1]) - flat_output_types = mlir.flatten_ir_types(output_types) - # TODO(phawkins): verify `shapes` have a major-to-minor layout. - layouts = ir.ArrayAttr.get([ - ir.ArrayAttr.get( - [mlir.i64_attr(i) - for i in range(len(aval.shape) - 1, -1, -1)]) - for aval in shapes - ]) - infeed = hlo.InfeedOp( - flat_output_types + [hlo.TokenType.get()], - token, - infeed_config=ir.StringAttr.get(''), - layout=layouts) - if partitions is not None: - mlir.set_sharding(infeed, xla.sharding_to_proto(partitions)) - token = infeed.results[-1] - outs = infeed.results[:-1] - return mlir.unflatten_ir_values_like_types(outs, output_types) + [ - token, - ] - -mlir.register_lowering(infeed_p, _infeed_lowering) - - -def outfeed(token, xs, partitions = None): - """Outfeeds value `xs` to the host. Experimental. - - `token` is used to sequence infeed and outfeed effects. - `partitions` may be specified inside a `sharded_jit` or `pjit` function. - """ - if partitions is not None: - # We specifically use type() to raise an error for PartitionSpecs. - if type(partitions) != tuple: # pylint: disable=unidiomatic-typecheck - raise ValueError(f"'partitions' argument to outfeed should be a tuple, " - f"got {partitions}") - flat_xs, _ = tree_util.tree_flatten(xs) - return outfeed_p.bind(token, *flat_xs, partitions=partitions) - -def _outfeed_abstract_eval(token, *xs, partitions): - if token is not abstract_token: - raise TypeError("First argument to outfeed must be a token") - return abstract_token, {outfeed_effect} - -outfeed_p = Primitive("outfeed") -outfeed_p.def_impl(partial(dispatch.apply_primitive, outfeed_p)) -outfeed_p.def_effectful_abstract_eval(_outfeed_abstract_eval) -mlir.lowerable_effects.add_type(InOutFeedEffect) - - -def _outfeed_lowering(ctx, token, *xs, partitions): - outfeed = hlo.OutfeedOp( - mlir.flatten_ir_values(xs), - token, - outfeed_config=ir.StringAttr.get('')) - if partitions is not None: - mlir.set_sharding(outfeed, xla.sharding_to_proto(partitions)) - return outfeed.results - -mlir.register_lowering(outfeed_p, _outfeed_lowering) - - def rng_uniform(a, b, shape): """Stateful PRNG generator. Experimental and its use is discouraged. @@ -7926,6 +8261,7 @@ def rng_uniform(a, b, shape): This API may be removed at any time. """ + a, b = core.standard_insert_pvary(a, b) return rng_uniform_p.bind(a, b, shape=tuple(shape)) def _rng_uniform_abstract_eval(a, b, *, shape): @@ -7952,15 +8288,23 @@ def _rng_uniform_lowering(ctx, a, b, *, shape): mlir.register_lowering(rng_uniform_p, _rng_uniform_lowering) -def _rng_bit_generator_shape_rule(key, *, shape, dtype, algorithm): +def _rng_bit_generator_shape_rule(key, *, shape, dtype, algorithm, out_sharding): del dtype, algorithm return (key.shape, tuple(shape)) -def _rng_bit_generator_dtype_rule(key, *, shape, dtype, algorithm): +def _rng_bit_generator_sharding_rule(key, *, shape, dtype, algorithm, + out_sharding): + return (key.sharding, out_sharding) + +def _rng_bit_generator_vma_rule(key, *, shape, dtype, algorithm, out_sharding): + return (key.vma, frozenset()) + +def _rng_bit_generator_dtype_rule(key, *, shape, dtype, algorithm, out_sharding): del shape, algorithm return (key.dtype, dtype) -def _rng_bit_generator_weak_type_rule(key, *, shape, dtype, algorithm): +def _rng_bit_generator_weak_type_rule(key, *, shape, dtype, algorithm, + out_sharding): del shape, dtype, algorithm return (key.weak_type, False) @@ -7991,7 +8335,7 @@ def _rng_algorithm(algorithm: RandomAlgorithm): assert False def _rng_bit_generator_lowering( - ctx, key, *, shape, dtype, algorithm): + ctx, key, *, shape, dtype, algorithm, out_sharding): key_type = ir.RankedTensorType(key.type) key_shape, key_etype = key_type.shape, key_type.element_type # While the RngBitGenerator HLO accepts a u64[2] key on all backends, we @@ -8020,7 +8364,7 @@ def _rng_bit_generator_lowering( ir.RankedTensorType.get([2], u64_type), hlo.reshape(ir.RankedTensorType.get([2, 2], u32_type), key)) algorithm_attr = _rng_algorithm(algorithm) - _, out_vals_aval = ctx.avals_out + out_key_aval, out_vals_aval = ctx.avals_out if any(not core.is_constant_shape(a.shape) for a in ctx.avals_out): output_shape = mlir.shape_tensor( mlir.eval_dynamic_shape(ctx, out_vals_aval.shape)) @@ -8044,7 +8388,8 @@ def _rng_bit_generator_lowering( out_vals = hlo.convert( ir.RankedTensorType.get(ir.RankedTensorType(out_vals.type).shape, etype), out_vals) - return [out_key, out_vals] + return [mlir.lower_with_sharding_in_types(ctx, out_key, out_key_aval), + mlir.lower_with_sharding_in_types(ctx, out_vals, out_vals_aval)] rng_bit_generator_p = Primitive("rng_bit_generator") @@ -8054,7 +8399,8 @@ def _rng_bit_generator_lowering( rng_bit_generator_p.def_abstract_eval( partial(standard_multi_result_abstract_eval, rng_bit_generator_p, _rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule, - _rng_bit_generator_weak_type_rule, None)) + _rng_bit_generator_weak_type_rule, _rng_bit_generator_sharding_rule, + _rng_bit_generator_vma_rule)) mlir.register_lowering(rng_bit_generator_p, _rng_bit_generator_lowering) @@ -8111,14 +8457,30 @@ def _copy_impl(prim, *args, **kwargs): copy_p.def_abstract_eval(lambda x: x) mlir.register_lowering(copy_p, lambda ctx, x: [x]) ad.deflinear(copy_p, lambda t: [copy_p.bind(t)]) -pe.def_trivial_padding(copy_p) batching.defvectorized(copy_p) -def _propagate_mem_kind_copy(in_mem_kind): - return in_mem_kind -pxla.memory_kind_propagate_rule[copy_p] = _propagate_mem_kind_copy + +# The dce_sink_p primitive marks a value as "used" from the perspective of DCE +# so the computation producing it won't be eliminated. +def dce_sink(val): + tree_util.tree_map(dce_sink_p.bind, val) + +class NoDCEEffect(effects.Effect): + pass +no_dce_effect = NoDCEEffect() +effects.control_flow_allowed_effects.add_type(NoDCEEffect) +effects.lowerable_effects.add_type(NoDCEEffect) + +dce_sink_p = core.Primitive('dce_sink') +dce_sink_p.def_impl(lambda _: []) +dce_sink_p.multiple_results = True +dce_sink_p.def_effectful_abstract_eval(lambda _: ([], {no_dce_effect})) +mlir.register_lowering(dce_sink_p, lambda ctx, _: []) +ad.deflinear(dce_sink_p, lambda _: []) +batching.primitive_batchers[dce_sink_p] = lambda x, bd: (x, bd) def rng_bit_generator(key, shape, dtype=np.uint32, - algorithm=RandomAlgorithm.RNG_DEFAULT): + algorithm=RandomAlgorithm.RNG_DEFAULT, + *, out_sharding=None): """Stateless PRNG bit generator. Experimental and its use is discouraged. Returns uniformly distributed random bits with the specified shape and dtype @@ -8126,26 +8488,27 @@ def rng_bit_generator(key, shape, dtype=np.uint32, default algorithm or the one specified. It provides direct access to the RngBitGenerator primitive exposed by XLA - (https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator) for low + (https://www.openxla.org/xla/operation_semantics#rngbitgenerator) for low level API access. Most users should use `jax.random` instead for a stable and more user friendly API. """ shape = core.canonicalize_shape(shape) - dtype = dtypes.canonicalize_dtype(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype(dtype, 'rng_bit_generator') + out_sharding = canonicalize_sharding(out_sharding, 'rng_bit_generator') if np.dtype(dtype) not in {np.dtype('uint8'), np.dtype('uint16'), np.dtype('uint32'), np.dtype('uint64')}: raise TypeError(f'rng_bit_generator: unsupported dtype {dtype}') return tuple( rng_bit_generator_p.bind( - key, shape=shape, dtype=dtype, algorithm=algorithm)) + key, shape=shape, dtype=dtype, algorithm=algorithm, + out_sharding=out_sharding)) -def _iota_abstract_eval(*dyn_shape, dtype, shape, dimension, sharding): - if not dyn_shape: - # TODO(mattjj) Generalize shape_like checking to permit dynamic shapes - _check_shapelike("iota", "shape", shape) +def _iota_abstract_eval(dtype, shape, dimension, sharding): + # TODO(mattjj) Generalize shape_like checking to permit dynamic shapes + _check_shapelike("iota", "shape", shape) if not any(dtypes.issubdtype(dtype, t) for t in _num): msg = 'iota does not accept dtype {}. Accepted dtypes are subtypes of {}.' typename = dtype_to_string(dtype) @@ -8154,85 +8517,35 @@ def _iota_abstract_eval(*dyn_shape, dtype, shape, dimension, sharding): if not 0 <= dimension < len(shape): raise ValueError("iota dimension must be between 0 and len(shape), got " f"{dimension=} for {shape=}") - if (not dyn_shape and - not any(isinstance(d, core.DArray) and - type(core.get_aval(d).dtype) is core.bint for d in shape)): - if sharding is None: - sharding = core.get_cur_mesh_sharding(spec=core.P(*[None] * len(shape))) - return ShapedArray(shape, dtype, sharding=sharding) - # TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code - return core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), dtype, False) - + if sharding is None: + sharding = core.get_cur_mesh_sharding(spec=core.P(*[None] * len(shape))) + return ShapedArray(shape, dtype, sharding=sharding) iota_p = Primitive('iota') iota_p.def_impl(partial(dispatch.apply_primitive, iota_p)) iota_p.def_abstract_eval(_iota_abstract_eval) -batching.ragged_prop_rules[iota_p] = batching.ragged_mask_no_op_rule -def _iota_staging_rule(trace, *dyn_shape, dtype, shape, dimension, sharding): +def _iota_staging_rule(trace, source_info, dtype, shape, dimension, + sharding): params = dict(dtype=dtype, shape=shape, dimension=dimension, sharding=sharding) - if not dyn_shape: - return trace.default_process_primitive(iota_p, (), params) - aval = core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), dtype, False) - return _dyn_shape_staging_rule(trace, iota_p, aval, *dyn_shape, **params) + return trace.default_process_primitive(iota_p, (), params, + source_info=source_info) pe.custom_staging_rules[iota_p] = _iota_staging_rule -def _iota_typecheck_rule(_, *dyn_shape, dtype, shape, dimension, sharding): - if not dyn_shape: - out_aval, effects = iota_p.abstract_eval( - dtype=dtype, shape=shape, dimension=dimension, sharding=sharding) - return [out_aval], effects - else: - out_shape = _merge_dyn_shape(shape, dyn_shape) - out_shape = [x.val if type(x) is core.Literal else x for x in out_shape] # pytype: disable=attribute-error - out_aval = core.DShapedArray(tuple(out_shape), dtype, False) - return [out_aval], core.no_effects +def _iota_typecheck_rule(_, dtype, shape, dimension, sharding): + out_aval, effects = iota_p.abstract_eval( + dtype=dtype, shape=shape, dimension=dimension, sharding=sharding) + return [out_aval], effects core.custom_typechecks[iota_p] = _iota_typecheck_rule -def _iota_lower(ctx, *dyn_shape, dtype, shape, dimension, sharding): +def _iota_lower(ctx, dtype, shape, dimension, sharding): del dtype aval_out, = ctx.avals_out - if dyn_shape: - aval_out = aval_out.update(shape=_merge_dyn_shape(shape, dyn_shape)) out = mlir.iota(ctx, aval_out, dimension=dimension) return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)] mlir.register_lowering(iota_p, _iota_lower) -def _iota_batching_rule(in_vals, in_dims, *, dtype, shape, dimension, - sharding): - (segment_lengths,), (ax,) = in_vals, in_dims - assert ax == 0 - bound = segment_lengths.dtype.bound - ragged_axis, = (i for i, dim in enumerate(shape) if dim is None) - shape = (len(segment_lengths),) + _merge_dyn_shape(shape, (bound,)) - if sharding is not None: - raise NotImplementedError('Please file an issue if you want this support') - iota = broadcasted_iota(dtype, shape, dimension+1) - return iota, batching.RaggedAxis(ax, ((ragged_axis+1, segment_lengths),)) -batching.primitive_batchers[iota_p] = _iota_batching_rule - -def _iota_padding_rule(in_avals, out_avals, *dyn_shape, dtype, shape, dimension, - sharding): - out_aval, = out_avals - new_shape = [] - new_dyn_shape = [] - for d in out_aval.shape: - if type(d) is pe.BoundedAxisSize: - new_shape.append(d.bound) - elif type(d) is int: - new_shape.append(d) - else: - assert isinstance(d, core.Tracer) - new_shape.append(None) - new_dyn_shape.append(d) - if sharding is not None: - raise NotImplementedError('Please file an issue if you want this support') - return [iota_p.bind(*new_dyn_shape, shape=tuple(new_shape), - dtype=dtype, dimension=dimension, sharding=sharding)] -pe.padding_rules[iota_p] = _iota_padding_rule - - ### util _ndim = np.ndim @@ -8256,22 +8569,50 @@ class PaddingType(enum.Enum): SAME_LOWER = 3 -def padtype_to_pads(in_shape, window_shape, window_strides, padding): - """Convert padding string to list of pairs of pad values.""" +def padtype_to_pads( + in_shape: Sequence[int] | np.ndarray, + window_shape: Sequence[int] | np.ndarray, + window_strides: Sequence[int] | np.ndarray, + padding: str | PaddingType) -> list[tuple[int, int]]: + """Convert a padding specification to a list of pad value pairs. + + This utility resolves abstract convolution padding modes into concrete + per-dimension integer padding values based on the input and window geometry. + + Args: + in_shape: Sequence of integers specifying the input spatial shape. + window_shape: Sequence of integers specifying the kernel/window spatial shape. + window_strides: Sequence of integers specifying the spatial strides. + padding: Either a padding string (``'SAME'``, ``'SAME_LOWER'``, or ``'VALID'``) + or a ``PaddingType`` enum value. Other values will result in an error. + + Returns: + A list of ``(low, high)`` integer tuples, one for each spatial dimension, + specifying the padding to apply before and after each dimension. + + Raises: + RuntimeError: If ``padding`` is a string but not one of the supported values. + TypeError: If ``padding`` is not a supported string or ``PaddingType`` value. + Notes: + - ``'VALID'``: Returns zero padding ``(0, 0)`` for all dimensions. + - ``'SAME'``: Pads such that the output spatial shape is computed via + ceiling division of ``in_shape`` by ``window_strides``. If the required + padding amount is odd, the extra padding is added to the **end** + (high side) of the dimension. + - ``'SAME_LOWER'``: Similar to ``'SAME'``, but if the required padding + amount is odd, the extra padding is added to the **start** + (low side) of the dimension. + """ if isinstance(padding, str): - mapping = { - 'VALID': PaddingType.VALID, - 'SAME': PaddingType.SAME, - 'SAME_LOWER': PaddingType.SAME_LOWER, - } try: - padding = mapping[padding.upper()] + padding = PaddingType[padding.upper()] except KeyError as err: - msg = "Unrecognized padding type: expected 'VALID' or 'SAME', got {}." - raise RuntimeError(msg.format(padding)) from err + raise RuntimeError( + f"Unrecognized padding type: expected 'VALID', 'SAME', or 'SAME_LOWER', got {padding}." + ) from err - if padding == PaddingType.SAME or padding == PaddingType.SAME_LOWER: + if padding in (PaddingType.SAME, PaddingType.SAME_LOWER): out_shape = _ceil_divide(in_shape, window_strides) pad_sizes = (core.max_dim(d, 0) for d in (out_shape - 1) * window_strides + @@ -8285,12 +8626,11 @@ def padtype_to_pads(in_shape, window_shape, window_strides, padding): (pad_size - pad_size // 2, pad_size // 2) for pad_size in pad_sizes ] # Avoids verbose numpy scalars in jaxprs. - return [p.item() if isinstance(p, np.generic) else p for p in pads] + return tree_util.tree_map(lambda x: x.item() if isinstance(x, np.generic) else x, pads) elif padding == PaddingType.VALID: return [(0, 0)] * len(in_shape) else: - msg = "Unknown padding type: {}." - raise TypeError(msg.format(padding)) + raise TypeError(f"Unknown padding type: {padding}.") # Map of lax function to equivalent jax.numpy function for use in error string below. @@ -8349,7 +8689,7 @@ def padtype_to_pads(in_shape, window_shape, window_strides, padding): 'tanh': 'tanh' } -def check_same_dtypes(name: str, *avals: core.UnshapedArray) -> None: +def check_same_dtypes(name: str, *avals: ShapedArray) -> None: """Check that dtypes agree, possibly ignoring float precision.""" # the `ignore_fp_precision` flag exists because the XLA shape inference logic # allows mixed floating point precision, but the HLO verifier often rejects it @@ -8358,8 +8698,8 @@ def check_same_dtypes(name: str, *avals: core.UnshapedArray) -> None: if len(avals) < 2: return - dtype = dtypes.canonicalize_dtype(avals[0].dtype) - if any(dtypes.canonicalize_dtype(aval.dtype) != dtype for aval in avals[1:]): + dtype = avals[0].dtype + if any(aval.dtype != dtype for aval in avals[1:]): msg = "lax.{} requires arguments to have the same dtypes, got {}." if name in _JNP_FUNCTION_EQUIVALENTS: equiv = _JNP_FUNCTION_EQUIVALENTS[name] @@ -8375,9 +8715,6 @@ def _check_shapelike(fun_name, arg_name, obj, non_zero_shape=False): # bool(obj) for an ndarray raises an error, so we check len if not len(obj): # pylint: disable=g-explicit-length-test return - if (config.dynamic_shapes.value and isinstance(obj, (tuple, list)) and - any(isinstance(d, (core.Tracer, core.DArray)) for d in obj)): - return # TODO(mattjj): handle more checks in the dynamic shape case obj_arr = np.array(obj) if obj_arr.ndim != 1: msg = "{} {} must be 1-dimensional, got {}." @@ -8399,27 +8736,35 @@ def _const(example, val): if dtypes.is_python_scalar(example): val = dtypes.scalar_type_of(example)(val) return val if dtype == _dtype(val) else np.array(val, dtype) - return np.array(val, dtype) + return literals.TypedNdArray(np.array(val, dtype), weak_type=False) _zeros: Callable = partial(full_like, fill_value=0) def _zero(x): x_aval = core.get_aval(x) - return full_like(x, shape=(), fill_value=0, - sharding=x_aval.sharding.with_spec(P())) + out = full_like(x, shape=(), fill_value=0, + sharding=x_aval.sharding.update(spec=P())) + return out _ones: Callable = partial(full_like, fill_value=1) def _one(x): x_aval = core.get_aval(x) - return full_like(x, shape=(), fill_value=1, - sharding=x_aval.sharding.with_spec(P())) + out = full_like(x, shape=(), fill_value=1, + sharding=x_aval.sharding.update(spec=P())) + return out + +def _one_vjp(x): + x_aval = core.get_aval(x) + ct_s = core.primal_sharding_to_cotangent_sharding(x_aval.sharding) + ct_s = ct_s.update(spec=ct_s.spec.update(partitions=())) + return full_like(x, shape=(), fill_value=1, sharding=ct_s) _twos: Callable = partial(full_like, fill_value=2) _two: Callable = partial(full_like, shape=(), fill_value=2) -dtype: Callable = partial(dtypes.dtype, canonicalize=True) -_dtype: Callable = partial(dtypes.dtype, canonicalize=True) +dtype: Callable = dtypes.dtype +_dtype: Callable = dtypes.dtype def _isnan(x: ArrayLike) -> Array: return ne(x, x) @@ -8506,16 +8851,92 @@ def _eq_meet(a, b): return eq(a, b) -def empty(dtype): - return empty_p.bind(dtype=dtype) +def empty(shape, dtype, *, out_sharding=None): + """Create an empty array of possibly uninitialized values. + + This initialization is backend dependent. + + Args: + shape: int or sequence of ints specifying the shape of the created array. + dtype: dtype for the created array. + out_sharding: (optional) :class:`~jax.sharding.PartitionSpec` or + :class:`~jax.NamedSharding` representing the sharding of the created + array (see `explicit sharding`_ for more details). + + Returns: + Uninitialized array of the specified shape, dtype, and sharding. + + Examples: + >>> jnp.empty(3, jnp.float32) # doctest: +SKIP + Array([-5.7326739e+29 -7.7323739e+29 -3.14159256e-29], dtype=float32) + + .. _explicit sharding: https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html + """ + out_sharding = canonicalize_sharding(out_sharding, 'lax.empty') + return empty_p.bind(shape=shape, dtype=dtype, out_sharding=out_sharding) + empty_p = core.Primitive('empty') -empty_p.def_abstract_eval(lambda *, dtype: core.ShapedArray((), dtype)) -def _empty_lower(ctx, *, dtype): +empty_p.def_impl(partial(dispatch.apply_primitive, empty_p)) + +def _empty_abstract_eval(*, shape, dtype, out_sharding): + return core.ShapedArray(shape, dtype, sharding=out_sharding) +empty_p.def_abstract_eval(_empty_abstract_eval) + +def _empty_custom_call_lower(ctx, *, shape, dtype, out_sharding): + if not core.is_constant_shape(shape): + return _empty_lower(ctx, shape=shape, dtype=dtype, out_sharding=out_sharding) dtype = dtype if dtypes.issubdtype(dtype, dtypes.extended) else np.dtype(dtype) - phys_aval = core.physical_aval(core.ShapedArray((), dtype)) - return mlir.ir_constant(np.zeros(phys_aval.shape, phys_aval.dtype)), + aval_out = core.ShapedArray(shape, dtype, sharding=out_sharding) + phys_aval = core.physical_aval(aval_out) + custom_call_op = hlo.CustomCallOp( + [mlir.ir.RankedTensorType.get( + list(phys_aval.shape), mlir.dtype_to_ir_type(phys_aval.dtype))], + [], + call_target_name=mlir.ir.StringAttr.get("AllocateBuffer"), + has_side_effect=mlir.ir.BoolAttr.get(False), + ) + assert len(custom_call_op.results) == 1 + res = custom_call_op.results[0] + return [mlir.lower_with_sharding_in_types(ctx, res, phys_aval)] +mlir.register_lowering(empty_p, _empty_custom_call_lower, 'tpu') +mlir.register_lowering(empty_p, _empty_custom_call_lower, 'gpu') + +def _empty_lower(ctx, *, shape, dtype, out_sharding): + dtype = dtype if dtypes.issubdtype(dtype, dtypes.extended) else np.dtype(dtype) + aval_out = core.ShapedArray(shape, dtype, sharding=out_sharding) + phys_aval = core.physical_aval(aval_out) + out = mlir.ir_constant(np.zeros((), phys_aval.dtype)) + out = mlir.broadcast_in_dim(ctx, out, phys_aval, broadcast_dimensions=[]) + return [mlir.lower_with_sharding_in_types(ctx, out, phys_aval)] mlir.register_lowering(empty_p, _empty_lower) +def _empty_batcher(axis_data, vals_in, dims_in, *, shape, dtype, out_sharding): + batched_shape = tuple_insert(shape, 0, axis_data.size) + batched_out_sharding = ( + None if out_sharding is None else + batching.get_sharding_for_vmap(axis_data, out_sharding, 0)) + y = empty_p.bind(shape=batched_shape, dtype=dtype, + out_sharding=batched_out_sharding) + return y, 0 +batching.fancy_primitive_batchers[empty_p] = _empty_batcher + +# TODO(yashkatariya): Delete `empty2` and replace scan's usage with `empty` once +# AllocateBuffer issues are fixed +def empty2(dtype, *, memory_space): + return empty2_p.bind(dtype=dtype, memory_space=memory_space) +empty2_p = core.Primitive('empty2') +dispatch.simple_impl(empty2_p) + +def _empty2_abstract_eval(*, dtype, memory_space): + return core.ShapedArray((), dtype, memory_space=memory_space) +empty2_p.def_abstract_eval(_empty2_abstract_eval) + +def _empty2_lower(ctx, *, dtype, memory_space): + dtype = dtype if dtypes.issubdtype(dtype, dtypes.extended) else np.dtype(dtype) + phys_aval = core.physical_aval(core.ShapedArray((), dtype)) + return [mlir.ir_constant(np.zeros(phys_aval.shape, phys_aval.dtype))] +mlir.register_lowering(empty2_p, _empty2_lower) + tie_p = core.Primitive('tie') tie_p.def_impl(lambda x, y: y) @@ -8524,51 +8945,23 @@ def _empty_lower(ctx, *, dtype): ad.primitive_jvps[tie_p] = \ lambda primals, tangents: (tie_p.bind(*primals), tangents[-1]) ad.primitive_transposes[tie_p] = lambda ct, x, _: [None, ct] -pe.def_trivial_padding(tie_p) batching.defvectorized(tie_p) -class BIntRules: - allow_conversion: bool = True - - @staticmethod - def physical_element_aval(dtype) -> core.ShapedArray: - return core.ShapedArray((), np.dtype('int32')) - - @staticmethod - def result_handler(sticky_device, aval): - def handler(_, buf): - buf.aval = core.ShapedArray(buf.shape, buf.dtype) - return core.DArray(aval, buf) - return handler - - @staticmethod - def global_sharded_result_handler(aval, out_sharding, committed): - phys_aval = core.physical_aval(aval) - phys_handler_maker = pxla.global_result_handlers[core.ShapedArray] - - if not dispatch.is_single_device_sharding(out_sharding): - raise NotImplementedError # TODO(mattjj) - else: - phys_sharding = out_sharding - phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed) - - def handler(bufs): - return core.DArray(aval, phys_handler(bufs)) - return handler - - -core.bint._rules = BIntRules - - def optimization_barrier(operand, /): """Prevents the compiler from moving operations across the barrier. Optimization barriers have a number of possible uses: - * An optimization barrier ensures that all inputs are evaluated before any - operators that depend on the barrier's outputs. This can be used to enforce - a particular order of operations. + * An optimization barrier ensures that every output of the barrier that is + used by any operator, has been evaluated before any operator that depends + on one of the barrier's outputs. This can be used to enforce a particular + order of operations. + + Note that all operands must be used through the barrier for this to work. + There are no ordering constraints between an operator that uses one of the + barrier's outputs, and an operator that directly (not through the barrier) + uses one of the barrier's inputs. * An optimization barrier prevents common subexpression elimination. This is used by JAX to implement rematerialization. * Optimization barriers prevent compiler fusions. That is, operations before @@ -8594,11 +8987,13 @@ def optimization_barrier(operand, /): Array(0., dtype=float32, weak_type=True) """ flat_args, treedef = tree_util.tree_flatten(operand) - return tree_util.tree_unflatten( - treedef, optimization_barrier_p.bind(*flat_args)) + flat_args = core.standard_insert_pvary(*flat_args) + out = optimization_barrier_p.bind(*flat_args) + return tree_util.tree_unflatten(treedef, out) def _optimization_barrier_abstract_eval(*args): + core.standard_vma_rule('optimization_barrier', *args) return args def _optimization_barrier_lowering_rule(ctx, *args): diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index c674401fb80d..94d7b18694d0 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -20,11 +20,10 @@ import math import string from typing import Any, Literal, overload +import warnings import numpy as np -from jax import lax - from jax._src import ad_util from jax._src import api from jax._src import config @@ -39,16 +38,14 @@ from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.lax import control_flow -from jax._src.lax import eigh as lax_eigh -from jax._src.lax import lax as lax_internal -from jax._src.lax import svd as lax_svd +from jax._src.lax import lax from jax._src.lax import utils as lax_utils from jax._src.lax.lax import _float, _complex, _int +from jax._src.lib import cuda_versions from jax._src.lib import gpu_linalg from jax._src.lib import gpu_solver from jax._src.lib import gpu_sparse from jax._src.lib import lapack -from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo @@ -121,14 +118,23 @@ def cholesky_update(r_matrix: ArrayLike, w_vector: ArrayLike) -> Array: A new upper-triangular matrix :math:`R` defining the Cholesky decomposition of :math:`A + w \, w^T`. """ + r_matrix, w_vector = core.standard_insert_pvary(r_matrix, w_vector) return cholesky_update_p.bind(r_matrix, w_vector) +class EigImplementation(enum.Enum): + """Enum for eigendecomposition algorithm.""" + CUSOLVER = "cusolver" + MAGMA = "magma" + LAPACK = "lapack" + + def eig( x: ArrayLike, *, compute_left_eigenvectors: bool = True, compute_right_eigenvectors: bool = True, + implementation: EigImplementation | None = None, use_magma: bool | None = None, ) -> list[Array]: """Eigendecomposition of a general matrix. @@ -163,11 +169,22 @@ def eig( compute_left_eigenvectors: If true, the left eigenvectors will be computed. compute_right_eigenvectors: If true, the right eigenvectors will be computed. - use_magma: Locally override the ``jax_use_magma`` flag. If ``True``, the - eigendecomposition is computed using MAGMA. If ``False``, the computation - is done using LAPACK on to the host CPU. If ``None`` (default), the - behavior is controlled by the ``jax_use_magma`` flag. This argument - is only used on GPU. + use_magma: Deprecated, please use ``implementation`` instead. Locally + override the ``jax_use_magma`` flag. If ``True``, the eigendecomposition + is computed using MAGMA. If ``False``, the computation is done using + LAPACK on to the host CPU. If ``None`` (default), the behavior is + controlled by the ``jax_use_magma`` flag. This argument is only used on + GPU. Will be removed in JAX 0.9. + implementation: Controls the choice of eigendecomposition algorithm. If + ``LAPACK``, the computation will be performed using LAPACK on the host CPU. + If ``MAGMA``, the computation will be performed using the MAGMA library on + the GPU. If ``CUSOLVER``, the computation will be performed using the + Cusolver library on the GPU. The ``CUSOLVER`` implementation requires + Cusolver 11.7.1 (from CUDA 12.6 update 2) to be installed, and does not + support computing left eigenvectors. + If ``None`` (default), an automatic choice will be made, depending on the + Cusolver version, whether left eigenvectors were requested, and the + ``jax_use_magma`` configuration variable. Returns: The eigendecomposition of ``x``, which is a tuple of the form @@ -179,9 +196,26 @@ def eig( If the eigendecomposition fails, then arrays full of NaNs will be returned for that batch element. """ + if use_magma is not None: + warnings.warn( + "use_magma is deprecated, please use" + " implementation=EigImplementation.MAGMA instead.", + DeprecationWarning, + stacklevel=2, + ) + implementation = ( + EigImplementation.MAGMA if use_magma else EigImplementation.LAPACK + ) return eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors, compute_right_eigenvectors=compute_right_eigenvectors, - use_magma=use_magma) + implementation=implementation) + + +class EighImplementation(enum.Enum): + """Implementation for symmetric/Hermitian eigendecomposition.""" + QR = "qr" + JACOBI = "jacobi" + QDWH = "qdwh" def eigh( @@ -191,6 +225,7 @@ def eigh( symmetrize_input: bool = True, sort_eigenvalues: bool = True, subset_by_index: tuple[int, int] | None = None, + implementation: EighImplementation | None = None, ) -> tuple[Array, Array]: r"""Eigendecomposition of a Hermitian matrix. @@ -213,6 +248,10 @@ def eigh( indices of eigenvalues to compute. For example, is ``range_select`` = [n-2,n], then ``eigh`` computes the two largest eigenvalues and their eigenvectors. + implementation: Optional implementation selection. ``QR`` uses QR-based + decomposition (default for CPU/GPU). ``JACOBI`` uses Jacobi iteration + (GPU/TPU only). ``QDWH`` uses QDWH spectral divide-and-conquer + (default on TPU, TPU only). Returns: A tuple ``(v, w)``. @@ -233,6 +272,7 @@ def eigh( lower=lower, sort_eigenvalues=sort_eigenvalues, subset_by_index=subset_by_index, + algorithm=implementation, ) return v, w @@ -268,6 +308,7 @@ def householder_product(a: ArrayLike, taus: ArrayLike) -> Array: A batch of orthogonal (unitary) matrices with the same shape as ``a``, containing the products of the elementary Householder reflectors. """ + a, taus = core.standard_insert_pvary(a, taus) return householder_product_p.bind(a, taus) @@ -432,6 +473,7 @@ class SvdAlgorithm(enum.Enum): DEFAULT = "default" QR = "QR" JACOBI = "Jacobi" + POLAR = "polar" @overload @@ -526,7 +568,7 @@ def symmetric_product( Computes the symmetric product - ..math:: + .. math:: \alpha \, A \, A^T + \beta \, C where :math:`A` is a rectangular matrix and :math:`C` is a symmetric matrix. @@ -545,6 +587,7 @@ def symmetric_product( ``symmetrize_output`` is ``True``, the upper triangle is filled with the transpose of the lower triangle, and the whole matrix is valid. """ + a_matrix, c_matrix = core.standard_insert_pvary(a_matrix, c_matrix) result = symmetric_product_p.bind(a_matrix, c_matrix, alpha=alpha, beta=beta) if symmetrize_output: upper_half = lax.transpose( @@ -602,6 +645,7 @@ def triangular_solve( singleton = np.ndim(b) == np.ndim(a) - 1 if singleton: b = lax.expand_dims(b, (-1 if left_side else -2,)) + a, b = core.standard_insert_pvary(a, b) out = triangular_solve_p.bind( a, b, left_side=left_side, lower=lower, transpose_a=transpose_a, conjugate_a=conjugate_a, unit_diagonal=unit_diagonal) @@ -635,7 +679,7 @@ def tridiagonal( superdiagonal. ``taus`` contains the scalar factors of the elementary Householder reflectors. """ - return tridiagonal_p.bind(lax_internal.asarray(a), lower=lower) + return tridiagonal_p.bind(lax.asarray(a), lower=lower) def tridiagonal_solve(dl: Array, d: Array, du: Array, b: Array) -> Array: @@ -661,6 +705,7 @@ def tridiagonal_solve(dl: Array, d: Array, du: Array, b: Array) -> Array: Returns: Solution ``X`` of tridiagonal system. """ + dl, d, du, b = core.standard_insert_pvary(dl, d, du, b) return tridiagonal_solve_p.bind(dl, d, du, b) @@ -717,34 +762,42 @@ def linalg_sharding_rule( spec = aval.sharding.spec batch_spec, rest_spec = spec[:len(spec) - rank], spec[len(spec) - rank:] if not all(s is None for s in rest_spec): - raise ValueError( + raise core.ShardingTypeError( f"Input {i} to {name} must be unsharded on non-batch dimensions, " f"but got {spec}." ) batch_specs.append(batch_spec) batch_spec = batch_specs[0] if any(b != batch_spec for b in batch_specs[1:]): - raise ValueError( + raise core.ShardingTypeError( f"All inputs to {name} must have the same batch sharding, but got " f"{batch_specs}." ) sharding = avals[0].sharding if multiple_results: return [ - sharding.with_spec( + sharding.update(spec= P(*(tuple(batch_spec) + (None,) * (len(s) - len(batch_spec)))) ) for s in output_shapes ] else: ndim = len(output_shapes) - len(batch_spec) - return sharding.with_spec(P(*(tuple(batch_spec) + (None,) * ndim))) + return sharding.update(spec=P(*(tuple(batch_spec) + (None,) * ndim))) + +def linalg_vma_rule(multiple_results, shape_rule, name, *avals, **kwargs): + output_shapes = shape_rule(*avals, **kwargs) + out_vma = core.standard_vma_rule(name, *avals) + if multiple_results: + return [out_vma] * len(output_shapes) + else: + return out_vma def linalg_primitive(result_dtype, accepted_dtypes, ranks, result_shape, name, multiple_results=False, supports_batching=True, require_same=True): dtype_rule = partial( - lax_internal.naryop_dtype_rule, result_dtype, accepted_dtypes, name, + lax.naryop_dtype_rule, result_dtype, accepted_dtypes, name, require_same=require_same) shape_rule = partial( linalg_shape_rule, multiple_results, supports_batching, ranks, @@ -754,6 +807,7 @@ def linalg_primitive(result_dtype, accepted_dtypes, ranks, result_shape, name, linalg_sharding_rule, multiple_results, shape_rule, ranks, name) else: sharding_rule = None + vma_rule = partial(linalg_vma_rule, multiple_results, shape_rule, name) prim = core.Primitive(name) prim.multiple_results = multiple_results prim.def_impl(partial(dispatch.apply_primitive, prim)) @@ -761,17 +815,18 @@ def linalg_primitive(result_dtype, accepted_dtypes, ranks, result_shape, name, prim.def_abstract_eval( partial(lax_utils.standard_multi_result_abstract_eval, prim, shape_rule, dtype_rule, lax_utils._standard_weak_type_rule, - sharding_rule)) + sharding_rule, vma_rule)) else: prim.def_abstract_eval( partial(lax_utils.standard_abstract_eval, prim, shape_rule, dtype_rule, - lax_utils._standard_weak_type_rule, sharding_rule)) + lax_utils._standard_weak_type_rule, sharding_rule, + partial(core.standard_vma_rule, name), None, None, None)) if supports_batching: batching.primitive_batchers[prim] = partial( batching.expand_dims_batcher, prim) return prim -standard_linalg_primitive = partial(linalg_primitive, lax_internal._input_dtype) +standard_linalg_primitive = partial(linalg_primitive, lax.input_dtype) # Primitive implementations @@ -794,7 +849,7 @@ def _cholesky_jvp_rule(primals, tangents): def phi(X): l = _tril(X) return l / lax.expand_dims( - lax_internal._const(X, 1) + lax_internal._eye(X.dtype, (X.shape[-1], X.shape[-1])), + lax._const(X, 1) + lax._eye(X.dtype, (X.shape[-1], X.shape[-1])), range(l.ndim - 2)) tmp = triangular_solve(L, sigma_dot, left_side=False, transpose_a=True, @@ -824,11 +879,27 @@ def _cholesky_cpu_lowering(ctx, operand): return [_replace_not_ok_with_nan(ctx, batch_dims, ok, result, out_aval)] +def _cholesky_gpu_lowering(ctx, operand, *, target_name_prefix): + operand_aval, = ctx.avals_in + out_aval, = ctx.avals_out + batch_dims = operand_aval.shape[:-2] + info_aval = ShapedArray(batch_dims, np.int32) + rule = _linalg_ffi_lowering(f"{target_name_prefix}solver_potrf_ffi", + avals_out=[operand_aval, info_aval], + operand_output_aliases={0: 0}) + result, info = rule(ctx, operand, lower=True) + ok = mlir.compare_hlo(info, mlir.full_like_aval(ctx, 0, info_aval), "EQ", + "SIGNED") + return [_replace_not_ok_with_nan(ctx, batch_dims, ok, result, out_aval)] + + cholesky_p = standard_linalg_primitive( (_float | _complex,), (2,), _cholesky_shape_rule, "cholesky") ad.primitive_jvps[cholesky_p] = _cholesky_jvp_rule mlir.register_lowering(cholesky_p, _cholesky_lowering) mlir.register_lowering(cholesky_p, _cholesky_cpu_lowering, platform="cpu") +register_cpu_gpu_lowering(cholesky_p, _cholesky_gpu_lowering, + supported_platforms=("cuda", "rocm")) # Cholesky update @@ -857,7 +928,8 @@ def _drotg_nonzero(x, y): np.array(1., dtype=x.dtype), np.array(0., dtype=x.dtype), ) - return lax.cond(y == 0, lambda x, y: one_and_zero, _drotg_nonzero, x, y) + return control_flow.cond( + y == 0, lambda x, y: one_and_zero, _drotg_nonzero, x, y) def _drot( first_vector: Array, second_vector: Array, @@ -896,7 +968,7 @@ def _cholesky_update_gpu_lowering_rule(target_name_prefix, ctx, r_matrix, def _eig_dtype_rule( a_dtype, *, compute_left_eigenvectors, compute_right_eigenvectors, **_ ): - dtype = dtypes.to_complex_dtype(dtypes.canonicalize_dtype(a_dtype)) + dtype = dtypes.to_complex_dtype(a_dtype) return (dtype,) * (1 + compute_left_eigenvectors + compute_right_eigenvectors) def _eig_shape_rule( @@ -915,8 +987,9 @@ def _eig_compute_attr(compute): ) def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors, - compute_right_eigenvectors, use_magma): - del use_magma # unused + compute_right_eigenvectors, implementation): + if implementation and implementation != EigImplementation.LAPACK: + raise ValueError("Only the lapack implementation is supported on CPU.") operand_aval, = ctx.avals_in out_aval = ctx.avals_out[0] batch_dims = operand_aval.shape[:-2] @@ -950,48 +1023,134 @@ def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors, output.append(vr) return output +def _unpack_conjugate_pairs(w, vr): + # cusolver, like LAPACK, uses a packed representation of the complex + # eigenvectors, where the (re, im) vectors are adjacent and shared by the + # conjugate pair: + # https://docs.nvidia.com/cuda/cusolver/index.html?highlight=geev#cusolverdnxgeev + if w.size == 0: + return lax.complex(vr, lax.full_like(vr, 0)) + + is_real = ((w.imag == 0) | (w.imag == np.nan)) + # Finds the positions at which each conjugate pair starts, via the parity of + # the count of the number of complex numbers seen. + conj_pair_start = control_flow.cumsum((~is_real).astype(int), + axis=len(w.shape) - 1) + conj_pair_start = conj_pair_start % 2 == 1 + pads = [(0, 0, 0)] * (len(vr.shape)) + pads[-1] = (-1, 1, 0) + vr_shifted_left = lax.pad(vr, lax._zero(vr), pads) + pads[-1] = (1, -1, 0) + vr_shifted_right = lax.pad(vr, lax._zero(vr), pads) + dims = np.delete(np.arange(len(vr.shape), dtype=np.int32), -2) + is_real = lax.broadcast_in_dim(is_real, vr.shape, broadcast_dimensions=dims) + conj_pair_start = lax.broadcast_in_dim(conj_pair_start, vr.shape, + broadcast_dimensions=dims) + re = lax.select(is_real | conj_pair_start, vr, vr_shifted_right) + im = lax.select(conj_pair_start, vr_shifted_left, -vr) + im = lax.select(is_real, lax.full_like(vr, 0), im) + return lax.complex(re, im) + + def _eig_gpu_lowering(ctx, operand, *, compute_left_eigenvectors, compute_right_eigenvectors, - use_magma, target_name_prefix): + implementation, target_name_prefix): operand_aval, = ctx.avals_in batch_dims = operand_aval.shape[:-2] n, m = operand_aval.shape[-2:] assert n == m - gpu_solver.initialize_hybrid_kernels() dtype = operand_aval.dtype - is_real = dtype == np.float32 or dtype == np.float64 - if is_real: - target_name = f"{target_name_prefix}hybrid_eig_real" - complex_dtype = np.complex64 if dtype == np.float32 else np.complex128 + complex_dtype = np.result_type(dtype, 1j) + if dtype in (np.float32, np.float64): + is_real = True + elif dtype in (np.complex64, np.complex128): + is_real = False else: - target_name = f"{target_name_prefix}hybrid_eig_comp" - assert dtype == np.complex64 or dtype == np.complex128 - complex_dtype = dtype - - avals_out = [ - ShapedArray(batch_dims + (n,), dtype), - ShapedArray(batch_dims + (n, n), complex_dtype), - ShapedArray(batch_dims + (n, n), complex_dtype), - ShapedArray(batch_dims, np.int32), - ] - if is_real: - avals_out = [ShapedArray(batch_dims + (n,), dtype)] + avals_out + raise ValueError(f"Unsupported dtype: {dtype}") - magma = config.gpu_use_magma.value - if use_magma is not None: - magma = "on" if use_magma else "off" + have_cusolver_geev = ( + target_name_prefix == "cu" + and cuda_versions + and cuda_versions.cusolver_get_version() >= 11701 + ) - rule = _linalg_ffi_lowering(target_name, avals_out=avals_out) - *w, vl, vr, info = rule(ctx, operand, magma=magma, - left=compute_left_eigenvectors, - right=compute_right_eigenvectors) - if is_real: - assert len(w) == 2 - w = hlo.complex(*w) + if ( + implementation is None and have_cusolver_geev + and not compute_left_eigenvectors + ) or implementation == EigImplementation.CUSOLVER: + if not have_cusolver_geev: + raise RuntimeError( + "Nonsymmetric eigendecomposition requires cusolver 11.7.1 or newer" + ) + if compute_left_eigenvectors: + raise NotImplementedError( + "Left eigenvectors are not supported by cusolver") + target_name = f"{target_name_prefix}solver_geev_ffi" + avals_out = [ + ShapedArray(batch_dims + (n, n), dtype), + ShapedArray(batch_dims + (n,), complex_dtype), + ShapedArray(batch_dims + (n, n), dtype), + ShapedArray(batch_dims + (n, n), dtype), + ShapedArray(batch_dims, np.int32), + ] + + rule = _linalg_ffi_lowering(target_name, avals_out=avals_out) + _, w, vl, vr, info = rule(ctx, operand, left=compute_left_eigenvectors, + right=compute_right_eigenvectors) + if is_real: + unpack = mlir.lower_fun(_unpack_conjugate_pairs, multiple_results=False) + if compute_left_eigenvectors: + sub_ctx = ctx.replace( + primitive=None, + avals_in=[ + ShapedArray(batch_dims + (n,), complex_dtype), + ShapedArray(batch_dims + (n, n), dtype), + ], + avals_out=[ShapedArray(batch_dims + (n, n), complex_dtype)], + ) + vl, = unpack(sub_ctx, w, vl) + if compute_right_eigenvectors: + sub_ctx = ctx.replace( + primitive=None, + avals_in=[ + ShapedArray(batch_dims + (n,), complex_dtype), + ShapedArray(batch_dims + (n, n), dtype), + ], + avals_out=[ShapedArray(batch_dims + (n, n), complex_dtype)], + ) + vr, = unpack(sub_ctx, w, vr) else: - assert len(w) == 1 - w = w[0] + magma = config.gpu_use_magma.value + if implementation is not None: + magma = "on" if implementation == EigImplementation.MAGMA else "off" + gpu_solver.initialize_hybrid_kernels() + if is_real: + target_name = f"{target_name_prefix}hybrid_eig_real" + complex_dtype = np.complex64 if dtype == np.float32 else np.complex128 + else: + target_name = f"{target_name_prefix}hybrid_eig_comp" + assert dtype == np.complex64 or dtype == np.complex128 + complex_dtype = dtype + + avals_out = [ + ShapedArray(batch_dims + (n,), dtype), + ShapedArray(batch_dims + (n, n), complex_dtype), + ShapedArray(batch_dims + (n, n), complex_dtype), + ShapedArray(batch_dims, np.int32), + ] + if is_real: + avals_out = [ShapedArray(batch_dims + (n,), dtype)] + avals_out + rule = _linalg_ffi_lowering(target_name, avals_out=avals_out) + *w, vl, vr, info = rule(ctx, operand, magma=magma, + left=compute_left_eigenvectors, + right=compute_right_eigenvectors) + if is_real: + assert len(w) == 2 + w = hlo.complex(*w) + else: + assert len(w) == 1 + w = w[0] zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.int32)) ok = mlir.compare_hlo(info, zeros, "EQ", "SIGNED") w_aval = ShapedArray(batch_dims + (n,), complex_dtype) @@ -1008,18 +1167,17 @@ def _eig_gpu_lowering(ctx, operand, *, return output def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors, - compute_right_eigenvectors, use_magma): - del use_magma # unused + compute_right_eigenvectors, implementation): if compute_left_eigenvectors or compute_right_eigenvectors: raise NotImplementedError( - 'The derivatives of eigenvectors are not implemented, only ' - 'eigenvalues. See ' + 'The derivatives of non-symmetric eigenvectors are not supported. ' + 'Only first-order derivatives of eigenvalues are supported. See ' 'https://github.com/jax-ml/jax/issues/2748 for discussion.') # Formula for derivative of eigenvalues w.r.t. a is eqn 4.60 in # https://arxiv.org/abs/1701.00392 a, = primals da, = tangents - l, v = eig(a, compute_left_eigenvectors=False) + l, v = eig(a, compute_left_eigenvectors=False, implementation=implementation) return [l], [(_solve(v, da.astype(v.dtype)) * _T(v)).sum(-1)] eig_p = linalg_primitive( @@ -1032,67 +1190,6 @@ def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors, # Symmetric/Hermitian eigendecomposition -def eigh_jacobi(x: ArrayLike, *, lower: bool = True, - sort_eigenvalues: bool = True) -> tuple[Array, Array]: - """Helper Jacobi eigendecomposition implemented by XLA. - - Used as a subroutine of QDWH-eig on TPU. - """ - return eigh_jacobi_p.bind(x, lower=lower, sort_eigenvalues=sort_eigenvalues) - -def _eigh_jacobi_shape_rule(shape, **_): - if shape[0] != shape[-1]: - raise ValueError( - "Argument to symmetric eigendecomposition must have shape [..., n, n], " - f"got shape {shape}" - ) - n = shape[0] - return (n,), (n, n) - -def _eigh_jacobi_dtype_rule(dtype, **_): - dtype = dtypes.canonicalize_dtype(dtype) - return lax_internal._complex_basetype(dtype), dtype - -def _eigh_jacobi_lowering_rule(ctx, operand, lower, sort_eigenvalues): - operand_aval, = ctx.avals_in - if operand_aval.shape[-1] == 0: - reshape_aval = operand_aval.update(shape=operand_aval.shape[:-1]) - return [ - hlo.real(mlir.reshape(ctx, operand, reshape_aval)), - operand, - ] - - eigvals_type = mlir.aval_to_ir_type(ctx.avals_out[0]) - eigvecs_type = mlir.aval_to_ir_type(ctx.avals_out[1]) - result_types = [eigvecs_type, eigvals_type] - - backend_config = f"{int(lower)},{int(sort_eigenvalues)},100,1e-6" - - if any(not is_constant_shape(aval_out.shape) - for aval_out in ctx.avals_out): - result_shapes = [ - mlir.eval_dynamic_shape_as_tensor(ctx, aval_out.shape) - # The custom call returns the results swapped - for aval_out in list(reversed(ctx.avals_out)) - ] - else: - result_shapes = None - op = mlir.custom_call( - "Eigh", - result_types=result_types, - operands=[operand], - backend_config=backend_config, - api_version=1, - result_shapes=result_shapes, - ) - return op.results[1], op.results[0] - -eigh_jacobi_p = linalg_primitive( - _eigh_jacobi_dtype_rule, (_float | _complex,), (2,), - _eigh_jacobi_shape_rule, "eigh_jacobi", multiple_results=True) -mlir.register_lowering(eigh_jacobi_p, _eigh_jacobi_lowering_rule) - - def _eigh_shape_rule(shape, *, subset_by_index, **_): if shape[0] != shape[-1]: raise ValueError( @@ -1105,11 +1202,10 @@ def _eigh_shape_rule(shape, *, subset_by_index, **_): return (n, d), (d,) def _eigh_dtype_rule(dtype, **_): - dtype = dtypes.canonicalize_dtype(dtype) - return dtype, lax_internal._complex_basetype(dtype) + return dtype, lax._complex_basetype(dtype) def _eigh_cpu_gpu_lowering( - ctx, operand, *, lower, sort_eigenvalues, subset_by_index, + ctx, operand, *, lower, sort_eigenvalues, subset_by_index, algorithm, target_name_prefix: str ): del sort_eigenvalues # The CPU/GPU implementations always sort. @@ -1119,6 +1215,12 @@ def _eigh_cpu_gpu_lowering( if not (subset_by_index is None or subset_by_index == (0, n)): raise NotImplementedError("subset_by_index not supported on CPU and GPU") batch_dims = operand_aval.shape[:-2] + + if algorithm == EighImplementation.QDWH: + raise NotImplementedError("QDWH implementation is only supported on TPU") + if algorithm == EighImplementation.JACOBI and target_name_prefix == "cpu": + raise NotImplementedError("Jacobi implementation is not supported on CPU") + if target_name_prefix == "cpu": dtype = operand_aval.dtype prefix = "he" if dtypes.issubdtype(dtype, np.complexfloating) else "sy" @@ -1130,7 +1232,12 @@ def _eigh_cpu_gpu_lowering( } else: target_name = f"{target_name_prefix}solver_syevd_ffi" - kwargs = {"lower": lower, "algorithm": np.uint8(0)} + # Use Jacobi (algorithm=2) if requested, otherwise use QR (algorithm=1) + if algorithm is None: + algo_int = 0 + else: + algo_int = 2 if algorithm == EighImplementation.JACOBI else 1 + kwargs = {"lower": lower, "algorithm": np.uint8(algo_int)} info_aval = ShapedArray(batch_dims, np.int32) avals_out = [v_aval, w_aval, info_aval] @@ -1145,59 +1252,8 @@ def _eigh_cpu_gpu_lowering( return [v, w] -def _eigh_tpu_impl(x, *, lower, sort_eigenvalues, subset_by_index): - *_, m, n = x.shape - assert m == n, (m, n) - - termination_size = 256 - if not is_constant_dim(m): - # TODO: maybe we can relax the check below for shape polymorphism? - raise NotImplementedError( - "Shape polymorphism for native lowering for eigh is implemented " - f"only for the batch dimensions: {x.shape}") - if m <= termination_size and ( - subset_by_index is None or subset_by_index == (0, n) - ): - eig_vals, eig_vecs = eigh_jacobi(x, lower=lower, - sort_eigenvalues=sort_eigenvalues) - return eig_vecs, eig_vals - - def eigh_qdwh(x): - if len(x.shape) > 2: - return control_flow.map(eigh_qdwh, x) - - # We should only look at elements from the lower/upper triangle. Reflects - # that triangle into the other triangle to form a Hermitian matrix. - if lower: - mask = lax_internal._tri(bool, (n, n), 0) - else: - mask = lax.bitwise_not(lax_internal._tri(bool, (n, n), -1)) - if dtypes.issubdtype(x.dtype, np.complexfloating): - re = lax.select(mask, lax.real(x), _T(lax.real(x))) - if lower: - im_mask = lax_internal._tri(bool, (n, n), -1) - else: - im_mask = lax.bitwise_not(lax_internal._tri(bool, (n, n), 0)) - im = lax.imag(x) - im = lax.select(im_mask, im, lax.full_like(im, 0)) - im = lax.select(mask, im, -_T(im)) - x = lax.complex(re, im) - else: - x = lax.select(mask, x, _T(x)) - - return lax_eigh.eigh( - x, - sort_eigenvalues=sort_eigenvalues, - termination_size=termination_size, - subset_by_index=subset_by_index, - ) - - eig_vals, eig_vecs = eigh_qdwh(x) - return eig_vecs, eig_vals - - def _eigh_jvp_rule( - primals, tangents, *, lower, sort_eigenvalues, subset_by_index + primals, tangents, *, lower, sort_eigenvalues, subset_by_index, algorithm ): (a,) = primals n = a.shape[-1] @@ -1220,11 +1276,12 @@ def _eigh_jvp_rule( lower=lower, sort_eigenvalues=sort_eigenvalues, subset_by_index=subset_by_index, + algorithm=algorithm, ) # for complex numbers we need eigenvalues to be full dtype of v, a: w = w_real.astype(a.dtype) - eye_n = lax_internal._eye(a.dtype, (n, n)) + eye_n = lax._eye(a.dtype, (n, n)) # carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs. with config.numpy_rank_promotion("allow"): Fmat = lax.integer_pow(eye_n + w[..., np.newaxis, :] - w[..., np.newaxis], -1) - eye_n @@ -1241,9 +1298,6 @@ def _eigh_jvp_rule( _eigh_dtype_rule, (_float | _complex,), (2,), _eigh_shape_rule, "eigh", multiple_results=True) ad.primitive_jvps[eigh_p] = _eigh_jvp_rule -mlir.register_lowering( - eigh_p, mlir.lower_fun(_eigh_tpu_impl, multiple_results=True), - platform='tpu') register_cpu_gpu_lowering(eigh_p, _eigh_cpu_gpu_lowering) @@ -1259,7 +1313,6 @@ def _hessenberg_shape_rule(shape, **_): def _hessenberg_dtype_rule(dtype, **_): - dtype = dtypes.canonicalize_dtype(dtype) return dtype, dtype @@ -1374,7 +1427,7 @@ def body(k, state): # a[k+1:, k+1:] -= jnp.outer(a[k+1:, k], a[k, k+1:]) a_outer = a[:, k, None] * a[k, None] a = a - lax.select((m_idx[:, None] > k) & (n_idx[None, :] > k), - a_outer, lax_internal._zeros(a_outer)) + a_outer, lax._zeros(a_outer)) return pivot, perm, a pivot = lax.full((min(m, n),), 0, dtype=np.int32) @@ -1383,7 +1436,7 @@ def body(k, state): # If the array is empty, the loop body never executes but tracing it to a # jaxpr fails because the indexing cannot succeed. return (pivot, perm, a) - return lax.fori_loop(0, min(m, n), body, (pivot, perm, a)) + return control_flow.fori_loop(0, min(m, n), body, (pivot, perm, a)) def _lu_blocked(a, block_size=128): @@ -1425,7 +1478,6 @@ def _lu_shape_rule(shape): def _lu_dtype_rule(dtype, **_): - dtype = dtypes.canonicalize_dtype(dtype) return dtype, dtypes.dtype(np.int32), dtypes.dtype(np.int32) @@ -1447,10 +1499,10 @@ def _lu_jvp_inner(lu, a_dot, permutation): l_padding = [(0, 0, 0)] * 2 l_padding[-1] = (0, m - k, 0) - zero = lax_internal._const(lu, 0) + zero = lax._const(lu, 0) l = lax.pad(_tril(lu[:, :k], -1), zero, l_padding) - l = l + lax_internal._eye(dtype, (m, m)) - u_eye = lax.pad(lax_internal._eye(dtype, (n - k, n - k)), zero, + l = l + lax._eye(dtype, (m, m)) + u_eye = lax.pad(lax._eye(dtype, (n - k, n - k)), zero, ((k, 0, 0), (k, 0, 0))) u_padding = [(0, 0, 0)] * 2 u_padding[-2] = (0, n - k, 0) @@ -1563,7 +1615,7 @@ def _lu_solve_core(lu: Array, permutation: Array, b: Array, trans: int) -> Array return lax.reshape(x, b.shape) -@partial(api.jit, static_argnums=(3,)) +@api.jit(static_argnums=(3,)) def _lu_solve(lu: Array, permutation: Array, b: Array, trans: int) -> Array: if len(lu.shape) < 2 or lu.shape[-1] != lu.shape[-2]: raise ValueError("last two dimensions of LU decomposition must be equal, " @@ -1635,16 +1687,25 @@ def _generic_lu_pivots_to_permutation(swaps, permutation_size): """ assert len(swaps.shape) >= 1 batch_dims = swaps.shape[:-1] + swaps_sharding = core.typeof(swaps).sharding + batch_spec = swaps_sharding.spec[:-1] + if swaps_sharding.spec[-1] != None: + raise ValueError( + "The last dim of swaps should be unsharded but got:" + f" {swaps_sharding.spec[-1]} for type {core.typeof(swaps)}") + permutation_sharding = swaps_sharding.update(spec=batch_spec + (None,)) k = swaps.shape[-1] m = permutation_size - permutation = lax.broadcasted_iota(np.int32, batch_dims + (m,), - len(batch_dims)) + permutation = lax.broadcasted_iota( + np.int32, batch_dims + (m,), len(batch_dims), + out_sharding=permutation_sharding) if m == 0 or k == 0: return permutation upper = np.array(k, np.int32) if is_constant_dim(k) else k - result, _ = lax.fori_loop(np.array(0, np.int32), upper, _lu_pivots_body_fn, - (permutation, swaps)) + permutation, swaps = core.standard_insert_pvary(permutation, swaps) + result, _ = control_flow.fori_loop(np.array(0, np.int32), upper, + _lu_pivots_body_fn, (permutation, swaps)) return result @@ -1701,7 +1762,6 @@ def _geqrf_shape_rule(shape): return shape, (core.min_dim(m, n),) def _geqrf_dtype_rule(dtype): - dtype = dtypes.canonicalize_dtype(dtype) return dtype, dtype def _geqrf_lowering_rule(ctx, operand): @@ -1758,6 +1818,7 @@ def geqp3(a: ArrayLike, jpvt: ArrayLike, *, elementary Householder reflectors, and ``jpvt`` is the column-pivot indices such that ``a[:, jpvt] = q @ r``. """ + a, jpvt = core.standard_insert_pvary(a, jpvt) a_out, jpvt_out, taus = geqp3_p.bind(a, jpvt, use_magma=use_magma) return a_out, jpvt_out, taus @@ -1766,8 +1827,6 @@ def _geqp3_shape_rule(a_shape, jpvt_shape, **_): return a_shape, jpvt_shape, (core.min_dim(m, n),) def _geqp3_dtype_rule(dtype, jpvt_dtype, *_, **__): - dtype = dtypes.canonicalize_dtype(dtype) - jpvt_dtype = dtypes.canonicalize_dtype(jpvt_dtype) return dtype, jpvt_dtype, dtype def _geqp3_cpu_gpu_lowering(ctx, a, jpvt, *, use_magma, target_name_prefix): @@ -1797,7 +1856,6 @@ def _qr_shape_rule(shape, *, pivoting, full_matrices, **_): return ((m, k), (k, n), (n,)) if pivoting else ((m, k), (k, n)) def _qr_dtype_rule(dtype, *, pivoting, **_): - dtype = dtypes.canonicalize_dtype(dtype) return (dtype, dtype, dtypes.dtype(np.int32)) if pivoting else (dtype, dtype) def qr_jvp_rule(primals, tangents, *, pivoting, full_matrices, use_magma): @@ -1816,7 +1874,7 @@ def qr_jvp_rule(primals, tangents, *, pivoting, full_matrices, use_magma): qt_dx_rinv_lower = _tril(qt_dx_rinv, -1) do = qt_dx_rinv_lower - _H(qt_dx_rinv_lower) # This is skew-symmetric # The following correction is necessary for complex inputs - I = lax.expand_dims(lax_internal._eye(do.dtype, (n, n)), range(qt_dx_rinv.ndim - 2)) + I = lax.expand_dims(lax._eye(do.dtype, (n, n)), range(qt_dx_rinv.ndim - 2)) do = do + I * (qt_dx_rinv - qt_dx_rinv.real.astype(qt_dx_rinv.dtype)) dq = q @ (do - qt_dx_rinv) + dx_rinv dr = (qt_dx_rinv - do) @ r @@ -1829,7 +1887,7 @@ def _qr_lowering(a, *, pivoting, full_matrices, use_magma): *batch_dims, m, n = a.shape if m == 0 or n == 0: k = m if full_matrices else core.min_dim(m, n) - q = lax.broadcast_in_dim(lax_internal._eye(a.dtype, (m, k)), + q = lax.broadcast_in_dim(lax._eye(a.dtype, (m, k)), (*batch_dims, m, k), (len(batch_dims), len(batch_dims) + 1)) r = lax.full((*batch_dims, k, n), 0, dtype=a.dtype) @@ -1849,7 +1907,7 @@ def _qr_lowering(a, *, pivoting, full_matrices, use_magma): q = householder_product(r[..., :m, :m], taus) elif full_matrices: pads = [(0, 0, 0)] * (len(batch_dims) + 1) + [(0, m - n, 0)] - q = lax.pad(r, lax_internal._zero(r), pads) + q = lax.pad(r, lax._zero(r), pads) q = householder_product(q, taus) else: q = householder_product(r, taus) @@ -1875,7 +1933,6 @@ def _schur_shape_rule(shape, *, compute_schur_vectors, **_): return (shape, shape) if compute_schur_vectors else (shape,) def _schur_dtype_rule(dtype, *, compute_schur_vectors, **_): - dtype = dtypes.canonicalize_dtype(dtype) return (dtype, dtype) if compute_schur_vectors else (dtype,) def _schur_cpu_lowering(ctx, operand, *, compute_schur_vectors, sort_eig_vals, @@ -1948,8 +2005,7 @@ def _svd_shape_rule(shape, *, full_matrices, compute_uv, subset_by_index, **_): return (rank,), def _svd_dtype_rule(dtype, *, compute_uv, **_): - dtype = dtypes.canonicalize_dtype(dtype) - real_dtype = lax_internal._complex_basetype(dtype) + real_dtype = lax._complex_basetype(dtype) if compute_uv: return real_dtype, dtype, dtype else: @@ -1967,7 +2023,11 @@ def _svd_jvp_rule( algorithm=algorithm, ) - if compute_uv and full_matrices: + if ( + compute_uv + and full_matrices + and not core.definitely_equal(A.shape[-2], A.shape[-1]) + ): # TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf raise NotImplementedError( "Singular value decomposition JVP not implemented for full matrices") @@ -1981,7 +2041,7 @@ def _svd_jvp_rule( return (s,), (ds,) s_diffs = (s_dim + _T(s_dim)) * (s_dim - _T(s_dim)) - s_diffs_zeros = lax_internal._eye(s.dtype, (s.shape[-1], s.shape[-1])) # jnp.ones((), dtype=A.dtype) * (s_diffs == 0.) # is 1. where s_diffs is 0. and is 0. everywhere else + s_diffs_zeros = lax._eye(s.dtype, (s.shape[-1], s.shape[-1])) # jnp.ones((), dtype=A.dtype) * (s_diffs == 0.) # is 1. where s_diffs is 0. and is 0. everywhere else s_diffs_zeros = lax.expand_dims(s_diffs_zeros, range(s_diffs.ndim - 2)) F = 1 / (s_diffs + s_diffs_zeros) - s_diffs_zeros dSS = s_dim.astype(A.dtype) * dS # dS.dot(jnp.diag(s)) @@ -2007,12 +2067,12 @@ def _svd_jvp_rule( def _empty_svd(a, *, full_matrices, compute_uv): batch_shape = a.shape[:-2] m, n = a.shape[-2:] - s = lax.full(batch_shape + (0,), 0, dtype=lax_internal._complex_basetype(a.dtype)) + s = lax.full(batch_shape + (0,), 0, dtype=lax._complex_basetype(a.dtype)) if not compute_uv: return (s,) if full_matrices: size = max(m, n) - u = lax.broadcast_in_dim(lax_internal._eye(a.dtype, (size, size)), + u = lax.broadcast_in_dim(lax._eye(a.dtype, (size, size)), (*batch_shape, size, size), (len(batch_shape), len(batch_shape) + 1)) else: @@ -2064,7 +2124,7 @@ def _svd_cpu_gpu_lowering( target_name = lapack.prepare_lapack_call("gesvd_ffi", operand_aval.dtype) else: raise NotImplementedError( - "The SVD Jacobi algorithm is not implemented on CPU.") + "The SVD Jacobi and Polar algorithms are not implemented on CPU.") mode = _svd_computation_attr(compute_uv, full_matrices) info_aval = ShapedArray(batch_dims, np.dtype(np.int32)) if compute_uv: @@ -2130,16 +2190,24 @@ def _svd_gpu_sub_lowering(ctx, operand, *, full_matrices, compute_uv, # default QR algorithm, but users can (in principle) override this behavior # by passing `use_jacobi=True`. # - # TODO(danfm): Since this was originally implemented, hipSolver appers to + # TODO(danfm): Since this was originally implemented, hipSolver appears to # have added support for the Jacobi algorithm, so we should investigate # removing this condition. + # TODO(phawkins): Consider making polar decomposition the default. + use_jacobi = False + use_polar = False if algorithm is None or algorithm == SvdAlgorithm.DEFAULT: try: - use_jacobi = target_name_prefix == "cu" and m <= 1024 and n <= 1024 + gpu_available = target_name_prefix == "cu" or \ + target_name_prefix == "hip" + use_jacobi = gpu_available and m <= 1024 and n <= 1024 except core.InconclusiveDimensionOperation: use_jacobi = False - else: - use_jacobi = algorithm == SvdAlgorithm.JACOBI + elif algorithm == SvdAlgorithm.JACOBI: + use_jacobi = True + elif algorithm == SvdAlgorithm.POLAR: + use_polar = True + column_major = True if use_jacobi: target_name = f"{target_name_prefix}solver_gesvdj_ffi" @@ -2150,18 +2218,22 @@ def _svd_gpu_sub_lowering(ctx, operand, *, full_matrices, compute_uv, econ = not full_matrices and m > 32 and n > 32 except core.InconclusiveDimensionOperation: econ = False + elif use_polar: + target_name = f"{target_name_prefix}solver_gesvdp_ffi" + econ = not full_matrices else: target_name = f"{target_name_prefix}solver_gesvd_ffi" econ = not full_matrices - # Because the base gesvd kernel only supports matrices where m >= n, we. + # Because the base gesvd kernel only supports matrices where m >= n, we + # conceptually transpose the matrix if m < n. transposed = m < n kwargs = {"transposed": transposed} if transposed: column_major = False - if use_jacobi: - # When using the Jacobi algorithm, the U and V matrices must always be - # allocated even if compute_uv is False. + if use_jacobi or use_polar: + # When using the Jacobi or polar algorithms, the U and V matrices must + # always be allocated even if compute_uv is False. u_aval = ShapedArray((*batch_dims, m, k if econ else m), u_aval.dtype) v_aval = ShapedArray((*batch_dims, n, k if econ else n), vt_aval.dtype) avals_out = [operand_aval, s_aval, u_aval, v_aval, info_aval] @@ -2169,12 +2241,13 @@ def _svd_gpu_sub_lowering(ctx, operand, *, full_matrices, compute_uv, avals_out = [operand_aval, s_aval, vt_aval, u_aval, info_aval] else: avals_out = [operand_aval, s_aval, u_aval, vt_aval, info_aval] + rule = _linalg_ffi_lowering(target_name, avals_out=avals_out, operand_output_aliases={0: 0}, column_major=column_major) _, s, u, vt, info = rule(ctx, operand, full_matrices=not econ, compute_uv=compute_uv, **kwargs) - if use_jacobi and compute_uv: + if (use_jacobi or use_polar) and compute_uv: vt = hlo.transpose( vt, mlir.dense_int_array(np.array(tuple(range(nb)) + (nb + 1, nb)))) @@ -2195,57 +2268,12 @@ def _svd_gpu_sub_lowering(ctx, operand, *, full_matrices, compute_uv, else: return s, u, vt, info -def _svd_tpu(a, *, full_matrices, compute_uv, subset_by_index, algorithm=None): - if algorithm is not None and algorithm != SvdAlgorithm.DEFAULT: - raise NotImplementedError( - "The SVD algorithm parameter is not implemented on TPU.") - - batch_dims = a.shape[:-2] - fn = partial( - lax_svd.svd, - full_matrices=full_matrices, - compute_uv=compute_uv, - subset_by_index=subset_by_index, - ) - for _ in range(len(batch_dims)): - fn = api.vmap(fn) - - if compute_uv: - u, s, vh = fn(a) - return [s, u, vh] - else: - s = fn(a) - return [s] - -def _svd_tpu_lowering_rule( - ctx, operand, *, full_matrices, compute_uv, subset_by_index, algorithm=None -): - del algorithm # unused - operand_aval, = ctx.avals_in - m, n = operand_aval.shape[-2:] - - if m == 0 or n == 0: - return mlir.lower_fun(_empty_svd, multiple_results=True)( - ctx, - operand, - full_matrices=full_matrices, - compute_uv=compute_uv, - ) - - return mlir.lower_fun(_svd_tpu, multiple_results=True)( - ctx, - operand, - full_matrices=full_matrices, - compute_uv=compute_uv, - subset_by_index=subset_by_index, - ) svd_p = linalg_primitive( _svd_dtype_rule, (_float | _complex,), (2,), _svd_shape_rule, "svd", multiple_results=True) ad.primitive_jvps[svd_p] = _svd_jvp_rule register_cpu_gpu_lowering(svd_p, _svd_cpu_gpu_lowering) -mlir.register_lowering(svd_p, _svd_tpu_lowering_rule) # Symmetric product @@ -2302,7 +2330,7 @@ def _triangular_solve_shape_rule(a_shape, b_shape, *, left_side=False, **_): return b_shape def _triangular_solve_dtype_rule(dtype, *_, **__): - return dtypes.canonicalize_dtype(dtype) + return dtype def _triangular_solve_jvp_rule_a( g_a, ans, a, b, *, left_side, lower, transpose_a, conjugate_a, @@ -2321,7 +2349,7 @@ def a_inverse(rhs): transpose_a=transpose_a, conjugate_a=conjugate_a, unit_diagonal=unit_diagonal) - # triangular_solve is about the same cost as matrix multplication (~n^2 FLOPs + # triangular_solve is about the same cost as matrix multiplication (~n^2 FLOPs # for matrix/vector inputs). Order these operations in whichever order is # cheaper. if left_side: @@ -2410,15 +2438,7 @@ def _triangular_solve_cpu_lower( conjugate_a = False if np.dtype(a_aval.dtype) in _cpu_lapack_types: target_name = lapack.prepare_lapack_call("trsm_ffi", a_aval.dtype) - # TODO(b/397715595): Remove forward_compat check no earlier than 2025-03-18. - if ctx.is_forward_compat() or jaxlib_version <= (0, 5, 1): - alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype)), - alpha_aval = ShapedArray((), a_aval.dtype), - batch_partitionable = False - else: - alpha = () - alpha_aval = () - batch_partitionable = True + alpha, alpha_aval, batch_partitionable = (), (), True rule = _linalg_ffi_lowering(target_name, [a_aval, b_aval, *alpha_aval], operand_output_aliases={1: 0}, @@ -2463,8 +2483,7 @@ def _tridiagonal_shape_rule(shape, **_): return shape, (n,), (n - 1,), (n - 1,) def _tridiagonal_dtype_rule(dtype, **_): - dtype = dtypes.canonicalize_dtype(dtype) - real_dtype = lax_internal._complex_basetype(dtype) + real_dtype = lax._complex_basetype(dtype) return dtype, real_dtype, real_dtype, dtype def _tridiagonal_cpu_gpu_lowering(ctx, a, *, lower, target_name_prefix): @@ -2511,16 +2530,10 @@ def _tridiagonal_solve_shape_rule(dl_shape, d_shape, du_shape, b_shape, **_): "equal the dimensions of the diagonal arguments.") return b_shape -def _tridiagonal_solve_gpu_lowering(lowering, ctx, dl, d, du, b): - _, _, _, b_aval = ctx.avals_in - if b_aval.dtype != np.float32 and b_aval.dtype != np.float64: - raise NotImplementedError( - "tridiagonal_solve is only implemented for float32 and float64 on GPU.") - m, n = b_aval.shape[-2:] - b_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, b_aval.shape) - return [lowering( - dl, d, du, b, m=m, n=n, ldb=m, t=b_aval.dtype, - b_shape_vals=b_shape_vals)] +def _tridiagonal_solve_gpu_lowering(ctx, dl, d, du, b, *, target_name_prefix): + target_name = f"{target_name_prefix}sparse_gtsv2_ffi" + rule = _linalg_ffi_lowering(target_name, operand_output_aliases={3: 0}) + return rule(ctx, dl, d, du, b) def _tridiagonal_solve_cpu_lowering(ctx, dl, d, du, b, **kwargs): del kwargs # unused @@ -2562,9 +2575,9 @@ def _tridiagonal_solve_transpose_rule(cotangent, dl, d, du, b): if type(cotangent) is ad_util.Zero: cotangent_b = ad_util.Zero(b.aval) else: - dl_trans = lax.concatenate((lax.zeros_like_array(du[..., -1:]), du[..., :-1]), + dl_trans = lax.concatenate((lax.full_like(du[..., -1:], 0), du[..., :-1]), du.ndim-1) - du_trans = lax.concatenate((dl[..., 1:], lax.zeros_like_array(dl[..., :1])), + du_trans = lax.concatenate((dl[..., 1:], lax.full_like(dl[..., :1], 0)), dl.ndim-1) cotangent_b = tridiagonal_solve(dl_trans, d, du_trans, cotangent) return [None, None, None, cotangent_b] @@ -2598,7 +2611,7 @@ def fwd(carry, args): dp_next = (d - a * dp) / (b - a * cp) return (cp_next, dp_next), (cp, dp) - (_, final), (cp, dp) = lax.scan( + (_, final), (cp, dp) = control_flow.scan( fwd, (du[0] / d[0], b[0] / d[0]), (dl[1:], d[1:], du[1:], b[1:, :]), unroll=32) @@ -2607,7 +2620,7 @@ def bwd(xn, args): x = dp - cp * xn return x, xn - end, ans = lax.scan(bwd, final, (cp, dp), unroll=32, reverse=True) + end, ans = control_flow.scan(bwd, final, (cp, dp), unroll=32, reverse=True) return lax.concatenate((end[None], ans), 0) def _tridiagonal_solve_jax(dl, d, du, b, **_): @@ -2628,11 +2641,11 @@ def _tridiagonal_solve_jax(dl, d, du, b, **_): platform='cpu') mlir.register_lowering( tridiagonal_solve_p, - partial(_tridiagonal_solve_gpu_lowering, gpu_sparse.cuda_gtsv2), + partial(_tridiagonal_solve_gpu_lowering, target_name_prefix='cu'), platform='cuda') mlir.register_lowering( tridiagonal_solve_p, - partial(_tridiagonal_solve_gpu_lowering, gpu_sparse.rocm_gtsv2), + partial(_tridiagonal_solve_gpu_lowering, target_name_prefix='hip'), platform='rocm') mlir.register_lowering(tridiagonal_solve_p, mlir.lower_fun( _tridiagonal_solve_jax, multiple_results=False)) @@ -2672,7 +2685,7 @@ def _solve(a: Array, b: Array) -> Array: # computing sensitivities. This is considerably faster. lu_, _, permutation = lu(lax.stop_gradient(a)) custom_solve = partial( - lax.custom_linear_solve, + control_flow.custom_linear_solve, lambda x: _broadcasted_matvec(a, x), solve=lambda _, x: lu_solve(lu_, permutation, x, trans=0), transpose_solve=lambda _, x: lu_solve(lu_, permutation, x, trans=1)) @@ -2693,13 +2706,13 @@ def symmetrize(x: Array) -> Array: return (x + _H(x)) / 2 def _tril(m: Array, k:int = 0) -> Array: *_, N, M = m.shape - mask = lax_internal._tri(bool, (N, M), k) - return lax.select(lax.broadcast(mask, m.shape[:-2]), m, lax.zeros_like_array(m)) + mask = lax._tri(bool, (N, M), k) + return lax.select(lax.broadcast(mask, m.shape[:-2]), m, lax.full_like(m, 0)) def _triu(m: Array, k:int = 0) -> Array: *_, N, M = m.shape - mask = lax_internal._tri(bool, (N, M), k - 1) - return lax.select(lax.broadcast(mask, m.shape[:-2]), lax.zeros_like_array(m), m) + mask = lax._tri(bool, (N, M), k - 1) + return lax.select(lax.broadcast(mask, m.shape[:-2]), lax.full_like(m, 0), m) def _construct_diagonal(s: Array) -> Array: """Construct a (batched) diagonal matrix""" @@ -2723,11 +2736,12 @@ def _nan_like_hlo(ctx: mlir.LoweringRuleContext, aval) -> ir.Value: def _broadcasting_select_hlo(ctx, which, which_aval, x, x_aval, y, y_aval) -> ir.Value: """Wrapper around XLA `Select` that broadcasts its arguments.""" - out_shapes = list(lax_internal.broadcast_shapes( + out_shapes = list(lax.broadcast_shapes( tuple(which_aval.shape), tuple(x_aval.shape), tuple(y_aval.shape))) + out_sharding = lax.broadcast_shardings(which_aval, x_aval, y_aval) which, x, y = mlir.multi_broadcast_in_dim(ctx, (which, x, y), (which_aval, x_aval, y_aval), - out_shapes) + out_shapes, out_sharding) return hlo.select(which, x, y) def _replace_not_ok_with_nan(ctx, batch_dims, ok, x, x_aval): @@ -2763,9 +2777,9 @@ def _column_major_matrix_layout(dim: int) -> tuple[int, ...]: return (dim - 2, dim - 1) + tuple(range(dim - 3, -1, -1)) def _sdy_rule_for_aval(letters, num_batch_dims, aval): - return " ".join( - ("...", *(next(letters) for _ in range(len(aval.shape) - num_batch_dims))) - ) + d = len(aval.shape) - num_batch_dims + prefix = "... " if num_batch_dims and d >= 0 else "" + return prefix + " ".join(next(letters) for _ in range(d)) def _build_sdy_sharding_rule(num_batch_dims, avals_in, avals_out): letters = iter(string.ascii_letters) diff --git a/jax/_src/lax/other.py b/jax/_src/lax/other.py index 00e15ef6a91d..b3d54064f9b4 100644 --- a/jax/_src/lax/other.py +++ b/jax/_src/lax/other.py @@ -54,7 +54,7 @@ def conv_general_dilated_patches( Docstring below adapted from `jax.lax.conv_general_dilated`. See Also: - https://www.tensorflow.org/xla/operation_semantics#conv_convolution + https://www.openxla.org/xla/operation_semantics#conv_convolution Args: lhs: a rank `n+2` dimensional input array. @@ -141,7 +141,7 @@ def conv_general_dilated_local( spatial location. Docstring below adapted from `jax.lax.conv_general_dilated`. See Also: - https://www.tensorflow.org/xla/operation_semantics#conv_convolution + https://www.openxla.org/xla/operation_semantics#conv_convolution Args: lhs: a rank `n+2` dimensional input array. @@ -284,6 +284,39 @@ def _logaddexp_jvp(primals, tangents): x1, x2 = primals t1, t2 = tangents primal_out = logaddexp(x1, x2) - tangent_out = lax.add(lax.mul(t1, lax.exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), - lax.mul(t2, lax.exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out))))) + tangent_out = lax.add( + lax.mul(t1, lax.exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), + lax.mul(t2, lax.exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out))))) + return primal_out, tangent_out + + +@custom_jvp +def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Compute log2(exp2(x1) + exp2(x2)) avoiding overflow.""" + x1_arr = lax.asarray(x1) + x2_arr = lax.asarray(x2) + assert x1_arr.dtype == x2_arr.dtype + + amax = lax.max(x1_arr, x2_arr) + invln2 = lax._const(amax, 1/np.log(2)) + if dtypes.isdtype(x1_arr.dtype, "real floating"): + delta = lax.sub(x1_arr, x2_arr) + return lax.select(lax._isnan(delta), + lax.add(x1_arr, x2_arr), # NaNs or infinities of the same sign. + lax.add(amax, lax.mul(invln2, lax.log1p(lax.exp2(lax.neg(lax.abs(delta))))))) + elif dtypes.isdtype(x1_arr.dtype, "complex floating"): + delta = lax.sub(lax.add(x1_arr, x2_arr), lax.mul(amax, lax._const(amax, 2))) + out = lax.add(amax, lax.mul(invln2, lax.log1p(lax.exp2(delta)))) + return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi / np.log(2))) + else: + raise ValueError(f"logaddexp2 requires floating-point or complex inputs; got {x1_arr.dtype}") + + +@logaddexp2.defjvp +def _logaddexp2_jvp(primals, tangents): + x1, x2 = primals + t1, t2 = tangents + primal_out = logaddexp2(x1, x2) + tangent_out = lax.add(lax.mul(t1, lax.exp2(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), + lax.mul(t2, lax.exp2(lax.sub(_replace_inf(x2), _replace_inf(primal_out))))) return primal_out, tangent_out diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 221fe2a9e87a..e156c617f5ec 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -19,14 +19,16 @@ from collections.abc import Sequence from functools import partial +from dataclasses import dataclass import itertools import math -import jax -from jax import tree_util from jax._src import core +from jax._src import config from jax._src import dispatch from jax._src import dtypes +from jax._src import effects as effects_lib +from jax._src import tree_util from jax._src.sharding_impls import (SPMDAxisContext, ShardingContext, NamedSharding, PartitionSpec as P) from jax._src.core import AxisName, ShapedArray @@ -34,10 +36,16 @@ from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import pxla +from jax._src.core import check_unreduced_args +from jax._src.mesh import get_abstract_mesh +from jax._src.core import abstract_token, pvary +from jax._src.lax import control_flow from jax._src.lax import lax from jax._src.lax import slicing from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo +from jax._src.lib import xla_client as xc +from jax._src.typing import Array from jax._src.util import (canonicalize_axis, moveaxis, safe_map, safe_zip, unzip2) import numpy as np @@ -113,8 +121,28 @@ def psum(x, axis_name, *, axis_index_groups=None): [20 22 24 26] [20 22 24 26]] """ + axes = ((axis_name,) if not isinstance(axis_name, (tuple, list)) else + tuple(axis_name)) + # TODO(yashkatariya): Remove this handling and remove_size_one_mesh_axis_from_type + # generally from JAX. + axes = _maybe_skip_one_sized_axes(axes) + if not axes: + return x + def bind(leaf): + from_ = _get_from(core.typeof(leaf), axes, 'jax.lax.psum') + if from_ == 'unreduced': + if axis_index_groups is not None: + raise NotImplementedError + return unreduced_psum(leaf, axes) + else: + return _psum(leaf, axes, axis_index_groups=axis_index_groups) + return tree_util.tree_map(bind, x) + +def _psum(x, axis_name, *, axis_index_groups): if not isinstance(axis_name, (tuple, list)): axis_name = (axis_name,) + if not axis_name: + return x if any(isinstance(axis, int) for axis in axis_name) and axis_index_groups is not None: raise ValueError("axis_index_groups only supported for sums over just named axes") _validate_reduce_axis_index_groups(axis_index_groups) @@ -139,10 +167,25 @@ def pos_reduce(x): size = math.prod([core.get_axis_env().axis_size(name) for name in named_axes]) out_flat = tuple(lax._const(leaf, size) * pos_reduce(leaf) for leaf in leaves) else: - out_flat = psum_p.bind( - *leaves, axes=tuple(axis_name), axis_index_groups=axis_index_groups) + if config._check_vma.value: + out_flat = [bind_psum_invariant(leaf, axes=tuple(axis_name), + axis_index_groups=axis_index_groups) + for leaf in leaves] + else: + out_flat = [psum_p.bind(leaf, axes=tuple(axis_name), + axis_index_groups=axis_index_groups) + for leaf in leaves] return tree_util.tree_unflatten(treedef, out_flat) + +def _maybe_skip_one_sized_axes(axes): + if config.remove_size_one_mesh_axis_from_type.value: + cur_mesh = get_abstract_mesh() + return tuple(i for i in axes + if (size := cur_mesh.shape.get(i)) is None or size != 1) + return axes + + def pmean(x, axis_name, *, axis_index_groups=None): """Compute an all-reduce mean on ``x`` over the pmapped axis ``axis_name``. @@ -173,7 +216,7 @@ def pmean(x, axis_name, *, axis_index_groups=None): [0. 0.6666667 1.3333334 2. ] """ x = psum(x, axis_name=axis_name, axis_index_groups=axis_index_groups) - n = psum(1, axis_name=axis_name, axis_index_groups=axis_index_groups) + n = _axis_size(axis_name, axis_index_groups) return tree_util.tree_map(lambda v: v / n, x) def pmax(x, axis_name, *, axis_index_groups=None): @@ -200,11 +243,12 @@ def pmax(x, axis_name, *, axis_index_groups=None): if any(isinstance(axis, int) for axis in axis_name) and axis_index_groups is not None: raise ValueError("axis_index_groups only supported for sums over just named axes") _validate_reduce_axis_index_groups(axis_index_groups) - leaves, treedef = tree_util.tree_flatten(x) axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) - out_flat = pmax_p.bind(*leaves, axes=axis_name, - axis_index_groups=axis_index_groups) - return tree_util.tree_unflatten(treedef, out_flat) + def bind(leaf): + leaf = insert_collective_pvary(axis_name, leaf) + return pmax_p.bind(leaf, axes=axis_name, axis_index_groups=axis_index_groups) + return tree_util.tree_map(bind, x) + def pmin(x, axis_name, *, axis_index_groups=None): """Compute an all-reduce min on ``x`` over the pmapped axis ``axis_name``. @@ -230,11 +274,11 @@ def pmin(x, axis_name, *, axis_index_groups=None): if any(isinstance(axis, int) for axis in axis_name) and axis_index_groups is not None: raise ValueError("axis_index_groups only supported for sums over just named axes") _validate_reduce_axis_index_groups(axis_index_groups) - leaves, treedef = tree_util.tree_flatten(x) axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) - out_flat = pmin_p.bind(*leaves, axes=axis_name, - axis_index_groups=axis_index_groups) - return tree_util.tree_unflatten(treedef, out_flat) + def bind(leaf): + leaf = insert_collective_pvary(axis_name, leaf) + return pmin_p.bind(leaf, axes=axis_name, axis_index_groups=axis_index_groups) + return tree_util.tree_map(bind, x) # TODO(mattjj): add a pargmin_p, or add named axis support to lax.argmin_p def pargmin(x, axis_name): @@ -253,7 +297,7 @@ def _axis_index_of_val(x, val, axis_name): mask = (val == x) validx = lax.select(mask, lax.full(mask.shape, idx), - lax.full(mask.shape, dtypes.iinfo(dtypes.dtype(idx)).max, dtypes.dtype(idx))) + lax.full(mask.shape, dtypes.iinfo(idx.dtype).max, idx.dtype)) return pmin(validx, axis_name) def _validate_reduce_axis_index_groups(axis_index_groups): @@ -325,9 +369,84 @@ def ppermute(x, axis_name, perm): """ if not isinstance(axis_name, (list, tuple)): axis_name = (axis_name,) - return tree_util.tree_map( - partial(ppermute_p.bind, axis_name=axis_name, - perm=tuple(map(tuple, perm))), x) + def bind(leaf): + leaf = insert_collective_pvary(axis_name, leaf) + return ppermute_p.bind(leaf, axis_name=axis_name, perm=tuple(map(tuple, perm))) + return tree_util.tree_map(bind, x) + + +def psend(x, axis_name, perm): + """Perform a collective send according to the permutation ``perm``. + + If ``x`` is a pytree then the result is equivalent to mapping this function to + each leaf in the tree. + + This function is an analog of the Send HLO. + + Args: + x: array(s) with a mapped axis named ``axis_name``. + axis_name: hashable Python object used to name a pmapped axis (see the + :func:`jax.pmap` documentation for more details). + perm: list of pairs of ints, representing ``(source_index, + destination_index)`` pairs that encode how the mapped axis named + ``axis_name`` should be shuffled. The integer values are treated as + indices into the mapped axis ``axis_name``. Any two pairs should not have + the same source index or the same destination index. For each index of the + axis ``axis_name`` that does not correspond to a destination index in + ``perm``, the corresponding values in the result are filled with zeros of + the appropriate type. The semantics here are platform-specific, and for + GPU they correspond to NCCL send. + + Returns: + A compiler token that can be used by precv and lax.optimzation_barrier to + enforce ordering of collective ops. + """ + axis_name = tuple(axis_name) if isinstance(axis_name, (list, tuple)) else (axis_name,) + + def bind(leaf): + leaf = insert_collective_pvary(axis_name, leaf) + return psend_p.bind(leaf, axis_name=axis_name, perm=tuple(map(tuple, perm))) + + return tree_util.tree_map(bind, x) + + +def precv(token, out_shape, axis_name, perm): + """Perform a collective recv according to the permutation ``perm``. + + This function is an analog of the Recv HLO. + + Args: + token: a compiler token, either generated by a matching psend or + lax.create_token(). This is used to enforce control dependencies between + collectives. + out_shape: ShapeDtypeStruct(s) containing the dtype and shape + of the result. + axis_name: hashable Python object used to name a pmapped axis (see the + :func:`jax.pmap` documentation for more details). + perm: list of pairs of ints, representing ``(source_index, + destination_index)`` pairs that encode how the mapped axis named + ``axis_name`` should be shuffled. The integer values are treated as + indices into the mapped axis ``axis_name``. Any two pairs should not have + the same source index or the same destination index. For each index of the + axis ``axis_name`` that does not correspond to a destination index in + ``perm``, the corresponding values in the result are filled with zeros of + the appropriate type. The semantics here are platform-specific, and for + GPU they correspond to NCCL recv. + + Returns: + Array(s) with the same shape as ``out_shape``. + """ + axis_name = tuple(axis_name) if isinstance(axis_name, (list, tuple)) else (axis_name,) + + return precv_p.bind( + token, + out_shape=core.ShapedArray( + out_shape.shape, out_shape.dtype + ), + axis_name=axis_name, + perm=tuple(map(tuple, perm)), + ) + def pshuffle(x, axis_name, perm): """Convenience wrapper of jax.lax.ppermute with alternate permutation encoding @@ -421,14 +540,14 @@ def all_to_all(x, axis_name, split_axis, concat_axis, *, axis_index_groups=None, np.insert(np.delete(x.shape, split_axis), concat_axis, axis_size) where ``axis_size`` is the size of the mapped axis named ``axis_name`` in - the input ``x``, i.e. ``axis_size = lax.psum(1, axis_name)``. + the input ``x``. Otherwise array with shape similar to the input shape, except with split_axis divided by axis size and concat_axis multiplied by axis size. """ axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) def bind(x, split_axis=split_axis, concat_axis=concat_axis): - group_size = psum(1, axis_name, axis_index_groups=axis_index_groups) + group_size = _axis_size(axis_name, axis_index_groups) if tiled: if x.shape[split_axis] % group_size != 0: raise ValueError(f"The size of all_to_all split_axis ({x.shape[split_axis]}) " @@ -447,6 +566,7 @@ def bind(x, split_axis=split_axis, concat_axis=concat_axis): else: # concat_axis < split_axis x = lax.expand_dims(x, (concat_axis,)) # insert the new axis split_axis += 1 # we have a new axis before split_axis now + x = insert_collective_pvary(axis_name, x) result = all_to_all_p.bind(x, split_axis=split_axis, concat_axis=concat_axis, axis_name=axis_name, axis_index_groups=axis_index_groups, @@ -476,14 +596,17 @@ def ragged_all_to_all( That is, we can represent ragged data contiguously using a triple of dense arrays ``(data, offsets, sizes)``: + * ``data``: the concatenated component arrays, * ``offsets``: 1D array of indices into the leading axis of ``data`` indicating where the data for each component array begins, * ``sizes``: 1D array of sizes of the leading axis of each component array. + We refer to this triple as a ragged array. (Offsets can't be computed from sizes in general to allow for internal padding.) For example:: + data: f32[8,3] = jnp.array([ [a,b,c], [d,e,f], [g,h,i], [j,k,l], [m,n,o], [p,q,r], [s,t,u], [v,w,x], ]) @@ -612,7 +735,7 @@ def ragged_all_to_all( axis_index_groups=axis_index_groups) -def axis_index(axis_name): +def axis_index(axis_name: AxisName) -> Array: """Return the index along the mapped axis ``axis_name``. Args: @@ -623,42 +746,77 @@ def axis_index(axis_name): For example, with 8 XLA devices available: - >>> from functools import partial - >>> @partial(jax.pmap, axis_name='i') - ... def f(_): - ... return lax.axis_index('i') + >>> mesh = jax.make_mesh((8,), 'i', axis_types=(jax.sharding.AxisType.Explicit,)) + >>> @jax.shard_map(mesh=mesh, in_specs=(), out_specs=jax.P('i')) + ... def f(): + ... return lax.axis_index('i')[None] ... - >>> f(np.zeros(4)) - Array([0, 1, 2, 3], dtype=int32) - >>> f(np.zeros(8)) + >>> f() Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32) - >>> @partial(jax.pmap, axis_name='i') - ... @partial(jax.pmap, axis_name='j') - ... def f(_): - ... return lax.axis_index('i'), lax.axis_index('j') + + >>> mesh = jax.make_mesh((4, 2), ('i', 'j'), + ... axis_types=(jax.sharding.AxisType.Explicit,) * 2) + >>> @jax.shard_map(mesh=mesh, in_specs=(), out_specs=jax.P('i', 'j')) + ... def f(): + ... return lax.axis_index(('i', 'j'))[None, None] ... - >>> x, y = f(np.zeros((4, 2))) - >>> print(x) - [[0 0] - [1 1] - [2 2] - [3 3]] - >>> print(y) - [[0 1] - [0 1] - [0 1] - [0 1]] + >>> f() + Array([[0, 1], + [2, 3], + [4, 5], + [6, 7]], dtype=int32) """ if not isinstance(axis_name, (tuple, list)): return axis_index_p.bind(axis_name=axis_name) else: inner_size = 1 - index = 0 + index = lax.asarray(0) for name in reversed(axis_name): index += axis_index(name) * inner_size - inner_size *= psum(1, name) + inner_size *= axis_size(name) return index + +def axis_size(axis_name: AxisName) -> int: + """Return the size of the mapped axis ``axis_name``. + + Args: + axis_name: hashable Python object used to name the mapped axis. + + Returns: + An integer representing the size. + + For example, with 8 XLA devices available: + + >>> mesh = jax.make_mesh((8,), 'i', axis_types=(jax.sharding.AxisType.Explicit,)) + >>> @jax.shard_map(mesh=mesh, in_specs=jax.P('i'), out_specs=jax.P()) + ... def f(_): + ... return lax.axis_size('i') + ... + >>> f(jnp.zeros(16)) + Array(8, dtype=int32, weak_type=True) + + >>> mesh = jax.make_mesh((4, 2), ('i', 'j'), + ... axis_types=(jax.sharding.AxisType.Explicit,) * 2) + >>> @jax.shard_map(mesh=mesh, in_specs=jax.P('i', 'j'), out_specs=jax.P()) + ... def f(_): + ... return lax.axis_size(('i', 'j')) + ... + >>> f(jnp.zeros((16, 8))) + Array(8, dtype=int32, weak_type=True) + """ + return _axis_size(axis_name) + + +def _axis_size( + axis_name: AxisName, + axis_index_groups: Sequence[Sequence[int]] | None = None, + /, +) -> int: + axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) + return psum(1, axis_name, axis_index_groups=axis_index_groups) + + def pgather(src, idx, axes: int | AxisName): """Uses the last positional axis of idx to index into src's axes.""" if not isinstance(axes, (tuple, list)): @@ -666,7 +824,6 @@ def pgather(src, idx, axes: int | AxisName): # TODO: Canonicalize exes! return pgather_p.bind(src, idx, axes=tuple(axes)) - ### parallel primitives def _names_in_param(pname: str, params: core.ParamDict) -> tuple[str]: @@ -676,74 +833,70 @@ def _names_in_param(pname: str, params: core.ParamDict) -> tuple[str]: else: return (axis_names,) -def _constant_reduction(prim, axis_data, args, axes, axis_index_groups): +def _constant_reduction(prim, axis_data, arg, axes, axis_index_groups): assert axis_data.name in axes if axis_index_groups: raise NotImplementedError new_axes = tuple(n for n in axes if n != axis_data.name) if new_axes: - args = prim.bind(*args, axes=new_axes, axis_index_groups=axis_index_groups) + arg = (prim.bind(arg, axes=new_axes) if prim is psum_invariant_p else + prim.bind(arg, axes=new_axes, axis_index_groups=axis_index_groups)) if prim is psum_p: - outs = [lax._const(x, axis_data.size) * x for x in args] + out = lax._const(arg, axis_data.size) * arg elif prim in (pmin_p, pmax_p): - outs = args + out = arg else: raise Exception(f"Unrecognized reducer: {prim}") - - return outs, [None] * len(outs) + return out, None def _reduction_with_positional_batcher( - prim, vals_in, dims_in, axis_index_groups, - transform_unmapped, transform_mapped): + prim, v, d, axis_index_groups, transform_unmapped, transform_mapped): if axis_index_groups is not None: raise NotImplementedError("axis_index_groups not supported in vmap collectives. " "Please open a feature request!") - vals_in = [val if d is batching.not_mapped or d == 0 else _moveaxis(d, 0, val) - for val, d in zip(vals_in, dims_in)] - mapped_vals_in, unmapped_vals_in = partitioned_vals_in = [], [] - mapped_idxs, unmapped_idxs = partitioned_idxs = [], [] - for i, (val, d) in enumerate(zip(vals_in, dims_in)): - partitioned_vals_in[d is batching.not_mapped].append(val) - partitioned_idxs[d is batching.not_mapped].append(i) - vals_out = [None] * len(vals_in) - if unmapped_vals_in: - unmapped_axes, unmapped_vals_in = transform_unmapped(0, unmapped_vals_in) - unmapped_vals_out = prim.bind(*unmapped_vals_in, axes=unmapped_axes, axis_index_groups=None) - for i, val in zip(unmapped_idxs, unmapped_vals_out): - vals_out[i] = val - if mapped_vals_in: - mapped_axes, mapped_vals_in = transform_mapped(0, mapped_vals_in) - mapped_vals_out = prim.bind(*mapped_vals_in, axes=mapped_axes, axis_index_groups=None) - for i, val in zip(mapped_idxs, mapped_vals_out): - vals_out[i] = val - assert all(v is not None for v in vals_out) - return vals_out - -def _reduction_batcher(prim, vals_in, dims_in, *, axes, axis_index_groups): - assert prim.multiple_results + v = v if d is batching.not_mapped or d == 0 else _moveaxis(d, 0, v) + if d is batching.not_mapped: + unmapped_axes, unmapped_vals_in = transform_unmapped(0, v) + return (prim.bind(unmapped_vals_in, axes=unmapped_axes) + if prim is psum_invariant_p else + prim.bind(unmapped_vals_in, axes=unmapped_axes, axis_index_groups=None)) + + mapped_axes, mapped_vals_in = transform_mapped(0, v) + return (prim.bind(mapped_vals_in, axes=mapped_axes) + if prim is psum_invariant_p else + prim.bind(mapped_vals_in, axes=mapped_axes, axis_index_groups=None)) + +def _reduction_batcher(prim, v, d, *, axes, axis_index_groups): + assert not prim.multiple_results if not any(isinstance(axis, int) for axis in axes): - return prim.bind(*vals_in, axes=axes, axis_index_groups=axis_index_groups), dims_in - vals_out = _reduction_with_positional_batcher( - prim, vals_in, dims_in, axis_index_groups, - lambda d, d_vals_in: (axes, d_vals_in), - lambda d, d_vals_in: (tuple(axis + (axis >= d) if isinstance(axis, int) else axis - for axis in axes), - d_vals_in)) + out = (prim.bind(v, axes=axes) if prim is psum_invariant_p else + prim.bind(v, axes=axes, axis_index_groups=axis_index_groups)) + return out, d + val_out = _reduction_with_positional_batcher( + prim, v, d, axis_index_groups, + lambda d, v: (axes, v), + lambda d, v: (tuple(axis + (axis >= d) if isinstance(axis, int) else axis + for axis in axes), + v)) # _reduction_with_positional_batcher moves all map dims to 0 - return vals_out, [d if d is batching.not_mapped else 0 for d in dims_in] + return val_out, d if d is batching.not_mapped else 0 -def _batched_reduction_collective( - prim, if_unmapped, axis_data, vals_in, dims_in, axes, - axis_index_groups): - assert prim.multiple_results - if all(d is None for d in dims_in): +def _batched_reduction_collective(prim, if_unmapped, axis_data, vals_in, + dims_in, axes, axis_index_groups): + assert not prim.multiple_results + (v,), (d,) = vals_in, dims_in + del vals_in, dims_in + + if d is None: if axis_data.name in axes: - return _constant_reduction(prim, axis_data, vals_in, axes, axis_index_groups) + return _constant_reduction(prim, axis_data, v, axes, axis_index_groups) else: - return prim.bind(*vals_in, axes=axes, axis_index_groups=axis_index_groups), dims_in + out = (prim.bind(v, axes=axes) if prim is psum_invariant_p else + prim.bind(v, axes=axes, axis_index_groups=axis_index_groups)) + return out, d if axis_data.name not in axes: - return _reduction_batcher(prim, vals_in, dims_in, axes=axes, - axis_index_groups=axis_index_groups) + return _reduction_batcher( + prim, v, d, axes=axes, axis_index_groups=axis_index_groups) # Note that we have a choice here. We can either unfuse the reduction into one # that handles the batched dims and then another one that handles the rest. @@ -751,15 +904,14 @@ def _batched_reduction_collective( # we have to split the primitive into one for unmapped inputs and another # one for mapped, because they differ in their `axes` parameter. # We choose the second strategy here. - vals_out = _reduction_with_positional_batcher( - prim, vals_in, dims_in, axis_index_groups, - lambda d, d_vals_in: (tuple(axis for axis in axes if axis != axis_data.name), - [if_unmapped(v, axis_data.size) for v in d_vals_in]), - lambda d, d_vals_in: (tuple(axis + (axis >= d) if isinstance(axis, int) else - axis if axis != axis_data.name else - d for axis in axes), - d_vals_in)) - return vals_out, [batching.not_mapped] * len(vals_out) + val_out = _reduction_with_positional_batcher( + prim, v, d, axis_index_groups, + lambda d, v: (tuple(axis for axis in axes if axis != axis_data.name), + if_unmapped(v, axis_data.size)), + lambda d, v: (tuple(axis + (axis >= d) if isinstance(axis, int) else axis + if axis != axis_data.name else d for axis in axes), + v)) + return val_out, batching.not_mapped def _replica_groups(axis_env, axis_name, axis_index_groups): replica_groups = pxla.axis_groups(axis_env, axis_name) @@ -776,38 +928,47 @@ def _replica_groups_hlo(replica_groups: Sequence[Sequence[int]] dtype=np.int64).T return ir.DenseIntElementsAttr.get(np.ascontiguousarray(groups)) -def _allreduce_impl(prim, pos_reducer, *args, axes, axis_index_groups): +def _allreduce_impl(prim, pos_reducer, arg, *, axes, axis_index_groups): assert axis_index_groups is None if not all(isinstance(axis, int) for axis in axes): - return dispatch.apply_primitive(prim, *args, axes=axes, + return dispatch.apply_primitive(prim, arg, axes=axes, axis_index_groups=axis_index_groups) assert all(isinstance(axis, int) for axis in axes) - return [pos_reducer(arg, axes) for arg in args] + return pos_reducer(arg, axes) -def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups): - _check_axis_names(axes) +def _allreduce_effectful_abstract_eval(aval, *, axes, axis_index_groups): + _check_axis_names(axes, 'psum') named_axes = tuple(axis for axis in axes if not isinstance(axis, int)) pos_axes = tuple(axis for axis in axes if isinstance(axis, int)) if axis_index_groups is not None: if len(pos_axes) != 0: raise ValueError(f"axis_index_groups can only be used with reductions over " f"named axes, but got: {axes}") - core.check_avals_context_mesh(args, 'all_reduce') - out_avals = [ - ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype, - sharding=lax._reduce_op_sharding_rule(arg, axes=pos_axes)) - for arg in args - ] - return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes} - -def _check_axis_names(axes): + core.check_avals_context_mesh([aval], 'psum') + check_unreduced_args([aval], 'psum') + out_aval = ShapedArray( + lax._reduce_op_shape_rule(aval, axes=pos_axes), aval.dtype, + sharding=lax._reduce_op_sharding_rule(aval, axes=pos_axes)) + return out_aval, {core.NamedAxisEffect(axis) for axis in named_axes} + +# TODO(yashkatariya): Replace this with _psum_invariant_abstract_eval +def _pmin_pmax_abstract_eval(name, aval, *, axes, axis_index_groups): + if not config._check_vma.value: + return _allreduce_effectful_abstract_eval( + aval, axes=axes, axis_index_groups=axis_index_groups) + return _psum_invariant_abstract_eval(name, aval, axes=axes) + +def _check_axis_names(axes, api_name): named_axes = tuple(axis for axis in axes if not isinstance(axis, int)) axis_env = core.get_axis_env() for name in named_axes: if not axis_env.axis_exists(name): - raise NameError(f"unbound axis name: {name}") + raise NameError( + f"Found an unbound axis name: {name}. To fix this, please call" + f" {api_name} under `jax.shard_map`.") -def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups): +def _allreduce_lowering(prim, pos_fn, ctx, arg, *, axes, axis_index_groups): + aval_in, = ctx.avals_in if axis_index_groups is not None and ("tpu" in ctx.module_context.platforms): len_0 = len(axis_index_groups[0]) if any(len(g) != len_0 for g in axis_index_groups): @@ -825,9 +986,9 @@ def _positional_reduce(aval, arg): reducer_ctx = ctx.replace(primitive=None, avals_in=[aval], avals_out=[aval_out]) out, = reducer(reducer_ctx, arg, axes=tuple(positional_axes)) return out - args = map(_positional_reduce, ctx.avals_in, args) + arg = _positional_reduce(aval_in, arg) if not named_axes: - return args + return [arg] replica_groups = _replica_groups_hlo( _replica_groups(ctx.module_context.axis_env, named_axes, @@ -837,20 +998,15 @@ def _positional_reduce(aval, arg): def all_reduce(aval, x): if is_spmd: - channel = ctx.module_context.new_channel() other_args = dict( channel_handle=hlo.ChannelHandle.get( - channel, mlir.DEVICE_TO_DEVICE_TYPE), + mlir.COLLECTIVE_CHANNEL_ID, mlir.DEVICE_TO_DEVICE_TYPE), use_global_device_ids=ir.BoolAttr.get(True)) else: other_args = {} - if hlo.get_api_version() < 8: - op = hlo.AllReduceOp( - x.type, x, replica_groups=replica_groups, **other_args) - else: - op = hlo.AllReduceOp( - [x.type], [x], replica_groups=replica_groups, **other_args) + op = hlo.AllReduceOp( + [x.type], [x], replica_groups=replica_groups, **other_args) scalar_aval = core.ShapedArray( (), aval.dtype, sharding=NamedSharding(aval.sharding.mesh, P())) scalar_type = mlir.aval_to_ir_type(scalar_aval) @@ -862,11 +1018,9 @@ def all_reduce(aval, x): out_nodes = lower_reducer(reducer_ctx, *reducer_block.arguments) hlo.return_(mlir.flatten_ir_values(out_nodes)) return op.result + return [all_reduce(aval_in, arg)] - return [all_reduce(aval, x) for aval, x in zip(ctx.avals_in, args)] - - -def _psum_transpose_rule(cts, *args, axes, axis_index_groups): +def _psum_transpose_rule(cts, arg, *, axes, axis_index_groups): named_axes, pos_axes = axes_partition = [], [] for axis in axes: axes_partition[isinstance(axis, int)].append(axis) @@ -875,18 +1029,16 @@ def _psum_transpose_rule(cts, *args, axes, axis_index_groups): def broadcast_positional(ct, arg): assert ad.is_undefined_primal(arg) if type(ct) is ad.Zero: return ad.Zero(arg.aval) - return lax._reduce_sum_transpose_rule(ct, arg, axes=pos_axes)[0] - cts = map(broadcast_positional, cts, args) + return lax._reduce_sum_transpose_rule(ct, arg, axes=pos_axes, + out_sharding=None)[0] + cts = broadcast_positional(cts, arg) # We treat psum as psum + pbroadcast, which is why the transpose reduces # over the named axes again (unlike for positional axes). - nonzero_out_cts, treedef = tree_util.tree_flatten(cts) - nonzero_in_cts = psum_p.bind(*nonzero_out_cts, axes=tuple(named_axes), - axis_index_groups=axis_index_groups) - return tree_util.tree_unflatten(treedef, nonzero_in_cts) + return (psum_p.bind(cts, axes=tuple(named_axes), + axis_index_groups=axis_index_groups),) psum_p = core.Primitive('psum') -psum_p.multiple_results = True psum_p.def_impl(partial(_allreduce_impl, psum_p, lax.reduce_sum)) psum_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval) mlir.register_lowering( @@ -897,9 +1049,8 @@ def broadcast_positional(ct, arg): batching.skippable_batchers[psum_p] = partial(_names_in_param, 'axes') pmax_p = core.Primitive('pmax') -pmax_p.multiple_results = True pmax_p.def_impl(partial(_allreduce_impl, pmax_p, lax.reduce_max)) -pmax_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval) +pmax_p.def_effectful_abstract_eval(partial(_pmin_pmax_abstract_eval, 'pmax')) mlir.register_lowering( pmax_p, partial(_allreduce_lowering, lax.max_p, lax.reduce_max)) batching.fancy_primitive_batchers[pmax_p] = \ @@ -908,9 +1059,8 @@ def broadcast_positional(ct, arg): pmin_p = core.Primitive('pmin') -pmin_p.multiple_results = True pmin_p.def_impl(partial(_allreduce_impl, pmin_p, lax.reduce_min)) -pmin_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval) +pmin_p.def_effectful_abstract_eval(partial(_pmin_pmax_abstract_eval, 'pmin')) mlir.register_lowering( pmin_p, partial(_allreduce_lowering, lax.min_p, lax.reduce_min)) batching.fancy_primitive_batchers[pmin_p] = \ @@ -918,12 +1068,12 @@ def broadcast_positional(ct, arg): batching.skippable_batchers[pmin_p] = partial(_names_in_param, 'axes') -def _ppermute_lowering(ctx, x, *, axis_name, perm): +def _pcollectives_lowering_common(ctx, *, axis_name, perm, op_name): replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name, None) group_size = len(replica_groups[0]) srcs, dsts = unzip2((src % group_size, dst % group_size) for src, dst in perm) if not (len(srcs) == len(set(srcs)) and len(dsts) == len(set(dsts))): - msg = "ppermute sources and destinations must be unique, got {}." + msg = f"{op_name} sources and destinations must be unique, got {{}}." raise ValueError(msg.format(perm)) full_perm = np.zeros((len(replica_groups), len(perm), 2), np.int64) @@ -940,15 +1090,24 @@ def _ppermute_lowering(ctx, x, *, axis_name, perm): and axis_context.manual_axes ) if is_manual: - channel = ctx.module_context.new_channel() other_args = dict( - channel_handle=hlo.ChannelHandle.get(channel, mlir.DEVICE_TO_DEVICE_TYPE)) + channel_handle=hlo.ChannelHandle.get( + mlir.COLLECTIVE_CHANNEL_ID, mlir.DEVICE_TO_DEVICE_TYPE + ) + ) else: other_args = {} + return full_perm, other_args + +def _ppermute_lowering(ctx, x, *, axis_name, perm): + full_perm, other_args = _pcollectives_lowering_common( + ctx, axis_name=axis_name, perm=perm, op_name="ppermute" + ) return hlo.CollectivePermuteOp( x, mlir.dense_int_elements(full_perm), **other_args).results + def _ppermute_transpose_rule(t, x, perm, axis_name): srcs, dsts = unzip2(perm) inverse_perm = list(zip(dsts, srcs)) @@ -974,7 +1133,9 @@ def _ppermute_batcher(axis_data, vals_in, dims_in, axis_name, perm): return v.take(perm_indices, d), d def _raise_to_shaped_abstract_eval(x, *, axis_name, **params): - _check_axis_names(axis_name) + _check_axis_names(axis_name, 'ppermute') + collective_vma_rule('ppermute', axis_name, x) + check_unreduced_args([x], 'ppermute') return x ppermute_p = core.Primitive('ppermute') @@ -984,6 +1145,108 @@ def _raise_to_shaped_abstract_eval(x, *, axis_name, **params): batching.fancy_primitive_batchers[ppermute_p] = _ppermute_batcher batching.skippable_batchers[ppermute_p] = partial(_names_in_param, 'axis_name') + +@dataclass(frozen=True) +class SingleSideCollectiveEffect(core.Effect): + __str__ = lambda _: "one-sided communication" + def __hash__(self): + return hash(SingleSideCollectiveEffect) + def __eq__(self, other): + return isinstance(other, SingleSideCollectiveEffect) + + +single_side_collective_effect = SingleSideCollectiveEffect() +core.effects.control_flow_allowed_effects.add_type(SingleSideCollectiveEffect) + +def _psend_lowering_gpu(ctx, x, *, axis_name, perm): + if all(p not in ctx.module_context.platforms for p in ("cuda", "rocm")): + raise NotImplementedError("psend is currently only implemented on GPUs") + + full_perm, other_args = _pcollectives_lowering_common( + ctx, axis_name=axis_name, perm=perm, op_name="psend" + ) + token = hlo.create_token() + send_op = hlo.SendOp( + [x], + token, + source_target_pairs=mlir.dense_int_elements(full_perm), + **other_args, + ) + axis_ctx = ctx.module_context.axis_context + if not isinstance(axis_ctx, SPMDAxisContext): + raise NotImplementedError("psend currently only supports manual sharding") + + sharding = xc.OpSharding() + sharding.type = xc.OpSharding.Type.MANUAL + mlir.set_sharding(send_op, sharding) + return send_op.results + + +effects_lib.lowerable_effects.add_type(SingleSideCollectiveEffect) + + +def _psend_abstract_eval(x, *, axis_name, **params): + _check_axis_names(axis_name, 'psend') + return abstract_token, { + *map(core.NamedAxisEffect, axis_name), + single_side_collective_effect, + } + + +psend_p = core.Primitive("psend") +psend_p.def_impl(partial(dispatch.apply_primitive, psend_p)) +psend_p.def_effectful_abstract_eval(_psend_abstract_eval) +mlir.register_lowering(psend_p, _psend_lowering_gpu, platform="gpu") + +def _psend_lowering(ctx, x, *, axis_name, perm): + raise NotImplementedError("psend is currently only implemented on GPU") +mlir.register_lowering(psend_p, _psend_lowering) + +batching.fancy_primitive_batchers[psend_p] = _ppermute_batcher +batching.skippable_batchers[psend_p] = partial(_names_in_param, "axis_name") + + +def _precv_lowering_gpu(ctx, token, *, out_shape, axis_name, perm): + full_perm, other_args = _pcollectives_lowering_common( + ctx, axis_name=axis_name, perm=perm, op_name="precv" + ) + recv_op = hlo.RecvOp( + [mlir.aval_to_ir_type(out_shape), token.type], + token, + source_target_pairs=mlir.dense_int_elements(full_perm), + **other_args, + ) + axis_ctx = ctx.module_context.axis_context + if not isinstance(axis_ctx, SPMDAxisContext): + raise NotImplementedError("precv currently only supports manual sharding") + + sharding = xc.OpSharding() + sharding.type = xc.OpSharding.Type.MANUAL + mlir.set_sharding(recv_op, sharding) + + # recv_op should return an array of [RankedTensorType, StableHlo.token]; we + # only need the tensor. + results = recv_op.results + return [results[0]] + + +def _precv_abstract_eval( + token, *, out_shape, axis_name, **params +): + return out_shape, {*map(core.NamedAxisEffect, axis_name), + single_side_collective_effect} + +precv_p = core.Primitive("precv") +precv_p.def_effectful_abstract_eval(_precv_abstract_eval) +mlir.register_lowering(precv_p, _precv_lowering_gpu, platform='gpu') + +def _precv_lowering(ctx, token, *, out_shape, axis_name, perm): + raise NotImplementedError("precv is currently only implemented on GPU") +mlir.register_lowering(precv_p, _precv_lowering) + +batching.fancy_primitive_batchers[precv_p] = _ppermute_batcher +batching.skippable_batchers[precv_p] = partial(_names_in_param, "axis_name") + def _pbroadcast_transpose_rule(t, x, source, axis_name): is_source = axis_index(axis_name) == source tsum = psum(t, axis_name) @@ -1012,14 +1275,27 @@ def _pbroadcast_lowering(ctx, x, *, axis_name, source): def source_to_front(group): return [group[source]] + list(group[:source]) + list(group[source + 1:]) replica_groups = [source_to_front(group) for group in replica_groups] - channel = ctx.module_context.new_channel() + is_spmd = isinstance( + ctx.module_context.axis_context, + (SPMDAxisContext, ShardingContext), + ) + if is_spmd: + # We want to emit the collective-broadcast with global device IDs and a + # channel ID, as otherwise it interprets the devices as replicas instead + # of partitions - and XLA is configured with only a single replica. + channel_handle = hlo.ChannelHandle.get(mlir.COLLECTIVE_CHANNEL_ID, + mlir.DEVICE_TO_DEVICE_TYPE) + other_args = dict(channel_handle=channel_handle) + else: + other_args = {} return hlo.CollectiveBroadcastOp( - x, replica_groups=_replica_groups_hlo(replica_groups)).results + x, replica_groups=_replica_groups_hlo(replica_groups), **other_args + ).results pbroadcast_p = core.Primitive('pbroadcast') pbroadcast_p.def_abstract_eval(_raise_to_shaped_abstract_eval) ad.deflinear2(pbroadcast_p, _pbroadcast_transpose_rule) -mlir.register_lowering(pbroadcast_p, _pbroadcast_lowering) +mlir.register_lowering(pbroadcast_p, _pbroadcast_lowering, platform='gpu') batching.fancy_primitive_batchers[pbroadcast_p] = _pbroadcast_batcher batching.skippable_batchers[pbroadcast_p] = partial(_names_in_param, 'axis_name') @@ -1057,22 +1333,14 @@ def _all_to_all_lowering( (SPMDAxisContext, ShardingContext), ) if is_spmd: - # We want to emit the all-gather with global device IDs and a unique + # We want to emit the all-gather with global device IDs and a # channel ID, as otherwise it interprets the devices as replicas instead # of partitions - and XLA is configured with only a single replica. - channel = ctx.module_context.new_channel() - channel_handle = hlo.ChannelHandle.get(channel, mlir.DEVICE_TO_DEVICE_TYPE) + channel_handle = hlo.ChannelHandle.get(mlir.COLLECTIVE_CHANNEL_ID, + mlir.DEVICE_TO_DEVICE_TYPE) other_args = dict(channel_handle=channel_handle) else: other_args = {} - if hlo.get_api_version() < 8: - return hlo.AllToAllOp( - x, - split_dimension=mlir.i64_attr(split_axis), - concat_dimension=mlir.i64_attr(concat_axis), - split_count=mlir.i64_attr(split_count), - replica_groups=_replica_groups_hlo(replica_groups), - **other_args).results return hlo.AllToAllOp( [x], split_dimension=mlir.i64_attr(split_axis), @@ -1109,15 +1377,15 @@ def _all_to_all_batcher(vals_in, dims_in, *, axis_name, split_axis, concat_axis, def _all_to_all_batched_collective(axis_data, vals_in, dims_in, axis_name, split_axis, concat_axis, axis_index_groups, tiled): - axis_size, frame_name = axis_data.size, axis_data.name if axis_index_groups is not None: raise NotImplementedError("Please open a feature request!") + axis_size, frame_name = axis_data.size, axis_data.name if isinstance(axis_name, (list, tuple)): axes_names = axis_name else: axes_names = [axis_name] - if axis_data.name not in axes_names: + if frame_name not in axes_names: return _all_to_all_batcher( vals_in, dims_in, axis_name=axis_name, split_axis=split_axis, concat_axis=concat_axis, axis_index_groups=axis_index_groups, tiled=tiled) @@ -1157,6 +1425,7 @@ def _all_to_all_batched_collective(axis_data, vals_in, dims_in, axis_index_groups=axis_index_groups, tiled=tiled) # Split out the local part into axis new_d (NOTE: d is already in axis 1) + assert d == 1 x = _splitaxis(split_axis, axis_size, x) new_d = split_axis concat_axis += (split_axis <= concat_axis) # Offset the existing axes by the new batch axis @@ -1182,18 +1451,28 @@ def _all_to_all_effectful_abstract_eval( del tiled # expand_dims and squeeze is done in `all_to_all` if `True` if not isinstance(axis_name, (list, tuple)): axis_name = (axis_name,) - _check_axis_names(axis_name) + _check_axis_names(axis_name, 'all_to_all') + check_unreduced_args([input_aval], 'all_to_all') shape = list(input_aval.shape) - axis_size = psum(1, axis_name) if axis_index_groups is None else len(axis_index_groups[0]) + axis_size = ( + _axis_size(axis_name) + if axis_index_groups is None + else len(axis_index_groups[0]) + ) assert shape[split_axis] % axis_size == 0, (shape[split_axis], axis_size) shape[split_axis] //= axis_size shape[concat_axis] *= axis_size - out_aval = input_aval.update(shape=tuple(shape), weak_type=False) + vma = collective_vma_rule('all_to_all', axis_name, input_aval) + out_aval = input_aval.update(shape=tuple(shape), weak_type=False, vma=vma) effects = {*map(core.NamedAxisEffect, axis_name)} return out_aval, effects +def _all_to_all_impl(*args, **kwargs): + raise RuntimeError("all_to_all must be used within a mapped context" + " like vmap or shard_map.") all_to_all_p = core.Primitive('all_to_all') +all_to_all_p.def_impl(_all_to_all_impl) all_to_all_p.def_effectful_abstract_eval(_all_to_all_effectful_abstract_eval) mlir.register_lowering(all_to_all_p, _all_to_all_lowering) ad.deflinear2(all_to_all_p, _all_to_all_transpose_rule) @@ -1220,7 +1499,7 @@ def _ragged_all_to_all_lowering( ctx.module_context.axis_context, (SPMDAxisContext, ShardingContext)) if is_spmd: ragged_all_to_all_attrs['channel_id'] = ir.IntegerAttr.get( - ir.IntegerType.get_signless(64), ctx.module_context.new_channel() + ir.IntegerType.get_signless(64), mlir.COLLECTIVE_CHANNEL_ID ) return hlo.CustomCallOp( @@ -1266,7 +1545,7 @@ def _ragged_all_to_all_effectful_abstract_eval( " size, but got shape {}".format(recv_sizes.shape) ) - _check_axis_names(axis_name) + _check_axis_names(axis_name, 'ragged_all_to_all') out_aval = output.update(shape=output.shape, weak_type=False) effects = {*map(core.NamedAxisEffect, axis_name)} return out_aval, effects @@ -1298,21 +1577,65 @@ def _ragged_all_to_all_transpose( operand_t = ragged_all_to_all_p.bind( t, zero, output_offsets_, recv_sizes, input_offsets_, send_sizes, axis_name=axis_name, axis_index_groups=axis_index_groups) - mask = jax.numpy.cumsum( - jax.numpy.zeros(t.shape[0], dtype='int32').at[output_offsets_].set(1)\ + mask = control_flow.cumsum( + lax.full(t.shape[0], 0, dtype='int32').at[output_offsets_].set(1) .at[output_offsets_ + recv_sizes].add(-1)) - mask = jax.numpy.expand_dims(mask, (*range(1, t.ndim),)) - output_t = jax.numpy.where(mask, 0, t) + mask = lax.expand_dims(mask, (*range(1, t.ndim),)) + mask = lax.broadcast_in_dim(mask, shape=t.shape, broadcast_dimensions=tuple(range(t.ndim))) + output_t = lax.select(mask, lax._zeros(t), t) return [operand_t, output_t] + [None] * 4 +def _ragged_all_to_all_batched_collective(axis_data, vals_in, dims_in, + axis_name, axis_index_groups): + if axis_data.name in axis_name: + raise NotImplementedError("Please open a feature request!") + if axis_index_groups: + raise NotImplementedError("Please open a feature request!") + size = axis_data.size + + def bdim_at_second(x, d): + assert x.ndim == 2 + return (batching.broadcast(x, size, 1, None) if d is None else + x if d == 1 else x.T) + def merge(x): return x.reshape(-1, *x.shape[2:]) + def split(x): return x.reshape(size, -1, *x.shape[1:]) + + operand, output = map(partial(batching.bdim_at_front, size=size), vals_in[:2], dims_in[:2]) + N, M = operand.shape[1], output.shape[1] + input_offsets, send_sizes, output_offsets, recv_sizes = \ + map(bdim_at_second, vals_in[2:], dims_in[2:]) + input_offsets += lax.iota(input_offsets.dtype, size)[None, :] * N + output_offsets += lax.iota(output_offsets.dtype, size)[None, :] * M + vals_in = operand, output, input_offsets, send_sizes, output_offsets, recv_sizes + result = split(ragged_all_to_all(*map(merge, vals_in), axis_name=axis_name)) + return result, 0 + +def _ragged_all_to_all_impl(*args, **kwargs): + raise RuntimeError("ragged_all_to_all must be used within a mapped context" + " like vmap or shard_map.") + ragged_all_to_all_p = core.Primitive('ragged_all_to_all') +ragged_all_to_all_p.def_impl(_ragged_all_to_all_impl) ragged_all_to_all_p.def_effectful_abstract_eval(_ragged_all_to_all_effectful_abstract_eval) ad.primitive_jvps[ragged_all_to_all_p] = _ragged_all_to_all_jvp ad.primitive_transposes[ragged_all_to_all_p] = _ragged_all_to_all_transpose mlir.register_lowering(ragged_all_to_all_p, _ragged_all_to_all_lowering) +batching.fancy_primitive_batchers[ragged_all_to_all_p] = _ragged_all_to_all_batched_collective batching.skippable_batchers[ragged_all_to_all_p] = partial(_names_in_param, 'axis_name') -def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False): + +def insert_collective_pvary(axis_name, x): + if not config._check_vma.value: + return x + + axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name + aval = core.get_aval(x) + names_union = set(axis_name) | aval.vma + x = pvary(x, tuple(n for n in names_union if n not in aval.vma)) + return x + +def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False, + to: str = 'varying'): """Gather values of x across all replicas. If ``x`` is a pytree then the result is equivalent to mapping this function to @@ -1376,17 +1699,36 @@ def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False): [[12 13 14 15] [ 4 5 6 7]]] """ + _allowed_ag_to = {'varying', 'reduced'} + if to not in _allowed_ag_to: + raise ValueError( + "Got unexpected `to` value for `jax.lax.all_gather`. Allowed `to`" + f" values are: {_allowed_ag_to}") + if to == 'varying': + return _all_gather(x, axis_name, axis_index_groups=axis_index_groups, + axis=axis, tiled=tiled) + else: + assert to == 'reduced' + if axis_index_groups is not None: + raise NotImplementedError + return all_gather_reduced(x, axis_name, axis=axis, tiled=tiled) + + +def _all_gather(x, axis_name, *, axis_index_groups, axis, tiled): if not isinstance(axis_name, tuple): - axis_name = axis_name, + axis_name = (axis_name,) + if not axis_name: + return x axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) - axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups) + axis_size = _axis_size(axis_name, axis_index_groups) def bind(leaf): + leaf = insert_collective_pvary(axis_name, leaf) return all_gather_p.bind( leaf, all_gather_dimension=canonicalize_axis( axis, np.ndim(leaf) if tiled else np.ndim(leaf) + 1), axis_name=axis_name, axis_index_groups=axis_index_groups, - axis_size=int(axis_size), tiled=tiled) + axis_size=axis_size, tiled=tiled) return tree_util.tree_map(bind, x) def _all_gather_impl(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled): @@ -1409,23 +1751,16 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name, replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name, axis_index_groups) if is_spmd: - # We want to emit the all-gather with global device IDs and a unique + # We want to emit the all-gather with global device IDs and a # channel ID, as otherwise it interprets the devices as replicas instead # of partitions - and XLA is configured with only a single replica. - channel = ctx.module_context.new_channel() other_args = dict( channel_handle=hlo.ChannelHandle.get( - channel, mlir.DEVICE_TO_DEVICE_TYPE), + mlir.COLLECTIVE_CHANNEL_ID, mlir.DEVICE_TO_DEVICE_TYPE), use_global_device_ids=ir.BoolAttr.get(True)) else: other_args = {} - if hlo.get_api_version() < 8: - return hlo.AllGatherOp( - mlir.aval_to_ir_type(out_aval), - x, all_gather_dim=mlir.i64_attr(all_gather_dimension), - replica_groups=_replica_groups_hlo(replica_groups), - **other_args).results return hlo.AllGatherOp( [mlir.aval_to_ir_type(out_aval)], [x], all_gather_dim=mlir.i64_attr(all_gather_dimension), @@ -1433,50 +1768,70 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name, **other_args).results +def collective_vma_rule(prim_name, axis_name, x_aval): + if not config._check_vma.value: + return frozenset() + axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name + if any(a not in x_aval.vma for a in axis_name): + raise ValueError( + f"Collective {prim_name} must be applied to a device-varying " + f" type, but got {x_aval.vma} for collective acting " + f"over axis name {axis_name}. Please open an issue at " + "https://github.com/jax-ml/jax/issues and as a temporary " + "workaround pass the check_vma=False argument to `jax.shard_map`") + return x_aval.vma + def _all_gather_effectful_abstract_eval( x_aval, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled ): if not isinstance(axis_name, (list, tuple)): axis_name = (axis_name,) - _check_axis_names(axis_name) + _check_axis_names(axis_name, 'all_gather') + check_unreduced_args([x_aval], 'all_gather') new_shape = list(x_aval.shape) if tiled: new_shape[all_gather_dimension] *= axis_size else: new_shape.insert(all_gather_dimension, axis_size) - return x_aval.update(shape=new_shape), {*map(core.NamedAxisEffect, axis_name)} + out_vma = collective_vma_rule('all_gather', axis_name, x_aval) + return (x_aval.update(shape=new_shape, vma=out_vma), + {*map(core.NamedAxisEffect, axis_name)}) -def _all_gather_transpose_rule(cts, x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled): +def _all_gather_transpose_rule(cts, x, *, all_gather_dimension, axis_name, + axis_index_groups, axis_size, tiled): return (psum_scatter(cts, axis_name=axis_name, scatter_dimension=all_gather_dimension, axis_index_groups=axis_index_groups, tiled=tiled),) - # TODO(sharadmv,apaszke): re-enable this when we can properly detect replication. - # return (lax.dynamic_index_in_dim(cts, idx, axis=all_gather_dimension, keepdims=False) * axis_size,) -def _all_gather_batcher(vals_in, dims_in, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled): +def _all_gather_batcher(prim, vals_in, dims_in, *, all_gather_dimension, axis_name, + axis_index_groups, axis_size, tiled): (x,), (d,) = vals_in, dims_in if d is not batching.not_mapped: if d <= all_gather_dimension: all_gather_dimension += 1 elif not tiled: # Tiled all-gather doesn't modify the set of dimensions d += 1 - result = all_gather_p.bind( - x, - all_gather_dimension=all_gather_dimension, - axis_name=axis_name, - axis_index_groups=axis_index_groups, - axis_size=axis_size, - tiled=tiled) - return result, d + if prim is all_gather_p: + result = all_gather_p.bind( + x, all_gather_dimension=all_gather_dimension, axis_name=axis_name, + axis_index_groups=axis_index_groups, axis_size=axis_size, + tiled=tiled) + return result, d + else: + assert prim is all_gather_invariant_p + result = all_gather_invariant_p.bind( + x, all_gather_dimension=all_gather_dimension, axis_name=axis_name, + axis_size=axis_size, tiled=tiled) + return result, d -def _all_gather_batched_collective(axis_data, vals_in, dims_in, +def _all_gather_batched_collective(prim, axis_data, vals_in, dims_in, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled): frame_size, frame_name = axis_data.size, axis_data.name if frame_name not in axis_name: return _all_gather_batcher( - vals_in, dims_in, all_gather_dimension=all_gather_dimension, + prim, vals_in, dims_in, all_gather_dimension=all_gather_dimension, axis_name=axis_name, axis_index_groups=axis_index_groups, axis_size=axis_size, tiled=tiled) if axis_index_groups is not None: @@ -1508,10 +1863,103 @@ def _all_gather_batched_collective(axis_data, vals_in, dims_in, partial(_all_gather_lowering, platform=p), platform=p) ad.deflinear2(all_gather_p, _all_gather_transpose_rule) -batching.fancy_primitive_batchers[all_gather_p] = _all_gather_batched_collective +batching.fancy_primitive_batchers[all_gather_p] = partial( + _all_gather_batched_collective, all_gather_p) batching.skippable_batchers[all_gather_p] = partial(_names_in_param, 'axis_name') +def all_gather_invariant(x, axis_name, *, axis: int = 0, tiled: bool = False): + """Gather values of x across all replicas. + + If ``x`` is a pytree then the result is equivalent to mapping this function to + each leaf in the tree. + + all_gather_invariant differs from all_gather in the following ways: + + * all_gather_invariant is Varying -> Invariant. + For example: `out: f32[8] = all_gather_invariant(inp: f32[4]{V: x}, 'x')` + where the size of mesh axis `x` is 2. + While all_gather is Varying -> Varying. + + * all_gather_invariant transposes to dynamic_slice which is + Invariant -> Varying. While all_gather transposes to reduce_scatter + which is Varying -> Varying. + """ + if not isinstance(axis_name, tuple): + axis_name = (axis_name,) + if not axis_name: + return x + axis_size = _axis_size(axis_name, None) + axes_ = frozenset(axis_name) + def bind(leaf): + in_vma = core.typeof(leaf).vma + if vary_names := axes_ - in_vma: + leaf = pvary(leaf, tuple(vary_names)) + return all_gather_invariant_p.bind( + leaf, + all_gather_dimension=canonicalize_axis(axis, np.ndim(leaf) if tiled else + np.ndim(leaf) + 1), + axis_name=axis_name, axis_size=axis_size, tiled=tiled) + return tree_util.tree_map(bind, x) + +all_gather_invariant_p = core.Primitive('all_gather_invariant') + +def _all_gather_invariant_effectful_abstract_eval( + x_aval, *, all_gather_dimension, axis_name, axis_size, tiled +): + _check_axis_names(axis_name, 'all_gather_invariant') + check_unreduced_args([x_aval], 'all_gather_invariant') + new_shape = list(x_aval.shape) + if tiled: + new_shape[all_gather_dimension] *= axis_size + else: + new_shape.insert(all_gather_dimension, axis_size) + out_vma = frozenset(v for v in x_aval.vma if v not in axis_name) + return (x_aval.update(shape=new_shape, vma=out_vma), + {*map(core.NamedAxisEffect, axis_name)}) + +all_gather_invariant_p.def_effectful_abstract_eval( + _all_gather_invariant_effectful_abstract_eval) + +def _all_gather_invariant_impl(x, *, all_gather_dimension, axis_name, axis_size, + tiled): + raise NotImplementedError +all_gather_invariant_p.def_impl(_all_gather_invariant_impl) + + +def _all_gather_invariant_lowering( + ctx, x, *, all_gather_dimension, axis_name, axis_size, tiled, platform=None): + return _all_gather_lowering( + ctx, x, all_gather_dimension=all_gather_dimension, axis_name=axis_name, + axis_index_groups=None, axis_size=axis_size, tiled=tiled, + platform=platform) + +mlir.register_lowering(all_gather_invariant_p, _all_gather_invariant_lowering) +for p in ("cuda", "rocm", "tpu"): + mlir.register_lowering(all_gather_invariant_p, + partial(_all_gather_invariant_lowering, platform=p), + platform=p) + +def _all_gather_invariant_transpose_rule( + cts, x, *, all_gather_dimension, axis_name, axis_size, tiled): + slice_size, rem = divmod(cts.shape[all_gather_dimension], axis_size) + assert not rem + idx = axis_index(axis_name) * slice_size + out = slicing.dynamic_slice_in_dim( + cts, idx, slice_size=slice_size, axis=all_gather_dimension) + return (out,) if tiled else (lax.squeeze(out, [all_gather_dimension]),) +ad.deflinear2(all_gather_invariant_p, _all_gather_invariant_transpose_rule) + +def _all_gather_invariant_batched_collective( + axis_data, vals_in, dims_in, all_gather_dimension, axis_name, axis_size, + tiled): + return _all_gather_batched_collective( + all_gather_invariant_p, axis_data, vals_in, dims_in, all_gather_dimension, + axis_name, None, axis_size, tiled) +batching.fancy_primitive_batchers[all_gather_invariant_p] = _all_gather_invariant_batched_collective +batching.skippable_batchers[all_gather_invariant_p] = partial(_names_in_param, 'axis_name') + + def _reduce_scatter_lowering( prim, ctx, x, *, scatter_dimension, axis_name, @@ -1529,13 +1977,12 @@ def _reduce_scatter_lowering( (SPMDAxisContext, ShardingContext), ) if is_spmd: - # We want to emit the all-gather with global device IDs and a unique + # We want to emit the all-gather with global device IDs and a # channel ID, as otherwise it interprets the devices as replicas instead # of partitions - and XLA is configured with only a single replica. - channel = ctx.module_context.new_channel() other_args = dict( channel_handle=hlo.ChannelHandle.get( - channel, mlir.DEVICE_TO_DEVICE_TYPE), + mlir.COLLECTIVE_CHANNEL_ID, mlir.DEVICE_TO_DEVICE_TYPE), use_global_device_ids=ir.BoolAttr.get(True)) else: other_args = {} @@ -1566,7 +2013,8 @@ def _reduce_scatter_effectful_abstract_eval( ): if not isinstance(axis_name, (list, tuple)): axis_name = (axis_name,) - _check_axis_names(axis_name) + _check_axis_names(axis_name, 'reduce_scatter') + check_unreduced_args([x_aval], 'reduce_scatter') new_shape = list(x_aval.shape) scatter_dim_input_size = x_aval.shape[scatter_dimension] if tiled: @@ -1581,7 +2029,9 @@ def _reduce_scatter_effectful_abstract_eval( f"{scatter_dim_input_size} must match shard count " f"{axis_size}") del new_shape[scatter_dimension] - return x_aval.update(shape=new_shape), {*map(core.NamedAxisEffect, axis_name)} + vma = collective_vma_rule('reduce_scatter', axis_name, x_aval) + return (x_aval.update(shape=new_shape, vma=vma), + {*map(core.NamedAxisEffect, axis_name)}) def _reduce_scatter_transpose_rule(cts, x, *, axis_name, scatter_dimension, @@ -1721,21 +2171,42 @@ def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, [12 14] [16 18]] """ + axes = (axis_name,) if not isinstance(axis_name, tuple) else axis_name + # TODO(yashkatariya): Remove this handling and remove_size_one_mesh_axis_from_type + # generally from JAX. + axes = _maybe_skip_one_sized_axes(axes) + if not axes: + return x + def bind(leaf): + from_ = _get_from(core.typeof(leaf), axes, 'jax.lax.psum_scatter') + if from_ == 'unreduced': + if axis_index_groups is not None: + raise NotImplementedError + return unreduced_psum_scatter( + leaf, axes, scatter_dimension=scatter_dimension, tiled=tiled) + else: + return _psum_scatter(leaf, axes, scatter_dimension=scatter_dimension, + axis_index_groups=axis_index_groups, tiled=tiled) + return tree_util.tree_map(bind, x) + +def _psum_scatter(x, axis_name, *, scatter_dimension, axis_index_groups, tiled): if not isinstance(axis_name, tuple): - axis_name = axis_name, - axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups) + axis_name = (axis_name,) + if not axis_name: + return x + axis_size = _axis_size(axis_name, axis_index_groups) axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) - bind = partial( - reduce_scatter_p.bind, - axis_name=axis_name, - scatter_dimension=scatter_dimension, - axis_index_groups=axis_index_groups, - axis_size=axis_size, - tiled=tiled) + def bind(leaf): + leaf = insert_collective_pvary(axis_name, leaf) + return reduce_scatter_p.bind( + leaf, axis_name=axis_name, scatter_dimension=scatter_dimension, + axis_index_groups=axis_index_groups, axis_size=axis_size, tiled=tiled) return tree_util.tree_map(bind, x) def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env): + from jax._src.shard_map import shard_map # pytype: disable=import-error + if isinstance(axis_name, tuple): assert axis_name, 'empty axis name' if len(axis_name) > 1: @@ -1753,12 +2224,11 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env): axis_context.manual_axes != frozenset(axis_context.mesh.axis_names)): if axis_env.sizes[axis_pos] == 1: return hlo.constant(ir.DenseElementsAttr.get(np.asarray(0, dtype=np.int32))) - from jax.experimental.shard_map import shard_map def f(): return axis_index_p.bind(axis_name=axis_name) return mlir.lower_fun( - lambda: [shard_map(f, axis_context.mesh, check_rep=False, - in_specs=(), out_specs=P())()])(ctx)[0] + lambda: [shard_map(f, check_vma=False, in_specs=(), + out_specs=P())()])(ctx)[0] nreplicas = axis_env.nreps // math.prod(axis_env.sizes) div = mlir.ir_constant( @@ -1781,8 +2251,14 @@ def _axis_index_lowering(ctx, *, axis_name): ctx.module_context.axis_env)] def _axis_index_effectful_abstract_eval(*, axis_name): - _check_axis_names([axis_name]) - return ShapedArray((), np.int32), {core.NamedAxisEffect(axis_name)} + effect = {core.NamedAxisEffect(axis_name)} + axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name + _check_axis_names(axis_name, 'axis_index') + mesh = get_abstract_mesh() + sharding = NamedSharding(mesh, P()) + vma = ((frozenset(axis_name) if mesh._any_axis_manual else frozenset()) + if config._check_vma.value else frozenset()) + return ShapedArray((), np.int32, sharding=sharding, vma=vma), effect def _axis_index_batcher(axis_data, vals_in, dims_in, *, axis_name): return lax.iota(np.int32, axis_data.size), 0 @@ -1813,7 +2289,7 @@ def _pgather_impl(src, idx, *, axes): def _pgather_abstract_eval(src, idx, *, axes): # TODO: Avals with names rule: remove all axes from src, insert those from idx # The order is important, because it is ok to re-insert one of the deleted axes! - _check_axis_names(axes) + _check_axis_names(axes, 'pgather') shape = list(src.shape) for axis in sorted((a for a in axes if isinstance(a, int)), reverse=True): del shape[axis] @@ -1856,3 +2332,530 @@ def _pgather_collective_batcher(axis_size, frame_name, _, vals_in, dims_in, *, a # TODO: Transpose? That requires adding pscatter... batching.fancy_primitive_batchers[pgather_p] = _pgather_collective_batcher batching.skippable_batchers[pgather_p] = partial(_names_in_param, 'axes') + +######################## psum_invariant_p #################################### + +def bind_psum_invariant(leaf, *, axes, axis_index_groups): + if axis_index_groups is not None: + raise NotImplementedError + axes_ = frozenset(axes) + in_vma = core.get_aval(leaf).vma + arg = (pvary(leaf, tuple(pbroadcast_names)) + if (pbroadcast_names := axes_ - in_vma) else leaf) + return psum_invariant_p.bind(arg, axes=axes) + +psum_invariant_p = core.Primitive('psum_invariant') + +def _psum_invariant_impl(arg, *, axes): + return _allreduce_impl(psum_invariant_p, lax.reduce_sum, arg, axes=axes, + axis_index_groups=None) +psum_invariant_p.def_impl(_psum_invariant_impl) + +def _psum_invariant_abstract_eval(name, aval, *, axes): + assert isinstance(axes, tuple) + _check_axis_names(axes, 'psum') + if not set(axes).intersection(aval.vma): + raise ValueError( + "psum is a variant->invariant collective. This means that the axis" + " names mentioned in `axes` passed to `psum` must be present in" + f" `jax.typeof(inp).vma`. Got axes={axes} and" + f" jax.typeof(inp).vma={aval.vma}") + + named_axes = tuple(axis for axis in axes if not isinstance(axis, int)) + pos_axes = tuple(axis for axis in axes if isinstance(axis, int)) + core.check_avals_context_mesh([aval], name) + check_unreduced_args([aval], name) + out_aval = core.ShapedArray( + lax._reduce_op_shape_rule(aval, axes=pos_axes), aval.dtype, + sharding=lax._reduce_op_sharding_rule(aval, axes=pos_axes), + vma=frozenset(a for a in aval.vma if a not in named_axes)) + return out_aval, {core.NamedAxisEffect(axis) for axis in named_axes} +psum_invariant_p.def_effectful_abstract_eval( + partial(_psum_invariant_abstract_eval, psum_invariant_p.name)) + +def _psum_invariant_lowering_rule(ctx, arg, *, axes): + return _allreduce_lowering(lax.add_p, lax.reduce_sum, ctx, arg, axes=axes, + axis_index_groups=None) +mlir.register_lowering(psum_invariant_p, _psum_invariant_lowering_rule) + +def _psum_invariant_batching_rule(axis_data, vals_in, dims_in, axes): + return _batched_reduction_collective( + psum_invariant_p, lambda v, axis_size: axis_size * v, + axis_data, vals_in, dims_in, axes, None) +batching.fancy_primitive_batchers[psum_invariant_p] = _psum_invariant_batching_rule +batching.skippable_batchers[psum_invariant_p] = partial(_names_in_param, 'axes') + +def _psum_invariant_transpose_rule(cts, arg, *, axes): + assert ad.is_undefined_primal(arg) + return (core.pvary(cts, axis_name=axes),) +ad.deflinear2(psum_invariant_p, _psum_invariant_transpose_rule) + +########################### pvary ################################## + +core.pvary_p.def_impl(lambda arg, *, axes: arg) +mlir.register_lowering(core.pvary_p, lambda ctx, x, *, axes: [x]) + +def _pvary_abstract_eval(aval, *, axes): + if not config._check_vma.value: + return aval + _check_axis_names(axes, 'pvary') + check_unreduced_args([aval], 'pvary') + assert isinstance(axes, tuple) + if set(axes).intersection(aval.vma): + raise ValueError( + "pvary is a invariant->variant collective. This means that the axis" + " names mentioned in `axes` passed to `pvary` must not be present in" + f" `jax.typeof(inp).vma`. Got axes={axes} and" + f" jax.typeof(inp)={aval}") + return aval.update(sharding=aval.sharding.update(mesh=get_abstract_mesh()), + vma=aval.vma.union(frozenset(axes))) +core.pvary_p.def_abstract_eval(_pvary_abstract_eval) + +def _pvary_transpose_rule(cts, arg, *, axes): + assert ad.is_undefined_primal(arg) + return (psum_invariant_p.bind(cts, axes=axes),) +ad.deflinear2(core.pvary_p, _pvary_transpose_rule) + +def _pvary_batcher(vals_in, dims_in, *, axes): + if any(type(axis) is int for axis in axes): + raise NotImplementedError + (x,), (d,) = vals_in, dims_in + y = core.pvary_p.bind(x, axes=axes) + return y, d +batching.primitive_batchers[core.pvary_p] = _pvary_batcher + +####################### all_gather_reduced ########################### + +# Varying -> Reduced collective +def all_gather_reduced(x, axis_name, *, axis: int = 0, tiled: bool = False): + if not isinstance(axis_name, tuple): + axis_name = (axis_name,) + if not axis_name: + return x + axis_size = _axis_size(axis_name, None) + def bind(leaf): + return all_gather_reduced_p.bind( + leaf, + all_gather_dimension=canonicalize_axis( + axis, np.ndim(leaf) if tiled else np.ndim(leaf) + 1), + axis_name=axis_name, axis_size=axis_size, tiled=tiled) + return tree_util.tree_map(bind, x) + +all_gather_reduced_p = core.Primitive('all_gather_reduced') + +def _all_gather_reduced_effectful_abstract_eval( + x_aval, *, all_gather_dimension, axis_name, axis_size, tiled +): + _check_axis_names(axis_name, 'all_gather_reduced') + if not x_aval.vma: + raise ValueError('all_gather_reduced only accepts inputs that are' + f' varying. Got {x_aval.str_short(True)}') + # If the intersection between x.vma and axis_name is empty, error + if not (x_aval.vma & set(axis_name)): + raise ValueError( + 'all_gather_reduced is a Varying -> Reduced collective. This means ' + f'that the {axis_name=} passed to `all_gather_reduced` must be present ' + f'in jax.typeof(x).vma={x_aval.vma}') + if x_aval.sharding.spec.reduced & set(axis_name): + raise ValueError( + "all_gather_reduced's input cannot be reduced across the axis_name" + f" provided. Got x={x_aval.str_short(True)} and {axis_name=}") + + new_shape = list(x_aval.shape) + if tiled: + new_shape[all_gather_dimension] *= axis_size + else: + new_shape.insert(all_gather_dimension, axis_size) + + x_aval_s = x_aval.sharding + new_reduced = x_aval_s.spec.reduced | frozenset(axis_name) + out_sharding = x_aval_s.update(spec=x_aval_s.spec.update(reduced=new_reduced)) + out_vma = frozenset(v for v in x_aval.vma if v not in axis_name) + return (x_aval.update(shape=new_shape, vma=out_vma, sharding=out_sharding), + {*map(core.NamedAxisEffect, axis_name)}) +all_gather_reduced_p.def_effectful_abstract_eval( + _all_gather_reduced_effectful_abstract_eval) + + +def _all_gather_reduced_impl(x, *, all_gather_dimension, axis_name, axis_size, + tiled): + raise NotImplementedError +all_gather_reduced_p.def_impl(_all_gather_reduced_impl) + + +def _all_gather_reduced_lowering( + ctx, x, *, all_gather_dimension, axis_name, axis_size, tiled, + platform=None): + return _all_gather_lowering( + ctx, x, all_gather_dimension=all_gather_dimension, axis_name=axis_name, + axis_index_groups=None, axis_size=axis_size, tiled=tiled, + platform=platform) + +mlir.register_lowering(all_gather_reduced_p, _all_gather_reduced_lowering) +for p in ("cuda", "rocm", "tpu"): + mlir.register_lowering(all_gather_reduced_p, + partial(_all_gather_reduced_lowering, platform=p), + platform=p) + +def _all_gather_reduced_transpose_rule( + cts, x, *, all_gather_dimension, axis_name, axis_size, tiled): + return (unreduced_psum_scatter(cts, axis_name=axis_name, + scatter_dimension=all_gather_dimension, + tiled=tiled),) +ad.deflinear2(all_gather_reduced_p, _all_gather_reduced_transpose_rule) + +def _all_gather_reduced_batched_collective( + axis_data, vals_in, dims_in, all_gather_dimension, axis_name, axis_size, + tiled): + raise NotImplementedError( + "Please file an issue at https://github.com/jax-ml/jax/issues") +batching.fancy_primitive_batchers[all_gather_reduced_p] = _all_gather_reduced_batched_collective +batching.skippable_batchers[all_gather_reduced_p] = partial(_names_in_param, 'axis_name') + +####################### unreduced_psum_scatter ########################### + +# Unreduced -> Varying collective +def unreduced_psum_scatter(x, axis_name, *, scatter_dimension=0, tiled=False): + if not isinstance(axis_name, tuple): + axis_name = (axis_name,) + if not axis_name: + return x + axis_size = _axis_size(axis_name, None) + def bind(leaf): + return unreduced_reduce_scatter_p.bind( + leaf, axis_name=axis_name, scatter_dimension=scatter_dimension, + axis_size=axis_size, tiled=tiled) + return tree_util.tree_map(bind, x) + +unreduced_reduce_scatter_p = core.Primitive('unreduced_reduce_scatter') + +def _unreduced_reduce_scatter_effectful_abstract_eval( + x_aval, *, axis_name, scatter_dimension, axis_size, tiled +): + _check_axis_names(axis_name, 'reduce_scatter') + if not x_aval.sharding.spec.unreduced: + raise ValueError('unreduced_psum_scatter only accepts inputs that are' + f' unreduced. Got {x_aval.str_short(True)}') + # If intersection between x.unreduced & axis_name is empty, error + if not (x_aval.sharding.spec.unreduced & frozenset(axis_name)): + raise ValueError( + "unreduced_psum_scatter is a Unreduced -> Varying collective. This" + f" means that the {axis_name=} passed to `unreduced_psum_scatter` must" + " be present in" + f" jax.typeof(x).sharding.spec.unreduced={x_aval.sharding.spec.unreduced}" + ) + if x_aval.vma & set(axis_name): + raise ValueError( + "unreduced_psum_scatter's input cannot be varying across the axis_name" + f" provided. Got x={x_aval.str_short(True)} and {axis_name=}") + + new_shape = list(x_aval.shape) + scatter_dim_input_size = x_aval.shape[scatter_dimension] + if tiled: + if scatter_dim_input_size % axis_size != 0: + raise ValueError(f"tiled reduce_scatter operand scatter dimension size " + f"{scatter_dim_input_size} must be divisible by " + f"shard_count {axis_size}") + new_shape[scatter_dimension] = scatter_dim_input_size // axis_size + else: + if scatter_dim_input_size != axis_size: + raise ValueError(f"reduce_scatter operand scatter dimension size " + f"{scatter_dim_input_size} must match shard count " + f"{axis_size}") + del new_shape[scatter_dimension] + + x_aval_s = x_aval.sharding + out_sharding = x_aval_s.update(spec=x_aval_s.spec.update( + unreduced=frozenset(i for i in x_aval_s.spec.unreduced if i not in axis_name))) + out_vma = x_aval.vma | set(axis_name) + return (x_aval.update(shape=new_shape, vma=out_vma, sharding=out_sharding), + {*map(core.NamedAxisEffect, axis_name)}) +unreduced_reduce_scatter_p.def_effectful_abstract_eval( + _unreduced_reduce_scatter_effectful_abstract_eval) + + +def _unreduced_reduce_scatter_impl( + x, *, axis_name, scatter_dimension, axis_size, tiled): + raise NotImplementedError +unreduced_reduce_scatter_p.def_impl(_unreduced_reduce_scatter_impl) + +def _unreduced_reduce_scatter_transpose_rule( + cts, x, *, axis_name, scatter_dimension, axis_size, tiled): + return (all_gather_reduced(cts, axis_name=axis_name, axis=scatter_dimension, + tiled=tiled),) +ad.deflinear2(unreduced_reduce_scatter_p, _unreduced_reduce_scatter_transpose_rule) + +def _unreduced_reduce_scatter_batcher( + axis_data, vals_in, dims_in, axis_name, scatter_dimension, axis_size, + tiled): + raise NotImplementedError( + "Please file an issue at https://github.com/jax-ml/jax/issues") +batching.fancy_primitive_batchers[unreduced_reduce_scatter_p] = _unreduced_reduce_scatter_batcher +batching.skippable_batchers[unreduced_reduce_scatter_p] = partial(_names_in_param, 'axis_name') + +def _unreduced_reduce_scatter_lowering( + prim, ctx, x, *, axis_name, scatter_dimension, axis_size, tiled): + return _reduce_scatter_lowering( + prim, ctx, x, axis_name=axis_name, scatter_dimension=scatter_dimension, + axis_size=axis_size, tiled=tiled, axis_index_groups=None) +mlir.register_lowering(unreduced_reduce_scatter_p, + partial(_unreduced_reduce_scatter_lowering, lax.add_p)) + +############################## unreduced_psum ########################### + +# Unreduced -> Invariant collective +def unreduced_psum(x, axis_name): + if not isinstance(axis_name, (tuple, list)): + axis_name = (axis_name,) + if not axis_name: + return x + return tree_util.tree_map( + lambda leaf: unreduced_psum_p.bind(leaf, axes=tuple(axis_name)), x) + +unreduced_psum_p = core.Primitive('unreduced_psum') + +def _unreduced_psum_abstract_eval(aval, *, axes): + _check_axis_names(axes, 'psum') + if not aval.sharding.spec.unreduced: + raise ValueError('unreduced_psum only accepts inputs that are' + f' unreduced. Got {aval.str_short(True)}') + # If intersection between x.unreduced & axis_name is empty, error + if not (aval.sharding.spec.unreduced & frozenset(axes)): + raise ValueError( + "unreduced_psum is a Unreduced -> Invariant collective. This" + f" means that the {axes=} passed to `unreduced_psum` must" + " be present in" + f" jax.typeof(x).sharding.spec.unreduced={aval.sharding.spec.unreduced}") + if aval.vma & set(axes): + raise ValueError( + "unreduced_psum's input cannot be varying across the " + f" axis_name provided. Got x={aval.str_short(True)} and {axes=}") + + if any(isinstance(a, int) for a in axes): + raise ValueError('unreduced_psum does not accept integer axis_name.' + f' Got axis_name={axes}') + + core.check_avals_context_mesh([aval], 'unreduced_psum') + a_s = aval.sharding + out_sharding = a_s.update(spec=a_s.spec.update( + unreduced=frozenset(u for u in a_s.spec.unreduced if u not in axes))) + out_aval = aval.update(sharding=out_sharding) + return out_aval, {core.NamedAxisEffect(axis) for axis in axes} +unreduced_psum_p.def_effectful_abstract_eval(_unreduced_psum_abstract_eval) + +def _unreduced_psum_lowering(ctx, arg, *, axes): + return _allreduce_lowering(lax.add_p, lax.reduce_sum, ctx, arg, + axes=axes, axis_index_groups=None) +mlir.register_lowering(unreduced_psum_p, _unreduced_psum_lowering) + +def _unreduced_psum_batcher(axis_data, vals_in, dims_in, axes): + raise NotImplementedError +batching.fancy_primitive_batchers[unreduced_psum_p] = _unreduced_psum_batcher +batching.skippable_batchers[unreduced_psum_p] = partial(_names_in_param, 'axes') + +def _unreduced_psum_transpose_rule(cts, arg, *, axes): + assert ad.is_undefined_primal(arg) + return (preduced(cts, axis_name=axes),) +ad.deflinear2(unreduced_psum_p, _unreduced_psum_transpose_rule) + +############################## preduced ################################# + +# Invariant -> Reduced no-op cast. It's the transpose of unreduced_psum. +def preduced(x, axis_name): + axes = (axis_name,) if not isinstance(axis_name, tuple) else axis_name + if not axes: + return x + cur_mesh = get_abstract_mesh() + new_axes = axes if cur_mesh.empty else core.order_wrt_mesh(cur_mesh, axes) + assert set(new_axes) == set(axes) + del axes + return tree_util.tree_map(lambda l: preduced_p.bind(l, axes=new_axes), x) + +preduced_p = core.Primitive('preduced') +preduced_p.def_impl(lambda arg, *, axes: arg) +mlir.register_lowering(preduced_p, lambda ctx, x, *, axes: [x]) + +def _preduced_abstract_eval(aval, *, axes): + assert isinstance(axes, tuple) + _check_axis_names(axes, 'preduced') + + if aval.vma.intersection(set(axes)): + raise ValueError( + "preduced is a Invariant->Reduced collective. This means that the" + " axis names mentioned in `axes` passed to `preduced` must not be" + f" present in `jax.typeof(inp).vma`. Got axes={axes} and" + f" jax.typeof(inp).vma={aval.vma}") + if aval.sharding.spec.reduced & set(axes): + raise ValueError( + "preduced input cannot be reduced across the axis_name" + f" provided. Got x={aval.str_short(True)} and axis_name={axes}") + + a_s = aval.sharding + new_reduced = a_s.spec.reduced | frozenset(axes) + out_sharding = a_s.update(mesh=get_abstract_mesh(), + spec=a_s.spec.update(reduced=new_reduced)) + out_aval = aval.update(sharding=out_sharding) + return out_aval +preduced_p.def_abstract_eval(_preduced_abstract_eval) + +def _preduced_transpose_rule(cts, arg, *, axes): + assert ad.is_undefined_primal(arg) + return (unreduced_psum(cts, axis_name=axes),) +ad.deflinear2(preduced_p, _preduced_transpose_rule) + +def _preduced_batcher(vals_in, dims_in, *, axes): + raise NotImplementedError +batching.primitive_batchers[preduced_p] = _preduced_batcher + +######################## vary_unreduced_cast ####################### + +# Varying -> Unreduced no-op cast +def vary_unreduced_cast(x, axis_name): + axes = (axis_name,) if not isinstance(axis_name, tuple) else axis_name + if not axis_name: + return x + return tree_util.tree_map( + lambda leaf: vary_unreduced_cast_p.bind(leaf, axes=axes), x) + +vary_unreduced_cast_p = core.Primitive('vary_unreduced_cast_p') +vary_unreduced_cast_p.def_impl(lambda arg, *, axes: arg) +mlir.register_lowering(vary_unreduced_cast_p, lambda ctx, x, *, axes: [x]) + +def _vary_unreduced_cast_abstract_eval(aval, *, axes): + assert isinstance(axes, tuple) + _check_axis_names(axes, 'vary_unreduced_cast') + check_unreduced_args([aval], 'vary_unreduced_cast') + if not aval.vma: + raise ValueError('vary_unreduced_cast only accepts inputs that are' + f' varying. Got {aval.str_short(True)}') + # If the intersection between aval.vma and axes is empty, error + if not (aval.vma & set(axes)): + raise ValueError( + "vary_unreduced_cast is a Varying->Unreduced collective. This" + " means that the axis names mentioned in `axes` passed to" + " `vary_unreduced_cast` must be present in" + f" `jax.typeof(x).vma`. Got axes={axes} and" + f" jax.typeof(x).vma={aval.vma}") + if aval.sharding.spec.unreduced & set(axes): + raise ValueError( + "vary_unreduced_cast input cannot be unreduced across the axis_name" + f" provided. Got x={aval.str_short(True)} and axis_name={axes}") + + aval_s = aval.sharding + new_unreduced = aval_s.spec.unreduced | frozenset(axes) + out_sharding = aval_s.update(mesh=get_abstract_mesh(), + spec=aval_s.spec.update(unreduced=new_unreduced)) + out_vma = frozenset(i for i in aval.vma if i not in axes) + return aval.update(sharding=out_sharding, vma=out_vma) +vary_unreduced_cast_p.def_abstract_eval(_vary_unreduced_cast_abstract_eval) + +def _vary_unreduced_cast_transpose_rule(cts, x, *, axes): + assert ad.is_undefined_primal(x) + return (core.reduced_vary_cast(cts, axis_name=axes),) +ad.deflinear2(vary_unreduced_cast_p, _vary_unreduced_cast_transpose_rule) + +def _vary_unreduced_cast_batcher(vals_in, dims_in, *, axes): + raise NotImplementedError +batching.primitive_batchers[vary_unreduced_cast_p] = _vary_unreduced_cast_batcher + +####################### reduced_vary_cast ############################# + +# Reduced -> Varying no-op cast +# Traceable defined in core.py to avoid circular imports +core.reduced_vary_cast_p.def_impl(lambda arg, *, axes: arg) +mlir.register_lowering(core.reduced_vary_cast_p, lambda ctx, x, *, axes: [x]) + +def _reduced_vary_cast_abstract_eval(aval, *, axes): + assert isinstance(axes, tuple) + _check_axis_names(axes, 'reduced_vary_cast') + if not aval.sharding.spec.reduced: + raise ValueError('reduced_vary_cast only accepts inputs that are' + f' reduced. Got {aval.str_short(True)}') + # If the intersection between aval.spec.reduced and axes is empty, error + if not (aval.sharding.spec.reduced & set(axes)): + raise ValueError( + "reduced_vary_cast is a Reduced->Varying collective. This" + " means that the axis names mentioned in `axes` passed to" + " `reduced_vary_cast` must be present in" + f" `jax.typeof(x).sharding.spec.reduced`. Got axes={axes} and" + f" jax.typeof(x).sharding.spec.reduced={aval.sharding.spec.reduced}") + if aval.vma & set(axes): + raise ValueError( + "reduced_vary_cast input cannot be varying across the axis_name" + f" provided. Got x={aval.str_short(True)} and axis_name={axes}") + + aval_s = aval.sharding + new_reduced = frozenset(i for i in aval_s.spec.reduced if i not in axes) + out_sharding = aval_s.update(mesh=get_abstract_mesh(), + spec=aval_s.spec.update(reduced=new_reduced)) + out_vma = aval.vma | frozenset(axes) + return aval.update(sharding=out_sharding, vma=out_vma) +core.reduced_vary_cast_p.def_abstract_eval(_reduced_vary_cast_abstract_eval) + +def _reduced_vary_cast_transpose_rule(cts, x, *, axes): + assert ad.is_undefined_primal(x) + return (vary_unreduced_cast(cts, axis_name=axes),) +ad.deflinear2(core.reduced_vary_cast_p, _reduced_vary_cast_transpose_rule) + +def _reduced_vary_cast_batcher(vals_in, dims_in, *, axes): + raise NotImplementedError +batching.primitive_batchers[core.reduced_vary_cast_p] = _reduced_vary_cast_batcher + +################################## pcast ############################# + +def _get_from(aval, axes: tuple[AxisName, ...], name) -> str: + vma = aval.vma + unreduced = aval.sharding.spec.unreduced + reduced = aval.sharding.spec.reduced + vma_ur = vma | unreduced | reduced + assert not (vma & unreduced & reduced) # intersection is empty + + out = set() + for a in axes: + if a in vma: + out.add('varying') + elif a in unreduced: + out.add('unreduced') + elif a in reduced: + out.add('reduced') + else: + assert a not in vma_ur + out.add('invarying') + + if len(out) > 1: + raise ValueError( + f"{name} can only accept axis_name which corresponds to one of" + " varying, unreduced, reduced or invarying state of the input. Got" + f" input type: {aval}, axes: {axes} and input state: {out}") + o, = out + return o + + +_pcast_funcs = { + ('invarying', 'varying'): core.pvary, + ('invarying', 'reduced'): preduced, + ('varying', 'unreduced'): vary_unreduced_cast, + ('reduced', 'varying'): core.reduced_vary_cast, +} + +_allowed_pcast_to = {'unreduced', 'reduced', 'varying'} + +def pcast(x, axis_name, *, to: str): + if isinstance(axis_name, (set, frozenset)): + raise TypeError(f"{axis_name=} must be a tuple or a str. Got {axis_name}") + axes = (axis_name,) if not isinstance(axis_name, tuple) else axis_name + if not axis_name: + return x + + if to not in _allowed_pcast_to: + raise ValueError( + "Got unexpected `to` value. Allowed `to` values are:" + f" {_allowed_pcast_to}") + + def bind(leaf): + from_ = _get_from(core.typeof(leaf), axes, 'jax.lax.pcast') + func = _pcast_funcs.get((from_, to), None) + if func is None: + raise ValueError(f"Unsupported pcast from={from_}, {to=}") + return func(leaf, axes) + return tree_util.tree_map(bind, x) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index c26de99c7374..bf72ac81e297 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -26,11 +26,11 @@ from jax._src import ad_util from jax._src import api -from jax._src import config from jax._src import core from jax._src import dispatch from jax._src import dtypes from jax._src import source_info_util +from jax._src.traceback_util import api_boundary from jax._src import util from jax._src import mesh as mesh_lib from jax._src.interpreters import ad @@ -38,27 +38,32 @@ from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax +from jax._src.lax import utils as lax_utils from jax._src.lax.utils import ( _argnum_weak_type, - _input_dtype, + input_dtype, standard_primitive, ) from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo +from jax._src.named_sharding import NamedSharding +from jax._src.partition_spec import PartitionSpec as P from jax._src.typing import Array, ArrayLike, Shape +from jax._src.state.indexing import ds from jax._src.util import safe_map, safe_zip map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip -_dtype = partial(dtypes.dtype, canonicalize=True) +_dtype = dtypes.dtype +_slice = slice def slice(operand: ArrayLike, start_indices: Sequence[int], limit_indices: Sequence[int], strides: Sequence[int] | None = None) -> Array: """Wraps XLA's `Slice - `_ + `_ operator. Args: @@ -118,7 +123,7 @@ def dynamic_slice( allow_negative_indices: bool | Sequence[bool] = True ) -> Array: """Wraps XLA's `DynamicSlice - `_ + `_ operator. Args: @@ -168,13 +173,9 @@ def dynamic_slice( """ start_indices = _dynamic_slice_indices( operand, start_indices, allow_negative_indices) - if config.dynamic_shapes.value: - dynamic_sizes, static_sizes = lax._extract_tracers_dyn_shape(slice_sizes) - else: - dynamic_sizes = [] - static_sizes = core.canonicalize_shape(slice_sizes) # type: ignore - return dynamic_slice_p.bind(operand, *start_indices, *dynamic_sizes, - slice_sizes=tuple(static_sizes)) + sizes = core.canonicalize_shape(slice_sizes) # type: ignore + operand, *start_indices = core.standard_insert_pvary(operand, *start_indices) + return dynamic_slice_p.bind(operand, *start_indices, slice_sizes=tuple(sizes)) def dynamic_update_slice( @@ -184,7 +185,7 @@ def dynamic_update_slice( allow_negative_indices: bool | Sequence[bool] = True ) -> Array: """Wraps XLA's `DynamicUpdateSlice - `_ + `_ operator. Args: @@ -234,13 +235,15 @@ def dynamic_update_slice( """ start_indices = _dynamic_slice_indices( operand, start_indices, allow_negative_indices) + operand, update, *start_indices = core.standard_insert_pvary( + operand, update, *start_indices) return dynamic_update_slice_p.bind(operand, update, *start_indices) class GatherDimensionNumbers(NamedTuple): """ Describes the dimension number arguments to an `XLA's Gather operator - `_. See the XLA + `_. See the XLA documentation for more details of what the dimension numbers mean. Args: @@ -303,7 +306,7 @@ class GatherScatterMode(enum.Enum): ONE_HOT = enum.auto() @staticmethod - def from_any(s: str | GatherScatterMode | None): + def from_any(s: str | GatherScatterMode | None) -> GatherScatterMode: if isinstance(s, GatherScatterMode): return s if s == "clip": @@ -329,7 +332,7 @@ def gather(operand: ArrayLike, start_indices: ArrayLike, """Gather operator. Wraps `XLA's Gather operator - `_. + `_. :func:`gather` is a low-level operator with complicated semantics, and most JAX users will never need to call it directly. Instead, you should prefer using @@ -412,10 +415,13 @@ def gather(operand: ArrayLike, start_indices: ArrayLike, fill_value = dtypes.iinfo(dtype).max elif dtype == dtypes.bool_: fill_value = True + elif dtypes.issubdtype(dtype, dtypes.prng_key): + fill_value = np.iinfo('uint32').max else: raise ValueError(f"Unsupported dtype for gather fill_value {dtype}") else: fill_value = None + operand, start_indices = core.standard_insert_pvary(operand, start_indices) return gather_p.bind( operand, start_indices, dimension_numbers=dimension_numbers, slice_sizes=core.canonicalize_shape(slice_sizes), @@ -428,7 +434,7 @@ def gather(operand: ArrayLike, start_indices: ArrayLike, class ScatterDimensionNumbers(NamedTuple): """ Describes the dimension number arguments to an `XLA's Scatter operator - `_. See the XLA + `_. See the XLA documentation for more details of what the dimension numbers mean. Args: @@ -464,6 +470,7 @@ class ScatterDimensionNumbers(NamedTuple): operand_batching_dims: Sequence[int] = () scatter_indices_batching_dims: Sequence[int] = () +@partial(api_boundary, repro_api_name="lax.scatter_add") def scatter_add( operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike, dimension_numbers: ScatterDimensionNumbers, *, @@ -472,7 +479,7 @@ def scatter_add( """Scatter-add operator. Wraps `XLA's Scatter operator - `_, where + `_, where addition is used to combine updates and values from `operand`. The semantics of scatter are complicated, and its API might change in the @@ -485,33 +492,71 @@ def scatter_add( scatter_indices: an array that gives the indices in `operand` to which each update in `updates` should be applied. updates: the updates that should be scattered onto `operand`. - dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes - how dimensions of `operand`, `scatter_indices`, `updates` and the output + dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes how + dimensions of `operand`, `scatter_indices`, `updates` and the output relate. indices_are_sorted: whether `scatter_indices` is known to be sorted. If true, may improve performance on some backends. unique_indices: whether the elements to be updated in ``operand`` are - guaranteed to not overlap with each other. If true, may improve performance on - some backends. JAX does not check this promise: if the updated elements - overlap when ``unique_indices`` is ``True`` the behavior is undefined. + guaranteed to not overlap with each other. If true, may improve + performance on some backends. JAX does not check this promise: if the + updated elements overlap when ``unique_indices`` is ``True`` the behavior + is undefined. mode: how to handle indices that are out of bounds: when set to 'clip', - indices are clamped so that the slice is within bounds, and when - set to 'fill' or 'drop' out-of-bounds updates are dropped. The behavior - for out-of-bounds indices when set to 'promise_in_bounds' is + indices are clamped so that the slice is within bounds, and when set to + 'fill' or 'drop' out-of-bounds updates are dropped. The behavior for + out-of-bounds indices when set to 'promise_in_bounds' is implementation-defined. Returns: An array containing the sum of `operand` and the scattered updates. + + Examples: + As mentioned above, you should basically never use :func:`scatter_add` + directly, and instead perform scatter-style operations using NumPy-style + indexing expressions via :attr:`jax.numpy.ndarray.at`. + + Here is and example of updating entries in an array using + :attr:`jax.numpy.ndarray.at`, which lowers to an XLA Scatter operation: + + >>> x = jnp.ones(5) + >>> indices = jnp.array([1, 2, 4]) + >>> values = jnp.array([2.0, 3.0, 4.0]) + + >>> x.at[indices].add(values) + Array([1., 3., 4., 1., 5.], dtype=float32) + + This syntax also supports several of the optional arguments to + :func:`scatter_add`, for example: + + >>> x.at[indices].add(values, indices_are_sorted=True, + ... mode='promise_in_bounds') + Array([1., 3., 4., 1., 5.], dtype=float32) + + By comparison, here is the equivalent function call using + :func:`scatter_add` directly, which is not something typical users should + ever need to do: + + >>> lax.scatter_add(x, indices[:, None], values, + ... dimension_numbers=lax.ScatterDimensionNumbers( + ... update_window_dims=(), + ... inserted_window_dims=(0,), + ... scatter_dims_to_operand_dims=(0,)), + ... indices_are_sorted=True, + ... mode=lax.GatherScatterMode.PROMISE_IN_BOUNDS) + Array([1., 3., 4., 1., 5.], dtype=float32) """ jaxpr, consts = lax._reduction_jaxpr(lax.add, core.get_aval(lax._const(operand, 0))) + operand, scatter_indices, updates = core.standard_insert_pvary( + operand, scatter_indices, updates) return scatter_add_p.bind( operand, scatter_indices, updates, update_jaxpr=jaxpr, update_consts=consts, dimension_numbers=dimension_numbers, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=GatherScatterMode.from_any(mode)) - +@partial(api_boundary, repro_api_name="lax.scatter_sub") def scatter_sub( operand: ArrayLike, scatter_indices: ArrayLike, @@ -525,7 +570,7 @@ def scatter_sub( """Scatter-sub operator. Wraps `XLA's Scatter operator - `_, where + `_, where subtraction is used to combine updates and values from `operand`. The semantics of scatter are complicated, and its API might change in the @@ -554,11 +599,14 @@ def scatter_sub( implementation-defined. Returns: - An array containing the sum of `operand` and the scattered updates. + An array containing the difference between `operand` and the scattered + updates. """ jaxpr, consts = lax._reduction_jaxpr( lax.sub, core.get_aval(lax._const(operand, 0)) ) + operand, scatter_indices, updates = core.standard_insert_pvary( + operand, scatter_indices, updates) return scatter_sub_p.bind( operand, scatter_indices, @@ -572,6 +620,7 @@ def scatter_sub( ) +@partial(api_boundary, repro_api_name="lax.scatter_mul") def scatter_mul( operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike, dimension_numbers: ScatterDimensionNumbers, *, @@ -580,7 +629,7 @@ def scatter_mul( """Scatter-multiply operator. Wraps `XLA's Scatter operator - `_, where + `_, where multiplication is used to combine updates and values from `operand`. The semantics of scatter are complicated, and its API might change in the @@ -609,16 +658,19 @@ def scatter_mul( implementation-defined. Returns: - An array containing the sum of `operand` and the scattered updates. + An array containing the product of `operand` and the scattered updates. """ jaxpr, consts = lax._reduction_jaxpr(lax.mul, core.get_aval(lax._const(operand, 1))) + operand, scatter_indices, updates = core.standard_insert_pvary( + operand, scatter_indices, updates) return scatter_mul_p.bind( operand, scatter_indices, updates, update_jaxpr=jaxpr, update_consts=consts, dimension_numbers=dimension_numbers, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=GatherScatterMode.from_any(mode)) +@partial(api_boundary, repro_api_name="lax.scatter_min") def scatter_min( operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike, dimension_numbers: ScatterDimensionNumbers, *, @@ -627,7 +679,7 @@ def scatter_min( """Scatter-min operator. Wraps `XLA's Scatter operator - `_, where + `_, where the `min` function is used to combine updates and values from `operand`. The semantics of scatter are complicated, and its API might change in the @@ -656,16 +708,19 @@ def scatter_min( implementation-defined. Returns: - An array containing the sum of `operand` and the scattered updates. + An array containing the min of `operand` and the scattered updates. """ jaxpr, consts = lax._reduction_jaxpr(lax.min, core.get_aval(lax._const(operand, 0))) + operand, scatter_indices, updates = core.standard_insert_pvary( + operand, scatter_indices, updates) return scatter_min_p.bind( operand, scatter_indices, updates, update_jaxpr=jaxpr, update_consts=consts, dimension_numbers=dimension_numbers, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=GatherScatterMode.from_any(mode)) +@partial(api_boundary, repro_api_name="lax.scatter_max") def scatter_max( operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike, dimension_numbers: ScatterDimensionNumbers, *, @@ -674,7 +729,7 @@ def scatter_max( """Scatter-max operator. Wraps `XLA's Scatter operator - `_, where + `_, where the `max` function is used to combine updates and values from `operand`. The semantics of scatter are complicated, and its API might change in the @@ -703,10 +758,12 @@ def scatter_max( implementation-defined. Returns: - An array containing the sum of `operand` and the scattered updates. + An array containing the max of `operand` and the scattered updates. """ jaxpr, consts = lax._reduction_jaxpr(lax.max, core.get_aval(lax._const(operand, 0))) + operand, scatter_indices, updates = core.standard_insert_pvary( + operand, scatter_indices, updates) return scatter_max_p.bind( operand, scatter_indices, updates, update_jaxpr=jaxpr, update_consts=consts, dimension_numbers=dimension_numbers, @@ -726,7 +783,7 @@ def scatter_apply( """Scatter-apply operator. Wraps `XLA's Scatter operator - `_, where values + `_, where values from ``operand`` are replaced with ``func(operand)``, with duplicate indices resulting in multiple applications of ``func``. @@ -771,6 +828,8 @@ def scatter_apply( pass jaxpr, consts = lax._reduction_jaxpr(_apply, core.get_aval(lax._zero(operand))) # TODO: implement this via its own primitive so we can define appropriate autodiff rules. + operand, scatter_indices, unused = core.standard_insert_pvary( + operand, scatter_indices, unused) return scatter_p.bind( operand, scatter_indices, unused, update_jaxpr=jaxpr, update_consts=consts, dimension_numbers=dimension_numbers, @@ -788,7 +847,7 @@ def scatter( """Scatter-update operator. Wraps `XLA's Scatter operator - `_, where updates + `_, where updates replace values from `operand`. If multiple updates are performed to the same index of operand, they may be @@ -819,7 +878,7 @@ def scatter( implementation-defined. Returns: - An array containing the sum of `operand` and the scattered updates. + An array containing the values of `operand` and the scattered updates. Examples: As mentioned above, you should basically never use :func:`scatter` directly, @@ -829,18 +888,18 @@ def scatter( Here is and example of updating entries in an array using :attr:`jax.numpy.ndarray.at`, which lowers to an XLA Scatter operation: - >>> x = jnp.zeros(5) + >>> x = jnp.ones(5) >>> indices = jnp.array([1, 2, 4]) >>> values = jnp.array([2.0, 3.0, 4.0]) >>> x.at[indices].set(values) - Array([0., 2., 3., 0., 4.], dtype=float32) + Array([1., 2., 3., 1., 4.], dtype=float32) This syntax also supports several of the optional arguments to :func:`scatter`, for example: >>> x.at[indices].set(values, indices_are_sorted=True, mode='promise_in_bounds') - Array([0., 2., 3., 0., 4.], dtype=float32) + Array([1., 2., 3., 1., 4.], dtype=float32) By comparison, here is the equivalent function call using :func:`scatter` directly, which is not something typical users should ever need to do: @@ -852,8 +911,10 @@ def scatter( ... scatter_dims_to_operand_dims=(0,)), ... indices_are_sorted=True, ... mode=lax.GatherScatterMode.PROMISE_IN_BOUNDS) - Array([0., 2., 3., 0., 4.], dtype=float32) + Array([1., 2., 3., 1., 4.], dtype=float32) """ + operand, scatter_indices, updates = core.standard_insert_pvary( + operand, scatter_indices, updates) return scatter_p.bind( operand, scatter_indices, updates, update_jaxpr=None, update_consts=(), dimension_numbers=dimension_numbers, @@ -1081,7 +1142,7 @@ def dynamic_slice_in_dim(operand: Array | np.ndarray, def dynamic_index_in_dim(operand: Array | np.ndarray, - index: int | Array, + index: ArrayLike, axis: int = 0, keepdims: bool = True, *, allow_negative_indices: bool = True) -> Array: @@ -1301,11 +1362,10 @@ def _slice_shape_rule(operand, *, start_indices, limit_indices, strides): msg = ("slice start_indices must be greater than or equal to zero, " "got start_indices of {}.") raise TypeError(msg.format(start_indices)) - if not config.dynamic_shapes.value: - if not all(map(operator.ge, limit_indices, start_indices)): - msg = ("slice limit_indices must be greater than or equal to start_indices," - " got start_indices {} and limit_indices {}.") - raise TypeError(msg.format(start_indices, limit_indices)) + if not all(map(operator.ge, limit_indices, start_indices)): + msg = ("slice limit_indices must be greater than or equal to start_indices," + " got start_indices {} and limit_indices {}.") + raise TypeError(msg.format(start_indices, limit_indices)) diff = tuple(map(operator.sub, limit_indices, start_indices)) if strides is None or tuple(strides) == (1,) * len(operand.shape): return diff @@ -1333,10 +1393,11 @@ def _get_sharding_for_varying_out_shape(out_shape, operand, name): operand.shape, out_shape, operand.sharding.spec): if (op_sh != out_sh and op_spec is not None and out_sh % _get_sub_spec_size(mesh, op_spec) != 0): - raise NotImplementedError( - f"{name} on sharded dims where out dim ({out_sh}) is not divisble by" + raise core.ShardingTypeError( + f"{name} on sharded dims where out dim ({out_sh}) is not divisible by" f" mesh axes ({_get_sub_spec_size(mesh, op_spec)}) with spec" - f" ({op_spec}) is not implemented.") + f" ({op_spec}) is not implemented." + ) # TODO(yashkatariya): Returning operand.sharding as is may or may not move # data. So think about how to avoid it which might include creating a new # mesh? For example: @@ -1355,7 +1416,6 @@ def _slice_sharding_rule(operand, *, start_indices, limit_indices, strides): return _get_sharding_for_varying_out_shape(out_shape, operand, 'slicing') def _slice_transpose_rule(t, operand, *, start_indices, limit_indices, strides): - assert ad.is_undefined_primal(operand) operand_shape = operand.aval.shape if strides is None or np.all(np.equal(strides, 1)): pads = zip(start_indices, np.subtract(operand_shape, limit_indices), @@ -1371,6 +1431,25 @@ def _slice_transpose_rule(t, operand, *, start_indices, limit_indices, strides): assert result.shape == operand_shape, f"{result.shape=} {operand_shape=}" return [result] +def _slice_transpose_fancy(out_ct, operand, *, start_indices, limit_indices, strides): + assert isinstance(operand, ad.GradAccum) + if type(out_ct) is ad_util.Zero: return + if isinstance(operand, ad.RefAccum): + slices = map(_slice, start_indices, limit_indices, strides) + operand.ref.addupdate(out_ct, tuple(slices)) + else: + if strides is None or np.all(np.equal(strides, 1)): + pads = zip(start_indices, np.subtract(operand.aval.shape, limit_indices), + (0,) * len(start_indices)) + else: + real_limits = np.add( + start_indices, + np.where(np.array(out_ct.shape) == 0, 0, + np.add(1, np.multiply(np.subtract(out_ct.shape, 1), strides)))) + pads = zip(start_indices, np.subtract(operand.aval.shape, real_limits), + np.subtract(strides, 1)) + operand.accum(lax.pad(out_ct, lax._const(out_ct, 0), pads)) + def _slice_batching_rule(batched_args, batch_dims, *, start_indices, limit_indices, strides): @@ -1392,13 +1471,12 @@ def _slice_batching_rule(batched_args, batch_dims, *, start_indices, out = slice(operand, new_start_indices, new_limit_indices, new_strides) return out, bdim -slice_p = standard_primitive(_slice_shape_rule, _input_dtype, 'slice', - sharding_rule=_slice_sharding_rule) +slice_p = standard_primitive(_slice_shape_rule, input_dtype, 'slice', + sharding_rule=_slice_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'slice')) ad.deflinear2(slice_p, _slice_transpose_rule) +ad.fancy_transposes[slice_p] = _slice_transpose_fancy batching.primitive_batchers[slice_p] = _slice_batching_rule -# TODO(mvoz): A better slice rule for ragged prop, enforcing boundaries -# or supporting nested jumbles. NYI. -batching.ragged_prop_rules[slice_p] = batching.ragged_mask_no_op_rule # Override the standard impl to defer to dynamic_slice whenever possible. # This lets us reuse the same program for many applications of slicing for as @@ -1425,34 +1503,29 @@ def _slice_lower(ctx, x, *, start_indices, limit_indices, strides): mlir.register_lowering(slice_p, _slice_lower) -def _dynamic_slice_shape_rule(operand, *starts_and_dyn_sizes, slice_sizes): - start_indices, dyn = util.split_list(starts_and_dyn_sizes, [operand.ndim]) - if operand.ndim != len(start_indices): - msg = ("dynamic_slice start_indices must have length equal to the number " - "of dimensions of the operand, got indices {} for operand shape {}.") - raise TypeError(msg.format(start_indices, operand.shape)) - if len(start_indices) != len(slice_sizes): - msg = ("dynamic_slice slice_sizes must have the same length as " - "start_indices, got start_indices length {} and slice_sizes {}.") - raise TypeError(msg.format(len(start_indices), slice_sizes)) - if not dyn and not all(map(operator.ge, operand.shape, slice_sizes)): +def _dynamic_slice_shape_rule(operand, *start_indices, slice_sizes): + if not all(map(operator.ge, operand.shape, slice_sizes)): msg = ("slice slice_sizes must be less than or equal to operand shape, " "got slice_sizes {} for operand shape {}.") raise TypeError(msg.format(slice_sizes, operand.shape)) - if not dyn and not all(ssz >= 0 for ssz in slice_sizes): + if not all(ssz >= 0 for ssz in slice_sizes): msg = ("slice slice_sizes must be greater than or equal to zero, " "got slice_sizes of {}.") raise TypeError(msg.format(slice_sizes)) if any(idx.ndim != 0 for idx in start_indices): raise TypeError("start_indices arguments to dynamic_slice must be scalars, " f" got indices {start_indices}") - return tuple(lax._merge_dyn_shape(slice_sizes, dyn)) + return tuple(slice_sizes) def _dynamic_slice_sharding_rule(operand, *starts_and_dyn_sizes, slice_sizes): out_shape = _dynamic_slice_shape_rule( operand, *starts_and_dyn_sizes, slice_sizes=slice_sizes) return _get_sharding_for_varying_out_shape(out_shape, operand, 'dynamic_slice') +def _dynamic_slice_reduced_rule(out_s, operand, *starts_and_dyn_sizes, + slice_sizes): + return out_s.update(spec=out_s.spec.update( + reduced=operand.sharding.spec.reduced)) def _dynamic_slice_dtype_rule(operand, *starts_and_dyn_sizes, slice_sizes): start_indices, dyn = util.split_list(starts_and_dyn_sizes, [operand.ndim]) @@ -1472,27 +1545,43 @@ def _dynamic_slice_jvp(primals, tangents, *, slice_sizes): def _dynamic_slice_transpose_rule(t, operand, *start_indices, slice_sizes): assert ad.is_undefined_primal(operand) assert all(not ad.is_undefined_primal(s) for s in start_indices) - operand_shape, operand_dtype = operand.aval.shape, operand.aval.dtype if type(t) is ad_util.Zero: return [ad_util.Zero(operand.aval)] + [None] * len(start_indices) else: - zeros = lax.full(operand_shape, 0, operand_dtype) + zeros = lax.full(operand.aval.shape, 0, operand.aval.dtype, + sharding=operand.aval.sharding) + zeros = core.pvary(zeros, tuple(operand.aval.vma)) return ([dynamic_update_slice_p.bind(zeros, t, *start_indices)] + [None] * len(start_indices)) +def _dynamic_slice_transpose_fancy(out_ct, operand, *start_indices, slice_sizes): + assert isinstance(operand, ad.GradAccum) + assert all(not isinstance(s, ad.GradAccum) for s in start_indices) + if type(out_ct) is ad_util.Zero: return + if isinstance(operand, ad.RefAccum): + operand.ref.addupdate(out_ct, tuple(map(ds, start_indices, slice_sizes))) + else: + zeros = lax.full(operand.aval.shape, 0, operand.aval.dtype, + sharding=operand.aval.sharding) + zeros = core.pvary(zeros, tuple(operand.aval.vma)) + operand.accum(dynamic_update_slice_p.bind(zeros, out_ct, *start_indices)) + def _batch_dynamic_slice_indices(indices, bdims): if len(indices) == 0: return np.array([], 'int32'), None empty_marker = object() size = next((x.shape[i] for x, i in zip(indices, bdims) if i is not None), empty_marker) + out = next(((core.typeof(x).sharding.mesh, core.typeof(x).sharding.spec[i]) + for x, i in zip(indices, bdims) if i is not None), None) if size is empty_marker: return lax.concatenate([lax.broadcast(i, (1,)) for i in indices], 0), None + out_s = None if out is None else NamedSharding(out[0], P(out[1], None)) indices = lax.concatenate( - [lax.broadcast_in_dim(x, (size, 1), - broadcast_dimensions=((0,) if i is not None else ())) - for x, i in zip(indices, bdims)], - dimension=1) + [lax.broadcast_in_dim( + x, (size, 1), broadcast_dimensions=((0,) if i is not None else ()), + out_sharding=out_s) + for x, i in zip(indices, bdims)], dimension=1) return indices, 0 def _dynamic_slice_batching_rule(batched_args, batch_dims, *, slice_sizes): @@ -1518,73 +1607,40 @@ def _dynamic_slice_batching_rule(batched_args, batch_dims, *, slice_sizes): [operand, index, *dyn_slice_sizes], [operand_bd, index_bdim, *dyn_slice_size_bds], dimension_numbers=dnums, slice_sizes=slice_sizes, unique_indices=True, indices_are_sorted=True, - mode=GatherScatterMode.PROMISE_IN_BOUNDS, fill_value=None) - -def _dynamic_slice_staging_rule(trace, x, *starts_and_dyn_sizes, slice_sizes): - start_indices, dyn = util.split_list(starts_and_dyn_sizes, [x.ndim]) - if not dyn: - return trace.default_process_primitive(dynamic_slice_p, (x, *start_indices), - dict(slice_sizes=slice_sizes)) - shape = lax._merge_dyn_shape(slice_sizes, dyn) - aval = core.DShapedArray(shape, x.dtype, False) - return lax._dyn_shape_staging_rule(trace, dynamic_slice_p, aval, x, - *starts_and_dyn_sizes, - slice_sizes=slice_sizes) - -def _dynamic_slice_typecheck_rule(_, x, *starts_and_dyn_sizes, slice_sizes): - start_indices, dyn = util.split_list(starts_and_dyn_sizes, [x.aval.ndim]) - if not dyn: - out_aval, effects = dynamic_slice_p.abstract_eval( - x.aval, *(d.aval for d in start_indices), slice_sizes=slice_sizes) - return [out_aval], effects - else: - # TODO(mattjj): perform more checks - out_shape = lax._merge_dyn_shape(slice_sizes, dyn) - out_shape = [d.val if type(d) is core.Literal else d for d in out_shape] - out_aval = core.DShapedArray(tuple(out_shape), x.aval.dtype, - x.aval.weak_type) - return [out_aval], core.no_effects - -def _dynamic_slice_padding_rule(in_avals, out_avals, x, *starts_and_dyn, + mode=GatherScatterMode.CLIP, fill_value=None) + +def _dynamic_slice_staging_rule(trace, source_info, x, *start_indices, slice_sizes): - x_aval, start_indices_avals, dyn_avals = util.split_list(in_avals, [1, x.ndim]) - start_indices, dyn = util.split_list(starts_and_dyn, [x.ndim]) - dyn_ = [a.dtype.bound if type(a.dtype) is core.bint else d - for a, d in zip(dyn_avals, dyn)] - slice_sizes_ = lax._merge_dyn_shape(slice_sizes, dyn_) - start_idx = [d.val if type(d) is core.DArray else d for d in start_indices] - return [dynamic_slice(x, start_idx, slice_sizes_)] + return trace.default_process_primitive( + dynamic_slice_p, (x, *start_indices), dict(slice_sizes=slice_sizes), + source_info=source_info) + +def _dynamic_slice_typecheck_rule(_, x, *start_indices, slice_sizes): + out_aval, effects = dynamic_slice_p.abstract_eval( + x.aval, *(d.aval for d in start_indices), slice_sizes=slice_sizes) + return [out_aval], effects dynamic_slice_p = standard_primitive( _dynamic_slice_shape_rule, _dynamic_slice_dtype_rule, 'dynamic_slice', weak_type_rule=_argnum_weak_type(0), - sharding_rule=_dynamic_slice_sharding_rule) + sharding_rule=_dynamic_slice_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'dynamic_slice'), + reduced_rule=_dynamic_slice_reduced_rule) ad.primitive_jvps[dynamic_slice_p] = _dynamic_slice_jvp ad.primitive_transposes[dynamic_slice_p] = _dynamic_slice_transpose_rule +ad.fancy_transposes[dynamic_slice_p] = _dynamic_slice_transpose_fancy batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule pe.custom_staging_rules[dynamic_slice_p] = _dynamic_slice_staging_rule core.custom_typechecks[dynamic_slice_p] = _dynamic_slice_typecheck_rule -pe.padding_rules[dynamic_slice_p] = _dynamic_slice_padding_rule -def _dynamic_slice_lower(ctx, x, *starts_and_dyn_sizes, slice_sizes): +def _dynamic_slice_lower(ctx, x, *start_indices, slice_sizes): x_aval, *_ = ctx.avals_in - start_indices, dyn = util.split_list(starts_and_dyn_sizes, [x_aval.ndim]) aval_out, = ctx.avals_out - if dyn: - aval_out = aval_out.update(shape=lax._merge_dyn_shape(slice_sizes, dyn)) out = mlir.dynamic_slice(ctx, aval_out, x, start_indices=start_indices) return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)] mlir.register_lowering(dynamic_slice_p, _dynamic_slice_lower) -# def _getslice_lower(ctx, x, lo, hi): -# aval_out, = ctx.avals_out -# return hlo.RealDynamicSliceOp( -# mlir.aval_to_ir_type(aval_out), x, -# mlir.shape_tensor([lo]), mlir.shape_tensor([hi]), mlir.shape_tensor([1]) -# ).results -# mlir.register_lowering(getslice_p, _getslice_lower) - def _dynamic_update_slice_shape_rule(operand, update, *start_indices): if operand.ndim != update.ndim: @@ -1606,13 +1662,33 @@ def _dynamic_update_slice_shape_rule(operand, update, *start_indices): def _dynamic_update_slice_sharding_rule(operand, update, *start_indices): if operand.sharding != update.sharding: - raise TypeError( - "dynamic_update_slice update sharding must be equal to operand" - " sharding, got update sharding" - f" {update.str_short(mesh_axis_types=True)} for operand sharding" - f" {operand.str_short(mesh_axis_types=True)}.") + raise core.ShardingTypeError( + "dynamic_update_slice operand sharding must be equal to update" + " sharding, got operand sharding" + f" {operand.str_short(mesh_axis_types=True)} and update sharding" + f" {update.str_short(mesh_axis_types=True)}.") return operand.sharding +def _dynamic_update_slice_unreduced_rule(out_s, operand, update, *start_indices): + if operand.sharding.spec.unreduced != update.sharding.spec.unreduced: + raise core.ShardingTypeError( + "dynamic_update_slice operand and update must be unreduced along the" + " same axes. Got operand sharding" + f" {operand.str_short(mesh_axis_types=True)} and update sharding" + f" {update.str_short(mesh_axis_types=True)}.") + return out_s.update(spec=out_s.spec.update( + unreduced=operand.sharding.spec.unreduced)) + +def _dynamic_update_slice_reduced_rule(out_s, operand, update, *start_indices): + if operand.sharding.spec.reduced != update.sharding.spec.reduced: + raise core.ShardingTypeError( + "dynamic_update_slice operand and update must be reduced along the" + " same axes. Got operand sharding" + f" {operand.str_short(mesh_axis_types=True)} and update sharding" + f" {update.str_short(mesh_axis_types=True)}.") + return out_s.update(spec=out_s.spec.update( + reduced=operand.sharding.spec.reduced)) + def _dynamic_update_slice_dtype_rule(operand, update, *start_indices): lax.check_same_dtypes("dynamic_update_slice", operand, update) if any(i.dtype != start_indices[0].dtype or @@ -1637,19 +1713,18 @@ def _dynamic_update_slice_jvp(primals, tangents): def _dynamic_update_slice_transpose_rule(t, operand, update, *start_indices): assert all(not ad.is_undefined_primal(x) for x in start_indices) - if ad.is_undefined_primal(update): - update_shape = update.aval.shape - else: - update_shape = update.shape + update_shape = (update.aval.shape if ad.is_undefined_primal(update) else + update.shape) + update_sharding = update.aval.sharding if type(t) is ad_util.Zero: operand_t = ad_util.Zero(operand.aval) if ad.is_undefined_primal(operand) else None update_t = ad_util.Zero(update.aval) if ad.is_undefined_primal(update) else None else: - dus = dynamic_update_slice_p.bind - ds = dynamic_slice_p.bind - zeros = lax._zeros(t, shape=update_shape) - operand_t = dus(t, zeros, *start_indices) if ad.is_undefined_primal(operand) else None - update_t = ds(t, *start_indices, slice_sizes=update_shape) if ad.is_undefined_primal(update) else None + zeros = lax._zeros(t, shape=update_shape, sharding=update_sharding) + operand_t = (dynamic_update_slice_p.bind(t, zeros, *start_indices) + if ad.is_undefined_primal(operand) else None) + update_t = (dynamic_slice_p.bind(t, *start_indices, slice_sizes=update_shape) + if ad.is_undefined_primal(update) else None) return [operand_t, update_t] + [None] * len(start_indices) def _dynamic_update_slice_batching_rule(batched_args, batch_dims): @@ -1678,7 +1753,10 @@ def _dynamic_update_slice_batching_rule(batched_args, batch_dims): dynamic_update_slice_p = standard_primitive( _dynamic_update_slice_shape_rule, _dynamic_update_slice_dtype_rule, - 'dynamic_update_slice', sharding_rule=_dynamic_update_slice_sharding_rule) + 'dynamic_update_slice', sharding_rule=_dynamic_update_slice_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'dynamic_update_slice'), + unreduced_rule=_dynamic_update_slice_unreduced_rule, + reduced_rule=_dynamic_update_slice_reduced_rule) ad.primitive_jvps[dynamic_update_slice_p] = _dynamic_update_slice_jvp ad.primitive_transposes[dynamic_update_slice_p] = \ _dynamic_update_slice_transpose_rule @@ -1697,7 +1775,7 @@ def _dynamic_update_slice_lower(ctx, x, update, *start_indices): def _gather_dtype_rule(operand, indices, *, fill_value, **kwargs): if not dtypes.issubdtype(indices.dtype, np.integer): raise ValueError("indices must have an integer type") - return dtypes.canonicalize_dtype(operand.dtype, allow_extended_dtype=True) + return operand.dtype _rank = lambda arr: len(arr.shape) @@ -1739,7 +1817,7 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers, """Validates the well-formedness of the arguments to Gather. The code implements the checks based on the detailed operation semantics of - XLA's `Gather `_ + XLA's `Gather `_ operator and following the outline of the implementation of ShapeInference::InferGatherShape in TensorFlow. """ @@ -1870,8 +1948,9 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers, slice_size = slice_sizes[i] corresponding_input_size = operand.shape[i] - if not (slice_size >= 0 and - corresponding_input_size >= slice_size): + if not core.is_empty_shape(indices.shape) and not ( + slice_size >= 0 and corresponding_input_size >= slice_size + ): raise TypeError(f"Slice size at index {i} in gather op is out of range, " f"must be within [0, {corresponding_input_size} + 1), " f"got {slice_size}.") @@ -1921,39 +2000,154 @@ def _gather_shape_computation(indices, dimension_numbers, slice_sizes): else next(indices_shape_gen) for i in range(output_shape_rank)) return ans -class GatherShardingError(Exception): - pass + +def _gather_spec_computation(operand, indices, dimension_numbers, slice_sizes): + """Returns gather output sharding spec if unambiguous, else None. + + Operand dimensions can be split into: + 1. Batching dims which must resolve unambiguously with the corresponding + indices batching dims. `operand_batching_dims` in GatherDimensionNumbers. + 2. Sliced dims which must be replicated. These are the subset of dims in + `start_index_map` where the slice size is not equal to the operand size. + A further subset of these (`collapsed_slice_dims`) are collapsed in the + output. + 3. Unsliced dims, these correspond directly to dimensions in the + output and so propagate their shardings. These are the dimensions not + batched or sliced. + + Indices dimensions can be split into: + 1. Batching dims which must resolve unambiguously with the corresponding + operand batching dims. `start_indices_batching_dims` in + GatherDimensionNumbers. + 2. Index vector dim which contains the start indices for each dimension of + the operand sliced in to. This must be replicated. It is the last + dimension of indices. + 3. Other dims which correspond directly to dimensions in the output and + so propagate their shardings. These are the dimensions not present in + index batching dims and index vector dim. + + If the axes of the corresponding batching dims between operand and indices are + both not None and do not match, then sharding propagation cannot be resolved + unambiguously and so we return None. + """ + offset_dims = dimension_numbers.offset_dims + start_index_map = dimension_numbers.start_index_map + collapsed_slice_dims = dimension_numbers.collapsed_slice_dims + operand_batching_dims = dimension_numbers.operand_batching_dims + start_indices_batching_dims = dimension_numbers.start_indices_batching_dims + output_shape_rank = len(offset_dims) + indices.ndim - 1 + index_vector_dim = indices.ndim - 1 + + operand_spec = operand.sharding.spec + indices_spec = list(indices.sharding.spec) + + if (all(s is None for s in operand_spec) and + all(s is None for s in indices_spec)): + return P() + + assert all(i in start_index_map for i in collapsed_slice_dims) + + # start_index_map defined which operand dimensions are sliced in to. However, + # if that slice is a full slice then we can propagate the sharding. + operand_index_dims_full_slice_or_replicated = all( + slice_sizes[i] == operand.shape[i] or operand_spec[i] is None + for i in start_index_map) + batching_dims_resolve_unambiguously = all( + indices_spec[indices_dim] == operand_spec[operand_dim] + or operand_spec[operand_dim] is None + or indices_spec[indices_dim] is None + for (operand_dim, indices_dim) in zip( + operand_batching_dims, start_indices_batching_dims) + ) + # Leads to comms on the bwd pass in scatter + indices_non_batch_dims_replicated = all( + s is None for i, s in enumerate(indices_spec) + if i not in start_indices_batching_dims + ) + + if (operand_index_dims_full_slice_or_replicated + and batching_dims_resolve_unambiguously + and indices_non_batch_dims_replicated): + # Resolve any batched shardings into indices shardings. + for operand_dim, indices_dim in zip( + operand_batching_dims, start_indices_batching_dims): + assert (indices_spec[indices_dim] == operand_spec[operand_dim] + or indices_spec[indices_dim] is None + or operand_spec[operand_dim] is None) + # Resolution and propagation of batching_dims is handled in indices, + # operand_batching_dims spec is resolved into indices spec (to match how + # the gather shape rule resolves output dimensions). + indices_spec[indices_dim] = indices_spec[indices_dim] or operand_spec[operand_dim] + + slice_sizes_gen = ( + (i, s) for i, s in enumerate(slice_sizes) + if i not in collapsed_slice_dims and i not in operand_batching_dims) + indices_spec_gen = iter(indices_spec) + + out_spec = [] + for i in range(output_shape_rank): + if i in offset_dims: + # The offset dims are the set of dimensions in the `gather` output that + # derive solely from the operand. + operand_dim, slice_size = next(slice_sizes_gen) + # Due to the `operand_index_dims_full_slice_or_replicated` check, if a + # slice is not full, it must be replicated. + assert (slice_size == operand.shape[operand_dim] + or operand_spec[operand_dim] is None) + out_spec.append(operand_spec[operand_dim]) + else: + # The other dimensions are either batching dims (which derive from both + # indices and operand, and we resolved above) or solely from indices. + out_spec.append(next(indices_spec_gen)) + return P(*out_spec) + return None + + +def _resolve_mesh(*meshes) -> mesh_lib.AbstractMesh: + """Resolves the mesh between given meshes.""" + unique_meshes = {mesh for mesh in meshes if not mesh.empty} + if len(unique_meshes) > 1: + raise core.ShardingTypeError( + f"Conflicting meshes received. Got: {unique_meshes=}" + ) + if not unique_meshes: + return mesh_lib.get_abstract_mesh() + mesh, = unique_meshes + return mesh + def _gather_sharding_rule(operand, indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): - # TODO(yashkatariya): Write a proper gather sharding rule. - cur_mesh = mesh_lib.get_abstract_mesh() - if cur_mesh.empty or cur_mesh._are_all_axes_auto or cur_mesh._are_all_axes_manual: - return core.get_cur_mesh_sharding() - if (cur_mesh._are_all_axes_explicit and - all(s is None for s in operand.sharding.spec) and - all(s is None for s in indices.sharding.spec)): - return core.get_cur_mesh_sharding() - raise GatherShardingError( - "Use `.at[...].get(out_sharding=)` to provide output PartitionSpec for" - " the gather indexing.") + out_mesh = _resolve_mesh(operand.sharding.mesh, indices.sharding.mesh) + out_spec = _gather_spec_computation(operand, indices, dimension_numbers, + slice_sizes) + if out_spec is None: + raise core.ShardingTypeError( + "Use `.at[...].get(out_sharding=)` to provide output PartitionSpec for" + " the gather indexing as out sharding could not be resolved" + " unambiguously (or would require collectives on inputs). Got" + f" {operand=}, {indices=}") + return NamedSharding(out_mesh, out_spec) def _gather_fill(operand, indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, fill_value, output_shape): """Lowers a FILL_OR_DROP gather as a PROMISE_IN_BOUNDS gather with masking.""" dnums = dimension_numbers - intarray = partial(np.array, dtype=np.int64) - operand_dims = lax.shape_as_value(operand.shape) - indices = lax.convert_element_type(indices, np.int64) + index_dtype = lax_utils.int_dtype_for_shape(operand.shape, signed=True) + intarray = partial(np.array, dtype=index_dtype) + operand_dims = lax.shape_as_value(operand.shape).astype(index_dtype) + indices = lax.convert_element_type(indices, index_dtype) num_batch_dims = len(indices.shape) - 1 - upper_bound = ( - operand_dims[intarray(dnums.start_index_map)] - - lax.shape_as_value(slice_sizes)[intarray(dnums.start_index_map)]) + upper_bound = operand_dims[ + intarray(dnums.start_index_map) + ] - lax.shape_as_value(slice_sizes)[intarray(dnums.start_index_map)].astype( + index_dtype + ) mask = lax.bitwise_and( - lax.ge(indices, np.int64(0)), + lax.ge(indices, index_dtype.type(0)), lax.le(indices, lax.expand_dims(upper_bound, tuple(range(num_batch_dims))))) mask = lax.reduce_and(mask, [num_batch_dims]) @@ -1969,7 +2163,8 @@ def _gather_fill(operand, indices, *, dimension_numbers, slice_sizes, indices_are_sorted=indices_are_sorted, mode=GatherScatterMode.PROMISE_IN_BOUNDS) return lax.select( - lax.broadcast_in_dim(mask, output_shape, batch_dims_in_output), + lax.broadcast_in_dim(mask, output_shape, batch_dims_in_output, + out_sharding=gather_out.aval.sharding), gather_out, lax.full_like(gather_out, fill_value=fill_value)) @@ -1985,11 +2180,12 @@ def _gather_transpose_rule(t, operand, indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): assert ad.is_undefined_primal(operand) - operand_shape = operand.aval.shape if type(t) is ad_util.Zero: out = ad_util.Zero(operand.aval) else: - zeros = lax.full(operand_shape, lax._zero(t)) + zeros = lax.full(operand.aval.shape, 0, core.typeof(t).dtype, + sharding=operand.aval.sharding) + zeros = core.pvary(zeros, tuple(operand.aval.vma)) scatter_dnums = ScatterDimensionNumbers( update_window_dims=dimension_numbers.offset_dims, inserted_window_dims=dimension_numbers.collapsed_slice_dims, @@ -2006,12 +2202,12 @@ def _gather_transpose_rule(t, operand, indices, *, dimension_numbers, def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): - operand, indices, *dyn_slice_sizes = batched_args - operand_bdim, indices_bdim, *dyn_slice_size_bds = batch_dims - dyn_slice_size_bounds = [b.dtype.bound for b in dyn_slice_sizes] + operand, indices = batched_args + operand_bdim, indices_bdim = batch_dims if operand_bdim is not None and indices_bdim is None: - operand, operand_bdim = batching.move_stacked_axis(operand, operand_bdim, 0) + operand = batching.moveaxis(operand, operand_bdim, 0) + operand_bdim = 0 slice_sizes = (operand.shape[0],) + slice_sizes offset_dims = (0,) + tuple(np.add(1, dimension_numbers.offset_dims)) collapsed_slice_dims = tuple(np.add(1, dimension_numbers.collapsed_slice_dims)) @@ -2026,29 +2222,10 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers, operand_batching_dims=operand_batching_dims, start_indices_batching_dims=dimension_numbers.start_indices_batching_dims, ) - if isinstance(operand_bdim, batching.RaggedAxis): - ragged_slice_sizes = batching.bdim_as_shape(operand_bdim, slice_sizes) - for orig, fabricated in zip( - lax._merge_dyn_shape(slice_sizes, dyn_slice_sizes), - ragged_slice_sizes): - if isinstance(fabricated, batching.IndexedAxisSize): - if not core.same_referent(orig, fabricated.lengths): - # Don't know what to do when slicing a ragged dimension with a - # different size. To wit, if the client tries to index outside the - # ragged size, the resulting element should be determined by the - # out of bounds `mode`, but the underlying gather will only do that - # if the client tries to index outside the _padded_ array. I guess - # we should read the mode and apply a mask that writes the correct - # fill element into all out-of-bounds locations? - raise NotImplementedError - bdim_out = batching.shape_as_bdim( - operand_bdim.stacked_axis, - _gather_shape_computation(indices, dnums, ragged_slice_sizes)) - else: - bdim_out = operand_bdim + bdim_out = operand_bdim return gather( operand, indices, dimension_numbers=dnums, - slice_sizes=lax._merge_dyn_shape(slice_sizes, dyn_slice_size_bounds), + slice_sizes=slice_sizes, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value), bdim_out @@ -2105,25 +2282,13 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers, indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value), 0 -def _gather_pad_rule(in_avals, out_avals, operand, indices, *, - dimension_numbers, slice_sizes, unique_indices, - indices_are_sorted, mode, fill_value): - operand_aval, indices_aval = in_avals - if any(isinstance(d, pe.BoundedAxisSize) for d in operand_aval.shape): - raise NotImplementedError - if mode != GatherScatterMode.PROMISE_IN_BOUNDS: - # with fill, jnp.where on operand; with clip, jnp.where on indices - raise NotImplementedError - return [gather(operand, indices, dimension_numbers=dimension_numbers, - slice_sizes=slice_sizes, mode=mode, fill_value=fill_value)] - gather_p = standard_primitive( _gather_shape_rule, _gather_dtype_rule, 'gather', - weak_type_rule=_argnum_weak_type(0), sharding_rule=_gather_sharding_rule) + weak_type_rule=_argnum_weak_type(0), sharding_rule=_gather_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'gather')) ad.defjvp(gather_p, _gather_jvp_rule, None) ad.primitive_transposes[gather_p] = _gather_transpose_rule batching.primitive_batchers[gather_p] = _gather_batching_rule -pe.padding_rules[gather_p] = _gather_pad_rule def _gather_lower_opaque(ctx, operand, indices, *, @@ -2149,6 +2314,7 @@ def _gather_lower_opaque(ctx, operand, indices, *, def _gather_lower(ctx, operand, indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): + _, indices_aval = ctx.avals_in aval_out, = ctx.avals_out if dtypes.issubdtype(aval_out.dtype, dtypes.extended): return [_gather_lower_opaque( @@ -2173,7 +2339,7 @@ def _gather_lower(ctx, operand, indices, *, start_indices_batching_dims=list( dimension_numbers.start_indices_batching_dims ), - index_vector_dim=len(ctx.avals_in[1].shape) - 1, + index_vector_dim=len(indices_aval.shape) - 1, offset_dims=list(dimension_numbers.offset_dims), start_index_map=list(dimension_numbers.start_index_map), ) @@ -2192,13 +2358,13 @@ def _gather_lower(ctx, operand, indices, *, } return hlo.DynamicGatherOp.build_generic( results=results, operands=operands, attributes=attributes).results + elif core.is_empty_shape(aval_out.shape): + out = mlir.full_like_aval(ctx, 0, aval_out) + return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)] else: - return [hlo.gather( - operand, - indices, - dnums, - mlir.dense_int_array(slice_sizes), - indices_are_sorted=ir.BoolAttr.get(indices_are_sorted))] + out = hlo.gather(operand, indices, dnums, mlir.dense_int_array(slice_sizes), + indices_are_sorted=ir.BoolAttr.get(indices_are_sorted)) + return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)] mlir.register_lowering(gather_p, _gather_lower) @@ -2206,7 +2372,15 @@ def _scatter_dtype_rule(operand, indices, updates, **kwargs): if not dtypes.issubdtype(indices.dtype, np.integer): raise ValueError("indices must have an integer type") lax.check_same_dtypes("scatter", operand, updates) - return dtypes.canonicalize_dtype(operand.dtype, allow_extended_dtype=True) + return operand.dtype + +def _get_updates_batching_dims(indices_batching_dims, update_window_dims, + index_vector_dim, updates_shape): + scatter_dim_in_updates = list(range(index_vector_dim)) + for i in update_window_dims: + scatter_dim_in_updates.insert(i, None) # type: ignore + assert len(scatter_dim_in_updates) == len(updates_shape) + return tuple(scatter_dim_in_updates.index(i) for i in indices_batching_dims) def _scatter_shape_rule(operand, indices, updates, *, update_jaxpr, update_consts, dimension_numbers, indices_are_sorted, @@ -2215,7 +2389,7 @@ def _scatter_shape_rule(operand, indices, updates, *, update_jaxpr, Scatter. The code implements the checks based on the detailed operation semantics of - XLA's `Scatter `_ + XLA's `Scatter `_ operator and following the outline of the implementation of ShapeInference::InferScatterShape in TensorFlow. """ @@ -2303,7 +2477,16 @@ def _scatter_shape_rule(operand, indices, updates, *, update_jaxpr, f"dimensions to have the same shape, got {operand_batch_shape} and " f"{indices_batch_shape}." ) - + updates_batching_dims = _get_updates_batching_dims( + scatter_indices_batching_dims, update_window_dims, index_vector_dim, + updates.shape) + updates_batch_shape = tuple(updates.shape[i] for i in updates_batching_dims) + if not core.definitely_equal_shape(operand_batch_shape, updates_batch_shape): + raise TypeError( + "Scatter op requires operand batching dimensions and updates batching " + f"dimensions to have the same shape, got {operand_batch_shape} and " + f"{updates_batch_shape}." + ) # Validate window_size window_size = ( len(update_window_dims) + @@ -2368,6 +2551,173 @@ def _scatter_shape_rule(operand, indices, updates, *, update_jaxpr, return operand.shape +def _is_resolvable(*axis_names: str | None) -> bool: + """Checks if given sharding axis names resolve unambiguously.""" + assert len(axis_names) >= 2, "At least two axis names expected." + return len({a for a in axis_names if a is not None}) <= 1 + + +def _scatter_spec_computation( + operand, indices, updates, dimension_numbers) -> P | None: + """For a scatter, we consider the gather rules in inverse. + + We consider a gather, then convert to scatter as the inverse. + A gather is a group of queries, each query slices a window out of `operand`. + + The `operand` dims sliced in to are those in `start_index_map`. Some of size 1 + dims in `start_index_map` are squeezed out of the gather output, termed + `collapsed_slice_dims`. Hence `collapsed_slice_dims` is a subset of + `start_index_map`. Confusingly, some of the slice sizes into `operand` can be + full slices, hence it can be inferred that the corresponding start index of + the window in that dim is 0 and hence the `operand` is not sliced in to. + + All dims of `indices`, except the final dim (`index_vector_dim`) are different + queries in to `operand`. These queries may be batched with `operand`, and so + are effectively N (query, operand) pairs so the queries are non-overlapping, + each into different matching sized slices of `operand` (N being the batch + dimension size), or unbatched where N queries are into the same `operand`. The + actual indices of the start of the (fixed size) windows are contained in the + final dimension of `indices`, the `index_vector_dim`. + + Some dims of `operand` are not sliced in to, these are `offset_dims` + [offset_dims is defined from the gather output perspective (the `updates` in + scatter) and by construction are the frontmost dims in `operand` after + disregarding `collapsed_slice_dims` and `operand_batching_dims`]. + + For a gather: + - Batch dimensions must be resolvable unambiguously between `operand` and + `indices` + - Dims the `operand` is sliced in to are either full slices, or the `operand` + is replicated in that dim (all the data for each query is present) + - The `index_vector_dim` containing the actual start indices of the windows + in to `operand` must be replicated + + A scatter is a group of queries, each of which has a corresponding window of + `updates` to put in to `operand`. + + For a scatter: + 1 - Batch dimensions must resolve unambiguously between `operand`, `indices` + and `updates` + 2 - Full slice dims must resolve unambiguously between `operand` and + `updates` + 3 - Sub slice dims in `operand` and `updates` must be replicated. + 4 - `Indices` and `updates` must be replicated in dims where both are + updating, possibly overlapping, subslices of operand. These are referred to + as unbatched queries in the above description of gather. + 5 - The `index_vector_dim` containing the actual start indices of the windows + in to `operand` must be replicated + + Not all dimensions updating slices in `operand` exist in `updates`. The + `inserted_window_dims` provide size 1 updates into `operand` and are inserted + into `updates`. We handle `operand` slices present in `updates`, then handle + the inserted dims separately. + + Correspondance: + Gather <-> Scatter + offset_dims <-> update_window_dims + collapsed_slice_dims <-> inserted_window_dims + start_index_map <-> scatter_dims_to_operand_dims + """ + operand_batching_dims = dimension_numbers.operand_batching_dims + indices_batching_dims = dimension_numbers.scatter_indices_batching_dims + update_window_dims = dimension_numbers.update_window_dims + inserted_window_dims = dimension_numbers.inserted_window_dims + index_vector_dim = indices.ndim - 1 + + operand_spec = operand.sharding.spec + indices_spec = indices.sharding.spec + updates_spec = updates.sharding.spec + + if (all(s is None for s in operand_spec) and + all(s is None for s in indices_spec) and + all(s is None for s in updates_spec)): + return P() + + updates_batching_dims = _get_updates_batching_dims( + indices_batching_dims, update_window_dims, index_vector_dim, updates.shape) + + # Work out the corresponding operand dim for update window dims + operand_window_dims = [ + i for i in range(operand.ndim) + if i not in inserted_window_dims and i not in operand_batching_dims] + + # 1 - Batch dimensions must resolve unambiguously between `operand`, `indices` + # and `updates` + batch_dims_resolvable = all( + _is_resolvable(operand_spec[od], indices_spec[id], updates_spec[ud]) + for od, id, ud in zip( + operand_batching_dims, indices_batching_dims, updates_batching_dims)) + + # 2 - Full slice dims must resolve unambiguously between `operand` and `updates` + # 3 - Sub slice dims in `operand` and `updates` must be replicated. + update_and_operand_window_dims_resolvable = all( + (updates.shape[update_dim] == operand.shape[operand_dim] and + updates_spec[update_dim] == operand_spec[operand_dim]) + or (updates_spec[update_dim] is None and operand_spec[operand_dim] is None) + for update_dim, operand_dim in zip( + update_window_dims, operand_window_dims)) + inserted_window_dims_replicated_in_operand = all( + spec is None for i, spec in enumerate(operand_spec) + if i in inserted_window_dims) + + # 4 - `Indices` and `updates` must be replicated in dims where both are + # updating, possibly overlapping, slices of operand. These are referred to as + # unbatched queries in the above description of gather. + unbatched_query_dims_in_updates_replicated = all( + spec is None for i, spec in enumerate(updates_spec) + if i not in updates_batching_dims and i not in update_window_dims) + unbatched_query_dims_in_indices_replicated = all( + spec is None for i, spec in enumerate(indices_spec[:index_vector_dim]) + if i not in indices_batching_dims) + + # 5 - The `index_vector_dim` containing the actual start indices of the + # windows in to `operand` must be replicated + index_vector_dim_is_replicated = indices_spec[index_vector_dim] is None + + if (batch_dims_resolvable and + update_and_operand_window_dims_resolvable and + inserted_window_dims_replicated_in_operand and + unbatched_query_dims_in_updates_replicated and + unbatched_query_dims_in_indices_replicated and + index_vector_dim_is_replicated): + out_spec = list(operand_spec) + + # 1 - Batch dims + for operand_dim, indices_dim, updates_dim in zip( + operand_batching_dims, indices_batching_dims, updates_batching_dims): + if out_spec[operand_dim] is None: + out_spec[operand_dim] = ( + indices_spec[indices_dim] or updates_spec[updates_dim]) + # 2, 3 - Full/sub slices of operand dims present in updates + for update_dim, operand_dim in zip(update_window_dims, operand_window_dims): + if out_spec[operand_dim] is None: + out_spec[operand_dim] = updates_spec[update_dim] + return P(*out_spec) + + return None + + +def _scatter_memory_space_rule( + operand, indices, updates, *, update_jaxpr, update_consts, + dimension_numbers, indices_are_sorted, unique_indices, mode): + return operand.memory_space + + +def _scatter_sharding_rule( + operand, indices, updates, *, update_jaxpr, update_consts, + dimension_numbers, indices_are_sorted, unique_indices, mode): + out_mesh = _resolve_mesh( + *(x.sharding.mesh for x in (operand, indices, updates))) + out_spec = _scatter_spec_computation(operand, indices, updates, + dimension_numbers) + if out_spec is None: + raise core.ShardingTypeError( + "Use `.at[...].set/add/mul/...(out_sharding=)` to provide output" + " PartitionSpec for the scatter update as out sharding could not be" + " resolved unambiguously (or would require collectives on inputs). Got" + f" {operand=}, {indices=}, {updates=}") + return NamedSharding(out_mesh, out_spec) + def _clamp_scatter_indices(operand, indices, updates, *, dnums): """Clamps `indices` to be in-range for a scatter.""" slice_sizes = [] @@ -2381,78 +2731,48 @@ def _clamp_scatter_indices(operand, indices, updates, *, dnums): upper_bounds: core.Shape = tuple(operand.shape[i] - slice_sizes[i] for i in dnums.scatter_dims_to_operand_dims) + # Stack upper_bounds into a Array[n] upper_bound = lax.shape_as_value(upper_bounds) # This fix fails lax_test_no_jax_array - upper_bound = lax.min(upper_bound, - lax.convert_element_type(np.uint64(np.iinfo(indices.dtype).max), - np.int64)) - + upper_bound = lax.min( + upper_bound, + upper_bound.dtype.type( + min(np.iinfo(upper_bound.dtype).max, np.iinfo(indices.dtype).max) + ), + ) + upper_bound = lax.convert_element_type(upper_bound, indices.dtype) upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape, (len(indices.shape) - 1,)) - return lax.clamp(np.int64(0), lax.convert_element_type(indices, np.int64), - upper_bound) - + return lax.clamp(indices.dtype.type(0), indices, upper_bound) def _scatter_addsub_jvp( - prim, - primals, - tangents, - *, - update_jaxpr, - update_consts, - dimension_numbers, - indices_are_sorted, - unique_indices, - mode, -): + prim, primals, tangents, *, update_jaxpr, update_consts, dimension_numbers, + indices_are_sorted, unique_indices, mode): operand, indices, updates = primals g_operand, g_indices, g_updates = tangents del g_indices # ignored val_out = prim.bind( - operand, - indices, - updates, - update_jaxpr=update_jaxpr, - update_consts=update_consts, - dimension_numbers=dimension_numbers, - indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, - mode=mode, - ) + operand, indices, updates, update_jaxpr=update_jaxpr, + update_consts=update_consts, dimension_numbers=dimension_numbers, + indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, + mode=mode) if type(g_operand) is ad_util.Zero and type(g_updates) is ad_util.Zero: tangent_out = ad_util.Zero.from_primal_value(val_out) else: g_operand = ad.instantiate_zeros(g_operand) g_updates = ad.instantiate_zeros(g_updates) tangent_out = prim.bind( - g_operand, - indices, - g_updates, - update_jaxpr=update_jaxpr, - update_consts=update_consts, - dimension_numbers=dimension_numbers, - indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, - mode=mode, - ) + g_operand, indices, g_updates, update_jaxpr=update_jaxpr, + update_consts=update_consts, dimension_numbers=dimension_numbers, + indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, + mode=mode) return val_out, tangent_out def _scatter_addsub_transpose_rule( - prim, - t, - operand, - indices, - updates, - *, - update_jaxpr, - update_consts, - dimension_numbers, - indices_are_sorted, - unique_indices, - mode, -): + prim, t, operand, indices, updates, *, update_jaxpr, update_consts, + dimension_numbers, indices_are_sorted, unique_indices, mode): assert not ad.is_undefined_primal(indices) if ad.is_undefined_primal(updates): updates_shape = updates.aval.shape @@ -2537,7 +2857,7 @@ def _scatter_mul_transpose_rule(t, operand, indices, updates, *, return [operand_t, None, update_t] -def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *, +def _scatter_batching_rule(scatter_op, axis_data, batched_args, batch_dims, *, update_jaxpr, update_consts, dimension_numbers, indices_are_sorted, unique_indices, mode): operand, indices, updates = batched_args @@ -2547,9 +2867,10 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *, # it at the front (so that we can scatter into it) size = next(x.shape[ax] for x, ax in zip(batched_args, batch_dims) if ax is not None) - operand = batching.bdim_at_front(operand, operand_bdim, size) - - updates = batching.bdim_at_front(updates, updates_bdim, size) + operand = batching.bdim_at_front(operand, operand_bdim, size, + axis_data.explicit_mesh_axis) + updates = batching.bdim_at_front(updates, updates_bdim, size, + axis_data.explicit_mesh_axis) if indices_bdim is None: inserted_window_dims = tuple(np.add(1, dimension_numbers.inserted_window_dims)) @@ -2571,29 +2892,23 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *, mode=mode, update_jaxpr=update_jaxpr, update_consts=update_consts), 0 # see the third case in _gather_batching_rule for comparison and comments - indices = batching.bdim_at_front(indices, indices_bdim, size) + indices = batching.bdim_at_front(indices, indices_bdim, size, + axis_data.explicit_mesh_axis) update_window_dims = tuple(np.add(1, dimension_numbers.update_window_dims)) - inserted_window_dims = tuple( - np.add(1, dimension_numbers.inserted_window_dims) - ) + inserted_window_dims = tuple(np.add(1, dimension_numbers.inserted_window_dims)) operand_batching_dims = (0,) + tuple( - np.add(1, dimension_numbers.operand_batching_dims) - ) + np.add(1, dimension_numbers.operand_batching_dims)) scatter_indices_batching_dims = (0,) + tuple( - np.add(1, dimension_numbers.scatter_indices_batching_dims) - ) + np.add(1, dimension_numbers.scatter_indices_batching_dims)) scatter_dims_to_operand_dims = tuple( - np.add(1, dimension_numbers.scatter_dims_to_operand_dims) - ) - + np.add(1, dimension_numbers.scatter_dims_to_operand_dims)) dnums = ScatterDimensionNumbers( update_window_dims=update_window_dims, inserted_window_dims=inserted_window_dims, scatter_dims_to_operand_dims=scatter_dims_to_operand_dims, operand_batching_dims=operand_batching_dims, - scatter_indices_batching_dims=scatter_indices_batching_dims, - ) + scatter_indices_batching_dims=scatter_indices_batching_dims) return scatter_op.bind( operand, indices, updates, dimension_numbers=dnums, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, @@ -2601,27 +2916,31 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *, scatter_add_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-add', - weak_type_rule=_argnum_weak_type(0)) + weak_type_rule=_argnum_weak_type(0), sharding_rule=_scatter_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'scatter_add'), + memory_space_rule=_scatter_memory_space_rule) ad.primitive_jvps[scatter_add_p] = partial(_scatter_addsub_jvp, scatter_add_p) ad.primitive_transposes[scatter_add_p] = partial(_scatter_addsub_transpose_rule, scatter_add_p) -batching.primitive_batchers[scatter_add_p] = ( - partial(_scatter_batching_rule, scatter_add_p)) +batching.fancy_primitive_batchers[scatter_add_p] = partial(_scatter_batching_rule, scatter_add_p) +batching.skippable_batchers[scatter_add_p] = lambda _: () scatter_sub_p = standard_primitive( - _scatter_shape_rule, - _scatter_dtype_rule, - "scatter-sub", - weak_type_rule=_argnum_weak_type(0), + _scatter_shape_rule, _scatter_dtype_rule, 'scatter-sub', + weak_type_rule=_argnum_weak_type(0), sharding_rule=_scatter_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'scatter_sub'), + memory_space_rule=_scatter_memory_space_rule ) ad.primitive_jvps[scatter_sub_p] = partial(_scatter_addsub_jvp, scatter_sub_p) ad.primitive_transposes[scatter_sub_p] = partial(_scatter_addsub_transpose_rule, scatter_sub_p) -batching.primitive_batchers[scatter_sub_p] = partial( - _scatter_batching_rule, scatter_sub_p -) +batching.fancy_primitive_batchers[scatter_sub_p] = partial( + _scatter_batching_rule, scatter_sub_p) +batching.skippable_batchers[scatter_sub_p] = lambda _: () scatter_mul_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-mul', - weak_type_rule=_argnum_weak_type(0)) + weak_type_rule=_argnum_weak_type(0), sharding_rule=_scatter_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'scatter_mul'), + memory_space_rule=_scatter_memory_space_rule) def _scatter_mul_jvp_rhs(g, x, i, y, *, dimension_numbers, indices_are_sorted, unique_indices, mode, **kw): @@ -2638,8 +2957,9 @@ def _scatter_mul_jvp_rhs(g, x, i, y, *, dimension_numbers, None, _scatter_mul_jvp_rhs) ad.primitive_transposes[scatter_mul_p] = _scatter_mul_transpose_rule -batching.primitive_batchers[scatter_mul_p] = ( +batching.fancy_primitive_batchers[scatter_mul_p] = ( partial(_scatter_batching_rule, scatter_mul_p)) +batching.skippable_batchers[scatter_mul_p] = lambda _: () def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr, update_consts, dimension_numbers, @@ -2750,16 +3070,22 @@ def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr, scatter_min_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-min', - weak_type_rule=_argnum_weak_type(0)) -batching.primitive_batchers[scatter_min_p] = ( + weak_type_rule=_argnum_weak_type(0), sharding_rule=_scatter_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'scatter_min'), + memory_space_rule=_scatter_memory_space_rule) +batching.fancy_primitive_batchers[scatter_min_p] = ( partial(_scatter_batching_rule, scatter_min_p)) +batching.skippable_batchers[scatter_min_p] = lambda _: () ad.primitive_jvps[scatter_min_p] = partial(_scatter_extremal_jvp, scatter_min_p) scatter_max_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-max', - weak_type_rule=_argnum_weak_type(0)) -batching.primitive_batchers[scatter_max_p] = ( + weak_type_rule=_argnum_weak_type(0), sharding_rule=_scatter_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'scatter_max'), + memory_space_rule=_scatter_memory_space_rule) +batching.fancy_primitive_batchers[scatter_max_p] = ( partial(_scatter_batching_rule, scatter_max_p)) +batching.skippable_batchers[scatter_max_p] = lambda _: () ad.primitive_jvps[scatter_max_p] = partial(_scatter_extremal_jvp, scatter_max_p) def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts, @@ -2812,10 +3138,7 @@ def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts, for update_dim in dnums.update_window_dims: ids_shape[update_dim] = 1 num_ids = math.prod(ids_shape) - if core.is_constant_dim(num_ids): - id_dtype = np.uint32 if (num_ids + 1) < np.iinfo(np.uint32).max else np.uint64 - else: - id_dtype = np.uint64 + id_dtype = lax_utils.int_dtype_for_dim(num_ids, signed=False) update_ids = lax.add(lax.reshape(lax.iota(id_dtype, num_ids), ids_shape), lax._ones(updates, dtype=id_dtype)) @@ -2874,8 +3197,10 @@ def _scatter_transpose_rule(t, operand, indices, updates, *, assert not ad.is_undefined_primal(indices) if ad.is_undefined_primal(updates): updates_shape = updates.aval.shape + updates_sharding = updates.aval.sharding else: updates_shape = updates.shape + updates_sharding = core.typeof(updates).sharding if type(t) is ad_util.Zero: operand_t = ad_util.Zero(operand.aval) if ad.is_undefined_primal(operand) else None update_t = ad_util.Zero(updates.aval) if ad.is_undefined_primal(updates) else None @@ -2883,10 +3208,11 @@ def _scatter_transpose_rule(t, operand, indices, updates, *, operand_t = update_t = None if ad.is_undefined_primal(operand): # Zero out gradient entries that correspond to updated indices. - operand_t = scatter(t, indices, lax.full(updates_shape, 0, dtype=t.dtype), - dimension_numbers=dimension_numbers, - indices_are_sorted=indices_are_sorted, - unique_indices=True, mode=mode) + operand_t = scatter( + t, indices, + lax.full(updates_shape, 0, dtype=t.dtype, sharding=updates_sharding), + dimension_numbers=dimension_numbers, + indices_are_sorted=indices_are_sorted, unique_indices=True, mode=mode) if ad.is_undefined_primal(updates): gather_dnums = GatherDimensionNumbers( @@ -2915,11 +3241,14 @@ def _scatter_transpose_rule(t, operand, indices, updates, *, scatter_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter', - weak_type_rule=_argnum_weak_type(0)) + weak_type_rule=_argnum_weak_type(0), sharding_rule=_scatter_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'scatter'), + memory_space_rule=_scatter_memory_space_rule) ad.primitive_jvps[scatter_p] = _scatter_jvp ad.primitive_transposes[scatter_p] = _scatter_transpose_rule -batching.primitive_batchers[scatter_p] = ( +batching.fancy_primitive_batchers[scatter_p] = ( partial(_scatter_batching_rule, scatter_p)) +batching.skippable_batchers[scatter_p] = lambda _: () def _scatter_lower_opaque(ctx, operand, indices, updates, *, @@ -2944,8 +3273,8 @@ def _scatter_lower_opaque(ctx, operand, indices, updates, *, return res -def _scatter_lower(ctx, operand, indices, updates, *, - update_jaxpr, update_consts, dimension_numbers, +def _scatter_lower(ctx: mlir.LoweringRuleContext, operand, indices, updates, *, + update_jaxpr: core.Jaxpr, update_consts, dimension_numbers, indices_are_sorted, unique_indices, mode): if update_jaxpr is None: assert not update_consts @@ -2978,14 +3307,9 @@ def _scatter_lower(ctx, operand, indices, updates, *, result = mlir.aval_to_ir_type(aval_out) operand = [operand] updates = [updates] - op = hlo.ScatterOp( - (result,), - operand, - indices, - updates, - scatter_dnums, - indices_are_sorted=ir.BoolAttr.get(indices_are_sorted), - unique_indices=ir.BoolAttr.get(unique_indices)) + op = hlo.ScatterOp((result,), operand, indices, updates, scatter_dnums, + indices_are_sorted=ir.BoolAttr.get(indices_are_sorted), + unique_indices=ir.BoolAttr.get(unique_indices)) scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), aval_out.dtype)) update = op.update_computation.blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(update): @@ -2995,9 +3319,10 @@ def _scatter_lower(ctx, operand, indices, updates, *, out_nodes, _ = mlir.jaxpr_subcomp( ctx.module_context, update_jaxpr, name_stack, mlir.TokenSet(), update_consts, update.arguments[0], update.arguments[1], - dim_var_values=ctx.dim_var_values) + dim_var_values=ctx.dim_var_values, const_lowering=ctx.const_lowering) hlo.return_(mlir.flatten_ir_values(out_nodes)) - return op.results + return [mlir.lower_with_sharding_in_types(ctx, r, aval) + for r, aval in safe_zip(op.results, ctx.avals_out)] mlir.register_lowering(scatter_p, _scatter_lower) mlir.register_lowering(scatter_add_p, _scatter_lower) @@ -3011,19 +3336,9 @@ def _real_dtype(dtype): return np.finfo(dtype).dtype def _scatter_addsub_lower_gpu( - ctx, - operand, - indices, - updates, - *, - update_jaxpr, - update_consts, - dimension_numbers, - indices_are_sorted, - unique_indices, - mode, - reduce_op, -): + ctx, operand, indices, updates, *, update_jaxpr, update_consts, + dimension_numbers, indices_are_sorted, unique_indices, mode, + reduce_op): operand_aval_in, _, updates_aval_in = ctx.avals_in if operand_aval_in.dtype != np.complex128: return _scatter_lower(ctx, operand, indices, updates, @@ -3057,18 +3372,14 @@ def _scatter(operand_part, updates_part): updates_part = [updates_part] scatter = hlo.ScatterOp( - (operand_type_part,), - operand_part, - indices, - updates_part, - scatter_dnums, + (operand_type_part,), operand_part, indices, updates_part, scatter_dnums, indices_are_sorted=ir.BoolAttr.get(indices_are_sorted), unique_indices=ir.BoolAttr.get(unique_indices)) scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), real_dtype)) reducer = scatter.regions[0].blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(reducer): hlo.return_([reduce_op(*reducer.arguments).result]) - return scatter.result + return mlir.lower_with_sharding_in_types(ctx, scatter.result, aval_out) real = _scatter(hlo.real(operand), hlo.real(updates)) imag = _scatter(hlo.imag(operand), hlo.imag(updates)) diff --git a/jax/_src/lax/special.py b/jax/_src/lax/special.py index b70513bc2d20..f2ab8f3f0ce5 100644 --- a/jax/_src/lax/special.py +++ b/jax/_src/lax/special.py @@ -19,8 +19,9 @@ from enum import Enum import numpy as np -from functools import partial +from functools import partial, reduce as _reduce +from jax._src import core from jax._src.lax.lax import (add, bitwise_and, bitwise_not, bitwise_or, broadcast_in_dim, broadcast_shapes, convert_element_type, div, eq, exp, full_like, ge, @@ -28,8 +29,8 @@ reduce, select, sign, sqrt, square, standard_naryop, standard_unop, sub, _const, _dtype, - _float, _nary_lower_hlo, _ones, _isnan, _reduce) -from jax._src.lax.control_flow import while_loop + _float, _nary_lower_hlo, _ones, _isnan) +from jax._src.lax.control_flow.loops import while_loop from jax._src import dtypes from jax._src.interpreters import ad @@ -37,8 +38,28 @@ from jax._src.lib.mlir.dialects import chlo from jax._src.typing import Array, ArrayLike +# TODO(mattjj): this function sucks, delete it +def _up_and_broadcast(doit): + def up_and_broadcast(*args): + broadcasted_shape = broadcast_shapes(*(a.shape for a in args)) + args = [broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim))) for a in args] + + a_dtype = args[0].dtype + needs_upcast = a_dtype == dtypes.bfloat16 or a_dtype == np.float16 + if needs_upcast: + args = [convert_element_type(a, np.float32) for a in args] + a_x_type = np.float32 + else: + a_x_type = a_dtype + result = doit(*args, dtype=a_x_type) + if needs_upcast: + result = convert_element_type(result, a_dtype) + return result + return up_and_broadcast + def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise regularized incomplete beta integral.""" + a, b, x = core.standard_insert_pvary(a, b, x) return regularized_incomplete_beta_p.bind(a, b, x) def lgamma(x: ArrayLike) -> Array: @@ -51,26 +72,33 @@ def digamma(x: ArrayLike) -> Array: def polygamma(m: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise polygamma: :math:`\psi^{(m)}(x)`.""" + m, x = core.standard_insert_pvary(m, x) return polygamma_p.bind(m, x) def igamma(a: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise regularized incomplete gamma function.""" + a, x = core.standard_insert_pvary(a, x) return igamma_p.bind(a, x) def igammac(a: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise complementary regularized incomplete gamma function.""" + a, x = core.standard_insert_pvary(a, x) return igammac_p.bind(a, x) def igamma_grad_a(a: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise derivative of the regularized incomplete gamma function.""" + a, x = core.standard_insert_pvary(a, x) return igamma_grad_a_p.bind(a, x) -def random_gamma_grad(a: ArrayLike, x: ArrayLike) -> Array: +@_up_and_broadcast +def random_gamma_grad(a: ArrayLike, x: ArrayLike, *, dtype) -> Array: r"""Elementwise derivative of samples from `Gamma(a, 1)`.""" - return random_gamma_grad_p.bind(a, x) + a, x = core.standard_insert_pvary(a, x) + return random_gamma_grad_impl(a, x, dtype=dtype) def zeta(x: ArrayLike, q: ArrayLike) -> Array: r"""Elementwise Hurwitz zeta function: :math:`\zeta(x, q)`""" + x, q = core.standard_insert_pvary(x, q) return zeta_p.bind(x, q) def bessel_i0e(x: ArrayLike) -> Array: @@ -194,12 +222,18 @@ def nth_partial_betainc_numerator(iteration, a, b, x): iteration_is_one = eq(iteration_bcast, full_like(iteration_bcast, 1)) iteration_minus_one = iteration_bcast - full_like(iteration_bcast, 1) m = iteration_minus_one // full_like(iteration_minus_one, 2) + m_is_zero = eq(m, full_like(m, 0)) m = convert_element_type(m, dtype) one = full_like(a, 1) two = full_like(a, 2.0) # Partial numerator terms - even_numerator = -(a + m) * (a + b + m) * x / ( - (a + two * m) * (a + two * m + one)) + + # When a is close to zero and m == 0, using zero_numerator avoids + # inaccuracies when FTZ or DAZ is enabled: + zero_numerator = -(a + b) * x / (a + one) + even_numerator = select(m_is_zero, zero_numerator, + -(a + m) * (a + b + m) * x / ( + (a + two * m) * (a + two * m + one))) odd_numerator = m * (b - m) * x / ((a + two * m - one) * (a + two * m)) one_numerator = full_like(x, 1.0) numerator = select(iteration_is_even, even_numerator, odd_numerator) @@ -210,12 +244,24 @@ def nth_partial_betainc_denominator(iteration, a, b, x): return select(eq(iteration_bcast, full_like(iteration_bcast, 0)), full_like(x, 0), full_like(x, 1)) + a_is_zero = bitwise_or(eq(a, full_like(a, 0)), eq(b, full_like(b, float('inf')))) + b_is_zero = bitwise_or(eq(b, full_like(b, 0)), eq(a, full_like(a, float('inf')))) + x_is_zero = eq(x, full_like(x, 0)) + x_is_one = eq(x, full_like(x, 1)) + x_is_not_zero = bitwise_not(x_is_zero) + x_is_not_one = bitwise_not(x_is_one) + is_nan = bitwise_or(bitwise_or(_isnan(a), _isnan(b)), _isnan(x)) + + result_is_zero = bitwise_or(bitwise_and(b_is_zero, x_is_not_one), bitwise_and(a_is_zero, x_is_zero)) + result_is_one = bitwise_or(bitwise_and(a_is_zero, x_is_not_zero), bitwise_and(b_is_zero, x_is_one)) + result_is_nan = bitwise_or(bitwise_or(bitwise_or( - le(a, full_like(a, 0)), le(b, full_like(b, 0))), + lt(a, full_like(a, 0)), lt(b, full_like(b, 0))), lt(x, full_like(x, 0))), gt(x, full_like(x, 1))) + result_is_nan = bitwise_or(result_is_nan, bitwise_or(bitwise_and(a_is_zero, b_is_zero), is_nan)) - # The continued fraction will converge rapidly when x < (a+1)/(a+b+2) - # as per: http://dlmf.nist.gov/8.17.E23 + # The continued fraction will converge rapidly when x < + # (a+1)/(a+b+2) as per: http://dlmf.nist.gov/8.17.E23. # # Otherwise, we can rewrite using the symmetry relation as per: # http://dlmf.nist.gov/8.17.E4 @@ -234,10 +280,21 @@ def nth_partial_betainc_denominator(iteration, a, b, x): inputs=[a, b, x] ) - lbeta_ab = lgamma(a) + lgamma(b) - lgamma(a + b) - result = continued_fraction * exp(log(x) * a + log1p(-x) * b - lbeta_ab) / a + # For very small a and to avoid division by zero, we'll use + # a * gamma(a) = gamma(a + 1) -> 1 as a -> 0+. + very_small = (dtypes.finfo(dtype).tiny * 2).astype(dtype) + lbeta_ab_small_a = lgamma(b) - lgamma(a + b) + lbeta_ab = lgamma(a) + lbeta_ab_small_a + factor = select(lt(a, full_like(a, very_small)), + exp(log1p(-x) * b - lbeta_ab_small_a), + exp(log(x) * a + log1p(-x) * b - lbeta_ab) / a) + result = continued_fraction * factor + result = select(converges_rapidly, result, sub(full_like(result, 1), result)) + + result = select(result_is_zero, full_like(a, 0), result) + result = select(result_is_one, full_like(a, 1), result) result = select(result_is_nan, full_like(a, float('nan')), result) - return select(converges_rapidly, result, sub(full_like(result, 1), result)) + return result class IgammaMode(Enum): VALUE = 1 @@ -494,24 +551,6 @@ def random_gamma_grad_impl(a, x, *, dtype): full_like(a, float('nan')), output) return output -def _up_and_broadcast(doit): - def up_and_broadcast(*args): - broadcasted_shape = broadcast_shapes(*(a.shape for a in args)) - args = [broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim))) for a in args] - - a_dtype = args[0].dtype - needs_upcast = a_dtype == dtypes.bfloat16 or a_dtype == np.float16 - if needs_upcast: - args = [convert_element_type(a, np.float32) for a in args] - a_x_type = np.float32 - else: - a_x_type = a_dtype - result = doit(*args, dtype=a_x_type) - if needs_upcast: - result = convert_element_type(result, a_dtype) - return result - return up_and_broadcast - def evaluate_chebyshev_polynomial(x, coefficients): b0 = full_like(x,0) @@ -657,11 +696,6 @@ def bessel_i0e_impl(x): ad.defjvp(igammac_p, igammac_grada, igammac_gradx) -random_gamma_grad_p = standard_naryop([_float, _float], 'random_gamma_grad') -mlir.register_lowering(random_gamma_grad_p, - mlir.lower_fun(_up_and_broadcast(random_gamma_grad_impl), - multiple_results=False)) - zeta_p = standard_naryop([_float, _float], 'zeta') mlir.register_lowering(zeta_p, partial(_nary_lower_hlo, chlo.zeta)) diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index f39d925ac2ad..669ffc510ae0 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -18,32 +18,38 @@ from functools import partial +import numpy as np + from jax._src import core from jax._src import dispatch from jax._src import dtypes from jax._src import mesh as mesh_lib -from jax._src.util import safe_zip +from jax._src import state +from jax._src.named_sharding import DuplicateSpecError, NamedSharding from jax._src.partition_spec import PartitionSpec as P -from jax._src.named_sharding import NamedSharding, DuplicateSpecError +from jax._src.util import safe_zip +from jax._src.typing import DimSize, DType, Shape zip, unsafe_zip = safe_zip, zip -import numpy as np -def _input_dtype(x, *_, **__): - return dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True) +def input_dtype(x, *_, **__): + return x.dtype def _argnum_weak_type(*argnums): return lambda *args, **_: all(args[i].weak_type for i in argnums) def standard_primitive(shape_rule, dtype_rule, name, - weak_type_rule=None, sharding_rule=None): + weak_type_rule=None, sharding_rule=None, vma_rule=None, + unreduced_rule=None, reduced_rule=None, + memory_space_rule=None): weak_type_rule = weak_type_rule or _standard_weak_type_rule prim = core.Primitive(name) prim.def_impl(partial(dispatch.apply_primitive, prim)) prim.def_abstract_eval( partial(standard_abstract_eval, prim, shape_rule, dtype_rule, - weak_type_rule, sharding_rule)) + weak_type_rule, sharding_rule, vma_rule, unreduced_rule, + reduced_rule, memory_space_rule)) return prim def _get_array_abstraction_level(a): return a.array_abstraction_level @@ -56,7 +62,7 @@ def _get_abstract_mesh_from_avals(in_avals) -> mesh_lib.AbstractMesh: if a.sharding.mesh.empty: continue if m is not None and m != a.sharding.mesh: - if m._are_all_axes_auto and a.sharding.mesh._are_all_axes_auto: + if m.are_all_axes_auto and a.sharding.mesh.are_all_axes_auto: return mesh_lib.empty_abstract_mesh raise ValueError( f'Mesh for all inputs should be equal. Got one mesh: {m} and' @@ -64,92 +70,161 @@ def _get_abstract_mesh_from_avals(in_avals) -> mesh_lib.AbstractMesh: m = a.sharding.mesh return mesh_lib.empty_abstract_mesh if m is None else m +def call_reduced_rule(prim, reduced_rule, out_s, num_out, *avals, **kwargs): + if reduced_rule is not None: + return reduced_rule(out_s, *avals, **kwargs) + if any(a.sharding.spec.reduced for a in avals): + raise NotImplementedError( + f'reduced rule for {prim.name} is not implemented. Please file an' + ' issue at https://github.com/jax-ml/jax/issues') + if any(s.spec.reduced for s in ([out_s] if num_out is None else out_s) + if s is not None): + raise NotImplementedError( + f'reduced rule for {prim.name} is not implemented. Please file an' + ' issue at https://github.com/jax-ml/jax/issues') + return out_s + +def call_unreduced_rule(prim, unreduced_rule, out_s, num_out, *avals, **kwargs): + if unreduced_rule is not None: + return unreduced_rule(out_s, *avals, **kwargs) -def call_sharding_rule(prim, rule, num_out, *avals, **kwargs): + if any(a.sharding.spec.unreduced for a in avals): + raise NotImplementedError( + f'unreduced rule for {prim.name} is not implemented. Please file an' + ' issue at https://github.com/jax-ml/jax/issues') + if any(s.spec.unreduced for s in ([out_s] if num_out is None else out_s) + if s is not None): + raise NotImplementedError( + f'unreduced rule for {prim.name} is not implemented. Please file an' + ' issue at https://github.com/jax-ml/jax/issues') + return out_s + +def call_sharding_rule(prim, sh_rule, unreduced_rule, reduced_rule, num_out, + *avals, **kwargs): cur_mesh = mesh_lib.get_abstract_mesh() aval_mesh = _get_abstract_mesh_from_avals(avals) if ((cur_mesh.empty or cur_mesh._are_all_axes_auto_or_manual) and (aval_mesh.empty or aval_mesh._are_all_axes_auto_or_manual)): aval_mesh = cur_mesh if aval_mesh.empty else aval_mesh - s = NamedSharding(aval_mesh, P()) - return s if num_out is None else [s] * num_out - if rule is None: - raise ValueError( - f'sharding rule for {prim.name} is not implemented. Please file a' - ' bug at https://github.com/jax-ml/jax/issues. You can work around' + out_s = NamedSharding(aval_mesh, P()) + out_s = out_s if num_out is None else [out_s] * num_out + out_s = call_reduced_rule( + prim, reduced_rule, out_s, num_out, *avals, **kwargs) + out_s = call_unreduced_rule( + prim, unreduced_rule, out_s, num_out, *avals, **kwargs) + return out_s + if sh_rule is None: + raise core.ShardingTypeError( + f'sharding rule for {prim.name} is not implemented. Please file an' + ' issue at https://github.com/jax-ml/jax/issues. You can work around' ' this error by dropping that operation into full auto sharding' - ' mode via: `jax.experimental.shard.auto_axes(fun, out_shardings=...)`') - return rule(*avals, **kwargs) + ' mode via: `jax.sharding.auto_axes(fun, out_shardings=...)`') + out_sharding = sh_rule(*avals, **kwargs) + out_sharding = call_reduced_rule( + prim, reduced_rule, out_sharding, num_out, *avals, **kwargs) + out_sharding = call_unreduced_rule( + prim, unreduced_rule, out_sharding, num_out, *avals, **kwargs) + return out_sharding -def call_shape_dtype_sharding_rule(prim, shape_rule, dtype_rule, sharding_rule, - multi_out, *avals, **kwargs): +def call_shape_dtype_sharding_rule( + prim, shape_rule, dtype_rule, sharding_rule, unreduced_rule, reduced_rule, + multi_out, *avals, **kwargs): out_shapes = shape_rule(*avals, **kwargs) out_dtypes = dtype_rule(*avals, **kwargs) num_out = len(out_shapes) if multi_out else None try: out_shardings = call_sharding_rule( - prim, sharding_rule, num_out, *avals, **kwargs) + prim, sharding_rule, unreduced_rule, reduced_rule, num_out, + *avals, **kwargs) except DuplicateSpecError as e: if multi_out: raise avals_str = ', '.join(i.str_short(short_dtypes=True) for i in avals) mesh = mesh_lib.empty_abstract_mesh if e.mesh is None else e.mesh - out_aval_str = core.str_short_aval(out_shapes, out_dtypes, mesh, e.pspec, - short_dtypes=True) - raise TypeError( + out_aval_str = core.str_short_aval( + out_shapes, out_dtypes, mesh, e.pspec, frozenset(), + core.MemorySpace.Device, short_dtypes=True) + raise core.ShardingTypeError( f'{prim} operation with inputs: {avals_str} produces an illegally' f' sharded result: {out_aval_str}') from e return out_shapes, out_dtypes, out_shardings -def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, - sharding_rule, *avals, **kwargs): - assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals +def _default_memory_space_rule(prim, *avals, **kwargs): + if all(a.memory_space == core.MemorySpace.Any for a in avals): + return core.MemorySpace.Any + prev_aval = None + for a in avals: + if not a.ndim: + continue + if prev_aval is not None and prev_aval.memory_space != a.memory_space: + raise ValueError( + f'memory_space of all inputs passed to `{prim.name}` must be the' + f' same. Got one operand with type: {prev_aval.str_short()} and' + f' another operand with type: {a.str_short()}') + prev_aval = a + if prev_aval is None: + return core.MemorySpace.Device + return prev_aval.memory_space + +def multi_mem_space_rule(prim, num_out, *avals, **kwargs): + out_mem_space = _default_memory_space_rule(prim, *avals, **kwargs) + return [out_mem_space] * num_out + + +def standard_abstract_eval( + prim, shape_rule, dtype_rule, weak_type_rule, sharding_rule, vma_rule, + unreduced_rule, reduced_rule, memory_space_rule, *avals, **kwargs): + for a in avals: + if isinstance(a, state.AbstractRef): + raise ValueError(f'Attempting to pass a Ref {a} to a primitive: ' + f'{prim} -- did you forget to unpack ([...]) the ref?') + if not isinstance(a, core.ShapedArray): + raise ValueError(f'Attempting to pass an unexpected type {a} to a ' + f'primitive: {prim}') + assert all(isinstance(aval, core.ShapedArray) for aval in avals), avals assert not prim.multiple_results weak_type = weak_type_rule(*avals, **kwargs) least_specialized = type(max(avals, key=_get_array_abstraction_level)) if least_specialized is core.ShapedArray: core.check_avals_context_mesh(avals, prim.name) out_shape, out_dtype, out_sharding = call_shape_dtype_sharding_rule( - prim, shape_rule, dtype_rule, sharding_rule, False, - *avals, **kwargs) + prim, shape_rule, dtype_rule, sharding_rule, unreduced_rule, + reduced_rule, False, *avals, **kwargs) + out_vma = vma_rule(*avals, **kwargs) + out_mem_space = (_default_memory_space_rule(prim, *avals, **kwargs) + if memory_space_rule is None else + memory_space_rule(*avals, **kwargs)) out_aval = core.ShapedArray( - out_shape, out_dtype, weak_type=weak_type, sharding=out_sharding) + out_shape, out_dtype, weak_type=weak_type, sharding=out_sharding, + vma=out_vma, memory_space=out_mem_space) core.check_avals_context_mesh([out_aval], prim.name) return out_aval - elif least_specialized is core.DShapedArray: - shape = shape_rule(*avals, **kwargs) - ty = (core.ShapedArray if all(type(d) is int for d in shape) - else core.DShapedArray) - return ty(shape, dtype_rule(*avals, **kwargs), weak_type) - elif least_specialized is core.UnshapedArray: - return core.UnshapedArray(dtype_rule(*avals, **kwargs), weak_type=weak_type) else: raise TypeError(avals, least_specialized) def standard_multi_result_abstract_eval( - prim, shape_rule, dtype_rule, weak_type_rule, sharding_rule, + prim, shape_rule, dtype_rule, weak_type_rule, sharding_rule, vma_rule, *avals, **kwargs): assert prim.multiple_results - assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals + assert all(isinstance(aval, core.ShapedArray) for aval in avals), avals least_specialized = max(map(type, avals), key=_get_array_abstraction_level) weak_types = weak_type_rule(*avals, **kwargs) if least_specialized is core.ShapedArray: core.check_avals_context_mesh(avals, prim.name) out_shapes, out_dtypes, out_shardings = call_shape_dtype_sharding_rule( - prim, shape_rule, dtype_rule, sharding_rule, True, *avals, **kwargs) + prim, shape_rule, dtype_rule, sharding_rule, None, None, True, + *avals, **kwargs) + out_vmas = vma_rule(*avals, **kwargs) + out_mem_spaces = multi_mem_space_rule(prim, len(out_shapes), *avals, **kwargs) if isinstance(weak_types, bool): weak_types = (weak_types,) * len(out_shapes) - out_avals = [core.ShapedArray(s, d, weak_type=weak_type, sharding=sh) - for s, d, weak_type, sh in zip(out_shapes, out_dtypes, - weak_types, out_shardings)] + out_avals = [core.ShapedArray(s, d, weak_type=weak_type, sharding=sh, + vma=vma, memory_space=ms) + for s, d, weak_type, sh, vma, ms in zip( + out_shapes, out_dtypes, weak_types, out_shardings, + out_vmas, out_mem_spaces)] core.check_avals_context_mesh(out_avals, prim.name) return out_avals - elif least_specialized is core.UnshapedArray: - out_dtypes = dtype_rule(*avals, **kwargs) - if isinstance(weak_types, bool): - weak_types = (weak_types,) * len(out_dtypes) - return [core.UnshapedArray(dtype, weak_type=weak_type) - for dtype, weak_type in zip(out_dtypes, weak_types)] else: raise TypeError(avals, least_specialized) @@ -167,3 +242,36 @@ def dtype_to_string(dtype): except AttributeError: pass return str(dtype) + +_int32_max = np.iinfo(np.int32).max +_uint32_max = np.iinfo(np.uint32).max + +def int_dtype_for_dim(d: DimSize, *, signed: bool) -> DType: + """Returns a integer dtype large enough to contain indices in dimension d.""" + if signed: + if not core.is_constant_dim(d): + return dtypes.default_int_dtype() + return np.dtype(np.int64) if d > _int32_max else np.dtype(np.int32) + else: + if not core.is_constant_dim(d): + return dtypes.default_uint_dtype() + return np.dtype(np.uint64) if d > _uint32_max else np.dtype(np.uint32) + +def int_dtype_for_shape(shape: Shape, *, signed: bool) -> DType: + """Returns a integer dtype large enough to contain indices in `shape`.""" + if signed: + for d in shape: + if core.is_constant_dim(d): + if d > _int32_max: + return np.dtype(np.int64) + else: + return dtypes.default_int_dtype() + return np.dtype(np.int32) + else: + for d in shape: + if core.is_constant_dim(d): + if d > _uint32_max: + return np.dtype(np.uint64) + else: + return dtypes.default_uint_dtype() + return np.dtype(np.uint32) diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 400646f6238f..00047f5d33c3 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -16,15 +16,17 @@ from collections.abc import Callable, Sequence from functools import partial +from typing import Any import warnings -from jax import tree_util +from jax._src import ad_util from jax._src import api_util from jax._src import core from jax._src import dispatch from jax._src import dtypes +from jax._src import tree_util from jax._src import util -from jax._src.core import ShapedArray +from jax._src.core import ClosedJaxpr, ShapedArray, jaxpr_as_fun from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -35,11 +37,8 @@ from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.typing import Array + import numpy as np -from jax._src.core import ClosedJaxpr -from jax._src.core import jaxpr_as_fun -from jax._src.interpreters.ad import jvp_jaxpr -from jax._src import ad_util map = util.safe_map zip = util.safe_zip @@ -50,7 +49,7 @@ def _reduce_window( init_value, computation, window_dimensions: core.Shape, - window_strides: Sequence[int], + window_strides: Sequence[int] | None, padding: str | Sequence[tuple[int, int]], base_dilation: Sequence[int] | None = None, window_dilation: Sequence[int] | None = None, @@ -77,9 +76,11 @@ def _reduce_window( window_dimensions if window_dilation is None else lax._dilate_shape(window_dimensions, window_dilation)) padding = tuple(lax.padtype_to_pads( - flat_operands[0].shape, dilated_window_dims, window_strides, padding)) + flat_operands[0].shape, dilated_window_dims, window_strides or [], padding)) else: - padding = tuple(padding) + padding = tuple((x, y) for x, y in padding) + if window_strides is None: + window_strides = (1,) * len(window_dimensions) if base_dilation is None: base_dilation = (1,) * len(window_dimensions) if window_dilation is None: @@ -97,6 +98,7 @@ def _reduce_window( raise ValueError( 'reduce_window output must have the same tree structure as the operands' f' {operand_tree} vs. {out_tree}') + flat_operands = core.standard_insert_pvary(*flat_operands) out_flat = reduce_window_p.bind( *flat_operands, *flat_init_values, @@ -112,18 +114,55 @@ def _reduce_window( def reduce_window( - operand, - init_value, + operand: Any, + init_value: Any, computation: Callable, window_dimensions: core.Shape, - window_strides: Sequence[int], - padding: str | Sequence[tuple[int, int]], + window_strides: Sequence[int] | None = None, + padding: str | Sequence[tuple[int, int]] = "VALID", base_dilation: Sequence[int] | None = None, window_dilation: Sequence[int] | None = None, -) -> Array: - """Wraps XLA's `ReduceWindowWithGeneralPadding - `_ - operator. +) -> Any: + """Reduction over padded windows. + + Wraps XLA's ReduceWindowWithGeneralPadding_ operator. + + Args: + operand: input array or tree of arrays. + init_value: value or tree of values. Tree structure must match that + of ``operand``. + computation: callable function over which to reduce. Input and output must be + a tree of the same structure as ``operand``. + window_dimensions: sequence of integers specifying the window size. + window_strides: optional sequence of integers specifying the strides, of + the same length as ``window_dimensions``. Default (``None``) indicates + a unit stride in each window dimension. + padding: string or sequence of integer tuples specifying the type of padding + to use (default: "VALID"). If a string, must be one of "VALID", "SAME", or + "SAME_LOWER". See the :func:`jax.lax.padtype_to_pads` utility. + base_dilation: optional sequence of integers for base dilation values, of + the same length as ``window_dimensions``. Default (``None``) indicates unit + dilation in each window dimension. + window_dilation: optional sequence of integers for window dilation values, of + the same length as ``window_dimensions``. Default (``None``) indicates unit + dilation in each window dimension. + + Returns: + A tree of arrays with the same structure as ``operand``. + + Example: + Here is a simple example of a windowed product over pairs in a 1-dimensional array: + + >>> import jax + >>> x = jax.numpy.arange(10, dtype='float32') + >>> x + Array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], dtype=float32) + + >>> initial = jax.numpy.float32(1) + >>> jax.lax.reduce_window(x, initial, jax.lax.mul, window_dimensions=(2,)) + Array([ 0., 2., 6., 12., 20., 30., 42., 56., 72.], dtype=float32) + + .. _ReduceWindowWithGeneralPadding: https://www.openxla.org/xla/operation_semantics#reducewindow """ return _reduce_window( operand, @@ -250,6 +289,8 @@ def _select_and_scatter(operand: Array, select: Callable, select, core.get_aval(init_value)) scatter_jaxpr, scatter_consts = lax._reduction_jaxpr( scatter, core.get_aval(init_value)) + operand, source, init_value = core.standard_insert_pvary( + operand, source, init_value) return select_and_scatter_p.bind( operand, source, init_value, select_jaxpr=select_jaxpr, select_consts=select_consts, scatter_jaxpr=scatter_jaxpr, @@ -261,6 +302,7 @@ def _select_and_scatter_add(source: Array, operand: Array, window_dimensions: core.Shape, window_strides: Sequence[int], padding: Sequence[tuple[int, int]]) -> Array: + source, operand = core.standard_insert_pvary(source, operand) return select_and_scatter_add_p.bind( source, operand, select_prim=select_prim, window_dimensions=tuple(window_dimensions), @@ -277,7 +319,7 @@ def _select_and_gather_add(tangents: Array, operand: Array, each window of the `operand` array. Wraps XLA's `ReduceWindow - `_ + `_ operator, which applies a reduction function to all elements in each window of the input multi-dimensional array. In this case, the input multi-dimensional array is built by packing each element in the `operand` array with its @@ -296,6 +338,7 @@ def _select_and_gather_add(tangents: Array, operand: Array, An array containing the elements in `tangents` corresponding to the output of the reduction of `operand` fin each window. """ + tangents, operand = core.standard_insert_pvary(tangents, operand) return select_and_gather_add_p.bind( tangents, operand, select_prim=select_prim, window_dimensions=tuple(window_dimensions), @@ -332,7 +375,8 @@ def _reduce_window_abstract_eval_rule( out_sharding = reduce_window_sharding_rule( operand_avals[0], window_dimensions, window_strides, padding, base_dilation, window_dilation) - return tuple(ShapedArray(out_shape, op.dtype, sharding=out_sharding) + vma = core.standard_vma_rule('reduce_window', *operand_avals) + return tuple(ShapedArray(out_shape, op.dtype, sharding=out_sharding, vma=vma) for op in operand_avals) @@ -398,7 +442,7 @@ def reduce_window_jvp( init_value_tangent = map(ad_util.instantiate, init_value_tangent) c_reduction_jaxpr = ClosedJaxpr(reduction_jaxpr, consts) - jvp_reduction = jvp_jaxpr(c_reduction_jaxpr, (True,) * len(tangents), [False] * len(init_value_tangent))[0] + jvp_reduction = ad.jvp_jaxpr(c_reduction_jaxpr, (True,) * len(tangents), [False] * len(init_value_tangent))[0] def wrapper(left, right): pl, tl = util.split_list(left, [n]) @@ -426,7 +470,7 @@ def wrapper(left, right): def _generic_reduce_window_lower( - ctx, + ctx: mlir.LoweringRuleContext, *args, jaxpr, consts, @@ -444,7 +488,7 @@ def reducer_body(reducer: ir.Block) -> Sequence[ir.Value]: raise NotImplementedError('Cannot lower effectful `reduce_window`.') out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, jaxpr, ctx.name_stack, mlir.TokenSet(), consts, *reducer.arguments, # type: ignore[misc] - dim_var_values=ctx.dim_var_values) + dim_var_values=ctx.dim_var_values, const_lowering=ctx.const_lowering) return mlir.flatten_ir_values(out_nodes) return mlir.reduce_window( @@ -514,25 +558,16 @@ def _reduce_window_batch_rule(reduce_window, batched_args, bdims, *, def reduce_window_sharding_rule(operand, window_dimensions, window_strides, padding, base_dilation, window_dilation): - if base_dilation is None: - base_dilation = [1] * operand.ndim - if window_dilation is None: - window_dilation = [1] * operand.ndim - - for spec, wdim, ws, pd, bd, wdil in zip( - operand.sharding.spec, window_dimensions, window_strides, padding, - base_dilation, window_dilation): - if spec is None: - continue - if not (wdim == 1 and ws == 1 and pd == 1 and bd == 1 and wdil == 1): - raise NotImplementedError( - "Only trivial windowing is supported along non-replicated" - f" dimensions. Got {operand.sharding.spec=}") - return operand.sharding + out_shape = reduce_window_shape_tuple( + operand.shape, window_dimensions, window_strides, padding, base_dilation, + window_dilation) + return lax.slicing._get_sharding_for_varying_out_shape( + out_shape, operand, 'reduce_window') reduce_window_sum_p = lax.standard_primitive( - _reduce_window_sum_shape_rule, lax._input_dtype, 'reduce_window_sum', - sharding_rule=reduce_window_sharding_rule) + _reduce_window_sum_shape_rule, lax.input_dtype, 'reduce_window_sum', + sharding_rule=reduce_window_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_window_sum')) ad.deflinear2(reduce_window_sum_p, _reduce_window_sum_transpose_rule) batching.primitive_batchers[reduce_window_sum_p] = partial( _reduce_window_batch_rule, _reduce_window_sum) @@ -597,16 +632,18 @@ def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides, reduce_window_max_p = lax.standard_primitive( - _common_reduce_window_shape_rule, lax._input_dtype, 'reduce_window_max', - sharding_rule=reduce_window_sharding_rule) + _common_reduce_window_shape_rule, lax.input_dtype, 'reduce_window_max', + sharding_rule=reduce_window_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_window_max')) ad.defjvp(reduce_window_max_p, partial(_reduce_window_chooser_jvp_rule, lax.max_p)) batching.primitive_batchers[reduce_window_max_p] = partial( _reduce_window_batch_rule, _reduce_window_max) reduce_window_min_p = lax.standard_primitive( - _common_reduce_window_shape_rule, lax._input_dtype, 'reduce_window_min', - sharding_rule=reduce_window_sharding_rule) + _common_reduce_window_shape_rule, lax.input_dtype, 'reduce_window_min', + sharding_rule=reduce_window_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_window_min')) ad.defjvp(reduce_window_min_p, partial(_reduce_window_chooser_jvp_rule, lax.min_p)) @@ -630,7 +667,8 @@ def _reduce_window_lower( ): operand_aval, = ctx.avals_in - scalar_aval = operand_aval.update(shape=()) + scalar_aval = operand_aval.update( + shape=(), sharding=operand_aval.sharding.update(spec=())) return mlir.reduce_window( ctx, @@ -670,16 +708,25 @@ def _select_and_scatter_shape_rule( raise TypeError(msg.format(window_strides, window_dimensions)) return operand.shape +def _select_and_scatter_sharding_rule( + operand, source, init_value, *, select_jaxpr, select_consts, scatter_jaxpr, + scatter_consts, window_dimensions, window_strides, padding): + return operand.sharding + select_and_scatter_p = lax.standard_primitive( - _select_and_scatter_shape_rule, lax._input_dtype, 'select_and_scatter') + _select_and_scatter_shape_rule, lax.input_dtype, 'select_and_scatter', + sharding_rule=_select_and_scatter_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'select_and_scatter')) def _select_and_scatter_lower( - ctx, operand, source, init_value, *, select_jaxpr, - select_consts, scatter_jaxpr, scatter_consts, window_dimensions, + ctx: mlir.LoweringRuleContext, operand, source, init_value, *, + select_jaxpr: core.Jaxpr, select_consts, + scatter_jaxpr: core.Jaxpr, scatter_consts, window_dimensions, window_strides, padding): operand_aval, source_aval, init_value_aval = ctx.avals_in aval_out, = ctx.avals_out - scalar_aval = operand_aval.update(shape=()) + scalar_aval = operand_aval.update( + shape=(), sharding=operand_aval.sharding.update(spec=())) scalar_type = mlir.aval_to_ir_type(scalar_aval) op = hlo.SelectAndScatterOp( mlir.aval_to_ir_type(aval_out), @@ -698,7 +745,8 @@ def _select_and_scatter_lower( ctx.name_stack, mlir.TokenSet(), select_consts, *select.arguments, - dim_var_values=ctx.dim_var_values) + dim_var_values=ctx.dim_var_values, + const_lowering=ctx.const_lowering) hlo.return_(mlir.flatten_ir_values(out_nodes)) scatter = op.scatter.blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(scatter): @@ -708,9 +756,11 @@ def _select_and_scatter_lower( ctx.name_stack, mlir.TokenSet(), scatter_consts, *scatter.arguments, - dim_var_values=ctx.dim_var_values) + dim_var_values=ctx.dim_var_values, + const_lowering=ctx.const_lowering) hlo.return_(mlir.flatten_ir_values(out_nodes)) - return op.results + return [mlir.lower_with_sharding_in_types(ctx, r, aval) + for r, aval in zip(op.results, ctx.avals_out)] mlir.register_lowering(select_and_scatter_p, _select_and_scatter_lower) @@ -719,6 +769,11 @@ def _select_and_scatter_add_shape_rule( padding): return operand.shape +def _select_and_scatter_add_sharding_rule( + source, operand, *, select_prim, window_dimensions, window_strides, + padding): + return operand.sharding + def _select_and_scatter_add_jvp( primals, tangents, *, select_prim, window_dimensions, window_strides, padding): @@ -765,8 +820,10 @@ def _select_and_scatter_add_batch_rule( return out, 0 select_and_scatter_add_p = lax.standard_primitive( - _select_and_scatter_add_shape_rule, lax._input_dtype, - 'select_and_scatter_add') + _select_and_scatter_add_shape_rule, lax.input_dtype, + 'select_and_scatter_add', + sharding_rule=_select_and_scatter_add_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'select_and_scatter_add')) ad.primitive_transposes[select_and_scatter_add_p] = \ _select_and_scatter_add_transpose @@ -826,7 +883,7 @@ def _select_and_gather_add_sharding_rule( tangents, operand, *, select_prim, window_dimensions, window_strides, padding, base_dilation, window_dilation): if tangents.sharding != operand.sharding: - raise TypeError( + raise core.ShardingTypeError( "select_and_gather_add tangents and operand shardings must match, " f"got {tangents.sharding} and {operand.sharding}.") return reduce_window_sharding_rule( @@ -1038,8 +1095,9 @@ def _select_and_gather_add_batching_rule( select_and_gather_add_p = lax.standard_primitive( - _select_and_gather_add_shape_rule, lax._input_dtype, - 'select_and_gather_add', sharding_rule=_select_and_gather_add_sharding_rule) + _select_and_gather_add_shape_rule, lax.input_dtype, + 'select_and_gather_add', sharding_rule=_select_and_gather_add_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'select_and_gather_add')) ad.primitive_jvps[select_and_gather_add_p] = _select_and_gather_add_jvp ad.primitive_transposes[select_and_gather_add_p] = \ _select_and_gather_add_transpose diff --git a/jax/_src/lax_reference.py b/jax/_src/lax_reference.py index 4d4c24b0500e..ab035fecc31a 100644 --- a/jax/_src/lax_reference.py +++ b/jax/_src/lax_reference.py @@ -189,8 +189,8 @@ def _bitcast_uint8_to_uint4(operand): def bitcast_convert_type(operand, dtype): operand = np.asarray(operand) - nbits_in = dtypes.bit_width(operand.dtype) - nbits_out = dtypes.bit_width(dtype) + nbits_in = dtypes.itemsize_bits(operand.dtype) + nbits_out = dtypes.itemsize_bits(dtype) if nbits_out > nbits_in: assert operand.shape[-1] == nbits_out // nbits_in @@ -285,7 +285,8 @@ def ragged_dot( out = np.zeros((m, n), dtype=lhs.dtype) result_iota = np.expand_dims(np.arange(out.shape[0]), list(range(1, out.ndim))) - start = 0 + result_iota = result_iota.astype(group_sizes.dtype) + start = np.asarray(0, dtype=group_sizes.dtype) for i, size in enumerate(group_sizes): out += np.where( np.logical_and(start <= result_iota, result_iota < (start + size)), @@ -319,7 +320,7 @@ def reshape(operand, new_sizes, dimensions=None): return np.reshape(np.transpose(operand, dimensions), new_sizes) def pad(operand, padding_value, padding_config): - # https://www.tensorflow.org/xla/operation_semantics#pad + # https://www.openxla.org/xla/operation_semantics#pad lo, hi, interior = util.unzip3(padding_config) # Handle first the positive edge padding and interior lo_pos, hi_pos = np.clip(lo, 0, None), np.clip(hi, 0, None) @@ -528,3 +529,14 @@ def reducer(operand, axis=0): result[out_idx] = py_binop(result[out_idx], operand[idx]) return result return reducer + +def top_k(operand, k, axis=-1): + if axis < 0: + axis = operand.ndim + axis + assert 0 <= axis < operand.ndim + operand_flipped = np.flip(operand, axis) + indices_flipped = np.argsort(operand_flipped, axis=axis, kind="stable") + indices_all = (operand.shape[axis] - 1 - np.flip(indices_flipped, axis)).astype(np.int32) + indices = indices_all[(_slice(None),) * axis + (_slice(k),)] + values = np.take_along_axis(operand, indices, axis=axis) + return values, indices diff --git a/jax/_src/layout.py b/jax/_src/layout.py index 5309f0b1fd9c..1634b02838b7 100644 --- a/jax/_src/layout.py +++ b/jax/_src/layout.py @@ -19,7 +19,8 @@ import numpy as np from jax._src.dtypes import iinfo, issubdtype from jax._src.sharding import Sharding -from jax._src.sharding_impls import AUTO as AutoSharding +from jax._src.named_sharding import AUTO as AutoSharding +from jax._src.util import tuple_insert from jax._src.lib import xla_client as xc Shape = tuple[int, ...] @@ -30,47 +31,57 @@ def __repr__(self): return "AUTO" -class DeviceLocalLayout: +class Layout: major_to_minor: tuple[int, ...] - _tiling: tuple[tuple[int, ...], ...] | None - _sub_byte_element_size_in_bits: int + tiling: tuple[tuple[int, ...], ...] | None + sub_byte_element_size_in_bits: int AUTO = AutoLayout() def __init__(self, major_to_minor: tuple[int, ...], - _tiling: tuple[tuple[int, ...], ...] | None = None, - _sub_byte_element_size_in_bits: int = 0): + tiling: tuple[tuple[int, ...], ...] | None = None, + sub_byte_element_size_in_bits: int = 0): self.major_to_minor = tuple(major_to_minor) - self._tiling = None if _tiling is None else tuple(map(tuple, _tiling)) - self._sub_byte_element_size_in_bits = _sub_byte_element_size_in_bits + self.tiling = None if tiling is None else tuple(map(tuple, tiling)) + self._sub_byte_element_size_in_bits = sub_byte_element_size_in_bits @staticmethod def from_pjrt_layout(pjrt_layout: xc.PjRtLayout): xla_layout = pjrt_layout._xla_layout() - return DeviceLocalLayout(xla_layout.minor_to_major()[::-1], # pytype: disable=wrong-arg-types - xla_layout.tiling(), # type: ignore[arg-type] - xla_layout.element_size_in_bits()) + return Layout(xla_layout.minor_to_major()[::-1], # pytype: disable=wrong-arg-types + xla_layout.tiling(), # type: ignore[arg-type] + xla_layout.element_size_in_bits()) def __repr__(self): return ( - f'DeviceLocalLayout(major_to_minor={self.major_to_minor},' - f' _tiling={self._tiling},' - f' _sub_byte_element_size_in_bits={self._sub_byte_element_size_in_bits})' + f'Layout(major_to_minor={self.major_to_minor},' + f' tiling={self.tiling},' + f' sub_byte_element_size_in_bits={self._sub_byte_element_size_in_bits})' ) def __hash__(self): - return hash((self.major_to_minor, self._tiling, + return hash((self.major_to_minor, self.tiling, self._sub_byte_element_size_in_bits)) def __eq__(self, other): - if not isinstance(other, DeviceLocalLayout): + if not isinstance(other, Layout): return False return (self.major_to_minor == other.major_to_minor and - self._tiling == other._tiling and + self.tiling == other.tiling and self._sub_byte_element_size_in_bits == other._sub_byte_element_size_in_bits) + def update(self, **kwargs): + if 'major_to_minor' not in kwargs: + kwargs['major_to_minor'] = self.major_to_minor + if 'tiling' not in kwargs: + kwargs['tiling'] = self.tiling + if 'sub_byte_element_size_in_bits' not in kwargs: + kwargs['sub_byte_element_size_in_bits'] = self._sub_byte_element_size_in_bits + return Layout(kwargs['major_to_minor'], kwargs['tiling'], + kwargs['sub_byte_element_size_in_bits']) + def _to_xla_layout(self, dtype) -> xc.Layout: - if self._tiling is None: + if self.tiling is None: xla_layout = xc.Layout(self.major_to_minor[::-1]) else: if self._sub_byte_element_size_in_bits != 0: @@ -79,7 +90,7 @@ def _to_xla_layout(self, dtype) -> xc.Layout: sub_byte_size = iinfo(dtype).bits if iinfo(dtype).bits < 8 else 0 else: sub_byte_size = 0 - xla_layout = xc.Layout(self.major_to_minor[::-1], self._tiling, + xla_layout = xc.Layout(self.major_to_minor[::-1], self.tiling, sub_byte_size) return xla_layout @@ -90,32 +101,32 @@ def check_compatible_aval(self, aval_shape: Shape): f' Got major_to_minor={self.major_to_minor} and shape={aval_shape}') -LayoutOptions = Union[DeviceLocalLayout, None, AutoLayout] # pytype: disable=invalid-annotation +LayoutOptions = Union[Layout, None, AutoLayout] # pytype: disable=invalid-annotation ShardingOptions = Union[Sharding, None, AutoSharding] -class Layout: - __slots__ = ['device_local_layout', 'sharding'] +class Format: + __slots__ = ['layout', 'sharding'] - def __init__(self, device_local_layout: LayoutOptions = None, + def __init__(self, layout: LayoutOptions = None, sharding: ShardingOptions = None): # If layout is concrete and sharding is not, error. - if (isinstance(device_local_layout, DeviceLocalLayout) and + if (isinstance(layout, Layout) and (sharding is None or isinstance(sharding, AutoSharding))): raise ValueError( 'Sharding has to be concrete when layout is of type' - f' {type(device_local_layout)}. Please pass a' - ' `jax.sharding.NamedSharding`, `jax.sharding.PositionalSharding` or' + f' {type(layout)}. Please pass a' + ' `jax.sharding.NamedSharding` or' ' `jax.sharding.SingleDeviceSharding` to the sharding argument. Got' f' sharding {sharding}' ) if not isinstance( - device_local_layout, (DeviceLocalLayout, type(None), AutoLayout)): + layout, (Layout, type(None), AutoLayout)): raise TypeError( - 'Invalid value received for the device_local_layout argument.' - ' Expected values are `None`, `DeviceLocalLayout.AUTO` or an' - f' instance of `DeviceLocalLayout`. Got {device_local_layout} of' - f' type {type(device_local_layout)}' + 'Invalid value received for the layout argument.' + ' Expected values are `None`, `Layout.AUTO` or an' + f' instance of `Layout`. Got {layout} of' + f' type {type(layout)}' ) if not isinstance( sharding, (Sharding, type(None), AutoSharding)): @@ -124,18 +135,24 @@ def __init__(self, device_local_layout: LayoutOptions = None, ' are `None`, `pjit.AUTO` or an instance of `jax.Sharding`. Got' f' {sharding} of type {type(sharding)}') - self.device_local_layout = device_local_layout + self.layout = layout self.sharding = sharding def __repr__(self): - return (f'Layout(device_local_layout={self.device_local_layout},' - f' sharding={self.sharding})') + return f'Format(layout={self.layout}, sharding={self.sharding})' def __hash__(self): - return hash((self.device_local_layout, self.sharding)) + return hash((self.layout, self.sharding)) def __eq__(self, other): - if not isinstance(other, Layout): + if not isinstance(other, Format): return False - return (self.device_local_layout == other.device_local_layout and + return (self.layout == other.layout and self.sharding == other.sharding) + + +def get_layout_for_vmap(dim: int, layout: Layout) -> Layout: + # Make the new dim major-most and shift all other dims by 1 in major_to_minor + new_m2m = tuple(m + 1 for m in layout.major_to_minor) + vmapped_major_to_minor = tuple_insert(new_m2m, dim, 0) + return layout.update(major_to_minor=vmapped_major_to_minor) diff --git a/jax/_src/lazy_loader.py b/jax/_src/lazy_loader.py index 14822bff3eff..9476f4681062 100644 --- a/jax/_src/lazy_loader.py +++ b/jax/_src/lazy_loader.py @@ -48,7 +48,7 @@ def __getattr__(name: str) -> Any: # for this ``name``. setattr(sys.modules[owner_name], name, value) return value - raise AttributeError(f"module '{package_name}' has no attribute '{name}") + raise AttributeError(f"module '{package_name}' has no attribute '{name}'") def __dir__() -> list[str]: return __all__ diff --git a/jax/_src/lib/BUILD b/jax/_src/lib/BUILD index 1fcbd4b6b7ef..798457a2586d 100644 --- a/jax/_src/lib/BUILD +++ b/jax/_src/lib/BUILD @@ -20,10 +20,7 @@ load( "pytype_strict_library", ) -package( - default_applicable_licenses = [], - default_visibility = ["//jax:internal"], -) +package(default_applicable_licenses = []) py_library_providing_imports_info( name = "lib", @@ -38,28 +35,8 @@ py_library_providing_imports_info( visibility = ["//jax:internal"] + jax_visibility("lib"), deps = [ "//jax:version", + "//jax/_src:lazy_loader", ] + if_building_jaxlib([ "//jaxlib", - "//jaxlib/mosaic/python:gpu_dialect", - "//jaxlib/mosaic/python:tpu_dialect", - "//jaxlib:cpu_feature_guard", - "//jaxlib:utils", - "//jaxlib/triton", - "//jaxlib/mlir/_mlir_libs:register_jax_dialects", - "//jaxlib/mlir:arithmetic_dialect", - "//jaxlib/mlir:builtin_dialect", - "//jaxlib/mlir:chlo_dialect", - "//jaxlib/mlir:func_dialect", - "//jaxlib/mlir:ir", - "//jaxlib/mlir:math_dialect", - "//jaxlib/mlir:memref_dialect", - "//jaxlib/mlir:mhlo_dialect", - "//jaxlib/mlir:pass_manager", - "//jaxlib/mlir:scf_dialect", - "//jaxlib/mlir:sdy_dialect", - "//jaxlib/mlir:sparse_tensor_dialect", - "//jaxlib/mlir:stablehlo_dialect", - "//jaxlib/mlir:vector_dialect", - # xla_client ]), ) diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index 7933bb769733..c611b58c3a09 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -17,10 +17,13 @@ from __future__ import annotations +import importlib import gc import os import pathlib import re +from types import ModuleType + try: import jaxlib as jaxlib @@ -33,14 +36,14 @@ import jax.version from jax.version import _minimum_jaxlib_version as _minimum_jaxlib_version_str try: - import jaxlib.version + import jaxlib.version # noqa: F401 except Exception as err: # jaxlib is too old to have version number. msg = f'This version of jax requires jaxlib version >= {_minimum_jaxlib_version_str}.' raise ImportError(msg) from err -# Checks the jaxlib version before importing anything else from jaxlib. +# Checks the jaxlib version before importing anything else. # Returns the jaxlib version string. def check_jaxlib_version(jax_version: str, jaxlib_version: str, minimum_jaxlib_version: str) -> tuple[int, ...]: @@ -77,45 +80,66 @@ def _parse_version(v: str) -> tuple[int, ...]: jaxlib_version=jaxlib.version.__version__, minimum_jaxlib_version=jax.version._minimum_jaxlib_version) -# Before importing any C compiled modules from jaxlib, first import the CPU +# Before importing any C compiled modules, first import the CPU # feature guard module to verify that jaxlib was compiled in a way that only # uses instructions that are present on this machine. import jaxlib.cpu_feature_guard as cpu_feature_guard cpu_feature_guard.check_cpu_features() -import jaxlib.utils as utils # noqa: F401 -import jaxlib.xla_client as xla_client +import jaxlib.xla_client as xla_client # noqa: F401 + +# Jaxlib code is split between the Jax and the XLA repositories. +# Only for the internal usage of the JAX developers, we expose a version +# number that can be used to perform changes without breaking the main +# branch on the Jax github. +jaxlib_extension_version: int = getattr(xla_client, '_version', 0) +ifrt_version: int = getattr(xla_client, '_ifrt_version', 0) + import jaxlib.lapack as lapack # noqa: F401 +import jaxlib.utils as utils # noqa: F401 +import jaxlib._jax as _jax # noqa: F401 + + + +import jaxlib.mlir._mlir_libs._jax_mlir_ext as jax_mlir_ext # noqa: F401 +from jaxlib._jax import guard_lib as guard_lib # noqa: F401 +from jaxlib._jax import jax_jit as jax_jit # noqa: F401 +from jaxlib._jax import pmap_lib as pmap_lib # noqa: F401 +from jaxlib._jax import pytree as pytree # noqa: F401 +from jaxlib._jax import Device as Device # noqa: F401 +from jaxlib import _profiler as _profiler # noqa: F401 +from jaxlib import _profile_data as _profile_data # noqa: F401 + +from jaxlib._jax import ffi as ffi # noqa: F401 +import jaxlib.cpu_sparse as cpu_sparse # noqa: F401 +has_cpu_sparse = True + +import jaxlib.weakref_lru_cache as weakref_lru_cache # noqa: F401 +import jaxlib._pretty_printer as _pretty_printer # noqa: F401 + +import jaxlib._ifrt_proxy as ifrt_proxy # noqa: F401 -xla_extension = xla_client._xla -pytree = xla_client._xla.pytree -jax_jit = xla_client._xla.jax_jit -pmap_lib = xla_client._xla.pmap_lib # XLA garbage collection: see https://github.com/jax-ml/jax/issues/14882 def _xla_gc_callback(*args): xla_client._xla.collect_garbage() gc.callbacks.append(_xla_gc_callback) -try: - import jaxlib.cuda._versions as cuda_versions # pytype: disable=import-error # noqa: F401 -except ImportError: +cuda_versions: ModuleType | None +for pkg_name in ['jax_cuda13_plugin', 'jax_cuda12_plugin', 'jaxlib.cuda']: try: - import jax_cuda12_plugin._versions as cuda_versions # pytype: disable=import-error # noqa: F401 + cuda_versions = importlib.import_module( + f'{pkg_name}._versions' + ) except ImportError: cuda_versions = None + else: + break import jaxlib.gpu_solver as gpu_solver # pytype: disable=import-error # noqa: F401 import jaxlib.gpu_sparse as gpu_sparse # pytype: disable=import-error # noqa: F401 import jaxlib.gpu_prng as gpu_prng # pytype: disable=import-error # noqa: F401 import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error # noqa: F401 -import jaxlib.hlo_helpers as hlo_helpers # pytype: disable=import-error # noqa: F401 - -# Jaxlib code is split between the Jax and the Tensorflow repositories. -# Only for the internal usage of the JAX developers, we expose a version -# number that can be used to perform changes without breaking the main -# branch on the Jax github. -xla_extension_version: int = getattr(xla_client, '_version', 0) import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error # noqa: F401 import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error # noqa: F401 @@ -123,9 +147,6 @@ def _xla_gc_callback(*args): import jaxlib.mosaic.python.mosaic_gpu as mosaic_gpu_dialect # pytype: disable=import-error # noqa: F401 import jaxlib.mosaic.python.tpu as tpu # pytype: disable=import-error # noqa: F401 -# Version number for MLIR:Python APIs, provided by jaxlib. -mlir_api_version = xla_client.mlir_api_version - # TODO(rocm): check if we need the same for rocm. def _cuda_path() -> str | None: @@ -141,17 +162,18 @@ def _try_cuda_nvcc_import() -> str | None: `nvvm/libdevice/libdevice.10.bc`. """ try: - from nvidia import cuda_nvcc # pytype: disable=import-error + nvcc_module = importlib.import_module('nvidia.cu13') except ImportError: - return None - - if hasattr(cuda_nvcc, '__file__') and cuda_nvcc.__file__ is not None: - # `cuda_nvcc` is a regular package. - cuda_nvcc_path = pathlib.Path(cuda_nvcc.__file__).parent - elif hasattr(cuda_nvcc, '__path__') and cuda_nvcc.__path__ is not None: - # `cuda_nvcc` is a namespace package, which might have multiple paths. - cuda_nvcc_path = None - for path in cuda_nvcc.__path__: + try: + nvcc_module = importlib.import_module('nvidia.cuda_nvcc') + except ImportError: + return None + + cuda_nvcc_path = None + if hasattr(nvcc_module, '__file__') and nvcc_module.__file__ is not None: + cuda_nvcc_path = pathlib.Path(nvcc_module.__file__).parent + elif hasattr(nvcc_module, '__path__') and nvcc_module.__path__ is not None: + for path in nvcc_module.__path__: if (pathlib.Path(path) / 'bin' / 'ptxas').exists(): cuda_nvcc_path = pathlib.Path(path) break @@ -160,14 +182,23 @@ def _try_cuda_nvcc_import() -> str | None: return str(cuda_nvcc_path) + def _try_bazel_runfiles() -> str | None: + """Try to get the path to the cuda installation in bazel runfiles.""" + python_runfiles = os.environ.get('PYTHON_RUNFILES') + if not python_runfiles: + return None + cuda_nvcc_root = os.path.join(python_runfiles, 'cuda_nvcc') + if os.path.exists(cuda_nvcc_root): + return cuda_nvcc_root + return None + if (path := _try_cuda_root_environment_variable()) is not None: return path elif (path := _try_cuda_nvcc_import()) is not None: return path + elif (path := _try_bazel_runfiles()) is not None: + return path return None cuda_path = _cuda_path() - -guard_lib = xla_client._xla.guard_lib -Device = xla_client._xla.Device diff --git a/jax/_src/lib/mlir/__init__.py b/jax/_src/lib/mlir/__init__.py index 5fc9dff3ac49..b87042d2dd72 100644 --- a/jax/_src/lib/mlir/__init__.py +++ b/jax/_src/lib/mlir/__init__.py @@ -16,4 +16,3 @@ import jaxlib.mlir.ir as ir import jaxlib.mlir.passmanager as passmanager -from jaxlib.mlir._mlir_libs import register_jax_dialects diff --git a/jax/_src/lib/mlir/dialects/__init__.py b/jax/_src/lib/mlir/dialects/__init__.py index a9bae8821db5..6bf689728144 100644 --- a/jax/_src/lib/mlir/dialects/__init__.py +++ b/jax/_src/lib/mlir/dialects/__init__.py @@ -17,8 +17,10 @@ from typing import Any, TYPE_CHECKING if TYPE_CHECKING: + from jaxlib.mlir.dialects import _gpu_ops_gen as _gpu_ops_gen from jaxlib.mlir.dialects import arith as arith from jaxlib.mlir.dialects import builtin as builtin + from jaxlib.mlir.dialects import cf as cf from jaxlib.mlir.dialects import chlo as chlo from jaxlib.mlir.dialects import func as func from jaxlib.mlir.dialects import gpu as gpu @@ -34,8 +36,10 @@ else: from jax._src import lazy_loader as _lazy __getattr__, __dir__, __all__ = _lazy.attach("jaxlib.mlir.dialects", [ + "_gpu_ops_gen", "arith", "builtin", + "cf", "chlo", "func", "gpu", @@ -51,11 +55,10 @@ ]) del _lazy -# TODO(bartchr): Once JAX is released with SDY, remove the try/except. -try: - from jaxlib.mlir.dialects import sdy as sdy -except ImportError: - sdy: Any = None # type: ignore[no-redef] +from jaxlib.mlir.dialects import mpmd +from jaxlib.mlir.dialects import sdy # Alias that is set up to abstract away the transition from MHLO to StableHLO. from jaxlib.mlir.dialects import stablehlo as hlo + +from jax._src import lib diff --git a/jax/_src/lib/mosaic_gpu.py b/jax/_src/lib/mosaic_gpu.py index 494112093029..37c190a409c5 100644 --- a/jax/_src/lib/mosaic_gpu.py +++ b/jax/_src/lib/mosaic_gpu.py @@ -18,6 +18,9 @@ try: from jaxlib.mosaic.gpu import _mosaic_gpu_ext # pytype: disable=import-error except ImportError: - from jax_cuda12_plugin import _mosaic_gpu_ext # pytype: disable=import-error + try: + from jax_cuda12_plugin import _mosaic_gpu_ext # pytype: disable=import-error + except ImportError: + from jax_cuda13_plugin import _mosaic_gpu_ext # pytype: disable=import-error except ImportError as e: raise ModuleNotFoundError("Failed to import the Mosaic GPU bindings") from e diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index 1497597ebd62..bf45604c2f5e 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -67,15 +67,17 @@ def trans1(static_arg, *dynamic_args, **kwargs): from collections.abc import Callable, Sequence from functools import partial import re -from typing import Any, Hashable, NamedTuple +import time +from typing import Any, NamedTuple +from collections.abc import Hashable import warnings import weakref from jax._src import config from jax._src import core from jax._src import traceback_util -from jax._src.tree_util import keystr, KeyPath, generate_key_paths -from jax._src.util import curry, cache_clearing_funs, HashableFunction +from jax._src.tree_util import KeyPath, generate_key_paths, keystr +from jax._src.util import curry, fun_name, register_cache traceback_util.register_exclusion(__file__) @@ -185,7 +187,7 @@ def __init__(self, f: Callable, @property def __name__(self): - return getattr(self.f, '__name__', '') + return fun_name(self.f, "") def wrap(self, gen, gen_static_args, out_store: Store | EqualStore | None) -> WrappedFun: @@ -225,6 +227,14 @@ def __eq__(self, other): self.params == other.params and self.in_type == other.in_type and self.debug_info == other.debug_info) + def replace_debug_info(self, dbg: core.DebugInfo) -> WrappedFun: + return WrappedFun(self.f, self.f_transformed, self.transforms, + self.stores, self.params, self.in_type, + dbg) + + def with_unknown_names(self) -> WrappedFun: + return self.replace_debug_info(self.debug_info.with_unknown_names()) + @curry def transformation2(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun: """Adds one more transformation to a WrappedFun. @@ -265,12 +275,9 @@ def transformation_with_aux2( out_thunk = lambda: out_store.val return fun.wrap(gen, gen_static_args, out_store), out_thunk -def fun_name(f): - try: - return f.__name__ - except: - return str(f) - +class InitialResultPaths: + pass +initial_result_paths = InitialResultPaths() class DebugInfo(NamedTuple): """Debugging info about a func, its arguments, and results.""" @@ -282,39 +289,39 @@ class DebugInfo(NamedTuple): which may be ''. """ - arg_names: tuple[str, ...] + arg_names: tuple[str, ...] | None """The paths of the flattened non-static argnames, e.g. `('x', 'dict_arg["a"]', ... )`. Uses the empty string for the args that do not correspond to user-named arguments, e.g., tangent args in `jax.jvp`, or for arguments that - we are not yet tracking properly. + we are not yet tracking properly. The value `None` denotes argument names. + At the moment, `arg_names` accuracy is best-effort. Use `safe_arg_names` to detect and handle an unexpected number of elements in `arg_names`. """ - result_paths: tuple[str, ...] | Callable[[], tuple[str, ...]] | None + result_paths: tuple[str, ...] | InitialResultPaths | Callable[[], tuple[str, ...]] | None """The paths to the flattened results, e.g., `('result[0]', result[1])` for a function that returns a tuple of arrays, or `(result,)` for a function that - returns a single array. - The result paths are not available while we are tracing the function, - instead we keep a thunk. It is possible for the result paths to be `None` - only when we first create a `DebugInfo`, before we put it in `lu.WrappedFun` - and before we start tracing. - Inside a `lu.WrappedFun` it can be only a thunk or a tuple of strings. - Once we are done tracing, we use - `self.resolve_result_paths()` to execute the thunk and replace the - actual result paths. - At the moment, `result_paths` accuracy is best-effort. + returns a single array. The value `None` denotes unknown paths. + + When we first create a `DebugInfo`, we may use the value + `initial_result_paths`, which we replace with a thunk when we put the + debug info into a `lu.WrappedFun`, before we start tracing. After tracing, + we call `self.resolve_result_paths()` to execute the thunk and replace + the result paths with a tuple. + Use `safe_result_paths` to detect and handle an unexpected number of elements in `result_paths`. """ def resolve_result_paths(self) -> DebugInfo: """Return a debug info with resolved result paths.""" - assert self.result_paths is not None + assert self.result_paths is not initial_result_paths if callable(self.result_paths): - return self._replace(result_paths=tuple(self.result_paths())) + paths = tuple(self.result_paths()) + return self._replace(result_paths=paths) return self @property @@ -326,32 +333,64 @@ def replace_func_name(self, name: str) -> DebugInfo: func_src_comps[0] = name return self._replace(func_src_info=" ".join(func_src_comps)) - def safe_arg_names(self, expected: int) -> tuple[str, ...]: + def set_result_paths(self, ans): + result_paths = tuple(f"result{_clean_keystr_arg_names(path)}" + for path, _ in generate_key_paths(ans)) + return self._replace(result_paths=result_paths) + + @property + def func_filename(self) -> str | None: + m = _re_func_src_info.match(self.func_src_info) + if not m: return None + return m.group(3) + + @property + def func_lineno(self) -> int | None: + m = _re_func_src_info.match(self.func_src_info) + if not m or m.group(4) is None: return None + return int(m.group(4)) + + def safe_arg_names(self, expected_count: int) -> tuple[str, ...]: """Get the arg_names with a safety check.""" - if len(self.arg_names) == expected: + self.assert_arg_names(expected_count) + if self.arg_names is not None: return self.arg_names - else: - # TODO(necula): this should not happen - return ("",) * expected + return ("",) * expected_count + + def assert_arg_names(self, expected_count: int): + assert self.arg_names is None or len(self.arg_names) == expected_count, ( + expected_count, self) - def filter_arg_names(self, keep: Sequence[bool]) -> tuple[str, ...]: + def filter_arg_names(self, keep: Sequence[bool]) -> tuple[str, ...] | None: """Keep only the arg_names for which `keep` is True.""" + if self.arg_names is None: + return None return tuple(v for v, b in zip(self.safe_arg_names(len(keep)), keep) if b) - def safe_result_paths(self, expected: int) -> tuple[str, ...]: - """Get the result paths with a safety check.""" - assert self.result_paths is not None and not callable(self.result_paths), self - if self.result_paths is not None and len(self.result_paths) == expected: - return self.result_paths - else: - # TODO(necula): this should not happen - return ("",) * expected + def safe_result_paths(self, expected_count: int) -> tuple[str, ...]: + """Get the result paths with a safety check. Empty paths mean unknown.""" + assert self.result_paths is not initial_result_paths and not callable(self.result_paths), self + self.assert_result_paths(expected_count) + if self.result_paths is not None: + return self.result_paths # type: ignore + + return ("",) * expected_count + + def assert_result_paths(self, expected_count: int): + assert self.result_paths is None or len(self.result_paths) == expected_count, ( # type: ignore + expected_count, self) - def filter_result_paths(self, keep: Sequence[bool]) -> tuple[str, ...]: + def filter_result_paths(self, keep: Sequence[bool]) -> tuple[str, ...] | None: """Keep only the result_paths for which `keep` is True.""" - assert self.result_paths is not None and not callable(self.result_paths), self - return tuple(v for v, b in zip(self.safe_result_paths(len(keep)), keep) if b) + assert self.result_paths is not initial_result_paths and not callable(self.result_paths), self + if self.result_paths is None: return None + return tuple(v for v, b in zip(self.result_paths, keep) if b) # type: ignore + def with_unknown_names(self) -> DebugInfo: + return self._replace(arg_names=None, result_paths=None) + + +_re_func_src_info = re.compile(r"([^ ]+)( at (.+):(\d+))?$") def _missing_debug_info(for_what: str) -> DebugInfo: warnings.warn( @@ -360,20 +399,14 @@ def _missing_debug_info(for_what: str) -> DebugInfo: "construct a proper DebugInfo object and propagate it to this function. " "See https://github.com/jax-ml/jax/issues/26480 for more details.", DeprecationWarning, stacklevel=2) - return DebugInfo("missing_debug_info", "", (), ()) + return DebugInfo("missing_debug_info", "", None, None) -def wrap_init(f: Callable, params=None, *, - debug_info: DebugInfo) -> WrappedFun: +def wrap_init(f: Callable, params=None, *, debug_info: DebugInfo) -> WrappedFun: """Wraps function `f` as a `WrappedFun`, suitable for transformation.""" params_dict = {} if params is None else params params = () if params is None else tuple(sorted(params.items())) + debug_info = debug_info._replace(result_paths=None) fun = WrappedFun(f, partial(f, **params_dict), (), (), params, None, debug_info) - if debug_info.result_paths is None: - fun, result_paths_thunk = _get_result_paths_thunk(fun) - debug_info = debug_info._replace( - result_paths=HashableFunction(result_paths_thunk, closure=())) - fun = WrappedFun(fun.f, fun.f_transformed, fun.transforms, fun.stores, - fun.params, fun.in_type, debug_info) return fun @@ -383,57 +416,21 @@ def _clean_keystr_arg_names(k: KeyPath) -> str: res = keystr(k) return _re_clean_keystr_arg_names.sub(r"\1", res) -@transformation_with_aux2 -def _get_result_paths_thunk(_fun: Callable, _store: Store, *args, **kwargs): - ans = _fun(*args, **kwargs) - result_paths = tuple(f"result{_clean_keystr_arg_names(path)}" for path, _ in generate_key_paths(ans)) - if _store: - # In some instances a lu.WrappedFun is called multiple times, e.g., - # the bwd function in a custom_vjp - assert _store.val == result_paths, (_store, result_paths) - else: - _store.store(result_paths) - return ans - def annotate(f: WrappedFun, in_type: core.InputType | None) -> WrappedFun: assert f.in_type is None if in_type is None: return f _check_input_type(in_type) - return WrappedFun(f.f, f.f_transformed, f.transforms, f.stores, f.params, in_type, f.debug_info) + return WrappedFun(f.f, f.f_transformed, f.transforms, f.stores, f.params, + in_type, f.debug_info) def _check_input_type(in_type: core.InputType) -> None: # Check that in_type is syntactically well-formed - assert type(in_type) is tuple and all(type(e) is tuple for e in in_type) - assert all(isinstance(a, core.AbstractValue) and type(b) is bool - for a, b in in_type) - - def valid_size(d) -> bool: - if isinstance(d, core.DBIdx) and type(d.val) is int and d.val >= 0: - return True - return (isinstance(d, (int, core.DBIdx, core.DArray)) and - (not isinstance(d, core.DArray) or type(d) is core.bint and not d.shape)) - assert all(valid_size(d) for a, _ in in_type if type(a) is core.DShapedArray - for d in a.shape) - - # Check that all DBIdx point to positions to the left of the input on which - # they appear. - assert all(d.val < i for i, (aval, _) in enumerate(in_type) - if isinstance(aval, core.DShapedArray) for d in aval.shape - if isinstance(d, core.DBIdx)) - - # Check that all implicit arguments have at least one DBIdx pointing to them. - provided = [e for _, e in in_type] - for aval, _ in in_type: - if type(aval) is core.DShapedArray: - for d in aval.shape: - if isinstance(d, core.DBIdx): - provided[d.val] = True - assert all(provided) - + assert type(in_type) is tuple + assert all(isinstance(a, core.AbstractValue) for a in in_type) def cache(call: Callable, *, - explain: Callable[[WrappedFun, bool, dict, tuple], None] | None = None): + explain: Callable[[WrappedFun, bool, dict, tuple, float], None] | None = None): """Memoization decorator for functions taking a WrappedFun as first argument. Args: @@ -442,7 +439,8 @@ def cache(call: Callable, *, memoization cache key. explain: a function that is invoked upon cache misses to log an explanation - of the miss. Invoked with `(fun, is_cache_first_use, cache, key)`. + of the miss. + Invoked with `(fun, is_cache_first_use, cache, key, elapsed_sec)`. Returns: A memoized version of ``call``. @@ -457,9 +455,11 @@ def memoized_fun(fun: WrappedFun, *args): ans, stores = result fun.populate_stores(stores) else: + if do_explain := explain and config.explain_cache_misses.value: + start = time.time() ans = call(fun, *args) - if explain and config.explain_cache_misses.value: - explain(fun, cache is new_cache, cache, key) + if do_explain: + explain(fun, cache is new_cache, cache, key, time.time() - start) # type: ignore cache[key] = (ans, fun.stores) return ans @@ -469,7 +469,7 @@ def _evict_function(f): memoized_fun.cache_clear = fun_caches.clear # type: ignore memoized_fun.evict_function = _evict_function # type: ignore - cache_clearing_funs.add(memoized_fun.cache_clear) + register_cache(memoized_fun, str(call)) return memoized_fun @transformation2 diff --git a/jax/_src/literals.py b/jax/_src/literals.py new file mode 100644 index 000000000000..5aed0f3c3256 --- /dev/null +++ b/jax/_src/literals.py @@ -0,0 +1,267 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Sequence +from jax._src.lib import _jax + +import numpy as np + +# TypedInt, TypedFloat, and TypedComplex are subclasses of int, float, and +# complex that carry a JAX dtype. Canonicalization forms these types from int, +# float, and complex. Repeated canonicalization, including under different +# jax_enable_x64 modes, preserves the dtype. + + +class TypedInt(int): + + dtype: np.dtype + + def __new__(cls, value: int, dtype: np.dtype): + v = super(TypedInt, cls).__new__(cls, value) + v.dtype = dtype + return v + + def __repr__(self): + return f'TypedInt({int(self)}, dtype={self.dtype.name})' + + def __getnewargs__(self): + return (int(self), self.dtype) + + +class TypedFloat(float): + + dtype: np.dtype + + def __new__(cls, value: float, dtype: np.dtype): + v = super(TypedFloat, cls).__new__(cls, value) + v.dtype = dtype + return v + + def __repr__(self): + return f'TypedFloat({float(self)}, dtype={self.dtype.name})' + + def __str__(self): + return str(float(self)) + + def __getnewargs__(self): + return (float(self), self.dtype) + + +class TypedComplex(complex): + + dtype: np.dtype + + def __new__(cls, value: complex, dtype: np.dtype): + v = super(TypedComplex, cls).__new__(cls, value) + v.dtype = dtype + return v + + def __repr__(self): + return f'TypedComplex({complex(self)}, dtype={self.dtype.name})' + + def __getnewargs__(self): + return (complex(self), self.dtype) + + +_jax.set_typed_int_type(TypedInt) +_jax.set_typed_float_type(TypedFloat) +_jax.set_typed_complex_type(TypedComplex) + + +typed_scalar_types: set[type] = {TypedInt, TypedFloat, TypedComplex} + + +class TypedNdArray: + """A TypedNdArray is a host-side array used by JAX during tracing. + + To most intents and purposes a TypedNdArray is a thin wrapper around a numpy + array and should act like it. The primary differences are that a TypedNdArray + carries a JAX type: + * its type is not canonicalized by JAX, irrespective of the jax_enable_x64 + mode + * it can be weakly typed. + """ + + __slots__ = ('val', 'weak_type') + + val: np.ndarray + weak_type: bool + + def __init__(self, val: np.ndarray, weak_type: bool): + self.val = val + self.weak_type = weak_type + + @property + def dtype(self) -> np.dtype: + return self.val.dtype + + @property + def shape(self) -> tuple[int, ...]: + return self.val.shape + + @property + def strides(self) -> Sequence[int]: + return self.val.strides + + @property + def ndim(self) -> int: + return self.val.ndim + + @property + def size(self) -> int: + return self.val.size + + def __len__(self) -> int: + return self.val.__len__() + + def __repr__(self): + prefix = 'TypedNdArray(' + if self.weak_type: + dtype_str = f'dtype={self.val.dtype.name}, weak_type=True)' + else: + dtype_str = f'dtype={self.val.dtype.name})' + + line_width = np.get_printoptions()['linewidth'] + if self.size == 0: + s = f'[], shape={self.val.shape}' + else: + s = np.array2string( + self.val, + prefix=prefix, + suffix=',', + separator=', ', + max_line_width=line_width, + ) + last_line_len = len(s) - s.rfind('\n') + 1 + sep = ' ' + if last_line_len + len(dtype_str) + 1 > line_width: + sep = ' ' * len(prefix) + return f'{prefix}{s},{sep}{dtype_str}' + + def __array__(self, dtype=None, copy=None): + # You might think that we can do the following here: + # return self.val.__array__(dtype=dtype, copy=copy) + # Unfortunately __array__ appears to be buggy on NumPy < 2.3 and interprets + # the "dtype=None" as "the default float type". + # TODO(phawkins): revert to the above form once NumPy 2.3 is the minimum + # supported version. + return np.asarray(self.val, dtype=dtype, copy=copy) # pytype: disable=wrong-keyword-args + + def __add__(self, other): + return self.val.__add__(other) + + def __sub__(self, other): + return self.val.__sub__(other) + + def __mul__(self, other): + return self.val.__mul__(other) + + def __floordiv__(self, other): + return self.val.__floordiv__(other) + + def __truediv__(self, other): + return self.val.__truediv__(other) + + def __mod__(self, other): + return self.val.__mod__(other) + + def __pow__(self, other): + return self.val.__pow__(other) + + def __radd__(self, other): + return self.val.__radd__(other) + + def __rsub__(self, other): + return self.val.__rsub__(other) + + def __rmul__(self, other): + return self.val.__rmul__(other) + + def __rtruediv__(self, other): + return self.val.__rtruediv__(other) + + def __rfloordiv__(self, other): + return self.val.__rfloordiv__(other) + + def __rmod__(self, other): + return self.val.__rmod__(other) + + def __rpow__(self, other): + return self.val.__rpow__(other) + + def __getitem__(self, index): + return self.val.__getitem__(index) + + def __bool__(self): + return self.val.__bool__() + + def __int__(self): + return self.val.__int__() + + def __float__(self): + return self.val.__float__() + + def __complex__(self): + return self.val.__complex__() + + def __index__(self): + return self.val.__index__() + + def __lt__(self, other): + return self.val.__lt__(other) + + def __le__(self, other): + return self.val.__le__(other) + + def __eq__(self, other): + return self.val.__eq__(other) + + def __ne__(self, other): + return self.val.__ne__(other) + + def __gt__(self, other): + return self.val.__gt__(other) + + def __ge__(self, other): + return self.val.__ge__(other) + + def __abs__(self): + return self.val.__abs__() + + def reshape(self, *args, **kw): + return self.val.reshape(*args, **kw) + + def item(self, *args): + return self.val.item(*args) + + @property + def T(self): + return self.val.T + + @property + def mT(self): + return self.val.mT + + def clip(self, *args, **kwargs): + return self.val.clip(*args, **kwargs) + + def astype(self, dtype, order='K', casting='unsafe', subok=True, copy=True): + return self.val.astype( + dtype, order=order, casting=casting, subok=subok, copy=copy + ) + + def tobytes(self, order='C'): + return self.val.tobytes(order=order) + +_jax.set_typed_ndarray_type(TypedNdArray) diff --git a/jax/_src/logging_config.py b/jax/_src/logging_config.py index bdf588d2054a..6b9ba1dd7bf4 100644 --- a/jax/_src/logging_config.py +++ b/jax/_src/logging_config.py @@ -13,8 +13,8 @@ # limitations under the License. import logging -import os import sys +from jax._src.lib import utils # Example log message: # DEBUG:2023-06-07 00:14:40,280:jax._src.xla_bridge:590: Initializing backend 'cpu' @@ -22,7 +22,6 @@ "{levelname}:{asctime}:{name}:{lineno}: {message}", style='{') _logging_level_set: dict[str, int] = {} -_default_TF_CPP_MIN_LOG_LEVEL = os.environ.get("TF_CPP_MIN_LOG_LEVEL", "1") _jax_logger_handler = logging.StreamHandler(sys.stderr) _jax_logger_handler.setFormatter(logging_formatter) @@ -48,19 +47,17 @@ 'DEBUG': 0, } -def _set_TF_CPP_MIN_LOG_LEVEL(logging_level: str | None = None): +def _set_cpp_min_log_level(logging_level: str | None = None): if logging_level in (None, "NOTSET"): - # resetting to user-default TF_CPP_MIN_LOG_LEVEL - # this is typically "1", but if the user overrode it, it can be != "1" - os.environ["TF_CPP_MIN_LOG_LEVEL"] = _default_TF_CPP_MIN_LOG_LEVEL - else: - # set cpp runtime logging level if the level is anything but NOTSET - if logging_level not in _tf_cpp_map: - raise ValueError(f"Attempting to set log level \"{logging_level}\" which" - f" isn't one of the supported:" - f" {list(_tf_cpp_map.keys())}.") - # config the CPP logging level 0 - debug, 1 - info, 2 - warning, 3 - error - os.environ["TF_CPP_MIN_LOG_LEVEL"] = str(_tf_cpp_map[logging_level]) + return + # set cpp runtime logging level if the level is anything but NOTSET + if logging_level not in _tf_cpp_map: + raise ValueError(f"Attempting to set log level \"{logging_level}\" which" + f" isn't one of the supported:" + f" {list(_tf_cpp_map.keys())}.") + # config the CPP logging level 0 - debug, 1 - info, 2 - warning, 3 - error + log_level = _tf_cpp_map[logging_level] + utils.absl_set_min_log_level(log_level) def update_logging_level_global(logging_level: str | None) -> None: # remove previous handlers @@ -69,7 +66,7 @@ def update_logging_level_global(logging_level: str | None) -> None: logger.removeHandler(_jax_logger_handler) logger.setLevel(level) _logging_level_set.clear() - _set_TF_CPP_MIN_LOG_LEVEL(logging_level) + _set_cpp_min_log_level(logging_level) if logging_level is None: return diff --git a/jax/_src/memory.py b/jax/_src/memory.py new file mode 100644 index 000000000000..8efeba37126c --- /dev/null +++ b/jax/_src/memory.py @@ -0,0 +1,24 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum + + +class Space(enum.Enum): + Device = enum.auto() + Host = enum.auto() + Any = enum.auto() + + def __repr__(self): + return f"MemorySpace.{self.name}" diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index b490febf7b0c..f0210d5d9b82 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -11,13 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Definitions of Mesh and ResourceEnv.""" +"""Definitions of Mesh and AbstractMesh""" from __future__ import annotations import collections from collections.abc import Hashable, Sequence import contextlib +import dataclasses import enum import functools import math @@ -32,6 +33,7 @@ from jax._src.lib import xla_client as xc zip, unsafe_zip = safe_zip, zip +config_ext = xc._xla.config MeshAxisName = Any ResourceAxisName = Hashable @@ -111,12 +113,16 @@ class AxisType(enum.Enum): def __repr__(self): return self.name -def _normalize_axis_types(axis_names, axis_types): +def _normalize_axis_types(axis_names, axis_types, name): axis_types = ((AxisType.Auto,) * len(axis_names) if axis_types is None else axis_types) if not isinstance(axis_types, tuple): - assert isinstance(axis_types, AxisType), axis_types axis_types = (axis_types,) + + if not all(isinstance(a, AxisType) for a in axis_types): + raise TypeError( + f"axis_types passed to {name} must be of type `jax.sharding.AxisType`." + f" Got {axis_types} of type {tuple(type(a) for a in axis_types)}") if len(axis_names) != len(axis_types): raise ValueError( "Number of axis names should match the number of axis_types. Got" @@ -134,76 +140,87 @@ def any_axis_types_match(axis_types, ty: AxisType) -> bool: return any(t == ty for t in axis_types) -class _BaseMesh: +class BaseMesh: axis_names: tuple[MeshAxisName, ...] shape_tuple: tuple[tuple[str, int], ...] - _axis_types: tuple[AxisType, ...] - - @property - def axis_types(self) -> tuple[AxisType, ...]: - return self._axis_types + axis_types: tuple[AxisType, ...] @functools.cached_property - def _are_all_axes_manual(self) -> bool: - return all_axis_types_match(self._axis_types, AxisType.Manual) + def are_all_axes_manual(self) -> bool: + return all_axis_types_match(self.axis_types, AxisType.Manual) @functools.cached_property - def _are_all_axes_auto(self) -> bool: - return all_axis_types_match(self._axis_types, AxisType.Auto) + def are_all_axes_auto(self) -> bool: + return all_axis_types_match(self.axis_types, AxisType.Auto) @functools.cached_property - def _are_all_axes_explicit(self) -> bool: - return all_axis_types_match(self._axis_types, AxisType.Explicit) + def are_all_axes_explicit(self) -> bool: + return all_axis_types_match(self.axis_types, AxisType.Explicit) @functools.cached_property def _are_all_axes_auto_or_manual(self) -> bool: - if not self._axis_types: + if not self.axis_types: return False return all(t == AxisType.Auto or t == AxisType.Manual - for t in self._axis_types) + for t in self.axis_types) + + @functools.cached_property + def _are_all_axes_explicit_or_manual(self) -> bool: + if not self.axis_types: + return False + return all(t == AxisType.Explicit or t == AxisType.Manual + for t in self.axis_types) @functools.cached_property def _any_axis_manual(self) -> bool: - return any_axis_types_match(self._axis_types, AxisType.Manual) + return any_axis_types_match(self.axis_types, AxisType.Manual) @functools.cached_property def _any_axis_auto(self) -> bool: - return any_axis_types_match(self._axis_types, AxisType.Auto) + return any_axis_types_match(self.axis_types, AxisType.Auto) @functools.cached_property def _any_axis_explicit(self) -> bool: - return any_axis_types_match(self._axis_types, AxisType.Explicit) + return any_axis_types_match(self.axis_types, AxisType.Explicit) + + @functools.cached_property + def _any_axis_auto_or_manual(self) -> bool: + if not self.axis_types: + return False + return any(t == AxisType.Auto or t == AxisType.Manual + for t in self.axis_types) + + @functools.cached_property + def auto_axes(self): + return tuple(n for n, t in safe_zip(self.axis_names, self.axis_types) + if t == AxisType.Auto) + + @functools.cached_property + def explicit_axes(self): + return tuple(n for n, t in safe_zip(self.axis_names, self.axis_types) + if t == AxisType.Explicit) @functools.cached_property - def _axis_types_dict(self): - if not self.axis_names: - return {} - d = collections.defaultdict(list) - for n, t in safe_zip(self.axis_names, self._axis_types): - d[t].append(n) - return {t: tuple(n) for t, n in d.items()} + def manual_axes(self): + return tuple(n for n, t in safe_zip(self.axis_names, self.axis_types) + if t == AxisType.Manual) @functools.cached_property def _name_to_type(self): - return dict(safe_zip(self.axis_names, self._axis_types)) + return dict(safe_zip(self.axis_names, self.axis_types)) +def _unpicke_mesh(devices, axis_names, axis_types): + return Mesh(devices, axis_names, axis_types) + _mesh_object_dict = {} # type: ignore -class Mesh(_BaseMesh, contextlib.ContextDecorator): +class Mesh(BaseMesh, contextlib.ContextDecorator): """Declare the hardware resources available in the scope of this manager. - In particular, all ``axis_names`` become valid resource names inside the - managed block and can be used e.g. in the ``in_axis_resources`` argument of - :py:func:`jax.experimental.pjit.pjit`. Also see JAX's multi-process programming - model (https://jax.readthedocs.io/en/latest/multi_process.html) - and the Distributed arrays and automatic parallelization tutorial - (https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) - - If you are compiling in multiple threads, make sure that the - ``with Mesh`` context manager is inside the function that the threads will - execute. + See `Distributed arrays and automatic parallelization`_ and + `Explicit Sharding`_ tutorials. Args: devices: A NumPy ndarray object containing JAX device objects (as @@ -211,42 +228,32 @@ class Mesh(_BaseMesh, contextlib.ContextDecorator): axis_names: A sequence of resource axis names to be assigned to the dimensions of the ``devices`` argument. Its length should match the rank of ``devices``. + axis_types: and optional tuple of :class:`jax.sharding.AxisType` entries corresponding to + the ``axis_names``. See `Explicit Sharding`_ for more information. Examples: - >>> from jax.experimental.pjit import pjit >>> from jax.sharding import Mesh - >>> from jax.sharding import PartitionSpec as P + >>> from jax.sharding import PartitionSpec as P, NamedSharding >>> import numpy as np ... - >>> inp = np.arange(16).reshape((8, 2)) - >>> devices = np.array(jax.devices()).reshape(4, 2) - ... >>> # Declare a 2D mesh with axes `x` and `y`. - >>> global_mesh = Mesh(devices, ('x', 'y')) - >>> # Use the mesh object directly as a context manager. - >>> with global_mesh: - ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) - - >>> # Initialize the Mesh and use the mesh as the context manager. - >>> with Mesh(devices, ('x', 'y')) as global_mesh: - ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) - - >>> # Also you can use it as `with ... as ...`. - >>> global_mesh = Mesh(devices, ('x', 'y')) - >>> with global_mesh as m: - ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) - - >>> # You can also use it as `with Mesh(...)`. - >>> with Mesh(devices, ('x', 'y')): - ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) + >>> devices = np.array(jax.devices()).reshape(4, 2) + >>> mesh = Mesh(devices, ('x', 'y')) + >>> inp = np.arange(16).reshape(8, 2) + >>> arr = jax.device_put(inp, NamedSharding(mesh, P('x', 'y'))) + >>> out = jax.jit(lambda x: x * 2)(arr) + >>> assert out.sharding == NamedSharding(mesh, P('x', 'y')) + + .. _Distributed arrays and automatic parallelization: https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html + .. _Explicit Sharding: https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html """ devices: np.ndarray axis_names: tuple[MeshAxisName, ...] def __new__(cls, devices: np.ndarray | Sequence[xc.Device], - axis_names: str | Sequence[MeshAxisName], *, + axis_names: str | Sequence[MeshAxisName], axis_types: tuple[AxisType, ...] | None = None): if not isinstance(devices, np.ndarray): devices = np.array(devices) @@ -263,7 +270,7 @@ def __new__(cls, devices: np.ndarray | Sequence[xc.Device], f"devices.ndim == {devices.ndim} and " f"len(axis_names) == {len(axis_names)}.") - axis_types = _normalize_axis_types(axis_names, axis_types) + axis_types = _normalize_axis_types(axis_names, axis_types, 'Mesh') key = (axis_names, devices.shape, tuple(devices.flat), axis_types) val = _mesh_object_dict.get(key, None) @@ -274,14 +281,13 @@ def __new__(cls, devices: np.ndarray | Sequence[xc.Device], self.devices = devices.copy() self.devices.flags.writeable = False self.axis_names = axis_names - self._axis_types = axis_types + self.axis_types = axis_types self._size = math.prod(self.shape.values()) if self.devices.ndim else 0 _mesh_object_dict[key] = self return self def __reduce__(self): - return (type(self), (self.devices, self.axis_names), - {'axis_types': self._axis_types}) + return (_unpicke_mesh, (self.devices, self.axis_names, self.axis_types)) def __eq__(self, other): # This is a performance optimization. Comparing thousands of devices @@ -292,14 +298,14 @@ def __eq__(self, other): return False return (self.axis_names == other.axis_names and self.devices.shape == other.devices.shape and - self._axis_types == other._axis_types and + self.axis_types == other.axis_types and self._internal_device_list == other._internal_device_list) def __hash__(self): if not hasattr(self, '_hash'): self._hash = hash( (self.axis_names, self._internal_device_list, self.devices.shape, - self._axis_types)) + self.axis_types)) return self._hash def __setattr__(self, name, value): @@ -332,6 +338,15 @@ def __exit__(self, exc_type, exc_value, traceback): if not t.physical_mesh.empty)) return False + def update(self, devices=None, axis_names=None, axis_types=None): + if devices is None: + devices = self.devices + if axis_names is None: + axis_names = self.axis_names + if axis_types is None: + axis_types = self.axis_types + return Mesh(devices, axis_names, axis_types) + @functools.cached_property def shape(self): return collections.OrderedDict( @@ -356,16 +371,10 @@ def size(self): def empty(self): return self.size == 0 - # TODO(emilyaf): Remove this when the `enable_empty_arrays` flag is - # removed. @functools.cached_property def is_multi_process(self): return self.devices.size != len(self.local_devices) - @functools.cached_property - def _process_indices(self): - return {d.process_index for d in self._flat_devices_tuple} - @property def local_mesh(self): return self._local_mesh(xb.process_index()) @@ -395,16 +404,19 @@ def _flat_devices_set(self): return set(self.devices.flat) def __str__(self): + if self.empty: + return "Mesh()" mesh_str = ", ".join(f"'{k}': {v}" for k, v in self.shape.items()) - atr = f", axis_types={self._axis_types}" + atr = f", axis_types={self.axis_types}" return f"Mesh({mesh_str}{atr})" @functools.cached_property def _repr(self): if self.empty: - return "Mesh(device_ids=[], axis_names=())" - atr = f", axis_types={self._axis_types}" - return f"Mesh(device_ids={self.device_ids!r}, axis_names={self.axis_names!r}{atr})" + return "Mesh(axis_sizes=(), axis_names=())" + atr = f", axis_types={self.axis_types}" + return (f"Mesh(axis_sizes={self.device_ids.shape}, " + f"axis_names={self.axis_names!r}{atr})") def __repr__(self): return self._repr @@ -416,8 +428,21 @@ def local_devices(self): @functools.cached_property def abstract_mesh(self): - return AbstractMesh(self.axis_sizes, self.axis_names, - axis_types=self._axis_types) + d = self.devices.flat[0] + if d is None: + abstract_device = None + else: + if d.platform == 'tpu': + num_cores = getattr(d, 'num_cores', None) + elif d.platform == 'gpu': + num_cores = getattr(d, 'core_count', None) + else: + num_cores = None + abstract_device = AbstractDevice( + device_kind=d.device_kind, num_cores=num_cores) + return AbstractMesh( + self.axis_sizes, self.axis_names, axis_types=self.axis_types, + abstract_device=abstract_device) EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ())) @@ -431,7 +456,19 @@ def __init__(self): thread_resources = _ThreadResourcesLocalState() -class AbstractMesh(_BaseMesh): +@dataclasses.dataclass(frozen=True) +class AbstractDevice: + device_kind: str + num_cores: int | None + + def __repr__(self): + return (f"AbstractDevice({self._repr()})") + + def _repr(self): + return f"device_kind={self.device_kind}, num_cores={self.num_cores}" + + +class AbstractMesh(BaseMesh): """AbstractMesh contains only axis names and axis sizes. It does not contain concrete devices compared to `jax.sharding.Mesh`. You @@ -440,15 +477,29 @@ class AbstractMesh(_BaseMesh): your mesh shape and axis names stay the same but the devices change. See the description of https://github.com/jax-ml/jax/pull/23022 for more details. + + Args: + axis_sizes: A tuple of integers specifying the size of each resource axis. + axis_names: A tuple of resource axis names to be assigned to the + dimensions of the ``devices`` argument. Its length should match the + rank of ``devices``. + axis_types: and optional tuple of :class:`jax.sharding.AxisType` entries corresponding to + the ``axis_names``. See `Explicit Sharding`_ for more information. + + .. _Explicit Sharding: https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html """ def __init__(self, axis_sizes: tuple[int, ...], axis_names: tuple[str, ...], - *, axis_types: AxisType | tuple[AxisType, ...] | None = None): + axis_types: AxisType | tuple[AxisType, ...] | None = None, + *, abstract_device=None): self.axis_sizes = axis_sizes self.axis_names = axis_names - self._size = math.prod(self.axis_sizes) if self.axis_sizes else 0 - self._axis_types = _normalize_axis_types(self.axis_names, axis_types) - self._hash = hash((self.axis_sizes, self.axis_names, self._axis_types)) + self.axis_types = _normalize_axis_types( + self.axis_names, axis_types, 'AbstractMesh') + self.abstract_device = abstract_device + self.size = math.prod(self.axis_sizes) if self.axis_sizes else 0 + self._hash = hash((self.axis_sizes, self.axis_names, self.axis_types, + self.abstract_device)) def __hash__(self): return self._hash @@ -460,17 +511,27 @@ def __eq__(self, other): return False return (self.axis_sizes == other.axis_sizes and self.axis_names == other.axis_names and - self._axis_types == other._axis_types) + self.axis_types == other.axis_types and + self.abstract_device == other.abstract_device) def __repr__(self): mesh_repr = (", ".join(f"'{n}': {v}" for n, v in self.shape_tuple) if self.shape_tuple else "()") - atr = f", axis_types={self._axis_types}" - return f"AbstractMesh({mesh_repr}{atr})" - - @property - def size(self): - return self._size + atr = f", axis_types={self.axis_types}" + ad = ("" if self.abstract_device is None else + f", {self.abstract_device._repr()}") + return f"AbstractMesh({mesh_repr}{atr}{ad})" + + def update(self, axis_sizes=None, axis_names=None, axis_types=None, **kwargs): + if axis_sizes is None: + axis_sizes = self.axis_sizes + if axis_names is None: + axis_names = self.axis_names + if axis_types is None: + axis_types = self.axis_types + if 'abstract_device' not in kwargs: + kwargs['abstract_device'] = self.abstract_device + return AbstractMesh(axis_sizes, axis_names, axis_types, **kwargs) @functools.cached_property def shape(self): @@ -496,9 +557,8 @@ def abstract_mesh(self): def update_axis_types(self, name_to_type: dict[MeshAxisName, AxisType]): new_axis_types = tuple(name_to_type[n] if n in name_to_type else a - for n, a in zip(self.axis_names, self._axis_types)) - return AbstractMesh(self.axis_sizes, self.axis_names, - axis_types=new_axis_types) + for n, a in zip(self.axis_names, self.axis_types)) + return self.update(axis_types=new_axis_types) @property def devices(self): @@ -526,11 +586,6 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): _raise_value_error("__exit__") - @staticmethod - def _extremely_unsafe_enter_tracing_context(mesh: AbstractMesh): - prev = jax_config.abstract_mesh_context_manager.swap_local(mesh) - return prev - # Create this indirection because pytype fails to recognize a property if a # property raises an exception unconditionally. Remove this once that is fixed. @@ -538,8 +593,9 @@ def _raise_value_error(name): raise ValueError(f"AbstractMesh does not implement {name}") empty_abstract_mesh = AbstractMesh((), ()) +empty_concrete_mesh = Mesh(np.empty((), dtype=object), ()) -class UseAbstractMeshContextManager: +class use_abstract_mesh: __slots__ = ['mesh', 'prev'] def __init__(self, mesh: AbstractMesh): @@ -551,15 +607,23 @@ def __init__(self, mesh: AbstractMesh): def __enter__(self): self.prev = jax_config.abstract_mesh_context_manager.swap_local(self.mesh) + if (self.prev is not config_ext.unset and + not self.prev.empty and not self.mesh.empty and + self.prev.size != self.mesh.size): + jax_config.abstract_mesh_context_manager.set_local(self.prev) + raise ValueError( + "use_abstract_mesh cannot change the size of the mesh. Got new mesh:" + f" {self.mesh} with size={self.mesh.size} and prev mesh:" + f" {self.prev} with size={self.prev.size}") def __exit__(self, exc_type, exc_value, traceback): jax_config.abstract_mesh_context_manager.set_local(self.prev) -use_abstract_mesh = UseAbstractMeshContextManager -def get_abstract_mesh(): +def get_abstract_mesh() -> AbstractMesh: val = jax_config.abstract_mesh_context_manager.value return empty_abstract_mesh if val is None else val -def get_concrete_mesh() -> Mesh | None: - return jax_config.device_context.value +def get_concrete_mesh() -> Mesh: + val = jax_config.device_context.value + return empty_concrete_mesh if val is None else val diff --git a/jax/_src/mesh_utils.py b/jax/_src/mesh_utils.py index ccc75af8c84f..995b978d00e4 100644 --- a/jax/_src/mesh_utils.py +++ b/jax/_src/mesh_utils.py @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The JAX Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,21 +19,21 @@ import collections from collections.abc import Callable, Generator, MutableMapping, Sequence import itertools -import logging import math from typing import Any from jax._src import xla_bridge as xb import numpy as np -logger = logging.getLogger(__name__) - _TPU_V2 = 'TPU v2' _TPU_V3 = 'TPU v3' _TPU_V4 = 'TPU v4' +_TPU_V4_LITE = "TPU v4 lite" _TPU_V5_LITE = "TPU v5 lite" _TPU_V5E = "TPU v5e" _TPU_V5P = "TPU v5p" +_TPU_V6_LITE = "TPU v6 lite" +_TPU_7X = "TPU7x" # Maps physical topology -> mesh shape -> transpose to use for jekbradbury's # famous contiguous mesh trick. @@ -72,6 +72,8 @@ _V5E_TRAY_RING_ORDER = (0, 1, 2, 3, 7, 6, 5, 4) _V5E_TRAY_IOTA_ORDER = (0, 4, 2, 6, 1, 5, 3, 7) _V5P_2x2x2_ORDER = (0, 1, 3, 2, 6, 7, 5, 4) +_7X_TRAY_2x2x2_RING_ORDER = (0, 1, 2, 3, 6, 7, 4, 5) + def _tpu_v2_v3_create_device_mesh( mesh_shape: Sequence[int], @@ -79,18 +81,12 @@ def _tpu_v2_v3_create_device_mesh( **unused_kwargs, ) -> np.ndarray: if len(devices) == 8: - logger.info( - 'Reordering mesh to physical ring order on single-tray TPU v2/v3.' - ) device_mesh = np.asarray(devices) device_mesh = device_mesh[np.array(_TRAY_RING_ORDER)] device_mesh = device_mesh.reshape(mesh_shape) return device_mesh elif mesh_shape[-1] == 8: device_mesh = np.asarray(devices).reshape(mesh_shape) - logger.info( - 'Reordering mesh to physical ring order on each TPU v2/v3 tray.' - ) perm = np.array(_TRAY_RING_ORDER) device_mesh = device_mesh[..., perm] return device_mesh @@ -101,6 +97,19 @@ def _tpu_v2_v3_create_device_mesh( return np.asarray(devices).reshape(mesh_shape) +# TODO(b/303712469): Unit test these handler functions. +# Creates a physical ring 0->1->3->2 if on v4i. +def _v4i_create_device_mesh( + mesh_shape: Sequence[int], devices: Sequence[Any], **unused_kwargs +) -> np.ndarray | None: + if len(devices) == 4: + device_mesh = np.asarray(devices) + device_mesh = device_mesh[np.array(_TRAY_2x2_RING_ORDER)] + device_mesh = device_mesh.reshape(mesh_shape) + return device_mesh + return None + + def _v5e_create_device_mesh( mesh_shape: Sequence[int], devices: Sequence[Any], **unused_kwargs ) -> np.ndarray | None: @@ -172,13 +181,50 @@ def _v5p_create_device_mesh( devices, key=lambda d: tuple(reversed(getattr(d, "coords", (0, 0, 0))))) - if bound_x == bound_y == 2 and bound_z == 2: + if bound_x == bound_y == bound_z == 2 and len(devices) == 8: device_mesh = np.asarray(sequential_devices) device_mesh = device_mesh[np.array(_V5P_2x2x2_ORDER)] device_mesh = device_mesh.reshape(mesh_shape) return device_mesh return None +def _7x_create_device_mesh( + mesh_shape: Sequence[int], devices: Sequence[Any], **unused_kwargs +) -> np.ndarray | None: + """Creates device assignment for small 7x topologies. + + The device assignment attempts to minimize the number of hops between + neighbors by allocating rings of devices, and assigns the core axis + preferentially due to its higher bandwidth. + + Args: + mesh_shape: Logical mesh shape used by the model. + devices: TPU devices. + **unused_kwargs: ... + + Returns: + None or reordered devices reshaped as `mesh_shape`. + """ + if len(devices) % 8 != 0 or len(devices) > 32: + return None + + physical_mesh_shape = _get_physical_tpu_mesh(devices).shape + # For the x and y axes, we only support at most 2x2 since we can make one ring + # along those axes and repeat with other separate rings along the z axis. + if physical_mesh_shape[0] > 2 or physical_mesh_shape[1] > 2: + return None + + indices = [] + for i in range(0, len(devices), 8): + new_indices = [x + i for x in _7X_TRAY_2x2x2_RING_ORDER] + indices.extend(new_indices) + + device_mesh = np.asarray(devices) + device_mesh = device_mesh[np.array(indices)] + device_mesh = device_mesh.reshape(mesh_shape) + return device_mesh + + # Registers functions to create device mesh for specific device kinds. Takes # precedence over the more general logic in create_device_mesh(). Handler may # return None; in that case, it will fall back to using the default logic. @@ -188,8 +234,11 @@ def _v5p_create_device_mesh( ] = { _TPU_V2: _tpu_v2_v3_create_device_mesh, _TPU_V3: _tpu_v2_v3_create_device_mesh, + _TPU_V4_LITE: _v4i_create_device_mesh, _TPU_V5_LITE: _v5e_create_device_mesh, _TPU_V5P: _v5p_create_device_mesh, + _TPU_V6_LITE: _v5e_create_device_mesh, + _TPU_7X: _7x_create_device_mesh, } @@ -253,7 +302,7 @@ def _create_device_mesh_for_nd_torus( list(enumerate(mesh_shape)) ): # Preferentially map to more physical axes first for higher bandwidth. - for num_axes in range(3, 0, -1): + for num_axes in range(len(physical_mesh.shape), 0, -1): # Try assign to any subset of size num_axes. Generate all candidates. indices_and_axes = itertools.combinations( enumerate(assignable_physical_mesh), num_axes @@ -652,6 +701,17 @@ def _get_physical_tpu_mesh(jax_devices: Sequence[Any]) -> np.ndarray: coords[1] - min_coords[1], d.core_on_chip - min_cores_per_chip, ] = d + elif (device_kind in (_TPU_7X,) or + (device_kind in (_TPU_V5P,) and cores_per_chip == 2)): + out = np.empty(dims + (cores_per_chip,), dtype=object) + for d in jax_devices: + coords = d.coords + out[ + coords[0] - min_coords[0], + coords[1] - min_coords[1], + coords[2] - min_coords[2], + d.core_on_chip - min_cores_per_chip, + ] = d else: out = np.empty(dims, dtype=object) for d in jax_devices: diff --git a/jax/_src/monitoring.py b/jax/_src/monitoring.py index 99e957733ba2..76c31d3c860a 100644 --- a/jax/_src/monitoring.py +++ b/jax/_src/monitoring.py @@ -46,10 +46,18 @@ def __call__( ) -> None: ... +class ScalarListenerWithMetadata(Protocol): + + def __call__( + self, event: str, value: float | int, **kwargs: str | int, + ) -> None: + ... + _event_listeners: list[EventListenerWithMetadata] = [] _event_duration_secs_listeners: list[EventDurationListenerWithMetadata] = [] _event_time_span_listeners: list[EventTimeSpanListenerWithMetadata] = [] +_scalar_listeners: list[ScalarListenerWithMetadata] = [] def record_event(event: str, **kwargs: str | int) -> None: @@ -81,6 +89,14 @@ def record_event_time_span( callback(event, start_time, end_time, **kwargs) +def record_scalar( + event: str, value: float | int, **kwargs: str | int +) -> None: + """Record a scalar summary value.""" + for callback in _scalar_listeners: + callback(event, value, **kwargs) + + def register_event_listener( callback: EventListenerWithMetadata, ) -> None: @@ -100,6 +116,14 @@ def register_event_duration_secs_listener( """Register a callback to be invoked during record_event_duration_secs().""" _event_duration_secs_listeners.append(callback) + +def register_scalar_listener( + callback : ScalarListenerWithMetadata, +) -> None: + """Register a callback to be invoked during record_scalar().""" + _scalar_listeners.append(callback) + + def get_event_duration_listeners() -> list[EventDurationListenerWithMetadata]: """Get event duration listeners.""" return list(_event_duration_secs_listeners) @@ -114,48 +138,48 @@ def get_event_listeners() -> list[EventListenerWithMetadata]: """Get event listeners.""" return list(_event_listeners) + +def get_scalar_listeners() -> list[ScalarListenerWithMetadata]: + """Get scalar event listeners.""" + return list(_scalar_listeners) + + def clear_event_listeners(): """Clear event listeners.""" global _event_listeners, _event_duration_secs_listeners, _event_time_span_listeners _event_listeners = [] _event_duration_secs_listeners = [] _event_time_span_listeners = [] + _scalar_listeners = [] -def _unregister_event_duration_listener_by_callback( - callback: EventDurationListenerWithMetadata) -> None: - """Unregister an event duration listener by callback. - This function is supposed to be called for testing only. - """ +def unregister_event_duration_listener( + callback: EventDurationListenerWithMetadata, +) -> None: + """Unregister an event duration listener by callback.""" assert callback in _event_duration_secs_listeners _event_duration_secs_listeners.remove(callback) -def _unregister_event_duration_listener_by_index(index: int) -> None: - """Unregister an event duration listener by index. - - This function is supposed to be called for testing only. - """ - size = len(_event_duration_secs_listeners) - assert -size <= index < size - del _event_duration_secs_listeners[index] - -def _unregister_event_time_span_listener_by_callback( +def unregister_event_time_span_listener( callback: EventTimeSpanListenerWithMetadata, ) -> None: - """Unregister an event time span listener by callback. - - This function is supposed to be called for testing only. - """ + """Unregister an event time span listener by callback.""" assert callback in _event_time_span_listeners _event_time_span_listeners.remove(callback) -def _unregister_event_listener_by_callback( - callback: EventListenerWithMetadata) -> None: - """Unregister an event listener by callback. - - This function is supposed to be called for testing only. - """ +def unregister_event_listener( + callback: EventListenerWithMetadata, +) -> None: + """Unregister an event listener by callback.""" assert callback in _event_listeners _event_listeners.remove(callback) + + +def unregister_scalar_listener( + callback: ScalarListenerWithMetadata, +) -> None: + """Unregister a scalar event listener by callback.""" + assert callback in _scalar_listeners + _scalar_listeners.remove(callback) diff --git a/jax/_src/named_sharding.py b/jax/_src/named_sharding.py index 5accdd880a79..5581e6fd5f2e 100644 --- a/jax/_src/named_sharding.py +++ b/jax/_src/named_sharding.py @@ -20,14 +20,13 @@ import functools from typing import Any, Union -from jax._src import config -from jax._src.util import use_cpp_class, cache, use_cpp_method, tuple_insert +from jax._src.util import use_cpp_class, cache, use_cpp_method from jax._src.lib import xla_client as xc from jax._src.lib.mlir.dialects import sdy from jax._src import mesh as mesh_lib -from jax._src.partition_spec import PartitionSpec, UnconstrainedSingleton +from jax._src.mesh import AxisType +from jax._src.partition_spec import PartitionSpec from jax._src import sharding as JSharding -from jax._src import xla_bridge as xb import numpy as np Shape = tuple[int, ...] @@ -41,10 +40,16 @@ class AUTO: def __init__(self, mesh: mesh_lib.Mesh): self.mesh = mesh - def _to_sdy_sharding(self, ndim: int) -> SdyArraySharding: - dim_shardings = [SdyDimSharding(axes=[], is_closed=False) + def _to_sdy_sharding(self, ndim: int) -> SdyArray: + dim_shardings = [SdyDim(axes=[], is_open=True) for _ in range(ndim)] - return SdyArraySharding(self.mesh.shape_tuple, dim_shardings) + return SdyArray(mesh_shape=self.mesh.shape_tuple, + dim_shardings=dim_shardings) + + @property + def _device_assignment(self): + return self.mesh._flat_devices_tuple + class UnspecifiedValue: def __repr__(self): @@ -73,6 +78,11 @@ def __repr__(self): ArrayMappingOrAutoOrUnspecified = Union[ArrayMapping, AUTO, UnspecifiedValue] +def _unpickle_named_sharding(mesh, spec, memory_kind, logical_device_ids): + return NamedSharding(mesh, spec, memory_kind=memory_kind, + _logical_device_ids=logical_device_ids) + + @use_cpp_class(xc.NamedSharding) class NamedSharding(JSharding.Sharding): r"""A :class:`NamedSharding` expresses sharding using named axes. @@ -91,14 +101,14 @@ class NamedSharding(JSharding.Sharding): is sharded across ``x`` axis of the mesh, and the second dimension is sharded across ``y`` axis of the mesh. - The Distributed arrays and automatic parallelization - (https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names) - tutorial has more details and diagrams that explain how - :class:`Mesh` and :class:`PartitionSpec` are used. + The `Distributed arrays and automatic parallelization`_ + and `Explicit Sharding`_ tutorials have more details and diagrams that + explain how :class:`Mesh` and :class:`PartitionSpec` are used. Args: mesh: A :class:`jax.sharding.Mesh` object. spec: A :class:`jax.sharding.PartitionSpec` object. + memory_kind: A string indicating the memory kind of the sharding. Examples: @@ -107,25 +117,25 @@ class NamedSharding(JSharding.Sharding): >>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y')) >>> spec = P('x', 'y') >>> named_sharding = jax.sharding.NamedSharding(mesh, spec) + + .. _Distributed arrays and automatic parallelization: https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html + .. _Explicit Sharding: https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html """ mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh spec: PartitionSpec _memory_kind: str | None - _manual_axes: frozenset[MeshAxisName] _logical_device_ids: tuple[int, ...] | None @use_cpp_method() def __init__( self, mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh, spec: PartitionSpec, *, - memory_kind: str | None = None, _manual_axes=frozenset(), - _logical_device_ids=None): + memory_kind: str | None = None, _logical_device_ids=None): self.mesh = mesh self.spec = spec self._memory_kind = memory_kind - self._manual_axes = _manual_axes self._logical_device_ids = _logical_device_ids - check_pspec(self.mesh, self.spec, self._manual_axes) + check_pspec(self.mesh, self.spec) def __repr__(self): mem = '' if self.memory_kind is None else f', memory_kind={self.memory_kind}' @@ -135,22 +145,21 @@ def __repr__(self): return f'NamedSharding(mesh={mesh_repr}, spec={self.spec}{mem}{ldi})' def __reduce__(self): - return (type(self), (self.mesh, self.spec), - {'memory_kind': self.memory_kind, - '_manual_axes': self._manual_axes, - '_logical_device_ids': self._logical_device_ids}) + return (_unpickle_named_sharding, + (self.mesh, self.spec, self.memory_kind, self._logical_device_ids)) @property def memory_kind(self) -> str | None: return self._memory_kind + @use_cpp_method() def __hash__(self): if not hasattr(self, '_hash'): self._hash = hash( - (self.mesh, self.memory_kind, self.spec, self._manual_axes, - self._logical_device_ids)) + (self.mesh, self.memory_kind, self.spec, self._logical_device_ids)) return self._hash + @use_cpp_method() def __eq__(self, other): if not isinstance(other, NamedSharding): return False @@ -158,7 +167,6 @@ def __eq__(self, other): return True if (self.spec != other.spec or self.memory_kind != other.memory_kind - or self._manual_axes != other._manual_axes or self._logical_device_ids != other._logical_device_ids): return False return self.mesh is other.mesh or self.mesh == other.mesh @@ -195,14 +203,8 @@ def is_fully_addressable(self) -> bool: if isinstance(self.mesh, mesh_lib.AbstractMesh): raise ValueError('is_fully_addressable is not implemented for ' '`jax.sharding.AbstractMesh`.') - # Speed up `is_fully_addressable` since there is a high chance that the - # mesh across multiple NamedSharding objects will be the same. - if config.enable_empty_arrays.value: - client = self._internal_device_list[0].client - return (len(self.mesh._process_indices) == 1 and - next(iter(self.mesh._process_indices)) == - xb.process_index(client)) - return not self.mesh.is_multi_process + # return False if addressable_device_list is empty. + return self._internal_device_list.is_fully_addressable # type: ignore @property def _is_concrete(self) -> bool: @@ -230,31 +232,58 @@ def is_fully_replicated(self) -> bool: num_partitions *= mesh_shape[name] return num_partitions == 1 + @functools.cached_property + def replicated_axes(self) -> frozenset[MeshAxisName]: + flat_spec = frozenset( + s for s in flatten_spec(self.spec) + if s is not None and s is not PartitionSpec.UNCONSTRAINED) + return frozenset(self.mesh.axis_names) - ( + flat_spec | self.spec.unreduced | self.spec.reduced) + def with_memory_kind(self, kind: str) -> NamedSharding: - return NamedSharding(self.mesh, self.spec, memory_kind=kind) + return self.update(memory_kind=kind) - def with_spec(self, spec: PartitionSpec | Sequence[Any]) -> NamedSharding: + def update(self, **kwargs) -> NamedSharding: + spec = kwargs.pop("spec", self.spec) if not isinstance(spec, PartitionSpec): spec = PartitionSpec(*spec) - return NamedSharding(self.mesh, spec, memory_kind=self.memory_kind) + return NamedSharding( + mesh=kwargs.pop("mesh", self.mesh), + spec=spec, + memory_kind=kwargs.pop("memory_kind", self.memory_kind), + _logical_device_ids=kwargs.pop("_logical_device_ids", + self._logical_device_ids)) def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return named_sharding_to_xla_hlo_sharding(self, num_dimensions) - def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding: - dim_shardings = [SdyDimSharding(axes=[], is_closed=True) + def _to_sdy_sharding(self, num_dimensions: int) -> SdyArray: + dim_shardings = [SdyDim(axes=[], is_open=False) for _ in range(num_dimensions)] for i, dim_spec in enumerate(self.spec): if dim_spec is PartitionSpec.UNCONSTRAINED: - dim_shardings[i].is_closed = False + dim_shardings[i].is_open = True elif dim_spec is None: # Already empty and closed sharding. pass else: dim_spec = dim_spec if isinstance(dim_spec, tuple) else (dim_spec,) dim_shardings[i].axes = dim_spec - return SdyArraySharding(self.mesh.shape_tuple, dim_shardings, - self._logical_device_ids) + return SdyArray(mesh_shape=self.mesh.shape_tuple, + dim_shardings=dim_shardings, + logical_device_ids=self._logical_device_ids, + unreduced_axes=self.spec.unreduced) + +NamedSharding.__module__ = 'jax.sharding' + +def flatten_spec(spec): + out = [] + for s in spec: + if isinstance(s, tuple): + out.extend(s) + else: + out.append(s) + return out def get_array_mapping( @@ -272,35 +301,40 @@ def get_array_mapping( return d @dataclasses.dataclass -class SdyDimSharding: +class SdyDim: axes: Sequence[str] - is_closed: bool - priority: int | None = None + is_open: bool def build(self) -> sdy.DimensionShardingAttr: return sdy.DimensionShardingAttr.get( [sdy.AxisRefAttr.get(axis) for axis in self.axes], - is_closed=self.is_closed, - priority=self.priority) + is_closed=not self.is_open) def __repr__(self): - return f'SdyDimSharding({self._custom_repr()})' + return f'SdyDim({self._custom_repr()})' def _custom_repr(self): axes_repr = ', '.join(f"'{a}'" for a in self.axes) open_repr = '' - if not self.is_closed: + if self.is_open: open_repr = ', ?' if self.axes else '?' - priority_repr = '' if self.priority is None else f'p{self.priority}' - return f'{{{axes_repr}{open_repr}}}{priority_repr}' - - -@dataclasses.dataclass -class SdyArraySharding: + return f'{{{axes_repr}{open_repr}}}' + +def _get_axes(axes, mesh_shape): + if not axes: + return () + assert mesh_shape is not None + # Sort wrt mesh axis names so order is deterministic and doesn't hang in + # McJAX. + return tuple(n for n, _ in mesh_shape if n in axes) + +@dataclasses.dataclass(kw_only=True) +class SdyArray: mesh_shape: tuple[tuple[str, int], ...] | None - dimension_shardings: Sequence[SdyDimSharding] + dim_shardings: Sequence[SdyDim] logical_device_ids: tuple[int, ...] | None = None replicated_axes: tuple[str, ...] = () + unreduced_axes: frozenset[str] = frozenset() def build(self) -> sdy.TensorShardingAttr: if self.mesh_shape is None: @@ -311,94 +345,42 @@ def build(self) -> sdy.TensorShardingAttr: mesh_attr = sdy.MeshAttr.get( [sdy.MeshAxisAttr.get(name, size) for name, size in self.mesh_shape], ldi) + + replicated_axes = _get_axes(self.replicated_axes, self.mesh_shape) + unreduced_axes = _get_axes(self.unreduced_axes, self.mesh_shape) return sdy.TensorShardingAttr.get( mesh_attr, - [dim_sharding.build() for dim_sharding in self.dimension_shardings], - replicated_axes=[sdy.AxisRefAttr.get(axis) for axis in self.replicated_axes]) + [dim_sharding.build() for dim_sharding in self.dim_shardings], + replicated_axes=[sdy.AxisRefAttr.get(axis) for axis in replicated_axes], + unreduced_axes=[sdy.AxisRefAttr.get(axis) for axis in unreduced_axes]) def __repr__(self): dim_sharding_repr = ', '.join( - d._custom_repr() for d in self.dimension_shardings) + d._custom_repr() for d in self.dim_shardings) device_id_repr = (f', device_ids={self.logical_device_ids}' if self.logical_device_ids is not None else '') rar = (f', replicated_axes={self.replicated_axes}' if self.replicated_axes else '') - return f"SdyArraySharding([{dim_sharding_repr}]{device_id_repr}{rar})" - -# TODO(yashkatariya): Remove this after jax 0.5.2 release -class ParsedPartitionSpec: - __slots__ = ('_user_spec', 'partitions') - - _user_spec: PartitionSpec | None - partitions: tuple[tuple[MeshAxisName, ...] | UnconstrainedSingleton, ...] - - def __init__(self, user_spec, partitions): - self._user_spec = user_spec - assert None not in partitions, partitions - self.partitions = tuple(partitions) - - def get_partition_spec(self) -> PartitionSpec: - if isinstance(self._user_spec, PartitionSpec): - return self._user_spec - else: - return get_single_pspec(self) - - def insert_axis_partitions(self, dim, val): - parts = self.partitions - too_short = dim - len(parts) - if too_short > 0: - parts += ((),) * too_short - new_partitions = tuple_insert(parts, dim, val) - return ParsedPartitionSpec(None, new_partitions) - - @classmethod - def from_user_input( - cls, - entry: PartitionSpec | None, - arg_name: str, - allow_unconstrained_dims: bool = False, - ) -> ParsedPartitionSpec: - if entry is None: - return cls(entry, ()) - if not isinstance(entry, PartitionSpec): - raise TypeError(f"{arg_name} are expected to be " - f"PartitionSpec instances or None, but got {entry}") - axis_specs = [] - for axis_spec in entry: - if axis_spec is None: - axis_spec = () - elif isinstance(axis_spec, (list, tuple)): - axis_spec = tuple(axis_spec) - elif axis_spec is PartitionSpec.UNCONSTRAINED: - if not allow_unconstrained_dims: - raise ValueError(f"Unconstrained dims are not allowed: {entry}") - axis_spec = PartitionSpec.UNCONSTRAINED - else: - axis_spec = (axis_spec,) - axis_specs.append(axis_spec) - new_entry = PartitionSpec( - *[tuple(e) if isinstance(e, (list, tuple)) else e for e in entry]) - return cls(new_entry, axis_specs) + return f"SdyArray([{dim_sharding_repr}]{device_id_repr}{rar})" + + +# TODO(yashkatariya): Upstream this into `_to_sdy_sharding` maybe with an extra +# parameter to it `_to_sdy_sharding(self, ndim, modify_wrt_axis_types=False)` +def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArray, mesh): + if mesh._any_axis_auto: + dim_shardings, used_axes = [], [] # type: ignore + for d in sdy_sharding.dim_shardings: + dim_shardings.append(SdyDim(axes=d.axes, is_open=True)) + used_axes.extend(d.axes) + remaining_axes = set(mesh.axis_names) - set(used_axes) + replicated_axes = tuple(r for r in remaining_axes + if mesh._name_to_type[r] == mesh_lib.AxisType.Explicit) + return SdyArray(mesh_shape=sdy_sharding.mesh_shape, + dim_shardings=dim_shardings, + logical_device_ids=sdy_sharding.logical_device_ids, + replicated_axes=replicated_axes) + return sdy_sharding - def __hash__(self): - return hash(self.partitions) - - def __eq__(self, other): - if not isinstance(other, ParsedPartitionSpec): - return False - return self.partitions == other.partitions - - def __len__(self): - return len(self.partitions) - - def __getitem__(self, i): - return self.partitions[i] - - def __iter__(self): - return iter(self.partitions) - - def __repr__(self): - return f"ParsedPartitionSpec(partitions={self.partitions})" @cache(max_size=4096, trace_context_in_key=False) def named_sharding_to_xla_hlo_sharding( @@ -408,14 +390,18 @@ def named_sharding_to_xla_hlo_sharding( mesh_axis_pos = {name: i for i, name in enumerate(self.mesh.axis_names)} special_axes = {} - mesh_manual_axes = {n for n, t in self.mesh._name_to_type.items() - if t == mesh_lib.AxisType.Manual} - manual_axes = self._manual_axes.union(mesh_manual_axes) + manual_axes = frozenset(self.mesh.manual_axes) if manual_axes: axis_names = self.mesh.axis_names for manual_axis in manual_axes: special_axes[axis_names.index(manual_axis)] = xc.OpSharding.Type.MANUAL + unreduced_axes = self.spec.unreduced + if unreduced_axes: + axis_names = self.mesh.axis_names + for u in unreduced_axes: + special_axes[axis_names.index(u)] = xc.OpSharding.Type.UNREDUCED + replicated_mesh_axes = [] for i, (axis_name, axis_val) in enumerate(mesh_shape.items()): if axis_name not in array_mapping: # type: ignore @@ -432,7 +418,7 @@ def named_sharding_to_xla_hlo_sharding( last_tile_dims = [] if replicated_mesh_axes: - axes_by_type = collections.defaultdict(list) + axes_by_type: dict[Any, list[int]] = collections.defaultdict(list) size_by_type = collections.defaultdict(lambda: 1) # type: ignore assert {x[0] for x in replicated_mesh_axes}.issuperset(set(special_axes.keys())) for i, size in replicated_mesh_axes: @@ -491,21 +477,12 @@ def array_mapping_to_axis_resources(array_mapping: ArrayMapping): partitions.append(None) return PartitionSpec(*partitions) -get_single_pspec = lambda p: array_mapping_to_axis_resources(get_array_mapping(p)) # type: ignore - -# TODO(yashkatariya): Remove this after jax 0.5.2 release -def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()): - if parsed_pspec is None: - spec = PartitionSpec() if spec is None else spec - parsed_pspec = ParsedPartitionSpec.from_user_input( - spec, "NamedSharding spec", allow_unconstrained_dims=True) - _check_unique_resources(parsed_pspec, "NamedSharding spec", mesh) - _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes) - return parsed_pspec +@cache(max_size=128, trace_context_in_key=False) def check_pspec(mesh, spec, _manual_axes=frozenset()): _check_unique_resources(spec, "NamedSharding spec", mesh) - _check_mesh_resource_axis(mesh, spec, _manual_axes) + _check_mesh_resource_axis(mesh, spec) + _check_mesh_unreduced(mesh, spec) class DuplicateSpecError(Exception): def __init__(self, message, mesh, pspec): @@ -517,13 +494,10 @@ def __init__(self, message, mesh, pspec): def __str__(self): return f"{self.message}" -def _check_unique_resources( - pspec: ParsedPartitionSpec | PartitionSpec, arg_name: str, mesh=None, -) -> None: +def _check_unique_resources(pspec: PartitionSpec, arg_name: str, mesh=None + ) -> None: resource_counts: dict[MeshAxisName, int] = {} duplicate = False - pspec = (pspec.get_partition_spec() if isinstance(pspec, ParsedPartitionSpec) - else pspec) for d in pspec: if d is PartitionSpec.UNCONSTRAINED or d is None: continue @@ -542,31 +516,42 @@ def _check_unique_resources( f' for {mesh_lib.show_axes(multiple_uses)}'), mesh=mesh, pspec=pspec) -@cache(max_size=128, trace_context_in_key=False) -def _check_mesh_resource_axis(mesh, pspec, _manual_axes): - pspec = (pspec.get_partition_spec() if isinstance(pspec, ParsedPartitionSpec) - else pspec) +def _check_mesh_resource_axis(mesh, pspec): for p in pspec: if p is PartitionSpec.UNCONSTRAINED or p is None: continue p = p if isinstance(p, tuple) else (p,) for r in p: - if r not in mesh.shape: + if r not in mesh.axis_names: raise ValueError( f"Resource axis: {r} of {pspec} " f"is not found in mesh: {tuple(mesh.shape.keys())}.") - if r in _manual_axes: - raise ValueError( - f"Axis: {r} of {pspec} " - f"is also found in manual_axes: {_manual_axes}.") from None - if not all(mesh._name_to_type[p[0]] == mesh._name_to_type[r] for r in p): - raise ValueError( - 'AxisTypes should be the same in a tuple subset of PartitionSpec:' - f' {pspec}. Got subset {p} with axis' - f' types: ({", ".join(str(mesh._name_to_type[r]) for r in p)})') - if (mesh_lib.AxisType.Auto not in mesh._axis_types_dict and + if (AxisType.Auto not in mesh.axis_types and PartitionSpec.UNCONSTRAINED in pspec): raise ValueError( f'{pspec} cannot contain' ' `P.UNCONSTRAINED` when no mesh axis_types are `Auto`. Got mesh' - f' axis_types: {mesh._axis_types_dict}') + f' axis_types: {mesh.axis_types}') + +def _check_mesh_unreduced(mesh, pspec): + for u in pspec.unreduced: + if u not in mesh.axis_names: + raise ValueError( + f'Unreduced axes {u} is not found in {mesh.axis_names=}. ' + f'Got {pspec=}') + if mesh._name_to_type[u] == AxisType.Auto: + raise ValueError( + 'Unreduced axes can only refer to mesh axes that are of type' + f' `Explicit` or `Manual`. Got unreduced axes: {pspec.unreduced} and' + f' mesh: {mesh}') + + for u in pspec.reduced: + if u not in mesh.axis_names: + raise ValueError( + f'Reduced axes {u} is not found in {mesh.axis_names=}. ' + f'Got {pspec=}') + if mesh._name_to_type[u] == AxisType.Auto: + raise ValueError( + 'Reduced axes can only refer to mesh axes that are of type' + f' `Explicit` or `Manual`. Got reduced axes: {pspec.reduced} and' + f' mesh: {mesh}') diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 7df0a638e566..050ec7f28f3a 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -19,44 +19,57 @@ from collections.abc import Sequence from functools import partial import operator -import math import numpy as np -from typing import Any, List, Literal +from typing import Any, Literal, overload +import warnings -import jax -import jax.numpy as jnp -from jax import custom_jvp -from jax import lax +from jax._src import api from jax._src import config from jax._src import core +from jax._src import custom_derivatives from jax._src import deprecations from jax._src import dtypes +from jax._src import lax +from jax._src import numpy as jnp from jax._src import util from jax._src.core import AxisName -from jax._src.sharding_impls import NamedSharding, PartitionSpec as P from jax._src.cudnn.fused_attention_stablehlo import ( dot_product_attention as cudnn_dot_product_attention, MaskType) from jax._src.cudnn.scaled_matmul_stablehlo import ( scaled_matmul_wrapper as cudnn_scaled_matmul, scaled_dot_general_wrapper as cudnn_scaled_dot_general, BlockScaleConfig) -from jax._src.interpreters import batching -from jax._src.interpreters import mlir +from jax._src.numpy import einsum as jnp_einsum from jax._src.numpy import util as numpy_util +from jax._src.numpy.reductions import _count +from jax._src.numpy.reductions import Axis +from jax._src.sharding_impls import NamedSharding, PartitionSpec as P from jax._src.typing import Array, ArrayLike, DType, DTypeLike from jax._src.ops.special import logsumexp as _logsumexp -class Unspecified: - def __repr__(self): - return "_UNSPECIFIED" -_UNSPECIFIED = Unspecified() +# activations +@api.jit +def identity(x: ArrayLike) -> Array: + r"""Identity activation function. + Returns the argument unmodified. -# activations + Args: + x : input array + + Returns: + The argument `x` unmodified. -@custom_jvp -@jax.jit + Examples: + >>> jax.nn.identity(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.])) + Array([-2. , -1. , -0.5, 0. , 0.5, 1. , 2. ], dtype=float32) + + """ + return numpy_util.ensure_arraylike("identity", x) + +@custom_derivatives.custom_jvp +@api.jit def relu(x: ArrayLike) -> Array: r"""Rectified linear unit activation function. @@ -92,7 +105,7 @@ def relu(x: ArrayLike) -> Array: # For behavior at 0, see https://dl.acm.org/doi/10.5555/3540261.3540297 relu.defjvps(lambda g, ans, x: lax.select(x > 0, g, lax.full_like(g, 0))) -@jax.jit +@api.jit def squareplus(x: ArrayLike, b: ArrayLike = 4) -> Array: r"""Squareplus activation function. @@ -107,14 +120,11 @@ def squareplus(x: ArrayLike, b: ArrayLike = 4) -> Array: x : input array b : smoothness parameter """ - numpy_util.check_arraylike("squareplus", x) - numpy_util.check_arraylike("squareplus", b) - x = jnp.asarray(x) - b = jnp.asarray(b) + x, b = numpy_util.ensure_arraylike("squareplus", x, b) y = x + jnp.sqrt(jnp.square(x) + b) return y / 2 -@jax.jit +@api.jit def softplus(x: ArrayLike) -> Array: r"""Softplus activation function. @@ -128,7 +138,7 @@ def softplus(x: ArrayLike) -> Array: """ return jnp.logaddexp(x, 0) -@jax.jit +@api.jit def sparse_plus(x: ArrayLike) -> Array: r"""Sparse plus function. @@ -150,11 +160,10 @@ def sparse_plus(x: ArrayLike) -> Array: Args: x: input (float) """ - numpy_util.check_arraylike("sparse_plus", x) - x = jnp.asarray(x) + x = numpy_util.ensure_arraylike("sparse_plus", x) return jnp.where(x <= -1.0, 0.0, jnp.where(x >= 1.0, x, (x + 1.0)**2/4)) -@jax.jit +@api.jit def soft_sign(x: ArrayLike) -> Array: r"""Soft-sign activation function. @@ -166,11 +175,10 @@ def soft_sign(x: ArrayLike) -> Array: Args: x : input array """ - numpy_util.check_arraylike("soft_sign", x) - x_arr = jnp.asarray(x) + x_arr = numpy_util.ensure_arraylike("soft_sign", x) return x_arr / (jnp.abs(x_arr) + 1) -@partial(jax.jit, inline=True) +@api.jit(inline=True) def sigmoid(x: ArrayLike) -> Array: r"""Sigmoid activation function. @@ -191,7 +199,7 @@ def sigmoid(x: ArrayLike) -> Array: """ return lax.logistic(x) -@jax.jit +@api.jit def sparse_sigmoid(x: ArrayLike) -> Array: r"""Sparse sigmoid activation function. @@ -223,7 +231,7 @@ def sparse_sigmoid(x: ArrayLike) -> Array: """ return 0.5 * jnp.clip(x + 1.0, 0.0, 2.0) -@jax.jit +@api.jit def silu(x: ArrayLike) -> Array: r"""SiLU (aka swish) activation function. @@ -243,13 +251,12 @@ def silu(x: ArrayLike) -> Array: See also: :func:`sigmoid` """ - numpy_util.check_arraylike("silu", x) - x_arr = jnp.asarray(x) + x_arr = numpy_util.ensure_arraylike("silu", x) return x_arr * sigmoid(x_arr) swish = silu -@jax.jit +@api.jit def mish(x: ArrayLike) -> Array: r"""Mish activation function. @@ -268,11 +275,10 @@ def mish(x: ArrayLike) -> Array: Returns: An array. """ - numpy_util.check_arraylike("mish", x) - x_arr = jnp.asarray(x) + x_arr = numpy_util.ensure_arraylike("mish", x) return x_arr * jnp.tanh(softplus(x_arr)) -@jax.jit +@api.jit def log_sigmoid(x: ArrayLike) -> Array: r"""Log-sigmoid activation function. @@ -290,11 +296,10 @@ def log_sigmoid(x: ArrayLike) -> Array: See also: :func:`sigmoid` """ - numpy_util.check_arraylike("log_sigmoid", x) - x_arr = jnp.asarray(x) + x_arr = numpy_util.ensure_arraylike("log_sigmoid", x) return -softplus(-x_arr) -@jax.jit +@api.jit def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Array: r"""Exponential linear unit activation function. @@ -316,13 +321,12 @@ def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Array: See also: :func:`selu` """ - numpy_util.check_arraylike("elu", x) - x_arr = jnp.asarray(x) + x_arr = numpy_util.ensure_arraylike("elu", x) return jnp.where(x_arr > 0, x_arr, alpha * jnp.expm1(jnp.where(x_arr > 0, 0., x_arr))) -@jax.jit +@api.jit def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> Array: r"""Leaky rectified linear unit activation function. @@ -346,11 +350,10 @@ def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> Array: See also: :func:`relu` """ - numpy_util.check_arraylike("leaky_relu", x) - x_arr = jnp.asarray(x) + x_arr = numpy_util.ensure_arraylike("leaky_relu", x) return jnp.where(x_arr >= 0, x_arr, negative_slope * x_arr) -@jax.jit +@api.jit def hard_tanh(x: ArrayLike) -> Array: r"""Hard :math:`\mathrm{tanh}` activation function. @@ -369,11 +372,10 @@ def hard_tanh(x: ArrayLike) -> Array: Returns: An array. """ - numpy_util.check_arraylike("hard_tanh", x) - x_arr = jnp.asarray(x) + x_arr = numpy_util.ensure_arraylike("hard_tanh", x) return jnp.where(x_arr > 1, 1, jnp.where(x_arr < -1, -1, x_arr)) -@jax.jit +@api.jit def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Array: r"""Continuously-differentiable exponential linear unit activation. @@ -398,7 +400,7 @@ def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Array: """ return jnp.maximum(x, 0.0) + alpha * jnp.expm1(jnp.minimum(x, 0.0) / alpha) -@jax.jit +@api.jit def selu(x: ArrayLike) -> Array: r"""Scaled exponential linear unit activation. @@ -431,7 +433,7 @@ def selu(x: ArrayLike) -> Array: return scale * elu(x, alpha) # TODO(phawkins): this jit was found to change numerics in a test. Debug this. -# @partial(jax.jit, static_argnames=("approximate",)) +# @api.jit(static_argnames=("approximate",)) def gelu(x: ArrayLike, approximate: bool = True) -> Array: r"""Gaussian error linear unit activation function. @@ -466,7 +468,7 @@ def gelu(x: ArrayLike, approximate: bool = True) -> Array: 0.5 * x_arr * (lax.erfc(-x_arr * sqrt_half)), dtype=x_arr.dtype ) -@partial(jax.jit, static_argnames=("axis",)) +@api.jit(static_argnames=("axis",)) def glu(x: ArrayLike, axis: int = -1) -> Array: r"""Gated linear unit activation function. @@ -490,8 +492,7 @@ def glu(x: ArrayLike, axis: int = -1) -> Array: See also: :func:`sigmoid` """ - numpy_util.check_arraylike("glu", x) - x_arr = jnp.asarray(x) + x_arr = numpy_util.ensure_arraylike("glu", x) size = x_arr.shape[axis] assert size % 2 == 0, "axis size must be divisible by 2" x1, x2 = jnp.split(x_arr, 2, axis) @@ -502,11 +503,39 @@ def glu(x: ArrayLike, axis: int = -1) -> Array: logsumexp = _logsumexp -@partial(jax.jit, static_argnames=("axis",)) +@api.jit(static_argnames=("axis", "keepdims")) +def logmeanexp( + x: ArrayLike, + axis: Axis = None, + where: ArrayLike | None = None, + keepdims: bool = False, +) -> Array: + r"""Log mean exp. + + Computes the function: + + .. math:: + \text{logmeanexp}(x) = \log \frac{1}{n} \sum_{i=1}^n \exp x_i = \text{logsumexp}(x) - \log n + + Args: + x: Input array. + axis: Axis or axes along which to reduce. + where: Elements to include in the reduction. Optional. + keepdims: Preserve the dimensions of the input. + Returns: + An array. + See also: + :func:`jax.nn.logsumexp` + """ + lse = _logsumexp(x, axis=axis, where=where, keepdims=keepdims) + count = _count(x, axis=axis, where=where, keepdims=keepdims, dtype=lse.dtype) + return lse - jnp.log(count) + + +@api.jit(static_argnames=("axis",)) def log_softmax(x: ArrayLike, - axis: int | tuple[int, ...] | None = -1, - where: ArrayLike | None = None, - initial: Unspecified = _UNSPECIFIED) -> Array: + axis: Axis = -1, + where: ArrayLike | None = None) -> Array: r"""Log-Softmax function. Computes the logarithm of the :code:`softmax` function, which rescales @@ -519,8 +548,9 @@ def log_softmax(x: ArrayLike, Args: x : input array axis: the axis or axes along which the :code:`log_softmax` should be - computed. Either an integer or a tuple of integers. - where: Elements to include in the :code:`log_softmax`. + computed. Either an integer, tuple of integers, or ``None`` (all axes). + where: Elements to include in the :code:`log_softmax`. The output for any + masked-out element is minus infinity. Returns: An array. @@ -532,29 +562,23 @@ def log_softmax(x: ArrayLike, See also: :func:`softmax` """ - # TODO(jakevdp): remove the initial argument after JAX v0.4.40. - if initial is not _UNSPECIFIED: - raise TypeError("The initial argument to jax.nn.log_softmax was removed in JAX v0.4.36.") - del initial - numpy_util.check_arraylike("log_softmax", x) - x_arr = jnp.asarray(x) - x_max = jnp.max(x_arr, axis, where=where, initial=-jnp.inf, keepdims=True) - x_safe = x_arr if where is None else jnp.where(where, x_arr, -jnp.inf) + x_arr = numpy_util.ensure_arraylike("log_softmax", x) + x_max = jnp.max(x_arr, axis, where=where, initial=-np.inf, keepdims=True) + x_safe = x_arr if where is None else jnp.where(where, x_arr, -np.inf) shifted = x_safe - lax.stop_gradient(x_max) shifted_logsumexp = jnp.log( jnp.sum(jnp.exp(shifted), axis, where=where, keepdims=True)) result = shifted - shifted_logsumexp if where is not None: - return jnp.where(where, result, -jnp.inf) + return jnp.where(where, result, -np.inf) return result # TODO(phawkins): this jit was found to change numerics in a test. Debug this. -# @partial(jax.jit, static_argnames=("axis",)) +# @api.jit(static_argnames=("axis",)) def softmax(x: ArrayLike, - axis: int | tuple[int, ...] | None = -1, - where: ArrayLike | None = None, - initial: Unspecified = _UNSPECIFIED) -> Array: + axis: Axis = -1, + where: ArrayLike | None = None) -> Array: r"""Softmax function. Computes the function which rescales elements to the range :math:`[0, 1]` @@ -567,8 +591,9 @@ def softmax(x: ArrayLike, x : input array axis: the axis or axes along which the softmax should be computed. The softmax output summed across these dimensions should sum to :math:`1`. - Either an integer or a tuple of integers. - where: Elements to include in the :code:`softmax`. + Either an integer, tuple of integers, or ``None`` (all axes). + where: Elements to include in the :code:`softmax`. The output for any + masked-out element is zero. Returns: An array. @@ -580,10 +605,6 @@ def softmax(x: ArrayLike, See also: :func:`log_softmax` """ - # TODO(jakevdp): remove the initial argument after JAX v0.4.40. - if initial is not _UNSPECIFIED: - raise TypeError("The initial argument to jax.nn.softmax was removed in JAX v0.4.36.") - del initial if config.softmax_custom_jvp.value: # mypy is confused by the `functools.partial` application in the definition # of `_softmax` and incorrectly concludes that `_softmax` returns @@ -593,12 +614,12 @@ def softmax(x: ArrayLike, return _softmax_deprecated(x, axis, where) # TODO(mattjj): replace softmax with _softmax when deprecation flag is removed -@partial(jax.custom_jvp, nondiff_argnums=(1,)) +@partial(custom_derivatives.custom_jvp, nondiff_argnums=(1,)) def _softmax( x: ArrayLike, - axis: int | tuple[int, ...] | None = -1, + axis: Axis = -1, where: ArrayLike | None = None, - initial: ArrayLike | None = -jnp.inf) -> Array: + initial: ArrayLike = -np.inf) -> Array: x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True) x_safe = x if where is None else jnp.where(where, x, initial) unnormalized = jnp.exp(x_safe - x_max) @@ -615,9 +636,9 @@ def _softmax_jvp(axis, primals, tangents): def _softmax_deprecated( x: ArrayLike, - axis: int | tuple[int, ...] | None = -1, + axis: Axis = -1, where: ArrayLike | None = None, - initial: ArrayLike | None = -jnp.inf) -> Array: + initial: ArrayLike = -np.inf) -> Array: x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True) x_safe = x if where is None else jnp.where(where, x, initial) unnormalized = jnp.exp(x_safe - lax.stop_gradient(x_max)) @@ -627,14 +648,40 @@ def _softmax_deprecated( return result -@partial(jax.jit, static_argnames=("axis",)) +@api.jit(static_argnames=("axis",)) def standardize(x: ArrayLike, - axis: int | tuple[int, ...] | None = -1, - mean: ArrayLike | None = None, - variance: ArrayLike | None = None, - epsilon: ArrayLike = 1e-5, - where: ArrayLike | None = None) -> Array: - r"""Normalizes an array by subtracting ``mean`` and dividing by :math:`\sqrt{\mathrm{variance}}`.""" + axis: Axis = -1, + mean: ArrayLike | None = None, + variance: ArrayLike | None = None, + epsilon: ArrayLike = 1e-5, + where: ArrayLike | None = None) -> Array: + r"""Standardizes input to zero mean and unit variance. + + The standardization is given by: + + .. math:: + + x_{std} = \frac{x - \langle x\rangle}{\sqrt{\langle(x - \langle x\rangle)^2\rangle + \epsilon}} + + where :math:`\langle x\rangle` indicates the mean of :math:`x`, and :math:`\epsilon` is + a small correction factor introduced to avoid division by zero. + + Args: + x: input array to be standardized. + axis: integer, tuple of integers, or ``None`` (all axes), representing the + axes along which to standardize. Defaults to the last axis (``-1``). + mean: optionally specify the mean used for standardization. If not specified, + then ``x.mean(axis, where=where)`` will be used. + variance: optionally specify the variance used for standardization. If not + specified, then ``x.var(axis, where=where)`` will be used. + epsilon: correction factor added to variance to avoid division by zero; defaults + to ``1E-5``. + where: optional boolean mask specifying which elements to use when computing + the mean and variance. + + Returns: + An array of the same shape as ``x`` containing the standardized input. + """ numpy_util.check_arraylike("standardize", x) numpy_util.check_arraylike_or_none("standardize", mean, variance, where) if mean is None: @@ -646,25 +693,29 @@ def standardize(x: ArrayLike, # when used in neural network normalization layers variance = jnp.mean( jnp.square(x), axis, keepdims=True, where=where) - jnp.square(mean) - return jnp.subtract(x, jnp.asarray(mean)) * lax.rsqrt(jnp.asarray(variance) + epsilon) + # Because we're using a less accurate variance definition, it may return + # negative values. This is problematic for the rsqrt, so we clip to 0. + # Note that this clipping only matters when the variance is vanishingly + # small compared to the mean of x, so the gradient should be unaffected. + variance = jnp.clip(variance, 0) + return jnp.subtract(x, mean) * lax.rsqrt(variance + epsilon) # TODO(slebedev): Change the type of `x` to `ArrayLike`. -@partial(jax.jit, static_argnames=("num_classes", "dtype", "axis")) +@api.jit(static_argnames=("num_classes", "dtype", "axis")) def _one_hot(x: Array, num_classes: int, *, - dtype: Any, axis: int | AxisName) -> Array: + dtype: DTypeLike, axis: int | AxisName) -> Array: num_classes = core.concrete_dim_or_error( num_classes, "The error arose in jax.nn.one_hot argument `num_classes`.") - dtype = dtypes.canonicalize_dtype(dtype) try: - output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1) + output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1) # type: ignore[arg-type] except TypeError: - axis_size = lax.psum(1, axis) + axis_size = lax.axis_size(axis) if num_classes != axis_size: raise ValueError(f"Expected num_classes to match the size of axis {axis}, " f"but {num_classes} != {axis_size}") from None axis_idx = lax.axis_index(axis) - return jnp.asarray(_dot_product_attention_xla == axis_idx, dtype=dtype) + return jnp.asarray(x == axis_idx, dtype=dtype) axis = operator.index(axis) # type: ignore[arg-type] lhs = lax.expand_dims(x, (axis,)) rhs_shape = [1] * x.ndim @@ -677,7 +728,7 @@ def _one_hot(x: Array, num_classes: int, *, # TODO(slebedev): Change the type of `x` to `ArrayLike`. def one_hot(x: Any, num_classes: int, *, - dtype: Any = jnp.float_, axis: int | AxisName = -1) -> Array: + dtype: Any | None = None, axis: int | AxisName = -1) -> Array: """One-hot encodes the given indices. Each index in the input ``x`` is encoded as a vector of zeros of length @@ -705,17 +756,18 @@ def one_hot(x: Any, num_classes: int, *, num_classes, "The error arose in jax.nn.one_hot argument `num_classes`.") x_arr = jnp.asarray(x) - if not jnp.isdtype(x_arr.dtype, "integral"): + if not dtypes.isdtype(x_arr.dtype, "integral"): # Deprecated 2024-12-18 deprecations.warn( 'jax-nn-one-hot-float-input', f"jax.nn.one_hot input should be integer-typed; got dtype={x_arr.dtype}", stacklevel=1) + dtype = dtypes.default_float_dtype() if dtype is None else dtype return _one_hot(x_arr, num_classes, dtype=dtype, axis=axis) -@jax.custom_jvp -@jax.jit +@custom_derivatives.custom_jvp +@api.jit def relu6(x: ArrayLike) -> Array: r"""Rectified Linear Unit 6 activation function. @@ -747,7 +799,7 @@ def relu6(x: ArrayLike) -> Array: relu6.defjvps(lambda g, ans, x: lax.select((x > 0) & (x < 6), g, lax.full_like(g, 0))) -@jax.jit +@api.jit def hard_sigmoid(x: ArrayLike) -> Array: r"""Hard Sigmoid activation function. @@ -767,7 +819,7 @@ def hard_sigmoid(x: ArrayLike) -> Array: """ return relu6(x + 3.) / 6. -@jax.jit +@api.jit def hard_silu(x: ArrayLike) -> Array: r"""Hard SiLU (swish) activation function @@ -788,18 +840,17 @@ def hard_silu(x: ArrayLike) -> Array: See also: :func:`hard_sigmoid` """ - numpy_util.check_arraylike("hard_silu", x) - x_arr = jnp.asarray(x) + x_arr = numpy_util.ensure_arraylike("hard_silu", x) return x_arr * hard_sigmoid(x_arr) hard_swish = hard_silu def _get_large_negative(dtype): - dtype_max = jnp.finfo(dtype).max + dtype_max = dtypes.finfo(dtype).max return jnp.asarray(-0.7 * dtype_max, dtype=dtype) def _get_causal_mask(T, S): - mask = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_)) + mask = jnp.tril(jnp.ones((T, S), dtype=bool)) return mask[None, None, :, :] def _get_window_mask(T: int, S: int, local_window_size: tuple[int, int]): @@ -829,12 +880,12 @@ def _get_padding_mask_encoded(T, q_seqlen): def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen, local_window_size): - if mask is None and not is_causal and q_seqlen is None and kv_seqlen is None: + if mask is None and not is_causal and q_seqlen is None and kv_seqlen is None and local_window_size is None: return logits - combined_mask = jnp.ones_like(logits, dtype=jnp.bool_) + combined_mask = jnp.ones_like(logits, dtype=bool) if mask is not None: - assert mask.dtype == jnp.bool_ + assert mask.dtype == np.dtype(bool) combined_mask = jnp.logical_and(combined_mask, mask) T, S = logits.shape[2], logits.shape[3] @@ -856,16 +907,17 @@ def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen, return padded_logits def _dot_product_attention_core(query, key, value, bias, mask, is_causal, - scale, q_seqlen, kv_seqlen, local_window_size): - logits_dtype = jnp.promote_types(query.dtype, jnp.float32) + scale, q_seqlen, kv_seqlen, local_window_size, + return_residual): + logits_dtype = jnp.promote_types(query.dtype, np.float32) # If the query and logits dtypes are different, then the default precision # can use inconsistent types in the backwards pass # (see https://github.com/jax-ml/jax/issues/24047). - if query.dtype == jnp.bfloat16: - precision = jax.lax.DotAlgorithmPreset.BF16_BF16_F32 - elif query.dtype == jnp.float16: - precision = jax.lax.DotAlgorithmPreset.F16_F16_F32 + if query.dtype == dtypes.bfloat16: + precision = lax.DotAlgorithmPreset.BF16_BF16_F32 + elif query.dtype == np.float16: + precision = lax.DotAlgorithmPreset.F16_F16_F32 # TODO(sbodenstein): Implement this fix for all dtypes. else: precision = None @@ -874,7 +926,7 @@ def _dot_product_attention_core(query, key, value, bias, mask, is_causal, # some GPUs do not support BF16_BF16_F32, and TPU does not support F16_F16_F32. # Use the default precision as a fallback in these cases. try: - logits = jnp.einsum( + logits = jnp_einsum.einsum( "BTNH,BSNH->BNTS", query, key, @@ -882,7 +934,7 @@ def _dot_product_attention_core(query, key, value, bias, mask, is_causal, preferred_element_type=logits_dtype, ) except: # pylint: disable=bare-except - logits = jnp.einsum( + logits = jnp_einsum.einsum( "BTNH,BSNH->BNTS", query, key, @@ -899,13 +951,19 @@ def _dot_product_attention_core(query, key, value, bias, mask, is_causal, local_window_size) # Softmax and it is always carried out in fp32. - padded_logits = padded_logits.astype(jnp.float32) - probs = jax.nn.softmax(padded_logits, axis=-1).astype(key.dtype) + padded_logits = padded_logits.astype(np.float32) + probs = softmax(padded_logits, axis=-1).astype(key.dtype) - encoded = jnp.einsum('BNTS,BSNH->BTNH', probs, value) - if q_seqlen is not None and kv_seqlen is not None: + encoded = jnp_einsum.einsum('BNTS,BSNH->BTNH', probs, value) + if q_seqlen is not None: mask = _get_padding_mask_encoded(encoded.shape[1], q_seqlen) encoded *= mask.astype(encoded.dtype) + + if return_residual: + lse_residual = logsumexp(padded_logits, axis=-1).astype(key.dtype) + lse_residual = jnp.transpose(lse_residual, (0, 2, 1)) # B N T -> B T N + return encoded, lax.stop_gradient(lse_residual) + return encoded def _dot_product_attention_xla( @@ -918,7 +976,8 @@ def _dot_product_attention_xla( scale: float, q_seqlen: Array | None, kv_seqlen: Array | None, - local_window_size: tuple[int, int] | None): + local_window_size: tuple[int, int] | None, + return_residual: bool = False): B, T, N, H = query.shape _, S, K, _ = key.shape @@ -936,78 +995,25 @@ def _reshape_to_grouped(t): return t bias = _reshape_to_grouped(bias) mask = _reshape_to_grouped(mask) - vmapped_fn = jax.vmap( + vmapped_fn = api.vmap( _dot_product_attention_core, - in_axes=(3, None, None, 2, 2, None, None, None, None, None), + in_axes=(3, None, None, 2, 2, None, None, None, None, None, None), out_axes=3, ) - encoded = vmapped_fn(query, key, value, bias, mask, is_causal, scale, - q_seqlen, kv_seqlen, local_window_size) - encoded = jnp.reshape(encoded, (B, T, N, H)) + output = vmapped_fn(query, key, value, bias, mask, is_causal, scale, + q_seqlen, kv_seqlen, local_window_size, return_residual) + + if return_residual: + encoded, lse_residual = output + encoded = jnp.reshape(encoded, (B, T, N, H)) + lse_residual = jnp.reshape(lse_residual, (B, T, N)) + return encoded, lse_residual + + encoded = jnp.reshape(output, (B, T, N, H)) return encoded -def bias_fwd_rule(a, query_head_num): - return bias_fwd_p.bind(a, query_head_num), a -def bias_bwd_rule(query_head_num, res, g): - a = res - if a.shape[0] > 1 or a.shape[-3] != query_head_num: - raise ValueError("cuDNN only supports bias gradient when the batch size is " - f"1 and the head number matches the query, but got " - f"B={a.shape[0]}, N={a.shape[-3]}.") - return (bias_bwd_p.bind(g, a, query_head_num),) - -# This function uses two custom primitives, `bias_fwd` and `bias_bwd`, to work -# around a cuDNN issue where bias gradients are only supported when the batch -# size is 1 and the number of heads matches the query. -# TODO(kaixih@nvidia): Remove this workaround once cuDNN resolves the issue. -@partial(jax.custom_vjp, nondiff_argnums=(1,)) -def check_valid_bias_batch(x, query_head_num): - output, _ = bias_fwd_rule(x, query_head_num) - return output -check_valid_bias_batch.defvjp(bias_fwd_rule, bias_bwd_rule) - -bias_fwd_p = core.Primitive('bias_fwd') -bias_fwd_p.multiple_results = False -bias_bwd_p = core.Primitive('bias_bwd') -bias_bwd_p.multiple_results = False - -def bias_fwd_impl(a, query_head_num): - return a -def bias_bwd_impl(g, a, query_head_num): - return g -bias_fwd_p.def_impl(bias_fwd_impl) -bias_bwd_p.def_impl(bias_bwd_impl) - -def bias_fwd_abstract_eval(a, query_head_num): - return core.ShapedArray(a.shape, a.dtype) -def bias_bwd_abstract_eval(g, a, query_head_num): - return core.ShapedArray(g.shape, g.dtype) -bias_fwd_p.def_abstract_eval(bias_fwd_abstract_eval) -bias_bwd_p.def_abstract_eval(bias_bwd_abstract_eval) - -def bias_fwd_lowering(ctx, a, query_head_num): - return [a] -def bias_bwd_lowering(ctx, g, a, query_head_num): - return [g] -mlir.register_lowering(bias_fwd_p, bias_fwd_lowering) -mlir.register_lowering(bias_bwd_p, bias_bwd_lowering) - -def bias_fwd_batch_rule(batched_args, batch_dims): - x, query_head_num = batched_args - a = batch_dims[0] - output, _ = bias_fwd_rule(x, query_head_num) - return output, a -def bias_bwd_batch_rule(batched_args, batch_dims): - g, x, query_head_num = batched_args - b = batch_dims[0] - *Bs, _, _, _ = x.shape - B = math.prod(Bs) - x = jnp.reshape(x, (B,) + x.shape[-3:]) - output, = bias_bwd_rule(query_head_num, x, g) - return output, b -batching.primitive_batchers[bias_fwd_p] = bias_fwd_batch_rule -batching.primitive_batchers[bias_bwd_p] = bias_bwd_batch_rule +@overload def dot_product_attention( query: ArrayLike, key: ArrayLike, @@ -1020,17 +1026,56 @@ def dot_product_attention( query_seq_lengths: ArrayLike | None = None, key_value_seq_lengths: ArrayLike | None = None, local_window_size: int | tuple[int, int] | None = None, - implementation: Literal['xla', 'cudnn'] | None = None) -> Array: + implementation: Literal['xla', 'cudnn'] | None = None, + return_residual: Literal[False] = ..., +) -> Array: ... + +@overload +def dot_product_attention( + query: ArrayLike, + key: ArrayLike, + value: ArrayLike, + bias: ArrayLike | None = None, + mask: ArrayLike | None = None, + *, + scale: float | None = None, + is_causal: bool = False, + query_seq_lengths: ArrayLike | None = None, + key_value_seq_lengths: ArrayLike | None = None, + local_window_size: int | tuple[int, int] | None = None, + implementation: Literal['xla', 'cudnn'] | None = None, + return_residual: Literal[True] = ..., +) -> tuple[Array, Array]: ... + +def dot_product_attention( + query: ArrayLike, + key: ArrayLike, + value: ArrayLike, + bias: ArrayLike | None = None, + mask: ArrayLike | None = None, + *, + scale: float | None = None, + is_causal: bool = False, + query_seq_lengths: ArrayLike | None = None, + key_value_seq_lengths: ArrayLike | None = None, + local_window_size: int | tuple[int, int] | None = None, + implementation: Literal['xla', 'cudnn'] | None = None, + return_residual: bool = False, +): r"""Scaled dot product attention function. - Computes the attention function on Query, Key, and Value tensors: + Computes the following for each head: .. math:: - \mathrm{Attention}(Q, K, V)=\mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V + \mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left( \frac{QK^T}{\sqrt{d}} + B \right) V - If we define :code:`logits` as the output of :math:`QK^T` and the - :code:`probs` as the output of :math:`softmax`. + where + :math:`Q` is the query matrix, + :math:`K` is the key matrix, + :math:`V` is the value matrix, + :math:`d` is the dimension of each individual query and key, + and :math:`B` is the bias matrix (optional). Throughout this function, we utilize the following uppercase letters to represent the shape of array:: @@ -1073,18 +1118,24 @@ def dot_product_attention( token's local window. If set, this specifies the (left_window_size, right_window_size) for each token. E.g., if local_window_size == (3, 2) and the sequence is [0, 1, 2, 3, 4, 5, c, 7, 8, 9], token `c` can attend - to [3, 4, 5, c, 7, 8]. If a single int is given, it will be intepreted as + to [3, 4, 5, c, 7, 8]. If a single int is given, it will be interpreted as a symmetric window (window_size, window_size). + return_residual: Whether to return the logsumexp tensor of shape BTN + or BNT to users. See section 3.1.1 in the FlashAttention-2 paper: + https://arxiv.org/pdf/2307.08691 to find the definition of logsumexp. implementation: A string to control which implementation backend to use. Supported strings are `xla`, `cudnn` (cuDNN flash attention). It defaults - to `None`, which will automatically select the best available backend. + to `None`, which currently falls back to `xla`. Note, `cudnn` supports only a subset of shapes/dtypes, and an exception will be thrown if its not supported. Returns: - An array of the attention output with the same shape as :code:`query`. + If return_residual is False, returns an array of the attention output with + the same shape as :code:`query`. If return_residual is True, returns a tuple + of (output, residual). The residual is the shape of BTN|TN. """ output_shape = jnp.asarray(query).shape + residual_shape = output_shape[:-1] def _ensure_4d(t): t = jnp.asarray(t) dims_to_add = 4 - t.ndim @@ -1119,11 +1170,11 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int], B, S, K, H = key_arr.shape _check_shape_and_dtype(value_arr, [B, S, K, H], key_arr.dtype, 'value') _check_shape_and_dtype(query_arr, [B, -1, -1, H], key_arr.dtype, 'query') - _check_shape_and_dtype(mask, [-1] * 4, jnp.bool_, 'mask') + _check_shape_and_dtype(mask, [-1] * 4, np.dtype(bool), 'mask') _check_shape_and_dtype(bias, [-1] * 4, None, 'bias') - _check_shape_and_dtype(query_seq_lengths, [B], jnp.int32, + _check_shape_and_dtype(query_seq_lengths, [B], np.dtype('int32'), 'query_seq_lengths') - _check_shape_and_dtype(key_value_seq_lengths, [B], jnp.int32, + _check_shape_and_dtype(key_value_seq_lengths, [B], np.dtype('int32'), 'key_value_seq_lengths') if query_arr.shape[-2] % K != 0: raise ValueError(f"The number of query heads must be a multiple of " @@ -1138,20 +1189,18 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int], scale=scale_val, q_seqlen=query_seq_lengths, kv_seqlen=key_value_seq_lengths, local_window_size=local_window_size, + return_residual=return_residual, ) case 'cudnn': - if bias is not None: - bias = check_valid_bias_batch(bias, query_arr.shape[-2]) - bias = jnp.asarray(bias) use_padding = ( query_seq_lengths is not None or key_value_seq_lengths is not None ) if use_padding: if query_seq_lengths is None: T = query_arr.shape[1] - query_seq_lengths = jnp.full((B,), T, dtype=jnp.int32) + query_seq_lengths = jnp.full((B,), T, dtype=np.int32) if key_value_seq_lengths is None: - key_value_seq_lengths = jnp.full((B,), S, dtype=jnp.int32) + key_value_seq_lengths = jnp.full((B,), S, dtype=np.int32) mask_type = MaskType.NO_MASK if use_padding and is_causal: @@ -1174,20 +1223,30 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int], out = cudnn_dot_product_attention( query_arr, key_arr, value_arr, bias, mask, query_seq_lengths, key_value_seq_lengths, scale=scale_val, mask_type=mask_type, - sliding_window_length=sliding_window, + sliding_window_length=sliding_window, return_residual=return_residual, ) + if return_residual: + # Regardless of input layout, cudnn always returns residual with + # (B N T) layout. + out, residual = out + residual = jnp.transpose(residual, (0, 2, 1)).astype(out.dtype) + out = (out, residual) case None: - # TODO(kaixih@nvidia) Defaults to XLA for now. Will automatically select - # best backend. + # TODO(kaixih@nvidia) Automatically select the best backend (defaults to XLA for now). out = _dot_product_attention_xla( query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal, scale=scale_val, q_seqlen=query_seq_lengths, kv_seqlen=key_value_seq_lengths, local_window_size=local_window_size, + return_residual=return_residual, ) case _: raise ValueError(f"Unsupported implementation option: {implementation}") + if return_residual: + out, residual = out + return jnp.reshape(out, output_shape), jnp.reshape(residual, residual_shape) + return jnp.reshape(out, output_shape) def scaled_matmul( @@ -1195,103 +1254,244 @@ def scaled_matmul( rhs: Array, lhs_scales: Array, rhs_scales: Array, - preferred_element_type: DTypeLike = jnp.float32, + preferred_element_type: DTypeLike = np.float32, ) -> Array: - r""" - Performs scaled matrix multiplication between two 3D arrays, with scaling - factors applied to the matrices. - .. math:: - \mathrm{ScaledMatmul}(lhs, rhs, lhs_scales, rhs_scales)=lhs_scales \cdot rhs_scales \cdot \mathrm{dot}(lhs, rhs) + r"""Scaled matrix multiplication function. + + Performs block-scaled matmul of `a` and `b` using `a_scales` and `b_scales`. + The last dim is the contracting dim, and block size is inferred. + + Mathematically, this operation is equivalent to:: + + a_block_size = a.shape[-1] // a_scales.shape[-1] + b_block_size = b.shape[-1] // b_scales.shape[-1] + a_scaled = a * jnp.repeat(a_scales, a_block_size, axis=-1) + b_scaled = b * jnp.repeat(b_scales, b_block_size, axis=-1) + jnp.einsum('BMK,BNK->BMN', a_scaled, b_scaled) + Args: - lhs (Array): A 3D array of shape (B, M, K). - rhs (Array): A 3D array of shape (B, N, K). - lhs_scales (Array): A 3D array of shape (B, M, K_block). - rhs_scales (Array): A 3D array of shape (B, N, K_block). - preferred_element_type (DTypeLike, optional): The preferred data type - for the computation. Defaults to `jnp.float32`. + lhs (Array): Operand a, shape (B, M, K). + rhs (Array): Operand b, shape (B, N, K). + lhs_scales (Array): Shape (B, M, K_a), where `K % K_a == 0`. + rhs_scales (Array): Shape (B, N, K_b), where `K % K_b == 0`. + preferred_element_type (DTypeLike, optional): Defaults to `jnp.float32`. + Returns: - Array: A 3D array of shape (B, M, N) representing the scaled matrix - multiplication result. - Raises: - AssertionError: If the number of columns in `lhs` (`lhs_K`) does not - match the number of columns in `rhs` (`rhs_K`). + Array of shape (B, M, N). + Notes: - - The function ensures that the `preferred_element_type` is - danonicalized before passing it to the underlying computation. - - Scaling is applied to the matrices based on the `lhs_scales` and - `rhs_scales` arrays, enabling efficient computations in blocks. + - We currently do not support user-defined `precision` for customizing the + compute data type. It is fixed to `jnp.float32`. + - Block size is inferred as `K // K_a` for `a` and `K // K_b` for `b`. + - To use cuDNN with Nvidia Blackwell GPUs, inputs must match:: + + # mxfp8 + a, b: jnp.float8_e4m3fn | jnp.float8_e5m2 + a_scales, b_scales: jnp.float8_e8m0fnu + block_size: 32 + # nvfp4 + a, b: jnp.float4_e2m1fn + a_scales, b_scales: jnp.float8_e4m3fn + block_size: 16 + + Examples: + + Basic case: + + >>> a = jnp.array([1, 2, 3]).reshape((1, 1, 3)) + >>> b = jnp.array([4, 5, 6]).reshape((1, 1, 3)) + >>> a_scales = jnp.array([0.5]).reshape((1, 1, 1)) + >>> b_scales = jnp.array([0.5]).reshape((1, 1, 1)) + >>> scaled_matmul(a, b, a_scales, b_scales) # doctest: +SKIP + Array([[[8.]]], dtype=float32) + + Using fused cuDNN call on Blackwell GPUs: + + >>> dtype = jnp.float8_e4m3fn + >>> a = jax.random.normal(jax.random.PRNGKey(1), (3, 128, 64), dtype=dtype) + >>> b = jax.random.normal(jax.random.PRNGKey(2), (3, 128, 64), dtype=dtype) + >>> a_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu) + >>> b_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu) + >>> scaled_matmul(a, b, a_scales, b_scales) # doctest: +SKIP """ - B, M, lhs_K = lhs.shape - _, N, rhs_K = rhs.shape - assert lhs_K == rhs_K - _, _, K_block = lhs_scales.shape - - preferred_element_type = dtypes.canonicalize_dtype( - np.dtype(preferred_element_type) + a, b, a_scales, b_scales = lhs, rhs, lhs_scales, rhs_scales + if not all(x.ndim == 3 for x in (a, b, a_scales, b_scales)): + raise ValueError( + "scaled_matmul requires all inputs to be 3-dimensional arrays" + ) + + B_a, M_a, K_a = a.shape + B_b, N_b, K_b = b.shape + if K_a != K_b or B_a != B_b: + raise ValueError( + "scaled_matmul requires inputs a and b to have matching batch (B) " + f"and contract (K) dimensions, but got shapes {a.shape} and " + f"{b.shape}" + ) + + B_as, M_as, K_as = a_scales.shape + B_bs, N_bs, K_bs = b_scales.shape + if K_as != K_bs or B_as != B_bs: + raise ValueError( + "scaled_matmul requires scales to have matching batch (B) and " + f"contract (K) dimensions, but got shapes {a_scales.shape} and " + f"{b_scales.shape}" + ) + + if M_as != M_a or N_bs != N_b: + raise ValueError( + "scaled_matmul requires scales to match non-contract dimensions of " + f"inputs, but got shapes a: {a.shape}, b: {b.shape}, a_scales: " + f"{a_scales.shape}, b_scales: {b_scales.shape}" + ) + + preferred_element_type = dtypes.check_and_canonicalize_user_dtype( + preferred_element_type, "scaled_matmul" ) out = cudnn_scaled_matmul( - lhs, - rhs, - lhs_scales, - rhs_scales, + a, + b, + a_scales, + b_scales, preferred_element_type=preferred_element_type, ) return out +def get_scaled_dot_general_config(mode: Literal['nvfp4', 'mxfp8'], + global_scale: Array | None = None): + r"""Get quantization configs for scaled_dot_general. + + Create quantization configs for the `jax.nn.scaled_dot_general`. + + See Also: + - :func:`jax.nn.scaled_dot_general`: Scaled dot general function. + """ + + if mode == 'nvfp4': + one = jnp.ones((1,), dtype=np.float32) + return BlockScaleConfig( + mode='nvfp4', + block_size=16, + data_type=dtypes.float4_e2m1fn, + scale_type=dtypes.float8_e4m3fn, + global_scale=one if global_scale is None else global_scale, + infer_only=False + ) + elif mode == 'mxfp8': + return BlockScaleConfig( + mode='mxfp8', + block_size=32, + data_type=dtypes.float8_e4m3fn, + scale_type=dtypes.float8_e8m0fnu, + global_scale=None, + infer_only=False + ) + else: + raise ValueError(f"Unsupported mode: {mode}") + def scaled_dot_general( lhs, rhs, dimension_numbers, - preferred_element_type=jnp.float32, - configs: List[BlockScaleConfig] | None = None, + preferred_element_type=np.float32, + configs: list[BlockScaleConfig] | None = None, implementation: Literal['cudnn'] | None = None, ): r"""Scaled dot general operation. - Computes the scaled dot general on lhs, rhs with quanitzation specified by configs: - .. math:: - \widehat{lhs}, s_a=\mathrm{quantize}(lhs) \\ - \widehat{rhs}, s_b=\mathrm{quantize}(rhs) \\ - \mathrm{ScaledDot}(lhs, rhs)=s_a \cdot s_b \cdot \mathrm{dot}(\widehat{lhs}, \widehat{rhs}) + + Performs a generalized dot product with block-scaled quantization on the + lhs and rhs inputs. This operation extends `lax.dot_general` to support + user-defined scaling configurations. + + Essentially, the operation follows:: + + a, a_scales = quantize(lhs, configs[0]) + b, b_scales = quantize(rhs, configs[1]) + c = jax.nn.scaled_matmul(a, b, a_scales, b_scales) + Args: - lhs: Left-hand side input tensor. - rhs: Right-hand side input tensor. - dimension_numbers: A tuple specifying the contraction and batch dimensions - for the dot general operation. Must follow the format: - `((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`. - preferred_element_type: The preferred output data type. Supported types are - `jnp.float32`, `jnp.bfloat16`, and `jnp.float16`. Defaults to `jnp.float32`. - configs: A list of `BlockScaleConfig` specifying the scaling - configurations for the operation. Defaults to `mxfp8`. - implementation: A string to control which implementation backend to use. - Supported strings are `cudnn` (cuDNN block scaled dot). It defaults - to `None`, which will automatically select the best available backend. + lhs (ArrayLike): Input array. + rhs (ArrayLike): Input array. + dimension_numbers (DotDimensionNumbers): A tuple of two tuples specifying + the contraction and batch dimensions: + `((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`. + preferred_element_type (DTypeLike, optional): Output data type of the dot + product. Defaults to `jnp.float32`. Other valid types include + `jnp.bfloat16` and `jnp.float16`. + configs (list of BlockScaleConfig, optional): Scaling configurations for + lhs, rhs, and gradients. Users can obtain valid configurations via + `jax.nn.get_scaled_dot_general_config`. Currently, `nvfp4` and `mxfp8` + are supported. If `None`, falls back to `lax.dot_general`. + implementation: str + (Deprecated) Backend selector, now ignored. The system chooses the backend + automatically. Scheduled for removal in future releases. + Returns: - The result of the scaled dot general operation. + Array: The resulting tensor, with batch dimensions first, followed by + non-contracting/non-batch dimensions of lhs, and then those of rhs. + + See Also: + - :func:`jax.nn.scaled_matmul`: Scaled matmul function. + - :func:`jax.lax.dot_general`: General dot product operator. + + Notes: + - Unlike `nn.scaled_matmul`, which assumes quantized low-precision + inputs with explicit scaling factors, this operator takes high-precision + inputs, applies quantization internally, and handles the backward pass. + + Examples: + + Creating config for mxfp8: + + >>> configs = [jax.nn.get_scaled_dot_general_config('mxfp8')] * 3 + + Creating config for nvfp4: + + >>> global_scale = jnp.array([0.5], jnp.float32) + >>> configs = [jax.nn.get_scaled_dot_general_config('nvfp4', global_scale)] * 3 + + Using scaled_dot_general with the configs: + + >>> import functools + >>> scaled_dot_general_fn = functools.partial(jax.nn.scaled_dot_general, configs=configs) + >>> lhs = jax.random.normal(jax.random.PRNGKey(1), (3, 128, 64)) + >>> rhs = jax.random.normal(jax.random.PRNGKey(2), (3, 128, 64)) + >>> out = scaled_dot_general_fn(lhs, rhs, (((2,), (2,)), ((0,), (0,)))) # doctest: +SKIP """ - # Create configs if not provided - if configs is None: - if dtypes.float8_e8m0fnu is None: - raise ValueError("Requires >= ml_dtypes 0.5.0 to support float8_e8m0fnu") - mxfp8_config = BlockScaleConfig( - mode='mxfp8', - block_size=32, - data_type=jnp.float8_e4m3fn, - scale_type=jnp.float8_e8m0fnu, - global_scale=None, - infer_only=False - ) - configs = [mxfp8_config for _ in range(3)] + if implementation is not None: + warnings.warn("Backend selector, now ignored. The system chooses the " + "backend automatically.", DeprecationWarning) - if implementation is None: - implementation = 'cudnn' + if configs is None: + return lax.dot_general(lhs, rhs, dimension_numbers, + preferred_element_type=preferred_element_type) - match implementation: - case 'cudnn': - out = cudnn_scaled_dot_general( - lhs, rhs, dimension_numbers, - preferred_element_type=preferred_element_type, - configs=configs - ) - case _: - raise ValueError(f"Unsupported implementation option: {implementation}") + out = cudnn_scaled_dot_general( + lhs, rhs, dimension_numbers, + preferred_element_type=preferred_element_type, + configs=configs + ) return out + +@custom_derivatives.custom_jvp +@api.jit +def log1mexp(x: ArrayLike) -> Array: + r"""Numerically stable calculation of :math:`\log(1 - \exp(-x))`. + + This function is undefined for :math:`x < 0`. + + Based on `TensorFlow's implementation `_. + + References: + .. [1] Martin Mächler. `Accurately Computing log(1 − exp(−|a|)) Assessed by the Rmpfr package. + `_. + """ + x = numpy_util.ensure_arraylike("log1mexp", x) + c = jnp.log(2.0) + return jnp.where( + x < c, + jnp.log(-jnp.expm1(-x)), + jnp.log1p(-jnp.exp(-x)), + ) + +log1mexp.defjvps(lambda g, ans, x: g / jnp.expm1(x)) diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py index 287e8f039e1d..ab34857d627d 100644 --- a/jax/_src/nn/initializers.py +++ b/jax/_src/nn/initializers.py @@ -22,15 +22,18 @@ from collections.abc import Sequence import math import typing -from typing import Any, Literal, Protocol +from typing import Any, Literal, Protocol, TypeAlias import numpy as np -import jax.numpy as jnp -from jax import random from jax._src import core from jax._src import dtypes -from jax._src.typing import Array, ArrayLike +from jax._src import numpy as jnp +from jax._src import random +from jax._src.named_sharding import NamedSharding +from jax._src.partition_spec import PartitionSpec +from jax._src.sharding_impls import canonicalize_sharding +from jax._src.typing import Array, ArrayLike, DType from jax._src.util import set_module export = set_module('jax.nn.initializers') @@ -41,20 +44,24 @@ DTypeLikeComplex = Any DTypeLikeInexact = Any # DTypeLikeFloat | DTypeLikeComplex RealNumeric = Any # Scalar jnp array or float +OutShardingType: TypeAlias = NamedSharding | PartitionSpec | None @export @typing.runtime_checkable class Initializer(Protocol): + """Protocol for initializers returned by :mod:`jax.nn.initializers` APIs.""" def __call__(self, key: Array, shape: core.Shape, - dtype: DTypeLikeInexact = jnp.float_) -> Array: + dtype: DTypeLikeInexact | None = None, + out_sharding: OutShardingType = None) -> Array: raise NotImplementedError @export def zeros(key: Array, shape: core.Shape, - dtype: DTypeLikeInexact = jnp.float_) -> Array: + dtype: DTypeLikeInexact | None = None, + out_sharding: OutShardingType = None) -> Array: """An initializer that returns a constant array full of zeros. The ``key`` argument is ignored. @@ -64,12 +71,14 @@ def zeros(key: Array, Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32) """ - return jnp.zeros(shape, dtypes.canonicalize_dtype(dtype)) + dtype = dtypes.default_float_dtype() if dtype is None else dtype + return jnp.zeros(shape, dtype, out_sharding=out_sharding) @export def ones(key: Array, shape: core.Shape, - dtype: DTypeLikeInexact = jnp.float_) -> Array: + dtype: DTypeLikeInexact | None = None, + out_sharding: OutShardingType = None) -> Array: """An initializer that returns a constant array full of ones. The ``key`` argument is ignored. @@ -80,12 +89,12 @@ def ones(key: Array, [1., 1.], [1., 1.]], dtype=float32) """ - return jnp.ones(shape, dtypes.canonicalize_dtype(dtype)) + dtype = dtypes.default_float_dtype() if dtype is None else dtype + return jnp.ones(shape, dtype, out_sharding=out_sharding) @export def constant(value: ArrayLike, - dtype: DTypeLikeInexact = jnp.float_ - ) -> Initializer: + dtype: DTypeLikeInexact | None = None) -> Initializer: """Builds an initializer that returns arrays full of a constant ``value``. Args: @@ -100,14 +109,16 @@ def constant(value: ArrayLike, """ def init(key: Array, shape: core.Shape, - dtype: DTypeLikeInexact = dtype) -> Array: - dtype = dtypes.canonicalize_dtype(dtype) - return jnp.full(shape, value, dtype=dtype) + dtype: DTypeLikeInexact | None = dtype, + out_sharding: OutShardingType = None) -> Array: + dtype = dtypes.default_float_dtype() if dtype is None else dtype + out_sharding = canonicalize_sharding(out_sharding, 'nn.initializers.constant') + return jnp.full(shape, value, dtype=dtype, device=out_sharding) return init @export def uniform(scale: RealNumeric = 1e-2, - dtype: DTypeLikeInexact = jnp.float_) -> Initializer: + dtype: DTypeLikeInexact | None = None) -> Initializer: """Builds an initializer that returns real uniformly-distributed random arrays. Args: @@ -126,14 +137,16 @@ def uniform(scale: RealNumeric = 1e-2, """ def init(key: Array, shape: core.Shape, - dtype: DTypeLikeInexact = dtype) -> Array: - dtype = dtypes.canonicalize_dtype(dtype) - return random.uniform(key, shape, dtype) * jnp.array(scale, dtype) + dtype: DTypeLikeInexact | None = dtype, + out_sharding: OutShardingType = None) -> Array: + dtype = dtypes.default_float_dtype() if dtype is None else dtype + return random.uniform(key, shape, dtype, + out_sharding=out_sharding) * jnp.array(scale, dtype) return init @export def normal(stddev: RealNumeric = 1e-2, - dtype: DTypeLikeInexact = jnp.float_) -> Initializer: + dtype: DTypeLikeInexact | None = None) -> Initializer: """Builds an initializer that returns real normally-distributed random arrays. Args: @@ -152,14 +165,16 @@ def normal(stddev: RealNumeric = 1e-2, """ def init(key: Array, shape: core.Shape, - dtype: DTypeLikeInexact = dtype) -> Array: - dtype = dtypes.canonicalize_dtype(dtype) - return random.normal(key, shape, dtype) * jnp.array(stddev, dtype) + dtype: DTypeLikeInexact | None = dtype, + out_sharding: OutShardingType = None) -> Array: + dtype = dtypes.default_float_dtype() if dtype is None else dtype + return random.normal(key, shape, dtype, + out_sharding=out_sharding) * jnp.array(stddev, dtype) return init @export def truncated_normal(stddev: RealNumeric = 1e-2, - dtype: DTypeLikeInexact = jnp.float_, + dtype: DTypeLikeInexact | None = None, lower: RealNumeric = -2.0, upper: RealNumeric = 2.0) -> Initializer: r"""Builds an initializer that returns truncated-normal random arrays. @@ -186,13 +201,14 @@ def truncated_normal(stddev: RealNumeric = 1e-2, Array([[ 2.9047365, 5.2338114, 5.29852 ], [-3.836303 , -4.192359 , 0.6022964]], dtype=float32) """ - def init(key: Array, shape: core.Shape, - dtype: DTypeLikeInexact = dtype) -> Array: - dtype = dtypes.canonicalize_dtype(dtype) + dtype: DTypeLikeInexact | None = dtype, + out_sharding: OutShardingType = None) -> Array: + dtype = dtypes.default_float_dtype() if dtype is None else dtype return random.truncated_normal( - key, lower, upper, shape, dtype) * jnp.array(stddev, dtype) + key, lower, upper, shape, dtype, + out_sharding=out_sharding) * jnp.array(stddev, dtype) return init @export @@ -207,9 +223,12 @@ def _compute_fans(shape: Sequence[int], Axes not in in_axis, out_axis, or batch_axis are assumed to constitute the "receptive field" of a convolution (kernel spatial dimensions). """ - if len(shape) <= 1: - raise ValueError(f"Can't compute input and output sizes of a {len(shape)}" - "-dimensional weights tensor. Must be at least 2D.") + if isinstance(in_axis, int) and in_axis == -2 and len(shape) <= 1: + raise ValueError( + f"Can't compute input and output sizes of a {len(shape)}-dimensional" + " weights tensor with default in_axis. Must be at least 2D or specify" + " in_axis explicitly." + ) if isinstance(in_axis, int): in_size = shape[in_axis] @@ -230,7 +249,7 @@ def _compute_fans(shape: Sequence[int], def _complex_uniform(key: Array, shape: Sequence[int], - dtype: DTypeLikeInexact) -> Array: + dtype: DType) -> Array: """ Sample uniform random values within a disk on the complex plane, with zero mean and unit variance. @@ -239,12 +258,12 @@ def _complex_uniform(key: Array, real_dtype = np.array(0, dtype).real.dtype dtype = dtypes.to_complex_dtype(real_dtype) r = jnp.sqrt(2 * random.uniform(key_r, shape, real_dtype)).astype(dtype) - theta = 2 * jnp.pi * random.uniform(key_theta, shape, real_dtype).astype(dtype) + theta = 2 * np.pi * random.uniform(key_theta, shape, real_dtype).astype(dtype) return r * jnp.exp(1j * theta) def _complex_truncated_normal(key: Array, upper: ArrayLike, shape: Sequence[int], - dtype: DTypeLikeInexact) -> Array: + dtype: DType) -> Array: """ Sample random values from a centered normal distribution on the complex plane, whose modulus is truncated to `upper`, and the variance before the truncation @@ -256,7 +275,7 @@ def _complex_truncated_normal(key: Array, upper: ArrayLike, t = ((1 - jnp.exp(jnp.array(-(upper ** 2), dtype))) * random.uniform(key_r, shape, real_dtype).astype(dtype)) r = jnp.sqrt(-jnp.log(1 - t)) - theta = 2 * jnp.pi * random.uniform(key_theta, shape, real_dtype).astype(dtype) + theta = 2 * np.pi * random.uniform(key_theta, shape, real_dtype).astype(dtype) return r * jnp.exp(1j * theta) @export @@ -267,8 +286,8 @@ def variance_scaling( Literal["uniform"]), in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, - batch_axis: Sequence[int] = (), - dtype: DTypeLikeInexact = jnp.float_ + batch_axis: int | Sequence[int] = (), + dtype: DTypeLikeInexact | None = None ) -> Initializer: r""" Initializer that adapts its scale to the shape of the weights tensor. @@ -312,12 +331,12 @@ def variance_scaling( ignored. dtype: the dtype of the weights. """ - def init(key: Array, shape: core.Shape, - dtype: DTypeLikeInexact = dtype) -> Array: + dtype: DTypeLikeInexact | None = dtype, + out_sharding: OutShardingType = None) -> Array: shape = core.canonicalize_shape(shape) - dtype = dtypes.canonicalize_dtype(dtype) + dtype = dtypes.default_float_dtype() if dtype is None else dtype fan_in, fan_out = _compute_fans(shape, in_axis, out_axis, batch_axis) if mode == "fan_in": denominator = fan_in elif mode == "fan_out": denominator = fan_out @@ -329,19 +348,22 @@ def init(key: Array, variance = jnp.array(scale / denominator, dtype=dtype) if distribution == "truncated_normal": - if jnp.issubdtype(dtype, jnp.floating): + if dtypes.issubdtype(dtype, np.floating): # constant is stddev of standard normal truncated to (-2, 2) stddev = jnp.sqrt(variance) / jnp.array(.87962566103423978, dtype) - return random.truncated_normal(key, -2, 2, shape, dtype) * stddev + return random.truncated_normal(key, -2, 2, shape, dtype, + out_sharding=out_sharding) * stddev else: # constant is stddev of complex standard normal truncated to 2 stddev = jnp.sqrt(variance) / jnp.array(.95311164380491208, dtype) return _complex_truncated_normal(key, 2, shape, dtype) * stddev elif distribution == "normal": - return random.normal(key, shape, dtype) * jnp.sqrt(variance) + return random.normal(key, shape, dtype, + out_sharding=out_sharding) * jnp.sqrt(variance) elif distribution == "uniform": - if jnp.issubdtype(dtype, jnp.floating): - return random.uniform(key, shape, dtype, -1) * jnp.sqrt(3 * variance) + if dtypes.issubdtype(dtype, np.floating): + return random.uniform(key, shape, dtype, -1, + out_sharding=out_sharding) * jnp.sqrt(3 * variance) else: return _complex_uniform(key, shape, dtype) * jnp.sqrt(variance) else: @@ -352,8 +374,8 @@ def init(key: Array, @export def glorot_uniform(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, - batch_axis: Sequence[int] = (), - dtype: DTypeLikeInexact = jnp.float_) -> Initializer: + batch_axis: int | Sequence[int] = (), + dtype: DTypeLikeInexact | None = None) -> Initializer: """Builds a Glorot uniform initializer (aka Xavier uniform initializer). A `Glorot uniform initializer`_ is a specialization of @@ -390,8 +412,8 @@ def glorot_uniform(in_axis: int | Sequence[int] = -2, @export def glorot_normal(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, - batch_axis: Sequence[int] = (), - dtype: DTypeLikeInexact = jnp.float_) -> Initializer: + batch_axis: int | Sequence[int] = (), + dtype: DTypeLikeInexact | None = None) -> Initializer: """Builds a Glorot normal initializer (aka Xavier normal initializer). A `Glorot normal initializer`_ is a specialization of @@ -428,8 +450,8 @@ def glorot_normal(in_axis: int | Sequence[int] = -2, @export def lecun_uniform(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, - batch_axis: Sequence[int] = (), - dtype: DTypeLikeInexact = jnp.float_) -> Initializer: + batch_axis: int | Sequence[int] = (), + dtype: DTypeLikeInexact | None = None) -> Initializer: """Builds a Lecun uniform initializer. A `Lecun uniform initializer`_ is a specialization of @@ -464,8 +486,8 @@ def lecun_uniform(in_axis: int | Sequence[int] = -2, @export def lecun_normal(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, - batch_axis: Sequence[int] = (), - dtype: DTypeLikeInexact = jnp.float_) -> Initializer: + batch_axis: int | Sequence[int] = (), + dtype: DTypeLikeInexact | None = None) -> Initializer: """Builds a Lecun normal initializer. A `Lecun normal initializer`_ is a specialization of @@ -500,8 +522,8 @@ def lecun_normal(in_axis: int | Sequence[int] = -2, @export def he_uniform(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, - batch_axis: Sequence[int] = (), - dtype: DTypeLikeInexact = jnp.float_) -> Initializer: + batch_axis: int | Sequence[int] = (), + dtype: DTypeLikeInexact | None = None) -> Initializer: """Builds a He uniform initializer (aka Kaiming uniform initializer). A `He uniform initializer`_ is a specialization of @@ -538,8 +560,8 @@ def he_uniform(in_axis: int | Sequence[int] = -2, @export def he_normal(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, - batch_axis: Sequence[int] = (), - dtype: DTypeLikeInexact = jnp.float_) -> Initializer: + batch_axis: int | Sequence[int] = (), + dtype: DTypeLikeInexact | None = None) -> Initializer: """Builds a He normal initializer (aka Kaiming normal initializer). A `He normal initializer`_ is a specialization of @@ -576,7 +598,7 @@ def he_normal(in_axis: int | Sequence[int] = -2, @export def orthogonal(scale: RealNumeric = 1.0, column_axis: int = -1, - dtype: DTypeLikeInexact = jnp.float_) -> Initializer: + dtype: DTypeLikeInexact | None = None) -> Initializer: """ Builds an initializer that returns uniformly distributed orthogonal matrices. @@ -601,8 +623,11 @@ def orthogonal(scale: RealNumeric = 1.0, """ def init(key: Array, shape: core.Shape, - dtype: DTypeLikeInexact = dtype) -> Array: - dtype = dtypes.canonicalize_dtype(dtype) + dtype: DTypeLikeInexact | None = dtype, + out_sharding: OutShardingType = None) -> Array: + if out_sharding is not None: + raise NotImplementedError + dtype = dtypes.default_float_dtype() if dtype is None else dtype if len(shape) < 2: raise ValueError("orthogonal initializer requires at least a 2D shape") n_rows, n_cols = math.prod(shape) // shape[column_axis], shape[column_axis] @@ -616,7 +641,7 @@ def init(key: Array, def delta_orthogonal( scale: RealNumeric = 1.0, column_axis: int = -1, - dtype: DTypeLikeInexact = jnp.float_) -> Initializer: + dtype: DTypeLikeInexact | None = None) -> Initializer: """ Builds an initializer for delta orthogonal kernels. @@ -651,8 +676,11 @@ def delta_orthogonal( """ def init(key: Array, shape: core.Shape, - dtype: DTypeLikeInexact = dtype) -> Array: - dtype = dtypes.canonicalize_dtype(dtype) + dtype: DTypeLikeInexact | None = dtype, + out_sharding: OutShardingType = None) -> Array: + if out_sharding is not None: + raise NotImplementedError + dtype = dtypes.default_float_dtype() if dtype is None else dtype if len(shape) not in [3, 4, 5]: raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D " "shape.") diff --git a/jax/_src/numpy/__init__.py b/jax/_src/numpy/__init__.py index 0a0a8260408e..1f1d115964c2 100644 --- a/jax/_src/numpy/__init__.py +++ b/jax/_src/numpy/__init__.py @@ -11,3 +11,224 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# The following are a subset of the full jax.numpy functionality used by +# internal imports. + +from jax._src.numpy.array_constructors import ( + asarray as asarray, + array as array, +) + +from jax._src.numpy.array_creation import ( + empty as empty, + empty_like as empty_like, + full as full, + full_like as full_like, + linspace as linspace, + ones as ones, + ones_like as ones_like, + zeros as zeros, + zeros_like as zeros_like, +) + +from jax._src.numpy.indexing import ( + take as take, +) + +from numpy import ( + nan as nan, + inf as inf, + floating as floating, + integer as integer, + inexact as inexact, + complexfloating as complexfloating, + number as number, + character as character, + generic as generic, + dtype as dtype, + unsignedinteger as unsignedinteger, +) +from jax._src import dtypes as dtypes +from jax._src.numpy.scalar_types import ( + bfloat16 as bfloat16, + bool_ as bool, # Array API alias for bool_ # noqa: F401 + bool_ as bool_, + cdouble as cdouble, + csingle as csingle, + complex128 as complex128, + complex64 as complex64, + complex_ as complex_, + double as double, + float16 as float16, + float32 as float32, + float4_e2m1fn as float4_e2m1fn, + float64 as float64, + float8_e3m4 as float8_e3m4, + float8_e4m3 as float8_e4m3, + float8_e4m3b11fnuz as float8_e4m3b11fnuz, + float8_e4m3fn as float8_e4m3fn, + float8_e4m3fnuz as float8_e4m3fnuz, + float8_e5m2 as float8_e5m2, + float8_e5m2fnuz as float8_e5m2fnuz, + float8_e8m0fnu as float8_e8m0fnu, + float_ as float_, + int2 as int2, + int4 as int4, + int8 as int8, + int16 as int16, + int32 as int32, + int64 as int64, + int_ as int_, + single as single, + uint as uint, + uint2 as uint2, + uint4 as uint4, + uint8 as uint8, + uint16 as uint16, + uint32 as uint32, + uint64 as uint64, +) + +from jax._src.numpy.lax_numpy import ( + apply_along_axis as apply_along_axis, + arange as arange, + argmax as argmax, + argmin as argmin, + argsort as argsort, + astype as astype, + atleast_1d as atleast_1d, + atleast_2d as atleast_2d, + block as block, + broadcast_arrays as broadcast_arrays, + broadcast_shapes as broadcast_shapes, + broadcast_to as broadcast_to, + clip as clip, + concatenate as concatenate, + cov as cov, + cross as cross, + diag as diag, + diag_indices as diag_indices, + diff as diff, + digitize as digitize, + expand_dims as expand_dims, + eye as eye, + flip as flip, + hstack as hstack, + iscomplexobj as iscomplexobj, + isscalar as isscalar, + issubdtype as issubdtype, + iinfo as iinfo, + meshgrid as meshgrid, + moveaxis as moveaxis, + nonzero as nonzero, + pad as pad, + permute_dims as permute_dims, + piecewise as piecewise, + promote_types as promote_types, + ravel as ravel, + reshape as reshape, + repeat as repeat, + roll as roll, + round as round, + result_type as result_type, + searchsorted as searchsorted, + select as select, + split as split, + squeeze as squeeze, + stack as stack, + trace as trace, + transpose as transpose, + tril as tril, + triu as triu, + triu_indices as triu_indices, + unravel_index as unravel_index, + vstack as vstack, + where as where, +) + +from jax._src.numpy.index_tricks import ( + ogrid as ogrid, +) + +from jax._src.numpy.polynomial import ( + polyval as polyval +) + +from jax._src.numpy.reductions import ( + all as all, + amax as amax, + amin as amin, + any as any, + cumprod as cumprod, + cumsum as cumsum, + max as max, + mean as mean, + median as median, + nanstd as nanstd, + prod as prod, + sum as sum, +) + +from jax._src.numpy.setops import ( + unique as unique, +) + +from jax._src.numpy.tensor_contractions import ( + dot as dot, + matmul as matmul, + vdot as vdot, +) + +from jax._src.numpy.ufuncs import ( + abs as abs, + arctan2 as arctan2, + bitwise_and as bitwise_and, + cbrt as cbrt, + ceil as ceil, + conj as conj, + conjugate as conjugate, + cos as cos, + deg2rad as deg2rad, + divide as divide, + equal as equal, + exp as exp, + expm1 as expm1, + floor as floor, + floor_divide as floor_divide, + fmod as fmod, + greater as greater, + greater_equal as greater_equal, + hypot as hypot, + imag as imag, + isinf as isinf, + isfinite as isfinite, + isnan as isnan, + less as less, + less_equal as less_equal, + log as log, + log1p as log1p, + log2 as log2, + logaddexp as logaddexp, + logical_and as logical_and, + logical_not as logical_not, + logical_or as logical_or, + maximum as maximum, + minimum as minimum, + mod as mod, + not_equal as not_equal, + power as power, + rad2deg as rad2deg, + real as real, + reciprocal as reciprocal, + sign as sign, + signbit as signbit, + sin as sin, + sqrt as sqrt, + square as square, + subtract as subtract, + tanh as tanh +) +from jax._src.numpy.ufuncs import ( + multiply as multiply, +) diff --git a/jax/_src/numpy/array_api_metadata.py b/jax/_src/numpy/array_api_metadata.py index 4a01f579a67e..a4a5496bdc73 100644 --- a/jax/_src/numpy/array_api_metadata.py +++ b/jax/_src/numpy/array_api_metadata.py @@ -21,13 +21,14 @@ from types import ModuleType -import jax -from jax._src.sharding import Sharding +from jax._src import config +from jax._src import dtypes as _dtypes +from jax._src import xla_bridge as xb from jax._src.lib import xla_client as xc -from jax._src import dtypes as _dtypes, config +from jax._src.sharding import Sharding -__array_api_version__ = '2023.12' +__array_api_version__ = '2024.12' def __array_namespace__(self, *, api_version: None | str = None) -> ModuleType: @@ -38,7 +39,8 @@ def __array_namespace__(self, *, api_version: None | str = None) -> ModuleType: if api_version is not None and api_version != __array_api_version__: raise ValueError(f"{api_version=!r} is not available; " f"available versions are: {[__array_api_version__]}") - return jax.numpy + import jax.numpy # pytype: disable=import-error + return jax.numpy # pytype: disable=module-attr def __array_namespace_info__() -> ArrayNamespaceInfo: @@ -51,8 +53,9 @@ class ArrayNamespaceInfo: .. _Python array API: https://data-apis.org/array-api/ """ _capabilities = { - "boolean indexing": True, - "data-dependent shapes": False, + "boolean indexing": False, # within transformations + "data-dependent shapes": False, # within transformations + "max dimensions": 64, # XLA limitation } def _build_dtype_dict(self): @@ -72,7 +75,10 @@ def default_device(self): return None def devices(self): - return jax.devices() + out = [None] # None indicates "uncommitted" + for backend in xb.backends(): + out.extend(xb.devices(backend)) + return out def capabilities(self): return self._capabilities @@ -80,16 +86,11 @@ def capabilities(self): def default_dtypes(self, *, device: xc.Device | Sharding | None = None): # Array API supported dtypes are device-independent in JAX del device - default_dtypes = { - "real floating": "f", - "complex floating": "c", - "integral": "i", - "indexing": "i", - } return { - dtype_name: _dtypes.canonicalize_dtype( - _dtypes._default_types.get(kind) - ) for dtype_name, kind in default_dtypes.items() + "real floating": _dtypes.default_float_dtype(), + "complex floating": _dtypes.default_complex_dtype(), + "integral": _dtypes.default_int_dtype(), + "indexing": _dtypes.default_int_dtype(), } def dtypes( diff --git a/jax/_src/numpy/array_constructors.py b/jax/_src/numpy/array_constructors.py new file mode 100644 index 000000000000..759286392966 --- /dev/null +++ b/jax/_src/numpy/array_constructors.py @@ -0,0 +1,428 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import logging +from typing import Any + +import numpy as np + +from jax._src import api +from jax._src import config +from jax._src import core +from jax._src import dtypes +from jax._src import literals +from jax._src import tree_util +from jax._src import xla_bridge +from jax._src.lax import lax +from jax._src.lib import xla_client as xc +from jax._src.numpy import util +from jax._src.typing import Array, ArrayLike, DTypeLike +from jax._src.sharding import Sharding +from jax._src.sharding_impls import NamedSharding, PartitionSpec as P + +logger = logging.getLogger(__name__) + +export = util.set_module('jax.numpy') + +for pkg_name in ['jax_cuda13_plugin', 'jax_cuda12_plugin', 'jaxlib.cuda']: + try: + cuda_plugin_extension = importlib.import_module( + f'{pkg_name}.cuda_plugin_extension' + ) + except ImportError: + cuda_plugin_extension = None # type: ignore + else: + break + +# Dynamically find and load ROCm plugin extension +rocm_plugin_extension = None +try: + from importlib.metadata import distributions + for dist in distributions(): + name = dist.metadata.get('Name', '') + if name.startswith('jax-rocm') and name.endswith('-plugin'): + module_name = name.replace('-', '_') + try: + rocm_plugin_extension = importlib.import_module( + f'{module_name}.rocm_plugin_extension' + ) + break + except ImportError: + continue +except Exception as e: + logger.debug("ROCm plugin discovery failed: %s", e) + + +def _supports_buffer_protocol(obj): + try: + view = memoryview(obj) + except TypeError: + return False + else: + return True + + +def _make_string_array( + object: np.ndarray, + dtype: DTypeLike | None = None, + ndmin: int = 0, + device: xc.Device | Sharding | None = None, +) -> Array: + if not isinstance(object, np.ndarray): + raise TypeError( + "Currently, string arrays can only be made from NumPy" + f" arrays. Got: {type(object)}." + ) + if dtype is not None and ( + dtypes.is_string_dtype(object.dtype) != dtypes.is_string_dtype(dtype) + ): + raise TypeError( + f"Cannot make an array with dtype {dtype} from an object with dtype" + f" {object.dtype}." + ) + if ndmin > object.ndim: + raise TypeError( + f"ndmin {ndmin} cannot be greater than object's ndims" + f" {object.ndim} for string arrays." + ) + + # Just do a device_put since XLA does not support string as a data type. + return api.device_put(x=object, device=device) + + +@export +def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, + order: str | None = "K", ndmin: int = 0, + *, device: xc.Device | Sharding | None = None, + out_sharding: NamedSharding | P | None = None) -> Array: + """Convert an object to a JAX array. + + JAX implementation of :func:`numpy.array`. + + Args: + object: an object that is convertible to an array. This includes JAX + arrays, NumPy arrays, Python scalars, Python collections like lists + and tuples, objects with a ``__jax_array__`` method, and objects + supporting the Python buffer protocol. + dtype: optionally specify the dtype of the output array. If not + specified it will be inferred from the input. + copy: specify whether to force a copy of the input. Default: True. + order: not implemented in JAX + ndmin: integer specifying the minimum number of dimensions in the + output array. + device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + out_sharding: (optional) :class:`~jax.sharding.PartitionSpec` or :class:`~jax.NamedSharding` + representing the sharding of the created array (see `explicit sharding`_ for more details). + This argument exists for consistency with other array creation routines across JAX. + Specifying both ``out_sharding`` and ``device`` will result in an error. + + Returns: + A JAX array constructed from the input. + + See also: + - :func:`jax.numpy.asarray`: like `array`, but by default only copies + when necessary. + - :func:`jax.numpy.from_dlpack`: construct a JAX array from an object + that implements the dlpack interface. + - :func:`jax.numpy.frombuffer`: construct a JAX array from an object + that implements the buffer interface. + + Examples: + Constructing JAX arrays from Python scalars: + + >>> jnp.array(True) + Array(True, dtype=bool) + >>> jnp.array(42) + Array(42, dtype=int32, weak_type=True) + >>> jnp.array(3.5) + Array(3.5, dtype=float32, weak_type=True) + >>> jnp.array(1 + 1j) + Array(1.+1.j, dtype=complex64, weak_type=True) + + Constructing JAX arrays from Python collections: + + >>> jnp.array([1, 2, 3]) # list of ints -> 1D array + Array([1, 2, 3], dtype=int32) + >>> jnp.array([(1, 2, 3), (4, 5, 6)]) # list of tuples of ints -> 2D array + Array([[1, 2, 3], + [4, 5, 6]], dtype=int32) + >>> jnp.array(range(5)) + Array([0, 1, 2, 3, 4], dtype=int32) + + Constructing JAX arrays from NumPy arrays: + + >>> jnp.array(np.linspace(0, 2, 5)) + Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32) + + Constructing a JAX array via the Python buffer interface, using Python's + built-in :mod:`array` module. + + >>> from array import array + >>> pybuffer = array('i', [2, 3, 5, 7]) + >>> jnp.array(pybuffer) + Array([2, 3, 5, 7], dtype=int32) + + .. _explicit sharding: https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html + """ + if order is not None and order != "K": + raise NotImplementedError("Only implemented for order='K'") + + # check if the given dtype is compatible with JAX + if dtype is not None: + dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "array") + + # Here we make a judgment call: we only return a weakly-typed array when the + # input object itself is weakly typed. That ensures asarray(x) is a no-op + # whenever x is weak, but avoids introducing weak types with something like + # array([1, 2, 3]) + weak_type = dtype is None and dtypes.is_weakly_typed(object) + + if device is None and out_sharding is None and isinstance(object, core.Tracer): + sharding = object.aval.sharding + sharding = None if sharding.mesh.empty else sharding + else: + sharding = util.choose_device_or_out_sharding(device, out_sharding, "jnp.array") + + # Use device_put to avoid a copy for ndarray inputs. + if (not copy and isinstance(object, np.ndarray) and + (dtype is None or dtype == object.dtype) and (ndmin <= object.ndim) and + device is None): + if dtype is not None: + # If there is an explicit dtype, we've already canonicalized things and + # device_put should not canonicalize again. + object = literals.TypedNdArray(object, weak_type=False) + # Keep the output uncommitted. + return api.device_put(object) + + # String arrays need separate handling because XLA does not support string + # as a data type. + if dtypes.is_string_dtype(dtype) or ( + hasattr(object, "dtype") and dtypes.is_string_dtype(object.dtype) + ): + return _make_string_array( + object=object, dtype=dtype, ndmin=ndmin, device=device + ) + + # For Python scalar literals, call coerce_to_array to catch any overflow + # errors. We don't use dtypes.is_python_scalar because we don't want this + # triggering for traced values. We do this here because it matters whether or + # not dtype is None. We don't assign the result because we want the raw object + # to be used for type inference below. + if isinstance(object, (bool, int, float, complex)): + _ = dtypes.coerce_to_array(object, dtype) + elif not isinstance(object, Array): + # Check if object supports any of the data exchange protocols + # (except dlpack, see data-apis/array-api#301). If it does, + # consume the object as jax array and continue (but not return) so + # that other array() arguments get processed against the input + # object. + # + # Notice that data exchange protocols define dtype in the + # corresponding data structures and it may not be available as + # object.dtype. So, we'll resolve the protocols here before + # evaluating object.dtype. + if hasattr(object, '__jax_array__'): + object = object.__jax_array__() + elif hasattr(object, '__cuda_array_interface__'): + cai = object.__cuda_array_interface__ + backend = xla_bridge.get_backend() + if 'rocm' in backend.platform_version.lower(): + gpu_plugin_extension = rocm_plugin_extension + elif 'cuda' in backend.platform_version.lower(): + gpu_plugin_extension = cuda_plugin_extension + else: + gpu_plugin_extension = None + if gpu_plugin_extension is None: + device_id = None + else: + device_id = gpu_plugin_extension.get_device_ordinal(cai["data"][0]) + object = xc._xla.cuda_array_interface_to_buffer( + cai=cai, gpu_backend=backend, device_id=device_id) + + # To handle nested lists & tuples, flatten the tree and process each leaf. + leaves, treedef = tree_util.tree_flatten( + object, is_leaf=lambda x: not isinstance(x, (list, tuple))) + if any(leaf is None for leaf in leaves): + raise ValueError("None is not a valid value for jnp.array") + leaves = [ + leaf + if (leaf_jax_array := getattr(leaf, "__jax_array__", None)) is None + else leaf_jax_array() + for leaf in leaves + ] + if dtype is None: + # Use lattice_result_type rather than result_type to avoid canonicalization. + # Otherwise, weakly-typed inputs would have their dtypes canonicalized. + try: + dtype = ( + dtypes.lattice_result_type(*leaves)[0] + if leaves + else dtypes.default_float_dtype() + ) + except TypeError: + # This happens if, e.g. one of the entries is a memoryview object. + # This is rare, so we only handle it if the normal path fails. + leaves = [_convert_to_array_if_dtype_fails(leaf) for leaf in leaves] + dtype = dtypes.lattice_result_type(*leaves)[0] + + object = treedef.unflatten(leaves) + out: ArrayLike + if all(not isinstance(leaf, Array) for leaf in leaves): + # TODO(jakevdp): falling back to numpy here fails to overflow for lists + # containing large integers; see discussion in + # https://github.com/jax-ml/jax/pull/6047. More correct would be to call + # coerce_to_array on each leaf, but this may have performance implications. + out = np.asarray(object, dtype=dtype) + elif isinstance(object, Array): + assert object.aval is not None + out = lax._array_copy(object) if copy else object + elif isinstance(object, (list, tuple)): + if object: + arrs = (array(elt, dtype=dtype, copy=False) for elt in object) + arrays_out = [lax.expand_dims(arr, [0]) for arr in arrs] + # lax.concatenate can be slow to compile for wide concatenations, so form a + # tree of concatenations as a workaround especially for op-by-op mode. + # (https://github.com/jax-ml/jax/issues/653). + k = 16 + while len(arrays_out) > k: + arrays_out = [lax.concatenate(arrays_out[i:i+k], 0) + for i in range(0, len(arrays_out), k)] + out = lax.concatenate(arrays_out, 0) + else: + out = np.array([], dtype=dtype) + elif _supports_buffer_protocol(object): + object = memoryview(object) + # TODO(jakevdp): update this once we support NumPy 2.0 semantics for the copy arg. + out = np.array(object) if copy else np.asarray(object) + else: + raise TypeError(f"Unexpected input type for array: {type(object)}") + out_array: Array = lax._convert_element_type( + out, dtype, weak_type=weak_type, sharding=sharding) + if ndmin > np.ndim(out_array): + out_array = lax.expand_dims(out_array, range(ndmin - np.ndim(out_array))) + return out_array + + +def _get_platform( + device_or_sharding: xc.Device | Sharding | None | str) -> str: + """Get device_or_sharding platform or look up config.default_device.value.""" + if isinstance(device_or_sharding, xc.Device): + return device_or_sharding.platform + elif isinstance(device_or_sharding, Sharding): + return list(device_or_sharding.device_set)[0].platform + elif isinstance(device_or_sharding, str): + return device_or_sharding + elif device_or_sharding is None: + if config.default_device.value is None: + return xla_bridge.default_backend() + else: + return _get_platform(config.default_device.value) + else: + raise ValueError(f"`{device_or_sharding = }` was passed to" + "`canonicalize_or_get_default_platform`, only xc.Device," + " Sharding, None or str values are supported.") + + +def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike: + try: + dtypes.dtype(x) + except TypeError: + return np.asarray(x) + else: + return x + + +@export +def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None, + *, copy: bool | None = None, + device: xc.Device | Sharding | None = None, + out_sharding: NamedSharding | P | None = None) -> Array: + """Convert an object to a JAX array. + + JAX implementation of :func:`numpy.asarray`. + + Args: + a: an object that is convertible to an array. This includes JAX + arrays, NumPy arrays, Python scalars, Python collections like lists + and tuples, objects with a ``__jax_array__`` method, and objects + supporting the Python buffer protocol. + dtype: optionally specify the dtype of the output array. If not + specified it will be inferred from the input. + order: not implemented in JAX + copy: optional boolean specifying the copy mode. If True, then always + return a copy. If False, then error if a copy is necessary. Default is + None, which will only copy when necessary. + device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + A JAX array constructed from the input. + + See also: + - :func:`jax.numpy.array`: like `asarray`, but defaults to `copy=True`. + - :func:`jax.numpy.from_dlpack`: construct a JAX array from an object + that implements the dlpack interface. + - :func:`jax.numpy.frombuffer`: construct a JAX array from an object + that implements the buffer interface. + + Examples: + Constructing JAX arrays from Python scalars: + + >>> jnp.asarray(True) + Array(True, dtype=bool) + >>> jnp.asarray(42) + Array(42, dtype=int32, weak_type=True) + >>> jnp.asarray(3.5) + Array(3.5, dtype=float32, weak_type=True) + >>> jnp.asarray(1 + 1j) + Array(1.+1.j, dtype=complex64, weak_type=True) + + Constructing JAX arrays from Python collections: + + >>> jnp.asarray([1, 2, 3]) # list of ints -> 1D array + Array([1, 2, 3], dtype=int32) + >>> jnp.asarray([(1, 2, 3), (4, 5, 6)]) # list of tuples of ints -> 2D array + Array([[1, 2, 3], + [4, 5, 6]], dtype=int32) + >>> jnp.asarray(range(5)) + Array([0, 1, 2, 3, 4], dtype=int32) + + Constructing JAX arrays from NumPy arrays: + + >>> jnp.asarray(np.linspace(0, 2, 5)) + Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32) + + Constructing a JAX array via the Python buffer interface, using Python's + built-in :mod:`array` module. + + >>> from array import array + >>> pybuffer = array('i', [2, 3, 5, 7]) + >>> jnp.asarray(pybuffer) + Array([2, 3, 5, 7], dtype=int32) + """ + # For copy=False, the array API specifies that we raise a ValueError if the input supports + # the buffer protocol but a copy is required. Since array() supports the buffer protocol + # via numpy, this is only the case when the default device is not 'cpu' + if (copy is False and not isinstance(a, Array) + and _get_platform(device) != "cpu" + and _supports_buffer_protocol(a)): + raise ValueError(f"jnp.asarray: cannot convert object of type {type(a)} to JAX Array " + f"on platform={_get_platform(device)} with " + "copy=False. Consider using copy=None or copy=True instead.") + if dtype is not None: + dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "asarray") + return array(a, dtype=dtype, copy=bool(copy), order=order, device=device, + out_sharding=out_sharding) diff --git a/jax/_src/numpy/array_creation.py b/jax/_src/numpy/array_creation.py index 67418e7322c9..2941bbcfe4d2 100644 --- a/jax/_src/numpy/array_creation.py +++ b/jax/_src/numpy/array_creation.py @@ -13,19 +13,23 @@ # limitations under the License. import types -from typing import Any +import operator +from typing import Any, Literal, overload import numpy as np -import jax -from jax import lax +from jax._src import api from jax._src import core from jax._src import dtypes +from jax._src.lax import lax from jax._src.lib import xla_client as xc +from jax._src.numpy.array_constructors import asarray +from jax._src.numpy import ufuncs from jax._src.numpy import util +from jax._src.sharding import Sharding +from jax._src.sharding_impls import NamedSharding, PartitionSpec as P from jax._src.typing import Array, ArrayLike, DuckTypedArray, DTypeLike -from jax._src.util import set_module -from jax.sharding import Sharding +from jax._src.util import canonicalize_axis, set_module export = set_module('jax.numpy') @@ -43,19 +47,26 @@ def canonicalize_shape(shape: Any, context: str="") -> core.Shape: @export def zeros(shape: Any, dtype: DTypeLike | None = None, *, - device: xc.Device | Sharding | None = None) -> Array: + device: xc.Device | Sharding | None = None, + out_sharding: NamedSharding | P | None = None) -> Array: """Create an array full of zeros. JAX implementation of :func:`numpy.zeros`. Args: shape: int or sequence of ints specifying the shape of the created array. - dtype: optional dtype for the created array; defaults to floating point. + dtype: optional dtype for the created array; defaults to float32 or float64 + depending on the X64 configuration (see :ref:`default-dtypes`). device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` - to which the created array will be committed. + to which the created array will be committed. This argument exists for + compatibility with the :ref:`python-array-api`. + out_sharding: (optional) :class:`~jax.sharding.PartitionSpec` or :class:`~jax.NamedSharding` + representing the sharding of the created array (see `explicit sharding`_ for more details). + This argument exists for consistency with other array creation routines across JAX. + Specifying both ``out_sharding`` and ``device`` will result in an error. Returns: - Array of the specified shape and dtype, on the specified device if specified. + Array of the specified shape and dtype, with the given device/sharding if specified. See also: - :func:`jax.numpy.zeros_like` @@ -69,30 +80,42 @@ def zeros(shape: Any, dtype: DTypeLike | None = None, *, >>> jnp.zeros((2, 3), dtype=bool) Array([[False, False, False], [False, False, False]], dtype=bool) + + .. _explicit sharding: https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html """ if isinstance(shape, types.GeneratorType): raise TypeError("expected sequence object with len >= 0 or a single integer") if (m := _check_forgot_shape_tuple("zeros", shape, dtype)): raise TypeError(m) - dtypes.check_user_dtype_supported(dtype, "zeros") + dtype = dtypes.check_and_canonicalize_user_dtype( + float if dtype is None else dtype, "zeros") shape = canonicalize_shape(shape) - return lax.full(shape, 0, dtypes.jax_dtype(dtype), sharding=util.normalize_device_to_sharding(device)) + sharding = util.choose_device_or_out_sharding( + device, out_sharding, 'jnp.zeros') + return lax.full(shape, 0, dtype, sharding=sharding) @export def ones(shape: Any, dtype: DTypeLike | None = None, *, - device: xc.Device | Sharding | None = None) -> Array: + device: xc.Device | Sharding | None = None, + out_sharding: NamedSharding | P | None = None) -> Array: """Create an array full of ones. JAX implementation of :func:`numpy.ones`. Args: shape: int or sequence of ints specifying the shape of the created array. - dtype: optional dtype for the created array; defaults to floating point. + dtype: optional dtype for the created array; defaults to float32 or float64 + depending on the X64 configuration (see :ref:`default-dtypes`). device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` - to which the created array will be committed. + to which the created array will be committed. This argument exists for + compatibility with the :ref:`python-array-api`. + out_sharding: (optional) :class:`~jax.sharding.PartitionSpec` or :class:`~jax.NamedSharding` + representing the sharding of the created array (see `explicit sharding`_ for more details). + This argument exists for consistency with other array creation routines across JAX. + Specifying both ``out_sharding`` and ``device`` will result in an error. Returns: - Array of the specified shape and dtype, on the specified device if specified. + Array of the specified shape and dtype, with the given device/sharding if specified. See also: - :func:`jax.numpy.ones_like` @@ -106,18 +129,24 @@ def ones(shape: Any, dtype: DTypeLike | None = None, *, >>> jnp.ones((2, 3), dtype=bool) Array([[ True, True, True], [ True, True, True]], dtype=bool) + + .. _explicit sharding: https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html """ if isinstance(shape, types.GeneratorType): raise TypeError("expected sequence object with len >= 0 or a single integer") if (m := _check_forgot_shape_tuple("ones", shape, dtype)): raise TypeError(m) shape = canonicalize_shape(shape) - dtypes.check_user_dtype_supported(dtype, "ones") - return lax.full(shape, 1, dtypes.jax_dtype(dtype), sharding=util.normalize_device_to_sharding(device)) + dtype = dtypes.check_and_canonicalize_user_dtype( + float if dtype is None else dtype, "ones") + sharding = util.choose_device_or_out_sharding( + device, out_sharding, 'jnp.ones') + return lax.full(shape, 1, dtype, sharding=sharding) @export def empty(shape: Any, dtype: DTypeLike | None = None, *, - device: xc.Device | Sharding | None = None) -> Array: + device: xc.Device | Sharding | None = None, + out_sharding: NamedSharding | P | None = None) -> Array: """Create an empty array. JAX implementation of :func:`numpy.empty`. Because XLA cannot create an @@ -126,12 +155,18 @@ def empty(shape: Any, dtype: DTypeLike | None = None, *, Args: shape: int or sequence of ints specifying the shape of the created array. - dtype: optional dtype for the created array; defaults to floating point. + dtype: optional dtype for the created array; defaults to float32 or float64 + depending on the X64 configuration (see :ref:`default-dtypes`). device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` - to which the created array will be committed. + to which the created array will be committed. This argument exists for + compatibility with the :ref:`python-array-api`. + out_sharding: (optional) :class:`~jax.sharding.PartitionSpec` or :class:`~jax.NamedSharding` + representing the sharding of the created array (see `explicit sharding`_ for more details). + This argument exists for consistency with other array creation routines across JAX. + Specifying both ``out_sharding`` and ``device`` will result in an error. Returns: - Array of the specified shape and dtype, on the specified device if specified. + Array of the specified shape and dtype, with the given device/sharding if specified. See also: - :func:`jax.numpy.empty_like` @@ -145,10 +180,13 @@ def empty(shape: Any, dtype: DTypeLike | None = None, *, >>> jnp.empty((2, 3), dtype=bool) Array([[False, False, False], [False, False, False]], dtype=bool) + + .. _explicit sharding: https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html """ if (m := _check_forgot_shape_tuple("empty", shape, dtype)): raise TypeError(m) - dtypes.check_user_dtype_supported(dtype, "empty") - return zeros(shape, dtype, device=device) + dtype = dtypes.check_and_canonicalize_user_dtype( + float if dtype is None else dtype, "empty") + return zeros(shape, dtype, device=device, out_sharding=out_sharding) def _check_forgot_shape_tuple(name, shape, dtype) -> str | None: # type: ignore @@ -197,22 +235,25 @@ def full(shape: Any, fill_value: ArrayLike, Array([[0, 1, 2], [0, 1, 2]], dtype=int32) """ - dtypes.check_user_dtype_supported(dtype, "full") + if dtype is not None: + dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "full") util.check_arraylike("full", fill_value) if np.ndim(fill_value) == 0: shape = canonicalize_shape(shape) - return lax.full(shape, fill_value, dtype, sharding=util.normalize_device_to_sharding(device)) + return lax.full(shape, fill_value, dtype, + sharding=util.canonicalize_device_to_sharding(device)) else: - return jax.device_put( - util._broadcast_to(jax.numpy.asarray(fill_value, dtype=dtype), shape), device) + return api.device_put( + util._broadcast_to(asarray(fill_value, dtype=dtype), shape), device) @export def zeros_like(a: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = None, shape: Any = None, *, - device: xc.Device | Sharding | None = None) -> Array: + device: xc.Device | Sharding | None = None, + out_sharding: NamedSharding | P | None = None) -> Array: """Create an array full of zeros with the same shape and dtype as an array. JAX implementation of :func:`numpy.zeros_like`. @@ -244,18 +285,24 @@ def zeros_like(a: ArrayLike | DuckTypedArray, [0, 0, 0]], dtype=int32) """ if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing + if hasattr(a, '__jax_array__'): + a = a.__jax_array__() util.check_arraylike("zeros_like", a) - dtypes.check_user_dtype_supported(dtype, "zeros_like") + if dtype is not None: + dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "zeros_like") if shape is not None: shape = canonicalize_shape(shape) - return lax.full_like(a, 0, dtype, shape, sharding=util.normalize_device_to_sharding(device)) + sharding = util.choose_device_or_out_sharding( + device, out_sharding, "jnp.zeros_like") + return lax.full_like(a, 0, dtype, shape, sharding=sharding) @export def ones_like(a: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = None, shape: Any = None, *, - device: xc.Device | Sharding | None = None) -> Array: + device: xc.Device | Sharding | None = None, + out_sharding: NamedSharding | P | None = None) -> Array: """Create an array of ones with the same shape and dtype as an array. JAX implementation of :func:`numpy.ones_like`. @@ -287,11 +334,16 @@ def ones_like(a: ArrayLike | DuckTypedArray, [1, 1, 1]], dtype=int32) """ if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing + if hasattr(a, '__jax_array__'): + a = a.__jax_array__() util.check_arraylike("ones_like", a) - dtypes.check_user_dtype_supported(dtype, "ones_like") + if dtype is not None: + dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "ones_like") if shape is not None: shape = canonicalize_shape(shape) - return lax.full_like(a, 1, dtype, shape, sharding=util.normalize_device_to_sharding(device)) + sharding = util.choose_device_or_out_sharding( + device, out_sharding, "jnp.ones_like") + return lax.full_like(a, 1, dtype, shape, sharding=sharding) @export @@ -332,9 +384,15 @@ def empty_like(prototype: ArrayLike | DuckTypedArray, [0, 0, 0]], dtype=int32) """ if not (hasattr(prototype, 'dtype') and hasattr(prototype, 'shape')): # support duck typing - util.check_arraylike("empty_like", prototype) - dtypes.check_user_dtype_supported(dtype, "empty_like") - return zeros_like(prototype, dtype=dtype, shape=shape, device=device) + if hasattr(prototype, '__jax_array__'): + prototype = prototype.__jax_array__() + util.check_arraylike("ones_like", prototype) + if dtype is not None: + dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "ones_like") + if shape is not None: + shape = canonicalize_shape(shape) + return lax.full_like(prototype, 0, dtype, shape, + sharding=util.canonicalize_device_to_sharding(device)) @export @@ -382,13 +440,329 @@ def full_like(a: ArrayLike | DuckTypedArray, util.check_arraylike("full_like", 0, fill_value) else: util.check_arraylike("full_like", a, fill_value) - dtypes.check_user_dtype_supported(dtype, "full_like") + if hasattr(a, '__jax_array__'): + a = a.__jax_array__() + if dtype is not None: + dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "full_like") if shape is not None: shape = canonicalize_shape(shape) if np.ndim(fill_value) == 0: - return lax.full_like(a, fill_value, dtype, shape, sharding=util.normalize_device_to_sharding(device)) + return lax.full_like(a, fill_value, dtype, shape, + sharding=util.canonicalize_device_to_sharding(device)) else: shape = np.shape(a) if shape is None else shape # type: ignore[arg-type] dtype = dtypes.result_type(a) if dtype is None else dtype - return jax.device_put( - util._broadcast_to(jax.numpy.asarray(fill_value, dtype=dtype), shape), device) + return api.device_put( + util._broadcast_to(asarray(fill_value, dtype=dtype), shape), device) + +@overload +def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, + endpoint: bool = True, retstep: Literal[False] = False, + dtype: DTypeLike | None = None, + axis: int = 0, + *, device: xc.Device | Sharding | None = None) -> Array: ... +@overload +def linspace(start: ArrayLike, stop: ArrayLike, num: int, + endpoint: bool, retstep: Literal[True], + dtype: DTypeLike | None = None, + axis: int = 0, + *, device: xc.Device | Sharding | None = None) -> tuple[Array, Array]: ... +@overload +def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, + endpoint: bool = True, *, retstep: Literal[True], + dtype: DTypeLike | None = None, + axis: int = 0, + device: xc.Device | Sharding | None = None) -> tuple[Array, Array]: ... +@overload +def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, + endpoint: bool = True, retstep: bool = False, + dtype: DTypeLike | None = None, + axis: int = 0, + *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: ... +@export +def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, + endpoint: bool = True, retstep: bool = False, + dtype: DTypeLike | None = None, + axis: int = 0, + *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: + """Return evenly-spaced numbers within an interval. + + JAX implementation of :func:`numpy.linspace`. + + Args: + start: scalar or array of starting values. + stop: scalar or array of stop values. + num: number of values to generate. Default: 50. + endpoint: if True (default) then include the ``stop`` value in the result. + If False, then exclude the ``stop`` value. + retstep: If True, then return a ``(result, step)`` tuple, where ``step`` is the + interval between adjacent values in ``result``. + axis: integer axis along which to generate the linspace. Defaults to zero. + device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + An array ``values``, or a tuple ``(values, step)`` if ``retstep`` is True, where: + + - ``values`` is an array of evenly-spaced values from ``start`` to ``stop`` + - ``step`` is the interval between adjacent values. + + See also: + - :func:`jax.numpy.arange`: Generate ``N`` evenly-spaced values given a starting + point and a step + - :func:`jax.numpy.logspace`: Generate logarithmically-spaced values. + - :func:`jax.numpy.geomspace`: Generate geometrically-spaced values. + + Examples: + List of 5 values between 0 and 10: + + >>> jnp.linspace(0, 10, 5) + Array([ 0. , 2.5, 5. , 7.5, 10. ], dtype=float32) + + List of 8 values between 0 and 10, excluding the endpoint: + + >>> jnp.linspace(0, 10, 8, endpoint=False) + Array([0. , 1.25, 2.5 , 3.75, 5. , 6.25, 7.5 , 8.75], dtype=float32) + + List of values and the step size between them + + >>> vals, step = jnp.linspace(0, 10, 9, retstep=True) + >>> vals + Array([ 0. , 1.25, 2.5 , 3.75, 5. , 6.25, 7.5 , 8.75, 10. ], dtype=float32) + >>> step + Array(1.25, dtype=float32) + + Multi-dimensional linspace: + + >>> start = jnp.array([0, 5]) + >>> stop = jnp.array([5, 10]) + >>> jnp.linspace(start, stop, 5) + Array([[ 0. , 5. ], + [ 1.25, 6.25], + [ 2.5 , 7.5 ], + [ 3.75, 8.75], + [ 5. , 10. ]], dtype=float32) + """ + num = core.concrete_dim_or_error(num, "'num' argument of jnp.linspace") + axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.linspace") + return _linspace(start, stop, num, endpoint, retstep, dtype, axis, device=device) + +@api.jit(static_argnames=('num', 'endpoint', 'retstep', 'dtype', 'axis', 'device')) +def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, + endpoint: bool = True, retstep: bool = False, + dtype: DTypeLike | None = None, + axis: int = 0, + *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: + """Implementation of linspace differentiable in start and stop args.""" + if num < 0: + raise ValueError(f"Number of samples, {num}, must be non-negative.") + start, stop = util.ensure_arraylike("linspace", start, stop) + + if dtype is None: + dtype = dtypes.to_inexact_dtype(dtypes.result_type(start, stop)) + else: + dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "linspace") + computation_dtype = dtypes.to_inexact_dtype(dtype) + start = start.astype(computation_dtype) + stop = stop.astype(computation_dtype) + + bounds_shape = list(lax.broadcast_shapes(np.shape(start), np.shape(stop))) + broadcast_start = util._broadcast_to(start, bounds_shape) + broadcast_stop = util._broadcast_to(stop, bounds_shape) + axis = len(bounds_shape) + axis + 1 if axis < 0 else axis + bounds_shape.insert(axis, 1) + div = (num - 1) if endpoint else num + if num > 1: + delta: Array = lax.convert_element_type(stop - start, computation_dtype) / asarray(div, dtype=computation_dtype) + iota_shape = [1,] * len(bounds_shape) + iota_shape[axis] = div + # This approach recovers the endpoints with float32 arithmetic, + # but can lead to rounding errors for integer outputs. + real_dtype = dtypes.finfo(computation_dtype).dtype + step = lax.iota(real_dtype, div).reshape(iota_shape) / asarray(div, real_dtype) + step = step.astype(computation_dtype) + out = (broadcast_start.reshape(bounds_shape) * (1 - step) + + broadcast_stop.reshape(bounds_shape) * step) + + if endpoint: + out = lax.concatenate([out, lax.expand_dims(broadcast_stop, (axis,))], + canonicalize_axis(axis, out.ndim)) + + elif num == 1: + delta = asarray(np.nan if endpoint else stop - start, dtype=computation_dtype) + out = broadcast_start.reshape(bounds_shape) + else: # num == 0 degenerate case, match numpy behavior + empty_shape = list(lax.broadcast_shapes(np.shape(start), np.shape(stop))) + empty_shape.insert(axis, 0) + delta = full((), np.nan, computation_dtype) + out = empty(empty_shape, dtype) + + if dtypes.issubdtype(dtype, np.integer) and not dtypes.issubdtype(out.dtype, np.integer): + out = lax.floor(out) + + sharding = util.canonicalize_device_to_sharding(device) + result = lax._convert_element_type(out, dtype, sharding=sharding) + return (result, delta) if retstep else result + + +@export +def logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, + endpoint: bool = True, base: ArrayLike = 10.0, + dtype: DTypeLike | None = None, axis: int = 0) -> Array: + """Generate logarithmically-spaced values. + + JAX implementation of :func:`numpy.logspace`. + + Args: + start: scalar or array. Used to specify the start value. The start value is + ``base ** start``. + stop: scalar or array. Used to specify the stop value. The end value is + ``base ** stop``. + num: int, optional, default=50. Number of values to generate. + endpoint: bool, optional, default=True. If True, then include the ``stop`` value + in the result. If False, then exclude the ``stop`` value. + base: scalar or array, optional, default=10. Specifies the base of the logarithm. + dtype: optional. Specifies the dtype of the output. + axis: int, optional, default=0. Axis along which to generate the logspace. + + Returns: + An array of logarithm. + + See also: + - :func:`jax.numpy.arange`: Generate ``N`` evenly-spaced values given a starting + point and a step value. + - :func:`jax.numpy.linspace`: Generate evenly-spaced values. + - :func:`jax.numpy.geomspace`: Generate geometrically-spaced values. + + Examples: + List 5 logarithmically spaced values between 1 (``10 ** 0``) and 100 + (``10 ** 2``): + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.logspace(0, 2, 5) + Array([ 1. , 3.162, 10. , 31.623, 100. ], dtype=float32) + + List 5 logarithmically-spaced values between 1(``10 ** 0``) and 100 + (``10 ** 2``), excluding endpoint: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.logspace(0, 2, 5, endpoint=False) + Array([ 1. , 2.512, 6.31 , 15.849, 39.811], dtype=float32) + + List 7 logarithmically-spaced values between 1 (``2 ** 0``) and 4 (``2 ** 2``) + with base 2: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.logspace(0, 2, 7, base=2) + Array([1. , 1.26 , 1.587, 2. , 2.52 , 3.175, 4. ], dtype=float32) + + Multi-dimensional logspace: + + >>> start = jnp.array([0, 5]) + >>> stop = jnp.array([5, 0]) + >>> base = jnp.array([2, 3]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.logspace(start, stop, 5, base=base) + Array([[ 1. , 243. ], + [ 2.378, 61.547], + [ 5.657, 15.588], + [ 13.454, 3.948], + [ 32. , 1. ]], dtype=float32) + """ + num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.logspace") + axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.logspace") + return _logspace(start, stop, num, endpoint, base, dtype, axis) + +@api.jit(static_argnames=('num', 'endpoint', 'dtype', 'axis')) +def _logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, + endpoint: bool = True, base: ArrayLike = 10.0, + dtype: DTypeLike | None = None, axis: int = 0) -> Array: + """Implementation of logspace differentiable in start and stop args.""" + if dtype is None: + dtype = dtypes.to_inexact_dtype(dtypes.result_type(start, stop)) + else: + dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "logspace") + computation_dtype = dtypes.to_inexact_dtype(dtype) + start, stop = util.ensure_arraylike("logspace", start, stop) + start = start.astype(computation_dtype) + stop = stop.astype(computation_dtype) + lin = linspace(start, stop, num, + endpoint=endpoint, retstep=False, dtype=None, axis=axis) + return lax.convert_element_type(ufuncs.power(base, lin), dtype) + + +@export +def geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, + dtype: DTypeLike | None = None, axis: int = 0) -> Array: + """Generate geometrically-spaced values. + + JAX implementation of :func:`numpy.geomspace`. + + Args: + start: scalar or array. Specifies the starting values. + stop: scalar or array. Specifies the stop values. + num: int, optional, default=50. Number of values to generate. + endpoint: bool, optional, default=True. If True, then include the ``stop`` value + in the result. If False, then exclude the ``stop`` value. + dtype: optional. Specifies the dtype of the output. + axis: int, optional, default=0. Axis along which to generate the geomspace. + + Returns: + An array containing the geometrically-spaced values. + + See also: + - :func:`jax.numpy.arange`: Generate ``N`` evenly-spaced values given a starting + point and a step value. + - :func:`jax.numpy.linspace`: Generate evenly-spaced values. + - :func:`jax.numpy.logspace`: Generate logarithmically-spaced values. + + Examples: + List 5 geometrically-spaced values between 1 and 16: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.geomspace(1, 16, 5) + Array([ 1., 2., 4., 8., 16.], dtype=float32) + + List 4 geomtrically-spaced values between 1 and 16, with ``endpoint=False``: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.geomspace(1, 16, 4, endpoint=False) + Array([1., 2., 4., 8.], dtype=float32) + + Multi-dimensional geomspace: + + >>> start = jnp.array([1, 1000]) + >>> stop = jnp.array([27, 1]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.geomspace(start, stop, 4) + Array([[ 1., 1000.], + [ 3., 100.], + [ 9., 10.], + [ 27., 1.]], dtype=float32) + """ + num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.geomspace") + axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.geomspace") + return _geomspace(start, stop, num, endpoint, dtype, axis) + +@api.jit(static_argnames=('num', 'endpoint', 'dtype', 'axis')) +def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, + dtype: DTypeLike | None = None, axis: int = 0) -> Array: + """Implementation of geomspace differentiable in start and stop args.""" + if dtype is None: + dtype = dtypes.to_inexact_dtype(dtypes.result_type(start, stop)) + else: + dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "geomspace") + computation_dtype = dtypes.to_inexact_dtype(dtype) + start, stop = util.ensure_arraylike("geomspace", start, stop) + start = start.astype(computation_dtype) + stop = stop.astype(computation_dtype) + + sign = ufuncs.sign(start) + res = sign * logspace(ufuncs.log10(start / sign), ufuncs.log10(stop / sign), + num, endpoint=endpoint, base=10.0, + dtype=computation_dtype, axis=0) + axis = canonicalize_axis(axis, res.ndim) + if axis != 0: + # res = moveaxis(res, 0, axis) + res = lax.transpose(res, permutation=(*range(1, axis + 1), 0, *range(axis + 1, res.ndim))) + return lax.convert_element_type(res, dtype) diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index e9e097c85aff..e0496d157b67 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -24,29 +24,32 @@ __all__ = ['register_jax_array_methods'] import abc -from functools import partial, wraps +from functools import wraps import math -from typing import Any, Sequence +from typing import Any +from collections.abc import Callable, Sequence import numpy as np -import jax -from jax import lax -from jax.sharding import Sharding + from jax._src import api from jax._src import core from jax._src import dtypes +from jax._src import literals from jax._src.api_util import _ensure_index_tuple from jax._src.array import ArrayImpl -from jax._src.lax import lax as lax_internal +from jax._src.lax import lax +from jax._src.lax import slicing as lax_slicing from jax._src.lib import xla_client as xc from jax._src.numpy import array_api_metadata +from jax._src.numpy import array_creation from jax._src.numpy import indexing from jax._src.numpy import lax_numpy from jax._src.numpy import tensor_contractions -from jax._src.pjit import PartitionSpec -from jax._src.sharding_impls import canonicalize_sharding, NamedSharding from jax._src.numpy import reductions from jax._src.numpy import ufuncs +from jax._src.pjit import PartitionSpec +from jax._src.sharding import Sharding +from jax._src.sharding_impls import canonicalize_sharding, NamedSharding from jax._src.ops import scatter from jax._src.typing import Array, ArrayLike, DimSize, DTypeLike, Shape, StaticScalar from jax._src.util import safe_zip, safe_map @@ -189,7 +192,7 @@ def _diagonal(self: Array, offset: int = 0, axis1: int = 0, axis2: int = 1) -> A """ return lax_numpy.diagonal(self, offset=offset, axis1=axis1, axis2=axis2) -def _dot(self: Array, b: ArrayLike, *, precision: lax_internal.PrecisionLike = None, +def _dot(self: Array, b: ArrayLike, *, precision: lax.PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: """Compute the dot product of two arrays. @@ -197,12 +200,12 @@ def _dot(self: Array, b: ArrayLike, *, precision: lax_internal.PrecisionLike = N """ return tensor_contractions.dot(self, b, precision=precision, preferred_element_type=preferred_element_type) -def _flatten(self: Array, order: str = "C") -> Array: +def _flatten(self: Array, order: str = "C", *, out_sharding=None) -> Array: """Flatten array into a 1-dimensional shape. Refer to :func:`jax.numpy.ravel` for the full documentation. """ - return lax_numpy.ravel(self, order=order) + return lax_numpy.ravel(self, order=order, out_sharding=out_sharding) def _imag_property(self: Array) -> Array: """Return the imaginary part of the array.""" @@ -217,7 +220,7 @@ def _item(self: Array, *args: int) -> bool | int | float | complex: def _itemsize_property(self: Array) -> int: """Length of one array element in bytes.""" - return dtypes.dtype(self, canonicalize=True).itemsize + return self.dtype.itemsize def _matrix_transpose_property(self: Array): """Compute the (batched) matrix transpose. @@ -259,7 +262,7 @@ def _min(self: Array, axis: reductions.Axis = None, out: None = None, def _nbytes_property(self: Array) -> int: """Total bytes consumed by the elements of the array.""" - return np.size(self) * dtypes.dtype(self, canonicalize=True).itemsize + return np.size(self) * self.dtype.itemsize def _nonzero(self: Array, *, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None, size: int | None = None) -> tuple[Array, ...]: @@ -293,14 +296,18 @@ def _real_property(self: Array) -> Array: return ufuncs.real(self) def _repeat(self: Array, repeats: ArrayLike, axis: int | None = None, *, - total_repeat_length: int | None = None) -> Array: + total_repeat_length: int | None = None, + out_sharding: NamedSharding | PartitionSpec | None = None) -> Array: """Construct an array from repeated elements. Refer to :func:`jax.numpy.repeat` for the full documentation. """ - return lax_numpy.repeat(self, repeats=repeats, axis=axis, total_repeat_length=total_repeat_length) + return lax_numpy.repeat(self, repeats=repeats, axis=axis, + total_repeat_length=total_repeat_length, + out_sharding=out_sharding) -def _reshape(self: Array, *args: Any, order: str = "C") -> Array: +def _reshape(self: Array, *args: Any, order: str = "C", out_sharding=None + ) -> Array: """Returns an array containing the same data with a new shape. Refer to :func:`jax.numpy.reshape` for full documentation. @@ -308,10 +315,14 @@ def _reshape(self: Array, *args: Any, order: str = "C") -> Array: __tracebackhide__ = True newshape = _compute_newshape(self, args[0] if len(args) == 1 else args) if order == "C": - return lax.reshape(self, newshape, None) + return lax.reshape(self, newshape, None, out_sharding=out_sharding) elif order == "F": dims = list(range(self.ndim)[::-1]) - return lax.reshape(self, newshape[::-1], dims).T + out_sharding = canonicalize_sharding(out_sharding, "jnp.reshape") + out_sharding = ( + None if out_sharding is None else out_sharding.update( + spec=out_sharding.spec.update(partitions=out_sharding.spec[::-1]))) + return lax.reshape(self, newshape[::-1], dims, out_sharding=out_sharding).T elif order == "A": raise NotImplementedError("np.reshape order=A is not implemented.") else: @@ -499,22 +510,32 @@ def _view(self: Array, dtype: DTypeLike | None = None, type: None = None) -> Arr Array([ True, False, True], dtype=bool) However, there are no guarantees about the results of any expression involving - a view such as this: `jnp.array([1, 2, 3], dtype=jnp.int8).view(jnp.bool_)`. + a view such as this: ``jnp.array([1, 2, 3], dtype=jnp.int8).view(jnp.bool_)``. In particular, the results may change between JAX releases and depending on the platform. To safely convert such an array to a boolean array, compare it with `0`:: >>> jnp.array([1, 2, 0], dtype=jnp.int8) != 0 Array([ True, True, False], dtype=bool) + + Args: + dtype: An optional output dtype. If not specified, the output dtype is the + same as the input dtype. + type: Not implemented; accepted for NumPy compatibility. + Returns: + The array, viewed as the new dtype. Unlike NumPy, the array may or may not + be a copy of the input array. """ if type is not None: raise NotImplementedError("`type` argument of array.view() is not supported.") - dtypes.check_user_dtype_supported(dtype, "view") - dtype = dtypes.canonicalize_dtype(dtype) + if dtype is None: + return self + + dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "view") - nbits_in = dtypes.bit_width(self.dtype) - nbits_out = dtypes.bit_width(dtype) + nbits_in = dtypes.itemsize_bits(self.dtype) + nbits_out = dtypes.itemsize_bits(dtype) if self.ndim == 0: if nbits_in != nbits_out: @@ -536,9 +557,10 @@ def _view(self: Array, dtype: DTypeLike | None = None, type: None = None) -> Arr if lax_numpy.issubdtype(self.dtype, np.complexfloating): new_shape = (*self.shape[:-1], self.shape[-1] * 2) new_dtype = lax_numpy.finfo(self.dtype).dtype - self = (lax_numpy.zeros(new_shape, new_dtype) - .at[..., 0::2].set(self.real) - .at[..., 1::2].set(self.imag)) + new_sharding = core.typeof(self).sharding + self = (array_creation.zeros(new_shape, new_dtype, out_sharding=new_sharding) + .at[..., 0::2].set(self.real) + .at[..., 1::2].set(self.imag)) return _view(self, dtype) if dtype == bool: @@ -566,7 +588,15 @@ def _notimplemented_flat(self): raise NotImplementedError("JAX Arrays do not implement the arr.flat property: " "consider arr.flatten() instead.") -_accepted_binop_types = (int, float, complex, np.generic, np.ndarray, Array) +_accepted_binop_types = ( + int, + float, + complex, + np.generic, + np.ndarray, + Array, + literals.TypedNdArray, +) _rejected_binop_types = (list, tuple, set, dict) def _defer_to_unrecognized_arg(opchar, binary_op, swap=False): @@ -588,7 +618,7 @@ def deferring_binary_op(self, other): def _unimplemented_setitem(self, i, x): msg = ("JAX arrays are immutable and do not support in-place item assignment." " Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method:" - " https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html") + " https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html") raise TypeError(msg.format(type(self))) def _operator_round(number: ArrayLike, ndigits: int | None = None) -> Array: @@ -608,12 +638,13 @@ def _deepcopy(self: Array, memo: Any) -> Array: def __array_module__(self, types): if all(issubclass(t, _HANDLED_ARRAY_TYPES) for t in types): + import jax.numpy # pytype: disable=import-error return jax.numpy else: return NotImplemented -@partial(jax.jit, static_argnums=(1,2,3)) +@api.jit(static_argnums=(1,2,3)) def _multi_slice(self: Array, start_indices: tuple[tuple[int, ...]], limit_indices: tuple[tuple[int, ...]], @@ -625,7 +656,7 @@ def _multi_slice(self: Array, """ results: list[Array] = [] for starts, limits, removed in zip(start_indices, limit_indices, removed_dims): - sliced = lax.slice(self, starts, limits) + sliced = lax_slicing.slice(self, starts, limits) if removed: sliced = lax.squeeze(sliced, removed) results.append(sliced) @@ -633,7 +664,7 @@ def _multi_slice(self: Array, # The next two functions are related to iter(array), implemented here to # avoid circular imports. -@jax.jit +@api.jit def _unstack(x: Array) -> list[Array]: dims = (0,) return [lax.squeeze(t, dims) for t in lax.split(x, (1,) * x.shape[0])] @@ -644,13 +675,14 @@ def _chunk_iter(x, size): else: num_chunks, tail = ufuncs.divmod(x.shape[0], size) for i in range(num_chunks): - yield lax.dynamic_slice_in_dim(x, i * size, size) + yield lax_slicing.dynamic_slice_in_dim(x, i * size, size) if tail: - yield lax.dynamic_slice_in_dim(x, num_chunks * size, tail) + yield lax_slicing.dynamic_slice_in_dim(x, num_chunks * size, tail) def _getitem(self, item): return indexing.rewriting_take(self, item) + # Syntactic sugar for scatter operations. class _IndexUpdateHelper: # Note: this docstring will appear as the docstring for the `at` property. @@ -689,10 +721,8 @@ class _IndexUpdateHelper: By default, JAX assumes that all indices are in-bounds. Alternative out-of-bound index semantics can be specified via the ``mode`` parameter (see below). - Arguments - --------- - mode : str - Specify out-of-bound indexing mode. Options are: + Args: + mode: string specifying out-of-bound indexing mode. Options are: - ``"promise_in_bounds"``: (default) The user promises that indices are in bounds. No additional checking will be performed. In practice, this means that @@ -703,50 +733,68 @@ class _IndexUpdateHelper: - ``"fill"``: alias for ``"drop"``. For `get()`, the optional ``fill_value`` argument specifies the value that will be returned. - See :class:`jax.lax.GatherScatterMode` for more details. - - indices_are_sorted : bool - If True, the implementation will assume that the indices passed to ``at[]`` - are sorted in ascending order, which can lead to more efficient execution - on some backends. - unique_indices : bool - If True, the implementation will assume that the indices passed to ``at[]`` - are unique, which can result in more efficient execution on some backends. - fill_value : Any - Only applies to the ``get()`` method: the fill value to return for out-of-bounds - slices when `mode` is ``'fill'``. Ignored otherwise. Defaults to ``NaN`` for - inexact types, the largest negative value for signed types, the largest positive - value for unsigned types, and ``True`` for booleans. - - Examples - -------- - >>> x = jnp.arange(5.0) - >>> x - Array([0., 1., 2., 3., 4.], dtype=float32) - >>> x.at[2].add(10) - Array([ 0., 1., 12., 3., 4.], dtype=float32) - >>> x.at[10].add(10) # out-of-bounds indices are ignored - Array([0., 1., 2., 3., 4.], dtype=float32) - >>> x.at[20].add(10, mode='clip') - Array([ 0., 1., 2., 3., 14.], dtype=float32) - >>> x.at[2].get() - Array(2., dtype=float32) - >>> x.at[20].get() # out-of-bounds indices clipped - Array(4., dtype=float32) - >>> x.at[20].get(mode='fill') # out-of-bounds indices filled with NaN - Array(nan, dtype=float32) - >>> x.at[20].get(mode='fill', fill_value=-1) # custom fill value - Array(-1., dtype=float32) + See :class:`jax.lax.GatherScatterMode` for more details. + wrap_negative_indices: If True (default) then negative indices indicate position + from the end of the array, similar to Python and NumPy indexing. If False, then + negative indices are considered out-of-bounds and behave according to the + ``mode`` parameter. + fill_value: Only applies to the ``get()`` method: the fill value to return for + out-of-bounds slices when ``mode`` is ``'fill'``. Ignored otherwise. Defaults + to ``NaN`` for inexact types, the largest negative value for signed types, the + largest positive value for unsigned types, and ``True`` for booleans. + indices_are_sorted: If True, the implementation will assume that the (normalized) + indices passed to ``at[]`` are sorted in ascending order, which can lead to more + efficient execution on some backends. If True but the indices are not actually + sorted, the output is undefined. + unique_indices: If True, the implementation will assume that the (normalized) indices + passed to ``at[]`` are unique, which can result in more efficient execution on some + backends. If True but the indices are not actually unique, the output is undefined. + + Examples: + >>> x = jnp.arange(5.0) + >>> x + Array([0., 1., 2., 3., 4.], dtype=float32) + >>> x.at[2].get() + Array(2., dtype=float32) + >>> x.at[2].add(10) + Array([ 0., 1., 12., 3., 4.], dtype=float32) + + By default, out-of-bound indices are ignored in updates, but this behavior + can be controlled with the ``mode`` parameter: + + >>> x.at[10].add(10) # dropped + Array([0., 1., 2., 3., 4.], dtype=float32) + >>> x.at[20].add(10, mode='clip') # clipped + Array([ 0., 1., 2., 3., 14.], dtype=float32) + + For ``get()``, out-of-bound indices are clipped by default: + + >>> x.at[20].get() # out-of-bounds indices clipped + Array(4., dtype=float32) + >>> x.at[20].get(mode='fill') # out-of-bounds indices filled with NaN + Array(nan, dtype=float32) + >>> x.at[20].get(mode='fill', fill_value=-1) # custom fill value + Array(-1., dtype=float32) + + Negative indices count from the end of the array, but this behavior can + be disabled by setting ``wrap_negative_indices = False``: + + >>> x.at[-1].set(99) + Array([ 0., 1., 2., 3., 99.], dtype=float32) + >>> x.at[-1].set(99, wrap_negative_indices=False, mode='drop') # dropped! + Array([0., 1., 2., 3., 4.], dtype=float32) """ __slots__ = ("array",) - def __init__(self, array): + array: Array + + def __init__(self, array: Array): self.array = array - def __getitem__(self, index): + def __getitem__(self, index: scatter.Index) -> _IndexUpdateRef: return _IndexUpdateRef(self.array, index) - def __repr__(self): + def __repr__(self) -> str: return f"_IndexUpdateHelper({self.array!r})" @@ -759,15 +807,21 @@ class _IndexUpdateRef: """ __slots__ = ("array", "index") - def __init__(self, array, index): + array: Array + index: scatter.Index + + def __init__(self, array: Array, index: scatter.Index): self.array = array self.index = index def __repr__(self) -> str: return f"_IndexUpdateRef({self.array!r}, {self.index!r})" - def get(self, *, indices_are_sorted=False, unique_indices=False, - mode=None, fill_value=None, out_sharding=None): + def get(self, *, indices_are_sorted: bool = False, unique_indices: bool = False, + mode: str | lax_slicing.GatherScatterMode | None = None, + fill_value: ArrayLike | None = None, + out_sharding: NamedSharding | PartitionSpec | None = None, + wrap_negative_indices: bool = True): """Equivalent to ``x[idx]``. Returns the value of ``x`` that would result from the NumPy-style @@ -775,7 +829,7 @@ def get(self, *, indices_are_sorted=False, unique_indices=False, the usual array indexing syntax in that it allows additional keyword arguments ``indices_are_sorted`` and ``unique_indices`` to be passed. - See :mod:`jax.ops` for details. + See :func:`jax.numpy.ndarray.at` for details. """ if out_sharding is not None: assert isinstance(out_sharding, (NamedSharding, PartitionSpec)) @@ -784,23 +838,34 @@ def get(self, *, indices_are_sorted=False, unique_indices=False, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode, fill_value=fill_value, + normalize_indices=wrap_negative_indices, out_sharding=out_sharding) - def set(self, values, *, indices_are_sorted=False, unique_indices=False, - mode=None): + def set(self, values: ArrayLike, *, indices_are_sorted: bool = False, + unique_indices: bool = False, + mode: str | lax_slicing.GatherScatterMode | None = None, + out_sharding: NamedSharding | PartitionSpec | None = None, + wrap_negative_indices: bool = True) -> None: """Pure equivalent of ``x[idx] = y``. Returns the value of ``x`` that would result from the NumPy-style :mod:`indexed assignment ` ``x[idx] = y``. - See :mod:`jax.ops` for details. + See :func:`jax.numpy.ndarray.at` for details. """ - return scatter._scatter_update(self.array, self.index, values, lax.scatter, - indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode) - - def apply(self, func, *, indices_are_sorted=False, unique_indices=False, - mode=None): + if out_sharding is not None: + assert isinstance(out_sharding, (NamedSharding, PartitionSpec)) + out_sharding = canonicalize_sharding(out_sharding, '.set') + return scatter._scatter_update( + self.array, self.index, values, lax_slicing.scatter, + indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, + mode=mode, out_sharding=out_sharding, # type: ignore + normalize_indices=wrap_negative_indices) + + def apply(self, func: Callable[[ArrayLike], Array], *, + indices_are_sorted: bool = False, unique_indices: bool = False, + mode: str | lax_slicing.GatherScatterMode | None = None, + wrap_negative_indices: bool = True) -> Array: """Pure equivalent of ``func.at(x, idx)`` for a unary ufunc ``func``. Returns the value of ``x`` that would result from applying the unary @@ -812,121 +877,144 @@ def apply(self, func, *, indices_are_sorted=False, unique_indices=False, Note that in the current implementation, ``scatter_apply`` is not compatible with automatic differentiation. - See :mod:`jax.ops` for details. + See :func:`jax.numpy.ndarray.at` for details. """ def _scatter_apply(x, indices, y, dims, **kwargs): - return lax.scatter_apply(x, indices, func, dims, update_shape=y.shape, **kwargs) - return scatter._scatter_update(self.array, self.index, - lax_internal._zero(self.array), - _scatter_apply, - indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode) - - def add(self, values, *, indices_are_sorted=False, unique_indices=False, - mode=None): + return lax_slicing.scatter_apply(x, indices, func, dims, update_shape=y.shape, **kwargs) + return scatter._scatter_update( + self.array, self.index, lax._zero(self.array), _scatter_apply, + indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, + mode=mode, normalize_indices=wrap_negative_indices) + + def add(self, values: ArrayLike, *, + indices_are_sorted: bool = False, unique_indices: bool = False, + mode: str | lax_slicing.GatherScatterMode | None = None, + out_sharding: NamedSharding | PartitionSpec | None = None, + wrap_negative_indices: bool = True) -> Array: """Pure equivalent of ``x[idx] += y``. Returns the value of ``x`` that would result from the NumPy-style :mod:indexed assignment ` ``x[idx] += y``. - See :mod:`jax.ops` for details. + See :func:`jax.numpy.ndarray.at` for details. """ - return scatter._scatter_update(self.array, self.index, values, - lax.scatter_add, - indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode) - - def subtract(self, values, *, indices_are_sorted=False, unique_indices=False, - mode=None): + if out_sharding is not None: + assert isinstance(out_sharding, (NamedSharding, PartitionSpec)) + out_sharding = canonicalize_sharding(out_sharding, '.add') + return scatter._scatter_update( + self.array, self.index, values, lax_slicing.scatter_add, + indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, + mode=mode, out_sharding=out_sharding, # type: ignore + normalize_indices=wrap_negative_indices) + + def subtract(self, values: ArrayLike, *, + indices_are_sorted: bool = False, unique_indices: bool = False, + mode: str | lax_slicing.GatherScatterMode | None = None, + wrap_negative_indices: bool = True) -> Array: """Pure equivalent of ``x[idx] -= y``. Returns the value of ``x`` that would result from the NumPy-style :mod:indexed assignment ` ``x[idx] -= y``. - See :mod:`jax.ops` for details. + See :func:`jax.numpy.ndarray.at` for details. """ return scatter._scatter_update(self.array, self.index, values, - lax.scatter_sub, + lax_slicing.scatter_sub, indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode) + unique_indices=unique_indices, mode=mode, + normalize_indices=wrap_negative_indices) - def multiply(self, values, *, indices_are_sorted=False, unique_indices=False, - mode=None): + def multiply(self, values: ArrayLike, *, + indices_are_sorted: bool = False, unique_indices: bool = False, + mode: str | lax_slicing.GatherScatterMode | None = None, + wrap_negative_indices: bool = True) -> Array: """Pure equivalent of ``x[idx] *= y``. Returns the value of ``x`` that would result from the NumPy-style :mod:indexed assignment ` ``x[idx] *= y``. - See :mod:`jax.ops` for details. + See :func:`jax.numpy.ndarray.at` for details. """ return scatter._scatter_update(self.array, self.index, values, - lax.scatter_mul, + lax_slicing.scatter_mul, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, - mode=mode) + mode=mode, normalize_indices=wrap_negative_indices) mul = multiply - def divide(self, values, *, indices_are_sorted=False, unique_indices=False, - mode=None): + def divide(self, values: ArrayLike, *, + indices_are_sorted: bool = False, unique_indices: bool = False, + mode: str | lax_slicing.GatherScatterMode | None = None, + wrap_negative_indices: bool = True) -> Array: """Pure equivalent of ``x[idx] /= y``. Returns the value of ``x`` that would result from the NumPy-style :mod:indexed assignment ` ``x[idx] /= y``. - See :mod:`jax.ops` for details. + See :func:`jax.numpy.ndarray.at` for details. """ return ufuncs.divide( self.array, - scatter._scatter_update(lax_numpy.ones_like(self.array), self.index, values, - lax.scatter_mul, + scatter._scatter_update(array_creation.ones_like(self.array), self.index, values, + lax_slicing.scatter_mul, indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode)) + unique_indices=unique_indices, mode=mode, + normalize_indices=wrap_negative_indices)) - def power(self, values, *, indices_are_sorted=False, unique_indices=False, - mode=None): + def power(self, values: ArrayLike, *, + indices_are_sorted: bool = False, unique_indices: bool = False, + mode: str | lax_slicing.GatherScatterMode | None = None, + wrap_negative_indices: bool = True) -> Array: """Pure equivalent of ``x[idx] **= y``. Returns the value of ``x`` that would result from the NumPy-style :mod:indexed assignment ` ``x[idx] **= y``. - See :mod:`jax.ops` for details. + See :func:`jax.numpy.ndarray.at` for details. """ return ufuncs.power( self.array, - scatter._scatter_update(lax_numpy.ones_like(self.array), self.index, values, - lax.scatter_mul, + scatter._scatter_update(array_creation.ones_like(self.array), self.index, values, + lax_slicing.scatter_mul, indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode)) + unique_indices=unique_indices, mode=mode, + normalize_indices=wrap_negative_indices)) - def min(self, values, *, indices_are_sorted=False, unique_indices=False, - mode=None): + def min(self, values: ArrayLike, *, + indices_are_sorted: bool = False, unique_indices: bool = False, + mode: str | lax_slicing.GatherScatterMode | None = None, + wrap_negative_indices: bool = True) -> Array: """Pure equivalent of ``x[idx] = minimum(x[idx], y)``. Returns the value of ``x`` that would result from the NumPy-style :mod:indexed assignment ` ``x[idx] = minimum(x[idx], y)``. - See :mod:`jax.ops` for details. + See :func:`jax.numpy.ndarray.at` for details. """ return scatter._scatter_update(self.array, self.index, values, - lax.scatter_min, + lax_slicing.scatter_min, indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode) + unique_indices=unique_indices, mode=mode, + normalize_indices=wrap_negative_indices) - def max(self, values, *, indices_are_sorted=False, unique_indices=False, - mode=None): + def max(self, values: ArrayLike, *, + indices_are_sorted: bool = False, unique_indices: bool = False, + mode: str | lax_slicing.GatherScatterMode | None = None, + wrap_negative_indices: bool = True) -> Array: """Pure equivalent of ``x[idx] = maximum(x[idx], y)``. Returns the value of ``x`` that would result from the NumPy-style :mod:indexed assignment ` ``x[idx] = maximum(x[idx], y)``. - See :mod:`jax.ops` for details. + See :func:`jax.numpy.ndarray.at` for details. """ return scatter._scatter_update(self.array, self.index, values, - lax.scatter_max, + lax_slicing.scatter_max, indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode) + unique_indices=unique_indices, mode=mode, + normalize_indices=wrap_negative_indices) _array_operators = { "getitem": _getitem, @@ -948,8 +1036,6 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False, "rsub": _defer_to_unrecognized_arg("-", ufuncs.subtract, swap=True), "mul": _defer_to_unrecognized_arg("*", ufuncs.multiply), "rmul": _defer_to_unrecognized_arg("*", ufuncs.multiply, swap=True), - "div": _defer_to_unrecognized_arg("/", ufuncs.divide), - "rdiv": _defer_to_unrecognized_arg("/", ufuncs.divide, swap=True), "truediv": _defer_to_unrecognized_arg("/", ufuncs.true_divide), "rtruediv": _defer_to_unrecognized_arg("/", ufuncs.true_divide, swap=True), "floordiv": _defer_to_unrecognized_arg("//", ufuncs.floor_divide), @@ -1126,7 +1212,6 @@ def _set_array_abstract_methods(basearray): def register_jax_array_methods(): """Call this function once to register methods of JAX arrays""" _set_shaped_array_attributes(core.ShapedArray) - _set_shaped_array_attributes(core.DShapedArray) _set_array_base_attributes(ArrayImpl, exclude={'__getitem__'}) _set_tracer_aval_forwarding(core.Tracer, exclude={*_impl_only_array_methods, "at"}) diff --git a/jax/_src/numpy/einsum.py b/jax/_src/numpy/einsum.py index 9d745643b596..b8f72081df62 100644 --- a/jax/_src/numpy/einsum.py +++ b/jax/_src/numpy/einsum.py @@ -13,19 +13,20 @@ # limitations under the License. import collections -from typing import overload, Any, Callable, Sequence +from typing import overload, Any +from collections.abc import Callable, Sequence import numpy as np import opt_einsum -from jax._src import config +from jax._src import api from jax._src import core from jax._src import dtypes -from jax._src.api import jit, named_call +from jax._src.export import shape_poly from jax._src.lax import lax -from jax._src.lax.lax import PrecisionLike from jax._src.numpy import util -from jax._src.sharding_impls import canonicalize_sharding, NamedSharding, PartitionSpec as P +from jax._src.pjit import auto_axes +from jax._src.sharding_impls import canonicalize_sharding, NamedSharding from jax._src.typing import Array, ArrayLike, DTypeLike from jax._src.util import partition_list, set_module, unzip2 @@ -44,7 +45,7 @@ def einsum( *operands: ArrayLike, out: None = None, optimize: str | bool | list[tuple[int, ...]] = "auto", - precision: PrecisionLike = None, + precision: lax.PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[..., Array] = lax.dot_general, out_sharding=None, @@ -57,7 +58,7 @@ def einsum( *operands: ArrayLike | Sequence[Any], out: None = None, optimize: str | bool | list[tuple[int, ...]] = "auto", - precision: PrecisionLike = None, + precision: lax.PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[..., Array] = lax.dot_general, out_sharding=None, @@ -69,7 +70,7 @@ def einsum( *operands, out: None = None, optimize: str | bool | list[tuple[int, ...]] = "auto", - precision: PrecisionLike = None, + precision: lax.PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[..., Array] = lax.dot_general, out_sharding=None, @@ -288,6 +289,10 @@ def einsum( spec = operands[0] if isinstance(operands[0], str) else None path_type = 'optimal' if optimize is True else Unoptimized() if optimize is False else optimize + # Extract __jax_array__ before passing to contract_path() + operands = tuple(op.__jax_array__() if hasattr(op, "__jax_array__") else op + for op in operands) + # Allow handling of shape polymorphism non_constant_dim_types = { type(d) for op in operands if not isinstance(op, str) @@ -303,13 +308,33 @@ def einsum( *operands, einsum_call=True, use_blas=True, optimize=path_type) contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions) # pytype: disable=attribute-error + num_contractions = len(contractions) - jit_einsum = jit(_einsum, static_argnums=(1, 2, 3, 4, 5), inline=True) + out_sharding = canonicalize_sharding(out_sharding, 'einsum') + if out_sharding is not None and not isinstance(out_sharding, NamedSharding): + raise NotImplementedError( + "`out_sharding` argument of `einsum` only supports NamedSharding" + " instances.") + + jit_einsum = api.jit(_einsum, static_argnums=(1, 2, 3, 4, 5), inline=True) if spec is not None: - jit_einsum = named_call(jit_einsum, name=spec) + jit_einsum = api.named_call(jit_einsum, name=spec) operand_arrays = list(util.ensure_arraylike_tuple("einsum", operands)) - return jit_einsum(operand_arrays, contractions, precision, - preferred_element_type, _dot_general, out_sharding) + + if num_contractions > 1 and out_sharding is not None: + # TODO(yashkatariya): If the out_sharding is unreduced, figure out a way to + # run the dot_general unreduced_rule on these einsums because right now we + # drop into Auto mode skipping the checks happening in the rule. + return auto_axes( + jit_einsum, + axes=out_sharding.mesh.explicit_axes, + out_sharding=out_sharding, + )(operand_arrays, contractions=contractions, precision=precision, + preferred_element_type=preferred_element_type, _dot_general=_dot_general, + out_sharding=None) + else: + return jit_einsum(operand_arrays, contractions, precision, + preferred_element_type, _dot_general, out_sharding) # Enable other modules to override einsum_contact_path. @@ -393,10 +418,8 @@ def einsum_path( .. _opt_einsum: https://github.com/dgasmith/opt_einsum """ - if optimize is True: - optimize = 'optimal' - elif optimize is False: - optimize = Unoptimized() + if isinstance(optimize, bool): + optimize = 'optimal' if optimize else Unoptimized() return opt_einsum.contract_path(subscripts, *operands, optimize=optimize) def _removechars(s, chars): @@ -411,22 +434,21 @@ def _einsum( _dot_general=lax.dot_general, out_sharding=None, ): - out_sharding = canonicalize_sharding(out_sharding, 'einsum') - if out_sharding is not None and not isinstance(out_sharding, NamedSharding): - raise NotImplementedError( - "`out_sharding` argument of `einsum` only supports NamedSharding" - " instances. Please file a bug if this is not enough for your use case.") - dtypes.check_user_dtype_supported(preferred_element_type, "einsum") if preferred_element_type is None: - preferred_element_type, output_weak_type = dtypes.result_type(*operands, return_weak_type_flag=True) + preferred_element_type, output_weak_type = dtypes.result_type( + *operands, return_weak_type_flag=True) else: + preferred_element_type = dtypes.check_and_canonicalize_user_dtype( + preferred_element_type, 'einsum' + ) output_weak_type = False def sum(x, axes): if dtypes.result_type(x, preferred_element_type) != x.dtype: x = x.astype(preferred_element_type) - return lax.reduce(x, np.array(0, x.dtype), - lax.add if x.dtype != bool else lax.bitwise_or, axes) + return lax.reduce( + x, np.array(0, x.dtype), lax.add if x.dtype != bool else lax.bitwise_or, + axes, out_sharding) def sum_uniques(operand, names, uniques): if uniques: @@ -456,8 +478,7 @@ def filter_singleton_dims(operand, names, other_shape, other_names): sqez_axes, keep_axes = partition_list(keep, list(range(operand.ndim))) return lax.squeeze(operand, sqez_axes), "".join(names[i] for i in keep_axes) - for i, (operand_indices, contracted_names_set, einstr) in enumerate(contractions): - last_contraction = i == len(contractions) - 1 + for operand_indices, contracted_names_set, einstr in contractions: contracted_names = sorted(contracted_names_set) input_str, result_names = einstr.split('->') input_names = input_str.split(',') @@ -515,7 +536,7 @@ def filter_singleton_dims(operand, names, other_shape, other_names): # NOTE(mattjj): this can fail non-deterministically in python3, maybe # due to opt_einsum - assert config.dynamic_shapes.value or all( + assert all( name in lhs_names and name in rhs_names and lhs.shape[lhs_names.index(name)] == rhs.shape[rhs_names.index(name)] for name in contracted_names), ( @@ -537,33 +558,27 @@ def filter_singleton_dims(operand, names, other_shape, other_names): names = batch_names_str + remaining_rhs_names + remaining_lhs_names if names == result_names: dimension_numbers = ((rhs_cont, lhs_cont), (rhs_batch, lhs_batch)) - k_out_sharding = ({} if out_sharding is None else - {'out_sharding': out_sharding}) + dot_out_sharding = ({} if out_sharding is None else + {'out_sharding': out_sharding}) operand = _dot_general(rhs, lhs, dimension_numbers, precision, preferred_element_type=preferred_element_type, - **k_out_sharding) + **dot_out_sharding) else: names = batch_names_str + remaining_lhs_names + remaining_rhs_names - if not last_contraction: - dot_general_out_sharding = None - elif out_sharding is not None and names != result_names: - if len(result_names) > len(out_sharding.spec): - out_sharding = out_sharding.with_spec( - out_sharding.spec._normalized_spec_for_aval(len(result_names))) - spec = out_sharding.spec - inverse_spec = tuple(spec[result_names.index(name)] for name in names) - dot_general_out_sharding = NamedSharding( - out_sharding.mesh, P(*inverse_spec)) - else: - dot_general_out_sharding = out_sharding # type: ignore dimension_numbers = ((lhs_cont, rhs_cont), (lhs_batch, rhs_batch)) - dot_general_out_sharding = ({} if dot_general_out_sharding is None else # type: ignore - {'out_sharding': dot_general_out_sharding}) + out_sharding = (_get_inverse_sharding(out_sharding, names, result_names) + if out_sharding is not None and names != result_names + else out_sharding) + dot_out_sharding = ({} if out_sharding is None else # type: ignore + {'out_sharding': out_sharding}) operand = _dot_general(lhs, rhs, dimension_numbers, precision, - preferred_element_type=preferred_element_type, - **dot_general_out_sharding) + preferred_element_type=preferred_element_type, + **dot_out_sharding) else: - raise NotImplementedError # if this is actually reachable, open an issue! + raise NotImplementedError( + "jax.numpy.einsum does not support simultaneous contraction of 3 or more" + " operands. Typically this means you've passed an unsupported path to" + " the einsum optimize parameter.") # the resulting 'operand' with axis labels 'names' should be a permutation # of the desired result @@ -576,3 +591,14 @@ def filter_singleton_dims(operand, names, other_shape, other_names): return lax._convert_element_type(operands[0], preferred_element_type, output_weak_type) + +def _get_inverse_sharding(out_sharding, names, result_names): + if len(result_names) > len(out_sharding.spec): + out_sharding = out_sharding.update(spec= + out_sharding.spec._normalized_spec_for_aval(len(result_names))) + spec = out_sharding.spec + inverse_spec = tuple(spec[result_names.index(name)] for name in names) + return NamedSharding(out_sharding.mesh, spec.update(partitions=inverse_spec)) + + +_poly_einsum_handlers[shape_poly._DimExpr] = shape_poly._einsum_contract_path diff --git a/jax/_src/numpy/error.py b/jax/_src/numpy/error.py new file mode 100644 index 000000000000..c0e48f56fc9c --- /dev/null +++ b/jax/_src/numpy/error.py @@ -0,0 +1,198 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +from typing import Literal +from collections.abc import Sequence + +import numpy as np + +from jax._src import config +from jax._src import dtypes +from jax._src import error_check as error_check_lib +from jax._src.numpy import array_constructors +from jax._src.numpy import array_creation +from jax._src.numpy import lax_numpy +from jax._src.numpy import ufuncs +from jax._src.numpy import reductions +from jax._src.typing import Array, ArrayLike + +Category = Literal["nan", "divide", "oob"] + + +def _is_category_disabled( + category: Category | None, +) -> bool: + """Check if the error checking behavior for the given category is disabled.""" + if category is None: + return False + if category == "nan": + raise ValueError("nan is deprecated. Use `_set_error_if_nan` instead.") + if category == "divide": + raise ValueError( + "divide is deprecated. Use `_set_error_if_divide_by_zero` instead." + ) + if category == "oob": + return config.error_checking_behavior_oob.value == "ignore" + raise ValueError(f"Invalid category: {category}") + + +def _set_error_if_with_category( + pred: Array, + /, + msg: str, + category: Category | None = None, +) -> None: + """Set the internal error state if any element of `pred` is `True`. + + This function is similar to :func:`set_error_if`, but it also takes a category + argument. The category can be "nan", "divide", or "oob". The error checking + behavior for each category can be configured using + :func:`set_error_checking_behavior`. If not provided, there will be no + category. + + This function is intended for use in JAX internal APIs (e.g., `jax.numpy`) + to perform category-specific runtime checks tied to the operation being + performed. + """ + if _is_category_disabled(category): + return + + error_check_lib.set_error_if(pred, msg) + + +def _set_error_if_nan(pred: Array, /): + """Set the internal error state if any element of `pred` is `NaN`. + + This function is disabled if the `jax_error_checking_behavior_nan` flag is + set to "ignore". + """ + if config.error_checking_behavior_nan.value == "ignore": + return + + if not dtypes.issubdtype(pred.dtype, np.floating): # only check floats + return + + error_check_lib.set_error_if(ufuncs.isnan(pred), "NaN encountered") + + +def _set_error_if_divide_by_zero(pred: Array, /): + """Set the internal error state if any element of `pred` is zero. + + This function is intended for checking if the denominator of a division is + zero. + + This function is disabled if the `jax_error_checking_behavior_divide` flag is + set to "ignore". + """ + if config.error_checking_behavior_divide.value == "ignore": + return + + zero = array_creation.zeros_like(pred, shape=()) + error_check_lib.set_error_if(pred == zero, "Division by zero encountered") + + +def _check_precondition_oob_gather( + shape: tuple[int, ...], gather_indices: ArrayLike +) -> None: + """Check for out of bounds errors before calling `lax.gather`.""" + if config.error_checking_behavior_oob.value == "ignore": + return + if not np.size(gather_indices): + return + + gather_indices = array_constructors.array(gather_indices) + shape = array_constructors.array(shape, dtype=gather_indices.dtype) + error_check_lib.set_error_if( + ufuncs.logical_or( + reductions.min(gather_indices) < -shape, + reductions.max(gather_indices) >= shape, + ), + "Out of bounds encountered before calling `lax.gather`", + ) + + +def _check_precondition_oob_dynamic_slice( + shape: tuple[int, ...], + start_indices: Sequence[ArrayLike], + slice_sizes: list[int], + allow_negative_indices: list[bool], +) -> None: + """Check for out of bounds errors before calling `lax.dynamic_slice`.""" + if config.error_checking_behavior_oob.value == "ignore": + return + + start_indices = array_constructors.array(start_indices) + shape = array_constructors.array(shape, dtype=start_indices.dtype) + slice_sizes = array_constructors.array(slice_sizes, dtype=start_indices.dtype) + allow_negative_indices = array_constructors.array(allow_negative_indices, dtype='bool') + + lower_bound = lax_numpy.where(allow_negative_indices, -shape, 0) + error_check_lib.set_error_if( + ufuncs.logical_or( + ufuncs.minimum(start_indices, start_indices + slice_sizes) < lower_bound, + ufuncs.maximum(start_indices, start_indices + slice_sizes) >= shape, + ), + "Out of bounds encountered before calling `lax.dynamic_slice`", + ) + + +Behavior = Literal["ignore", "raise"] + + +class error_checking_behavior: + """A context manager to set the error checking behavior. + + If both `all` and a category are provided, the category will override the + `all` setting. + + When the error checking behavior is set to "ignore", all errors will be + ignored. When set to "raise", errors will be detected and recorded, but an + exception will not be raised immediately. Users must call + :func:`raise_if_error` to at the end of the computation to raise the + exception. + """ + + def __init__( + self, + *, + all: Behavior | None = None, + nan: Behavior | None = None, + divide: Behavior | None = None, + oob: Behavior | None = None, + ) -> None: + new_settings = {} + if all is not None: + new_settings["nan"] = new_settings["divide"] = new_settings["oob"] = all + if nan is not None: + new_settings["nan"] = nan + if divide is not None: + new_settings["divide"] = divide + if oob is not None: + new_settings["oob"] = oob + self.new_settings = new_settings + self.stack = contextlib.ExitStack() + + def __enter__(self): + config_flags = { + "nan": config.error_checking_behavior_nan, + "divide": config.error_checking_behavior_divide, + "oob": config.error_checking_behavior_oob, + } + for key, value in self.new_settings.items(): + self.stack.enter_context(config_flags[key](value)) + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.stack.close() diff --git a/jax/_src/numpy/fft.py b/jax/_src/numpy/fft.py index f962438f23bb..bb9d02079b25 100644 --- a/jax/_src/numpy/fft.py +++ b/jax/_src/numpy/fft.py @@ -16,10 +16,11 @@ from collections.abc import Sequence import operator + import numpy as np -from jax import lax from jax._src import dtypes +from jax._src.lax import fft as lax_fft from jax._src.lib import xla_client from jax._src.util import safe_zip from jax._src.numpy.util import ensure_arraylike, promote_dtypes_inexact @@ -45,7 +46,7 @@ def _fft_norm(s: Array, func_name: str, norm: str) -> Array: '"ortho" or "forward".') -def _fft_core(func_name: str, fft_type: lax.FftType, a: ArrayLike, +def _fft_core(func_name: str, fft_type: lax_fft.FftType, a: ArrayLike, s: Shape | None, axes: Sequence[int] | None, norm: str | None) -> Array: full_name = f"jax.numpy.fft.{func_name}" @@ -53,8 +54,8 @@ def _fft_core(func_name: str, fft_type: lax.FftType, a: ArrayLike, if s is not None: s = tuple(map(operator.index, s)) - if np.any(np.less(s, 0)): - raise ValueError("Shape should be non-negative.") + if np.any(np.less_equal(s, 0)): + raise ValueError("Shape should be positive.") if s is not None and axes is not None and len(s) != len(axes): # Same error as numpy. @@ -80,14 +81,14 @@ def _fft_core(func_name: str, fft_type: lax.FftType, a: ArrayLike, in_s = list(arr.shape) for axis, x in safe_zip(axes, s): in_s[axis] = x - if fft_type == lax.FftType.IRFFT: + if fft_type == lax_fft.FftType.IRFFT: in_s[-1] = (in_s[-1] // 2 + 1) # Cropping arr = arr[tuple(map(slice, in_s))] # Padding arr = jnp.pad(arr, [(0, x-y) for x, y in zip(in_s, arr.shape)]) else: - if fft_type == lax.FftType.IRFFT: + if fft_type == lax_fft.FftType.IRFFT: s = [arr.shape[axis] for axis in axes[:-1]] if axes: s += [max(0, 2 * (arr.shape[axes[-1]] - 1))] @@ -103,10 +104,10 @@ def _fft_core(func_name: str, fft_type: lax.FftType, a: ArrayLike, return transformed -def _fft_core_nd(arr: Array, fft_type: lax.FftType, s: Shape) -> Array: +def _fft_core_nd(arr: Array, fft_type: lax_fft.FftType, s: Shape) -> Array: # XLA supports N-D transforms up to N=3 so we use XLA's FFT N-D directly. if len(s) <= 3: - return lax.fft(arr, fft_type, tuple(s)) + return lax_fft.fft(arr, fft_type, tuple(s)) # For larger N, we repeatedly apply N<=3 transforms until we reach the # requested dimension. We special case N=4 to use two 2-D transforms instead @@ -115,16 +116,16 @@ def _fft_core_nd(arr: Array, fft_type: lax.FftType, s: Shape) -> Array: n = 2 if len(s) == 4 else 3 src = tuple(range(arr.ndim - len(s), arr.ndim - n)) dst = tuple(range(arr.ndim - len(s) + n, arr.ndim)) - if fft_type in {lax.FftType.RFFT, lax.FftType.FFT}: - arr = lax.fft(arr, fft_type, tuple(s)[-n:]) + if fft_type in {lax_fft.FftType.RFFT, lax_fft.FftType.FFT}: + arr = lax_fft.fft(arr, fft_type, tuple(s)[-n:]) arr = jnp.moveaxis(arr, src, dst) - arr = _fft_core_nd(arr, lax.FftType.FFT, s[:-n]) + arr = _fft_core_nd(arr, lax_fft.FftType.FFT, s[:-n]) arr = jnp.moveaxis(arr, dst, src) else: arr = jnp.moveaxis(arr, src, dst) - arr = _fft_core_nd(arr, lax.FftType.IFFT, s[:-n]) + arr = _fft_core_nd(arr, lax_fft.FftType.IFFT, s[:-n]) arr = jnp.moveaxis(arr, dst, src) - arr = lax.fft(arr, fft_type, tuple(s)[-n:]) + arr = lax_fft.fft(arr, fft_type, tuple(s)[-n:]) return arr @@ -199,7 +200,7 @@ def fftn(a: ArrayLike, s: Shape | None = None, >>> jnp.allclose(x, jnp.fft.ifftn(x_fftn)) Array(True, dtype=bool) """ - return _fft_core('fftn', lax.FftType.FFT, a, s, axes, norm) + return _fft_core('fftn', lax_fft.FftType.FFT, a, s, axes, norm) def ifftn(a: ArrayLike, s: Shape | None = None, @@ -267,7 +268,7 @@ def ifftn(a: ArrayLike, s: Shape | None = None, [[ 2.5 +0.j 0. -0.58j 0. +0.58j] [ 0.17+0.j -0.83-0.29j -0.83+0.29j]] """ - return _fft_core('ifftn', lax.FftType.IFFT, a, s, axes, norm) + return _fft_core('ifftn', lax_fft.FftType.IFFT, a, s, axes, norm) def rfftn(a: ArrayLike, s: Shape | None = None, @@ -358,7 +359,7 @@ def rfftn(a: ArrayLike, s: Shape | None = None, >>> jnp.fft.rfftn(x1) Array([10.+0.j, -2.+2.j, -2.+0.j], dtype=complex64) """ - return _fft_core('rfftn', lax.FftType.RFFT, a, s, axes, norm) + return _fft_core('rfftn', lax_fft.FftType.RFFT, a, s, axes, norm) def irfftn(a: ArrayLike, s: Shape | None = None, @@ -435,7 +436,7 @@ def irfftn(a: ArrayLike, s: Shape | None = None, [[-2., -2., -2.], [-2., -2., -2.]]], dtype=float32) """ - return _fft_core('irfftn', lax.FftType.IRFFT, a, s, axes, norm) + return _fft_core('irfftn', lax_fft.FftType.IRFFT, a, s, axes, norm) def _axis_check_1d(func_name: str, axis: int | None): @@ -446,7 +447,7 @@ def _axis_check_1d(func_name: str, axis: int | None): "Got axis = %r." % (full_name, full_name, axis) ) -def _fft_core_1d(func_name: str, fft_type: lax.FftType, +def _fft_core_1d(func_name: str, fft_type: lax_fft.FftType, a: ArrayLike, n: int | None, axis: int | None, norm: str | None) -> Array: _axis_check_1d(func_name, axis) @@ -514,7 +515,7 @@ def fft(a: ArrayLike, n: int | None = None, >>> jnp.allclose(x, jnp.fft.ifft(x_fft)) Array(True, dtype=bool) """ - return _fft_core_1d('fft', lax.FftType.FFT, a, n=n, axis=axis, + return _fft_core_1d('fft', lax_fft.FftType.FFT, a, n=n, axis=axis, norm=norm) @@ -570,7 +571,7 @@ def ifft(a: ArrayLike, n: int | None = None, [ 0.67+0.58j -0.5 +1.44j 0.17+2.02j 1.83+0.29j] [ 0.67-0.58j -0.5 -1.44j 0.17-2.02j 1.83-0.29j]] """ - return _fft_core_1d('ifft', lax.FftType.IFFT, a, n=n, axis=axis, + return _fft_core_1d('ifft', lax_fft.FftType.IFFT, a, n=n, axis=axis, norm=norm) @@ -631,7 +632,7 @@ def rfft(a: ArrayLike, n: int | None = None, [ 1.-2.j, 3.-4.j, 5.-6.j], [-1.+0.j, -1.+0.j, -1.+0.j]], dtype=complex64) """ - return _fft_core_1d('rfft', lax.FftType.RFFT, a, n=n, axis=axis, + return _fft_core_1d('rfft', lax_fft.FftType.RFFT, a, n=n, axis=axis, norm=norm) @@ -691,7 +692,7 @@ def irfft(a: ArrayLike, n: int | None = None, [-0.75, -1.25, -1.75], [ 0.25, 0.75, 1.25]], dtype=float32) """ - return _fft_core_1d('irfft', lax.FftType.IRFFT, a, n=n, axis=axis, + return _fft_core_1d('irfft', lax_fft.FftType.IRFFT, a, n=n, axis=axis, norm=norm) @@ -712,7 +713,7 @@ def hfft(a: ArrayLike, n: int | None = None, are supported. Default is "backward". Returns: - A real-valued array containing the one-dimensional discret Fourier transform + A real-valued array containing the one-dimensional discrete Fourier transform of ``a`` by exploiting its inherent Hermitian-symmetry, having a dimension of ``n`` along ``axis``. @@ -781,7 +782,7 @@ def hfft(a: ArrayLike, n: int | None = None, conj_a = ufuncs.conj(a) _axis_check_1d('hfft', axis) nn = (conj_a.shape[axis] - 1) * 2 if n is None else n - return _fft_core_1d('hfft', lax.FftType.IRFFT, conj_a, n=n, axis=axis, + return _fft_core_1d('hfft', lax_fft.FftType.IRFFT, conj_a, n=n, axis=axis, norm=norm) * nn @@ -831,12 +832,12 @@ def ihfft(a: ArrayLike, n: int | None = None, _axis_check_1d('ihfft', axis) arr = jnp.asarray(a) nn = arr.shape[axis] if n is None else n - output = _fft_core_1d('ihfft', lax.FftType.RFFT, arr, n=n, axis=axis, + output = _fft_core_1d('ihfft', lax_fft.FftType.RFFT, arr, n=n, axis=axis, norm=norm) return ufuncs.conj(output) * (1 / nn) -def _fft_core_2d(func_name: str, fft_type: lax.FftType, a: ArrayLike, +def _fft_core_2d(func_name: str, fft_type: lax_fft.FftType, a: ArrayLike, s: Shape | None, axes: Sequence[int], norm: str | None) -> Array: full_name = f"jax.numpy.fft.{func_name}" @@ -923,7 +924,7 @@ def fft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), >>> jnp.allclose(x, jnp.fft.ifft2(x_fft2)) Array(True, dtype=bool) """ - return _fft_core_2d('fft2', lax.FftType.FFT, a, s=s, axes=axes, + return _fft_core_2d('fft2', lax_fft.FftType.FFT, a, s=s, axes=axes, norm=norm) @@ -995,7 +996,7 @@ def ifft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), [-0.33-0.58j, -0.33-0.58j], [-0.33+0.58j, -0.33+0.58j]]], dtype=complex64) """ - return _fft_core_2d('ifft2', lax.FftType.IFFT, a, s=s, axes=axes, + return _fft_core_2d('ifft2', lax_fft.FftType.IFFT, a, s=s, axes=axes, norm=norm) @@ -1074,7 +1075,7 @@ def rfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), [ 3.47+10.11j, 6.43+11.42j, 9.38+12.74j], [ 3.19 +1.63j, 4.4 +1.38j, 5.61 +1.12j]]], dtype=complex64) """ - return _fft_core_2d('rfft2', lax.FftType.RFFT, a, s=s, axes=axes, + return _fft_core_2d('rfft2', lax_fft.FftType.RFFT, a, s=s, axes=axes, norm=norm) @@ -1149,7 +1150,7 @@ def irfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), [ 0. , 0. , 0. ], [ 0. , 0. , 0. ]]], dtype=float32) """ - return _fft_core_2d('irfft2', lax.FftType.IRFFT, a, s=s, axes=axes, + return _fft_core_2d('irfft2', lax_fft.FftType.IRFFT, a, s=s, axes=axes, norm=norm) @@ -1175,7 +1176,8 @@ def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None, - :func:`jax.numpy.fft.rfftfreq`: frequencies for use with :func:`~jax.numpy.fft.rfft` and :func:`~jax.numpy.fft.irfft`. """ - dtype = dtype or dtypes.canonicalize_dtype(dtypes.float_) + dtype = dtype or dtypes.default_float_dtype() + if isinstance(n, (list, tuple)): raise ValueError( "The n argument of jax.numpy.fft.fftfreq only takes an int. " @@ -1186,22 +1188,13 @@ def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None, "The d argument of jax.numpy.fft.fftfreq only takes a single value. " "Got d = %s." % list(d)) - k = jnp.zeros(n, dtype=dtype, device=device) - if n % 2 == 0: - # k[0: n // 2 - 1] = jnp.arange(0, n // 2 - 1) - k = k.at[0: n // 2].set(jnp.arange(0, n // 2, dtype=dtype)) - - # k[n // 2:] = jnp.arange(-n // 2, -1) - k = k.at[n // 2:].set(jnp.arange(-n // 2, 0, dtype=dtype)) - - else: - # k[0: (n - 1) // 2] = jnp.arange(0, (n - 1) // 2) - k = k.at[0: (n - 1) // 2 + 1].set(jnp.arange(0, (n - 1) // 2 + 1, dtype=dtype)) - - # k[(n - 1) // 2 + 1:] = jnp.arange(-(n - 1) // 2, -1) - k = k.at[(n - 1) // 2 + 1:].set(jnp.arange(-(n - 1) // 2, 0, dtype=dtype)) + out_dtype = dtype + dtype = dtypes.finfo(dtypes.to_inexact_dtype(dtype)).dtype - return k / jnp.array(d * n, dtype=dtype, device=device) + i = jnp.arange(n, dtype=dtype, device=device) + k = ((i + n//2) % n - n//2) + result = k.astype(dtype) / jnp.array(d * n, dtype=dtype, device=device) + return result.astype(out_dtype) def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None, @@ -1227,7 +1220,7 @@ def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None, - :func:`jax.numpy.fft.fftfreq`: frequencies for use with :func:`~jax.numpy.fft.fft` and :func:`~jax.numpy.fft.ifft`. """ - dtype = dtype or dtypes.canonicalize_dtype(dtypes.float_) + dtype = dtype or dtypes.default_float_dtype() if isinstance(n, (list, tuple)): raise ValueError( "The n argument of jax.numpy.fft.rfftfreq only takes an int. " diff --git a/jax/_src/numpy/index_tricks.py b/jax/_src/numpy/index_tricks.py index ec67d7489f30..b34be9e3223e 100644 --- a/jax/_src/numpy/index_tricks.py +++ b/jax/_src/numpy/index_tricks.py @@ -17,17 +17,19 @@ from collections.abc import Iterable from typing import Any, Union -import jax +import numpy as np + +from jax._src import config from jax._src import core +from jax._src.numpy.array_constructors import array from jax._src.numpy.util import promote_dtypes +from jax._src.numpy.array_creation import linspace from jax._src.numpy.lax_numpy import ( - arange, array, concatenate, expand_dims, linspace, meshgrid, stack, transpose + arange, concatenate, expand_dims, meshgrid, stack, transpose ) from jax._src.typing import Array, ArrayLike from jax._src.util import set_module -import numpy as np - export = set_module('jax.numpy') @@ -83,7 +85,7 @@ def __getitem__(self, key: slice | tuple[slice, ...]) -> Array: if isinstance(key, slice): return _make_1d_grid_from_slice(key, op_name="mgrid") output: Iterable[Array] = (_make_1d_grid_from_slice(k, op_name="mgrid") for k in key) - with jax.numpy_dtype_promotion('standard'): + with config.numpy_dtype_promotion('standard'): output = promote_dtypes(*output) output_arr = meshgrid(*output, indexing='ij', sparse=False) if len(output_arr) == 0: @@ -128,7 +130,7 @@ def __getitem__( if isinstance(key, slice): return _make_1d_grid_from_slice(key, op_name="ogrid") output: Iterable[Array] = (_make_1d_grid_from_slice(k, op_name="ogrid") for k in key) - with jax.numpy_dtype_promotion('standard'): + with config.numpy_dtype_promotion('standard'): output = promote_dtypes(*output) return meshgrid(*output, indexing='ij', sparse=True) @@ -228,10 +230,8 @@ class RClass(_AxisConcat): An imaginary value for ``step`` will create a ``jnp.linspace`` object instead, which includes the right endpoint: - >>> jnp.r_[-1:1:6j, 0, jnp.array([1,2,3])] - Array([-1. , -0.6 , -0.20000002, 0.20000005, - 0.6 , 1. , 0. , 1. , - 2. , 3. ], dtype=float32) + >>> jnp.r_[-1:1:6j, 0, jnp.array([1,2,3])] # doctest: +SKIP + Array([-1. , -0.6, -0.2, 0.2, 0.6, 1. , 0. , 1. , 2. , 3. ], dtype=float32) Use a string directive of the form ``"axis,dims,trans1d"`` as the first argument to specify concatenation axis, minimum number of dimensions, and the position of the diff --git a/jax/_src/numpy/indexing.py b/jax/_src/numpy/indexing.py index 5d59bb53b457..4df769f9fc70 100644 --- a/jax/_src/numpy/indexing.py +++ b/jax/_src/numpy/indexing.py @@ -15,36 +15,568 @@ # pytype: skip-file """Indexing code for jax.numpy.""" +from __future__ import annotations + +import dataclasses +import enum from functools import partial import operator import string -from typing import Any, NamedTuple, Sequence +from typing import Any, NamedTuple +from collections.abc import Sequence import numpy as np -import jax -from jax import lax +from jax._src import api from jax._src import array from jax._src import config from jax._src import core from jax._src import dispatch from jax._src import dtypes from jax._src import errors -from jax._src.api import jit -from jax._src.lax import lax as lax_internal +from jax._src import literals +from jax._src.lax import lax +from jax._src.lax import slicing +from jax._src.lax import utils as lax_utils +from jax._src.numpy import array_constructors from jax._src.numpy import einsum -from jax._src import mesh as mesh_lib -from jax._src.pjit import auto_axes +from jax._src.numpy import error as jnp_error from jax._src.numpy import lax_numpy from jax._src.numpy import ufuncs from jax._src.numpy import util -from jax._src.tree_util import tree_flatten -from jax._src.typing import Array, ArrayLike, StaticScalar -from jax._src.util import canonicalize_axis, set_module, tuple_replace, safe_zip +from jax._src.partition_spec import PartitionSpec +from jax._src.pjit import auto_axes +from jax._src.sharding_impls import canonicalize_sharding, NamedSharding +from jax._src.tree_util import tree_flatten, tree_unflatten, register_pytree_node_class +from jax._src.typing import Array, ArrayLike, Index, StaticScalar +from jax._src.util import canonicalize_axis, safe_zip, set_module, tuple_update, unzip3 export = set_module('jax.numpy') +# Internal utilities for parsing and validating NumPy-style indices. + +class IndexType(enum.Enum): + """Enum for tracking the type of an index.""" + NONE = "none" + SLICE = "slice" + ELLIPSIS = "ellipsis" + INTEGER = "integer" + BOOLEAN = "boolean" + ARRAY = "array" + + @classmethod + def from_index(cls, idx: Index) -> IndexType: + """Create an IndexType enum from a supported JAX array index.""" + if idx is None: + return cls.NONE + elif idx is Ellipsis: + return cls.ELLIPSIS + elif isinstance(idx, slice): + return cls.SLICE + elif _is_integer_index(idx): + return cls.INTEGER + elif _is_boolean_index(idx): + return cls.BOOLEAN + elif isinstance(idx, (Array, np.ndarray, literals.TypedNdArray)): + if dtypes.issubdtype(idx.dtype, np.integer): + return cls.ARRAY + else: + raise TypeError( + f"Indexer must have integer or boolean type, got indexer with type {idx.dtype}") + elif isinstance(idx, str): + # TODO(jakevdp): this TypeError is for backward compatibility. + # We should switch to IndexError for consistency. + raise TypeError(f"JAX does not support string indexing; got {idx=}") + elif isinstance(idx, Sequence): + if not idx: # empty indices default to float, so special-case this. + return cls.ARRAY + idx_aval = api.eval_shape(array_constructors.asarray, idx) + if idx_aval.dtype == bool: + return cls.BOOLEAN + elif dtypes.issubdtype(idx_aval.dtype, np.integer): + return cls.ARRAY + else: + raise TypeError( + f"Indexer must have integer or boolean type, got indexer with type {idx_aval.dtype}") + elif isinstance(idx, (float, complex, np.generic)): + raise TypeError( + f"Indexer must have integer or boolean type, got indexer with type {np.dtype(type(idx))}") + else: + raise IndexError("only integers, slices (`:`), ellipsis (`...`), newaxis (`None`)" + f" and integer or boolean arrays are valid indices. Got {idx}") + + +class ParsedIndex(NamedTuple): + """Structure for tracking an indexer parsed within the context of an array shape.""" + index: Index # type: ignore[assignment] # seems to be a strange misfire by mypy. + typ: IndexType + consumed_axes: tuple[int, ...] + + +def _parse_indices( + indices: tuple[Index, ...], + shape: tuple[int, ...], +) -> list[ParsedIndex]: + """Parse indices in the context of an array shape. + + Args: + indices: a tuple of user-supplied indices to be parsed. + shape: the shape of the array being indexed. + + Returns: + The list of parsed indices stored in :class:`ParsedIndex` objects. + This list will have the same length as ``indices``. + + Raises: + IndexError: if any unrecognized index types are present or if there + are too many indices, or too many ellipses. + """ + # 1. go through indices to count the number of consumed dimensions. + # This is required to determine the effect of any ellipses. + dimensions_consumed: list[int] = [] + ellipses_indices: list[int] = [] + index_types: list[IndexType] = [] + for i, idx in enumerate(indices): + typ = IndexType.from_index(idx) + index_types.append(typ) + + if typ == IndexType.NONE: + dimensions_consumed.append(0) + elif typ == IndexType.ELLIPSIS: + # We don't yet know how many dimensions are consumed, so set to zero + # for now and update later. + dimensions_consumed.append(0) + ellipses_indices.append(i) + elif typ == IndexType.BOOLEAN: + dimensions_consumed.append(np.ndim(idx)) # type: ignore[arg-type] + elif typ in [IndexType.INTEGER, IndexType.ARRAY, IndexType.SLICE]: + dimensions_consumed.append(1) + else: + raise IndexError(f"Unrecognized index type: {typ}") + + # 2. Validate the consumed dimensions and ellipses. + if len(ellipses_indices) > 1: + raise IndexError("an index can only have a single ellipsis ('...')") + total_consumed = sum(dimensions_consumed) + if total_consumed > len(shape): + raise IndexError(f"Too many indices: array is {len(shape)}-dimensional," + f" but {total_consumed} were indexed") + if ellipses_indices: + dimensions_consumed[ellipses_indices[0]] = len(shape) - total_consumed + + # 3. Generate the final sequence of parsed indices. + result: list[ParsedIndex] = [] + current_dim = 0 + for index, typ, n_consumed in safe_zip(indices, index_types, dimensions_consumed): + consumed_axes = tuple(range(current_dim, current_dim + n_consumed)) + current_dim += len(consumed_axes) + result.append(ParsedIndex(index=index, typ=typ, consumed_axes=consumed_axes)) + return result + + +@register_pytree_node_class +@dataclasses.dataclass(frozen=True, kw_only=True) +class NDIndexer: + """Object that implements NumPy-style indexing operations on top of JAX. + + Generally this will be constructed via the :meth:`NDIndexer.from_raw_indices` + method. + + Attributes: + shape: the shape of the array being indexed. + indices: a list of :class:`ParsedIndex` objects. + """ + shape: tuple[int, ...] + indices: list[ParsedIndex] + + @classmethod + def from_raw_indices(cls, indices: Index | tuple[Index, ...], shape: tuple[int, ...]) -> NDIndexer: + """Create an NDIndexer object from raw user-supplied indices.""" + indices = eliminate_deprecated_list_indexing(indices) + indices = _parse_indices(indices, shape) + return cls(shape=shape, indices=indices) + + def validate_static_indices(self, normalize_indices: bool = True) -> None: + """Check that all static integer indices are in-bounds. + + Raises an IndexError in case of out-of-bound indices + """ + for position, idx in enumerate(self.indices): + if idx.typ == IndexType.INTEGER: + assert isinstance(idx.index, (int, np.integer)) + i = operator.index(idx.index) + axis, = idx.consumed_axes + size = self.shape[axis] + normed_idx = i + size if normalize_indices and i < 0 else i + if not 0 <= normed_idx < size: + raise IndexError(f"index {i} out of bounds for axis {axis} with size {size}" + f" ({normalize_indices=})") + + def validate_slices(self) -> None: + """Check that all slices have static start/stop/step values. + + Raises an IndexError in case of non-static entries. + """ + for position, idx in enumerate(self.indices): + if idx.typ == IndexType.SLICE: + assert isinstance(idx.index, slice) + if not all(_is_slice_element_none_or_constant_or_symbolic(val) + for val in [idx.index.start, idx.index.stop, idx.index.step]): + raise IndexError("Slice entries must be static integers." + f" Got {idx.index} at position {position}") + + @staticmethod + def is_sharded(arr) -> bool: + """Check whether the array is sharded.""" + return isinstance(arr, array.ArrayImpl) and not dispatch.is_single_device_sharding(arr.sharding) + + def has_partial_slices(self) -> bool: + """Check whether the indexer contains partial slices. + + For sharded arrays, partial slices cannot automatically propagate + sharding. + """ + for idx in self.indices: + if idx.typ == IndexType.INTEGER: + return True + if idx.typ == IndexType.SLICE: + slc = idx.index + assert isinstance(slc, slice) + axis, = idx.consumed_axes + size = self.shape[axis] + start, stop, step = slc.indices(self.shape[axis]) + if abs(step) != 1 or abs(stop - start) != size: + return True + return False + + def expand_bool_indices(self) -> NDIndexer: + """Returns a new NDIndexer with boolean indices replaced by array indices. + + The only exception are scalar boolean indices, which are left in-place. + """ + expanded_indices: list[ParsedIndex] = [] + + for position, idx in enumerate(self.indices): + if idx.typ != IndexType.BOOLEAN: + expanded_indices.append(idx) + continue + if not core.is_concrete(idx.index): + # TODO(mattjj): improve this error by tracking _why_ the indices are not concrete + raise errors.NonConcreteBooleanIndexError(core.get_aval(idx.index)) + assert isinstance(idx.index, (bool, np.ndarray, Array, literals.TypedNdArray, list)) + if np.ndim(idx.index) == 0: + # Scalar booleans + assert idx.consumed_axes == () + expanded_indices.append(ParsedIndex(index=bool(idx.index), typ=idx.typ, consumed_axes=())) + continue + idx_shape = np.shape(idx.index) + expected_shape = [self.shape[i] for i in idx.consumed_axes] + if not all(s1 in (0, s2) for s1, s2 in zip(idx_shape, expected_shape)): + raise IndexError("boolean index did not match shape of indexed array in index" + f" {position}: got {idx_shape}, expected {expected_shape}") + expanded_indices_raw = np.where(np.asarray(idx.index)) + expanded_indices.extend(ParsedIndex(index=i, typ=IndexType.ARRAY, consumed_axes=(axis,)) + for i, axis in safe_zip(expanded_indices_raw, idx.consumed_axes)) + return NDIndexer(shape=self.shape, indices=expanded_indices) + + def expand_scalar_bool_indices(self, sharding_spec: Any = None) -> tuple[NDIndexer, Any]: + new_shape = list(self.shape) + new_sharding_spec = list((None for _ in self.shape) if sharding_spec is None else sharding_spec) + new_indices = list(self.indices) + current_dim = 0 + for i, idx in enumerate(self.indices): + if idx.typ == IndexType.BOOLEAN and np.ndim(idx.index) == 0: # type: ignore[arg-type] + new_shape.insert(i, 1) + new_sharding_spec.insert(i, None) + new_indices[i] = ParsedIndex( + np.arange(int(idx.index)), typ=IndexType.ARRAY, consumed_axes=(current_dim,)) # type: ignore[arg-type] + current_dim += 1 + else: + n_consumed = len(idx.consumed_axes) + new_indices[i] = ParsedIndex( + index=idx.index, + typ=idx.typ, + consumed_axes = tuple(range(current_dim, current_dim + n_consumed)) + ) + current_dim += n_consumed + new_sharding_spec = None if sharding_spec is None else tuple(new_sharding_spec) + return NDIndexer(indices=new_indices, shape=tuple(new_shape)), new_sharding_spec + + def convert_sequences_to_arrays(self) -> NDIndexer: + new_indices = [ParsedIndex(lax_numpy.asarray(idx.index), typ=idx.typ, consumed_axes=idx.consumed_axes) + if isinstance(idx.index, Sequence) else idx for idx in self.indices] + return NDIndexer(indices=new_indices, shape=self.shape) + + def expand_ellipses(self) -> NDIndexer: + """ + Returns a new indexer with ellipsis and implicit trailing slices + replaced by explicit empty slices. + """ + expanded: list[ParsedIndex] = [] + consumed = 0 + for idx in self.indices: + consumed += len(idx.consumed_axes) + if idx.typ == IndexType.ELLIPSIS: + for axis in idx.consumed_axes: + expanded.append(ParsedIndex(index=slice(None), typ=IndexType.SLICE, consumed_axes=(axis,))) + else: + expanded.append(idx) + for axis in range(consumed, len(self.shape)): + expanded.append(ParsedIndex(index=slice(None), typ=IndexType.SLICE, consumed_axes=(axis,))) + return NDIndexer(shape=self.shape, indices=expanded) + + def normalize_indices(self) -> NDIndexer: + new_indices: list[ParsedIndex] = [] + for idx in self.indices: + if idx.typ == IndexType.INTEGER: + axis, = idx.consumed_axes + size: ArrayLike = self.shape[axis] + if isinstance(idx.index, np.unsignedinteger): + normed_index: Index = idx.index + else: + normed_index = idx.index + size if idx.index < 0 else idx.index # type: ignore[assignment,operator] + new_indices.append(ParsedIndex(normed_index, typ=idx.typ, consumed_axes=idx.consumed_axes)) + elif idx.typ == IndexType.ARRAY: + assert isinstance(idx.index, (Array, np.ndarray, literals.TypedNdArray)) + axis, = idx.consumed_axes + if dtypes.issubdtype(idx.index.dtype, np.unsignedinteger): + normed_index = idx.index + else: + size = self.shape[axis] + if core.is_constant_dim(size): + size = lax._const(idx.index, size) + else: + size = lax.convert_element_type(core.dimension_as_value(size), + idx.index.dtype) + normed_index = lax.select(idx.index < 0, lax.add(idx.index, size), idx.index) + new_indices.append(ParsedIndex(normed_index, typ=idx.typ, consumed_axes=idx.consumed_axes)) + else: + new_indices.append(idx) + return NDIndexer(indices=new_indices, shape=self.shape) + + def compute_via_static_slice(self, arr: Array, *, + normalize_indices: bool = True, + mode: str | slicing.GatherScatterMode | None) -> Array: + """Equivalent of arr[idx] implemented in terms of static :func:`lax.slice` operations. + + This supports only INTEGER, ELLIPSIS, NONE, and SLICE indices, and will raise a + TypeError if other indices are present. + """ + if mode is None: + parsed_mode = slicing.GatherScatterMode.PROMISE_IN_BOUNDS + else: + parsed_mode = slicing.GatherScatterMode.from_any(mode) + + if parsed_mode not in [ + slicing.GatherScatterMode.PROMISE_IN_BOUNDS, slicing.GatherScatterMode.CLIP]: + raise ValueError("static_slice requires mode='promise_in_bounds' or mode='clip'") + + # Validation of the unmodified user indices. + if parsed_mode == slicing.GatherScatterMode.PROMISE_IN_BOUNDS: + self.validate_static_indices(normalize_indices=normalize_indices) + self.validate_slices() + + # For sharded inputs, indexing (like x[0]) and partial slices (like x[:2] as + # opposed to x[:]) lead to incorrect sharding semantics when computed via slice. + # TODO(yashkatariya): fix slice with sharding + if self.is_sharded(arr) and self.has_partial_slices(): + raise ValueError("static_slice with partial slices does not support nontrivial array sharding.") + + for position, pidx in enumerate(self.indices): + if pidx.typ in [IndexType.INTEGER, IndexType.ELLIPSIS, IndexType.SLICE, IndexType.NONE]: + pass + elif pidx.typ in [IndexType.ARRAY, IndexType.BOOLEAN]: + raise TypeError("static_slice: indices must be static scalars or slices." + f" Got {pidx.index} at position {position}") + else: + raise TypeError(f"static_slice: unrecognized index {pidx.index} at position {position}.") + + # Now re-iterate to generate static slices. + start_indices: list[int] = [] + limit_indices: list[int] = [] + strides: list[int] = [] + rev_axes: list[int] = [] + squeeze_axes: list[int] = [] + newaxis_dims: list[int] = [] + + expanded = self.expand_ellipses() + for pidx in expanded.indices: + if pidx.typ in [IndexType.ARRAY, IndexType.BOOLEAN, IndexType.ELLIPSIS]: + raise RuntimeError(f"Internal: unexpected index encountered: {pidx}") + elif pidx.typ == IndexType.NONE: + # Expanded axes indices are based on the rank of the array after slicing + # (tracked by start_indices) and squeezing (tracked by squeeze_axes), and + # expand_dims inserts dimensions in order, so we must also account for + # previous expanded dimensions. + newaxis_dims.append(len(start_indices) - len(squeeze_axes) + len(newaxis_dims) ) + elif pidx.typ == IndexType.INTEGER: + assert isinstance(pidx.index, (int, np.integer)) + axis, = pidx.consumed_axes + start_index = int(pidx.index) + if normalize_indices and start_index < 0: + start_index += arr.shape[axis] + # Normalization & validation have already been handled, so clip start_index + # to valid range + start_index = min(max(start_index, 0), arr.shape[axis] - 1) + start_indices.append(start_index) + limit_indices.append(start_index + 1) + strides.append(1) + squeeze_axes.append(axis) + elif pidx.typ == IndexType.SLICE: + assert isinstance(pidx.index, slice) + axis, = pidx.consumed_axes + size = arr.shape[axis] + start, stop, stride = pidx.index.indices(size) + if stride < 0: + new_start = stop + 1 + abs(start - stop - 1) % abs(stride) + start_indices.append(new_start) + limit_indices.append(max(new_start, start + 1)) + strides.append(abs(stride)) + rev_axes.append(axis) + else: + start_indices.append(start) + limit_indices.append(stop) + strides.append(stride) + else: + raise TypeError(f"static_slice: unrecognized index {pidx.index}") + result = arr + optional_strides: list[int] | None = None if all(s == 1 for s in strides) else strides + is_trivial_slice = optional_strides is None and all( + (start, stop) == (0, size) + for start, stop, size in zip(start_indices, limit_indices, arr.shape) + ) + if not is_trivial_slice: + result = slicing.slice(result, start_indices, limit_indices, optional_strides) + if rev_axes: + result = lax.rev(result, rev_axes) + if squeeze_axes: + result = lax.squeeze(result, squeeze_axes) + if newaxis_dims: + result = lax.expand_dims(result, newaxis_dims) + return result + + def compute_via_dynamic_slice(self, arr: Array, *, + normalize_indices: bool = True, + mode: str | slicing.GatherScatterMode | None) -> Array: + """Equivalent of arr[idx] implemented in terms of static :func:`lax.dynamic_slice`. + + This supports only INTEGER, ELLIPSIS, NONE, SLICE, and scalar ARRAY indices, + and will raise a TypeError if other indices are present. + """ + if mode is not None: + parsed_mode = slicing.GatherScatterMode.from_any(mode) + if parsed_mode not in [ + slicing.GatherScatterMode.PROMISE_IN_BOUNDS, slicing.GatherScatterMode.CLIP]: + raise ValueError("dynamic_slice requires mode='promise_in_bounds' or mode='clip'") + + # For sharded inputs, indexing (like x[0]) and partial slices (like x[:2] as + # opposed to x[:]) lead to incorrect sharding semantics when computed via slice. + # TODO(yashkatariya): fix slice with sharding + if self.is_sharded(arr) and self.has_partial_slices(): + raise ValueError("dynamic_slice with partial slices does not support nontrivial array sharding.") + + for position, pidx in enumerate(self.indices): + if pidx.typ in [IndexType.INTEGER, IndexType.ELLIPSIS, IndexType.NONE]: + pass + elif pidx.typ == IndexType.SLICE: + assert isinstance(pidx.index, slice) + if pidx.index.step is not None and pidx.index.step not in [-1, 1]: + raise TypeError("dynamic_slice: only unit steps supported in slice." + f" Got {pidx.index} at position {position}") + elif pidx.typ == IndexType.ARRAY: + if isinstance(pidx.index, Sequence) or np.shape(pidx.index) != (): # type: ignore[arg-type] + raise TypeError("dynamic_slice: only scalar indices allowed." + f" Got {pidx.index} at position {position}") + elif pidx.typ == IndexType.BOOLEAN: + raise TypeError("dynamic_slice: indices must be scalars or slices." + f" Got {pidx.index} at position {position}") + else: + raise TypeError(f"dynamic_slice: unrecognized index {pidx.index} at position {position}.") + + start_indices: list[ArrayLike] = [] + slice_sizes: list[int] = [] + rev_axes: list[int] = [] + squeeze_axes: list[int] = [] + newaxis_dims: list[int] = [] + + expanded = self.expand_ellipses() + for pidx in expanded.indices: + if pidx.typ in [IndexType.BOOLEAN, IndexType.ELLIPSIS]: + raise RuntimeError(f"Internal: unexpected index encountered: {pidx}") + elif pidx.typ == IndexType.NONE: + # Expanded axes indices are based on the rank of the array after slicing + # (tracked by start_indices) and squeezing (tracked by squeeze_axes), and + # expand_dims inserts dimensions in order, so we must also account for + # previous expanded dimensions. + newaxis_dims.append(len(start_indices) - len(squeeze_axes) + len(newaxis_dims)) + elif pidx.typ in [IndexType.INTEGER, IndexType.ARRAY]: + index = lax_numpy.asarray(pidx.index) + assert index.shape == () # Validated above. + axis, = pidx.consumed_axes + start_indices.append(index) + slice_sizes.append(1) + squeeze_axes.append(axis) + elif pidx.typ == IndexType.SLICE: + assert isinstance(pidx.index, slice) + axis, = pidx.consumed_axes + size = arr.shape[axis] + start, stop, stride = pidx.index.indices(size) + assert stride in [-1, 1] # validated above + if stride < 0: + new_start = stop + 1 + abs(start - stop - 1) % abs(stride) + start_indices.append(new_start) + slice_sizes.append(max(0, start + 1 - new_start)) + rev_axes.append(axis) + else: + start_indices.append(start) + slice_sizes.append(stop - start) + else: + raise TypeError(f"dynamic_slice: unrecognized index {pidx.index}") + result = arr + is_trivial_slice = all( + (slice_size == axis_size) + for slice_size, axis_size in zip(slice_sizes, arr.shape) + ) + if not is_trivial_slice: + result = slicing.dynamic_slice(arr, start_indices, slice_sizes, + allow_negative_indices=normalize_indices) + if rev_axes: + result = lax.rev(result, rev_axes) + if squeeze_axes: + result = lax.squeeze(result, squeeze_axes) + if newaxis_dims: + result = lax.expand_dims(result, newaxis_dims) + return result + + def is_advanced_int_indexer(self): + """Returns True if idx should trigger int array indexing, False otherwise.""" + # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing + return any(idx.typ in [IndexType.ARRAY, IndexType.BOOLEAN] and np.ndim(idx.index) > 0 + for idx in self.indices) + + def to_gather(self, x_sharding: NamedSharding | Any, + normalize_indices: bool = True) -> _GatherIndexer: + return _index_to_gather(self, x_sharding=x_sharding, normalize_indices=normalize_indices) + + def tree_flatten(self): + # split dynamic and static indices + def is_dynamic(i: ParsedIndex): + return i.typ in [IndexType.INTEGER, IndexType.ARRAY, IndexType.BOOLEAN] + raw_dynamic_indices = [i.index if is_dynamic(i) else None for i in self.indices] + static_metadata = [ + ParsedIndex(index=None, typ=i.typ, consumed_axes=i.consumed_axes) if is_dynamic(i) else i + for i in self.indices] + return raw_dynamic_indices, (self.shape, static_metadata) + + @classmethod + def tree_unflatten(cls, aux_data, children): + shape, static_metadata = aux_data + indices = [idx if dyn_index is None else ParsedIndex(dyn_index, idx.typ, idx.consumed_axes) + for dyn_index, idx in safe_zip(children, static_metadata)] + return cls(indices=indices, shape=shape) + + @export def take( a: ArrayLike, @@ -74,12 +606,14 @@ def take( fill_value: The fill value to return for out-of-bounds slices when mode is 'fill'. Ignored otherwise. Defaults to NaN for inexact types, the largest negative value for signed types, the largest positive value for unsigned types, and True for booleans. - unique_indices: If True, the implementation will assume that the indices are unique, - which can result in more efficient execution on some backends. If set to True and - indices are not unique, the output is undefined. + unique_indices: If True, the implementation will assume that the indices are unique + after normalization of negative indices, which lets the compiler emit more efficient + code during the backward pass. If set to True and normalized indices are not unique, + the result is implementation-defined and may be non-deterministic. indices_are_sorted : If True, the implementation will assume that the indices are - sorted in ascending order, which can lead to more efficient execution on some - backends. If set to True and indices are not sorted, the output is undefined. + sorted in ascending order after normalization of negative indices, which can lead + to more efficient execution on some backends. If set to True and normalized indices + are not sorted, the output is implementation-defined. Returns: Array of values extracted from ``a``. @@ -133,7 +667,7 @@ def take( fill_value=fill_value) -@partial(jit, static_argnames=('axis', 'mode', 'unique_indices', 'indices_are_sorted', 'fill_value')) +@api.jit(static_argnames=('axis', 'mode', 'unique_indices', 'indices_are_sorted', 'fill_value')) def _take(a, indices, axis: int | None = None, out=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None): if out is not None: @@ -147,17 +681,17 @@ def _take(a, indices, axis: int | None = None, out=None, mode=None, axis_idx = canonicalize_axis(axis, np.ndim(a)) if mode is None or mode == "fill": - gather_mode = lax.GatherScatterMode.FILL_OR_DROP + gather_mode = slicing.GatherScatterMode.FILL_OR_DROP # lax.gather() does not support negative indices, so we wrap them here indices = util._where(indices < 0, indices + a.shape[axis_idx], indices) elif mode == "raise": # TODO(phawkins): we have no way to report out of bounds errors yet. raise NotImplementedError("The 'raise' mode to jnp.take is not supported.") elif mode == "wrap": - indices = ufuncs.mod(indices, lax_internal._const(indices, a.shape[axis_idx])) - gather_mode = lax.GatherScatterMode.PROMISE_IN_BOUNDS + indices = ufuncs.mod(indices, lax._const(indices, a.shape[axis_idx])) + gather_mode = slicing.GatherScatterMode.PROMISE_IN_BOUNDS elif mode == "clip": - gather_mode = lax.GatherScatterMode.CLIP + gather_mode = slicing.GatherScatterMode.CLIP else: raise ValueError(f"Invalid mode '{mode}' for np.take") @@ -174,27 +708,27 @@ def _take(a, indices, axis: int | None = None, out=None, mode=None, return lax.full_like(a, 0, shape=out_shape) slice_sizes[axis_idx] = 1 - dnums = lax.GatherDimensionNumbers( + dnums = slicing.GatherDimensionNumbers( offset_dims=tuple( list(range(axis_idx)) + list(range(axis_idx + index_dims, len(a.shape) + index_dims - 1))), collapsed_slice_dims=(axis_idx,), start_index_map=(axis_idx,)) - return lax.gather(a, indices[..., None], dimension_numbers=dnums, - slice_sizes=tuple(slice_sizes), - mode=gather_mode, unique_indices=unique_indices, - indices_are_sorted=indices_are_sorted, fill_value=fill_value) + return slicing.gather(a, indices[..., None], dimension_numbers=dnums, + slice_sizes=tuple(slice_sizes), + mode=gather_mode, unique_indices=unique_indices, + indices_are_sorted=indices_are_sorted, fill_value=fill_value) def _normalize_index(index, axis_size): """Normalizes an index value in the range [-N, N) to the range [0, N).""" - if dtypes.issubdtype(dtypes.dtype(index, canonicalize=True), np.unsignedinteger): + if dtypes.issubdtype(dtypes.dtype(index), np.unsignedinteger): return index if core.is_constant_dim(axis_size): - axis_size_val = lax_internal._const(index, axis_size) + axis_size_val = lax._const(index, axis_size) else: axis_size_val = lax.convert_element_type(core.dimension_as_value(axis_size), - dtypes.dtype(index, canonicalize=True)) + dtypes.dtype(index)) if isinstance(index, (int, np.integer)): return lax.add(index, axis_size_val) if index < 0 else index else: @@ -202,12 +736,12 @@ def _normalize_index(index, axis_size): @export -@partial(jit, static_argnames=('axis', 'mode', 'fill_value')) +@api.jit(static_argnames=('axis', 'mode', 'fill_value')) def take_along_axis( arr: ArrayLike, indices: ArrayLike, - axis: int | None, - mode: str | lax.GatherScatterMode | None = None, + axis: int | None = -1, + mode: str | slicing.GatherScatterMode | None = None, fill_value: StaticScalar | None = None, ) -> Array: """Take elements from an array. @@ -282,7 +816,7 @@ def take_along_axis( [2]], dtype=int32) """ a, indices = util.ensure_arraylike("take_along_axis", arr, indices) - index_dtype = dtypes.dtype(indices) + index_dtype = indices.dtype idx_shape = np.shape(indices) if not dtypes.issubdtype(index_dtype, np.integer): raise TypeError("take_along_axis indices must be of integer type, got " @@ -304,8 +838,7 @@ def replace(tup, val): lst[axis_int] = val return tuple(lst) - use_64bit_index = any(not core.is_constant_dim(d) or d >= (1 << 31) for d in a.shape) - index_dtype = np.dtype('int64' if use_64bit_index else 'int32') + index_dtype = lax_utils.int_dtype_for_dim(a.shape, signed=True) indices = lax.convert_element_type(indices, index_dtype) axis_size = a.shape[axis_int] @@ -315,8 +848,10 @@ def replace(tup, val): return lax.full(out_shape, 0, a.dtype) if mode == "one_hot": + from jax import nn # pytype: disable=import-error + indices = _normalize_index(indices, axis_size) - hot = jax.nn.one_hot(indices, axis_size, dtype=np.bool_) + hot = nn.one_hot(indices, axis_size, dtype=np.bool_) if a.ndim == 1: return einsum.einsum("...b,b->...", hot, a, preferred_element_type=a.dtype) if axis_int > len(string.ascii_letters) - 2: @@ -386,22 +921,24 @@ def replace(tup, val): # Squeeze a to remove singleton dimensions. a = lax.squeeze(a, dims_to_squeeze) gather_indices_arr = lax.concatenate(gather_indices, dimension=j) - dnums = lax.GatherDimensionNumbers( + dnums = slicing.GatherDimensionNumbers( offset_dims=tuple(offset_dims), collapsed_slice_dims=tuple(collapsed_slice_dims), start_index_map=tuple(start_index_map), operand_batching_dims=tuple(operand_batching_dims), start_indices_batching_dims=tuple(start_indices_batching_dims)) - return lax.gather(a, gather_indices_arr, dnums, tuple(slice_sizes), - mode="fill" if mode is None else mode, fill_value=fill_value) + return slicing.gather(a, gather_indices_arr, dnums, tuple(slice_sizes), + mode="fill" if mode is None else mode, fill_value=fill_value) def _make_along_axis_idx(shape, indices, axis): - return tuple_replace(lax_numpy.indices(shape, sparse=True), axis, indices) + if axis < 0: + axis += len(shape) + return tuple_update(lax_numpy.indices(shape, sparse=True), axis, indices) @export -@partial(jit, static_argnames=('axis', 'inplace', 'mode')) +@api.jit(static_argnames=('axis', 'inplace', 'mode')) def put_along_axis( arr: ArrayLike, indices: ArrayLike, @@ -506,7 +1043,7 @@ def _is_valid_integer_index_for_slice(idx, size, mode): if _is_integer_index(idx): return -size <= idx < size try: - shape, dtype = np.shape(idx), dtypes.dtype(idx, canonicalize=True) + shape, dtype = np.shape(idx), dtypes.dtype(idx) except: return False if shape == () and np.issubdtype(dtype, np.integer): @@ -520,14 +1057,18 @@ def _is_contiguous_slice(idx): (idx.stop is None or _is_integer_index(idx.stop)) and (idx.step is None or (_is_integer_index(idx.step) and idx.step == 1))) -def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> Array | None: +def _attempt_rewriting_take_via_slice( + arr: Array, indexer: NDIndexer, *, + mode: str | slicing.GatherScatterMode | None, + out_sharding: NamedSharding | PartitionSpec | None = None) -> Array | None: # attempt to compute _rewriting_take via lax.slice(); return None if not possible. - idx = idx if isinstance(idx, tuple) else (idx,) + + # TODO(jakevdp): update implementation to use indexer directly, and to reuse code + # from compute_via_static_slice + idx = tuple(i.index for i in indexer.indices) if not all(isinstance(i, int) for i in arr.shape): return None - if len(idx) > arr.ndim: - return None if any(i is None for i in idx): return None # TODO(jakevdp): handle newaxis case # For symbolic dimensions fallback to gather @@ -535,10 +1076,13 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> for i in idx if isinstance(i, slice) for elt in (i.start, i.stop, i.step)): return None - if any(i is Ellipsis for i in idx): - # Remove ellipses and add trailing `slice(None)`. + # Remove ellipses and pad with trailing `slice(None)` if necessary. + # Do this before checking against rank of `arr` so that `...` can + # count as no dimensions at all (e.g. `my_1d_array[:, ...]` succeeds) idx = _canonicalize_tuple_index(arr.ndim, idx=idx) + if len(idx) > arr.ndim: + return None simple_revs = {i for i, ind in enumerate(idx) if _is_simple_reverse_slice(ind)} int_indices = {i for i, (ind, size) in enumerate(zip(idx, arr.shape)) @@ -551,7 +1095,7 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> # TODO(yashkatariya): fix dynamic_slice with sharding is_sharded = (isinstance(arr, array.ArrayImpl) and not dispatch.is_single_device_sharding(arr.sharding)) - has_partial_slices = any(idx[i].indices(arr.shape[i]) != (0, arr.shape[i], 1) + has_partial_slices = any(idx[i].indices(arr.shape[i]) != (0, arr.shape[i], 1) # type: ignore[union-attr] for i in contiguous_slices) if is_sharded and (int_indices or has_partial_slices): return None @@ -570,7 +1114,7 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> idx += (arr.ndim - len(idx)) * (slice(None),) start_indices: Sequence[ArrayLike] = [] - slice_sizes: Sequence[int] = [] + slice_sizes: list[int] = [] allow_negative_indices: list[bool] = [] for ind, size in safe_zip(idx, arr.shape): @@ -582,64 +1126,112 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> allow_negative_indices.append(start < 0 or stop < 0) else: assert np.issubdtype(dtypes.dtype(ind), np.integer) # checked above - assert np.shape(ind) == () # checked above + assert np.shape(ind) == () # type: ignore[arg-type] # checked above start_indices.append(ind) slice_sizes.append(1) allow_negative_indices.append( not isinstance(ind, (int, np.integer)) or bool(ind < 0)) + # Try to use static slicing when possible. if all(isinstance(i, (int, np.integer)) and i >= 0 for i in start_indices): int_start_indices = [int(i) for i in start_indices] # type: ignore int_limit_indices = [i + s for i, s in zip(int_start_indices, slice_sizes)] - arr = lax.slice( + arr = slicing.slice( arr, start_indices=int_start_indices, limit_indices=int_limit_indices) else: # We must be careful with dtypes because dynamic_slice requires all # start indices to have matching types. if len(start_indices) > 1: - start_indices = util.promote_dtypes(*start_indices) - arr = lax.dynamic_slice( - arr, start_indices=start_indices, slice_sizes=slice_sizes, - allow_negative_indices=allow_negative_indices) + index_dtype = lax_utils.int_dtype_for_shape(arr.shape, signed=True) + start_indices = [lax.convert_element_type(idx, index_dtype) for idx in start_indices] + jnp_error._check_precondition_oob_dynamic_slice( + arr.shape, start_indices, slice_sizes, allow_negative_indices + ) + internal_ds = partial(slicing.dynamic_slice, slice_sizes=slice_sizes, + allow_negative_indices=allow_negative_indices) + if out_sharding is not None: + out_sharding = canonicalize_sharding(out_sharding, 'take') + arr = auto_axes( + internal_ds, + out_sharding=out_sharding, + axes=out_sharding.mesh.explicit_axes, # type: ignore + )(arr, start_indices) + else: + arr = internal_ds(arr, start_indices) if int_indices: arr = lax.squeeze(arr, tuple(int_indices)) return arr -def rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False, - mode=None, fill_value=None, out_sharding=None): +class IndexingStrategy(enum.Enum): + AUTO = 'auto' + GATHER = 'gather' + SCATTER = 'scatter' + STATIC_SLICE = 'static_slice' + DYNAMIC_SLICE = 'dynamic_slice' + + +def rewriting_take( + arr: Array, + idx: Index | tuple[Index, ...], *, + indices_are_sorted: bool = False, + unique_indices: bool = False, + mode: str | slicing.GatherScatterMode | None = None, + fill_value: ArrayLike | None = None, + normalize_indices: bool = True, + out_sharding: NamedSharding | PartitionSpec | None = None, + strategy: IndexingStrategy = IndexingStrategy.AUTO, +) -> Array: # Computes arr[idx]. # All supported cases of indexing can be implemented as an XLA gather, # followed by an optional reverse and broadcast_in_dim. + indexer = NDIndexer.from_raw_indices(idx, arr.shape) + + if not isinstance(strategy, IndexingStrategy): + raise TypeError(f"Expected strategy to be IndexingStrategy; got {strategy}") + + if config.check_static_indices.value and (mode is None or slicing.GatherScatterMode.from_any(mode) == slicing.GatherScatterMode.PROMISE_IN_BOUNDS): + indexer.validate_static_indices(normalize_indices=normalize_indices) + + if strategy == IndexingStrategy.STATIC_SLICE: + return indexer.compute_via_static_slice( + arr, mode=mode, normalize_indices=normalize_indices) + + if strategy == IndexingStrategy.DYNAMIC_SLICE: + return indexer.compute_via_dynamic_slice( + arr, mode=mode, normalize_indices=normalize_indices) + + # For simplicity of generated primitives, we call lax.slice or lax.dynamic_slice + # in the simplest cases: i.e. non-dynamic arrays indexed with integers and slices. + # TODO(jakevdp): lower to slice even when normalize_indices is False + if strategy == IndexingStrategy.AUTO and normalize_indices: + result = _attempt_rewriting_take_via_slice(arr, indexer, mode=mode, out_sharding=out_sharding) + if result is not None: + return result + + indexer = indexer.expand_bool_indices() + dynamic_idx, treedef = tree_flatten(indexer) + internal_gather = partial( + _gather, treedef=treedef, + indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, + mode=mode, fill_value=fill_value, normalize_indices=normalize_indices) + if out_sharding is not None: + out_sharding = canonicalize_sharding(out_sharding, 'take') + return auto_axes(internal_gather, out_sharding=out_sharding, + axes=out_sharding.mesh.explicit_axes, # type: ignore + )(arr, dynamic_idx) + return internal_gather(arr, dynamic_idx) - # For simplicity of generated primitives, we call lax.dynamic_slice in the - # simplest cases: i.e. non-dynamic arrays indexed with integers and slices. - - if (result := _attempt_rewriting_take_via_slice(arr, idx, mode)) is not None: - return result - - # TODO(mattjj,dougalm): expand dynamic shape indexing support - if config.dynamic_shapes.value and arr.ndim > 0: - try: aval = core.get_aval(idx) - except: pass - else: - if (isinstance(aval, core.DShapedArray) and aval.shape == () and - dtypes.issubdtype(aval.dtype, np.integer) and - not dtypes.issubdtype(aval.dtype, dtypes.bool_) and - isinstance(arr.shape[0], int)): - return lax.dynamic_index_in_dim(arr, idx, keepdims=False) - - treedef, static_idx, dynamic_idx = split_index_for_jit(idx, arr.shape) - return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, - unique_indices, mode, fill_value, out_sharding) # TODO(phawkins): re-enable jit after fixing excessive recompilation for # slice indexes (e.g., slice(0, 5, None), slice(10, 15, None), etc.). -# @partial(jit, static_argnums=(1, 2)) -def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, - unique_indices, mode, fill_value, out_sharding): - idx = merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx) - indexer = index_to_gather(np.shape(arr), idx) # shared with _scatter_update +# @api.jit(static_argnums=(1, 2)) +def _gather(arr, dynamic_idx, *, treedef, indices_are_sorted, + unique_indices, mode, fill_value, normalize_indices): + parsed_idx = tree_unflatten(treedef, dynamic_idx) + indexer = parsed_idx.to_gather(core.typeof(arr).sharding, + normalize_indices=normalize_indices) + jnp_error._check_precondition_oob_gather(arr.shape, indexer.gather_indices) y = arr if fill_value is not None: @@ -647,7 +1239,7 @@ def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, "fill_value argument to indexed get()") if np.ndim(fill_value) != 0: raise ValueError("fill_value argument to indexed get() must be a scalar") - if isinstance(fill_value, np.ndarray): + if isinstance(fill_value, (np.ndarray, literals.TypedNdArray)): fill_value = fill_value.item() if indexer.scalar_bool_dims: @@ -660,27 +1252,20 @@ def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, # We avoid generating a gather when indexer.gather_indices.size is empty. if not core.is_empty_shape(indexer.gather_indices.shape): - internal_gather = partial( - lax.gather, - dimension_numbers=indexer.dnums, - slice_sizes=indexer.gather_slice_shape, + y = slicing.gather( + y, indexer.gather_indices, indexer.dnums, indexer.gather_slice_shape, unique_indices=unique_indices or indexer.unique_indices, indices_are_sorted=indices_are_sorted or indexer.indices_are_sorted, mode=mode, fill_value=fill_value) - if out_sharding is not None: - internal_gather = auto_axes( - internal_gather, axes=mesh_lib.get_abstract_mesh().axis_names, - out_shardings=out_sharding) - y = internal_gather(y, indexer.gather_indices) # Reverses axes with negative strides. if indexer.reversed_y_dims: y = lax.rev(y, indexer.reversed_y_dims) - # This adds np.newaxis/None dimensions. return lax.expand_dims(y, indexer.newaxis_dims) -class _Indexer(NamedTuple): + +class _GatherIndexer(NamedTuple): # The expected shape of the slice output. slice_shape: Sequence[int] # The slice shape to pass to lax.gather(). @@ -688,7 +1273,7 @@ class _Indexer(NamedTuple): # The gather indices to use. gather_indices: ArrayLike # A GatherDimensionNumbers object describing the gather to perform. - dnums: lax.GatherDimensionNumbers + dnums: slicing.GatherDimensionNumbers # Are the gather_indices known to be non-overlapping and/or sorted? # (In practice, these translate to "there no advanced indices", because @@ -708,113 +1293,47 @@ class _Indexer(NamedTuple): # for gathers before performing other index operations. scalar_bool_dims: Sequence[int] + # The expected sharding of the slice output. + slice_sharding: NamedSharding | None = None -def split_index_for_jit(idx, shape): - """Splits indices into necessarily-static and dynamic parts. - Used to pass indices into `jit`-ted function. - """ - # Convert list indices to tuples in cases (deprecated by NumPy.) - idx = eliminate_deprecated_list_indexing(idx) - if any(isinstance(i, str) for i in idx): - raise TypeError(f"JAX does not support string indexing; got {idx=}") - - # Expand any (concrete) boolean indices. We can then use advanced integer - # indexing logic to handle them. - idx = _expand_bool_indices(idx, shape) - - leaves, treedef = tree_flatten(idx) - dynamic = [None] * len(leaves) - static = [None] * len(leaves) - for i, x in enumerate(leaves): - if x is Ellipsis: - static[i] = x - elif isinstance(x, slice): - # slice objects aren't hashable. - static[i] = (x.start, x.stop, x.step) - else: - dynamic[i] = x - return treedef, tuple(static), dynamic - -def merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx): - """Recombines indices that were split by split_index_for_jit.""" - idx = [] - for s, d in zip(static_idx, dynamic_idx): - if d is not None: - idx.append(d) - elif isinstance(s, tuple): - idx.append(slice(s[0], s[1], s[2])) - else: - idx.append(s) - return treedef.unflatten(idx) - -def _int(aval): - return not aval.shape and dtypes.issubdtype(aval.dtype, np.integer) +def _index_to_gather(indexer: NDIndexer, *, x_sharding: NamedSharding | Any, + normalize_indices: bool = True) -> _GatherIndexer: + indexer.validate_slices() + indexer = indexer.convert_sequences_to_arrays() -def _aval_or_none(x): - try: - return core.get_aval(x) - except: - return None - -def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], - normalize_indices: bool = True) -> _Indexer: - # Convert sequences to arrays - idx = tuple(lax_numpy.asarray(i, dtype=None if i else int) - if isinstance(i, Sequence) else i for i in idx) - abstract_idx = [_aval_or_none(i) for i in idx] - float_indices = [(i, val, aval) for i, (val, aval) in enumerate(zip(idx, abstract_idx)) - if aval is not None and dtypes.issubdtype(aval, np.inexact)] - - # Check for float or complex indices: - if float_indices: - i, val, aval = float_indices[0] - msg = ("Indexer must have integer or boolean type, got indexer " - "with type {} at position {}, indexer value {}") - raise TypeError(msg.format(aval.dtype.name, i, val)) - - # Check whether advanced indices are contiguous. We must do this before - # removing ellipses (https://github.com/jax-ml/jax/issues/25109) - # If advanced idexing axes do not appear contiguously, NumPy semantics - # move the advanced axes to the front. - is_advanced, = np.nonzero([isinstance(e, (int, np.integer, Array, np.ndarray)) - or lax_numpy.isscalar(e) for e in idx]) + is_advanced = np.nonzero([idx.typ in {IndexType.ARRAY, IndexType.INTEGER} for idx in indexer.indices]) advanced_axes_are_contiguous = np.all(np.diff(is_advanced) == 1) - # Remove ellipses and add trailing slice(None)s. - idx = _canonicalize_tuple_index(len(x_shape), idx) + indexer = indexer.expand_ellipses() + + scalar_bool_dims: Sequence[int] = [n for n, i in enumerate(indexer.indices) if i.typ == IndexType.BOOLEAN] + indexer, x_spec = indexer.expand_scalar_bool_indices(x_sharding.spec) - # Check for scalar boolean indexing: this requires inserting extra dimensions - # before performing the rest of the logic. - scalar_bool_dims: Sequence[int] = [n for n, i in enumerate(idx) if isinstance(i, bool)] - if scalar_bool_dims: - idx = tuple(np.arange(int(i)) if isinstance(i, bool) else i for i in idx) - x_shape = list(x_shape) - for i in sorted(scalar_bool_dims): - x_shape.insert(i, 1) - x_shape = tuple(x_shape) + if normalize_indices: + indexer = indexer.normalize_indices() # Check for advanced indexing: # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing - advanced_indexes: Sequence[Array | np.ndarray] | None = None + # The advanced indices. + advanced_indexes: Sequence[Array] = [] # The positions of the advanced indexing axes in `idx`. idx_advanced_axes: Sequence[int] = [] # The positions of the advanced indexes in x's shape. # collapsed, after None axes have been removed. See below. - x_advanced_axes: Sequence[int] | None = None + x_advanced_axes: Sequence[int] = [] - if _is_advanced_int_indexer(idx): - idx_no_nones = [(i, d) for i, d in enumerate(idx) if d is not None] + if indexer.is_advanced_int_indexer(): + idx_without_none = [(i, d) for i, d in enumerate(indexer.indices) if d.typ != IndexType.NONE] advanced_pairs = ( - (lax_numpy.asarray(e), i, j) for j, (i, e) in enumerate(idx_no_nones) - if lax_numpy.isscalar(e) or isinstance(e, (Sequence, Array, np.ndarray))) - if normalize_indices: - advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j) - for e, i, j in advanced_pairs) - advanced_indexes, idx_advanced_axes, x_advanced_axes = zip(*advanced_pairs) + (lax_numpy.asarray(e.index), i, j) + for j, (i, e) in enumerate(idx_without_none) + if e.typ in [IndexType.ARRAY, IndexType.INTEGER] + ) + advanced_indexes, idx_advanced_axes, x_advanced_axes = unzip3(advanced_pairs) x_axis = 0 # Current axis in x. y_axis = 0 # Current axis in y, before collapsing. See below. @@ -825,10 +1344,7 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], collapsed_slice_dims: list[int] = [] start_index_map: list[int] = [] - use_64bit_index = ( - any(not core.is_constant_dim(d) or d >= (1 << 31) for d in x_shape) and - config.enable_x64.value) - index_dtype = np.dtype('int64') if use_64bit_index else np.dtype('int32') + index_dtype = lax_utils.int_dtype_for_shape(indexer.shape, signed=True) # Gather indices. # Pairs of (array, start_dim) values. These will be broadcast into @@ -841,25 +1357,25 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], # First, y is broadcast to slice_shape. In general `y` only need broadcast to # the right shape. slice_shape: list[int] = [] - # Next, y is squeezed to remove newaxis_dims. This removes np.newaxis/`None` # indices, which the scatter cannot remove itself. newaxis_dims: list[int] = [] - # Finally, we reverse reversed_y_dims to handle slices with negative strides. reversed_y_dims: list[int] = [] gather_slice_shape: list[int] = [] + slice_spec = [] - for idx_pos, i in enumerate(idx): + for idx_pos, index in enumerate(indexer.indices): # Handle the advanced indices here if: # * the advanced indices were not contiguous and we are the start. # * we are at the position of the first advanced index. - if (advanced_indexes is not None and + if (advanced_indexes and (advanced_axes_are_contiguous and idx_pos == idx_advanced_axes[0] or not advanced_axes_are_contiguous and idx_pos == 0)): advanced_index_arrs = util._broadcast_arrays(*advanced_indexes) shape = advanced_index_arrs[0].shape + aia_spec = core.typeof(advanced_index_arrs[0]).sharding.spec ndim = len(shape) start_dim = len(gather_indices_shape) @@ -873,6 +1389,7 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], start_index_map.extend(x_advanced_axes) collapsed_slice_dims.extend(x_advanced_axes) slice_shape.extend(shape) + slice_spec.extend(aia_spec) y_axis += ndim collapsed_y_axis += ndim @@ -882,44 +1399,35 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], gather_slice_shape.append(1) continue - # Handle basic int indexes. - abstract_i = _aval_or_none(i) - if isinstance(abstract_i, core.ShapedArray) and _int(abstract_i): - if core.definitely_equal(x_shape[x_axis], 0): + if index.typ in [IndexType.INTEGER, IndexType.ARRAY] and np.ndim(index.index) == 0: # type: ignore[arg-type] + # Basic scalar int indices + if core.definitely_equal(indexer.shape[x_axis], 0): # XLA gives error when indexing into an axis of size 0 raise IndexError(f"index is out of bounds for axis {x_axis} with size 0") - i = _normalize_index(i, x_shape[x_axis]) if normalize_indices else i - i_converted = lax.convert_element_type(i, index_dtype) + i_converted = lax.convert_element_type(index.index, index_dtype) # type: ignore[arg-type] gather_indices.append((i_converted, len(gather_indices_shape))) collapsed_slice_dims.append(x_axis) gather_slice_shape.append(1) start_index_map.append(x_axis) x_axis += 1 - # Handle np.newaxis (None) - elif i is None: + + elif index.typ == IndexType.NONE: + # None indexing: add a dimension. slice_shape.append(1) + slice_spec.append(None) newaxis_dims.append(y_axis) y_axis += 1 - elif isinstance(i, slice): - # Handle slice index (only static, otherwise an error is raised) - if not all(_is_slice_element_none_or_constant_or_symbolic(elt) - for elt in (i.start, i.stop, i.step)): - msg = ("Array slice indices must have static start/stop/step to be used " - "with NumPy indexing syntax. " - f"Found slice({i.start}, {i.stop}, {i.step}). " - "To index a statically sized " - "array at a dynamic position, try lax.dynamic_slice/" - "dynamic_update_slice (JAX does not support dynamically sized " - "arrays within JIT compiled functions).") - raise IndexError(msg) - - start, step, slice_size = core.canonicalize_slice(i, x_shape[x_axis]) + elif index.typ == IndexType.SLICE: + # Handle static slice index. + assert isinstance(index.index, slice) + start, step, slice_size = core.canonicalize_slice(index.index, indexer.shape[x_axis]) slice_shape.append(slice_size) + slice_spec.append(x_spec[x_axis]) if core.definitely_equal(step, 1): - # Avoid generating trivial gather (an optimization) - if not core.definitely_equal(slice_size, x_shape[x_axis]): + # Optimization: avoid generating trivial gather. + if not core.definitely_equal(slice_size, indexer.shape[x_axis]): gather_indices.append((lax.convert_element_type(start, index_dtype), len(gather_indices_shape))) start_index_map.append(x_axis) @@ -942,14 +1450,7 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], y_axis += 1 x_axis += 1 else: - if (abstract_i is not None and - not (dtypes.issubdtype(abstract_i.dtype, np.integer) or dtypes.issubdtype(abstract_i.dtype, np.bool_))): - msg = ("Indexer must have integer or boolean type, got indexer " - "with type {} at position {}, indexer value {}") - raise TypeError(msg.format(abstract_i.dtype.name, idx_pos, i)) - - raise IndexError("Indexing mode not yet supported. Got unsupported indexer " - f"at position {idx_pos}: {i!r}") + raise IndexError(f"Got unsupported indexer at position {idx_pos}: {index!r}") if len(gather_indices) == 0: gather_indices_array: ArrayLike = np.zeros((0,), dtype=index_dtype) @@ -964,25 +1465,28 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], for g, i in gather_indices], last_dim) - dnums = lax.GatherDimensionNumbers( + dnums = slicing.GatherDimensionNumbers( offset_dims = tuple(offset_dims), collapsed_slice_dims = tuple(sorted(collapsed_slice_dims)), start_index_map = tuple(start_index_map) ) - return _Indexer( + slice_sharding = x_sharding.update(spec=slice_spec) + return _GatherIndexer( slice_shape=slice_shape, newaxis_dims=tuple(newaxis_dims), gather_slice_shape=gather_slice_shape, reversed_y_dims=reversed_y_dims, dnums=dnums, gather_indices=gather_indices_array, - unique_indices=advanced_indexes is None, - indices_are_sorted=advanced_indexes is None, - scalar_bool_dims=scalar_bool_dims) + unique_indices=not advanced_indexes, + indices_are_sorted=not advanced_indexes, + scalar_bool_dims=scalar_bool_dims, + slice_sharding=slice_sharding) def _should_unpack_list_index(x): """Helper for eliminate_deprecated_list_indexing.""" - return (isinstance(x, (np.ndarray, Array)) and np.ndim(x) != 0 + return (isinstance(x, (np.ndarray, Array, literals.TypedNdArray)) + and np.ndim(x) != 0 or isinstance(x, (Sequence, slice)) or x is Ellipsis or x is None) @@ -991,7 +1495,9 @@ def eliminate_deprecated_list_indexing(idx): # non-tuple sequence containing slice objects, [Ellipses, or newaxis # objects]". Detects this and raises a TypeError. if not isinstance(idx, tuple): - if isinstance(idx, Sequence) and not isinstance(idx, (Array, np.ndarray, str)): + if isinstance(idx, Sequence) and not isinstance( + idx, (Array, np.ndarray, literals.TypedNdArray, str) + ): # As of numpy 1.16, some non-tuple sequences of indices result in a warning, while # others are converted to arrays, based on a set of somewhat convoluted heuristics # (See https://github.com/numpy/numpy/blob/v1.19.2/numpy/core/src/multiarray/mapping.c#L179-L343) @@ -1018,52 +1524,6 @@ def _is_boolean_index(i): or isinstance(i, list) and i and all(_is_scalar(e) and dtypes.issubdtype(dtypes.dtype(e), np.bool_) for e in i)) -def _expand_bool_indices(idx, shape): - """Converts concrete bool indexes into advanced integer indexes.""" - out = [] - total_dims = len(shape) - num_ellipsis = sum(e is Ellipsis for e in idx) - if num_ellipsis > 1: - raise IndexError("an index can only have a single ellipsis ('...')") - elif num_ellipsis == 1: - total_dims = sum(np.ndim(e) if _is_boolean_index(e) else 1 for e in idx - if e is not None and e is not Ellipsis) - ellipsis_offset = 0 - newaxis_offset = 0 - for dim_number, i in enumerate(idx): - try: - abstract_i = core.get_aval(i) - except TypeError: - abstract_i = None - if _is_boolean_index(i): - if isinstance(i, list): - i = lax_numpy.array(i) - abstract_i = core.get_aval(i) - - if not core.is_concrete(i): - # TODO(mattjj): improve this error by tracking _why_ the indices are not concrete - raise errors.NonConcreteBooleanIndexError(abstract_i) - elif np.ndim(i) == 0: - out.append(bool(i)) - else: - i_shape = np.shape(i) - start = len(out) + ellipsis_offset - newaxis_offset - expected_shape = shape[start: start + np.ndim(i)] - if len(i_shape) != len(expected_shape): - raise IndexError(f"too many boolean indices at index {dim_number}: got mask of shape " - f"{i_shape}, but only {len(expected_shape)} dimensions remain.") - if not all(s1 in (0, s2) for s1, s2 in zip(i_shape, expected_shape)): - raise IndexError("boolean index did not match shape of indexed array in index " - f"{dim_number}: got {i_shape}, expected {expected_shape}") - out.extend(np.where(i)) - else: - out.append(i) - if i is Ellipsis: - ellipsis_offset = len(shape) - total_dims - 1 - if i is None: - newaxis_offset += 1 - return tuple(out) - def _is_slice_element_none_or_constant_or_symbolic(elt): """Return True if elt is a constant or None.""" @@ -1074,27 +1534,12 @@ def _is_slice_element_none_or_constant_or_symbolic(elt): except TypeError: return False -# TODO(mattjj): clean up this logic -def _is_advanced_int_indexer(idx): - """Returns True if idx should trigger int array indexing, False otherwise.""" - # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing - assert isinstance(idx, tuple) - if all(e is None or e is Ellipsis or isinstance(e, slice) - or _is_scalar(e) and dtypes.issubdtype(dtypes.dtype(e), np.integer) for e in idx): - return False - return all(e is None or e is Ellipsis or isinstance(e, slice) - or _is_int_arraylike(e) for e in idx) - -def _is_int_arraylike(x): - """Returns True if x is array-like with integer dtype, False otherwise.""" - return (isinstance(x, int) and not isinstance(x, bool) - or dtypes.issubdtype(getattr(x, "dtype", None), np.integer) - or isinstance(x, (list, tuple)) and all(_is_int_arraylike(e) for e in x)) - def _is_scalar(x): """Checks if a Python or NumPy scalar.""" - return np.isscalar(x) or (isinstance(x, (np.ndarray, Array)) - and np.ndim(x) == 0) + return np.isscalar(x) or ( + isinstance(x, (np.ndarray, literals.TypedNdArray, Array)) + and np.ndim(x) == 0 + ) def _canonicalize_tuple_index(arr_ndim, idx): """Helper to remove Ellipsis and add in the implicit trailing slice(None).""" @@ -1259,16 +1704,16 @@ def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike, [ 0, 0, 20, 0, 0], [ 0, 0, 0, 0, 30]], dtype=int32) """ + if inplace: + raise ValueError( + "jax.numpy.put cannot modify arrays in-place, because JAX arrays are immutable. " + "Pass inplace=False to instead return an updated array.") arr, ind_arr, _ = util.ensure_arraylike("put", a, ind, v) ind_arr = ind_arr.ravel() v_arr = lax_numpy.ravel(v) if not arr.size or not ind_arr.size or not v_arr.size: return arr v_arr = lax_numpy._tile_to_size(v_arr, len(ind_arr)) - if inplace: - raise ValueError( - "jax.numpy.put cannot modify arrays in-place, because JAX arrays are immutable. " - "Pass inplace=False to instead return an updated array.") if mode is None: scatter_mode = "drop" elif mode == "clip": diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 96efc48062e1..42b9e7a151b1 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -28,29 +28,28 @@ import builtins from collections.abc import Callable, Sequence from functools import partial -import importlib import math import operator import os -from typing import (Any, IO, Literal, Protocol, TypeVar, Union, overload) -import warnings +from typing import Any, IO, Literal, Protocol, TypeVar, Union, overload -import jax -from jax import jit -from jax import lax -from jax._src import config +import numpy as np + +from jax._src import api from jax._src import core from jax._src import deprecations from jax._src import dtypes -from jax._src import xla_bridge from jax._src.api_util import _ensure_index_tuple from jax._src.custom_derivatives import custom_jvp -from jax._src.lax import lax as lax_internal -from jax._src.lax.lax import (PrecisionLike,_array_copy, - _sort_le_comparator, _sort_lt_comparator) +from jax._src.lax import control_flow +from jax._src.lax import convolution as lax_conv +from jax._src.lax import lax +from jax._src.lax import slicing as lax_slicing +from jax._src.lax import special as lax_special +from jax._src.lax import utils as lax_utils from jax._src.lib import xla_client as xc -from jax._src.numpy.array_creation import (empty, empty_like, full, - ones, ones_like, zeros, zeros_like) +from jax._src.numpy.array_constructors import array, asarray +from jax._src.numpy import array_creation from jax._src.numpy import indexing from jax._src.numpy import reductions from jax._src.numpy import tensor_contractions @@ -58,29 +57,22 @@ from jax._src.numpy import util from jax._src.numpy.sorting import argsort, sort from jax._src.numpy.vectorize import vectorize +from jax._src.sharding_impls import canonicalize_sharding from jax._src.typing import ( - Array, ArrayLike, DType, DTypeLike, DeprecatedArg, DimSize, Shape + Array, ArrayLike, DType, DTypeLike, DeprecatedArg, DimSize, Shape, SupportsShape ) from jax._src.util import ( - NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, + canonicalize_axis as _canonicalize_axis, + canonicalize_axis_tuple as _canonicalize_axis_tuple, ceil_of_ratio, safe_zip, set_module, unzip2) -from jax.sharding import Sharding -from jax._src.sharding_impls import SingleDeviceSharding -from jax.tree_util import tree_leaves, tree_map -import numpy as np +from jax._src.sharding import Sharding +from jax._src.sharding_impls import NamedSharding, PartitionSpec as P +from jax._src.mesh import get_abstract_mesh +from jax._src.pjit import auto_axes +from jax._src.tree_util import tree_map export = set_module('jax.numpy') -for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib.cuda']: - try: - cuda_plugin_extension = importlib.import_module( - f'{pkg_name}.cuda_plugin_extension' - ) - except ImportError: - cuda_plugin_extension = None # type: ignore - else: - break - T = TypeVar('T') # Wrappers for NumPy printoptions @@ -150,8 +142,12 @@ def iscomplexobj(x: Any) -> bool: >>> jnp.iscomplexobj(jnp.array([0, 1+2j])) True """ - if x is None: + # Fast path for common types. + if isinstance(x, (complex, np.complexfloating)): + return True + if x is None or isinstance(x, (bool, int, float, str, np.generic)): return False + # Fall back to dtype attribute lookup. try: typ = x.dtype.type except AttributeError: @@ -159,8 +155,7 @@ def iscomplexobj(x: Any) -> bool: return issubdtype(typ, np.complexfloating) -def _dtype(x: Any) -> DType: - return dtypes.dtype(x, canonicalize=True) +_dtype = dtypes.dtype # Dtype-related functions iinfo = dtypes.iinfo @@ -169,50 +164,7 @@ def _dtype(x: Any) -> DType: can_cast = dtypes.can_cast promote_types = dtypes.promote_types -ComplexWarning = NumpyComplexWarning - -_lax_const = lax_internal._const - - -def _convert_and_clip_integer(val: ArrayLike, dtype: DType) -> Array: - """ - Convert integer-typed val to specified integer dtype, clipping to dtype - range rather than wrapping. - - Args: - val: value to be converted - dtype: dtype of output - - Returns: - equivalent of val in new dtype - - Examples - -------- - Normal integer type conversion will wrap: - - >>> val = jnp.uint32(0xFFFFFFFF) - >>> val.astype('int32') - Array(-1, dtype=int32) - - This function clips to the values representable in the new type: - - >>> _convert_and_clip_integer(val, 'int32') - Array(2147483647, dtype=int32) - """ - val = val if isinstance(val, Array) else asarray(val) - dtype = dtypes.canonicalize_dtype(dtype) - if not (issubdtype(dtype, np.integer) and issubdtype(val.dtype, np.integer)): - raise TypeError("_convert_and_clip_integer only accepts integer dtypes.") - - val_dtype = dtypes.canonicalize_dtype(val.dtype) - if val_dtype != val.dtype: - # TODO(jakevdp): this is a weird corner case; need to figure out how to handle it. - # This happens in X32 mode and can either come from a jax value created in another - # context, or a Python integer converted to int64. - pass - min_val = _lax_const(val, max(iinfo(dtype).min, iinfo(val_dtype).min)) - max_val = _lax_const(val, min(iinfo(dtype).max, iinfo(val_dtype).max)) - return clip(val, min_val, max_val).astype(dtype) +ComplexWarning = np.exceptions.ComplexWarning @export @@ -266,11 +218,11 @@ def load(file: IO[bytes] | str | os.PathLike[Any], *args: Any, **kwargs: Any) -> ### implementations of numpy functions in terms of lax @export -@jit +@api.jit def fmin(x1: ArrayLike, x2: ArrayLike) -> Array: """Return element-wise minimum of the input arrays. - JAX implemtentation of :func:`numpy.fmin`. + JAX implementation of :func:`numpy.fmin`. Args: x1: input array or scalar. @@ -318,7 +270,7 @@ def fmin(x1: ArrayLike, x2: ArrayLike) -> Array: @export -@jit +@api.jit def fmax(x1: ArrayLike, x2: ArrayLike) -> Array: """Return element-wise maximum of the input arrays. @@ -499,14 +451,14 @@ def isscalar(element: Any) -> bool: False >>> jnp.isscalar([1]) False - >>> jnp.isscalar(tuple()) + >>> jnp.isscalar(()) False >>> jnp.isscalar(slice(10)) False """ if np.isscalar(element): return True - elif isinstance(element, (np.ndarray, jax.Array)): + elif isinstance(element, (np.ndarray, Array)): return element.ndim == 0 elif hasattr(element, '__jax_array__'): return asarray(element).ndim == 0 @@ -547,18 +499,18 @@ def result_type(*args: Any) -> DType: of the ``jax_enable_x64`` configuration flag, meaning that 64-bit types may be downcast to 32-bit: - >>> jnp.result_type('float64') + >>> jnp.result_type('float64') # doctest: +SKIP dtype('float32') For details on 64-bit values, refer to `Sharp bits - double precision`_: - .. _Sharp bits - double precision: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision + .. _Sharp bits - double precision: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision """ return dtypes.result_type(*args) @export -@jit +@api.jit def trunc(x: ArrayLike) -> Array: """Round input to the nearest integer towards zero. @@ -589,13 +541,13 @@ def trunc(x: ArrayLike) -> Array: [-8., 5., 3.]], dtype=float32) """ x = util.ensure_arraylike('trunc', x) - if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')): + if dtypes.isdtype(x.dtype, ('integral', 'bool')): return x - return where(lax.lt(x, _lax_const(x, 0)), ufuncs.ceil(x), ufuncs.floor(x)) + return where(lax.lt(x, lax._const(x, 0)), ufuncs.ceil(x), ufuncs.floor(x)) -@partial(jit, static_argnames=['mode', 'op', 'precision', 'preferred_element_type']) -def _conv(x: Array, y: Array, mode: str, op: str, precision: PrecisionLike, +@api.jit(static_argnames=['mode', 'op', 'precision', 'preferred_element_type']) +def _conv(x: Array, y: Array, mode: str, op: str, precision: lax.PrecisionLike, preferred_element_type: DTypeLike | None = None) -> Array: if np.ndim(x) != 1 or np.ndim(y) != 1: raise ValueError(f"{op}() only support 1-dimensional inputs.") @@ -628,16 +580,16 @@ def _conv(x: Array, y: Array, mode: str, op: str, precision: PrecisionLike, else: raise ValueError("mode must be one of ['full', 'same', 'valid']") - result = lax.conv_general_dilated(x[None, None, :], y[None, None, :], (1,), - padding, precision=precision, - preferred_element_type=preferred_element_type) + result = lax_conv.conv_general_dilated(x[None, None, :], y[None, None, :], (1,), + padding, precision=precision, + preferred_element_type=preferred_element_type) return result[0, 0, out_order] @export -@partial(jit, static_argnames=('mode', 'precision', 'preferred_element_type')) +@api.jit(static_argnames=('mode', 'precision', 'preferred_element_type')) def convolve(a: ArrayLike, v: ArrayLike, mode: str = 'full', *, - precision: PrecisionLike = None, + precision: lax.PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: r"""Convolution of two one dimensional arrays. @@ -711,9 +663,9 @@ def convolve(a: ArrayLike, v: ArrayLike, mode: str = 'full', *, @export -@partial(jit, static_argnames=('mode', 'precision', 'preferred_element_type')) +@api.jit(static_argnames=('mode', 'precision', 'preferred_element_type')) def correlate(a: ArrayLike, v: ArrayLike, mode: str = 'valid', *, - precision: PrecisionLike = None, + precision: lax.PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: r"""Correlation of two one dimensional arrays. @@ -845,7 +797,7 @@ def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10, range = (where(reductions.ptp(range) == 0, range[0] - 0.5, range[0]), where(reductions.ptp(range) == 0, range[1] + 0.5, range[1])) assert range is not None - return linspace(range[0], range[1], bins_int + 1, dtype=dtype) + return array_creation.linspace(range[0], range[1], bins_int + 1, dtype=dtype) @export @@ -911,11 +863,11 @@ def histogram(a: ArrayLike, bins: ArrayLike = 10, Array(True, dtype=bool) """ if weights is None: - util.check_arraylike("histogram", a, bins) + a, _ = util.ensure_arraylike("histogram", a, bins) a, = util.promote_dtypes_inexact(a) - weights = ones_like(a) + weights = array_creation.ones_like(a) else: - util.check_arraylike("histogram", a, bins, weights) + a, _, weights = util.ensure_arraylike("histogram", a, bins, weights) if np.shape(a) != np.shape(weights): raise ValueError("weights should have the same shape as a.") a, weights = util.promote_dtypes_inexact(a, weights) @@ -923,7 +875,7 @@ def histogram(a: ArrayLike, bins: ArrayLike = 10, bin_edges = histogram_bin_edges(a, bins, range, weights) bin_idx = searchsorted(bin_edges, a, side='right') bin_idx = where(a == bin_edges[-1], len(bin_edges) - 1, bin_idx) - counts = zeros(len(bin_edges), weights.dtype).at[bin_idx].add(weights)[1:] + counts = array_creation.zeros(len(bin_edges), weights.dtype).at[bin_idx].add(weights)[1:] if density: bin_widths = diff(bin_edges) counts = counts / bin_widths / counts.sum() @@ -1005,7 +957,7 @@ def histogram2d(x: ArrayLike, y: ArrayLike, bins: ArrayLike | list[ArrayLike] = >>> jnp.allclose(normed_sum, 1.0) Array(True, dtype=bool) """ - util.check_arraylike("histogram2d", x, y) + x, y = util.ensure_arraylike("histogram2d", x, y) try: N = len(bins) # type: ignore[arg-type] except TypeError: @@ -1077,10 +1029,10 @@ def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10, Array(True, dtype=bool) """ if weights is None: - util.check_arraylike("histogramdd", sample) + sample = util.ensure_arraylike("histogramdd", sample) sample, = util.promote_dtypes_inexact(sample) else: - util.check_arraylike("histogramdd", sample, weights) + sample, weights = util.ensure_arraylike("histogramdd", sample, weights) if np.shape(weights) != np.shape(sample)[:1]: raise ValueError("should have one weight for each sample.") sample, weights = util.promote_dtypes_inexact(sample, weights) @@ -1203,8 +1155,8 @@ def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array: Array([[1, 3], [2, 4]], dtype=int32) """ - util.check_arraylike("transpose", a) - axes_ = list(range(np.ndim(a))[::-1]) if axes is None else axes + a = util.ensure_arraylike("transpose", a) + axes_ = list(range(a.ndim)[::-1]) if axes is None else axes axes_ = [_canonicalize_axis(i, np.ndim(a)) for i in axes_] return lax.transpose(a, axes_) @@ -1235,7 +1187,7 @@ def permute_dims(a: ArrayLike, /, axes: tuple[int, ...]) -> Array: [2, 5], [3, 6]], dtype=int32) """ - util.check_arraylike("permute_dims", a) + a = util.ensure_arraylike("permute_dims", a) return lax.transpose(a, axes) @@ -1285,8 +1237,8 @@ def matrix_transpose(x: ArrayLike, /) -> Array: [[5, 7], [6, 8]]], dtype=int32) """ - util.check_arraylike("matrix_transpose", x) - ndim = np.ndim(x) + x = util.ensure_arraylike("matrix_transpose", x) + ndim = x.ndim if ndim < 2: raise ValueError(f"x must be at least two-dimensional for matrix_transpose; got {ndim=}") axes = (*range(ndim - 2), ndim - 1, ndim - 2) @@ -1294,7 +1246,7 @@ def matrix_transpose(x: ArrayLike, /) -> Array: @export -@partial(jit, static_argnames=('k', 'axes')) +@api.jit(static_argnames=('k', 'axes')) def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array: """Rotate an array by 90 degrees counterclockwise in the plane specified by axes. @@ -1353,7 +1305,7 @@ def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array: [11, 8], [12, 9]]], dtype=int32) """ - util.check_arraylike("rot90", m) + m = util.ensure_arraylike("rot90", m) if np.ndim(m) < 2: raise ValueError("rot90 requires its first argument to have ndim at least " f"two, but got first argument of shape {np.shape(m)}, " @@ -1437,7 +1389,7 @@ def flip(m: ArrayLike, axis: int | Sequence[int] | None = None) -> Array: arr = util.ensure_arraylike("flip", m) return _flip(arr, reductions._ensure_optional_axes(axis)) -@partial(jit, static_argnames=('axis',)) +@api.jit(static_argnames=('axis',)) def _flip(m: Array, axis: int | tuple[int, ...] | None = None) -> Array: if axis is None: return lax.rev(m, list(range(len(np.shape(m))))) @@ -1500,7 +1452,7 @@ def flipud(m: ArrayLike) -> Array: @export -@jit +@api.jit def iscomplex(x: ArrayLike) -> Array: """Return boolean array showing where the input is complex. @@ -1521,11 +1473,11 @@ def iscomplex(x: ArrayLike) -> Array: Array([False, False, False, True, True], dtype=bool) """ i = ufuncs.imag(x) - return lax.ne(i, _lax_const(i, 0)) + return lax.ne(i, lax._const(i, 0)) @export -@jit +@api.jit def isreal(x: ArrayLike) -> Array: """Return boolean array showing where the input is real. @@ -1546,11 +1498,11 @@ def isreal(x: ArrayLike) -> Array: Array([ True, True, True, True, False], dtype=bool) """ i = ufuncs.imag(x) - return lax.eq(i, _lax_const(i, 0)) + return lax.eq(i, lax._const(i, 0)) @export -@partial(jit, static_argnames=['deg']) +@api.jit(static_argnames=['deg']) def angle(z: ArrayLike, deg: bool = False) -> Array: """Return the angle of a complex valued number or array. @@ -1589,12 +1541,13 @@ def angle(z: ArrayLike, deg: bool = False) -> Array: [[ 71.57 -68.2 ] [-36.87 33.69]] """ + z = util.ensure_arraylike('angle', z) re = ufuncs.real(z) im = ufuncs.imag(z) - dtype = _dtype(re) + dtype = re.dtype if not issubdtype(dtype, np.inexact) or ( - issubdtype(_dtype(z), np.floating) and np.ndim(z) == 0): - dtype = dtypes.canonicalize_dtype(dtypes.float_) + issubdtype(z.dtype, np.floating) and np.ndim(z) == 0): + dtype = dtypes.default_float_dtype() re = lax.convert_element_type(re, dtype) im = lax.convert_element_type(im, dtype) result = lax.atan2(im, re) @@ -1602,7 +1555,7 @@ def angle(z: ArrayLike, deg: bool = False) -> Array: @export -@partial(jit, static_argnames=('n', 'axis')) +@api.jit(static_argnames=('n', 'axis')) def diff(a: ArrayLike, n: int = 1, axis: int = -1, prepend: ArrayLike | None = None, append: ArrayLike | None = None) -> Array: @@ -1714,7 +1667,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1, @export -@jit +@api.jit def ediff1d(ary: ArrayLike, to_end: ArrayLike | None = None, to_begin: ArrayLike | None = None) -> Array: """Compute the differences of the elements of the flattened array. @@ -1777,7 +1730,7 @@ def ediff1d(ary: ArrayLike, to_end: ArrayLike | None = None, @export -@partial(jit, static_argnames=("axis", "edge_order")) +@api.jit(static_argnames=("axis", "edge_order")) def gradient( f: ArrayLike, *varargs: ArrayLike, @@ -1850,7 +1803,7 @@ def gradient( a, *spacing = util.promote_dtypes_inexact(f, *varargs) def gradient_along_axis(a, h, axis): - sliced = partial(lax.slice_in_dim, a, axis=axis) + sliced = partial(lax_slicing.slice_in_dim, a, axis=axis) upper_edge = sliced(1, 2) - sliced(0, 1) lower_edge = sliced(-1, None) - sliced(-2, -1) @@ -1868,7 +1821,7 @@ def gradient_along_axis(a, h, axis): h_shape = [1] * a.ndim h_shape[axis] = len(h) h = h.reshape(h_shape) - sliced_x = partial(lax.slice_in_dim, h, axis=axis) + sliced_x = partial(lax_slicing.slice_in_dim, h, axis=axis) upper_edge /= sliced_x(1, 2) - sliced_x(0, 1) lower_edge /= sliced_x(-1, None) - sliced_x(-2, -1) @@ -1944,9 +1897,8 @@ def isrealobj(x: Any) -> bool: @export def reshape( - a: ArrayLike, shape: DimSize | Shape | None = None, order: str = "C", *, - newshape: DimSize | Shape | DeprecatedArg = DeprecatedArg(), - copy: bool | None = None) -> Array: + a: ArrayLike, shape: DimSize | Shape, order: str = "C", *, + copy: bool | None = None, out_sharding=None) -> Array: """Return a reshaped copy of an array. JAX implementation of :func:`numpy.reshape`, implemented in terms of @@ -1962,8 +1914,6 @@ def reshape( JAX does not support ``order="A"``. copy: unused by JAX; JAX always returns a copy, though under JIT the compiler may optimize such copies away. - newshape: deprecated alias of the ``shape`` argument. Will result in a - :class:`DeprecationWarning` if used. Returns: reshaped copy of input array with the specified shape. @@ -2021,25 +1971,18 @@ def reshape( __tracebackhide__ = True util.check_arraylike("reshape", a) - # TODO(jakevdp): finalized 2024-12-2; remove argument after JAX v0.4.40. - if not isinstance(newshape, DeprecatedArg): - raise TypeError("The newshape argument to jnp.reshape was removed in JAX v0.4.36." - " Use shape instead.") - if shape is None: - raise TypeError( - "jnp.shape requires passing a `shape` argument, but none was given." - ) try: - # forward to method for ndarrays - return a.reshape(shape, order=order) # type: ignore[call-overload,union-attr] + if out_sharding is None: + # forward to method for ndarrays + return a.reshape(shape, order=order) # type: ignore[call-overload,union-attr] except AttributeError: pass - return asarray(a).reshape(shape, order=order) + return asarray(a).reshape(shape, order=order, out_sharding=out_sharding) @export -@partial(jit, static_argnames=('order',), inline=True) -def ravel(a: ArrayLike, order: str = "C") -> Array: +@api.jit(static_argnames=('order', 'out_sharding'), inline=True) +def ravel(a: ArrayLike, order: str = "C", *, out_sharding=None) -> Array: """Flatten array into a 1-dimensional shape. JAX implementation of :func:`numpy.ravel`, implemented in terms of @@ -2085,10 +2028,10 @@ def ravel(a: ArrayLike, order: str = "C") -> Array: >>> x.ravel() Array([1, 2, 3, 4, 5, 6], dtype=int32) """ - util.check_arraylike("ravel", a) + a = util.ensure_arraylike("ravel", a) if order == "K": raise NotImplementedError("Ravel not implemented for order='K'.") - return reshape(a, (np.size(a),), order) + return reshape(a, (np.size(a),), order, out_sharding=out_sharding) @export @@ -2150,8 +2093,7 @@ def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int], """ assert len(multi_index) == len(dims), f"len(multi_index)={len(multi_index)} != len(dims)={len(dims)}" dims = tuple(core.concrete_or_error(operator.index, d, "in `dims` argument of ravel_multi_index().") for d in dims) - util.check_arraylike("ravel_multi_index", *multi_index) - multi_index_arr = [asarray(i) for i in multi_index] + multi_index_arr = list(util.ensure_arraylike_tuple("ravel_multi_index", multi_index)) for index in multi_index_arr: if mode == 'raise': core.concrete_or_error(array, index, @@ -2176,8 +2118,7 @@ def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int], else: raise ValueError(f"invalid order={order!r}. Expected 'C' or 'F'") - result = array(0, dtype=(multi_index_arr[0].dtype if multi_index_arr - else dtypes.canonicalize_dtype(dtypes.int_))) + result = array(0, dtype=multi_index_arr[0].dtype if multi_index_arr else int) for i, s in zip(multi_index_arr, strides): result = result + i * int(s) return result @@ -2241,13 +2182,17 @@ def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]: for i, s in reversed(list(enumerate(shape))): indices_arr, out_indices[i] = ufuncs.divmod(indices_arr, s) oob_pos = indices_arr > 0 - oob_neg = indices_arr < -1 + if dtypes.issubdtype(indices_arr.dtype, np.unsignedinteger): + # Unsigned integers can't be out of bounds at the low end. + oob_neg = asarray(False) + else: + oob_neg = indices_arr < -1 return tuple(where(oob_pos, s - 1, where(oob_neg, 0, i)) for s, i in safe_zip(shape, out_indices)) @export -@partial(jit, static_argnames=('new_shape',)) +@api.jit(static_argnames=('new_shape',)) def resize(a: ArrayLike, new_shape: Shape) -> Array: """Return a new array with specified shape. @@ -2259,7 +2204,7 @@ def resize(a: ArrayLike, new_shape: Shape) -> Array: Returns: A resized array with specified shape. The elements of ``a`` are repeated in - the resized array, if the resized array is larger than the original aray. + the resized array, if the resized array is larger than the original array. See also: - :func:`jax.numpy.reshape`: Returns a reshaped copy of an array. @@ -2290,7 +2235,7 @@ def resize(a: ArrayLike, new_shape: Shape) -> Array: new_size = math.prod(new_shape) if arr.size == 0 or new_size == 0: - return zeros_like(arr, shape=new_shape) + return array_creation.zeros_like(arr, shape=new_shape) repeats = ceil_of_ratio(new_size, arr.size) arr = tile(arr, repeats)[:new_size] @@ -2358,7 +2303,7 @@ def squeeze(a: ArrayLike, axis: int | Sequence[int] | None = None) -> Array: arr = util.ensure_arraylike("squeeze", a) return _squeeze(arr, _ensure_index_tuple(axis) if axis is not None else None) -@partial(jit, static_argnames=('axis',), inline=True) +@api.jit(static_argnames=('axis',), inline=True) def _squeeze(a: Array, axis: tuple[int, ...]) -> Array: if axis is None: a_shape = np.shape(a) @@ -2435,13 +2380,13 @@ def expand_dims(a: ArrayLike, axis: int | Sequence[int]) -> Array: [2], [3]]]], dtype=int32) """ - util.check_arraylike("expand_dims", a) + a = util.ensure_arraylike("expand_dims", a) axis = _ensure_index_tuple(axis) return lax.expand_dims(a, axis) @export -@partial(jit, static_argnames=('axis1', 'axis2'), inline=True) +@api.jit(static_argnames=('axis1', 'axis2'), inline=True) def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array: """Swap two axes of an array. @@ -2482,7 +2427,7 @@ def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array: >>> a.transpose(0, 3, 2, 1).shape (2, 5, 4, 3) """ - util.check_arraylike("swapaxes", a) + a = util.ensure_arraylike("swapaxes", a) perm = np.arange(np.ndim(a)) perm[axis1], perm[axis2] = perm[axis2], perm[axis1] return lax.transpose(a, list(perm)) @@ -2541,7 +2486,7 @@ def moveaxis(a: ArrayLike, source: int | Sequence[int], return _moveaxis(arr, _ensure_index_tuple(source), _ensure_index_tuple(destination)) -@partial(jit, static_argnames=('source', 'destination'), inline=True) +@api.jit(static_argnames=('source', 'destination'), inline=True) def _moveaxis(a: Array, source: tuple[int, ...], destination: tuple[int, ...]) -> Array: source = tuple(_canonicalize_axis(i, np.ndim(a)) for i in source) destination = tuple(_canonicalize_axis(i, np.ndim(a)) for i in destination) @@ -2555,7 +2500,7 @@ def _moveaxis(a: Array, source: tuple[int, ...], destination: tuple[int, ...]) - @export -@partial(jit, static_argnames=('equal_nan',)) +@api.jit(static_argnames=('equal_nan',)) def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike = 1e-08, equal_nan: bool = False) -> Array: r"""Check if the elements of two arrays are approximately equal within a tolerance. @@ -2600,48 +2545,30 @@ def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike Array([ True, True, True], dtype=bool) """ a, b = util.promote_args("isclose", a, b) - dtype = _dtype(a) + dtype = a.dtype if dtypes.issubdtype(dtype, dtypes.extended): return lax.eq(a, b) a, b = util.promote_args_inexact("isclose", a, b) - dtype = _dtype(a) + dtype = a.dtype if issubdtype(dtype, np.complexfloating): - dtype = util._complex_elem_type(dtype) + dtype = np.array(0, dtype).real.dtype rtol = lax.convert_element_type(rtol, dtype) atol = lax.convert_element_type(atol, dtype) - out = lax.le( + both_nan = ufuncs.logical_and(ufuncs.isnan(a), ufuncs.isnan(b)) + check_fin = ufuncs.isfinite(b) + in_range = lax.le( lax.abs(lax.sub(a, b)), lax.add(atol, lax.mul(rtol, lax.abs(b)))) - # This corrects the comparisons for infinite and nan values - a_inf = ufuncs.isinf(a) - b_inf = ufuncs.isinf(b) - any_inf = ufuncs.logical_or(a_inf, b_inf) - both_inf = ufuncs.logical_and(a_inf, b_inf) - # Make all elements where either a or b are infinite to False - out = ufuncs.logical_and(out, ufuncs.logical_not(any_inf)) - # Make all elements where both a or b are the same inf to True - same_value = lax.eq(a, b) - same_inf = ufuncs.logical_and(both_inf, same_value) - out = ufuncs.logical_or(out, same_inf) - - # Make all elements where either a or b is NaN to False - a_nan = ufuncs.isnan(a) - b_nan = ufuncs.isnan(b) - any_nan = ufuncs.logical_or(a_nan, b_nan) - out = ufuncs.logical_and(out, ufuncs.logical_not(any_nan)) - if equal_nan: - # Make all elements where both a and b is NaN to True - both_nan = ufuncs.logical_and(a_nan, b_nan) - out = ufuncs.logical_or(out, both_nan) - return out + out = ufuncs.logical_or(lax.eq(a, b), ufuncs.logical_and(check_fin, in_range)) + return ufuncs.logical_or(out, both_nan) if equal_nan else out def _interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, left: ArrayLike | str | None = None, right: ArrayLike | str | None = None, period: ArrayLike | None = None) -> Array: - util.check_arraylike("interp", x, xp, fp) + x, xp, fp = util.ensure_arraylike("interp", x, xp, fp) if np.shape(xp) != np.shape(fp) or np.ndim(xp) != 1: raise ValueError("xp and fp must be one-dimensional arrays of equal size") x_arr, xp_arr = util.promote_dtypes_inexact(x, xp) @@ -2759,7 +2686,7 @@ def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, static_argnames.append('right') if period is None: static_argnames.append('period') - jitted_interp = jit(_interp, static_argnames=static_argnames) + jitted_interp = api.jit(_interp, static_argnames=static_argnames) return jitted_interp(x, xp, fp, left, right, period) @@ -2825,7 +2752,7 @@ def where(condition, x=None, y=None, /, *, size=None, fill_value=None): (reverse-mode differentiation), a NaN in either ``x`` or ``y`` will propagate into the gradient, regardless of the value of ``condition``. More information on this behavior and workarounds is available in the `JAX FAQ - `_. + `_. Examples: When ``x`` and ``y`` are not provided, ``where`` behaves equivalently to @@ -2918,6 +2845,12 @@ def select( raise ValueError(msg.format(len(condlist), len(choicelist))) if len(condlist) == 0: raise ValueError("condlist must be non-empty") + + util.check_arraylike("select", *condlist, *choicelist, default) + condlist = [asarray(cond) for cond in condlist] + choicelist = [asarray(choice) for choice in choicelist] + default = asarray(default) + # Put the default at front with condition False because # argmax returns zero for an array of False values. choicelist = util.promote_dtypes(default, *choicelist) @@ -2934,7 +2867,7 @@ def bincount(x: ArrayLike, weights: ArrayLike | None = None, JAX implementation of :func:`numpy.bincount`. - For an array of positive integers ``x``, this function returns an array ``counts`` + For an array of non-negative integers ``x``, this function returns an array ``counts`` of size ``x.max() + 1``, such that ``counts[i]`` contains the number of occurrences of the value ``i`` in ``x``. @@ -2947,7 +2880,7 @@ def bincount(x: ArrayLike, weights: ArrayLike | None = None, like :func:`jax.jit`. In this case, items larger than `length + 1` will be dropped. Args: - x : N-dimensional array of positive integers + x : 1-dimensional array of non-negative integers weights: optional array of weights associated with ``x``. If not specified, the weight for each entry will be ``1``. minlength: the minimum length of the output counts array. @@ -2989,11 +2922,11 @@ def bincount(x: ArrayLike, weights: ArrayLike | None = None, >>> jnp.bincount(x, length=5) Array([2, 1, 0, 1, 0], dtype=int32) """ - util.check_arraylike("bincount", x) - if _dtype(x) == bool: + x = util.ensure_arraylike("bincount", x) + if x.dtype == bool: x = lax.convert_element_type(x, 'int32') - if not issubdtype(_dtype(x), np.integer): - raise TypeError(f"x argument to bincount must have an integer type; got {_dtype(x)}") + if not issubdtype(x.dtype, np.integer): + raise TypeError(f"x argument to bincount must have an integer type; got {x.dtype}") if np.ndim(x) != 1: raise ValueError("only 1-dimensional input supported.") minlength = core.concrete_or_error(operator.index, minlength, @@ -3010,7 +2943,7 @@ def bincount(x: ArrayLike, weights: ArrayLike | None = None, weights = np.array(1, dtype=dtypes.int_) elif np.shape(x) != np.shape(weights): raise ValueError("shape of weights must match shape of x.") - return zeros(length, _dtype(weights)).at[clip(x, 0)].add(weights, mode='drop') + return array_creation.zeros(length, _dtype(weights)).at[clip(x, 0)].add(weights, mode='drop') @overload def broadcast_shapes(*shapes: Sequence[int]) -> tuple[int, ...]: ... @@ -3097,11 +3030,13 @@ def broadcast_arrays(*args: ArrayLike) -> list[Array]: .. _NumPy broadcasting: https://numpy.org/doc/stable/user/basics.broadcasting.html """ + args = util.ensure_arraylike_tuple("broadcast_arrays", args) return util._broadcast_arrays(*args) @export -def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array: +def broadcast_to(array: ArrayLike, shape: DimSize | Shape, + *, out_sharding: NamedSharding | P | None = None) -> Array: """Broadcast an array to a specified shape. JAX implementation of :func:`numpy.broadcast_to`. JAX uses NumPy-style @@ -3135,7 +3070,7 @@ def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array: .. _NumPy broadcasting: https://numpy.org/doc/stable/user/basics.broadcasting.html """ - return util._broadcast_to(array, shape) + return util._broadcast_to(array, shape, sharding=out_sharding) def _split(op: str, ary: ArrayLike, @@ -3377,7 +3312,7 @@ def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | Array @export -@jit +@api.jit def clip( arr: ArrayLike | None = None, /, @@ -3410,6 +3345,7 @@ def clip( Returns: An array containing values from ``arr``, with values smaller than ``min`` set to ``min``, and values larger than ``max`` set to ``max``. + Wherever ``min`` is larger than ``max``, the value of ``max`` is returned. See also: - :func:`jax.numpy.minimum`: Compute the element-wise minimum value of two arrays. @@ -3435,7 +3371,7 @@ def clip( ) util.check_arraylike("clip", arr) - if any(jax.numpy.iscomplexobj(t) for t in (arr, min, max)): + if any(iscomplexobj(t) for t in (arr, min, max)): raise ValueError( "Clip received a complex value either through the input or the min/max " "keywords. Complex values have no ordering and cannot be clipped. " @@ -3444,12 +3380,12 @@ def clip( if min is not None: arr = ufuncs.maximum(min, arr) if max is not None: - arr = ufuncs.minimum(max, arr) + arr = ufuncs.minimum(max, arr) # type: ignore return asarray(arr) @export -@partial(jit, static_argnames=('decimals',)) +@api.jit(static_argnames=('decimals',)) def round(a: ArrayLike, decimals: int = 0, out: None = None) -> Array: """Round input evenly to the given number of decimals. @@ -3493,7 +3429,7 @@ def round(a: ArrayLike, decimals: int = 0, out: None = None) -> Array: decimals = core.concrete_or_error(operator.index, decimals, "'decimals' argument of jnp.round") if out is not None: raise NotImplementedError("The 'out' argument to jnp.round is not supported.") - dtype = _dtype(a) + dtype = a.dtype if issubdtype(dtype, np.integer): if decimals < 0: raise NotImplementedError( @@ -3509,11 +3445,14 @@ def _round_float(x: ArrayLike) -> Array: # end due to precision problems. As a workaround for float16, convert to # float32, x = lax.convert_element_type(x, np.float32) if dtype == np.float16 else x - factor = _lax_const(x, 10 ** decimals) + factor = lax._const(x, 10 ** decimals) out = lax.div(lax.round(lax.mul(x, factor), lax.RoundingMethod.TO_NEAREST_EVEN), factor) return lax.convert_element_type(out, dtype) if dtype == np.float16 else out + if decimals > np.log10(dtypes.finfo(dtype).max): + # Rounding beyond the input precision is a no-op. + return lax.asarray(a) if issubdtype(dtype, np.complexfloating): return lax.complex(_round_float(lax.real(a)), _round_float(lax.imag(a))) else: @@ -3521,19 +3460,24 @@ def _round_float(x: ArrayLike) -> Array: @export -@partial(jit, static_argnames=('decimals',)) +@api.jit(static_argnames=('decimals',)) def around(a: ArrayLike, decimals: int = 0, out: None = None) -> Array: """Alias of :func:`jax.numpy.round`""" return round(a, decimals, out) @export -@jit +@api.jit def fix(x: ArrayLike, out: None = None) -> Array: """Round input to the nearest integer towards zero. JAX implementation of :func:`numpy.fix`. + .. warning:: + + :func:`jax.numpy.fix` is deprecated and will be removed + in JAX v0.10.0. Use :func:`jax.numpy.trunc` instead. + Args: x: input array. out: unused by JAX. @@ -3554,20 +3498,20 @@ def fix(x: ArrayLike, out: None = None) -> Array: [[ 4.48 4.79 -1.68] [-0.31 0.7 -3.34] [-1.9 1.89 2.47]] - >>> jnp.fix(x) + >>> jnp.fix(x) # doctest: +SKIP Array([[ 4., 4., -1.], [-0., 0., -3.], [-1., 1., 2.]], dtype=float32) """ - util.check_arraylike("fix", x) + x = util.ensure_arraylike("fix", x) if out is not None: raise NotImplementedError("The 'out' argument to jnp.fix is not supported.") - zero = _lax_const(x, 0) + zero = lax._const(x, 0) return where(lax.ge(x, zero), ufuncs.floor(x), ufuncs.ceil(x)) @export -@jit +@api.jit def nan_to_num(x: ArrayLike, copy: bool = True, nan: ArrayLike = 0.0, posinf: ArrayLike | None = None, neginf: ArrayLike | None = None) -> Array: @@ -3616,14 +3560,14 @@ def nan_to_num(x: ArrayLike, copy: bool = True, nan: ArrayLike = 0.0, """ del copy x = util.ensure_arraylike("nan_to_num", x) - dtype = _dtype(x) + dtype = x.dtype if not issubdtype(dtype, np.inexact): return x if issubdtype(dtype, np.complexfloating): return lax.complex( nan_to_num(lax.real(x), nan=nan, posinf=posinf, neginf=neginf), nan_to_num(lax.imag(x), nan=nan, posinf=posinf, neginf=neginf)) - info = finfo(dtypes.canonicalize_dtype(dtype)) + info = finfo(dtype) posinf = info.max if posinf is None else posinf neginf = info.min if neginf is None else neginf out = where(ufuncs.isnan(x), asarray(nan, dtype=dtype), x) @@ -3633,7 +3577,7 @@ def nan_to_num(x: ArrayLike, copy: bool = True, nan: ArrayLike = 0.0, @export -@partial(jit, static_argnames=('equal_nan',)) +@api.jit(static_argnames=('equal_nan',)) def allclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike = 1e-08, equal_nan: bool = False) -> Array: r"""Check if two arrays are element-wise approximately equal within a tolerance. @@ -3774,10 +3718,12 @@ def nonzero(a: ArrayLike, *, size: int | None = None, "The size argument of jnp.nonzero must be statically specified " "to use jnp.nonzero within JAX transformations.") if arr.size == 0 or calculated_size == 0: - return tuple(zeros(calculated_size, int) for dim in arr.shape) + return tuple(array_creation.zeros(calculated_size, int) for dim in arr.shape) flat_indices = reductions.cumsum( bincount(reductions.cumsum(mask), length=calculated_size)) - strides: np.ndarray = (np.cumprod(arr.shape[::-1])[::-1] // arr.shape).astype(dtypes.int_) + strides: np.ndarray = np.cumprod(arr.shape[::-1])[::-1] // arr.shape + if all(core.is_constant_dim(d) for d in strides): + strides = strides.astype(flat_indices.dtype) out = tuple((flat_indices // stride) % size for stride, size in zip(strides, arr.shape)) if fill_value is not None: fill_value_tup = fill_value if isinstance(fill_value, tuple) else arr.ndim * (fill_value,) @@ -3835,7 +3781,7 @@ def flatnonzero(a: ArrayLike, *, size: int | None = None, @export -@partial(jit, static_argnames=('axis',)) +@api.jit(static_argnames=('axis',)) def unwrap(p: ArrayLike, discont: ArrayLike | None = None, axis: int = -1, period: ArrayLike = 2 * np.pi) -> Array: """Unwrap a periodic signal. @@ -3852,10 +3798,15 @@ def unwrap(p: ArrayLike, discont: ArrayLike | None = None, Returns: An unwrapped copy of ``p``. + Notes: + This implementation follows that of :func:`numpy.unwrap`, and is not + well-suited for integer-period unwrapping of narrow-width integers + (e.g. `int8`, `int16`) or unsigned integers. + Examples: Consider a situation in which you are making measurements of the position of a rotating disk via the ``x`` and ``y`` locations of some point on that disk. - The underlying variable is an always-increating angle which we'll generate + The underlying variable is an always-increasing angle which we'll generate this way, using degrees for ease of representation: >>> rng = np.random.default_rng(0) @@ -3893,13 +3844,20 @@ def unwrap(p: ArrayLike, discont: ArrayLike | None = None, that satisfy this assumption, :func:`unwrap` can recover the original phased signal. """ p = util.ensure_arraylike("unwrap", p) + p, period = util.promote_dtypes(p, period) + if issubdtype(p.dtype, np.complexfloating): raise ValueError("jnp.unwrap does not support complex inputs.") if p.shape[axis] == 0: - return util.promote_dtypes_inexact(p)[0] + return p + if discont is None: discont = period / 2 - interval = period / 2 + if dtypes.issubdtype(p.dtype, np.integer): + interval = period // 2 + else: + interval = period / 2 + dd = diff(p, axis=axis) ddmod = ufuncs.mod(dd + interval, period) - interval ddmod = where((ddmod == -interval) & (dd > 0), interval, ddmod) @@ -3907,8 +3865,8 @@ def unwrap(p: ArrayLike, discont: ArrayLike | None = None, ph_correct = where(ufuncs.abs(dd) < discont, 0, ddmod - dd) up = concatenate(( - lax.slice_in_dim(p, 0, 1, axis=axis), - lax.slice_in_dim(p, 1, None, axis=axis) + reductions.cumsum(ph_correct, axis=axis) + lax_slicing.slice_in_dim(p, 0, 1, axis=axis), + lax_slicing.slice_in_dim(p, 1, None, axis=axis) + reductions.cumsum(ph_correct, axis=axis) ), axis=axis) return up @@ -3974,7 +3932,7 @@ def _check_no_padding(axis_padding: tuple[Any, Any], mode: str): def _pad_constant(array: Array, pad_width: PadValue[int], constant_values: Array) -> Array: nd = np.ndim(array) - constant_values = lax_internal._convert_element_type( + constant_values = lax._convert_element_type( constant_values, array.dtype, dtypes.is_weakly_typed(array)) constant_values_nd = np.ndim(constant_values) @@ -4014,10 +3972,10 @@ def _pad_wrap(array: Array, pad_width: PadValue[int]) -> Array: total_repeats = left_repeats + right_repeats + 1 parts = [] if left_remainder > 0: - parts += [lax.slice_in_dim(array, size - left_remainder, size, axis=i)] + parts += [lax_slicing.slice_in_dim(array, size - left_remainder, size, axis=i)] parts += total_repeats * [array] if right_remainder > 0: - parts += [lax.slice_in_dim(array, 0, right_remainder, axis=i)] + parts += [lax_slicing.slice_in_dim(array, 0, right_remainder, axis=i)] array = lax.concatenate(parts, dimension=i) return array @@ -4028,17 +3986,18 @@ def _pad_symmetric_or_reflect(array: Array, pad_width: PadValue[int], assert reflect_type in ("even", "odd") for i in range(np.ndim(array)): - if array.shape[i] == 0: + axis_size = array.shape[i] + if axis_size == 0: _check_no_padding(pad_width[i], mode) continue - - axis_size = array.shape[i] + if pad_width[i][0] == 0 and pad_width[i][1] == 0: + continue def build_padding(array, padding, before): if before: - edge = lax.slice_in_dim(array, 0, 1, axis=i) + edge = lax_slicing.slice_in_dim(array, 0, 1, axis=i) else: - edge = lax.slice_in_dim(array, -1, None, axis=i) + edge = lax_slicing.slice_in_dim(array, -1, None, axis=i) # Try to give nicer error messages for unsupported shape polymorphic uses shape_poly_error_msg = lambda: ( @@ -4069,16 +4028,16 @@ def build_padding(array, padding, before): start = -(curr_pad + offset) stop = None if (mode == "symmetric" or axis_size == 1) else -1 - x = lax.slice_in_dim(array, start, stop, axis=i) + x = lax_slicing.slice_in_dim(array, start, stop, axis=i) x = flip(x, axis=i) if reflect_type == 'odd': x = 2 * edge - x if axis_size > 1: if before: - edge = lax.slice_in_dim(x, 0, 1, axis=i) + edge = lax_slicing.slice_in_dim(x, 0, 1, axis=i) else: - edge = lax.slice_in_dim(x, -1, None, axis=i) + edge = lax_slicing.slice_in_dim(x, -1, None, axis=i) if before: array = lax.concatenate([x, array], dimension=i) @@ -4101,10 +4060,10 @@ def _pad_edge(array: Array, pad_width: PadValue[int]) -> Array: n = array.shape[i] npad_before, npad_after = pad_width[i] - edge_before = lax.slice_in_dim(array, 0, 1, axis=i) + edge_before = lax_slicing.slice_in_dim(array, 0, 1, axis=i) pad_before = repeat(edge_before, npad_before, axis=i) - edge_after = lax.slice_in_dim(array, n-1, n, axis=i) + edge_after = lax_slicing.slice_in_dim(array, n-1, n, axis=i) pad_after = repeat(edge_after, npad_after, axis=i) array = lax.concatenate([pad_before, array, pad_after], dimension=i) @@ -4114,9 +4073,9 @@ def _pad_edge(array: Array, pad_width: PadValue[int]) -> Array: def _pad_linear_ramp(array: Array, pad_width: PadValue[int], end_values: PadValue[ArrayLike]) -> Array: for axis in range(np.ndim(array)): - edge_before = lax.slice_in_dim(array, 0, 1, axis=axis) - edge_after = lax.slice_in_dim(array, -1, None, axis=axis) - ramp_before = linspace( + edge_before = lax_slicing.slice_in_dim(array, 0, 1, axis=axis) + edge_after = lax_slicing.slice_in_dim(array, -1, None, axis=axis) + ramp_before = array_creation.linspace( start=end_values[axis][0], stop=edge_before.squeeze(axis), # Dimension is replaced by linspace num=pad_width[axis][0], @@ -4124,9 +4083,9 @@ def _pad_linear_ramp(array: Array, pad_width: PadValue[int], dtype=array.dtype, axis=axis ) - ramp_before = lax_internal._convert_element_type( + ramp_before = lax._convert_element_type( ramp_before, weak_type=dtypes.is_weakly_typed(array)) - ramp_after = linspace( + ramp_after = array_creation.linspace( start=end_values[axis][1], stop=edge_after.squeeze(axis), # Dimension is replaced by linspace num=pad_width[axis][1], @@ -4134,7 +4093,7 @@ def _pad_linear_ramp(array: Array, pad_width: PadValue[int], dtype=array.dtype, axis=axis ) - ramp_after = lax_internal._convert_element_type( + ramp_after = lax._convert_element_type( ramp_after, weak_type=dtypes.is_weakly_typed(array)) # Reverse linear space in appropriate dimension @@ -4162,8 +4121,8 @@ def _pad_stats(array: Array, pad_width: PadValue[int], length_before = min(length_before, array_length) length_after = min(length_after, array_length) - slice_before = lax.slice_in_dim(array, 0, length_before, axis=i) - slice_after = lax.slice_in_dim(array, -length_after, None, axis=i) + slice_before = lax_slicing.slice_in_dim(array, 0, length_before, axis=i) + slice_after = lax_slicing.slice_in_dim(array, -length_after, None, axis=i) stat_before = stat_func(slice_before, axis=i, keepdims=True) stat_after = stat_func(slice_after, axis=i, keepdims=True) @@ -4171,9 +4130,9 @@ def _pad_stats(array: Array, pad_width: PadValue[int], stat_before = round(stat_before) stat_after = round(stat_after) - stat_before = lax_internal._convert_element_type( + stat_before = lax._convert_element_type( stat_before, array.dtype, dtypes.is_weakly_typed(array)) - stat_after = lax_internal._convert_element_type( + stat_after = lax._convert_element_type( stat_after, array.dtype, dtypes.is_weakly_typed(array)) npad_before, npad_after = pad_width[i] @@ -4188,10 +4147,10 @@ def _pad_empty(array: Array, pad_width: PadValue[int]) -> Array: # Note: jax.numpy.empty = jax.numpy.zeros for i in range(np.ndim(array)): shape_before = array.shape[:i] + (pad_width[i][0],) + array.shape[i + 1:] - pad_before = empty_like(array, shape=shape_before) + pad_before = array_creation.empty_like(array, shape=shape_before) shape_after = array.shape[:i] + (pad_width[i][1],) + array.shape[i + 1:] - pad_after = empty_like(array, shape=shape_after) + pad_after = array_creation.empty_like(array, shape=shape_after) array = lax.concatenate([pad_before, array, pad_after], dimension=i) return array @@ -4204,7 +4163,7 @@ def _pad_func(array: Array, pad_width: PadValue[int], func: Callable[..., Any], return padded -@partial(jit, static_argnums=(1, 2, 4, 5, 6)) +@api.jit(static_argnums=(1, 2, 4, 5, 6)) def _pad(array: ArrayLike, pad_width: PadValueLike[int], mode: str, constant_values: ArrayLike, stat_length: PadValueLike[int], end_values: PadValueLike[ArrayLike], reflect_type: str): @@ -4376,7 +4335,7 @@ def pad_func(row: Array, pad_width: tuple[int, int], Array([-10, -10, 2, 3, 4, 10, 10], dtype=int32) """ - util.check_arraylike("pad", array) + array = util.ensure_arraylike("pad", array) pad_width = _broadcast_to_pairs(pad_width, np.ndim(array), "pad_width") if pad_width and not all(core.is_dim(p[0]) and core.is_dim(p[1]) for p in pad_width): @@ -4471,7 +4430,7 @@ def stack(arrays: np.ndarray | Array | Sequence[ArrayLike], axis = _canonicalize_axis(axis, arrays.ndim) return concatenate(expand_dims(arrays, axis + 1), axis=axis, dtype=dtype) else: - util.check_arraylike("stack", *arrays) + arrays = util.ensure_arraylike_tuple("stack", arrays) shape0 = np.shape(arrays[0]) axis = _canonicalize_axis(axis, len(shape0) + 1) new_arrays = [] @@ -4483,7 +4442,7 @@ def stack(arrays: np.ndarray | Array | Sequence[ArrayLike], @export -@partial(jit, static_argnames="axis") +@api.jit(static_argnames="axis") def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]: """Unstack an array along an axis. @@ -4560,7 +4519,7 @@ def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array: [1, 2], [3, 4]], dtype=int32) """ - util.check_arraylike("tile", A) + A = util.ensure_arraylike("tile", A) try: iter(reps) # type: ignore[arg-type] except TypeError: @@ -4569,11 +4528,13 @@ def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array: reps_tup = tuple(reps) # type: ignore[arg-type] reps_tup = tuple(operator.index(rep) if core.is_constant_dim(rep) else rep for rep in reps_tup) - A_shape = (1,) * (len(reps_tup) - np.ndim(A)) + np.shape(A) - reps_tup = (1,) * (len(A_shape) - len(reps_tup)) + reps_tup - result = broadcast_to(reshape(A, [j for i in A_shape for j in [1, i]]), - [k for pair in zip(reps_tup, A_shape) for k in pair]) - return reshape(result, tuple(np.multiply(A_shape, reps_tup))) + # lax.tile expects reps and A.shape to have the same rank. + reps_tup = (1,) * (A.ndim - len(reps_tup)) + reps_tup + if len(reps_tup) > np.ndim(A): + A = lax.expand_dims( + A, dimensions=tuple(range(len(reps_tup) - np.ndim(A)))) + return lax.tile(A, reps_tup) + def _concatenate_array(arr: ArrayLike, axis: int | None, dtype: DTypeLike | None = None) -> Array: @@ -4603,7 +4564,8 @@ def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike], except along the specified axis. If a single array is given it will be treated equivalently to `arrays = unstack(arrays)`, but the implementation will avoid explicit unstacking. - axis: specify the axis along which to concatenate. + axis: specify the axis along which to concatenate. If None, the arrays are + flattened before concatenation. dtype: optional dtype of the resulting array. If not specified, the dtype will be determined via type promotion rules described in :ref:`type-promotion`. @@ -4633,7 +4595,7 @@ def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike], """ if isinstance(arrays, (np.ndarray, Array)): return _concatenate_array(arrays, axis, dtype=dtype) - util.check_arraylike("concatenate", *arrays) + arrays = util.ensure_arraylike_tuple("concatenate", arrays) if not len(arrays): raise ValueError("Need at least one array to concatenate.") if axis is None: @@ -4693,7 +4655,7 @@ def concat(arrays: Sequence[ArrayLike], /, *, axis: int | None = 0) -> Array: [1., 1., 1., 0.]], dtype=float32) """ util.check_arraylike("concat", *arrays) - return jax.numpy.concatenate(arrays, axis=axis) + return concatenate(arrays, axis=axis) @export @@ -4749,7 +4711,7 @@ def vstack(tup: np.ndarray | Array | Sequence[ArrayLike], """ arrs: Array | list[Array] if isinstance(tup, (np.ndarray, Array)): - arrs = jax.vmap(atleast_2d)(tup) + arrs = api.vmap(atleast_2d)(tup) else: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("vstack", *tup, emit_warning=True) @@ -4808,7 +4770,7 @@ def hstack(tup: np.ndarray | Array | Sequence[ArrayLike], """ arrs: Array | list[Array] if isinstance(tup, (np.ndarray, Array)): - arrs = jax.vmap(atleast_1d)(tup) + arrs = api.vmap(atleast_1d)(tup) arr0_ndim = arrs.ndim - 1 else: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. @@ -4871,10 +4833,11 @@ def dstack(tup: np.ndarray | Array | Sequence[ArrayLike], """ arrs: Array | list[Array] if isinstance(tup, (np.ndarray, Array)): - arrs = jax.vmap(atleast_3d)(tup) + arrs = api.vmap(atleast_3d)(tup) else: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("dstack", *tup, emit_warning=True) + tup = util.ensure_arraylike_tuple("dstack", tup) arrs = [atleast_3d(m) for m in tup] return concatenate(arrs, axis=2, dtype=dtype) @@ -4904,7 +4867,7 @@ def column_stack(tup: np.ndarray | Array | Sequence[ArrayLike]) -> Array: - :func:`jax.numpy.concatenate`: concatenation along existing axes. - :func:`jax.numpy.vstack`: stack vertically, i.e. along axis 0. - :func:`jax.numpy.hstack`: stack horizontally, i.e. along axis 1. - - :func:`jax.numpy.hstack`: stack depth=wise, i.e. along axis 2. + - :func:`jax.numpy.dstack`: stack depth-wise, i.e. along axis 2. Examples: Scalar values: @@ -4932,7 +4895,7 @@ def column_stack(tup: np.ndarray | Array | Sequence[ArrayLike]) -> Array: """ arrs: Array | list[Array] | np.ndarray if isinstance(tup, (np.ndarray, Array)): - arrs = jax.vmap(lambda x: atleast_2d(x).T)(tup) if tup.ndim < 3 else tup + arrs = api.vmap(lambda x: atleast_2d(x).T)(tup) if tup.ndim < 3 else tup else: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("column_stack", *tup, emit_warning=True) @@ -5022,8 +4985,8 @@ def choose(a, choices): """ if out is not None: raise NotImplementedError("The 'out' argument to jnp.choose is not supported.") - util.check_arraylike('choose', a, *choices) - if not issubdtype(_dtype(a), np.integer): + a, *choices = util.ensure_arraylike_tuple('choose', (a, *choices)) + if not issubdtype(a.dtype, np.integer): raise ValueError("`a` array must be integer typed") N = len(choices) @@ -5066,7 +5029,7 @@ def _block(xs: ArrayLike | list[ArrayLike]) -> tuple[Array, int]: @export -@jit +@api.jit def block(arrays: ArrayLike | list[ArrayLike]) -> Array: """Create an array from a list of blocks. @@ -5150,7 +5113,7 @@ def atleast_1d(x: ArrayLike, /) -> Array: def atleast_1d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... @export -@jit +@api.jit def atleast_1d(*arys: ArrayLike) -> Array | list[Array]: """Convert inputs to arrays with at least 1 dimension. @@ -5205,7 +5168,7 @@ def atleast_2d(x: ArrayLike, /) -> Array: def atleast_2d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... @export -@jit +@api.jit def atleast_2d(*arys: ArrayLike) -> Array | list[Array]: """Convert inputs to arrays with at least 2 dimensions. @@ -5269,7 +5232,7 @@ def atleast_3d(x: ArrayLike, /) -> Array: def atleast_3d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... @export -@jit +@api.jit def atleast_3d(*arys: ArrayLike) -> Array | list[Array]: """Convert inputs to arrays with at least 3 dimensions. @@ -5336,252 +5299,13 @@ def atleast_3d(*arys: ArrayLike) -> Array | list[Array]: return [atleast_3d(arr) for arr in arys] -def _supports_buffer_protocol(obj): - try: - view = memoryview(obj) - except TypeError: - return False - else: - return True - - -def _make_string_array( - object: np.ndarray, - dtype: DTypeLike | None = None, - ndmin: int = 0, - device: xc.Device | Sharding | None = None, -) -> Array: - if not isinstance(object, np.ndarray): - raise TypeError( - "Currently, string arrays can only be made from NumPy" - f" arrays. Got: {type(object)}." - ) - if dtype is not None and ( - dtypes.is_string_dtype(object.dtype) != dtypes.is_string_dtype(dtype) - ): - raise TypeError( - f"Cannot make an array with dtype {dtype} from an object with dtype" - f" {object.dtype}." - ) - if ndmin > object.ndim: - raise TypeError( - f"ndmin {ndmin} cannot be greater than object's ndims" - f" {object.ndim} for string arrays." - ) - - # Just do a device_put since XLA does not support string as a data type. - return jax.device_put(x=object, device=device) - - -@export -def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, - order: str | None = "K", ndmin: int = 0, - *, device: xc.Device | Sharding | None = None) -> Array: - """Convert an object to a JAX array. - - JAX implementation of :func:`numpy.array`. - - Args: - object: an object that is convertible to an array. This includes JAX - arrays, NumPy arrays, Python scalars, Python collections like lists - and tuples, objects with an ``__array__`` method, and objects - supporting the Python buffer protocol. - dtype: optionally specify the dtype of the output array. If not - specified it will be inferred from the input. - copy: specify whether to force a copy of the input. Default: True. - order: not implemented in JAX - ndmin: integer specifying the minimum number of dimensions in the - output array. - device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` - to which the created array will be committed. - - Returns: - A JAX array constructed from the input. - - See also: - - :func:`jax.numpy.asarray`: like `array`, but by default only copies - when necessary. - - :func:`jax.numpy.from_dlpack`: construct a JAX array from an object - that implements the dlpack interface. - - :func:`jax.numpy.frombuffer`: construct a JAX array from an object - that implements the buffer interface. - - Examples: - Constructing JAX arrays from Python scalars: - - >>> jnp.array(True) - Array(True, dtype=bool) - >>> jnp.array(42) - Array(42, dtype=int32, weak_type=True) - >>> jnp.array(3.5) - Array(3.5, dtype=float32, weak_type=True) - >>> jnp.array(1 + 1j) - Array(1.+1.j, dtype=complex64, weak_type=True) - - Constructing JAX arrays from Python collections: - - >>> jnp.array([1, 2, 3]) # list of ints -> 1D array - Array([1, 2, 3], dtype=int32) - >>> jnp.array([(1, 2, 3), (4, 5, 6)]) # list of tuples of ints -> 2D array - Array([[1, 2, 3], - [4, 5, 6]], dtype=int32) - >>> jnp.array(range(5)) - Array([0, 1, 2, 3, 4], dtype=int32) - - Constructing JAX arrays from NumPy arrays: - - >>> jnp.array(np.linspace(0, 2, 5)) - Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32) - - Constructing a JAX array via the Python buffer interface, using Python's - built-in :mod:`array` module. - - >>> from array import array - >>> pybuffer = array('i', [2, 3, 5, 7]) - >>> jnp.array(pybuffer) - Array([2, 3, 5, 7], dtype=int32) - """ - if order is not None and order != "K": - raise NotImplementedError("Only implemented for order='K'") - - # check if the given dtype is compatible with JAX - dtypes.check_user_dtype_supported(dtype, "array") - - # Here we make a judgment call: we only return a weakly-typed array when the - # input object itself is weakly typed. That ensures asarray(x) is a no-op - # whenever x is weak, but avoids introducing weak types with something like - # array([1, 2, 3]) - weak_type = dtype is None and dtypes.is_weakly_typed(object) - if device is None and isinstance(object, core.Tracer): - sharding = object.aval.sharding - sharding = None if sharding.mesh.empty else sharding - else: - sharding = canonicalize_device_to_sharding(device) - - # Use device_put to avoid a copy for ndarray inputs. - if (not copy and isinstance(object, np.ndarray) and - (dtype is None or dtype == object.dtype) and (ndmin <= object.ndim) and - device is None): - # Keep the output uncommitted. - return jax.device_put(object) - - # String arrays need separate handling because XLA does not support string - # as a data type. - if dtypes.is_string_dtype(dtype) or ( - hasattr(object, "dtype") and dtypes.is_string_dtype(object.dtype) - ): - return _make_string_array( - object=object, dtype=dtype, ndmin=ndmin, device=device - ) - - # For Python scalar literals, call coerce_to_array to catch any overflow - # errors. We don't use dtypes.is_python_scalar because we don't want this - # triggering for traced values. We do this here because it matters whether or - # not dtype is None. We don't assign the result because we want the raw object - # to be used for type inference below. - if isinstance(object, (bool, int, float, complex)): - _ = dtypes.coerce_to_array(object, dtype) - elif not isinstance(object, Array): - # Check if object supports any of the data exchange protocols - # (except dlpack, see data-apis/array-api#301). If it does, - # consume the object as jax array and continue (but not return) so - # that other array() arguments get processed against the input - # object. - # - # Notice that data exchange protocols define dtype in the - # corresponding data structures and it may not be available as - # object.dtype. So, we'll resolve the protocols here before - # evaluating object.dtype. - if hasattr(object, '__jax_array__'): - object = object.__jax_array__() - elif hasattr(object, '__cuda_array_interface__'): - cai = object.__cuda_array_interface__ - backend = xla_bridge.get_backend("cuda") - if cuda_plugin_extension is None: - device_id = None - else: - device_id = cuda_plugin_extension.get_device_ordinal(cai["data"][0]) - object = xc._xla.cuda_array_interface_to_buffer( - cai=cai, gpu_backend=backend, device_id=device_id) - - object = tree_map(lambda leaf: leaf.__jax_array__() - if hasattr(leaf, "__jax_array__") else leaf, object) - leaves = tree_leaves(object, is_leaf=lambda x: x is None) - if any(leaf is None for leaf in leaves): - # Added Nov 16 2023 - if deprecations.is_accelerated("jax-numpy-array-none"): - raise TypeError("None is not a valid value for jnp.array") - warnings.warn( - "None encountered in jnp.array(); this is currently treated as NaN. " - "In the future this will result in an error.", - FutureWarning, stacklevel=2) - leaves = tree_leaves(object) - if dtype is None: - # Use lattice_result_type rather than result_type to avoid canonicalization. - # Otherwise, weakly-typed inputs would have their dtypes canonicalized. - try: - dtype = dtypes._lattice_result_type(*leaves)[0] if leaves else dtypes.float_ - except TypeError: - # This happens if, e.g. one of the entries is a memoryview object. - # This is rare, so we only handle it if the normal path fails. - leaves = [_convert_to_array_if_dtype_fails(leaf) for leaf in leaves] - dtype = dtypes._lattice_result_type(*leaves)[0] - - if not weak_type: - dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment] - - out: ArrayLike - - if all(not isinstance(leaf, Array) for leaf in leaves): - # TODO(jakevdp): falling back to numpy here fails to overflow for lists - # containing large integers; see discussion in - # https://github.com/jax-ml/jax/pull/6047. More correct would be to call - # coerce_to_array on each leaf, but this may have performance implications. - out = np.asarray(object, dtype=dtype) - elif isinstance(object, Array): - assert object.aval is not None - out = _array_copy(object) if copy else object - elif isinstance(object, (list, tuple)): - if object: - out = stack([asarray(elt, dtype=dtype) for elt in object]) - else: - out = np.array([], dtype=dtype) - elif _supports_buffer_protocol(object): - object = memoryview(object) - # TODO(jakevdp): update this once we support NumPy 2.0 semantics for the copy arg. - out = np.array(object) if copy else np.asarray(object) - else: - raise TypeError(f"Unexpected input type for array: {type(object)}") - out_array: Array = lax_internal._convert_element_type( - out, dtype, weak_type=weak_type, sharding=sharding) - if ndmin > np.ndim(out_array): - out_array = lax.expand_dims(out_array, range(ndmin - np.ndim(out_array))) - return out_array - - -def canonicalize_device_to_sharding(device: xc.Device | Sharding | None - ) -> Sharding | None: - if isinstance(device, xc.Device): - return SingleDeviceSharding(device) - return device - - -def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike: - try: - dtypes.dtype(x) - except TypeError: - return np.asarray(x) - else: - return x - - @export def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = False, device: xc.Device | Sharding | None = None) -> Array: """Convert an array to a specified dtype. - JAX imlementation of :func:`numpy.astype`. + JAX implementation of :func:`numpy.astype`. This is implemented via :func:`jax.lax.convert_element_type`, which may have slightly different behavior than :func:`numpy.astype` in some cases. @@ -5617,8 +5341,9 @@ def astype(x: ArrayLike, dtype: DTypeLike | None, x_arr = util.ensure_arraylike("astype", x) if dtype is None: - dtype = dtypes.canonicalize_dtype(dtypes.float_) - dtypes.check_user_dtype_supported(dtype, "astype") + dtype = dtypes.default_float_dtype() + else: + dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "astype") if issubdtype(x_arr.dtype, np.complexfloating): if dtypes.isdtype(dtype, ("integral", "real floating")): deprecations.warn( @@ -5629,96 +5354,14 @@ def astype(x: ArrayLike, dtype: DTypeLike | None, stacklevel=2) elif np.dtype(dtype) == bool: # convert_element_type(complex, bool) has the wrong semantics. - x_arr = (x_arr != _lax_const(x_arr, 0)) + x_arr = (x_arr != lax._const(x_arr, 0)) # We offer a more specific warning than the usual ComplexWarning so we prefer # to issue our warning. - result = lax_internal._convert_element_type( - x_arr, dtype, sharding=util.normalize_device_to_sharding(device), + result = lax._convert_element_type( + x_arr, dtype, sharding=util.canonicalize_device_to_sharding(device), warn_on_complex_to_real_cast=False) - return _array_copy(result) if copy else result - - -@export -def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None, - *, copy: bool | None = None, - device: xc.Device | Sharding | None = None) -> Array: - """Convert an object to a JAX array. - - JAX implementation of :func:`numpy.asarray`. - - Args: - a: an object that is convertible to an array. This includes JAX - arrays, NumPy arrays, Python scalars, Python collections like lists - and tuples, objects with an ``__array__`` method, and objects - supporting the Python buffer protocol. - dtype: optionally specify the dtype of the output array. If not - specified it will be inferred from the input. - order: not implemented in JAX - copy: optional boolean specifying the copy mode. If True, then always - return a copy. If False, then error if a copy is necessary. Default is - None, which will only copy when necessary. - device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` - to which the created array will be committed. - - Returns: - A JAX array constructed from the input. - - See also: - - :func:`jax.numpy.array`: like `asarray`, but defaults to `copy=True`. - - :func:`jax.numpy.from_dlpack`: construct a JAX array from an object - that implements the dlpack interface. - - :func:`jax.numpy.frombuffer`: construct a JAX array from an object - that implements the buffer interface. - - Examples: - Constructing JAX arrays from Python scalars: - - >>> jnp.asarray(True) - Array(True, dtype=bool) - >>> jnp.asarray(42) - Array(42, dtype=int32, weak_type=True) - >>> jnp.asarray(3.5) - Array(3.5, dtype=float32, weak_type=True) - >>> jnp.asarray(1 + 1j) - Array(1.+1.j, dtype=complex64, weak_type=True) - - Constructing JAX arrays from Python collections: - - >>> jnp.asarray([1, 2, 3]) # list of ints -> 1D array - Array([1, 2, 3], dtype=int32) - >>> jnp.asarray([(1, 2, 3), (4, 5, 6)]) # list of tuples of ints -> 2D array - Array([[1, 2, 3], - [4, 5, 6]], dtype=int32) - >>> jnp.asarray(range(5)) - Array([0, 1, 2, 3, 4], dtype=int32) - - Constructing JAX arrays from NumPy arrays: - - >>> jnp.asarray(np.linspace(0, 2, 5)) - Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32) - - Constructing a JAX array via the Python buffer interface, using Python's - built-in :mod:`array` module. - - >>> from array import array - >>> pybuffer = array('i', [2, 3, 5, 7]) - >>> jnp.asarray(pybuffer) - Array([2, 3, 5, 7], dtype=int32) - """ - # For copy=False, the array API specifies that we raise a ValueError if the input supports - # the buffer protocol but a copy is required. Since array() supports the buffer protocol - # via numpy, this is only the case when the default device is not 'cpu' - if (copy is False and not isinstance(a, Array) - and jax.default_backend() != 'cpu' - and _supports_buffer_protocol(a)): - raise ValueError(f"jnp.asarray: cannot convert object of type {type(a)} to JAX Array " - f"on backend={jax.default_backend()!r} with copy=False. " - "Consider using copy=None or copy=True instead.") - dtypes.check_user_dtype_supported(dtype, "asarray") - if dtype is not None: - dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment] - return array(a, dtype=dtype, copy=bool(copy), order=order, device=device) + return lax._array_copy(result) if copy else result @export @@ -5910,14 +5553,14 @@ def fromfile(*args, **kwargs): ``jnp.asarray(np.fromfile(...))`` instead, although care should be taken if ``np.fromfile`` is used within jax transformations because of its potential side-effect of consuming the file object; for more information see `Common Gotchas: Pure Functions - `_. + `_. """ raise NotImplementedError( "jnp.fromfile() is not implemented because it may be non-pure and thus unsafe for use " "with JIT and other JAX transformations. Consider using jnp.asarray(np.fromfile(...)) " "instead, although care should be taken if np.fromfile is used within a jax transformations " "because of its potential side-effect of consuming the file object; for more information see " - "https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions") + "https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions") @export @@ -5929,14 +5572,14 @@ def fromiter(*args, **kwargs): ``jnp.asarray(np.fromiter(...))`` instead, although care should be taken if ``np.fromiter`` is used within jax transformations because of its potential side-effect of consuming the iterable object; for more information see `Common Gotchas: Pure Functions - `_. + `_. """ raise NotImplementedError( "jnp.fromiter() is not implemented because it may be non-pure and thus unsafe for use " "with JIT and other JAX transformations. Consider using jnp.asarray(np.fromiter(...)) " "instead, although care should be taken if np.fromiter is used within a jax transformations " "because of its potential side-effect of consuming the iterable object; for more information see " - "https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions") + "https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions") @export @@ -5963,7 +5606,7 @@ def from_dlpack(x: Any, /, *, device: xc.Device | Sharding | None = None, if needed for a device transfer. Returns: - A JAX array of the imput buffer. + A JAX array of the input buffer. Note: While JAX arrays are always immutable, dlpack buffers cannot be marked as @@ -6083,7 +5726,7 @@ def fromfunction(function: Callable[..., Array], shape: Any, shape = core.canonicalize_shape(shape, context="shape argument of jnp.fromfunction()") for i in range(len(shape)): in_axes = [0 if i == j else None for j in range(len(shape))] - function = jax.vmap(function, in_axes=tuple(in_axes[::-1])) + function = api.vmap(function, in_axes=tuple(in_axes[::-1])) return function(*(arange(s, dtype=dtype) for s in shape), **kwargs) @@ -6172,16 +5815,17 @@ def eye(N: DimSize, M: DimSize | None = None, # instead of putting it on default device and then on the specific device output = _eye(N, M=M, k=k, dtype=dtype) if device is not None: - return jax.device_put(output, device=device) + return api.device_put(output, device=device) return output def _eye(N: DimSize, M: DimSize | None = None, k: int | ArrayLike = 0, dtype: DTypeLike | None = None) -> Array: - dtypes.check_user_dtype_supported(dtype, "eye") + dtype = dtypes.check_and_canonicalize_user_dtype( + float if dtype is None else dtype, "eye") if isinstance(k, int): - k = lax_internal._clip_int_to_valid_range(k, np.int32, + k = lax._clip_int_to_valid_range(k, np.int32, "`argument `k` of jax.numpy.eye") offset = util.ensure_arraylike("eye", k) if not (offset.shape == () and dtypes.issubdtype(offset.dtype, np.integer)): @@ -6225,14 +5869,16 @@ def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array: Array([[1, 0], [0, 1]], dtype=int32) """ - dtypes.check_user_dtype_supported(dtype, "identity") + if dtype is not None: + dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "identity") return eye(n, dtype=dtype) @export def arange(start: ArrayLike | DimSize, stop: ArrayLike | DimSize | None = None, step: ArrayLike | None = None, dtype: DTypeLike | None = None, - *, device: xc.Device | Sharding | None = None) -> Array: + *, device: xc.Device | Sharding | None = None, + out_sharding: NamedSharding | P | None = None) -> Array: """Create an array of evenly-spaced values. JAX implementation of :func:`numpy.arange`, implemented in terms of @@ -6259,6 +5905,10 @@ def arange(start: ArrayLike | DimSize, stop: ArrayLike | DimSize | None = None, be determined via type promotion of `start`, `stop`, and `step`. device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` to which the created array will be committed. + out_sharding: (optional) :class:`~jax.NamedSharding` or :class:`~jax.P` to + which the created array will be committed. Use `out_sharding` argument, + if using explicit sharding + (https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html) Returns: Array of evenly-spaced values from ``start`` to ``stop``, separated by ``step``. @@ -6301,53 +5951,78 @@ def arange(start: ArrayLike | DimSize, stop: ArrayLike | DimSize | None = None, - :func:`jax.numpy.linspace`: generate a fixed number of evenly-spaced values. - :func:`jax.lax.iota`: directly generate integer sequences in XLA. """ - # TODO(vfdev-5): optimize putting the array directly on the device specified - # instead of putting it on default device and then on the specific device - output = _arange(start, stop=stop, step=step, dtype=dtype) - if device is not None: - return jax.device_put(output, device=device) - return output + sharding = util.choose_device_or_out_sharding( + device, out_sharding, 'jnp.arange') + if sharding is None or not sharding._is_concrete: + assert sharding is None or isinstance(sharding, NamedSharding) + return _arange(start, stop=stop, step=step, dtype=dtype, + out_sharding=sharding) + else: + output = _arange(start, stop=stop, step=step, dtype=dtype) + return api.device_put(output, sharding) def _arange(start: ArrayLike | DimSize, stop: ArrayLike | DimSize | None = None, - step: ArrayLike | None = None, dtype: DTypeLike | None = None) -> Array: - dtypes.check_user_dtype_supported(dtype, "arange") - if not config.dynamic_shapes.value: - util.check_arraylike("arange", start) - if stop is None and step is None: - start = core.concrete_or_error(None, start, "It arose in the jnp.arange argument 'stop'") - else: - start = core.concrete_or_error(None, start, "It arose in the jnp.arange argument 'start'") - util.check_arraylike_or_none("arange", None, stop, step) + step: ArrayLike | None = None, dtype: DTypeLike | None = None, + out_sharding: NamedSharding | None = None) -> Array: + # Validate inputs + if dtype is not None: + dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "arange") + util.check_arraylike_or_none("arange", start, stop, step) + + # Ensure start/stop/step are concrete + start_name = "stop" if stop is None and step is None else "start" + start = core.concrete_or_error(None, start, f"It arose in the jnp.arange argument '{start_name}'") stop = core.concrete_or_error(None, stop, "It arose in the jnp.arange argument 'stop'") step = core.concrete_or_error(None, step, "It arose in the jnp.arange argument 'step'") - start_name = "stop" if stop is None and step is None else "start" + + # Ensure start/stop/step are scalars for name, val in [(start_name, start), ("stop", stop), ("step", step)]: if val is not None and np.ndim(val) != 0: raise ValueError(f"jax.numpy.arange: arguments must be scalars; got {name}={val}") + + # Handle symbolic dimensions if any(core.is_symbolic_dim(v) for v in (start, stop, step)): - # Some dynamic shapes - if stop is None and step is None: - stop = start - start = 0 + if stop is None: + start, stop = 0, start + if step is None: step = 1 - elif stop is not None and step is None: - step = 1 - return _arange_dynamic(start, stop, step, dtype or dtypes.canonicalize_dtype(np.int64)) + return _arange_dynamic(start, stop, step, dtype or dtypes.default_int_dtype()) + if dtype is None: - dtype = result_type(start, *(x for x in [stop, step] if x is not None)) + dtype = dtypes.result_type(start, *(x for x in [stop, step] if x is not None)) dtype = dtypes.jax_dtype(dtype) - if stop is None and step is None: - start_dtype = _dtype(start) - if (not dtypes.issubdtype(start_dtype, np.integer) and - not dtypes.issubdtype(start_dtype, dtypes.extended)): - ceil_ = ufuncs.ceil if isinstance(start, core.Tracer) else np.ceil - start = ceil_(start).astype(int) - return lax.iota(dtype, start) # type: ignore[arg-type] + + if iscomplexobj(start) or iscomplexobj(stop) or iscomplexobj(step): + deprecations.warn( + "jax-numpy-arange-complex", + ( + "Passing complex start/stop/step to jnp.arange is deprecated;" + " in the future this will result in a ValueError." + ), + stacklevel=3 + ) + # Complex arange is poorly defined; fall back to NumPy here. + # TODO(jakevdp): deprecate the complex case. + return array(np.arange(start, stop, step, dtype=dtype), device=out_sharding) + + if step is not None: + # arange(N, M, K): when step is specified, fall back to NumPy. + return array(np.arange(start, stop, step, dtype=dtype), device=out_sharding) + + if stop is None: + start, stop = 0, start + + if start == 0: + # arange(M) or arange(0, M) + size = max(0, int(np.ceil(stop))) + return lax.broadcasted_iota(dtype, (size,), 0, out_sharding=out_sharding) + else: - if step is None and start == 0 and stop is not None: - return lax.iota(dtype, np.ceil(stop).astype(int)) - return array(np.arange(start, stop=stop, step=step, dtype=dtype)) + # arange(N, M) + size = max(0, int(np.ceil(stop - start))) + return lax.add(lax.convert_element_type(start, dtype), + lax.broadcasted_iota(dtype, (size,), 0, out_sharding=out_sharding)) def _arange_dynamic( @@ -6373,316 +6048,6 @@ def _arange_dynamic( return (array(start, dtype=dtype) + array(step, dtype=dtype) * lax.iota(dtype, size)) -@overload -def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, - endpoint: bool = True, retstep: Literal[False] = False, - dtype: DTypeLike | None = None, - axis: int = 0, - *, device: xc.Device | Sharding | None = None) -> Array: ... -@overload -def linspace(start: ArrayLike, stop: ArrayLike, num: int, - endpoint: bool, retstep: Literal[True], - dtype: DTypeLike | None = None, - axis: int = 0, - *, device: xc.Device | Sharding | None = None) -> tuple[Array, Array]: ... -@overload -def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, - endpoint: bool = True, *, retstep: Literal[True], - dtype: DTypeLike | None = None, - axis: int = 0, - device: xc.Device | Sharding | None = None) -> tuple[Array, Array]: ... -@overload -def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, - endpoint: bool = True, retstep: bool = False, - dtype: DTypeLike | None = None, - axis: int = 0, - *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: ... -@export -def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, - endpoint: bool = True, retstep: bool = False, - dtype: DTypeLike | None = None, - axis: int = 0, - *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: - """Return evenly-spaced numbers within an interval. - - JAX implementation of :func:`numpy.linspace`. - - Args: - start: scalar or array of starting values. - stop: scalar or array of stop values. - num: number of values to generate. Default: 50. - endpoint: if True (default) then include the ``stop`` value in the result. - If False, then exclude the ``stop`` value. - retstep: If True, then return a ``(result, step)`` tuple, where ``step`` is the - interval between adjacent values in ``result``. - axis: integer axis along which to generate the linspace. Defaults to zero. - device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` - to which the created array will be committed. - - Returns: - An array ``values``, or a tuple ``(values, step)`` if ``retstep`` is True, where: - - - ``values`` is an array of evenly-spaced values from ``start`` to ``stop`` - - ``step`` is the interval between adjacent values. - - See also: - - :func:`jax.numpy.arange`: Generate ``N`` evenly-spaced values given a starting - point and a step - - :func:`jax.numpy.logspace`: Generate logarithmically-spaced values. - - :func:`jax.numpy.geomspace`: Generate geometrically-spaced values. - - Examples: - List of 5 values between 0 and 10: - - >>> jnp.linspace(0, 10, 5) - Array([ 0. , 2.5, 5. , 7.5, 10. ], dtype=float32) - - List of 8 values between 0 and 10, excluding the endpoint: - - >>> jnp.linspace(0, 10, 8, endpoint=False) - Array([0. , 1.25, 2.5 , 3.75, 5. , 6.25, 7.5 , 8.75], dtype=float32) - - List of values and the step size between them - - >>> vals, step = jnp.linspace(0, 10, 9, retstep=True) - >>> vals - Array([ 0. , 1.25, 2.5 , 3.75, 5. , 6.25, 7.5 , 8.75, 10. ], dtype=float32) - >>> step - Array(1.25, dtype=float32) - - Multi-dimensional linspace: - - >>> start = jnp.array([0, 5]) - >>> stop = jnp.array([5, 10]) - >>> jnp.linspace(start, stop, 5) - Array([[ 0. , 5. ], - [ 1.25, 6.25], - [ 2.5 , 7.5 ], - [ 3.75, 8.75], - [ 5. , 10. ]], dtype=float32) - """ - num = core.concrete_dim_or_error(num, "'num' argument of jnp.linspace") - axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.linspace") - return _linspace(start, stop, num, endpoint, retstep, dtype, axis, device=device) - -@partial(jit, static_argnames=('num', 'endpoint', 'retstep', 'dtype', 'axis', 'device')) -def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, - endpoint: bool = True, retstep: bool = False, - dtype: DTypeLike | None = None, - axis: int = 0, - *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: - """Implementation of linspace differentiable in start and stop args.""" - dtypes.check_user_dtype_supported(dtype, "linspace") - if num < 0: - raise ValueError(f"Number of samples, {num}, must be non-negative.") - start, stop = util.ensure_arraylike("linspace", start, stop) - - if dtype is None: - dtype = dtypes.to_inexact_dtype(result_type(start, stop)) - dtype = dtypes.jax_dtype(dtype) - computation_dtype = dtypes.to_inexact_dtype(dtype) - start = start.astype(computation_dtype) - stop = stop.astype(computation_dtype) - - bounds_shape = list(lax.broadcast_shapes(np.shape(start), np.shape(stop))) - broadcast_start = broadcast_to(start, bounds_shape) - broadcast_stop = broadcast_to(stop, bounds_shape) - axis = len(bounds_shape) + axis + 1 if axis < 0 else axis - bounds_shape.insert(axis, 1) - div = (num - 1) if endpoint else num - if num > 1: - delta: Array = lax.convert_element_type(stop - start, computation_dtype) / array(div, dtype=computation_dtype) - iota_shape = [1,] * len(bounds_shape) - iota_shape[axis] = div - # This approach recovers the endpoints with float32 arithmetic, - # but can lead to rounding errors for integer outputs. - real_dtype = finfo(computation_dtype).dtype - step = reshape(lax.iota(real_dtype, div), iota_shape) / array(div, real_dtype) - step = step.astype(computation_dtype) - out = (reshape(broadcast_start, bounds_shape) * (1 - step) + - reshape(broadcast_stop, bounds_shape) * step) - - if endpoint: - out = lax.concatenate([out, lax.expand_dims(broadcast_stop, (axis,))], - _canonicalize_axis(axis, out.ndim)) - - elif num == 1: - delta = asarray(np.nan if endpoint else stop - start, dtype=computation_dtype) - out = reshape(broadcast_start, bounds_shape) - else: # num == 0 degenerate case, match numpy behavior - empty_shape = list(lax.broadcast_shapes(np.shape(start), np.shape(stop))) - empty_shape.insert(axis, 0) - delta = asarray(np.nan, dtype=computation_dtype) - out = reshape(array([], dtype=dtype), empty_shape) - - if issubdtype(dtype, np.integer) and not issubdtype(out.dtype, np.integer): - out = lax.floor(out) - - sharding = canonicalize_device_to_sharding(device) - result = lax_internal._convert_element_type(out, dtype, sharding=sharding) - return (result, delta) if retstep else result - - -@export -def logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, - endpoint: bool = True, base: ArrayLike = 10.0, - dtype: DTypeLike | None = None, axis: int = 0) -> Array: - """Generate logarithmically-spaced values. - - JAX implementation of :func:`numpy.logspace`. - - Args: - start: scalar or array. Used to specify the start value. The start value is - ``base ** start``. - stop: scalar or array. Used to specify the stop value. The end value is - ``base ** stop``. - num: int, optional, default=50. Number of values to generate. - endpoint: bool, optional, default=True. If True, then include the ``stop`` value - in the result. If False, then exclude the ``stop`` value. - base: scalar or array, optional, default=10. Specifies the base of the logarithm. - dtype: optional. Specifies the dtype of the output. - axis: int, optional, default=0. Axis along which to generate the logspace. - - Returns: - An array of logarithm. - - See also: - - :func:`jax.numpy.arange`: Generate ``N`` evenly-spaced values given a starting - point and a step value. - - :func:`jax.numpy.linspace`: Generate evenly-spaced values. - - :func:`jax.numpy.geomspace`: Generate geometrically-spaced values. - - Examples: - List 5 logarithmically spaced values between 1 (``10 ** 0``) and 100 - (``10 ** 2``): - - >>> with jnp.printoptions(precision=3, suppress=True): - ... jnp.logspace(0, 2, 5) - Array([ 1. , 3.162, 10. , 31.623, 100. ], dtype=float32) - - List 5 logarithmically-spaced values between 1(``10 ** 0``) and 100 - (``10 ** 2``), excluding endpoint: - - >>> with jnp.printoptions(precision=3, suppress=True): - ... jnp.logspace(0, 2, 5, endpoint=False) - Array([ 1. , 2.512, 6.31 , 15.849, 39.811], dtype=float32) - - List 7 logarithmically-spaced values between 1 (``2 ** 0``) and 4 (``2 ** 2``) - with base 2: - - >>> with jnp.printoptions(precision=3, suppress=True): - ... jnp.logspace(0, 2, 7, base=2) - Array([1. , 1.26 , 1.587, 2. , 2.52 , 3.175, 4. ], dtype=float32) - - Multi-dimensional logspace: - - >>> start = jnp.array([0, 5]) - >>> stop = jnp.array([5, 0]) - >>> base = jnp.array([2, 3]) - >>> with jnp.printoptions(precision=3, suppress=True): - ... jnp.logspace(start, stop, 5, base=base) - Array([[ 1. , 243. ], - [ 2.378, 61.547], - [ 5.657, 15.588], - [ 13.454, 3.948], - [ 32. , 1. ]], dtype=float32) - """ - num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.logspace") - axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.logspace") - return _logspace(start, stop, num, endpoint, base, dtype, axis) - -@partial(jit, static_argnames=('num', 'endpoint', 'dtype', 'axis')) -def _logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, - endpoint: bool = True, base: ArrayLike = 10.0, - dtype: DTypeLike | None = None, axis: int = 0) -> Array: - """Implementation of logspace differentiable in start and stop args.""" - dtypes.check_user_dtype_supported(dtype, "logspace") - if dtype is None: - dtype = dtypes.to_inexact_dtype(result_type(start, stop)) - dtype = dtypes.jax_dtype(dtype) - computation_dtype = dtypes.to_inexact_dtype(dtype) - start, stop = util.ensure_arraylike("logspace", start, stop) - start = start.astype(computation_dtype) - stop = stop.astype(computation_dtype) - lin = linspace(start, stop, num, - endpoint=endpoint, retstep=False, dtype=None, axis=axis) - return lax.convert_element_type(ufuncs.power(base, lin), dtype) - - -@export -def geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, - dtype: DTypeLike | None = None, axis: int = 0) -> Array: - """Generate geometrically-spaced values. - - JAX implementation of :func:`numpy.geomspace`. - - Args: - start: scalar or array. Specifies the starting values. - stop: scalar or array. Specifies the stop values. - num: int, optional, default=50. Number of values to generate. - endpoint: bool, optional, default=True. If True, then include the ``stop`` value - in the result. If False, then exclude the ``stop`` value. - dtype: optional. Specifies the dtype of the output. - axis: int, optional, default=0. Axis along which to generate the geomspace. - - Returns: - An array containing the geometrically-spaced values. - - See also: - - :func:`jax.numpy.arange`: Generate ``N`` evenly-spaced values given a starting - point and a step value. - - :func:`jax.numpy.linspace`: Generate evenly-spaced values. - - :func:`jax.numpy.logspace`: Generate logarithmically-spaced values. - - Examples: - List 5 geometrically-spaced values between 1 and 16: - - >>> with jnp.printoptions(precision=3, suppress=True): - ... jnp.geomspace(1, 16, 5) - Array([ 1., 2., 4., 8., 16.], dtype=float32) - - List 4 geomtrically-spaced values between 1 and 16, with ``endpoint=False``: - - >>> with jnp.printoptions(precision=3, suppress=True): - ... jnp.geomspace(1, 16, 4, endpoint=False) - Array([1., 2., 4., 8.], dtype=float32) - - Multi-dimensional geomspace: - - >>> start = jnp.array([1, 1000]) - >>> stop = jnp.array([27, 1]) - >>> with jnp.printoptions(precision=3, suppress=True): - ... jnp.geomspace(start, stop, 4) - Array([[ 1., 1000.], - [ 3., 100.], - [ 9., 10.], - [ 27., 1.]], dtype=float32) - """ - num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.geomspace") - axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.geomspace") - return _geomspace(start, stop, num, endpoint, dtype, axis) - -@partial(jit, static_argnames=('num', 'endpoint', 'dtype', 'axis')) -def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, - dtype: DTypeLike | None = None, axis: int = 0) -> Array: - """Implementation of geomspace differentiable in start and stop args.""" - dtypes.check_user_dtype_supported(dtype, "geomspace") - if dtype is None: - dtype = dtypes.to_inexact_dtype(result_type(start, stop)) - dtype = dtypes.jax_dtype(dtype) - computation_dtype = dtypes.to_inexact_dtype(dtype) - start, stop = util.ensure_arraylike("geomspace", start, stop) - start = start.astype(computation_dtype) - stop = stop.astype(computation_dtype) - - sign = ufuncs.sign(start) - res = sign * logspace(ufuncs.log10(start / sign), ufuncs.log10(stop / sign), - num, endpoint=endpoint, base=10.0, - dtype=computation_dtype, axis=0) - if axis != 0: - res = moveaxis(res, 0, axis) - return lax.convert_element_type(res, dtype) - @export def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False, @@ -6766,7 +6131,7 @@ def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False, @export -@jit +@api.jit def i0(x: ArrayLike) -> Array: r"""Calculate modified Bessel function of first kind, zeroth order. @@ -6801,18 +6166,18 @@ def i0(x: ArrayLike) -> Array: """ x_arr, = util.promote_args_inexact("i0", x) if not issubdtype(x_arr.dtype, np.floating): - raise ValueError(f"Unsupported input type to jax.numpy.i0: {_dtype(x)}") + raise ValueError(f"Unsupported input type to jax.numpy.i0: {x_arr.dtype}") return _i0(x_arr) @custom_jvp def _i0(x): abs_x = lax.abs(x) - return lax.mul(lax.exp(abs_x), lax.bessel_i0e(abs_x)) + return lax.mul(lax.exp(abs_x), lax_special.bessel_i0e(abs_x)) @_i0.defjvp def _i0_jvp(primals, tangents): - primal_out, tangent_out = jax.jvp(_i0.fun, primals, tangents) + primal_out, tangent_out = api.jvp(_i0.fun, primals, tangents) return primal_out, where(primals[0] == 0, 0.0, tangent_out) @export @@ -6856,7 +6221,7 @@ def ix_(*args: ArrayLike) -> tuple[Array, ...]: if len(a.shape) != 1: msg = "Arguments to jax.numpy.ix_ must be 1-dimensional, got shape {}" raise ValueError(msg.format(a.shape)) - if _dtype(a) == bool: + if a.dtype == bool: raise NotImplementedError( "Boolean arguments to jax.numpy.ix_ are not implemented") shape = [1] * n @@ -6911,8 +6276,8 @@ def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, (Array([[0], [1]], dtype=int32), Array([[0, 1, 2]], dtype=int32)) """ - dtypes.check_user_dtype_supported(dtype, "indices") - dtype = dtype or dtypes.canonicalize_dtype(dtypes.int_) + dtype = dtypes.check_and_canonicalize_user_dtype( + int if dtype is None else dtype, "indices") dimensions = tuple( core.concrete_or_error(operator.index, d, "dimensions argument of jnp.indices") for d in dimensions) @@ -6931,7 +6296,8 @@ def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, @export def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, - total_repeat_length: int | None = None) -> Array: + total_repeat_length: int | None = None, + out_sharding: NamedSharding | P | None = None) -> Array: """Construct an array from repeated elements. JAX implementation of :func:`numpy.repeat`. @@ -6995,8 +6361,49 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, Array([[1, 1, 2, 2, 2, 2, 2], [3, 3, 4, 4, 4, 4, 4]], dtype=int32) """ - arr = util.ensure_arraylike("repeat", a) - core.is_dim(repeats) or util.check_arraylike("repeat", repeats) + if out_sharding is not None: + return _auto_repeat(_repeat, a, repeats, axis, total_repeat_length, + out_sharding) + ctx_mesh = get_abstract_mesh() + if ctx_mesh._any_axis_explicit: + aval = core.typeof(a) + if axis is None or aval.sharding.spec[axis] is not None: + raise ValueError( + "Please pass sharding to `jnp.repeat` via `out_sharding` parameter.") + assert axis is not None and aval.sharding.spec[axis] is None + out_sharding = (NamedSharding(ctx_mesh, P()) + if aval.sharding.mesh.empty else aval.sharding) + return _auto_repeat(_repeat, a, repeats, axis, total_repeat_length, + out_sharding) + try: + return _repeat(repeats, a, axis=axis, + total_repeat_length=total_repeat_length) + except core.ShardingTypeError as e: + raise ValueError( + "Please pass sharding to `jnp.repeat` via `out_sharding` parameter.") + +def _auto_repeat(fun, a, repeats, axis, total_repeat_length, out_sharding): + out_sharding = canonicalize_sharding(out_sharding, 'repeat') + if total_repeat_length is None: + return auto_axes(partial(fun, repeats, axis=axis, + total_repeat_length=total_repeat_length), + out_sharding=out_sharding, + axes=out_sharding.mesh.explicit_axes # type: ignore + )(a) + else: + return auto_axes( + partial(fun, axis=axis, total_repeat_length=total_repeat_length), + out_sharding=out_sharding, + axes=out_sharding.mesh.explicit_axes # type: ignore + )(repeats, a) + +def _repeat(repeats: ArrayLike, a: ArrayLike, *, axis: int | None = None, + total_repeat_length: int | None = None) -> Array: + if core.is_dim(repeats): + util.check_arraylike("repeat", a) + else: + util.check_arraylike("repeat", a, repeats) + arr = asarray(a) if axis is None: arr = arr.ravel() @@ -7042,7 +6449,7 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, # Special case when a is a scalar. if arr.ndim == 0: if np.shape(repeats) == (1,): - return full([total_repeat_length], arr) + return array_creation.full([total_repeat_length], arr) else: raise ValueError('`repeat` with a scalar parameter `a` is only ' 'implemented for scalar values of the parameter `repeats`.') @@ -7065,7 +6472,7 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, # Cumsum to get indices of new number in repeated tensor, e.g. [0, 1, 3, 3] scatter_indices = reductions.cumsum(exclusive_repeats) # Scatter these onto a zero buffer, e.g. [1,1,0,2,0,0,0,0] - block_split_indicators = zeros([total_repeat_length], dtype='int32') + block_split_indicators = array_creation.zeros([total_repeat_length], dtype='int32') block_split_indicators = block_split_indicators.at[scatter_indices].add(1) # Cumsum again to get scatter indices for repeat, e.g. [0,1,1,3,3,3,3,3] gather_indices = reductions.cumsum(block_split_indicators) - 1 @@ -7073,7 +6480,7 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, @export -@partial(jit, static_argnames=('axis',)) +@api.jit(static_argnames=('axis',)) def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0, axis: int = -1) -> Array: r""" @@ -7118,11 +6525,11 @@ def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0, # TODO(phawkins): remove this annotation after fixing jnp types. dx_array: Array if x is None: - util.check_arraylike('trapezoid', y) + y = util.ensure_arraylike('trapezoid', y) y_arr, = util.promote_dtypes_inexact(y) dx_array = asarray(dx) else: - util.check_arraylike('trapezoid', y, x) + y, x = util.ensure_arraylike('trapezoid', y, x) y_arr, x_arr = util.promote_dtypes_inexact(y, x) if x_arr.ndim == 1: dx_array = diff(x_arr) @@ -7183,14 +6590,17 @@ def tri(N: int, M: int | None = None, k: int = 0, dtype: DTypeLike | None = None [1., 0., 0., 0.], [1., 1., 0., 0.]], dtype=float32) """ - dtypes.check_user_dtype_supported(dtype, "tri") + if dtype is None: + # TODO(phawkins): this is a strange default. + dtype = np.dtype(np.float32) + else: + dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "tri") M = M if M is not None else N - dtype = dtype or np.dtype('float32') - return lax_internal._tri(dtype, (N, M), k) + return lax._tri(dtype, (N, M), k) @export -@partial(jit, static_argnames=('k',)) +@api.jit(static_argnames=('k',)) def tril(m: ArrayLike, k: int = 0) -> Array: r"""Return lower triangle of an array. @@ -7243,17 +6653,17 @@ def tril(m: ArrayLike, k: int = 0) -> Array: [[5, 0], [7, 8]]], dtype=int32) """ - util.check_arraylike("tril", m) + m = util.ensure_arraylike("tril", m) m_shape = np.shape(m) if len(m_shape) < 2: raise ValueError("Argument to jax.numpy.tril must be at least 2D") N, M = m_shape[-2:] mask = tri(N, M, k=k, dtype=bool) - return lax.select(lax.broadcast(mask, m_shape[:-2]), m, zeros_like(m)) + return lax.select(lax.broadcast(mask, m_shape[:-2]), m, array_creation.zeros_like(m)) @export -@partial(jit, static_argnames=('k',)) +@api.jit(static_argnames=('k',)) def triu(m: ArrayLike, k: int = 0) -> Array: r"""Return upper triangle of an array. @@ -7310,17 +6720,17 @@ def triu(m: ArrayLike, k: int = 0) -> Array: [[5, 6], [0, 8]]], dtype=int32) """ - util.check_arraylike("triu", m) + m = util.ensure_arraylike("triu", m) m_shape = np.shape(m) if len(m_shape) < 2: raise ValueError("Argument to jax.numpy.triu must be at least 2D") N, M = m_shape[-2:] mask = tri(N, M, k=k - 1, dtype=bool) - return lax.select(lax.broadcast(mask, m_shape[:-2]), zeros_like(m), m) + return lax.select(lax.broadcast(mask, m_shape[:-2]), array_creation.zeros_like(m), m) @export -@partial(jit, static_argnames=('axis1', 'axis2', 'dtype')) +@api.jit(static_argnames=('axis1', 'axis2', 'dtype')) def trace(a: ArrayLike, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int = 1, dtype: DTypeLike | None = None, out: None = None) -> Array: """Calculate sum of the diagonal of input along the given axes. @@ -7367,21 +6777,22 @@ def trace(a: ArrayLike, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int >>> jnp.trace(x, offset=1, axis1=1, axis2=2) Array([2, 6], dtype=int32) """ - util.check_arraylike("trace", a) + a = util.ensure_arraylike("trace", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.trace is not supported.") if _canonicalize_axis(axis1, np.ndim(a)) == _canonicalize_axis(axis2, np.ndim(a)): raise ValueError(f"axis1 and axis2 can not be same. axis1={axis1} and axis2={axis2}") - dtypes.check_user_dtype_supported(dtype, "trace") + if dtype is not None: + dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "trace") a_shape = np.shape(a) a = moveaxis(a, (axis1, axis2), (-2, -1)) # Mask out the diagonal and reduce. a = where(eye(a_shape[axis1], a_shape[axis2], k=offset, dtype=bool), - a, zeros_like(a)) + a, array_creation.zeros_like(a)) return reductions.sum(a, axis=(-2, -1), dtype=dtype) @@ -7431,7 +6842,7 @@ def mask_indices(n: int, >>> jnp.mask_indices(3, mask_func) (Array([0, 1, 1, 2, 2], dtype=int32), Array([0, 0, 1, 0, 2], dtype=int32)) """ - i, j = nonzero(mask_func(ones((n, n)), k), size=size) + i, j = nonzero(mask_func(array_creation.ones((n, n)), k), size=size) return (i, j) @@ -7441,12 +6852,12 @@ def _triu_size(n, m, k): elif k >= m: return 0 else: - mk = min(n, m - k) + mk = core.min_dim(n, m - k) return mk * (mk + 1) // 2 + mk * (m - k - mk) @export -def triu_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array]: +def triu_indices(n: DimSize, k: DimSize = 0, m: DimSize | None = None) -> tuple[Array, Array]: """Return the indices of upper triangle of an array of size ``(n, m)``. JAX implementation of :func:`numpy.triu_indices`. @@ -7497,15 +6908,15 @@ def triu_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array >>> jnp.triu_indices(3, k=-1) (Array([0, 0, 0, 1, 1, 1, 2, 2], dtype=int32), Array([0, 1, 2, 0, 1, 2, 1, 2], dtype=int32)) """ - n = core.concrete_or_error(operator.index, n, "n argument of jnp.triu_indices") - k = core.concrete_or_error(operator.index, k, "k argument of jnp.triu_indices") - m = n if m is None else core.concrete_or_error(operator.index, m, "m argument of jnp.triu_indices") - i, j = nonzero(triu(ones((n, m)), k=k), size=_triu_size(n, m, k)) + n = core.concrete_dim_or_error(n, "n argument of jnp.triu_indices") + k = core.concrete_dim_or_error(k, "k argument of jnp.triu_indices") + m = n if m is None else core.concrete_dim_or_error(m, "m argument of jnp.triu_indices") + i, j = nonzero(triu(array_creation.ones((n, m)), k=k), size=_triu_size(n, m, k)) return i, j @export -def tril_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array]: +def tril_indices(n: DimSize, k: DimSize = 0, m: DimSize | None = None) -> tuple[Array, Array]: """Return the indices of lower triangle of an array of size ``(n, m)``. JAX implementation of :func:`numpy.tril_indices`. @@ -7556,15 +6967,15 @@ def tril_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array >>> jnp.tril_indices(3, k=-1) (Array([1, 2, 2], dtype=int32), Array([0, 0, 1], dtype=int32)) """ - n = core.concrete_or_error(operator.index, n, "n argument of jnp.triu_indices") - k = core.concrete_or_error(operator.index, k, "k argument of jnp.triu_indices") - m = n if m is None else core.concrete_or_error(operator.index, m, "m argument of jnp.triu_indices") - i, j = nonzero(tril(ones((n, m)), k=k), size=_triu_size(m, n, -k)) + n = core.concrete_dim_or_error(n, "n argument of jnp.triu_indices") + k = core.concrete_dim_or_error(k, "k argument of jnp.triu_indices") + m = n if m is None else core.concrete_dim_or_error(m, "m argument of jnp.triu_indices") + i, j = nonzero(tril(array_creation.ones((n, m)), k=k), size=_triu_size(m, n, -k)) return i, j @export -def triu_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: +def triu_indices_from(arr: ArrayLike | SupportsShape, k: int = 0) -> tuple[Array, Array]: """Return the indices of upper triangle of a given array. JAX implementation of :func:`numpy.triu_indices_from`. @@ -7615,14 +7026,18 @@ def triu_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: >>> jnp.triu_indices_from(arr, k=-1) (Array([0, 0, 0, 1, 1, 1, 2, 2], dtype=int32), Array([0, 1, 2, 0, 1, 2, 1, 2], dtype=int32)) """ - arr_shape = np.shape(arr) + if hasattr(arr, "shape"): + arr_shape = arr.shape + else: + arr = util.ensure_arraylike("triu_indices_from", arr) + arr_shape = arr.shape if len(arr_shape) != 2: raise ValueError("Only 2-D inputs are accepted") return triu_indices(arr_shape[0], k=k, m=arr_shape[1]) @export -def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: +def tril_indices_from(arr: ArrayLike | SupportsShape, k: int = 0) -> tuple[Array, Array]: """Return the indices of lower triangle of a given array. JAX implementation of :func:`numpy.tril_indices_from`. @@ -7673,7 +7088,11 @@ def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: >>> jnp.tril_indices_from(arr, k=-1) (Array([1, 2, 2], dtype=int32), Array([0, 0, 1], dtype=int32)) """ - arr_shape = np.shape(arr) + if hasattr(arr, "shape"): + arr_shape = arr.shape + else: + arr = util.ensure_arraylike("tril_indices_from", arr) + arr_shape = arr.shape if len(arr_shape) != 2: raise ValueError("Only 2-D inputs are accepted") return tril_indices(arr_shape[0], k=k, m=arr_shape[1]) @@ -7696,6 +7115,7 @@ def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: bool = False, *, dimensions must be the same size. val: scalar or array with which to fill the diagonal. If an array, it will be flattened and repeated to fill the diagonal entries. + wrap: Not implemented by JAX. Only the default value of ``False`` is supported. inplace: must be set to False to indicate that the input is not modified in-place, but rather a modified copy is returned. @@ -7793,7 +7213,12 @@ def diag_indices(n: int, ndim: int = 2) -> tuple[Array, ...]: if ndim < 0: raise ValueError("ndim argument to diag_indices must be nonnegative, got {}" .format(ndim)) - return (lax.iota(dtypes.int_, n),) * ndim + index_dtype = lax_utils.int_dtype_for_dim(n, signed=True) + # We'd give the correct output values with int32, but use the default dtype to + # match NumPy type semantics if x64 mode is enabled for now. + if index_dtype == np.dtype(np.int32): + index_dtype = dtypes.default_int_dtype() + return (lax.iota(index_dtype, n),) * ndim @export @@ -7827,7 +7252,7 @@ def diag_indices_from(arr: ArrayLike) -> tuple[Array, ...]: Array([0, 1], dtype=int32), Array([0, 1], dtype=int32)) """ - util.check_arraylike("diag_indices_from", arr) + arr = util.ensure_arraylike("diag_indices_from", arr) nd = np.ndim(arr) if not np.ndim(arr) >= 2: raise ValueError("input array must be at least 2-d") @@ -7840,7 +7265,7 @@ def diag_indices_from(arr: ArrayLike) -> tuple[Array, ...]: @export -@partial(jit, static_argnames=('offset', 'axis1', 'axis2')) +@api.jit(static_argnames=('offset', 'axis1', 'axis2')) def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0, axis2: int = 1) -> Array: """Returns the specified diagonal of an array. @@ -7876,7 +7301,7 @@ def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0, >>> jnp.diagonal(x, offset=-1) Array([4, 8], dtype=int32) """ - util.check_arraylike("diagonal", a) + a = util.ensure_arraylike("diagonal", a) if np.ndim(a) < 2: raise ValueError("diagonal requires an array of at least two dimensions.") @@ -7897,21 +7322,21 @@ def _default_diag(a): # The mosaic lowering rule for diag is only defined for square arrays. # TODO(mvoz): Add support for offsets. - if np.shape(a)[0] != np.shape(a)[1] or np.ndim(a) != 2 or offset != 0 or _dtype(a) == bool: + if np.shape(a)[0] != np.shape(a)[1] or np.ndim(a) != 2 or offset != 0 or a.dtype == bool: return _default_diag(a) else: - a_shape_eye = eye(np.shape(a)[0], dtype=_dtype(a)) + a_shape_eye = eye(np.shape(a)[0], dtype=a.dtype) def _mosaic_diag(a): def _sum(x, axis): return lax.reduce( x, - np.array(0, _dtype(x)), - lax.add if _dtype(x) != bool else lax.bitwise_or, + np.array(0, x.dtype), + lax.add if x.dtype != bool else lax.bitwise_or, (axis,), ) return _sum(lax.mul(a_shape_eye, a), axis=0) - return lax.platform_dependent(a, default=_default_diag, mosaic=_mosaic_diag) + return control_flow.platform_dependent(a, default=_default_diag, mosaic=_mosaic_diag) @export @@ -7962,17 +7387,17 @@ def diag(v: ArrayLike, k: int = 0) -> Array: >>> jnp.diag(x) Array([1, 5, 9], dtype=int32) """ + v = util.ensure_arraylike("diag", v) return _diag(v, operator.index(k)) -@partial(jit, static_argnames=('k',)) -def _diag(v, k): - util.check_arraylike("diag", v) +@api.jit(static_argnames=('k',)) +def _diag(v: Array, k: int): v_shape = np.shape(v) if len(v_shape) == 1: zero = lambda x: lax.full_like(x, shape=(), fill_value=0) n = v_shape[0] + abs(k) v = lax.pad(v, zero(v), ((max(0, k), max(0, -k), 0),)) - return where(eye(n, k=k, dtype=bool), v, zeros_like(v)) + return where(eye(n, k=k, dtype=bool), v, array_creation.zeros_like(v)) elif len(v_shape) == 2: return diagonal(v, offset=k) else: @@ -8024,7 +7449,7 @@ def diagflat(v: ArrayLike, k: int = 0) -> Array: v_ravel = ravel(v) v_length = len(v_ravel) adj_length = v_length + abs(k) - res = zeros(adj_length*adj_length, dtype=v_ravel.dtype) + res = array_creation.zeros(adj_length*adj_length, dtype=v_ravel.dtype) i = arange(0, adj_length-abs(k)) if (k >= 0): fi = i+k+i*adj_length @@ -8037,13 +7462,14 @@ def diagflat(v: ArrayLike, k: int = 0) -> Array: # TODO(jakevdp): add support for N-dimensional inputs as in NumPy v2.2 @export -def trim_zeros(filt: ArrayLike, trim: str ='fb') -> Array: +def trim_zeros(filt: ArrayLike, trim: str ='fb', + axis: int | Sequence[int] | None = None) -> Array: """Trim leading and/or trailing zeros of the input array. JAX implementation of :func:`numpy.trim_zeros`. Args: - filt: input array. Must have ``filt.ndim == 1``. + filt: N-dimensional input array. trim: string, optional, default = ``fb``. Specifies from which end the input is trimmed. @@ -8051,34 +7477,63 @@ def trim_zeros(filt: ArrayLike, trim: str ='fb') -> Array: - ``b`` - trims only the trailing zeros. - ``fb`` - trims both leading and trailing zeros. + axis: optional axis or axes along which to trim. If not specified, trim along + all axes of the array. + Returns: An array containing the trimmed input with same dtype as ``filt``. Examples: + One-dimensional input: + >>> x = jnp.array([0, 0, 2, 0, 1, 4, 3, 0, 0, 0]) >>> jnp.trim_zeros(x) Array([2, 0, 1, 4, 3], dtype=int32) + >>> jnp.trim_zeros(x, trim='f') + Array([2, 0, 1, 4, 3, 0, 0, 0], dtype=int32) + >>> jnp.trim_zeros(x, trim='b') + Array([0, 0, 2, 0, 1, 4, 3], dtype=int32) + + Two-dimensional input: + + >>> x = jnp.zeros((4, 5)).at[1:3, 1:4].set(1) + >>> x + Array([[0., 0., 0., 0., 0.], + [0., 1., 1., 1., 0.], + [0., 1., 1., 1., 0.], + [0., 0., 0., 0., 0.]], dtype=float32) + >>> jnp.trim_zeros(x) + Array([[1., 1., 1.], + [1., 1., 1.]], dtype=float32) + >>> jnp.trim_zeros(x, trim='f') + Array([[1., 1., 1., 0.], + [1., 1., 1., 0.], + [0., 0., 0., 0.]], dtype=float32) + >>> jnp.trim_zeros(x, axis=0) + Array([[0., 1., 1., 1., 0.], + [0., 1., 1., 1., 0.]], dtype=float32) + >>> jnp.trim_zeros(x, axis=1) + Array([[0., 0., 0.], + [1., 1., 1.], + [1., 1., 1.], + [0., 0., 0.]], dtype=float32) """ - # Non-array inputs are deprecated 2024-09-11 - util.check_arraylike("trim_zeros", filt, emit_warning=True) + filt = util.ensure_arraylike("trim_zeros", filt) core.concrete_or_error(None, filt, "Error arose in the `filt` argument of trim_zeros()") - filt_arr = jax.numpy.asarray(filt) - del filt - if filt_arr.ndim != 1: - # Added on 2024-09-11 - if deprecations.is_accelerated("jax-numpy-trimzeros-not-1d-array"): - raise TypeError(f"'filt' must be 1-D array, but received {filt_arr.ndim}-D array.") - warnings.warn( - "Passing arrays with ndim != 1 to jnp.trim_zeros() is deprecated. Currently, it " - "works with Arrays having ndim != 1. In the future this will result in an error.", - DeprecationWarning, stacklevel=2) - nz = (filt_arr == 0) - if reductions.all(nz): - return empty(0, filt_arr.dtype) - start: Array | int = argmin(nz) if 'f' in trim.lower() else 0 - end: Array | int = argmin(nz[::-1]) if 'b' in trim.lower() else 0 - return filt_arr[start:len(filt_arr) - end] + axis_set = set(_canonicalize_axis_tuple(axis, filt.ndim)) + if not axis_set or ('f' not in trim.lower() and 'b' not in trim.lower()): + return filt + def _get_slice(x: Array, ax: int) -> slice: + if ax not in axis_set: + return slice(None) + mask = x.any(axis=[i for i in range(x.ndim) if i != ax]) + if not mask.any(): + return slice(0, 0) + start = int(mask.argmax()) if 'f' in trim.lower() else None + stop = x.shape[ax] - int(mask[::-1].argmax()) if 'b' in trim.lower() else None + return slice(start, stop) + return filt[*(_get_slice(filt, ax) for ax in range(filt.ndim))] def trim_zeros_tol(filt, tol, trim='fb'): @@ -8086,14 +7541,14 @@ def trim_zeros_tol(filt, tol, trim='fb'): "Error arose in the `filt` argument of trim_zeros_tol()") nz = (ufuncs.abs(filt) < tol) if reductions.all(nz): - return empty(0, _dtype(filt)) + return array_creation.empty(0, _dtype(filt)) start = argmin(nz) if 'f' in trim.lower() else 0 end = argmin(nz[::-1]) if 'b' in trim.lower() else 0 return filt[start:len(filt) - end] @export -@partial(jit, static_argnames=('axis',)) +@api.jit(static_argnames=('axis',)) def append( arr: ArrayLike, values: ArrayLike, axis: int | None = None ) -> Array: @@ -8243,15 +7698,18 @@ def delete( # Case 3: obj is an array # NB: pass both arrays to check for appropriate error message. util.check_arraylike("delete", a, obj) + # Can't use ensure_arraylike here because obj may be static. + if hasattr(obj, "__jax_array__"): + obj = obj.__jax_array__() # Case 3a: unique integer indices; delete in a JIT-compatible way if issubdtype(_dtype(obj), np.integer) and assume_unique_indices: obj = asarray(obj).ravel() obj = clip(where(obj < 0, obj + a.shape[axis], obj), 0, a.shape[axis]) obj = sort(obj) - obj -= arange(len(obj)) # type: ignore[arg-type,operator] - i = arange(a.shape[axis] - obj.size) - i += (i[None, :] >= obj[:, None]).sum(0) + obj -= arange(len(obj), dtype=obj.dtype) # type: ignore + i = arange(a.shape[axis] - obj.size, dtype=obj.dtype) + i += (i[None, :] >= obj[:, None]).sum(0, dtype=i.dtype) return a[(slice(None),) * axis + (i,)] # Case 3b: non-unique indices: must be static. @@ -8349,18 +7807,18 @@ def insert(arr: ArrayLike, obj: ArrayLike | slice, values: ArrayLike, index = ravel(indices)[0] if indices.ndim == 0: values_arr = moveaxis(values_arr, 0, axis) - indices = full(values_arr.shape[axis], index) + indices = array_creation.full(values_arr.shape[axis], index) n_input = a.shape[axis] n_insert = broadcast_shapes(indices.shape, (values_arr.shape[axis],))[0] out_shape = list(a.shape) out_shape[axis] += n_insert - out = zeros_like(a, shape=tuple(out_shape)) + out = array_creation.zeros_like(a, shape=tuple(out_shape)) indices = where(indices < 0, indices + n_input, indices) indices = clip(indices, 0, n_input) values_ind = indices.at[argsort(indices)].add(arange(n_insert, dtype=indices.dtype)) - arr_mask = ones(n_input + n_insert, dtype=bool).at[values_ind].set(False) + arr_mask = array_creation.ones(n_input + n_insert, dtype=bool).at[values_ind].set(False) arr_ind = where(arr_mask, size=n_input)[0] out = out.at[(slice(None),) * axis + (values_ind,)].set(values_arr) @@ -8441,9 +7899,9 @@ def apply_along_axis( axis = _canonicalize_axis(axis, num_dims) func = lambda arr: func1d(arr, *args, **kwargs) for i in range(1, num_dims - axis): - func = jax.vmap(func, in_axes=i, out_axes=-1) + func = api.vmap(func, in_axes=i, out_axes=-1) for i in range(axis): - func = jax.vmap(func, in_axes=0, out_axes=0) + func = api.vmap(func, in_axes=0, out_axes=0) return func(arr) @@ -8504,7 +7962,7 @@ def apply_over_axes(func: Callable[[ArrayLike, int], Array], a: ArrayLike, @export -@partial(jit, static_argnames=('axisa', 'axisb', 'axisc', 'axis')) +@api.jit(static_argnames=('axisa', 'axisb', 'axisc', 'axis')) def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1, axis: int | None = None): r"""Compute the (batched) cross product of two arrays. @@ -8596,16 +8054,16 @@ def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1, a0 = a[..., 0] a1 = a[..., 1] - a2 = a[..., 2] if a.shape[-1] == 3 else zeros_like(a0) + a2 = a[..., 2] if a.shape[-1] == 3 else array_creation.zeros_like(a0) b0 = b[..., 0] b1 = b[..., 1] - b2 = b[..., 2] if b.shape[-1] == 3 else zeros_like(b0) + b2 = b[..., 2] if b.shape[-1] == 3 else array_creation.zeros_like(b0) c = array([a1 * b2 - a2 * b1, a2 * b0 - a0 * b2, a0 * b1 - a1 * b0]) return moveaxis(c, 0, axisc) @export -@jit +@api.jit def kron(a: ArrayLike, b: ArrayLike) -> Array: """Compute the Kronecker product of two input arrays. @@ -8651,7 +8109,7 @@ def kron(a: ArrayLike, b: ArrayLike) -> Array: @export -@partial(jit, static_argnames=('N', 'increasing')) +@api.jit(static_argnames=('N', 'increasing')) def vander( x: ArrayLike, N: int | None = None, increasing: bool = False ) -> Array: @@ -8687,7 +8145,7 @@ def vander( [3, 1], [4, 1]], dtype=int32) - Generates the Vandermonde matrix in increaing order of powers, when + Generates the Vandermonde matrix in increasing order of powers, when ``increasing=True``. >>> jnp.vander(x, increasing=True) @@ -8706,7 +8164,7 @@ def vander( iota = lax.iota(x.dtype, N) if not increasing: - iota = lax.sub(_lax_const(iota, N - 1), iota) + iota = lax.sub(lax._const(iota, N - 1), iota) return ufuncs.power(x[..., None], expand_dims(iota, tuple(range(x.ndim)))) @@ -8773,6 +8231,7 @@ def argwhere( >>> jnp.argwhere(0) Array([], shape=(0, 0), dtype=int32) """ + a = util.ensure_arraylike("argwhere", a) result = transpose(vstack(nonzero(atleast_1d(a), size=size, fill_value=fill_value))) if np.ndim(a) == 0: return result[:0].reshape(result.shape[0], 0) @@ -8801,6 +8260,10 @@ def argmax(a: ArrayLike, axis: int | None = None, out: None = None, - :func:`jax.numpy.argmin`: return the index of the minimum value. - :func:`jax.numpy.nanargmax`: compute ``argmax`` while ignoring NaN values. + Note: + When the maximum value occurs more than once along a particular axis, the + smallest index is returned. + Examples: >>> x = jnp.array([1, 3, 5, 4, 2]) >>> jnp.argmax(x) @@ -8821,7 +8284,7 @@ def argmax(a: ArrayLike, axis: int | None = None, out: None = None, return _argmax(arr, None if axis is None else operator.index(axis), keepdims=bool(keepdims)) -@partial(jit, static_argnames=('axis', 'keepdims'), inline=True) +@api.jit(static_argnames=('axis', 'keepdims'), inline=True) def _argmax(a: Array, axis: int | None = None, keepdims: bool = False) -> Array: if axis is None: dims = list(range(np.ndim(a))) @@ -8831,7 +8294,8 @@ def _argmax(a: Array, axis: int | None = None, keepdims: bool = False) -> Array: dims = [axis] if a.shape[axis] == 0: raise ValueError("attempt to get argmax of an empty sequence") - result = lax.argmax(a, _canonicalize_axis(axis, a.ndim), dtypes.canonicalize_dtype(dtypes.int_)) + # TODO(phawkins): use an int64 index if the dimension is large enough. + result = lax.argmax(a, _canonicalize_axis(axis, a.ndim), int) return expand_dims(result, dims) if keepdims else result @@ -8853,6 +8317,10 @@ def argmin(a: ArrayLike, axis: int | None = None, out: None = None, Returns: an array containing the index of the minimum value along the specified axis. + Note: + When the minimum value occurs more than once along a particular axis, the + smallest index is returned. + See also: - :func:`jax.numpy.argmax`: return the index of the maximum value. - :func:`jax.numpy.nanargmin`: compute ``argmin`` while ignoring NaN values. @@ -8877,7 +8345,7 @@ def argmin(a: ArrayLike, axis: int | None = None, out: None = None, return _argmin(arr, None if axis is None else operator.index(axis), keepdims=bool(keepdims)) -@partial(jit, static_argnames=('axis', 'keepdims'), inline=True) +@api.jit(static_argnames=('axis', 'keepdims'), inline=True) def _argmin(a: Array, axis: int | None = None, keepdims: bool = False) -> Array: if axis is None: dims = list(range(np.ndim(a))) @@ -8887,7 +8355,8 @@ def _argmin(a: Array, axis: int | None = None, keepdims: bool = False) -> Array: dims = [axis] if a.shape[axis] == 0: raise ValueError("attempt to get argmin of an empty sequence") - result = lax.argmin(a, _canonicalize_axis(axis, a.ndim), dtypes.canonicalize_dtype(dtypes.int_)) + # TODO(phawkins): use an int64 index if the dimension is large enough. + result = lax.argmin(a, _canonicalize_axis(axis, a.ndim), int) return expand_dims(result, dims) if keepdims else result @@ -8945,13 +8414,13 @@ def nanargmax( """ if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanargmax is not supported.") + a = util.ensure_arraylike("nanargmax", a) return _nanargmax(a, None if axis is None else operator.index(axis), keepdims=bool(keepdims)) -@partial(jit, static_argnames=('axis', 'keepdims')) -def _nanargmax(a, axis: int | None = None, keepdims: bool = False): - util.check_arraylike("nanargmax", a) - if not issubdtype(_dtype(a), np.inexact): +@api.jit(static_argnames=('axis', 'keepdims')) +def _nanargmax(a: Array, axis: int | None = None, keepdims: bool = False): + if not issubdtype(a.dtype, np.inexact): return argmax(a, axis=axis, keepdims=keepdims) nan_mask = ufuncs.isnan(a) a = where(nan_mask, -np.inf, a) @@ -9006,13 +8475,13 @@ def nanargmin( """ if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanargmin is not supported.") + a = util.ensure_arraylike("nanargmin", a) return _nanargmin(a, None if axis is None else operator.index(axis), keepdims=bool(keepdims)) -@partial(jit, static_argnames=('axis', 'keepdims')) -def _nanargmin(a, axis: int | None = None, keepdims : bool = False): - util.check_arraylike("nanargmin", a) - if not issubdtype(_dtype(a), np.inexact): +@api.jit(static_argnames=('axis', 'keepdims')) +def _nanargmin(a: Array, axis: int | None = None, keepdims : bool = False): + if not issubdtype(a.dtype, np.inexact): return argmin(a, axis=axis, keepdims=keepdims) nan_mask = ufuncs.isnan(a) a = where(nan_mask, np.inf, a) @@ -9020,7 +8489,7 @@ def _nanargmin(a, axis: int | None = None, keepdims : bool = False): return where(reductions.all(nan_mask, axis=axis, keepdims=keepdims), -1, res) -@partial(jit, static_argnums=(2,)) +@api.jit(static_argnums=(2,)) def _roll_dynamic(a: Array, shift: Array, axis: Sequence[int]) -> Array: b_shape = lax.broadcast_shapes(shift.shape, np.shape(axis)) if len(b_shape) != 1: @@ -9033,17 +8502,17 @@ def _roll_dynamic(a: Array, shift: Array, axis: Sequence[int]) -> Array: x = ufuncs.remainder(lax.convert_element_type(x, np.int32), lax.max(a_shape_i, np.int32(1))) a_concat = lax.concatenate((a, a), i) - a = lax.dynamic_slice_in_dim(a_concat, a_shape_i - x, a.shape[i], axis=i) + a = lax_slicing.dynamic_slice_in_dim(a_concat, a_shape_i - x, a.shape[i], axis=i) return a -@partial(jit, static_argnums=(1, 2)) +@api.jit(static_argnums=(1, 2)) def _roll_static(a: Array, shift: Sequence[int], axis: Sequence[int]) -> Array: for ax, s in zip(*np.broadcast_arrays(axis, shift)): if a.shape[ax] == 0: continue i = (-s) % a.shape[ax] - a = lax.concatenate([lax.slice_in_dim(a, i, a.shape[ax], axis=ax), - lax.slice_in_dim(a, 0, i, axis=ax)], + a = lax.concatenate([lax_slicing.slice_in_dim(a, i, a.shape[ax], axis=ax), + lax_slicing.slice_in_dim(a, 0, i, axis=ax)], dimension=ax) return a @@ -9102,7 +8571,7 @@ def roll(a: ArrayLike, shift: ArrayLike | Sequence[int], @export -@partial(jit, static_argnames=('axis', 'start')) +@api.jit(static_argnames=('axis', 'start')) def rollaxis(a: ArrayLike, axis: int, start: int = 0) -> Array: """Roll the specified axis to a given position. @@ -9154,7 +8623,7 @@ def rollaxis(a: ArrayLike, axis: int, start: int = 0) -> Array: >>> jnp.moveaxis(a, 1, -1).shape (2, 4, 5, 3) """ - util.check_arraylike("rollaxis", a) + a = util.ensure_arraylike("rollaxis", a) start = core.concrete_or_error(operator.index, start, "'start' argument of jnp.rollaxis()") a_ndim = np.ndim(a) axis = _canonicalize_axis(axis, a_ndim) @@ -9168,7 +8637,7 @@ def rollaxis(a: ArrayLike, axis: int, start: int = 0) -> Array: @export -@partial(jit, static_argnames=('axis', 'bitorder')) +@api.jit(static_argnames=('axis', 'bitorder')) def packbits(a: ArrayLike, axis: int | None = None, bitorder: str = "big") -> Array: """Pack array of bits into a uint8 array. @@ -9231,7 +8700,7 @@ def packbits(a: ArrayLike, axis: int | None = None, bitorder: str = "big") -> Ar raise TypeError('Expected an input array of integer or boolean data type') if bitorder not in ['little', 'big']: raise ValueError("'order' must be either 'little' or 'big'") - arr = lax.gt(arr, _lax_const(a, 0)).astype('uint8') + arr = lax.ne(arr, lax._const(arr, 0)).astype('uint8') bits = arange(8, dtype='uint8') if bitorder == 'big': bits = bits[::-1] @@ -9252,7 +8721,7 @@ def packbits(a: ArrayLike, axis: int | None = None, bitorder: str = "big") -> Ar @export -@partial(jit, static_argnames=('axis', 'count', 'bitorder')) +@api.jit(static_argnames=('axis', 'count', 'bitorder')) def unpackbits( a: ArrayLike, axis: int | None = None, @@ -9349,12 +8818,12 @@ def _gcd_cond_fn(xs: tuple[Array, Array]) -> Array: def _gcd_body_fn(xs: tuple[Array, Array]) -> tuple[Array, Array]: x1, x2 = xs x1, x2 = (where(x2 != 0, x2, x1), - where(x2 != 0, lax.rem(x1, x2), _lax_const(x2, 0))) + where(x2 != 0, lax.rem(x1, x2), lax._const(x2, 0))) return (where(x1 < x2, x2, x1), where(x1 < x2, x1, x2)) @export -@jit +@api.jit def gcd(x1: ArrayLike, x2: ArrayLike) -> Array: """Compute the greatest common divisor of two arrays. @@ -9391,17 +8860,17 @@ def gcd(x1: ArrayLike, x2: ArrayLike) -> Array: >>> jnp.gcd(x1, x2) Array([ 6, 3, 12], dtype=int32) """ - util.check_arraylike("gcd", x1, x2) + x1, x2 = util.ensure_arraylike("gcd", x1, x2) x1, x2 = util.promote_dtypes(x1, x2) - if not issubdtype(_dtype(x1), np.integer): + if not issubdtype(x1.dtype, np.integer): raise ValueError("Arguments to jax.numpy.gcd must be integers.") x1, x2 = broadcast_arrays(x1, x2) - gcd, _ = lax.while_loop(_gcd_cond_fn, _gcd_body_fn, (ufuncs.abs(x1), ufuncs.abs(x2))) + gcd, _ = control_flow.while_loop(_gcd_cond_fn, _gcd_body_fn, (ufuncs.abs(x1), ufuncs.abs(x2))) return gcd @export -@jit +@api.jit def lcm(x1: ArrayLike, x2: ArrayLike) -> Array: """Compute the least common multiple of two arrays. @@ -9438,13 +8907,13 @@ def lcm(x1: ArrayLike, x2: ArrayLike) -> Array: >>> jnp.lcm(x1, x2) Array([12, 36, 12], dtype=int32) """ - util.check_arraylike("lcm", x1, x2) + x1, x2 = util.ensure_arraylike("lcm", x1, x2) x1, x2 = util.promote_dtypes(x1, x2) x1, x2 = ufuncs.abs(x1), ufuncs.abs(x2) - if not issubdtype(_dtype(x1), np.integer): + if not issubdtype(x1.dtype, np.integer): raise ValueError("Arguments to jax.numpy.lcm must be integers.") d = gcd(x1, x2) - return where(d == 0, _lax_const(d, 0), + return where(d == 0, lax._const(d, 0), ufuncs.multiply(x1, ufuncs.floor_divide(x2, d))) @@ -9606,11 +9075,12 @@ def compress(condition: ArrayLike, a: ArrayLike, axis: int | None = None, @export -@partial(jit, static_argnames=('rowvar', 'bias', 'ddof')) +@api.jit(static_argnames=('rowvar', 'bias', 'ddof', 'dtype')) def cov(m: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True, bias: bool = False, ddof: int | None = None, fweights: ArrayLike | None = None, - aweights: ArrayLike | None = None) -> Array: + aweights: ArrayLike | None = None, + dtype: DTypeLike | None = None) -> Array: r"""Estimate the weighted sample covariance. JAX implementation of :func:`numpy.cov`. @@ -9653,6 +9123,8 @@ def cov(m: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True, a relative weight specifying the "importance" of each observation. In the ``ddof=0`` case, it is equivalent to assigning probabilities to each observation. + dtype: optional data type of the result. Must be a float or complex type; + if not specified, it will be determined based on the dtype of the input. Returns: A covariance matrix of shape ``(M, M)``, or a scalar with shape ``()`` if ``M = 1``. @@ -9712,8 +9184,11 @@ def cov(m: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True, if m.ndim > 2: raise ValueError("m has more than 2 dimensions") # same as numpy error + if dtype is not None and not dtypes.issubdtype(dtype, np.inexact): + raise ValueError(f"cov: dtype must be a subclass of float or complex; got {dtype=}") + X = atleast_2d(m) - if not rowvar and X.shape[0] != 1: + if not rowvar and m.ndim != 1: X = X.T if X.shape[0] == 0: return array([]).reshape(0, 0) @@ -9723,6 +9198,10 @@ def cov(m: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True, if not rowvar and y_arr.shape[0] != 1: y_arr = y_arr.T X = concatenate((X, y_arr), axis=0) + if X.shape[1] == 0: + cov_shape = () if X.shape[0] == 1 else (X.shape[0], X.shape[0]) + return array_creation.full(cov_shape, np.nan, dtype=X.dtype) + if ddof is None: ddof = 1 if bias == 0 else 0 @@ -9733,7 +9212,7 @@ def cov(m: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True, raise RuntimeError("cannot handle multidimensional fweights") if np.shape(fweights)[0] != X.shape[1]: raise RuntimeError("incompatible numbers of samples and fweights") - if not issubdtype(_dtype(fweights), np.integer): + if not issubdtype(fweights.dtype, np.integer): raise TypeError("fweights must be integer.") # Ensure positive fweights; note that numpy raises an error on negative fweights. w = abs(fweights) @@ -9747,6 +9226,10 @@ def cov(m: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True, aweights = abs(aweights) w = aweights if w is None else w * aweights + if dtype is not None: + X = X.astype(dtype) + w = w.astype(dtype) if w is not None else w + avg, w_sum = reductions.average(X, axis=1, weights=w, returned=True) w_sum = w_sum[0] @@ -9765,8 +9248,9 @@ def cov(m: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True, @export -@partial(jit, static_argnames=('rowvar',)) -def corrcoef(x: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True) -> Array: +@api.jit(static_argnames=('rowvar', 'dtype')) +def corrcoef(x: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True, + dtype: DTypeLike | None = None) -> Array: r"""Compute the Pearson correlation coefficients. JAX implementation of :func:`numpy.corrcoef`. @@ -9790,6 +9274,8 @@ def corrcoef(x: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True) -> A ``rowvar = True`` case, ``m`` becomes ``jnp.vstack([m, y])``. rowvar: if True (default) then each row of ``m`` represents a variable. If False, then each column represents a variable. + dtype: optional data type of the result. Must be a float or complex type; + if not specified, it will be determined based on the dtype of the input. Returns: A covariance matrix of shape ``(M, M)``. @@ -9841,7 +9327,9 @@ def corrcoef(x: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True) -> A [0.12 0.01 1. ]] """ util.check_arraylike("corrcoef", x) - c = cov(x, y, rowvar) + if dtype is not None and not dtypes.issubdtype(dtype, np.inexact): + raise ValueError(f"corrcoef: dtype must be a subclass of float or complex; got {dtype=}") + c = cov(x, y, rowvar, dtype=dtype) if len(np.shape(c)) == 0: # scalar - this should yield nan for values (nan/nan, inf/inf, 0/0), 1 otherwise return ufuncs.divide(c, c) @@ -9859,27 +9347,31 @@ def corrcoef(x: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True) -> A @partial(vectorize, excluded={0, 1, 3, 4}) -def _searchsorted_via_scan(unrolled: bool, sorted_arr: Array, query: Array, side: str, dtype: type) -> Array: - op = _sort_le_comparator if side == 'left' else _sort_lt_comparator +def _searchsorted_via_scan(unrolled: bool, sorted_arr: Array, query: Array, + side: str, dtype: type) -> Array: + op = lax._sort_le_comparator if side == 'left' else lax._sort_lt_comparator unsigned_dtype = np.uint32 if dtype == np.int32 else np.uint64 def body_fun(state, _): low, high = state mid = low.astype(unsigned_dtype) + high.astype(unsigned_dtype) - mid = lax.div(mid, unsigned_dtype(2)).astype(dtype) + mid = lax.div(mid, array(2, dtype=unsigned_dtype)).astype(dtype) go_left = op(query, sorted_arr[mid]) return (where(go_left, low, mid), where(go_left, mid, high)), () n_levels = int(np.ceil(np.log2(len(sorted_arr) + 1))) init = (array(0, dtype=dtype), array(len(sorted_arr), dtype=dtype)) - carry, _ = lax.scan(body_fun, init, (), length=n_levels, - unroll=n_levels if unrolled else 1) + vma = core.typeof(sorted_arr).vma + init = tuple(core.pvary(i, tuple(vma)) for i in init) + carry, _ = control_flow.scan(body_fun, init, (), length=n_levels, + unroll=n_levels if unrolled else 1) return carry[1] def _searchsorted_via_sort(sorted_arr: Array, query: Array, side: str, dtype: type) -> Array: - working_dtype = np.dtype('int32') if sorted_arr.size + query.size < np.iinfo(np.int32).max else np.dtype('int64') + working_dtype = lax_utils.int_dtype_for_dim(sorted_arr.size + query.size, + signed=False) def _rank(x): idx = lax.iota(working_dtype, x.shape[0]) - return zeros_like(idx).at[argsort(x)].set(idx) + return array_creation.zeros_like(idx).at[argsort(x)].set(idx) query_flat = query.ravel() if side == 'left': index = _rank(lax.concatenate([query_flat, sorted_arr], 0))[:query.size] @@ -9889,13 +9381,13 @@ def _rank(x): def _searchsorted_via_compare_all(sorted_arr: Array, query: Array, side: str, dtype: type) -> Array: - op = _sort_lt_comparator if side == 'left' else _sort_le_comparator - comparisons = jax.vmap(op, in_axes=(0, None))(sorted_arr, query) + op = lax._sort_lt_comparator if side == 'left' else lax._sort_le_comparator + comparisons = api.vmap(op, in_axes=(0, None))(sorted_arr, query) return comparisons.sum(dtype=dtype, axis=0) @export -@partial(jit, static_argnames=('side', 'method')) +@api.jit(static_argnames=('side', 'method')) def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', sorter: ArrayLike | None = None, *, method: str = 'scan') -> Array: """Perform a binary search within a sorted array. @@ -9957,9 +9449,9 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', Array([0, 2, 5, 1, 1], dtype=int32) """ if sorter is None: - util.check_arraylike("searchsorted", a, v) + a, v = util.ensure_arraylike("searchsorted", a, v) else: - util.check_arraylike("searchsorted", a, v, sorter) + a, v, sorter = util.ensure_arraylike("searchsorted", a, v, sorter) if side not in ['left', 'right']: raise ValueError(f"{side!r} is an invalid value for keyword 'side'. " "Expected one of ['left', 'right'].") @@ -9972,20 +9464,21 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', a, v = util.promote_dtypes(a, v) if sorter is not None: a = a[sorter] - dtype = np.dtype('int32') if a.shape[0] <= np.iinfo(np.int32).max else np.dtype('int64') + dtype = lax_utils.int_dtype_for_dim(a.shape[0], signed=True) if a.shape[0] == 0: - return zeros_like(v, dtype=dtype) + return array_creation.zeros_like(v, dtype=dtype) impl = { 'scan': partial(_searchsorted_via_scan, False), 'scan_unrolled': partial(_searchsorted_via_scan, True), 'sort': _searchsorted_via_sort, 'compare_all': _searchsorted_via_compare_all, }[method] + a, v = core.standard_insert_pvary(a, v) return impl(a, v, side, dtype) # type: ignore @export -@partial(jit, static_argnames=('right', 'method')) +@api.jit(static_argnames=('right', 'method')) def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False, *, method: str | None = None) -> Array: """Convert an array to bin indices. @@ -10029,7 +9522,7 @@ def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False, if bins_arr.ndim != 1: raise ValueError(f"digitize: bins must be a 1-dimensional array; got {bins=}") if bins_arr.shape[0] == 0: - return zeros_like(x, dtype=np.int32) + return array_creation.zeros_like(x, dtype=np.int32) side = 'right' if not right else 'left' kwds: dict[str, str] = {} if method is None else {'method': method} return where( @@ -10123,20 +9616,21 @@ def piecewise(x: ArrayLike, condlist: Array | Sequence[ArrayLike], frozenset(funcs.items()), # dict is not hashable. *args, **kw) -@partial(jit, static_argnames=['funcs']) +@api.jit(static_argnames=['funcs']) def _piecewise(x: Array, condlist: Array, consts: dict[int, ArrayLike], funcs: frozenset[tuple[int, Callable[..., Array]]], *args, **kw) -> Array: funcdict = dict(funcs) funclist = [consts.get(i, funcdict.get(i)) for i in range(len(condlist) + 1)] - indices = argmax(reductions.cumsum(concatenate([zeros_like(condlist[:1]), condlist], 0), 0), 0) - dtype = _dtype(x) + indices = argmax(reductions.cumsum(concatenate( + [array_creation.zeros_like(condlist[:1]), condlist], 0), 0), 0) + dtype = x.dtype def _call(f): return lambda x: f(x, *args, **kw).astype(dtype) def _const(v): return lambda x: array(v, dtype=dtype) funclist = [_call(f) if callable(f) else _const(f) for f in funclist] - return vectorize(lax.switch, excluded=(1,))(indices, funclist, x) + return vectorize(control_flow.switch, excluded=(1,))(indices, funclist, x) def _tile_to_size(arr: Array, size: int) -> Array: diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 23f2a58b09f6..cd000862bacc 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -18,50 +18,56 @@ from functools import partial import itertools import math - -import numpy as np import operator from typing import Literal, NamedTuple, overload -import jax -from jax import jit, custom_jvp -from jax import lax +import numpy as np -from jax._src import deprecations -from jax._src.lax import lax as lax_internal -from jax._src.lax.lax import PrecisionLike +from jax._src import api +from jax._src import core +from jax._src import config +from jax._src.custom_derivatives import custom_jvp +from jax._src.lax import lax from jax._src.lax import linalg as lax_linalg +from jax._src.lax import utils as lax_utils +from jax._src.numpy import array_creation from jax._src.numpy import einsum from jax._src.numpy import indexing from jax._src.numpy import lax_numpy as jnp +from jax._src.sharding_impls import NamedSharding, PartitionSpec as P from jax._src.numpy import reductions, tensor_contractions, ufuncs from jax._src.numpy.util import promote_dtypes_inexact, ensure_arraylike from jax._src.util import canonicalize_axis, set_module -from jax._src.typing import ArrayLike, Array, DTypeLike, DeprecatedArg +from jax._src.typing import ArrayLike, Array, DTypeLike export = set_module('jax.numpy.linalg') class EighResult(NamedTuple): - eigenvalues: jax.Array - eigenvectors: jax.Array + eigenvalues: Array + eigenvectors: Array + + +class EigResult(NamedTuple): + eigenvalues: Array + eigenvectors: Array class QRResult(NamedTuple): - Q: jax.Array - R: jax.Array + Q: Array + R: Array class SlogdetResult(NamedTuple): - sign: jax.Array - logabsdet: jax.Array + sign: Array + logabsdet: Array class SVDResult(NamedTuple): - U: jax.Array - S: jax.Array - Vh: jax.Array + U: Array + S: Array + Vh: Array def _H(x: ArrayLike) -> Array: @@ -72,8 +78,8 @@ def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2 @export -@partial(jit, static_argnames=['upper']) -def cholesky(a: ArrayLike, *, upper: bool = False) -> Array: +@api.jit(static_argnames=['upper', 'symmetrize_input']) +def cholesky(a: ArrayLike, *, upper: bool = False, symmetrize_input: bool = True) -> Array: """Compute the Cholesky decomposition of a matrix. JAX implementation of :func:`numpy.linalg.cholesky`. @@ -98,10 +104,14 @@ def cholesky(a: ArrayLike, *, upper: bool = False) -> Array: Must have shape ``(..., N, N)``. upper: if True, compute the upper Cholesky decomposition `U`. if False (default), compute the lower Cholesky decomposition `L`. + symmetrize_input: if True (default) then input is symmetrized, which leads + to better behavior under automatic differentiation. Note that when this + is set to True, both the upper and lower triangles of the input will + be used in computing the decomposition. Returns: array of shape ``(..., N, N)`` representing the Cholesky decomposition - of the input. If the input is not Hermitian positive-definite, The result + of the input. If the input is not Hermitian positive-definite, the result will contain NaN entries. @@ -135,7 +145,7 @@ def cholesky(a: ArrayLike, *, upper: bool = False) -> Array: """ a = ensure_arraylike("jnp.linalg.cholesky", a) a, = promote_dtypes_inexact(a) - L = lax_linalg.cholesky(a) + L = lax_linalg.cholesky(a, symmetrize_input=symmetrize_input) return L.mT.conj() if upper else L @@ -198,7 +208,7 @@ def svd( @export @partial( - jit, + api.jit, static_argnames=( "full_matrices", "compute_uv", @@ -289,7 +299,9 @@ def svd( s = lax.abs(v) if compute_uv: sign = lax.sign(v) - idxs = lax.broadcasted_iota(np.int64, s.shape, dimension=s.ndim - 1) + idx_dtype = lax_utils.int_dtype_for_dim( + s.shape[s.ndim - 1], signed=False) + idxs = lax.broadcasted_iota(idx_dtype, s.shape, dimension=s.ndim - 1) s, idxs, sign = lax.sort((s, idxs, sign), dimension=-1, num_keys=1) s = lax.rev(s, dimensions=[s.ndim - 1]) idxs = lax.rev(idxs, dimensions=[s.ndim - 1]) @@ -318,7 +330,7 @@ def svd( @export -@partial(jit, static_argnames=('n',)) +@api.jit(static_argnames=('n',)) def matrix_power(a: ArrayLike, n: int) -> Array: """Raise a square matrix to an integer power. @@ -363,8 +375,7 @@ def matrix_power(a: ArrayLike, n: int) -> Array: Array([[ 5.5 , -2.5 ], [-3.75, 1.75]], dtype=float32) """ - a = ensure_arraylike("jnp.linalg.matrix_power", a) - arr, = promote_dtypes_inexact(a) + arr = ensure_arraylike("jnp.linalg.matrix_power", a) if arr.ndim < 2: raise TypeError("{}-dimensional array given. Array must be at least " @@ -400,10 +411,9 @@ def matrix_power(a: ArrayLike, n: int) -> Array: @export -@jit +@api.jit def matrix_rank( - M: ArrayLike, rtol: ArrayLike | None = None, *, - tol: ArrayLike | DeprecatedArg | None = DeprecatedArg()) -> Array: + M: ArrayLike, rtol: ArrayLike | None = None, *, tol: ArrayLike | None = None) -> Array: """Compute the rank of a matrix. JAX implementation of :func:`numpy.linalg.matrix_rank`. @@ -417,8 +427,8 @@ def matrix_rank( smaller than `rtol * largest_singular_value` are considered to be zero. If ``rtol`` is None (the default), a reasonable default is chosen based the floating point precision of the input. - tol: deprecated alias of the ``rtol`` argument. Will result in a - :class:`DeprecationWarning` if used. + tol: alias of the ``rtol`` argument present for backward compatibility. + Only one of `rtol` or `tol` may be specified. Returns: array of shape ``a.shape[-2]`` giving the matrix rank. @@ -440,16 +450,11 @@ def matrix_rank( Array(1, dtype=int32) """ M = ensure_arraylike("jnp.linalg.matrix_rank", M) - # TODO(micky774): deprecated 2024-5-14, remove after deprecation expires. - if not isinstance(tol, DeprecatedArg): + if tol is not None: + if rtol is not None: + raise ValueError("matrix_rank: only one of tol or rtol may be specified.") rtol = tol - del tol - deprecations.warn( - "jax-numpy-linalg-matrix_rank-tol", - ("The tol argument for linalg.matrix_rank is deprecated. " - "Please use rtol instead."), - stacklevel=2 - ) + del tol M, = promote_dtypes_inexact(M) if M.ndim < 2: return (M != 0).any().astype(np.int32) @@ -505,7 +510,7 @@ def _slogdet_qr(a: Array) -> tuple[Array, Array]: @export -@partial(jit, static_argnames=('method',)) +@api.jit(static_argnames=('method',)) def slogdet(a: ArrayLike, *, method: str | None = None) -> SlogdetResult: """ Compute the sign and (natural) logarithm of the determinant of an array. @@ -559,7 +564,7 @@ def _slogdet_jvp(primals, tangents): sign_dot = (ans_dot - ufuncs.real(ans_dot).astype(ans_dot.dtype)) * sign ans_dot = ufuncs.real(ans_dot) else: - sign_dot = jnp.zeros_like(sign) + sign_dot = array_creation.zeros_like(sign) return (sign, ans), (sign_dot, ans_dot) _slogdet_lu.defjvp(_slogdet_jvp) @@ -641,18 +646,18 @@ def _cofactor_solve(a: ArrayLike, b: ArrayLike) -> tuple[Array, Array]: permutation = jnp.broadcast_to(permutation, (*batch_dims, a_shape[-1])) iotas = jnp.ix_(*(lax.iota(np.int32, b) for b in (*batch_dims, 1))) # filter out any matrices that are not full rank - d = jnp.ones(x.shape[:-1], x.dtype) + d = array_creation.ones(x.shape[:-1], x.dtype) d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False) d = reductions.any(ufuncs.logical_or(ufuncs.isnan(d), ufuncs.isinf(d)), axis=-1) d = jnp.tile(d[..., None, None], d.ndim*(1,) + x.shape[-2:]) - x = jnp.where(d, jnp.zeros_like(x), x) # first filter + x = jnp.where(d, array_creation.zeros_like(x), x) # first filter x = x[iotas[:-1] + (permutation, slice(None))] x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=True, unit_diagonal=True) x = jnp.concatenate((x[..., :-1, :] * partial_det[..., -1, None, None], x[..., -1:, :]), axis=-2) x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False) - x = jnp.where(d, jnp.zeros_like(x), x) # second filter + x = jnp.where(d, array_creation.zeros_like(x), x) # second filter return partial_det[..., -1], x @@ -686,7 +691,7 @@ def _det_jvp(primals, tangents): @export -@jit +@api.jit def det(a: ArrayLike) -> Array: """ Compute the determinant of an array. @@ -723,7 +728,7 @@ def det(a: ArrayLike) -> Array: @export -def eig(a: ArrayLike) -> tuple[Array, Array]: +def eig(a: ArrayLike) -> EigResult: """ Compute the eigenvalues and eigenvectors of a square array. @@ -733,7 +738,7 @@ def eig(a: ArrayLike) -> tuple[Array, Array]: a: array of shape ``(..., M, M)`` for which to compute the eigenvalues and vectors. Returns: - A tuple ``(eigenvalues, eigenvectors)`` with + A namedtuple ``(eigenvalues, eigenvectors)``. The namedtuple has fields: - ``eigenvalues``: an array of shape ``(..., M)`` containing the eigenvalues. - ``eigenvectors``: an array of shape ``(..., M, M)``, where column ``v[:, i]`` is the @@ -765,11 +770,11 @@ def eig(a: ArrayLike) -> tuple[Array, Array]: a = ensure_arraylike("jnp.linalg.eig", a) a, = promote_dtypes_inexact(a) w, v = lax_linalg.eig(a, compute_left_eigenvectors=False) - return w, v + return EigResult(w, v) @export -@jit +@api.jit def eigvals(a: ArrayLike) -> Array: """ Compute the eigenvalues of a general matrix. @@ -807,7 +812,7 @@ def eigvals(a: ArrayLike) -> Array: @export -@partial(jit, static_argnames=('UPLO', 'symmetrize_input')) +@api.jit(static_argnames=('UPLO', 'symmetrize_input')) def eigh(a: ArrayLike, UPLO: str | None = None, symmetrize_input: bool = True) -> EighResult: """ @@ -821,7 +826,9 @@ def eigh(a: ArrayLike, UPLO: str | None = None, UPLO: specifies whether the calculation is done with the lower triangular part of ``a`` (``'L'``, default) or the upper triangular part (``'U'``). symmetrize_input: if True (default) then input is symmetrized, which leads - to better behavior under automatic differentiation. + to better behavior under automatic differentiation. Note that when this + is set to True, both the upper and lower triangles of the input will + be used in computing the decomposition. Returns: A namedtuple ``(eigenvalues, eigenvectors)`` where @@ -863,8 +870,9 @@ def eigh(a: ArrayLike, UPLO: str | None = None, @export -@partial(jit, static_argnames=('UPLO',)) -def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array: +@api.jit(static_argnames=('UPLO', 'symmetrize_input')) +def eigvalsh(a: ArrayLike, UPLO: str | None = 'L', *, + symmetrize_input: bool = True) -> Array: """ Compute the eigenvalues of a Hermitian matrix. @@ -875,6 +883,10 @@ def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array: or symmetric (if real) matrix. UPLO: specifies whether the calculation is done with the lower triangular part of ``a`` (``'L'``, default) or the upper triangular part (``'U'``). + symmetrize_input: if True (default) then input is symmetrized, which leads + to better behavior under automatic differentiation. Note that when this + is set to True, both the upper and lower triangles of the input will + be used in computing the decomposition. Returns: An array of shape ``(..., M)`` containing the eigenvalues, sorted in @@ -894,15 +906,14 @@ def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array: """ a = ensure_arraylike("jnp.linalg.eigvalsh", a) a, = promote_dtypes_inexact(a) - w, _ = eigh(a, UPLO) + w, _ = eigh(a, UPLO, symmetrize_input=symmetrize_input) return w # TODO(micky774): deprecated 2024-5-14, remove wrapper after deprecation expires. @export def pinv(a: ArrayLike, rtol: ArrayLike | None = None, - hermitian: bool = False, *, - rcond: ArrayLike | DeprecatedArg | None = DeprecatedArg()) -> Array: + hermitian: bool = False, *, rcond: ArrayLike | None = None) -> Array: """Compute the (Moore-Penrose) pseudo-inverse of a matrix. JAX implementation of :func:`numpy.linalg.pinv`. @@ -916,8 +927,8 @@ def pinv(a: ArrayLike, rtol: ArrayLike | None = None, determined based on the floating point precision of the dtype. hermitian: if True, then the input is assumed to be Hermitian, and a more efficient algorithm is used (default: False) - rcond: deprecated alias of the ``rtol`` argument. Will result in a - :class:`DeprecationWarning` if used. + rcond: alias of the `rtol` argument, present for backward compatibility. + Only one of `rtol` and `rcond` may be specified. Returns: An array of shape ``(..., N, M)`` containing the pseudo-inverse of ``a``. @@ -945,21 +956,16 @@ def pinv(a: ArrayLike, rtol: ArrayLike | None = None, >>> jnp.allclose(a_pinv @ a, jnp.eye(2), atol=1E-4) Array(True, dtype=bool) """ - if not isinstance(rcond, DeprecatedArg): + if rcond is not None: + if rtol is not None: + raise ValueError("pinv: only one of rtol and rcond may be specified.") rtol = rcond - del rcond - deprecations.warn( - "jax-numpy-linalg-pinv-rcond", - ("The rcond argument for linalg.pinv is deprecated. " - "Please use rtol instead."), - stacklevel=2 - ) - + del rcond return _pinv(a, rtol, hermitian) @partial(custom_jvp, nondiff_argnums=(1, 2)) -@partial(jit, static_argnames=('hermitian')) +@api.jit(static_argnames=('hermitian')) def _pinv(a: ArrayLike, rtol: ArrayLike | None = None, hermitian: bool = False) -> Array: # Uses same algorithm as # https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979 @@ -967,7 +973,7 @@ def _pinv(a: ArrayLike, rtol: ArrayLike | None = None, hermitian: bool = False) arr, = promote_dtypes_inexact(a) m, n = arr.shape[-2:] if m == 0 or n == 0: - return jnp.empty(arr.shape[:-2] + (n, m), arr.dtype) + return array_creation.empty(arr.shape[:-2] + (n, m), arr.dtype) arr = ufuncs.conj(arr) if rtol is None: max_rows_cols = max(arr.shape[-2:]) @@ -985,7 +991,7 @@ def _pinv(a: ArrayLike, rtol: ArrayLike | None = None, hermitian: bool = False) @_pinv.defjvp -@jax.default_matmul_precision("float32") +@config.default_matmul_precision("float32") def _pinv_jvp(rtol, hermitian, primals, tangents): # The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems # Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM @@ -1014,7 +1020,7 @@ def _pinv_jvp(rtol, hermitian, primals, tangents): @export -@jit +@api.jit def inv(a: ArrayLike) -> Array: """Return the inverse of a square matrix @@ -1074,7 +1080,7 @@ def inv(a: ArrayLike) -> Array: @export -@partial(jit, static_argnames=('ord', 'axis', 'keepdims')) +@api.jit(static_argnames=('ord', 'axis', 'keepdims')) def norm(x: ArrayLike, ord: int | str | None = None, axis: None | tuple[int, ...] | int = None, keepdims: bool = False) -> Array: @@ -1210,12 +1216,16 @@ def norm(x: ArrayLike, ord: int | str | None = None, " compute a vector-norm, or two axes to compute a matrix-norm.") @overload +def qr(a: ArrayLike, + mode: Literal["reduced", "complete", "raw", "full"] = "reduced", + ) -> QRResult: ... +@overload def qr(a: ArrayLike, mode: Literal["r"]) -> Array: ... @overload -def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult: ... +def qr(a: ArrayLike, mode: str) -> Array | QRResult: ... @export -@partial(jit, static_argnames=('mode',)) +@api.jit(static_argnames=('mode',)) def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult: """Compute the QR decomposition of an array @@ -1299,7 +1309,7 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult: @export -@jit +@api.jit def solve(a: ArrayLike, b: ArrayLike) -> Array: """Solve a linear system of equations. @@ -1357,6 +1367,7 @@ def solve(a: ArrayLike, b: ArrayLike) -> Array: " To recover this behavior, use solve(a, b[..., None]).squeeze(-1).") signature = "(m,m),(m)->(m)" if b.ndim == 1 else "(m,m),(m,n)->(m,n)" + a, b = core.standard_insert_pvary(a, b) return jnp.vectorize(lax_linalg._solve, signature=signature)(a, b) @@ -1379,16 +1390,16 @@ def _lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None, *, m, n = a.shape dtype = a.dtype if a.size == 0: - s = jnp.empty(0, dtype=a.dtype) + s = array_creation.empty(0, dtype=a.dtype) rank = jnp.array(0, dtype=int) - x = jnp.empty((n, *b.shape[1:]), dtype=a.dtype) + x = array_creation.empty((n, *b.shape[1:]), dtype=a.dtype) else: if rcond is None: rcond = float(jnp.finfo(dtype).eps) * max(n, m) else: rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond) u, s, vt = svd(a, full_matrices=False) - mask = s >= jnp.array(rcond, dtype=s.dtype) * s[0] + mask = (s > 0) & (s >= jnp.array(rcond, dtype=s.dtype) * s[0]) rank = mask.sum() safe_s = jnp.where(mask, s, 1).astype(a.dtype) s_inv = jnp.where(mask, 1 / safe_s, 0)[:, np.newaxis] @@ -1405,7 +1416,7 @@ def _lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None, *, x = x.ravel() return x, resid, rank, s -_jit_lstsq = jit(partial(_lstsq, numpy_resid=False)) +_jit_lstsq = api.jit(partial(_lstsq, numpy_resid=False)) @export @@ -1606,8 +1617,8 @@ def matrix_transpose(x: ArrayLike, /) -> Array: x_arr = ensure_arraylike('jnp.linalg.matrix_transpose', x) ndim = x_arr.ndim if ndim < 2: - raise ValueError(f"matrix_transpose requres at least 2 dimensions; got {ndim=}") - return jax.lax.transpose(x_arr, (*range(ndim - 2), ndim - 1, ndim - 2)) + raise ValueError(f"matrix_transpose requires at least 2 dimensions; got {ndim=}") + return lax.transpose(x_arr, (*range(ndim - 2), ndim - 1, ndim - 2)) @export @@ -1672,14 +1683,14 @@ def vector_norm(x: ArrayLike, /, *, axis: int | tuple[int, ...] | None = None, k raise ValueError(msg) else: abs_x = ufuncs.abs(x) - ord_arr = lax_internal._const(abs_x, ord) - ord_inv = lax_internal._const(abs_x, 1. / ord_arr) + ord_arr = lax._const(abs_x, ord) + ord_inv = lax._const(abs_x, 1. / ord_arr) out = reductions.sum(abs_x ** ord_arr, axis=axis, keepdims=keepdims) return ufuncs.power(out, ord_inv) @export def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, - precision: PrecisionLike = None, + precision: lax.PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: """Compute the (batched) vector conjugate dot product of two arrays. @@ -1730,7 +1741,7 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, @export def matmul(x1: ArrayLike, x2: ArrayLike, /, *, - precision: PrecisionLike = None, + precision: lax.PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: """Perform a matrix multiplication. @@ -1792,8 +1803,9 @@ def matmul(x1: ArrayLike, x2: ArrayLike, /, *, @export def tensordot(x1: ArrayLike, x2: ArrayLike, /, *, axes: int | tuple[Sequence[int], Sequence[int]] = 2, - precision: PrecisionLike = None, - preferred_element_type: DTypeLike | None = None) -> Array: + precision: lax.PrecisionLike = None, + preferred_element_type: DTypeLike | None = None, + out_sharding: NamedSharding | P | None = None) -> Array: """Compute the tensor dot product of two N-dimensional arrays. JAX implementation of :func:`numpy.linalg.tensordot`. @@ -1867,8 +1879,9 @@ def tensordot(x1: ArrayLike, x2: ArrayLike, /, *, [2, 4, 6]], dtype=int32) """ x1, x2 = ensure_arraylike('jnp.linalg.tensordot', x1, x2) - return tensor_contractions.tensordot(x1, x2, axes=axes, precision=precision, - preferred_element_type=preferred_element_type) + return tensor_contractions.tensordot( + x1, x2, axes=axes, precision=precision, + preferred_element_type=preferred_element_type, out_sharding=out_sharding) @export @@ -2029,7 +2042,7 @@ def tensorsolve(a: ArrayLike, b: ArrayLike, axes: tuple[int, ...] | None = None) @export -def multi_dot(arrays: Sequence[ArrayLike], *, precision: PrecisionLike = None) -> Array: +def multi_dot(arrays: Sequence[ArrayLike], *, precision: lax.PrecisionLike = None) -> Array: """Efficiently compute matrix products between a sequence of arrays. JAX implementation of :func:`numpy.linalg.multi_dot`. @@ -2121,7 +2134,7 @@ def multi_dot(arrays: Sequence[ArrayLike], *, precision: PrecisionLike = None) - @export -@partial(jit, static_argnames=['p']) +@api.jit(static_argnames=['p']) def cond(x: ArrayLike, p=None): """Compute the condition number of a matrix. diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index 81d320cb7403..71856cb5ac21 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -14,20 +14,19 @@ from __future__ import annotations -from functools import partial import operator import numpy as np -from jax import jit -from jax import lax +from jax._src import api from jax._src import dtypes from jax._src import core -from jax._src.lax import lax as lax_internal +from jax._src.lax import control_flow +from jax._src.lax import lax +from jax._src.numpy.array_creation import full, ones, zeros from jax._src.numpy.lax_numpy import ( arange, argmin, array, atleast_1d, concatenate, convolve, - diag, finfo, full, ones, roll, trim_zeros, - trim_zeros_tol, vander, zeros) + diag, finfo, roll, trim_zeros, trim_zeros_tol, vander) from jax._src.numpy.tensor_contractions import dot, outer from jax._src.numpy.ufuncs import maximum, true_divide, sqrt from jax._src.numpy.reductions import all @@ -41,7 +40,7 @@ export = set_module('jax.numpy') -@jit +@api.jit def _roots_no_zeros(p: Array) -> Array: # build companion matrix and find its eigenvalues (the roots) if p.size < 2: @@ -51,7 +50,7 @@ def _roots_no_zeros(p: Array) -> Array: return linalg.eigvals(A) -@jit +@api.jit def _roots_with_zeros(p: Array, num_leading_zeros: Array | int) -> Array: # Avoid lapack errors when p is all zero p = _where(len(p) == num_leading_zeros, 1.0, p) @@ -124,7 +123,7 @@ def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array: @export -@partial(jit, static_argnames=('deg', 'rcond', 'full', 'cov')) +@api.jit(static_argnames=('deg', 'rcond', 'full', 'cov')) def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, full: bool = False, w: ArrayLike | None = None, cov: bool = False ) -> Array | tuple[Array, ...]: @@ -146,7 +145,7 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, rcond: Relative condition number of the fit. Default value is ``len(x) * eps``. It must be specified statically. full: Switch that controls the return value. Default is ``False`` which - restricts the return value to the array of polynomail coefficients ``p``. + restricts the return value to the array of polynomial coefficients ``p``. If ``True``, the function returns a tuple ``(p, resids, rank, s, rcond)``. It must be specified statically. w: Array of weights of shape ``(M,)``. If None, all data points are considered @@ -154,8 +153,8 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, unsquared residual of :math:`y_i - \widehat{y}_i` at :math:`x_i`, where :math:`\widehat{y}_i` is the fitted value of :math:`y_i`. Default is None. cov: Boolean or string. If ``True``, returns the covariance matrix scaled - by ``resids/(M-deg-1)`` along with ploynomial coefficients. If - ``cov='unscaled'``, returns the unscaaled version of covariance matrix. + by ``resids/(M-deg-1)`` along with polynomial coefficients. If + ``cov='unscaled'``, returns the unscaled version of covariance matrix. Default is ``False``. ``cov`` is ignored if ``full=True``. It must be specified statically. @@ -224,7 +223,7 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, >>> p, C = jnp.polyfit(x, y, 2, cov=True) >>> p.shape, C.shape - ((3, 3), (3, 3, 1)) + ((3, 3), (3, 3, 3)) """ if w is None: x_arr, y_arr = ensure_arraylike("polyfit", x, y) @@ -233,7 +232,6 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, del x, y deg = core.concrete_or_error(int, deg, "deg must be int") order = deg + 1 - # check arguments if deg < 0: raise ValueError("expected deg >= 0") if x_arr.ndim != 1: @@ -245,7 +243,6 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, if x_arr.shape[0] != y_arr.shape[0]: raise TypeError("expected x and y to have same length") - # set rcond if rcond is None: rcond = len(x_arr) * float(finfo(x_arr.dtype).eps) rcond = core.concrete_or_error(float, rcond, "rcond must be float") @@ -268,34 +265,45 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, # scale lhs to improve condition number and solve scale = sqrt((lhs*lhs).sum(axis=0)) - lhs /= scale[np.newaxis,:] + lhs /= scale[np.newaxis, :] c, resids, rank, s = linalg.lstsq(lhs, rhs, rcond) - c = (c.T/scale).T # broadcast scale coefficients + + # Broadcasting scale coefficients + if c.ndim > 1: + # For multi-dimensional output, make scale (1, order) to divide + # across the c.T of shape (num_rhs, order) + c = (c.T / scale[np.newaxis, :]).T + else: + # Simple case for 1D output + c = c / scale if full: assert rcond is not None - return c, resids, rank, s, lax_internal.asarray(rcond) + return c, resids, rank, s, lax.asarray(rcond) elif cov: Vbase = linalg.inv(dot(lhs.T, lhs)) Vbase /= outer(scale, scale) + if cov == "unscaled": - fac = 1 + fac = array(1.0) else: if len(x_arr) <= order: - raise ValueError("the number of data points must exceed order " - "to scale the covariance matrix") + raise ValueError("the number of data points must exceed order" + " to scale the covariance matrix") fac = resids / (len(x_arr) - order) - fac = fac[0] #making np.array() of shape (1,) to int + if y_arr.ndim == 1: + fac = atleast_1d(fac)[np.newaxis] + # For 1D output, simple scalar multiplication return c, Vbase * fac else: - return c, Vbase[:, :, np.newaxis] * fac + # For multiple rhs, broadcast fac to match shape + return c, Vbase[:, :, np.newaxis] * atleast_1d(fac)[np.newaxis, np.newaxis, :] else: return c - @export -@jit +@api.jit def poly(seq_of_zeros: ArrayLike) -> Array: r"""Returns the coefficients of a polynomial for the given sequence of roots. @@ -378,7 +386,7 @@ def poly(seq_of_zeros: ArrayLike) -> Array: @export -@partial(jit, static_argnames=['unroll']) +@api.jit(static_argnames=['unroll']) def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array: r"""Evaluates the polynomial at specific values. @@ -437,12 +445,12 @@ def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array: del p, x shape = lax.broadcast_shapes(p_arr.shape[1:], x_arr.shape) y = lax.full_like(x_arr, 0, shape=shape, dtype=x_arr.dtype) - y, _ = lax.scan(lambda y, p: (y * x_arr + p, None), y, p_arr, unroll=unroll) # type: ignore[misc] + y, _ = control_flow.scan(lambda y, p: (y * x_arr + p, None), y, p_arr, unroll=unroll) # type: ignore[misc] return y @export -@jit +@api.jit def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array: r"""Returns the sum of the two polynomials. @@ -500,7 +508,7 @@ def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array: @export -@partial(jit, static_argnames=('m',)) +@api.jit(static_argnames=('m',)) def polyint(p: ArrayLike, m: int = 1, k: int | ArrayLike | None = None) -> Array: r"""Returns the coefficients of the integration of specified order of a polynomial. @@ -569,7 +577,7 @@ def polyint(p: ArrayLike, m: int = 1, k: int | ArrayLike | None = None) -> Array @export -@partial(jit, static_argnames=('m',)) +@api.jit(static_argnames=('m',)) def polyder(p: ArrayLike, m: int = 1) -> Array: r"""Returns the coefficients of the derivative of specified order of a polynomial. @@ -747,7 +755,7 @@ def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> @export -@jit +@api.jit def polysub(a1: ArrayLike, a2: ArrayLike) -> Array: r"""Returns the difference of two polynomials. diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 985b296bc06f..1d6869d2140d 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -16,33 +16,31 @@ import builtins from collections.abc import Callable, Sequence -from functools import partial import math import operator from typing import overload, Any, Literal, Protocol, Union import numpy as np -import jax -from jax import lax from jax._src import api +from jax._src import config from jax._src import core -from jax._src import deprecations from jax._src import dtypes from jax._src.numpy.util import ( - _broadcast_to, check_arraylike, _complex_elem_type, + _broadcast_to, ensure_arraylike, promote_dtypes_inexact, promote_dtypes_numeric, _where) -from jax._src.lax import lax as lax_internal -from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DeprecatedArg -from jax._src.util import ( - canonicalize_axis as _canonicalize_axis, maybe_named_axis, - set_module) +from jax._src.lax import control_flow +from jax._src.lax import lax as lax +from jax._src.lax import other as lax_other +from jax._src.lax import parallel as lax_parallel +from jax._src.lax import slicing as lax_slicing +from jax._src.typing import Array, ArrayLike, DType, DTypeLike +from jax._src.util import canonicalize_axis, canonicalize_axis_tuple, maybe_named_axis, set_module export = set_module('jax.numpy') _all = builtins.all -_lax_const = lax_internal._const Axis = Union[int, Sequence[int], None] @@ -54,10 +52,9 @@ def _isscalar(element: Any) -> bool: def _moveaxis(a: ArrayLike, source: int, destination: int) -> Array: # simplified version of jnp.moveaxis() for local use. - check_arraylike("moveaxis", a) - a = lax_internal.asarray(a) - source = _canonicalize_axis(source, np.ndim(a)) - destination = _canonicalize_axis(destination, np.ndim(a)) + a = ensure_arraylike("moveaxis", a) + source = canonicalize_axis(source, np.ndim(a)) + destination = canonicalize_axis(destination, np.ndim(a)) perm = [i for i in range(np.ndim(a)) if i != source] perm.insert(destination, source) return lax.transpose(a, perm) @@ -67,33 +64,30 @@ def _upcast_f16(dtype: DTypeLike) -> DType: return np.dtype('float32') return np.dtype(dtype) -def _promote_integer_dtype(dtype: DTypeLike) -> DTypeLike: +def _promote_integer_dtype(dtype: DType) -> DType: # Note: NumPy always promotes to 64-bit; jax instead promotes to the # default dtype as defined by dtypes.int_ or dtypes.uint. if dtypes.issubdtype(dtype, np.bool_): - return dtypes.int_ + return dtypes.default_int_dtype() elif dtypes.issubdtype(dtype, np.unsignedinteger): - if np.iinfo(dtype).bits < np.iinfo(dtypes.uint).bits: - return dtypes.uint + default_uint_dtype = dtypes.default_uint_dtype() + if np.iinfo(dtype).bits < np.iinfo(default_uint_dtype).bits: + return default_uint_dtype elif dtypes.issubdtype(dtype, np.integer): - if np.iinfo(dtype).bits < np.iinfo(dtypes.int_).bits: - return dtypes.int_ + default_int_dtype = dtypes.default_int_dtype() + if np.iinfo(dtype).bits < np.iinfo(default_int_dtype).bits: + return default_int_dtype return dtype def check_where(name: str, where: ArrayLike | None) -> Array | None: if where is None: return where - check_arraylike(name, where) - where_arr = lax_internal.asarray(where) - if where_arr.dtype != bool: - # Deprecation added 2024-12-05 - deprecations.warn( - 'jax-numpy-reduction-non-boolean-where', - f"jnp.{name}: where must be None or a boolean array; got dtype={where_arr.dtype}.", - stacklevel=2) - return where_arr.astype(bool) - return where_arr - + where = ensure_arraylike(name, where) + if where.dtype != bool: + raise ValueError( + f"jnp.{name}: where must be None or a boolean array; got {where.dtype=}." + ) + return where ReductionOp = Callable[[Any, Any], Any] @@ -113,16 +107,14 @@ def _reduction(a: ArrayLike, name: str, op: ReductionOp, init_val: ArrayLike, # exists, passing along all its arguments. if out is not None: raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported.") - check_arraylike(name, a) + a = ensure_arraylike(name, a) where_ = check_where(name, where_) - dtypes.check_user_dtype_supported(dtype, name) axis = core.concrete_or_error(None, axis, f"axis argument to jnp.{name}().") if initial is None and not has_identity and where_ is not None: raise ValueError(f"reduction operation {name} does not have an identity, so to use a " f"where mask one has to specify 'initial'") - a = a if isinstance(a, Array) else lax_internal.asarray(a) a = preproc(a) if preproc else a pos_dims, dims = _reduction_dims(a, axis) @@ -131,12 +123,13 @@ def _reduction(a: ArrayLike, name: str, op: ReductionOp, init_val: ArrayLike, if not _all(shape[d] >= 1 for d in pos_dims): raise ValueError(f"zero-size array to reduction operation {name} which has no identity") - result_dtype = dtype or dtypes.dtype(a) - - if dtype is None and promote_integers: - result_dtype = _promote_integer_dtype(result_dtype) - - result_dtype = dtypes.canonicalize_dtype(result_dtype) + result_dtype: DType + if dtype is None: + result_dtype = a.dtype + if promote_integers: + result_dtype = _promote_integer_dtype(result_dtype) + else: + result_dtype = dtypes.check_and_canonicalize_user_dtype(dtype, name) if upcast_f16_for_computation and dtypes.issubdtype(result_dtype, np.inexact): computation_dtype = _upcast_f16(result_dtype) @@ -156,7 +149,7 @@ def _reduction(a: ArrayLike, name: str, op: ReductionOp, init_val: ArrayLike, else: result = lax.reduce(a, init_val, op, dims) if initial is not None: - initial_arr = lax.convert_element_type(initial, lax_internal.asarray(a).dtype) + initial_arr = lax.convert_element_type(initial, lax.asarray(a).dtype) if initial_arr.shape != (): raise ValueError("initial value must be a scalar. " f"Got array of shape {initial_arr.shape}") @@ -166,7 +159,7 @@ def _reduction(a: ArrayLike, name: str, op: ReductionOp, init_val: ArrayLike, return lax.convert_element_type(result, dtype or result_dtype) def _canonicalize_axis_allow_named(x, rank): - return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank), lambda name: name) + return maybe_named_axis(x, lambda i: canonicalize_axis(i, rank), lambda name: name) def _reduction_dims(a: ArrayLike, axis: Axis): if axis is None: @@ -183,10 +176,10 @@ def _reduction_dims(a: ArrayLike, axis: Axis): else: return canon_axis, canon_axis -def _reduction_init_val(a: ArrayLike, init_val: Any) -> np.ndarray: +def _reduction_init_val(a: Array, init_val: Any) -> np.ndarray: # This function uses np.* functions because lax pattern matches against the # specific concrete values of the reduction inputs. - a_dtype = dtypes.canonicalize_dtype(dtypes.dtype(a)) + a_dtype = a.dtype if a_dtype == 'bool': return np.array(init_val > 0, dtype=a_dtype) if (np.isinf(init_val) and dtypes.issubdtype(a_dtype, np.floating) @@ -209,7 +202,7 @@ def _cast_to_numeric(operand: Array) -> Array: return promote_dtypes_numeric(operand)[0] def _require_integer(arr: Array) -> Array: - if not dtypes.isdtype(arr, ("bool", "integral")): + if not dtypes.isdtype(arr.dtype, ("bool", "integral")): raise ValueError(f"integer argument required; got dtype={arr.dtype}") return arr @@ -225,7 +218,7 @@ def force(x): force, x, "The axis argument must be known statically.") -@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True) +@api.jit(static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True) def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, @@ -233,7 +226,7 @@ def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, return _reduction(a, "sum", lax.add, 0, preproc=_cast_to_numeric, bool_op=lax.bitwise_or, upcast_f16_for_computation=(dtype is None), axis=axis, dtype=dtype, out=out, keepdims=keepdims, - initial=initial, where_=where, parallel_reduce=lax.psum, + initial=initial, where_=where, parallel_reduce=lax_parallel.psum, promote_integers=promote_integers) @@ -313,7 +306,7 @@ def sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, -@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True) +@api.jit(static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True) def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, @@ -400,13 +393,13 @@ def prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, promote_integers=promote_integers) -@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) -def _reduce_max(a: ArrayLike, axis: Axis = None, out: None = None, - keepdims: bool = False, initial: ArrayLike | None = None, - where: ArrayLike | None = None) -> Array: +@api.jit(static_argnames=('axis', 'keepdims'), inline=True) +def _reduce_max(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: return _reduction(a, "max", lax.max, -np.inf, has_identity=False, - axis=axis, out=out, keepdims=keepdims, - initial=initial, where_=where, parallel_reduce=lax.pmax) + axis=axis, dtype=dtype, out=out, keepdims=keepdims, + initial=initial, where_=where, parallel_reduce=lax_parallel.pmax) @export @@ -483,13 +476,13 @@ def max(a: ArrayLike, axis: Axis = None, out: None = None, return _reduce_max(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, initial=initial, where=where) -@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) -def _reduce_min(a: ArrayLike, axis: Axis = None, out: None = None, - keepdims: bool = False, initial: ArrayLike | None = None, - where: ArrayLike | None = None) -> Array: +@api.jit(static_argnames=('axis', 'keepdims', 'dtype'), inline=True) +def _reduce_min(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: return _reduction(a, "min", lax.min, np.inf, has_identity=False, - axis=axis, out=out, keepdims=keepdims, - initial=initial, where_=where, parallel_reduce=lax.pmin) + axis=axis, dtype=dtype, out=out, keepdims=keepdims, + initial=initial, where_=where, parallel_reduce=lax_parallel.pmin) @export @@ -565,7 +558,7 @@ def min(a: ArrayLike, axis: Axis = None, out: None = None, return _reduce_min(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, initial=initial, where=where) -@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) +@api.jit(static_argnames=('axis', 'keepdims'), inline=True) def _reduce_all(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: return _reduction(a, "all", lax.bitwise_and, True, preproc=_cast_to_bool, @@ -622,7 +615,7 @@ def all(a: ArrayLike, axis: Axis = None, out: None = None, return _reduce_all(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, where=where) -@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) +@api.jit(static_argnames=('axis', 'keepdims'), inline=True) def _reduce_any(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: return _reduction(a, "any", lax.bitwise_or, False, preproc=_cast_to_bool, @@ -680,18 +673,18 @@ def any(a: ArrayLike, axis: Axis = None, out: None = None, keepdims=keepdims, where=where) -@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True) +@api.jit(static_argnames=('axis', 'keepdims', 'dtype'), inline=True) def _reduce_bitwise_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: - arr = lax_internal.asarray(a) - init_val = np.array(-1, dtype=dtype or arr.dtype) + arr = lax.asarray(a) + init_val = np.array(-1).astype(dtype or arr.dtype) return _reduction(arr, name="reduce_bitwise_and", op=lax.bitwise_and, init_val=init_val, preproc=_require_integer, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where) -@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True) +@api.jit(static_argnames=('axis', 'keepdims', 'dtype'), inline=True) def _reduce_bitwise_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: @@ -700,7 +693,7 @@ def _reduce_bitwise_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None initial=initial, where_=where) -@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True) +@api.jit(static_argnames=('axis', 'keepdims', 'dtype'), inline=True) def _reduce_bitwise_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: @@ -709,7 +702,7 @@ def _reduce_bitwise_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None initial=initial, where_=where) -@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True) +@api.jit(static_argnames=('axis', 'keepdims', 'dtype'), inline=True) def _reduce_logical_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: @@ -718,7 +711,7 @@ def _reduce_logical_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None initial=initial, where_=where) -@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True) +@api.jit(static_argnames=('axis', 'keepdims', 'dtype'), inline=True) def _reduce_logical_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: @@ -727,7 +720,7 @@ def _reduce_logical_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None initial=initial, where_=where) -@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True) +@api.jit(static_argnames=('axis', 'keepdims', 'dtype'), inline=True) def _reduce_logical_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: @@ -742,8 +735,11 @@ def _logsumexp(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, """Compute log(sum(exp(a))) while avoiding precision loss.""" if out is not None: raise NotImplementedError("The 'out' argument to jnp.logaddexp.reduce is not supported.") - dtypes.check_user_dtype_supported(dtype, "jnp.logaddexp.reduce") - check_arraylike("logsumexp", a) + if dtype is not None: + dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "jnp.logaddexp.reduce") + # TODO(phawkins): dtype isn't used here. That seems like a bug! + del dtype + a = ensure_arraylike("logsumexp", a) where = check_where("logsumexp", where) a_arr, = promote_dtypes_inexact(a) pos_dims, dims = _reduction_dims(a_arr, axis) @@ -753,7 +749,7 @@ def _logsumexp(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, exp_a = lax.exp(lax.sub(a_arr, amax_with_dims.astype(a_arr.dtype))) sumexp = exp_a.sum(axis=dims, keepdims=keepdims, where=where) result = lax.add(lax.log(sumexp), amax.astype(sumexp.dtype)) - return result if initial is None else lax.logaddexp(initial, result) + return result if initial is None else lax_other.logaddexp(initial, result) def _logsumexp2(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, @@ -762,8 +758,10 @@ def _logsumexp2(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, """Compute log2(sum(2 ** a)) via logsumexp.""" if out is not None: raise NotImplementedError("The 'out' argument to jnp.logaddexp2.reduce is not supported.") - dtypes.check_user_dtype_supported(dtype, "jnp.logaddexp2.reduce") - check_arraylike("logsumexp2", a) + if dtype is not None: + dtype = dtypes.check_and_canonicalize_user_dtype( + dtype, "jnp.logaddexp2.reduce") + a = ensure_arraylike("logsumexp2", a) where = check_where("logsumexp2", where) ln2 = float(np.log(2)) if initial is not None: @@ -771,7 +769,6 @@ def _logsumexp2(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, return _logsumexp(a * ln2, axis=axis, dtype=dtype, keepdims=keepdims, where=where, initial=initial) / ln2 - @export def amin(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, @@ -796,7 +793,7 @@ def _axis_size(a: ArrayLike, axis: int | Sequence[int]): size = 1 a_shape = np.shape(a) for a in axis_seq: - size *= maybe_named_axis(a, lambda i: a_shape[i], lambda name: lax.psum(1, name)) + size *= maybe_named_axis(a, lambda i: a_shape[i], lax_parallel.axis_size) return size @@ -867,40 +864,55 @@ def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, return _mean(a, _ensure_optional_axes(axis), dtype, out, keepdims, where=where, upcast_f16_for_computation=(dtype is None)) -@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'upcast_f16_for_computation'), +def _count( + a: ArrayLike, + axis: Axis, + keepdims: bool, + where: ArrayLike | None, + dtype: DTypeLike, +): + if where is None: + if axis is None: + count = core.dimension_as_value(np.size(a)) + else: + count = core.dimension_as_value(_axis_size(a, axis)) + count = lax.convert_element_type(count, dtype) + else: + count = sum(_broadcast_to(where, np.shape(a)), axis, dtype=dtype, keepdims=keepdims) + return count + +@api.jit(static_argnames=('axis', 'dtype', 'keepdims', 'upcast_f16_for_computation'), inline=True) def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, *, upcast_f16_for_computation: bool = True, where: ArrayLike | None = None) -> Array: - check_arraylike("mean", a) + a = ensure_arraylike("mean", a) where = check_where("mean", where) if out is not None: raise NotImplementedError("The 'out' argument to jnp.mean is not supported.") if dtype is None: - result_dtype = dtypes.to_inexact_dtype(dtypes.dtype(a, canonicalize=True)) + result_dtype = dtypes.to_inexact_dtype(a.dtype) else: - dtypes.check_user_dtype_supported(dtype, "mean") - result_dtype = dtypes.canonicalize_dtype(dtype) + result_dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "mean") if upcast_f16_for_computation and dtypes.issubdtype(result_dtype, np.inexact): computation_dtype = _upcast_f16(result_dtype) else: computation_dtype = result_dtype - if where is None: - if axis is None: - normalizer = core.dimension_as_value(np.size(a)) - else: - normalizer = core.dimension_as_value(_axis_size(a, axis)) - else: - normalizer = sum(_broadcast_to(where, np.shape(a)), axis, - dtype=computation_dtype, keepdims=keepdims) + normalizer = _count( + a, + axis=axis, + keepdims=keepdims, + where=where, + dtype=computation_dtype, + ) return lax.div( sum(a, axis, dtype=computation_dtype, keepdims=keepdims, where=where), - lax.convert_element_type(normalizer, computation_dtype) + normalizer, ).astype(result_dtype) @overload @@ -923,8 +935,9 @@ def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, a: array to be averaged axis: an optional integer or sequence of integers specifying the axis along which the mean to be computed. If not specified, mean is computed along all the axes. - weights: an optional array of weights for a weighted average. Must be - broadcast-compatible with ``a``. + weights: an optional array of weights for a weighted average. This must either exactly + match the shape of `a`, or if `axis` is specified, it must have shape ``a.shape[axis]`` + for a single axis, or shape ``tuple(a.shape[ax] for ax in axis)`` for multiple axes. returned: If False (default) then return only the average. If True then return both the average and the normalization factor (i.e. the sum of weights). keepdims: If True, reduced axes are left in the result with size 1. If False (default) @@ -968,50 +981,32 @@ def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, """ return _average(a, _ensure_optional_axes(axis), weights, returned, keepdims) -@partial(api.jit, static_argnames=('axis', 'returned', 'keepdims'), inline=True) +@api.jit(static_argnames=('axis', 'returned', 'keepdims'), inline=True) def _average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]: + axis = None if axis is None else canonicalize_axis_tuple(axis, np.ndim(a)) + if weights is None: # Treat all weights as 1 - check_arraylike("average", a) + a = ensure_arraylike("average", a) a, = promote_dtypes_inexact(a) avg = mean(a, axis=axis, keepdims=keepdims) if axis is None: weights_sum = lax.full((), core.dimension_as_value(a.size), dtype=avg.dtype) elif isinstance(axis, tuple): - weights_sum = lax.full_like(avg, math.prod(core.dimension_as_value(a.shape[d]) for d in axis)) - else: - weights_sum = lax.full_like(avg, core.dimension_as_value(a.shape[axis])) # type: ignore[index] + weights_sum = lax.full((), math.prod(core.dimension_as_value(a.shape[d]) for d in axis), dtype=avg.dtype) else: - check_arraylike("average", a, weights) + a, weights = ensure_arraylike("average", a, weights) a, weights = promote_dtypes_inexact(a, weights) - a_shape = np.shape(a) - a_ndim = len(a_shape) - weights_shape = np.shape(weights) - - if axis is None: - pass - elif isinstance(axis, tuple): - axis = tuple(_canonicalize_axis(d, a_ndim) for d in axis) - else: - axis = _canonicalize_axis(axis, a_ndim) - - if a_shape != weights_shape: - # Make sure the dimensions work out - if len(weights_shape) != 1: - raise ValueError("1D weights expected when shapes of a and " - "weights differ.") + if a.shape != weights.shape: if axis is None: raise ValueError("Axis must be specified when shapes of a and " "weights differ.") - elif isinstance(axis, tuple): - raise ValueError("Single axis expected when shapes of a and weights differ") - elif not core.definitely_equal(weights_shape[0], a_shape[axis]): - raise ValueError("Length of weights not " - "compatible with specified axis.") - - weights = _broadcast_to(weights, (a_ndim - 1) * (1,) + weights_shape) - weights = _moveaxis(weights, -1, axis) + if weights.shape != tuple(a.shape[ax] for ax in axis): + raise ValueError("Shape of weights must be consistent with shape " + "of a along specified axis.") + new_shape = tuple(dim if i in axis else 1 for i, dim in enumerate(a.shape)) + weights = lax.reshape(weights, new_shape, dimensions=tuple(np.argsort(axis))) weights_sum = sum(weights, axis=axis, keepdims=keepdims) avg = sum(a * weights, axis=axis, keepdims=keepdims) / weights_sum @@ -1026,7 +1021,8 @@ def _average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, @export def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, - where: ArrayLike | None = None, correction: int | float | None = None) -> Array: + where: ArrayLike | None = None, mean: ArrayLike | None = None, + correction: int | float | None = None) -> Array: r"""Compute the variance along a given axis. JAX implementation of :func:`numpy.var`. @@ -1042,6 +1038,11 @@ def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, with size 1. where: optional, boolean array, default=None. The elements to be used in the variance. Array should be broadcast compatible to the input. + mean: optional, mean of the input array, computed along the given axis. + If provided, it will be used to compute the variance instead of + computing it from the input array. If specified, mean must be broadcast-compatible + with the input array. In the general case, this can be achieved by computing the mean with + ``keepdims=True`` and ``axis`` matching this function's ``axis`` argument. correction: int or float, default=None. Alternative name for ``ddof``. Both ddof and correction can't be provided simultaneously. out: Unused by JAX. @@ -1104,22 +1105,27 @@ def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, correction = ddof elif not isinstance(ddof, int) or ddof != 0: raise ValueError("ddof and correction can't be provided simultaneously.") - return _var(a, _ensure_optional_axes(axis), dtype, out, correction, keepdims, - where=where) - -@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) -def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, - out: None = None, correction: int | float = 0, keepdims: bool = False, *, - where: ArrayLike | None = None) -> Array: - check_arraylike("var", a) + a = ensure_arraylike("var", a) + return _var(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, correction=correction, keepdims=keepdims, + where=where, a_mean=mean) + +@api.jit(static_argnames=('axis', 'dtype', 'keepdims')) +def _var(a: Array, *, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, correction: int | float = 0, keepdims: bool = False, + where: ArrayLike | None = None, a_mean: ArrayLike | None = None) -> Array: where = check_where("var", where) - dtypes.check_user_dtype_supported(dtype, "var") + if dtype is not None: + dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "var") if out is not None: raise NotImplementedError("The 'out' argument to jnp.var is not supported.") - computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype) - a = lax_internal.asarray(a).astype(computation_dtype) - a_mean = mean(a, axis, dtype=computation_dtype, keepdims=True, where=where) + computation_dtype, dtype = _var_promote_types(a.dtype, dtype) + a = lax.asarray(a).astype(computation_dtype) + if a_mean is None: + a_mean = mean(a, axis, dtype=computation_dtype, keepdims=True, where=where) + else: + a_mean = ensure_arraylike("var", a_mean).astype(computation_dtype) + centered = lax.sub(a, a_mean) if dtypes.issubdtype(computation_dtype, np.complexfloating): centered = lax.real(lax.mul(centered, lax.conj(centered))) @@ -1127,19 +1133,18 @@ def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, else: centered = lax.square(centered) - if where is None: - if axis is None: - normalizer = core.dimension_as_value(np.size(a)) - else: - normalizer = core.dimension_as_value(_axis_size(a, axis)) - normalizer = lax.convert_element_type(normalizer, computation_dtype) - else: - normalizer = sum(_broadcast_to(where, np.shape(a)), axis, - dtype=computation_dtype, keepdims=keepdims) + normalizer = _count( + a, + axis=axis, + keepdims=keepdims, + where=where, + dtype=computation_dtype, + ) + normalizer = lax.sub(normalizer, lax.convert_element_type(correction, computation_dtype)) result = sum(centered, axis, dtype=computation_dtype, keepdims=keepdims, where=where) result = lax.div(result, normalizer).astype(dtype) - with jax.debug_nans(False): + with config.debug_nans(False): result = _where(normalizer > 0, result, np.nan) return result @@ -1160,7 +1165,7 @@ def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DTy dtype = dtypes.to_inexact_dtype(a_dtype) computation_dtype = dtype else: - dtype = _complex_elem_type(a_dtype) + dtype = np.array(0, a_dtype).real.dtype computation_dtype = a_dtype return _upcast_f16(computation_dtype), np.dtype(dtype) @@ -1168,7 +1173,8 @@ def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DTy @export def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, - where: ArrayLike | None = None, correction: int | float | None = None) -> Array: + where: ArrayLike | None = None, mean: ArrayLike | None = None, + correction: int | float | None = None) -> Array: r"""Compute the standard deviation along a given axis. JAX implementation of :func:`numpy.std`. @@ -1185,6 +1191,11 @@ def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, with size 1. where: optional, boolean array, default=None. The elements to be used in the standard deviation. Array should be broadcast compatible to the input. + mean: optional, mean of the input array, computed along the given axis. + If provided, it will be used to compute the standard deviation instead of + computing it from the input array. If specified, mean must be broadcast-compatible + with the input array. In the general case, this can be achieved by computing the mean with + ``keepdims=True`` and ``axis`` matching this function's ``axis`` argument. correction: int or float, default=None. Alternative name for ``ddof``. Both ddof and correction can't be provided simultaneously. out: Unused by JAX. @@ -1242,21 +1253,23 @@ def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, correction = ddof elif not isinstance(ddof, int) or ddof != 0: raise ValueError("ddof and correction can't be provided simultaneously.") - return _std(a, _ensure_optional_axes(axis), dtype, out, correction, keepdims, - where=where) - -@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) -def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, - out: None = None, correction: int | float = 0, keepdims: bool = False, *, - where: ArrayLike | None = None) -> Array: - check_arraylike("std", a) + a = ensure_arraylike("std", a) + return _std(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, correction=correction, keepdims=keepdims, + where=where, mean=mean) + +@api.jit(static_argnames=('axis', 'dtype', 'keepdims')) +def _std(a: Array, *, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, correction: int | float = 0, keepdims: bool = False, + where: ArrayLike | None = None, mean: ArrayLike | None = None) -> Array: where = check_where("std", where) - dtypes.check_user_dtype_supported(dtype, "std") - if dtype is not None and not dtypes.issubdtype(dtype, np.inexact): - raise ValueError(f"dtype argument to jnp.std must be inexact; got {dtype}") + if dtype is not None: + dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "std") + if not dtypes.issubdtype(dtype, np.inexact): + raise ValueError(f"dtype argument to jnp.std must be inexact; got {dtype}") if out is not None: raise NotImplementedError("The 'out' argument to jnp.std is not supported.") - return lax.sqrt(var(a, axis=axis, dtype=dtype, correction=correction, keepdims=keepdims, where=where)) + return lax.sqrt(var(a, axis=axis, dtype=dtype, correction=correction, + keepdims=keepdims, where=where, mean=mean)) @export @@ -1298,12 +1311,12 @@ def ptp(a: ArrayLike, axis: Axis = None, out: None = None, [7], [6]], dtype=int32) """ + a = ensure_arraylike("ptp", a) return _ptp(a, _ensure_optional_axes(axis), out, keepdims) -@partial(api.jit, static_argnames=('axis', 'keepdims')) -def _ptp(a: ArrayLike, axis: Axis = None, out: None = None, +@api.jit(static_argnames=('axis', 'keepdims')) +def _ptp(a: Array, axis: Axis = None, out: None = None, keepdims: bool = False) -> Array: - check_arraylike("ptp", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.ptp is not supported.") x = amax(a, axis=axis, keepdims=keepdims) @@ -1312,7 +1325,7 @@ def _ptp(a: ArrayLike, axis: Axis = None, out: None = None, @export -@partial(api.jit, static_argnames=('axis', 'keepdims')) +@api.jit(static_argnames=('axis', 'keepdims')) def count_nonzero(a: ArrayLike, axis: Axis = None, keepdims: bool = False) -> Array: r"""Return the number of nonzero elements along a given axis. @@ -1350,31 +1363,31 @@ def count_nonzero(a: ArrayLike, axis: Axis = None, [1], [3]], dtype=int32) """ - check_arraylike("count_nonzero", a) - return sum(lax.ne(a, _lax_const(a, 0)), axis=axis, - dtype=dtypes.canonicalize_dtype(int), keepdims=keepdims) + a = ensure_arraylike("count_nonzero", a) + return sum(lax.ne(a, lax._const(a, 0)), axis=axis, + dtype=dtypes.default_int_dtype(), keepdims=keepdims) def _nan_reduction(a: ArrayLike, name: str, jnp_reduction: Callable[..., Array], init_val: ArrayLike, nan_if_all_nan: bool, axis: Axis = None, keepdims: bool = False, where: ArrayLike | None = None, **kwargs) -> Array: - check_arraylike(name, a) + a = ensure_arraylike(name, a) where = check_where(name, where) - if not dtypes.issubdtype(dtypes.dtype(a), np.inexact): + if not dtypes.issubdtype(a.dtype, np.inexact): return jnp_reduction(a, axis=axis, keepdims=keepdims, where=where, **kwargs) - out = jnp_reduction(_where(lax_internal._isnan(a), _reduction_init_val(a, init_val), a), + out = jnp_reduction(_where(lax._isnan(a), _reduction_init_val(a, init_val), a), axis=axis, keepdims=keepdims, where=where, **kwargs) if nan_if_all_nan: - return _where(all(lax_internal._isnan(a), axis=axis, keepdims=keepdims), - _lax_const(a, np.nan), out) + return _where(all(lax._isnan(a), axis=axis, keepdims=keepdims), + lax._const(a, np.nan), out) else: return out @export -@partial(api.jit, static_argnames=('axis', 'keepdims')) +@api.jit(static_argnames=('axis', 'keepdims')) def nanmin(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: @@ -1457,7 +1470,7 @@ def nanmin(a: ArrayLike, axis: Axis = None, out: None = None, @export -@partial(api.jit, static_argnames=('axis', 'keepdims')) +@api.jit(static_argnames=('axis', 'keepdims')) def nanmax(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: @@ -1540,8 +1553,9 @@ def nanmax(a: ArrayLike, axis: Axis = None, out: None = None, @export -@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) -def nansum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, +@api.jit(static_argnames=('axis', 'dtype', 'keepdims')) +def nansum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: r"""Return the sum of the array elements along a given axis, ignoring NaNs. @@ -1617,14 +1631,15 @@ def nansum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: >>> jnp.nansum(x, axis=0, keepdims=True, where=where) Array([[0., 0., 0., 0.]], dtype=float32) """ - dtypes.check_user_dtype_supported(dtype, "nanprod") + if dtype is not None: + dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "nanprod") return _nan_reduction(a, 'nansum', sum, 0, nan_if_all_nan=False, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where) @export -@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) +@api.jit(static_argnames=('axis', 'dtype', 'keepdims')) def nanprod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: @@ -1701,14 +1716,15 @@ def nanprod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out >>> jnp.nanprod(x, axis=0, keepdims=True, where=where) Array([[1., 1., 1., 1.]], dtype=float32) """ - dtypes.check_user_dtype_supported(dtype, "nanprod") + if dtype is not None: + dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "nanprod") return _nan_reduction(a, 'nanprod', prod, 1, nan_if_all_nan=False, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where) @export -@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) +@api.jit(static_argnames=('axis', 'dtype', 'keepdims')) def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, where: ArrayLike | None = None) -> Array: r"""Return the mean of the array elements along a given axis, ignoring NaNs. @@ -1783,28 +1799,27 @@ def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out >>> jnp.nanmean(x, axis=0, keepdims=True, where=where) Array([[nan, nan, nan, nan]], dtype=float32) """ - check_arraylike("nanmean", a) + a = ensure_arraylike("nanmean", a) where = check_where("nanmean", where) if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanmean is not supported.") - if dtypes.issubdtype(dtypes.dtype(a), np.bool_) or dtypes.issubdtype(dtypes.dtype(a), np.integer): + if dtypes.issubdtype(a.dtype, np.bool_) or dtypes.issubdtype(a.dtype, np.integer): return mean(a, axis, dtype, out, keepdims, where=where) if dtype is None: - dtype = dtypes.to_inexact_dtype(dtypes.dtype(a, canonicalize=True)) + dtype = dtypes.to_inexact_dtype(a.dtype) else: - dtypes.check_user_dtype_supported(dtype, "mean") - dtype = dtypes.canonicalize_dtype(dtype) - nan_mask = lax_internal.bitwise_not(lax_internal._isnan(a)) + dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "mean") + nan_mask = lax.bitwise_not(lax._isnan(a)) normalizer = sum(nan_mask, axis=axis, dtype=dtype, keepdims=keepdims, where=where) td = lax.div(nansum(a, axis, dtype=dtype, keepdims=keepdims, where=where), normalizer) return td @export -@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) +@api.jit(static_argnames=('axis', 'dtype', 'keepdims')) def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, - where: ArrayLike | None = None) -> Array: + where: ArrayLike | None = None, mean: ArrayLike | None = None) -> Array: r"""Compute the variance of array elements along a given axis, ignoring NaNs. JAX implementation of :func:`numpy.nanvar`. @@ -1820,6 +1835,11 @@ def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: with size 1. where: optional, boolean array, default=None. The elements to be used in the variance. Array should be broadcast compatible to the input. + mean: optional, mean of the input array, computed along the given axis. + If provided, it will be used to compute the variance instead of + computing it from the input array. If specified, mean must be broadcast-compatible + with the input array. In the general case, this can be achieved by computing the mean with + ``keepdims=True`` and ``axis`` matching this function's ``axis`` argument. out: Unused by JAX. Returns: @@ -1877,26 +1897,34 @@ def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: [0. ], [4. ]], dtype=float32) """ - check_arraylike("nanvar", a) + a = ensure_arraylike("nanvar", a) where = check_where("nanvar", where) - dtypes.check_user_dtype_supported(dtype, "nanvar") + if dtype is not None: + dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "nanvar") if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanvar is not supported.") + return _nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof, keepdims=keepdims, where=where, a_mean=mean) - computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype) - a = lax_internal.asarray(a).astype(computation_dtype) - a_mean = nanmean(a, axis, dtype=computation_dtype, keepdims=True, where=where) +def _nanvar(a: Array, *, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, + ddof: int = 0, keepdims: bool = False, + where: ArrayLike | None = None, a_mean: ArrayLike | None = None) -> Array: + computation_dtype, dtype = _var_promote_types(a.dtype, dtype) + a = lax.asarray(a).astype(computation_dtype) + if a_mean is None: + a_mean = nanmean(a, axis, dtype=computation_dtype, keepdims=True, where=where) + else: + a_mean = ensure_arraylike("nanvar", a_mean).astype(computation_dtype) - centered = _where(lax_internal._isnan(a), 0, lax.sub(a, a_mean)) # double-where trick for gradients. + centered = _where(lax._isnan(a), 0, lax.sub(a, a_mean)) # double-where trick for gradients. if dtypes.issubdtype(centered.dtype, np.complexfloating): centered = lax.real(lax.mul(centered, lax.conj(centered))) else: centered = lax.square(centered) - normalizer = sum(lax_internal.bitwise_not(lax_internal._isnan(a)), + normalizer = sum(lax.bitwise_not(lax._isnan(a)), axis=axis, keepdims=keepdims, where=where) normalizer = normalizer - ddof - normalizer_mask = lax.le(normalizer, lax_internal._zero(normalizer)) + normalizer_mask = lax.le(normalizer, lax._zero(normalizer)) result = sum(centered, axis, keepdims=keepdims, where=where) result = _where(normalizer_mask, np.nan, result) divisor = _where(normalizer_mask, 1, normalizer) @@ -1905,10 +1933,10 @@ def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: @export -@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) +@api.jit(static_argnames=('axis', 'dtype', 'keepdims')) def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, - where: ArrayLike | None = None) -> Array: + where: ArrayLike | None = None, mean: ArrayLike | None = None) -> Array: r"""Compute the standard deviation along a given axis, ignoring NaNs. JAX implementation of :func:`numpy.nanstd`. @@ -1925,6 +1953,11 @@ def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: with size 1. where: optional, boolean array, default=None. The elements to be used in the standard deviation. Array should be broadcast compatible to the input. + mean: optional, mean of the input array, computed along the given axis. + If provided, it will be used to compute the standard deviation instead of + computing it from the input array. If specified, mean must be broadcast-compatible + with the input array. In the general case, this can be achieved by computing the mean with + ``keepdims=True`` and ``axis`` matching this function's ``axis`` argument. out: Unused by JAX. Returns: @@ -1973,12 +2006,14 @@ def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: >>> jnp.nanstd(x, axis=0, keepdims=True, where=where) Array([[0.5, 0.5, 0. , 0. ]], dtype=float32) """ - check_arraylike("nanstd", a) + a = ensure_arraylike("nanstd", a) where = check_where("nanstd", where) - dtypes.check_user_dtype_supported(dtype, "nanstd") + if dtype is not None: + dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "nanstd") if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanstd is not supported.") - return lax.sqrt(nanvar(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where)) + return lax.sqrt(nanvar(a, axis=axis, dtype=dtype, ddof=ddof, + keepdims=keepdims, where=where, mean=mean)) class CumulativeReduction(Protocol): @@ -1992,10 +2027,9 @@ def _cumulative_reduction( fill_nan: bool = False, fill_value: ArrayLike = 0, promote_integers: bool = False) -> Array: """Helper function for implementing cumulative reductions.""" - check_arraylike(name, a) + a = ensure_arraylike(name, a) if out is not None: raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported") - dtypes.check_user_dtype_supported(dtype, name) if axis is None or _isscalar(a): a = lax.reshape(a, (np.size(a),)) @@ -2004,19 +2038,24 @@ def _cumulative_reduction( a_shape = list(np.shape(a)) num_dims = len(a_shape) - axis = _canonicalize_axis(axis, num_dims) + axis = canonicalize_axis(axis, num_dims) if fill_nan: - a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a) + a = _where(lax._isnan(a), lax._const(a, fill_value), a) - a_type: DType = dtypes.dtype(a) - result_type: DTypeLike = dtypes.dtype(dtype or a) - if dtype is None and promote_integers or dtypes.issubdtype(result_type, np.bool_): - result_type = _promote_integer_dtype(result_type) - result_type = dtypes.canonicalize_dtype(result_type) + a_type: DType = a.dtype + result_type: DType + if dtype is None: + result_type = a_type + if promote_integers or dtypes.issubdtype(result_type, np.bool_): + result_type = _promote_integer_dtype(result_type) + else: + result_type = dtypes.check_and_canonicalize_user_dtype(dtype, name) + if dtypes.issubdtype(result_type, np.bool_): + result_type = _promote_integer_dtype(result_type) if a_type != np.bool_ and dtype == np.bool_: - a = lax_internal.asarray(a).astype(np.bool_) + a = lax.asarray(a).astype(np.bool_) a = lax.convert_element_type(a, result_type) result = reduction(a, axis) @@ -2028,7 +2067,7 @@ def _cumulative_reduction( @export -@partial(api.jit, static_argnames=('axis', 'dtype')) +@api.jit(static_argnames=('axis', 'dtype')) def cumsum(a: ArrayLike, axis: int | None = None, dtype: DTypeLike | None = None, out: None = None) -> Array: """Cumulative sum of elements along an axis. @@ -2061,11 +2100,11 @@ def cumsum(a: ArrayLike, axis: int | None = None, Array([[ 1, 3, 6], [ 4, 9, 15]], dtype=int32) """ - return _cumulative_reduction("cumsum", lax.cumsum, a, axis, dtype, out) + return _cumulative_reduction("cumsum", control_flow.cumsum, a, axis, dtype, out) @export -@partial(api.jit, static_argnames=('axis', 'dtype')) +@api.jit(static_argnames=('axis', 'dtype')) def cumprod(a: ArrayLike, axis: int | None = None, dtype: DTypeLike | None = None, out: None = None) -> Array: """Cumulative product of elements along an axis. @@ -2097,11 +2136,11 @@ def cumprod(a: ArrayLike, axis: int | None = None, Array([[ 1, 2, 6], [ 4, 20, 120]], dtype=int32) """ - return _cumulative_reduction("cumprod", lax.cumprod, a, axis, dtype, out) + return _cumulative_reduction("cumprod", control_flow.cumprod, a, axis, dtype, out) @export -@partial(api.jit, static_argnames=('axis', 'dtype')) +@api.jit(static_argnames=('axis', 'dtype')) def nancumsum(a: ArrayLike, axis: int | None = None, dtype: DTypeLike | None = None, out: None = None) -> Array: """Cumulative sum of elements along an axis, ignoring NaN values. @@ -2146,12 +2185,12 @@ def nancumsum(a: ArrayLike, axis: int | None = None, Array([[ 1., 3., 3.], [ 4., 4., 10.]], dtype=float32) """ - return _cumulative_reduction("nancumsum", lax.cumsum, a, axis, dtype, out, + return _cumulative_reduction("nancumsum", control_flow.cumsum, a, axis, dtype, out, fill_nan=True, fill_value=0) @export -@partial(api.jit, static_argnames=('axis', 'dtype')) +@api.jit(static_argnames=('axis', 'dtype')) def nancumprod(a: ArrayLike, axis: int | None = None, dtype: DTypeLike | None = None, out: None = None) -> Array: """Cumulative product of elements along an axis, ignoring NaN values. @@ -2195,15 +2234,15 @@ def nancumprod(a: ArrayLike, axis: int | None = None, Array([[ 1., 2., 2.], [ 4., 4., 24.]], dtype=float32) """ - return _cumulative_reduction("nancumprod", lax.cumprod, a, axis, dtype, out, + return _cumulative_reduction("nancumprod", control_flow.cumprod, a, axis, dtype, out, fill_nan=True, fill_value=1) -@partial(api.jit, static_argnames=('axis', 'dtype')) +@api.jit(static_argnames=('axis', 'dtype')) def _cumsum_with_promotion(a: ArrayLike, axis: int | None = None, dtype: DTypeLike | None = None, out: None = None) -> Array: """Utility function to compute cumsum with integer promotion.""" - return _cumulative_reduction("_cumsum_with_promotion", lax.cumsum, + return _cumulative_reduction("_cumsum_with_promotion", control_flow.cumsum, a, axis, dtype, out, promote_integers=True) @@ -2242,8 +2281,7 @@ def cumulative_sum( Array([[ 0, 1, 3, 6], [ 0, 4, 9, 15]], dtype=int32) """ - check_arraylike("cumulative_sum", x) - x = lax_internal.asarray(x) + x = ensure_arraylike("cumulative_sum", x) if x.ndim == 0: raise ValueError( "The input must be non-scalar to take a cumulative sum, however a " @@ -2257,14 +2295,15 @@ def cumulative_sum( "explicit value. The axis argument is only optional for one-dimensional " "arrays.") - axis = _canonicalize_axis(axis, x.ndim) - dtypes.check_user_dtype_supported(dtype) + axis = canonicalize_axis(axis, x.ndim) + if dtype is not None: + dtype = dtypes.check_and_canonicalize_user_dtype(dtype) out = _cumsum_with_promotion(x, axis=axis, dtype=dtype) if include_initial: zeros_shape = list(x.shape) zeros_shape[axis] = 1 - out = lax_internal.concatenate( - [lax_internal.full(zeros_shape, 0, dtype=out.dtype), out], + out = lax.concatenate( + [lax.full(zeros_shape, 0, dtype=out.dtype), out], dimension=axis) return out @@ -2304,8 +2343,7 @@ def cumulative_prod( Array([[ 1, 1, 2, 6], [ 1, 4, 20, 120]], dtype=int32) """ - check_arraylike("cumulative_prod", x) - x = lax_internal.asarray(x) + x = ensure_arraylike("cumulative_prod", x) if x.ndim == 0: raise ValueError( "The input must be non-scalar to take a cumulative product, however a " @@ -2319,25 +2357,25 @@ def cumulative_prod( "explicit value. The axis argument is only optional for one-dimensional " "arrays.") - axis = _canonicalize_axis(axis, x.ndim) - dtypes.check_user_dtype_supported(dtype) - out = _cumulative_reduction("cumulative_prod", lax.cumprod, x, axis, dtype) + axis = canonicalize_axis(axis, x.ndim) + if dtype is not None: + dtype = dtypes.check_and_canonicalize_user_dtype(dtype) + out = _cumulative_reduction("cumulative_prod", control_flow.cumprod, x, axis, dtype) if include_initial: zeros_shape = list(x.shape) zeros_shape[axis] = 1 - out = lax_internal.concatenate( - [lax_internal.full(zeros_shape, 1, dtype=out.dtype), out], + out = lax.concatenate( + [lax.full(zeros_shape, 1, dtype=out.dtype), out], dimension=axis) return out # Quantiles -# TODO(jakevdp): interpolation argument deprecated 2024-05-16 @export -@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) +@api.jit(static_argnames=('axis', 'overwrite_input', 'keepdims', 'method')) def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, *, interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array: + keepdims: bool = False) -> Array: """Compute the quantile of the data along the specified axis. JAX implementation of :func:`numpy.quantile`. @@ -2354,8 +2392,6 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No default is ``linear``. keepdims: if True, then the returned array will have the same number of dimensions as the input. Default is False. - interpolation: deprecated alias of the ``method`` argument. Will result - in a :class:`DeprecationWarning` if used. Returns: An array containing the specified quantiles along the specified axes. @@ -2377,24 +2413,18 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No >>> jnp.quantile(x, q, method='nearest') Array([2., 4., 7.], dtype=float32) """ - check_arraylike("quantile", a, q) + a, q = ensure_arraylike("quantile", a, q) if overwrite_input or out is not None: raise ValueError("jax.numpy.quantile does not support overwrite_input=True " "or out != None") - if not isinstance(interpolation, DeprecatedArg): - deprecations.warn( - "jax-numpy-quantile-interpolation", - ("The interpolation= argument to 'quantile' is deprecated. " - "Use 'method=' instead."), stacklevel=2) - method = interpolation - return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, False) - -# TODO(jakevdp): interpolation argument deprecated 2024-05-16 + return _quantile(lax.asarray(a), lax.asarray(q), axis, method, keepdims, False) + + @export -@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) +@api.jit(static_argnames=('axis', 'overwrite_input', 'keepdims', 'method')) def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, *, interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array: + keepdims: bool = False) -> Array: """Compute the quantile of the data along the specified axis, ignoring NaNs. JAX implementation of :func:`numpy.nanquantile`. @@ -2411,8 +2441,6 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = default is ``linear``. keepdims: if True, then the returned array will have the same number of dimensions as the input. Default is False. - interpolation: deprecated alias of the ``method`` argument. Will result - in a :class:`DeprecationWarning` if used. Returns: An array containing the specified quantiles along the specified axes. @@ -2435,18 +2463,12 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = >>> jnp.nanquantile(x, q) Array([1.5, 3. , 4.5], dtype=float32) """ - check_arraylike("nanquantile", a, q) + a, q = ensure_arraylike("nanquantile", a, q) if overwrite_input or out is not None: msg = ("jax.numpy.nanquantile does not support overwrite_input=True or " "out != None") raise ValueError(msg) - if not isinstance(interpolation, DeprecatedArg): - deprecations.warn( - "jax-numpy-quantile-interpolation", - ("The interpolation= argument to 'nanquantile' is deprecated. " - "Use 'method=' instead."), stacklevel=2) - method = interpolation - return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, True) + return _quantile(lax.asarray(a), lax.asarray(q), axis, method, keepdims, True) def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, method: str, keepdims: bool, squash_nans: bool) -> Array: @@ -2464,7 +2486,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, elif isinstance(axis, tuple): keepdim = list(a.shape) nd = a.ndim - axis = tuple(_canonicalize_axis(ax, nd) for ax in axis) + axis = tuple(canonicalize_axis(ax, nd) for ax in axis) if len(set(axis)) != len(axis): raise ValueError('repeated axis') for ax in axis: @@ -2478,9 +2500,9 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, do_not_touch_shape = tuple(x for idx,x in enumerate(a.shape) if idx not in axis) touch_shape = tuple(x for idx,x in enumerate(a.shape) if idx in axis) a = lax.reshape(a, do_not_touch_shape + (math.prod(touch_shape),), dimensions) - axis = _canonicalize_axis(-1, a.ndim) + axis = canonicalize_axis(-1, a.ndim) else: - axis = _canonicalize_axis(axis, a.ndim) + axis = canonicalize_axis(axis, a.ndim) q_shape = q.shape q_ndim = q.ndim @@ -2490,21 +2512,21 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, a_shape = a.shape if squash_nans: - a = _where(lax_internal._isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end. + a = _where(lax._isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end. a = lax.sort(a, dimension=axis) - counts = sum(lax_internal.bitwise_not(lax_internal._isnan(a)), axis=axis, dtype=q.dtype, keepdims=keepdims) + counts = sum(lax.bitwise_not(lax._isnan(a)), axis=axis, dtype=q.dtype, keepdims=keepdims) shape_after_reduction = counts.shape q = lax.expand_dims( q, tuple(range(q_ndim, len(shape_after_reduction) + q_ndim))) counts = lax.expand_dims(counts, tuple(range(q_ndim))) - q = lax.mul(q, lax.sub(counts, _lax_const(q, 1))) + q = lax.mul(q, lax.sub(counts, lax._const(q, 1))) low = lax.floor(q) high = lax.ceil(q) high_weight = lax.sub(q, low) - low_weight = lax.sub(_lax_const(high_weight, 1), high_weight) + low_weight = lax.sub(lax._const(high_weight, 1), high_weight) - low = lax.max(_lax_const(low, 0), lax.min(low, counts - 1)) - high = lax.max(_lax_const(high, 0), lax.min(high, counts - 1)) + low = lax.max(lax._const(low, 0), lax.min(low, counts - 1)) + high = lax.max(lax._const(high, 0), lax.min(high, counts - 1)) low = lax.convert_element_type(low, int) high = lax.convert_element_type(high, int) out_shape = q_shape + shape_after_reduction @@ -2518,33 +2540,33 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, index[axis] = high high_value = a[tuple(index)] else: - with jax.debug_nans(False): - a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a) + with config.debug_nans(False): + a = _where(any(lax._isnan(a), axis=axis, keepdims=True), np.nan, a) a = lax.sort(a, dimension=axis) - n = lax.convert_element_type(a_shape[axis], lax_internal._dtype(q)) + n = lax.convert_element_type(a_shape[axis], lax._dtype(q)) q = lax.mul(q, n - 1) low = lax.floor(q) high = lax.ceil(q) high_weight = lax.sub(q, low) - low_weight = lax.sub(_lax_const(high_weight, 1), high_weight) + low_weight = lax.sub(lax._const(high_weight, 1), high_weight) - low = lax.clamp(_lax_const(low, 0), low, n - 1) - high = lax.clamp(_lax_const(high, 0), high, n - 1) + low = lax.clamp(lax._const(low, 0), low, n - 1) + high = lax.clamp(lax._const(high, 0), high, n - 1) low = lax.convert_element_type(low, int) high = lax.convert_element_type(high, int) slice_sizes = list(a_shape) slice_sizes[axis] = 1 - dnums = lax.GatherDimensionNumbers( + dnums = lax_slicing.GatherDimensionNumbers( offset_dims=tuple(range( q_ndim, len(a_shape) + q_ndim if keepdims else len(a_shape) + q_ndim - 1)), collapsed_slice_dims=() if keepdims else (axis,), start_index_map=(axis,)) - low_value = lax.gather(a, low[..., None], dimension_numbers=dnums, - slice_sizes=slice_sizes) - high_value = lax.gather(a, high[..., None], dimension_numbers=dnums, - slice_sizes=slice_sizes) + low_value = lax_slicing.gather(a, low[..., None], dimension_numbers=dnums, + slice_sizes=slice_sizes) + high_value = lax_slicing.gather(a, high[..., None], dimension_numbers=dnums, + slice_sizes=slice_sizes) if q_ndim == 1: low_weight = lax.broadcast_in_dim(low_weight, low_value.shape, broadcast_dimensions=(0,)) @@ -2559,10 +2581,10 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, elif method == "higher": result = high_value elif method == "nearest": - pred = lax.le(high_weight, _lax_const(high_weight, 0.5)) + pred = lax.le(high_weight, lax._const(high_weight, 0.5)) result = lax.select(pred, low_value, high_value) elif method == "midpoint": - result = lax.mul(lax.add(low_value, high_value), _lax_const(low_value, 0.5)) + result = lax.mul(lax.add(low_value, high_value), lax._const(low_value, 0.5)) else: raise ValueError(f"{method=!r} not recognized") if keepdims and keepdim: @@ -2572,13 +2594,12 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, return lax.convert_element_type(result, a.dtype) -# TODO(jakevdp): interpolation argument deprecated 2024-05-16 @export -@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) +@api.jit(static_argnames=('axis', 'overwrite_input', 'keepdims', 'method')) def percentile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, *, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array: + keepdims: bool = False) -> Array: """Compute the percentile of the data along the specified axis. JAX implementation of :func:`numpy.percentile`. @@ -2595,8 +2616,6 @@ def percentile(a: ArrayLike, q: ArrayLike, default is ``linear``. keepdims: if True, then the returned array will have the same number of dimensions as the input. Default is False. - interpolation: deprecated alias of the ``method`` argument. Will result - in a :class:`DeprecationWarning` if used. Returns: An array containing the specified percentiles along the specified axes. @@ -2618,25 +2637,18 @@ def percentile(a: ArrayLike, q: ArrayLike, >>> jnp.percentile(x, q, method='nearest') Array([1., 3., 4.], dtype=float32) """ - check_arraylike("percentile", a, q) + a, q = ensure_arraylike("percentile", a, q) q, = promote_dtypes_inexact(q) - if not isinstance(interpolation, DeprecatedArg): - deprecations.warn( - "jax-numpy-quantile-interpolation", - ("The interpolation= argument to 'percentile' is deprecated. " - "Use 'method=' instead."), stacklevel=2) - method = interpolation return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input, method=method, keepdims=keepdims) -# TODO(jakevdp): interpolation argument deprecated 2024-05-16 @export -@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) +@api.jit(static_argnames=('axis', 'overwrite_input', 'keepdims', 'method')) def nanpercentile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, *, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array: + keepdims: bool = False) -> Array: """Compute the percentile of the data along the specified axis, ignoring NaN values. JAX implementation of :func:`numpy.nanpercentile`. @@ -2653,8 +2665,6 @@ def nanpercentile(a: ArrayLike, q: ArrayLike, default is ``linear``. keepdims: if True, then the returned array will have the same number of dimensions as the input. Default is False. - interpolation: deprecated alias of the ``method`` argument. Will result - in a :class:`DeprecationWarning` if used. Returns: An array containing the specified percentiles along the specified axes. @@ -2678,21 +2688,15 @@ def nanpercentile(a: ArrayLike, q: ArrayLike, >>> jnp.nanpercentile(x, q) Array([1.5, 3. , 4.5], dtype=float32) """ - check_arraylike("nanpercentile", a, q) + a, q = ensure_arraylike("nanpercentile", a, q) q, = promote_dtypes_inexact(q) q = q / 100 - if not isinstance(interpolation, DeprecatedArg): - deprecations.warn( - "jax-numpy-quantile-interpolation", - ("The interpolation= argument to 'nanpercentile' is deprecated. " - "Use 'method=' instead."), stacklevel=2) - method = interpolation return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input, method=method, keepdims=keepdims) @export -@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims')) +@api.jit(static_argnames=('axis', 'overwrite_input', 'keepdims')) def median(a: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, keepdims: bool = False) -> Array: @@ -2738,13 +2742,13 @@ def median(a: ArrayLike, axis: int | tuple[int, ...] | None = None, [4. ], [4.5]], dtype=float32) """ - check_arraylike("median", a) + a = ensure_arraylike("median", a) return quantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input, keepdims=keepdims, method='midpoint') @export -@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims')) +@api.jit(static_argnames=('axis', 'overwrite_input', 'keepdims')) def nanmedian(a: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, keepdims: bool = False) -> Array: @@ -2795,7 +2799,7 @@ def nanmedian(a: ArrayLike, axis: int | tuple[int, ...] | None = None, [5. ], [3. ]], dtype=float32) """ - check_arraylike("nanmedian", a) + a = ensure_arraylike("nanmedian", a) return nanquantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input, keepdims=keepdims, method='midpoint') diff --git a/jax/_src/numpy/scalar_types.py b/jax/_src/numpy/scalar_types.py index 2f9954488b41..4ebf75b020a7 100644 --- a/jax/_src/numpy/scalar_types.py +++ b/jax/_src/numpy/scalar_types.py @@ -22,11 +22,12 @@ from typing import Any -import jax +import numpy as np + from jax._src.typing import Array from jax._src import core from jax._src import dtypes -import numpy as np +from jax._src.numpy.array_constructors import asarray # Some objects below rewrite their __module__ attribute to this name. @@ -36,6 +37,11 @@ class _ScalarMeta(type): dtype: np.dtype + @property + def __numpy_dtype__(self) -> np.dtype: + # __numpy_dtype__ protocol added in NumPy v2.4.0. + return self.dtype + def __hash__(self) -> int: return hash(self.dtype.type) @@ -46,7 +52,7 @@ def __ne__(self, other: Any) -> bool: return not (self == other) def __call__(self, x: Any) -> Array: - return jax.numpy.asarray(x, dtype=self.dtype) + return asarray(x, dtype=self.dtype) def __instancecheck__(self, instance: Any) -> bool: return isinstance(instance, self.dtype.type) @@ -68,33 +74,27 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta: return meta bool_ = _make_scalar_type(np.bool_) -if dtypes.uint2 is not None: - uint2 = _make_scalar_type(dtypes.uint2) +uint2 = _make_scalar_type(dtypes.uint2) uint4 = _make_scalar_type(dtypes.uint4) uint8 = _make_scalar_type(np.uint8) uint16 = _make_scalar_type(np.uint16) uint32 = _make_scalar_type(np.uint32) uint64 = _make_scalar_type(np.uint64) -if dtypes.int2 is not None: - int2 = _make_scalar_type(dtypes.int2) +int2 = _make_scalar_type(dtypes.int2) int4 = _make_scalar_type(dtypes.int4) int8 = _make_scalar_type(np.int8) int16 = _make_scalar_type(np.int16) int32 = _make_scalar_type(np.int32) int64 = _make_scalar_type(np.int64) -if dtypes.float8_e3m4 is not None: - float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4) -if dtypes.float8_e4m3 is not None: - float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3) -if dtypes.float8_e8m0fnu is not None: - float8_e8m0fnu = _make_scalar_type(dtypes.float8_e8m0fnu) +float4_e2m1fn = _make_scalar_type(dtypes.float4_e2m1fn) +float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4) +float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3) +float8_e8m0fnu = _make_scalar_type(dtypes.float8_e8m0fnu) float8_e4m3fn = _make_scalar_type(dtypes.float8_e4m3fn) float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz) float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2) float8_e5m2fnuz = _make_scalar_type(dtypes.float8_e5m2fnuz) float8_e4m3b11fnuz = _make_scalar_type(dtypes.float8_e4m3b11fnuz) -if dtypes.float4_e2m1fn is not None: - float4_e2m1fn = _make_scalar_type(dtypes.float4_e2m1fn) bfloat16 = _make_scalar_type(dtypes.bfloat16) float16 = _make_scalar_type(np.float16) float32 = single = _make_scalar_type(np.float32) @@ -102,7 +102,7 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta: complex64 = csingle = _make_scalar_type(np.complex64) complex128 = cdouble = _make_scalar_type(np.complex128) -int_ = int32 if dtypes.int_ == np.int32 else int64 -uint = uint32 if dtypes.uint == np.uint32 else uint64 -float_: Any = float32 if dtypes.float_ == np.float32 else float64 -complex_ = complex64 if dtypes.complex_ == np.complex64 else complex128 +int_ = int64 +uint = uint64 +float_ = float64 +complex_ = complex128 diff --git a/jax/_src/numpy/setops.py b/jax/_src/numpy/setops.py index d4a8e41dd317..ce1b950c3c0c 100644 --- a/jax/_src/numpy/setops.py +++ b/jax/_src/numpy/setops.py @@ -14,20 +14,18 @@ from __future__ import annotations -from functools import partial import math import operator from typing import cast, NamedTuple import numpy as np -import jax -from jax import jit -from jax import lax - +from jax._src import api from jax._src import core from jax._src import dtypes -from jax._src.lax import lax as lax_internal +from jax._src.lax import lax +from jax._src.lax import slicing as lax_slicing +from jax._src.lax import utils as lax_utils from jax._src.numpy.array_creation import empty, full, full_like, ones, zeros from jax._src.numpy.lax_numpy import ( append, arange, concatenate, diff, @@ -42,10 +40,8 @@ export = set_module('jax.numpy') -_lax_const = lax_internal._const - -@partial(jit, static_argnames=('assume_unique', 'invert', 'method')) +@api.jit(static_argnames=('assume_unique', 'invert', 'method')) def _in1d(ar1: ArrayLike, ar2: ArrayLike, invert: bool, method='auto', assume_unique=False) -> Array: ar1, ar2 = ensure_arraylike("in1d", ar1, ar2) @@ -59,8 +55,10 @@ def _in1d(ar1: ArrayLike, ar2: ArrayLike, invert: bool, else: return (arr1[:, None] == arr2[None, :]).any(-1) elif method == 'binary_search': + from jax._src.numpy.lax_numpy import searchsorted + arr2 = lax.sort(arr2) - ind = jax.numpy.searchsorted(arr2, arr1) + ind = searchsorted(arr2, arr1) if invert: return arr1 != arr2[ind] else: @@ -86,8 +84,8 @@ def _concat_unique(arr1: Array, arr2: Array) -> tuple[Array, Array]: arr1, num_unique1 = _unique(arr1, axis=0, size=arr1.size, return_true_size=True) arr2, num_unique2 = _unique(arr2, axis=0, size=arr2.size, return_true_size=True) arr = zeros(arr1.size + arr2.size, dtype=dtypes.result_type(arr1, arr2)) - arr = lax.dynamic_update_slice(arr, arr1, (0,)) - arr = lax.dynamic_update_slice(arr, arr2, (num_unique1,)) + arr = lax_slicing.dynamic_update_slice(arr, arr1, (0,)) + arr = lax_slicing.dynamic_update_slice(arr, arr2, (num_unique1,)) return arr, num_unique1 + num_unique2 @@ -158,6 +156,9 @@ def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, Array([1, 2, 0, 0], dtype=int32) """ arr1, arr2 = ensure_arraylike("setdiff1d", ar1, ar2) + arr1 = arr1.ravel() + arr2 = arr2.ravel() + if size is None: core.concrete_or_error(None, ar1, "The error arose in setdiff1d()") else: @@ -252,7 +253,7 @@ def union1d(ar1: ArrayLike, ar2: ArrayLike, return cast(Array, out) -@partial(jit, static_argnames=['assume_unique', 'size']) +@api.jit(static_argnames=['assume_unique', 'size']) def _setxor1d_size(arr1: Array, arr2: Array, fill_value: ArrayLike | None, *, assume_unique: bool, size: int, ) -> Array: # Ensured by caller @@ -339,14 +340,15 @@ def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, *, return aux[flag[1:] & flag[:-1]] -@partial(jit, static_argnames=['return_indices']) +@api.jit(static_argnames=['return_indices']) def _intersect1d_sorted_mask(arr1: Array, arr2: Array, return_indices: bool) -> tuple[Array, Array, Array | None]: """JIT-compatible helper function for intersect1d""" assert arr1.ndim == arr2.ndim == 1 arr = concatenate((arr1, arr2)) if return_indices: - iota = lax.broadcasted_iota(np.int64, np.shape(arr), dimension=0) + idx_dtype = lax_utils.int_dtype_for_dim(arr.shape[0], signed=True) + iota = lax.broadcasted_iota(idx_dtype, np.shape(arr), dimension=0) aux, indices = lax.sort_key_val(arr, iota) else: aux = sort(arr) @@ -355,7 +357,7 @@ def _intersect1d_sorted_mask(arr1: Array, arr2: Array, return aux, mask, indices -@partial(jit, static_argnames=['fill_value', 'assume_unique', 'size', 'return_indices']) +@api.jit(static_argnames=['fill_value', 'assume_unique', 'size', 'return_indices']) def _intersect1d_size(arr1: Array, arr2: Array, fill_value: ArrayLike | None, assume_unique: bool, size: int, return_indices: bool) -> Array | tuple[Array, Array, Array]: """Jit-compatible helper function for intersect1d with size specified.""" @@ -382,8 +384,8 @@ def _intersect1d_size(arr1: Array, arr2: Array, fill_value: ArrayLike | None, as arr1, ind1, num_unique1 = _unique(arr1, 0, size=arr1.size, return_index=True, return_true_size=True, fill_value=0) arr2, ind2, num_unique2 = _unique(arr2, 0, size=arr2.size, return_index=True, return_true_size=True, fill_value=0) arr = zeros(arr1.size + arr2.size, dtype=dtypes.result_type(arr1, arr2)) - arr = lax.dynamic_update_slice(arr, arr1, (0,)) - arr = lax.dynamic_update_slice(arr, arr2, (num_unique1,)) + arr = lax_slicing.dynamic_update_slice(arr, arr1, (0,)) + arr = lax_slicing.dynamic_update_slice(arr, arr2, (num_unique1,)) mask = arange(arr.size) < num_unique1 + num_unique2 _, aux, aux_sort_indices = lax.sort([~mask, arr, arange(arr.size)], is_stable=True, num_keys=2) @@ -572,14 +574,14 @@ def isin(element: ArrayLike, test_elements: ArrayLike, "To make jnp.unique() compatible with JIT and other transforms, you can specify " "a concrete value for the size argument, which will determine the output size.") -@partial(jit, static_argnames=['axis', 'equal_nan']) +@api.jit(static_argnames=['axis', 'equal_nan']) def _unique_sorted_mask(ar: Array, axis: int, equal_nan: bool) -> tuple[Array, Array, Array]: aux = moveaxis(ar, axis, 0) if np.issubdtype(aux.dtype, np.complexfloating): # Work around issue in sorting of complex numbers with Nan only in the # imaginary component. This can be removed if sorting in this situation # is fixed to match numpy. - aux = where(isnan(aux), _lax_const(aux, np.nan), aux) + aux = where(isnan(aux), lax._const(aux, np.nan), aux) size, *out_shape = aux.shape if math.prod(out_shape) == 0: size = 1 @@ -620,7 +622,7 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo ind = nonzero(mask, size=size)[0] result = aux[ind] if aux.size else aux if size is not None and fill_value is not None: - fill_value = lax_internal.asarray(fill_value).astype(result.dtype) + fill_value = lax.asarray(fill_value).astype(result.dtype) if result.shape[0]: valid = lax.expand_dims(arange(size) < mask.sum(), tuple(range(1, result.ndim))) result = where(valid, result, fill_value) @@ -637,7 +639,7 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo if return_inverse: if aux.size: imask = cumsum(mask) - 1 - inv_idx = zeros(mask.shape, dtype=dtypes.canonicalize_dtype(dtypes.int_)) + inv_idx = zeros(mask.shape, dtype=int) inv_idx = inv_idx.at[perm].set(imask) else: inv_idx = zeros(ar.shape[axis], dtype=int) @@ -651,7 +653,7 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo idx = idx.at[1:].set(where(idx[1:], idx[1:], mask.size)) ret += (diff(idx),) elif ar.shape[axis]: - ret += (full((1,), ar.shape[axis], dtype=dtypes.canonicalize_dtype(dtypes.int_)),) + ret += (full((1,), ar.shape[axis], dtype=int),) else: ret += (empty(0, dtype=int),) if return_true_size: diff --git a/jax/_src/numpy/sorting.py b/jax/_src/numpy/sorting.py index a0f368e2ef07..bb2a14acf541 100644 --- a/jax/_src/numpy/sorting.py +++ b/jax/_src/numpy/sorting.py @@ -12,24 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial -from typing import Sequence +from collections.abc import Sequence import numpy as np -import jax from jax._src import api -from jax._src import core from jax._src import dtypes +from jax._src.lax import lax +from jax._src.lax import utils as lax_utils from jax._src.numpy import util from jax._src.util import canonicalize_axis, set_module from jax._src.typing import Array, ArrayLike -from jax import lax export = set_module('jax.numpy') @export -@partial(api.jit, static_argnames=('axis', 'kind', 'order', 'stable', 'descending')) +@api.jit(static_argnames=('axis', 'kind', 'order', 'stable', 'descending')) def sort( a: ArrayLike, axis: int | None = -1, @@ -90,7 +88,7 @@ def sort( return lax.rev(result, dimensions=[dimension]) if descending else result @export -@partial(api.jit, static_argnames=('axis', 'kind', 'order', 'stable', 'descending')) +@api.jit(static_argnames=('axis', 'kind', 'order', 'stable', 'descending')) def argsort( a: ArrayLike, axis: int | None = -1, @@ -155,8 +153,12 @@ def argsort( arr = arr.ravel() axis = 0 dimension = canonicalize_axis(axis, arr.ndim) - use_64bit_index = not core.is_constant_dim(arr.shape[dimension]) or arr.shape[dimension] >= (1 << 31) - iota = lax.broadcasted_iota(np.dtype('int64') if use_64bit_index else dtypes.int_, arr.shape, dimension) + idx_dtype = lax_utils.int_dtype_for_dim(arr.shape[dimension], signed=True) + # We'd give the correct output values with int32, but use the default dtype to + # match NumPy type semantics if x64 mode is enabled for now. + if idx_dtype == np.dtype(np.int32): + idx_dtype = dtypes.default_int_dtype() + iota = lax.broadcasted_iota(idx_dtype, arr.shape, dimension) # For stable descending sort, we reverse the array and indices to ensure that # duplicates remain in their original order when the final indices are reversed. # For non-stable descending sort, we can avoid these extra operations. @@ -168,7 +170,7 @@ def argsort( @export -@partial(api.jit, static_argnames=['kth', 'axis']) +@api.jit(static_argnames=['kth', 'axis']) def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array: """Returns a partially-sorted copy of an array. @@ -226,7 +228,7 @@ def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array: axis = canonicalize_axis(axis, arr.ndim) kth = canonicalize_axis(kth, arr.shape[axis]) - arr = jax.numpy.swapaxes(arr, axis, -1) + arr = arr.swapaxes(axis, -1) if dtypes.isdtype(arr.dtype, "unsigned integer"): # Here, we apply a trick to handle correctly 0 values for unsigned integers bottom = -lax.top_k(-(arr + 1), kth + 1)[0] - 1 @@ -234,11 +236,11 @@ def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array: bottom = -lax.top_k(-arr, kth + 1)[0] top = lax.top_k(arr, arr.shape[-1] - kth - 1)[0] out = lax.concatenate([bottom, top], dimension=arr.ndim - 1) - return jax.numpy.swapaxes(out, -1, axis) + return out.swapaxes(-1, axis) @export -@partial(api.jit, static_argnames=['kth', 'axis']) +@api.jit(static_argnames=['kth', 'axis']) def argpartition(a: ArrayLike, kth: int, axis: int = -1) -> Array: """Returns indices that partially sort an array. @@ -297,7 +299,7 @@ def argpartition(a: ArrayLike, kth: int, axis: int = -1) -> Array: axis = canonicalize_axis(axis, arr.ndim) kth = canonicalize_axis(kth, arr.shape[axis]) - arr = jax.numpy.swapaxes(arr, axis, -1) + arr = arr.swapaxes(axis, -1) if dtypes.isdtype(arr.dtype, "unsigned integer"): # Here, we apply a trick to handle correctly 0 values for unsigned integers bottom_ind = lax.top_k(-(arr + 1), kth + 1)[1] @@ -307,11 +309,11 @@ def argpartition(a: ArrayLike, kth: int, axis: int = -1) -> Array: # To avoid issues with duplicate values, we compute the top indices via a proxy set_to_zero = lambda a, i: a.at[i].set(0) for _ in range(arr.ndim - 1): - set_to_zero = jax.vmap(set_to_zero) - proxy = set_to_zero(jax.numpy.ones(arr.shape), bottom_ind) + set_to_zero = api.vmap(set_to_zero) + proxy = set_to_zero(lax.full(arr.shape, 1.0), bottom_ind) top_ind = lax.top_k(proxy, arr.shape[-1] - kth - 1)[1] out = lax.concatenate([bottom_ind, top_ind], dimension=arr.ndim - 1) - return jax.numpy.swapaxes(out, -1, axis) + return out.swapaxes(-1, axis) @export @@ -353,7 +355,7 @@ def sort_complex(a: ArrayLike) -> Array: @export -@partial(api.jit, static_argnames=('axis',)) +@api.jit(static_argnames=('axis',)) def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> Array: """Sort a sequence of keys in lexicographic order. @@ -421,9 +423,13 @@ def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> A if len({np.shape(key) for key in key_arrays}) > 1: raise ValueError("all keys need to be the same shape") if np.ndim(key_arrays[0]) == 0: - return jax.numpy.array(0, dtype=dtypes.canonicalize_dtype(dtypes.int_)) + return lax.full((), 0, dtypes.default_int_dtype()) axis = canonicalize_axis(axis, np.ndim(key_arrays[0])) - use_64bit_index = key_arrays[0].shape[axis] >= (1 << 31) - iota = lax.broadcasted_iota(np.dtype('int64') if use_64bit_index else dtypes.int_, - np.shape(key_arrays[0]), axis) + idx_dtype = lax_utils.int_dtype_for_dim(key_arrays[0].shape[axis], + signed=True) + # We'd give the correct output values with int32, but use the default dtype to + # match NumPy type semantics if x64 mode is enabled for now. + if idx_dtype == np.dtype(np.int32): + idx_dtype = dtypes.default_int_dtype() + iota = lax.broadcasted_iota(idx_dtype, np.shape(key_arrays[0]), axis) return lax.sort((*key_arrays[::-1], iota), dimension=axis, num_keys=len(key_arrays))[-1] diff --git a/jax/_src/numpy/tensor_contractions.py b/jax/_src/numpy/tensor_contractions.py index 850eb90cf1d2..4b4238dce765 100644 --- a/jax/_src/numpy/tensor_contractions.py +++ b/jax/_src/numpy/tensor_contractions.py @@ -20,15 +20,13 @@ import numpy as np -import jax -from jax import lax +from jax._src import api from jax._src import core from jax._src import dtypes -from jax._src.api import jit -from jax._src.lax import lax as lax_internal -from jax._src.lax.lax import PrecisionLike +from jax._src.lax import lax from jax._src.numpy import ufuncs from jax._src.numpy import util +from jax._src.sharding_impls import NamedSharding, PartitionSpec as P from jax._src.numpy.vectorize import vectorize from jax._src.typing import Array, ArrayLike, DTypeLike from jax._src.util import canonicalize_axis, set_module @@ -36,10 +34,12 @@ export = set_module('jax.numpy') @export -@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) +@api.jit(static_argnames=('precision', 'preferred_element_type', 'out_sharding'), + inline=True) def dot(a: ArrayLike, b: ArrayLike, *, - precision: PrecisionLike = None, - preferred_element_type: DTypeLike | None = None) -> Array: + precision: lax.PrecisionLike = None, + preferred_element_type: DTypeLike | None = None, + out_sharding=None) -> Array: """Compute the dot product of two arrays. JAX implementation of :func:`numpy.dot`. @@ -102,10 +102,12 @@ def dot(a: ArrayLike, b: ArrayLike, *, (3, 2, 1) """ a, b = util.ensure_arraylike("dot", a, b) - dtypes.check_user_dtype_supported(preferred_element_type, "dot") if preferred_element_type is None: - preferred_element_type, output_weak_type = dtypes.result_type(a, b, return_weak_type_flag=True) + preferred_element_type, output_weak_type = dtypes.result_type( + a, b, return_weak_type_flag=True) else: + preferred_element_type = dtypes.check_and_canonicalize_user_dtype( + preferred_element_type, "dot") output_weak_type = False batch_dims = ((), ()) @@ -119,16 +121,22 @@ def dot(a: ArrayLike, b: ArrayLike, *, contract_dims = ((a_ndim - 1,), (b_ndim - 2,)) result = lax.dot_general(a, b, dimension_numbers=(contract_dims, batch_dims), precision=precision, - preferred_element_type=preferred_element_type) - return lax_internal._convert_element_type(result, preferred_element_type, - output_weak_type) + preferred_element_type=preferred_element_type, + out_sharding=out_sharding) + return lax._convert_element_type(result, preferred_element_type, + output_weak_type) @export -@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) +@partial( + api.jit, + static_argnames=('precision', 'preferred_element_type', 'out_sharding'), + inline=True, +) def matmul(a: ArrayLike, b: ArrayLike, *, - precision: PrecisionLike = None, + precision: lax.PrecisionLike = None, preferred_element_type: DTypeLike | None = None, + out_sharding: NamedSharding | P | None = None, ) -> Array: """Perform a matrix multiplication. @@ -184,7 +192,6 @@ def matmul(a: ArrayLike, b: ArrayLike, *, [49, 64]], dtype=int32) """ a, b = util.ensure_arraylike("matmul", a, b) - dtypes.check_user_dtype_supported(preferred_element_type, "matmul") for i, x in enumerate((a, b)): if np.ndim(x) < 1: msg = (f"matmul input operand {i} must have ndim at least 1, " @@ -193,6 +200,8 @@ def matmul(a: ArrayLike, b: ArrayLike, *, if preferred_element_type is None: preferred_element_type, output_weak_type = dtypes.result_type(a, b, return_weak_type_flag=True) else: + preferred_element_type = dtypes.check_and_canonicalize_user_dtype( + preferred_element_type, "matmul") output_weak_type = False a_is_mat, b_is_mat = (np.ndim(a) > 1), (np.ndim(b) > 1) @@ -234,21 +243,24 @@ def matmul(a: ArrayLike, b: ArrayLike, *, raise ValueError("Incompatible shapes for matmul arguments: {} and {}" .format(np.shape(a), np.shape(b))) - if a_is_mat: idx_a_other.append(num_batch_dims) - if b_is_mat: idx_b_other.append(num_batch_dims + a_is_mat) + if a_is_mat: + idx_a_other.append(num_batch_dims) + if b_is_mat: + idx_b_other.append(num_batch_dims + a_is_mat) perm = np.argsort(np.concatenate([idx_batch, idx_a_other, idx_b_other])) a = lax.squeeze(a, tuple(a_squeeze)) b = lax.squeeze(b, tuple(b_squeeze)) out = lax.dot_general( a, b, (((np.ndim(a) - 1,), (np.ndim(b) - 1 - b_is_mat,)), (a_batch, b_batch)), - precision=precision, preferred_element_type=preferred_element_type) + precision=precision, preferred_element_type=preferred_element_type, + out_sharding=out_sharding) result = lax.transpose(out, perm) - return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type) + return lax._convert_element_type(result, preferred_element_type, output_weak_type) @export -@jit +@api.jit def matvec(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Batched matrix-vector product. @@ -284,12 +296,12 @@ def matvec(x1: ArrayLike, x2: ArrayLike, /) -> Array: Array([[ 50, 122], [ 38, 92]], dtype=int32) """ - util.check_arraylike("matvec", x1, x2) + x1, x2 = util.ensure_arraylike("matvec", x1, x2) return vectorize(matmul, signature="(n,m),(m)->(n)")(x1, x2) @export -@jit +@api.jit def vecmat(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Batched conjugate vector-matrix product. @@ -326,15 +338,15 @@ def vecmat(x1: ArrayLike, x2: ArrayLike, /) -> Array: Array([[ 40, 46], [ 94, 109]], dtype=int32) """ - util.check_arraylike("matvec", x1, x2) + x1, x2 = util.ensure_arraylike("matvec", x1, x2) return vectorize(matmul, signature="(n),(n,m)->(m)")(ufuncs.conj(x1), x2) @export -@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) +@api.jit(static_argnames=('precision', 'preferred_element_type'), inline=True) def vdot( a: ArrayLike, b: ArrayLike, *, - precision: PrecisionLike = None, + precision: lax.PrecisionLike = None, preferred_element_type: DTypeLike | None = None, ) -> Array: """Perform a conjugate multiplication of two 1D vectors. @@ -372,16 +384,16 @@ def vdot( >>> jnp.dot(x, y) Array(0.+14.j, dtype=complex64) """ - util.check_arraylike("vdot", a, b) - if dtypes.issubdtype(dtypes.dtype(a, canonicalize=True), np.complexfloating): + a, b = util.ensure_arraylike("vdot", a, b) + if dtypes.issubdtype(a.dtype, np.complexfloating): a = ufuncs.conj(a) - return dot(jax.numpy.ravel(a), jax.numpy.ravel(b), precision=precision, + return dot(a.ravel(), b.ravel(), precision=precision, preferred_element_type=preferred_element_type) @export def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, - precision: PrecisionLike = None, + precision: lax.PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: """Perform a conjugate multiplication of two batched vectors. @@ -426,11 +438,13 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, >>> jnp.linalg.vecdot(a, b, axis=-1) Array([20, 47], dtype=int32) """ + from jax._src.numpy.lax_numpy import moveaxis + x1_arr, x2_arr = util.ensure_arraylike("jnp.vecdot", x1, x2) if x1_arr.shape[axis] != x2_arr.shape[axis]: raise ValueError(f"axes must match; got shapes {x1_arr.shape} and {x2_arr.shape} with {axis=}") - x1_arr = jax.numpy.moveaxis(x1_arr, axis, -1) - x2_arr = jax.numpy.moveaxis(x2_arr, axis, -1) + x1_arr = moveaxis(x1_arr, axis, -1) + x2_arr = moveaxis(x2_arr, axis, -1) return vectorize(partial(vdot, precision=precision, preferred_element_type=preferred_element_type), signature="(n),(n)->()")(x1_arr, x2_arr) @@ -438,8 +452,9 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, @export def tensordot(a: ArrayLike, b: ArrayLike, axes: int | Sequence[int] | Sequence[Sequence[int]] = 2, - *, precision: PrecisionLike = None, - preferred_element_type: DTypeLike | None = None) -> Array: + *, precision: lax.PrecisionLike = None, + preferred_element_type: DTypeLike | None = None, + out_sharding: NamedSharding | P | None = None) -> Array: """Compute the tensor dot product of two N-dimensional arrays. JAX implementation of :func:`numpy.linalg.tensordot`. @@ -512,13 +527,14 @@ def tensordot(a: ArrayLike, b: ArrayLike, [2, 4, 6]], dtype=int32) """ a, b = util.ensure_arraylike("tensordot", a, b) - dtypes.check_user_dtype_supported(preferred_element_type, "tensordot") a_ndim = np.ndim(a) b_ndim = np.ndim(b) if preferred_element_type is None: preferred_element_type, output_weak_type = dtypes.result_type(a, b, return_weak_type_flag=True) else: + preferred_element_type = dtypes.check_and_canonicalize_user_dtype( + preferred_element_type, "tensordot") output_weak_type = False if type(axes) is int: @@ -545,16 +561,18 @@ def tensordot(a: ArrayLike, b: ArrayLike, msg = ("tensordot axes argument must be an int, a pair of ints, or a pair " "of lists/tuples of ints.") raise TypeError(msg) - result = lax.dot_general(a, b, (contracting_dims, ((), ())), - precision=precision, preferred_element_type=preferred_element_type) - return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type) + result = lax.dot_general( + a, b, (contracting_dims, ((), ())), precision=precision, + preferred_element_type=preferred_element_type, + out_sharding=out_sharding) + return lax._convert_element_type(result, preferred_element_type, output_weak_type) @export -@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) +@api.jit(static_argnames=('precision', 'preferred_element_type'), inline=True) def inner( - a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, + a: ArrayLike, b: ArrayLike, *, precision: lax.PrecisionLike = None, preferred_element_type: DTypeLike | None = None, ) -> Array: """Compute the inner product of two arrays. @@ -601,15 +619,16 @@ def inner( """ a, b = util.ensure_arraylike("inner", a, b) if np.ndim(a) == 0 or np.ndim(b) == 0: - a = jax.numpy.asarray(a, dtype=preferred_element_type) - b = jax.numpy.asarray(b, dtype=preferred_element_type) + if preferred_element_type is not None: + a = a.astype(preferred_element_type) + b = b.astype(preferred_element_type) return a * b return tensordot(a, b, (-1, -1), precision=precision, preferred_element_type=preferred_element_type) @export -@partial(jit, inline=True) +@api.jit(inline=True) def outer(a: ArrayLike, b: ArrayLike, out: None = None) -> Array: """Compute the outer product of two arrays. @@ -638,6 +657,6 @@ def outer(a: ArrayLike, b: ArrayLike, out: None = None) -> Array: """ if out is not None: raise NotImplementedError("The 'out' argument to jnp.outer is not supported.") - util.check_arraylike("outer", a, b) + a, b = util.ensure_arraylike("outer", a, b) a, b = util.promote_dtypes(a, b) - return jax.numpy.ravel(a)[:, None] * jax.numpy.ravel(b)[None, :] + return a.ravel()[:, None] * b.ravel()[None, :] diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index c488855b70fa..b3278b12976c 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -17,21 +17,23 @@ from __future__ import annotations from collections.abc import Callable -from functools import partial import math import operator from typing import Any -import jax +import numpy as np + +from jax._src import api from jax._src.typing import Array, ArrayLike, DTypeLike -from jax._src.lax import lax as lax_internal +from jax._src.lax import control_flow +from jax._src.lax import slicing +from jax._src.lax import lax from jax._src.numpy import indexing -import jax._src.numpy.lax_numpy as jnp +from jax._src.numpy import lax_numpy as jnp from jax._src.numpy.reductions import _moveaxis from jax._src.numpy.util import check_arraylike, _broadcast_to, _where from jax._src.numpy.vectorize import vectorize from jax._src.util import canonicalize_axis, set_module -import numpy as np export = set_module("jax.numpy") @@ -90,7 +92,7 @@ class ufunc: [ 5, 6, 7, 8, 9], [ 6, 7, 8, 9, 10]], dtype=int32) - The :meth:`ufunc.reduce` method perfoms a reduction over the array. + The :meth:`ufunc.reduce` method performs a reduction over the array. For example, :meth:`jnp.add.reduce` is equivalent to ``jnp.sum``: >>> jnp.add.reduce(x) @@ -110,7 +112,7 @@ class ufunc: Array([101, 2, 3, 4, 5], dtype=int32) And the :meth:`ufunc.reduceat` method performs a number of ``reduce`` - operations bewteen specified indices of an array; for ``jnp.add`` the + operations between specified indices of an array; for ``jnp.add`` the operation is similar to :func:`jax.ops.segment_sum`: >>> jnp.add.reduceat(x, jnp.array([0, 2])) @@ -179,12 +181,12 @@ def __call__(self, *args: ArrayLike, out: None = None, where: None = None) -> An call = self.__static_props['call'] or self._call_vectorized return call(*args) - @partial(jax.jit, static_argnames=['self']) + @api.jit(static_argnames=['self']) def _call_vectorized(self, *args): return vectorize(self._func)(*args) - @partial(jax.jit, static_argnames=['self', 'axis', 'dtype', 'out', 'keepdims']) - def reduce(self, a: ArrayLike, axis: int = 0, + @api.jit(static_argnames=['self', 'axis', 'dtype', 'out', 'keepdims']) + def reduce(self, a: ArrayLike, axis: int | None = 0, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: @@ -249,8 +251,8 @@ def reduce(self, a: ArrayLike, axis: int = 0, if self.identity is None and initial is None: raise ValueError(f"reduction operation {self.__name__!r} does not have an identity, " "so to use a where mask one has to specify 'initial'.") - if lax_internal._dtype(where) != bool: - raise ValueError(f"where argument must have dtype=bool; got dtype={lax_internal._dtype(where)}") + if lax._dtype(where) != bool: + raise ValueError(f"where argument must have dtype=bool; got dtype={lax._dtype(where)}") reduce = self.__static_props['reduce'] or self._reduce_via_scan return reduce(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where) @@ -258,11 +260,11 @@ def _reduce_via_scan(self, arr: ArrayLike, axis: int | None = 0, dtype: DTypeLik keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: assert self.nin == 2 and self.nout == 1 - arr = lax_internal.asarray(arr) + arr = lax.asarray(arr) if initial is None: initial = self.identity if dtype is None: - dtype = jax.eval_shape(self._func, lax_internal._one(arr), lax_internal._one(arr)).dtype + dtype = api.eval_shape(self._func, lax._one(arr), lax._one(arr)).dtype if where is not None: where = _broadcast_to(where, arr.shape) if isinstance(axis, tuple): @@ -290,8 +292,10 @@ def _reduce_via_scan(self, arr: ArrayLike, axis: int | None = 0, dtype: DTypeLik if where is not None: where = _moveaxis(where, axis, 0) - if initial is None and arr.shape[0] == 0: - raise ValueError("zero-size array to reduction operation {self.__name__} which has no ideneity") + if arr.shape[0] == 0: + if initial is None: + raise ValueError(f"zero-size array to reduction operation {self.__name__} which has no identity") + return lax.full(final_shape, initial, dtype) def body_fun(i, val): if where is None: @@ -306,15 +310,15 @@ def body_fun(i, val): else: start_index = 0 start_value = initial - start_value = _broadcast_to(lax_internal.asarray(start_value).astype(dtype), arr.shape[1:]) + start_value = _broadcast_to(lax.asarray(start_value).astype(dtype), arr.shape[1:]) - result = jax.lax.fori_loop(start_index, arr.shape[0], body_fun, start_value) + result = control_flow.fori_loop(start_index, arr.shape[0], body_fun, start_value) if keepdims: result = result.reshape(final_shape) return result - @partial(jax.jit, static_argnames=['self', 'axis', 'dtype']) + @api.jit(static_argnames=['self', 'axis', 'dtype']) def accumulate(self, a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, out: None = None) -> Array: """Accumulate operation derived from binary ufunc. @@ -376,24 +380,26 @@ def _accumulate_via_scan(self, arr: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None) -> Array: assert self.nin == 2 and self.nout == 1 check_arraylike(f"{self.__name__}.accumulate", arr) - arr = lax_internal.asarray(arr) + arr = lax.asarray(arr) if dtype is None: - dtype = jax.eval_shape(self._func, lax_internal._one(arr), lax_internal._one(arr)).dtype + dtype = api.eval_shape(self._func, lax._one(arr), lax._one(arr)).dtype if axis is None or isinstance(axis, tuple): raise ValueError("accumulate does not allow multiple axes") axis = canonicalize_axis(axis, np.ndim(arr)) + if arr.size == 0: + return lax.full(arr.shape, 0, dtype) arr = _moveaxis(arr, axis, 0) def scan_fun(carry, _): i, x = carry y = _where(i == 0, arr[0].astype(dtype), self(x.astype(dtype), arr[i].astype(dtype))) return (i + 1, y), y - _, result = jax.lax.scan(scan_fun, (0, arr[0].astype(dtype)), None, length=arr.shape[0]) + _, result = control_flow.scan(scan_fun, (0, arr[0].astype(dtype)), None, length=arr.shape[0]) return _moveaxis(result, 0, axis) - @partial(jax.jit, static_argnums=[0], static_argnames=['inplace']) + @api.jit(static_argnums=[0], static_argnames=['inplace']) def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *, inplace: bool = True) -> Array: """Update elements of an array via the specified unary or binary ufunc. @@ -440,15 +446,15 @@ def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *, def _at_via_scan(self, a: ArrayLike, indices: Any, *args: Any) -> Array: assert len(args) in {0, 1} check_arraylike(f"{self.__name__}.at", a, *args) - dtype = jax.eval_shape(self._func, lax_internal._one(a), *(lax_internal._one(arg) for arg in args)).dtype - a = lax_internal.asarray(a).astype(dtype) - args = tuple(lax_internal.asarray(arg).astype(dtype) for arg in args) + dtype = api.eval_shape(self._func, lax._one(a), *(lax._one(arg) for arg in args)).dtype + a = lax.asarray(a).astype(dtype) + args = tuple(lax.asarray(arg).astype(dtype) for arg in args) indices = indexing.eliminate_deprecated_list_indexing(indices) if not indices: return a shapes = [np.shape(i) for i in indices if not isinstance(i, slice)] - shape = shapes and jax.lax.broadcast_shapes(*shapes) + shape = shapes and lax.broadcast_shapes(*shapes) if not shape: return a.at[indices].set(self(a.at[indices].get(), *args)) @@ -462,10 +468,10 @@ def scan_fun(carry, x): idx = tuple(ind if isinstance(ind, slice) else ind[i] for ind in indices) a = a.at[idx].set(self(a.at[idx].get(), *(arg[i] for arg in args))) return (i + 1, a), x - carry, _ = jax.lax.scan(scan_fun, (0, a), None, len(indices[0])) # type: ignore[arg-type] + carry, _ = control_flow.scan(scan_fun, (0, a), None, len(indices[0])) # type: ignore[arg-type] return carry[1] - @partial(jax.jit, static_argnames=['self', 'axis', 'dtype']) + @api.jit(static_argnames=['self', 'axis', 'dtype']) def reduceat(self, a: ArrayLike, indices: Any, axis: int = 0, dtype: DTypeLike | None = None, out: None = None) -> Array: """Reduce an array between specified indices via a binary ufunc. @@ -517,7 +523,7 @@ def reduceat(self, a: ArrayLike, indices: Any, axis: int = 0, def _reduceat_via_scan(self, a: ArrayLike, indices: Any, axis: int = 0, dtype: DTypeLike | None = None) -> Array: check_arraylike(f"{self.__name__}.reduceat", a, indices) - a = lax_internal.asarray(a) + a = lax.asarray(a) idx_tuple = indexing.eliminate_deprecated_list_indexing(indices) assert len(idx_tuple) == 1 indices = idx_tuple[0] @@ -531,17 +537,17 @@ def _reduceat_via_scan(self, a: ArrayLike, indices: Any, axis: int = 0, raise ValueError("reduceat requires a single integer axis.") axis = canonicalize_axis(axis, a.ndim) out = indexing.take(a, indices, axis=axis) - ind = jax.lax.expand_dims(jnp.append(indices, a.shape[axis]), - list(np.delete(np.arange(out.ndim), axis))) - ind_start = jax.lax.slice_in_dim(ind, 0, ind.shape[axis] - 1, axis=axis) - ind_end = jax.lax.slice_in_dim(ind, 1, ind.shape[axis], axis=axis) + ind = lax.expand_dims(jnp.append(indices, a.shape[axis]), + list(np.delete(np.arange(out.ndim), axis))) + ind_start = slicing.slice_in_dim(ind, 0, ind.shape[axis] - 1, axis=axis) + ind_end = slicing.slice_in_dim(ind, 1, ind.shape[axis], axis=axis) def loop_body(i, out): return _where((i > ind_start) & (i < ind_end), - self(out, indexing.take(a, jax.lax.expand_dims(i, (0,)), axis=axis)), + self(out, indexing.take(a, lax.expand_dims(i, (0,)), axis=axis)), out) - return jax.lax.fori_loop(0, a.shape[axis], loop_body, out) + return control_flow.fori_loop(0, a.shape[axis], loop_body, out) - @partial(jax.jit, static_argnums=[0]) + @api.jit(static_argnums=[0]) def outer(self, A: ArrayLike, B: ArrayLike, /) -> Array: """Apply the function to all pairs of values in ``A`` and ``B``. @@ -572,7 +578,7 @@ def outer(self, A: ArrayLike, B: ArrayLike, /) -> Array: [ 10 20 30 40 50 60 70 80 90 100]] For input arrays with ``N`` and ``M`` dimensions respectively, the output - will have dimesion ``N + M``: + will have dimension ``N + M``: >>> x = jnp.ones((1, 3, 5)) >>> y = jnp.ones((2, 4)) @@ -584,8 +590,8 @@ def outer(self, A: ArrayLike, B: ArrayLike, /) -> Array: if self.nout != 1: raise ValueError("outer only supported for functions returning a single value") check_arraylike(f"{self.__name__}.outer", A, B) - _ravel = lambda A: jax.lax.reshape(A, (np.size(A),)) - result = jax.vmap(jax.vmap(self, (None, 0)), (0, None))(_ravel(A), _ravel(B)) + _ravel = lambda A: lax.reshape(A, (np.size(A),)) + result = api.vmap(api.vmap(self, (None, 0)), (0, None))(_ravel(A), _ravel(B)) return result.reshape(*np.shape(A), *np.shape(B)) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 91191d24a12e..8ee55a6c4b55 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -32,12 +32,14 @@ from jax._src.lax import lax from jax._src.lax import other as lax_other from jax._src.typing import Array, ArrayLike +from jax._src.numpy import array_constructors +from jax._src.numpy import error as jnp_error +from jax._src.numpy import reductions +from jax._src.numpy.ufunc_api import ufunc from jax._src.numpy.util import ( - check_arraylike, promote_args, promote_args_inexact, + check_arraylike, ensure_arraylike, promote_args, promote_args_inexact, promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric, promote_shapes, _where, check_no_float0s) -from jax._src.numpy.ufunc_api import ufunc -from jax._src.numpy import reductions from jax._src.util import set_module @@ -52,7 +54,7 @@ } def _constant_like(x, const): - return np.array(const, dtype=dtypes.dtype(x)) + return array_constructors.array(const, dtype=dtypes.dtype(x)) def _replace_inf(x: ArrayLike) -> Array: return lax.select(isposinf(real(x)), lax._zeros(x), x) @@ -80,7 +82,7 @@ def decorator(func: Callable[[ArrayLike, ArrayLike], Array]) -> ufunc: @export -@partial(jit, inline=True) +@jit(inline=True) def fabs(x: ArrayLike, /) -> Array: """Compute the element-wise absolute values of the real-valued input. @@ -118,28 +120,28 @@ def fabs(x: ArrayLike, /) -> Array: >>> jnp.fabs(x2) Array([1., 0.], dtype=float32) """ - check_arraylike('fabs', x) - if dtypes.issubdtype(dtypes.dtype(x), np.complexfloating): + x = ensure_arraylike('fabs', x) + if dtypes.issubdtype(x.dtype, np.complexfloating): raise TypeError("ufunc 'fabs' does not support complex dtypes") return lax.abs(*promote_args_inexact('fabs', x)) @export -@partial(jit, inline=True) +@jit(inline=True) def bitwise_invert(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.invert`.""" return lax.bitwise_not(*promote_args('bitwise_invert', x)) @export -@partial(jit, inline=True) +@jit(inline=True) def bitwise_not(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.invert`.""" return lax.bitwise_not(*promote_args('bitwise_not', x)) @export -@partial(jit, inline=True) +@jit(inline=True) def invert(x: ArrayLike, /) -> Array: """Compute the bitwise inversion of an input. @@ -232,7 +234,7 @@ def negative(x: ArrayLike, /) -> Array: @export -@partial(jit, inline=True) +@jit(inline=True) def positive(x: ArrayLike, /) -> Array: """Return element-wise positive values of the input. @@ -281,7 +283,7 @@ def positive(x: ArrayLike, /) -> Array: @export -@partial(jit, inline=True) +@jit(inline=True) def sign(x: ArrayLike, /) -> Array: r"""Return an element-wise indication of sign of the input. @@ -296,7 +298,7 @@ def sign(x: ArrayLike, /) -> Array: -1, & x < 0 \end{cases} - For complex valued input, ``jnp.sign`` returns a unit vector repesenting the + For complex valued input, ``jnp.sign`` returns a unit vector representing the phase. For generalized case, the sign of ``x`` is given by: .. math:: @@ -332,7 +334,7 @@ def sign(x: ArrayLike, /) -> Array: @export -@partial(jit, inline=True) +@jit(inline=True) def floor(x: ArrayLike, /) -> Array: """Round input to the nearest integer downwards. @@ -346,8 +348,8 @@ def floor(x: ArrayLike, /) -> Array: the nearest integer that is less than or equal to the value itself. See also: - - :func:`jax.numpy.fix`: Rounds the input to the nearest interger towards zero. - - :func:`jax.numpy.trunc`: Rounds the input to the nearest interger towards + - :func:`jax.numpy.fix`: Rounds the input to the nearest integer towards zero. + - :func:`jax.numpy.trunc`: Rounds the input to the nearest integer towards zero. - :func:`jax.numpy.ceil`: Rounds the input up to the nearest integer. @@ -364,14 +366,14 @@ def floor(x: ArrayLike, /) -> Array: [ 0., -1., 0.], [-5., 2., 1.]], dtype=float32) """ - check_arraylike('floor', x) - if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')): - return lax.asarray(x) + x = ensure_arraylike('floor', x) + if dtypes.isdtype(x.dtype, ('integral', 'bool')): + return x return lax.floor(*promote_args_inexact('floor', x)) @export -@partial(jit, inline=True) +@jit(inline=True) def ceil(x: ArrayLike, /) -> Array: """Round input to the nearest integer upwards. @@ -385,8 +387,8 @@ def ceil(x: ArrayLike, /) -> Array: the nearest integer that is greater than or equal to the value itself. See also: - - :func:`jax.numpy.fix`: Rounds the input to the nearest interger towards zero. - - :func:`jax.numpy.trunc`: Rounds the input to the nearest interger towards + - :func:`jax.numpy.fix`: Rounds the input to the nearest integer towards zero. + - :func:`jax.numpy.trunc`: Rounds the input to the nearest integer towards zero. - :func:`jax.numpy.floor`: Rounds the input down to the nearest integer. @@ -403,14 +405,14 @@ def ceil(x: ArrayLike, /) -> Array: [-0., 4., 1.], [ 5., 4., -1.]], dtype=float32) """ - check_arraylike('ceil', x) - if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')): + x = ensure_arraylike('ceil', x) + if dtypes.isdtype(x.dtype, ('integral', 'bool')): return lax.asarray(x) return lax.ceil(*promote_args_inexact('ceil', x)) @export -@partial(jit, inline=True) +@jit(inline=True) def exp(x: ArrayLike, /) -> Array: """Calculate element-wise exponential of the input. @@ -452,7 +454,7 @@ def exp(x: ArrayLike, /) -> Array: @export -@partial(jit, inline=True) +@jit(inline=True) def log(x: ArrayLike, /) -> Array: """Calculate element-wise natural logarithm of the input. @@ -486,11 +488,13 @@ def log(x: ArrayLike, /) -> Array: >>> jnp.allclose(jnp.log(x1*x2), jnp.log(x1)+jnp.log(x2)) Array(True, dtype=bool) """ - return lax.log(*promote_args_inexact('log', x)) + out = lax.log(*promote_args_inexact('log', x)) + jnp_error._set_error_if_nan(out) + return out @export -@partial(jit, inline=True) +@jit(inline=True) def expm1(x: ArrayLike, /) -> Array: """Calculate ``exp(x)-1`` of each element of the input. @@ -535,7 +539,7 @@ def expm1(x: ArrayLike, /) -> Array: @export -@partial(jit, inline=True) +@jit(inline=True) def log1p(x: ArrayLike, /) -> Array: """Calculates element-wise logarithm of one plus input, ``log(x+1)``. @@ -572,11 +576,13 @@ def log1p(x: ArrayLike, /) -> Array: >>> jnp.expm1(jnp.log(x1+1)) # doctest: +SKIP Array([1.000166e-04, 9.536743e-07, 0.000000e+00], dtype=float32) """ - return lax.log1p(*promote_args_inexact('log1p', x)) + out = lax.log1p(*promote_args_inexact('log1p', x)) + jnp_error._set_error_if_nan(out) + return out @export -@partial(jit, inline=True) +@jit(inline=True) def sin(x: ArrayLike, /) -> Array: """Compute a trigonometric sine of each element of input. @@ -604,11 +610,13 @@ def sin(x: ArrayLike, /) -> Array: ... print(jnp.sin(x)) [ 0.707 1. 0.707 -0. ] """ - return lax.sin(*promote_args_inexact('sin', x)) + out = lax.sin(*promote_args_inexact('sin', x)) + jnp_error._set_error_if_nan(out) + return out @export -@partial(jit, inline=True) +@jit(inline=True) def cos(x: ArrayLike, /) -> Array: """Compute a trigonometric cosine of each element of input. @@ -635,11 +643,13 @@ def cos(x: ArrayLike, /) -> Array: ... print(jnp.cos(x)) [ 0.707 -0. -0.707 -0.866] """ - return lax.cos(*promote_args_inexact('cos', x)) + out = lax.cos(*promote_args_inexact('cos', x)) + jnp_error._set_error_if_nan(out) + return out @export -@partial(jit, inline=True) +@jit(inline=True) def tan(x: ArrayLike, /) -> Array: """Compute a trigonometric tangent of each element of input. @@ -666,11 +676,13 @@ def tan(x: ArrayLike, /) -> Array: ... print(jnp.tan(x)) [ 0. 0.577 1. -1. -0.577] """ - return lax.tan(*promote_args_inexact('tan', x)) + out = lax.tan(*promote_args_inexact('tan', x)) + jnp_error._set_error_if_nan(out) + return out @export -@partial(jit, inline=True) +@jit(inline=True) def arcsin(x: ArrayLike, /) -> Array: r"""Compute element-wise inverse of trigonometric sine of input. @@ -708,11 +720,13 @@ def arcsin(x: ArrayLike, /) -> Array: ... jnp.arcsin(3+4j) Array(0.634+2.306j, dtype=complex64, weak_type=True) """ - return lax.asin(*promote_args_inexact('arcsin', x)) + out = lax.asin(*promote_args_inexact('arcsin', x)) + jnp_error._set_error_if_nan(out) + return out @export -@partial(jit, inline=True) +@jit(inline=True) def arccos(x: ArrayLike, /) -> Array: """Compute element-wise inverse of trigonometric cosine of input. @@ -751,11 +765,13 @@ def arccos(x: ArrayLike, /) -> Array: ... jnp.arccos(4-1j) Array(0.252+2.097j, dtype=complex64, weak_type=True) """ - return lax.acos(*promote_args_inexact('arccos', x)) + out = lax.acos(*promote_args_inexact('arccos', x)) + jnp_error._set_error_if_nan(out) + return out @export -@partial(jit, inline=True) +@jit(inline=True) def arctan(x: ArrayLike, /) -> Array: """Compute element-wise inverse of trigonometric tangent of input. @@ -796,7 +812,7 @@ def arctan(x: ArrayLike, /) -> Array: @export -@partial(jit, inline=True) +@jit(inline=True) def sinh(x: ArrayLike, /) -> Array: r"""Calculate element-wise hyperbolic sine of input. @@ -851,7 +867,7 @@ def sinh(x: ArrayLike, /) -> Array: @export -@partial(jit, inline=True) +@jit(inline=True) def cosh(x: ArrayLike, /) -> Array: r"""Calculate element-wise hyperbolic cosine of input. @@ -905,7 +921,7 @@ def cosh(x: ArrayLike, /) -> Array: @export -@partial(jit, inline=True) +@jit(inline=True) def arcsinh(x: ArrayLike, /) -> Array: r"""Calculate element-wise inverse of hyperbolic sine of input. @@ -1005,13 +1021,14 @@ def arccosh(x: ArrayLike, /) -> Array: # Note: arccosh is multi-valued for complex input, and lax.acosh # uses a different convention than np.arccosh. result = lax.acosh(*promote_args_inexact("arccosh", x)) + jnp_error._set_error_if_nan(result) if dtypes.issubdtype(result.dtype, np.complexfloating): result = _where(real(result) < 0, lax.neg(result), result) return result @export -@partial(jit, inline=True) +@jit(inline=True) def tanh(x: ArrayLike, /) -> Array: r"""Calculate element-wise hyperbolic tangent of input. @@ -1065,7 +1082,7 @@ def tanh(x: ArrayLike, /) -> Array: @export -@partial(jit, inline=True) +@jit(inline=True) def arctanh(x: ArrayLike, /) -> Array: r"""Calculate element-wise inverse of hyperbolic tangent of input. @@ -1110,11 +1127,13 @@ def arctanh(x: ArrayLike, /) -> Array: ... jnp.arctanh(x1) Array([-0.549+1.571j, 0.347+1.571j, 0.239-1.509j], dtype=complex64) """ - return lax.atanh(*promote_args_inexact('arctanh', x)) + out = lax.atanh(*promote_args_inexact('arctanh', x)) + jnp_error._set_error_if_nan(out) + return out @export -@partial(jit, inline=True) +@jit(inline=True) def sqrt(x: ArrayLike, /) -> Array: """Calculates element-wise non-negative square root of the input array. @@ -1143,11 +1162,13 @@ def sqrt(x: ArrayLike, /) -> Array: >>> jnp.sqrt(-1) Array(nan, dtype=float32, weak_type=True) """ - return lax.sqrt(*promote_args_inexact('sqrt', x)) + out = lax.sqrt(*promote_args_inexact('sqrt', x)) + jnp_error._set_error_if_nan(out) + return out @export -@partial(jit, inline=True) +@jit(inline=True) def cbrt(x: ArrayLike, /) -> Array: """Calculates element-wise cube root of the input array. @@ -1212,7 +1233,11 @@ def add(x: ArrayLike, y: ArrayLike, /) -> Array: Array([10, 11, 12, 13], dtype=int32) """ x, y = promote_args("add", x, y) - return lax.add(x, y) if x.dtype != bool else lax.bitwise_or(x, y) + if x.dtype == bool: + return lax.bitwise_or(x, y) + out = lax.add(x, y) + jnp_error._set_error_if_nan(out) + return out def _multiply_at(a: Array, indices: Any, b: ArrayLike) -> Array: @@ -1347,7 +1372,7 @@ def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: @export -@partial(jit, inline=True) +@jit(inline=True) def left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: r"""Shift bits of ``x`` to left by the amount specified in ``y``, element-wise. @@ -1403,14 +1428,14 @@ def left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: @export -@partial(jit, inline=True) +@jit(inline=True) def bitwise_left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.left_shift`.""" return lax.shift_left(*promote_args_numeric("bitwise_left_shift", x, y)) @export -@partial(jit, inline=True) +@jit(inline=True) def equal(x: ArrayLike, y: ArrayLike, /) -> Array: """Returns element-wise truth value of ``x == y``. @@ -1460,7 +1485,7 @@ def equal(x: ArrayLike, y: ArrayLike, /) -> Array: @export -@partial(jit, inline=True) +@jit(inline=True) def not_equal(x: ArrayLike, y: ArrayLike, /) -> Array: """Returns element-wise truth value of ``x != y``. @@ -1541,11 +1566,13 @@ def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: >>> x - 10 Array([-10, -9, -8, -7], dtype=int32) """ - return lax.sub(*promote_args("subtract", x, y)) + out = lax.sub(*promote_args("subtract", x, y)) + jnp_error._set_error_if_nan(out) + return out @export -@partial(jit, inline=True) +@jit(inline=True) def arctan2(x1: ArrayLike, x2: ArrayLike, /) -> Array: r"""Compute the arctangent of x1/x2, choosing the correct quadrant. @@ -1579,7 +1606,7 @@ def arctan2(x1: ArrayLike, x2: ArrayLike, /) -> Array: :math:`\tan(\theta) = y / x`, and compute :math:`\theta = \tan^{-1}(y/x)`. Unfortunately, this does not recover the input angle: - >>> with jnp.printoptions(precision=2, suppress=True): + >>> with jnp.printoptions(precision=2, suppress=True): # doctest: +SKIP ... print(jnp.arctan(y / x)) [-0. 0.79 1.57 -0.79 0. 0.79 1.57 -0.79 0. ] @@ -1595,13 +1622,12 @@ def arctan2(x1: ArrayLike, x2: ArrayLike, /) -> Array: The results match the input ``theta``, except at the endpoints where :math:`+\pi` and :math:`-\pi` represent indistinguishable points on the unit circle. By convention, - :func:`arctan2` alwasy returns values between :math:`-\pi` and :math:`+\pi` inclusive. + :func:`arctan2` always returns values between :math:`-\pi` and :math:`+\pi` inclusive. """ return lax.atan2(*promote_args_inexact("arctan2", x1, x2)) -@export -@partial(jit, inline=True) +@binary_ufunc(identity=None, reduce=reductions._reduce_min) def minimum(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise minimum of the input arrays. @@ -1661,8 +1687,7 @@ def minimum(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.min(*promote_args("minimum", x, y)) -@export -@partial(jit, inline=True) +@binary_ufunc(identity=None, reduce=reductions._reduce_max) def maximum(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise maximum of the input arrays. @@ -1686,7 +1711,7 @@ def maximum(x: ArrayLike, y: ArrayLike, /) -> Array: arrays. - :func:`jax.numpy.fmax`: Returns element-wise maximum of the input arrays, ignoring NaNs. - - :func:`jax.numpy.amax`: Retruns the maximum of array elements along a given + - :func:`jax.numpy.amax`: Returns the maximum of array elements along a given axis. - :func:`jax.numpy.nanmax`: Returns the maximum of the array elements along a given axis, ignoring NaNs. @@ -1722,7 +1747,7 @@ def maximum(x: ArrayLike, y: ArrayLike, /) -> Array: @export -@partial(jit, inline=True) +@jit(inline=True) def float_power(x: ArrayLike, y: ArrayLike, /) -> Array: """Calculate element-wise base ``x`` exponential of ``y``. @@ -1750,7 +1775,7 @@ def float_power(x: ArrayLike, y: ArrayLike, /) -> Array: >>> jnp.float_power(x, y) Array([ 9. , 1. , -0.2], dtype=float32) - Inputs with broacast compatibility: + Inputs with broadcast compatibility: >>> x1 = jnp.array([[2, -4, 1], ... [-1, 2, 3]]) @@ -1765,11 +1790,13 @@ def float_power(x: ArrayLike, y: ArrayLike, /) -> Array: >>> jnp.float_power(-3, 1.7) Array(nan, dtype=float32, weak_type=True) """ - return lax.pow(*promote_args_inexact("float_power", x, y)) + out = lax.pow(*promote_args_inexact("float_power", x, y)) + jnp_error._set_error_if_nan(out) + return out @export -@partial(jit, inline=True) +@jit(inline=True) def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise next floating point value after ``x`` towards ``y``. @@ -1797,7 +1824,7 @@ def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: @export -@partial(jit, inline=True) +@jit(inline=True) def spacing(x: ArrayLike, /) -> Array: """Return the spacing between ``x`` and the next adjacent number. @@ -1905,7 +1932,7 @@ def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: @export -@partial(jit, inline=True) +@jit(inline=True) def logical_not(x: ArrayLike, /) -> Array: """Compute NOT bool(x) element-wise. @@ -1952,7 +1979,7 @@ def _complex_comparison(lax_op: Callable[[ArrayLike, ArrayLike], Array], @export -@partial(jit, inline=True) +@jit(inline=True) def greater_equal(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise truth value of ``x >= y``. @@ -1998,7 +2025,7 @@ def greater_equal(x: ArrayLike, y: ArrayLike, /) -> Array: @export -@partial(jit, inline=True) +@jit(inline=True) def greater(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise truth value of ``x > y``. @@ -2045,7 +2072,7 @@ def greater(x: ArrayLike, y: ArrayLike, /) -> Array: @export -@partial(jit, inline=True) +@jit(inline=True) def less_equal(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise truth value of ``x <= y``. @@ -2092,7 +2119,7 @@ def less_equal(x: ArrayLike, y: ArrayLike, /) -> Array: @export -@partial(jit, inline=True) +@jit(inline=True) def less(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise truth value of ``x < y``. @@ -2140,49 +2167,49 @@ def less(x: ArrayLike, y: ArrayLike, /) -> Array: # Array API aliases @export -@partial(jit, inline=True) +@jit(inline=True) def acos(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arccos`""" return arccos(*promote_args('acos', x)) @export -@partial(jit, inline=True) +@jit(inline=True) def acosh(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arccosh`""" return arccosh(*promote_args('acosh', x)) @export -@partial(jit, inline=True) +@jit(inline=True) def asin(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arcsin`""" return arcsin(*promote_args('asin', x)) @export -@partial(jit, inline=True) +@jit(inline=True) def asinh(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arcsinh`""" return arcsinh(*promote_args('asinh', x)) @export -@partial(jit, inline=True) +@jit(inline=True) def atan(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arctan`""" return arctan(*promote_args('atan', x)) @export -@partial(jit, inline=True) +@jit(inline=True) def atanh(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arctanh`""" return arctanh(*promote_args('atanh', x)) @export -@partial(jit, inline=True) +@jit(inline=True) def atan2(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arctan2`""" return arctan2(*promote_args('atan2', x1, x2)) @@ -2226,7 +2253,7 @@ def bitwise_count(x: ArrayLike, /) -> Array: @export -@partial(jit, inline=True) +@jit(inline=True) def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: r"""Right shift the bits of ``x1`` to the amount specified in ``x2``. @@ -2278,14 +2305,14 @@ def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: @export -@partial(jit, inline=True) +@jit(inline=True) def bitwise_right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.right_shift`.""" return right_shift(x1, x2) @export -@partial(jit, inline=True) +@jit(inline=True) def absolute(x: ArrayLike, /) -> Array: r"""Calculate the absolute value element-wise. @@ -2315,13 +2342,13 @@ def absolute(x: ArrayLike, /) -> Array: >>> jnp.absolute(x3) Array([17., 5., 5.], dtype=float32) """ - check_arraylike('absolute', x) - dt = dtypes.dtype(x) + x = ensure_arraylike('absolute', x) + dt = x.dtype return lax.asarray(x) if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x) @export -@partial(jit, inline=True) +@jit(inline=True) def abs(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.absolute`.""" return absolute(x) @@ -2358,10 +2385,10 @@ def rint(x: ArrayLike, /) -> Array: >>> jnp.rint(x3) Array([-2.+4.j, 4.-0.j], dtype=complex64) """ - check_arraylike('rint', x) - dtype = dtypes.dtype(x) + x = ensure_arraylike('rint', x) + dtype = x.dtype if dtype == bool or dtypes.issubdtype(dtype, np.integer): - return lax.convert_element_type(x, dtypes.float_) + return lax.convert_element_type(x, dtypes.default_float_dtype()) if dtypes.issubdtype(dtype, np.complexfloating): return lax.complex(rint(lax.real(x)), rint(lax.imag(x))) return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN) @@ -2402,13 +2429,13 @@ def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array: [ 2., 3.]], dtype=float32) """ x1, x2 = promote_args_inexact("copysign", x1, x2) - if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating): + if dtypes.issubdtype(x1.dtype, np.complexfloating): raise TypeError("copysign does not support complex-valued inputs") return _where(signbit(x2).astype(bool), -lax.abs(x1), lax.abs(x1)) @export -@partial(jit, inline=True) +@jit(inline=True) def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Calculates the division of x1 by x2 element-wise @@ -2443,7 +2470,10 @@ def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: :func:`jax.numpy.floor_divide` for integer division """ x1, x2 = promote_args_inexact("true_divide", x1, x2) - return lax.div(x1, x2) + jnp_error._set_error_if_divide_by_zero(x2) + out = lax.div(x1, x2) + jnp_error._set_error_if_nan(out) + return out @export @@ -2493,7 +2523,8 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: Array([3., 2., 2.], dtype=float32) """ x1, x2 = promote_args_numeric("floor_divide", x1, x2) - dtype = dtypes.dtype(x1) + jnp_error._set_error_if_divide_by_zero(x2) + dtype = x1.dtype if dtypes.issubdtype(dtype, np.unsignedinteger): return lax.div(x1, x2) elif dtypes.issubdtype(dtype, np.integer): @@ -2544,9 +2575,10 @@ def divmod(x1: ArrayLike, x2: ArrayLike, /) -> tuple[Array, Array]: Array([0.30000007, 1. , 2.9 ], dtype=float32)) """ x1, x2 = promote_args_numeric("divmod", x1, x2) - if dtypes.issubdtype(dtypes.dtype(x1), np.integer): + if dtypes.issubdtype(x1.dtype, np.integer): return floor_divide(x1, x2), remainder(x1, x2) else: + jnp_error._set_error_if_divide_by_zero(x2) return _float_divmod(x1, x2) @@ -2582,8 +2614,9 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: :func:`jax.lax.integer_pow`. - When ``x2`` is a traced scalar or an array, ``jnp.power`` lowers to :func:`jax.lax.pow`. - - ``jnp.power`` raises a ``TypeError`` for integer type raised to negative - integer power. + - ``jnp.power`` raises a ``TypeError`` for integer type raised to a concrete + negative integer power. For a non-concrete power, the operation is invalid + and the returned value is implementation-defined. - ``jnp.power`` returns ``nan`` for negative value raised to the power of non-integer values. @@ -2619,6 +2652,11 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: [nan, 27., 1.]], dtype=float32) """ check_arraylike("power", x1, x2) + + # Must do __jax_array__ conversion prior to dtype check. + x1 = x1.__jax_array__() if hasattr(x1, "__jax_array__") else x1 + x2 = x2.__jax_array__() if hasattr(x2, "__jax_array__") else x2 + check_no_float0s("power", x1, x2) # We apply special cases, both for algorithmic and autodiff reasons: @@ -2645,26 +2683,28 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax.integer_pow(x1, x2) # Handle cases #2 and #3 under a jit: - return _power(x1, x2) + out = _power(x1, x2) + jnp_error._set_error_if_nan(out) + return out @export def pow(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.power`""" return power(x1, x2) -@partial(jit, inline=True) +@jit(inline=True) def _power(x1: ArrayLike, x2: ArrayLike) -> Array: x1, x2 = promote_shapes("power", x1, x2) # not dtypes # Case 2: bool/integer result x1_, x2_ = promote_args_numeric("power", x1, x2) - if (dtypes.issubdtype(dtypes.dtype(x1_), np.integer) or - dtypes.issubdtype(dtypes.dtype(x1_), np.bool_)): - assert np.iinfo(dtypes.dtype(x1_)).bits <= 64 # _pow_int_int assumes <=64bit + if (dtypes.issubdtype(x1_.dtype, np.integer) or + dtypes.issubdtype(x1_.dtype, np.bool_)): + assert np.iinfo(x1_.dtype).bits <= 64 # _pow_int_int assumes <=64bit return _pow_int_int(x1_, x2_) # Case 3: float/complex base with integer power (special autodiff behavior) - d1, d2 = dtypes.dtype(x1), dtypes.dtype(x2) + d1, d2 = x1.dtype, x2.dtype if dtypes.issubdtype(d1, np.inexact) and dtypes.issubdtype(d2, np.integer): return lax.pow(x1, x2) @@ -2741,12 +2781,11 @@ def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array: Array(True, dtype=bool) """ x1, x2 = promote_args_inexact("logaddexp2", x1, x2) - ln2 = float(np.log(2)) - return logaddexp(x1 * ln2, x2 * ln2) / ln2 + return lax_other.logaddexp2(x1, x2) @export -@partial(jit, inline=True) +@jit(inline=True) def log2(x: ArrayLike, /) -> Array: """Calculates the base-2 logarithm of ``x`` element-wise. @@ -2771,11 +2810,13 @@ def log2(x: ArrayLike, /) -> Array: im = lax.imag(r) ln2 = lax.log(_constant_like(re, 2)) return lax.complex(lax.div(re, ln2), lax.div(im, ln2)) - return lax.div(lax.log(x), lax.log(_constant_like(x, 2))) + out = lax.div(lax.log(x), lax.log(_constant_like(x, 2))) + jnp_error._set_error_if_nan(out) + return out @export -@partial(jit, inline=True) +@jit(inline=True) def log10(x: ArrayLike, /) -> Array: """Calculates the base-10 logarithm of x element-wise @@ -2795,17 +2836,20 @@ def log10(x: ArrayLike, /) -> Array: [-2. -1. 0. 1. 2. 3.] """ x, = promote_args_inexact("log10", x) + one_over_log10 = np.array(0.4342944819032518, # exact value of 1 / log(10) + dtype=dtypes.finfo(x.dtype).dtype) if dtypes.issubdtype(x.dtype, np.complexfloating): r = lax.log(x) re = lax.real(r) im = lax.imag(r) - ln10 = lax.log(_constant_like(re, 10)) - return lax.complex(lax.div(re, ln10), lax.div(im, ln10)) - return lax.div(lax.log(x), lax.log(_constant_like(x, 10))) + return lax.complex(lax.mul(re, one_over_log10), lax.mul(im, one_over_log10)) + out = lax.mul(lax.log(x), one_over_log10) + jnp_error._set_error_if_nan(out) + return out @export -@partial(jit, inline=True) +@jit(inline=True) def exp2(x: ArrayLike, /) -> Array: """Calculate element-wise base-2 exponential of input. @@ -2885,7 +2929,7 @@ def signbit(x: ArrayLike, /) -> Array: Array([False, True, False, True], dtype=bool) """ x, = promote_args("signbit", x) - dtype = dtypes.dtype(x) + dtype = x.dtype if dtypes.issubdtype(dtype, np.integer): return lax.lt(x, _constant_like(x, 0)) elif dtypes.issubdtype(dtype, np.bool_): @@ -2950,15 +2994,24 @@ def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: >>> jnp.ldexp(m, e) Array([ 2., 3., 5., 11.], dtype=float32) """ - check_arraylike("ldexp", x1, x2) - x1_dtype = dtypes.dtype(x1) - x2_dtype = dtypes.dtype(x2) + x1, x2 = ensure_arraylike("ldexp", x1, x2) + x1_dtype = x1.dtype + x2_dtype = x2.dtype if (dtypes.issubdtype(x1_dtype, np.complexfloating) or dtypes.issubdtype(x2_dtype, np.inexact)): raise ValueError(f"ldexp not supported for input types {(x1_dtype, x2_dtype)}") x1, = promote_args_inexact("ldexp", x1) - x2 = lax.convert_element_type(x2, dtypes.dtype(x1)) - x = x1 * (2 ** x2) + x2 = lax.convert_element_type(x2, x1.dtype) + + # Split off the exponent to avoid overflow for small x1 and large x2. + m, e = frexp(x1) + e = (e.astype(x2.dtype) + x2).astype(x1.dtype) + + # exponent may overflow by 1 and still have a finite result. + m = _where(e > 0, m * 2, m) + e = _where(e > 0, e - 1, e) + + x = m * (2 ** e.astype(m.dtype)) return _where(isinf(x1) | (x1 == 0), x1, x) @@ -2995,11 +3048,14 @@ def frexp(x: ArrayLike, /) -> tuple[Array, Array]: >>> m * 2 ** e Array([1., 2., 3., 4., 5.], dtype=float32) """ - check_arraylike("frexp", x) + x = ensure_arraylike("frexp", x) x, = promote_dtypes_inexact(x) if dtypes.issubdtype(x.dtype, np.complexfloating): raise TypeError("frexp does not support complex-valued inputs") + return _frexp(x) +@custom_jvp +def _frexp(x): dtype = dtypes.dtype(x) info = dtypes.finfo(dtype) mask = (1 << info.nexp) - 1 @@ -3016,6 +3072,16 @@ def frexp(x: ArrayLike, /) -> tuple[Array, Array]: return _where(cond, x, x1), lax.convert_element_type(x2, np.int32) +@_frexp.defjvp +def _frexp_jvp(primals, tangents): + x, = primals + t, = tangents + m, e = frexp(x) + mdot = t * exp2(-e.astype(t.dtype)) + edot = lax.full_like(e, fill_value=0, dtype=dtypes.float0) + return (m, e), (mdot, edot) + + @export @jit def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array: @@ -3054,6 +3120,7 @@ def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array: [ 0., 2., -2.]], dtype=float32) """ x1, x2 = promote_args_numeric("remainder", x1, x2) + jnp_error._set_error_if_divide_by_zero(x2) zero = _constant_like(x1, 0) if dtypes.issubdtype(x2.dtype, np.integer): x2 = _where(x2 == 0, lax._ones(x2), x2) @@ -3061,7 +3128,9 @@ def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array: trunc_mod_not_zero = lax.ne(trunc_mod, zero) do_plus = lax.bitwise_and( lax.ne(lax.lt(trunc_mod, zero), lax.lt(x2, zero)), trunc_mod_not_zero) - return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod) + out = lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod) + jnp_error._set_error_if_nan(out) + return out @export @@ -3087,7 +3156,7 @@ def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array: operation of ``x1`` and ``x2`` with same sign as the elements of ``x1``. Note: - The result of ``jnp.fmod`` is equivalent to ``x1 - x2 * jnp.fix(x1 / x2)``. + The result of ``jnp.fmod`` is equivalent to ``x1 - x2 * jnp.trunc(x1 / x2)``. See also: - :func:`jax.numpy.mod` and :func:`jax.numpy.remainder`: Returns the element-wise @@ -3102,18 +3171,20 @@ def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array: >>> jnp.fmod(x1, x2) Array([[ 1, -1, 4], [ 0, 2, -2]], dtype=int32) - >>> x1 - x2 * jnp.fix(x1 / x2) + >>> x1 - x2 * jnp.trunc(x1 / x2) Array([[ 1., -1., 4.], [ 0., 2., -2.]], dtype=float32) """ - check_arraylike("fmod", x1, x2) + x1, x2 = ensure_arraylike("fmod", x1, x2) if dtypes.issubdtype(dtypes.result_type(x1, x2), np.integer): x2 = _where(x2 == 0, lax._ones(x2), x2) - return lax.rem(*promote_args_numeric("fmod", x1, x2)) + out = lax.rem(*promote_args_numeric("fmod", x1, x2)) + jnp_error._set_error_if_nan(out) + return out @export -@partial(jit, inline=True) +@jit(inline=True) def square(x: ArrayLike, /) -> Array: """Calculate element-wise square of the input array. @@ -3157,13 +3228,13 @@ def square(x: ArrayLike, /) -> Array: >>> jnp.square(x2) Array([-8.-6.j, -1.+0.j, 4.+0.j], dtype=complex64) """ - check_arraylike("square", x) + x = ensure_arraylike("square", x) x, = promote_dtypes_numeric(x) return lax.square(x) @export -@partial(jit, inline=True) +@jit(inline=True) def deg2rad(x: ArrayLike, /) -> Array: r"""Convert angles from degrees to radians. @@ -3198,7 +3269,7 @@ def deg2rad(x: ArrayLike, /) -> Array: @export -@partial(jit, inline=True) +@jit(inline=True) def rad2deg(x: ArrayLike, /) -> Array: r"""Convert angles from radians to degrees. @@ -3246,7 +3317,7 @@ def radians(x: ArrayLike, /) -> Array: @export -@partial(jit, inline=True) +@jit(inline=True) def conjugate(x: ArrayLike, /) -> Array: """Return element-wise complex-conjugate of the input. @@ -3271,7 +3342,7 @@ def conjugate(x: ArrayLike, /) -> Array: >>> jnp.conjugate(x) Array([2.+1.j, 3.-5.j, 7.-0.j], dtype=complex64) """ - check_arraylike("conjugate", x) + x = ensure_arraylike("conjugate", x) return lax.conj(x) if np.iscomplexobj(x) else lax.asarray(x) @@ -3282,7 +3353,7 @@ def conj(x: ArrayLike, /) -> Array: @export -@partial(jit, inline=True) +@jit(inline=True) def imag(val: ArrayLike, /) -> Array: """Return element-wise imaginary of part of the complex argument. @@ -3309,12 +3380,12 @@ def imag(val: ArrayLike, /) -> Array: >>> jnp.imag(x) Array([ 3., -1., 0.], dtype=float32) """ - check_arraylike("imag", val) + val = ensure_arraylike("imag", val) return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0) @export -@partial(jit, inline=True) +@jit(inline=True) def real(val: ArrayLike, /) -> Array: """Return element-wise real part of the complex argument. @@ -3341,7 +3412,7 @@ def real(val: ArrayLike, /) -> Array: >>> jnp.real(x) Array([ 3., 4., -0.], dtype=float32) """ - check_arraylike("real", val) + val = ensure_arraylike("real", val) return lax.real(val) if np.iscomplexobj(val) else lax.asarray(val) @@ -3371,7 +3442,7 @@ def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: >>> jnp.modf(x) (Array([-0.4000001 , -0.6999998 , 0.6 , 0.5 , 0.29999995], dtype=float32), Array([-3., -5., 0., 1., 2.], dtype=float32)) """ - check_arraylike("modf", x) + x = ensure_arraylike("modf", x) x, = promote_dtypes_inexact(x) if out is not None: raise NotImplementedError("The 'out' argument to jnp.modf is not supported.") @@ -3380,7 +3451,7 @@ def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: @export -@partial(jit, inline=True) +@jit(inline=True) def isfinite(x: ArrayLike, /) -> Array: """Return a boolean array indicating whether each element of input is finite. @@ -3410,8 +3481,8 @@ def isfinite(x: ArrayLike, /) -> Array: >>> jnp.isfinite(3-4j) Array(True, dtype=bool, weak_type=True) """ - check_arraylike("isfinite", x) - dtype = dtypes.dtype(x) + x = ensure_arraylike("isfinite", x) + dtype = x.dtype if dtypes.issubdtype(dtype, np.floating): return lax.is_finite(x) elif dtypes.issubdtype(dtype, np.complexfloating): @@ -3451,8 +3522,8 @@ def isinf(x: ArrayLike, /) -> Array: >>> jnp.isinf(x) Array([False, True, False, True, False], dtype=bool) """ - check_arraylike("isinf", x) - dtype = dtypes.dtype(x) + x = ensure_arraylike("isinf", x) + dtype = x.dtype if dtypes.issubdtype(dtype, np.floating): return lax.eq(lax.abs(x), _constant_like(x, np.inf)) elif dtypes.issubdtype(dtype, np.complexfloating): @@ -3464,10 +3535,10 @@ def isinf(x: ArrayLike, /) -> Array: return lax.full_like(x, False, dtype=np.bool_) -def _isposneginf(infinity: float, x: ArrayLike, out) -> Array: +def _isposneginf(infinity: float, x: Array, out) -> Array: if out is not None: raise NotImplementedError("The 'out' argument to isneginf/isposinf is not supported.") - dtype = dtypes.dtype(x) + dtype = x.dtype if dtypes.issubdtype(dtype, np.floating): return lax.eq(x, _constant_like(x, infinity)) elif dtypes.issubdtype(dtype, np.complexfloating): @@ -3507,6 +3578,7 @@ def isposinf(x, /, out=None): >>> jnp.isposinf(x) Array([False, False, True, False, False], dtype=bool) """ + x = ensure_arraylike("isposinf", x) return _isposneginf(np.inf, x, out) @@ -3541,11 +3613,12 @@ def isneginf(x, /, out=None): >>> jnp.isneginf(x) Array([ True, False, False, False, False], dtype=bool) """ + x = ensure_arraylike("isneginf", x) return _isposneginf(-np.inf, x, out) @export -@partial(jit, inline=True) +@jit(inline=True) def isnan(x: ArrayLike, /) -> Array: """Returns a boolean array indicating whether each element of input is ``NaN``. @@ -3575,7 +3648,7 @@ def isnan(x: ArrayLike, /) -> Array: >>> jnp.isnan(x) Array([False, False, False, True], dtype=bool) """ - check_arraylike("isnan", x) + x = ensure_arraylike("isnan", x) return lax.ne(x, x) @@ -3591,9 +3664,9 @@ def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array: .. math:: \mathrm{heaviside}(x1, x2) = \begin{cases} - 0., & x < 0\\ - x2, & x = 0\\ - 1., & x > 0. + 0, & x1 < 0\\ + x2, & x1 = 0\\ + 1, & x1 > 0. \end{cases} Args: @@ -3622,7 +3695,7 @@ def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array: >>> jnp.heaviside(-3, x2) Array([0., 0., 0.], dtype=float32) """ - check_arraylike("heaviside", x1, x2) + x1, x2 = ensure_arraylike("heaviside", x1, x2) x1, x2 = promote_dtypes_inexact(x1, x2) zero = _lax_const(x1, 0) return _where(lax.lt(x1, zero), zero, @@ -3679,7 +3752,7 @@ def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array: @export -@partial(jit, inline=True) +@jit(inline=True) def reciprocal(x: ArrayLike, /) -> Array: """Calculate element-wise reciprocal of the input. @@ -3707,7 +3780,7 @@ def reciprocal(x: ArrayLike, /) -> Array: >>> jnp.reciprocal(x) Array([1. , 0.2 , 0.25], dtype=float32) """ - check_arraylike("reciprocal", x) + x = ensure_arraylike("reciprocal", x) x, = promote_dtypes_inexact(x) return lax.integer_pow(x, -1) @@ -3760,7 +3833,7 @@ def sinc(x: ArrayLike, /) -> Array: (d/dx)^4 f(0.0) = 19.48 (d/dx)^5 f(0.0) = 0.00 """ - check_arraylike("sinc", x) + x = ensure_arraylike("sinc", x) x, = promote_dtypes_inexact(x) eq_zero = lax.eq(x, _lax_const(x, 0)) pi_x = lax.mul(_lax_const(x, np.pi), x) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index e281c63ae654..53a78cf6c8ac 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -14,30 +14,34 @@ from __future__ import annotations from collections.abc import Sequence -from functools import partial from typing import Any, overload - +import math import warnings +import numpy as np + from jax._src import api from jax._src import config from jax._src import core from jax._src import dtypes +from jax._src import literals from jax._src.lax import lax from jax._src.lib import xla_client as xc from jax._src.sharding_impls import SingleDeviceSharding -from jax._src.util import safe_zip, safe_map, set_module -from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape -from jax.sharding import Sharding - -import numpy as np +from jax._src.util import safe_zip, safe_map, set_module, canonicalize_axis_tuple +from jax._src.sharding import Sharding +from jax._src.sharding_impls import (NamedSharding, PartitionSpec as P, + canonicalize_sharding) +from jax._src.typing import ( + Array, ArrayLike, DimSize, Shape, SupportsNdim, SupportsShape, SupportsSize) zip, unsafe_zip = safe_zip, zip map, unsafe_map = safe_map, map export = set_module('jax.numpy') -_dtype = partial(dtypes.dtype, canonicalize=True) +_dtype = dtypes.dtype + def promote_shapes(fun_name: str, *args: ArrayLike) -> list[Array]: """Apply NumPy-style broadcasting, making args shape-compatible for lax.py.""" @@ -45,23 +49,16 @@ def promote_shapes(fun_name: str, *args: ArrayLike) -> list[Array]: return [lax.asarray(arg) for arg in args] else: shapes = [np.shape(arg) for arg in args] - if config.dynamic_shapes.value: - # With dynamic shapes we don't support singleton-dimension broadcasting; - # we instead broadcast out to the full shape as a temporary workaround. - # TODO(mattjj): revise this workaround - res_shape = lax.broadcast_shapes(*shapes) # Can raise an error! - return [_broadcast_to(arg, res_shape) for arg, shp in zip(args, shapes)] + if all(len(shapes[0]) == len(s) for s in shapes[1:]): + return [lax.asarray(arg) for arg in args] # no need for rank promotion, so rely on lax promotion + nonscalar_ranks = {len(shp) for shp in shapes if shp} + if len(nonscalar_ranks) < 2: + return [lax.asarray(arg) for arg in args] # rely on lax scalar promotion else: - if all(len(shapes[0]) == len(s) for s in shapes[1:]): - return [lax.asarray(arg) for arg in args] # no need for rank promotion, so rely on lax promotion - nonscalar_ranks = {len(shp) for shp in shapes if shp} - if len(nonscalar_ranks) < 2: - return [lax.asarray(arg) for arg in args] # rely on lax scalar promotion - else: - if config.numpy_rank_promotion.value != "allow": - _rank_promotion_warning_or_error(fun_name, shapes) - result_rank = len(lax.broadcast_shapes(*shapes)) - return [lax.broadcast_to_rank(arg, result_rank) for arg in args] + if config.numpy_rank_promotion.value != "allow": + _rank_promotion_warning_or_error(fun_name, shapes) + result_rank = len(lax.broadcast_shapes(*shapes)) + return [lax.broadcast_to_rank(arg, result_rank) for arg in args] def _rank_promotion_warning_or_error(fun_name: str, shapes: Sequence[Shape]): @@ -69,13 +66,13 @@ def _rank_promotion_warning_or_error(fun_name: str, shapes: Sequence[Shape]): msg = ("Following NumPy automatic rank promotion for {} on shapes {}. " "Set the jax_numpy_rank_promotion config option to 'allow' to " "disable this warning; for more information, see " - "https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.") + "https://docs.jax.dev/en/latest/rank_promotion_warning.html.") warnings.warn(msg.format(fun_name, ' '.join(map(str, shapes)))) elif config.numpy_rank_promotion.value == "raise": msg = ("Operands could not be broadcast together for {} on shapes {} " "and with the config option jax_numpy_rank_promotion='raise'. " "For more information, see " - "https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.") + "https://docs.jax.dev/en/latest/rank_promotion_warning.html.") raise ValueError(msg.format(fun_name, ' '.join(map(str, shapes)))) @@ -85,8 +82,7 @@ def promote_dtypes(*args: ArrayLike) -> list[Array]: if len(args) < 2: return [lax.asarray(arg) for arg in args] else: - to_dtype, weak_type = dtypes._lattice_result_type(*args) - to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_extended_dtype=True) + to_dtype, weak_type = dtypes.lattice_result_type(*args) return [lax._convert_element_type(x, to_dtype, weak_type) for x in args] @@ -94,8 +90,7 @@ def promote_dtypes_inexact(*args: ArrayLike) -> list[Array]: """Convenience function to apply Numpy argument dtype promotion. Promotes arguments to an inexact type.""" - to_dtype, weak_type = dtypes._lattice_result_type(*args) - to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_extended_dtype=True) + to_dtype, weak_type = dtypes.lattice_result_type(*args) to_dtype_inexact = dtypes.to_inexact_dtype(to_dtype) # type: ignore[arg-type] return [lax._convert_element_type(x, to_dtype_inexact, weak_type) for x in args] @@ -105,8 +100,7 @@ def promote_dtypes_numeric(*args: ArrayLike) -> list[Array]: """Convenience function to apply Numpy argument dtype promotion. Promotes arguments to a numeric (non-bool) type.""" - to_dtype, weak_type = dtypes._lattice_result_type(*args) - to_dtype = dtypes.canonicalize_dtype(to_dtype) + to_dtype, weak_type = dtypes.lattice_result_type(*args) to_dtype_numeric = dtypes.to_numeric_dtype(to_dtype) return [lax._convert_element_type(x, to_dtype_numeric, weak_type) for x in args] @@ -116,20 +110,16 @@ def promote_dtypes_complex(*args: ArrayLike) -> list[Array]: """Convenience function to apply Numpy argument dtype promotion. Promotes arguments to a complex type.""" - to_dtype, weak_type = dtypes._lattice_result_type(*args) - to_dtype = dtypes.canonicalize_dtype(to_dtype) + to_dtype, weak_type = dtypes.lattice_result_type(*args) to_dtype_complex = dtypes.to_complex_dtype(to_dtype) return [lax._convert_element_type(x, to_dtype_complex, weak_type) for x in args] -def _complex_elem_type(dtype: DTypeLike) -> DType: - """Returns the float type of the real/imaginary parts of a complex dtype.""" - return np.abs(np.zeros((), dtype)).dtype - +_arraylike_types = (np.ndarray, Array, literals.TypedNdArray) def _arraylike(x: ArrayLike) -> bool: - return (isinstance(x, np.ndarray) or isinstance(x, Array) or + return (isinstance(x, _arraylike_types) or hasattr(x, '__jax_array__') or np.isscalar(x)) @@ -140,6 +130,10 @@ def _arraylike_asarray(x: Any) -> Array: return lax.asarray(x) +def _check_jax_array_protocol(x: Any) -> Any: + return x.__jax_array__() if hasattr(x, '__jax_array__') else x + + @overload def ensure_arraylike(fun_name: str, /) -> tuple[()]: ... @overload @@ -158,7 +152,7 @@ def ensure_arraylike(fun_name: str, /, *args: Any) -> Array | tuple[Array, ...]: return tuple(_arraylike_asarray(arg) for arg in args) # pytype: disable=bad-return-type -def ensure_arraylike_tuple(fun_name: str, tup: tuple[Any, ...]) -> tuple[Array, ...]: +def ensure_arraylike_tuple(fun_name: str, tup: Sequence[Any]) -> tuple[Array, ...]: """Check that argument elements are arraylike and convert to a tuple of arrays. This is useful because ensure_arraylike with a single argument returns a single array. @@ -222,6 +216,7 @@ def check_for_prngkeys(fun_name: str, *args: Any): def promote_args(fun_name: str, *args: ArrayLike) -> list[Array]: """Convenience function to apply Numpy argument shape and dtype promotion.""" check_arraylike(fun_name, *args) + args = tuple(_check_jax_array_protocol(arg) for arg in args) _check_no_float0s(fun_name, *args) check_for_prngkeys(fun_name, *args) return promote_shapes(fun_name, *promote_dtypes(*args)) @@ -229,6 +224,7 @@ def promote_args(fun_name: str, *args: ArrayLike) -> list[Array]: def promote_args_numeric(fun_name: str, *args: ArrayLike) -> list[Array]: check_arraylike(fun_name, *args) + args = tuple(_check_jax_array_protocol(arg) for arg in args) _check_no_float0s(fun_name, *args) check_for_prngkeys(fun_name, *args) return promote_shapes(fun_name, *promote_dtypes_numeric(*args)) @@ -239,12 +235,13 @@ def promote_args_inexact(fun_name: str, *args: ArrayLike) -> list[Array]: Promotes non-inexact types to an inexact type.""" check_arraylike(fun_name, *args) + args = tuple(_check_jax_array_protocol(arg) for arg in args) _check_no_float0s(fun_name, *args) check_for_prngkeys(fun_name, *args) return promote_shapes(fun_name, *promote_dtypes_inexact(*args)) -@partial(api.jit, inline=True) +@api.jit(inline=True) def _broadcast_arrays(*args: ArrayLike) -> list[Array]: """Like Numpy's broadcast_arrays but doesn't return views.""" avals = [core.shaped_abstractify(arg) for arg in args] @@ -258,14 +255,15 @@ def _broadcast_arrays(*args: ArrayLike) -> list[Array]: def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape, sharding=None ) -> Array: - check_arraylike("broadcast_to", arr) + arr = ensure_arraylike("broadcast_to", arr) arr = arr if isinstance(arr, Array) else lax.asarray(arr) if not isinstance(shape, tuple) and np.ndim(shape) == 0: shape = (shape,) # check that shape is concrete shape = core.canonicalize_shape(shape) # type: ignore[arg-type] arr_shape = np.shape(arr) - if core.definitely_equal_shape(arr_shape, shape): + if (core.definitely_equal_shape(arr_shape, shape) and + (sharding is None or core.typeof(arr).sharding == sharding)): return arr elif len(shape) < len(arr_shape): raise ValueError(f"Cannot broadcast to shape with fewer dimensions: {arr_shape=} {shape=}") @@ -286,6 +284,7 @@ def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape, sharding=None # materialize the broadcast forms of scalar arguments. @api.jit def _where(condition: ArrayLike, x: ArrayLike, y: ArrayLike) -> Array: + condition, x, y = ensure_arraylike("where", condition, x, y) if x is None or y is None: raise ValueError("Either both or neither of the x and y arguments should " "be provided to jax.numpy.where, got {} and {}." @@ -304,16 +303,28 @@ def _where(condition: ArrayLike, x: ArrayLike, y: ArrayLike) -> Array: is_always_empty = False # can fail with dynamic shapes return lax.select(condition, x_arr, y_arr) if not is_always_empty else x_arr - -def normalize_device_to_sharding(device: xc.Device | Sharding | None) -> Sharding | None: +def canonicalize_device_to_sharding(device: xc.Device | Sharding | None + ) -> Sharding | None: if isinstance(device, xc.Device): return SingleDeviceSharding(device) - else: - return device + return device + +def choose_device_or_out_sharding(device: xc.Device | Sharding | None, + out_sharding: NamedSharding | P | None, + name: str) -> Sharding | NamedSharding | None: + if device is not None and out_sharding is not None: + raise ValueError( + f"Only one of `device` or `out_sharding` can be set. Got {device=} and" + f" {out_sharding=}") + if device is not None and out_sharding is None: + return canonicalize_device_to_sharding(device) + if device is None and out_sharding is not None: + return canonicalize_sharding(out_sharding, name) + return None @export -def ndim(a: ArrayLike) -> int: +def ndim(a: ArrayLike | SupportsNdim) -> int: """Return the number of dimensions of an array. JAX implementation of :func:`numpy.ndim`. Unlike ``np.ndim``, this function @@ -321,7 +332,7 @@ def ndim(a: ArrayLike) -> int: tuple. Args: - a: array-like object. + a: array-like object, or any object with an ``ndim`` attribute. Returns: An integer specifying the number of dimensions of ``a``. @@ -346,13 +357,18 @@ def ndim(a: ArrayLike) -> int: >>> x.ndim 1 """ + if hasattr(a, "ndim"): + return a.ndim # Deprecation warning added 2025-2-20. check_arraylike("ndim", a, emit_warning=True) - return np.ndim(a) # NumPy dispatches to a.ndim if available. + if hasattr(a, "__jax_array__"): + a = a.__jax_array__() + # NumPy dispatches to a.ndim if available. + return np.ndim(a) # type: ignore[arg-type] @export -def shape(a: ArrayLike) -> tuple[int, ...]: +def shape(a: ArrayLike | SupportsShape) -> tuple[int, ...]: """Return the shape an array. JAX implementation of :func:`numpy.shape`. Unlike ``np.shape``, this function @@ -360,7 +376,7 @@ def shape(a: ArrayLike) -> tuple[int, ...]: tuple. Args: - a: array-like object. + a: array-like object, or any object with a ``shape`` attribute. Returns: An tuple of integers representing the shape of ``a``. @@ -385,13 +401,18 @@ def shape(a: ArrayLike) -> tuple[int, ...]: >>> x.shape (10,) """ + if hasattr(a, "shape"): + return a.shape # Deprecation warning added 2025-2-20. check_arraylike("shape", a, emit_warning=True) - return np.shape(a) # NumPy dispatches to a.shape if available. + if hasattr(a, "__jax_array__"): + a = a.__jax_array__() + # NumPy dispatches to a.shape if available. + return np.shape(a) # type: ignore[arg-type] @export -def size(a: ArrayLike, axis: int | None = None) -> int: +def size(a: ArrayLike | SupportsSize | SupportsShape, axis: int | Sequence[int] | None = None) -> int: """Return number of elements along a given axis. JAX implementation of :func:`numpy.size`. Unlike ``np.size``, this function @@ -399,9 +420,10 @@ def size(a: ArrayLike, axis: int | None = None) -> int: tuple. Args: - a: array-like object - axis: optional integer along which to count elements. By default, return - the total number of elements. + a: array-like object, or any object with a ``size`` attribute when ``axis`` is not + specified, or with a ``shape`` attribute when ``axis`` is specified. + axis: optional integer or sequence of integers indicating which axis or axes to count + elements along. ``None`` (the default) returns the total number of elements. Returns: An integer specifying the number of elements in ``a``. @@ -417,6 +439,10 @@ def size(a: ArrayLike, axis: int | None = None) -> int: 6 >>> jnp.size(y, axis=1) 3 + >>> jnp.size(y, axis=(1,)) + 3 + >>> jnp.size(y, axis=(0, 1)) + 6 This also works for scalars: @@ -428,6 +454,9 @@ def size(a: ArrayLike, axis: int | None = None) -> int: >>> y.size 6 """ - # Deprecation warning added 2025-2-20. check_arraylike("size", a, emit_warning=True) - return np.size(a, axis=axis) # NumPy dispatches to a.size if available. + if axis is None and hasattr(a, "size"): + return a.size + _shape = shape(a) # type: ignore[arg-type] + axis = canonicalize_axis_tuple(axis, len(_shape), allow_duplicate=False) + return math.prod(_shape[i] for i in axis) diff --git a/jax/_src/numpy/vectorize.py b/jax/_src/numpy/vectorize.py index e6ad1386a52e..f166a96a4693 100644 --- a/jax/_src/numpy/vectorize.py +++ b/jax/_src/numpy/vectorize.py @@ -23,7 +23,7 @@ from jax._src import api from jax._src import config -from jax import lax +from jax._src.lax import lax from jax._src.numpy import lax_numpy as jnp from jax._src.util import set_module, safe_map as map, safe_zip as zip @@ -144,18 +144,15 @@ def wrapped(*args): out = func(*args) out_shapes = map(np.shape, out if isinstance(out, tuple) else [out]) - if expected_output_core_dims is None: - output_core_dims = [()] * len(out_shapes) - else: - output_core_dims = expected_output_core_dims - if len(output_core_dims) > 1 and not isinstance(out, tuple): - raise TypeError( - "output must be a tuple when multiple outputs are expected, " - "got: {!r}\n{}".format(out, error_context)) - if len(out_shapes) != len(output_core_dims): - raise TypeError( - 'wrong number of output arguments: expected %r, got %r %s' - % (len(output_core_dims), len(out_shapes), error_context)) + output_core_dims = expected_output_core_dims + if len(output_core_dims) > 1 and not isinstance(out, tuple): + raise TypeError( + "output must be a tuple when multiple outputs are expected, " + "got: {!r}\n{}".format(out, error_context)) + if len(out_shapes) != len(output_core_dims): + raise TypeError( + 'wrong number of output arguments: expected %r, got %r %s' + % (len(output_core_dims), len(out_shapes), error_context)) sizes = dict(dim_sizes) for shape, core_dims in zip(out_shapes, output_core_dims): @@ -215,7 +212,8 @@ def vectorize(pyfunc, *, excluded=frozenset(), signature=None): ``(m,n),(n)->(m)`` for vectorized matrix-vector multiplication. If provided, ``pyfunc`` will be called with (and expected to return) arrays with shapes given by the size of corresponding core dimensions. By - default, pyfunc is assumed to take scalars arrays as input and output. + default, pyfunc is assumed to take scalar arrays as input, and if + ``signature`` is ``None``, ``pyfunc`` can produce outputs of any shape. Returns: Vectorized version of the given function. @@ -294,8 +292,11 @@ def wrapped(*args, **kwargs): broadcast_shape, dim_sizes = _parse_input_dimensions( args, input_core_dims, error_context) - checked_func = _check_output_dims( - excluded_func, dim_sizes, output_core_dims, error_context) + if output_core_dims is None: + checked_func = excluded_func + else: + checked_func = _check_output_dims( + excluded_func, dim_sizes, output_core_dims, error_context) # Detect implicit rank promotion: if config.numpy_rank_promotion.value != "allow": @@ -307,7 +308,7 @@ def wrapped(*args, **kwargs): f" promotion for jnp.vectorize function with signature {signature}." " Set the jax_numpy_rank_promotion config option to 'allow' to" " disable this message; for more information, see" - " https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.") + " https://docs.jax.dev/en/latest/rank_promotion_warning.html.") if config.numpy_rank_promotion.value == "warn": warnings.warn(msg) elif config.numpy_rank_promotion.value == "raise": diff --git a/jax/_src/numpy/window_functions.py b/jax/_src/numpy/window_functions.py index 96a15db777a8..f7ca9e38ea8f 100644 --- a/jax/_src/numpy/window_functions.py +++ b/jax/_src/numpy/window_functions.py @@ -16,11 +16,11 @@ from jax._src import core from jax._src import dtypes +from jax._src.lax import lax from jax._src.numpy import lax_numpy from jax._src.numpy import ufuncs from jax._src.typing import Array, ArrayLike from jax._src.util import set_module -from jax import lax export = set_module('jax.numpy') @@ -49,7 +49,7 @@ def blackman(M: int) -> Array: - :func:`jax.numpy.kaiser`: return a Kaiser window of size M. """ M = core.concrete_or_error(int, M, "M argument of jnp.blackman") - dtype = dtypes.canonicalize_dtype(dtypes.float_) + dtype = dtypes.default_float_dtype() if M <= 1: return lax.full((M,), 1, dtype) n = lax.iota(dtype, M) @@ -80,7 +80,7 @@ def bartlett(M: int) -> Array: - :func:`jax.numpy.kaiser`: return a Kaiser window of size M. """ M = core.concrete_or_error(int, M, "M argument of jnp.bartlett") - dtype = dtypes.canonicalize_dtype(dtypes.float_) + dtype = dtypes.default_float_dtype() if M <= 1: return lax.full((M,), 1, dtype) n = lax.iota(dtype, M) @@ -111,7 +111,7 @@ def hamming(M: int) -> Array: - :func:`jax.numpy.kaiser`: return a Kaiser window of size M. """ M = core.concrete_or_error(int, M, "M argument of jnp.hamming") - dtype = dtypes.canonicalize_dtype(dtypes.float_) + dtype = dtypes.default_float_dtype() if M <= 1: return lax.full((M,), 1, dtype) n = lax.iota(dtype, M) @@ -142,7 +142,7 @@ def hanning(M: int) -> Array: - :func:`jax.numpy.kaiser`: return a Kaiser window of size M. """ M = core.concrete_or_error(int, M, "M argument of jnp.hanning") - dtype = dtypes.canonicalize_dtype(dtypes.float_) + dtype = dtypes.default_float_dtype() if M <= 1: return lax.full((M,), 1, dtype) n = lax.iota(dtype, M) @@ -174,7 +174,7 @@ def kaiser(M: int, beta: ArrayLike) -> Array: - :func:`jax.numpy.hanning`: return a Hanning window of size M. """ M = core.concrete_or_error(int, M, "M argument of jnp.kaiser") - dtype = dtypes.canonicalize_dtype(dtypes.float_) + dtype = dtypes.default_float_dtype() if M <= 1: return lax.full((M,), 1, dtype) n = lax.iota(dtype, M) diff --git a/jax/_src/op_shardings.py b/jax/_src/op_shardings.py index ed53b52a2c6d..fd05c26503ff 100644 --- a/jax/_src/op_shardings.py +++ b/jax/_src/op_shardings.py @@ -25,42 +25,50 @@ def get_num_ways_dim_sharded( - hlo_sharding: xc.HloSharding) -> tuple[list[int], int]: + hlo_sharding: xc.HloSharding, allow_partial_manual: bool = False +) -> tuple[list[int], int]: + assert not hlo_sharding.is_manual() if hlo_sharding.is_replicated(): return [], 1 + if hlo_sharding.is_unreduced(): + return [], 1 partitions = hlo_sharding.tile_assignment_dimensions() subgroup_types = hlo_sharding.subgroup_types() if subgroup_types == [xc.OpSharding.Type.REPLICATED]: - replicate_on_last_tile_dim = True + return list(partitions[:-1]), partitions[-1] + elif subgroup_types == [xc.OpSharding.Type.UNREDUCED]: + return list(partitions[:-1]), 1 + elif set(subgroup_types) == {xc.OpSharding.Type.REPLICATED, + xc.OpSharding.Type.UNREDUCED}: + replicated_loc = subgroup_types.index(xc.OpSharding.Type.REPLICATED) + return list(partitions[:-2]), partitions[-2:][replicated_loc] + elif allow_partial_manual and xc.OpSharding.Type.MANUAL in subgroup_types: + if subgroup_types == [xc.OpSharding.Type.MANUAL]: + return list(partitions[:-1]), 1 + else: + assert (set(subgroup_types) == + {xc.OpSharding.Type.REPLICATED, xc.OpSharding.Type.MANUAL}) + replicated_loc = subgroup_types.index(xc.OpSharding.Type.REPLICATED) + return list(partitions[:-2]), partitions[-2:][replicated_loc] + elif hlo_sharding.replicate_on_last_tile_dim(): + return list(partitions[:-1]), partitions[-1] else: - replicate_on_last_tile_dim = hlo_sharding.replicate_on_last_tile_dim() if subgroup_types: - raise NotImplementedError( - "Unhandled OpSharding type. Please open a bug report!") - num_replicas = 1 - if replicate_on_last_tile_dim: - num_replicas = partitions[-1] - partitions = partitions[:-1] - return list(partitions), num_replicas - - -def is_op_sharding_replicated(op: xc.OpSharding | xc.HloSharding) -> bool: - if isinstance(op, xc.OpSharding): - op = xc.HloSharding.from_proto(op) - if op.num_devices() == 1: - return True - return op.is_replicated() + raise NotImplementedError(f"Unhandled OpSharding type: {hlo_sharding}. " + "Please open a bug report!") + return list(partitions), 1 + + +def is_hlo_sharding_replicated(hc: xc.HloSharding) -> bool: + return True if hc.num_devices() == 1 else hc.is_replicated() -def are_op_shardings_equal(op1: xc.OpSharding | xc.HloSharding, - op2: xc.OpSharding | xc.HloSharding) -> bool: - if op1 is op2: +def are_hlo_shardings_equal(hc1: xc.HloSharding, hc2: xc.HloSharding) -> bool: + if hc1 is hc2: return True - if is_op_sharding_replicated(op1) and is_op_sharding_replicated(op2): + if is_hlo_sharding_replicated(hc1) and is_hlo_sharding_replicated(hc2): return True - hc1 = xc.HloSharding.from_proto(op1) if isinstance(op1, xc.OpSharding) else op1 - hc2 = xc.HloSharding.from_proto(op2) if isinstance(op2, xc.OpSharding) else op2 return hc1 == hc2 @@ -75,7 +83,7 @@ def op_sharding_to_numpy_indices( # num_devices is required as an argument when hlo_sharding is # REPLICATED. `jax.device_count()` cannot be used because you can create # an opsharding with less number of devices than `jax.device_count()`. - if is_op_sharding_replicated(hlo_sharding): + if is_hlo_sharding_replicated(hlo_sharding): indices.fill((slice(None),) * len(shape)) return indices @@ -98,7 +106,7 @@ def op_sharding_to_numpy_indices( device_it = iter(hlo_sharding.tile_assignment_devices()) - for i, idxs in enumerate(itertools.product(*axis_indices)): + for idxs in itertools.product(*axis_indices): for _ in range(num_replicas): indices[next(device_it)] = idxs return indices diff --git a/jax/_src/ops/scatter.py b/jax/_src/ops/scatter.py index e19be6622168..f120d386b5fd 100644 --- a/jax/_src/ops/scatter.py +++ b/jax/_src/ops/scatter.py @@ -16,34 +16,34 @@ from __future__ import annotations -from collections.abc import Callable, Sequence -from typing import Union +from collections.abc import Callable +from functools import partial +from typing import Any import warnings import numpy as np -from jax import lax - from jax._src import config from jax._src import core from jax._src import dtypes +from jax._src import numpy as jnp +from jax._src import tree_util from jax._src import util -from jax._src.lax import lax as lax_internal +from jax._src.lax import lax +from jax._src.lax import slicing from jax._src.numpy import indexing -from jax._src.numpy import lax_numpy as jnp from jax._src.numpy import reductions from jax._src.numpy.util import check_arraylike, promote_dtypes -from jax._src.typing import Array, ArrayLike - - -from types import EllipsisType -SingleIndex = int | slice | Sequence[int] | Array | EllipsisType | None -Index = Union[SingleIndex, tuple[SingleIndex, ...]] -Scalar = Union[complex, float, int, np.number] +from jax._src.pjit import auto_axes +from jax._src.sharding_impls import NamedSharding +from jax._src.typing import Array, ArrayLike, Index -def _scatter_update(x, idx, y, scatter_op, indices_are_sorted, - unique_indices, mode=None, normalize_indices=True): +def _scatter_update(x: ArrayLike, idx: Index | tuple[Index, ...], + y: ArrayLike, scatter_op: Callable[..., Array], + indices_are_sorted: bool, unique_indices: bool, + mode: slicing.GatherScatterMode | str | None = None, normalize_indices: bool = True, + out_sharding: NamedSharding | None = None): """Helper for indexed updates. Computes the value of x that would result from computing:: @@ -73,18 +73,29 @@ def _scatter_update(x, idx, y, scatter_op, indices_are_sorted, # XLA gathers and scatters are very similar in structure; the scatter logic # is more or less a transpose of the gather equivalent. - treedef, static_idx, dynamic_idx = indexing.split_index_for_jit(idx, x.shape) - return _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, - indices_are_sorted, unique_indices, mode, - normalize_indices) + indexer = indexing.NDIndexer.from_raw_indices(idx, x.shape).expand_bool_indices() + dynamic_idx, treedef = tree_util.tree_flatten(indexer) + + internal_scatter = partial( + _scatter_impl, scatter_op=scatter_op, treedef=treedef, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, mode=mode, + normalize_indices=normalize_indices) + if out_sharding is not None: + return auto_axes(internal_scatter, out_sharding=out_sharding, + axes=out_sharding.mesh.explicit_axes # type: ignore + )(x, y, dynamic_idx) + return internal_scatter(x, y, tuple(dynamic_idx)) # TODO(phawkins): re-enable jit after fixing excessive recompilation for # slice indexes (e.g., slice(0, 5, None), slice(10, 15, None), etc.). -# @partial(jit, static_argnums=(2, 3, 4)) -def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, - indices_are_sorted, unique_indices, mode, - normalize_indices): +# @jit(static_argnums=(2, 3, 4)) +def _scatter_impl(x: ArrayLike, y: ArrayLike, dynamic_idx: tuple[Any, ...], *, + scatter_op: Callable[..., Array], + treedef: tree_util.PyTreeDef, + indices_are_sorted: bool, unique_indices: bool, + mode: slicing.GatherScatterMode | str | None, normalize_indices: bool): dtype = lax.dtype(x) weak_type = dtypes.is_weakly_typed(x) @@ -93,13 +104,12 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, warnings.warn( "scatter inputs have incompatible types: cannot safely cast value " f"from dtype={lax.dtype(y)} to dtype={lax.dtype(x)} with " - f"jax_numpy_dtype_promotion={config.numpy_dtype_promotion.value!r}. " + f"jax_numpy_dtype_promotion={config.numpy_dtype_promotion.value}. " "In future JAX releases this will result in an error.", FutureWarning) - idx = indexing.merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx) - indexer = indexing.index_to_gather(np.shape(x), idx, - normalize_indices=normalize_indices) + general_indexer = tree_util.tree_unflatten(treedef, dynamic_idx) + indexer = general_indexer.to_gather(core.typeof(x).sharding, normalize_indices=normalize_indices) # Avoid calling scatter if the slice shape is empty, both as a fast path and # to handle cases like zeros(0)[array([], int32)]. @@ -109,7 +119,8 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, x, y = promote_dtypes(x, y) # Broadcast `y` to the slice output shape. - y = jnp.broadcast_to(y, tuple(indexer.slice_shape)) + y = jnp.broadcast_to(y, tuple(indexer.slice_shape), + out_sharding=indexer.slice_sharding) # Collapse any `None`/`np.newaxis` dimensions. y = jnp.squeeze(y, axis=indexer.newaxis_dims) if indexer.reversed_y_dims: @@ -120,7 +131,7 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, # Transpose the gather dimensions into scatter dimensions (cf. # lax._gather_transpose_rule) - dnums = lax.ScatterDimensionNumbers( + dnums = slicing.ScatterDimensionNumbers( update_window_dims=indexer.dnums.offset_dims, inserted_window_dims=indexer.dnums.collapsed_slice_dims, scatter_dims_to_operand_dims=indexer.dnums.start_index_map, @@ -134,22 +145,22 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, mode=mode) if indexer.scalar_bool_dims: out = lax.squeeze(out, indexer.scalar_bool_dims) - return lax_internal._convert_element_type(out, dtype, weak_type) + return lax._convert_element_type(out, dtype, weak_type) def _get_identity(op, dtype): """Get an appropriate identity for a given operation in a given dtype.""" - if op is lax.scatter_add: + if op is slicing.scatter_add: return 0 - elif op is lax.scatter_mul: + elif op is slicing.scatter_mul: return 1 - elif op is lax.scatter_min: + elif op is slicing.scatter_min: if dtype == dtypes.bool_: return True elif dtypes.issubdtype(dtype, np.integer): return dtypes.iinfo(dtype).max return float('inf') - elif op is lax.scatter_max: + elif op is slicing.scatter_max: if dtype == dtypes.bool_: return False elif dtypes.issubdtype(dtype, np.integer): @@ -168,9 +179,9 @@ def _segment_update(name: str, unique_indices: bool = False, bucket_size: int | None = None, reducer: Callable | None = None, - mode: lax.GatherScatterMode | None = None) -> Array: + mode: slicing.GatherScatterMode | str | None = None) -> Array: check_arraylike(name, data, segment_ids) - mode = lax.GatherScatterMode.FILL_OR_DROP if mode is None else mode + mode = slicing.GatherScatterMode.FILL_OR_DROP if mode is None else mode data = jnp.asarray(data) segment_ids = jnp.asarray(segment_ids) dtype = data.dtype @@ -207,7 +218,7 @@ def segment_sum(data: ArrayLike, indices_are_sorted: bool = False, unique_indices: bool = False, bucket_size: int | None = None, - mode: lax.GatherScatterMode | None = None) -> Array: + mode: slicing.GatherScatterMode | str | None = None) -> Array: """Computes the sum within segments of an array. Similar to TensorFlow's `segment_sum @@ -252,7 +263,7 @@ def segment_sum(data: ArrayLike, Array([1, 5, 4], dtype=int32) """ return _segment_update( - "segment_sum", data, segment_ids, lax.scatter_add, num_segments, + "segment_sum", data, segment_ids, slicing.scatter_add, num_segments, indices_are_sorted, unique_indices, bucket_size, reductions.sum, mode=mode) @@ -262,7 +273,7 @@ def segment_prod(data: ArrayLike, indices_are_sorted: bool = False, unique_indices: bool = False, bucket_size: int | None = None, - mode: lax.GatherScatterMode | None = None) -> Array: + mode: slicing.GatherScatterMode | str | None = None) -> Array: """Computes the product within segments of an array. Similar to TensorFlow's `segment_prod @@ -272,8 +283,7 @@ def segment_prod(data: ArrayLike, data: an array with the values to be reduced. segment_ids: an array with integer dtype that indicates the segments of `data` (along its leading axis) to be reduced. Values can be repeated and - need not be sorted. Values outside of the range [0, num_segments) are - dropped and do not contribute to the result. + need not be sorted. num_segments: optional, an int with nonnegative value indicating the number of segments. The default is set to be the minimum number of segments that would support all indices in ``segment_ids``, calculated as @@ -283,11 +293,11 @@ def segment_prod(data: ArrayLike, indices_are_sorted: whether ``segment_ids`` is known to be sorted. unique_indices: whether `segment_ids` is known to be free of duplicates. bucket_size: size of bucket to group indices into. ``segment_prod`` is - performed on each bucket separately to improve numerical stability of - addition. Default ``None`` means no bucketing. + performed on each bucket separately to improve numerical stability. + Default ``None`` means no bucketing. mode: a :class:`jax.lax.GatherScatterMode` value describing how out-of-bounds indices should be handled. By default, values outside of the - range [0, num_segments) are dropped and do not contribute to the sum. + range [0, num_segments) are dropped and do not contribute to the result. Returns: An array with shape :code:`(num_segments,) + data.shape[1:]` representing the @@ -308,7 +318,7 @@ def segment_prod(data: ArrayLike, Array([ 0, 6, 20], dtype=int32) """ return _segment_update( - "segment_prod", data, segment_ids, lax.scatter_mul, num_segments, + "segment_prod", data, segment_ids, slicing.scatter_mul, num_segments, indices_are_sorted, unique_indices, bucket_size, reductions.prod, mode=mode) @@ -318,7 +328,7 @@ def segment_max(data: ArrayLike, indices_are_sorted: bool = False, unique_indices: bool = False, bucket_size: int | None = None, - mode: lax.GatherScatterMode | None = None) -> Array: + mode: slicing.GatherScatterMode | str | None = None) -> Array: """Computes the maximum within segments of an array. Similar to TensorFlow's `segment_max @@ -328,8 +338,7 @@ def segment_max(data: ArrayLike, data: an array with the values to be reduced. segment_ids: an array with integer dtype that indicates the segments of `data` (along its leading axis) to be reduced. Values can be repeated and - need not be sorted. Values outside of the range [0, num_segments) are - dropped and do not contribute to the result. + need not be sorted. num_segments: optional, an int with nonnegative value indicating the number of segments. The default is set to be the minimum number of segments that would support all indices in ``segment_ids``, calculated as @@ -342,7 +351,7 @@ def segment_max(data: ArrayLike, performed on each bucket separately. Default ``None`` means no bucketing. mode: a :class:`jax.lax.GatherScatterMode` value describing how out-of-bounds indices should be handled. By default, values outside of the - range [0, num_segments) are dropped and do not contribute to the sum. + range [0, num_segments) are dropped and do not contribute to the result. Returns: An array with shape :code:`(num_segments,) + data.shape[1:]` representing the @@ -363,7 +372,7 @@ def segment_max(data: ArrayLike, Array([1, 3, 5], dtype=int32) """ return _segment_update( - "segment_max", data, segment_ids, lax.scatter_max, num_segments, + "segment_max", data, segment_ids, slicing.scatter_max, num_segments, indices_are_sorted, unique_indices, bucket_size, reductions.max, mode=mode) @@ -373,7 +382,7 @@ def segment_min(data: ArrayLike, indices_are_sorted: bool = False, unique_indices: bool = False, bucket_size: int | None = None, - mode: lax.GatherScatterMode | None = None) -> Array: + mode: slicing.GatherScatterMode | str | None = None) -> Array: """Computes the minimum within segments of an array. Similar to TensorFlow's `segment_min @@ -383,8 +392,7 @@ def segment_min(data: ArrayLike, data: an array with the values to be reduced. segment_ids: an array with integer dtype that indicates the segments of `data` (along its leading axis) to be reduced. Values can be repeated and - need not be sorted. Values outside of the range [0, num_segments) are - dropped and do not contribute to the result. + need not be sorted. num_segments: optional, an int with nonnegative value indicating the number of segments. The default is set to be the minimum number of segments that would support all indices in ``segment_ids``, calculated as @@ -397,7 +405,7 @@ def segment_min(data: ArrayLike, performed on each bucket separately. Default ``None`` means no bucketing. mode: a :class:`jax.lax.GatherScatterMode` value describing how out-of-bounds indices should be handled. By default, values outside of the - range [0, num_segments) are dropped and do not contribute to the sum. + range [0, num_segments) are dropped and do not contribute to the result. Returns: An array with shape :code:`(num_segments,) + data.shape[1:]` representing the @@ -418,5 +426,5 @@ def segment_min(data: ArrayLike, Array([0, 2, 4], dtype=int32) """ return _segment_update( - "segment_min", data, segment_ids, lax.scatter_min, num_segments, + "segment_min", data, segment_ids, slicing.scatter_min, num_segments, indices_are_sorted, unique_indices, bucket_size, reductions.min, mode=mode) diff --git a/jax/_src/ops/special.py b/jax/_src/ops/special.py index fe4c46394832..1f24d609e7be 100644 --- a/jax/_src/ops/special.py +++ b/jax/_src/ops/special.py @@ -16,12 +16,15 @@ from typing import overload, Literal -import jax -from jax import lax -from jax import numpy as jnp +from jax._src import config +from jax._src.lax import lax +from jax._src.numpy import lax_numpy as jnp +from jax._src.numpy import reductions +from jax._src.numpy import ufuncs from jax._src.numpy.reductions import _reduction_dims, Axis from jax._src.numpy.util import promote_args_inexact from jax._src.typing import Array, ArrayLike + import numpy as np # The definition of logsumexp is shared between jax.nn and jax.scipy, and @@ -47,16 +50,15 @@ def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, JAX implementation of :func:`scipy.special.logsumexp`. .. math:: - \mathrm{logsumexp}(a) = \mathrm{log} \sum_j b \cdot \mathrm{exp}(a_{ij}) + \operatorname{logsumexp} a = \log \sum_i b_i \exp a_i - where the :math:`j` indices range over one or more dimensions to be reduced. + where the :math:`i` indices range over one or more dimensions to be reduced. Args: a: the input array axis: int or sequence of ints, default=None. Axis along which the sum to be computed. If None, the sum is computed along all the axes. - b: scaling factors for :math:`\mathrm{exp}(a)`. Must be broadcastable to the - shape of `a`. + b: scaling factors for the exponentials. Must be broadcastable to the shape of `a`. keepdims: If ``True``, the axes that are reduced are left in the output as dimensions of size 1. return_sign: If ``True``, the output will be a ``(result, sign)`` pair, @@ -68,18 +70,21 @@ def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, Returns: Either an array ``result`` or a pair of arrays ``(result, sign)``, depending on the value of the ``return_sign`` argument. + + See also: + :func:`jax.nn.logmeanexp` """ if where is not None: a = jnp.where(where, a, 0) if b is not None: a_arr, b_arr = promote_args_inexact("logsumexp", a, b) - a_arr = jnp.where(b_arr != 0, a_arr, -jnp.inf) + a_arr = jnp.where(b_arr != 0, a_arr, -np.inf) else: a_arr, = promote_args_inexact("logsumexp", a) b_arr = a_arr # for type checking pos_dims, dims = _reduction_dims(a_arr, axis) - amax = jnp.max(a_arr.real, axis=dims, keepdims=keepdims, where=where, initial=-jnp.inf) - amax = lax.stop_gradient(lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0))) + amax = reductions.max(a_arr.real, axis=dims, keepdims=keepdims, where=where, initial=-np.inf) + amax = lax.stop_gradient(lax.select(ufuncs.isfinite(amax), amax, lax.full_like(amax, 0))) amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims) exp_a = lax.exp(lax.sub(a_arr, amax_with_dims.astype(a_arr.dtype))) @@ -94,6 +99,6 @@ def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, if return_sign: return (out, sign) if b is not None and not np.issubdtype(out.dtype, np.complexfloating): - with jax.debug_nans(False): + with config.debug_nans(False): out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype), out) return out diff --git a/jax/_src/pallas/BUILD b/jax/_src/pallas/BUILD index 91987167512c..469b31d9b326 100644 --- a/jax/_src/pallas/BUILD +++ b/jax/_src/pallas/BUILD @@ -38,19 +38,42 @@ py_library( "utils.py", ], deps = [ - "//jax", - "//jax:ad_util", - "//jax:api_util", - "//jax:config", - "//jax:core", - "//jax:dtypes", - "//jax:effects", - "//jax:mlir", - "//jax:partial_eval", - "//jax:pretty_printer", - "//jax:source_info_util", - "//jax:tree_util", - "//jax:util", + "//jax/_src:ad", + "//jax/_src:ad_util", + "//jax/_src:api", + "//jax/_src:api_util", + "//jax/_src:basearray", + "//jax/_src:batching", + "//jax/_src:checkify", + "//jax/_src:config", + "//jax/_src:core", + "//jax/_src:custom_derivatives", + "//jax/_src:debugging", + "//jax/_src:dtypes", + "//jax/_src:effects", + "//jax/_src:export", + "//jax/_src:frozen_dict", + "//jax/_src:hijax", + "//jax/_src:lax", + "//jax/_src:mlir", + "//jax/_src:numpy", + "//jax/_src:partial_eval", + "//jax/_src:pretty_printer", + "//jax/_src:source_info_util", + "//jax/_src:tree_util", + "//jax/_src:typing", + "//jax/_src:util", "//jax/_src/lib", ] + py_deps("numpy"), ) + +py_library( + name = "pallas_test_util", + srcs = [ + "pallas_test_util.py", + ], + deps = [ + ":pallas", + "//jax/_src:test_util", + ], +) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 206c2a73fbed..792559751a92 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -17,6 +17,7 @@ import collections from collections.abc import Callable, Iterable, Iterator, Sequence +from collections.abc import Hashable, Mapping import contextlib import copy import dataclasses @@ -24,24 +25,30 @@ import functools import itertools import threading -from typing import Any, ClassVar, Hashable, Protocol, Union, runtime_checkable +from typing import Any, ClassVar, Literal, Protocol, TypeAlias, Union, runtime_checkable + -import jax from jax._src import api_util +from jax._src.api import jit from jax._src import config from jax._src import core as jax_core from jax._src import dtypes +from jax._src import effects +from jax._src import frozen_dict from jax._src import linear_util as lu from jax._src import state from jax._src import tree_util +from jax._src import typing as jax_typing from jax._src import util from jax._src.export._export import export from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.state import discharge as state_discharge +from jax._src.state import indexing from jax._src.state import types as state_types from jax._src.state.types import TransformedRef -import jax.numpy as jnp +from jax._src import numpy as jnp + class DynamicGridDim: def __repr__(self): @@ -50,7 +57,7 @@ def __repr__(self): partial = functools.partial -GridElement = int | jax_core.Array +GridElement = int | jax_typing.Array GridName = Hashable GridNames = tuple[Hashable, ...] | None NamedGrid = tuple[tuple[GridName, int], ...] @@ -67,18 +74,79 @@ def __repr__(self): SEMAPHORE_INTERPRET_DTYPE = jnp.int16 SEMAPHORE_MAX_VALUE = jnp.iinfo(SEMAPHORE_INTERPRET_DTYPE).max +class AbstractSemaphoreTyRules: + @staticmethod + def pallas_interpret_element_aval(_) -> jax_core.ShapedArray: + return jax_core.ShapedArray((), SEMAPHORE_INTERPRET_DTYPE) + + @staticmethod + def physical_element_aval(_) -> jax_core.ShapedArray: + return jax_core.ShapedArray((), jnp.int32) + +# TODO(sharadmv): implement dtype rules for AbstractSemaphoreTy +class AbstractSemaphoreTy(dtypes.ExtendedDType): + name: str + _rules = AbstractSemaphoreTyRules + + def __repr__(self) -> str: + return self.name + + def __eq__(self, other): + return self.__class__ == other.__class__ + + def __hash__(self) -> int: + return hash(self.__class__) + +class semaphore_dtype(dtypes.extended): + """Common dtype for all kinds of semaphore dtypes. + + This is an abstract class that should never be instantiated, but rather + exists for the sake of ``jnp.issubdtype``. + """ + +class semaphore(semaphore_dtype): + """Regular semaphore dtype. + + Like its superclass, this class should never be instantiated. + """ + +class Semaphore(AbstractSemaphoreTy): + name = "semaphore" + type = semaphore + +class barrier_semaphore(semaphore_dtype): + """Barrier semaphore dtype. + + Like its superclass, this class should never be instantiated. + """ + +class BarrierSemaphore(AbstractSemaphoreTy): + name = "barrier_semaphore" + type = barrier_semaphore + +Backend = Literal["mosaic_tpu", "triton", "mosaic_gpu"] @runtime_checkable class CompilerParams(Protocol): """Base class for compiler parameters.""" - PLATFORM: ClassVar[str] + BACKEND: ClassVar[Backend] # Subclasses must be dataclasses. __dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]] @dataclasses.dataclass(frozen=True) class Buffered: + """Specifies how a block should be buffered for a pipeline. + + Attributes: + buffer_count: The number of buffers to use for multiple buffering. + use_lookahead: optional bool, indicates whether to use lookahead on the + buffer. Enabling lookahead allows the pipeline to begin fetching the next + changed block as soon as a slot is available, no matter how many + iterations ahead that block is. + """ buffer_count: int + use_lookahead: bool = False split_list = util.split_list @@ -90,35 +158,29 @@ class ShapedArrayWithMemorySpace(jax_core.ShapedArray): __slots__ = ["memory_space"] def __init__(self, shape, dtype, weak_type=False, sharding=None, - memory_space=None): - super().__init__(shape, dtype, weak_type=weak_type, sharding=sharding) + vma=frozenset(), memory_space=None): + super().__init__(shape, dtype, weak_type=weak_type, sharding=sharding, + vma=vma) self.memory_space = memory_space def __eq__(self, other): return super().__eq__(other) and self.memory_space == other.memory_space def __hash__(self): - return hash(( + return hash((self.shape, self.dtype, self.weak_type, self.sharding, + self.vma, self.memory_space)) + + def str_short(self, short_dtypes=False, mesh_axis_types=False): + return jax_core.str_short_aval( self.shape, self.dtype, - self.weak_type, - getattr(self, "sharding", None), + self.sharding.mesh, + self.sharding.spec, + self.vma, self.memory_space, - )) - - def str_short(self, short_dtypes=False): - dt_str = \ - dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name - dt_str = dt_str.replace("void", "float0") - shapestr = ",".join(map(str, self.shape)) - if hasattr(self, "sharding"): - sharding_str = f"{dt_str}[{shapestr}]({self.sharding})" - else: - sharding_str = "" - memoryspace_str = ( - "" if self.memory_space is None else f"<{self.memory_space}>" + short_dtypes, + mesh_axis_types, ) - return f"{dt_str}{memoryspace_str}[{shapestr}]{sharding_str}" def update( self, @@ -126,6 +188,7 @@ def update( dtype=None, weak_type=None, sharding=None, + vma=None, memory_space=None, ): if shape is None: @@ -135,11 +198,14 @@ def update( if weak_type is None: weak_type = self.weak_type if sharding is None: - sharding = getattr(self, "sharding", None) + sharding = self.sharding + if vma is None: + vma = self.vma if memory_space is None: memory_space = self.memory_space return ShapedArrayWithMemorySpace( - shape, dtype, weak_type, sharding=sharding, memory_space=memory_space + shape, dtype, weak_type, sharding=sharding, vma=vma, + memory_space=memory_space ) mlir.ir_type_handlers[ShapedArrayWithMemorySpace] = mlir._array_ir_types @@ -147,67 +213,38 @@ def update( @dataclasses.dataclass(frozen=True) class MemoryRef: """Like jax.ShapeDtypeStruct but with memory spaces.""" - shape: tuple[int, ...] - dtype: jnp.dtype + inner_aval: jax_core.AbstractValue # TODO(b/368122763): Unify memory space types across backends memory_space: Any def get_array_aval(self) -> jax_core.ShapedArray: - dtype = self.dtype + if not isinstance(self.inner_aval, jax_core.ShapedArray): + raise ValueError( + f"MemoryRef type must be a ShapedArray, got {type(self.inner_aval)}" + ) + dtype = self.inner_aval.dtype if not isinstance(dtype, (jnp.dtype, dtypes.ExtendedDType)): dtype = jnp.dtype(dtype) return ShapedArrayWithMemorySpace( - self.shape, dtype, memory_space=self.memory_space + self.inner_aval.shape, dtype, memory_space=self.memory_space ) - def get_ref_aval(self) -> TransformedRef | AbstractMemoryRef: + def get_ref_aval(self) -> TransformedRef | state.AbstractRef: # TODO(sharadmv): Clean this up. ShapedArrayWithMemorySpace fails when we # try to apply JAX ops to it. - return AbstractMemoryRef( - jax_core.ShapedArray(self.shape, self.dtype), self.memory_space) - - -class AbstractMemoryRef(state.AbstractRef): - __slots__ = ["inner_aval", "memory_space"] + return state.AbstractRef(self.inner_aval, self.memory_space) - inner_aval: jax_core.ShapedArray - - def __init__(self, inner_aval: jax_core.ShapedArray, memory_space: Any): - if isinstance(inner_aval, ShapedArrayWithMemorySpace): - if inner_aval.memory_space is not None: - assert inner_aval.memory_space == memory_space, ( - f"Mismatched memory spaces: {inner_aval.memory_space=}," - f" {memory_space=}" - ) - self.inner_aval = inner_aval - self.memory_space = memory_space - - def __repr__(self) -> str: - return f'MemRef<{self.memory_space}>{{{self.inner_aval.str_short()}}}' - - def update_weak_type(self, weak_type): - return AbstractMemoryRef( - self.inner_aval.update_weak_type(weak_type), self.memory_space) - - def update(self, inner_aval=None, memory_space=None): - inner_aval = self.inner_aval if inner_aval is None else inner_aval - memory_space = self.memory_space if memory_space is None else memory_space - return AbstractMemoryRef(inner_aval, memory_space) - - def to_tangent_aval(self): - return AbstractMemoryRef( - self.inner_aval.to_tangent_aval(), self.memory_space) - - # TODO(dougalm, sharadmv): figure out how to avoid needing this - def normalize(self): - return state.AbstractRef(self.inner_aval).normalize() + @property + def dtype(self): + return self.inner_aval.dtype - def __eq__(self, other): - return (type(self) is type(other) and self.inner_aval == other.inner_aval - and self.memory_space == other.memory_space) + @property + def shape(self): + return self.inner_aval.shape - def __hash__(self): - return hash((self.__class__, self.inner_aval, self.memory_space)) + def __lt__(self, other): + return (self.shape, self.dtype, self.memory_space) < ( + other.shape, other.dtype, other.memory_space) class MemorySpace(enum.Enum): @@ -219,6 +256,16 @@ class MemorySpace(enum.Enum): ANY = "any" # Unrestricted memory space (usually HBM) ERROR = "error" # Memory space for checkify errors. INDEX = "index" # Memory space for scalar prefetch arguments. + KEY = "key" # Memory space for PRNG keys. + HOST = "host" # Host memory space. + + def from_type(self, type: jax_core.AbstractValue) -> MemoryRef: + return MemoryRef(type, memory_space=self) + + def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype): + # A convenience function for constructing MemoryRef types of ShapedArrays. + return self.from_type(jax_core.ShapedArray(shape, dtype)) + def __str__(self) -> str: return self.value @@ -262,7 +309,7 @@ def axis_frame() -> PallasGridContext: @dataclasses.dataclass(frozen=True) class GridAxis: - index: jax.Array + index: jax_typing.Array size: int # Stores the kernel execution position and the size along grid axes. @@ -283,49 +330,167 @@ def current_grid_env() -> GridEnv | None: return _pallas_tracing_env.grid_env_stack[-1] -class Mapped: - """Used as a block shape dimension to denote a mapped dimension. - A mapped dimension behaves like `1` except it is squeezed from the block. - See :ref:`pallas_blockspec` for more details. - """ - def __repr__(self): - return "Mapped" -mapped = Mapped() +@dataclasses.dataclass(frozen=True) +class Element: + """Use to index an array using an elementwise start index.""" + block_size: int + padding: tuple[int, int] = (0, 0) + def __str__(self): + if self.padding == (0, 0): + return f"Element({self.block_size})" + return f"Element({self.block_size}, padding={self.padding})" @dataclasses.dataclass(frozen=True) -class Unblocked: - padding: tuple[tuple[int, int], ...] | None = None - - def __repr__(self): - return f"Unblocked(padding={self.padding})" -unblocked = Unblocked() +class Squeezed: + """Represents a one-sized block dimension that is squeezed out in the kernel.""" +squeezed = Squeezed() +@dataclasses.dataclass(frozen=True) class Blocked: + """The default BlockShape type.""" + block_size: int + + def __str__(self): + return f"Blocked({self.block_size})" + +@dataclasses.dataclass(frozen=True) +class BoundedSlice: + """Allows to specify a bounded slice of a dimension. + + Specifically, the index_map need to return a ``pl.Slice/pl.ds`` for this + dimension. The start and size may be dynamic, as long as the size <= + block_size. + """ + block_size: int + def __repr__(self): - return "Blocked" -blocked = Blocked() + return f"BoundedSlice({self.block_size})" +BlockDim: TypeAlias = Element | Squeezed | Blocked | BoundedSlice -IndexingMode = Union[Blocked, Unblocked] def default_index_map(ndim: int) -> Callable: return lambda *args: (0,) * ndim + +def _canonicalize_block_dim(dim: BlockDim | int | None) -> BlockDim: + match dim: + case None: + return squeezed + case int(): + return Blocked(int(dim)) + case Squeezed() | Blocked() | Element() | BoundedSlice(): + return dim + case _: + # Handle case where the dim is a symbolic dimension so we assume it is + # Blocked. + if jax_core.is_symbolic_dim(dim): + return Blocked(dim) + try: + return Blocked(int(dim)) + except Exception as e: + raise ValueError( + f"Unsupported block dimension type: {type(dim)}. Allowed types:" + " `pl.Squeezed`, `pl.Blocked`, `pl.Element`, `int`, `None`." + ) from e + +def _canonicalize_block_shape(block_shape: Sequence[BlockDim | int | None] + ) -> tuple[BlockDim, ...]: + return tuple(_canonicalize_block_dim(dim) for dim in block_shape) + + +def _get_block_dim_size(dim: BlockDim) -> int: + match dim: + case Squeezed(): + return 1 + case Blocked(block_size): + return block_size + case Element(): + return dim.block_size + case BoundedSlice(block_size): + return block_size + case _: + raise ValueError(f"Unsupported block shape type: {type(dim)}") + +def get_block_size(dim: BlockDim | int | None) -> int: + match dim: + case int(): + return dim + case Squeezed() | None: + return 1 + case ( + Blocked(block_size) | Element(block_size, _) | BoundedSlice(block_size) + ): + return block_size + case _: + raise ValueError(f"Unsupported block shape type: {type(dim)}") + + +def _get_block_shape(block_shape: tuple[BlockDim, ...]) -> tuple[int, ...]: + return tuple(_get_block_dim_size(dim) for dim in block_shape) + +def _get_ref_block_shape(block_shape: tuple[BlockDim, ...]) -> tuple[int, ...]: + # Special handling for squeezed here (don't include Squeezed dims in the Ref + # shape). + return tuple( + _get_block_dim_size(dim) + for dim in block_shape + if not isinstance(dim, Squeezed) + ) + + +class _IndexMapFunc: + """Helper class that checks for index_map equality.""" + + def __init__(self, index_map): + self.index_map = index_map + functools.update_wrapper(self, self.index_map) + + def __eq__(self, other: object): + if not isinstance(other, _IndexMapFunc): + return NotImplemented + return self.index_map == other.index_map + + def __call__(self, *args, **kwargs): + out_indices = self.index_map(*args, **kwargs) + if isinstance(out_indices, list): + out_indices = tuple(out_indices) + if not isinstance(out_indices, tuple): + out_indices = (out_indices,) + return out_indices + + @dataclasses.dataclass class BlockSpec: """Specifies how an array should be sliced for each invocation of a kernel. - See :ref:`pallas_blockspec` for more details. + The `block_shape` is a sequence of `int | None`s, or `BlockDim` types (e.g. + `pl.Element`, `pl.Squeezed`, `pl.Blocked`, `pl.BoundedSlice`). Each of these + types specify the size of the block dimension. `None` is used to specify a + dimension that is squeezed out of the kernel. The `BlockDim` types allow for + more fine-grained control over the indexing of the dimension. The `index_map` + needs to return a tuple of the same length as `block_shape`, which each entry + depending on the type of `BlockDim`. + + See :ref:`pallas_blockspec` and the individual `BlockDim` type docstrings for + more details. """ # An internal canonicalized version is in BlockMapping. - block_shape: Sequence[int | None] | None = None + block_shape: Sequence[BlockDim | int | None] | None = None index_map: Callable[..., Any] | None = None memory_space: Any | None = dataclasses.field(kw_only=True, default=None) - indexing_mode: IndexingMode = dataclasses.field(kw_only=True, default=blocked) pipeline_mode: Buffered | None = None + def __post_init__(self): + if self.index_map is not None: + # TODO(sharadmv): Add this once we have a better way to handle + # index_map equality. + # self.index_map = _IndexMapFunc( + # traceback_util.api_boundary(self.index_map, repro_user_func=True)) + self.index_map = _IndexMapFunc(self.index_map) + def to_block_mapping( self, origin: OriginStr, @@ -335,17 +500,27 @@ def to_block_mapping( index_map_avals: Sequence[jax_core.AbstractValue], index_map_tree: tree_util.PyTreeDef, grid: GridMappingGrid, - mapped_dims: tuple[int, ...], + vmapped_dims: tuple[int, ...], + debug: bool = False, ) -> BlockMapping: + if self.block_shape is not None: + if not hasattr(array_aval, "shape"): + raise ValueError( + "Array type must have a `shape` attribute, but got" + f" {type(array_aval)}" + ) if self.index_map is None: index_map_func = default_index_map(len(array_aval.shape)) - api_util.save_wrapped_fun_sourceinfo(index_map_func, default_index_map) + index_map_dbg = api_util.debug_info("pallas_call index_map", + default_index_map, (),{} + )._replace(arg_names=("",) * len(index_map_avals)) + api_util.save_wrapped_fun_debug_info(index_map_func, index_map_dbg) else: index_map_func = self.index_map if self.block_shape is None: - block_shape = array_aval.shape + block_shape = _canonicalize_block_shape(array_aval.shape) else: - block_shape = self.block_shape # type: ignore + block_shape = _canonicalize_block_shape(self.block_shape) if len(array_aval.shape) != len(block_shape): raise ValueError( f"Block shape for {origin} (= {block_shape}) " @@ -353,16 +528,18 @@ def to_block_mapping( f"array shape {array_aval.shape}." ) - unmapped_block_shape = tuple(s for s in block_shape if s is not None) - block_array_aval = array_aval.update(shape=unmapped_block_shape) - if isinstance(array_aval, jax_core.DShapedArray): - # Get the "max" shape for the ragged array. + ref_block_shape = _get_ref_block_shape(block_shape) + if isinstance(array_aval, ShapedArrayWithMemorySpace): block_array_aval = jax_core.ShapedArray( - block_array_aval.shape, - block_array_aval.dtype, - block_array_aval.weak_type, + ref_block_shape, array_aval.dtype, array_aval.weak_type ) - block_aval = AbstractMemoryRef(block_array_aval, self.memory_space) + elif isinstance(array_aval, state_types.AbstractLinVal): + if not isinstance(array_aval.inner_aval, jax_core.ShapedArray): + raise NotImplementedError # TODO(mattjj,sharadmv) + block_array_aval = array_aval.inner_aval.update(shape=ref_block_shape) + else: + block_array_aval = array_aval.update(shape=ref_block_shape) + block_aval = state.AbstractRef(block_array_aval, self.memory_space) if ( not jax_core.is_constant_shape(block_aval.shape) @@ -376,50 +553,72 @@ def to_block_mapping( fake_index_map_args, fake_index_map_kwargs = \ index_map_tree.unflatten([False] * index_map_tree.num_leaves) - debug = api_util.debug_info("pallas_call index_map", - index_map_func, fake_index_map_args, - fake_index_map_kwargs) + debug_info = api_util.debug_info( + "pallas_call index_map", + index_map_func, + fake_index_map_args, + fake_index_map_kwargs, + ) flat_index_map_fun, index_map_out_tree_thunk = api_util.flatten_fun( - lu.wrap_init(index_map_func, debug_info=debug), index_map_tree) - with tracing_grid_env(grid, mapped_dims): - jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(index_map_func, debug_info=debug_info), index_map_tree + ) + with tracing_grid_env(grid, vmapped_dims): + jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic( flat_index_map_fun, index_map_avals ) + index_map_out_tree = index_map_out_tree_thunk() + unflat_avals = tree_util.tree_unflatten(index_map_out_tree, out_avals) - mapped_block_shape = tuple(mapped if s is None else s for s in block_shape) - if len(out_avals) != len(block_shape): + if len(unflat_avals) != len(block_shape): raise ValueError( - f"Index map function {debug.func_src_info} for " + f"Index map function {debug_info.func_src_info} for " f"{origin} must return " f"{len(block_shape)} values to match {block_shape=}. " - f"Currently returning {len(out_avals)} values." + f"Currently returning {len(unflat_avals)} values:" ) + # Verify types match + for i, (idx_aval, bd) in enumerate(zip(unflat_avals, block_shape)): + match bd: + case BoundedSlice(): + if not isinstance(idx_aval, indexing.Slice): + raise ValueError( + "index_map returned a value of type" + f" {type(idx_aval)} at position {i} with block dimension" + f" {bd} when it should be pl.Slice" + ) + case Blocked() | Element() | Squeezed() | int(): + if ( + not isinstance(idx_aval, jax_core.ShapedArray) + and not idx_aval.shape + ): + raise ValueError( + "index_map returned a value of type" + f" {type(idx_aval)} at position {i} with block dimension" + f" {bd} when it should be a scalar" + ) for i, ov in enumerate(out_avals): if ov.shape or ov.dtype not in [jnp.int32, jnp.int64]: raise ValueError( - f"Index map function {debug.func_src_info} for " + f"Index map function {debug_info.func_src_info} for " f"{origin} must return integer scalars. Output[{i}] has type " f"{ov}." ) if consts: raise ValueError( - f"Index map function {debug.func_src_info} for " + f"Index map function {debug_info.func_src_info} for " f"{origin} must not capture constants: {consts}" ) - array_aval_shape = _max_shape_from_aval(array_aval) - mapping = BlockMapping( - block_shape=mapped_block_shape, + block_shape=block_shape, transformed_block_aval=block_aval, # There are no transforms by default index_map_jaxpr=jax_core.ClosedJaxpr(jaxpr, consts), - indexing_mode=self.indexing_mode, - array_shape_dtype=jax.ShapeDtypeStruct( - array_aval_shape, array_aval.dtype - ), + index_map_out_tree=index_map_out_tree, + array_aval=array_aval, origin=origin, pipeline_mode=self.pipeline_mode, + debug=debug, ) mapping.check_invariants() return mapping @@ -453,30 +652,27 @@ class BlockMapping: """ # TODO(apaszke,sharadmv): Replace mapped dims in block_shape with a transform. # After all, it's just indexing out singleton dimensions. - block_shape: tuple[Mapped | int, ...] - transformed_block_aval: AbstractMemoryRef + block_shape: tuple[BlockDim, ...] + transformed_block_aval: state.AbstractRef index_map_jaxpr: jax_core.ClosedJaxpr - indexing_mode: IndexingMode - array_shape_dtype: jax.ShapeDtypeStruct # The whole array + index_map_out_tree: tree_util.PyTreeDef + array_aval: jax_core.ShapedArray # The whole array origin: OriginStr transforms: Sequence[MemoryRefTransform] = () pipeline_mode: Buffered | None = None + debug: bool = False def check_invariants(self) -> None: if not config.enable_checks.value: return - unmapped_block_shape = tuple(s for s in self.block_shape if s is not mapped) - assert unmapped_block_shape == self.ref_aval.shape, ( + ref_block_shape = _get_ref_block_shape(self.block_shape) + assert ref_block_shape == self.ref_aval.shape, ( self.block_shape, self.ref_aval.shape) - assert len(self.block_shape) == len(self.array_shape_dtype.shape), ( - self.block_shape, self.array_shape_dtype + assert len(self.block_shape) == len(self.array_aval.shape), ( + self.block_shape, self.array_aval ) assert not self.index_map_jaxpr.consts - assert len(self.block_shape) == len(self.index_map_jaxpr.out_avals), ( - self.block_shape, - self.index_map_jaxpr.out_avals, - ) assert all(ov.shape == () and (ov.dtype == jnp.int32 or ov.dtype == jnp.int64) for ov in self.index_map_jaxpr.out_avals), ( @@ -488,14 +684,14 @@ def replace(self, **kwargs): return new_self @property - def block_aval(self) -> AbstractMemoryRef: + def block_aval(self) -> state.AbstractRef: # If you hit this, make sure you take transforms into account and use either # ref_aval or transformed_block_aval. assert not self.transforms, "Lowering failed to handle transforms" return self.transformed_block_aval @property - def ref_aval(self) -> AbstractMemoryRef | TransformedRef: + def ref_aval(self) -> state.AbstractRef | TransformedRef: """Returns the abstract value of the Ref after transformations.""" if not self.transforms: return self.transformed_block_aval @@ -514,24 +710,86 @@ def compute_start_indices_interpret(self, loop_idx, *args): # updated values since we only care about the return values. block_indices, _ = split_list(block_indices_and_rest, [len(self.block_shape)]) - if isinstance(self.indexing_mode, Blocked): - return tuple(i if b is mapped else b * i - for b, i in zip(self.block_shape, block_indices)) - elif isinstance(self.indexing_mode, Unblocked): - return block_indices - else: - raise RuntimeError(f"Unknown indexing mode: {self.indexing_mode}") + def _get_start_index(i, b): + match b: + case Squeezed() | Element(): + return i + case Blocked(block_size): + return block_size * i + case _: + raise ValueError(f"Unsupported block dim type: {type(b)}") + return tuple( + _get_start_index(i, b) for i, b in zip(block_indices, self.block_shape) + ) def has_trivial_window(self): """If block shape is same as the array shape and index_map returns 0s.""" - for b, s in zip(self.block_shape, self.array_shape_dtype.shape): - if b != s and not (b is mapped and s == 1): + for b, s in zip(self.block_shape, self.array_aval.shape): + if _get_block_dim_size(b) != s: return False for atom in self.index_map_jaxpr.jaxpr.outvars: if not (isinstance(atom, jax_core.Literal) and atom.val == 0): return False return True + def to_block_spec(self) -> BlockSpec: + def index_map(*args): + flat_args = tree_util.tree_leaves(args) + return jax_core.jaxpr_as_fun(self.index_map_jaxpr)(*flat_args) + return BlockSpec( + self.block_shape, + index_map, + memory_space=self.block_aval.memory_space, + pipeline_mode=self.pipeline_mode, + ) + + def to_lojax( + self, index_map_avals, index_map_tree, grid, vmapped_dims + ) -> list[BlockMapping]: + block_aval = self.transformed_block_aval + if not block_aval.inner_aval.is_high: + return [self] + assert self.array_aval.is_high + lo_array_avals = self.array_aval.lo_ty() + block_spec = self.to_block_spec() + if not hasattr(block_aval.inner_aval, "lower_block_spec"): + raise ValueError( + f"Cannot lower block spec {block_spec} on {block_aval.inner_aval}." + " Need to define lower_block_spec method on the type." + ) + lo_block_specs = block_aval.inner_aval.lower_block_spec(block_spec) + return [ + _convert_block_spec_to_block_mapping( + bs, + self.origin, + lo_array_aval, + index_map_avals=index_map_avals, + index_map_tree=index_map_tree, + grid=grid, + vmapped_dims=vmapped_dims, + debug=self.debug, + ) + for bs, lo_array_aval in zip(lo_block_specs, lo_array_avals) + ] + + def __repr__(self): + if self.debug: + return ( + f"BlockMapping(block_shape={self.block_shape}, " + f"transformed_block_aval={self.transformed_block_aval}, " + f"index_map_jaxpr={self.index_map_jaxpr}, " + f"index_map_out_tree={self.index_map_out_tree}, " + f"array_aval={self.array_aval}, " + f"origin={self.origin}, " + f"transforms={self.transforms}, " + f"pipeline_mode={self.pipeline_mode}, " + f"debug={self.debug})" + ) + return f"BlockMapping(block_shape={self.block_shape})" + + def __str__(self): + return self.__repr__() + @contextlib.contextmanager def tracing_grid_env(grid: GridMappingGrid, mapped_dims: tuple[int, ...]): @@ -586,16 +844,18 @@ class GridMapping: block_mappings: tuple[BlockMapping, ...] # The inputs for tracing the index map: the tree and the flat avals index_map_tree: tree_util.PyTreeDef - index_map_avals: tuple[jax_core.AbstractValue] + index_map_avals: tuple[jax_core.AbstractValue, ...] # Which dimensions in `grid` are vmapped. vmapped_dims: tuple[int, ...] + scratch_avals: tuple[jax_core.AbstractValue, ...] num_index_operands: int num_inputs: int num_outputs: int - num_scratch_operands: int get_grid_indices: Callable | None = None local_grid_env: Callable | None = None + # Primarily dictates how much debugging information is printed. + debug: bool = False def check_invariants(self) -> None: if not config.enable_checks.value: return @@ -641,15 +901,14 @@ def replace(self, **kwargs) -> GridMapping: new_self.check_invariants() return new_self - @property - # TODO(necula): deprecate and then remove this property. - def mapped_dims(self) -> tuple[int, ...]: - return self.vmapped_dims - @property def num_dynamic_grid_bounds(self): return sum(b is dynamic_grid_dim for b in self.grid) + @property + def num_scratch_operands(self): + return len(self.scratch_avals) + @property def static_grid(self) -> StaticGrid: if self.num_dynamic_grid_bounds: @@ -696,14 +955,14 @@ def slice_scratch_ops(self): return slice(0, 0) @property - def in_shapes(self) -> Iterable[jax.ShapeDtypeStruct]: + def in_shapes(self) -> Iterable[jax_core.ShapeDtypeStruct]: """The shapes of *index, *inputs.""" index_shapes = ( - jax.ShapeDtypeStruct(ia.shape, ia.dtype) + jax_core.ShapeDtypeStruct(ia.shape, ia.dtype) for ia in self.index_map_avals[len(self.grid) :] ) inputs_shapes = ( - bm.array_shape_dtype + jax_core.ShapeDtypeStruct(bm.array_aval.shape, bm.array_aval.dtype) for bm in self.block_mappings[:self.num_inputs]) return itertools.chain(index_shapes, inputs_shapes) @@ -715,33 +974,77 @@ def block_mappings_output(self) -> Iterable[BlockMapping]: self.num_inputs + self.num_outputs) @property - def out_shapes(self) -> Iterable[jax.ShapeDtypeStruct]: + def out_shapes(self) -> Iterable[jax_core.ShapeDtypeStruct]: return tuple( - bm.array_shape_dtype for bm in self.block_mappings_output) + jax_core.ShapeDtypeStruct(bm.array_aval.shape, bm.array_aval.dtype) + for bm in self.block_mappings_output) + def to_lojax(self): + input_block_mappings, output_block_mappings, () = split_list( + self.block_mappings, + [self.num_inputs, self.num_inputs + self.num_outputs], + ) + updated_input_block_mappings = [ + lo_mapping + for bm in input_block_mappings + for lo_mapping in bm.to_lojax( + self.index_map_avals, + self.index_map_tree, + self.grid, + self.vmapped_dims, + ) + ] + updated_output_block_mappings = [ + lo_mapping + for bm in output_block_mappings + for lo_mapping in bm.to_lojax( + self.index_map_avals, + self.index_map_tree, + self.grid, + self.vmapped_dims, + ) + ] + new_num_inputs = len(updated_input_block_mappings) + new_num_outputs = len(updated_output_block_mappings) + updated_scratch_avals = [ + lo_aval + for aval in self.scratch_avals + for lo_aval in (aval.lo_ty() if aval.is_high else [aval]) + ] + updated_block_mappings = updated_input_block_mappings + updated_output_block_mappings + return self.replace(block_mappings=tuple(updated_block_mappings), + num_inputs=new_num_inputs, + num_outputs=new_num_outputs, + scratch_avals=tuple(updated_scratch_avals)) -def _is_valid_grid_dim(dim: int | jax.Array) -> bool: - if isinstance(dim, jax.Array): - return True - return jax_core.is_dim(dim) + def __repr__(self): + if self.debug: + return ( + f"GridMapping(grid={self.grid}, grid_names={self.grid_names}, " + f"block_mappings={self.block_mappings}, " + f"index_map_tree={self.index_map_tree}, " + f"index_map_avals={self.index_map_avals}, " + f"vmapped_dims={self.vmapped_dims}, " + f"num_index_operands={self.num_index_operands}, " + f"num_inputs={self.num_inputs}, " + f"num_outputs={self.num_outputs}, " + f"num_scratch_operands={self.num_scratch_operands}, " + f"get_grid_indices={self.get_grid_indices}, " + f"local_grid_env={self.local_grid_env}, " + f"debug={self.debug})" + ) + return ( + f"GridMapping(grid={self.grid}, block_mappings={self.block_mappings})" + ) + def __str__(self): + return self.__repr__() -def _max_shape_from_aval(array_aval: jax_core.ShapedArray): - array_aval_shape = list(array_aval.shape) - for i, s in enumerate(array_aval.shape): - try: - aval = jax_core.get_aval(s) - if isinstance(aval, jax_core.DShapedArray): - array_aval_shape[i] = aval.dtype.bound - except OverflowError as e: - # Note - there are annoying cases where on 32 bit hardware, - # a flattened index space may overflow - for these cases, - # we just take the shape as is. - # In most places, this is totally sound to do. - # For ragged/jumble inputs, this will fail downstream. - return array_aval.shape - return tuple(array_aval_shape) +def _is_valid_grid_dim(dim: int | jax_typing.Array) -> bool: + if isinstance(dim, jax_typing.Array): + return True + return jax_core.is_dim(dim) def _convert_block_spec_to_block_mapping( @@ -753,7 +1056,8 @@ def _convert_block_spec_to_block_mapping( index_map_avals: Sequence[jax_core.AbstractValue], index_map_tree: tree_util.PyTreeDef, grid: GridMappingGrid, - mapped_dims: tuple[int, ...], + vmapped_dims: tuple[int, ...], + debug: bool = False, ) -> BlockMapping: if block_spec is no_block_spec: block_spec = BlockSpec(None, None) @@ -763,20 +1067,25 @@ def _convert_block_spec_to_block_mapping( index_map_avals=index_map_avals, index_map_tree=index_map_tree, grid=grid, - mapped_dims=mapped_dims, + vmapped_dims=vmapped_dims, + debug=debug, ) + index_map_grid_aval = jax_core.ShapedArray((), jnp.int32) class ScratchShape(Protocol): def get_array_aval(self) -> jax_core.AbstractValue: ... - def get_ref_aval(self) -> state.AbstractRef: + def get_ref_aval(self) -> state.AbstractRef | TransformedRef: ... -ScratchShapeTree = Sequence[Union[ScratchShape, "ScratchShapeTree"]] +ScratchShapeTree = ( + Sequence[Union[ScratchShape, "ScratchShapeTree"]] + | Mapping[str, Union[ScratchShape, "ScratchShapeTree"]] +) @dataclasses.dataclass(init=False, kw_only=True) @@ -839,8 +1148,8 @@ def get_grid_mapping( out_avals: Sequence[jax_core.AbstractValue], out_tree: tree_util.PyTreeDef, out_origins: Sequence[OriginStr], -) -> tuple[tuple[jax_core.AbstractValue, ...], - GridMapping]: + debug: bool = False, +) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]: if dynamic_shapes_export_enabled(): dim_check : Any = jax_core.is_dim else: @@ -879,15 +1188,14 @@ def get_grid_mapping( if grid_spec.scratch_shapes: flat_scratch_shapes, scratch_tree = tree_util.tree_flatten( grid_spec.scratch_shapes) - flat_scratch_avals = map(lambda s: s.get_ref_aval(), flat_scratch_shapes) - num_flat_scratch_operands = len(flat_scratch_avals) + flat_scratch_avals = tuple(s.get_ref_aval() for s in flat_scratch_shapes) jaxpr_scratch_avals = tree_util.tree_unflatten( scratch_tree, flat_scratch_avals) if not isinstance(jaxpr_scratch_avals, (tuple, list)): jaxpr_scratch_avals = (jaxpr_scratch_avals,) - del flat_scratch_avals, flat_scratch_shapes, scratch_tree + del flat_scratch_shapes, scratch_tree else: - num_flat_scratch_operands = 0 + flat_scratch_avals = () jaxpr_scratch_avals = () if grid_spec.in_specs is not no_block_spec: @@ -895,7 +1203,7 @@ def get_grid_mapping( if in_specs_tree != in_tree: raise ValueError( pytreedef_mismatch_err_msg("`in_specs`", in_specs_tree, - "inputs", in_tree)) + "`inputs`", in_tree)) else: flat_in_specs = [no_block_spec] * len(in_avals) @@ -905,7 +1213,8 @@ def get_grid_mapping( index_map_avals=index_map_avals, index_map_tree=index_map_tree, grid=grid_mapping_grid, # type: ignore[arg-type] - mapped_dims=(), + vmapped_dims=(), + debug=debug, ), flat_in_specs, in_origins[num_flat_scalar_prefetch:], @@ -927,7 +1236,8 @@ def get_grid_mapping( index_map_avals=index_map_avals, index_map_tree=index_map_tree, grid=grid_mapping_grid, # type: ignore[arg-type] - mapped_dims=(), + vmapped_dims=(), + debug=debug, ), flat_out_specs, out_origins, @@ -943,7 +1253,8 @@ def get_grid_mapping( num_index_operands=num_flat_scalar_prefetch, num_inputs=len(flat_in_specs), num_outputs=len(flat_out_specs), - num_scratch_operands=num_flat_scratch_operands, + scratch_avals=flat_scratch_avals, + debug=debug, ) grid_mapping.check_invariants() in_ref_avals = [bm.ref_aval for bm in in_block_mappings] @@ -993,23 +1304,74 @@ class CostEstimate: flops: int transcendentals: int bytes_accessed: int + remote_bytes_transferred: int = 0 def __post_init__(self): for k, v in dataclasses.asdict(self).items(): if not isinstance(v, int): - raise ValueError("All fields in CostEstimate must be ints. " - f"{k} is not an int: {type(v)}({v})") + raise ValueError( + "All fields in CostEstimate must be ints. " + f"{k} is not an int: {type(v)}({v})" + ) def to_json(self) -> bytes: return ( f'{{"flops": {self.flops}, "transcendentals": {self.transcendentals},' - f' "bytes_accessed": {self.bytes_accessed}}}' + f' "bytes_accessed": {self.bytes_accessed},' + f' "remote_bytes_transferred": {self.remote_bytes_transferred}}}' ).encode("ascii") +def get_memory_space_aval(aval: jax_core.AbstractValue) -> Any: + """Queries the memory space of an array.""" + if isinstance(aval, ShapedArrayWithMemorySpace): + return aval.memory_space + if isinstance(aval, state.AbstractRef): + if aval.memory_space is not None: + return aval.memory_space + return get_memory_space_aval(aval.inner_aval) + return None + +def _get_sds(aval: jax_core.AbstractValue): + match aval: + case state.AbstractRef(inner_aval=inner_aval): + if aval.memory_space is not None: + return aval.memory_space(aval.shape, aval.dtype) + return _get_sds(inner_aval) + case ShapedArrayWithMemorySpace(): + return aval.memory_space(aval.shape, aval.dtype) + case jax_core.ShapedArray(): + return jax_core.ShapeDtypeStruct( + aval.shape, aval.dtype, vma=aval.vma, sharding=aval.sharding + ) + case _: + raise ValueError(f"Unsupported abstract value: {aval}") + + core_map_p = jax_core.Primitive("core_map") core_map_p.multiple_results = True +def _core_map_is_high(*avals, jaxpr, **params): + del avals, params + return jaxpr.is_high +core_map_p.is_high = _core_map_is_high # type: ignore[method-assign] + +def _core_map_to_lojax(*consts, jaxpr, mesh, **params): + closed_hi_jaxpr = jax_core.ClosedJaxpr(jaxpr, consts) + with ( + tracing_grid_env(tuple(mesh.shape.values()), mapped_dims=()), + jax_core.extend_axis_env_nd(mesh.shape.items()), + ): + closed_lo_jaxpr = pe.lower_jaxpr(closed_hi_jaxpr) + assert not closed_lo_jaxpr.is_high + return core_map_p.bind( + *closed_lo_jaxpr.consts, + jaxpr=closed_lo_jaxpr.jaxpr, + mesh=mesh, + **params, + ) +core_map_p.to_lojax = _core_map_to_lojax + def core_map( mesh, @@ -1019,6 +1381,7 @@ def core_map( debug: bool = False, cost_estimate: CostEstimate | None = None, name: str | None = None, + metadata: dict[str, str] | None = None, ): """Runs a function on a mesh, mapping it over the devices in the mesh. @@ -1031,36 +1394,86 @@ def core_map( interpret: Whether to run the function in interpret mode. debug: Whether or not to out helpful debugging information. cost_estimate: The cost estimate of the function. + name: The (optional) name of the kernel. + metadata: Optional dictionary of information about the kernel that will be + serialized as JSON in the HLO. Can be used for debugging and analysis. """ def wrapped(f): - name_ = name or f.__name__ flat_args, in_tree = tree_util.tree_flatten(((), {})) + debug_info = api_util.debug_info("pallas_core_map", f, (), {}) flat_fun, out_tree_thunk = api_util.flatten_fun( - lu.wrap_init(f, - debug_info=api_util.debug_info("pallas_core_map", f, - (), {})), - in_tree) - with jax_core.extend_axis_env_nd(mesh.shape.items()): - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, flat_args) - out = core_map_p.bind(*consts, jaxpr=jaxpr, mesh=mesh, - compiler_params=compiler_params, - interpret=interpret, - debug=debug, - cost_estimate=cost_estimate, name=name_) - if out: - raise ValueError("core_map-ped functions must not return any outputs.") - return tree_util.tree_unflatten(out_tree_thunk(), out) + lu.wrap_init(f, debug_info=debug_info), in_tree + ) + with ( + tracing_grid_env(tuple(mesh.shape.values()), mapped_dims=()), + jax_core.extend_axis_env_nd(mesh.shape.items()), + ): + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, flat_args) + + out_tree = out_tree_thunk() + if out_tree != tree_util.tree_structure(None): + raise ValueError( + f"The kernel function in core_map {debug_info.func_src_info} should" + f" return None. It returns a PyTree: {out_tree}." + ) + + out = core_map_p.bind( + *consts, + jaxpr=jaxpr, + debug_info=debug_info, + mesh=mesh, + compiler_params=compiler_params, + interpret=( + config.pallas_tpu_interpret_mode_context_manager.value or interpret + ), + debug=debug, + cost_estimate=cost_estimate, + name=name or util.fun_name(f), + metadata=frozen_dict.FrozenDict(metadata) + if metadata is not None + else None, + ) + return tree_util.tree_unflatten(out_tree, out) + return wrapped +# TODO(sharadmv,ivyzheng): remove this once we use axis dicts primarily +class CommsEffect(effects.Effect): + pass + +comms_effect = CommsEffect() +effects.lowerable_effects.add_type(CommsEffect) +effects.control_flow_allowed_effects.add_type(CommsEffect) +effects.remat_allowed_effects.add_type(CommsEffect) +effects.custom_derivatives_allowed_effects.add_type(CommsEffect) + +kernel_local_effects: effects.EffectTypeSet = effects.EffectTypeSet() + @core_map_p.def_effectful_abstract_eval -def _core_map_abstract_eval(*args, jaxpr, mesh, **_): +def _core_map_abstract_eval(*args, jaxpr, mesh, **kwargs): del args if jaxpr.outvars: raise ValueError("core_map must not return any outputs.") + interpret = kwargs.get('interpret', False) effs = set() + if interpret: + try: + from jax._src.pallas.mosaic.interpret import interpret_pallas_call as mosaic_tpu_interpret # Avoid circular dependency. + if isinstance(interpret, mosaic_tpu_interpret.InterpretParams): + effs = mosaic_tpu_interpret.get_interpret_effects() + except ImportError: + pass + try: + from jax._src.pallas.mosaic_gpu.interpret import interpret_pallas_call as mosaic_gpu_interpret # Avoid circular dependency. + if isinstance(interpret, mosaic_gpu_interpret.InterpretParams): + effs = mosaic_gpu_interpret.get_interpret_effects() + except ImportError: + pass for eff in jaxpr.effects: - if mesh.discharges_effect(eff): + if mesh.discharges_effect(eff) or isinstance(eff, CommsEffect): + continue + if kernel_local_effects.contains(eff): continue if not isinstance(eff, jax_core.NamedAxisEffect): effs.add(eff) @@ -1070,10 +1483,23 @@ def _core_map_abstract_eval(*args, jaxpr, mesh, **_): return [], effs +def core_map_lowering_rule(ctx: mlir.LoweringRuleContext, + *args, + jaxpr, + **kwargs + ): + del ctx, args, kwargs + raise ValueError( + "Attempted to lower core_map without discharging. This can happen if " + "the core_map body does not modify any Refs or have other observable " + f"side-effects.\n Jaxpr of the body: {jaxpr}") +mlir.register_lowering(core_map_p, core_map_lowering_rule) + + class Mesh(Protocol): @property - def backend(self) -> str: + def backend(self) -> Backend: ... @property @@ -1084,6 +1510,32 @@ def shape(self) -> collections.OrderedDict[object, int]: _core_map_mesh_rules: dict[type[Any], Callable[..., Any]] = {} +with_memory_space_constraint_p = jax_core.Primitive( + 'with_memory_space_constraint') + +@with_memory_space_constraint_p.def_impl +def with_memory_space_constraint_impl(x, *, memory_space): + del x, memory_space + raise ValueError("Cannot eagerly run with_memory_space_constraint.") + + +@with_memory_space_constraint_p.def_abstract_eval +def with_memory_space_constraint_abstract_eval(x, *, memory_space): + if not isinstance(x, jax_core.ShapedArray): + raise NotImplementedError("with_memory_space_constraint only supports " + "arrays.") + return ShapedArrayWithMemorySpace( + x.shape, x.dtype, memory_space=memory_space + ) + +def with_memory_space_constraint_lowering_rule(ctx, x, *, memory_space): + del ctx, memory_space + return [x] +mlir.register_lowering( + with_memory_space_constraint_p, with_memory_space_constraint_lowering_rule +) + + def default_mesh_discharge_rule( in_avals, out_avals, @@ -1095,9 +1547,18 @@ def default_mesh_discharge_rule( interpret, cost_estimate, name, + memory_space=MemorySpace.ANY, + metadata, + scratch_shapes, ): """Discharges a ``core_map`` over a mesh to a ``pallas_call``.""" - del out_avals # Unused. + default_memory_space = memory_space + if not all( + isinstance(aval, state.AbstractRef) for aval in (in_avals + out_avals) + ): + raise ValueError( + "default_mesh_discharge_rule only supports Ref inputs/outputs." + ) def body(*args): # Due to aliasing, ``args`` contains aliased inputs and outputs so we @@ -1111,26 +1572,40 @@ def body(*args): for eff in jaxpr.effects if isinstance(eff, state_types.WriteEffect) ) - any_spec = BlockSpec(memory_space=MemorySpace.ANY) - grid_spec = GridSpec( - grid=tuple(mesh.shape.items()), - in_specs=[any_spec] * len(in_avals), - out_specs=[any_spec] * len(modified_idxs), - ) + in_memory_spaces = [get_memory_space_aval(aval) for aval in in_avals] + in_memory_spaces = [ + memory_space if m is None else m for m in in_memory_spaces + ] + args = [ + with_memory_space_constraint_p.bind(arg, memory_space=memory_space) + if memory_space is not None and memory_space is not default_memory_space else arg + for arg, memory_space in zip(args, in_memory_spaces) + ] + in_specs = [ + BlockSpec(memory_space=memory_space) for memory_space in in_memory_spaces + ] + out_specs = [in_specs[idx] for idx in modified_idxs] + out_shapes = [_get_sds(in_avals[idx]) for idx in modified_idxs] from jax._src.pallas import pallas_call # Avoid circular dependency. outs = pallas_call._pallas_call( body, name=name, - out_shape=[in_avals[idx] for idx in modified_idxs], + out_shape=out_shapes, input_output_aliases={ in_idx: out_idx for out_idx, in_idx in enumerate(modified_idxs) }, - grid_spec=grid_spec, + grid_spec=GridSpec( + grid=tuple(mesh.shape.items()), + in_specs=in_specs, + out_specs=out_specs, + scratch_shapes=scratch_shapes, + ), mesh=mesh, compiler_params=compiler_params, interpret=interpret, debug=debug, cost_estimate=cost_estimate, + metadata=metadata, )(*args) # ``outs`` lacks the unmodified inputs. Add them back in. all_outs = [None] * len(args) @@ -1140,21 +1615,65 @@ def body(*args): @state_discharge.register_discharge_rule(core_map_p) -def _core_map_discharge_rule(in_avals, out_avals, *args_flat, jaxpr, mesh, **kwargs): +def _core_map_discharge_rule(in_avals, out_avals, *args_flat, jaxpr, debug_info, mesh, **kwargs): if type(mesh) not in _core_map_mesh_rules: raise NotImplementedError(f"Mesh type {type(mesh)} not supported.") + if jaxpr.constvars: + # The mapped jaxpr can only close over refs. Closing over anything else, + # including arrays, is not allowed -- these must be passed into the jaxpr + # as inputs. + consts_avals = [ + aval + for var in jaxpr.constvars + if not isinstance(aval := var.aval, state.AbstractRef) + ] + is_scalar_const_aval = [ + isinstance(aval, jax_core.ShapedArray) and not aval.shape + for aval in consts_avals + ] + if not all(is_scalar_const_aval): + ctx = jax_core.JaxprPpContext() + non_scalar_const_avals = [ + aval + for aval, is_scalar in zip(consts_avals, is_scalar_const_aval) + if not is_scalar + ] + non_scalar_const_pp_avals = ", ".join( + jax_core.pp_aval(aval, ctx) for aval in non_scalar_const_avals + ) + raise ValueError( + "The kernel function in core_map" + f" {debug_info.func_src_info} captures non-scalar constants" + f" [{non_scalar_const_pp_avals}]. You should pass them as inputs." + ) return _core_map_mesh_rules[type(mesh)]( in_avals, out_avals, *args_flat, jaxpr=jaxpr, mesh=mesh, **kwargs ) def _core_map_typecheck_rule(_, *in_atoms, jaxpr, mesh, **kwargs): - del in_atoms, kwargs + del in_atoms with jax_core.extend_axis_env_nd(tuple(mesh.shape.items())): jax_core.check_jaxpr(jaxpr) + interpret = kwargs.get('interpret', False) effs = set() + if interpret: + try: + from jax._src.pallas.mosaic.interpret import interpret_pallas_call as mosaic_tpu_interpret # Avoid circular dependency. + if isinstance(interpret, mosaic_tpu_interpret.InterpretParams): + effs = mosaic_tpu_interpret.get_interpret_effects() + except ImportError: + pass + try: + from jax._src.pallas.mosaic_gpu.interpret import interpret_pallas_call as mosaic_gpu_interpret # Avoid circular dependency. + if isinstance(interpret, mosaic_gpu_interpret.InterpretParams): + effs = mosaic_gpu_interpret.get_interpret_effects() + except ImportError: + pass for eff in jaxpr.effects: - if mesh.discharges_effect(eff): + if mesh.discharges_effect(eff) or isinstance(eff, CommsEffect): + continue + if kernel_local_effects.contains(eff): continue if not isinstance(eff, jax_core.NamedAxisEffect): effs.add(eff) @@ -1166,11 +1685,19 @@ def _core_map_typecheck_rule(_, *in_atoms, jaxpr, mesh, **kwargs): def lower_as_mlir( - f, *args, dynamic_shapes=False, device=None, static_argnames=(), **kwargs + f, + *args, + dynamic_shapes=False, + device=None, + static_argnames=(), + platforms=None, + **kwargs, ) -> mlir.ir.Module: with pallas_export_experimental(dynamic_shapes): - f = jax.jit(f, device=device, static_argnames=static_argnames) - exported = export(f, platforms=["tpu"])(*args, **kwargs) + f = jit(f, device=device, static_argnames=static_argnames) + if platforms is None: + platforms = ["tpu"] + exported = export(f, platforms=platforms)(*args, **kwargs) stablehlo = exported.mlir_module() return stablehlo # type: ignore[return-value] @@ -1179,3 +1706,18 @@ def lower_as_mlir( _out_shape_to_aval_mapping: dict[ type[Any], Callable[[Any], jax_core.AbstractValue] ] = {} + + +def _core_map_partial_eval_custom(saveable, unks_in, inst_in, eqn): + assert all(inst_in) + if all(unks_in): + return None, eqn, [], [], [] # purely unknown + elif not any(unks_in): + return eqn, eqn, [], [], [] # full remat + else: + # Some values, e.g. empty refs or refs initialized to constant zero, can be + # 'known', but really they belong in the staged/tangent computation. We + # encounter them here as known inputs mixed in with unknown/tangent inputs, + # which tells us that this core_map is really a purely tangent computation. + return None, eqn, [], [], [] +pe.partial_eval_jaxpr_custom_rules[core_map_p] = _core_map_partial_eval_custom diff --git a/jax/_src/pallas/cost_estimate.py b/jax/_src/pallas/cost_estimate.py index 73db4a2e2d4a..83d35e2ae977 100644 --- a/jax/_src/pallas/cost_estimate.py +++ b/jax/_src/pallas/cost_estimate.py @@ -15,9 +15,10 @@ import dataclasses import functools import math -from typing import Any, Sequence +from typing import Any +from collections.abc import Sequence -import jax +from jax._src import tree_util from jax._src import api_util from jax._src import core as jax_core from jax._src import custom_derivatives @@ -64,12 +65,11 @@ def cost_estimate_jaxpr( total_cost = CostEstimate(flops=0, transcendentals=0, bytes_accessed=0) for eqn in jaxpr.eqns: - _, bind_params = eqn.primitive.get_bind_params(eqn.params) rule = _cost_rules.get(eqn.primitive, None) if rule is not None: context = Context(avals_in=[v.aval for v in eqn.invars], avals_out=[v.aval for v in eqn.outvars]) - op_cost = rule(context, **bind_params) + op_cost = rule(context, **eqn.params) total_cost = total_cost + op_cost return pallas_core.CostEstimate( flops=total_cost.flops, @@ -89,15 +89,15 @@ def estimate_cost(fun, *args, **kwargs) -> pallas_core.CostEstimate: Returns: A pallas_core.CostEstimate object containing the cost estimate. """ - flattened_args, treedef = jax.tree.flatten(args) + flattened_args, treedef = tree_util.tree_flatten(args) partial_fun = functools.partial(fun, **kwargs) wrapped_fun, _ = api_util.flatten_fun_nokwargs( lu.wrap_init(partial_fun, debug_info=api_util.debug_info("cost_estimate", fun, - args, kwargs)), + args, {})), treedef) avals = [jax_core.ShapedArray(a.shape, a.dtype) for a in flattened_args] - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals) estimate = cost_estimate_jaxpr(jax_core.ClosedJaxpr(jaxpr, consts)) input_bytes = sum( math.prod(a.shape) * a.dtype.itemsize for a in flattened_args) @@ -205,10 +205,12 @@ def dot_general_cost_rule(ctx: Context, assert len(lhs_batch_dims) == len(rhs_batch_dims) flops = 1 # Flops along a contracting dim is 2*dim (addition and multiplication) + contracting_flops = 1 for i in range(len(lhs_contracting_dims)): lhs_dim, rhs_dim = lhs_contracting_dims[i], rhs_contracting_dims[i] assert x_shape[lhs_dim] == y_shape[rhs_dim] - flops *= 2 * x_shape[lhs_dim] + contracting_flops *= x_shape[lhs_dim] + flops *= 2 * contracting_flops # Now we handle all other dimensions. for i, lhs_dim in enumerate(x_shape): if i in lhs_contracting_dims: @@ -237,17 +239,17 @@ def _pjit_cost_rule(ctx, *, jaxpr: jax_core.ClosedJaxpr, **_): transcendentals=inner_cost.transcendentals, bytes_accessed=inner_cost.bytes_accessed, ) -register_cost_rule(pjit.pjit_p, _pjit_cost_rule) +register_cost_rule(pjit.jit_p, _pjit_cost_rule) -def _custom_vjp_rule(ctx, *, fun_jaxpr: jax_core.ClosedJaxpr, **_): +def _custom_vjp_rule(ctx, *, call_jaxpr: jax_core.ClosedJaxpr, **_): del ctx - inner_cost = cost_estimate_jaxpr(fun_jaxpr) + inner_cost = cost_estimate_jaxpr(call_jaxpr) return CostEstimate( flops=inner_cost.flops, transcendentals=inner_cost.transcendentals, bytes_accessed=inner_cost.bytes_accessed, ) -register_cost_rule(custom_derivatives.custom_vjp_call_jaxpr_p, _custom_vjp_rule) +register_cost_rule(custom_derivatives.custom_vjp_call_p, _custom_vjp_rule) def _run_state_rule(*_, jaxpr: jax_core.Jaxpr, **_2): inner_cost = cost_estimate_jaxpr(pe.close_jaxpr(jaxpr)) diff --git a/jax/_src/pallas/fuser/BUILD b/jax/_src/pallas/fuser/BUILD index 66bbac33aabb..fa52ef43ed4c 100644 --- a/jax/_src/pallas/fuser/BUILD +++ b/jax/_src/pallas/fuser/BUILD @@ -33,7 +33,8 @@ pytype_strict_library( deps = [ ":block_spec", ":custom_evaluate", - ":fusable", + ":custom_fusion", + ":fusible", ":fusion", ":jaxpr_fusion", ], @@ -47,30 +48,60 @@ pytype_strict_library( deps = [ ":fuser_utils", "//jax", - "//jax:ad_util", - "//jax:api_util", - "//jax:core", - "//jax:partial_eval", - "//jax:tree_util", - "//jax:util", + "//jax/_src:ad_util", + "//jax/_src:api", + "//jax/_src:api_util", + "//jax/_src:core", + "//jax/_src:custom_derivatives", + "//jax/_src:hijax", + "//jax/_src:lax", + "//jax/_src:partial_eval", + "//jax/_src:random", + "//jax/_src:traceback_util", + "//jax/_src:tree_util", + "//jax/_src:typing", + "//jax/_src:util", "//jax/_src/pallas", ] + py_deps("numpy"), ) pytype_strict_library( - name = "fusable", + name = "custom_fusion", srcs = [ - "fusable.py", + "custom_fusion_lib.py", + ], + deps = [ + ":block_spec", + "//jax/_src:api_util", + "//jax/_src:core", + "//jax/_src:custom_api_util", + "//jax/_src:lax", + "//jax/_src:mlir", + "//jax/_src:partial_eval", + "//jax/_src:traceback_util", + "//jax/_src:tree_util", + "//jax/_src:util", + "//jax/_src/pallas", + "//jax/_src/pallas/mosaic:lowering", + ], +) + +pytype_strict_library( + name = "fusible", + srcs = [ + "fusible.py", ], deps = [ ":fusion", "//jax", - "//jax:api_util", - "//jax:core", - "//jax:mlir", - "//jax:partial_eval", - "//jax:tree_util", - "//jax:util", + "//jax/_src:api_util", + "//jax/_src:batching", + "//jax/_src:core", + "//jax/_src:mlir", + "//jax/_src:partial_eval", + "//jax/_src:traceback_util", + "//jax/_src:tree_util", + "//jax/_src:util", ], ) @@ -81,7 +112,7 @@ pytype_strict_library( ], deps = [ "//jax", - "//jax:util", + "//jax/_src:util", ], ) @@ -91,33 +122,37 @@ pytype_strict_library( "jaxpr_fusion.py", ], deps = [ - ":fusable", - ":fusable_dtype", + ":fusible", + ":fusible_dtype", ":fusion", "//jax", - "//jax:api_util", - "//jax:core", - "//jax:partial_eval", - "//jax:tree_util", + "//jax/_src:api_util", + "//jax/_src:core", + "//jax/_src:partial_eval", + "//jax/_src:traceback_util", + "//jax/_src:tree_util", + "//jax/_src:util", ], ) pytype_strict_library( - name = "fusable_dtype", + name = "fusible_dtype", srcs = [ - "fusable_dtype.py", + "fusible_dtype.py", ], deps = [ ":block_spec", - ":fusable", + ":fusible", "//jax", - "//jax:api_util", - "//jax:core", - "//jax:dtypes", - "//jax:partial_eval", - "//jax:source_info_util", - "//jax:tree_util", - "//jax:util", + "//jax/_src:api_util", + "//jax/_src:core", + "//jax/_src:custom_derivatives", + "//jax/_src:dtypes", + "//jax/_src:lax", + "//jax/_src:partial_eval", + "//jax/_src:source_info_util", + "//jax/_src:tree_util", + "//jax/_src:util", "//jax/_src/pallas", ], ) @@ -128,10 +163,10 @@ pytype_strict_library( deps = [ ":fuser_utils", "//jax", - "//jax:core", - "//jax:source_info_util", - "//jax:tree_util", - "//jax:util", + "//jax/_src:core", + "//jax/_src:source_info_util", + "//jax/_src:tree_util", + "//jax/_src:util", ], ) @@ -139,9 +174,9 @@ pytype_strict_library( name = "fuser_utils", srcs = ["fuser_utils.py"], deps = [ - "//jax:api_util", - "//jax:core", - "//jax:partial_eval", - "//jax:tree_util", + "//jax/_src:api_util", + "//jax/_src:core", + "//jax/_src:partial_eval", + "//jax/_src:tree_util", ], ) diff --git a/jax/_src/pallas/fuser/__init__.py b/jax/_src/pallas/fuser/__init__.py index 3295c8f1061a..fb4a47679293 100644 --- a/jax/_src/pallas/fuser/__init__.py +++ b/jax/_src/pallas/fuser/__init__.py @@ -17,6 +17,7 @@ from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec from jax._src.pallas.fuser.custom_evaluate import evaluate as evaluate -from jax._src.pallas.fuser.fusable import fusable as fusable +from jax._src.pallas.fuser.custom_fusion_lib import custom_fusion as custom_fusion +from jax._src.pallas.fuser.fusible import fusible as fusible from jax._src.pallas.fuser.fusion import Fusion as Fusion from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse diff --git a/jax/_src/pallas/fuser/block_spec.py b/jax/_src/pallas/fuser/block_spec.py index de0cdd204f3c..1d9a1f028c38 100644 --- a/jax/_src/pallas/fuser/block_spec.py +++ b/jax/_src/pallas/fuser/block_spec.py @@ -16,24 +16,33 @@ from __future__ import annotations +from collections.abc import Callable, Sequence import contextlib import dataclasses import enum import functools import threading -from typing import Any, Callable, Protocol, Sequence +from typing import Any, Protocol import jax from jax import lax from jax._src import ad_util from jax._src import core from jax._src import custom_derivatives +from jax._src import hijax from jax._src import pjit +from jax._src import prng +from jax._src import state from jax._src import tree_util +from jax._src import typing from jax._src import util from jax._src.interpreters import partial_eval as pe from jax._src.pallas import core as pallas_core +from jax._src.pallas import utils as pallas_utils from jax._src.pallas.fuser import fuser_utils +from jax._src.state import indexing +from jax._src.state import primitives as state_primitives +from jax._src.traceback_util import api_boundary import jax.numpy as jnp import numpy as np @@ -67,6 +76,7 @@ class PushRuleContext: avals_out: tuple[core.AbstractValue, ...] +@functools.partial(api_boundary, repro_api_name="fuser.make_scalar_prefetch_handler") def make_scalar_prefetch_handler(*args): def scalar_prefetch_getter(*sp_inputs): result = sp_inputs @@ -77,21 +87,18 @@ def scalar_prefetch_getter(*sp_inputs): return scalar_prefetch_getter -def _default_eval_fn(eqn, eval_ctx, *args): - del eval_ctx - out = eqn.primitive.bind(*args, **eqn.params) - if eqn.primitive.multiple_results: - return out - return [out] - - -def _wrap_eval_fn(primitive, eval_fn): - def wrapped(*args): - if primitive.multiple_results: - return eval_fn(*args) - return [eval_fn(*args)] - - return wrapped +def _block_size(dim: pallas_core.Element | int | None) -> int | None: + match dim: + case ( + pallas_core.Element() + | pallas_core.BoundedSlice() + | pallas_core.Blocked() + ): + return dim.block_size + case pallas_core.Squeezed() | None: + return None + case _: + return dim # pytype: disable=bad-return-type @dataclasses.dataclass @@ -144,8 +151,8 @@ class KernelEvalContext: program_ids: tuple[int | jax.Array, ...] | None avals_in: tuple[core.AbstractValue, ...] | None avals_out: tuple[core.AbstractValue, ...] | None - in_block_specs: tuple[pallas_core.BlockSpec, ...] | None - out_block_specs: tuple[pallas_core.BlockSpec, ...] | None + in_block_specs: tuple[pallas_core.BlockSpec, ...] + out_block_specs: tuple[pallas_core.BlockSpec, ...] grid: tuple[int | jax.Array, ...] | None scalar_prefetch_handler: Any | None out_usages: tuple[set[Usage], ...] | None @@ -170,8 +177,14 @@ def get_out_block_indices(self): _illegal = object() -_sp_env = threading.local() -_sp_env.scalar_prefetch = None + +class _SpEnv(threading.local): + + def __init__(self): + self.scalar_prefetch = None + + +_sp_env = _SpEnv() @contextlib.contextmanager @@ -192,7 +205,7 @@ def _wrap_block_spec_scalar_prefetch( block_spec: pallas_core.BlockSpec, num_grid_args: int, ) -> pallas_core.BlockSpec: - if block_spec is pallas_core.no_block_spec: + if block_spec is pallas_core.no_block_spec or block_spec.index_map is None: return block_spec def new_index_map(*args_and_scalar_prefetch): @@ -225,6 +238,7 @@ def new_index_map(*args): return out_block_spec +@functools.partial(api_boundary, repro_api_name="fuser.pull_block_spec") def pull_block_spec( f: Callable, out_block_specs: pallas_core.BlockSpec | tuple[pallas_core.BlockSpec, ...], @@ -236,9 +250,7 @@ def wrapped(*args, **kwargs): jaxpr, consts, in_tree, out_tree_ = fuser_utils.make_jaxpr( f, *args, **kwargs ) - # TODO(sharadmv): handle these consts better, they should correspond to - # scalar prefetch. - del consts, out_tree_ + del out_tree_ jaxpr_out_usages = [{Usage.REGULAR}] * len(jaxpr.outvars) block_specs_ = jax.tree.map( _unwrap_block_spec_scalar_prefetch, out_block_specs @@ -251,15 +263,17 @@ def wrapped(*args, **kwargs): ) assert all(used_invars) assert all(used_consts) + read_usage_env = compute_usage(jaxpr, jaxpr_out_usages) in_block_specs, env, read_usage_env = _pull_block_spec( jaxpr, tuple(flat_block_specs), - jaxpr_out_usages, scalar_prefetch_handler=scalar_prefetch_handler, + read_usage_env=read_usage_env, grid=grid, ) kernel_fn = make_kernel_function( jaxpr, + consts, in_tree, out_tree, read_usage_env, @@ -282,11 +296,39 @@ def wrapped(*args, **kwargs): return wrapped +def _block_dim_equal( + b1: int | pallas_core.BlockDim | None, b2: int | pallas_core.BlockDim | None +) -> bool: + block_size1 = pallas_core.get_block_size(b1) + block_size2 = pallas_core.get_block_size(b2) + match (b1, b2): + case (None, _) | (_, None): + return b1 == b2 + case ( + (pallas_core.Blocked(), int()) + | (int(), pallas_core.Blocked()) + | (pallas_core.Blocked(), pallas_core.Blocked()) + | (int(), int()) + ): + return block_size1 == block_size2 + case _: + return type(b1) == type(b2) and (block_size1 == block_size2) + + +def _block_shapes_equal( + bs1: tuple[int | pallas_core.BlockDim | None] | None, + bs2: tuple[int | pallas_core.BlockDim | None] | None, +) -> bool: + if bs1 is None or bs2 is None: + return bs1 == bs2 + return all(_block_dim_equal(b1, b2) for b1, b2 in zip(bs1, bs2)) + + def _pull_block_spec( jaxpr: core.Jaxpr, out_block_specs: tuple[pallas_core.BlockSpec, ...], - out_usages, *, + read_usage_env: Callable[[core.Var], set[Usage]], scalar_prefetch_handler: Any | None = None, grid: tuple[int | jax.Array, ...], ) -> tuple[ @@ -294,7 +336,6 @@ def _pull_block_spec( tuple[dict[core.Var, pallas_core.BlockSpec], dict[int, Any]], Any, ]: - read_usage_env = compute_usage(jaxpr, out_usages) jaxpr_invar_usages = util.safe_map(read_usage_env, jaxpr.invars) env: dict[core.Var, pallas_core.BlockSpec] = {} scalar_prefetch_fn_env = {} @@ -306,7 +347,7 @@ def _pull_block_spec( def _read_block_spec(atom: core.Atom) -> pallas_core.BlockSpec | Any: if isinstance(atom, core.Literal): return pallas_core.no_block_spec - return env[atom] + return env.get(atom, pallas_core.no_block_spec) def _write_block_spec(atom: core.Atom, block_spec: pallas_core.BlockSpec): if isinstance(atom, core.Literal): @@ -315,9 +356,11 @@ def _write_block_spec(atom: core.Atom, block_spec: pallas_core.BlockSpec): for i, eqn in reversed(list(enumerate(jaxpr.eqns))): eqn_out_block_specs = tuple(util.safe_map(_read_block_spec, eqn.outvars)) + if all(bs is pallas_core.no_block_spec for bs in eqn_out_block_specs): + continue rule = pull_block_spec_rules.get(eqn.primitive, None) if not rule: - raise NotImplementedError(eqn.primitive) + raise NotImplementedError(eqn.primitive, eqn_out_block_specs) ctx = PullRuleContext( avals_in=tuple(v.aval for v in eqn.invars), avals_out=tuple(v.aval for v in eqn.outvars), @@ -348,7 +391,7 @@ def _write_block_spec(atom: core.Atom, block_spec: pallas_core.BlockSpec): jaxpr.invars, needed_invars, jaxpr.eqns[: jaxpr.eqns.index(eqn)], - debug_info=jaxpr.debug_info, + debug_info=jaxpr.debug_info._replace(result_paths=None), ) scalar_prefetch_jaxpr, used_consts, used_invars = pe.dce_jaxpr_consts( scalar_prefetch_jaxpr_no_dce, @@ -358,6 +401,7 @@ def _write_block_spec(atom: core.Atom, block_spec: pallas_core.BlockSpec): scalar_prefetch_jaxpr = scalar_prefetch_jaxpr.replace( constvars=[], invars=jaxpr.constvars, + debug_info=scalar_prefetch_jaxpr.debug_info.with_unknown_names(), ) def _scalar_prefetch_fn(jaxpr): @@ -373,14 +417,14 @@ def _scalar_prefetch_fn(jaxpr): ) ctx.scalar_prefetch_fn = scalar_prefetch_fn_env[i] = scalar_prefetch_fn for v, in_block_spec in zip(eqn.invars, in_block_specs, strict=True): + # TODO(cjfj): Check that index map functions are equivalent (in jaxpr). if ( not isinstance(v, core.Literal) and v in env - and (bs := env[v]) != in_block_spec + and not _block_shapes_equal(env[v].block_shape, + in_block_spec.block_shape) ): - if bs.block_shape != in_block_spec.block_shape: - in_block_spec = in_block_spec.replace(block_shape=_illegal) - in_block_spec = in_block_spec.replace(index_map=_illegal) + in_block_spec = pallas_core.BlockSpec(_illegal, _illegal) # pytype: disable=wrong-arg-types _write_block_spec(v, in_block_spec) def _get_in_block_spec(v, usage): @@ -405,6 +449,7 @@ def _get_in_block_spec(v, usage): def make_kernel_function( jaxpr: core.Jaxpr, + consts, in_tree, out_tree, read_usage_env, @@ -417,15 +462,21 @@ def make_kernel_function( invar_usages = util.safe_map(read_usage_env, jaxpr.invars) bs_env, scalar_prefetch_fn_env = block_spec_env - def _remove_nones(shape: tuple[int | None, ...] | None) -> tuple[int, ...]: - assert shape is not None - return tuple(s for s in shape if s is not None) + def _remove_nones( + shape: tuple[pallas_core.BlockDim | int | None, ...] | None, + ) -> tuple[int, ...]: + new_shape = tuple(_block_size(s) for s in shape) + return tuple(s for s in new_shape if s is not None) _no_aval = object() def _get_block_aval(bs, aval): + if isinstance(aval, state.AbstractRef): + return aval if bs is pallas_core.no_block_spec or bs is None: return _no_aval + if bs.block_shape is None: + return aval return aval.update(shape=_remove_nones(bs.block_shape)) # pytype: disable=attribute-error in_block_avals = [ @@ -435,58 +486,15 @@ def _get_block_aval(bs, aval): unflat_in_block_arg_avals, unflat_in_block_kwarg_avals = ( tree_util.tree_unflatten(in_tree, in_block_avals) ) - unflat_arg_usages, unflat_kwarg_usages = tree_util.tree_unflatten( - in_tree, invar_usages - ) - def sds_like(x): - if x is _no_aval: - return _no_aval - return jax.ShapeDtypeStruct(x.shape, x.dtype) - - kernel_in_type = jax.tree.map( - sds_like, (unflat_in_block_arg_avals, unflat_in_block_kwarg_avals) - ) + kernel_in_type = (unflat_in_block_arg_avals, unflat_in_block_kwarg_avals) def _read_block_spec(atom: core.Atom) -> pallas_core.BlockSpec | Any: if isinstance(atom, core.Literal): return pallas_core.no_block_spec - return bs_env[atom] + return bs_env.get(atom, pallas_core.no_block_spec) def kernel_fn(program_ids, scalar_prefetch, *args, **kwargs): - def _check_args(prefix, path, x, y, usage): - if usage == {Usage.SCALAR_PREFETCH}: - return - if y is _no_aval: - return - x_aval, y_aval = core.get_aval(x), core.get_aval(y) - if x_aval.shape != y_aval.shape: - raise ValueError( - f'Shapes do not match: actual={x_aval.shape} !=' - f' expected={y_aval.shape}. Path:' - f' {prefix}{jax.tree_util.keystr(path)}. Expected type:' - f' {kernel_in_type}. Actual args: {(args, kwargs)}' - ) - if x_aval.dtype != y_aval.dtype: - raise ValueError( - f'DTypes do not match: actual={x_aval.dtype} !=' - f' expected={y_aval.dtype}. Path:' - f' {prefix}{jax.tree_util.keystr(path)}. Expected type:' - f' {kernel_in_type}. Actual args: {(args, kwargs)}' - ) - - jax.tree_util.tree_map_with_path( - functools.partial(_check_args, 'args'), - args, - kernel_in_type[0], - unflat_arg_usages, - ) - jax.tree_util.tree_map_with_path( - functools.partial(_check_args, 'kwargs'), - kwargs, - kernel_in_type[1], - unflat_kwarg_usages, - ) flat_args, in_tree_ = tree_util.tree_flatten((args, kwargs)) if in_tree_ != tree_util.tree_structure(kernel_in_type): raise ValueError(f'Expected {kernel_in_type} PyTree, got {in_tree_}') @@ -502,6 +510,8 @@ def read_env(atom): def write_env(var, val): env[var] = val + for const, constvar in zip(consts, jaxpr.constvars): + env[constvar] = const for invar, arg, usage in zip(jaxpr.invars, flat_args, invar_usages): if Usage.REGULAR in usage: env[invar] = arg @@ -556,9 +566,12 @@ def write_env(var, val): return kernel_fn +@functools.partial(api_boundary, repro_api_name="fuser.get_fusion_values") def get_fusion_values( fusion: Callable, *args, **kwargs -) -> tuple[Callable, tuple[jax.Array, ...], tuple[jax.Array, ...]]: +) -> tuple[ + Callable, tuple[typing.SupportsShape, ...], tuple[typing.SupportsShape, ...] +]: jaxpr, values, in_tree, out_tree = fuser_utils.make_jaxpr( fusion, *args, **kwargs ) @@ -705,19 +718,42 @@ def _eltwise_usage_rule( return [used_out] -def _bcast_block_spec( +def _pull_bcast_block_spec( block_spec: pallas_core.BlockSpec, i: int ) -> pallas_core.BlockSpec: - def new_index_map(i, *args): + def new_index_map(*args): idx = block_spec.index_map(*args) assert len(idx) == len(block_spec.block_shape) idx = util.tuple_update(idx, i, 0) return idx - new_block_shape = util.tuple_update(block_spec.block_shape, i, 1) - return pallas_core.BlockSpec( - new_block_shape, functools.partial(new_index_map, i) + if block_spec.block_shape[i] is None: + return pallas_core.BlockSpec(block_spec.block_shape, new_index_map) + + # TODO(wdvi): This is a hack needed since lowering rules require block shape + # to contain either all pl.Element or none + bcast_dim_block_shape = 1 + if isinstance(block_spec.block_shape[i], pallas_core.Element): + bcast_dim_block_shape = pallas_core.Element(1) + new_block_shape = util.tuple_update( # pytype: disable=wrong-arg-types + block_spec.block_shape, i, bcast_dim_block_shape ) + return pallas_core.BlockSpec(new_block_shape, new_index_map) + + +def _push_bcast_block_spec( + block_spec: pallas_core.BlockSpec, + i: int, + size: int, +) -> pallas_core.BlockSpec: + + bcast_dim_block_shape = size + if isinstance(block_spec.block_shape[i], pallas_core.Element): + bcast_dim_block_shape = pallas_core.Element(size) + new_block_shape = util.tuple_update( # pytype: disable=wrong-arg-types + block_spec.block_shape, i, bcast_dim_block_shape + ) + return pallas_core.BlockSpec(new_block_shape, block_spec.index_map) def _binop_usage_rule(prim, ctx, used_out: set[Usage]): @@ -761,19 +797,29 @@ def _eval_function(_, x, y): zip(left_aval.shape, right_aval.shape, strict=True) ): if l == 1 and r != 1: - l_block_spec = _bcast_block_spec(l_block_spec, i) + l_block_spec = _pull_bcast_block_spec(l_block_spec, i) if r == 1 and l != 1: - r_block_spec = _bcast_block_spec(r_block_spec, i) + r_block_spec = _pull_bcast_block_spec(r_block_spec, i) return [l_block_spec, r_block_spec] +def register_default_eval_rule(prim: core.Primitive): + def default_rule(ctx, *args, **params): + assert all(bs is pallas_core.no_block_spec for bs in ctx.out_block_specs) + return prim.bind(*args, **params) + + register_eval_rule(prim)(default_rule) + + def register_binop_rule(prim: core.Primitive): register_pull_block_spec_rule(prim)(functools.partial(_binop_pull_rule, prim)) register_usage_rule(prim)(functools.partial(_binop_usage_rule, prim)) register_eval_rule(prim)(functools.partial(_binop_eval_rule, prim)) +register_default_eval_rule(state_primitives.get_p) + register_binop_rule(lax.mul_p) register_binop_rule(lax.add_p) register_binop_rule(lax.sub_p) @@ -784,8 +830,12 @@ def register_binop_rule(prim: core.Primitive): register_binop_rule(lax.eq_p) register_binop_rule(lax.gt_p) register_binop_rule(lax.ge_p) +register_binop_rule(lax.or_p) +register_binop_rule(lax.xor_p) register_binop_rule(lax.and_p) +register_binop_rule(lax.shift_right_logical_p) register_binop_rule(ad_util.add_any_p) +register_binop_rule(lax.pow_p) @register_eval_rule(lax.select_n_p) @@ -839,10 +889,74 @@ def new_index_map(*args): def _slice_eval_rule(ctx, x, **params): del params out_block_shape = ctx.out_block_specs[0].block_shape - assert len(x.shape) == sum(1 for bs in out_block_shape if bs is not None) + assert len(x.shape) == sum( + 1 + for bs in out_block_shape + if not (bs is None or isinstance(bs, pallas_core.Squeezed)) + ) return x +def _offset_indexer( + bs: pallas_core.BlockDim | int | None, + indexer, + slice_start, + slice_size, +): + # Short-circuit if the slice start is just at zero. + if isinstance(slice_start, int) and slice_start == 0: + return indexer + match bs: + case None | pallas_core.Squeezed(): + return indexer + slice_start + case pallas_core.Element(block_size): + _maybe_static_check( + slice_start % block_size == 0, + f'slice_start is not a multiple of block_size {block_size}', + ) + _maybe_static_check( + slice_size % block_size == 0, + f'slice_size is not a multiple of block_size {block_size}', + ) + return indexer + slice_start + case int() | pallas_core.Blocked(): + block_size = _block_size(bs) + _maybe_static_check( + slice_start % block_size == 0, + f'slice_start is not a multiple of block_size {block_size}', + ) + _maybe_static_check( + slice_size % block_size == 0, + f'slice_size is not a multiple of block_size {block_size}', + ) + # indexer is a block index so we need to offset it by the block offset. + return indexer + slice_start // block_size + case pallas_core.BoundedSlice(block_size): + assert isinstance(indexer, indexing.Slice) + _maybe_static_check( + indexer.start % block_size == 0, + f'slice_start is not a multiple of block_size {block_size}', + ) + _maybe_static_check( + indexer.size % block_size == 0, + f'slice_size is not a multiple of block_size {block_size}', + ) + return indexing.ds(indexer.start + slice_start, indexer.size) + case _: + raise ValueError(f'Unsupported block size {bs}') + + +def _maybe_static_check(pred: bool, msg: str): + # Tries to emit a static error if possible, otherwise falls back to runtime. + from jax.experimental import checkify + + if isinstance(pred, jax.Array): + checkify.check(pred, msg, debug=True) + else: + if not pred: + raise ValueError(msg) + + @register_pull_block_spec_rule(lax.slice_p) def _slice_rule( ctx: PullRuleContext, @@ -853,30 +967,47 @@ def _slice_rule( strides: tuple[int, ...] | None, ): del ctx - if strides is not None: + if strides is not None and not all(stride == 1 for stride in strides): raise NotImplementedError('strides are not supported yet') slice_sizes = tuple( int(end - start) for start, end in zip(start_indices, limit_indices) ) + # Do some basic checks for bs, slice_start, slice_size in zip( block_spec.block_shape, start_indices, slice_sizes ): - if bs is None: - continue - assert slice_start % bs == 0, (start_indices, block_spec.block_shape) - assert slice_size % bs == 0, (slice_sizes, block_spec.block_shape) - offsets = tuple( - slice_start // bs if bs is not None else slice_start - for slice_start, bs in zip(start_indices, block_spec.block_shape) - ) - - def _offset(x, i): - return x + i if i != 0 else x + match bs: + case None | pallas_core.Squeezed(): + continue + case pallas_core.BoundedSlice() | pallas_core.Element(): + block_size = _block_size(bs) + # Require that block_size no bigger than the slice. + if block_size > slice_size: + raise ValueError( + f'Block size {block_size} is larger than the slice size' + f' {slice_size}' + ) + case _: + block_size = _block_size(bs) + assert slice_start % block_size == 0, ( + start_indices, + block_spec.block_shape, + ) + assert slice_size % block_size == 0, ( + slice_sizes, + block_spec.block_shape, + ) def new_index_map(*args): idx = block_spec.index_map(*args) assert len(idx) == len(block_spec.block_shape) - return tuple(_offset(i, o) for i, o in zip(idx, offsets)) + idx = tuple( + _offset_indexer(bs, i, start, size) + for bs, i, start, size in zip( + block_spec.block_shape, idx, start_indices, slice_sizes, strict=True + ) + ) + return idx return [pallas_core.BlockSpec(block_spec.block_shape, new_index_map)] @@ -893,20 +1024,6 @@ def _dynamic_slice_usage_rule(ctx, used_out: set[Usage], **params): return [set()] * len(ctx.avals_in) -def _offset(x, i, s): - from jax.experimental import checkify - - if s is not None: - pred = i % s == 0 - if isinstance(pred, jax.Array): - checkify.check(i % s == 0, 'Invalid index', debug=True) - else: - if not pred: - raise ValueError('Invalid index') - offset = jax.lax.div(i, s) if s is not None else i - return x + offset - - @register_eval_rule(lax.dynamic_slice_p) def _dynamic_slice_eval_rule(ctx, x, *args, **params): del ctx, params @@ -920,7 +1037,6 @@ def _dynamic_slice_rule( *, slice_sizes: tuple[int, ...], ): - del slice_sizes def new_index_map(*args): slice_starts = ctx.scalar_prefetch_fn() @@ -942,11 +1058,11 @@ def new_index_map(*args): # multiples of the block sizes. The indices of the block that correspond to # the slice are then given by (i // b_l, j // b_m, k // b_n). # We then add these block indices to block indices produced by the index - # map. + # map block_indices = tuple( - _offset(i, o, s) - for i, o, s in zip( - idx, slice_starts, block_spec.block_shape, strict=True + _offset_indexer(s, i, start, size) + for i, s, start, size in zip( + idx, block_spec.block_shape, slice_starts, slice_sizes, strict=True ) ) return block_indices @@ -957,12 +1073,187 @@ def new_index_map(*args): ) +@register_pull_block_spec_rule(state_primitives.swap_p) +def _swap_pull_rule( + ctx: PullRuleContext, + block_spec: pallas_core.BlockSpec, + **kwargs, +): + del ctx, kwargs + # The output and val block spec are the same. + return [block_spec, block_spec] + + +@register_eval_rule(state_primitives.swap_p) +def _swap_eval_rule(ctx: KernelEvalContext, ref, val, *idx, tree): + indexers = tree_util.tree_unflatten(tree, idx) + ref_aval, _ = ctx.avals_in[:2] + indexers_avals = tree_util.tree_unflatten(tree, ctx.avals_in[2:]) + assert hasattr(ref_aval, 'shape') + if len(indexers) > 1: + raise NotImplementedError('swap not supported yet') + if not indexers_avals: + indexer_aval = indexing.NDIndexer.make_trivial_indexer(ref_aval.shape) + else: + indexer_aval = indexers_avals[0] + for idx_aval, size in zip(indexer_aval.indices, ref_aval.shape, strict=True): + if not isinstance(idx_aval, indexing.Slice): + raise NotImplementedError('swap not supported yet') + if not isinstance(idx_aval.start, int): + raise NotImplementedError('swap not supported yet') + if not isinstance(idx_aval.size, int): + raise NotImplementedError('swap not supported yet') + if idx_aval.stride != 1: + raise NotImplementedError('swap not supported yet') + if idx_aval.start != 0: + raise NotImplementedError('swap not supported yet') + if idx_aval.size != size: + raise NotImplementedError('swap not supported yet') + # We have a pure slice so now we can just re-index the ref according to the + # block indices. + block_spec = ctx.out_block_specs[0] + block_idx = ctx.get_out_block_indices()[0] + + def _slice(i, b): + if not isinstance(b, int): + raise NotImplementedError('swap not supported yet') + return i if b is None else indexing.ds(i * b, b) + + indexer = tuple( + _slice(i, b) + for i, b in zip(block_idx, block_spec.block_shape, strict=True) + ) + return ref.swap(val, idx=indexer) + + +@register_pull_block_spec_rule(state_primitives.get_p) +def _get_pull_rule( + ctx: PullRuleContext, block_spec: pallas_core.BlockSpec, *, tree +): + ref_aval = ctx.avals_in[0] + assert hasattr(ref_aval, 'shape') + indexers_avals = tree_util.tree_unflatten(tree, ctx.avals_in[1:]) + if len(indexers_avals) > 1: + raise NotImplementedError('get not supported yet') + if not indexers_avals: + indexer_aval = indexing.NDIndexer.make_trivial_indexer(ref_aval.shape) + else: + indexer_aval = indexers_avals[0] + block_shape_iter = iter(block_spec.block_shape) + block_shape = [] + if not all( + bd is None + or isinstance(bd, (int, pallas_core.Blocked, pallas_core.Squeezed)) + for bd in block_spec.block_shape + ): + raise NotImplementedError('get not supported yet') + for idx_aval, size in zip(indexer_aval.indices, ref_aval.shape, strict=True): + if not isinstance(idx_aval, indexing.Slice): + assert hasattr(idx_aval, 'shape') and not idx_aval.shape + block_shape.append(pallas_core.Squeezed()) + continue + if not isinstance(idx_aval.start, int): + raise NotImplementedError('get not supported yet') + if not isinstance(idx_aval.size, int): + raise NotImplementedError('get not supported yet') + if idx_aval.stride != 1: + raise NotImplementedError('get not supported yet') + if idx_aval.start != 0: + raise NotImplementedError('get not supported yet') + if idx_aval.size != size: + raise NotImplementedError('get not supported yet') + bd = next(block_shape_iter) + block_shape.append(_block_size(bd)) + assert next(block_shape_iter, None) is None + + def new_index_map(*args): + idx = block_spec.index_map(*args) + idx_iter = iter(idx) + indices = tuple( + 0 + if (bd is None or isinstance(bd, pallas_core.Squeezed)) + else next(idx_iter) + for bd in range(len(block_shape)) + ) + assert next(idx_iter, None) is None + return indices + + new_block_spec = pallas_core.BlockSpec(block_shape, new_index_map) + return ([new_block_spec] + + [pallas_core.no_block_spec] * (len(ctx.avals_in) - 1)) + + +@register_eval_rule(state_primitives.get_p) +def _get_eval_rule(ctx: KernelEvalContext, ref, *idx, tree): + indexers = tree_util.tree_unflatten(tree, idx) + ref_aval = ctx.avals_in[0] + indexers_avals = tree_util.tree_unflatten(tree, ctx.avals_in[1:]) + ref_block_spec = ctx.in_block_specs[0] + assert hasattr(ref_aval, 'shape') + if len(indexers) > 1: + raise NotImplementedError('get not supported yet') + if not indexers: + indexer = indexing.NDIndexer.make_trivial_indexer(ref_aval.shape) + indexer_aval = indexer + else: + indexer = indexers[0] + indexer_aval = indexers_avals[0] + block_indexer = [] + + def _slice(i, b): + match b: + case int(): + return indexing.ds(i * b, b) + case pallas_core.Blocked(bs): + return indexing.ds(i * bs, bs) + case pallas_core.Squeezed() | None: + return i + case _: + raise NotImplementedError('get not supported yet') + + if ref_block_spec is pallas_core.no_block_spec: + # Short-circuit if the ref is not blocked. + return state_primitives.get_p.bind(ref, *idx, tree=tree) + block_idx_iter = iter(ctx.get_out_block_indices()[0]) + for idx_aval, size, idx, bd in zip( + indexer_aval.indices, + ref_aval.shape, + indexer.indices, + ref_block_spec.block_shape, + strict=True, + ): + if not isinstance(idx_aval, indexing.Slice): + assert hasattr(idx_aval, 'shape') and not idx_aval.shape, idx_aval + assert bd is None or isinstance(bd, pallas_core.Squeezed) + block_indexer.append(idx) + continue + if not isinstance(idx_aval.start, int): + raise NotImplementedError('get not supported yet') + if not isinstance(idx_aval.size, int): + raise NotImplementedError('get not supported yet') + if idx_aval.stride != 1: + raise NotImplementedError('get not supported yet') + if idx_aval.start != 0: + raise NotImplementedError('get not supported yet') + if idx_aval.size != size: + raise NotImplementedError('get not supported yet') + bidx = next(block_idx_iter) + block_indexer.append(_slice(bidx, bd)) + assert next(block_idx_iter, None) is None + return ref.get(idx=tuple(block_indexer)) + + @register_eval_rule(lax.concatenate_p) def _concatenate_eval_rule(ctx: KernelEvalContext, *args, dimension): # We now handle the case where each of the concatenated array dimensions # divides the block size. block_spec = ctx.out_block_specs[0] block_shape = block_spec.block_shape + is_element_block = [isinstance(bd, pallas_core.Element) for bd in block_shape] + if any(is_element_block): + raise NotImplementedError( + 'Concatenation with Element indexing is not yet supported.' + ) block_dim = block_shape[dimension] if block_dim is None: block_dim = 1 @@ -1006,15 +1297,20 @@ def _concatenate_rule( dimension: int, ): block_shape = block_spec.block_shape + is_element_block = [isinstance(bd, pallas_core.Element) for bd in block_shape] + if any(is_element_block): + raise NotImplementedError( + 'Concatenation with Element indexing is not yet supported.' + ) num_blocks = [] block_dim = block_shape[dimension] - if block_dim is None: + if block_dim is None or isinstance(block_dim, pallas_core.Squeezed): block_dim = 1 if block_dim == sum(aval.shape[dimension] for aval in ctx.avals_in): # pytype: disable=attribute-error # Handle special case if the block contains all of the concatenated # array. new_shapes = [ - util.tuple_update( + util.tuple_update( # pytype: disable=wrong-arg-types block_spec.block_shape, dimension, aval.shape[dimension] # pytype: disable=attribute-error ) for aval in ctx.avals_in @@ -1075,14 +1371,21 @@ def _broadcast_in_dim_usage_rule(ctx, used_out: set[Usage], **params): @register_eval_rule(lax.broadcast_in_dim_p) def _broadcast_in_dim_eval_rule( - eval_ctx: KernelEvalContext, x, broadcast_dimensions, **params + eval_ctx: KernelEvalContext, x, broadcast_dimensions, shape, **params ): - if not eval_ctx.avals_in[0].shape: # pytype: disable=attribute-error - # Scalar -> Array broadcast - block_spec = eval_ctx.out_block_specs[0] - shape = tuple(s for s in block_spec.block_shape if s is not None) - return jax.lax.broadcast_in_dim(x, broadcast_dimensions=(), shape=shape) - return x + del params # Unused. + in_shape = eval_ctx.avals_in[0].shape # pytype: disable=attribute-error + if in_shape == shape: + # Dummy broadcast + return x + shape = tuple(map(_block_size, eval_ctx.out_block_specs[0].block_shape)) + dims = tuple( + d - sum(s is None for s in shape[:d]) + for d in broadcast_dimensions + if shape[d] is not None + ) + shape = tuple(s for s in shape if s is not None) + return jax.lax.broadcast_in_dim(x, broadcast_dimensions=dims, shape=shape) @register_pull_block_spec_rule(lax.broadcast_in_dim_p) @@ -1096,15 +1399,20 @@ def _broadcast_in_dim_pull_rule( ): del shape, sharding - if not ctx.avals_in[0].shape: # pytype: disable=attribute-error + shape = ctx.avals_in[0].shape # pytype: disable=attribute-error + if not shape: return [pallas_core.no_block_spec] def new_index_map(*args): idx = block_spec.index_map(*args) - return tuple(idx[i] for i in broadcast_dimensions) + return tuple( + 0 if (d == 1) else idx[i] + for i, d in zip(broadcast_dimensions, shape, strict=True) + ) new_block_shape = tuple( - block_spec.block_shape[i] for i in broadcast_dimensions + b if ((b := block_spec.block_shape[i]) is None) or (d != 1) else 1 + for i, d in zip(broadcast_dimensions, shape, strict=True) ) return [pallas_core.BlockSpec(new_block_shape, new_index_map)] @@ -1115,10 +1423,17 @@ def _transpose_eval_rule( ): block_spec = eval_ctx.out_block_specs[0] block_shape = block_spec.block_shape - block_shape_no_nones = tuple(bs for bs in block_shape if bs is not None) + block_shape_no_nones = tuple( + bs + for bs in block_shape + if not (bs is None or isinstance(bs, pallas_core.Squeezed)) + ) block_dims_iter = iter(range(len(block_shape_no_nones))) expanded_block_dims = [ - None if bs is None else next(block_dims_iter) for bs in block_shape + None + if (bs is None or isinstance(bs, pallas_core.Squeezed)) + else next(block_dims_iter) + for bs in block_shape ] assert next(block_dims_iter, None) is None permuted_block_dims = [expanded_block_dims[p] for p in permutation] @@ -1171,6 +1486,67 @@ def _convert_element_type_pull_rule( return [block_spec] +@register_eval_rule(lax.bitcast_convert_type_p) +def _bitcast_convert_type_eval_rule(eval_ctx: KernelEvalContext, x, new_dtype): + return jax.lax.bitcast_convert_type(x, new_dtype) + + +@register_pull_block_spec_rule(lax.bitcast_convert_type_p) +def _bitcast_convert_type_pull_rule( + ctx: PullRuleContext, + block_spec: pallas_core.BlockSpec, + *, + new_dtype: jnp.dtype, +): + old_dtype = ctx.avals_in[0].dtype # pytype: disable=attribute-error + if old_dtype.itemsize != new_dtype.itemsize: + raise NotImplementedError( + 'bitcast_convert_type with different bitwidths not supported yet:' + f' {old_dtype=}, {new_dtype=}' + ) + return [block_spec] + + +@register_eval_rule(prng.random_bits_p) +def _random_bits_eval_rule(eval_ctx: KernelEvalContext, key, bit_width, shape): + del shape + block_spec = eval_ctx.out_block_specs[0] + indices = eval_ctx.get_out_block_indices()[0] + block_shape = block_spec.block_shape + # This is the important part here: we fold in block indices into the key so + # each block gets different random numbers. + for idx in indices: + key = jax.random.fold_in(key, idx) + return prng.random_bits(key, bit_width=bit_width, shape=block_shape) + + +@register_pull_block_spec_rule(prng.random_bits_p) +def _random_bits_pull_rule( + ctx: PullRuleContext, + block_spec: pallas_core.BlockSpec, + **_, +): + del ctx, block_spec + key_block_spec = pallas_core.BlockSpec( + block_shape=None, memory_space=pallas_core.MemorySpace.KEY + ) + return [key_block_spec] + + +@register_eval_rule(prng.random_wrap_p) +def _random_wrap_eval_rule(eval_ctx: KernelEvalContext, arr, *, impl): + del eval_ctx + return jax.random.wrap_key_data(arr, impl=impl) + + +@register_pull_block_spec_rule(prng.random_wrap_p) +def _random_wrap_pull_rule( + ctx: PullRuleContext, block_spec: pallas_core.BlockSpec, *, impl +): + del ctx, block_spec, impl + return [pallas_core.BlockSpec(block_shape=None)] + + @register_eval_rule(lax.iota_p) def _iota_eval_rule( eval_ctx: KernelEvalContext, *, dimension, shape, dtype, sharding @@ -1179,10 +1555,16 @@ def _iota_eval_rule( block_spec = eval_ctx.out_block_specs[0] block_idx = eval_ctx.get_out_block_indices()[0] assert len(block_idx) == len(shape) - iota_shape = tuple(s for s in block_spec.block_shape if s is not None) - dim_ = dimension - sum(s is None for s in block_spec.block_shape[:dimension]) + iota_shape = tuple( + _block_size(s) for s in block_spec.block_shape if s is not None + ) + dim_ = dimension - sum( + _block_size(s) is None for s in block_spec.block_shape[:dimension] + ) local_iota = jax.lax.broadcasted_iota(dtype, iota_shape, dim_) - return local_iota + block_idx[dimension] * block_spec.block_shape[dimension] + return local_iota + block_idx[dimension] * _block_size( + block_spec.block_shape[dimension] + ) @register_pull_block_spec_rule(lax.iota_p) @@ -1203,7 +1585,228 @@ def _iota_pull_rule( return [] -@register_usage_rule(pjit.pjit_p) +def _pattern_match_lanes_to_sublanes_reshape( + aval_in: core.ShapedArray, + aval_out: core.ShapedArray, +) -> bool: + # Pattern matches a reshape of the form (..., n * l) -> (..., n, l) + # where l is a multiple of 128. + + *leading_out, last_dim_in = aval_in.shape + *leading_in, second_to_last_dim_out, last_dim = aval_out.shape + if leading_in != leading_out: + return False + if second_to_last_dim_out * last_dim != last_dim_in: + return False + if last_dim % 128 != 0: + return False + return True + + +@register_pull_block_spec_rule(lax.reshape_p) +def _reshape_pull_rule( + ctx: PullRuleContext, + block_spec: pallas_core.BlockSpec, + *, + dimensions: tuple[int, ...] | None, + new_sizes: tuple[int, ...], + sharding: jax.sharding.Sharding, +): + del sharding, new_sizes + if dimensions is not None: + raise NotImplementedError('reshape with None dimensions not supported yet') + aval_in = ctx.avals_in[0] + assert isinstance(aval_in, core.ShapedArray) + aval_out = ctx.avals_out[0] + assert isinstance(aval_out, core.ShapedArray) + + block_shape = block_spec.block_shape + shape_in = aval_in.shape + shape_out = aval_out.shape + assert np.prod(shape_in) == np.prod(shape_out) + + # Handle merged dims; i.e. (..., m, n, ...) -> (..., m * n, ...). + i = 0 + j = 0 + merged_dims = [] + + while i < len(shape_in) and j < len(shape_out): + merged = [] + while not merged or np.prod(merged) < shape_out[j]: + merged.append(shape_in[i]) + i += 1 + + if np.prod(merged) > shape_out[j]: + break # Dimension has been split (or something more complex). + + merged_dims.append(merged) + j += 1 + + if (i == len(shape_in)) and np.prod(shape_out[j:]) == 1: + new_block_shape = [] + new_grids = [] + + for d, bd, merged in zip(shape_out, block_shape, merged_dims): + bs = pallas_core.get_block_size(bd) + + if len(merged) == 1: + new_grids.append((merged[0] // bs,)) + new_block_shape.append(bd if bd is not None else 1) + continue + + if not isinstance(bd, (int, pallas_core.Blocked)): + raise NotImplementedError('reshape merge must use `Blocked` block size') + + num_blocks = pallas_utils.cdiv(d, bs) + new_block_dims = [] + for md in reversed(merged): + if bs % md == 0: + new_block_dims.append(md) + bs //= md + elif md % bs == 0: + new_block_dims.append(bs) + bs = 1 + else: + raise NotImplementedError('unsupported reshape merge') + + new_block_dims.reverse() + new_block_shape.extend(new_block_dims) + new_grid = [ + np.int32(md // pallas_core.get_block_size(bd)) + for md, bd in zip(merged, new_block_dims) + ] + new_grids.append(tuple(new_grid)) + + if np.prod(new_grid) != num_blocks: + raise NotImplementedError('reshape merge must maintain grid size') + + def new_index_map(*args): + # NOTE: The `zip` will drop indices for any trailing `1` dims. + idxs = ( + jnp.unravel_index(idx, new_grid) if len(new_grid) > 1 else (idx,) + for idx, new_grid in zip(block_spec.index_map(*args), new_grids) + ) + return sum(idxs, ()) + + return [pallas_core.BlockSpec(tuple(new_block_shape), new_index_map)] + + # Handle the case where we reshape from (..., n * l) -> (..., n, l) + if _pattern_match_lanes_to_sublanes_reshape(aval_in, aval_out): + if not isinstance(block_shape[-1], (int, pallas_core.Blocked)): + raise NotImplementedError( + f'reshape must use Blocked block size on lanes: {block_shape}' + ) + if not isinstance(block_shape[-2], (int, pallas_core.Blocked)): + raise NotImplementedError( + f'reshape must use Blocked block size on sublanes: {block_shape}' + ) + last_dim = aval_out.shape[-1] + block_sublane_dim, block_lane_dim = ( + _block_size(block_shape[-2]), + _block_size(block_shape[-1]), + ) + total_block_size = block_sublane_dim * block_lane_dim + if total_block_size % 128 != 0: + raise NotImplementedError( + 'reshape with non-128 aligned block size on lanes not supported yet' + ) + if block_lane_dim != last_dim: + raise NotImplementedError( + 'reshape with non-matching block size on lanes not supported yet:' + f' {block_shape}' + ) + new_block_shape = block_shape[:-2] + (total_block_size,) + + def new_index_map(*args): # pylint: disable=function-redefined + *idx, second_to_last, last = block_spec.index_map(*args) + # last should always be 0 + if not isinstance(last, int) and last != 0: + raise NotImplementedError( + 'Must select entire block on last dimension for reshape' + ) + return *idx, second_to_last + + return [pallas_core.BlockSpec(new_block_shape, new_index_map)] + + raise NotImplementedError(f'reshape not supported yet: {aval_in}, {aval_out}') + + +@register_eval_rule(lax.reshape_p) +def _reshape_eval_rule( + eval_ctx: KernelEvalContext, x, *, dimensions, new_sizes, sharding +): + del sharding, dimensions, new_sizes + out_shape_nones = tuple( + _block_size(s) for s in eval_ctx.out_block_specs[0].block_shape + ) + out_shape = tuple(s for s in out_shape_nones if s is not None) + # Because we have restricted the pull block spec rule, we can just apply a + # basic reshape here. + x = x.reshape(out_shape) + return x + + +@register_pull_block_spec_rule(lax.reduce_sum_p) +def _reduce_sum_pull_rule( + ctx: PullRuleContext, + block_spec: pallas_core.BlockSpec, + *, + axes: tuple[int, ...], + out_sharding, +): + aval_in = ctx.avals_in[0] + assert isinstance(aval_in, core.ShapedArray) + new_block_shape = [] + block_shape = iter(block_spec.block_shape) + for i, d in enumerate(aval_in.shape): + if i in axes: + new_block_shape.append(pallas_core.Blocked(d)) + else: + new_block_shape.append(next(block_shape)) + assert next(block_shape, None) is None + + def new_index_map(*args): + idx = block_spec.index_map(*args) + new_idx = [] + idx_iter = iter(idx) + for i in range(len(aval_in.shape)): + if i in axes: + new_idx.append(0) + else: + new_idx.append(next(idx_iter)) + assert next(idx_iter, None) is None + return tuple(new_idx) + + new_block_spec = block_spec.replace( + block_shape=tuple(new_block_shape), index_map=new_index_map + ) + return [new_block_spec] + + +@register_eval_rule(lax.reduce_sum_p) +def _reduce_sum_eval_rule( + ctx: KernelEvalContext, + x, + *, + axes: tuple[int, ...], + out_sharding, +): + aval_in = ctx.avals_in[0] + assert isinstance(aval_in, core.ShapedArray) + block_shape = tuple(ctx.in_block_specs[0].block_shape) + for i in axes: + if _block_size(block_shape[i]) != aval_in.shape[i]: + raise NotImplementedError( + f'reduce_sum on partial blocks not supported: {aval_in=},' + f' {block_shape=}' + ) + return jax.lax.reduce_sum(x, axes=axes) + + +# Higher order primitives + + +@register_usage_rule(pjit.jit_p) def _jit_usage_rule( ctx, used_out: list[set[Usage]], *, jaxpr: core.ClosedJaxpr, **_ ): @@ -1212,23 +1815,27 @@ def _jit_usage_rule( return in_usages -@register_eval_rule(pjit.pjit_p) +@register_eval_rule(pjit.jit_p) def _jit_eval_rule(ctx: KernelEvalContext, *args, jaxpr, **kwargs): jaxpr, consts = jaxpr.jaxpr, jaxpr.consts if consts: raise NotImplementedError('pjit with consts not supported yet') out_tree = tree_util.tree_structure(tuple(jaxpr.outvars)) in_tree = tree_util.tree_structure((tuple(jaxpr.invars), {})) - read_usage_env = compute_usage(jaxpr, ctx.out_usages) + + def read_usage_env(_: core.Var): + return {Usage.REGULAR} + _, env, _ = _pull_block_spec( jaxpr, ctx.out_block_specs, - ctx.out_usages, scalar_prefetch_handler=ctx.scalar_prefetch_handler, + read_usage_env=read_usage_env, grid=ctx.grid, ) kernel_fn = make_kernel_function( jaxpr, + (), in_tree, out_tree, read_usage_env, @@ -1240,18 +1847,22 @@ def _jit_eval_rule(ctx: KernelEvalContext, *args, jaxpr, **kwargs): return kernel_fn(ctx.get_program_ids(), ctx.scalar_prefetch, *args) -@register_pull_block_spec_rule(pjit.pjit_p) +@register_pull_block_spec_rule(pjit.jit_p) def _jit_pull_block_spec_rule( ctx: PullRuleContext, out_block_specs, *, jaxpr, **kwargs ): jaxpr, consts = jaxpr.jaxpr, jaxpr.consts if consts: raise NotImplementedError('pjit with consts not supported yet') + + def read_usage_env(_: core.Var): + return {Usage.REGULAR} + in_block_specs, _, _ = _pull_block_spec( jaxpr, out_block_specs, - ctx.out_usages, scalar_prefetch_handler=ctx.scalar_prefetch_handler, + read_usage_env=read_usage_env, grid=ctx.grid, ) return in_block_specs @@ -1276,16 +1887,20 @@ def _custom_jvp_call_eval_rule( raise NotImplementedError('custom_jvp_call with consts not supported yet') out_tree = tree_util.tree_structure(tuple(jaxpr.outvars)) in_tree = tree_util.tree_structure((tuple(jaxpr.invars), {})) - read_usage_env = compute_usage(jaxpr, ctx.out_usages) + + def read_usage_env(_: core.Var): + return {Usage.REGULAR} + _, env, _ = _pull_block_spec( jaxpr, ctx.out_block_specs, - ctx.out_usages, scalar_prefetch_handler=ctx.scalar_prefetch_handler, grid=ctx.grid, + read_usage_env=read_usage_env, ) kernel_fn = make_kernel_function( jaxpr, + (), in_tree, out_tree, read_usage_env, @@ -1304,16 +1919,99 @@ def _custom_jvp_call_pull_block_spec_rule( jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts if consts: raise NotImplementedError('custom_jvp_call with consts not supported yet') + + def read_usage_env(_: core.Var): + return {Usage.REGULAR} + + in_block_specs, _, _ = _pull_block_spec( + jaxpr, + out_block_specs, + scalar_prefetch_handler=ctx.scalar_prefetch_handler, + grid=ctx.grid, + read_usage_env=read_usage_env, + ) + return in_block_specs + + +@register_usage_rule(custom_derivatives.custom_vjp_call_p) +def _custom_vjp_call_usage_rule( + ctx, used_out: list[set[Usage]], *, call_jaxpr: core.ClosedJaxpr, **_ +): + del ctx + read_usage_env = compute_usage(call_jaxpr.jaxpr, used_out) + in_usages = util.safe_map(read_usage_env, call_jaxpr.jaxpr.invars) + return in_usages + + +@register_eval_rule(custom_derivatives.custom_vjp_call_p) +def _custom_vjp_call_eval_rule( + ctx: KernelEvalContext, *args, call_jaxpr: core.ClosedJaxpr, **kwargs +): + jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts + if consts: + raise NotImplementedError('custom_vjp_call with consts not supported yet') + out_tree = tree_util.tree_structure(tuple(jaxpr.outvars)) + in_tree = tree_util.tree_structure((tuple(jaxpr.invars), {})) + + def read_usage_env(_: core.Var): + return {Usage.REGULAR} + + _, env, _ = _pull_block_spec( + jaxpr, + ctx.out_block_specs, + scalar_prefetch_handler=ctx.scalar_prefetch_handler, + grid=ctx.grid, + read_usage_env=read_usage_env, + ) + kernel_fn = make_kernel_function( + jaxpr, + (), + in_tree, + out_tree, + read_usage_env, + ctx.in_block_specs, + env, + ctx.scalar_prefetch_handler, + ctx.grid, + ) + return kernel_fn(ctx.get_program_ids(), ctx.scalar_prefetch, *args) + + +@register_pull_block_spec_rule(custom_derivatives.custom_vjp_call_p) +def _custom_vjp_call_pull_block_spec_rule( + ctx: PullRuleContext, out_block_specs, *, call_jaxpr, **kwargs +): + jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts + if consts: + raise NotImplementedError('custom_vjp_call with consts not supported yet') + + def read_usage_env(_: core.Var): + return {Usage.REGULAR} + in_block_specs, _, _ = _pull_block_spec( jaxpr, out_block_specs, - ctx.out_usages, scalar_prefetch_handler=ctx.scalar_prefetch_handler, grid=ctx.grid, + read_usage_env=read_usage_env, ) return in_block_specs +@register_pull_block_spec_rule(hijax.call_hi_primitive_p) +def _custom_call_hi_primitive_pull_block_spec_rule( + ctx: PullRuleContext, out_block_specs, *, prim +): + return prim.pull_block_spec_rule(ctx, out_block_specs) + +@register_eval_rule(hijax.call_hi_primitive_p) +def _custom_call_hi_primitive_eval_rule( + ctx: KernelEvalContext, *args, prim +): + return jax.tree.leaves(prim.block_eval_rule(ctx, *args)) + + +@functools.partial(api_boundary, repro_api_name="fuser.push_block_spec") def push_block_spec( f: Callable, *in_spec_args, @@ -1434,6 +2132,23 @@ def _binop_push_rule( left_aval, right_aval = ctx.avals_in assert isinstance(left_aval, core.ShapedArray) assert isinstance(right_aval, core.ShapedArray) + if not right_aval.shape: + return left_block_spec + if not left_aval.shape: + return right_block_spec + lhs_has_block_spec = left_block_spec is not pallas_core.no_block_spec + rhs_has_block_spec = right_block_spec is not pallas_core.no_block_spec + if not (lhs_has_block_spec ^ rhs_has_block_spec): + # We can only do a push if one of the block specs is unspecified + # or they are identical. + if left_block_spec is right_block_spec: + return left_block_spec + raise ValueError('Illegal binary push. One of the block specs must be no_block_spec.') + for l, r in zip(left_aval.shape, right_aval.shape, strict=True): + if l == 1 and r != 1 and lhs_has_block_spec: + raise ValueError('Cannot propagate block spec through LHS broadcast.') + if r == 1 and l != 1 and rhs_has_block_spec: + raise ValueError('Cannot propagate block spec through RHS broadcast.') if left_block_spec is pallas_core.no_block_spec: return right_block_spec if right_block_spec is pallas_core.no_block_spec: @@ -1455,6 +2170,7 @@ def _binop_push_rule( register_binop_push_rule(lax.eq_p) register_binop_push_rule(lax.gt_p) register_binop_push_rule(lax.and_p) +register_binop_push_rule(lax.pow_p) register_binop_push_rule(ad_util.add_any_p) @@ -1477,7 +2193,7 @@ def _transpose_push_rule( ) -> pallas_core.BlockSpec: del ctx block_shape = block_spec.block_shape - new_shape = [block_shape[i] for i in permutation] + new_shape = tuple(block_shape[i] for i in permutation) if set(permutation[-2:]) != {permutation[-1], permutation[-2]}: raise NotImplementedError( 'Cannot permute last two dimensions with leading dimensions.' @@ -1510,9 +2226,14 @@ def _select_n_push_rule( ): del ctx block_specs = [b for b in args if b is not pallas_core.no_block_spec] + assert len(block_specs) > 0 + block_spec = block_specs[0] if len(block_specs) > 1: - raise NotImplementedError('select_n with multiple inputs not supported yet') - return block_specs[0] + if any(b is not block_spec for b in block_specs): + raise NotImplementedError( + 'select_n with multiple differing inputs not supported yet' + ) + return block_spec @register_push_block_spec_rule(custom_derivatives.custom_jvp_call_p) @@ -1523,7 +2244,28 @@ def _custom_jvp_call_push_rule( return _push_block_spec_jaxpr(call_jaxpr.jaxpr, *block_specs) -@register_push_block_spec_rule(pjit.pjit_p) +@register_push_block_spec_rule(custom_derivatives.custom_vjp_call_p) +def _custom_vjp_call_push_rule( + ctx, + *block_specs, + call_jaxpr: core.ClosedJaxpr, + num_consts, + fwd_jaxpr_thunk, + bwd, + out_trees, + symbolic_zeros, +): + del ctx, num_consts, fwd_jaxpr_thunk, bwd, out_trees, symbolic_zeros + return _push_block_spec_jaxpr(call_jaxpr.jaxpr, *block_specs) + +@register_push_block_spec_rule(hijax.call_hi_primitive_p) +def _custom_call_hi_primitive_push_block_spec_rule( + ctx: PullRuleContext, *block_specs, prim +): + return prim.push_block_spec_rule(ctx, block_specs) + + +@register_push_block_spec_rule(pjit.jit_p) def _pjit_push_rule(ctx, *block_specs, jaxpr: core.ClosedJaxpr, **_): assert not jaxpr.consts return _push_block_spec_jaxpr(jaxpr.jaxpr, *block_specs) @@ -1546,5 +2288,174 @@ def register_eltwise_rule(prim: core.Primitive): register_eltwise_rule(lax.cos_p) register_eltwise_rule(lax.sqrt_p) register_eltwise_rule(lax.rsqrt_p) +register_eltwise_rule(lax.square_p) register_eltwise_rule(lax.log_p) register_eltwise_rule(lax.integer_pow_p) +register_eltwise_rule(lax.logistic_p) + + +@register_push_block_spec_rule(lax.reshape_p) +def _reshape_push_rule( + ctx: PullRuleContext, + block_spec: pallas_core.BlockSpec, + *, + dimensions: tuple[int, ...] | None, + new_sizes: tuple[int, ...], + sharding: jax.sharding.Sharding, +): + del sharding, new_sizes + if dimensions is not None: + raise NotImplementedError('reshape with None dimensions not supported yet') + aval_in = ctx.avals_in[0] + assert isinstance(aval_in, core.ShapedArray) + aval_out = ctx.avals_out[0] + assert isinstance(aval_out, core.ShapedArray) + if _pattern_match_lanes_to_sublanes_reshape(aval_in, aval_out): + block_shape = tuple(block_spec.block_shape) + if not isinstance(block_shape[-1], (int, pallas_core.Blocked)): + raise NotImplementedError( + f'reshape must use Blocked block size on lanes: {block_shape}' + ) + last_dim = aval_out.shape[-1] + last_block_dim = _block_size(block_shape[-1]) + if last_block_dim % 128 != 0: + raise NotImplementedError( + 'reshape with non-128 aligned block size on lanes not supported yet' + ) + if last_block_dim % last_dim != 0: + raise NotImplementedError( + 'reshape with non-divisible block size on lanes not supported yet' + ) + num_last_dim_blocks = last_block_dim // last_dim + new_block_shape = block_shape[:1] + (num_last_dim_blocks, last_dim) + + def new_index_map(*args): + *idx, last = block_spec.index_map(*args) + return *idx, last, 0 + + return pallas_core.BlockSpec(new_block_shape, new_index_map) + raise NotImplementedError(f'reshape not supported yet: {aval_in}, {aval_out}') + + +@register_push_block_spec_rule(lax.reduce_sum_p) +def _reduce_sum_push_rule( + ctx: PushRuleContext, + block_spec: pallas_core.BlockSpec, + *, + axes: tuple[int, ...], + out_sharding, +): + del out_sharding + aval_in = ctx.avals_in[0] + assert isinstance(aval_in, core.ShapedArray) + if not all( + aval_in.shape[i] == pallas_core.get_block_size(block_spec.block_shape[i]) + for i in axes + ): + raise NotImplementedError( + f'reduce_sum over partial blocks not supported yet: {aval_in.shape=},' + f' {block_spec.block_shape=}, {axes=}' + ) + new_block_shape = tuple( + bd for i, bd in enumerate(block_spec.block_shape) if i not in axes + ) + + def new_index_map(*args): + idx = block_spec.index_map(*args) + return tuple(idx[i] for i in range(len(idx)) if i not in axes) + + return block_spec.replace( + block_shape=tuple(new_block_shape), index_map=new_index_map + ) + + +@register_push_block_spec_rule(lax.broadcast_in_dim_p) +def _broadcast_in_dim_push_rule( + ctx: PushRuleContext, + block_spec: pallas_core.BlockSpec, + *, + shape: tuple[int, ...], + broadcast_dimensions: tuple[int, ...], + sharding: jax.sharding.Sharding, +): + del sharding + in_aval = ctx.avals_in[0] + assert isinstance(in_aval, core.ShapedArray) + in_shape = in_aval.shape + + dim_map = { + out_dim: in_dim + for in_dim, out_dim in enumerate(broadcast_dimensions) + } + + new_block_shape = [] + for i, s in enumerate(shape): + if i in dim_map: + in_dim = dim_map[i] + if in_shape[in_dim] != s: + assert pallas_core.get_block_size(block_spec.block_shape[in_dim]) == 1 + new_block_shape.append(s) + else: + new_block_shape.append(block_spec.block_shape[in_dim]) + else: + new_block_shape.append(s) + + def new_index_map(*args): + idx = block_spec.index_map(*args) + return tuple( + idx[dim_map[i]] if i in dim_map else 0 for i in range(len(shape)) + ) + + return pallas_core.BlockSpec(tuple(new_block_shape), new_index_map) + + +@register_push_block_spec_rule(lax.concatenate_p) +def _concatenate_push_rule( + ctx: PushRuleContext, + *block_specs: pallas_core.BlockSpec, + dimension: int, +): + avals_in = ctx.avals_in + block_shapes = [ + pallas_core._canonicalize_block_shape(block_spec.block_shape) + for block_spec in block_specs + ] + # We only support concatenation if the entirety of the concat dimension is blocked. + assert all(hasattr(aval_in, 'shape') for aval_in in avals_in) + if not all( + block_shape[dimension] == pallas_core.Blocked(avals_in.shape[dimension]) # pytype: disable=attribute-error + for block_shape, avals_in in zip(block_shapes, avals_in) + ): + raise NotImplementedError( + f'concatenate not supported yet: {block_shapes=}, {avals_in=}' + ) + def _new_index_map(*args): + all_indices = [block_spec.index_map(*args) for block_spec in block_specs] + # This is a very important check. We cannot actually construct a single BlockSpec + # for the output of concatenate if the indices are not identical across all the + # inputs. This is not something we can always enforce statically, but to be conservative + # we apply a very aggressive check. We can consider relaxing this later. + if not all( + (all_indices[0][i] is all_indices[j][i]) + for i in range(len(all_indices[0])) + for j in range(len(all_indices)) + ): + raise ValueError( + 'Cannot statically prove that all input blocks to concatenate are the' + ' same.' + ) + # If all block indices are the same, we are materializing the full concatenation along + # the concat dimension, so we use index 0. + base_indices = list(all_indices[0]) + base_indices[dimension] = 0 + return tuple(base_indices) + + new_block_shape = list(block_specs[0].block_shape) + # Since the entirety of the concat dimension is materialized in the blocks, + # the new block size is the sum of the block sizes of the inputs along that + # dimension. + new_block_shape[dimension] = sum( + pallas_core.get_block_size(block_shape[dimension]) + for block_shape in block_shapes + ) + return pallas_core.BlockSpec(tuple(new_block_shape), _new_index_map) diff --git a/jax/_src/pallas/fuser/custom_fusion_lib.py b/jax/_src/pallas/fuser/custom_fusion_lib.py new file mode 100644 index 000000000000..fb2955daf19f --- /dev/null +++ b/jax/_src/pallas/fuser/custom_fusion_lib.py @@ -0,0 +1,264 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Callable, Sequence +import dataclasses +import functools +from typing import Any, Protocol + +from jax._src import api_util +from jax._src import core +from jax._src import custom_api_util +from jax._src import linear_util as lu +from jax._src.traceback_util import api_boundary +from jax._src import tree_util +from jax._src import util +from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe +from jax._src.pallas.mosaic import lowering as mosaic_lowering +from jax._src.pallas import core as pallas_core +from jax._src.pallas.fuser import block_spec as block_spec_lib + + +custom_fusion_p = core.Primitive('custom_fusion') +custom_fusion_p.multiple_results = True + +CustomPullBlockSpecRuleFn = Callable[[tuple[pallas_core.BlockSpec, ...]], + Sequence[pallas_core.BlockSpec]] + +CustomPushBlockSpecRuleFn = Callable[[tuple[pallas_core.BlockSpec, ...]], + tuple[pallas_core.BlockSpec, ...]] + +@dataclasses.dataclass(frozen=True) +class CustomEvalContext: + out_block_specs: tuple[pallas_core.BlockSpec, ...] + out_block_indices: tuple[Any, ...] + +class CustomEvalRuleFn(Protocol): + + def __call__( + self, + ctx: CustomEvalContext, + *args: Any, + ) -> Sequence[Any]: + ... + + +@custom_api_util.register_custom_decorator_type +class custom_fusion: + fun: Callable[..., Any] + + eval_rule: CustomEvalRuleFn | None = None + + pull_block_spec_rule: CustomPullBlockSpecRuleFn | None = None + + # Optional if this custom_fusion is only used as an input fusion. + push_block_spec_rule: CustomPushBlockSpecRuleFn | None = None + + # Optional alternative implementation to use instead of `fun` for when this + # custom fusion is run inside a Pallas kernel. + pallas_impl: Callable[..., Any] | None = None + + def __init__(self, fun: Callable[..., Any]): + functools.update_wrapper(self, fun) + self.fun = fun + + def def_pallas_impl(self, pallas_impl): + self.pallas_impl = pallas_impl + return pallas_impl + + def def_pull_block_spec( + self, pull_block_spec_rule: CustomPullBlockSpecRuleFn): + self.pull_block_spec_rule = pull_block_spec_rule + return pull_block_spec_rule + + def def_push_block_spec( + self, push_block_spec_rule: CustomPushBlockSpecRuleFn): + self.push_block_spec_rule = push_block_spec_rule + return push_block_spec_rule + + def def_eval_rule(self, eval_rule: CustomEvalRuleFn): + self.eval_rule = eval_rule + return eval_rule + + @functools.partial(api_boundary, + repro_api_name="jax.pallas.custom_fusion.__call__") + def __call__(self, *args, **kwargs): + debug_fun = api_util.debug_info("custom_fusion fun", self.fun, args, kwargs) + + # TODO(jburnim): Better error messages here. + assert self.eval_rule is not None + assert self.pull_block_spec_rule is not None + + try: + args = api_util.resolve_kwargs(self.fun, args, kwargs) + except TypeError as e: + raise TypeError( + "The input arguments to the custom_fusion-decorated function " + f"{debug_fun.func_name} could not be resolved to positional-only " + f"arguments. Binding failed with the error:\n{e}" + ) from e + + # flatten and get jaxpr + args_flat, in_tree = tree_util.tree_flatten(args) + in_avals = [core.get_aval(x) for x in args_flat] + flat_fun, out_tree = api_util.flatten_fun_nokwargs( + lu.wrap_init(self.fun, debug_info=debug_fun.with_unknown_names()), + in_tree) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) + + # if a Pallas implementation was provided, get its jaxpr + if self.pallas_impl is not None: + debug_pallas_impl = api_util.debug_info( + "custom_fusion pallas_impl", self.pallas_impl, args, kwargs) + + flat_pallas_impl, pallas_out_tree = api_util.flatten_fun_nokwargs( + lu.wrap_init(self.pallas_impl, debug_info=debug_pallas_impl), + in_tree) + # TODO(jburnim): Error if out_tree() and kernel_out_tree() are different? + del pallas_out_tree + pallas_jaxpr, _, pallas_consts = ( + pe.trace_to_jaxpr_dynamic(flat_pallas_impl, in_avals)) + else: + pallas_jaxpr = None + pallas_consts = [] + + # debug_info for rules + out_flat = custom_fusion_p.bind( + *consts, + *pallas_consts, + *args_flat, + jaxpr=jaxpr, + num_consts=len(consts), + eval_rule=self.eval_rule, + pull_block_spec_rule=self.pull_block_spec_rule, + push_block_spec_rule=self.push_block_spec_rule, + pallas_jaxpr=pallas_jaxpr, + pallas_num_consts=len(pallas_consts), + in_tree=in_tree, + out_tree=out_tree(), + kernel_out_tree=out_tree()) + + return tree_util.tree_unflatten(out_tree(), out_flat) + + +@custom_fusion_p.def_impl +def _custom_fusion_impl( + *args, + jaxpr: core.Jaxpr, + num_consts: int, + pallas_num_consts: int, + **_): + consts, _, args = util.split_list(args, [num_consts, pallas_num_consts]) # type: ignore[assignment] + return core.eval_jaxpr(jaxpr, consts, *args) + +mlir.register_lowering(custom_fusion_p, mlir.lower_fun( + _custom_fusion_impl, multiple_results=True)) + + +@custom_fusion_p.def_effectful_abstract_eval +def _custom_fusion_effectful_abstract_eval( + *args, + jaxpr: core.Jaxpr, + pallas_jaxpr: core.Jaxpr | None, + **_): + del args + # TODO(jburnim): Error if pallas_jaxpr has different number of outputs, or + # different shapes and types of outputs? + if jaxpr.effects: + raise NotImplementedError( + "custom_fusion-decorated function {jaxpr.debug_info.func_src_info} " + "has effects, which is not yet supported: {jaxpr.effects}") + if pallas_jaxpr is not None and pallas_jaxpr.effects: + raise NotImplementedError( + "custom_fusion-decorated function {jaxpr.debug_info.func_src_info} " + "has a pallas_impl with effects, which is not yet supported: " + f"{pallas_jaxpr.effects}") + return jaxpr.out_avals, jaxpr.effects + + +@block_spec_lib.register_eval_rule(custom_fusion_p) +def _custom_fusion_eval_rule( + ctx: block_spec_lib.KernelEvalContext, + *args, + eval_rule: CustomEvalRuleFn, + num_consts: int, + pallas_num_consts: int, + **_): + args = args[num_consts + pallas_num_consts:] + return eval_rule(CustomEvalContext( + out_block_specs=ctx.out_block_specs, + out_block_indices=ctx.get_out_block_indices(), + ), *args) + + +# TODO(jburnim): Lowering rules for SC and Mosaic GPU. + +@mosaic_lowering.register_lowering_rule(custom_fusion_p) +def _custom_fusion_mosaic_lowering_rule( + ctx: mosaic_lowering.LoweringRuleContext, + *args, + jaxpr: core.Jaxpr, + num_consts: int, + pallas_jaxpr: core.Jaxpr | None, + pallas_num_consts: int, + **_): + consts, pallas_consts, args = util.split_list( + args, [num_consts, pallas_num_consts]) + if pallas_jaxpr is None: + pallas_jaxpr = jaxpr + pallas_consts = consts + lowering_context = ctx.lowering_context.replace(block_shapes=ctx.block_shapes) + return mosaic_lowering.jaxpr_subcomp( + lowering_context, pallas_jaxpr, *pallas_consts, *args) + + +@block_spec_lib.register_pull_block_spec_rule(custom_fusion_p) # type: ignore[arg-type] +def _custom_fusion_pull_block_spec_rule( + ctx : block_spec_lib.PullRuleContext, + out_block_specs : tuple[pallas_core.BlockSpec, ...], + *, + pull_block_spec_rule : CustomPullBlockSpecRuleFn, + **_, +) -> Sequence[pallas_core.BlockSpec]: + del ctx + return pull_block_spec_rule(out_block_specs) + + +@block_spec_lib.register_push_block_spec_rule(custom_fusion_p) # type: ignore[arg-type] +def _custom_fusion_push_block_spec_rule( + ctx : block_spec_lib.PushRuleContext, + *block_specs : pallas_core.BlockSpec, + push_block_spec_rule : CustomPushBlockSpecRuleFn, + **_ +) -> tuple[pallas_core.BlockSpec, ...]: + del ctx + # TODO(jburnim): Better error message if push_block_spec_rule is None. + return push_block_spec_rule(block_specs) + + +@block_spec_lib.register_usage_rule(custom_fusion_p) # type: ignore[arg-type] +def _custom_fusion_usage_rule( + ctx : block_spec_lib.UsageRuleContext, + used_out: Sequence[set[block_spec_lib.Usage]], + *, + jaxpr: core.Jaxpr, + **_ +) -> Sequence[set[block_spec_lib.Usage]]: + del ctx + # TODO(jburnim): Error if jaxpr.jaxpr gives different usage than pallas_jaxpr? + read_usage_env = block_spec_lib.compute_usage(jaxpr, used_out) + return util.safe_map(read_usage_env, jaxpr.invars) diff --git a/jax/_src/pallas/fuser/fusable.py b/jax/_src/pallas/fuser/fusable.py deleted file mode 100644 index b075c6d136c9..000000000000 --- a/jax/_src/pallas/fuser/fusable.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright 2025 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Fusable primitive.""" - -import jax -from jax._src import api_util -from jax._src import core as jax_core -from jax._src import linear_util as lu -from jax._src import tree_util -from jax._src import util -from jax._src.interpreters import mlir -from jax._src.interpreters import partial_eval as pe -from jax._src.pallas.fuser import fusion as fusion_lib - -fusable_p = jax_core.Primitive('fusable') -fusable_p.multiple_results = True - - -def _get_aval(x): - return jax_core.raise_to_shaped(jax_core.get_aval(x)) - - -def _make_trivial_fusion(x: jax.Array) -> fusion_lib.Fusion: - return fusion_lib.Fusion( - func=lambda: x, - in_type=((), {}), - out_type=jax.ShapeDtypeStruct(x.shape, x.dtype), - ) - - -def fusable(f): - def wrapper(*args): - def wrapped(*args): - in_fusions = tree_util.tree_map(_make_trivial_fusion, args) - return f(*in_fusions, None) - - flat_args, in_tree = tree_util.tree_flatten(args) - debug_info = api_util.debug_info('fusable', wrapped, args, {}) - flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( - lu.wrap_init(wrapped, debug_info=debug_info), in_tree - ) - flat_avals = [_get_aval(x) for x in flat_args] - jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) - out_tree = out_tree_thunk() - out = fusable_p.bind( - *consts, - *flat_args, - jaxpr=jaxpr, - num_consts=len(consts), - in_tree=in_tree, - out_tree=out_tree, - func=f, - ) - return tree_util.tree_unflatten(out_tree, out) - - return wrapper - - -@fusable_p.def_impl -def _(*consts_and_args, jaxpr, num_consts, **_): - consts, args = util.split_list(consts_and_args, [num_consts]) - return jax_core.eval_jaxpr(jaxpr, consts, *args) - - -mlir.register_lowering(fusable_p, mlir.lower_fun(fusable_p.impl)) - - -@fusable_p.def_abstract_eval -def _(*args, jaxpr, **kwargs): - del args, kwargs - return [v.aval for v in jaxpr.outvars] diff --git a/jax/_src/pallas/fuser/fuser_utils.py b/jax/_src/pallas/fuser/fuser_utils.py index ff44725bb958..d6eaebd0024f 100644 --- a/jax/_src/pallas/fuser/fuser_utils.py +++ b/jax/_src/pallas/fuser/fuser_utils.py @@ -20,14 +20,13 @@ from jax._src.interpreters import partial_eval as pe - def make_jaxpr(f, *args, **kwargs): flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) - flat_avals = [core.get_aval(x) for x in flat_args] + flat_avals = [core.shaped_abstractify(x) for x in flat_args] debug_info = api_util.debug_info('make_jaxpr', f, args, kwargs) flat_fun, out_tree_thunk = api_util.flatten_fun( lu.wrap_init(f, debug_info=debug_info), in_tree ) - jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) out_tree = out_tree_thunk() return jaxpr, consts, in_tree, out_tree diff --git a/jax/_src/pallas/fuser/fusible.py b/jax/_src/pallas/fuser/fusible.py new file mode 100644 index 000000000000..cee23edb5230 --- /dev/null +++ b/jax/_src/pallas/fuser/fusible.py @@ -0,0 +1,135 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fusible primitive.""" +from functools import partial +from typing import Any + +import jax +from jax._src import api_util +from jax._src import core as jax_core +from jax._src.interpreters import batching +from jax._src import linear_util as lu +from jax._src.traceback_util import api_boundary +from jax._src import tree_util +from jax._src import util +from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe +from jax._src.pallas.fuser import fusion as fusion_lib + +fusible_p = jax_core.Primitive('fusible') +fusible_p.multiple_results = True + +def _fusible_is_high(*_, jaxpr, **params): + del params + return jaxpr.is_high + +fusible_p.is_high = _fusible_is_high # type: ignore + + +def _make_trivial_fusion(x: jax.Array) -> fusion_lib.Fusion: + return fusion_lib.Fusion( + func=lambda: x, + in_type=((), {}), + out_type=jax.typeof(x), + ) + + +@partial(api_boundary, repro_api_name="fuser.fusible") +def fusible(f=None, *, output_fusion_prefix: Any = True): + def decorator(f): + def wrapper(*args): + def wrapped(*args): + in_fusions = tree_util.tree_map(_make_trivial_fusion, args) + return f(*in_fusions, None) + + flat_args, in_tree = tree_util.tree_flatten(args) + debug_info = api_util.debug_info('fusible', wrapped, args, {}) + flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( + lu.wrap_init(wrapped, debug_info=debug_info), in_tree + ) + flat_avals = [jax_core.get_aval(x) for x in flat_args] + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) + out_tree = out_tree_thunk() + out = fusible_p.bind( + *consts, + *flat_args, + jaxpr=jaxpr, + num_consts=len(consts), + in_tree=in_tree, + out_tree=out_tree, + func=f, + output_fusion_prefix=output_fusion_prefix, + ) + return tree_util.tree_unflatten(out_tree, out) + + return wrapper + + if f is not None: + return decorator(f) + return decorator + + +@fusible_p.def_impl +def _(*consts_and_args, jaxpr, num_consts, **_): + consts, args = util.split_list(consts_and_args, [num_consts]) + return jax_core.eval_jaxpr(jaxpr, consts, *args) + + +mlir.register_lowering(fusible_p, mlir.lower_fun(fusible_p.impl)) + + +@fusible_p.def_effectful_abstract_eval +def _(*args, jaxpr, **kwargs): + del args, kwargs + return [v.aval for v in jaxpr.outvars], jaxpr.effects + + +def _fusible_trivial_batching_rule(axis_data, args, dims, **kwargs): + if axis_data.size != 1: + raise NotImplementedError('fusible does not support non-trivial batching') + + unbatched_args = tuple( + a if (d is batching.not_mapped or d is None) else a[d] + for a, d in zip(args, dims, strict=True) + ) + out_unbatched = fusible_p.bind(*unbatched_args, **kwargs) + out = tuple(o[None] for o in out_unbatched) + + return out, (0,) * len(out) + +batching.fancy_primitive_batchers[fusible_p] = _fusible_trivial_batching_rule + + +def _fusible_to_lojax(*hi_args, jaxpr, num_consts, **_): + const_in_avals = jaxpr.in_aval_qdds[:num_consts] + num_lo_consts = sum(len(aval.lo_ty()) for aval in const_in_avals) + + lo_args = [ + lo_val + for aval, x in util.safe_zip(jaxpr.in_aval_qdds, hi_args) + for lo_val in (aval.read_loval(x) if aval.has_qdd else aval.lower_val(x)) + ] + + closed_jaxpr = jax_core.ClosedJaxpr(jaxpr, lo_args[:num_lo_consts]) + + lo_jaxpr = pe.lower_jaxpr(closed_jaxpr) + all_outs = fusible_p.bind(*lo_args, jaxpr=lo_jaxpr.jaxpr, num_consts=num_lo_consts) + + out_mut, lo_outs = util.split_list(all_outs, [pe.num_himuts_out(jaxpr)]) + pe.apply_himut(jaxpr, hi_args, out_mut) + return pe.raise_lo_outs(jaxpr.out_avals, lo_outs) + + +fusible_p.to_lojax = _fusible_to_lojax diff --git a/jax/_src/pallas/fuser/fusable_dtype.py b/jax/_src/pallas/fuser/fusible_dtype.py similarity index 65% rename from jax/_src/pallas/fuser/fusable_dtype.py rename to jax/_src/pallas/fuser/fusible_dtype.py index e5bc9ab683ab..7a2628080fe1 100644 --- a/jax/_src/pallas/fuser/fusable_dtype.py +++ b/jax/_src/pallas/fuser/fusible_dtype.py @@ -12,16 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Custom fusable dtypes.""" +"""Custom fusible dtypes.""" import abc import dataclasses import functools -from typing import Any, Sequence, TypeVar +import itertools as it +from typing import Any, TypeVar +from collections.abc import Sequence import jax from jax._src import api_util from jax._src import core +from jax._src import custom_derivatives from jax._src import dtypes from jax._src import linear_util as lu from jax._src import source_info_util @@ -34,7 +37,7 @@ from jax._src.pallas import pallas_call from jax._src.pallas import primitives as pallas_primitives from jax._src.pallas.fuser import block_spec -from jax._src.pallas.fuser.fusable import fusable_p +from jax._src.pallas.fuser.fusible import fusible_p from jax._src.state import discharge as state_discharge from jax._src.state import primitives as state_primitives from jax._src.util import foreach @@ -54,7 +57,7 @@ @pack_dtype_p.def_abstract_eval def pack_dtype_abstract_eval(*xs, dtype): - if dtypes.issubdtype(dtype, FusableElementDType): + if dtypes.issubdtype(dtype, FusibleElementDType): return dtype.abstract_pack(*xs) raise ValueError("Attempted to pack non-fusion dtype: {dtype}") @@ -69,33 +72,31 @@ def pack(*xs, dtype): @unpack_dtype_p.def_abstract_eval def unpack_dtype_abstract_eval(x): - if dtypes.issubdtype(x.dtype, FusableElementDType): + if dtypes.issubdtype(x.dtype, FusibleElementDType): return x.dtype.abstract_unpack(x) - elif isinstance(x.dtype, pallas_core.AbstractMemoryRef): + elif isinstance(x.dtype, state.AbstractRef): raise NotImplementedError() raise ValueError("Attempted to unpack non-fusion dtype: {dtype}") def unpack(x): - return unpack_dtype_p.bind(x) + return tuple(unpack_dtype_p.bind(x)) -class FusableElementDType(dtypes.extended): - """Scalar dtype for fusable dtypes.""" +class FusibleElementDType(dtypes.extended): + """Scalar dtype for fusible dtypes.""" - pass - -class FusableTyRules: +class FusibleTyRules: allow_conversion: bool = False class FusionDType(dtypes.ExtendedDType, metaclass=abc.ABCMeta): - """Base class for fusable extended dtypes.""" + """Base class for fusible extended dtypes.""" _op_registry = {} - _rules = FusableTyRules - type = FusableElementDType + _rules = FusibleTyRules + type = FusibleElementDType @abc.abstractmethod def abstract_unpack(self, x) -> Sequence[Any]: @@ -121,12 +122,20 @@ def name(self): return str(self) @abc.abstractmethod - def pull_block_spec_one_step(self, *args, **kwargs): + def pull_block_spec_one_step(self, aval_out, *args, **kwargs): + raise NotImplementedError() + + @abc.abstractmethod + def unpack_push_block_spec(self, aval_in, *args, **kwargs): + raise NotImplementedError() + + @abc.abstractmethod + def unpack_pull_block_spec(self, aval_in, *args, **kwargs): raise NotImplementedError() def physicalize(f): - """Runs a function that contains fusable extended dtypes.""" + """Runs a function that contains fusible extended dtypes.""" def wrapper(*args, **kwargs): if kwargs: @@ -136,8 +145,8 @@ def wrapper(*args, **kwargs): wrapped_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( lu.wrap_init(f, debug_info=debug_info), treedef ) - avals = [core.ShapedArray(a.shape, a.dtype) for a in flattened_args] - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals) + avals = [core.get_aval(a) for a in flattened_args] + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals) new_jaxpr = physicalize_closed_jaxpr(core.ClosedJaxpr(jaxpr, consts)) out_flat = core.eval_jaxpr( new_jaxpr.jaxpr, new_jaxpr.consts, *flattened_args @@ -155,18 +164,20 @@ def physicalize_closed_jaxpr(jaxpr: core.ClosedJaxpr) -> core.ClosedJaxpr: flat_avals, treedef = tree_util.tree_flatten(in_avals) debug_info = api_util.debug_info("physicalize_closed_jaxpr", fun, (), {}) wrapped_fun, _ = api_util.flatten_fun_nokwargs( - lu.wrap_init(fun, debug_info=debug_info), treedef + lu.wrap_init(fun, debug_info=debug_info.with_unknown_names()), treedef ) - new_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, flat_avals) + new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, flat_avals) assert len(new_jaxpr.constvars) == len(consts), "Mismatched consts" return core.ClosedJaxpr(new_jaxpr, consts) def _physical_aval(aval): if isinstance(aval, core.ShapedArray): + if isinstance(aval.dtype, FusionDType): + return aval.dtype.abstract_unpack(aval) return core.ShapedArray(aval.shape, aval.dtype) if isinstance(aval, state.AbstractRef): - if isinstance(aval.dtype, FusionDType): + if _is_fusion_type(aval): unpacked = aval.dtype.abstract_unpack(aval.inner_aval) return tuple(aval.update(inner_aval=u) for u in unpacked) return aval @@ -183,12 +194,11 @@ def _flat_jaxpr_eval(consts, args): const_avals = [_physical_aval(v.aval) for v in jaxpr.constvars] flat_avals, treedef = jax.tree.flatten((const_avals, in_avals)) debug_info = api_util.debug_info( - "physicalize_jaxpr", _flat_jaxpr_eval, (), {} - ) + "physicalize_jaxpr", _flat_jaxpr_eval, (const_avals, in_avals), {}) wrapped_fun, _ = api_util.flatten_fun_nokwargs( lu.wrap_init(_flat_jaxpr_eval, debug_info=debug_info), treedef ) - new_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, flat_avals) + new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, flat_avals) assert not consts new_jaxpr = pe.convert_invars_to_constvars( new_jaxpr, len(tree_util.tree_leaves(const_avals)) @@ -205,7 +215,7 @@ class Context: def physicalize_interp( jaxpr: core.Jaxpr, consts: Sequence[core.Value], *args: core.Value ): - """Physicalizes a jaxpr by replacing fusable dtypes with physical types.""" + """Physicalizes a jaxpr by replacing fusible dtypes with physical types.""" # TODO: Merge into JAX core. env: dict[core.Var, Any] = {} @@ -236,11 +246,10 @@ def write_env(var: core.Var, val: Any): eqn.ctx.manager, ): # need to check types and then invoke the correct rule. - in_types = [aval.dtype for aval in avals_in] # pytype: disable=attribute-error ctx = Context( avals_in=avals_in, avals_out=[var.aval for var in eqn.outvars] ) - custom_rule = _phys_find_rule(eqn.primitive, in_types) + custom_rule = _phys_find_rule(eqn.primitive, avals_in) if custom_rule: outvals = custom_rule(ctx, *invals, **eqn.params) else: @@ -248,7 +257,7 @@ def write_env(var: core.Var, val: Any): outvals = eqn.primitive.bind(*subfuns, *invals, **bind_params) if eqn.primitive.multiple_results: - assert len(outvals) == len(eqn.outvars) + assert len(outvals) == len(eqn.outvars), eqn foreach(write_env, eqn.outvars, outvals) else: write_env(eqn.outvars[0], outvals) @@ -256,11 +265,21 @@ def write_env(var: core.Var, val: Any): return map(read_env, jaxpr.outvars) -def _phys_find_rule(primitive, types: Sequence[dtypes.DType]): +def _is_fusion_type(aval: core.AbstractValue): + """Returns whether an aval is an array containing fusion types.""" + return ( + isinstance(aval, (core.ShapedArray, state.AbstractRef)) + and hasattr(aval, 'dtype') + and isinstance(aval.dtype, FusionDType) + ) + + +def _phys_find_rule(primitive, avals: Sequence[core.AbstractValue]): """Finds the physicalization rule for a primitive.""" if primitive in _physicalize_rules: return _physicalize_rules[primitive] - fusion_types = {type_ for type_ in types if isinstance(type_, FusionDType)} + + fusion_types = {aval.dtype for aval in avals if _is_fusion_type(aval)} # pytype: disable=attribute-error if len(fusion_types) == 0: return None elif len(fusion_types) > 1: @@ -275,10 +294,8 @@ def _phys_find_rule(primitive, types: Sequence[dtypes.DType]): def _assert_no_fusion_types(avals: Sequence[core.AbstractValue]): - for aval in avals: - if isinstance(aval, (core.ShapedArray, state.AbstractRef)): - if isinstance(aval.dtype, FusionDType): - raise NotImplementedError(f"Fusion type found in avals: {avals}") + if any(_is_fusion_type(aval) for aval in avals): + raise NotImplementedError(f"Fusion type found in avals: {avals}") def _pallas_call_physicalize_rule( @@ -288,10 +305,13 @@ def _pallas_call_physicalize_rule( _assert_no_fusion_types(ctx.avals_out) with grid_mapping.trace_env(): new_jaxpr = physicalize_closed_jaxpr(core.ClosedJaxpr(jaxpr, ())) - num_new_vals = len(new_jaxpr.jaxpr.invars) - len(jaxpr.invars) - grid_mapping = grid_mapping.replace( - num_scratch_operands=grid_mapping.num_scratch_operands + num_new_vals - ) + if diff := len(new_jaxpr.jaxpr.invars) - len(jaxpr.invars): + num_scratch_avals = len(grid_mapping.scratch_avals) + diff + new_scratch_avals = tuple(v.aval for v in + new_jaxpr.jaxpr.invars[-num_scratch_avals:]) + grid_mapping = grid_mapping.replace( + scratch_avals=new_scratch_avals + ) return pallas_call.pallas_call_p.bind( *args, jaxpr=new_jaxpr.jaxpr, grid_mapping=grid_mapping, **kwargs ) @@ -302,9 +322,9 @@ def _pallas_call_physicalize_rule( def _cond_physicalize_rule(ctx: Context, *args, branches, **kwargs): _assert_no_fusion_types(ctx.avals_out) - physicalized_branches = [ + physicalized_branches = tuple( physicalize_closed_jaxpr(branch) for branch in branches - ] + ) flat_args = jax.tree.leaves(args) return conditionals.cond_p.bind( *flat_args, branches=physicalized_branches, **kwargs @@ -314,6 +334,41 @@ def _cond_physicalize_rule(ctx: Context, *args, branches, **kwargs): _physicalize_rules[conditionals.cond_p] = _cond_physicalize_rule +@lu.transformation2 +def _physicalize_transform(f, *args): + vals, zeros = args[::2], args[1::2] + assert len(vals) == len(zeros) + wrapper = lambda *inner_vals: f( + *it.chain.from_iterable(zip(inner_vals, zeros)) + ) + return physicalize(wrapper)(*vals) + + +@lu.transformation2 +def _physicalize_transform_bwd(f, const_avals, *args): + return [custom_derivatives.Zero(a) for a in const_avals] + list( + physicalize(f)(*args) + ) + + +def _custom_vjp_call_physicalize_rule( + ctx: Context, *args, call_jaxpr, num_consts, fwd_jaxpr_thunk, bwd, **kwargs +): + _assert_no_fusion_types(ctx.avals_out) + new_jaxpr = physicalize_closed_jaxpr(call_jaxpr) + fun = lu.wrap_init(core.jaxpr_as_fun(new_jaxpr), + debug_info=call_jaxpr.jaxpr.debug_info) + fwd = custom_derivatives.lift_fwd(num_consts, fwd_jaxpr_thunk) + fwd_physicalized = _physicalize_transform(fwd) + const_avals, _ = util.split_list(new_jaxpr.in_avals, [num_consts]) + bwd_physicalized = _physicalize_transform_bwd(bwd, const_avals) + return custom_derivatives.custom_vjp_call_p.bind( + fun, fwd_physicalized, bwd_physicalized, *args, **kwargs + ) + +_physicalize_rules[custom_derivatives.custom_vjp_call_p] = _custom_vjp_call_physicalize_rule + + def _run_state_rule(ctx: Context, *args, jaxpr, which_linear, is_initialized): _assert_no_fusion_types(ctx.avals_in) _assert_no_fusion_types(ctx.avals_out) @@ -365,11 +420,20 @@ def _scan_rule(ctx: Context, *args, jaxpr, **params): def _while_rule( - ctx: Context, *args, body_jaxpr, cond_jaxpr, body_nconsts, **params + ctx: Context, *args, body_jaxpr, cond_jaxpr, body_nconsts, + cond_nconsts, **params ): _assert_no_fusion_types(ctx.avals_out) cond_avals = [v.aval for v in cond_jaxpr.jaxpr.invars] - _assert_no_fusion_types(cond_avals) + _, cond_in_avals = util.split_list(cond_avals, [cond_nconsts]) + _assert_no_fusion_types(cond_in_avals) + new_cond_jaxpr = physicalize_closed_jaxpr(cond_jaxpr) + new_num_cond_consts = ( + cond_nconsts + + len(new_cond_jaxpr.jaxpr.invars) + - len(cond_jaxpr.jaxpr.invars) + ) + body_avals = [v.aval for v in body_jaxpr.jaxpr.invars] _, body_in_avals = util.split_list(body_avals, [body_nconsts]) _assert_no_fusion_types(body_in_avals) @@ -380,15 +444,25 @@ def _while_rule( - len(body_jaxpr.jaxpr.invars) ) flat_args = tree_util.tree_leaves(args) - assert len(flat_args) == len(new_body_jaxpr.jaxpr.invars), ( - f"Length mismatch: {len(flat_args)=} !=" + cond_consts, body_consts, flat_args = util.split_list( + flat_args, [new_num_cond_consts, new_num_body_consts] + ) + assert len(flat_args) + len(body_consts) == len( + new_body_jaxpr.jaxpr.invars), ( + f"Length mismatch: {len(flat_args) + len(body_consts)} !=" f" {len(new_body_jaxpr.jaxpr.invars)=}" ) + assert len(flat_args) + len(cond_consts) == len( + new_cond_jaxpr.jaxpr.invars), ( + f"Length mismatch: {len(flat_args) + len(cond_consts)} !=" + f" {len(new_cond_jaxpr.jaxpr.invars)=}" + ) return jax.lax.while_p.bind( - *flat_args, + *(cond_consts + body_consts + flat_args), body_jaxpr=new_body_jaxpr, - cond_jaxpr=cond_jaxpr, + cond_jaxpr=new_cond_jaxpr, body_nconsts=new_num_body_consts, + cond_nconsts=new_num_cond_consts, **params, ) @@ -413,7 +487,7 @@ def _unpack_rule(_, arg): def _swap_rule(ctx: Context, ref, val, *args, tree): ref_aval, *_ = ctx.avals_in - if not isinstance(ref_aval.dtype, FusionDType): + if not _is_fusion_type(ref_aval): return state_primitives.swap_p.bind(ref, val, *args, tree=tree) return ref_aval.dtype.swap(ref, val, *args, tree=tree) @@ -423,7 +497,7 @@ def _swap_rule(ctx: Context, ref, val, *args, tree): def _get_rule(ctx: Context, ref, *args, tree): ref_aval, *_ = ctx.avals_in - if not isinstance(ref_aval.dtype, FusionDType): + if not _is_fusion_type(ref_aval): return state_primitives.get_p.bind(ref, *args, tree=tree) return ref_aval.dtype.get(ref, *args, tree=tree) @@ -433,8 +507,7 @@ def _get_rule(ctx: Context, ref, *args, tree): @block_spec.register_eval_rule(pack_dtype_p) def _pack_dtype_eval_rule(eval_ctx: block_spec.KernelEvalContext, *args, dtype): - del eval_ctx - return pack_dtype_p.bind(*args, dtype=dtype) + return dtype.pack_eval_rule(eval_ctx, *args) @block_spec.register_pull_block_spec_rule(pack_dtype_p) @@ -444,16 +517,46 @@ def _pack_dtype_pull_rule( *, dtype: FusionDType, ): - del ctx - return dtype.pull_block_spec_one_step(block_spec) # pytype: disable=attribute-error + aval_out = ctx.avals_out[0] + return dtype.pull_block_spec_one_step(aval_out, block_spec) # pytype: disable=attribute-error + + +@block_spec.register_push_block_spec_rule(unpack_dtype_p) +def _unpack_dtype_push_rule( + ctx: block_spec.PushRuleContext, + block_spec: pallas_core.BlockSpec, +): + aval_in = ctx.avals_in[0] + assert isinstance(aval_in, core.ShapedArray) + assert isinstance(aval_in.dtype, FusionDType), aval_in.dtype + return aval_in.dtype.unpack_push_block_spec(aval_in, block_spec) # pytype: disable=attribute-error + + +@block_spec.register_pull_block_spec_rule(unpack_dtype_p) +def _unpack_dtype_pull_rule( + ctx: block_spec.PushRuleContext, + block_specs: pallas_core.BlockSpec, +): + aval_in = ctx.avals_in[0] + assert isinstance(aval_in, core.ShapedArray) + assert isinstance(aval_in.dtype, FusionDType), aval_in.dtype + return aval_in.dtype.unpack_pull_block_spec(aval_in, *block_specs) + + +@block_spec.register_eval_rule(unpack_dtype_p) +def _unpack_dtype_eval_rule(eval_ctx: block_spec.KernelEvalContext, *args): + aval_in = eval_ctx.avals_in[0] + assert isinstance(aval_in, core.ShapedArray) + assert isinstance(aval_in.dtype, FusionDType), aval_in.dtype + return aval_in.dtype.unpack_eval_rule(eval_ctx, *args) -def _fusable_physicalize_rule( +def _fusible_physicalize_rule( _, *consts_and_args, jaxpr, num_consts, in_tree, out_tree, func ): consts, _ = util.split_list(consts_and_args, [num_consts]) new_jaxpr = physicalize_closed_jaxpr(core.ClosedJaxpr(jaxpr, consts)) - return fusable_p.bind( + return fusible_p.bind( *consts_and_args, jaxpr=new_jaxpr.jaxpr, num_consts=num_consts, @@ -463,4 +566,4 @@ def _fusable_physicalize_rule( ) -_physicalize_rules[fusable_p] = _fusable_physicalize_rule +_physicalize_rules[fusible_p] = _fusible_physicalize_rule diff --git a/jax/_src/pallas/fuser/fusion.py b/jax/_src/pallas/fuser/fusion.py index eff8c36ddb08..6319722a9823 100644 --- a/jax/_src/pallas/fuser/fusion.py +++ b/jax/_src/pallas/fuser/fusion.py @@ -17,7 +17,8 @@ from __future__ import annotations import dataclasses -from typing import Any, Callable, Generic, ParamSpec, TypeVar +from typing import Any, Generic, ParamSpec, TypeVar +from collections.abc import Callable import jax from jax._src import util diff --git a/jax/_src/pallas/fuser/jaxpr_fusion.py b/jax/_src/pallas/fuser/jaxpr_fusion.py index 3d36b8f3e2fd..837df0412abc 100644 --- a/jax/_src/pallas/fuser/jaxpr_fusion.py +++ b/jax/_src/pallas/fuser/jaxpr_fusion.py @@ -14,35 +14,34 @@ """Fuses a function.""" +from collections.abc import Sequence +import functools from typing import Any - import jax from jax._src import api_util from jax._src import core as jax_core from jax._src import linear_util as lu +from jax._src.traceback_util import api_boundary from jax._src import tree_util from jax._src.interpreters import partial_eval as pe - -from jax._src.pallas.fuser import fusable_dtype +from jax._src.pallas.fuser import fusible_dtype from jax._src.pallas.fuser import fusion as fusion_lib -from jax._src.pallas.fuser.fusable import fusable_p - - -def _get_aval(x): - return jax_core.raise_to_shaped(jax_core.get_aval(x)) +from jax._src.pallas.fuser.fusible import fusible_p -def fuse(f=None, *, physicalize: bool = False, debug: bool = False): - """Fuses a function into a single fusable. +@functools.partial(api_boundary, repro_api_name="fuser.fuse") +def fuse(f=None, *, resolve_fusion_dtypes: bool = True, debug: bool = False): + """Fuses a function into a single fusible. Args: f: The function to fuse. - physicalize: (experimental) whether to physicalize the function. + resolve_fusion_dtypes: (experimental) whether or not to resolve fusion + dtypes (which don't correspond to physical dtypes) debug: Whether to print debug information. - There should be a single call to a `fusable` inside the body of `f`. `fuse` + There should be a single call to a `fusible` inside the body of `f`. `fuse` returns a transformed function that will fuse the surrounding computation into - the fusable and invoke it. + the fusible and invoke it. """ def decorator(f): @@ -52,8 +51,8 @@ def wrapper(*args, **kwargs): flat_fun, out_tree_thunk = api_util.flatten_fun( lu.wrap_init(f, debug_info=debug_info), in_tree ) - flat_avals = [_get_aval(x) for x in flat_args] - jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) + flat_avals = [jax_core.get_aval(x) for x in flat_args] + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) if debug: print("Jaxpr before fusion:") print(jaxpr) @@ -61,8 +60,8 @@ def wrapper(*args, **kwargs): out_flat = fuse_jaxpr(jaxpr, out_tree, consts, *flat_args) return tree_util.tree_unflatten(out_tree, out_flat) - if physicalize: - wrapper = fusable_dtype.physicalize(wrapper) + if resolve_fusion_dtypes: + wrapper = fusible_dtype.physicalize(wrapper) return wrapper if f is not None: @@ -70,18 +69,19 @@ def wrapper(*args, **kwargs): return decorator -_fusable: dict[jax_core.Primitive, Any] = {} +_fusible: dict[jax_core.Primitive, Any] = {} -def construct_fusion( +def _construct_fusion_jaxpr( candidate_values, jaxpr: jax_core.Jaxpr, outvars, *invars, **kwargs -) -> fusion_lib.Fusion: +): flat_outvars, out_tree = tree_util.tree_flatten(outvars) flat_invars, in_tree = tree_util.tree_flatten((invars, kwargs)) new_jaxpr_no_dce = jaxpr.replace( outvars=flat_outvars, constvars=jaxpr.constvars + jaxpr.invars, invars=flat_invars, + debug_info=jaxpr.debug_info.with_unknown_names() ) new_jaxpr, used_consts, used_invars = pe.dce_jaxpr_consts( new_jaxpr_no_dce, @@ -94,23 +94,158 @@ def construct_fusion( c for used, c in zip(used_consts, candidate_values, strict=True) if used ) kernel_in_tree = tree_util.tree_structure((invars, kwargs)) + flat_in_type = [x.aval for x in flat_invars] + in_type = tree_util.tree_unflatten(kernel_in_tree, flat_in_type) + out_type = tree_util.tree_unflatten( + out_tree, + [x.aval for x in flat_outvars], + ) + return new_jaxpr, new_values, in_type, out_type, out_tree + + +def construct_fusion( + candidate_values, jaxpr: jax_core.Jaxpr, outvars, *invars, **kwargs +) -> fusion_lib.Fusion: + new_jaxpr, new_values, in_type, out_type, out_tree = _construct_fusion_jaxpr( + candidate_values, jaxpr, outvars, *invars, **kwargs + ) def _fn(*args, **kwargs): flat_args, _ = tree_util.tree_flatten((args, kwargs)) out_flat = jax_core.eval_jaxpr(new_jaxpr, new_values, *flat_args) return tree_util.tree_unflatten(out_tree, out_flat) - flat_in_type = [ - jax.ShapeDtypeStruct(x.aval.shape, x.aval.dtype) for x in flat_invars - ] - in_type = tree_util.tree_unflatten(kernel_in_tree, flat_in_type) - out_type = tree_util.tree_unflatten( - out_tree, - [jax.ShapeDtypeStruct(x.aval.shape, x.aval.dtype) for x in flat_outvars], - ) return fusion_lib.Fusion(_fn, in_type, out_type) +def _find_downstream( + jaxpr: jax_core.Jaxpr, in_used: Sequence[bool] +) -> tuple[bool, ...]: + # TODO(sharadmv): We use partial_eval to query downstream dependencies which + # is not an officially sanctioned way to do so, since PE is really used for + # AD. In the future, we should have a special Jaxpr API that queries this. + _, _, out_used, *_ = pe.partial_eval_jaxpr_custom( + jaxpr, + in_unknowns=in_used, + in_inst=in_used, + ensure_out_unknowns=False, + ensure_out_inst=False, + saveable=lambda *_, **__: False, + ) + return tuple(out_used) + + +def _construct_output_permutation( + used: list[tuple[bool, ...]], +) -> list[int]: + order = [] + for u in used: + true_vals = [i for i in range(len(u)) if u[i]] + order.extend(true_vals) + return [order.index(i) for i in range(len(order))] + + +def _construct_output_fusions( + candidate_values, + jaxpr, + out_tree, + fusion_eqn_index, + fusion_eqn_outvars, # Flat list of vars output by the fusible eqn + fusion_eqn_out_tree, # Tree structure of the fusible eqn outputs + output_fusion_prefix, # Pytree defining output groups +): + # 1. Create jaxpr_out: represents computation *after* the fusible + # Inputs: fusion_eqn_outvars + # Outputs: jaxpr.outvars + jaxpr_out, all_values, _, _, _ = _construct_fusion_jaxpr( + candidate_values, + jaxpr.replace( + eqns=jaxpr.eqns[:fusion_eqn_index] + + jaxpr.eqns[fusion_eqn_index + 1 :] + ), + tree_util.tree_unflatten(out_tree, jaxpr.outvars), # Original outputs + tree_util.tree_unflatten( + fusion_eqn_out_tree, fusion_eqn_outvars + ), # Fusible outputs as inputs + ) + + # 2. Group fusible outputs based on the mask + unflat_fusible_outvars = jax.tree.unflatten( + fusion_eqn_out_tree, fusion_eqn_outvars + ) + partial_flat = jax.tree.structure(output_fusion_prefix).flatten_up_to( + unflat_fusible_outvars + ) + + # 3. Calculate dependencies and check disjointedness + downstream_outputs_used_masks = [] # List of bool tuples, one per group + already_used_final_outputs = set() # Indices of final outputs already claimed + for outvars_group in partial_flat: + # Identify vars in this group + used_fusible_outvars = set(jax.tree.leaves(outvars_group)) + # Create mask for jaxpr_out inputs corresponding to this group + in_used_mask = [ + True if v in used_fusible_outvars else False for v in jaxpr_out.invars + ] + # Trace dependencies through jaxpr_out to find which final outputs are affected + downstream_used_mask = _find_downstream( + jaxpr_out, in_used_mask + ) # Mask for jaxpr_out.outvars (== jaxpr.outvars) + + # Check for overlap in final output usage across groups + for i, used in enumerate(downstream_used_mask): + if used: + if i in already_used_final_outputs: + raise ValueError( + "Outputs must be disjoint in order to use separate output fusions" + ) + already_used_final_outputs.add(i) + downstream_outputs_used_masks.append(downstream_used_mask) + + # 4. Construct output permutation needed to restore original output order + output_permutation = _construct_output_permutation( + downstream_outputs_used_masks + ) + + # Construct fusions for each group by DCEing the jaxpr_out + output_fusions = [] + for i, outvars_group in enumerate(partial_flat): + flat_group_vars, _ = tree_util.tree_flatten(outvars_group) + downstream_used_mask = downstream_outputs_used_masks[i] + + used_jaxpr_invars = [False] * len(all_values) + [ + v in flat_group_vars for v in jaxpr_out.invars + ] + jaxpr_out_for_group, used_consts, _ = pe.dce_jaxpr_consts( + jaxpr_out, downstream_used_mask, instantiate=used_jaxpr_invars + ) + values_for_jaxpr = tuple( + c for used, c in zip(used_consts, all_values, strict=True) if used + ) + + def _fn(jaxpr, vals, *args, **kwargs): + flat_args, _ = tree_util.tree_flatten((args, kwargs)) + out_flat = jax_core.eval_jaxpr(jaxpr, vals, *flat_args) + return tuple(out_flat) + + fn = functools.partial(_fn, jaxpr_out_for_group, values_for_jaxpr) + in_type = jax.tree.map(lambda x: x.aval, outvars_group) + out_type = tuple(v.aval for v in jaxpr_out_for_group.outvars) + fusion = fusion_lib.Fusion( + fn, + (in_type, {}), + out_type, + ) + output_fusions.append(fusion) + + return ( + tree_util.tree_unflatten( + tree_util.tree_structure(output_fusion_prefix), output_fusions + ), + output_permutation, + ) + + def fuse_jaxpr( jaxpr: jax_core.Jaxpr, out_tree: tree_util.PyTreeDef, consts, *args ): @@ -118,16 +253,33 @@ def fuse_jaxpr( # Collect input fusions for i, eqn in enumerate(jaxpr.eqns): - if eqn.primitive is fusable_p: + if eqn.primitive is fusible_p: fusion_eqn_index = i break if fusion_eqn_index is None: - raise ValueError("No fusable eqn found") + raise ValueError("No fusible eqn found") fusion_eqn = jaxpr.eqns[fusion_eqn_index] + # Now let's check if we need to do any fusion at all, e.g. do the outputs of + # the jaxpr have any dependence on the fusion at all? candidate_values = [*consts, *args] + independent_jaxpr, _, out_used, *_ = pe.partial_eval_jaxpr_custom( + jaxpr.replace( + eqns=(jaxpr.eqns[:fusion_eqn_index] + + jaxpr.eqns[fusion_eqn_index + 1 :]), + constvars=jaxpr.constvars + jaxpr.invars, + invars=fusion_eqn.outvars, + debug_info=jaxpr.debug_info.with_unknown_names()), + in_unknowns=[True] * len(fusion_eqn.outvars), + in_inst=[True] * len(fusion_eqn.outvars), + ensure_out_unknowns=False, + ensure_out_inst=False, + saveable=lambda *_, **__: False) + if not any(out_used): + # Short circuit if there is no need to run the fusible at all. + return jax_core.eval_jaxpr(independent_jaxpr, candidate_values) - # Construct fusions for non-constant inputs to the fusable. + # Construct fusions for non-constant inputs to the fusible. in_fusions_flat = [ construct_fusion( candidate_values, @@ -141,21 +293,20 @@ def fuse_jaxpr( in_fusions = tree_util.tree_unflatten( fusion_eqn.params["in_tree"], in_fusions_flat ) - out_fusion = construct_fusion( + output_fusions, output_permutation = _construct_output_fusions( candidate_values, - jaxpr.replace( - eqns=jaxpr.eqns[:fusion_eqn_index] - + jaxpr.eqns[fusion_eqn_index + 1 :] - ), - tree_util.tree_unflatten(out_tree, jaxpr.outvars), - tree_util.tree_unflatten( - fusion_eqn.params["out_tree"], fusion_eqn.outvars - ), + jaxpr, + out_tree, + fusion_eqn_index, + fusion_eqn.outvars, + fusion_eqn.params["out_tree"], + fusion_eqn.params["output_fusion_prefix"], ) - # Run the fusable. - out = fusion_eqn.params["func"](*in_fusions, out_fusion) - - # Now return the flattened output (the fuse_jaxpr caller should unflatten). - out_flat = tree_util.tree_leaves(out) - assert len(out_flat) == len(jaxpr.outvars) - return out_flat + out = fusion_eqn.params["func"](*in_fusions, output_fusions) + flat_out = jax.tree.leaves(out) + permuted_out = [flat_out[i] for i in output_permutation] + assert len(permuted_out) == len(jaxpr.outvars), ( + len(permuted_out), + len(jaxpr.outvars), + ) + return permuted_out diff --git a/jax/_src/pallas/helpers.py b/jax/_src/pallas/helpers.py index 1b2649d4e987..026abbbe5731 100644 --- a/jax/_src/pallas/helpers.py +++ b/jax/_src/pallas/helpers.py @@ -13,49 +13,248 @@ # limitations under the License. """Pallas helper functions.""" -from typing import Any, Protocol +from collections.abc import Callable, Mapping, Sequence +import functools -import jax -import jax.numpy as jnp -from jax._src.pallas import pallas_call +from jax._src import api +from jax._src import checkify +from jax._src import config +from jax._src import core as jax_core +from jax._src import tree_util +from jax._src import typing as jax_typing +from jax._src import util +import jax._src.lax as lax +from jax._src.lax.control_flow import conditionals from jax._src.pallas import core as pl_core +from jax._src.pallas import primitives as pl_primitives +from jax._src.pallas import utils as pl_utils +from jax._src import numpy as jnp -@jax.named_call -def empty( - shape: tuple[int, ...], dtype: jnp.dtype, *, memory_space: Any = None -): - def _empty_kernel(_): - # No-op to leave the out_ref uninitialized - pass +empty = api.named_call(lax.empty) - if memory_space is None: - kernel_memory_space = pl_core.MemorySpace.ANY - memory_space = jax.ShapeDtypeStruct - else: - kernel_memory_space = memory_space - return pallas_call.pallas_call( - _empty_kernel, - in_specs=[], - out_specs=pl_core.BlockSpec(memory_space=kernel_memory_space), - out_shape=memory_space(shape, dtype), - )() +@api.named_call +def empty_like(x: object): + """Create an empty PyTree of possibly uninitialized values. + + Args: + x: A PyTree with leaves specifying the shape and dtype of + the uninitialized object. + + Returns: + A PyTree with the same structure as ``x``, but with uninitialized + values. + + See Also: + :func:`jax.lax.empty` + """ + return tree_util.tree_map(lambda leaf: empty(leaf.shape, leaf.dtype), x) -class ArrayLike(Protocol): - shape: tuple[int, ...] - dtype: jnp.dtype +def empty_ref_like(x: object) -> jax_typing.Array: + """Returns an empty array Ref with same shape/dtype/memory space as x.""" + match x: + case pl_core.MemoryRef(): + memory_space = x.memory_space + case jax_core.ShapeDtypeStruct(): + memory_space = pl_core.MemorySpace.ANY + case _: + raise ValueError(f'empty_ref_like does not support {type(x)}') + return jax_core.new_ref(empty_like(x), memory_space=memory_space) -def empty_like(x: ArrayLike, *, memory_space: Any = None): - return empty(x.shape, x.dtype, memory_space=memory_space) +def when( + condition: bool | jax_typing.ArrayLike, / +) -> Callable[[Callable[[], None]], Callable[[], None]]: + """Calls the decorated function when the condition is met. -def when(condition): + Args: + condition: If a boolean, this is equivalent to ``if condition: f()``. If an + array, ``when`` produces a :func:`jax.lax.cond` with the decorated + function as the true branch. + + Returns: + A decorator. + """ def _wrapped(f): if isinstance(condition, bool): if condition: f() else: - jax.lax.cond(condition, f, lambda: None) + conditionals.cond(condition, f, lambda: None) return _wrapped + + +def loop( + lower: jax_typing.ArrayLike, + upper: jax_typing.ArrayLike, + *, + step: jax_typing.ArrayLike = 1, + unroll: int | bool | None = None, +) -> Callable[[Callable[[jax_typing.Array], None]], None]: + """Returns a decorator that calls the decorated function in a loop.""" + zero: jax_typing.ArrayLike + if not all(map(jax_core.is_concrete, (lower, upper, step))): + idx_type = jnp.result_type(lower, upper, step) + lower = lax.convert_element_type(lower, idx_type) + upper = lax.convert_element_type(upper, idx_type) + step = lax.convert_element_type(step, idx_type) + zero = jnp.array(0, dtype=idx_type) + else: + zero = 0 + + def decorator(body): + lax.fori_loop( + zero, + pl_utils.cdiv(upper - lower, step), + lambda idx, _: body(lower + idx * step), + init_val=None, + unroll=unroll, + ) + + return decorator + + +_ENABLE_DEBUG_CHECKS = config.bool_state( + "jax_pallas_enable_debug_checks", + default=False, + help=( + "If set, ``pl.debug_check`` calls are checked at runtime. Otherwise," + " they are a noop." + ), +) + + +enable_debug_checks = _ENABLE_DEBUG_CHECKS + + +def debug_checks_enabled() -> bool: + """Returns runtime checks are enabled.""" + return _ENABLE_DEBUG_CHECKS.value + + +def debug_check(condition, message): + """Check the condition if + :func:`~jax.experimental.pallas.enable_debug_checks` is set, otherwise + do nothing. + """ + return checkify.debug_check(condition, message) + + +def _make_kernel(body, + out_shape: object, + mesh: pl_core.Mesh, + scratch_shapes: pl_core.ScratchShapeTree = (), + name: str | None = None, + **mesh_kwargs + ): + if unwrap_out := not isinstance(out_shape, (tuple, list)): + out_shape = (out_shape,) + + @api.jit + def wrapper(*operands): + arg_refs = tree_util.tree_map(jax_core.new_ref, operands) + out_refs = tree_util.tree_map( + lambda out: jax_core.new_ref( + lax.empty(out.shape, out.dtype), + memory_space=( + ms + if hasattr(out, "memory_space") + and not isinstance( + ms := out.memory_space, jax_core.MemorySpace + ) + else None + ), + ), + out_shape, + ) + + + @pl_core.core_map(mesh, **mesh_kwargs, name=name or util.fun_name(body)) + def _(): + return pl_primitives.run_scoped( + functools.partial(body, *arg_refs, *out_refs), + *scratch_shapes if isinstance(scratch_shapes, Sequence) else (), + **scratch_shapes if isinstance(scratch_shapes, Mapping) else {}, + ) + + outs = tree_util.tree_map(lambda ref: ref[...], out_refs) + return outs[0] if unwrap_out else outs + return wrapper + + +def kernel(body: Callable | api.NotSpecified = api.NotSpecified(), # pylint: disable=g-bare-generic + out_shape: object | None = None, + *, + mesh: pl_core.Mesh, + scratch_shapes: pl_core.ScratchShapeTree = (), + compiler_params: pl_core.CompilerParams | None = None, + interpret: bool = False, + cost_estimate: pl_core.CostEstimate | None = None, + debug: bool = False, + name: str | None = None, + metadata: dict[str, str] | None = None, +): + """Entry point for creating a Pallas kernel. + + This is a convenience wrapper around ``core_map`` for executing a kernel + over a mesh and ``run_scoped`` for allocating scratch memory. + + If ``body`` is provided, this function behaves as a decorator: + + .. code-block:: python + + def kernel_body(in_ref, out_ref): + ... + kernel = pl.kernel(kernel_body, out_shape=...) + + If ``body`` is omitted, this function behaves as a decorator factory and + will return a decorator that can be used to annotate a kernel body: + + .. code-block:: python + + @pl.kernel(out_shape=...) + def kernel(in_ref, out_ref): + ... + + Args: + body: The body of the kernel. If provided, this function behaves as a + decorator, and if omitted, this function behaves as a decorator factory. + out_shape: The shape of the output. Should be a PyTree of + ``jax.ShapeDtypeStruct`` or ``jax.Array`` s. + mesh: The mesh to run the kernel on. + scratch_shapes: The shapes of the scratch arrays. + compiler_params: The compiler parameters to pass to the backend. + interpret: Whether to run the function in interpret mode. + debug: Whether or not to out helpful debugging information. + cost_estimate: The cost estimate of the function. + name: The (optional) name of the kernel. + metadata: Optional dictionary of information about the kernel that will be + serialized as JSON in the HLO. Can be used for debugging and analysis. + + Returns: + If ``body`` is provided, returns a function that runs the kernel. + It should take any number of input operands and returns an output with the + same PyTree structure as `out_shape`. + If ``body`` is omitted, returns a decorator that can be used to annotate + a kernel body. + """ + # Note we default out_shape to None to allow `body` to come before it + # in the function signature, but `body` itself is optional. + if out_shape is None: + raise ValueError('out_shape must be provided.') + kwds = dict( + out_shape=out_shape, + mesh=mesh, + scratch_shapes=scratch_shapes, + compiler_params=compiler_params, + interpret=interpret, + cost_estimate=cost_estimate, + debug=debug, + name=name, + metadata=metadata) + if isinstance(body, api.NotSpecified): + return lambda fun: _make_kernel(fun, **kwds) # type: ignore[arg-type] + else: + return _make_kernel(body, **kwds) # type: ignore[arg-type] diff --git a/jax/_src/pallas/hlo_interpreter.py b/jax/_src/pallas/hlo_interpreter.py index 6fbe5e914bfe..f3faddbc13dd 100644 --- a/jax/_src/pallas/hlo_interpreter.py +++ b/jax/_src/pallas/hlo_interpreter.py @@ -27,25 +27,32 @@ from collections.abc import Iterable, Sequence from functools import reduce, partial import itertools -from typing import Any, Callable +from typing import Any +from collections.abc import Callable -import jax -from jax import lax +from jax._src.lax import lax +from jax._src.lax import slicing +from jax._src.lax.control_flow import conditionals +from jax._src.lax.control_flow import loops from jax._src import core as jax_core +from jax._src import frozen_dict from jax._src import linear_util as lu from jax._src import source_info_util from jax._src.interpreters import partial_eval as pe from jax._src.pallas import core as pallas_core from jax._src.pallas import primitives +from jax._src import state from jax._src.state import discharge as state_discharge +from jax._src import typing as jax_typing from jax._src import util + from jax._src.util import ( foreach, safe_map, safe_zip, split_list, ) -import jax.numpy as jnp +from jax._src import numpy as jnp import numpy as np map, unsafe_map = safe_map, map @@ -74,7 +81,7 @@ def _logical_to_interpret_mode_dtype(dtype): def _logical_aval_to_interpret_mode_aval(aval): - if isinstance(aval, pallas_core.AbstractMemoryRef): + if isinstance(aval, state.AbstractRef): inner_aval = _logical_aval_to_interpret_mode_aval(aval.inner_aval) return aval.update(inner_aval=inner_aval) if isinstance(aval, jax_core.ShapedArray): @@ -83,22 +90,23 @@ def _logical_aval_to_interpret_mode_aval(aval): return aval -def _dynamic_slice(start_idx, block_shape, value, is_indexing): +def _dynamic_slice( + start_idx, block_shape: tuple[int, ...], value, is_squeeze, +): start_idx = tuple(jnp.asarray(s, dtype=jnp.int32) for s in start_idx) - output = lax.dynamic_slice(value, start_idx, slice_sizes=block_shape) - squeeze_dims = tuple(np.arange(len(is_indexing))[np.array(is_indexing, - dtype=np.bool_)]) - return lax.squeeze(output, squeeze_dims) + output = slicing.dynamic_slice(value, start_idx, slice_sizes=block_shape) + squeeze_dims = tuple(np.arange(len(is_squeeze))[np.array(is_squeeze, + dtype=np.bool_)]) + return lax.squeeze(output, squeeze_dims) # type: ignore[arg-type] -def _dynamic_update_slice(start_idx, block_shape, value, update, - is_indexing): +def _dynamic_update_slice(start_idx, block_shape, value, update, is_squeeze): start_idx = tuple(jnp.asarray(s, dtype=jnp.int32) for s in start_idx) - broadcast_dims = tuple(i for i, b in enumerate(is_indexing) + broadcast_dims = tuple(i for i, b in enumerate(is_squeeze) if not b) update = lax.broadcast_in_dim(update, block_shape, broadcast_dims) assert update.shape == block_shape - return lax.dynamic_update_slice(value, update, start_idx) + return slicing.dynamic_update_slice(value, update, start_idx) # TODO(justinfu): Move this to a common utility file. @@ -112,8 +120,7 @@ def _get_next_indices(grid, indices): return tuple(reversed(next_indices)) -def _pad_to_block_dimension(value, - block_shape): +def _pad_to_block_dimension(value, block_shape: tuple[int, ...]): """Pads values so the shape evenly divides into block dimensions. For example, if values has a shape of (33, 2, 5) with a block_shape of @@ -121,8 +128,7 @@ def _pad_to_block_dimension(value, Args: value: Array to be padded. - block_shape: Block shapes to use for padding. If None, no padding will - be performed. + block_shape: Block shapes to use for padding. Returns: A padded array. @@ -139,7 +145,7 @@ def _pad_to_block_dimension(value, def _initialize_output_vals( block_mappings_output: Iterable[BlockMapping], - input_args, input_output_aliases) -> Sequence[jax.Array]: + input_args, input_output_aliases) -> Sequence[jax_typing.Array]: oi_map = {v: k for k, v in input_output_aliases} output_vals = [] for i, bm in enumerate(block_mappings_output): @@ -147,8 +153,8 @@ def _initialize_output_vals( output_vals.append(input_args[oi_map[i]]) else: output_vals.append(primitives.uninitialized_value( - bm.array_shape_dtype.shape, - bm.array_shape_dtype.dtype)) + bm.array_aval.shape, + bm.array_aval.dtype)) return output_vals @@ -190,7 +196,7 @@ def eval_jaxpr_recursive( consts: Consts that ``jaxpr`` closes over. *args: Input arguments to the ``jaxpr``. recurse_hop_rule: A Jaxpr interpreter to call on sub-jaxprs of - higher-order primtives. + higher-order primitives. propagate_source_info: Whether to propagate source info. """ def read(v: jax_core.Atom) -> Any: @@ -236,8 +242,7 @@ def pad_jaxpr_constvars(jaxpr: jax_core.Jaxpr, to pad each Jaxpr with all consts from all branches so the signatures match, but only use the consts for this branch. """ - newvar = jax_core.gensym(suffix='_') - unused_const_vars = [tuple(map(newvar, const_avals)) + unused_const_vars = [tuple(map(jax_core.Var, const_avals)) for const_avals in all_const_avals] const_prefix = util.concatenate(unused_const_vars[:i]) const_suffix = util.concatenate(unused_const_vars[i + 1:]) @@ -308,14 +313,20 @@ def rule(interpreter, *args, **params): return primitive.bind(*args, **params) return rule -_eval_jaxpr_hop_rules[lax.scan_p] = make_hop_rule(lax.scan_p, 'jaxpr') -_eval_jaxpr_hop_rules[lax.while_p] = make_hop_rule( - lax.while_p, 'body_jaxpr', 'cond_jaxpr') -_eval_jaxpr_hop_rules[lax.cond_p] = make_hop_rule(lax.cond_p, 'branches') +_eval_jaxpr_hop_rules[loops.scan_p] = make_hop_rule(loops.scan_p, 'jaxpr') +_eval_jaxpr_hop_rules[loops.while_p] = make_hop_rule( + loops.while_p, 'body_jaxpr', 'cond_jaxpr') +_eval_jaxpr_hop_rules[conditionals.cond_p] = make_hop_rule(conditionals.cond_p, 'branches') def _run_scoped_physicalize_rule( - interpreter, *consts, jaxpr: jax_core.Jaxpr): + interpreter, *consts, jaxpr: jax_core.Jaxpr, collective_axes): + if collective_axes: + raise NotImplementedError( + "run_scoped interpret rule does not support collective axes" + ) physical_jaxpr, physical_consts = interpreter(jaxpr, consts) - return primitives.run_scoped_p.bind(*physical_consts, jaxpr=physical_jaxpr) + return primitives.run_scoped_p.bind( + *physical_consts, jaxpr=physical_jaxpr, collective_axes=collective_axes + ) _eval_jaxpr_hop_rules[primitives.run_scoped_p] = _run_scoped_physicalize_rule @@ -328,7 +339,7 @@ def resolve_physical_types(jaxpr: jax_core.Jaxpr, consts: Sequence[Any]): eval_jaxpr_recursive, jaxpr, consts, recurse_hop_rule=resolve_physical_types) wrapped = lu.wrap_init(interp_fun, debug_info=jaxpr.debug_info) - new_jaxpr, _, new_consts, () = pe.trace_to_jaxpr_dynamic( + new_jaxpr, _, new_consts = pe.trace_to_jaxpr_dynamic( wrapped, kernel_avals) return new_jaxpr, new_consts @@ -344,8 +355,10 @@ def pallas_call_hlo_interpret( compiler_params: Any, cost_estimate: CostEstimate, out_avals: tuple[jax_core.AbstractValue, ...], + metadata: frozen_dict.FrozenDict[str, str] | None, + name: str | None, ): - del mesh, compiler_params, cost_estimate, out_avals + del mesh, compiler_params, cost_estimate, out_avals, metadata, name debug_info = jaxpr.debug_info # If we're in interpret mode, we *scan* over the grid and eval the # discharged jaxpr. @@ -377,32 +390,27 @@ def pallas_call_hlo_interpret( carry = [] for x, bm in zip(itertools.chain(block_args, out), grid_mapping.block_mappings): - if isinstance(bm.indexing_mode, pallas_core.Unblocked): - padding = bm.indexing_mode.padding - if padding is not None and any(p != (0, 0) for p in padding): - if input_output_aliases: - raise NotImplementedError("Padding with aliasing not supported.") - pad_value = primitives.uninitialized_value(shape=(), dtype=x.dtype) - x = lax.pad(x, pad_value, [(*p, 0) for p in padding]) + padding = [bd.padding if isinstance(bd, pallas_core.Element) else (0, 0) + for bd in bm.block_shape] + if padding is not None and any(p != (0, 0) for p in padding): + if input_output_aliases: + raise NotImplementedError("Padding with aliasing not supported.") + pad_value = primitives.uninitialized_value(shape=(), dtype=x.dtype) + x = lax.pad(x, pad_value, [(*p, 0) for p in padding]) carry.append(x) - is_indexing_dim = [ - tuple(b is pallas_core.mapped for b in bm.block_shape) + block_shapes = [pallas_core._get_block_shape(bm.block_shape) + for bm in grid_mapping.block_mappings] + is_squeeze_dim = [ + tuple(isinstance(bd, pallas_core.Squeezed) for bd in bm.block_shape) for bm in grid_mapping.block_mappings ] - block_shapes = [ - tuple(1 if i else b for i, b in zip(iid, bm.block_shape)) - for iid, bm in zip(is_indexing_dim, grid_mapping.block_mappings) - ] # Pad values to evenly divide into block dimensions. This matches the # behavior of the non-interpret mode. We pad with NaN, to make it easier # to catch OOB accesses. for carry_element in carry: aval = carry_element.aval - if isinstance(aval, jax_core.DShapedArray): - aval = jax_core.ShapedArray(aval.shape, aval.dtype) - carry_element.aval = aval carry = map(_pad_to_block_dimension, carry, block_shapes) carry.extend(scratch_values) @@ -416,7 +424,7 @@ def pallas_call_hlo_interpret( num_iterations = 1 # The scan carry: (i, loop_idx, *consts, *ins, *outs, *scratch) - # i:int32 is the interation index + # i:int32 is the iteration index # loop_idx: tuple[int32] are the program ids for each grid axis def cond(carry): i, *_ = carry @@ -444,7 +452,7 @@ def body(carry): for bm in grid_mapping.block_mappings ] blocks = map(_dynamic_slice, start_indices, block_shapes, - carry_consts_ins, is_indexing_dim) + carry_consts_ins, is_squeeze_dim) with pallas_core.grid_env(local_grid_env): assert len(discharged_jaxpr.invars) == len(scalars) + len(blocks) + len( scratch_values @@ -462,26 +470,26 @@ def body(carry): _, out_inout, out_scratch = split_list( blocks, [grid_mapping.num_index_operands, num_inout_blocks]) out_carry = map(_dynamic_update_slice, start_indices, block_shapes, - carry_consts_ins, out_inout, is_indexing_dim) + carry_consts_ins, out_inout, is_squeeze_dim) return (i + 1, _get_next_indices(grid, loop_idx), *out_carry, *out_scratch) - (_, _, *carry) = lax.while_loop( + (_, _, *carry) = loops.while_loop( cond, body, (jnp.int32(0), grid_start_indices, *carry) ) out_out = carry[len(block_args):len(block_args) + len(out)] out_nopad = [] for o, bm in zip(out_out, grid_mapping.block_mappings_output): - if isinstance(bm.indexing_mode, pallas_core.Unblocked): - padding = bm.indexing_mode.padding - if padding is not None and any(p != (0, 0) for p in padding): - if input_output_aliases: - raise NotImplementedError("Padding with aliasing not supported.") - pad_low, pad_high = zip(*padding) - limit_indices = [s - p for s, p in zip(o.shape, pad_high)] - o = lax.slice(o, pad_low, limit_indices) - if o.shape != bm.array_shape_dtype.shape: - o = lax.slice(o, (0,) * o.ndim, bm.array_shape_dtype.shape) + padding = [bd.padding if isinstance(bd, pallas_core.Element) else (0, 0) + for bd in bm.block_shape] + if padding is not None and any(p != (0, 0) for p in padding): + if input_output_aliases: + raise NotImplementedError("Padding with aliasing not supported.") + pad_low, pad_high = zip(*padding) + limit_indices = [s - p for s, p in zip(o.shape, pad_high)] + o = slicing.slice(o, pad_low, limit_indices) + if o.shape != bm.array_aval.shape: + o = slicing.slice(o, (0,) * o.ndim, bm.array_aval.shape) out_nopad.append(o) return out_nopad diff --git a/jax/_src/pallas/mosaic/BUILD b/jax/_src/pallas/mosaic/BUILD index 24e8341046b0..afcdefffa81e 100644 --- a/jax/_src/pallas/mosaic/BUILD +++ b/jax/_src/pallas/mosaic/BUILD @@ -15,7 +15,7 @@ # Package for Mosaic-specific Pallas extensions load("@rules_python//python:defs.bzl", "py_library") -load("//jaxlib:jax.bzl", "py_deps") +load("//jaxlib:jax.bzl", "py_deps", "pytype_strict_library") package( default_applicable_licenses = [], @@ -33,16 +33,7 @@ py_library( deps = [ "//jax", "//jax/_src/pallas", - ], -) - -py_library( - name = "verification", - srcs = ["verification.py"], - deps = [ - "//jax", - "//jax:mlir", - "//jax/_src/lib", + "//jax/extend:backend", ], ) @@ -50,8 +41,8 @@ py_library( name = "error_handling", srcs = ["error_handling.py"], deps = [ - "//jax:compiler", - "//jax:traceback_util", + "//jax/_src:compiler", + "//jax/_src:traceback_util", "//jax/_src/lib", ], ) @@ -62,13 +53,13 @@ py_library( deps = [ ":core", "//jax", - "//jax:core", - "//jax:dtypes", - "//jax:mlir", - "//jax:pretty_printer", - "//jax:tree_util", - "//jax:typing", - "//jax:util", + "//jax/_src:core", + "//jax/_src:dtypes", + "//jax/_src:mlir", + "//jax/_src:pretty_printer", + "//jax/_src:tree_util", + "//jax/_src:typing", + "//jax/_src:util", "//jax/_src/pallas", ], ) @@ -79,54 +70,137 @@ py_library( deps = [ ":core", ":lowering", - ":verification", + ":sc_lowering", "//jax", - "//jax:config", - "//jax:core", - "//jax:mlir", - "//jax:mosaic", - "//jax:sharding_impls", - "//jax:source_info_util", - "//jax:tpu_custom_call", + "//jax/_src:config", + "//jax/_src:core", + "//jax/_src:frozen_dict", + "//jax/_src:lax", + "//jax/_src:mlir", + "//jax/_src:sharding_impls", + "//jax/_src:source_info_util", + "//jax/_src:tpu_custom_call", "//jax/_src/lib", "//jax/_src/pallas", + "//jax/experimental:mosaic", ] + py_deps("numpy"), ) -py_library( +pytype_strict_library( name = "lowering", srcs = ["lowering.py"], deps = [ ":core", ":error_handling", ":primitives", + ":random", "//jax", - "//jax:ad_util", - "//jax:core", - "//jax:dtypes", - "//jax:mesh", - "//jax:mlir", - "//jax:mosaic", - "//jax:partial_eval", - "//jax:source_info_util", - "//jax:util", - "//jax:xla", + "//jax/_src:ad_util", + "//jax/_src:api", + "//jax/_src:checkify", + "//jax/_src:cloud_tpu_init", + "//jax/_src:config", + "//jax/_src:core", + "//jax/_src:custom_derivatives", + "//jax/_src:debugging", + "//jax/_src:dtypes", + "//jax/_src:export", + "//jax/_src:lax", + "//jax/_src:literals", + "//jax/_src:mesh", + "//jax/_src:mlir", + "//jax/_src:partial_eval", + "//jax/_src:random", + "//jax/_src:source_info_util", + "//jax/_src:traceback_util", + "//jax/_src:typing", + "//jax/_src:util", + "//jax/_src:xla_bridge", "//jax/_src/lib", "//jax/_src/pallas", + "//jax/experimental:mosaic", ] + py_deps("numpy"), ) +pytype_strict_library( + name = "sc_lowering", + srcs = [ + "sc_lowering.py", + ], + visibility = ["//jax/experimental:pallas_sc_users"], + deps = [ + ":core", + ":lowering", + ":primitives", + ":sc_core", + "//jax/_src:api_util", + "//jax/_src:core", + "//jax/_src:debugging", + "//jax/_src:lax", + "//jax/_src:mesh", + "//jax/_src:mlir", + "//jax/_src:numpy", + "//jax/_src:partial_eval", + "//jax/_src:source_info_util", + "//jax/_src:tree_util", + "//jax/_src:util", + "//jax/_src/lib", + "//jax/_src/pallas", + "//jax/experimental:mosaic", + ], +) + +pytype_strict_library( + name = "sc_core", + srcs = [ + "sc_core.py", + ], + visibility = ["//jax/experimental:pallas_sc_users"], + deps = [ + ":core", + ":tpu_info", + "//jax", + "//jax/_src:core", + "//jax/_src:lax", + "//jax/_src:tree_util", + "//jax/_src/pallas", + ], +) + +pytype_strict_library( + name = "sc_primitives", + srcs = [ + "sc_primitives.py", + ], + deps = [ + ":core", + ":lowering", + ":sc_core", + ":sc_lowering", + "//jax", + "//jax/_src:core", + "//jax/_src:dtypes", + "//jax/_src:effects", + "//jax/_src:lax", + "//jax/_src:partial_eval", + "//jax/_src/lib", + "//jax/_src/pallas", + "//jax/experimental:mosaic", + ], +) + py_library( name = "pipeline", srcs = ["pipeline.py"], deps = [ ":core", ":primitives", + ":tpu_info", "//jax", - "//jax:api_util", - "//jax:pallas", - "//jax:util", + "//jax/_src:api_util", + "//jax/_src:util", "//jax/_src/pallas", + "//jax/experimental:pallas", "//jax/extend:backend", ] + py_deps("numpy"), ) @@ -137,7 +211,8 @@ py_library( deps = [ ":primitives", "//jax", - "//jax:typing", + "//jax/_src:blocked_sampler", + "//jax/_src:typing", ] + py_deps("numpy"), ) @@ -152,17 +227,13 @@ py_library( ], ) -py_library( - name = "interpret", - srcs = ["interpret.py"], +pytype_strict_library( + name = "tpu_info", + srcs = ["tpu_info.py"], deps = [ ":core", - ":primitives", "//jax", - "//jax:core", - "//jax:source_info_util", - "//jax:util", - "//jax/_src/lib", - "//jax/_src/pallas", - ] + py_deps("numpy"), + "//jax/_src:dtypes", + "//jax/_src:util", + ], ) diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index f582248ee7c3..f9fe91fe14b1 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -16,18 +16,22 @@ from __future__ import annotations import collections +from collections.abc import Mapping from collections.abc import Sequence import dataclasses import enum -import functools from typing import Any, ClassVar, Literal import jax -from jax._src import config from jax._src import core as jax_core -from jax._src import dtypes +from jax._src import deprecations +from jax._src import linear_util as lu +from jax._src import state from jax._src import util +from jax._src.frozen_dict import FrozenDict +from jax._src.interpreters import partial_eval as pe from jax._src.pallas import core as pallas_core +from jax.extend import backend as jex_backend import jax.numpy as jnp import numpy as np @@ -35,124 +39,183 @@ map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip -partial = functools.partial -Grid = pallas_core.Grid -TupleGrid = pallas_core.TupleGrid -BlockSpec = pallas_core.BlockSpec -BlockSpecTree = pallas_core.BlockSpecTree -GridMapping = pallas_core.GridMapping -NoBlockSpec = pallas_core.NoBlockSpec -ScratchShapeTree = pallas_core.ScratchShapeTree -AbstractMemoryRef = pallas_core.AbstractMemoryRef no_block_spec = pallas_core.no_block_spec -_convert_block_spec_to_block_mapping = pallas_core._convert_block_spec_to_block_mapping _out_shape_to_aval_mapping = pallas_core._out_shape_to_aval_mapping -split_list = util.split_list - -_ENABLE_RUNTIME_ASSERT = config.bool_state( - "jax_pallas_enable_runtime_assert", - default=False, - help=( - "If set, enables runtime assertions in the kernel via checkify.check." - " Otherwise, runtime asserts will be ignored unless functionalized" - " using checkify.checkify." - ), + + +class KernelType(enum.Enum): + TC = 0 + SC_SCALAR_SUBCORE = 1 + SC_VECTOR_SUBCORE = 2 + + +class GridDimensionSemantics(enum.Enum): + PARALLEL = "parallel" + CORE_PARALLEL = "core_parallel" + SUBCORE_PARALLEL = "subcore_parallel" + ARBITRARY = "arbitrary" + +PARALLEL = GridDimensionSemantics.PARALLEL +CORE_PARALLEL = GridDimensionSemantics.CORE_PARALLEL +SUBCORE_PARALLEL = GridDimensionSemantics.SUBCORE_PARALLEL +ARBITRARY = GridDimensionSemantics.ARBITRARY + + +DimensionSemantics = ( + Literal["parallel", "core_parallel", "subcore_parallel", "arbitrary"] + | GridDimensionSemantics ) +class SideEffectType(enum.Enum): + # No side effects, can be deduplicated / removed if unused. + PURE = "pure" + # Cannot be deduplicated, but can be removed if unused. + DATAFLOW_SIDE_EFFECTING = "dataflow_side_effecting" + # Cannot be deduplicated or removed. + SIDE_EFFECTING = "side_effecting" + + @dataclasses.dataclass(frozen=True) -class TPUCompilerParams(pallas_core.CompilerParams): +class CompilerParams(pallas_core.CompilerParams): """Mosaic TPU compiler parameters. Attributes: - dimension_semantics: A list of dimension semantics for each grid - dimension of the kernel. Either "parallel" for dimensions that can - execute in any order, or "arbitrary" for dimensions that must be - executed sequentially. + dimension_semantics: A list of dimension semantics for each grid dimension + of the kernel. Either "parallel" for dimensions that can execute in any + order, or "arbitrary" for dimensions that must be executed sequentially. allow_input_fusion: A list of booleans indicating whether input fusion is allowed for each argument. - vmem_limit_bytes: Overrides the default VMEM limit for a kernel. Note - that this must be used in conjunction with the + vmem_limit_bytes: Overrides the default VMEM limit for a kernel. Note that + this must be used in conjunction with the --xla_tpu_scoped_vmem_limit_kib=N flag with N*1kib > vmem_limit_bytes. - collective_id: Indicates which barrier semaphore to use for the kernel. - Note that using the same collective_id does not guarantee that - the same barrier semaphore will be allocated between kernels. + collective_id: Indicates which barrier semaphore to use for the kernel. Note + that using the same collective_id does not guarantee that the same barrier + semaphore will be allocated between kernels. + has_side_effects: Set to True to prevent kernel being CSEd by XLA. + flags: A dictionary of command line flags for the kernel. internal_scratch_in_bytes: The size of the internal scratch space used by Mosaic. - flags: A dictionary of command line flags for the kernel. serialization_format: The serialization format for the kernel body. - device_type: The device type to compile for. + kernel_type: Specify if the kernel is meant to run on TensorCore or one of + the SparseCores + disable_bounds_checks: Disable bounds checks in the kernel. + skip_device_barrier: Skip the default device barrier for the kernel. + allow_collective_id_without_custom_barrier: Allow the use of collective_id + without a custom barrier. + use_tc_tiling_on_sc: Use TensorCore tiling for SparseCore. This flag is + only used for ``SC_*_SUBCORE`` kernels. """ - PLATFORM: ClassVar[str] = "mosaic" - dimension_semantics: ( - Sequence[Literal["parallel", "arbitrary"] | GridDimensionSemantics] | None - ) = None - allow_input_fusion: Sequence[bool] | None = None + BACKEND: ClassVar[pallas_core.Backend] = "mosaic_tpu" + dimension_semantics: tuple[DimensionSemantics, ...] | None = None + allow_input_fusion: tuple[bool, ...] | None = None vmem_limit_bytes: int | None = None collective_id: int | None = None - has_side_effects: bool = False + has_side_effects: bool | SideEffectType = False flags: dict[str, Any] | None = None internal_scratch_in_bytes: int | None = None serialization_format: int = 1 - device_type: str | None = None + kernel_type: KernelType = KernelType.TC + disable_bounds_checks: bool = False + skip_device_barrier: bool = False + allow_collective_id_without_custom_barrier: bool = False + shape_invariant_numerics: bool = True + use_tc_tiling_on_sc: bool | None = None + + def __init__( + self, + dimension_semantics: Sequence[DimensionSemantics] | None = None, + allow_input_fusion: Sequence[bool] | None = None, + vmem_limit_bytes: int | None = None, + collective_id: int | None = None, + has_side_effects: bool | SideEffectType = False, + flags: Mapping[str, Any] | None = None, + internal_scratch_in_bytes: int | None = None, + serialization_format: int = 1, + kernel_type: KernelType = KernelType.TC, + disable_bounds_checks: bool = False, + skip_device_barrier: bool = False, + allow_collective_id_without_custom_barrier: bool = False, + shape_invariant_numerics: bool = True, + use_tc_tiling_on_sc: bool | None = None, + ): + object.__setattr__( + self, + "dimension_semantics", + None if dimension_semantics is None else tuple(dimension_semantics), + ) + object.__setattr__( + self, + "allow_input_fusion", + None if allow_input_fusion is None else tuple(allow_input_fusion), + ) + object.__setattr__(self, "vmem_limit_bytes", vmem_limit_bytes) + object.__setattr__(self, "collective_id", collective_id) + object.__setattr__(self, "has_side_effects", has_side_effects) + object.__setattr__( + self, "flags", None if flags is None else FrozenDict(flags) + ) + object.__setattr__( + self, "internal_scratch_in_bytes", internal_scratch_in_bytes + ) + object.__setattr__(self, "serialization_format", serialization_format) + object.__setattr__(self, "kernel_type", kernel_type) + object.__setattr__(self, "disable_bounds_checks", disable_bounds_checks) + object.__setattr__(self, "skip_device_barrier", skip_device_barrier) + object.__setattr__( + self, + "allow_collective_id_without_custom_barrier", + allow_collective_id_without_custom_barrier, + ) + object.__setattr__( + self, "shape_invariant_numerics", shape_invariant_numerics + ) + object.__setattr__(self, "use_tc_tiling_on_sc", use_tc_tiling_on_sc) + # Replace is a method, not a field. replace = dataclasses.replace -class TPUMemorySpace(enum.Enum): - ANY = "any" # TODO(b/368401328): Remove this and just use pl.ANY. + +class MemorySpace(enum.Enum): VMEM = "vmem" + VMEM_SHARED = "vmem_shared" SMEM = "smem" CMEM = "cmem" SEMAPHORE = "semaphore_mem" + HBM = "hbm" + HOST = "host" def __str__(self) -> str: return self.value - def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype): - # A convenience function for constructing MemoryRef types. - return pallas_core.MemoryRef(shape, dtype, self) - -class semaphore_dtype(dtypes.extended): pass -class semaphore(semaphore_dtype): pass -class dma_semaphore(semaphore_dtype): pass -class barrier_semaphore(semaphore_dtype): pass - -class AbstractSemaphoreTyRules: - @staticmethod - def pallas_interpret_element_aval(_) -> jax_core.ShapedArray: - return jax_core.ShapedArray((), pallas_core.SEMAPHORE_INTERPRET_DTYPE) - - @staticmethod - def physical_element_aval(_) -> jax_core.ShapedArray: - return jax_core.ShapedArray((), jnp.int32) + def from_type(self, ty): + return pallas_core.MemoryRef(ty, memory_space=self) -class AbstractSemaphoreTy(dtypes.ExtendedDType): - name: str - _rules = AbstractSemaphoreTyRules + def __call__(self, shape: Sequence[int], dtype: jnp.dtype): + # A convenience function for constructing MemoryRef types of ShapedArrays. + return self.from_type(jax_core.ShapedArray(tuple(shape), dtype)) - def __repr__(self) -> str: - return self.name + def __getattr__(self, name): + if name == "ANY": + # Deprecated on Dec 10, 2025. + deprecations.warn( + "pltpu-memory-space-any", + "pltpu.MemorySpace.ANY is deprecated. Use pl.ANY instead.", + stacklevel=2, + ) + return pallas_core.MemorySpace.ANY + return super().__getattr__(name) # type: ignore - def __eq__(self, other): - return self.__class__ == other.__class__ - def __hash__(self) -> int: - return hash(self.__class__) - -# TODO(sharadmv): implement dtype rules for AbstractSemaphoreTy +# TODO(slebedev): Remove this after +MemorySpace.ANY = pallas_core.MemorySpace.ANY -class SemaphoreTy(AbstractSemaphoreTy): - type = semaphore - name = "sem" +class dma_semaphore(pallas_core.semaphore_dtype): pass -class DmaSemaphoreTy(AbstractSemaphoreTy): +class DMASemaphore(pallas_core.AbstractSemaphoreTy): type = dma_semaphore name = "dma_sem" -class BarrierSemaphoreTy(AbstractSemaphoreTy): - type = barrier_semaphore - name = "barrier_sem" - class SemaphoreType(enum.Enum): REGULAR = "regular" DMA = "dma" @@ -161,17 +224,18 @@ class SemaphoreType(enum.Enum): def __call__(self, shape: tuple[int, ...]): dtype: Any if self == SemaphoreType.DMA: - dtype = DmaSemaphoreTy() + dtype = DMASemaphore() elif self == SemaphoreType.BARRIER: - dtype = BarrierSemaphoreTy() + dtype = pallas_core.BarrierSemaphore() else: - dtype = SemaphoreTy() - return pallas_core.MemoryRef(shape, dtype, TPUMemorySpace.SEMAPHORE) + dtype = pallas_core.Semaphore() + return pallas_core.MemoryRef(jax_core.ShapedArray(shape, dtype), + MemorySpace.SEMAPHORE) def get_array_aval(self) -> pallas_core.ShapedArrayWithMemorySpace: return self(()).get_array_aval() - def get_ref_aval(self) -> AbstractMemoryRef: + def get_ref_aval(self) -> state.AbstractRef: return self(()).get_ref_aval() @dataclasses.dataclass(frozen=True) @@ -186,18 +250,18 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec): def __init__( self, num_scalar_prefetch: int, - grid: Grid = (), - in_specs: BlockSpecTree = no_block_spec, - out_specs: BlockSpecTree = no_block_spec, - scratch_shapes: ScratchShapeTree = () + grid: pallas_core.Grid = (), + in_specs: pallas_core.BlockSpecTree = no_block_spec, + out_specs: pallas_core.BlockSpecTree = no_block_spec, + scratch_shapes: pallas_core.ScratchShapeTree = () ): super().__init__(grid, in_specs, out_specs, scratch_shapes) self.num_scalar_prefetch = num_scalar_prefetch self.scratch_shapes = tuple(scratch_shapes) def _make_scalar_ref_aval(self, aval): - return AbstractMemoryRef(jax_core.ShapedArray(aval.shape, aval.dtype), - TPUMemorySpace.SMEM) + return state.AbstractRef(jax_core.ShapedArray(aval.shape, aval.dtype), + MemorySpace.SMEM) @dataclasses.dataclass(frozen=True) @@ -211,6 +275,17 @@ class TensorCoreMesh: devices: np.ndarray axis_names: Sequence[str] + def __init__(self, devices: np.ndarray, axis_names: Sequence[str]): + devices = np.copy(devices) + devices.setflags(write=False) + object.__setattr__(self, "devices", devices) + object.__setattr__(self, "axis_names", tuple(axis_names)) + + def __hash__(self) -> int: + return hash( + (self.devices.shape, tuple(np.ravel(self.devices)), self.axis_names) + ) + @property def backend(self) -> str: return "mosaic_tpu" @@ -225,23 +300,26 @@ def discharges_effect(self, effect: jax_core.Effect): def create_tensorcore_mesh( - axis_name: str, devices: Sequence[jax.Device] | None = None + axis_name: str, + devices: Sequence[jax.Device] | None = None, + num_cores: int | None = None, ) -> TensorCoreMesh: - # TODO(b/355036384): emit a better error if we don't have tensorcores. - if devices is None: - devices = jax.devices() - num_cores = devices[0].num_cores + if devices is not None and num_cores is not None: + raise ValueError('cannot specify both devices and num_cores') + if num_cores is None: + if devices is None: + abstract_device = jax.sharding.get_abstract_mesh().abstract_device + if abstract_device is None: + devices = [jax.devices()[0]] + else: + devices = [abstract_device] + num_cores = devices[0].num_cores return TensorCoreMesh( np.array([TensorCore(i) for i in range(num_cores)]), [axis_name], ) -def runtime_assert_enabled() -> bool: - """Returns whether runtime asserts are enabled.""" - return _ENABLE_RUNTIME_ASSERT.value - - def _tensorcore_mesh_discharge_rule( in_avals, out_avals, @@ -249,24 +327,81 @@ def _tensorcore_mesh_discharge_rule( mesh, jaxpr, compiler_params: Any | None, - interpret: bool, + interpret: Any, debug: bool, cost_estimate: pallas_core.CostEstimate | None, name: str, + metadata: FrozenDict[str, str] | None, ): assert isinstance(mesh, TensorCoreMesh) - if compiler_params and not isinstance(compiler_params, TPUCompilerParams): + if compiler_params and not isinstance(compiler_params, CompilerParams): raise ValueError( - "compiler_params must be a pltpu.TPUCompilerParams" + "compiler_params must be a pltpu.CompilerParams" ) if not compiler_params: - compiler_params = TPUCompilerParams() + compiler_params = CompilerParams() if len(mesh.shape) > 1: raise NotImplementedError("Mesh must be 1D") if compiler_params.dimension_semantics is not None: raise ValueError( "dimension_semantics must be None for TensorCoreMesh" ) + num_cores = len(mesh.devices) + if num_cores > 1: + # Since each core will have its own VMEM, we currently disallow VMEM inputs + # and outputs since other ops might not agree on how they are sharded across + # cores by the (core-mapped) kernel. + if any( + pallas_core.get_memory_space_aval(aval) == MemorySpace.VMEM + for aval in in_avals + ): + raise NotImplementedError( + "TensorCoreMesh does not support VMEM inputs/outputs when there are" + " >1 cores. Use HBM or ANY instead." + ) + def allowed_aval(aval): + if isinstance(aval, state.AbstractRef): + return True + if isinstance(aval, jax_core.ShapedArray): + # Only scalars are allowed. + return not aval.shape + return False + assert all(allowed_aval(v.aval) for v in jaxpr.constvars + jaxpr.invars) + + is_scalar_const = [ + isinstance(v.aval, jax_core.ShapedArray) and not v.aval.shape + for v in jaxpr.constvars + ] + if any(is_scalar_const): + # Rewrite body jaxpr to take in scalar values as Refs. + def new_body(*args): + args = [ + a[0] if is_scalar else a + for a, is_scalar in zip(args, is_scalar_const) + ] + return jax_core.eval_jaxpr(jaxpr, args) + # TODO(sharadmv): Remove this once Mosaic support passing scalars as values. + new_trace_avals = [ + state.AbstractRef( # pylint: disable=g-long-ternary + jax_core.ShapedArray((1,), v.aval.dtype), + memory_space=MemorySpace.SMEM, + ) + if is_scalar + else v.aval + for v, is_scalar in zip(jaxpr.constvars, is_scalar_const) + ] + new_jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( + lu.wrap_init( + new_body, debug_info=jaxpr.debug_info.with_unknown_names() + ), + new_trace_avals, + ) + jaxpr = new_jaxpr.replace(invars=[], constvars=new_jaxpr.invars) + args = tuple( + a[None] if is_scalar else a + for a, is_scalar in zip(args, is_scalar_const) + ) + in_avals, out_avals = util.split_list(new_trace_avals, [len(in_avals)]) return pallas_core.default_mesh_discharge_rule( in_avals, out_avals, @@ -280,6 +415,8 @@ def _tensorcore_mesh_discharge_rule( interpret=interpret, cost_estimate=cost_estimate, name=name, + metadata=metadata, + scratch_shapes=[], ) pallas_core._core_map_mesh_rules[TensorCoreMesh] = ( @@ -298,8 +435,12 @@ def _convert_semaphore_type_to_aval( ) -class GridDimensionSemantics(enum.Enum): - PARALLEL = "parallel" - ARBITRARY = "arbitrary" -PARALLEL = GridDimensionSemantics.PARALLEL -ARBITRARY = GridDimensionSemantics.ARBITRARY +def get_device_kind() -> str: + if abstract_device := jax.sharding.get_abstract_mesh().abstract_device: + return abstract_device.device_kind + return jex_backend.get_default_device().device_kind + +def get_num_device_cores() -> int: + if abstract_device := jax.sharding.get_abstract_mesh().abstract_device: + return abstract_device.num_cores + return jex_backend.get_default_device().num_cores diff --git a/jax/_src/pallas/mosaic/error_handling.py b/jax/_src/pallas/mosaic/error_handling.py index 7b4e03e4a3ff..3d8714945e90 100644 --- a/jax/_src/pallas/mosaic/error_handling.py +++ b/jax/_src/pallas/mosaic/error_handling.py @@ -18,6 +18,7 @@ import types from jax._src import compiler from jax._src import traceback_util +from jax._src.lib import _jax from jax._src.lib import xla_client from jax._src.lib.mlir import ir @@ -35,7 +36,7 @@ r'( to (?P[0-9]+)?:(?P[0-9]+))?\)' ) MLIR_ERR_PREFIX = ( - 'Pallas encountered an internal verification error.' + 'Pallas encountered an internal verification error. ' 'Please file a bug at https://github.com/jax-ml/jax/issues. ' 'Error details: ' ) @@ -55,9 +56,9 @@ def __init__(self, message: str): def _handle_xla_runtime_error( - base_err: xla_client.XlaRuntimeError, + base_err: _jax.JaxRuntimeError, ) -> MosaicError | None: - """Reformats XLARuntimeError to include a Python traceback.""" + """Reformats JaxRuntimeError to include a Python traceback.""" if 'Mosaic' not in str(base_err): return None try: diff --git a/jax/_src/pallas/mosaic/helpers.py b/jax/_src/pallas/mosaic/helpers.py index 76421cec3340..80bb4ef4abed 100644 --- a/jax/_src/pallas/mosaic/helpers.py +++ b/jax/_src/pallas/mosaic/helpers.py @@ -60,7 +60,7 @@ def _copy_start_or_wait(action, src_ref, dst_ref): def run_on_first_core(core_axis_name: str): """Runs a function on the first core in a given axis.""" - num_cores = jax.lax.psum(1, core_axis_name) + num_cores = jax.lax.axis_size(core_axis_name) if num_cores == 1: return lambda f: f() @@ -77,7 +77,7 @@ def _(): def core_barrier(sem, *, core_axis_name: str): """Synchronizes all cores in a given axis.""" - num_cores = jax.lax.psum(1, core_axis_name) + num_cores = jax.lax.axis_size(core_axis_name) core_id = jax.lax.axis_index(core_axis_name) @pl_helpers.when(num_cores > 1) @@ -88,8 +88,8 @@ def signal_core(i): # Don't signal ourself @pl_helpers.when(core_id != i) def _(): - plm_primitives.semaphore_signal(sem, 1, core_index=i) + pl_primitives.semaphore_signal(sem, 1, core_index=i) for i in range(num_cores): signal_core(i) - plm_primitives.semaphore_wait(sem, num_cores - 1) + pl_primitives.semaphore_wait(sem, num_cores - 1) diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py deleted file mode 100644 index 1ad7be8154cd..000000000000 --- a/jax/_src/pallas/mosaic/interpret.py +++ /dev/null @@ -1,1651 +0,0 @@ -# Copyright 2024 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections -from collections.abc import Iterable, Sequence -import dataclasses -import enum -import functools -import itertools -import math -import threading -from typing import Any, Literal - -import jax -from jax import lax -from jax._src import callback -from jax._src import core as jax_core -from jax._src.lax.control_flow import for_loop -from jax._src import linear_util as lu -from jax._src import source_info_util -from jax._src.pallas.mosaic import primitives as mosaic_primitives -from jax._src.pallas.mosaic import core as mosaic_core -from jax._src.pallas import core as pallas_core -from jax._src.pallas import primitives -from jax._src import pjit -from jax._src.state import discharge as state_discharge -from jax._src.state import indexing -from jax._src.state import primitives as state_primitives -from jax._src.util import ( - safe_map, - safe_zip, - split_list -) -from jax.interpreters import partial_eval as pe -import jax.numpy as jnp -import numpy as np - - -map, unsafe_map = safe_map, map -zip, unsafe_zip = safe_zip, zip - -Grid = pallas_core.Grid -TupleGrid = pallas_core.TupleGrid -GridSpec = pallas_core.GridSpec -BlockMapping = pallas_core.BlockMapping -GridMapping = pallas_core.GridMapping -BlockSpec = pallas_core.BlockSpec -BlockSpecTree = pallas_core.BlockSpecTree -NoBlockSpec = pallas_core.NoBlockSpec -no_block_spec = pallas_core.no_block_spec -ScratchShapeTree = pallas_core.ScratchShapeTree -CostEstimate = pallas_core.CostEstimate - - -@dataclasses.dataclass(frozen=True) -class TPUInterpretParams: - """Parameters for Mosaic TPU interpret mode. - - Attributes: - dma_execution_mode: If "eager", DMAs are executed as soon as they are - issued. If "on_wait", DMA reads or writes are only executed when a device - is waiting on a DMA semaphore that will be signaled when the read or write - is complete. - Default: "on_wait". - detect_races: If True, a dynamic, happens-before race detector will be - used to detect data races during kernel interpretation. If any races are - detected, a message will be printed and `races.races_found` will be set - to True. - Default: False. - skip_floating_point_ops: If True, operations that produce only floating - point values will not be interpreted; instead, their results will be - replaced with arrays all of `jnp.inf`. Additionaly any floating point - operands to any operation will be replaced with (arrays of) `jnp.inf`. - Default: False. - uninitialized_memory: If "nan", allocated buffers are initialized to - to contain all NaNs (or to their maximum possible value for integers). - If "zero", allocated buffers are initialized to all zeros. - Default: "nan". - """ - dma_execution_mode: Literal["eager", "on_wait"] = "on_wait" - detect_races: bool = False - skip_floating_point_ops: bool = False - uninitialized_memory: Literal["nan", "zero"] = "nan" - - -VectorClock = np.ndarray - -# Conceptually, each DMA runs on its own, independent device. Representing -# this precisely would require vector clocks to have sizes linear in the number -# of DMAs. -# -# Instead, we use approximate vector clocks of fixed size. We assign each DMA -# a virtual device ID in the range [num_devices + 1, NUM_VIRTUAL_DEVICES] -- -# and each operation of a DMA increments the corresponding coordinate in its -# vector clock. (So the "virtual" part of a vector clock is effectively -# counting, for each virtual device, the number of DMAs that happened-before -# the vector clock and were assigned to that virtual device.) -# -# If two approximate clocks are unordered, then their corresponding events are -# not ordered by the happens-before relation. So this approximation will not -# introduce any false positives in detecting data races. But we may fail to -# detect some true data races because there can be cases where two approximate -# clocks are ordered, and we will treat the corresponding events as ordered -# by the happens-before relation, but the corresponding events are not -# actually ordered. -NUM_VIRTUAL_DEVICES = 32 - -def make_vector_clock(num_devices: int) -> VectorClock: - del num_devices - return np.zeros(NUM_VIRTUAL_DEVICES, dtype=np.int32) - -def copy_vector_clock(x: VectorClock) -> VectorClock: - if x is None: - return None - return x.copy() - -def update_vector_clock(x: VectorClock, y: VectorClock): - x[:] = np.maximum(x, y) - -def lt(x: VectorClock, y: VectorClock) -> bool: - return bool((x <= y).all() & (x < y).any()) - -def ordered(x: VectorClock, y: VectorClock) -> bool: - return lt(x, y) | lt(y, x) - -def inc_vector_clock(x: VectorClock, device_id: int): - if device_id >= len(x): - raise ValueError(f'device_id={device_id} is out of range for x={x}') - assert device_id < len(x) - x[device_id] += 1 - - -class Semaphore: - def __init__(self, semaphore_id=None): - shared_memory = _get_shared_memory() - - self.id = semaphore_id - - # TODO(jburnim): Use one Condition variable per device. (Which will be - # easier to do when we're using single integer device IDs.) - self.cv = threading.Condition() - - self.counts = np.zeros(shared_memory.num_devices, dtype=np.int32) - - self.interpret_params = shared_memory.interpret_params - if self.interpret_params.detect_races: - # We associate a vector clock with each count in self.counts. Whenever - # self.counts[i] is signaled, self.clocks[i] is updated with the vector - # clock of the signaling device. Whenever device i successfully waits on - # self.counts[i], the vector clock of device i is updated with - # self.clocks[i]. - # - # TODO(jburnim): Model happens-before more precisely for the case where - # semaphores are over-signaled. - self.clocks = [None] * shared_memory.num_devices - - def signal(self, inc, device_id, clock): - """Signal the semaphore on `device_id` by `inc`. - - Args: - inc: A positive integer. The amount by which to increment the semaphore - on the target device. - device_id: The ID of the target device. - clock: The vector clock of the signaling device at the time of the signal. - """ - device_id = int(device_id) - with self.cv: - self.counts[device_id] += inc - if self.interpret_params.detect_races: - if self.clocks[device_id] is None: - self.clocks[device_id] = copy_vector_clock(clock) - else: - update_vector_clock(self.clocks[device_id], clock) - self.cv.notify_all() - - def read(self, device_id): - with self.cv: - return self.counts[device_id] - - def wait(self, value, device_id, *, is_dma=False): - device_id = int(device_id) - shared_memory = _get_shared_memory() - - # TODO(jburnim): - # - If the count is larger than value, raise an error? - # - If the count is equal to value, but there DMAs waiting to signal us, - # raise an error? - - # Simple implementation for non-DMA semaphores. - if not is_dma or (self.interpret_params.dma_execution_mode == "eager"): - with self.cv: - while self.counts[device_id] < value: - self.cv.wait() - self.counts[device_id] -= value - if self.interpret_params.detect_races: - clock = copy_vector_clock(self.clocks[device_id]) - if self.interpret_params.detect_races: - with shared_memory.lock: - update_vector_clock(shared_memory.clocks[device_id], clock) - return - - # For DMA semaphores (when dma_execution_mode=='on_wait'), while our count - # is not large enough we will select and partially execute pending DMAs - # until our count is large enough. - # - # This approach will tend to run DMAs as late as possible, as well as - # out-of-order. This approach also lets us avoid the complexity of spinning - # up separate threads to handle executing DMAs. - shared_memory = _get_shared_memory() - while True: - clock = None - with self.cv: - if self.counts[device_id] >= value: - self.counts[device_id] -= value - if self.interpret_params.detect_races: - clock = copy_vector_clock(self.clocks[device_id]) - else: - return - if clock is not None: - with shared_memory.lock: - update_vector_clock(shared_memory.clocks[device_id], clock) - return - - with shared_memory.lock: - dma_queue = shared_memory.dmas_by_sem[self.id] - if len(dma_queue) > 0: - dma = dma_queue.pop() - else: - continue - - # Only execute the DMA as far as necessary to signal us. - assert (dma.src_sem is self) or (dma.dst_sem is self) - with dma.lock: - if dma.virtual_device_id is None: - dma.virtual_device_id = np.random.randint( - shared_memory.num_devices, NUM_VIRTUAL_DEVICES) - - if dma.state == DmaState.STARTED: - # Do the read. - if self.interpret_params.detect_races: - inc_vector_clock(dma.clock, dma.virtual_device_id) - dma.data = get(dma.src_device_id, - dma.src_memory_space, - dma.src_buffer_id, - dma.src_transforms, - clock=copy_vector_clock(dma.clock), - src_device_id=dma.id, - source_info=dma.source_info) - if self.interpret_params.detect_races: - inc_vector_clock(dma.clock, dma.virtual_device_id) - if dma.src_sem is not None: - data_size = dma.data.itemsize * dma.data.size - dma.src_sem.signal( - data_size, device_id=dma.src_device_id, clock=dma.clock) - dma.state = DmaState.READ - - if dma.src_sem is self: - # We were only waiting for the DMA read (i.e., we're the send - # semaphore), so leave the DMA write for later. - continue - assert dma.state == DmaState.READ - - # Do the write. - assert dma.dst_sem is self - if self.interpret_params.detect_races: - inc_vector_clock(dma.clock, dma.virtual_device_id) - store(dma.dst_device_id, - dma.dst_memory_space, - dma.dst_buffer_id, - dma.dst_transforms, - dma.data, - clock=copy_vector_clock(dma.clock), - src_device_id=dma.id, - source_info=dma.source_info) - if self.interpret_params.detect_races: - inc_vector_clock(dma.clock, dma.virtual_device_id) - data_size = dma.data.itemsize * dma.data.size - dma.dst_sem.signal( - data_size, device_id=dma.dst_device_id, clock=dma.clock) - - dma.data = None - dma.state = DmaState.COMPLETED - - -class DmaState(enum.Enum): - STARTED = 0 - READ = 1 - COMPLETED = 2 - -@dataclasses.dataclass -class DMA: - id: int - - src_device_id: int - src_memory_space: int - src_buffer_id: int - src_transforms: tuple[Any, ...] - dst_device_id: int - dst_memory_space: int - dst_buffer_id: int - dst_transforms: tuple[Any, ...] - src_sem: Semaphore - dst_sem: Semaphore - - clock: VectorClock - - source_info: source_info_util.SourceInfo | None = None - - state: DmaState = DmaState.STARTED - data: np.ndarray | None = None - virtual_device_id: int | None = None - lock: threading.Lock = dataclasses.field(default_factory=threading.Lock) - - -@dataclasses.dataclass -class RaceDetectionState: - num_devices: int - - # (memory_space, buffer_id, device_id) -> [(device_id, VectorClock, range)] - reads: dict = dataclasses.field( - default_factory=lambda: collections.defaultdict(list)) - - # (memory_space, buffer_id, device_id) -> [(device_id, VectorClock, range)] - writes: dict = dataclasses.field( - default_factory=lambda: collections.defaultdict(list)) - - lock: threading.Lock = dataclasses.field(default_factory=threading.Lock) - - races_found: bool = False - -def _is_empty_slice(slice_or_idx: slice | int): - if isinstance(slice_or_idx, int) or (slice_or_idx == slice(None)): - return False - - # NOTE: All slices here will have known size. - start = int(slice_or_idx.start) if slice_or_idx.start is not None else 0 - stop = int(slice_or_idx.stop) - return (start < stop) - -def slices_overlap(slice_or_idx1: slice | int, slice_or_idx2: slice | int): - if isinstance(slice_or_idx1, int): - slice_or_idx1 = slice(slice_or_idx1, slice_or_idx1 + 1) - if isinstance(slice_or_idx2, int): - slice_or_idx2 = slice(slice_or_idx2, slice_or_idx2 + 1) - - if slice_or_idx1 == slice(None): - return _is_empty_slice(slice_or_idx2) - if slice_or_idx2 == slice(None): - return _is_empty_slice(slice_or_idx1) - - # TODO(jburnim): Handle non-zero steps. - assert (slice_or_idx1.step == 1) or (slice_or_idx1.step is None) - assert (slice_or_idx2.step == 1) or (slice_or_idx2.step is None) - - # NOTE: We are only comparing slices with known stops (and sizes). - # Do we need to handle zero-length slices? - return ((slice_or_idx1.start <= slice_or_idx2.start < slice_or_idx1.stop) - | (slice_or_idx2.start <= slice_or_idx1.start < slice_or_idx2.stop)) - -def ranges_overlap(range1: tuple[slice | int, ...], - range2: tuple[slice | int, ...]) -> bool: - return all(slices_overlap(r1, r2) for r1, r2 - in itertools.zip_longest(range1, range2, fillvalue=slice(None))) - -def check_read(device_id, clock, buffer_key, rnge, source_info=None): - if source_info is not None: - user_frame = source_info_util.summarize(source_info) - else: - user_frame = 'pallas_call' - - with races.lock: - writes = races.writes[buffer_key] - num_writes = len(writes) - races.reads[buffer_key].append((device_id, clock, rnge, user_frame)) - - for i in range(num_writes): - write_device_id, write_clock, write_range, write_frame = writes[i] - if ordered(write_clock, clock): - continue - if not ranges_overlap(rnge, write_range): - continue - # TODO(jburnim): When printing device IDs for reads/writes, distinguish - # between real device IDs vs. DMA IDs. - print('RACE DETECTED\n' - f' read of {buffer_key}[{rnge}] from {device_id}, {user_frame}\n' - f' write of {buffer_key}[{write_range}] from {write_device_id}, {write_frame}') - with races.lock: - races.races_found = True - return - -def check_write(device_id, clock, buffer_key, rnge, source_info=None): - if source_info is not None: - user_frame = source_info_util.summarize(source_info) - else: - user_frame = 'pallas_call' - - with races.lock: - writes = races.writes[buffer_key] - reads = races.reads[buffer_key] - num_writes = len(writes) - num_reads = len(reads) - races.writes[buffer_key].append((device_id, clock, rnge, user_frame)) - - # TODO(jburnim): For performance, we should also probably remove any - # conflicting reads and writes that happened-before the current write. - - for i in range(num_writes): - write_device_id, write_clock, write_range, write_frame = writes[i] - if ordered(write_clock, clock): - continue - if not ranges_overlap(rnge, write_range): - continue - # TODO(jburnim): When printing device IDs for reads/writes, distinguish - # between real device IDs vs. DMA IDs. - print('RACE DETECTED\n' - f' write of {buffer_key}[{rnge}] from {device_id}, {user_frame}\n' - f' write of {buffer_key}[{write_range}] from {write_device_id}, {write_frame}') - with races.lock: - races.races_found = True - break - - for i in range(num_reads): - read_device_id, read_clock, read_range, read_frame = reads[i] - if ordered(read_clock, clock): - continue - if not ranges_overlap(rnge, read_range): - continue - # TODO(jburnim): When printing device IDs for reads/writes, distinguish - # between real device IDs vs. DMA IDs. - print('RACE DETECTED\n' - f' write of {buffer_key}[{rnge}] from {device_id}, {user_frame}\n' - f' read of {buffer_key}[{read_range}] from {read_device_id}, {read_frame}') - with races.lock: - races.races_found = True - return - - -@dataclasses.dataclass -class SharedMemory: - interpret_params: TPUInterpretParams - num_devices: int - clocks: list[VectorClock] - barrier: threading.Barrier - - # (memory_space, buffer_id, device_id) -> NumPy array - # TODO(jburnim): Handle Megacore. - mem: dict[tuple[int, int, int], np.ndarray] = dataclasses.field( - default_factory=dict) - - # semaphore_id -> Semaphore - sem: dict[int, Semaphore] = dataclasses.field(default_factory=dict) - - # (semaphore_id, device_id) - # -> list of DMAs that will signal the semaphore on the given device - dmas_by_sem: dict[tuple[int, int], list[DMA]] = dataclasses.field( - default_factory=lambda: collections.defaultdict(list)) - - lock: threading.Lock = dataclasses.field(default_factory=threading.Lock) - - # device_id -> next buffer ID - next_buffer_id: dict[int, int] = dataclasses.field( - default_factory=lambda: collections.defaultdict(lambda: 100)) - # device_id -> next semaphore ID - next_semaphore_id: dict[int, int] = dataclasses.field( - default_factory=lambda: collections.defaultdict(lambda: 2000)) - - next_dma_id: int = 100 - - -# TODO(jburnim): Do we want to support multiple instances of SharedMemory? -# Maybe for running multiple distinct interpreted computations in parallel? -_shared_memory : SharedMemory | None = None -_shared_memory_init_lock = threading.Lock() -races : RaceDetectionState | None = None - -def _get_shared_memory() -> SharedMemory: - assert _shared_memory is not None - return _shared_memory - -def _clear_shared_memory(): - global _shared_memory - with _shared_memory_init_lock: - _shared_memory = None - -def _initialize_shared_memory(device_id, num_devices, *, interpret_params): - global _shared_memory - del device_id - num_devices = int(num_devices) - with _shared_memory_init_lock: - if _shared_memory is None: - _shared_memory = SharedMemory( - interpret_params=interpret_params, - num_devices=num_devices, - clocks=[make_vector_clock(num_devices) for _ in range(num_devices)], - barrier=threading.Barrier(num_devices)) - assert _shared_memory.num_devices == num_devices - - global races - races = RaceDetectionState(num_devices=num_devices) - -def _clean_up_shared_memory(device_id): - device_id = int(device_id) - shared_memory = _get_shared_memory() - shared_memory.barrier.wait() - if device_id == 0: - _clear_shared_memory() - -def _validate(device_id): - device_id = int(device_id) - - shared_memory = _get_shared_memory() - with shared_memory.lock: - for sem in shared_memory.sem.values(): - with sem.cv: - if sem.counts[device_id] != 0: - # TODO(jburnim): Make this raise an error, but in a way that doesn't - # cause other devices to hang later in `_clean_up_shared_memory`. - print( - f'Semaphore {sem.id} has non-zero count for {device_id} at ' - f'kernel exit: {sem.counts[device_id]}') - -def _allocate_buffer(device_id, memory_space, val): - device_id = int(device_id) - memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] - val = np.array(val) - - shared_memory = _get_shared_memory() - with shared_memory.lock: - buffer_id = shared_memory.next_buffer_id[device_id] - shared_memory.next_buffer_id[device_id] = buffer_id + 1 - # TODO(jburnim): Add options for initializing memory (e.g., with NaNs, - # with zeros, or with the buffer ID). - shared_memory.mem[(memory_space, buffer_id, device_id)] = val - - # TODO(jburnim): Raise an error if buffer_id is too big for int16. - return np.int16(buffer_id) - -def _deallocate_buffer(device_id, memory_space, buffer_id): - device_id = int(device_id) - memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] - buffer_id = int(buffer_id) - - shared_memory = _get_shared_memory() - with shared_memory.lock: - # TODO(jburnim): Error if buffer doesn't exist? - shared_memory.mem.pop((memory_space, buffer_id, device_id), None) - -def _allocate_semaphores(device_id, shape): - device_id = int(device_id) - shape = tuple(map(int, shape)) - num_semaphores = math.prod(shape) - - shared_memory = _get_shared_memory() - with shared_memory.lock: - semaphore_id = shared_memory.next_semaphore_id[device_id] - shared_memory.next_semaphore_id[device_id] = semaphore_id + num_semaphores - for i in range(semaphore_id, semaphore_id + num_semaphores): - if i not in shared_memory.sem: - shared_memory.sem[i] = Semaphore(i) - - # NOTE: For now, we use a relatively uncommon datatype (int16) for - # semaphore (and buffer) IDs, so these values are more easily identifiable - # in kernels. - # - # TODO(jburnim): Raise an error if any IDs are too big for int16. - return np.int16( - range(semaphore_id, semaphore_id + num_semaphores) - ).reshape(shape) - - -TPU_MEMORY_SPACE_IDXS : dict[mosaic_core.TPUMemorySpace | None, int] = { - v: i for i, v in enumerate(mosaic_core.TPUMemorySpace)} -TPU_MEMORY_SPACE_NAMES = { - i: v.value for i, v in enumerate(mosaic_core.TPUMemorySpace)} - -# Default to VMEM when no memory space is specified. -TPU_MEMORY_SPACE_IDXS[None] = ( - TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.VMEM]) - -def get_barrier_semaphore(device_id, collective_id): - del device_id - collective_id = int(collective_id) - - # TODO(jburnim): Check/fix so that IDs for barrier semaphores do not conflict - # with IDs for regular or DMA semaphores. (For example, store them in a - # different table.) - shared_memory = _get_shared_memory() - with shared_memory.lock: - semaphore_id = collective_id - if semaphore_id not in shared_memory.sem: - shared_memory.sem[semaphore_id] = Semaphore() - - return np.int16(semaphore_id) - -def _transform_slice_or_index(slice_or_idx): - if isinstance(slice_or_idx, int): - return slice_or_idx - else: - start = int(slice_or_idx.start) - size = int(slice_or_idx.size) - stride = int(slice_or_idx.stride) - return slice(start, start + size * stride, stride) - -def _compose_slice_or_index(slice_or_idx1, slice_or_idx2): - ret = [] - i = 0 - j = 0 - while True: - if i == len(slice_or_idx1): - ret.extend(slice_or_idx2[j:]) - return tuple(ret) - elif j == len(slice_or_idx2): - ret.extend(slice_or_idx1[i:]) - return tuple(ret) - elif isinstance(slice_or_idx1[i], int): - ret.append(slice_or_idx1[i]) - i += 1 - elif isinstance(slice_or_idx2[j], int): - ret.append(slice_or_idx1[i].start + slice_or_idx2[j] * slice_or_idx1[i].step) - i += 1 - j += 1 - else: - ret.append(slice( - slice_or_idx1[i].start + slice_or_idx2[j].start * slice_or_idx1[i].step, - slice_or_idx1[i].start + slice_or_idx2[j].stop * slice_or_idx1[i].step, - slice_or_idx1[i].step * slice_or_idx2[j].step - )) - i += 1 - j += 1 - -def _to_range(transforms) -> tuple[slice | int, ...]: - ret = () - for transform in transforms: - # For now, assume only NDIndexer transforms. - ret = _compose_slice_or_index( - ret, tuple(_transform_slice_or_index(i) for i in transform.indices)) - return ret - -def get(device_id, memory_space, buffer_id, transforms, *, - src_device_id=None, clock=None, source_info=None): - device_id = int(device_id) - memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] - buffer_id = int(buffer_id) - try: - transforms = jax.tree.map(int, transforms) - except: - raise ValueError('Advanced indexers are not supported on TPU') - - shared_memory = _get_shared_memory() - with shared_memory.lock: - read_range = _to_range(transforms) - if shared_memory.interpret_params.detect_races: - inc_vector_clock(shared_memory.clocks[device_id], device_id) - if clock is None: - clock = copy_vector_clock(shared_memory.clocks[device_id]) - buffer = shared_memory.mem[(memory_space, buffer_id, device_id)] - ret = buffer[read_range].copy() - if transforms: - # TODO(jburnim): Instead of using NDIndexer, do the computation ourselves - # with buffer.shape and read_range? - expected_shape = transforms[-1].get_indexer_shape() - if expected_shape != ret.shape[:len(expected_shape)]: - raise ValueError( - f'Out-of-bounds read of ({device_id} {memory_space} {buffer_id}): ' - f'reading [{read_range}] but bufer has shape {buffer.shape} .') - - if shared_memory.interpret_params.detect_races: - if src_device_id is None: - src_device_id = device_id - check_read(src_device_id, clock, (memory_space, buffer_id, device_id), - read_range, source_info=source_info) - - return ret - -def store(device_id, memory_space, buffer_id, transforms, val, *, - src_device_id=None, clock=None, source_info=None): - device_id = int(device_id) - memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] - buffer_id = int(buffer_id) - try: - transforms = jax.tree.map(int, transforms) - except: - raise ValueError('Advanced indexers are not supported on TPU') - val = np.array(val) - - shared_memory = _get_shared_memory() - with shared_memory.lock: - if shared_memory.interpret_params.detect_races: - inc_vector_clock(shared_memory.clocks[device_id], device_id) - if clock is None: - clock = copy_vector_clock(shared_memory.clocks[device_id]) - - buff = shared_memory.mem[(memory_space, buffer_id, device_id)] - assert buff.dtype == val.dtype # TODO(jburnim): Catch this statically. - write_range = _to_range(transforms) - # TODO(jburnim): Better error message if this raises? - in_bounds_shape = buff[write_range].shape - if in_bounds_shape != val.shape: - raise ValueError( - f'Out-of-bounds write of ({device_id} {memory_space} {buffer_id}): ' - f'writing [{write_range}] but buffer has shape {buff.shape} .') - buff[write_range] = val - - if shared_memory.interpret_params.detect_races: - if src_device_id is None: - src_device_id = device_id - check_write(src_device_id, clock, (memory_space, buffer_id, device_id), - write_range, source_info=source_info) - -def swap(device_id, memory_space, buffer_id, transforms, val, mask, *, - source_info=None): - device_id = int(device_id) - memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] - buffer_id = int(buffer_id) - try: - transforms = jax.tree.map(int, transforms) - except: - raise ValueError('Advanced indexers are not supported on TPU') - val = np.array(val) - mask = np.array(mask) if mask is not None else None - if mask is not None: - assert mask.shape == val.shape - - shared_memory = _get_shared_memory() - with shared_memory.lock: - if shared_memory.interpret_params.detect_races: - inc_vector_clock(shared_memory.clocks[device_id], device_id) - clock = copy_vector_clock(shared_memory.clocks[device_id]) - buff = shared_memory.mem[(memory_space, buffer_id, device_id)] - assert buff.dtype == val.dtype # TODO(jburnim): Catch this statically. - read_write_range = _to_range(transforms) - # TODO(jburnim): Better error message if this raises? - raw_result = buff[read_write_range] - in_bounds_shape = raw_result.shape - if mask is None: - if in_bounds_shape != val.shape: - raise ValueError( - f'Out-of-bounds swap of ({device_id} {memory_space} {buffer_id}): ' - f'swapping [{read_write_range}] but buffer has shape {buff.shape} .') - buff[read_write_range] = val - return raw_result.copy() - - in_bounds_mask = np.full(mask.shape, True) - for i in range(len(in_bounds_shape)): - in_bounds_mask[in_bounds_shape[i]:] = False - if (~in_bounds_mask & mask).any(): - # TODO(jburnim): Include indices of out-of-bounds locations where mask - # is True. - raise ValueError( - f'Out-of-bounds masked swap of ({device_id} {memory_space} {buffer_id}): ' - f'swapping [{read_write_range}] but buffer has shape {buff.shape} . ') - - in_bounds_idx = tuple(slice(i) for i in in_bounds_shape) - result = val.copy() - result[in_bounds_idx] = np.where( - mask[in_bounds_idx], raw_result, val[in_bounds_idx]) - buff[read_write_range] = np.where( - mask[in_bounds_idx], val[in_bounds_idx], raw_result) - - if shared_memory.interpret_params.detect_races: - check_write(device_id, clock, (memory_space, buffer_id, device_id), - read_write_range, source_info=source_info) - return result - -def execute_dma(dma): - # TODO(jburnim) Eliminate duplicate code here and in Semaphore.wait. - shared_memory = _get_shared_memory() - with dma.lock: - assert dma.state == DmaState.STARTED - - if dma.virtual_device_id is None: - # See comment in Semaphore.wait . - dma.virtual_device_id = np.random.randint( - shared_memory.num_devices, NUM_VIRTUAL_DEVICES) - - # Do the read. - if shared_memory.interpret_params.detect_races: - inc_vector_clock(dma.clock, dma.virtual_device_id) - dma.data = get(dma.src_device_id, - dma.src_memory_space, - dma.src_buffer_id, - dma.src_transforms, - clock=copy_vector_clock(dma.clock), - src_device_id=dma.id, - source_info=dma.source_info) - data_size = dma.data.itemsize * dma.data.size - - # Signal the send semaphore. - if shared_memory.interpret_params.detect_races: - inc_vector_clock(dma.clock, dma.virtual_device_id) - if dma.src_sem is not None: - dma.src_sem.signal( - data_size, device_id=dma.src_device_id, clock=dma.clock) - dma.state = DmaState.READ - - # Do the write. - if shared_memory.interpret_params.detect_races: - inc_vector_clock(dma.clock, dma.virtual_device_id) - store(dma.dst_device_id, - dma.dst_memory_space, - dma.dst_buffer_id, - dma.dst_transforms, - dma.data, - clock=copy_vector_clock(dma.clock), - src_device_id=dma.id, - source_info=dma.source_info) - - # Signal the receive semaphore. - if shared_memory.interpret_params.detect_races: - inc_vector_clock(dma.clock, dma.virtual_device_id) - if dma.dst_sem is not None: - dma.dst_sem.signal( - data_size, device_id=dma.dst_device_id, clock=dma.clock) - - dma.data = None - dma.state = DmaState.COMPLETED - -def print_memory(device_id): - device_id = int(device_id) - if all(d == 0 for d in device_id): - shared_memory = _get_shared_memory() - with shared_memory.lock: - print(shared_memory.mem) - -def dma_start(device_id, src_memory_space, src_id, src_transforms, - dst_memory_space, dst_id, dst_transforms, - dst_sem_id, src_sem_id, dst_device_id, - source_info=None): - device_id = int(device_id) - src_memory_space, src_id = int(src_memory_space), int(src_id) - src_transforms = jax.tree.map(int, src_transforms) - dst_memory_space, dst_id = int(dst_memory_space), int(dst_id) - dst_transforms = jax.tree.map(int, dst_transforms) - dst_sem_id = int(dst_sem_id) - src_sem_id = int(src_sem_id) if src_sem_id is not None else None - if dst_device_id is not None: - dst_device_id = int(dst_device_id) - else: - dst_device_id = device_id - - shared_memory = _get_shared_memory() - with shared_memory.lock: - dst_sem = shared_memory.sem[dst_sem_id] - src_sem = shared_memory.sem[src_sem_id] if src_sem_id is not None else None - - clock = None - if shared_memory.interpret_params.detect_races: - inc_vector_clock(shared_memory.clocks[device_id], device_id) - clock = copy_vector_clock(shared_memory.clocks[device_id]) - dma_id = shared_memory.next_dma_id - shared_memory.next_dma_id += 1 - - dma = DMA( - dma_id, - device_id, src_memory_space, src_id, src_transforms, - dst_device_id, dst_memory_space, dst_id, dst_transforms, - src_sem, - dst_sem, - clock=clock, - source_info=source_info, - ) - - if shared_memory.interpret_params.dma_execution_mode == 'on_wait': - shared_memory.dmas_by_sem[dst_sem_id].append(dma) - if src_sem_id is not None: - shared_memory.dmas_by_sem[src_sem_id].append(dma) - return - - assert shared_memory.interpret_params.dma_execution_mode == 'eager' - execute_dma(dma) - -def dma_wait(device_id, sem_id, size): - device_id = int(device_id) - sem_id = int(sem_id) - size = int(size) - - shared_memory = _get_shared_memory() - with shared_memory.lock: - if shared_memory.interpret_params.detect_races: - inc_vector_clock(shared_memory.clocks[device_id], device_id) - sem = shared_memory.sem[sem_id] - sem.wait(size, device_id, is_dma=True) - -def semaphore_signal(device_id, sem_id, inc, target_device_id, - target_core_index): - device_id = int(device_id) - sem_id = int(sem_id) - inc = int(inc) - if target_device_id is None: - target_device_id = device_id - else: - target_device_id = int(target_device_id) - - if target_core_index is not None: - if int(target_core_index) != 0: - raise NotImplementedError('semaphore_signal with target_core_index != 0') - - shared_memory = _get_shared_memory() - with shared_memory.lock: - clock = None - if shared_memory.interpret_params.detect_races: - inc_vector_clock(shared_memory.clocks[device_id], device_id) - clock = copy_vector_clock(shared_memory.clocks[device_id]) - sem = shared_memory.sem[sem_id] - sem.signal(inc, target_device_id, clock) - -def semaphore_wait(device_id, sem_id, value): - device_id = int(device_id) - sem_id = int(sem_id) - value = int(value) - - shared_memory = _get_shared_memory() - with shared_memory.lock: - if shared_memory.interpret_params.detect_races: - inc_vector_clock(shared_memory.clocks[device_id], device_id) - sem = shared_memory.sem[sem_id] - sem.wait(value, device_id) - -def _compute_transformed_shape_and_dtype(shape, dtype, transforms): - for transform in transforms: - if transform is None: - continue - shape = transform.transform_shape(shape) - dtype = transform.transform_dtype(dtype) - return shape, dtype - -def _device_coords_to_logical_id(device_coords, axis_sizes): - if not isinstance(device_coords, tuple): - device_coords = (device_coords,) - assert len(device_coords) == len(axis_sizes) - sizes = list(axis_sizes.values()) - ret = 0 - for i in range(len(device_coords)): - ret += device_coords[i] * math.prod(sizes[i+1:]) - return ret - -def _device_id_to_logical(device_id, device_id_type, axis_sizes): - if device_id is None: - return None - if device_id_type == mosaic_primitives.DeviceIdType.MESH: - return _device_coords_to_logical_id(device_id, axis_sizes) - elif device_id_type == mosaic_primitives.DeviceIdType.LOGICAL: - return device_id - else: - raise ValueError(f'Unsupported device ID type: {device_id_type}') - -@lu.cache -def _to_jaxpr(flat_fun, in_avals): - new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) - new_jaxpr = jax_core.ClosedJaxpr(new_jaxpr, consts) - return new_jaxpr - -def _is_any(memory_space): - return ((memory_space == mosaic_core.TPUMemorySpace.ANY) or - (memory_space == pallas_core.MemorySpace.ANY)) - -def _is_float(dtype): - return jnp.issubdtype(dtype, jnp.floating) - -_SENTINEL = jnp.inf - -@dataclasses.dataclass(frozen=True) -class Placeholder: - """Placeholder for use in `_interpret_jaxpr` below instead of putting a concrete value into `env`.""" - shape: tuple[int, ...] - dtype: jnp.dtype - -def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params): - env = {} - - def read(var): - if isinstance(var, jax_core.Literal): - result = var.val - else: - result = env[var] - if isinstance(result, Placeholder): - result = jax.lax.full(result.shape, _SENTINEL, result.dtype) - return result - - def write(var, value): - if interpret_params.skip_floating_point_ops and _is_float(value.dtype): - value = Placeholder(value.shape, value.dtype) - env[var] = value - - jax.util.safe_map(write, jaxpr.constvars + jaxpr.invars, args) - - # Get the device ID. - axis_sizes = jax_core.get_axis_env().axis_sizes - device_id = _device_coords_to_logical_id( - tuple(lax.axis_index(s) for s in axis_sizes.keys()), - axis_sizes) - # TODO(jburnim): Pass the device ID around, instead of re-fetching/computing - # it for each sub-jaxpr. - - # TODO(jburnim): Clean up and finish this evaluation loop. For example: - # - Replace the big if-statement with a dictionary of rules. - # - Handle other higher-order primitives? - # - Megacore. - _interpret = functools.partial( - _interpret_jaxpr, compiler_params=compiler_params, - interpret_params=interpret_params) - for eqn in jaxpr.eqns: - with source_info_util.user_context( - eqn.source_info.traceback, name_stack=eqn.source_info.name_stack): - prim = eqn.primitive - # We defer reading the values for `eqn.invars` into each of the branches - # of the if-elif-else statement below. This is because the else branch may - # not need to do any reads if `interpret_params.skip_floating_point_ops` - # is True. If this is the case, we want to avoid materializing the read - # array into the jaxpr when this function is traced. - deferred_invals = functools.partial(jax.util.safe_map, read, eqn.invars) - - if prim is primitives.load_p: - (ref, transforms, mask, _) = jax.tree.unflatten( - eqn.params['args_tree'], deferred_invals()) - if mask is not None: - raise NotImplementedError('masked load_p') - out = callback.io_callback( - functools.partial(get, source_info=eqn.source_info), - eqn.outvars[0].aval, - device_id, - TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], - ref, - transforms, - ordered=True) - - elif prim is primitives.swap_p: - (ref, transforms, val, mask) = jax.tree.unflatten( - eqn.params['args_tree'], deferred_invals()) - out = callback.io_callback( - functools.partial(swap, source_info=eqn.source_info), - eqn.outvars[0].aval, - device_id, - TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], - ref, - transforms, - val, - mask, - ordered=True) - - elif prim is mosaic_primitives.delay_p: - out = [] - - elif prim is lax.cond_p: - def _make_branch(jaxpr): - return lambda *args: _interpret(jaxpr, *args) - invals = deferred_invals() - out = lax.switch( - invals[0], - [_make_branch(branch_jaxpr.jaxpr) - for branch_jaxpr in eqn.params['branches']], - *invals[1:]) - - elif prim is lax.scan_p: - consts, init_carry, xs = split_list( - deferred_invals(), - [eqn.params['num_consts'], eqn.params['num_carry']], - ) - def _scan_body(c, a): - return split_list( - _interpret(eqn.params['jaxpr'].jaxpr, *consts, *c, *a), - [eqn.params['num_carry']]) - carry, out = lax.scan(_scan_body, init_carry, xs=xs, - length=eqn.params.get('length', None)) - out = carry + out - - elif prim is lax.while_p: - cond_consts, body_consts, init_vals = split_list( - deferred_invals(), - [eqn.params['cond_nconsts'], eqn.params['body_nconsts']], - ) - out = lax.while_loop( - lambda args: _interpret( - eqn.params['cond_jaxpr'].jaxpr, *cond_consts, *args)[0], - lambda args: _interpret( - eqn.params['body_jaxpr'].jaxpr, *body_consts, *args), - init_vals) - - elif prim is for_loop.for_p: - raise NotImplementedError('for_p') - - elif prim is pjit.pjit_p: - def f(*args, jaxpr): - return _interpret(jaxpr.jaxpr, *jaxpr.consts, *args) - invals = deferred_invals() - in_avals = tuple(jax_core.shaped_abstractify(i) for i in invals) - new_jaxpr = _to_jaxpr( - lu.wrap_init(functools.partial(f, jaxpr=eqn.params['jaxpr']), - debug_info=eqn.params['jaxpr'].jaxpr.debug_info), - in_avals) - out = pjit.pjit_p.bind(*invals, **(eqn.params | {'jaxpr': new_jaxpr})) - - elif prim is primitives.run_scoped_p: - # Allocate a buffer or semaphore for each element of - # eqn.params['jaxpr'].invars . - allocs = [] - for v in eqn.params['jaxpr'].invars: - if v.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE: - allocs.append(callback.io_callback( - _allocate_semaphores, - jax.ShapeDtypeStruct(v.aval.shape, jnp.int16), - device_id, - v.aval.shape, - ordered=True)) - else: - allocs.append(callback.io_callback( - _allocate_buffer, - jax.ShapeDtypeStruct((), jnp.int16), - device_id, - TPU_MEMORY_SPACE_IDXS[v.aval.memory_space], - _uninitialized_value( - v.aval.shape, v.aval.dtype, interpret_params), - ordered=True)) - - out = _interpret(eqn.params['jaxpr'], *deferred_invals(), *allocs) - - for a in allocs: - if isinstance(a, tuple): - callback.io_callback( - _deallocate_buffer, - None, - device_id, - TPU_MEMORY_SPACE_IDXS[v.aval.memory_space], - a, - ordered=True) - else: - # TODO(jburnim): De-allocate semaphores. - # callback.io_callback( - # _deallocate_semaphores, - # None, - # device_id, - # a, - # ordered=True) - pass - - elif prim is state_primitives.get_p: - invals = deferred_invals() - out = callback.io_callback( - functools.partial(get, source_info=eqn.source_info), - eqn.outvars[0].aval, - device_id, - TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], - invals[0], - jax.tree.unflatten(eqn.params['tree'], invals[1:]), - ordered=True) - - elif prim is state_primitives.swap_p: - invals = deferred_invals() - out = callback.io_callback( - functools.partial(swap, source_info=eqn.source_info), - eqn.outvars[0].aval, - device_id, - TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], - invals[0], - jax.tree.unflatten(eqn.params['tree'], invals[2:]), - invals[1], - None, - ordered=True) - - elif prim is mosaic_primitives.dma_start_p: - ( - src, - src_transforms, - dst, - dst_transforms, - dst_sem, - dst_sem_transforms, - src_sem, - src_sem_transforms, - target_device_id, - ) = jax.tree.unflatten(eqn.params['tree'], deferred_invals()) - target_device_id = _device_id_to_logical( - target_device_id, eqn.params['device_id_type'], axis_sizes) - (orig_src_ref, _, orig_dst_ref, *_ - ) = jax.tree.unflatten(eqn.params['tree'], eqn.invars) - callback.io_callback( - functools.partial(dma_start, source_info=eqn.source_info), - (), - device_id, - TPU_MEMORY_SPACE_IDXS[getattr(orig_src_ref.aval, 'memory_space', mosaic_core.TPUMemorySpace.ANY)], - src, src_transforms, - TPU_MEMORY_SPACE_IDXS[getattr(orig_dst_ref.aval, 'memory_space', mosaic_core.TPUMemorySpace.ANY)], - dst, dst_transforms, - state_discharge.transform_array(dst_sem, dst_sem_transforms), - state_discharge.transform_array(src_sem, src_sem_transforms), - target_device_id, - ordered=True) - out = [] - - elif prim is mosaic_primitives.dma_wait_p: - ( - src, - src_transforms, - dst, - dst_transforms, - dst_sem, - dst_sem_transforms, - src_sem, - src_sem_transforms, - target_device_id, - ) = jax.tree.unflatten(eqn.params['tree'], deferred_invals()) - read_shape, read_dtype = _compute_transformed_shape_and_dtype( - eqn.invars[0].aval.shape, eqn.invars[0].aval.dtype, src_transforms) - callback.io_callback( - dma_wait, - (), - device_id, - state_discharge.transform_array(dst_sem, dst_sem_transforms), - math.prod(read_shape) * read_dtype.itemsize, - ordered=True) - out = [] - - elif prim is mosaic_primitives.get_barrier_semaphore_p: - out = callback.io_callback( - get_barrier_semaphore, - jax.ShapeDtypeStruct((), jnp.int16), - device_id, - compiler_params['mosaic']['collective_id'], - ordered=True) - - elif prim is mosaic_primitives.semaphore_signal_p: - sem, sem_transforms, inc, target_device_id, core_index = ( - jax.tree.unflatten(eqn.params['args_tree'], deferred_invals())) - target_device_id = _device_id_to_logical( - target_device_id, eqn.params['device_id_type'], axis_sizes) - callback.io_callback( - semaphore_signal, - (), - device_id, - state_discharge.transform_array(sem, sem_transforms), - inc, - target_device_id, - core_index, - ordered=True) - out = [] - - elif prim is mosaic_primitives.semaphore_wait_p: - sem, sem_transforms, value = ( - jax.tree.unflatten(eqn.params['args_tree'], deferred_invals())) - callback.io_callback( - semaphore_wait, - (), - device_id, - state_discharge.transform_array(sem, sem_transforms), - value, - ordered=True) - out = [] - - elif prim is primitives.atomic_rmw_p: - raise NotImplementedError('atomic_rmw_p') - - elif prim is primitives.atomic_cas_p: - raise NotImplementedError('atomic_cas_p') - - else: - if interpret_params.skip_floating_point_ops and all( - _is_float(ovar.aval.dtype) for ovar in eqn.outvars - ): - # Skip `prim.bind` since `prim` only produces floating-point values. - # It is safe to populate `out` with avals since mapping `write` over - # `out` below only relies on the shape and dtype (for writing - # `Placeholder`s). - out = [ovar.aval for ovar in eqn.outvars] - if not prim.multiple_results: - out = out[0] - else: - subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) - out = prim.bind(*subfuns, *deferred_invals(), **bind_params) - - out = out if prim.multiple_results else [out] - jax.util.safe_map(write, eqn.outvars, out) - - return jax.util.safe_map(read, jaxpr.outvars) - -def _initialize_output_vals( - block_mappings_output: Iterable[BlockMapping], - input_args, input_output_aliases, - interpret_params: TPUInterpretParams, -) -> Sequence[jax.Array]: - oi_map = {v: k for k, v in input_output_aliases} - output_vals = [] - for i, bm in enumerate(block_mappings_output): - if i in oi_map: - output_vals.append(input_args[oi_map[i]]) - else: - output_vals.append(_uninitialized_value( - bm.array_shape_dtype.shape, - bm.array_shape_dtype.dtype, - interpret_params)) - return output_vals - -def _compute_start_indices(block_mapping, loop_idx, *args): - block_indices = ( - jax_core.jaxpr_as_fun(block_mapping.index_map_jaxpr)(*loop_idx, *args)) - if isinstance(block_mapping.indexing_mode, pallas_core.Blocked): - ret = tuple(i if b is pallas_core.mapped else b * i - for b, i in zip(block_mapping.block_shape, block_indices)) - elif isinstance(block_mapping.indexing_mode, pallas_core.Unblocked): - ret = block_indices - else: - raise RuntimeError(f"Unknown indexing mode: {block_mapping.indexing_mode}") - return ret - -def _get_next_indices(grid, indices): - next_indices = [] - carry = True - for dim_size, index in reversed(list(zip(grid, indices))): - i = jnp.where(carry, index + 1, index) - carry = dim_size == i - next_indices.append(jnp.where(carry, 0, i)) - return tuple(reversed(next_indices)) - -def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing): - start_idx = tuple(jnp.array(s, dtype=jnp.int32) for s in start_idx) - output = lax.dynamic_slice(value, start_idx, slice_sizes=block_shape) - squeeze_dims = tuple(np.arange(len(is_indexing))[np.array(is_indexing, - dtype=np.bool_)]) - return lax.squeeze(output, squeeze_dims) - -def _uninitialized_value(shape, dtype, interpret_params): - if interpret_params.uninitialized_memory == 'nan': - if jnp.issubdtype(dtype, jnp.floating): - return jnp.full(shape, jnp.nan, dtype) - elif jnp.issubdtype(dtype, jnp.integer): - return jnp.full(shape, jnp.iinfo(dtype).max, dtype) - elif jnp.issubdtype(dtype, jnp.bool): - return jnp.full(shape, False, dtype) - if interpret_params.uninitialized_memory == 'zero': - return jnp.full(shape, 0, dtype) - raise NotImplementedError( - interpret_params.uninitialized_memory + ' + ' + str(dtype)) - -def _pad_to_block_dimension(value, block_shape, interpret_params): - """Pads values so the shape evenly divides into block dimensions. - - For example, if values has a shape of (33, 2, 5) with a block_shape of - (32, 2, 4), this function will pad the value of shape to (64, 2, 8). - - Args: - value: Array to be padded. - block_shape: Block shapes to use for padding. If None, no padding will - be performed. - - Returns: - A padded array. - """ - padded_shape = tuple( - ((v - 1) // b + 1) * b for v, b in zip(value.shape, block_shape) - ) - if padded_shape != value.shape: - pad_width = tuple((0, a-b) for a, b in zip(padded_shape, value.shape)) - pad_value = _uninitialized_value((), value.dtype, interpret_params) - value = jnp.pad(value, pad_width, constant_values=pad_value) - return value - -def get_interpret_effects(): - return {callback._OrderedIOEffect} - -def interpret_pallas_call( - *args, - jaxpr: jax_core.Jaxpr, - debug: bool, - input_output_aliases: tuple[tuple[int, int], ...], - grid_mapping: GridMapping, - mesh: pallas_core.Mesh | None, - compiler_params: Any, - cost_estimate: CostEstimate, - out_avals: tuple[jax_core.AbstractValue, ...], - interpret_params: TPUInterpretParams, -): - del debug, mesh, cost_estimate, out_avals - - # args contains: *dynamic_grid_sizes, *index, *inputs. (No consts?) - dynamic_grid_args, scalars, input_args = split_list( - args, - [grid_mapping.num_dynamic_grid_bounds, grid_mapping.num_index_operands], - ) - dynamic_grid_args_iter = iter(dynamic_grid_args) - grid = tuple( - a if a is not pallas_core.dynamic_grid_dim - else next(dynamic_grid_args_iter) - for a in grid_mapping.grid - ) - assert next(dynamic_grid_args_iter, None) is None - - axis_sizes = jax_core.get_axis_env().axis_sizes - num_devices = functools.reduce( - jnp.multiply, axis_sizes.values(), jnp.int32(1)) - device_id = _device_coords_to_logical_id( - tuple(lax.axis_index(s) for s in axis_sizes.keys()), - axis_sizes) - callback.io_callback( - functools.partial( - _initialize_shared_memory, interpret_params=interpret_params), - (), - device_id, - num_devices, - ordered=True) - - # Pad input arguments. - is_indexing_dim = [ - tuple(b is pallas_core.mapped for b in bm.block_shape) - for bm in grid_mapping.block_mappings - ] - block_shapes = [ - tuple(1 if i else b for i, b in zip(iid, bm.block_shape)) - for iid, bm in zip(is_indexing_dim, grid_mapping.block_mappings) - ] - num_inputs = grid_mapping.num_inputs - input_args = [ - _pad_to_block_dimension(a, bs, interpret_params) - for a, bs in zip(input_args, block_shapes[:num_inputs]) - ] - - # Allocate buffers in HBM for outputs. - output_buffer_ids = [] - output_buffer_shapes = [] - output_vals = _initialize_output_vals( - grid_mapping.block_mappings_output, - scalars + input_args, - input_output_aliases, - interpret_params) - num_outputs = grid_mapping.num_outputs - output_block_shapes = block_shapes[num_inputs : num_inputs + num_outputs] - for out_val, bs in zip(output_vals, output_block_shapes): - padded_val = _pad_to_block_dimension(out_val, bs, interpret_params) - output_buffer_shapes.append(padded_val.shape) - output_buffer_ids.append(callback.io_callback( - _allocate_buffer, - jax.ShapeDtypeStruct((), jnp.int16), - device_id, - TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], - padded_val, - ordered=True)) - # Allocate buffers for all kernel arguments (e.g., scalars, inputs, - # outputs, scratch). - io_alias_map = dict(input_output_aliases) - oi_alias_map = {v: k for k, v in input_output_aliases} - kernel_buffer_ids = [] - for _, val in zip(jaxpr.invars[grid_mapping.slice_index_ops], scalars): - kernel_buffer_ids.append(callback.io_callback( - _allocate_buffer, - jax.ShapeDtypeStruct((), jnp.int16), - device_id, - TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.SMEM], - val, - ordered=True)) - for i, var in enumerate(jaxpr.invars[grid_mapping.num_index_operands:]): - output_idx = i - grid_mapping.num_inputs - is_input = i < grid_mapping.num_inputs - is_output = (output_idx >= 0) and (output_idx < grid_mapping.num_outputs) - if var.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE: - kernel_buffer_ids.append(callback.io_callback( - _allocate_semaphores, - jax.ShapeDtypeStruct(var.aval.shape, jnp.int16), - device_id, - var.aval.shape, - ordered=True)) - elif is_output and _is_any(var.aval.memory_space): - # Use the already-allocated HBM output buffer. - # - # TODO(jburnim): For kernel args in HBM, check that block shape is the - # same as for the corresponding pallas_call input, and that the index_map - # is trivial. - kernel_buffer_ids.append(output_buffer_ids[output_idx]) - elif is_output and (output_idx in oi_alias_map): - # Use the already-allocated (non-HBM) input buffer. - kernel_buffer_ids.append(kernel_buffer_ids[oi_alias_map[output_idx]]) - elif is_input and (i in io_alias_map) and _is_any(var.aval.memory_space): - # Use the already-allocated HBM output buffer. - kernel_buffer_ids.append(output_buffer_ids[io_alias_map[i]]) - else: - # TODO(jburnim): For kernel args in HBM, check that block shape is the - # same as for the corresponding pallas_call input, and that the index_map - # is trivial. - kernel_buffer_ids.append(callback.io_callback( - _allocate_buffer, - jax.ShapeDtypeStruct((), jnp.int16), - device_id, - TPU_MEMORY_SPACE_IDXS[var.aval.memory_space], - _uninitialized_value( - var.aval.shape, var.aval.dtype, interpret_params), - ordered=True)) - - _, input_ids, kernel_output_ids, _ = split_list( - kernel_buffer_ids, - [grid_mapping.num_index_operands, num_inputs, grid_mapping.num_outputs]) - input_vars, output_vars = split_list( - jaxpr.invars[grid_mapping.slice_block_ops], [num_inputs]) - - # For kernel inputs that are in HBM, we populate the buffer once before - # any kernel invocations. - for buffer_id, var, val in zip(input_ids, input_vars, input_args): - if not _is_any(var.aval.memory_space): - continue - if (val.shape != var.aval.shape) or (val.dtype != var.aval.dtype): - # TODO(jburnim): Also check that the index_map is trivial. - raise ValueError() - callback.io_callback( - store, - (), - device_id, - TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], - buffer_id, - (), - val, - ordered=True) - - if grid: - num_iterations = functools.reduce(jnp.multiply, grid) # type: ignore[arg-type] - else: - # Base case is always one iteration when grid is () - num_iterations = 1 - - def body(carry): - # The loop carry: (i, loop_idx) -- - # - i:int32 is the interation index - # - loop_idx: tuple[int32] are the program ids for each grid axis - i, loop_idx = carry - - if grid_mapping.local_grid_env is not None: - local_grid_env = grid_mapping.local_grid_env(loop_idx, grid) - else: - local_grid_env = tuple( - pallas_core.GridAxis(idx, b) - for dim, (idx, b) in enumerate(zip(loop_idx, grid)) - if dim not in grid_mapping.vmapped_dims - ) - - with pallas_core.grid_env(local_grid_env): - # Copy slices of the input to the kernel buffers. - # - # TODO(jburnim): Only copy slices when the index mapping has changed? - start_indices = [_compute_start_indices(bm, loop_idx, *scalars) - for bm in grid_mapping.block_mappings] - for j, var in enumerate(input_vars): - if _is_any(var.aval.memory_space): - continue - sliced_val = _maybe_dynamic_slice(start_indices[j], block_shapes[j], - input_args[j], is_indexing_dim[j]) - assert(sliced_val.shape == var.aval.shape) - callback.io_callback( - # TODO(jburnim): Pass source_info from the pallas_call, in case this - # store is involved in a data race. - store, - (), - device_id, - TPU_MEMORY_SPACE_IDXS[var.aval.memory_space], - input_ids[j], - (), - sliced_val, - ordered=True) - - # Invoke the kernel. - _interpret_jaxpr(jaxpr, *kernel_buffer_ids, - compiler_params=compiler_params, - interpret_params=interpret_params) - - # Copy from the kernel buffers to slices of the output in HBM. - # - # TODO(jburnim): Only copy if the index mapping will change in the - # next iteration (or if this is the last iteration)? - for j, var in enumerate(output_vars): - if _is_any(var.aval.memory_space): - continue - kernel_output_val = callback.io_callback( - # TODO(jburnim): Pass source_info from the pallas_call, in case this - # get is involved in a data race. - get, - var.aval, - device_id, - TPU_MEMORY_SPACE_IDXS[var.aval.memory_space], - kernel_output_ids[j], - (), - ordered=True) - transform = indexing.NDIndexer( - indices=tuple(indexing.ds(st, sz) if not iid else st - for st, sz, iid in zip(start_indices[num_inputs + j], - block_shapes[num_inputs + j], - is_indexing_dim[num_inputs + j])), - shape=output_vals[j].shape, - int_indexer_shape=()) - callback.io_callback( - # TODO(jburnim): Pass source_info from the pallas_call, in case this - # store is involved in a data race. - store, - (), - device_id, - TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], - output_buffer_ids[j], - (transform,), - kernel_output_val, - ordered=True) - - return i + 1, _get_next_indices(grid, loop_idx) - - # TODO(jburnim): Handle parallel grid dimensions + megacore. - _ = lax.while_loop( - lambda carry: carry[0] < num_iterations, - body, - (jnp.int32(0), (jnp.int32(0),) * len(grid)) - ) - - # Read the output from the allocated output buffers. - ret = [ - callback.io_callback( - # TODO(jburnim): Pass source_info from the pallas_call, in case this - # get is involved in a data race. - get, - val, - device_id, - TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], - output_buffer_id, - (indexing.NDIndexer.from_indices_shape( - tuple(indexing.ds(0, s) for s in val.shape), - output_buffer_shape),), - ordered=True) - for val, output_buffer_id, output_buffer_shape in zip( - output_vals, output_buffer_ids, output_buffer_shapes) - ] - - callback.io_callback( - _validate, - (), - device_id, - ordered=True) - - # For now, when we're done with a pallas_call, we delete the shared memory. - # We use a barrier to ensure that all devices are done running the kernel. - # - # TODO(jburnim): Get rid of this barrier. And figure out how this should - # work if we want to invoke successive pallas_calls that use the same - # shared memory. - callback.io_callback( - _clean_up_shared_memory, - (), - device_id, - ordered=True) - - return ret diff --git a/jax/_src/pallas/mosaic/interpret/BUILD b/jax/_src/pallas/mosaic/interpret/BUILD new file mode 100644 index 000000000000..2a86f2258032 --- /dev/null +++ b/jax/_src/pallas/mosaic/interpret/BUILD @@ -0,0 +1,103 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Package for Pallas TPU Interpret Mode + +load("@rules_python//python:defs.bzl", "py_library") +load("//jaxlib:jax.bzl", "py_deps", "pytype_strict_library") + +package( + default_applicable_licenses = [], + default_visibility = [ + "//jax:internal", + ], +) + +py_library( + name = "interpret_pallas_call", + srcs = [ + "__init__.py", + "interpret_pallas_call.py", + ], + deps = [ + ":race_detection_state", + ":shared_memory", + ":thread_map", + ":utils", + ":vector_clock", + "//jax", + "//jax/_src:api", + "//jax/_src:callback", + "//jax/_src:config", + "//jax/_src:core", + "//jax/_src:frozen_dict", + "//jax/_src:lax", + "//jax/_src:mlir", + "//jax/_src:source_info_util", + "//jax/_src:typing", + "//jax/_src:util", + "//jax/_src/pallas", + "//jax/_src/pallas/mosaic:core", + "//jax/_src/pallas/mosaic:primitives", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "vector_clock", + srcs = ["vector_clock.py"], + deps = py_deps("numpy"), +) + +pytype_strict_library( + name = "shared_memory", + srcs = ["shared_memory.py"], + deps = [ + ":race_detection_state", + ":vector_clock", + "//jax", + "//jax/_src:source_info_util", + "//jax/_src:typing", + "//jax/_src/pallas", + "//jax/_src/pallas/mosaic:core", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "race_detection_state", + srcs = ["race_detection_state.py"], + deps = [ + ":vector_clock", + "//jax/_src:source_info_util", + ], +) + +pytype_strict_library( + name = "thread_map", + srcs = ["thread_map.py"], + deps = [ + "//jax", + "//jax/_src:callback", + ], +) + +pytype_strict_library( + name = "utils", + srcs = ["utils.py"], + deps = [ + "//jax", + "//jax/_src:core", + "//jax/_src:util", + "//jax/_src/pallas", + ] + py_deps("numpy"), +) diff --git a/jax/_src/pallas/mosaic/interpret/__init__.py b/jax/_src/pallas/mosaic/interpret/__init__.py new file mode 100644 index 000000000000..1337256a5074 --- /dev/null +++ b/jax/_src/pallas/mosaic/interpret/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py b/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py new file mode 100644 index 000000000000..aa5d881818f7 --- /dev/null +++ b/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py @@ -0,0 +1,2265 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable +import contextlib +import dataclasses +import enum +import functools +import itertools +import math +import threading +from typing import Any, Literal, cast + +import jax +from jax import lax +from jax._src import callback +from jax._src import config +from jax._src import core as jax_core +from jax._src import frozen_dict +from jax._src import linear_util as lu +from jax._src import pjit +from jax._src import source_info_util +from jax._src.interpreters import mlir +from jax._src.pallas import core as pallas_core +from jax._src.pallas import primitives +from jax._src.pallas.mosaic import core as mosaic_core +from jax._src.pallas.mosaic import primitives as mosaic_primitives +from jax._src.pallas.mosaic.interpret import shared_memory as memory +from jax._src.pallas.mosaic.interpret import vector_clock as vc +from jax._src.pallas.mosaic.interpret.race_detection_state import RaceDetectionState +from jax._src.pallas.mosaic.interpret.thread_map import thread_map +import jax._src.pallas.mosaic.interpret.utils as interpret_utils +from jax._src.state import discharge as state_discharge +from jax._src.state import indexing +from jax._src.state import primitives as state_primitives +from jax._src.typing import Array +from jax._src.util import ( + safe_map, + safe_zip, + split_list +) +from jax.interpreters import partial_eval as pe +import jax.numpy as jnp +import numpy as np + + +map, unsafe_map = safe_map, map +zip, unsafe_zip = safe_zip, zip + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class InterpretParams(interpret_utils.InterpretParams): + """Parameters for TPU interpret mode. + + TPU interpret mode is a way run Pallas TPU kernels on CPU, while simulating + a TPU's shared memory (HBM, VMEM, etc.), communication (remote and local + DMAs), and synchronization operations (semaphores, barriers, etc.). This mode + is intended for debugging and testing. + + To run a kernel under TPU interpret mode, pass an instance of + ``InterpretParams`` as an argument for the ``interpret`` parameter of + :func:`jax.experimental.pallas.pallas_call` or + :func:`jax.experimental.pallas.core_map`. + + NOTE: If an exception is raised while interpreting a kernel, you must call + :func:`reset_tpu_interpret_mode_state` before using TPU interpret mode + again in the same process. + + Attributes: + dma_execution_mode: If "eager", DMAs are executed as soon as they are + issued. If "on_wait", DMA reads or writes are only executed when a device + is waiting on a DMA semaphore that will be signaled when the read or write + is complete. + Default: "on_wait". + random_seed: Seed for random number generator used during interpretation. + Currently random numbers are used to randomize the grid coordinates along + dimensions with 'parallel' semantics. + Default: None. + grid_point_recorder: Callback that is invoked by the interpreter for each + grid point in the order in which the grid points are traversed. The + callback is invoked with two arguments: - A tuple of grid coordinates. - + The local core ID of the core that is processing the grid point. This + callback is intended for inspecting - the randomization of coordinates + along grid dimensions with 'parallel' semantics and - the mapping of grid + points to local (i.e. per-device) cores. + Default: None. + allow_hbm_allocation_in_run_scoped: If `True`, allows the allocation of HBM + buffers (which are then shared across the cores in a device) in + `run_scoped`. While this behavior can be enabled in the interpreter, + allocating HBM buffers with `run_scoped` is not supported when executing + Pallas kernels on a real TPU. + Default: `False`. + """ + + dma_execution_mode: Literal["eager", "on_wait"] = "on_wait" + random_seed: int | None = None + grid_point_recorder: ( + Callable[[tuple[np.int32, ...], np.int32], None] | None + ) = None + allow_hbm_allocation_in_run_scoped: bool = False + + @property + def num_cores_per_device(self) -> int: + return self.num_cores_or_threads + + +@contextlib.contextmanager +def force_tpu_interpret_mode(params: InterpretParams = InterpretParams()): + """Context manager that forces TPU interpret mode under its dynamic context. + + TPU interpret mode is a way run Pallas TPU kernels on CPU, while simulating + a TPU's shared memory (HBM, VMEM, etc.), communication (remote and local + DMAs), and synchronization operations (semaphores, barriers, etc.). This mode + is intended for debugging and testing. See :class:`InterpretParams` for + additional information. + + Args: + params: an instance of :class:`InterpretParams`. Any call to + :func:`jax.experimental.pallas.pallas_call` or + :func:`jax.experimental.pallas.core_map` that is traced under this context + manager will be run with ``interpret=params``. When ``params`` is not + ``None``, this will cause those calls to run with TPU interpret mode. + """ + prev = config.pallas_tpu_interpret_mode_context_manager.swap_local(params) + try: + yield + finally: + config.pallas_tpu_interpret_mode_context_manager.set_local(prev) + +def set_tpu_interpret_mode(params: InterpretParams = InterpretParams()): + config.pallas_tpu_interpret_mode_context_manager.set_global(params) # type: ignore[arg-type] + + +# TODO(jburnim): Do we want to support multiple instances of SharedMemory? +# Maybe for running multiple distinct interpreted computations in parallel? +_shared_memory: memory.SharedMemory | None = None +_shared_memory_init_lock = threading.Lock() +races: RaceDetectionState | None = None +dma_id_counter: interpret_utils.Counter | None = None + +def reset_tpu_interpret_mode_state(): + """Resets all global, shared state used by TPU interpret mode. + + TPU interpret mode uses global, shared state for simulating memory buffers + and semaphores, for race detection, etc., when interpreting a kernel. + Normally, this shared state is cleaned up after a kernel is interpreted. + + But if an exception is thrown while interpreting a kernel, the shared state + is not cleaned up, allowing the simulated TPU state to be examined for + debugging purposes. In this case, the shared state must be reset before + any further kernels are interpreted. + """ + global _shared_memory, races, dma_id_counter + with _shared_memory_init_lock: + _shared_memory = None + races = None + dma_id_counter = None + + +def _get_shared_memory() -> memory.SharedMemory: + assert _shared_memory is not None + return _shared_memory + + +def _clear_shared_memory(): + global _shared_memory + with _shared_memory_init_lock: + _shared_memory = None + + +def _initialize_shared_memory( + device_id, num_devices, num_cores_per_device, *, interpret_params +): + global _shared_memory, races, dma_id_counter + del device_id + + num_devices = int(num_devices) + num_cores_per_device = int(num_cores_per_device) + num_cores = num_devices * num_cores_per_device + + with _shared_memory_init_lock: + if _shared_memory is None: + vector_clock_size = interpret_params.get_vector_clock_size(num_devices) + races = RaceDetectionState(num_cores=num_cores) + dma_id_counter = interpret_utils.Counter(100) + _shared_memory = memory.SharedMemory( + num_devices=num_devices, + num_cores_per_device=num_cores_per_device, + out_of_bounds_reads=interpret_params.out_of_bounds_reads, + dma_execution_mode=interpret_params.dma_execution_mode, + uninitialized_memory=interpret_params.uninitialized_memory, + detect_races=interpret_params.detect_races, + vector_clock_size=vector_clock_size, + clocks=[ + vc.make_vector_clock(vector_clock_size) for _ in range(num_cores) + ], + barrier=threading.Barrier( + num_devices, action=_update_clocks_for_global_barrier + ), + clean_up_barrier=threading.Barrier( + num_devices, action=_clear_shared_memory + ), + ) + assert _shared_memory.num_cores == num_cores + + +def _update_clocks_for_device_barrier(device_id): + """Synchronizes the vector clocks for the cores on the given device.""" + shared_memory = _get_shared_memory() + shared_memory.update_clocks_for_device_barrier(device_id) + + +def _update_clocks_for_global_barrier(): + """Synchronizes all vector clocks.""" + shared_memory = _get_shared_memory() + shared_memory.update_clocks(0, shared_memory.num_cores) + + +def _barrier(device_id): + del device_id + shared_memory = _get_shared_memory() + if shared_memory.num_devices > 1: + shared_memory.barrier.wait() + + +def _clean_up_shared_memory(device_id): + del device_id + shared_memory = _get_shared_memory() + shared_memory.clean_up_barrier.wait() + + +def _check_for_revisiting(device_id, local_core_id, loop_idx, output_blocks): + device_id = int(device_id) + local_core_id = int(local_core_id) + loop_idx = tuple(int(x) for x in loop_idx) + try: + output_blocks = jax.tree.map(int, output_blocks) + except: + raise ValueError('Advanced indexers are not supported on TPU') + output_ranges = [ + interpret_utils.to_range(b) if b is not None else None + for b in output_blocks + ] + + shared_memory = _get_shared_memory() + past_output_ranges = shared_memory.output_ranges[(device_id, local_core_id)] + if not past_output_ranges: + past_output_ranges.append((loop_idx, output_ranges)) + return + + for i in range(len(output_ranges)): + if output_ranges[i] is None: + continue + if past_output_ranges[-1][1][i] == output_ranges[i]: + continue + # TODO(jburnim): Do something constant time instead of linear here. + past_idxs = [ + j + for j, ors in enumerate(past_output_ranges) + if ors[1][i] == output_ranges[i] + ] + if past_idxs: + raise RuntimeError( + f'Revisited block {output_ranges[i]} of output {i} in iteration ' + f'{loop_idx}. The block was previously visited in iterations ' + f'{past_output_ranges[past_idxs[0]][0]} through ' + f'{past_output_ranges[past_idxs[-1]][0]} .' + ) + + past_output_ranges.append((loop_idx, output_ranges)) + + +def _validate(device_id): + device_id = int(device_id) + + shared_memory = _get_shared_memory() + semaphores = shared_memory.get_sempahores_with_nonzero_count(device_id) + if semaphores: + sem, global_core_id = semaphores[0] + # TODO(jburnim): Make this raise an error, but in a way that doesn't + # cause other devices to hang later in `_clean_up_shared_memory`. + print( + f'Semaphore {sem.id} has non-zero count for {device_id} (global core' + f' {global_core_id}) at kernel exit:' + f' {sem.count_by_core[global_core_id]}' + ) + + +def _allocate_buffer( + device_id: Array, + local_core_id: Array | None, + memory_space: Array, + val: Array, +): + """Allocates a memory buffer on the device with id `device_id` and core with id `local_core_id`. + + Args: + device_id: Singleton array holding the device id where the buffer will be + allocated. + local_core_id: None or singleton array holding the core id where the buffer + will be allocated. If None, a buffer will be allocated on each cores on + the device. + memory_space: Singleton array indicating the memory space to allocate the + buffer in. If the corresponding memory space is "any" (i.e. HBM), at most + one buffer will be allocated and it will belong to (local) core id 0. + val: Array of values to initialize the allocated buffer with. + + Returns: + Integer id for the allocated buffer. + """ + device_id = int(device_id) + memory_space_str = TPU_MEMORY_SPACE_NAMES[int(memory_space)] + del memory_space + val = np.array(val) + + shared_memory = _get_shared_memory() + + if local_core_id is None: + local_core_id_int = 0 + local_core_ids = tuple(range(shared_memory.num_cores_per_device)) + else: + local_core_id_int = int(local_core_id) + local_core_ids = (local_core_id_int,) + del local_core_id + + local_core_id_to_buffer_id: dict[int, int] = {} + for lci in local_core_ids: + buffer_id = shared_memory.get_next_buffer_id(device_id, lci) + if memory_space_str in ['any', 'hbm']: + # If allocating in HBM, only actually allocate a buffer once. The first + # local core (i.e. thread) that gets here allocates the buffer, but the + # buffer is still keyed in the shared memory with core ID 0. However, + # since the buffer is shared across all cores, we initialize the buffer's + # `ref_count` with the number of cores per device. This ensures that the + # buffer is not deallocated until all cores have exited the scope of the + # allocation (e.g. have exited the body of a `run_scoped`). + key = (memory_space_str, buffer_id, device_id, 0) + ref_count = shared_memory.num_cores_per_device + else: + key = (memory_space_str, buffer_id, device_id, lci) + ref_count = 1 + if len(local_core_id_to_buffer_id) > 0: + # If we are allocating more than one buffer, we must make additional + # copies of `val` so that each buffer is a distinct ndarray. + val = val.copy() + + shared_memory.allocate_buffer(key, ref_count=ref_count, value=val) + local_core_id_to_buffer_id[lci] = buffer_id + + # The buffer ids should always be kept in sync across all cores. + assert all( + buffer_id == local_core_id_to_buffer_id[local_core_id_int] + for buffer_id in local_core_id_to_buffer_id.values() + ) + # TODO(jburnim): Raise an error if buffer_id is too big for int16. + return np.int16(local_core_id_to_buffer_id[local_core_id_int]) + + +def _local_core_id_or_zero_if_hbm(local_core_id: int, memory_space: str) -> int: + if memory_space in ['any', 'hbm']: + return 0 + return local_core_id + + +def _deallocate_buffer(device_id, local_core_id, memory_space, buffer_id): + device_id = int(device_id) + local_core_id = int(local_core_id) + memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] + buffer_id = int(buffer_id) + + local_core_id = _local_core_id_or_zero_if_hbm(local_core_id, memory_space) + + shared_memory = _get_shared_memory() + key = (memory_space, buffer_id, device_id, local_core_id) + shared_memory.deallocate_buffer(key) + + +def _allocate_semaphores( + device_id: Array, local_core_id: Array | None, shape: Array +): + """Allocates semaphores on the device with id `device_id` and core with id `local_core_id`. + + The number of semaphores allocated is given by the product of the entries in + `shape`. + + Since for each semaphore id there is really only one global `Semaphore` + object, 'allocation' of semaphores per device and core here means that the + internal counter of semaphore ids that is held by `SharedMemory` is + incremented for each the device and core (or for all cores on the dive if + argument `local_core_id` is None, see below). + + Args: + device_id: Singleton array holding the id for the device where the + semaphores will be allocated. + local_core_id: None or singleton array holding the id for the core where the + semaphores will be allocated. If None, semaphores will be allocated on all + cores on the device. + shape: Shape of the semaphore array to allocate. + + Returns: + Array of semaphore ids. + """ + device_id = int(device_id) + shape = tuple(map(int, shape)) + num_semaphores = math.prod(shape) + + shared_memory = _get_shared_memory() + + if local_core_id is None: + local_core_id_int = 0 + global_core_ids = shared_memory.get_global_core_ids(device_id) + else: + local_core_id_int = int(local_core_id) + global_core_ids = ( + shared_memory.get_global_core_id(device_id, local_core_id_int), + ) + del local_core_id + + global_core_id_to_semaphore_id = {} + for gci in global_core_ids: + semaphore_id = shared_memory.allocate_semaphores(gci, num_semaphores) + global_core_id_to_semaphore_id[gci] = semaphore_id + + global_core_id = shared_memory.get_global_core_id( + device_id, local_core_id_int + ) + # The semaphore ids should always be kept in sync across all cores. + assert all( + semaphore_id == global_core_id_to_semaphore_id[global_core_id] + for semaphore_id in global_core_id_to_semaphore_id.values() + ) + + # NOTE: For now, we use a relatively uncommon datatype (int16) for + # semaphore (and buffer) IDs, so these values are more easily identifiable + # in kernels. + # + # TODO(jburnim): Raise an error if any IDs are too big for int16. + semaphore_id = global_core_id_to_semaphore_id[global_core_id] + return np.arange( + semaphore_id, semaphore_id + num_semaphores, dtype=np.int16 + ).reshape(shape) + + +TPU_MEMORY_SPACE_IDXS: dict[ + mosaic_core.MemorySpace | pallas_core.MemorySpace | None, int +] = {v: i for i, v in enumerate(mosaic_core.MemorySpace)} +TPU_MEMORY_SPACE_NAMES = { + i: v.value for i, v in enumerate(mosaic_core.MemorySpace) +} + +# Inject ANY as the last memory space. +TPU_MEMORY_SPACE_NAMES[len(TPU_MEMORY_SPACE_IDXS)] = ( + pallas_core.MemorySpace.ANY.value +) +TPU_MEMORY_SPACE_IDXS[pallas_core.MemorySpace.ANY] = len(TPU_MEMORY_SPACE_IDXS) + +# Default to VMEM when no memory space is specified. +TPU_MEMORY_SPACE_IDXS[None] = TPU_MEMORY_SPACE_IDXS[ + mosaic_core.MemorySpace.VMEM +] + + +def get_barrier_semaphore(device_id, collective_id): + del device_id + collective_id = int(collective_id) + shared_memory = _get_shared_memory() + shared_memory.guarantee_semaphore_with_fixed_id(collective_id) + return np.int16(collective_id) + + +def _to_int(x: int | Array | None) -> int | None: + """Converts a value to an integer, or returns None if the value is None.""" + if x is None: + return None + return int(x) + + +def get( + device_id, + local_core_id, + memory_space, + buffer_id, + transforms, + block_indices=None, + grid_loop_idx=None, + *, + src_device_id=None, + src_local_core_id=None, + clock=None, + source_info=None, + input_name=None, +) -> np.ndarray: + device_id = int(device_id) + local_core_id = int(local_core_id) + memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] + buffer_id = int(buffer_id) + try: + transforms = jax.tree.map(int, transforms) + except: + raise ValueError('Advanced indexers are not supported on TPU') + src_device_id = _to_int(src_device_id) + src_local_core_id = _to_int(src_local_core_id) + if input_name is not None: + # NOTE: input_name, block_indices, and grid_loop_idx are set only if this + # function is being called to read a block from a pallas_call input (at the + # start of one iteration of the kernel body). + block_indices = tuple(int(x) for x in block_indices) + grid_loop_idx = tuple(int(x) for x in tuple(grid_loop_idx)) + + shared_memory = _get_shared_memory() + + local_core_id_for_buffer = _local_core_id_or_zero_if_hbm( + local_core_id, memory_space + ) + global_core_id = shared_memory.get_global_core_id(device_id, local_core_id) + + key = (memory_space, buffer_id, device_id, local_core_id_for_buffer) + read_range = interpret_utils.to_range(transforms) + ret, (shape, dtype), clock_ = shared_memory.get_buffer_content( + key, read_range, global_core_id + ) + clock = clock if clock is not None else clock_ + + # Compute the shape of the read value, assuming the read is fully in-bounds. + # TODO(jburnim): We already know this shape in the Jaxpr where we insert a + # callback to `get`. Should we just pass the shape to `get`? + # TODO(jburnim): Move to a helper function? + full_read_shape = [] + assert len(read_range) <= len(shape) + for dim_size, idx_or_slice in itertools.zip_longest( + shape, read_range, fillvalue=None + ): + assert isinstance(dim_size, int) + if idx_or_slice is None: + full_read_shape.append(dim_size) + elif isinstance(idx_or_slice, int): + continue + else: + dim_size = (idx_or_slice.stop - idx_or_slice.start) // idx_or_slice.step + assert isinstance(dim_size, int) + full_read_shape.append(dim_size) + full_read_shape = tuple(full_read_shape) + + if (ret is None) or (full_read_shape != ret.shape): + if shared_memory.out_of_bounds_reads == 'raise': + if source_info is None: + ctx = contextlib.nullcontext() + else: + ctx = source_info_util.user_context( + traceback=source_info.traceback, name_stack=source_info.name_stack + ) # type: ignore[assignment] + with ctx: + if input_name is None: + raise IndexError( + 'Out-of-bounds read of' + f' ({device_id} {local_core_id} {memory_space} {buffer_id}):' + f' reading [{read_range}] but buffer has shape {shape}.' + ) + else: + # Different error message when we are reading a block of an input, + # to copy it to a buffer before invoking the kernel body. + raise IndexError( + f'Out-of-bounds block index {block_indices} for' + f' input "{input_name}" in iteration {grid_loop_idx}' + f' on device {device_id} (core {local_core_id}):' + f' reading [{read_range}] but input has shape {shape}.' + ) + # out_of_bounds_reads == "uninitialized" + uninit_array = np.full( + full_read_shape, + interpret_utils.get_uninitialized_value( + dtype, shared_memory.uninitialized_memory + ), + dtype=dtype, + ) + if ret is None: + ret = uninit_array + else: + uninit_array[tuple(slice(s) for s in ret.shape)] = ret + ret = uninit_array + + if shared_memory.detect_races: + if src_device_id is None: + src_device_id = device_id + if src_local_core_id is None: + src_local_core_id = local_core_id + assert races is not None + races.check_read( + src_device_id, + src_local_core_id, + clock, + (memory_space, buffer_id, device_id, local_core_id_for_buffer), + read_range, + source_info=source_info, + ) + + return ret + + +def store( + device_id, + local_core_id, + memory_space, + buffer_id, + transforms, + val, + block_indices=None, + grid_loop_idx=None, + *, + src_device_id=None, + src_local_core_id=None, + clock=None, + source_info=None, + output_name=None, +): + device_id = int(device_id) + local_core_id = int(local_core_id) + memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] + buffer_id = int(buffer_id) + try: + transforms = jax.tree.map(int, transforms) + except: + raise ValueError('Advanced indexers are not supported on TPU') + val = np.array(val) + src_device_id = _to_int(src_device_id) + src_local_core_id = _to_int(src_local_core_id) + if output_name is not None: + # NOTE: output_name, block_indices, and grid_loop_idx are set only if this + # function is being called to store a block into a pallas_call output (at + # the end of one iteration of the kernel body). + block_indices = tuple(int(x) for x in block_indices) + grid_loop_idx = tuple(int(x) for x in tuple(grid_loop_idx)) + + shared_memory = _get_shared_memory() + + local_core_id_for_buffer = _local_core_id_or_zero_if_hbm( + local_core_id, memory_space + ) + global_core_id = shared_memory.get_global_core_id(device_id, local_core_id) + + key = (memory_space, buffer_id, device_id, local_core_id_for_buffer) + write_range = interpret_utils.to_range(transforms) + in_bounds, (shape, _), clock_ = shared_memory.store_buffer_content( + key, write_range, val, global_core_id + ) + clock = clock if clock is not None else clock_ + + if not in_bounds: + if output_name is None: + raise ValueError( + 'Out-of-bounds write of' + f' ({device_id} {local_core_id} {memory_space} {buffer_id}):' + f' writing [{write_range}] but buffer has shape {shape} .' + ) + else: + # Different error message when we are copying a kernel buffer to a + # block of an output (just after a kernel invocation). + raise IndexError( + f'Out-of-bounds block index {block_indices} for' + f' output "{output_name}" in iteration {grid_loop_idx}' + f' on device {device_id} (core {local_core_id}):' + f' reading [{write_range}] but output has shape {shape}.' + ) + + if shared_memory.detect_races: + if src_device_id is None: + src_device_id = device_id + if src_local_core_id is None: + src_local_core_id = local_core_id + assert races is not None + races.check_write( + src_device_id, + src_local_core_id, + clock, + (memory_space, buffer_id, device_id, local_core_id_for_buffer), + write_range, + source_info=source_info, + ) + + +def swap( + device_id, + local_core_id, + memory_space, + buffer_id, + transforms, + val, + mask, + *, + source_info=None, +): + device_id = int(device_id) + local_core_id = int(local_core_id) + memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] + buffer_id = int(buffer_id) + try: + transforms = jax.tree.map(int, transforms) + except: + raise ValueError('Advanced indexers are not supported on TPU') + val = np.array(val) + mask = np.array(mask) if mask is not None else None + if mask is not None: + assert mask.shape == val.shape + + shared_memory = _get_shared_memory() + + local_core_id_for_buffer = _local_core_id_or_zero_if_hbm( + local_core_id, memory_space + ) + global_core_id = shared_memory.get_global_core_id(device_id, local_core_id) + + key = (memory_space, buffer_id, device_id, local_core_id_for_buffer) + read_write_range = interpret_utils.to_range(transforms) + ret, (shape, _), clock = shared_memory.swap_buffer_content( + key, read_write_range, val, mask, global_core_id + ) + + if ret is None: + if mask is None: + raise ValueError( + 'Out-of-bounds swap of' + f' ({device_id} {local_core_id} {memory_space} {buffer_id}):' + f' swapping [{read_write_range}] but buffer has shape' + f' {shape} .' + ) + else: + # TODO(jburnim): Include indices of out-of-bounds locations where mask + # is True. + raise ValueError( + 'Out-of-bounds masked swap of' + f' ({device_id} {local_core_id} {memory_space} {buffer_id}): swapping' + f' [{read_write_range}] but buffer has shape {shape} . ' + ) + + if shared_memory.detect_races: + assert races is not None + races.check_write( + device_id, + local_core_id, + clock, + (memory_space, buffer_id, device_id, local_core_id_for_buffer), + read_write_range, + source_info=source_info, + ) + return ret + + +class DmaState(enum.Enum): + STARTED = 0 + READ = 1 + COMPLETED = 2 + + +@dataclasses.dataclass +class DMA: + id: int + + src_device_id: int + src_local_core_id: int + src_memory_space: int + src_buffer_id: int + src_transforms: tuple[Any, ...] + dst_device_id: int + dst_local_core_id: int + dst_memory_space: int + dst_buffer_id: int + dst_transforms: tuple[Any, ...] + src_sem: memory.Semaphore | None + dst_sem: memory.Semaphore + virtual_device_id: int + clock: vc.VectorClock + + source_info: source_info_util.SourceInfo | None = None + + state: DmaState = DmaState.STARTED + data: np.ndarray | None = None + lock: threading.Lock = dataclasses.field(default_factory=threading.Lock) + + @property + def data_size(self) -> int: + assert self.data is not None + return self.data.itemsize * self.data.size + + @property + def detect_races(self) -> bool: + return self.dst_sem.detect_races + + @property + def src_global_core_id(self) -> int: + return self.dst_sem.get_global_core_id( + self.src_device_id, self.src_local_core_id + ) + + @property + def dst_global_core_id(self) -> int: + return self.dst_sem.get_global_core_id( + self.dst_device_id, self.dst_local_core_id + ) + + def execute_read(self): + """Executes the reading part of this DMA. + + Note that the caller must not hold the lock on the shared memory (because + `get` is called in this method). + """ + # Must acquire the lock on `self` because: + # - `self.state` is inspected and modified in this method. + # - `self.data` is assigned in this method. + with self.lock: + if self.state != DmaState.STARTED: + return + + if self.detect_races: + vc.inc_vector_clock(self.clock, self.virtual_device_id) + + self.data = get( + self.src_device_id, + self.src_local_core_id, + self.src_memory_space, + self.src_buffer_id, + self.src_transforms, + clock=vc.copy_vector_clock(self.clock), + src_device_id=self.id, + src_local_core_id=0, + source_info=self.source_info, + ) + + if self.detect_races: + vc.inc_vector_clock(self.clock, self.virtual_device_id) + + # Signal the send semaphore. + if self.src_sem is not None: + self.src_sem.signal( + self.data_size, self.src_global_core_id, clock=self.clock + ) + + self.state = DmaState.READ + + def execute_write(self): + """Executes the writing part of this DMA. + + Note that the caller must not hold the lock on the shared memory (because + `store` is called in this method). + """ + # Must acquire the lock on `self` because: + # - `self.state` is inspected and modified in this method. + # - `self.data` is assigned in this method. + with self.lock: + assert self.state in (DmaState.READ, DmaState.COMPLETED) + if self.state == DmaState.COMPLETED: + return + assert self.data is not None + + if self.detect_races: + vc.inc_vector_clock(self.clock, self.virtual_device_id) + + store( + self.dst_device_id, + self.dst_local_core_id, + self.dst_memory_space, + self.dst_buffer_id, + self.dst_transforms, + self.data, + clock=vc.copy_vector_clock(self.clock), + src_device_id=self.id, + src_local_core_id=0, + source_info=self.source_info, + ) + + if self.detect_races: + vc.inc_vector_clock(self.clock, self.virtual_device_id) + + self.dst_sem.signal( + self.data_size, self.dst_global_core_id, clock=self.clock + ) + + self.data = None + self.state = DmaState.COMPLETED + + def execute_read_and_write(self): + """Executes this DMA, bot the reading and writing parts. + + Note that the caller must not hold the lock on the shared memory. + """ + self.execute_read() + self.execute_write() + + +def dma_start( + device_id, + src_local_core_id, + src_memory_space, + src_id, + src_transforms, + dst_memory_space, + dst_id, + dst_transforms, + dst_sem_id, + src_sem_id, + dst_device_id, + source_info=None, +): + shared_memory = _get_shared_memory() + device_id = int(device_id) + src_local_core_id = int(src_local_core_id) + src_global_core_id = shared_memory.get_global_core_id( + device_id, src_local_core_id + ) + src_memory_space, src_id = int(src_memory_space), int(src_id) + src_transforms = jax.tree.map(int, src_transforms) + dst_memory_space, dst_id = int(dst_memory_space), int(dst_id) + dst_transforms = jax.tree.map(int, dst_transforms) + dst_sem_id = int(dst_sem_id) + src_sem_id = int(src_sem_id) if src_sem_id is not None else None + if dst_device_id is not None: + dst_device_id = int(dst_device_id) + else: + dst_device_id = device_id + dst_global_core_id = shared_memory.get_global_core_id( + dst_device_id, src_local_core_id # Same core on destination device as on source. + ) + + (src_sem, dst_sem), clock = shared_memory.get_semaphores_and_increment_clock( + (src_sem_id, dst_sem_id), src_global_core_id + ) + + assert dma_id_counter is not None + id = dma_id_counter.get_next() + + dma = DMA( + id, + device_id, + src_local_core_id, + src_memory_space, + src_id, + src_transforms, + dst_device_id, + src_local_core_id, # Same core on destination device as on source. + dst_memory_space, + dst_id, + dst_transforms, + src_sem, + dst_sem, + virtual_device_id = shared_memory.get_random_virtual_device_id(), + clock=clock, + source_info=source_info, + ) + + if shared_memory.dma_execution_mode == 'on_wait': + if src_sem_id is None: + shared_memory.append_semaphore_task( + dst_sem_id, dst_global_core_id, dma.execute_read_and_write + ) + else: + shared_memory.append_semaphore_task( + src_sem_id, src_global_core_id, dma.execute_read + ) + shared_memory.append_semaphore_task( + dst_sem_id, + dst_global_core_id, + # This task for the waiting semaphore with ID `dst_sem_id` may be + # executed before the corresponding DMA task for the sending semaphore + # that does the DMA read. We therefore have to append a read-and-write + # task here, instead of just a write task. If the reading for the DMA + # has already been executed, the DMA's state will indicate this and + # the read-write-task appended here will do the write only. + # (Alternatively, we could have the DMA write task wait on the + # `send_semphore`. This issue with this approach is that we do not + # know the number of bytes transferred that `send_semaphore` should be + # waiting for until after the reader task is done.) + dma.execute_read_and_write, + ) + return + + assert shared_memory.dma_execution_mode == 'eager' + dma.execute_read_and_write() + + +def dma_wait(device_id, local_core_id, sem_id, size): + shared_memory = _get_shared_memory() + + device_id = int(device_id) + local_core_id = int(local_core_id) + sem_id = int(sem_id) + size = int(size) + + global_core_id = shared_memory.get_global_core_id(device_id, local_core_id) + + (sem,), _ = shared_memory.get_semaphores_and_increment_clock( + {sem_id}, global_core_id + ) + assert sem is not None + sem.wait(size, global_core_id, has_tasks=True) + + +def semaphore_signal( + device_id, + local_core_id, + sem_id, + inc, + target_device_id, + target_local_core_id, +): + shared_memory = _get_shared_memory() + + device_id = int(device_id) + local_core_id = int(local_core_id) + sem_id = int(sem_id) + inc = int(inc) + src_global_core_id = shared_memory.get_global_core_id( + device_id, local_core_id + ) + if target_device_id is None: + target_device_id = device_id + else: + target_device_id = int(target_device_id) + if target_local_core_id is None: + target_local_core_id = 0 + + (sem,), clock = shared_memory.get_semaphores_and_increment_clock( + {sem_id}, src_global_core_id + ) + assert sem is not None + sem.signal( + inc, + shared_memory.get_global_core_id(target_device_id, target_local_core_id), + clock, + ) + + +def semaphore_wait(device_id, local_core_id, sem_id, value): + shared_memory = _get_shared_memory() + + device_id = int(device_id) + local_core_id = int(local_core_id) + sem_id = int(sem_id) + value = int(value) + global_core_id = shared_memory.get_global_core_id(device_id, local_core_id) + + (sem,), _ = shared_memory.get_semaphores_and_increment_clock( + {sem_id}, global_core_id + ) + assert sem is not None + sem.wait(value, global_core_id) + + +def _compute_transformed_shape_and_dtype(shape, dtype, transforms): + for transform in transforms: + if transform is None: + continue + shape = transform.transform_shape(shape) + dtype = transform.transform_dtype(dtype) + return shape, dtype + + +@lu.cache +def _to_jaxpr(flat_fun, in_avals): + new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) + new_jaxpr = jax_core.ClosedJaxpr(new_jaxpr, consts) + return new_jaxpr + +def _is_any(memory_space): + return memory_space is pallas_core.MemorySpace.ANY + + +_SENTINEL = jnp.inf + + +def _get_memory_space_and_raise_if_hbm(aval, primitive_name, message=None): + memory_space = aval.memory_space + if memory_space in [mosaic_core.MemorySpace.HBM, pallas_core.MemorySpace.ANY]: + if message is None: + message = ( + f'{primitive_name}: Buffers with a memory space of HBM or ANY cannot' + ' be referenced directly. Instead, use `pltpu.sync_copy` or' + ' `pltpu.async_copy`.' + ) + raise ValueError(message) + return memory_space + + +def _interpret_jaxpr( + jaxpr, + *args, + axis_sizes, + mesh, + axis_indices, + device_id, + local_core_id, + compiler_params, + interpret_params +): + sentinel_for_floating_point_values = ( + _SENTINEL if interpret_params.skip_floating_point_ops else None + ) + env = interpret_utils.JaxprEnv( + vars=jaxpr.constvars + jaxpr.invars, + values=args, + sentinel_for_floating_point_values=sentinel_for_floating_point_values, + ) + + # TODO(jburnim): Clean up and finish this evaluation loop. For example: + # - Replace the big if-statement with a dictionary of rules. + # - Handle other higher-order primitives? + _interpret = functools.partial( + _interpret_jaxpr, + axis_sizes=axis_sizes, + mesh=mesh, + axis_indices=axis_indices, + device_id=device_id, + local_core_id=local_core_id, + compiler_params=compiler_params, + interpret_params=interpret_params, + ) + for eqn in jaxpr.eqns: + with source_info_util.user_context( + eqn.source_info.traceback, name_stack=eqn.source_info.name_stack): + prim = eqn.primitive + # We defer reading the values for `eqn.invars` into each of the branches + # of the if-elif-else statement below. This is because the else branch may + # not need to do any reads if `interpret_params.skip_floating_point_ops` + # is True. If this is the case, we want to avoid materializing the read + # array into the jaxpr when this function is traced. + deferred_invals = functools.partial(env.read_many, eqn.invars) + + if prim is primitives.load_p: + (ref, transforms, mask, _) = jax.tree.unflatten( + eqn.params['args_tree'], deferred_invals()) + if mask is not None: + raise NotImplementedError('masked load_p') + memory_space = _get_memory_space_and_raise_if_hbm( + eqn.invars[0].aval, 'load_p' + ) + out = callback.io_callback( + functools.partial(get, source_info=eqn.source_info), + eqn.outvars[0].aval, + device_id, + local_core_id, + TPU_MEMORY_SPACE_IDXS[memory_space], + ref, + transforms, + ordered=True, + ) + + elif prim is primitives.swap_p: + (ref, transforms, val, mask) = jax.tree.unflatten( + eqn.params['args_tree'], deferred_invals()) + memory_space = _get_memory_space_and_raise_if_hbm( + eqn.invars[0].aval, 'swap_p' + ) + out = callback.io_callback( + functools.partial(swap, source_info=eqn.source_info), + eqn.outvars[0].aval, + device_id, + local_core_id, + TPU_MEMORY_SPACE_IDXS[memory_space], + ref, + transforms, + val, + mask, + ordered=True, + ) + + elif prim is primitives.delay_p: + # TODO(jburnim): Implement this properly? + out = [] + + elif prim is mosaic_primitives.prng_seed_p: + # TODO(jburnim): Implement this properly? + out = [] + + elif prim is mosaic_primitives.prng_random_bits_p: + # TODO(jburnim): Implement this properly? + out = jnp.zeros(eqn.params['shape'], jnp.int32) + + elif ((prim is lax.axis_index_p) + and (mesh is not None) and (eqn.params['axis_name'] in mesh.shape)): + # We are interpreting a core_map, and this lax.axis_index call is + # querying our index along the core axis, so return our core ID. + out = local_core_id + + elif ((prim is lax.axis_index_p) + and (eqn.params['axis_name'] in axis_indices)): + # We replace lax.axis_index calls in the kernel body, so that the + # kernel body jaxpr can be run on other threads (via an io_callback) + # without having to recreate the axis environment in those threads. + out = axis_indices[eqn.params['axis_name']] + + elif prim is lax.cond_p: + def _make_branch(jaxpr): + return lambda *args: _interpret(jaxpr, *args) + invals = deferred_invals() + out = lax.switch( + invals[0], + [_make_branch(branch_jaxpr.jaxpr) + for branch_jaxpr in eqn.params['branches']], + *invals[1:]) + + elif prim is lax.scan_p: + consts, init_carry, xs = split_list( + deferred_invals(), + [eqn.params['num_consts'], eqn.params['num_carry']], + ) + def _scan_body(c, a): + return split_list( + _interpret(eqn.params['jaxpr'].jaxpr, *consts, *c, *a), + [eqn.params['num_carry']]) + carry, out = lax.scan(_scan_body, init_carry, xs=xs, + length=eqn.params.get('length', None)) + out = carry + out + + elif prim is lax.while_p: + cond_consts, body_consts, init_vals = split_list( + deferred_invals(), + [eqn.params['cond_nconsts'], eqn.params['body_nconsts']], + ) + out = lax.while_loop( + lambda args: _interpret( + eqn.params['cond_jaxpr'].jaxpr, *cond_consts, *args)[0], + lambda args: _interpret( + eqn.params['body_jaxpr'].jaxpr, *body_consts, *args), + init_vals) + + elif prim is pjit.jit_p: + def f(*args, jaxpr): + return _interpret(jaxpr.jaxpr, *jaxpr.consts, *args) + invals = deferred_invals() + in_avals = tuple(jax_core.shaped_abstractify(i) for i in invals) + new_jaxpr = _to_jaxpr( + lu.wrap_init(functools.partial(f, jaxpr=eqn.params['jaxpr']), + debug_info=eqn.params['jaxpr'].jaxpr.debug_info), + in_avals) + out = pjit.jit_p.bind(*invals, **(eqn.params | {'jaxpr': new_jaxpr})) + + elif prim is primitives.run_scoped_p: + if eqn.params['collective_axes']: + raise NotImplementedError( + 'run_scoped_p with collective axes is not supported' + ) + # Allocate a buffer or semaphore for each element of + # eqn.params['jaxpr'].invars. It is assumed that each core + # runs the same sequence of `run_scoped`s. + allocs = [] + for v in eqn.params['jaxpr'].invars: + if v.aval.memory_space == mosaic_core.MemorySpace.SEMAPHORE: + allocs.append( + callback.io_callback( + _allocate_semaphores, + jax.ShapeDtypeStruct(v.aval.shape, jnp.int16), + device_id, + local_core_id, + v.aval.shape, + ordered=True, + ) + ) + else: + if not interpret_params.allow_hbm_allocation_in_run_scoped: + memory_space = _get_memory_space_and_raise_if_hbm( + v.aval, 'run_scoped_p', "Cannot allocate HBM in `run_scoped`." + ) + else: + memory_space = v.aval.memory_space + allocs.append( + callback.io_callback( + _allocate_buffer, + jax.ShapeDtypeStruct((), jnp.int16), + device_id, + local_core_id, + TPU_MEMORY_SPACE_IDXS[memory_space], + interpret_params.get_uninitialized_array( + v.aval.shape, v.aval.dtype + ), + ordered=True, + ) + ) + + out = _interpret(eqn.params['jaxpr'], *deferred_invals(), *allocs) + + for a, v in zip(allocs, eqn.params['jaxpr'].invars): + if v.aval.memory_space == mosaic_core.MemorySpace.SEMAPHORE: + # TODO(jburnim): De-allocate semaphores. + # callback.io_callback( + # _deallocate_semaphores, + # None, + # device_id, + # a, + # ordered=True) + pass + else: + callback.io_callback( + _deallocate_buffer, + None, + device_id, + local_core_id, + TPU_MEMORY_SPACE_IDXS[v.aval.memory_space], + a, + ordered=True, + ) + + elif prim is state_primitives.get_p: + memory_space = _get_memory_space_and_raise_if_hbm( + eqn.invars[0].aval, 'get_p' + ) + invals = deferred_invals() + out = callback.io_callback( + functools.partial(get, source_info=eqn.source_info), + eqn.outvars[0].aval, + device_id, + local_core_id, + TPU_MEMORY_SPACE_IDXS[memory_space], + invals[0], + jax.tree.unflatten(eqn.params['tree'], invals[1:]), + ordered=True, + ) + + elif prim is state_primitives.swap_p: + memory_space = _get_memory_space_and_raise_if_hbm( + eqn.invars[0].aval, 'swap_p' + ) + invals = deferred_invals() + out = callback.io_callback( + functools.partial(swap, source_info=eqn.source_info), + eqn.outvars[0].aval, + device_id, + local_core_id, + TPU_MEMORY_SPACE_IDXS[memory_space], + invals[0], + jax.tree.unflatten(eqn.params['tree'], invals[2:]), + invals[1], + None, + ordered=True, + ) + + elif prim is mosaic_primitives.dma_start_p: + ( + src, + src_transforms, + dst, + dst_transforms, + dst_sem, + dst_sem_transforms, + src_sem, + src_sem_transforms, + target_device_id, + ) = jax.tree.unflatten(eqn.params['tree'], deferred_invals()) + target_device_id = interpret_utils._device_id_to_logical( + target_device_id, eqn.params['device_id_type'], axis_sizes, + axis_indices) + (orig_src_ref, _, orig_dst_ref, *_ + ) = jax.tree.unflatten(eqn.params['tree'], eqn.invars) + src_memory_space = getattr(orig_src_ref.aval, 'memory_space', None) + if src_memory_space is None: + src_memory_space = pallas_core.MemorySpace.ANY + dst_memory_space = getattr(orig_dst_ref.aval, 'memory_space', None) + if dst_memory_space is None: + dst_memory_space = pallas_core.MemorySpace.ANY + callback.io_callback( + functools.partial(dma_start, source_info=eqn.source_info), + (), + device_id, + local_core_id, + TPU_MEMORY_SPACE_IDXS[src_memory_space], + src, + src_transforms, + TPU_MEMORY_SPACE_IDXS[dst_memory_space], + dst, + dst_transforms, + state_discharge.transform_array(dst_sem, dst_sem_transforms), + state_discharge.transform_array(src_sem, src_sem_transforms), + target_device_id, + ordered=True, + ) + out = [] + + elif prim is mosaic_primitives.dma_wait_p: + ( + src, + src_transforms, + dst, + dst_transforms, + dst_sem, + dst_sem_transforms, + src_sem, + src_sem_transforms, + target_device_id, + ) = jax.tree.unflatten(eqn.params['tree'], deferred_invals()) + read_shape, read_dtype = _compute_transformed_shape_and_dtype( + eqn.invars[0].aval.shape, eqn.invars[0].aval.dtype, src_transforms) + callback.io_callback( + dma_wait, + (), + device_id, + local_core_id, + state_discharge.transform_array(dst_sem, dst_sem_transforms), + math.prod(read_shape) * read_dtype.itemsize, + ordered=True, + ) + out = [] + + elif prim is mosaic_primitives.get_barrier_semaphore_p: + out = callback.io_callback( + get_barrier_semaphore, + jax.ShapeDtypeStruct((), jnp.int16), + device_id, + _get_mosaic_params(compiler_params).collective_id, + ordered=True, + ) + + elif prim is primitives.semaphore_signal_p: + sem, sem_transforms, inc, target_device_id, core_index = ( + jax.tree.unflatten(eqn.params['args_tree'], deferred_invals())) + target_device_id = interpret_utils._device_id_to_logical( + target_device_id, eqn.params['device_id_type'], axis_sizes, + axis_indices) + callback.io_callback( + semaphore_signal, + (), + device_id, + local_core_id, + state_discharge.transform_array(sem, sem_transforms), + inc, + target_device_id, + core_index, + ordered=True, + ) + out = [] + + elif prim is primitives.semaphore_wait_p: + sem, sem_transforms, value, decrement = ( + jax.tree.unflatten(eqn.params['args_tree'], deferred_invals())) + if not decrement: + raise NotImplementedError('Non-decrementing wait is not supported.') + callback.io_callback( + semaphore_wait, + (), + device_id, + local_core_id, + state_discharge.transform_array(sem, sem_transforms), + value, + ordered=True, + ) + out = [] + + elif prim is primitives.atomic_rmw_p: + raise NotImplementedError('atomic_rmw_p') + + elif prim is primitives.atomic_cas_p: + raise NotImplementedError('atomic_cas_p') + + else: + if interpret_params.skip_floating_point_ops and all( + interpret_utils.is_float(ovar.aval.dtype) for ovar in eqn.outvars + ): + # Skip `prim.bind` since `prim` only produces floating-point values. + # It is safe to populate `out` with avals since mapping `write` over + # `out` below only relies on the shape and dtype (for writing + # `Placeholder`s). + out = [ovar.aval for ovar in eqn.outvars] + if not prim.multiple_results: + out = out[0] + else: + subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) + out = prim.bind(*subfuns, *deferred_invals(), **bind_params) + + out = out if prim.multiple_results else [out] + env.write_many(eqn.outvars, out) + + return env.read_many(jaxpr.outvars) + +def _compute_start_indices( + block_mapping, loop_idx, *args, + axis_sizes, mesh, axis_indices, device_id, local_core_id, + compiler_params, interpret_params): + jaxpr = block_mapping.index_map_jaxpr + block_indices = _interpret_jaxpr( + jaxpr.jaxpr, + *jaxpr.consts, + *loop_idx, + *args, + axis_sizes=axis_sizes, + mesh=mesh, + axis_indices=axis_indices, + device_id=device_id, + local_core_id=local_core_id, + compiler_params=compiler_params, + interpret_params=interpret_params, + ) + def _get_start_index(i, b): + match b: + case pallas_core.Squeezed(): + return i + case pallas_core.Element(): + return i + case pallas_core.Blocked(): + return i * b.block_size + case _: + raise ValueError(f"Unsupported block dim type: {type(b)}") + ret = jnp.array( + tuple( + _get_start_index(i, b) + for i, b in zip(block_indices, block_mapping.block_shape) + ), + dtype=jnp.int32, + ) + return block_indices, ret + +def _get_next_indices(grid, indices): + next_indices = [] + carry = True + for dim_size, index in reversed(list(zip(grid, indices))): + i = jnp.where(carry, index + 1, index) + carry = dim_size == i + next_indices.append(jnp.where(carry, 0, i)) + return tuple(reversed(next_indices)) + +def _get_indices(grid, loop_index): + indices = [] + for dim_size in reversed(grid): + i = loop_index % dim_size + loop_index = loop_index // dim_size + indices.append(i) + return tuple(reversed(indices)) + +def _get_mosaic_params(compiler_params: dict[str, pallas_core.CompilerParams]) -> mosaic_core.CompilerParams: + try: + return cast(mosaic_core.CompilerParams, compiler_params['mosaic_tpu']) + except KeyError: + return mosaic_core.CompilerParams() + + +def _get_parallel_dim_semantics( + compiler_params: dict[str, Any], num_dimensions_in_grid: int, +) -> tuple[bool, ...]: + """Returns a tuple indicating which grid dimensions have parallel semantics. + + Args: + compiler_params: Representation of a `mosaic_core.CompilerParams` object + as a dictionary. + num_dimensions_in_grid: The number of dimensions in the grid. + + Returns: + A tuple of booleans where the entry at index `i` is `True` precisely if the + `i`-th dimension in the grid has parallel semantics. + + Raises: + ValueError: If the dimensions with parallel semantics do not form a prefix + of the grid. + """ + mosaic_params = _get_mosaic_params(compiler_params) + if mosaic_params.dimension_semantics is None: + return (False,) * num_dimensions_in_grid + result = tuple(ds in ('parallel', mosaic_core.PARALLEL) + for ds in mosaic_params.dimension_semantics) + for ds0, ds1 in zip(result[:-1], result[1:]): + if ds1 and not ds0: + raise ValueError( + 'Dimensions with parallel semantics must form a prefix of the grid.' + ) + return result + + +def _get_parallel_subgrid_size( + parallel_semantics_per_dim: tuple[bool, ...], grid: tuple[int, ...] +) -> int: + """Returns the size of the subgrid along the parallel dimensions.""" + return math.prod( + dim_size if parallel_dim else 1 + for dim_size, parallel_dim in zip(grid, parallel_semantics_per_dim) + ) + +_GridPointCoordinatesPerDim = tuple[Array, ...] + +def _get_randomized_grid_coordinates( + grid: tuple[int, ...], + compiler_params: dict[str, Any], + random_seed: int | None, +) -> _GridPointCoordinatesPerDim: + """Returns a tuple of randomized coordinates for each 'parallel' dimension in `grid`. + + For a dimension with 'parallel' semantics at position `d` in the grid, the + returned tuple contains a random permutation of the sequence `[0,..., + grid[d] - 1]` at index `d`. For each dimension with 'arbitrary' semantics, + the resulting tuple contains an empty array. (Inserting an empty array for an + 'arbitrary' dimension at position `d` in the grid, instead of the sequence + `[0,..., grid[d] - 1]`, allows `grid[d]` to be a dynamic value, i.e. a value + not known at Jax trace time.) + + Args: + grid: Tuple of sizes of the dimensions in the grid. + compiler_params: Representation of a `mosaic_core.CompilerParams` object + as a dictionary. + parallel_semantics_per_dim: A tuple of booleans indicating whether the + corresponding dimension in the grid has parallel semantics. + random_seed: The seed to use for randomizing coordinates in parallel + dimensions. + """ + parallel_semantics_per_dim = _get_parallel_dim_semantics( + compiler_params, len(grid) + ) + + key = jax.random.key(random_seed or 0) + grid_point_coordinates = [] + for dim_size, parallel_dim in zip(grid, parallel_semantics_per_dim): + if parallel_dim: + # The size of a dimension with `parallel` semantics must be known at Jax + # trace time. This ensures that the arguments to `jnp.arange` and + # `jax.random.permutation` below are valid. + dim_size = jax_core.concrete_or_error(None, dim_size) + + coordindates_along_dim = jnp.arange(dim_size, dtype=jnp.int32) + key, subkey = jax.random.split(key) + coordindates_along_dim = jax.random.permutation( + subkey, coordindates_along_dim + ) + grid_point_coordinates.append(coordindates_along_dim) + else: + grid_point_coordinates.append(jnp.array((), dtype=jnp.int32)) + + return tuple(grid_point_coordinates) + +# TODO(sharadmv, jburnim): add support for memory space constraints +remove_memory_space_p = jax_core.Primitive('remove_memory_space') + +@remove_memory_space_p.def_abstract_eval +def _remove_memory_space_abstract_eval(x): + if isinstance(x, pallas_core.ShapedArrayWithMemorySpace): + if ( + x.memory_space is None + or x.memory_space is pallas_core.MemorySpace.ANY + or x.memory_space is mosaic_core.MemorySpace.HBM + ): + return jax_core.ShapedArray(x.shape, x.dtype) + raise NotImplementedError(f'Unsupported memory space: {x.memory_space}') + return x + +@remove_memory_space_p.def_impl +def _remove_memory_space_impl(x): + return x + +def _remove_memory_space_lowering(_, x): + return [x] +mlir.register_lowering(remove_memory_space_p, _remove_memory_space_lowering) + + +def _get_grid_point( + loop_indices: tuple[Array, ...], + grid_point_coordinates: _GridPointCoordinatesPerDim, +) -> Array: + """Indexes each entry in `grid_point_coordinates` with the corresponding entry in `loop_indices`. + + If an entry in `grid_point_coordinates` is an empty array, the corresponding + entry in the returned array is the corresponding entry in `loop_indices`. + Otherwise, the returned array contains the entry in `grid_point_coordinates` + indexed with the corresponding entry in `loop_indices`. + + Args: + loop_indices: A tuple of loop indices. + grid_point_coordinates: A tuple of coordinate arrays for each dimension in + the grid. Dimensions with 'arbitrary' semantics are represented by empty + arrays. Dimensions with 'parallel' semantics are represented by arrays of + randomized coordinates. + + Returns: + A 1-dimensional array containing the coordinates for the grid point + corresponding to the specified `loop_indices`. + """ + grid_point = [] + for li, coords in zip(loop_indices, grid_point_coordinates): + grid_point.append(li if jnp.size(coords) == 0 else coords[li]) + return jnp.array(grid_point, dtype=np.int32) + + +def get_interpret_effects(): + return {callback._OrderedIOEffect} + + +def interpret_pallas_call( + *args, + jaxpr: jax_core.Jaxpr, + debug: bool, + input_output_aliases: tuple[tuple[int, int], ...], + grid_mapping: pallas_core.GridMapping, + mesh: pallas_core.Mesh | None, + compiler_params: dict[str, Any], + cost_estimate: pallas_core.CostEstimate, + out_avals: tuple[jax_core.AbstractValue, ...], + interpret_params: InterpretParams, + metadata: frozen_dict.FrozenDict[str, str] | None, + name: str | None, +): + del debug, cost_estimate, out_avals, name + del metadata # TODO(sharadmv): Add metadata to HLO. + + if isinstance(mesh, mosaic_core.TensorCoreMesh): + # As a convenience for users, if we are interpreting a pl.core_map over a + # TensorCoreMesh, we automatically set the number of cores per device so + # that users don't have to specify it in the InterpretParams. + assert len(mesh.shape) == 1 + interpret_params = dataclasses.replace( + interpret_params, num_cores_or_threads=mesh.devices.shape[0] + ) + + args = [remove_memory_space_p.bind(a) for a in args] + # args contains: *dynamic_grid_sizes, *index, *inputs. (No consts?) + dynamic_grid_args, scalars, input_args = split_list( + args, + [grid_mapping.num_dynamic_grid_bounds, grid_mapping.num_index_operands], + ) + dynamic_grid_args_iter = iter(dynamic_grid_args) + grid = tuple( + a if a is not pallas_core.dynamic_grid_dim + else next(dynamic_grid_args_iter) + for a in grid_mapping.grid + ) + assert next(dynamic_grid_args_iter, None) is None + + axis_sizes = jax_core.get_axis_env().axis_sizes + num_devices = functools.reduce( + jnp.multiply, axis_sizes.values(), jnp.int32(1)) + axis_indices = {k: lax.axis_index(k) for k in axis_sizes.keys()} + device_id = interpret_utils.device_coords_to_logical_id( + tuple(axis_indices.values()), axis_sizes, axis_indices + ) + callback.io_callback( + functools.partial( + _initialize_shared_memory, interpret_params=interpret_params + ), + (), + device_id, + num_devices, + interpret_params.num_cores_per_device, + ordered=True, + ) + + # Pad input arguments. + is_squeeze_dim = [ + tuple(isinstance(b, pallas_core.Squeezed) for b in bm.block_shape) + for bm in grid_mapping.block_mappings + ] + block_shapes = [ + pallas_core._get_block_shape(bm.block_shape) + for bm in grid_mapping.block_mappings + ] + num_inputs = grid_mapping.num_inputs + input_args = [ + interpret_params.pad_to_block_dimension(a, bs) + for a, bs in zip(input_args, block_shapes[:num_inputs]) + ] + + # Allocate HBM buffers for pallas_call inputs. + # + # TODO(jburnim): As an optimization, skip allocating buffers for inputs that + # are neither aliased nor passed to the kernel in HBM? + input_buffer_ids = [] + for i, var in enumerate( + jaxpr.invars[grid_mapping.num_index_operands:][:grid_mapping.num_inputs]): + assert var.aval.dtype == input_args[i].dtype + input_buffer_ids.append( + callback.io_callback( + _allocate_buffer, + jax.ShapeDtypeStruct((), jnp.int16), + device_id, + None, # local_core_id + TPU_MEMORY_SPACE_IDXS[pallas_core.MemorySpace.ANY], + input_args[i], + ordered=True, + ) + ) + + # Allocate buffers in HBM for pallas_call outputs. + oi_alias_map = {v: k - len(scalars) for k, v in input_output_aliases} + if any(i < 0 for i in oi_alias_map.keys()): + raise ValueError('Aliasing of scalar prefetch arguments is not currently ' + 'supported in TPU interpret mode.') + output_buffer_ids = [] + output_buffer_shapes = [] + output_vals = [] + num_outputs = grid_mapping.num_outputs + output_block_shapes = block_shapes[num_inputs : num_inputs + num_outputs] + for i, bm in enumerate(grid_mapping.block_mappings_output): + if i in oi_alias_map: + # Reuse the HBM buffer for the aliased pallas_call input. + output_buffer_ids.append(input_buffer_ids[oi_alias_map[i]]) + output_buffer_shapes.append(input_args[oi_alias_map[i]].shape) + output_vals.append(input_args[oi_alias_map[i]]) + else: + out_val = interpret_params.get_uninitialized_array( + bm.array_aval.shape, bm.array_aval.dtype + ) + padded_val = interpret_params.pad_to_block_dimension( + out_val, output_block_shapes[i] + ) + output_buffer_ids.append( + callback.io_callback( + _allocate_buffer, + jax.ShapeDtypeStruct((), jnp.int16), + device_id, + None, # local_core_id + TPU_MEMORY_SPACE_IDXS[pallas_core.MemorySpace.ANY], + padded_val, + ordered=True, + ) + ) + output_buffer_shapes.append(padded_val.shape) + output_vals.append(out_val) + + # Allocate buffers for non-HBM kernel arguments (e.g., scalars, inputs, + # outputs, scratch). + scalar_buffer_ids = [] + for var, val in zip(jaxpr.invars[grid_mapping.slice_index_ops], scalars): + assert var.aval.shape == val.shape + assert var.aval.dtype == val.dtype + scalar_buffer_ids.append( + callback.io_callback( + _allocate_buffer, + jax.ShapeDtypeStruct((), jnp.int16), + device_id, + None, # local_core_id, + TPU_MEMORY_SPACE_IDXS[mosaic_core.MemorySpace.SMEM], + val, + ordered=True, + ) + ) + + kernel_buffer_ids = scalar_buffer_ids.copy() + for i, var in enumerate(jaxpr.invars[grid_mapping.num_index_operands:]): + output_idx = i - grid_mapping.num_inputs + is_input = i < grid_mapping.num_inputs + is_output = (output_idx >= 0) and (output_idx < grid_mapping.num_outputs) + if var.aval.memory_space == mosaic_core.MemorySpace.SEMAPHORE: + kernel_buffer_ids.append( + callback.io_callback( + _allocate_semaphores, + jax.ShapeDtypeStruct(var.aval.shape, jnp.int16), + device_id, + None, # local_core_id + var.aval.shape, + ordered=True, + ) + ) + elif _is_any(var.aval.memory_space): + # Use the already-allocated HBM input or output buffer. + # + # TODO(jburnim): For kernel args in HBM, check that block shape equals the + # shape of the corresponding pallas_call input, and that the index_map + # is trivial. + assert is_input ^ is_output + if is_input: + kernel_buffer_ids.append(input_buffer_ids[i]) + if is_output: + kernel_buffer_ids.append(output_buffer_ids[output_idx]) + else: + kernel_buffer_ids.append( + callback.io_callback( + _allocate_buffer, + jax.ShapeDtypeStruct((), jnp.int16), + device_id, + None, # local_core_id, + TPU_MEMORY_SPACE_IDXS[var.aval.memory_space], + interpret_params.get_uninitialized_array( + var.aval.shape, var.aval.dtype + ), + ordered=True, + ) + ) + + if _get_mosaic_params(compiler_params).collective_id is None: + # The kernel doesn't specify its own barrier semaphore, so we do a global + # barrier before running the first iteration of the kernel. + callback.io_callback(_barrier, (), device_id, ordered=True) + + _, input_ids, kernel_output_ids, _ = split_list( + kernel_buffer_ids, + [grid_mapping.num_index_operands, num_inputs, grid_mapping.num_outputs]) + input_vars, output_vars = split_list( + jaxpr.invars[grid_mapping.slice_block_ops], [num_inputs]) + + if grid: + num_iterations = functools.reduce(jnp.multiply, grid) # type: ignore[arg-type] + else: + # Base case is always one iteration when grid is () + num_iterations = 1 + + if isinstance(mesh, mosaic_core.TensorCoreMesh): + # We are interpreting a pl.core_map over a TensorCoreMesh, so we use a + # fixed division of the grid between cores, instead of a random division. + randomized_grid_coordinates = (jnp.array((), dtype=jnp.int32),) * len(grid) + else: + randomized_grid_coordinates = _get_randomized_grid_coordinates( + grid, compiler_params, interpret_params.random_seed # type: ignore[arg-type] + ) + + parallel_dim_semantics = _get_parallel_dim_semantics( + compiler_params, len(grid) + ) + parallel_subgrid_size = _get_parallel_subgrid_size( + parallel_dim_semantics, grid # type: ignore[arg-type] + ) + num_points_in_parallel_subgrid_per_core = ( + parallel_subgrid_size + interpret_params.num_cores_per_device - 1 + ) // interpret_params.num_cores_per_device # We round up here. + num_iterations_per_point_in_parallel_subgrid = ( + # This is evenly divisible. + num_iterations // parallel_subgrid_size # type: ignore[operator] + ) + num_iterations_per_core = ( + num_points_in_parallel_subgrid_per_core + * num_iterations_per_point_in_parallel_subgrid + ) + def _get_local_grid_env(grid_point): + if grid_mapping.local_grid_env is not None: + return grid_mapping.local_grid_env(grid_point, grid) + else: + return tuple( + pallas_core.GridAxis(idx, b) + for dim, (idx, b) in enumerate(zip(grid_point, grid)) + if dim not in grid_mapping.vmapped_dims + ) + + def _execute_grid_for_core(core_index): + # NOTE: We assume here that all parallel dimensions appear before all + # arbitrary dimensions in the grid. (We will have raised an error earlier + # if this is not the case.) + # + # TODO(jburnim): Are we overusing nested local functions here? + initial_iteration_idx = core_index * num_iterations_per_core + loop_bound = jnp.minimum( + (core_index + 1) * num_iterations_per_core, num_iterations) + + def _body( + carry: tuple[ + jnp.int32, + tuple[jnp.int32, ...], + jnp.ndarray, + list[jnp.ndarray], + list[jnp.ndarray], + list[jnp.ndarray], + ], + ) -> tuple[ + jnp.int32, + tuple[jnp.int32, ...], + jnp.ndarray, + list[jnp.ndarray], + list[jnp.ndarray], + list[jnp.ndarray], + ]: + """Performs one execution of the kernel body. + + Execution of `jaxpr` is preceded by reading kernel input buffers and + followed by writing kernel output buffers. + + Args: + carry: (iteration_idx, loop_idx, grid_point, prev_start_indices, + cur_start_indices). + - iteration_idx: the iteration index. + - loop_idx: internal indices for looping over the grid. + - grid_point: the current positions along all axes of the grid. + - prev_start_indices: a rank-1 array that contains the start indices + for the slices of inputs and outputs processed in the previous loop + iteration. + - cur_start_indices: a rank-1 array that contains the start indices + for the slices of inputs and outputs processed in the current loop + iteration. + + Note that by carrying the previous *and* current start indices between + loop iterations, it suffices to compute only one list of start indices, + i.e. `next_start_indices` (see below), per iteration. + + Returns: + The carry for the next iteration. + """ + ( + iteration_idx, + loop_idx, + grid_point, + prev_start_indices, + cur_block_indices, + cur_start_indices, + ) = carry + if interpret_params.grid_point_recorder is not None: + callback.io_callback( + interpret_params.grid_point_recorder, + (), + grid_point, + core_index, + ) + + with pallas_core.grid_env(_get_local_grid_env(grid_point)): + next_loop_idx = _get_next_indices(grid, loop_idx) + next_grid_point = _get_grid_point( + next_loop_idx, randomized_grid_coordinates + ) + next_block_indices, next_start_indices = zip(*[ + _compute_start_indices( + bm, + next_grid_point, + *scalar_buffer_ids, + axis_sizes=axis_sizes, + mesh=mesh, + axis_indices=axis_indices, + device_id=device_id, + local_core_id=core_index, + compiler_params=compiler_params, + interpret_params=interpret_params, + ) + for bm in grid_mapping.block_mappings + ]) + if jaxpr.debug_info.arg_names is not None: + input_names, output_names = split_list( + jaxpr.debug_info.arg_names[grid_mapping.slice_block_ops], [num_inputs]) + else: + input_names = ["unknown",] * grid_mapping.num_inputs + output_names = ["unknown",] * grid_mapping.num_outputs + + # Copy slices of the input to the kernel buffers. + def _store_slice_to_kernel_input(index, input_var): + # Copy from the HBM buffer for the pallas_call input to the kernel + # input buffer. + # TODO(jburnim): Just use input_args[j] when the input is not aliased? + transform = indexing.NDIndexer( + indices=tuple( + indexing.ds(st, sz) if not iid else st + for st, sz, iid in zip( + cur_start_indices[index], + block_shapes[index], + is_squeeze_dim[index], + ) + ), + shape=input_args[index].shape, + int_indexer_shape=(), + ) + sliced_val = callback.io_callback( + # TODO(jburnim): Pass source_info from the pallas_call, in case this + # read is involved in a data race. + functools.partial(get, input_name=input_names[index]), + jax.ShapeDtypeStruct(input_var.aval.shape, input_var.aval.dtype), + device_id, + core_index, + TPU_MEMORY_SPACE_IDXS[pallas_core.MemorySpace.ANY], + input_buffer_ids[index], + (transform,), + cur_block_indices[index], + grid_point, + ordered=True, + ) + callback.io_callback( + # TODO(jburnim): Pass source_info from the pallas_call, in case this + # store is involved in a data race. + store, + (), + device_id, + core_index, + TPU_MEMORY_SPACE_IDXS[input_var.aval.memory_space], + input_ids[index], + (), + sliced_val, + ordered=True, + ) + + for j, var in enumerate(input_vars): + if _is_any(var.aval.memory_space): + continue + assert len(cur_start_indices[j].shape) == 1 + assert len(prev_start_indices[j].shape) == 1 + jax.lax.cond( + (iteration_idx == initial_iteration_idx) + | jax.lax.reduce_or( + cur_start_indices[j] != prev_start_indices[j], axes=(0,) + ), + functools.partial(_store_slice_to_kernel_input, j, var), + lambda: None, + ) + + # Invoke the kernel. + _interpret_jaxpr( + jaxpr, + *kernel_buffer_ids, + axis_sizes=axis_sizes, + mesh=mesh, + axis_indices=axis_indices, + device_id=device_id, + local_core_id=core_index, + compiler_params=compiler_params, + interpret_params=interpret_params, + ) + + # Copy from the kernel buffers to slices of the output in HBM. + def _store_to_output_buffer(index, output_var, transform): + kernel_output_val = callback.io_callback( + # TODO(jburnim): Pass source_info from the pallas_call, in case this + # get is involved in a data race. + get, + output_var.aval, + device_id, + core_index, + TPU_MEMORY_SPACE_IDXS[output_var.aval.memory_space], + kernel_output_ids[index], + (), + ordered=True, + ) + callback.io_callback( + # TODO(jburnim): Pass source_info from the pallas_call, in case this + # store is involved in a data race. + functools.partial(store, output_name=output_names[index]), + (), + device_id, + core_index, + TPU_MEMORY_SPACE_IDXS[pallas_core.MemorySpace.ANY], + output_buffer_ids[index], + (transform,), + kernel_output_val, + cur_block_indices[num_inputs + index], + grid_point, + ordered=True, + ) + + output_slices : list[Any] = [] + for j, var in enumerate(output_vars): + if _is_any(var.aval.memory_space): + output_slices.append(None) + continue + assert len(cur_start_indices[num_inputs + j].shape) == 1 + assert len(next_start_indices[num_inputs + j].shape) == 1 + transform = indexing.NDIndexer( + indices=tuple( + indexing.ds(st, sz) if not iid else st # type: ignore[misc] + for st, sz, iid in zip( + cur_start_indices[num_inputs + j], + block_shapes[num_inputs + j], + is_squeeze_dim[num_inputs + j], + ) + ), + shape=output_vals[j].shape, + int_indexer_shape=(), + ) + if j in oi_alias_map: + # Suppress revisiting check for output buffers that are aliased to + # input buffers. + output_slices.append(None) + else: + output_slices.append((transform,)) + jax.lax.cond( + (iteration_idx + 1 == loop_bound) + | jax.lax.reduce_or( + cur_start_indices[num_inputs + j] + != next_start_indices[num_inputs + j], + axes=(0,), + ), + functools.partial(_store_to_output_buffer, j, var, transform), + lambda: None, + ) + callback.io_callback( + _check_for_revisiting, + (), + device_id, + core_index, + loop_idx, + output_slices, + ordered=True, + ) + + return ( + iteration_idx + 1, + next_loop_idx, + next_grid_point, + cur_start_indices, + next_block_indices, + next_start_indices, + ) + + initial_loop_idx = _get_indices(grid, initial_iteration_idx) + initial_grid_point = _get_grid_point( + initial_loop_idx, randomized_grid_coordinates) + with pallas_core.grid_env(_get_local_grid_env(initial_grid_point)): + initial_block_indices, initial_start_indices = zip(*[ + _compute_start_indices( + bm, + initial_grid_point, + *scalar_buffer_ids, + axis_sizes=axis_sizes, + mesh=mesh, + axis_indices=axis_indices, + device_id=device_id, + local_core_id=core_index, + compiler_params=compiler_params, + interpret_params=interpret_params, + ) + for bm in grid_mapping.block_mappings + ]) + + _ = lax.while_loop( + lambda carry: carry[0] < loop_bound, + _body, + ( + initial_iteration_idx, + initial_loop_idx, + initial_grid_point, + initial_start_indices, # Previous start indices are ignored on the first iteration. + initial_block_indices, + initial_start_indices, + ), + ) + + # TODO(jburnim): Should we only create happens-before here from core 0 to + # the other cores? + callback.io_callback( + _update_clocks_for_device_barrier, (), device_id, ordered=True + ) + + thread_map(_execute_grid_for_core, interpret_params.num_cores_per_device) + + # TODO(jburnim): Should we only create happens-before here from the other + # # cores to core 0? + callback.io_callback( + _update_clocks_for_device_barrier, (), device_id, ordered=True + ) + + # Read the output from the allocated output buffers. + ret = [ + callback.io_callback( + # TODO(jburnim): Pass source_info from the pallas_call, in case this + # get is involved in a data race. + get, + val, + device_id, + 0, # local_core_id + TPU_MEMORY_SPACE_IDXS[pallas_core.MemorySpace.ANY], + output_buffer_id, + ( + indexing.NDIndexer.from_indices_shape( + tuple(indexing.ds(0, s) for s in val.shape), + output_buffer_shape, + ), + ), + ordered=True, + ) + for val, output_buffer_id, output_buffer_shape in zip( + output_vals, output_buffer_ids, output_buffer_shapes + ) + ] + + callback.io_callback(_validate, (), device_id, ordered=True) + + # For now, when we're done with a pallas_call, we delete the shared memory. + # We use a barrier to ensure that all devices are done running the kernel. + # + # TODO(jburnim): Get rid of this barrier. And figure out how this should + # work if we want to invoke successive pallas_calls that use the same + # shared memory. + callback.io_callback( + _clean_up_shared_memory, (), device_id, ordered=True + ) + + return ret diff --git a/jax/_src/pallas/mosaic/interpret/race_detection_state.py b/jax/_src/pallas/mosaic/interpret/race_detection_state.py new file mode 100644 index 000000000000..ff76778119d3 --- /dev/null +++ b/jax/_src/pallas/mosaic/interpret/race_detection_state.py @@ -0,0 +1,187 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import dataclasses +import itertools +import threading + +from jax._src import source_info_util +from jax._src.pallas.mosaic.interpret import vector_clock as vc + + +def _is_empty_slice(slice_or_idx: slice | int): + if isinstance(slice_or_idx, int) or (slice_or_idx == slice(None)): + return False + + # NOTE: All slices here will have known size. + start = int(slice_or_idx.start) if slice_or_idx.start is not None else 0 + stop = int(slice_or_idx.stop) + return start < stop + + +def _slices_overlap(slice_or_idx1: slice | int, slice_or_idx2: slice | int): + if isinstance(slice_or_idx1, int): + slice_or_idx1 = slice(slice_or_idx1, slice_or_idx1 + 1) + if isinstance(slice_or_idx2, int): + slice_or_idx2 = slice(slice_or_idx2, slice_or_idx2 + 1) + + if slice_or_idx1 == slice(None): + return _is_empty_slice(slice_or_idx2) + if slice_or_idx2 == slice(None): + return _is_empty_slice(slice_or_idx1) + + # TODO(jburnim): Handle non-zero steps. + assert (slice_or_idx1.step == 1) or (slice_or_idx1.step is None) + assert (slice_or_idx2.step == 1) or (slice_or_idx2.step is None) + + assert slice_or_idx1.start is not None + assert slice_or_idx1.stop is not None + assert slice_or_idx2.start is not None + assert slice_or_idx2.stop is not None + + # NOTE: We are only comparing slices with known stops (and sizes). + # Do we need to handle zero-length slices? + return (slice_or_idx1.start <= slice_or_idx2.start < slice_or_idx1.stop) | ( + slice_or_idx2.start <= slice_or_idx1.start < slice_or_idx2.stop + ) + + +def _ranges_overlap( + range1: tuple[slice | int, ...], range2: tuple[slice | int, ...] +) -> bool: + return all( + _slices_overlap(r1, r2) + for r1, r2 in itertools.zip_longest(range1, range2, fillvalue=slice(None)) + ) + + +@dataclasses.dataclass +class RaceDetectionState: + num_cores: int + + # (memory_space, buffer_id, device_id, local_core_id) -> [(device_id, local_core_id, VectorClock, range)] + reads: dict = dataclasses.field( + default_factory=lambda: collections.defaultdict(list) + ) + + # (memory_space, buffer_id, device_id, local_core_id) -> [(device_id, local_core_id, VectorClock, range)] + writes: dict = dataclasses.field( + default_factory=lambda: collections.defaultdict(list) + ) + + lock: threading.Lock = dataclasses.field(default_factory=threading.Lock) + + races_found: bool = False + + def check_read( + self, device_id, local_core_id, clock, buffer_key, rnge, source_info=None + ): + if source_info is not None: + user_frame = source_info_util.summarize(source_info) + else: + user_frame = 'pallas_call' + + with self.lock: + writes = self.writes[buffer_key] + num_writes = len(writes) + self.reads[buffer_key].append( + (device_id, local_core_id, clock, rnge, user_frame) + ) + + for i in range(num_writes): + ( + write_device_id, + write_local_core_id, + write_clock, + write_range, + write_frame, + ) = writes[i] + if vc.ordered(write_clock, clock): + continue + if not _ranges_overlap(rnge, write_range): + continue + # TODO(jburnim): When printing device IDs for reads/writes, distinguish + # between real device IDs vs. DMA IDs. + print( + f'RACE DETECTED\n read of {buffer_key}[{rnge}] from {device_id},' + f' {local_core_id}, {user_frame}\n clock: {clock}\n write of' + f' {buffer_key}[{write_range}] from {write_device_id},' + f' {write_local_core_id} {write_frame}\n clock: {write_clock}\n' + ) + with self.lock: + self.races_found = True + return + + def check_write( + self, device_id, local_core_id, clock, buffer_key, rnge, source_info=None + ): + if source_info is not None: + user_frame = source_info_util.summarize(source_info) + else: + user_frame = 'pallas_call' + + with self.lock: + writes = self.writes[buffer_key] + reads = self.reads[buffer_key] + num_writes = len(writes) + num_reads = len(reads) + self.writes[buffer_key].append((device_id, local_core_id, clock, rnge, user_frame)) + + # TODO(jburnim): For performance, we should also probably remove any + # conflicting reads and writes that happened-before the current write. + + for i in range(num_writes): + ( + write_device_id, + write_local_core_id, + write_clock, + write_range, + write_frame, + ) = writes[i] + if vc.ordered(write_clock, clock): + continue + if not _ranges_overlap(rnge, write_range): + continue + # TODO(jburnim): When printing device IDs for reads/writes, distinguish + # between real device IDs vs. DMA IDs. + print( + f'RACE DETECTED\n write of {buffer_key}[{rnge}] from {device_id},' + f' {local_core_id}, {user_frame}\n clock: {clock}\n write of' + f' {buffer_key}[{write_range}] from {write_device_id},' + f' {write_local_core_id}, {write_frame}\n clock: {write_clock}\n' + ) + with self.lock: + self.races_found = True + break + + for i in range(num_reads): + read_device_id, read_local_core_id, read_clock, read_range, read_frame = ( + reads[i] + ) + if vc.ordered(read_clock, clock): + continue + if not _ranges_overlap(rnge, read_range): + continue + # TODO(jburnim): When printing device IDs for reads/writes, distinguish + # between real device IDs vs. DMA IDs. + print( + f'RACE DETECTED\n write of {buffer_key}[{rnge}] from {device_id},' + f' {local_core_id}, {user_frame}\n clock: {clock}\n read of' + f' {buffer_key}[{read_range}] from {read_device_id},' + f' {read_local_core_id}, {read_frame}\n clock: {read_clock}\n' + ) + with self.lock: + self.races_found = True + return diff --git a/jax/_src/pallas/mosaic/interpret/shared_memory.py b/jax/_src/pallas/mosaic/interpret/shared_memory.py new file mode 100644 index 000000000000..21fd6600928a --- /dev/null +++ b/jax/_src/pallas/mosaic/interpret/shared_memory.py @@ -0,0 +1,591 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import collections +from collections.abc import Sequence +import dataclasses +import gc +import threading +from typing import Any, Callable, Literal + +from jax._src.pallas.mosaic.interpret import vector_clock as vc +import numpy as np + + +class Semaphore: + + def __init__( + self, + shared_memory: SharedMemory, + semaphore_id: int, + ): + self.shared_memory = shared_memory + self.id: int = semaphore_id + + # TODO(jburnim): Use one Condition variable per device. (Which will be + # easier to do when we're using single integer device IDs.) + self.cv = threading.Condition() + + self.count_by_core = np.zeros(self.shared_memory.num_cores, dtype=np.int32) + + if self.shared_memory.detect_races: + # We associate a vector clock with each count in self.counts. Whenever + # self.count_by_core[i] is signaled, self.clocks[i] is updated with the + # vector clock of the signaling core. Whenever core i successfully waits + # on self.count_by_core[i], the vector clock of core i is updated with + # self.clocks[i]. + # + # TODO(jburnim): Model happens-before more precisely for the case where + # semaphores are over-signaled. + self.clocks: list[vc.VectorClock | None] = [ + None + ] * self.shared_memory.num_cores + + @property + def num_cores(self) -> int: + return self.shared_memory.num_cores + + @property + def detect_races(self) -> bool: + return self.shared_memory.detect_races + + @property + def dma_execution_mode(self) -> str: + return self.shared_memory.dma_execution_mode + + def get_global_core_id(self, device_id: int, local_core_id: int) -> int: + return self.shared_memory.get_global_core_id(device_id, local_core_id) + + def signal(self, inc, global_core_id, clock): + """Signal the semaphore on `(device_id, core_id)` by `inc`. + + Args: + inc: A positive integer. The amount by which to increment the semaphore + on the target device. + global_core_id: The ID of the target core. + clock: The vector clock of the signaling device at the time of the signal. + """ + global_core_id = int(global_core_id) + with self.cv: + self.count_by_core[global_core_id] += inc + if self.shared_memory.detect_races: + if self.clocks[global_core_id] is None: + self.clocks[global_core_id] = vc.copy_vector_clock(clock) + else: + vc.update_vector_clock(self.clocks[global_core_id], clock) + self.cv.notify_all() + + def read(self, global_core_id): + with self.cv: + return self.count_by_core[global_core_id] + + def wait(self, value, global_core_id, *, has_tasks=False): + global_core_id = int(global_core_id) + + # TODO(jburnim): + # - If the count is larger than value, raise an error? + # - If the count is equal to value, but there DMAs waiting to signal us, + # raise an error? + + # Simple implementation for semaphores that have no tasks that can signal + # them. + clock = None + if not has_tasks: + with self.cv: + while self.count_by_core[global_core_id] < value: + self.cv.wait() + self.count_by_core[global_core_id] -= value + if self.detect_races: + assert self.clocks[global_core_id] is not None + clock = vc.copy_vector_clock(self.clocks[global_core_id]) + if self.detect_races: + with self.shared_memory.lock: + assert clock is not None + vc.update_vector_clock( + self.shared_memory.clocks[global_core_id], clock + ) + return + + # TODO(nrink): Update the comment below to generalize from DMAs and DMA + # semaphores. We now have the concept of 'tasks' that can signal a + # semaphore. At the moment, DMAs are the only tasks that occur; and what is + # allowed to be a task may still change (because it should probably be more + # restricted than allowing tasks to be arbitrary callables, as is currently + # done). + # + # For DMA semaphores (when shared_memory.dma_execution_mode=='on_wait'), + # while our count is not large enough we will select and partially execute + # pending DMAs until our count is large enough. + # + # This approach will tend to run DMAs as late as possible, as well as + # out-of-order. This approach also lets us avoid the complexity of spinning + # up separate threads to handle executing DMAs. + while True: + clock = None + with self.cv: + if self.count_by_core[global_core_id] >= value: + self.count_by_core[global_core_id] -= value + if self.detect_races: + assert self.clocks[global_core_id] is not None + clock = vc.copy_vector_clock(self.clocks[global_core_id]) + else: + return + if clock is not None: + with self.shared_memory.lock: + vc.update_vector_clock( + self.shared_memory.clocks[global_core_id], clock + ) + return + + with self.shared_memory.lock: + task_queue = self.shared_memory.tasks_by_sem[(self.id, global_core_id)] + if len(task_queue) > 0: + task = task_queue.pop() + else: + continue + + task() + + +# A `SemaphoreTask` is called when a semaphore is waiting to be signalled on a +# specific core. A `SemaphoreTask` will typically capture the `Semaphore` object +# that is waiting, so that when the task is called, it can signal the semaphore +# (by calling `Semaphore.signal` from within the task). When a `SemaphoreTask` +# object is called, it can be assumed that the call stack of the task will +# *not* hold the lock on the shared memory in the captured `Semaphore` object. +# This allows the task to use methods from `SharedMemory` to access and modify +# the global shared memory object. +SemaphoreTask = Callable[[], None] + + +@dataclasses.dataclass +class Buffer: + content: np.ndarray + _: dataclasses.KW_ONLY + ref_count: int = 1 + + def decrease_ref_count(self): + # We should never decrease the `ref_count` to below zero. + assert self.ref_count > 0 + self.ref_count -= 1 + + def has_zero_ref_count(self) -> bool: + return self.ref_count == 0 + + def size(self) -> int: + return self.content.itemsize * self.content.size + + +@dataclasses.dataclass(frozen=True) +class ShapeAndDtype: + shape: Sequence[int] + dtype: np.dtype + + def __iter__(self): + return iter((self.shape, self.dtype)) + + +@dataclasses.dataclass +class SharedMemory: + num_devices: int + num_cores_per_device: int + out_of_bounds_reads: str + dma_execution_mode: str + uninitialized_memory: Literal["nan", "zero"] + detect_races: bool + vector_clock_size: int + + clocks: list[vc.VectorClock] + barrier: threading.Barrier + clean_up_barrier: threading.Barrier + + # (memory_space, buffer_id, device_id, local_core_id) -> NumPy array + mem: dict[tuple[str, int, int, int], Buffer] = dataclasses.field( + default_factory=dict + ) + + # semaphore_id -> Semaphore + sem: dict[int, Semaphore] = dataclasses.field(default_factory=dict) + + # (semaphore_id, global_core_id) + # -> tasks that will signal the semaphore on the core with the given ID and + # that should therefore be considered for execution when the semaphore is + # waiting (to be signalled). + tasks_by_sem: dict[tuple[int, int], list[SemaphoreTask]] = dataclasses.field( + default_factory=lambda: collections.defaultdict(list) + ) + + lock: threading.Lock = dataclasses.field(default_factory=threading.Lock) + + # (device_id, local_core_id) -> next buffer ID + next_buffer_id: dict[tuple[int, int], int] = dataclasses.field( + default_factory=lambda: collections.defaultdict(lambda: 100) + ) + # global_core_id -> next semaphore ID + next_semaphore_id: dict[int, int] = dataclasses.field( + default_factory=lambda: collections.defaultdict(lambda: 2000) + ) + + deallocated_bytes: int = 0 + + # (device_id, local_core_id) -> [(grid_index, [range])] + output_ranges: dict[tuple[int, int], list] = dataclasses.field( + default_factory=lambda: collections.defaultdict(list) + ) + + # semaphore_id -> Semaphore, where the semaphore_id is user-specified. + fixed_id_sem: dict[int, Semaphore] = dataclasses.field( + default_factory=dict + ) + + @property + def num_cores(self) -> int: + return self.num_devices * self.num_cores_per_device + + def get_global_core_id(self, device_id: int, local_core_id: int) -> int: + """Computes the global core ID from the given device and local core ID.""" + return device_id * self.num_cores_per_device + local_core_id + + def get_global_core_ids(self, device_id: int) -> Sequence[int]: + """Computes the global core IDs for all cores in the given device.""" + return tuple( + self.get_global_core_id(device_id, core_id) + for core_id in range(self.num_cores_per_device) + ) + + def append_semaphore_task( + self, + semaphore_id: int, + global_core_id: int, + task: SemaphoreTask, + ): + """Appends a task to be executed if the semaphore with the given sempahore ID is waiting to be signalled on the core with the given global core ID.""" + with self.lock: + self.tasks_by_sem[(semaphore_id, global_core_id)].append(task) + + def get_random_virtual_device_id(self) -> int: + # Virtual device IDs are needed for DMAs. Conceptually, each DMA runs on its + # own, independent device. Representing this precisely would require vector + # clocks to have sizes linear in the number of DMAs. + # + # Instead, we use approximate vector clocks of fixed size. We assign each + # DMA a virtual core ID in the range + # + # [num_cores, self.vector_clock_size - 1], + # + # and each operation of a DMA increments the corresponding coordinate in its + # vector clock. (So the "virtual" part of a vector clock is effectively + # counting, for each virtual core, the number of DMAs that happened-before + # the vector clock and were assigned to that virtual core.) + # + # If two approximate clocks are unordered, then their corresponding events + # are not ordered by the happens-before relation. So this approximation will + # not introduce any false positives in detecting data races. But we may fail + # to detect some true data races because there can be cases where two + # approximate clocks are ordered, and we will treat the corresponding events + # as ordered by the happens-before relation, but the corresponding events + # are not actually ordered. + return np.random.randint(self.num_cores, self.vector_clock_size) + + def print(self, device_id: int): + device_id = int(device_id) + if device_id == 0: + with self.lock: + print(self.mem) + + def get_semaphores_and_increment_clock( + self, sem_ids: Sequence[int | None], global_core_id: int + ) -> tuple[list[Semaphore | None], vc.VectorClock | None]: + """Returns the semaphores with the given `sem_ids` and increments the vector clock for the core with `global_core_id`. + + If race detection is enabled, this method increments the vector clock for + the core with the given `global_core_id` (while holding the lock on `self`). + We do this so that we can associate a (vector clock) time with the shared + memory operation of looking up the semaphores, which in turn can be used as + a proxy for the time when the returned semaphores are used by the client of + the `SharedMemory` class without acquiring the lock on `self`. (For the + purpose of encapsulation, we prefer to think of `self.lock` as a private + attribute of the `SharedMemory` class; hence clients of the class should not + attempt to acquire this lock explicitly.) + + Args: + sem_ids: The IDs of the semaphores to return or None. + global_core_id: The ID of the core whose vector clock should be + incremented (if race detection is enabled). + + Returns: + - The semaphores with the given `sem_ids` or None if the corresponding + entry in `sem_ids` is None. + - The incremented vector clock for the core with the given + `global_core_id`, or None if race detection is not enabled. + """ + clock = None + with self.lock: + if self.detect_races: + vc.inc_vector_clock(self.clocks[global_core_id], global_core_id) + clock = vc.copy_vector_clock(self.clocks[global_core_id]) + + sems = [] + for sem_id in sem_ids: + if sem_id is None: + sem = None + elif sem_id in self.fixed_id_sem: + if sem_id in self.sem: + # TODO(nrink): For now we make it the responsibility of the client to + # ensure that fixed-ID semaphores do not collide with internal + # semaphore IDs. + raise ValueError( + f'Semaphore {sem_id} occurs as both fixed-id and internal.' + ) + sem = self.fixed_id_sem[sem_id] + else: + sem = self.sem[sem_id] + sems.append(sem) + + return sems, clock + + def get_sempahores_with_nonzero_count( + self, device_id: int + ) -> list[tuple[Semaphore, int]]: + """Returns tuples (semaphore, global_core_id) for all semaphores with a nonzero count for the core with `global_core_id`.""" + result = [] + with self.lock: + for _, sem in self.sem.items() | self.fixed_id_sem.items(): + with sem.cv: + for gci in self.get_global_core_ids(device_id): + if sem.count_by_core[gci] != 0: + result.append((sem, gci)) + return result + + def get_next_buffer_id(self, device_id: int, local_core_id: int) -> int: + """Returns the next buffer ID for the given device and local core ID.""" + with self.lock: + buffer_id = self.next_buffer_id[(device_id, local_core_id)] + self.next_buffer_id[(device_id, local_core_id)] = buffer_id + 1 + return buffer_id + + def allocate_buffer( + self, + key: Any, + ref_count: int, + value: np.ndarray, + ): + """Allocates a memory buffer with the given key unless it already exists.""" + with self.lock: + if key not in self.mem: + self.mem[key] = Buffer(value, ref_count=ref_count) + + def deallocate_buffer(self, key: Any): + """Decreases the ref count for the buffer with `key` and deallocates the buffer if the ref count is zero.""" + with self.lock: + buff = self.mem[key] + buff.decrease_ref_count() + if buff.has_zero_ref_count(): + self.mem.pop(key) + self.deallocated_bytes += buff.size() + del buff + + should_collect = self.deallocated_bytes > 100_000_000 + if should_collect: + self.deallocated_bytes = 0 + + if should_collect: + # Periodic garbage collection here prevents OOMs -- although it's not clear + # why arrays are not getting freed without this. + gc.collect() + + def allocate_semaphores(self, key: Any, num_semaphores: int) -> int: + """Returns the next semaphore ID and ensures that the next `num_semaphores` are allocated.""" + with self.lock: + semaphore_id = self.next_semaphore_id[key] + self.next_semaphore_id[key] = semaphore_id + num_semaphores + + for i in range(semaphore_id, semaphore_id + num_semaphores): + if i not in self.sem: + self.sem[i] = Semaphore(shared_memory=self, semaphore_id=i) + + return semaphore_id + + def guarantee_semaphore_with_fixed_id(self, semaphore_id: int): + """Ensures that a semaphore with the given `semaphore_id` exists. + + If the semaphore with the given ID does not exist, it is allocated. Note + that semaphores that are allocated with this method live in their own + address space (internally, they are mapped in a separate dictionary) from + the sempahores allocated with the `allocate_sempahores` method above. + + This methods is intended to be used for barrier semaphores, where the + _collective_ semaphore ID is specified by the interpreter (i.e. by the + client of the `SharedMemory` class). This simulates sempahores that exist + prior to any Pallas kernels being run. + + Args: + semaphore_id: The ID of the semaphore to ensure exists, i.e. is allocated. + """ + with self.lock: + if semaphore_id not in self.fixed_id_sem: + self.fixed_id_sem[semaphore_id] = Semaphore( + semaphore_id=semaphore_id, shared_memory=self + ) + + def get_buffer_content( + self, key: Any, rnge: tuple[slice | int, ...], global_core_id: int + ) -> tuple[np.ndarray | None, ShapeAndDtype, vc.VectorClock | None]: + """Reads contents of a memory buffer. + + Args: + key: The key of the buffer to read. + rnge: The range to read within the buffer. + global_core_id: The global core ID of the core reading the buffer. + + Returns: + - The contents of the read range of the buffer, or None if reading out of + bounds. + - The shape and dtype of the full content array of the buffer. + - The incremented vector clock for the core with the given global core ID, + or None if race detection is not enabled. + """ + clock = None + with self.lock: + if self.detect_races: + vc.inc_vector_clock(self.clocks[global_core_id], global_core_id) + clock = vc.copy_vector_clock(self.clocks[global_core_id]) + array = self.mem[key].content + + try: + result = array[rnge].copy() + except: + result = None + + shape_and_dtype = ShapeAndDtype(array.shape, array.dtype) + return result, shape_and_dtype, clock + + def store_buffer_content( + self, + key: Any, + rnge: tuple[slice | int, ...], + value: np.ndarray, + global_core_id: int, + ) -> tuple[bool, ShapeAndDtype, vc.VectorClock | None]: + """Stores contents into a memory buffer. + + Args: + key: The key of the buffer to store into. + rnge: The range within the buffer contents that `value` is written to. + value: The array to store into the buffer. + global_core_id: The global core ID of the core writing into the buffer. + + Returns: + - True of the store was in bounds, False otherwise. + - The shape and dtype of the full content array of the buffer. + - The incremented vector clock for the core with the given global core ID, + or None if race detection is not enabled. + """ + clock = None + with self.lock: + if self.detect_races: + vc.inc_vector_clock(self.clocks[global_core_id], global_core_id) + clock = vc.copy_vector_clock(self.clocks[global_core_id]) + array = self.mem[key].content + shape_and_dtype = ShapeAndDtype(array.shape, array.dtype) + + assert array.dtype == value.dtype # TODO(jburnim): Catch this statically. + # TODO(jburnim): Better error message if this raises? + in_bounds_shape = array[rnge].shape + if in_bounds_shape == value.shape: + is_in_bounds = True + array[rnge] = value + else: + is_in_bounds = False + + return is_in_bounds, shape_and_dtype, clock + + def swap_buffer_content( + self, + key: Any, + rnge: tuple[slice | int, ...], + value: np.ndarray, + mask: np.ndarray | None, + global_core_id: int, + ) -> tuple[np.ndarray | None, ShapeAndDtype, vc.VectorClock | None]: + """Swaps contents of a memory buffer. + + Args: + key: The key of the buffer to swap into. + rnge: The range within the buffer contents that `value` is swapped into. + value: The array to be written into the buffer. + mask: The mask to apply to the swap operation. + global_core_id: The global core ID of the core writing into the buffer. + + Returns: + - The contents of the range of the buffer (prior to the swap), or None if + accessing buffer contents bounds. + - The shape and dtype of the full content array of the buffer. + - The incremented vector clock for the core with the given global core ID, + or None if race detection is not enabled. + """ + clock = None + with self.lock: + if self.detect_races: + vc.inc_vector_clock(self.clocks[global_core_id], global_core_id) + clock = vc.copy_vector_clock(self.clocks[global_core_id]) + array = self.mem[key].content + shape_and_dtype = ShapeAndDtype(array.shape, array.dtype) + + assert array.dtype == value.dtype # TODO(jburnim): Catch this statically. + # TODO(jburnim): Better error message if this raises? + raw_result = array[rnge] + in_bounds_shape = raw_result.shape + + if mask is None: + if in_bounds_shape == value.shape: + array[rnge] = value + return raw_result.copy(), shape_and_dtype, clock + else: + return None, shape_and_dtype, clock + else: + in_bounds_mask = np.full(mask.shape, True) + for i in range(len(in_bounds_shape)): + in_bounds_mask[in_bounds_shape[i] :] = False + if (~in_bounds_mask & mask).any(): + return None, shape_and_dtype, clock + else: + in_bounds_idx = tuple(slice(i) for i in in_bounds_shape) + result = value.copy() + result[in_bounds_idx] = np.where( + mask[in_bounds_idx], raw_result, value[in_bounds_idx] + ) + array[rnge] = np.where( + mask[in_bounds_idx], value[in_bounds_idx], raw_result + ) + return result.copy(), shape_and_dtype, clock + + def update_clocks(self, low_global_core_id, high_global_core_id): + """Synchronizes the vector clocks for the cores with ids in the range between the two arguments.""" + # Despite only updating the vector clocks for some cores, we still need to + # hold the global lock to ensure that no other devices are concurrently + # accessing the same vector clocks. + with self.lock: + for c in self.clocks[low_global_core_id + 1 : high_global_core_id]: + vc.update_vector_clock(self.clocks[low_global_core_id], c) + for c in self.clocks[low_global_core_id + 1 : high_global_core_id]: + vc.update_vector_clock(c, self.clocks[low_global_core_id]) + + def update_clocks_for_device_barrier(self, device_id): + """Synchronizes the vector clocks for the cores on the given device.""" + low_core_id = device_id * self.num_cores_per_device + high_core_id = (device_id + 1) * self.num_cores_per_device + self.update_clocks(low_core_id, high_core_id) diff --git a/jax/_src/pallas/mosaic/interpret/thread_map.py b/jax/_src/pallas/mosaic/interpret/thread_map.py new file mode 100644 index 000000000000..3b162a70daeb --- /dev/null +++ b/jax/_src/pallas/mosaic/interpret/thread_map.py @@ -0,0 +1,80 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from concurrent import futures +import functools + +import jax +from jax._src import callback +import jax.core as jax_core +import jax.numpy as jnp + + +def _run_jaxpr(jaxpr, consts, *args): + def _run(jaxpr, consts, *args): + jax_core.eval_jaxpr(jaxpr, consts, *args) + + traced = jax.jit(_run, static_argnums=(0,)).trace(jaxpr, consts, *args) + traced.lower().compile()(consts, *args) + return + + +def _thread_map_callback(jaxpr, num_threads, consts): + num_threads = int(num_threads) + threads = [] + with futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + for i in range(num_threads): + threads.append(executor.submit(_run_jaxpr, jaxpr, consts, jnp.int32(i))) + exceptions = [] + for i in range(num_threads): + try: + threads[i].result() + except Exception as e: + exceptions.append(e) + if exceptions: + # TODO(jburnim): Use ExceptionGroup once JAX requires Python 3.11. + # raise ExceptionGroup('Exceptions raised during _thread_map', exceptions) + raise exceptions[0] + + +def _call_threadmap_callback(jaxpr, num_threads, *consts): + # NOTE: At runtime, _thread_map_callback will lower and compile the + # given jaxpr. (JAX's caches should ensure the jaxpr is only lowered and + # compiled once.) + # + # TODO(jburnim): Would it be worth trying to lower/compile the jaxpr at + # lowering/compilation time? E.g., by using a custom primitive here, could + # we lower/compile jaxpr at lowering time, and then pass the compiled + # function to the callback? + return callback.io_callback( + functools.partial(_thread_map_callback, jaxpr), + (), + num_threads, + consts, + ordered=True, + ) + + +def thread_map(f, num_threads): + if num_threads == 1: + f(jnp.int32(0)) + return + + def _f(core_index): + f(core_index) + return () + + jaxpr = jax.make_jaxpr(_f)(jnp.int32(0)) + + _call_threadmap_callback(jaxpr.jaxpr, num_threads, *jaxpr.consts) diff --git a/jax/_src/pallas/mosaic/interpret/utils.py b/jax/_src/pallas/mosaic/interpret/utils.py new file mode 100644 index 000000000000..f9575409f745 --- /dev/null +++ b/jax/_src/pallas/mosaic/interpret/utils.py @@ -0,0 +1,346 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Sequence +import dataclasses +import math +import threading +from typing import Any, Literal + +from jax import lax +from jax._src import core as jax_core +from jax._src.pallas import primitives +from jax._src.util import safe_map +import jax.numpy as jnp +import numpy as np + + +def get_uninitialized_value( + dtype, uninitialized_memory: Literal["nan", "zero"] +): + if uninitialized_memory == "nan": + if jnp.issubdtype(dtype, jnp.floating): + return np.nan + elif jnp.issubdtype(dtype, jnp.integer): + return jnp.iinfo(dtype).max + elif jnp.issubdtype(dtype, jnp.bool): + return True + if uninitialized_memory == "zero": + return 0 + raise NotImplementedError(uninitialized_memory + " + " + str(dtype)) + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class InterpretParams: + """Parameters for kernel interpret mode. + + Interpret mode is a way to run Pallas kernels on CPU, while simulating TPU/GPU + shared memory, communication, and synchronization operations. + + Attributes: + detect_races: If True, a dynamic, happens-before race detector will be used + to detect data races during kernel interpretation. If any races are + detected, a message will be printed and `races.races_found` will be set to + True. + Default: False. + out_of_bounds_reads: If "raise", an exception will be raised on any + out-of-bounds read of a buffer. If "uninitialized_value", any parts of + the read that are out-of-bounds will return the value used to fill + uninitialized memory, which can be configured via the + "uninitialized_memory". + Default: "raise". + skip_floating_point_ops: If True, operations that produce only floating + point values will not be interpreted; instead, their results will be + replaced with arrays all of `jnp.inf`. Additionally any floating point + operands to any operation will be replaced with (arrays of) `jnp.inf`. + Default: False. + uninitialized_memory: If "nan", allocated buffers are initialized to contain + all NaNs (or to their maximum possible value for integers). If "zero", + allocated buffers are initialized to all zeros. + Default: "nan". + num_cores_or_threads: The number of cores per device (TPU) or threads per + block (GPU). Note that for interpreting GPU kernels, we currently only + support a single block in the grid. (So the number of threads per block on + the GPU can be thought of as the number of threads that runs concurrently + on the GPU.) + Default: 1. + vector_clock_size: The number of entries in the vector clocks. This should + be an integer bigger then the total number of cores, i.e. bigger than + `number of devices * num_cores_per_device`. If `None`, the vector clock + size that is used in the interpreter will default to twice the total + number of cores. + Default: None. + """ + + detect_races: bool = False + out_of_bounds_reads: Literal["raise", "uninitialized"] = "raise" + skip_floating_point_ops: bool = False + uninitialized_memory: Literal["nan", "zero"] = "nan" + num_cores_or_threads: int = 1 + vector_clock_size: int | None = None + + def __post_init__(self): + if self.num_cores_or_threads < 1: + raise ValueError( + "Number of cores or threads must be at least 1, but got" + f" {self.num_cores_or_threads}." + ) + + def get_vector_clock_size(self, num_devices) -> int: + """Returns the number of vector clocks to use.`""" + num_cores_or_threads = num_devices * self.num_cores_or_threads + if self.vector_clock_size is not None: + if num_cores_or_threads >= self.vector_clock_size: + raise ValueError( + f"Vector clock size ({self.vector_clock_size}) must be greater than" + f" the total number of cores/threads ({num_cores_or_threads})." + ) + return self.vector_clock_size + else: + # Default to twice the total number of cores/threads. + return 2 * num_cores_or_threads + + def get_uninitialized_array(self, shape, dtype): + return jnp.full( + shape, + get_uninitialized_value(dtype, self.uninitialized_memory), + dtype, + ) + + def pad_to_block_dimension(self, value, block_shape): + """Pads values so the shape evenly divides into block dimensions. + + For example, if values has a shape of (33, 2, 5) with a block_shape of + (32, 2, 4), this function will pad the value of shape to (64, 2, 8). + + Args: + value: Array to be padded. + block_shape: Block shapes to use for padding. If None, no padding will be + performed. + + Returns: + A padded array. + """ + padded_shape = tuple( + ((v - 1) // b + 1) * b for v, b in zip(value.shape, block_shape) + ) + if padded_shape != value.shape: + pad_width = tuple((0, a - b) for a, b in zip(padded_shape, value.shape)) + pad_value = self.get_uninitialized_array((), value.dtype) + value = jnp.pad(value, pad_width, constant_values=pad_value) + return value + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class InterpretGPUParams(InterpretParams): + ... + + +class Counter: + """A simple counter that is thread-safe.""" + + def __init__(self, initial_value: int): + self.value = initial_value + self.lock = threading.Lock() + + def get_next(self): + with self.lock: + result = self.value + self.value += 1 + return result + + +# TODO(sharadmv): De-dup this w/ the impl in primitives.py. +def _device_id_dict_to_mesh(device_id_dict, axis_sizes, axis_indices): + physical_axis_dict = {} + axis_names = axis_sizes.keys() + for axis, idx in device_id_dict.items(): + if isinstance(axis, tuple) and any(a in axis_names for a in axis): + if not all(a in axis_names for a in axis): + raise NotImplementedError( + f"{axis} mixes JAX mesh and Pallas mesh grid axes" + ) + axes_dimensions = [axis_sizes[name] for name in axis] + for axis_index, axis_name in enumerate(axis): + axis_size = axis_sizes[axis_name] + inner_mesh_size = math.prod(axes_dimensions[axis_index + 1 :]) + minor_divisor = inner_mesh_size + + # Fast path for power of 2s + if inner_mesh_size & (inner_mesh_size - 1) == 0: + shift_len = (inner_mesh_size & -inner_mesh_size).bit_length() - 1 + partial_device_idx = idx >> shift_len + else: + partial_device_idx = idx // minor_divisor + + if axis_size & (axis_size - 1) == 0: + device_idx = partial_device_idx & (axis_size - 1) + else: + device_idx = partial_device_idx % axis_size + physical_axis_dict[axis_name] = device_idx + else: + physical_axis_dict[axis] = idx + device_id = [] + for axis in axis_names: + if axis in physical_axis_dict: + device_id.append(physical_axis_dict[axis]) + else: + device_id.append(axis_indices[axis]) + non_mesh_axes = { + k: v for k, v in physical_axis_dict.items() if k not in axis_names + } + return tuple(device_id), non_mesh_axes + + +def device_coords_to_logical_id(device_coords, axis_sizes, axis_indices): + if isinstance(device_coords, dict): + device_coords, non_mesh_axes = _device_id_dict_to_mesh( + device_coords, axis_sizes, axis_indices + ) + if non_mesh_axes: + raise NotImplementedError(non_mesh_axes) + if not isinstance(device_coords, tuple): + device_coords = (device_coords,) + assert len(device_coords) == len(axis_sizes) + sizes = list(axis_sizes.values()) + ret = 0 + for i in range(len(device_coords)): + ret += device_coords[i] * math.prod(sizes[i + 1 :]) + return ret + + +def _device_id_to_logical(device_id, device_id_type, axis_sizes, axis_indices): + if device_id is None: + return None + if device_id_type == primitives.DeviceIdType.MESH: + return device_coords_to_logical_id(device_id, axis_sizes, axis_indices) + elif device_id_type == primitives.DeviceIdType.LOGICAL: + return device_id + else: + raise ValueError(f"Unsupported device ID type: {device_id_type}") + + +def is_int(dtype): + return jnp.issubdtype(dtype, jnp.integer) + + +def is_float(dtype): + return jnp.issubdtype(dtype, jnp.floating) + + +@dataclasses.dataclass(frozen=True) +class Placeholder: + """Placeholder for use in `JaxprEnv` below instead of storing a concrete value.""" + + shape: tuple[int, ...] + dtype: jnp.dtype + + +class JaxprEnv: + """An environment for interpreting jaxprs, mapping variables to values.""" + + def __init__( + self, + *, + vars: Sequence[jax_core.Var] | None = None, + values: Sequence[Any] | None = None, + sentinel_for_floating_point_values: Any = None, + ): + self._sentinel_for_floating_point_values = ( + sentinel_for_floating_point_values + ) + self._env: dict[jax_core.Var, Any] = {} + + if vars is None and values is None: + return + + vars = vars or [] + values = values or [] + self.write_many(vars, values) + + def read(self, var): + if isinstance(var, jax_core.Literal): + result = var.val + else: + result = self._env[var] + if isinstance(result, Placeholder): + result = lax.full( + result.shape, self._sentinel_for_floating_point_values, result.dtype + ) + return result + + def read_many(self, vars): + return safe_map(self.read, vars) + + def write(self, var, value): + if self._sentinel_for_floating_point_values and is_float(value.dtype): + value = Placeholder(value.shape, value.dtype) + self._env[var] = value + + def write_many(self, vars, values): + safe_map(self.write, vars, values) + + +def _transform_slice_or_index(slice_or_idx): + if isinstance(slice_or_idx, int): + return slice_or_idx + else: + start = int(slice_or_idx.start) + size = int(slice_or_idx.size) + stride = int(slice_or_idx.stride) + return slice(start, start + size * stride, stride) + + +def _compose_slice_or_index(slice_or_idx1, slice_or_idx2): + ret = [] + i = 0 + j = 0 + while True: + if i == len(slice_or_idx1): + ret.extend(slice_or_idx2[j:]) + return tuple(ret) + elif j == len(slice_or_idx2): + ret.extend(slice_or_idx1[i:]) + return tuple(ret) + elif isinstance(slice_or_idx1[i], int): + ret.append(slice_or_idx1[i]) + i += 1 + elif isinstance(slice_or_idx2[j], int): + ret.append( + slice_or_idx1[i].start + slice_or_idx2[j] * slice_or_idx1[i].step + ) + i += 1 + j += 1 + else: + ret.append( + slice( + slice_or_idx1[i].start + + slice_or_idx2[j].start * slice_or_idx1[i].step, + slice_or_idx1[i].start + + slice_or_idx2[j].stop * slice_or_idx1[i].step, + slice_or_idx1[i].step * slice_or_idx2[j].step, + ) + ) + i += 1 + j += 1 + + +def to_range(transforms) -> tuple[slice | int, ...]: + ret = () + for transform in transforms: + # For now, assume only NDIndexer transforms. + ret = _compose_slice_or_index( + ret, tuple(_transform_slice_or_index(i) for i in transform.indices) + ) + return ret diff --git a/jax/_src/pallas/mosaic/interpret/vector_clock.py b/jax/_src/pallas/mosaic/interpret/vector_clock.py new file mode 100644 index 000000000000..c0c2c96b55e1 --- /dev/null +++ b/jax/_src/pallas/mosaic/interpret/vector_clock.py @@ -0,0 +1,46 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +VectorClock = np.ndarray + + +def make_vector_clock(vector_clock_size: int) -> VectorClock: + return np.zeros(vector_clock_size, dtype=np.int32) + + +def copy_vector_clock(x: VectorClock) -> VectorClock: + if x is None: + return None + return x.copy() + + +def update_vector_clock(x: VectorClock, y: VectorClock): + x[:] = np.maximum(x[:], y[:]) + + +def lt(x: VectorClock, y: VectorClock) -> bool: + return bool((x <= y).all() & (x < y).any()) + + +def ordered(x: VectorClock, y: VectorClock) -> bool: + return lt(x, y) | lt(y, x) + + +def inc_vector_clock(x: VectorClock, global_core_id: int): + if global_core_id >= len(x): + raise ValueError(f'device_id={global_core_id} is out of range for x={x}') + assert global_core_id < len(x) + x[global_core_id] += 1 diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 10b9de7487eb..4dee33a31cc7 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -15,12 +15,12 @@ """Module for lowering JAX to Mosaic-compatible MLIR dialects.""" from __future__ import annotations -from collections.abc import Callable, Sequence +from collections.abc import Callable, Collection, Hashable, Sequence import contextlib import dataclasses import functools import string -from typing import Any, Hashable +from typing import Any, Literal, Protocol, Self, TypeVar, cast import jax from jax import api_util @@ -28,62 +28,66 @@ from jax import tree_util from jax._src import ad_util from jax._src import checkify +from jax._src import config from jax._src import core as jax_core from jax._src import custom_derivatives from jax._src import debugging from jax._src import dtypes from jax._src import linear_util as lu +from jax._src import literals from jax._src import mesh as mesh_lib from jax._src import pjit from jax._src import prng from jax._src import source_info_util from jax._src import state from jax._src import traceback_util +from jax._src import xla_bridge from jax._src.cloud_tpu_init import is_cloud_tpu_older_than +from jax._src.export import shape_poly from jax._src.export._export import export from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe +from jax._src.lax import control_flow from jax._src.lax import lax as lax_internal -from jax._src.lax.control_flow import for_loop +from jax._src.lax.control_flow import BranchesPlatforms +from jax._src.lib import xla_client from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith +from jax._src.lib.mlir.dialects import cf from jax._src.lib.mlir.dialects import func from jax._src.lib.mlir.dialects import math from jax._src.lib.mlir.dialects import memref from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector from jax._src.pallas import core as pallas_core -from jax._src.pallas import pallas_call +from jax._src.pallas import helpers as pallas_helpers from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils from jax._src.pallas.mosaic import core as tpu_core from jax._src.pallas.mosaic import error_handling from jax._src.pallas.mosaic import primitives as tpu_primitives from jax._src.pallas.mosaic import random as pl_random -from jax._src.state import discharge as state_discharge from jax._src.state import indexing from jax._src.state import primitives as state_primitives from jax._src.state.types import RefBitcaster, RefReshaper -from jax._src.state.utils import dtype_bitwidth from jax._src.typing import Array, DTypeLike from jax._src.util import foreach from jax._src.util import safe_map from jax._src.util import safe_zip from jax._src.util import split_list -from jax._src.util import unzip2 from jax.experimental.mosaic.dialects import tpu import jax.numpy as jnp -from jaxlib.mlir.ir import Module import numpy as np # TODO(sharadmv): enable type checking # mypy: ignore-errors NDIndexer = indexing.NDIndexer -TPUMemorySpace = tpu_core.TPUMemorySpace -MemorySpace = pallas_core.MemorySpace | TPUMemorySpace -VMEM = tpu_core.TPUMemorySpace.VMEM -SMEM = tpu_core.TPUMemorySpace.SMEM +TPUMemorySpace = tpu_core.MemorySpace +AnyMemorySpace = pallas_core.MemorySpace | TPUMemorySpace +VMEM = TPUMemorySpace.VMEM +SMEM = TPUMemorySpace.SMEM +ANY = pallas_core.MemorySpace.ANY # Booleans are stored as the following type in memrefs. BOOL_MEMREF_TYPE = np.dtype('int32') @@ -99,12 +103,24 @@ map, unsafe_map = safe_map, map # pylint: disable=redefined-builtin zip, unsafe_zip = safe_zip, zip # pylint: disable=redefined-builtin +# Extended types that should not be converted to physical types in lowering. +PHYSICAL_EXTENDED_DTYPES = {pallas_core.semaphore_dtype} + + +def should_physicalize_dtype(dtype) -> bool: + """Returns whether a dtype should be lowered to a physical type.""" + return ( + jnp.issubdtype(dtype, dtypes.extended) and + not any(jnp.issubdtype(dtype, t) for t in PHYSICAL_EXTENDED_DTYPES) + ) + + +def _maybe_physicalize_block_shape(aval, block_shape): + if should_physicalize_dtype(aval.dtype): + physical_element_aval = jax_core.physical_element_aval(aval.dtype) # pytype: disable=wrong-arg-types + block_shape += physical_element_aval.shape + return block_shape -@dataclasses.dataclass -class MeshContext: - mesh_shape: tuple[int, ...] - axis_names: tuple[str, ...] - mesh_strides: tuple[int, ...] # Note - On Export Placeholders # @@ -130,8 +146,10 @@ class MeshContext: # SHLO functions that compute the symbolic dimension expression for the # placeholder. class LoweringDynamicShapeEnv: - dim_expr_to_placeholder: dict[shape_poly._DimExpr, int] = {} - placeholder_to_dim_expr: dict[int, shape_poly._DimExpr] = {} + + def __init__(self): + self.dim_expr_to_placeholder: dict[shape_poly._DimExpr, int] = {} + self.placeholder_to_dim_expr: dict[int, shape_poly._DimExpr] = {} def to_placeholder(self, dim_expr: Any) -> ir.Value: if jax_core.is_constant_dim(dim_expr): @@ -154,23 +172,29 @@ def to_placeholder(self, dim_expr: Any) -> ir.Value: return self.dim_expr_to_placeholder[dim_expr] +DynamicShapeReplacementFn = Callable[ + [tuple[jax_core.DimSize, ...]], tuple[int, ...] +] + + @dataclasses.dataclass class LoweringContext: - ir_context: ir.Context grid_sizes: tuple[int, ...] # Includes both user and vmap axes. grid_names: tuple[Hashable, ...] | None - mapped_dims: tuple[int, ...] # Indices of vmapped grid dimensions. + vmapped_dims: tuple[int, ...] # Indices of vmapped grid dimensions. user_grid_indices: Sequence[ir.Value] | None - block_shapes: list[tuple[int | pallas_core.Mapped, ...]] + block_shapes: list[tuple[int | pallas_core.Squeezed, ...]] name_stack: source_info_util.NameStack - mesh_context: MeshContext | None - replace = dataclasses.replace + mesh_context: pallas_utils.MeshInfo | None + kernel_type: tpu_core.KernelType traceback_caches: mlir.TracebackCaches - for_verification: bool forward_compatible: bool - dynamic_shape_replacement_fn: Callable[ - [tuple[jax.DimSize, ...]], tuple[int, ...] - ] + backend: xla_client.Client | None + dynamic_shape_replacement_fn: DynamicShapeReplacementFn + + def replace(self, **changes: Any) -> LoweringContext: + # The wrapper is necessary to convince pytype that this is a method. + return dataclasses.replace(self, **changes) @property def grid_rank(self): @@ -184,28 +208,54 @@ def grid_name_context(self): return grid_names = self.grid_names valid_grid_sizes = tuple( - d for i, d in enumerate(self.grid_sizes) if i not in self.mapped_dims + d for i, d in enumerate(self.grid_sizes) if i not in self.vmapped_dims ) grid_env = zip(grid_names, valid_grid_sizes) with jax_core.extend_axis_env_nd(grid_env): yield +# This is morally ``ShapedArray | state.AbstractRef``, but pytype does not +# allow calling methods on a union type, making ``update`` non-callable, so +# we use a protocol instead of a union. +class ShapedAbstractValue(Protocol): + shape: tuple[jax_core.DimSize, ...] + dtype: jnp.dtype + weak_type: bool + + def update(self, **kwargs: Any) -> Self: + raise NotImplementedError + + @dataclasses.dataclass class LoweringRuleContext: lowering_context: LoweringContext - avals_in: Sequence[jax_core.AbstractValue] - avals_out: Sequence[jax_core.AbstractValue] - block_shapes: Sequence[tuple[int | pallas_core.Mapped, ...] | None] - replace = dataclasses.replace + avals_in: Sequence[ShapedAbstractValue] + avals_out: Sequence[ShapedAbstractValue] + block_shapes: Sequence[tuple[int | pallas_core.Squeezed, ...] | None] + + def replace(self, **changes: Any) -> LoweringRuleContext: + # The wrapper is necessary to convince pytype that this is a method. + return dataclasses.replace(self, **changes) @property def forward_compatible(self): return self.lowering_context.forward_compatible + def is_cloud_tpu_older_than(self, year: int, month: int, day: int): + # No way for us to query the version, so assume the oldest possible backend. + if self.lowering_context.backend is None: + return True + backend = self.lowering_context.backend + return is_cloud_tpu_older_than(year, month, day, backend) + + +def _memory_space_to_tpu_memory_space( + memory_space: AnyMemorySpace | None, +) -> TPUMemorySpace | Literal[ANY]: + if memory_space == jax_core.MemorySpace.Device: + return ANY -def _memory_space_to_tpu_memory_space(memory_space: MemorySpace | None - ) -> TPUMemorySpace: match memory_space: case None: # We pick VMEM as the default one when no memory space is @@ -213,8 +263,14 @@ def _memory_space_to_tpu_memory_space(memory_space: MemorySpace | None return TPUMemorySpace.VMEM case pallas_core.MemorySpace.ANY: # Map the general ANY memory space to TPU ANY memory space - return TPUMemorySpace.ANY - case pallas_core.MemorySpace.ERROR | pallas_core.MemorySpace.INDEX: + return ANY + case pallas_core.MemorySpace.HOST: + return TPUMemorySpace.HOST + case ( + pallas_core.MemorySpace.ERROR + | pallas_core.MemorySpace.INDEX + | pallas_core.MemorySpace.KEY + ): return TPUMemorySpace.SMEM case TPUMemorySpace(): # Leave the memory space unchanged @@ -223,27 +279,29 @@ def _memory_space_to_tpu_memory_space(memory_space: MemorySpace | None raise ValueError(f"Invalid memory space: {memory_space}") -def _memory_space_to_mosaic_attribute(memory_space: MemorySpace | None +def _memory_space_to_mosaic_attribute(memory_space: AnyMemorySpace | None ) -> ir.Attribute: tpu_memory_space = _memory_space_to_tpu_memory_space(memory_space) return ir.Attribute.parse(f"#tpu.memory_space<{tpu_memory_space}>") -def _dtype_to_ir_type(dtype: jnp.dtype, +def _dtype_to_ir_type(dtype: DTypeLike, is_kernel_boundary: bool = False) -> ir.Type: - if jnp.issubdtype(dtype, tpu_core.semaphore_dtype): + if jnp.issubdtype(dtype, pallas_core.semaphore_dtype): if jnp.issubdtype(dtype, tpu_core.dma_semaphore): return ir.Type.parse("!tpu.dma_semaphore") - elif jnp.issubdtype(dtype, tpu_core.semaphore): + elif jnp.issubdtype(dtype, pallas_core.semaphore): return ir.Type.parse("!tpu.semaphore") - elif jnp.issubdtype(dtype, tpu_core.barrier_semaphore): + elif jnp.issubdtype(dtype, pallas_core.barrier_semaphore): return ir.Type.parse("!tpu.semaphore") else: raise NotImplementedError - if is_kernel_boundary and jnp.issubdtype(dtype, jnp.dtype('bool')): + if jnp.issubdtype(dtype, dtypes.extended): + raise NotImplementedError(f"Extended dtype {dtype} is unsupported.") + if is_kernel_boundary and jnp.issubdtype(dtype, jnp.bool): dtype = BOOL_MEMREF_TYPE # TODO(justinfu): Remove after mosaic supports unsigned types. # This conversion makes mosaic interpret all unsigned types as signed types. - type = mlir.dtype_to_ir_type(dtype) + type = mlir.dtype_to_ir_type(jnp.dtype(dtype)) if isinstance(type, ir.IntegerType): return ir.IntegerType.get_signless(type.width) else: @@ -254,9 +312,23 @@ def aval_to_ir_type( dynamic_shape_replacement_fn, aval, shape=None, - memory_space: MemorySpace | None = None, + memory_space: AnyMemorySpace | None = None, is_kernel_boundary: bool = False, + allow_extended_types: bool = True, ): + if allow_extended_types and should_physicalize_dtype(aval.dtype): + if isinstance(aval, state.AbstractRef): + inner_aval = jax_core.physical_aval(aval.inner_aval) # pytype: disable=wrong-arg-types + physical_aval = aval.update(inner_aval=inner_aval) + else: + physical_aval = jax_core.physical_aval(aval) # pytype: disable=wrong-arg-types + if shape is not None: + shape = jax_core.physical_shape(shape, aval.dtype) + return aval_to_ir_type(dynamic_shape_replacement_fn, + aval=physical_aval, + shape=shape, memory_space=memory_space, + is_kernel_boundary=is_kernel_boundary, + allow_extended_types=False) if isinstance(aval, tpu_core.AbstractSemaphore): if aval.sem_type is tpu_core.SemaphoreType.DMA: sem_type = ir.Type.parse("!tpu.dma_semaphore") @@ -268,21 +340,11 @@ def aval_to_ir_type( raise ValueError(f"Cannot allocate {aval.sem_type}.") memspace = _memory_space_to_mosaic_attribute(TPUMemorySpace.SEMAPHORE) return ir.MemRefType.get((), sem_type, memory_space=memspace) - if dtypes.issubdtype(aval.dtype, dtypes.prng_key): - shape = aval.dtype._impl.key_shape - if pl_random.is_pallas_impl(aval.dtype._impl): - if memory_space is None: - memory_space = TPUMemorySpace.SMEM - if memory_space != TPUMemorySpace.SMEM: - raise ValueError( - f"PRNG keys must be stored in SMEM. Got {memory_space}" - ) - memspace = _memory_space_to_mosaic_attribute(memory_space) - return ir.MemRefType.get(shape, _dtype_to_ir_type(np.dtype(np.uint32)), - memory_space=memspace) if isinstance(aval, state.AbstractRef): if shape is None: shape = aval.shape + if memory_space is None: + memory_space = aval.memory_space memspace = _memory_space_to_mosaic_attribute(memory_space) shape = dynamic_shape_replacement_fn(shape) return ir.MemRefType.get(shape, @@ -301,7 +363,7 @@ def aval_to_ir_type( raise NotImplementedError(aval) -def ir_constant(x, mlir_type=None): +def ir_constant(x: Any, mlir_type: ir.Type | None = None) -> ir.Value: if not hasattr(x, "dtype"): if isinstance(x, int): x = np.array(x, np.int32) @@ -309,59 +371,56 @@ def ir_constant(x, mlir_type=None): x = np.array(x, np.float32) if not mlir_type: mlir_type = _dtype_to_ir_type(x.dtype) - if isinstance(x, int) or np.issubdtype(x.dtype, np.integer): + if isinstance(x, int) or jnp.issubdtype(x.dtype, np.integer): return arith.constant(mlir_type, ir.IntegerAttr.get(mlir_type, int(x))) - elif isinstance(x, float) or x.dtype == np.float32: - return arith.constant(mlir_type, ir.FloatAttr.get(mlir_type, float(x))) - elif x.dtype == jnp.bfloat16: + elif isinstance(x, float) or jnp.issubdtype(x.dtype, jnp.floating): return arith.constant(mlir_type, ir.FloatAttr.get(mlir_type, float(x))) elif x.dtype == jnp.bool_: return arith.constant(mlir_type, ir.BoolAttr.get(bool(x))) raise NotImplementedError(x.dtype) -lowering_rules = {} +lowering_rules = {kernel_type: {} for kernel_type in tpu_core.KernelType} skip_mlir_conversions = set() + +T = TypeVar("T") + + +def register_lowering_rule( + prim: jax_core.Primitive, + *, + kernel_types: Collection[tpu_core.KernelType] = (tpu_core.KernelType.TC,), + ensure_mlir_values: bool = True, +) -> Callable[[T], T]: + def decorator(rule: T) -> T: + for kernel_type in kernel_types: + lowering_rules[kernel_type][prim] = rule + if not ensure_mlir_values: + skip_mlir_conversions.add((prim, kernel_type)) + return rule + + return decorator + + def _get_aval_physical_dtype_shape(aval): - dtype_physical_shape = jax_core.physical_aval(aval).shape[ - len(aval.shape) : - ] - return dtype_physical_shape + if should_physicalize_dtype(aval.dtype): # pytype: disable=attribute-error + physical_aval = jax_core.physical_aval(aval) + return physical_aval.shape[len(aval.shape) :] # pytype: disable=attribute-error + else: + return () def _get_arg_type( - dynamic_shape_replacement_fn: Callable[ - [tuple[jax.DimSize, ...]], tuple[jax.DimSize, ...] - ], - aval, - block_mapping: pallas_core.BlockMapping | None, -): + dynamic_shape_replacement_fn: DynamicShapeReplacementFn, + aval: ShapedAbstractValue, + shape: tuple[int, ...] | None = None, +) -> ir.Type: memory_space = None - if isinstance(aval, pallas_core.AbstractMemoryRef): - memory_space = aval.memory_space - # We assume unannotated memory refs are in VMEM - if memory_space is None: - memory_space = TPUMemorySpace.VMEM - if isinstance(aval, tpu_core.AbstractSemaphore): - return aval_to_ir_type(dynamic_shape_replacement_fn, aval), None - # TODO(necula): clean this None block_mapping - if block_mapping is None: - return ( - aval_to_ir_type( - dynamic_shape_replacement_fn, aval, memory_space=memory_space - ), - aval.shape, - ) - shape = tuple(1 if b is pallas_core.mapped else b for b in block_mapping.block_shape) - return ( - aval_to_ir_type( - dynamic_shape_replacement_fn, - aval, - shape=shape, - memory_space=memory_space, - ), - block_mapping.block_shape, + if isinstance(aval, state.AbstractRef): + memory_space = _memory_space_to_tpu_memory_space(aval.memory_space) + return aval_to_ir_type( + dynamic_shape_replacement_fn, aval, shape=shape, memory_space=memory_space ) @@ -375,39 +434,37 @@ def _canonicalize_dimension_semantic( @dataclasses.dataclass(init=False) class MosaicGridMapping: - grid: tuple[int, ...] | None + grid: pallas_core.GridMappingGrid | None grid_names: tuple[Hashable, ...] | None jaxpr: jax_core.Jaxpr - block_mappings: tuple[pallas_core.BlockMapping | None, ...] - mapped_dims: tuple[int, ...] + block_mappings: tuple[pallas_core.BlockMapping, ...] + vmapped_dims: tuple[int, ...] scalar_prefetch_types: tuple[ir.Type, ...] operand_types: tuple[ir.Type, ...] scratch_types: tuple[ir.Type, ...] grid_types: tuple[ir.Type, ...] scalar_prefetch_block_shapes: tuple[tuple[int, ...], ...] - operand_block_shapes: tuple[tuple[int, ...], ...] + operand_block_shapes: tuple[tuple[int | pallas_core.Squeezed, ...], ...] scratch_block_shapes: tuple[tuple[int, ...], ...] - mesh_info: MeshInfo | None - get_grid_indices: Callable | None + mesh_info: pallas_utils.MeshInfo | None + get_grid_indices: Callable[..., Any] def __init__( self, jaxpr: jax_core.Jaxpr, grid_mapping: pallas_core.GridMapping, - dimension_semantics: tuple[str | tpu_core.GridDimensionSemantics, ...] | None, + dimension_semantics: Sequence[tpu_core.DimensionSemantics] | None, mesh: mesh_lib.Mesh | None, - dynamic_shape_replacement_fn: Callable[ - [tuple[jax.DimSize, ...]], tuple[int, ...] - ], + dynamic_shape_replacement_fn: DynamicShapeReplacementFn, ): self.grid = grid_mapping.grid self.grid_names = grid_mapping.grid_names self.jaxpr = jaxpr self.block_mappings = grid_mapping.block_mappings - self.mapped_dims = grid_mapping.vmapped_dims + self.vmapped_dims = grid_mapping.vmapped_dims # TODO(mvoz): Generalize to not need this user_grid = tuple( - g for i, g in enumerate(self.grid) if i not in self.mapped_dims + g for i, g in enumerate(self.grid) if i not in self.vmapped_dims ) if dimension_semantics is None: dimension_semantics = ("arbitrary",) * len(user_grid) @@ -416,67 +473,84 @@ def __init__( ) if len(user_grid) != len(dimension_semantics): raise ValueError( - "Must have dimension semantics for each dimension of the grid." + "Length of grid does not match length of dimension semantics." + f" len(grid)={len(user_grid)}, {len(dimension_semantics)=}" ) - assert len(self.mapped_dims) + len(dimension_semantics) == len( + assert len(self.vmapped_dims) + len(dimension_semantics) == len( self.grid ), ( - f"Misconfigured grid: {self.mapped_dims=}, {dimension_semantics=}," + f"Misconfigured grid: {self.vmapped_dims=}, {dimension_semantics=}," f" {self.grid=}" ) # dimension_semantics is user provided and won't take into account vmap # dimensions. Here we add in parallel dimensions for the vmaps. semantics_iter = iter(dimension_semantics) self._dimension_semantics = tuple( - next(semantics_iter) if i not in self.mapped_dims else "parallel" + next(semantics_iter) if i not in self.vmapped_dims else "parallel" for i in range(len(self.grid)) ) - in_avals = [invar.aval for invar in self.jaxpr.invars] + in_avals = [ + cast(ShapedAbstractValue, invar.aval) for invar in self.jaxpr.invars + ] # jaxpr has signature [*scalar_prefetch, *consts, *in_ops, *out_ops, *scratch] scalar_prefetch_avals = in_avals[grid_mapping.slice_index_ops] operand_avals = in_avals[grid_mapping.slice_block_ops] scratch_avals = in_avals[grid_mapping.slice_scratch_ops] - self.scalar_prefetch_types, _ = unzip2([ - _get_arg_type(dynamic_shape_replacement_fn, aval, None) + self.scalar_prefetch_types = tuple( + _get_arg_type(dynamic_shape_replacement_fn, aval) for aval in scalar_prefetch_avals - ]) - self.scalar_prefetch_block_shapes = tuple( + ) + scalar_prefetch_block_shapes = tuple( aval.shape for aval in scalar_prefetch_avals) - self.operand_types, self.operand_block_shapes = unzip2([ - _get_arg_type(dynamic_shape_replacement_fn, aval, block_mapping) - for aval, block_mapping in zip(operand_avals, self.block_mappings) - ]) - self.scratch_types, _ = unzip2([ - _get_arg_type(dynamic_shape_replacement_fn, aval, None) + self.scalar_prefetch_block_shapes = tuple( + map(_maybe_physicalize_block_shape, scalar_prefetch_avals, + scalar_prefetch_block_shapes)) + operands_types = [] + operand_block_shapes = [] + for aval, bm in zip(operand_avals, self.block_mappings): + shape = pallas_core._get_block_shape(bm.block_shape) + # Keep around squeezed as a sentinel for the lowering rules. + block_shape = tuple( + pallas_core.squeezed + if isinstance(b, pallas_core.Squeezed) + else pallas_core._get_block_dim_size(b) + for b in bm.block_shape + ) + block_shape = _maybe_physicalize_block_shape(aval, block_shape) + operands_types.append( + _get_arg_type(dynamic_shape_replacement_fn, aval, shape=shape) + ) + operand_block_shapes.append(block_shape) + self.operand_types = tuple(operands_types) + self.operand_block_shapes = tuple(operand_block_shapes) + self.scratch_types = tuple( + _get_arg_type(dynamic_shape_replacement_fn, aval) for aval in scratch_avals - ]) + ) self.scratch_block_shapes = tuple( aval.shape if not isinstance(aval, tpu_core.AbstractSemaphore) else None for aval in scratch_avals ) - self.grid_types, _ = unzip2([ + self.grid_types = ( _get_arg_type( - dynamic_shape_replacement_fn, - pallas_core.index_map_grid_aval, - None, - ) - for _ in range(len(self.grid)) - ]) + dynamic_shape_replacement_fn, pallas_core.index_map_grid_aval + ), + ) * len(self.grid) + self._prepare_mesh_info(mesh) if grid_mapping.get_grid_indices is None: - - # Avoid using self.mapped_dims within the function, since doing so will + # Avoid using self.vmapped_dims within the function, since doing so will # introduce a self->_get_grid_indices->self reference cycle that means # MosaicGridMapping instances can only ever be deleted by GC, rather than # by their reference counts going to 0. - mapped_dims = self.mapped_dims + vmapped_dims = self.vmapped_dims def _get_grid_indices(indices, maybe_include_mapped_dims: bool): if maybe_include_mapped_dims: return indices return tuple( - idx for i, idx in enumerate(indices) if i not in mapped_dims + idx for i, idx in enumerate(indices) if i not in vmapped_dims ) self.get_grid_indices = _get_grid_indices @@ -498,13 +572,7 @@ def _prepare_mesh_info(self, mesh: mesh_lib.Mesh | None): "Cannot shadow axis mesh axis names with grid names. mesh axis" f" names: {mesh.axis_names}, grid names: {self.grid_names}" ) - # We need mesh <-> logical translation tables. Since the logical IDs are - # just linearized versions of the mesh IDs, we create those tables. - mesh_strides = pallas_utils.strides_from_shape(tuple( - mesh.shape[a] for a in axis_names - )) - mesh_shape = tuple(mesh.shape.values()) - self.mesh_info = MeshInfo(mesh_shape, axis_names, mesh_strides) + self.mesh_info = pallas_utils.MeshInfo.from_mesh(mesh) def maybe_compress_grid(self): # If we have many leading parallel dimensions, we should "compress" them @@ -516,21 +584,28 @@ def maybe_compress_grid(self): def has_communication(self) -> bool: nonlocal_axis_names = set() def _get_nonlocal_axis_names(jaxpr: jax_core.Jaxpr): - return { + axis_name_effects = { e.name for e in jaxpr.effects if isinstance(e, jax_core.NamedAxisEffect) and (not self.grid_names or e.name not in self.grid_names) } + # Comms effects catch the case where we have comms but don't actually have + # a named axis effect. We can remove these once all comms primitives + # require an axis name. + comms_effects = { + "comms" + for e in jaxpr.effects + if isinstance(e, pallas_core.CommsEffect) + } + return comms_effects | axis_name_effects nonlocal_axis_names.update(_get_nonlocal_axis_names(self.jaxpr)) for bm in self.block_mappings: - if bm is not None: - nonlocal_axis_names.update(_get_nonlocal_axis_names(bm.index_map_jaxpr)) + nonlocal_axis_names.update( + _get_nonlocal_axis_names(bm.index_map_jaxpr.jaxpr) + ) return bool(nonlocal_axis_names) - def get_extra_args(self) -> tuple[Any, ...]: - return () - def get_dimension_semantics(self) -> ir.ArrayAttr: def _get_semantics(s: str | None) -> str: @@ -545,12 +620,6 @@ def _get_semantics(s: str | None) -> str: ) ) -@dataclasses.dataclass -class MeshInfo: - mesh_shape: tuple[int, ...] - axis_names: list[str] - mesh_strides: tuple[int, ...] - def _check_block_mappings( block_mappings: tuple[pallas_core.BlockMapping, ...], @@ -559,40 +628,53 @@ def _check_block_mappings( ) -> None: del lowering_context # originally needed for forward compat for bm in block_mappings: - rank = len(bm.block_shape) + dtype = bm.array_aval.dtype + array_shape = bm.array_aval.shape + if should_physicalize_dtype(dtype): + physical_element_aval = jax_core.physical_element_aval(dtype) + physical_dtype = physical_element_aval.dtype + physical_array_shape = jax_core.physical_shape(array_shape, dtype) + physical_block_shape = bm.block_shape + tuple( + pallas_core.Blocked(i) for i in physical_element_aval.shape) + else: + physical_dtype = dtype + physical_array_shape = array_shape + physical_block_shape = bm.block_shape + + rank = len(physical_block_shape) # TODO(necula): add tests for SMEM blocks with trivial windowing # We support scalars too - if (bm.block_aval.memory_space == tpu_core.TPUMemorySpace.SMEM and - bm.has_trivial_window()): + memory_space = _memory_space_to_tpu_memory_space(bm.block_aval.memory_space) + if memory_space == tpu_core.MemorySpace.SMEM and bm.has_trivial_window(): continue - if bm.block_aval.memory_space == tpu_core.TPUMemorySpace.SEMAPHORE: + if memory_space == tpu_core.MemorySpace.SEMAPHORE: continue def err_details(): return (f"Block spec for {bm.origin} in pallas_call {debug_info.func_src_info} " "has block shape " - f"{bm.block_shape}, array shape {bm.array_shape_dtype.shape}, " + f"{physical_block_shape}, array shape {physical_array_shape}, " # TODO(necula): add index_map source location info f"and index_map {bm.index_map_jaxpr.jaxpr}, in " f"memory space {bm.block_aval.memory_space}." - "\nSee details at https://jax.readthedocs.io/en/latest/pallas/grid_blockspec.html#pallas-blockspec") + "\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec") if rank < 1: raise ValueError( "The Pallas TPU lowering currently supports only blocks of " "rank >= 1. " + err_details()) - if (bm.block_aval.memory_space == tpu_core.TPUMemorySpace.ANY and - not bm.has_trivial_window()): + if ( + memory_space is ANY or memory_space == tpu_core.MemorySpace.HBM + ) and not bm.has_trivial_window(): raise ValueError( "The Pallas TPU lowering currently supports in memory space ANY " "only blocks having the same block shape as the array shape " "and a trivial index_map (returning all 0s)." + err_details()) - unmapped_bs = [ - 1 if bs is pallas_core.mapped else bs for bs in bm.block_shape] - bs0, as0 = unmapped_bs[-1], bm.array_shape_dtype.shape[-1] + unmapped_bs = pallas_core._get_block_shape(physical_block_shape) + bs0, as0 = unmapped_bs[-1], physical_array_shape[-1] if rank >= 2: - bs1, as1 = unmapped_bs[-2], bm.array_shape_dtype.shape[-2] + bs1, as1 = unmapped_bs[-2], physical_array_shape[-2] else: bs1, as1 = 1, 1 @@ -620,11 +702,10 @@ def err_details(): ) else: assert rank == 1 - # bools get a bitwidth of 32 due to how mosaic handles them - if bm.array_shape_dtype.dtype == jnp.bool_: - bitwidth = 32 + if bm.array_aval.dtype == jnp.bool_: + bitwidth = dtypes.itemsize_bits(BOOL_MEMREF_TYPE) else: - bitwidth = lax_internal._bit_width(bm.array_shape_dtype.dtype) + bitwidth = dtypes.itemsize_bits(physical_dtype) packing = 32 // bitwidth tiling_size = 128 * packing evenly_divisible = (bs0 == as0 or bs0 % tiling_size == 0) @@ -635,7 +716,7 @@ def err_details(): " shape is equal to the first (and only) dimension of the array" " shape, or 2) the first (and only) dimension of the block shape" f" is a multiple of the tiling size ({tiling_size} = 128 * (32 //" - f" {lax_internal._bit_width(bm.array_shape_dtype.dtype)})) of the" + f" {dtypes.itemsize_bits(physical_dtype)})) of the" " array shape. " + err_details() ) @@ -643,21 +724,21 @@ def err_details(): def lower_jaxpr_to_module( lowering_context: mlir.LoweringRuleContext, - ctx: ir.Context, grid_mapping: pallas_core.GridMapping, jaxpr: jax_core.Jaxpr, *, - dimension_semantics: ( - tuple[str | tpu_core.GridDimensionSemantics, None, ...] | None - ), + dimension_semantics: Sequence[tpu_core.DimensionSemantics] | None, + kernel_type: tpu_core.KernelType, mesh: mesh_lib.Mesh | None = None, - for_verification: bool = False, dynamic_shape_replacement_enabled: bool = False, -) -> tuple[Module, tuple[Any, ...]]: +) -> ir.Module: + backend = lowering_context.module_context.get_backend(optional=True) # NOTE: We should bump this periodically - if is_cloud_tpu_older_than(2025, 1, 10): + if backend is not None and is_cloud_tpu_older_than(2025, 8, 1, backend): + platform_version = xla_bridge.get_backend().platform_version raise RuntimeError( - "Pallas TPU requires a libTPU version that's at most a month old" + "Pallas TPU requires a libtpu version that's at most a month old. Found" + f" version string:\n{platform_version}" ) debug_info = jaxpr.debug_info _mosaic_lowering_dynamic_shape_env = None @@ -695,20 +776,26 @@ def dynamic_shape_replacement_fn( sym_tab = ir.SymbolTable(m.operation) func_op = lower_jaxpr_to_func( - ctx, jaxpr, mosaic_grid_mapping=mosaic_grid_mapping, name="main", - for_verification=for_verification, + kernel_type=kernel_type, forward_compatible=lowering_context.is_forward_compat(), dynamic_shape_replacement_fn=dynamic_shape_replacement_fn, dynamic_shape_replacement_enabled=dynamic_shape_replacement_enabled, + backend=backend, ) m.body.append(func_op) sym_tab.insert(func_op) window_params = [] static_grid = None grid = mosaic_grid_mapping.grid + if not grid and any( + not bm.has_trivial_window() for bm in grid_mapping.block_mappings + ): + raise NotImplementedError( + "Non-trivial windowing is not supported for grid-free pallas_call." + ) if grid: for i, bm in enumerate(grid_mapping.block_mappings): func_name = f"transform_{i}" @@ -716,27 +803,35 @@ def dynamic_shape_replacement_fn( tpu_memory_space = _memory_space_to_tpu_memory_space( bm.block_aval.memory_space) if ( - tpu_memory_space == tpu_core.TPUMemorySpace.ANY - or tpu_memory_space == tpu_core.TPUMemorySpace.SEMAPHORE + tpu_memory_space is ANY + or tpu_memory_space == tpu_core.MemorySpace.HBM + or tpu_memory_space == tpu_core.MemorySpace.SEMAPHORE ): # We checked above that the block does not require windowing. window_params.append(ir.DictAttr.get()) continue mlir_func = lower_jaxpr_to_transform_func( - ctx, bm.index_map_jaxpr.jaxpr, bm.block_aval, name=func_name, mosaic_grid_mapping=mosaic_grid_mapping, - for_verification=for_verification, + kernel_type=kernel_type, forward_compatible=lowering_context.is_forward_compat(), dynamic_shape_replacement_fn=dynamic_shape_replacement_fn, + backend=backend, ) assert mlir_func.verify(), mlir_func - block_shape = [ - 1 if b is pallas_core.mapped else b for b in bm.block_shape - ] + block_shape = list(pallas_core._get_block_shape(bm.block_shape)) + + # Force single-buffering pipelining for trivial windowing in VMEM. + pipeline_mode = bm.pipeline_mode + if ( + tpu_memory_space == tpu_core.MemorySpace.VMEM + and bm.has_trivial_window() + ): + pipeline_mode = pallas_core.Buffered(1) + # If we have an extended dtype, we need to add the block shape for the # remaining physical dtype. block_shape += list(_get_aval_physical_dtype_shape(bm.block_aval.inner_aval)) @@ -746,28 +841,52 @@ def dynamic_shape_replacement_fn( window_bounds=window_shape, transform_indices=ir.FlatSymbolRefAttr.get(func_name), ) - if isinstance(bm.indexing_mode, pallas_core.Unblocked): - if bm.indexing_mode.padding is None: - pad_low = pad_high = [0] * len(bm.block_shape) - else: - pad_low, pad_high = map(list, zip(*bm.indexing_mode.padding)) + for bd in bm.block_shape: + if not isinstance( + bd, (pallas_core.Element, pallas_core.Squeezed, pallas_core.Blocked) + ): + raise NotImplementedError( + "Unsupported block dimension type: " + f"{type(bd)} for block shape: {bm.block_shape}" + ) + is_element_block = [isinstance(bd, pallas_core.Element) + for bd in bm.block_shape] + if any(is_element_block): + is_element_or_squeezed_block = [ + isinstance(bd, (pallas_core.Element, pallas_core.Squeezed)) + for bd in bm.block_shape + ] + if not all(is_element_or_squeezed_block): + raise NotImplementedError( + "All block dimensions must be Elements or none of them can be" + " Elements." + ) + padding = [ + bd.padding if isinstance(bd, pallas_core.Element) else (0, 0) + for bd in bm.block_shape + ] + pad_low, pad_high = map(list, zip(*padding)) block_params["window_kind"] = ir.Attribute.parse( f"#tpu.element_window<{pad_low},{pad_high}>" ) - if bm.pipeline_mode is not None: - if not isinstance(bm.pipeline_mode, pallas_core.Buffered): + if pipeline_mode is not None: + if not isinstance(pipeline_mode, pallas_core.Buffered): raise LoweringException( - f"Unsupported pipeline mode: {bm.pipeline_mode}." + f"Unsupported pipeline mode: {pipeline_mode}." + ) + if pipeline_mode.use_lookahead: + raise NotImplementedError( + "Lookahead is not supported for XLA pipeline emitter lowering." ) - buffer_count = bm.pipeline_mode.buffer_count + buffer_count = pipeline_mode.buffer_count if buffer_count < 1 or buffer_count > 2: raise LoweringException( "Only single (1) and double (2) buffering are supported. Got" f" {buffer_count}." ) - pipeline_mode = "synchronous" if buffer_count == 1 else "double_buffered" + pipeline_mode_str = "synchronous" if buffer_count == 1 else "double_buffered" block_params["pipeline_mode"] = ir.Attribute.parse( - f"#tpu.pipeline_mode<{pipeline_mode}>" + f"#tpu.pipeline_mode<{pipeline_mode_str}>" ) window_params.append(ir.DictAttr.get(block_params)) m.body.append(mlir_func) @@ -802,9 +921,11 @@ def dynamic_shape_replacement_fn( else: grid_vars = [] - invars = [invar.aval for invar in jaxpr.invars] + invars = cast( + list[jax_core.ShapedArray], [invar.aval for invar in jaxpr.invars] + ) # Faux shape for grid, just to get the avals - invars.append(jax.ShapeDtypeStruct(grid_vars, jax.numpy.int32)) + invars.append(jax_core.ShapedArray(grid_vars, jnp.int32)) args_dimvars = shape_poly.all_dim_vars(invars) # This is dimexpr var -> placeholder value for when we jit the dim expr @@ -840,21 +961,19 @@ def dynamic_shape_replacement_fn( m.operation.attributes[ "tpu.dynamic_dimension_mapping_arg_name_" + str(placeholder) ] = ir.StringAttr.get(arg_name_str) - return m, mosaic_grid_mapping.get_extra_args() + return m def lower_jaxpr_to_transform_func( - ctx: ir.Context, jaxpr: jax_core.Jaxpr, aval: jax_core.AbstractValue, *, name: str, mosaic_grid_mapping: MosaicGridMapping, - for_verification: bool, - forward_compatible: bool, - dynamic_shape_replacement_fn: ( - Callable[[tuple[jax.DimSize, ...]], tuple[int, ...]] | None - ) = None, + kernel_type: tpu_core.KernelType, + forward_compatible: bool, + backend: Any | None, + dynamic_shape_replacement_fn: DynamicShapeReplacementFn | None = None, ) -> func.FuncOp: num_grid = len(mosaic_grid_mapping.grid_types) arg_types = [ @@ -871,25 +990,18 @@ def body_func(*args): *mosaic_grid_mapping.scalar_prefetch_block_shapes, ] - mesh_info = mosaic_grid_mapping.mesh_info - if mesh_info is not None: - mesh_context = MeshContext( - mesh_info.mesh_shape, mesh_info.axis_names, mesh_info.mesh_strides - ) - else: - mesh_context = None lowering_context = LoweringContext( - ctx, mosaic_grid_mapping.grid, mosaic_grid_mapping.grid_names, - mosaic_grid_mapping.mapped_dims, + mosaic_grid_mapping.vmapped_dims, None, arg_block_shapes, source_info_util.NameStack(), - mesh_context=mesh_context, + mesh_context=mosaic_grid_mapping.mesh_info, + kernel_type=kernel_type, traceback_caches=mlir.TracebackCaches(), - for_verification=for_verification, forward_compatible=forward_compatible, + backend=backend, dynamic_shape_replacement_fn=dynamic_shape_replacement_fn, ) out = jaxpr_subcomp(lowering_context, jaxpr, *jaxpr_indices, @@ -897,13 +1009,13 @@ def body_func(*args): assert isinstance(aval, state.AbstractRef), aval # If we have an extended dtype, we need to add 0s for the block indices # for the remaining physical dtype. - out += [ - ir_constant(0, mlir_type=_dtype_to_ir_type(jnp.dtype("int32"))) - ] * len(_get_aval_physical_dtype_shape(aval.inner_aval)) + out += [ir_constant(0, mlir_type=_dtype_to_ir_type(jnp.int32))] * len( + _get_aval_physical_dtype_shape(aval.inner_aval) + ) return out body_func.__name__ = name - body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func) + body: Any = func.FuncOp.from_py_func(*arg_types, name=name)(body_func) try: body.func_op.verify() except ir.MLIRError as e: @@ -912,16 +1024,14 @@ def body_func(*args): def lower_jaxpr_to_func( - ctx: ir.Context, jaxpr: jax_core.Jaxpr, *, mosaic_grid_mapping: MosaicGridMapping, name: str, - for_verification: bool, + kernel_type: tpu_core.KernelType, forward_compatible: bool, - dynamic_shape_replacement_fn: ( - Callable[[tuple[jax.DimSize, ...]], tuple[int, ...]] | None - ) = None, + backend: Any | None, + dynamic_shape_replacement_fn: DynamicShapeReplacementFn | None = None, dynamic_shape_replacement_enabled: bool = False, ) -> func.FuncOp: num_grid = len(mosaic_grid_mapping.grid_types) @@ -943,32 +1053,25 @@ def body_func(*args): jaxpr_indices = mosaic_grid_mapping.get_grid_indices( grid_indices, maybe_include_mapped_dims=False ) - mesh_info = mosaic_grid_mapping.mesh_info - if mesh_info is not None: - mesh_context = MeshContext( - mesh_info.mesh_shape, mesh_info.axis_names, mesh_info.mesh_strides - ) - else: - mesh_context = None lowering_context = LoweringContext( - ctx, mosaic_grid_mapping.grid, mosaic_grid_mapping.grid_names, - mosaic_grid_mapping.mapped_dims, + mosaic_grid_mapping.vmapped_dims, jaxpr_indices, arg_block_shapes, source_info_util.NameStack(), - mesh_context=mesh_context, + mesh_context=mosaic_grid_mapping.mesh_info, + kernel_type=kernel_type, traceback_caches=mlir.TracebackCaches(), - for_verification=for_verification, forward_compatible=forward_compatible, + backend=backend, dynamic_shape_replacement_fn=dynamic_shape_replacement_fn, ) return jaxpr_subcomp( lowering_context, jaxpr, *scalar_prefetch, *operands_and_scratch ) body_func.__name__ = name - body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func) + body: Any = func.FuncOp.from_py_func(*arg_types, name=name)(body_func) if dynamic_shape_replacement_enabled: # Skip verification for dynamic shape replacement - you can potentially # produce ir like ex: add(x[placeholder_0, placeholder_1], y[128, 128]) @@ -988,8 +1091,8 @@ def f_lowered(ctx: LoweringRuleContext, *args, **params): wrapped_fun = lu.wrap_init( f, params, debug_info=api_util.debug_info("mosaic lower_fun", f, - args, params)) - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) + args, {})) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) if consts: raise NotImplementedError jaxpr = pe.convert_constvars_jaxpr(jaxpr) @@ -1032,7 +1135,7 @@ def _compute_name_stack_updates( def jaxpr_subcomp( ctx: LoweringContext, jaxpr: jax_core.Jaxpr, *args: ir.Value -) -> Sequence[ir.Value]: +) -> list[ir.Value]: assert not jaxpr.constvars env = {} block_shape_env = {} @@ -1060,42 +1163,53 @@ def write_env(var: jax_core.Var, val): current_name_stack.extend(initial_name_stack) for eqn in jaxpr.eqns: invals = map(read_env, eqn.invars) - source_info = eqn.source_info.replace( - name_stack=ctx.name_stack + eqn.source_info.name_stack + # Skip lowering equations that don't have used outputs and no side-effects. + # This allows us to avoid lowering eqns that have unlowerable types (e.g. + # float0) in them. + # TODO(sharadmv): Remove this when DCEing the jaxpr works properly. + if ( + all(isinstance(v, jax_core.DropVar) for v in eqn.outvars) + and not eqn.effects + ): + continue + eqn_name_stack = ctx.name_stack + eqn.source_info.name_stack + loc = mlir.source_info_to_location( # pytype: disable=wrong-arg-types + ctx, eqn.primitive, eqn_name_stack, eqn.source_info.traceback ) - loc = mlir._source_info_to_location(ctx, eqn.primitive, source_info) with (source_info_util.user_context(eqn.source_info.traceback), loc, eqn.ctx.manager): - if eqn.primitive in lowering_rules: - if eqn.primitive not in skip_mlir_conversions: - invals = [_ensure_mlir_value(x, v.aval) - for x, v in zip(invals, eqn.invars)] + if eqn.primitive in lowering_rules[ctx.kernel_type]: + if (eqn.primitive, ctx.kernel_type) not in skip_mlir_conversions: + invals = [ + _ensure_mlir_value(x, cast(ShapedAbstractValue, v.aval)) + for x, v in zip(invals, eqn.invars) + ] block_shapes = map(read_block_shape, eqn.invars) rule_context = LoweringRuleContext( ctx, - [v.aval for v in eqn.invars], - [v.aval for v in eqn.outvars], + cast(Sequence[ShapedAbstractValue], [v.aval for v in eqn.invars]), + cast(Sequence[ShapedAbstractValue], [v.aval for v in eqn.outvars]), block_shapes, ) # Insert trace_start and trace_stop ops on named_scope boundaries. - name_stack = [scope.name for scope in source_info.name_stack.stack] + name_stack = [scope.name for scope in eqn_name_stack.stack] popped, pushed = _compute_name_stack_updates( current_name_stack, name_stack) current_name_stack = name_stack for _ in popped: - tpu.TraceStopOp() + tpu.trace_stop() for name in pushed: - tpu.TraceStartOp(message=name, level=10) + tpu.trace_start(message=name, level=10) try: - ans = lowering_rules[eqn.primitive]( + ans = lowering_rules[ctx.kernel_type][eqn.primitive]( rule_context, *invals, **eqn.params ) except LoweringException: raise # We only add the extra info to the innermost exception. except Exception as e: - if not pallas_call._verbose_errors_enabled(): + if not config.jax_pallas_verbose_errors.value: raise msg = (f"{type(e).__name__}: {e}\n" + "Additional diagnostics: \n" + @@ -1103,15 +1217,16 @@ def write_env(var: jax_core.Var, val): new_error = LoweringException(msg) # We insert the traceback here so that the user code shows # up in the traceback for the post-transform error. - if source_info.traceback is not None: - tb = source_info.traceback.as_python_traceback() + if eqn.source_info.traceback is not None: + tb = eqn.source_info.traceback.as_python_traceback() new_error.__traceback__ = traceback_util.filter_traceback(tb) raise new_error from e else: raise NotImplementedError( - "Unimplemented primitive in Pallas TPU lowering: " - f"{eqn.primitive.name}. " - "Please file an issue on https://github.com/jax-ml/jax/issues.") + "Unimplemented primitive in Pallas TPU lowering for" + f" {ctx.kernel_type}: {eqn.primitive.name}. Please file an issue on" + " https://github.com/jax-ml/jax/issues." + ) if eqn.primitive.multiple_results: foreach(write_env, eqn.outvars, ans) else: @@ -1121,7 +1236,7 @@ def write_env(var: jax_core.Var, val): popped, pushed = _compute_name_stack_updates( current_name_stack, initial_name_stack) for _ in popped: - tpu.TraceStopOp() + tpu.trace_stop() assert len(pushed) == 0 outvals = map(read_env, jaxpr.outvars) @@ -1132,12 +1247,14 @@ def write_env(var: jax_core.Var, val): return outvals -def _ensure_mlir_value(val, aval): +def _ensure_mlir_value(val: object, aval: ShapedAbstractValue) -> Any: if isinstance(val, ir.Value): return val if isinstance(val, KeyScalarBundle): + # TODO(slebedev): Drop this branch and change the return type to ir.Value. return val - elif isinstance(val, (np.generic, np.ndarray, int, float)): + elif isinstance(val, (np.generic, np.ndarray, int, float, + literals.TypedNdArray)): return ir_constant(val, _dtype_to_ir_type(aval.dtype)) else: raise RuntimeError( @@ -1145,6 +1262,7 @@ def _ensure_mlir_value(val, aval): ) +@register_lowering_rule(state_primitives.get_p, ensure_mlir_values=False) def _get_lowering_rule( ctx: LoweringRuleContext, ref, *idx, tree, ): @@ -1161,10 +1279,7 @@ def _get_lowering_rule( return _load_lowering_rule(ctx, *args_flat, args_tree=args_tree) -lowering_rules[state_primitives.get_p] = _get_lowering_rule -skip_mlir_conversions.add(state_primitives.get_p) - - +@register_lowering_rule(state_primitives.swap_p, ensure_mlir_values=False) def _swap_lowering_rule( ctx: LoweringRuleContext, ref, @@ -1186,12 +1301,9 @@ def _swap_lowering_rule( ) return _masked_swap_lowering_rule(ctx, *args_flat, args_tree=args_tree) -lowering_rules[state_primitives.swap_p] = _swap_lowering_rule -skip_mlir_conversions.add(state_primitives.swap_p) - def _make_index(s): - if isinstance(s, (int, np.ndarray)): + if isinstance(s, (int, np.ndarray, literals.TypedNdArray)): return ir_constant(s, ir.IndexType.get()) if s.type == ir.IndexType.get(): return s @@ -1230,7 +1342,7 @@ def _index_to_start_size_stride( def _indexer_to_start_size_stride( indexer: NDIndexer, - ref_block_shape: tuple[int | pallas_core.Mapped, ...], + ref_block_shape: tuple[int | pallas_core.Squeezed, ...], *, cast_to_index: bool, ) -> tuple[ @@ -1238,21 +1350,21 @@ def _indexer_to_start_size_stride( tuple[int | ir.Value, ...], tuple[int, ...], tuple[bool, ...], - tuple[int | pallas_core.Mapped, ...], + tuple[int | pallas_core.Squeezed, ...], ]: indices_iter = iter(indexer.indices) starts, sizes, strides, squeeze_dims = [], [], [], [] for s in ref_block_shape: - start, size, stride, squeeze_dim = ( - ( - _maybe_cast_to_index(cast_to_index, 0), - 1, - 1, - True, + match s: + case pallas_core.Squeezed(): + start = _maybe_cast_to_index(cast_to_index, 0) + size = 1 + stride = 1 + squeeze_dim = True + case _: + start, size, stride, squeeze_dim = _index_to_start_size_stride( + next(indices_iter), cast_to_index # pytype: disable=wrong-arg-types ) - if s is pallas_core.mapped - else _index_to_start_size_stride(next(indices_iter), cast_to_index) - ) starts.append(start) sizes.append(size) strides.append(stride) @@ -1274,10 +1386,9 @@ def _slice_memref( ref: ir.Value, indexer: NDIndexer, ref_dtype: DTypeLike, - ref_block_shape: tuple[int | pallas_core.Mapped, ...], -) -> tuple[ir.Value, tuple[int | pallas_core.Mapped, ...]]: + ref_block_shape: tuple[int | pallas_core.Squeezed, ...], +) -> tuple[ir.Value, tuple[int | pallas_core.Squeezed, ...]]: assert ref_block_shape is not None - target_shape = indexer.get_indexer_shape() starts, sizes, strides, squeeze_dims, ref_block_shape = ( _indexer_to_start_size_stride( indexer, @@ -1287,26 +1398,43 @@ def _slice_memref( ) if not all((s is None or s == 1) for s in strides): raise NotImplementedError("Strided slices of references are unsupported.") - dynamic_sizes = tuple(s for s in sizes if isinstance(s, ir.Value)) + ir_dynamic_size = ir.ShapedType.get_dynamic_size() - static_sizes = tuple(s if not isinstance(s, ir.Value) - else ir_dynamic_size for s in sizes) - target_ref_ty = ir.MemRefType.get( - static_sizes, - _dtype_to_ir_type(ref_dtype), - memory_space=ref.type.memory_space, + static_starts = [] + for s in starts: + if not isinstance(s, ir.Value): + static_starts.append(s) + elif (v := _fold_and_get_constant_value(s)) is not None: + static_starts.append(v) + else: + static_starts.append(ir_dynamic_size) + + static_sizes = [] + dynamic_sizes = [] + for s in sizes: + if not isinstance(s, ir.Value): + static_sizes.append(s) + elif (v := _fold_and_get_constant_value(s)) is not None: + static_sizes.append(v) + else: + static_sizes.append(ir_dynamic_size) + dynamic_sizes.append(s) + + ref_ty = ir.MemRefType(ref.type) + out_ty = ir.MemRefType.get( + static_sizes, ref_ty.element_type, memory_space=ref_ty.memory_space ) - out = tpu.memref_slice(target_ref_ty, ref, starts, dynamic_sizes) + out = tpu.memref_slice(out_ty, ref, starts, dynamic_sizes) if any(squeeze_dims): - # We need to squeeze out some dimensions - static_sizes = tuple(s if not isinstance(s, ir.Value) - else ir_dynamic_size for s in target_shape) - squeezed_ref_ty = ir.MemRefType.get( - static_sizes, - _dtype_to_ir_type(ref_dtype), - memory_space=ref.type.memory_space, - ) - out = tpu.memref_squeeze(squeezed_ref_ty, out) + # We need to squeeze out some dimensions. + ref_ty = out_ty + del out_ty + out_ty = ir.MemRefType.get( + [dim for i, dim in enumerate(ref_ty.shape) if not squeeze_dims[i]], + ref_ty.element_type, + memory_space=ref_ty.memory_space + ) + out = tpu.memref_squeeze(out_ty, out) return out, ref_block_shape @@ -1314,30 +1442,31 @@ def _bitcast_memref( ref: ir.Value, bitcaster: RefBitcaster, ref_dtype: DTypeLike, - ref_block_shape: tuple[int | pallas_core.Mapped, ...], -) -> tuple[ir.Value, DTypeLike, tuple[int | pallas_core.Mapped, ...]]: - src_bitwidth = dtype_bitwidth(ref_dtype) - dst_bitwidth = dtype_bitwidth(bitcaster.dtype) + ref_block_shape: tuple[int | pallas_core.Squeezed, ...], +) -> tuple[ir.Value, DTypeLike, tuple[int | pallas_core.Squeezed, ...]]: + src_bitwidth = dtypes.itemsize_bits(ref_dtype) + dst_bitwidth = dtypes.itemsize_bits(bitcaster.dtype) if src_bitwidth != dst_bitwidth: if len(ref_block_shape) < 2: raise NotImplementedError( "Bitcast 1D ref with bitwidth change is not supported." ) - if ref_block_shape[-2] is pallas_core.mapped: + if ref_block_shape[-2] is pallas_core.squeezed: raise NotImplementedError( "Bitcast a ref whose 2nd minormost dimension is squeezed when" " bitwidth changes." ) new_ref_dtype = bitcaster.dtype + ref_ty = ir.MemRefType(ref.type) target_ref_ty = ir.MemRefType.get( bitcaster.shape, _dtype_to_ir_type(new_ref_dtype), - memory_space=ref.type.memory_space, + memory_space=ref_ty.memory_space, ) new_ref_block_shape = list(ref_block_shape) if ( len(new_ref_block_shape) >= 2 - and new_ref_block_shape[-2] is not pallas_core.mapped + and new_ref_block_shape[-2] is not pallas_core.squeezed ): new_ref_block_shape[-2] = ( new_ref_block_shape[-2] * src_bitwidth // dst_bitwidth @@ -1353,8 +1482,8 @@ def _reshape_memref( ref: ir.Value, reshaper: RefReshaper, ref_dtype: DTypeLike, - ref_block_shape: tuple[int | pallas_core.Mapped, ...], -) -> tuple[ir.Value, DTypeLike, tuple[int | pallas_core.Mapped, ...]]: + ref_block_shape: tuple[int | pallas_core.Squeezed, ...], +) -> tuple[ir.Value, tuple[int, ...]]: if ref_dtype != reshaper.dtype: raise ValueError( f"Reshape a ref with dtype change: {reshaper.dtype} vs {ref_dtype}" @@ -1362,8 +1491,8 @@ def _reshape_memref( if len(ref_block_shape) < 2: raise NotImplementedError("Reshape 1D ref is not supported.") if ( - ref_block_shape[-2] is pallas_core.mapped - or ref_block_shape[-1] is pallas_core.mapped + ref_block_shape[-2] is pallas_core.squeezed + or ref_block_shape[-1] is pallas_core.squeezed ): raise NotImplementedError( "Reshape a ref with squeezed dimension on last two dimensions." @@ -1373,10 +1502,11 @@ def _reshape_memref( f"Reshape a ref with different number of elements: {ref_block_shape} " f"vs {reshaper.shape}" ) + ref_ty = ir.MemRefType(ref.type) target_ref_ty = ir.MemRefType.get( reshaper.shape, _dtype_to_ir_type(reshaper.dtype), - memory_space=ref.type.memory_space, + memory_space=ref_ty.memory_space, ) return ( tpu.memref_reshape(target_ref_ty, ref), @@ -1419,15 +1549,38 @@ class KeyScalarBundle: lowering pass. """ key_shape: tuple[int, ...] - scalars: list[ir.OpResult] + scalars: Sequence[ir.OpResult] + +def _canonicalize_transforms_to_indexer( + ref_aval, + transforms, + transforms_avals, +): + if not transforms: + prev_transforms, idx = [], NDIndexer.make_trivial_indexer(ref_aval.shape) + else: + if not isinstance(transforms[-1], NDIndexer): + ref_shape = state.get_transforms_shape(transforms, ref_aval.shape) + idx = NDIndexer.make_trivial_indexer(ref_shape) + prev_transforms = transforms + else: + (*prev_transforms, idx) = transforms + (*_, idx_aval) = transforms_avals + if any( + (not isinstance(a, primitives.Slice) and a.shape) + for a in idx_aval.indices + ): + raise ValueError("Cannot do int indexing on TPU") + return prev_transforms, idx + +@register_lowering_rule(primitives.load_p, ensure_mlir_values=False) def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): ref, transforms, mask, _ = args_tree.unflatten(args_flat) ref_aval, transforms_avals, _, _ = args_tree.unflatten(ctx.avals_in) - (*prev_transforms, idx) = transforms - # Select last aval, which is the one that will be used for the load. - (*_, idx_aval) = transforms_avals - + prev_transforms, idx = _canonicalize_transforms_to_indexer( + ref_aval, transforms, transforms_avals + ) if mask is not None: raise NotImplementedError @@ -1441,19 +1594,34 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): if isinstance(aval_out.dtype, prng.KeyTy) and pl_random.is_pallas_impl( aval_out.dtype._impl ): + # TODO(justinfu): Merge this with standard extended dtype handling. if not is_smem_load: raise ValueError("PRNG keys must be loaded from SMEM. Did you set " - "the memory space to TPUMemorySpace.SMEM in the " + "the memory space to MemorySpace.SMEM in the " "BlockSpec for the PRNG key input?") return _prng_key_load_lowering_rule(ctx, *args_flat, args_tree=args_tree) + if should_physicalize_dtype(aval_out.dtype): + physical_element_aval = jax_core.physical_element_aval(aval_out.dtype) # pytype: disable=wrong-arg-types + idx = cast(NDIndexer, idx) + if idx.int_indexer_shape: + raise NotImplementedError() + elt_slices = [ + indexing.Slice(0, size) for size in physical_element_aval.shape] + idx = NDIndexer( + indices=idx.indices + tuple(elt_slices), + shape=idx.shape + physical_element_aval.shape, + int_indexer_shape=(), + ) + physical_out_dtype = physical_element_aval.dtype + physical_out_shape = jax_core.physical_shape( + aval_out.shape, aval_out.dtype + ) + else: + physical_out_dtype = aval_out.dtype + physical_out_shape = aval_out.shape if not is_smem_load and not ref_block_shape: raise NotImplementedError( "Indexing into a ()-shaped Ref not yet supported on TPU.") - if any( - (not isinstance(a, primitives.Slice) and a.shape) - for a in idx_aval.indices - ): - raise ValueError("Cannot do int indexing on TPU") starts, sizes, strides, _, _ = _indexer_to_start_size_stride( idx, ref_block_shape, @@ -1471,7 +1639,7 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): raise ValueError( "Loads are only allowed on VMEM and SMEM references." + extra ) - load_aval = jax_core.ShapedArray(sizes, dtype=aval_out.dtype) + load_aval = jax_core.ShapedArray(sizes, dtype=physical_out_dtype) if need_stride: load_val = tpu.strided_load( aval_to_ir_type( @@ -1494,10 +1662,13 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): starts, ) if load_aval != aval_out: - vec_type = ir.VectorType.get(aval_out.shape, - _dtype_to_ir_type(aval_out.dtype, - is_kernel_boundary=True)) - load_val = vector.shape_cast(vec_type, load_val) + if physical_out_shape: + vec_type = ir.VectorType.get(physical_out_shape, + _dtype_to_ir_type(physical_out_dtype, + is_kernel_boundary=True)) + load_val = vector.shape_cast(vec_type, load_val) + else: + load_val = vector.extract(load_val, [], [0] * len(load_aval.shape)) return _maybe_cast_load_to_bool(ctx, aval_out, load_val) def _prng_key_load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree) -> KeyScalarBundle: @@ -1508,36 +1679,50 @@ def _prng_key_load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree We store these scalars in a bundle type called KeyScalarBundle, which has special case handling for functions that consume the key such as set_seed. """ - ref, _, _, _ = args_tree.unflatten(args_flat) + ref, transforms, _, _ = args_tree.unflatten(args_flat) + ref_aval, transforms_avals, _, _ = args_tree.unflatten( + ctx.avals_in + ) + prev_transforms, idx = _canonicalize_transforms_to_indexer( + ref_aval, transforms, transforms_avals + ) (aval_out,) = ctx.avals_out assert isinstance(aval_out.dtype, prng.KeyTy) - ref_block_shape = aval_out.dtype._impl.key_shape + key_shape = aval_out.dtype._impl.key_shape + ref_block_shape, *_ = ctx.block_shapes + idx = cast(NDIndexer, idx) + inner_aval = jax_core.physical_aval(ref_aval.inner_aval) # pytype: disable=wrong-arg-types + ref, ref_block_shape = _transform_ref( + ref, inner_aval.dtype, ref_block_shape, prev_transforms + ) - if len(ref_block_shape) != 2: - raise NotImplementedError("Seed key_data must be 2D.") - if tuple(ref_block_shape) != (1, 1): - raise NotImplementedError( - f"Seed key_data of shape != (1, 1) not supported. Got: {ref_block_shape}") + if len(key_shape) != 2: + raise NotImplementedError("Seed key_data must be 1D.") + if key_shape[0] != 1: + raise NotImplementedError("Leading dimension of seed key_data must be 1.") + if not all(s == 1 for s in idx.shape): + raise NotImplementedError("Can only load a single key per load.") + assert ref_block_shape[-2:] == key_shape, f"{ref_block_shape=} {key_shape=}" load_ops = [] - for i in range(ref_block_shape[0]): - idx = NDIndexer(indices=(0, i), shape=ref_block_shape, - int_indexer_shape=tuple()) + for i in range(key_shape[1]): + ref_shape = tuple( + dim for dim in ref_block_shape if dim is not pallas_core.squeezed + ) + scalar_idx = NDIndexer( + indices=(*idx.indices, 0, i), shape=ref_shape, int_indexer_shape=() + ) starts, _, _, _, _ = _indexer_to_start_size_stride( - idx, + scalar_idx, ref_block_shape, cast_to_index=True, ) load_ops.append(memref.load(ref, starts)) - return KeyScalarBundle(scalars=load_ops, key_shape=tuple(ref_block_shape)) - - -lowering_rules[primitives.load_p] = _load_lowering_rule -skip_mlir_conversions.add(primitives.load_p) + return KeyScalarBundle(scalars=load_ops, key_shape=tuple(key_shape)) def _maybe_cast_load_to_bool( - ctx, out_aval, val: ir.Value + ctx: LoweringRuleContext, out_aval, val: ir.Value ) -> tuple[ir.Value, jnp.dtype]: """Casts a memref load value to bool if the requested value is a bool. @@ -1564,13 +1749,13 @@ def _maybe_cast_load_to_bool( out_aval, is_kernel_boundary=True, ) - vector_zeros = arith.ConstantOp( + vector_zeros = arith.constant( load_vector_type, ir.DenseElementsAttr.get_splat(load_vector_type, const_zero) ) return arith.cmpi(predicate, val, vector_zeros) else: # Scalar case. - const_zero = arith.ConstantOp(load_scalar_type, const_zero) + const_zero = arith.constant(load_scalar_type, const_zero) return arith.cmpi(predicate, val, const_zero) @@ -1588,6 +1773,7 @@ def _maybe_cast_store_to_memref_type( return arith.extui(int_out_type, val) +@register_lowering_rule(primitives.swap_p, ensure_mlir_values=False) def _masked_swap_lowering_rule( ctx: LoweringRuleContext, *args_flat, args_tree, **_ ): @@ -1595,8 +1781,9 @@ def _masked_swap_lowering_rule( ref_aval, transforms_avals, val_aval, mask_aval = args_tree.unflatten( ctx.avals_in ) - (*prev_transforms, idx) = transforms - (*_, idx_aval) = transforms_avals + prev_transforms, idx = _canonicalize_transforms_to_indexer( + ref_aval, transforms, transforms_avals + ) if mask is not None: if val_aval.dtype.itemsize != 4: @@ -1619,11 +1806,6 @@ def _masked_swap_lowering_rule( (aval_out,) = ctx.avals_out if not isinstance(val, ir.Value): val = ir_constant(val, mlir_type=_dtype_to_ir_type(val_aval.dtype)) - if any( - (not isinstance(a, primitives.Slice) and a.shape) - for a in idx_aval.indices - ): - raise ValueError("Cannot do int indexing on TPU") if not is_smem_store and not ref_block_shape: raise NotImplementedError( "Indexing into a ()-shaped Ref not yet supported on TPU.") @@ -1643,7 +1825,7 @@ def _masked_swap_lowering_rule( result = memref.load(ref, starts) result = _maybe_cast_load_to_bool(ctx, val_aval, result) val = _maybe_cast_store_to_memref_type(ctx, val_aval, val) - memref.StoreOp(val, ref, starts) + memref.store(val, ref, starts) return result if not is_vmem_store: @@ -1659,12 +1841,12 @@ def _masked_swap_lowering_rule( raise ValueError("Cannot store scalars to VMEM") mem_slice_shape = list(aval_out.shape) - for i, a in enumerate(idx_aval.indices): + for i, a in enumerate(idx.indices): if not isinstance(a, primitives.Slice): mem_slice_shape.insert(i, 1) mem_slice_shape_iter = iter(mem_slice_shape) mem_slice_shape = [ - 1 if b is pallas_core.mapped else next(mem_slice_shape_iter) + 1 if b is pallas_core.squeezed else next(mem_slice_shape_iter) for b in ref_block_shape ] mem_aval = aval_out.update( @@ -1682,6 +1864,8 @@ def _masked_swap_lowering_rule( result = vector.load(mem_aval_vec_type, ref, starts) val = _maybe_cast_store_to_memref_type(ctx, val_aval, val) if mem_aval != aval_out: + if not aval_out.shape: + raise ValueError("Cannot swap scalars to VMEM.") # We are slicing a scalar so provided dummy 1 indices result_vec_type = ir.VectorType.get(aval_out.shape, _dtype_to_ir_type(aval_out.dtype, is_kernel_boundary=True)) @@ -1689,21 +1873,25 @@ def _masked_swap_lowering_rule( val_vec_type = ir.VectorType.get(mem_aval.shape, _dtype_to_ir_type(mem_aval.dtype, is_kernel_boundary=True)) val = vector.shape_cast(val_vec_type, val) + if mask is not None: + mask_vec_type = ir.VectorType.get( + mem_aval.shape, _dtype_to_ir_type(mask_aval.dtype) + ) + mask = vector.shape_cast(mask_vec_type, mask) result = _maybe_cast_load_to_bool(ctx, val_aval, result) if need_stride: if mask is not None: raise NotImplementedError("masked swap with strided store") - tpu.StridedStoreOp(val, ref, starts, strides) + tpu.strided_store(val, ref, starts, strides) else: - tpu.VectorStoreOp(val, ref, starts, [], mask=mask) + tpu.vector_store(val, ref, starts, [], mask=mask) return result -lowering_rules[primitives.swap_p] = _masked_swap_lowering_rule -skip_mlir_conversions.add(primitives.swap_p) - - +@register_lowering_rule( + primitives.multiple_of_p, kernel_types=[*tpu_core.KernelType] +) def _multiple_of_lowering_rule(ctx: LoweringRuleContext, val, *, values): del ctx for multiple in values: @@ -1711,11 +1899,8 @@ def _multiple_of_lowering_rule(ctx: LoweringRuleContext, val, *, values): return val -lowering_rules[primitives.multiple_of_p] = _multiple_of_lowering_rule - - def reduce_lowering_rule(reduce_fn, type_to_kind, type_to_identity): - def _lowering_rule(ctx: LoweringRuleContext, x, *, axes): + def _lowering_rule(ctx: LoweringRuleContext, x, *, axes, **kwargs): (x_aval,) = ctx.avals_in if not ctx.avals_out[0].shape: # If reducing to a scalar, we reduce by adding a leading singleton @@ -1759,7 +1944,7 @@ def _proxy_fun(val, *, axes): ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0] ) identity = ir.DenseElementsAttr.get_splat(out_type, val) - acc = arith.ConstantOp(out_type, identity) + acc = arith.constant(out_type, identity) return vector.multi_reduction(kind, x, acc, axes) return _lowering_rule @@ -1775,7 +1960,7 @@ def _proxy_fun(val, *, axes): } _reduce_max_lowering_rule = reduce_lowering_rule( jnp.max, REDUCE_MAX_KINDS, REDUCE_MAX_IDENTITY) -lowering_rules[lax.reduce_max_p] = _reduce_max_lowering_rule +register_lowering_rule(lax.reduce_max_p)(_reduce_max_lowering_rule) REDUCE_MIN_KINDS = { @@ -1789,7 +1974,7 @@ def _proxy_fun(val, *, axes): } _reduce_min_lowering_rule = reduce_lowering_rule( jnp.min, REDUCE_MIN_KINDS, REDUCE_MIN_IDENTITY) -lowering_rules[lax.reduce_min_p] = _reduce_min_lowering_rule +register_lowering_rule(lax.reduce_min_p)(_reduce_min_lowering_rule) REDUCE_SUM_KINDS = { @@ -1803,9 +1988,10 @@ def _proxy_fun(val, *, axes): } _reduce_sum_lowering_rule = reduce_lowering_rule( jnp.sum, REDUCE_SUM_KINDS, REDUCE_SUM_IDENTITY) -lowering_rules[lax.reduce_sum_p] = _reduce_sum_lowering_rule +register_lowering_rule(lax.reduce_sum_p)(_reduce_sum_lowering_rule) +@register_lowering_rule(lax.reduce_and_p) def _reduce_and_lowering_rule(ctx: LoweringRuleContext, x, *, axes): def _proxy_reduce(arg, *, axes): # Mosaic currently only supports float reductions, so we cast the boolean @@ -1818,9 +2004,8 @@ def _proxy_reduce(arg, *, axes): _proxy_reduce, multiple_results=False) return proxy_lowering(ctx, x, axes=axes) -lowering_rules[lax.reduce_and_p] = _reduce_and_lowering_rule - +@register_lowering_rule(lax.reduce_or_p) def _reduce_or_lowering_rule(ctx: LoweringRuleContext, x, *, axes): def _proxy_reduce(arg, *, axes): # Mosaic currently only supports float reductions, so we cast the boolean @@ -1833,9 +2018,8 @@ def _proxy_reduce(arg, *, axes): _proxy_reduce, multiple_results=False) return proxy_lowering(ctx, x, axes=axes) -lowering_rules[lax.reduce_or_p] = _reduce_or_lowering_rule - +@register_lowering_rule(state_primitives.broadcast_to_p) def _broadcast_to_lowering_rule( ctx: LoweringRuleContext, x, shape: Sequence[int] ): @@ -1845,29 +2029,33 @@ def _broadcast_to_lowering_rule( ) -lowering_rules[state_primitives.broadcast_to_p] = _broadcast_to_lowering_rule - - +@register_lowering_rule( + lax.broadcast_in_dim_p, kernel_types=[*tpu_core.KernelType] +) def _broadcast_in_dim_lowering_rule( ctx: LoweringRuleContext, val, *, shape, broadcast_dimensions, sharding ): del sharding (aval_in,) = ctx.avals_in (aval_out,) = ctx.avals_out + if aval_in.shape == shape: + return val - if jnp.issubdtype(aval_in.dtype, jnp.bool_): + if jnp.issubdtype(aval_in.dtype, jnp.bool_) and ( + ctx.forward_compatible or ctx.is_cloud_tpu_older_than(2025, 6, 3) + ): # Direct broadcasts for bools are not supported in Mosaic due to booleans # living in mask registers and broadcast operating on vregs. Broadcast as an # integer instead and cast back to a bool. - # TODO(b/351019164): Implement this logic in Mosaic BroadcastOp instead. def _proxy_fun(val, *, shape, broadcast_dimensions): int_val = jnp.where(val, 1, 0) bcast_val = jax.lax.broadcast_in_dim(int_val, shape, broadcast_dimensions) return bcast_val == 1 - proxy_lowering = lower_fun( - _proxy_fun, multiple_results=False) + + proxy_lowering = lower_fun(_proxy_fun, multiple_results=False) return proxy_lowering( - ctx, val, shape=shape, broadcast_dimensions=broadcast_dimensions) + ctx, val, shape=shape, broadcast_dimensions=broadcast_dimensions + ) if broadcast_dimensions: out_shape_list = [1] * len(shape) @@ -1886,9 +2074,6 @@ def _proxy_fun(val, *, shape, broadcast_dimensions): return vector.broadcast(out_type, val) -lowering_rules[lax.broadcast_in_dim_p] = _broadcast_in_dim_lowering_rule - - def jax_dot_dims_to_tpu_dot_dot_dims(dimension_numbers, lhs_shape, rhs_shape): """Converts a jax dot dimension numbers to a tpu dot dimension numbers. @@ -1954,6 +2139,7 @@ def format_dims(dims): return ir.Attribute.parse(tpu_dim_numbers_str) +@register_lowering_rule(lax.dot_general_p) def _dot_general_lowering_rule( ctx: LoweringRuleContext, x, @@ -1968,9 +2154,10 @@ def _dot_general_lowering_rule( out_type = aval_to_ir_type( ctx.lowering_context.dynamic_shape_replacement_fn, aval_out ) - val_type = out_type.element_type + assert isinstance(out_type, ir.ShapedType) + val_type = ir.ShapedType(out_type).element_type if any( - cls.isinstance(val_type) + isinstance(val_type, cls) for cls in [ ir.BF16Type, ir.F32Type, @@ -1980,13 +2167,25 @@ def _dot_general_lowering_rule( ] ): val = ir.FloatAttr.get(val_type, 0.0) - elif ir.IntegerType.isinstance(val_type): + elif isinstance(val_type, ir.IntegerType): val = ir.IntegerAttr.get(val_type, 0) else: raise NotImplementedError(ctx.avals_out[0].dtype) lhs_aval, rhs_aval = ctx.avals_in # This is really a matrix-vector product. It only looks like matrix-matrix. - if lhs_dims == (1,) and rhs_dims == (1,) and ctx.avals_in[1].shape[0] == 1: + if ( + lhs_dims == (1,) + and rhs_dims == (1,) + and ctx.avals_in[1].shape[0] == 1 + and len(ctx.avals_in[0].shape) == 2 + and len(ctx.avals_in[1].shape) == 2 + and ( + lhs_aval.dtype != jnp.float32 + or rhs_aval.dtype != jnp.float32 + or ctx.forward_compatible + or ctx.is_cloud_tpu_older_than(2025, 8, 10) + ) + ): if ctx.avals_in[0].shape != ctx.avals_in[1].shape: bcast_shape = jnp.broadcast_shapes( ctx.avals_in[0].shape, ctx.avals_out[0].shape @@ -2026,12 +2225,12 @@ def _dot_general_lowering_rule( else: raise NotImplementedError(f"Unsupported {preferred_element_type=}") - acc = arith.ConstantOp( + acc = arith.constant( red_type, ir.DenseElementsAttr.get_splat(red_type, val) ) - red = vector.MultiDimReductionOp( + red = vector.multi_reduction( ir.Attribute.parse("#vector.kind"), - arith.MulFOp(x, y), + arith.mulf(x, y), acc, [1] ) @@ -2053,7 +2252,7 @@ def _dot_general_lowering_rule( ) else: raise NotImplementedError(f"Unsupported dot precision: {precision}") - out_tile = arith.ConstantOp( + out_tile = arith.constant( out_type, ir.DenseElementsAttr.get_splat(out_type, val) ) return tpu.matmul( @@ -2066,11 +2265,11 @@ def _dot_general_lowering_rule( ) -lowering_rules[lax.dot_general_p] = _dot_general_lowering_rule - -def _convert_helper(x, *, to_dtype): +def _convert_helper(x: Array, *, to_dtype: jnp.dtype) -> Array: # Helper function for dtype conversion from_dtype = x.dtype + from_bitwidth = dtypes.itemsize_bits(from_dtype) + to_bitwidth = dtypes.itemsize_bits(to_dtype) if from_dtype == jnp.bool_: x = x.astype(jnp.int32) return _convert_helper(x, to_dtype=to_dtype) @@ -2078,36 +2277,30 @@ def _convert_helper(x, *, to_dtype): # Lower float32 or (u)int32 -> bool to cmp neq %in, 0 # TODO(apaszke,mvoz): Move the upcasts for cmpi to the Mosaic canonicalizer. if jnp.issubdtype(from_dtype, jnp.floating): - if from_dtype.itemsize < 4: + if from_bitwidth < 32: x = x.astype(jnp.float32) elif jnp.issubdtype(from_dtype, jnp.integer): - if from_dtype.itemsize < 4: + if from_bitwidth < 32: x = x.astype(jnp.int32) return x != jnp.asarray(0, dtype=x.dtype) if jnp.issubdtype(from_dtype, jnp.signedinteger): - if from_dtype.itemsize < 4: + if from_bitwidth < 32: x = x.astype(jnp.int32) - if jnp.issubdtype(to_dtype, jnp.floating) and to_dtype.itemsize < 4: + if jnp.issubdtype(to_dtype, jnp.floating) and to_bitwidth < 32: x = x.astype(jnp.float32) return x.astype(to_dtype) if jnp.issubdtype(from_dtype, jnp.unsignedinteger): - if from_dtype.itemsize < 4: + if from_bitwidth < 32: x = x.astype(jnp.uint32) # unsigned -> float is unsupported. We fall through and raise at the bottom. if not jnp.issubdtype(to_dtype, jnp.floating): return x.astype(to_dtype) - if jnp.issubdtype(from_dtype, jnp.floating) and jnp.issubdtype( - to_dtype, jnp.signedinteger - ): - if from_dtype.itemsize < 4: - x = x.astype(jnp.float32) - if to_dtype.itemsize < 4: - # Need to clip values to match XLA - minval, maxval = jnp.iinfo(to_dtype).min, jnp.iinfo(to_dtype).max - x = jnp.clip(x, minval, maxval) - return x.astype(jnp.int32).astype(to_dtype) raise NotImplementedError(f"Unsupported cast: {from_dtype} -> {to_dtype}") + +@register_lowering_rule( + lax.convert_element_type_p, kernel_types=[*tpu_core.KernelType] +) def _convert_element_type_lowering_rule( ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding ): @@ -2131,39 +2324,48 @@ def _convert_element_type_lowering_rule( floating = jnp.floating integer = jnp.integer signed = jnp.signedinteger - both_32bit = old_dtype.itemsize == 4 and new_dtype.itemsize == 4 + unsigned = jnp.unsignedinteger + old_bitwidth = dtypes.itemsize_bits(old_dtype) + new_bitwidth = dtypes.itemsize_bits(new_dtype) + both_32bit = old_bitwidth == 32 and new_bitwidth == 32 if _from(floating) and _to(floating): - if old_dtype.itemsize < new_dtype.itemsize and new_dtype.itemsize == 4: + forward_compat = ctx.forward_compatible or ctx.is_cloud_tpu_older_than( + 2025, 6, 29 + ) + if old_bitwidth < new_bitwidth and ( + new_bitwidth == 32 or not forward_compat + ): return arith.extf(out_type, x) - elif old_dtype.itemsize > new_dtype.itemsize and old_dtype.itemsize == 4: + elif old_bitwidth > new_bitwidth and ( + old_bitwidth == 32 or not forward_compat + ): return arith.truncf(out_type, x) elif _from(integer) and _to(integer): - if old_dtype.itemsize < new_dtype.itemsize and new_dtype.itemsize == 4: - if not (_from(signed) and _to(signed)): - raise NotImplementedError(f"Unsupported cast: {old_dtype} -> {new_dtype}") - return arith.extsi(out_type, x) - elif old_dtype.itemsize > new_dtype.itemsize and old_dtype.itemsize == 4: + if old_bitwidth < new_bitwidth and new_bitwidth == 32: + if _from(unsigned): + return arith.extui(out_type, x) + if _from(signed): + return arith.extsi(out_type, x) + elif old_bitwidth > new_bitwidth and old_bitwidth == 32: return arith.trunci(out_type, x) elif jnp.iinfo(old_dtype).bits == jnp.iinfo(new_dtype).bits: # This case triggers when casting signed to unsigned or vice versa. return x - # TODO(apaszke): Remove both_32bit constraints using the Mosaic canonicalizer. elif _from(floating) and _to(signed): - # TODO(apaszke): Remove once a month has passed, along with the - # _convert_helper float -> signed conversion above. - if not ctx.forward_compatible or both_32bit: - return arith.fptosi(out_type, x) - elif _from(signed) and _to(floating) and both_32bit: - return arith.sitofp(out_type, x) - elif old_dtype == jnp.bool_ and _to(integer) and new_dtype.itemsize == 4: + return arith.fptosi(out_type, x) + elif _from(signed) and _to(floating): + if ( + not (ctx.forward_compatible or ctx.is_cloud_tpu_older_than(2025, 5, 12)) + or both_32bit + ): + return arith.sitofp(out_type, x) + elif old_dtype == jnp.bool_ and _to(integer) and new_bitwidth == 32: return arith.extui(out_type, x) return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype), multiple_results=False)(ctx, x) -lowering_rules[lax.convert_element_type_p] = _convert_element_type_lowering_rule - - +@register_lowering_rule(lax.reshape_p, kernel_types=[*tpu_core.KernelType]) def _reshape_lowering_rule(ctx: LoweringRuleContext, x, new_sizes, dimensions, sharding): if dimensions is not None: @@ -2177,6 +2379,8 @@ def _reshape_lowering_rule(ctx: LoweringRuleContext, x, new_sizes, dimensions, ), x, ) + if not ctx.avals_out[0].shape: + return vector.extract(x, [], [0] * len(ctx.avals_in[0].shape)) return vector.shape_cast( aval_to_ir_type( ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0] @@ -2185,9 +2389,7 @@ def _reshape_lowering_rule(ctx: LoweringRuleContext, x, new_sizes, dimensions, ) -lowering_rules[lax.reshape_p] = _reshape_lowering_rule - - +@register_lowering_rule(lax.squeeze_p, kernel_types=[*tpu_core.KernelType]) def _squeeze_lowering_rule(ctx: LoweringRuleContext, x, dimensions): del dimensions # Unused. (aval_in,) = ctx.avals_in @@ -2208,19 +2410,13 @@ def _squeeze_lowering_rule(ctx: LoweringRuleContext, x, dimensions): ) -lowering_rules[lax.squeeze_p] = _squeeze_lowering_rule - - +@register_lowering_rule(lax.concatenate_p, kernel_types=[*tpu_core.KernelType]) def _concatenate_lowering_rule(ctx: LoweringRuleContext, *xs, dimension): - out_type = aval_to_ir_type( - ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0] - ) - return tpu.concatenate(out_type, xs, dimension=dimension) - - -lowering_rules[lax.concatenate_p] = _concatenate_lowering_rule + del ctx # Unused. + return tpu.concatenate(xs, dimension=dimension) +@register_lowering_rule(lax.split_p, kernel_types=[*tpu_core.KernelType]) def _split_lowering_rule( ctx: LoweringRuleContext, x, *, sizes, axis ): @@ -2245,20 +2441,27 @@ def _split_lowering_rule( starts[axis] += size return outs -lowering_rules[lax.split_p] = _split_lowering_rule - +@register_lowering_rule(lax.iota_p) def _iota_lowering_rule(ctx: LoweringRuleContext, dtype, shape, dimension, sharding): + if len(shape) == 1: + if dimension != 0: + raise ValueError("Dimension must be 0 for 1D iota.") + def _1d_iota_helper(): + iota_2d = lax.iota_p.bind(dtype=dtype, + shape=(1,) + shape, + dimension=1, + sharding=sharding) + return iota_2d[0] + return lower_fun(_1d_iota_helper, multiple_results=False)(ctx) out_type = aval_to_ir_type( ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0] ) - return tpu.iota(out_type, dimension=dimension) - - -lowering_rules[lax.iota_p] = _iota_lowering_rule + return tpu.iota(out_type, dimensions=[dimension]) +@register_lowering_rule(lax.gather_p) def _gather_lowering_rule( ctx: LoweringRuleContext, x, @@ -2277,8 +2480,6 @@ def _gather_lowering_rule( if len(in_aval.shape) != 2: raise NotImplementedError("Only 2D gather is supported") - if pallas_utils.dtype_bitwidth(in_aval.dtype) != 32: - raise NotImplementedError("Only 32-bit gather is supported") if in_aval.shape != indices_aval.shape[:-1] != out_aval.shape: raise ValueError("Shape mismatch in input, indices and output") @@ -2287,18 +2488,17 @@ def _gather_lowering_rule( ) # During lowering jnp.take_along_axis to lax.gather, we append extra dimension # to the end of the indices array. We should reshape it back to the original - # shape before lowering to Mosaic and rely on MLIR CSE to remove the reshapes. + # shape before lowering to Mosaic and rely on MLIR canonicalization to remove + # the reshapes. assert indices_aval.shape == in_aval.shape + (1,) recovered_indices = vector.shape_cast( - ir.VectorType.get(in_aval.shape, ir.IntegerType.get_signless(32)), + ir.VectorType.get(in_aval.shape, indices.type.element_type), indices, ) # Note: current support for lax.gather is still very limited. del fill_value if ( slice_sizes == (1, 1) - and not unique_indices - and not indices_are_sorted and mode in ( lax.GatherScatterMode.FILL_OR_DROP, @@ -2312,7 +2512,7 @@ def _gather_lowering_rule( operand_batching_dims=(1,), start_indices_batching_dims=(1,), ): - return tpu.dynamic_gather(out_type, x, recovered_indices, 0) + return tpu.dynamic_gather(x, recovered_indices, [0]) if dimension_numbers == lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(1,), @@ -2320,41 +2520,43 @@ def _gather_lowering_rule( operand_batching_dims=(0,), start_indices_batching_dims=(0,), ): - return tpu.dynamic_gather(out_type, x, recovered_indices, 1) + return tpu.dynamic_gather(x, recovered_indices, [1]) raise NotImplementedError("Unsupported gather") -lowering_rules[lax.gather_p] = _gather_lowering_rule - - +@register_lowering_rule(lax.transpose_p) def _transpose_lowering_rule(ctx: LoweringRuleContext, x, *, permutation): - if permutation != (1, 0): - raise NotImplementedError out_type = aval_to_ir_type( ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0] ) - return vector.transpose(out_type, x, permutation) - - -lowering_rules[lax.transpose_p] = _transpose_lowering_rule + if ctx.forward_compatible or ctx.is_cloud_tpu_older_than(2025, 5, 8): + return vector.transpose(out_type, x, permutation) + else: + return tpu.transpose(out_type, x, permutation) -def _bcast(x, y, x_aval, y_aval, out_aval): +def _bcast( + x: ir.Value | object, + y: ir.Value | object, + x_aval: ShapedAbstractValue, + y_aval: ShapedAbstractValue, + out_aval: ShapedAbstractValue, +) -> tuple[ir.Value, ir.Value]: x_dtype = x_aval.dtype y_dtype = y_aval.dtype if y_aval.weak_type: y_dtype = x_aval.dtype elif x_aval.weak_type: x_dtype = y_aval.dtype - if isinstance(x, (np.ndarray, np.number, int, float)): - if getattr(y, "type", None) == ir.IndexType.get(): - mlir_type = y.type + if not isinstance(x, ir.Value): + if (y_type := getattr(y, "type", None)) == ir.IndexType.get(): + mlir_type = y_type else: mlir_type = _dtype_to_ir_type(x_dtype) x = ir_constant(x, mlir_type) - if isinstance(y, (np.ndarray, np.number, int, float)): - if getattr(x, "type", None) == ir.IndexType.get(): - mlir_type = x.type + if not isinstance(y, ir.Value): + if (x_type := getattr(x, "type", None)) == ir.IndexType.get(): + mlir_type = x_type else: mlir_type = _dtype_to_ir_type(y_dtype) y = ir_constant(y, mlir_type) @@ -2368,6 +2570,10 @@ def _bcast(x, y, x_aval, y_aval, out_aval): return x, y +@register_lowering_rule( + lax.add_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False +) +@register_lowering_rule(ad_util.add_any_p, ensure_mlir_values=False) def _add_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out @@ -2378,12 +2584,6 @@ def _add_lowering_rule(ctx: LoweringRuleContext, x, y): raise NotImplementedError(aval_out.dtype) -lowering_rules[lax.add_p] = _add_lowering_rule -skip_mlir_conversions.add(lax.add_p) -lowering_rules[ad_util.add_any_p] = _add_lowering_rule -skip_mlir_conversions.add(ad_util.add_any_p) - - class FoldingError(Exception): pass @@ -2391,22 +2591,22 @@ class FoldingError(Exception): def _fold_and_get_constant_value(x): def _fold(x, fuel): if fuel <= 0: - raise FoldingError("Folding depth exceeded") + raise FoldingError() op_name = getattr(x.owner, "name", None) binop_folds = { "arith.maxsi": max, "arith.minsi": min, } if op_name == "arith.constant": - if ir.IntegerType.isinstance(x.type): + if isinstance(x.type, ir.IntegerType): return ir.IntegerAttr(x.owner.attributes["value"]).value - elif ir.FloatType.isinstance(x.type): + elif isinstance(x.type, ir.FloatType): return ir.FloatAttr(x.owner.attributes["value"]).value else: raise ValueError(f"Unsupported constant type: {x.type}") if op_name in binop_folds: return binop_folds[op_name](_fold(v, fuel - 1) for v in x.owner.operands) - raise FoldingError(f"Folding not supported for {x.owner}") + raise FoldingError() try: return _fold(x, 10) @@ -2414,6 +2614,13 @@ def _fold(x, fuel): return None +@register_lowering_rule(lax.stop_gradient_p) +def _stop_gradient_lowering_rule(_: LoweringRuleContext, x): + return x + +@register_lowering_rule( + lax.max_p, ensure_mlir_values=False, kernel_types=[*tpu_core.KernelType] +) def _max_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out @@ -2426,10 +2633,9 @@ def _max_lowering_rule(ctx: LoweringRuleContext, x, y): raise NotImplementedError(aval_out.dtype) -lowering_rules[lax.max_p] = _max_lowering_rule -skip_mlir_conversions.add(lax.max_p) - - +@register_lowering_rule( + lax.min_p, ensure_mlir_values=False, kernel_types=[*tpu_core.KernelType] +) def _min_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out @@ -2441,11 +2647,54 @@ def _min_lowering_rule(ctx: LoweringRuleContext, x, y): return arith.minimumf(x, y) raise NotImplementedError(aval_out.dtype) +def _reduce_index_helper( + ctx: LoweringRuleContext, x, axes, index_dtype, reduction_kind): + (x_aval,) = ctx.avals_in + (out_aval,) = ctx.avals_out + if x_aval.dtype != jnp.float32: + raise NotImplementedError("Only float32 is supported") + if len(axes) != 1: + raise NotImplementedError("Only single axis reduction supported") + if index_dtype != jnp.int32: + raise NotImplementedError("Only index_dtype=int32 is supported") + + axis = axes[0] + # TODO(b/460843515): Support 1D inputs in Mosaic. + is_1d = len(x_aval.shape) == 1 + if is_1d: + x_2d_aval = jax_core.ShapedArray((1, *x_aval.shape), x_aval.dtype) + x_2d_type = aval_to_ir_type( + ctx.lowering_context.dynamic_shape_replacement_fn, x_2d_aval + ) + out_aval = jax_core.ShapedArray((1, *out_aval.shape), out_aval.dtype) + x = vector.shape_cast(x_2d_type, x) + axis += 1 + + out_type = aval_to_ir_type( + ctx.lowering_context.dynamic_shape_replacement_fn, out_aval + ) + result = tpu.reduce_index(out_type, x, axis, reduction_kind) + if is_1d: + return vector.extract(result, [], [0]) + return result -lowering_rules[lax.min_p] = _min_lowering_rule -skip_mlir_conversions.add(lax.min_p) +@register_lowering_rule(lax.argmax_p, ensure_mlir_values=False) +def _argmax_lowering_rule(ctx: LoweringRuleContext, x, axes, index_dtype): + return _reduce_index_helper( + ctx, x, axes, index_dtype, + ir.Attribute.parse("#tpu.reduction_kind") + ) +@register_lowering_rule(lax.argmin_p, ensure_mlir_values=False) +def _argmin_lowering_rule(ctx: LoweringRuleContext, x, axes, index_dtype): + return _reduce_index_helper( + ctx, x, axes, index_dtype, + ir.Attribute.parse("#tpu.reduction_kind") + ) +@register_lowering_rule( + lax.sub_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False +) def _sub_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out @@ -2456,10 +2705,9 @@ def _sub_lowering_rule(ctx: LoweringRuleContext, x, y): raise NotImplementedError(aval_out.dtype) -lowering_rules[lax.sub_p] = _sub_lowering_rule -skip_mlir_conversions.add(lax.sub_p) - - +@register_lowering_rule( + lax.mul_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False +) def _mul_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out @@ -2470,10 +2718,9 @@ def _mul_lowering_rule(ctx: LoweringRuleContext, x, y): raise NotImplementedError(aval_out.dtype) -lowering_rules[lax.mul_p] = _mul_lowering_rule -skip_mlir_conversions.add(lax.mul_p) - - +@register_lowering_rule( + lax.div_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False +) def _div_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out @@ -2486,10 +2733,9 @@ def _div_lowering_rule(ctx: LoweringRuleContext, x, y): raise NotImplementedError(aval_out.dtype) -lowering_rules[lax.div_p] = _div_lowering_rule -skip_mlir_conversions.add(lax.div_p) - - +@register_lowering_rule( + lax.rem_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False +) def _rem_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out @@ -2502,10 +2748,7 @@ def _rem_lowering_rule(ctx: LoweringRuleContext, x, y): raise NotImplementedError(aval_out.dtype) -lowering_rules[lax.rem_p] = _rem_lowering_rule -skip_mlir_conversions.add(lax.rem_p) - - +@register_lowering_rule(lax.abs_p, kernel_types=[*tpu_core.KernelType]) def _abs_lowering_rule(ctx: LoweringRuleContext, x): (aval_out,) = ctx.avals_out if jnp.issubdtype(aval_out.dtype, jnp.integer): @@ -2515,9 +2758,9 @@ def _abs_lowering_rule(ctx: LoweringRuleContext, x): raise NotImplementedError(aval_out.dtype) -lowering_rules[lax.abs_p] = _abs_lowering_rule - - +@register_lowering_rule( + lax.neg_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False +) def _neg_lowering_rule(ctx: LoweringRuleContext, x): (x_aval,) = ctx.avals_in new_ctx = ctx.replace( @@ -2527,58 +2770,49 @@ def _neg_lowering_rule(ctx: LoweringRuleContext, x): return _sub_lowering_rule(new_ctx, np.array(0, dtype=x_aval.dtype), x) -lowering_rules[lax.neg_p] = _neg_lowering_rule -skip_mlir_conversions.add(lax.neg_p) - - +@register_lowering_rule(lax.sign_p, kernel_types=[*tpu_core.KernelType]) def _sign_lowering_rule(ctx: LoweringRuleContext, x): return lower_fun( pallas_utils.sign_lowering_helper, multiple_results=False, )(ctx, x) -lowering_rules[lax.sign_p] = _sign_lowering_rule - - +@register_lowering_rule(lax.nextafter_p) def _nextafter_lowering_rule(ctx: LoweringRuleContext, x, y): return lower_fun( pallas_utils.nextafter_lowering_helper, multiple_results=False, )(ctx, x, y) -lowering_rules[lax.nextafter_p] = _nextafter_lowering_rule - - -def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.rsqrt_p) +def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.rsqrt(x) -lowering_rules[lax.rsqrt_p] = _rsqrt_lowering_rule - - -def _sqrt_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.sqrt_p) +def _sqrt_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.sqrt(x) -lowering_rules[lax.sqrt_p] = _sqrt_lowering_rule - - +@register_lowering_rule(lax.square_p) def _square_lowering_rule(ctx: LoweringRuleContext, x): if jnp.issubdtype(ctx.avals_in[0].dtype, jnp.integer): return arith.muli(x, x) return arith.mulf(x, x) -lowering_rules[lax.square_p] = _square_lowering_rule - - -def _exp_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.exp_p, kernel_types=[*tpu_core.KernelType]) +def _exp_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.exp(x) -lowering_rules[lax.exp_p] = _exp_lowering_rule - - +@register_lowering_rule(lax.pow_p, ensure_mlir_values=False) def _pow_lowering_rule(ctx: LoweringRuleContext, x, y): # jax accepts float base (x) and integer/float exponent (y), and integer # exponent is casted to float. @@ -2593,39 +2827,37 @@ def _pow_lowering_rule(ctx: LoweringRuleContext, x, y): return math.powf(x, y) -lowering_rules[lax.pow_p] = _pow_lowering_rule -skip_mlir_conversions.add(lax.pow_p) - - +@register_lowering_rule(lax.integer_pow_p) def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, *, y): return lower_fun(lax_internal._integer_pow, multiple_results=False)( ctx, x, y=y) -lowering_rules[lax.integer_pow_p] = _integer_pow_lowering_rule - - -def _exp2_lowering_rule(ctx: LoweringRuleContext, x): - # exp2 in JAX lowers to exp(ln2 * x), not to pow2. We match that behavior - # here. - return lower_fun( - lambda x: jnp.exp(jnp.astype(np.log(2), x.dtype) * x), - multiple_results=False, - )(ctx, x) - - -lowering_rules[lax.exp2_p] = _exp2_lowering_rule -skip_mlir_conversions.add(lax.exp2_p) +@register_lowering_rule(lax.exp2_p, ensure_mlir_values=False) +def _exp2_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") + if ctx.forward_compatible or ctx.is_cloud_tpu_older_than(2025, 7, 26): + # exp2 in JAX lowers to exp(ln2 * x), not to pow2. We match that behavior + # here. + return lower_fun( + lambda x: jnp.exp(jnp.astype(np.log(2), x.dtype) * x), + multiple_results=False, + )(ctx, x) + return math.exp2(x) -def _logistic_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.logistic_p) +def _logistic_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") neg_x = arith.negf(x) exp_neg_x = math.exp(neg_x) aval_out = ctx.avals_out[0] out_type = aval_to_ir_type( ctx.lowering_context.dynamic_shape_replacement_fn, aval_out ) - if aval_out.shape == (): + if not aval_out.shape: one = ir_constant(1.0, mlir_type=out_type) else: one = vector.broadcast(out_type, ir_constant(1.0)) @@ -2633,51 +2865,49 @@ def _logistic_lowering_rule(ctx: LoweringRuleContext, x): return arith.divf(one, denom) -lowering_rules[lax.logistic_p] = _logistic_lowering_rule - - -def _sin_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.sin_p) +def _sin_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.sin(x) -lowering_rules[lax.sin_p] = _sin_lowering_rule - - -def _cos_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.cos_p) +def _cos_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.cos(x) -lowering_rules[lax.cos_p] = _cos_lowering_rule - - -def _tan_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.tan_p) +def _tan_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.tan(x) -lowering_rules[lax.tan_p] = _tan_lowering_rule - - -def _tanh_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.tanh_p) +def _tanh_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.tanh(x) -lowering_rules[lax.tanh_p] = _tanh_lowering_rule - - -def _log_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.log_p) +def _log_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.log(x) -lowering_rules[lax.log_p] = _log_lowering_rule - - -def _log1p_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.log1p_p) +def _log1p_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.log1p(x) -lowering_rules[lax.log1p_p] = _log1p_lowering_rule - - +@register_lowering_rule(lax.round_p) def _round_lowering_rule(ctx: LoweringRuleContext, x, *, rounding_method): if rounding_method == 0: return math.round(x) @@ -2687,37 +2917,28 @@ def _round_lowering_rule(ctx: LoweringRuleContext, x, *, rounding_method): raise NotImplementedError(f"Unsupported rounding method: {rounding_method}") -lowering_rules[lax.round_p] = _round_lowering_rule - - +@register_lowering_rule(lax.ceil_p) def _ceil_lowering_rule(ctx: LoweringRuleContext, x): return math.ceil(x) -lowering_rules[lax.ceil_p] = _ceil_lowering_rule - - +@register_lowering_rule(lax.floor_p) def _floor_lowering_rule(ctx: LoweringRuleContext, x): return math.floor(x) -lowering_rules[lax.floor_p] = _floor_lowering_rule - - +@register_lowering_rule(lax.clz_p) def _clz_lowering_rule(ctx: LoweringRuleContext, x): return math.ctlz(x) -lowering_rules[lax.clz_p] = _clz_lowering_rule - +@register_lowering_rule(lax.population_count_p) def _population_count_lowering_rule(ctx: LoweringRuleContext, x): aval_out = ctx.avals_out[0] - if aval_out.shape == (): + if not aval_out.shape: raise ValueError("Population count is not supported on scalars") return math.ctpop(x) -lowering_rules[lax.population_count_p] = _population_count_lowering_rule - # Mapping for signed integer comparisons. _cmpsi_lowering_types = { @@ -2823,23 +3044,21 @@ def _cmp_lowering_rule(primitive, ctx: LoweringRuleContext, x, y): raise NotImplementedError(f"Unsupported dtype in cmp: {dtype}") -lowering_rules[lax.eq_p] = functools.partial(_cmp_lowering_rule, lax.eq_p) -lowering_rules[lax.ne_p] = functools.partial(_cmp_lowering_rule, lax.ne_p) -lowering_rules[lax.lt_p] = functools.partial(_cmp_lowering_rule, lax.lt_p) -lowering_rules[lax.le_p] = functools.partial(_cmp_lowering_rule, lax.le_p) -lowering_rules[lax.gt_p] = functools.partial(_cmp_lowering_rule, lax.gt_p) -lowering_rules[lax.ge_p] = functools.partial(_cmp_lowering_rule, lax.ge_p) +for prim in [lax.eq_p, lax.ne_p, lax.lt_p, lax.le_p, lax.gt_p, lax.ge_p]: + register_lowering_rule(prim, kernel_types=[*tpu_core.KernelType])( + functools.partial(_cmp_lowering_rule, prim) + ) +@register_lowering_rule( + lax.and_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False +) def _and_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) return arith.andi(x, y) -lowering_rules[lax.and_p] = _and_lowering_rule -skip_mlir_conversions.add(lax.and_p) - - +@register_lowering_rule(lax.is_finite_p) def _is_finite_lowering_rule(ctx: LoweringRuleContext, x): out_aval, = ctx.avals_out out_type = aval_to_ir_type( @@ -2848,18 +3067,15 @@ def _is_finite_lowering_rule(ctx: LoweringRuleContext, x): return _not_lowering_rule(ctx, tpu.weird(out_type, x)) -lowering_rules[lax.is_finite_p] = _is_finite_lowering_rule - - +@register_lowering_rule( + lax.or_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False +) def _or_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) return arith.ori(x, y) -lowering_rules[lax.or_p] = _or_lowering_rule -skip_mlir_conversions.add(lax.or_p) - - +@register_lowering_rule(lax.not_p, kernel_types=[*tpu_core.KernelType]) def _not_lowering_rule(ctx: LoweringRuleContext, x): # The primitive not_p is lowered to # https://github.com/openxla/stablehlo/blob/main/docs/spec.md#not @@ -2878,14 +3094,13 @@ def _not_lowering_rule(ctx: LoweringRuleContext, x): ctx.lowering_context.dynamic_shape_replacement_fn, out_aval ) scalar_minus_one = ir.IntegerAttr.get(out_scalar_type, -1) - minus_one = arith.ConstantOp( + minus_one = arith.constant( out_type, ir.DenseElementsAttr.get_splat(out_type, scalar_minus_one) ) return arith.xori(x, minus_one) -lowering_rules[lax.not_p] = _not_lowering_rule - +@register_lowering_rule(lax.select_n_p, kernel_types=[*tpu_core.KernelType]) def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, x, *args): if len(args) > 1: raise NotImplementedError("select_n only supported with <= 2 arguments") @@ -2905,56 +3120,17 @@ def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, x, *args): return arith.select(pred, y, x) -lowering_rules[lax.select_n_p] = _select_n_lowering_rule - - def _clamp(min, operand, max): res = jnp.maximum(operand, min) return jnp.minimum(res, max) +@register_lowering_rule(lax.clamp_p) def _clamp_lowering_rule(ctx: LoweringRuleContext, min, operand, max): """Compute minimum_p(maximum_p(min, operand), max).""" return lower_fun(_clamp, multiple_results=False)(ctx, min, operand, max) -lowering_rules[lax.clamp_p] = _clamp_lowering_rule - - -def _for_lowering_rule( - ctx: LoweringRuleContext, - *args, - jaxpr, - nsteps, - reverse, - unroll, - which_linear, -): - should_discharge = [ - not isinstance(aval, state.AbstractRef) for aval in ctx.avals_in - ] - jaxpr, () = state_discharge.discharge_state( - jaxpr, (), should_discharge=[False, *should_discharge] - ) - for i in range(nsteps): - if reverse: - i = nsteps - i - 1 - i = ir_constant(i) - lowering_context = ctx.lowering_context.replace( - block_shapes=[(), *ctx.block_shapes], - ) - non_ref_args = jaxpr_subcomp(lowering_context, jaxpr, i, *args) - non_ref_args_iter = iter(non_ref_args) - args = [ - next(non_ref_args_iter) if s else a - for a, s in zip(args, should_discharge) - ] - return args - - -lowering_rules[for_loop.for_p] = _for_lowering_rule - - def _lower_jaxpr_to_for_loop(ctx: LoweringRuleContext, jaxpr: jax_core.Jaxpr, start: int | ir.Value, num_steps: int | ir.Value, consts, *args, @@ -2968,8 +3144,10 @@ def _run_body(i, args): else: del i lowering_context = ctx.lowering_context.replace( - block_shapes=ctx.block_shapes[:len(consts)] - + ctx.block_shapes[len(consts) + 1:], + block_shapes=( + *ctx.block_shapes[: len(consts)], + *ctx.block_shapes[len(consts) + 1 :], + ), ) args = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args) return args @@ -2982,8 +3160,7 @@ def _run_body(i, args): # No need for an scf.For. We can just unroll completely for i in range(start, start + num_steps): args = _run_body( - ir_constant(i, mlir_type=_dtype_to_ir_type(jnp.dtype("int32"))), - args, + ir_constant(i, mlir_type=_dtype_to_ir_type(jnp.int32)), args ) return args if unroll != 1: @@ -2991,16 +3168,19 @@ def _run_body(i, args): f"Only unroll={num_steps=} and unroll=1 supported. Got {unroll=}.") lbd = _ensure_mlir_value(start, pallas_core.index_map_grid_aval) ubd = arith.addi(lbd, _ensure_mlir_value(num_steps, pallas_core.index_map_grid_aval)) - step = ir_constant(1, mlir_type=_dtype_to_ir_type(jnp.dtype("int32"))) + step = ir_constant(1, mlir_type=_dtype_to_ir_type(jnp.int32)) for_op = scf.ForOp(lbd, ubd, step, args) with ir.InsertionPoint(for_op.body): iv = for_op.induction_variable inner_args = for_op.inner_iter_args inner_out = _run_body(iv, inner_args) - scf.YieldOp(inner_out) + scf.yield_(inner_out) return for_op.results +@register_lowering_rule( + lax.scan_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False +) def _scan_lowering_rule( ctx: LoweringRuleContext, *args, @@ -3031,6 +3211,7 @@ def _scan_lowering_rule( consts_avals, args_avals = split_list(ctx.avals_in, [num_consts]) if has_loop_index: loop_index_start, *args = args + loop_index_start = loop_index_start args_avals = args_avals[1:] else: loop_index_start = 0 @@ -3041,12 +3222,8 @@ def _scan_lowering_rule( consts, *args, has_loop_index=has_loop_index, unroll=unroll) if has_loop_index: - out = [ir_constant(length, - mlir_type=_dtype_to_ir_type(jnp.dtype('int32'))), - *out] + out = [ir_constant(length, mlir_type=_dtype_to_ir_type(jnp.int32)), *out] return out -lowering_rules[lax.scan_p] = _scan_lowering_rule -skip_mlir_conversions.add(lax.scan_p) def _lower_while_via_fori( @@ -3062,8 +3239,10 @@ def _lower_while_via_fori( (lb, ub), args = carry[:2], carry[2:] for_out = _lower_jaxpr_to_for_loop( ctx.replace( - block_shapes=ctx.block_shapes[: body_nconsts + 1] - + ctx.block_shapes[body_nconsts + 2 :], + block_shapes=( + *ctx.block_shapes[: body_nconsts + 1], + *ctx.block_shapes[body_nconsts + 2 :], + ), ), fori_jaxpr, lb, @@ -3076,6 +3255,7 @@ def _lower_while_via_fori( return [ub, ub, *for_out] +@register_lowering_rule(lax.while_p, kernel_types=[*tpu_core.KernelType]) def _while_lowering_rule( ctx: LoweringRuleContext, *args, @@ -3136,9 +3316,8 @@ def _while_lowering_rule( return list(while_op.results) -lowering_rules[lax.while_p] = _while_lowering_rule - -def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches): +@register_lowering_rule(lax.cond_p, kernel_types=[*tpu_core.KernelType]) +def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches, **params): index, *args = args constant_index = _fold_and_get_constant_value(index) @@ -3153,7 +3332,7 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches): pred = arith.cmpi( arith.CmpIPredicate.ne, index, ir_constant(0, index.type) ) - if_op = scf.IfOp(pred, out_types, hasElse=True) + if_op = scf.IfOp(pred, out_types, has_else=True) lowering_context = ctx.lowering_context.replace( block_shapes=ctx.block_shapes[1:], ) @@ -3169,33 +3348,30 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches): ) else: out = jaxpr_subcomp(lowering_context, branches[1].jaxpr, *args) - scf.YieldOp(out) + scf.yield_(out) with ir.InsertionPoint(if_op.else_block): out = jaxpr_subcomp(lowering_context, branches[0].jaxpr, *args) - scf.YieldOp(out) + scf.yield_(out) return if_op.results -lowering_rules[lax.cond_p] = _cond_lowering_rule - - +@register_lowering_rule(pjit.jit_p, kernel_types=[*tpu_core.KernelType]) def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_): lowering_context = ctx.lowering_context.replace(block_shapes=ctx.block_shapes) return jaxpr_subcomp(lowering_context, jaxpr.jaxpr, *args) -lowering_rules[pjit.pjit_p] = _pjit_lowering_rule - - -def _mesh_cast_lowering_rule(ctx, x, dst_sharding): +@register_lowering_rule(pjit.reshard_p) +def _reshard_lowering_rule(ctx: LoweringRuleContext, x, *, dst_sharding, + concrete_mesh): return x -lowering_rules[pjit.mesh_cast_p] = _mesh_cast_lowering_rule +@register_lowering_rule(custom_derivatives.custom_jvp_call_p) def _custom_jvp_call_lowering_rule( ctx: LoweringRuleContext, *args, - call_jaxpr: jax_core.Jaxpr, + call_jaxpr: jax_core.ClosedJaxpr, jvp_jaxpr_fun: lu.WrappedFun, num_consts: int, symbolic_zeros: bool, @@ -3208,50 +3384,65 @@ def _custom_jvp_call_lowering_rule( return jaxpr_subcomp(lowering_context, call_jaxpr.jaxpr, *args) -lowering_rules[custom_derivatives.custom_jvp_call_p] = ( - _custom_jvp_call_lowering_rule) +@register_lowering_rule(custom_derivatives.custom_vjp_call_p) +def _custom_vjp_call_lowering_rule( + ctx: LoweringRuleContext, + *args, + call_jaxpr, + fwd_jaxpr_thunk, + out_trees, + symbolic_zeros, + bwd, + num_consts, +): + if num_consts: raise NotImplementedError + lowering_context = ctx.lowering_context.replace(block_shapes=ctx.block_shapes) + return jaxpr_subcomp(lowering_context, call_jaxpr.jaxpr, *args) +@register_lowering_rule(debugging.debug_callback_p) def _debug_callback_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs): del ctx, args, kwargs # No-op debug callbacks in Mosaic for now return [] -lowering_rules[debugging.debug_callback_p] = _debug_callback_lowering_rule - - +@register_lowering_rule( + primitives.program_id_p, kernel_types=[*tpu_core.KernelType] +) def _program_id_lowering_rule(ctx: LoweringRuleContext, *, axis: int): - if ctx.lowering_context.user_grid_indices is None: raise ValueError( f"program id: {axis} was passed, but user did not provide a grid." ) length = len(ctx.lowering_context.user_grid_indices) - if not (0 <= axis < length): + if axis not in range(length): raise ValueError( f"user passed in program id with axis: {axis}, but grid only has" f" length: {length}" ) return ctx.lowering_context.user_grid_indices[axis] -lowering_rules[primitives.program_id_p] = _program_id_lowering_rule + +@register_lowering_rule( + primitives.num_programs_p, kernel_types=[*tpu_core.KernelType] +) def _num_programs_lowering_rule(ctx: LoweringRuleContext, *, axis: int): - mapped_axes = set(ctx.lowering_context.mapped_dims) + vmapped_axes = set(ctx.lowering_context.vmapped_dims) seen_user_axes = 0 for i in range(ctx.lowering_context.grid_rank): - seen_user_axes += int(i not in mapped_axes) + seen_user_axes += int(i not in vmapped_axes) if seen_user_axes == axis + 1: break else: raise ValueError( f"user passed in program id with axis: {axis}, but grid only has" - f" length: {len(ctx.lowering_context.grid_rank)}" + f" length: {ctx.lowering_context.grid_rank}" ) return tpu.iteration_bound(i) -lowering_rules[primitives.num_programs_p] = _num_programs_lowering_rule +@register_lowering_rule(tpu_primitives.repeat_p) def _repeat_lowering_rule(ctx: LoweringRuleContext, x, *, repeats, axis): (out_aval,) = ctx.avals_out return tpu.repeat( @@ -3264,9 +3455,25 @@ def _repeat_lowering_rule(ctx: LoweringRuleContext, x, *, repeats, axis): ) -lowering_rules[tpu_primitives.repeat_p] = _repeat_lowering_rule +@register_lowering_rule(lax.tile_p) +def _tile_lowering_rule(ctx: LoweringRuleContext, x, *, reps): + (x_aval,) = ctx.avals_in + newshape = list(x_aval.shape) + for axis, repeats in enumerate(reps): + newshape[axis] *= repeats + x = tpu.repeat( + aval_to_ir_type( + ctx.lowering_context.dynamic_shape_replacement_fn, + x_aval.update(shape=tuple(newshape)) + ), + x, + axis, + repeats, + ) + return x +@register_lowering_rule(tpu_primitives.roll_p) def _roll_lowering_rule( ctx: LoweringRuleContext, x, shift, *, axis, stride, stride_axis ): @@ -3283,9 +3490,7 @@ def _roll_lowering_rule( ) -lowering_rules[tpu_primitives.roll_p] = _roll_lowering_rule - - +@register_lowering_rule(lax.slice_p, kernel_types=[*tpu_core.KernelType]) def _slice_lowering_rule( ctx: LoweringRuleContext, x, limit_indices, start_indices, strides ): @@ -3302,62 +3507,119 @@ def _slice_lowering_rule( ) -lowering_rules[lax.slice_p] = _slice_lowering_rule - - +@register_lowering_rule( + lax.xor_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False +) def _xor_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) return arith.xori(x, y) -lowering_rules[lax.xor_p] = _xor_lowering_rule -skip_mlir_conversions.add(lax.xor_p) - - +@register_lowering_rule( + lax.shift_left_p, + kernel_types=[*tpu_core.KernelType], + ensure_mlir_values=False, +) def _shift_left_lowering_rule(ctx: LoweringRuleContext, x, d): x, d = _bcast(x, d, *ctx.avals_in, *ctx.avals_out) return arith.shli(x, d) -lowering_rules[lax.shift_left_p] = _shift_left_lowering_rule -skip_mlir_conversions.add(lax.shift_left_p) - - +@register_lowering_rule( + lax.shift_right_arithmetic_p, + kernel_types=[*tpu_core.KernelType], + ensure_mlir_values=False, +) def _shift_right_arithmetic_lowering_rule(ctx: LoweringRuleContext, x, d): x, d = _bcast(x, d, *ctx.avals_in, *ctx.avals_out) return arith.shrsi(x, d) -lowering_rules[lax.shift_right_arithmetic_p] = _shift_right_arithmetic_lowering_rule -skip_mlir_conversions.add(lax.shift_right_arithmetic_p) - - -def _shift_right_logical_lowering_rules(ctx: LoweringRuleContext, x, d): +@register_lowering_rule( + lax.shift_right_logical_p, + kernel_types=[*tpu_core.KernelType], + ensure_mlir_values=False, +) +def _shift_right_logical_lowering_rule(ctx: LoweringRuleContext, x, d): x, d = _bcast(x, d, *ctx.avals_in, *ctx.avals_out) return arith.shrui(x, d) -lowering_rules[lax.shift_right_logical_p] = _shift_right_logical_lowering_rules -skip_mlir_conversions.add(lax.shift_right_logical_p) - - +@register_lowering_rule(lax.erf_inv_p) def _erf_inv_lowering_rule(ctx: LoweringRuleContext, x): return lower_fun( pallas_utils.erf_inv_lowering_helper, multiple_results=False, )(ctx, x) -lowering_rules[lax.erf_inv_p] = _erf_inv_lowering_rule - - +@register_lowering_rule(primitives.reciprocal_p) def _reciprocal_lowering_rule(ctx: LoweringRuleContext, x, *, approx): if not isinstance(x.type.element_type, ir.F32Type): raise ValueError("Only float32 is supported.") return tpu.reciprocal(x, approx=approx) -lowering_rules[primitives.reciprocal_p] = _reciprocal_lowering_rule +@register_lowering_rule(tpu_primitives.stochastic_round_p) +def _stochastic_round_lowering_rule( + ctx: LoweringRuleContext, x, random_bits, *, target_dtype +): + if not isinstance(x.type.element_type, ir.F32Type): + raise ValueError("Only float32 input is supported.") + if target_dtype not in [ + jnp.bfloat16, + jnp.float8_e5m2, + jnp.float8_e4m3fn, + jnp.float8_e4m3b11fnuz, + ]: + raise ValueError( + "Only bfloat16, float8_e5m2, float8_e4m3fn, and float8_e4m3b11fnuz " + "are supported as target dtypes." + ) + (_, in_aval,) = ctx.avals_in + out_type = ir.VectorType.get( + in_aval.shape, mlir.dtype_to_ir_type(jnp.dtype(target_dtype)) + ) + return tpu.stochastic_convert(out_type, x, random_bits) + + +def _check_elementwise_unpack_dtypes(unpacked_dtype, packed_dtype): + if unpacked_dtype == jnp.float32 and packed_dtype == jnp.bfloat16: + return + if unpacked_dtype == jnp.int32 and packed_dtype in [ + jnp.int16, jnp.int8, jnp.int4 + ]: + return + raise ValueError( + f"Unsupported elementwise packing: {unpacked_dtype} -> {packed_dtype}. " + "Only f32 <-> bf16 and i32 <-> i16/i8/i4 are supported." + ) + + +@register_lowering_rule(tpu_primitives.pack_elementwise_p) +def _pack_elementwise_lowering_rule( + ctx: LoweringRuleContext, *xs, packed_dtype +): + in_aval = ctx.avals_in[0] + out_aval = ctx.avals_out[0] + packed_ir_type = _dtype_to_ir_type(packed_dtype) + out_type = ir.VectorType.get(in_aval.shape, _dtype_to_ir_type(out_aval.dtype)) + return tpu.pack_elementwise(out_type, xs, target_type=packed_ir_type) + + +@register_lowering_rule(tpu_primitives.unpack_elementwise_p) +def _unpack_elementwise_lowering_rule( + ctx: LoweringRuleContext, x, index, packed_dtype, unpacked_dtype +): + in_aval = ctx.avals_in[0] + _check_elementwise_unpack_dtypes(unpacked_dtype, packed_dtype) + out_type = ir.VectorType.get( + in_aval.shape, _dtype_to_ir_type(unpacked_dtype) + ) + return tpu.unpack_elementwise( + out_type, x, source_type=_dtype_to_ir_type(packed_dtype), index=index) + +@register_lowering_rule(tpu_primitives.bitcast_p) def _bitcast_lowering_rule(ctx: LoweringRuleContext, x, *, ty): del ty (out_aval,) = ctx.avals_out @@ -3368,14 +3630,17 @@ def _bitcast_lowering_rule(ctx: LoweringRuleContext, x, *, ty): x, ) -lowering_rules[tpu_primitives.bitcast_p] = _bitcast_lowering_rule +@register_lowering_rule( + lax.bitcast_convert_type_p, kernel_types=[*tpu_core.KernelType] +) def _bitcast_convert_type_lowering_rule( - ctx: LoweringRuleContext, x, *, new_dtype): + ctx: LoweringRuleContext, x, *, new_dtype +): (in_aval, ) = ctx.avals_in (out_aval,) = ctx.avals_out - old_bitwidth = pallas_utils.dtype_bitwidth(in_aval.dtype) - new_bitwidth = pallas_utils.dtype_bitwidth(new_dtype) + old_bitwidth = dtypes.itemsize_bits(in_aval.dtype) + new_bitwidth = dtypes.itemsize_bits(new_dtype) if old_bitwidth != new_bitwidth: raise NotImplementedError("Changing bitwidths not supported.") return tpu.bitcast( @@ -3384,15 +3649,13 @@ def _bitcast_convert_type_lowering_rule( ), x, ) -lowering_rules[lax.bitcast_convert_type_p] = _bitcast_convert_type_lowering_rule def _alloc_value( aval: jax_core.AbstractValue, *, ctx: LoweringRuleContext ) -> ir.Value: - if isinstance(aval, pallas_core.AbstractMemoryRef): - memspace = _memory_space_to_mosaic_attribute(aval.memory_space) - if jnp.issubdtype(aval.dtype, tpu_core.semaphore_dtype): + if isinstance(aval, state.AbstractRef): + if jnp.issubdtype(aval.dtype, pallas_core.semaphore_dtype): assert aval.memory_space == TPUMemorySpace.SEMAPHORE memref_type = aval_to_ir_type( ctx.lowering_context.dynamic_shape_replacement_fn, @@ -3401,11 +3664,14 @@ def _alloc_value( ) return tpu.sem_alloc(memref_type) else: - out_type = ir.MemRefType.get( - aval.shape, - _dtype_to_ir_type(aval.dtype, is_kernel_boundary=True), - memory_space=memspace) - return memref.alloca(out_type, [], []) + memref_type = aval_to_ir_type( + ctx.lowering_context.dynamic_shape_replacement_fn, + aval, + is_kernel_boundary=True, + memory_space=aval.memory_space, + ) + assert isinstance(memref_type, ir.MemRefType) + return memref.alloca(memref_type, [], []) elif isinstance(aval, tpu_core.AbstractSemaphore): memref_type = aval_to_ir_type( ctx.lowering_context.dynamic_shape_replacement_fn, @@ -3416,7 +3682,16 @@ def _alloc_value( raise NotImplementedError(f"Cannot allocate {type(aval)}.") -def _run_scoped_lowering_rule(ctx: LoweringRuleContext, *consts, jaxpr): +@register_lowering_rule(primitives.run_scoped_p) +def _run_scoped_lowering_rule( + ctx: LoweringRuleContext, + *consts, + jaxpr, + collective_axes, + alloc_fn=_alloc_value, +): + if collective_axes: + raise NotImplementedError("run_scoped lowering does not support collective axes") out_type = [ aval_to_ir_type(ctx.lowering_context.dynamic_shape_replacement_fn, aval) for aval in ctx.avals_out @@ -3426,61 +3701,76 @@ def _run_scoped_lowering_rule(ctx: LoweringRuleContext, *consts, jaxpr): with ctx.lowering_context.grid_name_context(): jaxpr = pe.convert_constvars_jaxpr(jaxpr) with ir.InsertionPoint(region.body): - alloc_fn = functools.partial(_alloc_value, ctx=ctx) - args = map(alloc_fn, in_avals) + args = map(lambda aval: alloc_fn(aval, ctx=ctx), in_avals) block_shapes = tuple(a.shape if isinstance(a, state.AbstractRef) else None for a in in_avals) + block_shapes = tuple(map(_maybe_physicalize_block_shape, + in_avals, block_shapes)) ctx = ctx.lowering_context.replace( block_shapes=(*ctx.block_shapes, *block_shapes) ) out = jaxpr_subcomp(ctx, jaxpr, *consts, *args) - tpu.YieldOp(out) + tpu.yield_(out) return region.results -lowering_rules[primitives.run_scoped_p] = _run_scoped_lowering_rule - def _device_id_to_logical( ctx: LoweringRuleContext, device_id, - device_id_type: tpu_primitives.DeviceIdType): - if device_id_type is tpu_primitives.DeviceIdType.MESH: - # Mesh means we are passed the mesh coordinates for the device - device_ids = tree_util.tree_leaves(device_id) - mesh_strides = ctx.lowering_context.mesh_context.mesh_strides - - i32 = ir.IntegerType.get_signless(32) - if len(device_ids) == 0: - return arith.constant(i32, 0) - return functools.reduce( - arith.addi, - ( - arith.muli(a, arith.constant(i32, b)) - for a, b in zip(device_ids, mesh_strides) - ), + device_id_type: primitives.DeviceIdType): + logical_device_id, non_mesh_axes = primitives.device_id_to_logical( + ctx.lowering_context.mesh_context, + device_id, + device_id_type, + lambda name: _axis_index_rule(ctx, axis_name=name), + ) + core_index = None + if grid_names := ctx.lowering_context.grid_names: + if len(grid_names) > 1: + raise NotImplementedError( + "Unable to determine core axis name if grid_names is more than 1." + ) + core_axis_name = grid_names[0] + core_index = non_mesh_axes.pop(core_axis_name, None) + if non_mesh_axes: + raise ValueError( + f"Unrecognized axes in device_id: {non_mesh_axes}" ) - elif device_id_type is tpu_primitives.DeviceIdType.LOGICAL: - return device_id - raise NotImplementedError(f"Unsupported device id type: {device_id_type}") + return logical_device_id, core_index +@register_lowering_rule( + primitives.semaphore_read_p, kernel_types=[*tpu_core.KernelType] +) def _semaphore_read_lowering_rule( ctx: LoweringRuleContext, *args, args_tree, ): - sem_aval, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in) + sem_aval, sem_transforms_avals = tree_util.tree_unflatten(args_tree, ctx.avals_in) + primitives.check_sem_avals( + sem_aval, + sem_transforms_avals, + "read", + allowed_semaphore_types={ + tpu_core.dma_semaphore, + pallas_core.semaphore, + pallas_core.barrier_semaphore, + pallas_core.SEMAPHORE_INTERPRET_DTYPE, + }, + ) sem, transforms = tree_util.tree_unflatten(args_tree, args) sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms) return tpu.sem_read(sem) -lowering_rules[tpu_primitives.semaphore_read_p] = _semaphore_read_lowering_rule - +@register_lowering_rule( + primitives.semaphore_signal_p, kernel_types=[*tpu_core.KernelType] +) def _semaphore_signal_lowering_rule( ctx: LoweringRuleContext, *args, args_tree, - device_id_type: tpu_primitives.DeviceIdType, + device_id_type: primitives.DeviceIdType, ): sem_aval, _, _, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in) sem, transforms, value, device_id, core_index = tree_util.tree_unflatten( @@ -3488,25 +3778,41 @@ def _semaphore_signal_lowering_rule( ) sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms) if device_id is not None: - device_id = _device_id_to_logical(ctx, device_id, device_id_type) + device_id, core_id = _device_id_to_logical(ctx, device_id, device_id_type) + if core_id is not None: + if core_index is not None: + raise ValueError( + "Cannot specify both `core_index` and the core axis in `device_id`." + ) + core_index = core_id tpu.sem_signal(sem, value, device_id=device_id, core_id=core_index) return [] -lowering_rules[tpu_primitives.semaphore_signal_p] = ( - _semaphore_signal_lowering_rule) - - +@register_lowering_rule( + primitives.semaphore_wait_p, kernel_types=[*tpu_core.KernelType] +) def _semaphore_wait_lowering_rule(ctx: LoweringRuleContext, *args, args_tree): - sem_aval, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in) - sem, transforms, value = tree_util.tree_unflatten(args_tree, args) + sem_aval, _, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in) + sem, transforms, value, decrement = tree_util.tree_unflatten(args_tree, args) + if not decrement: + raise NotImplementedError("Non-decrementing wait is not supported.") sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms) tpu.sem_wait(sem, value) return [] -lowering_rules[tpu_primitives.semaphore_wait_p] = _semaphore_wait_lowering_rule -def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree, - device_id_type: tpu_primitives.DeviceIdType): + +@register_lowering_rule(tpu_primitives.dma_start_p) +def _dma_start_lowering_rule( + ctx: LoweringRuleContext, + *args, + tree, + device_id_type: primitives.DeviceIdType, + priority: int, + add: bool, +): + if add: + raise NotImplementedError("DMA with add=True is not supported.") ( src_ref, src_transforms, @@ -3536,21 +3842,35 @@ def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree, dst_ref, dst_ref_aval.dtype, dst_ref_block_shape, dst_transforms ) sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, sem_transforms) + core_id = None if device_id is not None: - device_id = _device_id_to_logical(ctx, device_id, device_id_type) - tpu.enqueue_dma(src_ref, dst_ref, sem, source_semaphore=src_sem, - device_id=device_id) - + device_id, core_id = _device_id_to_logical(ctx, device_id, device_id_type) + tpu.enqueue_dma( + src_ref, + dst_ref, + sem, + source_semaphore=src_sem, + device_id=device_id, + core_id=core_id, + priority=priority, + ) return [] -lowering_rules[tpu_primitives.dma_start_p] = _dma_start_lowering_rule +@register_lowering_rule(tpu_primitives.dma_wait_p) def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree, - device_id_type: tpu_primitives.DeviceIdType): - del device_id_type - (src, src_transforms, dst, transforms, sem, sem_transforms, _, _, _) = ( - tree_util.tree_unflatten(tree, args) - ) + device_id_type: primitives.DeviceIdType): + ( + src, + src_transforms, + dst, + transforms, + sem, + sem_transforms, + _, + _, + device_id, + ) = tree_util.tree_unflatten(tree, args) (src_aval, _, dst_aval, _, sem_aval, _, _, _, _) = tree_util.tree_unflatten( tree, ctx.avals_in ) @@ -3559,27 +3879,19 @@ def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree, src, _ = _transform_ref(src, src_aval.dtype, src_aval.shape, src_transforms) dst, _ = _transform_ref(dst, dst_aval.dtype, ref_block_shape, transforms) sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, sem_transforms) - if ctx.forward_compatible or is_cloud_tpu_older_than(2025, 2, 12): - # TODO(mvoz): Remove once six months have passed. b/395630795 - if hasattr(src_aval, "memory_space"): - src_memory_space = _memory_space_to_mosaic_attribute(src_aval.memory_space) - smem_space = ir.Attribute.parse("#tpu.memory_space") - src_is_smem = src_memory_space == smem_space - wait_ref = src if src_is_smem else dst - else: - wait_ref = dst - # Legacy instruction backwards compatibility. - tpu.wait_dma(sem, wait_ref) + + core_id = None + if device_id is not None: + device_id, core_id = _device_id_to_logical(ctx, device_id, device_id_type) + + if ctx.forward_compatible or ctx.is_cloud_tpu_older_than(2025, 7, 27): + tpu.wait_dma2(sem, src, dst, core_id=core_id) else: - tpu.wait_dma2(sem, src, dst) + tpu.wait_dma2(sem, src, dst, device_id=device_id, core_id=core_id) return [] -lowering_rules[tpu_primitives.dma_wait_p] = _dma_wait_lowering_rule - -def _device_id_lowering_rule(ctx: LoweringRuleContext): - return tpu.device_id() -lowering_rules[tpu_primitives.device_id_p] = _device_id_lowering_rule +@register_lowering_rule(lax.axis_index_p, kernel_types=[*tpu_core.KernelType]) def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): grid_names = ctx.lowering_context.grid_names if grid_names and axis_name in grid_names: @@ -3598,28 +3910,47 @@ def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): np.prod(mesh_shape[axis_index + 1 :], dtype=np.int32) ) return arith.remsi(arith.divsi(device_id, minor_divisor), axis_size) -lowering_rules[lax.axis_index_p] = _axis_index_rule + +@register_lowering_rule( + tpu_primitives.get_barrier_semaphore_p, kernel_types=[*tpu_core.KernelType] +) def _get_barrier_semaphore_rule(ctx: LoweringRuleContext): memref_type = aval_to_ir_type( ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0] ) return tpu.sem_barrier(memref_type) -lowering_rules[tpu_primitives.get_barrier_semaphore_p] = _get_barrier_semaphore_rule +@register_lowering_rule(primitives.delay_p) def _delay_rule(ctx: LoweringRuleContext, nanos: int): tpu.delay(nanos) return [] -lowering_rules[tpu_primitives.delay_p] = _delay_rule - - +@register_lowering_rule(debugging.debug_print_p) def _debug_print_rule( - ctx: LoweringRuleContext, *args, fmt: str, has_placeholders: bool + ctx: LoweringRuleContext, + *dyn_args, + fmt: str, + ordered, + partitioned, + in_tree, + static_args, + np_printoptions, + has_placeholders, + logging_record, ): - is_scalar_inputs = [aval.shape == () for aval in ctx.avals_in] + del partitioned, np_printoptions + if ordered: + raise NotImplementedError("Ordered debug_print is not supported on Pallas.") + args, kwargs = debugging.merge_callback_args(in_tree, dyn_args, static_args) + if kwargs: + raise ValueError( + "Only positional arguments are supported by debug_print on Pallas." + ) + + is_scalar_inputs = [not aval.shape for aval in ctx.avals_in] is_all_scalars = all(is_scalar_inputs) is_single_vector = len(is_scalar_inputs) == 1 and not is_scalar_inputs[0] if not (is_all_scalars or is_single_vector): @@ -3630,8 +3961,8 @@ def _debug_print_rule( # Scalar case. if is_all_scalars: - primitives.check_debug_print_format(fmt, *args) if has_placeholders: + primitives.check_debug_print_format(fmt, *args) if not all( isinstance(arg.type, ir.IntegerType) and arg.type.width == 32 for arg in args @@ -3642,10 +3973,12 @@ def _debug_print_rule( " remove placeholders from the format string." ) - # TPU expects $0, $1 etc as placeholders. + # TPU expects $0, $1 etc as placeholders. fmt = "".join( - f"{text}${idx}" - for idx, (text, _, _, _) in enumerate(string.Formatter().parse(fmt)) + f"{text}${spec}{idx}" if field is not None else text + for idx, (text, field, spec, _) in enumerate( + string.Formatter().parse(fmt) + ) ) tpu.log(args, fmt, formatted=has_placeholders) @@ -3687,9 +4020,7 @@ def _debug_print_rule( return () -lowering_rules[primitives.debug_print_p] = _debug_print_rule - - +@register_lowering_rule(tpu_primitives.prng_seed_p) def _prng_seed_lowering_rule(ctx: LoweringRuleContext, *seeds): del ctx # In the KeyScalarBundle case we unpack the bundle and set the seed with @@ -3705,9 +4036,9 @@ def _prng_seed_lowering_rule(ctx: LoweringRuleContext, *seeds): raise ValueError(f"All seed data must be scalar integers. Got {seed_types}") tpu.prng_set_seed_32(seeds) return [] -lowering_rules[tpu_primitives.prng_seed_p] = _prng_seed_lowering_rule +@register_lowering_rule(tpu_primitives.prng_random_bits_p) def _prng_random_bits_lowering_rule(ctx: LoweringRuleContext, *, shape): if len(shape) <= 1: # TODO(b/342054464): Support implicit dims for PRNGRandomBitsOp. @@ -3717,18 +4048,19 @@ def _prng_random_bits_lowering_rule(ctx: LoweringRuleContext, *, shape): ctx.lowering_context.dynamic_shape_replacement_fn, out_aval ) return tpu.prng_random_bits(out_type) -lowering_rules[tpu_primitives.prng_random_bits_p] = _prng_random_bits_lowering_rule -def random_seed_lowering(ctx, seeds, *, impl): +@register_lowering_rule(prng.random_seed_p) +def random_seed_lowering(ctx: LoweringRuleContext, seeds, *, impl): seed_lowering = lower_fun(impl.seed, multiple_results=False) return seed_lowering(ctx, seeds) -lowering_rules[prng.random_seed_p] = random_seed_lowering -def random_bits_lowering(ctx, keys, *, bit_width, shape): +@register_lowering_rule(prng.random_bits_p) +def random_bits_lowering(ctx: LoweringRuleContext, keys, *, bit_width, shape): assert bit_width == 32, "Only 32-bit PRNG supported." aval, = ctx.avals_in + assert isinstance(aval.dtype, prng.KeyTy) impl = aval.dtype._impl _proxy_fn = impl.random_bits if not pl_random.is_pallas_impl(impl): @@ -3738,100 +4070,93 @@ def new_lowering(key, bit_width, shape): _proxy_fn = new_lowering bits_lowering = lower_fun(_proxy_fn, multiple_results=False) return bits_lowering(ctx, keys, bit_width=bit_width, shape=shape) -lowering_rules[prng.random_bits_p] = random_bits_lowering -def random_fold_in_lowering(ctx, keys, msgs): - keys_aval, _ = ctx.avals_in +@register_lowering_rule(prng.random_fold_in_p) +def random_fold_in_lowering(ctx: LoweringRuleContext, keys, msgs): + keys_aval, msgs_aval = ctx.avals_in + assert isinstance(keys_aval.dtype, prng.KeyTy) impl = keys_aval.dtype._impl fold_in_lowering = lower_fun(impl.fold_in, multiple_results=False) - return fold_in_lowering(ctx, keys, msgs) -lowering_rules[prng.random_fold_in_p] = random_fold_in_lowering + if pl_random.is_pallas_impl(impl): + return fold_in_lowering(ctx, keys, msgs) + else: + ctx = dataclasses.replace(ctx, + avals_in=[jax_core.physical_aval(keys_aval), msgs_aval], + avals_out=map(jax_core.physical_aval, ctx.avals_out)) + return fold_in_lowering(ctx, keys, msgs) -def random_unwrap_lowering(ctx, key): +@register_lowering_rule(prng.random_unwrap_p) +def random_unwrap_lowering(ctx: LoweringRuleContext, key): keys_aval = ctx.avals_in[0] + assert isinstance(keys_aval.dtype, prng.KeyTy) impl = keys_aval.dtype._impl if not pl_random.is_pallas_impl(impl): return key - assert isinstance(key, KeyScalarBundle) - # Convert to a vector. - if tuple(key.key_shape) != (1, 1): - raise NotImplementedError( - "Seed key_data of shape != (1, 1) not supported. " - f"Got: {key.key_shape}") - scalar = key.scalars[0] - out_type = ir.VectorType.get( - key.key_shape, _dtype_to_ir_type(jnp.dtype('int32')) + raise ValueError( + "key_data not support for Pallas PRNG keys. Use" + " split_pallas_seed instead." ) - val = vector.broadcast(out_type, scalar) - return val -lowering_rules[prng.random_unwrap_p] = random_unwrap_lowering -def random_wrap_lowering(ctx, key_data, *, impl): +@register_lowering_rule(prng.random_wrap_p) +def random_wrap_lowering(ctx: LoweringRuleContext, key_data, *, impl): del ctx if not pl_random.is_pallas_impl(impl): return key_data - if isinstance(key_data.type, ir.VectorType): - # If the key data lives in vregs, need to unpack it to sregs. - key_data_list = [] - key_data_shape = key_data.type.shape - if len(key_data_shape) != 2: - raise NotImplementedError("Seed key_data must be 2D.") - if tuple(key_data_shape) != (1, 1): - raise NotImplementedError( - "Seed key_data of shape != (1, 1) not supported. " - f"Got: {key_data_shape}") - for i in range(key_data_shape[1]): - key_data_list.append(vector.ExtractOp(key_data, [], [0, i])) - return KeyScalarBundle( - scalars=key_data_list, key_shape=tuple(key_data_shape)) - if isinstance(key_data, KeyScalarBundle): - return key_data - else: - raise NotImplementedError(f"key_data wrap {type(key_data)}") + raise ValueError( + "wrap_key_data not support for Pallas PRNG keys. Use" + " wrap_pallas_seed instead." + ) -lowering_rules[prng.random_wrap_p] = random_wrap_lowering -def _checkify_lowering_rule( - ctx: LoweringRuleContext, *err_args, err_tree, debug): - if not tpu_core.runtime_assert_enabled(): - if debug: - return [] - else: - raise LoweringException("Non-debug check must be functionalized. " - "Enable runtime asserts with " - "--jax_pallas_enable_runtime_assert " - "or functionalize with checkify.check.") - - assert ctx.lowering_context.ir_context.allow_unregistered_dialects, ( - "allow_unregistered_dialects must be set to True for " - "runtime assert check.") +@register_lowering_rule(tpu_primitives.split_key_p) +def _split_key_lowering_rule( + ctx: LoweringRuleContext, key_data: KeyScalarBundle +): + return key_data.scalars + + +@register_lowering_rule(tpu_primitives.join_key_p) +def _join_key_lowering_rule(ctx: LoweringRuleContext, *scalars, impl): + if not pl_random.is_pallas_impl(impl): + return ValueError(f"Can only join Pallas keys. Got impl={impl}") + return KeyScalarBundle(scalars=scalars, key_shape=tuple(impl.key_shape)) + + +@register_lowering_rule(checkify.check_p, kernel_types=[*tpu_core.KernelType]) +def _check_lowering_rule( + ctx: LoweringRuleContext, *err_args, err_tree, debug +): + del ctx # Unused. + + if not debug: + raise NotImplementedError( + "Non-debug checks are not supported by the Mosaic backend." + " Functionalize them via `jax.experimental.checkify`." + ) + if not pallas_helpers.debug_checks_enabled(): + return [] + error = jax.tree.unflatten(err_tree, err_args) - assert len(error._pred) == 1 - assert len(error._metadata) == 1 - assert len(error._payload) == 1 - pred = list(error._pred.items())[0][1] - metadata = list(error._metadata.items())[0] - payload = list(error._payload.items())[0][1] - exception_tree = metadata[1] + [pred] = error._pred.values() + [exception_tree] = error._metadata.values() + [payload] = error._payload.values() exception = jax.tree.unflatten(exception_tree, payload) assert isinstance(exception, checkify.FailedCheckError) + assert isinstance(exception, checkify.FailedCheckError) - # check_p has an inverted predicate compared to assert, - # so we need to compute not(pred) here. - out_scalar_type = _dtype_to_ir_type(jnp.dtype('bool')) - minus_one = ir_constant(-1, out_scalar_type) + # check_p has an inverted predicate compared to assert, so we need to compute + # ``not pred`` here. + minus_one = ir_constant(-1, _dtype_to_ir_type(jnp.bool)) not_pred = arith.xori(pred, minus_one) - attrs = {"msg": ir.StringAttr.get(exception.fmt_string)} - ir.Operation.create("cf.assert", - operands=(not_pred,), - attributes=attrs) + cf.assert_(not_pred, exception.fmt_string) return [] -lowering_rules[checkify.check_p] = _checkify_lowering_rule -def _threefry2x32_lowering(ctx, k1, k2, m1, m2): + +@register_lowering_rule(prng.threefry2x32_p) +def _threefry2x32_lowering(ctx: LoweringRuleContext, k1, k2, m1, m2): def _lower_fun(k1, k2, m1, m2): with jax.named_scope("threefry2x32"): res = prng._threefry2x32_lowering(k1, k2, m1, m2, use_rolled_loops=False) @@ -3841,10 +4166,8 @@ def _lower_fun(k1, k2, m1, m2): return threefry_lowering(ctx, k1, k2, m1, m2) -lowering_rules[prng.threefry2x32_p] = _threefry2x32_lowering - - -def _iota_2x32_shape_lowering(ctx, *, shape): +@register_lowering_rule(prng.iota_2x32_shape_p) +def _iota_2x32_shape_lowering(ctx: LoweringRuleContext, *, shape): total_elements = np.prod(shape) if total_elements > np.iinfo(jnp.int32).max: raise NotImplementedError(f"Iota with >{np.iinfo(jnp.int32).max} items.") @@ -3865,14 +4188,12 @@ def _lower_fun(shape): return iota_lowering(ctx, shape=shape) -lowering_rules[prng.iota_2x32_shape_p] = _iota_2x32_shape_lowering - - +@register_lowering_rule(lax.pad_p) def _pad_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs): operand, padding_value = args padding_config = kwargs["padding_config"] - out_type: ir.VectorType = aval_to_ir_type( + out_type = aval_to_ir_type( ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_in[0] ) if not isinstance(out_type, ir.VectorType): @@ -3882,7 +4203,7 @@ def _pad_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs): if low == 0 and high == 0 and interior == 0: continue - def _pad(val): + def _pad(val, axis=axis): shape = list(operand.type.shape) shape[axis] = val pad_vec_type = ir.VectorType.get( @@ -3894,34 +4215,17 @@ def _pad(val): pad = vector.broadcast(pad_vec_type, padding_value) else: scalar_attr = ir.FloatAttr.get(operand.type.element_type, padding_value) - pad = arith.ConstantOp( + pad = arith.constant( pad_vec_type, - ir.DenseElementsAttr.get_splat( - pad_vec_type, - scalar_attr, - ), - ).result + ir.DenseElementsAttr.get_splat(pad_vec_type, scalar_attr), + ) return pad if low != 0: - pad_low = _pad(low) - new_shape = out_type.shape - new_shape[axis] += low - out_type = ir.VectorType.get( - new_shape, - out_type.element_type, - ) - operand = tpu.concatenate(out_type, [pad_low, operand], dimension=axis) + operand = tpu.concatenate([_pad(low), operand], dimension=axis) if high != 0: - pad_high = _pad(high) - new_shape = out_type.shape - new_shape[axis] += high - out_type = ir.VectorType.get( - new_shape, - out_type.element_type, - ) - operand = tpu.concatenate(out_type, [operand, pad_high], dimension=axis) + operand = tpu.concatenate([operand, _pad(high)], dimension=axis) if interior > 0: raise NotImplementedError("Not implemented: interior padding") @@ -3929,38 +4233,37 @@ def _pad(val): return operand -lowering_rules[lax.pad_p] = _pad_lowering_rule - - +@register_lowering_rule(control_flow.platform_index_p) def _platform_index_lowering( ctx: mlir.LoweringRuleContext, *, - platforms: Sequence[Sequence[str]], - has_default: bool, + platforms: BranchesPlatforms, ): for i, ps in enumerate(platforms): # note - slightly odd structure here, as platforms is a seq[seq[str]] - if "mosaic" in ps: + if "mosaic" in ps or ps is None: return ir_constant(i) - if has_default: - return ir_constant(len(platforms)) - raise NotImplementedError( "No mosaic or default platform indexing rule found." ) -lowering_rules[jax._src.lax.control_flow.platform_index_p] = _platform_index_lowering - - -def _dim_as_value_lowering(ctx: mlir.LoweringRuleContext, *, dim): +@register_lowering_rule(shape_poly.dim_as_value_p) +def _dim_as_value_lowering(ctx: LoweringRuleContext, *, dim): placeholder = ctx.lowering_context.dynamic_shape_replacement_fn((dim,))[0] - return ir_constant( - placeholder, mlir_type=_dtype_to_ir_type(jnp.dtype("int32")) - ) + return ir_constant(placeholder, mlir_type=_dtype_to_ir_type(jnp.int32)) + +@register_lowering_rule(tpu_primitives.touch_p) +def _touch_lowering_rule(ctx: LoweringRuleContext, x: jax.Array): + del ctx, x + return [] -import jax._src.export.shape_poly as shape_poly -lowering_rules[shape_poly.dim_as_value_p] = _dim_as_value_lowering +@register_lowering_rule(tpu_primitives.trace_value_p) +def _trace_value_lowering_rule(ctx: LoweringRuleContext, value, *, label: str): + """Lower trace_value to tpu.trace_value.""" + del ctx + tpu.trace_value(value, label) + return [] diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 896af0c464c5..dd8bb00dcc78 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -16,23 +16,25 @@ from __future__ import annotations -import os -import tempfile -from typing import Any +from collections.abc import Sequence +import dataclasses +import json +from typing import cast import jax from jax import dtypes -from jax._src import config from jax._src import core as jax_core +from jax._src import frozen_dict from jax._src import sharding_impls from jax._src import tpu_custom_call from jax._src.interpreters import mlir from jax._src.lib.mlir import ir -from jax._src.pallas import core +from jax._src.lib.mlir import passmanager from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic import core as tpu_core from jax._src.pallas.mosaic import lowering -from jax._src.pallas.mosaic import verification +from jax._src.pallas.mosaic import sc_lowering +from jax._src.state import types as state_types from jax.experimental import mosaic from jax.experimental.mosaic.dialects import tpu @@ -45,7 +47,7 @@ def _maybe_cast_to_int(x: jax.Array | jax_core.AbstractValue): after loading from a memref inside of the kernel. """ assert isinstance( - x, (jax.Array, jax_core.ShapedArray, jax_core.DShapedArray) + x, (jax.Array, jax_core.ShapedArray, state_types.AbstractLinVal) ), type(x) if isinstance(x, jax.Array): if dtypes.issubdtype(x.dtype, jax.numpy.bool_): @@ -53,81 +55,86 @@ def _maybe_cast_to_int(x: jax.Array | jax_core.AbstractValue): return x else: if dtypes.issubdtype(x.dtype, jax.numpy.bool_): + if isinstance(x, state_types.AbstractLinVal): + raise NotImplementedError # TODO(mattjj,sharadmv) return jax_core.ShapedArray(x.shape, lowering.BOOL_MEMREF_TYPE) return x -_DUMP_PROMELA_TO = config.string_flag( - "jax_pallas_dump_promela_to", - default=os.getenv("JAX_PALLAS_DUMP_PROMELA_TO", ""), - help=( - "If set, dumps a Promela model of the kernel to the specified" - " directory. The model can verify that the kernel is free of data" - " races, deadlocks, etc." - ), -) - def _get_memory_space_from_aval( - out_aval: jax_core.AbstractValue, + out_aval: jax_core.AbstractValue, kernel_type: tpu_core.KernelType ) -> tpu_custom_call.MemorySpace | None: if not isinstance(out_aval, jax_core.ShapedArray): - raise ValueError('Memory spaces not defined for non-ShapedArrays') - if not isinstance(out_aval, core.ShapedArrayWithMemorySpace): + raise ValueError("Memory spaces not defined for non-ShapedArrays") + if not isinstance(out_aval, pallas_core.ShapedArrayWithMemorySpace): # If we are passed a regular old ShapedArray, we don't constrain the # memory space return None # If we are passed an aval with an explicit memory space tag, we use it # to constrain the memory space. match out_aval.memory_space: - case None: - return None - case tpu_core.TPUMemorySpace.ANY: - return None - case tpu_core.TPUMemorySpace.VMEM: + case tpu_core.MemorySpace.HBM: + return tpu_custom_call.MemorySpace.HBM + case tpu_core.MemorySpace.VMEM: return tpu_custom_call.MemorySpace.VMEM - case tpu_core.TPUMemorySpace.SMEM: + case tpu_core.MemorySpace.SMEM: return tpu_custom_call.MemorySpace.SMEM - case tpu_core.TPUMemorySpace.SEMAPHORE: - return tpu_custom_call.MemorySpace.SEMAPHORE_MEM + case tpu_core.MemorySpace.SEMAPHORE: + match kernel_type: + case tpu_core.KernelType.SC_SCALAR_SUBCORE: + return tpu_custom_call.MemorySpace.SC_SCALAR_SEMAPHORE_MEM + case tpu_core.KernelType.TC: + return tpu_custom_call.MemorySpace.SEMAPHORE_MEM + case _: + raise ValueError(f"Invalid kernel type for semaphore: {kernel_type}") + case tpu_core.MemorySpace.HOST: + return tpu_custom_call.MemorySpace.HOST return None def _get_memory_spaces_from_avals( - out_avals: tuple[jax_core.AbstractValue, ...], + avals: Sequence[jax_core.AbstractValue], kernel_type: tpu_core.KernelType ) -> tuple[tpu_custom_call.MemorySpace | None, ...] | None: - output_memory_spaces = None + memory_spaces = None if any( - isinstance(out_aval, core.ShapedArrayWithMemorySpace) - for out_aval in out_avals + isinstance(aval, pallas_core.ShapedArrayWithMemorySpace) for aval in avals ): - output_memory_spaces = tuple(map(_get_memory_space_from_aval, out_avals)) - return output_memory_spaces + memory_spaces = tuple( + _get_memory_space_from_aval(aval, kernel_type=kernel_type) + for aval in avals + ) + return memory_spaces + def pallas_call_tpu_lowering_rule( ctx: mlir.LoweringRuleContext, *in_nodes, jaxpr: jax_core.Jaxpr, - grid_mapping: core.GridMapping, + grid_mapping: pallas_core.GridMapping, mesh: pallas_core.Mesh | None, input_output_aliases: tuple[tuple[int, int], ...], debug: bool, interpret: bool, - compiler_params: dict[str, Any], - cost_estimate: core.CostEstimate | None, + compiler_params: dict[str, pallas_core.CompilerParams], + cost_estimate: pallas_core.CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], + metadata: frozen_dict.FrozenDict[str, str] | None, + name: str | None, ): """Lowers a pallas_call to a Mosaic TPU custom call.""" - del mesh, interpret # Unused. + del interpret # Unused. - debug_info = jaxpr._debug_info + debug_info = jaxpr.debug_info if debug: print(f"\nThe kernel jaxpr for pallas_call {debug_info.func_src_info}:") print(jaxpr) - if "mosaic" in compiler_params: - mosaic_params = compiler_params["mosaic"] + + if "mosaic_tpu" in compiler_params: + mosaic_params = cast(tpu_core.CompilerParams, compiler_params["mosaic_tpu"]) else: - mosaic_params = {} + mosaic_params = tpu_core.CompilerParams() + del mesh jax_mesh = None axis_context = ctx.module_context.axis_context if axis_context is not None: @@ -138,64 +145,38 @@ def pallas_call_tpu_lowering_rule( mlir_ctx.load_all_available_dialects() tpu.register_dialect(mlir_ctx) - def lower_module(for_verification: bool): - if for_verification or tpu_core.runtime_assert_enabled(): - mlir_ctx.allow_unregistered_dialects = True - with mlir_ctx, ir.Location.unknown(mlir_ctx): - dimension_semantics = mosaic_params.get("dimension_semantics", None) - return lowering.lower_jaxpr_to_module( - ctx, - mlir_ctx, - grid_mapping, - jaxpr, - dimension_semantics=dimension_semantics, - mesh=jax_mesh, - for_verification=for_verification, - dynamic_shape_replacement_enabled=pallas_core.dynamic_shapes_export_enabled(), + match (kernel_type := mosaic_params.kernel_type): + case tpu_core.KernelType.TC: + lower_jaxpr_to_module = lowering.lower_jaxpr_to_module + case tpu_core.KernelType.SC_SCALAR_SUBCORE | tpu_core.KernelType.SC_VECTOR_SUBCORE: + lower_jaxpr_to_module = sc_lowering.lower_jaxpr_to_module + case _: + raise ValueError( + f"Unsupported kernel type: {mosaic_params.kernel_type}" ) - mosaic_module, extra_args = lower_module(for_verification=False) + with mlir_ctx, ir.Location.unknown(mlir_ctx): + mosaic_module = lower_jaxpr_to_module( + ctx, + grid_mapping, + jaxpr, + dimension_semantics=mosaic_params.dimension_semantics, + kernel_type=kernel_type, + mesh=jax_mesh, + dynamic_shape_replacement_enabled=pallas_core.dynamic_shapes_export_enabled(), + ) + if debug: + pm = passmanager.PassManager.parse("builtin.module(canonicalize)", mlir_ctx) + pm.run(mosaic_module.operation) print(f"\nThe Mosaic module for pallas_call {debug_info.func_src_info}:") print(mosaic_module) - num_extra_args = len(extra_args) num_dyn_bounds = grid_mapping.num_dynamic_grid_bounds input_output_aliases = tuple( - (a[0] + num_dyn_bounds + num_extra_args, a[1]) + (a[0] + num_dyn_bounds, a[1]) for a in input_output_aliases ) - if promela_dump_path := _DUMP_PROMELA_TO.value: - num_devices = 1 if jax_mesh is None else jax_mesh.devices.size - num_cores = ( - jax.devices()[0].num_cores - if jax_mesh is None - else jax_mesh.devices[0].num_cores - ) - verification_module, _ = lower_module(for_verification=True) - model = verification.export_promela_model( - verification_module, num_devices, num_cores - ) - if promela_dump_path == "stdout": - print(f"The Promela model for pallas_call {debug_info.func_src_info}:") - print(model) - else: - if promela_dump_path == "sponge": - promela_dump_path = os.getenv("TEST_UNDECLARED_OUTPUTS_DIR", "") - if not promela_dump_path: - raise ValueError( - "TEST_UNDECLARED_OUTPUTS_DIR must be set when" - " --jax_pallas_dump_promela_to=sponge" - ) - dump_ctx = tempfile.NamedTemporaryFile( - mode="w", - prefix=mlir.sanitize_name(debug_info.func_name) + "-", - suffix=".pml", - dir=promela_dump_path, delete=False, - ) - with dump_ctx as f: - f.write(model) - # Replace in_avals to physical avals. # This step is required for mapping logical types to physical types. # (e.g. PRNG key -> uint32[2]) @@ -206,48 +187,136 @@ def lower_module(for_verification: bool): def _maybe_cast_inputs(*args): args = [_maybe_cast_to_int(x) for x in args] return args + kernel_in_avals = [_maybe_cast_to_int(x) for x in ctx.avals_in] - kernel_out_avals = [_maybe_cast_to_int(x) for x in out_avals] + kernel_out_avals = [_maybe_cast_to_int(x) for x in ctx.avals_out] cast_ctx = ctx.replace(avals_out=kernel_in_avals) in_nodes = mlir.lower_fun(_maybe_cast_inputs)(cast_ctx, *in_nodes) # Dynamic grid bounds have to go at the front. dynamic_grid_args, args = in_nodes[:num_dyn_bounds], in_nodes[num_dyn_bounds:] kernel_ctx = ctx.replace(avals_in=kernel_in_avals, avals_out=kernel_out_avals) - output_memory_spaces = _get_memory_spaces_from_avals(out_avals) + output_memory_spaces = _get_memory_spaces_from_avals( + out_avals, kernel_type=kernel_type + ) + input_memory_spaces = None + if any( + isinstance(aval, pallas_core.ShapedArrayWithMemorySpace) + for aval in ctx.avals_in + ): + input_memory_spaces = _get_memory_spaces_from_avals( + ctx.avals_in, kernel_type=kernel_type + ) if cost_estimate is not None: - mosaic_cost_estimate = tpu_custom_call.CostEstimate( - flops=cost_estimate.flops, - bytes_accessed=cost_estimate.bytes_accessed, - transcendentals=cost_estimate.transcendentals, + mosaic_cost_estimate = cast( + tpu_custom_call.CostEstimate, dataclasses.asdict(cost_estimate) ) else: mosaic_cost_estimate = None + if input_memory_spaces is None and output_memory_spaces is not None: + input_memory_spaces_list: list[tpu_custom_call.MemorySpace | None] = [ + None, + ] * len(ctx.avals_in) + for input_output_alias in input_output_aliases: + input_memory_spaces_list[input_output_alias[0]] = output_memory_spaces[ + input_output_alias[1] + ] + input_memory_spaces = tuple(input_memory_spaces_list) + if input_memory_spaces is not None: + # Filter out the memory spaces that are not supported for input memory + # spaces. + input_memory_spaces = tuple( + i + if i + in { # pylint: disable=g-long-ternary + tpu_custom_call.MemorySpace.HBM, + tpu_custom_call.MemorySpace.VMEM, + tpu_custom_call.MemorySpace.SMEM, + } + else None + for i in input_memory_spaces + ) + has_side_effects: bool | tpu_custom_call.TpuSideEffectType + match mosaic_params.has_side_effects: + case bool(): + has_side_effects = mosaic_params.has_side_effects + case tpu_core.SideEffectType.PURE: + has_side_effects = tpu_custom_call.TpuSideEffectType.PURE + case tpu_core.SideEffectType.DATAFLOW_SIDE_EFFECTING: + has_side_effects = ( + tpu_custom_call.TpuSideEffectType.DATAFLOW_SIDE_EFFECTING + ) + case tpu_core.SideEffectType.SIDE_EFFECTING: + has_side_effects = tpu_custom_call.TpuSideEffectType.SIDE_EFFECTING + case _: + raise ValueError(f"Invalid side effect type: {mosaic_params.has_side_effects}") + tiling: tpu_custom_call.Tiling | None = None + if mosaic_params.use_tc_tiling_on_sc is not None: + if kernel_type not in ( + tpu_core.KernelType.SC_SCALAR_SUBCORE, + tpu_core.KernelType.SC_VECTOR_SUBCORE, + ): + raise ValueError( + "use_tc_tiling_on_sc= is only supported for SC_*_SUBCORE kernels" + ) + + tiling = ( + tpu_custom_call.Tiling.COMPACT + if mosaic_params.use_tc_tiling_on_sc + else tpu_custom_call.Tiling.SPARSE_CORE + ) + dict_metadata = dict(metadata) if metadata is not None else {} + del metadata + if jax_mesh is not None: + mesh_axes = { + e.name + for e in jaxpr.effects + if isinstance(e, jax_core.NamedAxisEffect) + # Filter for only device mesh axis name effects + and e.name in jax_mesh.axis_names + } + # Only put mesh axes in metadata if there are any. + if mesh_axes: + if "mesh_axes" in dict_metadata: + raise ValueError("Metadata already contains mesh axes.") + mesh_axes_list = list(mesh_axes) + if all(isinstance(a, str) for a in mesh_axes): + mesh_axes_list = sorted(mesh_axes) # type: ignore + dict_metadata["mesh_axes"] = json.dumps(mesh_axes_list) out_nodes = mosaic.lower_module_to_custom_call( kernel_ctx, *dynamic_grid_args, - *extra_args, *args, module=mosaic_module, out_type=kernel_out_avals, - backend="tpu", - kernel_name=mlir.sanitize_name(debug_info.func_name), + kernel_name=mlir.sanitize_name(name or debug_info.func_name), cost_estimate=mosaic_cost_estimate, - vmem_limit_bytes=mosaic_params.get("vmem_limit_bytes"), - flags=mosaic_params.get("flags"), - allow_input_fusion=mosaic_params.get("allow_input_fusion"), + vmem_limit_bytes=mosaic_params.vmem_limit_bytes, + flags=mosaic_params.flags, + allow_input_fusion=mosaic_params.allow_input_fusion, input_output_aliases=input_output_aliases, - serialization_format=mosaic_params.get("serialization_format", 1), - device_type=mosaic_params.get("device_type"), - internal_scratch_in_bytes=mosaic_params.get("internal_scratch_in_bytes"), - collective_id=mosaic_params.get("collective_id", None), - has_side_effects=mosaic_params.get("has_side_effects", False), + serialization_format=mosaic_params.serialization_format, + internal_scratch_in_bytes=mosaic_params.internal_scratch_in_bytes, + collective_id=mosaic_params.collective_id, + has_side_effects=has_side_effects, output_memory_spaces=output_memory_spaces, + disable_bounds_checks=mosaic_params.disable_bounds_checks, + input_memory_spaces=input_memory_spaces, + metadata=dict_metadata, + skip_device_barrier=mosaic_params.skip_device_barrier, + allow_collective_id_without_custom_barrier=mosaic_params.allow_collective_id_without_custom_barrier, + shape_invariant_numerics=mosaic_params.shape_invariant_numerics, + tiling=tiling, ) - _maybe_cast_to_bool = lambda x, aval: x.astype( - jax.numpy.bool_) if aval.dtype == jax.numpy.bool_ else x + _maybe_cast_to_bool = ( + lambda x, aval: x.astype(jax.numpy.bool_) + if aval.dtype == jax.numpy.bool_ + else x + ) + def _maybe_cast_outputs(*args): args = [_maybe_cast_to_bool(x, aval) for x, aval in zip(args, out_avals)] return args + cast_ctx = ctx.replace(avals_in=kernel_out_avals) return mlir.lower_fun(_maybe_cast_outputs)(cast_ctx, *out_nodes) diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 184b1497adf9..d5ce51e6cea3 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -20,27 +20,28 @@ import dataclasses import enum import functools -import itertools -import operator -from typing import Any, Union +from typing import Any, Literal, Union import jax +from jax import core as jax_core from jax import lax from jax import tree_util from jax._src import util as jax_util from jax._src.pallas import core as pallas_core from jax._src.pallas import primitives as primitives from jax._src.pallas.mosaic import core as tpu_core +from jax._src.pallas.mosaic import helpers as tpu_helpers from jax._src.pallas.mosaic import primitives as tpu_primitives +from jax._src.pallas.mosaic import tpu_info +from jax._src.state import types as state_types from jax.experimental import pallas as pl -from jax.extend.backend import get_default_device import jax.numpy as jnp -import numpy as np -SMEM = tpu_core.TPUMemorySpace.SMEM -VMEM = tpu_core.TPUMemorySpace.VMEM -DMA = tpu_core.SemaphoreType.DMA +SMEM = tpu_core.MemorySpace.SMEM +VMEM = tpu_core.MemorySpace.VMEM +HBM = tpu_core.MemorySpace.HBM +ANY = pallas_core.MemorySpace.ANY REF = pallas_core.MemoryRef GridDimensionSemantics = tpu_core.GridDimensionSemantics PARALLEL = tpu_core.PARALLEL @@ -55,8 +56,19 @@ PipelineRefs = Union[Sequence[REF], Any] -# TODO(sharadmv): make this a parameter and make it queryable from the Device. -_TILING = (8, 128) +class Tiling(enum.Enum): + COMPACT = enum.auto() + SPARSE_CORE = enum.auto() + + @property + def shape(self) -> tuple[int, ...]: + # TODO(slebedev): Use ``get_tpu_info()`` instead of hardcoding the values. + match self: + case Tiling.COMPACT: + return (8, 128) + case Tiling.SPARSE_CORE: + return (8,) + def _broadcast_pytree_to(from_pytree, to_pytree): """Broadcast a prefix pytree to a given full tree.""" @@ -77,40 +89,63 @@ def add_leaves(i, x): return tree_util.tree_unflatten(treedef, broadcast_leaves) -@jax_util.cache(trace_context_in_key=False) def _get_tpu_generation() -> int: - kind = get_default_device().device_kind - if kind.endswith(' lite'): - kind = kind[:-len(' lite')] - assert kind[:5] == "TPU v", kind - return int(kind[5]) - -def _make_tiling(shape: tuple[int, ...], dtype: np.dtype) -> tuple[int, ...]: - # For a n-dimensional shape, returns (8, 128) for the last 2 dimensions - # and 1 for the leading n - 2. For example, (256, 256) -> (8, 128) and - # (2, 3, 128, 128) -> (1, 1, 8, 128). - if len(shape) < 2: - raise ValueError(f"Shape must have at least 2 dimensions: {shape=}") - leading_dims, final_dims = shape[:-2], shape[-2:] - # We want to find the minimum power of 2 that fits the second-minor dimension - # of shape, with maximum value 8. - second_minor, _ = final_dims - packing = 4 // dtype.itemsize - max_tiling = _TILING[0] - second_minor_tiling = (1 + int(_get_tpu_generation() < 4)) * packing - while second_minor_tiling < min(second_minor, max_tiling): - second_minor_tiling *= 2 - return (*(1,) * len(leading_dims), second_minor_tiling, _TILING[1]) - - -def _round_up_to_nearest_multiple(s: int, multiple: int) -> int: - if s % multiple == 0: + return tpu_info.get_tpu_info().generation + + +def _make_tiling( + shape: tuple[int, ...], + ty: jax_core.AbstractValue, + tiling: Tiling | None = None, +) -> tuple[int | None, ...]: + """Compute a tiling for the given shape and type. + + For an n-dimensional shape, returns the tiling for the last + ``len(tiling.shape)`` dimensions and 1 for the leading dims. For example: + - 2D tiling: (256, 256) -> (8, 128) and (2, 3, 128, 128) -> (1, 1, 8, 128). + - 1D tiling: (16,) -> (8,) and (2, 3, 8) -> (1, 1, 8). + + Types are not required to have a dtype, so for such types we return None for + all dimensions because their tiling is unknown. + """ + if not hasattr(ty, "dtype"): + return (None,) * len(shape) + packing = 4 // ty.dtype.itemsize + + if tiling is None: + tiling = Tiling.COMPACT + tiling_rank = len(tiling.shape) + if len(shape) < tiling_rank: + raise ValueError( + f"Shape must have at least {tiling_rank} dimensions: {shape=}" + ) + + leading_dims, final_dims = shape[:-tiling_rank], shape[-tiling_rank:] + match tiling: + case Tiling.COMPACT: + # We want to find the minimum power of 2 that fits the second-minor + # dimension of shape, with maximum value equal to ``tiling.shape[0]``. + second_minor, _ = final_dims + max_tiling = tiling.shape[0] + second_minor_tiling = (1 + int(_get_tpu_generation() < 4)) * packing + while second_minor_tiling < min(second_minor, max_tiling): + second_minor_tiling *= 2 + return (*(1,) * len(leading_dims), second_minor_tiling, tiling.shape[1]) + case Tiling.SPARSE_CORE: + [tile_size] = tiling.shape + return (*(1,) * len(leading_dims), tile_size * packing) + + +def _round_up_to_nearest_multiple( + s: int | jax.Array, multiple: int +) -> int | jax.Array: + if isinstance(s, int) and s % multiple == 0: return s # Subtract off the remainder, then add multiple return s - s % multiple + multiple -def _make_ds( +def _make_block_ds( idx: jax.Array | int, size: jax.Array | int ) -> pl.Slice: """Make a DMA slice with mosaic size hints.""" @@ -119,34 +154,103 @@ def _make_ds( return out -def _make_block_slice( - block_index: jax.Array, block_size: int, size: int, tiling: int -) -> pl.Slice | slice: - # Computes a slice given a block index and block size. In the default case, - # we return slice(block_index * block_size, (block_index + 1) * block_size). - # However, if the total size of the ref does not divide block size and we are - # selecting the last block, we need to pick the lowest tiling size multiple - # that contains the block. - if size % block_size == 0: - return _make_ds(block_index, block_size) +def _create_blocked_slice( + block_index: jax.Array | int, + block_size: int, + dim_size: int, + tiling: int | None, +): + block_start = block_size * block_index + if (dim_rem := dim_size % block_size) == 0: + return pl.ds(block_start, block_size) + if tiling is None: + raise ValueError("If tiling is None, block_size must divide dim_size.") if block_size % tiling != 0: raise ValueError(f"Block size must divide tiling: {block_size=}, {tiling=}") - num_blocks = pl.cdiv(size, block_size) + num_blocks = pl.cdiv(dim_size, block_size) is_last = block_index == num_blocks - 1 rounded_size = jnp.where( is_last, - _round_up_to_nearest_multiple(size % block_size, tiling), + _round_up_to_nearest_multiple(dim_rem % block_size, tiling), block_size, ) rounded_size = pl.multiple_of(rounded_size, tiling) return pl.ds(block_index * block_size, rounded_size) +def _create_bounded_slice(slice_start: jax.Array | int, + slice_size: jax.Array | int, + block_size: int, + dim_size: int, + tiling: int | None): + if tiling is not None and block_size % tiling != 0: + raise ValueError(f"Block size must divide tiling: {block_size=}, {tiling=}") + # We assume by construction that slice_size <= block_size. We also assume + # that the slice_start is already aligned to the tiling. + + if tiling is None: + return pl.ds(slice_start, slice_size) + + # If we are out of bound, we need to round the slice size down to the nearest + # multiple of the tiling. + is_oob = slice_start + slice_size > dim_size + remaining = dim_size - slice_start + rounded_size = jnp.where( + is_oob, + _round_up_to_nearest_multiple(remaining, tiling), + slice_size, + ) + rounded_size = pl.multiple_of(rounded_size, tiling) + return pl.ds(slice_start, rounded_size) + +def _make_block_slice( + block_index: jax.Array, block_size: pl.BlockDim | int | None, size: int, + tiling: int | None +) -> pl.Slice | slice | int | jax.Array: + # Computes a slice given a block index and block size. In the default case, + # we return slice(block_index * block_size, (block_index + 1) * block_size). + # However, if the total size of the ref does not divide block size and we are + # selecting the last block, we need to pick the lowest tiling size multiple + # that contains the block. + match block_size: + case pl.Blocked(): + return _create_blocked_slice(block_index, block_size.block_size, size, tiling) + case int(): + return _create_blocked_slice(block_index, block_size, size, tiling) + case pl.Element(): + block_start = block_index + block_size = block_size.block_size + return _create_bounded_slice( + block_start, block_size, block_size, size, tiling + ) + case pl.BoundedSlice(block_size): + if not isinstance(block_index, pl.Slice): + raise ValueError( + "Must return a pl.ds from the index_map for a BoundedSlice" + " dimension." + ) + slice_start = block_index.start + slice_size = block_index.size + return _create_bounded_slice( + slice_start, slice_size, block_size, size, tiling + ) + case None | pl.Squeezed(): + return block_index + case _: + raise ValueError(f"Unsupported block dimension type: {block_size}") + def _tuples_differ(xs, ys): """Dynamic index-tuple comparison calculation.""" - differences = jax.tree.map(lambda x, y: x != y, xs, ys) + differences = jax.tree.leaves(jax.tree.map(lambda x, y: x != y, xs, ys)) return functools.reduce(lambda x, y: x | y, differences, False) +def _tuple_all_binop(binop, xs, ys): + """Dynamic reduce_all calculation with a user-provided comparison op.""" + differences = jax.tree.leaves(jax.tree.map(lambda x, y: binop(x, y), xs, ys)) + return functools.reduce(lambda x, y: x & y, differences, True) + +_tuple_lt = functools.partial(_tuple_all_binop, lambda x, y: x < y) + def _grid_size(grid): """Dynamic grid size calculation.""" @@ -156,20 +260,6 @@ def _grid_size(grid): return size -def _get_indices(step, grid, offsets): - """Get indices for a given step and grid.""" - # TODO(enriqueps): Implement using bitwise ops, avoid div/rem since they are - # expensive. - extended_grid = grid + (1,) - strides = tuple( - itertools.accumulate(extended_grid[::-1], func=operator.mul))[::-1] - indices = tuple( - lax.div(lax.rem(step, a), b) - for a, b in zip(strides[:-1], strides[1:]) - ) - return tuple(a + b for a, b in zip(indices, offsets, strict=True)) - - class BufferType(enum.Enum): """Buffer type for the arguments to an emitted pipeline.""" INPUT = 1 @@ -179,25 +269,223 @@ class BufferType(enum.Enum): MANUAL = 5 + @property + def is_input(self): + return self in [ + BufferType.INPUT, + BufferType.ACCUMULATOR, + BufferType.INPUT_OUTPUT, + ] + + @property + def is_output(self): + return self in [ + BufferType.OUTPUT, + BufferType.ACCUMULATOR, + BufferType.INPUT_OUTPUT, + ] + + +def _get_block_shape(spec: pl.BlockSpec) -> tuple[int, ...]: + """Get the block shape for a given block spec.""" + def _get_dim_size(bd): + match bd: + case pl.Blocked(block_size): + return block_size + case pl.Element(block_size): + return block_size + case pl.BoundedSlice(block_size): + return block_size + case int(): + return bd + case None | pl.Squeezed(): + return None + case _: + raise ValueError(f"Unsupported block dimension type: {bd}") + if spec.block_shape is None: + raise ValueError("Block shape must be specified.") + block_shape_nones = tuple(_get_dim_size(x) for x in spec.block_shape) + return tuple(x for x in block_shape_nones if x is not None) + + +class BufferedRefBase: + """Abstract interface for BufferedRefs.""" + + @property + def spec(self) -> pl.BlockSpec: + raise NotImplementedError() + + @property + def buffer_type(self) -> BufferType: + raise NotImplementedError() + + @property + def is_buffered(self) -> bool: + return False + + @property + def is_input(self): + return self.buffer_type.is_input + + @property + def is_output(self): + return self.buffer_type.is_output + + @property + def is_accumulator(self): + return self.buffer_type == BufferType.ACCUMULATOR + + @property + def is_input_output(self): + return self.buffer_type == BufferType.INPUT_OUTPUT + + @property + def is_manual(self): + return self.buffer_type == BufferType.MANUAL + + def init_slots(self): + """Initialize slot indices.""" + raise NotImplementedError() + + def advance_copy_in_slot(self, predicate: bool = True) -> BufferedRefBase: + """Advance the copy in slot.""" + raise NotImplementedError() + + def advance_wait_in_slot(self, predicate: bool = True) -> BufferedRefBase: + """Advance the wait in slot.""" + raise NotImplementedError() + + def advance_copy_out_slot(self, predicate: bool = True) -> BufferedRefBase: + """Advance the copy out slot.""" + raise NotImplementedError() -@tree_util.register_pytree_node_class + def advance_wait_out_slot(self, predicate: bool = True) -> BufferedRefBase: + """Advance the wait out slot.""" + raise NotImplementedError() + + def load_slots(self, predicate: bool | jax.Array = True) -> BufferedRefBase: + """Load slot information into registers.""" + raise NotImplementedError() + + def save_slots(self, predicate: bool | jax.Array = True): + """Save slot information from registers.""" + raise NotImplementedError() + + @property + def block_shape(self) -> Sequence[pl.BlockDim | int | None] | None: + return self.spec.block_shape + + @property + def compute_index(self): + return self.spec.index_map + + def get_dma_slice(self, src_ty, grid_indices): + # We need to handle blocks that might go OOB in the src array. An in bounds + # block looks like this (for array shape (600, 600) and block shape + # (256, 256)): + # + # +--------------+------------------| + # | Block (0,0) | | + # | (256, 256) | | + # +--------------+ | + # | A (600, 600) | + # | | + # +---------------------------------+ + # + # For in-bounds blocks, we don't need to do anything special. + # An out-of-bounds block looks like this: + # + # +--------------+------------------| + # | | + # | | + # + | + # | A (600, 600) | + # +--------------+ | + # | Block (2,0) | | + # + --------------------------------| + # | XXXXXXXXXX | + # +--------------+ + # where the X's indicate where the block is out of bounds. + # + # When we have an out of bounds block like this, we need to truncate it to + # a tile boundary (tiles are (8, 128) along the two minormost dimensions). + # In this case, we'll have a block that is indexing the + # 512:768 elements of A along the first dimension. We need to convert 768 + # into 600 (600 % 8 == 0), so our indexing will look like this: + + # +--------------+------------------| + # | | + # | | + # + | + # | A (600, 600) | + # +--------------+ | + # | Block (2,0) | | + # + --------------------------------| + # where it is now a (88, 256) sized block. + # + # Suppose A is now (601, 600), instead of picking a (88, 256)-sized block + # for the last iteration on that dimension, we will pick the next highest + # tile multiple, i.e. (96, 256). + + if (src_shape := getattr(src_ty, "shape", None)) is None: + raise ValueError(f"Type {src_ty} does not have a shape") + + tiling = _make_tiling(src_shape, src_ty, getattr(self, "tiling", None)) + block_indices = self.compute_index(*grid_indices) + return tuple( + _make_block_slice(bi, bs, ss, t) + for bi, bs, ss, t in zip( + block_indices, self.block_shape, src_shape, tiling, strict=True + ) + ) + + def bind_existing_ref(self, window_ref, indices): + """For handling VMEM references, the pipeline aliases the existing ref.""" + del window_ref, indices + return self + + def unbind_refs(self): + return self + + def with_spec(self, spec: pl.BlockSpec) -> BufferedRefBase: + """Returns a new BufferedRefBase with the given block spec.""" + raise NotImplementedError() + +def _ref_to_value_aval(ref): + """Return the inner of a ref, or a ShapedArray for TransformedRefs.""" + return ( + jax_core.ShapedArray(shape=ref.shape, dtype=ref.dtype) + if isinstance(ref, state_types.TransformedRef) + else jax.typeof(ref).inner_aval + ) + + +# TODO(justinfu): Refactor and rename slot fields to reflect cumulative values +# instead of slot index. +@tree_util.register_dataclass @dataclasses.dataclass(frozen=True) -class BufferedRef: +class BufferedRef(BufferedRefBase): """A helper class to automate VMEM double buffering in pallas pipelines. Attributes: spec: pallas blockspec. - dtype: dtype for buffers. buffer_type: enum indicating whether this is an input, output, or in/out accumulator buffered reference. - window_ref: a double-buffer to hold a working buffer and a dirty buffer used + window_ref: a multiple-buffer to hold the working and dirty buffers used to copy into and out of. In the case of a BufferedRef targeting a VMEM reference, this simply points to the existing ref. accum_ref: accumulating buffer used by accumulator BufferedRefs. - current_slot: current slot index to the working buffer. - next_slot: slot that will point to the working buffer in the next iteration. - sem_recvs: Double buffered semaphores for input DMAs. - sem_sends: Double buffered semaphores for output DMAs. + copy_in_slot: current slot to copy in for the working buffer. + copy_out_slot: current slot to copy out for the working buffer. + wait_in_slot: current slot to wait in for the working buffer. + wait_out_slot: current slot to wait out for the working buffer. + next_fetch_smem: Holds the next grid indices to fetch for lookahead. This + is the SMEM backing buffer used to persist state between pipeline + invocations. + next_fetch_sreg: Holds the next grid indices to fetch for lookahead. This + is the register state used to track the indices within the pipeline loop. + sem_recvs: Multiple buffered semaphores for input DMAs. + sem_sends: Multiple buffered semaphores for output DMAs. block_shape: passthrough property for the BlockSpec's block_shape. compute_index: passthrough property for the BlockSpec's compute_index. memory_space: passthrough property for the BlockSpec's memory_space. @@ -209,117 +497,202 @@ class BufferedRef: automatic accumulation. swap: Tracks whether the BufferedRef slots need to be swapped before next copy. + tiling: The tiling to assume for the buffers. """ - spec: pl.BlockSpec # static metadata - dtype: Any # static metadata - buffer_type: BufferType # static metadata - window_ref: REF | None - accum_ref: REF | None - current_slot: ArrayRef | None - # TODO(ramiroleal): Unused by class. Remove argument from - # BufferedRef instantiations. - next_slot: ArrayRef | None + _spec: pl.BlockSpec = dataclasses.field(metadata=dict(static=True)) + _buffer_type: BufferType = dataclasses.field(metadata=dict(static=True)) + window_ref: ArrayRef | None + accum_ref: ArrayRef | None + copy_in_slot: ArrayRef | None + wait_in_slot: ArrayRef | None + copy_out_slot: ArrayRef | None + wait_out_slot: ArrayRef | None + _copy_in_slot_reg: int | jax.Array | None + _wait_in_slot_reg: int | jax.Array | None + _copy_out_slot_reg: int | jax.Array | None + _wait_out_slot_reg: int | jax.Array | None + next_fetch_smem: Sequence[jax.Array] | None + next_fetch_sreg: Sequence[jax.Array] | None sem_recvs: SemaphoreTuple | None sem_sends: SemaphoreTuple | None # TODO(ramiroleal): Improve prefetch/postyeet interface to avoid # using this ref. swap: ArrayRef | None + tiling: Tiling | None = dataclasses.field(metadata=dict(static=True)) - def tree_flatten(self): - return ( - ( - self.window_ref, - self.accum_ref, - self.current_slot, - self.next_slot, - self.sem_recvs, - self.sem_sends, - self.swap, - ), - (self.spec, self.dtype, self.buffer_type), - ) + def __post_init__(self): + if self.is_buffered and self.buffer_count < 1: + raise ValueError( + f"buffer_count must be at least 1, got {self.buffer_count}" + ) + if self.is_output: + if self.is_buffered and self.buffer_count > 2: + raise NotImplementedError( + "Buffer count >2 not supported for output buffered refs." + ) - @classmethod - def tree_unflatten(cls, meta, data): - return cls(*meta, *data) + @property + def spec(self): + return self._spec + + @property + def buffer_type(self): + return self._buffer_type + + @property + def is_buffered(self) -> bool: + """Whether this buffer is multiple-buffered.""" + slots = [self.copy_in_slot, self.wait_in_slot, + self.copy_out_slot, self.wait_out_slot] + return any(x is not None for x in slots) + + @property + def use_lookahead(self) -> bool: + """Whether this buffer allows lookahead for fetching blocks.""" + return self.next_fetch_smem is not None + + @property + def buffer_count(self) -> int: + """Returns the number of buffers used for multiple buffering.""" + if not self.is_buffered: + raise ValueError("buffer count is undefined") + return self.window_ref.shape[0] # type: ignore[union-attr] @staticmethod def buffer_types() -> type[BufferType]: return BufferType @classmethod - def create(cls, spec, dtype, buffer_type, needs_swap_ref=True) -> BufferedRef: + def create( + cls, + spec: pl.BlockSpec, + dtype_or_type, + buffer_type, + buffer_count, + needs_swap_ref=True, + grid_rank=None, + use_lookahead=False, + source_memory_space: tpu_core.MemorySpace | Literal[ANY] = ANY, # type: ignore[valid-type] + tiling: Tiling | None = None, + ) -> BufferedRef: """Create a BufferedRef. Args: spec: pallas blockspec. - dtype: dtype for buffers. + dtype_or_type: dtype or aval for buffers. If an aval, the shape is + ignored. buffer_type: enum indicating whether this is an input, output, or in/out accumulator buffered reference. needs_swap_ref: whether a swap slots tracker needs to be allocated. + grid_rank: rank of the pipeline grid. + use_lookahead: whether to enable pipeline lookahead. + source_memory_space: The memory space of the backing source Ref. + tiling: The tiling to assume for the buffers. Returns: Initialized BufferedRef """ - block_shape = tuple(1 if x is None else x for x in spec.block_shape) + + # (123, 456) is a dummy shape since we never use ty without + # calling .update(shape=...) first. + ty = ( + dtype_or_type + if isinstance(dtype_or_type, jax_core.AbstractValue) + else jax_core.ShapedArray((123, 456), dtype_or_type) + ) + + block_shape = _get_block_shape(spec) if buffer_type is BufferType.ACCUMULATOR: - accum_ref = VMEM(block_shape, dtype) + accum_ref = VMEM.from_type(ty.update(shape=block_shape)) else: accum_ref = None - if spec.memory_space == VMEM: - # We don't need to do any double-buffering in the case that our pipeline - # reference is already in VMEM, we just need allocate the accumulation - # buffer and we will refer to the original reference slices directly. + buffer_memory_space = ( + VMEM if spec.memory_space is None else spec.memory_space) + if buffer_memory_space not in (SMEM, VMEM, HBM): + raise ValueError( + f"Unsupported buffer memory space: {buffer_memory_space}" + ) + if source_memory_space is buffer_memory_space: return cls( - spec=spec, - dtype=dtype, - buffer_type=buffer_type, + _spec=spec, + _buffer_type=buffer_type, window_ref=None, # to be bound to existing ref by the pipeline routine accum_ref=accum_ref, - current_slot=None, - next_slot=None, + copy_in_slot=None, + wait_in_slot=None, + copy_out_slot=None, + wait_out_slot=None, + _copy_in_slot_reg=None, + _wait_in_slot_reg=None, + _copy_out_slot_reg=None, + _wait_out_slot_reg=None, + next_fetch_smem=None, + next_fetch_sreg=None, sem_recvs=None, sem_sends=None, swap=None, + tiling=None, ) else: - memory_space = SMEM if spec.memory_space == SMEM else VMEM + if use_lookahead and grid_rank is None: + raise ValueError( + "grid_rank must be specified when use_lookahead is True." + ) + + buffer_ty = ty.update(shape=(buffer_count, *block_shape)) return cls( - spec=spec, - dtype=dtype, - buffer_type=buffer_type, - window_ref=memory_space((2,) + block_shape, dtype), + _spec=spec, + _buffer_type=buffer_type, + window_ref=buffer_memory_space.from_type(buffer_ty), accum_ref=accum_ref, - current_slot=SMEM((1,), jnp.int32), - next_slot=None, + copy_in_slot=SMEM((1,), jnp.uint32) if buffer_type.is_input else None, + wait_in_slot=SMEM((1,), jnp.uint32) if buffer_type.is_input else None, + copy_out_slot=SMEM((1,), jnp.uint32) if buffer_type.is_output else None, + wait_out_slot=SMEM((1,), jnp.uint32) if buffer_type.is_output else None, + _copy_in_slot_reg=None, + _wait_in_slot_reg=None, + _copy_out_slot_reg=None, + _wait_out_slot_reg=None, + next_fetch_smem=[SMEM((1,), jnp.int32) for _ in range( + grid_rank)] if use_lookahead else None, + next_fetch_sreg=None, sem_recvs=( None if buffer_type is BufferType.OUTPUT - else SemaphoreType.DMA((2,)) + else SemaphoreType.DMA((buffer_count,)) ), sem_sends=( None if buffer_type is BufferType.INPUT - else SemaphoreType.DMA((2,)) + else SemaphoreType.DMA((buffer_count,)) ), swap=SMEM((1,), jnp.bool) if needs_swap_ref else None, + tiling=tiling, ) @classmethod - def input(cls, spec, dtype, needs_swap_ref=True): - return cls.create(spec, dtype, BufferType.INPUT, needs_swap_ref) + def input(cls, spec, dtype_or_type, buffer_count=2, **kwargs): + return cls.create( + spec, dtype_or_type, BufferType.INPUT, buffer_count, **kwargs + ) @classmethod - def output(cls, spec, dtype, needs_swap_ref=True): - return cls.create(spec, dtype, BufferType.OUTPUT, needs_swap_ref) + def output(cls, spec, dtype_or_type, buffer_count=2, **kwargs): + return cls.create( + spec, dtype_or_type, BufferType.OUTPUT, buffer_count, **kwargs + ) @classmethod - def accumulator(cls, spec, dtype, needs_swap_ref=True): - return cls.create(spec, dtype, BufferType.ACCUMULATOR, needs_swap_ref) + def accumulator(cls, spec, dtype_or_type, buffer_count=2, **kwargs): + return cls.create( + spec, dtype_or_type, BufferType.ACCUMULATOR, buffer_count, **kwargs + ) @classmethod - def input_output(cls, spec, dtype, needs_swap_ref=True): - return cls.create(spec, dtype, BufferType.INPUT_OUTPUT, needs_swap_ref) + def input_output(cls, spec, dtype_or_type, buffer_count=2, **kwargs): + return cls.create( + spec, dtype_or_type, BufferType.INPUT_OUTPUT, buffer_count, **kwargs + ) @property def block_shape(self): @@ -329,162 +702,332 @@ def block_shape(self): def compute_index(self): return self.spec.index_map - @property - def memory_space(self): - return self.spec.memory_space + def with_spec(self, spec: pl.BlockSpec) -> BufferedRef: + """Returns a new BufferedRef with the given block spec.""" + return dataclasses.replace(self, _spec=spec) + + def with_next_fetch( + self, next_fetch: Sequence[jax.Array] | None = None, + ): + return dataclasses.replace(self, next_fetch_sreg=next_fetch) + + def with_slot_index( + self, + copy_in_slot: int | jax.Array | None = None, + copy_out_slot: int | jax.Array | None = None, + wait_in_slot: int | jax.Array | None = None, + wait_out_slot: int | jax.Array | None = None, + ) -> "BufferedRef": + """Returns a new BufferedRef with the given slot index.""" + new_buf = self + if copy_in_slot is not None: + new_buf = dataclasses.replace(new_buf, _copy_in_slot_reg=copy_in_slot) + if copy_out_slot is not None: + new_buf = dataclasses.replace(new_buf, _copy_out_slot_reg=copy_out_slot) + if wait_in_slot is not None: + new_buf = dataclasses.replace(new_buf, _wait_in_slot_reg=wait_in_slot) + if wait_out_slot is not None: + new_buf = dataclasses.replace(new_buf, _wait_out_slot_reg=wait_out_slot) + return new_buf @property def current_ref(self): buffer_slice = tuple( - 0 if x is None else slice(None) for x in self.block_shape) - if self.memory_space == VMEM: + slice(None) + for x in self.block_shape + if not (x is None or isinstance(x, pl.Squeezed)) + ) + assert not (self.window_ref is None or isinstance(self.window_ref, REF)) + if not self.is_buffered: return self.window_ref.at[buffer_slice] else: - return self.window_ref.at[(self.current_slot_index, *buffer_slice)] + if self.is_output: + slot = self.current_copy_out_slot + else: + slot = self.current_wait_in_slot + return self.window_ref.at[(slot, *buffer_slice)] @property - def is_input(self): - return self.buffer_type in [ - BufferType.INPUT, - BufferType.ACCUMULATOR, - BufferType.INPUT_OUTPUT, - ] + def cumulative_copy_in(self): + """The cumulative number of copy_ins issued on this buffer.""" + if self._copy_in_slot_reg is not None: + val = self._copy_in_slot_reg + else: + val = self.copy_in_slot[0] + return val @property - def is_output(self): - return self.buffer_type in [ - BufferType.OUTPUT, - BufferType.ACCUMULATOR, - BufferType.INPUT_OUTPUT, - ] + def current_copy_in_slot(self): + """Index in multiple buffer corresponding to the current slot.""" + return lax.rem(self.cumulative_copy_in, jnp.uint32(self.buffer_count)) @property - def is_accumulator(self): - return self.buffer_type == BufferType.ACCUMULATOR + def cumulative_copy_out(self): + """The cumulative number of copy_outs issued on this buffer.""" + if self._copy_out_slot_reg is not None: + val = self._copy_out_slot_reg + else: + val = self.copy_out_slot[0] + return val @property - def is_input_output(self): - return self.buffer_type == BufferType.INPUT_OUTPUT + def current_copy_out_slot(self): + """Index in multiple buffer corresponding to the current copy slot.""" + return lax.rem(self.cumulative_copy_out, jnp.uint32(self.buffer_count)) + + @property + def cumulative_wait_in(self): + """The cumulative number of wait_ins issued on this buffer.""" + if self._wait_in_slot_reg is not None: + val = self._wait_in_slot_reg + else: + val = self.wait_in_slot[0] + return val + + @property + def current_wait_in_slot(self): + """Index in multiple buffer corresponding to the current wait slot.""" + return lax.rem(self.cumulative_wait_in, jnp.uint32(self.buffer_count)) @property - def current_slot_index(self): - return self.current_slot[0] + def cumulative_wait_out(self): + """The cumulative number of wait_outs issued on this buffer.""" + if self._wait_out_slot_reg is not None: + val = self._wait_out_slot_reg + else: + val = self.wait_out_slot[0] + return val @property - def next_slot_index(self): - return lax.rem(self.current_slot_index + 1, 2) + def current_wait_out_slot(self): + """Index in multiple buffer corresponding to the current wait slot.""" + return lax.rem(self.cumulative_wait_out, jnp.uint32(self.buffer_count)) + + @property + def next_fetch_indices(self): + """Returns the next grid indices to fetch from if using lookahead.""" + if not self.use_lookahead: + raise ValueError("Can only get fetch indices if using lookahead.") + if self.next_fetch_sreg is not None: + return self.next_fetch_sreg + return tuple(smem[0] for smem in self.next_fetch_smem) def bind_existing_ref(self, window_ref, indices): """For handling VMEM references, the pipeline aliases the existing ref.""" - if self.memory_space == VMEM: + if not self.is_buffered: return dataclasses.replace( self, window_ref=window_ref.at[self.compute_slice(indices)] ) return self + def unbind_refs(self): + if not self.is_buffered: + return dataclasses.replace(self, window_ref=None) + return self + def compute_slice(self, grid_indices): """Compute DMA slice from grid indices.""" - block_shape = tuple(1 if x is None else x for x in self.block_shape) indices = self.compute_index(*grid_indices) - return jax.tree.map(_make_ds, indices, block_shape) + assert len(self.block_shape) == len(indices) + indexer = [] + for bd, idx in zip(self.block_shape, indices, strict=True): + match bd: + case None | pl.Squeezed(): + # Dimension is squeezed out so we don't do anything. + indexer.append(idx) + case pl.Element(): + raise ValueError( + "Element block dimensions are not supported." + ) + case pl.BoundedSlice(): + raise ValueError( + "BoundedSlice block dimensions are not supported." + ) + case pl.Blocked(block_size): + indexer.append(_make_block_ds(idx, block_size)) + case int(): + indexer.append(_make_block_ds(idx, bd)) + case _: + raise ValueError(f"Unsupported block dimension type: {type(bd)}") + return tuple(indexer) def init_slots(self): """Initialize slot indices.""" - if self.memory_space == VMEM: return - self.current_slot[0] = 0 + if not self.is_buffered: return + if self.is_input: + self.copy_in_slot[0] = 0 + self.wait_in_slot[0] = 0 + if self.use_lookahead: + for i in range(len(self.next_fetch_smem)): + self.next_fetch_smem[i][0] = 0 + if self.is_output: + self.copy_out_slot[0] = 0 + self.wait_out_slot[0] = 0 if self.swap is not None: self.swap[0] = False - def swap_slots(self): - """Switch to the next slot.""" - if self.memory_space == VMEM: return - self.current_slot[0] = self.next_slot_index - if self.swap is not None: - self.swap[0] = False + def advance_copy_in_slot(self, predicate: bool | jax.Array = True) -> "BufferedRef": + """Switch to the next copy slot.""" + if not self.is_buffered: return self + if not self.is_input: + return self + current_slot = (self.copy_in_slot[0] if # type: ignore[index] + self._copy_in_slot_reg is None else self._copy_in_slot_reg) + new_current_slot = lax.select(predicate, current_slot + 1, current_slot) + if self._copy_in_slot_reg is not None: + return self.with_slot_index(copy_in_slot=new_current_slot) + assert isinstance(self.copy_in_slot, jax.Array) + self.copy_in_slot[0] = new_current_slot + return self - def get_dma_slice(self, src_shape, src_dtype, grid_indices): - # We need to handle blocks that might go OOB in the src array. An in bounds - # block looks like this (for array shape (600, 600) and block shape - # (256, 256)): - # - # +--------------+------------------| - # | Block (0,0) | | - # | (256, 256) | | - # +--------------+ | - # | A (600, 600) | - # | | - # +---------------------------------+ - # - # For in-bounds blocks, we don't need to do anything special. - # An out-of-bounds block looks like this: - # - # +--------------+------------------| - # | | - # | | - # + | - # | A (600, 600) | - # +--------------+ | - # | Block (2,0) | | - # + --------------------------------| - # | XXXXXXXXXX | - # +--------------+ - # where the X's indicate where the block is out of bounds. - # - # When we have an out of bounds block like this, we need to truncate it to - # a tile boundary (tiles are (8, 128) along the two minormost dimensions). - # In this case, we'll have a block that is indexing the - # 512:768 elements of A along the first dimension. We need to convert 768 - # into 600 (600 % 8 == 0), so our indexing will look like this: + def advance_wait_in_slot(self, predicate: bool | jax.Array = True) -> "BufferedRef": + """Switch to the next wait slot.""" + if not self.is_buffered: return self + if not self.is_input: + return self + current_slot = (self.wait_in_slot[0] if # type: ignore[index] + self._wait_in_slot_reg is None else self._wait_in_slot_reg) + new_current_slot = lax.select(predicate, current_slot + 1, current_slot) + if self._wait_in_slot_reg is not None: + return self.with_slot_index(wait_in_slot=new_current_slot) + assert isinstance(self.wait_in_slot, jax.Array) + self.wait_in_slot[0] = new_current_slot + return self - # +--------------+------------------| - # | | - # | | - # + | - # | A (600, 600) | - # +--------------+ | - # | Block (2,0) | | - # + --------------------------------| - # where it is now a (88, 256) sized block. - # - # Suppose A is now (601, 600), instead of picking a (88, 256)-sized block - # for the last iteration on that dimension, we will pick the next highest - # tile multiple, i.e. (96, 256). - if len(src_shape) < 2: - raise NotImplementedError("Must use >1D values.") + def advance_copy_out_slot(self, predicate: bool | jax.Array = True) -> "BufferedRef": + """Switch to the next copy slot.""" + if not self.is_buffered: return self + if not self.is_output: + return self + current_slot = (self.copy_out_slot[0] if self._copy_out_slot_reg # type: ignore[index] + is None else self._copy_out_slot_reg) + new_current_slot = lax.select(predicate, current_slot + 1, current_slot) + if self._copy_out_slot_reg is not None: + return self.with_slot_index(copy_out_slot=new_current_slot) + assert isinstance(self.copy_out_slot, jax.Array) + self.copy_out_slot[0] = new_current_slot + return self - tiling = _make_tiling(src_shape, src_dtype) - block_shape = tuple(1 if b is None else b for b in self.block_shape) - block_indices = self.compute_index(*grid_indices) - return jax.tree.map( - _make_block_slice, block_indices, block_shape, src_shape, tiling + def advance_wait_out_slot(self, predicate: bool | jax.Array = True) -> "BufferedRef": + """Switch to the next wait slot.""" + if not self.is_buffered: return self + if not self.is_output: + return self + current_slot = (self.wait_out_slot[0] if self._wait_out_slot_reg # type: ignore[index] + is None else self._wait_out_slot_reg) + new_current_slot = lax.select(predicate, current_slot + 1, current_slot) + if self._wait_out_slot_reg is not None: + return self.with_slot_index(wait_out_slot=new_current_slot) + assert isinstance(self.wait_out_slot, jax.Array) + self.wait_out_slot[0] = new_current_slot + return self + + def load_slots(self, predicate: bool | jax.Array = True) -> BufferedRef: + """Load slot information into registers.""" + if not self.is_buffered: + return self + def _do_load(): + copy_in = self.copy_in_slot[0] if self.is_input else None + copy_out = self.copy_out_slot[0] if self.is_output else None + wait_in = self.wait_in_slot[0] if self.is_input else None + wait_out = self.wait_out_slot[0] if self.is_output else None + if self.use_lookahead: + next_fetch = tuple(self.next_fetch_smem[i][0] for i in range( + len(self.next_fetch_smem))) + else: + next_fetch = None + return (copy_in, copy_out, wait_in, wait_out, next_fetch) + def _no_load(): + copy_in = copy_out = wait_in = wait_out = None + # Need to make sure that we return a non-none value to make sure + # the pytrees for both branches match. + _ensure_not_none = lambda x: x if x is not None else jnp.uint32(0) + if self.is_input: + copy_in = _ensure_not_none(self._copy_in_slot_reg) + wait_in = _ensure_not_none(self._wait_in_slot_reg) + if self.is_output: + copy_out = _ensure_not_none(self._copy_out_slot_reg) + wait_out = _ensure_not_none(self._wait_out_slot_reg) + if self.use_lookahead: + if self.next_fetch_sreg is None: + next_fetch = tuple(jnp.int32(0) for _ in range( + len(self.next_fetch_smem))) + else: + next_fetch = self.next_fetch_sreg + else: + next_fetch = None + return (copy_in, copy_out, wait_in, wait_out, next_fetch) + (copy_in_slot, copy_out_slot, wait_in_slot, wait_out_slot, + next_fetch) = lax.cond(predicate, _do_load, _no_load) + bref = self.with_slot_index( + copy_in_slot=copy_in_slot, + copy_out_slot=copy_out_slot, + wait_in_slot=wait_in_slot, + wait_out_slot=wait_out_slot, ) + if bref.next_fetch_smem is not None: + bref = bref.with_next_fetch(next_fetch=next_fetch) + return bref + + def save_slots(self, predicate: bool | jax.Array = True): + """Save slot information from registers.""" + if not self.is_buffered: + return + @pl.when(predicate) + def _(): + if self.is_input: + assert self._copy_in_slot_reg is not None + self.copy_in_slot[0] = self._copy_in_slot_reg + assert self._wait_in_slot_reg is not None + self.wait_in_slot[0] = self._wait_in_slot_reg + if self.use_lookahead: + assert self.next_fetch_sreg is not None + for i in range(len(self.next_fetch_smem)): + self.next_fetch_smem[i][0] = self.next_fetch_sreg[i] + if self.is_output: + assert self._copy_out_slot_reg is not None + self.copy_out_slot[0] = self._copy_out_slot_reg + assert self._wait_out_slot_reg is not None + self.wait_out_slot[0] = self._wait_out_slot_reg def copy_in(self, src_ref, grid_indices): """Starts copy of HBM dma slice into the current slot.""" assert self.is_input - if self.memory_space == VMEM: return + if not self.is_buffered: return + assert not (self.window_ref is None or isinstance(self.window_ref, REF)) + assert self.sem_recvs is not None if self.swap is not None: self.swap[0] = True - next_slot = self.next_slot_index - src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices) - dst_slice = tuple(pl.ds(0, s.size) for s in src_slice) + slot = self.current_copy_in_slot + src_slice = self.get_dma_slice(_ref_to_value_aval(src_ref), grid_indices) + dst_slice = tuple( + pl.ds(0, s.size) + for s, bd in zip(src_slice, self.block_shape) + if not (bd is None or isinstance(bd, pl.Squeezed)) + ) tpu_primitives.make_async_copy( src_ref.at[src_slice], - self.window_ref.at[next_slot].at[dst_slice], - self.sem_recvs.at[next_slot], + self.window_ref.at[(slot, *dst_slice)], + self.sem_recvs.at[slot], ).start() def copy_out(self, dst_ref, grid_indices): """Starts copy of HBM dma slice from the current slot.""" assert self.is_output - if self.memory_space == VMEM: return + if not self.is_buffered: return + assert not (self.window_ref is None or isinstance(self.window_ref, REF)) + assert self.sem_sends is not None if self.swap is not None: self.swap[0] = True - slot = self.current_slot_index - dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices) - src_slice = tuple(pl.ds(0, s.size) for s in dst_slice) + slot = self.current_copy_out_slot + dst_slice = self.get_dma_slice(_ref_to_value_aval(dst_ref), grid_indices) + src_slice = tuple( + pl.ds(0, s.size) + for s, bd in zip(dst_slice, self.block_shape) + if not (bd is None or isinstance(bd, pl.Squeezed)) + ) tpu_primitives.make_async_copy( - self.window_ref.at[slot].at[src_slice], + self.window_ref.at[(slot, *src_slice)], dst_ref.at[dst_slice], self.sem_sends.at[slot], ).start() @@ -492,30 +1035,41 @@ def copy_out(self, dst_ref, grid_indices): def wait_in(self, src_ref, grid_indices): """Waits for input copy to finish.""" assert self.is_input - if self.memory_space == VMEM: return - src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices) - dst_slice = tuple(pl.ds(0, s.size) for s in src_slice) - current_slot = self.current_slot_index + if not self.is_buffered: return + assert not (self.window_ref is None or isinstance(self.window_ref, REF)) + assert self.sem_recvs is not None + src_slice = self.get_dma_slice(_ref_to_value_aval(src_ref), grid_indices) + dst_slice = tuple( + pl.ds(0, s.size) + for s, bd in zip(src_slice, self.block_shape) + if not (bd is None or isinstance(bd, pl.Squeezed)) + ) + wait_slot = self.current_wait_in_slot tpu_primitives.make_async_copy( src_ref.at[src_slice], # nb: doesn't matter - self.window_ref.at[current_slot].at[ - dst_slice + self.window_ref.at[ + (wait_slot, *dst_slice) ], # only dst shape is important - self.sem_recvs.at[current_slot], + self.sem_recvs.at[wait_slot], ).wait() def wait_out(self, dst_ref, grid_indices): """Waits for output copy to finish.""" assert self.is_output - if self.memory_space == VMEM: return - # In a double buffer, previous slot is the same as next slot. - prev_slot = self.next_slot_index - dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices) - src_slice = tuple(pl.ds(0, s.size) for s in dst_slice) + if not self.is_buffered: return + assert not (self.window_ref is None or isinstance(self.window_ref, REF)) + assert self.sem_sends is not None + wait_slot = self.current_wait_out_slot + dst_slice = self.get_dma_slice(_ref_to_value_aval(dst_ref), grid_indices) + src_slice = tuple( + pl.ds(0, s.size) + for s, bd in zip(dst_slice, self.block_shape) + if not (bd is None or isinstance(bd, pl.Squeezed)) + ) tpu_primitives.make_async_copy( - self.window_ref.at[prev_slot].at[src_slice], # nb: doesn't matter + self.window_ref.at[(wait_slot, *src_slice)], # nb: doesn't matter dst_ref.at[dst_slice], # only dst shape is important - self.sem_sends.at[prev_slot], + self.sem_sends.at[wait_slot], ).wait() # Accumulator methods @@ -533,16 +1087,18 @@ def set_accumulator(self, init=False): """Set accumulator or zero it out to initialize.""" assert self.is_accumulator if self.accum_ref is not None: + accum_dtype = self.accum_ref.dtype def _init(): self.accum_ref[...] = jnp.zeros_like(self.accum_ref[...]) def _set(): - self.accum_ref[...] = self.current_ref[...].astype(self.accum_ref.dtype) + self.accum_ref[...] = self.current_ref[...].astype(accum_dtype) lax.cond(init, _init, _set) def accumulate(self): """Add into the current slot.""" assert self.is_accumulator if self.accum_ref is not None: + assert self.window_ref is not None accum_dtype = jnp.float32 if self.window_ref.dtype == jnp.int32: accum_dtype = jnp.int32 @@ -554,10 +1110,100 @@ def accumulate(self): ).astype(self.window_ref.dtype) +def fetch_with_lookahead(buffered_ref, src_ref, + grid, + grid_offsets, + predicate: jax.Array | bool = True, + max_num_fetches: int | None = None, + update_slots: bool = True): + """Fetch future blocks using unbounded lookahead. + + Args: + buffered_ref: the BufferedRef to fetch for. + src_ref: the source Ref. + grid: the grid bounds. + grid_offsets: the grid offsets (used for megacore). + predicate: a boolean predicate for whether to perform the fetch. + max_num_fetches: the maximum number of fetches to perform. If None, + this will continually fetch until all copy_in slots are full. + update_slots: whether to update the register slot indices. + """ + assert buffered_ref.use_lookahead + add_offset = lambda x: tuple( + i + j for i, j in zip(x, grid_offsets, strict=True)) + index_inbound = lambda x: _tuple_lt(x, grid) + increment_indices = lambda x: _next_index(x, grid, allow_overflow=True) + def as_uint32(x): + if isinstance(x, bool): + return jnp.uint32(x) + else: + return x.astype(jnp.uint32) + + fetch_limit = buffered_ref.cumulative_wait_in + buffered_ref.buffer_count + if max_num_fetches is not None: + fetch_once_limit = buffered_ref.cumulative_copy_in + max_num_fetches + # We would like to write jnp.minimum(fetch_limit, fetch_once_limit) + # but this does not compile in Mosaic. + fetch_limit = lax.select(fetch_limit < fetch_once_limit, + fetch_limit, fetch_once_limit) + + + def _loop_cond(carry): + _, next_indices, cumulative_copy_in = carry + # Don't fetch more blocks than we have buffers. + within_limit = cumulative_copy_in < fetch_limit + # Don't fetch past the end of the grid. + in_bounds = index_inbound(next_indices) + return predicate & within_limit & in_bounds + + def _loop_body(carry): + current_indices, next_indices, cumulative_copy_in = carry + cur_indices_offset = add_offset(current_indices) + next_indices_offset = add_offset(next_indices) + block_indices = buffered_ref.compute_index(*cur_indices_offset) + next_block_indices = buffered_ref.compute_index(*next_indices_offset) + will_change = _tuples_differ(block_indices, next_block_indices) + pred = will_change + bref = buffered_ref.with_slot_index(copy_in_slot=cumulative_copy_in) + @pl.when(pred) + def _start(): + bref.copy_in(src_ref, next_indices_offset) # pylint: disable=cell-var-from-loop + next_copy_in = cumulative_copy_in + as_uint32(pred) + next_next_indices = increment_indices(next_indices) + return next_indices, next_next_indices, next_copy_in + current_indices = buffered_ref.next_fetch_indices + next_fetch = increment_indices(current_indices) + final_indices, _, final_copy_in_slot = lax.while_loop( + _loop_cond, _loop_body, + (current_indices, next_fetch, buffered_ref.cumulative_copy_in)) + + buffered_ref = buffered_ref.with_next_fetch(final_indices) + if update_slots: + buffered_ref = buffered_ref.with_slot_index(copy_in_slot=final_copy_in_slot) + return buffered_ref, final_copy_in_slot + + # Helper to tree map over BufferedRefs as leaves. map_brefs = functools.partial( jax.tree.map, - is_leaf=lambda x: isinstance(x, BufferedRef)) + is_leaf=lambda x: isinstance(x, BufferedRefBase) +) + +def map_inputs(f, *args): + """Maps over all input BufferedRefs.""" + def fmap(bref, *f_args): + if bref.is_input: + return f(bref, *f_args) + return bref + return map_brefs(fmap, *args) + +def map_outputs(f, *args): + """Maps over all output BufferedRefs.""" + def fmap(bref, *f_args): + if bref.is_output: + return f(bref, *f_args) + return bref + return map_brefs(fmap, *args) def _filter_indices( @@ -570,15 +1216,36 @@ def _filter_indices( def _next_index( - indices: tuple[int | jax.Array, ...], grid: tuple[int | jax.Array, ...] + indices: tuple[int | jax.Array, ...], grid: tuple[int | jax.Array, ...], + allow_overflow: bool = False, ) -> tuple[int | jax.Array, ...]: + """Increments the grid indices by one. + + Args: + indices: the current grid indices. + grid: the pallas grid. + allow_overflow: whether to allow the indices to overflow the grid. + If False (default), indices will wrap around to zero after reaching the + maximum grid size. If True, the bounds on the first grid position + will be ignored. + + Returns: + The next grid indices. + """ out = [] carry: bool | jax.Array = True - for i, g in reversed(list(zip(indices, grid, strict=True))): + for position, (i, g) in enumerate( + reversed(list(zip(indices, grid, strict=True)))): inc = jax.lax.select(carry, i + 1, i) - carry = inc == g + if allow_overflow and (position == len(grid) - 1): + carry = False + else: + carry = inc == g out.append(jax.lax.select(carry, 0, inc)) - return _filter_indices(tuple(reversed(out)), grid) + if allow_overflow: + return tuple(reversed(out)) + else: + return _filter_indices(tuple(reversed(out)), grid) def _prev_index( @@ -602,6 +1269,7 @@ def __init__( indices: tuple[int | jax.Array, ...], grid: tuple[int | jax.Array, ...], grid_offsets: tuple[int | jax.Array, ...], + num_stages: int, first_cycle=None, last_cycle=None, init_accumulators=None, @@ -614,6 +1282,7 @@ def __init__( indices: current grid indices. grid: pallas grid for BufferedRefs. grid_offsets: offsets for grid indices (used for megacore). + num_stages: number of stages in the pipeline. first_cycle: whether this is the first invocation of the pipeline. last_cycle: whether this is the last invocation of the pipeline. init_accumulators: do we zero-initialize accumulator state for this @@ -622,6 +1291,8 @@ def __init__( """ self.step = step self.grid = grid + self.grid_offsets = grid_offsets + self.num_stages = num_stages self.first_cycle = first_cycle self.last_cycle = last_cycle self.init_accumulators = init_accumulators @@ -647,10 +1318,25 @@ def __init__( i + j for i, j in zip(_prev_index(indices, grid), grid_offsets, strict=True) ) + next_indices = _next_index(indices, grid) self.next_indices = tuple( i + j - for i, j in zip(_next_index(indices, grid), grid_offsets, strict=True) + for i, j in zip(next_indices, grid_offsets, strict=True) ) + self.add_offset = lambda x: tuple(i + j for i, j in zip(x, grid_offsets, + strict=True)) + # TODO(justinfu): Don't recompute these on each iteration. + # fetch_indices stores the grid indices indexed by the amount of lookahead. + # i.e. fetch_indices[2] contains the grid indices 2 iterations + # ahead. + self.fetch_indices = [self.indices, self.next_indices] + fetch_indices = next_indices + for _ in range(self.num_stages-1): + fetch_indices = _next_index(fetch_indices, grid) + self.fetch_indices.append(tuple( + i + j + for i, j in zip(fetch_indices, grid_offsets, strict=True) + )) @contextmanager def _named_scope(self, name): @@ -664,41 +1350,106 @@ def grid_env(self): return pallas_core.grid_env( list(map(pallas_core.GridAxis, self.indices, self.grid))) + def out_of_fetch(self, buffered_ref): + """Returns whether there are no more blocks to fetch.""" + # Currently this is based on the iteration, but if we want to support + # lookahead this will depend on whether the lookahead reached the end. + if not buffered_ref.is_buffered: + return jnp.bool(False) + return self.step >= (self.num_steps - buffered_ref.buffer_count + 1) + def has_changed(self, buffered_ref): + if not buffered_ref.is_buffered: + return False indices = buffered_ref.compute_index(*self.indices) prev_indices = buffered_ref.compute_index(*self.prev_indices) return _tuples_differ(indices, prev_indices) - def will_change(self, buffered_ref): + def will_change_current(self, buffered_ref): + if not buffered_ref.is_buffered: + return False indices = buffered_ref.compute_index(*self.indices) next_indices = buffered_ref.compute_index(*self.next_indices) return _tuples_differ(indices, next_indices) + def will_change_fetch(self, buffered_ref): + if not buffered_ref.is_buffered: + return False + if buffered_ref.buffer_count < 2: + raise NotImplementedError() + indices = buffered_ref.compute_index( + *self.fetch_indices[buffered_ref.buffer_count-2]) + next_indices = buffered_ref.compute_index( + *self.fetch_indices[buffered_ref.buffer_count-1]) + return _tuples_differ(indices, next_indices) + def alias_local_refs(self, buffered_ref, ref): return buffered_ref.bind_existing_ref(ref, self.indices) + def unalias_local_refs(self, buffered_ref): + return buffered_ref.unbind_refs() + # SCHEDULE ---------------------------------------------------------------- # Below is the sequence of conditional waits and copies used for inputs, # outputs, and in-out accumulators. - def initialize(self, buffered_ref, src_ref, schedule=None): + def initialize_step(self, buffered_ref, src_ref, schedule=None, step=0): if schedule is None: schedule = _default_schedule - pred = schedule["prologue_copy_in"](self, buffered_ref, src_ref) - - with self._named_scope("ep_initialize"): - @pl.when(self.first_step_ever) - def _init_slots(): - buffered_ref.init_slots() - - @pl.when(pred) - def _start(): - if buffered_ref.is_input: - buffered_ref.copy_in(src_ref, self.indices) - buffered_ref.swap_slots() - - def wait_in(self, buffered_ref, src_ref, schedule=None): + # TODO(justinfu): Should cache this, but it doesn't actually do computation + # in both default & fixed schedules right now so it doesn't increase + # the Jaxpr size. + do_copy = schedule["prologue_copy_in"](self, buffered_ref, src_ref) + + with self._named_scope(f"ep_initialize_{step}"): + if step == 0: + @pl.when(self.first_step_ever) + def _init_slots(): + buffered_ref.init_slots() + buffered_ref = buffered_ref.load_slots() + + if not buffered_ref.is_input or not buffered_ref.is_buffered: + return buffered_ref + + if (step + 1) >= buffered_ref.buffer_count: + return buffered_ref + + if buffered_ref.use_lookahead: + if step == 0: + # We always fetch the first block. + @pl.when(do_copy) + def _start(): + buffered_ref.copy_in(src_ref, + self.add_offset(buffered_ref.next_fetch_indices)) # pylint: disable=cell-var-from-loop + buffered_ref = buffered_ref.advance_copy_in_slot(do_copy) + else: + buffered_ref, _ = fetch_with_lookahead( + buffered_ref, + src_ref, + self.grid, + self.grid_offsets, + predicate=self.first_step_ever & do_copy, + max_num_fetches=1, + ) + else: + if step == 0: + predicate = do_copy + fetch_indices = self.fetch_indices[step] + else: + fetch_indices = self.fetch_indices[step] + prev_grid_indices = self.fetch_indices[step - 1] + block_indices = buffered_ref.compute_index(*fetch_indices) + prev_block_indices = buffered_ref.compute_index(*prev_grid_indices) + block_changed = _tuples_differ(block_indices, prev_block_indices) + predicate = do_copy & block_changed + @pl.when(predicate) # pylint: disable=cell-var-from-loop + def _start(): + buffered_ref.copy_in(src_ref, fetch_indices) # pylint: disable=cell-var-from-loop + buffered_ref = buffered_ref.advance_copy_in_slot(predicate) + return buffered_ref + + def wait_in(self, buffered_ref, src_ref, schedule=None) -> "BufferedRef": if schedule is None: schedule = _default_schedule pred = schedule["wait_in"](self, buffered_ref, src_ref) @@ -722,20 +1473,29 @@ def _set_accumulator(): # so this is usually just setting the accumulator to 0. buffered_ref.set_accumulator(self.init_accumulators) lax.cond(pred, _wait, _no_wait) + return buffered_ref - def copy_in(self, buffered_ref, src_ref, schedule=None): + def copy_in(self, buffered_ref, src_ref, schedule=None) -> "BufferedRef": if schedule is None: schedule = _default_schedule pred = schedule['copy_in'](self, buffered_ref, src_ref) + if not buffered_ref.is_input: + return buffered_ref - @pl.when(pred) - @self._named_scope("ep_copy_in") - def _send(): - if buffered_ref.is_input: - # We skip the last step because that's what prefetch is for. - @pl.when(~self.last_step) - def _copy_in(): - buffered_ref.copy_in(src_ref, self.next_indices) + if buffered_ref.use_lookahead: + buffered_ref, _ = fetch_with_lookahead( + buffered_ref, src_ref, self.grid, self.grid_offsets, predicate=True + ) + else: + @pl.when(pred) + @self._named_scope("ep_copy_in") + def _send(): + if buffered_ref.is_input and buffered_ref.is_buffered: + buffered_ref.copy_in(src_ref, + self.fetch_indices[buffered_ref.buffer_count-1]) + buffered_ref = buffered_ref.advance_copy_in_slot( + pred & buffered_ref.is_input) + return buffered_ref # --> Call prefetch here to grab the first inputs of next cycle. @@ -745,16 +1505,53 @@ def prefetch(self, buffered_ref, src_ref, schedule=None): schedule = _default_schedule pred = schedule['prefetch'](self, buffered_ref, src_ref) - @pl.when(pred) - @self._named_scope("ep_prefetch") - def _send(): - if buffered_ref.is_input: - # Prefetch should only run on the last step. - @pl.when(self.last_step) - def _prefetch_in(): - buffered_ref.copy_in(src_ref, self.next_indices) + if not buffered_ref.is_input or not buffered_ref.is_buffered: + return - def wait_out(self, buffered_ref, dst_ref, schedule=None): + if buffered_ref.use_lookahead: + buffered_ref = buffered_ref.with_next_fetch( + jax.tree.map(jnp.zeros_like, buffered_ref.next_fetch_sreg)) + @pl.when(pred) + def _start(): + buffered_ref.copy_in( + src_ref, self.add_offset(buffered_ref.next_fetch_sreg)) # pylint: disable=cell-var-from-loop + buffered_ref = buffered_ref.advance_copy_in_slot(pred) + + buffered_ref, final_copy_in_slot = fetch_with_lookahead( + buffered_ref, + src_ref, + self.grid, + self.grid_offsets, + predicate=pred, + update_slots=False, + ) + @pl.when(pred) + def _(): + bref = buffered_ref.with_slot_index(copy_in_slot=final_copy_in_slot) + bref.save_slots() + else: + pred = pred & self.last_step + grid_indices = self.indices + for i in range(buffered_ref.buffer_count - 1): + next_grid_indices = self.fetch_indices[i+1] + block_indices = buffered_ref.compute_index(*grid_indices) + next_block_indices = buffered_ref.compute_index(*next_grid_indices) + if i == 0: + # If the prefetch predicate triggers, we already know that the + # first block needs to be copied. + should_prefetch = True + else: + should_prefetch = _tuples_differ(block_indices, next_block_indices) + + @pl.when(pred & should_prefetch) + def _(): + buffered_ref.copy_in(src_ref, next_grid_indices) # pylint: disable=cell-var-from-loop + buffered_ref = buffered_ref.advance_copy_in_slot(pred & should_prefetch) + grid_indices = next_grid_indices + buffered_ref.save_slots() + return + + def wait_out(self, buffered_ref, dst_ref, schedule=None) -> "BufferedRef": if schedule is None: schedule = _default_schedule pred = schedule['wait_out'](self, buffered_ref, dst_ref) @@ -763,12 +1560,19 @@ def wait_out(self, buffered_ref, dst_ref, schedule=None): @self._named_scope("ep_wait_out") def _wait(): if buffered_ref.is_output: + # Note: As implemented, the current scheduler cannot support multiple + # buffering on outputs. In order to do so properly, we need to save + # the indices for which the copy_out was issued, and wait on them + # here. In the current schedule we always immediately wait_out + # on the iteration after the copy_out, so the prev_indices is always + # the correct grid index to wait on. buffered_ref.wait_out(dst_ref, self.prev_indices) + return buffered_ref.advance_wait_out_slot(pred & buffered_ref.is_output) # --> Call "postyeet" here, after last output copy is finished from previous # cycle - def copy_out(self, buffered_ref, dst_ref, schedule=None): + def copy_out(self, buffered_ref, dst_ref, schedule=None) -> "BufferedRef": if schedule is None: schedule = _default_schedule pred = schedule['copy_out'](self, buffered_ref, dst_ref) @@ -792,6 +1596,7 @@ def _just_accumulate(): def _accumulate(): buffered_ref.accumulate() lax.cond(pred, _copy_out_and_accumulate, _just_accumulate) + return buffered_ref.advance_copy_out_slot(pred & buffered_ref.is_output) def finalize(self, buffered_ref, dst_ref, schedule=None): if schedule is None: @@ -804,29 +1609,18 @@ def _end(): if buffered_ref.is_output: buffered_ref.wait_out(dst_ref, self.indices) - def swap_slots(self, buffered_ref, hbm_ref, schedule=None): - if buffered_ref.swap is not None: - swap = buffered_ref.swap[0] - else: - # If we are not using an SMEM `swap` tensor to keep track of - # swaps needed, then all the copies into and out of BufferedRefs - # are done by direct calls to the `copy_in` and `copy_out` - # methods in the pipeline loop. To determine if the BufferedRef - # needs a swap of slots, we recalculate the copy-in/copy-out - # conditions. - if schedule is None: - schedule = _default_schedule - pred_in = schedule["copy_in"](self, buffered_ref, hbm_ref) - pred_out = schedule["copy_out"](self, buffered_ref, hbm_ref) - - copied_in = pred_in & buffered_ref.is_input & ~self.last_step - copied_out = pred_out & buffered_ref.is_output - swap = copied_in | copied_out - - @pl.when(swap) - @self._named_scope("ep_swap") - def _swap(): - buffered_ref.swap_slots() + buffered_ref.save_slots() + + def advance_slots(self, buffered_ref, schedule=None): + if schedule is None: + schedule = _default_schedule + + if buffered_ref.is_input: + pred = schedule['advance_wait_in'](self, buffered_ref, schedule) + buffered_ref = buffered_ref.advance_wait_in_slot(pred) + # Currently we advance copy_in and output slots after their respective + # operation. + return buffered_ref # END SCHEDULE -------------------------------------------------------------- @@ -846,17 +1640,18 @@ def _swap(): prologue_copy_in=lambda s, bref, _: s.first_step_ever, # We assume that the source ref changed for prefetch. wait_in=lambda s, bref, _: s.has_changed(bref) | s.first_step, - copy_in=lambda s, bref, _: s.will_change(bref) & ~s.last_step_ever, + advance_wait_in=lambda s, bref, _: ( + s.will_change_current(bref) | s.last_step), + copy_in=lambda s, bref, _: s.will_change_fetch(bref) & ~s.out_of_fetch( + bref), # We assume that the source ref changed. E.g. because of a CM DMA. prefetch=lambda s, bref, _: ( - (s.will_change(bref) | s.last_step) & ~s.last_step_ever + (s.will_change_fetch(bref) | s.last_step) & ~s.last_step_ever ), # We assume that the target ref changed. E.g. because of a CM DMA. - wait_out=lambda s, bref, _: ( - (s.has_changed(bref) | s.first_step) & ~s.first_step_ever - ), + wait_out=lambda s, bref, _: (s.has_changed(bref) | s.first_step) & ~s.first_step_ever, # We assume that the target ref is changing. E.g. because of a CM DMA. - copy_out=lambda s, bref, _: s.will_change(bref) | s.last_step, + copy_out=lambda s, bref, _: s.will_change_current(bref) | s.last_step, epilogue_wait_out=lambda s, bref, _: s.last_step_ever, ) @@ -869,13 +1664,15 @@ def _swap(): prologue_copy_in=lambda s, bref, _: s.first_step_ever, # We don't assume that the source ref changed for prefetch. wait_in=lambda s, bref, _: s.has_changed(bref) | s.first_step_ever, - copy_in=lambda s, bref, _: s.will_change(bref) & ~s.last_step_ever, + advance_wait_in=lambda s, bref, _: s.will_change_current(bref), + copy_in=lambda s, bref, _: s.will_change_fetch(bref) & ~s.out_of_fetch( + bref), # We don't assume that the source ref changed. - prefetch=lambda s, bref, _: s.will_change(bref) & ~s.last_step_ever, + prefetch=lambda s, bref, _: s.will_change_fetch(bref) & ~s.last_step_ever, # We don't assume that the target ref changed. - wait_out=lambda s, bref, _: s.has_changed(bref) & ~s.first_step_ever, + wait_out=lambda s, bref, _: (s.has_changed(bref) & ~s.first_step_ever), # We don't assume that the target ref is changing. - copy_out=lambda s, bref, _: s.will_change(bref) | s.last_step_ever, + copy_out=lambda s, bref, _: s.will_change_current(bref) | s.last_step_ever, epilogue_wait_out=lambda s, bref, _: s.last_step_ever, ) @@ -888,7 +1685,7 @@ def skip_input_copies_when_init_accumulators(schedule) -> Any: def new_pred(original_pred_fn, *a): pred = original_pred_fn(*a) if a[1].is_accumulator or a[1].is_input_output: - pred &= ~a[0].init_accumulators + pred &= jnp.logical_not(a[0].init_accumulators) return pred new_schedule[k] = functools.partial( @@ -917,10 +1714,12 @@ def get_pipeline_schedule(schedule) -> Any: def make_pipeline_allocations( *refs, - in_specs=None, - out_specs=None, + in_specs=(), + out_specs=(), + tiling: Tiling | None = None, should_accumulate_out=False, needs_swap_ref=True, + grid=None, ): """Create BufferedRefs for the pipeline. @@ -934,6 +1733,7 @@ def make_pipeline_allocations( should_accumulate_out: booleans to indicate which outputs should be treated as accumulators. needs_swap_ref: whether a swap slots tracker needs to be allocated. + grid: grid to use for the pipeline. Returns: A list of BufferedRefs, one corresponding to each ref specified in the @@ -952,12 +1752,52 @@ def make_pipeline_allocations( in_refs = refs[:num_in_specs] out_refs = refs[num_in_specs:] def make_input_bref(in_spec, in_ref): - return BufferedRef.input(in_spec, in_ref.dtype, needs_swap_ref) + buffer_count = 2 + use_lookahead = False + if in_spec.pipeline_mode is not None: + buffer_count = in_spec.pipeline_mode.buffer_count + use_lookahead = in_spec.pipeline_mode.use_lookahead + if use_lookahead and grid is None: + raise ValueError("Grid must be specified when using lookahead.") + + in_aval = _ref_to_value_aval(in_ref) + return BufferedRef.input( + in_spec, + in_aval, + buffer_count, + needs_swap_ref=needs_swap_ref, + grid_rank=len(grid), + use_lookahead=use_lookahead, + source_memory_space=in_ref.memory_space, + tiling=tiling, + ) in_brefs = jax.tree.map(make_input_bref, in_specs, in_refs) def make_output_bref(out_spec, out_ref, accumulate): + buffer_count = 2 + if out_spec.pipeline_mode is not None: + buffer_count = out_spec.pipeline_mode.buffer_count + if out_spec.pipeline_mode.use_lookahead: + raise ValueError("Output buffering does not support lookahead.") + + out_aval = _ref_to_value_aval(out_ref) + if accumulate: - return BufferedRef.accumulator(out_spec, out_ref.dtype, needs_swap_ref) - return BufferedRef.output(out_spec, out_ref.dtype, needs_swap_ref) + return BufferedRef.accumulator( + out_spec, + out_aval, + buffer_count, + needs_swap_ref=needs_swap_ref, + source_memory_space=out_ref.memory_space, + tiling=tiling, + ) + return BufferedRef.output( + out_spec, + out_aval, + buffer_count, + needs_swap_ref=needs_swap_ref, + source_memory_space=out_ref.memory_space, + tiling=tiling, + ) out_brefs = jax.tree.map( make_output_bref, out_specs, out_refs, should_accumulate_out) return (*in_brefs, *out_brefs) @@ -975,7 +1815,7 @@ def _partition_grid( num_cores = pl.num_programs(core_axis) core_id = pl.program_id(core_axis) else: - num_cores = jax.lax.psum(1, core_axis) + num_cores = jax.lax.axis_size(core_axis) core_id = jax.lax.axis_index(core_axis) # Check that num_cores is statically known if not isinstance(num_cores, int): @@ -1054,20 +1894,52 @@ def _partition_grid( offsets = jax_util.tuple_update( (0,) * len(grid), partition_dimension, grid_offset ) - return new_grid, offsets + return new_grid, offsets # type: ignore[return-value] + + +def sync_copy(src: REF | BufferedRef, dst: REF | BufferedRef, indices): + """Perform a synchronous copy from src to dst.""" + bref: BufferedRef + hbm_ref: REF + if isinstance(src, BufferedRef): + bref = src + if isinstance(dst, BufferedRef): + raise ValueError("Only one of src or dst can be a BufferedRef.") + hbm_ref = dst + copy_in = False + else: + if not isinstance(dst, BufferedRef): + raise ValueError("One of src or dst must be a BufferedRef.") + bref = dst + hbm_ref = src + copy_in = True + hbm_slice = bref.get_dma_slice(_ref_to_value_aval(hbm_ref), indices) + bref_slice = tuple( + pl.ds(0, s.size) + for s, bd in zip(hbm_slice, bref.block_shape) + if not (bd is None or isinstance(bd, pl.Squeezed)) + ) + if copy_in: + tpu_helpers.sync_copy(hbm_ref.at[hbm_slice], + bref.current_ref.at[bref_slice]) # type: ignore[union-attr] + else: + tpu_helpers.sync_copy(bref.current_ref.at[bref_slice], # type: ignore[union-attr] + hbm_ref.at[hbm_slice]) def emit_pipeline( body, *, grid: tuple[int | jax.Array, ...], - in_specs=None, - out_specs=None, + in_specs=(), + out_specs=(), + tiling: Tiling | None = None, should_accumulate_out: bool = False, core_axis: int | None = None, core_axis_name: str | None = None, dimension_semantics: tuple[GridDimensionSemantics, ...] | None = None, trace_scopes: bool = True, + no_pipelining: bool = False, ): """Creates a function to emit a manual pallas pipeline. @@ -1084,6 +1956,7 @@ def emit_pipeline( grid: a pallas grid definition. in_specs: input pallas block specs out_specs: output pallas block specs + tiling: optional tiling to assume for the refs. should_accumulate_out: booleans to indicate which outputs should be treated as accumulators. core_axis: optional int, indicates whether or not to partition the grid @@ -1094,6 +1967,8 @@ def emit_pipeline( or ARBITRARY). trace_scopes: optional bool, indicates whether to annotate each region in the pipeline using named_scope. + no_pipelining: If True, turns off pipelining and all copies will be made + synchronous. This is useful for debugging multiple-buffering related bugs. """ if any(not isinstance(d, (int, jax.Array)) for d in grid): grid_types = tuple(type(d) for d in grid) @@ -1115,6 +1990,10 @@ def emit_pipeline( if isinstance(out_specs, list): out_specs = tuple(out_specs) should_accumulate_out = _broadcast_pytree_to(should_accumulate_out, out_specs) + get_buffer_count = lambda spec: (spec.pipeline_mode.buffer_count if + (spec is not None and spec.pipeline_mode is not None) else 2) + flattened_specs = jax.tree.leaves((in_specs, out_specs)) + max_buffer_count = max((2, *map(get_buffer_count, flattened_specs))) def pipeline( *refs: Any, @@ -1182,6 +2061,8 @@ def pipeline( out_specs=out_specs, should_accumulate_out=should_accumulate_out, needs_swap_ref=needs_swap_ref, + grid=grid, + tiling=tiling, ), ) if isinstance(allocations, list): @@ -1202,29 +2083,33 @@ def make_scheduler(step, indices): indices, grid, grid_offsets=grid_offsets, + num_stages=max_buffer_count, first_cycle=first_cycle, last_cycle=last_cycle, init_accumulators=init_accumulators, trace_scopes=trace_scopes, ) - def loop_body(step, indices): + def loop_body(step, carry): + unaliased_brefs, indices = carry scheduler = make_scheduler(step, indices) with scheduler.grid_env(): # prepare any local VMEM aliases - brefs = map_brefs(scheduler.alias_local_refs, allocations, refs) - + brefs = map_brefs(scheduler.alias_local_refs, unaliased_brefs, refs) # loop input handling phase - map_brefs(scheduler.copy_in, brefs, refs, schedule) - map_brefs(scheduler.wait_in, brefs, refs, schedule) + brefs = map_brefs(scheduler.copy_in, brefs, refs, schedule) + brefs = map_brefs(scheduler.wait_in, brefs, refs, schedule) # prefetch inputs for the *next* invocation of this pipeline with scheduler._named_scope("ep_prefetch"): if prefetch is not None: - lax.cond(step == num_steps - 1, + do_prefetch = step == num_steps - 1 + map_brefs(lambda x: x.save_slots(do_prefetch), brefs) + lax.cond(do_prefetch, lambda: prefetch(*brefs, scheduler), lambda: None) + brefs = map_brefs(lambda x: x.load_slots(do_prefetch), brefs) # run the kernel! if body_prologue is not None: @@ -1234,34 +2119,83 @@ def loop_body(step, indices): body(*current_refs, *scratches) # loop output handling phase - map_brefs(scheduler.copy_out, brefs, refs, schedule) - map_brefs(scheduler.wait_out, brefs, refs, schedule) + brefs = map_brefs(scheduler.copy_out, brefs, refs, schedule) + brefs = map_brefs(scheduler.wait_out, brefs, refs, schedule) # handle writes for the *last* invocation of this pipeline's outputs with scheduler._named_scope("ep_postyeet"): if postyeet is not None: - lax.cond(step == 0, + do_postyeet = step == 0 + map_brefs(lambda x: x.save_slots(do_postyeet), brefs) + lax.cond(do_postyeet, lambda: postyeet(*brefs, scheduler), lambda: None) + brefs = map_brefs(lambda x: x.load_slots(do_postyeet), brefs) - map_brefs(scheduler.swap_slots, brefs, refs, schedule) - return _next_index(indices, grid) + brefs = map_brefs(scheduler.advance_slots, brefs, schedule) + # Unbind window_refs for VMEM-backed buffers. Without this + # we will be returning TransformedRefs which are not valid + # JAX types. + brefs = map_brefs(scheduler.unalias_local_refs, brefs) + return brefs, _next_index(indices, grid) - @pl.when(num_steps > 0) - def _(): - # pipeline prologue + if no_pipelining: + # Debugging mode where all copies are synchronous. initial_indices = (0,) * len(grid) scheduler = make_scheduler(0, initial_indices) brefs = map_brefs(scheduler.alias_local_refs, allocations, refs) - map_brefs(scheduler.initialize, brefs, refs, schedule) - - # pipeline loop - next_indices = lax.fori_loop(0, num_steps, loop_body, initial_indices) - - # pipeline epilogue - final_indices = _prev_index(next_indices, grid) - scheduler = make_scheduler(num_steps - 1, final_indices) - brefs = map_brefs(scheduler.alias_local_refs, allocations, refs) - map_brefs(scheduler.finalize, brefs, refs, schedule) + map_brefs(lambda bref: bref.init_slots(), brefs) + if postyeet is not None or prefetch is not None: + raise NotImplementedError("Prefetch/Postyeet not supported") + if any(bref.is_accumulator for bref in brefs): + raise NotImplementedError("Accumulators not supported") + @functools.partial(jax.lax.fori_loop, 0, num_steps, + init_val=(brefs, initial_indices)) + def _loop_body(step, carry): + brefs, indices = carry + scheduler = make_scheduler(step, indices) + with scheduler.grid_env(): + # prepare any local VMEM aliases + brefs = map_brefs(scheduler.alias_local_refs, brefs, refs) + # loop input handling phase + copy_in = lambda bref, ref: sync_copy(ref, bref, indices) + map_inputs(copy_in, brefs, refs) + # run the kernel! + if body_prologue is not None: + body_prologue() + current_refs = map_brefs(lambda x: x.current_ref, brefs) + with scheduler._named_scope("ep_run_kernel"): + body(*current_refs, *scratches) + # loop output handling phase + copy_out = lambda bref, ref: sync_copy(bref, ref, indices) + map_outputs(copy_out, brefs, refs) + brefs = map_brefs(scheduler.unalias_local_refs, brefs) + return brefs, _next_index(indices, grid) + else: + @pl.when(num_steps > 0) + def _(): + # pipeline prologue + initial_indices = (0,) * len(grid) + scheduler = make_scheduler(0, initial_indices) + brefs = allocations + with scheduler.grid_env(): + # We issue num_stages-1 prefetch copies per buffer. + # We iterate over steps in the outer loop because we want to + # queue all iteration 0 prefetches before iteration 1, and so on. + for step in range(scheduler.num_stages - 1): + brefs = map_brefs(functools.partial( + scheduler.initialize_step, step=step), + brefs, refs, schedule) + + # pipeline loop + brefs, next_indices = lax.fori_loop( + 0, num_steps, loop_body, (brefs, initial_indices) + ) + + # pipeline epilogue + final_indices = _prev_index(next_indices, grid) + scheduler = make_scheduler(num_steps - 1, final_indices) + with scheduler.grid_env(): + map_brefs(scheduler.finalize, brefs, refs, schedule) return pipeline @@ -1270,8 +2204,8 @@ def emit_pipeline_with_allocations( body, *, grid, - in_specs=None, - out_specs=None, + in_specs=(), + out_specs=(), should_accumulate_out=False, ): """Creates pallas pipeline and top-level allocation preparation functions. @@ -1285,17 +2219,17 @@ def emit_pipeline_with_allocations( as accumulators. Returns: - (emit_pipeline, make_allocations) function pair, where: - emit_pipeline is the pallas pipeline function. - make_allocations is a function to create buffered refs for the inner - pipeline that can be created at the top-level of a pallas call to be - reused across multiple invocations of the inner pipeline. - + (emit_pipeline, make_allocations) function pair, where + - emit_pipeline is the pallas pipeline function. + - make_allocations is a function to create buffered refs for the inner + pipeline that can be created at the top-level of a pallas call to be + reused across multiple invocations of the inner pipeline. """ make_allocations = functools.partial(make_pipeline_allocations, in_specs=in_specs, out_specs=out_specs, - should_accumulate_out=should_accumulate_out) + should_accumulate_out=should_accumulate_out, + grid=grid) pipeline = emit_pipeline( body, grid=grid, diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index fb0e0c2c55e3..3045fe7c4edb 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -16,19 +16,23 @@ from __future__ import annotations import dataclasses -import enum +import functools +import logging from typing import Any import jax from jax._src import core as jax_core from jax._src import dtypes +from jax._src import effects from jax._src import pretty_printer as pp +from jax._src import prng as jax_prng +from jax._src import random as jax_random from jax._src import state from jax._src import tree_util from jax._src import util from jax._src.interpreters import mlir from jax._src.pallas import core as pl_core -from jax._src.pallas import utils as pallas_utils +from jax._src.pallas import primitives from jax._src.pallas.mosaic import core as tpu_core from jax._src.state import discharge as state_discharge from jax._src.state import indexing @@ -42,33 +46,47 @@ map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip +IntDeviceId = int | jax.Array +MultiDimDeviceId = tuple[IntDeviceId, ...] | dict[str | tuple[str, ...], IntDeviceId] +Ref = state.AbstractRef | state.TransformedRef + repeat_p = jax_core.Primitive('repeat') -def repeat(x, repeats, axis): +def repeat(x: jax.Array, repeats: int, axis: int) -> jax.Array: + axis = util.canonicalize_axis(axis, x.ndim) return repeat_p.bind(x, repeats=repeats, axis=axis) @repeat_p.def_abstract_eval def _repeat_abstract_eval(x, *, repeats, axis): + if axis < 0 or axis >= len(x.shape): + raise ValueError(f"axis: {axis} is out of range [0, {len(x.shape)})") shape = list(x.shape) shape[axis] *= repeats return jax_core.ShapedArray(shape, x.dtype) +@repeat_p.def_impl +def repeat_impl(x: jax.Array, *, repeats: int, axis: int): + reps = [repeats if i == axis else 1 for i in range(x.ndim)] + return jnp.tile(x, reps) + + def _repeat_lowering_rule(ctx: mlir.LoweringRuleContext, x, *, repeats, axis): - def _repeat(x): - return jnp.repeat(x, repeats, axis) - return mlir.lower_fun(_repeat, multiple_results=False)(ctx, x) + return mlir.lower_fun( + functools.partial(repeat_impl, repeats=repeats, axis=axis), + multiple_results=False, + )(ctx, x) mlir.register_lowering(repeat_p, _repeat_lowering_rule) bitcast_p = jax_core.Primitive("bitcast") -def bitcast(x, ty: DTypeLike): - ty = dtypes.canonicalize_dtype(ty) +def bitcast(x: jax.Array, ty: DTypeLike) -> jax.Array: + ty = dtypes.check_and_canonicalize_user_dtype(ty) if len(x.shape) < 2: raise ValueError("Not implemented: bitcast 1D") - src_bitwidth = pallas_utils.dtype_bitwidth(x.dtype) - dst_bitwidth = pallas_utils.dtype_bitwidth(ty) + src_bitwidth = dtypes.itemsize_bits(x.dtype) + dst_bitwidth = dtypes.itemsize_bits(ty) if x.shape[-2] * src_bitwidth % dst_bitwidth: raise ValueError( "Not implemented: the 2nd minor dim can not be perfectly packed or" @@ -80,16 +98,16 @@ def bitcast(x, ty: DTypeLike): @bitcast_p.def_abstract_eval def _bitcast_abstract_eval(x, *, ty): shape = list(x.shape) - src_bitwidth = pallas_utils.dtype_bitwidth(x.dtype) - dst_bitwidth = pallas_utils.dtype_bitwidth(ty) + src_bitwidth = dtypes.itemsize_bits(x.dtype) + dst_bitwidth = dtypes.itemsize_bits(ty) shape[-2] = shape[-2] * src_bitwidth // dst_bitwidth return jax_core.ShapedArray(shape, ty) def _bitcast_lowering_rule(ctx: mlir.LoweringRuleContext, x, *, ty): def _bitcast(x): - src_bitwidth = pallas_utils.dtype_bitwidth(x.dtype) - dst_bitwidth = pallas_utils.dtype_bitwidth(ty) + src_bitwidth = dtypes.itemsize_bits(x.dtype) + dst_bitwidth = dtypes.itemsize_bits(ty) if src_bitwidth < dst_bitwidth: *leading, m, n = x.shape packing = dst_bitwidth // src_bitwidth @@ -111,13 +129,13 @@ def _bitcast(x): def roll( - x, - shift, + x: jax.Array, + shift: jax.Array | int, axis: int, *, stride: int | None = None, stride_axis: int | None = None, -): +) -> jax.Array: if isinstance(shift, int) and shift < 0: raise ValueError("shift must be non-negative.") if axis < 0 or axis >= len(x.shape): @@ -160,255 +178,6 @@ def _roll(x, shift): mlir.register_lowering(roll_p, _roll_lowering_rule) -class DeviceIdType(enum.Enum): - MESH = "mesh" - LOGICAL = "logical" - - -def check_sem_avals( - sem_aval, sem_transforms_avals, name, allowed_semaphore_types=None -): - if allowed_semaphore_types is None: - allowed_semaphore_types = { - tpu_core.semaphore, - tpu_core.barrier_semaphore, - # For interpret mode. - pl_core.SEMAPHORE_INTERPRET_DTYPE, - } - if not isinstance(sem_aval, state.AbstractRef): - raise ValueError(f"Cannot {name} on a non-semaphore Ref: {sem_aval}") - sem_shape = sem_aval.shape - if sem_transforms_avals: - sem_shape = sem_transforms_avals[-1].get_indexer_shape() - if sem_shape: - raise ValueError(f"Cannot {name} on a non-()-shaped semaphore: {sem_shape}") - sem_dtype = sem_aval.dtype - if not any( - jnp.issubdtype(sem_dtype, sem_type) - for sem_type in allowed_semaphore_types - ): - raise ValueError( - f"Must {name} semaphores of the following types:" - f" {allowed_semaphore_types}. Got {sem_dtype}." - ) - - -def _transform_semaphore(ref_value, transforms, ref_aval): - """Helper function for indexing into a semaphore during state_discharge.""" - if ref_value.shape == ref_aval.shape: - return state_discharge.transform_array(ref_value, transforms) - elif len(ref_value.shape) == 0: - return ref_value - else: - raise ValueError( - f"Semaphore value shape {ref_value.shape} does not match aval shape" - f" {ref_aval.shape}" - ) - - -semaphore_read_p = jax_core.Primitive("semaphore_read") -semaphore_read_p.multiple_results = False - - -def semaphore_read(sem_or_view): - ref, transforms = _get_ref_and_transforms(sem_or_view) - args = [ref, transforms] - flat_args, args_tree = tree_util.tree_flatten(args) - return semaphore_read_p.bind(*flat_args, args_tree=args_tree) - -@semaphore_read_p.def_abstract_eval -def _semaphore_read_abstract_eval( - *avals, - args_tree, -): - sem_aval, sem_transforms_avals = tree_util.tree_unflatten(args_tree, avals) - check_sem_avals( - sem_aval, - sem_transforms_avals, - "read", - allowed_semaphore_types={ - tpu_core.dma_semaphore, - tpu_core.semaphore, - tpu_core.barrier_semaphore, - pl_core.SEMAPHORE_INTERPRET_DTYPE, - }, - ) - return jax_core.ShapedArray((), jnp.dtype("int32")) - -def _semaphore_read_discharge_rule(in_avals, - out_avals, - *flat_args, - args_tree): - del out_avals - [ref, transforms] = args_tree.unflatten(flat_args) - sem_value = _transform_semaphore(ref, transforms, in_avals[0]) - sem_value = sem_value.astype(jnp.int32) - return (None,) * len(in_avals), sem_value -state_discharge.register_discharge_rule(semaphore_read_p)( - _semaphore_read_discharge_rule -) - - -semaphore_signal_p = jax_core.Primitive('semaphore_signal') -semaphore_signal_p.multiple_results = True - - -def semaphore_signal( - sem_or_view, - inc: int | jax.Array = 1, - *, - device_id: int | jax.Array | None | tuple[int | jax.Array, ...] = None, - device_id_type: DeviceIdType = DeviceIdType.MESH, - core_index: int | jax.Array | None = None, -): - ref, transforms = _get_ref_and_transforms(sem_or_view) - inc = jnp.asarray(inc, dtype=jnp.int32) - args = [ref, transforms, inc, device_id, core_index] - flat_args, args_tree = tree_util.tree_flatten(args) - semaphore_signal_p.bind( - *flat_args, - args_tree=args_tree, - device_id_type=device_id_type, - ) - - -@semaphore_signal_p.def_abstract_eval -def _semaphore_signal_abstract_eval( - *avals, - args_tree, - device_id_type: DeviceIdType, -): - del device_id_type - ( - sem_aval, - sem_transforms_avals, - value_aval, - device_id_avals, - core_index_aval, - ) = tree_util.tree_unflatten(args_tree, avals) - check_sem_avals(sem_aval, sem_transforms_avals, "signal") - if value_aval.dtype != jnp.dtype("int32"): - raise ValueError("Must signal an int32 value.") - if device_id_avals is not None: - device_id_flat_avals = tree_util.tree_leaves(device_id_avals) - for aval in device_id_flat_avals: - if aval.dtype != jnp.dtype("int32"): - raise ValueError("`device_id`s must be an int32 value.") - return [] - - -def _semaphore_signal_pp_eqn(eqn: jax_core.JaxprEqn, - context: jax_core.JaxprPpContext, - settings: jax_core.JaxprPpSettings): - del settings - invars = eqn.invars - tree = eqn.params["args_tree"] - ( - sem, - sem_transforms, - value, - device_ids, - _, - ) = tree_util.tree_unflatten(tree, invars) - out = pp.concat([ - pp.text("semaphore_signal"), - pp.text(" "), - sp.pp_ref_transforms(context, sem, sem_transforms), - pp.text(" "), - pp.text(jax_core.pp_var(value, context)), - ]) - if device_ids is not None: - flat_device_ids = tree_util.tree_leaves(device_ids) - if not flat_device_ids: - return out - device_ids_pp = [pp.text(jax_core.pp_var(flat_device_ids[0], context))] - for device_id in flat_device_ids[1:]: - device_ids_pp.append(pp.text(" ")) - device_ids_pp.append(pp.text(jax_core.pp_var(device_id, context))) - out = pp.concat([out, pp.concat(device_ids_pp)]) - return out -jax_core.pp_eqn_rules[semaphore_signal_p] = _semaphore_signal_pp_eqn - - -def _semaphore_signal_discharge_rule(in_avals, - out_avals, - *flat_args, - args_tree, - device_id_type): - del out_avals, device_id_type - [ref, transforms, inc, device_id, core_index] = args_tree.unflatten(flat_args) - if device_id is not None: - raise NotImplementedError("Remote signal not implemented.") - if core_index is not None: - raise NotImplementedError("Multiple core support not implemented.") - sem_value = _transform_semaphore(ref, transforms, in_avals[0]) - inc = inc.astype(pl_core.SEMAPHORE_INTERPRET_DTYPE) - _, new_sem_value = state_discharge.transform_swap_array( - ref, transforms, sem_value + inc - ) - return (new_sem_value,) + (None,) * (len(in_avals) - 1), () -state_discharge.register_discharge_rule(semaphore_signal_p)( - _semaphore_signal_discharge_rule -) - - -semaphore_wait_p = jax_core.Primitive('semaphore_wait') -semaphore_wait_p.multiple_results = True - -def semaphore_wait(sem_or_view, dec: int | jax.Array = 1): - ref, transforms = _get_ref_and_transforms(sem_or_view) - dec = jnp.asarray(dec, dtype=jnp.int32) - args = [ref, transforms, dec] - flat_args, args_tree = tree_util.tree_flatten(args) - semaphore_wait_p.bind(*flat_args, args_tree=args_tree) - -@semaphore_wait_p.def_abstract_eval -def _semaphore_wait_abstract_eval(*avals, args_tree): - sem_aval, sem_transforms_avals, value_aval = tree_util.tree_unflatten( - args_tree, avals - ) - check_sem_avals(sem_aval, sem_transforms_avals, "wait") - if value_aval.dtype != jnp.dtype("int32"): - raise ValueError("Must wait an int32 value.") - return [] - -def _semaphore_wait_pp_eqn(eqn: jax_core.JaxprEqn, - context: jax_core.JaxprPpContext, - settings: jax_core.JaxprPpSettings): - del settings - invars = eqn.invars - tree = eqn.params["args_tree"] - ( - sem, - sem_transforms, - value, - ) = tree_util.tree_unflatten(tree, invars) - return pp.concat([ - pp.text("semaphore_wait"), - pp.text(" "), - sp.pp_ref_transforms(context, sem, sem_transforms), - pp.text(" "), - pp.text(jax_core.pp_var(value, context)), - ]) -jax_core.pp_eqn_rules[semaphore_wait_p] = _semaphore_wait_pp_eqn - -def _semaphore_wait_discharge_rule(in_avals, - out_avals, - *flat_args, - args_tree): - del out_avals - [ref, transforms, dec] = args_tree.unflatten(flat_args) - sem_value = _transform_semaphore(ref, transforms, in_avals[0]) - dec = dec.astype(pl_core.SEMAPHORE_INTERPRET_DTYPE) - _, new_sem_value = state_discharge.transform_swap_array( - ref, transforms, sem_value - dec - ) - return (new_sem_value,) + (None,) * (len(in_avals) - 1), () -state_discharge.register_discharge_rule(semaphore_wait_p)( - _semaphore_wait_discharge_rule -) - - @dataclasses.dataclass class AsyncCopyDescriptor: src_ref: Any @@ -419,21 +188,32 @@ class AsyncCopyDescriptor: dst_sem_transforms: tuple[Transform, ...] src_sem: int | jax.Array | None src_sem_transforms: tuple[Transform, ...] | None - device_id: int | jax.Array | None - device_id_type: DeviceIdType = DeviceIdType.MESH + device_id: MultiDimDeviceId | IntDeviceId | None + device_id_type: primitives.DeviceIdType = primitives.DeviceIdType.MESH + _used: bool = dataclasses.field( + default=False, init=False, compare=False, hash=False + ) def __post_init__(self): if (self.src_sem is None) ^ (self.device_id is None): raise ValueError("Either both or neither `src_sem` and `device_id` " "can be set.") + def __del__(self): + if not self._used: + # Exceptions in ``__del__`` are ignored, so logging is our only option. + logging.error( + "AsyncCopyDescriptor was not used." + " Did you mean to call `start` or `wait` on it?" + ) + @property def is_remote(self): return self.src_sem is not None def _get_args_and_tree(self, swap_src_and_dst: bool = False): if swap_src_and_dst: - return tree_util.tree_flatten(( + return _dma_flatten( self.dst_ref, self.dst_transforms, self.src_ref, @@ -443,9 +223,9 @@ def _get_args_and_tree(self, swap_src_and_dst: bool = False): self.dst_sem, self.dst_sem_transforms, self.device_id, - )) + ) else: - return tree_util.tree_flatten(( + return _dma_flatten( self.src_ref, self.src_transforms, self.dst_ref, @@ -455,11 +235,18 @@ def _get_args_and_tree(self, swap_src_and_dst: bool = False): self.src_sem, self.src_sem_transforms, self.device_id, - )) + ) - def start(self): + def start(self, priority: int = 0, *, add: bool = False): + self._used = True flat_args, tree = self._get_args_and_tree() - dma_start_p.bind(*flat_args, tree=tree, device_id_type=self.device_id_type) + dma_start_p.bind( + *flat_args, + tree=tree, + device_id_type=self.device_id_type, + priority=priority, + add=add, + ) def wait(self): if self.is_remote: @@ -467,12 +254,14 @@ def wait(self): self.wait_recv() def wait_recv(self): + self._used = True flat_args, tree = self._get_args_and_tree() dma_wait_p.bind( *flat_args, tree=tree, device_id_type=self.device_id_type ) def wait_send(self): + self._used = True if not self.is_remote: raise ValueError("Cannot `wait_send` on a local copy.") # We swap src and dst since by default dma_wait_p waits on the dst_sem @@ -484,11 +273,178 @@ def wait_send(self): ) +def _dma_flatten( + src_ref, + src_transforms, + dst_ref, + dst_transforms, + dst_sem, + dst_sem_transforms, + src_sem, + src_sem_transforms, + device_id, +): + return tree_util.tree_flatten(( + src_ref, + _maybe_wrap_transformed_refs(src_transforms), + dst_ref, + _maybe_wrap_transformed_refs(dst_transforms), + dst_sem, + dst_sem_transforms, + src_sem, + src_sem_transforms, + device_id, + )) + + +def _dma_unflatten(tree, flat_args): + ( + src_ref, + src_transforms, + dst_ref, + dst_transforms, + dst_sem, + dst_sem_transforms, + src_sem, + src_sem_transforms, + device_id, + ) = tree_util.tree_unflatten(tree, flat_args) + return ( + src_ref, + _maybe_unwrap_transformed_refs(src_transforms), + dst_ref, + _maybe_unwrap_transformed_refs(dst_transforms), + dst_sem, + dst_sem_transforms, + src_sem, + src_sem_transforms, + device_id, + ) + + +def _maybe_wrap_transformed_refs(transforms: Any) -> Any: + return jax.tree.map( + lambda obj: _maybe_wrap_transformed_refs(TransformedRefTree.wrap(obj)) + if isinstance(obj, state.TransformedRef) + else obj, + transforms, + ) + + +def _maybe_unwrap_transformed_refs(transforms: Any) -> Any: + return jax.tree.map( + lambda obj: _maybe_unwrap_transformed_refs(obj.unwrap()) + if isinstance(obj, TransformedRefTree) + else obj, + transforms, + is_leaf=lambda obj: isinstance(obj, TransformedRefTree), + ) + + +@jax.tree_util.register_dataclass +@dataclasses.dataclass(frozen=True) +class TransformedRefTree(state.TransformedRef): + """A PyTree wrapper for a ``TransformedRef``. + + The wrapper is necessary to support the case when a ``TransformedRef`` is + indexed with other ``TransformedRef``s. + """ + + @classmethod + def wrap(cls, ref: state.TransformedRef) -> TransformedRefTree: + return cls(ref.ref, ref.transforms) + + def unwrap(self) -> state.TransformedRef: + return state.TransformedRef(self.ref, self.transforms) + + +def _get_dma_effects( + src_transforms_avals, + dst_transforms_avals, + dst_sem_transforms_avals, + src_sem_aval, + device_id_aval, + device_id_type, +): + n_src_transforms = len(tree_util.tree_leaves(src_transforms_avals)) + n_dst_transforms = len(tree_util.tree_leaves(dst_transforms_avals)) + n_dst_sem_transforms = len(tree_util.tree_leaves(dst_sem_transforms_avals)) + dst_sem_index = 1 + n_src_transforms + 1 + n_dst_transforms + effs = { + state.ReadEffect(0), # Read from src ref + state.WriteEffect(n_src_transforms + 1), # Write to dst ref + state.WriteEffect(dst_sem_index), # Write to dst sem + } + if src_sem_aval is not None: + src_sem_index = ( + 1 + n_src_transforms + 1 + n_dst_transforms + 1 + n_dst_sem_transforms + ) + effs.add(state.WriteEffect(src_sem_index)) + if device_id_aval is not None: + if device_id_type is primitives.DeviceIdType.MESH and isinstance( + device_id_aval, dict + ): + for k in device_id_aval: + if not isinstance(k, tuple): + k = (k,) + for k_ in k: + effs.add(jax_core.NamedAxisEffect(k_)) + return effs + + dma_start_p = jax_core.Primitive('dma_start') dma_start_p.multiple_results = True +def _dma_is_high(*avals, **params): + return any(aval.is_high for aval in avals) + +dma_start_p.is_high = _dma_is_high # type: ignore[method-assign] + +def _dma_start_to_lojax(*args, tree, device_id_type, priority, add): + ( + src_ref, + src_transforms, + dst_ref, + dst_transforms, + dst_sem, + dst_sem_transforms, + src_sem, + src_sem_transforms, + device_id, + ) = tree_util.tree_unflatten(tree, args) + src_ref_aval = jax_core.get_aval(src_ref) + dst_ref_aval = jax_core.get_aval(dst_ref) + if not (src_ref_aval.is_high and dst_ref_aval.is_high): + raise NotImplementedError("dma_start not implemented in LoJAX yet.") + dst_sem_aval = jax_core.get_aval(dst_sem) + if dst_sem_aval.is_high: + raise NotImplementedError("dma_start not implemented in LoJAX yet.") + if src_sem is not None: + if jax_core.get_aval(src_sem).is_high: + raise NotImplementedError("dma_start not implemented in LoJAX yet.") + src_transformed_ref = state.TransformedRef(src_ref, src_transforms) + dst_transformed_ref = state.TransformedRef(dst_ref, dst_transforms) + if src_sem is not None: + src_sem = state.TransformedRef(src_sem, src_sem_transforms) + dst_sem = state.TransformedRef(dst_sem, dst_sem_transforms) + + src_ref_aval.inner_aval.dma_start( + src_transformed_ref, + dst_transformed_ref, + src_sem, + dst_sem, + device_id=device_id, + priority=priority, + device_id_type=device_id_type, + add=add + ) + return [] +dma_start_p.to_lojax = _dma_start_to_lojax + @dma_start_p.def_effectful_abstract_eval -def _dma_start_abstract_eval(*args, tree, device_id_type): +def _dma_start_abstract_eval(*args, tree, device_id_type, priority, add): + if priority < 0: + raise ValueError(f"DMA start priority must be non-negative: {priority}") ( src_ref_aval, src_transforms_avals, @@ -499,7 +455,11 @@ def _dma_start_abstract_eval(*args, tree, device_id_type): src_sem_aval, src_sem_transforms_avals, device_id_aval, - ) = tree_util.tree_unflatten(tree, args) + ) = _dma_unflatten(tree, args) + if not all(isinstance(x, state.AbstractRef) for x in [ + src_ref_aval, dst_ref_aval, dst_sem_aval]): + raise ValueError( + "DMA source/destination/semaphore arguments must be Refs.") dst_sem_shape = dst_sem_aval.shape if dst_sem_transforms_avals: dst_sem_shape = dst_sem_transforms_avals[-1].get_indexer_shape() @@ -508,6 +468,9 @@ def _dma_start_abstract_eval(*args, tree, device_id_type): f"Cannot signal on a non-()-shaped semaphore: {dst_sem_shape}" ) if src_sem_aval is not None: + if not isinstance(src_sem_aval, state.AbstractRef): + raise ValueError( + "DMA source semaphore must be a Ref.") src_sem_shape = src_sem_aval.shape if src_sem_transforms_avals: src_sem_shape = src_sem_transforms_avals[-1].get_indexer_shape() @@ -515,14 +478,22 @@ def _dma_start_abstract_eval(*args, tree, device_id_type): raise ValueError( f"Cannot signal on a non-()-shaped semaphore: {src_sem_shape}" ) - n_src_transforms = len(tree_util.tree_leaves(src_transforms_avals)) - return [], {state.ReadEffect(0), state.WriteEffect(n_src_transforms + 1)} + return [], _get_dma_effects( + src_transforms_avals, + dst_transforms_avals, + dst_sem_transforms_avals, + src_sem_aval, + device_id_aval, + device_id_type, + ) def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn, context: jax_core.JaxprPpContext, settings: jax_core.JaxprPpSettings): invars = eqn.invars tree = eqn.params["tree"] + priority = eqn.params["priority"] + add = eqn.params["add"] ( src_ref, src_transforms, @@ -533,13 +504,13 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn, src_sem, src_sem_transforms, device_id, - ) = tree_util.tree_unflatten(tree, invars) + ) = _dma_unflatten(tree, invars) del src_sem_transforms # TODO(sharadmv): pretty print source semaphores and device id if src_sem or device_id: return jax_core._pp_eqn(eqn, context, settings) return pp.concat([ - pp.text("dma_start"), + pp.text(f"dma_start(p{priority}{', add' if add else ''})"), pp.text(" "), sp.pp_ref_transforms(context, src_ref, src_transforms), pp.text(" -> "), @@ -550,8 +521,16 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn, jax_core.pp_eqn_rules[dma_start_p] = _dma_start_pp_eqn -def dma_start_partial_discharge_rule(should_discharge, in_avals, out_avals, - *args, tree, device_id_type): + +def dma_start_partial_discharge_rule( + should_discharge, in_avals, out_avals, *args, tree, device_id_type, + priority, add +): + # Note: we ignore the DMA priority in discharge rules. + del priority + if add: + raise NotImplementedError( + "DMA partial discharge add=True not yet implemented.") ( src_ref, src_transforms, @@ -562,7 +541,7 @@ def dma_start_partial_discharge_rule(should_discharge, in_avals, out_avals, src_sem, src_sem_transforms, device_id, - ) = tree_util.tree_unflatten(tree, args) + ) = _dma_unflatten(tree, args) ( _, src_transforms_avals, @@ -573,7 +552,7 @@ def dma_start_partial_discharge_rule(should_discharge, in_avals, out_avals, src_sem_aval, src_sem_transforms_avals, _, - ) = tree_util.tree_unflatten(tree, in_avals) + ) = _dma_unflatten(tree, in_avals) del out_avals ( @@ -584,7 +563,7 @@ def dma_start_partial_discharge_rule(should_discharge, in_avals, out_avals, dst_sem_discharge, _, *maybe_src_sem_discharge, - ) = tree_util.tree_unflatten(tree, should_discharge) + ) = _dma_unflatten(tree, should_discharge) is_remote = device_id is not None src_sem_discharge = None @@ -610,14 +589,24 @@ def dma_start_partial_discharge_rule(should_discharge, in_avals, out_avals, # TODO(justinfu): Verify that code only works in SPMD mode. axis_env = jax_core.get_axis_env() nonempty_axes = [name for name in axis_env.axis_sizes if name is not None] - if device_id_type == DeviceIdType.LOGICAL: + if isinstance(device_id, dict): + if device_id_type is not primitives.DeviceIdType.MESH: + raise ValueError( + "`device_id_type` must be MESH if `device_id` is a dict," + f" got: {device_id_type = }." + ) + device_id_list = [] + for axis in nonempty_axes: + device_id_list.append(device_id.get(axis, jax.lax.axis_index(axis))) + device_id = tuple(device_id_list) + if device_id_type == primitives.DeviceIdType.LOGICAL: if len(nonempty_axes) > 1: raise NotImplementedError("Sharding with more than one named axis not " "implemented in dma_start_p for LOGICAL " "device_id_type.") shard_axis = nonempty_axes[0] my_axis = jax.lax.axis_index(shard_axis) - elif device_id_type == DeviceIdType.MESH: + elif device_id_type == primitives.DeviceIdType.MESH: device_id_len = 1 if isinstance(device_id, jax.Array): device_id_len = device_id.size @@ -667,7 +656,7 @@ def do_discharge_dst(dst_ref=dst_ref): def do_discharge_dst_sem(dst_sem=dst_sem): recv_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE) recv_size = jnp.array(recv_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) - dst_sem_value = _transform_semaphore( + dst_sem_value = primitives._transform_semaphore( dst_sem, dst_sem_transforms, dst_sem_aval ) _, ret = state_discharge.transform_swap_array( @@ -678,7 +667,7 @@ def do_discharge_dst_sem(dst_sem=dst_sem): def do_discharge_src_sem(src_sem=src_sem): send_size = jnp.minimum(local_src.size, pl_core.SEMAPHORE_MAX_VALUE) send_size = jnp.array(send_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) - src_sem_value = _transform_semaphore( + src_sem_value = primitives._transform_semaphore( src_sem, src_sem_transforms, src_sem_aval ) _, ret = state_discharge.transform_swap_array( @@ -710,16 +699,74 @@ def do_discharge_src_sem(src_sem=src_sem): return new_vals, [] + state_discharge.register_partial_discharge_rule(dma_start_p)(dma_start_partial_discharge_rule) dma_wait_p = jax_core.Primitive('dma_wait') dma_wait_p.multiple_results = True -@dma_wait_p.def_abstract_eval -def _dma_wait_abstract_eval(*args, tree, device_id_type): - del args, tree, device_id_type +dma_wait_p.is_high = _dma_is_high # type: ignore[method-assign] + +def _dma_wait_to_lojax(*args, tree, device_id_type): + ( + src_ref, + src_transforms, + dst_ref, + dst_transforms, + dst_sem, + dst_sem_transforms, + src_sem, + src_sem_transforms, + device_id, + ) = tree_util.tree_unflatten(tree, args) + src_ref_aval = jax_core.get_aval(src_ref) + dst_ref_aval = jax_core.get_aval(dst_ref) + if not (src_ref_aval.is_high and dst_ref_aval.is_high): + raise NotImplementedError("dma_wait not implemented in LoJAX yet.") + dst_sem_aval = jax_core.get_aval(dst_sem) + if dst_sem_aval.is_high: + raise NotImplementedError("dma_wait not implemented in LoJAX yet.") + if src_sem is not None: + if jax_core.get_aval(src_sem).is_high: + raise NotImplementedError("dma_wait not implemented in LoJAX yet.") + src_transformed_ref = state.TransformedRef(src_ref, src_transforms) + dst_transformed_ref = state.TransformedRef(dst_ref, dst_transforms) + if src_sem is not None: + src_sem = state.TransformedRef(src_sem, src_sem_transforms) + dst_sem = state.TransformedRef(dst_sem, dst_sem_transforms) + src_ref_aval.inner_aval.dma_wait( + src_transformed_ref, + dst_transformed_ref, + src_sem, + dst_sem, + device_id=device_id, + device_id_type=device_id_type, + ) return [] +dma_wait_p.to_lojax = _dma_wait_to_lojax + +@dma_wait_p.def_effectful_abstract_eval +def _dma_wait_abstract_eval(*args, tree, device_id_type): + ( + src_ref_aval, + src_transforms_avals, + dst_ref_aval, + dst_transforms_avals, + dst_sem_aval, + dst_sem_transforms_avals, + src_sem_aval, + src_sem_transforms_avals, + device_id_aval, + ) = _dma_unflatten(tree, args) + return [], _get_dma_effects( + src_transforms_avals, + dst_transforms_avals, + dst_sem_transforms_avals, + src_sem_aval, + device_id_aval, + device_id_type, + ) def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn, context: jax_core.JaxprPpContext, @@ -737,7 +784,7 @@ def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn, _, _, _, - ) = tree_util.tree_unflatten(tree, invars) + ) = _dma_unflatten(tree, invars) return pp.concat([ pp.text("dma_wait"), pp.text(" "), @@ -754,8 +801,10 @@ def dma_wait_partial_discharge_rule(should_discharge, # TODO(b/370563115): perform ref update in dma_wait discharge rule instead of dma_start del out_avals, device_id_type _, _, dst_ref, dst_ref_transforms, dst_sem, dst_sem_transforms, _, _, _ = ( - tree_util.tree_unflatten(tree, args)) - (_, + _dma_unflatten(tree, args) + ) + ( + _, src_ref_transforms_avals, _, dst_ref_transforms_avals, @@ -764,21 +813,21 @@ def dma_wait_partial_discharge_rule(should_discharge, src_sem_aval, src_sem_transforms_avals, device_id_aval, - ) = tree_util.tree_unflatten(tree, in_avals) + ) = _dma_unflatten(tree, in_avals) # The only one we can discharge is the dst semaphore. The provided # buffers are only specified for their types and not their value so # it's completely irrelevant for us here if they are discharged. - should_discharge_unflattened = tree_util.tree_unflatten(tree, should_discharge) + should_discharge_unflattened = _dma_unflatten(tree, should_discharge) if not should_discharge_unflattened[4]: return (None,) * len(in_avals), [] num_sem_transforms = len(tree_util.tree_leaves(dst_sem_transforms_avals)) num_transforms = len(tree_util.tree_leaves(dst_ref_transforms_avals)) - updates = state_discharge.transform_array(dst_ref, dst_ref_transforms) + updates = state_discharge.transform_array(dst_ref[...], dst_ref_transforms) copy_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE) copy_size = jnp.array(copy_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) - sem_value = _transform_semaphore(dst_sem, dst_sem_transforms, dst_sem_aval) + sem_value = primitives._transform_semaphore(dst_sem, dst_sem_transforms, dst_sem_aval) _, new_sem = state_discharge.transform_swap_array( dst_sem, dst_sem_transforms, sem_value - copy_size ) @@ -799,8 +848,18 @@ def _get_ref_and_transforms(ref): return ref.ref, ref.transforms return ref, () -def make_async_copy(src_ref, dst_ref, sem): - """Issues a DMA copying from src_ref to dst_ref.""" + +def make_async_copy(src_ref, dst_ref, sem) -> AsyncCopyDescriptor: + """Creates a description of an asynchronous copy operation. + + Args: + src_ref: The source Reference. + dst_ref: The destination Reference. + sem: The semaphore used to track completion of the copy. + + Returns: + An AsyncCopyDescriptor. + """ src_ref, src_transforms = _get_ref_and_transforms(src_ref) dst_ref, dst_transforms = _get_ref_and_transforms(dst_ref) sem, sem_transforms = _get_ref_and_transforms(sem) @@ -814,17 +873,27 @@ def make_async_copy(src_ref, dst_ref, sem): None, None, None, - DeviceIdType.MESH, + primitives.DeviceIdType.MESH, ) -def async_copy(src_ref, dst_ref, sem): + +def async_copy( + src_ref, dst_ref, sem, *, priority: int = 0, add: bool = False, +) -> AsyncCopyDescriptor: """Issues a DMA copying from src_ref to dst_ref.""" copy_descriptor = make_async_copy(src_ref, dst_ref, sem) - copy_descriptor.start() + copy_descriptor.start(priority=priority, add=add) return copy_descriptor -def make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, - device_id_type: DeviceIdType = DeviceIdType.MESH): + +def make_async_remote_copy( + src_ref, + dst_ref, + send_sem, + recv_sem, + device_id: MultiDimDeviceId | IntDeviceId | None, + device_id_type: primitives.DeviceIdType = primitives.DeviceIdType.MESH, +) -> AsyncCopyDescriptor: """Creates a description of a remote copy operation. Copies data from src_ref on the current device to dst_ref on the device @@ -838,8 +907,10 @@ def make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, dst_ref: The destination Reference. send_sem: The semaphore on the source device. recv_sem: The semaphore on the destination device. - device_id: The device id of the destination device. + device_id: The device id of the destination device. It could be a tuple, or + a dictionary specifying the communication axis and destination index. device_id_type: The type of the device id. + Returns: An AsyncCopyDescriptor. """ @@ -847,6 +918,11 @@ def make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, send_sem, send_sem_transforms = _get_ref_and_transforms(send_sem) dst_ref, dst_transforms = _get_ref_and_transforms(dst_ref) recv_sem, recv_sem_transforms = _get_ref_and_transforms(recv_sem) + if device_id_type == primitives.DeviceIdType.LOGICAL: + assert not isinstance( + device_id, tuple | dict + ), "LOGICAL device_id_type does not support device_id as a tuple or dict." + return AsyncCopyDescriptor( src_ref, src_transforms, @@ -860,28 +936,29 @@ def make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, device_id_type=device_id_type, ) -def async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, - device_id_type: DeviceIdType = DeviceIdType.MESH): + +def async_remote_copy( + src_ref, + dst_ref, + send_sem, + recv_sem, + device_id, + device_id_type: primitives.DeviceIdType = primitives.DeviceIdType.MESH, +) -> AsyncCopyDescriptor: + """Issues a remote DMA copying from src_ref to dst_ref.""" copy_descriptor = make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, device_id_type) copy_descriptor.start() return copy_descriptor -device_id_p = jax_core.Primitive('device_id') - -@device_id_p.def_abstract_eval -def _device_id_abstract_eval(): - return jax_core.ShapedArray((), jnp.dtype("int32")) - -device_id = device_id_p.bind get_barrier_semaphore_p = jax_core.Primitive('get_barrier_semaphore') @get_barrier_semaphore_p.def_abstract_eval def _get_barrier_semaphore_abstract_eval(): - return pl_core.AbstractMemoryRef( - jax_core.ShapedArray((), tpu_core.BarrierSemaphoreTy()), - tpu_core.TPUMemorySpace.SEMAPHORE, + return state.AbstractRef( + jax_core.ShapedArray((), pl_core.BarrierSemaphore()), + tpu_core.MemorySpace.SEMAPHORE, ) def get_barrier_semaphore(): @@ -902,33 +979,27 @@ def get_barrier_semaphore(): to share a collective_id. However, if in doubt, prefer not sharing collective_ids, as doing so incorrectly can lead to silent data corruption or crashes. - Note that re-using the same collective_id doesn't guarantee that the same + Note that reusing the same collective_id doesn't guarantee that the same semaphore is provided by XLA. """ return get_barrier_semaphore_p.bind() -delay_p = jax_core.Primitive("delay") -delay_p.multiple_results = True - - -@delay_p.def_abstract_eval -def _delay_abstract_eval(nanos): - del nanos - return [] - - -def delay(nanos): - """Delays vector execution for the given number of nanosconds.""" - delay_p.bind(nanos) - # RNG Ops prng_seed_p = jax_core.Primitive("prng_seed") prng_seed_p.multiple_results = True -@prng_seed_p.def_abstract_eval -def _(*_): - return [] + +class PRNGEffect(effects.Effect): + pass +prng_effect = PRNGEffect() +effects.control_flow_allowed_effects.add_type(PRNGEffect) +pl_core.kernel_local_effects.add_type(PRNGEffect) + + +@prng_seed_p.def_effectful_abstract_eval +def _prng_seed_abstract_eval(*_): + return [], {prng_effect} def prng_seed(*seeds: int | jax.Array) -> None: @@ -945,8 +1016,261 @@ def prng_seed(*seeds: int | jax.Array) -> None: 'prng_random_bits') @prng_random_bits_p.def_abstract_eval -def _(*, shape): +def _prng_random_bits_abstract_eval(*, shape): return jax_core.ShapedArray(shape, jnp.dtype("int32")) + def prng_random_bits(shape): return prng_random_bits_p.bind(shape=shape) + +# PRNG wrap/unwrap ops. +# We cannot use JAX's key_data and wrap_key_data because they return +# vectors, and Pallas keys are represented as lists of scalars. + +split_key_p = jax_core.Primitive("prng_split") +split_key_p.multiple_results = True + + +@split_key_p.def_abstract_eval +def _split_key_scalar_abstract_eval(seed): + key_shape = seed.dtype._impl.key_shape + if len(key_shape) != 2 or key_shape[0] != 1: + raise ValueError(f"Key shape must be (1, N), got {key_shape}") + return [jax_core.ShapedArray((), jnp.dtype("uint32"))] * key_shape[1] + + +def unwrap_pallas_seed(seed): + """Splits a PRNG key into it's scalar components.""" + return split_key_p.bind(seed) + + +join_key_p = jax_core.Primitive("prng_join") + + +@join_key_p.def_abstract_eval +def _join_key_scalar_abstract_eval(*seeds, impl): + if len(impl.key_shape) != 2 or impl.key_shape[0] != 1: + raise ValueError(f"Key shape must be (1, N), got {impl.key_shape}") + if len(seeds) != impl.key_shape[1]: + raise ValueError( + f"Number of seeds must match key shape, got {len(seeds)}" + f" != {impl.key_shape[1]}." + ) + return jax_core.ShapedArray((), dtype=jax_prng.KeyTy(impl)) + + +def wrap_pallas_seed(*seeds, impl): + """Joins scalar into a single PRNG key.""" + impl = jax_random.resolve_prng_impl(impl) + return join_key_p.bind(*seeds, impl=impl) + + +stochastic_round_p = jax_core.Primitive("stochastic_round") + + +def stochastic_round(x, random_bits, *, target_dtype): + return stochastic_round_p.bind(x, random_bits, target_dtype=target_dtype) + + +@stochastic_round_p.def_abstract_eval +def _stochastic_round_abstract_eval(x, random_bits, *, target_dtype): + if random_bits.shape != x.shape: + raise ValueError( + "The shape of `random_bits` must match the shape of `x` for " + f"stochastic_round, but got {random_bits.shape} and {x.shape}" + ) + if random_bits.dtype != jnp.dtype("uint32"): + raise ValueError( + "The dtype of `random_bits` must be uint32 for stochastic_round, " + f"but got {random_bits.dtype}" + ) + return jax_core.ShapedArray(x.shape, target_dtype) + + +def _get_elementwise_packing_factor(unpacked_dtype, packed_dtype): + unpacked_bitwidth = dtypes.itemsize_bits(unpacked_dtype) + packed_bitwidth = dtypes.itemsize_bits(packed_dtype) + if unpacked_bitwidth % packed_bitwidth != 0: + raise ValueError( + "Unpacked bitwidth must be a multiple of packed bitwidth, got " + f"{unpacked_bitwidth} and {packed_bitwidth}" + ) + return unpacked_bitwidth // packed_bitwidth + +pack_elementwise_p = jax_core.Primitive("pack_elementwise") + + +def pack_elementwise(xs, *, packed_dtype): + return pack_elementwise_p.bind(*xs, packed_dtype=packed_dtype) + + +@pack_elementwise_p.def_abstract_eval +def _pack_elementwise_abstract_eval(*xs, packed_dtype): + if not xs: + raise ValueError("At least one source is required") + first = xs[0] + if not all(x.shape == first.shape for x in xs): + raise ValueError("All sources must have the same shape") + if not all(x.dtype == first.dtype for x in xs): + raise ValueError("All sources must have the same dtype") + if not (first.dtype == jnp.float32 and packed_dtype == jnp.bfloat16) and not ( + jnp.issubdtype(first.dtype, jnp.integer) + and jnp.issubdtype(packed_dtype, jnp.integer) + ): + raise ValueError( + "Only f32 -> bf16 and int -> int are supported. Got" + f" {first.dtype} and {packed_dtype}" + ) + packing_factor = _get_elementwise_packing_factor(first.dtype, packed_dtype) + if len(xs) != packing_factor: + raise ValueError( + "The number of sources must match the packing factor " + f"({packing_factor}), got {len(xs)}" + ) + out_dtype = jnp.dtype(f"uint{dtypes.itemsize_bits(first.dtype)}") + return jax_core.ShapedArray(first.shape, out_dtype) + + +unpack_elementwise_p = jax_core.Primitive("unpack_elementwise") + + +def unpack_elementwise(x, *, index, packed_dtype, unpacked_dtype): + return unpack_elementwise_p.bind( + x, index=index, packed_dtype=packed_dtype, unpacked_dtype=unpacked_dtype + ) + + +@unpack_elementwise_p.def_abstract_eval +def _unpack_elementwise_abstract_eval(x, *, index, packed_dtype, unpacked_dtype): + if x.dtype != jnp.uint32: + raise ValueError(f"Source must be uint32, got {x.dtype}") + packing_factor = _get_elementwise_packing_factor(unpacked_dtype, packed_dtype) + if index < 0 or index >= packing_factor: + raise ValueError( + f"Index {index} is out of bounds for packing factor {packing_factor}") + return jax_core.ShapedArray(x.shape, unpacked_dtype) + + +def with_memory_space_constraint( + x: jax.Array, memory_space: Any +) -> jax.Array: + """Constrains the memory space of an array. + + This primitive does not change the value of ``x``, but it constrains the + memory space where it should be allocated. This is useful to force + Pallas to allocate an array in a specific memory space. + + As of now, this only operates on the inputs pallas_calls, as in you can + apply this to the arguments of a pallas_call and it will constrain them, but + other operations will not respect this constraint. + + Args: + x: The array to constrain. + memory_space: The memory space to constrain to. + + Returns: + The array ``x`` with the memory space constraint. + """ + if memory_space is pl_core.MemorySpace.ANY: + return x + if memory_space not in { + tpu_core.MemorySpace.HBM, + tpu_core.MemorySpace.VMEM, + tpu_core.MemorySpace.SMEM, + }: + raise NotImplementedError( + "with_memory_space_constraint only supports HBM, VMEM and SMEM." + ) + return pl_core.with_memory_space_constraint_p.bind( + x, memory_space=memory_space) + + +def load(ref: Ref, *, mask: jax.Array | None = None) -> jax.Array: + """Loads an array from the given ref. + + If ``mask`` is not specified, this function has the same semantics as + ``ref[idx]`` in JAX. + + Args: + ref: The ref to load from. + mask: An optional boolean mask specifying which indices to load. + + Returns: + The loaded array. + """ + return primitives.load(ref, None, mask=mask) + + +def store(ref: Ref, val: jax.Array, *, mask: jax.Array | None = None) -> None: + """Stores a value to the given ref. + + If ``mask`` is not specified, this function has the same semantics as + ``ref[idx] = val`` in JAX. + + Args: + ref: The ref to store to. + val: The value to store. + mask: An optional boolean mask specifying which indices to store. + """ + return primitives.store(ref, None, val, mask=mask) + + +touch_p = jax_core.Primitive("add_dependency") +touch_p.multiple_results = True + + +def touch(ref: jax.Array | state.TransformedRef) -> None: + """Adds a fake read-write dependency to the given ref.""" + ref_leaves = jax.tree.leaves(ref) + ref_leaves = [ref.ref if isinstance(ref, state.TransformedRef) else ref + for ref in ref_leaves] + for ref in ref_leaves: + touch_p.bind(ref) + + +@touch_p.def_effectful_abstract_eval +def _touch_abstract_eval(ref: jax.Array): + return [], {state.ReadEffect(0), state.WriteEffect(0)} + + +trace_value_p = jax_core.Primitive("trace_value") +trace_value_p.multiple_results = True + + +def trace_value(label: str, value: jax.Array) -> None: + """Emit a scalar value to the current xprof trace scope. + + This appends a dynamic scalar value to the enclosing trace region. + The value will appear in xprof trace viewer associated with the trace event. + + Args: + label: A string label for this value in xprof. + value: A scalar i32 or f32 value to emit. + + Example: + # Inside a Pallas kernel: + x = jnp.sum(y > 0) + pltpu.trace_value("my_x", x) + """ + trace_value_p.bind(value, label=label) + + +class TraceEffect(effects.Effect): + pass + + +trace_effect = TraceEffect() +effects.control_flow_allowed_effects.add_type(TraceEffect) +pl_core.kernel_local_effects.add_type(TraceEffect) + + +@trace_value_p.def_effectful_abstract_eval +def _trace_value_abstract_eval(value, *, label): + del label + if value.shape: + raise ValueError( + f"trace_value requires a scalar value, got shape {value.shape}" + ) + if value.dtype not in (jnp.int32, jnp.float32): + raise ValueError(f"trace_value requires i32 or f32, got {value.dtype}") + return [], {trace_effect} diff --git a/jax/_src/pallas/mosaic/random.py b/jax/_src/pallas/mosaic/random.py index fd8dcc720f07..3751b5611655 100644 --- a/jax/_src/pallas/mosaic/random.py +++ b/jax/_src/pallas/mosaic/random.py @@ -13,18 +13,18 @@ # limitations under the License. from collections.abc import Callable - import functools import jax from jax import numpy as jnp from jax import random as jax_api_random from jax._src import blocked_sampler from jax._src import dtypes +from jax._src import prng as jax_prng from jax._src import typing -from jax._src.pallas.mosaic.primitives import prng_seed -from jax._src.pallas.mosaic.primitives import prng_random_bits from jax._src.pallas import primitives -from jax._src import prng as jax_prng +from jax._src.pallas.mosaic import primitives as tpu_primitives +from jax._src.pallas.mosaic.primitives import prng_random_bits +from jax._src.pallas.mosaic.primitives import prng_seed Shape = jax_prng.Shape @@ -32,8 +32,8 @@ KeylessSampleFnType = Callable[..., jax.Array] set_seed = prng_seed - -FOLD_IN_ROUNDS = 128 +unwrap_pallas_seed = tpu_primitives.unwrap_pallas_seed +wrap_pallas_seed = tpu_primitives.wrap_pallas_seed def to_pallas_key(key: jax.Array) -> jax.Array: @@ -63,7 +63,7 @@ def is_pallas_impl(impl: jax_prng.PRNGImpl) -> bool: def _seed_func(seed: jnp.int32): seed_data = jnp.zeros(tpu_key_impl.key_shape, dtype=jnp.int32) - return (seed_data + seed).astype(jnp.uint32) + return (seed_data + seed).astype(jnp.uint32) # Broadcast the seed. def _random_bits(key: typing.Array, bit_width: int, shape: Shape): if bit_width != 32: @@ -72,42 +72,26 @@ def _random_bits(key: typing.Array, bit_width: int, shape: Shape): return prng_random_bits(shape) def _fold_in(key: jax_prng.PRNGKeyArray, data: typing.Array): - # Roughly, we compute the new key as follows: - # new_key = random_bits(data)[..., 127] ^ random_bits(old_key)[..., 127] - # Because the TPU generates random numbers in (8, 128) blocks at once, we - # can generate that many values without additional cost which will reduce - # correlation between the old and new keys. - - # TODO(justinfu): The underlying TPU hardware PRNG doesn't produce robust - # random bits when applied in rounds such as below (measured via crush). - # We should consider a different strategy for generating keys. - key_shape = tpu_key_impl.key_shape - - prng_seed(data) - data_bits = prng_random_bits( - key_shape + (FOLD_IN_ROUNDS,)).astype(jnp.uint32) - prng_seed(key) - key_bits = prng_random_bits( - key_shape + (FOLD_IN_ROUNDS,)).astype(jnp.uint32) - - mixed = key_bits[..., FOLD_IN_ROUNDS-1] ^ data_bits[..., FOLD_IN_ROUNDS-1] - assert mixed.shape == key_shape - return jax.random.wrap_key_data(mixed, impl="pallas_tpu") + key0, key1 = unwrap_pallas_seed(key) + # Perform a cheap mixing of data into the key. + key1 = key1 + data + [key0, key1] = jax_prng.apply_round([key0, key1], 13) + return wrap_pallas_seed(key0, key1, impl="pallas_tpu") def _split(key: typing.Array, shape: Shape): del key, shape - raise NotImplementedError() + raise NotImplementedError( + "Cannot split a Pallas key. Use fold_in instead to generate new keys." + ) tpu_key_impl = jax_prng.PRNGImpl( - # Pallas currently only supports 2D+ windows, so set the key_shape - # to be 2D to have better compatibility with setting BlockSpecs. - key_shape=(1, 1), - seed=_seed_func, - split=_split, - random_bits=_random_bits, - fold_in=_fold_in, - name="pallas_tpu", - tag="pl" + key_shape=(1, 2), + seed=_seed_func, + split=_split, + random_bits=_random_bits, + fold_in=_fold_in, + name="pallas_tpu", + tag="pl", ) jax_prng.register_prng(tpu_key_impl) @@ -172,9 +156,10 @@ def new_sampler(*args, **kwargs): new_sampler.__doc__ = "\n".join(doc_lines) return new_sampler -bits = _make_stateful_sampler(jax_api_random.bits) # type: ignore -uniform = _make_stateful_sampler(jax_api_random.uniform) # type: ignore -bernoulli = _make_stateful_sampler(jax_api_random.bernoulli) # type: ignore +stateful_bits = _make_stateful_sampler(jax_api_random.bits) # type: ignore +stateful_uniform = _make_stateful_sampler(jax_api_random.uniform) # type: ignore +stateful_bernoulli = _make_stateful_sampler(jax_api_random.bernoulli) # type: ignore +stateful_normal = _make_stateful_sampler(jax_api_random.normal) # type: ignore def sample_block(sampler_fn: SampleFnType, @@ -186,14 +171,14 @@ def sample_block(sampler_fn: SampleFnType, **kwargs) -> jax.Array: """Samples a block of random values with invariance guarantees. - `sample_block` allows the sampling of identical blocks of random values + ``sample_block`` allows the sampling of identical blocks of random values across kernels with different block shapes and iteration orders. Each call to `sample_block` returns a `block_size`-shaped array of random samples corresponding to the `block_index`. - `tile_size` should be chosen such that it is a divisor to all block sizes - one needs to be invariant to. The larger the `tile_size`, the more - efficient the sampling process wil be and therefore the best choice is + ``tile_size`` should be chosen such that it is a divisor to all block sizes + one needs to be invariant to. The larger the ``tile_size``, the more + efficient the sampling process will be and therefore the best choice is typically the greatest common divisor between all possible block sizes. Args: @@ -201,7 +186,7 @@ def sample_block(sampler_fn: SampleFnType, random samples. global_key: The global key to use for sampling. block_size: The shape of an individual block. - tile_size: The shape of a `tile`, which is the smallest unit at + tile_size: The shape of a ``tile``, which is the smallest unit at which samples are generated. This should be selected to be a divisor of all block sizes one needs to be invariant to. total_size: The total size of the array to sample. @@ -210,8 +195,8 @@ def sample_block(sampler_fn: SampleFnType, **kwargs: Additional arguments to pass to the sampler_fn. Returns: - A `block_size` shaped array of samples for the current block corresponding - to `block_index`. + A ``block_size`` shaped array of samples for the current block corresponding + to ``block_index``. """ if len(block_size) != len(tile_size): raise ValueError(f"block_size ({len(block_size)}) and tile_size " diff --git a/jax/_src/pallas/mosaic/sc_core.py b/jax/_src/pallas/mosaic/sc_core.py new file mode 100644 index 000000000000..8f3001f25730 --- /dev/null +++ b/jax/_src/pallas/mosaic/sc_core.py @@ -0,0 +1,387 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains SparseCore-specific Pallas abstractions.""" + +from __future__ import annotations + +import collections +from collections.abc import Sequence +import dataclasses +import math +from typing import Any, TypeAlias + +import jax +from jax._src import core as jax_core +from jax._src import state +from jax._src import tree_util +from jax._src.pallas import core as pallas_core +from jax._src.pallas import primitives as pallas_primitives +from jax._src.pallas.mosaic import core as tpu_core +from jax._src.pallas.mosaic import tpu_info +import jax.numpy as jnp + + +Tiling: TypeAlias = Sequence[Sequence[int]] + + +@dataclasses.dataclass(frozen=True) +class MemoryRef(pallas_core.MemoryRef): + """A MemoryRef for SparseCore.""" + + tiling: Tiling | None = None + + def __init__( + self, + shape: Sequence[int], + dtype: jax.typing.DTypeLike, + memory_space: tpu_core.MemorySpace, + tiling: Tiling | None = None, + ): + super().__init__(jax_core.ShapedArray(shape, dtype), memory_space) + + for tile in tiling or (): + if len(tile) > len(shape): + raise ValueError( + f"Tile rank must not exceed shape rank: {tile=} vs {shape=}" + ) + + object.__setattr__(self, "tiling", tiling) + + def get_ref_aval(self) -> state.TransformedRef | state.AbstractRef: + # TODO(sharadmv): Clean this up. ShapedArrayWithMemorySpace fails when we + # try to apply JAX ops to it. + return AbstractRef(self.inner_aval, self.memory_space, self.tiling) + + +class AbstractRef(state.AbstractRef): + """An AbstractRef for SparseCore.""" + + tiling: Tiling | None = None + + def __init__( + self, + aval: jax_core.AbstractValue, + memory_space: tpu_core.MemorySpace, + tiling: Tiling | None, + ): + super().__init__(aval, memory_space) + + self.tiling = tiling + + def update( # type: ignore[override] + self, + inner_aval: Any | None = None, + memory_space: Any | None = None, + tiling: Tiling | None = None, + ) -> AbstractRef: + return AbstractRef( + inner_aval if inner_aval is not None else self.inner_aval, + memory_space if memory_space is not None else self.memory_space, + tiling if tiling is not None else self.tiling, + ) + + +@dataclasses.dataclass +class BlockSpec(pallas_core.BlockSpec): + """A BlockSpec for SparseCore. + + Attributes: + indexed_by: The optional index of a parameter to use as the indexer. If set, + the pipeline emitter will issue and indirect stream indexing into the + value of this parameter as part of the pipeline. + indexed_dim: The dimension to index into. Optional unless ``indexed_by`` is + set. + + See also: + :class:`jax.experimental.pallas.BlockSpec` + """ + + indexed_by: int | None = None + indexed_dim: int | None = None + + def __post_init__(self): + if (self.indexed_by is None) != (self.indexed_dim is None): + raise ValueError( + "indexed_by and indexed_dim must both be set or both unset" + ) + + def to_block_mapping( + self, + origin: pallas_core.OriginStr, + array_aval: jax_core.ShapedArray, + *, + index_map_avals: Sequence[jax_core.AbstractValue], + index_map_tree: tree_util.PyTreeDef, + grid: pallas_core.GridMappingGrid, + vmapped_dims: tuple[int, ...], + debug: bool = False, + ) -> BlockMapping: + bm = super().to_block_mapping( + origin, + array_aval, + index_map_avals=index_map_avals, + index_map_tree=index_map_tree, + grid=grid, + vmapped_dims=vmapped_dims, + debug=debug, + ) + return BlockMapping( + **{f.name: getattr(bm, f.name) for f in dataclasses.fields(bm)}, + indexed_by=self.indexed_by, + indexed_dim=self.indexed_dim, + ) + + +@dataclasses.dataclass(frozen=True) +class BlockMapping(pallas_core.BlockMapping): + indexed_by: int | None = None + indexed_dim: int | None = None + + +def get_sparse_core_info() -> tpu_info.SparseCoreInfo: + """Returns the SparseCore information for the current device.""" + return tpu_info.get_tpu_info().sparse_core or tpu_info.SparseCoreInfo( + num_cores=0, num_subcores=0, num_lanes=0, dma_granule_size_bytes=0, + ) + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class ScalarSubcoreMesh: + axis_name: str + num_cores: int + + @property + def backend(self) -> str: + return "mosaic_tpu" + + @property + def shape(self): + return collections.OrderedDict(core=self.num_cores) + + def discharges_effect(self, effect): + del effect # Unused. + return False + + +def gather_global_allocations(jaxpr): + + def _gather_from_eqns(*, eqn=None, jaxpr=None): + if eqn is not None: + if eqn.primitive is pallas_primitives.get_global_p: + what = eqn.params["what"] + yield pallas_core.MemoryRef(what.inner_aval, what.memory_space) + for subjaxpr in jax_core.jaxprs_in_params(eqn.params): + yield from _gather_from_eqns(jaxpr=subjaxpr) + else: + for eqn in jaxpr.eqns: + yield from _gather_from_eqns(eqn=eqn) + + allocations = collections.defaultdict(list) + for memref in _gather_from_eqns(jaxpr=jaxpr): + allocations[memref].append(memref) + return allocations + + +def _scalar_subcore_mesh_discharge_rule( + in_avals, + out_avals, + *args, + mesh, + jaxpr, + compiler_params, + interpret, + debug, + cost_estimate, + name, + metadata, +): + if not isinstance(mesh, ScalarSubcoreMesh): + raise TypeError(f"Mesh must be a ScalarSubcoreMesh, got {type(mesh)}") + assert len(mesh.shape) == 1 + sc_info = get_sparse_core_info() + if mesh.num_cores > (num_expected := sc_info.num_cores): + raise ValueError( + f"Mesh has {mesh.num_cores} cores, but the current TPU chip has only" + f" {num_expected} SparseCores" + ) + if compiler_params is None: + compiler_params = tpu_core.CompilerParams() + if compiler_params.dimension_semantics is not None: + raise ValueError("ScalarSubcoreMesh does not support dimension_semantics=") + sa_avals = [a for a in in_avals if isinstance(a, jax_core.ShapedArray)] + if sa_avals: + raise NotImplementedError( + f"Cannot close over values in core_map: {sa_avals}" + ) + return pallas_core.default_mesh_discharge_rule( + in_avals, + out_avals, + *args, + mesh=mesh, + jaxpr=jaxpr, + compiler_params=dataclasses.replace( + compiler_params, + dimension_semantics=["core_parallel"], + kernel_type=tpu_core.KernelType.SC_SCALAR_SUBCORE, + ), + interpret=interpret, + debug=debug, + cost_estimate=cost_estimate, + name=name, + memory_space=tpu_core.MemorySpace.HBM, + metadata=metadata, + scratch_shapes=tree_util.tree_leaves(gather_global_allocations(jaxpr)), + ) + + +pallas_core._core_map_mesh_rules[ScalarSubcoreMesh] = ( + _scalar_subcore_mesh_discharge_rule +) + +def _get_num_cores() -> int: + """Returns the number of cores for the current SparseCore.""" + return get_sparse_core_info().num_cores + +def _get_num_subcores() -> int: + """Returns the number of subcores for the current SparseCore.""" + return get_sparse_core_info().num_subcores + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class VectorSubcoreMesh: + core_axis_name: str + subcore_axis_name: str + num_cores: int = dataclasses.field(default_factory=_get_num_cores) + num_subcores: int = dataclasses.field( + default_factory=_get_num_subcores, init=False + ) + + def __post_init__(self): + sc_info = get_sparse_core_info() + if self.num_cores > (num_expected := sc_info.num_cores): + raise ValueError( + f"Mesh has {self.num_cores} cores, but the current TPU chip has only" + f" {num_expected} SparseCores" + ) + if self.num_subcores != sc_info.num_subcores: + raise ValueError( + f"Mesh has {self.num_subcores} subcores, but the current TPU chip has" + f" only {num_expected} subcores" + ) + + @property + def backend(self) -> str: + return "mosaic_tpu" + + @property + def shape(self): + return collections.OrderedDict( + core=self.num_cores, subcore=self.num_subcores) + + def discharges_effect(self, effect): + del effect # Unused. + return False + + +def _vector_subcore_mesh_discharge_rule( + in_avals, + out_avals, + *args, + mesh, + jaxpr, + compiler_params, + interpret, + debug, + cost_estimate, + name, + metadata, +): + if not isinstance(mesh, VectorSubcoreMesh): + raise TypeError(f"Mesh must be a VectorSubcoreMesh, got {type(mesh)}") + assert len(mesh.shape) == 2 + sc_info = get_sparse_core_info().num_cores + if mesh.num_cores > (num_expected := sc_info): + raise ValueError( + f"Mesh has {mesh.num_cores} cores, but the current TPU chip has only" + f" {num_expected} SparseCores" + ) + if compiler_params is None: + compiler_params = tpu_core.CompilerParams() + if compiler_params.dimension_semantics is not None: + raise ValueError("VectorSubcoreMesh does not support dimension_semantics=") + return pallas_core.default_mesh_discharge_rule( + in_avals, + out_avals, + *args, + mesh=mesh, + jaxpr=jaxpr, + compiler_params=dataclasses.replace( + compiler_params, + dimension_semantics=["core_parallel", "subcore_parallel"], + kernel_type=tpu_core.KernelType.SC_VECTOR_SUBCORE, + ), + interpret=interpret, + debug=debug, + cost_estimate=cost_estimate, + name=name, + memory_space=tpu_core.MemorySpace.HBM, + metadata=metadata, + scratch_shapes=tree_util.tree_leaves(gather_global_allocations(jaxpr)), + ) + + +pallas_core._core_map_mesh_rules[VectorSubcoreMesh] = ( + _vector_subcore_mesh_discharge_rule +) + + +# TODO(slebedev): Only keep the shapes which do not require unrolling. +SUPPORTED_VECTOR_SHAPES = collections.defaultdict(list) +for dtype in [jnp.int32, jnp.uint32, jnp.float32]: + SUPPORTED_VECTOR_SHAPES[jnp.dtype(dtype)].extend([ + # fmt: off + (8,), (16,), (32,), (64,), + (1, 8), (1, 16), + (2, 8), (2, 16), + (4, 8), (4, 16), + # fmt: on + ]) +for dtype in [jnp.int16, jnp.uint16, jnp.float16, jnp.bfloat16]: + SUPPORTED_VECTOR_SHAPES[jnp.dtype(dtype)].extend([ + # fmt: off + (16,), (32,), (64,), + (2, 8), (2, 16), + # fmt: on + ]) +for dtype in [jnp.float16, jnp.bfloat16]: + SUPPORTED_VECTOR_SHAPES[jnp.dtype(dtype)].extend([ + # fmt: off + (4, 8), (4, 16), + # fmt: on + ]) +for dtype in [jnp.int8, jnp.uint8]: + SUPPORTED_VECTOR_SHAPES[jnp.dtype(dtype)].extend([ + # fmt: off + (32,), (64,), + (4, 8), (4, 16), + # fmt: on + ]) + + +# Make sure all combinations are divisible by the vector register size. +supported_shapes: list[Any] = [] +for dtype, supported_shapes in SUPPORTED_VECTOR_SHAPES.items(): + for shape in supported_shapes: + assert (math.prod(shape) * dtype.itemsize) % 32 == 0 +del dtype, supported_shapes diff --git a/jax/_src/pallas/mosaic/sc_lowering.py b/jax/_src/pallas/mosaic/sc_lowering.py new file mode 100644 index 000000000000..2817ffb53860 --- /dev/null +++ b/jax/_src/pallas/mosaic/sc_lowering.py @@ -0,0 +1,1027 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Lowering for Pallas TPU SparseCore.""" + +from typing import Any, NoReturn, cast +from collections.abc import Sequence +import contextlib +import dataclasses +import functools + +from jax._src import api_util +from jax._src import core as jax_core +from jax._src import debugging +from jax._src import lax +from jax._src import linear_util as lu +from jax._src import mesh as mesh_lib +from jax._src import numpy as jnp +from jax._src import source_info_util +from jax._src import state +from jax._src import util +from jax._src import tree_util +from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import arith +from jax._src.lib.mlir.dialects import func +from jax._src.lib.mlir.dialects import memref +from jax._src.lib.mlir.dialects import vector +from jax._src.pallas import core as pallas_core +from jax._src.pallas import primitives as pallas_primitives +from jax._src.pallas.mosaic import core as tpu_core +from jax._src.pallas.mosaic import lowering as tc_lowering +from jax._src.pallas.mosaic import primitives as tpu_primitives +from jax._src.pallas.mosaic import sc_core +from jax._src.state import discharge as state_discharge +from jax._src.state import indexing +from jax._src.state import primitives as state_primitives +from jax.experimental.mosaic.dialects import tpu + + +map, unsafe_map = util.safe_map, map +zip, unsafe_zip = util.safe_zip, zip + + +MemorySpace = tpu_core.MemorySpace + + +class GlobalAllocations: + """Hands out global allocations sequentially during lowering.""" + def __init__(self, allocations: dict[pallas_core.MemoryRef, list[ir.Value]]): + self._allocations = {k: list(v) for k, v in allocations.items()} + + def next_allocation(self, what: state.AbstractRef | pallas_core.TransformedRef) -> Any: + """Returns the next available allocation for the given shape.""" + what = pallas_core.MemoryRef(what.inner_aval, what.memory_space) + if what not in self._allocations: + raise LookupError(f"No allocations are available for {what}.") + if not self._allocations[what]: + raise LookupError(f"No more allocations available for {what}.") + return self._allocations[what].pop() + + @contextlib.contextmanager + def verify_usage(self): + """Scope that verifies all allocations are used.""" + try: + yield + finally: + unused = [k for k, v in self._allocations.items() if v] + if unused: + raise AssertionError(f"Some allocations unused ({unused}).") + + +@dataclasses.dataclass +class ScLoweringContext(tc_lowering.LoweringContext): + """Lowering context for SparseCore.""" + global_allocations: GlobalAllocations + +LoweringRuleContext = tc_lowering.LoweringRuleContext + +_transform_ref = tc_lowering._transform_ref +_dtype_to_ir_type = tc_lowering._dtype_to_ir_type + +# pylint: disable=protected-access + + +def dynamic_shape_replacement_fn(x): + return x + + +def lower_jaxpr_to_module( + lowering_context: mlir.LoweringRuleContext, + grid_mapping: pallas_core.GridMapping, + jaxpr: jax_core.Jaxpr, + *, + dimension_semantics: Sequence[tpu_core.DimensionSemantics] | None, + kernel_type: tpu_core.KernelType, + mesh: mesh_lib.Mesh | None = None, + dynamic_shape_replacement_enabled: bool = False, +) -> ir.Module: + """Lowers a Jaxpr to a Mosaic SparseCore module.""" + if dynamic_shape_replacement_enabled: + raise NotImplementedError( + "Dynamic shape replacement is not supported for SparseCore." + ) + if ( + lowering_context.is_forward_compat() + or tc_lowering.is_cloud_tpu_older_than( + 2026, 1, 18, lowering_context.module_context.backend + ) + ) and not grid_mapping.grid: + # TODO(slebedev): Remove this branch after Jan 18th 2026. + index_map_avals, index_map_tree = tree_util.tree_flatten( + ((jax_core.ShapedArray((), jnp.int32),), {}) + ) + if grid_mapping.num_index_operands: + raise ValueError( + "Index operands not supported for SparseCore when grid is empty." + ) + new_grid = (1,) + new_block_mappings = [] + for bm in grid_mapping.block_mappings: + + def new_index_map(*args, bm=bm): + return jax_core.eval_jaxpr( + # Discard the leading grid index. + bm.index_map_jaxpr.jaxpr, + bm.index_map_jaxpr.consts, + *args[1:], + ) + + debug_info = bm.index_map_jaxpr.jaxpr.debug_info + if debug_info.arg_names is not None: + debug_info = debug_info._replace( + arg_names=("idx", *debug_info.arg_names) + ) + flat_fun, _ = api_util.flatten_fun( + lu.wrap_init(new_index_map, debug_info=debug_info), index_map_tree + ) + with pallas_core.tracing_grid_env(new_grid, grid_mapping.vmapped_dims): + index_map_jaxpr, _, index_map_jaxpr_consts = pe.trace_to_jaxpr_dynamic( + flat_fun, index_map_avals + ) + new_block_mappings.append( + bm.replace( + index_map_jaxpr=jax_core.ClosedJaxpr( + index_map_jaxpr, index_map_jaxpr_consts + ) + ) + ) + + grid_mapping = grid_mapping.replace( + grid=new_grid, + index_map_avals=index_map_avals, + index_map_tree=index_map_tree, + block_mappings=tuple(new_block_mappings), + ) + dimension_semantics = ("arbitrary",) + + for bm in grid_mapping.block_mappings: + for bd in bm.block_shape: + if not isinstance(bd, pallas_core.Blocked): + raise NotImplementedError( + "Unsupported block dimension type: " + f"{type(bd)} for block shape: {bm.block_shape}" + ) + + backend = lowering_context.module_context.get_backend(optional=True) + mosaic_grid_mapping = MosaicGridMapping( + jaxpr, grid_mapping, dimension_semantics, mesh=mesh + ) + m = ir.Module.create() + sym_tab = ir.SymbolTable(m.operation) + func_op = lower_jaxpr_to_func( + jaxpr, + name="main", + kernel_type=kernel_type, + mosaic_grid_mapping=mosaic_grid_mapping, + forward_compatible=lowering_context.is_forward_compat(), + backend=backend, + ) + m.body.append(func_op) + sym_tab.insert(func_op) + func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get( + mosaic_grid_mapping.grid + ) + func_op.attributes["dimension_semantics"] = ( + mosaic_grid_mapping.get_dimension_semantics() + ) + if not mosaic_grid_mapping.grid: + # No need for "window_params" if the grid is empty. + return m + window_params = [] + for i, bm in enumerate(grid_mapping.block_mappings): + func_name = f"transform_{i}" + mlir_func = tc_lowering.lower_jaxpr_to_transform_func( + bm.index_map_jaxpr.jaxpr, + bm.block_aval, + name=func_name, + mosaic_grid_mapping=mosaic_grid_mapping, + kernel_type=kernel_type, + forward_compatible=lowering_context.is_forward_compat(), + backend=backend, + ) + assert mlir_func.verify(), mlir_func + m.body.append(mlir_func) + sym_tab.insert(mlir_func) + + block_shape = list(pallas_core._get_block_shape(bm.block_shape)) + block_params = dict( + window_bounds=ir.DenseI64ArrayAttr.get(block_shape), + transform_indices=ir.FlatSymbolRefAttr.get(func_name), + ) + window_params.append(ir.DictAttr.get(block_params)) + func_op.attributes["window_params"] = ir.ArrayAttr.get(window_params) + return m + + +@dataclasses.dataclass(init=False) +class MosaicGridMapping(tc_lowering.MosaicGridMapping): + """Abstracts a grid mapping for Mosaic SparseCore.""" + + def __init__( + self, + jaxpr: jax_core.Jaxpr, + grid_mapping: pallas_core.GridMapping, + dimension_semantics: Sequence[tpu_core.DimensionSemantics] | None, + mesh: mesh_lib.Mesh | None, + ): + for bm in grid_mapping.block_mappings: + shape = pallas_core._get_block_shape(bm.block_shape) + if len(shape) > 1 and shape[-1] % 8: + raise ValueError( + f"The minormost dimension of a block for {bm.origin} must be a" + f" multiple of 8, got shape {shape}" + ) + if any( + isinstance(var.aval, sc_core.AbstractRef) + for var in jaxpr.invars[grid_mapping.slice_scratch_ops] + ): + # TODO(slebedev): Support tiling annotations for kernel operands. + raise NotImplementedError( + "`plsc.MemoryRef`s are not supported as scratch operands to the" + " kernel. Allocate them in the kernel body via `pl.run_scoped`." + ) + super().__init__( + jaxpr, + grid_mapping, + dimension_semantics, + mesh, + dynamic_shape_replacement_fn=dynamic_shape_replacement_fn, + ) + + +def lower_jaxpr_to_func( + jaxpr: jax_core.Jaxpr, + *, + name: str, + kernel_type: tpu_core.KernelType, + mosaic_grid_mapping: MosaicGridMapping, + forward_compatible: bool, + backend: Any | None, +) -> func.FuncOp: + """Lowers a Jaxpr to a Mosaic SparseCore function.""" + num_grid = len(mosaic_grid_mapping.grid_types) + num_scalar_prefetch = len(mosaic_grid_mapping.scalar_prefetch_types) + if num_scalar_prefetch: + raise NotImplementedError("Scalar prefetch not supported.") + num_scratch = len(mosaic_grid_mapping.scratch_types) + arg_types = [ + *mosaic_grid_mapping.grid_types, + *mosaic_grid_mapping.scalar_prefetch_types, + *mosaic_grid_mapping.operand_types, + *mosaic_grid_mapping.scratch_types, + ] + arg_block_shapes = [ + *mosaic_grid_mapping.scalar_prefetch_block_shapes, + *mosaic_grid_mapping.operand_block_shapes, + *mosaic_grid_mapping.scratch_block_shapes, + ] + + def body_func(*args: ir.Value): + grid_indices, scalar_prefetch, operands_and_scratch = util.split_list( + args, [num_grid, num_scalar_prefetch] + ) + grid_indices = mosaic_grid_mapping.get_grid_indices( + grid_indices, maybe_include_mapped_dims=False + ) + jaxpr_indices = tuple( + idx + for i, idx in enumerate(grid_indices) + if i not in mosaic_grid_mapping.vmapped_dims + ) + + allocations = sc_core.gather_global_allocations(jaxpr) + flat_allocations, allocations_tree = tree_util.tree_flatten(allocations) + allocation_operands = operands_and_scratch[ + len(operands_and_scratch) - len(flat_allocations):] + allocations = allocations_tree.unflatten(allocation_operands) + lowering_context = ScLoweringContext( + mosaic_grid_mapping.grid, # type: ignore + mosaic_grid_mapping.grid_names, + mosaic_grid_mapping.vmapped_dims, + jaxpr_indices, + arg_block_shapes, + source_info_util.NameStack(), + mesh_context=mosaic_grid_mapping.mesh_info, + traceback_caches=mlir.TracebackCaches(), + kernel_type=kernel_type, + forward_compatible=forward_compatible, + backend=backend, + dynamic_shape_replacement_fn=dynamic_shape_replacement_fn, + global_allocations=GlobalAllocations(allocations), + ) + with lowering_context.global_allocations.verify_usage(): + return tc_lowering.jaxpr_subcomp( + lowering_context, jaxpr, *scalar_prefetch, *operands_and_scratch + ) + + body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func) + func_op = cast(func.FuncOp, body.func_op) + func_op.attributes["tpu.core_type"] = ir.Attribute.parse( + f"#tpu.core_type<{kernel_type.name.lower()}>" + ) + func_op.attributes["scratch_operands"] = ir.IntegerAttr.get( + ir.IntegerType.get_signless(64), num_scratch + ) + arg_attrs = [ir.DictAttr.get({})] * num_grid + for arg, bm in zip( + func_op.arguments[num_grid : len(func_op.arguments) - num_scratch], + mosaic_grid_mapping.block_mappings, + ): + d = {} + if ( + str(arg.type.memory_space) == "#tpu.memory_space" + or str(arg.type.memory_space) == "#tpu.memory_space" + ): + d["sc.persistent"] = ir.UnitAttr.get() + if isinstance(bm, sc_core.BlockMapping) and bm.indexed_by is not None: + d["sc.indexed_by"] = mlir.i32_attr(bm.indexed_by) + d["sc.indexed_dim"] = mlir.i32_attr(bm.indexed_dim) + arg_attrs.append(ir.DictAttr.get(d)) + arg_attrs.extend([ir.DictAttr.get({})] * num_scratch) + + func_op.arg_attrs = ir.ArrayAttr.get(arg_attrs) + try: + func_op.verify() + except Exception as e: + raise ValueError( + f"Body failed to verify: {func_op}.\nThis is an internal error." + " Please report a bug at:" + " https://github.com/jax-ml/jax/issues/new?assignees=sharadmv." + ) from e + return func_op + + +register_lowering_rule = functools.partial( + tc_lowering.register_lowering_rule, + kernel_types=( + tpu_core.KernelType.SC_SCALAR_SUBCORE, + tpu_core.KernelType.SC_VECTOR_SUBCORE, + ), +) + +@register_lowering_rule(pallas_primitives.get_global_p) +def _lower_get_global(ctx: LoweringRuleContext, *, what): + lctx = ctx.lowering_context + assert isinstance(lctx, ScLoweringContext) + return lctx.global_allocations.next_allocation(what) + + +@register_lowering_rule(state_primitives.get_p) +def _get_lowering_rule(ctx: LoweringRuleContext, ref, *flat_transforms, tree): + return _load_lowering_rule(ctx, ref, None, *flat_transforms, tree=tree) + + +def _load_lowering_rule( + ctx: LoweringRuleContext, ref, mask, *flat_transforms, tree +): + ref_aval, *_flat_index_avals = ctx.avals_in + assert isinstance(ref_aval, state.AbstractRef) + [out_aval] = ctx.avals_out + assert isinstance(out_aval, jax_core.ShapedArray) + + if ( + (ref_memory_space := ref_aval.memory_space) is MemorySpace.HBM or + ref_memory_space is MemorySpace.VMEM_SHARED + ): + raise NotImplementedError( + f"Get does not support loading from {ref_memory_space.name}." + " Copy the data to a core-local memory space, e.g. VMEM," + " via `pltpu.async_copy`." + ) + + transforms = list(tree_util.tree_unflatten(tree, flat_transforms)) + if not transforms or not isinstance(transforms[-1], indexing.NDIndexer): + ref_shape = state.get_transforms_shape(transforms, ref_aval.shape) + transforms.append(indexing.NDIndexer.make_trivial_indexer(ref_shape)) + *prev_transforms, indexer = transforms + ref_block_shape, *_ = ctx.block_shapes + ref, ref_block_shape = _transform_ref( + ref, ref_aval.dtype, ref_block_shape, prev_transforms + ) + starts, sizes, strides, _, _ = tc_lowering._indexer_to_start_size_stride( + indexer, ref_block_shape, cast_to_index=True + ) + del sizes # Currently unused. + if not all(s == 1 for s in strides): + raise NotImplementedError( + "Get only supports slices with stride 1, got {strides}" + ) + + if not out_aval.ndim: + if mask is not None: + raise NotImplementedError("Get does not support masked scalar loads") + return memref.load(ref, starts) + + if ref_memory_space is MemorySpace.SMEM: + raise NotImplementedError("Get can only load scalars from SMEM") + else: + _check_aval_is_supported("Get", out_aval) + + vec_type = ir.VectorType.get( + out_aval.shape, _dtype_to_ir_type(ref_aval.dtype) + ) + return tpu.vector_load(vec_type, ref, indices=starts, strides=[], mask=mask) + + +@register_lowering_rule(state_primitives.swap_p) +def _swap_lowering_rule( + ctx: LoweringRuleContext, ref, val, *flat_transforms, tree +): + return _store_lowering_rule( + ctx, ref, val, None, *flat_transforms, tree=tree, add=False + ) + + +def _store_lowering_rule( + ctx: LoweringRuleContext, ref, val, mask, *flat_transforms, tree, add +): + ref_aval, _, *_flat_index_avals = ctx.avals_in + assert isinstance(ref_aval, state.AbstractRef) + [out_aval] = ctx.avals_out + assert isinstance(out_aval, jax_core.ShapedArray) + + if ( + (ref_memory_space := ref_aval.memory_space) is MemorySpace.HBM or + ref_memory_space is MemorySpace.VMEM_SHARED + ): + raise NotImplementedError( + f"Swap does not support storing to {ref_memory_space.name}." + " Copy the data to a core-local memory space, e.g. VMEM," + " via `pltpu.async_copy`." + ) + + transforms = list(tree_util.tree_unflatten(tree, flat_transforms)) + if not transforms or not isinstance(transforms[-1], indexing.NDIndexer): + ref_shape = state.get_transforms_shape(transforms, ref_aval.shape) + transforms.append(indexing.NDIndexer.make_trivial_indexer(ref_shape)) + *prev_transforms, indexer = transforms + ref_block_shape, *_ = ctx.block_shapes + ref, ref_block_shape = _transform_ref( + ref, ref_aval.dtype, ref_block_shape, prev_transforms + ) + starts, sizes, strides, _, _ = tc_lowering._indexer_to_start_size_stride( + indexer, ref_block_shape, cast_to_index=True + ) + del sizes # Currently unused. + if not all(s == 1 for s in strides): + raise NotImplementedError( + "Swap only supports slices with stride 1, got {strides}" + ) + + if not out_aval.ndim: + if mask is not None: + raise NotImplementedError("Swap does not support masked scalar stores") + if add: + # TODO(slebedev): We can use memref.atomic_rmw here, but the SC compiler + # doesn't support it yet. + raise NotImplementedError("Swap does not support atomic scalar adds") + old_val = memref.load(ref, starts) + memref.store(val, ref, starts) + return old_val + + if ref_memory_space is MemorySpace.SMEM: + raise NotImplementedError("Swap can only store scalars to SMEM") + else: + _check_aval_is_supported("Swap", out_aval) + + vec_type = ir.VectorType.get( + out_aval.shape, _dtype_to_ir_type(ref_aval.dtype) + ) + old_val = tpu.vector_load(vec_type, ref, starts, strides=[], mask=mask) + tpu.vector_store(val, ref, starts, strides=[], mask=mask, add=add) + return old_val + + +@register_lowering_rule(lax.iota_p, + kernel_types=[tpu_core.KernelType.SC_VECTOR_SUBCORE]) +def _iota_lowering_rule_sc(ctx: LoweringRuleContext, dtype, shape, dimension, + sharding): + sc_info = sc_core.get_sparse_core_info() + if shape != (sc_info.num_lanes,): + raise ValueError( + f"Unsupported iota shape for SC vector subcore. Got {shape}, supported " + f"shape is {(sc_info.num_lanes,)}." + ) + [out_aval] = ctx.avals_out + out_type = ir.VectorType.get( + [sc_info.num_lanes], _dtype_to_ir_type(out_aval.dtype) + ) + return tpu.iota(out_type, dimensions=[dimension]) + + +def _check_aval_is_supported(caller: str, aval: jax_core.ShapedArray) -> None: + if aval.shape in sc_core.SUPPORTED_VECTOR_SHAPES.get(aval.dtype, []): + return + supported_shapes = ", ".join( + map(repr, sc_core.SUPPORTED_VECTOR_SHAPES[aval.dtype]) + ) + if not supported_shapes: + raise NotImplementedError(f"{caller} does not support {aval.dtype} arrays") + else: + raise NotImplementedError( + f"{caller} only supports {aval.dtype} arrays of shapes" + f" [{supported_shapes}], got {aval.shape}" + ) + + +@register_lowering_rule(debugging.debug_print_p) +def _debug_print_lowering_rule( + ctx: LoweringRuleContext, + *args, + fmt: str, + ordered, + partitioned, + in_tree, + static_args, + np_printoptions, + has_placeholders, + logging_record, +): + del partitioned, np_printoptions, in_tree, static_args + def fail(reason: str) -> NoReturn: + raise NotImplementedError( + f"pl.debug_print() {reason} when lowering to SparseCore" + ) + + if ordered: + fail("does not support ordered print") + if has_placeholders: + fail("does not support placeholders") + + match args: + case []: + tpu.log(inputs=[], tag=fmt) + case [arg] if isinstance(arg.type, ir.MemRefType): + tpu.log_buffer(arg, ctx.avals_in[0].shape, fmt) # pytype: disable=attribute-error + case [arg]: + tpu.log(inputs=[arg], tag=fmt) + case _: + fail("does not support multiple inputs") + return [] + + +def _memref_memory_space(ref: ir.Value) -> MemorySpace: + match str(ir.MemRefType(ref.type).memory_space): + case "#tpu.memory_space": + return MemorySpace.HBM + case "#tpu.memory_space": + return MemorySpace.VMEM + case "#tpu.memory_space": + return MemorySpace.VMEM_SHARED + case "#tpu.memory_space": + return MemorySpace.SMEM + case _: + raise LookupError(f"Unknown memory space: {ref.type}") + + +def _prepare_dma_refs( + src_ref, + src_transforms, + dst_ref, + dst_transforms, + src_aval, + dst_aval, + is_add: bool = False, +): + """Prepares the DMA source and destination references.""" + src_memory_space = _memref_memory_space(src_ref) + dst_memory_space = _memref_memory_space(dst_ref) + match src_memory_space, dst_memory_space: + case MemorySpace.HBM | MemorySpace.VMEM_SHARED, MemorySpace.VMEM: + if _has_indirect_offsets(dst_transforms): + raise ValueError( + "Only the source ref can be indexed when doing a gather via" + " `pltpu.async_copy`" + ) + dst_ref, _ = _transform_ref( + dst_ref, dst_aval.dtype, dst_aval.shape, dst_transforms + ) + dst_ref_shape = ir.MemRefType(dst_ref.type).shape + indirect_offsets, src_transforms = _extract_indirect_offsets( + src_transforms, tuple(dst_ref_shape) + ) + src_ref, _ = _transform_ref( + src_ref, src_aval.dtype, src_aval.shape, src_transforms + ) + indirect_offsets_ref_str = "src_ref" + case MemorySpace.VMEM, MemorySpace.HBM | MemorySpace.VMEM_SHARED: + if _has_indirect_offsets(src_transforms): + raise ValueError( + "Only the destination ref can be indexed when doing a scatter via" + " `pltpu.async_copy`" + ) + src_ref, _ = _transform_ref( + src_ref, src_aval.dtype, src_aval.shape, src_transforms + ) + src_ref_shape = ir.MemRefType(src_ref.type).shape + indirect_offsets, dst_transforms = _extract_indirect_offsets( + dst_transforms, tuple(src_ref_shape) + ) + dst_ref, _ = _transform_ref( + dst_ref, dst_aval.dtype, dst_aval.shape, dst_transforms + ) + indirect_offsets_ref_str = "dst_ref" + case _: # Indirect DMA is not supported. + if ( + # fmt: off + _has_indirect_offsets(src_transforms) or + _has_indirect_offsets(dst_transforms) + # fmt: on + ): + raise NotImplementedError( + "Scatter/gather via `pltpu.async_copy` from" + f" {src_memory_space.name} to {dst_memory_space.name} is not" + " supported" + ) + if is_add: + raise ValueError( + "DMAs with `add=True` are only supported between VMEM and " + f"HBM/VMEM_SHARED. " + f"Got (src, dst)={(src_aval.memory_space, dst_aval.memory_space)}" + ) + src_ref, _ = _transform_ref( + src_ref, src_aval.dtype, src_aval.shape, src_transforms + ) + dst_ref, _ = _transform_ref( + dst_ref, dst_aval.dtype, dst_aval.shape, dst_transforms + ) + indirect_offsets = None + indirect_offsets_ref_str = "" + if is_add and indirect_offsets is None: + raise NotImplementedError( + "DMAs with `add=True` must (for now) specify offsets of the" + " majormost dimension. You can do this by writing" + " `pltpu.async_copy(..., {ref}={ref}.at[jnp.arange(vec_dim)], ...)`" + " or `pltpu.async_copy(..., {ref}={ref}.at[indices_ref]," + " ...)`.".format(ref=indirect_offsets_ref_str) + ) + return src_ref, dst_ref, indirect_offsets + + +# TODO(slebedev): Use the TC rule once we align the ``LoweringRuleContext`` +# with the TC lowering. +@register_lowering_rule(tpu_primitives.dma_start_p) +def _dma_start_lowering_rule( + ctx: LoweringRuleContext, + *args, + tree, + device_id_type: pallas_primitives.DeviceIdType, + priority: int, + add: bool, +): + ( + src_ref, + src_transforms, + dst_ref, + dst_transforms, + sem, + sem_transforms, + src_sem, + src_sem_transforms, + device_id, + ) = tpu_primitives._dma_unflatten(tree, args) + src_aval, _, dst_aval, _, sem_aval, _, src_sem_aval, _, _ = ( + tpu_primitives._dma_unflatten(tree, ctx.avals_in) + ) + + src_ref, dst_ref, indirect_offsets = _prepare_dma_refs( + src_ref, src_transforms, dst_ref, dst_transforms, src_aval, dst_aval, add + ) + if add and indirect_offsets is None: + # TODO: Support regular DMA with add=True. + raise NotImplementedError( + "DMAs with `add=True` must (for now) specify offsets of the majormost " + "dimension. You can do this by writing " + "`pltpu.async_copy(..., dst_ref=ref.at[jnp.arange(vec_dim)], ...)` or " + "`pltpu.async_copy(..., dst_ref=ref.at[iota_ref], ...)`." + ) + sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, sem_transforms) + if src_sem is not None: + src_sem, _ = _transform_ref( + src_sem, src_sem_aval.dtype, src_sem_aval.shape, src_sem_transforms + ) + + # If not ``None``, we lower to an indirect DMA instead. + if indirect_offsets is None: + if device_id is not None: + device_id, _ = tc_lowering._device_id_to_logical( + ctx, device_id, device_id_type + ) + tpu.enqueue_dma( + src_ref, + dst_ref, + sem, + source_semaphore=src_sem, + device_id=device_id, + priority=priority, + ) + return [] + + if device_id is not None: + raise NotImplementedError( + "Scatter/gather to or from a remote device via `pltpu.async_copy` is" + " not supported" + ) + del priority # Unused by indirect DMAs. + tpu.enqueue_indirect_dma(src_ref, dst_ref, indirect_offsets, sem, add=add) + return [] + + +# TODO(slebedev): Use the TC rule once we align the ``LoweringRuleContext`` +# with the TC lowering. +@register_lowering_rule(tpu_primitives.dma_wait_p) +def _dma_wait_lowering_rule( + ctx: LoweringRuleContext, + *args, + tree, + device_id_type: pallas_primitives.DeviceIdType, +): + ( + src_ref, + src_transforms, + dst_ref, + dst_transforms, + sem, + sem_transforms, + _, + _, + device_id, + ) = tpu_primitives._dma_unflatten(tree, args) + src_aval, _, dst_aval, _, sem_aval, _, _, _, _ = ( + tpu_primitives._dma_unflatten(tree, ctx.avals_in) + ) + + src_ref, dst_ref, indirect_offsets = _prepare_dma_refs( + src_ref, src_transforms, dst_ref, dst_transforms, src_aval, dst_aval, + ) + sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, sem_transforms) + + # If not ``None``, we lower to an indirect DMA instead of a regular DMA. + if indirect_offsets is None: + if device_id is not None: + device_id, _ = tc_lowering._device_id_to_logical( + ctx, device_id, device_id_type + ) + tpu.wait_dma2(sem, src_ref, dst_ref, device_id=device_id) + return [] + + if device_id is not None: + raise NotImplementedError( + "Scatter/gather to or from a remote device via `pltpu.async_copy` is" + " not supported" + ) + tpu.wait_indirect_dma(sem, src_ref, dst_ref) + return [] + + +def _extract_indirect_offsets_from_indexer( + indexer: indexing.NDIndexer, expected_shape: tuple[int, ...] | None = None +) -> ir.Value | None: + offsets_ref: Any # Make mypy happy. + match indexer.indices: + case [ir.Value() as offsets, *_] if ( + # fmt: off + isinstance(offsets.type, ir.MemRefType) or + isinstance(offsets.type, ir.VectorType) + ): # fmt: on + shape = indexer.get_indexer_shape() + if expected_shape is not None and shape != expected_shape: + raise NotImplementedError( + "The indexer shape in scatter/gather via `pltpu.async_copy` does" + f" not match the expected shape. Want: {expected_shape}, got:" + f" {shape}." + ) + case [state.TransformedRef() as offsets_ref, *_]: + offsets_type = ir.MemRefType(offsets_ref.ref.type) + if offsets_type.element_type != ir.IntegerType.get_signless(32): + raise NotImplementedError( + "Only int32 indices are supported by scatter/gather via" + " `pltpu.async_copy` with a dynamically-shaped indexer" + ) + offsets, _ = _transform_ref( + offsets_ref.ref, + jnp.int32, + offsets_type.shape, # The shape before the indexing. + offsets_ref.transforms, + ) + case _: + return None + + if isinstance(offsets.type, ir.MemRefType): + offsets_memory_space = _memref_memory_space(offsets) + if offsets_memory_space is not MemorySpace.VMEM: + raise NotImplementedError( + "Indices for scatter/gather via `pltpu.async_copy` must be in VMEM," + f" got {offsets_memory_space.name}" + ) + if not state_discharge._is_trivial_indexer( + indexing.NDIndexer(indexer.indices[1:], indexer.shape[1:], ()) + ): + # TODO(slebedev): Consider lifting this restriction. + raise NotImplementedError( + "Only indexing along the major dimension is supported in scatter/gather" + " via `pltpu.async_copy`" + ) + return offsets + + +def _extract_indirect_offsets( + transforms: Sequence[ir.Value], expected_shape: tuple[int, ...] +) -> tuple[ir.Value | None, Sequence[pallas_core.MemoryRefTransform]]: + for i, indexer in enumerate(transforms): + if not isinstance(indexer, indexing.NDIndexer): + continue + offsets = _extract_indirect_offsets_from_indexer(indexer, expected_shape) + if offsets is None: + continue + if i != len(transforms) - 1: + raise NotImplementedError( + "The indexed ref in scatter/gather via `pltpu.async_copy` cannot have" + " any transforms following the indexer" + ) + return offsets, transforms[:i] + + return None, transforms + + +def _has_indirect_offsets(transforms: Sequence[ir.Value]) -> bool: + return any( + _extract_indirect_offsets_from_indexer(indexer) is not None + for indexer in transforms + if isinstance(indexer, indexing.NDIndexer) + ) + + +@register_lowering_rule(pallas_primitives.run_scoped_p) +def _run_scoped_lowering_rule( + ctx: LoweringRuleContext, *consts, jaxpr, collective_axes +): + return tc_lowering._run_scoped_lowering_rule( + ctx, + *consts, + jaxpr=jaxpr, + collective_axes=collective_axes, + alloc_fn=_alloc_value, + ) + + +@register_lowering_rule( + lax.sort_p, kernel_types=[tpu_core.KernelType.SC_VECTOR_SUBCORE] +) +def _sort_lowering_rule( + ctx: LoweringRuleContext, *xs, dimension, is_stable, num_keys +): + del is_stable # Unused, always stable. + if dimension not in (0, -1): + raise ValueError(f"Unsupported dimension: {dimension}") + if num_keys != 1: + raise NotImplementedError("Multiple sort keys not supported") + sc_info = sc_core.get_sparse_core_info() + supported_shape = (sc_info.num_lanes,) + for i, aval in enumerate(ctx.avals_in): + if aval.shape != supported_shape: + raise NotImplementedError( + f"Unsupported shape for operand {i} of SC sort: Got {aval.shape}, " + f"expected {supported_shape}" + ) + keys = xs[0] + values = xs[1:] + mask_type = ir.VectorType.get( + [sc_info.num_lanes], ir.IntegerType.get_signless(1)) + mask = arith.constant(mask_type, ir.DenseElementsAttr.get_splat( + mask_type, ir.BoolAttr.get(True))) + if not values: + _, sorted_keys, _ = tpu.sort( + mask_type, keys.type, keys.type, keys, keys, mask=mask + ) + return (sorted_keys,) + results: list[ir.Value] = [] + for value in values: + _, sorted_keys, sorted_value = tpu.sort( + mask_type, keys.type, value.type, keys, value, mask=mask + ) + if not results: + results.append(sorted_keys) + results.append(sorted_value) + return tuple(results) + + +@register_lowering_rule( + lax.gather_p, kernel_types=[tpu_core.KernelType.SC_VECTOR_SUBCORE] +) +def _gather_lowering_rule( + ctx: LoweringRuleContext, + x, + indices, + *, + dimension_numbers, + slice_sizes, + unique_indices, + indices_are_sorted, + mode, + fill_value, +): + + in_aval, indices_aval = ctx.avals_in + out_aval, = ctx.avals_out + + if len(in_aval.shape) != 1: + raise NotImplementedError("Only 1D gather is supported") + if in_aval.shape != indices_aval.shape[:-1] != out_aval.shape: + raise ValueError( + "Shape mismatch in input, indices and output:" + f" {in_aval.shape}, {indices_aval.shape[:-1]}, {out_aval.shape}" + ) + + # During lowering jnp.take_along_axis to lax.gather, we append extra dimension + # to the end of the indices array. We should reshape it back to the original + # shape before lowering to Mosaic and rely on MLIR canonicalization to remove + # the reshapes. + assert indices_aval.shape == in_aval.shape + (1,) + recovered_indices = vector.shape_cast( + ir.VectorType.get(in_aval.shape, indices.type.element_type), + indices, + ) + # Note: current support for lax.gather is still very limited. + del fill_value + if slice_sizes == (1,) and mode == lax.GatherScatterMode.PROMISE_IN_BOUNDS: + if dimension_numbers == lax.GatherDimensionNumbers( + offset_dims=(), + collapsed_slice_dims=(0,), + start_index_map=(0,), + operand_batching_dims=(), + start_indices_batching_dims=(), + ): + return tpu.dynamic_gather(x, recovered_indices, [0]) + raise NotImplementedError("Unsupported gather") + + +@register_lowering_rule( + lax.rev_p, kernel_types=[tpu_core.KernelType.SC_VECTOR_SUBCORE] +) +def _rev_lowering_rule(ctx: LoweringRuleContext, x, dimensions): + del ctx # Unused. + if dimensions != (0,): + raise NotImplementedError(f"Invalid dimensions for SC lax.rev: {dimensions}") + i32 = ir.IntegerType.get_signless(32) + vec_dim = sc_core.get_sparse_core_info().num_lanes + cdim = arith.constant(i32, ir.IntegerAttr.get(i32, vec_dim - 1)) + cdim_vec = vector.broadcast(ir.VectorType.get((vec_dim,), cdim.type), cdim) + return tpu.dynamic_gather( + x, + arith.subi(cdim_vec, tpu.iota(cdim_vec.type, dimensions=[0])), + dimensions=[0], + ) + + +def _default_tile_strides( + tiling: sc_core.Tiling, shape: Sequence[int] +) -> Sequence[int]: + """Returns default tile strides for a given shape and tiling.""" + assert tiling + + cdiv = lambda a, b: (a + b - 1) // b + + strides = [0] * len(shape) + stride = 1 + first_tile, *_ = tiling + for d in reversed(range(len(shape))): + assert shape[d] != ir.ShapedType.get_dynamic_size() + strides[d] = stride + if d >= len(shape) - len(first_tile): + tile_d = d - (len(shape) - len(first_tile)) + stride *= cdiv(shape[d], first_tile[tile_d]) + else: + stride *= shape[d] + return strides + + +def _alloc_value( + aval: jax_core.AbstractValue, *, ctx: LoweringRuleContext +) -> ir.Value: + if isinstance(aval, sc_core.AbstractRef) and aval.tiling is not None: + tiling = "".join(f"({','.join(map(str, tile))})" for tile in aval.tiling) + strides = _default_tile_strides(aval.tiling, aval.shape) + out_type = ir.MemRefType.get( + aval.shape, + _dtype_to_ir_type(aval.dtype, is_kernel_boundary=True), + layout=ir.Attribute.parse(f"#tpu.tiled<{tiling},{strides}>"), + memory_space=tc_lowering._memory_space_to_mosaic_attribute( + aval.memory_space or MemorySpace.VMEM + ), + ) + return memref.alloca(out_type, [], []) + return tc_lowering._alloc_value(aval, ctx=ctx) diff --git a/jax/_src/pallas/mosaic/sc_primitives.py b/jax/_src/pallas/mosaic/sc_primitives.py new file mode 100644 index 000000000000..7f06e37d8474 --- /dev/null +++ b/jax/_src/pallas/mosaic/sc_primitives.py @@ -0,0 +1,1148 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pallas primitives for SparseCore.""" + +from collections.abc import Callable, Sequence +import enum +import functools +from typing import TypeAlias, TypeVar, overload + +import jax +from jax import api_util +from jax import lax +from jax._src import core as jax_core +from jax._src import dtypes +from jax._src import effects +from jax._src import linear_util as lu +from jax._src.interpreters import partial_eval as pe +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import arith +from jax._src.lib.mlir.dialects import scf +from jax._src.lib.mlir.dialects import vector +from jax._src.pallas import core as pallas_core +from jax._src.pallas.mosaic import core as tpu_core +from jax._src.pallas.mosaic import lowering as tc_lowering +from jax._src.pallas.mosaic import sc_core +from jax._src.pallas.mosaic import sc_lowering +from jax._src.state import primitives as state_primitives +from jax._src.state import types as state_types +from jax.experimental.mosaic.dialects import tpu +import jax.numpy as jnp + + +_ensure_ir_value = tc_lowering._ensure_mlir_value +aval_to_ir_type = functools.partial( + tc_lowering.aval_to_ir_type, sc_lowering.dynamic_shape_replacement_fn +) + +TransformedRef: TypeAlias = state_types.TransformedRef +Ref: TypeAlias = state_types.AbstractRef | TransformedRef + +_T = TypeVar("_T") + +load_p = jax_core.Primitive("load") +load_p.is_effectful = lambda params: True # type: ignore + + +@load_p.def_effectful_abstract_eval +def _load_abstract_eval(ref, *args, has_mask, tree): + flat_transforms = args[:-1] if has_mask else args + tref = state_types.TransformedRef( + ref, jax.tree.unflatten(tree, flat_transforms)) + if has_mask: + mask = args[-1] + if mask.dtype != jnp.bool: + raise TypeError(f"Mask must be a boolean array, got {mask.dtype}") + if mask.shape != tref.shape: + raise ValueError(f"Mask must have shape {tref.shape}, got {mask.shape}") + return ( + jax_core.ShapedArray(tref.shape, ref.dtype), {state_types.ReadEffect(0)}) + + +@sc_lowering.register_lowering_rule(load_p) +def _load_lowering_rule( + ctx: sc_lowering.LoweringRuleContext, ref, *args, has_mask, tree +): + if has_mask: + *flat_transforms, mask = args + else: + flat_transforms, mask = list(args), None + return sc_lowering._load_lowering_rule( + ctx, ref, mask, *flat_transforms, tree=tree + ) + + +def load_expanded(ref: Ref, *, mask: jax.Array) -> jax.Array: + """Performs and expanded masked load from a ref. + + Elements from ``ref`` are placed into positions where ``mask`` is ``True``. + The elements are taken from ``ref`` sequentially, meaning that the i-th + ``True`` value in ``mask`` corresponds to accessing ``ref[i]``. The result is + expanded into the shape of the ``mask``. + + For example, if the mask is ``[True, False, True, True]``, the result is + ```[ref[0], , ref[2], ref[3]]``, where ```` is an undefined value. + + Args: + ref: The ref to load from. + mask: A boolean mask specifying which elements to load into. + + Returns: + The loaded array, with the same shape as the mask. No assumptions can be + made about the elements at the indices where the mask is ``False``. + """ + if not isinstance(ref, Ref): + raise TypeError(f"ref must be an AbstractRef or TransformedRef, got {ref}") + if not isinstance(ref, TransformedRef): + ref = ref.at[...] # type: ignore + assert isinstance(ref, TransformedRef) + flat_transforms, tree = jax.tree.flatten(ref.transforms) + return load_p.bind(ref.ref, *flat_transforms, mask, has_mask=True, tree=tree) + + +swap_p = jax_core.Primitive("swap") +swap_p.is_effectful = lambda params: True # type: ignore + + +@swap_p.def_effectful_abstract_eval +def _swap_abstract_eval(ref, x, *args, has_mask, tree, add): + flat_transforms = args[:-1] if has_mask else args + tref = state_types.TransformedRef( + ref, jax.tree.unflatten(tree, flat_transforms)) + if has_mask: + mask = args[-1] + if mask.dtype != jnp.bool: + raise TypeError(f"Mask must be a boolean array, got {mask.dtype}") + if mask.shape != tref.shape: + raise ValueError(f"Mask must have shape {tref.shape}, got {mask.shape}") + if ref.dtype != x.dtype: + raise TypeError( + f"Ref and value must have the same dtype, got {ref.dtype} and {x.dtype}" + ) + if tref.shape != x.shape: + raise ValueError(f"Value must have shape {tref.shape}, got {x.shape}") + effects = {state_types.WriteEffect(0)} + if add: + effects.add(state_types.ReadEffect(0)) + return x, effects + + +@sc_lowering.register_lowering_rule(swap_p) +def _swap_lowering_rule( + ctx: sc_lowering.LoweringRuleContext, ref, x, *args, has_mask, tree, add +): + if has_mask: + *flat_transforms, mask = args + else: + flat_transforms, mask = list(args), None + return sc_lowering._store_lowering_rule( + ctx, ref, x, mask, *flat_transforms, tree=tree, add=add + ) + + +def store_compressed(ref: Ref, x: jax.Array, *, mask: jax.Array) -> None: + """Performs a compressed masked store to a ref. + + Elements from ``x`` where ``mask`` is ``True`` are placed into ``ref``. + The elements are written to ``ref`` sequentially, meaning the i-th ``True`` + value in ``mask`` corresponds to writing to ``ref[i]``. + + For example, if the mask is ``[True, False, True, True]``, the elements + ``x[0]``, ``x[2]``, and ``x[3]`` are written to ``ref[0]``, ``ref[1]``, and + ``ref[2]`` respectively. + + Args: + ref: The ref to store into. + x: The array to store. Must have the same shape as ``ref``. + mask: A boolean mask specifying which elements from ``x`` to store. + """ + if not isinstance(ref, Ref): + raise TypeError(f"ref must be an AbstractRef or TransformedRef, got {ref}") + if not isinstance(ref, TransformedRef): + ref = ref.at[...] # type: ignore + assert isinstance(ref, TransformedRef) + flat_transforms, tree = jax.tree.flatten(ref.transforms) + _ = swap_p.bind( + ref.ref, + x, + *flat_transforms, + mask, + has_mask=True, + tree=tree, + add=False, + ) + return None + + +def addupdate(ref: Ref, x: jax.Array) -> None: + """Performs an atomic add to a ref. + + Args: + ref: The ref to store into. + x: The array to store. Must have the same shape as ``ref``. + """ + if not isinstance(ref, Ref): + raise TypeError(f"ref must be an AbstractRef or TransformedRef, got {ref}") + if not isinstance(ref, TransformedRef): + ref = ref.at[...] # type: ignore + assert isinstance(ref, TransformedRef) + flat_transforms, tree = jax.tree.flatten(ref.transforms) + _ = swap_p.bind( + ref.ref, x, *flat_transforms, has_mask=False, tree=tree, add=True + ) + return None + + +def addupdate_compressed(ref: Ref, x: jax.Array, *, mask: jax.Array) -> None: + """Performs a masked atomic add to a ref. + + See ``store_compressed`` for details on how the mask is used. + """ + if not isinstance(ref, Ref): + raise TypeError(f"ref must be an AbstractRef or TransformedRef, got {ref}") + if not isinstance(ref, TransformedRef): + ref = ref.at[...] # type: ignore + assert isinstance(ref, TransformedRef) + flat_transforms, tree = jax.tree.flatten(ref.transforms) + _ = swap_p.bind( + ref.ref, x, *flat_transforms, mask, has_mask=True, tree=tree, add=True + ) + return None + + +def _indexed_shape(ref: Ref, indices: Sequence[jax.Array]) -> tuple[int, ...]: + if len(indices) != ref.ndim: + raise ValueError(f"The number of indices does not match {ref.ndim=}") + prev_idx = None + for idx in indices: + if idx.ndim != 1: + raise ValueError( + f"Indices must be a 1-D array, got an index with shape {idx.shape}" + ) + if prev_idx is not None and idx.size != prev_idx.size: + raise ValueError( + "Indices must have the same size, got {prev_idx.size} and {idx.size}" + ) + prev_idx = idx + assert prev_idx is not None + return (prev_idx.size,) + + +gather_p = jax_core.Primitive("gather") +gather_p.is_effectful = lambda params: True # type: ignore + + +@gather_p.def_effectful_abstract_eval +def _gather_abstract_eval(*flat_args, tree): + ref, transforms, indices, mask = tree.unflatten(flat_args) + if transforms: + ref = state_types.TransformedRef(ref, transforms) + if ref.dtype not in (jnp.int32, jnp.float32): + raise TypeError(f"ref.dtype={ref.dtype} must be int32 or float32") + out_aval = jax_core.ShapedArray(_indexed_shape(ref, indices), ref.dtype) + sc_lowering._check_aval_is_supported("Gather", out_aval) + if mask is not None and mask.shape != out_aval.shape: + raise ValueError( + f"{mask.shape=} does not match the expected shape {out_aval.shape}" + ) + return out_aval, {state_types.ReadEffect(0)} + + +@sc_lowering.register_lowering_rule(gather_p) +def _gather_lowering_rule( + ctx: sc_lowering.LoweringRuleContext, *flat_args, tree +): + ref, transforms, indices, mask = tree.unflatten(flat_args) + ref_aval, *_ = tree.unflatten(ctx.avals_in) + if ref_aval.memory_space not in (tpu_core.MemorySpace.VMEM, None): + raise ValueError( + f"Gather only supports loading from VMEM, got {ref_aval.memory_space}" + ) + if transforms: + ref_block_shape, *_ = ctx.block_shapes + ref, _ = tc_lowering._transform_ref( + ref, ref_aval.dtype, ref_block_shape, transforms + ) + [out_aval] = ctx.avals_out + vec_type = ir.VectorType.get( + out_aval.shape, sc_lowering._dtype_to_ir_type(ref_aval.dtype) + ) + return tpu.vector_load_idx(vec_type, ref, indices, mask=mask) + + +def load_gather( + ref: Ref, indices: Sequence[jax.Array], *, mask: jax.Array | None = None +) -> jax.Array: + """Gathers an array from a ref. + + Args: + ref: The ref in ``VMEM`` to gather from. + indices: A sequence of 1D arrays, one for each dimension of ``ref``. Each + array specifies an index for that dimension. All arrays must have the same + size. + mask: An optional boolean array, which specifies which elements to load. If + ``None``, all elements are loaded. + + Returns: + The gathered array. + """ + ref, transforms = state_primitives.get_ref_and_transforms( + ref, None, "load_gather" + ) + flat_args, tree = jax.tree.flatten((ref, transforms, indices, mask)) + return gather_p.bind(*flat_args, tree=tree) + + +scatter_p = jax_core.Primitive("scatter") +scatter_p.is_effectful = lambda params: True # type: ignore +scatter_p.multiple_results = True + + +@scatter_p.def_effectful_abstract_eval +def _scatter_abstract_eval(*flat_args, tree, add): + ref, transforms, indices, x, mask = jax.tree.unflatten(tree, flat_args) + if transforms: + ref = state_types.TransformedRef(ref, transforms) + if ref.dtype not in (jnp.int32, jnp.float32): + raise TypeError(f"ref.dtype={ref.dtype} must be int32 or float32") + expected_shape = _indexed_shape(ref, indices) + if x.shape != expected_shape: + raise ValueError( + f"{x.shape=} does not match expected shape {expected_shape}" + ) + if x.dtype != ref.dtype: + raise TypeError(f"val.dtype={x.dtype} != ref.dtype={ref.dtype}") + if mask is not None and mask.shape != expected_shape: + raise ValueError( + f"{mask.shape=} does not match expected shape {expected_shape}" + ) + effects = {state_types.WriteEffect(0)} + if add: + effects.add(state_types.ReadEffect(0)) + return (), effects + + +@sc_lowering.register_lowering_rule(scatter_p) +def _scatter_lowering_rule( + ctx: sc_lowering.LoweringRuleContext, *flat_args, tree, add +): + ref, transforms, indices, x, mask = jax.tree.unflatten(tree, flat_args) + ref_aval, *_ = tree.unflatten(ctx.avals_in) + if ref_aval.memory_space not in (tpu_core.MemorySpace.VMEM, None): + raise ValueError( + f"Scatter only supports storing to VMEM, got {ref_aval.memory_space}" + ) + if transforms: + ref_block_shape, *_ = ctx.block_shapes + ref, _ = tc_lowering._transform_ref( + ref, ref_aval.dtype, ref_block_shape, transforms + ) + tpu.vector_store_idx(x, ref, indices, mask=mask, add=add) + return () + + +def store_scatter( + ref: Ref, + indices: Sequence[jax.Array], + x: jax.Array, + *, + mask: jax.Array | None = None, +) -> None: + """Scatters an array to a ref. + + Args: + ref: The ref in ``VMEM`` to scatter to. + indices: A sequence of 1D arrays, one for each dimension of ``ref``. Each + array specifies an index for that dimension. All arrays must have the same + size. + val: The array to store. + mask: An optional boolean array, which specifies which elements to store. If + ``None``, all elements are stored. + """ + if not indices: + raise ValueError("Indices must not be empty") + ref, transforms = state_primitives.get_ref_and_transforms( + ref, None, "store_scatter" + ) + flat_args, tree = jax.tree.flatten((ref, transforms, indices, x, mask)) + _ = scatter_p.bind(*flat_args, tree=tree, add=False) + return None + + +def addupdate_scatter( + ref: Ref, + indices: Sequence[jax.Array], + x: jax.Array, + *, + mask: jax.Array | None = None, +) -> None: + """Scatters an array to a ref atomically adding to existing values.""" + if not indices: + raise ValueError("Indices must not be empty") + ref, transforms = state_primitives.get_ref_and_transforms( + ref, None, "store_scatter" + ) + flat_args, tree = jax.tree.flatten((ref, transforms, indices, x, mask)) + _ = scatter_p.bind(*flat_args, tree=tree, add=True) + + +bitcast_p = jax_core.Primitive("bitcast") + + +@bitcast_p.def_abstract_eval +def _bitcast_abstract_eval(x, dtype): + old_bitwidth = dtypes.itemsize_bits(x.dtype) + new_bitwidth = dtypes.itemsize_bits(dtype) + if old_bitwidth == new_bitwidth: + return jax_core.ShapedArray(x.shape, dtype) + if x.ndim == 0: + raise ValueError( + "Cannot bitcast a ()-shaped array to a dtype with a different bitwidth:" + f" {old_bitwidth=} vs {new_bitwidth=}" + ) + new_last_dim, rem = divmod(x.shape[-1] * old_bitwidth, new_bitwidth) + if rem: + raise ValueError( + f"Cannot bitcast from {x.dtype} ({old_bitwidth} bits) to" + f" {dtype} ({new_bitwidth} bits), because {x.shape[-1]=} *" + f" {old_bitwidth} is not divisible by {new_bitwidth}" + ) + return jax_core.ShapedArray((*x.shape[:-1], new_last_dim), dtype) + + +@sc_lowering.register_lowering_rule(bitcast_p) +def _bitcast_lowering_rule(ctx: sc_lowering.LoweringRuleContext, x, *, dtype): + del dtype # Unused. + [out_aval] = ctx.avals_out + return vector.bitcast(aval_to_ir_type(out_aval), x) + + +def bitcast(x: jax.Array, dtype: jax.typing.DTypeLike) -> jax.Array: + """Bitcasts an array to a different dtype. + + Unlike ``lax.bitcast_convert_type``, this function returns an array of the + same rank as the input. The minormost dimension is expanded/shrunk to + account for the difference in the element bitwidth. + """ + if x.dtype == dtype: + return x + return bitcast_p.bind(x, dtype=jnp.dtype(dtype)) + + +class MemoryEffect(jax_core.Effect): + pass + + +effects.control_flow_allowed_effects.add_type(MemoryEffect) +effects.lowerable_effects.add_type(MemoryEffect) +_memory_effect = MemoryEffect() + +barrier_p = jax_core.Primitive("barrier") +barrier_p.multiple_results = True + +@barrier_p.def_effectful_abstract_eval +def _barrier_abstract_eval(): + return (), {_memory_effect} + + +@sc_lowering.register_lowering_rule(barrier_p) +def _barrier_lowering_rule(ctx: sc_lowering.LoweringRuleContext): + ix = ir.IndexType.get() + tpu.barrier(arith.constant(ix, ir.IntegerAttr.get(ix, 0))) + return () + + +def subcore_barrier(): + """Blocks until all subcores on the same core reach this instruction. + + The barrier must be used with the vector subcore, either via + :class:jax.experimental.pallas.tpu_sc.VectorSubcoreMesh or by specifying + ``` + pltpu.CompilerParams( + kernel_type=pltpu.KernelType.SC_VECTOR_SUBCORE, + dimension_semantics[..., "subcore_parallel", ...]) + ``` + to ``pallas_call``. + """ + barrier_p.bind() + + +scan_count_p = jax_core.Primitive("unique") +scan_count_p.multiple_results = True + + +@scan_count_p.def_abstract_eval +def _scan_count_abstract_eval(x, mask): + if x.dtype not in (jnp.uint32, jnp.int32, jnp.float32): + raise NotImplementedError( + f"x.dtype={x.dtype} must be uint32, int32 or float32") + if not jnp.issubdtype(mask.dtype, jnp.bool): + raise TypeError(f"mask.dtype={mask.dtype} is not a boolean dtype") + if x.shape != mask.shape: + raise ValueError(f"x.shape={x.shape} != mask.shape={mask.shape}") + return jax_core.ShapedArray(x.shape, jnp.int32), mask + + +@sc_lowering.register_lowering_rule(scan_count_p) +def _scan_count_lowering_rule(ctx: sc_lowering.LoweringRuleContext, x, mask): + del ctx # Unused. + # Reverse, because the MLIR op returns the mask first. + return tpu.scan_count(mask, x)[::-1] + + +def scan_count( + x: jax.Array, mask: jax.Array | None = None +) -> tuple[jax.Array, jax.Array]: + """Computes the running duplicate occurrence count of the array. + + Args: + x: An array of integers or floats. + mask: An optional array of booleans, which specifies which elements ``x`` + are eligible for counting. If ``None``, all elements are eligible. + + Returns: + A tuple of two arrays: + + * the running duplicate occurrence count of ``x``; + * the mask indicating the last occurrence of each duplicate that was + counted. + """ + return scan_count_p.bind(x, lax.full(x.shape, True) if mask is None else mask) + + +masked_cummax_p = jax_core.Primitive("masked_cummax") +masked_cummax_p.multiple_results = False + +@masked_cummax_p.def_abstract_eval +def _masked_cummax_abstract_eval(x, mask): + if x.dtype != jnp.int32 and x.dtype != jnp.float32: + raise NotImplementedError(f"x.dtype={x.dtype} must be int32 or float32") + if not jnp.issubdtype(mask.dtype, jnp.bool): + raise TypeError(f"mask.dtype={mask.dtype} is not a boolean dtype") + if x.shape != mask.shape: + raise ValueError(f"x.shape={x.shape} != mask.shape={mask.shape}") + return x + +@sc_lowering.register_lowering_rule(masked_cummax_p) +def _masked_cummax_lowering_rule(ctx: sc_lowering.LoweringRuleContext, x, mask): + del ctx # Unused. + return tpu.scan( + x.type, x, ir.Attribute.parse("#tpu.reduction_kind"), mask=mask) + +def cummax(x: jax.Array, *, mask: jax.Array | None = None) -> jax.Array: + """Returns the cumulative max of the array along its innermost axis. + + Elements from `x` will pass through directly to the result until the first + valid value is encountered (`mask[i] == True`). If you would like to specify + a default value for such elements instead, write + `x = jnp.where(mask, x, default_value)` before or after calling this function. + + Args: + x: An array of integers or floats. + mask: An optional array of booleans, which specifies which elements of `x` + are eligible for the max. If `None`, all elements are eligible. + """ + if x.ndim != 1: + raise NotImplementedError(f"masked_cummax: x={x.aval} must be rank 1") + if mask is None: + mask = lax.full(x.shape, True) + return masked_cummax_p.bind(x, mask) + +@sc_lowering.register_lowering_rule( + lax.reduce_max_p, kernel_types=[tpu_core.KernelType.SC_VECTOR_SUBCORE]) +def _reduce_max_lowering_rule(ctx: sc_lowering.LoweringRuleContext, x, axes): + if axes != (0,): + raise NotImplementedError( + f"reduce_max requires axes to be (0,) on SparseCore, but got {axes}.") + vec_dim = ctx.avals_in[0].shape[0] + i1t = ir.IntegerType.get_signless(1) + c1 = arith.constant(i1t, ir.IntegerAttr.get(i1t, 1)) + c1v = vector.broadcast(ir.VectorType.get(x.type.shape, c1.type), c1) + return vector.extract( + _masked_cummax_lowering_rule(ctx, x, c1v), [], [vec_dim - 1]) + + +masked_cumsum_p = jax_core.Primitive("masked_cumsum") +masked_cumsum_p.multiple_results = False + +@masked_cumsum_p.def_abstract_eval +def _masked_cumsum_abstract_eval(x, mask): + if x.dtype != jnp.int32 and x.dtype != jnp.float32: + raise NotImplementedError(f"x.dtype={x.dtype} must be int32 or float32") + if not jnp.issubdtype(mask.dtype, jnp.bool): + raise TypeError(f"mask.dtype={mask.dtype} is not a boolean dtype") + if x.shape != mask.shape: + raise ValueError(f"x.shape={x.shape} != mask.shape={mask.shape}") + return jax_core.ShapedArray(x.shape, x.dtype) + +@sc_lowering.register_lowering_rule(masked_cumsum_p) +def _masked_cumsum_lowering_rule(ctx: sc_lowering.LoweringRuleContext, x, mask): + del ctx # Unused. + return tpu.scan( + x.type, x, ir.Attribute.parse("#tpu.reduction_kind"), mask=mask) + +@sc_lowering.register_lowering_rule(lax.cumsum_p) +def _cumsum_lowering_rule(ctx: sc_lowering.LoweringRuleContext, x, axis, + reverse): + if axis != 0: + raise NotImplementedError(f"SC cumsum: axis={axis} must be 0.") + if len(ctx.avals_in[0].shape) != 1: + raise NotImplementedError(f"SC cumsum: x={ctx.avals_in[0]} must be rank 1") + if reverse: + raise NotImplementedError("SC cumsum: reverse=True is not yet supported") + i1t = ir.IntegerType.get_signless(1) + c1 = arith.constant(i1t, ir.IntegerAttr.get(i1t, 1)) + c1v = vector.broadcast(ir.VectorType.get(x.type.shape, c1.type), c1) + return tpu.scan( + x.type, x, ir.Attribute.parse("#tpu.reduction_kind"), mask=c1v) + +def cumsum(x: jax.Array, *, mask: jax.Array | None = None) -> jax.Array: + """Returns the cumulative sum of the array along its innermost axis. + + This differs from `jnp.cumsum` in that it takes an additional `mask` argument. + + Args: + x: An array of integers or floats. + mask: An optional array of booleans, which specifies which elements of `x` + are eligible for summing. If `None`, all elements are eligible. + """ + if x.ndim != 1: + raise NotImplementedError(f"cumsum: x={x.aval} must be rank 1") + if mask is None: + mask = lax.full(x.shape, True) + return masked_cumsum_p.bind(x, mask) + +@sc_lowering.register_lowering_rule( + lax.reduce_sum_p, kernel_types=[tpu_core.KernelType.SC_VECTOR_SUBCORE]) +def _reduce_sum_lowering_rule( + ctx: sc_lowering.LoweringRuleContext, x, axes, out_sharding): + del out_sharding # Unused. + vec_dim = ctx.avals_in[0].shape[0] + if axes != (0,): + raise NotImplementedError(f"SC reduce_sum: axes={axes} must be (0,).") + return vector.extract( + _cumsum_lowering_rule(ctx, x, 0, reverse=False), [], [vec_dim - 1]) + + +masked_sort_p = jax_core.Primitive("masked_sort") +masked_sort_p.multiple_results = True + +@masked_sort_p.def_abstract_eval +def _masked_sort_abstract_eval(keys, values, *maybe_mask, descending): + del descending # Unused. + supported_shape = (sc_core.get_sparse_core_info().num_lanes,) + if keys.dtype not in (jnp.int32, jnp.float32): + raise NotImplementedError( + f"sort_key_val: keys dtype {keys.dtype} should be int32 or float32") + if keys.shape != supported_shape: + raise ValueError(f"keys shape {keys.shape} must be {supported_shape}") + if jnp.dtype(values.dtype).itemsize != 4: + raise NotImplementedError( + f"sort_key_val: values dtype {values.dtype} should be 32 bits") + if values.shape != supported_shape: + raise ValueError(f"values shape {values.shape} must be {supported_shape}") + if maybe_mask: + [mask] = maybe_mask + if not jnp.issubdtype(mask.dtype, jnp.bool): + raise TypeError(f"mask dtype {mask.dtype} is not boolean") + if mask.shape != supported_shape: + raise ValueError(f"mask shape {mask.shape} must be {supported_shape}") + return keys, values, *maybe_mask + +@sc_lowering.register_lowering_rule(masked_sort_p) +def _masked_sort_lowering_rule( + ctx: sc_lowering.LoweringRuleContext, keys, values, *maybe_mask, descending): + del ctx # Unused. + if maybe_mask: + [mask] = maybe_mask + else: + mask_type = ir.VectorType.get( + [sc_core.get_sparse_core_info().num_lanes], + ir.IntegerType.get_signless(1)) + mask = arith.constant(mask_type, ir.DenseElementsAttr.get_splat( + mask_type, ir.BoolAttr.get(True))) + out_mask, sorted_keys, sorted_values = tpu.sort( + mask.type, keys.type, values.type, keys, values, mask=mask, + descending=descending + ) + if maybe_mask: + return sorted_keys, sorted_values, out_mask + return sorted_keys, sorted_values + +def sort_key_val( + keys: jax.Array, values: jax.Array, *, + mask: jax.Array | None = None, descending: bool = False +) -> jax.Array: + """Sorts keys and values, pushing invalid elements to the last positions. + + Args: + keys: An array of integers or floats. + values: An array of values corresponding to the keys. + mask: An optional array of booleans, which specifies which elements of + `keys` and `values` are valid. If `None`, all elements are valid. + descending: Whether to sort in descending order. + + Returns: + sorted_keys, sorted_values, [output_mask]: The sorted keys and values, and, + if a mask was given, the corresponding mask for output keys and values. + """ + maybe_mask = () if mask is None else (mask,) + return masked_sort_p.bind(keys, values, *maybe_mask, descending=descending) + + +parallel_loop_p = jax_core.Primitive("parallel_loop") +parallel_loop_p.is_effectful = lambda params: bool(params["jaxpr"].effects) # type: ignore +parallel_loop_p.multiple_results = True + + +@parallel_loop_p.def_effectful_abstract_eval +def _parallel_loop_abstract_eval(*args, jaxpr, tree, **params): + del params # Unused. + _, _, _, _, carries = tree.unflatten(args) + if any(isinstance(c, (Ref, TransformedRef)) for c in carries): + raise TypeError(f"Carried values may not be refs, but got: {carries}") + updated_effects = set() + for eff in jaxpr.effects: + if isinstance(eff, effects.JaxprInputEffect): + # Offset for the parallel_loop eqn to account for start, stop, and step + # args passed to parallel_loop_p.bind. + eff = eff.replace(input_index=eff.input_index + 3) + updated_effects.add(eff) + return carries, updated_effects + + +@sc_lowering.register_lowering_rule(parallel_loop_p) +def _parallel_loop_lowering_rule( + ctx: sc_lowering.LoweringRuleContext, + *flat_args, + tree, + unroll, + jaxpr, +): + lower, upper, step, consts, carry = tree.unflatten(flat_args) + for_op = scf.ForOp( + _ensure_ir_value(lower, pallas_core.index_map_grid_aval), + _ensure_ir_value(upper, pallas_core.index_map_grid_aval), + _ensure_ir_value(step, pallas_core.index_map_grid_aval), + carry, + ) + for_op.attributes["sc.parallel_access"] = ir.UnitAttr.get() + for_op.attributes["sc.loop_unroll_factor"] = ir.IntegerAttr.get( + ir.IntegerType.get_signless(64), unroll + ) + with ir.InsertionPoint(for_op.body): + _, _, _, consts_block_shapes, *_ = tree.unflatten(ctx.block_shapes) + lowering_ctx = ctx.lowering_context.replace( + block_shapes=[*consts_block_shapes, None] + [None] * len(carry), + ) + carry_out = tc_lowering.jaxpr_subcomp( + lowering_ctx, + pe.convert_constvars_jaxpr(jaxpr), + *consts, + for_op.induction_variable, + *for_op.inner_iter_args, + ) + scf.yield_(carry_out) + return for_op.results + + +@overload +def parallel_loop( + lower: jax.typing.ArrayLike, + upper: jax.typing.ArrayLike, + step: jax.typing.ArrayLike = ..., + *, + unroll: int = ..., + carry: None = None, +) -> Callable[[Callable[[jax.Array], None]], None]: + ... + + +@overload +def parallel_loop( + lower: jax.typing.ArrayLike, + upper: jax.typing.ArrayLike, + step: jax.typing.ArrayLike = ..., + *, + unroll: int = ..., + carry: _T, +) -> Callable[[Callable[[jax.Array, _T], _T]], _T]: + ... + + +def parallel_loop(lower, upper, step=1, *, unroll=1, carry=None): + """A parallel loop decorator. + + The decorated function forms the loop body. It is called with the current + loop index as the argument and optionally, a single additional carry argument. + + The loop iterations must be independent, meaning that operations in one + iteration cannot depend on the side effects, especially Ref writes, of any + other iteration. This allows the compiler to execute instructions from + different iterations concurrently, potentially reordering them for better + performance. + + Cross-iteration dependencies traceable via carried values are allowed. Refs + may not be carried. + + Safe usage of carried value:: + + @parallel_loop(0, 64, step=8, carry=jnp.int32(1)) + def body(i, j): + # Writes are independent across iterations. + x_ref[pl.ds(i, 8)] = j + jnp.arange(8) + return j + 1 + + Any pytree can be carried. The final value is returned by the decorator:: + + def body(i, my_tree: MyTree): + # Writes are independent across iterations. + x_ref[pl.ds(i, 8)] = my_tree.transform(jnp.arange(8)) + return my_tree.step(i) + final_value = parallel_loop(0, 64, step=8, carry=MyTree())(body) + + Undefined result:: + + @parallel_loop(0, 64, step=4, carry=jnp.int32(1)) + def body(i, j): + # Because the step size is 4, the array written is of size 8, and loop + # iterations may be reordered, the values in indices 4-59 of x_ref are + # unspecified after the loop. (The values in 0-3 and 60-63 are only + # written by the first and last iterations, so are well-defined.) + x_ref[pl.ds(i, 8)] = j + jnp.arange(8) + return j + 1 + + Unsafe read of "previous" iteration's write (don't do this):: + + @parallel_loop(0, 64, 8, carry=jnp.int32(1)) + def body(i, j): + # Unsafe because it depends on the side-effect of "previous" iterations, + # which may be executed in parallel or reordered. + mask = x_ref[pl.ds(0, 8)] < j + x_ref[pl.ds(0, 8)] += jnp.where(mask, j + jnp.arange(8), 0) + return j + 1 + + Args: + lower: The starting value of the loop index. + upper: The exclusive upper bound of the loop index. + step: The increment of the loop index. Default to 1. + unroll: The unroll factor of the loop. + carry: Optional carried state of the loop. + + Returns: + A decorator that executes the given function in a parallel loop. + """ + + def decorator(body): + flat_carries, carry_tree = jax.tree.flatten(carry) + def wrapped(idx, *carries): + if carry is None: + body(idx) + return [] + result = body(idx, carry_tree.unflatten(carries)) + result, result_tree = jax.tree.flatten(result) + if result_tree != carry_tree: + raise ValueError( + "parallel_loop: body result should have same structure as carry:" + f" {result_tree} != {carry_tree}" + ) + return result + flat_avals = [ + pallas_core.index_map_grid_aval, + *(c.aval for c in flat_carries), + ] + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( + lu.wrap_init( + wrapped, + debug_info=api_util.debug_info( + "parallel_loop", body, flat_avals, {} + ), + ), + flat_avals, + ) + carry_tree.unflatten(jaxpr.outvars) # Verify same structure. + disallowed_effects = effects.control_flow_allowed_effects.filter_not_in( + jaxpr.effects + ) + if disallowed_effects: + raise NotImplementedError( + f"Effects not supported in parallel_loop: {disallowed_effects}" + ) + flat_args, tree = jax.tree.flatten( + (lower, upper, step, consts, flat_carries) + ) + flat_result = parallel_loop_p.bind( + *flat_args, tree=tree, unroll=unroll, jaxpr=jaxpr + ) + if carry is None: + return None + return carry_tree.unflatten(flat_result) + + return decorator + + +class PackFormat(enum.Enum): + #: [a0, a1], [b0, b1] -> [[a0, a1], [b0, b1]] + COMPRESSED = "compressed" + #: [a0, a1], [b0, b1] -> [a0, b0, a1, b1] + INTERLEAVED = "interleaved" + + +def _format_to_ir_attribute(format: PackFormat) -> ir.Attribute: + return ir.Attribute.parse(f"#tpu.pack_format<{format.value}>") + + +pack_p = jax_core.Primitive("pack") + + +@pack_p.def_abstract_eval +def _pack_abstract_eval(a, b, *, format, preferred_element_type): + if a.shape != b.shape: + raise ValueError( + f"Packed arrays must have the same shape, got {a.shape} and {b.shape}" + ) + if a.ndim != 1: + raise ValueError(f"Packed arrays must be 1-D, got {a.ndim}") + if a.dtype != b.dtype: + raise TypeError( + f"Packed arrays must have the same dtype, got {a.dtype} and {b.dtype}" + ) + if preferred_element_type is None: + match a.dtype: + case jnp.float32: + packed_dtype = jnp.bfloat16 + case jnp.int32: + packed_dtype = jnp.int16 + case _: + # TODO(slebedev): Support more types. + raise NotImplementedError( + f"Only packing of float32 and int32 is supported, got {a.dtype}" + ) + else: + packed_bw = dtypes.itemsize_bits(a.dtype) // 2 + if dtypes.itemsize_bits(preferred_element_type) != packed_bw: + raise ValueError( + f"preferred_element_type= must have bitwidth {packed_bw}, got" + f" {dtypes.itemsize_bits(preferred_element_type)}" + ) + packed_dtype = preferred_element_type + + match format: + case PackFormat.INTERLEAVED: + packed_shape = (2 * a.size,) + case PackFormat.COMPRESSED: + packed_shape = (a.size, 2) + return jax_core.ShapedArray(packed_shape, packed_dtype) + + +@sc_lowering.register_lowering_rule(pack_p) +def _pack_lowering_rule( + ctx: sc_lowering.LoweringRuleContext, + a, + b, + *, + format, + preferred_element_type, +): + del preferred_element_type # Unused. + [out_aval] = ctx.avals_out + return tpu.pack_subelements( + aval_to_ir_type(out_aval), + [a, b], + [0, 1], + _format_to_ir_attribute(format), + ) + + +def pack( + a: jax.Array, + b: jax.Array, + /, + *, + format: PackFormat, + preferred_element_type: jax.typing.DTypeLike | None = None, +) -> jax.Array: + """Packs two arrays according to the given format. + + .. warning:: This API is temporary and will be removed once the SparseCore + compiler is able to do packing/unpacking automatically. + + Args: + a: The first array to pack. + b: The second array to pack. + format: The packing format to use. + preferred_element_type: Optional. The preferred element type of the packed + array. If specified, must have half the bitwidth of the input array types. + + Returns: + The packed array. + """ + if preferred_element_type is not None: + preferred_element_type = jnp.dtype(preferred_element_type) + return pack_p.bind( + a, b, format=format, preferred_element_type=preferred_element_type + ) + + +unpack_p = jax_core.Primitive("unpack") +unpack_p.multiple_results = True + + +@unpack_p.def_abstract_eval +def _unpack_abstract_eval(ab, *, format, preferred_element_type): + match format: + case PackFormat.INTERLEAVED: + if ab.ndim != 1 or ab.size % 2 != 0: + raise ValueError( + "Interleaved unpack requires a 1-D array with an even size, got" + f" {ab.shape}" + ) + case PackFormat.COMPRESSED: + if ab.ndim != 2 or ab.shape[1] != 2: + raise ValueError( + "Compressed unpack requires an array with shape (N, 2), got" + f" {ab.shape}" + ) + if preferred_element_type is None: + match ab.dtype: + case jnp.bfloat16: + unpacked_dtype = jnp.float32 + case jnp.int16: + unpacked_dtype = jnp.int32 + case _: + # TODO(slebedev): Support more types. + raise NotImplementedError( + f"Only unpacking of bloat16 and int16 is supported, got {ab.dtype}" + ) + else: + unpacked_bw = dtypes.itemsize_bits(ab.dtype) * 2 + if dtypes.itemsize_bits(preferred_element_type) != unpacked_bw: + raise ValueError( + f"preferred_element_type= must have bitwidth {unpacked_bw}, got" + f" {dtypes.itemsize_bits(preferred_element_type)}" + ) + unpacked_dtype = preferred_element_type + return (jax_core.ShapedArray((ab.size // 2,), unpacked_dtype),) * 2 + + +@sc_lowering.register_lowering_rule(unpack_p) +def _unpack_lowering_rule( + ctx: sc_lowering.LoweringRuleContext, ab, *, format, preferred_element_type +): + del preferred_element_type # Unused. + out_aval, _ = ctx.avals_out + out_type = aval_to_ir_type(out_aval) + return ( + tpu.unpack_subelements(out_type, ab, 0, _format_to_ir_attribute(format)), + tpu.unpack_subelements(out_type, ab, 1, _format_to_ir_attribute(format)), + ) + + +def unpack( + ab: jax.Array, + /, + *, + format: PackFormat, + preferred_element_type: jax.typing.DTypeLike | None = None, +) -> tuple[jax.Array, jax.Array]: + """Unpacks two arrays according to the given format. + + .. warning:: This API is temporary and will be removed once the SparseCore + compiler is able to do packing/unpacking automatically. + + Args: + ab: The array to unpack. + format: The packing format to use. + preferred_element_type: Optional. The preferred element type of the unpacked + arrays. If specified, must have double the bitwidth of the input array + type. + + Returns: + The unpacked arrays. + """ + if preferred_element_type is not None: + preferred_element_type = jnp.dtype(preferred_element_type) + return unpack_p.bind( + ab, + format=format, + preferred_element_type=preferred_element_type, + ) + + +def _mask_all_reduce_abstract_eval(x, *, reduce): + if x.dtype != jnp.bool: + raise TypeError(f"Mask all-reduce only supports bool arrays, got {x.dtype}") + match x.shape: + case (minor_dim,): + return jax_core.ShapedArray((minor_dim // reduce,), jnp.int32) + case _: + raise ValueError("Mask all-reduce only supports 1D arrays") + + +def _mask_all_reduce_lowering_rule( + ctx: sc_lowering.LoweringRuleContext, x, *, reduce, kind: str +): + [out_aval] = ctx.avals_out + return tpu.all_reduce( + ir.VectorType.get( + out_aval.shape, + ir.IntegerType.get_signless(32), + ), + x, + 0, + ir.Attribute.parse(f"#tpu.reduction_kind<{kind}>"), + ) + + +all_reduce_population_count_p = jax_core.Primitive( + "all_reduce_population_count" +) +all_reduce_population_count_p.def_abstract_eval(_mask_all_reduce_abstract_eval) +sc_lowering.register_lowering_rule(all_reduce_population_count_p)( + functools.partial(_mask_all_reduce_lowering_rule, kind="sum") +) + + +def all_reduce_population_count(x: jax.Array, *, reduce: int = 1) -> jax.Array: + """Computes the number of nonzero elements in the array. + + Args: + x: A 1D array of bools. + reduce: The factor to reduce the output shape by. + + Returns: + An array with each element containing the number of true elements in ``x``. + """ + return all_reduce_population_count_p.bind(x, reduce=reduce) + + +all_reduce_ffs_p = jax_core.Primitive("all_reduce_ffs") +all_reduce_ffs_p.def_abstract_eval(_mask_all_reduce_abstract_eval) +sc_lowering.register_lowering_rule(all_reduce_ffs_p)( + functools.partial(_mask_all_reduce_lowering_rule, kind="find_first_set") +) + + +def all_reduce_ffs(x: jax.Array, *, reduce: int = 1) -> jax.Array: + """Computes the index of the first true element in the array. + + Args: + x: A 1D array of bools. + reduce: The factor to reduce the output shape by. + + Returns: + An array with each element containing the index of the first true element in + ``x`` or ``x.size`` if there are no true elements. + """ + return all_reduce_ffs_p.bind(x, reduce=reduce) diff --git a/jax/_src/pallas/mosaic/tpu_info.py b/jax/_src/pallas/mosaic/tpu_info.py new file mode 100644 index 000000000000..3159fe0d1a98 --- /dev/null +++ b/jax/_src/pallas/mosaic/tpu_info.py @@ -0,0 +1,366 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Exposes TPU hardware information.""" + +import dataclasses +import enum +from typing import Callable + +from jax import numpy as jnp +from jax._src import dtypes +from jax._src import util as jax_util +from jax._src.pallas.mosaic import core + + +class ChipVersionBase: + pass + + +class ChipVersion(ChipVersionBase, enum.Enum): + TPU_V2 = "v2" + TPU_V3 = "v3" + TPU_V4I = "v4i" + TPU_V4 = "v4" + TPU_V5E = "v5e" + TPU_V5P = "v5p" + TPU_V6E = "v6e" + TPU_7X = "7x" + + def __str__(self) -> str: + return self.value + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class SparseCoreInfo: + """SparseCore-specific information.""" + + num_cores: int + num_subcores: int + num_lanes: int + dma_granule_size_bytes: int + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class TpuInfo: + """TPU hardware information. + + Note that all information is per-TensorCore so you would need to multiply + by `num_cores` to obtain the total for the chip. + """ + + chip_version: ChipVersionBase + generation: int + num_cores: int + num_lanes: int + num_sublanes: int + mxu_column_size: int + vmem_capacity_bytes: int + cmem_capacity_bytes: int + smem_capacity_bytes: int + hbm_capacity_bytes: int + mem_bw_bytes_per_second: int + bf16_ops_per_second: int + int8_ops_per_second: int + fp8_ops_per_second: int + int4_ops_per_second: int + + sparse_core: SparseCoreInfo | None = None + + @property + def is_lite(self) -> bool: + return self.chip_version in { + ChipVersion.TPU_V4I, + ChipVersion.TPU_V5E, + ChipVersion.TPU_V6E, + } + + @property + def is_split_chip(self) -> bool: + # Is this a multi-core chip being used in single-core mode? + return self.num_cores == 1 and not self.is_lite + + def is_matmul_supported( + self, + lhs_dtype: jnp.dtype | str, + rhs_dtype: jnp.dtype | str, + ) -> bool: + """Returns whether the given matmul input dtypes are supported on the chip.""" + lhs_dt = jnp.dtype(lhs_dtype) if isinstance(lhs_dtype, str) else lhs_dtype + rhs_dt = jnp.dtype(rhs_dtype) if isinstance(rhs_dtype, str) else rhs_dtype + + F32 = jnp.float32 + BF16 = jnp.bfloat16 + S8 = jnp.int8 + U8 = jnp.uint8 + F8E4M3B11FNUZ = jnp.float8_e4m3b11fnuz + F8E4M3FN = jnp.float8_e4m3fn + F8E5M2 = jnp.float8_e5m2 + S4 = jnp.int4 + U4 = jnp.uint4 + + match self.generation: + case 2 | 3: + return lhs_dt == rhs_dt == F32 + case 4: + return lhs_dt in {F32, BF16} and rhs_dt in {F32, BF16, S8} + case 5 | 6: + return ( + ( + lhs_dt in {F32, BF16, F8E5M2, F8E4M3B11FNUZ} + and rhs_dt in {F32, BF16, F8E5M2, F8E4M3B11FNUZ} + ) + or (lhs_dt in {U8, S8} and rhs_dt in {U8, S8}) + or (lhs_dt in {U4, S4} and rhs_dt in {U4, S4}) + ) + case 7: + return (lhs_dt in {F32, BF16} and rhs_dt in {F32, BF16}) or ( + lhs_dt in {F32, BF16, F8E5M2, F8E4M3FN} + and rhs_dt in {F8E5M2, F8E4M3FN} + ) + case _: + return False + + def get_sublane_tiling(self, dtype: jnp.dtype) -> int: + """Returns the sublane tiling for the given itemsize. + + Note that this is a heurustic and depends on the settings of the XLA flags. + """ + bitwidth = dtypes.itemsize_bits(dtype) + if self.generation < 7: + # Caveat: before TPU7x, by default XLA does not use large 2nd minor tiling + # but it can be enabled by setting the flag + # xla_tpu_enable_large_2nd_minor_layout_for_x16. + if bitwidth == 16 or bitwidth == 32: + return self.num_sublanes + else: + # Large 2nd minor tiling is enabled for other types. + return self.num_sublanes * (32 // bitwidth) + # XLA allows large 2nd minor tiling by default starting with TPU7x. + if self.generation == 7: + return self.num_sublanes * (32 // bitwidth) + raise NotImplementedError("TPU generation is not supported") + + +def is_tpu_device() -> bool: + """Returns whether the current device is a TPU.""" + return core.get_device_kind() in { + "TPU v2", + "TPU v3", + "TPU v4", + "TPU v4 lite", + "TPU v5e", + "TPU v5 lite", + "TPU v5", + "TPU v5p", + "TPU v6 lite", + "TPU v6e", + "TPU7x", + } + + +registry: dict[str, Callable[[], TpuInfo]] = {} + + +@jax_util.cache(trace_context_in_key=True) +def get_tpu_info() -> TpuInfo: + """Returns the TPU hardware information for the current device. + + Note that all information is *per-TensorCore* so you would need to multiply by + `num_cores` to obtain the total for the chip. + + Returns: + A TpuInfo object containing the hardware information for the current device. + """ + device_kind = core.get_device_kind() + + # Common parameters for all TensorCores + NUM_LANES = 128 + NUM_SUBLANES = 8 + MXU_COLUMN_SIZE_GEN_LT_6 = 128 + MXU_COLUMN_SIZE_GEN_GE_6 = 256 + + match device_kind: + case "TPU v2": # 2 TensorCores per chip + num_chip_cores = 2 + return TpuInfo( + chip_version=ChipVersion.TPU_V2, + generation=2, + num_cores=core.get_num_device_cores(), + num_lanes=NUM_LANES, + num_sublanes=NUM_SUBLANES, + mxu_column_size=MXU_COLUMN_SIZE_GEN_LT_6, + vmem_capacity_bytes=16 * 1024 * 1024, # 16 MiB per core + cmem_capacity_bytes=0, + smem_capacity_bytes=16 * 1024, # 16 KiB per core + hbm_capacity_bytes=int(16_000_000_000 // num_chip_cores), + mem_bw_bytes_per_second=int(7.16e11 // num_chip_cores), + bf16_ops_per_second=int(4.6e13 // num_chip_cores), + int8_ops_per_second=0, # Not Available + fp8_ops_per_second=0, # Not Available + int4_ops_per_second=0, # Not Available + ) + case "TPU v3": # 2 TensorCores per chip + num_chip_cores = 2 + return TpuInfo( + chip_version=ChipVersion.TPU_V3, + generation=3, + num_cores=core.get_num_device_cores(), + num_lanes=NUM_LANES, + num_sublanes=NUM_SUBLANES, + mxu_column_size=MXU_COLUMN_SIZE_GEN_LT_6, + vmem_capacity_bytes=16 * 1024 * 1024, # 16 MiB per core + cmem_capacity_bytes=0, + smem_capacity_bytes=16 * 1024, # 16 KiB per core + hbm_capacity_bytes=34_400_000_000 // num_chip_cores, + mem_bw_bytes_per_second=int(8.25e11 // num_chip_cores), + bf16_ops_per_second=int(1.40e14 // num_chip_cores), + int8_ops_per_second=0, # Not Available + fp8_ops_per_second=0, # Not Available + int4_ops_per_second=0, # Not Available + ) + case "TPU v4 lite": # 1 TensorCore per chip + return TpuInfo( + chip_version=ChipVersion.TPU_V4I, + generation=4, + num_cores=core.get_num_device_cores(), + num_lanes=NUM_LANES, + num_sublanes=NUM_SUBLANES, + mxu_column_size=MXU_COLUMN_SIZE_GEN_LT_6, + vmem_capacity_bytes=16 * 1024 * 1024, # 16 MiB per core + cmem_capacity_bytes=134_000_000, + smem_capacity_bytes=1024 * 1024, # 1 MiB per core + hbm_capacity_bytes=8_590_000_000, + mem_bw_bytes_per_second=int(6.14e11), + bf16_ops_per_second=int(1.37e14), + int8_ops_per_second=0, # Not Available + fp8_ops_per_second=0, # Not Available + int4_ops_per_second=0, # Not Available + ) + case "TPU v4": # 2 TensorCores per chip + num_chip_cores = 2 + return TpuInfo( + chip_version=ChipVersion.TPU_V4, + generation=4, + num_cores=core.get_num_device_cores(), + num_lanes=NUM_LANES, + num_sublanes=NUM_SUBLANES, + mxu_column_size=MXU_COLUMN_SIZE_GEN_LT_6, + vmem_capacity_bytes=16 * 1024 * 1024, # 16 MiB per core + cmem_capacity_bytes=134_000_000 // num_chip_cores, + smem_capacity_bytes=1024 * 1024, # 1 MiB per core + hbm_capacity_bytes=34_400_000_000 // num_chip_cores, + mem_bw_bytes_per_second=int(1.23e12 // num_chip_cores), + bf16_ops_per_second=int(2.75e14 // num_chip_cores), + int8_ops_per_second=0, # Not Available + fp8_ops_per_second=0, # Not Available + int4_ops_per_second=0, # Not Available + ) + case "TPU v5 lite" | "TPU v5e": # 1 TensorCore per chip + return TpuInfo( + chip_version=ChipVersion.TPU_V5E, + generation=5, + num_cores=core.get_num_device_cores(), + num_lanes=NUM_LANES, + num_sublanes=NUM_SUBLANES, + mxu_column_size=MXU_COLUMN_SIZE_GEN_LT_6, + vmem_capacity_bytes=128 * 1024 * 1024, # 128 MiB per core + cmem_capacity_bytes=0, + smem_capacity_bytes=1024 * 1024, # 1 MiB per core + hbm_capacity_bytes=17_200_000_000, + mem_bw_bytes_per_second=int(8.20e11), + bf16_ops_per_second=int(1.97e14), + int8_ops_per_second=int(3.94e14), + fp8_ops_per_second=0, # Not Available + int4_ops_per_second=int(7.88e14), + ) + case "TPU v5" | "TPU v5p": # 2 TensorCores per chip + num_chip_cores = 2 + return TpuInfo( + chip_version=ChipVersion.TPU_V5P, + generation=5, + num_cores=core.get_num_device_cores(), + num_lanes=NUM_LANES, + num_sublanes=NUM_SUBLANES, + mxu_column_size=MXU_COLUMN_SIZE_GEN_LT_6, + vmem_capacity_bytes=64 * 1024 * 1024, # 64 MiB per core + cmem_capacity_bytes=0, + smem_capacity_bytes=1024 * 1024, # 1 MiB per core + hbm_capacity_bytes=103_000_000_000 // num_chip_cores, + mem_bw_bytes_per_second=int(2.46e12 // num_chip_cores), + bf16_ops_per_second=int(4.59e14 // num_chip_cores), + int8_ops_per_second=int(9.18e14 // num_chip_cores), + fp8_ops_per_second=0, # Not Available + int4_ops_per_second=int(1.84e15 // num_chip_cores), + sparse_core=SparseCoreInfo( + num_cores=4, + num_subcores=16, + num_lanes=8, + dma_granule_size_bytes=32, + ), + ) + case "TPU v6 lite" | "TPU v6e": # 1 TensorCore per chip + return TpuInfo( + chip_version=ChipVersion.TPU_V6E, + generation=6, + num_cores=core.get_num_device_cores(), + num_lanes=NUM_LANES, + num_sublanes=NUM_SUBLANES, + mxu_column_size=MXU_COLUMN_SIZE_GEN_GE_6, + vmem_capacity_bytes=128 * 1024 * 1024, # 128 MiB per core + cmem_capacity_bytes=0, + smem_capacity_bytes=1024 * 1024, # 1 MiB per core + hbm_capacity_bytes=34_400_000_000, + mem_bw_bytes_per_second=int(1.64e12), + bf16_ops_per_second=int(9.20e14), + int8_ops_per_second=int(1.84e15), + fp8_ops_per_second=int(9.20e14), + int4_ops_per_second=int(3.68e15), + sparse_core=SparseCoreInfo( + num_cores=2, + num_subcores=16, + num_lanes=8, + dma_granule_size_bytes=32, + ), + ) + case "TPU7x": + num_cores = core.get_num_device_cores() + num_chip_cores = 2 + return TpuInfo( + chip_version=ChipVersion.TPU_7X, + generation=7, + num_cores=num_cores, + num_lanes=128, + num_sublanes=8, + mxu_column_size=256, + vmem_capacity_bytes=64 * 1024 * 1024, # 64 MiB per core + cmem_capacity_bytes=0, + smem_capacity_bytes=1024 * 1024, # 1 MiB per core + hbm_capacity_bytes=206_000_000_000 // num_chip_cores, + mem_bw_bytes_per_second=int(7.40e12 // num_chip_cores), + bf16_ops_per_second=int(2.31e15 // num_chip_cores), + int8_ops_per_second=0, # Not Available + fp8_ops_per_second=int(4.60e15 // num_chip_cores), + int4_ops_per_second=0, # Not Available + sparse_core=SparseCoreInfo( + num_cores=4, + num_subcores=16, + num_lanes=16, + dma_granule_size_bytes=64, + ), + ) + case _ as d: + if d in registry: + return registry[d]() + raise ValueError(f"Unsupported TPU device kind: {device_kind}") diff --git a/jax/_src/pallas/mosaic/verification.py b/jax/_src/pallas/mosaic/verification.py deleted file mode 100644 index 08ff58770804..000000000000 --- a/jax/_src/pallas/mosaic/verification.py +++ /dev/null @@ -1,662 +0,0 @@ -# Copyright 2024 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import contextlib -import dataclasses -import io -import itertools -import math -import textwrap -from typing import Any, Sequence -from jax import lax -from jax._src import core as jax_core -from jax._src import tree_util -from jax._src.lib import tpu -from jax._src.pallas.mosaic import lowering -from jax._src.pallas.mosaic import primitives -from jax._src.util import split_list, unzip2 -from jaxlib.mlir import ir -from jaxlib.mlir.dialects import arith -from jaxlib.mlir.dialects import func -from jaxlib.mlir.passmanager import PassManager - -_UNSPECIFIED = object() -Var = str - -# TODO(apaszke): Add checks that semaphores are always left at 0. -# TODO(apaszke): Add checks that no remote resources are used while the remote -# device is not in the kernel (both before and after). -# TODO(apaszke): Model 0-sized DMAs faithfully. - - -PREAMBLE = """ -#define buf_readers(index, device, core) _buf_readers[(index)*(NDEVICE*NCORE) + (device)*NCORE + core] -#define buf_written(index, device, core) _buf_written[(index)*(NDEVICE*NCORE) + (device)*NCORE + core] -#define sems(index, device, core) _sems[(index)*(NDEVICE*NCORE) + (device)*NCORE + core] -#define barrier_sems(device, core) _barrier_sems[(device)*NCORE + core] - -#ifndef NDMA -#define NDMA 2 -#endif - -mtype = { DMA }; -chan dma_queue = [NDMA] of { mtype, int, int, int, int, int, int, int, int, int, int }; -""" - -DMA_PROCESS = """ -active [NDMA] proctype DmaEngine() { - int src_dev, src_core, src_sem, src_buf_base, src_buf_len; - int dst_dev, dst_core, dst_sem, dst_buf_base, dst_buf_len; - do - :: skip; - end: dma_queue?DMA(src_dev, src_core, src_sem, src_buf_base, src_buf_len, dst_dev, dst_core, dst_sem, dst_buf_base, dst_buf_len); - d_step { - printf("DMA read done: [%d, %d)@{%d, %d} (%d++)\\n", src_buf_base, src_buf_base + src_buf_len, src_dev, src_core, src_sem); - int i; - for (i : src_buf_base .. src_buf_base + src_buf_len - 1) { - buf_readers(i, src_dev, src_core)--; - } - sems(src_sem, src_dev, src_core)++; - } // Read complete - d_step { - printf("DMA write done: [%d, %d)@{%d, %d} (%d++)\\n", dst_buf_base, dst_buf_base + dst_buf_len, dst_dev, dst_core, dst_sem); - int i; - for (i : dst_buf_base .. dst_buf_base + dst_buf_len - 1) { - buf_written(i, dst_dev, dst_core)--; - } - sems(dst_sem, dst_dev, dst_core)++; - } // Write complete - od -} -""" - -class PrintCtx: - MAX_REF_UNROLL = 8 - - def __init__(self, iteration_bounds): - self.level = 1 - self.num_semaphores = 0 - self.num_buffers = 0 - self.locals = [] - self.counter = itertools.count() - self.env: dict[ir.Value, Var | int] = {} - self.program_ids = tuple(f"pid{i}" for i in range(len(iteration_bounds))) - self.device_id = "dev_id" - - # TODO(apaszke): Clean up core_id! This is not a visible detail in the Mosaic - # programming model. - self.emit(None, "int core_id = 0") - # Reconstruct device id and program ids from the pid. - self.emit(None, "int dev_id") - if iteration_bounds: - self.emit(None, f"int {', '.join(self.program_ids)}") - with self.block("d_step {", "}"): - idx = "_pid" - program_ids = [] - for i, b in reversed(list(enumerate(iteration_bounds))): - program_ids.append(self.emit(None, f"pid{i} = {idx} % {b}")) - idx = self.emit("int", f"{idx} / {b}") - self.emit(None, f"dev_id = {idx}") - - def emit_global_ref(self, shape: Sequence[int]): - slots = 1 - if shape and shape[0] <= self.MAX_REF_UNROLL: - slots = shape[0] - base = self.num_buffers - self.num_buffers += slots - return GlobalRefModel(base, slots) - - def emit_global_semaphore_ref(self, shape: Sequence[int]): - count = math.prod(shape) - base = self.num_semaphores - self.num_semaphores += count - return GlobalSemaphoreModel(base, count) - - def _indent(self, text: str) -> str: - return textwrap.indent(text, " " * self.level) - - def emit(self, ty, expr): - name = None - if ty is not None: - name = "l" + str(next(self.counter)) - expr = f"{ty} {name} = {expr}" - self.locals.append(self._indent(expr) + ";\n") - return name - - def comment(self, comment): - self.locals.append(self._indent(f"/* {comment} */\n")) - - @contextlib.contextmanager - def block(self, begin: str, end: str): - self.locals.append(self._indent(begin) + "\n") - self.level += 1 - yield - self.level -= 1 - self.locals.append(self._indent(end) + "\n") - - @contextlib.contextmanager - def comment_if_emitted(self, comment): - self.comment(comment) - yield - self.comment(comment) - if self.locals[-1] == self.locals[-2]: - self.locals.pop() - self.locals.pop() - - def get(self, value: ir.Value, default: Any = _UNSPECIFIED): - if default is _UNSPECIFIED: - return self.env[value] - else: - return self.env.get(value, default) - - def set(self, value: ir.Value, model_value: Any): - self.env[value] = model_value - - def get_model( - self, - has_barrier_sems: bool, - num_devices: int, - num_cores: int, - parallel_iteration_bounds: Sequence[int], - ) -> str: - result = io.StringIO() - result.write(f"#define NDEVICE {num_devices}\n") - result.write("#define NCORE 1\n") - result.write(f"byte _buf_readers[{self.num_buffers}*NDEVICE*NCORE] = 0;\n") - result.write(f"bool _buf_written[{self.num_buffers}*NDEVICE*NCORE] = 0;\n") - result.write(f"byte _sems[{self.num_semaphores}*NDEVICE*NCORE] = 0;\n") - if has_barrier_sems: - result.write("byte _barrier_sems[NDEVICE*NCORE] = 0;\n") - result.write(PREAMBLE) - result.write("\n") - parallel_threads = math.prod(parallel_iteration_bounds) - result.write(f"active [NDEVICE*{parallel_threads}] proctype Kernel() {{\n") - for l in self.locals: - result.write(l) - result.write("}\n") - result.write(DMA_PROCESS) - return result.getvalue() - - -def resolve_location(location): - if location is None: - location = [None, None] - else: - location = list(location) - if location[0] is None: - location[0] = "dev_id" - if location[1] is None: - location[1] = "core_id" - return tuple(location) - - -@dataclasses.dataclass(frozen=True) -class GlobalRefModel: - """A model of a memory reference. - - When a reference has a small leading dimension, it might be represented by - multiple slots in the reference array. Its region starts at base (that can be - dynamic) and has the given length (always static). - """ - base: Any - length: int - - def readers_at(self, location): - dev, core = resolve_location(location) - return [f"buf_readers({self.base} + {i}, {dev}, {core})" for i in range(self.length)] - - def written_at(self, location): - dev, core = resolve_location(location) - return [f"buf_written({self.base} + {i}, {dev}, {core})" for i in range(self.length)] - - -@dataclasses.dataclass(frozen=True) -class GlobalSemaphoreModel: - """A model of a semaphore reference. - - Semaphore arrays are always fully unrolled and are represented by a contiguous - subset of the global semaphore array. - """ - base: Any - length: int - - def at(self, location): - dev, core = resolve_location(location) - return f"sems({self.base}, {dev}, {core})" - - -@dataclasses.dataclass(frozen=True) -class GlobalBarrierSemaphoreModel: - def at(self, location): - dev, core = resolve_location(location) - return f"barrier_sems({dev}, {core})" - - -def _print_op(ctx, op): - match op.name: - case "tpu.region": - _print_block(ctx, op.body) - case "tpu.device_id": - return ctx.device_id - case "arith.constant": - if ir.IntegerType.isinstance(op.result.type): - return str(ir.IntegerAttr(op.value).value) - else: - return - case "tpu.sem_signal": - location = resolve_location((ctx.get(op.device_id, None), ctx.get(op.core_id, None))) - sem_model = ctx.get(op.semaphore) - sem = sem_model.at(location) - amount = ctx.get(op.amount) - if isinstance(sem_model, GlobalBarrierSemaphoreModel): - ctx.emit(None, f'printf("Signal: BARRIER@{{%d, %d}} += %d\\n", {location[0]}, {location[1]}, {amount})') - else: - ctx.emit(None, f'printf("Signal: %d@{{%d, %d}} += %d\\n", {sem_model.base}, {location[0]}, {location[1]}, {amount})') - ctx.emit(None, f"d_step {{ {sem} = {sem} + {amount} }}") - case "tpu.sem_wait": - sem_model = ctx.get(op.semaphore) - sem = sem_model.at(location=None) - amount = ctx.get(op.amount) - ctx.emit(None, f"atomic {{ {sem} >= {amount}; {sem} = {sem} - {amount} }}") - if isinstance(sem_model, GlobalBarrierSemaphoreModel): - ctx.emit(None, f'printf("Wait done: BARRIER -= %d\\n", {amount})') - else: - ctx.emit(None, f'printf("Wait done: %d -= %d\\n", {sem_model.base}, {amount})') - case "tpu.enqueue_dma": - dst_location = resolve_location((ctx.get(op.device_id, None), ctx.get(op.core_id, None))) - src = ctx.get(op.source) - src_sem = ctx.get(op.source_semaphore) - dst = ctx.get(op.target) - dst_sem = ctx.get(op.target_semaphore) - src_readonly = "\n && ".join(is_written + " == 0" for is_written in src.written_at(None)) - dst_unused = "\n && ".join( - is_written + " == 0" - for is_written in itertools.chain( - dst.written_at(dst_location), dst.readers_at(dst_location) - ) - ) - ctx.emit( - None, - 'printf("DMA: [%d, %d)@{%d, %d} -> [%d, %d)@{%d, %d}\\n",' - f" {src.base}, {src.base} + {src.length}, dev_id, core_id," - f" {dst.base}, {dst.base} + {dst.length}, {dst_location[0]}," - f" {dst_location[1]})", - ) - with ctx.block("d_step {", "}"): - ctx.emit(None, f"assert({src_readonly}); // Source is not written to.") - ctx.emit(None, f"assert({dst_unused}); // Destination is unused.") - for r in src.readers_at(None): - ctx.emit(None, f"{r}++") - for w in dst.written_at(dst_location): - ctx.emit(None, f"{w} = 1") - ctx.emit( - None, - f"dma_queue!DMA(dev_id, core_id, {src_sem.base}, {src.base}," - f" {src.length}, {dst_location[0]}, {dst_location[1]}," - f" {dst_sem.base}, {dst.base}, {dst.length})", - ) - case "tpu.wait_dma": - sem_model = ctx.get(op.semaphore) - sem = sem_model.at(location=None) - ctx.emit(None, f"atomic {{ {sem} >= 1; {sem} = {sem} - 1 }}") - ctx.emit(None, f'printf("Awaited DMA: %d\\n", {sem_model.base})') - case "tpu.sem_barrier": - return GlobalBarrierSemaphoreModel() - case "tpu.memref_slice": - result = ctx.get(op.mem_ref, None) - if result is None: - return NotImplemented - src_shape = ir.MemRefType(op.mem_ref.type).shape - dst_shape = ir.MemRefType(op.result.type).shape - dynamic = ir.ShapedType.get_dynamic_size() - # We always unroll semaphore references entirely, and we need to be - # faithful when slicing them. - if isinstance(result, GlobalSemaphoreModel): - # We only support contiguous slices of semaphore arrays at the moment. - seen_nontrivial_unequal = False - for s, d in zip(src_shape, dst_shape): - if d == 1: - continue - if s != d: - if seen_nontrivial_unequal: - raise NotImplementedError("Non-contiguous slices of semaphore arrays") - seen_nontrivial_unequal = True - strides = [] - stride = 1 - for s in src_shape[::-1]: - strides.append(stride) - stride *= s - strides = reversed(strides) - indices = [ctx.get(idx) for idx in op.base_idx] - linear_offset = " + ".join(f"{idx} * {s}" for idx, s in zip(indices, strides)) - return GlobalSemaphoreModel( - base=f"{result.base} + {linear_offset}", length=math.prod(dst_shape) - ) - else: - assert isinstance(result, GlobalRefModel) - major_idx = ctx.get(op.base_idx[0], None) - if (not src_shape or src_shape[0] == dynamic or dst_shape[0] == dynamic - or result.length == 1 or major_idx is None): - return result - return GlobalRefModel(f"{result.base} + {major_idx}", dst_shape[0]) - case "tpu.memref_squeeze": - result = ctx.get(op.input, None) - return NotImplemented if result is None else result - case "tpu.assume_multiple": - result = ctx.get(op.value, None) - return NotImplemented if result is None else result - case "arith.addi": - return bin_op(ctx, "int", "+", *op.operands) - case "arith.subi": - return bin_op(ctx, "int", "-", *op.operands) - case "arith.muli": - return bin_op(ctx, "int", "*", *op.operands) - case "arith.remsi": - # TODO(apaszke): Make sure this has right semantics for negative integers. - return bin_op(ctx, "int", "%", *op.operands) - case "arith.divsi": - return bin_op(ctx, "int", "/", *op.operands) - case "arith.andi": - return bin_op(ctx, _model_type(op.result.type), "&", *op.operands) - case "arith.select": - cond, if_true, if_false = map(lambda o: ctx.get(o, None), op.operands) - if cond is None or if_true is None or if_false is None: - return NotImplemented - result_ty = _model_type(op.result.type) - return ctx.emit(result_ty, f"({cond} -> {if_true} : {if_false})") - case "arith.index_cast": - model = ctx.get(op.operands[0], None) - return ctx.emit("int", model) if model is not None else NotImplemented - case "arith.cmpi": - match op.predicate.value: - case arith.CmpIPredicate.eq: - return bin_op(ctx, "bool", "==", *op.operands) - case arith.CmpIPredicate.ne: - return bin_op(ctx, "bool", "!=", *op.operands) - case arith.CmpIPredicate.slt: - return bin_op(ctx, "bool", "<", *op.operands) - case arith.CmpIPredicate.sle: - return bin_op(ctx, "bool", "<=", *op.operands) - case arith.CmpIPredicate.sgt: - return bin_op(ctx, "bool", ">", *op.operands) - case arith.CmpIPredicate.sge: - return bin_op(ctx, "bool", ">=", *op.operands) - return bin_op(ctx, "bool", "/", *op.operands) - case "tpu.trace_start": - ctx.comment(op.message.value) - case "tpu.assume_multiple": - # TODO(apaszke): Add an assertion - return ctx.get(op.value, NotImplemented) - case "verification.pretend": - read_refs = [] - for o in op.operands: - if (model := ctx.get(o, None)) is None: - raise ValueError(f"Could not model the read of {o}") - read_refs.append(model) - with ctx.block("d_step {", "}"): # Start reading - for r in read_refs: - for loc in r.written_at(None): - ctx.emit(None, f"assert(!{loc})") - for loc in r.readers_at(None): - ctx.emit(None, f"{loc}++") - with ctx.block("d_step {", "}"): # Stop reading - for r in read_refs: - for loc in r.readers_at(None): - ctx.emit(None, f"{loc}--") - case "vector.load": - ref = ctx.get(op.operands[0]) - assert isinstance(ref, GlobalRefModel) - if (first_idx := ctx.get(op.operands[1], None)) is not None: - leading_load_len = ir.VectorType(op.result.type).shape[0] - ref = GlobalRefModel(f"{ref.base} + {first_idx}", leading_load_len) - with ctx.block("d_step {", "}"): # Start reading - for loc in ref.written_at(None): - ctx.emit(None, f"assert(!{loc})") - for loc in ref.readers_at(None): - ctx.emit(None, f"{loc}++") - with ctx.block("d_step {", "}"): # Stop reading - for loc in ref.readers_at(None): - ctx.emit(None, f"{loc}--") - return NotImplemented # We don't model the result of the load. - case "vector.store": - ref = ctx.get(op.operands[1]) # Stored value goes first - assert isinstance(ref, GlobalRefModel) - if (first_idx := ctx.get(op.operands[2], None)) is not None: - leading_store_len = ir.VectorType(op.operands[0].type).shape[0] - ref = GlobalRefModel(f"{ref.base} + {first_idx}", leading_store_len) - with ctx.block("d_step {", "}"): # Start writing - for loc in ref.readers_at(None): - ctx.emit(None, f"assert(!{loc})") - for loc in ref.written_at(None): - ctx.emit(None, f"assert(!{loc})") - ctx.emit(None, f"{loc} = 1") - with ctx.block("d_step {", "}"): # Stop reading - for loc in ref.written_at(None): - ctx.emit(None, f"{loc} = 0") - case "scf.for": - carrys = [ - ctx.emit("int", ctx.get(arg)) - if ir.IntegerType.isinstance(arg.type) else None - for arg in op.initArgs - ] - bounds = (op.lowerBound, op.upperBound, op.step) - lower, upper, step = bound_models = map(ctx.get, bounds) - for model, v in zip(bound_models, bounds): - if model is None: - raise ValueError(f"Could not model loop bound or step: {v}") - induction_var = ctx.emit("int", lower) - with ctx.block("do", "od"): - ctx.emit(None, f":: {induction_var} < {upper}; ") - ctx.set(op.induction_variable, induction_var) - for c, arg in zip(carrys, op.inner_iter_args, strict=True): - if c is not None: - ctx.set(arg, c) - _print_block(ctx, op.body) - terminator = op.body.operations[len(op.body.operations) - 1] - new_carrys = terminator.operands - with ctx.block("d_step {", "}"): - for c, new in zip(carrys, new_carrys, strict=True): - if c is not None: - ctx.emit(None, f"{c} = {ctx.get(new)}") - ctx.emit(None, f"{induction_var} = {induction_var} + {step}") - ctx.emit(None, ":: else -> break") - ctx.emit(None, "skip") # To avoid "Jump into d_step sequence errors" - if len(carrys) == 1: - return carrys[0] - else: - return tuple(carrys) - case "scf.if": - if op.results: - raise NotImplementedError - if (condition := ctx.get(op.condition, None)) is None: - raise ValueError(f"Could not model branch condition: {op.condition}") - with ctx.block("if", "fi"): - ctx.emit(None, f":: ({condition})") - _print_block(ctx, op.then_block) - if op.regions[1].blocks: - ctx.emit(None, ":: else") - _print_block(ctx, op.else_block) - else: - ctx.emit(None, ":: else -> skip") - case _: - if not op.regions: - return NotImplemented - raise NotImplementedError("Must handle all ops with regions") - - -def bin_op(ctx, result_ty, op, lhs, rhs): - lhs = ctx.get(lhs, None) - rhs = ctx.get(rhs, None) - if lhs is None or rhs is None: - return NotImplemented - return ctx.emit(result_ty, f"{lhs} {op} {rhs}") - - -def _model_type(ty): - if ir.IntegerType.isinstance(ty): - if ir.IntegerType(ty).width == 1: - return "bool" - else: - return "int" - else: - raise NotImplementedError(ty) - - -def _print_block(ctx, block): - for op in block: - try: - with ctx.comment_if_emitted(op.OPERATION_NAME): - results = _print_op(ctx, op) - except Exception as e: - raise RuntimeError(f"Failed to print op: {op}") from e - if results is NotImplemented: - continue - if not op.results: - assert results is None or results == () - elif len(op.results) > 1: - raise NotImplementedError(op) - else: - ctx.set(op.result, results) - - -def export_promela_model( - module, num_devices: int, num_cores_per_device: int -) -> str: - with module.context: - _, uses_barrier_semaphores = tpu.private_has_communication(module.operation) - # Clone the module and simplify it to make the model smaller and simpler. - module = ir.Module.parse(module.operation.get_asm(binary=True)) - passes = ["canonicalize", "cse"] - pipeline = PassManager.parse(f"builtin.module({','.join(passes)})") - pipeline.run(module.operation) - main_str_attr = ir.StringAttr.get("main") - for f in module.body: - if getattr(f, "name", None) == main_str_attr: - break - else: - raise ValueError("No main function found") - assert isinstance(f, func.FuncOp) - - iteration_bounds: Sequence[int] = () - if "iteration_bounds" in f.attributes: - iteration_bounds = ir.DenseI64ArrayAttr(f.attributes["iteration_bounds"]) # type: ignore - dynamic = ir.ShapedType.get_dynamic_size() - if any(b == dynamic for b in iteration_bounds): - raise ValueError("Dynamic iteration bounds not supported") - - dimension_semantics = ir.ArrayAttr(f.attributes["dimension_semantics"]) - - parallel = ir.Attribute.parse("#tpu.dimension_semantics") - if any(s != parallel for s in dimension_semantics): - raise NotImplementedError("Non-parallel dimensions not supported") - - num_scalar_prefetch = 0 - if "scalar_prefetch" in f.attributes: - num_scalar_prefetch = ir.IntegerAttr(f.attributes["scalar_prefetch"]).value - - (entry_block,) = f.body - ctx = PrintCtx(iteration_bounds) - sem_ty = ir.Type.parse("!tpu.semaphore") - dma_sem_ty = ir.Type.parse("!tpu.dma_semaphore") - program_id_args, prefetch_args, other_args = split_list( - entry_block.arguments, [len(iteration_bounds), num_scalar_prefetch] - ) - for arg, model in zip(program_id_args, ctx.program_ids, strict=True): - ctx.set(arg, model) - del prefetch_args # We ignore prefetch_args - for arg in other_args: - if ir.MemRefType.isinstance(arg.type): - ty = ir.MemRefType(arg.type) - if ty.element_type == sem_ty or ty.element_type == dma_sem_ty: - ctx.set(arg, ctx.emit_global_semaphore_ref(ty.shape)) - else: - ctx.set(arg, ctx.emit_global_ref(ty.shape)) - _print_block(ctx, entry_block) - return ctx.get_model( - uses_barrier_semaphores, num_devices, num_cores_per_device, iteration_bounds - ) - - -assume_p = jax_core.Primitive("assume_for_verification") -assume_p.def_impl(lambda x, y: x) - -@assume_p.def_abstract_eval -def _assume_abstract_eval(x, y): - assert jax_core.typematch(x, y) - return x - -def _assume_lowering(ctx: lowering.LoweringRuleContext, x, y): - return y if ctx.lowering_context.for_verification else x - -lowering.lowering_rules[assume_p] = _assume_lowering # type: ignore - -def assume(normally, *, when_verifying): - return assume_p.bind(normally, when_verifying) - - -pretend_p = jax_core.Primitive("pretend_for_verification") -pretend_p.multiple_results = True - -@pretend_p.def_abstract_eval -def _pretend_abstract_eval(*_, **params): - del params # Unused. - return () - -def _pretend_lowering(ctx: lowering.LoweringRuleContext, *flat_args, tree): - if ctx.lowering_context.for_verification: - (base_read_refs, transforms) = tree_util.tree_unflatten(tree, flat_args) - read_ref_avals, _ = tree_util.tree_unflatten(tree, ctx.avals_in) - block_shapes, _ = tree_util.tree_unflatten(tree, ctx.block_shapes) - read_refs = [ - lowering._index_ref(ref, aval, block_shape, indexer)[0] - for ref, aval, block_shape, indexer in zip( - base_read_refs, - read_ref_avals, - block_shapes, - transforms, - strict=True, - ) - ] - ir.Operation.create("verification.pretend", operands=read_refs) - return () - -lowering.lowering_rules[pretend_p] = _pretend_lowering # type: ignore - -def pretend(read_refs): - refs, transforms = unzip2( - primitives._get_ref_and_transforms(r) for r in read_refs - ) - flat_args, tree = tree_util.tree_flatten((refs, transforms)) - return pretend_p.bind(*flat_args, tree=tree) - - -def skip(f): - """Skips the verification of the given function.""" - def wrapper(*args, **kwargs): - is_not_verifying = assume(normally=1, when_verifying=0) - lax.cond(is_not_verifying)(lambda: f(*args, **kwargs)) - return wrapper - - -def define_model(model): - """Replaces a function with its simplified model during verification.""" - def decorator(f): - def wrapper(*args, **kwargs): - lax.cond( - assume(normally=1, when_verifying=0), - lambda: f(*args, **kwargs), - lambda: model(*args, **kwargs), - ) - return wrapper - return decorator diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index e5b491aef330..74bf46056bd6 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -42,13 +42,17 @@ pytype_strict_library( name = "pallas_call_registration", srcs = ["pallas_call_registration.py"], deps = [ + ":core", ":lowering", "//jax", - "//jax:core", - "//jax:mlir", - "//jax:mosaic_gpu", + "//jax/_src:config", + "//jax/_src:core", + "//jax/_src:frozen_dict", + "//jax/_src:mlir", + "//jax/_src:sharding_impls", "//jax/_src/pallas", - ], + "//jax/experimental:mosaic_gpu", + ] + py_deps("numpy"), ) pytype_strict_library( @@ -57,30 +61,50 @@ pytype_strict_library( deps = [ ":core", "//jax", - "//jax:core", - "//jax:mlir", - "//jax:mosaic_gpu", - "//jax:pallas", - "//jax:partial_eval", - "//jax:source_info_util", - "//jax:util", + "//jax/_src:api", + "//jax/_src:checkify", + "//jax/_src:config", + "//jax/_src:core", + "//jax/_src:debugging", + "//jax/_src:dtypes", + "//jax/_src:lax", + "//jax/_src:literals", + "//jax/_src:mesh", + "//jax/_src:mlir", + "//jax/_src:partial_eval", + "//jax/_src:source_info_util", + "//jax/_src:state_types", + "//jax/_src:tree_util", + "//jax/_src:util", "//jax/_src/lib", "//jax/_src/pallas", + "//jax/experimental:mosaic_gpu", + "//jax/experimental:pallas", ] + py_deps("numpy"), ) pytype_strict_library( name = "core", srcs = ["core.py"], + tags = [ + # Avoid circular dependency. + "ignore_for_dep=third_party.py.jax._src.pallas.mosaic_gpu.primitives.wgmma_accumulator_deref", + ], deps = [ "//jax", - "//jax:core", - "//jax:dtypes", - "//jax:effects", - "//jax:mosaic_gpu", - "//jax:state_types", - "//jax:tree_util", + "//jax/_src:core", + "//jax/_src:custom_batching", + "//jax/_src:dtypes", + "//jax/_src:effects", + "//jax/_src:frozen_dict", + "//jax/_src:lax", + "//jax/_src:pretty_printer", + "//jax/_src:state_types", + "//jax/_src:tree_util", + "//jax/_src:util", + "//jax/_src/lib", "//jax/_src/pallas", + "//jax/experimental:mosaic_gpu", "//jaxlib/mlir:ir", ] + py_deps("numpy"), ) @@ -92,14 +116,19 @@ pytype_strict_library( ":core", ":lowering", "//jax", - "//jax:core", - "//jax:mlir", - "//jax:mosaic_gpu", - "//jax:tree_util", - "//jax:util", + "//jax/_src:core", + "//jax/_src:debugging", + "//jax/_src:dtypes", + "//jax/_src:lax", + "//jax/_src:literals", + "//jax/_src:pretty_printer", + "//jax/_src:state_types", + "//jax/_src:tree_util", + "//jax/_src:util", "//jax/_src/lib", "//jax/_src/pallas", - ], + "//jax/experimental:mosaic_gpu", + ] + py_deps("numpy"), ) pytype_strict_library( @@ -109,11 +138,39 @@ pytype_strict_library( ":core", ":primitives", "//jax", - "//jax:core", - "//jax:mosaic_gpu", - "//jax:pallas", - "//jax:partial_eval", - "//jax:util", + "//jax/_src:core", + "//jax/_src:partial_eval", + "//jax/_src:state_types", + "//jax/_src:util", "//jax/_src/pallas", + "//jax/experimental:mosaic_gpu", + "//jax/experimental:pallas", ], ) + +pytype_strict_library( + name = "helpers", + srcs = ["helpers.py"], + deps = [ + ":core", + ":primitives", + "//jax", + "//jax/_src:dtypes", + "//jax/_src:util", + "//jax/_src/lib", + "//jax/_src/pallas", + "//jax/experimental:mosaic_gpu", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "torch", + srcs = ["torch.py"], + deps = [ + "//jax", + "//jax/_src:dtypes", + "//jax/_src:util", + "//jax/_src/lib", + "//jax/experimental:mosaic_gpu", + ] + py_deps("numpy"), +) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 630c1b8f4bed..69b16945684c 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -18,30 +18,48 @@ import abc import collections -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence import dataclasses import enum +import functools import itertools as it -from typing import Any, ClassVar, Literal +import math +from typing import Any, ClassVar, Literal, Union import jax from jax._src import core as jax_core +from jax._src import custom_batching from jax._src import dtypes from jax._src import effects +from jax._src import frozen_dict +from jax._src import lax +from jax._src import pretty_printer as pp +from jax._src import state from jax._src import tree_util +from jax._src import util +from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.pallas import core as pallas_core +from jax._src.pallas import primitives as pallas_primitives +from jax._src.state import discharge as state_discharge from jax._src.state import indexing from jax._src.state import types as state_types -from jax._src.state import discharge as state_discharge import jax.experimental.mosaic.gpu as mgpu +from jax.experimental.mosaic.gpu import tcgen05 +from jax.experimental.mosaic.gpu import utils as mgpu_utils import jax.numpy as jnp from jaxlib.mlir import ir -AbstractMemoryRef = pallas_core.AbstractMemoryRef +_Ref = state.AbstractRef | state_types.TransformedRef DimensionSemantics = Literal["parallel", "sequential"] +# We align all our SMEM allocations to 1024 bytes. TMA and WGMMA are very +# sensitive to alignment and while this is quite conservative, it gets the job +# done. We should make this more refined in the future. +SMEM_ALIGNMENT = 1024 +TMEM_COL_ALIGNMENT = 4 + def is_trivial_index(idx, shape) -> bool: """Checks if the index selects the entire shape.""" @@ -58,7 +76,7 @@ def _slices(d): @dataclasses.dataclass(frozen=True, kw_only=True) -class GPUCompilerParams(pallas_core.CompilerParams): +class CompilerParams(pallas_core.CompilerParams): """Mosaic GPU compiler parameters. Attributes: @@ -74,32 +92,48 @@ class GPUCompilerParams(pallas_core.CompilerParams): references. Defaults to 0, and must be strictly smaller than max_concurrent_steps. Generally, you'll want to set it to 1 if you don't await the WGMMA in the body. + unsafe_no_auto_barriers: If True, Pallas will never automatically insert + barrier instructions that ensure synchronous semantics of loads and stores. + At the moment, the insertion is done conservatively and might regress + performance. There are (at least) two conditions that must be satisfied + for the use of this flag to be safe. First, no memory region is ever read + *and* written to by the same thread (async copies are performed by + background threads and do not count towards this rule). Secondly, no + thread ever calls commit_smem(), reads from the committed SMEM and then + issues an async copy overwriting that region (this is a very artificial + and highly unlikely scenario). profile_space: The number of profiler events that can be collected in a single invocation. It is undefined behavior if a thread collects more events than this. profile_dir: The directory to which profiling traces will be written to. """ - PLATFORM: ClassVar[str] = "mosaic_gpu" + BACKEND: ClassVar[pallas_core.Backend] = "mosaic_gpu" approx_math: bool = False dimension_semantics: Sequence[DimensionSemantics] | None = None max_concurrent_steps: int = 1 - delay_release: int = 0 + unsafe_no_auto_barriers: bool = False profile_space: int = 0 profile_dir: str = "" - thread_semantics: mgpu.core.ThreadSemantics = mgpu.core.ThreadSemantics.Lane + lowering_semantics: mgpu.core.LoweringSemantics = mgpu.core.LoweringSemantics.Lane def __post_init__(self): + if self.dimension_semantics is not None: + object.__setattr__( + self, "dimension_semantics", tuple(self.dimension_semantics) + ) if bool(self.profile_space) ^ bool(self.profile_dir): raise ValueError( "Either both profile_space and profile_dir must be set, or neither." ) -class GPUMemorySpace(enum.Enum): +class MemorySpace(enum.Enum): #: Global memory. GMEM = "gmem" #: Shared memory. SMEM = "smem" + #: Tensor memory. New addition to Blackwell. Not available on Hopper. + TMEM = "tmem" #: Registers. REGS = "regs" @@ -108,30 +142,199 @@ def __str__(self) -> str: def __call__( self, - shape: tuple[int, ...], + shape: Sequence[int], dtype: jnp.dtype, + *, transforms: Sequence[MemoryRefTransform] = (), - + packed: bool | None = None, + collective: bool | None = None, + layout: TMEMLayout | None = None, ) -> pallas_core.MemoryRef: - # A convenience function for constructing MemoryRef types. - return GPUMemoryRef(shape, dtype, memory_space=self, transforms=transforms) - - -def kernel(body, out_shape, compiler_params=None, **mesh_kwargs): + shape = tuple(shape) + # TODO(sharadmv): Add HiType constructor support. + if self == MemorySpace.TMEM: + if transforms: + raise ValueError("transforms are not supported for TMEM") + if collective is None: + collective = False + if layout is None: + if packed is None: + if dtypes.itemsize_bits(dtype) != 32: + raise ValueError( + "dtypes narrower than 32-bit require either the packed argument" + " or an explicit TMEM layout" + ) + packed = False + mgpu_layout = infer_tmem_layout( + shape, dtype, packed=packed, collective=collective + ) + else: + if packed is not None: + raise ValueError("packed cannot be specified if layout is specified.") + mgpu_layout = layout.to_mgpu() + else: + if packed is not None or collective is not None or layout is not None: + raise ValueError("packed, collective and layout arguments are only supported for TMEM.") + mgpu_layout = None + return GPUMemoryRef(jax_core.ShapedArray(shape, dtype), memory_space=self, + transforms=transforms, layout=mgpu_layout, + collective=collective) + + +class SemaphoreType(enum.Enum): + REGULAR = "regular" + BARRIER = "barrier" + + def __call__(self, shape: tuple[int, ...]): + dtype: Any + if self == SemaphoreType.BARRIER: + dtype = pallas_core.BarrierSemaphore() + else: + dtype = pallas_core.Semaphore() + return pallas_core.MemoryRef(jax_core.ShapedArray(shape, dtype), + MemorySpace.GMEM) + + def get_array_aval(self) -> jax_core.ShapedArray: + return self(()).get_array_aval() + + def get_ref_aval(self) -> _Ref: + return self(()).get_ref_aval() + + +class PrimitiveSemantics(enum.Enum): + """Thread semantics for a primitives at the Pallas user-level.""" + + Warp = enum.auto() + Warpgroup = enum.auto() + + +# Convenience constants for (lowering, primitive) thread semantics pairs. +LANExWG_SEMANTICS = ( + mgpu.LoweringSemantics.Lane, PrimitiveSemantics.Warpgroup) +LANExWARP_SEMANTICS = ( + mgpu.LoweringSemantics.Lane, PrimitiveSemantics.Warp) +WGxWG_SEMANTICS = ( + mgpu.LoweringSemantics.Warpgroup, PrimitiveSemantics.Warpgroup) + + +# TODO(justinfu): Reconcile with pl.kernel. +def kernel( + body: Callable[..., None], + out_shape: object, + *, + scratch_shapes: pallas_core.ScratchShapeTree = (), + compiler_params: pallas_core.CompilerParams | None = None, + # Mesh kwargs + grid: tuple[int, ...] = (), + grid_names: tuple[str, ...] = (), + cluster: tuple[int, ...] = (), + cluster_names: tuple[str, ...] = (), + num_threads: int | None = None, + thread_name: str | None = None, + interpret: Any = None, + **mesh_kwargs: object, +): + """Entry point for defining a Mosaic GPU kernel. + + Args: + body: The kernel body, which should take as arguments the input, output, + and scratch Refs. The number of input Refs is determined by the number + of arguments passed into kernel returned by this function. The number of + output and scratch Refs are determined by `out_shape` and `scratch_shapes` + respectively. + out_shape: a PyTree of :class:`jax.ShapeDtypeStruct` describing the shape + and dtypes of the outputs. + scratch_shapes: an iterable (may be nested) of GPUMemoryRef describing + scratch Refs to allocate for this kernel. + compiler_params: Additional compiler options. See the `CompilerParams` + dataclass for more details. + grid: A tuple of integers specifying the size of the kernel grid. + grid_names: The axis names of the grid. Must be the same length as `grid`. + cluster: A tuple of integers specifying the size of the kernel cluster. + cluster_names: The axis names of the grid. Must be the same length as + `cluster`. + num_threads: The number of threads to launch per block. Note that these + do not correspond to CUDA threads, but rather to warpgroups on Hopper + and Blackwell GPUs. + thread_name: The axis name used to query the thread index. + **mesh_kwargs: Additional mesh kwargs. See `Mesh` for more details. + + Returns: + A function that runs the kernel. It should take any number of input + operands and returns an output with the same PyTree structure as + `out_shape`. + """ if unwrap_out := not isinstance(out_shape, (tuple, list)): out_shape = (out_shape,) + + @custom_batching.custom_vmap def wrapper(*operands): def stateful(operand_and_out_refs): operand_refs, out_refs = operand_and_out_refs + mesh = Mesh( + grid=grid, + grid_names=grid_names, + cluster=cluster, + cluster_names=cluster_names, + num_threads=num_threads, + thread_name=thread_name, + **mesh_kwargs) + _thread_name = mesh.thread_name if mesh.thread_name is not None else () def cmap_body(): - body(*operand_refs, *out_refs) + pallas_primitives.run_scoped( + functools.partial(body, *operand_refs, *out_refs), + *(scratch_shapes if isinstance(scratch_shapes, Sequence) else ()), + collective_axes=_thread_name, + **(scratch_shapes if isinstance(scratch_shapes, Mapping) else {}), + ) + if mesh.kernel_name is not None: + cmap_body.__name__ = mesh.kernel_name + else: + # The body function name is used to set the name of the kernel as a + # fallback if the kernel name is not set explicitly. + cmap_body.__name__ = getattr(body, "__name__", "anonymous") pallas_core.core_map( - GPUMesh(**mesh_kwargs), compiler_params=compiler_params + mesh, compiler_params=compiler_params, interpret=interpret )(cmap_body) - _, outs = state_discharge.run_state(stateful)( - (operands, jax.tree.map(jnp.zeros_like, out_shape)) - ) + _, outs = state_discharge.run_state(stateful)(( + operands, + jax.tree.map(lambda s: jax.lax.empty(s.shape, s.dtype), out_shape), + )) return outs[0] if unwrap_out else outs + + @wrapper.def_vmap + def _vmap_rule(axis_size, in_batched, *args): + axis_name = object() + + def batched_body(*refs): + idx = lax.axis_index(axis_name) + lens = (len(args), len(out_shape)) + operand_refs, out_refs, scratch_refs = util.split_list(refs, lens) + slice_ref = lambda r, b=True: (r.at[idx] if b else r) + operand_refs = tree_util.tree_map(slice_ref, operand_refs, in_batched) + out_refs = tree_util.tree_map(slice_ref, out_refs) + return body(*operand_refs, *out_refs, *scratch_refs) + + out_shape_ = out_shape[0] if unwrap_out else out_shape + add_batch_dim = lambda x: x.update(shape=(axis_size, *x.shape)) + mesh_kwargs_ = dict(mesh_kwargs) + out = kernel( + batched_body, + out_shape=tree_util.tree_map(add_batch_dim, out_shape_), + scratch_shapes=scratch_shapes, + compiler_params=compiler_params, + grid=(axis_size,) + grid, + grid_names=(axis_name,) + grid_names, + cluster=cluster, + cluster_names=cluster_names, + num_threads=num_threads, + thread_name=thread_name, + interpret=interpret, + **mesh_kwargs_, + )(*args) + out_batched = tree_util.tree_map(lambda _: True, out_shape_) + return out, out_batched + return wrapper @@ -139,13 +342,26 @@ def cmap_body(): class GPUMemoryRef(pallas_core.MemoryRef): transforms: Sequence[MemoryRefTransform] = () - def get_ref_aval(self) -> pallas_core.TransformedRef | AbstractMemoryRef: - aval = jax_core.ShapedArray(self.shape, self.dtype) + layout: tcgen05.TMEMLayout | None = dataclasses.field(default=None, kw_only=True) + collective: bool | None = dataclasses.field(default=None, kw_only=True) + + def __post_init__(self): + is_tmem = self.memory_space == MemorySpace.TMEM + assert (self.layout is not None) == is_tmem + assert (self.collective is not None) == is_tmem + assert not (self.transforms and is_tmem) + + def get_ref_aval(self) -> _Ref: + aval: Any = jax_core.ShapedArray(self.shape, self.dtype) for t in self.transforms: aval = t(aval) - ref = pallas_core.TransformedRef( - AbstractMemoryRef(aval, memory_space=self.memory_space), () - ) + if self.memory_space == MemorySpace.TMEM: + aval = AbstractTMEMRef( + aval, self.memory_space, self.layout, self.collective + ) + else: + aval = state.AbstractRef(aval, memory_space=self.memory_space) + ref = pallas_core.TransformedRef(aval, ()) for t in reversed(self.transforms): ref = t.undo(ref) if not ref.transforms: @@ -153,11 +369,234 @@ def get_ref_aval(self) -> pallas_core.TransformedRef | AbstractMemoryRef: return ref +def align_to(x: int, alignment: int): + if rem := x % alignment: + return x + alignment - rem + return x + + +# A tree of `GPUMemoryRef`s. +_GPUMemoryRefTree = Any + + +def _ref_group_size(refs: _GPUMemoryRefTree) -> int: + size = 0 + for ref in jax.tree.leaves(refs): + # Make sure that the start of each ref is aligned with `SMEM_ALIGNMENT`. + size = align_to(size, SMEM_ALIGNMENT) + if jnp.issubdtype(ref.dtype, jnp.integer): + nbits = jnp.iinfo(ref.dtype).bits + elif jnp.issubdtype(ref.dtype, jnp.floating): + nbits = jnp.finfo(ref.dtype).bits + else: + raise NotImplementedError(f"Unsupported dtype: {ref.dtype}") + ref_bits = math.prod(ref.shape) * nbits + if ref_bits % 8: + raise ValueError( + "Only byte-aligned shapes are supported. Got shape:" + f" {ref.dtype}{ref.shape}" + ) + size += ref_bits // 8 + return size + + +def _ref_group_tmem_col_size(refs: _GPUMemoryRefTree) -> int: + """Returns the total number of TMEM columns used by a group of aliased Refs. + """ + ncols = 0 + for ref in jax.tree.leaves(refs): + ref_ncols = ref.layout.cols_in_shape(ref.shape, + dtypes.itemsize_bits(ref.dtype)) + ncols += align_to(ref_ncols, TMEM_COL_ALIGNMENT) + return ncols + + +def infer_tmem_layout( + shape: tuple[int, ...], + dtype: jnp.dtype, + *, + packed: bool, + collective: bool) -> tcgen05.TMEMLayout: + """Infers the number of columns used and layout for allocating TMEM Refs.""" + if packed: + packing = 32 // dtypes.itemsize_bits(dtype) + else: + packing = 1 + return tcgen05._infer_tmem_layout(shape, collective=collective, packing=packing) # type: ignore + + +def flatten_ref_union(ref_union: AbstractRefUnion) -> tuple[_Ref, ...]: + """Flattens a union of trees of references into a tuple of references. + + This is the moral equivalent of `jax.tree.leaves` for aliased references. + """ + flat_refs = [] + if ref_union.memory_space == SMEM: + union_bytes = 0 + for ref_group in ref_union.refs: + byte_offset = 0 + def unflatten(ref): + nonlocal byte_offset + byte_offset = align_to(byte_offset, SMEM_ALIGNMENT) + assert isinstance(ref, state.AbstractRef) or isinstance( + ref, pallas_core.TransformedRef + ) + if not isinstance(ref, pallas_core.TransformedRef): + ref = pallas_core.TransformedRef(ref, transforms=()) + transform = ExtractAliasedRef.from_transformed_ref(ref, byte_offset) + result = pallas_core.TransformedRef( + ref_union, transforms=(transform, *ref.transforms) + ) + if jnp.issubdtype(ref.dtype, jnp.integer): + nbits = jnp.iinfo(ref.dtype).bits + elif jnp.issubdtype(ref.dtype, jnp.floating): + nbits = jnp.finfo(ref.dtype).bits + else: + raise NotImplementedError(f"Unsupported dtype: {ref.dtype}") + ref_bits = math.prod(ref.shape) * nbits + if ref_bits % 8: + raise ValueError( + "Only byte-aligned shapes are supported. Got shape:" + f" {ref.dtype}{ref.shape}" + ) + byte_offset += ref_bits // 8 + return result + flat_refs.append(jax.tree.map(unflatten, ref_group)) + union_bytes = max(union_bytes, byte_offset) + assert union_bytes == ref_union.shape[0] + elif ref_union.memory_space == TMEM: + union_cols = 0 + for ref_group in ref_union.refs: + col_offset = 0 + def unflatten(ref): + nonlocal col_offset + col_offset = align_to(col_offset, TMEM_COL_ALIGNMENT) + if not isinstance(ref, pallas_core.TransformedRef): + ref = pallas_core.TransformedRef(ref, transforms=()) + ncols = ref.layout.cols_in_shape(ref.shape, + dtypes.itemsize_bits(ref.dtype)) + transform = ExtractAliasedRef.from_transformed_ref( + ref, col_offset, layout=ref.layout) + result = pallas_core.TransformedRef( + ref_union, transforms=(transform, *ref.transforms) + ) + col_offset += ncols + return result + flat_refs.append(jax.tree.map(unflatten, ref_group)) + union_cols = max(union_cols, col_offset) + assert union_cols == ref_union.shape[1], (union_cols, ref_union.shape[1]) + else: + raise NotImplementedError("Only SMEM and TMEM refs are supported.") + return tuple(flat_refs) + + +class AbstractRefUnion(state.AbstractRef): + refs: Sequence[_GPUMemoryRefTree] + + def __init__( + self, + aval, + refs: Sequence[_GPUMemoryRefTree], + memory_space, + ): + self.refs = refs + super().__init__(aval, memory_space=memory_space) + + def _iter(self, tracer): + return iter(flatten_ref_union(tracer)) + + def _getitem(self, tracer, index): + return list(iter(tracer))[index] + + def _setitem(self, tracer, index, value): + del tracer, index, value # Unused. + raise ValueError("Ref unions can't be assigned to.") + + def update(self, inner_aval=None, memory_space=None, kind=None): + ref = super().update(inner_aval, memory_space, kind) + return AbstractRefUnion(ref.inner_aval, self.refs, self.memory_space) + + @functools.cached_property + def layout(self) -> tcgen05.TMEMLayout: + if self.memory_space != TMEM: + raise ValueError("layout attribute is only defined for TMEM refs") + return tcgen05.tmem_default_layout(packing=1) + + @functools.cached_property + def collective(self) -> bool: + if self.memory_space != TMEM: + raise ValueError("collective attribute is only defined for TMEM refs") + ref_leaves = jax.tree.leaves(self.refs) + first_ref = ref_leaves[0] + assert all(ref.collective == first_ref.collective for ref in ref_leaves) + return first_ref.collective + + +@dataclasses.dataclass(init=False, frozen=True) +class RefUnion(GPUMemoryRef): + """A sequence of trees of refs that are allowed to reuse the same memory. + + One should not make assumptions as to how each ref will map to the underlying + memory region, since arbitrary padding may be applied in between different + refs. + + As such, ref unions are only safe to use when the groups of refs that we + intend to alias have disjoint lifetimes (i.e. one should never attempt to read + data using a different ref than the one that was used to write the data). + """ + refs: Sequence[_GPUMemoryRefTree] = () + + def __init__(self, *refs: _GPUMemoryRefTree): + ref_leaves = jax.tree.leaves(refs) + if all(ref.memory_space == SMEM for ref in ref_leaves): + object.__setattr__(self, "refs", refs) + num_bytes = max(map(_ref_group_size, self.refs)) + super().__init__( + inner_aval=jax_core.ShapedArray( + (num_bytes,), jnp.int8 + ), + memory_space=SMEM, + transforms=(), + ) + elif all(ref.memory_space == TMEM for ref in ref_leaves): + object.__setattr__(self, "refs", refs) + max_cols = max(map(_ref_group_tmem_col_size, self.refs)) + is_collective = ref_leaves[0].collective + if any(r.collective != is_collective for r in ref_leaves): + raise ValueError( + "Some aliased TMEM references are collective and some are not." + ) + super().__init__( + inner_aval=jax_core.ShapedArray( + shape=(128, max_cols,), + dtype=jnp.int32, + ), + memory_space=TMEM, + transforms=(), + layout=tcgen05.tmem_default_layout(packing=1), + collective=all(ref.collective for ref in ref_leaves), + ) + else: + raise NotImplementedError( + "All aliased Refs must have the same memory space (SMEM or TMEM). " + f"Got {(ref.memory_space for ref in ref_leaves)}.") + + def get_ref_aval(self) -> AbstractRefUnion: + inner_aval = jax.core.ShapedArray(self.shape, self.dtype) + refs_aval = jax.tree.map(lambda ref: ref.get_ref_aval(), self.refs) + return AbstractRefUnion(inner_aval, refs_aval, + memory_space=self.memory_space) + + class MemoryRefTransform(pallas_core.MemoryRefTransform, abc.ABC): @abc.abstractmethod def to_gpu_transform(self) -> mgpu.MemRefTransform: pass + @abc.abstractmethod + def to_gpu_transform_attr(self) -> ir.Attribute: + pass + def batch(self, leading_rank: int): """Returns a transform that accepts a ref with the extra `leading_rank` dims. @@ -171,7 +610,7 @@ def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray: shape=self.to_gpu_transform().transform_shape(aval.shape) ) -Index = slice | int | ir.Value +Index = Union[mgpu.DynamicSlice, slice, int, ir.Value] @dataclasses.dataclass(frozen=True) class TilingTransform(MemoryRefTransform): @@ -194,6 +633,9 @@ def batch(self, leading_rank: int): def to_gpu_transform(self) -> mgpu.MemRefTransform: return mgpu.TileTransform(self.tiling) + def to_gpu_transform_attr(self) -> ir.Attribute: + return mgpu.dialect.TileTransformAttr.get(self.tiling) + @tree_util.register_dataclass @dataclasses.dataclass(frozen=True) @@ -213,26 +655,82 @@ def transform_shape(self, shape): def transform_dtype(self, dtype): return dtype + def untransform_transpose( + self, perm: tuple[int, ...] + ) -> tuple[tuple[int, ...], state_types.Transform]: + # The transpose in question is applied to the utiled ref so we + # need to translate it by duplicating and offsetting the last part. + off = len(perm) + new_suffix = [i + off for i in perm[-len(self.tiling) :]] + if set(new_suffix) != set(range(off, off + len(self.tiling))): + raise ValueError( + "Transpose cannot be moved before a tiling transform when it changes" + f" the set of tiled dimensions. (permutation: {perm}, tiling:" + f" {self.tiling})" + ) + + new_tiling = tuple(self.tiling[i - off] for i in new_suffix) + return (*perm, *new_suffix), dataclasses.replace(self, tiling=new_tiling) + + def untransform_reshape( + self, dtype: jnp.dtype, shape: tuple[int, ...] + ) -> tuple[tuple[int, ...], state_types.Transform]: + del dtype + # TODO(slebedev): Support this. + raise NotImplementedError("Reshapes don't commute with tiling.") + def untransform_index( - self, idxs: tuple[Index, ...] + self, dtype: jnp.dtype | ir.Type, idxs: tuple[Index, ...] ) -> tuple[tuple[Index, ...], state_types.Transform]: + del dtype untiled_idxs = idxs[: -len(self.tiling)] tiled_idxs = idxs[-len(self.tiling) :] - idxs_after_tiling = [] + idxs_after_tiling: list[Index] = [] for idx, tile in zip(tiled_idxs, self.tiling): - if not isinstance(idx, slice): - raise NotImplementedError("Non-slice indices are not supported") - assert isinstance(idx, slice) - if idx.step is not None and idx.step != 1: - raise NotImplementedError("Strided slices unsupported") - if (idx.start is not None and idx.start % tile) or (idx.stop is not None and idx.stop % tile): - raise ValueError("Non-empty slices must be tile aligned") - idxs_after_tiling.append(slice(idx.start // tile, idx.stop // tile)) + if isinstance(idx, slice): + if idx.step is not None and idx.step != 1: + raise NotImplementedError( + f"Strided slices unsupported. Got stride: {idx.step}" + ) + if (idx.start is not None and idx.start % tile) or ( + idx.stop is not None and idx.stop % tile + ): + raise ValueError( + f"Expected slice start ({idx.start}) and slice stop ({idx.stop})" + f" to be divisible by the tile size ({tile})" + ) + idxs_after_tiling.append(slice(idx.start // tile, idx.stop // tile)) + elif isinstance(idx, mgpu.DynamicSlice): + if idx.length % tile: + raise ValueError( + f"Dynamic slice length ({idx.length}) is not divisible by the" + f" tiling ({tile})" + ) + if isinstance(idx.base, ir.Value): + if not mgpu_utils.is_known_divisible(idx.base, tile): + raise ValueError( + "Dynamic slice base index (which is a dynamic value) cannot be" + f" statically proven to be divisible by the tiling ({tile})" + ) + new_base = arith_dialect.divui(idx.base, mgpu.c(tile, idx.base.type)) + else: + if idx.base % tile: + raise ValueError( + f"Dynamic slice base ({idx.base}) is not divisible by the" + f" tiling ({tile})" + ) + new_base = idx.base // tile + idxs_after_tiling.append(mgpu.DynamicSlice(new_base, idx.length // tile)) + else: + raise TypeError(f"Unsupported index type: {type(idx)}") return (*untiled_idxs, *idxs_after_tiling, *(slice(None) for _ in self.tiling)), self def undo_to_gpu_transform(self) -> mgpu.MemRefTransform: return mgpu.TileTransform(self.tiling) + def pretty_print(self, context: jax_core.JaxprPpContext) -> pp.Doc: + return pp.text(f"{{untile({list(self.tiling)})}}") + def _perm_inverse(permutation: tuple[int, ...]) -> tuple[int, ...]: inverse = [-1] * len(permutation) @@ -267,25 +765,33 @@ def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef: def to_gpu_transform(self) -> mgpu.MemRefTransform: return mgpu.TransposeTransform(self.permutation) + def to_gpu_transform_attr(self) -> ir.Attribute: + return mgpu.dialect.TransposeTransformAttr.get(self.permutation) + @tree_util.register_dataclass @dataclasses.dataclass(frozen=True) -class TransposeRef(state_types.Transform): - permutation: tuple[int, ...] +class TransposeRef(state_types.RefTransposer): - def transform_shape(self, shape): - if shape is None: - return None - return tuple(shape[i] for i in self.permutation) + def untransform_transpose( + self, perm + ) -> tuple[tuple[int, ...], state_types.Transform]: + raise NotImplementedError( + "Commuting of transpose over transpose is not supported." + ) - def transform_dtype(self, dtype): - return dtype + def untransform_reshape( + self, dtype: jnp.dtype | ir.Type, shape: tuple[int, ...] + ) -> tuple[tuple[int, ...], state_types.Transform]: + del shape, dtype + raise NotImplementedError("Can't reshape a transposed memref.") def untransform_index( - self, idxs: tuple[Index, ...] + self, dtype: jnp.dtype | ir.Type, idxs: tuple[Index, ...] ) -> tuple[tuple[Index, ...], state_types.Transform]: + del dtype removed_dims = [ - i for i, idx in enumerate(idxs) if not isinstance(idx, slice) + i for i, idx in enumerate(idxs) if not isinstance(idx, (slice, mgpu.ds)) ] new_perm = tuple( p - sum(d < p for d in removed_dims) @@ -299,19 +805,160 @@ def undo_to_gpu_transform(self) -> mgpu.MemRefTransform: return mgpu.TransposeTransform(_perm_inverse(self.permutation)) -def transpose_ref( - ref: pallas_core.TransformedRef | Any, - permutation: tuple[int, ...], +@tree_util.register_pytree_node_class +@dataclasses.dataclass +class PeerMemRef(state_types.Transform): + device_id: Any + device_id_type: pallas_primitives.DeviceIdType + + def transform_shape(self, shape): + return shape + + def transform_dtype(self, dtype): + return dtype + + def untransform_index( + self, idxs: tuple[Index, ...] + ) -> tuple[tuple[Index, ...], state_types.Transform]: + return idxs, self + + def tree_flatten(self): + return (self.device_id,), (self.device_id_type,) + + @classmethod + def tree_unflatten(cls, metadata, arrays): + return cls(arrays[0], metadata[0]) + + +@tree_util.register_pytree_node_class +@dataclasses.dataclass +class MulticastRef(state_types.Transform): + collective_axes: tuple[Hashable, ...] + + def transform_shape(self, shape): + return shape + + def transform_dtype(self, dtype): + return dtype + + def untransform_index( + self, idxs: tuple[Index, ...] + ) -> tuple[tuple[Index, ...], state_types.Transform]: + return idxs, self + + def tree_flatten(self): + return (), self.collective_axes + + @classmethod + def tree_unflatten(cls, metadata, arrays): + return cls(metadata[0]) + + +def remote_ref( + ref: _Ref, + device_id: jax.typing.ArrayLike, + device_id_type: pallas_primitives.DeviceIdType = pallas_primitives.DeviceIdType.MESH, +) -> pallas_core.TransformedRef: + """Translate memref to a symmetric memref on a peer device.""" + if not isinstance(ref, pallas_core.TransformedRef): + if not isinstance(jax_core.get_aval(ref), state_types.AbstractRef): + raise TypeError("ref must be a reference") + ref = pallas_core.TransformedRef(ref, transforms=()) + if any(isinstance(t, MulticastRef) for t in ref.transforms): + raise ValueError("Can't make a multicast reference into a peer reference.") + return pallas_core.TransformedRef( + ref.ref, (*ref.transforms, PeerMemRef(device_id, device_id_type)), + ) + + +def multicast_ref( + ref: _Ref, + collective_axes: Hashable | tuple[Hashable, ...], ) -> pallas_core.TransformedRef: + """Return a multicast reference for cross-device operations. + + Args: + ref: The reference to transform. + collective_axes: The JAX mesh axes indicating the devices to operate on. + """ + if not isinstance(collective_axes, tuple): + collective_axes = (collective_axes,) if not isinstance(ref, pallas_core.TransformedRef): - if not isinstance(jax_core.get_aval(ref), pallas_core.AbstractMemoryRef): + if not isinstance(jax_core.get_aval(ref), state_types.AbstractRef): raise TypeError("ref must be a reference") ref = pallas_core.TransformedRef(ref, transforms=()) + if any(isinstance(t, PeerMemRef) for t in ref.transforms): + raise ValueError("Can't make a peer reference into a multicast reference.") return pallas_core.TransformedRef( - ref.ref, (*ref.transforms, TransposeRef(permutation)), + ref.ref, (*ref.transforms, MulticastRef(collective_axes)), ) +def transform_ref( + ref: pallas_core.TransformedRef, + transform: state_types.Transform +) -> pallas_core.TransformedRef: + if not isinstance(ref, pallas_core.TransformedRef): + if not isinstance(jax_core.get_aval(ref), state_types.AbstractRef): + raise TypeError("ref must be a reference") + ref = pallas_core.TransformedRef(ref, transforms=()) + return pallas_core.TransformedRef( + ref.ref, (*ref.transforms, transform), + ) + +def transpose_ref( + ref: pallas_core.TransformedRef | Any, + permutation: tuple[int, ...], +) -> pallas_core.TransformedRef: + assert hasattr(ref, "memory_space") + if ref.memory_space == MemorySpace.TMEM: + raise ValueError("Can't transpose a TMEM reference.") + return ref.transpose(permutation) + +def untile_ref(ref, tiling: tuple[int, ...]) -> pallas_core.TransformedRef: + return transform_ref(ref, UntileRef(tiling)) + +def unswizzle_ref(ref, swizzle: int) -> pallas_core.TransformedRef: + return transform_ref(ref, UnswizzleRef(swizzle)) + + +@tree_util.register_pytree_node_class +@dataclasses.dataclass(frozen=True) +class ExtractAliasedRef(state_types.Transform): + """Bitcasts the underlying ref at the given offset to the given shape and dtype.""" + dtype: dtypes.DType + shape: tuple[int, ...] + offset: int + # TMEM-specific params + layout: tcgen05.TMEMLayout | None + + @classmethod + def from_transformed_ref( + cls, + ref: pallas_core.TransformedRef, + byte_offset: int, + layout: tcgen05.TMEMLayout | None = None, + ): + return cls(dtypes.dtype(ref.dtype), ref.ref.shape, byte_offset, layout) + + def transform_shape(self, shape): + if shape is None: + return None + return self.shape + + def transform_dtype(self, dtype): + del dtype # Unused. + return self.dtype + + def tree_flatten(self): + return (), (self.dtype, self.shape, self.offset, self.layout) + + @classmethod + def tree_unflatten(cls, metadata, arrays): + assert not arrays + return cls(*metadata) + + @dataclasses.dataclass(frozen=True) class SwizzleTransform(MemoryRefTransform): swizzle: int @@ -334,12 +981,15 @@ def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef: def to_gpu_transform(self) -> mgpu.MemRefTransform: raise RuntimeError("SwizzleTransform does not have a GPU transform.") + def to_gpu_transform_attr(self) -> ir.Attribute: + return mgpu.dialect.SwizzleTransformAttr.get(self.swizzle) + def undo_to_gpu_transform(self) -> mgpu.MemRefTransform: # There's no swizzle transform in mgpu right now. It's a separate arg. raise NotImplementedError def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray: - swizzle_elems = self.swizzle // aval.dtype.itemsize + swizzle_elems = (self.swizzle * 8) // dtypes.itemsize_bits(aval.dtype) if swizzle_elems != aval.shape[-1]: raise ValueError( f"Swizzle {self.swizzle} requires the trailing dimension to be of" @@ -353,29 +1003,74 @@ def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray: class UnswizzleRef(state_types.Transform): swizzle: int = dataclasses.field(metadata=dict(static=True)) + def swizzle_elems(self, dtype: jnp.dtype | ir.Type) -> int: + if not isinstance(dtype, ir.Type): + dtype = mgpu_utils.dtype_to_ir_type(dtype) + return (self.swizzle * 8) // mgpu.bitwidth(dtype) + + def untransform_transpose(self, perm) -> tuple[tuple[int, ...], state_types.Transform]: + if perm[-1] != len(perm) - 1: + raise ValueError("Can't transpose the swizzled dimension.") + + return perm, self + + def untransform_reshape( + self, dtype: jnp.dtype | ir.Type, shape: tuple[int, ...] + ) -> tuple[tuple[int, ...], state_types.Transform]: + if shape[-1] != self.swizzle_elems(dtype): + raise ValueError( + f"Reshape shape {shape} is not divisible by swizzle elements" + f" {self.swizzle_elems(dtype)}" + ) + return shape, self + def untransform_index( - self, idxs: tuple[Index, ...] + self, dtype: jnp.dtype | ir.Type, idxs: tuple[Index, ...] ) -> tuple[tuple[Index, ...], state_types.Transform]: + swizzle_elems = self.swizzle_elems(dtype) if not idxs: return idxs, self - if not all(isinstance(idx, slice) for idx in idxs[-2:]): + if not all(isinstance(idx, (slice, mgpu.ds)) for idx in idxs[-2:]): raise NotImplementedError( "Non-slice indices are not supported in 2 minormost dims" ) last_idx = idxs[-1] - assert isinstance(last_idx, slice) - if last_idx.step is not None and last_idx.step != 1: - raise NotImplementedError("Swizzled dims cannot be sliced") - if (last_idx.start is not None and last_idx.start != 0) or ( - last_idx.stop is not None and last_idx.stop != self.swizzle - ): - raise ValueError("Swizzled dims cannot be sliced") + if isinstance(last_idx, mgpu.DynamicSlice): + if last_idx.base != 0 or last_idx.length != swizzle_elems: + raise ValueError("Swizzled dims cannot be sliced") + else: + assert isinstance(last_idx, slice) + if ( + (last_idx.step is not None and last_idx.step != 1) + or (last_idx.start is not None and last_idx.start != 0) + or (last_idx.stop is not None and last_idx.stop != swizzle_elems) + ): + raise ValueError("Swizzled dims cannot be sliced") return idxs, self + def pretty_print(self, context: jax_core.JaxprPpContext) -> pp.Doc: + return pp.text(f"{{unswizzle({self.swizzle})}}") + @dataclasses.dataclass -class GPUBlockSpec(pallas_core.BlockSpec): +class BlockSpec(pallas_core.BlockSpec): + r"""A GPU-specific ``BlockSpec``. + + Attributes: + transforms: A sequence of transforms that will be applied to the + reference. + delay_release: used during pipelining to delay the release of + resources of a slot after it is used in the computation. + collective_axes: When set, all blocks along the specified axes must execute + the same sequence of pipeline operations (with the only exception being + the index_map in non-collective ``BlockSpec``\ s), and all of them must + return the same block from the index_map for this operand. This enables + the pipelining helpers to use collective async copies, which can improve + performance. + """ transforms: Sequence[MemoryRefTransform] = () + delay_release: int = 0 + collective_axes: tuple[Hashable, ...] = () def to_block_mapping( self, @@ -385,19 +1080,26 @@ def to_block_mapping( index_map_avals: Sequence[jax_core.AbstractValue], index_map_tree: tree_util.PyTreeDef, grid: pallas_core.GridMappingGrid, - mapped_dims: tuple[int, ...], + vmapped_dims: tuple[int, ...], + debug: bool = False, ) -> pallas_core.BlockMapping: + if self.collective_axes: + raise ValueError( + "collective_axes is not supported in pallas_call. Use plgpu.kernel" + " with plgpu.emit_pipeline_warp_specialized instead." + ) bm = super().to_block_mapping( origin, array_aval, index_map_avals=index_map_avals, index_map_tree=index_map_tree, grid=grid, - mapped_dims=mapped_dims, + vmapped_dims=vmapped_dims, + debug=debug, ) block_inner_aval = bm.block_aval.inner_aval for t in self.transforms: - block_inner_aval = t(block_inner_aval) + block_inner_aval = t(block_inner_aval) # type: ignore[arg-type] return bm.replace( transformed_block_aval=bm.block_aval.update( inner_aval=block_inner_aval @@ -406,9 +1108,10 @@ def to_block_mapping( ) -GMEM = GPUMemorySpace.GMEM -SMEM = GPUMemorySpace.SMEM -REGS = GPUMemorySpace.REGS +GMEM = MemorySpace.GMEM +SMEM = MemorySpace.SMEM +TMEM = MemorySpace.TMEM +REGS = MemorySpace.REGS class barrier_dtype(dtypes.extended): @@ -421,21 +1124,78 @@ class BarrierType(dtypes.ExtendedDType): name: ClassVar[str] = "barrier" num_arrivals: int + orders_tensor_core: bool def __str__(self): return self.name @dataclasses.dataclass(frozen=True) -class Barrier: +class ClusterBarrierType(dtypes.ExtendedDType): + type: ClassVar[Any] = barrier_dtype + name: ClassVar[str] = "cluster_barrier" + + collective_axes: tuple[str | tuple[str, ...], ...] num_arrivals: int + orders_tensor_core: bool + + def __str__(self): + return self.name + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class Barrier: + """Describes a barrier reference. + + Attributes: + num_arrivals: The number of arrivals that will be recorded by this barrier. + num_barriers: The number of barriers that will be created. Individual + barriers can be accessed by indexing into the barrier Ref. + orders_tensor_core: If False, a successfull wait from one thread does not + guarantee that the TensorCore-related operations in other threads have + completed. Similarly, when False any TensorCore operation in the waiting + thread is allowed to begin before the wait succeeds. + """ + num_arrivals: int = 1 + num_barriers: int = 1 + orders_tensor_core: bool = False + + def get_array_aval(self) -> jax_core.ShapedArray: + raise ValueError("Barriers are not arrays") + + def get_ref_aval(self) -> state.AbstractRef: + aval = jax_core.ShapedArray( + [self.num_barriers], + BarrierType( + self.num_arrivals, orders_tensor_core=self.orders_tensor_core + ), + ) + return state.AbstractRef(aval, SMEM) + + def __post_init__(self): + if self.num_arrivals < 1: + raise ValueError( + f"Num arrivals must be at least 1, but got {self.num_arrivals}" + ) + +@dataclasses.dataclass(frozen=True, kw_only=True) +class ClusterBarrier: + collective_axes: tuple[str | tuple[str, ...], ...] num_barriers: int = 1 + num_arrivals: int = 1 + orders_tensor_core: bool = False + + def get_array_aval(self) -> jax_core.ShapedArray: + raise ValueError("Cluster barriers are not arrays") - def get_ref_aval(self) -> AbstractMemoryRef: + def get_ref_aval(self) -> state.AbstractRef: aval = jax_core.ShapedArray( - [self.num_barriers], BarrierType(self.num_arrivals) + [self.num_barriers], + ClusterBarrierType( + self.collective_axes, self.num_arrivals, self.orders_tensor_core + ), ) - return AbstractMemoryRef(aval, SMEM) + return state.AbstractRef(aval, SMEM) @dataclasses.dataclass(frozen=True) @@ -444,13 +1204,13 @@ class WGMMAAccumulatorRef: dtype: jnp.dtype = jnp.float32 _init: Any = state_types.uninitialized - def get_ref_aval(self) -> AbstractMemoryRef: + def get_ref_aval(self) -> state.AbstractRef: if self._init is not state_types.uninitialized: raise ValueError( "Preinitialized WGMMAAccumulatorRef only supported in pl.run_state." ) return WGMMAAbstractAccumulatorRef( - jax_core.ShapedArray(shape=self.shape, dtype=self.dtype), GPUMemorySpace.REGS + jax_core.ShapedArray(shape=self.shape, dtype=self.dtype), MemorySpace.REGS ) @staticmethod @@ -460,23 +1220,24 @@ def init(array): def _wgmma_ref_type_mapping(ref: WGMMAAccumulatorRef): aval = WGMMAAbstractAccumulatorRef( - jax_core.ShapedArray(shape=ref.shape, dtype=ref.dtype), GPUMemorySpace.REGS + jax_core.ShapedArray(shape=ref.shape, dtype=ref.dtype), MemorySpace.REGS ) return aval, ref._init state_types._ref_type_aval_mappings[WGMMAAccumulatorRef] = _wgmma_ref_type_mapping -class WGMMAAbstractAccumulatorRef(AbstractMemoryRef): +class WGMMAAbstractAccumulatorRef(state.AbstractRef): __slots__ = ["inner_aval", "memory_space"] def __repr__(self) -> str: return f'Accumulator{{{self.inner_aval.str_short()}}}' - def update_weak_type(self, weak_type): - return _as_accum(super().update_weak_type(weak_type)) - - def update(self, inner_aval=None, memory_space=None): - return _as_accum(super().update(inner_aval=None, memory_space=None)) + def update(self, inner_aval=None, memory_space=None, kind=None): + ref = super().update(inner_aval, memory_space, kind) + return WGMMAAbstractAccumulatorRef( + inner_aval=ref.inner_aval, + memory_space=ref.memory_space, + ) def _getitem(self, tracer, idx): from jax._src.pallas.mosaic_gpu.primitives import wgmma_accumulator_deref # pytype: disable=import-error @@ -488,35 +1249,64 @@ def _getitem(self, tracer, idx): return arr -def _as_accum(ref) -> WGMMAAbstractAccumulatorRef: - return WGMMAAbstractAccumulatorRef( - inner_aval=ref.inner_aval, - memory_space=ref.memory_space, # pytype: disable=attribute-error - ) +class AbstractTMEMRef(state.AbstractRef): + __slots__ = ["inner_aval", "memory_space", "layout", "collective"] + + def __init__(self, inner_aval, memory_space, layout, collective): + super().__init__(inner_aval, memory_space) + self.layout = layout + self.collective = collective + + def __repr__(self) -> str: + return f'TMEM({self.inner_aval.str_short()}, layout={self.layout}, collective={self.collective})' + + def update(self, inner_aval=None, memory_space=None, kind=None): + ref = super().update(inner_aval, memory_space, kind) + return AbstractTMEMRef( + ref.inner_aval, ref.memory_space, self.layout, self.collective + ) _WARPGROUP_AXIS_NAME = object() @dataclasses.dataclass(frozen=True, kw_only=True) -class GPUMesh: - grid: tuple[int, ...] = () - cluster: tuple[int, ...] = () +class Mesh: + grid: Sequence[int] = () + grid_names: Sequence[str] = () + cluster: Sequence[int] = () + cluster_names: Sequence[str] = () # Those are NOT CUDA threads. On Hopper they correspond to warpgroups. num_threads: int | None = None - axis_names: tuple[str, ...] = () + thread_name: str | None = None + kernel_name: str | None = None def __post_init__(self): - if len(self.axis_names) != len(self.grid) + (self.num_threads is not None): - raise ValueError("Need as many axis names as grid dimensions + warp groups") - if self.num_threads is not None and self.num_threads > 2048 // 128: + if len(self.cluster) > 3: + raise ValueError(f"cluster= must be at most 3D, got {self}.") + if len(self.grid_names) != len(self.grid): raise ValueError( - "Requested too many CUDA threads per block. Each Mosaic thread" - " corresponds to 128 CUDA threads." + f"grid_names must have the same length as grid, got {self}." ) - if self.cluster: - raise NotImplementedError( - "Pallas/MosaicGPU does not support clusters yet." + if len(self.cluster_names) != len(self.cluster): + raise ValueError( + f"cluster_names must have the same length as cluster, got {self}." + ) + if (self.thread_name is None) != (self.num_threads is None): + raise ValueError( + "num_threads and thread_name must be either both set or both None," + f" got {self}" + ) + max_mosaic_threads = 2048 // 128 + if self.num_threads is not None and self.num_threads > max_mosaic_threads: + raise ValueError( + "Requested too many CUDA threads per block. Each Mosaic thread" + f" corresponds to 128 CUDA threads. At most {max_mosaic_threads}" + f" are supported, got {self}" ) + object.__setattr__(self, "grid", tuple(self.grid)) + object.__setattr__(self, "grid_names", tuple(self.grid_names)) + object.__setattr__(self, "cluster", tuple(self.cluster)) + object.__setattr__(self, "cluster_names", tuple(self.cluster_names)) @property def backend(self) -> str: @@ -527,20 +1317,40 @@ def shape(self) -> collections.OrderedDict[object, int]: pairs: Iterable[tuple[object, int]] if self.num_threads is not None: pairs = zip( - self.axis_names, (*self.grid, *self.cluster, self.num_threads) + (*self.grid_names, *self.cluster_names, self.thread_name), + (*self.grid, *self.cluster, self.num_threads), ) else: - pairs = tuple( - zip( - (*self.axis_names, _WARPGROUP_AXIS_NAME), - (*self.grid, *self.cluster, 1), - ) + pairs = zip( + (*self.grid_names, *self.cluster_names), + (*self.grid, *self.cluster), ) return collections.OrderedDict(pairs) def discharges_effect(self, effect: jax_core.Effect): return effect is _wgmma_pipeline_effect or effect is _memory_effect +@dataclasses.dataclass(frozen=True, kw_only=True) +class WarpMesh: + """Represents a mesh over individual warps within a warpgroup. + + When used in conjunction with `core_map`, the warp ID will be visible + within the body of the wrapped scope by querying `lax.axis_index` with + the specified axis name. + """ + + _NUM_WARPS_PER_WARPGROUP: ClassVar[int] = 4 + axis_name: str + + @property + def shape(self): + return collections.OrderedDict([ + (self.axis_name, self._NUM_WARPS_PER_WARPGROUP), + ]) + + def discharges_effect(self, effect: jax_core.Effect): + del effect + return False def _gpu_mesh_discharge_rule( in_avals, @@ -553,18 +1363,22 @@ def _gpu_mesh_discharge_rule( debug, cost_estimate, name, + metadata, ): - if not isinstance(mesh, GPUMesh): - raise TypeError(f"Mesh must be a GPUMesh, got {type(mesh)}") - if mesh.cluster: - raise NotImplementedError - if compiler_params and not isinstance(compiler_params, GPUCompilerParams): + if not isinstance(mesh, Mesh): + raise TypeError(f"Mesh must be a `plgpu.Mesh`, got {type(mesh)}") + if compiler_params and not isinstance(compiler_params, CompilerParams): raise TypeError( - "Compiler params must be a GPUCompilerParams, got" + "Compiler params must be a `plgpu.CompilerParams`, got" f" {type(compiler_params)}" ) if not compiler_params: - compiler_params = GPUCompilerParams() + compiler_params = CompilerParams() + sa_avals = [a for a in in_avals if isinstance(a, jax_core.ShapedArray)] + if sa_avals: + raise NotImplementedError( + f"Cannot close over values in core_map: {sa_avals}" + ) return pallas_core.default_mesh_discharge_rule( in_avals, out_avals, @@ -576,10 +1390,13 @@ def _gpu_mesh_discharge_rule( interpret=interpret, cost_estimate=cost_estimate, name=name, + memory_space=GMEM, + metadata=metadata, + scratch_shapes=[], ) -pallas_core._core_map_mesh_rules[GPUMesh] = _gpu_mesh_discharge_rule +pallas_core._core_map_mesh_rules[Mesh] = _gpu_mesh_discharge_rule class MemoryEffect(jax_core.Effect): @@ -596,3 +1413,188 @@ class _WGMMAPipelineEffect(effects.Effect): effects.control_flow_allowed_effects.add_type(_WGMMAPipelineEffect) _wgmma_pipeline_effect = _WGMMAPipelineEffect() + + +# We define the layout_cast primitive here, because it needs to be available in +# the lowering code (to provide layout hints to the rules). +layout_cast_p = jax_core.Primitive("layout_cast") + + +@layout_cast_p.def_abstract_eval +def _layout_cast_abstract_eval(x, new_layout): + del new_layout # Unused. + return x + + +def layout_cast(x: Any, new_layout: SomeLayout): + """Casts the layout of the given array.""" + return layout_cast_p.bind(x, new_layout=new_layout) + + +class SomeLayout: + + def reduce(self, axes: int | Sequence[int]) -> "SomeLayout": + if isinstance(axes, int): + axes = (axes,) + return ReducedLayout(self, axes) + + def to_mgpu(self, *args, **kwargs) -> mgpu.FragmentedLayout: + raise NotImplementedError + + +@dataclasses.dataclass(frozen=True) +class ParameterizedLayout(SomeLayout): + layout_cls: Layout | TMEMLayout + args: Sequence[Any] + kwargs: Any + + def __post_init__(self): + object.__setattr__(self, "args", tuple(self.args)) + object.__setattr__(self, "kwargs", frozen_dict.FrozenDict(self.kwargs)) + + def to_mgpu(self) -> mgpu.FragmentedLayout: + return self.layout_cls.to_mgpu(*self.args, **self.kwargs) + + +@dataclasses.dataclass(frozen=True) +class ReducedLayout(SomeLayout): + layout: SomeLayout + axes: Sequence[int] + + def to_mgpu(self) -> mgpu.FragmentedLayout: + layout = self.layout.to_mgpu() + if not isinstance(layout, mgpu.TiledLayout): + raise ValueError("Only TiledLayout supports reductions.") + return layout.reduce(self.axes) + + +class Layout(SomeLayout, enum.Enum): + #: [m, n] matrix, where m % 64 == 0 == n % 8. + WGMMA = enum.auto() + WGMMA_8BIT = enum.auto() + WGMMA_UPCAST_2X = enum.auto() + WGMMA_UPCAST_4X = enum.auto() + WGMMA_TRANSPOSED = enum.auto() + + WG_SPLAT = enum.auto() + WG_STRIDED = enum.auto() + + TILED = enum.auto() + + TCGEN05 = enum.auto() + TCGEN05_TRANSPOSED = enum.auto() + TCGEN05_M64_COLLECTIVE = enum.auto() + TCGEN05_TMEM_NATIVE = enum.auto() + TCGEN05_M64_COLLECTIVE_NATIVE = enum.auto() + + SMEM_GMEM_COPY = enum.auto() + TMA_GATHER_INDICES = enum.auto() + + # TODO(b/435159109): Remove this once LLVM regression is addressed. + _WGMMA_ACC_32BIT = enum.auto() # Temporarily exposed to work around LLVM bugs + + def __call__(self, *args, **kwargs) -> ParameterizedLayout: + return ParameterizedLayout(self, args, kwargs) + + def to_mgpu(self, *args, **kwargs) -> mgpu.FragmentedLayout: + def check_no_args(): + if args or kwargs: + raise ValueError(f"Can't instantiate {self} with arguments.") + + match self: + case Layout.WGMMA_TRANSPOSED: + check_no_args() + return mgpu.WGMMA_TRANSPOSED_LAYOUT + case Layout.WGMMA: + check_no_args() + return mgpu.WGMMA_LAYOUT + case Layout.WGMMA_8BIT: + check_no_args() + return mgpu.WGMMA_LAYOUT_8BIT + case Layout.WGMMA_UPCAST_2X: + check_no_args() + return mgpu.WGMMA_LAYOUT_UPCAST_2X + case Layout.WGMMA_UPCAST_4X: + check_no_args() + return mgpu.WGMMA_LAYOUT_UPCAST_4X + case Layout._WGMMA_ACC_32BIT: + check_no_args() + return mgpu.fragmented_array.WGMMA_LAYOUT_ACC_32BIT + case Layout.WG_SPLAT: + return mgpu.WGSplatFragLayout(*args, **kwargs) # pytype: disable=missing-parameter + case Layout.WG_STRIDED: + return mgpu.WGStridedFragLayout(*args, **kwargs) # pytype: disable=missing-parameter + case Layout.TILED: + return mgpu.TiledLayout(*args, **kwargs) + case Layout.TCGEN05: + check_no_args() + return mgpu.TCGEN05_LAYOUT + case Layout.TCGEN05_TRANSPOSED: + check_no_args() + return mgpu.TCGEN05_TRANSPOSED_LAYOUT + case Layout.TCGEN05_TMEM_NATIVE: + if args or kwargs: + return mgpu.tmem_native_layout(*args, **kwargs) + return mgpu.TMEM_NATIVE_LAYOUT + case Layout.TCGEN05_M64_COLLECTIVE: + return tcgen05.fa_m64_collective_layout(*args, **kwargs) # pytype: disable=missing-parameter + case Layout.TCGEN05_M64_COLLECTIVE_NATIVE: + return tcgen05.tmem_m64_collective_layout(*args, **kwargs).as_tiled_layout() # pytype: disable=missing-parameter + case Layout.SMEM_GMEM_COPY: + normalize_args = lambda shape, dtype, swizzle: (shape, dtype, swizzle) + shape, dtype, swizzle = normalize_args(*args, **kwargs) + bitwidth = dtypes.itemsize_bits(dtype) + tiling = (8, 8 * swizzle // bitwidth) + row_tiles, col_tiles = mgpu.tile_shape(shape, tiling)[-4:-2] + return mgpu.fragmented_array.tiled_copy_smem_gmem_layout( + row_tiles, col_tiles, swizzle, bitwidth + ) + case Layout.TMA_GATHER_INDICES: + return mgpu.TMA_GATHER_INDICES_LAYOUT + + +# TODO(apaszke): Adjust the users and remove these backfills. +Layout.WGMMA_ROW = Layout.WGMMA.reduce(1) +Layout.WGMMA_COL = Layout.WGMMA.reduce(0) +Layout.TCGEN05_ROW = Layout.TCGEN05.reduce(1) +Layout.TCGEN05_COL = Layout.TCGEN05.reduce(0) +Layout.TCGEN05_TMEM_NATIVE_ROW = Layout.TCGEN05_TMEM_NATIVE.reduce(1) + + +class TMEMLayout(enum.Enum): + """Layout for TMEM references.""" + # TODO(apaszke): Remove the layout suffix. + SCALES_LAYOUT = enum.auto() + SPARSE_METADATA_LAYOUT = enum.auto() + M64_COLLECTIVE_LAYOUT = enum.auto() + + def __call__(self, *args, **kwargs) -> ParameterizedLayout: + return ParameterizedLayout(self, args, kwargs) + + def to_mgpu(self, *args, **kwargs) -> tcgen05.TMEMLayout: + match self: + case TMEMLayout.SCALES_LAYOUT: + return tcgen05.scales_layout(*args, **kwargs) + case TMEMLayout.SPARSE_METADATA_LAYOUT: + return tcgen05.sparse_meta_layout(*args, **kwargs) + case TMEMLayout.M64_COLLECTIVE_LAYOUT: + return tcgen05.tmem_m64_collective_layout(*args, **kwargs) # pytype: disable=missing-parameter + + +def TryClusterCancelResult( + num_buffers: int | None = None) -> pallas_core.MemoryRef: + """Helper function to create Refs for cluster launch control results. + + Args: + num_buffers: Optional argument for specifying the number of buffers + to allocate. If None, will return a single 16-byte buffer. If specified, + will return a (num_buffers, 16)-shaped buffer. + + Returns: + A MemoryRef with the correct shape for holding the opaque cluster launch + control result. + """ + if num_buffers is None: + return SMEM((16,), jnp.int8) + else: + return SMEM((num_buffers, 16), jnp.int8) diff --git a/jax/_src/pallas/mosaic_gpu/helpers.py b/jax/_src/pallas/mosaic_gpu/helpers.py new file mode 100644 index 000000000000..f118ec35f1f3 --- /dev/null +++ b/jax/_src/pallas/mosaic_gpu/helpers.py @@ -0,0 +1,385 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for Pallas Mosaic GPU kernels.""" + +from collections.abc import Callable, Hashable, Sequence +import dataclasses +import functools +import math +from typing import TypeVar, overload + +import jax +from jax import numpy as jnp +from jax import lax +from jax._src import dtypes +from jax._src.pallas.mosaic_gpu import primitives as gpu_primitives +from jax._src.pallas.mosaic_gpu import core as gpu_core +from jax._src.pallas import primitives as pallas_primitives +import numpy as np + +_T = TypeVar("_T") + + +@dataclasses.dataclass(frozen=True, eq=False) +class NDLoopInfo: + """Container dataclass for loop iteration information. + + Attributes: + index: The grid indices corresponding to the current loop iteration. + local_index: The local iteration index. + num_local_steps: The total number of local iterations to run. None + if unknown. + """ + index: tuple[jax.Array, ...] + local_index: jax.Array | int + num_local_steps: jax.Array | int | None + + +@overload +def nd_loop( + grid: Sequence[int], + *, + collective_axes: Sequence[Hashable] | Hashable, + tiling: Sequence[int] | None = None, + init_carry: None = None +) -> Callable[[Callable[[NDLoopInfo], None]], None]: + ... + + +@overload +def nd_loop( + grid: Sequence[int], + *, + collective_axes: Sequence[Hashable] | Hashable, + tiling: Sequence[int] | None = None, + init_carry: _T +) -> Callable[[Callable[[NDLoopInfo, _T], _T]], _T]: + ... + + +def nd_loop(grid, *, collective_axes, tiling=None, init_carry=None): + """A loop over a multi-dimensional grid partitioned along the given axes. + + The body of the loop a single argument ``loop_info`` which is an NDLoopInfo + object containing index and iteration information. However if a carry is + specified, the body will expect a second keyword argument `carry` containing + the loop carry. + + For example, if ``collective_axes`` is ``"x"`` with :func:`lax.axis_size` + equal to 4 and the grid is (2, 3), the implementation would produce the + following iteration order + + +-----------+--------+------------+ + | loop step | index | axis index | + +===========+========+============+ + | 0 | (0, 0) | 0 | + +-----------+--------+------------+ + | 1 | (0, 1) | 1 | + +-----------+--------+------------+ + | 2 | (0, 2) | 2 | + +-----------+--------+------------+ + | 3 | (1, 0) | 3 | + +-----------+--------+------------+ + | 4 | (1, 1) | 0 | + +-----------+--------+------------+ + | 5 | (1, 2) | 1 | + +-----------+--------+------------+ + + which comes from partitioning the flat iteration space into chunks in an + interleaved fashion wrt the ``"x"`` axis index. + + Note that in the example the total number of loop steps is not divisible + by the axis size of ``"x"``, and thus for some ``"x"`` axis indices the + loop will do one iteration less. + + +------------+------------------+ + | axis index | indices | + +============+==================+ + | 0 | (0, 0), (1, 1) | + +------------+------------------+ + | 1 | (0, 1), (1, 2) | + +------------+------------------+ + | 2 | (0, 2) | + +------------+------------------+ + | 3 | (1, 0) | + +------------+------------------+ + + If ``init_carry`` is passed then ``nd_loop()`` will expect the body to + take and return the carry. If it's ``None`` then no carry argument is + expected. + + See also: + - :func:`jax.experimental.pallas.loop`: A loop over a single dimension. + """ + + axis_index = lax.axis_index(collective_axes) + axis_size = lax.axis_size(collective_axes) + if tiling: + if len(grid) != len(tiling): + raise ValueError(f"{tiling=} and {grid=} must have same length.") + if any(dim % tile != 0 for dim, tile in zip(grid, tiling, strict=True)): + raise ValueError(f"Tiling {tiling} does not divide grid {grid}.") + tile_grid = tuple( + dim // tile for dim, tile in zip(grid, tiling, strict=True)) + grid = (*tile_grid, *tiling) + grid_size = math.prod(grid) + + def decorator(body): + def wrapper(wave_step, carry): + nonlocal body + step = wave_step * axis_size + axis_index + # The loop below is conceptually ``jnp.unravel_index``, but it uses + # ``lax`` APIs instead of ``jax.numpy`` to minimize the number of + # primitives used. + index = [] + for grid_dim in reversed(grid): + grid_dim = lax.convert_element_type(grid_dim, step.dtype) + index.append(lax.rem(step, grid_dim)) + step = lax.div(step, grid_dim) + index.reverse() + + if tiling: + # Recompute index as if the grid was not tiled. + tile_indices, subtile_indices = index[:len(tiling)], index[len(tiling):] + untiled_index = [] + for sub_idx, tile_idx, tile_dim in zip( + subtile_indices, tile_indices, tiling, strict=True): + untiled_index.append(sub_idx + tile_idx * tile_dim) + index = untiled_index + + loop_info = NDLoopInfo( + index=tuple(index), + local_index=wave_step, + num_local_steps=upper + ) + if init_carry is None: + body(loop_info) + else: + return body(loop_info, carry=carry) + + upper = lax.div(grid_size, axis_size) + lax.convert_element_type( + axis_index < grid_size % axis_size, axis_index.dtype + ) + return lax.fori_loop(0, upper, wrapper, init_carry) + return decorator + + +def format_tcgen05_sparse_metadata(meta): + """Formats the sparse metadata for tcgen05.mma into the expected format. + + See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-sparse-matrices-sparsity-selector-kind-f16-m128-256 + for the documentation of the required layouts. The array can be copied into + SMEM, from where ``plgpu.async_copy_sparse_metadata_to_tmem`` can be used to + copy it over to TMEM. + """ + if meta.dtype != dtypes.uint2: + raise ValueError(f"Expected metadata dtype to be uint2, got: {meta.dtype}") + if meta.ndim != 3: + raise ValueError( + "Expected metadata to be 3-dimensional (M, K // 4, 2), but it is" + f" {meta.ndim}D" + ) + m, k, _2 = meta.shape + if _2 != 2: + raise ValueError( + "Expected the trailing dimension of the metadata to be 2, got:" + f" {meta.shape[-1]}" + ) + k *= 2 + return ( + meta.reshape(m // 128, 8, 2, 8, k // 64, 4, 2, 8) + .transpose(0, 4, 1, 6, 3, 5, 2, 7) + .reshape(m // 128, k // 64, 128, 64) + ) + + +def find_swizzle(minor_dim_bits: int, what: str = ""): + """Returns the largest swizzle that can be applied to a memory region. + + Swizzling is usually necessary when dealing with 2D data in SMEM, especially + if the reference is used as an MMA operand. The returned swizzle is usually + applied as ``plgpu`` transform: + + transforms = ( + plgpu.TilingTransform((8, 8 * swizzle // elem_bits)), + plgpu.SwizzleTransform(swizzle)) + ) + + Args: + minor_dim_bits: The number of bits in the minor (last) dimension of the + memory region. Usually computed as ``dim_size * jnp.finfo(dtype).bits``. + what: A string describing the operand for which the swizzle is being + computed. Improves the error message if specified. + """ + for swizzle_bytes in (128, 64, 32, 16): + if minor_dim_bits % (swizzle_bytes * 8) == 0: + return swizzle_bytes + if what: + what = " for " + what + raise ValueError( + f"No valid out swizzle{what}: minor dimension has" + f" {minor_dim_bits} bits, which is not a multiple of 128 (16 bytes)" + ) + + +def planar_snake( + lin_idx: jax.Array, shape: tuple[int, int], minor_dim: int, tile_width: int +): + """Converts a linear index into an index into shape, trying to optimize locality. + + The "space filling curve" this function computes splits the minor dimension + into tiles of length ``tile_width``. Every other tile has its major dimension + inverted, so that the iteration order "snakes around" when going from one tile + to another. + + For a shape of (8, 8), ``minor_dim=0`` and ``tile_width=2``, the iteration + order is:: + + 0 2 4 6 8 10 12 14 + 1 3 5 7 9 11 13 15 + 30 28 26 24 22 20 18 16 + 31 29 27 25 23 21 19 17 + 32 34 36 38 40 42 44 46 + 33 35 37 39 41 43 45 47 + 62 60 58 56 54 52 50 48 + 63 61 59 57 55 53 51 49 + + Notice how each pair of rows forms a tile (``minor_dim=0``, ``tile_width=2``) + and when moving from one tile to another, the indices increase along columns + in one of them and decrease in the other. + """ + tile_width = np.int32(tile_width) + major_size = np.int32(shape[1 - minor_dim]) + minor_size = np.int32(shape[minor_dim]) + minor_tile_idx = lax.div(lin_idx, tile_width * major_size) + + def tile_coordinates(lin_idx, width): + # if minor_dim == 0 then tiles are (tile_width, major_size) else (major_size, tile_width) + minor_within_tile = lax.rem(lin_idx, width) + major_within_tile = lax.rem(lax.div(lin_idx, width), major_size) + minor = minor_tile_idx * tile_width + minor_within_tile + major = lax.select( + lax.rem(minor_tile_idx, np.int32(2)) == 0, + major_within_tile, + major_size - 1 - major_within_tile, + ) + return (minor, major) if minor_dim == 0 else (major, minor) + + num_full_tiles = shape[minor_dim] // tile_width + full_tiles_minor_size = num_full_tiles * tile_width + num_full_tiles_elements = num_full_tiles * tile_width * major_size + is_full_tile = lin_idx < num_full_tiles_elements + return jax.tree.map( + functools.partial(jax.lax.select, is_full_tile), + tile_coordinates(lin_idx, tile_width), + tile_coordinates(lin_idx - num_full_tiles_elements, minor_size - full_tiles_minor_size) + ) + + +@overload +def dynamic_scheduling_loop( + grid_names: Sequence[Hashable], + *, + thread_axis: Hashable | None = None, + init_carry: None = None +) -> Callable[[Callable[[NDLoopInfo], None]], None]: + ... + + +@overload +def dynamic_scheduling_loop( + grid_names: Sequence[Hashable], + *, + thread_axis: Hashable | None = None, + init_carry: _T +) -> Callable[[Callable[[NDLoopInfo, _T], _T]], _T]: + ... + + +def dynamic_scheduling_loop( + grid_names, + thread_axis = None, + init_carry = None): + """A loop over program instances using dynamic work scheduling. + + This loop will iterate through available program instances until all + work has been scheduled. The kernel should be instantiated with a grid + equal to the logical amount of work to be done (as opposed to a persistent + kernel where the grid is set to the number of cores). Each core running + this loop will continuously query the next available block of work and + the loop will terminate when the entire grid has been scheduled. + + Example usage:: + + @plgpu.dynamic_scheduling_loop(grid_names) + def body(loop_info): + work(loop_info.index) # do work... + + Args: + grid_names: The names of the axes in the grid. + thread_axis: The name of the thread axis. This must be passed in if + the kernel uses multiple threads. + init_carry: An optional initial carry for the loop. If passed in, the + body function should expect a ``carry`` keyword argument and return + the next carry value. + """ + if thread_axis is not None: + num_threads = lax.axis_size(thread_axis) + else: + num_threads = 1 + user_carry = init_carry + + def decorator(body): + grid_idx = tuple(lax.axis_index(axis_name) for axis_name in grid_names) + success = True + def _scoped(try_cancel_buffer, try_cancel_barrier): + def try_cancel_cond(carry): + _, success, _, _ = carry + return success + def try_cancel_body(carry): + grid_idx, _, wave_step, user_carry = carry + slot = lax.rem(wave_step, jnp.int32(2)) + gpu_primitives.try_cluster_cancel(try_cancel_buffer.at[slot], + try_cancel_barrier.at[slot]) + loop_info = NDLoopInfo( + index=grid_idx, + local_index=wave_step, + num_local_steps=None, + ) + if user_carry is None: + body(loop_info) + else: + user_carry = body(loop_info, carry=user_carry) + gpu_primitives.barrier_wait(try_cancel_barrier.at[slot]) + grid_idx, success = gpu_primitives.query_cluster_cancel( + try_cancel_buffer.at[slot], + grid_names=grid_names) + return (grid_idx, success, wave_step + jnp.int32(1), user_carry) + init_carry = (grid_idx, success, jnp.int32(0), user_carry) + final_carry = lax.while_loop( + try_cancel_cond, + try_cancel_body, + init_carry, + ) + if user_carry is not None: + return final_carry[-1] + return pallas_primitives.run_scoped( + _scoped, + try_cancel_buffer=gpu_core.TryClusterCancelResult(2), + try_cancel_barrier=gpu_core.Barrier(num_arrivals=num_threads, + num_barriers=2), + collective_axes=thread_axis, + ) + return decorator diff --git a/jax/_src/pallas/mosaic_gpu/interpret/BUILD b/jax/_src/pallas/mosaic_gpu/interpret/BUILD new file mode 100644 index 000000000000..a324487edbc4 --- /dev/null +++ b/jax/_src/pallas/mosaic_gpu/interpret/BUILD @@ -0,0 +1,85 @@ +# Copyright 2026 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Package for Pallas GPU Interpret Mode + +load("@rules_python//python:defs.bzl", "py_library") +load("//jaxlib:jax.bzl", "py_deps", "pytype_strict_library") + +package( + default_applicable_licenses = [], + default_visibility = [ + "//jax:internal", + ], +) + +py_library( + name = "interpret_pallas_call", + srcs = [ + "__init__.py", + "interpret_pallas_call.py", + ], + deps = [ + ":gpu_callbacks", + ":jaxpr_interpret", + "//jax", + "//jax/_src:callback", + "//jax/_src:core", + "//jax/_src:effects", + "//jax/_src:lax", + "//jax/_src:typing", + "//jax/_src:util", + "//jax/_src/pallas", + "//jax/_src/pallas/mosaic/interpret:thread_map", + "//jax/_src/pallas/mosaic/interpret:utils", + "//jax/_src/pallas/mosaic_gpu:core", + "//jax/experimental:pallas_mosaic_gpu", + ], +) + +pytype_strict_library( + name = "gpu_callbacks", + srcs = ["gpu_callbacks.py"], + deps = [ + "//jax", + "//jax/_src:callback", + "//jax/_src:lax", + "//jax/_src:source_info_util", + "//jax/_src/pallas/mosaic/interpret:race_detection_state", + "//jax/_src/pallas/mosaic/interpret:shared_memory", + "//jax/_src/pallas/mosaic/interpret:utils", + "//jax/_src/pallas/mosaic/interpret:vector_clock", + "//jax/_src/pallas/mosaic_gpu:core", + ] + py_deps([ + "numpy", + ]), +) + +pytype_strict_library( + name = "jaxpr_interpret", + srcs = ["jaxpr_interpret.py"], + deps = [ + ":gpu_callbacks", + "//jax", + "//jax/_src:callback", + "//jax/_src:core", + "//jax/_src:lax", + "//jax/_src:source_info_util", + "//jax/_src:util", + "//jax/_src/pallas", + "//jax/_src/pallas/mosaic/interpret:utils", + "//jax/_src/pallas/mosaic_gpu:core", + "//jax/experimental:pallas_mosaic_gpu", + ], +) diff --git a/jax/_src/pallas/mosaic_gpu/interpret/__init__.py b/jax/_src/pallas/mosaic_gpu/interpret/__init__.py new file mode 100644 index 000000000000..554f568218e5 --- /dev/null +++ b/jax/_src/pallas/mosaic_gpu/interpret/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2026 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jax/_src/pallas/mosaic_gpu/interpret/gpu_callbacks.py b/jax/_src/pallas/mosaic_gpu/interpret/gpu_callbacks.py new file mode 100644 index 000000000000..2fa260f79810 --- /dev/null +++ b/jax/_src/pallas/mosaic_gpu/interpret/gpu_callbacks.py @@ -0,0 +1,685 @@ +# Copyright 2026 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Mapping, Sequence +import contextlib +import dataclasses +import functools +import itertools +import threading +import types +from typing import Self + +import jax +from jax import numpy as jnp +from jax._src import callback +from jax._src import source_info_util +from jax._src.pallas.mosaic.interpret import shared_memory as memory +from jax._src.pallas.mosaic.interpret import utils as interpret_utils +from jax._src.pallas.mosaic.interpret import vector_clock as vc +from jax._src.pallas.mosaic.interpret.race_detection_state import RaceDetectionState +from jax._src.pallas.mosaic_gpu import core as mosaic_gpu_core +from jax._src.state import indexing +import numpy as np + + +IDX_BY_GPU_MEMORY_SPACE: Mapping[mosaic_gpu_core.MemorySpace, int] = ( + types.MappingProxyType( + {v: i for i, v in enumerate(mosaic_gpu_core.MemorySpace)} + ) +) + + +GPU_MEMORY_SPACE_BY_IDX = types.MappingProxyType( + dict(enumerate(mosaic_gpu_core.MemorySpace)) +) + + +def get_memory_space_idx(space: mosaic_gpu_core.MemorySpace | None) -> int: + if space is None: + return IDX_BY_GPU_MEMORY_SPACE[mosaic_gpu_core.MemorySpace.SMEM] + return IDX_BY_GPU_MEMORY_SPACE[space] + + +def is_smem_memory_space(space: mosaic_gpu_core.MemorySpace | None) -> bool: + if space is None: + return True + return space == mosaic_gpu_core.MemorySpace.SMEM + + +def is_gmem_memory_space(space: mosaic_gpu_core.MemorySpace | None) -> bool: + return space == mosaic_gpu_core.MemorySpace.GMEM + + +_shared_memory: memory.SharedMemory | None = None +_shared_memory_init_lock = threading.Lock() +_races: RaceDetectionState | None = None + + +def _get_shared_memory() -> memory.SharedMemory: + assert _shared_memory is not None + return _shared_memory + + +def _clear_shared_memory(): + global _shared_memory + with _shared_memory_init_lock: + _shared_memory = None + + +def get_races() -> RaceDetectionState: + assert _races is not None + return _races + + +# Below we define pairs of _callback_ functions. Each pair consists of +# +# (1) a module-private function, e.g. `_initialize_shared_memory`, and +# (2) a thin wrapper around the this module-private function, e.g. +# `call_initialize_shared_memory`. +# +# The module-private function (1) runs in the Python ("host") process and +# manages interaction of the interpreted Pallas kernel with the memory system, +# represented by the module-global `SharedMemory` object `_shared_memory`. +# +# The wrapper function (2) is to be called from the interpreted Pallas kernel +# (that is simulating a "device", or thread). It serves as the interface between +# the "device" kernel and the "host" memory system and merely passes arguments +# on to the corresponding function (1). Importantly, when the wrapper receives +# an argument that is a Jax (device) array, this argument is received as a Numpy +# (host) array by the corresponding function (1), due to the +# `callback.io_callback` mechanism. + + +def _initialize_shared_memory( + num_devices: jnp.ndarray, + num_threads: jnp.ndarray, + *, + interpret_params: interpret_utils.InterpretGPUParams, +): + global _shared_memory, _races + + num_devices = int(num_devices) + num_threads = int(num_threads) + num_total_threads = num_devices * num_threads + + with _shared_memory_init_lock: + if _shared_memory is None: + vector_clock_size = interpret_params.get_vector_clock_size(num_devices) + _races = RaceDetectionState(num_cores=num_total_threads) + _shared_memory = memory.SharedMemory( + num_devices=num_devices, + # We re-use the `SharedMemory`'s capability to model multiple cores + # per (TPU) device for modeling the multiple threads on a single GPU + # device. + num_cores_per_device=num_threads, + out_of_bounds_reads=interpret_params.out_of_bounds_reads, + # TODO(nrink): Support different DMA execution modes on GPU. + dma_execution_mode="eager", + uninitialized_memory=interpret_params.uninitialized_memory, + detect_races=interpret_params.detect_races, + vector_clock_size=vector_clock_size, + clocks=[ + vc.make_vector_clock(vector_clock_size) + for _ in range(num_total_threads) + ], + barrier=threading.Barrier(num_devices, action=lambda: None), + clean_up_barrier=threading.Barrier( + num_devices, action=_clear_shared_memory + ), + ) + # The naming of the `num_cores` property of `SharedMemory` originates from the + # support for multipl cores in a (Megacore) TPU device. As commented above, on + # GPU we model multiple threads per device as _cores_ in the + # (TPU-/Megacore-)inspired terminology of`SharedMemory`. + assert _shared_memory.num_cores == num_total_threads + + +def call_initialize_shared_memory( + *, + num_devices: int, + num_threads: int, + interpret_params: interpret_utils.InterpretGPUParams, +): + callback.io_callback( + functools.partial( + _initialize_shared_memory, + interpret_params=interpret_params, + ), + (), + num_devices, + num_threads, + ordered=True, + ) + + +def _clean_up_shared_memory(): + shared_memory = _get_shared_memory() + shared_memory.clean_up_barrier.wait() + + +def call_clean_up_shared_memory(): + callback.io_callback(_clean_up_shared_memory, (), ordered=True) + + +def _update_clocks_for_device_barrier(device_id: int): + shared_memory = _get_shared_memory() + shared_memory.update_clocks_for_device_barrier(device_id) + + +def call_update_clocks_for_device_barrier(device_id: int): + callback.io_callback( + _update_clocks_for_device_barrier, (), device_id, ordered=True + ) + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class HostAllocationRequest: + """Request for an allocation on a device/thread and in a memory space.""" + + memory_space_id: int + device_id: int + # Defaults to zero for `AllocationRequest`s that do not specify a thread ID. + thread_id: int = 0 + # The reference count is needed only for allocations that are explicitly + # deallocated (with _deallocate_buffer below). This currently only applies to + # allocations made by a `run_scoped` primitive. + initial_ref_count: int = 1 + + def __iter__(self): + # We make `self` iterable to ease conversion into Numpy and Jax arrays (cf. + # methods `as_array` and `as_jax_array` below). Note that for this purpose + # it would suffice to have any method that return a suitable iterator, + # instead of implementing the special `__iter__` method. Not implementing + # `__iter__` would mean that objects of this class cannot (accidentally) be + # iterated over by clients of the class. + return iter(( + self.memory_space_id, + self.device_id, + self.thread_id, + self.initial_ref_count, + )) + + @classmethod + def shape_and_dtype(cls) -> jax.ShapeDtypeStruct: + num_fields = len(dataclasses.fields(cls)) + return jax.ShapeDtypeStruct((num_fields,), jnp.int32) + + @property + def as_array(self) -> np.ndarray: + return np.array(list(self), dtype=np.int32) + + @property + def as_jax_array(self) -> jnp.ndarray: + return jnp.array(list(self), dtype=jnp.int32) + + @classmethod + def from_array(cls, request: np.ndarray | jnp.ndarray) -> Self: + if request.shape != cls.shape_and_dtype().shape: + raise ValueError( + f"Expected shape {cls.shape_and_dtype().shape} but got" + f" {request.shape}" + ) + if not interpret_utils.is_int(request.dtype): + raise ValueError(f"Expected integer dtype but got {request.dtype}") + + arg_names = [f.name for f in dataclasses.fields(cls)] + values = map(int, request) + return cls(**dict(zip(arg_names, values))) + + +def make_allocation_request_array( + *, + memory_space_id: int, + device_id: int, + thread_id: int = 0, + initial_ref_count: int = 1, +) -> jnp.ndarray: + return HostAllocationRequest( + memory_space_id=memory_space_id, + device_id=device_id, + thread_id=thread_id, + initial_ref_count=initial_ref_count, + ).as_jax_array + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class HostAllocationKey(HostAllocationRequest): + """Key for an allocation in shared memory.""" + + buffer_id: int + + def __iter__(self): + # Note that implementing `__iter__` here affects the bahviour of the + # `as_array` and `as_jax_array` methods of the base class. This is intended. + yield from super().__iter__() + yield self.buffer_id + + +def _allocate_buffer_for_all_threads( + device_id: np.ndarray, + allocation_request: np.ndarray, + value: np.ndarray, +) -> np.ndarray: + """Allocates a buffer for the given `allocation_request`. + + While only a single buffer is allocated, we increment the next buffer ID on + `_shared_memory` for all threads. (This is analogous to the behavior when + interpreting TPU kernels with multiple cores per TPU device.) + + Args: + allocation_request: Array that converts into an `HostAllocationRequest` with + `thread_id` set to zero. This requirement can be thought of as associating + the allocated buffer (that is shared across all threads) with the zeroth + thread. + value: Array of values to initialize the allocated buffer with. + + Returns: + `AllocationKey` to refer to the allocated buffer. + + Raises: + ValueError: If the `thread_id` in `allocation_request` is not zero. + """ + device_id = int(device_id) + allocation_request = HostAllocationRequest.from_array(allocation_request) + if allocation_request.thread_id != 0: + raise ValueError( + "`thread_id` must be zero when allocating a buffer for all threads" + ) + value = np.array(value) + shared_memory = _get_shared_memory() + + key = None + buffer_id = None + for thread_id in range(shared_memory.num_cores_per_device): + buffer_id_for_thread_id = shared_memory.get_next_buffer_id( + device_id, thread_id + ) + if not buffer_id: + buffer_id = buffer_id_for_thread_id + else: + # We keep the buffer ids in sync across all threads. This implies, in + # particular, that every instance of the assignment to `key` below assigns + # an `AllocationKey` object with the same attributes. + assert buffer_id == buffer_id_for_thread_id + + key = HostAllocationKey( + memory_space_id=allocation_request.memory_space_id, + device_id=allocation_request.device_id, + thread_id=0, + initial_ref_count=allocation_request.initial_ref_count, + buffer_id=buffer_id, + ) + ref_count = allocation_request.initial_ref_count + # We rely on the fact that `allocate_buffer` will not allocate a new buffer + # if one with the same key already exists. + shared_memory.allocate_buffer(key, ref_count=ref_count, value=value) + + # We expect the `for`-loop above to have executed its body at least once. + assert key is not None + return key.as_array + + +def call_allocate_buffer_for_all_threads( + device_id: int, + allocation_request: jnp.ndarray, + value: jnp.ndarray, +) -> jnp.ndarray: + return callback.io_callback( + _allocate_buffer_for_all_threads, + HostAllocationKey.shape_and_dtype(), + device_id, + allocation_request, + value, + ordered=True, + ) + + +def _allocate_buffer( + device_id: np.ndarray, + thread_id: np.ndarray, + allocation_request: np.ndarray, + value: np.ndarray, +) -> np.ndarray: + """Allocates a buffer for the given `allocation_request`. + + Args: + allocation_request: Array that converts into a `HostAllocationRequest`. + value: Array of values to initialize the allocated buffer with. + + Returns: + `AllocationKey` to refer to the allocated buffer. + """ + device_id = int(device_id) + thread_id = int(thread_id) + allocation_request = HostAllocationRequest.from_array(allocation_request) + value = np.array(value) + shared_memory = _get_shared_memory() + + buffer_id = shared_memory.get_next_buffer_id(device_id, thread_id) + + key = HostAllocationKey( + memory_space_id=allocation_request.memory_space_id, + device_id=allocation_request.device_id, + thread_id=allocation_request.thread_id, + initial_ref_count=allocation_request.initial_ref_count, + buffer_id=buffer_id, + ) + ref_count = allocation_request.initial_ref_count + shared_memory.allocate_buffer(key, ref_count=ref_count, value=value) + return key.as_array + + +def call_allocate_buffer( + device_id: int, + thread_id: int, + allocation_request: jnp.ndarray, + value: jnp.ndarray, +) -> jnp.ndarray: + return callback.io_callback( + _allocate_buffer, + HostAllocationKey.shape_and_dtype(), + device_id, + thread_id, + allocation_request, + value, + ordered=True, + ) + + +def _deallocate_buffer(allocation_key: np.ndarray): + """Decreases the reference count of the buffer with `allocation_key` (Deallocates the buffer if its reference count becomes zero).""" + allocation_key = HostAllocationKey.from_array(allocation_key) + shared_memory = _get_shared_memory() + shared_memory.deallocate_buffer(allocation_key) + + +def call_deallocate_buffer(allocation_key: jnp.ndarray): + callback.io_callback( + _deallocate_buffer, + None, + allocation_key, + ordered=True, + ) + + +def _handle_out_of_bounds_read( + ret: np.ndarray | None, + full_read_shape: tuple[int, ...], + shape: Sequence[int], + dtype: np.dtype, + allocation_key: HostAllocationKey, + read_range: tuple[int | slice, ...], + shared_memory: memory.SharedMemory, + source_info, + input_name: str | None, + block_indices: tuple[int, ...] | None, + grid_loop_idx: tuple[int, ...] | None, +) -> np.ndarray: + """Handles out-of-bounds read based on shared_memory configuration.""" + if shared_memory.out_of_bounds_reads == "raise": + if source_info is None: + ctx = contextlib.nullcontext() + else: + ctx = source_info_util.user_context( + traceback=source_info.traceback, name_stack=source_info.name_stack + ) # type: ignore[assignment] + with ctx: + if input_name is None: + raise IndexError( + f"Out-of-bounds read of {allocation_key}:" + f" reading [{read_range}] but buffer has shape {shape}." + ) + else: + # Different error message when we are reading a block of an input, + # to copy it to a buffer before invoking the kernel body. + raise IndexError( + f"Out-of-bounds block index {block_indices} for {allocation_key}," + f' input "{input_name}" in iteration {grid_loop_idx}:' + f" reading [{read_range}] but input has shape {shape}." + ) + # out_of_bounds_reads == "uninitialized" + uninit_array = np.full( + full_read_shape, + interpret_utils.get_uninitialized_value( + dtype, shared_memory.uninitialized_memory + ), + dtype=dtype, + ) + if ret is None: + return uninit_array + else: + uninit_array[tuple(slice(s) for s in ret.shape)] = ret + return uninit_array + + +def _is_dynamic(indexer: indexing.NDIndexer) -> bool: + return any( + isinstance(idx, indexing.Slice) + and (idx.is_dynamic_start or idx.is_dynamic_size) + for idx in indexer.indices + ) + + +def _validate_transforms(transforms): + for transform in transforms: + match transform: + case indexing.NDIndexer(): + if _is_dynamic(transform): + raise ValueError( + "Dynamic indexing not supported in GPU interpret mode" + ) + case mosaic_gpu_core.MemoryRefTransform(): + raise ValueError(f"GPU transformation {transform} not supported yet") + case _: + raise ValueError(f"Unsupported transform: {transform}") + + +def _get( + device_id: np.ndarray, + thread_id: np.ndarray, + allocation_key: np.ndarray, + transforms, + block_indices=None, + grid_loop_idx=None, + clock=None, + source_info=None, + input_name=None, +) -> np.ndarray: + """Performs a read from the buffer for `allocation_key_as_array` from the given device and thread.""" + device_id = int(device_id) + thread_id = int(thread_id) + allocation_key = HostAllocationKey.from_array(allocation_key) + + _validate_transforms(transforms) + # TODO(nrink): Support tiling and swizzling transforms. + transforms = jax.tree.map(int, transforms) + + if input_name is not None: + # NOTE: input_name, block_indices, and grid_loop_idx are set only if this + # function is being called to read a block from a pallas_call input (at the + # start of one iteration of the kernel body). + assert block_indices is not None + block_indices = tuple(int(x) for x in block_indices) + assert grid_loop_idx is not None + grid_loop_idx = tuple(int(x) for x in grid_loop_idx) + + shared_memory = _get_shared_memory() + + global_core_id = shared_memory.get_global_core_id(device_id, thread_id) + + read_range = interpret_utils.to_range(transforms) + ret, (shape, dtype), clock_ = shared_memory.get_buffer_content( + allocation_key, read_range, global_core_id + ) + clock = clock if clock is not None else clock_ + + # Compute the shape of the read value, assuming the read is fully in-bounds. + # TODO(jburnim): We already know this shape in the Jaxpr where we insert a + # callback to `get`. Should we just pass the shape to `get`? + # TODO(jburnim): Move to a helper function? + full_read_shape = [] + assert len(read_range) <= len(shape) + for dim_size, idx_or_slice in itertools.zip_longest( + shape, read_range, fillvalue=None + ): + assert isinstance(dim_size, int) + if idx_or_slice is None: + full_read_shape.append(dim_size) + elif isinstance(idx_or_slice, int): + continue + else: + dim_size = (idx_or_slice.stop - idx_or_slice.start) // idx_or_slice.step + assert isinstance(dim_size, int) + full_read_shape.append(dim_size) + full_read_shape = tuple(full_read_shape) + + if (ret is None) or (full_read_shape != ret.shape): + ret = _handle_out_of_bounds_read( + ret, + full_read_shape, + shape, + dtype, + allocation_key, + read_range, + shared_memory, + source_info, + input_name, + block_indices, + grid_loop_idx, + ) + + if shared_memory.detect_races: + get_races().check_read( + device_id, + thread_id, + clock, + allocation_key, + read_range, + source_info=source_info, + ) + return ret + + +def call_get( + *, + result_shape_and_dtype, + device_id: int, + thread_id: int, + allocation_key: jnp.ndarray, + transforms, + block_indices=None, + grid_loop_idx=None, + clock=None, + source_info=None, + input_name=None, +) -> jnp.ndarray: + return callback.io_callback( + functools.partial(_get, source_info=source_info, input_name=input_name), + result_shape_and_dtype, + device_id, + thread_id, + allocation_key, + transforms, + block_indices, + grid_loop_idx, + clock, + ordered=True, + ) + + +def _swap( + device_id: np.ndarray, + thread_id: np.ndarray, + allocation_key_as_array: np.ndarray, + transforms, + val, + mask, + *, + source_info=None, +): + """Performs a swap into the buffer for `allocation_key_as_array` from the given device and thread.""" + device_id = int(device_id) + thread_id = int(thread_id) + allocation_key = HostAllocationKey.from_array(allocation_key_as_array) + + _validate_transforms(transforms) + # TODO(nrink): Support tiling and swizzling transforms. + transforms = jax.tree.map(int, transforms) + + val = np.array(val) + mask = np.array(mask) if mask is not None else None + if mask is not None: + assert mask.shape == val.shape + + shared_memory = _get_shared_memory() + + global_core_id = shared_memory.get_global_core_id(device_id, thread_id) + + read_write_range = interpret_utils.to_range(transforms) + ret, (shape, _), clock = shared_memory.swap_buffer_content( + allocation_key, read_write_range, val, mask, global_core_id + ) + + if ret is None: + if mask is None: + raise ValueError( + f"Out-of-bounds swap of {allocation_key}:" + f" swapping [{read_write_range}] but buffer has shape" + f" {shape} ." + ) + else: + # TODO(jburnim): Include indices of out-of-bounds locations where mask + # is True. + raise ValueError( + f"Out-of-bounds masked swap of {allocation_key}: swapping" + f" [{read_write_range}] but buffer has shape {shape} . " + ) + + if shared_memory.detect_races: + get_races().check_write( + device_id, + thread_id, + clock, + allocation_key, + read_write_range, + source_info=source_info, + ) + return ret + + +def call_swap( + *, + result_shape_and_dtype, + device_id: int, + thread_id: int, + allocation_key: jnp.ndarray, + transforms, + val, + mask, + source_info=None, +): + return callback.io_callback( + functools.partial(_swap, source_info=source_info), + result_shape_and_dtype, + device_id, + thread_id, + allocation_key, + transforms, + val, + mask, + ordered=True, + ) diff --git a/jax/_src/pallas/mosaic_gpu/interpret/interpret_pallas_call.py b/jax/_src/pallas/mosaic_gpu/interpret/interpret_pallas_call.py new file mode 100644 index 000000000000..304f13b906d1 --- /dev/null +++ b/jax/_src/pallas/mosaic_gpu/interpret/interpret_pallas_call.py @@ -0,0 +1,478 @@ +# Copyright 2026 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Mapping, Sequence, Set +import dataclasses +import math +from typing import Any + +import jax +from jax._src import callback +from jax._src import core as jax_core +from jax._src import effects +from jax._src.pallas import core as pallas_core +from jax._src.pallas.mosaic.interpret import thread_map +from jax._src.pallas.mosaic.interpret import utils as interpret_utils +from jax._src.pallas.mosaic_gpu import core as mosaic_gpu_core +from jax._src.pallas.mosaic_gpu.interpret import gpu_callbacks +from jax._src.pallas.mosaic_gpu.interpret import jaxpr_interpret +from jax._src.typing import Array +from jax._src.util import (safe_zip, split_list) +from jax.experimental.pallas import mosaic_gpu as plgpu + + +InterpretParams = interpret_utils.InterpretGPUParams + + +def get_interpret_effects() -> Set[effects.Effect]: + return {callback._OrderedIOEffect} # pylint: disable=protected-access + + +def get_races() -> gpu_callbacks.RaceDetectionState: + return gpu_callbacks.get_races() + + +def _get_grid_bounds(grid_mapping: pallas_core.GridMapping) -> tuple[int, ...]: + if grid_mapping.num_dynamic_grid_bounds > 0: + raise NotImplementedError( + "Dynamic grid bounds not (yet) supported in GPU interpret mode." + ) + result = [] + for x in grid_mapping.grid: + # We have already tested for the absence of dynamic grid bounds. So all + # entries in the grid should be ints. + assert isinstance(x, int) + result.append(x) + return tuple(result) + + +def _get_num_threads( + grid_mapping: pallas_core.GridMapping, mesh: plgpu.Mesh | None +) -> int: + if not mesh: + num_threads = 1 + elif isinstance(mesh, plgpu.Mesh): + if math.prod(mesh.grid) != 1: + raise NotImplementedError( + f"Invalid grid {mesh.grid} in mesh: GPU interpret mode does not" + " support non-trivial grids (i.e. grids with more than one point)." + ) + if mesh.cluster is not None and math.prod(mesh.cluster) != 1: + raise NotImplementedError( + f"Invalid cluster {mesh.cluster} in mesh: GPU interpret mode does not" + " support (non-trivial) clusters." + ) + num_threads = int(mesh.num_threads or 1) + else: + raise ValueError(f"Unsupported mesh type: {type(mesh)}") + + if math.prod(_get_grid_bounds(grid_mapping)) != num_threads: + raise NotImplementedError( + f"Invalid grid {grid_mapping.grid} in grid_mapping: GPU interpret mode" + " does not support grids with more points than threads" + f" ({num_threads}). " + ) + + return num_threads + + +def _allocate_buffers_for_inputs( + device_id: int, + invars: Sequence[Any], + inputs: Sequence[jax.Array], +) -> list[jax.Array]: + """Allocates `GMEM` buffers for the `inputs` of a `pallas_call`.""" + # TODO(nrink): This code is a simplified version to the corresponding TPU + # interpreter code. Eventually, we should merge the two. + input_buffer_keys = [] + for var, value in safe_zip(invars, inputs): + assert var.aval.dtype == value.dtype + allocation_request = gpu_callbacks.make_allocation_request_array( + device_id=device_id, + # All operands of a `pallas_call`/`core_map` that are arrays (i.e. that + # are not sempahores, barriers etc.) are placed in `GMEM`. These arrays + # (or slices thereof) may need to be copied into `SMEM` before executing + # the kernel. + memory_space_id=gpu_callbacks.get_memory_space_idx( + mosaic_gpu_core.MemorySpace.GMEM + ), + ) + input_buffer_keys.append( + gpu_callbacks.call_allocate_buffer_for_all_threads( + device_id, allocation_request, value + ) + ) + + return input_buffer_keys + + +@dataclasses.dataclass(frozen=True) +class AllocationKeyAndValue: + key: jax.Array + value: jax.Array + + @property + def shape(self) -> tuple[int, ...]: + return self.value.shape + + +def _allocate_buffers_for_outputs( + device_id: int, + num_threads: int, + input_output_aliases: tuple[tuple[int, int], ...], + grid_mapping: pallas_core.GridMapping, + input_buffer_keys: Sequence[jax.Array], + input_vals: Sequence[jax.Array], + interpret_params: InterpretParams, +) -> list[AllocationKeyAndValue]: + """Allocates `GMEM` buffers for `pallas_call` outputs, respecting aliased inputs.""" + # TODO(nrink): This code is a simplified version to the corresponding TPU + # interpreter code. Eventually, we should merge the two. + assert len(input_buffer_keys) == len(input_vals) + + oi_alias_map = {v: k for k, v in input_output_aliases} + output_buffer_keys_and_values = [] + + block_shapes = [ + pallas_core._get_block_shape(bm.block_shape) # pylint: disable=protected-access + for bm in grid_mapping.block_mappings + ] + num_inputs = grid_mapping.num_inputs + + num_outputs = grid_mapping.num_outputs + output_block_shapes = block_shapes[num_inputs : num_inputs + num_outputs] + for output_idx, bm in enumerate(grid_mapping.block_mappings_output): + if output_idx in oi_alias_map: + aliased_input_idx = oi_alias_map[output_idx] + # Reuse the `GMEM` buffer for the aliased `pallas_call`/`core_map` input. + output_buffer_keys_and_values.append( + AllocationKeyAndValue( + key=input_buffer_keys[aliased_input_idx], + value=input_vals[aliased_input_idx], + ) + ) + else: + out_val = interpret_params.get_uninitialized_array( + bm.array_aval.shape, bm.array_aval.dtype + ) + padded_val = interpret_params.pad_to_block_dimension( + out_val, output_block_shapes[output_idx] + ) + allocation_request = gpu_callbacks.make_allocation_request_array( + device_id=device_id, + # All outputs of a `pallas_call`/`core_map` that are arrays (i.e. that + # are not sempahores, barriers etc.) are placed in `GMEM`. Results + # from executing the kernel (or slices thereof) may need to be copied + # from `SMEM` into the `GMEM` output buffers that are allocated here. + memory_space_id=gpu_callbacks.get_memory_space_idx( + mosaic_gpu_core.MemorySpace.GMEM + ), + initial_ref_count=num_threads, + ) + output_buffer_key = gpu_callbacks.call_allocate_buffer_for_all_threads( + device_id, allocation_request, padded_val + ) + output_buffer_keys_and_values.append( + AllocationKeyAndValue(key=output_buffer_key, value=out_val) + ) + + return output_buffer_keys_and_values + + +def _get_kernel_buffers( + device_id: int, + num_threads: int, + grid_mapping: pallas_core.GridMapping, + invars: Sequence[Any], + input_buffer_keys: Sequence[jax.Array], + output_buffer_keys: Sequence[jax.Array], + interpret_params: InterpretParams, +) -> list[jax.Array]: + """Collects buffers to be passed to the kernel from `pallas_call` input/output buffers.""" + # TODO(nrink): This code is a simplified version to the corresponding TPU + # interpreter code. Eventually, we should merge the two. + kernel_buffer_keys = [] + for i, var in enumerate(invars): + output_idx = i - grid_mapping.num_inputs + is_input = i < grid_mapping.num_inputs + is_output = (output_idx >= 0) and (output_idx < grid_mapping.num_outputs) + aval = var.aval + # TODO(nrink): Support allocation of semaphores. + if gpu_callbacks.is_gmem_memory_space(aval.memory_space): + # Use the already-allocated GMEM input or output buffer. + # + # TODO(jburnim): For kernel args in GMEM, check that block shape equals + # the shape of the corresponding `pallas_call` input, and that the + # index_map is trivial. + assert is_input ^ is_output + if is_input: + kernel_buffer_keys.append(input_buffer_keys[i]) + if is_output: + kernel_buffer_keys.append(output_buffer_keys[output_idx]) + else: + allocation_request = gpu_callbacks.make_allocation_request_array( + device_id=device_id, + memory_space_id=gpu_callbacks.get_memory_space_idx(aval.memory_space), + initial_ref_count=num_threads, + ) + init_val = interpret_params.get_uninitialized_array( + aval.shape, aval.dtype + ) + kernel_buffer_keys.append( + gpu_callbacks.call_allocate_buffer_for_all_threads( + device_id, allocation_request, init_val + ) + ) + + return kernel_buffer_keys + + +def _get_outputs( + device_id: int, output_buffers: Sequence[AllocationKeyAndValue] +) -> Sequence[Array]: + """Reads and returns values from the allocated output buffers.""" + outputs = [] + for buffer in output_buffers: + outputs.append( + gpu_callbacks.call_get( + result_shape_and_dtype=buffer.value, + device_id=device_id, + thread_id=0, + allocation_key=buffer.key, + transforms=(), # Read the entire buffer. + ) + ) + + return outputs + + +def _load_and_store_between_allocation_keys( + *, + device_id: int, + thread_id: int, + share_and_dtype: Any, + load_allocation_key: jax.Array, + store_allocation_key: jax.Array, + transform, +): + loaded_value = gpu_callbacks.call_get( + result_shape_and_dtype=share_and_dtype, + device_id=device_id, + thread_id=thread_id, + allocation_key=load_allocation_key, + transforms=transform, + ) + gpu_callbacks.call_swap( + result_shape_and_dtype=share_and_dtype, + device_id=device_id, + thread_id=thread_id, + allocation_key=store_allocation_key, + transforms=transform, + val=loaded_value, + mask=None, + ) + + +def _copy_from_gmem_buffers( + device_id: int, + thread_id: int, + avals: Sequence[Any], + gmem_buffer_keys: Sequence[jax.Array], + target_buffer_keys: Sequence[jax.Array], + transforms): + for aval, gmem_buffer_key, target_buffer_key in zip( + avals, gmem_buffer_keys, target_buffer_keys, strict=True + ): + if gpu_callbacks.is_gmem_memory_space(aval.memory_space): + continue + _load_and_store_between_allocation_keys( + device_id=device_id, + thread_id=thread_id, + share_and_dtype=aval, + load_allocation_key=gmem_buffer_key, + store_allocation_key=target_buffer_key, + transform=transforms, + ) + + +def _copy_to_gmem_buffers( + device_id: int, + thread_id: int, + avals: Sequence[Any], + source_buffer_keys: Sequence[jax.Array], + gmem_buffer_keys: Sequence[jax.Array], + transforms): + for aval, source_buffer_key, gmem_buffer_key in zip( + avals, source_buffer_keys, gmem_buffer_keys, strict=True + ): + if gpu_callbacks.is_gmem_memory_space(aval.memory_space): + continue + _load_and_store_between_allocation_keys( + device_id=device_id, + thread_id=thread_id, + share_and_dtype=aval, + load_allocation_key=source_buffer_key, + store_allocation_key=gmem_buffer_key, + transform=transforms, + ) + + +def interpret_pallas_call( + *args, + jaxpr: jax_core.Jaxpr, + debug: bool, + input_output_aliases: tuple[tuple[int, int], ...], + grid_mapping: pallas_core.GridMapping, + mesh: plgpu.Mesh | None, + compiler_params: Mapping[str, Any], + cost_estimate: pallas_core.CostEstimate, + out_avals: tuple[jax_core.AbstractValue, ...], + interpret_params: InterpretParams, + metadata: Mapping[str, str] | None, + **kwargs, +) -> Sequence[Array]: + # TODO(nrink): A more fleshed out implementation of the GPU interpreter may + # need to use some of these `del`ed arguments. + del debug, cost_estimate, metadata, out_avals, kwargs + + # TODO(nrink): Support non-trivial `BlockSpec`s (i.e. with non-trivial + # `index_map`s). + assert all(bm.has_trivial_window() for bm in grid_mapping.block_mappings) + + num_threads = _get_num_threads(grid_mapping, mesh) + device_info = jaxpr_interpret.DeviceInfo() + + interpret_params = dataclasses.replace( + interpret_params, num_cores_or_threads=num_threads + ) + + gpu_callbacks.call_initialize_shared_memory( + num_devices=device_info.num_devices, + num_threads=num_threads, + interpret_params=interpret_params, + ) + + dynamic_grid_args, scalars, inputs = split_list( + args, + [grid_mapping.num_dynamic_grid_bounds, grid_mapping.num_index_operands], + ) + if dynamic_grid_args: + raise NotImplementedError("Dynamic grid bounds not (yet) supported on GPU") + if scalars: + raise NotImplementedError("Scalar arguments not (yet) supported on GPU") + + assert grid_mapping.num_index_operands == 0 + + input_buffer_keys = _allocate_buffers_for_inputs( + device_info.device_id, + jaxpr.invars[: grid_mapping.num_inputs], + inputs, + ) + + output_buffers = _allocate_buffers_for_outputs( + device_info.device_id, + num_threads, + input_output_aliases, + grid_mapping, + input_buffer_keys, + inputs, + interpret_params, + ) + + kernel_buffer_keys = _get_kernel_buffers( + device_info.device_id, + num_threads, + grid_mapping, + jaxpr.invars, + input_buffer_keys, + [buffer.key for buffer in output_buffers], + interpret_params, + ) + + # TODO(nrink): The two assignments below have been taken from the + # corresponding TPU interpreter code. Confirm that they make sense here (i.e. + # for GPU kernels). + kernel_input_buffer_keys, kernel_output_buffer_keys, _ = split_list( + kernel_buffer_keys, [grid_mapping.num_inputs, grid_mapping.num_outputs] + ) + input_vars, output_vars = split_list( + jaxpr.invars[grid_mapping.slice_block_ops], [grid_mapping.num_inputs] + ) + + def _kernel(thread_id): + # Note that the copying from `GMEM` buffers here could introduce races when + # multiple threads copy to the same kernel input buffer. For this to happen, + # (a) there must be multiple threads and (b) the targeted kernel input + # buffer must not be in `GMEM` (since we omit copies from `GMEM` to `GMEM`). + # Currently, the ways in which a Pallas GPU kernel can be invoked do not + # allow for (a) and (b) to be true at the same time: (a) requires that the + # kernel is *not* invoked through a `pallas_call` but (b) can only be caused + # if `BlockSpec`s are used when invoking the kernels, which requires that + # the kernel be invoked through a `pallas_call`. + # + # TODO(nrink): Support copying of slices/blocks only, based on the + # `BlockSpec`s. (Currently only trivial `BlockSpec`s are supported.) + _copy_from_gmem_buffers( + device_id=device_info.device_id, + thread_id=thread_id, + avals=[var.aval for var in input_vars], + gmem_buffer_keys=input_buffer_keys, + target_buffer_keys=kernel_input_buffer_keys, + transforms=(), + ) + + jaxpr_interpreter = jaxpr_interpret.JaxprInterpreter( + thread_id=thread_id, + mesh=mesh, + device_info=device_info, + compiler_params=compiler_params, + interpret_params=interpret_params, + ) + jaxpr_interpreter.interpret(jaxpr, *kernel_buffer_keys) + + # Note that a comment about potential races that is analogous to the comment + # before the call to `_copy_from_gmem_buffers` above applies here too. + # + # TODO(nrink): Support copying of slices/blocks only, based on the + # `BlockSpec`s. (Currently only trivial `BlockSpec`s are supported.) + _copy_to_gmem_buffers( + device_id=device_info.device_id, + thread_id=thread_id, + avals=[var.aval for var in output_vars], + source_buffer_keys=kernel_output_buffer_keys, + gmem_buffer_keys=[buffer.key for buffer in output_buffers], + transforms=(), + ) + + # TODO(nrink): Should we only create happens-before here from thread 0 to + # the other threads? Currently we update the vector clocks for all threads by + # looking at the vector clock of all (other) threads. It should suffice, but + # this needs to be confirmed, to update the vector clocks for all threads by + # looking only at the vector clock of thread 0 (and at the vector clock for + # the thread itself). + gpu_callbacks.call_update_clocks_for_device_barrier(device_info.device_id) + + thread_map.thread_map(_kernel, num_threads) + + # TODO(nrink): Should we only create happens-before here from the other + # threads to thread 0? Analogous to the comment above, it should suffice, but + # this needs to be confirmed, to update only the vector clock of thread 0 (and + # not the vector clocks for all other threads). + gpu_callbacks.call_update_clocks_for_device_barrier(device_info.device_id) + + outputs = _get_outputs(device_info.device_id, output_buffers) + + gpu_callbacks.call_clean_up_shared_memory() + + return outputs diff --git a/jax/_src/pallas/mosaic_gpu/interpret/jaxpr_interpret.py b/jax/_src/pallas/mosaic_gpu/interpret/jaxpr_interpret.py new file mode 100644 index 000000000000..7630ba079db6 --- /dev/null +++ b/jax/_src/pallas/mosaic_gpu/interpret/jaxpr_interpret.py @@ -0,0 +1,282 @@ +# Copyright 2026 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable, Mapping, Sequence +import dataclasses +import functools +import math +from typing import Any + +import jax +from jax import lax +from jax._src import core as jax_core +from jax._src import source_info_util +from jax._src.pallas import primitives +from jax._src.pallas.mosaic.interpret import utils as interpret_utils +from jax._src.pallas.mosaic_gpu import core as mosaic_gpu_core +from jax._src.pallas.mosaic_gpu.interpret import gpu_callbacks +from jax._src.state import primitives as state_primitives +from jax._src.util import safe_zip +from jax.experimental.pallas import mosaic_gpu as plgpu +import jax.numpy as jnp + + +@dataclasses.dataclass(init=False, frozen=True) +class DeviceInfo: + """Information about the device that is being interpreted.""" + + # The indices along each axis of the device being interpreted. + axis_indices: Mapping[jax_core.AxisName, int] + # The size of each axis in the mesh of all (SMPD) devices. + axis_sizes: Mapping[jax_core.AxisName, int] + + def __init__(self): + # Since this class is frozen, we must use `object.__setattr__` to set the + # attributes. + object.__setattr__(self, "axis_sizes", jax_core.get_axis_env().axis_sizes) + object.__setattr__( + self, + "axis_indices", + {k: lax.axis_index(k) for k in self.axis_sizes.keys()}, + ) + + @functools.cached_property + def device_id(self) -> int: + """Computes the logical ID of the device being interpreted.""" + return interpret_utils.device_coords_to_logical_id( + tuple(self.axis_indices.values()), self.axis_sizes, self.axis_indices + ) + + @functools.cached_property + def num_devices(self) -> int: + """Computes the number of (SPMD) devices.""" + return math.prod(self.axis_sizes.values()) + + +def _raise_if_unsupported_memory_space( + space: mosaic_gpu_core.MemorySpace | None, +): + # TODO(nrink): Support more memory spaces. + if space is not None and space not in [ + mosaic_gpu_core.MemorySpace.GMEM, + mosaic_gpu_core.MemorySpace.SMEM, + ]: + raise NotImplementedError(f"Unsupported memory space: {space}") + + +_SENTINEL = jnp.inf + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class JaxprInterpreter: + """Interprets a jaxpr by replacing memory operations with (GPU) callbacks.""" + + thread_id: int + mesh: plgpu.Mesh | None + device_info: DeviceInfo + compiler_params: Mapping[str, Any] + interpret_params: interpret_utils.InterpretParams + + @functools.cached_property + def num_threads(self) -> int: + if self.mesh is None or self.mesh.num_threads is None: + return 1 + else: + return int(self.mesh.num_threads) + + def _interpret_axis_index_p(self, eqn): + assert eqn.primitive is lax.axis_index_p + axis_name = eqn.params["axis_name"] + if (self.mesh is not None) and (axis_name == self.mesh.thread_name): + return jnp.int32(self.thread_id) + elif axis_name in self.device_info.axis_indices: + return self.device_info.axis_indices[axis_name] + else: + raise ValueError( + f"Unable to determine axis index for axis name {axis_name}" + ) + + def _interpret_get_p(self, eqn, get_invals: Callable[[], Sequence[Any]]): + assert eqn.primitive is state_primitives.get_p + assert isinstance(eqn.outvars[0].aval, jax_core.ShapedArray) + invals = get_invals() + return gpu_callbacks.call_get( + result_shape_and_dtype=eqn.outvars[0].aval, + device_id=self.device_info.device_id, + thread_id=self.thread_id, + allocation_key=invals[0], + transforms=jax.tree.unflatten(eqn.params["tree"], invals[1:]), + source_info=eqn.source_info, + ) + + def _interpret_swap_p(self, eqn, get_invals: Callable[[], Sequence[Any]]): + assert eqn.primitive is state_primitives.swap_p + assert isinstance(eqn.outvars[0].aval, jax_core.ShapedArray) + invals = get_invals() + return gpu_callbacks.call_swap( + result_shape_and_dtype=eqn.outvars[0].aval, + device_id=self.device_info.device_id, + thread_id=self.thread_id, + allocation_key=invals[0], + transforms=jax.tree.unflatten(eqn.params["tree"], invals[2:]), + val=invals[1], + mask=None, + ) + + def _interpret_run_scoped_p( + self, eqn, get_invals: Callable[[], Sequence[Any]] + ): + + def _allocate_for_aval(aval, same_allocations_for_all_threads: bool): + _raise_if_unsupported_memory_space(aval.memory_space) + memory_space_idx = gpu_callbacks.get_memory_space_idx(aval.memory_space) + allocation_request = gpu_callbacks.make_allocation_request_array( + device_id=self.device_info.device_id, + memory_space_id=memory_space_idx, + thread_id=0 if same_allocations_for_all_threads else self.thread_id, + initial_ref_count=self.num_threads + if same_allocations_for_all_threads + else 1, + ) + return gpu_callbacks.call_allocate_buffer( + self.device_info.device_id, + self.thread_id, + allocation_request, + self.interpret_params.get_uninitialized_array(aval.shape, aval.dtype), + ) + + def _deallocate_for_aval(allocation, aval): + # TODO(nrink): Check that sempahores have value zero at the end of their + # lifetimes. (If semaphores are never explicitly deallocated, this check + # could take place at the end of kernel interpretation.) + _raise_if_unsupported_memory_space(aval.memory_space) + return gpu_callbacks.call_deallocate_buffer( + allocation, + ) + + assert eqn.primitive is primitives.run_scoped_p + collective_axes = eqn.params["collective_axes"] + # Note that on GPU, `SMEM` buffers and barriers can only be allocated + # collectively (i.e. corresponding to `same_allocations=True`). In the + # interpreter we are a little more lenient and allow non-collective + # allocations for `SMEM` buffers. + same_allocations = False + if collective_axes: + if ( + self.mesh is None + or len(collective_axes) != 1 + or collective_axes[0] != self.mesh.thread_name + ): + raise NotImplementedError( + "When interpreting `run_scoped` in a GPU kernel, non-empty" + " `collective_axes` is currently only supported when it contains a" + " single axis that agrees with the thread axis (i.e. `thread_name`)" + " of the mesh." + ) + same_allocations = True + + # Allocate a buffer or semaphore (to do, see below) for each element of + # `eqn.params['jaxpr'].invars`. It is assumed that each thread runs the same + # sequence of `run_scoped`s. + vars = eqn.params["jaxpr"].invars + allocs = [] + for v in vars: + # TODO(nrink): Support semaphores. (Currently the call to + # `_allocate_for_aval` will fail when trying to allocate a semaphore.) + allocs.append(_allocate_for_aval(v.aval, same_allocations)) + + out = self.interpret(eqn.params["jaxpr"], *get_invals(), *allocs) + + for a, v in safe_zip(allocs, vars): + _deallocate_for_aval(a, v.aval) + + return out + + def _interpret_cond_p(self, eqn, get_invals: Callable[[], Sequence[Any]]): + invals = get_invals() + return lax.switch( + invals[0], + [ + functools.partial(self.interpret, branch_jaxpr.jaxpr) + for branch_jaxpr in eqn.params["branches"] + ], + *invals[1:], + ) + + def _interpret_arithmetic_primitive( + self, eqn, get_invals: Callable[[], Sequence[Any]] + ): + if self.interpret_params.skip_floating_point_ops and all( + interpret_utils.is_float(ovar.aval.dtype) for ovar in eqn.outvars + ): + # Skip `eqn.primitive.bind` since `eqn.primitive` only produces + # floating-point values. It is safe to populate `out` with avals + # since mapping `env.write_many` over `out` (in `self.interpret`) below + # only relies on the shape and dtype (for writing `Placeholder`s). + out = [ovar.aval for ovar in eqn.outvars] + if not eqn.primitive.multiple_results: + out = out[0] + return out + else: + subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) + return eqn.primitive.bind(*subfuns, *get_invals(), **bind_params) + + def interpret(self, jaxpr, *args): + sentinel_for_floating_point_values = ( + _SENTINEL if self.interpret_params.skip_floating_point_ops else None + ) + env = interpret_utils.JaxprEnv( + vars=jaxpr.constvars + jaxpr.invars, + values=args, + sentinel_for_floating_point_values=sentinel_for_floating_point_values, + ) + + for eqn in jaxpr.eqns: + with source_info_util.user_context( + eqn.source_info.traceback, + name_stack=eqn.source_info.name_stack, + ): + # We defer reading the values for `eqn.invars` into each of the branches + # of the match statement below. This is because the case for arithmetic + # primitives may not need to do any reads + # (if `self.interpret_params.skip_floating_point_ops` is True). If this + # is the case, we want to avoid materializing the read array into the + # jaxpr when this function is traced. + deferred_invals = functools.partial(env.read_many, eqn.invars) + match eqn.primitive: + case lax.axis_index_p: + out = self._interpret_axis_index_p(eqn) + case primitives.program_id_p: + # Currently we only support grids and clusters with a single device. + # Hence, zero is the only valid program id. + out = jnp.int32(0) + case state_primitives.get_p: + out = self._interpret_get_p(eqn, deferred_invals) + case primitives.load_p: + raise NotImplementedError("load_p is not supported on GPU yet") + case state_primitives.swap_p: + out = self._interpret_swap_p(eqn, deferred_invals) + case primitives.swap_p: + raise NotImplementedError("swap_p is not supported on GPU yet") + case primitives.run_scoped_p: + out = self._interpret_run_scoped_p(eqn, deferred_invals) + case lax.cond_p: + out = self._interpret_cond_p(eqn, deferred_invals) + case _: + out = self._interpret_arithmetic_primitive(eqn, deferred_invals) + + out = out if eqn.primitive.multiple_results else [out] + env.write_many(eqn.outvars, out) + + return env.read_many(jaxpr.outvars) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 6b06e6b7dfc2..3028258ee46e 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -17,33 +17,45 @@ from __future__ import annotations import collections -from collections.abc import Callable, Hashable, MutableMapping, MutableSequence, Sequence +from collections.abc import Callable, Hashable, Iterator, MutableMapping, MutableSequence, Sequence import contextlib import dataclasses import functools +import inspect +import itertools import math -from typing import Any, Protocol, cast +import operator +from typing import Any, Protocol, Self, TypeVar, assert_never, cast import jax from jax import api_util from jax import lax +from jax._src import checkify +from jax._src import config from jax._src import core as jax_core +from jax._src import debugging +from jax._src import dtypes from jax._src import linear_util as lu +from jax._src import literals +from jax._src import mesh as mesh_lib from jax._src import pjit from jax._src import source_info_util +from jax._src import tree_util from jax._src import util from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith as arith_dialect +from jax._src.lib.mlir.dialects import cf as cf_dialect from jax._src.lib.mlir.dialects import gpu as gpu_dialect +from jax._src.lib.mlir.dialects import llvm as llvm_dialect from jax._src.lib.mlir.dialects import math as math_dialect from jax._src.lib.mlir.dialects import memref as memref_dialect from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect from jax._src.lib.mlir.dialects import scf as scf_dialect from jax._src.lib.mlir.dialects import vector as vector_dialect from jax._src.pallas import core as pallas_core -from jax._src.pallas import pallas_call +from jax._src.pallas import helpers as pallas_helpers from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils from jax._src.pallas.mosaic_gpu import core as gpu_core @@ -52,10 +64,12 @@ from jax._src.state import primitives as sp from jax._src.state import types as state_types from jax._src.state.types import RefReshaper +from jax._src.state.types import RefTransposer from jax._src.util import foreach import jax.experimental.mosaic.gpu as mgpu from jax.experimental.mosaic.gpu import core as mgpu_core from jax.experimental.mosaic.gpu import profiler as mgpu_profiler +from jax.experimental.mosaic.gpu import tcgen05 from jax.experimental.mosaic.gpu import utils as mgpu_utils import jax.numpy as jnp import numpy as np @@ -63,54 +77,68 @@ # TODO(slebedev): Enable type checking. # mypy: ignore-errors -# pytype: skip-file map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip partial = functools.partial SMEM = gpu_core.SMEM -# We align all our SMEM allocations to 1024 bytes. TMA and WGMMA are very -# sensitive to alignment and while this is quite conservative, it gets the job -# done. We should make this more refined in the future. -_SMEM_ALIGNMENT = 1024 WARPGROUP_SIZE = 128 +RefOrTmemType = TypeVar("RefOrTmemType", ir.Value, tcgen05.TMEMRef) +CollectiveAxesType = Sequence[Hashable] -def _align_to(x: int, alignment: int): - if (rem := x % alignment): - return x + alignment - rem - return x - -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class ResourceEstimatorContext: - thread_semantics: mgpu.ThreadSemantics + axis_names: _AxisNames + lowering_semantics: mgpu.LoweringSemantics @property def arrival_multiplier(self) -> int: return ( WARPGROUP_SIZE - if self.thread_semantics == mgpu.ThreadSemantics.Lane + if self.lowering_semantics == mgpu.LoweringSemantics.Lane else 1 ) +AnyBarrier = mgpu.Barrier | mgpu.ClusterBarrier + + @dataclasses.dataclass(kw_only=True, frozen=True) class Resources: smem_scratch_bytes: int = 0 - barrier_counts: collections.Counter[mgpu.Barrier] = dataclasses.field( + tmem_scratch_cols: int = 0 + tmem_collective_scratch_cols: int = 0 + barrier_counts: collections.Counter[AnyBarrier] = dataclasses.field( default_factory=collections.Counter ) + # Maps from collective axes to number of semaphores. + scoped_gmem_semaphores: dict[CollectiveAxesType, int] = dataclasses.field( + default_factory=dict + ) def __post_init__(self): object.__setattr__( self, "smem_scratch_bytes", - _align_to(self.smem_scratch_bytes, _SMEM_ALIGNMENT), + gpu_core.align_to(self.smem_scratch_bytes, gpu_core.SMEM_ALIGNMENT), + ) + + # TMEM must be allocated in 128x8 chunks. + object.__setattr__( + self, + "tmem_scratch_cols", + gpu_core.align_to(self.tmem_scratch_cols, 8), + ) + object.__setattr__( + self, + "tmem_collective_scratch_cols", + gpu_core.align_to(self.tmem_collective_scratch_cols, 8), ) @property - def barriers(self) -> Sequence[mgpu.Barrier]: + def barriers(self) -> Sequence[AnyBarrier]: return list(self.barrier_counts.elements()) def __add__(self, other: Resources) -> Resources: @@ -118,17 +146,42 @@ def __add__(self, other: Resources) -> Resources: # # At the moment, if we have run_scoped(b1) followed by run_scoped(b2) # we will allocate two barriers, even though one would be enough. + sems = self.scoped_gmem_semaphores + other_sems = other.scoped_gmem_semaphores + scoped_gmem_semaphores = {key: sems.get(key, 0) + other_sems.get(key, 0) + for key in sems.keys() | other_sems.keys()} return Resources( smem_scratch_bytes=self.smem_scratch_bytes + other.smem_scratch_bytes, + tmem_scratch_cols=self.tmem_scratch_cols + other.tmem_scratch_cols, + tmem_collective_scratch_cols=self.tmem_collective_scratch_cols + + other.tmem_collective_scratch_cols, barrier_counts=self.barrier_counts + other.barrier_counts, + scoped_gmem_semaphores=scoped_gmem_semaphores, ) - def __or__(self, other: Resources) -> Resources: + def or_(self, other: Resources, axis_names: _AxisNames) -> Resources: + sems = self.scoped_gmem_semaphores + other_sems = other.scoped_gmem_semaphores + scoped_gmem_semaphores = {} + for sem_scope in sems.keys() | other_sems.keys(): + if _is_block_local_scope(sem_scope, axis_names): + value = max(sems.get(sem_scope, 0), other_sems.get(sem_scope, 0)) + elif _is_global_scope(sem_scope, axis_names): + value = sems.get(sem_scope, 0) + other_sems.get(sem_scope, 0) + else: + raise RuntimeError(f"Unrecognized semaphore scope: {sem_scope}") + scoped_gmem_semaphores[sem_scope] = value return Resources( smem_scratch_bytes=max( self.smem_scratch_bytes, other.smem_scratch_bytes ), + tmem_scratch_cols=max(self.tmem_scratch_cols, other.tmem_scratch_cols), + tmem_collective_scratch_cols=max( + self.tmem_collective_scratch_cols, + other.tmem_collective_scratch_cols, + ), barrier_counts=self.barrier_counts | other.barrier_counts, + scoped_gmem_semaphores=scoped_gmem_semaphores, ) @@ -158,11 +211,21 @@ def _estimate_resources( rs = Resources(smem_scratch_bytes=0) for eqn in jaxpr.eqns: # TODO(slebedev): Add support for other primitives, notably control flow. - rule = _resource_estimators.get(eqn.primitive) - if rule is None: - # Assume that unsupported primitives are neutral wrt resource usage. + if rule := _resource_estimators.get(eqn.primitive): + rs = rs.or_( + rule(ctx, *(invar.aval for invar in eqn.invars), **eqn.params), + ctx.axis_names, + ) continue - rs |= rule(ctx, *(invar.aval for invar in eqn.invars), **eqn.params) + # Assume that unsupported primitives are neutral wrt resource usage, + # unless they have a jaxpr in their params. + if any( + isinstance(v, (jax_core.Jaxpr, jax_core.ClosedJaxpr)) + for v in eqn.params.values() + ): + raise NotImplementedError( + f"Resource estimation does not support {eqn.primitive}" + ) return rs @@ -170,10 +233,10 @@ def _estimate_resources( @_register_resource_estimator(lax.cond_p) def _cond_resource_estimator( ctx: ResourceEstimatorContext, *args, branches -) -> int: +) -> Resources: del args # Unused. return functools.reduce( - lambda a, b: a | b, + lambda a, b: a.or_(b, ctx.axis_names), (_estimate_resources(ctx, branch.jaxpr) for branch in branches), ) @@ -181,9 +244,9 @@ def _cond_resource_estimator( @_register_resource_estimator(lax.scan_p) def _scan_resource_estimator( ctx: ResourceEstimatorContext, *args, jaxpr: jax_core.ClosedJaxpr, **params -) -> int: +) -> Resources: del args, params # Unused. - return _estimate_resources(ctx, jaxpr) + return _estimate_resources(ctx, jaxpr.jaxpr) @_register_resource_estimator(lax.while_p) @@ -193,64 +256,192 @@ def _while_resource_estimator( cond_jaxpr: jax_core.ClosedJaxpr, body_jaxpr: jax_core.ClosedJaxpr, **params, -) -> int: +) -> Resources: del args, params # Unused. - return _estimate_resources(ctx, cond_jaxpr) | _estimate_resources( - ctx, body_jaxpr + return _estimate_resources(ctx, cond_jaxpr.jaxpr).or_( + _estimate_resources(ctx, body_jaxpr.jaxpr), ctx.axis_names ) +@_register_resource_estimator(pjit.jit_p) +def _pjit_resource_estimator( + ctx: ResourceEstimatorContext, + *args, + jaxpr: jax_core.ClosedJaxpr, + **params, +) -> Resources: + del args, params # Unused. + return _estimate_resources(ctx, jaxpr.jaxpr) + + +@_register_resource_estimator(pallas_core.core_map_p) +def _core_map_resource_estimator( + ctx: ResourceEstimatorContext, *args, jaxpr: jax_core.Jaxpr, **params +) -> Resources: + del args, params # Unused. + return _estimate_resources(ctx, jaxpr) + + +@_register_resource_estimator(discharge.run_state_p) +def _run_state_resource_estimator( + ctx: ResourceEstimatorContext, *args, jaxpr: jax_core.Jaxpr, **params +) -> Resources: + del args, params # Unused. + return _estimate_resources(ctx, jaxpr) + + @_register_resource_estimator(primitives.run_scoped_p) def _run_scoped_resource_estimator( - ctx: ResourceEstimatorContext, *consts, jaxpr: jax_core.Jaxpr -) -> int: + ctx: ResourceEstimatorContext, + *consts, + jaxpr: jax_core.Jaxpr, + collective_axes, +) -> Resources: + # NOTE: This rule assumes that the allocation happens collectively, although + # it can't be checked here due to limited context. We check this in the actual + # lowering rule. del consts # Unused. rs = Resources() for v in jaxpr.invars: - aval = v.aval + aval = cast(ShapedAbstractValue, v.aval) if isinstance(aval.dtype, gpu_core.BarrierType): + multiplier = 1 if aval.dtype.orders_tensor_core else ctx.arrival_multiplier rs += Resources( barrier_counts=collections.Counter([ mgpu.Barrier( - aval.dtype.num_arrivals * ctx.arrival_multiplier, *aval.shape + aval.dtype.num_arrivals * multiplier, *aval.shape ) ]) ) - else: + continue + if isinstance(aval.dtype, gpu_core.ClusterBarrierType): + collective_dims = jax.tree.map( + lambda axis: _resolve_cluster_axis(ctx.axis_names, axis), + aval.dtype.collective_axes, + ) + rs += Resources( + barrier_counts=collections.Counter( + [mgpu.ClusterBarrier(collective_dims, aval.dtype.num_arrivals, *aval.shape)] + ) + ) + continue + assert isinstance(aval, state_types.AbstractRef) + if aval.memory_space == gpu_core.TMEM: + if len(aval.shape) != 2: + raise ValueError(f"TMEM allocations must be 2D. Got {aval.shape}") + # Estimate columns used. + if isinstance(aval, gpu_core.AbstractRefUnion): + assert aval.shape[0] == 128 + cols_used = aval.shape[1] + else: + cols_used = aval.layout.cols_in_shape( + aval.shape, dtypes.itemsize_bits(aval.dtype) + ) + if aval.collective: + rs += Resources(tmem_collective_scratch_cols=cols_used) + else: + rs += Resources(tmem_scratch_cols=cols_used) + elif aval.memory_space == gpu_core.SMEM: rs += Resources( - smem_scratch_bytes=math.prod(aval.shape) * aval.dtype.itemsize + smem_scratch_bytes=aval.size * dtypes.itemsize_bits(aval.dtype) // 8 ) + elif aval.memory_space == gpu_core.REGS: + # Don't need to allocate anything. + pass + elif aval.memory_space == gpu_core.GMEM and jnp.issubdtype(aval.dtype, pallas_core.semaphore): + if _is_block_local_scope(collective_axes, ctx.axis_names): + rs += Resources(scoped_gmem_semaphores={collective_axes: aval.size}) + else: + raise ValueError( + "Only thread-collective allocations are supported in run_scoped. To" + " allocate global semaphores, use pl.get_global." + ) + else: + raise NotImplementedError( + f"Unsupported memory space: {aval.memory_space}") return rs + _estimate_resources(ctx, jaxpr) +REDUCE_SCRATCH_ELEMS = 128 * 4 # vector of 4 elements per lane in each WG @_register_resource_estimator(lax.reduce_sum_p) -def _reduce_sum_resource_estimator( - ctx: ResourceEstimatorContext, x_aval: jax_core.ShapedArray, *, axes -) -> int: +@_register_resource_estimator(lax.reduce_max_p) +@_register_resource_estimator(lax.reduce_min_p) +def _reduce_resource_estimator( + ctx: ResourceEstimatorContext, x_aval: jax_core.ShapedArray, *, axes, + **kwargs +) -> Resources: del ctx, axes # Unused. - # We don't need shmem for some reductons, but it depends on the layout, so we - # conservatively request some scratch space. - return Resources(smem_scratch_bytes=4 * x_aval.dtype.itemsize) + # We don't need SMEM for some reductions, but it depends on the layout, so we + # conservatively request the maximum scratch space we might need. + return Resources(smem_scratch_bytes=REDUCE_SCRATCH_ELEMS * x_aval.dtype.itemsize) + + +@dataclasses.dataclass(frozen=True) +class _AxisNames: + grid: Sequence[Hashable] + cluster: Sequence[Hashable] = () + wg: Hashable | None = None + + def __iter__(self) -> Iterator[Hashable]: + return itertools.chain( + self.grid, self.cluster, [self.wg] if self.wg is not None else [] + ) + + def reverse(self) -> "_AxisNames": + return _AxisNames(self.grid[::-1], self.cluster[::-1], self.wg) + + +AnyBarrierRef = ( + mgpu.BarrierRef | mgpu.DialectBarrierRef | mgpu.CollectiveBarrierRef +) @dataclasses.dataclass class ModuleContext: name: str - grid_names: Sequence[Hashable] | None + axis_names: _AxisNames program_ids: Sequence[ir.Value] | None approx_math: bool - single_wg_lane_predicate: ir.Value + single_wg_lane_predicate: ir.Value | None + single_warp_lane_predicate: ir.Value | None smem_requested_bytes: int smem_used_bytes: int - runtime_barriers: MutableMapping[ - mgpu.Barrier, MutableSequence[mgpu.BarrierRef] - ] + tmem_requested_cols: int + tmem_used_cols: int + tmem_base: ir.Value | None + scoped_gmem_used_semaphores: dict[CollectiveAxesType, int] + scoped_gmem_semaphore_base_ptr: dict[CollectiveAxesType, ir.Value] + runtime_barriers: MutableMapping[AnyBarrier, MutableSequence[AnyBarrierRef]] name_stack: source_info_util.NameStack traceback_caches: mlir.TracebackCaches squashed_dims: tuple[int, ...] - thread_semantics: mgpu.ThreadSemantics + lowering_semantics: mgpu.LoweringSemantics + primitive_semantics: gpu_core.PrimitiveSemantics + mesh_info: pallas_utils.MeshInfo | None + # See the documentation of unsafe_no_auto_barriers in CompilerParams. + auto_barriers: bool + warp_axis_name: str | None = None + + @property + def single_lane_predicate(self) -> ir.Value: + """Returns a predicate that is True for a single lane within the current + thread semantics. + """ + assert self.lowering_semantics == mgpu.LoweringSemantics.Lane + match self.primitive_semantics: + case gpu_core.PrimitiveSemantics.Warpgroup: + return self.single_wg_lane_predicate + case gpu_core.PrimitiveSemantics.Warp: + return self.single_warp_lane_predicate + case _: + raise ValueError(f"Unknown semantics: {self.primitive_semantics}") - def reserve_barrier(self, barrier: mgpu.Barrier) -> mgpu.BarrierRef: + @contextlib.contextmanager + def reserve_barrier( + self, barrier: mgpu.Barrier + ) -> Iterator[ + mgpu.BarrierRef | mgpu.DialectBarrierRef | mgpu.CollectiveBarrierRef + ]: """Reserves a barrier. Raises: @@ -259,81 +450,155 @@ def reserve_barrier(self, barrier: mgpu.Barrier) -> mgpu.BarrierRef: available = self.runtime_barriers.get(barrier, []) if not available: raise RuntimeError(f"Barrier {barrier} is already reserved") - return available.pop() + barrier = available.pop() + yield barrier + available.append(barrier) + + @contextlib.contextmanager + def reserve_semaphores(self, + shape: tuple[int, ...], + collective_axes: CollectiveAxesType + ) -> Iterator[ir.Value]: + allocated_sems = math.prod(shape) + ref = mgpu.memref_slice( + self.scoped_gmem_semaphore_base_ptr[collective_axes], + mgpu.ds(self.scoped_gmem_used_semaphores[collective_axes], + allocated_sems), + ) + ref = mgpu.memref_reshape(ref, shape) + + self.scoped_gmem_used_semaphores[collective_axes] += allocated_sems + yield ref + # TODO: In debug mode verify the values of all semaphores are again 0 + self.scoped_gmem_used_semaphores[collective_axes] -= allocated_sems + + @contextlib.contextmanager + def alloc_tmem( + self, + struct: jax.ShapeDtypeStruct, + *, + layout: tcgen05.TMEMLayout, + ) -> Iterator[tcgen05.TMEMRef | ir.Value]: + if self.lowering_semantics == mgpu.LoweringSemantics.Lane: + off = arith_dialect.addi( + self.tmem_base, _i32_constant(self.tmem_used_cols) + ) + tmem_ref = tcgen05.TMEMRef( + address=off, + shape=struct.shape, + dtype=mgpu_utils.dtype_to_ir_type(struct.dtype), + layout=layout, + ) + else: + type = ir.MemRefType.get( + struct.shape, + mgpu_utils.dtype_to_ir_type(struct.dtype), + memory_space=mgpu_utils.tmem(), + ) + tmem_ref = mgpu.dialect.slice_tmem( + type, self.tmem_base, self.tmem_used_cols + ) + layout_attr = mgpu.to_layout_attr(layout) + tmem_ref = mgpu.dialect.tmem_layout_cast(tmem_ref, layout_attr) + cols_used = layout.cols_in_shape( + struct.shape, dtypes.itemsize_bits(struct.dtype) + ) + cols_used = gpu_core.align_to(cols_used, gpu_core.TMEM_COL_ALIGNMENT) + self.tmem_used_cols += cols_used + yield tmem_ref + self.tmem_used_cols -= cols_used - # TODO(cperivol): Only return the shapes and figure out the sizes when freeing. @contextlib.contextmanager - def scratch_view( - self, structs: Sequence[jax.ShapeDtypeStruct] - ) -> Sequence[ir.Value]: - """Creates a view into the runtime scratch buffer for each struct. + def scratch_view(self, struct: jax.ShapeDtypeStruct) -> Iterator[ir.Value]: + """Creates a view into the runtime scratch buffer for the given struct. This is a low-level API. Use it only if you know what you are doing. - The function allocates bytes at the top of a stack, which need to be - deallocated in a FIFO fashion with :meth:`ModuleContext.stack_free_smem`. - After deallocation, the view is invalid and cannot be used. + After the context manager exits, the view is invalid and cannot be used. Args: - structus: The shapes and dtypes of the views to create. + struct: The shape and dtype of the view to create. Returns: - A tuple, where the first element is the number of bytes allocated, - and the second element is a sequence of memref views into the - runtime scratch buffer. + A memref view into the runtime scratch buffer. """ smem_base = None - smem = ir.Attribute.parse("#gpu.address_space") i8 = ir.IntegerType.get_signless(8) i32 = ir.IntegerType.get_signless(32) - if self.thread_semantics == mgpu.ThreadSemantics.Lane: + if self.lowering_semantics == mgpu.LoweringSemantics.Lane: smem_base = gpu_dialect.dynamic_shared_memory( - ir.MemRefType.get((mgpu_utils.DYNAMIC,), i8, memory_space=smem) + ir.MemRefType.get( + (mgpu_utils.DYNAMIC,), i8, memory_space=mgpu_utils.smem() + ) ) - views = [] off = initial_used_bytes = self.smem_used_bytes - assert off % _SMEM_ALIGNMENT == 0 - for s in structs: - scratch_ty = ir.MemRefType.get( - s.shape, - mgpu_utils.dtype_to_ir_type(s.dtype), - memory_space=smem, - ) - # The below code emission relies on the assumption that the first scratch - # operand provided by Mosaic GPU always begins at the beginning of - # dynamic SMEM. Mosaic GPU is expected to uphold that invariant. - if self.thread_semantics == mgpu.ThreadSemantics.Lane: - view = memref_dialect.view( - scratch_ty, smem_base, _as_index(off), [] - ) - else: - view = mgpu.dialect.slice_smem(scratch_ty, mgpu_utils.c(off, i32)) - views.append(view) + assert off % gpu_core.SMEM_ALIGNMENT == 0 + scratch_ty = ir.MemRefType.get( + struct.shape, + mgpu_utils.dtype_to_ir_type(struct.dtype), + memory_space=mgpu_utils.smem(), + ) + # The below code emission relies on the assumption that the first scratch + # operand provided by Mosaic GPU always begins at the beginning of + # dynamic SMEM. Mosaic GPU is expected to uphold that invariant. + if self.lowering_semantics == mgpu.LoweringSemantics.Lane: + view = memref_dialect.view(scratch_ty, smem_base, _as_index(off), []) + else: + view = mgpu.dialect.slice_smem(scratch_ty, mgpu_utils.c(off, i32)) - off += _align_to( - math.prod(s.shape) * jnp.dtype(s.dtype).itemsize, _SMEM_ALIGNMENT - ) + off += gpu_core.align_to( + math.prod(struct.shape) + * dtypes.itemsize_bits(jnp.dtype(struct.dtype)) + // 8, + gpu_core.SMEM_ALIGNMENT, + ) assert off <= self.smem_requested_bytes, "Ran out of scoped SMEM" - assert off % _SMEM_ALIGNMENT == 0 + assert off % gpu_core.SMEM_ALIGNMENT == 0 self.smem_used_bytes = off - yield views + yield view self.smem_used_bytes = initial_used_bytes +# This is morally ``ShapedArray | state.AbstractRef``, but pytype does not +# allow calling methods on a union type, making ``update`` non-callable, so +# we use a protocol instead of a union. +class ShapedAbstractValue(Protocol): + shape: tuple[jax_core.DimSize, ...] + dtype: jnp.dtype + weak_type: bool + + @property + def ndim(self) -> int: + ... + + @property + def size(self) -> int: + ... + + def update(self, **kwargs: Any) -> Self: + raise NotImplementedError + + @dataclasses.dataclass(frozen=True) class LoweringRuleContext: module_ctx: ModuleContext launch_ctx: mgpu.LaunchContext prim: jax_core.Primitive - avals_in: Sequence[jax_core.ShapedArray] - avals_out: Sequence[jax_core.ShapedArray] + avals_in: Sequence[ShapedAbstractValue] + avals_out: Sequence[ShapedAbstractValue] + out_layout_hint: mgpu.FragmentedLayout | None - replace = dataclasses.replace + def replace(self, **changes: Any) -> LoweringRuleContext: + # The wrapper is necessary to convince pytype that this is a method. + return dataclasses.replace(self, **changes) @property def estimator_ctx(self) -> ResourceEstimatorContext: - return ResourceEstimatorContext(thread_semantics=self.module_ctx.thread_semantics) + return ResourceEstimatorContext( + axis_names=self.module_ctx.axis_names, + lowering_semantics=self.module_ctx.lowering_semantics, + ) @dataclasses.dataclass(frozen=True) @@ -341,14 +606,9 @@ class LoweringResult: module: ir.Module grid: tuple[int, ...] block: tuple[int, ...] - out_structs: tuple[jax.ShapeDtypeStruct, ...] - profiler_context: ProfilerContext | None - - -@dataclasses.dataclass(frozen=True) -class ProfilerContext: - dump_path: str - spec: mgpu_profiler.ProfilerSpec + new_out_shapes: tuple[jax.ShapeDtypeStruct, ...] # Does not include gmem scratch! + profiler_spec: mgpu_profiler.ProfilerSpec | None + gmem_scratch_shapes: tuple[jax.ShapeDtypeStruct, ...] class LoweringError(Exception): # pylint: disable=g-bad-exception-name @@ -366,11 +626,13 @@ def _eval_index_map( ) result = [] for i, b in zip(block_indices, block_mapping.block_shape): - if b is pallas_core.mapped: - result.append(i) - else: - # TODO(slebedev): Use a type-agnostic multiplication wrapper. - result.append(arith_dialect.muli(_as_index(i), _as_index(b))) + match b: + case pallas_core.Squeezed() | pallas_core.Element(): + result.append(i) + case pallas_core.Blocked(): + result.append(arith_dialect.muli(_as_index(i), _as_index(b))) + case _: + raise ValueError(f"Unsupported block dim type: {b}") return tuple(result) @@ -382,12 +644,12 @@ def err_details(bm: pallas_core.BlockMapping) -> str: return ( f"Block spec for {bm.origin} in pallas_call {debug_info.func_src_info}" f" has block shape {bm.block_shape}, array shape" - f" {bm.array_shape_dtype.shape}," + f" {bm.array_aval.shape}," # TODO(necula): add index_map source location info f" and index_map {bm.index_map_jaxpr.jaxpr} in" f" memory space {bm.transformed_block_aval.memory_space}." " See details at" - " https://jax.readthedocs.io/en/latest/pallas/grid_blockspec.html#pallas-blockspec." + " https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec." ) for bm in block_mappings: @@ -402,7 +664,7 @@ def err_details(bm: pallas_core.BlockMapping) -> str: + err_details(bm) ) - if not isinstance(bm.indexing_mode, pallas_core.Blocked): + if any(isinstance(b, pallas_core.Element) for b in bm.block_shape): raise NotImplementedError( "Only Blocked indexing mode is supported in Mosaic GPU lowering.\n\n" + err_details(bm) @@ -432,27 +694,29 @@ def index_map(*indices): which_parallel, indices, [ - primitives.program_id(axis) - for axis, is_parallel in enumerate(which_parallel) + primitives.program_id(axis - 1) + for axis, is_parallel in zip( + itertools.accumulate(which_parallel), which_parallel + ) if is_parallel ], ) return eval_index_map(*new_indices) - return gpu_core.GPUBlockSpec( + return gpu_core.BlockSpec( bm.block_shape, index_map, memory_space=bm.transformed_block_aval.memory_space, - indexing_mode=bm.indexing_mode, - transforms=bm.transforms, + transforms=cast(Sequence[gpu_core.MemoryRefTransform], bm.transforms), ) def lower_pipelined_jaxpr_to_module( grid_mapping: pallas_core.GridMapping, - mesh: pallas_core.Mesh | None, + gpu_mesh: pallas_core.Mesh | None, + jax_mesh: mesh_lib.Mesh | None, jaxpr: jax_core.Jaxpr, - compiler_params: dict[str, Any], + params: gpu_core.CompilerParams, cost_estimate: pallas_core.CostEstimate | None, ) -> LoweringResult: del cost_estimate # Unused. @@ -474,24 +738,23 @@ def lower_pipelined_jaxpr_to_module( block_mappings, [grid_mapping.num_inputs] ) - if mesh is not None: - assert isinstance(mesh, gpu_core.GPUMesh) - if mesh and mesh.num_threads is not None: - # Last dim corresponds to the warpgroup count. - block = (128 * grid_mapping.grid[-1], 1, 1) - grid = grid_mapping.grid[:-1] + if gpu_mesh: + assert isinstance(gpu_mesh, gpu_core.Mesh) + block = (128 * (gpu_mesh.num_threads or 1), 1, 1) + grid = gpu_mesh.grid + thread_axis = ( + gpu_mesh.thread_name if gpu_mesh.thread_name is not None else () + ) else: block = (128, 1, 1) grid = grid_mapping.grid + thread_axis = () - params = compiler_params.get("mosaic_gpu", {}) - dimension_semantics = params.get("dimension_semantics", None) - if dimension_semantics is None: + if params.dimension_semantics is None: which_parallel = [True] * len(grid) else: - assert len(dimension_semantics) == len(grid) - which_parallel = [ds == "parallel" for ds in dimension_semantics] - del dimension_semantics + assert len(params.dimension_semantics) == len(grid) + which_parallel = [ds == "parallel" for ds in params.dimension_semantics] sequential_grid = tuple( d for axis, d in enumerate(grid) if not which_parallel[axis] @@ -500,34 +763,38 @@ def lower_pipelined_jaxpr_to_module( d for axis, d in enumerate(grid) if which_parallel[axis] ) - from jax._src.pallas.mosaic_gpu import pipeline - from jax._src.pallas.mosaic_gpu import primitives as gpu_primitives + from jax._src.pallas.mosaic_gpu import pipeline # pytype: disable=import-error + from jax._src.pallas.mosaic_gpu import primitives as gpu_primitives # pytype: disable=import-error - def ref_for_aval(aval: jax_core.AbstractValue): + def ref_for_aval(aval: ShapedAbstractValue): if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef): return gpu_core.WGMMAAccumulatorRef(aval.shape, aval.dtype) - elif isinstance(aval, pallas_core.AbstractMemoryRef): - return pallas_core.MemoryRef(aval.shape, aval.dtype, aval.memory_space) + elif isinstance(aval, gpu_core.AbstractTMEMRef): + return gpu_core.GPUMemoryRef( + jax_core.ShapedArray(aval.shape, aval.dtype), gpu_core.TMEM, + transforms=(), layout=aval.layout, collective=aval.collective, + ) + elif isinstance(aval, state_types.AbstractRef): + return pallas_core.MemoryRef(jax_core.ShapedArray(aval.shape, aval.dtype), + aval.memory_space) else: return gpu_core.SMEM(aval.shape, aval.dtype) def pipeline_fn(*refs): - return primitives.run_scoped( + primitives.run_scoped( functools.partial(scoped_pipeline_fn, *refs), scratch_refs=[ - ref_for_aval(v.aval) + ref_for_aval(cast(ShapedAbstractValue, v.aval)) for v in jaxpr.invars[grid_mapping.slice_scratch_ops] ], + collective_axes=thread_axis, # scratch_refs are shared across threads ) + return () # ``wrap_init`` does not support functions returning None. def scoped_pipeline_fn(*refs, scratch_refs): - def body_fn(*refs): - grid_env = pallas_core.current_grid_env() - assert grid_env is not None # Set by ``emit_pipeline``. + def body_fn(indices, *refs): program_ids_template = util.merge_lists( - which_parallel, - [grid_axis.index for grid_axis in grid_env], - [None] * sum(which_parallel), + which_parallel, indices, [None] * sum(which_parallel) ) assert len(refs) + len(scratch_refs) == len(jaxpr.invars) return gpu_primitives.jaxpr_call( @@ -545,146 +812,288 @@ def body_fn(*refs): _block_spec_from_block_mapping(bm, which_parallel) for bm in out_block_mappings ], - max_concurrent_steps=params.pop("max_concurrent_steps", 1), - delay_release=params.pop("delay_release", 0), + max_concurrent_steps=params.max_concurrent_steps, )(*refs) with grid_mapping.trace_env(): - new_jaxpr, _, new_consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init( - # ``wrap_init`` does not support functions returning None. - lambda *args: pipeline_fn(*args) or (), - debug_info=jaxpr.debug_info, - ), + new_jaxpr, _, new_consts = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(pipeline_fn, debug_info=jaxpr.debug_info.with_unknown_names()), [ gpu_core.GMEM( - bm.array_shape_dtype.shape, bm.array_shape_dtype.dtype + bm.array_aval.shape, bm.array_aval.dtype ).get_ref_aval() for bm in block_mappings ], ) assert not new_consts + axis_names = ( + _AxisNames(gpu_mesh.grid_names, gpu_mesh.cluster_names, gpu_mesh.thread_name) + if gpu_mesh is not None + else _AxisNames(grid_mapping.grid_names or ()) + ) with grid_mapping.trace_env(): return lower_jaxpr_to_module( + jax_mesh, + axis_names, parallel_grid, - grid_mapping.grid_names, block, - mesh.cluster if mesh is not None else (), - [bm.array_shape_dtype for bm in in_block_mappings], - [bm.array_shape_dtype for bm in out_block_mappings], + gpu_mesh.cluster if gpu_mesh is not None else (), + [bm.array_aval for bm in in_block_mappings], + [bm.array_aval for bm in out_block_mappings], new_jaxpr, - compiler_params, + params, new_consts, ) def lower_jaxpr_to_module( - grid: Sequence[int], - grid_names: Sequence[str], - block: Sequence[int], - cluster: Sequence[int], - in_shapes: Sequence[jax.ShapeDtypeStruct], - out_shapes: Sequence[jax.ShapeDtypeStruct], + jax_mesh: mesh_lib.Mesh | None, + axis_names: _AxisNames, + grid: tuple[int, ...], + block: tuple[int, ...], + cluster: tuple[int, ...], + in_shapes: Sequence[jax_core.ShapedArray], + out_shapes: Sequence[jax_core.ShapedArray], jaxpr: jax_core.Jaxpr, - compiler_params: dict[str, Any], + params: gpu_core.CompilerParams, consts=(), ) -> LoweringResult: debug_info = jaxpr.debug_info - params = compiler_params.get("mosaic_gpu", {}) - approx_math = params.get("approx_math", False) - thread_semantics = params.get( - "thread_semantics", mgpu_core.ThreadSemantics.Lane - ) + approx_math = params.approx_math + lowering_semantics = params.lowering_semantics + + if len(cluster) < 3: + cluster = (1,) * (3 - len(cluster)) + cluster + else: + assert len(cluster) == 3 if len(grid) <= 3: squashed_dims = () - parallel_grid = grid + (1,) * (3 - len(grid)) + parallel_grid = (1,) * (3 - len(grid)) + grid else: - # If we have >3 parallel dimensions, we merge all leading dimensions - # into the first (Dimension.x) CUDA grid dimension. + # If we have >3 parallel dimensions, we flatten all but the minormost 2 dims. + # Ex: (2, 3, 4, 5) -> (6, 4, 5) squashed_dims = grid[:-2] parallel_grid = (math.prod(grid[:-2]), *grid[-2:]) + # We reverse the order because Pallas prefers row-major iteration while the + # CUDA runtime prefers column-major iteration. + parallel_grid = parallel_grid[::-1] + cluster = cluster[::-1] + squashed_dims = squashed_dims[::-1] + axis_names = axis_names.reverse() + + rs = _estimate_resources( + ResourceEstimatorContext( + axis_names=axis_names, lowering_semantics=lowering_semantics + ), + jaxpr, + ) + def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): - *buffers_gmem, (runtime_smem, runtime_barriers) = buffers + *buffers_gmem, ( + runtime_smem, + runtime_barriers, + runtime_tmem, + ) = buffers + num_input_buffers = (len(in_shapes) + + len(rs.scoped_gmem_semaphores)) + input_buffers_gmem = buffers_gmem[:num_input_buffers] + output_buffers_gmem = buffers_gmem[num_input_buffers:] + + scoped_gmem_semaphores = {} + for collective_axes in sorted( + rs.scoped_gmem_semaphores.keys(), reverse=True): + num_sems = rs.scoped_gmem_semaphores[collective_axes] + # Extract the semaphores local to the current scope. + index = ir.IndexType.get() + # TODO(justinfu): Compute scope_idx for general collective_axes. + # scope_idx computes axis_index(all_axes - collective_axes) + if _is_block_local_scope(collective_axes, axis_names): + scope_idx = arith_dialect.index_castui(index, mgpu_utils.block_idx()) + elif _is_global_scope(collective_axes, axis_names): + scope_idx = _as_index(0) + else: + raise NotImplementedError( + f"Unimplemented scope for semaphores: {collective_axes=}") + scoped_gmem_semaphores[collective_axes] = mgpu.memref_slice( + output_buffers_gmem[-1], + mgpu.ds( + arith_dialect.muli( + scope_idx, arith_dialect.constant(index, num_sems) + ), + num_sems, + ), + ) + # The semaphore buffer is an aliased input/output, so we need to skip it + # in both the inputs and outputs. + input_buffers_gmem = input_buffers_gmem[:-1] + output_buffers_gmem = output_buffers_gmem[:-1] + buffers_gmem = [*input_buffers_gmem, *output_buffers_gmem] grouped_barriers = collections.defaultdict(list) for barrier, barrier_ref in zip(rs.barriers, runtime_barriers): grouped_barriers[barrier].append(barrier_ref) + if runtime_tmem is not None: + if lowering_semantics == mgpu.LoweringSemantics.Lane: + tmem_cols = math.prod(runtime_tmem.shape) // tcgen05.TMEM_ROWS + tmem_base = runtime_tmem.address + else: + tmem_cols = math.prod(runtime_tmem.type.shape) // tcgen05.TMEM_ROWS + tmem_base = runtime_tmem + else: + tmem_cols = 0 + tmem_base = None + + if lowering_semantics == mgpu.LoweringSemantics.Lane: + single_wg_lane_predicate = mgpu.single_thread_predicate( + scope=mgpu.ThreadSubset.WARPGROUP) + single_warp_lane_predicate = mgpu.single_thread_predicate( + scope=mgpu.ThreadSubset.WARP) + else: # Warpgroup semantics do not have a single lane predicate. + single_wg_lane_predicate = None + single_warp_lane_predicate = None + module_ctx = ModuleContext( mlir.sanitize_name(debug_info.func_name), - grid_names, - [_program_id(axis, squashed_dims) for axis in range(len(grid))], + axis_names, + [ + _program_id(axis, squashed_dims, len(grid)) + for axis in range(len(grid)) + ], approx_math, - mgpu.single_thread_predicate(per_block=False), + single_wg_lane_predicate, + single_warp_lane_predicate, smem_requested_bytes=math.prod(ir.MemRefType(runtime_smem.type).shape), smem_used_bytes=0, + tmem_requested_cols=tmem_cols, + tmem_used_cols=0, + tmem_base=tmem_base, + scoped_gmem_used_semaphores={k: 0 for k in scoped_gmem_semaphores}, + scoped_gmem_semaphore_base_ptr=scoped_gmem_semaphores, runtime_barriers=grouped_barriers, name_stack=source_info_util.NameStack(), traceback_caches=mlir.TracebackCaches(), squashed_dims=squashed_dims, - thread_semantics=thread_semantics, + lowering_semantics=lowering_semantics, + primitive_semantics=gpu_core.PrimitiveSemantics.Warpgroup, + mesh_info=pallas_utils.MeshInfo.from_mesh(jax_mesh) + if jax_mesh is not None + else None, + auto_barriers=not params.unsafe_no_auto_barriers, ) del runtime_smem, grouped_barriers, runtime_barriers - _ = lower_jaxpr_to_mosaic_gpu( module_ctx, launch_ctx, jaxpr, buffers_gmem, consts ) - rs = _estimate_resources(ResourceEstimatorContext(thread_semantics), jaxpr) - smem_scratch_bytes = params.get("smem_scratch_bytes") - if smem_scratch_bytes is None: - smem_scratch_bytes = rs.smem_scratch_bytes + scratch_buffers = [ + jax.ShapeDtypeStruct(shape=[rs.smem_scratch_bytes], dtype=np.int8), + rs.barriers, + ] + if rs.tmem_scratch_cols > 0 and rs.tmem_collective_scratch_cols > 0: + raise ValueError( + "Can't mix collective and non-collective TMEM allocations within the" + " same kernel." + ) + tmem_scratch_cols = rs.tmem_scratch_cols + rs.tmem_collective_scratch_cols + if tmem_scratch_cols > 0: + scratch_buffers.append( + mgpu.TMEM( + shape=(tcgen05.TMEM_ROWS, tmem_scratch_cols), + dtype=np.int32, + collective=rs.tmem_collective_scratch_cols > 0, + ), + ) + else: + scratch_buffers.append(None) - prof_ctx = prof_spec = None - if prof_space := params.get("profile_space", 0): + prof_spec = None + if params.profile_space: # Each range is 2 events, each event is 4 bytes. - prof_spec = mgpu_profiler.ProfilerSpec(prof_space * 2 * 4) - prof_ctx = ProfilerContext(params["profile_dir"], prof_spec) - module, out_structs_gmem, _, launch_ctx, scratch_arr = ( + prof_spec = mgpu_profiler.ProfilerSpec( + params.profile_space * 2 * 4, dump_path=params.profile_dir + ) + cuda_grid = tuple(map(operator.mul, parallel_grid, cluster)) + + scoped_semaphores_shape = [] + for collective_axes in sorted(rs.scoped_gmem_semaphores.keys()): + num_sems = rs.scoped_gmem_semaphores[collective_axes] + # TODO(justinfu): Compute axis_size for general collective_axes. + # axis_size computes axis_size(all_axes - collective_axes) + if _is_block_local_scope(collective_axes, axis_names): + axis_size = math.prod(cuda_grid) + elif _is_global_scope(collective_axes, axis_names): + axis_size = 1 + else: + raise NotImplementedError( + f"Unimplemented scope for semaphores: {collective_axes=}") + scoped_semaphores_shape.append( + jax.ShapeDtypeStruct( + shape=(axis_size * num_sems,), dtype=np.int32 + ), + ) + scoped_semaphores_shape = tuple(scoped_semaphores_shape) + + # NOTE: new_out_shapes has out_shapes, then semaphores_shape and + # optionally the profiler buffer. + module, new_out_shapes, _, launch_ctx = ( mgpu_core._lower_as_gpu_kernel( body, - grid=parallel_grid, + grid=cuda_grid, cluster=cluster, block=block, - in_shapes=in_shapes, - out_shape=out_shapes, - smem_scratch_shape=( - jax.ShapeDtypeStruct(shape=[smem_scratch_bytes], dtype=np.int8), - rs.barriers, - ), + in_shapes=(*in_shapes, *scoped_semaphores_shape), + out_shape=(*out_shapes, *scoped_semaphores_shape), + inout_shape=(), + smem_scratch_shape=scratch_buffers, + lowering_semantics=lowering_semantics, module_name=mlir.sanitize_name(debug_info.func_name), + kernel_name=mlir.sanitize_name(debug_info.func_name), prof_spec=prof_spec, ) ) - if thread_semantics == mgpu.ThreadSemantics.Warpgroup: + if lowering_semantics == mgpu.LoweringSemantics.Warpgroup: + # We need to run a pass that removes dead-code for which layout inference + # does not work. + pm = mlir.passmanager.PassManager.parse("builtin.module(canonicalize)", module.context) + pm.run(module.operation) + # Run Python lowering passes. The remaining passes will be run in C++ in # jax/jaxlib/mosaic/gpu/custom_call.cc mgpu.infer_layout(module) # pytype: disable=attribute-error - mgpu.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error + mgpu.lower_mgpu_dialect( + module, launch_ctx, auto_barriers=not params.unsafe_no_auto_barriers + ) - mgpu_core._initialize_scratch(launch_ctx, scratch_arr) + launch_ctx.scratch.finalize_size() return LoweringResult( - module, parallel_grid, block, out_structs_gmem, prof_ctx + module, cuda_grid, block, new_out_shapes, prof_spec, + scoped_semaphores_shape, ) mosaic_lowering_rules = { # Lowering rules when using Mosaic GPU lane semantics. - mgpu.ThreadSemantics.Lane: {} , + (mgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warpgroup): {} , + gpu_core.LANExWARP_SEMANTICS: {} , # Lowering rules when using Mosaic GPU warpgroup semantics. - mgpu.ThreadSemantics.Warpgroup: {}, + (mgpu.LoweringSemantics.Warpgroup, + gpu_core.PrimitiveSemantics.Warpgroup): {}, } def register_lowering_rule( - primitive: jax_core.Primitive, thread_semantics: mgpu.ThreadSemantics + primitive: jax_core.Primitive, + lowering_semantics: mgpu.LoweringSemantics, + primitive_semantics: gpu_core.PrimitiveSemantics = gpu_core.PrimitiveSemantics.Warpgroup, ): def deco(fn): - mosaic_lowering_rules[thread_semantics][primitive] = fn + mosaic_lowering_rules[ + (lowering_semantics, primitive_semantics)][primitive] = fn return fn return deco @@ -720,7 +1129,7 @@ def write_env(var: jax_core.Var, val, require_value: bool = True): # TODO(apaszke): Handle other avals (refs, etc.). if isinstance(aval := var.aval, jax_core.ShapedArray): # TODO(apaszke): Clarify the type invariants for lane semantics? - if module_ctx.thread_semantics == mgpu.ThreadSemantics.Warpgroup: + if module_ctx.lowering_semantics == mgpu.LoweringSemantics.Warpgroup: # Shaped arrays must be vectors if and only if their shape is non-empty. # Those with empty shapes should be represented by their scalar type. mlir_dtype = mgpu_utils.dtype_to_ir_type(aval.dtype) @@ -728,11 +1137,11 @@ def write_env(var: jax_core.Var, val, require_value: bool = True): if require_value: raise AssertionError(f"Shaped arrays must be represented by ir.Values, got: {val}") else: - if var.aval.shape: + if aval.shape: raise AssertionError("Only scalars can be represented by non-ir.Values") return # Skip following checks. if aval.shape: - if not ir.VectorType.isinstance(val.type): + if not isinstance(val.type, ir.VectorType): raise AssertionError(f"Non-scalar arrays must be represented by vectors, got: {val.type}") vty = ir.VectorType(val.type) if vty.element_type != mlir_dtype: @@ -740,27 +1149,33 @@ def write_env(var: jax_core.Var, val, require_value: bool = True): if tuple(vty.shape) != aval.shape: raise AssertionError(f"Vector shape must match ShapedArray shape, got: {vty.shape} != {aval.shape}") else: - if ir.VectorType.isinstance(val.type): + if isinstance(val.type, ir.VectorType): raise AssertionError(f"Scalars must be represented by non-vector types, got: {val.type}") if val.type != mlir_dtype: raise AssertionError(f"Scalar type must match ShapedArray dtype, got: {val.type} != {mlir_dtype}") - foreach(write_env, jaxpr.constvars, consts) - foreach(lambda v, a: write_env(v, a, require_value=False), jaxpr.invars, args) + foreach( + functools.partial(write_env, require_value=False), jaxpr.constvars, consts + ) + foreach(functools.partial(write_env, require_value=False), jaxpr.invars, args) + # TODO(justinfu): Handle transform scopes. last_local_name_stack: list[str] = [] named_regions = [] - for eqn in jaxpr.eqns: + for i, eqn in enumerate(jaxpr.eqns): invals = map(read_env, eqn.invars) - source_info = eqn.source_info.replace( - name_stack=module_ctx.name_stack + eqn.source_info.name_stack + eqn_name_stack = module_ctx.name_stack + eqn.source_info.name_stack + loc = mlir.source_info_to_location( # pytype: disable=wrong-arg-types + module_ctx, eqn.primitive, eqn_name_stack, eqn.source_info.traceback ) - loc = mlir._source_info_to_location(module_ctx, eqn.primitive, source_info) with source_info_util.user_context(eqn.source_info.traceback), loc: - if eqn.primitive not in mosaic_lowering_rules[module_ctx.thread_semantics]: + if eqn.primitive not in mosaic_lowering_rules[ + (module_ctx.lowering_semantics, module_ctx.primitive_semantics)]: raise NotImplementedError( "Unimplemented primitive in Pallas Mosaic GPU lowering: " - f"{eqn.primitive.name}. " + f"{eqn.primitive.name} for lowering semantics " + f"{module_ctx.lowering_semantics} and user thread semantics " + f"{module_ctx.primitive_semantics}. " "Please file an issue on https://github.com/jax-ml/jax/issues." ) new_local_name_stack = [scope.name for scope in eqn.source_info.name_stack.stack] @@ -772,20 +1187,32 @@ def write_env(var: jax_core.Var, val, require_value: bool = True): wrapper_stack = contextlib.ExitStack() wrapper_stack.enter_context(launch_ctx.named_region(name)) named_regions.append(wrapper_stack) - rule = mosaic_lowering_rules[module_ctx.thread_semantics][eqn.primitive] + rule = mosaic_lowering_rules[ + (module_ctx.lowering_semantics, module_ctx.primitive_semantics) + ][eqn.primitive] + # If the equation is immediately followed by a layout cast on its output, + # we provide the layout as a hint to the rule. + out_layout_hint = None + if i + 1 < len(jaxpr.eqns): + lookahead_eqn = jaxpr.eqns[i + 1] + is_layout_cast = lookahead_eqn.primitive == gpu_core.layout_cast_p + uses_eqn_output = lookahead_eqn.invars == eqn.outvars + if is_layout_cast and uses_eqn_output: + out_layout_hint = lookahead_eqn.params["new_layout"].to_mgpu() rule_ctx = LoweringRuleContext( module_ctx, launch_ctx, avals_in=[cast(jax_core.ShapedArray, v.aval) for v in eqn.invars], avals_out=[cast(jax_core.ShapedArray, v.aval) for v in eqn.outvars], prim=eqn.primitive, + out_layout_hint=out_layout_hint, ) try: outvals = rule(rule_ctx, *invals, **eqn.params) except LoweringError: raise # We only add the extra info to the innermost exception. except Exception as e: - if not pallas_call._verbose_errors_enabled(): + if not config.jax_pallas_verbose_errors.value: raise inval_types = map(lambda t: getattr(t, "type", None), invals) raise LoweringError( @@ -801,8 +1228,9 @@ def write_env(var: jax_core.Var, val, require_value: bool = True): return map(read_env, jaxpr.outvars) -@register_lowering_rule(primitives.program_id_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(primitives.program_id_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(primitives.program_id_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule( + primitives.program_id_p, mgpu.LoweringSemantics.Warpgroup) def _program_id_lowering_rule(ctx: LoweringRuleContext, axis): if ctx.module_ctx.program_ids is None: raise NotImplementedError("pl.program_id() is not supported in this context") @@ -826,24 +1254,21 @@ def _unravel_program_id( return arith_dialect.index_cast(ir.IntegerType.get_signless(32), pid) -def _program_id(parallel_axis: int, squashed_dims: tuple[int, ...]) -> ir.Value: - if squashed_dims: - if parallel_axis < len(squashed_dims): - # All squashed dimensions are mapped to Dimension.x. - block_id = gpu_dialect.block_id(gpu_dialect.Dimension.x) - return _unravel_program_id(block_id, parallel_axis, squashed_dims) - else: - # Handle unsquashed axes. - return arith_dialect.index_cast( - ir.IntegerType.get_signless(32), - gpu_dialect.block_id(gpu_dialect.Dimension( - parallel_axis - len(squashed_dims) + 1)), - ) +def _program_id( + parallel_axis: int, squashed_dims: tuple[int, ...], grid_size: int +) -> ir.Value: + """Returns the id of the current kernel instance along the given axis in the original Pallas grid.""" + if parallel_axis < len(squashed_dims): + # All squashed dimensions are mapped to Dimension.z. + block_id = gpu_dialect.block_id(gpu_dialect.Dimension.z) + idx = len(squashed_dims) - 1 - parallel_axis + return _unravel_program_id(block_id, idx, squashed_dims) else: + idx = grid_size - 1 - parallel_axis + assert idx in (0, 1, 2) return arith_dialect.index_cast( ir.IntegerType.get_signless(32), - gpu_dialect.block_id(gpu_dialect.Dimension(parallel_axis)), - ) + gpu_dialect.block_id(gpu_dialect.Dimension(idx))) def _lower_fun( @@ -860,7 +1285,7 @@ def lowering_rule(ctx: LoweringRuleContext, *args, **params): "Pallas Mosaic GPU lower_fun", fun, args, params ), ) - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) out = lower_jaxpr_to_mosaic_gpu( ctx.module_ctx, ctx.launch_ctx, jaxpr, args, consts ) @@ -868,77 +1293,233 @@ def lowering_rule(ctx: LoweringRuleContext, *args, **params): return lowering_rule +def _handle_dtype_bitcast( + ref: ir.Value, src_dtype: ir.Type, dst_dtype: ir.Type +) -> ir.Value: + """Allows bitcasting a SMEM ref from one element type to another. -@register_lowering_rule(primitives.num_programs_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(primitives.num_programs_p, mgpu.ThreadSemantics.Warpgroup) -def _num_programs_lowering_rule(ctx: LoweringRuleContext, axis): - del ctx # Unused. - return arith_dialect.index_cast( - ir.IntegerType.get_signless(32), - gpu_dialect.block_dim(gpu_dialect.Dimension(axis)), - ) - + Args: + ref: the reference to bitcast. + src_dtype: the source element type. + dst_dtype: the destination element type. -def _handle_reshaping( - ref: ir.Value, transforms: Sequence[gpu_core.Transform] -) -> tuple[ir.Value, Sequence[gpu_core.Transform]]: - is_trivial_indexer = lambda t: isinstance( - t, indexing.NDIndexer - ) and gpu_core.is_trivial_index(t.indices, t.shape) + Returns: + A bitcasted version of `ref` with element type `dst_dtype`. - last_reshaper_idx = next( - reversed([i for i, t in enumerate(transforms) if isinstance(t, RefReshaper)]), - None, - ) - if last_reshaper_idx is None: - return ref, transforms - # Check that before the reshape are only trivial indexes and or - # other reshapes. - # TODO(cperivol): Reshapes should bubble up rather than being - # expected to effectively be the first ref transform. - if not all(isinstance(t, RefReshaper) or is_trivial_indexer(t) for t in transforms[:last_reshaper_idx]): + Raises: + ValueError: if the source ref is not in SMEM. + """ + if src_dtype == dst_dtype: + return ref + if src_dtype != ir.IntegerType.get_signless(8): + raise NotImplementedError( + "Data type bitcast is only supported from i8 to other types." + ) + ref_ty = ir.MemRefType(ref.type) + if not mgpu_utils.is_smem_ref(ref_ty): + raise ValueError(f"Only workgroup memory is supported but got {ref}.") + if len(ref_ty.shape) != 1: raise NotImplementedError( - "Reshapes do not compose with other transforms and indexers must be" - f" trivial (transforms: {transforms})" + "Data type bitcast is only supported for 1D arrays." + ) + [stride], _ = ref_ty.get_strides_and_offset() + if stride != 1: + raise ValueError( + "Data type bitcast is only supported for contiguous 1D arrays, but got " + f"stride={stride}." ) - reshaper = cast(RefReshaper, transforms[last_reshaper_idx]) - # Skip all the reshapes and trivial indexes. - return mgpu.memref_reshape(ref, reshaper.shape), transforms[last_reshaper_idx + 1:] + [shape_bytes] = ref_ty.shape + shape_bitwidth = shape_bytes * 8 + target_bitwidth = mgpu_utils.bitwidth(dst_dtype) + + if shape_bitwidth % target_bitwidth: + raise ValueError( + f"Can not bitcast memory region of size {shape_bitwidth} bits to dtype " + f"with {target_bitwidth} bits." + ) + + result_type = ir.MemRefType.get( + shape=(shape_bitwidth // target_bitwidth,), + element_type=dst_dtype, + memory_space=ref_ty.memory_space, + ) + # Do a memref_ptr/ptr_as_memref roundtrip instead of using `memref.view`, + # which refuses to take in our source ref. This is because `memref.view` only + # works on a super restricted set of `memref`s. E.g., it does not work if an + # offset is specified, which can be the case for our SMEM refs. + smem = mgpu_utils.WORKGROUP_NVPTX_ADDRESS_SPACE + ref = mgpu_utils.memref_ptr(ref, memory_space=smem) + return mgpu_utils.ptr_as_memref(ref, result_type, ptr_memory_space=smem) -def _handle_indexing( - ref: ir.Value, transforms: Sequence[gpu_core.Transform] -) -> tuple[ir.Value, Sequence[gpu_core.Transform]]: - if not transforms: - pass - indexer_idxs = [ - i for i, t in enumerate(transforms) if isinstance(t, indexing.NDIndexer) - ] - if not indexer_idxs: - return ref, transforms - sliced_ref = ref + +def _extract_aliased_ref( + ref: RefOrTmemType, transforms: Sequence[state_types.Transform] +) -> tuple[RefOrTmemType, Sequence[state_types.Transform]]: + match transforms: + case ( + gpu_core.ExtractAliasedRef( + dtype, transformed_shape, offset, layout + ), + *other_transforms, + ): + mlir_dtype = mgpu_utils.dtype_to_ir_type(dtype) + if isinstance(ref, tcgen05.TMEMRef): + assert layout is not None + if ref.shape[0] != transformed_shape[0]: + raise ValueError( + "TMEM aliasing only supported for Refs with the same first" + f" dimension, got {ref.shape[0]} != {transformed_shape[0]}." + ) + address = arith_dialect.addi(ref.address, _i32_constant(offset)) + ref = tcgen05.TMEMRef( + address=address, + shape=transformed_shape, + dtype=mgpu_utils.dtype_to_ir_type(dtype), + layout=layout) + else: + assert layout is None + ref_bits = math.prod(transformed_shape) * mgpu_utils.bitwidth(mlir_dtype) + if ref_bits % 8: + raise NotImplementedError("Only byte-aligned bitcasts are supported.") + assert offset % gpu_core.SMEM_ALIGNMENT == 0 + ref_bytes = ref_bits // 8 + ref = mgpu.memref_slice(ref, slice(offset, offset + ref_bytes)) + ref = _handle_dtype_bitcast( + ref, + ir.MemRefType(ref.type).element_type, + mgpu_utils.dtype_to_ir_type(dtype), + ) + ref = mgpu.memref_reshape(ref, transformed_shape) + return ref, tuple(other_transforms) + case _: + return ref, transforms + + +def _transform_dtype( + dtype: dtypes.DType, + transforms: Sequence[state_types.Transform], +) -> dtypes.DType: + """Applies `t.transform_dtype` for `t` in `transforms` sequentially on `dtype`.""" + for transform in transforms: + dtype = transform.transform_dtype(dtype) + assert dtype is not None + return dtype # pytype: disable=bad-return-type + + +def _handle_transforms( + ctx: LoweringRuleContext, + ref: RefOrTmemType, + transforms: Sequence[state_types.Transform], + *, + handle_transposes=True, + handle_reshapes=True, + allow_peer_refs=False, + allow_multicast_refs=False, +) -> tuple[RefOrTmemType, Sequence[state_types.Transform]]: + # Before we handle other transforms, we resolve any possible leading + # aliasing transform. + ref, transforms = _extract_aliased_ref(ref, transforms) + if isinstance(ref, tcgen05.TMEMRef): + mlir_dtype = ref.dtype + else: + mlir_dtype = ir.MemRefType(ref.type).element_type + transformed_ref = ref new_transforms = [] - for t in transforms: - if not isinstance(t, indexing.NDIndexer): - new_transforms.append(t) - continue - indexer = cast(indexing.NDIndexer, t) - if indexer.int_indexer_shape: - raise NotImplementedError("int_indexer_shape non-empty") - indices = _ndindexer_indices(indexer) + def _bubble_up(untransform_fn, data): + nonlocal new_transforms new_transforms_rev = [] for t in reversed(new_transforms): - indices, new_t = t.untransform_index(indices) + data, new_t = untransform_fn(t, data) new_transforms_rev.append(new_t) - sliced_ref = mgpu.memref_slice(sliced_ref, indices) + new_transforms = list(reversed(new_transforms_rev)) - return sliced_ref, new_transforms + return data + + peer_device_id = None + is_multicast = False + for t in transforms: + match t: + case indexing.NDIndexer(): + indexer = cast(indexing.NDIndexer, t) + if indexer.int_indexer_shape: + raise NotImplementedError("int_indexer_shape non-empty") + indices = _ndindexer_indices(indexer) + indices = _bubble_up( + lambda t, idxs: t.untransform_index(mlir_dtype, idxs), indices + ) + if ( + isinstance(transformed_ref, tcgen05.TMEMRef) + and ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane + ): + transformed_ref = transformed_ref.slice(*indices) + else: + transformed_ref = mgpu.memref_slice(transformed_ref, indices) + case RefTransposer(perm): + if handle_transposes: + perm = _bubble_up(lambda t, p: t.untransform_transpose(p), perm) + if isinstance(transformed_ref, tcgen05.TMEMRef): + raise ValueError("TMEM transpose not allowed.") + transformed_ref = mgpu.memref_transpose(transformed_ref, perm) + else: + if not isinstance(t, gpu_core.TransposeRef): + t = gpu_core.TransposeRef(perm) + new_transforms.append(t) + case RefReshaper(dtype=dtype, shape=shape) if handle_reshapes: + shape = _bubble_up( + lambda t, p: t.untransform_reshape(dtype, p), # pylint: disable=cell-var-from-loop + shape) + if isinstance(transformed_ref, tcgen05.TMEMRef): + raise ValueError("TMEM reshape not allowed.") + transformed_ref = mgpu.memref_reshape(transformed_ref, shape) + case gpu_core.PeerMemRef(device_id, device_id_type): + peer_device_id, other_axes = primitives.device_id_to_logical( + ctx.module_ctx.mesh_info, + _ensure_ir_value_device_id(device_id), + device_id_type, + lambda name: _axis_index_rule(ctx, axis_name=name), + ) + if other_axes: + raise ValueError( + "Only JAX mesh axes can be used to obtain peer references, but" + f" got {other_axes}" + ) + case gpu_core.MulticastRef(_, _, _): + if not allow_multicast_refs: + raise NotImplementedError( + "Multicast references are not allowed in the lowering of this" + " primitive." + ) + is_multicast = True + case _: + new_transforms.append(t) + if peer_device_id is not None: + assert not is_multicast + if not allow_peer_refs: + raise NotImplementedError( + "Peer device references are not allowed in the lowering of this" + " primitive." + ) + transformed_ref = ctx.launch_ctx.to_remote( + transformed_ref, _ensure_ir_value(peer_device_id, jnp.int32) + ) + if is_multicast: + transformed_ref = ctx.launch_ctx.to_remote_multicast(transformed_ref) + return transformed_ref, new_transforms -def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ...]: +def _ndindexer_indices( + indexer: indexing.NDIndexer, allow_arrays: bool = False +) -> tuple[gpu_core.Index | mgpu.FragmentedArray | ir.Value, ...]: indices = [] for idx in indexer.indices: - if not isinstance(idx, indexing.Slice): + if (isinstance(idx, mgpu.FragmentedArray) and idx.shape) or ( + isinstance(idx, ir.Value) and isinstance(idx.type, ir.VectorType) # pytype: disable=attribute-error + ): + if not allow_arrays: + raise ValueError("Arrays are not supported as indices.") + indices.append(idx) + elif not isinstance(idx, indexing.Slice): indices.append(_as_index(idx)) elif not idx.is_dynamic_start and not idx.is_dynamic_size: indices.append(slice(idx.start, idx.start + idx.size, idx.stride)) @@ -954,130 +1535,290 @@ def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ... return tuple(indices) -@register_lowering_rule(sp.get_p, mgpu.ThreadSemantics.Lane) -def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *leaves, tree): - if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem): - raise TypeError(f"Can only load from references (got {x_smem}).") - - x_aval = ctx.avals_in[0] +@register_lowering_rule(sp.get_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule( + sp.get_p, mgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warp +) +def _get_lowering_rule( + ctx: LoweringRuleContext, x_ref, *leaves, tree, optimized=True +): + if isinstance(x_ref, tcgen05.TMEMRef): + raise RuntimeError( + "Loads from TMEM are asynchronous operations and cannot be performed" + " using the usual syntax. Please use plgpu.async_load_tmem instead." + ) + if ( + ctx.avals_out[0].shape + and ctx.module_ctx.primitive_semantics == gpu_core.PrimitiveSemantics.Warp + ): + raise ValueError("Can only load scalars in warp-level code.") + if not isinstance(x_ref, ir.Value) and isinstance(x_ref, ir.MemRefType): + raise TypeError(f"Can only load from references (got {x_ref}).") + dtype = ctx.avals_out[0].dtype transforms = jax.tree.unflatten(tree, leaves) - x_smem, transforms = _handle_reshaping(x_smem, transforms) - x_smem, transforms = _handle_indexing(x_smem, transforms) + transposed = ctx.out_layout_hint and ctx.out_layout_hint in ( + mgpu.WGMMA_TRANSPOSED_LAYOUT, + mgpu.TCGEN05_TRANSPOSED_LAYOUT, + ) + transposed = bool(transposed) + x_smem, transforms = _handle_transforms( + ctx, x_ref, transforms, handle_transposes=not transposed, + allow_peer_refs=True + ) + x_smem = cast(ir.Value, x_smem) + del x_ref # Don't use x_ref anymore. Use x_smem instead! + + is_signed = mgpu_utils.is_signed(dtype) + + if not ctx.avals_out[0].shape: # The scalar case is simple. + val = memref_dialect.load(x_smem, []) + return mgpu.FragmentedArray.splat(val, shape=(), is_signed=is_signed) match transforms: - case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): - if tiling != (64, swizzle // x_aval.dtype.itemsize): - raise NotImplementedError("Tiling does not fit swizzle") + case ( + gpu_core.UnswizzleRef(swizzle), + gpu_core.UntileRef(tiling), + *maybe_transpose, + ): + if len(tiling) != 2: + raise NotImplementedError(f"Only 2D tiling is supported, got: {tiling}") + bw = dtypes.itemsize_bits(ctx.avals_out[0].dtype) + expected_minor_tiling = swizzle * 8 // bw + if tiling[-1] != expected_minor_tiling: + raise NotImplementedError( + "Minor tiling dimension does not fit swizzle: " + f" expected {expected_minor_tiling}, got {tiling[-1]}" + ) + + if transposed != bool(maybe_transpose): + raise ValueError( + "Either both the ref and the value are transposed or neither is." + ) + + if maybe_transpose: + if maybe_transpose != [gpu_core.TransposeRef((1, 0))]: + raise NotImplementedError( + f"Unsupported transforms: {transforms} ({maybe_transpose})" + ) + + x_smem = mgpu.memref_transpose(x_smem, (1, 0, 3, 2)) return mgpu.FragmentedArray.load_tiled( - x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype), swizzle=swizzle - ) - case (): - # Handle scalar indexing. - if not ctx.avals_out[0].shape: - is_signed = mgpu_utils.is_signed(x_aval.dtype) - val = memref_dialect.load(x_smem, []) - return mgpu.FragmentedArray.splat(val, shape=(), is_signed=is_signed) - - return mgpu.FragmentedArray.load_strided( - x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) + x_smem, + is_signed=is_signed, + swizzle=swizzle, + layout=ctx.out_layout_hint or mgpu.WGMMA_LAYOUT, + optimized=optimized, ) + case (*maybe_transpose,): + if maybe_transpose: + if len(maybe_transpose) != 1 or not isinstance( + maybe_transpose[0], gpu_core.TransposeRef + ): + raise NotImplementedError( + f"Unsupported transforms: {transforms} ({maybe_transpose})" + ) + x_smem = mgpu.memref_transpose(x_smem, maybe_transpose[0].permutation) + match ctx.out_layout_hint: + case mgpu.WGStridedFragLayout(shape=shape, vec_size=vec_size): + ref_ty = ir.MemRefType(x_smem.type) + if shape != tuple(ref_ty.shape): + raise ValueError( + f"Unsupported shape {shape}, (expected {tuple(ref_ty.shape)})" + ) + return mgpu.FragmentedArray.load_strided( + x_smem, + is_signed=is_signed, + vec_size=vec_size, + ) + case None: + return mgpu.FragmentedArray.load_strided(x_smem, is_signed=is_signed) + case _: + return mgpu.FragmentedArray.load_untiled( + x_smem, + is_signed=is_signed, + layout=ctx.out_layout_hint, + swizzle=16, + optimized=optimized, + ) case _: raise NotImplementedError(f"Unsupported transforms: {transforms}") -@register_lowering_rule(sp.get_p, mgpu.ThreadSemantics.Warpgroup) -def _get_lowering_rule_wg(ctx: LoweringRuleContext, x_smem, *leaves, tree): - if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem): - raise TypeError(f"Can only load from references (got {x_smem}).") - - x_aval = ctx.avals_in[0] +@register_lowering_rule(sp.get_p, mgpu.LoweringSemantics.Warpgroup) +def _get_lowering_rule_wg( + ctx: LoweringRuleContext, x_ref, *leaves, tree, optimized=True +): + if not isinstance(x_ref, ir.Value) and isinstance(x_ref, ir.MemRefType): + raise TypeError(f"Can only load from references (got {x_ref}).") transforms = jax.tree.unflatten(tree, leaves) - x_smem, transforms = _handle_reshaping(x_smem, transforms) - x_smem, transforms = _handle_indexing(x_smem, transforms) + x_ref, transforms = _handle_transforms( + ctx, x_ref, transforms, allow_peer_refs=True + ) if transforms: raise NotImplementedError( "Transforms are not yet implemented for warpgroup semantics" ) + assert isinstance(x_ref, ir.Value) shape = ctx.avals_out[0].shape - ty = ir.VectorType.get(shape, mgpu_utils.dtype_to_ir_type(x_aval.dtype)) if shape: - zero_index = arith_dialect.constant(ir.IndexType.get(), 0) - indices = [zero_index for _ in range(len(shape))] - return vector_dialect.load(ty, x_smem, indices) + return mgpu.dialect.vector_load(x_ref, optimized=optimized) else: - return memref_dialect.load(x_smem, []) + return memref_dialect.load(x_ref, []) -@register_lowering_rule(sp.swap_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(sp.swap_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule( + sp.swap_p, mgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warp +) def _swap_lowering_rule( - ctx: LoweringRuleContext, x_smem, value, *leaves, tree + ctx: LoweringRuleContext, x_ref, value, *leaves, tree ): - if not isinstance(value, mgpu.FragmentedArray): - raise TypeError(f"Can only store arrays (got {value}).") - if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem): - raise TypeError(f"Can only store to references (got {x_smem}).") - x_aval = ctx.avals_in[0] + if isinstance(x_ref, tcgen05.TMEMRef): + raise RuntimeError( + "Stores to TMEM are asynchronous operations and cannot be performed" + " using the usual syntax. Please use plgpu.async_store_tmem instead." + ) + barrier = mgpu.warpgroup_barrier + if ctx.module_ctx.primitive_semantics == gpu_core.PrimitiveSemantics.Warp: + if ctx.avals_out[0].shape: + raise NotImplementedError("Can only store scalars in warp-level lowering.") + i32 = ir.IntegerType.get_signless(32) + barrier = functools.partial( + nvvm_dialect.bar_warp_sync, arith_dialect.constant(i32, -1) + ) + value = _ensure_fa(value, ctx.avals_in[1].dtype) + + if not isinstance(x_ref, ir.Value) and isinstance(x_ref, ir.MemRefType): + raise TypeError(f"Can only store to references (got {x_ref}).") + v_aval = ctx.avals_in[1] transforms = jax.tree.unflatten(tree, leaves) - x_smem, transforms = _handle_reshaping(x_smem, transforms) - x_smem, transforms = _handle_indexing(x_smem, transforms) + transposed_value = value.layout in ( + mgpu.WGMMA_TRANSPOSED_LAYOUT, + mgpu.TCGEN05_TRANSPOSED_LAYOUT, + ) + x_smem, transforms = _handle_transforms( + ctx, x_ref, transforms, handle_transposes=not transposed_value, + allow_peer_refs=True + ) + del x_ref # Don't use x_ref anymore. Use x_smem instead! + + if ctx.module_ctx.auto_barriers: + barrier() # Make sure reads have completed before we write. + match transforms: - case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): - if tiling != (64, swizzle // x_aval.dtype.itemsize): - raise NotImplementedError("Tiling does not fit swizzle") + case _ if math.prod(ctx.avals_out[0].shape) == 1: # Scalar case. + zero_idx = _ir_constant(0, ir.IndexType.get()) + indices = [zero_idx] * len(ctx.avals_out[0].shape) + old_value = mgpu.FragmentedArray.splat( + memref_dialect.load(x_smem, indices), + shape=(), + is_signed=mgpu_utils.is_signed(v_aval.dtype), + ) + value.store_untiled(x_smem) + case ( + gpu_core.UnswizzleRef(swizzle), + gpu_core.UntileRef(tiling), + *maybe_transpose, + ): + if len(tiling) != 2: + raise NotImplementedError(f"Only 2D tiling is supported, got: {tiling}") + bw = dtypes.itemsize_bits(v_aval.dtype) + expected_minor_tiling = swizzle * 8 // bw + if tiling[-1] != expected_minor_tiling: + raise NotImplementedError( + "Minor tiling dimension does not fit swizzle: " + f" expected {expected_minor_tiling}, got {tiling[-1]}" + ) + + if transposed_value != bool(maybe_transpose): + raise ValueError( + "Either both the ref and the value are transposed or neither is." + ) + + if maybe_transpose: + if maybe_transpose != [gpu_core.TransposeRef((1, 0))]: + raise NotImplementedError( + f"Unsupported transforms: {transforms} ({maybe_transpose})" + ) + + x_smem = mgpu.memref_transpose(x_smem, (1, 0, 3, 2)) + old_value = mgpu.FragmentedArray.load_tiled( - x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype), swizzle=swizzle + x_smem, + is_signed=mgpu_utils.is_signed(v_aval.dtype), + swizzle=swizzle, + layout=value.layout, ) value.store_tiled(x_smem, swizzle=swizzle) - return old_value - case (): - old_value = mgpu.FragmentedArray.load_strided( - x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) - ) - value.store_untiled(x_smem) - return old_value + case () | (gpu_core.TransposeRef(),): + transposed = bool(transforms) + match value.layout: + case mgpu.TiledLayout(): + if transposed: + assert isinstance( + transforms[0], gpu_core.TransposeRef + ) # silence pytype + permutation = transforms[0].permutation + x_smem = mgpu.memref_transpose(x_smem, permutation) + old_value = mgpu.FragmentedArray.load_untiled( + x_smem, + layout=value.layout, + is_signed=mgpu_utils.is_signed(v_aval.dtype), + optimized=False, + ) + value.store_untiled(x_smem, optimized=False) + case _: + if transposed: + raise NotImplementedError(f"Unsupported transforms: {transforms}") + old_value = mgpu.FragmentedArray.load_strided( + x_smem, is_signed=mgpu_utils.is_signed(v_aval.dtype) + ) + value.store_untiled(x_smem) case _: raise NotImplementedError(f"Unsupported transforms: {transforms}") + if ctx.module_ctx.auto_barriers: + barrier() # Make sure the writes have completed. + return old_value -@register_lowering_rule(sp.swap_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(sp.swap_p, mgpu.LoweringSemantics.Warpgroup) def _swap_lowering_rule_wg( ctx: LoweringRuleContext, x_smem, value, *leaves, tree ): - if not ir.VectorType.isinstance(value.type): - raise TypeError(f"Can only store vectors (got {value}).") - if not ir.MemRefType.isinstance(x_smem.type): - raise TypeError(f"Can only store to references (got {x_smem}).") - - x_aval = ctx.avals_in[0] - + shape = ctx.avals_out[0].shape + if shape and not isinstance(value.type, ir.VectorType): + raise TypeError(f"Can only store scalars or vectors (got {value}).") + if not ( + isinstance(x_smem, ir.Value) and isinstance(x_smem.type, ir.MemRefType) + ): + raise TypeError(f"Can only store to references (got {x_smem}).") transforms = jax.tree.unflatten(tree, leaves) - x_smem, transforms = _handle_reshaping(x_smem, transforms) - x_smem, transforms = _handle_indexing(x_smem, transforms) - + x_smem, transforms = _handle_transforms( + ctx, x_smem, transforms, allow_peer_refs=True) if transforms: raise NotImplementedError( "Transforms are not yet implemented for warpgroup semantics" ) - - shape = ctx.avals_out[0].shape - ty = ir.VectorType.get(shape, mgpu_utils.dtype_to_ir_type(x_aval.dtype)) + assert isinstance(x_smem, ir.Value) + value = _ensure_ir_value(value, ctx.avals_in[1].dtype) if shape: - zero_index = arith_dialect.constant(ir.IndexType.get(), 0) - indices = [zero_index for _ in range(len(shape))] - old_value = vector_dialect.load(ty, x_smem, indices) - vector_dialect.store(value, x_smem, indices) + old_value = mgpu.dialect.vector_load(x_smem) + mgpu.dialect.vector_store(value, x_smem) else: old_value = memref_dialect.load(x_smem, []) memref_dialect.store(value, x_smem, []) return old_value -@register_lowering_rule(pjit.pjit_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(pjit.pjit_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(pjit.jit_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(pjit.jit_p, mgpu.LoweringSemantics.Warpgroup) +@register_lowering_rule( + pjit.jit_p, mgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warp +) def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **kwargs): if jaxpr.consts: raise NotImplementedError @@ -1085,11 +1826,8 @@ def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **kwargs): ctx.module_ctx, ctx.launch_ctx, jaxpr.jaxpr, args, ) -@register_lowering_rule(pjit.mesh_cast_p, mgpu.ThreadSemantics.Lane) -def _mesh_cast_lowering_rule(ctx, x, dst_sharding): - return x -@register_lowering_rule(lax.slice_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.slice_p, mgpu.LoweringSemantics.Lane) def _slice_lowering_rule( ctx: LoweringRuleContext, x, limit_indices, start_indices, strides ): @@ -1099,8 +1837,28 @@ def _slice_lowering_rule( return x[tuple(slice(b, e) for b, e in zip(start_indices, limit_indices))] -@register_lowering_rule(lax.select_n_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.select_n_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.slice_p, mgpu.LoweringSemantics.Warpgroup) +def _slice_lowering_rule_wg( + ctx: LoweringRuleContext, x, limit_indices, start_indices, strides +): + del limit_indices + assert isinstance(x.type, ir.VectorType) + if strides is not None: + raise NotImplementedError("Strides are not supported.") + out_ty = ir.VectorType.get( + ctx.avals_out[0].shape, ir.VectorType(x.type).element_type + ) + sizes = ctx.avals_out[0].shape + strides = [1] * len(start_indices) + return vector_dialect.extract_strided_slice( + out_ty, x, start_indices, sizes, strides + ) + + +@register_lowering_rule(lax.select_n_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.select_n_p, mgpu.LoweringSemantics.Lane, + gpu_core.PrimitiveSemantics.Warp) +@register_lowering_rule(lax.select_n_p, mgpu.LoweringSemantics.Warpgroup) def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases): if len(cases) != 2: raise NotImplementedError( @@ -1108,8 +1866,12 @@ def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases): f" {len(cases)}" ) pred_aval, *cases_avals = ctx.avals_in + if ctx.module_ctx.primitive_semantics == gpu_core.PrimitiveSemantics.Warp: + if not all(aval.shape == () for aval in ctx.avals_in): + raise NotImplementedError( + "Can only select on scalars in warp-level lowering.") [out_aval] = ctx.avals_out - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: pred = _ensure_fa(pred, pred_aval.dtype) cases = _bcast(*cases, *cases_avals, out_aval) # ``select`` expects the first case to be the true branch, but ``select_n`` @@ -1127,7 +1889,7 @@ def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases): return arith_dialect.select(pred, *reversed(cases)) -@register_lowering_rule(lax.broadcast_in_dim_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.broadcast_in_dim_p, mgpu.LoweringSemantics.Lane) def _broadcast_in_dim_lowering_rule( ctx: LoweringRuleContext, x: mgpu.FragmentedArray, @@ -1140,49 +1902,98 @@ def _broadcast_in_dim_lowering_rule( [x_aval] = ctx.avals_in [y_aval] = ctx.avals_out x = _ensure_fa(x, x_aval.dtype) + rank_diff = y_aval.ndim - x_aval.ndim + if (isinstance(x.layout, mgpu.WGSplatFragLayout) and + broadcast_dimensions == tuple(range(rank_diff, rank_diff + x_aval.ndim))): + return x.broadcast(shape) if ( - broadcast_dimensions == tuple(range(x_aval.ndim)) - and y_aval.ndim == x_aval.ndim + 1 - and x.layout == mgpu.WGMMA_ROW_LAYOUT + isinstance(x.layout, mgpu.WGStridedFragLayout) + and broadcast_dimensions == tuple(range(rank_diff, y_aval.ndim)) ): - return x.broadcast_minor(y_aval.shape[-1]) - if broadcast_dimensions: - raise NotImplementedError - return x.broadcast(shape) + new_layout = mgpu.WGStridedFragLayout( + shape=y_aval.shape, vec_size=x.layout.vec_size + ) + return x.broadcast_in_dim(y_aval.shape, broadcast_dimensions, new_layout) + if not isinstance(layout := x.layout, mgpu.TiledLayout): + raise NotImplementedError(f"Unsupported layout: {x.layout}") + if any(d1 >= d2 for d1, d2 in zip(broadcast_dimensions[:-1], broadcast_dimensions[1:])): + raise NotImplementedError("broadcast_dimensions must be strictly increasing") + new_dims = [d for d in range(y_aval.ndim) if d not in broadcast_dimensions] + if (new_layout := ctx.out_layout_hint) is None: + candidates = [ + mgpu.WGMMA_LAYOUT, + mgpu.WGMMA_TRANSPOSED_LAYOUT, + mgpu.TCGEN05_LAYOUT, + mgpu.TCGEN05_TRANSPOSED_LAYOUT, + tcgen05.TMEM_NATIVE_LAYOUT, + ] + if y_aval.shape[-1] % 16 == 0: + candidates.append(tcgen05.fa_m64_collective_layout(y_aval.shape[-1])) + for candidate in candidates: + if len(candidate.base_tile_shape) != len(shape): + continue + if candidate.reduce(new_dims) == layout: + if new_layout is None: + new_layout = candidate + elif candidate == mgpu.TCGEN05_LAYOUT and new_layout == mgpu.WGMMA_LAYOUT: + continue # Choosing WGMMA_LAYOUT for backwards compatibility. + else: + raise NotImplementedError( + "Multiple options for the layout of the broadcast result (found" + f" at least {new_layout} and {candidate}). Use plgpu.layout_cast" + " on the output to suggest the desired output layout." + ) + if new_layout is None: + raise NotImplementedError( + "No compatible layout found for the broadcast result. Use" + " plgpu.layout_cast on the output to suggest the desired output layout." + ) + return x.broadcast_in_dim(y_aval.shape, broadcast_dimensions, new_layout) -@register_lowering_rule(lax.broadcast_in_dim_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule( + lax.broadcast_in_dim_p, mgpu.LoweringSemantics.Warpgroup) def _broadcast_in_dim_lowering_rule_wg( ctx: LoweringRuleContext, - x: ir.Value, + x, *, broadcast_dimensions, shape, sharding, ): del sharding - if broadcast_dimensions: - raise NotImplementedError [x_aval] = ctx.avals_in - x = _ensure_ir_value(x, x_aval.dtype) - return vector_dialect.splat( - ir.VectorType.get(shape, mgpu_utils.dtype_to_ir_type(x_aval.dtype)), - x, - ) + mlir_type = mgpu_utils.dtype_to_ir_type(x_aval.dtype) + result_ty = ir.VectorType.get(shape, mlir_type) + if not broadcast_dimensions: + # Even though we could implement this case by passing a 0D vector as input + # to mgpu.dialect.BroadcastInDimOp we don't want that. 0D vectors are + # generally problematic and so we avoid them by specializing that case + # directly here. + x = _ensure_ir_value(x, x_aval.dtype) + return vector_dialect.broadcast(result_ty, x) + return mgpu.dialect.broadcast_in_dim(result_ty, x, broadcast_dimensions) -@register_lowering_rule(lax.convert_element_type_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.convert_element_type_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.convert_element_type_p, + mgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warp) def _convert_element_type_lowering_rule( ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding ): del weak_type, sharding [x_aval] = ctx.avals_in + if ctx.module_ctx.primitive_semantics == gpu_core.PrimitiveSemantics.Warp: + if x_aval.shape != (): + raise NotImplementedError( + "Non-scalar arithmetic is not supported in warp-level lowering.") return _ensure_fa(x, x_aval.dtype).astype( mgpu_utils.dtype_to_ir_type(new_dtype), is_signed=mgpu_utils.is_signed(new_dtype) ) -@register_lowering_rule(lax.convert_element_type_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule( + lax.convert_element_type_p, mgpu.LoweringSemantics.Warpgroup) def _convert_element_type_lowering_rule_wg( ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding ): @@ -1200,10 +2011,10 @@ def _convert_element_type_lowering_rule_wg( if 1 < mgpu_utils.bitwidth(cur_dtype) < 8 or 1 < mgpu_utils.bitwidth(new_dtype) < 8: raise NotImplementedError("Conversion involving sub-byte types unsupported") - from_float = ir.FloatType.isinstance(cur_dtype) - to_float = ir.FloatType.isinstance(new_dtype) - from_integer = ir.IntegerType.isinstance(cur_dtype) - to_integer = ir.IntegerType.isinstance(new_dtype) + from_float = isinstance(cur_dtype, ir.FloatType) + to_float = isinstance(new_dtype, ir.FloatType) + from_integer = isinstance(cur_dtype, ir.IntegerType) + to_integer = isinstance(new_dtype, ir.IntegerType) if from_float and to_float: cur_ty_width = ir.FloatType(cur_dtype).width new_ty_width = ir.FloatType(new_dtype).width @@ -1263,8 +2074,8 @@ def convert(ty, x): maxint = _ir_constant(maxint, cur_dtype) minint = _ir_constant(minint, cur_dtype) if x_aval.shape: - maxint = vector_dialect.splat(x.type, maxint) - minint = vector_dialect.splat(x.type, minint) + maxint = vector_dialect.broadcast(x.type, maxint) + minint = vector_dialect.broadcast(x.type, minint) x = arith_dialect.minimumf(x, maxint) x = arith_dialect.maximumf(x, minint) else: @@ -1274,28 +2085,51 @@ def convert(ty, x): return convert(ty, x) -mosaic_lowering_rules[mgpu.ThreadSemantics.Lane].update({ +mosaic_lowering_rules[gpu_core.LANExWG_SEMANTICS].update({ lax.neg_p: lambda ctx, x: -x, lax.not_p: lambda ctx, x: ~x, }) -mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup].update({ +def _unary_warp_lowering_rule(impl): + def _lowering_rule(ctx: LoweringRuleContext, x): + if not all(aval_in.shape == () for aval_in in ctx.avals_in): + raise NotImplementedError( + "Non-scalar arithmetic is not supported in warp-level lowering.") + return impl(x) + return _lowering_rule + +mosaic_lowering_rules[gpu_core.LANExWARP_SEMANTICS].update({ + lax.neg_p: _unary_warp_lowering_rule(lambda x: -x), + lax.not_p: _unary_warp_lowering_rule(lambda x: ~x) +}) + +mosaic_lowering_rules[gpu_core.WGxWG_SEMANTICS].update({ lax.neg_p: _lower_fun(lambda x: jnp.subtract(0, x), multiple_results=False), lax.not_p: _lower_fun( - lambda x: jnp.bitwise_xor(x, -1), multiple_results=False + lambda x: jnp.astype(jnp.bitwise_xor(jnp.astype(x, int), -1), jnp.dtype(x)), multiple_results=False, ), }) def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl): + if ctx.module_ctx.primitive_semantics == gpu_core.PrimitiveSemantics.Warp: + if not all(aval_in.shape == () for aval_in in ctx.avals_in): + raise NotImplementedError( + "Non-scalar arithmetic is not supported in warp-level lowering.") x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) return impl(x, y) -mosaic_lowering_rules[mgpu.ThreadSemantics.Lane].update({ +def _div(x, y): + return x / y if isinstance(x.mlir_dtype, ir.FloatType) else x // y + + +for semantics in [gpu_core.LANExWG_SEMANTICS, gpu_core.LANExWARP_SEMANTICS]: + mosaic_lowering_rules[semantics].update({ lax.add_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x + y), lax.sub_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x - y), lax.mul_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x * y), + lax.div_p: partial(_binary_op_lowering_rule, impl=_div), lax.rem_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x % y), lax.and_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x & y), lax.or_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x | y), @@ -1308,8 +2142,7 @@ def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl): lax.ne_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x != y), lax.max_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x.max(y)), lax.min_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x.min(y)), -}) - + }) def _binary_op_lowering_rule_wg( ctx: LoweringRuleContext, x, y, *, ui_impl, si_impl, f_impl=None @@ -1353,7 +2186,7 @@ def _binary_op_lowering_rule_wg( arith_dialect.minimumf, ), ]: - mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][op] = partial( + mosaic_lowering_rules[gpu_core.WGxWG_SEMANTICS][op] = partial( _binary_op_lowering_rule_wg, si_impl=si_impl, ui_impl=ui_impl, @@ -1372,7 +2205,7 @@ def _binary_boolean_op_lowering_rule_wg( (lax.or_p, arith_dialect.ori), (lax.xor_p, arith_dialect.xori), ]: - mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][op] = partial( + mosaic_lowering_rules[gpu_core.WGxWG_SEMANTICS][op] = partial( _binary_boolean_op_lowering_rule_wg, impl=impl, ) @@ -1387,7 +2220,7 @@ def _comparison_lowering_rule_wg( x, y = _bcast_wg(x, y, *ctx.avals_in, *ctx.avals_out) if jnp.issubdtype(x_aval, jnp.signedinteger): return arith_dialect.cmpi(si_pred, x, y) - elif jnp.issubdtype(x_aval, jnp.integer) or jnp.issubdtype(x_aval, jnp.bool): + elif jnp.issubdtype(x_aval, jnp.unsignedinteger) or jnp.issubdtype(x_aval, jnp.bool): return arith_dialect.cmpi(ui_pred, x, y) elif jnp.issubdtype(x_aval, jnp.floating): return arith_dialect.cmpf(f_pred, x, y) @@ -1405,35 +2238,52 @@ def _comparison_lowering_rule_wg( (lax.gt_p, CmpIPred.sgt, CmpIPred.ugt, CmpFPred.OGT), (lax.ge_p, CmpIPred.sge, CmpIPred.uge, CmpFPred.OGE), ]: - mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][op] = partial( + mosaic_lowering_rules[gpu_core.WGxWG_SEMANTICS][op] = partial( _comparison_lowering_rule_wg, si_pred=si_pred, ui_pred=ui_pred, f_pred=f_pred, ) +@register_lowering_rule(lax.integer_pow_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.integer_pow_p, mgpu.LoweringSemantics.Warpgroup) +def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y): + [x_aval] = ctx.avals_in + if y <= 1: + raise NotImplementedError + + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: + mul_op = operator.mul + elif jnp.issubdtype(x_aval.dtype, jnp.integer): + mul_op = arith_dialect.muli + elif jnp.issubdtype(x_aval.dtype, jnp.floating): + mul_op = arith_dialect.mulf + else: + raise NotImplementedError(f"Unsupported dtype {x_aval.dtype}") -@register_lowering_rule(lax.div_p, mgpu.ThreadSemantics.Lane) -def _div_lowering_rule(ctx: LoweringRuleContext, x, y): - x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) - if ir.FloatType.isinstance(x.mlir_dtype): - return x / y - return x // y + # Y is an integer. Here we start with res = x so the range is y-1 + res = x + # Repeated doubling algorithm. + for i in reversed(range(y.bit_length() - 1)): + res = mul_op(res, res) + if (y >> i) & 1: + res = mul_op(res, x) + return res -@register_lowering_rule(lax.integer_pow_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.integer_pow_p, mgpu.ThreadSemantics.Warpgroup) -def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y): - if y != 2: - raise NotImplementedError - return _square_lowering_rule(ctx, x) +@register_lowering_rule(lax.clamp_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.clamp_p, mgpu.LoweringSemantics.Warpgroup) +def _clamp_lowering_rule(ctx: LoweringRuleContext, l, x, u): + return _lower_fun( + lambda l, x, u: lax.min(lax.max(x, l), u), multiple_results=False + )(ctx, l, x, u) -@register_lowering_rule(lax.square_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.square_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.square_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.square_p, mgpu.LoweringSemantics.Warpgroup) def _square_lowering_rule(ctx: LoweringRuleContext, x): [x_aval] = ctx.avals_in - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: x = _ensure_fa(x, x_aval.dtype) return x * x if jnp.issubdtype(x_aval.dtype, jnp.integer): @@ -1443,11 +2293,13 @@ def _square_lowering_rule(ctx: LoweringRuleContext, x): raise NotImplementedError(f"Unsupported dtype {x_aval.dtype}") -@register_lowering_rule(lax.rsqrt_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.rsqrt_p, mgpu.ThreadSemantics.Warpgroup) -def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.rsqrt_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.rsqrt_p, mgpu.LoweringSemantics.Warpgroup) +def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: return _ensure_fa(x, x_aval.dtype).rsqrt(approx=ctx.module_ctx.approx_math) fastmath = ( arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None @@ -1457,11 +2309,13 @@ def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): ) -@register_lowering_rule(lax.tanh_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.tanh_p, mgpu.ThreadSemantics.Warpgroup) -def _tanh_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.tanh_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.tanh_p, mgpu.LoweringSemantics.Warpgroup) +def _tanh_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: return _ensure_fa(x, x_aval.dtype).tanh(approx=ctx.module_ctx.approx_math) fastmath = ( arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None @@ -1469,23 +2323,27 @@ def _tanh_lowering_rule(ctx: LoweringRuleContext, x): return math_dialect.tanh(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) -def _logistic(x): +def _logistic(x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return 1.0 / (1 + lax.exp(-x)) -mosaic_lowering_rules[mgpu.ThreadSemantics.Lane][lax.logistic_p] = _lower_fun( +mosaic_lowering_rules[gpu_core.LANExWG_SEMANTICS][lax.logistic_p] = _lower_fun( _logistic, multiple_results=False ) -mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][lax.logistic_p] = ( +mosaic_lowering_rules[gpu_core.WGxWG_SEMANTICS][lax.logistic_p] = ( _lower_fun(_logistic, multiple_results=False) ) -@register_lowering_rule(lax.exp_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.exp_p, mgpu.ThreadSemantics.Warpgroup) -def _exp_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.exp_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.exp_p, mgpu.LoweringSemantics.Warpgroup) +def _exp_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: return _ensure_fa(x, x_aval.dtype).exp(approx=ctx.module_ctx.approx_math) fastmath = ( arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None @@ -1493,22 +2351,52 @@ def _exp_lowering_rule(ctx: LoweringRuleContext, x): return math_dialect.exp(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) -@register_lowering_rule(lax.exp2_p, mgpu.ThreadSemantics.Lane) -def _exp2_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.exp2_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.exp2_p, mgpu.LoweringSemantics.Warpgroup) +def _exp2_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: return _ensure_fa(x, x_aval.dtype).exp2(approx=ctx.module_ctx.approx_math) fastmath = ( arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None ) return math_dialect.exp2(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) +@register_lowering_rule(lax.sin_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.sin_p, mgpu.LoweringSemantics.Warpgroup) +def _sin_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") + [x_aval] = ctx.avals_in + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: + return _ensure_fa(x, x_aval.dtype).sin(approx=ctx.module_ctx.approx_math) + fastmath = ( + arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None + ) + return math_dialect.sin(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) -@register_lowering_rule(lax.log_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.log_p, mgpu.ThreadSemantics.Warpgroup) -def _log_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.cos_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.cos_p, mgpu.LoweringSemantics.Warpgroup) +def _cos_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: + return _ensure_fa(x, x_aval.dtype).cos(approx=ctx.module_ctx.approx_math) + fastmath = ( + arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None + ) + return math_dialect.cos(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) + +@register_lowering_rule(lax.log_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.log_p, mgpu.LoweringSemantics.Warpgroup) +def _log_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") + [x_aval] = ctx.avals_in + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: return _ensure_fa(x, x_aval.dtype).log(approx=ctx.module_ctx.approx_math) fastmath = ( arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None @@ -1516,48 +2404,224 @@ def _log_lowering_rule(ctx: LoweringRuleContext, x): return math_dialect.log(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) -@register_lowering_rule(lax.reduce_sum_p, mgpu.ThreadSemantics.Lane) -def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes): +@register_lowering_rule(lax.abs_p, mgpu.LoweringSemantics.Lane) +def _abs_lowering_rule(ctx: LoweringRuleContext, x): + [x_aval] = ctx.avals_in + return _ensure_fa(x, x_aval.dtype).abs() + + +@register_lowering_rule(lax.abs_p, mgpu.LoweringSemantics.Warpgroup) +def _abs_lowering_rule_wg(ctx: LoweringRuleContext, x): + [x_aval] = ctx.avals_in + x = _ensure_ir_value(x, x_aval.dtype) + if jnp.issubdtype(x_aval.dtype, jnp.floating): + return math_dialect.absf(x) + if jnp.issubdtype(x_aval.dtype, jnp.integer): + return math_dialect.absi(x) + raise NotImplementedError(f"Unsupported dtype for abs: {x_aval.dtype}") + + +@register_lowering_rule(lax.round_p, mgpu.LoweringSemantics.Lane) +def _round_lowering_rule(ctx: LoweringRuleContext, x, rounding_method): + [x_aval] = ctx.avals_in + x = _ensure_fa(x, x_aval.dtype) + match rounding_method: + case lax.RoundingMethod.AWAY_FROM_ZERO: + return x.round() + case lax.RoundingMethod.TO_NEAREST_EVEN: + return x.round_even() + case _: + assert_never(rounding_method) + + +@register_lowering_rule(lax.round_p, mgpu.LoweringSemantics.Warpgroup) +def _round_lowering_rule_wg(ctx: LoweringRuleContext, x, rounding_method): + [x_aval] = ctx.avals_in + x = _ensure_ir_value(x, x_aval.dtype) + if not jnp.issubdtype(x_aval.dtype, jnp.floating): + raise NotImplementedError(f"Unsupported dtype for round: {x_aval.dtype}") + match rounding_method: + case lax.RoundingMethod.AWAY_FROM_ZERO: + return math_dialect.round(x) + case lax.RoundingMethod.TO_NEAREST_EVEN: + return math_dialect.roundeven(x) + case _: + assert_never(rounding_method) + +_copysign_p = jax_core.Primitive("_copysign") + + +def _copysign(x1: jax.typing.ArrayLike, x2: jax.typing.ArrayLike) -> jax.Array: + return _copysign_p.bind(x1, x2) + + +@_copysign_p.def_abstract_eval +def _copysign_abstract_eval(x1, x2): + return jax_core.ShapedArray(x2.shape, x2.dtype) + + +@register_lowering_rule(_copysign_p, mgpu.LoweringSemantics.Lane) +def _copysign_lowering_rule(ctx: LoweringRuleContext, x1, x2): + [x1_aval, x2_aval] = ctx.avals_in + x1 = _ensure_fa(x1, x1_aval.dtype) + x2 = _ensure_fa(x2, x2_aval.dtype) + return x1.copysign(x2) + + +@register_lowering_rule(_copysign_p, mgpu.LoweringSemantics.Warpgroup) +def _copysign_lowering_rule(ctx: LoweringRuleContext, x1, x2): + [x1_aval, x2_aval] = ctx.avals_in + x1 = _ensure_ir_value(x1, x1_aval.dtype) + x2 = _ensure_ir_value(x2, x2_aval.dtype) + return math_dialect.copysign(x1, x2) + + +@register_lowering_rule(lax.sign_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.sign_p, mgpu.LoweringSemantics.Warpgroup) +def _sign_lowering_rule(ctx: LoweringRuleContext, x): + def sign(x): + if jnp.issubdtype(x.dtype, jnp.floating): + ones = lax.full(x.shape, 1.0, dtype=x.dtype) + zeros = lax.full(x.shape, 0.0, dtype=x.dtype) + return lax.select(x != 0, _copysign(ones, x), zeros) + if jnp.issubdtype(x.dtype, jnp.signedinteger): + return (x > 0).astype(x.dtype) - (x < 0).astype(x.dtype) + if jnp.issubdtype(x.dtype, jnp.unsignedinteger): + return (x != 0).astype(x.dtype) + raise ValueError(f"Unsupported dtype for sign: {x.dtype}") + + return _lower_fun(sign, multiple_results=False)(ctx, x) + + +@register_lowering_rule(lax.erf_p, mgpu.LoweringSemantics.Lane) +def _erf_lowering_rule(ctx: LoweringRuleContext, x): + [x_aval] = ctx.avals_in + return _ensure_fa(x, x_aval.dtype).erf() + + +@register_lowering_rule(lax.erf_p, mgpu.LoweringSemantics.Warpgroup) +def _erf_lowering_rule_wg(ctx: LoweringRuleContext, x): + [x_aval] = ctx.avals_in + return math_dialect.erf(_ensure_ir_value(x, x_aval.dtype)) + + +@register_lowering_rule(lax.atan2_p, mgpu.LoweringSemantics.Lane) +def _atan2_lowering_rule(ctx: LoweringRuleContext, y, x): + y, x = _bcast(y, x, *ctx.avals_in, *ctx.avals_out) + return y.atan2(x) + + +@register_lowering_rule(lax.atan2_p, mgpu.LoweringSemantics.Warpgroup) +def _atan2_lowering_rule_wg(ctx: LoweringRuleContext, y, x): + y, x = _bcast_wg(y, x, *ctx.avals_in, *ctx.avals_out) + return math_dialect.atan2(y, x) + + +@register_lowering_rule(lax.reshape_p, mgpu.LoweringSemantics.Lane) +def _reshape_lowering_rule( + ctx: LoweringRuleContext, x, new_sizes, dimensions, sharding +): + if dimensions is not None: + raise NotImplementedError("Not implemented: dimensions") + if sharding is not None: + raise NotImplementedError("Not implemented: sharding") + [x_aval] = ctx.avals_in + return _ensure_fa(x, x_aval.dtype).reshape(new_sizes) + + +@register_lowering_rule(lax.reshape_p, mgpu.LoweringSemantics.Warpgroup) +def _reshape_lowering_rule_wg( + ctx: LoweringRuleContext, x, new_sizes, dimensions, sharding +): + if dimensions is not None: + raise NotImplementedError("Not implemented: dimensions") + if sharding is not None: + raise NotImplementedError("Not implemented: sharding") + [x_aval] = ctx.avals_in + x = _ensure_ir_value(x, x_aval.dtype) + if x_aval.ndim == 0: # scalar + res_ty = ir.VectorType.get(new_sizes, x.type) + return vector_dialect.broadcast(res_ty, x) + else: + res_ty = ir.VectorType.get(new_sizes, ir.VectorType(x.type).element_type) + return vector_dialect.shape_cast(res_ty, x) + + +@register_lowering_rule(lax.squeeze_p, mgpu.LoweringSemantics.Lane) +def _squeeze_lowering_rule(ctx: LoweringRuleContext, x, dimensions): + [x_aval] = ctx.avals_in + [y_aval] = ctx.avals_out + return _ensure_fa(x, x_aval.dtype).reshape(y_aval.shape) + + +@register_lowering_rule(lax.squeeze_p, mgpu.LoweringSemantics.Warpgroup) +def _squeeze_lowering_rule_wg(ctx: LoweringRuleContext, x, dimensions): + [x_aval] = ctx.avals_in + [y_aval] = ctx.avals_out + x = _ensure_ir_value(x, x_aval.dtype) + if y_aval.ndim == 0: # scalar + return vector_dialect.extract( + x, dynamic_position=[], static_position=[0] * x_aval.ndim + ) + else: + res_ty = ir.VectorType.get(y_aval.shape, ir.VectorType(x.type).element_type) + return vector_dialect.shape_cast(res_ty, x) + + +def _reduce_lowering_rule(op, ctx: LoweringRuleContext, x, *, axes, **kwargs): [x_aval] = ctx.avals_in match x.layout: case mgpu.WGStridedFragLayout(): if set(axes) != set(range(x_aval.ndim)): raise NotImplementedError("No support for axes yet") + # To relax the restriction below, you need to ensure sufficient + # synchronization with other places that use `scratch_view` (which at the + # time of writing is only `run_scoped`). + if ctx.module_ctx.axis_names.wg is not None: + raise NotImplementedError( + "No support for reduce_sum over all axes and multiple Pallas" + " threads" + ) scratch_ty = jax.ShapeDtypeStruct(shape=(4,), dtype=x_aval.dtype) - with ctx.module_ctx.scratch_view([scratch_ty]) as [scratch]: - return x.reduce_sum(scratch) - case mgpu.WGMMA_LAYOUT: - if axes != (x_aval.ndim - 1,): - raise NotImplementedError - if not jnp.issubdtype(x_aval.dtype, jnp.floating): - raise NotImplementedError - return x.reduce("add", axes[0]) + with ctx.module_ctx.scratch_view(scratch_ty) as scratch: + return x.reduce(op, axes, scratch) + case mgpu.TiledLayout(): + if len(axes) != 1: + raise NotImplementedError("Multi-axis reductions not supported") + reduced_dim = x.layout.tiling.tile_dimension(axes[0]) + if any(reduced_dim[d] for d in x.layout.partitioned_warp_dims): + size = x.layout.vector_length * 128 # a vector per lane in each WG. + if size > REDUCE_SCRATCH_ELEMS: + raise NotImplementedError( + f"Reduce scratch {size=} exceeds max={REDUCE_SCRATCH_ELEMS}" + ) + scratch_ty = jax.ShapeDtypeStruct(shape=(size,), dtype=x_aval.dtype) + ctx = ctx.module_ctx.scratch_view(scratch_ty) + else: + ctx = contextlib.nullcontext(None) + with ctx as scratch: + return x.reduce(op, axes[0], scratch=scratch) case _: raise NotImplementedError(f"Unsupported layout {x.layout}") - -@register_lowering_rule(lax.reduce_max_p, mgpu.ThreadSemantics.Lane) -def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes): - [x_aval] = ctx.avals_in - match x.layout: - case mgpu.WGMMA_LAYOUT: - if axes != (x_aval.ndim - 1,): - raise NotImplementedError - if not jnp.issubdtype(x_aval.dtype, jnp.floating): - raise NotImplementedError - return x.reduce("max", axes[0]) - case _: - raise NotImplementedError(f"Unsupported layout {x.layout}") +register_lowering_rule(lax.reduce_sum_p, mgpu.LoweringSemantics.Lane)( + functools.partial(_reduce_lowering_rule, "add") +) +register_lowering_rule(lax.reduce_max_p, mgpu.LoweringSemantics.Lane)( + functools.partial(_reduce_lowering_rule, "max") +) +register_lowering_rule(lax.reduce_min_p, mgpu.LoweringSemantics.Lane)( + functools.partial(_reduce_lowering_rule, "min") +) def _reduce_lowering_rule_wg( - kind: vector_dialect.CombiningKind, - acc: object, ctx: LoweringRuleContext, + kind: vector_dialect.CombiningKind, + acc: int | float, x, - *, axes, -) -> ir.OpView: +) -> ir.Value: [x_aval] = ctx.avals_in [out_aval] = ctx.avals_out x = _ensure_ir_value(x, x_aval.dtype) @@ -1569,96 +2633,234 @@ def _reduce_lowering_rule_wg( x = vector_dialect.shape_cast( ir.VectorType.get([x_aval.size], out_type), x ) - return vector_dialect.ReductionOp(out_type, kind, x) - acc = vector_dialect.splat( + reduction = vector_dialect.ReductionOp(out_type, kind, x) + reduction.attributes["offset"] = ir.IntegerAttr.get( + ir.IntegerType.get_signless(32), ctx.module_ctx.smem_used_bytes + ) + return reduction.result + acc = vector_dialect.broadcast( ir.VectorType.get(out_aval.shape, out_type), _ensure_ir_value(acc, out_aval.dtype), ) - return vector_dialect.MultiDimReductionOp(kind, x, acc, axes) + return vector_dialect.multi_reduction(kind, x, acc, axes) -@register_lowering_rule(lax.reduce_sum_p, mgpu.ThreadSemantics.Warpgroup) -def _reduce_sum_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes): - op = _reduce_lowering_rule_wg( - vector_dialect.CombiningKind.ADD, 0, ctx, x, axes=axes - ) - op.attributes["offset"] = ir.IntegerAttr.get( - ir.IntegerType.get_signless(32), ctx.module_ctx.smem_used_bytes - ) - return op.result +@register_lowering_rule(lax.reduce_sum_p, mgpu.LoweringSemantics.Warpgroup) +def _reduce_sum_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes, + out_sharding): + kind = vector_dialect.CombiningKind.ADD + return _reduce_lowering_rule_wg(ctx, kind, 0, x, axes) -@register_lowering_rule(lax.reduce_max_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.reduce_max_p, mgpu.LoweringSemantics.Warpgroup) def _reduce_max_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes): [x_aval] = ctx.avals_in if jnp.issubdtype(x_aval.dtype, jnp.floating): kind = vector_dialect.CombiningKind.MAXIMUMF acc = float("-inf") - elif jnp.issubdtype(x_aval.dtype, jnp.signedinteger): - kind = vector_dialect.CombiningKind.MAXSI - acc = np.iinfo(x_aval.dtype).max - elif jnp.issubdtype(x_aval.dtype, jnp.unsignedinteger): - kind = vector_dialect.CombiningKind.MAXUI + elif jnp.issubdtype(x_aval.dtype, jnp.integer): + if jnp.issubdtype(x_aval.dtype, jnp.signedinteger): + kind = vector_dialect.CombiningKind.MAXSI + else: + kind = vector_dialect.CombiningKind.MAXUI + acc = np.iinfo(x_aval.dtype).min + else: + raise NotImplementedError(f"Unsupported dtype {x_aval.dtype}") + return _reduce_lowering_rule_wg(ctx, kind, acc, x, axes) + + +@register_lowering_rule(lax.reduce_min_p, mgpu.LoweringSemantics.Warpgroup) +def _reduce_min_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes): + [x_aval] = ctx.avals_in + if jnp.issubdtype(x_aval.dtype, jnp.floating): + kind = vector_dialect.CombiningKind.MINIMUMF + acc = float("inf") + elif jnp.issubdtype(x_aval.dtype, jnp.integer): + if jnp.issubdtype(x_aval.dtype, jnp.signedinteger): + kind = vector_dialect.CombiningKind.MINSI + else: + kind = vector_dialect.CombiningKind.MINUI acc = np.iinfo(x_aval.dtype).max else: raise NotImplementedError(f"Unsupported dtype {x_aval.dtype}") - return _reduce_lowering_rule_wg(kind, acc, ctx, x, axes=axes).result + return _reduce_lowering_rule_wg(ctx, kind, acc, x, axes) -@register_lowering_rule(lax.axis_index_p, mgpu.ThreadSemantics.Lane) -def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): - i32 = ir.IntegerType.get_signless(32) - grid_names = ctx.module_ctx.grid_names +@register_lowering_rule(lax.reduce_prod_p, mgpu.LoweringSemantics.Warpgroup) +def _reduce_prod_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes): + [x_aval] = ctx.avals_in + if jnp.issubdtype(x_aval.dtype, jnp.floating): + acc = 1.0 + elif jnp.issubdtype(x_aval.dtype, jnp.integer): + acc = 1 + else: + raise NotImplementedError(f"Unsupported dtype {x_aval.dtype}") + kind = vector_dialect.CombiningKind.MUL + return _reduce_lowering_rule_wg(ctx, kind, acc, x, axes) + + +def _block_id(ctx: LoweringRuleContext, dim: gpu_dialect.Dimension) -> ir.Value: + result = gpu_dialect.block_id(dim) + cluster_size = ctx.launch_ctx.cluster_size + if math.prod(cluster_size) == 1 or cluster_size[dim.value] == 1: + return result + # We scale the grid in the presence of clusters, so we need to scale the + # block ID back here. + return arith_dialect.divui(result, _as_index(cluster_size[dim.value])) + + +def _resolve_cluster_axis(axis_names: _AxisNames | None, axis_name: str): + if not axis_names: + raise LookupError( + "No axis names are available. Make sure you are using `pl.core_map`" + " with a `plgpu.Mesh`." + ) + if not axis_names or axis_name not in axis_names.cluster: + raise LookupError( + f"Unknown cluster axis {axis_name}, available axes:" + f" {[*axis_names.cluster]}" + ) + return gpu_dialect.Dimension(axis_names.cluster.index(axis_name)) + + +def _is_block_local_scope(collective_axes: CollectiveAxesType, + axis_names: _AxisNames): + """Returns whether the collective axes represents a block scope.""" + if axis_names.wg is None: + return not collective_axes + else: + return collective_axes == (axis_names.wg,) + + +def _is_global_scope(collective_axes: CollectiveAxesType, + axis_names: _AxisNames): + """Returns whether the collective axes represents a GPU global scope.""" + return set(collective_axes) == set(axis_names) + +def block_id_to_grid_id(ctx: LoweringRuleContext, + block_ids: Sequence[ir.Value], + axis_name: Hashable): squashed_dims = ctx.module_ctx.squashed_dims + axis_names = ctx.module_ctx.axis_names if squashed_dims: - unsquashed_names = grid_names[-3:] - squashed_names = grid_names[:-3] + unsquashed_names = axis_names.grid[:2] + squashed_names = axis_names.grid[2:] else: # These are unused but initialized for type checkers. - unsquashed_names = () - squashed_names = () - if grid_names and axis_name in grid_names: - if axis_name == grid_names[-1]: - return mgpu.warpgroup_idx(sync=True) + unsquashed_names = squashed_names = () + + if squashed_dims: + if axis_name in unsquashed_names: + # We reversed the grid and cluster axes. + # e.g. for the grid (a, b, c, d, wg) + # squashed = (a, b) Mapped to Dimension.z (2) + # unsquashed = (c, d) Mapped to Dimension.y (1) and Dimension.x (0) + idx = unsquashed_names.index(axis_name) + return block_ids[gpu_dialect.Dimension(idx)] else: - if squashed_dims: - if axis_name in unsquashed_names: - # We add 1 to the index because the first dimension is the - # squashed dimension. - # e.g. for the grid (a, b, c, d, wg) - # squashed = (a, b) Mapped to Dimension.x (0) - # unsquashed = (c, d) Mapped to Dimension.y (1) and Dimension.z (2) - idx = unsquashed_names.index(axis_name) + 1 - return arith_dialect.index_cast( - i32, - gpu_dialect.block_id(gpu_dialect.Dimension(idx)), - ) - elif axis_name in squashed_names: - # All squashed dimensions are mapped to Dimension.x. - block_id = gpu_dialect.block_id(gpu_dialect.Dimension.x) - axis = squashed_names.index(axis_name) - return _unravel_program_id(block_id, axis, squashed_dims) - else: - if axis_name in grid_names: - idx = grid_names.index(axis_name) - return arith_dialect.index_cast( - i32, - gpu_dialect.block_id(gpu_dialect.Dimension(idx)), - ) - raise ValueError( - "Named axes can only refer to GPUMesh axes in Mosaic GPU kernels" - ) + assert axis_name in squashed_names + # All squashed dimensions are mapped to Dimension.z. + axis = squashed_names.index(axis_name) + return _unravel_program_id( + _as_index(block_ids[gpu_dialect.Dimension.z]), axis, squashed_dims + ) + else: + assert axis_name in axis_names.grid + idx = axis_names.grid.index(axis_name) + return block_ids[gpu_dialect.Dimension(idx)] -@register_lowering_rule(primitives.debug_print_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.axis_index_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.axis_index_p, mgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warp) +@register_lowering_rule(lax.axis_index_p, mgpu.LoweringSemantics.Warpgroup) +def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): + if ctx.module_ctx.primitive_semantics == gpu_core.PrimitiveSemantics.Warp: + if axis_name == ctx.module_ctx.warp_axis_name: + return mgpu.warp_idx(sync=True) + raise ValueError( + "Named axes can only refer to the warp axis name inside of core_map." + ) + gpu_axis_names = ctx.module_ctx.axis_names + jax_axis_names = getattr(ctx.module_ctx.mesh_info, "axis_names", ()) + if gpu_axis_names is None and not jax_axis_names: + raise LookupError( + "No axis names are available. Make sure you are using `pl.core_map`" + " with a `plgpu.Mesh` or an appropriate JAX device mesh." + ) + if axis_name not in itertools.chain(gpu_axis_names or (), jax_axis_names): + raise LookupError( + f"Axis {axis_name} does not refer to a GPU mesh axis (available axes:" + f" {[*gpu_axis_names]}) or a JAX mesh axis (available axes:" + f" {[*jax_axis_names]})" + ) + if axis_name in jax_axis_names: + jax_mesh = ctx.module_ctx.mesh_info + assert jax_mesh is not None + device_id = ctx.launch_ctx.device_id() + jax_mesh_shape = jax_mesh.mesh_shape + axis_index = jax_axis_names.index(axis_name) + i32 = ir.IntegerType.get_signless(32) + axis_size = _ir_constant(jax_mesh_shape[axis_index], i32) + minor_divisor = _ir_constant( + np.prod(jax_mesh_shape[axis_index + 1 :], dtype=np.int32), i32 + ) + return arith_dialect.remsi(arith_dialect.divsi(device_id, minor_divisor), axis_size) + + # We already checked that the axis is in scope and it wasn't a JAX mesh axis. + assert gpu_axis_names is not None + + # We only deal with GPU axes from now on. + axis_names = gpu_axis_names + if axis_names.wg is not None and axis_name == axis_names.wg: + return mgpu.warpgroup_idx(sync=True) + + if axis_name in axis_names.cluster: + return arith_dialect.index_cast( + ir.IntegerType.get_signless(32), + gpu_dialect.cluster_block_id( + gpu_dialect.Dimension(axis_names.cluster.index(axis_name)) + ), + ) + block_ids = tuple(arith_dialect.index_cast( + ir.IntegerType.get_signless(32), + _block_id(ctx, dimension), + ) for dimension in gpu_dialect.Dimension) + return block_id_to_grid_id(ctx, block_ids, axis_name) + + +@register_lowering_rule(debugging.debug_print_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule( + debugging.debug_print_p, + mgpu.LoweringSemantics.Lane, + gpu_core.PrimitiveSemantics.Warp, +) +@register_lowering_rule( + debugging.debug_print_p, mgpu.LoweringSemantics.Warpgroup +) def _debug_print_lowering_rule( ctx: LoweringRuleContext, *args, fmt, - has_placeholders: bool, + ordered, + partitioned, + in_tree, + static_args, + np_printoptions, + has_placeholders, + logging_record, ): - del has_placeholders # Unused. + del partitioned, np_printoptions, has_placeholders + if ordered: + raise NotImplementedError("Ordered debug_print is not supported on Pallas.") + args, kwargs = debugging.merge_callback_args(in_tree, args, static_args) + if kwargs: + raise ValueError( + "Only positional arguments are supported by debug_print on Pallas." + ) primitives.check_debug_print_format(fmt, *args) + scope = mgpu.ThreadSubset.WARPGROUP + if ctx.module_ctx.primitive_semantics == gpu_core.PrimitiveSemantics.Warp: + scope = mgpu.ThreadSubset.WARP if not any(aval.shape for aval in ctx.avals_in): mgpu.debug_print( fmt, @@ -1666,10 +2868,15 @@ def _debug_print_lowering_rule( _ensure_ir_value(arg, aval.dtype) for arg, aval in zip(args, ctx.avals_in) ), + scope=scope ) elif len(ctx.avals_in) == 1: [arg] = args - arg.debug_print(fmt) + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Warpgroup: + mgpu.dialect.debug_print(fmt, arg) + else: + arg.debug_print(fmt) + else: raise NotImplementedError( "debug_print only supports printing of scalar values, or a single array" @@ -1678,93 +2885,192 @@ def _debug_print_lowering_rule( return () -@register_lowering_rule(primitives.debug_print_p, mgpu.ThreadSemantics.Warpgroup) -def _debug_print_lowering_rule( - ctx: LoweringRuleContext, - *args, - fmt, - has_placeholders: bool, -): - del ctx, has_placeholders # Unused. - if args: - raise NotImplementedError("debug_print only supports string messages in warpgroup semantics") - mgpu.debug_print(fmt) - return () - -@register_lowering_rule(primitives.run_scoped_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(primitives.run_scoped_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(primitives.run_scoped_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(primitives.run_scoped_p, mgpu.LoweringSemantics.Warpgroup) def _run_scoped_lowering_rule( - ctx: LoweringRuleContext, *consts, jaxpr: jax_core.Jaxpr + ctx: LoweringRuleContext, *consts, jaxpr: jax_core.Jaxpr, collective_axes ): input_refs = [] should_discharge = [] - alloc_stack = contextlib.ExitStack() - for v in jaxpr.invars: - aval = v.aval - if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef): - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Warpgroup: - # TODO(bchetioui): Fix this and remove the NotImplementedError. - raise NotImplementedError( - "WGMMA accumulators are not supported with Warpgroup semantics." - ) - mlir_dtype = mlir.dtype_to_ir_type(aval.dtype) - input_refs.append(mgpu.WGMMAAccumulator.zero(*aval.shape, mlir_dtype)) - should_discharge.append(True) - elif isinstance(aval.dtype, gpu_core.BarrierType): - input_refs.append( - ctx.module_ctx.reserve_barrier( - mgpu.Barrier( - aval.dtype.num_arrivals - * ctx.estimator_ctx.arrival_multiplier, - *aval.shape, - ) + wg_axis = ctx.module_ctx.axis_names.wg + is_multithreaded = wg_axis is not None + is_thread_collective = is_multithreaded and collective_axes == (wg_axis,) + # Make sure everyone has exited previous scoped allocations. Note that we + # don't synchronize when we exit the allocation, but only when we might want + # to reuse its memory again. + if collective_axes and collective_axes != (wg_axis,): + raise ValueError( + "Only thread-collective allocations are supported in run_scoped." + ) + if is_multithreaded and is_thread_collective: + gpu_dialect.barrier() + with contextlib.ExitStack() as alloc_stack: + for v in jaxpr.invars: + aval = cast(ShapedAbstractValue, v.aval) + if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef): + if collective_axes: + raise ValueError( + "WGMMA accumulators can only be allocated non-collectively. Hint:" + " remove collective_axes from run_scoped. If other allocations" + " are performed as well, split the run_scoped into two." ) - ) - should_discharge.append(False) - elif aval.memory_space == gpu_core.SMEM: - [input_ref] = alloc_stack.enter_context( - ctx.module_ctx.scratch_view( - [jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype)] + is_signed = mgpu_utils.is_signed(aval.dtype) + if is_signed is not None and not is_signed: + raise ValueError( + "Invalid WGMMA accumulator dtype for s8/i8 WGMMA. " + f"Expected signed integer, but got {aval.dtype}." + ) + + dtype = mlir.dtype_to_ir_type(aval.dtype) + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: + input_refs.append( + mgpu.WGMMAAccumulator.zero(*aval.shape, dtype, is_signed=is_signed) + ) + else: + if isinstance(dtype, ir.IntegerType): + zero = arith_dialect.constant(dtype, 0) + else: + zero = arith_dialect.constant(dtype, 0.0) + acc = vector_dialect.broadcast( + ir.VectorType.get(aval.shape, dtype), zero ) + acc = mgpu.dialect.optimization_barrier([acc]) + nvvm_dialect.wgmma_fence_aligned() + input_refs.append(acc) + should_discharge.append(True) + continue + if ( + isinstance(aval, state_types.AbstractRef) + and aval.memory_space == gpu_core.GMEM + and jnp.issubdtype(aval.dtype, pallas_core.semaphore) + ): + input_ref = alloc_stack.enter_context( + ctx.module_ctx.reserve_semaphores( + aval.shape, collective_axes=collective_axes + ) + ) + input_refs.append(input_ref) + should_discharge.append(False) + continue + + # All other allocations must be made collectively across all threads. + if is_multithreaded and not is_thread_collective: + raise NotImplementedError( + "Only thread-collective allocations are supported in multithreaded" + " kernels. Hint: add" + f" collective_axes={ctx.module_ctx.axis_names.wg} to your" + " run_scoped if you intend all threads to share the same" + f" allocation (currently collective_axes={collective_axes})." + ) + if isinstance(aval.dtype, gpu_core.BarrierType): + multiplier = (1 if aval.dtype.orders_tensor_core else + ctx.estimator_ctx.arrival_multiplier) + barrier_ref = alloc_stack.enter_context( + ctx.module_ctx.reserve_barrier( + mgpu.Barrier( + aval.dtype.num_arrivals * multiplier, + *aval.shape, + ) + ) + ) + input_refs.append(barrier_ref) + should_discharge.append(False) + continue + if isinstance(aval.dtype, gpu_core.ClusterBarrierType): + collective_dims = jax.tree.map( + lambda axis: _resolve_cluster_axis(ctx.module_ctx.axis_names, axis), + aval.dtype.collective_axes, + ) + barrier_ref = alloc_stack.enter_context( + ctx.module_ctx.reserve_barrier( + mgpu.ClusterBarrier(collective_dims, aval.dtype.num_arrivals, *aval.shape) + ) + ) + input_refs.append(barrier_ref) + should_discharge.append(False) + continue + + if not isinstance(aval, state_types.AbstractRef): + raise ValueError(f"Can't convert to ref: {aval}") + if aval.memory_space == gpu_core.SMEM: + input_ref = alloc_stack.enter_context( + ctx.module_ctx.scratch_view( + jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype) + ) + ) + input_refs.append(input_ref) + should_discharge.append(False) + elif aval.memory_space == gpu_core.TMEM: + input_ref = alloc_stack.enter_context( + ctx.module_ctx.alloc_tmem( + jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype), + layout=aval.layout, + ) + ) + input_refs.append(input_ref) + should_discharge.append(False) + + if any(should_discharge): + # We convert consts to args, because we only have ir.Values and + # not JAX values during lowering. discharge_state() produces JAX + # valiues for the arguments but expects them to be provided for the + # consts. We also don't want to wrap the values in refs. + no_const_jaxpr = pe.convert_constvars_jaxpr(jaxpr) + should_discharge = [False] * len(consts) + should_discharge + discharged_jaxpr, _ = discharge.discharge_state(no_const_jaxpr, (), should_discharge=should_discharge) + new_input_vals = (*consts, *input_refs) + outs = lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, + ctx.launch_ctx, + discharged_jaxpr, + new_input_vals, + (), ) - input_refs.append(input_ref) - should_discharge.append(False) + # Discharge appends to the output the refs that got discharged. + outs = outs[:-sum(should_discharge)] else: - raise ValueError(f"Can't convert to ref: {aval}") - - if any(should_discharge): - # We convert consts to args, because we only have ir.Values and - # not JAX values during lowering. discharge_state() produces JAX - # valiues for the aguments but expects them to be provided for the - # consts. We also don't want to wrap the values in refs. - no_const_jaxpr = pe.convert_constvars_jaxpr(jaxpr) - should_discharge = [False] * len(consts) + should_discharge - discharged_jaxpr, _ = discharge.discharge_state(no_const_jaxpr, (), should_discharge=should_discharge) - new_input_vals = consts + tuple(input_refs) - outs = lower_jaxpr_to_mosaic_gpu( - ctx.module_ctx, - ctx.launch_ctx, - discharged_jaxpr, - new_input_vals, - (), - ) - # Discharge appends to the output the refs that got discharged. - outs = outs[:-sum(should_discharge)] - else: - outs = lower_jaxpr_to_mosaic_gpu( - ctx.module_ctx, - ctx.launch_ctx, - jaxpr, - input_refs, - consts, - ) + outs = lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, + ctx.launch_ctx, + jaxpr, + input_refs, + consts, + ) assert len(outs) == len(jaxpr.outvars), (jaxpr, outs) return outs -@register_lowering_rule(discharge.run_state_p, mgpu.ThreadSemantics.Lane) +@_register_resource_estimator(primitives.get_global_p) +def _get_global_resource_estimator( + ctx: ResourceEstimatorContext, *, what +) -> Resources: + if what.memory_space == gpu_core.GMEM and jnp.issubdtype( + what.dtype, pallas_core.semaphore + ): + collective_axes = tuple(ctx.axis_names) + return Resources(scoped_gmem_semaphores={collective_axes: what.size}) + raise NotImplementedError(f"get_global only supports semaphores, got {what}") + + +@register_lowering_rule(primitives.get_global_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule( + primitives.get_global_p, mgpu.LoweringSemantics.Warpgroup +) +def _get_global_lowering_rule(ctx: LoweringRuleContext, *, what): + if what.memory_space == gpu_core.GMEM and jnp.issubdtype( + what.dtype, pallas_core.semaphore + ): + collective_axes = tuple(ctx.module_ctx.axis_names) + return ctx.module_ctx.reserve_semaphores( + what.shape, collective_axes=collective_axes + ).__enter__() + raise NotImplementedError(f"get_global only supports semaphores, got {what}") + + +@register_lowering_rule(discharge.run_state_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(discharge.run_state_p, mgpu.LoweringSemantics.Warpgroup) def _run_state_lowering_rule( ctx: LoweringRuleContext, *args, @@ -1782,7 +3088,12 @@ def _run_state_lowering_rule( for arg, v, out_aval in zip(args, jaxpr.invars, ctx.avals_out): aval = v.aval if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef): - new_input_vals.append(mgpu.WGMMAAccumulator.from_registers(arg)) + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Warpgroup: + arg = mgpu.dialect.optimization_barrier([arg]) + nvvm_dialect.wgmma_fence_aligned() + new_input_vals.append(arg) + else: + new_input_vals.append(mgpu.WGMMAAccumulator.from_registers(arg)) should_discharge.append(True) assert isinstance(out_aval, jax_core.ShapedArray) else: @@ -1817,12 +3128,12 @@ def _lower_jaxpr_to_for_loop( ctx: LoweringRuleContext, jaxpr: jax_core.Jaxpr, start: ir.Value, - length: ir.Value, + length: int | ir.Value, consts, *args, has_loop_index: bool, + unroll: int | None = None, ): - _consts_avals, arg_avals = util.split_list(ctx.avals_in, [len(consts)]) arg_avals = arg_avals[has_loop_index:] out_avals = [] @@ -1836,28 +3147,61 @@ def as_values(vals, avals): _ensure = ( _ensure_fa - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane else _ensure_ir_value ) - return [v if a else _ensure(v, av) for a, v, av in zip(is_acc, vals, avals)] - - @mgpu.fori(length, as_values(args, arg_avals)) - def loop(loop_index, body_args): - if has_loop_index: - loop_index = arith_dialect.addi(loop_index, start) - jaxpr_args = [*consts, loop_index, *body_args] - else: - jaxpr_args = [*consts, *body_args] - outs = lower_jaxpr_to_mosaic_gpu( - ctx.module_ctx, ctx.launch_ctx, jaxpr, jaxpr_args - ) + return [ + v if a else _ensure(v, av.dtype) + for a, v, av in zip(is_acc, vals, avals) + ] + + def loop(base_loop_index, body_args): + outs = body_args + if unroll is not None: + base_loop_index = arith_dialect.muli( + base_loop_index, _ir_constant(unroll, start.type) + ) + base_loop_index = arith_dialect.addi(base_loop_index, start) + for step in range(unroll or 1): + if has_loop_index: + loop_index = arith_dialect.addi( + base_loop_index, _ir_constant(step, start.type) + ) + jaxpr_args = [*consts, loop_index, *outs] + else: + jaxpr_args = [*consts, *outs] + outs = lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, ctx.launch_ctx, jaxpr, jaxpr_args + ) return as_values(outs, out_avals) - return loop.results + if unroll is not None: + if not isinstance(length, int): + raise NotImplementedError( + "``length`` must be an integer when ``unroll` is specified, got" + f" {length}" + ) + if length % unroll: + # TODO(slebedev): Emit an epilogue taking care of the remaining steps. + raise NotImplementedError( + f"``unroll`` must divide ``length``, got {unroll=} and {length=}" + ) + if unroll == length: + # Special-case: the loop is fully unrolled. + return loop(_ir_constant(0, start.type), as_values(args, arg_avals)) + return mgpu.fori( + _ir_constant(length // unroll, start.type), as_values(args, arg_avals) + )(loop).results + else: + if not isinstance(length, ir.Value): + length = _ir_constant(length, start.type) + return mgpu.fori(length, as_values(args, arg_avals))(loop).results -@register_lowering_rule(lax.scan_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.scan_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.scan_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.scan_p, mgpu.LoweringSemantics.Warpgroup) +@register_lowering_rule(lax.scan_p, mgpu.LoweringSemantics.Lane, + gpu_core.PrimitiveSemantics.Warp) def _scan_lowering_rule( ctx: LoweringRuleContext, *args, @@ -1871,13 +3215,9 @@ def _scan_lowering_rule( _split_transpose: bool, ): # Can only handle fori_loop-like scans. - if ( - (num_extensive := len(args) - num_consts - num_carry) - or reverse - or unroll != 1 - ): + if (num_extensive := len(args) - num_consts - num_carry) or reverse: raise NotImplementedError - del linear, num_extensive, reverse, unroll + del linear, num_extensive, reverse jaxpr, jaxpr_consts = jaxpr.jaxpr, jaxpr.consts if jaxpr_consts: @@ -1893,17 +3233,24 @@ def _scan_lowering_rule( start, *args = args index_aval, *_ = arg_avals start: ir.Value = _ensure_ir_value(start, index_aval.dtype) - length = _ir_constant(length, start.type) else: start = _i32_constant(0) - length = _i32_constant(length) + for_out = _lower_jaxpr_to_for_loop( - ctx, jaxpr, start, length, consts, *args, has_loop_index=has_loop_index + ctx, + jaxpr, + start, + length, + consts, + *args, + has_loop_index=has_loop_index, + unroll=unroll, ) if has_loop_index: # Need to return the final loop index value if the outer scan expects # it as an output. - return [length, *for_out] + loop_index = arith_dialect.addi(start, _ir_constant(length, start.type)) + return [loop_index, *for_out] return for_out @@ -1945,8 +3292,9 @@ def _lower_while_via_fori( return ub, ub, *for_out -@register_lowering_rule(lax.while_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.while_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.while_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.while_p, *gpu_core.LANExWARP_SEMANTICS) +@register_lowering_rule(lax.while_p, mgpu.LoweringSemantics.Warpgroup) def _while_lowering_rule( ctx: LoweringRuleContext, *args, @@ -1970,7 +3318,7 @@ def _while_lowering_rule( _is_acc = lambda x: isinstance(x, mgpu.WGMMAAccumulator) _ensure = _ensure_ir_value - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: _ensure = lambda v, aval: v if _is_acc(v) else _ensure_fa(v, aval.dtype) # If we fail conversion to fori, fallback to an ordinary while loop. @@ -2004,56 +3352,71 @@ def _while_lowering_rule( ctx.module_ctx, ctx.launch_ctx, body_jaxpr.jaxpr, body_args ) loop_out = [*map(_ensure, loop_out, carry_avals)] - for idx, (carry_fa, out_fa) in enumerate(zip(carry, loop_out)): - if _is_acc(carry_fa) != _is_acc(out_fa): - raise ValueError( - f"The loop body output has unexpected accumulator type: output[{idx}]" - f" is {out_fa}, when it should be {carry_fa}." - ) + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: + for idx, (carry_fa, out_fa) in enumerate(zip(carry, loop_out)): + if _is_acc(carry_fa) != _is_acc(out_fa): + raise ValueError( + f"The loop body output has unexpected accumulator type:" + f" output[{idx}] is {out_fa}, when it should be {carry_fa}." + ) - if not _is_acc(out_fa) and carry_fa.layout != out_fa.layout: - raise ValueError( - f"The loop body output has unexpected layout: output[{idx}] has" - f" layout {out_fa.layout}, when it should be {carry_fa.layout}." - ) + if not _is_acc(out_fa) and carry_fa.layout != out_fa.layout: + raise ValueError( + f"The loop body output has unexpected layout: output[{idx}] has" + f" layout {out_fa.layout}, when it should be {carry_fa.layout}." + ) scf_dialect.yield_( carry_treedef.flatten_up_to(loop_out) if loop_out else [] ) return carry_treedef.unflatten(list(while_op.results)) -@register_lowering_rule(lax.cond_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.cond_p, mgpu.ThreadSemantics.Warpgroup) -def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches): +@register_lowering_rule(lax.cond_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.cond_p, + mgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warp) +@register_lowering_rule(lax.cond_p, mgpu.LoweringSemantics.Warpgroup) +def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches, + **params): + if params: + raise NotImplementedError("platform_dependent cond") index_aval, *_arg_avals = ctx.avals_in def _yielded_values(outs, avals): ret = [] for out, aval in zip(outs, avals): - if isinstance(out, mgpu.FragmentedArray): + if isinstance(out, (mgpu.WGMMAAccumulator, mgpu.FragmentedArray)): ret.append(out) else: ret.append(_ensure_ir_value(out, aval.dtype)) return ret - # We need the branch return mlir types in order to construct the - # switch operation. To avoid leaking information about what kind of - # mlir types are internal to FragmentedArrays and other mgpu types, - # we run one of the branches in a dummy module that we throw away to - # extract the return types + # We need to know the result types ahead of time to construct the switch + # operation. Below we lower the first branch in a throw-away module to + # extract them. with ir.InsertionPoint(ir.Module.create().body): outs = lower_jaxpr_to_mosaic_gpu( ctx.module_ctx, ctx.launch_ctx, branches[0].jaxpr, args ) - yielded_types = [v.type for v in jax.tree.leaves(_yielded_values(outs, ctx.avals_out))] + yielded_types = [ + v.type for v in jax.tree.leaves(_yielded_values(outs, ctx.avals_out)) + ] del outs - switch_op = scf_dialect.IndexSwitchOp( - yielded_types, - _as_index(_ensure_ir_value(index, index_aval.dtype)), - ir.DenseI64ArrayAttr.get(range(len(branches) - 1)), - num_caseRegions=len(branches) - 1, - ) + # TODO(apaszke): Remove once minimal jaxlib is 0.8.2 + idx_switch_params = inspect.signature(scf_dialect.IndexSwitchOp).parameters + if (mlir_compat := "num_caseRegions" in idx_switch_params): + switch_op = scf_dialect.IndexSwitchOp( + yielded_types, + _as_index(_ensure_ir_value(index, index_aval.dtype)), + ir.DenseI64ArrayAttr.get(range(len(branches) - 1)), + num_caseRegions=len(branches) - 1, + ) + else: + switch_op = scf_dialect.IndexSwitchOp( + yielded_types, + _as_index(_ensure_ir_value(index, index_aval.dtype)), + range(len(branches) - 1), + ) # ``RegionSequence`` in MLIR does not support slicing, so the # auto-generated Python bindings for ``caseRegions`` fail at runtime! @@ -2063,7 +3426,8 @@ def _yielded_values(outs, avals): regions = regions[1:] + regions[:1] treedef = None for branch, region in zip(branches, regions): - with ir.InsertionPoint(region.blocks.append()): + block = region.blocks.append() if mlir_compat else region.blocks[0] + with ir.InsertionPoint(block): outs = lower_jaxpr_to_mosaic_gpu( ctx.module_ctx, ctx.launch_ctx, branch.jaxpr, args, consts=branch.consts ) @@ -2080,9 +3444,9 @@ def _yielded_values(outs, avals): return treedef.unflatten(list(switch_op.results)) -@register_lowering_rule(lax.bitcast_convert_type_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.bitcast_convert_type_p, mgpu.LoweringSemantics.Lane) @register_lowering_rule( - lax.bitcast_convert_type_p, mgpu.ThreadSemantics.Warpgroup + lax.bitcast_convert_type_p, mgpu.LoweringSemantics.Warpgroup ) def _bitcast_convert_type_lowering_rule( ctx: LoweringRuleContext, x, *, new_dtype @@ -2098,34 +3462,90 @@ def _bitcast_convert_type_lowering_rule( " have different widths" ) - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Warpgroup: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Warpgroup: x = _ensure_ir_value(x, x_aval.dtype) return arith_dialect.bitcast( ir.VectorType.get(x_aval.shape, dst_elem_type), x ) x = _ensure_fa(x, x_aval.dtype) - if ir.IntegerType.isinstance(dst_elem_type): - output_is_signed = mgpu_utils.is_signed(new_dtype) - else: - output_is_signed = None + output_is_signed = mgpu_utils.is_signed(new_dtype) return mgpu.FragmentedArray.bitcast( x, dst_elem_type, output_is_signed=output_is_signed ) -@register_lowering_rule(lax.optimization_barrier_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.optimization_barrier_p, mgpu.LoweringSemantics.Lane) def _optimization_barrier_lowering(ctx: LoweringRuleContext, *args): - args = (_ensure_fa(arg, aval.dtype) for arg, aval in zip(args, ctx.avals_in)) - return mgpu.optimization_barrier(*args) + result = mgpu.optimization_barrier( + *(_ensure_fa(arg, aval.dtype) for arg, aval in zip(args, ctx.avals_in)) + ) + return (result,) if len(ctx.avals_in) == 1 else result + + +@register_lowering_rule( + lax.optimization_barrier_p, mgpu.LoweringSemantics.Warpgroup +) +def _optimization_barrier_lowering_wg(ctx: LoweringRuleContext, *args): + result = mgpu.dialect.optimization_barrier([ + _ensure_ir_value(arg, aval.dtype) for arg, aval in zip(args, ctx.avals_in) + ]) + return (result,) if len(ctx.avals_in) == 1 else result + + +@register_lowering_rule(pallas_core.core_map_p, mgpu.LoweringSemantics.Lane) +def _core_map_lowering_rule( + ctx: LoweringRuleContext, + *args, + jaxpr, + mesh, + **_, +): + if isinstance(mesh, gpu_core.WarpMesh): + # A core_map over a WarpMesh represents a fork/join over individual + # warps in a warpgroup. + if (ctx.module_ctx.warp_axis_name or + ctx.module_ctx.primitive_semantics == gpu_core.PrimitiveSemantics.Warp): + raise LoweringError( + "Cannot nest core_maps. Already under core_map with warp_axis_name " + f"{ctx.module_ctx.warp_axis_name}.") + module_ctx = dataclasses.replace( + ctx.module_ctx, + warp_axis_name=mesh.axis_name, + primitive_semantics=gpu_core.PrimitiveSemantics.Warp, + ) + for aval_in in ctx.avals_in: + if isinstance(aval_in, jax_core.ShapedArray) and aval_in.shape: + raise LoweringError( + "Can only close over scalars and Refs when using core_map with " + f"WarpMesh. Found array of shape {aval_in}." + ) + # We allow the warps to schedule async copies without synchronizing with + # other warps, so we need to add a barrier here to make sure all reads and + # writes have completed. + if ctx.module_ctx.auto_barriers: + mgpu.warpgroup_barrier() + _ = lower_jaxpr_to_mosaic_gpu( + module_ctx, + ctx.launch_ctx, + jaxpr, + args=(), + consts=args, + ) + if ctx.module_ctx.auto_barriers: + # We need to ensure that any effects produced by one warp + # (e.g. async copies) are observable by all other warps. + mgpu.warpgroup_barrier() + return [] + raise ValueError(f"Unsupported mesh: {mesh}") def _bcast( - x: ir.Value, - y: ir.Value, - x_aval: jax_core.ShapedArray, - y_aval: jax_core.ShapedArray, - out_aval: jax_core.ShapedArray, + x: Any, + y: Any, + x_aval: ShapedAbstractValue, + y_aval: ShapedAbstractValue, + out_aval: ShapedAbstractValue, ) -> tuple[mgpu.FragmentedArray, mgpu.FragmentedArray]: if not isinstance(x, mgpu.FragmentedArray): x_dtype = x_aval.dtype @@ -2154,11 +3574,11 @@ def _ensure_fa(x: object, dtype: jnp.dtype) -> mgpu.FragmentedArray: def _bcast_wg( - x: object, - y: object, - x_aval: jax_core.ShapedArray, - y_aval: jax_core.ShapedArray, - out_aval: jax_core.ShapedArray, + x: Any, + y: Any, + x_aval: ShapedAbstractValue, + y_aval: ShapedAbstractValue, + out_aval: ShapedAbstractValue, ) -> tuple[ir.Value, ir.Value]: """Ensures that ``x`` and ``y`` have the expected shapes and dtypes. @@ -2178,17 +3598,17 @@ def _bcast_wg( if y_aval.weak_type: y_dtype = x_aval.dtype y = _ensure_ir_value(y, y_dtype) - if not ir.VectorType.isinstance(x.type): + if not isinstance(x.type, ir.VectorType): assert not x_aval.shape - x = vector_dialect.splat( + x = vector_dialect.broadcast( ir.VectorType.get(out_aval.shape, mgpu_utils.dtype_to_ir_type(x_dtype)), x, ) elif x_aval.shape != out_aval.shape: raise NotImplementedError("Unsupported broadcast") - if not ir.VectorType.isinstance(y.type): + if not isinstance(y.type, ir.VectorType): assert not y_aval.shape - y = vector_dialect.splat( + y = vector_dialect.broadcast( ir.VectorType.get(out_aval.shape, mgpu_utils.dtype_to_ir_type(y_dtype)), y, ) @@ -2197,10 +3617,10 @@ def _bcast_wg( return x, y -def _ensure_ir_value(x: object, dtype: jnp.dtype) -> ir.Value: +def _ensure_ir_value(x: Any, dtype: jnp.dtype) -> ir.Value: if isinstance(x, ir.Value): mlir_dtype = mgpu_utils.dtype_to_ir_type(dtype) - if ir.VectorType.isinstance(x.type): + if isinstance(x.type, ir.VectorType): assert ir.VectorType(x.type).element_type == mlir_dtype else: assert x.type == mlir_dtype, (x.type, mlir_dtype) @@ -2213,8 +3633,19 @@ def _ensure_ir_value(x: object, dtype: jnp.dtype) -> ir.Value: return _ir_constant(x, mgpu_utils.dtype_to_ir_type(dtype)) +def _ensure_ir_value_device_id(device_id: Any) -> ir.Value: + ensure_i32 = functools.partial(_ensure_ir_value, dtype=jnp.int32) + if isinstance(device_id, tuple): + return tuple(map(ensure_i32, device_id)) + if isinstance(device_id, dict): + return {k: ensure_i32(v) for k, v in device_id.items()} + return ensure_i32(device_id) + + def _ir_constant(v: object, t: ir.Type) -> ir.Value: - if isinstance(v, (np.number, np.ndarray, int, float)): + if isinstance( + v, (np.number, np.ndarray, int, float, literals.TypedNdArray) + ): if isinstance(t, (ir.IntegerType, ir.IndexType)): v = int(v) else: @@ -2240,12 +3671,16 @@ def _as_index(v: object) -> ir.Value: match v: case int(): return arith_dialect.constant(ir.IndexType.get(), v) - case ir.Value() if ir.IndexType.isinstance(v.type): + case ir.Value() if isinstance(v.type, ir.IndexType): return v - case ir.Value() if ir.IntegerType.isinstance(v.type): + case ir.Value() if isinstance(v.type, ir.IntegerType): return arith_dialect.index_cast(ir.IndexType.get(), v) case mgpu.FragmentedArray(layout=mgpu.WGSplatFragLayout()): return _as_index(v.registers.item()) + case literals.TypedNdArray() if ( + np.issubdtype(v.dtype, np.integer) and v.ndim == 0 + ): + return arith_dialect.constant(ir.IndexType.get(), int(v)) case _: raise ValueError(f"Unsupported index: {v} of type {type(v)}") @@ -2269,20 +3704,28 @@ def merge_indexers( if indexer.int_indexer_shape: raise NotImplementedError() - def _ensure_idx_fa(x): + def _ensure_idx_fa(x: Any) -> mgpu.FragmentedArray: i32 = ir.IntegerType.get_signless(32) if isinstance(x, ir.Value): # TODO(cperivol): We assume all indices are signed. We should # look at the JAX avals to see if the integers are signed or # not to figure out is_signed. - is_signed = False if ir.IntegerType.isinstance(x.type) else None - return mgpu.FragmentedArray.splat( - x, (), is_signed=is_signed - ).astype(i32, is_signed=False) + is_signed = False if isinstance(x.type, ir.IntegerType) else None + return mgpu.FragmentedArray.splat(x, (), is_signed=is_signed).astype( + i32, is_signed=False + ) if isinstance(x, mgpu.FragmentedArray): return x.astype(i32, is_signed=False) if isinstance(x, int): return mgpu.FragmentedArray.splat(mgpu.c(x, i32), (), is_signed=False) + if ( + isinstance(x, literals.TypedNdArray) + and x.ndim == 0 + and np.issubdtype(x.dtype, np.signedinteger) + ): + return mgpu.FragmentedArray.splat( + mgpu.c(int(x), i32), (), is_signed=False + ) raise NotImplementedError(x) num_skipped = 0 @@ -2313,3 +3756,157 @@ def _ensure_idx_fa(x): shape=root_shape, int_indexer_shape=(), ) + + +@register_lowering_rule(primitives.semaphore_read_p, mgpu.LoweringSemantics.Lane) +def _semaphore_read_lowering_rule(ctx: LoweringRuleContext, *args, args_tree): + sem, transforms = tree_util.tree_unflatten(args_tree, args) + sem, transforms = _handle_transforms(ctx, sem, transforms) + if transforms: + raise NotImplementedError(f"Unhandled transforms for semaphore_read: {transforms}") + sem_ptr = mgpu.utils.memref_ptr(sem) + i32_ty = ir.IntegerType.get_signless(32) + result = llvm_dialect.inline_asm( + i32_ty, + [sem_ptr], + "ld.acquire.sys.u32 $0,[$1];", + "=r,l", + has_side_effects=True, + ) + return _ensure_fa(result, jnp.int32) + + +@register_lowering_rule(primitives.semaphore_signal_p, mgpu.LoweringSemantics.Lane) +def _semaphore_signal_lowering_rule( + ctx: LoweringRuleContext, + *args, + args_tree, + device_id_type, +): + i32 = ir.IntegerType.get_signless(32) + sem, transforms, value, device_id, core_index = tree_util.tree_unflatten( + args_tree, args + ) + if core_index is not None: + raise NotImplementedError( + "Mosaic GPU backend does not support the concept of cores, but" + " core_index is specified" + ) + sem, transforms = _handle_transforms(ctx, sem, transforms) + if transforms: + raise NotImplementedError(f"Unhandled transforms for semaphore_signal: {transforms}") + sem_ptr = mgpu.utils.memref_ptr(sem) + if device_id is not None: + device_id, other_axes = primitives.device_id_to_logical( + ctx.module_ctx.mesh_info, + _ensure_ir_value_device_id(device_id), + device_id_type, + lambda name: _axis_index_rule(ctx, axis_name=name), + ) + if other_axes: + raise NotImplementedError( + f"Only JAX mesh axes can be used in device_id, but found {other_axes}" + ) + sem_ptr = ctx.launch_ctx.to_remote(sem_ptr, device_id) + # TODO(apaszke): Narrow the scope from .sys to .gpu when the semaphore is local. + val = _ir_constant(value, i32) + # We only signal the semaphore from a single lane, which does not guarantee + # anything about the state of the other three warps in the warpgroup (they + # might still be e.g. reading memory that someone will overwrite once they + # receive a signal). + if ctx.module_ctx.auto_barriers: + mgpu.utils.warpgroup_barrier() + mgpu_utils.SemaphoreRef(sem_ptr).signal( + val, predicate=ctx.module_ctx.single_wg_lane_predicate + ) + return () + + +@register_lowering_rule(primitives.semaphore_wait_p, mgpu.LoweringSemantics.Lane) +def _semaphore_wait_lowering_rule(ctx: LoweringRuleContext, *args, args_tree): + sem, transforms, value, decrement = tree_util.tree_unflatten(args_tree, args) + sem, transforms = _handle_transforms(ctx, sem, transforms) + if transforms: + raise NotImplementedError( + f"Unhandled transforms for semaphore_wait: {transforms}" + ) + mgpu_utils.SemaphoreRef(mgpu.utils.memref_ptr(sem)).wait( + _ensure_ir_value(value, jnp.int32), decrement=decrement + ) + return () + + +@register_lowering_rule(checkify.check_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(checkify.check_p, mgpu.LoweringSemantics.Warpgroup) +def _check_lowering_rule(ctx: LoweringRuleContext, *err_args, err_tree, debug): + if not debug: + raise NotImplementedError( + "Non-debug checks are not supported by the Mosaic GPU backend." + " Functionalize them via `jax.experimental.checkify`." + ) + if not pallas_helpers.debug_checks_enabled(): + return [] + + error = jax.tree.unflatten(err_tree, err_args) + [pred] = error._pred.values() + [exception_tree] = error._metadata.values() + [payload] = error._payload.values() + exception = jax.tree.unflatten(exception_tree, payload) + assert isinstance(exception, checkify.FailedCheckError) + + # check_p has an inverted predicate compared to assert, so we need to compute + # ``not pred`` here. + minus_one = _ir_constant(-1, mgpu_utils.dtype_to_ir_type(jnp.bool)) + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: + pred = pred.registers.item() + not_pred = arith_dialect.xori(pred, minus_one) + cf_dialect.assert_(not_pred, exception.fmt_string) + return [] + +@register_lowering_rule(gpu_core.layout_cast_p, mgpu.LoweringSemantics.Lane) +def _layout_cast_lowering(ctx: LoweringRuleContext, x, *, new_layout): + del ctx # Unused. + return x.to_layout(new_layout.to_mgpu()) + + +@register_lowering_rule(gpu_core.layout_cast_p, mgpu.LoweringSemantics.Warpgroup) +def _layout_cast_lowering_wg( + ctx: LoweringRuleContext, x, *, new_layout +): + del ctx # Unused. + return mgpu.dialect.layout_cast(x, mgpu.to_layout_attr(new_layout.to_mgpu())) + + +@register_lowering_rule(lax.iota_p, mgpu.LoweringSemantics.Lane) +def _iota_lowering( + ctx: LoweringRuleContext, dtype, shape, dimension, sharding +): + del sharding # Unused. + if ctx.out_layout_hint is None: + raise RuntimeError( + "Failed to infer the output layout of the iota. Please apply" + " plgpu.layout_cast to its output right after its creation." + ) + mlir_dtype = mgpu_utils.dtype_to_ir_type(dtype) + is_signed = mgpu_utils.is_signed(dtype) + return mgpu.FragmentedArray.broadcasted_iota( + mlir_dtype, shape, dimension, ctx.out_layout_hint, is_signed=is_signed + ) + + +@register_lowering_rule(lax.iota_p, mgpu.LoweringSemantics.Warpgroup) +def _iota_lowering_wg( + ctx: LoweringRuleContext, dtype, shape, dimension, sharding +): + del ctx, sharding + result_type = ir.VectorType.get(shape, mgpu_utils.dtype_to_ir_type(dtype)) + return mgpu.dialect.broadcasted_iota(result_type, dimension) + + +@register_lowering_rule(primitives.delay_p, mgpu.LoweringSemantics.Lane) +def _delay_lowering(ctx: LoweringRuleContext, nanos): + del ctx # Unused. + if not isinstance(nanos, ir.Value): + nanos = _i32_constant(nanos) + mgpu.nanosleep(nanos) + return [] diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index d506349fe101..eba20f7b75e1 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -19,15 +19,21 @@ import os import time -from typing import Any +from typing import cast import warnings import jax +from jax._src import config from jax._src import core as jax_core +from jax._src import frozen_dict +from jax._src import sharding_impls from jax._src.interpreters import mlir from jax._src.pallas import core as pallas_core +from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.pallas.mosaic_gpu import lowering -import jax.experimental.mosaic.gpu.core as mosaic_core +from jax.experimental.mosaic import gpu as mgpu +import jax.numpy as jnp +import numpy as np def pallas_call_lowering( @@ -39,10 +45,13 @@ def pallas_call_lowering( input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: pallas_core.GridMapping, mesh: pallas_core.Mesh | None, - compiler_params: dict[str, Any], + compiler_params: dict[str, pallas_core.CompilerParams], cost_estimate: pallas_core.CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], + metadata: frozen_dict.FrozenDict[str, str] | None, + name: str | None, ): + del metadata, name # TODO(sharadmv): Add metadata to HLO. debug_info = jaxpr.debug_info del interpret, out_avals if grid_mapping.num_dynamic_grid_bounds: @@ -56,45 +65,69 @@ def pallas_call_lowering( print(f"The grid mapping for pallas_call {debug_info.func_src_info}:") print(grid_mapping) - thread_semantics = compiler_params.get("mosaic_gpu", {}).get( - "thread_semantics", mosaic_core.ThreadSemantics.Lane - ) - if thread_semantics == mosaic_core.ThreadSemantics.Warpgroup: - mosaic_core.dialect.register_dialect(ctx.module_context.context) # pytype: disable=attribute-error + mgpu.dialect.register_dialect(ctx.module_context.context) # pytype: disable=attribute-error - lowering_result = lowering.lower_pipelined_jaxpr_to_module( - grid_mapping, - mesh, - jaxpr, - compiler_params, - cost_estimate, - ) + if "mosaic_gpu" in compiler_params: + params = cast(gpu_core.CompilerParams, compiler_params["mosaic_gpu"]) + else: + params = gpu_core.CompilerParams() + + jax_mesh = None + axis_context = ctx.module_context.axis_context + if axis_context is not None: + if isinstance(axis_context, sharding_impls.SPMDAxisContext): + jax_mesh = axis_context.mesh + + # TODO(slebedev): Remove this once the ensure-debug-info-scope-on-llvm-func + # pass correctly handles full tracebacks. + with config.include_full_tracebacks_in_locations(False): + lowering_result = lowering.lower_pipelined_jaxpr_to_module( + grid_mapping, mesh, jax_mesh, jaxpr, params, cost_estimate + ) if debug: print(f"\nThe Mosaic GPU module for pallas_call {debug_info.func_src_info}:") print(lowering_result.module.operation) module = lowering_result.module - new_avals_out = [ - jax_core.ShapedArray(t.shape, t.dtype) for t in lowering_result.out_structs - ] - outs = mosaic_core._mosaic_gpu_lowering_rule( - ctx.replace(avals_out=new_avals_out), - *args, + new_avals_in = list(ctx.avals_in) + new_avals_out = list(map(_as_shaped_array, lowering_result.new_out_shapes)) + scratch_args = () + if lowering_result.gmem_scratch_shapes: + # The new_out_shapes contain the original outputs first, followed by the + # GMEM scratch shapes, and optionally the profiler buffer. + input_output_aliases += tuple( + (len(ctx.avals_in) + i, len(ctx.avals_out) + i) + for i in range(len(lowering_result.gmem_scratch_shapes)) + ) + # The GMEM scratch is an aliased kernel input/output. + new_avals_in.extend(map(_as_shaped_array, lowering_result.gmem_scratch_shapes)) + # We guarantee zero-initialization of the GMEM scratch at the moment, which + # is important for semaphores. + def zero_init_gmem_scratch(): + return [jnp.zeros_like(s) for s in lowering_result.gmem_scratch_shapes] + scratch_args = mlir.lower_fun( + zero_init_gmem_scratch, multiple_results=True + )(ctx.replace(avals_in=())) + outs = mgpu.core._mosaic_gpu_lowering_rule( + ctx.replace(avals_in=new_avals_in, avals_out=new_avals_out), + *args, *scratch_args, module=module, - out_types=lowering_result.out_structs, + out_types=lowering_result.new_out_shapes, + inout_types=(), input_output_aliases=input_output_aliases, + # False until we add get_barrier_semaphore() feature. + use_custom_barrier=False, ) - if (prof_ctx := lowering_result.profiler_context) is not None: + if (prof_spec := lowering_result.profiler_spec) is not None: *outs, prof_buffer = outs - if (dump_path := prof_ctx.dump_path) == "sponge": - dump_path = os.getenv("TEST_UNDECLARED_OUTPUTS_DIR") # type: ignore out_file = os.path.join( - dump_path, f"{mlir.sanitize_name(debug_info.func_name)}-{time.time_ns()}-trace.json" + prof_spec.dump_path, + f"{mlir.sanitize_name(debug_info.func_name)}-{time.time_ns()}-trace.json", ) def dump_profile(prof_buffer): try: with open(out_file, "x") as f: - prof_ctx.spec.dump( + prof_spec.dump( prof_buffer, f, grid=lowering_result.grid, @@ -111,4 +144,10 @@ def do_callback(prof_buffer): mlir.lower_fun(do_callback, multiple_results=True)( ctx.replace(avals_in=(new_avals_out[-1],)), prof_buffer ) + if lowering_result.gmem_scratch_shapes: # Drop the GMEM scratch. + outs = outs[:-len(lowering_result.gmem_scratch_shapes)] return outs + + +def _as_shaped_array(t: jax.ShapeDtypeStruct) -> jax_core.ShapedArray: + return jax_core.ShapedArray(t.shape, np.dtype(t.dtype)) diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index a48fec61b7af..4c8b49cbce60 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -18,61 +18,103 @@ from collections.abc import Callable, Sequence import dataclasses +import enum import functools import itertools as it import math -from typing import Any +from typing import Any, Protocol, TypeVar, Union, cast import jax from jax import api_util from jax import lax from jax._src import core from jax._src import linear_util as lu +from jax._src import state from jax._src import util from jax._src.interpreters import partial_eval as pe from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.pallas.mosaic_gpu import primitives as gpu_primitives -from jax._src.util import foreach from jax.experimental import pallas as pl import jax.numpy as jnp map = util.safe_map zip = util.safe_zip +T = TypeVar('T') +BlockSpecPytree = Sequence[Union[pl.BlockSpec, "BlockSpecPytree"]] +AbstractRefPytree = Sequence[Union[state.AbstractRef, "AbstractRefPytree"]] + +def _get_block_size( + bd: pl.Blocked | pl.Element | pl.Squeezed | pl.BoundedSlice | int | None, +) -> int: + match bd: + case int(): + return bd + case pl.Blocked(block_size): + return block_size + case _: + raise NotImplementedError(f"Unsupported block size type: {type(bd)}") + +def _get_block_shape(spec: pallas_core.BlockSpec): + if spec.block_shape is None: + raise ValueError("Block shape must be specified.") + + block_shape = tuple( + _get_block_size(bd) + for bd in spec.block_shape + if not (bd is None or isinstance(bd, pl.Squeezed)) + ) + return block_shape + + +map_brefs = functools.partial( + jax.tree.map, is_leaf=lambda x: isinstance(x, BufferedRef) +) + @jax.tree_util.register_dataclass @dataclasses.dataclass(frozen=True) class BufferedRef: - spec: pallas_core.BlockSpec = dataclasses.field(metadata={"static": True}) + spec: gpu_core.BlockSpec = dataclasses.field(metadata={"static": True}) is_index_invariant: bool = dataclasses.field(metadata={"static": True}) - gmem_ref: pallas_core.AbstractMemoryRef + gmem_ref: state.AbstractRef # ``None`` if the ref is pinned to GMEM; otherwise, has shape # [num_slots, *spec.block_shape]. - smem_ref: pallas_core.AbstractMemoryRef | None + smem_ref: state.AbstractRef | None def get_ref_for_slot( self, slot: int | jax.Array - ) -> pallas_core.AbstractMemoryRef: + ) -> state.AbstractRef: if self.smem_ref is None: return self.gmem_ref return self.smem_ref.at[slot] - def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice, ...]: + def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice | jax.Array, ...]: index_map = self.spec.index_map assert index_map is not None + assert self.spec.block_shape is not None # We don't allow Python scalars here, because they are interpreted # differently depending on the x32/x64 mode. assert all(i.dtype == jnp.dtype(jnp.int32) for i in grid_indices) + + def _make_block_slice(block_index: jax.Array, bd: pl.BlockDim | int | None): + match bd: + case int(): + return pl.Slice(block_index * bd, bd) + case pl.Blocked(block_size): + return pl.Slice(block_index * block_size, block_size) + case None | pl.Squeezed(): + return block_index + case _: + raise ValueError(f"Unsupported block dimension type: {bd}") + return tuple( - pl.Slice(idx * size, size) # type: ignore[arg-type] - for idx, size in zip( - index_map(*grid_indices), self.spec.block_shape # type: ignore[arg-type] - ) + map(_make_block_slice, index_map(*grid_indices), self.spec.block_shape) ) - def copy_in(self, slot, grid_indices, barrier_ref): + def copy_in(self, slot, grid_indices, barrier_ref, barrier_slot=None): if not _in_smem(self.spec): return assert self.smem_ref is not None @@ -80,7 +122,8 @@ def copy_in(self, slot, grid_indices, barrier_ref): gpu_primitives.copy_gmem_to_smem( self.gmem_ref.at[gmem_slices], # pytype: disable=unsupported-operands self.smem_ref.at[slot], # pytype: disable=unsupported-operands - barrier_ref.at[slot], + barrier_ref.at[barrier_slot if barrier_slot is not None else slot], + collective_axes=getattr(self.spec, "collective_axes", ()), ) def copy_out(self, slot, grid_indices, predicate=None): @@ -102,7 +145,7 @@ def _uses_arguments( if not num_args: return () - jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( + jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( lu.wrap_init( index_map, debug_info=api_util.debug_info("pallas index_map", @@ -115,7 +158,7 @@ def _uses_arguments( def _is_index_invariant( - spec: pallas_core.BlockSpec, grid: pallas_core.StaticGrid + spec: pallas_core.BlockSpec, grid: pallas_core.TupleGrid ) -> bool: if (index_map := spec.index_map) is None: return True @@ -123,7 +166,7 @@ def _is_index_invariant( def _inc_grid_by_1( - indices: tuple[jax.Array, ...], grid: Sequence[int] + indices: tuple[jax.Array, ...], grid: pallas_core.TupleGrid ) -> tuple[jax.Array, ...]: next_indices = [] carry: bool | jax.Array = True @@ -139,69 +182,92 @@ def _inc_grid_by_1( def _in_smem(spec: pallas_core.BlockSpec) -> bool: return spec.memory_space in (None, gpu_core.SMEM) - -# ``pl.Slice`` uses a different pytree encoding, depending on whether the -# start/size are static or dynamic. This leads to pytree structure mismatch -# in the pipeline body. So, we define a different ``Slice`` class below. - - -@dataclasses.dataclass(frozen=True) -class _Slice: - start: int | jax.Array - size: int | jax.Array - - def __eq__(self, other: _Slice) -> jax.Array: # type: ignore - return lax.bitwise_and(self.start == other.start, self.size == other.size) - - -jax.tree_util.register_dataclass( - _Slice, data_fields=["start", "size"], meta_fields=[] -) +def _downcast_spec( + spec: gpu_core.BlockSpec | pallas_core.BlockSpec, +) -> gpu_core.BlockSpec: + if isinstance(spec, gpu_core.BlockSpec): + return spec + + return gpu_core.BlockSpec( + block_shape=spec.block_shape, + index_map=spec.index_map, + memory_space=spec.memory_space, + pipeline_mode=spec.pipeline_mode, + ) def emit_pipeline( - body: Callable[..., None], + body: Callable[..., T], *, - grid: pallas_core.StaticGrid, + grid: pallas_core.TupleGrid, in_specs: Sequence[pallas_core.BlockSpec] = (), out_specs: Sequence[pallas_core.BlockSpec] = (), max_concurrent_steps: int = 1, - delay_release: int = 0, + init_carry: T | None = None, ): - """Creates a function to emit a manual pipeline within a Pallas kernel. + r"""Creates a function to emit a manual pipeline within a Pallas kernel. Args: - body: The pipeline body. - grid: The grid to use for the pipeline. - in_specs: The block specs for the inputs. - out_specs: The block specs for the outputs. - max_concurrent_steps: The maximum number of sequential stages that are - active concurrently. Defaults to 1. - delay_release: The number of steps to wait before reusing the input/output - references. Defaults to 0, and must be strictly smaller than - ``max_concurrent_steps``. Generally, you'll want to set it to 1 if you - don't await the WGMMA in the body. + body: The pipeline body function, which is called with + + - ``indices``: Tuple of current loop indices. + - ``*input_refs``: SMEM refs for inputs. + - ``*output_refs``: SMEM refs for outputs. + + If ``init_carry`` is provided, ``body`` receives an additional argument + ``carry`` -- the carry from the previous iteration. It must then return + the next carry value. + grid: The grid dimensions for the pipeline. + in_specs: A sequence of :class:`~jax.experimental.pallas.BlockSpec`\s + for inputs. + out_specs: A sequence of :class:`~jax.experimental.pallas.BlockSpec`\s + for outputs. + max_concurrent_steps: Maximum concurrently active pipeline stages. + init_carry: Optional initial carry. If provided, ``body`` handles + carry-over state between iterations, and the pipeline returns the + final carry. + + Returns: + A function that, when called with GMEM input and output refs, executes the + pipeline and returns the final carry value (if ``init_carry`` was used), + otherwise it returns None. """ - num_steps = math.prod(grid) - if max_concurrent_steps <= delay_release: + in_specs = tuple(map(_downcast_spec, in_specs)) + out_specs = tuple(map(_downcast_spec, out_specs)) + for spec in in_specs: + if spec.collective_axes: + raise NotImplementedError( + "BlockSpecs with collective_axes are not supported in emit_pipeline" + ) + for spec in out_specs: + if spec.collective_axes: + raise ValueError("Output BlockSpecs cannot have collective_axes") + # TODO(justinfu): Factor out common code between warp-specialized and + # normal pipelines. + delay_release_levels = sorted({s.delay_release for s in in_specs}) or [0] + if delay_release_levels and max_concurrent_steps <= delay_release_levels[0]: raise ValueError( - "max_concurrent_steps must be greater than delay_release, but" - f" {max_concurrent_steps=}, {delay_release=}" + "max_concurrent_steps must be greater than all delay_release values," + f" but {max_concurrent_steps=} and {delay_release_levels=}." ) + num_steps = math.prod(grid) + has_dynamic_grid = not isinstance(num_steps, int) + # Convert the grid to int32 explicitly to avoid dtype promotion errors. + grid = tuple(jnp.asarray(g, dtype=jnp.int32) for g in grid) + # Shrink ``max_concurrent_steps`` if the total number of steps is lower to # reduce the size of the refs allocated in SMEM. - if max_concurrent_steps > num_steps: - max_concurrent_steps = num_steps - delay_release = 0 # No need to delay anything. + if not has_dynamic_grid and max_concurrent_steps > num_steps: + max_concurrent_steps = cast(int, num_steps) - def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): + def pipeline(*gmem_refs: state.AbstractRef): in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)]) in_smem_refs, out_smem_refs = util.split_list( [ gpu_core.SMEM( - (max_concurrent_steps, *spec.block_shape), # type: ignore + (max_concurrent_steps, *_get_block_shape(spec)), # type: ignore ref.dtype, transforms=tuple( t.batch(1) for t in getattr(spec, "transforms", ()) @@ -213,6 +279,7 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): ], [len(in_specs)], ) + num_arrivals = sum(map(_in_smem, in_specs)) return pl.run_scoped( functools.partial( scoped_pipeline, @@ -221,9 +288,11 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): ), in_smem_refs=in_smem_refs, out_smem_refs=out_smem_refs, - barrier_ref=gpu_core.Barrier( + barrier_ref=None + if num_arrivals == 0 + else gpu_core.Barrier( # TODO(slebedev): Change this to arrive only once. - sum(map(_in_smem, in_specs)), + num_arrivals=num_arrivals, num_barriers=max_concurrent_steps, ), ) @@ -244,51 +313,65 @@ def scoped_pipeline( ) ] - for step, indices in enumerate( - it.islice(it.product(*map(range, grid)), max_concurrent_steps) - ): - indices = tuple(map(lambda i: jnp.asarray(i, dtype=jnp.int32), indices)) - foreach(lambda bref: bref.copy_in(step, indices, barrier_ref), in_brefs) + # Initialize the pipeline. + indices = (jnp.asarray(0, dtype=jnp.int32),) * len(grid) + if has_dynamic_grid: + prologue_steps = lax.min(max_concurrent_steps, num_steps) + else: + assert max_concurrent_steps <= num_steps + prologue_steps = max_concurrent_steps + + def prologue(step, fetch_indices): + for bref in in_brefs: + bref.copy_in(step, fetch_indices, barrier_ref) + return _inc_grid_by_1(fetch_indices, grid) + jax.lax.fori_loop(0, prologue_steps, prologue, indices, unroll=not has_dynamic_grid) # This is true if any of the outputs need to be transferred inside the loop. - copies_out_in_loop = not all(bref.is_index_invariant for bref in out_brefs) + smem_out_brefs = [bref for bref in out_brefs if _in_smem(bref.spec)] + copies_out_in_loop = not all(bref.is_index_invariant for bref in smem_out_brefs) + needs_epilogue = any(bref.is_index_invariant for bref in smem_out_brefs) + # In the loop body, `max_concurrent_steps` may be larger than `num_steps` in + # the dynamic grid case. This is fine, since in that case, we will never + # need to fetch more data anyway. def loop_body(step, carry): slot = lax.rem(step, max_concurrent_steps) - indices, fetch_indices, last_store_slices = carry + indices, fetch_index_levels, last_store_indices, prev_body_carry = carry - if in_specs: - # Wait for the current GMEM->SMEM copy to complete. + if barrier_ref is not None: + # Wait for the current GMEM->SMEM copy to complete, if any. gpu_primitives.barrier_wait(barrier_ref.at[slot]) # Wait for the previous output SMEM->GMEM copy to complete. if copies_out_in_loop: gpu_primitives.wait_smem_to_gmem( - max_concurrent_steps - (1 + delay_release), wait_read_only=True + max_concurrent_steps - 1, wait_read_only=True ) - with pallas_core.grid_env(map(pallas_core.GridAxis, indices, grid)): - body(*( - bref.get_ref_for_slot(slot) - for bref in it.chain(in_brefs, out_brefs) - )) + next_body_carry = body( + indices, + *( + bref.get_ref_for_slot(slot) + for bref in it.chain(in_brefs, out_brefs) + ), + *(prev_body_carry,) if init_carry is not None else (), + ) if copies_out_in_loop: gpu_primitives.commit_smem() # Copy the output from SMEM to GMEM. - new_store_slices = last_store_slices[:] + new_store_indices = last_store_indices[:] for idx, bref in enumerate(out_brefs): if bref.is_index_invariant: - assert last_store_slices[idx] is None + assert last_store_indices[idx] is None continue - assert last_store_slices[idx] is not None - new_store_slices[idx] = tuple( - _Slice(s.start, s.size) for s in bref.compute_gmem_slice(indices) - ) + assert last_store_indices[idx] is not None + new_store_indices[idx] = bref.spec.index_map(*indices) are_same_slices = map( lambda old, new: old == new, - last_store_slices[idx], - new_store_slices[idx], + last_store_indices[idx], + new_store_indices[idx], ) slices_changed = ~functools.reduce(lax.bitwise_and, are_same_slices) is_last_step = step == num_steps - 1 @@ -301,90 +384,158 @@ def loop_body(step, carry): predicate=lax.bitwise_or(slices_changed, is_last_step), ) - gpu_primitives.commit_smem_to_gmem_group() - - fetch_step = step + (max_concurrent_steps - delay_release) - fetch_slot = lax.rem(fetch_step, max_concurrent_steps) - - def do_fetch(): - for bref in in_brefs: - bref.copy_in(fetch_slot, fetch_indices, barrier_ref) + if copies_out_in_loop: + gpu_primitives.commit_smem_to_gmem_group() - jax.lax.cond( - lax.bitwise_and(step >= delay_release, fetch_step < num_steps), - do_fetch, - lambda: None, - ) + for delay_release, fetch_indices in zip( + delay_release_levels, fetch_index_levels + ): + fetch_step = step + (max_concurrent_steps - delay_release) + fetch_slot = lax.rem(fetch_step, max_concurrent_steps) + + # pylint: disable=cell-var-from-loop + def do_fetch(): + for bref in in_brefs: + if bref.spec.delay_release == delay_release: + bref.copy_in(fetch_slot, fetch_indices, barrier_ref) + # pylint: enable=cell-var-from-loop + + jax.lax.cond( + lax.bitwise_and(step >= delay_release, fetch_step < num_steps), + do_fetch, + lambda: None, + ) + next_fetch_indices_levels = [ + _inc_grid_by_1(fetch_indices, grid) + for fetch_indices in fetch_index_levels + ] return ( _inc_grid_by_1(indices, grid), - _inc_grid_by_1(fetch_indices, grid), - new_store_slices, + next_fetch_indices_levels, + new_store_indices, + next_body_carry if init_carry is not None else None, ) - # Invariant: ``indices`` and ``fetch_indices`` are always - # ``max_concurrent_steps-delay_release`` apart. - indices = (jnp.asarray(0, dtype=jnp.int32),) * len(grid) - fetch_indices = indices - for _ in range(max_concurrent_steps-delay_release): - fetch_indices = _inc_grid_by_1(fetch_indices, grid) + fetch_index_levels = [] + for delay_release in delay_release_levels: + fetch_indices = indices + for _ in range(max_concurrent_steps - delay_release): + fetch_indices = _inc_grid_by_1(fetch_indices, grid) + fetch_index_levels.append(fetch_indices) + # TODO(justinfu): Only store base pointer instead of all indices. - last_store_slices = [ + last_store_indices = [ None if bref.is_index_invariant - else (_Slice(-1, -1),) * len(bref.spec.block_shape) + else (jnp.array(-1),) * len(bref.spec.block_shape) for bref in out_brefs ] - last_indices, _, _ = lax.fori_loop( - 0, num_steps, loop_body, (indices, fetch_indices, last_store_slices) + last_indices, _, _, final_carry = lax.fori_loop( + 0, + num_steps, + loop_body, + (indices, fetch_index_levels, last_store_indices, init_carry), ) # Outputs invariant to the sequential axis are never written from inside the # loop. This is the only place where we store them. - if not copies_out_in_loop: + if not copies_out_in_loop and needs_epilogue: gpu_primitives.commit_smem() - last_slot = lax.rem(num_steps - 1, max_concurrent_steps) - for bref in out_brefs: - if bref.is_index_invariant: - bref.copy_out(last_slot, last_indices, predicate=None) - gpu_primitives.commit_smem_to_gmem_group() + if needs_epilogue: + last_slot = lax.rem(num_steps - 1, max_concurrent_steps) + for bref in out_brefs: + if bref.is_index_invariant: + bref.copy_out(last_slot, last_indices, predicate=None) - # Finalize the pipeline. - gpu_primitives.wait_smem_to_gmem(0) + gpu_primitives.commit_smem_to_gmem_group() + + if smem_out_brefs: + # Finalize the pipeline. + gpu_primitives.wait_smem_to_gmem(0) + return final_carry if init_carry is not None else None return pipeline + +class ComputeContext(Protocol): + """Protocol for a compute context for the warp specialized pipeline. + + The ComputeContext is run exclusively in the compute thread and allows + the user to set up a prologue to initialize a pipeline carry and an epilogue + to consume the final carry. + + All values allocated in the ComputeContext will only be allocated in the + compute thread and not the memory thread. This can potentially reduce + register pressure if certain values are only consumed by the compute threads. + + Usage will usually follow this structure: + + ``` + def compute_context(pipeline): + # Perform prologue work and compute the initial carry. + initial_carry = ... + # Run the pipeline. + final_carry = pipeline(*initial_carry) + # Perform epilogue work using the final carry. + do_work(final_carry) + ``` + + """ + def __call__(self, pipeline: Callable[[T], T]) -> None: + ... + + +class PipelinePipeline(enum.IntEnum): + START = 0 + STEADY = 1 + STOP = 2 + + +class WarpSpecializedPipeline(Protocol): + """Protocol for a warp specialized pipeline.""" + def __call__( + self, *gmem_refs: Any, allocations: Any | None = None, + ) -> None: + ... + + def get_allocations(self, *gmem_refs: Any) -> Any: + ... + + def emit_pipeline_warp_specialized( body: Callable[..., None], *, - grid: pallas_core.StaticGrid, + grid: pallas_core.TupleGrid, memory_registers: int, - in_specs: Sequence[gpu_core.GPUBlockSpec] = (), - out_specs: Sequence[gpu_core.GPUBlockSpec] = (), + in_specs: BlockSpecPytree = (), + out_specs: BlockSpecPytree = (), max_concurrent_steps: int = 2, wg_axis: str, num_compute_wgs: int, + pipeline_state: jax.Array | PipelinePipeline | None = None, manual_consumed_barriers: bool = False, - carry_coroutine: Any | None = None, + compute_context: ComputeContext | None = None, memory_thread_idx: int | None = None, -): +) -> WarpSpecializedPipeline: """Creates a function to emit a warp-specialized pipeline. The ``body`` function should have the following signature (without carry). ``consumed_barriers`` is an optional argument that is only passed if the - ``manual_consumed_barriers`` argument is True. + ``manual_consumed_barriers`` argument is True:: - ``` - def body(*input_refs, *output_refs, [consumed_barriers]) -> None: - ``` + def body(indices, *input_refs, *output_refs, *consumed_barriers) -> None: - or with a carries enabled (enabled via the ``carry_coroutine`` argument), - where the body returns the next carry: + or with a carries enabled (enabled via the ``compute_context`` argument), + where the body returns the next carry:: - ``` - def body(*input_refs, *output_refs, [consumed_barriers], carry) -> Carry: - ``` + def body( + indices, *input_refs, *output_refs, *consumed_barriers, carry + ) -> Carry: + + When ``manual_consumed_barriers`` is True, the user must arrive on all the + consumed barriers from all compute warpgroups at each pipeline step. Args: body: The pipeline body. @@ -400,16 +551,80 @@ def body(*input_refs, *output_refs, [consumed_barriers], carry) -> Carry: manual_consumed_barriers: If True, consumed barriers will be passed into the body function after the output refs. There will be one barrier per input and will be passed in the same order. - carry_coroutine: If specified, enables carries in the pipeline. - The signature of the body function will be modified such that the last - argument will be the current carry and it must return the next carry. - The coroutine itself should yield the initial carry, and the - yield statement will return the final value of the carry. + compute_context: If specified, enables carries in the pipeline and allows + a user-specified prologue/epilogue that is only executed in the compute + thread. The signature of the pipeline body function will be modified + such that the last argument will be the current carry and it must + return the next carry. + The compute_context itself should follow the signature of `ComputeContext` + and take a pipeline function as its sole argument. Calling the + pipeline with the initial carry will run the pipeline and return the + final carry. memory_thread_idx: The index of the memory thread. If not specified, defaults to the last thread. + pipeline_state: If multiple pipelines that have almost the same parameters + (only in/out_specs and body can differ) are going to be evaluated + in sequence, this argument can be used to avoid pipeline bubbles between + their invocations. The first pipeline in the sequence should use the + ``START`` state, followed by an arbitrary number of ``STEADY`` states, + followed by a single ``STOP`` state. Note that until the pipeline with + ``STOP`` is done, the memory thread will not wait for the compute threads + to complete and fully consume their work. Any modification of their + operands other than invoking another pipeline is disallowed. + + Important: To achieve bubble-free execution, it is important to also use + the manual allocation mode by calling ``get_allocations`` on the returned + function, passing the result to ``pl.run_scoped`` and the provided results + to the returned function as an ``allocations`` keyword argument. + Otherwise, the pipeline function will perform the scoped allocation itself + which can lead to synchronization that can still cause pipeline bubbles. """ + # TODO(justinfu): Factor out common code between warp-specialized and # normal pipelines. + if not isinstance(in_specs, (list, tuple)): + in_specs = (in_specs,) + if not isinstance(out_specs, (list, tuple)): + out_specs = (out_specs,) + if isinstance(in_specs, list): + in_specs = tuple(in_specs) + if isinstance(out_specs, list): + out_specs = tuple(out_specs) + + flat_in_specs, in_specs_treedef = jax.tree.flatten(in_specs) + flat_in_specs = tuple(map(_downcast_spec, flat_in_specs)) + for spec in flat_in_specs: + if len(spec.collective_axes) > 1: + raise ValueError( + "Only a single collective axis supported in input BlockSpecs, but" + f" got {spec.collective_axes}" + ) + collective_axes = tuple(frozenset( + a for spec in flat_in_specs for a in spec.collective_axes + )) + flat_out_specs, out_specs_treedef = jax.tree.flatten(out_specs) + flat_out_specs = tuple(map(_downcast_spec, flat_out_specs)) + for spec in flat_out_specs: + if spec.collective_axes: + raise ValueError("Output BlockSpecs cannot have collective_axes") + delay_release = None + for in_spec in in_specs: + if not isinstance(in_spec, gpu_core.BlockSpec): + delay_release = 0 + continue + delay_release = in_spec.delay_release + if in_spec.delay_release != delay_release: + raise NotImplementedError( + "All inputs must have the same delay_release, but" + f" {in_spec.delay_release=} != {delay_release=}" + ) + + delay_release = delay_release or 0 + if max_concurrent_steps <= delay_release: + raise ValueError( + "max_concurrent_steps must be greater than delay_release, but" + f" {max_concurrent_steps=}, {delay_release=}" + ) if memory_thread_idx is None: memory_thread_idx = num_compute_wgs @@ -418,17 +633,20 @@ def body(*input_refs, *output_refs, [consumed_barriers], carry) -> Carry: # thread is the last thread. raise NotImplementedError("Memory thread must be the last thread.") - has_carry = carry_coroutine is not None + has_carry = compute_context is not None # Trace the index maps to determine if they depend on the grid. # Grid-independent values will not be multiple-buffered. in_spec_has_seq_axis = [ - ~_is_index_invariant(spec, grid) for spec in in_specs] + not _is_index_invariant(spec, grid) for spec in flat_in_specs] out_spec_has_seq_axis = [ - ~_is_index_invariant(spec, grid) for spec in out_specs] + not _is_index_invariant(spec, grid) for spec in flat_out_specs] spec_has_seq_axis = [*in_spec_has_seq_axis, *out_spec_has_seq_axis] + if not all(in_spec_has_seq_axis): + raise NotImplementedError("Only inputs with a dependency on the grid are supported.") - num_pipeline_steps = math.prod(grid) + num_steps = math.prod(grid) + has_dynamic_grid = not isinstance(num_steps, int) def _get_slot(step, has_seq_dim): """Returns the buffer slot given the pipeline step.""" @@ -439,42 +657,53 @@ def _get_slot(step, has_seq_dim): # Shrink ``max_concurrent_steps`` if the total number of steps is lower to # reduce the size of the refs allocated in SMEM. - if max_concurrent_steps > num_pipeline_steps: - max_concurrent_steps = num_pipeline_steps - - def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): - in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)]) - if len(out_gmem_refs) != len(out_specs): + if not has_dynamic_grid and max_concurrent_steps > num_steps: + max_concurrent_steps = cast(int, num_steps) + + def _get_scoped_allocs(*gmem_refs: AbstractRefPytree): + in_gmem_refs = gmem_refs[:len(in_specs)] + out_gmem_refs = gmem_refs[len(in_specs):] + flat_in_gmem_refs, in_gmem_refs_treedef = jax.tree.flatten(in_gmem_refs) + flat_out_gmem_refs, out_gmem_refs_treedef = jax.tree.flatten(out_gmem_refs) + if in_specs_treedef != in_gmem_refs_treedef: + raise ValueError( + "Input specs and input gmem refs must have the same pytree structure." + f" {in_specs_treedef} != {in_gmem_refs_treedef}" + ) + if out_specs_treedef != out_gmem_refs_treedef: raise ValueError( - "Number of output refs does not match number of output specs." + "Output specs and output gmem refs must have the same pytree structure." + f" {out_specs_treedef} != {out_gmem_refs_treedef}" ) + flat_gmem_refs = [*flat_in_gmem_refs, *flat_out_gmem_refs] smem_allocs = [] for spec, has_seq_dim, gmem_ref in zip( - it.chain(in_specs, out_specs), + it.chain(flat_in_specs, flat_out_specs), spec_has_seq_axis, - gmem_refs): + flat_gmem_refs): slots = max_concurrent_steps if has_seq_dim else 1 smem_allocs.append( gpu_core.SMEM( - (slots, *spec.block_shape), # type: ignore + (slots, *_get_block_shape(spec)), # type: ignore gmem_ref.dtype, - transforms=spec.transforms, + transforms=getattr(spec, "transforms", ()), ) ) - in_smem_refs, out_smem_refs = util.split_list( - smem_allocs, [len(in_specs)]) - - in_smem_barriers = [] - consumed_barriers = [] - for has_seq_dim in in_spec_has_seq_axis: - num_barriers = max_concurrent_steps if has_seq_dim else 1 - in_smem_barriers.append( - gpu_core.Barrier( - num_arrivals=1, - num_barriers=num_barriers)) + flat_in_smem_refs, flat_out_smem_refs = util.split_list( + smem_allocs, [len(flat_in_specs)]) + in_smem_barrier = gpu_core.Barrier(num_arrivals=len(flat_in_specs), num_barriers=max_concurrent_steps) + flat_consumed_barriers = [] + consumed_barrier_type: Any + if collective_axes: + consumed_barrier_type = functools.partial( + gpu_core.ClusterBarrier, collective_axes=collective_axes # type: ignore + ) + else: + consumed_barrier_type = gpu_core.Barrier + for _ in flat_in_specs: if manual_consumed_barriers: - consumed_barriers.append( - gpu_core.Barrier( + flat_consumed_barriers.append( + consumed_barrier_type( num_arrivals=num_compute_wgs, num_barriers=max_concurrent_steps, ) @@ -482,43 +711,86 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): if not manual_consumed_barriers: # We only allocated one consumed barrier for all inputs when using # automatic consumed barriers. - consumed_barriers = [ - gpu_core.Barrier( + flat_consumed_barriers = [ + consumed_barrier_type( num_arrivals=num_compute_wgs, num_barriers=max_concurrent_steps, ) ] - return pl.run_scoped( - functools.partial( - scoped_pipeline, - in_gmem_refs=in_gmem_refs, - out_gmem_refs=out_gmem_refs, - ), - in_smem_refs=in_smem_refs, - out_smem_refs=out_smem_refs, - in_smem_barrier_refs=in_smem_barriers, - consumed_barrier_refs=consumed_barriers, + return dict( + flat_in_smem_refs=flat_in_smem_refs, + flat_out_smem_refs=flat_out_smem_refs, + in_smem_barrier_ref=in_smem_barrier, + flat_consumed_barrier_refs=flat_consumed_barriers, ) + def pipeline(*gmem_refs: AbstractRefPytree, allocations: Any | None = None): + """ + Run the pipeline. + + Args: + *gmem_refs: A list of pytrees of pallas refs + allocations: The allocation provided by ``pl.run_scoped`` when the result + of calling ``get_allocations(*gmem_refs)`` is passed to + ``pl.run_scoped``. + """ + in_gmem_refs = gmem_refs[:len(in_specs)] + out_gmem_refs = gmem_refs[len(in_specs):] + flat_in_gmem_refs, in_gmem_refs_treedef = jax.tree.flatten(in_gmem_refs) + flat_out_gmem_refs, out_gmem_refs_treedef = jax.tree.flatten(out_gmem_refs) + if in_specs_treedef != in_gmem_refs_treedef: + raise ValueError( + "Input specs and input gmem refs must have the same pytree structure." + f" {in_specs_treedef} != {in_gmem_refs_treedef}" + ) + if out_specs_treedef != out_gmem_refs_treedef: + raise ValueError( + "Output specs and output gmem refs must have the same pytree structure." + f" {out_specs_treedef} != {out_gmem_refs_treedef}" + ) + + if allocations is None: + if pipeline_state is not None: + raise ValueError( + "Pipeline state should not be set when using automatic allocation." + ) + return pl.run_scoped( + functools.partial( + scoped_pipeline, + flat_in_gmem_refs=flat_in_gmem_refs, + flat_out_gmem_refs=flat_out_gmem_refs, + ), + **_get_scoped_allocs(*gmem_refs), + collective_axes=wg_axis, + ) + else: + scoped_pipeline( + flat_in_gmem_refs=flat_in_gmem_refs, + flat_out_gmem_refs=flat_out_gmem_refs, + **allocations, + ) + + pipeline.get_allocations = _get_scoped_allocs + def scoped_pipeline( *, - in_gmem_refs, - out_gmem_refs, - in_smem_refs, - out_smem_refs, - in_smem_barrier_refs, - consumed_barrier_refs, + flat_in_gmem_refs, + flat_out_gmem_refs, + flat_in_smem_refs, + flat_out_smem_refs, + in_smem_barrier_ref, + flat_consumed_barrier_refs, ): - in_brefs: Sequence[BufferedRef] = [ - BufferedRef(spec, ~has_seq_axis, gmem_ref, smem_ref) + flat_in_brefs: Sequence[BufferedRef] = [ + BufferedRef(spec, not has_seq_axis, gmem_ref, smem_ref) for spec, has_seq_axis, gmem_ref, smem_ref in zip( - in_specs, in_spec_has_seq_axis, in_gmem_refs, in_smem_refs + flat_in_specs, in_spec_has_seq_axis, flat_in_gmem_refs, flat_in_smem_refs ) ] - out_brefs: Sequence[BufferedRef] = [ - BufferedRef(spec, ~has_seq_axis, gmem_ref, smem_ref) + flat_out_brefs: Sequence[BufferedRef] = [ + BufferedRef(spec, not has_seq_axis, gmem_ref, smem_ref) for spec, has_seq_axis, gmem_ref, smem_ref in zip( - out_specs, out_spec_has_seq_axis, out_gmem_refs, out_smem_refs + flat_out_specs, out_spec_has_seq_axis, flat_out_gmem_refs, flat_out_smem_refs ) ] @@ -528,120 +800,185 @@ def compute_block(): action="increase") # This is true if any of the outputs need to be transferred inside the loop. - copies_out_in_loop = not all(bref.is_index_invariant for bref in out_brefs) + smem_out_brefs = [bref for bref in flat_out_brefs if _in_smem(bref.spec)] + # The implementation below has races when we have multiple compute WGs. + # The problem is that we expect the compute WGs to deal with issuing the + # SMEM->GMEM copies, but (1) we never predicate them, so we repeat the + # same copy multiple times, and (2) we don't synchronize the compute WGs + # in any way. In the unlikely event that one of the compute WGs runs 2 + # steps ahead, it might start overwriting the output buffer before the + # other WG has issued its copy. + # + # The best fix here would be to move the SMEM->GMEM copies into the memory + # WG and use proper barriers (with arrival_count=2) to ensure all WGs have + # produced their outputs before it is sent out to GMEM. + if smem_out_brefs and num_compute_wgs > 1: + raise NotImplementedError( + "SMEM outputs are not supported with multiple compute warpgroups" + ) + copies_out_in_loop = not all(bref.is_index_invariant for bref in smem_out_brefs) + needs_epilogue = any(bref.is_index_invariant for bref in smem_out_brefs) def compute_loop_body(step, carry): - indices, last_store_slices, prev_body_carry = carry + indices, last_store_indices, prev_body_carry = carry slot = lax.rem(step, max_concurrent_steps) + consumed_slot = lax.rem(step - delay_release, max_concurrent_steps) # Wait for the current GMEM->SMEM copies to complete. - for in_barrier, has_seq_dim in zip( - in_smem_barrier_refs, in_spec_has_seq_axis): - # TODO(justinfu): Use a single barrier with - # num_arrivals=len(in_smem_barrier_refs) - gpu_primitives.barrier_wait( - in_barrier.at[_get_slot(slot, has_seq_dim)]) + gpu_primitives.barrier_wait(in_smem_barrier_ref.at[_get_slot(slot, True)]) # Wait for the previous output SMEM->GMEM copy to complete. if copies_out_in_loop: - gpu_primitives.wait_smem_to_gmem(max_concurrent_steps - 1) + gpu_primitives.wait_smem_to_gmem( + max_concurrent_steps - 1, wait_read_only=True + ) - with pallas_core.grid_env(map(pallas_core.GridAxis, indices, grid)): - body_refs = [] - for bref in it.chain(in_brefs, out_brefs): - buf_slot = _get_slot(slot, ~bref.is_index_invariant) - body_refs.append(bref.get_ref_for_slot(buf_slot)) + in_brefs = jax.tree.unflatten(in_specs_treedef, flat_in_brefs) + out_brefs = jax.tree.unflatten(out_specs_treedef, flat_out_brefs) + all_brefs = (*in_brefs, *out_brefs) + body_args = map_brefs( + lambda bref: bref.get_ref_for_slot( + _get_slot(slot, not bref.is_index_invariant) + ), + all_brefs, + ) - body_args = body_refs - if manual_consumed_barriers: - body_args += [consumed_barrier_ref.at[slot] for consumed_barrier_ref in consumed_barrier_refs] - if has_carry: - body_args += [prev_body_carry] - next_body_carry = body(*body_args) + if manual_consumed_barriers: + barriers = jax.tree.unflatten( + in_specs_treedef, + [barrier.at[consumed_slot] for barrier in flat_consumed_barrier_refs], + ) + body_args = (*body_args, *barriers) + if has_carry: + body_args = (*body_args, prev_body_carry) + next_body_carry = body(indices, *body_args) if not manual_consumed_barriers: - [consumed_barrier_ref] = consumed_barrier_refs - gpu_primitives.barrier_arrive(consumed_barrier_ref.at[slot]) + [consumed_barrier_ref] = flat_consumed_barrier_refs + if delay_release > 0: + lax.cond( + step < delay_release, + lambda: None, + lambda: gpu_primitives.barrier_arrive(consumed_barrier_ref.at[consumed_slot]), + ) + else: + gpu_primitives.barrier_arrive(consumed_barrier_ref.at[consumed_slot]) # TODO(justinfu,apaszke): This should probably be done by the memory WG. # Copy the output from SMEM to GMEM. if copies_out_in_loop: gpu_primitives.commit_smem() - new_store_slices = last_store_slices[:] - for idx, bref in enumerate(out_brefs): + new_store_indices = last_store_indices[:] + for idx, bref in enumerate(flat_out_brefs): if bref.is_index_invariant: - assert last_store_slices[idx] is None + assert last_store_indices[idx] is None continue - assert last_store_slices[idx] is not None - new_store_slices[idx] = tuple( - _Slice(s.start, s.size) for s in bref.compute_gmem_slice(indices) - ) + assert last_store_indices[idx] is not None + new_store_indices[idx] = bref.spec.index_map(*indices) are_same_slices = map( lambda old, new: old == new, - last_store_slices[idx], - new_store_slices[idx], + last_store_indices[idx], + new_store_indices[idx], ) slices_changed = ~functools.reduce(lax.bitwise_and, are_same_slices) - bref.copy_out(_get_slot(slot, ~bref.is_index_invariant), + bref.copy_out(_get_slot(slot, not bref.is_index_invariant), indices, predicate=slices_changed) gpu_primitives.commit_smem_to_gmem_group() next_indices = _inc_grid_by_1(indices, grid) - return (next_indices, new_store_slices, next_body_carry) + return (next_indices, new_store_indices, next_body_carry) init_indices = (jnp.asarray(0, dtype=jnp.int32),) * len(grid) + # TODO(justinfu): Only store base pointer instead of all indices. - last_store_slices = [ + last_store_indices = [ None if bref.is_index_invariant - else (_Slice(-1, -1),) * len(bref.spec.block_shape) - for bref in out_brefs + else (jnp.array(-1),) * len(bref.spec.block_shape) + for bref in flat_out_brefs ] if has_carry: - _carry = carry_coroutine() - try: - carry_init = next(_carry) - except StopIteration: - raise ValueError("carry_coroutine must yield the initial carry.") # pylint: disable=raise-missing-from + last_indices = None + def pipeline_callback(user_init_carry): + nonlocal last_indices + if last_indices is not None: + raise ValueError( + "Cannot call pipeline more than once in `compute_context`") + init_loop_carry = (init_indices, last_store_indices, user_init_carry) + last_indices, _, final_body_carry = lax.fori_loop(0, + num_steps, + compute_loop_body, + init_loop_carry) + return final_body_carry + compute_context(pipeline_callback) + if last_indices is None: + raise ValueError("Pipeline was not called in `compute_context`") else: - _carry = None - carry_init = None - init_loop_carry = (init_indices, last_store_slices, carry_init) - last_indices, _, final_body_carry = lax.fori_loop(0, - num_pipeline_steps, - compute_loop_body, - init_loop_carry) - if has_carry: - try: - _carry.send(final_body_carry) # pytype: disable=attribute-error - raise ValueError("carry_coroutine must only yield once.") - except StopIteration: - pass + assert compute_context is None + last_indices, _, _ = lax.fori_loop( + 0, num_steps, compute_loop_body, + (init_indices, last_store_indices, None) + ) # Handle index_invariant outputs after the loop. They are not # written in the main pipeline loop. - if not copies_out_in_loop: + if not copies_out_in_loop and needs_epilogue: gpu_primitives.commit_smem() - last_slot = lax.rem(num_pipeline_steps - 1, max_concurrent_steps) - for bref in out_brefs: - if bref.is_index_invariant: - bref.copy_out(last_slot, last_indices, predicate=None) - gpu_primitives.commit_smem_to_gmem_group() + if needs_epilogue: + last_slot = lax.rem(num_steps - 1, max_concurrent_steps) + for bref in flat_out_brefs: + if bref.is_index_invariant: + bref.copy_out(_get_slot(last_slot, has_seq_dim=False), + last_indices, predicate=None) - # Finalize the pipeline. - gpu_primitives.wait_smem_to_gmem(0) + gpu_primitives.commit_smem_to_gmem_group() + + if smem_out_brefs: + # Finalize the pipeline. + gpu_primitives.wait_smem_to_gmem(0) # The memory thread executes this block which issues all pipelined DMAs. + # TODO(apaszke,justinfu): Use a single arrive_expect_tx for all transfers. def memory_block(): gpu_primitives.set_max_registers(memory_registers, action="decrease") indices = (jnp.asarray(0, dtype=jnp.int32),) * len(grid) + if has_dynamic_grid: + prologue_steps = lax.min(max_concurrent_steps, num_steps) + else: + assert max_concurrent_steps <= num_steps + prologue_steps = max_concurrent_steps + pipeline_init_prologue_steps = prologue_steps + if pipeline_state is not None: + if has_dynamic_grid: + raise NotImplementedError( + "A pipeline of pipelines is not supported with dynamic grids" + ) + if num_steps % max_concurrent_steps: + raise NotImplementedError( + "A pipeline of pipelines is only allowed when the number of steps" + f" (product of grid, here {num_steps}) is divisible by" + f" {max_concurrent_steps=}" + ) + if delay_release: + raise NotImplementedError( + "A pipeline of pipelines is not supported with delay_release" + ) + if isinstance(pipeline_state, PipelinePipeline): + prologue_steps = prologue_steps if pipeline_state == PipelinePipeline.START else 0 + else: + prologue_steps = jnp.where(pipeline_state == PipelinePipeline.START, prologue_steps, 0) # Begin initial copies. - for step in range(max_concurrent_steps): - for bref, barrier in zip(in_brefs, in_smem_barrier_refs): - buf_slot = _get_slot(step, ~bref.is_index_invariant) - bref.copy_in(buf_slot, indices, barrier) - indices = _inc_grid_by_1(indices, grid) + def _init_step(step, indices): + for bref in flat_in_brefs: + buf_slot = _get_slot(step, not bref.is_index_invariant) + barrier_slot = _get_slot(step, True) + bref.copy_in(buf_slot, indices, in_smem_barrier_ref, barrier_slot) + return _inc_grid_by_1(indices, grid) + + indices = jax.lax.fori_loop( + 0, prologue_steps, _init_step, indices, unroll=not has_dynamic_grid + ) def memory_loop_body(step, carry): indices, = carry @@ -651,22 +988,35 @@ def memory_loop_body(step, carry): if not manual_consumed_barriers: # We only have one consumed barrier when using automatic consumed # barrier management. - [consumed_barrier_ref] = consumed_barrier_refs + [consumed_barrier_ref] = flat_consumed_barrier_refs gpu_primitives.barrier_wait(consumed_barrier_ref.at[slot]) - consumed_barrier_it = [None] * len(in_brefs) + consumed_barrier_it = [None] * len(flat_in_brefs) else: - consumed_barrier_it = consumed_barrier_refs + consumed_barrier_it = flat_consumed_barrier_refs - for bref, barrier, consumed_barrier in zip( - in_brefs, in_smem_barrier_refs, consumed_barrier_it): + for bref, consumed_barrier in zip(flat_in_brefs, consumed_barrier_it): if manual_consumed_barriers: gpu_primitives.barrier_wait(consumed_barrier.at[slot]) # pytype: disable=attribute-error - bref.copy_in( - _get_slot(fetch_slot, ~bref.is_index_invariant), indices, barrier) + buf_slot = _get_slot(fetch_slot, not bref.is_index_invariant) + barrier_slot = _get_slot(fetch_slot, True) + bref.copy_in(buf_slot, indices, in_smem_barrier_ref, barrier_slot) next_indices = _inc_grid_by_1(indices, grid) return (next_indices,) - lax.fori_loop(0, num_pipeline_steps - max_concurrent_steps, - memory_loop_body, (indices,)) + lax.fori_loop(0, num_steps - prologue_steps, memory_loop_body, (indices,)) + # Await all the arrivals to not leave barriers in a bad state. + # We only need to account for the prologue steps, only the first + # delay_release of them skip arrivals, so we subtract them. + @pl.when(pipeline_state is None or pipeline_state == PipelinePipeline.STOP) + def _quiesce(): + @pl.loop( + num_steps - pipeline_init_prologue_steps, + num_steps - delay_release, + unroll=not has_dynamic_grid, + ) + def _epi_step(step): + consumed_slot = lax.rem(step, max_concurrent_steps) + for barrier in flat_consumed_barrier_refs: + gpu_primitives.barrier_wait(barrier.at[consumed_slot]) wg_idx = lax.axis_index(wg_axis) lax.cond( @@ -674,14 +1024,23 @@ def memory_loop_body(step, carry): compute_block, memory_block ) - return pipeline + # Mypy doesn't notice the .get_allocations assignment above. + return pipeline # type: ignore def _compute_registers( memory_registers: int, num_compute_wgs: int, ) -> int: - """Returns the number of registers to use for the compute thread.""" - # TODO(justinfu): Configure this per-platform. - n_registers = (512 - memory_registers) / num_compute_wgs + """Returns the max number of registers to use in compute threads. + + We start with the theoretical max registers per thread if one wargroup + (128 threads) used the entire SM's 64k register file (64k / 128 = 512). + Then reserve `memory_registers` for the producer warpgroup and distribute + the remaining registers evenly among the compute warpgroups. + + Note: The maximum number of registers per thread is 255, so we clamp + the value. + """ + n_registers = min(256, (512 - memory_registers) / num_compute_wgs) # Round down to the nearest multiple of 8. return int((n_registers // 8) * 8) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 7f26f5d2b6a3..a5ff98b51adb 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -16,23 +16,30 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Hashable, Sequence +import contextlib import dataclasses -import enum +import functools import itertools import math from typing import Any, Literal import jax from jax._src import core as jax_core +from jax._src import debugging +from jax._src import dtypes +from jax._src import literals +from jax._src import pretty_printer as pp from jax._src import state from jax._src import tree_util from jax._src import util from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith as arith_dialect -from jax._src.lib.mlir.dialects import llvm as llvm_dialect +from jax._src.lib.mlir.dialects import builtin as builtin_dialect +from jax._src.lib.mlir.dialects import gpu as gpu_dialect from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect from jax._src.pallas import core as pallas_core +from jax._src.pallas import primitives as pallas_primitives from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.pallas.mosaic_gpu import lowering from jax._src.pallas.mosaic_gpu.core import state_types @@ -40,18 +47,23 @@ from jax._src.state import indexing from jax._src.state import primitives as state_primitives from jax.experimental.mosaic import gpu as mgpu +from jax.experimental.mosaic.gpu import layouts as mgpu_layouts +from jax.experimental.mosaic.gpu import tcgen05 from jax.experimental.mosaic.gpu import utils as mgpu_utils import jax.numpy as jnp +import numpy as np +AxisName = jax_core.AxisName +WARP_SIZE = 32 WARPGROUP_SIZE = 128 -_Ref = pallas_core.AbstractMemoryRef | state_types.TransformedRef - +_Ref = state.AbstractRef | state_types.TransformedRef +SomeLayout = gpu_core.SomeLayout def _check_ref( - aval: object, name: str, memory_space: gpu_core.GPUMemorySpace + aval: object, name: str, memory_space: gpu_core.MemorySpace ) -> None: if not isinstance(aval, state_types.AbstractRef): raise TypeError(f"{name} must be a reference, got {aval}") @@ -62,6 +74,66 @@ def _check_ref( ) +print_layout_p = jax_core.Primitive("print_layout") +print_layout_p.multiple_results = True + + +@print_layout_p.def_effectful_abstract_eval +def _print_layout_abstract_eval(aval_in, fmt, *_, **params): + del aval_in, fmt, params # Unused. + return (), {debugging.debug_effect} + + +@lowering.register_lowering_rule(print_layout_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + print_layout_p, mgpu.LoweringSemantics.Warpgroup +) +def _print_layout_lowering( + ctx: lowering.LoweringRuleContext, + x: mgpu.FragmentedArray | tcgen05.TMEMRef | ir.Value, + fmt: str, + *transforms_leaves, + transforms_tree +): + if transforms_leaves: + x, remaining_transforms = lowering._handle_transforms( + ctx, x, transforms_tree.unflatten(transforms_leaves), + ) + if remaining_transforms: + raise NotImplementedError( + f"Unsupported transforms {remaining_transforms}." + ) + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: + print(fmt.format(mgpu.dialect_lowering.pprint_layout(x))) + else: + assert isinstance(x, ir.Value) + mgpu.dialect.print_layout(fmt, x) + return () + + +def print_layout(fmt: str, x: jax.typing.ArrayLike | _Ref) -> None: + """Prints the layout chosen by Mosaic GPU for a given array or TMEM reference. + + This is evaluated at compile-time and has no incidence on the runtime behavior + of the program. + + Args: + fmt: The format string to use for printing the layout. + x: The array or TMEM reference to print the layout of. + """ + if isinstance(x, pallas_core.TransformedRef): + transforms_leaves, transforms_tree = jax.tree.flatten(x.transforms) + x = x.ref + else: + transforms_leaves, transforms_tree = [], None + print_layout_p.bind( + x, + fmt=fmt, + *transforms_leaves, + transforms_tree=transforms_tree, + ) + + copy_smem_to_gmem_p = jax_core.Primitive("copy_smem_to_gmem") copy_smem_to_gmem_p.multiple_results = True @@ -74,9 +146,48 @@ def _copy_smem_to_gmem_abstract_eval(src, dst, *args, **params): return (), {state.ReadEffect(0), state.WriteEffect(1)} -@lowering.register_lowering_rule(copy_smem_to_gmem_p, mgpu.ThreadSemantics.Lane) +def _copy_smem_to_gmem_pp_eqn( + eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings, +): + src, dst, *flat_args = eqn.invars + src_transforms_treedef = eqn.params["src_transforms_treedef"] + dst_transforms_treedef = eqn.params["dst_transforms_treedef"] + pp_params = {} + if not (commit_group := eqn.params["commit_group"]): + pp_params["commit_group"] = commit_group + if eqn.params["has_user_predicate"]: + flat_args, user_predicate = flat_args[:-1], flat_args[-1] + pp_params["user_predicate"] = jax_core.pp_var(user_predicate, context) + if reduction_op := eqn.params["reduction_op"]: + pp_params["reduction_op"] = reduction_op + flat_src_transforms, flat_dst_transforms = util.split_list( + flat_args, + [src_transforms_treedef.num_leaves], + ) + src_transforms = src_transforms_treedef.unflatten(flat_src_transforms) + dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms) + return pp.concat([ + pp.text("copy_smem_to_gmem"), + jax_core.pp_kv_pairs(pp_params.items(), context, settings), + pp.text(" "), + state_primitives.pp_ref_transforms(context, src, src_transforms), + pp.text(" -> "), + state_primitives.pp_ref_transforms(context, dst, dst_transforms), + ]) + + +jax_core.pp_eqn_rules[copy_smem_to_gmem_p] = _copy_smem_to_gmem_pp_eqn + + +@lowering.register_lowering_rule( + copy_smem_to_gmem_p, mgpu.LoweringSemantics.Lane) @lowering.register_lowering_rule( - copy_smem_to_gmem_p, mgpu.ThreadSemantics.Warpgroup + copy_smem_to_gmem_p, mgpu.LoweringSemantics.Lane, + primitive_semantics=gpu_core.PrimitiveSemantics.Warp) +@lowering.register_lowering_rule( + copy_smem_to_gmem_p, mgpu.LoweringSemantics.Warpgroup ) def _copy_smem_to_gmem_lowering( ctx: lowering.LoweringRuleContext, @@ -87,27 +198,45 @@ def _copy_smem_to_gmem_lowering( dst_transforms_treedef, has_user_predicate, commit_group, + reduction_op, ): - predicate = ctx.module_ctx.single_wg_lane_predicate if has_user_predicate: flat_args, user_predicate = flat_args[:-1], flat_args[-1] - predicate = arith_dialect.andi( - predicate, lowering._ensure_ir_value(user_predicate, jnp.bool) - ) + predicate = lowering._ensure_ir_value(user_predicate, jnp.bool) + else: + predicate = None + + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: + if predicate is not None: + assert ctx.module_ctx.single_lane_predicate is not None + predicate = arith_dialect.andi( + predicate, ctx.module_ctx.single_lane_predicate + ) + else: + predicate = ctx.module_ctx.single_lane_predicate + flat_src_transforms, flat_dst_transforms = util.split_list( flat_args, [src_transforms_treedef.num_leaves], ) src_transforms = src_transforms_treedef.unflatten(flat_src_transforms) dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms) - src, src_transforms = lowering._handle_indexing(src, src_transforms) - copy_params = _extract_gmem_copy_params(dst_transforms) | _extract_smem_copy_params(src_transforms) - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + handle_transposes = ( + ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Warpgroup + ) + src, src_transforms = lowering._handle_transforms( + ctx, src, src_transforms, handle_transposes=handle_transposes + ) + copy_params = _extract_gmem_copy_params( + ctx, dst_transforms, supports_multicast=True + ) | _extract_smem_copy_params(src_transforms) + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: ctx.launch_ctx.async_copy( src_ref=src, dst_ref=dst, predicate=predicate, arrive=commit_group, + reduction_op=reduction_op, **copy_params, ) return () @@ -119,15 +248,37 @@ def _copy_smem_to_gmem_lowering( else: indices, slice_lengths = _split_gmem_slice(copy_params["gmem_slice"]) assert copy_params.get("swizzle") is None + if copy_params.get("gmem_peer_id", None) is not None: + raise NotImplementedError( + "GMEM refs with peer ids are not supported in warpgroup lowering." + ) assert not copy_params.get("gmem_transform") - mgpu.dialect.async_store( - src, - dst, - indices, - slice_lengths, - predicate=predicate, - commit_group=commit_group, # type: ignore[call-arg] - ) + if reduction_op is not None: + # TODO(b/415721295): Call mgpu.dialect.async_store after the if, after + # the minimal jaxlib version is 0.8.2. + if not hasattr(mgpu.dialect, "TMAReduction"): + raise NotImplementedError("Reduction op is not supported yet.") + reduction_op_attr = getattr( + mgpu.dialect.TMAReduction, reduction_op.capitalize() + ) + mgpu.dialect.async_store( + src, + dst, + indices, + slice_lengths, + predicate=predicate, + commit_group=commit_group, # type: ignore[call-arg] + reduction_op=reduction_op_attr, + ) + else: + mgpu.dialect.async_store( + src, + dst, + indices, + slice_lengths, + predicate=predicate, + commit_group=commit_group, # type: ignore[call-arg] + ) return () @@ -143,24 +294,71 @@ def _split_gmem_slice(gmem_slice): case mgpu.DynamicSlice(): indices.append(arith_dialect.index_cast(i32, idx.base)) slice_lengths.append(idx.length) - case ir.Value(): + case ir.Value() if isinstance(idx.type, ir.IndexType): indices.append(arith_dialect.index_cast(i32, idx)) slice_lengths.append(-1) + case ir.Value() if isinstance(idx.type, ir.IntegerType): + indices.append(idx) + slice_lengths.append(-1) + case ir.Value() if isinstance(idx.type, ir.VectorType): + indices.append(idx) + [length] = ir.VectorType(idx.type).shape + slice_lengths.append(length) case _: raise NotImplementedError(f"Unsupported GMEM slice: {idx}") return indices, slice_lengths -def _extract_gmem_copy_params(transforms): +def _extract_gmem_copy_params(ctx, transforms, supports_multicast=False): if not transforms: return {} + peer_id = None + indexers = [] for transform in transforms: - if not isinstance(transform, indexing.NDIndexer): + if isinstance(transform, gpu_core.PeerMemRef): + peer_id, other_axes = pallas_primitives.device_id_to_logical( + ctx.module_ctx.mesh_info, + lowering._ensure_ir_value_device_id(transform.device_id), + transform.device_id_type, + lambda name: lowering._axis_index_rule(ctx, axis_name=name), + ) + if other_axes: + raise ValueError( + "Only JAX mesh axes can be used to obtain peer references, but" + f" got {other_axes}" + ) + continue + elif isinstance(transform, gpu_core.MulticastRef): + if not supports_multicast: + raise ValueError( + "Multicast refs are not supported by this primitive." + ) + if (mesh_info := ctx.module_ctx.mesh_info) is None: + raise ValueError( + "JAX device mesh is required by multicast copies, but not defined." + " Use jax.set_mesh." + ) + if set(transform.collective_axes) != set(mesh_info.axis_names): + raise NotImplementedError( + "Only collective_axes that include all JAX device mesh axes are" + f" supported, but got {transform.collective_axes}. Make sure to" + f" pass collective_axes={mesh_info.axis_names}" + ) + peer_id = mgpu.GLOBAL_BROADCAST + continue + elif isinstance(transform, indexing.NDIndexer): + indexers.append(transform) + else: raise NotImplementedError( "Non-indexing transforms on GMEM refs are not implemented.") - indexer = lowering.merge_indexers(transforms) + if indexers: + indexer = lowering.merge_indexers(indexers) + gmem_slice = lowering._ndindexer_indices(indexer, allow_arrays=True) + else: + gmem_slice = () return dict( - gmem_slice=lowering._ndindexer_indices(indexer), + gmem_slice=gmem_slice, + gmem_peer_id=peer_id, ) @@ -186,6 +384,7 @@ def copy_smem_to_gmem( predicate: jax.Array | None = None, *, commit_group: bool = True, + reduction_op: mgpu.TMAReductionOp | None = None, ) -> None: """Asynchronously copies a SMEM reference to a GMEM reference. @@ -194,19 +393,22 @@ def copy_smem_to_gmem( dst: The GMEM reference to copy to. predicate: A boolean indicating whether the copy should be performed. If ``None``, the copy is always performed. - commit_group: If ``True``, this and any previously uncommitted copies - are committed to a group and can be awaited jointly via - :func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem`. + commit_group: If ``True``, this and any previously uncommitted copies are + committed to a group and can be awaited jointly via + :func:`jax.experimental.pallas.mosaic_gpu.wait_smem_to_gmem`. + reduction_op: If set, perform the specified reduction operation when storing + to GMEM. For example, using ``"add"`` is conceptually equivalent to + doing ``src += dst``. See also: - :func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem` - :func:`jax.experimental.mosaic.gpu.commit_smem` + :func:`jax.experimental.pallas.mosaic_gpu.wait_smem_to_gmem` + :func:`jax.experimental.pallas.mosaic_gpu.commit_smem` """ src, src_transforms = state_primitives.get_ref_and_transforms( - src, None, "copy_smem_to_gmem", force_trailing_indexer=False, + src, None, "copy_smem_to_gmem" ) dst, dst_transforms = state_primitives.get_ref_and_transforms( - dst, None, "copy_smem_to_gmem", force_trailing_indexer=False, + dst, None, "copy_smem_to_gmem" ) flat_src_transforms, src_transforms_treedef = tree_util.tree_flatten( src_transforms @@ -224,6 +426,7 @@ def copy_smem_to_gmem( dst_transforms_treedef=dst_transforms_treedef, has_user_predicate=predicate is not None, commit_group=commit_group, + reduction_op=reduction_op, ) return None @@ -241,9 +444,51 @@ def _copy_gmem_to_smem_abstract_eval(src, dst, barrier, *args, **params): return (), {state.ReadEffect(0), state.WriteEffect(1)} -@lowering.register_lowering_rule(copy_gmem_to_smem_p, mgpu.ThreadSemantics.Lane) +def _copy_gmem_to_smem_pp_eqn( + eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings, +): + src, dst, barrier, *flat_args = eqn.invars + src_transforms_treedef = eqn.params["src_transforms_treedef"] + dst_transforms_treedef = eqn.params["dst_transforms_treedef"] + barrier_transforms_treedef = eqn.params["barrier_transforms_treedef"] + pp_params = {} + if collective_axes := eqn.params["collective_axes"]: + pp_params["collective_axes"] = collective_axes + flat_src_transforms, flat_dst_transforms, flat_barrier_transforms = ( + util.split_list( + flat_args, + [ + src_transforms_treedef.num_leaves, + dst_transforms_treedef.num_leaves, + ], + ) + ) + src_transforms = src_transforms_treedef.unflatten(flat_src_transforms) + dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms) + barrier_transforms = barrier_transforms_treedef.unflatten( + flat_barrier_transforms + ) + return pp.concat([ + pp.text("copy_gmem_to_smem"), + jax_core.pp_kv_pairs(pp_params.items(), context, settings), + pp.text(" "), + state_primitives.pp_ref_transforms(context, src, src_transforms), + pp.text(" -> "), + state_primitives.pp_ref_transforms(context, dst, dst_transforms), + pp.text(" using "), + state_primitives.pp_ref_transforms(context, barrier, barrier_transforms), + ]) + + +jax_core.pp_eqn_rules[copy_gmem_to_smem_p] = _copy_gmem_to_smem_pp_eqn + + @lowering.register_lowering_rule( - copy_gmem_to_smem_p, mgpu.ThreadSemantics.Warpgroup + copy_gmem_to_smem_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + copy_gmem_to_smem_p, mgpu.LoweringSemantics.Warpgroup ) def _copy_gmem_to_smem_lowering( ctx: lowering.LoweringRuleContext, @@ -254,6 +499,9 @@ def _copy_gmem_to_smem_lowering( src_transforms_treedef, dst_transforms_treedef, barrier_transforms_treedef, + collective_axes, + partitioned_axis, + for_warpgroup: bool = True, ): flat_src_transforms, flat_dst_transforms, flat_barrier_transforms = ( util.split_list( @@ -266,46 +514,123 @@ def _copy_gmem_to_smem_lowering( ) src_transforms = src_transforms_treedef.unflatten(flat_src_transforms) dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms) - dst, dst_transforms = lowering._handle_indexing(dst, dst_transforms) - copy_params = _extract_smem_copy_params(dst_transforms) | _extract_gmem_copy_params(src_transforms) - barrier_indexer = _extract_barrier_indexer( + handle_transposes = ( + ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Warpgroup + ) + dst, dst_transforms = lowering._handle_transforms( + ctx, dst, dst_transforms, handle_transposes=handle_transposes + ) + copy_params = _extract_smem_copy_params(dst_transforms) | _extract_gmem_copy_params(ctx, src_transforms) + base_index = _extract_barrier_slice_base( barrier_transforms_treedef.unflatten(flat_barrier_transforms) ) - if barrier_indexer is not None: - barrier = barrier.__getitem__( - *map(lowering._as_index, barrier_indexer.indices) + if base_index is not None: + barrier = barrier[base_index] + collective = None + if collective_axes is not None: + collective = tuple( + lowering._resolve_cluster_axis(ctx.module_ctx.axis_names, axis) + for axis in collective_axes ) + is_partitioned_copy = collective and partitioned_axis is not None dst_ty = ir.MemRefType(dst.type) - bytes = math.prod(dst_ty.shape) * mgpu.bytewidth(dst_ty.element_type) - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + bits = math.prod(dst_ty.shape) * mgpu.bitwidth(dst_ty.element_type) + if bits % 8: + raise ValueError( + f"Can only transfer integer bytes (shape={dst_ty.shape}," + f" dtype={dst_ty.element_type})" + ) + bytes = bits // 8 + + if is_partitioned_copy: + # Bytes is the destination size, which is only half of the total + # size of the partitioned transfer so we need to double it. + bytes *= 2 + if len(collective) != 1: # type: ignore + raise ValueError( + f"Expected exactly one collective axis, got {collective_axes=}" + ) + if math.prod(ctx.launch_ctx.cluster_size) != 2: + raise NotImplementedError( + "Partitioned loads only supported for clusters of size 2. Got" + f" cluster size {ctx.launch_ctx.cluster_size}." + ) + + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: if bytes % WARPGROUP_SIZE: - raise NotImplementedError("Only aligned copies are supported") - # We arrive uniformly from each thread in the WG, so we need to divide the - # number of bytes by the number of threads in the WG. - # TODO: apaszke - Relax this. We can just select the WG leader and have it - # arrive with the whole transfer size, while everyone else arrives with 0. - # But we should continue using this scheme as it's likely to be faster. - bytes //= WARPGROUP_SIZE - barrier.arrive_expect_tx(bytes) + raise NotImplementedError( + "Only copies transferring a number of bytes divisible by the" + f" warpgroup size are supported. Got {bytes=} but warpgroup size is" + f" {WARPGROUP_SIZE}" + ) + if for_warpgroup: + # We arrive uniformly from each thread in the WG, so we need to divide the + # number of bytes by the number of threads in the WG. + # TODO: apaszke - Relax this. We can just select the WG leader and have it + # arrive with the whole transfer size, while everyone else arrives with 0. + # But we should continue using this scheme as it's likely to be faster. + bytes //= WARPGROUP_SIZE + if ctx.module_ctx.auto_barriers: + mgpu.warpgroup_barrier() # Make sure all reads have completed. + if is_partitioned_copy: + first_block = arith_dialect.cmpi( + arith_dialect.CmpIPredicate.eq, + ctx.launch_ctx.cluster_idx(collective[0]), # type: ignore + mgpu.c(0, ir.IndexType.get()), + ) + barrier.arrive_expect_tx(bytes, predicate=first_block) + else: + barrier.arrive_expect_tx(bytes) + else: + # In Warp-level lowering, we arrive on each CUDA thread in a warp, but + # the barrier still expects a full 128 arrivals so we arrive 4 times + # on each CUDA thread instead. + # TODO(justinfu): The arrival counts are wrong if called outside of a + # single warp. Figure out how to guard against this in user code. + bytes = bytes // WARP_SIZE + if is_partitioned_copy: + first_block = arith_dialect.cmpi( + arith_dialect.CmpIPredicate.eq, + ctx.launch_ctx.cluster_idx(collective[0]), # type: ignore + mgpu.c(0, ir.IndexType.get()), + ) + with mgpu.when(first_block): + barrier.arrive(arrival_count=3, can_complete=False) + barrier.arrive_expect_tx(bytes) + else: + barrier.arrive(arrival_count=3, can_complete=False) + barrier.arrive_expect_tx(bytes) + + # Gathers are a warpgroup-level collective and can't take a predicate. + predicate_kwarg = dict(predicate=ctx.module_ctx.single_lane_predicate) + if gmem_slice := copy_params.get("gmem_slice", ()): + first_idx = gmem_slice[0] + if isinstance(first_idx, mgpu.FragmentedArray) and first_idx.shape: + predicate_kwarg = {} ctx.launch_ctx.async_copy( src_ref=src, dst_ref=dst, barrier=barrier, arrive=False, - predicate=ctx.module_ctx.single_wg_lane_predicate, + collective=collective, + partitioned=partitioned_axis, **copy_params, + **predicate_kwarg, ) return () - + i32 = ir.IntegerType.get_signless(32) if "gmem_slice" not in copy_params: - i32 = ir.IntegerType.get_signless(32) slice_lengths = ir.MemRefType(src.type).shape indices = [mgpu.utils.c(0, i32)] * len(slice_lengths) else: indices, slice_lengths = _split_gmem_slice(copy_params["gmem_slice"]) assert copy_params.get("swizzle") is None assert not copy_params.get("gmem_transform") - barrier_ref = barrier.as_dialect_barrier_memref() + if copy_params.get("gmem_peer_id", None) is not None: + raise NotImplementedError( + "GMEM refs with peer ids are not supported in warpgroup lowering." + ) + barrier_ref = barrier.as_barrier_memref() mgpu.dialect.arrive_expect_tx(barrier_ref, bytes) mgpu.dialect.async_load( src, @@ -313,23 +638,62 @@ def _copy_gmem_to_smem_lowering( barrier_ref, indices, slice_lengths, - collective=ir.ArrayAttr.get([]), + collective=ir.ArrayAttr.get( + [ir.IntegerAttr.get(i32, axis) for axis in collective or []] + ), ) return () -def copy_gmem_to_smem(src: _Ref, dst: _Ref, barrier: _Ref) -> None: +lowering.register_lowering_rule( + copy_gmem_to_smem_p, + mgpu.LoweringSemantics.Lane, + primitive_semantics=gpu_core.PrimitiveSemantics.Warp, +)(functools.partial(_copy_gmem_to_smem_lowering, for_warpgroup=False)) + + +def copy_gmem_to_smem( + src: _Ref, + dst: _Ref, + barrier: _Ref, + *, + collective_axes: str | tuple[str, ...] | None = None, + partitioned_axis: int | None = None, +) -> None: """Asynchronously copies a GMEM reference to a SMEM reference. + If collective_axes is specified, this performs a multicast copy where + all CUDA blocks that share the same index along the collective axis + receive a copy of the same block of data loaded from `dst` to `src`. + + If both collective_axes and partitioned_axis are specified, this will perform + a partitioned collective copy where each block in the cluster will receive + a tile of `transfer_size // cluster_size` data from the `src` Ref. + For example, if `src` has a shape of (256, 256) and a partitioned + copy is performed along axis 0 with cluster size 2, then the first block will + receive `src[0:128, :]` and the second will receive `src[128:256, :]`. + NOTE: Only the first block in the cluster will arrive on the barrier, + and an additional cluster barrier is necessary to ensure that all blocks in + the cluster have finished the copy. + + Args: + src: The source Ref. Must be in GMEM. + dst: The destination Ref. Must be in SMEM. + barrier: The barrier to use for tracking completion of the copy. + collective_axes: The collective axes to use for the copy. + partitioned_axis: Indicates which array axis along the src/dst Refs to + partition across during a partitioned collective copy. Requires + collective_axes to also be specified. + See also: - :func:`jax.experimental.mosaic.gpu.barrier_arrive` - :func:`jax.experimental.mosaic.gpu.barrier_wait` + :func:`jax.experimental.pallas.mosaic_gpu.barrier_arrive` + :func:`jax.experimental.pallas.mosaic_gpu.barrier_wait` """ src, src_transforms = state_primitives.get_ref_and_transforms( - src, None, "copy_gmem_to_smem", force_trailing_indexer=False, + src, None, "copy_gmem_to_smem" ) dst, dst_transforms = state_primitives.get_ref_and_transforms( - dst, None, "copy_gmem_to_smem", force_trailing_indexer=False, + dst, None, "copy_gmem_to_smem" ) flat_src_transforms, src_transforms_treedef = tree_util.tree_flatten( src_transforms @@ -338,11 +702,13 @@ def copy_gmem_to_smem(src: _Ref, dst: _Ref, barrier: _Ref) -> None: dst_transforms ) barrier, barrier_transforms = state_primitives.get_ref_and_transforms( - barrier, None, "copy_gmem_to_smem", force_trailing_indexer=False, + barrier, None, "copy_gmem_to_smem" ) flat_barrier_transforms, barrier_transforms_treedef = tree_util.tree_flatten( barrier_transforms ) + if isinstance(collective_axes, str): + collective_axes = (collective_axes,) copy_gmem_to_smem_p.bind( src, dst, @@ -353,30 +719,150 @@ def copy_gmem_to_smem(src: _Ref, dst: _Ref, barrier: _Ref) -> None: src_transforms_treedef=src_transforms_treedef, dst_transforms_treedef=dst_transforms_treedef, barrier_transforms_treedef=barrier_transforms_treedef, + collective_axes=collective_axes, + partitioned_axis=partitioned_axis, + ) + return None + +async_prefetch_p = jax_core.Primitive("async_prefetch") +async_prefetch_p.multiple_results = True + +@async_prefetch_p.def_effectful_abstract_eval +def _async_prefetch_abstract_eval(ref, *args, **params): + del args, params # Unused. + _check_ref(ref, "ref", gpu_core.GMEM) + return (), {state.ReadEffect(0)} + + +@lowering.register_lowering_rule(async_prefetch_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + async_prefetch_p, + mgpu.LoweringSemantics.Lane, + primitive_semantics=gpu_core.PrimitiveSemantics.Warp, +) +@lowering.register_lowering_rule( + async_prefetch_p, mgpu.LoweringSemantics.Warpgroup +) +def _async_prefetch_lowering( + ctx: lowering.LoweringRuleContext, + ref, + *flat_ref_transforms, + ref_transforms_treedef, + collective_axes, + partitioned_axis, +): + ref_transforms = ref_transforms_treedef.unflatten(flat_ref_transforms) + copy_params = _extract_gmem_copy_params(ctx, ref_transforms) + collective = None + if collective_axes is not None: + collective = tuple( + lowering._resolve_cluster_axis(ctx.module_ctx.axis_names, axis) + for axis in collective_axes + ) + + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: + predicate_kwarg = dict(predicate=ctx.module_ctx.single_lane_predicate) + if gmem_slice := copy_params.get("gmem_slice", ()): + first_idx = gmem_slice[0] + # Gathers are a warpgroup-level collective and can't take a predicate. + if isinstance(first_idx, mgpu.FragmentedArray) and first_idx.shape: + predicate_kwarg = {} + ctx.launch_ctx.async_prefetch( + gmem_ref=ref, + collective=collective, + partitioned=partitioned_axis, + **copy_params, + **predicate_kwarg, + ) + return () + + if "gmem_slice" not in copy_params: + i32 = ir.IntegerType.get_signless(32) + slice_lengths = ir.MemRefType(ref.type).shape + indices = [mgpu.utils.c(0, i32)] * len(slice_lengths) + else: + indices, slice_lengths = _split_gmem_slice(copy_params["gmem_slice"]) + assert copy_params.get("swizzle") is None + assert not copy_params.get("gmem_transform") + if copy_params.get("gmem_peer_id", None) is not None: + raise NotImplementedError( + "GMEM refs with peer ids are not supported in warpgroup lowering." + ) + mgpu.dialect.async_prefetch( + ref, indices, slice_lengths, collective=ir.ArrayAttr.get([]) + ) + return () + + +def async_prefetch( + ref: _Ref, + *, + collective_axes: str | tuple[str, ...] | None = None, + partitioned_axis: int | None = None, +) -> None: + """Asynchronously prefetches a GMEM reference to the L2 cache. + + If collective_axes is specified, each CUDA block only prefetches a part of + the ``ref``, with other parts covered by blocks that share the same index + along the collective axis. + + If both ``collective_axes`` and ``partitioned_axis`` are specified, the + ``partitioned_axis`` indicates the logical axis used to split the prefetch + across the collective axes. + + Args: + ref: The source Ref. Must be in GMEM. + collective_axes: The collective axes to use for the prefetch. + partitioned_axis: Indicates which axis of the ``ref`` to partition across + during a collective prefetch. Requires collective_axes to also be + specified. + """ + ref, ref_transforms = state_primitives.get_ref_and_transforms( + ref, None, "async_prefetch" + ) + flat_ref_transforms, ref_transforms_treedef = tree_util.tree_flatten( + ref_transforms + ) + if isinstance(collective_axes, str): + collective_axes = (collective_axes,) + async_prefetch_p.bind( + ref, + *flat_ref_transforms, + ref_transforms_treedef=ref_transforms_treedef, + collective_axes=collective_axes, + partitioned_axis=partitioned_axis, ) return None -def _extract_barrier_indexer(transforms) -> indexing.NDIndexer | None: +def _extract_barrier_slice_base(transforms) -> ir.Value | None: if not transforms: return None - match transforms: - case [indexing.NDIndexer(indices=[idx]) as indexer]: - if not isinstance(idx, indexing.Slice): - return indexer - if indexing.Slice.from_slice(slice(None), *indexer.shape) == idx: - # Special-case: the whole slice. - return None - else: - raise ValueError( - f"Barrier can only be indexed with an integer, got {idx}" - ) - case [indexing.NDIndexer()]: - raise NotImplementedError("Barrier does not support multiple indices") - case []: - return None - case _: - raise ValueError("Barrier does not support arbirary transforms") + base_index = None + while transforms: + match transforms: + case [indexing.NDIndexer(indices=[idx]) as indexer, *transforms]: + if isinstance(idx, indexing.Slice): + if indexing.Slice.from_slice(slice(None), *indexer.shape) == idx: + # Special-case: the whole slice. + continue + idx = idx.start + if isinstance( + idx, (int, ir.Value, mgpu.FragmentedArray, literals.TypedNdArray) + ): + if base_index is None: + base_index = lowering._as_index(idx) + else: + base_index = arith_dialect.addi(base_index, lowering._as_index(idx)) + else: + raise ValueError( + f"Barrier can only be indexed with integers or slices, got {idx}" + ) + case [indexing.NDIndexer(), *_]: + raise NotImplementedError("Barrier does not support multiple indices") + case _: + raise ValueError("Barrier does not support arbitrary transforms") + return base_index barrier_arrive_p = jax_core.Primitive("barrier_arrive") @@ -390,26 +876,59 @@ def _barrier_arrive_abstract_eval(barrier, *args, **params): return (), {gpu_core._memory_effect} -@lowering.register_lowering_rule(barrier_arrive_p, mgpu.ThreadSemantics.Lane) +def _barrier_arrive_pp_eqn( + eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings, +): + del settings + barrier, *flat_transforms = eqn.invars + transforms_treedef = eqn.params["transforms_treedef"] + transforms = transforms_treedef.unflatten(flat_transforms) + return pp.concat([ + pp.text("barrier_arrive"), + pp.text(" "), + state_primitives.pp_ref_transforms(context, barrier, transforms), + ]) + + +jax_core.pp_eqn_rules[barrier_arrive_p] = _barrier_arrive_pp_eqn + + +@lowering.register_lowering_rule(barrier_arrive_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule(barrier_arrive_p, mgpu.LoweringSemantics.Warpgroup) def _barrier_arrive_lowering( ctx: lowering.LoweringRuleContext, barrier, *flat_transforms, transforms_treedef, ): - del ctx # Unused. transforms = transforms_treedef.unflatten(flat_transforms) - indexer = _extract_barrier_indexer(transforms) - if indexer is not None: - barrier = barrier.__getitem__(*map(lowering._as_index, indexer.indices)) - barrier.arrive() + base_index = _extract_barrier_slice_base(transforms) + if base_index is not None: + barrier = barrier[base_index] + sem_dtype = ctx.avals_in[0].inner_aval.dtype # type: ignore + orders_tensor_core = getattr(sem_dtype, "orders_tensor_core", False) + if orders_tensor_core: + # We arrive on only one lane for barriers with orders_tensor_core=True, + # so we need to perfom a separate warpgroup barrier. + mgpu_utils.warpgroup_barrier() + + if isinstance(barrier, mgpu.CollectiveBarrierRef): + barrier.arrive(orders_tensor_core) + else: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Warpgroup: + barrier.arrive(orders_tensor_core) + else: + pred = ctx.module_ctx.single_lane_predicate if orders_tensor_core else None + barrier.arrive(orders_tensor_core=orders_tensor_core, predicate=pred) return () -def barrier_arrive(barrier: pallas_core.AbstractMemoryRef) -> None: +def barrier_arrive(barrier: state.AbstractRef) -> None: """Arrives at the given barrier.""" barrier, transforms = state_primitives.get_ref_and_transforms( - barrier, None, "barrier_arrive", force_trailing_indexer=False, + barrier, None, "barrier_arrive" ) flat_transforms, transforms_treedef = tree_util.tree_flatten(transforms) barrier_arrive_p.bind( @@ -428,31 +947,60 @@ def _barrier_wait_abstract_eval(barrier, *args, **params): return (), {gpu_core._memory_effect} -@lowering.register_lowering_rule(barrier_wait_p, mgpu.ThreadSemantics.Lane) -@lowering.register_lowering_rule(barrier_wait_p, mgpu.ThreadSemantics.Warpgroup) +def _barrier_wait_pp_eqn( + eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings, +): + del settings + barrier, *flat_transforms = eqn.invars + transforms_treedef = eqn.params["transforms_treedef"] + transforms = transforms_treedef.unflatten(flat_transforms) + return pp.concat([ + pp.text("barrier_wait"), + pp.text(" "), + state_primitives.pp_ref_transforms(context, barrier, transforms), + ]) + + +jax_core.pp_eqn_rules[barrier_wait_p] = _barrier_wait_pp_eqn + + +@lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + barrier_wait_p, + mgpu.LoweringSemantics.Lane, + gpu_core.PrimitiveSemantics.Warp, +) +@lowering.register_lowering_rule( + barrier_wait_p, mgpu.LoweringSemantics.Warpgroup +) def _barrier_wait_lowering( ctx: lowering.LoweringRuleContext, barrier, *flat_transforms, transforms_treedef, ): - del ctx # Unused. + barrier_aval = ctx.avals_in[0] transforms = transforms_treedef.unflatten(flat_transforms) - indexer = _extract_barrier_indexer(transforms) - if indexer is not None: - barrier = barrier.__getitem__(*map(lowering._as_index, indexer.indices)) - barrier.wait() + orders_tensor_core = getattr( + barrier_aval.inner_aval.dtype, "orders_tensor_core", False # type: ignore + ) + base_index = _extract_barrier_slice_base(transforms) + if base_index is not None: + barrier = barrier[base_index] + barrier.wait(orders_tensor_core=orders_tensor_core) return () -def barrier_wait(barrier: pallas_core.AbstractMemoryRef) -> None: +def barrier_wait(barrier: state.AbstractRef) -> None: """Waits on the given barrier.""" barrier, transforms = state_primitives.get_ref_and_transforms( - barrier, None, "barrier_wait", force_trailing_indexer=False, + barrier, None, "barrier_wait" ) flat_transforms, transforms_treedef = tree_util.tree_flatten(transforms) barrier_wait_p.bind( - barrier, *flat_transforms, transforms_treedef=transforms_treedef + barrier, *flat_transforms, transforms_treedef=transforms_treedef, ) @@ -466,21 +1014,29 @@ def _wait_smem_to_gmem_abstract_eval(n, *, wait_read_only): return (), {gpu_core._memory_effect} -@lowering.register_lowering_rule(wait_smem_to_gmem_p, mgpu.ThreadSemantics.Lane) @lowering.register_lowering_rule( - wait_smem_to_gmem_p, mgpu.ThreadSemantics.Warpgroup + wait_smem_to_gmem_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + wait_smem_to_gmem_p, *gpu_core.LANExWARP_SEMANTICS) +@lowering.register_lowering_rule( + wait_smem_to_gmem_p, mgpu.LoweringSemantics.Warpgroup ) def _wait_smem_to_gmem_lowering( ctx: lowering.LoweringRuleContext, n, *, wait_read_only ): + if ctx.module_ctx.primitive_semantics == gpu_core.PrimitiveSemantics.Warp: + scope = mgpu_utils.ThreadSubset.WARP + else: + scope = mgpu_utils.ThreadSubset.WARPGROUP ctx.launch_ctx.await_async_copy( - allow_groups=n, await_read_only=wait_read_only + allow_groups=n, await_read_only=wait_read_only, + scope=scope ) return () def wait_smem_to_gmem(n: int, wait_read_only: bool = False) -> None: - """Waits until there are no more than ``n`` SMEM->GMEM copies in flight. + """Waits until no more than the most recent ``n`` SMEM->GMEM copies issued by the calling thread are in flight. Args: n: The maximum number of copies in flight to wait for. @@ -499,8 +1055,9 @@ def _commit_group_abstract_eval(): return (), {gpu_core._memory_effect} -@lowering.register_lowering_rule(commit_group_p, mgpu.ThreadSemantics.Lane) -@lowering.register_lowering_rule(commit_group_p, mgpu.ThreadSemantics.Warpgroup) +@lowering.register_lowering_rule(commit_group_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + commit_group_p, mgpu.LoweringSemantics.Warpgroup) def _commit_group_lowering(ctx: lowering.LoweringRuleContext): del ctx # Unused. nvvm_dialect.cp_async_bulk_commit_group() @@ -508,7 +1065,7 @@ def _commit_group_lowering(ctx: lowering.LoweringRuleContext): def commit_smem_to_gmem_group() -> None: - """Commits all issued but uncommited SMEM->GMEM copies to a group.""" + """Commits all issued but uncommitted SMEM->GMEM copies to a group.""" commit_group_p.bind() @@ -517,11 +1074,7 @@ def commit_smem_to_gmem_group() -> None: wgmma_ref_p.multiple_results = True -def wgmma( - acc: gpu_core.WGMMAAbstractAccumulatorRef, - a, - b: pallas_core.TransformedRef, -) -> None: +def wgmma(acc: gpu_core.WGMMAAbstractAccumulatorRef, a, b) -> None: """Performs an asynchronous warp group matmul-accumulate on the given references. Conceptually, this is equivalent to doing ``acc[...] += a[...] @ b[...]``, @@ -550,19 +1103,33 @@ def wgmma( if a.dtype != b.dtype: raise ValueError(f"Mixed input dtypes for matrix multiplication unsupported: lhs={a.dtype}, rhs={b.dtype}") + acc_transforms_leaves: list + if isinstance(acc, pallas_core.TransformedRef): + acc_transforms_leaves, acc_transforms_tree = jax.tree.flatten(acc.transforms) + acc = acc.ref + else: + acc_transforms_leaves, acc_transforms_tree = [], None + if isinstance(a, pallas_core.TransformedRef): a_transforms_leaves, a_transforms_tree = jax.tree.flatten(a.transforms) a = a.ref else: a_transforms_leaves, a_transforms_tree = [], None - b_transforms_leaves, b_transforms_tree = jax.tree.flatten(b.transforms) + + if isinstance(b, pallas_core.TransformedRef): + b_transforms_leaves, b_transforms_tree = jax.tree.flatten(b.transforms) + b = b.ref + else: + b_transforms_leaves, b_transforms_tree = [], None wgmma_ref_p.bind( acc, a, - b.ref, + b, + *acc_transforms_leaves, *a_transforms_leaves, *b_transforms_leaves, + acc_transforms_tree=acc_transforms_tree, a_transforms_tree=a_transforms_tree, b_transforms_tree=b_transforms_tree, ) @@ -582,6 +1149,39 @@ def _wgmma_ref_effectful_abstract_eval(acc_aval, a_aval, b_aval, *_, **params): } +def _wgmma_ref_pp_eqn( + eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings, +): + del settings + acc, a, b, *leaves = eqn.invars + transform_treedefs = [ + eqn.params["acc_transforms_tree"], + eqn.params["a_transforms_tree"], + eqn.params["b_transforms_tree"], + ] + transform_leaves = util.split_list( + leaves, [getattr(tree, "num_leaves", 0) for tree in transform_treedefs] + ) + acc_transforms, a_transforms, b_transforms = ( + () if treedef is None else treedef.unflatten(leaves) + for treedef, leaves in zip(transform_treedefs, transform_leaves) + ) + return pp.concat([ + pp.text("wgmma_ref"), + pp.text(" "), + state_primitives.pp_ref_transforms(context, acc, acc_transforms), + pp.text(" <- "), + state_primitives.pp_ref_transforms(context, a, a_transforms), + pp.text(" @ "), + state_primitives.pp_ref_transforms(context, b, b_transforms), + ]) + + +jax_core.pp_eqn_rules[wgmma_ref_p] = _wgmma_ref_pp_eqn + + @discharge.register_discharge_rule(wgmma_ref_p) def _wgmma_ref_discharge(in_avals, out_avals, *args, **kwargs): del in_avals, out_avals @@ -592,50 +1192,79 @@ def _wgmma_ref_discharge(in_avals, out_avals, *args, **kwargs): wgmma_p = jax_core.Primitive("wgmma") -@lowering.register_lowering_rule(wgmma_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule(wgmma_p, mgpu.LoweringSemantics.Lane) def _wgmma_lowering( ctx: lowering.LoweringRuleContext, acc, a, b, *transforms_leaves, + acc_transforms_tree, a_transforms_tree, b_transforms_tree, ): - _, a_aval, *_ = ctx.avals_in lhs_swizzle: int | None = None - if a_transforms_tree is not None: - a_transforms_leaves, b_transforms_leaves = util.split_list( - transforms_leaves, [a_transforms_tree.num_leaves] + transform_treedefs = [ + acc_transforms_tree, a_transforms_tree, b_transforms_tree + ] + transform_leaves = util.split_list( + transforms_leaves, [getattr(tree, "num_leaves", 0) for tree in transform_treedefs] + ) + acc_transforms, a_transforms, b_transforms = ( + None if treedef is None else treedef.unflatten(leaves) + for treedef, leaves in zip(transform_treedefs, transform_leaves) + ) + + acc_indices = None + if acc_transforms is not None: + if not all(isinstance(t, indexing.NDIndexer) for t in acc_transforms): + raise ValueError("WGMMA accumulator only supports indexing transforms") + acc_indexer = lowering.merge_indexers(acc_transforms) + if acc_indexer.int_indexer_shape: + raise NotImplementedError("int_indexer_shape non-empty") + acc_indices = lowering._ndindexer_indices(acc_indexer) + + if a_transforms is not None: + a, a_transforms = lowering._handle_transforms( + ctx, a, a_transforms, handle_transposes=False, handle_reshapes=False ) - a_transforms = a_transforms_tree.unflatten(a_transforms_leaves) - a, a_transforms = lowering._handle_indexing(a, a_transforms) match a_transforms: case (gpu_core.UnswizzleRef(lhs_swizzle), gpu_core.UntileRef(tiling)): - swizzle_elems = lhs_swizzle // a_aval.dtype.itemsize - if tiling != (64, swizzle_elems): - raise NotImplementedError("WGMMA lhs tiling does not fit swizzle") + lhs_transpose = False + case ( + gpu_core.UnswizzleRef(lhs_swizzle), + gpu_core.UntileRef(tiling), + gpu_core.TransposeRef((1, 0)), + ): + lhs_transpose = True case _: raise ValueError(f"WGMMA lhs has unsupported transforms: {a_transforms}.") + a_mlir_dtype = ir.MemRefType(a.type).element_type + swizzle_elems = lhs_swizzle // mgpu_utils.bytewidth(a_mlir_dtype) + if tiling != (8, swizzle_elems): + raise NotImplementedError( + f"WGMMA lhs tiling does not fit swizzle. Got {tiling=}, expected (8, {swizzle_elems})" + ) else: - b_transforms_leaves = transforms_leaves # type: ignore + lhs_transpose = False if not isinstance(a, mgpu.FragmentedArray): raise ValueError( "When WGMMA lhs is passed in as a ref, it must be transformed by" " swizzling and tiling appropriately." ) - b_transforms = b_transforms_tree.unflatten(b_transforms_leaves) - b, b_transforms = lowering._handle_indexing(b, b_transforms) + assert b_transforms is not None + b, b_transforms = lowering._handle_transforms( + ctx, b, b_transforms, handle_transposes=False, handle_reshapes=False + ) match b_transforms: case (gpu_core.UnswizzleRef(rhs_swizzle), gpu_core.UntileRef(rhs_tiling)): rhs_transpose = False case ( gpu_core.UnswizzleRef(rhs_swizzle), - gpu_core.TransposeRef((1, 0, 2, 3)), # Only transpose between tiles gpu_core.UntileRef(rhs_tiling), - gpu_core.TransposeRef((1, 0)), # Transpose the two logical dims + gpu_core.TransposeRef((1, 0)), ): rhs_transpose = True case ( @@ -661,15 +1290,70 @@ def _wgmma_lowering( raise ValueError(f"WGMMA rhs has unsupported transforms: {b_transforms}.") if lhs_swizzle is not None: - swizzle_elems = rhs_swizzle // a_aval.dtype.itemsize + b_mlir_dtype = ir.MemRefType(b.type).element_type + swizzle_elems = rhs_swizzle // mgpu_utils.bytewidth(b_mlir_dtype) if rhs_swizzle != lhs_swizzle: raise NotImplementedError("WGMMA rhs swizzle must match lhs swizzle") - if rhs_tiling != (swizzle_elems, swizzle_elems): + if rhs_tiling != (8, swizzle_elems): raise NotImplementedError("WGMMA rhs tiling does not fit swizzle") + if lhs_transpose: + a = mgpu.memref_transpose(a, (1, 0, 3, 2)) if rhs_transpose: - b = mgpu.memref_transpose(b, (0, 1, 3, 2)) - new_acc = mgpu.wgmma(acc, a, b, swizzle=rhs_swizzle) + b = mgpu.memref_transpose(b, (1, 0, 3, 2)) + acc_in = acc + if acc_indices is not None: + acc_in = mgpu.WGMMAAccumulator( + _value=acc._value[acc_indices], + _original_layout=acc._original_layout, + _sync=False, + ) + acc_out = mgpu.wgmma(acc_in, a, b, swizzle=rhs_swizzle) + if acc_indices is not None: + acc_value = acc._value.copy() + acc_value[acc_indices] = acc_out._value + acc_out = mgpu.WGMMAAccumulator( + _value=acc_value, _original_layout=acc._original_layout, _sync=False + ) + nvvm_dialect.wgmma_commit_group_sync_aligned() + return acc_out + + +@lowering.register_lowering_rule(wgmma_p, mgpu.LoweringSemantics.Warpgroup) +def _wgmma_warpgroup_lowering( + ctx: lowering.LoweringRuleContext, + acc, + a, + b, + *transforms_leaves, + acc_transforms_tree, + a_transforms_tree, + b_transforms_tree, +): + if acc_transforms_tree is not None: + raise NotImplementedError + if a_transforms_tree is not None: + a_transforms_leaves, b_transforms_leaves = util.split_list( + transforms_leaves, [a_transforms_tree.num_leaves] + ) + a_transforms = a_transforms_tree.unflatten(a_transforms_leaves) + a, a_transforms = lowering._handle_transforms(ctx, a, a_transforms) + if a_transforms: + raise ValueError( + f"WGMMA lhs has unsupported transforms: {a_transforms}." + ) + else: + b_transforms_leaves = transforms_leaves # type: ignore + + if b_transforms_tree is not None: + b_transforms = b_transforms_tree.unflatten(b_transforms_leaves) + b, b_transforms = lowering._handle_transforms(ctx, b, b_transforms) + if b_transforms: + raise ValueError( + f"WGMMA rhs has unsupported transforms: {b_transforms}." + ) + + new_acc = mgpu.dialect.wgmma(acc, a, b) nvvm_dialect.wgmma_commit_group_sync_aligned() return new_acc @@ -697,7 +1381,8 @@ def wgmma_wait_effectful_abstract_eval(_): return [], {gpu_core._wgmma_pipeline_effect} -@lowering.register_lowering_rule(wgmma_wait_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule(wgmma_wait_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule(wgmma_wait_p, mgpu.LoweringSemantics.Warpgroup) def _wgmma_wait_lowering(ctx: lowering.LoweringRuleContext, allow_groups): del ctx nvvm_dialect.wgmma_wait_group_sync_aligned(allow_groups) @@ -728,70 +1413,708 @@ def _wgmma_accumulator_deref_discharge(in_avals, out_avals, acc): return (None,), wgmma_accumulator_deref_p.bind(acc) -@lowering.register_lowering_rule(wgmma_accumulator_deref_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule( + wgmma_accumulator_deref_p, mgpu.LoweringSemantics.Lane +) +@lowering.register_lowering_rule( + wgmma_accumulator_deref_p, mgpu.LoweringSemantics.Warpgroup +) def _wgmma_accumulator_deref_lowering(ctx: lowering.LoweringRuleContext, acc): - del ctx nvvm_dialect.wgmma_wait_group_sync_aligned(0) - return acc.value - - -class Layout(enum.Enum): - #: [m, n] matrix, where m % 64 == 0 == n % 8. - WGMMA = enum.auto() - #: [m] matrix, where m % 64 == 0. - WGMMA_ROW = enum.auto() - - WG_SPLAT = enum.auto() - WG_STRIDED = enum.auto() - - def __call__(self, *args, **kwargs) -> ParameterizedLayout: - return ParameterizedLayout(self, args, kwargs) - - def to_mgpu(self, *args, **kwargs) -> mgpu.FragmentedLayout: - def check_no_args(): - if args or kwargs: - raise ValueError(f"Can't instantiate {self} with arguments.") - - match self: - case Layout.WGMMA: - check_no_args() - return mgpu.WGMMA_LAYOUT - case Layout.WGMMA_ROW: - check_no_args() - return mgpu.WGMMA_ROW_LAYOUT - case Layout.WG_SPLAT: - return mgpu.WGSplatFragLayout(*args, **kwargs) # pytype: disable=missing-parameter - case Layout.WG_STRIDED: - return mgpu.WGStridedFragLayout(*args, **kwargs) - -@dataclasses.dataclass(frozen=True) -class ParameterizedLayout: - layout_cls: Layout - args: Sequence[Any] - kwargs: Any - - def to_mgpu(self) -> mgpu.FragmentedLayout: - return self.layout_cls.to_mgpu(*self.args, **self.kwargs) - + return ( + acc.value + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane + else acc + ) -layout_cast_p = jax_core.Primitive("layout_cast") +# MMA for TensorCore gen 5. +tcgen05_mma_p = jax_core.Primitive("tcgen05_mma") +tcgen05_mma_p.multiple_results = True + +def tcgen05_mma(acc: _Ref, + a: _Ref, + b: _Ref, + barrier: _Ref | None = None, + *, + a_scale: _Ref | None = None, + b_scale: _Ref | None = None, + a_sparse_metadata: _Ref | None = None, + accumulate: bool | jax.Array = True, + collective_axis: str | None = None): + """Asynchronous matrix-multiply accumulate for TensorCore gen 5 (Blackwell). + + If run in collective mode, ``acc``, ``a`` (LHS), and ``b`` (RHS) should + correspond to half of the total inputs to the MMA, where ``acc`` and ``a`` + (LHS) are split in half along the rows and ``b`` (RHS) is split along the + columns like so:: + + ----------- ----------- ----------- + | ACC1 | | LHS1 | | | | + ----------- += ----------- @ |RHS1|RHS2| + | ACC2 | | LHS2 | | | | + ----------- ----------- ----------- + + To use the block-scaled matrix-multiply, provide ``a_scale`` and ``b_scale`` + operands (they must be both present or both unspecified). -@layout_cast_p.def_abstract_eval -def _layout_cast_abstract_eval(x, new_layout): - del new_layout # Unused. - return x + Args: + acc: The accumulator. Must be a TMEM Ref. + a: The left-hand side. Must be a TMEM/SMEM Ref. + b: The right-hand side. Must be an SMEM Ref. + barrier: Optional barrier Ref for synchronizing with the tensor core. + Must have orders_tensor_core set to True. If not specified, the MMA + completion should be explicitly observed by calling + :func:`jax.experimental.pallas.mosaic_gpu.tcgen05_commit_arrive` + a_scale: An optional scale for the ``a`` operand. Must be a TMEM Ref if present. + b_scale: An optional scale for the ``b`` operand. Must be a TMEM Ref if present. + a_sparse_metadata: An optional sparse metadata for the ``a`` operand. + Must be a TMEM Ref if present. + accumulate: Whether to accumulate into acc or overwrite it. + collective_axis: The name of the cluster axis along which to perform + a collective MMA. The cluster axis should have a size of exactly 2, + and must be on the minormost cluster axis. + """ + acc_m, acc_n = acc.shape + lhs_m, lhs_k = a.shape + rhs_k, rhs_n = b.shape + if collective_axis is not None: + acc_n /= 2 + is_sparse = a_sparse_metadata is not None + if acc_m != lhs_m: + raise ValueError( + f"Accumulator and LHS have incompatible shapes. Accumulator: {acc.shape}. LHS: {a.shape}.") + if acc_n != rhs_n: + raise ValueError( + f"Accumulator and RHS have incompatible shapes. Accumulator: {acc.shape}. RHS: {b.shape}.") + if (lhs_k * (1 + is_sparse)) != rhs_k: + raise ValueError( + f"LHS and RHS have incompatible shapes. LHS: {a.shape}. RHS: {b.shape}.") + if isinstance(acc, pallas_core.TransformedRef): + acc_transforms_leaves, acc_transforms_tree = jax.tree.flatten( + acc.transforms) + acc = acc.ref + else: + acc_transforms_leaves, acc_transforms_tree = [], None -@lowering.register_lowering_rule(layout_cast_p, mgpu.ThreadSemantics.Lane) -def _layout_cast_lowering(ctx: lowering.LoweringRuleContext, x, *, new_layout): - del ctx # Unused. - return x.to_layout(new_layout.to_mgpu()) + if isinstance(a, pallas_core.TransformedRef): + a_transforms_leaves, a_transforms_tree = jax.tree.flatten(a.transforms) + a = a.ref + else: + a_transforms_leaves, a_transforms_tree = [], None + + if isinstance(b, pallas_core.TransformedRef): + b_transforms_leaves, b_transforms_tree = jax.tree.flatten(b.transforms) + b = b.ref + else: + b_transforms_leaves, b_transforms_tree = [], None + + if (is_scaled := a_scale is not None) != (b_scale is not None): + raise ValueError("a_scale and b_scale must both be present or absent.") + scales = [] + if isinstance(a_scale, pallas_core.TransformedRef): + a_scale_transforms_leaves, a_scale_transforms_tree = jax.tree.flatten( + a_scale.transforms + ) + scales.append(a_scale.ref) + else: + a_scale_transforms_leaves, a_scale_transforms_tree = [], None + scales.append(a_scale) + if isinstance(b_scale, pallas_core.TransformedRef): + b_scale_transforms_leaves, b_scale_transforms_tree = jax.tree.flatten( + b_scale.transforms + ) + scales.append(b_scale.ref) + else: + b_scale_transforms_leaves, b_scale_transforms_tree = [], None + scales.append(b_scale) + if not is_scaled: + scales = [] + + if isinstance(a_sparse_metadata, pallas_core.TransformedRef): + a_sparse_metadata_transforms_leaves, a_sparse_metadata_transforms_tree = jax.tree.flatten( + a_sparse_metadata.transforms + ) + sparse_metadata = [a_sparse_metadata.ref] + else: + a_sparse_metadata_transforms_leaves, a_sparse_metadata_transforms_tree = [], None + sparse_metadata = [a_sparse_metadata] if is_sparse else [] + + if isinstance(barrier, pallas_core.TransformedRef): + barrier_transforms_leaves, barrier_transforms_tree = jax.tree.flatten( + barrier.transforms + ) + barrier = barrier.ref + else: + barrier_transforms_leaves, barrier_transforms_tree = [], None + + if barrier is not None: + barrier_ref = [barrier] + arrive = True + else: + barrier_ref = [] + arrive = False + + tcgen05_mma_p.bind(acc, a, b, accumulate, *barrier_ref, *scales, *sparse_metadata, + *acc_transforms_leaves, *a_transforms_leaves, + *b_transforms_leaves, + *barrier_transforms_leaves, + *a_scale_transforms_leaves, *b_scale_transforms_leaves, + *a_sparse_metadata_transforms_leaves, + acc_transforms_tree=acc_transforms_tree, + a_transforms_tree=a_transforms_tree, + b_transforms_tree=b_transforms_tree, + barrier_transforms_tree=barrier_transforms_tree, + a_scale_transforms_tree=a_scale_transforms_tree, + b_scale_transforms_tree=b_scale_transforms_tree, + a_sparse_metadata_transforms_tree=a_sparse_metadata_transforms_tree, + collective_axis=collective_axis, + arrive=arrive, + scaled=bool(scales), + sparse=is_sparse) + + +@tcgen05_mma_p.def_abstract_eval +def _tcgen05_mma_abstract_eval(acc, a, b, accumulate, + *barrier_scales_and_transforms_leaves, + acc_transforms_tree, a_transforms_tree, + b_transforms_tree, + barrier_transforms_tree, + a_scale_transforms_tree, + b_scale_transforms_tree, + a_sparse_metadata_transforms_tree, + collective_axis, + arrive, + scaled, + sparse): + del (accumulate, acc_transforms_tree, + a_transforms_tree, b_transforms_tree, barrier_transforms_tree) + + if acc.memory_space != gpu_core.TMEM: + raise ValueError("Accumulator must be a TMEM Ref.") + if a.memory_space not in (gpu_core.SMEM, gpu_core.TMEM): + raise ValueError("LHS must be a TMEM/SMEM Ref.") + if b.memory_space != gpu_core.SMEM: + raise ValueError("RHS must be an SMEM Ref.") + + if collective_axis is not None: + # TODO(justinfu): If under a core_map, the avals for acc/a + # become normal MemRefs so we cannot check if they are collective. + # Figure out a way to fix this. + if isinstance(acc, gpu_core.AbstractTMEMRef) and not acc.collective: + raise ValueError( + "Accumulator Ref must be collective if collective_axis is set.") + if isinstance(a, gpu_core.AbstractTMEMRef) and not a.collective: + raise ValueError( + "LHS Ref must be collective if collective_axis is set.") + + scales_and_transforms_leaves = barrier_scales_and_transforms_leaves + if arrive: + barrier, *scales_and_transforms_leaves = barrier_scales_and_transforms_leaves + orders_tensor_core = getattr( + barrier.inner_aval.dtype, "orders_tensor_core", False) + if not orders_tensor_core: + raise ValueError("MMA barrier must have orders_tensor_core set to True.") + if scaled: + a_scale, b_scale = scales_and_transforms_leaves[:2] + if a_scale.memory_space != gpu_core.TMEM: + raise ValueError("a_scale must be a TMEM Ref") + if b_scale.memory_space != gpu_core.TMEM: + raise ValueError("b_scale must be a TMEM Ref") + + return [] + + +@lowering.register_lowering_rule(tcgen05_mma_p, *gpu_core.LANExWG_SEMANTICS) +@lowering.register_lowering_rule(tcgen05_mma_p, *gpu_core.LANExWARP_SEMANTICS) +def _tcgen05_mma_lowering( + ctx: lowering.LoweringRuleContext, + acc: tcgen05.TMEMRef, + a_ref, + b_ref, + accumulate: bool | ir.Value, + *barrier_scales_and_transforms_leaves, + acc_transforms_tree, + a_transforms_tree, + b_transforms_tree, + barrier_transforms_tree, + a_scale_transforms_tree, + b_scale_transforms_tree, + a_sparse_metadata_transforms_tree, + collective_axis, + arrive, + scaled: bool, + sparse: bool, +): + _, a_aval, b_aval, *_ = ctx.avals_in + lhs_swizzle: int | None = None + lhs_transpose: bool = False + if arrive: + barrier_ref, *scales_and_transforms_leaves = barrier_scales_and_transforms_leaves + else: + barrier_ref = None + scales_and_transforms_leaves = barrier_scales_and_transforms_leaves # type: ignore[assignment] + if scaled: + a_scale_ref, b_scale_ref, *transforms_leaves = scales_and_transforms_leaves + else: + a_scale_ref = b_scale_ref = None + transforms_leaves = scales_and_transforms_leaves # type: ignore[assignment] + if sparse: + a_sparse_metadata_ref, *transforms_leaves = transforms_leaves + else: + a_sparse_metadata_ref = None + + transforms_trees = ( + acc_transforms_tree, + a_transforms_tree, + b_transforms_tree, + barrier_transforms_tree, + a_scale_transforms_tree, + b_scale_transforms_tree, + a_sparse_metadata_transforms_tree, + ) + ( + acc_transforms_leaves, + a_transforms_leaves, + b_transforms_leaves, + barrier_transforms_leaves, + a_scale_transforms_leaves, + b_scale_transforms_leaves, + a_sparse_metadata_transforms_leaves, + leftovers, + ) = util.split_list( + transforms_leaves, + [getattr(tree, "num_leaves", 0) for tree in transforms_trees], + ) + assert not leftovers + + if acc_transforms_tree is not None: + acc_transforms = acc_transforms_tree.unflatten(acc_transforms_leaves) + acc, acc_transforms = lowering._handle_transforms( + ctx, acc, acc_transforms, handle_transposes=False + ) + if acc_transforms: + raise NotImplementedError( + f"Unsupported transforms for ACC: {acc_transforms}." + ) + + if a_transforms_tree is not None: + a_transforms = a_transforms_tree.unflatten(a_transforms_leaves) + a_dtype = lowering._transform_dtype(a_aval.dtype, a_transforms) + a_ref, a_transforms = lowering._handle_transforms( + ctx, a_ref, a_transforms, handle_transposes=False, handle_reshapes=True + ) + match a_transforms: + case (gpu_core.UnswizzleRef(lhs_swizzle), gpu_core.UntileRef(lhs_tiling)): + lhs_transpose = False + case ( + gpu_core.UnswizzleRef(lhs_swizzle), + gpu_core.UntileRef(lhs_tiling), + gpu_core.TransposeRef((1, 0)), + ): + lhs_transpose = True + case () if isinstance(a_ref, tcgen05.TMEMRef): + lhs_tiling = None # type: ignore + case _: + raise NotImplementedError( + f"Unsupported transforms for LHS: {a_transforms}." + ) + if not isinstance(a_ref, tcgen05.TMEMRef): + swizzle_elems = 8 * lhs_swizzle // dtypes.itemsize_bits(a_dtype) # type: ignore + if lhs_tiling != (8, swizzle_elems): + raise ValueError("MMA lhs tiling does not fit swizzle. " + f"{lhs_tiling=} expected={(8, swizzle_elems)}") + + assert b_transforms_tree is not None + b_transforms = b_transforms_tree.unflatten(b_transforms_leaves) + b_dtype = lowering._transform_dtype(b_aval.dtype, b_transforms) + b_ref, b_transforms = lowering._handle_transforms( + ctx, b_ref, b_transforms, handle_transposes=False, handle_reshapes=True + ) + match b_transforms: + case (gpu_core.UnswizzleRef(rhs_swizzle), gpu_core.UntileRef(rhs_tiling)): + rhs_transpose = False + case ( + gpu_core.UnswizzleRef(rhs_swizzle), + gpu_core.UntileRef(rhs_tiling), + gpu_core.TransposeRef((1, 0)), + ): + rhs_transpose = True + case _: + raise NotImplementedError( + f"Unsupported transforms for RHS: {b_transforms}." + ) + swizzle_elems = 8 * rhs_swizzle // dtypes.itemsize_bits(b_dtype) + if rhs_tiling != (8, swizzle_elems): + raise ValueError( + "MMA rhs tiling does not fit swizzle" + f" {rhs_tiling=} expected={(8, swizzle_elems)}" + ) + + if barrier_transforms_tree is not None and barrier_ref is not None: + barrier_transforms = barrier_transforms_tree.unflatten( + barrier_transforms_leaves + ) + base_index = _extract_barrier_slice_base(barrier_transforms) + if base_index is not None: + barrier_ref = barrier_ref[base_index] + + if lhs_swizzle is None: + lhs_swizzle = rhs_swizzle + elif rhs_swizzle != lhs_swizzle: + raise ValueError("MMA rhs swizzle must match lhs swizzle." + f" {lhs_swizzle=} {rhs_swizzle=}") + if lhs_transpose: + if isinstance(a_ref, tcgen05.TMEMRef): + raise ValueError("TMEM transpose not allowed.") + a_ref = mgpu.memref_transpose(a_ref, (1, 0, 3, 2)) + if rhs_transpose: + b_ref = mgpu.memref_transpose(b_ref, (1, 0, 3, 2)) + if isinstance(accumulate, bool): + accumulate = mgpu.c(accumulate, ir.IntegerType.get_signless(1)) + elif isinstance(accumulate, mgpu.FragmentedArray): + accumulate = accumulate.registers.item() + assert isinstance(accumulate, ir.Value) + + if a_scale_transforms_tree is not None: + a_scale_transforms = a_scale_transforms_tree.unflatten( + a_scale_transforms_leaves + ) + a_scale_ref, a_scale_transforms = lowering._handle_transforms( + ctx, a_scale_ref, a_scale_transforms + ) + if a_scale_transforms: + raise NotImplementedError(f"Unsupported transforms: {a_scale_transforms}") + if b_scale_transforms_tree is not None: + b_scale_transforms = b_scale_transforms_tree.unflatten( + b_scale_transforms_leaves + ) + b_scale_ref, b_scale_transforms = lowering._handle_transforms( + ctx, b_scale_ref, b_scale_transforms + ) + if b_scale_transforms: + raise NotImplementedError(f"Unsupported transforms: {b_scale_transforms}") + if a_sparse_metadata_transforms_tree is not None: + a_sparse_metadata_transforms = a_sparse_metadata_transforms_tree.unflatten( + a_sparse_metadata_transforms_leaves + ) + a_sparse_metadata_ref, a_sparse_metadata_transforms = ( + lowering._handle_transforms( + ctx, a_sparse_metadata_ref, a_sparse_metadata_transforms + ) + ) + if a_sparse_metadata_transforms: + raise NotImplementedError( + f"Unsupported transforms: {a_sparse_metadata_transforms}" + ) + + predicate = ctx.module_ctx.single_lane_predicate + if collective_axis is not None: + is_leader_block = _collective_mma_predicate(ctx, collective_axis) + predicate = arith_dialect.andi(predicate, is_leader_block) + collective = True + else: + collective = False + + with mgpu.when(predicate): + tcgen05.mma( + acc, + a_ref, + b_ref, + a_swizzle=int(lhs_swizzle), + b_swizzle=int(rhs_swizzle), + a_scale=a_scale_ref, + b_scale=b_scale_ref, + a_sparse_metadata=a_sparse_metadata_ref, + accumulate=accumulate, + collective=collective, + ) + if arrive: + tcgen05.commit_arrive(barrier_ref, + collective=collective, + ctx=ctx.launch_ctx) + return [] + + +@lowering.register_lowering_rule( + tcgen05_mma_p, mgpu.LoweringSemantics.Warpgroup +) +def _tcgen05_mma_lowering_wg( + ctx: lowering.LoweringRuleContext, + acc_ref, + a_ref, + b_ref, + accumulate: bool | ir.Value, + *barrier_scales_and_transforms_leaves, + acc_transforms_tree, + a_transforms_tree, + b_transforms_tree, + barrier_transforms_tree, + a_scale_transforms_tree, + b_scale_transforms_tree, + a_sparse_metadata_transforms_tree, + collective_axis, + arrive, + scaled: bool, + sparse: bool, +): + del ( + a_scale_transforms_tree, + b_scale_transforms_tree, + a_sparse_metadata_transforms_tree, + ) + if scaled or sparse: + raise NotImplementedError( + "Scaled and sparse MMAs not supported for WG semantics." + ) + + if arrive: + barrier_ref, *transforms_leaves = barrier_scales_and_transforms_leaves + else: + barrier_ref = None + transforms_leaves = barrier_scales_and_transforms_leaves # type: ignore[assignment] + + transforms_trees = ( + acc_transforms_tree, + a_transforms_tree, + b_transforms_tree, + barrier_transforms_tree, + ) + ( + acc_transforms_leaves, + a_transforms_leaves, + b_transforms_leaves, + barrier_transforms_leaves, + leftovers, + ) = util.split_list( + transforms_leaves, + [getattr(tree, "num_leaves", 0) for tree in transforms_trees], + ) + assert not leftovers + + if acc_transforms_tree is not None: + acc_transforms = acc_transforms_tree.unflatten(acc_transforms_leaves) + acc_ref, acc_transforms = lowering._handle_transforms( + ctx, acc_ref, acc_transforms, handle_transposes=False + ) + if acc_transforms: + raise NotImplementedError( + f"Unsupported transforms for ACC: {acc_transforms}." + ) + + if a_transforms_tree is not None: + a_transforms = a_transforms_tree.unflatten(a_transforms_leaves) + a_aval = ctx.avals_in[1] + assert isinstance(a_aval, state_types.AbstractRef) + handle_transposes = a_aval.memory_space == gpu_core.SMEM + a_ref, a_transforms = lowering._handle_transforms( + ctx, a_ref, a_transforms, handle_transposes=handle_transposes + ) + if a_transforms: + raise NotImplementedError( + f"Unsupported transforms for LHS: {a_transforms}." + ) + + if b_transforms_tree is not None: + b_transforms = b_transforms_tree.unflatten(b_transforms_leaves) + b_ref, b_transforms = lowering._handle_transforms(ctx, b_ref, b_transforms) + if b_transforms: + raise NotImplementedError( + f"Unsupported transforms for RHS: {b_transforms}." + ) + + if barrier_transforms_tree is not None and barrier_ref is not None: + barrier_transforms = barrier_transforms_tree.unflatten( + barrier_transforms_leaves + ) + base_index = _extract_barrier_slice_base(barrier_transforms) + if base_index is not None: + barrier_ref = barrier_ref[base_index] + + predicate_ctx: contextlib.AbstractContextManager[None] + if collective_axis is not None: + predicate_ctx = mgpu.when(_collective_mma_predicate(ctx, collective_axis)) + collective = True + else: + predicate_ctx = contextlib.nullcontext() + collective = False + + if isinstance(accumulate, bool): + i1 = ir.IntegerType.get_signless(1) + accumulate = arith_dialect.constant(i1, accumulate) + + with predicate_ctx: + mgpu.dialect.tcgen05_mma( + acc_ref, + a_ref, + b_ref, + accumulate=accumulate, + collective=collective, + ) + if arrive: + assert isinstance(barrier_ref, mgpu.DialectBarrierRef) + tcgen05.commit_arrive(barrier_ref.get_ptr(), collective, ctx.launch_ctx) + return [] + + +tcgen05_commit_arrive_p = jax_core.Primitive("tcgen05_commit_arrive") +tcgen05_commit_arrive_p.multiple_results = True + + +def tcgen05_commit_arrive(barrier: _Ref, + collective_axis: str | None = None): + """Tracks completion of a preceding ``tcgen05_mma`` call. + + Args: + barrier: Barrier Ref for synchronizing with the tensor core. Must have + orders_tensor_core set to True. + collective_axis: The name of the cluster axis along which the + MMA was performed if it was collective. The cluster axis should have a + size of exactly 2, and must be on the minormost cluster axis. + + See also: + :func:`jax.experimental.pallas.mosaic_gpu.tcgen05_mma` + """ + if isinstance(barrier, pallas_core.TransformedRef): + barrier_transforms_leaves, barrier_transforms_tree = jax.tree.flatten( + barrier.transforms + ) + barrier = barrier.ref + else: + barrier_transforms_leaves, barrier_transforms_tree = [], None + + tcgen05_commit_arrive_p.bind( + barrier, *barrier_transforms_leaves, + barrier_transforms_tree=barrier_transforms_tree, + collective_axis=collective_axis) + + +@tcgen05_commit_arrive_p.def_abstract_eval +def _tcgen05_commit_arrive_abstract_eval(barrier, + *barrier_transforms_leaves, + barrier_transforms_tree, + collective_axis): + del (barrier_transforms_leaves, barrier_transforms_tree, collective_axis) + orders_tensor_core = getattr( + barrier.inner_aval.dtype, "orders_tensor_core", False) + if not orders_tensor_core: + raise ValueError("MMA barrier must have orders_tensor_core set to True.") + return [] + + +@lowering.register_lowering_rule( + tcgen05_commit_arrive_p, *gpu_core.LANExWG_SEMANTICS) +@lowering.register_lowering_rule( + tcgen05_commit_arrive_p, *gpu_core.LANExWARP_SEMANTICS) +def _tcgen05_commit_arrive_lowering( + ctx: lowering.LoweringRuleContext, + barrier_ref: mgpu.BarrierRef, + *barrier_transforms_leaves, + barrier_transforms_tree, + collective_axis, +): + if barrier_transforms_tree is not None: + barrier_transforms = barrier_transforms_tree.unflatten( + barrier_transforms_leaves + ) + base_index = _extract_barrier_slice_base(barrier_transforms) + if base_index is not None: + barrier_ref = barrier_ref[base_index] + + predicate = ctx.module_ctx.single_lane_predicate + if collective_axis is not None: + is_leader_block = _collective_mma_predicate(ctx, collective_axis) + predicate = arith_dialect.andi(predicate, is_leader_block) + collective = True + else: + collective = False + + with mgpu.when(predicate): + tcgen05.commit_arrive(barrier_ref, + collective=collective, + ctx=ctx.launch_ctx) + return [] + + +@lowering.register_lowering_rule( + tcgen05_commit_arrive_p, mgpu.LoweringSemantics.Warpgroup +) +def _tcgen05_commit_arrive_lowering_wg( + ctx: lowering.LoweringRuleContext, + barrier_ref: mgpu.DialectBarrierRef, + *barrier_transforms_leaves, + barrier_transforms_tree, + collective_axis, +): + if barrier_transforms_tree is not None: + barrier_transforms = barrier_transforms_tree.unflatten( + barrier_transforms_leaves + ) + base_index = _extract_barrier_slice_base(barrier_transforms) + if base_index is not None: + barrier_ref = barrier_ref[base_index] + + predicate_ctx: contextlib.AbstractContextManager[None] + if collective_axis is not None: + predicate_ctx = mgpu.when(_collective_mma_predicate(ctx, collective_axis)) + collective = True + else: + predicate_ctx = contextlib.nullcontext() + collective = False + + with predicate_ctx: + tcgen05.commit_arrive(barrier_ref.get_ptr(), collective, ctx.launch_ctx) + return [] + + +def _collective_mma_predicate(ctx: lowering.LoweringRuleContext, + collective_axis: str) -> ir.Value: + """Computes a predicate to run only on the leader block.""" + cluster_axis = lowering._resolve_cluster_axis( + ctx.module_ctx.axis_names, collective_axis) + if cluster_axis != gpu_dialect.Dimension(0): + # Note: resolve_cluster_axis checks if axis_names exists. + assert ctx.module_ctx.axis_names is not None + if len(ctx.module_ctx.axis_names.cluster) <= 1: + raise ValueError("No cluster axes found.") + minormost_cluster_axis = ctx.module_ctx.axis_names.cluster[0] + raise ValueError( + "Can only perform collective MMA along minormost cluster axis. " + f"Got {collective_axis}, expected {minormost_cluster_axis}.") + index = ir.IndexType.get() + is_leader_block = arith_dialect.cmpi( + arith_dialect.CmpIPredicate.eq, + ctx.launch_ctx.cluster_idx(cluster_axis), mgpu.c(0, index)) + return is_leader_block + + +commit_tmem_p = jax_core.Primitive("commit_tmem") +commit_tmem_p.multiple_results = True + + +@commit_tmem_p.def_effectful_abstract_eval +def _commit_tmem_abstract_eval(): + return (), {gpu_core._memory_effect} + + +@lowering.register_lowering_rule(commit_tmem_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + commit_tmem_p, mgpu.LoweringSemantics.Warpgroup +) +def _commit_tmem_lowering(_): + tcgen05.commit_tmem() + return () -def layout_cast(x: Any, new_layout: Layout | ParameterizedLayout): - """Casts the layout of the given array.""" - return layout_cast_p.bind(x, new_layout=new_layout) +def commit_tmem(): + """Commits all writes to TMEM issued by the current thread. + + Once this function returns, the effects of calling ``async_store_tmem`` from + the current thread are visible to TMEM loads, MMA and barrier operations of + ``Barrier``s with ``orders_tensor_core=True``. + """ + commit_tmem_p.bind() set_max_registers_p = jax_core.Primitive("set_max_registers_p") @@ -804,7 +2127,10 @@ def _set_max_registers_abstract_eval(n, *, action): return (), {gpu_core._memory_effect} -@lowering.register_lowering_rule(set_max_registers_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule( + set_max_registers_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + set_max_registers_p, mgpu.LoweringSemantics.Warpgroup) def _set_max_registers_lowering( ctx: lowering.LoweringRuleContext, n, *, action ): @@ -832,62 +2158,46 @@ def _commit_smem_abstract_eval(): return (), {gpu_core._memory_effect} -@lowering.register_lowering_rule(commit_smem_p, mgpu.ThreadSemantics.Lane) -@lowering.register_lowering_rule(commit_smem_p, mgpu.ThreadSemantics.Warpgroup) +@lowering.register_lowering_rule(commit_smem_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + commit_smem_p, mgpu.LoweringSemantics.Warpgroup) def _commit_smem_lowering(ctx: lowering.LoweringRuleContext): + # TODO(bchetioui): add primitive for commit smem to mosaic_gpu dialect. mgpu.commit_shared() return () def commit_smem(): - """Commits all writes to SMEM, making them visible to loads, TMA and WGMMA.""" + """Commits all writes to SMEM, making them visible to TMA and MMA operations.""" commit_smem_p.bind() -broadcasted_iota_p = jax_core.Primitive("broadcasted_iota") - -@broadcasted_iota_p.def_abstract_eval -def _broadcasted_iota_abstract_eval(dtype, shape, dimension, layout): - del layout, dimension - return jax_core.ShapedArray(shape, dtype) - - -@lowering.register_lowering_rule(broadcasted_iota_p, mgpu.ThreadSemantics.Lane) -def _broadcasted_iota_lowering( - ctx: lowering.LoweringRuleContext, dtype, shape, dimension, layout -): - del ctx # Unused. - mlir_dtype = mgpu_utils.dtype_to_ir_type(dtype) - if ir.FloatType.isinstance(mlir_dtype): - i32 = ir.IntegerType.get_signless(32) - cast = lambda x: arith_dialect.uitofp( - mlir_dtype, arith_dialect.index_cast(i32, x) - ) - else: - cast = lambda x: arith_dialect.index_cast(mlir_dtype, x) - is_signed = mgpu_utils.is_signed(dtype) - return mgpu.FragmentedArray.splat( - llvm_dialect.mlir_undef(mlir_dtype), - shape, - layout.to_mgpu(), - is_signed=is_signed, - ).foreach( - lambda _, idx: cast(idx[dimension]), - create_array=True, - is_signed=is_signed, - ) - - def broadcasted_iota( dtype: jax.typing.DTypeLike, shape: Sequence[int], dimension: int, *, - layout: Layout | None = None, + layout: SomeLayout | None = None, ) -> jax.Array: - return broadcasted_iota_p.bind( - dtype=jnp.dtype(dtype), shape=shape, dimension=dimension, layout=layout - ) + result = jax.lax.broadcasted_iota(dtype, shape, dimension) + if layout is not None: + result = gpu_core.layout_cast(result, layout) + return result + + +@lowering.register_lowering_rule(jax_core.closed_call_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule(jax_core.closed_call_p, mgpu.LoweringSemantics.Warpgroup) +def _closed_call_lowering_rule(ctx, *args, call_jaxpr: jax_core.ClosedJaxpr): + if call_jaxpr.consts: raise NotImplementedError + return lowering.lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, ctx.launch_ctx, call_jaxpr.jaxpr, args) + + +@lowering._register_resource_estimator(jax_core.closed_call_p) +def _closed_call_resource_estimator(ctx, *args, call_jaxpr): + del args # Unused. + if call_jaxpr.consts: raise NotImplementedError + return lowering._estimate_resources(ctx, call_jaxpr.jaxpr) jaxpr_call_p = jax_core.Primitive("jaxpr_call") @@ -900,8 +2210,42 @@ def _jaxpr_call_abstract_eval(*args, jaxpr: jax_core.Jaxpr, **params): return [v.aval for v in jaxpr.outvars] -@lowering.register_lowering_rule(jaxpr_call_p, mgpu.ThreadSemantics.Lane) -@lowering.register_lowering_rule(jaxpr_call_p, mgpu.ThreadSemantics.Warpgroup) +def _jaxpr_call_pp_eqn( + eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings, +): + flat_args = eqn.invars + ref_treedefs = eqn.params["ref_treedefs"] + flat_refs, _ = util.split_list( + flat_args, [sum(treedef.num_leaves for treedef in ref_treedefs)] + ) + flat_refs = util.split_list( + flat_refs, + [treedef.num_leaves for treedef in ref_treedefs[: len(ref_treedefs) - 1]], + ) + trailer = [] + for treedef, flat_ref in zip(ref_treedefs, flat_refs): + ref = treedef.unflatten(flat_ref) + transforms = [] + if isinstance(ref, tuple): + ref, transforms = ref + trailer.append(pp.text(" ")) + trailer.append(state_primitives.pp_ref_transforms(context, ref, transforms)) + return pp.concat([ + pp.text("jaxpr_call"), + pp.text("["), + jax_core.pp_kv_pair("jaxpr", eqn.params["jaxpr"], context, settings), + pp.text("]"), + pp.concat(trailer), + ]) + + +jax_core.pp_eqn_rules[jaxpr_call_p] = _jaxpr_call_pp_eqn + + +@lowering.register_lowering_rule(jaxpr_call_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule(jaxpr_call_p, mgpu.LoweringSemantics.Warpgroup) def _jaxpr_call_lowering_rule( ctx: lowering.LoweringRuleContext, *flat_args, @@ -920,15 +2264,20 @@ def _jaxpr_call_lowering_rule( for treedef, flat_ref in zip(ref_treedefs, flat_refs): ref = treedef.unflatten(flat_ref) if isinstance(ref, tuple): + ref, transforms = ref # We ignore other transforms here, because they are already embedded # in the jaxpr. - ref, _ = lowering._handle_indexing(*ref) + ref, _ = lowering._handle_transforms( + ctx, ref, transforms, handle_reshapes=False, handle_transposes=False + ) args.append(ref) program_ids = program_ids_treedef.unflatten(flat_program_ids) for axis, pid in enumerate(program_ids): if pid is not None: continue - program_ids[axis] = lowering._program_id(axis, ctx.module_ctx.squashed_dims) + program_ids[axis] = lowering._program_id( + axis, ctx.module_ctx.squashed_dims, len(program_ids) + ) new_module_ctx = dataclasses.replace(ctx.module_ctx, program_ids=program_ids) return lowering.lower_jaxpr_to_mosaic_gpu( new_module_ctx, ctx.launch_ctx, jaxpr, args @@ -969,7 +2318,7 @@ def _jaxpr_call_discharge( outs = jaxpr_call_p.bind( *flat_args, jaxpr=discharged_jaxpr, - ref_treedefs=ref_treedefs, + ref_treedefs=tuple(ref_treedefs), program_ids_treedef=program_ids_treedef, ) discharged_outs_it = iter(outs[len(jaxpr.outvars) :]) @@ -985,7 +2334,7 @@ def _jaxpr_call_discharge( def jaxpr_call( jaxpr: jax_core.Jaxpr, - *refs: pallas_core.AbstractMemoryRef | state_types.TransformedRef, + *refs: state.AbstractRef | state_types.TransformedRef, program_ids: Sequence[jax.Array | None], ) -> Sequence[jax.Array]: """Internal primitive for calling a kernel jaxpr inside ``emit_pipeline``. @@ -1022,6 +2371,1540 @@ def jaxpr_call( *flat_refs, *flat_program_ids, jaxpr=jaxpr, - ref_treedefs=ref_treedefs, + ref_treedefs=tuple(ref_treedefs), program_ids_treedef=program_ids_treedef, ) + + +@dataclasses.dataclass(frozen=True) +class ShapeDtypeStruct: + shape: tuple[int, ...] + dtype: jnp.dtype + layout: SomeLayout + + +inline_mgpu_p = jax_core.Primitive("inline_mgpu_p") +inline_mgpu_p.multiple_results = True + + +@dataclasses.dataclass(frozen=True) +class RefType: + transforms: tuple[gpu_core.MemoryRefTransform, ...] = () + + +def _undo_transforms( + raw_ref: state.AbstractRef, + memory_transforms: Sequence[gpu_core.MemoryRefTransform], +): + """Extract the `Transform`s that reverse the `MemoryRefTransform`s""" + tmp_ref = state_types.TransformedRef(raw_ref, transforms=()) + tmp_ref = functools.reduce(lambda r, t: t.undo(r), reversed(memory_transforms), tmp_ref) + return tmp_ref.transforms + + +def inline_mgpu(*, arg_types=(), return_type=None): + r"""Returns a decorator that inlines Mosaic GPU code. + + This allows using lower-level Mosaic GPU abstractions and operations, which + are otherwise not directly exposed in Pallas. + + Example:: + + layout = plgpu.Layout.WG_STRIDED(x_ref.shape, vec_size=4) + + @plgpu.inline_mgpu( + arg_types=(plgpu.RefType(),), + return_type=plgpu.ShapeDtypeStruct( + (128, 128), dtype, layout=layout + ), + ) + def add_one(ctx, smem_ref): + x = mgpu.FragmentedArray.load_tiled(smem_ref) + y = mgpu.FragmentedArray.splat( + mgpu.c(1, x.mlir_dtype), shape=x.shape, layout=x.layout + ) + return x + y + + Args: + arg_types: A sequence of pytrees where the leaves are + :class:`~jax.experimental.pallas.mosaic_gpu.RefType`\s or + :class:`~jax.experimental.pallas.mosaic_gpu.Layout`\s for reference or + array arguments respectively. + return_type: A pytree where the leaves are + :class:`~jax.experimental.pallas.mosaic_gpu.ShapeDtypeStruct`\s + representing the arrays returned by the decorated function. + """ + flat_arg_types, treedef_ty = jax.tree.flatten(tuple(arg_types)) + flat_ret_ty, pytree_ret_ty = jax.tree.flatten(return_type) + if return_type and not all(isinstance(r, ShapeDtypeStruct) for r in flat_ret_ty): + raise ValueError( + "inline_mgpu_p only supports plgpu.ShapeDtypeStruct return types." + ) + if not all(isinstance(r, (SomeLayout, RefType)) for r in flat_arg_types): + raise ValueError( + "inline_mgpu_p only supports only SomeLayout and RefType arg types." + ) + + def inner(f): + def wrapper(*args): + flat_args, treedef = jax.tree.flatten(tuple(args)) + if treedef != treedef_ty: + raise ValueError(f"Mismatched type shape: {treedef} != {treedef_ty}") + + # Strip the transforms from the refs since they will be recorded in + # the types. + ref_transforms = [] + raw_flat_args = [] + for a, t in zip(flat_args, flat_arg_types): + if isinstance(a, state_types.TransformedRef) and isinstance(t, RefType): + raw_flat_args.append(a.ref) + ref_transforms.append(a.transforms) + elif isinstance(aval := jax_core.get_aval(a), jax_core.ShapedArray) and isinstance(t, SomeLayout): + raw_flat_args.append(a) + ref_transforms.append(None) + elif isinstance(aval, state.AbstractRef) and isinstance(t, RefType): + raw_flat_args.append(a) + ref_transforms.append(()) + else: + raise ValueError(f"Mismatched type: {a, t}") + + flat_ref_transforms, pytree_ref_transforms = jax.tree.flatten(ref_transforms) + flat_ret = inline_mgpu_p.bind( + *raw_flat_args, + *flat_ref_transforms, + flat_arg_types=tuple(flat_arg_types), + flat_ret_ty=tuple(flat_ret_ty), + pytree_ret_ty=pytree_ret_ty, + pytree_args=treedef, + pytree_ref_transforms=pytree_ref_transforms, + mgpu_fn=f, + ) + return jax.tree.unflatten(pytree_ret_ty, flat_ret) + return wrapper + + return inner + + +@inline_mgpu_p.def_effectful_abstract_eval +def _inline_mgpu_abstract_eval( + *flat_args_and_transforms, + flat_arg_types, + flat_ret_ty, + pytree_args, + pytree_ref_transforms, + pytree_ret_ty, + mgpu_fn, +): + del flat_arg_types, pytree_ret_ty, pytree_ref_transforms, mgpu_fn # Unused. + aval_return = tuple( + jax_core.ShapedArray(x.shape, x.dtype) for x in flat_ret_ty + ) + # TODO(cperivol): Let the user set the effects. + flat_args = flat_args_and_transforms[:pytree_args.num_leaves] + return aval_return, { + gpu_core._wgmma_pipeline_effect, + gpu_core._memory_effect, + *itertools.chain.from_iterable( + (state.ReadEffect(i), state.WriteEffect(i)) + for i, r in enumerate(flat_args) + if isinstance(r, state.AbstractRef) + ), + } + + +@discharge.register_partial_discharge_rule(inline_mgpu_p) +def _inline_mgpu_discharge(*args, **kwargs): + del args, kwargs + raise NotImplementedError("inline_mgpu_p does not support discharge.") + + +def _type_check_mgpu_lane_semantics(v, ty): + match (ty, v): + case (RefType(), ir.Value()) if isinstance(v.type, ir.MemRefType): + pass + case (ShapeDtypeStruct(), mgpu.FragmentedArray()): + mlir_dtype = mgpu_utils.dtype_to_ir_type(ty.dtype) + if v.mlir_dtype != mlir_dtype: + raise ValueError( + f"Array dtype mismatch: expected {v.mlir_dtype} got {mlir_dtype}." + ) + if ty.shape != v.shape: + raise ValueError( + f"Array shape mismatch: expected {ty.shape} got {v.shape}." + ) + if v.layout != ty.layout.to_mgpu(): + raise ValueError( + f"Array layout mismatch: expected {v.layout} got {ty.layout.to_mgpu()}." + ) + case (SomeLayout(), mgpu.FragmentedArray()): + if ty.to_mgpu() != v.layout: + raise ValueError(f"Unexpected layout for {v} (expected: {ty})") + case _: + raise ValueError(f"Unexpected type {ty} for value {v}") + + +def _inline_mgpu_flat_transformed_args( + ctx: lowering.LoweringRuleContext, + flat_args_and_transforms, + flat_arg_types, + pytree_args, + pytree_ref_transforms, + ) -> Sequence[ir.Value]: + flat_args = flat_args_and_transforms[:pytree_args.num_leaves] + flat_arg_avals = ctx.avals_in[:pytree_args.num_leaves] + ref_transforms = pytree_ref_transforms.unflatten(flat_args_and_transforms[pytree_args.num_leaves:]) + is_wg_semantics = ( + ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Warpgroup + ) + + if not is_wg_semantics: + for a, t in zip(flat_args, flat_arg_types): + _type_check_mgpu_lane_semantics(a, t) + + flat_transformed : list[ir.Value] = [] + for a, aval, t, transforms in zip( + flat_args, flat_arg_avals, flat_arg_types, ref_transforms, strict=True + ): + if not isinstance(t, RefType): + flat_transformed.append(a) + assert transforms is None + continue + assert isinstance(aval, state.AbstractRef) + a, user_transforms = lowering._handle_transforms( + ctx, + a, + transforms, + handle_transposes=is_wg_semantics, + ) + + if is_wg_semantics: + if user_transforms: + raise NotImplementedError( + "Not all transforms could be handled. Remaining transforms:" + f" {user_transforms}." + ) + else: + # Transforms that do not originate from a MemoryRefTransform are + # applied implicitly (eg by emit-pipeline) and therefore we do not + # expect the user to pass them to the type. The transforms not + # passed by the user here will be discharged. + ty_transforms = _undo_transforms(aval, t.transforms) + if ty_transforms != tuple(user_transforms): + raise ValueError(f"Transform mismatch: got {user_transforms}, expected {ty_transforms}") + flat_transformed.append(a) + + return flat_transformed + + +@lowering.register_lowering_rule(inline_mgpu_p, mgpu.LoweringSemantics.Lane) +def _inline_mgpu_lowering_rule( + ctx: lowering.LoweringRuleContext, + *flat_args_and_transforms, + mgpu_fn: Callable[..., Any], + flat_arg_types, + flat_ret_ty, + pytree_args, + pytree_ref_transforms, + pytree_ret_ty, +): + flat_transformed = _inline_mgpu_flat_transformed_args( + ctx, + flat_args_and_transforms, + flat_arg_types, + pytree_args, + pytree_ref_transforms, + ) + args = jax.tree.unflatten(pytree_args, flat_transformed) + ret = mgpu_fn(ctx.launch_ctx, *args) + ret_leaves, ret_tree = jax.tree.flatten( + ret, lambda x: isinstance(x, mgpu.FragmentedArray) + ) + + if ret_tree != pytree_ret_ty: + return_type = jax.tree.unflatten(pytree_ret_ty, flat_ret_ty) + raise ValueError( + f"inline_mgpu_p return type tree mismatch: {ret} != {return_type}" + ) + + for ty, r in zip(flat_ret_ty, ret_leaves): + _type_check_mgpu_lane_semantics(r, ty) + + return ret_leaves + + +def _ref_type_to_transforms(ref_type: RefType) -> ir.ArrayAttribute: + """Returns the Mosaic GPU transforms for the given ref type.""" + transform_attrs = [t.to_gpu_transform_attr() for t in ref_type.transforms] + return ir.ArrayAttr.get(transform_attrs) + + +def _replace_uses_in_block(old: ir.Value, new: ir.Value, block: ir.Block): + """Replaces all uses of the `old` value with the `new` value in `block`.""" + + def is_contained_within_block(operand: ir.OpOperand, block: ir.Block) -> bool: + current_op = operand.owner.operation + while (parent := current_op.parent) is not None: + if current_op.block == block: + return True + current_op = parent + return False + + for use in old.uses: + if is_contained_within_block(use, block): + use.owner.operands[use.operand_number] = new + + +def _clone_custom_op_with_extra_args( + custom_op: mgpu.dialect.CustomPrimitiveOp, extra_args: Sequence[ir.Value] +) -> mgpu.dialect.CustomPrimitiveOp: + """Clones a CustomPrimitiveOp and its block adding the given extra_args. + + The new args are not allowed to contain SMEM refs or vector types. The extra + args are added in order at the end of the existing parameter list. + + The reason we need to do this is because the custom primitive op has the + "IsolatedFromAbove" trait, which requires that its block does not close + over any values defined outside of it. When lowering the provided mgpu_fn, + it's possible that it closed over values from the conext (such as the SMEM + descriptors if it calls async_copy). Post-processing the original block + with this function is therefore required to restore the isolation property. + """ + for arg in extra_args: + if isinstance(arg.type, ir.MemRefType) and mgpu_utils.is_smem_ref(arg.type): + raise ValueError(f"Extra arg {arg} must not be an SMEM ref.") + if isinstance(arg.type, ir.VectorType): + raise ValueError(f"Extra arg {arg} must not have a vector type.") + + new_operands = list(custom_op.operands) + list(extra_args) + old_block = custom_op.body.blocks[0] + new_in_types = [a.type for a in list(old_block.arguments) + list(extra_args)] + + # Below, we can reuse all layouts and transforms, because the extra args + # are not smem refs or vectors. + new_op = mgpu.dialect.CustomPrimitiveOp( + result=custom_op.results, + operands_=new_operands, + in_layouts=custom_op.in_layouts, + in_transforms=custom_op.in_transforms, + out_layouts=custom_op.out_layouts, + ) + new_block = new_op.body.blocks.append(*new_in_types) + for op in old_block.operations: + new_block.append(op) + for old_arg, new_arg in zip(old_block.arguments, new_block.arguments): + old_arg.replace_all_uses_with(new_arg) + num_old_args = len(old_block.arguments) + for extra_arg, new_arg in zip( + extra_args, new_block.arguments[num_old_args:], strict=True + ): + _replace_uses_in_block(extra_arg, new_arg, new_block) + + return new_op + + +def _custom_primitive_in_specs( + ctx: lowering.LoweringRuleContext, + flat_arg_types, + flat_transformed_args, + pytree_args, +) -> tuple[Sequence[ir.Type], Sequence[ir.Attribute], Sequence[ir.ArrayAttr]]: + """Returns a tuple containing the list of MLIR input types, layouts, and + transforms for the given JAX array and ref arguments.""" + in_types = [] + in_layouts = [] + in_transforms : list[ir.ArrayAttr] = [] + flat_arg_avals = ctx.avals_in[:pytree_args.num_leaves] + for aval, transformed, t in zip( + flat_arg_avals, flat_transformed_args, flat_arg_types + ): + match aval: + case state.AbstractRef(): + initial_ty = ir.MemRefType(transformed.type) + in_types.append(initial_ty) + if mgpu_utils.is_smem_ref(initial_ty): + in_transforms.append(_ref_type_to_transforms(t)) + case jax_core.ShapedArray() if isinstance(t, SomeLayout): + el_type = mgpu_utils.dtype_to_ir_type(aval.dtype) + if len(aval.shape) == 0: + in_types.append(el_type) + else: + vector_type = ir.VectorType.get(aval.shape, el_type) + in_types.append(vector_type) + in_layouts.append(mgpu_layouts.to_layout_attr(t.to_mgpu())) + case _: + raise NotImplementedError( + f"Unsupported aval type: {aval}, {type(aval)}, {t}" + ) + return in_types, in_layouts, in_transforms + + +def _custom_primitive_op_results(flat_ret_ty) -> tuple[ + Sequence[ir.Type], + Sequence[ir.Attribute | None], +]: + """Returns a tuple containing the list of output MLIR types, and layouts for + the given JAX return types.""" + results_ty: list[ir.Type] = [] + out_layouts: list[ir.Attribute | None] = [] + for r in flat_ret_ty: + if not isinstance(r, ShapeDtypeStruct): + raise NotImplementedError(f"Expected a ShapeDtypeStruct, but got: {r}") + el_type = mgpu_utils.dtype_to_ir_type(r.dtype) + if not r.shape: # scalar case. + results_ty.append(el_type) + out_layouts.append(None) + else: + results_ty.append(ir.VectorType.get(r.shape, el_type)) + layout = mgpu_layouts.to_layout_attr(r.layout.to_mgpu()) + out_layouts.append(layout) + return results_ty, out_layouts + + +def _populate_custom_primitive_op_block( + ctx: lowering.LoweringRuleContext, + block: ir.Block, + mgpu_fn: Callable[..., Any], + pytree_args, + in_layouts: Sequence[ir.Attribute], + in_transforms: ir.ArrayAttr, + results_ty: Sequence[ir.Type], + out_layouts: Sequence[ir.Attribute | None], +): + """Calls the given mgpu_fn to populate the block, handling inputs and outputs. + + Block arguments that are references to SMEM or vectors are unwrapped to + transformed references and fragmented arrays before they are passed to the + python function mgpu_fn. + + The resulting fragmented arrays, if any, are wrapped as vectors before they + are returned. + """ + with ir.InsertionPoint(block): + fn_inputs = [] + in_layouts_it = iter(in_layouts) + in_transforms_it = iter(in_transforms) + avals_in = ctx.avals_in[:pytree_args.num_leaves] + for arg, aval in zip(block.arguments, avals_in, strict=True): + if isinstance(arg.type, ir.MemRefType): + memref_ty = ir.MemRefType(arg.type) + if not mgpu_utils.is_smem_ref(memref_ty): + fn_inputs.append(arg) + continue + + _, transforms = ( + mgpu.dialect_lowering.swizzle_and_transforms_from_transforms_attr( + next(in_transforms_it) + ) + ) + # The block arguments in the Mosaic GPU dialect are logical refs that + # wrap the transfromed refs. Since the mgpu_fn works at the lowered + # "lane" level, we need to transform (lower) the inputs before passing + # them to the mgpu_fn. + transformed_type = mgpu.dialect_lowering.transformed_smem_ref_type( + memref_ty, transforms + ) + conversion_cast = builtin_dialect.UnrealizedConversionCastOp( + [transformed_type], [arg] + ) + fn_inputs.append(conversion_cast.result) + elif isinstance(arg.type, ir.VectorType): + layout_attr = next(in_layouts_it) + layout = mgpu.layouts.from_layout_attr(layout_attr) + + vector_ty = ir.VectorType(arg.type) + reg_shape = layout.registers_shape(vector_ty.shape) + reg_ty = layout.registers_element_type(vector_ty.element_type) + + # The vector block arguments in the Mosaic GPU dialect are wrapped + # Fragmented Arrays. Since the mgpu_fn works at the lowered + # "lane" level, we need to unwrap (lower) the input vectors before + # passing them to the mgpu_fn. + conversion_cast = builtin_dialect.UnrealizedConversionCastOp( + [reg_ty] * math.prod(reg_shape), [arg] + ) + conversion_cast.attributes["registers_shape"] = ir.ArrayAttr.get([ + ir.IntegerAttr.get(ir.IntegerType.get_signless(64), s) + for s in reg_shape + ]) + conversion_cast.attributes["layout"] = layout_attr + + registers = np.array(list(conversion_cast.results)).reshape(reg_shape) + is_signed = mgpu_utils.is_signed(aval.dtype) + fa = mgpu.FragmentedArray( + _registers=registers, _layout=layout, _is_signed=is_signed + ) + fn_inputs.append(fa) + else: + fn_inputs.append(arg) + + args = jax.tree.unflatten(pytree_args, fn_inputs) + inner_ret = mgpu_fn(ctx.launch_ctx, *args) + if inner_ret is None: + inner_ret = [] + elif not isinstance(inner_ret, tuple) and not isinstance(inner_ret, list): + inner_ret = [inner_ret] + ir_ret = [] + for fa, result_ty, out_layout in zip( + inner_ret, results_ty, out_layouts, strict=True + ): + if not isinstance(fa, mgpu.FragmentedArray): + raise ValueError(f"Expected a FragmentedArray, but got: {fa}") + if isinstance(result_ty, ir.VectorType): + result_shape = ir.VectorType(result_ty).shape + if fa.shape != tuple(result_shape): + raise ValueError(f"Expected {result_shape} but got {fa.shape}") + if out_layout != mgpu.layouts.to_layout_attr(fa.layout): + raise ValueError( + f"Output layout {out_layout} does not match the layout of the" + f" returned fragmented array {fa.layout}." + ) + ir_ret.append( + mgpu.dialect_lowering.fragmented_array_to_ir(fa, result_ty) + ) + else: # scalar case. + assert out_layout is None + if fa.shape: + raise ValueError(f"Expected 0D shape, but got {fa.shape}") + if not isinstance(fa.layout, mgpu.WGSplatFragLayout): + raise ValueError(f"Expected WGSplatFragLayout, but got {fa.layout}") + value = fa.registers.item() + ir_ret.append(value) + + mgpu.dialect.ReturnOp(operands_=ir_ret) + + +def _closed_over_values(block: ir.Block) -> list[ir.Value]: + """Returns the values closed over in the given block.""" + def _closed_over_values_inner( + block: ir.Block, vals_in_block: set[ir.Value] + ) -> list[ir.Value]: + closed_over_values = [] + for arg in block.arguments: + vals_in_block.add(arg) + for op in block.operations: + for o in op.operands: + if o not in vals_in_block: + closed_over_values.append(o) + for r in op.regions: + for b in r.blocks: + closed_over_values.extend(_closed_over_values_inner(b, vals_in_block)) + for r in op.results: + vals_in_block.add(r) + return closed_over_values + return _closed_over_values_inner(block, set()) + + +@lowering.register_lowering_rule(inline_mgpu_p, mgpu.LoweringSemantics.Warpgroup) +def _inline_mgpu_lowering_rule_wg_semantics( + ctx: lowering.LoweringRuleContext, + *flat_args_and_transforms, + mgpu_fn: Callable[..., Any], + flat_arg_types, + flat_ret_ty, + pytree_args, + pytree_ref_transforms, + pytree_ret_ty, +): + del pytree_ret_ty + flat_transformed_args = _inline_mgpu_flat_transformed_args( + ctx, + flat_args_and_transforms, + flat_arg_types, + pytree_args, + pytree_ref_transforms, + ) + + in_types, in_layouts, in_transforms = ( + _custom_primitive_in_specs( + ctx, flat_arg_types, flat_transformed_args, pytree_args + ) + ) + results_ty, out_layouts = _custom_primitive_op_results(flat_ret_ty) + + custom_op = mgpu.dialect.CustomPrimitiveOp( + result=results_ty, + operands_=flat_transformed_args, + in_layouts=in_layouts, + in_transforms=in_transforms, + out_layouts=[l for l in out_layouts if l is not None], + ) + block : ir.Block = custom_op.body.blocks.append(*in_types) + _populate_custom_primitive_op_block( + ctx, + block, + mgpu_fn, + pytree_args, + in_layouts, + in_transforms, + results_ty, + out_layouts, + ) + + # We need to ensure that the block doesn't capture any values from the context + # and uses args for everything instead. E.g. `LaunchContext.tma_descriptors` + # will be captured when calling `ctx.async_copy`. + captured = _closed_over_values(block) + if captured: + old_custom_op = custom_op + custom_op = _clone_custom_op_with_extra_args(custom_op, captured) + old_custom_op.erase() + + return custom_op.results + + +load_p = jax_core.Primitive("load") + + +@load_p.def_effectful_abstract_eval +def _load_abstract_eval(src, *avals_flat, tree, optimized): + del optimized # Unused. + transforms = tree.unflatten(avals_flat) + dtype = lowering._transform_dtype(src.dtype, transforms) + transforms = list(transforms) + if not transforms or not isinstance(transforms[-1], indexing.NDIndexer): + ref_shape = state.get_transforms_shape(transforms, src.shape) + transforms.append(indexing.NDIndexer.make_trivial_indexer(ref_shape)) + shape = transforms[-1].get_indexer_shape() + return jax_core.ShapedArray(shape, dtype), {state.ReadEffect(0)} + + +lowering.register_lowering_rule(load_p, mgpu.LoweringSemantics.Lane)( + lowering._get_lowering_rule +) +lowering.register_lowering_rule( + load_p, mgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warp +)( + lowering._get_lowering_rule +) +lowering.register_lowering_rule(load_p, mgpu.LoweringSemantics.Warpgroup)( + lowering._get_lowering_rule_wg +) + + +def load( + src: _Ref, + idx, + *, + layout: SomeLayout | None = None, + optimized: bool = True, +) -> jax.Array: + """Loads from a reference into an array with the specified layout. + + Args: + src: The reference to load from. Can be either in SMEM or GMEM. + idx: The index to load from. + layout: The optional layout to use for the resulting array. + optimized: If True, a compilation error will be raised if no optimized + implementation for the load is available. + + Returns: + The loaded array. + """ + src, src_transforms = state_primitives.get_ref_and_transforms( + src, idx, "load" + ) + flat_src_transforms, src_transforms_treedef = tree_util.tree_flatten( + src_transforms + ) + result = load_p.bind( + src, + *flat_src_transforms, + tree=src_transforms_treedef, + optimized=optimized, + ) + if layout is not None: + result = gpu_core.layout_cast(result, layout) + return result + + +async_load_tmem_p = jax_core.Primitive("async_load") + +def async_load_tmem(src: _Ref, *, layout: SomeLayout | None = None) -> jax.Array: + """Performs an asynchronous load from the TMEM array. + + The load operation is only partly asynchronous. The returned array can be used + immediately, without any additional synchronization. However, it cannot be + assumed that the read from TMEM has completed when the function returns. If + you ever attempt to overwrite the read region, you should ensure that + ``wait_load_tmem`` has been called before that happens. Failure to do so + can result in nondeterministic data races. + + For example, the following sequence of operations at the end of the kernel is + valid, even though the TMEM load is never awaited:: + + smem_ref[...] = plgpu.async_load_tmem(tmem_ref) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_ref, gmem_ref) + plgpu.wait_smem_to_gmem(0) + + However, if the kernel was persistent and might reuse the TMEM again, the + sequence should be extended with a call to ``wait_load_tmem``. + + Args: + src: The TMEM reference to load from. + layout: The optional layout hint to use for the resulting array. + """ + src, src_transforms = state_primitives.get_ref_and_transforms( + src, None, "async_load_tmem" + ) + flat_src_transforms, src_transforms_treedef = tree_util.tree_flatten( + src_transforms + ) + result = async_load_tmem_p.bind( + src, *flat_src_transforms, tree=src_transforms_treedef + ) + if layout is not None: + result = gpu_core.layout_cast(result, layout) + return result + +@async_load_tmem_p.def_effectful_abstract_eval +def _async_load_tmem_abstract_eval(src, *avals_flat, tree): + if src.memory_space != gpu_core.MemorySpace.TMEM: + raise ValueError("Async load only supports TMEM refs") + return state_primitives._get_abstract_eval(src, *avals_flat, tree=tree) + +@lowering.register_lowering_rule(async_load_tmem_p, mgpu.LoweringSemantics.Lane) +def _async_load_tmem_lowering_rule( + ctx: lowering.LoweringRuleContext, x_ref, *leaves, tree +): + assert isinstance(x_ref, tcgen05.TMEMRef) + transforms = jax.tree.unflatten(tree, leaves) + x_tmem, transforms = lowering._handle_transforms( + ctx, x_ref, transforms, handle_transposes=False, handle_reshapes=False, + ) + if transforms: + raise NotImplementedError( + f"Unimplemented transforms for TMEM refs. {transforms=}" + ) + layout_hint = None + if isinstance(ctx.out_layout_hint, mgpu.TiledLayout): + layout_hint = ctx.out_layout_hint + is_signed = mgpu_utils.is_signed(ctx.avals_out[0].dtype) + return x_tmem.load(layout=layout_hint, is_signed=is_signed) + + +@lowering.register_lowering_rule( + async_load_tmem_p, mgpu.LoweringSemantics.Warpgroup +) +def _async_load_tmem_lowering_rule_wg( + ctx: lowering.LoweringRuleContext, x_ref: ir.Value, *leaves, tree +): + assert isinstance(x_ref, ir.Value) + assert isinstance(x_ref.type, ir.MemRefType) + + transforms = jax.tree.unflatten(tree, leaves) + x_tmem, transforms = lowering._handle_transforms( + ctx, + x_ref, + transforms, + handle_transposes=False, + handle_reshapes=False, + ) + if transforms: + raise NotImplementedError( + f"Unimplemented transforms for TMEM refs. {transforms=}" + ) + return mgpu.dialect.async_load_tmem(x_tmem) + + +wait_load_tmem_p = jax_core.Primitive("wait_load_tmem") +wait_load_tmem_p.multiple_results = True + +def wait_load_tmem(): + """Awaits all previously asynchronous TMEM loads issued by the calling thread. + + Once this function returns, the TMEM loads issued by the calling thread are + guaranteed to have completed. The read TMEM regions can be safely overwritten + by the calling thread, or any threads signalled through ``Barrier``s with + ``orders_tensor_core=True``. + """ + wait_load_tmem_p.bind() + + +@wait_load_tmem_p.def_effectful_abstract_eval +def _wait_load_tmem_abstract_eval(): + return (), {gpu_core._memory_effect} + + +@lowering.register_lowering_rule(wait_load_tmem_p, mgpu.LoweringSemantics.Lane) +def _wait_load_tmem_lowering(_): + tcgen05.wait_load_tmem() + return () + + +async_store_tmem_p = jax_core.Primitive("async_store_tmem") +async_store_tmem_p.multiple_results = True + +def async_store_tmem(ref: _Ref, value): + """Stores the value to TMEM. + + The store is asynchronous and is not guaranteed to be visible (e.g. by reads + or MMA operations) until ``commit_tmem`` has been called. + + Args: + ref: The TMEM reference to store to. + value: The value to store. + """ + ref, ref_transforms = state_primitives.get_ref_and_transforms( + ref, None, "async_store_tmem" + ) + flat_ref_transforms, ref_transforms_treedef = tree_util.tree_flatten( + ref_transforms + ) + async_store_tmem_p.bind( + ref, value, *flat_ref_transforms, tree=ref_transforms_treedef + ) + +@async_store_tmem_p.def_effectful_abstract_eval +def _async_store_tmem_abstract_eval(ref, val, *avals_flat, tree): + if ref.memory_space != gpu_core.MemorySpace.TMEM: + raise ValueError("Async store only supports TMEM refs") + _, effects = state_primitives._swap_abstract_eval( + ref, val, *avals_flat, tree=tree + ) + return (), effects + +@lowering.register_lowering_rule(async_store_tmem_p, mgpu.LoweringSemantics.Lane) +def _async_store_tmem_lowering_rule( + ctx: lowering.LoweringRuleContext, x_ref, value, *leaves, tree +): + assert isinstance(x_ref, tcgen05.TMEMRef) + transforms = jax.tree.unflatten(tree, leaves) + x_tmem, transforms = lowering._handle_transforms( + ctx, x_ref, transforms, handle_transposes=False, handle_reshapes=False, + ) + if transforms: + raise NotImplementedError( + f"Unimplemented transforms for TMEM refs. {transforms=}" + ) + x_tmem.store(value) + return () + + +@lowering.register_lowering_rule( + async_store_tmem_p, mgpu.LoweringSemantics.Warpgroup +) +def _async_store_tmem_lowering_rule_wg( + ctx: lowering.LoweringRuleContext, + x_ref: ir.Value, + value: ir.Value, + *leaves, + tree, +): + assert isinstance(x_ref, ir.Value) + assert isinstance(x_ref.type, ir.MemRefType) + assert isinstance(value, ir.Value) + assert isinstance(value.type, ir.VectorType) + + transforms = jax.tree.unflatten(tree, leaves) + x_tmem, transforms = lowering._handle_transforms( + ctx, + x_ref, + transforms, + handle_transposes=False, + handle_reshapes=False, + ) + if transforms: + raise NotImplementedError( + f"Unimplemented transforms for TMEM refs. {transforms=}" + ) + mgpu.dialect.async_store_tmem(value, x_tmem) + return () + + +async_copy_scales_to_tmem_p = jax_core.Primitive("async_copy_scales_to_tmem") +async_copy_scales_to_tmem_p.multiple_results = True + + +def async_copy_scales_to_tmem( + smem_ref: _Ref, tmem_ref: _Ref, collective_axis: AxisName | None = None, +): + """Copies the MMA scales from SMEM to TMEM. + + The copy is performed asynchronously and can be awaited by calling + ``tcgen05_commit_arrive`` and waiting on the specified barrier. However, if + the copy is consumed by an MMA operation issued in the same thread, no + synchronization is necessary (except for eventually awaiting the MMA operation + itself). + """ + smem_ref, smem_transforms = state_primitives.get_ref_and_transforms( + smem_ref, None, "async_copy_scales_to_tmem" + ) + flat_smem_transforms, smem_transforms_treedef = tree_util.tree_flatten( + smem_transforms + ) + tmem_ref, tmem_transforms = state_primitives.get_ref_and_transforms( + tmem_ref, None, "async_copy_scales_to_tmem" + ) + flat_tmem_transforms, tmem_transforms_treedef = tree_util.tree_flatten( + tmem_transforms + ) + async_copy_scales_to_tmem_p.bind( + smem_ref, tmem_ref, *flat_smem_transforms, *flat_tmem_transforms, + smem_tree=smem_transforms_treedef, tmem_tree=tmem_transforms_treedef, + collective_axis=collective_axis, + ) + + +async_copy_sparse_metadata_to_tmem_p = jax_core.Primitive("async_copy_sparse_metadata_to_tmem") +async_copy_sparse_metadata_to_tmem_p.multiple_results = True + + +def async_copy_sparse_metadata_to_tmem( + smem_ref: _Ref, tmem_ref: _Ref, collective_axis: AxisName | None = None +): + """Copies the MMA sparse metadata from SMEM to TMEM. + + The copy is performed asynchronously and can be awaited by calling + ``tcgen05_commit_arrive`` and waiting on the specified barrier. However, if + the copy is consumed by an MMA operation issued in the same thread, no + synchronization is necessary (except for eventually awaiting the MMA operation + itself). + """ + smem_ref, smem_transforms = state_primitives.get_ref_and_transforms( + smem_ref, None, "async_copy_sparse_metadata_to_tmem" + ) + flat_smem_transforms, smem_transforms_treedef = tree_util.tree_flatten( + smem_transforms + ) + tmem_ref, tmem_transforms = state_primitives.get_ref_and_transforms( + tmem_ref, None, "async_copy_sparse_metadata_to_tmem" + ) + flat_tmem_transforms, tmem_transforms_treedef = tree_util.tree_flatten( + tmem_transforms + ) + async_copy_sparse_metadata_to_tmem_p.bind( + smem_ref, tmem_ref, *flat_smem_transforms, *flat_tmem_transforms, + smem_tree=smem_transforms_treedef, tmem_tree=tmem_transforms_treedef, + collective_axis=collective_axis, + ) + + +@async_copy_scales_to_tmem_p.def_effectful_abstract_eval +@async_copy_sparse_metadata_to_tmem_p.def_effectful_abstract_eval +def _async_copy_to_tmem_abstract_eval(smem_ref, tmem_ref, *_args, **_kwargs): + if smem_ref.memory_space != gpu_core.MemorySpace.SMEM: + raise ValueError("async_copy_scales_to_tmem source must be an SMEM ref") + if tmem_ref.memory_space != gpu_core.MemorySpace.TMEM: + raise ValueError("async_copy_scales_to_tmem target must be a TMEM ref") + return (), {gpu_core._memory_effect} + +def _async_copy_to_tmem_lowering_rule( + impl, ctx: lowering.LoweringRuleContext, smem_ref, tmem_ref, *leaves, smem_tree, tmem_tree, collective_axis +): + assert isinstance(tmem_ref, tcgen05.TMEMRef) + smem_leaves, tmem_leaves = util.split_list(leaves, [smem_tree.num_leaves]) + smem_transforms = jax.tree.unflatten(smem_tree, smem_leaves) + tmem_transforms = jax.tree.unflatten(tmem_tree, tmem_leaves) + smem_ref, smem_transforms = lowering._handle_transforms(ctx, smem_ref, smem_transforms) + tmem_ref, tmem_transforms = lowering._handle_transforms(ctx, tmem_ref, tmem_transforms) + if smem_transforms: + raise NotImplementedError(f"Unimplemented transforms for SMEM refs: {smem_transforms}") + if tmem_transforms: + raise NotImplementedError(f"Unimplemented transforms for TMEM refs: {tmem_transforms}") + + predicate = ctx.module_ctx.single_lane_predicate + if collective_axis is not None: + is_leader_block = _collective_mma_predicate(ctx, collective_axis) + predicate = arith_dialect.andi(predicate, is_leader_block) + collective = True + else: + collective = False + + with mgpu.when(predicate): + impl(smem_ref, tmem_ref, collective=collective) + return () + +@lowering.register_lowering_rule( + async_copy_scales_to_tmem_p, mgpu.LoweringSemantics.Lane +) +@lowering.register_lowering_rule( + async_copy_scales_to_tmem_p, + mgpu.LoweringSemantics.Lane, + gpu_core.PrimitiveSemantics.Warp, +) +def _async_copy_scales_to_tmem_lowering_rule(*args, **kwargs): + return _async_copy_to_tmem_lowering_rule( + tcgen05.async_copy_scales_smem_to_tmem, *args, **kwargs + ) + + +@lowering.register_lowering_rule( + async_copy_sparse_metadata_to_tmem_p, mgpu.LoweringSemantics.Lane +) +@lowering.register_lowering_rule( + async_copy_sparse_metadata_to_tmem_p, + mgpu.LoweringSemantics.Lane, + gpu_core.PrimitiveSemantics.Warp, +) +def _async_copy_sparse_metadata_to_tmem_lowering_rule(*args, **kwargs): + return _async_copy_to_tmem_lowering_rule( + tcgen05.async_copy_sparse_metadata_smem_to_tmem, *args, **kwargs + ) + + +semaphore_signal_parallel_p = jax_core.Primitive('semaphore_signal_parallel') +semaphore_signal_parallel_p.multiple_results = True + + +@dataclasses.dataclass(frozen=True) +class SemaphoreSignal: + ref: _Ref + _: dataclasses.KW_ONLY + device_id: pallas_primitives.DeviceId | None + inc: int | jax.Array = 1 + + +def semaphore_signal_parallel(*signals: SemaphoreSignal): + """Signals multiple semaphores without any guaranteed ordering of signal arrivals. + + This primitive is largely equivalent to:: + + for sem in semaphores: + pl.semaphore_signal(sem, inc, device_id=device_id) + + only unlike the loop above, it does not guarantee any ordering of signal + arrivals. In particular, the target device might observe a signal on + ``semaphores[1]`` before it observes a signal on ``semaphores[0]``. + This operation still guarantees that any side effects performed before the + signal will be fully performed and visible before any of the signals arrive. + + The relaxed requirements make the whole operation significantly cheaper on + GPUs, as a single expensive memory fence can be used for all signals (instead + of an expensive fence for each signal). + """ + semaphores = [s.ref for s in signals] + device_ids = [s.device_id for s in signals] + incs = [jnp.asarray(s.inc, dtype=jnp.int32) for s in signals] + refs, transforms = util.unzip2( + map(pallas_primitives._get_ref_and_transforms, semaphores) + ) + args = [refs, transforms, incs, device_ids] + flat_args, args_tree = tree_util.tree_flatten(args) + semaphore_signal_parallel_p.bind( + *flat_args, + args_tree=args_tree, + ) + + +@semaphore_signal_parallel_p.def_effectful_abstract_eval +def _semaphore_signal_parallel_abstract_eval(*avals, args_tree): + ( + sem_avals, + sem_transforms_avals, + value_avals, + device_id_avals, + ) = tree_util.tree_unflatten(args_tree, avals) + for sem_aval, sem_transform_avals in zip(sem_avals, sem_transforms_avals, strict=True): + pallas_primitives.check_sem_avals(sem_aval, sem_transform_avals, "signal") + if any(va.dtype != jnp.dtype("int32") for va in value_avals): + raise ValueError( + "Must signal int32 values, but got" + f" {[aval.dtype for aval in value_avals]}" + ) + effs = set() + for device_id in device_id_avals: + if device_id is not None: + device_id_flat_avals = tree_util.tree_leaves(device_id) + for aval in device_id_flat_avals: + if aval.dtype != jnp.dtype("int32"): + raise ValueError( + f"`device_id`s must be int32 values, but got {aval.dtype}" + ) + effs.add(pallas_core.comms_effect) + return [], effs + + +@lowering.register_lowering_rule(semaphore_signal_parallel_p, mgpu.LoweringSemantics.Lane) +def _semaphore_signal_lowering_rule( + ctx: lowering.LoweringRuleContext, *args, args_tree, +): + i32 = ir.IntegerType.get_signless(32) + sems, transforms, values, device_ids = tree_util.tree_unflatten( + args_tree, args + ) + transformed_sems = [] + for sem, sem_transforms in zip(sems, transforms, strict=True): + sem, sem_transforms = lowering._handle_transforms(ctx, sem, sem_transforms) + if sem_transforms: + raise NotImplementedError(f"Unhandled transforms for semaphore_signal_parallel: {sem_transforms}") + transformed_sems.append(sem) + del sems, transforms # Use transformed_sems instead. + for sem, value, device_id in zip(transformed_sems, values, device_ids, strict=True): + sem_ptr = mgpu.utils.memref_ptr(sem) + if device_id is not None: + device_id, other_axes = pallas_primitives.device_id_to_logical( + ctx.module_ctx.mesh_info, + device_id, + pallas_primitives.DeviceIdType.MESH, + lambda name: lowering._axis_index_rule(ctx, axis_name=name), + ) + if other_axes: + raise NotImplementedError( + f"Only JAX mesh axes can be used in device_id, but found {other_axes}" + ) + device_id = lowering._ensure_ir_value(device_id, jnp.int32) + sem_ptr = ctx.launch_ctx.to_remote(sem_ptr, device_id) + # TODO(apaszke): Narrow the scope from .sys to .gpu when the semaphore is local. + # We only signal the semaphore from a single lane, which does not guarantee + # anything about the state of the other three warps in the warpgroup (they + # might still be e.g. reading memory that someone will overwrite once they + # receive a signal). + if ctx.module_ctx.auto_barriers: + mgpu.utils.warpgroup_barrier() + val = lowering._ir_constant(value, i32) + mgpu_utils.SemaphoreRef(sem_ptr).signal( + val, predicate=ctx.module_ctx.single_wg_lane_predicate, relaxed=True, + ) + mgpu_utils.fence_release_sys() + return () + +try_cluster_cancel_p = jax_core.Primitive('try_cluster_cancel') +try_cluster_cancel_p.multiple_results = True + +@try_cluster_cancel_p.def_effectful_abstract_eval +def _try_cluster_cancel_abstract_eval(*args, **params): + del args, params + + return (), {gpu_core._memory_effect} + +@lowering.register_lowering_rule( + try_cluster_cancel_p, mgpu.LoweringSemantics.Lane +) + +def try_cluster_cancel_lowering( + ctx: lowering.LoweringRuleContext, + result_ref, + barrier: mgpu.BarrierRef, + *transforms_leaves, + result_transforms_tree, + barrier_transforms_tree, +): + i1 = ir.IntegerType.get_signless(1) + i32 = ir.IntegerType.get_signless(32) + + if result_transforms_tree is not None: + res_transforms_leaves, barrier_transforms_leaves = util.split_list( + transforms_leaves, [result_transforms_tree.num_leaves]) + res_transforms = result_transforms_tree.unflatten(res_transforms_leaves) + result_ref, res_transforms = lowering._handle_transforms( + ctx, result_ref, res_transforms) + if res_transforms: + raise NotImplementedError( + f"Unimplemented transforms for result ref: {res_transforms}" + ) + else: + barrier_transforms_leaves = transforms_leaves # type: ignore + + if barrier_transforms_tree is not None: + base_index = _extract_barrier_slice_base( + barrier_transforms_tree.unflatten(barrier_transforms_leaves) + ) + if base_index is not None: + barrier = barrier[base_index] + + result_ty = ir.MemRefType(result_ref.type) + bits = math.prod(result_ty.shape) * mgpu.bitwidth(result_ty.element_type) + if bits != 128: + raise TypeError( + f"Try cluster cancel response must be 128 bits, but is {bits} bits." + ) + + is_first_wg = arith_dialect.cmpi( + arith_dialect.CmpIPredicate.eq, mgpu.warpgroup_idx(), mgpu.c(0, i32) + ) + is_leader_thread = arith_dialect.andi( + ctx.module_ctx.single_lane_predicate, is_first_wg + ) + + bytes = arith_dialect.select(is_leader_thread, mgpu.c(16, i32), mgpu.c(0, i32)) + barrier.arrive_expect_tx(bytes) + + is_first_cta = mgpu.c(1, i1) + for dim in gpu_dialect.Dimension: + is_first_cta = arith_dialect.andi( + is_first_cta, + arith_dialect.cmpi( + arith_dialect.CmpIPredicate.eq, + ctx.launch_ctx.cluster_idx(dim), + mgpu.c(0, ir.IndexType.get()), + ), + ) + + mgpu.try_cluster_cancel( + result_ref, + barrier, + predicate=arith_dialect.andi(is_leader_thread, is_first_cta), + ) + + return [] + + +def try_cluster_cancel(result_ref: _Ref, barrier: _Ref) -> None: + """Initiates an async request to claim a new work unit from the grid. + + It allows an SM to dynamically acquire work by atomically canceling the launch + of a pending cluster from the grid and claiming its CTA ID as the next unit + of work. + + Args: + result_ref: An SMEM ref where the 16-byte result will be stored. + barrier: A barrier used to coordinate the completion of the query. + + See also: + :func:`jax.experimental.pallas.mosaic_gpu.query_cluster_cancel` + """ + if isinstance(result_ref, pallas_core.TransformedRef): + result_transforms_leaves, result_transforms_tree = jax.tree.flatten( + result_ref.transforms + ) + result_ref = result_ref.ref + else: + result_transforms_leaves, result_transforms_tree = [], None + + if isinstance(barrier, pallas_core.TransformedRef): + barrier_transforms_leaves, barrier_transforms_tree = jax.tree.flatten( + barrier.transforms + ) + barrier = barrier.ref + else: + barrier_transforms_leaves, barrier_transforms_tree = [], None + + try_cluster_cancel_p.bind( + result_ref, + barrier, + *result_transforms_leaves, + *barrier_transforms_leaves, + result_transforms_tree=result_transforms_tree, + barrier_transforms_tree=barrier_transforms_tree, + ) + + +query_cluster_cancel_p = jax_core.Primitive("query_cluster_cancel") +query_cluster_cancel_p.multiple_results = True + +@query_cluster_cancel_p.def_effectful_abstract_eval +def _query_cluster_cancel_abstract_eval(try_cancel_buffer, + *transforms_leaves, + grid_names, + transforms_tree): + del try_cancel_buffer, transforms_leaves, transforms_tree + grid_idxs = (jax_core.ShapedArray((), jnp.int32),) * len(grid_names) + return ( + ( + *grid_idxs, + jax_core.ShapedArray((), jnp.bool_), + ), + {gpu_core._memory_effect}, + ) + + +@lowering.register_lowering_rule( + query_cluster_cancel_p, mgpu.LoweringSemantics.Lane +) +def query_cluster_cancel_lowering(ctx: lowering.LoweringRuleContext, + result_ref, + *transforms_leaves, + grid_names, + transforms_tree): + if transforms_tree is not None: + res_transforms = transforms_tree.unflatten(transforms_leaves) + result_ref, res_transforms = lowering._handle_transforms( + ctx, result_ref, res_transforms) + if res_transforms: + raise NotImplementedError( + f"Unimplemented transforms for result ref: {res_transforms}" + ) + + result_ty = ir.MemRefType(result_ref.type) + bits = math.prod(result_ty.shape) * mgpu.bitwidth(result_ty.element_type) + if bits != 128: + raise TypeError(f"Response to decode must be 128 bits, but is {bits} bits.") + + x, y, z, success = mgpu.query_cluster_cancel(result_ref) + cta_grid = [x, y, z] + i32 = ir.IntegerType.get_signless(32) + # Divide out the cluster dimensions. + for axis in ctx.module_ctx.axis_names.cluster: + dim = lowering._resolve_cluster_axis(ctx.module_ctx.axis_names, axis) # type: ignore[arg-type] + cta_grid[dim] = arith_dialect.divui( + cta_grid[dim], + mgpu.c(ctx.launch_ctx.cluster_size[dim], i32)) + # Convert to grid indices. + requested_idxs = [] + for axis_name in grid_names: + requested_idxs.append(lowering.block_id_to_grid_id( + ctx, cta_grid, axis_name)) + return (*requested_idxs, success) + + +def query_cluster_cancel( + result_ref: _Ref, + grid_names: Sequence[Hashable]) -> tuple[tuple[jax.Array, ...], jax.Array]: + """Decodes the result of a ``try_cluster_cancel`` operation. + + It interprets the 16-byte opaque response written to shared memory by a + completed ``try_cluster_cancel`` call to determine if a new work unit was + successfully claimed. + + Args: + result_ref: The SMEM ref containing the query response. + grid_names: A tuple of grid axis names to query for. + + Returns: + A tuple containing the decoded response: + - the grid indices for the requested axis names. + - A boolean indicating if the cancellation was successful. + + See also: + :func:`jax.experimental.pallas.mosaic_gpu.try_cluster_cancel` + """ + if isinstance(result_ref, pallas_core.TransformedRef): + result_transforms_leaves, result_transforms_tree = jax.tree.flatten( + result_ref.transforms + ) + result_ref = result_ref.ref + else: + result_transforms_leaves, result_transforms_tree = [], None + result = query_cluster_cancel_p.bind( + result_ref, + *result_transforms_leaves, + grid_names=grid_names, + transforms_tree=result_transforms_tree) + return tuple(result[:-1]), result[-1] + + +multimem_store_p = jax_core.Primitive("multimem_store") +multimem_store_p.multiple_results = True + + +def multimem_store(source: jax.Array, ref: _Ref, collective_axes: Hashable | tuple[Hashable, ...]): + """Stores the value to ref on all devices present in collective_axes. + + The stores is done using the multimem instructions, meaning that the data is + only transferred to the switch once, and broadcasted to all other devices + there. + + Args: + source: The value to store. + ref: The GMEM reference to store the value to. + collective_axes: The JAX mesh axes indicating the devices to store to. + """ + if isinstance(ref, pallas_core.TransformedRef): + transforms_leaves, transforms_tree = jax.tree.flatten( + ref.transforms + ) + ref = ref.ref + else: + transforms_leaves, transforms_tree = [], None + multimem_store_p.bind( + source, + ref, + *transforms_leaves, + collective_axes=collective_axes, + transforms_tree=transforms_tree, + ) + + +@multimem_store_p.def_effectful_abstract_eval +def _multimem_store_abstract_eval(source, ref, *transforms_leaves, transforms_tree, **_): + _check_ref(ref, "ref", gpu_core.GMEM) + shape, dtype = ref.shape, ref.dtype + if transforms_tree is not None: + transforms = jax.tree.unflatten(transforms_tree, transforms_leaves) + for t in transforms: + shape = t.transform_shape(shape) + dtype = t.transform_dtype(dtype) + if source.dtype != dtype: + raise ValueError(f"Value dtype {source.dtype} does not match ref dtype {dtype}") + if source.shape != shape: + raise ValueError(f"Value shape {source.shape} does not match ref shape {shape}") + return [], {pallas_core.comms_effect, state.WriteEffect(1)} + + +@lowering.register_lowering_rule(multimem_store_p, mgpu.LoweringSemantics.Lane) +def _multimem_store_lowering_rule( + ctx: lowering.LoweringRuleContext, value, local_ref, *transforms_leaves, transforms_tree, collective_axes, +): + if (mesh_info := ctx.module_ctx.mesh_info) is None: + raise ValueError( + "JAX device mesh is required by multimem_store, but not defined." + ) + if set(collective_axes) != set(mesh_info.axis_names): + raise NotImplementedError( + "Only collective_axes that include all JAX device mesh" + f" ({mesh_info.axis_names}) axes are supported, but got" + f" {collective_axes}" + ) + if not isinstance(value, mgpu.FragmentedArray): + raise TypeError(f"Can only store arrays (got {value}).") + if transforms_tree is not None: + transforms = tree_util.tree_unflatten(transforms_tree, transforms_leaves) + local_ref, transforms = lowering._handle_transforms( + ctx, local_ref, transforms, allow_peer_refs=False + ) + if transforms: + raise NotImplementedError( + f"Unhandled transforms for multimem_store: {transforms}" + ) + multi_ref = ctx.launch_ctx.to_remote_multicast(local_ref) + if not ctx.avals_in[0].shape: + multi_ref.store(lowering._ensure_ir_value(value, ctx.avals_out[0].dtype), []) + else: + value.store_untiled(multi_ref, optimized=False) + if ctx.module_ctx.auto_barriers: + mgpu.warpgroup_barrier() # Make sure the writes have completed. + return () + + +multimem_load_reduce_p = jax_core.Primitive("multimem_load_reduce") + +@multimem_load_reduce_p.def_effectful_abstract_eval +def _multimem_load_reduce_abstract_eval(ref, *avals_flat, tree, collective_axes, reduction_op): + del collective_axes, reduction_op + _check_ref(ref, "ref", gpu_core.GMEM) + shape, dtype = ref.shape, ref.dtype + if tree is not None: + transforms = jax.tree.unflatten(tree, avals_flat) + for t in transforms: + shape = t.transform_shape(shape) + dtype = t.transform_dtype(dtype) + return jax_core.ShapedArray(shape, dtype), {pallas_core.comms_effect} + +@lowering.register_lowering_rule(multimem_load_reduce_p, mgpu.LoweringSemantics.Lane) +def _multimem_load_reduce_lowering_rule( + ctx: lowering.LoweringRuleContext, ref, *transforms_leaves, tree, collective_axes, reduction_op, +): + if (mesh_info := ctx.module_ctx.mesh_info) is None: + raise ValueError( + "JAX device mesh is required by multimem_load_reduce, but not defined." + ) + if set(collective_axes) != set(mesh_info.axis_names): + raise NotImplementedError( + "Only collective_axes that include all JAX device mesh" + f" ({mesh_info.axis_names}) axes are supported, but got" + f" {collective_axes}" + ) + if (layout := ctx.out_layout_hint) is None: + raise RuntimeError( + "Failed to infer the output layout of multimem_load_reduce. Please apply" + " plgpu.layout_cast to its output right after its creation." + ) + if not isinstance(layout, (mgpu.TiledLayout, mgpu.WGStridedFragLayout)): + raise ValueError( + "Only tiled and WG strided layouts are supported by" + f" multimem_load_reduce, but got {layout}" + ) + dtype = ctx.avals_out[0].dtype + transforms = tree.unflatten(transforms_leaves) + ref, transforms = lowering._handle_transforms(ctx, ref, transforms, allow_peer_refs=False) + if transforms: + raise NotImplementedError( + f"Unhandled transforms for multimem_load_reduce: {transforms}" + ) + multi_ref = ctx.launch_ctx.to_remote_multicast(ref) + is_signed = mgpu_utils.is_signed(dtype) + arr = mgpu.FragmentedArray.load_reduce_untiled( + multi_ref, layout=layout, is_signed=is_signed, reduction=reduction_op + ) + return arr + +def multimem_load_reduce( + ref: _Ref, + *, + collective_axes: Hashable | tuple[Hashable, ...], + reduction_op: mgpu.MultimemReductionOp, +) -> jax.Array: + """Loads from a GMEM reference on all devices present in collective_axes and reduces the loaded values. + + The supported dtypes are: ``jnp.float32``, ``jnp.float16``, ``jnp.bfloat16``, + ``jnp.float8_e5m2``, ``jnp.float8_e4m3fn``, ``jnp.int32`` and ``jnp.int64``. + + 8-bit floating point dtypes are only supported on Blackwell GPUs. + + Args: + ref: The GMEM reference to load from. + collective_axes: The JAX mesh axes indicating the devices to load from. + reduction_op: The reduction operation to perform on the loaded values. The + allowed values are add (all dtypes), min, max (all dtypes but f32), as + well as and, or and xor (integer types only). + """ + ref, ref_transforms = state_primitives.get_ref_and_transforms( + ref, None, "multimem_load_reduce" + ) + flat_ref_transforms, ref_transforms_treedef = tree_util.tree_flatten( + ref_transforms + ) + return multimem_load_reduce_p.bind( + ref, + *flat_ref_transforms, + tree=ref_transforms_treedef, + collective_axes=collective_axes, + reduction_op=reduction_op, + ) + +semaphore_signal_multicast_p = jax_core.Primitive("semaphore_signal_multicast") +semaphore_signal_multicast_p.multiple_results = True + +def semaphore_signal_multicast( + semaphore, + value: int | jax.Array = 1, + *, + collective_axes: Hashable | tuple[Hashable, ...], +): + """Signals a semaphore on all devices along collective_axes. + + At the moment only signals to all devices are supported. + + Args: + semaphore: The semaphore reference to signal. + value: The increment value for the semaphore. + collective_axes: The mesh axes to multicast the signal across. + Must contain all mesh axes. + """ + if not isinstance(collective_axes, tuple): + collective_axes = (collective_axes,) + ref, transforms = pallas_primitives._get_ref_and_transforms(semaphore) + value = jnp.asarray(value, dtype=jnp.int32) + args = [ref, transforms, value] + flat_args, args_tree = tree_util.tree_flatten(args) + return semaphore_signal_multicast_p.bind( + *flat_args, + args_tree=args_tree, + collective_axes=collective_axes, + ) + + +@semaphore_signal_multicast_p.def_effectful_abstract_eval +def _semaphore_signal_multicast_abstract_eval(*avals, args_tree, collective_axes): + del collective_axes # Unused. + sem, _, _ = tree_util.tree_unflatten(args_tree, avals) + pallas_primitives.check_sem_avals(sem, None, "semaphore_signal_multicast") + return (), {pallas_core.comms_effect} + + +@lowering.register_lowering_rule(semaphore_signal_multicast_p, mgpu.LoweringSemantics.Lane) +def _semaphore_signal_multicast_lowering( + ctx: lowering.LoweringRuleContext, *args, args_tree, collective_axes +): + i32 = ir.IntegerType.get_signless(32) + sem, transforms, value = tree_util.tree_unflatten(args_tree, args) + sem, sem_transforms = lowering._handle_transforms(ctx, sem, transforms) + if sem_transforms: + raise NotImplementedError( + f"Unhandled transforms for semaphore_signal_multicast: {sem_transforms}" + ) + if not isinstance(collective_axes, (tuple, list)): + collective_axes = (collective_axes,) + if (mesh_info := ctx.module_ctx.mesh_info) is None: + raise ValueError("collective_axes requires a mesh context") + if set(collective_axes) != set(mesh_info.axis_names): + raise ValueError( + f"collective_axes {collective_axes} must equal entire mesh axes {mesh_info.axis_names}" + ) + multi_ref = ctx.launch_ctx.to_remote_multicast(sem) + if ctx.module_ctx.auto_barriers: + mgpu_utils.warpgroup_barrier() + val = lowering._ir_constant(value, i32) + mgpu_utils.SemaphoreRef.signal_multimem(mgpu_utils.memref_ptr(multi_ref.ref), val) + return () diff --git a/jax/_src/pallas/mosaic_gpu/torch.py b/jax/_src/pallas/mosaic_gpu/torch.py new file mode 100644 index 000000000000..a132042bc8c2 --- /dev/null +++ b/jax/_src/pallas/mosaic_gpu/torch.py @@ -0,0 +1,294 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PyTorch interop for Mosaic GPU.""" + +from __future__ import annotations + +import ctypes +from collections import defaultdict +import functools +import itertools +from typing import Callable, TypeGuard, Mapping +import weakref + +import jax +import jax.numpy as jnp +from jax._src import util +from jax._src.lib.mlir import ir +from jax._src.lib.mlir import passmanager +from jax._src.lib.mlir.dialects import func +from jax._src.lib.mlir.dialects import hlo +import jax.experimental.mosaic.gpu as mgpu +from jax.experimental.mosaic.gpu import core as mgpu_core + + +def as_torch_kernel(fn): + """Makes a Mosaic GPU kernel callable with PyTorch tensors. + + Args: + fn: A JAX function that invokes a Mosaic GPU kernel. Note that + the implementation currently only supports functions that contain a + single Mosaic GPU kernel invocation, without any other JAX API calls, + e.g. from :mod:`jax.numpy`. + + Returns: + A wrapper function that accepts PyTorch tensors as inputs and returns + PyTorch tensors as outputs. The output tensors are allocated on the + same device as the input tensors. + + Example:: + + @functools.partial( + pl.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.int32) + ) + def add_kernel(x_ref, y_ref, o_ref): + o_ref[...] = x_ref[...] + y_ref[...] + + x = torch.arange(128, dtype=torch.int32, device="cuda") + y = x * x + out = plgpu.as_torch_kernel(add_kernel)(x, y) + """ + @functools.wraps(fn) + def wrapper(*args): + in_structs = jax.tree.map( + lambda arg: jax.ShapeDtypeStruct( + # Drop the "torch." prefix from the dtype string, if present. + arg.shape, + str(arg.dtype).split(".")[-1], + ), + args, + ) + return _compile_fn(fn, in_structs)(*args) + + return wrapper + + +def _find_mgpu_call_in_module(module: ir.Module): + main_funcs = [ + op + for op in module.body.operations + if isinstance(op, func.FuncOp) and op.name.value == "main" + ] + # TODO(apaszke): Add support for jax.jit, which will call another function + # from main. + if len(main_funcs) != 1: + raise ValueError("Expected a single function in the kernel module") + [func_body] = main_funcs[0].body.blocks + return _find_mgpu_call(func_body, list(func_body.arguments)) + + +def _mlir_to_torch_dtype(torch, mlir_dtype: ir.Type): + if mlir_dtype == ir.F32Type.get(): + return torch.float32 + if mlir_dtype == ir.F16Type.get(): + return torch.float16 + if mlir_dtype == ir.BF16Type.get(): + return torch.bfloat16 + if isinstance(mlir_dtype, ir.IntegerType): + int_type = ir.IntegerType(mlir_dtype) + if int_type.is_signed or int_type.is_signless: + return getattr(torch, f"int{int_type.width}") + else: + return getattr(torch, f"uint{int_type.width}") + raise NotImplementedError(f"Unsupported MLIR type: {mlir_dtype}") + + +def _find_mgpu_call(block: ir.Block, args: list[ir.Value]): + import torch # type: ignore[import-not-found] # pytype: disable=import-error + mgpu_call: hlo.CustomCallOp | None = None + get_outputs = None + to_evaluate: list[Callable] = [] + init_env = {} + name_source = itertools.count() + value_names: Mapping[ir.Value, int] = defaultdict(lambda: next(name_source)) + for op in block.operations: + if _is_custom_call(op, "AllocateBuffer"): + def allocate_torch_buffer( + env, + device, + _shape=op.result.type.shape, + _dtype=_mlir_to_torch_dtype(torch, op.result.type.element_type), + _result_name=value_names[op.result], + ): + env[_result_name] = torch.empty(_shape, dtype=_dtype, device=device) + to_evaluate.append(allocate_torch_buffer) + elif _is_custom_call(op, "mosaic_gpu_v2"): + if mgpu_call is not None: + raise ValueError("Multiple Mosaic GPU kernels found in the module") + mgpu_call = op + elif op.name == "func.return" or op.name == "sdy.return": + if mgpu_call is None: + raise ValueError("No Mosaic GPU call found in the module") + if get_outputs is not None: + raise ValueError("Multiple return ops found in the module") + mgpu_results = list(mgpu_call.results) + try: + out_indices = [mgpu_results.index(o) for o in op.operands] + except ValueError: + raise ValueError("The function can only return kernel results") from None + def get_outputs(*results, _out_indices=out_indices): + return tuple(results[i] for i in _out_indices) + elif op.name == "stablehlo.constant": + result_type = ir.ShapedType(op.result.type) + if result_type.shape: + raise ValueError(f"Only scalar constants are supported, got {op}") + if not op.value.is_splat: + raise ValueError(f"Only splat constants are supported, got {op}") + if result_type.element_type == ir.IntegerType.get_signless(32): + init_env[value_names[op.result]] = ir.IntegerAttr( + op.value.get_splat_value() + ).value + else: + raise NotImplementedError(f"Only i32 constants are supported, got {op}") + elif op.name == "stablehlo.broadcast_in_dim": + if op.broadcast_dimensions: + raise ValueError("Only scalar broadcasts are supported") + target_shape = tuple(op.result.type.shape) + result_name = value_names[op.result] + operand_name = value_names[op.operand] + dtype = torch.int32 + def run_broadcast( + env, + device, + _target_shape=target_shape, + _dtype=dtype, + _operand_name=operand_name, + _result_name=result_name, + ): + env[_result_name] = torch.broadcast_to( + torch.as_tensor(env[_operand_name], dtype=_dtype, device=device), + _target_shape, + ) + + to_evaluate.append(run_broadcast) + else: + raise ValueError(f"Unsupported operation found in the kernel module: {op}") + if mgpu_call is None: + raise ValueError("No Mosaic GPU call found in the module") + if get_outputs is None: + raise ValueError("No return op found in the module") + + block_arg_names = [value_names[arg] for arg in block.arguments] + mgpu_arg_names = [value_names[arg] for arg in mgpu_call.operands] + def prepare_args(*user_args, device): + env = dict(init_env) + for name, arg in zip(block_arg_names, user_args, strict=True): + env[name] = arg + for thunk in to_evaluate: + thunk(env, device) + return tuple(env[name] for name in mgpu_arg_names) + output_input_aliases = [None] * len(mgpu_call.results) + for alias in mgpu_call.output_operand_aliases: + alias = hlo.OutputOperandAlias(alias) + if alias.operand_tuple_indices: + raise NotImplementedError("Tupled operand indices not supported") + if len(alias.output_tuple_indices) > 1: + raise NotImplementedError("Expected one element in output_tuple_indices") + [output_index] = alias.output_tuple_indices or (0,) + output_input_aliases[output_index] = alias.operand_index + + output_types = [ + (result.type.shape, _mlir_to_torch_dtype(torch, result.type.element_type)) + for result in mgpu_call.results + ] + def prepare_outputs(*all_args, device): + outputs = [] + for ty, alias in zip(output_types, output_input_aliases, strict=True): + if alias is not None: + outputs.append(all_args[alias]) + continue + outputs.append(torch.empty(ty[0], dtype=ty[1], device=device)) + return outputs + + return mgpu_call, prepare_args, prepare_outputs, get_outputs + + +def _is_custom_call(op: ir.Operation, name: str) -> TypeGuard[hlo.CustomCallOp]: + return isinstance(op, hlo.CustomCallOp) and op.call_target_name.value == name + + +@util.weakref_lru_cache +def _compile_fn(fn, in_structs): + try: + import torch # type: ignore[import-not-found] # pytype: disable=import-error + except ImportError: + raise RuntimeError("Can't compile for PyTorch: import torch failed") from None + + traced = jax.jit(fn).trace(*in_structs) + main_module = traced.lower().compiler_ir() + with main_module.context: + # jax.jit outlines its bodies which we undo for the interpreter. + mgpu.dialect.register_inliner_extensions(main_module.context) + inliner_pass = passmanager.PassManager.parse( + "builtin.module(inline{default-pipeline=})" + ) + inliner_pass.run(main_module.operation) + mgpu_call, prepare_args, prepare_outputs, get_outputs = _find_mgpu_call_in_module( + main_module + ) + + if not isinstance(in_structs, tuple): + in_structs = (in_structs,) + unwrap_output_tuple = False + if not isinstance(out_structs := traced.out_info, tuple): + out_structs = (out_structs,) + unwrap_output_tuple = True + flat_arg_types, expected_arg_treedef = jax.tree.flatten(in_structs) + _, out_treedef = jax.tree.flatten(out_structs) + + backend_config = mgpu_call.attributes["mhlo.backend_config"] + module_asm = backend_config["module"].value_bytes + launch, unload = mgpu_core._compile_as_torch_gpu_kernel(module_asm) + + def as_torch_dtype(dtype): + # torch contains NumPy-compatible dtypes in its top namespace + return getattr(torch, jnp.dtype(dtype).name) + + def apply(*user_args): + flat_user_args, arg_treedef = jax.tree.flatten(user_args) + if arg_treedef != expected_arg_treedef: + raise ValueError( + f"Invalid argument structure: expected {expected_arg_treedef}, got" + f" {arg_treedef}, ({user_args=})" + ) + for arg, expected_ty in zip(flat_user_args, flat_arg_types): + if arg.shape != expected_ty.shape: + raise ValueError( + f"Argument shape mismatch: expected {expected_ty.shape}, got" + f" {arg.shape}" + ) + if arg.dtype != as_torch_dtype(expected_ty.dtype): + raise ValueError( + "Argument dtype mismatch: expected" + f" {as_torch_dtype(expected_ty.dtype)}, got {arg.dtype}" + ) + + # We run all the ops that are necessary to prepare the arguments + device = torch.device("cuda") + flat_args = prepare_args(*flat_user_args, device=device) + flat_outs = prepare_outputs(*flat_args, device=device) + # Construct a device pointer list like in the XLA calling convention + buffers = (ctypes.c_void_p * (len(flat_args) + len(flat_outs)))() + for i, arg in enumerate(itertools.chain(flat_args, flat_outs)): + buffers[i] = arg.data_ptr() + launch(buffers, device) + user_outs = get_outputs(*flat_outs) + out = jax.tree.unflatten(out_treedef, user_outs) + return out[0] if unwrap_output_tuple else out + + # Unload the compiled code when the Python function is destroyed. + apply.destructor = weakref.ref(apply, lambda _weak_ref: unload) + + return apply diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index d0b74b2e5148..f478c30ab68a 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -15,71 +15,75 @@ """Module for calling pallas functions from JAX.""" from __future__ import annotations -from collections.abc import Callable, Sequence -import dataclasses +from collections.abc import Callable, Mapping, Sequence +import contextlib import enum +import math from functools import partial, reduce import types -from typing import Any, Literal, cast +from typing import Any + +from jax._src import api +import jax._src.lax as lax -import jax -from jax import lax from jax._src import ad_util from jax._src import api_util from jax._src import checkify from jax._src import config from jax._src import core as jax_core from jax._src import effects +from jax._src import hijax from jax._src import linear_util as lu from jax._src import state +from jax._src.traceback_util import api_boundary from jax._src import tree_util +from jax._src import typing as jax_typing +from jax._src.mesh import get_abstract_mesh +from jax._src.frozen_dict import FrozenDict from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.pallas import core as pallas_core -from jax._src.pallas import helpers as pallas_helpers from jax._src.pallas import hlo_interpreter from jax._src.pallas import primitives from jax._src.state import discharge as state_discharge +from jax._src.shard_map import shard_map, P, _as_manual_mesh from jax._src.state import types as state_types from jax._src.util import ( - foreach, safe_map, safe_zip, split_list, tuple_insert, unzip2, - weakref_lru_cache, ) -import jax.numpy as jnp +from jax._src import numpy as jnp + map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip -Grid = pallas_core.Grid -TupleGrid = pallas_core.TupleGrid -GridSpec = pallas_core.GridSpec BlockMapping = pallas_core.BlockMapping GridMapping = pallas_core.GridMapping -BlockSpec = pallas_core.BlockSpec -BlockSpecTree = pallas_core.BlockSpecTree -NoBlockSpec = pallas_core.NoBlockSpec no_block_spec = pallas_core.no_block_spec -ScratchShapeTree = pallas_core.ScratchShapeTree CostEstimate = pallas_core.CostEstimate +Backend = pallas_core.Backend +CompilerParams = pallas_core.CompilerParams # See the docstring for GridMapping for the calling convention -pallas_call_p = jax_core.Primitive('pallas_call') +pallas_call_p = hijax.HiPrimitive('pallas_call') pallas_call_p.multiple_results = True def _pallas_call_impl(*args, **params): # Call the lowering path - @partial(jax.jit, inline=True) + @partial(api.jit, inline=True) + def _jit_run(*args): return pallas_call_p.bind(*args, **params) - return _jit_run(*args) + + with config.disable_jit(False): + return _jit_run(*args) pallas_call_p.def_impl(_pallas_call_impl) @@ -89,28 +93,166 @@ def _pallas_call_abstract_eval( out_avals: tuple[jax_core.AbstractValue, ...], interpret, backend, + input_output_aliases, + grid_mapping, **params ): - del avals - - if isinstance(interpret, mosaic_tpu_interpret.TPUInterpretParams): + if isinstance(interpret, mosaic_tpu_interpret.InterpretParams): # Report effects that will be introduced when running/lowering - # mosaic_tpu_interpret.mosaic_tpu_interpret.interpret_pallas_call . + # mosaic_tpu_interpret.interpret_pallas_call . effs = mosaic_tpu_interpret.get_interpret_effects() + elif isinstance(interpret, mosaic_gpu_interpret.InterpretParams): + # Report effects that will be introduced when running/lowering + # mosaic_gpu_interpret.interpret_pallas_call . + effs = mosaic_gpu_interpret.get_interpret_effects() + elif getattr(params.get('compiler_params', None), 'has_side_effects', False): + effs = jax_core.GenericEffect(pallas_call_p) else: effs = jax_core.no_effects + # closed-over refs and dynamic grid bounds aren't reflected in + # input_output_aliases, though they are present in `avals`, so split them off + num_refs = sum(isinstance(a, state.AbstractRef) for a in avals) + _, _, avals = split_list(avals, [num_refs, grid_mapping.num_dynamic_grid_bounds]) + + inout_aliases = dict(input_output_aliases) + lin_avals = {i for i, a in enumerate(avals) + if isinstance(a, state_types.AbstractLinVal)} + if (missing := lin_avals - set(inout_aliases)): + raise ValueError(f"input pinned buffers without input_output_aliases:" + f"{missing}") + outin_aliases = {out_idx: in_idx for in_idx, out_idx in inout_aliases.items()} # Make sure we don't return ShapedArrayWithMemorySpace to the outside world. - return [ - jax_core.ShapedArray(a.shape, a.dtype, a.weak_type) - if isinstance(a, pallas_core.ShapedArrayWithMemorySpace) - else a - for a in out_avals - ], effs + out_avals = [jax_core.ShapedArray(a.shape, a.dtype, a.weak_type, + sharding=a.sharding) + if isinstance(a, pallas_core.ShapedArrayWithMemorySpace) else + avals[outin_aliases[out_idx]] if out_idx in outin_aliases + else a for out_idx, a in enumerate(out_avals)] + # TODO(mattjj,yashkatariya): if we hide vmapped away mesh axes, use this: + # if not (all(a.sharding.mesh.are_all_axes_manual for a in avals) and + # all(a.sharding.mesh.are_all_axes_manual for a in out_avals) and + # get_abstract_mesh().are_all_axes_manual): + # raise ValueError("pallas_call requires all mesh axes to be Manual, " + # f"got {get_abstract_mesh().axis_types}") + + # NOTE(mattjj,yashkatariya): this doesn't catch auto-mode non-manual axes + if not (all(p is None for a in avals if isinstance(a, jax_core.ShapedArray) + for p in a.sharding.spec) and + all(p is None for a in out_avals if isinstance(a, jax_core.ShapedArray) + for p in a.sharding.spec)): + raise ValueError("pallas_call requires all mesh axes to be Manual, " + f"got {get_abstract_mesh().axis_types}") + return out_avals, effs pallas_call_p.def_effectful_abstract_eval(_pallas_call_abstract_eval) +def _pallas_call_is_high(*_, jaxpr, **params): + del params + return jaxpr.is_high +pallas_call_p.is_high = _pallas_call_is_high # type: ignore + + +def _get_index_mapping(avals) -> dict[int, tuple[int, ...]]: + indices = {} + counter = 0 + for i, in_aval in enumerate(avals): + local_counter = [] + for _ in range(len(in_aval.lo_ty())): + local_counter.append(counter) + counter += 1 + indices[i] = tuple(local_counter) + return indices + + +def _pallas_call_to_lojax( + *hi_args, + jaxpr: jax_core.Jaxpr, + input_output_aliases: tuple[tuple[int, int], ...], + grid_mapping: GridMapping, + mesh: pallas_core.Mesh | None, + debug: bool, + interpret: Any, + compiler_params: Any, + cost_estimate: CostEstimate | None, + out_avals: tuple[jax_core.AbstractValue, ...], + backend: Backend | None, + metadata: FrozenDict[str, str] | None, + name: str | None, +): + if any(jax_core.get_aval(x).has_qdd for x in hi_args): + raise NotImplementedError("pallas_call does not support QDD for inputs") + if any(aval.has_qdd for aval in out_avals): + raise NotImplementedError("pallas_call does not support QDD for outputs") + closed_jaxpr = jax_core.ClosedJaxpr(jaxpr, ()) + with grid_mapping.trace_env(): + closed_lo_jaxpr = pe.lower_jaxpr(closed_jaxpr) + assert not closed_lo_jaxpr.consts + lo_jaxpr = closed_lo_jaxpr.jaxpr + for block_mapping in grid_mapping.block_mappings: + index_map_jaxpr = block_mapping.index_map_jaxpr + if index_map_jaxpr.jaxpr.is_high: + raise NotImplementedError( + "pallas_call does not support hijax for index_map" + ) + avals = [jax_core.get_aval(a) for a in hi_args] + lo_args = [lo_val for aval, x in zip(avals, hi_args) + for lo_val in (aval.read_loval(x) if aval.has_qdd + else aval.lower_val(x))] + lo_out_avals = [ + lo_aval + for aval in out_avals + for lo_aval in (aval.lo_ty() if aval.is_high else [aval]) + ] + lo_grid_mapping = grid_mapping.to_lojax() + in_avals = [v.aval for v in lo_jaxpr.invars] + scalar_prefetch_avals = in_avals[lo_grid_mapping.slice_index_ops] + operand_avals = in_avals[lo_grid_mapping.slice_block_ops] + scratch_avals = in_avals[lo_grid_mapping.slice_scratch_ops] + # Some basic checks + assert len(scalar_prefetch_avals) + len(operand_avals) + len( + scratch_avals + ) == len(in_avals) + assert len(lo_grid_mapping.block_mappings) + len( + lo_grid_mapping.scratch_avals + ) + lo_grid_mapping.num_index_operands == len(lo_jaxpr.invars), ( + len(lo_grid_mapping.block_mappings), + len(lo_grid_mapping.scratch_avals), + lo_grid_mapping.num_index_operands, + len(lo_jaxpr.invars), + ) + + # We need to update the input/output aliases to be in terms of the + # flattened lo inputs/outputs. + # Get mappings from hi input/outputs to the tuple of lo inputs/outputs + input_index_mapping = _get_index_mapping(avals) + output_index_mapping = _get_index_mapping(out_avals) + new_input_output_aliases = [] + # Alias lo inputs to lo outputs + for i, o in input_output_aliases: + assert i in input_index_mapping + assert o in output_index_mapping + for i_lo, o_lo in zip(input_index_mapping[i], output_index_mapping[o]): + new_input_output_aliases.append((i_lo, o_lo)) + + lo_outs = pallas_call_p.bind( + *lo_args, + jaxpr=lo_jaxpr, + grid_mapping=lo_grid_mapping, + mesh=mesh, + cost_estimate=cost_estimate, + backend=backend, + metadata=metadata, + compiler_params=compiler_params, + debug=debug, + interpret=interpret, + input_output_aliases=tuple(new_input_output_aliases), + out_avals=tuple(lo_out_avals), + name=name, + ) + return pe.raise_lo_outs(out_avals, lo_outs) +pallas_call_p.to_lojax = _pallas_call_to_lojax # type: ignore + def _pallas_call_jvp_rule( primals, @@ -121,11 +263,13 @@ def _pallas_call_jvp_rule( grid_mapping: GridMapping, mesh: pallas_core.Mesh | None, debug: bool, - interpret: bool, + interpret: Any, compiler_params: Any, cost_estimate: CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], - backend: _Backend | None, + backend: Backend | None, + metadata: FrozenDict[str, str] | None, + name: str | None, ): debug_info = jaxpr.debug_info if grid_mapping.num_dynamic_grid_bounds: @@ -192,6 +336,8 @@ def _pallas_call_jvp_rule( cost_estimate=jvp_cost_estimate, out_avals=(*out_avals, *out_avals), backend=backend, + metadata=metadata, + name=name, ) out_primals, out_tangents = split_list(out_flat, [len(out_flat) // 2]) return out_primals, out_tangents @@ -203,95 +349,71 @@ def _pallas_call_jvp_rule( def _batch_block_mapping( grid_mapping: GridMapping, axis_size: int, - for_ragged: bool, aval: jax_core.ShapedArray, dim: int | batching.NotMapped, block_mapping: BlockMapping, - ragged_axis_values, ) -> BlockMapping: def _block_map_function(new_idx, *args): - if for_ragged: - drop_last_args = args[:-1] - else: - drop_last_args = args + drop_last_args = args indices = jax_core.eval_jaxpr( block_mapping.index_map_jaxpr.jaxpr, block_mapping.index_map_jaxpr.consts, *drop_last_args, ) + unflat_indices = tree_util.tree_unflatten( + block_mapping.index_map_out_tree, indices) + if not isinstance(unflat_indices, tuple): + unflat_indices = (unflat_indices,) + unflat_indices = list(unflat_indices) if dim is not batching.not_mapped: - if isinstance(dim, batching.RaggedAxis): - assert for_ragged, "Ragged axis not supported for non-ragged batching." - stacked_axis = dim.stacked_axis - indices.insert(stacked_axis, new_idx) - else: - indices.insert(dim, new_idx) - return tuple(indices) + unflat_indices.insert(dim, new_idx) + return tuple(unflat_indices) idx_avals = [pallas_core.index_map_grid_aval, *block_mapping.index_map_jaxpr.in_avals] - if for_ragged: - if isinstance(dim, batching.RaggedAxis): - assert for_ragged, "Ragged axis not supported for non-ragged batching." - _, _, _, lengths_aval = ragged_axis_values - idx_avals = [*idx_avals, lengths_aval] - else: - i32_aval_memref = pallas_core.AbstractMemoryRef( - jax_core.ShapedArray(([axis_size]), jnp.int32), - pallas_core.MemorySpace.INDEX, - ) - idx_avals = [*idx_avals, i32_aval_memref] - + block_mapping_flat_fn, out_tree_thunk = api_util.flatten_fun_nokwargs( + lu.wrap_init(_block_map_function, + debug_info=block_mapping.index_map_jaxpr.jaxpr.debug_info.with_unknown_names()), + tree_util.tree_structure(idx_avals)) with grid_mapping.trace_env(): - block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(_block_map_function, - debug_info=block_mapping.index_map_jaxpr.jaxpr.debug_info), + block_mapping_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( + block_mapping_flat_fn, idx_avals) + new_index_map_out_tree = out_tree_thunk() shape = block_mapping.block_shape if dim is batching.not_mapped: new_block_shape = shape - new_array_shape_dtype = block_mapping.array_shape_dtype + new_array_aval = block_mapping.array_aval else: - if isinstance(dim, batching.RaggedAxis): - assert for_ragged, "Ragged axis not supported for non-ragged batching." - new_block_shape = shape - stacked_axis = dim.stacked_axis - new_block_shape = tuple_insert( - new_block_shape, stacked_axis, pallas_core.mapped - ) - else: - new_block_shape = tuple_insert(shape, dim, pallas_core.mapped) + new_block_shape = tuple_insert(shape, dim, pallas_core.squeezed) - array_shape = block_mapping.array_shape_dtype.shape - if isinstance(dim, batching.RaggedAxis): - assert for_ragged, "Ragged axis not supported for non-ragged batching." - stacked_axis = dim.stacked_axis - array_shape = tuple_insert(array_shape, stacked_axis, axis_size) - else: - array_shape = tuple_insert(array_shape, dim, axis_size) + array_shape = block_mapping.array_aval.shape + array_shape = tuple_insert(array_shape, dim, axis_size) - new_array_shape_dtype = jax.ShapeDtypeStruct( - array_shape, block_mapping.array_shape_dtype.dtype + new_array_aval = jax_core.ShapedArray( + array_shape, block_mapping.array_aval.dtype ) jaxpr = jax_core.ClosedJaxpr(block_mapping_jaxpr, consts) return block_mapping.replace(block_shape=new_block_shape, - array_shape_dtype=new_array_shape_dtype, - index_map_jaxpr=jaxpr) + array_aval=new_array_aval, + index_map_jaxpr=jaxpr, + index_map_out_tree=new_index_map_out_tree) def _broadcast_input_output_aliases( - args: Sequence[jax.Array], + args: Sequence[jax_typing.Array], + dims: Sequence[int | batching.NotMapped], *, input_output_aliases: tuple[tuple[int, int], ...], axis_size: int, -) -> tuple[tuple[jax.Array, ...], tuple[int | batching.NotMapped, ...]]: +) -> tuple[tuple[jax_typing.Array, ...], tuple[int | batching.NotMapped, ...]]: """Broadcast input/output operands. When we have input/output aliasing, since the output will be mapped, we need to make sure to broadcast the input across that dimension if it is not - mapped. If the input is mapped, but on a different axis, we tranpose the input + mapped. If the input is mapped, but on a different axis, we transpose the input to match the output. """ @@ -300,14 +422,9 @@ def _broadcast_input_output_aliases( for input_index, _ in input_output_aliases: dim = dims_[input_index] dims_[input_index] = 0 - if isinstance(dim, batching.RaggedAxis): - stacked_axis = dim.stacked_axis - if stacked_axis != 0: - raise NotImplementedError("Ragged aliasing on non 0 dim NYI") - return tuple(args_), tuple(dims_) - if dim is batching.not_mapped: - args_[input_index] = batching.broadcast(args_[input_index], axis_size, 0) + args_[input_index] = batching.broadcast( + args_[input_index], axis_size, 0, None) elif dim != 0: # TODO(cjfj): Change output batching axis instead? args_[input_index] = jnp.moveaxis(args[input_index], dim, 0) @@ -316,7 +433,7 @@ def _broadcast_input_output_aliases( def _batch_with_explicit_loop( - args: Sequence[jax.Array], + args: Sequence[jax_typing.Array], dims: Sequence[int | batching.NotMapped], *, jaxpr: jax_core.Jaxpr, @@ -324,11 +441,13 @@ def _batch_with_explicit_loop( mesh: pallas_core.Mesh | None, input_output_aliases: tuple[tuple[int, int], ...], debug: bool, - interpret: bool, + interpret: Any, compiler_params: Any, cost_estimate: CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], - backend: _Backend | None, + backend: Backend | None, + metadata: FrozenDict[str, str] | None, + name: str | None, ): """Batch the pallas_call by calling it in loop over the batch size. @@ -357,15 +476,16 @@ def _batch_with_explicit_loop( axis_size=axis_size, ) - # The output arrays are completelly overwritten, so we can just initialize + # The output arrays are completely overwritten, so we can just initialize # empty arrays. initial_state = [ - jnp.empty(tuple_insert(bm.array_shape_dtype.shape, 0, axis_size), - dtype=bm.array_shape_dtype.dtype) + jnp.empty(tuple_insert(bm.array_aval.shape, 0, axis_size), + dtype=bm.array_aval.dtype) for bm in grid_mapping.block_mappings_output ] - def body(batch_index: jax.Array, state: list[jax.Array]) -> list[jax.Array]: + def body(batch_index: jax_typing.Array, state: list[jax_typing.Array]) -> list[jax_typing.Array]: + batch_args = [] for arg, dim in zip(args, dims): @@ -376,7 +496,7 @@ def body(batch_index: jax.Array, state: list[jax.Array]) -> list[jax.Array]: else: batch_args.append( jnp.squeeze( - jax.lax.dynamic_slice_in_dim( + lax.dynamic_slice_in_dim( operand=arg, start_index=batch_index, slice_size=1, @@ -397,9 +517,11 @@ def body(batch_index: jax.Array, state: list[jax.Array]) -> list[jax.Array]: cost_estimate=cost_estimate, out_avals=out_avals, backend=backend, + metadata=metadata, + name=name, ) for i, batch_out_array in enumerate(batch_out): - state[i] = jax.lax.dynamic_update_index_in_dim( + state[i] = lax.dynamic_update_index_in_dim( state[i], batch_out_array, batch_index, @@ -408,12 +530,13 @@ def body(batch_index: jax.Array, state: list[jax.Array]) -> list[jax.Array]: return state - result = jax.lax.fori_loop(0, axis_size, body, initial_state, unroll=False) + result = lax.fori_loop(0, axis_size, body, initial_state, unroll=False) return result, (0,) * len(result) def _pallas_call_batching_rule( + axis_data, args, dims, *, @@ -422,51 +545,52 @@ def _pallas_call_batching_rule( mesh: pallas_core.Mesh | None, input_output_aliases: tuple[tuple[int, int], ...], debug: bool, - interpret: bool, + interpret: Any, compiler_params: Any, cost_estimate: CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], - backend: _Backend | None, + backend: Backend | None, + metadata: FrozenDict[str, str] | None = None, + name: str | None = None, ): if mesh is not None: raise NotImplementedError( "pallas_call with a mesh does not support batching" ) - def _maybe_squeeze_out_bdim( - x: jax.Array, bdim: int | batching.NotMapped - ) -> jax.Array: - if bdim is batching.not_mapped: - return x - return jnp.squeeze(x, axis=bdim) + def _maybe_squeeze_out_bdim(x: jax_typing.Array, bdim: int | batching.NotMapped + ) -> jax_typing.Array: + return x if bdim is batching.not_mapped else jnp.squeeze(x, axis=bdim) - def get_size(i, x, d): - if not isinstance(d, batching.RaggedAxis): - return x.shape[d] - return x.aval.shape[d.stacked_axis] + # this is the _global_ axis size if axis_data.explicit_mesh_axis is not None + # we want to convert it to the local axis size + axis_size = axis_data.size + ema = axis_data.explicit_mesh_axis + abs_mesh = get_abstract_mesh() + if ema: + mesh_size = math.prod(abs_mesh.shape[i] for i in ema) + axis_size, ragged = divmod(axis_size, mesh_size) + assert not ragged - (axis_size,) = { - get_size(i=i, x=x, d=d) - for i, (x, d) in enumerate(zip(args, dims)) - if d is not batching.not_mapped - } if axis_size == 1: # Why are we even vmapping? - args = map(_maybe_squeeze_out_bdim, args, dims) - out = pallas_call_p.bind( - *args, - jaxpr=jaxpr, - grid_mapping=grid_mapping, - mesh=mesh, - input_output_aliases=input_output_aliases, - debug=debug, - interpret=interpret, - compiler_params=compiler_params, - cost_estimate=cost_estimate, - out_avals=out_avals, - backend=backend, - ) - return [jnp.expand_dims(x, 0) for x in out], (0,) * len(out) + manual_out_avals = [ + o.update(sharding=o.sharding.update(mesh=_as_manual_mesh(o.sharding.mesh, ema))) + for o in out_avals] if ema else out_avals + def temp_f(*args): + args = map(_maybe_squeeze_out_bdim, args, dims) + out = pallas_call_p.bind( + *args, jaxpr=jaxpr, grid_mapping=grid_mapping, mesh=mesh, + input_output_aliases=input_output_aliases, debug=debug, + interpret=interpret, compiler_params=compiler_params, + cost_estimate=cost_estimate, out_avals=tuple(manual_out_avals), + backend=backend, metadata=metadata, name=name) + return [jnp.expand_dims(x, 0) for x in out] + if ema: + temp_f = remove_explicit(ema)(shard_map( + temp_f, out_specs=P(ema), axis_names=set(ema))) + out = temp_f(*args) + return out, (0,) * len(out) # The first num_dynamic_grid_bounds arguments are size-1 arrays that store # the size of the dynamic bounds. @@ -486,6 +610,8 @@ def get_size(i, x, d): elif any(bdim is not batching.not_mapped for bdim in dynamic_grid_dims): # TODO(amagni, sharadmv): Explore possibility of batching dynamic grid # bounds. + if ema: + raise NotImplementedError() return _batch_with_explicit_loop( args=dynamic_grid_args + args, dims=dynamic_grid_dims + dims, @@ -499,6 +625,8 @@ def get_size(i, x, d): cost_estimate=cost_estimate, out_avals=out_avals, backend=backend, + metadata=metadata, + name=name, ) else: pass # No dynamic grid dimensions @@ -521,6 +649,8 @@ def get_size(i, x, d): else: # TODO(amagni,sharadmv,apaszke): enable efficient batching over # prefetched scalar args. + if ema: + raise NotImplementedError() return _batch_with_explicit_loop( args=scalar_args + args, dims=scalar_bdims + bdims, @@ -534,6 +664,8 @@ def get_size(i, x, d): cost_estimate=cost_estimate, out_avals=out_avals, backend=backend, + metadata=metadata, + name=name, ) if not dims: @@ -551,30 +683,7 @@ def get_size(i, x, d): args, dims, input_output_aliases=input_output_aliases, axis_size=axis_size ) - # Each dim either has data about its ragged axis, or None - ragged_axis_values = [] - for d in dims: - if isinstance(d, batching.RaggedAxis): - stacked_axis, ragged_axis_dim, ragged_axis_length = ( - batching._ragged_axis_parts(d) - ) - aval = jax_core.get_aval(ragged_axis_length).update(dtype=jnp.int32) - if isinstance(aval, jax_core.DShapedArray): - aval = jax_core.ShapedArray(aval.shape, aval.dtype, aval.weak_type) - lengths_aval = pallas_core.AbstractMemoryRef( - aval, - pallas_core.MemorySpace.INDEX, - ) - # TODO(mvoz): Give this its own type - ragged_axis_values.append( - (stacked_axis, ragged_axis_dim, ragged_axis_length, lengths_aval) - ) - else: - ragged_axis_values.append(None) # type: ignore[arg-type] - all_dims = list(dims) + [0] * grid_mapping.num_outputs - ragged_axis_values = ragged_axis_values + [None] * grid_mapping.num_outputs - num_index_operands = grid_mapping.num_index_operands num_scratch_operands = grid_mapping.num_scratch_operands @@ -588,34 +697,16 @@ def get_size(i, x, d): _batch_block_mapping, grid_mapping, axis_size, - any(ragged_axis_values), ), avals_to_batch, all_dims[num_index_operands:], block_mappings, - ragged_axis_values[num_index_operands:], ) index_map_tree_args, index_map_tree_kwargs = grid_mapping.index_map_tree.unflatten( grid_mapping.index_map_avals) assert not index_map_tree_kwargs batched_index_map_args = (pallas_core.index_map_grid_aval,) + index_map_tree_args - - lengths_aval = None # type: ignore[assignment] - - # Check all the ragged axis values, ensure their raggedness pattern - # is identical (consider moving this check up!) - for rav in ragged_axis_values: - if rav is not None: - if lengths_aval is None: - lengths_aval = rav[3] - else: - assert lengths_aval == rav[3], "NYI - different lengths in ragged batch" - - if lengths_aval: - batched_index_map_args = batched_index_map_args + (lengths_aval,) - num_index_operands += 1 - batched_index_map_avals, batched_index_map_tree = tree_util.tree_flatten( (batched_index_map_args, {})) @@ -628,7 +719,10 @@ def get_size(i, x, d): vmapped_dims=(0,) + tuple(a + 1 for a in grid_mapping.vmapped_dims), ) - if cost_estimate is not None: + # Avoid scaling the cost estimate by the batch size if the batch size is a + # dynamic shape (DimExpr). + # https://docs.jax.dev/en/latest/export/shape_poly.html#computing-with-dimension-variables + if cost_estimate is not None and isinstance(axis_size, int): batched_cost_estimate = CostEstimate( flops=cost_estimate.flops * axis_size, bytes_accessed=cost_estimate.bytes_accessed * axis_size, @@ -637,287 +731,47 @@ def get_size(i, x, d): else: batched_cost_estimate = None - # Start the ragged handling code - # Here, we: - # - Rewrite the indexer to save memory (skip indices outside the ragged bounds) - # - Rewrite the kernel to save compute (skip elements outside the ragged bounds) - # - Update various internal structures/metadata to account for the new - # block spec. - # - Set the hacky flag of ragged_originating on the mapping, to signal to - # the lowering code to treat mapped dimensions as part of the user grid. - if lengths_aval: - batched_grid_mapping = batched_grid_mapping.replace( - get_grid_indices=lambda indices, maybe_include_mapped_dims: indices, - local_grid_env=lambda loop_idx, grid: tuple( - pallas_core.GridAxis(idx, b) for (idx, b) in zip(loop_idx, grid) - ), - ) - - # Note - on zero filling counterfactuals - # A debug util to produce a counterfactual version of the when - # gating, where for all values that don't pass the @when check, - # we write 0s. This is useful for debugging, as certain lowering paths - # like mosaic will write the last data as passthrough, leading to - # potentially confusing results. - block_mapped_dim_idxs = [] - for block_mapping in batched_grid_mapping.block_mappings: - mapped_dim_idxs = [] - for i, d in enumerate(block_mapping.block_shape): - if d is pallas_core.mapped: - mapped_dim_idxs.append(i) - else: - mapped_dim_idxs.append(None) # type: ignore[arg-type] - block_mapped_dim_idxs.append(mapped_dim_idxs) - - mapped_dim_idx = None - for rav, mapped_dim_idxs in zip(ragged_axis_values, block_mapped_dim_idxs): - if rav is not None: - stacked_axis = rav[0] - if mapped_dim_idx is None: - mapped_dim_idx = mapped_dim_idxs[stacked_axis] - if mapped_dim_idxs[stacked_axis] is None: - raise ValueError( - f"Expected mapped dim to be {stacked_axis}, but got" - f" {mapped_dim_idxs[stacked_axis]}" - ) - else: - assert mapped_dim_idx == mapped_dim_idxs[stacked_axis], ( - f"Different mapped dims - expected {mapped_dim_idx}, but got" - f" {mapped_dim_idxs[stacked_axis]}" - ) - - # This is the blockspec size of the dimension - block_shapes = [b.block_shape for b in batched_grid_mapping.block_mappings] - - # Parse out the operations from the jaxpr to determine how to mask the output - # NOTE! while this *could* be a default dict of None, and None is sound, as - # it denotes that there is no raggedness for the given var, we explicitly - # do not do this, so as to get better signal on implementation of rules - # A misimplemented rule that does not account for new vars being introduced - # will result in an error on the next op using the new var. The benefit of - # of forcing implementers to account for all outputs and intermediaries is - # a very nice one. - - var_to_raggedness = {} - for invar, rav in zip(jaxpr.invars, ragged_axis_values): - var_to_raggedness[invar] = rav - - for eqn in jaxpr.eqns: - prim = eqn.primitive - if prim not in batching.ragged_prop_rules: - raise NotImplementedError(f"Not implemented - ragged prop for {prim}") - rule = batching.ragged_prop_rules[prim] - - invar_raggedness = [ - ( - var_to_raggedness.get(invar, None) - if isinstance(invar, jax_core.Var) - else None - ) - for invar in eqn.invars - ] - try: - invar_raggedness, outvar_raggedness = rule( - eqn.params, invar_raggedness, eqn.outvars # type: ignore[arg-type] - ) - except Exception as e: - raise RuntimeError( - f"Failed to run rule for {prim}. invars: {eqn.invars}, outvars:" - f" {eqn.outvars}. Underlying reason: {e}" - ) from e - - for invar, rav in zip(eqn.invars, invar_raggedness): # type: ignore[assignment] - if isinstance(invar, jax_core.Var): - var_to_raggedness[invar] = rav - for outvar, rav in zip(eqn.outvars, outvar_raggedness): - if isinstance(outvar, jax_core.Var): - var_to_raggedness[outvar] = rav - - for pos, invar in enumerate(jaxpr.invars): - ragged_axis_values[pos] = var_to_raggedness[invar] - - per_input_ragged_axis_dim: list[int | None] = [] - for rav in ragged_axis_values: - if rav is not None: - per_input_ragged_axis_dim.append(rav[1]) - else: - per_input_ragged_axis_dim.append(None) - - def when_wrapped_kernel(lengths_ref, *args, **kwargs): - b_idx = primitives.program_id(mapped_dim_idx) - - b_len = lengths_ref[b_idx] - run_kernel = jnp.array(True) - for i, _ in enumerate(args): - ragged_axis_dim = per_input_ragged_axis_dim[i] - if ragged_axis_dim is None: - continue - arg_i_idx = ( - primitives.program_id(ragged_axis_dim) - * block_shapes[i][ragged_axis_dim] - ) - run_kernel = jnp.logical_and(run_kernel, arg_i_idx < b_len) - - # TODO(mvoz): Unimplemented primitive in pallas - # b_len_mod = jnp.equal(jnp.mod(b_len, val_at_ragged_dim), 0) - # checkify.check(b_len_mod, "b_len % val_at_ragged_dim != 0") - - @pallas_helpers.when(run_kernel) - def f(): - # Important! This allows us to trace the inner kernel with the correct - # grid to preserve user program_id semantics. Ex: program_id(0) will - # always be analogous to program_id(1) in the outer kernel. - with pallas_core.tracing_grid_env(grid_mapping.grid, ()): - jax_core.eval_jaxpr(jaxpr, (), *args, **kwargs) - - kernel_avals = [lengths_aval] + [v.aval for v in jaxpr.invars] - flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten( - list(kernel_avals) - ) - - def _rewrite_index_jaxpr(enumerate_batched_block_mapping): - arg_pos, batched_block_mapping = enumerate_batched_block_mapping - indexer_avals = [ - v.aval for v in batched_block_mapping.index_map_jaxpr.jaxpr.invars - ] - flat_indexer_avals, indexer_in_tree = tree_util.tree_flatten( - list(indexer_avals) - ) - - def index_rewrite_kernel(*indexer_args): - ragged_axis_dim = per_input_ragged_axis_dim[arg_pos] - - # the problem here seems to be that we are rnning this for all inputs, per input, because they each have an indexer - which means - # that the indexer for output isnt getting written - before, it always was - - lengths_ref = indexer_args[-1] - rest_indexer_args = indexer_args[:-1] - # Lengths are always the last argument of the indexer. - # lengths_ref = args[-1] - # Invariant: Stacked axis is enforced to be the mapped axis above. - b_idx = indexer_args[mapped_dim_idx] - - nargs = list(rest_indexer_args) - - if ragged_axis_dim is not None: - val_at_ragged_dim = batched_block_mapping.block_shape[ragged_axis_dim] - - # The current index into the ragged dimension. - # Invariant: There is only one ragged dimension, enforced above. - i_idx = indexer_args[ragged_axis_dim] - - # grid space -> element space - i_len = i_idx * val_at_ragged_dim - - # The length of the current batch. - b_len = lengths_ref[b_idx] - - # Have we reached the end of the current batch? - not_done = i_len < b_len - - am_last_batch = b_idx == axis_size - 1 - last_good_block = lax.div(b_len, val_at_ragged_dim) - 1 - - # The logic below can be thought of as: - # if index_oob_ragged: - # if not last_batch: - # batch_idx += 1 - # ragged_idx = 0 - # else: - # ragged_idx = last_good_block - # - # wherein we find the next good block by incrementing the batch index - # and setting the ragged index to 0 if we are not in the last batch. - # Otherwise, we set the ragged index to the last good block. - b_next = jnp.where( - not_done, b_idx, jnp.where(am_last_batch, b_idx, b_idx + 1) - ) - i_next = jnp.where( - not_done, i_idx, jnp.where(am_last_batch, last_good_block, 0) - ) - nargs[ragged_axis_dim] = i_next - nargs[mapped_dim_idx] = b_next - - nargs = nargs + [lengths_ref] - return jax_core.eval_jaxpr( - batched_block_mapping.index_map_jaxpr.jaxpr, - batched_block_mapping.index_map_jaxpr.consts, - *nargs, - ) - index_jaxpr, _ = _trace_kernel_to_jaxpr( - index_rewrite_kernel, - batched_block_mapping.index_map_jaxpr.jaxpr.debug_info, - batched_grid_mapping, - tuple(flat_indexer_avals), - indexer_in_tree, - tuple(() for _ in flat_indexer_avals), - indexer=True, - ) - - batched_block_mapping = batched_block_mapping.replace( - index_map_jaxpr=pe.close_jaxpr(index_jaxpr) - ) - return batched_block_mapping - - # Important! This allows us to trace the outer kernel with the correct grid - # to enable accessing the batch program_id. - with pallas_core.tracing_grid_env(batched_grid_mapping.grid, ()): - batched_block_mappings = map( - _rewrite_index_jaxpr, enumerate(batched_block_mappings) - ) - - batched_grid_mapping = batched_grid_mapping.replace( - block_mappings=tuple(batched_block_mappings), - ) - - jaxpr, consts = _trace_kernel_to_jaxpr( - when_wrapped_kernel, - jaxpr.debug_info, - batched_grid_mapping, - tuple(flat_kernel_avals), - kernel_in_tree, - tuple(() for _ in flat_kernel_avals), - ) - if consts: - raise NotImplementedError("consts not supported in pallas_call") - - # We need to rewrite the input_output_aliases here, the initial call - # to broadcast is done, and we have inseted a new input (lengths), so - # there's an off-by-one here now. - new_input_output_aliases = [] - for k, v in input_output_aliases: - new_input_output_aliases.append((k + 1, v)) - input_output_aliases = tuple(new_input_output_aliases) - - # assert ragged_axis_length is not None - args = (ragged_axis_length, *args) assert all(isinstance(aval, jax_core.ShapedArray) for aval in out_avals) batched_out_avals = [] for aval in out_avals: - sharding = aval.sharding.with_spec(tuple_insert(aval.sharding.spec, 0, None)) + manual_mesh = (_as_manual_mesh(aval.sharding.mesh, ema) if ema else + aval.sharding.mesh) + sharding = aval.sharding.update( + mesh=manual_mesh, spec=tuple_insert(aval.sharding.spec, 0, None)) shape = tuple_insert(aval.shape, 0, axis_size) batched_out_avals.append(aval.update(shape=shape, sharding=sharding)) batched_out_avals = tuple(batched_out_avals) - out = pallas_call_p.bind( - *dynamic_grid_args, - *args, - jaxpr=jaxpr, - grid_mapping=batched_grid_mapping, - mesh=mesh, - input_output_aliases=input_output_aliases, - debug=debug, - interpret=interpret, - compiler_params=compiler_params, - cost_estimate=batched_cost_estimate, - out_avals=batched_out_avals, - backend=backend, - ) + bind = partial( + pallas_call_p.bind, jaxpr=jaxpr, grid_mapping=batched_grid_mapping, + mesh=mesh, input_output_aliases=input_output_aliases, debug=debug, + interpret=interpret, compiler_params=compiler_params, + cost_estimate=batched_cost_estimate, out_avals=batched_out_avals, + backend=backend, metadata=metadata, name=name) + + if ema: + # TODO all batching rules should probably be in outer mesh ctx + bind = remove_explicit(ema)(shard_map( + bind, out_specs=P(ema), axis_names=set(ema))) + + out = bind(*dynamic_grid_args, *args) return out, (0,) * len(out) +batching.fancy_primitive_batchers[pallas_call_p] = _pallas_call_batching_rule +batching.skippable_batchers[pallas_call_p] = lambda _: () + -batching.primitive_batchers[pallas_call_p] = _pallas_call_batching_rule +@contextlib.contextmanager +def remove_explicit(ema): + prev = jax_core.trace_ctx.axis_env + # assert set(prev.explicit_mesh_axis_names) == set(ema) + new = jax_core.AxisEnv(prev.axis_sizes, prev.spmd_axis_names, set()) + try: + jax_core.trace_ctx.set_axis_env(new) + yield + finally: + jax_core.trace_ctx.set_axis_env(prev) def checkify_pallas_kernel_body_jaxpr( @@ -965,22 +819,21 @@ def pallas_call_checkify_oob_grid(error: checkify.Error, num_iterations = 1 is_indexing_dim = [ - tuple(b is pallas_core.mapped for b in bm.block_shape) + tuple(isinstance(b, pallas_core.Squeezed) for b in bm.block_shape) for bm in grid_mapping.block_mappings ] block_shapes = [ - None if iid is None - else tuple(1 if i else b for i, b in zip(iid, bm.block_shape)) - for iid, bm in zip(is_indexing_dim, grid_mapping.block_mappings) + pallas_core._get_block_shape(bm.block_shape) + for bm in grid_mapping.block_mappings ] # The scan carry: (i, loop_idx, *consts, *ins, *outs, *scratch) - # i:int32 is the interation index + # i:int32 is the iteration index # loop_idx: tuple[int32] are the program ids for each grid axis def cond(carry): i, *_ = carry return i < num_iterations def body(carry): - i, loop_idx = carry + i, loop_idx, blocks = carry if grid_mapping.local_grid_env is not None: local_grid_env = grid_mapping.local_grid_env(loop_idx, grid) else: @@ -995,14 +848,14 @@ def body(carry): for bm in grid_mapping.block_mappings] # We perform a dynamic slice on the i/o blocks, which will be checked by # checkify for OOB accesses. - foreach(hlo_interpreter._dynamic_slice, start_indices, block_shapes, + blocks = map(hlo_interpreter._dynamic_slice, start_indices, block_shapes, [*input_args, *output_args], is_indexing_dim) - return (i + 1, hlo_interpreter._get_next_indices(grid, loop_idx)) + return (i + 1, hlo_interpreter._get_next_indices(grid, loop_idx), blocks) def f(_): - lax.while_loop( - cond, body, (jnp.int32(0), grid_start_indices) + return lax.while_loop( + cond, body, (jnp.int32(0), grid_start_indices, [jnp.zeros(shape) for shape in block_shapes]) ) - flat_args, jaxpr_in_tree = jax.tree_util.tree_flatten((jnp.int32(0),)) + flat_args, jaxpr_in_tree = tree_util.tree_flatten((jnp.int32(0),)) wrapped_loop, _ = api_util.flatten_fun_nokwargs( lu.wrap_init(f, debug_info=api_util.debug_info("checkify oob_grid_access", @@ -1010,7 +863,7 @@ def f(_): jaxpr_in_tree) with pallas_core.tracing_grid_env(grid_mapping.grid, ()): avals_in = map(jax_core.get_aval, flat_args) - traced_loop, _, consts, () = pe.trace_to_jaxpr_dynamic( + traced_loop, _, consts = pe.trace_to_jaxpr_dynamic( wrapped_loop, list(avals_in)) traced_loop = jax_core.ClosedJaxpr(traced_loop, consts) out_error, _ = checkify.checkify_jaxpr( @@ -1021,7 +874,7 @@ def pallas_call_checkify_rule(error: checkify.Error, enabled_errors, *args: jax_core.Value, jaxpr: jax_core.Jaxpr, - interpret: bool, + interpret: Any, input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: GridMapping, out_avals: tuple[jax_core.AbstractValue, ...], @@ -1052,7 +905,7 @@ def pallas_call_checkify_rule(error: checkify.Error, _jaxpr, _, error_effects = checkify_pallas_kernel_body_jaxpr( closed_jaxpr, enabled_errors, error, grid_mapping) error = error._add_placeholder_effects(error_effects) - err_vals, err_in_tree = jax.tree.flatten(error) + err_vals, err_in_tree = tree_util.tree_flatten(error) shaped_err_avals = map(jax_core.get_aval, err_vals) # Trace the kernel jaxpr to get a checkified jaxpr. This jaxpr will have @@ -1104,42 +957,47 @@ def _ensure_2d_error_shape(arg): dtype = arg.dtype return jax_core.ShapedArray((1, 1) + arg.shape, dtype=dtype, weak_type=arg.weak_type) - elif isinstance(arg, jax.Array): + elif isinstance(arg, jax_typing.Array): return jnp.reshape(arg, (1, 1) + arg.shape) else: return jnp.array([[arg]]) shaped_err_avals = map(_ensure_2d_error_shape, shaped_err_avals) err_vals = map(_ensure_2d_error_shape, err_vals) - error_memref_aval = [pallas_core.AbstractMemoryRef( + error_memref_aval = [state.AbstractRef( err_val, pallas_core.MemorySpace.ERROR) for err_val in shaped_err_avals] shaped_scalar_avals, input_aval, output_aval, scratch_aval = split_list( shaped_input_avals, [num_scalars, num_kernel_inputs, num_kernel_outputs]) retrace_in_avals = [*shaped_scalar_avals, *error_memref_aval, *input_aval, *error_memref_aval, *output_aval, *scratch_aval] jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(retrace_in_avals) - debug = api_util.debug_info("checkify_pallas", checked_kernel_fn, + debug_info = api_util.debug_info("checkify_pallas", checked_kernel_fn, retrace_in_avals, {}) wrapped_kernel_with_err, out_tree_thunk = api_util.flatten_fun_nokwargs( - lu.wrap_init(checked_kernel_fn, debug_info=debug), jaxpr_in_tree) + lu.wrap_init(checked_kernel_fn, debug_info=debug_info), jaxpr_in_tree) with pallas_core.tracing_grid_env(grid_mapping.grid, ()): - final_jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( + final_jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( wrapped_kernel_with_err, jaxpr_flat_avals) # Prepare pallas_call inputs. We need to create new block specs # for the new error inputs and outputs. error_block_specs = [pallas_core.BlockSpec(None, None)] * len(shaped_err_avals) error_paths, _ = unzip2(tree_util.tree_flatten_with_path(error_block_specs)[0]) - error_origins = tuple(f"errrors[{tree_util.keystr(p)}" for p in error_paths) + error_origins = tuple(f"errors[{tree_util.keystr(p)}" for p in error_paths) error_block_mappings = map( - partial( - pallas_core._convert_block_spec_to_block_mapping, - index_map_avals=grid_mapping.index_map_avals, - index_map_tree=grid_mapping.index_map_tree, - grid=grid_mapping.grid, - mapped_dims=grid_mapping.vmapped_dims), - error_block_specs, error_origins, shaped_err_avals) + partial( + pallas_core._convert_block_spec_to_block_mapping, + index_map_avals=grid_mapping.index_map_avals, + index_map_tree=grid_mapping.index_map_tree, + grid=grid_mapping.grid, + vmapped_dims=grid_mapping.vmapped_dims, + debug=True, + ), + error_block_specs, + error_origins, + shaped_err_avals, + ) input_block_mappings, output_block_mappings = split_list( grid_mapping.block_mappings, [num_kernel_inputs,]) grid_mapping_with_error = grid_mapping.replace( @@ -1167,12 +1025,11 @@ def _ensure_2d_error_shape(arg): errors, results = split_list(result, [num_err_vals]) # TODO(b/350593266): Remove line below once we support ()-shaped scalars. errors = [err_val[0, 0] for err_val in errors] - new_error, _ = jax.tree.unflatten(error_out_tree, errors) + new_error, _ = tree_util.tree_unflatten(error_out_tree, errors) return new_error, results checkify.error_checks[pallas_call_p] = pallas_call_checkify_rule -@weakref_lru_cache def _trace_kernel_to_jaxpr( fun: Callable, debug_info: jax_core.DebugInfo, @@ -1181,22 +1038,32 @@ def _trace_kernel_to_jaxpr( kernel_in_tree: tree_util.PyTreeDef, kernel_in_transforms: tuple[tuple[pallas_core.Transform, ...], ...], indexer: bool = False, -) -> tuple[jax_core.ClosedJaxpr, tuple[jax.Array, ...]]: +) -> tuple[jax_core.Jaxpr, tuple[jax_typing.Array, ...]]: wrapped_kernel_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( lu.wrap_init(fun, debug_info=debug_info), kernel_in_tree) wrapped_kernel_fun = primitives.wrap_with_transforms( wrapped_kernel_fun, kernel_in_transforms ) - with grid_mapping.trace_env(): - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun, - kernel_avals) + with grid_mapping.trace_env(), config._check_vma(False): + with config.mutable_array_checks(False): + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( + wrapped_kernel_fun, kernel_avals) if consts: - consts_avals = [jax_core.get_aval(c) for c in consts] - if any(not isinstance(aval, state.AbstractRef) for aval in consts_avals): + consts_avals = [ + aval + for c in consts + if not isinstance(aval := jax_core.get_aval(c), state.AbstractRef) + ] + if consts_avals: + ctx = jax_core.JaxprPpContext() + pp_consts_avals = ", ".join( + jax_core.pp_aval(aval, ctx) for aval in consts_avals + ) raise ValueError( - f"The kernel function in the pallas_call {debug_info.func_src_info} " - f"captures constants {consts_avals}. " - "You should pass them as inputs") + "The kernel function in the pallas_call" + f" {debug_info.func_src_info} captures constants" + f" [{pp_consts_avals}]. You should pass them as inputs." + ) kernel_out_tree = out_tree_thunk() if not indexer and kernel_out_tree != tree_util.tree_structure(None): @@ -1206,9 +1073,9 @@ def _trace_kernel_to_jaxpr( return jaxpr, tuple(consts) -_PALLAS_USE_MOSAIC_GPU = config.bool_flag( +_PALLAS_USE_MOSAIC_GPU = config.bool_state( "jax_pallas_use_mosaic_gpu", - default=config.bool_env("JAX_PALLAS_USE_MOSAIC_GPU", False), + default=config.bool_env("JAX_PALLAS_USE_MOSAIC_GPU", True), help=( "If True, lower Pallas kernels to the experimental Mosaic GPU" " dialect, instead of Triton IR." @@ -1216,44 +1083,33 @@ def _trace_kernel_to_jaxpr( ) -_PALLAS_VERBOSE_ERRORS = config.bool_flag( - "jax_pallas_verbose_errors", - default=config.bool_env("JAX_PALLAS_VERBOSE_ERRORS", True), - help=( - "If True, print verbose error messages for Pallas kernels." - ), -) - - -def _verbose_errors_enabled() -> bool: - return _PALLAS_VERBOSE_ERRORS.value - - def _unsupported_lowering_error(platform: str) -> Exception: return ValueError( f"Cannot lower pallas_call on platform: {platform}. To use Pallas on GPU," " install jaxlib GPU 0.4.24 or newer. To use Pallas on TPU, install" " jaxlib TPU and libtpu. See" - " https://jax.readthedocs.io/en/latest/installation.html." + " https://docs.jax.dev/en/latest/installation.html." ) -_Backend = Literal["mosaic_tpu", "triton", "mosaic_gpu"] - def _pallas_call_lowering( ctx: mlir.LoweringRuleContext, *in_nodes, - interpret: bool, - backend: _Backend | None, + interpret: Any, + backend: Backend | None, **params, ): if params['jaxpr'].constvars: raise ValueError('Cannot lower a pallas_call with constants.') if interpret: - if isinstance(interpret, mosaic_tpu_interpret.TPUInterpretParams): + if isinstance(interpret, mosaic_tpu_interpret.InterpretParams): impl = partial(mosaic_tpu_interpret.interpret_pallas_call, interpret_params=interpret, **params) + elif isinstance(interpret, mosaic_gpu_interpret.InterpretParams): + impl = partial(mosaic_gpu_interpret.interpret_pallas_call, + interpret_params=interpret, + **params) else: impl = partial(hlo_interpreter.pallas_call_hlo_interpret, backend=backend, @@ -1279,14 +1135,21 @@ def tpu_lowering(ctx: mlir.LoweringRuleContext, def gpu_lowering(ctx: mlir.LoweringRuleContext, *in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value], **params): + is_rocm = ctx.module_context.platforms == ("rocm",) try: match backend: case "mosaic_gpu": + if is_rocm: + raise ValueError( + "Mosaic GPU backend does not yet support AMD ROCm devices. " + "Use backend='triton' for ROCm." + ) from jax._src.pallas.mosaic_gpu import pallas_call_registration case "triton": from jax._src.pallas.triton import pallas_call_registration # type: ignore case None: - if _PALLAS_USE_MOSAIC_GPU.value: + # Mosaic GPU only supports NVIDIA CUDA, not AMD ROCm. + if _PALLAS_USE_MOSAIC_GPU.value and not is_rocm: from jax._src.pallas.mosaic_gpu import pallas_call_registration else: from jax._src.pallas.triton import pallas_call_registration # type: ignore @@ -1324,7 +1187,8 @@ def _pallas_custom_str_eqn_compact( _pallas_custom_str_eqn_compact ) -def _pallas_call_typecheck_rule(*in_avals, grid_mapping, **params): +def _pallas_call_typecheck_rule(ctx_factory, *in_atoms, grid_mapping, **params): + in_avals = [x.aval for x in in_atoms] with grid_mapping.trace_env(): return pallas_call_p.abstract_eval( *in_avals, grid_mapping=grid_mapping, **params @@ -1333,10 +1197,24 @@ def _pallas_call_typecheck_rule(*in_avals, grid_mapping, **params): def _convert_out_shape_to_aval(out_shape: Any) -> jax_core.AbstractValue: match out_shape: - case jax.ShapeDtypeStruct(): - return jax_core.ShapedArray(shape=out_shape.shape, dtype=out_shape.dtype) + case jax_core.ShapeDtypeStruct(): + if config._check_vma.value: + if out_shape.vma is None: + raise ValueError( + "When `check_vma=True` on `jax.shard_map`, `vma` on" + " `jax.ShapeDtypeStruct` must not be `None`. Please specify how the" + " output should be varying across mesh axes using the `vma`" + " argument of `jax.ShapeDtypeStruct` or set `check_vma=False` on" + " `jax.shard_map`.") + return jax_core.ShapedArray( + shape=out_shape.shape, dtype=out_shape.dtype, + sharding=jax_core.get_cur_mesh_sharding(), vma=out_shape.vma) + return jax_core.ShapedArray(shape=out_shape.shape, dtype=out_shape.dtype, + sharding=jax_core.get_cur_mesh_sharding()) case pallas_core.MemoryRef(): return out_shape.get_array_aval() + case hijax.HiType(): + return out_shape case _: if type(out_shape) in pallas_core._out_shape_to_aval_mapping: return pallas_core._out_shape_to_aval_mapping[type(out_shape)]( @@ -1357,11 +1235,13 @@ def _pallas_call_state_discharge_rule( grid_mapping: GridMapping, mesh: pallas_core.Mesh | None, debug: bool, - interpret: bool, + interpret: Any, compiler_params: Any, cost_estimate: CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], - backend: _Backend | None = None, + backend: Backend | None, + metadata: FrozenDict[str, str] | None, + name: str | None, ): del avals_out assert all(isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars) @@ -1369,7 +1249,7 @@ def _pallas_call_state_discharge_rule( ref_avals, rest_in_avals = split_list(avals_in, [num_refs]) assert all(isinstance(ref_aval, state.AbstractRef) for ref_aval in ref_avals) ref_avals = [ - pallas_core.AbstractMemoryRef( + state.AbstractRef( ref_aval.inner_aval, pallas_core.MemorySpace.ANY ) for ref_aval in ref_avals @@ -1380,12 +1260,14 @@ def _pallas_call_state_discharge_rule( ref_block_mappings = [ block_spec.to_block_mapping( origin="", # TODO(sharadmv): enable origins for refs - array_aval=ref_aval.inner_aval, + array_aval=ref_aval.inner_aval, # type: ignore[arg-type] index_map_avals=grid_mapping.index_map_avals, index_map_tree=grid_mapping.index_map_tree, grid=grid_mapping.grid, - mapped_dims=grid_mapping.mapped_dims, - ) for ref_aval, block_spec in zip(ref_avals, ref_block_specs) + vmapped_dims=grid_mapping.vmapped_dims, + debug=debug, + ) + for ref_aval, block_spec in zip(ref_avals, ref_block_specs) ] in_block_mappings, out_block_mappings = split_list( grid_mapping.block_mappings, [grid_mapping.num_inputs] @@ -1437,8 +1319,9 @@ def _rewritten_body(*args): ], ) ) - new_jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(_rewritten_body, debug_info=jaxpr.debug_info), + new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(_rewritten_body, + debug_info=jaxpr.debug_info.with_unknown_names()), [ *index_map_avals, *ref_avals, @@ -1455,7 +1338,7 @@ def _rewritten_body(*args): *ref_args, *rest_args, jaxpr=new_jaxpr, - input_output_aliases=new_input_output_aliases, + input_output_aliases=tuple(new_input_output_aliases), grid_mapping=new_grid_mapping, mesh=mesh, debug=debug, @@ -1464,6 +1347,8 @@ def _rewritten_body(*args): cost_estimate=cost_estimate, out_avals=new_out_avals, backend=backend, + metadata=metadata, + name=name, ) refs_out, rest = split_list(out_flat, [num_refs]) updated_vals_in = refs_out + [None] * len(rest_in_avals) @@ -1474,22 +1359,30 @@ def pallas_call( kernel: Callable[..., None], out_shape: Any, *, - grid_spec: GridSpec | None = None, - grid: TupleGrid = (), - in_specs: BlockSpecTree = no_block_spec, - out_specs: BlockSpecTree = no_block_spec, - scratch_shapes: ScratchShapeTree = (), - input_output_aliases: dict[int, int] = {}, + grid_spec: pallas_core.GridSpec | None = None, + grid: pallas_core.TupleGrid = (), + in_specs: pallas_core.BlockSpecTree = no_block_spec, + out_specs: pallas_core.BlockSpecTree = no_block_spec, + scratch_shapes: pallas_core.ScratchShapeTree = (), + input_output_aliases: Mapping[int, int] = {}, debug: bool = False, - interpret: bool = False, + interpret: Any = False, name: str | None = None, - compiler_params: dict[str, Any] | pallas_core.CompilerParams | None = None, + compiler_params: ( + Mapping[Backend, pallas_core.CompilerParams] + | pallas_core.CompilerParams + | None + ) = None, cost_estimate: CostEstimate | None = None, - backend: _Backend | None = None, + backend: Backend | None = None, + metadata: dict[str, str] | None = None, ) -> Callable[..., Any]: - """Invokes a Pallas kernel on some inputs. + """Entry point for creating a Pallas kernel. + + In contrast to :func:`jax.experimental.pallas.kernel`, this entry point + assumes that the kernel will be executed over a ``grid``. - See `Pallas Quickstart `_. + See `Pallas Quickstart `_. Args: kernel: the kernel function, that receives a Ref for each input and output. @@ -1518,7 +1411,7 @@ def pallas_call( etc. input_output_aliases: a dictionary mapping the index of some inputs to the index of the output that aliases them. These indices are in the - flattened inputs and outputs. + flattened inputs and outputs (ignoring None values). debug: if True, Pallas prints various intermediate forms of the kernel as it is being processed. interpret: runs the ``pallas_call`` as a ``jax.jit`` of a scan over the @@ -1527,25 +1420,27 @@ def pallas_call( This is useful for debugging. name: if present, specifies the name to use for this kernel call in debugging and error messages. To this name we append the file and line - where the kernel function is defined, .e.g: - `{name} for kernel function {kernel_name} at {file}:{line}`. - If missing, then we use `{kernel_name} at {file}:{line}`. - compiler_params: Optional compiler parameters. If a dict is provided, it - should be of the form {platform: {param_name: param_value}}, where - platform is either 'mosaic' or 'triton'. It is also possible - to pass in `jax.experimental.pallas.tpu.TPUCompilerParams` for TPUs and - `jax.experimental.pallas.gpu.TritonCompilerParams` for Triton/GPUs. - backend: Optional string literal one of "mosaic_tpu", "triton" or "mosaic_gpu" - determining the backend to be used. None means let pallas decide. - + where the kernel function is defined, .e.g: `{name} for kernel function + {kernel_name} at {file}:{line}`. If missing, then we use `{kernel_name} at + {file}:{line}`. + compiler_params: Optional compiler parameters. The value should either be a + backend-specific dataclass + (:class:`jax.experimental.pallas.tpu.CompilerParams`, + :class:`jax.experimental.pallas.triton.CompilerParams`, + :class:`jax.experimental.pallas.mosaic_gpu.CompilerParams`) or a dict + mapping backend name to the corresponding platform-specific dataclass. + backend: Optional string literal one of ``"mosaic_tpu"``, ``"triton"`` or + ``"mosaic_gpu"`` determining the backend to be used. None means let Pallas + decide. + metadata: Optional dictionary of information about the kernel that will be + serialized as JSON in the HLO. Can be used for debugging and analysis. Returns: A function that can be called on a number of positional array arguments to invoke the Pallas kernel. - """ if grid_spec is None: - grid_spec = GridSpec(grid, in_specs, out_specs, scratch_shapes) + grid_spec = pallas_core.GridSpec(grid, in_specs, out_specs, scratch_shapes) else: if grid: raise ValueError( @@ -1564,6 +1459,9 @@ def pallas_call( "If `grid_spec` is specified, then `scratch_shapes` must " f"be `()`. It is {scratch_shapes}") del grid, in_specs, out_specs + # We can infer a backend from compiler_params if it is not specified. + if backend is None and isinstance(compiler_params, pallas_core.CompilerParams): + backend = compiler_params.BACKEND return _pallas_call( kernel, out_shape, @@ -1575,33 +1473,56 @@ def pallas_call( compiler_params=compiler_params, cost_estimate=cost_estimate, backend=backend, + metadata=metadata, ) +def _normalize_compiler_params( + compiler_params: Mapping[Backend, pallas_core.CompilerParams] | pallas_core.CompilerParams | None, +) -> Mapping[Backend, pallas_core.CompilerParams]: + if compiler_params is None: + return FrozenDict({}) + if isinstance(compiler_params, CompilerParams): + compiler_params = {compiler_params.BACKEND: compiler_params} + assert isinstance(compiler_params, Mapping) + for backend, params in compiler_params.items(): + if backend not in ["mosaic_tpu", "mosaic_gpu", "triton"]: + raise ValueError(f"Unknown backend in compiler_params: {backend}") + if not isinstance(params, CompilerParams): + raise ValueError( + f"Unexpected compiler_params for backend {backend}: {params}" + ) + if params.BACKEND != backend: + raise ValueError( + f"Inconsistent backend in compiler_params: {params.BACKEND} !=" + f" {backend}" + ) + if not isinstance(compiler_params, FrozenDict): + compiler_params = FrozenDict(compiler_params) + return compiler_params + + +@partial(api_boundary, repro_api_name="jax.experimental.pallas.pallas_call") def _pallas_call( kernel: Callable[..., None], out_shape: Any, *, - grid_spec: GridSpec, + grid_spec: pallas_core.GridSpec, mesh: pallas_core.Mesh | None = None, - input_output_aliases: dict[int, int] = {}, + input_output_aliases: Mapping[int, int] = {}, debug: bool = False, - interpret: bool = False, + interpret: Any = False, name: str | None = None, - compiler_params: dict[str, Any] | pallas_core.CompilerParams | None = None, + compiler_params: ( + Mapping[Backend, CompilerParams] | CompilerParams | None + ) = None, cost_estimate: CostEstimate | None = None, - backend: _Backend | None = None, + backend: Backend | None = None, + metadata: dict[str, str] | None = None, ): - if compiler_params is None: - compiler_params = {} - if isinstance(compiler_params, pallas_core.CompilerParams): - if compiler_params.PLATFORM not in ["mosaic", "mosaic_gpu", "triton"]: - raise ValueError( - f"Unknown platform in compiler params: {compiler_params.PLATFORM}" - ) - compiler_params = { - compiler_params.PLATFORM: dataclasses.asdict(compiler_params) - } + interpret = ( + config.pallas_tpu_interpret_mode_context_manager.value or interpret) + compiler_params = _normalize_compiler_params(compiler_params) if mesh is not None: if tuple(mesh.shape.values()) != grid_spec.grid: @@ -1611,7 +1532,7 @@ def _pallas_call( ) if backend is not None: raise ValueError("If `mesh` is specified, then `backend` must be `None`.") - backend = cast(_Backend, mesh.backend) + backend = mesh.backend grid_spec, dynamic_grid_bounds = pallas_core.unzip_dynamic_grid_bounds(grid_spec) # TODO(necula): this canonicalization may be convenient for some usage @@ -1622,7 +1543,7 @@ def _pallas_call( flat_out_shapes_with_paths, out_tree = tree_util.tree_flatten_with_path(out_shape) out_paths, flat_out_shapes = unzip2(flat_out_shapes_with_paths) - @partial(jax.jit, inline=True) + @partial(api.jit, inline=True) def wrapped(*args): flat_args_with_paths, in_tree = tree_util.tree_flatten_with_path(args) in_paths, flat_args = unzip2(flat_args_with_paths) @@ -1636,13 +1557,22 @@ def wrapped(*args): # TODO(necula): check that input_output_aliases is well-formed: no duplicates, etc. kernel_args, grid_mapping = pallas_core.get_grid_mapping( grid_spec, - flat_in_avals, in_tree, in_origins, - flat_out_avals, out_tree, out_origins) + flat_in_avals, + in_tree, + in_origins, + flat_out_avals, + out_tree, + out_origins, + debug, + ) flat_kernel_args, kernel_in_tree = tree_util.tree_flatten(kernel_args) flat_kernel_avals = tuple( x.ref if isinstance(x, state_types.TransformedRef) else x for x in flat_kernel_args ) + if config._check_vma.value: + flat_kernel_avals = tuple(a.update_vma(frozenset()) + for a in flat_kernel_avals) # Note that only a subset of all transforms can be found here, and they are # never expected to contain any arrays. kernel_arg_transforms = tuple( @@ -1654,7 +1584,7 @@ def wrapped(*args): if name is not None: kernel_dbg = kernel_dbg.replace_func_name(mlir.sanitize_name(name)) jaxpr, consts = _trace_kernel_to_jaxpr( - kernel, kernel_dbg, grid_mapping, tuple(flat_kernel_avals), + kernel, kernel_dbg, grid_mapping, flat_kernel_avals, kernel_in_tree, kernel_arg_transforms) for i_idx, o_idx in input_output_aliases.items(): if i_idx not in range(len(flat_in_avals)): @@ -1669,18 +1599,6 @@ def wrapped(*args): f"[0, {len(flat_out_avals)})") in_aval = flat_in_avals[i_idx] out_aval = flat_out_avals[o_idx] - if isinstance(in_aval, jax_core.DShapedArray): - new_shape = [] - for d in in_aval.shape: - if isinstance(d, int): - new_shape.append(d) - else: - new_shape.append(d.dtype.bound) - - in_aval = jax_core.ShapedArray( - tuple(new_shape), in_aval.dtype, in_aval.weak_type - ) - if in_aval.shape != out_aval.shape or in_aval.dtype != out_aval.dtype: raise ValueError( f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' " @@ -1690,22 +1608,28 @@ def wrapped(*args): f"a different abstract value {out_aval}.") index_args, rest_args = split_list(flat_args, [grid_mapping.num_index_operands]) - out_flat = pallas_call_p.bind( - *consts, - *dynamic_grid_bounds, - *index_args, - *rest_args, - out_avals=flat_out_avals, - jaxpr=jaxpr, - debug=debug, - interpret=interpret, - grid_mapping=grid_mapping, - mesh=mesh, - input_output_aliases=tuple(input_output_aliases.items()), - compiler_params=compiler_params, - cost_estimate=cost_estimate, - backend=backend, + ctx = ( + api.named_scope(name) if name is not None else contextlib.nullcontext() ) + with ctx: + out_flat = pallas_call_p.bind( + *consts, + *dynamic_grid_bounds, + *index_args, + *rest_args, + out_avals=flat_out_avals, + jaxpr=jaxpr, + debug=debug, + interpret=interpret, + grid_mapping=grid_mapping, + mesh=mesh, + input_output_aliases=tuple(input_output_aliases.items()), + compiler_params=compiler_params, + cost_estimate=cost_estimate, + backend=backend, + metadata=FrozenDict(metadata) if metadata is not None else None, + name=name, + ) out = tree_util.tree_unflatten(out_tree, out_flat) return out return wrapped @@ -1733,7 +1657,7 @@ def in_path_to_input_origin( # We import the TPU backend at the top level because it defines flags. Note that -# we can only do that at the bottom of this file, beacuse it also depends on +# we can only do that at the bottom of this file, because it also depends on # this module already being initialized. try: @@ -1742,8 +1666,15 @@ def in_path_to_input_origin( mosaic_tpu_backend = None # type: ignore try: - from jax._src.pallas.mosaic import interpret as mosaic_tpu_interpret + from jax._src.pallas.mosaic.interpret import interpret_pallas_call as mosaic_tpu_interpret except ImportError: mosaic_tpu_interpret = types.SimpleNamespace( # type: ignore - TPUInterpretParams=types.new_class('_NoInstances', (enum.Enum,)), + InterpretParams=types.new_class("_NoInstances", (enum.Enum,)), + ) + +try: + from jax._src.pallas.mosaic_gpu.interpret import interpret_pallas_call as mosaic_gpu_interpret +except ImportError: + mosaic_gpu_interpret = types.SimpleNamespace( # type: ignore + InterpretParams=types.new_class("_NoInstances", (enum.Enum,)), ) diff --git a/jax/_src/pallas/pallas_test_util.py b/jax/_src/pallas/pallas_test_util.py new file mode 100644 index 000000000000..621ca70b72bd --- /dev/null +++ b/jax/_src/pallas/pallas_test_util.py @@ -0,0 +1,55 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pallas test utilities.""" +import sys + +from jax._src import test_util as jtu +from jax._src.pallas import pallas_call +from jax.experimental import pallas as pl + +use_mosaic_gpu = pallas_call._PALLAS_USE_MOSAIC_GPU.value + + +@jtu.with_config(jax_traceback_filtering="off") +class PallasTest(jtu.JaxTestCase): + INTERPRET: bool = False + + def setUp(self): + if not jtu.test_device_matches(['cpu']) and self.INTERPRET: + self.skipTest('Only run interpret tests on CPU.') + if not self.INTERPRET: + # Running on accelerator + if jtu.test_device_matches(["cpu"]): + self.skipTest("On CPU the test works only in interpret mode") + if (jtu.test_device_matches(["cuda"]) and + not jtu.is_cuda_compute_capability_at_least("8.0")): + self.skipTest("Only works on GPU with capability >= sm80") + if (jtu.test_device_matches(["cuda"]) and use_mosaic_gpu and + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("Mosaic GPU requires capability >= sm90") + if sys.platform == "win32": + self.skipTest("Only works on non-Windows platforms") + super().setUp() + + def pallas_call(self, *args, **kwargs): + return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET) + + +class PallasTPUTest(PallasTest): + """A test case that only runs on TPUs or in interpret mode on CPU.""" + + def setUp(self): + if not jtu.test_device_matches(['tpu']) and not self.INTERPRET: + self.skipTest('Test requires TPUs') + super().setUp() diff --git a/jax/_src/pallas/pipelining/BUILD b/jax/_src/pallas/pipelining/BUILD new file mode 100644 index 000000000000..e5ddf5d868ca --- /dev/null +++ b/jax/_src/pallas/pipelining/BUILD @@ -0,0 +1,74 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +load( + "//jaxlib:jax.bzl", + "py_deps", + "pytype_strict_library", +) + +package( + default_applicable_licenses = [], + default_visibility = [ + "//jax:internal", + ], +) + +pytype_strict_library( + name = "internal", + srcs = ["internal.py"], + deps = [ + "//jax", + "//jax/_src:core", + "//jax/_src:state_types", + ], +) + +pytype_strict_library( + name = "schedule_api", + srcs = ["schedule_api.py"], + deps = [ + ":internal", + ":schedulers", + "//jax", + "//jax/_src:api_util", + "//jax/_src:core", + "//jax/_src:partial_eval", + "//jax/_src:state_types", + "//third_party/py/numpy", + ], +) + +pytype_strict_library( + name = "schedulers", + srcs = ["schedulers.py"], + deps = [ + ":internal", + "//jax", + "//jax/_src:core", + "//third_party/py/numpy", + ], +) + +pytype_strict_library( + name = "pipeline_test_util", + testonly = True, + srcs = ["pipeline_test_util.py"], + deps = [ + "//jax/_src:debugging", + "//jax/_src/pallas/pipelining:internal", + "//jax/_src/pallas/pipelining:schedulers", + ] + py_deps([ + "absl/testing", + ]), +) diff --git a/jax/_src/pallas/pipelining/__init__.py b/jax/_src/pallas/pipelining/__init__.py new file mode 100644 index 000000000000..1337256a5074 --- /dev/null +++ b/jax/_src/pallas/pipelining/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jax/_src/pallas/pipelining/internal.py b/jax/_src/pallas/pipelining/internal.py new file mode 100644 index 000000000000..136ba3d5f1df --- /dev/null +++ b/jax/_src/pallas/pipelining/internal.py @@ -0,0 +1,89 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Internal APIs and data structures for the custom pipelining API.""" +from collections.abc import Hashable, Sequence +import dataclasses + +from jax._src import core as jax_core +from jax._src.state import types as state_types + + +ReadEffect = state_types.ReadEffect +WriteEffect = state_types.WriteEffect +RefEffect = state_types.ReadEffect | state_types.WriteEffect +BufferIndex = int | str + + +def filter_write_effects(effects: set[RefEffect]) -> set[WriteEffect]: + return {effect for effect in effects if isinstance(effect, WriteEffect)} + + +def filter_read_effects(effects: set[RefEffect]) -> set[ReadEffect]: + return {effect for effect in effects if isinstance(effect, ReadEffect)} + + +def filter_tokens(effects: set[RefEffect]) -> set[RefEffect]: + return {effect for effect in effects if isinstance(effect.input_index, str)} + + +@dataclasses.dataclass(frozen=True) +class SchedulingProperties: + max_in_flight: int + is_async_start: bool + is_async_done: bool + + def __post_init__(self): + if self.is_async_start and self.is_async_done: + raise ValueError( + "Async start and async done are mutually exclusive.") + + +@dataclasses.dataclass(frozen=True) +class PipelineStage: + """An internal representation of a pipeline stage.""" + jaxpr: jax_core.ClosedJaxpr + effects: set[RefEffect] + properties: SchedulingProperties + name: str + + def get_read_idxs(self) -> set[BufferIndex]: + """Returns the buffer indices that this stage reads from.""" + return { + effect.input_index + for effect in filter_read_effects(self.effects) + } + + def get_write_idxs(self) -> set[BufferIndex]: + """Returns the buffer indices that this stage writes to.""" + return { + effect.input_index + for effect in filter_write_effects(self.effects) + } + + def __str__(self): + return self.name + + def __repr__(self): + return f"{self.name}[effs={self.effects}]" + + +@dataclasses.dataclass(frozen=True) +class NDLoopStruct: + stages: Sequence[PipelineStage] + grid: Sequence[int] + + +def make_token(obj: Hashable) -> str: + """Returns a fake input ID used to thread data dependencies.""" + return f"token_{hash(obj)}" diff --git a/jax/_src/pallas/pipelining/pipeline_test_util.py b/jax/_src/pallas/pipelining/pipeline_test_util.py new file mode 100644 index 000000000000..e9280d7daac9 --- /dev/null +++ b/jax/_src/pallas/pipelining/pipeline_test_util.py @@ -0,0 +1,64 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test utilities for the custom pipeline scheduling API.""" +import dataclasses +from typing import Any, Sequence + +from jax._src import debugging +from jax._src.pallas.pipelining import schedulers +from jax._src.pallas.pipelining import internal + + +def print_stage( + ctx: schedulers.PipelineContext, stage: internal.PipelineStage, *args +): + """Evaluation function that prints the stage name and iteration number.""" + del args + debugging.debug_print( + "[itr={}] %s" % stage, ctx.linearized_index, ordered=True) + + +@dataclasses.dataclass(frozen=True) +class AnyOrder: + """A helper class to mark the order of elements as unimportant.""" + elements: Sequence[Any] + + +def compare_lists(result, expected): + """Returns if two lists are equal while respecting ``AnyOrder`` elements.""" + result_ptr = 0 + expected_ptr = 0 + any_order_set = None + while result_ptr < len(result) and expected_ptr < len(expected): + cur_result = result[result_ptr] + cur_expected = expected[expected_ptr] + if isinstance(cur_expected, AnyOrder): + if any_order_set is None: + any_order_set = set(cur_expected.elements) + + if cur_result in any_order_set: + result_ptr += 1 + any_order_set.remove(cur_result) + else: + return False + if not any_order_set: + any_order_set = None + expected_ptr += 1 + else: + if cur_result == cur_expected: + result_ptr += 1 + expected_ptr += 1 + else: + return False + return True diff --git a/jax/_src/pallas/pipelining/schedule_api.py b/jax/_src/pallas/pipelining/schedule_api.py new file mode 100644 index 000000000000..e4f4817d1464 --- /dev/null +++ b/jax/_src/pallas/pipelining/schedule_api.py @@ -0,0 +1,325 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Internal API for the Pallas pipelining scheduler.""" +# mypy: ignore-errors +# pylint: disable=missing-function-docstring +# pylint: disable=g-doc-args +# pytype: disable=wrong-keyword-args +import dataclasses +from typing import Any, Optional, Sequence + +import jax +from jax._src import api_util +from jax._src import core as jax_core +from jax._src import linear_util as lu +from jax._src.interpreters import partial_eval as pe +from jax._src.state import types as state_types +from jax._src.pallas.pipelining import schedulers +from jax._src.pallas.pipelining import internal + + +PipelineContext = schedulers.PipelineContext + + +def stage(max_in_flight: int): + """Wrapper for creating a pipeline stage.""" + def wrapper(func) -> SyncStage: + return SyncStage(func, max_in_flight) + return wrapper + + +class SyncStage: + """Constructs a synchronous pipeline stage.""" + + def __init__(self, func, max_in_flight: int): + self.func = func + self.max_in_flight = max_in_flight + + def trace( + self, abstract_refs, state_avals, grid + ) -> internal.PipelineStage: + jaxpr, effs = trace_fun( + self.func, abstract_refs, state_avals, grid + ) + name = getattr(self.func, "__name__", str(self.func)) + return internal.PipelineStage( + jaxpr=jaxpr, + effects=set(effs), + properties=internal.SchedulingProperties( + max_in_flight=self.max_in_flight, + is_async_start=False, + is_async_done=False, + ), + name=name, + ) + + +class AsyncStage: + """Constructs an asynchronous pipeline stage.""" + + def __init__(self, max_in_flight: int): + self.start_func = None + self.end_func = None + self.max_in_flight = max_in_flight + + def def_start(self, func): + self.start_func = func + return self + + def def_end(self, func): + self.end_func = func + return self + + def trace( + self, abstract_refs, state_avals, grid + ) -> tuple[internal.PipelineStage, internal.PipelineStage]: + start_jaxpr, start_effs = trace_fun( + self.start_func, abstract_refs, state_avals, grid + ) + end_jaxpr, end_effs = trace_fun( + self.end_func, abstract_refs, state_avals, grid + ) + token = internal.make_token(self) + start_effs = {*start_effs, internal.WriteEffect(token)} + end_effs = {*end_effs, internal.ReadEffect(token)} + name = getattr(self.start_func, "__name__", str(self.start_func)) + start_stage = internal.PipelineStage( + jaxpr=start_jaxpr, + effects=start_effs, + properties=internal.SchedulingProperties( + max_in_flight=self.max_in_flight, + is_async_start=True, + is_async_done=False, + ), + name=name, + ) + name = getattr(self.end_func, "__name__", str(self.end_func)) + end_stage = internal.PipelineStage( + jaxpr=end_jaxpr, + effects=end_effs, + properties=internal.SchedulingProperties( + max_in_flight=self.max_in_flight, + is_async_start=False, + is_async_done=True, + ), + name=name, + ) + return start_stage, end_stage + + +Stage = SyncStage | AsyncStage + + +def trace_fun( + fun, ref_avals, state_avals, grid +) -> tuple[jax_core.ClosedJaxpr, Sequence[internal.RefEffect]]: + """Trace a stage body function to a Jaxpr.""" + ctx_aval = PipelineContext.aval_pytree(grid, state_avals) + num_ctx_avals = len(jax.tree.leaves(ctx_aval)) + flat_avals, in_tree = jax.tree.flatten((ctx_aval, *ref_avals)) + debug_info = api_util.debug_info("trace_fun", fun, flat_avals, {}) + flat_fn, out_tree_thunk = api_util.flatten_fun_nokwargs( + lu.wrap_init(fun, debug_info=debug_info), in_tree + ) + del out_tree_thunk + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fn, flat_avals) + ref_effects = [ + eff for eff in jaxpr.effects if isinstance(eff, state_types.RefEffect) + ] + # Subtract off the consts and state_avals, since this is variable per stage. + n_const = len(consts) + ref_effects = [ + type(eff)(input_index=eff.input_index - n_const - num_ctx_avals) + for eff in ref_effects + ] + return jax_core.ClosedJaxpr(jaxpr, consts), ref_effects + +def apply_ref_filter( + stages: Sequence[internal.PipelineStage], + ref_filter: Any, + grid, state_avals +) -> Sequence[internal.PipelineStage]: + """Removes any effects belonging to Refs that do not pass the filter.""" + if ref_filter is None: + return stages + ctx_aval = PipelineContext.aval_pytree(grid, state_avals) + num_ctx_avals = len(jax.tree.leaves(ctx_aval)) + new_stages = [] + for stage_ in stages: + jaxpr = stage_.jaxpr.jaxpr + ref_effects = stage_.effects + token_effects = list(internal.filter_tokens(ref_effects)) + refs_to_keep = { + i - num_ctx_avals + for i, aval in enumerate(jaxpr.in_avals) + if ref_filter(aval) + } + new_effects = [ + eff for eff in ref_effects if eff.input_index in refs_to_keep + ] + token_effects + new_stages.append(dataclasses.replace(stage_, effects=set(new_effects))) + return new_stages + +def convert_accum_effects_to_writes(stages: Sequence[internal.PipelineStage] + ) -> Sequence[internal.PipelineStage]: + """Replaces all accumulate effects with simple writes.""" + # After tracing, an accumulation such as ref[...] += y + # will result in both a ReadEffect and a WriteEffect into `ref`. + new_stages = [] + for stage_ in stages: + read_effs = internal.filter_read_effects(stage_.effects) + write_effs = internal.filter_write_effects(stage_.effects) + new_read_effs = ( + eff + for eff in read_effs + if state_types.WriteEffect(eff.input_index) not in write_effs + ) + effs = (*new_read_effs, *write_effs) + new_stages.append(dataclasses.replace(stage_, effects=set(effs))) + return new_stages + + +def remove_duplicate_writes_between_async_stages( + stages: Sequence[internal.PipelineStage], +) -> Sequence[internal.PipelineStage]: + """Removes duplicate writes between the async start and done stages. + + This is done because the scheduler doesn't support multiple writes to + the same Ref in different stages. We instead write to a token in the + async_start stage that's read by the async_done and all direct consumers. + """ + new_stages = [] + for stage_ in stages: + if stage_.properties.is_async_start: + start_read_effs = internal.filter_read_effects(stage_.effects) + start_write_effs = internal.filter_write_effects(stage_.effects) + write_token = internal.filter_tokens(start_write_effs) + assert len(write_token) == 1, stage_.effects + write_token = tuple(write_token)[0] + read_token = state_types.ReadEffect(write_token.input_index) + + done_stage = [ + x + for x in stages + if x.properties.is_async_done and read_token in x.effects + ] + assert len(done_stage) == 1 + done_stage = done_stage[0] + end_write_effs = internal.filter_write_effects(done_stage.effects) + start_write_effs = start_write_effs - end_write_effs + start_effs = (*start_read_effs, *start_write_effs) + new_stages.append(dataclasses.replace(stage_, effects=set(start_effs))) + else: + new_stages.append(stage_) + return new_stages + + +def thread_token_deps_to_consumers(stages: Sequence[internal.PipelineStage] + ) -> Sequence[internal.PipelineStage]: + """Threads the async token to consumers of async op. + + This ensures that the async_start op does not start too soon and potentially + clobber buffers that the consumers are reading from. + """ + effects = [stage_.effects for stage_ in stages] + for stage_ in stages: + if stage_.properties.is_async_done: + write_tokens = internal.filter_tokens( + internal.filter_write_effects(stage_.effects) + ) + read_tokens = internal.filter_tokens( + internal.filter_read_effects(stage_.effects) + ) + assert not write_tokens, stage_.effects + assert len(read_tokens) == 1, stage_.effects + read_token_effect = tuple(read_tokens)[0] + write_idxs = stage_.get_write_idxs() + for i, other_stage in enumerate(stages): + if any( + write_idx in other_stage.get_read_idxs() for write_idx in write_idxs + ): + effects[i].add(read_token_effect) + return [dataclasses.replace(stage_, effects=set(effects[i]) + ) for i, stage_ in enumerate(stages)] + + +def schedule_pipeline( + stages: Sequence[Stage], + grid: Sequence[int], + args: Sequence[Any], + ref_filter: Optional[Any] = None, + initial_state: schedulers.PipelineState | None = None, + scheduler: schedulers.PipelineScheduler = schedulers.static_nd_loop_scheduler, + **scheduler_kwargs, +): + """Schedules stages and emits the code for a pipeline. + + Args: + stages: A sequence of pipeline stages. + grid: The loop grid size. + args: A sequence of arguments to the pipeline. These will be passed + directly to each stage. + ref_filter: An optional function to filter out Refs during tracing so + that they do not affect the pipeline schedule. + initial_state: An optional pipeline state that will be passed as a + carry into each stage. + scheduler: Which scheduling function to use. + **scheduler_kwargs: Additional arguments to pass to the scheduler. + + Returns: + A function that can be called with ``args`` and runs the pipeline. + """ + _, ref_tree = jax.tree.flatten(args) + def _get_aval(x): + if hasattr(x, "get_ref_aval"): + return x.get_ref_aval() + return jax_core.get_aval(x) + avals = jax.tree.map(_get_aval, args) + + # Make state avals. + state_avals = jax.tree.map(_get_aval, initial_state) + + traced_stages = [] + for stage in stages: + if isinstance(stage, SyncStage): + traced_stages.append(stage.trace(avals, state_avals, grid)) + elif isinstance(stage, AsyncStage): + start_stage, end_stage = stage.trace(avals, state_avals, grid) + traced_stages.append(start_stage) + traced_stages.append(end_stage) + else: + raise ValueError(f"Unsupported stage type: {type(stage)}") + + # Run several "passes" to clean up effects before scheduling. + traced_stages = apply_ref_filter(traced_stages, ref_filter, grid, state_avals) + traced_stages = convert_accum_effects_to_writes(traced_stages) + traced_stages = remove_duplicate_writes_between_async_stages(traced_stages) + traced_stages = thread_token_deps_to_consumers(traced_stages) + + loop_struct = internal.NDLoopStruct(stages=traced_stages, grid=grid) + + def pipeline(*args): + flat_args, args_tree = jax.tree.flatten(args) + if args_tree != ref_tree: + raise ValueError( + f"Args tree and ref tree do not match.\n{args_tree=}\n{ref_tree=}" + ) + scheduler( + loop_struct, + args=flat_args, + initial_state=initial_state, + **scheduler_kwargs, + ) + + return pipeline diff --git a/jax/_src/pallas/pipelining/schedulers.py b/jax/_src/pallas/pipelining/schedulers.py new file mode 100644 index 000000000000..90aeb4603b92 --- /dev/null +++ b/jax/_src/pallas/pipelining/schedulers.py @@ -0,0 +1,552 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pipeline scheduler implementations.""" +# mypy: ignore-errors +# pytype: disable=invalid-annotation +# pytype: disable=wrong-arg-types +# pytype: disable=bad-return-type +# pylint: disable=missing-function-docstring +# pylint: disable=g-doc-args + +import collections +from collections.abc import Callable, Mapping, Sequence +import copy +import dataclasses +import functools +import math +import operator +from typing import Any, cast, Protocol + +import jax +from jax import lax +from jax import numpy as jnp +from jax._src import core as jax_core +import numpy as np + +from jax._src.pallas.pipelining import internal + + +PipelineState = Any +PipelineScheduler = Callable[ + [internal.NDLoopStruct, Sequence[Any], Any, Any], None] + + +def compute_grid_indices(linear_index: jax.Array, grid_size: Sequence[int]): + """Computes the grid indices for a given linear index.""" + indices = [] + for i, _ in enumerate(grid_size): + rest_size = math.prod(grid_size[i+1:]) + axis_index = linear_index // rest_size + indices.append(axis_index) + linear_index = lax.rem(linear_index, rest_size) + return indices + + +def increment_grid(indices: Sequence[int | jax.Array], + grid: Sequence[int], + dynamic: bool = False): + """Increments the grid indices by 1.""" + next_indices = [] + carry: bool | jax.Array = True + for idx, size in reversed(list(zip(indices, grid, strict=True))): + if dynamic: + idx = cast(jax.Array, idx) + next_idx = lax.select(carry, idx + 1, idx) + carry = next_idx == size + next_indices.append( + lax.select(carry, jnp.asarray(0, dtype=idx.dtype), next_idx) + ) + else: + next_idx = idx + 1 if carry else idx + carry = next_idx == size + next_indices.append(0 if carry else next_idx) + return tuple(reversed(next_indices)), carry + + +@functools.partial(jax.tree_util.register_dataclass, + data_fields=["loop_index", + "linearized_index", + "pipeline_state"], + meta_fields=[]) +@dataclasses.dataclass(frozen=True) +class PipelineContext: + """Container class containing pipeline state information. + + Attributes: + loop_index: The current grid indices to run for the current stage. + linearized_index: The linearized ``loop_index``. + pipeline_state: The global pipeline carry state. + """ + loop_index: tuple[jax.Array, ...] + linearized_index: jax.Array + pipeline_state: PipelineState + + @classmethod + def aval_pytree(cls, grid, state_avals) -> "PipelineContext": + return PipelineContext( + loop_index=(jax_core.ShapedArray((), jnp.int32),) * len(grid), + linearized_index=jax_core.ShapedArray((), jnp.int32), + pipeline_state=state_avals) + + +def check_pipeline(stages: Sequence[internal.PipelineStage]): + """Runs sanity checks on the pipeline.""" + last_write = collections.defaultdict(lambda: None) + last_read = collections.defaultdict(lambda: None) + for i, stage in enumerate(stages): + for read_idx in stage.get_read_idxs(): + if last_write[read_idx] is None: + raise ValueError( + f"Read before write. {stage} attempted to read ref {read_idx}" + " without a prior stage writing to it.") + last_read[read_idx] = i + for write_idx in stage.get_write_idxs(): + if last_write[write_idx] is not None: + raise ValueError( + f"Write conflict. {stage} writes to ref {write_idx} but it was" + f" already written to by stage {stages[last_write[write_idx]]}." + " The current scheduler only allows one stage to write to each" + " buffer.") + last_write[write_idx] = i + all_idxs = last_write.keys() | last_read.keys() + for i in all_idxs: + if last_write[i] > last_read[i]: + raise ValueError(f"Ref {i} is written to after its final read.") + + +@functools.partial(jax.tree_util.register_dataclass, + data_fields=["stage_counters"], + meta_fields=["which_stage_writes", "which_stages_read"]) +@dataclasses.dataclass(frozen=True) +class Scoreboard: + """A scoreboard used to book-keep data dependencies. + + Attributes: + which_stage_writes: A mapping from buffer index to the stage index that + writes to it. + which_stages_read: A mapping from buffer index to the stages that read + from it. + stage_counters: A list of length num_stages that tracks the number of times + each stage has run. + """ + which_stage_writes: Mapping[internal.BufferIndex, int] + which_stages_read: Mapping[internal.BufferIndex, Sequence[int]] + stage_counters: list[jax.Array | int] + + @classmethod + def create(cls, stages: Sequence[internal.PipelineStage]): + which_stage_writes = collections.defaultdict(lambda: None) + which_stage_reads = collections.defaultdict(set) + stage_counters = [0] * len(stages) + for i, stage in enumerate(stages): + for write_idx in stage.get_write_idxs(): + which_stage_writes[write_idx] = i + for read_idx in stage.get_read_idxs(): + which_stage_reads[read_idx].add(i) + return cls(which_stage_writes, which_stage_reads, stage_counters) + + def get_stage_counter(self, stage_idx: int) -> jax.Array | int: + """Returns the current stage counter for the given stage index.""" + return self.stage_counters[stage_idx] + + def get_writing_stage(self, buffer_idx: internal.BufferIndex) -> int: + """Returns the stage index that writes to the given buffer index.""" + return self.which_stage_writes[buffer_idx] + + def increment_stage_counter(self, stage_idx: int) -> None: + """Increments the stage counter for the given stage index.""" + self.stage_counters[stage_idx] += 1 + + def copy(self) -> "Scoreboard": + """Returns a deep copy of the scoreboard.""" + new_stage_counters = copy.copy(self.stage_counters) + return Scoreboard(self.which_stage_writes, self.which_stages_read, + new_stage_counters) + + +@functools.partial(jax.tree_util.register_dataclass, + data_fields=["indices"], + meta_fields=["grid", "offsets", "dynamic"]) +@dataclasses.dataclass(frozen=True) +class GridCarry: + """Helper class for managing the pipeline grid indices. + + Attributes: + grid: The size of the grid. + offsets: A mapping from the stage index to the integer offset from the + slowest scheduled stage. + dynamic: Whether grid indices should be calculated dynamically. + indices: A mapping from offset to the grid indices. + """ + grid: Sequence[int] + offsets: Sequence[int] + dynamic: bool + indices: Sequence[Sequence[int | jax.Array]] + + @classmethod + def init(cls, grid, offsets, dynamic=False) -> 'GridCarry': + max_offset = max(offsets) + cur_indices = tuple([0] * len(grid)) + indices = [cur_indices] + for _ in range(1, max_offset + 1): + next_indices, _ = increment_grid(cur_indices, grid) + indices.append(next_indices) + cur_indices = next_indices + return cls(grid, offsets, dynamic, tuple(indices)) + + def next(self) -> "GridCarry": + next_indices, _ = increment_grid( + self.indices[-1], self.grid, dynamic=self.dynamic + ) + new_indices = (*self.indices[1:], next_indices) + return GridCarry(self.grid, self.offsets, self.dynamic, new_indices) + + def get_indices_for_stage(self, stage_idx: int) -> Sequence[int | jax.Array]: + return self.indices[self.offsets[stage_idx]] + + +def check_args_ready( + stage: internal.PipelineStage, + scoreboard: Scoreboard, + new_scoreboard: Scoreboard, + current_stage_counter: int | jax.Array, + dynamic=False, +) -> bool | jax.Array: + """Returns whether all arguments to the stage have already been computed.""" + all_read_stages = [] + for arg_idx in stage.get_read_idxs(): + if stage.properties.is_async_start: + # Async start stages can start immediately after the preceding + # stage, so we use new_scoreboard instead of scoreboard. + arg_stage_idx = new_scoreboard.get_writing_stage(arg_idx) + arg_stage_ctr = new_scoreboard.get_stage_counter(arg_stage_idx) + else: + arg_stage_idx = scoreboard.get_writing_stage(arg_idx) + arg_stage_ctr = scoreboard.get_stage_counter(arg_stage_idx) + all_read_stages.append(arg_stage_ctr > current_stage_counter) + op = jnp.logical_and if dynamic else operator.and_ + args_ready = functools.reduce(op, all_read_stages, True) + return args_ready + + +def check_async_done(stage: internal.PipelineStage, + scoreboard: Scoreboard, + num_itrs: int | jax.Array, + current_stage_counter: int | jax.Array, + dynamic=False) -> bool | jax.Array: + """Returns whether the async done stage can run.""" + and_op = jnp.logical_and if dynamic else operator.and_ + # For async done stages, we need to insert delays so that they + # happen as late as possible. + # First condition is that there are a full number of async starts + # in flight. + max_in_flight = stage.properties.max_in_flight + can_run = True + token_read_effs = internal.filter_tokens( + internal.filter_read_effects(stage.effects)) + read_tokens = {effect.input_index for effect in token_read_effs} + assert len(read_tokens) == 1, stage.effects + read_token = tuple(read_tokens)[0] + async_start_stage_idx = scoreboard.which_stage_writes[read_token] + async_start_counter = scoreboard.get_stage_counter( + async_start_stage_idx) + async_done_counter = current_stage_counter + min_op = jnp.minimum if dynamic else min + start_full = (async_start_counter >= + min_op(async_done_counter + max_in_flight, num_itrs)) + can_run = and_op(can_run, start_full) + # Second condition - the consumers of this stage's outputs will + # actually need the results on the next iteration. + for write_idx in stage.get_write_idxs(): + which_stages_read = scoreboard.which_stages_read[write_idx] + for read_stage_idx in which_stages_read: + read_itr = scoreboard.stage_counters[read_stage_idx] + can_run = and_op(can_run, (current_stage_counter <= read_itr)) + return can_run + + +def check_async_start( + stage: internal.PipelineStage, + scoreboard: Scoreboard, + current_stage_counter: int | jax.Array, + dynamic=False, +) -> bool | jax.Array: + """Returns whether the async start stage can run.""" + token_write_effs = internal.filter_tokens( + internal.filter_write_effects(stage.effects) + ) + assert len(token_write_effs) == 1, stage.effects + token_write_idx = tuple(token_write_effs)[0].input_index + dependent_stages = scoreboard.which_stages_read[token_write_idx] + + dependents_ready = [] + max_in_flight = stage.properties.max_in_flight + for dependent_stage_idx in dependent_stages: + check_itr = scoreboard.stage_counters[dependent_stage_idx] + # Do not issue more async_starts than max_in_flight. + dependents_ready.append( + current_stage_counter < check_itr + max_in_flight) + op = jnp.logical_and if dynamic else operator.and_ + dependents_ready = functools.reduce(op, dependents_ready, True) + return dependents_ready + + +class EvalStageFunc(Protocol): + def __call__( + self, + ctx: PipelineContext, + stage: internal.PipelineStage, + args: Sequence[Any], + ) -> PipelineState: + ... + + +def eval_stage(ctx: PipelineContext, stage: internal.PipelineStage, args + ) -> PipelineState: + """Evaluates a single stage.""" + flat_ctx = jax.tree.leaves(ctx) + state_tree = jax.tree.structure(ctx.pipeline_state) + next_state = jax_core.eval_jaxpr( + stage.jaxpr.jaxpr, stage.jaxpr.consts, *flat_ctx, *args + ) + if next_state: + return jax.tree.unflatten(state_tree, next_state) + return ctx.pipeline_state + + +def linearize_stages(stages: Sequence[internal.PipelineStage] + ) -> Sequence[internal.PipelineStage]: + """Computes a linearization of the pipeline stages.""" + linearized_stages = [] + outputs_written = set() + available_stages = stages + while available_stages: + stage_added = False + new_available_stages = list(available_stages) + for stage in available_stages: + if all(read_idx in outputs_written for read_idx in stage.get_read_idxs()): + linearized_stages.append(stage) + outputs_written.update(stage.get_write_idxs()) + stage_added = True + new_available_stages.remove(stage) + available_stages = new_available_stages + if not stage_added: + raise ValueError( + "Failed to linearize pipeline stages. Could not linearize" + f" {available_stages=}") + return linearized_stages + + +def make_ctx(stage: internal.PipelineStage, + stage_idx: int, + scoreboard: Scoreboard, + pipeline_state: PipelineState, + grid_carry: GridCarry | None = None, + grid: Sequence[int] | None = None, + offset: int | jax.Array = 0) -> PipelineContext: + del stage + step = scoreboard.stage_counters[stage_idx] + offset + if grid_carry is not None: + loop_index = grid_carry.get_indices_for_stage(stage_idx) + else: + loop_index = compute_grid_indices(step, grid) + return PipelineContext(loop_index=loop_index, + linearized_index=step, + pipeline_state=pipeline_state) + + +# TODO(justinfu): Implement a second version that rolls more of the pipeline +# into the loop body to reduce code size. +def static_nd_loop_scheduler( + nd_loop: internal.NDLoopStruct, + args: Sequence[Any], + initial_state: PipelineState | None = None, + eval_fn: EvalStageFunc | None = None, +): + """Schedules and emits the pipeline into a single instruction stream. + + This scheduler is static in the sense that most of the control logic is + implemented in Python and run at JAX tracing time. This reduce scalar + core pressure as the scoreboarding logic does not have to be computed + at runtime. + """ + if eval_fn is None: + eval_fn = eval_stage + + stages = linearize_stages(nd_loop.stages) + num_stages = len(stages) + num_itrs = np.prod(nd_loop.grid) + check_pipeline(stages) + scoreboard = Scoreboard.create(stages) + + def can_run_stage( + stage: internal.PipelineStage, + scoreboard: Scoreboard, + new_scoreboard: Scoreboard, + current_stage_counter: int | jax.Array, + ) -> bool | jax.Array: + can_run = True + # Check args ready. + can_run = can_run & check_args_ready( + stage, scoreboard, new_scoreboard, current_stage_counter) + # Check dependents + if stage.properties.is_async_start: + can_run = can_run & check_async_start( + stage, scoreboard, current_stage_counter, + ) + if stage.properties.is_async_done: + can_run = can_run & check_async_done( + stage, scoreboard, num_itrs, current_stage_counter) + return can_run + + def compute_offsets(scoreboard: Scoreboard) -> Sequence[int] | None: + while any(scoreboard.stage_counters[i] < 1 for i in range(num_stages)): + new_scoreboard = scoreboard.copy() + for stage_idx, stage in enumerate(stages): + current_stage_counter = scoreboard.stage_counters[stage_idx] + can_run = can_run_stage( + stage, scoreboard, new_scoreboard, current_stage_counter + ) + if can_run: + new_scoreboard.increment_stage_counter(stage_idx) + if scoreboard.stage_counters == new_scoreboard.stage_counters: + raise ValueError("Scheduling error. No stages ran.") + scoreboard = new_scoreboard + min_stage = min(scoreboard.stage_counters) + offsets = [ + scoreboard.stage_counters[i] - min_stage for i in range(num_stages) + ] + if max(offsets) > num_itrs: + # Bail out, since we won't be running the main loop. + return None + return offsets + + # Main loop stage iteration offsets. + # This is a list of integers containing the number of iterations each + # stage is ahead of the slowest stage. + offsets = compute_offsets(scoreboard) + + # Static prologue + # This runs the pipeline up until the steady state. + pipeline_state = initial_state + with jax.named_scope("pipeline_prologue"): + while any( + scoreboard.stage_counters[i] < (offsets[i] if offsets else 1) + for i in range(num_stages) + ): + new_scoreboard = scoreboard.copy() + for stage_idx, stage in enumerate(stages): + current_stage_counter = scoreboard.stage_counters[stage_idx] + if offsets: + can_run = current_stage_counter < offsets[stage_idx] + else: + can_run = current_stage_counter < num_itrs + can_run = can_run & can_run_stage( + stage, scoreboard, new_scoreboard, current_stage_counter + ) + if can_run: + pipeline_state = eval_fn( + make_ctx( + stage, stage_idx, scoreboard, pipeline_state, + grid=nd_loop.grid, + ), + stage, + args, + ) + new_scoreboard.increment_stage_counter(stage_idx) + if scoreboard.stage_counters == new_scoreboard.stage_counters: + raise ValueError("Scheduling error. No stages ran.") + scoreboard = new_scoreboard + + if offsets: + assert all( + scoreboard.stage_counters[i] == offsets[i] for i in range(num_stages) + ), ( + f"Scheduling error. Scoreboard {scoreboard.stage_counters} does not" + f" match computed offsets {offsets}" + ) + + # Dynamic loop body. + # This runs the steady state of the pipeline where all stages run with + # no control flow. + @jax.named_scope("pipeline_steady_state") + def loop_body(itr: jax.Array, carry: tuple[PipelineState, GridCarry]): + pipeline_state, grid_carry = carry + stages_left = list(stages) + old_scoreboard = scoreboard.copy() + while any(stages_left): + new_scoreboard = old_scoreboard.copy() + for stage_idx, stage in enumerate(stages_left): + if stage is None: + continue + current_stage_counter = old_scoreboard.stage_counters[stage_idx] + can_run = can_run_stage( + stage, old_scoreboard, new_scoreboard, current_stage_counter + ) + if can_run: + pipeline_state = eval_fn( + make_ctx( + stage, + stage_idx, + old_scoreboard, + pipeline_state, + grid_carry=grid_carry, + offset=itr, + ), + stage, + args, + ) + new_scoreboard.increment_stage_counter(stage_idx) + stages_left[stage_idx] = None + old_scoreboard = new_scoreboard + return (pipeline_state, grid_carry.next()) + + num_loop_itrs = int(max(num_itrs - max(scoreboard.stage_counters), 0)) + if offsets: + grid_carry = GridCarry.init( + offsets=offsets, grid=nd_loop.grid, dynamic=True) + init_carry = (pipeline_state, grid_carry) + final_carry = jax.lax.fori_loop(0, num_loop_itrs, loop_body, init_carry) + (pipeline_state, _) = final_carry + + # Update the static scoreboard to reflect the fact that each stage ran + # num_loop_itrs times. + for stage_idx in range(len(stages)): + scoreboard.stage_counters[stage_idx] += num_loop_itrs + + # Static epilogue + with jax.named_scope("pipeline_epilogue"): + while any( + scoreboard.stage_counters[i] < num_itrs for i in range(num_stages) + ): + new_scoreboard = scoreboard.copy() + for stage_idx, stage in enumerate(stages): + current_stage_counter = scoreboard.stage_counters[stage_idx] + can_run = current_stage_counter < num_itrs + can_run = can_run & can_run_stage( + stage, scoreboard, new_scoreboard, current_stage_counter + ) + if can_run: + pipeline_state = eval_fn( + make_ctx( + stage, stage_idx, scoreboard, pipeline_state, + grid=nd_loop.grid, + ), + stage, + args, + ) + new_scoreboard.increment_stage_counter(stage_idx) + if scoreboard.stage_counters == new_scoreboard.stage_counters: + raise ValueError("Scheduling error. No stages ran.") + scoreboard = new_scoreboard diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 3306649f24f3..d800cc8ab32c 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -16,36 +16,41 @@ from __future__ import annotations +from collections.abc import Callable, Sequence +from collections.abc import Hashable import enum import functools +import math import string -from typing import Any, Callable +from typing import Any -import jax -from jax import lax -from jax import tree_util +import jax._src.lax as lax +from jax._src import tree_util from jax._src import ad_util from jax._src import api_util -from jax._src import callback from jax._src import core as jax_core +from jax._src import config +from jax._src import debugging from jax._src import dtypes +from jax._src import typing as jax_typing from jax._src import effects from jax._src import linear_util as lu from jax._src import pretty_printer as pp from jax._src import state from jax._src import util from jax._src.interpreters import ad -from jax._src.interpreters import batching from jax._src.interpreters import partial_eval as pe +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import arith from jax._src.pallas import core as pallas_core +from jax._src.pallas import utils as pallas_utils from jax._src.state import discharge as state_discharge from jax._src.state import indexing -from jax._src.state import types as state_types from jax._src.state import primitives as sp +from jax._src.state import types as state_types from jax.interpreters import mlir -import jax.numpy as jnp +from jax._src import numpy as jnp -partial = functools.partial Slice = indexing.Slice NDIndexer = indexing.NDIndexer @@ -53,16 +58,15 @@ zip, unsafe_zip = util.safe_zip, zip program_id_p = jax_core.Primitive("program_id") -batching.ragged_prop_rules[program_id_p] = batching.ragged_mask_no_op_rule -def program_id(axis: int) -> jax.Array: +def program_id(axis: int) -> jax_typing.Array: """Returns the kernel execution position along the given axis of the grid. - For example, with a 2D `grid` in the kernel execution corresponding to the - grid coordinates `(1, 2)`, - `program_id(axis=0)` returns `1` and `program_id(axis=1)` returns `2`. + For example, with a 2D ``grid`` in the kernel execution corresponding to the + grid coordinates ``(1, 2)``, + ``program_id(axis=0)`` returns ``1`` and ``program_id(axis=1)`` returns ``2``. - The returned value is an array of shape `()` and dtype `int32`. + The returned value is an array of shape ``()`` and dtype ``int32``. Args: axis: the axis of the grid along which to count the program. @@ -88,7 +92,7 @@ def _program_id_abstract_eval(**_): num_programs_p = jax_core.Primitive("num_programs") -def num_programs(axis: int) -> int | jax.Array: +def num_programs(axis: int) -> int | jax_typing.Array: """Returns the size of the grid along the given axis.""" return num_programs_p.bind(axis=axis) @@ -126,10 +130,9 @@ def _atomic_rmw_discharge_rule( in_avals, out_avals, *args_flat, args_tree, atomic_type: AtomicOpType ): del out_avals # Unused. - ref, indexers, val, mask = args_tree.unflatten(args_flat) - if len(indexers) > 1: - raise NotImplementedError("Only one indexer is supported.") - idx = indexers[0] + ref, transforms, val, mask = args_tree.unflatten(args_flat) + *prev_transforms, idx = transforms + ref = state_discharge.transform_array(ref, prev_transforms) if mask is not None: raise NotImplementedError @@ -145,7 +148,7 @@ def _atomic_rmw_discharge_rule( if all((isinstance(s, Slice) or not s.shape) for s in idx.indices): indices = idx.indices - scalar_dims = [not isinstance(s, Slice) and s.shape == () for s in indices] + scalar_dims = [not isinstance(s, Slice) and not s.shape for s in indices] slice_starts = [s.start if isinstance(s, Slice) else s for s in indices] slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices) out_ones = lax.dynamic_slice(ref, slice_starts, slice_sizes=slice_sizes) @@ -345,9 +348,11 @@ def _atomic_cas_discharge_rule(in_avals, out_avals, ref, cmp, val): mlir.register_lowering(max_contiguous_p, lambda _, x, **__: [x]) def max_contiguous(x, values): - if not isinstance(values, list): - values = [values] - return max_contiguous_p.bind(x, values=values) + """A compiler hint that asserts the ``values`` first values of ``x`` are contiguous. + """ + if not isinstance(values, (list, tuple)): + values = (values,) + return max_contiguous_p.bind(x, values=tuple(values)) @max_contiguous_p.def_abstract_eval def _max_contiguous_abstract_eval(aval, **_): @@ -358,9 +363,20 @@ def _max_contiguous_abstract_eval(aval, **_): multiple_of_p.def_impl(lambda x, **_: x) mlir.register_lowering(multiple_of_p, lambda _, x, **__: [x]) -def multiple_of(x: jax.Array, values: list[int] | int) -> jax.Array: - if not isinstance(values, list): - values = [values] +def multiple_of(x: jax_typing.Array, values: Sequence[int] | int) -> jax_typing.Array: + """A compiler hint that asserts a value is a static multiple of another. + + Note that misusing this function, such as asserting ``x`` is a multiple of + ``N`` when it is not, can result in undefined behavior. + + Args: + x: The input array. + values: A set of static divisors that ``x`` is a multiple of. + + Returns: + A copy of ``x``. + """ + values = (values,) if isinstance(values, int) else tuple(values) return multiple_of_p.bind(x, values=values) @multiple_of_p.def_abstract_eval @@ -372,9 +388,11 @@ def _multiple_of_abstract_eval(aval, **_): @load_p.def_effectful_abstract_eval def _load_abstract_eval(*avals_flat, args_tree, **_): - ref, indexers, _, _ = args_tree.unflatten(avals_flat) + ref, transforms, _, _ = args_tree.unflatten(avals_flat) + assert transforms is not None + transformed_ref = pallas_core.TransformedRef(ref, transforms) return ( - jax_core.ShapedArray(indexers[-1].get_indexer_shape(), ref.dtype), + jax_core.ShapedArray(transformed_ref.shape, transformed_ref.dtype), {state.ReadEffect(0)}, ) @@ -382,15 +400,12 @@ def _load_abstract_eval(*avals_flat, args_tree, **_): def _load_pp_rule(eqn, context, settings): # Pretty prints `a = load x i` as `x[i] <- a` y, = eqn.outvars - x, indexers, mask, other = tree_util.tree_unflatten(eqn.params["args_tree"], - eqn.invars) + x, transforms, mask, other = tree_util.tree_unflatten( + eqn.params["args_tree"], eqn.invars + ) # TODO(sharadmv): pretty print mask and other lhs = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes) - result = [ - lhs, - pp.text(' <- '), - sp.pp_ref_transforms(context, x, indexers) - ] + result = [lhs, pp.text(" <- "), sp.pp_ref_transforms(context, x, transforms)] if mask is not None: result += [ pp.text(" "), @@ -408,18 +423,20 @@ def _load_pp_rule(eqn, context, settings): def _load_jvp(primals, tangents, args_tree, **params): - ref_primal, indexers, mask, other_primal = args_tree.unflatten(primals) + ref_primal, transforms, mask, other_primal = args_tree.unflatten(primals) ref_tangent, _, _, other_tangent = args_tree.unflatten(tangents) if other_tangent is not None: other_tangent = ad_util.instantiate(other_tangent) return ( load_p.bind( - *tree_util.tree_leaves((ref_primal, indexers, mask, other_primal)), + *tree_util.tree_leaves((ref_primal, transforms, mask, other_primal)), args_tree=args_tree, **params, ), load_p.bind( - *tree_util.tree_leaves((ref_tangent, indexers, mask, other_tangent)), + *tree_util.tree_leaves( + (ref_tangent, transforms, mask, other_tangent) + ), args_tree=args_tree, **params, ), @@ -468,18 +485,22 @@ def _pad_values_to_avoid_dynamic_slice_oob_shift(value, padding_value=padding_value) return value -_unpad_values_to_avoid_dynamic_slice_oob_shift = partial( - _pad_values_to_avoid_dynamic_slice_oob_shift, unpad=True) +_unpad_values_to_avoid_dynamic_slice_oob_shift = functools.partial( + _pad_values_to_avoid_dynamic_slice_oob_shift, unpad=True +) @state_discharge.register_discharge_rule(load_p) def _load_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_): del out_avals # Unused. - ref, indexers, mask, other = args_tree.unflatten(args_flat) - # TODO(sharadmv): add support for multiple indexers - if len(indexers) > 1: - raise NotImplementedError("Only one indexer supported in discharge rule.") - idx = indexers[0] + ref, transforms, mask, other = args_tree.unflatten(args_flat) + transforms = list(transforms) + if not transforms or not isinstance(transforms[-1], indexing.NDIndexer): + ref_shape = state.get_transforms_shape(transforms, in_avals[0].shape) + transforms.append(indexing.NDIndexer.make_trivial_indexer(ref_shape)) + *prev_transforms, idx = transforms + assert isinstance(idx, NDIndexer) + ref = state_discharge.transform_array(ref, prev_transforms) if all((isinstance(s, Slice) or not s.shape) for s in idx.indices): # TODO(ayx): support strided load/store in interpret mode. for s in idx.indices: @@ -489,15 +510,15 @@ def _load_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_): scalar_dims = [not isinstance(s, Slice) and not s.shape for s in indices] slice_starts = [s.start if isinstance(s, Slice) else s for s in indices] slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices) - # fixes an inconstency with lax.dynamic_slice where if the slice goes out + # fixes an inconsistency with lax.dynamic_slice where if the slice goes out # of bounds, it will instead move the start_index backwards so the slice # will fit in memory. ref = _pad_values_to_avoid_dynamic_slice_oob_shift(ref, slice_sizes) - idx_dtype = dtypes.canonicalize_dtype(jnp.int64) + idx_dtype = dtypes.default_int_dtype() out_ones = lax.dynamic_slice( - ref, - [jnp.astype(s, idx_dtype) for s in slice_starts], - slice_sizes=slice_sizes, + ref, + [jnp.astype(s, idx_dtype) for s in slice_starts], + slice_sizes=slice_sizes, ) out_indexer = tuple(0 if scalar else slice(None) for scalar in scalar_dims) out = out_ones[out_indexer] @@ -515,20 +536,23 @@ def _load_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_): @swap_p.def_effectful_abstract_eval def _swap_abstract_eval(*avals_flat, args_tree, **_): - ref, indexers, val, _ = args_tree.unflatten(avals_flat) - expected_output_shape = indexers[-1].get_indexer_shape() + ref, transforms, val, mask = args_tree.unflatten(avals_flat) + assert transforms is not None + transformed_ref = pallas_core.TransformedRef(ref, transforms) + expected_output_shape = transformed_ref.shape + expected_output_dtype = transformed_ref.dtype if expected_output_shape != val.shape: raise ValueError( f"Invalid shape for `swap`. Ref shape: {ref.shape}. " - f"Value shape: {val.shape}. Indices: {indexers}. " + f"Value shape: {val.shape}. Transforms: {transforms}. " ) - if ref.dtype != val.dtype: + if expected_output_dtype != val.dtype: raise ValueError( - f"Invalid dtype for `swap`. Ref dtype: {ref.dtype}. " + f"Invalid dtype for `swap`. Ref dtype: {expected_output_dtype}. " f"Value dtype: {val.dtype}. " ) return ( - jax_core.ShapedArray(expected_output_shape, ref.dtype), + jax_core.ShapedArray(expected_output_shape, expected_output_dtype), {state.WriteEffect(0)}, ) @@ -538,8 +562,8 @@ def _swap_pp_rule(eqn, context, settings): # or: # Pretty prints `_ = swap x v i` as `x[i] <- v` y, = eqn.outvars - x, indexers, val, mask = eqn.params["args_tree"].unflatten(eqn.invars) - x_i = sp.pp_ref_transforms(context, x, indexers) + x, transforms, val, mask = eqn.params["args_tree"].unflatten(eqn.invars) + x_i = sp.pp_ref_transforms(context, x, transforms) if isinstance(y, jax_core.DropVar): return pp.concat([ x_i, @@ -565,17 +589,17 @@ def _swap_pp_rule(eqn, context, settings): def _swap_jvp(primals, tangents, *, args_tree, **params): - ref_primal, indexers, val_primal, mask = args_tree.unflatten(primals) + ref_primal, transforms, val_primal, mask = args_tree.unflatten(primals) ref_tangent, _, val_tangent, _ = args_tree.unflatten(tangents) val_tangent = ad_util.instantiate(val_tangent) return ( swap_p.bind( - *tree_util.tree_leaves((ref_primal, indexers, val_primal, mask)), + *tree_util.tree_leaves((ref_primal, transforms, val_primal, mask)), args_tree=args_tree, **params, ), swap_p.bind( - *tree_util.tree_leaves((ref_tangent, indexers, val_tangent, mask)), + *tree_util.tree_leaves((ref_tangent, transforms, val_tangent, mask)), args_tree=args_tree, **params, ), @@ -588,10 +612,14 @@ def _swap_jvp(primals, tangents, *, args_tree, **params): @state_discharge.register_discharge_rule(swap_p) def _swap_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_): del out_avals # Unused. - ref, indexers, val, mask = args_tree.unflatten(args_flat) - if len(indexers) > 1: - raise NotImplementedError("Only one indexer supported in discharge rule.") - idx = indexers[0] + ref, transforms, val, mask = args_tree.unflatten(args_flat) + transforms = list(transforms) + if not transforms or not isinstance(transforms[-1], indexing.NDIndexer): + ref_shape = state.get_transforms_shape(transforms, in_avals[0].shape) + transforms.append(indexing.NDIndexer.make_trivial_indexer(ref_shape)) + *prev_transforms, idx = transforms + assert isinstance(idx, NDIndexer) + ref = state_discharge.transform_array(ref, prev_transforms) if all((isinstance(s, Slice) or not s.shape) for s in idx.indices): # TODO(ayx): support strided load/store in interpret mode. for s in idx.indices: @@ -631,7 +659,7 @@ def _swap_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_): def load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier=None, - eviction_policy=None, volatile=False) -> jax.Array: + eviction_policy=None, volatile=False) -> jax_typing.Array: """Returns an array loaded from the given index. If neither ``mask`` nor ``other`` is specified, this function has the same @@ -661,7 +689,7 @@ def load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier=None, ) def swap(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None, - _function_name="swap") -> jax.Array: + _function_name="swap") -> jax_typing.Array: """Swaps the value at the given index and returns the old value. See :func:`~jax.experimental.pallas.load` for the meaning of the arguments. @@ -685,8 +713,36 @@ def store(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None) -> None: _ = swap(x_ref_or_view, idx, val, mask=mask, eviction_policy=eviction_policy, _function_name="store") + +def _handle_small(dtype: jax_typing.DTypeLike): + """Ugly workaround to support types that don't allow automatic promotion.""" + if dtype == jnp.int4: + return jnp.int8 + if dtype == jnp.float8_e4m3b11fnuz: + return jnp.bfloat16 + return dtype + + def dot(a, b, trans_a: bool = False, trans_b: bool = False, allow_tf32: bool | None = None, precision=None): + """Computes the dot product of two arrays. + + The inputs can optionally be transposed before computing the + product. Depending on the hardware, this can be cheaper than + computing the transpose beforehand. + + Args: + a: The left-hand size of the dot product, of shape ``(..., N)``. + b: The right-hand size of the dot product, of shape ``(...N, M)``. + trans_a: Whether to transpose ``a`` before the product. + trans_b: Whether to transpose ``b`` before the product. + allow_tf32: Whether to use tf32 precision. + Mutually exclusive with ``precision``. + precision: Specifies the precision of the dot product. + + See Also: + :func:`jax.numpy.dot` + """ if (a.ndim != 2) or (b.ndim != 2): raise ValueError("`a` and `b` must be 2D arrays.") lhs_contract_dim = 0 if trans_a else 1 @@ -696,15 +752,9 @@ def dot(a, b, trans_a: bool = False, trans_b: bool = False, raise ValueError("Only one of allow_tf32 and precision can be specified") precision = lax.Precision.HIGH if allow_tf32 else lax.Precision.HIGHEST - def _handle_f8(dtype: jax.typing.DTypeLike): - """Ugly workaround to support float8_e4m3b11fnuz in dot.""" - if dtype == jnp.float8_e4m3b11fnuz: - return jnp.bfloat16 - return dtype - - dtype = jnp.promote_types(_handle_f8(a.dtype), _handle_f8(b.dtype)) + dtype = jnp.promote_types(_handle_small(a.dtype), _handle_small(b.dtype)) out_dtype = jnp.int32 if jnp.issubdtype(dtype, jnp.integer) else jnp.float32 - return jax.lax.dot_general( + return lax.dot_general( a, b, dimension_numbers=(((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())), @@ -741,24 +791,7 @@ def _reciprocal(x, *, approx=False): mlir.register_lowering(reciprocal_p, _reciprocal_lowering_rule) -class PrintEffect(effects.Effect): - __str__ = lambda self: "Print" - - -debug_print_effect = PrintEffect() - -# TODO(slebedev): Consider making the effect ordered. -effects.lowerable_effects.add_type(PrintEffect) -effects.control_flow_allowed_effects.add_type(PrintEffect) -effects.remat_allowed_effects.add_type(PrintEffect) -effects.custom_derivatives_allowed_effects.add_type(PrintEffect) - - -debug_print_p = jax_core.Primitive("debug_print") -debug_print_p.multiple_results = True - - -def debug_print(fmt: str, *args: jax.typing.ArrayLike): +def debug_print(fmt: str, *args: jax_typing.ArrayLike): """Prints values from inside a Pallas kernel. Args: @@ -769,7 +802,8 @@ def debug_print(fmt: str, *args: jax.typing.ArrayLike): (``{...}``), since it is always printed before any of the values. * On GPU, when using the experimental Mosaic GPU backend, ``fmt`` must contain a placeholder for each value to be printed. Format specs and - conversions are not supported. All values must be scalars. + conversions are not supported. If a single value is provided, the value + may be an array. Otherwise, all values must be scalars. * On TPU, if all inputs are scalars: If ``fmt`` contains placeholders, all values must be 32-bit integers. If there are no placeholders, the values are printed after the format string. @@ -777,16 +811,12 @@ def debug_print(fmt: str, *args: jax.typing.ArrayLike): the format string. The format string must end with a single placeholder ``{}``. *args: The values to print. - """ # fmt: skip - has_placeholders = False - if fmt: - _, field_name, *_ = next(iter(string.Formatter().parse(fmt))) - has_placeholders = field_name is not None - return debug_print_p.bind(*args, fmt=fmt, has_placeholders=has_placeholders) + """ + return debugging.debug_print(fmt, *args, skip_format_check=True) def check_debug_print_format( - fmt: str, *args: jax.typing.ArrayLike + fmt: str, *args: jax_typing.ArrayLike ): n_placeholders = 0 for _, field, spec, conversion in string.Formatter().parse(fmt): @@ -808,59 +838,6 @@ def check_debug_print_format( ) -@debug_print_p.def_impl -def debug_print_impl(*args: Any, fmt: str, has_placeholders: bool): - if has_placeholders: - print(fmt.format(*args)) - else: - print(fmt, *args) - return () - - -@debug_print_p.def_effectful_abstract_eval -def debug_print_abstract_eval(*avals: Any, fmt: str, has_placeholders: bool): - del avals, fmt, has_placeholders # Unused. - return [], {debug_print_effect} - - -def debug_print_batching_rule(args, dims, **params): - """Unrolls the print primitive across the mapped axis.""" - axis_size = next(x.shape[i] for x, i in zip(args, dims) if i is not None) - - # TODO(sharadmv): implement in terms of rolled loop unstead of unrolled. - def get_arg_at_dim(i, dim, arg): - if dim is batching.not_mapped: - # Broadcast unmapped argument - return arg - return lax.index_in_dim(arg, i, axis=dim, keepdims=False) - - outs = [] - for i in range(axis_size): - args_idx = map(functools.partial(get_arg_at_dim, i), dims, args) - outs.append(debug_print_p.bind(*args_idx, **params)) - outs = [jnp.stack(xs) for xs in zip(*outs)] - return outs, (0,) * len(outs) - - -batching.primitive_batchers[debug_print_p] = functools.partial( - debug_print_batching_rule, debug_print_p -) - - -@functools.partial(mlir.register_lowering, debug_print_p) -def debug_print_lowering_rule(ctx, *args, **params): - result, _, _ = callback.emit_python_callback( - ctx, - functools.partial(debug_print_p.impl, **params), - None, - list(args), - ctx.avals_in, - ctx.avals_out, - has_side_effect=True, - ) - return result - - # All of those shenanigans are because we can't make TransformedRef a PyTree, # because they should appear as atomic JAX values to the users. # TODO(apaszke): This can be deleted once we make transforms in Mosaic GPU @@ -877,14 +854,37 @@ def wrap_with_transforms(f, transforms, *args): run_scoped_p = jax_core.Primitive("run_scoped") run_scoped_p.multiple_results = True - -def run_scoped(f: Callable[..., Any], *types: Any, **kw_types: Any) -> Any: +def _run_scoped_is_high(*avals, jaxpr, **params): + del avals, params + return jaxpr.is_high +run_scoped_p.is_high = _run_scoped_is_high # type: ignore[method-assign] + +def _run_scoped_to_lojax(*args, jaxpr, **params): + closed_hi_jaxpr = jax_core.ClosedJaxpr(jaxpr, args) + closed_lo_jaxpr = pe.lower_jaxpr(closed_hi_jaxpr) + consts = closed_lo_jaxpr.consts + return run_scoped_p.bind(*consts, jaxpr=closed_lo_jaxpr.jaxpr, **params) +run_scoped_p.to_lojax = _run_scoped_to_lojax + +def run_scoped( + f: Callable[..., Any], + *types: Any, + collective_axes: Hashable | tuple[Hashable, ...] = (), + **kw_types: Any, +) -> Any: """Calls the function with allocated references and returns the result. The positional and keyword arguments describe which reference types to allocate for each argument. Each backend has its own set of reference types in addition to :class:`jax.experimental.pallas.MemoryRef`. + + When ``collective_axes`` is specified, the same allocation will be returned for + all programs that only differ in their program ids along the collective axes. + It is an error not to call the same ``run_scoped`` in all programs along that + axis. """ + if not isinstance(collective_axes, tuple): + collective_axes = (collective_axes,) flat_types, in_tree = tree_util.tree_flatten((types, kw_types)) flat_fun, out_tree_thunk = api_util.flatten_fun( lu.wrap_init(f, @@ -907,14 +907,15 @@ def run_scoped(f: Callable[..., Any], *types: Any, **kw_types: Any) -> Any: # parent scope). Jax can't reason about effects to references that # are not in the invars of an operation so we just put them all # there. - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, avals) - out = run_scoped_p.bind(*consts, jaxpr=jaxpr) + with config.mutable_array_checks(False): + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, avals) + out = run_scoped_p.bind(*consts, jaxpr=jaxpr, collective_axes=collective_axes) return tree_util.tree_unflatten(out_tree_thunk(), out) @run_scoped_p.def_effectful_abstract_eval -def _run_scoped_abstract_eval(*args, jaxpr): - del args +def _run_scoped_abstract_eval(*args, jaxpr, collective_axes): + del args, collective_axes # jaxpr will have effects for its inputs (Refs that are allocated) and for # constvars (closed over Refs). The effects for the allocated Refs are local # to the jaxpr and shouldn't propagate out. @@ -935,8 +936,12 @@ def _run_scoped_discharge_rule( out_avals, *args_flat, jaxpr, - **_): + collective_axes): del out_avals + if collective_axes: + raise NotImplementedError( + "run_scoped discharge does not support collective_axes yet." + ) num_consts = len(args_flat) # discharge_state only discharges invars, not consts, so in order to # discharge the requested refs we need to move them to the invar set. @@ -956,7 +961,9 @@ def _run_scoped_discharge_rule( # Run_scoped discharged the external variables but the scoped ones # are not discharged. - out = run_scoped_p.bind(*args_flat, jaxpr=discharged_body) + out = run_scoped_p.bind( + *args_flat, jaxpr=discharged_body, collective_axes=collective_axes + ) # Order of outputs: # (1) return values, (2) closed refs, (3) scoped refs. return_values = out[:num_return_values] @@ -964,7 +971,7 @@ def _run_scoped_discharge_rule( # We update all ref values with their updated values from the discharged # body. For other values we leave them in place. updates = [ - ref_outputs.pop(0) if should and isinstance(aval, pallas_core.AbstractMemoryRef) + ref_outputs.pop(0) if should and isinstance(aval, state.AbstractRef) else None for should, aval in zip(should_discharge, in_avals)] assert len(updates) == len(in_avals), f'{len(updates)} != {len(in_avals)}' return updates, return_values @@ -975,7 +982,12 @@ def _run_scoped_discharge_rule( @functools.partial(mlir.register_lowering, run_scoped_p) -def _run_scoped_lowering_rule(ctx, *args, jaxpr): +def _run_scoped_lowering_rule(ctx, *args, jaxpr, collective_axes): + if collective_axes: + raise ValueError( + "run_scoped lowering outside of Pallas does not support" + " collective_axes." + ) jaxpr_noconst = pe.convert_constvars_jaxpr(jaxpr) num_return_values = len(jaxpr_noconst.outvars) discharged_body, new_consts = state_discharge.discharge_state( @@ -993,3 +1005,489 @@ def _lower_fun(*lower_fun_args): return out[:num_return_values] return mlir.lower_fun(_lower_fun, multiple_results=True)(ctx, *args) + + +get_global_p = jax_core.Primitive("get_global") +get_global_p.multiple_results = False +get_global_p.ref_primitive = True +jax_core._ref_allocating_primitives.add(get_global_p) + +def get_global(what: pallas_core.ScratchShape) -> jax_typing.Array: + """Returns a global reference that persists across all kernel invocations. + + Each call to ``get_global`` returns a different and unique reference, but one that + is stable across invocations of the kernel body. + + Args: + what: The reference type to allocate. Each backend has its own set of + reference types (e.g., :class:`jax.experimental.pallas.mosaic_gpu.SemaphoreType` for GPU). + + Example:: + + sem_ref = pl.get_global(plgpu.SemaphoreType.REGULAR) + pl.semaphore_signal(sem_ref) + pl.semaphore_wait(sem_ref) + """ + ref_aval = what.get_ref_aval() + return get_global_p.bind(what=ref_aval) + + +@get_global_p.def_abstract_eval +def _get_global_abstract_eval(*, what): + return what + + +def _get_global_discharge_rule(in_avals, out_avals, *, what): + del in_avals, out_avals, what + raise NotImplementedError( + "get_global discharge is not supported in interpret mode." + ) + + +state_discharge.register_discharge_rule(get_global_p)( + _get_global_discharge_rule +) + + +def _get_ref_and_transforms(ref): + if isinstance(ref, state.TransformedRef): + return ref.ref, ref.transforms + return ref, () + + +class DeviceIdType(enum.Enum): + MESH = "mesh" + LOGICAL = "logical" + + +def check_sem_avals( + sem_aval, sem_transforms_avals, name, allowed_semaphore_types=None +): + if allowed_semaphore_types is None: + allowed_semaphore_types = { + pallas_core.semaphore, + pallas_core.barrier_semaphore, + # For interpret mode. + pallas_core.SEMAPHORE_INTERPRET_DTYPE, + } + if not isinstance(sem_aval, state.AbstractRef): + raise ValueError(f"Cannot {name} on a non-semaphore Ref: {sem_aval}") + sem_shape = sem_aval.shape + if sem_transforms_avals: + sem_shape = sem_transforms_avals[-1].get_indexer_shape() + if sem_shape: + raise ValueError(f"Cannot {name} on a non-()-shaped semaphore: {sem_shape}") + sem_dtype = sem_aval.dtype + if not any( + jnp.issubdtype(sem_dtype, sem_type) + for sem_type in allowed_semaphore_types + ): + raise ValueError( + f"Must {name} semaphores of the following types:" + f" {allowed_semaphore_types}. Got {sem_dtype}." + ) + + +def _transform_semaphore(ref_value, transforms, ref_aval): + """Helper function for indexing into a semaphore during state_discharge.""" + if ref_value.shape == ref_aval.shape: + return state_discharge.transform_array(ref_value, transforms) + elif len(ref_value.shape) == 0: + return ref_value + else: + raise ValueError( + f"Semaphore value shape {ref_value.shape} does not match aval shape" + f" {ref_aval.shape}" + ) + + +semaphore_read_p = jax_core.Primitive("semaphore_read") +semaphore_read_p.multiple_results = False + + +def semaphore_read(sem_or_view) -> jax_typing.Array: + """Reads the value of a semaphore. + + Args: + sem_or_view: A Ref (or view) representing a semaphore. + + Returns: + A scalar Array containing the value of the semaphore. + """ + ref, transforms = _get_ref_and_transforms(sem_or_view) + args = [ref, transforms] + flat_args, args_tree = tree_util.tree_flatten(args) + return semaphore_read_p.bind(*flat_args, args_tree=args_tree) + +@semaphore_read_p.def_abstract_eval +def _semaphore_read_abstract_eval( + *avals, + args_tree, +): + del avals, args_tree + return jax_core.ShapedArray((), jnp.dtype("int32")) + +def _semaphore_read_discharge_rule(in_avals, + out_avals, + *flat_args, + args_tree): + del out_avals + [ref, transforms] = args_tree.unflatten(flat_args) + sem_value = _transform_semaphore(ref, transforms, in_avals[0]) + sem_value = sem_value.astype(jnp.int32) + return (None,) * len(in_avals), sem_value +state_discharge.register_discharge_rule(semaphore_read_p)( + _semaphore_read_discharge_rule +) + + +DeviceId = ( + int + | jax_typing.Array + | None + | tuple[int | jax_typing.Array, ...] + | dict[Any, int | jax_typing.Array] +) + +class SemaphoreEffect(effects.Effect): + pass +sem_effect = SemaphoreEffect() +effects.control_flow_allowed_effects.add_type(SemaphoreEffect) +effects.custom_derivatives_allowed_effects.add_type(SemaphoreEffect) +pallas_core.kernel_local_effects.add_type(SemaphoreEffect) + + +semaphore_signal_p = jax_core.Primitive('semaphore_signal') +semaphore_signal_p.multiple_results = True + + +def semaphore_signal( + sem_or_view, + inc: int | jax_typing.Array = 1, + *, + device_id: DeviceId = None, + device_id_type: DeviceIdType = DeviceIdType.MESH, + core_index: int | jax_typing.Array | None = None, +): + """Increments the value of a semaphore. + + This operation can also be performed remotely if ``device_id`` is specified, + in which ``sem_or_view`` refers to a Ref located on another device. + Note that it is assumed that ``sem_or_view`` is already allocated + (e.g. through the proper use of barriers), or else this operation could + result in undefined behavior. + + Args: + sem_or_view: A Ref (or view) representing a semaphore. + inc: The value to increment by. + device_id (optional): Specifies which device to signal. + If not specified, ``sem_or_view`` is assumed to be local. + device_id_type (optional): The format in which + ``device_id`` should be specified. + core_index (optional): If on a multi-core device, + specifies which core to signal. + """ + ref, transforms = _get_ref_and_transforms(sem_or_view) + inc = jnp.asarray(inc, dtype=jnp.int32) + args = [ref, transforms, inc, device_id, core_index] + flat_args, args_tree = tree_util.tree_flatten(args) + semaphore_signal_p.bind( + *flat_args, + args_tree=args_tree, + device_id_type=device_id_type, + ) + + +@semaphore_signal_p.def_effectful_abstract_eval +def _semaphore_signal_abstract_eval( + *avals, + args_tree, + device_id_type: DeviceIdType, +): + ( + sem_aval, + sem_transforms_avals, + value_aval, + device_id_aval, + core_index_aval, + ) = tree_util.tree_unflatten(args_tree, avals) + check_sem_avals(sem_aval, sem_transforms_avals, "signal") + if value_aval.dtype != jnp.dtype("int32"): + raise ValueError(f"Must signal an int32 value, but got {value_aval.dtype}") + effs: set[effects.Effect] = {sem_effect} + if device_id_aval is not None: + device_id_flat_avals = tree_util.tree_leaves(device_id_aval) + for aval in device_id_flat_avals: + if aval.dtype != jnp.dtype("int32"): + raise ValueError( + f"`device_id`s must be an int32 value, but got {aval.dtype}" + ) + if device_id_type is DeviceIdType.MESH and isinstance(device_id_aval, dict): + for k in device_id_aval: + if not isinstance(k, tuple): + k = (k,) + for k_ in k: + effs.add(jax_core.NamedAxisEffect(k_)) + else: + effs.add(pallas_core.comms_effect) + return [], effs + +def _semaphore_signal_pp_eqn(eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings): + del settings + invars = eqn.invars + tree = eqn.params["args_tree"] + ( + sem, + sem_transforms, + value, + device_ids, + _, + ) = tree_util.tree_unflatten(tree, invars) + out = pp.concat([ + pp.text("semaphore_signal"), + pp.text(" "), + sp.pp_ref_transforms(context, sem, sem_transforms), + pp.text(" "), + pp.text(jax_core.pp_var(value, context)), + ]) + if device_ids is not None: + flat_device_ids = tree_util.tree_leaves(device_ids) + if not flat_device_ids: + return out + device_ids_pp = [pp.text(jax_core.pp_var(flat_device_ids[0], context))] + for device_id in flat_device_ids[1:]: + device_ids_pp.append(pp.text(" ")) + device_ids_pp.append(pp.text(jax_core.pp_var(device_id, context))) + out = pp.concat([out, pp.concat(device_ids_pp)]) + return out +jax_core.pp_eqn_rules[semaphore_signal_p] = _semaphore_signal_pp_eqn + + +def _semaphore_signal_discharge_rule(in_avals, + out_avals, + *flat_args, + args_tree, + device_id_type): + del out_avals, device_id_type + [ref, transforms, inc, device_id, core_index] = args_tree.unflatten(flat_args) + if device_id is not None: + raise NotImplementedError("Remote signal not implemented.") + if core_index is not None: + raise NotImplementedError("Multiple core support not implemented.") + sem_value = _transform_semaphore(ref, transforms, in_avals[0]) + inc = inc.astype(pallas_core.SEMAPHORE_INTERPRET_DTYPE) + _, new_sem_value = state_discharge.transform_swap_array( + ref, transforms, sem_value + inc + ) + return (new_sem_value,) + (None,) * (len(in_avals) - 1), () +state_discharge.register_discharge_rule(semaphore_signal_p)( + _semaphore_signal_discharge_rule +) + + +semaphore_wait_p = jax_core.Primitive('semaphore_wait') +semaphore_wait_p.multiple_results = True + + +def semaphore_wait( + sem_or_view, value: int | jax_typing.Array = 1, *, decrement: bool = True +): + """Blocks execution of the current thread until a semaphore reaches a value. + + Args: + sem_or_view: A Ref (or view) representing a semaphore. + value: The target value that the semaphore should reach before unblocking. + decrement: Whether to decrement the value of the semaphore after + a successful wait. + """ + ref, transforms = _get_ref_and_transforms(sem_or_view) + value = jnp.asarray(value, dtype=jnp.int32) + args = [ref, transforms, value, decrement] + flat_args, args_tree = tree_util.tree_flatten(args) + semaphore_wait_p.bind(*flat_args, args_tree=args_tree) + +@semaphore_wait_p.def_effectful_abstract_eval +def _semaphore_wait_abstract_eval(*avals, args_tree): + sem_aval, sem_transforms_avals, value_aval, _ = tree_util.tree_unflatten( + args_tree, avals + ) + check_sem_avals(sem_aval, sem_transforms_avals, "wait") + if value_aval.dtype != jnp.dtype("int32"): + raise ValueError("Must wait an int32 value.") + return [], {sem_effect} + +def _semaphore_wait_pp_eqn(eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings): + del settings + invars = eqn.invars + tree = eqn.params["args_tree"] + ( + sem, + sem_transforms, + value, + decrement, + ) = tree_util.tree_unflatten(tree, invars) + parts = [ + pp.text("semaphore_wait"), + ] + if decrement: + parts.append(pp.text("[dec]")) + parts += [ + pp.text(" "), + sp.pp_ref_transforms(context, sem, sem_transforms), + pp.text(" "), + pp.text(jax_core.pp_var(value, context)), + ] + return pp.concat(parts) +jax_core.pp_eqn_rules[semaphore_wait_p] = _semaphore_wait_pp_eqn + +def _semaphore_wait_discharge_rule(in_avals, + out_avals, + *flat_args, + args_tree): + del out_avals + [ref, transforms, value, decrement] = args_tree.unflatten(flat_args) + sem_value = _transform_semaphore(ref, transforms, in_avals[0]) + value = value.astype(pallas_core.SEMAPHORE_INTERPRET_DTYPE) + if decrement: + _, new_sem_value = state_discharge.transform_swap_array( + ref, transforms, sem_value - value + ) + else: + new_sem_value = sem_value + return (new_sem_value,) + (None,) * (len(in_avals) - 1), () +state_discharge.register_discharge_rule(semaphore_wait_p)( + _semaphore_wait_discharge_rule +) + + +def _device_id_dict_to_mesh(mesh_context: pallas_utils.MeshInfo | None, device_id_dict, get_axis_index): + i32 = ir.IntegerType.get_signless(32) + if mesh_context is None: + mesh_axis_sizes = {} + else: + mesh_axis_sizes = dict( + zip(mesh_context.axis_names, mesh_context.mesh_shape) + ) + physical_axis_dict = {} + # Handle joint axes (i.e., one logical axis over >1 physical axes) + for axis_name, idx in device_id_dict.items(): + if isinstance(axis_name, tuple) and any( + a in mesh_axis_sizes for a in axis_name + ): + if not all(a in mesh_axis_sizes for a in axis_name): + raise NotImplementedError( + f"{axis_name} mixes JAX mesh and Pallas mesh grid axes" + ) + axes_dimensions = [mesh_axis_sizes[name] for name in axis_name] + for axis_index, axis_name in enumerate(axis_name): + axis_size = mesh_axis_sizes[axis_name] + inner_mesh_size = math.prod(axes_dimensions[axis_index + 1 :]) + minor_divisor = arith.constant(i32, inner_mesh_size) + + # Fast path for power of 2s + if inner_mesh_size & (inner_mesh_size - 1) == 0: + shift_len = (inner_mesh_size & -inner_mesh_size).bit_length() - 1 + partial_device_idx = arith.shrui(idx, arith.constant(i32, shift_len)) + else: + partial_device_idx = arith.divsi(idx, minor_divisor) + + if axis_size & (axis_size - 1) == 0: + device_idx = arith.andi( + partial_device_idx, + arith.constant(i32, mesh_axis_sizes[axis_name] - 1), + ) + else: + device_idx = arith.remsi( + partial_device_idx, arith.constant(i32, axis_size) + ) + physical_axis_dict[axis_name] = device_idx + else: + physical_axis_dict[axis_name] = idx + device_id = [] + for axis_name in mesh_axis_sizes: + if axis_name in physical_axis_dict: + device_id.append(physical_axis_dict[axis_name]) + else: + device_id.append(get_axis_index(axis_name)) + non_mesh_axes = { + k: v + for k, v in physical_axis_dict.items() + if k not in mesh_axis_sizes + } + return tuple(device_id), non_mesh_axes + + +def device_id_to_logical( + mesh_context: pallas_utils.MeshInfo | None, + device_id: ir.Value | tuple[ir.Value, ...] | dict[Any, ir.Value], + device_id_type: DeviceIdType, + get_axis_index, +) -> tuple[ir.Value | None, dict[Any, ir.Value]]: + """Normalizes a device id into a logical device id and axes that don't correspond to JAX mesh axes. + + The indexing implied by the returned axis dict should be handled by the + caller. If there are no cross-device operations, then the returned logical + device id will be None. + """ + non_mesh_axes = {} + if isinstance(device_id, dict): + if device_id_type is not DeviceIdType.MESH: + raise ValueError( + "`device_id_type` must be MESH if `device_id` is a dict," + f" got: {device_id_type = }." + ) + device_id, non_mesh_axes = _device_id_dict_to_mesh(mesh_context, device_id, get_axis_index) + if device_id_type is DeviceIdType.MESH: + # Mesh means we are passed the mesh coordinates for the device + device_ids = tree_util.tree_leaves(device_id) + mesh_strides: tuple[int, ...] + if mesh_context is None: + mesh_strides = () + else: + mesh_strides = mesh_context.mesh_strides + if len(device_ids) != len(mesh_strides): + raise ValueError( + "Number of device ids must match the number of mesh axes, but got" + f" {len(device_ids)} ids for a {len(mesh_strides)}D mesh." + ) + + i32 = ir.IntegerType.get_signless(32) + if not device_ids: + # If there are no device ids, then it is purely local communication. + return None, non_mesh_axes + return functools.reduce( + arith.addi, + ( + arith.muli(a, arith.constant(i32, b)) + for a, b in zip(device_ids, mesh_strides) + ), + ), non_mesh_axes + elif device_id_type is DeviceIdType.LOGICAL: + return device_id, non_mesh_axes + raise NotImplementedError(f"Unsupported device id type: {device_id_type}") + + +delay_p = jax_core.Primitive("delay") +delay_p.multiple_results = True + + +class DelayEffect(effects.Effect): + pass +delay_effect = DelayEffect() +effects.control_flow_allowed_effects.add_type(DelayEffect) +pallas_core.kernel_local_effects.add_type(DelayEffect) + + +@delay_p.def_effectful_abstract_eval +def _delay_abstract_eval(nanos): + del nanos + return [], {delay_effect} + + +def delay(nanos: int | jax_typing.Array) -> None: + """Sleeps for the given number of nanoseconds.""" + delay_p.bind(nanos) diff --git a/jax/_src/pallas/triton/BUILD b/jax/_src/pallas/triton/BUILD index cde2aadd6013..94c898bdb9f9 100644 --- a/jax/_src/pallas/triton/BUILD +++ b/jax/_src/pallas/triton/BUILD @@ -42,13 +42,14 @@ pytype_strict_library( deps = [ ":lowering", "//jax", - "//jax:ad_util", - "//jax:api_util", - "//jax:core", - "//jax:mlir", - "//jax:partial_eval", - "//jax:source_info_util", - "//jax:util", + "//jax/_src:ad_util", + "//jax/_src:api_util", + "//jax/_src:core", + "//jax/_src:lax", + "//jax/_src:mlir", + "//jax/_src:partial_eval", + "//jax/_src:source_info_util", + "//jax/_src:util", "//jax/_src/lib", "//jax/_src/pallas", ] + py_deps("numpy"), @@ -59,14 +60,20 @@ pytype_strict_library( srcs = ["lowering.py"], deps = [ "//jax", - "//jax:ad_util", - "//jax:api_util", - "//jax:config", - "//jax:core", - "//jax:mlir", - "//jax:partial_eval", - "//jax:source_info_util", - "//jax:util", + "//jax/_src:ad_util", + "//jax/_src:api", + "//jax/_src:api_util", + "//jax/_src:config", + "//jax/_src:core", + "//jax/_src:custom_derivatives", + "//jax/_src:debugging", + "//jax/_src:lax", + "//jax/_src:literals", + "//jax/_src:mlir", + "//jax/_src:partial_eval", + "//jax/_src:source_info_util", + "//jax/_src:state_types", + "//jax/_src:util", "//jax/_src/lib", "//jax/_src/pallas", ] + py_deps("numpy"), @@ -76,12 +83,12 @@ pytype_strict_library( name = "pallas_call_registration", srcs = ["pallas_call_registration.py"], deps = [ + ":core", ":lowering", "//jax", - "//jax:config", - "//jax:core", - "//jax:mlir", - "//jax:util", + "//jax/_src:core", + "//jax/_src:frozen_dict", + "//jax/_src:mlir", "//jax/_src/lib", "//jax/_src/pallas", ], diff --git a/jax/_src/pallas/triton/core.py b/jax/_src/pallas/triton/core.py index 097f8497e8f7..45c456aad690 100644 --- a/jax/_src/pallas/triton/core.py +++ b/jax/_src/pallas/triton/core.py @@ -21,7 +21,7 @@ from jax._src.pallas import core as pallas_core @dataclasses.dataclass(frozen=True) -class TritonCompilerParams(pallas_core.CompilerParams): +class CompilerParams(pallas_core.CompilerParams): """Compiler parameters for Triton. Attributes: @@ -29,10 +29,7 @@ class TritonCompilerParams(pallas_core.CompilerParams): 32 threads. num_stages: The number of stages the compiler should use for software pipelining loops. - serialized_metadata: Additional compiler metadata. This field is unstable - and may be removed in the future. """ - PLATFORM: ClassVar[str] = "triton" + BACKEND: ClassVar[pallas_core.Backend] = "triton" num_warps: int | None = None num_stages: int | None = None - serialized_metadata: bytes | None = None diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index f3a8dd175ec1..23035ec8a056 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -17,11 +17,11 @@ from __future__ import annotations from collections.abc import Callable, Sequence +from collections.abc import Hashable import dataclasses import functools import math -import operator -from typing import Any, Hashable, TypeVar +from typing import Any, TypeVar import jax from jax import lax @@ -32,6 +32,8 @@ from jax._src import config from jax._src import core as jax_core from jax._src import custom_derivatives +from jax._src import debugging +from jax._src import literals from jax._src import linear_util as lu from jax._src import pjit from jax._src import source_info_util @@ -39,23 +41,17 @@ from jax._src import util from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe -from jax._src.lax.control_flow import for_loop -from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.lib.mlir.dialects import math as math_dialect from jax._src.lib.mlir.dialects import scf as scf_dialect from jax._src.lib.triton import dialect as tt_dialect from jax._src.pallas import core as pallas_core -from jax._src.pallas import pallas_call from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils -from jax._src.state import discharge from jax._src.state import indexing from jax._src.state import primitives as sp from jax._src.util import foreach -from jax._src.util import merge_lists -from jax._src.util import partition_list from jax._src.util import split_list import jax.numpy as jnp import numpy as np @@ -87,10 +83,10 @@ class ModuleContext: @dataclasses.dataclass class BlockInfo: - full_shape_dtype: jax.ShapeDtypeStruct + full_shape_dtype: jax_core.ShapedArray start_indices: Sequence[Any] start_indices_alignment: Sequence[int] - block_shape: tuple[int | pallas_core.Mapped, ...] + block_shape: tuple[int | pallas_core.Squeezed, ...] @dataclasses.dataclass @@ -121,36 +117,49 @@ def _eval_index_map( block_indices = lower_jaxpr_to_triton_ir( ctx, block_mapping.index_map_jaxpr.jaxpr, None, *idx ) - block_indices = ( + block_indices = tuple( _ensure_ir_value(i, jax_core.ShapedArray((), jnp.int32)) for i in block_indices ) - if isinstance(block_mapping.indexing_mode, pallas_core.Unblocked): - if block_mapping.indexing_mode.padding is not None: - raise NotImplementedError( - "Unblocked indexing with padding is not supported in Triton lowering." - ) - if block_mapping.pipeline_mode is not None: - raise NotImplementedError( - "Pipeline mode is not supported in Triton lowering." - ) - return tuple(block_indices) + block_indices = tree_util.tree_unflatten( + block_mapping.index_map_out_tree, block_indices) + if block_mapping.pipeline_mode is not None: + raise NotImplementedError( + "Pipeline mode is not supported in Triton lowering." + ) + if any( + isinstance(b, pallas_core.Element) and b.padding != (0, 0) + for b in block_mapping.block_shape + ): + raise NotImplementedError( + "Unblocked indexing with padding is not supported in Triton lowering." + ) + def _get_start_index(i, b): + match b: + case pallas_core.Squeezed() | pallas_core.Element(): + return i + case pallas_core.Blocked(): + return _mul(i, _ir_constant(b.block_size, i.type)) + case _: + raise ValueError(f"Unsupported block dim type: {type(b)}") return tuple( - i if b is pallas_core.mapped else _mul(i, _ir_constant(b, i.type)) - for i, b in zip(block_indices, block_mapping.block_shape) + _get_start_index(i, b) for i, b in + zip(block_indices, block_mapping.block_shape) ) def _get_index_alignment(block_mapping: BlockMapping) -> tuple[int, ...]: - if isinstance(block_mapping.indexing_mode, pallas_core.Unblocked): - return (1,) * len(block_mapping.block_shape) - return tuple( - 1 if b is pallas_core.mapped else b for b in block_mapping.block_shape - ) + def _get_bdim_alignment(b: pallas_core.BlockDim): + match b: + case pallas_core.Squeezed() | pallas_core.Element(): + return 1 + case pallas_core.Blocked(): + return b.block_size + return tuple(_get_bdim_alignment(b) for b in block_mapping.block_shape) def _bcast_to(a: ir.Value, shape: tuple[int, ...]) -> ir.Value: - if not ir.RankedTensorType.isinstance(a.type): + if not isinstance(a.type, ir.RankedTensorType): if not shape: return a return tt_dialect.splat(ir.RankedTensorType.get(shape, a.type), a) @@ -174,12 +183,16 @@ def _bcast( y_aval: jax_core.ShapedArray, out_aval: jax_core.ShapedArray, ) -> ir.Value: - if isinstance(x, (np.ndarray, np.number, int, float)): + if isinstance( + x, (np.ndarray, np.number, int, float, literals.TypedNdArray) + ): x_dtype = x_aval.dtype if x_aval.weak_type: x_dtype = y_aval.dtype x = _ir_constant(x, _dtype_to_ir_type(x_dtype)) - if isinstance(y, (np.ndarray, np.number, int, float)): + if isinstance( + y, (np.ndarray, np.number, int, float, literals.TypedNdArray) + ): y_dtype = y_aval.dtype if y_aval.weak_type: y_dtype = x_aval.dtype @@ -274,8 +287,9 @@ def _new_ir_context() -> ir.Context: # this). This check is only needed to obtain a nicer error message; the # Triton lowering will fail anyway but it will crash with a C++ exception. # We currently apply this check only to load/store operations. -def _check_tensor_size(shape: tuple[int | pallas_core.Mapped, ...]): - size = math.prod(1 if d is pallas_core.mapped else d for d in shape) +def _check_tensor_size(shape: tuple[int | pallas_core.Squeezed, ...]): + size = math.prod(1 if isinstance(d, pallas_core.Squeezed) else d + for d in shape) power_of_2 = (size & (size - 1)) == 0 if not power_of_2: raise ValueError( @@ -344,10 +358,12 @@ def lower_jaxpr_to_triton_module( ) block_infos = [ BlockInfo( - block_mapping.array_shape_dtype, + block_mapping.array_aval, _eval_index_map(ctx, program_ids, block_mapping), _get_index_alignment(block_mapping), - block_mapping.block_shape, + tuple(pallas_core.squeezed if isinstance(b, pallas_core.Squeezed) + else pallas_core._get_block_dim_size(b) + for b in block_mapping.block_shape), ) for block_mapping in grid_mapping.block_mappings ] @@ -394,7 +410,9 @@ def write_env(var: jax_core.Var, val): avals_in = [v.aval for v in eqn.invars] avals_out = [v.aval for v in eqn.outvars] eqn_block_infos = map(read_block_info_env, eqn.invars) - loc = mlir._source_info_to_location(ctx, eqn.primitive, eqn.source_info) + loc = mlir.source_info_to_location( + ctx, eqn.primitive, eqn.source_info.name_stack, + eqn.source_info.traceback) rule_ctx = LoweringRuleContext(ctx, avals_in, avals_out, eqn_block_infos) try: with source_info_util.user_context(eqn.source_info.traceback), loc: @@ -402,7 +420,7 @@ def write_env(var: jax_core.Var, val): except LoweringError: raise # We only add the extra info to the innermost exception. except Exception as e: - if not pallas_call._verbose_errors_enabled(): + if not config.jax_pallas_verbose_errors.value: raise inval_types = map(lambda t: getattr(t, "type", None), invals) raise LoweringError( @@ -428,7 +446,7 @@ def f_lowered(ctx: LoweringRuleContext, *args, **params): fn, params, debug_info=api_util.debug_info("pallas triton lower_fun", fun, args, params)) - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) jaxpr = jax_core.ClosedJaxpr(jaxpr, consts) out = _closed_call_lowering_rule(ctx, *args, call_jaxpr=jaxpr) return out if multiple_results else out[0] @@ -467,7 +485,7 @@ def _atomic_rmw( semantic: tt_dialect.MemSemantic = tt_dialect.MemSemantic.ACQUIRE_RELEASE, sync_scope: tt_dialect.MemSyncScope = tt_dialect.MemSyncScope.GPU, ) -> ir.Value: - if ir.RankedTensorType.isinstance(ptr.type): + if isinstance(ptr.type, ir.RankedTensorType): ptr_type = ir.RankedTensorType(ptr.type) element_type = tt_dialect.PointerType(ptr_type.element_type) result_type = ir.RankedTensorType.get( @@ -491,6 +509,10 @@ def _atomic_lowering_rule( assert block_info is not None ptr, indexers, val, mask = args_tree.unflatten(args_flat) *_, value_aval, mask_aval = args_tree.unflatten(ctx.avals_in) + indexers = list(indexers) + if not indexers or not isinstance(indexers[-1], indexing.NDIndexer): + ref_shape = state.get_transforms_shape(indexers, ctx.avals_in[0].shape) + indexers.append(NDIndexer.make_trivial_indexer(ref_shape)) if len(indexers) != 1: raise NotImplementedError("Only single indexer is supported.") idx = indexers[0] @@ -523,7 +545,7 @@ def _atomic_lowering_rule( @register_lowering(primitives.atomic_cas_p) def _atomic_cas_lowering_rule(ctx: LoweringRuleContext, ptr, cmp, val): _, cmp_aval, val_aval = ctx.avals_in - if ir.RankedTensorType.isinstance(ptr.type): + if isinstance(ptr.type, ir.RankedTensorType): ptr_type = ir.RankedTensorType(ptr.type) element_type = tt_dialect.PointerType(ptr_type.element_type) result_type = ir.RankedTensorType.get( @@ -557,7 +579,7 @@ def _associative_scan_lowering(body, ctx: LoweringRuleContext, args, axes): body, (args, args), {})), in_tree ) - combine_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( + combine_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( flat_fun, in_avals ) out_tree = out_tree_thunk() @@ -654,7 +676,9 @@ def _make_dispatch_table( name: str, **tables: Sequence[_Extern | _Fallback] ) -> Callable[..., ir.Value]: - def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value: + def inner( + ctx: LoweringRuleContext, *args: ir.Value, **_ + ) -> ir.Value: table = tables[ctx.context.platform] h = next((e for e in table if e.matches(ctx.avals_in)), None) if h is None: @@ -1117,20 +1141,26 @@ def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value: }) +def _is_triton_pointer_type(t): + if hasattr(tt_dialect.PointerType, "isinstance"): + return tt_dialect.PointerType.isinstance(t) + return isinstance(t, tt_dialect.PointerType) + + def _minus(x: ir.Value) -> ir.Value: - if tt_dialect.PointerType.isinstance(_element_type(x.type)): + if _is_triton_pointer_type(x.type): raise NotImplementedError(f"unsupported type: {x.type}") - return _sub(_full(x.type, 0), x) + return _sub(_zeros_like(x), x) def _add(x: ir.Value, y: ir.Value): x_element_type = _element_type(x.type) y_element_type = _element_type(y.type) - if tt_dialect.PointerType.isinstance(x_element_type): - assert not tt_dialect.PointerType.isinstance(y_element_type) + if _is_triton_pointer_type(x_element_type): + assert not _is_triton_pointer_type(y_element_type) return tt_dialect.addptr(x.type, x, y) - if tt_dialect.PointerType.isinstance(y_element_type): + if _is_triton_pointer_type(y_element_type): return tt_dialect.addptr(y.type, y, x) assert x.type == y.type, (str(x.type), str(y.type)) @@ -1144,9 +1174,9 @@ def _add(x: ir.Value, y: ir.Value): def _sub(x: ir.Value, y: ir.Value) -> ir.Value: x_element_type = _element_type(x.type) y_element_type = _element_type(y.type) - if tt_dialect.PointerType.isinstance(x_element_type): + if _is_triton_pointer_type(x_element_type): return tt_dialect.addptr(x.type, x, _minus(y)) - elif not tt_dialect.PointerType.isinstance(y_element_type): + elif not _is_triton_pointer_type(y_element_type): assert x.type == y.type, (str(x.type), str(y.type)) if isinstance(x_element_type, ir.IntegerType): return arith_dialect.subi(x, y) @@ -1260,6 +1290,10 @@ def _cmp( ) +def _is_nan(x: ir.Value) -> ir.Value: + return arith_dialect.cmpf(arith_dialect.CmpFPredicate.UNO, x, x) + + _JAX_TO_TRITON_BINARY = { lax.add_p: _add, lax.sub_p: _sub, @@ -1302,17 +1336,31 @@ def signed_rule(ctx: LoweringRuleContext, x, y, fn=fn): triton_lowering_rules[prim] = signed_rule -@register_lowering(primitives.debug_print_p) +@register_lowering(debugging.debug_print_p) def debug_print_lowering_rule( ctx: LoweringRuleContext, *args: ir.Value, fmt: str, - has_placeholders: bool, + ordered, + partitioned, + in_tree, + static_args, + np_printoptions, + has_placeholders, + logging_record, ): + del partitioned, np_printoptions + if ordered: + raise NotImplementedError("Ordered debug_print is not supported on Pallas.") if has_placeholders: raise ValueError( "pl.debug_print() does not support placeholders when lowering to Triton" ) + args, kwargs = debugging.merge_callback_args(in_tree, args, static_args) + if kwargs: + raise ValueError( + "Only positional arguments are supported by debug_print on Pallas." + ) tt_dialect.print_( f" {fmt} ", @@ -1326,7 +1374,7 @@ def debug_print_lowering_rule( def _set_attr(v: ir.Value, name: str, attr: ir.Attribute) -> None: - if not ir.BlockArgument.isinstance(v): + if not isinstance(v, ir.BlockArgument): v.owner.attributes[name] = attr return @@ -1373,7 +1421,7 @@ def _broadcast_to_rule(ctx: LoweringRuleContext, x, shape: Sequence[int]): @register_lowering(lax.integer_pow_p) def _integer_pow_rule(ctx: LoweringRuleContext, x, *, y: int): if y == 0: - return _full(x.type, 1) + return _ones_like(x) is_reciprocal = y < 0 if is_reciprocal: @@ -1393,14 +1441,15 @@ def _integer_pow_rule(ctx: LoweringRuleContext, x, *, y: int): acc = _cast(acc, x_aval.dtype, out_aval.dtype) if is_reciprocal: signed = jnp.issubdtype(out_aval.dtype, jnp.signedinteger) - return _truediv(_full(acc.type, 1), acc, signed=signed) + return _truediv(_ones_like(acc), acc, signed=signed) else: return acc _JAX_FN_MAPPING = { lax.clamp_p: lambda min, a, max: jnp.minimum(jnp.maximum(min, a), max), - lax.logistic_p: lambda a: 1 / (1 + jnp.exp(-a)), + lax.logistic_p: lambda a, accuracy: 1 / (1 + jnp.exp(-a)), + lax.is_finite_p: lambda x: jnp.logical_and(~jnp.isnan(x), ~jnp.isinf(x)), } for prim, fn in _JAX_FN_MAPPING.items(): @@ -1479,7 +1528,7 @@ def _iota_lowering_rule(ctx: LoweringRuleContext, *, dtype, shape, dimension, def _element_type(t: ir.Type) -> ir.Type: - if ir.RankedTensorType.isinstance(t): + if isinstance(t, ir.RankedTensorType): return ir.RankedTensorType(t).element_type else: return t @@ -1508,14 +1557,30 @@ def _full(t: ir.Type, v: object) -> ir.Type: else: raise NotImplementedError - if ir.RankedTensorType.isinstance(t): + if isinstance(t, ir.RankedTensorType): return tt_dialect.splat(t, result) else: return result +def _zeros(t: ir.Type) -> ir.Value: + return _full(t, 0) + + +def _zeros_like(x: ir.Value) -> ir.Value: + return _full(x.type, 0) + + +def _ones(t: ir.Type) -> ir.Value: + return _full(t, 1) + + +def _ones_like(x: ir.Value) -> ir.Value: + return _full(x.type, 1) + + def _splat(x: ir.value, shape: Sequence[int]) -> ir.Value: - if ir.RankedTensorType.isinstance(x.type): + if isinstance(x.type, ir.RankedTensorType): raise TypeError("cannot splat a tensor") if not shape: return x @@ -1523,7 +1588,7 @@ def _splat(x: ir.value, shape: Sequence[int]) -> ir.Value: def _expand_dims(x: ir.Value, axis: int) -> ir.Value: - if not ir.RankedTensorType.isinstance(x.type): + if not isinstance(x.type, ir.RankedTensorType): shape = list(ir.RankedTensorType(x.type).shape) shape.insert(axis, 1) return _splat(x, shape) @@ -1534,11 +1599,10 @@ def _float_float_cast(src: ir.Value, dst_type: ir.Type) -> ir.Value: src_element_type = ir.FloatType(_element_type(src.type)) dst_element_type = ir.FloatType(_element_type(dst_type)) if src_element_type.width == 8 or dst_element_type.width == 8: - return tt_dialect.fp_to_fp( - dst_type, - src, - rounding=tt_dialect.RoundingMode.RTNE, + rounding = ( + tt_dialect.RoundingMode.RTNE if src_element_type.width > 8 else None ) + return tt_dialect.fp_to_fp(dst_type, src, rounding=rounding) if src_element_type.width > dst_element_type.width: return arith_dialect.truncf(dst_type, src) elif src_element_type.width < dst_element_type.width: @@ -1552,7 +1616,7 @@ def _int_int_cast(src: ir.Value, dst_type: ir.Type, signed: bool) -> ir.Value: dst_element_type = ir.IntegerType(_element_type(dst_type)) assert src_element_type != dst_element_type if dst_element_type.width == 1: - return _not_equal(src, _full(src.type, 0), signed=signed) + return _not_equal(src, _zeros_like(src), signed=signed) if src_element_type.width == dst_element_type.width: return arith_dialect.bitcast(dst_type, src) @@ -1572,7 +1636,7 @@ def _float_int_cast( raise NotImplementedError(f"cannot cast {src} tp {dst_type}") dst_element_type = ir.IntegerType(_element_type(dst_type)) if dst_element_type.width == 1: - return _not_equal(src, _full(src.type, 0), signed=signed) + return _not_equal(src, _zeros_like(src), signed=signed) else: # We clamp the float value to the min/max integer destination value # in order to match JAX/XLA casting behavior. Note that this differs @@ -1621,9 +1685,9 @@ def _cast( def _ir_cast(src: ir.Value, dst_type: ir.Type, *, signed: bool, dst_signed: bool = False) -> ir.Value: - if ir.RankedTensorType.isinstance( - src.type - ) and not ir.RankedTensorType.isinstance(dst_type): + if isinstance(src.type, ir.RankedTensorType) and not isinstance( + dst_type, ir.RankedTensorType + ): src_type = ir.RankedTensorType(src.type) dst_type = ir.RankedTensorType.get( src_type.shape, @@ -1668,22 +1732,22 @@ def _ir_cast(src: ir.Value, dst_type: ir.Type, *, ): return _int_float_cast(src, dst_type, signed=signed) - if tt_dialect.PointerType.isinstance(src_element_type) and isinstance( + if _is_triton_pointer_type(src_element_type) and isinstance( dst_element_type, ir.IntegerType ): if dst_element_type.width == 64: return tt_dialect.ptr_to_int(dst_type, src) elif dst_element_type.width == 1: x = _ir_cast(src, ir.IntegerType.get_signless(64), signed=signed) - zero = _full(x.type, 0) + zero = _zeros_like(x) return _ir_cast(_not_equal(x, zero, signed=signed), dst_type, signed=signed) - if isinstance( - src_element_type, ir.IntegerType - ) and tt_dialect.PointerType.isinstance(dst_element_type): + if isinstance(src_element_type, ir.IntegerType) and _is_triton_pointer_type( + dst_element_type + ): return tt_dialect.int_to_ptr(dst_type, src) - if tt_dialect.PointerType.isinstance( - src_element_type - ) and tt_dialect.PointerType.isinstance(dst_element_type): + if _is_triton_pointer_type(src_element_type) and _is_triton_pointer_type( + dst_element_type + ): return tt_dialect.bitcast(dst_type, src) raise NotImplementedError(f"cannot cast {src} to {dst_type}") @@ -1715,7 +1779,7 @@ def _broadcast_in_dim_lowering_rule( ): del sharding x = _ensure_ir_value(x, *ctx.avals_in) - if not ir.RankedTensorType.isinstance(x.type): + if not isinstance(x.type, ir.RankedTensorType): return _bcast_to(x, shape) expand_dims = [i for i in range(len(shape)) if i not in broadcast_dimensions] for dim in expand_dims: @@ -1747,7 +1811,7 @@ def _reshape_lowering_rule( def _reshape(a: ir.Value, shape: Sequence[int]) -> ir.Value: - if not ir.RankedTensorType.isinstance(a.type): + if not isinstance(a.type, ir.RankedTensorType): assert all(dim_size == 1 for dim_size in shape) return _splat(a, shape) @@ -1759,6 +1823,12 @@ def _reshape(a: ir.Value, shape: Sequence[int]) -> ir.Value: ) +def get_join_type(old_type: ir.RankedTensorType): + shape = old_type.shape + shape.append(2) + return ir.RankedTensorType.get(shape, old_type.element_type, old_type.encoding) + + @register_lowering(lax.concatenate_p) def _concatenate_lowering_rule(ctx: LoweringRuleContext, *args, dimension): if len(args) != 2: @@ -1773,16 +1843,40 @@ def _concatenate_lowering_rule(ctx: LoweringRuleContext, *args, dimension): raise NotImplementedError( "Only arguments with shape [..., 1] are supported." ) - return tt_dialect.join( - _reshape(x, x_aval.shape[:-1]), _reshape(y, y_aval.shape[:-1]) - ) + lhs = _reshape(x, x_aval.shape[:-1]) + rhs = _reshape(y, y_aval.shape[:-1]) + ret_type = get_join_type(ir.RankedTensorType(rhs.type)) + return tt_dialect.join(ret_type, lhs, rhs) + + +@register_lowering(lax.split_p) +def _split_lowering_rule(ctx: LoweringRuleContext, x, *, sizes, axis): + pass + # TODO(cjfj): Add support for larger powers of 2. + num_parts = len(sizes) + if num_parts != pallas_utils.next_power_of_2(num_parts): + raise NotImplementedError("Only power-of-2 num parts supported.") + if any(size != sizes[0] for size in sizes): + raise NotImplementedError("Only equal-sized splits are supported.") + + def split_into_2(x): + shape = ir.RankedTensorType(x.type).shape + x = _reshape(x, shape[:axis] + [2, shape[axis] // 2] + shape[axis + 1 :]) + permutation = tuple(d for d in range(len(shape) + 1) if d != axis) + (axis,) + return tuple(tt_dialect.split(tt_dialect.trans(x, permutation))) + + x_parts = (x,) + while len(x_parts) < num_parts: + x_parts = sum(map(split_into_2, x_parts), ()) + return x_parts def _compute_offsets_from_indices( block_info: BlockInfo, nd_indexer: NDIndexer ) -> ir.Value: full_shape = block_info.full_shape_dtype.shape - num_mapped_dims = sum(b is pallas_core.mapped for b in block_info.block_shape) + num_squeezed_dims = sum(isinstance(b, pallas_core.Squeezed) + for b in block_info.block_shape) strides = pallas_utils.strides_from_shape(full_shape) indexer_shape = nd_indexer.get_indexer_shape() int_indexer_shape = nd_indexer.int_indexer_shape @@ -1790,7 +1884,7 @@ def _compute_offsets_from_indices( indices = nd_indexer.indices other_shape = indexer_shape[len(int_indexer_shape) :] other_shape_idx = 0 - assert len(indices) + num_mapped_dims == len(full_shape) + assert len(indices) + num_squeezed_dims == len(full_shape) assert len(block_info.start_indices) == len(full_shape) array_dtype = jnp.dtype(block_info.full_shape_dtype.dtype) @@ -1798,7 +1892,7 @@ def _compute_offsets_from_indices( # Use 64-bit indexing when offset might be >= 2**32 bytes. offset_eltype = ir.IntegerType.get_signless(64 if full_size > 2**32 else 32) if indexer_shape: - offsets = _full(ir.RankedTensorType.get(indexer_shape, offset_eltype), 0) + offsets = _zeros(ir.RankedTensorType.get(indexer_shape, offset_eltype)) else: offsets = _ir_constant(0, offset_eltype) @@ -1806,10 +1900,11 @@ def _compute_offsets_from_indices( for dim_stride, dim_block_size, start_offset in zip( strides, block_info.block_shape, block_info.start_indices ): - if dim_block_size is pallas_core.mapped: - index = _ir_constant(0, offset_eltype) - else: - index = next(indexer_iter) + match dim_block_size: + case pallas_core.Squeezed(): + index = _ir_constant(0, offset_eltype) + case int(): + index = next(indexer_iter) if isinstance(index, slice): index = primitives.Slice.from_slice(index, dim_block_size) @@ -1840,12 +1935,12 @@ def _compute_offsets_from_indices( dim_offsets = _ir_constant(dim_offsets, offset_eltype) dim_offsets = _ir_cast(dim_offsets, offset_eltype, signed=False) - if ir.RankedTensorType.isinstance(dim_offsets.type): + if isinstance(dim_offsets.type, ir.RankedTensorType): for _ in other_shape: rank = ir.RankedTensorType(dim_offsets.type).rank dim_offsets = _expand_dims(dim_offsets, rank) - if ir.RankedTensorType.isinstance(dim_offsets.type): + if isinstance(dim_offsets.type, ir.RankedTensorType): rank = ir.RankedTensorType(dim_offsets.type).rank for _ in range(len(indexer_shape) - rank): dim_offsets = _expand_dims(dim_offsets, 0) @@ -1871,13 +1966,10 @@ def _compute_pointers_from_indices( @register_lowering(sp.get_p) def _get_lowering_rule(ctx: LoweringRuleContext, ptr, *idx, tree): indexers = tree_util.tree_unflatten(tree, idx) - if not tt_dialect.PointerType.isinstance(ptr.type): + if not _is_triton_pointer_type(ptr.type): assert len(indexers) == 0 return ptr - if len(indexers) > 1: - raise NotImplementedError("No support for multiple indexers yet.") - indexer = indexers[0] - args_flat, args_tree = tree_util.tree_flatten((ptr, (indexer,), None, None)) + args_flat, args_tree = tree_util.tree_flatten((ptr, indexers, None, None)) return _masked_load_lowering_rule( ctx, *args_flat, @@ -1917,21 +2009,21 @@ def _load( f"unsupported eviction policy: {eviction_policy}" ) from None - if tt_dialect.PointerType.isinstance(ptr.type): + if _is_triton_pointer_type(ptr.type): ptr_type = tt_dialect.PointerType(ptr.type) - if ir.RankedTensorType.isinstance(ptr_type.pointee_type): + if isinstance(ptr_type.pointee_type, ir.RankedTensorType): raise NotImplementedError("loading from a block pointer is not supported") ptr_type = _element_type(ptr.type) - if not tt_dialect.PointerType.isinstance(ptr_type): + if not _is_triton_pointer_type(ptr_type): raise ValueError(f"unsupported pointer type: {ptr_type}") ptr_type = tt_dialect.PointerType(ptr_type) if other is not None and mask is None: raise ValueError("other requires mask to be provided") - if not ir.RankedTensorType.isinstance(ptr.type): - if other is not None and ir.RankedTensorType.isinstance(other.type): + if not isinstance(ptr.type, ir.RankedTensorType): + if other is not None and isinstance(other.type, ir.RankedTensorType): raise ValueError("other cannot be a block if pointer is not a block") - if mask is not None and ir.RankedTensorType.isinstance(mask.type): + if mask is not None and isinstance(mask.type, ir.RankedTensorType): raise ValueError("mask cannot be a block if pointer is not a block") pointee_type = ptr_type.pointee_type @@ -2016,8 +2108,13 @@ def _masked_load_lowering_rule( *_, mask_aval, other_aval = args_tree.unflatten(ctx.avals_in) if len(indexers) > 1: raise NotImplementedError("No support for multiple indexers yet.") - idx = indexers[0] - if not tt_dialect.PointerType.isinstance(ptr.type): + indexers = list(indexers) + if not indexers: + ref_shape = state.get_transforms_shape(indexers, ctx.avals_in[0].shape) + idx = NDIndexer.make_trivial_indexer(ref_shape) + else: + idx = indexers[0] + if not _is_triton_pointer_type(ptr.type): assert len(ctx.avals_in) == 1 return ptr @@ -2055,22 +2152,15 @@ def _masked_load_lowering_rule( if not is_int4: return values - # After jaxlib 0.5.2, XLA packs pairs of `[u]int4` values into a `uint8` - # value with the first in the least significant bits and the second in the - # most significant. Before jaxlib 0.5.2, the order was reversed. if is_contiguous_int4: msb_values = arith_dialect.shrui(values, _full(values.type, 4)) - if jaxlib_version < (0, 5, 2): - values = tt_dialect.join(msb_values, values) - else: - values = tt_dialect.join(values, msb_values) + join_type = get_join_type(ir.RankedTensorType(values.type)) + values = tt_dialect.join(join_type, values, msb_values) shape = ir.RankedTensorType(values.type).shape values = _reshape(values, (*shape[:-2], shape[-2] * shape[-1])) else: offsets = _ir_cast(offsets, ir.IntegerType.get_signless(32), signed=False) in_msb = _mod(offsets, _full(offsets.type, 2), signed=False) - if jaxlib_version < (0, 5, 2): - in_msb = arith_dialect.xori(in_msb, _full(in_msb.type, 1)) shift = _mul(in_msb, _full(in_msb.type, 4)) shift = _ir_cast(shift, values.type, signed=False) values = arith_dialect.shrui(values, shift) @@ -2080,13 +2170,12 @@ def _masked_load_lowering_rule( @register_lowering(sp.swap_p) def _swap_lowering_rule(ctx: LoweringRuleContext, ptr, value, *idx, tree): indexers = tree_util.tree_unflatten(tree, idx) - if not tt_dialect.PointerType.isinstance(ptr.type): + if not _is_triton_pointer_type(ptr.type): assert len(indexers) == 0 return ptr if len(indexers) > 1: raise NotImplementedError("No support for multiple indexers yet.") - indexer = indexers[0] - args_flat, args_tree = tree_util.tree_flatten((ptr, (indexer,), value, None)) + args_flat, args_tree = tree_util.tree_flatten((ptr, indexers, value, None)) return _masked_swap_lowering_rule( ctx, *args_flat, args_tree=args_tree, eviction_policy=None ) @@ -2116,19 +2205,19 @@ def _store( f"unsupported eviction policy: {eviction_policy}" ) from None - if tt_dialect.PointerType.isinstance(ptr.type): + if _is_triton_pointer_type(ptr.type): ptr_type = tt_dialect.PointerType(ptr.type) - if ir.RankedTensorType.isinstance(ptr_type.pointee_type): + if isinstance(ptr_type.pointee_type, ir.RankedTensorType): raise NotImplementedError("loading from a block pointer is not supported") ptr_type = _element_type(ptr.type) - if not tt_dialect.PointerType.isinstance(ptr_type): + if not _is_triton_pointer_type(ptr_type): raise ValueError(f"unsupported pointer type: {ptr_type}") ptr_type = tt_dialect.PointerType(ptr_type) - if not ir.RankedTensorType.isinstance(ptr.type): - if ir.RankedTensorType.isinstance(value.type): + if not isinstance(ptr.type, ir.RankedTensorType): + if isinstance(value.type, ir.RankedTensorType): raise ValueError("value cannot be a block if pointer is not a block") - if mask is not None and ir.RankedTensorType.isinstance(mask.type): + if mask is not None and isinstance(mask.type, ir.RankedTensorType): raise ValueError("mask cannot be a block if pointer is not a block") pointee_type = ptr_type.pointee_type @@ -2156,7 +2245,11 @@ def _masked_swap_lowering_rule( *_, value_aval, mask_aval = args_tree.unflatten(ctx.avals_in) if len(indexers) > 1: raise NotImplementedError("No support for multiple indexers yet.") - idx = indexers[0] + if not indexers: + ref_shape = state.get_transforms_shape(indexers, ctx.avals_in[0].shape) + idx = NDIndexer.make_trivial_indexer(ref_shape) + else: + idx = indexers[0] ptr = _compute_pointers_from_indices(ptr, block_info, idx) other = None if value is not None: @@ -2176,7 +2269,7 @@ def _addupdate_lowering_rule(ctx: LoweringRuleContext, ptr, value, *idx, tree): block_info, *_ = ctx.block_infos assert block_info is not None indexers = tree_util.tree_unflatten(tree, idx) - if not tt_dialect.PointerType.isinstance(ptr.type): + if not _is_triton_pointer_type(ptr.type): assert len(indexers) == 0 return ptr if len(indexers) > 1: @@ -2198,6 +2291,14 @@ def _transpose_lowering(ctx: LoweringRuleContext, x, *, permutation): _TF32_PRECISIONS = (lax.Precision.HIGH, lax.Precision.DEFAULT) +def _as_bf16(x): + return _ir_cast(x, _dtype_to_ir_type(jnp.bfloat16), signed=False) + + +def _as_f32(x): + return _ir_cast(x, _dtype_to_ir_type(jnp.float32), signed=False) + + @register_lowering(lax.dot_general_p) def _dot_general_lowering( ctx: LoweringRuleContext, @@ -2237,6 +2338,9 @@ def _dot_general_lowering( | lax.DotAlgorithmPreset.F16_F16_F32 | lax.DotAlgorithmPreset.BF16_BF16_BF16 | lax.DotAlgorithmPreset.BF16_BF16_F32 + | lax.DotAlgorithmPreset.BF16_BF16_F32_X3 + | lax.DotAlgorithmPreset.BF16_BF16_F32_X6 + | lax.DotAlgorithmPreset.BF16_BF16_F32_X9 ): input_precision = None case _: @@ -2275,7 +2379,40 @@ def _dot_general_lowering( m, _ = a_type.shape _, n = b_type.shape - acc = _full(ir.RankedTensorType.get([m, n], _dtype_to_ir_type(acc_dtype)), 0) + acc = _zeros(ir.RankedTensorType.get([m, n], _dtype_to_ir_type(acc_dtype))) + + if precision in ( + lax.DotAlgorithmPreset.BF16_BF16_F32_X3, + lax.DotAlgorithmPreset.BF16_BF16_F32_X6, + lax.DotAlgorithmPreset.BF16_BF16_F32_X9, + ): + a_bf16 = _as_bf16(a) + b_bf16 = _as_bf16(b) + a_err0 = _sub(a, _as_f32(a_bf16)) + b_err0 = _sub(b, _as_f32(b_bf16)) + a_err0_bf16 = _as_bf16(a_err0) + b_err0_bf16 = _as_bf16(b_err0) + a_err1_bf16 = _as_bf16(_sub(a_err0, _as_f32(a_err0_bf16))) + b_err1_bf16 = _as_bf16(_sub(b_err0, _as_f32(b_err0_bf16))) + # Accumulate the smallest values first to reduce the numeric error. + if precision == lax.DotAlgorithmPreset.BF16_BF16_F32_X9: + acc = tt_dialect.dot(a_err1_bf16, b_err0_bf16, acc) + acc = tt_dialect.dot(a_err1_bf16, b_err1_bf16, acc) + acc = tt_dialect.dot(a_err0_bf16, b_err1_bf16, acc) + if precision in ( + lax.DotAlgorithmPreset.BF16_BF16_F32_X6, + lax.DotAlgorithmPreset.BF16_BF16_F32_X9, + ): + acc = tt_dialect.dot(a_err1_bf16, b_bf16, acc) + acc = tt_dialect.dot(a_bf16, b_err1_bf16, acc) + acc = tt_dialect.dot(a_err0_bf16, b_err0_bf16, acc) + acc = tt_dialect.dot(a_err0_bf16, b_bf16, acc) + acc = tt_dialect.dot(a_bf16, b_err0_bf16, acc) + # If `a` rounding error is zero and `b` is `inf` then `acc` may contain + # `NaN`s (as `0 * inf = NaN`), and vice versa. + acc = arith_dialect.select(_is_nan(acc), _zeros_like(acc), acc) + a, b = a_bf16, b_bf16 + acc = tt_dialect.dot(a, b, acc, input_precision=input_precision) return _cast(acc, acc_dtype, out_aval.dtype) @@ -2292,7 +2429,7 @@ def _reduction_lowering(body, ctx: LoweringRuleContext, a, axes): body, (a, a), {})), in_tree ) - combine_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( + combine_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( flat_fun, [*mapped_avals, *mapped_avals] ) out_tree = out_tree_thunk() @@ -2312,7 +2449,7 @@ def _reduction_lowering(body, ctx: LoweringRuleContext, a, axes): return list(reduce_op.result) -def _reduce_lowering(body, ctx: LoweringRuleContext, a, *, axes): +def _reduce_lowering(body, ctx: LoweringRuleContext, a, *, axes, **kwargs): assert isinstance(axes, tuple) if not axes: return a @@ -2396,7 +2533,7 @@ def _reduce_argmin_combine(left, right): ) -@register_lowering(pjit.pjit_p) +@register_lowering(pjit.jit_p) def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_): if jaxpr.consts: raise NotImplementedError @@ -2404,8 +2541,9 @@ def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_): ctx.context, jaxpr.jaxpr, ctx.block_infos, *args ) -@register_lowering(pjit.mesh_cast_p) -def _mesh_cast_lowering_rule(ctx, x, dst_sharding): + +@register_lowering(pjit.reshard_p) +def _reshard_lowering_rule(ctx, x, *, dst_sharding, concrete_mesh): return x @@ -2446,61 +2584,6 @@ def _is_read_only(ref_effects) -> bool: return isinstance(eff, state.ReadEffect) -@register_lowering(for_loop.for_p) -def _for_lowering_rule( - ctx: LoweringRuleContext, - *args, - jaxpr, - which_linear, - nsteps, - reverse, - unroll, -): - del which_linear - if reverse or unroll != 1: - raise NotImplementedError - _i_constant = _i64_constant if config.enable_x64.value else _i32_constant - lower_bound = _i_constant(0) - upper_bound = _i_constant(nsteps) - step = _i_constant(1) - init_args = map(_ensure_ir_value, args, ctx.avals_in) - # Partially discharge state from jaxpr for non-pointers - should_discharge = [ - not isinstance(a, state.AbstractRef) for a in ctx.avals_in - ] - discharged_jaxpr, () = discharge.discharge_state( - jaxpr, (), should_discharge=[True, *should_discharge] - ) - in_avals = [v.aval for v in jaxpr.invars] - state_effects = state.get_ref_state_effects(in_avals, jaxpr.effects)[1:] - # Read-only `Ref`s don't need to be passed in explicitly as loop arguments so - # we can filter them out. - read_only = map(_is_read_only, state_effects) - is_loop_arg = map( - operator.and_, map(operator.not_, read_only), should_discharge - ) - ptrs, _ = partition_list(should_discharge, init_args) - non_loop_args, loop_args = partition_list(is_loop_arg, init_args) - for_op = scf_dialect.ForOp(lower_bound, upper_bound, step, loop_args) - with ir.InsertionPoint(for_op.body): - loop_index = for_op.induction_variable - for_body_args = [ - for_op.body.arguments[i + 1] for i, _ in enumerate(loop_args) - ] - loop_body_args = merge_lists(is_loop_arg, non_loop_args, for_body_args) - out_discharged = lower_jaxpr_to_triton_ir( - ctx.context, - discharged_jaxpr, - [None, *ctx.block_infos], - loop_index, - *loop_body_args, - ) - all_out = merge_lists(should_discharge, ptrs, out_discharged) - _, loop_out = partition_list(is_loop_arg, all_out) - scf_dialect.yield_(loop_out) - return merge_lists(is_loop_arg, non_loop_args, list(for_op.results_)) - - def _lower_jaxpr_to_for_loop( ctx: LoweringRuleContext, jaxpr: jax_core.Jaxpr, @@ -2642,7 +2725,8 @@ def _maybe_pattern_match_fori_loop( jaxpr = jaxpr.replace( eqns=jaxpr.eqns[:eqn_index] + jaxpr.eqns[eqn_index + 1:], invars=new_invars, - outvars=new_outvars) + outvars=new_outvars, + debug_info=jaxpr.debug_info.with_unknown_names()) _, body_consts, carry = split_list(args, [cond_nconsts, body_nconsts]) (lb, ub), args = carry[:2], carry[2:] const_block_infos, args_block_infos = split_list(ctx.block_infos, @@ -2753,7 +2837,7 @@ def to_type(out_aval): use_branch0 = _equal(index, _ir_constant(0, index.type), signed=False) # TODO(bjp): Switch to scf.index_switch once exposed in triton.cc - if_op = scf_dialect.IfOp(use_branch0, out_types, hasElse=True) + if_op = scf_dialect.IfOp(use_branch0, out_types, has_else=True) with ir.InsertionPoint.at_block_begin(if_op.then_block): outs0 = lower_jaxpr_to_triton_ir( ctx.context, @@ -2784,13 +2868,17 @@ def to_type(out_aval): def _ensure_ir_value(x: object, aval: jax_core.ShapedArray) -> ir.Value: if isinstance(x, ir.Value): return x - elif isinstance(x, (np.number, np.ndarray, int, float)): + elif isinstance( + x, (np.number, np.ndarray, int, float, literals.TypedNdArray) + ): return _ir_constant(x, _dtype_to_ir_type(aval.dtype)) raise NotImplementedError def _ir_constant(v: object, t: ir.Type) -> ir.Value: - if isinstance(v, (np.number, np.ndarray, int, float)): + if isinstance( + v, (np.number, np.ndarray, int, float, literals.TypedNdArray) + ): if isinstance(t, ir.IntegerType): v = int(v) else: @@ -2829,7 +2917,7 @@ def _bitcast_convert_type_lowering_rule( raise NotImplementedError( f"cannot cast {operand} to {new_dtype} because of different widths" ) - if ir.RankedTensorType.isinstance(operand.type): + if isinstance(operand.type, ir.RankedTensorType): shape = ir.RankedTensorType(operand.type).shape result_type = ir.RankedTensorType.get(shape, dst_elem_type) else: diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index 4e8775e514f0..d418cebc7f20 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -17,16 +17,19 @@ from __future__ import annotations import io -from typing import Any +import json +from typing import cast import zlib import jax +from jax._src import frozen_dict import jax._src.core as jax_core from jax._src.interpreters import mlir -from jax._src.lib import triton from jax._src.lib import gpu_triton as triton_kernel_call_lib +from jax._src.lib import triton from jax._src.lib.mlir import ir from jax._src.pallas import core as pallas_core +from jax._src.pallas.triton import core as triton_core from jax._src.pallas.triton import lowering @@ -39,7 +42,7 @@ def normalize_grid(grid: pallas_core.StaticGrid) -> tuple[int, int, int]: def avals_to_layouts(avals): - return [list(reversed(range(aval.ndim))) for aval in avals] + return [list(reversed(range(aval.ndim))) for aval in avals] # pytype: disable=attribute-error def pallas_call_lowering( @@ -51,11 +54,13 @@ def pallas_call_lowering( input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: pallas_core.GridMapping, mesh: pallas_core.Mesh | None, - compiler_params: dict[str, Any], + compiler_params: dict[str, pallas_core.CompilerParams], cost_estimate: pallas_core.CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], + metadata: frozen_dict.FrozenDict[str, str] | None, + name: str | None, ): - del interpret, out_avals, cost_estimate + del interpret, out_avals, cost_estimate, name debug_info = jaxpr.debug_info if grid_mapping.num_dynamic_grid_bounds: raise NotImplementedError( @@ -67,16 +72,17 @@ def pallas_call_lowering( ) if mesh is not None: raise NotImplementedError("mesh is not supported in the Triton backend") - triton_params = compiler_params.get("triton", compiler_params) - num_warps = triton_params.get("num_warps", 4) - num_warps = 4 if num_warps is None else num_warps + [lowering_platform] = ctx.platforms or ctx.module_context.platforms - if lowering_platform == "rocm": - num_stages = triton_params.get("num_stages", 1) - num_stages = 1 if num_stages is None else num_stages + + if "triton" in compiler_params: + params = cast(triton_core.CompilerParams, compiler_params["triton"]) else: - num_stages = triton_params.get("num_stages", 3) - num_stages = 3 if num_stages is None else num_stages + params = triton_core.CompilerParams() + num_warps = 4 if params.num_warps is None else params.num_warps + num_stages = params.num_stages + if num_stages is None: + num_stages = 1 if lowering_platform == "rocm" else 3 if debug: print(f"\nThe kernel jaxpr for pallas_call {debug_info.func_src_info}:") @@ -99,12 +105,16 @@ def pallas_call_lowering( buf = io.BytesIO() module_op.write_bytecode(buf) + serialized_metadata = None + if metadata is not None: + serialized_metadata = json.dumps(dict(metadata)) + # TODO(b/394629193): Remove True once the bug is fixed. if True: # AOT Triton compilation is only available on jaxlib 0.5.1+. out_types = [ - ir.RankedTensorType.get(bm.array_shape_dtype.shape, - mlir.dtype_to_ir_type(bm.array_shape_dtype.dtype)) + ir.RankedTensorType.get(bm.array_aval.shape, + mlir.dtype_to_ir_type(bm.array_aval.dtype)) for bm in grid_mapping.block_mappings_output ] backend_config = dict( @@ -117,12 +127,11 @@ def pallas_call_lowering( grid_z=mlir.i32_attr(grid_z), debug=ir.BoolAttr.get(debug), ) - if "serialized_metadata" in (triton_params or {}): + if serialized_metadata is not None: # This field is unstable and may be removed in the future. - if triton_params["serialized_metadata"] is not None: - backend_config["serialized_metadata"] = ir.StringAttr.get( - triton_params["serialized_metadata"] - ) + backend_config["serialized_metadata"] = ir.StringAttr.get( + serialized_metadata + ) return mlir.custom_call( call_target_name="__gpu$xla.gpu.triton", result_types=out_types, @@ -178,10 +187,10 @@ def pallas_call_lowering( call_target_name="triton_kernel_call", result_types=[*map(mlir.aval_to_ir_type, ctx.avals_out)], operands=in_nodes, - backend_config=zlib.compress( + backend_config=zlib.compress( kernel_call.to_proto( debug_info.func_name, - triton_params.get("serialized_metadata") or b"", + (serialized_metadata or "").encode(), ) ), operand_layouts=avals_to_layouts(ctx.avals_in), diff --git a/jax/_src/pallas/triton/primitives.py b/jax/_src/pallas/triton/primitives.py index b845a4079ff4..25423ebba471 100644 --- a/jax/_src/pallas/triton/primitives.py +++ b/jax/_src/pallas/triton/primitives.py @@ -17,16 +17,22 @@ from __future__ import annotations from collections.abc import Sequence +from typing import TypeAlias import jax from jax._src import core as jax_core +from jax._src import state from jax._src.lib.mlir.dialects import gpu as gpu_dialect from jax._src.lib.triton import dialect as tt_dialect +from jax._src.pallas import primitives as pallas_primitives from jax._src.pallas.triton import lowering from jax.interpreters import mlir import jax.numpy as jnp +Ref: TypeAlias = state.AbstractRef | state.TransformedRef + + def approx_tanh(x: jax.Array) -> jax.Array: r"""Elementwise approximate hyperbolic tangent: :math:`\mathrm{tanh}(x)`. @@ -42,6 +48,11 @@ def approx_tanh(x: jax.Array) -> jax.Array: elif x.dtype == jnp.float32: asm = "tanh.approx.f32 $0, $1;" constraint = "f" + elif x.dtype == jnp.float64: + # f64 tanh.approx is only supported on ROCm (uses __ocml_tanh_f64) + # CUDA does not have a PTX instruction for f64 approximate tanh + asm = "tanh.approx.f64 $0, $1;" + constraint = "d" else: raise TypeError(f"approx_tanh does not accept {x.dtype} arrays") @@ -83,7 +94,7 @@ def elementwise_inline_asm( asm=asm, constraints=constraints, pack=pack, - result_shape_dtypes=result_shape_dtypes, + result_shape_dtypes=tuple(result_shape_dtypes), ) @@ -113,6 +124,13 @@ def _elementwise_inline_asm_lowering( result_shape_dtypes, ): del result_shape_dtypes # Unused. + + # For ROCm, PTX inline assembly is not supported. For tanh.approx, we use + # Triton's __triton_hip_fast_tanhf (fast exp-based formula) for f32, and + # OCML's __ocml_tanh_f64 for f64. See: https://github.com/triton-lang/triton/pull/7780 + if ctx.context.platform == "rocm" and "tanh.approx" in asm: + return _approx_tanh_rocm_lowering(ctx, *args) + return tt_dialect.ElementwiseInlineAsmOp( [*map(mlir.aval_to_ir_type, ctx.avals_out)], asm, @@ -123,6 +141,86 @@ def _elementwise_inline_asm_lowering( ).result +def _approx_tanh_rocm_lowering( + ctx: lowering.LoweringRuleContext, + *args, +): + """Lower approx_tanh for ROCm. + + AMD CDNA3 (MI300X/gfx942) does not have a hardware tanh instruction. + + For f32 (and f16/bf16 via casting): We use Triton's __triton_hip_fast_tanhf + which implements a fast exp-based formula: tanh(x) = (exp(2x) - 1) / (exp(2x) + 1) + See: https://github.com/triton-lang/triton/pull/7780 + + For f64: We use OCML's __ocml_tanh_f64 (AMD's Open Compute Math Library) + since fast_tanhf only supports f32. + """ + from jax._src.lib.mlir import ir + from jax._src.lib.mlir.dialects import arith as arith_dialect + + [arg] = args + [out_aval] = ctx.avals_out + in_dtype = ctx.avals_in[0].dtype + + # Helper to get IR type for a dtype + def dtype_to_ir_type(dtype): + dtype = jnp.dtype(dtype) + return mlir.dtype_to_ir_type(dtype) + + # f64: use __ocml_tanh_f64 (fast_tanhf only supports f32) + if in_dtype == jnp.float64: + result_type = mlir.aval_to_ir_type(out_aval) + result = tt_dialect.extern_elementwise( + result_type, + list(args), + libname="", + libpath="", + symbol="__ocml_tanh_f64", + pure=True, + ) + return [result] + + # fast_tanhf only supports f32. For f16/bf16, cast to f32, compute, cast back. + needs_cast = in_dtype in (jnp.float16, jnp.bfloat16) + + if needs_cast: + # Cast input to f32 (extend) + f32_type = dtype_to_ir_type(jnp.float32) + if out_aval.shape: + f32_result_type = ir.RankedTensorType.get(out_aval.shape, f32_type) + else: + f32_result_type = f32_type + arg_f32 = arith_dialect.extf(f32_result_type, arg) + + # Call __triton_hip_fast_tanhf (fast exp-based implementation) + tanh_result = tt_dialect.extern_elementwise( + f32_result_type, + [arg_f32], + libname="libdevice", + libpath="", + symbol="__triton_hip_fast_tanhf", + pure=True, + ) + + # Cast result back to original dtype (truncate) + out_type = mlir.aval_to_ir_type(out_aval) + result = arith_dialect.truncf(out_type, tanh_result) + else: + # f32: call __triton_hip_fast_tanhf directly + result_type = mlir.aval_to_ir_type(out_aval) + result = tt_dialect.extern_elementwise( + result_type, + list(args), + libname="libdevice", + libpath="", + symbol="__triton_hip_fast_tanhf", + pure=True, + ) + + return [result] + + def debug_barrier() -> None: """Synchronizes all kernel executions in the grid.""" return debug_barrier_p.bind() @@ -140,3 +238,58 @@ def _debug_barrier_lowering(ctx: lowering.LoweringRuleContext): del ctx # Unused. gpu_dialect.barrier() return [] + + +def load( + ref: Ref, + *, + mask: jax.Array | None = None, + other: jax.typing.ArrayLike | None = None, + cache_modifier: str | None = None, + eviction_policy: str | None = None, + volatile: bool = False, +) -> jax.Array: + """Loads an array from the given ref. + + If neither ``mask`` nor ``other`` is specified, this function has the same + semantics as ``ref[idx]`` in JAX. + + Args: + ref: The ref to load from. + mask: An optional boolean mask specifying which indices to load. If mask is + ``False`` and ``other`` is not given, no assumptions can be made about the + value in the resulting array. + other: An optional value to use for indices where mask is ``False``. + cache_modifier: TO BE DOCUMENTED. + eviction_policy: TO BE DOCUMENTED. + volatile: TO BE DOCUMENTED. + """ + return pallas_primitives.load( + ref, + None, + mask=mask, + other=other, + cache_modifier=cache_modifier, + eviction_policy=eviction_policy, + volatile=volatile, + ) + + +def store( + ref: Ref, + val: jax.Array, + *, + mask: jax.Array | None = None, + eviction_policy: str | None = None, +) -> None: + """Stores a value to the given ref. + + See :func:`~jax.experimental.pallas.load` for the meaning of the arguments. + """ + return pallas_primitives.store( + ref, + None, + val, + mask=mask, + eviction_policy=eviction_policy, + ) diff --git a/jax/_src/pallas/utils.py b/jax/_src/pallas/utils.py index a78c5487a4d6..90a61aeb4302 100644 --- a/jax/_src/pallas/utils.py +++ b/jax/_src/pallas/utils.py @@ -15,13 +15,16 @@ """Pallas utility functions.""" from __future__ import annotations -from typing import overload +import dataclasses +from typing import Any, overload -import jax -from jax import lax +from jax._src.lax import lax from jax._src import core as jax_core +from jax._src import dtypes +from jax._src import typing as jax_typing + from jax._src.util import split_list -import jax.numpy as jnp +from jax._src import numpy as jnp import numpy as np @@ -30,21 +33,29 @@ def cdiv(a: int, b: int) -> int: ... @overload -def cdiv(a: int, b: jax.Array) -> jax.Array: +def cdiv(a: int, b: jax_typing.Array) -> jax_typing.Array: ... @overload -def cdiv(a: jax.Array, b: int) -> jax.Array: +def cdiv(a: jax_typing.Array, b: int) -> jax_typing.Array: ... @overload -def cdiv(a: jax.Array, b: jax.Array) -> jax.Array: +def cdiv(a: jax_typing.Array, b: jax_typing.Array) -> jax_typing.Array: ... -def cdiv(a: int | jax.Array, b: int | jax.Array) -> int | jax.Array: +def cdiv(a: int | jax_typing.Array, b: int | jax_typing.Array) -> int | jax_typing.Array: + """Computes the ceiling division of a divided by b. + + Examples: + >>> cdiv(8, 2) + 4 + >>> cdiv(9, 2) # 9 / 2 = 4.5, which rounds up to 5 + 5 + """ if isinstance(a, int) and isinstance(b, int): return (a + b - 1) // b - return lax.div(a + b - 1, b) + return lax.div(a + (b - 1), b) def strides_from_shape(shape: tuple[int, ...]) -> tuple[int, ...]: @@ -62,10 +73,6 @@ def next_power_of_2(x: int) -> int: raise ValueError("`next_power_of_2` requires a non-negative integer.") return 1 if x == 0 else 2 ** (x - 1).bit_length() -def dtype_bitwidth(dtype: np.dtype | jnp.dtype) -> int: - if jnp.issubdtype(dtype, jnp.integer): - return jnp.iinfo(dtype).bits - return np.dtype(dtype).itemsize * 8 def pattern_match_scan_to_fori_loop( jaxpr: jax_core.Jaxpr, num_consts: int, num_carry: int @@ -110,9 +117,9 @@ def pattern_match_scan_to_fori_loop( def pattern_match_while_to_fori_loop( - cond_jaxpr: jax_core.Jaxpr, + cond_jaxpr: jax_core.ClosedJaxpr, cond_nconsts: int, - body_jaxpr: jax_core.Jaxpr, + body_jaxpr: jax_core.ClosedJaxpr, body_nconsts: int, ) -> tuple[jax_core.Jaxpr | None, str | None]: # Try to pattern match to fori loop. @@ -168,10 +175,23 @@ def pattern_match_while_to_fori_loop( *jaxpr.invars[body_nconsts + 2 :], ) new_outvars = tuple(jaxpr.outvars[2:]) + if jaxpr.debug_info.arg_names is not None: + new_arg_names = (*jaxpr.debug_info.arg_names[:body_nconsts], + "", + *jaxpr.debug_info.arg_names[body_nconsts + 2:]) + else: + new_arg_names = None + if jaxpr.debug_info.result_paths is not None: + new_result_paths = jaxpr.debug_info.result_paths[2:] + else: + new_result_paths = None + jaxpr = jaxpr.replace( eqns=jaxpr.eqns[:eqn_index] + jaxpr.eqns[eqn_index + 1 :], invars=new_invars, outvars=new_outvars, + debug_info=jaxpr.debug_info._replace(arg_names=new_arg_names, + result_paths=new_result_paths) ) return jaxpr, None @@ -313,7 +333,7 @@ def nextafter_lowering_helper(x, y): jnp.float64, jnp.uint64, np.float64, np.uint64, np.int64, ) - bitwidth = dtype_bitwidth(x.dtype) + bitwidth = dtypes.itemsize_bits(x.dtype) x_as_int = x.view(jnp_uint) y_as_int = y.view(jnp_uint) @@ -380,3 +400,20 @@ def nextafter_lowering_helper(x, y): # Cast back to the original type. return result.view(jnp_float) + + +@dataclasses.dataclass(frozen=True) +class MeshInfo: + mesh_shape: tuple[int, ...] + axis_names: tuple[str, ...] + mesh_strides: tuple[int, ...] + + @staticmethod + def from_mesh(mesh: Any) -> MeshInfo: + # We need mesh <-> logical translation tables. Since the logical IDs are + # just linearized versions of the mesh IDs, we create those tables. + mesh_strides = strides_from_shape(tuple( + mesh.shape[a] for a in mesh.axis_names + )) + mesh_shape = tuple(mesh.shape.values()) + return MeshInfo(mesh_shape, mesh.axis_names, mesh_strides) diff --git a/jax/_src/partition_spec.py b/jax/_src/partition_spec.py index bf6a90060bc8..d1ee8d6a40fc 100644 --- a/jax/_src/partition_spec.py +++ b/jax/_src/partition_spec.py @@ -13,37 +13,31 @@ # limitations under the License. from __future__ import annotations +from typing import Any -class UnconstrainedSingleton: - - def __repr__(self): - return "UNCONSTRAINED" - - def __reduce__(self): - return (_get_default_unconstrained, ()) +from jax._src.lib import _jax +from jax._src.util import use_cpp_class, use_cpp_method +_UNCONSTRAINED_PARTITION = _jax.UNCONSTRAINED_PARTITION +_canonicalize_partition = _jax.canonicalize_partition -# Unconstrained sentinel value for PartitionSpec, representing a dimension for -# which the user wants XLA to assign the best partitioning. -# TODO(yashkatariya): May rename to AUTO. -_UNCONSTRAINED_PARTITION = UnconstrainedSingleton() -def _get_default_unconstrained(): - return _UNCONSTRAINED_PARTITION +def unpickle_pspec(partitions, unreduced, reduced): + return PartitionSpec(*partitions, unreduced=unreduced, reduced=reduced) -def _canonicalize_partition(partition): - if not partition: - return None - if partition is _UNCONSTRAINED_PARTITION: - return _UNCONSTRAINED_PARTITION - if isinstance(partition, (tuple, list)): - if len(partition) == 1: - return partition[0] - return tuple(partition) - return partition +def _get_ur_str(unreduced, reduced): + if unreduced and reduced: + return f"unreduced={set(unreduced)!r}, reduced={set(reduced)!r}" + elif unreduced and not reduced: + return f"unreduced={set(unreduced)!r}" + elif not unreduced and reduced: + return f"reduced={set(reduced)!r}" + assert False # unreachable +AxisName = Any -class PartitionSpec(tuple): +@use_cpp_class(_jax.PartitionSpec) +class PartitionSpec: """Tuple describing how to partition an array across a mesh of devices. Each element is either ``None``, a string, or a tuple of strings. @@ -52,38 +46,132 @@ class PartitionSpec(tuple): This class exists so JAX's pytree utilities can distinguish a partition specifications from tuples that should be treated as pytrees. """ + __match_args__ = ("_partitions",) # A sentinel value representing a dim is unconstrained. UNCONSTRAINED = _UNCONSTRAINED_PARTITION - def __init__(self, *partitions): - pass - - def __new__(cls, *partitions): - partitions = tuple(_canonicalize_partition(p) for p in partitions) - return tuple.__new__(PartitionSpec, partitions) + @use_cpp_method() + def __init__(self, *partitions, unreduced=frozenset(), reduced=frozenset()): + self._partitions = tuple(_canonicalize_partition(p) for p in partitions) + if not isinstance(unreduced, (set, frozenset)): + raise TypeError( + "`unreduced` argument of PartitionSpec should be of type" + f" `frozenset` or `set`. Got type {type(unreduced)}") + if not isinstance(reduced, (set, frozenset)): + raise TypeError( + "`reduced` argument of PartitionSpec should be of type" + f" `frozenset` or `set`. Got type {type(reduced)}") + self.unreduced = frozenset(unreduced) + # See the description of https://github.com/jax-ml/jax/pull/29381 + self.reduced = frozenset(reduced) + # `__init__` is implemented in C++ so this check happens in C++ + # _check(self._partitions, self.unreduced, self.reduced) def __repr__(self): - return f"PartitionSpec{tuple.__repr__(self)}" + pr = repr(self._partitions)[1:-1] + if not self.unreduced and not self.reduced: + return f"PartitionSpec({pr})" + ur_str = _get_ur_str(self.unreduced, self.reduced) + pr = '' if not pr else f"{pr} " if pr.endswith(',') else f"{pr}, " + return (f"PartitionSpec({pr}{ur_str})") def __reduce__(self): - return (PartitionSpec, tuple(self)) + return (unpickle_pspec, (self._partitions, self.unreduced, self.reduced)) + + def __getitem__(self, i): + return self._partitions[i] + + def __iter__(self): + return iter(self._partitions) + + def __len__(self): + return len(self._partitions) + @use_cpp_method() def __eq__(self, other): - if not isinstance(other, tuple): + if isinstance(other, PartitionSpec): + return (self._partitions == other._partitions and + self.unreduced == other.unreduced and + self.reduced == other.reduced) + elif isinstance(other, tuple): + if self.unreduced: + raise TypeError( + f"other {other} cannot be of instance `tuple` when self {self} has" + " unreduced in `__eq__` of PartitionSpec.") + if self.reduced: + raise TypeError( + f"other {other} cannot be of instance `tuple` when self {self} has" + " reduced in `__eq__` of PartitionSpec.") + other_p = tuple(_canonicalize_partition(o) for o in other) + return self._partitions == other_p + else: return False - other = tuple(_canonicalize_partition(o) for o in other) - return super().__eq__(other) + @use_cpp_method() def __hash__(self): - return super().__hash__() + return hash((self._partitions, self.unreduced, self.reduced)) + + def __add__(self, other): + if isinstance(other, PartitionSpec): + return PartitionSpec( + *self, *other, + unreduced={*self.unreduced, *other.unreduced}, + reduced={*self.reduced, *other.reduced}) + elif isinstance(other, tuple): + if self.unreduced: + raise TypeError( + f"other {other} cannot be of instance `tuple` when self {self} has" + " unreduced in `__add__` of PartitionSpec.") + if self.reduced: + raise TypeError( + f"other {other} cannot be of instance `tuple` when self {self} has" + " reduced in `__add__` of PartitionSpec.") + return PartitionSpec(*self, *other) + else: + raise NotImplementedError + + def __radd__(self, other): + if not isinstance(other, tuple): + raise NotImplementedError + # other will always be a tuple. + if self.unreduced: + raise TypeError( + f"other {other} cannot be of instance `tuple` when self {self} has" + " unreduced in `__radd__` of PartitionSpec.") + if self.reduced: + raise TypeError( + f"other {other} cannot be of instance `tuple` when self {self} has" + " reduced in `__radd__` of PartitionSpec.") + return PartitionSpec(*other, *self) def index(self, value): - value = _canonicalize_partition(value) - return super().index(value) + return self._partitions.index(_canonicalize_partition(value)) + + def count(self, value): + return self._partitions.count(_canonicalize_partition(value)) + + def update(self, **kwargs): + return PartitionSpec(*kwargs.pop("partitions", self._partitions), + unreduced=kwargs.pop("unreduced", self.unreduced), + reduced=kwargs.pop("reduced", self.reduced)) def _normalized_spec_for_aval(self, ndim: int) -> PartitionSpec: - out = [None if p is _UNCONSTRAINED_PARTITION else p for p in self] + out = [None if p is _UNCONSTRAINED_PARTITION else p + for p in self._partitions] if len(out) < ndim: out.extend([None] * (ndim - len(out))) - return PartitionSpec(*out) + return self.update(partitions=out) + + def _check_compatible_wrt_shape(self, shape): + if len(shape) < len(self._partitions): + extra_msg = (' For scalars the PartitionSpec should be P()' + if len(shape) == 0 else '') + raise ValueError( + f"PartitionSpec {self} is only valid for values of rank at least " + f"{len(self._partitions)}, but was applied to a value of rank " + f"{len(shape)}.{extra_msg}") + +PartitionSpec.__module__ = 'jax.sharding' + +P = PartitionSpec diff --git a/jax/_src/path.py b/jax/_src/path.py index 03a15e42e33a..9e07b51a9a49 100644 --- a/jax/_src/path.py +++ b/jax/_src/path.py @@ -43,3 +43,18 @@ def __call__(self, *pathsegments: str | os.PathLike) -> pathlib.Path: # https://github.com/google/etils/blob/2083f3d932a88d8a135ef57112cd1f9aff5d559e/etils/epath/abstract_path.py#L47 Path = epath.Path epath_installed = True + +def make_jax_dump_dir(out_dir_path: str) -> pathlib.Path | None: + """Make a directory or return the undeclared outputs directory if `sponge`.""" + if not out_dir_path: + return None + if out_dir_path == "sponge": + out_dir_path = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", "") + if not out_dir_path: + raise ValueError( + "Got output directory (e.g., via JAX_DUMP_IR_TO) 'sponge' but" + " TEST_UNDECLARED_OUTPUTS_DIR is not defined." + ) + out_dir = Path(out_dir_path) + out_dir.mkdir(parents=True, exist_ok=True) + return out_dir diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index f7a4361ffee2..bbdcebb64bd3 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -16,14 +16,12 @@ from collections import defaultdict from collections.abc import Callable, Sequence, Iterable -import contextlib -import dataclasses +from dataclasses import dataclass, replace from functools import partial import inspect import logging -import operator as op import weakref -from typing import NamedTuple, Any, Union, cast +from typing import NamedTuple, Any, Union import warnings import numpy as np @@ -34,6 +32,7 @@ from jax._src import core from jax._src import dispatch from jax._src import dtypes +from jax._src import effects from jax._src import linear_util as lu from jax._src import mesh as mesh_lib from jax._src import op_shardings @@ -45,15 +44,12 @@ from jax._src import tree_util from jax._src import util from jax._src import xla_bridge as xb +from jax._src.core import typeof, cur_qdd from jax._src.api_util import ( - argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs, - donation_vector, check_callable, resolve_argnums, - argnames_partial_except, debug_info, - hoist_obj_attrs, _check_no_aliased_ref_args, - _check_no_aliased_closed_over_refs) + flatten_axes, donation_vector, check_callable, resolve_argnums, debug_info, + check_no_aliased_ref_args, _check_no_aliased_closed_over_refs) from jax._src.interpreters import partial_eval as pe from jax._src.partition_spec import PartitionSpec -from jax._src.interpreters import xla from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -68,18 +64,19 @@ NamedSharding, GSPMDSharding, SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue, prepare_axis_resources, parse_flatten_op_sharding, canonicalize_sharding, - flatten_spec) -from jax._src.layout import Layout, DeviceLocalLayout, AutoLayout -from jax._src.state import discharge as state_discharge, RefEffect, AbstractRef + _internal_use_concrete_mesh) +from jax._src.layout import Format, Layout, AutoLayout, get_layout_for_vmap +from jax._src.state.types import RefEffect from jax._src.traceback_util import api_boundary from jax._src.tree_util import ( - tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure, tree_leaves, - treedef_children, broadcast_prefix, all_leaves, prefix_errors, keystr, - PyTreeDef, none_leaf_registry as none_lr, tree_map) + tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure, + treedef_children, prefix_errors, PyTreeDef, none_leaf_registry as none_lr, + tree_map, FlatTree) +from jax._src.typing import ArrayLike from jax._src.util import ( - HashableFunction, safe_map, safe_zip, wraps, tuple_insert, - distributed_debug_log, split_list, weakref_lru_cache, - merge_lists, subs_list, fun_name, fun_qual_name) + HashableFunction, safe_map, safe_zip, wraps, distributed_debug_log, + split_list, weakref_lru_cache, merge_lists, subs_list, fun_name) +from jax._src.lib import jaxlib_extension_version map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip @@ -94,48 +91,6 @@ logger = logging.getLogger(__name__) -def _find_arg_mismatch(arg_list, fails, fun_name): - mismatched_args_msg = [] - def mismatch(err): - for name, inp_da, aval in arg_list: - if err.m_type == pxla.MismatchType.ARG_SHARDING and err.da == inp_da: - mismatched_args_msg.append( - f"argument {name} of {fun_name} with shape {aval.str_short()} and " - f"{err._dev_ids_plat_str}") - break - first_err, second_err = fails - mismatch(first_err) - mismatch(second_err) - return mismatched_args_msg - - -def _device_assignment_mismatch_error(fun_name, fails, args_flat, api_name, - arg_names): - arg_list = [] - if arg_names is None: - arg_names = [''] * len(args_flat) - for a, n in zip(args_flat, arg_names): - da = (a.sharding._device_assignment - if getattr(a, 'sharding', None) is not None else None) - arg_list.append((n, da, core.shaped_abstractify(a))) - - mismatched_args_msg = _find_arg_mismatch(arg_list, fails, fun_name) - - if len(mismatched_args_msg) == 2: - first, second = mismatched_args_msg # pytype: disable=bad-unpacking - extra_msg = f" Got {first} and {second}" - elif len(mismatched_args_msg) == 1: - first, second = fails - # Choose the failure left which is not already covered by ARG_SHARDING. - left = second if first.m_type == pxla.MismatchType.ARG_SHARDING else first - extra_msg = f" Got {mismatched_args_msg[0]} and{left._str(api_name)}" - else: - first, second = fails - extra_msg = f" Got{first._str(api_name)} and{second._str(api_name)}" - msg = (f"Received incompatible devices for {api_name}ted computation.{extra_msg}") - return msg - - class PjitInfo(NamedTuple): """Things that we know about a jit instance before it is called. @@ -163,7 +118,6 @@ class PjitInfo(NamedTuple): backend: str | None keep_unused: bool inline: bool - abstracted_axes: Any | None use_resource_env: bool # False for jit, True for pjit compiler_options_kvs: tuple[tuple[str, Any], ...] @@ -175,34 +129,31 @@ def __eq__(self, other): return self is other -def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs): - p, args_flat = _infer_params(fun, jit_info, args, kwargs) +def _run_python_pjit(p, args_flat, fun: Callable, jit_info: PjitInfo, args, kwargs): for arg in args_flat: dispatch.check_arg(arg) - if p.attrs_tracked: - init_states = _get_states(p.attrs_tracked) - args_flat = [*init_states, *args_flat] - try: - if (core.trace_state_clean() and - not config.debug_key_reuse.value and - not config.data_dependent_tracing_fallback.value): + if (core.trace_state_clean() and not config.debug_key_reuse.value + and not p.params['jaxpr'].jaxpr.is_high): args_flat = map(core.full_lower, args_flat) core.check_eval_args(args_flat) - out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params) + out_flat, compiled, profiler, const_args = _pjit_call_impl_python( + *args_flat, **p.params) else: - out_flat = pjit_p.bind(*args_flat, **p.params) + out_flat = jit_p.bind(*args_flat, **p.params) compiled = None profiler = None - except pxla.DeviceAssignmentMismatchError as e: + const_args = [] + except stages.DeviceAssignmentMismatchError as e: fails, = e.args fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun))) - msg = _device_assignment_mismatch_error( - fun_name, fails, args_flat, 'jit', p.arg_names) + arg_types = map(convert_to_metaty, args_flat) + msg = stages._device_assignment_mismatch_error( + fun_name, fails, arg_types, 'jit', p.arg_names) raise ValueError(msg) from None - except xla.InvalidInputException as e: + except dtypes.InvalidInputException as e: arg_names = [''] * len(args_flat) if p.arg_names is None else p.arg_names # Run canonicalization again to figure out which arg failed. if p.params['jaxpr'].consts: @@ -210,100 +161,67 @@ def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs): else: for arg, name, aval in zip(args_flat, arg_names, p.in_avals): try: - xla.canonicalize_dtype(arg) - except xla.InvalidInputException as _: + dtypes.canonicalize_value(arg) + except dtypes.InvalidInputException as _: # Reraise as TypeError with the new message. raise TypeError( f"Argument '{name}' of shape {aval.str_short()} of type" f' {type(arg)} is not a valid JAX type.') from e raise AssertionError("Unreachable") from e - except dispatch.InternalFloatingPointError as e: + except api_util.InternalFloatingPointError as e: if getattr(fun, '_apply_primitive', False): raise FloatingPointError(f"invalid value ({e.ty}) encountered in {fun.__qualname__}") from None - dispatch.maybe_recursive_nan_check(e, fun, args, kwargs) - - if p.attrs_tracked: - num_states_out = sum(end_tree.num_leaves for _, end_tree, _ in p.attrs_tracked) - final_states, out_flat = split_list(out_flat, [num_states_out]) - _set_states(p.attrs_tracked, final_states) + api_util.maybe_recursive_nan_check(e, fun, args, kwargs) outs = tree_unflatten(p.out_tree, out_flat) - return (outs, out_flat, p.out_tree, args_flat, p.params['jaxpr'], - p.attrs_tracked, compiled, profiler) - - -def _set_states(attrs_tracked, vals): - from jax.experimental.attrs import jax_setattr - valss = split_list(vals, [td.num_leaves for _, td, _ in attrs_tracked[:-1]]) - for ((_, treedef, (obj, attr)), leaves) in zip(attrs_tracked, valss): - val = tree_unflatten(treedef, leaves) - jax_setattr(obj, attr, val) - -def _get_states(attrs_tracked): - from jax.experimental.attrs import jax_getattr - vals = [] - for treedef, _, (obj, attr) in attrs_tracked: - tree = jax_getattr(obj, attr) - leaves, treedef_ = tree_flatten(tree) - assert treedef == treedef_ - vals.extend(leaves) - return vals + return (outs, out_flat, p.out_tree, args_flat, + p.params['jaxpr'], compiled, profiler, const_args) + def _need_to_rebuild_with_fdo(pgle_profiler): return (pgle_profiler is not None and pgle_profiler.is_enabled() and not pgle_profiler.is_fdo_consumed()) def _get_fastpath_data( - executable, out_tree, args_flat, out_flat, attrs_tracked, effects, - consts, abstracted_axes, pgle_profiler -) -> pxla.MeshExecutableFastpathData | None: - out_reflattened, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat) - - use_fastpath = ( - executable is not None - and isinstance(executable, pxla.MeshExecutable) - and isinstance(executable.unsafe_call, pxla.ExecuteReplicated) + executable, out_tree, args_flat, out_flat, effects, consts_for_constvars, + pgle_profiler, const_args: Sequence[ArrayLike] + ) -> pxla.MeshExecutableFastpathData | None: + if ( + executable is None + or not isinstance(executable, pxla.MeshExecutable) + or not isinstance(executable.unsafe_call, pxla.ExecuteReplicated) # No effects in computation - and not executable.unsafe_call.ordered_effects - and not executable.unsafe_call.has_unordered_effects - and not executable.unsafe_call.has_host_callbacks - and all(isinstance(x, xc.ArrayImpl) for x in out_reflattened) - and abstracted_axes is None - # no attr state effects - and not attrs_tracked + or executable.unsafe_call.ordered_effects + or executable.unsafe_call.has_unordered_effects # no ref state effects - and not any(isinstance(e, RefEffect) for e in effects) + or any(isinstance(e, RefEffect) for e in effects) # no prng reuse checking - and not (config.debug_key_reuse.value and any( + or (config.debug_key_reuse.value and any( hasattr(arg, 'dtype') and dtypes.issubdtype(arg.dtype, dtypes.prng_key) - for arg in (*args_flat, *out_flat, *consts))) - and not _need_to_rebuild_with_fdo(pgle_profiler) - ) + for arg in (*args_flat, *out_flat, *consts_for_constvars))) + or _need_to_rebuild_with_fdo(pgle_profiler) + or config.no_execution.value + ): + return None - if use_fastpath: - out_avals = [o.aval for o in out_reflattened] - out_committed = [o._committed for o in out_reflattened] - kept_var_bitvec = [i in executable._kept_var_idx - for i in range(len(args_flat))] - in_shardings = [ - sharding_impls.physical_sharding(a, s) - if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended) - else s - for s, a in zip(executable._in_shardings, executable.in_avals) - ] - fastpath_data = pxla.MeshExecutableFastpathData( - executable.xla_executable, out_tree, in_shardings, - executable._out_shardings, out_avals, out_committed, kept_var_bitvec, - executable._dispatch_in_layouts) - else: - fastpath_data = None - return fastpath_data - - -def _cpp_pjit_evict_fn(self): - self._clear_cache() - _create_pjit_jaxpr.evict_function(self._fun) # pytype: disable=attribute-error - _infer_params_cached.cache_clear() + out_reflattened, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat) + if not all(isinstance(x, xc.ArrayImpl) for x in out_reflattened): + return None + + out_avals = [o.aval for o in out_reflattened] + out_committed = [o._committed for o in out_reflattened] + kept_var_bitvec = [i in executable._kept_var_idx + for i in range(len(const_args) + len(args_flat))] + in_shardings = [ + sharding_impls.physical_sharding(a, s) + if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended) + else s + for s, a in zip(executable._in_shardings, executable.in_avals) + ] + return pxla.MeshExecutableFastpathData( + executable.xla_executable, out_tree, in_shardings, + executable._out_shardings, out_avals, out_committed, kept_var_bitvec, + executable._dispatch_in_layouts, const_args) # The entries are doubled here from the default 4096 because _pjit_call_impl @@ -331,17 +249,19 @@ def _cpp_pjit(fun: Callable, jit_info: PjitInfo): @api_boundary def cache_miss(*args, **kwargs): + # args do not include the const args + # See https://docs.jax.dev/en/latest/internals/constants.html. if config.no_tracing.value: raise RuntimeError(f"re-tracing function {jit_info.fun_sourceinfo} for " "`jit`, but 'no_tracing' is set") - - (outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked, executable, - pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs) + p, args_flat = _trace_for_jit(fun, jit_info, args, kwargs) + (outs, out_flat, out_tree, args_flat, jaxpr, + executable, pgle_profiler, const_args) = _run_python_pjit( + p, args_flat, fun, jit_info, args, kwargs) maybe_fastpath_data = _get_fastpath_data( - executable, out_tree, args_flat, out_flat, attrs_tracked, jaxpr.effects, - jaxpr.consts, jit_info.abstracted_axes, - pgle_profiler) + executable, out_tree, args_flat, out_flat, jaxpr.effects, jaxpr.consts, + pgle_profiler, const_args) return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler) @@ -366,22 +286,49 @@ def cache_miss(*args, **kwargs): cpp_pjitted_f = wraps(fun)(cpp_pjit_f) cpp_pjitted_f._fun = fun - type(cpp_pjitted_f).clear_cache = _cpp_pjit_evict_fn + cpp_pjitted_f._jit_info = jit_info + cpp_jitted_f_class = type(cpp_pjitted_f) + cpp_jitted_f_class.clear_cache = jit_evict_fn + cpp_jitted_f_class.lower = jit_lower + cpp_jitted_f_class.trace = jit_trace + cpp_jitted_f_class.eval_shape = jit_eval_shape return cpp_pjitted_f +@api_boundary +def jit_trace(jit_func, *args, **kwargs) -> stages.Traced: + p, args_flat = _trace_for_jit(jit_func._fun, jit_func._jit_info, args, kwargs) + arg_types = map(convert_to_metaty, args_flat) + return stages.Traced(arg_types, p.params, p.in_tree, p.out_tree, p.consts) + +@api_boundary +def jit_lower(jit_func, *args, **kwargs): + return jit_trace(jit_func, *args, **kwargs).lower() + +@api_boundary +def jit_eval_shape(jit_func, *args, **kwargs): + return jit_trace(jit_func, *args, **kwargs).out_info + +def jit_evict_fn(self): + self._clear_cache() + if jaxlib_extension_version >= 392: + pe.trace_to_jaxpr.evict_weakref(self._fun) # cl/846898750 + else: + # This clears *all* jaxpr tracing caches, not just for `self`. + pe.trace_to_jaxpr.cache_clear() + def _split_layout_and_sharding(entries): entries_flat, treedef = tree_flatten(entries, is_leaf=lambda x: x is None) layouts, shardings = [], [] for e in entries_flat: - if isinstance(e, Layout): - layouts.append(e.device_local_layout) + if isinstance(e, Format): + layouts.append(e.layout) shardings.append(e.sharding) - elif isinstance(e, (DeviceLocalLayout, AutoLayout)): + elif isinstance(e, (Layout, AutoLayout)): raise ValueError( '`jax.jit` does not accept device-local layouts directly. Create ' - 'a `Layout` instance wrapping this device-local layout and pass ' + 'a `Format` instance wrapping this device-local layout and pass ' f'that to `jit` instead. Got {e}') else: layouts.append(None) @@ -391,30 +338,30 @@ def _split_layout_and_sharding(entries): return tree_unflatten(treedef, layouts), tree_unflatten(treedef, shardings) -def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, - donate_argnums: int | Sequence[int] | None, - donate_argnames: str | Iterable[str] | None, +def _parse_jit_arguments(fun: Callable, *, in_shardings: Any, + out_shardings: Any, static_argnums: int | Sequence[int] | None, static_argnames: str | Iterable[str] | None, - device: xc.Device | None, backend: str | None, - abstracted_axes: Any | None, keep_unused: bool, - inline: bool, compiler_options: dict[str, Any] | None, + donate_argnums: int | Sequence[int] | None, + donate_argnames: str | Iterable[str] | None, + keep_unused: bool, device: xc.Device | None, + backend: str | None, inline: bool, + compiler_options: dict[str, Any] | None, use_resource_env: bool) -> PjitInfo: """Parses the arguments to jit/pjit. Performs any preprocessing and validation of the arguments that we can do ahead of time before the jit()-ed function is invoked. """ - if abstracted_axes and not config.dynamic_shapes.value: - raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes") - check_callable(fun) if backend is not None or device is not None: warnings.warn( 'backend and device argument on jit is deprecated. You can use' - ' `jax.device_put(..., jax.local_devices("cpu")[0])` on the inputs to' - ' the jitted function to get the same behavior.', DeprecationWarning) + ' `jax.device_put(..., jax.local_devices(backend="cpu")[0])` on the' + ' inputs to the jitted function to get the same behavior.', + DeprecationWarning, + ) if device is not None and backend is not None: raise ValueError("can't specify both a device and a backend for jit, " f"got {device=} and {backend=}") @@ -437,7 +384,8 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, out_layouts, out_shardings = _split_layout_and_sharding(out_shardings) in_shardings = prepare_axis_resources(in_shardings, 'in_shardings') - out_shardings = prepare_axis_resources(out_shardings, 'out_shardings') + out_shardings = prepare_axis_resources(out_shardings, 'out_shardings', + allow_unconstrained_dims=True) user_specified_in_shardings = (in_shardings is not None and not isinstance(in_shardings, UnspecifiedValue)) @@ -472,110 +420,83 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, static_argnames=static_argnames, donate_argnums=donate_argnums, donate_argnames=donate_argnames, device=device, backend=backend, keep_unused=keep_unused, inline=inline, - abstracted_axes=abstracted_axes, use_resource_env=use_resource_env, compiler_options_kvs=compiler_options_kvs) - -def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo): - - @api_boundary - def lower(*args, **kwargs): - return trace(*args, **kwargs).lower() - - @api_boundary - def eval_shape(*args, **kwargs): - p, _ = _infer_params(fun, jit_info, args, kwargs) - out_s = [None if isinstance(s, UnspecifiedValue) else s for s in p.params['out_shardings']] - # TODO(yashkatariya): Add `Layout` to SDS. - out = [api.ShapeDtypeStruct(x.shape, x.dtype, sharding=s, - weak_type=x.weak_type) - for x, s in zip(p.params['jaxpr'].out_avals, out_s)] - return tree_unflatten(p.out_tree, out) - - @api_boundary - def trace(*args, **kwargs) -> stages.Traced: - p, args_flat = _infer_params(fun, jit_info, args, kwargs) - donate_argnums = tuple(i for i, d in enumerate(p.donated_invars) if d) - args_info = stages.make_args_info(p.in_tree, p.in_avals, donate_argnums) - lower_callable = partial(_resolve_and_lower, args_flat, **p.params, - pgle_profiler=None) - return stages.Traced( - p.params['jaxpr'], args_info, p.params["name"], p.out_tree, - lower_callable, args_flat, p.arg_names, p.num_consts) - - wrapped = _cpp_pjit(fun, jit_info) - wrapped.lower = lower - wrapped.eval_shape = eval_shape - wrapped.trace = trace - return wrapped - - -def make_jit(fun: Callable, in_shardings: Any, out_shardings: Any, - donate_argnums: int | Sequence[int] | None, - donate_argnames: str | Iterable[str] | None, +def make_jit(fun: Callable, + *, + in_shardings: Any, + out_shardings: Any, static_argnums: int | Sequence[int] | None, static_argnames: str | Iterable[str] | None, - device: xc.Device | None, backend: str | None, - abstracted_axes: Any | None, keep_unused: bool, - inline: bool, compiler_options: dict[str, Any] | None, + donate_argnums: int | Sequence[int] | None, + donate_argnames: str | Iterable[str] | None, + keep_unused: bool, + device: xc.Device | None, + backend: str | None, + inline: bool, + compiler_options: dict[str, Any] | None, use_resource_env: bool) -> Any: """jit() and pjit() are thin wrappers around this function.""" jit_info = _parse_jit_arguments( - fun, in_shardings, out_shardings, donate_argnums, donate_argnames, - static_argnums, static_argnames, device, backend, abstracted_axes, - keep_unused, inline, compiler_options, use_resource_env) - return _make_jit_wrapper(fun, jit_info) + fun, in_shardings=in_shardings, out_shardings=out_shardings, + static_argnums=static_argnums, static_argnames=static_argnames, + donate_argnums=donate_argnums, donate_argnames=donate_argnames, + keep_unused=keep_unused, device=device, backend=backend, inline=inline, + compiler_options=compiler_options, + use_resource_env=use_resource_env) + return _cpp_pjit(fun, jit_info) class PjitParams(NamedTuple): - consts: list[Any] # Only jaxpr constants, we can't keep other arguments alive + # Only jaxpr constants, we can't keep other arguments alive. These go as + # first arguments for `params['jaxpr']`. + consts: list[ArrayLike] # Corresponding to jaxpr.constvars + # Everything we need to trace, lower, and compile the jit function; passed + # to `pjit_call_impl_python`, along with the `args_flat` params: dict[str, Any] - in_avals: tuple[core.AbstractValue, ...] - in_tree: PyTreeDef + in_avals: tuple[core.AbstractValue, ...] # Not including the const_args + in_tree: PyTreeDef # Not including the const_args out_tree: PyTreeDef - donated_invars: tuple[bool, ...] - arg_names: tuple[str, ...] - num_consts: int - attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]] + arg_names: tuple[str, ...] # Not including the const_args -def _infer_params_impl( - fun: Callable, - ji: PjitInfo, - ctx_mesh: mesh_lib.Mesh | None, - dbg: core.DebugInfo, - args: tuple[Any, ...], - kwargs: dict[str, Any], - in_avals: tuple[core.AbstractValue, ...] | None, -) -> tuple[PjitParams, list[Any]]: - util.test_event("pjit._infer_params_impl", fun) +def _trace_for_jit( + fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> tuple[PjitParams, list[core.Value]]: + if ji.use_resource_env: # pjit + ctx_mesh = mesh_lib.thread_resources.env.physical_mesh + else: + ctx_mesh = mesh_lib.get_concrete_mesh() + dbg = debug_info( + 'jit', fun, args, kwargs, static_argnums=ji.static_argnums, + static_argnames=ji.static_argnames, sourceinfo=ji.fun_sourceinfo, + signature=ji.fun_signature) + + signature, dynargs = jax_jit.parse_arguments( + args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums, + ji.static_argnames, tree_util.default_registry) + avals_list = _infer_input_type(fun, dbg, dynargs) + args_ft = FlatTree.flatten_static_argnums_argnames( + args, kwargs, ji.static_argnums, ji.static_argnames) + # TODO(dougalm): args_ft.vals and dynargs should be exactly the same here. + # Why did we need to flatten in C++ and again in Python? + avals = args_ft.update(avals_list) + have_kwargs = bool(kwargs) if have_kwargs and ji.user_specified_in_shardings: raise ValueError( "pjit does not support kwargs when in_shardings is specified.") - if ctx_mesh is not None: - if (ji.backend or ji.device) and not ctx_mesh.empty: - raise ValueError( - "Mesh context manager should not be used with jit when backend or " - "device is also specified as an argument to jit.") - - axes_specs = _flat_axes_specs(ji.abstracted_axes, *args, **kwargs) - - f = lu.wrap_init(fun, debug_info=dbg) - f, dyn_args = argnums_partial_except(f, ji.static_argnums, args, allow_invalid=True) - del args - - f, dyn_kwargs = argnames_partial_except(f, ji.static_argnames, kwargs) - explicit_args, in_tree = tree_flatten((dyn_args, dyn_kwargs)) - flat_fun, out_tree = flatten_fun(f, in_tree) - flat_fun, explicit_args = hoist_obj_attrs(flat_fun, explicit_args) + if not ctx_mesh.empty and (ji.backend or ji.device): + raise ValueError( + "Mesh context manager should not be used with jit when backend or " + "device is also specified as an argument to jit.") if (ji.donate_argnums or ji.donate_argnames) and not config.debug_nans.value: - donated_invars = donation_vector(ji.donate_argnums, ji.donate_argnames, in_tree) + donated_invars = donation_vector(ji.donate_argnums, ji.donate_argnames, avals.tree) else: - donated_invars = (False,) * len(explicit_args) + donated_invars = (False,) * len(avals) # If backend or device is set as an arg on jit, then resolve them to # in_shardings and out_shardings as if user passed in in_shardings @@ -587,63 +508,77 @@ def _infer_params_impl( in_shardings_leaves = out_shardings_leaves = tuple(leaves) in_shardings_treedef = out_shardings_treedef = treedef else: + api_name = 'pjit' if ji.use_resource_env else 'jit' in_shardings_leaves = tuple( - _create_sharding_for_array(ctx_mesh, x, 'in_shardings', 'jit') + _create_sharding_for_array(ctx_mesh, x, 'in_shardings', api_name) for x in ji.in_shardings_leaves) - in_shardings_treedef = ji.in_shardings_treedef out_shardings_leaves = tuple( - _create_sharding_for_array(ctx_mesh, x, 'out_shardings', 'jit') + _create_sharding_for_array(ctx_mesh, x, 'out_shardings', api_name) for x in ji.out_shardings_leaves) + in_shardings_treedef = ji.in_shardings_treedef out_shardings_treedef = ji.out_shardings_treedef assert None not in in_shardings_leaves assert None not in out_shardings_leaves - in_type: core.InputType | tuple[core.AbstractValue, ...] - if config.dynamic_shapes.value: - assert in_avals is None - in_type = pe.infer_lambda_input_type(axes_specs, explicit_args) - in_avals = tuple(a for a, e in in_type if e) - else: - in_type = in_avals # type: ignore - assert in_avals is not None + in_type = avals.map2( + lambda a, x: core.AvalQDD(a, cur_qdd(x)) if a.has_qdd else a, # type: ignore + args_ft) + assert avals is not None in_shardings_flat, in_layouts_flat = _process_in_axis_resources( in_shardings_treedef, in_shardings_leaves, ji.in_layouts_treedef, ji.in_layouts_leaves, - in_avals, in_tree, flat_fun.debug_info, device_or_backend_set, have_kwargs) + avals, dbg, device_or_backend_set, have_kwargs) - attr_token = _attr_token(flat_fun, in_type) + qdd_token = _qdd_cache_index(fun, in_type.vals) # represents qdd state context - jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr( - flat_fun, in_type, attr_token, IgnoreKey(ji.inline)) + with dispatch.log_elapsed_time( + "Finished tracing + transforming {fun_name} for pjit in {elapsed_time:.9f} sec", + fun_name(fun), event=dispatch.JAXPR_TRACE_EVENT): + if ji.use_resource_env: # pjit + with (_internal_use_concrete_mesh(ctx_mesh), + mesh_lib.use_abstract_mesh(ctx_mesh.abstract_mesh)): + jaxpr, out_avals = pe.trace_to_jaxpr(fun, in_type, dbg, qdd_token) + else: + jaxpr, out_avals = pe.trace_to_jaxpr(fun, in_type, dbg, qdd_token) + + if config.debug_key_reuse.value: + # Import here to avoid circular imports + from jax.experimental.key_reuse._core import check_key_reuse_jaxpr # pytype: disable=import-error + check_key_reuse_jaxpr(jaxpr.jaxpr) + + result_paths = tuple(f"result{lu._clean_keystr_arg_names(path)}" + for path in out_avals.paths) + jaxpr.jaxpr._debug_info = jaxpr.debug_info._replace(result_paths=result_paths) + + # TODO(mattjj,yashkatariya): if we take the 'true' path then we *must* fall + # off the C++ dispatch fast path for correctness. Ensure that happens. + if any(isinstance(c, core.Tracer) or core.typeof(c).has_qdd for c in jaxpr.consts): + jaxpr, consts = pe.separate_consts(jaxpr) + else: + consts = [] if config.mutable_array_checks.value: - _check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), explicit_args) - _attr_update(flat_fun, in_type, attr_token, attrs_tracked) + _check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), dynargs) + _qdd_cache_update(fun, in_type.vals, qdd_token, consts, + jaxpr.in_aval_qdds[:len(consts)]) out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings( out_shardings_treedef, out_shardings_leaves, ji.out_layouts_treedef, - ji.out_layouts_leaves, HashableFunction(out_tree, closure=()), + ji.out_layouts_leaves, out_avals.tree, tuple(out_avals), jaxpr.jaxpr._debug_info, device_or_backend_set) - assert len(explicit_args) == len(in_shardings_flat) == len(in_layouts_flat) - - if config.dynamic_shapes.value: - implicit_args = _extract_implicit_args( - cast(core.InputType, in_type), explicit_args) - else: - implicit_args = [] - args_flat = [*implicit_args, *explicit_args] + assert len(dynargs) == len(in_shardings_flat) == len(in_layouts_flat) - num_states_in = sum(init_tree.num_leaves for init_tree, _, _ in attrs_tracked) - num_extra_args = len(implicit_args) + num_states_in + len(consts) + num_extra_args = len(consts) in_shardings_flat = (UNSPECIFIED,) * num_extra_args + in_shardings_flat in_layouts_flat = (None,) * num_extra_args + in_layouts_flat donated_invars = (False,) * num_extra_args + donated_invars assert (len(in_shardings_flat) == len(in_layouts_flat) == - len(donated_invars) == num_states_in + len(consts) + len(args_flat)) + len(donated_invars) == len(consts) + len(avals)) + name = getattr(fun, '__name__', '') params = dict( jaxpr=jaxpr, in_shardings=in_shardings_flat, @@ -652,75 +587,14 @@ def _infer_params_impl( out_layouts=out_layouts_flat, donated_invars=donated_invars, ctx_mesh=ctx_mesh, - name=fun_qual_name(flat_fun), + name=name, keep_unused=ji.keep_unused, inline=ji.inline, compiler_options_kvs=ji.compiler_options_kvs, ) - return PjitParams(consts, params, in_avals, in_tree, out_tree(), - donated_invars, dbg.arg_names, len(consts), - attrs_tracked), args_flat - - -class InferParamsCacheEntry: - """Mutable value object for _infer_params_cached.""" - __slots__ = ['pjit_params'] - pjit_params: PjitParams | None - def __init__(self): - self.pjit_params = None - - -# We use an outer cache that is keyed on the signature of the arguments, but -# when populating a cache entry using _infer_params_impl, we need to provide -# actual arguments. In principle we could refactor _infer_params_impl to look -# only at an argument signature instead of args/kwargs in those cases that we -# cache, but this was a more minimal change. -@util.weakref_lru_cache -def _infer_params_cached( - fun: Callable, - jit_info: PjitInfo, - signature: jax_jit.ArgumentSignature, - in_avals: tuple[core.AbstractValue, ...], - ctx_mesh: mesh_lib.Mesh | None, -) -> InferParamsCacheEntry: - return InferParamsCacheEntry() - - -def _infer_params( - fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any] - ) -> tuple[PjitParams, list[Any]]: - if ji.use_resource_env: - with sharding_impls.use_mesh(mesh_lib.thread_resources.env.physical_mesh): - return _infer_params_internal(fun, ji, args, kwargs) - return _infer_params_internal(fun, ji, args, kwargs) - -def _infer_params_internal( - fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any] - ) -> tuple[PjitParams, list[Any]]: - ctx_mesh = mesh_lib.get_concrete_mesh() - dbg = debug_info( - 'jit', fun, args, kwargs, static_argnums=ji.static_argnums, - static_argnames=ji.static_argnames, sourceinfo=ji.fun_sourceinfo, - signature=ji.fun_signature) - - if config.dynamic_shapes.value: # if dynamic shapes, don't use the cache - p, args_flat = _infer_params_impl(fun, ji, ctx_mesh, dbg, - args, kwargs, in_avals=None) - return p, p.consts + args_flat - - signature, dynargs = jax_jit.parse_arguments( - args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums, - ji.static_argnames, tree_util.default_registry) - avals = _infer_input_type(fun, dbg, dynargs) - entry = _infer_params_cached(fun, ji, signature, avals, ctx_mesh) - - if entry.pjit_params is None: - p, args_flat = _infer_params_impl( - fun, ji, ctx_mesh, dbg, args, kwargs, in_avals=avals) - if p.attrs_tracked: # if attrs, don't popoulate the cache - return p, p.consts + args_flat - entry.pjit_params = p - return entry.pjit_params, entry.pjit_params.consts + dynargs + p = PjitParams(consts, params, avals.vals, avals.tree_without_statics, + out_avals.tree, dbg.safe_arg_names(len(avals))) + return p, p.consts + dynargs def _infer_input_type(fun: Callable, dbg: core.DebugInfo, explicit_args) -> tuple[core.AbstractValue, ...]: @@ -729,64 +603,25 @@ def _infer_input_type(fun: Callable, dbg: core.DebugInfo, for i, x in enumerate(explicit_args): avals.append(core.shaped_abstractify(x)) except OverflowError: - arg_path = f"argument path is {dbg.arg_names[i]}" + arg_path = f"argument path is {dbg.arg_names[i] if dbg.arg_names is not None else 'unknown'}" # pytype: disable=name-error raise OverflowError( "An overflow was encountered while parsing an argument to a jitted " f"computation, whose {arg_path}." ) from None except TypeError: - arg_description = f"path {dbg.arg_names[i]}" + arg_description = f"path {dbg.arg_names[i] if dbg.arg_names is not None else 'unknown'}" # pytype: disable=name-error raise TypeError( f"Error interpreting argument to {fun} as an abstract array." - f" The problematic value is of type {type(x)} and was passed to" + f" The problematic value is of type {type(x)} and was passed to" # pytype: disable=name-error f" the function at {arg_description}.\n" "This typically means that a jit-wrapped function was called with a non-array" " argument, and this argument was not marked as static using the" " static_argnums or static_argnames parameters of jax.jit." ) from None if config.mutable_array_checks.value: - _check_no_aliased_ref_args(dbg, avals, explicit_args) + check_no_aliased_ref_args(lambda: dbg, avals, explicit_args) return tuple(avals) -def _extract_implicit_args( - in_type: Sequence[tuple[core.AbstractValue, bool]], - explicit_args: Sequence[Any] -) -> Sequence[core.Tracer]: - """ - Given an input type and explicitly-passed arguments (per the user-facing API - calling convention), extract implicit axis size arguments from shapes of - explicit arguments (for the trace-time / jaxpr-level calling convention). - """ - # First, using `in_type` construct a list to represent the full argument list, - # leaving the implicit arguments as None placeholders for now. - explicit_args_ = iter(explicit_args) - args = [next(explicit_args_) if expl else None for _, expl in in_type] - assert next(explicit_args_, None) is None - del explicit_args, explicit_args_ - - # Next, populate the implicit arguments using the DBIdxs in `in_type`. - for i, (aval, explicit) in enumerate(in_type): - if not explicit or not isinstance(aval, core.DShapedArray): - continue # can't populate an implicit argument - arg = args[i] - assert arg is not None - for d1, d2 in zip(aval.shape, arg.aval.shape): - if isinstance(d1, core.DBIdx): - if args[d1.val] is None: - args[d1.val] = d2 - assert core.same_referent(args[d1.val], d2) - assert all(x is not None for x in args) - return [x for x, (_, e) in zip(args, in_type) if not e] # type: ignore - -def _flat_axes_specs(abstracted_axes, *args, **kwargs - ) -> list[pe.AbstractedAxesSpec] | None: - if abstracted_axes is None: return None - if kwargs: raise NotImplementedError - def ax_leaf(l): - return (isinstance(l, dict) and all_leaves(l.values()) or - isinstance(l, tuple) and all_leaves(l, lambda x: x is None)) - return broadcast_prefix(abstracted_axes, args, ax_leaf) - class JitWrapped(stages.Wrapped): @@ -800,6 +635,7 @@ def trace(self, *args, **kwargs) -> stages.Traced: # in_shardings and out_shardings can't be None as the default value # because `None` means that the input is fully replicated. +@partial(api_boundary, repro_api_name="pjit.pjit") def pjit( fun: Callable, in_shardings: Any = UNSPECIFIED, @@ -812,181 +648,15 @@ def pjit( device: xc.Device | None = None, backend: str | None = None, inline: bool = False, - abstracted_axes: Any | None = None, compiler_options: dict[str, Any] | None = None, ) -> JitWrapped: - """Makes ``fun`` compiled and automatically partitioned across multiple devices. - - NOTE: This function is now equivalent to jax.jit please use that instead. - The returned function has semantics equivalent to those of ``fun``, but is - compiled to an XLA computation that runs across multiple devices - (e.g. multiple GPUs or multiple TPU cores). This can be useful if the jitted - version of ``fun`` would not fit in a single device's memory, or to speed up - ``fun`` by running each operation in parallel across multiple devices. - - The partitioning over devices happens automatically based on the - propagation of the input partitioning specified in ``in_shardings`` and - the output partitioning specified in ``out_shardings``. The resources - specified in those two arguments must refer to mesh axes, as defined by - the :py:func:`jax.sharding.Mesh` context manager. Note that the mesh - definition at :func:`~pjit` application time is ignored, and the returned function - will use the mesh definition available at each call site. - - Inputs to a :func:`~pjit`'d function will be automatically partitioned across devices - if they're not already correctly partitioned based on ``in_shardings``. - In some scenarios, ensuring that the inputs are already correctly pre-partitioned - can increase performance. For example, if passing the output of one - :func:`~pjit`'d function to another :func:`~pjit`’d function (or the same - :func:`~pjit`’d function in a loop), make sure the relevant - ``out_shardings`` match the corresponding ``in_shardings``. - - .. note:: - **Multi-process platforms:** On multi-process platforms such as TPU pods, - :func:`~pjit` can be used to run computations across all available devices across - processes. To achieve this, :func:`~pjit` is designed to be used in SPMD Python - programs, where every process is running the same Python code such that all - processes run the same :func:`~pjit`'d function in the same order. - - When running in this configuration, the mesh should contain devices across - all processes. All inputs arguments must be globally shaped. - ``fun`` will still be executed across *all* devices in the mesh, - including those from other processes, and will be given a global view of the - data spread across multiple processes as a single array. - - The SPMD model also requires that the same multi-process :func:`~pjit`'d - functions must be run in the same order on all processes, but they can be - interspersed with arbitrary operations running in a single process. - - Args: - fun: Function to be compiled. Should be a pure function, as side-effects may - only be executed once. Its arguments and return value should be arrays, - scalars, or (nested) standard Python containers (tuple/list/dict) thereof. - Positional arguments indicated by ``static_argnums`` can be anything at - all, provided they are hashable and have an equality operation defined. - Static arguments are included as part of a compilation cache key, which is - why hash and equality operators must be defined. - in_shardings: Pytree of structure matching that of arguments to ``fun``, - with all actual arguments replaced by resource assignment specifications. - It is also valid to specify a pytree prefix (e.g. one value in place of a - whole subtree), in which case the leaves get broadcast to all values in - that subtree. - - The ``in_shardings`` argument is optional. JAX will infer the shardings - from the input :py:class:`jax.Array`'s, and defaults to replicating the input - if the sharding cannot be inferred. - - The valid resource assignment specifications are: - - - :py:class:`Sharding`, which will decide how the value - will be partitioned. With this, using a mesh context manager is not - required. - - :py:obj:`None` is a special case whose semantics are: - - if the mesh context manager is *not* provided, JAX has the freedom to - choose whatever sharding it wants. - For in_shardings, JAX will mark is as replicated but this behavior - can change in the future. - For out_shardings, we will rely on the XLA GSPMD partitioner to - determine the output shardings. - - If the mesh context manager is provided, None will imply that the - value will be replicated on all devices of the mesh. - - For backwards compatibility, in_shardings still supports ingesting - :py:class:`PartitionSpec`. This option can *only* be used with the - mesh context manager. - - - :py:class:`PartitionSpec`, a tuple of length at most equal to the rank - of the partitioned value. Each element can be a :py:obj:`None`, a mesh - axis or a tuple of mesh axes, and specifies the set of resources assigned - to partition the value's dimension matching its position in the spec. - - The size of every dimension has to be a multiple of the total number of - resources assigned to it. - out_shardings: Like ``in_shardings``, but specifies resource - assignment for function outputs. - The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit` - will use GSPMD's sharding propagation to determine how to shard the outputs. - static_argnums: An optional int or collection of ints that specify which - positional arguments to treat as static (compile-time constant). - Operations that only depend on static arguments will be constant-folded in - Python (during tracing), and so the corresponding argument values can be - any Python object. - - Static arguments should be hashable, meaning both ``__hash__`` and - ``__eq__`` are implemented, and immutable. Calling the jitted function - with different values for these constants will trigger recompilation. - Arguments that are not arrays or containers thereof must be marked as - static. - - If ``static_argnums`` is not provided, no arguments are treated as static. - static_argnames: An optional string or collection of strings specifying - which named arguments to treat as static (compile-time constant). See the - comment on ``static_argnums`` for details. If not - provided but ``static_argnums`` is set, the default is based on calling - ``inspect.signature(fun)`` to find corresponding named arguments. - donate_argnums: Specify which positional argument buffers are "donated" to - the computation. It is safe to donate argument buffers if you no longer - need them once the computation has finished. In some cases XLA can make - use of donated buffers to reduce the amount of memory needed to perform a - computation, for example recycling one of your input buffers to store a - result. You should not reuse buffers that you donate to a computation, JAX - will raise an error if you try to. By default, no argument buffers are - donated. - - If neither ``donate_argnums`` nor ``donate_argnames`` is provided, no - arguments are donated. If ``donate_argnums`` is not provided but - ``donate_argnames`` is, or vice versa, JAX uses - :code:`inspect.signature(fun)` to find any positional arguments that - correspond to ``donate_argnames`` - (or vice versa). If both ``donate_argnums`` and ``donate_argnames`` are - provided, ``inspect.signature`` is not used, and only actual - parameters listed in either ``donate_argnums`` or ``donate_argnames`` will - be donated. - - For more details on buffer donation see the - `FAQ `_. - donate_argnames: An optional string or collection of strings specifying - which named arguments are donated to the computation. See the - comment on ``donate_argnums`` for details. If not - provided but ``donate_argnums`` is set, the default is based on calling - ``inspect.signature(fun)`` to find corresponding named arguments. - keep_unused: If `False` (the default), arguments that JAX determines to be - unused by `fun` *may* be dropped from resulting compiled XLA executables. - Such arguments will not be transferred to the device nor provided to the - underlying executable. If `True`, unused arguments will not be pruned. - device: This argument is deprecated. Please put your arguments on the - device you want before passing them to jit. - Optional, the Device the jitted function will run on. (Available devices - can be retrieved via :py:func:`jax.devices`.) The default is inherited - from XLA's DeviceAssignment logic and is usually to use - ``jax.devices()[0]``. - backend: This argument is deprecated. Please put your arguments on the - backend you want before passing them to jit. - Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or - ``'tpu'``. - - Returns: - A wrapped version of ``fun``, set up for just-in-time compilation and - automatically partitioned by the mesh available at each call site. - - For example, a convolution operator can be automatically partitioned over - an arbitrary set of devices by a single :func:`~pjit` application: - - >>> import jax - >>> import jax.numpy as jnp - >>> import numpy as np - >>> from jax.sharding import Mesh, PartitionSpec - >>> from jax.experimental.pjit import pjit - >>> - >>> x = jnp.arange(8, dtype=jnp.float32) - >>> f = pjit(lambda x: jax.numpy.convolve(x, jnp.asarray([0.5, 1.0, 0.5]), 'same'), - ... in_shardings=None, out_shardings=PartitionSpec('devices')) - >>> with Mesh(np.array(jax.devices()), ('devices',)): - ... print(f(x)) # doctest: +SKIP - [ 0.5 2. 4. 6. 8. 10. 12. 10. ] - """ + """`jax.experimental.pjit.pjit` has been deprecated. Please use `jax.jit`.""" return make_jit( - fun, in_shardings, out_shardings, donate_argnums, donate_argnames, - static_argnums, static_argnames, device, backend, abstracted_axes, - keep_unused, inline, compiler_options, use_resource_env=True) + fun, in_shardings=in_shardings, out_shardings=out_shardings, + static_argnums=static_argnums, static_argnames=static_argnames, + donate_argnums=donate_argnums, donate_argnames=donate_argnames, + keep_unused=keep_unused, device=device, backend=backend, inline=inline, + compiler_options=compiler_options, use_resource_env=True) def hashable_pytree(pytree): @@ -997,32 +667,20 @@ def hashable_pytree(pytree): def _create_sharding_for_array(mesh, x, name, api_name): - if x is None and (mesh is None or mesh.empty): - return UNSPECIFIED + if x is None: + if api_name == 'jit' or mesh.empty: + return UNSPECIFIED + return sharding_impls.cached_named_sharding(mesh, PartitionSpec()) if isinstance(x, (AUTO, UnspecifiedValue, Sharding)): return x - if mesh is None: - msg = ('jax.jit only supports `Sharding`s being passed to' - f' {name}. Looks like you are passing either `PartitionSpec` or `None`' - f' which is not allowed in jax.jit.\n') - if name == 'in_shardings': - msg += (f'Note that {name} argument is optional. JAX will infer the shardings' - " from the input jax.Array's and will default to replicating the" - ' input if the sharding cannot be inferred.') - elif name == 'out_shardings': - msg += (f'Note that {name} is optional. If not specified, jax.jit will' - " use GSPMD's sharding propagation to figure out what the sharding" - ' of the output(s) should be.') - raise RuntimeError(msg) if mesh.empty: raise RuntimeError( - f'{api_name} requires a non-empty mesh if you are passing' - f' `PartitionSpec`s or `None` to {name}! Is a mesh defined at the call' - f' site? Alternatively, provide `Sharding`s to {name} and' - ' then the mesh context manager is not required.') - # A nice user error is raised in prepare_axis_resources. - assert x is None or isinstance(x, PartitionSpec), x - return sharding_impls.create_mesh_pspec_sharding(mesh, x) + f'{api_name} requires a non-empty mesh in context if you are passing' + f' `PartitionSpec`s to {name}. You can define a context mesh via' + ' `jax.set_mesh(mesh)`. Alternatively, provide `Sharding`s to' + f' {name} and then the mesh context manager is not required.') + assert isinstance(x, PartitionSpec), x + return sharding_impls.cached_named_sharding(mesh, x) def _create_sharding_with_device_backend(device, backend): @@ -1055,9 +713,9 @@ def flatten_axis_resources(what, tree, shardings, tupled_args): # tuple, but while it is a non-leaf pytree, either it wasn't a tuple or it # wasn't the right length. msg = (f"{what} specification must be a tree prefix of the positional " - f"arguments tuple passed to the `pjit`-decorated function. In " - f"particular, {what} must either be a None, a PartitionSpec, or " - f"a tuple of length equal to the number of positional arguments.") + f"arguments tuple. In particular, {what} must either be a Sharding, " + "a PartitionSpec, or a tuple of length equal to the number of " + "positional arguments.") # If `tree` represents an args tuple, then `axis_resources` must be a tuple. # TODO(mattjj,apaszke): disable implicit list casts, remove 'or list' below if type(shardings) is not tuple: @@ -1099,10 +757,12 @@ def __repr__(self): return "pytree leaf" @util.cache(max_size=4096, trace_context_in_key=False) def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves, in_layouts_treedef, in_layouts_leaves, - in_avals, in_tree, debug_info: core.DebugInfo, + in_avals, dbg: core.DebugInfo, device_or_backend_set, kws): - if not kws: - in_tree, _ = treedef_children(in_tree) + if kws: + in_tree = in_avals.tree_without_statics + else: + in_tree, _ = treedef_children(in_avals.tree_without_statics) orig_in_shardings = tree_unflatten(in_shardings_treedef, in_shardings_leaves) # Only do this if original in_shardings are unspecified. If it is AUTO, go @@ -1120,190 +780,14 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves, in_layouts_flat = flatten_axis_resources( "pjit in_layouts", in_tree, in_layouts, tupled_args=True) - # TODO(dougalm,mattjj): enable debug info with attrs_tracked - attrs_tracked = len(debug_info.arg_names) != len(in_avals) - if not config.dynamic_shapes.value and not attrs_tracked: - pjit_check_aval_sharding(in_shardings_flat, in_avals, - debug_info.safe_arg_names(len(in_avals)), - "pjit arguments", allow_uneven_sharding=False) - check_aval_layout_compatibility( - in_layouts_flat, in_avals, - debug_info.safe_arg_names(len(in_avals)), "jit arguments") # type: ignore[arg-type] + pjit_check_aval_sharding(in_shardings_flat, in_avals, + dbg.safe_arg_names(len(in_avals)), + "pjit arguments", allow_uneven_sharding=False) + check_aval_layout_compatibility( + in_layouts_flat, in_avals, + dbg.safe_arg_names(len(in_avals)), "jit arguments") # type: ignore[arg-type] return in_shardings_flat, in_layouts_flat -callsites: set[str] = set() - -def explain_tracing_cache_miss( - fun: lu.WrappedFun, unseen_f: bool, cache: dict, key: tuple): - if config.check_tracer_leaks.value: return - - def unpack(key): - transforms, (), _, (in_type, _, inline), *_, ctx = key - # TODO(dougalm,mattjj): enable cache miss explanation with attrs - _, (_, (in_tree,)), *_ = transforms - return in_tree, in_type, inline.val, ctx - in_tree, in_type, inline, ctx = unpack(key) - if inline: return - - debug_info = fun.debug_info - msg: list[str] = [] - p = msg.append - done = lambda: logger.log(logging.WARNING, '\n'.join(msg)) - - callsite = source_info_util.summarize(source_info_util.current()) - p(f"TRACING CACHE MISS at {callsite} because:") - - # have we seen this function before at all? - fun_name = getattr(fun.f, '__qualname__', fun.f) - if debug_info.func_src_info: - # TODO(necula): clean up the extraction of the source info - _, *rest = debug_info.func_src_info.split(' at ') - src_info = " defined at " + ' '.join(rest) - else: - src_info = '' - if unseen_f: - p(f" never seen function:\n {fun_name} id={id(fun.f)}{src_info}") - if callsite in callsites: - p(" but seen another function defined on the same line; maybe the function is\n" - " being re-defined repeatedly, preventing caching?") - callsites.add(callsite) - return done() - else: - p(f" for {fun_name}{src_info}") - - seen_keys = map(unpack, cache.keys()) - - # have we maybe switched some args to be kwargs or visa-versa? - args_tree, kwargs_tree = treedef_children(in_tree) - args_kwargs_trees = [treedef_children(k) for k, *_ in seen_keys] - args_kwargs_match = [t for t in args_kwargs_trees - if t == [args_tree, kwargs_tree]] - if not args_kwargs_match: - num_args = len(treedef_children(args_tree)) - _, kwarg_keys = kwargs_tree.node_data() # type: ignore - p(f" never seen passing {num_args} positional args and {len(kwarg_keys)} " - "keyword args with keys:\n" - f" {', '.join(map(repr, kwarg_keys))}") - dont_match = [set(t[1].node_data()[1]) for t in args_kwargs_trees # type: ignore - if t != [args_tree, kwargs_tree]] - close_kwargs = min( - dont_match, key=set(kwarg_keys).symmetric_difference, default=None - ) - if not close_kwargs: - p(" closest seen is passing no keyword args") - else: - p(f" closest seen passes {len(close_kwargs)} keyword args with keys:\n" - f" {', '.join(map(repr, close_kwargs))}") - return done() - - # have we never seen this tracing context before? - ctxs_match = [c for *_, c in seen_keys if c == ctx] - if not ctxs_match: - p(" tracing context doesn't match, e.g. due to config or context manager") - dont_match = [c for *_, c in seen_keys if c != ctx] - closest_ctx = min(dont_match, key=lambda c: sum(map(op.ne, c, ctx))) - idxs = [i for i, (c1, c2) in enumerate(zip(ctx, closest_ctx)) if c1 != c2] - p(" closest seen context tuple differs at positions:\n" - f" {', '.join(map(str, idxs))}\n" - " compare to tuple returned by config._trace_context() in jax/_src/config.py.") - return done() - - # have we never seen this input pytree before? - trees_match = [k for k in seen_keys if k[0] == in_tree] - if not trees_match: - in_tree_str = f':\n {in_tree}' if len(str(in_tree)) < 76 else '' - p(f" never seen input pytree{in_tree_str}") - dont_match = [t for t, *_ in seen_keys if t != in_tree] - closest_tree = min(dont_match, key=lambda t: abs(t.num_leaves - in_tree.num_leaves)) - errs = list(tree_util.equality_errors_pytreedef(in_tree, closest_tree)) # type: ignore[arg-type] - p(f" closest seen input pytree has {len(errs)} mismatches, including:") - for path, thing1, thing2, explanation in errs: - fst, *path = path # type: ignore - base = ['args', 'kwargs'][fst.idx] - p(f" * at {base}{keystr(tuple(path))}, seen {thing2} but now given {thing1}," - f" so {explanation}") - return done() - - # have we never seen these input types (eg shapes, dtypes) before? - types_match = [k for k in trees_match if k[1] == in_type] - if not types_match: - if len(in_type) < 5: - in_type_str = ':\n {}'.format(', '.join( - f'{n}: {ty.str_short(short_dtypes=True)}' - for n, ty in zip(debug_info.arg_names, in_type))) - else: - in_type_str = '' - p(f" never seen input type signature{in_type_str}") - dont_match = [t for _, t, *_ in trees_match if t != in_type] - closest_ty = min(dont_match, key=lambda t: sum(map(op.ne, t, in_type))) - num_mismatch = sum(map(op.ne, closest_ty, in_type)) - p(f" closest seen input type signature has {num_mismatch} mismatches, including:") - add_weak_type_hint = False - arg_names = debug_info.safe_arg_names(len(in_type)) - - for name, ty1, ty2 in zip(arg_names, closest_ty, in_type): - if ty1 != ty2: - if type(ty1) == type(ty2) == core.ShapedArray: - s1, s2 = ty1.str_short(True), ty2.str_short(True) - if ty1.weak_type != ty2.weak_type: - s1 += f'{{weak_type={ty1.weak_type}}}' - s2 += f'{{weak_type={ty2.weak_type}}}' - add_weak_type_hint = True - elif ty1.sharding != ty2.sharding: - s1 = ty1.str_short(short_dtypes=True, mesh_axis_types=True) - s2 = ty2.str_short(short_dtypes=True, mesh_axis_types=True) - else: - s1, s2 = str(ty1), str(ty2) - p(f" * at {name}, seen {s1}, but now given {s2}") - if add_weak_type_hint: - p('where weak_type=True often means a Python builtin numeric value, and ') - p('weak_type=False means a jax.Array.') - p('See https://jax.readthedocs.io/en/latest/type_promotion.html#weak-types') - return done() - - # we think this is unreachable... - p("explanation unavailable! please open an issue at https://github.com/jax-ml/jax") - return done() - -@partial(lu.cache, explain=explain_tracing_cache_miss) -def _create_pjit_jaxpr( - fun: lu.WrappedFun, - in_type: core.InputType | Sequence[core.AbstractValue], - attr_data: int, - ignored_inline: IgnoreKey -) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue], - list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: - util.test_event("create_pjit_jaxpr") - del ignored_inline # just for explain_cache_miss - if config.no_tracing.value: - raise RuntimeError(f"re-tracing function {fun.f} for `jit`, but " - "'no_tracing' is set") - with dispatch.log_elapsed_time( - "Finished tracing + transforming {fun_name} for pjit in {elapsed_time:.9f} sec", - fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT): - if config.dynamic_shapes.value: - jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2( - lu.annotate(fun, cast(core.InputType, in_type))) - attrs_tracked = [] - else: - jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic( - fun, in_type) - # assert attr_data is sentinel or attr_data matches attrs_tracked - - if config.debug_key_reuse.value: - # Import here to avoid circular imports - from jax.experimental.key_reuse._core import check_key_reuse_jaxpr - check_key_reuse_jaxpr(jaxpr) - - if any(isinstance(c, core.Tracer) for c in consts): - closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr)) - final_consts = consts - else: - closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) - final_consts = [] - return closed_jaxpr, final_consts, global_out_avals, attrs_tracked - - @util.cache(max_size=4096, trace_context_in_key=False) def _check_and_canonicalize_out_shardings( out_shardings_treedef, out_shardings_leaves, out_layouts_treedef, @@ -1315,7 +799,7 @@ def _check_and_canonicalize_out_shardings( out_shardings_flat = (orig_out_shardings,) * len(out_avals) else: out_shardings_flat = flatten_axis_resources( - "pjit out_shardings", out_tree(), orig_out_shardings, + "pjit out_shardings", out_tree, orig_out_shardings, tupled_args=False) out_layouts = tree_unflatten(out_layouts_treedef, out_layouts_leaves) @@ -1323,60 +807,41 @@ def _check_and_canonicalize_out_shardings( out_layouts_flat = (out_layouts,) * len(out_avals) else: out_layouts_flat = flatten_axis_resources( - "pjit out_layouts", out_tree(), out_layouts, tupled_args=False) - - if not config.dynamic_shapes.value: - pjit_check_aval_sharding( - out_shardings_flat, out_avals, - debug_info.safe_result_paths(len(out_avals)), - "pjit outputs", allow_uneven_sharding=False) - check_aval_layout_compatibility( - out_layouts_flat, out_avals, - debug_info.safe_result_paths(len(out_avals)), - "jit outputs") - return out_shardings_flat, out_layouts_flat + "pjit out_layouts", out_tree, out_layouts, tupled_args=False) + pjit_check_aval_sharding( + out_shardings_flat, out_avals, + debug_info.safe_result_paths(len(out_avals)), + "pjit outputs", allow_uneven_sharding=False) + check_aval_layout_compatibility( + out_layouts_flat, out_avals, + debug_info.safe_result_paths(len(out_avals)), + "jit outputs") + return out_shardings_flat, out_layouts_flat -AttrRecord = tuple[object, str, PyTreeDef, list[core.AbstractValue]] -_seen_attrs = weakref.WeakKeyDictionary() # type: ignore +_seen_qdds = weakref.WeakKeyDictionary() # type: ignore -def seen_attrs_get( - fun: lu.WrappedFun, - in_type: core.InputType | tuple[core.AbstractValue, ...] -) -> list: - cache = _seen_attrs.setdefault(fun.f, defaultdict(list)) - assert fun.in_type is None or fun.in_type == in_type - return cache[(fun.transforms, fun.params, in_type)] +def _seen_qdds_get(fun, in_type) -> list: + cache = _seen_qdds.setdefault(fun, defaultdict(list)) + return cache[in_type] -def _attr_token( - fun: lu.WrappedFun, - in_type: core.InputType | tuple[core.AbstractValue, ...] -) -> int: - from jax.experimental.attrs import jax_getattr - cases = seen_attrs_get(fun, in_type) +def _qdd_cache_index(fun, in_type) -> int: + cases = _seen_qdds_get(fun, in_type) for i, records in enumerate(cases): - for obj, attr, treedef, avals in records: - val = jax_getattr(obj, attr) - vals, treedef_ = tree_flatten(val) - avals_ = map(core.shaped_abstractify, vals) - if treedef != treedef_ or avals != avals_: break + for obj, qdd in records: + if core.cur_qdd(obj) != qdd: break else: return i return len(cases) -def _attr_update(fun, in_type, i, attrs_tracked): - from jax.experimental.attrs import jax_getattr - leaves = lambda obj, attr: tree_leaves(jax_getattr(obj, attr)) - records = [(obj, attr, init_tree, map(core.shaped_abstractify, leaves(obj, attr))) - for init_tree, _, (obj, attr) in attrs_tracked] - cases = seen_attrs_get(fun, in_type) +def _qdd_cache_update(fun, in_type, i, consts, aval_qdds): + cases = _seen_qdds_get(fun, in_type) if i == len(cases): - cases.append(records) - else: - assert i < len(cases) and cases[i] == records + cases.append([(c, aval_qdd.qdd) for c, aval_qdd in zip(consts, aval_qdds) + if aval_qdd.has_qdd]) -@dataclasses.dataclass(frozen=True) +@dataclass(frozen=True) class IgnoreKey: val: Any def __hash__(self): @@ -1387,36 +852,33 @@ def __eq__(self, other): def pjit_check_aval_sharding( shardings, flat_avals, names: Sequence[str], - what_aval: str, allow_uneven_sharding: bool): + what_aval: str, allow_uneven_sharding: bool, + allow_partial_manual: bool = False): for aval, s, name in zip(flat_avals, shardings, names): if isinstance(s, (UnspecifiedValue, AUTO)): continue name_str = f' with pytree key path {name}' if name else '' shape = aval.shape try: - # Sharding interfaces can implement `check_compatible_aval` as an optional - # method to raise a more meaningful error. - if hasattr(s, 'check_compatible_aval'): - s.check_compatible_aval(shape) - else: - s._to_xla_hlo_sharding(len(shape)) + s.check_compatible_aval(shape) except ValueError as e: raise ValueError( f'One of {what_aval}{name_str} is incompatible with its sharding ' f'annotation {s}: {e}') - # Use the `OpSharding` proto to find out how many ways each dimension of - # the aval is sharded. This approach will work across all - # Sharding. - hlo_sharding = s._to_xla_hlo_sharding(len(shape)) - assert hlo_sharding is not None - num_ways_dim_sharded, _ = op_shardings.get_num_ways_dim_sharded(hlo_sharding) - for i, size in enumerate(num_ways_dim_sharded): - if not allow_uneven_sharding and shape[i] % size != 0: - raise ValueError(f"One of {what_aval}{name_str} was given the sharding " - f"of {s}, which implies that " - f"the global size of its dimension {i} should be " - f"divisible by {size}, but it is equal to {shape[i]} " - f"(full shape: {shape})") + + if not allow_uneven_sharding: + hlo_sharding = s._to_xla_hlo_sharding(len(shape)) + assert hlo_sharding is not None + num_ways_dim_sharded, _ = op_shardings.get_num_ways_dim_sharded( + hlo_sharding, allow_partial_manual) + for i, size in enumerate(num_ways_dim_sharded): + if shape[i] % size != 0: + raise ValueError( + f'One of {what_aval}{name_str} was given the sharding ' + f'of {s}, which implies that ' + f'the global size of its dimension {i} should be ' + f'divisible by {size}, but it is equal to {shape[i]} ' + f'(full shape: {shape})') def check_aval_layout_compatibility( @@ -1425,9 +887,8 @@ def check_aval_layout_compatibility( if l is None or isinstance(l, AutoLayout): continue name_str = f' with pytree key path {name}' if name else '' - shape = aval.shape try: - l.check_compatible_aval(shape) + l.check_compatible_aval(aval.shape) except ValueError as e: raise ValueError( f'One of {what_aval}{name_str} is incompatible with its layout ' @@ -1436,11 +897,74 @@ def check_aval_layout_compatibility( # -------------------- pjit rules -------------------- -pjit_p = core.Primitive("pjit") -pjit_p.multiple_results = True -pjit_p.skip_canonicalization = True - -def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals): +jit_p = core.Primitive("jit") +jit_p.is_effectful = lambda params: bool(params['jaxpr'].effects) # type: ignore +jit_p.multiple_results = True +jit_p.skip_canonicalization = True + +def _is_high(*_, jaxpr, **__) -> bool: + return jaxpr.jaxpr.is_high +jit_p.is_high = _is_high # type: ignore + +def _to_lojax(*hi_args, jaxpr, **params): + # convert closed-over boxes to explicit args + jaxpr, closed_over_himutables = pe.convert_const_himutables(jaxpr) + hi_args = [*closed_over_himutables, *hi_args] + params = _converted_mutables_add_params(len(closed_over_himutables), **params) + + # expand pjit params that must match number of lo inputs/outputs + lo_nums_in = [len(aval.lo_ty()) for aval in jaxpr.in_aval_qdds] + lo_nums_out = [len(t.lo_ty()) for t in jaxpr.out_avals] + lo_muts_out = pe.num_himuts_out(jaxpr) + params = _lojax_expand_params(lo_nums_in, lo_nums_out, lo_muts_out, **params) + + # collect lo input values + lo_args = [lo_val for aval, x in zip(jaxpr.in_aval_qdds, hi_args) + for lo_val in (aval.read_loval(x) if aval.has_qdd + else aval.lower_val(x))] + + # lower the jaxpr and bind it using lo input values + lo_jaxpr = pe.lower_jaxpr(jaxpr) + all_outs = jit_p.bind(*lo_args, jaxpr=lo_jaxpr, **params) + out_mut, lo_outs = split_list(all_outs, [lo_muts_out]) + pe.apply_himut(jaxpr, hi_args, out_mut) + return pe.raise_lo_outs(jaxpr.out_avals, lo_outs) +jit_p.to_lojax = _to_lojax + +def _converted_mutables_add_params( + n, *, donated_invars, in_shardings, in_layouts, **params): + donated_invars = (False,) * n + donated_invars + in_shardings = (UNSPECIFIED,) * n + in_shardings + in_layouts = (None,) * n + in_layouts + return dict(params, donated_invars=donated_invars, in_shardings=in_shardings, + in_layouts=in_layouts) + + +def _lojax_expand_params( + nums_in, nums_out, muts_out, *, donated_invars, in_shardings, in_layouts, + out_shardings, out_layouts, **params): + # some pjit params match the length of hi_jaxpr.invars/outvars, so when + # lowering we must expand them to match their number of lojax types + def expand(ns, xs): + return tuple(y for n, x in zip(ns, xs) for y in (x,) * n) + donated_invars = expand(nums_in , donated_invars) + in_shardings = expand(nums_in , in_shardings ) + in_layouts = expand(nums_in , in_layouts ) + out_shardings = expand(nums_out, out_shardings ) + out_layouts = expand(nums_out, out_layouts ) + + # also, the lo_jaxpr has pure outputs corresponding to mutable hi_jaxpr types + out_shardings = (UNSPECIFIED,) * muts_out + out_shardings + out_layouts = (None,) * muts_out + out_layouts + + new_params = dict(params, donated_invars=donated_invars, + in_shardings=in_shardings, in_layouts=in_layouts, + out_shardings=out_shardings, out_layouts=out_layouts) + return new_params + + +def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, + in_avals) -> Sequence[Layout | AutoLayout | None]: # If device or backend is set, return the default layout. This is because you # can pass arrays on cpu (with untiled layouts) to jit with backend='tpu' # which causes error checks to fail. Returning the default layout allows @@ -1448,23 +972,23 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals): if pxla.check_device_backend_on_shardings(resolved_in_shardings): return (None,) * len(jit_in_layouts) - resolved_in_layouts = [] + resolved_in_layouts: list[Layout | AutoLayout | None] = [] for arg, jit_in_l, rs, aval in safe_zip( args, jit_in_layouts, resolved_in_shardings, in_avals): - committed = getattr(arg, '_committed', True) + committed = arg.committed # `arg_layout` is only used for checking purposes in the `else` branch # below. We cannot replace default layout with None to raise nicer errors. # `dispatch_arg_layout` replaces default layouts with `None` to simplify # dispatch and lowering logic downstream. - if hasattr(arg, 'layout'): - arg_layout = arg.layout.device_local_layout + if arg.format is not None: + arg_layout = arg.format.layout dispatch_arg_layout = (None if pxla.is_default_layout(arg_layout, rs, aval) else arg_layout) else: arg_layout, dispatch_arg_layout = None, None # Sharding can be unspecified when array is committed if it's a PmapSharding. is_pmap_sharding = (isinstance(rs, UnspecifiedValue) or - isinstance(getattr(arg, 'sharding', None), PmapSharding)) + isinstance(arg.sharding, PmapSharding)) if jit_in_l is None: if committed: if is_pmap_sharding: @@ -1475,8 +999,8 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals): resolved_in_layouts.append(None) else: # arg_layout can be None because some backends don't implement the - # required layout methods. Hence `arr.layout` can return - # `Layout(None, sharding)` + # required layout methods. Hence `arr.format` can return + # `Format(None, sharding)` if (committed and not is_pmap_sharding and arg_layout is not None @@ -1484,19 +1008,18 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals): extra_msg = '' if isinstance(jit_in_l, AutoLayout): extra_msg = ( - ' The layout given to `jax.jit` is `DeviceLocalLayout.AUTO` but' + ' The layout given to `jax.jit` is `Layout.AUTO` but' ' the corresponding argument passed is a `jax.Array` with a' ' concrete layout. Consider passing a `jax.ShapeDtypeStruct`' ' instead of `jax.Array` as an argument to the jitted function ' - ' when using `DeviceLocalLayout.AUTO`.' + ' when using `Layout.AUTO`.' ) raise ValueError('Layout passed to jit does not match the layout ' 'on the respective arg. ' - f'Got pjit layout: {jit_in_l},\n' - f'arg layout: {arg_layout} for ' - f'arg shape: {core.shaped_abstractify(arg).str_short()}.' + f'Got jit layout: {jit_in_l},\n' + f'arg layout: {arg_layout} for arg type: {arg.aval}.' f'{extra_msg}') - jit_in_l = (None if isinstance(jit_in_l, DeviceLocalLayout) and + jit_in_l = (None if isinstance(jit_in_l, Layout) and pxla.is_default_layout(jit_in_l, rs, aval) else jit_in_l) resolved_in_layouts.append(jit_in_l) return tuple(resolved_in_layouts) @@ -1506,13 +1029,27 @@ def _resolve_out_layouts(out_layouts, out_shardings, out_avals): for out_l, out_s, out_aval in safe_zip(out_layouts, out_shardings, out_avals): if out_l is None: new_out_layouts.append(None) - elif (isinstance(out_l, DeviceLocalLayout) and + elif (isinstance(out_l, Layout) and pxla.is_default_layout(out_l, out_s, out_aval)): new_out_layouts.append(None) else: new_out_layouts.append(out_l) return tuple(new_out_layouts) +def finalize_arg_sharding(arg_s, committed): + if isinstance(arg_s, UnspecifiedValue): + return arg_s + else: + if committed: + # If the arg has a PmapSharding, then reshard it unconditionally. + return UNSPECIFIED if isinstance(arg_s, PmapSharding) else arg_s + else: + assert isinstance(arg_s, Sharding) + if dispatch.is_single_device_sharding(arg_s): + return UNSPECIFIED + raise NotImplementedError('Having uncommitted Array sharded on ' + 'multiple devices is not supported.') + def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] ) -> Sequence[PjitSharding]: # If True, means that device or backend is set by the user on pjit and it @@ -1523,63 +1060,29 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] if pxla.check_device_backend_on_shardings(pjit_in_shardings): return pjit_in_shardings - committed_arg_shardings = [] - for a in args: - arg_s = getattr(a, 'sharding', None) - # arg sharding can be None in case of ShapeDtypeStruct. jax.Array does - # not allow None as the sharding. - if arg_s is None: - continue - # Don't consider PmapSharding inputs as committed. They will get resharded - # unconditionally. - if isinstance(arg_s, PmapSharding): - continue - if getattr(a, '_committed', True): - committed_arg_shardings.append((arg_s, pxla.MismatchType.ARG_SHARDING, None)) - resolved_in_shardings: list[PjitSharding] = [] for arg, pjit_in_s in zip(args, pjit_in_shardings): # arg sharding can be None in case of ShapeDtypeStruct. jax.Array does # not allow None as the sharding. - arg_s, committed = ((arg.sharding, getattr(arg, '_committed', True)) - if hasattr(arg, 'sharding') and arg.sharding is not None + arg_s, committed = ((arg.sharding, arg.committed) if arg.sharding is not None else (UNSPECIFIED, False)) if isinstance(arg_s, NamedSharding) and arg_s.mesh.empty: arg_s, committed = UNSPECIFIED, False if isinstance(pjit_in_s, UnspecifiedValue): - if isinstance(arg_s, UnspecifiedValue): - resolved_in_shardings.append(arg_s) - else: - if committed: - # If the arg has a PmapSharding, then reshard it unconditionally. - if isinstance(arg_s, PmapSharding): - resolved_in_shardings.append(UNSPECIFIED) - else: - resolved_in_shardings.append(arg_s) - else: - assert isinstance(arg_s, Sharding) - if dispatch.is_single_device_sharding(arg_s): - resolved_in_shardings.append(UNSPECIFIED) - else: - raise NotImplementedError('Having uncommitted Array sharded on ' - 'multiple devices is not supported.') + resolved_in_shardings.append(finalize_arg_sharding(arg_s, committed)) else: - if (isinstance(arg, np.ndarray) and - not pjit_in_s.is_fully_replicated and # type: ignore[union-attr] + if (arg.is_np_array and not pjit_in_s.is_fully_replicated and # type: ignore[union-attr] xb.process_count() > 1): raise ValueError( 'Passing non-trivial shardings for numpy ' 'inputs is not allowed. To fix this error, either specify a ' 'replicated sharding explicitly or use ' - '`jax.experimental.multihost_utils.host_local_array_to_global_array(...)` ' + '`jax.make_array_from_process_local_data(...)` ' 'to convert your host local numpy inputs to a jax.Array which you ' - 'can pass to pjit. ' + 'can pass to jit. ' 'If the numpy input is the same on each process, then you can use ' '`jax.make_array_from_callback(...) to create a `jax.Array` which ' - 'you can pass to pjit. ' - 'Please see the jax.Array migration guide for more information ' - 'https://jax.readthedocs.io/en/latest/jax_array_migration.html#handling-of-host-local-inputs-to-pjit-like-batch-etc. ' - f'Got arg shape: {arg.shape}, arg value: {arg}') + f'you can pass to jit. Got arg type: {arg.aval}') if not isinstance(arg_s, UnspecifiedValue) and arg_s._is_concrete: # jax.jit does not allow resharding across different memory kinds even # if the argument is uncommitted. Use jax.device_put for those cases, @@ -1587,29 +1090,27 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] if pjit_in_s.memory_kind != arg_s.memory_kind: # type: ignore[union-attr] raise ValueError( 'Memory kinds passed to jax.jit does not match memory kind on the' - f' respective arg. Got pjit memory kind: {pjit_in_s.memory_kind}, ' # type: ignore[union-attr] - f'arg memory kind: {arg_s.memory_kind} for ' - f'arg shape: {core.shaped_abstractify(arg).str_short()}') + f' respective arg. Got jit memory kind: {pjit_in_s.memory_kind}, ' # type: ignore[union-attr] + f'arg memory kind: {arg_s.memory_kind} for arg type: {arg.aval}') if (committed and not isinstance(arg_s, PmapSharding) and - not op_shardings.are_op_shardings_equal( + not op_shardings.are_hlo_shardings_equal( pjit_in_s._to_xla_hlo_sharding(arg.ndim), # type: ignore[union-attr] arg_s._to_xla_hlo_sharding(arg.ndim))): - raise ValueError('Sharding passed to pjit does not match the sharding ' + raise ValueError('Sharding passed to jit does not match the sharding ' 'on the respective arg. ' - f'Got pjit sharding: {pjit_in_s},\n' - f'arg sharding: {arg_s} for ' - f'arg shape: {core.shaped_abstractify(arg).str_short()}') + f'Got jit sharding: {pjit_in_s},\n' + f'arg sharding: {arg_s} for arg type: {arg.aval}') resolved_in_shardings.append(pjit_in_s) return tuple(resolved_in_shardings) def _resolve_and_lower( - args, jaxpr, in_shardings, out_shardings, in_layouts, + args, jaxpr: core.ClosedJaxpr, in_shardings, out_shardings, in_layouts, out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline, lowering_platforms, lowering_parameters, pgle_profiler, - compiler_options_kvs): + compiler_options_kvs) -> pxla.MeshComputation: in_shardings = _resolve_in_shardings(args, in_shardings) in_layouts = _resolve_in_layouts(args, in_layouts, in_shardings, jaxpr.in_avals) @@ -1623,10 +1124,50 @@ def _resolve_and_lower( _pgle_profiler_dict = weakref.WeakKeyDictionary() # type: ignore + +@dataclass(frozen=True) +class MetaTy: + aval: Any + sharding: Any + format: Any + committed: bool + is_np_array: bool + + replace = replace # type: ignore + + @property + def shape(self): + return self.aval.shape + + @property + def ndim(self): + return self.aval.ndim + +@util.cache(max_size=4096, trace_context_in_key=False) +def create_meta_ty(aval, arg_sharding, arg_format, arg_committed, is_np_array): + return MetaTy(aval, arg_sharding, arg_format, arg_committed, is_np_array) + +def convert_to_metaty(arg): + # TODO(yashkatariya): Remove this Tracer special case after + # getattr(Tracer, 'sharding') is fast. + if isinstance(arg, core.Tracer): + return create_meta_ty(arg.aval, None, None, True, False) + aval = core.shaped_abstractify(arg) + arg_sharding = getattr(arg, 'sharding', None) + arg_format = getattr(arg, 'format', None) + arg_committed = getattr(arg, '_committed', True) + is_np_array = isinstance(arg, np.ndarray) + return create_meta_ty(aval, arg_sharding, arg_format, arg_committed, + is_np_array) + + def _pjit_call_impl_python( - *args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, + *args, + jaxpr: core.ClosedJaxpr, + in_shardings, out_shardings, in_layouts, out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs): + util.test_event("jit_cpp_cache_miss") pgle_compile_options, pgle_profiler = {}, None if config.enable_pgle.value and config.pgle_profiling_runs.value > 0: compilation_target_key = jaxpr @@ -1647,8 +1188,9 @@ def _pjit_call_impl_python( compiler_options_kvs = compiler_options_kvs + tuple(pgle_compile_options.items()) # Passing mutable PGLE profile here since it should be extracted by JAXPR to # initialize the fdo_profile compile option. - compiled = _resolve_and_lower( - args, jaxpr=jaxpr, in_shardings=in_shardings, + arg_types = map(convert_to_metaty, args) + computation = _resolve_and_lower( + arg_types, jaxpr=jaxpr, in_shardings=in_shardings, out_shardings=out_shardings, in_layouts=in_layouts, out_layouts=out_layouts, donated_invars=donated_invars, ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, @@ -1656,12 +1198,13 @@ def _pjit_call_impl_python( lowering_parameters=mlir.LoweringParameters(), pgle_profiler=pgle_profiler, compiler_options_kvs=compiler_options_kvs, - ).compile() + ) + compiled = computation.compile() # This check is expensive so only do it if enable_checks is on. if compiled._auto_spmd_lowering and config.enable_checks.value: pxla.check_array_xla_sharding_layout_match( - args, compiled._in_shardings, compiled._in_layouts, + args, compiled._in_shardings, compiled._in_layouts, # type: ignore jaxpr.jaxpr._debug_info, compiled._kept_var_idx) if config.distributed_debug.value: # Defensively only perform fingerprint logic if debug logging is enabled @@ -1678,7 +1221,8 @@ def _pjit_call_impl_python( ("out_layouts", out_layouts), ("abstract args", map(core.abstractify, args)), ("fingerprint", fingerprint)) - return compiled.unsafe_call(*args), compiled, pgle_profiler + return (compiled.unsafe_call(*computation.const_args, *args), + compiled, pgle_profiler, computation.const_args) @weakref_lru_cache def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts, @@ -1694,20 +1238,24 @@ def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts, return lambda *args: core.jaxpr_as_fun(jaxpr())(*args) # pylint: disable=unnecessary-lambda -def _pjit_call_impl(*args, jaxpr, +def _pjit_call_impl(*args, jaxpr: core.ClosedJaxpr, in_shardings, out_shardings, in_layouts, out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs): def call_impl_cache_miss(*args_, **kwargs_): - out_flat, compiled, pgle_profiler = _pjit_call_impl_python( + # args_ do not include the const args + # See https://docs.jax.dev/en/latest/internals/constants.html. + # TODO(necula): remove num_const_args when fixing the C++ path + out_flat, compiled, pgle_profiler, const_args = _pjit_call_impl_python( *args, jaxpr=jaxpr, in_shardings=in_shardings, out_shardings=out_shardings, in_layouts=in_layouts, out_layouts=out_layouts, donated_invars=donated_invars, ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, inline=inline, compiler_options_kvs=compiler_options_kvs) fastpath_data = _get_fastpath_data( - compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects, - jaxpr.consts, None, pgle_profiler) + compiled, tree_structure(out_flat), args, out_flat, + jaxpr.effects, jaxpr.consts, pgle_profiler, + const_args) return out_flat, fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler) f = _get_jaxpr_as_fun( @@ -1727,9 +1275,10 @@ def call_impl_cache_miss(*args_, **kwargs_): tree_util.dispatch_registry, pxla.cc_shard_arg, _get_cpp_global_cache(cache_key.contains_explicit_attributes))(*args) -pjit_p.def_impl(_pjit_call_impl) - +jit_p.def_impl(_pjit_call_impl) +# This cache is important for python dispatch performance. +@weakref_lru_cache def _pjit_lower( jaxpr: core.ClosedJaxpr, in_shardings, @@ -1745,8 +1294,7 @@ def _pjit_lower( *, lowering_platforms: tuple[str, ...] | None, lowering_parameters: mlir.LoweringParameters, - pgle_profiler: profiler.PGLEProfiler | None): - util.test_event("pjit_lower") + pgle_profiler: profiler.PGLEProfiler | None) -> pxla.MeshComputation: return pxla.lower_sharding_computation( jaxpr, 'jit', name, in_shardings, out_shardings, in_layouts, out_layouts, tuple(donated_invars), @@ -1757,7 +1305,13 @@ def _pjit_lower( pgle_profiler=pgle_profiler) -def pjit_staging_rule(trace, *args, **params): +def pjit_staging_rule(trace, source_info, *args, **params): + if params["compiler_options_kvs"]: + raise ValueError( + '`compiler_options` can only be passed to top-level `jax.jit`. Got' + f' compiler_options={dict(params["compiler_options_kvs"])} specified on' + f' a nested jit with name: {params["name"]} and source info:' + f' {source_info_util.summarize(source_info)}') # If we're inlining, no need to compute forwarding information; the inlined # computation will in effect forward things. if (params["inline"] and @@ -1766,112 +1320,55 @@ def pjit_staging_rule(trace, *args, **params): all(i is None for i in params["in_layouts"]) and all(o is None for o in params["out_layouts"])): jaxpr = params["jaxpr"] - if config.dynamic_shapes.value: - # Inline jaxpr doesn't handle dynamic shapes when inlining. If dynamic - # shapes are enabled, use eval_jaxpr, which uses the tracing machinery, - # but redundantly performs abstract evaluation again. - with core.set_current_trace(trace): - return core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args, - propagate_source_info=False) - else: - return pe.inline_jaxpr_into_trace( - trace, jaxpr.jaxpr, jaxpr.consts, *args) - - jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding( - params['jaxpr'], params['out_shardings'], params['out_layouts']) - params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings, - out_layouts=out_layouts) - if config.dynamic_shapes.value: - source_info = source_info_util.current() - out_tracers = [] - for aval in _out_type(jaxpr): - if type(aval) is core.DShapedArray: - shape = [args[d.val] if type(d) is core.InDBIdx else - out_tracers[d.val] if type(d) is core.OutDBIdx else - d for d in aval.shape] - aval = aval.update(shape=tuple(core.get_referent(d) for d in shape)) - out_tracers.append(pe.DynamicJaxprTracer(trace, aval, source_info)) - eqn = core.new_jaxpr_eqn( - map(trace.getvar, args), map(trace.makevar, out_tracers), pjit_p, params, - jaxpr.effects, source_info) - trace.frame.add_eqn(eqn) - elif any(isinstance(c, core.MutableArray) for c in jaxpr.consts): + out = pe.inline_jaxpr_into_trace( + trace, source_info, jaxpr.jaxpr, jaxpr.consts, *args) + return [trace.to_jaxpr_tracer(x, source_info) for x in out] + + jaxpr = params['jaxpr'] + if any(isinstance(c, core.Ref) for c in jaxpr.consts): jaxpr, consts = pxla._move_mutable_consts(jaxpr) - consts = map(trace.new_const, consts) + consts = [trace.new_const(c, source_info) for c in consts] in_shardings = (*params['in_shardings'],) + (UNSPECIFIED,) * len(consts) in_layouts = (*params['in_layouts'],) + (None,) * len(consts) donated_invars = (*params['donated_invars'],) + (False,) * len(consts) new_params = dict(params, jaxpr=jaxpr, in_shardings=in_shardings, in_layouts=in_layouts, donated_invars=donated_invars) out_tracers = trace.default_process_primitive( - pjit_p, (*args, *consts), new_params) + jit_p, (*args, *consts), new_params, source_info=source_info) else: - out_tracers = trace.default_process_primitive(pjit_p, args, params) - - out_tracers_ = iter(out_tracers) - out_tracers = [args[f] if type(f) is int else next(out_tracers_) - for f in in_fwd] - assert next(out_tracers_, None) is None + out_tracers = trace.default_process_primitive( + jit_p, args, params, source_info=source_info) return out_tracers -pe.custom_staging_rules[pjit_p] = pjit_staging_rule +pe.custom_staging_rules[jit_p] = pjit_staging_rule -def _pjit_forwarding(jaxpr, out_shardings, out_layouts): - in_fwd: list[int | None] = pe._jaxpr_forwarding(jaxpr.jaxpr) - in_fwd = [fwd if isinstance(os, UnspecifiedValue) and ol is None else None for fwd, os, ol - in zip(in_fwd, out_shardings, out_layouts)] - keep = [f is None for f in in_fwd] - jaxpr = pe.prune_closed_jaxpr_outputs(jaxpr, keep) - out_shardings = [o for o, k in zip(out_shardings, keep) if k] - out_layouts = [o for o, k in zip(out_layouts , keep) if k] - return jaxpr, in_fwd, out_shardings, out_layouts - def pjit_forwarding_rule(eqn): - jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding( - eqn.params['jaxpr'], eqn.params['out_shardings'], eqn.params['out_layouts']) - new_outvars = [v for v, f in zip(eqn.outvars, in_fwd) if f is None] - new_params = dict(eqn.params, jaxpr=jaxpr, out_shardings=(*out_shardings,), - out_layouts=(*out_layouts,)) - new_eqn = eqn.replace(params=new_params, outvars=new_outvars) - fwd_vars = [eqn.invars[f] if f is not None else None for f in in_fwd] - return fwd_vars, new_eqn -pe.forwarding_rules[pjit_p] = pjit_forwarding_rule - - -# TODO(mattjj): remove/trivialize this when jaxprs have type annotation on them, -# since it's actually not possible in general to infer the type from the term -def _out_type(jaxpr: core.ClosedJaxpr) -> list[core.AbstractValue]: - out = [] - in_idx = {v: i for i, v in enumerate(jaxpr.jaxpr.invars)} - out_idx = {x: i for i, x in enumerate(jaxpr.jaxpr.invars) - if type(x) is core.Var} - for x in jaxpr.jaxpr.outvars: - aval = x.aval - if type(aval) is core.DShapedArray: - shape = [core.InDBIdx(in_idx[d]) if d in in_idx else - core.OutDBIdx(out_idx[d]) if d in out_idx else - d for d in x.aval.shape] - aval = aval.update(shape=tuple(shape)) - out.append(aval) - return out + return [None] * len(eqn.outvars), eqn +# TODO(mattjj): Remove pjit_forwarding_rule and also in staging rule. +pe.forwarding_rules[jit_p] = pjit_forwarding_rule def _pjit_typecheck(ctx_factory, *in_atoms, jaxpr, **params): - return core._check_call(ctx_factory, pjit_p, in_atoms, + return core._check_call(ctx_factory, jit_p, in_atoms, dict(params, call_jaxpr=jaxpr.jaxpr)) -core.custom_typechecks[pjit_p] = _pjit_typecheck +core.custom_typechecks[jit_p] = _pjit_typecheck def _pjit_abstract_eval(*args, jaxpr, out_shardings, **_): - return jaxpr.out_avals, jaxpr.effects -pjit_p.def_effectful_abstract_eval(_pjit_abstract_eval) + effs = core.eqn_effects(jaxpr) if jaxpr.constvars else jaxpr.effects + return jaxpr.out_avals, effs +jit_p.def_effectful_abstract_eval(_pjit_abstract_eval) def _pjit_cached_lower_jaxpr_to_fun(ctx: mlir.LoweringRuleContext, name: str, jaxpr: core.ClosedJaxpr, + num_const_args: int, in_avals, effects, in_shardings, out_shardings, in_layouts, out_layouts, api_name): + assert len(in_avals) == num_const_args + len(jaxpr.in_avals) + assert len(in_avals) == len(in_shardings) + assert len(in_avals) == len(in_layouts) mod_ctx = ctx.module_context axis_ctx = ctx.module_context.axis_context num_devices = None @@ -1879,24 +1376,33 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx: mlir.LoweringRuleContext, num_devices = axis_ctx.num_devices elif isinstance(axis_ctx, sharding_impls.SPMDAxisContext): num_devices = axis_ctx.mesh.size - key = (pjit_p, name, jaxpr, effects, num_devices, - pxla.SemanticallyEqualShardings(in_shardings, jaxpr.in_avals), - pxla.SemanticallyEqualShardings(out_shardings, jaxpr.out_avals), + key = (jit_p, name, jaxpr, effects, num_devices, + pxla.SemanticallyEqualShardings(in_shardings, in_avals), # pytype: disable=wrong-arg-types + pxla.SemanticallyEqualShardings(out_shardings, jaxpr.out_avals), # pytype: disable=wrong-arg-types in_layouts, out_layouts, api_name) func = mod_ctx.cached_primitive_lowerings.get(key, None) if func is None: arg_shardings = [None if isinstance(i, UnspecifiedValue) else i for i in in_shardings] result_shardings = [None if isinstance(o, UnspecifiedValue) else o for o in out_shardings] - # TODO(b/228598865): inlined calls cannot have shardings set directly on the - # inputs or outputs because they are lost during MLIR->HLO conversion. - # using_sharding_annotation=False means we add an identity operation instead. + # TODO(b/228598865): non-top-level functions cannot have shardings set + # directly on the inputs or outputs because they are lost during MLIR->HLO + # conversion. using_sharding_annotation=False means we add an identity + # operation instead. + num_callbacks = len(mod_ctx.host_callbacks) func = mlir.lower_jaxpr_to_fun( - mod_ctx, name, jaxpr, effects, ctx.name_stack, + mod_ctx, name, jaxpr, effects, + num_const_args=num_const_args, in_avals=in_avals, arg_shardings=arg_shardings, result_shardings=result_shardings, - use_sharding_annotations=False, api_name=api_name, + use_sharding_annotations=False, arg_layouts=in_layouts, result_layouts=out_layouts) - mod_ctx.cached_primitive_lowerings[key] = func + + # If this Jaxpr includes callbacks, we can't cache the lowering because + # on TPU every callback must have a globally unique channel under + # libtpu <= 0.0.34, but the channel gets assigned during lowering. + has_callbacks = len(mod_ctx.host_callbacks) > num_callbacks + if mlir.USE_NEW_TPU_CALLBACK_LOWERING or not has_callbacks or "tpu" not in mod_ctx.platforms: + mod_ctx.cached_primitive_lowerings[key] = func return func @@ -1909,16 +1415,32 @@ def _pjit_lowering(ctx: mlir.LoweringRuleContext, *args, name: str, output_types = [mlir.token_type()] * len(effects) + output_types flat_output_types = mlir.flatten_ir_types(output_types) + const_args_and_avals = core.jaxpr_const_args(jaxpr.jaxpr) + const_args, const_arg_avals = util.unzip2(const_args_and_avals) + in_avals = (*const_arg_avals, *jaxpr.in_avals) + ca_shardings = const_args_shardings(const_args) + in_shardings = ca_shardings + in_shardings # type: ignore + ca_layouts = const_args_layouts(const_args, const_arg_avals, ca_shardings) + in_layouts = ca_layouts + in_layouts # type: ignore + func = _pjit_cached_lower_jaxpr_to_fun( - ctx, name, jaxpr, tuple(effects), in_shardings, + ctx, name, jaxpr, len(const_args), in_avals, + tuple(effects), in_shardings, out_shardings, in_layouts, out_layouts, api_name='jit') tokens_in = [ctx.tokens_in.get(eff) for eff in effects] - args = (*ctx.dim_var_values, *tokens_in, *args) - call = func_dialect.CallOp(flat_output_types, - ir.FlatSymbolRefAttr.get(func.name.value), - mlir.flatten_ir_values(args)) + hoisted_const_values = [ + mlir.ir_constant(c, const_lowering=ctx.const_lowering, aval=aval) + for c, aval in const_args_and_avals + ] + args = (*ctx.dim_var_values, *tokens_in, *hoisted_const_values, *args) + with mlir.source_info_to_location( + ctx.module_context, None, + ctx.name_stack.extend(util.wrap_name('jit', name)), ctx.traceback): + call = func_dialect.CallOp( + flat_output_types, ir.FlatSymbolRefAttr.get(func.name.value), + mlir.flatten_ir_values(args)) mlir.wrap_compute_type_in_place(ctx, call) out_nodes = mlir.unflatten_ir_values_like_types(call.results, output_types) tokens, out_nodes = split_list(out_nodes, [len(effects)]) @@ -1926,8 +1448,22 @@ def _pjit_lowering(ctx: mlir.LoweringRuleContext, *args, name: str, ctx.set_tokens_out(tokens_out) return out_nodes -mlir.register_lowering(pjit_p, _pjit_lowering) +# TODO(phawkins): this is marked uncacheable because it has its own cache and +# because the cache breaks jaxpr metadata like source locations. We should fix +# the metadata problem and consolidate the caches. +mlir.register_lowering(jit_p, _pjit_lowering, cacheable=False) + +def const_args_shardings(const_args: Sequence[ArrayLike]) -> Sequence[PjitSharding]: + return _resolve_in_shardings( + const_args, (sharding_impls.UNSPECIFIED,) * len(const_args)) +def const_args_layouts( + const_args: Sequence[ArrayLike], + avals: Sequence[core.AbstractValue], + shardings: Sequence[PjitSharding] + ) -> Sequence[Layout | AutoLayout | None]: + return _resolve_in_layouts( + const_args, (None,) * len(const_args), shardings, avals) def _pjit_batcher(axis_data, vals_in, dims_in: tuple[int, ...], @@ -1935,10 +1471,7 @@ def _pjit_batcher(axis_data, vals_in, in_shardings, out_shardings, in_layouts, out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs): - segment_lens, dims_in = batching.indirectify_ragged_axes(dims_in) new_jaxpr, axes_out = batching.batch_jaxpr2(jaxpr, axis_data, dims_in) - - # TODO(axch): prepend with Nones (?) to account for new segment_lens inputs in_shardings = tuple( _pjit_batcher_for_sharding(i, axis_in, axis_data.spmd_name, ctx_mesh, aval.ndim) @@ -1955,7 +1488,7 @@ def _pjit_batcher(axis_data, vals_in, raise NotImplementedError( 'Concrete layouts are not supported for vmap(jit).') - vals_out = pjit_p.bind( + vals_out = jit_p.bind( *vals_in, jaxpr=new_jaxpr, in_shardings=in_shardings, @@ -1969,48 +1502,36 @@ def _pjit_batcher(axis_data, vals_in, inline=inline, compiler_options_kvs=compiler_options_kvs) - resolved_axes_out = batching.resolve_ragged_axes_against_inputs_outputs( - vals_in, vals_out, axes_out) - return vals_out, resolved_axes_out + return vals_out, axes_out -batching.fancy_primitive_batchers[pjit_p] = _pjit_batcher -batching.ragged_prop_rules[pjit_p] = batching.ragged_mask_no_op_rule +batching.fancy_primitive_batchers[jit_p] = _pjit_batcher -def _insert_axis_partitions(spec, dim, val): - too_short = dim - len(spec) - if too_short > 0: - spec += (None,) * too_short - new_partitions = tuple_insert(spec, dim, val) - return PartitionSpec(*new_partitions) def _pjit_batcher_for_sharding( - s: Sharding | UnspecifiedValue, - dim: int | batching.RaggedAxis, spmd_axis_name: tuple[str, ...] | None, + s, dim: int, spmd_axis_name: tuple[str, ...] | None, mesh, ndim: int): if isinstance(s, UnspecifiedValue): return s hlo_s = s._to_xla_hlo_sharding(ndim) if spmd_axis_name is None: - if sharding_impls.is_op_sharding_replicated(hlo_s): + if sharding_impls.is_hlo_sharding_replicated(hlo_s): return s if isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh): return NamedSharding( - s.mesh, _insert_axis_partitions(s.spec, dim, PartitionSpec.UNCONSTRAINED)) + s.mesh, pxla.batch_spec(s.spec, dim, PartitionSpec.UNCONSTRAINED)) new_op = hlo_s.to_proto().clone() tad = list(new_op.tile_assignment_dimensions) tad.insert(dim, 1) # type: ignore new_op.tile_assignment_dimensions = tad - new_gs = GSPMDSharding( - s._device_assignment, new_op, - _device_list=getattr(s, '_internal_device_list', None)) + new_gs = GSPMDSharding(s._internal_device_list, new_op) return pxla._get_out_sharding_from_orig_sharding([new_gs], [None], s, None)[0] else: if isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh): return NamedSharding( - s.mesh, _insert_axis_partitions(s.spec, dim, spmd_axis_name)) + s.mesh, pxla.batch_spec(s.spec, dim, spmd_axis_name)) if isinstance(s, NamedSharding): mesh = s.mesh - if mesh is None or mesh.empty: + if mesh.empty: raise ValueError( 'If you are using spmd_axis_name parameter of jax.vmap,' ' please make sure to run your jitted function inside the mesh' @@ -2020,7 +1541,7 @@ def _pjit_batcher_for_sharding( f' manager scope{s!r}') spec = parse_flatten_op_sharding(hlo_s, mesh)[0] return NamedSharding( - mesh, _insert_axis_partitions(spec, dim, spmd_axis_name)) + mesh, pxla.batch_spec(spec, dim, spmd_axis_name)) def _pjit_jvp(primals_in, tangents_in, @@ -2035,7 +1556,7 @@ def _filter_zeros(is_nz_l, l): return (x for nz, x in zip(is_nz_l, l) if nz) _filter_zeros_in = partial(_filter_zeros, is_nz_tangents_in) _filter_zeros_out = partial(_filter_zeros, is_nz_tangents_out) - outputs = pjit_p.bind( + outputs = jit_p.bind( *primals_in, *_filter_zeros_in(tangents_in), jaxpr=jaxpr_jvp, in_shardings=(*in_shardings, *_filter_zeros_in(in_shardings)), @@ -2054,34 +1575,73 @@ def _filter_zeros(is_nz_l, l): tangents_out_it = iter(tangents_out) return primals_out, [next(tangents_out_it) if nz else ad.Zero(aval) for nz, aval in zip(is_nz_tangents_out, jaxpr.out_avals)] -ad.primitive_jvps[pjit_p] = _pjit_jvp - - -def _pjit_linearization(nzs, *primals_in, jaxpr, - in_shardings, out_shardings, in_layouts, out_layouts, - donated_invars, ctx_mesh, name, keep_unused, inline, - compiler_options_kvs): - primal_jaxpr, num_residuals, nzs_out, tangent_jaxpr = ad.linearize_jaxpr(jaxpr, nzs) - # constvars will become residuals. Move them to the end of the ordinary args. - res_shardings = (UNSPECIFIED,) * num_residuals - res_layouts = (None,) * num_residuals - res_donated = (False,) * num_residuals - def tangent_fun(consts_, *tangents): +ad.primitive_jvps[jit_p] = _pjit_jvp + + +def _pjit_linearize(nzs, *primals_in, jaxpr, in_shardings, out_shardings, + in_layouts, out_layouts, donated_invars, ctx_mesh, name, + keep_unused, inline, compiler_options_kvs): + primal_jaxpr, num_residuals_out, nzs_out, in_fwd_res, tangent_jaxpr = \ + ad.linearize_jaxpr(jaxpr, nzs) + num_residuals_in = len(in_fwd_res) + num_primals_out = len(primal_jaxpr.out_avals) - num_residuals_out + + res_shardings_in = (UNSPECIFIED,) * num_residuals_in + res_layouts_in = (None,) * num_residuals_in + res_donated = (False,) * num_residuals_in + primal_out_shardings = tuple(out_shardings) + (UNSPECIFIED,) * num_residuals_out + primal_out_layouts = tuple(out_layouts) + (None,) * num_residuals_out + + config.enable_checks.value and core.check_jaxpr(primal_jaxpr.jaxpr) + config.enable_checks.value and core.check_jaxpr(tangent_jaxpr.jaxpr) + + def keep_where(l, should_keep): + return tuple(x for x, keep in zip(l, should_keep) if keep) + + # Input-to-output forwarding. + in_fwd = pe._jaxpr_forwarding(primal_jaxpr.jaxpr) + in_fwd_primal, in_fwd_res_ = split_list(in_fwd, [num_primals_out]) + assert all(f is None for f in in_fwd_res_) + in_fwd = [ + fwd if isinstance(os, UnspecifiedValue) and ol is None else None + for os, ol, fwd in zip(out_shardings, out_layouts, in_fwd_primal) + ] + in_fwd_res_ + del in_fwd_res_, in_fwd_primal + keep = [f is None for f in in_fwd] + primal_jaxpr = pe.prune_closed_jaxpr_outputs(primal_jaxpr, keep) + primal_out_shardings = keep_where(primal_out_shardings, keep) + primal_out_layouts = keep_where(primal_out_layouts, keep) + _, kept_res = split_list(keep, [num_primals_out]) + num_kept_residuals = sum(kept_res) + del keep, kept_res, num_primals_out + + # Output-to-output forwarding. + num_primals_out = len(primal_jaxpr.out_avals) - num_kept_residuals + out_vars, res_vars = split_list(primal_jaxpr.jaxpr.outvars, [num_primals_out]) + idx_map = {id(v): i for i, v in enumerate(out_vars)} + out_fwd = [None] * num_primals_out + [idx_map.get(id(v)) for v in res_vars] + keep = [f is None for f in out_fwd] + primal_jaxpr = pe.prune_closed_jaxpr_outputs(primal_jaxpr, keep) + primal_out_shardings = keep_where(primal_out_shardings, keep) + primal_out_layouts = keep_where(primal_out_layouts, keep) + del keep + + tangent_avals_out = [a.to_tangent_aval() for a in jaxpr.out_avals] + + def tangent_fun(residuals, *tangents): tangents_nz = _filter_zeros(nzs, tangents) - assert len(consts_) == num_residuals - nz_tangents_out = pjit_p.bind(*(*tangents_nz, *consts_), - jaxpr=tangent_jaxpr, - in_shardings=_filter_zeros(nzs, in_shardings) + res_shardings, + nz_tangents_out = jit_p.bind( + *residuals, *tangents_nz, jaxpr=tangent_jaxpr, + in_shardings=res_shardings_in + _filter_zeros(nzs, in_shardings), out_shardings=_filter_zeros(nzs_out, out_shardings), - in_layouts=_filter_zeros(nzs, in_layouts) + res_layouts, + in_layouts=res_layouts_in + _filter_zeros(nzs, in_layouts), out_layouts=_filter_zeros(nzs_out, out_layouts), - donated_invars=_filter_zeros(nzs, donated_invars) + res_donated, + donated_invars=res_donated + _filter_zeros(nzs, donated_invars), ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, inline=inline, compiler_options_kvs=compiler_options_kvs) - tangent_avals_out = [v.aval.to_tangent_aval() for v in jaxpr.jaxpr.outvars] nz_tangents_out_ = iter(nz_tangents_out) tangents_out = [next(nz_tangents_out_) if nz else ad.Zero(aval) for (aval, nz) in zip(tangent_avals_out, nzs_out)] @@ -2090,22 +1650,26 @@ def tangent_fun(consts_, *tangents): def _filter_zeros(is_nz_l, l): return tuple(x for nz, x in zip(is_nz_l, l) if nz) - ans = pjit_p.bind(*primals_in, jaxpr=primal_jaxpr, - in_shardings=in_shardings, - out_shardings=(*res_shardings, *out_shardings), - in_layouts=in_layouts, - out_layouts=(*res_layouts, *out_layouts), - donated_invars=donated_invars, - ctx_mesh=ctx_mesh, - name=name, - keep_unused=keep_unused, - inline=inline, - compiler_options_kvs=compiler_options_kvs) - residuals_ans, primal_ans = split_list(ans, [num_residuals]) + assert len(in_shardings) == len(primal_jaxpr.in_avals) + ans = jit_p.bind(*primals_in, jaxpr=primal_jaxpr, + in_shardings=in_shardings, + out_shardings=primal_out_shardings, + in_layouts=in_layouts, + out_layouts=primal_out_layouts, + donated_invars=donated_invars, + ctx_mesh=ctx_mesh, + name=name, + keep_unused=keep_unused, + inline=inline, + compiler_options_kvs=compiler_options_kvs) + ans = subs_list(out_fwd, ans, ans) + ans = subs_list(in_fwd, primals_in, ans) + primal_ans, residuals_ans = split_list(ans, [len(ans) - num_residuals_out]) + residuals_ans = subs_list(in_fwd_res, [*jaxpr.consts, *primals_in], residuals_ans) return primal_ans, nzs_out, residuals_ans, tangent_fun -ad.primitive_linearizations[pjit_p] = _pjit_linearization +ad.primitive_linearizations[jit_p] = _pjit_linearize def _pjit_partial_eval(trace: pe.JaxprTrace, @@ -2117,42 +1681,32 @@ def _pjit_partial_eval(trace: pe.JaxprTrace, known_ins = tuple(pv.is_known() for pv in in_pvals) unknown_ins = tuple(not k for k in known_ins) - if any(isinstance(e, (RefEffect, core.InternalMutableArrayEffect)) - for e in jaxpr.effects): - known_jaxpr_, unknown_jaxpr_, unknown_outs, _, num_res_val, num_res_ref = \ - pe.partial_eval_jaxpr_stateful(jaxpr.jaxpr, unknown_ins, unknown_ins, - False, False, None) - if num_res_ref: raise NotImplementedError - known_jaxpr = pe.ClosedJaxpr(known_jaxpr_, jaxpr.consts) - unknown_jaxpr = pe.ClosedJaxpr(unknown_jaxpr_, jaxpr.consts) - res_avals = unknown_jaxpr.in_avals[:num_res_val] - else: - known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \ - pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False) + known_jaxpr, unknown_jaxpr, unknown_outs, res_out_avals, in_fwd_res = \ + pe.partial_eval_jaxpr_nounits_fwd(jaxpr, unknown_ins, instantiate=False) unknown_outs = tuple(unknown_outs) # type: ignore[assignment] known_outs = tuple(not uk for uk in unknown_outs) - num_residuals = len(res_avals) - res_shardings = (UNSPECIFIED,) * num_residuals - res_layouts = (None,) * num_residuals + # out_shardings and out_layouts for residual values output by known_jaxpr def keep_where(l, should_keep): return tuple(x for x, keep in zip(l, should_keep) if keep) - known_out_shardings = keep_where(out_shardings, known_outs) + res_shardings - known_out_layouts = keep_where(out_layouts, known_outs) + res_layouts + known_out_shardings = (keep_where(out_shardings, known_outs) + + (UNSPECIFIED,) * len(res_out_avals)) + known_out_layouts = (keep_where(out_layouts, known_outs) + + (None,) * len(res_out_avals)) # Input-to-output forwarding: compute which outputs are just forwarded inputs. - num_out_primals = len(known_jaxpr.out_avals) - num_residuals + num_out_primals = len(known_jaxpr.out_avals) - len(res_out_avals) in_fwd: list[int | None] = pe._jaxpr_forwarding(known_jaxpr.jaxpr) - # Only forward primal outputs when corresponding out_sharding is UNSPECIFIED. - in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals]) + in_fwd_primal, in_fwd_res_ = split_list(in_fwd, [num_out_primals]) + assert all(f is None for f in in_fwd_res_) in_fwd = [ fwd if isinstance(os, UnspecifiedValue) and ol is None else None for os, ol, fwd in zip( keep_where(out_shardings, known_outs), keep_where(out_layouts, known_outs), in_fwd_primal) - ] + in_fwd_res - del in_fwd_primal, in_fwd_res + ] + in_fwd_res_ + del in_fwd_primal, in_fwd_res_ # Prune jaxpr outputs and out_shardings by removing the input-forwards. keep = [f is None for f in in_fwd] known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep) @@ -2188,14 +1742,18 @@ def keep_where(l, should_keep): # Bind known things to pjit_p. known_inputs = [pv.get_known() for pv in in_pvals if pv.is_known()] - all_known_outs = pjit_p.bind(*known_inputs, **known_params) + all_known_outs = jit_p.bind(*known_inputs, **known_params) # Add back in the output fwds. all_known_outs = subs_list(out_fwd, all_known_outs, all_known_outs) # Add back in the input fwds. all_known_outs = subs_list(in_fwd, known_inputs, all_known_outs) known_out_vals, residual_vals = \ - split_list(all_known_outs, [len(all_known_outs) - num_residuals]) + split_list(all_known_outs, [len(all_known_outs) - len(res_out_avals)]) + residual_vals_ = iter(residual_vals) + residual_vals = [next(residual_vals_) if f is None + else [*jaxpr.consts, *known_inputs][f] for f in in_fwd_res] + assert next(residual_vals_, None) is None residual_tracers = map(trace.new_instantiated_const, residual_vals) # The convention of partial_eval_jaxpr_nounits is to place residual binders at @@ -2203,16 +1761,22 @@ def keep_where(l, should_keep): # jaxpr equation built below and the pjit transpose rule assume a # residual-inputs-last convention. unknown_jaxpr = pe.move_binders_to_back( - unknown_jaxpr, [True] * num_residuals + [False] * sum(unknown_ins)) - # Prepare unknown tracers + unknown_jaxpr, [True] * len(residual_vals) + [False] * sum(unknown_ins)) + + # Set up staged-out 'unknown' eqn + unknown_in_shardings = (keep_where(in_shardings, unknown_ins) + + (UNSPECIFIED,) * len(residual_tracers)) + unknown_in_layouts = (keep_where(in_layouts, unknown_ins) + + (None,) * len(residual_tracers)) + unknown_donated_invars = (keep_where(donated_invars, unknown_ins) + + (False,) * len(residual_tracers)) unknown_params = dict( jaxpr=unknown_jaxpr, - in_shardings=(keep_where(in_shardings, unknown_ins) + res_shardings), + in_shardings=unknown_in_shardings, + in_layouts=unknown_in_layouts, out_shardings=keep_where(out_shardings, unknown_outs), - in_layouts=(keep_where(in_layouts, unknown_ins) + res_layouts), out_layouts=keep_where(out_layouts, unknown_outs), - donated_invars=(keep_where(donated_invars, unknown_ins) + - (False,) * num_residuals), + donated_invars=unknown_donated_invars, ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, @@ -2224,16 +1788,19 @@ def keep_where(l, should_keep): pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None) for aval in unknown_out_avals ] - eqn = pe.new_eqn_recipe((*unknown_tracers_in, *residual_tracers), + unknown_tracers_in = [*unknown_tracers_in, *residual_tracers] + eqn = pe.new_eqn_recipe(trace, unknown_tracers_in, unknown_tracers_out, - pjit_p, + jit_p, unknown_params, unknown_jaxpr.effects, source_info_util.current()) for t in unknown_tracers_out: t.recipe = eqn + if effects.partial_eval_kept_effects.filter_in(unknown_jaxpr.effects): + trace.effect_handles.append(pe.EffectHandle(unknown_tracers_in, eqn)) # type: ignore return merge_lists(unknown_outs, known_out_vals, unknown_tracers_out) -pe.custom_partial_eval_rules[pjit_p] = _pjit_partial_eval +pe.custom_partial_eval_rules[jit_p] = _pjit_partial_eval def _pjit_partial_eval_custom_params_updater( @@ -2282,97 +1849,70 @@ def _pjit_partial_eval_custom_params_updater( assert len(new_params_staged['out_layouts']) == len(params_staged['jaxpr'].out_avals) return new_params_known, new_params_staged -pe.partial_eval_jaxpr_custom_rules[pjit_p] = \ +pe.partial_eval_jaxpr_custom_rules[jit_p] = \ partial(pe.closed_call_partial_eval_custom_rule, 'jaxpr', _pjit_partial_eval_custom_params_updater) -@lu.cache -def _pjit_transpose_trace(fun: lu.WrappedFun, - in_avals: Sequence[core.AbstractValue]): - transpose_jaxpr, _, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic( - fun, in_avals) - transpose_jaxpr = core.ClosedJaxpr(transpose_jaxpr, consts) - return transpose_jaxpr, attrs_tracked - - -def _pjit_transpose(cts_in, *primals_in, - jaxpr: core.ClosedJaxpr, - in_shardings, out_shardings, in_layouts, out_layouts, - donated_invars, ctx_mesh, name, keep_unused, inline, - compiler_options_kvs): - def prune_type(ty, xs, maybe_zeros): - return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty) - - body = lu.wrap_init(ad.closed_backward_pass, - debug_info=jaxpr.jaxpr._debug_info) - body = lu.hashable_partial(body, jaxpr, False) - primals_and_nz_cts_in, in_treedef = tree_flatten((primals_in, cts_in)) - body, cts_out_treedef_thunk = flatten_fun_nokwargs(body, in_treedef) - - transpose_in_shardings = ( - *prune_type(ad.UndefinedPrimal, in_shardings, primals_in), - *prune_type(ad.Zero, out_shardings, cts_in) - ) - transpose_in_layouts = ( - *prune_type(ad.UndefinedPrimal, in_layouts, primals_in), - *prune_type(ad.Zero, out_layouts, cts_in) - ) - global_cts_in_avals = tuple(core.get_aval(ct) for ct in primals_and_nz_cts_in) - - transpose_jaxpr, attrs_tracked = _pjit_transpose_trace( - body, global_cts_in_avals) - cts_out_treedef = cts_out_treedef_thunk() - transpose_out_shardings = prune_type( - ad.Zero, - in_shardings, - tree_unflatten(cts_out_treedef, [object()] * cts_out_treedef.num_leaves)) - transpose_out_layouts = prune_type( - ad.Zero, - in_layouts, - tree_unflatten(cts_out_treedef, [object()] * cts_out_treedef.num_leaves)) - - if attrs_tracked: - init_states = _get_states(attrs_tracked) - primals_and_nz_cts_in = [*init_states, *primals_and_nz_cts_in] - transpose_in_shardings = (UNSPECIFIED,) * len(attrs_tracked) + transpose_in_shardings - transpose_out_shardings = (UNSPECIFIED,) * len(attrs_tracked) + transpose_out_shardings - transpose_in_layouts = (None,) * len(attrs_tracked) + transpose_in_layouts - transpose_out_layouts = (None,) * len(attrs_tracked) + transpose_out_layouts +def _pjit_transpose_fancy( + cts_in, *args, jaxpr, in_shardings, out_shardings, in_layouts, + out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline, + compiler_options_kvs): + primals_ctrefs, specs = ad.project_accums(args) + in_flat, in_tree = tree_flatten((primals_ctrefs, cts_in)) + in_avals = [core.AvalQDD(a, cur_qdd(x)) if (a := typeof(x)).has_qdd # type: ignore + else a for x in in_flat] + trans_jaxpr, out_tree = _transpose_jaxpr_fancy(jaxpr, in_tree, (*in_avals,), specs) + + trans_in_shardings = ( + [s for x, s in zip(args, in_shardings) if not isinstance(x,ad.ValAccum)] + + [s for x, s in zip(cts_in, out_shardings) if not isinstance(x, ad.Zero)]) + trans_in_layouts = ( + [l for x, l in zip(args, in_layouts) if not isinstance(x, ad.ValAccum)] + + [l for x, l in zip(cts_in, out_layouts) if not isinstance(x, ad.Zero)]) + cts_out_ = tree_unflatten(out_tree, trans_jaxpr.out_avals) + trans_out_shardings = tuple(s for x, s in zip(cts_out_, in_shardings) + if isinstance(x, core.AbstractValue)) + trans_out_layouts = tuple(l for x, l in zip(cts_out_, in_layouts ) + if isinstance(x, core.AbstractValue)) try: - nz_cts_out = pjit_p.bind( - *primals_and_nz_cts_in, - jaxpr=transpose_jaxpr, - in_shardings=transpose_in_shardings, - out_shardings=transpose_out_shardings, - in_layouts=transpose_in_layouts, - out_layouts=transpose_out_layouts, - donated_invars=(False,) * len(primals_and_nz_cts_in), - ctx_mesh=ctx_mesh, - name=name, - keep_unused=keep_unused, - inline=inline, + cts_out = jit_p.bind( + *in_flat, jaxpr=trans_jaxpr, in_shardings=tuple(trans_in_shardings), + in_layouts=tuple(trans_in_layouts), out_shardings=trans_out_shardings, + out_layouts=trans_out_layouts, donated_invars=(False,) * len(in_flat), + ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, inline=inline, compiler_options_kvs=compiler_options_kvs) - except dispatch.InternalFloatingPointError as e: + except api_util.InternalFloatingPointError as e: print("Invalid nan value encountered in the backward pass of a jax.jit " "function. Calling the de-optimized backward pass.") try: - _ = ad.closed_backward_pass(jaxpr, None, primals_in, cts_in) + ad.backward_pass3(jaxpr.jaxpr, False, jaxpr.consts, args, cts_in) except (FloatingPointError, ZeroDivisionError) as e2: raise e2 from None # great else: # If control reaches this line, we got a NaN on the output of `compiled` # but not `fun.call_wrapped` on the same arguments. Let's tell the user. - dispatch._raise_no_nan_in_deoptimized(e) - - if attrs_tracked: - final_states, nz_cts_out = split_list(nz_cts_out, [len(init_states)]) - _set_states(attrs_tracked, final_states) + api_util._raise_no_nan_in_deoptimized(e) - return tree_unflatten(cts_out_treedef, nz_cts_out) -ad.primitive_transposes[pjit_p] = _pjit_transpose + for x, ct in zip(args, tree_unflatten(out_tree, cts_out)): + if isinstance(x, ad.ValAccum): x.accum(ct) +@weakref_lru_cache +def _transpose_jaxpr_fancy(jaxpr, in_tree, in_avals, specs): + cell = lambda: None + def transposed(*in_flat): + primals_ctrefs, cts_in = tree_unflatten(in_tree, in_flat) + args = ad.unproject_accums(specs, primals_ctrefs) + ad.backward_pass3(jaxpr.jaxpr, False, jaxpr.consts, args, cts_in) + cts_out = [x.freeze() if isinstance(x, ad.ValAccum) else None for x in args] + cts_out, cell.out_tree = tree_flatten(cts_out) # type: ignore + return cts_out + dbg = jaxpr.jaxpr.debug_info.with_unknown_names() + trans_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(transposed, debug_info=dbg), in_avals) + return core.ClosedJaxpr(trans_jaxpr, consts), cell.out_tree # type: ignore +ad.fancy_transposes[jit_p] = _pjit_transpose_fancy @weakref_lru_cache def _dce_jaxpr_pjit( @@ -2384,7 +1924,6 @@ def _dce_jaxpr_pjit( def dce_jaxpr_pjit_rule(used_outputs: list[bool], eqn: core.JaxprEqn ) -> tuple[list[bool], core.JaxprEqn | None]: - if not any(used_outputs) and not pe.has_effects(eqn): return [False] * len(eqn.invars), None @@ -2407,13 +1946,14 @@ def keep_where(xs, keeps): if not any(used_inputs) and not any(used_outputs) and not dced_jaxpr.effects: return used_inputs, None else: + new_effs = core.eqn_effects(dced_jaxpr) new_eqn = core.new_jaxpr_eqn( [v for v, used in zip(eqn.invars, used_inputs) if used], [v for v, used in zip(eqn.outvars, used_outputs) if used], - eqn.primitive, new_params, dced_jaxpr.effects, eqn.source_info, eqn.ctx) + eqn.primitive, new_params, new_effs, eqn.source_info, eqn.ctx) return used_inputs, new_eqn -pe.dce_rules[pjit_p] = dce_jaxpr_pjit_rule +pe.dce_rules[jit_p] = dce_jaxpr_pjit_rule def _pjit_pp_rule(eqn: core.JaxprEqn, @@ -2433,7 +1973,7 @@ def _pjit_pp_rule(eqn: core.JaxprEqn, del params['out_layouts'] if not params['keep_unused']: del params['keep_unused'] - if params['ctx_mesh'] is None or params['ctx_mesh'].empty: + if params['ctx_mesh'].empty: del params['ctx_mesh'] if not params['compiler_options_kvs']: del params['compiler_options_kvs'] @@ -2446,56 +1986,40 @@ def _pjit_pp_rule(eqn: core.JaxprEqn, del params["name"] return core._pp_eqn(eqn, context, settings, params=["name"] + sorted(params)) -core.pp_eqn_rules[pjit_p] = _pjit_pp_rule - - -def _pjit_state_discharge_rule( - in_avals, out_avals, *args, jaxpr, in_shardings, out_shardings, - in_layouts, out_layouts, **params): - if not all(isinstance(s, UnspecifiedValue) for s in (*in_shardings, *out_shardings)): - raise NotImplementedError - - if not (all(l is None for l in in_layouts) and - all(l is None for l in out_layouts)): - raise NotImplementedError - - jaxpr, consts = jaxpr.jaxpr, jaxpr.consts - num_outs = len(jaxpr.outvars) - discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, consts) - discharged_closed_jaxpr = core.ClosedJaxpr(discharged_jaxpr, discharged_consts) - new_in_shardings = (UnspecifiedValue(),) * len(discharged_jaxpr.invars) - new_out_shardings = (UnspecifiedValue(),) * len(discharged_jaxpr.outvars) - new_in_layouts = (None,) * len(discharged_jaxpr.invars) - new_out_layouts = (None,) * len(discharged_jaxpr.outvars) - out_and_ref_vals = pjit_p.bind( - *args, jaxpr=discharged_closed_jaxpr, in_shardings=new_in_shardings, - out_shardings=new_out_shardings, in_layouts=new_in_layouts, - out_layouts=new_out_layouts, **params) - out_vals, ref_vals = split_list(out_and_ref_vals, [num_outs]) - ref_vals_iter = iter(ref_vals) - new_invals = tuple(next(ref_vals_iter) if isinstance(aval, AbstractRef) - else None for aval in in_avals) - sentinel = object() - assert next(ref_vals_iter, sentinel) is sentinel - return new_invals, out_vals -state_discharge.register_discharge_rule(pjit_p)(_pjit_state_discharge_rule) +core.pp_eqn_rules[jit_p] = _pjit_pp_rule # -------------------- with_sharding_constraint -------------------- -def check_shardings_are_auto(shardings_flat): - for s in shardings_flat: - if not isinstance(s, NamedSharding): +def check_shardings_are_auto(s: Sharding) -> None: + if not isinstance(s, NamedSharding): + return + mesh = s.mesh.abstract_mesh + if not all(mesh._name_to_type[i] == mesh_lib.AxisType.Auto + for axes in s.spec + if axes is not PartitionSpec.UNCONSTRAINED and axes is not None + for i in (axes if isinstance(axes, tuple) else (axes,))): + raise ValueError( + 'The spec of NamedSharding passed to with_sharding_constraint can' + f' only refer to Auto axes of the mesh. Got spec={s.spec} and' + f' mesh={mesh}. You probably meant to use `reshard` API?') + +def assert_shardings_equal(x_aval, user_sharding: NamedSharding): + x_spec = x_aval.sharding.spec + user_spec = user_sharding.spec._normalized_spec_for_aval(x_aval.ndim) + if config.remove_size_one_mesh_axis_from_type.value: + user_spec = core.remove_size_one_mesh_axis(user_spec, user_sharding.mesh) + for x, s in zip(x_spec, user_spec): + if s is PartitionSpec.UNCONSTRAINED: continue - mesh = s.mesh.abstract_mesh - if not all(mesh._name_to_type[i] == mesh_lib.AxisType.Auto - for axes in s.spec - if axes is not PartitionSpec.UNCONSTRAINED and axes is not None - for i in (axes if isinstance(axes, tuple) else (axes,))): - raise ValueError( - 'The spec of NamedSharding passed to with_sharding_constraint can' - f' only refer to Auto axes of the mesh. Got spec={s.spec} and' - f' mesh={mesh}') + else: + if x != s: + raise AssertionError( + '`with_sharding_constraint` acts as an assert when all axes of' + f' mesh are of type `Explicit`. The array sharding: {x_spec} did' + f' not match the sharding provided: {user_spec}. Please use' + ' `jax.sharding.reshard` to shard your input to the sharding you' + ' want.') def with_sharding_constraint(x, shardings): @@ -2516,10 +2040,10 @@ def with_sharding_constraint(x, shardings): Returns: x_with_shardings: PyTree of jax.Arrays with specified sharding constraints. - .. _Distributed arrays and automatic parallelization: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html + .. _Distributed arrays and automatic parallelization: https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html """ x_flat, tree = tree_flatten(x) - + x_avals_flat = [core.shaped_abstractify(x) for x in x_flat] layouts, shardings = _split_layout_and_sharding(shardings) user_shardings = prepare_axis_resources( @@ -2534,9 +2058,12 @@ def with_sharding_constraint(x, shardings): flatten_axes("with_sharding_constraint layouts", tree, layouts)) del layouts - context_mesh = ( - mesh_lib.get_abstract_mesh() if mesh_lib.get_concrete_mesh() is not None - else mesh_lib.thread_resources.env.physical_mesh) + if not mesh_lib.get_concrete_mesh().empty: + context_mesh = mesh_lib.get_abstract_mesh() + elif not mesh_lib.get_abstract_mesh().empty: + context_mesh = mesh_lib.get_abstract_mesh() + else: + context_mesh = mesh_lib.thread_resources.env.physical_mesh shardings_flat = [_create_sharding_for_array(context_mesh, a, 'shardings', 'with_sharding_constraint') @@ -2551,25 +2078,29 @@ def with_sharding_constraint(x, shardings): # TODO(bartchr): remove `unconstrained_dims` after migrating to Shardy. It's # already part of the shardings. unconstrained_dims = [get_unconstrained_dims(s) - if isinstance(s, NamedSharding) else {} + if isinstance(s, NamedSharding) else frozenset() for s in shardings_flat] pjit_check_aval_sharding( - shardings_flat, x_flat, ("",) * len(shardings_flat), + shardings_flat, x_avals_flat, ("",) * len(shardings_flat), "with_sharding_constraint arguments", - allow_uneven_sharding=True) - - check_shardings_are_auto(shardings_flat) - - check_aval_layout_compatibility(user_layouts_flat, x_flat, + allow_uneven_sharding=True, allow_partial_manual=True) + check_aval_layout_compatibility(user_layouts_flat, x_avals_flat, ("",) * len(user_layouts_flat), "with_sharding_constraint arguments") - outs = [sharding_constraint_p.bind(xf, sharding=s, layout=l, - context_mesh=context_mesh, - unconstrained_dims=ud) - for xf, s, l, ud in zip(x_flat, shardings_flat, user_layouts_flat, - unconstrained_dims)] + outs = [] + for xf, x_aval, s, l, ud in zip(x_flat, x_avals_flat, shardings_flat, + user_layouts_flat, unconstrained_dims): + if (mesh_lib.get_abstract_mesh().are_all_axes_explicit and l is None and + isinstance(s, NamedSharding)): + assert_shardings_equal(x_aval, s) + outs.append(xf) + else: + check_shardings_are_auto(s) + outs.append(sharding_constraint_p.bind( + xf, sharding=s, layout=l, context_mesh=context_mesh, + unconstrained_dims=ud)) return tree_unflatten(tree, outs) def _identity_fn(x): return x @@ -2581,7 +2112,7 @@ def _sharding_constraint_impl(x, sharding, layout, context_mesh, if (not context_mesh.empty and isinstance(context_mesh, AbstractMesh) and not hasattr(x, 'sharding')): concrete_mesh = mesh_lib.get_concrete_mesh() - assert concrete_mesh is not None + assert not concrete_mesh.empty sharding = NamedSharding(concrete_mesh, sharding.spec) else: aval = core.shaped_abstractify(x) @@ -2603,39 +2134,63 @@ def _sharding_constraint_impl(x, sharding, layout, context_mesh, sharding = NamedSharding(x.sharding.mesh, sharding.spec) if layout is None: - if hasattr(x, 'sharding') and x.sharding.is_equivalent_to(sharding, x.ndim): - return x # Run a jit here to raise good errors when device assignment don't match. return api.jit(_identity_fn, out_shardings=sharding)(x) else: - if (hasattr(x, 'layout') and x.layout.device_local_layout == layout and - x.sharding.is_equivalent_to(sharding, x.ndim)): - return x - return api.jit(_identity_fn, out_shardings=Layout(layout, sharding))(x) + return api.jit(_identity_fn, out_shardings=Format(layout, sharding))(x) sharding_constraint_p = core.Primitive("sharding_constraint") sharding_constraint_p.def_impl(_sharding_constraint_impl) -sharding_constraint_p.def_abstract_eval(lambda x, **_: x) ad.deflinear2(sharding_constraint_p, lambda ct, _, **params: (sharding_constraint_p.bind(ct, **params),)) +def _sharding_constraint_abstract_eval( + x_aval, *, sharding, layout, context_mesh, unconstrained_dims): + if isinstance(sharding, NamedSharding): + return x_aval.update( + sharding=x_aval.sharding.update(mesh=sharding.mesh.abstract_mesh)) + return x_aval.update(sharding=None) +sharding_constraint_p.def_abstract_eval(_sharding_constraint_abstract_eval) + def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding, layout, context_mesh, unconstrained_dims): - aval, = ctx.avals_in + in_aval, = ctx.avals_in out_aval, = ctx.avals_out axis_ctx = ctx.module_context.axis_context + + if (isinstance(sharding, NamedSharding) and + any(o is not None for o in out_aval.sharding.spec)): + spec = sharding.spec._normalized_spec_for_aval(in_aval.ndim) + new_spec = [] + for user_spec, aval_spec in zip(spec, out_aval.sharding.spec): + if aval_spec is None: + new_spec.append(user_spec) + else: + aval_spec = aval_spec if isinstance(aval_spec, tuple) else (aval_spec,) + if user_spec is PartitionSpec.UNCONSTRAINED: + raise NotImplementedError + if user_spec is None: + new_spec.append(aval_spec) + elif isinstance(user_spec, tuple): + new_spec.append(aval_spec + user_spec) + else: + new_spec.append(aval_spec + (user_spec,)) + sharding = sharding.update(spec=new_spec) + + if dtypes.issubdtype(in_aval.dtype, dtypes.extended): + in_aval = core.physical_aval(in_aval) if (isinstance(axis_ctx, sharding_impls.SPMDAxisContext) and axis_ctx.manual_axes): - sharding = mlir.add_manual_axes(axis_ctx, sharding, aval.ndim) + sharding = mlir.add_manual_axes(axis_ctx, sharding, in_aval.ndim) if config.use_shardy_partitioner.value: - sharding = sharding._to_sdy_sharding(aval.ndim) + sharding = sharding._to_sdy_sharding(in_aval.ndim) else: - sharding = sharding._to_xla_hlo_sharding(aval.ndim).to_proto() + sharding = sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto() out = mlir.wrap_with_sharding_op( ctx, x_node, out_aval, sharding, unspecified_dims=unconstrained_dims) if layout is not None: - out = mlir.wrap_with_layout_op(ctx, out, out_aval, layout, aval) + out = mlir.wrap_with_layout_op(ctx, out, out_aval, layout, in_aval) return [out] mlir.register_lowering(sharding_constraint_p, _sharding_constraint_hlo_lowering) @@ -2666,154 +2221,99 @@ def _sharding_constraint_batcher( vmapped_sharding = NamedSharding( vmapped_sharding.mesh, PartitionSpec(*new_spec)) - # TODO(yashkatariya): Figure out layouts should change under vmap. - if layout is not None: - raise NotImplementedError( - 'Concrete layout is not supported for vmap(with_sharding_constraint). ' - f'Got layout {layout}') + vmapped_layout = (get_layout_for_vmap(d, layout) if layout is not None else + layout) y = sharding_constraint_p.bind( x, sharding=vmapped_sharding, - layout=layout, + layout=vmapped_layout, context_mesh=context_mesh, - unconstrained_dims=unconstrained_dims) + unconstrained_dims=frozenset(unconstrained_dims)) return y, d batching.fancy_primitive_batchers[sharding_constraint_p] = _sharding_constraint_batcher batching.skippable_batchers[sharding_constraint_p] = lambda _: () -# -------------------- mesh_cast --------------------------- - -# TODO(yashkatariya): Make shardings optional. -def mesh_cast(xs, out_shardings): - x_flat, treedef = tree_flatten(xs) - shardings_flat = flatten_axes("mesh_cast shardings", treedef, out_shardings) - out_flat = [ - mesh_cast_p.bind( - x, dst_sharding=canonicalize_sharding( - s, 'mesh_cast', check_mesh_consistency=False)) - for x, s in safe_zip(x_flat, shardings_flat) - ] - return tree_unflatten(treedef, out_flat) - -mesh_cast_p = core.Primitive('mesh_cast') -mesh_cast_p.skip_canonicalization = True -def _mesh_cast_abstract_eval(aval, dst_sharding): - src_sharding = aval.sharding - if src_sharding == dst_sharding: - return aval - if src_sharding.mesh.empty or dst_sharding.mesh.empty: - return aval.update(sharding=dst_sharding) - if src_sharding.mesh.shape_tuple != dst_sharding.mesh.shape_tuple: - raise ValueError( - f'Mesh shape of the input {src_sharding.mesh.shape_tuple} does not' - ' match the mesh shape of the target sharding' - f' {dst_sharding.mesh.shape_tuple} for shape {aval.str_short()}') - if (src_sharding.mesh._axis_types_dict == dst_sharding.mesh._axis_types_dict - and src_sharding.spec != dst_sharding.spec): - raise ValueError( - 'mesh_cast should only be used when AxisType changes between the' - ' input mesh and the target mesh. Got src' - f' axis_types={src_sharding.mesh._axis_types_dict} and dst' - f' axis_types={dst_sharding.mesh._axis_types_dict}. To reshard between' - ' the same mesh, use `jax.sharding.reshard` instead?') - if src_sharding.mesh._any_axis_explicit and dst_sharding.mesh._any_axis_explicit: - for s, d in safe_zip(flatten_spec(src_sharding.spec), - flatten_spec(dst_sharding.spec)): - if s is None and d is None: - continue - if s is None and d is not None: - assert (src_sharding.mesh._name_to_type[d] == mesh_lib.AxisType.Auto - and dst_sharding.mesh._name_to_type[d] == mesh_lib.AxisType.Explicit) - continue - if s is not None and d is None: - assert (src_sharding.mesh._name_to_type[s] == mesh_lib.AxisType.Explicit - and dst_sharding.mesh._name_to_type[s] == mesh_lib.AxisType.Auto) - continue - if d != s: - raise ValueError( - 'Explicit data movement in mesh_cast is not allowed. Got src spec:' - f' {s} and dst spec: {d}') - return aval.update(sharding=dst_sharding) -mesh_cast_p.def_abstract_eval(_mesh_cast_abstract_eval) - -def _mesh_cast_impl(x, dst_sharding): - return dispatch.apply_primitive(mesh_cast_p, x, dst_sharding=dst_sharding) -mesh_cast_p.def_impl(_mesh_cast_impl) - -def _mesh_cast_transpose_rule(ct, x, dst_sharding): - return [mesh_cast_p.bind(ct, dst_sharding=x.aval.sharding)] -ad.deflinear2(mesh_cast_p, _mesh_cast_transpose_rule) - -def _mesh_cast_hlo_lowering(ctx, x_node, *, dst_sharding): - aval, = ctx.avals_in - aval_out, = ctx.avals_out - proto = (dst_sharding._to_sdy_sharding(aval.ndim) - if config.use_shardy_partitioner.value else - dst_sharding._to_xla_hlo_sharding(aval.ndim).to_proto()) - return [mlir.lower_with_sharding_in_types(ctx, x_node, aval_out, proto)] -mlir.register_lowering(mesh_cast_p, _mesh_cast_hlo_lowering) - -def _mesh_cast_batcher(axis_data, vals_in, dims_in, dst_sharding): - x, = vals_in - d, = dims_in - vmapped_dst_sharding = batching.get_sharding_for_vmap( - axis_data, dst_sharding, d) - y = mesh_cast_p.bind(x, dst_sharding=vmapped_dst_sharding) - return y, d -batching.fancy_primitive_batchers[mesh_cast_p] = _mesh_cast_batcher -batching.skippable_batchers[mesh_cast_p] = lambda _: () - # -------------------- reshard ------------------------------------ def reshard(xs, out_shardings): x_flat, treedef = tree_flatten(xs) - shardings_flat = flatten_axes("reshard shardings", treedef, out_shardings) + shardings_flat = flatten_axis_resources( + "reshard out_shardings", treedef, out_shardings, tupled_args=True) x_avals_flat = [core.shaped_abstractify(x) for x in x_flat] out_flat = [] for x, x_aval, s in safe_zip(x_flat, x_avals_flat, shardings_flat): - ds = canonicalize_sharding(s, 'reshard') - ds = ds.with_spec(ds.spec._normalized_spec_for_aval(x_aval.ndim)) # pytype: disable=attribute-error - out_flat.append(reshard_p.bind(x, dst_sharding=ds)) + ds = canonicalize_sharding(s, 'reshard', check_mesh_consistency=False) + if ds is None: + raise ValueError( + 'Reshard should only be used with out_shardings which are non-None ' + f'and have a nonempty mesh. Got sharding {s}.' + ) + ds = ds.update(spec=ds.spec._normalized_spec_for_aval(x_aval.ndim)) # pytype: disable=attribute-error + cmesh = (s.mesh if (isinstance(s, NamedSharding) and + isinstance(s.mesh, mesh_lib.Mesh)) + else None) + out_flat.append(reshard_p.bind(x, dst_sharding=ds, concrete_mesh=cmesh)) return tree_unflatten(treedef, out_flat) reshard_p = core.Primitive('reshard') +reshard_p.skip_canonicalization = True -def _reshard_abstract_eval(aval, dst_sharding): - src_sharding = aval.sharding - if (not src_sharding.mesh.empty and - src_sharding.mesh.abstract_mesh != dst_sharding.mesh.abstract_mesh): - raise ValueError( - f'Mesh of the input {src_sharding.mesh.abstract_mesh} does not' - ' equal the mesh of the target sharding' - f' {dst_sharding.mesh.abstract_mesh} for shape {aval.str_short()}') +def _reshard_abstract_eval(aval, *, dst_sharding, concrete_mesh): + assert isinstance(aval, core.ShapedArray) + if aval.sharding == dst_sharding: + return aval return aval.update(sharding=dst_sharding) reshard_p.def_abstract_eval(_reshard_abstract_eval) -def _reshard_impl(x, dst_sharding): - return dispatch.apply_primitive(reshard_p, x, dst_sharding=dst_sharding) +def _reshard_impl(x, *, dst_sharding, concrete_mesh): + thunk = lambda: dispatch.apply_primitive( + reshard_p, x, dst_sharding=dst_sharding, concrete_mesh=concrete_mesh) + if concrete_mesh is None: + return thunk() + else: + with sharding_impls.set_mesh(concrete_mesh): + return thunk() reshard_p.def_impl(_reshard_impl) -def _reshard_transpose_rule(ct, x, dst_sharding): - return [reshard_p.bind(ct, dst_sharding=x.aval.sharding)] +def _reshard_transpose_rule(ct, x, *, dst_sharding, concrete_mesh): + assert ad.is_undefined_primal(x) + out_sharding = x.aval.to_cotangent_aval().sharding + with mesh_lib.use_abstract_mesh(out_sharding.mesh): + x_bar = reshard_p.bind(ct, dst_sharding=out_sharding, + concrete_mesh=concrete_mesh) + return [x_bar] ad.deflinear2(reshard_p, _reshard_transpose_rule) -def _reshard_hlo_lowering(ctx, x_node, *, dst_sharding): - aval, = ctx.avals_in +def _reshard_transpose_fancy(ct, x, *, dst_sharding, concrete_mesh): + assert isinstance(x, ad.GradAccum) + if type(ct) is ad.Zero: + return + out_sharding = x.aval.to_cotangent_aval().sharding + with mesh_lib.use_abstract_mesh(out_sharding.mesh): + x_bar = reshard_p.bind(ct, dst_sharding=out_sharding, + concrete_mesh=concrete_mesh) + x.accum(x_bar) +ad.fancy_transposes[reshard_p] = _reshard_transpose_fancy + +def _reshard_hlo_lowering(ctx, x_node, *, dst_sharding, concrete_mesh): + aval_in, = ctx.avals_in aval_out, = ctx.avals_out - proto = (dst_sharding._to_sdy_sharding(aval.ndim) + if dtypes.issubdtype(aval_in.dtype, dtypes.extended): + aval_in = core.physical_aval(aval_in) + proto = (dst_sharding._to_sdy_sharding(aval_in.ndim) if config.use_shardy_partitioner.value else - dst_sharding._to_xla_hlo_sharding(aval.ndim).to_proto()) + dst_sharding._to_xla_hlo_sharding(aval_in.ndim).to_proto()) return [mlir.lower_with_sharding_in_types(ctx, x_node, aval_out, proto)] mlir.register_lowering(reshard_p, _reshard_hlo_lowering) -def _reshard_batcher(axis_data, vals_in, dims_in, dst_sharding): - assert axis_data.spmd_name is None +def _reshard_batcher(axis_data, vals_in, dims_in, dst_sharding, concrete_mesh): x, = vals_in d, = dims_in vmapped_dst_sharding = batching.get_sharding_for_vmap( axis_data, dst_sharding, d) - y = reshard_p.bind(x, dst_sharding=vmapped_dst_sharding) + y = reshard_p.bind(x, dst_sharding=vmapped_dst_sharding, + concrete_mesh=concrete_mesh) return y, d batching.fancy_primitive_batchers[reshard_p] = _reshard_batcher batching.skippable_batchers[reshard_p] = lambda _: () @@ -2821,125 +2321,159 @@ def _reshard_batcher(axis_data, vals_in, dims_in, dst_sharding): # -------------------- auto and user mode ------------------------- def _get_new_mesh(axes: str | tuple[str, ...] | None, - axis_type: mesh_lib.AxisType, name: str, - error_on_manual_to_auto_explict=False): + axis_type: mesh_lib.AxisType, name: str, shardings=None): cur_mesh = mesh_lib.get_abstract_mesh() - # TODO(yashkatariya): Maybe allow fetching mesh from the args to enable - # computation follows data? - if cur_mesh.empty: + flat_shardings, _ = tree_flatten(shardings) + sharding_mesh = mesh_lib.empty_abstract_mesh + for i in flat_shardings: + if isinstance(i, NamedSharding): + if not sharding_mesh.empty and sharding_mesh != i.mesh.abstract_mesh: + raise ValueError( + f'Shardings passed to {name} should have the same mesh. Got one' + f' mesh {sharding_mesh} and another {i.mesh}') + sharding_mesh = i.mesh.abstract_mesh + + if sharding_mesh.empty and cur_mesh.empty: raise ValueError( f'Context mesh {cur_mesh} cannot be empty. Please use' - ' `jax.sharding.use_mesh` API to enter into a mesh context when using' + ' `jax.set_mesh` API to enter into a mesh context when using' f' `{name}` API.') + if not sharding_mesh.empty and not cur_mesh.empty: + if sharding_mesh != cur_mesh: + raise ValueError( + f'Context mesh {cur_mesh} must match the mesh passed to shardings' + f' {sharding_mesh}. Recommended approach is to use' + ' `jax.set_mesh` context manager.') + mesh_to_use = cur_mesh + elif sharding_mesh.empty and not cur_mesh.empty: + mesh_to_use = cur_mesh + else: + assert not sharding_mesh.empty and cur_mesh.empty + mesh_to_use = sharding_mesh + if axes is None: - axes = cur_mesh.axis_names + axes = mesh_to_use.axis_names if not isinstance(axes, tuple): axes = (axes,) for a in axes: - if (error_on_manual_to_auto_explict and - cur_mesh._name_to_type[a] == mesh_lib.AxisType.Manual and + if (mesh_to_use._name_to_type[a] == mesh_lib.AxisType.Manual and axis_type in {mesh_lib.AxisType.Auto, mesh_lib.AxisType.Explicit}): raise NotImplementedError( 'Going from `Manual` AxisType to `Auto` or `Explicit` AxisType is not' ' allowed. Please file a bug at https://github.com/jax-ml/jax/issues' ' with your use case') - return cur_mesh.update_axis_types({a: axis_type for a in axes}) - -def auto_axes(fun, *, axes: str | tuple[str, ...] | None = None, - out_shardings=None): + return (mesh_to_use.update_axis_types({a: axis_type for a in axes}), + mesh_to_use, axes) + +def auto_axes(f=None, /, *, axes: str | tuple[str, ...] | None = None, + out_sharding=None): + kwargs = dict(axes_=axes, out_sharding=out_sharding) + if f is None: + return lambda g: _auto_axes(g, **kwargs) + return _auto_axes(f, **kwargs) + +def _auto_axes(fun, *, axes_, out_sharding): + @wraps(fun) def decorator(*args, **kwargs): - if out_shardings is None: - if "out_shardings" in kwargs: - _out_shardings = kwargs.pop("out_shardings") + if out_sharding is None: + if "out_sharding" in kwargs: + _out_sharding = kwargs.pop("out_sharding") else: - raise TypeError("Missing required keyword argument: 'out_shardings'") + raise TypeError("Missing required keyword argument: 'out_sharding'") else: - _out_shardings = out_shardings - new_mesh = _get_new_mesh(axes, mesh_lib.AxisType.Auto, 'auto_axes', - error_on_manual_to_auto_explict=True) + _out_sharding = out_sharding + new_mesh, prev_mesh, axes = _get_new_mesh( + axes_, mesh_lib.AxisType.Auto, 'auto_axes', shardings=_out_sharding) + if set(prev_mesh.auto_axes) == set(axes): + return fun(*args, **kwargs) with mesh_lib.use_abstract_mesh(new_mesh): in_specs = tree_map(lambda a: core.modify_spec_for_auto_manual( core.get_aval(a).sharding.spec, new_mesh), args) - args = mesh_cast(args, in_specs) + args = reshard(args, in_specs) out = fun(*args, **kwargs) - return mesh_cast(out, _out_shardings) + return reshard(out, _out_sharding) return decorator -@contextlib.contextmanager -def use_auto_axes(*axes): - new_mesh = _get_new_mesh(axes, mesh_lib.AxisType.Auto, 'use_auto_axes') - with mesh_lib.use_abstract_mesh(new_mesh): - yield +def explicit_axes(f=None, /, *, axes: str | tuple[str, ...] | None = None, + in_sharding=None): + kwargs = dict(axes=axes, in_sharding=in_sharding) + if f is None: + return lambda g: _explicit_axes(g, **kwargs) + return _explicit_axes(f, **kwargs) -def explicit_axes(fun, *, axes: str | tuple[str, ...] | None = None, - in_shardings=None): +def _explicit_axes(fun, *, axes, in_sharding): + @wraps(fun) def decorator(*args, **kwargs): - if in_shardings is None: - if "in_shardings" in kwargs: - _in_shardings = kwargs.pop("in_shardings") + if in_sharding is None: + if "in_sharding" in kwargs: + _in_sharding = kwargs.pop("in_sharding") else: - raise TypeError("Missing required keyword argument: 'in_shardings'") + raise TypeError("Missing required keyword argument: 'in_sharding'") else: - _in_shardings = in_shardings - new_mesh = _get_new_mesh(axes, mesh_lib.AxisType.Explicit, 'explicit_axes', - error_on_manual_to_auto_explict=True) + _in_sharding = in_sharding + new_mesh, _, _ = _get_new_mesh(axes, mesh_lib.AxisType.Explicit, + 'explicit_axes') with mesh_lib.use_abstract_mesh(new_mesh): - args = mesh_cast(args, _in_shardings) + args = reshard(args, _in_sharding) out = fun(*args, **kwargs) out_specs = tree_map(lambda o: core.modify_spec_for_auto_manual( core.get_aval(o).sharding.spec, mesh_lib.get_abstract_mesh()), out) - return mesh_cast(out, out_specs) + return reshard(out, out_specs) return decorator -@contextlib.contextmanager -def use_explicit_axes(*axes): - new_mesh = _get_new_mesh(axes, mesh_lib.AxisType.Explicit, - 'use_explicit_axes') - with mesh_lib.use_abstract_mesh(new_mesh): - yield - -# -------------------- helpers -------------------- - -def get_unconstrained_dims(sharding: NamedSharding): - assert sharding.spec is not None - return {i for i, axes in enumerate(sharding.spec) - if axes is PartitionSpec.UNCONSTRAINED} - - -def get_op_sharding_from_executable( - executable) -> tuple[Sequence[xc.OpSharding], Sequence[xc.OpSharding]]: - in_op_shardings: list[xc.OpSharding] = [] - parameter_shardings_from_xla = executable.get_parameter_shardings() - if parameter_shardings_from_xla is not None: - in_op_shardings = parameter_shardings_from_xla +# -------------------- with_layout_constraint -------------------- - out_op_shardings: list[xc.OpSharding] = [] - output_shardings_from_xla = executable.get_output_shardings() - if output_shardings_from_xla is not None: - out_op_shardings = output_shardings_from_xla +def with_layout_constraint(x, layouts): + x_flat, tree = tree_flatten(x) + x_avals_flat = [core.shaped_abstractify(x) for x in x_flat] + layouts_flat = tuple(flatten_axes("with_layout_constraint layouts", tree, + layouts)) + if any(not isinstance(l, Layout) for l in layouts_flat): + raise ValueError( + 'layouts passed to `with_layout_constraint` must be of type' + f' `Layout`. Got {[type(l) for l in layouts_flat]}') + check_aval_layout_compatibility( + layouts_flat, x_avals_flat, ("",) * len(layouts_flat), + "with_layout_constraint arguments") + outs = [layout_constraint_p.bind(xf, layout=l) + for xf, l in zip(x_flat, layouts_flat)] + return tree_unflatten(tree, outs) - return in_op_shardings, out_op_shardings +layout_constraint_p = core.Primitive('layout_constraint') +layout_constraint_p.def_abstract_eval(lambda x, **_: x) +ad.deflinear2(layout_constraint_p, + lambda ct, _, **params: (layout_constraint_p.bind(ct, **params),)) +def _layout_constraint_impl(x, *, layout): + if not isinstance(x, xc.ArrayImpl): + raise ValueError( + 'with_layout_constraint in eager mode can only be applied to' + f' jax.Arrays. Got {type(x)}') + if x.format.layout == layout: # type: ignore + return x + return api.jit(_identity_fn, out_shardings=Format(layout, x.sharding))(x) +layout_constraint_p.def_impl(_layout_constraint_impl) -def _get_ppspec_from_executable( - executable, mesh - ) -> tuple[Sequence[PartitionSpec], Sequence[PartitionSpec]]: - input_op_shardings, output_op_sharding = get_op_sharding_from_executable( - executable - ) - in_pspec: list[PartitionSpec] = [] - for s in input_op_shardings: - in_pspec.extend(parse_flatten_op_sharding(s, mesh)) +def _layout_constraint_hlo_lowering(ctx, x_node, *, layout): + aval, = ctx.avals_in + out_aval, = ctx.avals_out + return [mlir.wrap_with_layout_op(ctx, x_node, out_aval, layout, aval)] +mlir.register_lowering(layout_constraint_p, + _layout_constraint_hlo_lowering) - out_pspec: list[PartitionSpec] = [] - for s in output_op_sharding: - out_pspec.extend(parse_flatten_op_sharding(s, mesh)) - return in_pspec, out_pspec +def _layout_constraint_batcher(axis_data, vals_in, dims_in, layout): + x, = vals_in + d, = dims_in + vmapped_layout = get_layout_for_vmap(d, layout) + y = layout_constraint_p.bind(x, layout=vmapped_layout) + return y, d +batching.fancy_primitive_batchers[layout_constraint_p] = _layout_constraint_batcher +batching.skippable_batchers[layout_constraint_p] = lambda _: () +# -------------------- helpers -------------------- -def get_pspec_from_executable( - executable, mesh: pxla.Mesh -) -> tuple[tuple[PartitionSpec, ...], tuple[PartitionSpec, ...]]: - in_pspec, out_pspec = _get_ppspec_from_executable(executable, mesh) - return tuple(in_pspec), tuple(out_pspec) +def get_unconstrained_dims(sharding: NamedSharding): + assert sharding.spec is not None + return frozenset(i for i, axes in enumerate(sharding.spec) + if axes is PartitionSpec.UNCONSTRAINED) diff --git a/jax/_src/pmap.py b/jax/_src/pmap.py new file mode 100644 index 000000000000..ea59eaa4e6cb --- /dev/null +++ b/jax/_src/pmap.py @@ -0,0 +1,123 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from functools import partial + +from jax._src import core +from jax._src import dispatch +from jax._src import linear_util as lu +from jax._src import stages +from jax._src import traceback_util +from jax._src import util +from jax._src import xla_bridge as xb +from jax._src.shard_map import _shard_map, _axes_to_pspec +from jax._src.api import _shared_code_pmap, _prepare_pmap, jit +from jax._src.mesh import Mesh +from jax._src.lax import lax +from jax._src.tree_util import tree_map, tree_unflatten + +map, unsafe_map = util.safe_map, map +zip, unsafe_zip = util.safe_zip, zip +traceback_util.register_exclusion(__file__) + +# Implementing pmap in terms of shard_map + +def pmap(f, axis_name=None, *, in_axes=0, out_axes=0, + static_broadcasted_argnums=(), devices=None, backend=None, + axis_size=None, donate_argnums=(), global_arg_shapes=None): + del global_arg_shapes + # TODO(vanderplas): move these definitions into jax._src and avoid local import. + import jax.experimental.multihost_utils as mhu # pytype: disable=import-error + devices = tuple(devices) if devices is not None else devices + axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap( + f, axis_name, static_broadcasted_argnums, donate_argnums, in_axes, out_axes) + if isinstance(axis_name, core._TempAxisName): + axis_name = repr(axis_name) + + def infer_params(*args, __check=True, **kwargs): + p = _prepare_pmap(f, in_axes, out_axes, static_broadcasted_tuple, + donate_tuple, devices, backend, axis_size, args, kwargs) + if __check: + for arg in p.flat_args: + dispatch.check_arg(arg) + mesh = Mesh(_get_devices(p, backend), (axis_name,)) + _pmapped, in_specs, out_specs = _cached_shard_map( + p.flat_fun, mesh, p.in_axes_flat, p.out_axes_thunk, axis_name) + jitted_f = jit( + _pmapped, + donate_argnums=[i for i, val in enumerate(p.donated_invars) if val]) + if __check and xb.process_count() > 1: + flat_global_args = mhu.host_local_array_to_global_array( + p.flat_args, mesh, list(in_specs)) + else: + flat_global_args = p.flat_args + return jitted_f, flat_global_args, p, mesh, out_specs, donate_tuple + + @util.wraps(f) + def wrapped(*args, **kwargs): + jitted_f, flat_global_args, p, mesh, out_specs, _ = infer_params( + *args, **kwargs) + outs = jitted_f(*flat_global_args) + if xb.process_count() > 1: + outs = mhu.global_array_to_host_local_array(outs, mesh, out_specs()) + return tree_unflatten(p.out_tree(), outs) + + def lower(*args, **kwargs): + jitted_f, flat_global_args, p, _, _, donate_tuple = infer_params( + *args, __check=False, **kwargs + ) + abstract_args = list(map(core.shaped_abstractify, flat_global_args)) + args_info = stages.make_args_info(p.in_tree, abstract_args, donate_tuple) + lowered = jitted_f.trace(*flat_global_args).lower() + lowered = stages.Lowered(lowered._lowering, args_info, p.out_tree(), + no_kwargs=lowered._no_kwargs) + return lowered + wrapped.lower = lower + return wrapped + + +@lu.cache +def _cached_shard_map(flat_fun, mesh, in_axes_flat, out_axes_thunk, axis_name): + f_transformed = flat_fun.f_transformed + def reset_stores_f_transformed(*args, **kwargs): + for store in flat_fun.stores: + if store is not None: + store.reset() + return f_transformed(*args, **kwargs) + flat_fun.f_transformed = reset_stores_f_transformed + in_specs = tuple(map(partial(_axes_to_pspec, axis_name), in_axes_flat)) + out_specs = lambda: map(partial(_axes_to_pspec, axis_name), out_axes_thunk()) + fun = _handle_reshapes(flat_fun, in_axes_flat, out_axes_thunk) + return (_shard_map(fun.call_wrapped, mesh=mesh, in_specs=in_specs, + out_specs=out_specs, check_vma=False, + axis_names=set(mesh.axis_names)), + in_specs, out_specs) + +@lu.transformation2 +def _handle_reshapes(f, in_axes, out_axes_thunk, *args, **kwargs): + args = tree_map(lambda x, ax: x if ax is None else lax.squeeze(x, [ax]), + list(args), list(in_axes)) + out = f(*args) + return tree_map(lambda x, ax: x if ax is None else lax.expand_dims(x, [ax]), + list(out), list(out_axes_thunk())) + +def _get_devices(p, backend): + if backend is not None and p.devices is None: + devs = xb.devices(backend=backend) + else: + devs = xb.devices() if p.devices is None else p.devices + if xb.process_count() > 1: + return devs[:p.global_axis_size] + return devs[:p.local_axis_size] diff --git a/jax/_src/pretty_printer.py b/jax/_src/pretty_printer.py index e8fdff497445..b29b5d478158 100644 --- a/jax/_src/pretty_printer.py +++ b/jax/_src/pretty_printer.py @@ -28,18 +28,12 @@ from __future__ import annotations from collections.abc import Sequence -import enum from functools import partial import sys -from typing import Any, NamedTuple +from typing import Any from jax._src import config -from jax._src import util - -try: - import colorama # pytype: disable=import-error -except ImportError: - colorama = None +from jax._src.lib import _pretty_printer as _pretty_printer _PPRINT_USE_COLOR = config.bool_state( @@ -66,409 +60,40 @@ def _can_use_color() -> bool: CAN_USE_COLOR = _can_use_color() -class Doc(util.StrictABC): - __slots__ = () - - def format( - self, width: int = 80, *, use_color: bool | None = None, - annotation_prefix: str = " # ", - source_map: list[list[tuple[int, int, Any]]] | None = None - ) -> str: - """ - Formats a pretty-printer document as a string. - - Args: - source_map: for each line in the output, contains a list of - (start column, end column, source) tuples. Each tuple associates a - region of output text with a source. - """ - if use_color is None: - use_color = CAN_USE_COLOR and _PPRINT_USE_COLOR.value - return _format(self, width, use_color=use_color, - annotation_prefix=annotation_prefix, source_map=source_map) - - def __str__(self): - return self.format() - - def __add__(self, other: Doc) -> Doc: - return concat([self, other]) - -class _NilDoc(Doc): - def __repr__(self): return "nil" - -_nil = _NilDoc() - -class _TextDoc(Doc): - __slots__ = ("text", "annotation") - text: str - annotation: str | None - - def __init__(self, text: str, annotation: str | None = None): - assert isinstance(text, str), text - assert annotation is None or isinstance(annotation, str), annotation - self.text = text - self.annotation = annotation - - def __repr__(self): - if self.annotation is not None: - return f"text(\"{self.text}\", annotation=\"{self.annotation}\")" - else: - return f"text(\"{self.text}\")" - -class _ConcatDoc(Doc): - __slots__ = ("children",) - children: list[Doc] - - def __init__(self, children: Sequence[Doc]): - self.children = list(children) - assert all(isinstance(doc, Doc) for doc in self.children), self.children - - def __repr__(self): return f"concat({self.children})" - -class _BreakDoc(Doc): - __slots__ = ("text",) - text: str - - def __init__(self, text: str): - assert isinstance(text, str), text - self.text = text - - def __repr__(self): return f"break({self.text})" - -class _GroupDoc(Doc): - __slots__ = ("child",) - child: Doc - - def __init__(self, child: Doc): - assert isinstance(child, Doc), child - self.child = child - - def __repr__(self): return f"group({self.child})" - -class _NestDoc(Doc): - __slots__ = ("n", "child",) - n: int - child: Doc - - def __init__(self, n: int, child: Doc): - assert isinstance(child, Doc), child - self.n = n - self.child = child - - def __repr__(self): return f"nest({self.n, self.child})" - - -_NO_SOURCE = object() - -class _SourceMapDoc(Doc): - __slots__ = ("child", "source") - child: Doc - source: Any - - def __init__(self, child: Doc, source: Any): - assert isinstance(child, Doc), child - self.child = child - self.source = source - - def __repr__(self): return f"source({self.child}, {self.source})" - - -Color = enum.Enum("Color", ["BLACK", "RED", "GREEN", "YELLOW", "BLUE", - "MAGENTA", "CYAN", "WHITE", "RESET"]) -Intensity = enum.Enum("Intensity", ["DIM", "NORMAL", "BRIGHT"]) - -class _ColorDoc(Doc): - __slots__ = ("foreground", "background", "intensity", "child") - foreground: Color | None - background: Color | None - intensity: Intensity | None - child: Doc - - def __init__(self, child: Doc, *, foreground: Color | None = None, - background: Color | None = None, - intensity: Intensity | None = None): - assert isinstance(child, Doc), child - self.child = child - self.foreground = foreground - self.background = background - self.intensity = intensity - - -_BreakMode = enum.Enum("_BreakMode", ["FLAT", "BREAK"]) - - -# In Lindig's paper fits() and format() are defined recursively. This is a -# non-recursive formulation using an explicit stack, necessary because Python -# doesn't have a tail recursion optimization. - -def _fits(doc: Doc, width: int, agenda: list[tuple[int, _BreakMode, Doc]] - ) -> bool: - while width >= 0 and len(agenda) > 0: - i, m, doc = agenda.pop() - if isinstance(doc, _NilDoc): - pass - elif isinstance(doc, _TextDoc): - width -= len(doc.text) - elif isinstance(doc, _ConcatDoc): - agenda.extend((i, m, d) for d in reversed(doc.children)) - elif isinstance(doc, _BreakDoc): - if m == _BreakMode.BREAK: - return True - width -= len(doc.text) - elif isinstance(doc, _NestDoc): - agenda.append((i + doc.n, m, doc.child)) - elif isinstance(doc, _GroupDoc): - agenda.append((i, _BreakMode.FLAT, doc.child)) - elif isinstance(doc, _ColorDoc) or isinstance(doc, _SourceMapDoc): - agenda.append((i, m, doc.child)) - else: - raise ValueError("Invalid document ", doc) - - return width >= 0 - - -# Annotation layout: A flat group is sparse if there are no breaks between -# annotations. -def _sparse(doc: Doc) -> bool: - agenda = [doc] - num_annotations = 0 - seen_break = False - while len(agenda) > 0: - doc = agenda.pop() - if isinstance(doc, _NilDoc): - pass - elif isinstance(doc, _TextDoc): - if doc.annotation is not None: - if num_annotations >= 1 and seen_break: - return False - num_annotations += 1 - elif isinstance(doc, _ConcatDoc): - agenda.extend(reversed(doc.children)) - elif isinstance(doc, _BreakDoc): - seen_break = True - elif isinstance(doc, _NestDoc): - agenda.append(doc.child) - elif isinstance(doc, _GroupDoc): - agenda.append(doc.child) - elif isinstance(doc, _ColorDoc) or isinstance(doc, _SourceMapDoc): - agenda.append(doc.child) - else: - raise ValueError("Invalid document ", doc) - - return True - -class _ColorState(NamedTuple): - foreground: Color - background: Color - intensity: Intensity - -class _State(NamedTuple): - indent: int - mode: _BreakMode - doc: Doc - color: _ColorState - source_map: Any - -class _Line(NamedTuple): - text: str - width: int - annotations: str | None | list[str] - - -def _update_color(use_color: bool, state: _ColorState, update: _ColorState - ) -> tuple[_ColorState, str]: - if not use_color or colorama is None: - return update, "" - color_str = "" - if state.foreground != update.foreground: - color_str += getattr(colorama.Fore, str(update.foreground.name)) - if state.background != update.background: - color_str += getattr(colorama.Back, str(update.background.name)) - if state.intensity != update.intensity: - color_str += colorama.Style.NORMAL # pytype: disable=unsupported-operands - color_str += getattr(colorama.Style, str(update.intensity.name)) - return update, color_str - - -def _align_annotations(lines): - # TODO: Hafiz also implements a local alignment mode, where groups of lines - # with annotations are aligned together. - maxlen = max(l.width for l in lines) - out = [] - for l in lines: - if len(l.annotations) == 0: - out.append(l._replace(annotations=None)) - elif len(l.annotations) == 1: - out.append(l._replace(text=l.text + " " * (maxlen - l.width), - annotations=l.annotations[0])) - else: - out.append(l._replace(text=l.text + " " * (maxlen - l.width), - annotations=l.annotations[0])) - for a in l.annotations[1:]: - out.append(_Line(text=" " * maxlen, width=l.width, annotations=a)) - return out - - +Color = _pretty_printer.Color +Intensity = _pretty_printer.Intensity +Doc = _pretty_printer.Doc def _format( - doc: Doc, width: int, *, use_color: bool, annotation_prefix: str, - source_map: list[list[tuple[int, int, Any]]] | None + self, width: int = 80, *, use_color: bool | None = None, + annotation_prefix: str = " # ", + source_map: list[list[tuple[int, int, Any]]] | None = None ) -> str: - lines = [] - default_colors = _ColorState(Color.RESET, Color.RESET, Intensity.NORMAL) - annotation_colors = _ColorState(Color.RESET, Color.RESET, Intensity.DIM) - color_state = default_colors - source_start = 0 # The column at which the current source region starts. - source = _NO_SOURCE # The currently active source region. - line_source_map = [] # Source maps for the current line of text. - agenda = [_State(0, _BreakMode.BREAK, doc, default_colors, source)] - k = 0 - line_text = "" - line_annotations = [] - while len(agenda) > 0: - i, m, doc, color, agenda_source = agenda.pop() - if source_map is not None and agenda_source != source: - pos = len(line_text) - if source_start != pos and source is not _NO_SOURCE: - line_source_map.append((source_start, pos, source)) - source = agenda_source - source_start = pos - if isinstance(doc, _NilDoc): - pass - elif isinstance(doc, _TextDoc): - color_state, color_str = _update_color(use_color, color_state, color) - line_text += color_str - line_text += doc.text - if doc.annotation is not None: - line_annotations.append(doc.annotation) - k += len(doc.text) - elif isinstance(doc, _ConcatDoc): - agenda.extend(_State(i, m, d, color, source) - for d in reversed(doc.children)) - elif isinstance(doc, _BreakDoc): - if m == _BreakMode.BREAK: - if len(line_annotations) > 0: - color_state, color_str = _update_color(use_color, color_state, - annotation_colors) - line_text += color_str - lines.append(_Line(line_text, k, line_annotations)) - if source_map is not None: - pos = len(line_text) - if source_start != pos and source is not _NO_SOURCE: - line_source_map.append((source_start, pos, source)) - source_map.append(line_source_map) - line_source_map = [] - source_start = i - line_text = " " * i - line_annotations = [] - k = i - else: - color_state, color_str = _update_color(use_color, color_state, color) - line_text += color_str - line_text += doc.text - k += len(doc.text) - elif isinstance(doc, _NestDoc): - agenda.append(_State(i + doc.n, m, doc.child, color, source)) - elif isinstance(doc, _GroupDoc): - # In Lindig's paper, _fits is passed the remainder of the document. - # I'm pretty sure that's a bug and we care only if the current group fits! - if (_sparse(doc) - and _fits(doc, width - k, [(i, _BreakMode.FLAT, doc.child)])): - agenda.append(_State(i, _BreakMode.FLAT, doc.child, color, source)) - else: - agenda.append(_State(i, _BreakMode.BREAK, doc.child, color, source)) - elif isinstance(doc, _ColorDoc): - color = _ColorState(doc.foreground or color.foreground, - doc.background or color.background, - doc.intensity or color.intensity) - agenda.append(_State(i, m, doc.child, color, source)) - elif isinstance(doc, _SourceMapDoc): - agenda.append(_State(i, m, doc.child, color, doc.source)) - else: - raise ValueError("Invalid document ", doc) - - if len(line_annotations) > 0: - color_state, color_str = _update_color(use_color, color_state, - annotation_colors) - line_text += color_str - if source_map is not None: - pos = len(line_text) - if source_start != pos and source is not _NO_SOURCE: - line_source_map.append((source_start, pos, source)) - source_map.append(line_source_map) - lines.append(_Line(line_text, k, line_annotations)) - lines = _align_annotations(lines) - out = "\n".join( - l.text if l.annotations is None - else f"{l.text}{annotation_prefix}{l.annotations}" for l in lines) - color_state, color_str = _update_color(use_color, color_state, - default_colors) - return out + color_str - - - - -# Public API. - -def nil() -> Doc: - """An empty document.""" - return _nil - -def text(s: str, annotation: str | None = None) -> Doc: - """Literal text.""" - return _TextDoc(s, annotation) - -def concat(docs: Sequence[Doc]) -> Doc: - """Concatenation of documents.""" - docs = list(docs) - if len(docs) == 1: - return docs[0] - return _ConcatDoc(docs) - -def brk(text: str = " ") -> Doc: - """A break. - - Prints either as a newline or as `text`, depending on the enclosing group. - """ - return _BreakDoc(text) - -def group(doc: Doc) -> Doc: - """Layout alternative groups. - - Prints the group with its breaks as their text (typically spaces) if the - entire group would fit on the line when printed that way. Otherwise, breaks - inside the group as printed as newlines. """ - return _GroupDoc(doc) - -def nest(n: int, doc: Doc) -> Doc: - """Increases the indentation level by `n`.""" - return _NestDoc(n, doc) - + Formats a pretty-printer document as a string. -def color(doc: Doc, *, foreground: Color | None = None, - background: Color | None = None, - intensity: Intensity | None = None): - """ANSI colors. - - Overrides the foreground/background/intensity of the text for the child doc. - Requires use_colors=True to be set when printing and the `colorama` package - to be installed; otherwise does nothing. + Args: + source_map: for each line in the output, contains a list of + (start column, end column, source) tuples. Each tuple associates a + region of output text with a source. """ - return _ColorDoc(doc, foreground=foreground, background=background, - intensity=intensity) - + if use_color is None: + use_color = CAN_USE_COLOR and _PPRINT_USE_COLOR.value + return self._format( + width, use_color=use_color, annotation_prefix=annotation_prefix, + source_map=source_map) +Doc.format = _format +Doc.__str__ = lambda self: self.format() # type: ignore[method-assign] + +nil = _pretty_printer.nil +text = _pretty_printer.text +concat = _pretty_printer.concat +brk = _pretty_printer.brk +group = _pretty_printer.group +nest = _pretty_printer.nest +color = _pretty_printer.color +source_map = _pretty_printer.source_map -def source_map(doc: Doc, source: Any): - """Source mapping. - - A source map associates a region of the pretty-printer's text output with a - source location that produced it. For the purposes of the pretty printer a - ``source`` may be any object: we require only that we can compare sources for - equality. A text region to source object mapping can be populated as a side - output of the ``format`` method. - """ - return _SourceMapDoc(doc, source) type_annotation = partial(color, intensity=Intensity.NORMAL, foreground=Color.MAGENTA) @@ -480,6 +105,8 @@ def join(sep: Doc, docs: Sequence[Doc]) -> Doc: docs = list(docs) if len(docs) == 0: return nil() + if len(docs) == 1: + return docs[0] xs = [docs[0]] for doc in docs[1:]: xs.append(sep) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 2fa9b2b37aa4..15f648e89789 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -21,19 +21,17 @@ import numpy as np -import jax -from jax import lax -from jax import numpy as jnp -from jax import tree_util - from jax._src import api from jax._src import config as config from jax._src import core from jax._src import dispatch from jax._src import dtypes +from jax._src import ffi +from jax._src import literals +from jax._src import numpy as jnp from jax._src import pretty_printer as pp from jax._src import source_info_util -from jax._src import tree_util as tree_util_internal +from jax._src import tree_util from jax._src import typing from jax._src.api import jit, vmap from jax._src.dtypes import float0 @@ -41,8 +39,9 @@ from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import pxla -from jax._src.interpreters import xla -from jax._src.lax import lax as lax_internal +from jax._src.lax import control_flow as lax_control_flow +from jax._src.lax import lax +from jax._src.lax import slicing as lax_slicing from jax._src.lib import gpu_prng from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir @@ -50,7 +49,8 @@ from jax._src.numpy.array_methods import ( _array_operators, _set_array_base_attributes, _IndexUpdateHelper) from jax._src.sharding_impls import ( - NamedSharding, PmapSharding, physical_sharding, logical_sharding) + NamedSharding, PmapSharding, SingleDeviceSharding, physical_sharding, + logical_sharding) from jax._src.typing import Array from jax._src.util import safe_map, safe_zip @@ -61,8 +61,19 @@ Shard = Any # TODO(jakevdp): fix circular imports and import Shard Shape = tuple[int, ...] -UINT_DTYPES = { - 8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64} +UINT_DTYPES: dict[int, np.dtype] = { + 8: np.dtype('uint8'), + 16: np.dtype('uint16'), + 32: np.dtype('uint32'), + 64: np.dtype('uint64'), +} + +if hasattr(gpu_prng, "registrations"): + for platform, targets in gpu_prng.registrations().items(): + for name, value, api_version in targets: + ffi.register_ffi_target( + name, value, platform=platform, api_version=api_version + ) # -- PRNG implementation interface @@ -105,7 +116,7 @@ def pprint(self): ])))) -prngs = {} +prngs: dict[str, PRNGImpl] = {} def register_prng(impl: PRNGImpl): if impl.name in prngs: @@ -131,7 +142,7 @@ def _check_prng_key_data(impl, key_data: typing.Array): f"got dtype={key_data.dtype}") -class PRNGKeyArray(jax.Array): +class PRNGKeyArray(Array): """An array of PRNG keys backed by an RNG implementation. This class lifts the definition of a PRNG, provided in the form of a @@ -148,7 +159,7 @@ class behave like an array whose base elements are keys, hiding the # device_buffer, device_buffers, __cuda_interface__() _impl: PRNGImpl - _base_array: typing.Array + _base_array: Array _consumed: bool | np.ndarray # Used in jax.experimental.key_reuse. _source_info: None | source_info_util.SourceInfo = None @@ -156,8 +167,17 @@ def __init__(self, impl, key_data: Any): assert not isinstance(key_data, core.Tracer) _check_prng_key_data(impl, key_data) self._impl = impl - self._base_array = key_data self._consumed = False # TODO(jakevdp): default to True here? + if isinstance(key_data, (np.ndarray, literals.TypedNdArray)): + aval = core.get_aval(key_data) + device = pxla.get_default_device() + key_data = pxla.batched_device_put( + aval, SingleDeviceSharding(device), [np.asarray(key_data)], [device], + committed=False) + self._base_array = key_data + + def _replace_with(self, value: PRNGKeyArray): + self._base_array._replace_with(value._base_array) def block_until_ready(self): _ = self._base_array.block_until_ready() @@ -168,9 +188,8 @@ def copy_to_host_async(self): @property def aval(self): - logical_sharding = (self.sharding if hasattr(self._base_array, 'sharding') - else None) - return keys_shaped_array(self._impl, self.shape, logical_sharding) + vma = self._base_array.aval.vma + return keys_shaped_array(self._impl, self.shape, self.sharding, vma) @property def shape(self): @@ -188,6 +207,10 @@ def ndim(self): def dtype(self): return KeyTy(self._impl) + @property + def nbytes(self): + return self.itemsize * self.size + @property def itemsize(self): return self.dtype.itemsize @@ -312,7 +335,7 @@ def prngkeyarray_unflatten(impl, children): base_array, = children return PRNGKeyArray(impl, base_array) -tree_util_internal.dispatch_registry.register_node( +tree_util.dispatch_registry.register_node( PRNGKeyArray, prngkeyarray_flatten, prngkeyarray_unflatten) @@ -321,9 +344,9 @@ def seed_with_impl(impl: PRNGImpl, seed: int | typing.ArrayLike) -> PRNGKeyArray return random_seed(seed, impl=impl) -def keys_shaped_array(impl, shape, sharding): +def keys_shaped_array(impl, shape, sharding, vma): aval = core.ShapedArray(shape, KeyTy(impl)) - return core.update_aval_with_sharding(aval, sharding) + return core.update_aval_with_sharding(aval, sharding, vma=vma) def base_arr_shape_to_keys_shape(impl, base_arr_shape): base_ndim = len(impl.key_shape) @@ -336,7 +359,7 @@ class KeyTyRules: @staticmethod def full(shape, fill_value, dtype): physical_shape = (*shape, *dtype._impl.key_shape) - if hasattr(fill_value, 'dtype') and jnp.issubdtype(fill_value.dtype, dtypes.prng_key): + if hasattr(fill_value, 'dtype') and dtypes.issubdtype(fill_value.dtype, dtypes.prng_key): key_data = jnp.broadcast_to(random_unwrap(fill_value), physical_shape) else: key_data = lax.full(physical_shape, fill_value, dtype=np.dtype('uint32')) @@ -346,7 +369,7 @@ def full(shape, fill_value, dtype): @staticmethod def physical_element_aval(dtype) -> core.ShapedArray: - return core.ShapedArray(dtype._impl.key_shape, jnp.dtype('uint32')) + return core.ShapedArray(dtype._impl.key_shape, np.dtype('uint32')) @staticmethod def physical_const(val) -> Array: @@ -379,10 +402,10 @@ def local_sharded_result_handler(aval, sharding, indices): phys_handler = phys_handler_maker(phys_aval, phys_sharding, phys_indices) # set up a handler that calls the physical one and wraps back up - def handler(bufs): - return PRNGKeyArray(aval.dtype._impl, phys_handler(bufs)) + def handler(arr): + return PRNGKeyArray(aval.dtype._impl, arr) - return handler + return phys_handler.wrap(handler) @staticmethod def global_sharded_result_handler(aval, out_sharding, committed): @@ -392,8 +415,8 @@ def global_sharded_result_handler(aval, out_sharding, committed): phys_sharding = physical_sharding(aval, out_sharding) phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed) def handler(bufs): - return PRNGKeyArray(aval.dtype._impl, phys_handler(bufs)) - return handler + return PRNGKeyArray(aval.dtype._impl, bufs) + return phys_handler.wrap(handler) @staticmethod def make_sharded_array(aval, sharding, arrays, committed): @@ -415,7 +438,6 @@ def device_put_sharded(vals, aval, sharding, devices): @staticmethod def device_put_replicated(val, aval, sharding, devices): physical_aval = core.physical_aval(aval) - assert len(xla.aval_to_xla_shapes(physical_aval)) == 1 physical_buf = random_unwrap(val) phys_sharding = physical_sharding(aval, sharding) physical_result = pxla.batched_device_put( @@ -461,7 +483,7 @@ def __hash__(self) -> int: core.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval -xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x +dtypes.canonicalize_value_handlers[PRNGKeyArray] = lambda x: x def key_array_shard_arg_handler(xs: Sequence[PRNGKeyArray], shardings, layouts, @@ -476,9 +498,9 @@ def key_array_shard_arg_handler(xs: Sequence[PRNGKeyArray], shardings, layouts, pxla.shard_arg_handlers[PRNGKeyArray] = key_array_shard_arg_handler -def key_array_constant_handler(x): +def key_array_constant_handler(x, aval): arr = x._base_array - return mlir.get_constant_handler(type(arr))(arr) + return mlir.get_constant_handler(type(arr))(arr, aval) mlir.register_constant_handler(PRNGKeyArray, key_array_constant_handler) @@ -542,7 +564,8 @@ def random_seed(seeds: int | typing.ArrayLike, impl: PRNGImpl) -> PRNGKeyArray: @random_seed_p.def_abstract_eval def random_seed_abstract_eval(seeds_aval, *, impl): - return keys_shaped_array(impl, seeds_aval.shape, seeds_aval.sharding) + return keys_shaped_array(impl, seeds_aval.shape, seeds_aval.sharding, + seeds_aval.vma) @random_seed_p.def_impl def random_seed_impl(seeds, *, impl): @@ -575,9 +598,13 @@ def random_split(keys, shape: Shape): def random_split_abstract_eval(keys_aval, *, shape): # TODO(yashkatariya): random_split should take sharding as an arg too so we # don't choose None here? - new_spec = (*keys_aval.sharding.spec, *[None] * len(shape)) + if keys_aval.sharding.mesh.empty: + out_sharding = core.get_cur_mesh_sharding() + else: + new_spec = (*keys_aval.sharding.spec, *[None] * len(shape)) + out_sharding = keys_aval.sharding.update(spec=new_spec) return keys_shaped_array(keys_aval.dtype._impl, (*keys_aval.shape, *shape), - keys_aval.sharding.with_spec(new_spec)) + out_sharding, keys_aval.vma) @random_split_p.def_impl def random_split_impl(keys, *, shape): @@ -603,7 +630,9 @@ def random_split_lowering(ctx, keys, *, shape): def random_fold_in(keys, msgs): - return random_fold_in_p.bind(keys, jnp.asarray(msgs)) + msgs = jnp.asarray(msgs) + keys, msgs = core.standard_insert_pvary(keys, msgs) + return random_fold_in_p.bind(keys, msgs) random_fold_in_p = core.Primitive('random_fold_in') ad.defjvp_zero(random_fold_in_p) @@ -611,9 +640,12 @@ def random_fold_in(keys, msgs): @random_fold_in_p.def_abstract_eval def random_fold_in_abstract_eval(keys_aval, msgs_aval): - shape = lax_internal.broadcasting_shape_rule( + shape = lax.broadcasting_shape_rule( + 'random_fold_in', keys_aval, msgs_aval) + sharding = lax.broadcasting_sharding_rule( 'random_fold_in', keys_aval, msgs_aval) - return core.ShapedArray(shape, keys_aval.dtype) + vma = core.standard_vma_rule('random_fold_in', keys_aval, msgs_aval) + return core.ShapedArray(shape, keys_aval.dtype, sharding=sharding, vma=vma) @random_fold_in_p.def_impl def random_fold_in_impl(keys, msgs): @@ -651,7 +683,14 @@ def random_bits(keys, bit_width, shape): def random_bits_abstract_eval(keys_aval, *, bit_width, shape): out_shape = (*keys_aval.shape, *shape) out_dtype = dtypes.dtype(f'uint{bit_width}') - return core.ShapedArray(out_shape, out_dtype) + # TODO(yashkatariya): random_bits should take an out_sharding argument. + if keys_aval.sharding.mesh.empty: + out_sharding = core.get_cur_mesh_sharding() + else: + new_spec = (*keys_aval.sharding.spec, *[None] * len(shape)) + out_sharding = keys_aval.sharding.update(spec=new_spec) + return core.ShapedArray(out_shape, out_dtype, sharding=out_sharding, + vma=keys_aval.vma) @random_bits_p.def_impl def random_bits_impl(keys, *, bit_width, shape): @@ -708,7 +747,7 @@ def random_wrap(base_arr, *, impl): def random_wrap_abstract_eval(base_arr_aval, *, impl): shape = base_arr_shape_to_keys_shape(impl, base_arr_aval.shape) sharding = logical_sharding(shape, KeyTy(impl), base_arr_aval.sharding) - return keys_shaped_array(impl, shape, sharding) + return keys_shaped_array(impl, shape, sharding, base_arr_aval.vma) @random_wrap_p.def_impl def random_wrap_impl(base_arr, *, impl): @@ -728,7 +767,7 @@ def random_wrap_batch_rule(batched_args, batch_dims, *, impl): def random_unwrap(keys): - if not jnp.issubdtype(keys.dtype, dtypes.prng_key): + if not dtypes.issubdtype(keys.dtype, dtypes.prng_key): raise TypeError(f'random_unwrap takes key array operand, got {keys.dtype=}') return random_unwrap_p.bind(keys) @@ -774,7 +813,7 @@ def threefry_seed(seed: typing.Array) -> typing.Array: """ return _threefry_seed(seed) -@partial(jit, inline=True) +@jit(inline=True) def _threefry_seed(seed: typing.Array) -> typing.Array: if seed.shape: raise TypeError(f"PRNG key seed must be a scalar; got {seed!r}.") @@ -782,7 +821,7 @@ def _threefry_seed(seed: typing.Array) -> typing.Array: raise TypeError(f"PRNG key seed must be an integer; got {seed!r}") convert = lambda k: lax.expand_dims(lax.convert_element_type(k, np.uint32), [0]) k1 = convert( - lax.shift_right_logical(seed, lax_internal._const(seed, 32))) + lax.shift_right_logical(seed, lax._const(seed, 32))) with config.numpy_dtype_promotion('standard'): # TODO(jakevdp): in X64 mode, this can generate 64-bit computations for 32-bit # inputs. We should avoid this. @@ -791,9 +830,9 @@ def _threefry_seed(seed: typing.Array) -> typing.Array: def _make_rotate_left(dtype): - if not jnp.issubdtype(dtype, np.integer): + if not dtypes.issubdtype(dtype, np.integer): raise TypeError("_rotate_left only accepts integer dtypes.") - nbits = np.array(jnp.iinfo(dtype).bits, dtype) + nbits = np.array(dtypes.iinfo(dtype).bits, dtype) def _rotate_left(x, d): if lax.dtype(d) != dtype: @@ -807,12 +846,12 @@ def _rotate_left(x, d): ### hash function and split def _threefry2x32_abstract_eval(*args): - if any(a.dtype != jnp.uint32 for a in args): + if any(a.dtype != np.uint32 for a in args): raise TypeError("Arguments to threefry2x32 must have uint32 type, got {}" .format(args)) if all(isinstance(arg, core.ShapedArray) for arg in args): - shape = lax_internal.broadcasting_shape_rule(*args) - aval = core.ShapedArray(shape, jnp.dtype(jnp.uint32)) + shape = lax.broadcasting_shape_rule(*args) + aval = core.ShapedArray(shape, np.dtype('uint32')) else: raise TypeError(f"Arguments to threefry2x32 must all be arrays, got {args}") return (aval,) * 2 @@ -861,7 +900,9 @@ def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True): x[1] = x[1] + ks[1] if use_rolled_loops: - x, _, _ = lax.fori_loop(0, 5, rolled_loop_step, (x, rotate_list(ks), rotations)) + x, _, _ = lax_control_flow.fori_loop( + 0, 5, rolled_loop_step, (x, rotate_list(ks), rotations) + ) else: for r in rotations[0]: @@ -893,16 +934,16 @@ def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True): # Since the unrolled lowering is large, emit it as an out-of-line function. -_threefry2x32_lowering_rule = mlir.cache_lowering(mlir.lower_fun( +_threefry2x32_lowering_rule = mlir.lower_fun( partial(_threefry2x32_lowering, use_rolled_loops=False), - multiple_results=True)) + multiple_results=True) _threefry2x32_cpu_lowering_rule = mlir.lower_fun( partial(_threefry2x32_lowering, use_rolled_loops=True), multiple_results=True) -def _threefry2x32_gpu_lowering_rule(lowering_func, ctx, k1, k2, x1, x2): +def _threefry2x32_gpu_lowering_rule(ctx, k1, k2, x1, x2, *, target_name_prefix): if not config.threefry_gpu_kernel_lowering.value: # back to default lowering return _threefry2x32_lowering_rule(ctx, k1, k2, x1, x2) @@ -917,23 +958,11 @@ def _broadcast(x, aval): return mlir.broadcast_in_dim(ctx, x, aval_out, broadcast_dimensions=range(rank - len(aval.shape), rank)) - out_len = reduce(op.mul, aval_out.shape, 1) - if not core.is_constant_dim(out_len): - length = mlir.eval_dynamic_shape_as_tensor(ctx, [out_len]) - length = mlir.hlo.convert( - ir.RankedTensorType.get((1,), ir.IntegerType.get_signless(64)), - length) - output_shape = mlir.eval_dynamic_shape_as_tensor(ctx, aval_out.shape) - else: - length = int(out_len) # will be passed statically - output_shape = None - - return lowering_func( - (_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)), - (_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length, - output_shape, - False, # forward_compatibility_mode - ) + sub_ctx = ctx.replace(avals_in=(aval_out,) * 4) + rule = ffi.ffi_lowering( + f"{target_name_prefix}_threefry2x32_ffi") + return rule(sub_ctx, _broadcast(k1, k1_aval), _broadcast(k2, k2_aval), + _broadcast(x1, x1_aval), _broadcast(x2, x2_aval)) threefry2x32_p = core.Primitive("threefry2x32") @@ -942,17 +971,19 @@ def _broadcast(x, aval): threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval) batching.defbroadcasting(threefry2x32_p) mlir.register_lowering( - threefry2x32_p, _threefry2x32_lowering_rule) + threefry2x32_p, _threefry2x32_lowering_rule, inline=False) mlir.register_lowering( threefry2x32_p, _threefry2x32_cpu_lowering_rule, platform='cpu') mlir.register_lowering( threefry2x32_p, - partial(_threefry2x32_gpu_lowering_rule, gpu_prng.cuda_threefry2x32), - platform='cuda') + partial(_threefry2x32_gpu_lowering_rule, target_name_prefix='cu'), + platform='cuda', + inline=False) mlir.register_lowering( threefry2x32_p, - partial(_threefry2x32_gpu_lowering_rule, gpu_prng.rocm_threefry2x32), - platform='rocm') + partial(_threefry2x32_gpu_lowering_rule, target_name_prefix='hip'), + platform='rocm', + inline=False) def iota_2x32_shape(shape): @@ -1057,7 +1088,7 @@ def _mul(x: core.DimSize, y: ir.Value) -> ir.Value: mlir.register_lowering(iota_2x32_shape_p, iota_2x32_shape_lowering) -@partial(jit, inline=True) +@jit(inline=True) def threefry_2x32(keypair, count): """Apply the Threefry 2x32 hash. @@ -1086,11 +1117,11 @@ def threefry_2x32(keypair, count): flat_count_padded = jnp.concatenate([flat_count, np.uint32([0])]) flat_count_padded_half_size = flat_count_padded.shape[0] // 2 x = [ - lax.dynamic_slice(flat_count_padded, (0,), - (flat_count_padded_half_size,)), - lax.dynamic_slice(flat_count_padded, - (flat_count_padded_half_size,), - (flat_count_padded_half_size,)) + lax_slicing.dynamic_slice(flat_count_padded, (0,), + (flat_count_padded_half_size,)), + lax_slicing.dynamic_slice(flat_count_padded, + (flat_count_padded_half_size,), + (flat_count_padded_half_size,)) ] assert x[0].shape == x[1].shape, (x[0].shape, x[1].shape) @@ -1100,7 +1131,7 @@ def threefry_2x32(keypair, count): if core.is_constant_dim(odd_size): return lax.reshape(out[:-1] if odd_size else out, count.shape) else: - out_no_padding = lax.dynamic_slice(out, (0,), (flat_count.shape[0],)) + out_no_padding = lax_slicing.dynamic_slice(out, (0,), (flat_count.shape[0],)) return lax.reshape(out_no_padding, count.shape) @@ -1108,20 +1139,20 @@ def threefry_split(key: typing.Array, shape: Shape) -> typing.Array: shape = tuple(unsafe_map(core.concrete_dim_or_error, shape)) return _threefry_split(key, shape) -@partial(jit, static_argnums=(1,)) +@jit(static_argnums=(1,)) def _threefry_split(key, shape) -> typing.Array: if config.threefry_partitionable.value: return _threefry_split_foldlike(key, shape) else: return _threefry_split_original(key, shape) -@partial(jit, static_argnums=(1,), inline=True) +@jit(static_argnums=(1,), inline=True) def _threefry_split_original(key, shape) -> typing.Array: num = math.prod(shape) counts = lax.iota(np.uint32, num * 2) return lax.reshape(threefry_2x32(key, counts), (*shape, 2)) -@partial(jit, static_argnums=(1,), inline=True) +@jit(static_argnums=(1,), inline=True) def _threefry_split_foldlike(key, shape) -> typing.Array: k1, k2 = key counts1, counts2 = iota_2x32_shape(shape) @@ -1131,7 +1162,7 @@ def _threefry_split_foldlike(key, shape) -> typing.Array: def threefry_fold_in(key: typing.Array, data: typing.Array) -> typing.Array: assert not data.shape - return _threefry_fold_in(key, jnp.uint32(data)) + return _threefry_fold_in(key, jnp.asarray(data, dtype='uint32')) @jit def _threefry_fold_in(key, data): @@ -1162,13 +1193,13 @@ def _threefry_random_bits_partitionable(key: typing.Array, bit_width, shape): if bit_width == 64: bits_hi = lax.convert_element_type(bits1, dtype) bits_lo = lax.convert_element_type(bits2, dtype) - return lax.shift_left(bits_hi, dtype(32)) | bits_lo + return lax.shift_left(bits_hi, jnp.asarray(32, dtype=dtype)) | bits_lo elif bit_width == 32: return bits1 ^ bits2 else: return lax.convert_element_type(bits1 ^ bits2, dtype) -@partial(jit, static_argnums=(1, 2), inline=True) +@jit(static_argnums=(1, 2), inline=True) def _threefry_random_bits_original(key: typing.Array, bit_width, shape): size = math.prod(shape) # Compute ceil(bit_width * size / 32) in a way that is friendly to shape @@ -1178,7 +1209,7 @@ def _threefry_random_bits_original(key: typing.Array, bit_width, shape): max_count += 1 if core.is_constant_dim(max_count): - nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max) + nblocks, rem = divmod(max_count, dtypes.iinfo(np.uint32).max) else: nblocks, rem = 0, max_count @@ -1187,18 +1218,18 @@ def _threefry_random_bits_original(key: typing.Array, bit_width, shape): else: keys = threefry_split(key, (nblocks + 1,)) subkeys, last_key = keys[:-1], keys[-1] - blocks = vmap(threefry_2x32, in_axes=(0, None))(subkeys, lax.iota(np.uint32, jnp.iinfo(np.uint32).max)) + blocks = vmap(threefry_2x32, in_axes=(0, None))(subkeys, lax.iota(np.uint32, dtypes.iinfo(np.uint32).max)) last = threefry_2x32(last_key, lax.iota(np.uint32, rem)) bits = lax.concatenate([blocks.ravel(), last], 0) dtype = UINT_DTYPES[bit_width] if bit_width == 64: bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)] - bits = lax.shift_left(bits[0], dtype(32)) | bits[1] + bits = lax.shift_left(bits[0], jnp.asarray(32, dtype=dtype)) | bits[1] elif bit_width in [8, 16]: # this is essentially bits.view(dtype)[:size] bits = lax.bitwise_and( - np.uint32(np.iinfo(dtype).max), + jnp.asarray(np.iinfo(dtype).max, dtype='uint32'), lax.shift_right_logical( lax.broadcast(bits, (1,)), lax.mul( @@ -1227,7 +1258,7 @@ def _threefry_random_bits_original(key: typing.Array, bit_width, shape): # -- RngBitGenerator PRNG implementation # This code is experimental! -# https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator +# https://www.openxla.org/xla/operation_semantics#rngbitgenerator # Notice that the RngBitGenerator operations are not guaranteed to be # stable/deterministic across backends or compiler versions. Correspondingly, we # reserve the right to change any of these implementations at any time! @@ -1253,7 +1284,7 @@ def _rbg_fold_in(key: typing.Array, data: typing.Array) -> typing.Array: def _rbg_random_bits(key: typing.Array, bit_width: int, shape: Sequence[int] ) -> typing.Array: - if not key.shape == (4,) and key.dtype == jnp.dtype('uint32'): + if not key.shape == (4,) and key.dtype == np.dtype('uint32'): raise TypeError("_rbg_random_bits got invalid prng key.") if bit_width not in (8, 16, 32, 64): raise TypeError("requires 8-, 16-, 32- or 64-bit field width.") @@ -1276,7 +1307,7 @@ def _unsafe_rbg_split(key: typing.Array, shape: Shape) -> typing.Array: # treat 10 iterations of random bits as a 'hash function' num = math.prod(shape) _, keys = lax.rng_bit_generator(key, (10 * num, 4), dtype='uint32') - return lax.slice_in_dim( + return lax_slicing.slice_in_dim( keys, start_index=None, limit_index=None, stride=10).reshape(*shape, 4) def _unsafe_rbg_fold_in(key: typing.Array, data: typing.Array) -> typing.Array: @@ -1294,3 +1325,20 @@ def _unsafe_rbg_fold_in(key: typing.Array, data: typing.Array) -> typing.Array: tag='urbg') register_prng(unsafe_rbg_prng_impl) + + +# Register export serialization for PRNG key types. +try: + from jax._src.export import serialization # pytype: disable=import-error + from jax._src.export import serialization_generated as ser_flatbuf # pytype: disable=import-error +except ImportError: + # This can happen if flatbuffers is not installed, in which case export + # serialization is not supported and it is safe to skip the registration. + pass +else: + serialization.register_dtype_kind( + KeyTy(prngs["threefry2x32"]), ser_flatbuf.DType.key_fry) + serialization.register_dtype_kind( + KeyTy(prngs["rbg"]), ser_flatbuf.DType.key_rbg) + serialization.register_dtype_kind( + KeyTy(prngs["unsafe_rbg"]), ser_flatbuf.DType.key_unsafe_rbg) diff --git a/jax/_src/profiler.py b/jax/_src/profiler.py index f06933f57e22..3fb357299298 100644 --- a/jax/_src/profiler.py +++ b/jax/_src/profiler.py @@ -32,14 +32,23 @@ traceback_util.register_exclusion(__file__) from jax._src import xla_bridge -from jax._src.lib import xla_client +from jax._src.lib import _profiler +from jax._src.lib import _profile_data -_profiler_server: xla_client.profiler.ProfilerServer | None = None +ProfileData = _profile_data.ProfileData +ProfileEvent = _profile_data.ProfileEvent +ProfilePlane = _profile_data.ProfilePlane + +_profiler_server: _profiler.ProfilerServer | None = None logger = logging.getLogger(__name__) -def start_server(port: int) -> xla_client.profiler.ProfilerServer: +class ProfileOptions(_profiler.ProfileOptions): + """Profiler Options to configure the collectors for the profiler.""" + + +def start_server(port: int) -> _profiler.ProfilerServer: """Starts the profiler server on port `port`. Using the "TensorFlow profiler" feature in `TensorBoard @@ -59,7 +68,7 @@ def start_server(port: int) -> xla_client.profiler.ProfilerServer: # is for start_trace), but I'm putting it here to be safe. xla_bridge.get_backend() - _profiler_server = xla_client.profiler.start_server(port) + _profiler_server = _profiler.start_server(port) return _profiler_server @@ -89,12 +98,17 @@ def reset(self): _profile_state = _ProfileState() -def start_trace(log_dir: os.PathLike | str, create_perfetto_link: bool = False, - create_perfetto_trace: bool = False) -> None: +def start_trace( + log_dir: os.PathLike | str, + create_perfetto_link: bool = False, + create_perfetto_trace: bool = False, + profiler_options: ProfileOptions | None = None, +) -> None: """Starts a profiler trace. The trace will capture CPU, GPU, and/or TPU activity, including Python - functions and JAX on-device operations. Use :func:`stop_trace` to end the trace + functions and JAX on-device operations. Use :func:`stop_trace` to end the + trace and save the results to ``log_dir``. The resulting trace can be viewed with TensorBoard. Note that TensorBoard @@ -113,8 +127,8 @@ def start_trace(log_dir: os.PathLike | str, create_perfetto_link: bool = False, ``perfetto_trace.json.gz`` file that is compatible for upload with the Perfetto trace viewer UI (https://ui.perfetto.dev). The file will also be generated if ``create_perfetto_link`` is true. This could be useful if you - want to generate a Perfetto-compatible trace without blocking the - process. + want to generate a Perfetto-compatible trace without blocking the process. + profiler_options: Profiler options to configure the profiler for collection. """ with _profile_state.lock: if _profile_state.profile_session is not None: @@ -126,7 +140,12 @@ def start_trace(log_dir: os.PathLike | str, create_perfetto_link: bool = False, # fail and no TPU operations will be included in the profile. xla_bridge.get_backend() - _profile_state.profile_session = xla_client.profiler.ProfilerSession() + if profiler_options is None: + _profile_state.profile_session = _profiler.ProfilerSession() + else: + _profile_state.profile_session = _profiler.ProfilerSession( + profiler_options + ) _profile_state.create_perfetto_link = create_perfetto_link _profile_state.create_perfetto_trace = ( create_perfetto_trace or create_perfetto_link) @@ -201,7 +220,7 @@ def stop_trace(): if _profile_state.profile_session is None: raise RuntimeError("No profile started") sess = _profile_state.profile_session - sess.export(sess.stop(), str(_profile_state.log_dir)) + sess.stop_and_export(str(_profile_state.log_dir)) # type: ignore if _profile_state.create_perfetto_trace: abs_filename = _write_perfetto_trace_file(_profile_state.log_dir) if _profile_state.create_perfetto_link: @@ -219,13 +238,18 @@ def stop_and_get_fdo_profile() -> bytes | str: if _profile_state.profile_session is None: raise RuntimeError("No profile started") xspace = _profile_state.profile_session.stop() - fdo_profile = xla_client.profiler.get_fdo_profile(xspace) + fdo_profile = _profiler.get_fdo_profile(xspace) _profile_state.reset() return fdo_profile @contextmanager -def trace(log_dir: os.PathLike | str, create_perfetto_link=False, create_perfetto_trace=False): +def trace( + log_dir: os.PathLike | str, + create_perfetto_link=False, + create_perfetto_trace=False, + profiler_options: ProfileOptions | None = None, +): """Context manager to take a profiler trace. The trace will capture CPU, GPU, and/or TPU activity, including Python @@ -247,17 +271,19 @@ def trace(log_dir: os.PathLike | str, create_perfetto_link=False, create_perfett ``perfetto_trace.json.gz`` file that is compatible for upload with the Perfetto trace viewer UI (https://ui.perfetto.dev). The file will also be generated if ``create_perfetto_link`` is true. This could be useful if you - want to generate a Perfetto-compatible trace without blocking the - process. + want to generate a Perfetto-compatible trace without blocking the process. + profiler_options: Profiler options to configure the profiler for collection. """ - start_trace(log_dir, create_perfetto_link, create_perfetto_trace) + start_trace( + log_dir, create_perfetto_link, create_perfetto_trace, profiler_options + ) try: yield finally: stop_trace() -class TraceAnnotation(xla_client.profiler.TraceMe): +class TraceAnnotation(_profiler.TraceMe): """Context manager that generates a trace event in the profiler. The trace event spans the duration of the code enclosed by the context. @@ -271,7 +297,6 @@ class TraceAnnotation(xla_client.profiler.TraceMe): This will cause a "my_label" event to show up on the trace timeline if the event occurs while the process is being traced. """ - pass class StepTraceAnnotation(TraceAnnotation): @@ -332,7 +357,6 @@ def annotate_function(func: Callable, name: str | None = None, def wrapper(*args, **kwargs): with TraceAnnotation(name, **decorator_kwargs): return func(*args, **kwargs) - return wrapper return wrapper @@ -361,7 +385,8 @@ def device_memory_profile(backend: str | None = None) -> bytes: Returns: A byte string containing a binary `pprof`-format protocol buffer. """ - return xla_client.heap_profile(xla_bridge.get_backend(backend)) + client = xla_bridge.get_backend(backend) + return gzip.compress(client.heap_profile()) def save_device_memory_profile(filename, backend: str | None = None) -> None: @@ -382,7 +407,7 @@ def save_device_memory_profile(filename, backend: str | None = None) -> None: # Allows to run model with profiler given amount of times. After required amount -# of retries achived client can collect FDO data. +# of retries achieved client can collect FDO data. class PGLEProfiler: def __init__(self, retries: int, percentile: int): @@ -391,7 +416,7 @@ def __init__(self, retries: int, percentile: int): self.collected_fdo: str | None = None self.called_times: int = 0 self.fdo_profiles: list[Any] = [] - self.current_session: xla_client.profiler.ProfilerSession | None = None + self.current_session: _profiler.ProfilerSession | None = None def consume_fdo_profile(self) -> str | None: if self.collected_fdo is not None: @@ -400,7 +425,7 @@ def consume_fdo_profile(self) -> str | None: if not self.is_enabled() or self.called_times != self.retries: return None - self.collected_fdo = xla_client.profiler.aggregate_profiled_instructions( + self.collected_fdo = _profiler.aggregate_profiled_instructions( self.fdo_profiles, self.percentile ) return self.collected_fdo @@ -424,16 +449,17 @@ def trace(cls, runner: PGLEProfiler | None): or not runner.is_enabled() or runner.is_fdo_consumed()): yield else: - options = xla_client.profiler.ProfileOptions() + options = _profiler.ProfileOptions() options.enable_hlo_proto = True - runner.current_session = xla_client.profiler.ProfilerSession(options) + options.raise_error_on_start_failure = True + runner.current_session = _profiler.ProfilerSession(options) try: yield finally: xspace = runner.current_session.stop() runner.fdo_profiles.append( - xla_client.profiler.get_fdo_profile(xspace) + _profiler.get_fdo_profile(xspace) ) runner.current_session = None diff --git a/jax/_src/public_test_util.py b/jax/_src/public_test_util.py index 455a3b98cce2..5060424765c6 100644 --- a/jax/_src/public_test_util.py +++ b/jax/_src/public_test_util.py @@ -14,6 +14,7 @@ from functools import partial import operator +from typing import Any, TypeAlias from jax._src import api from jax._src import config @@ -32,28 +33,35 @@ EPS = 1e-4 -def _dtype(x): +def _dtype(x: Any) -> np.dtype: if hasattr(x, 'dtype'): return x.dtype - elif type(x) in _dtypes.python_scalar_dtypes: - return np.dtype(_dtypes.python_scalar_dtypes[type(x)]) + elif (dt := _dtypes.python_scalar_types_to_dtypes.get(type(x))) is not None: + return dt else: return np.asarray(x).dtype +ToleranceDict: TypeAlias = dict[np.dtype, int | float] -_default_tolerance = { +_default_tolerance: ToleranceDict = { _dtypes.float0: 0, np.dtype(np.bool_): 0, + np.dtype(_dtypes.int2): 0, np.dtype(_dtypes.int4): 0, np.dtype(np.int8): 0, np.dtype(np.int16): 0, np.dtype(np.int32): 0, np.dtype(np.int64): 0, + np.dtype(_dtypes.uint2): 0, np.dtype(_dtypes.uint4): 0, np.dtype(np.uint8): 0, np.dtype(np.uint16): 0, np.dtype(np.uint32): 0, np.dtype(np.uint64): 0, + np.dtype(_dtypes.float4_e2m1fn): 1e0, + np.dtype(_dtypes.float8_e3m4): 1e-1, + np.dtype(_dtypes.float8_e4m3): 1e-1, + np.dtype(_dtypes.float8_e8m0fnu): 1e0, np.dtype(_dtypes.float8_e4m3b11fnuz): 1e-1, np.dtype(_dtypes.float8_e4m3fn): 1e-1, np.dtype(_dtypes.float8_e4m3fnuz): 1e-1, @@ -67,16 +75,15 @@ def _dtype(x): np.dtype(np.complex128): 1e-15, } -if _dtypes.int2 is not None: - assert _dtypes.uint2 is not None - _default_tolerance[np.dtype(_dtypes.int2)] = 0 - _default_tolerance[np.dtype(_dtypes.uint2)] = 0 - def default_tolerance(): return _default_tolerance -default_gradient_tolerance = { +default_gradient_tolerance: ToleranceDict = { + np.dtype(_dtypes.float4_e2m1fn): 1e0, + np.dtype(_dtypes.float8_e3m4): 1e-1, + np.dtype(_dtypes.float8_e4m3): 1e-1, + np.dtype(_dtypes.float8_e8m0fnu): 1e0, np.dtype(_dtypes.float8_e4m3b11fnuz): 1e-1, np.dtype(_dtypes.float8_e4m3fn): 1e-1, np.dtype(_dtypes.float8_e4m3fnuz): 1e-1, @@ -90,21 +97,8 @@ def default_tolerance(): np.dtype(np.complex128): 1e-5, } -# TODO: make this unconditional when ml_dtypes>=0.5.0 is required -if _dtypes.float8_e3m4 is not None: - _default_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1 - default_gradient_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1 -if _dtypes.float8_e4m3 is not None: - _default_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1 - default_gradient_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1 -if _dtypes.float8_e8m0fnu is not None: - _default_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0 - default_gradient_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0 -if _dtypes.float4_e2m1fn is not None: - _default_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0 - default_gradient_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0 - -def is_python_scalar(val): + +def is_python_scalar(val: Any) -> bool: return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex)) def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): @@ -113,6 +107,10 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): return custom_float_dtypes = [ + _dtypes.float4_e2m1fn, + _dtypes.float8_e8m0fnu, + _dtypes.float8_e3m4, + _dtypes.float8_e4m3, _dtypes.float8_e4m3b11fnuz, _dtypes.float8_e4m3fn, _dtypes.float8_e4m3fnuz, @@ -121,15 +119,6 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): _dtypes.bfloat16, ] - if _dtypes.float8_e4m3 is not None: - custom_float_dtypes.insert(0, _dtypes.float8_e4m3) - if _dtypes.float8_e3m4 is not None: - custom_float_dtypes.insert(0, _dtypes.float8_e3m4) - if _dtypes.float8_e8m0fnu is not None: - custom_float_dtypes.insert(0, _dtypes.float8_e8m0fnu) - if _dtypes.float4_e2m1fn is not None: - custom_float_dtypes.insert(0, _dtypes.float4_e2m1fn) - def maybe_upcast(x): if x.dtype in custom_float_dtypes: return x.astype(np.float32) @@ -151,7 +140,8 @@ def maybe_upcast(x): # value errors. It should not do that. np.testing.assert_allclose(a, b, **kw, err_msg=err_msg) -def tolerance(dtype, tol=None): + +def tolerance(dtype: np.dtype, tol: int | float | ToleranceDict | None = None) -> int | float: tol = {} if tol is None else tol if not isinstance(tol, dict): return tol @@ -369,4 +359,15 @@ def f_vjp(*args): return vjp_py(out_primal_py) _check_grads(f_vjp, args, order - 1, rev_msg) + if "lin" in modes: + lin_msg = f'LIN of {err_msg}' if err_msg else 'LIN' + _check_jvp(f, partial(_jvp_from_lin, f), args, err_msg=lin_msg) + if order > 1: + _check_grads(partial(_jvp_from_lin, f), (args, args), order - 1, lin_msg) + _check_grads(f, args, order) + +def _jvp_from_lin(f, primals, tangents): + primal_out, f_lin = api.linearize(f, *primals) + tangent_out = f_lin(*tangents) + return primal_out, tangent_out diff --git a/jax/_src/random.py b/jax/_src/random.py index 094268c65825..5da171c269af 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -14,7 +14,7 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Hashable, Sequence from functools import partial import math from operator import index @@ -24,24 +24,28 @@ import numpy as np -import jax.numpy as jnp -from jax import lax -from jax.numpy.linalg import cholesky, svd, eigh - from jax._src import config from jax._src import core from jax._src import dispatch from jax._src import dtypes +from jax._src import numpy as jnp from jax._src import prng from jax._src import xla_bridge +from jax._src.mesh import get_abstract_mesh +from jax._src.sharding_impls import NamedSharding, PartitionSpec as P from jax._src.api import jit, vmap from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir -from jax._src.lax import lax as lax_internal -from jax._src.numpy.lax_numpy import _convert_and_clip_integer +from jax._src.lax import control_flow as lax_control_flow +from jax._src.lax import lax +from jax._src.lax import special as lax_special +from jax._src.numpy import einsum as jnp_einsum +from jax._src.numpy import linalg as jnp_linalg from jax._src.numpy.util import _arraylike, check_arraylike, promote_dtypes_inexact -from jax._src.typing import Array, ArrayLike, DTypeLike +from jax._src.pjit import auto_axes +from jax._src.sharding_impls import canonicalize_sharding +from jax._src.typing import Array, ArrayLike, DType, DTypeLike from jax._src.util import canonicalize_axis @@ -61,8 +65,6 @@ ### utilities -_lax_const = lax_internal._const - def _isnan(x: ArrayLike) -> Array: return lax.ne(x, x) @@ -76,12 +78,12 @@ def _check_prng_key(name: str, key: ArrayLike, *, # Call random_wrap here to surface errors for invalid keys. wrapped_key = prng.random_wrap(key, impl=default_prng_impl()) wrapped = True - if config.legacy_prng_key.value == 'error': + if config.legacy_prng_key.value == config.LegacyPrngKeyState.ERROR: raise ValueError( 'Legacy uint32 key array passed as key to jax.random function. ' 'Please create keys using jax.random.key(). If use of a raw key array ' 'was intended, set jax_legacy_prng_key="allow".') - elif config.legacy_prng_key.value == 'warn': + elif config.legacy_prng_key.value == config.LegacyPrngKeyState.WARN: warnings.warn( 'Legacy uint32 key array passed as key to jax.random function. ' 'Please create keys using jax.random.key(). If use of a raw key array ' @@ -104,7 +106,7 @@ def _check_prng_key(name: str, key: ArrayLike, *, def _return_prng_keys(was_wrapped, key): # TODO(frostig): remove once we always enable_custom_prng - assert jnp.issubdtype(key.dtype, dtypes.prng_key) + assert dtypes.issubdtype(key.dtype, dtypes.prng_key) if config.enable_custom_prng.value: return key else: @@ -112,7 +114,7 @@ def _return_prng_keys(was_wrapped, key): def _random_bits(key: Array, bit_width: int, shape: Shape) -> Array: - assert jnp.issubdtype(key.dtype, dtypes.prng_key) + assert dtypes.issubdtype(key.dtype, dtypes.prng_key) return prng.random_bits(key, bit_width=bit_width, shape=shape) @@ -159,7 +161,7 @@ def __eq__(self, other) -> bool: # TODO(frostig,vanderplas): remove PRNGImpl from this union when it's # no longer in the public API because `default_prng_impl` is gone -PRNGSpecDesc = Union[str, PRNGSpec, PRNGImpl] +PRNGSpecDesc = Union[str, PRNGSpec, PRNGImpl, Hashable] def resolve_prng_impl(impl_spec: PRNGSpecDesc | None) -> PRNGImpl: @@ -188,7 +190,7 @@ def resolve_prng_impl(impl_spec: PRNGSpecDesc | None) -> PRNGImpl: def _key(ctor_name: str, seed: int | ArrayLike, impl_spec: PRNGSpecDesc | None) -> Array: impl = resolve_prng_impl(impl_spec) - if hasattr(seed, 'dtype') and jnp.issubdtype(seed.dtype, dtypes.prng_key): + if hasattr(seed, 'dtype') and dtypes.issubdtype(seed.dtype, dtypes.prng_key): raise TypeError( f"{ctor_name} accepts a scalar seed, but was given a PRNG key.") if np.ndim(seed): @@ -223,7 +225,7 @@ def PRNGKey(seed: int | ArrayLike, *, This function produces old-style legacy PRNG keys, which are arrays of dtype ``uint32``. For more, see the note in the `PRNG keys - `_ + `_ section. When possible, :func:`jax.random.key` is recommended for use instead. @@ -261,7 +263,7 @@ def fold_in(key: ArrayLike, data: IntegerArray) -> Array: if np.ndim(data): raise TypeError("fold_in accepts a scalar, but was given an array of" f"shape {np.shape(data)} != (). Use jax.vmap for batching.") - key_out = prng.random_fold_in(key, jnp.uint32(data)) + key_out = prng.random_fold_in(key, jnp.asarray(data, dtype='uint32')) return _return_prng_keys(wrapped, key_out) @@ -269,7 +271,7 @@ def _split(key: Array, num: int | tuple[int, ...] = 2) -> Array: # Alternative to split() to use within random samplers. # TODO(frostig): remove and use split(); we no longer need to wait # to always enable_custom_prng - assert jnp.issubdtype(key.dtype, dtypes.prng_key) + assert dtypes.issubdtype(key.dtype, dtypes.prng_key) if key.ndim: raise TypeError("split accepts a single key, but was given a key array of " f"shape {key.shape} != (). Use jax.vmap for batching.") @@ -292,7 +294,7 @@ def split(key: ArrayLike, num: int | tuple[int, ...] = 2) -> Array: def _key_impl(keys: Array) -> PRNGImpl: - assert jnp.issubdtype(keys.dtype, dtypes.prng_key) + assert dtypes.issubdtype(keys.dtype, dtypes.prng_key) keys_dtype = typing.cast(prng.KeyTy, keys.dtype) return keys_dtype._impl @@ -306,7 +308,7 @@ def key_impl(keys: ArrayLike) -> str | PRNGSpec: def _key_data(keys: Array) -> Array: - assert jnp.issubdtype(keys.dtype, dtypes.prng_key) + assert dtypes.issubdtype(keys.dtype, dtypes.prng_key) return prng.random_unwrap(keys) def key_data(keys: ArrayLike) -> Array: @@ -346,9 +348,20 @@ def _check_shape(name: str, shape: Shape, *param_shapes) -> None: raise ValueError(msg.format(name, shape_, shape)) +def maybe_auto_axes(f, out_sharding, **hoist_kwargs): + f_ = partial(f, **hoist_kwargs) + if out_sharding is None: + return f_ + else: + return auto_axes(f_, out_sharding=out_sharding, + axes=out_sharding.mesh.explicit_axes) + + def bits(key: ArrayLike, shape: Shape = (), - dtype: DTypeLikeUInt | None = None) -> Array: + dtype: DTypeLikeUInt | None = None, + *, + out_sharding=None) -> Array: """Sample uniform bits in the form of unsigned integers. Args: @@ -363,23 +376,35 @@ def bits(key: ArrayLike, """ key, _ = _check_prng_key("bits", key) if dtype is None: - dtype = dtypes.canonicalize_dtype(jnp.uint) + dtype = dtypes.default_uint_dtype() else: - dtypes.check_user_dtype_supported(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype(dtype) if not dtypes.issubdtype(dtype, np.unsignedinteger): raise ValueError("dtype argument to `bits` must be an unsigned int dtype, " f"got {dtype}") - dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) + out_sharding = canonicalize_sharding_for_samplers(out_sharding, "bits", shape) bit_width = dtype.itemsize * 8 - return _random_bits(key, bit_width, shape) + return maybe_auto_axes(_random_bits, out_sharding, + bit_width=bit_width, shape=shape)(key) + + +def canonicalize_sharding_for_samplers(out_sharding, name, shape): + out_sharding = canonicalize_sharding(out_sharding, name) + cur_mesh = get_abstract_mesh() + if cur_mesh.are_all_axes_explicit and out_sharding is None and not shape: + # when shape is empty i.e. scalar, we can choose a replicated sharding. + out_sharding = NamedSharding(cur_mesh, P()) + return out_sharding def uniform(key: ArrayLike, shape: Shape = (), - dtype: DTypeLikeFloat = float, + dtype: DTypeLikeFloat | None = None, minval: RealArray = 0., - maxval: RealArray = 1.) -> Array: + maxval: RealArray = 1., + *, + out_sharding=None) -> Array: """Sample uniform random values in [minval, maxval) with given shape/dtype. Args: @@ -395,19 +420,21 @@ def uniform(key: ArrayLike, A random array with the specified shape and dtype. """ key, _ = _check_prng_key("uniform", key) - dtypes.check_user_dtype_supported(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype( + float if dtype is None else dtype) shape = core.canonicalize_shape(shape) + out_sharding = canonicalize_sharding_for_samplers(out_sharding, "uniform", shape) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `uniform` must be a float dtype, " f"got {dtype}") - dtype = dtypes.canonicalize_dtype(dtype) - return _uniform(key, shape, dtype, minval, maxval) + return maybe_auto_axes(_uniform, out_sharding, + shape=shape, dtype=dtype)(key, minval, maxval) -@partial(jit, static_argnums=(1, 2)) -def _uniform(key, shape, dtype, minval, maxval) -> Array: +@jit(static_argnums=(3, 4)) +def _uniform(key, minval, maxval, shape, dtype) -> Array: _check_shape("uniform", shape) - if not jnp.issubdtype(dtype, np.floating): + if not dtypes.issubdtype(dtype, np.floating): raise TypeError("uniform only accepts floating point dtypes.") minval = lax.convert_element_type(minval, dtype) @@ -415,7 +442,7 @@ def _uniform(key, shape, dtype, minval, maxval) -> Array: minval = lax.broadcast_to_rank(minval, len(shape)) maxval = lax.broadcast_to_rank(maxval, len(shape)) - finfo = jnp.finfo(dtype) + finfo = dtypes.finfo(dtype) nbits, nmant = finfo.bits, finfo.nmant if nbits not in (8, 16, 32, 64): @@ -435,21 +462,64 @@ def _uniform(key, shape, dtype, minval, maxval) -> Array: # 1 (after applying the bias), then shift and scale to the desired range. The # bit-level transformation we use relies on Numpy and XLA having bit-for-bit # equivalent float representations, which might not be true on all platforms. + float_bits = lax.shift_right_logical( + bits, jnp.array(rng_bits - nmant, uint_dtype)) float_bits = lax.bitwise_or( - lax.shift_right_logical(bits, np.array(rng_bits - nmant, uint_dtype)), - np.array(1.0, dtype).view(uint_dtype), - ) - floats = lax.bitcast_convert_type(float_bits, dtype) - np.array(1., dtype) + float_bits, + # The double cast is because the TPU backend does not implement `view` on + # float64 values => do the `view` in NumPy first, but then ensure that + # we have a JAX array that won't be canonicalized further. + jnp.asarray(np.array(1.0, dtype).view(float_bits.dtype), + dtype=float_bits.dtype)) + floats = lax.bitcast_convert_type(float_bits, dtype) - jnp.array(1., dtype) return lax.max( minval, lax.reshape(floats * (maxval - minval) + minval, shape)) +def _convert_and_clip_integer(val: Array, dtype: DType) -> Array: + """ + Convert integer-typed val to specified integer dtype, clipping to dtype + range rather than wrapping. + + Args: + val: value to be converted + dtype: dtype of output + + Returns: + equivalent of val in new dtype + + Examples + -------- + Normal integer type conversion will wrap: + + >>> val = jnp.uint32(0xFFFFFFFF) + >>> val.astype('int32') + Array(-1, dtype=int32) + + This function clips to the values representable in the new type: + + >>> _convert_and_clip_integer(val, 'int32') + Array(2147483647, dtype=int32) + """ + assert isinstance(val, Array) + if not (dtypes.issubdtype(dtype, np.integer) and dtypes.issubdtype(val.dtype, np.integer)): + raise TypeError("_convert_and_clip_integer only accepts integer dtypes.") + + min_val = lax._const(val, max(dtypes.iinfo(dtype).min, + dtypes.iinfo(val.dtype).min)) + max_val = lax._const(val, min(dtypes.iinfo(dtype).max, + dtypes.iinfo(val.dtype).max)) + return jnp.clip(val, min_val, max_val).astype(dtype) + + def randint(key: ArrayLike, shape: Shape, minval: IntegerArray, maxval: IntegerArray, - dtype: DTypeLikeInt = int) -> Array: + dtype: DTypeLikeInt | None = None, + *, + out_sharding=None) -> Array: """Sample uniform random values in [minval, maxval) with given shape/dtype. Args: @@ -464,37 +534,90 @@ def randint(key: ArrayLike, Returns: A random array with the specified shape and dtype. + + .. note:: + + :func:`randint` uses a modulus-based computation that is known to produce + slightly biased values in some cases. The magnitude of the bias scales as + ``(maxval - minval) * ((2 ** nbits ) % (maxval - minval)) / 2 ** nbits``: + in words, the bias goes to zero when ``(maxval - minval)`` is a power of 2, + and otherwise the bias will be small whenever ``(maxval - minval)`` is + small compared to the range of the sampled type. + + To reduce this bias, 8-bit and 16-bit values will always be sampled at 32-bit and + then cast to the requested type. If you find yourself sampling values for which + this bias may be problematic, a possible alternative is to sample via uniform:: + + def randint_via_uniform(key, shape, minval, maxval, dtype): + u = jax.random.uniform(key, shape, minval=minval - 0.5, maxval=maxval - 0.5) + return u.round().astype(dtype) + + But keep in mind this method has its own biases due to floating point rounding + errors, and in particular there may be some integers in the range + ``[minval, maxval)`` that are impossible to produce with this approach. """ key, _ = _check_prng_key("randint", key) - dtypes.check_user_dtype_supported(dtype) - dtype = dtypes.canonicalize_dtype(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype( + int if dtype is None else dtype) shape = core.canonicalize_shape(shape) - return _randint(key, shape, minval, maxval, dtype) + out_sharding = canonicalize_sharding_for_samplers(out_sharding, "randint", shape) + + if not dtypes.issubdtype(dtype, np.integer): + raise TypeError(f"randint only accepts integer dtypes, got {dtype}") -@partial(jit, static_argnums=(1, 4)) -def _randint(key, shape, minval, maxval, dtype) -> Array: + info = dtypes.iinfo(dtype) + dtype_for_sampling = dtype + if info.bits < 32: + # Sample in 32 bits to avoid biased results. + dtype_for_sampling = np.dtype('int32') + minval = jnp.asarray(minval).astype('int32').clip(int(info.min), int(info.max)) + maxval = jnp.asarray(maxval).astype('int32').clip(int(info.min), int(info.max) + 1) + + return maybe_auto_axes(_randint, out_sharding, shape=shape, dtype=dtype_for_sampling)( + key, minval, maxval).astype(dtype) + + +@jit(static_argnums=(3, 4)) +def _randint(key, minval, maxval, shape, dtype) -> Array: + # We have three imperfect options for generating random integers in an arbitrary + # user-specified range: + # + # 1. Rejection sampling. This produces unbiased results, but involves a dynamic + # number of iterations, so it's not suitable for computation on accelerators. + # 2. Generate floating point values between minval and maxval, and cast to int. + # This introduces bias for large ranges due to floating point rounding error: + # many integers within a given range would never be sampled. + # 3. Generate numbers in a range that is a power of 2, and use an integer modulus + # to shift them into the desired range. This produces a biased distribution + # when the desired range is not a power of 2, which scales as + # O[bias] ~= (desired_range ** 2) / full_range. + # + # Given these three imperfect options, we opt for a modified version of (3), where we + # sample 2 * nbits bits per value, because it is efficient and works well in most cases + # of interest. To help users avoid inadvertently producing biased results, we always + # generate samples in at least 32 bits. _check_shape("randint", shape, np.shape(minval), np.shape(maxval)) - if not jnp.issubdtype(dtype, np.integer): + if not dtypes.issubdtype(dtype, np.integer): raise TypeError(f"randint only accepts integer dtypes, got {dtype}") check_arraylike("randint", minval, maxval) minval = jnp.asarray(minval) maxval = jnp.asarray(maxval) - if not jnp.issubdtype(minval.dtype, np.integer): + if not dtypes.issubdtype(minval.dtype, np.integer): minval = minval.astype(int) - if not jnp.issubdtype(maxval.dtype, np.integer): + if not dtypes.issubdtype(maxval.dtype, np.integer): maxval = maxval.astype(int) # Flag where maxval is greater than the maximum value of dtype # in order to handle cases like randint(key, shape, 0, 256, 'uint8') maxval_out_of_range = lax.gt( - maxval, _convert_and_clip_integer(jnp.array(jnp.iinfo(dtype).max, dtype), maxval.dtype)) + maxval, _convert_and_clip_integer(jnp.array(dtypes.iinfo(dtype).max, dtype), maxval.dtype)) minval = _convert_and_clip_integer(minval, dtype) maxval = _convert_and_clip_integer(maxval, dtype) minval = lax.broadcast_to_rank(minval, len(shape)) maxval = lax.broadcast_to_rank(maxval, len(shape)) - nbits = jnp.iinfo(dtype).bits + nbits = dtypes.iinfo(dtype).bits if nbits not in (8, 16, 32, 64): raise TypeError(f"randint only accepts 8-, 16-, 32-, or 64-bit dtypes, got {dtype}") @@ -518,14 +641,14 @@ def _randint(key, shape, minval, maxval, dtype) -> Array: # causing remainders below to have no effect, which is the correct semantics. span = lax.select( maxval_out_of_range & (maxval > minval), - lax.add(span, _lax_const(span, 1)), + lax.add(span, lax._const(span, 1)), span) # To compute a remainder operation on an integer that might have twice as many # bits as we can represent in the native unsigned dtype, we compute a # multiplier equal to 2**nbits % span. To avoid overflow, we use the identity: # (a * b) % N = [(a % N) * (b % N)] % N - multiplier = lax.rem(_lax_const(span, 2 ** (nbits // 2)), span) + multiplier = lax.rem(lax._const(span, 2 ** (nbits // 2)), span) multiplier = lax.rem(lax.mul(multiplier, multiplier), span) random_offset = lax.add(lax.mul(lax.rem(higher_bits, span), multiplier), @@ -537,7 +660,9 @@ def _randint(key, shape, minval, maxval, dtype) -> Array: def permutation(key: ArrayLike, x: int | ArrayLike, axis: int = 0, - independent: bool = False) -> Array: + independent: bool = False, + *, + out_sharding=None) -> Array: """Returns a randomly permuted array or range. Args: @@ -554,18 +679,24 @@ def permutation(key: ArrayLike, key, _ = _check_prng_key("permutation", key) check_arraylike("permutation", x) axis = canonicalize_axis(axis, np.ndim(x) or 1) + out_sharding = canonicalize_sharding(out_sharding, "permutation") if not np.ndim(x): if not np.issubdtype(lax.dtype(x), np.integer): raise TypeError("x must be an integer or at least 1-dimensional") - r = core.concrete_or_error(int, x, 'argument x of jax.random.permutation()') - return _shuffle(key, jnp.arange(r), axis) + r = core.concrete_or_error(int, x, "argument x of jax.random.permutation()") + return maybe_auto_axes(lambda key: _shuffle(key, jnp.arange(r), axis), + out_sharding)(key) + return maybe_auto_axes( + _permutation, out_sharding, axis=axis, independent=independent)(key, x) + +def _permutation(key, x, axis, independent): if independent or np.ndim(x) == 1: return _shuffle(key, x, axis) ind = _shuffle(key, jnp.arange(x.shape[axis]), 0) # type: ignore[union-attr] return jnp.take(x, ind, axis, unique_indices=True) -@partial(jit, static_argnums=(2,)) +@jit(static_argnums=(2,)) def _shuffle(key, x, axis) -> Array: # On parallel architectures, Fisher-Yates is more expensive than doing # multiple sorts. This algorithm is based on one developed and analyzed by @@ -582,7 +713,7 @@ def _shuffle(key, x, axis) -> Array: # Section 2 of http://people.csail.mit.edu/costis/6896sp11/lec5s.pdf for # another analysis (where the keys are generated one bit at a time). exponent = 3 # see tjablin@'s analysis for explanation of this parameter - uint32max = jnp.iinfo(np.uint32).max + uint32max = dtypes.iinfo(np.uint32).max if not core.is_constant_dim(x.size): raise NotImplementedError( "shape polymorphism for `permutation` or `shuffle`" @@ -602,7 +733,8 @@ def choice(key: ArrayLike, shape: Shape = (), replace: bool = True, p: RealArray | None = None, - axis: int = 0) -> Array: + axis: int = 0, + mode: str | None = None) -> Array: """Generates a random sample from a given array. .. warning:: @@ -625,6 +757,12 @@ def choice(key: ArrayLike, entries in a. axis: int, optional. The axis along which the selection is performed. The default, 0, selects by row. + mode: optional, "high" or "low" for how many bits to use in the gumbel sampler + when `p is None` and `replace = False`. The default is determined by the + ``use_high_dynamic_range_gumbel`` config, which defaults to "low". With mode="low", + in float32 sampling will be biased for choices with probability less than about + 1E-7; with mode="high" this limit is pushed down to about 1E-14. mode="high" + approximately doubles the cost of sampling. Returns: An array of shape `shape` containing samples from `a`. @@ -670,7 +808,7 @@ def choice(key: ArrayLike, ind = jnp.searchsorted(p_cuml, r).astype(int) else: # Gumbel top-k trick: https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/ - g = gumbel(key, (n_inputs,), dtype=p_arr.dtype) + jnp.log(p_arr) + g = gumbel(key, (n_inputs,), dtype=p_arr.dtype, mode=mode) + jnp.log(p_arr) ind = lax.top_k(g, k=n_draws)[1].astype(int) result = ind if arr.ndim == 0 else jnp.take(arr, ind, axis) @@ -680,7 +818,9 @@ def choice(key: ArrayLike, def normal(key: ArrayLike, shape: Shape = (), - dtype: DTypeLikeFloat = float) -> Array: + dtype: DTypeLikeFloat | None = None, + *, + out_sharding=None) -> Array: r"""Sample standard normal random values with given shape and float dtype. The values are returned according to the probability density function: @@ -702,14 +842,15 @@ def normal(key: ArrayLike, """ key, _ = _check_prng_key("normal", key) shape = core.canonicalize_shape(shape) - dtypes.check_user_dtype_supported(dtype) + out_sharding = canonicalize_sharding_for_samplers(out_sharding, "normal", shape) + dtype = dtypes.check_and_canonicalize_user_dtype( + float if dtype is None else dtype) if not dtypes.issubdtype(dtype, np.inexact): raise ValueError(f"dtype argument to `normal` must be a float or complex dtype, " f"got {dtype}") - dtype = dtypes.canonicalize_dtype(dtype) - return _normal(key, shape, dtype) + return maybe_auto_axes(_normal, out_sharding, shape=shape, dtype=dtype)(key) -@partial(jit, static_argnums=(1, 2)) +@jit(static_argnums=(1, 2)) def _normal(key, shape, dtype) -> Array: if dtypes.issubdtype(dtype, np.complexfloating): sqrt2 = np.array(np.sqrt(2), dtype) @@ -722,13 +863,13 @@ def _normal(key, shape, dtype) -> Array: else: return _normal_real(key, shape, dtype) -@partial(jit, static_argnums=(1, 2)) +@jit(static_argnums=(1, 2)) def _normal_real(key, shape, dtype) -> Array: _check_shape("normal", shape) lo = np.nextafter(np.array(-1., dtype), np.array(0., dtype), dtype=dtype) hi = np.array(1., dtype) u = uniform(key, shape, dtype, lo, hi) - return lax.mul(np.array(np.sqrt(2), dtype), lax.erf_inv(u)) + return lax.mul(np.array(np.sqrt(2), dtype), lax_special.erf_inv(u)) def multivariate_normal(key: ArrayLike, @@ -768,12 +909,13 @@ def multivariate_normal(key: ArrayLike, ``broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:]``. """ key, _ = _check_prng_key("multivariate_normal", key) - dtypes.check_user_dtype_supported(dtype) mean, cov = promote_dtypes_inexact(mean, cov) if method not in {'svd', 'eigh', 'cholesky'}: raise ValueError("method must be one of {'svd', 'eigh', 'cholesky'}") if dtype is None: dtype = mean.dtype + else: + dtype = dtypes.check_and_canonicalize_user_dtype(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `multivariate_normal` must be a float " f"dtype, got {dtype}") @@ -781,7 +923,7 @@ def multivariate_normal(key: ArrayLike, shape = core.canonicalize_shape(shape) return _multivariate_normal(key, mean, cov, shape, dtype, method) -@partial(jit, static_argnums=(3, 4, 5)) +@jit(static_argnums=(3, 4, 5)) def _multivariate_normal(key, mean, cov, shape, dtype, method) -> Array: if not np.ndim(mean) >= 1: msg = "multivariate_normal requires mean.ndim >= 1, got mean.ndim == {}" @@ -801,16 +943,16 @@ def _multivariate_normal(key, mean, cov, shape, dtype, method) -> Array: _check_shape("normal", shape, mean.shape[:-1], cov.shape[:-2]) if method == 'svd': - (u, s, _) = svd(cov) + (u, s, _) = jnp_linalg.svd(cov) factor = u * jnp.sqrt(s[..., None, :]) elif method == 'eigh': - (w, v) = eigh(cov) + (w, v) = jnp_linalg.eigh(cov) factor = v * jnp.sqrt(w[..., None, :]) else: # 'cholesky' - factor = cholesky(cov) + factor = jnp_linalg.cholesky(cov) normal_samples = normal(key, shape + mean.shape[-1:], dtype) with config.numpy_rank_promotion('allow'): - result = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples) + result = mean + jnp_einsum.einsum('...ij,...j->...i', factor, normal_samples) return result @@ -818,7 +960,8 @@ def truncated_normal(key: ArrayLike, lower: RealArray, upper: RealArray, shape: Shape | None = None, - dtype: DTypeLikeFloat = float) -> Array: + dtype: DTypeLikeFloat | None = None, + *, out_sharding=None) -> Array: r"""Sample truncated standard normal random values with given shape and dtype. The values are returned according to the probability density function: @@ -849,14 +992,16 @@ def truncated_normal(key: ArrayLike, if shape is not None: shape = core.canonicalize_shape(shape) key, _ = _check_prng_key("truncated_normal", key) - dtypes.check_user_dtype_supported(dtype) + out_sharding = canonicalize_sharding(out_sharding, "truncated_normal") + dtype = dtypes.check_and_canonicalize_user_dtype( + float if dtype is None else dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `truncated_normal` must be a float " f"dtype, got {dtype}") - dtype = dtypes.canonicalize_dtype(dtype) - return _truncated_normal(key, lower, upper, shape, dtype) + return maybe_auto_axes(_truncated_normal, out_sharding, + shape=shape, dtype=dtype)(key, lower, upper) -@partial(jit, static_argnums=(3, 4)) +@jit(static_argnums=(3, 4)) def _truncated_normal(key, lower, upper, shape, dtype) -> Array: if shape is None: shape = lax.broadcast_shapes(np.shape(lower), np.shape(upper)) @@ -866,12 +1011,12 @@ def _truncated_normal(key, lower, upper, shape, dtype) -> Array: sqrt2 = np.array(np.sqrt(2), dtype) lower = lax.convert_element_type(lower, dtype) upper = lax.convert_element_type(upper, dtype) - a = lax.erf(lower / sqrt2) - b = lax.erf(upper / sqrt2) - if not jnp.issubdtype(dtype, np.floating): + a = lax_special.erf(lower / sqrt2) + b = lax_special.erf(upper / sqrt2) + if not dtypes.issubdtype(dtype, np.floating): raise TypeError("truncated_normal only accepts floating point dtypes.") u = uniform(key, shape, dtype, minval=a, maxval=b) - out = sqrt2 * lax.erf_inv(u) + out = sqrt2 * lax_special.erf_inv(u) # Clamp the value to the open interval (lower, upper) to make sure that # rounding (or if we chose `a` for `u`) doesn't push us outside of the range. return jnp.clip( @@ -881,8 +1026,10 @@ def _truncated_normal(key, lower, upper, shape, dtype) -> Array: def bernoulli(key: ArrayLike, - p: RealArray = np.float32(0.5), - shape: Shape | None = None) -> Array: + p: RealArray = 0.5, + shape: Shape | None = None, + mode: str = 'low', + *, out_sharding=None) -> Array: r"""Sample Bernoulli random values with given shape and mean. The values are distributed according to the probability mass function: @@ -899,6 +1046,11 @@ def bernoulli(key: ArrayLike, shape: optional, a tuple of nonnegative integers representing the result shape. Must be broadcast-compatible with ``p.shape``. The default (None) produces a result shape equal to ``p.shape``. + mode: optional, "high" or "low" for how many bits to use when sampling. + default='low'. Set to "high" for correct sampling at small values of + `p`. When sampling in float32, bernoulli samples with mode='low' produce + incorrect results for p < ~1E-7. mode="high" approximately doubles the + cost of sampling. Returns: A random array with boolean dtype and shape given by ``shape`` if ``shape`` @@ -906,30 +1058,42 @@ def bernoulli(key: ArrayLike, """ if shape is not None: shape = core.canonicalize_shape(shape) + if mode not in ['high', 'low']: + raise ValueError(f"got {mode=}, expected 'high' or 'low'") key, _ = _check_prng_key("bernoulli", key) - dtype = dtypes.canonicalize_dtype(lax.dtype(p)) - if not jnp.issubdtype(dtype, np.floating): + out_sharding = canonicalize_sharding(out_sharding, "bernoulli") + dtype = lax.dtype(p) + if not dtypes.issubdtype(dtype, np.floating): msg = "bernoulli probability `p` must have a floating dtype, got {}." raise TypeError(msg.format(dtype)) p = lax.convert_element_type(p, dtype) - return _bernoulli(key, p, shape) + return maybe_auto_axes(_bernoulli, out_sharding, + shape=shape, mode=mode)(key, p) + -@partial(jit, static_argnums=(2,)) -def _bernoulli(key, p, shape) -> Array: +@jit(static_argnames=['shape', 'mode']) +def _bernoulli(key: Array, p: Array, shape: Shape | None, mode: str) -> Array: if shape is None: # TODO: Use the named part of `p` as well shape = np.shape(p) else: _check_shape("bernoulli", shape, np.shape(p)) + dtype = lax.dtype(p) - return uniform(key, shape, lax.dtype(p)) < p + if mode == 'high': + u1, u2 = uniform(key, (2, *shape), dtype) + # resolution of uniform samples is 2 ** -n_mantissa + u2 *= 2 ** -dtypes.finfo(dtype).nmant + return u2 < p - u1 + else: + return uniform(key, shape, lax.dtype(p)) < p def beta(key: ArrayLike, a: RealArray, b: RealArray, shape: Shape | None = None, - dtype: DTypeLikeFloat = float) -> Array: + dtype: DTypeLikeFloat | None = None) -> Array: r"""Sample Beta random values with given shape and float dtype. The values are distributed according to the probability density function: @@ -956,11 +1120,11 @@ def beta(key: ArrayLike, ``shape`` is not None, or else by broadcasting ``a`` and ``b``. """ key, _ = _check_prng_key("beta", key) - dtypes.check_user_dtype_supported(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype( + float if dtype is None else dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `beta` must be a float " f"dtype, got {dtype}") - dtype = dtypes.canonicalize_dtype(dtype) if shape is not None: shape = core.canonicalize_shape(shape) return _beta(key, a, b, shape, dtype) @@ -972,6 +1136,8 @@ def _beta(key, a, b, shape, dtype) -> Array: else: _check_shape("beta", shape, np.shape(a), np.shape(b)) + key, (a, b) = random_insert_pvary("jax.random.beta", key, a, b) + a = lax.convert_element_type(a, dtype) b = lax.convert_element_type(b, dtype) key_a, key_b = _split(key) @@ -988,7 +1154,7 @@ def _beta(key, a, b, shape, dtype) -> Array: def cauchy(key: ArrayLike, shape: Shape = (), - dtype: DTypeLikeFloat = float) -> Array: + dtype: DTypeLikeFloat | None = None) -> Array: r"""Sample Cauchy random values with given shape and float dtype. The values are distributed according to the probability density function: @@ -1009,26 +1175,26 @@ def cauchy(key: ArrayLike, A random array with the specified shape and dtype. """ key, _ = _check_prng_key("cauchy", key) - dtypes.check_user_dtype_supported(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype( + float if dtype is None else dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `cauchy` must be a float " f"dtype, got {dtype}") - dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) return _cauchy(key, shape, dtype) -@partial(jit, static_argnums=(1, 2)) +@jit(static_argnums=(1, 2)) def _cauchy(key, shape, dtype) -> Array: _check_shape("cauchy", shape) - u = uniform(key, shape, dtype, minval=jnp.finfo(dtype).eps, maxval=1.) - pi = _lax_const(u, np.pi) - return lax.tan(lax.mul(pi, lax.sub(u, _lax_const(u, 0.5)))) + u = uniform(key, shape, dtype, minval=dtypes.finfo(dtype).eps, maxval=1.) + pi = lax._const(u, np.pi) + return lax.tan(lax.mul(pi, lax.sub(u, lax._const(u, 0.5)))) def dirichlet(key: ArrayLike, alpha: RealArray, shape: Shape | None = None, - dtype: DTypeLikeFloat = float) -> Array: + dtype: DTypeLikeFloat | None = None) -> Array: r"""Sample Dirichlet random values with given shape and float dtype. The values are distributed according to the probability density function: @@ -1061,17 +1227,19 @@ def dirichlet(key: ArrayLike, ``alpha.shape``. """ key, _ = _check_prng_key("dirichlet", key) - dtypes.check_user_dtype_supported(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype( + float if dtype is None else dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `dirichlet` must be a float " f"dtype, got {dtype}") - dtype = dtypes.canonicalize_dtype(dtype) if shape is not None: shape = core.canonicalize_shape(shape) return _dirichlet(key, alpha, shape, dtype) -@partial(jit, static_argnums=(2, 3)) +@jit(static_argnums=(2, 3)) def _dirichlet(key, alpha, shape, dtype) -> Array: + from jax._src.nn.functions import softmax # pytype: disable=import-error + if not np.ndim(alpha) >= 1: msg = "dirichlet requires alpha.ndim >= 1, got alpha.ndim == {}" raise ValueError(msg.format(np.ndim(alpha))) @@ -1085,21 +1253,12 @@ def _dirichlet(key, alpha, shape, dtype) -> Array: # Compute gamma in log space, otherwise small alpha can lead to poor behavior. log_gamma_samples = loggamma(key, alpha, shape + np.shape(alpha)[-1:], dtype) - return _softmax(log_gamma_samples, -1) - - -def _softmax(x, axis) -> Array: - """Utility to compute the softmax of x along a given axis.""" - if not dtypes.issubdtype(x.dtype, np.floating): - raise TypeError(f"_softmax only accepts floating dtypes, got {x.dtype}") - x_max = jnp.max(x, axis, keepdims=True) - unnormalized = jnp.exp(x - lax.stop_gradient(x_max)) - return unnormalized / unnormalized.sum(axis, keepdims=True) + return softmax(log_gamma_samples, -1) def exponential(key: ArrayLike, shape: Shape = (), - dtype: DTypeLikeFloat = float) -> Array: + dtype: DTypeLikeFloat | None = None) -> Array: r"""Sample Exponential random values with given shape and float dtype. The values are distributed according to the probability density function: @@ -1120,15 +1279,15 @@ def exponential(key: ArrayLike, A random array with the specified shape and dtype. """ key, _ = _check_prng_key("exponential", key) - dtypes.check_user_dtype_supported(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype( + float if dtype is None else dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `exponential` must be a float " f"dtype, got {dtype}") - dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) return _exponential(key, shape, dtype) -@partial(jit, static_argnums=(1, 2)) +@jit(static_argnums=(1, 2)) def _exponential(key, shape, dtype) -> Array: _check_shape("exponential", shape) u = uniform(key, shape, dtype) @@ -1140,14 +1299,20 @@ def _gamma_one(key: Array, alpha, log_space) -> Array: # Ref: A simple method for generating gamma variables, George Marsaglia and Wai Wan Tsang # The algorithm can also be founded in: # https://en.wikipedia.org/wiki/Gamma_distribution#Generating_gamma-distributed_random_variables - zero = _lax_const(alpha, 0) - one = _lax_const(alpha, 1) - minus_one = _lax_const(alpha, -1) - one_over_two = _lax_const(alpha, 0.5) - one_over_three = _lax_const(alpha, 1. / 3.) - squeeze_const = _lax_const(alpha, 0.0331) + zero = lax._const(alpha, 0) + one = lax._const(alpha, 1) + two = lax._const(alpha, 2) + minus_one = lax._const(alpha, -1) + one_over_two = lax._const(alpha, 0.5) + one_over_three = lax._const(alpha, 1. / 3.) + squeeze_const = lax._const(alpha, 0.0331) dtype = lax.dtype(alpha) + zero = core.pvary(zero, tuple(core.typeof(alpha).vma)) + one = core.pvary(one, tuple(core.typeof(alpha).vma)) + minus_one = core.pvary(minus_one, tuple(core.typeof(alpha).vma)) + two = core.pvary(two, tuple(core.typeof(alpha).vma)) + # for alpha < 1, we boost alpha to alpha + 1 and get a sample according to # Gamma(alpha) ~ Gamma(alpha+1) * Uniform()^(1 / alpha) # When alpha is very small, this boost can be problematic because it may result @@ -1169,10 +1334,11 @@ def _cond_fn(kXVU): # TODO: use lax.cond when its batching rule is supported # The reason is to avoid evaluating second condition which involves log+log # if the first condition is satisfied - cond = lax.bitwise_and(lax.ge(U, lax.sub(one, lax.mul(squeeze_const, lax.mul(X, X)))), - lax.ge(lax.log(U), lax.add(lax.mul(X, one_over_two), - lax.mul(d, lax.add(lax.sub(one, V), - lax.log(V)))))) + cond = lax.bitwise_and( + lax.ge(U, lax.sub(one, lax.mul(squeeze_const, lax.mul(X, X)))), + lax.ge(lax.log(U), lax.add(lax.mul(X, one_over_two), + lax.mul(d, lax.add(lax.sub(one, V), + lax.log(V)))))) return cond def _body_fn(kXVU): @@ -1185,7 +1351,8 @@ def _next_kxv(kxv): key = kXVU[0] key, x_key, U_key = _split(key, 3) - _, x, v = lax.while_loop(lambda kxv: lax.le(kxv[2], zero), _next_kxv, (x_key, zero, minus_one)) + _, x, v = lax_control_flow.while_loop(lambda kxv: lax.le(kxv[2], zero), + _next_kxv, (x_key, zero, minus_one)) X = lax.mul(x, x) V = lax.mul(lax.mul(v, v), v) U = uniform(U_key, (), dtype=dtype) @@ -1193,14 +1360,17 @@ def _next_kxv(kxv): # initial state is chosen such that _cond_fn will return True key, subkey = _split(key) - _, _, V, _ = lax.while_loop(_cond_fn, _body_fn, (key, zero, one, _lax_const(alpha, 2))) + _, _, V, _ = lax_control_flow.while_loop( + _cond_fn, _body_fn, (key, zero, one, two)) if log_space: log_samples = lax.neg(exponential(subkey, (), dtype=dtype)) - log_boost = lax.select(boost_mask | (log_samples == 0), zero, lax.mul(log_samples, lax.div(one, alpha_orig))) + log_boost = lax.select(boost_mask | (log_samples == 0), zero, + lax.mul(log_samples, lax.div(one, alpha_orig))) return lax.add(lax.add(lax.log(d), lax.log(V)), log_boost) else: samples = 1 - uniform(subkey, (), dtype=dtype) - boost = lax.select(boost_mask, one, lax.pow(samples, lax.div(one, alpha_orig))) + boost = lax.select(boost_mask, one, + lax.pow(samples, lax.div(one, alpha_orig))) return lax.mul(lax.mul(d, V), boost) @@ -1212,21 +1382,22 @@ def _gamma_grad(sample, a, *, log_space): # This requires computing exp(log_sample), which may be zero due to float roundoff. # In this case, correct it to smallest representable float. samples = lax.exp(samples) - zero = lax_internal._const(sample, 0) - tiny = lax.full_like(samples, jnp.finfo(samples.dtype).tiny) + zero = lax._const(sample, 0) + tiny = lax.full_like(samples, dtypes.finfo(samples.dtype).tiny) samples = lax.select(lax.eq(samples, zero), tiny, samples) - gamma_grad = lambda alpha, sample: lax.random_gamma_grad(alpha, sample) / sample + gamma_grad = lambda alpha, sample: ( + lax_special.random_gamma_grad(alpha, sample) / sample) else: - gamma_grad = lax.random_gamma_grad + gamma_grad = lax_special.random_gamma_grad if xla_bridge.get_backend().platform == 'cpu': - grads = lax.map(lambda args: gamma_grad(*args), (alphas, samples)) + grads = lax_control_flow.map(lambda args: gamma_grad(*args), (alphas, samples)) else: grads = vmap(gamma_grad)(alphas, samples) return grads.reshape(np.shape(a)) def _gamma_impl(key, a, *, log_space, use_vmap=False): # split key to match the shape of a - a_shape = jnp.shape(a) + a_shape = np.shape(a) split_count = math.prod(a_shape[key.ndim:]) keys = key.flatten() keys = vmap(_split, in_axes=(0, None))(keys, split_count) @@ -1236,7 +1407,7 @@ def _gamma_impl(key, a, *, log_space, use_vmap=False): if use_vmap and _key_impl(key) is prng.threefry_prng_impl: samples = vmap(partial(_gamma_one, log_space=log_space))(keys, alphas) else: - samples = lax.map( + samples = lax_control_flow.map( lambda args: _gamma_one(*args, log_space=log_space), (keys, alphas)) return jnp.reshape(samples, a_shape) @@ -1252,7 +1423,12 @@ def _gamma_batching_rule(batched_args, batch_dims, *, log_space): random_gamma_p = core.Primitive('random_gamma') random_gamma_p.def_impl(_gamma_impl) -random_gamma_p.def_abstract_eval(lambda key, a, **_: a) + +def _random_gamma_abstract_eval(key, a, **_): + core.standard_vma_rule('random_gamma', key, a) + return a +random_gamma_p.def_abstract_eval(_random_gamma_abstract_eval) + ad.defjvp2( random_gamma_p, None, lambda tangent, ans, key, a, **kwds: tangent * _gamma_grad(ans, a, **kwds)) @@ -1267,7 +1443,7 @@ def _gamma_batching_rule(batched_args, batch_dims, *, log_space): def gamma(key: ArrayLike, a: RealArray, shape: Shape | None = None, - dtype: DTypeLikeFloat = float) -> Array: + dtype: DTypeLikeFloat | None = None) -> Array: r"""Sample Gamma random values with given shape and float dtype. The values are distributed according to the probability density function: @@ -1301,11 +1477,11 @@ def gamma(key: ArrayLike, accuracy for small values of ``a``. """ key, _ = _check_prng_key("gamma", key) - dtypes.check_user_dtype_supported(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype( + float if dtype is None else dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `gamma` must be a float " f"dtype, got {dtype}") - dtype = dtypes.canonicalize_dtype(dtype) if shape is not None: shape = core.canonicalize_shape(shape) return _gamma(key, a, shape=shape, dtype=dtype) @@ -1314,7 +1490,7 @@ def gamma(key: ArrayLike, def loggamma(key: ArrayLike, a: RealArray, shape: Shape | None = None, - dtype: DTypeLikeFloat = float) -> Array: + dtype: DTypeLikeFloat | None = None) -> Array: """Sample log-gamma random values with given shape and float dtype. This function is implemented such that the following will hold for a @@ -1343,17 +1519,17 @@ def loggamma(key: ArrayLike, gamma : standard gamma sampler. """ key, _ = _check_prng_key("loggamma", key) - dtypes.check_user_dtype_supported(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype( + float if dtype is None else dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `gamma` must be a float " f"dtype, got {dtype}") - dtype = dtypes.canonicalize_dtype(dtype) if shape is not None: shape = core.canonicalize_shape(shape) return _gamma(key, a, shape=shape, dtype=dtype, log_space=True) -@partial(jit, static_argnames=('shape', 'dtype', 'log_space')) +@jit(static_argnames=('shape', 'dtype', 'log_space')) def _gamma(key, a, shape, dtype, log_space=False) -> Array: if shape is None: shape = np.shape(a) @@ -1363,10 +1539,11 @@ def _gamma(key, a, shape, dtype, log_space=False) -> Array: a = lax.convert_element_type(a, dtype) if np.shape(a) != shape: a = jnp.broadcast_to(a, shape) + key, (a,) = random_insert_pvary('gamma', key, a) return random_gamma_p.bind(key, a, log_space=log_space) -@partial(jit, static_argnums=(2, 3, 4)) +@jit(static_argnums=(2, 3, 4)) def _poisson_knuth(key, lam, shape, dtype, max_iters) -> Array: # Knuth's algorithm for generating Poisson random variates. # Reference: @@ -1385,11 +1562,11 @@ def cond_fn(carry): k_init = lax.full_like(lam, 0, dtype, shape) log_rate_init = lax.full_like(lam, 0, np.float32, shape) - k = lax.while_loop(cond_fn, body_fn, (0, k_init, key, log_rate_init))[1] + k = lax_control_flow.while_loop(cond_fn, body_fn, (0, k_init, key, log_rate_init))[1] return (k - 1).astype(dtype) -@partial(jit, static_argnums=(2, 3, 4)) +@jit(static_argnums=(2, 3, 4)) def _poisson_rejection(key, lam, shape, dtype, max_iters) -> Array: # Transformed rejection due to Hormann. # Reference: @@ -1410,7 +1587,7 @@ def body_fn(carry): k = lax.floor((2 * a / u_shifted + b) * u + lam + 0.43) s = lax.log(v * inv_alpha / (a / (u_shifted * u_shifted) + b)) - t = -lam + k * log_lam - lax.lgamma(k + 1) + t = -lam + k * log_lam - lax_special.lgamma(k + 1) accept1 = (u_shifted >= 0.07) & (v <= v_r) reject = (k < 0) | ((u_shifted < 0.013) & (v > u_shifted)) @@ -1427,12 +1604,12 @@ def cond_fn(carry): return (~accepted).any() & (i < max_iters) k_init = lax.full_like(lam, -1, lam.dtype, shape) - accepted = lax.full_like(lam, False, jnp.bool_, shape) - k = lax.while_loop(cond_fn, body_fn, (0, k_init, accepted, key))[1] + accepted = lax.full_like(lam, False, np.dtype('bool'), shape) + k = lax_control_flow.while_loop(cond_fn, body_fn, (0, k_init, accepted, key))[1] return k.astype(dtype) -@partial(jit, static_argnums=(2, 3)) +@jit(static_argnums=(2, 3)) def _poisson(key, lam, shape, dtype) -> Array: # The implementation matches TensorFlow and NumPy: # https://github.com/tensorflow/tensorflow/blob/v2.2.0-rc3/tensorflow/core/kernels/random_poisson_op.cc @@ -1444,7 +1621,7 @@ def _poisson(key, lam, shape, dtype) -> Array: # The acceptance probability for rejection sampling maxes out at 89% as # λ -> ∞, so pick some arbitrary large value. lam_rejection = lax.select(use_knuth, lax.full_like(lam, 1e5), lam) - max_iters = dtype.type(jnp.iinfo(dtype).max) # insanely conservative + max_iters = dtype.type(dtypes.iinfo(dtype).max) # insanely conservative result = lax.select( use_knuth, _poisson_knuth(key, lam_knuth, shape, dtype, max_iters), @@ -1456,7 +1633,7 @@ def _poisson(key, lam, shape, dtype) -> Array: def poisson(key: ArrayLike, lam: RealArray, shape: Shape | None = None, - dtype: DTypeLikeInt = int) -> Array: + dtype: DTypeLikeInt | None = None) -> Array: r"""Sample Poisson random values with given shape and integer dtype. The values are distributed according to the probability mass function: @@ -1479,7 +1656,8 @@ def poisson(key: ArrayLike, ``shape is not None, or else by ``lam.shape``. """ key, _ = _check_prng_key("poisson", key) - dtypes.check_user_dtype_supported(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype( + int if dtype is None else dtype) # TODO(frostig): generalize underlying poisson implementation and # remove this check keys_dtype = typing.cast(prng.KeyTy, key.dtype) @@ -1488,7 +1666,6 @@ def poisson(key: ArrayLike, raise NotImplementedError( '`poisson` is only implemented for the threefry2x32 RNG, ' f'not {key_impl}') - dtype = dtypes.canonicalize_dtype(dtype) if shape is not None: shape = core.canonicalize_shape(shape) else: @@ -1500,8 +1677,10 @@ def poisson(key: ArrayLike, def gumbel(key: ArrayLike, shape: Shape = (), - dtype: DTypeLikeFloat = float, - mode: str | None =None) -> Array: + dtype: DTypeLikeFloat | None = None, + mode: str | None = None, + *, + out_sharding=None) -> Array: """Sample Gumbel random values with given shape and float dtype. The values are distributed according to the probability density function: @@ -1516,36 +1695,45 @@ def gumbel(key: ArrayLike, dtype: optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32). mode: optional, "high" or "low" for how many bits to use when sampling. + The default is determined by the ``use_high_dynamic_range_gumbel`` config, + which defaults to "low". When drawing float32 samples, with mode="low" the + uniform resolution is such that the largest possible gumbel logit is ~16; + with mode="high" this is increased to ~32, at approximately double the + computational cost. Returns: A random array with the specified shape and dtype. """ key, _ = _check_prng_key("gumbel", key) - dtypes.check_user_dtype_supported(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype( + float if dtype is None else dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `gumbel` must be a float " f"dtype, got {dtype}") - dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) if mode is None: mode = "high" if config.use_high_dynamic_range_gumbel.value else "low" if mode not in ("high", "low"): raise ValueError("Must provide valid mode for gumbel got: %s" % mode) - return _gumbel(key, shape, dtype, mode) + out_sharding = canonicalize_sharding_for_samplers(out_sharding, "gumbel", shape) + return maybe_auto_axes(_gumbel, out_sharding, shape=shape, dtype=dtype, + mode=mode)(key) -@partial(jit, static_argnums=(1, 2, 3)) +@jit(static_argnums=(1, 2, 3)) def _gumbel(key, shape, dtype, mode) -> Array: _check_shape("gumbel", shape) + info = dtypes.finfo(dtype) if mode == "high": - high, low = _uniform(key, (2,) + shape, dtype, minval=0., maxval=1.) + high, low = _uniform(key, minval=0., maxval=1., + shape=(2,) + shape, dtype=dtype) # TODO(parkers): The condition is to protect against rounding up but # we should be able to add safely with the right addition operation. x = jnp.where(high >= 0.5, high, - high + 2 ** -(jnp.finfo(dtype).nmant) * low + jnp.finfo(dtype).tiny) + high + 2 ** -(info.nmant) * low + info.tiny) return -jnp.log(-jnp.log1p(-x)) else: return -jnp.log(-jnp.log( - _uniform(key, shape, dtype, minval=jnp.finfo(dtype).tiny, maxval=1.))) + _uniform(key, minval=info.tiny, maxval=1., shape=shape, dtype=dtype))) def categorical( @@ -1554,6 +1742,7 @@ def categorical( axis: int = -1, shape: Shape | None = None, replace: bool = True, + mode: str | None = None, ) -> Array: """Sample random values from categorical distributions. @@ -1568,8 +1757,14 @@ def categorical( shape: Optional, a tuple of nonnegative integers representing the result shape. Must be broadcast-compatible with ``np.delete(logits.shape, axis)``. The default (None) produces a result shape equal to ``np.delete(logits.shape, axis)``. - replace: If True, perform sampling without replacement. Default (False) is to - perform sampling with replacement. + replace: If True (default), perform sampling with replacement. If False, perform + sampling without replacement. + mode: optional, "high" or "low" for how many bits to use in the gumbel sampler. + The default is determined by the ``use_high_dynamic_range_gumbel`` config, + which defaults to "low". With mode="low", in float32 sampling will be biased + for events with probability less than about 1E-7; with mode="high" this limit + is pushed down to about 1E-14. mode="high" approximately doubles the cost of + sampling. Returns: A random array with int dtype and shape given by ``shape`` if ``shape`` @@ -1599,11 +1794,11 @@ def categorical( logits_shape = list(shape[len(shape) - len(batch_shape):]) logits_shape.insert(axis % len(logits_arr.shape), logits_arr.shape[axis]) return jnp.argmax( - gumbel(key, (*shape_prefix, *logits_shape), logits_arr.dtype) + + gumbel(key, (*shape_prefix, *logits_shape), logits_arr.dtype, mode=mode) + lax.expand_dims(logits_arr, tuple(range(len(shape_prefix)))), axis=axis) else: - logits_arr += gumbel(key, logits_arr.shape, logits_arr.dtype) + logits_arr += gumbel(key, logits_arr.shape, logits_arr.dtype, mode=mode) k = math.prod(shape_prefix) if k > logits_arr.shape[axis]: raise ValueError( @@ -1623,7 +1818,7 @@ def categorical( def laplace(key: ArrayLike, shape: Shape = (), - dtype: DTypeLikeFloat = float) -> Array: + dtype: DTypeLikeFloat | None = None) -> Array: r"""Sample Laplace random values with given shape and float dtype. The values are distributed according to the probability density function: @@ -1642,25 +1837,25 @@ def laplace(key: ArrayLike, A random array with the specified shape and dtype. """ key, _ = _check_prng_key("laplace", key) - dtypes.check_user_dtype_supported(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype( + float if dtype is None else dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `laplace` must be a float " f"dtype, got {dtype}") - dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) return _laplace(key, shape, dtype) -@partial(jit, static_argnums=(1, 2)) +@jit(static_argnums=(1, 2)) def _laplace(key, shape, dtype) -> Array: _check_shape("laplace", shape) u = uniform( - key, shape, dtype, minval=-1. + jnp.finfo(dtype).epsneg, maxval=1.) + key, shape, dtype, minval=-1. + dtypes.finfo(dtype).epsneg, maxval=1.) return lax.mul(lax.sign(u), lax.log1p(lax.neg(lax.abs(u)))) def logistic(key: ArrayLike, shape: Shape = (), - dtype: DTypeLikeFloat = float) -> Array: + dtype: DTypeLikeFloat | None = None) -> Array: r"""Sample logistic random values with given shape and float dtype. The values are distributed according to the probability density function: @@ -1679,25 +1874,25 @@ def logistic(key: ArrayLike, A random array with the specified shape and dtype. """ key, _ = _check_prng_key("logistic", key) - dtypes.check_user_dtype_supported(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype( + float if dtype is None else dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `logistic` must be a float " f"dtype, got {dtype}") - dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) return _logistic(key, shape, dtype) -@partial(jit, static_argnums=(1, 2)) +@jit(static_argnums=(1, 2)) def _logistic(key, shape, dtype): _check_shape("logistic", shape) - x = uniform(key, shape, dtype, minval=jnp.finfo(dtype).eps, maxval=1.) - return lax.log(lax.div(x, lax.sub(_lax_const(x, 1), x))) + x = uniform(key, shape, dtype, minval=dtypes.finfo(dtype).tiny, maxval=1.) + return lax.sub(lax.log(x), lax.log1p(lax.neg(x))) def pareto(key: ArrayLike, b: RealArray, shape: Shape | None = None, - dtype: DTypeLikeFloat = float) -> Array: + dtype: DTypeLikeFloat | None = None) -> Array: r"""Sample Pareto random values with given shape and float dtype. The values are distributed according to the probability density function: @@ -1722,16 +1917,16 @@ def pareto(key: ArrayLike, ``shape`` is not None, or else by ``b.shape``. """ key, _ = _check_prng_key("pareto", key) - dtypes.check_user_dtype_supported(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype( + float if dtype is None else dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `pareto` must be a float " f"dtype, got {dtype}") - dtype = dtypes.canonicalize_dtype(dtype) if shape is not None: shape = core.canonicalize_shape(shape) return _pareto(key, b, shape, dtype) -@partial(jit, static_argnums=(2, 3)) +@jit(static_argnums=(2, 3)) def _pareto(key, b, shape, dtype) -> Array: if shape is None: shape = np.shape(b) @@ -1746,7 +1941,7 @@ def _pareto(key, b, shape, dtype) -> Array: def t(key: ArrayLike, df: RealArray, shape: Shape = (), - dtype: DTypeLikeFloat = float) -> Array: + dtype: DTypeLikeFloat | None = None) -> Array: r"""Sample Student's t random values with given shape and float dtype. The values are distributed according to the probability density function: @@ -1771,15 +1966,15 @@ def t(key: ArrayLike, ``shape`` is not None, or else by ``df.shape``. """ key, _ = _check_prng_key("t", key) - dtypes.check_user_dtype_supported(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype( + float if dtype is None else dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `t` must be a float " f"dtype, got {dtype}") - dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) return _t(key, df, shape, dtype) -@partial(jit, static_argnums=(2, 3)) +@jit(static_argnums=(2, 3)) def _t(key, df, shape, dtype) -> Array: if shape is None: shape = np.shape(df) @@ -1789,7 +1984,7 @@ def _t(key, df, shape, dtype) -> Array: df = lax.convert_element_type(df, dtype) key_n, key_g = _split(key) n = normal(key_n, shape, dtype) - two = _lax_const(n, 2) + two = lax._const(n, 2) half_df = lax.div(df, two) g = gamma(key_g, half_df, shape, dtype) return n * jnp.sqrt(half_df / g) @@ -1798,7 +1993,7 @@ def _t(key, df, shape, dtype) -> Array: def chisquare(key: ArrayLike, df: RealArray, shape: Shape | None = None, - dtype: DTypeLikeFloat = float) -> Array: + dtype: DTypeLikeFloat | None = None) -> Array: r"""Sample Chisquare random values with given shape and float dtype. The values are distributed according to the probability density function: @@ -1824,23 +2019,23 @@ def chisquare(key: ArrayLike, ``shape`` is not None, or else by ``df.shape``. """ key, _ = _check_prng_key("chisquare", key) - dtypes.check_user_dtype_supported(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype( + float if dtype is None else dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError("dtype argument to `chisquare` must be a float " f"dtype, got {dtype}") - dtype = dtypes.canonicalize_dtype(dtype) if shape is not None: shape = core.canonicalize_shape(shape) return _chisquare(key, df, shape, dtype) -@partial(jit, static_argnums=(2, 3)) +@jit(static_argnums=(2, 3)) def _chisquare(key, df, shape, dtype) -> Array: if shape is None: shape = np.shape(df) else: _check_shape("chisquare", shape, np.shape(df)) df = lax.convert_element_type(df, dtype) - two = _lax_const(df, 2) + two = lax._const(df, 2) half_df = lax.div(df, two) log_g = loggamma(key, a=half_df, shape=shape, dtype=dtype) chi2 = lax.mul(jnp.exp(log_g), two) @@ -1851,7 +2046,7 @@ def f(key: ArrayLike, dfnum: RealArray, dfden: RealArray, shape: Shape | None = None, - dtype: DTypeLikeFloat = float) -> Array: + dtype: DTypeLikeFloat | None = None) -> Array: r"""Sample F-distribution random values with given shape and float dtype. The values are distributed according to the probability density function: @@ -1882,16 +2077,16 @@ def f(key: ArrayLike, ``shape`` is not None, or else by ``df.shape``. """ key, _ = _check_prng_key("f", key) - dtypes.check_user_dtype_supported(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype( + float if dtype is None else dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError("dtype argument to `f` must be a float " f"dtype, got {dtype}") - dtype = dtypes.canonicalize_dtype(dtype) if shape is not None: shape = core.canonicalize_shape(shape) return _f(key, dfnum, dfden, shape, dtype) -@partial(jit, static_argnums=(3, 4)) +@jit(static_argnums=(3, 4)) def _f(key, dfnum, dfden, shape, dtype) -> Array: if shape is None: shape = lax.broadcast_shapes(np.shape(dfden), np.shape(dfnum)) @@ -1913,7 +2108,7 @@ def _f(key, dfnum, dfden, shape, dtype) -> Array: def rademacher(key: ArrayLike, shape: Shape = (), - dtype: DTypeLikeInt = int) -> Array: + dtype: DTypeLikeInt | None = None) -> Array: r"""Sample from a Rademacher distribution. The values are distributed according to the probability mass function: @@ -1934,13 +2129,13 @@ def rademacher(key: ArrayLike, """ key, _ = _check_prng_key("rademacher", key) - dtypes.check_user_dtype_supported(dtype) - dtype = dtypes.canonicalize_dtype(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype( + int if dtype is None else dtype) shape = core.canonicalize_shape(shape) return _rademacher(key, shape, dtype) -@partial(jit, static_argnums=(1, 2)) +@jit(static_argnums=(1, 2)) def _rademacher(key, shape, dtype) -> Array: bernoulli_samples = bernoulli(key=key, p=0.5, shape=shape).astype(dtype) return (2 * bernoulli_samples - 1).astype(dtype) @@ -1948,7 +2143,7 @@ def _rademacher(key, shape, dtype) -> Array: def maxwell(key: ArrayLike, shape: Shape = (), - dtype: DTypeLikeFloat = float) -> Array: + dtype: DTypeLikeFloat | None = None) -> Array: r"""Sample from a one sided Maxwell distribution. The values are distributed according to the probability density function: @@ -1970,27 +2165,27 @@ def maxwell(key: ArrayLike, # Generate samples using: # sqrt(X^2 + Y^2 + Z^2), X,Y,Z ~N(0,1) key, _ = _check_prng_key("maxwell", key) - dtypes.check_user_dtype_supported(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype( + float if dtype is None else dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `maxwell` must be a float " f"dtype, got {dtype}") - dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) return _maxwell(key, shape, dtype) -@partial(jit, static_argnums=(1, 2)) +@jit(static_argnums=(1, 2)) def _maxwell(key, shape, dtype) -> Array: shape = shape + (3,) norm_rvs = normal(key=key, shape=shape, dtype=dtype) - return jnp.linalg.norm(norm_rvs, axis=-1) + return jnp_linalg.norm(norm_rvs, axis=-1) def double_sided_maxwell(key: ArrayLike, loc: RealArray, scale: RealArray, shape: Shape = (), - dtype: DTypeLikeFloat = float) -> Array: + dtype: DTypeLikeFloat | None = None) -> Array: r"""Sample from a double sided Maxwell distribution. The values are distributed according to the probability density function: @@ -2013,16 +2208,16 @@ def double_sided_maxwell(key: ArrayLike, """ key, _ = _check_prng_key("double_sided_maxwell", key) - dtypes.check_user_dtype_supported(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype( + float if dtype is None else dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `double_sided_maxwell` must be a float" f" dtype, got {dtype}") - dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) return _double_sided_maxwell(key, loc, scale, shape, dtype) -@partial(jit, static_argnums=(3, 4)) +@jit(static_argnums=(3, 4)) def _double_sided_maxwell(key, loc, scale, shape, dtype) -> Array: params_shapes = lax.broadcast_shapes(np.shape(loc), np.shape(scale)) if not shape: @@ -2042,7 +2237,7 @@ def weibull_min(key: ArrayLike, scale: RealArray, concentration: RealArray, shape: Shape = (), - dtype: DTypeLikeFloat = float) -> Array: + dtype: DTypeLikeFloat | None = None) -> Array: r"""Sample from a Weibull distribution. The values are distributed according to the probability density function: @@ -2065,16 +2260,16 @@ def weibull_min(key: ArrayLike, """ key, _ = _check_prng_key("weibull_min", key) - dtypes.check_user_dtype_supported(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype( + float if dtype is None else dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `weibull_min` must be a float " f"dtype, got {dtype}") - dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) return _weibull_min(key, scale, concentration, shape, dtype) -@partial(jit, static_argnums=(3, 4)) +@jit(static_argnums=(3, 4)) def _weibull_min(key, scale, concentration, shape, dtype) -> Array: random_uniform = uniform( key=key, shape=shape, minval=0, maxval=1, dtype=dtype) @@ -2087,7 +2282,7 @@ def orthogonal( key: ArrayLike, n: int, shape: Shape = (), - dtype: DTypeLikeFloat = float, + dtype: DTypeLikeFloat | None = None, m: int | None = None, ) -> Array: r"""Sample uniformly from the orthogonal group O(n). @@ -2110,7 +2305,7 @@ def orthogonal( m: an integer indicating the number of columns. Defaults to `n`. Returns: - A random array of shape `(*shape, n, n)` and specified dtype. + A random array of shape `(*shape, n, m)` and specified dtype. References: .. [1] Mezzadri, Francesco. (2007). "How to generate random matrices from @@ -2123,14 +2318,15 @@ def orthogonal( _m = m shape = core.canonicalize_shape(shape) key, _ = _check_prng_key("orthogonal", key) - dtypes.check_user_dtype_supported(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype( + float if dtype is None else dtype) _check_shape("orthogonal", shape) n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()") _m = core.concrete_or_error(index, _m, "The error occurred in jax.random.orthogonal()") z = normal(key, (*shape, max(n, _m), min(n, _m)), dtype) - q, r = jnp.linalg.qr(z) - d = jnp.linalg.diagonal(r) + q, r = jnp_linalg.qr(z) + d = jnp_linalg.diagonal(r) x = q * jnp.expand_dims(jnp.sign(d), -2) if n < _m: @@ -2142,7 +2338,7 @@ def generalized_normal( key: ArrayLike, p: float, shape: Shape = (), - dtype: DTypeLikeFloat = float + dtype: DTypeLikeFloat | None = None ) -> Array: r"""Sample from the generalized normal distribution. @@ -2166,7 +2362,8 @@ def generalized_normal( """ shape = core.canonicalize_shape(shape) key, _ = _check_prng_key("generalized_normal", key) - dtypes.check_user_dtype_supported(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype( + float if dtype is None else dtype) _check_shape("generalized_normal", shape) keys = split(key) g = gamma(keys[0], 1/p, shape, dtype) @@ -2178,7 +2375,7 @@ def ball( d: int, p: float = 2, shape: Shape = (), - dtype: DTypeLikeFloat = float + dtype: DTypeLikeFloat | None = None ): """Sample uniformly from the unit Lp ball. @@ -2197,7 +2394,8 @@ def ball( """ shape = core.canonicalize_shape(shape) key, _ = _check_prng_key("ball", key) - dtypes.check_user_dtype_supported(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype( + float if dtype is None else dtype) _check_shape("ball", shape) d = core.concrete_or_error(index, d, "The error occurred in jax.random.ball()") k1, k2 = split(key) @@ -2209,7 +2407,7 @@ def ball( def rayleigh(key: ArrayLike, scale: RealArray, shape: Shape | None = None, - dtype: DTypeLikeFloat = float) -> Array: + dtype: DTypeLikeFloat | None = None) -> Array: r"""Sample Rayleigh random values with given shape and float dtype. The values are returned according to the probability density function: @@ -2235,16 +2433,16 @@ def rayleigh(key: ArrayLike, ``shape`` is not None, or else by ``scale.shape``. """ key, _ = _check_prng_key("rayleigh", key) - dtypes.check_user_dtype_supported(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype( + float if dtype is None else dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError("dtype argument to `rayleigh` must be a float " f"dtype, got {dtype}") - dtype = dtypes.canonicalize_dtype(dtype) if shape is not None: shape = core.canonicalize_shape(shape) return _rayleigh(key, scale, shape, dtype) -@partial(jit, static_argnums=(2, 3)) +@jit(static_argnums=(2, 3)) def _rayleigh(key, scale, shape, dtype) -> Array: if shape is None: shape = np.shape(scale) @@ -2254,7 +2452,7 @@ def _rayleigh(key, scale, shape, dtype) -> Array: scale = scale.astype(dtype) scale = jnp.broadcast_to(scale, shape) log_u = lax.log(u) - n_two = _lax_const(scale, -2) + n_two = lax._const(scale, -2) sqrt_u = lax.sqrt(lax.mul(log_u, n_two)) ray = lax.mul(scale, sqrt_u) return ray @@ -2262,7 +2460,7 @@ def _rayleigh(key, scale, shape, dtype) -> Array: def wald(key: ArrayLike, mean: RealArray, shape: Shape | None = None, - dtype: DTypeLikeFloat = float) -> Array: + dtype: DTypeLikeFloat | None = None) -> Array: r"""Sample Wald random values with given shape and float dtype. The values are returned according to the probability density function: @@ -2289,16 +2487,16 @@ def wald(key: ArrayLike, ``shape`` is not None, or else by ``mean.shape``. """ key, _ = _check_prng_key("wald", key) - dtypes.check_user_dtype_supported(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype( + float if dtype is None else dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError("dtype argument to `wald` must be a float " f"dtype, got {dtype}") - dtype = dtypes.canonicalize_dtype(dtype) if shape is not None: shape = core.canonicalize_shape(shape) return _wald(key, mean, shape, dtype) -@partial(jit, static_argnums=(2, 3)) +@jit(static_argnums=(2, 3)) def _wald(key, mean, shape, dtype) -> Array: if shape is None: shape = np.shape(mean) @@ -2320,7 +2518,7 @@ def _wald(key, mean, shape, dtype) -> Array: def geometric(key: ArrayLike, p: RealArray, shape: Shape | None = None, - dtype: DTypeLikeInt = int) -> Array: + dtype: DTypeLikeInt | None = None) -> Array: r"""Sample Geometric random values with given shape and float dtype. The values are returned according to the probability mass function: @@ -2345,16 +2543,16 @@ def geometric(key: ArrayLike, ``shape`` is not None, or else by ``p.shape``. """ key, _ = _check_prng_key("geometric", key) - dtypes.check_user_dtype_supported(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype( + int if dtype is None else dtype) if not dtypes.issubdtype(dtype, np.integer): raise ValueError("dtype argument to `geometric` must be an int " f"dtype, got {dtype}") - dtype = dtypes.canonicalize_dtype(dtype) if shape is not None: shape = core.canonicalize_shape(shape) return _geometric(key, p, shape, dtype) -@partial(jit, static_argnums=(2, 3)) +@jit(static_argnums=(2, 3)) def _geometric(key, p, shape, dtype) -> Array: if shape is None: shape = np.shape(p) @@ -2375,7 +2573,7 @@ def triangular(key: ArrayLike, mode: RealArray, right: RealArray, shape: Shape | None = None, - dtype: DTypeLikeFloat = float) -> Array: + dtype: DTypeLikeFloat | None = None) -> Array: r"""Sample Triangular random values with given shape and float dtype. The values are returned according to the probability density function: @@ -2407,16 +2605,16 @@ def triangular(key: ArrayLike, ``shape`` is not None, or else by ``left.shape``, ``mode.shape`` and ``right.shape``. """ key, _ = _check_prng_key("triangular", key) - dtypes.check_user_dtype_supported(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype( + float if dtype is None else dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError("dtype argument to `triangular` must be a float " f"dtype, got {dtype}") - dtype = dtypes.canonicalize_dtype(dtype) if shape is not None: shape = core.canonicalize_shape(shape) return _triangular(key, left, mode, right, shape, dtype) -@partial(jit, static_argnums=(4, 5), inline=True) +@jit(static_argnums=(4, 5), inline=True) def _triangular(key, left, mode, right, shape, dtype) -> Array: # https://en.wikipedia.org/wiki/Triangular_distribution#Generating_triangular-distributed_random_variates if shape is None: @@ -2437,7 +2635,7 @@ def _triangular(key, left, mode, right, shape, dtype) -> Array: def lognormal(key: ArrayLike, sigma: RealArray = np.float32(1), shape: Shape | None = None, - dtype: DTypeLikeFloat = float) -> Array: + dtype: DTypeLikeFloat | None = None) -> Array: r""" Sample lognormal random values with given shape and float dtype. The values are distributed according to the probability density function: @@ -2460,16 +2658,16 @@ def lognormal(key: ArrayLike, A random array with the specified dtype and with shape given by ``shape``. """ key, _ = _check_prng_key("lognormal", key) - dtypes.check_user_dtype_supported(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype( + float if dtype is None else dtype) if not dtypes.issubdtype(dtype, np.inexact): raise ValueError(f"dtype argument to `lognormal` must be a float or complex dtype, " f"got {dtype}") - dtype = dtypes.canonicalize_dtype(dtype) if shape is not None: shape = core.canonicalize_shape(shape) return _lognormal(key, sigma, shape, dtype) -@partial(jit, static_argnums=(2, 3), inline=True) +@jit(static_argnums=(2, 3), inline=True) def _lognormal(key, sigma, shape, dtype) -> Array: if shape is None: shape = np.shape(sigma) @@ -2497,17 +2695,18 @@ def _stirling_approx_tail(k): dtype=k.dtype, ) use_tail_values = k <= 9 - k = lax.clamp(_lax_const(k, 0.0), k, _lax_const(k, 9.0)) + k = lax.clamp(lax._const(k, 0.0), k, lax._const(k, 9.0)) kp1sq = (k + 1) * (k + 1) approx = (1.0 / 12 - (1.0 / 360 - 1.0 / 1260 / kp1sq) / kp1sq) / (k + 1) k = jnp.floor(k) - return lax.select(use_tail_values, stirling_tail_vals[jnp.int32(k)], approx) + return lax.select( + use_tail_values, stirling_tail_vals[jnp.asarray(k, dtype='int32')], approx) -@partial(jit, static_argnums=(3, 4, 5), inline=True) +@jit(static_argnums=(3, 4, 5), inline=True) def _binomial_inversion(key, count, prob, shape, dtype, max_iters): if config.enable_checks.value: - assert jnp.issubdtype(prob.dtype, jnp.floating) + assert dtypes.issubdtype(prob.dtype, np.floating) log1minusprob = jnp.log1p(-prob) @@ -2527,11 +2726,11 @@ def cond_fn(carry): num_geom_init = lax.full_like(prob, 0, prob.dtype, shape) geom_sum_init = lax.full_like(prob, 0, prob.dtype, shape) carry = (0, num_geom_init, geom_sum_init, key) - k = lax.while_loop(cond_fn, body_fn, carry)[1] + k = lax_control_flow.while_loop(cond_fn, body_fn, carry)[1] return (k - 1).astype(dtype) -@partial(jit, static_argnums=(3, 4, 5), inline=True) +@jit(static_argnums=(3, 4, 5), inline=True) def _btrs(key, count, prob, shape, dtype, max_iters): # transforman-rejection algorithm # https://www.tandfonline.com/doi/abs/10.1080/00949659308811496 @@ -2575,18 +2774,18 @@ def cond_fn(carry): return (~accepted).any() & (i < max_iters) k_init = lax.full_like(prob, -1, prob.dtype, shape) - carry = (0, k_init, jnp.full(shape, False, jnp.bool_), key) - return lax.while_loop(cond_fn, body_fn, carry)[1].astype(dtype) + carry = (0, k_init, jnp.full(shape, False, bool), key) + return lax_control_flow.while_loop(cond_fn, body_fn, carry)[1].astype(dtype) -@partial(jit, static_argnums=(3, 4), inline=True) +@jit(static_argnums=(3, 4), inline=True) def _binomial(key, count, prob, shape, dtype) -> Array: # The implementation matches TensorFlow and TensorFlow Probability: # https://github.com/tensorflow/tensorflow/blob/v2.2.0-rc3/tensorflow/core/kernels/random_binomial_op.cc # and tensorflow_probability.substrates.jax.distributions.Binomial # For n * p < 10, we use the binomial inverse algorithm; otherwise btrs. if shape is None: - shape = jnp.broadcast_shapes(jnp.shape(count), jnp.shape(prob)) + shape = jnp.broadcast_shapes(np.shape(count), np.shape(prob)) else: _check_shape("binomial", shape, np.shape(count), np.shape(prob)) (prob,) = promote_dtypes_inexact(prob) @@ -2608,7 +2807,7 @@ def _binomial(key, count, prob, shape, dtype) -> Array: count_inv = lax.select(use_inversion, count, lax.full_like(count, 0.0)) count_btrs = lax.select(use_inversion, lax.full_like(count, 1e4), count) q_btrs = lax.select(use_inversion, lax.full_like(q, 0.5), q) - max_iters = dtype.type(jnp.finfo(dtype).max) + max_iters = dtype.type(dtypes.finfo(dtype).max) samples = lax.select( use_inversion, _binomial_inversion(key, count_inv, q, shape, dtype, max_iters), @@ -2619,14 +2818,14 @@ def _binomial(key, count, prob, shape, dtype) -> Array: invalid = (q_l_0 | q_is_nan | count_nan_or_neg) samples = lax.select( invalid, - jnp.full_like(samples, jnp.nan, dtype), + jnp.full_like(samples, np.nan, dtype), samples, ) # +inf count leads to inf samples = lax.select( count_inf & (~invalid), - jnp.full_like(samples, jnp.inf, dtype), + jnp.full_like(samples, np.inf, dtype), samples, ) @@ -2643,7 +2842,7 @@ def binomial( n: RealArray, p: RealArray, shape: Shape | None = None, - dtype: DTypeLikeFloat = float, + dtype: DTypeLikeFloat | None = None, ) -> Array: r"""Sample Binomial random values with given shape and float dtype. @@ -2674,12 +2873,12 @@ def binomial( """ key, _ = _check_prng_key("binomial", key) check_arraylike("binomial", n, p) - dtypes.check_user_dtype_supported(dtype) + dtype = dtypes.check_and_canonicalize_user_dtype( + float if dtype is None else dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError( f"dtype argument to `binomial` must be a float dtype, got {dtype}" ) - dtype = dtypes.canonicalize_dtype(dtype) if shape is not None: shape = core.canonicalize_shape(shape) return _binomial(key, n, p, shape, dtype) @@ -2699,7 +2898,7 @@ def multinomial( p: RealArray, *, shape: Shape | None = None, - dtype: DTypeLikeFloat = float, + dtype: DTypeLikeFloat | None = None, unroll: int | bool = 1, ): r"""Sample from a multinomial distribution. @@ -2743,11 +2942,11 @@ def f(remainder, ratio_key): p = jnp.moveaxis(p, -1, 0) - remaining_probs = lax.cumsum(p, 0, reverse=True) + remaining_probs = lax_control_flow.cumsum(p, 0, reverse=True) ratios = p / jnp.where(remaining_probs == 0, 1, remaining_probs) keys = split(key, ratios.shape[0]) - remainder, counts = lax.scan(f, n, (ratios, keys), unroll=unroll) + remainder, counts = lax_control_flow.scan(f, n, (ratios, keys), unroll=unroll) # final remainder should be zero return jnp.moveaxis(counts, 0, -1).astype(dtype) @@ -2769,3 +2968,29 @@ def clone(key): >>> assert data == same_data """ return random_clone_p.bind(key) + + +def random_insert_pvary(name, key, *args): + if not config._check_vma.value: + return key, args + if not args: + return key, args + key_vma = core.typeof(key).vma + out = [] + for a in args: + arg_vma = (aval.vma if isinstance(aval := core.typeof(a), core.ShapedArray) + else frozenset()) + # If key is less varying than the args, then it's an error and user should + # pvary at their level because it has key-reuse implications. They can + # shard the keys passed to shard_map correctly so as to avoid key-reuse + # getting correctly varying keys. But JAX shouldn't auto-pvary the key. + if key_vma - arg_vma: + a = core.pvary(a, tuple(k for k in key_vma if k not in arg_vma)) + if key_vma != core.typeof(a).vma: + raise TypeError( + f"{name} requires all arguments to have matching type. Got key type:" + f" {core.typeof(key)} vs arg type: {core.typeof(a)}. Use" + " jax.lax.pcast(..., to='varying') to make them match. If your key is" + " less varying than arg, watch out for key-reuse problems.") + out.append(a) + return key, out diff --git a/jax/_src/ref.py b/jax/_src/ref.py new file mode 100644 index 000000000000..bbd3a57c8393 --- /dev/null +++ b/jax/_src/ref.py @@ -0,0 +1,34 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any +from jax._src import core + + +def new_ref(init_val: Any, *, memory_space: Any = None) -> core.Ref: + """Create a mutable array reference with initial value ``init_val``. + + For more discussion, see the `Ref guide`_. + + Args: + init_val: A :class:`jax.Array` representing the initial state + of the buffer. + memory_space: An optional memory space attribute for the Ref. + + Returns: + A :class:`jax.ref.Ref` containing a reference to a mutable buffer. + + .. _Ref guide: https://docs.jax.dev/en/latest/array_refs.html + """ + return core.new_ref(init_val, memory_space=memory_space) diff --git a/jax/_src/scipy/cluster/vq.py b/jax/_src/scipy/cluster/vq.py index a82c8928644d..3d93351adf12 100644 --- a/jax/_src/scipy/cluster/vq.py +++ b/jax/_src/scipy/cluster/vq.py @@ -1,74 +1,75 @@ -# Copyright 2022 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -import operator - -from jax import vmap -import jax.numpy as jnp -from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact -from jax._src.typing import Array, ArrayLike - - -def vq(obs: ArrayLike, code_book: ArrayLike, check_finite: bool = True) -> tuple[Array, Array]: - """Assign codes from a code book to a set of observations. - - JAX implementation of :func:`scipy.cluster.vq.vq`. - - Assigns each observation vector in ``obs`` to a code from ``code_book`` - based on the nearest Euclidean distance. - - Args: - obs: array of observation vectors of shape ``(M, N)``. Each row represents - a single observation. If ``obs`` is one-dimensional, then each entry is - treated as a length-1 observation. - code_book: array of codes with shape ``(K, N)``. Each row represents a single - code vector. If ``code_book`` is one-dimensional, then each entry is treated - as a length-1 code. - check_finite: unused in JAX - - Returns: - A tuple of arrays ``(code, dist)`` - - - ``code`` is an integer array of shape ``(M,)`` containing indices ``0 <= i < K`` - of the closest entry in ``code_book`` for the given entry in ``obs``. - - ``dist`` is a float array of shape ``(M,)`` containing the euclidean - distance between each observation and the nearest code. - - Examples: - >>> obs = jnp.array([[1.1, 2.1, 3.1], - ... [5.9, 4.8, 6.2]]) - >>> code_book = jnp.array([[1., 2., 3.], - ... [2., 3., 4.], - ... [3., 4., 5.], - ... [4., 5., 6.]]) - >>> codes, distances = jax.scipy.cluster.vq.vq(obs, code_book) - >>> print(codes) - [0 3] - >>> print(distances) - [0.17320499 1.9209373 ] - """ - del check_finite # unused - check_arraylike("scipy.cluster.vq.vq", obs, code_book) - obs_arr, cb_arr = promote_dtypes_inexact(obs, code_book) - if obs_arr.ndim != cb_arr.ndim: - raise ValueError("Observation and code_book should have the same rank") - if obs_arr.ndim == 1: - obs_arr, cb_arr = obs_arr[..., None], cb_arr[..., None] - if obs_arr.ndim != 2: - raise ValueError("ndim different than 1 or 2 are not supported") - dist = vmap(lambda ob: jnp.linalg.norm(ob[None] - cb_arr, axis=-1))(obs_arr) - code = jnp.argmin(dist, axis=-1) - dist_min = vmap(operator.getitem)(dist, code) - return code, dist_min +# Copyright 2022 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import operator + +from jax._src import api +from jax._src import numpy as jnp +from jax._src.numpy import linalg as jnp_linalg +from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact +from jax._src.typing import Array, ArrayLike + + +def vq(obs: ArrayLike, code_book: ArrayLike, check_finite: bool = True) -> tuple[Array, Array]: + """Assign codes from a code book to a set of observations. + + JAX implementation of :func:`scipy.cluster.vq.vq`. + + Assigns each observation vector in ``obs`` to a code from ``code_book`` + based on the nearest Euclidean distance. + + Args: + obs: array of observation vectors of shape ``(M, N)``. Each row represents + a single observation. If ``obs`` is one-dimensional, then each entry is + treated as a length-1 observation. + code_book: array of codes with shape ``(K, N)``. Each row represents a single + code vector. If ``code_book`` is one-dimensional, then each entry is treated + as a length-1 code. + check_finite: unused in JAX + + Returns: + A tuple of arrays ``(code, dist)`` + + - ``code`` is an integer array of shape ``(M,)`` containing indices ``0 <= i < K`` + of the closest entry in ``code_book`` for the given entry in ``obs``. + - ``dist`` is a float array of shape ``(M,)`` containing the euclidean + distance between each observation and the nearest code. + + Examples: + >>> obs = jnp.array([[1.1, 2.1, 3.1], + ... [5.9, 4.8, 6.2]]) + >>> code_book = jnp.array([[1., 2., 3.], + ... [2., 3., 4.], + ... [3., 4., 5.], + ... [4., 5., 6.]]) + >>> codes, distances = jax.scipy.cluster.vq.vq(obs, code_book) + >>> print(codes) + [0 3] + >>> print(distances) + [0.17320499 1.9209373 ] + """ + del check_finite # unused + check_arraylike("scipy.cluster.vq.vq", obs, code_book) + obs_arr, cb_arr = promote_dtypes_inexact(obs, code_book) + if obs_arr.ndim != cb_arr.ndim: + raise ValueError("Observation and code_book should have the same rank") + if obs_arr.ndim == 1: + obs_arr, cb_arr = obs_arr[..., None], cb_arr[..., None] + if obs_arr.ndim != 2: + raise ValueError("ndim different than 1 or 2 are not supported") + dist = api.vmap(lambda ob: jnp_linalg.norm(ob[None] - cb_arr, axis=-1))(obs_arr) + code = jnp.argmin(dist, axis=-1) + dist_min = api.vmap(operator.getitem)(dist, code) + return code, dist_min diff --git a/jax/_src/scipy/fft.py b/jax/_src/scipy/fft.py index bac5f776a0c7..7a99366fb50f 100644 --- a/jax/_src/scipy/fft.py +++ b/jax/_src/scipy/fft.py @@ -17,16 +17,21 @@ from collections.abc import Sequence from functools import partial import math +import operator -from jax import lax -import jax.numpy as jnp -from jax._src.util import canonicalize_axis -from jax._src.numpy.util import promote_dtypes_complex, promote_dtypes_inexact +import numpy as np + +from jax._src import lax +from jax._src import numpy as jnp +from jax._src.numpy import fft as jnp_fft +from jax._src.numpy.util import ( + promote_dtypes_complex, promote_dtypes_inexact, ensure_arraylike) +from jax._src.util import canonicalize_axis, canonicalize_axis_tuple from jax._src.typing import Array def _W4(N: int, k: Array) -> Array: N_arr, k = promote_dtypes_complex(N, k) - return jnp.exp(-.5j * jnp.pi * k / N_arr) + return jnp.exp(-.5j * np.pi * k / N_arr) def _dct_interleave(x: Array, axis: int) -> Array: v0 = lax.slice_in_dim(x, None, None, 2, axis) @@ -97,6 +102,8 @@ def dct(x: Array, type: int = 2, n: int | None = None, [-1.75 0.73 1.01 -2.18] [ 1.33 -1.05 -2.34 -0.07]] """ + x = ensure_arraylike("idctn", x) + if type != 2: raise NotImplementedError('Only DCT type 2 is implemented.') if norm is not None and norm not in ['backward', 'ortho']: @@ -110,7 +117,7 @@ def dct(x: Array, type: int = 2, n: int | None = None, N = x.shape[axis] v = _dct_interleave(x, axis) - V = jnp.fft.fft(v, axis=axis) + V = jnp_fft.fft(v, axis=axis) k = lax.expand_dims(jnp.arange(N, dtype=V.real.dtype), [a for a in range(x.ndim) if a != axis]) out = V * _W4(N, k) out = 2 * out.real @@ -123,7 +130,7 @@ def _dct2(x: Array, axes: Sequence[int], norm: str | None) -> Array: axis1, axis2 = map(partial(canonicalize_axis, num_dims=x.ndim), axes) N1, N2 = x.shape[axis1], x.shape[axis2] v = _dct_interleave(_dct_interleave(x, axis1), axis2) - V = jnp.fft.fftn(v, axes=axes) + V = jnp_fft.fftn(v, axes=axes) k1 = lax.expand_dims(jnp.arange(N1, dtype=V.dtype), [a for a in range(x.ndim) if a != axis1]) k2 = lax.expand_dims(jnp.arange(N2, dtype=V.dtype), @@ -197,14 +204,22 @@ def dctn(x: Array, type: int = 2, ... print(jax.scipy.fft.dctn(x, s=[2, 4])) [[ 9.36 11.23 2.12 -10.97] [ 11.57 5.86 -1.37 -1.58]] -""" + """ + x = ensure_arraylike("idctn", x) + if type != 2: raise NotImplementedError('Only DCT type 2 is implemented.') if norm is not None and norm not in ['backward', 'ortho']: raise ValueError(f"jax.scipy.fft.dctn: {norm=!r} is not implemented") - if axes is None: - axes = range(x.ndim) + if s is not None: + try: + s = list(s) + except TypeError: + assert not isinstance(s, Sequence) + s = [operator.index(s)] + + axes = canonicalize_axis_tuple(axes, x.ndim) if len(axes) == 1: return dct(x, n=s[0] if s is not None else None, axis=axes[0], norm=norm) @@ -224,7 +239,7 @@ def dctn(x: Array, type: int = 2, def idct(x: Array, type: int = 2, n: int | None = None, - axis: int = -1, norm: str | None = None) -> Array: + axis: int = -1, norm: str | None = None) -> Array: """Computes the inverse discrete cosine transform of the input JAX implementation of :func:`scipy.fft.idct`. @@ -287,6 +302,8 @@ def idct(x: Array, type: int = 2, n: int | None = None, >>> jnp.allclose(x, jax.scipy.fft.idct(x_dct)) Array(True, dtype=bool) """ + x = ensure_arraylike("idct", x) + if type != 2: raise NotImplementedError('Only DCT type 2 is implemented.') if norm is not None and norm not in ['backward', 'ortho']: @@ -310,7 +327,7 @@ def idct(x: Array, type: int = 2, n: int | None = None, x = x / (_W4(N, k)) x = x * 2 * N - x = jnp.fft.ifft(x, axis=axis) + x = jnp_fft.ifft(x, axis=axis) # convert back to reals.. out = _dct_deinterleave(x.real, axis) return out @@ -386,13 +403,21 @@ def idctn(x: Array, type: int = 2, >>> jnp.allclose(x, jax.scipy.fft.idctn(x_dctn)) Array(True, dtype=bool) """ + x = ensure_arraylike("idctn", x) + if type != 2: raise NotImplementedError('Only DCT type 2 is implemented.') if norm is not None and norm not in ['backward', 'ortho']: raise ValueError(f"jax.scipy.fft.idctn: {norm=!r} is not implemented") - if axes is None: - axes = range(x.ndim) + if s is not None: + try: + s = list(s) + except TypeError: + assert not isinstance(s, Sequence) + s = [operator.index(s)] + + axes = canonicalize_axis_tuple(axes, x.ndim) if len(axes) == 1: return idct(x, n=s[0] if s is not None else None, axis=axes[0], norm=norm) diff --git a/jax/_src/scipy/integrate.py b/jax/_src/scipy/integrate.py index b61cdb163b8d..295d3d6a276e 100644 --- a/jax/_src/scipy/integrate.py +++ b/jax/_src/scipy/integrate.py @@ -14,14 +14,12 @@ from __future__ import annotations -from functools import partial - -from jax import jit +from jax._src.api import jit +from jax._src.numpy import lax_numpy from jax._src.typing import Array, ArrayLike -import jax.numpy as jnp -@partial(jit, static_argnames=('axis',)) +@jit(static_argnames=('axis',)) def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0, axis: int = -1) -> Array: r""" @@ -66,4 +64,4 @@ def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0, >>> jnp.allclose(result, jnp.pi) Array(True, dtype=bool) """ - return jnp.trapezoid(y, x, dx, axis) + return lax_numpy.trapezoid(y, x, dx, axis) diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index 9917cbaa0b12..e945610430dc 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -15,21 +15,23 @@ from __future__ import annotations from functools import partial - -import numpy as np import textwrap from typing import overload, Any, Literal -import jax -import jax.numpy as jnp -from jax import jit, vmap, jvp -from jax import lax +import numpy as np + +from jax._src import config from jax._src import dtypes +from jax._src import lax +from jax._src import numpy as jnp +from jax._src.api import jit, vmap, jvp from jax._src.lax import linalg as lax_linalg -from jax._src.lax import qdwh +from jax._src.numpy import linalg as jnp_linalg +from jax._src.numpy import vectorize as jnp_vectorize from jax._src.numpy.util import ( check_arraylike, promote_dtypes, promote_dtypes_inexact, - promote_dtypes_complex) + promote_dtypes_complex, promote_args_inexact) +from jax._src.tpu.linalg import qdwh from jax._src.typing import Array, ArrayLike @@ -39,7 +41,7 @@ """) _no_overwrite_and_chkfinite_doc = _no_chkfinite_doc + "\nDoes not support the Scipy argument ``overwrite_*=True``." -@partial(jit, static_argnames=('lower',)) +@jit(static_argnames=('lower',)) def _cholesky(a: ArrayLike, lower: bool) -> Array: a, = promote_dtypes_inexact(jnp.asarray(a)) l = lax_linalg.cholesky(a if lower else jnp.conj(a.mT), symmetrize_input=False) @@ -153,7 +155,7 @@ def cho_factor(a: ArrayLike, lower: bool = False, overwrite_a: bool = False, del overwrite_a, check_finite # Unused return (cholesky(a, lower=lower), lower) -@partial(jit, static_argnames=('lower',)) +@jit(static_argnames=('lower',)) def _cho_solve(c: ArrayLike, b: ArrayLike, lower: bool) -> Array: c, b = promote_dtypes_inexact(jnp.asarray(c), jnp.asarray(b)) lax_linalg._check_solve_shapes(c, b) @@ -219,7 +221,7 @@ def _svd(x: ArrayLike, *, full_matrices: bool, compute_uv: Literal[False]) -> Ar @overload def _svd(x: ArrayLike, *, full_matrices: bool, compute_uv: bool) -> Array | tuple[Array, Array, Array]: ... -@partial(jit, static_argnames=('full_matrices', 'compute_uv')) +@jit(static_argnames=('full_matrices', 'compute_uv')) def _svd(a: ArrayLike, *, full_matrices: bool, compute_uv: bool) -> Array | tuple[Array, Array, Array]: a, = promote_dtypes_inexact(jnp.asarray(a)) return lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv) @@ -352,7 +354,7 @@ def det(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> A Array([-2., 37.], dtype=float32) """ del overwrite_a, check_finite # unused - return jnp.linalg.det(a) + return jnp_linalg.det(a) @overload @@ -367,7 +369,7 @@ def _eigh(a: ArrayLike, b: ArrayLike | None, lower: bool, eigvals_only: Literal[ def _eigh(a: ArrayLike, b: ArrayLike | None, lower: bool, eigvals_only: bool, eigvals: None, type: int) -> Array | tuple[Array, Array]: ... -@partial(jit, static_argnames=('lower', 'eigvals_only', 'eigvals', 'type')) +@jit(static_argnames=('lower', 'eigvals_only', 'eigvals', 'type')) def _eigh(a: ArrayLike, b: ArrayLike | None, lower: bool, eigvals_only: bool, eigvals: None, type: int) -> Array | tuple[Array, Array]: if b is not None: @@ -418,22 +420,20 @@ def eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True, JAX implementation of :func:`scipy.linalg.eigh`. + Only the standard eigenvalue problem is supported: ``a @ v = lambda * v``. + The parameter `b` must be None; the generalized problem (``a @ v = lambda * b @ v``) + is not implemented. + Args: a: Hermitian input array of shape ``(..., N, N)`` - b: optional Hermitian input of shape ``(..., N, N)``. If specified, compute - the generalized eigenvalue problem. + b: Must be None. The generalized eigenvalue problem is not supported. lower: if True (default) access only the lower portion of the input matrix. Otherwise access only the upper portion. eigvals_only: If True, compute only the eigenvalues. If False (default) compute both eigenvalues and eigenvectors. - type: if ``b`` is specified, ``type`` gives the type of generalized eigenvalue - problem to be computed. Denoting ``(λ, v)`` as an eigenvalue, eigenvector pair: + type: Not used. Only type=1 is supported. - - ``type = 1`` solves ``a @ v = λ * b @ v`` (default) - - ``type = 2`` solves ``a @ b @ v = λ * v`` - - ``type = 3`` solves ``b @ a @ v = λ * v`` - - eigvals: a ``(low, high)`` tuple specifying which eigenvalues to compute. + eigvals: Not used. Only eigvals=None is supported. overwrite_a: unused by JAX. overwrite_b: unused by JAX. turbo: unused by JAX. @@ -446,6 +446,9 @@ def eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True, - ``eigvals``: array of shape ``(..., N)`` containing the eigenvalues. - ``eigvecs``: array of shape ``(..., N, N)`` containing the eigenvectors. + Raise: + NotImplementedError: If `b` is not None. + See also: - :func:`jax.numpy.linalg.eigh`: NumPy-style eigh API. - :func:`jax.lax.linalg.eigh`: XLA-style eigh API. @@ -477,7 +480,7 @@ def eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True, del overwrite_a, overwrite_b, turbo, check_finite # unused return _eigh(a, b, lower, eigvals_only, eigvals, type) -@partial(jit, static_argnames=('output',)) +@jit(static_argnames=('output',)) def _schur(a: Array, output: str) -> tuple[Array, Array]: if output == "complex": a = a.astype(dtypes.to_complex_dtype(a.dtype)) @@ -486,6 +489,8 @@ def _schur(a: Array, output: str) -> tuple[Array, Array]: def schur(a: ArrayLike, output: str = 'real') -> tuple[Array, Array]: """Compute the Schur decomposition + Only implemented on CPU. + JAX implementation of :func:`scipy.linalg.schur`. The Schur form `T` of a matrix `A` satisfies: @@ -602,10 +607,10 @@ def inv(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> A Array([ 0. , 1.25, -0.5 ], dtype=float32) """ del overwrite_a, check_finite # unused - return jnp.linalg.inv(a) + return jnp_linalg.inv(a) -@partial(jit, static_argnames=('overwrite_a', 'check_finite')) +@jit(static_argnames=('overwrite_a', 'check_finite')) def lu_factor(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> tuple[Array, Array]: """Factorization for LU-based linear solves @@ -657,7 +662,7 @@ def lu_factor(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True return lu, pivots -@partial(jit, static_argnames=('trans', 'overwrite_b', 'check_finite')) +@jit(static_argnames=('trans', 'overwrite_b', 'check_finite')) def lu_solve(lu_and_piv: tuple[Array, ArrayLike], b: ArrayLike, trans: int = 0, overwrite_b: bool = False, check_finite: bool = True) -> Array: """Solve a linear system using an LU factorization @@ -722,12 +727,12 @@ def _lu(a: ArrayLike, permute_l: Literal[False]) -> tuple[Array, Array, Array]: @overload def _lu(a: ArrayLike, permute_l: bool) -> tuple[Array, Array] | tuple[Array, Array, Array]: ... -@partial(jit, static_argnums=(1,)) +@jit(static_argnums=(1,)) def _lu(a: ArrayLike, permute_l: bool) -> tuple[Array, Array] | tuple[Array, Array, Array]: a, = promote_dtypes_inexact(jnp.asarray(a)) lu, _, permutation = lax_linalg.lu(a) dtype = lax.dtype(a) - m, n = jnp.shape(a) + m, n = np.shape(a) p = jnp.real(jnp.array(permutation[None, :] == jnp.arange(m, dtype=permutation.dtype)[:, None], dtype=dtype)) k = min(m, n) l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype) @@ -750,7 +755,7 @@ def lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False, check_finite: bool = True) -> tuple[Array, Array] | tuple[Array, Array, Array]: ... -@partial(jit, static_argnames=('permute_l', 'overwrite_a', 'check_finite')) +@jit(static_argnames=('permute_l', 'overwrite_a', 'check_finite')) def lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False, check_finite: bool = True) -> tuple[Array, Array] | tuple[Array, Array, Array]: """Compute the LU decomposition @@ -851,7 +856,7 @@ def _qr(a: ArrayLike, mode: str, pivoting: bool ) -> tuple[Array] | tuple[Array, Array] | tuple[Array, Array, Array]: ... -@partial(jit, static_argnames=('mode', 'pivoting')) +@jit(static_argnames=('mode', 'pivoting')) def _qr(a: ArrayLike, mode: str, pivoting: bool ) -> tuple[Array] | tuple[Array, Array] | tuple[Array, Array, Array]: if mode in ("full", "r"): @@ -992,10 +997,10 @@ def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = " return _qr(a, mode, pivoting) -@partial(jit, static_argnames=('assume_a', 'lower')) +@jit(static_argnames=('assume_a', 'lower')) def _solve(a: ArrayLike, b: ArrayLike, assume_a: str, lower: bool) -> Array: if assume_a != 'pos': - return jnp.linalg.solve(a, b) + return jnp_linalg.solve(a, b) a, b = promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b)) lax_linalg._check_solve_shapes(a, b) @@ -1079,7 +1084,7 @@ def solve(a: ArrayLike, b: ArrayLike, lower: bool = False, raise ValueError(f"Expected assume_a to be one of {valid_assume_a}; got {assume_a!r}") return _solve(a, b, assume_a, lower) -@partial(jit, static_argnames=('trans', 'lower', 'unit_diagonal')) +@jit(static_argnames=('trans', 'lower', 'unit_diagonal')) def _solve_triangular(a: ArrayLike, b: ArrayLike, trans: int | str, lower: bool, unit_diagonal: bool) -> Array: if trans == 0 or trans == "N": @@ -1094,7 +1099,7 @@ def _solve_triangular(a: ArrayLike, b: ArrayLike, trans: int | str, a, b = promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b)) # lax_linalg.triangular_solve only supports matrix 'b's at the moment. - b_is_vector = jnp.ndim(a) == jnp.ndim(b) + 1 + b_is_vector = np.ndim(a) == np.ndim(b) + 1 if b_is_vector: b = b[..., None] out = lax_linalg.triangular_solve(a, b, left_side=True, lower=lower, @@ -1172,7 +1177,7 @@ def solve_triangular(a: ArrayLike, b: ArrayLike, trans: int | str = 0, lower: bo return _solve_triangular(a, b, trans, lower, unit_diagonal) -@partial(jit, static_argnames=('upper_triangular', 'max_squarings')) +@jit(static_argnames=('upper_triangular', 'max_squarings')) def expm(A: ArrayLike, *, upper_triangular: bool = False, max_squarings: int = 16) -> Array: """Compute the matrix exponential @@ -1228,7 +1233,7 @@ def expm(A: ArrayLike, *, upper_triangular: bool = False, max_squarings: int = 1 raise ValueError(f"Expected A to be a (batched) square matrix, got {A.shape=}.") if A.ndim > 2: - return jnp.vectorize( + return jnp_vectorize.vectorize( partial(expm, upper_triangular=upper_triangular, max_squarings=max_squarings), signature="(n,n)->(n,n)")(A) @@ -1236,7 +1241,7 @@ def expm(A: ArrayLike, *, upper_triangular: bool = False, max_squarings: int = 1 def _nan(args): A, *_ = args - return jnp.full_like(A, jnp.nan) + return jnp.full_like(A, np.nan) def _compute(args): A, P, Q = args @@ -1251,7 +1256,7 @@ def _compute(args): def _calc_P_Q(A: Array) -> tuple[Array, Array, Array]: if A.ndim != 2 or A.shape[0] != A.shape[1]: raise ValueError('expected A to be a square matrix') - A_L1 = jnp.linalg.norm(A,1) + A_L1 = jnp_linalg.norm(A,1) n_squarings: Array U: Array V: Array @@ -1282,12 +1287,12 @@ def _solve_P_Q(P: ArrayLike, Q: ArrayLike, upper_triangular: bool = False) -> Ar if upper_triangular: return solve_triangular(Q, P) else: - return jnp.linalg.solve(Q, P) + return jnp_linalg.solve(Q, P) def _precise_dot(A: ArrayLike, B: ArrayLike) -> Array: return jnp.dot(A, B, precision=lax.Precision.HIGHEST) -@partial(jit, static_argnums=2) +@jit(static_argnums=2) def _squaring(R: Array, n_squarings: Array, max_squarings: int) -> Array: # squaring step to undo scaling def _squaring_precise(x): @@ -1372,7 +1377,7 @@ def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: bool = True) -> Array | tuple[Array, Array]: ... -@partial(jit, static_argnames=('method', 'compute_expm')) +@jit(static_argnames=('method', 'compute_expm')) def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: bool = True) -> Array | tuple[Array, Array]: """Compute the Frechet derivative of the matrix exponential. @@ -1455,25 +1460,28 @@ def block_diag(*arrs: ArrayLike) -> Array: [0., 0., 0., 1., 1., 1.]], dtype=float32) """ if len(arrs) == 0: - arrs = (jnp.zeros((1, 0)),) + arrs = (jnp.zeros((1, 0)),) arrs = tuple(promote_dtypes(*arrs)) - bad_shapes = [i for i, a in enumerate(arrs) if jnp.ndim(a) > 2] + bad_shapes = [i for i, a in enumerate(arrs) if np.ndim(a) > 2] if bad_shapes: raise ValueError("Arguments to jax.scipy.linalg.block_diag must have at " "most 2 dimensions, got {} at argument {}." .format(arrs[bad_shapes[0]], bad_shapes[0])) converted_arrs = [jnp.atleast_2d(a) for a in arrs] - acc = converted_arrs[0] - dtype = lax.dtype(acc) - for a in converted_arrs[1:]: - _, c = a.shape - a = lax.pad(a, dtype.type(0), ((0, 0, 0), (acc.shape[-1], 0, 0))) - acc = lax.pad(acc, dtype.type(0), ((0, 0, 0), (0, c, 0))) - acc = lax.concatenate([acc, a], dimension=0) - return acc + dtype = lax.dtype(converted_arrs[0]) + total_cols = sum(a.shape[1] for a in converted_arrs) + + padded_arrs = [] + current_col = 0 + for arr in converted_arrs: + cols = arr.shape[1] + padding_config = ((0, 0, 0), (current_col, total_cols - cols - current_col, 0)) + padded_arrs.append(lax.pad(arr, dtype.type(0), padding_config)) + current_col += cols + return jnp.concatenate(padded_arrs, axis=0) -@partial(jit, static_argnames=("eigvals_only", "select", "select_range")) +@jit(static_argnames=("eigvals_only", "select", "select_range")) def eigh_tridiagonal(d: ArrayLike, e: ArrayLike, *, eigvals_only: bool = False, select: str = 'a', select_range: tuple[float, float] | None = None, tol: float | None = None) -> Array: @@ -1527,8 +1535,8 @@ def eigh_tridiagonal(d: ArrayLike, e: ArrayLike, *, eigvals_only: bool = False, def _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, x): """Implements the Sturm sequence recurrence.""" n = alpha.shape[0] - zeros = jnp.zeros(x.shape, dtype=jnp.int32) - ones = jnp.ones(x.shape, dtype=jnp.int32) + zeros = jnp.zeros(x.shape, dtype=np.int32) + ones = jnp.ones(x.shape, dtype=np.int32) # The first step in the Sturm sequence recurrence # requires special care if x is equal to alpha[0]. @@ -1574,7 +1582,7 @@ def cond(iqc): alpha = jnp.asarray(d) beta = jnp.asarray(e) - supported_dtypes = (jnp.float32, jnp.float64, jnp.complex64, jnp.complex128) + supported_dtypes = (np.float32, np.float64, np.complex64, np.complex128) if alpha.dtype != beta.dtype: raise TypeError("diagonal and off-diagonal values must have same dtype, " f"got {alpha.dtype} and {beta.dtype}") @@ -1586,7 +1594,7 @@ def cond(iqc): if n <= 1: return jnp.real(alpha) - if jnp.issubdtype(alpha.dtype, np.complexfloating): + if dtypes.issubdtype(alpha.dtype, np.complexfloating): alpha = jnp.real(alpha) beta_sq = jnp.real(beta * jnp.conj(beta)) beta_abs = jnp.sqrt(beta_sq) @@ -1623,13 +1631,13 @@ def cond(iqc): # Determine the indices of the desired eigenvalues, based on select and # select_range. if select == 'a': - target_counts = jnp.arange(n, dtype=jnp.int32) + target_counts = jnp.arange(n, dtype=np.int32) elif select == 'i': if select_range is None: raise ValueError("for select='i', select_range must be specified.") if select_range[0] > select_range[1]: raise ValueError('Got empty index range in select_range.') - target_counts = jnp.arange(select_range[0], select_range[1] + 1, dtype=jnp.int32) + target_counts = jnp.arange(select_range[0], select_range[1] + 1, dtype=np.int32) elif select == 'v': # TODO(phawkins): requires dynamic shape support. raise NotImplementedError("eigh_tridiagonal(..., select='v') is not " @@ -1647,7 +1655,7 @@ def cond(iqc): # Pre-broadcast the scalars used in the Sturm sequence for improved # performance. - target_shape = jnp.shape(target_counts) + target_shape = np.shape(target_counts) lower = jnp.broadcast_to(lower, shape=target_shape) upper = jnp.broadcast_to(upper, shape=target_shape) mid = 0.5 * (upper + lower) @@ -1672,8 +1680,8 @@ def body(args): _, _, mid, _ = lax.while_loop(cond, body, (0, lower, mid, upper)) return mid -@partial(jit, static_argnames=('side', 'method')) -@jax.default_matmul_precision("float32") +@jit(static_argnames=('side', 'method')) +@config.default_matmul_precision("float32") def polar(a: ArrayLike, side: str = 'right', *, method: str = 'qdwh', eps: float | None = None, max_iterations: int | None = None) -> tuple[Array, Array]: r"""Computes the polar decomposition. @@ -1832,6 +1840,9 @@ def _sqrtm(A: ArrayLike) -> Array: def sqrtm(A: ArrayLike, blocksize: int = 1) -> Array: """Compute the matrix square root + This function is implemented using :func:`scipy.linalg.schur`, which is only + supported on CPU. + JAX implementation of :func:`scipy.linalg.sqrtm`. Args: @@ -1875,7 +1886,7 @@ def sqrtm(A: ArrayLike, blocksize: int = 1) -> Array: return _sqrtm(A) -@partial(jit, static_argnames=('check_finite',)) +@jit(static_argnames=('check_finite',)) def rsf2csf(T: ArrayLike, Z: ArrayLike, check_finite: bool = True) -> tuple[Array, Array]: """Convert real Schur form to complex Schur form. @@ -1938,15 +1949,15 @@ def rsf2csf(T: ArrayLike, Z: ArrayLike, check_finite: bool = True) -> tuple[Arra raise ValueError(f"Input array shapes must match: Z: {Z_arr.shape} vs. T: {T_arr.shape}") T_arr, Z_arr = promote_dtypes_complex(T_arr, Z_arr) - eps = jnp.finfo(T_arr.dtype).eps + eps = dtypes.finfo(T_arr.dtype).eps N = T_arr.shape[0] if N == 1: return T_arr, Z_arr def _update_T_Z(m, T, Z): - mu = jnp.linalg.eigvals(lax.dynamic_slice(T, (m-1, m-1), (2, 2))) - T[m, m] - r = jnp.linalg.norm(jnp.array([mu[0], T[m, m-1]])).astype(T.dtype) + mu = jnp_linalg.eigvals(lax.dynamic_slice(T, (m-1, m-1), (2, 2))) - T[m, m] + r = jnp_linalg.norm(jnp.array([mu[0], T[m, m-1]])).astype(T.dtype) c = mu[0] / r s = T[m, m-1] / r G = jnp.array([[c.conj(), s], [-s, c]], dtype=T.dtype) @@ -1960,7 +1971,7 @@ def _update_T_Z(m, T, Z): # T[:m+1, m-1:m+1] = T[:m+1, m-1:m+1] @ G.conj().T T_cols = lax.dynamic_slice_in_dim(T, m-1, 2, axis=1) - row_mask = jnp.arange(N)[:, jnp.newaxis] < m+1 + row_mask = jnp.arange(N)[:, np.newaxis] < m+1 T_zeroed_rows_dot_GH = jnp.where(row_mask, T_cols, 0) @ G.conj().T T_cols_new = jnp.where(~row_mask, T_cols, T_zeroed_rows_dot_GH) T = lax.dynamic_update_slice_in_dim(T, T_cols_new, m-1, axis=1) @@ -1992,7 +2003,7 @@ def hessenberg(a: ArrayLike, *, calc_q: Literal[True], overwrite_a: bool = False check_finite: bool = True) -> tuple[Array, Array]: ... -@partial(jit, static_argnames=('calc_q', 'check_finite', 'overwrite_a')) +@jit(static_argnames=('calc_q', 'check_finite', 'overwrite_a')) def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False, check_finite: bool = True) -> Array | tuple[Array, Array]: """Compute the Hessenberg form of the matrix @@ -2042,7 +2053,7 @@ def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False, Array(True, dtype=bool) """ del overwrite_a, check_finite # unused - n = jnp.shape(a)[-1] + n = np.shape(a)[-1] if n == 0: if calc_q: return jnp.zeros_like(a), jnp.zeros_like(a) @@ -2137,7 +2148,7 @@ def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array: check_arraylike("toeplitz", c, r) return _toeplitz(jnp.atleast_1d(jnp.asarray(c)), jnp.atleast_1d(jnp.asarray(r))) -@partial(jnp.vectorize, signature="(m),(n)->(m,n)") +@partial(jnp_vectorize.vectorize, signature="(m),(n)->(m,n)") def _toeplitz(c: Array, r: Array) -> Array: ncols, = c.shape nrows, = r.shape @@ -2151,7 +2162,7 @@ def _toeplitz(c: Array, r: Array) -> Array: precision=lax.Precision.HIGHEST)[0] return jnp.flip(patches, axis=0) -@partial(jit, static_argnames=("n",)) +@jit(static_argnames=("n",)) def hilbert(n: int) -> Array: r"""Create a Hilbert matrix of order n. @@ -2180,5 +2191,220 @@ def hilbert(n: int) -> Array: [0.5 , 0.33333334, 0.25 ], [0.33333334, 0.25 , 0.2 ]], dtype=float32) """ - a = lax.broadcasted_iota(jnp.float64, (n, 1), 0) + a = lax.broadcasted_iota(float, (n, 1), 0) return 1/(a + a.T + 1) + +@jit(static_argnames=("n", "kind",)) +def pascal(n: int, kind: str | None = None) -> Array: + r"""Create a Pascal matrix approximation of order n. + + JAX implementation of :func:`scipy.linalg.pascal`. + + The elements of the Pascal matrix approximate the binomial coefficients. This + implementation is not exact as JAX does not support exact factorials. + + Args: + n: the size of the matrix to create. + kind: (optional) must be one of ``lower``, ``upper``, or ``symmetric`` (default). + + Returns: + A Pascal matrix of shape ``(n, n)`` + + Examples: + >>> with jnp.printoptions(precision=3): + ... print(jax.scipy.linalg.pascal(3, kind="lower")) + ... print(jax.scipy.linalg.pascal(4, kind="upper")) + ... print(jax.scipy.linalg.pascal(5)) + [[1. 0. 0.] + [1. 1. 0.] + [1. 2. 1.]] + [[1. 1. 1. 1.] + [0. 1. 2. 3.] + [0. 0. 1. 3.] + [0. 0. 0. 1.]] + [[ 1. 1. 1. 1. 1.] + [ 1. 2. 3. 4. 5.] + [ 1. 3. 6. 10. 15.] + [ 1. 4. 10. 20. 35.] + [ 1. 5. 15. 35. 70.]] + """ + if kind is None: + kind = "symmetric" + + valid_kind = ["symmetric", "lower", "upper"] + + if kind not in valid_kind: + raise ValueError(f"Expected kind to be on of: {valid_kind}; got {kind}") + + a = jnp.arange(n, dtype=np.float32) + + L_n = _binom(a[:, None], a[None, :]) + + if kind == "lower": + return L_n + + if kind == "upper": + return L_n.T + + return jnp.dot(L_n, L_n.T) + +@jit +def _binom(n, k): + a = lax.lgamma(n + 1.0) + b = lax.lgamma(n - k + 1.0) + c = lax.lgamma(k + 1.0) + return lax.exp(a - b - c) + + +def _solve_sylvester_triangular_scan(R: Array, S: Array, F: Array) -> Array: + """ + Solves the Sylvester equation using Bartels-Stewart algorithm + .. math:: + + RY + YS^T = F + + where R and S are upper triangular matrices following a Schur decomposition. + + Args: + R: Matrix of shape m x m + S: Matrix of shape n x n + F: Matrix of shape m x n + + Returns: + Y: Matrix of shape m x n + """ + R, S, F = promote_args_inexact("_solve_sylvester_triangular_scan", R, S, F) + + m, n = F.shape + total = m * n + # scan the matrix from bottom-right to top-left + flat_indices = jnp.arange(total - 1, -1, -1) + Y0 = jnp.zeros((m * n,), dtype=F.dtype) + + def scan_fn(Y_flat, idx): + i = idx // n + j = idx % n + Y = Y_flat.reshape((m, n)) + rhs = F[i, j] + + # Row term: gets contributions from R and already filled in Y. mask ensures that we only get non-zero elements from R because it is upper triangular + k_row = jnp.arange(m) + row_mask = k_row > i + r_row = R[i, :] + y_col = Y[:, j] + row_term = jnp.sum(jnp.where(row_mask, r_row * y_col, 0.0)) + + # Col term: same as Row term but now uses S instead of R. + k_col = jnp.arange(n) + col_mask = k_col > j + y_row = Y[i, :] + s_col = S[:, j] + col_term = jnp.sum(jnp.where(col_mask, y_row * s_col, 0.0)) + + # Here we are solving for the current Y[i, j] + rhs -= row_term + col_term + val = rhs / (R[i, i] + S[j, j]) + + Y_flat = Y_flat.at[i * n + j].set(val) + return Y_flat, None + + Y_flat_final, _ = lax.scan(scan_fn, Y0, flat_indices) + return Y_flat_final.reshape((m, n)) + + +@jit(static_argnames=["method", "tol"]) +def solve_sylvester(A: ArrayLike, B: ArrayLike, C: ArrayLike, *, method: str = "schur", tol: float = 1e-8) -> Array: + """ + Solves the Sylvester equation + .. math:: + + AX + XB = C + + Using one of two methods. + + (1) Bartell-Stewart (schur) algorithm (default) [CPU ONLY]: + + Where A and B are first decomposed using Schur decomposition to construct and alternate sylvester equation: + .. math:: + + RY + YS^T = F + + Where R and S are in quasitriangular form when A and B are real valued and triangular when A and B are complex. + + (2) The Eigen decomposition algorithm [CPU and GPU] + + Args: + A: Matrix of shape m x m + B: Matrix of shape n x n + C: Matrix of shape m x n + method: "schur" is the default and is accurate but slow, and "eigen" is an alternative that is faster but less accurate for ill-conditioned matrices. + tol: How close the sum of the eigenvalues from A and B can be to zero before returning matrix of NaNs + + Returns: + X: Matrix of shape m x n + + Examples: + >>> A = jax.numpy.array([[1, 2], [3, 4]]) + >>> B = jax.numpy.array([[5, 6], [7, 8]]) + >>> C = jax.numpy.array([[6, 8], [10, 12]]) + >>> X = jax.scipy.linalg.solve_sylvester(A, B, C) + >>> print(X) # doctest: +SKIP + [[1. 0.] + [0. 1.]] + + Notes: + The Bartel-Stewart algorithm is robust because a Schur decomposition always exists even for defective matrices, + and it handles complex and ill-conditioned problems better than the eigen decomposition method. + However, there are a couple of drawbacks. First, It is computationally more expensive than + the eigen decomposition method because you need to perform a Schur decomposition and then scan the entire solution matrix. + Second, it requires more system memory compared to the eigen decomposition method. + + The eigen decomposition method is the fastest method to solve a sylvester equation. However, this speed brings with it a couple of drawbacks. + First, A and B must be diagonalizable otherwise the eigenvectors will be linearly dependent and ill-conditioned leading to accuracy issues. + Second, when the eigenvectors are not orthogonal roundoff errors are amplified. + + Additionally, for complex types as the size of the matrix increases the accuracy of the results degrades. Float64 types are most robust to degradation. + + The tol argument allows you to specify how ill-conditioned a matrix can be and still estimate a solution. + For matrices that are ill-conditioned we recommend using float64 instead of the default float32 dtype. The solver + can still return good estimates for ill-conditioned matrices depending on how close to zero the sums of the eigenvalues of A and B + are. + """ + A, B, C = promote_args_inexact("solve_sylvester", A, B, C) + + m, n = C.shape + + if A.shape != (m, m) or B.shape != (n, n) or C.shape != (m, n): + raise ValueError(f"Incompatible shapes for Sylvester equation:\nA: {A.shape}\nB: {B.shape}\nC: {C.shape}") + + if method == "schur": + # Schur decomposition + R, U = schur(A, output='complex') + S, V = schur(B.conj().T, output='complex') + + # Transform right-hand side + F = U.conj().T @ C.astype(R.dtype) @ V + + # Solve triangular Sylvester system + Y = _solve_sylvester_triangular_scan(R, S.conj().T, F) + + # Transform back + X = U @ Y @ V.conj().T + elif method == "eigen": + RA, UA = jnp.linalg.eig(A) + RB, UB = jnp.linalg.eig(B) + F = solve(UA, C.astype(RA.dtype) @ UB) + W = RA[:, None] + RB[None, :] + Y = F / W + X = UA[:m,:m] @ Y[:m,:n] @ inv(UB)[:n,:n] + else: + raise ValueError(f"Unrecognized method {method}. The two valid methods are either \"schur\" or \"eigen\".") + + if not dtypes.issubdtype(C.dtype, np.complexfloating): + X = X.real + + return lax.cond( + jnp.any(jnp.abs(jnp.linalg.eigvals(A)[:, None] + jnp.linalg.eigvals(B)[None, :]) < tol), + lambda: jnp.zeros_like(X) * np.nan, + lambda: X, + ) diff --git a/jax/_src/scipy/ndimage.py b/jax/_src/scipy/ndimage.py index ee144eaf990a..ce63d6ba8b53 100644 --- a/jax/_src/scipy/ndimage.py +++ b/jax/_src/scipy/ndimage.py @@ -17,10 +17,13 @@ import itertools import operator +import numpy as np + from jax._src import api +from jax._src import dtypes +from jax._src import numpy as jnp from jax._src import util -from jax import lax -import jax.numpy as jnp +from jax._src.lax import lax from jax._src.typing import ArrayLike, Array from jax._src.util import safe_zip as zip @@ -29,7 +32,7 @@ def _nonempty_prod(arrs: Sequence[Array]) -> Array: return functools.reduce(operator.mul, arrs) def _nonempty_sum(arrs: Sequence[Array]) -> Array: - return functools.reduce(operator.add, arrs) + return sum(arrs[1:], arrs[0]) def _mirror_index_fixer(index: Array, size: int) -> Array: s = size - 1 # Half-wavelength of triangular wave @@ -49,11 +52,11 @@ def _reflect_index_fixer(index: Array, size: int) -> Array: def _round_half_away_from_zero(a: Array) -> Array: - return a if jnp.issubdtype(a.dtype, jnp.integer) else lax.round(a) + return a if dtypes.issubdtype(a.dtype, np.integer) else lax.round(a) def _nearest_indices_and_weights(coordinate: Array) -> list[tuple[Array, ArrayLike]]: - index = _round_half_away_from_zero(coordinate).astype(jnp.int32) + index = _round_half_away_from_zero(coordinate).astype(np.int32) weight = coordinate.dtype.type(1) return [(index, weight)] @@ -62,7 +65,7 @@ def _linear_indices_and_weights(coordinate: Array) -> list[tuple[Array, ArrayLik lower = jnp.floor(coordinate) upper_weight = coordinate - lower lower_weight = 1 - upper_weight - index = lower.astype(jnp.int32) + index = lower.astype(np.int32) return [(index, lower_weight), (index + 1, upper_weight)] @@ -117,17 +120,11 @@ def _map_coordinates(input: ArrayLike, coordinates: Sequence[ArrayLike], contribution = jnp.where(all_valid, input_arr[indices], cval) outputs.append(_nonempty_prod(weights) * contribution) # type: ignore result = _nonempty_sum(outputs) - if jnp.issubdtype(input_arr.dtype, jnp.integer): + if dtypes.issubdtype(input_arr.dtype, np.integer): result = _round_half_away_from_zero(result) return result.astype(input_arr.dtype) -""" - Only nearest neighbor (``order=0``), linear interpolation (``order=1``) and - modes ``'constant'``, ``'nearest'``, ``'wrap'`` ``'mirror'`` and ``'reflect'`` are currently supported. - - """ - def map_coordinates( input: ArrayLike, coordinates: Sequence[ArrayLike], order: int, mode: str = 'constant', cval: ArrayLike = 0.0, diff --git a/jax/_src/scipy/optimize/_lbfgs.py b/jax/_src/scipy/optimize/_lbfgs.py index aa82ab4fd0c8..3f4767f101a7 100644 --- a/jax/_src/scipy/optimize/_lbfgs.py +++ b/jax/_src/scipy/optimize/_lbfgs.py @@ -19,9 +19,13 @@ from functools import partial from typing import NamedTuple -import jax -import jax.numpy as jnp -from jax import lax +import numpy as np + +from jax._src import api +from jax._src import dtypes +from jax._src import lax +from jax._src import numpy as jnp +from jax._src.numpy import linalg as jnp_linalg from jax._src.scipy.optimize.line_search import line_search from jax._src.typing import Array @@ -74,7 +78,7 @@ def _minimize_lbfgs( fun: Callable, x0: Array, maxiter: float | None = None, - norm=jnp.inf, + norm=np.inf, maxcor: int = 10, ftol: float = 2.220446049250313e-09, gtol: float = 1e-05, @@ -109,7 +113,7 @@ def _minimize_lbfgs( Optimization results. """ d = len(x0) - dtype = jnp.dtype(x0) + dtype = dtypes.dtype(x0) # ensure there is at least one termination condition if (maxiter is None) and (maxfun is None) and (maxgrad is None): @@ -117,14 +121,14 @@ def _minimize_lbfgs( # set others to inf, such that >= is supported if maxiter is None: - maxiter = jnp.inf + maxiter = np.inf if maxfun is None: - maxfun = jnp.inf + maxfun = np.inf if maxgrad is None: - maxgrad = jnp.inf + maxgrad = np.inf # initial evaluation - f_0, g_0 = jax.value_and_grad(fun)(x0) + f_0, g_0 = api.value_and_grad(fun)(x0) state_initial = LBFGSResults( converged=False, failed=False, @@ -160,7 +164,7 @@ def body_fun(state: LBFGSResults): ) # evaluate at next iterate - s_k = ls_results.a_k.astype(p_k.dtype) * p_k + s_k = jnp.asarray(ls_results.a_k).astype(p_k.dtype) * p_k x_kp1 = state.x_k + s_k f_kp1 = ls_results.f_k g_kp1 = ls_results.g_k @@ -177,7 +181,7 @@ def body_fun(state: LBFGSResults): status = jnp.where(state.k >= maxiter, 1, status) status = jnp.where(ls_results.failed, 5, status) - converged = jnp.linalg.norm(g_kp1, ord=norm) < gtol + converged = jnp_linalg.norm(g_kp1, ord=norm) < gtol # TODO(jakevdp): use a fixed-point procedure rather than type-casting? state = state._replace( diff --git a/jax/_src/scipy/optimize/bfgs.py b/jax/_src/scipy/optimize/bfgs.py index 657b7610e6e1..a52881c28fc2 100644 --- a/jax/_src/scipy/optimize/bfgs.py +++ b/jax/_src/scipy/optimize/bfgs.py @@ -19,10 +19,15 @@ from functools import partial from typing import NamedTuple -import jax -import jax.numpy as jnp -from jax import lax +import numpy as np + +from jax._src import api +from jax._src import lax +from jax._src import numpy as jnp +from jax._src.numpy import einsum as jnp_einsum +from jax._src.numpy import linalg as jnp_linalg from jax._src.scipy.optimize.line_search import line_search +from jax._src.typing import Array class _BFGSResults(NamedTuple): @@ -49,30 +54,30 @@ class _BFGSResults(NamedTuple): line_search_status: int describing line search end state (only means something if line search fails). """ - converged: bool | jax.Array - failed: bool | jax.Array - k: int | jax.Array - nfev: int | jax.Array - ngev: int | jax.Array - nhev: int | jax.Array - x_k: jax.Array - f_k: jax.Array - g_k: jax.Array - H_k: jax.Array - old_old_fval: jax.Array - status: int | jax.Array - line_search_status: int | jax.Array + converged: bool | Array + failed: bool | Array + k: int | Array + nfev: int | Array + ngev: int | Array + nhev: int | Array + x_k: Array + f_k: Array + g_k: Array + H_k: Array + old_old_fval: Array + status: int | Array + line_search_status: int | Array _dot = partial(jnp.dot, precision=lax.Precision.HIGHEST) -_einsum = partial(jnp.einsum, precision=lax.Precision.HIGHEST) +_einsum = partial(jnp_einsum.einsum, precision=lax.Precision.HIGHEST) def minimize_bfgs( fun: Callable, - x0: jax.Array, + x0: Array, maxiter: int | None = None, - norm=jnp.inf, + norm=np.inf, gtol: float = 1e-5, line_search_maxiter: int = 10, ) -> _BFGSResults: @@ -96,14 +101,14 @@ def minimize_bfgs( """ if maxiter is None: - maxiter = jnp.size(x0) * 200 + maxiter = np.size(x0) * 200 d = x0.shape[0] initial_H = jnp.eye(d, dtype=x0.dtype) - f_0, g_0 = jax.value_and_grad(fun)(x0) + f_0, g_0 = api.value_and_grad(fun)(x0) state = _BFGSResults( - converged=jnp.linalg.norm(g_0, ord=norm) < gtol, + converged=jnp_linalg.norm(g_0, ord=norm) < gtol, failed=False, k=0, nfev=1, @@ -113,7 +118,7 @@ def minimize_bfgs( f_k=f_0, g_k=g_0, H_k=initial_H, - old_old_fval=f_0 + jnp.linalg.norm(g_0) / 2, + old_old_fval=f_0 + jnp_linalg.norm(g_0) / 2, status=0, line_search_status=0, ) @@ -147,12 +152,12 @@ def body_fun(state): y_k = g_kp1 - state.g_k rho_k = jnp.reciprocal(_dot(y_k, s_k)) - sy_k = s_k[:, jnp.newaxis] * y_k[jnp.newaxis, :] + sy_k = s_k[:, np.newaxis] * y_k[np.newaxis, :] w = jnp.eye(d, dtype=rho_k.dtype) - rho_k * sy_k H_kp1 = (_einsum('ij,jk,lk', w, state.H_k, w) - + rho_k * s_k[:, jnp.newaxis] * s_k[jnp.newaxis, :]) + + rho_k * s_k[:, np.newaxis] * s_k[np.newaxis, :]) H_kp1 = jnp.where(jnp.isfinite(rho_k), H_kp1, state.H_k) - converged = jnp.linalg.norm(g_kp1, ord=norm) < gtol + converged = jnp_linalg.norm(g_kp1, ord=norm) < gtol state = state._replace( converged=converged, diff --git a/jax/_src/scipy/optimize/line_search.py b/jax/_src/scipy/optimize/line_search.py index 189009693cdd..6d16b67f1c66 100644 --- a/jax/_src/scipy/optimize/line_search.py +++ b/jax/_src/scipy/optimize/line_search.py @@ -17,16 +17,19 @@ from typing import NamedTuple from functools import partial +from jax._src import api +from jax._src import dtypes +from jax._src import lax +from jax._src import numpy as jnp from jax._src.numpy.util import promote_dtypes_inexact -import jax.numpy as jnp -import jax -from jax import lax +from jax._src.typing import Array + _dot = partial(jnp.dot, precision=lax.Precision.HIGHEST) def _cubicmin(a, fa, fpa, b, fb, c, fc): - dtype = jnp.result_type(a, fa, fpa, b, fb, c, fc) + dtype = dtypes.result_type(a, fa, fpa, b, fb, c, fc) C = fpa db = b - a dc = c - a @@ -59,23 +62,23 @@ def _binary_replace(replace_bit, original_dict, new_dict, keys=None): class _ZoomState(NamedTuple): - done: bool | jax.Array - failed: bool | jax.Array - j: int | jax.Array - a_lo: float | jax.Array - phi_lo: float | jax.Array - dphi_lo: float | jax.Array - a_hi: float | jax.Array - phi_hi: float | jax.Array - dphi_hi: float | jax.Array - a_rec: float | jax.Array - phi_rec: float | jax.Array - a_star: float | jax.Array - phi_star: float | jax.Array - dphi_star: float | jax.Array - g_star: float | jax.Array - nfev: int | jax.Array - ngev: int | jax.Array + done: bool | Array + failed: bool | Array + j: int | Array + a_lo: float | Array + phi_lo: float | Array + dphi_lo: float | Array + a_hi: float | Array + phi_hi: float | Array + dphi_hi: float | Array + a_rec: float | Array + phi_rec: float | Array + a_star: float | Array + phi_star: float | Array + dphi_star: float | Array + g_star: float | Array + nfev: int | Array + ngev: int | Array def _zoom(restricted_func_and_grad, wolfe_one, wolfe_two, a_lo, phi_lo, @@ -118,7 +121,7 @@ def body(state): # This will cause the line search to stop, and since the Wolfe conditions # are not satisfied the minimization should stop too. - threshold = jnp.where((jnp.finfo(dalpha.dtype).bits < 64), 1e-5, 1e-10) + threshold = jnp.where((dtypes.finfo(dalpha.dtype).bits < 64), 1e-5, 1e-10) state = state._replace(failed=state.failed | (dalpha <= threshold)) # Cubmin is sometimes nan, though in this case the bounds check will fail. @@ -188,6 +191,16 @@ def body(state): ), ), ) + state = state._replace( + **_binary_replace( + lo_to_j & ~hi_to_lo, + state._asdict(), + dict( + a_rec=state.a_lo, + phi_rec=state.phi_lo, + ), + ), + ) state = state._replace( **_binary_replace( lo_to_j, @@ -196,8 +209,6 @@ def body(state): a_lo=a_j, phi_lo=phi_j, dphi_lo=dphi_j, - a_rec=state.a_lo, - phi_rec=state.phi_lo, ), ), ) @@ -215,18 +226,18 @@ def body(state): class _LineSearchState(NamedTuple): - done: bool | jax.Array - failed: bool | jax.Array - i: int | jax.Array - a_i1: float | jax.Array - phi_i1: float | jax.Array - dphi_i1: float | jax.Array - nfev: int | jax.Array - ngev: int | jax.Array - a_star: float | jax.Array - phi_star: float | jax.Array - dphi_star: float | jax.Array - g_star: jax.Array + done: bool | Array + failed: bool | Array + i: int | Array + a_i1: float | Array + phi_i1: float | Array + dphi_i1: float | Array + nfev: int | Array + ngev: int | Array + a_star: float | Array + phi_star: float | Array + dphi_star: float | Array + g_star: Array class _LineSearchResults(NamedTuple): @@ -243,15 +254,15 @@ class _LineSearchResults(NamedTuple): g_k: final gradient value status: integer end status """ - failed: bool | jax.Array - nit: int | jax.Array - nfev: int | jax.Array - ngev: int | jax.Array - k: int | jax.Array - a_k: int | jax.Array - f_k: jax.Array - g_k: jax.Array - status: bool | jax.Array + failed: bool | Array + nit: int | Array + nfev: int | Array + ngev: int | Array + k: int | Array + a_k: int | Array + f_k: Array + g_k: Array + status: bool | Array def line_search(f, xk, pk, old_fval=None, old_old_fval=None, gfk=None, c1=1e-4, @@ -275,7 +286,7 @@ def line_search(f, xk, pk, old_fval=None, old_old_fval=None, gfk=None, c1=1e-4, xk, pk = promote_dtypes_inexact(xk, pk) def restricted_func_and_grad(t): t = jnp.array(t, dtype=pk.dtype) - phi, g = jax.value_and_grad(f)(xk + t * pk) + phi, g = api.value_and_grad(f)(xk + t * pk) dphi = jnp.real(_dot(g, pk)) return phi, dphi, g @@ -409,7 +420,7 @@ def body(state): # Step sizes which are too small causes the optimizer to get stuck with a # direction of zero in <64 bit mode - avoid with a floor on minimum step size. alpha_k = jnp.asarray(state.a_star) - alpha_k = jnp.where((jnp.finfo(alpha_k.dtype).bits != 64) + alpha_k = jnp.where((dtypes.finfo(alpha_k.dtype).bits != 64) & (jnp.abs(alpha_k) < 1e-8), jnp.sign(alpha_k) * 1e-8, alpha_k) diff --git a/jax/_src/scipy/optimize/minimize.py b/jax/_src/scipy/optimize/minimize.py index 4fc006be6df0..86c38350c960 100644 --- a/jax/_src/scipy/optimize/minimize.py +++ b/jax/_src/scipy/optimize/minimize.py @@ -15,13 +15,12 @@ from __future__ import annotations from collections.abc import Callable, Mapping -from typing import Any +from typing import Any, NamedTuple -import jax +from jax._src import numpy as jnp from jax._src.scipy.optimize.bfgs import minimize_bfgs from jax._src.scipy.optimize._lbfgs import _minimize_lbfgs -from typing import NamedTuple -import jax.numpy as jnp +from jax._src.typing import Array class OptimizeResults(NamedTuple): @@ -40,20 +39,20 @@ class OptimizeResults(NamedTuple): njev: integer number of gradient evaluations. nit: integer number of iterations of the optimization algorithm. """ - x: jax.Array - success: bool | jax.Array - status: int | jax.Array - fun: jax.Array - jac: jax.Array - hess_inv: jax.Array | None - nfev: int | jax.Array - njev: int | jax.Array - nit: int | jax.Array + x: Array + success: bool | Array + status: int | Array + fun: Array + jac: Array + hess_inv: Array | None + nfev: int | Array + njev: int | Array + nit: int | Array def minimize( fun: Callable, - x0: jax.Array, + x0: Array, args: tuple = (), *, method: str, diff --git a/jax/_src/scipy/signal.py b/jax/_src/scipy/signal.py index d950cd2ea395..d843194277d6 100644 --- a/jax/_src/scipy/signal.py +++ b/jax/_src/scipy/signal.py @@ -23,17 +23,21 @@ import numpy as np -import jax -import jax.numpy.fft -import jax.numpy as jnp -from jax import lax +from jax._src import api from jax._src import core from jax._src import dtypes +from jax._src import lax +from jax._src import numpy as jnp from jax._src.api_util import _ensure_index_tuple from jax._src.lax.lax import PrecisionLike +from jax._src.numpy import fft as jnp_fft from jax._src.numpy import linalg from jax._src.numpy.util import ( - check_arraylike, promote_dtypes_inexact, promote_dtypes_complex) + check_arraylike, + ensure_arraylike, + promote_dtypes_complex, + promote_dtypes_inexact, +) from jax._src.third_party.scipy import signal_helper from jax._src.typing import Array, ArrayLike from jax._src.util import canonicalize_axis, tuple_delete, tuple_insert @@ -108,7 +112,7 @@ def fftconvolve(in1: ArrayLike, in2: ArrayLike, mode: str = "full", if any(in1.shape[i] != in2.shape[i] for i in mapped_axes): raise ValueError(f"mapped axes must have same shape; got {in1.shape=} {in2.shape=} {axes=}") for ax in sorted(mapped_axes): - _fftconvolve = jax.vmap(_fftconvolve, in_axes=ax, out_axes=ax) + _fftconvolve = api.vmap(_fftconvolve, in_axes=ax, out_axes=ax) return _fftconvolve(in1, in2) def _fftconvolve_unbatched(in1: Array, in2: Array, mode: str) -> Array: @@ -126,13 +130,16 @@ def _fftconvolve_unbatched(in1: Array, in2: Array, mode: str) -> Array: if swap: in1, in2 = in2, in1 - if jnp.iscomplexobj(in1): - fft, ifft = jnp.fft.fftn, jnp.fft.ifftn + if (all(s1 == 1 or s2 == 1 for s1, s2 in zip(in1.shape, in2.shape))): + conv = in1 * in2 else: - fft, ifft = jnp.fft.rfftn, jnp.fft.irfftn - sp1 = fft(in1, fft_shape) - sp2 = fft(in2, fft_shape) - conv = ifft(sp1 * sp2, fft_shape) + if jnp.iscomplexobj(in1): + fft, ifft = jnp.fft.fftn, jnp.fft.ifftn + else: + fft, ifft = jnp.fft.rfftn, jnp.fft.irfftn + sp1 = fft(in1, fft_shape) + sp2 = fft(in2, fft_shape) + conv = ifft(sp1 * sp2, fft_shape) if mode == "full": out_shape = full_shape @@ -148,7 +155,7 @@ def _fftconvolve_unbatched(in1: Array, in2: Array, mode: str) -> Array: return lax.dynamic_slice(conv, start_indices, out_shape) -# Note: we do not re-use the code from jax.numpy.convolve here, because the handling +# Note: we do not reuse the code from jax.numpy.convolve here, because the handling # of padding differs slightly between the two implementations (particularly for # mode='same'). def _convolve_nd(in1: Array, in2: Array, mode: str, *, precision: PrecisionLike) -> Array: @@ -319,7 +326,7 @@ def convolve2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fill """ if boundary != 'fill' or fillvalue != 0: raise NotImplementedError("convolve2d() only supports boundary='fill', fillvalue=0") - if jnp.ndim(in1) != 2 or jnp.ndim(in2) != 2: + if np.ndim(in1) != 2 or np.ndim(in2) != 2: raise ValueError("convolve2d() only supports 2-dimensional inputs.") return _convolve_nd(in1, in2, mode, precision=precision) @@ -454,7 +461,7 @@ def correlate2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fil """ if boundary != 'fill' or fillvalue != 0: raise NotImplementedError("correlate2d() only supports boundary='fill', fillvalue=0") - if jnp.ndim(in1) != 2 or jnp.ndim(in2) != 2: + if np.ndim(in1) != 2 or np.ndim(in2) != 2: raise ValueError("correlate2d() only supports 2-dimensional inputs.") swap = all(s1 <= s2 for s1, s2 in zip(in1.shape, in2.shape)) @@ -566,13 +573,9 @@ def _fft_helper(x: Array, win: Array, detrend_func: Callable[[Array], Array], result = x[..., np.newaxis] else: step = nperseg - noverlap - batch_shape = list(batch_shape) - x = x.reshape((math.prod(batch_shape), signal_length, 1)) - result = jax.lax.conv_general_dilated_patches( - x, (nperseg,), (step,), - 'VALID', - dimension_numbers=('NTC', 'OIT', 'NTC')) - result = result.reshape(*batch_shape, *result.shape[-2:]) + starts = jnp.arange(signal_length - nperseg + 1, step=step) + slice_func = partial(lax.dynamic_slice_in_dim, operand=x, slice_size=nperseg, axis=-1) + result = api.vmap(slice_func, out_axes=-2)(start_index=starts) # Detrend each data segment individually result = detrend_func(result) @@ -584,9 +587,9 @@ def _fft_helper(x: Array, win: Array, detrend_func: Callable[[Array], Array], # Perform the fft on last axis. Zero-pads automatically if sides == 'twosided': - return jax.numpy.fft.fft(result, n=nfft) + return jnp_fft.fft(result, n=nfft) else: - return jax.numpy.fft.rfft(result.real, n=nfft) + return jnp_fft.rfft(result.real, n=nfft) def odd_ext(x: Array, n: int, axis: int = -1) -> Array: @@ -793,9 +796,9 @@ def detrend_func(d): sides = 'twosided' if sides == 'twosided': - freqs = jax.numpy.fft.fftfreq(nfft_int, 1/fs, dtype=freq_dtype) + freqs = jnp_fft.fftfreq(nfft_int, 1/fs, dtype=freq_dtype) elif sides == 'onesided': - freqs = jax.numpy.fft.rfftfreq(nfft_int, 1/fs, dtype=freq_dtype) + freqs = jnp_fft.rfftfreq(nfft_int, 1/fs, dtype=freq_dtype) # Perform the windowed FFTs result = _fft_helper(x, win, detrend_func, @@ -1034,16 +1037,16 @@ def _overlap_and_add(x: Array, step_size: int) -> Array: x = x.reshape((flat_batchsize, nframes, nstep_per_segment, step_size)) # For obtaining shifted signals, this routine reinterprets flattened array - # with a shrinked axis. With appropriate truncation/ padding, this operation + # with a shrunken axis. With appropriate truncation/ padding, this operation # pushes the last padded elements of the previous row to the head of the # current row. # See implementation of `overlap_and_add` in Tensorflow for details. x = x.transpose((0, 2, 1, 3)) # x: (B, S, N, T) x = jnp.pad(x, ((0, 0), (0, 0), (0, nframes), (0, 0))) # x: (B, S, N*2, T) - shrinked = x.shape[2] - 1 + shrunken = x.shape[2] - 1 x = x.reshape((flat_batchsize, -1)) - x = x[:, :(nstep_per_segment * shrinked * step_size)] - x = x.reshape((flat_batchsize, nstep_per_segment, shrinked * step_size)) + x = x[:, :(nstep_per_segment * shrunken * step_size)] + x = x.reshape((flat_batchsize, nstep_per_segment, shrunken * step_size)) # Finally, sum shifted segments, and truncate results to the output_size. x = x.sum(axis=1)[:, :output_size] @@ -1071,7 +1074,7 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann', noverlap: Number of points to overlap between segments (default: ``nperseg // 2``). nfft: Number of FFT points used in the STFT. If ``None`` (default), the value is determined from the size of ``Zxx``. - input_onesided: If Tru` (default), interpret the input as a one-sided STFT + input_onesided: If True (default), interpret the input as a one-sided STFT (positive frequencies only). If False, interpret the input as a two-sided STFT. boundary: If True (default), it is assumed that the input signal was extended at its boundaries by ``stft``. If `False`, the input signal is assumed to have been truncated at the boundaries by `stft`. @@ -1099,7 +1102,7 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann', [1. 2. 3. 2. 1. 0. 1. 2.] """ # Input validation - check_arraylike("istft", Zxx) + Zxx = ensure_arraylike("istft", Zxx) if Zxx.ndim < 2: raise ValueError('Input stft must be at least 2d!') freq_axis = canonicalize_axis(freq_axis, Zxx.ndim) @@ -1107,8 +1110,7 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann', if freq_axis == time_axis: raise ValueError('Must specify differing time and frequency axes!') - Zxx = jnp.asarray(Zxx, dtype=jax.dtypes.canonicalize_dtype( - np.result_type(Zxx, np.complex64))) + Zxx = jnp.asarray(Zxx, dtype=dtypes.to_complex_dtype(Zxx.dtype)) n_default = (2 * (Zxx.shape[freq_axis] - 1) if input_onesided else Zxx.shape[freq_axis]) @@ -1142,19 +1144,19 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann', Zxx = jnp.transpose(Zxx, outer_idxs + (freq_axis, time_axis)) # Perform IFFT - ifunc = jax.numpy.fft.irfft if input_onesided else jax.numpy.fft.ifft + ifunc = jnp_fft.irfft if input_onesided else jnp_fft.ifft # xsubs: [..., T, N], N is the number of frames, T is the frame length. xsubs = ifunc(Zxx, axis=-2, n=nfft)[..., :nperseg_int, :] # Get window as array - if window == 'hann': + if isinstance(window, str) and window == 'hann': # Implement the default case without scipy - win = jnp.array([1.0]) if nperseg_int == 1 else jnp.sin(jnp.linspace(0, jnp.pi, nperseg_int, endpoint=False)) ** 2 + win = jnp.array([1.0]) if nperseg_int == 1 else jnp.sin(jnp.linspace(0, np.pi, nperseg_int, endpoint=False)) ** 2 win = win.astype(xsubs.dtype) elif isinstance(window, (str, tuple)): # TODO(jakevdp): implement get_window() in JAX to remove optional scipy dependency try: - from scipy.signal import get_window + from scipy.signal import get_window # pytype: disable=import-error except ImportError as err: raise ImportError(f"scipy must be available to use {window=}") from err win = get_window(window, nperseg_int) diff --git a/jax/_src/scipy/sparse/linalg.py b/jax/_src/scipy/sparse/linalg.py index 560b3773dfe4..93ae9a88fb9e 100644 --- a/jax/_src/scipy/sparse/linalg.py +++ b/jax/_src/scipy/sparse/linalg.py @@ -18,22 +18,22 @@ import numpy as np -import jax -import jax.numpy as jnp -from jax import device_put -from jax import lax -from jax import scipy as jsp -from jax.tree_util import (tree_leaves, tree_map, tree_structure, - tree_reduce, Partial) - +from jax._src import api from jax._src import dtypes +from jax._src import lax +from jax._src import numpy as jnp from jax._src.lax import lax as lax_internal +from jax._src.numpy import einsum as jnp_einsum +from jax._src.scipy import linalg as jsp_linalg +from jax._src.tree_util import (tree_leaves, tree_map, tree_structure, + tree_reduce, Partial) +from jax._src.typing import Array from jax._src.util import safe_map as map _dot = partial(jnp.dot, precision=lax.Precision.HIGHEST) _vdot = partial(jnp.vdot, precision=lax.Precision.HIGHEST) -_einsum = partial(jnp.einsum, precision=lax.Precision.HIGHEST) +_einsum = partial(jnp_einsum.einsum, precision=lax.Precision.HIGHEST) # aliases for working with pytrees @@ -85,7 +85,7 @@ def _normalize_matvec(f): """Normalize an argument for computing matrix-vector products.""" if callable(f): return f - elif isinstance(f, (np.ndarray, jax.Array)): + elif isinstance(f, (np.ndarray, Array)): if f.ndim != 2 or f.shape[0] != f.shape[1]: raise ValueError( f'linear operator must be a square matrix, but has shape: {f.shape}') @@ -127,7 +127,7 @@ def body_fun(value): r0 = _sub(b, A(x0)) p0 = z0 = M(r0) - dtype = jnp.result_type(*tree_leaves(p0)) + dtype = dtypes.result_type(*tree_leaves(p0)) gamma0 = _vdot_real_tree(r0, z0).astype(dtype) initial_value = (x0, r0, gamma0, p0, 0) @@ -177,7 +177,7 @@ def body_fun(value): r0 = _sub(b, A(x0)) rho0 = alpha0 = omega0 = lax_internal._convert_element_type( - 1, *dtypes._lattice_result_type(*tree_leaves(b))) + 1, *dtypes.lattice_result_type(*tree_leaves(b))) initial_value = (x0, r0, r0, alpha0, omega0, rho0, r0, r0, 0) x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value) @@ -186,7 +186,7 @@ def body_fun(value): def _shapes(pytree): - return map(jnp.shape, tree_leaves(pytree)) + return map(np.shape, tree_leaves(pytree)) def _isolve(_isolve_solve, A, b, x0=None, *, tol=1e-5, atol=0.0, @@ -194,7 +194,7 @@ def _isolve(_isolve_solve, A, b, x0=None, *, tol=1e-5, atol=0.0, if x0 is None: x0 = tree_map(jnp.zeros_like, b) - b, x0 = device_put((b, x0)) + b, x0 = api.device_put((b, x0)) if maxiter is None: size = sum(bi.size for bi in tree_leaves(b)) @@ -296,10 +296,9 @@ def _safe_normalize(x, thresh=None): taken to be 0, and the normalized x to be the zero vector. """ norm = _norm(x) - dtype, weak_type = dtypes._lattice_result_type(*tree_leaves(x)) - dtype = dtypes.canonicalize_dtype(dtype) + dtype, weak_type = dtypes.lattice_result_type(*tree_leaves(x)) if thresh is None: - thresh = jnp.finfo(norm.dtype).eps + thresh = dtypes.finfo(norm.dtype).eps thresh = thresh.astype(dtype).real use_norm = norm > thresh @@ -398,9 +397,8 @@ def _kth_arnoldi_iteration(k, A, M, V, H): subspace is declared to have been found, in which case in which case the new vector is taken to be the zero vector. """ - dtype, _ = dtypes._lattice_result_type(*tree_leaves(V)) - dtype = dtypes.canonicalize_dtype(dtype) - eps = jnp.finfo(dtype).eps + dtype, _ = dtypes.lattice_result_type(*tree_leaves(V)) + eps = dtypes.finfo(dtype).eps v = tree_map(lambda x: x[..., k], V) # Gets V[:, k] v = M(A(v)) @@ -470,7 +468,7 @@ def _gmres_incremental(A, b, x0, unit_residual, residual_norm, ptol, restart, M) lambda x: jnp.pad(x[..., None], ((0, 0),) * x.ndim + ((0, restart),)), unit_residual, ) - dtype = jnp.result_type(*tree_leaves(b)) + dtype = dtypes.result_type(*tree_leaves(b)) # use eye() to avoid constructing a singular matrix in case of early # termination R = jnp.eye(restart, restart + 1, dtype=dtype) @@ -497,7 +495,7 @@ def arnoldi_qr_step(carry): k, residual_norm, V, R, beta_vec, _ = carry del k # Until we figure out how to pass this to the user. - y = jsp.linalg.solve_triangular(R[:, :-1].T, beta_vec[:-1]) + y = jsp_linalg.solve_triangular(R[:, :-1].T, beta_vec[:-1]) dx = tree_map(lambda X: _dot(X[..., :-1], y), V) x = _add(x0, dx) @@ -508,10 +506,10 @@ def arnoldi_qr_step(carry): def _lstsq(a, b): - # faster than jsp.linalg.lstsq + # faster than jsp_linalg.lstsq a2 = _dot(a.T.conj(), a) b2 = _dot(a.T.conj(), b) - return jsp.linalg.solve(a2, b2, assume_a='pos') + return jsp_linalg.solve(a2, b2, assume_a='pos') def _gmres_batched(A, b, x0, unit_residual, residual_norm, ptol, restart, M): @@ -530,8 +528,7 @@ def _gmres_batched(A, b, x0, unit_residual, residual_norm, ptol, restart, M): lambda x: jnp.pad(x[..., None], ((0, 0),) * x.ndim + ((0, restart),)), unit_residual, ) - dtype, weak_type = dtypes._lattice_result_type(*tree_leaves(b)) - dtype = dtypes.canonicalize_dtype(dtype) + dtype, weak_type = dtypes.lattice_result_type(*tree_leaves(b)) H = lax_internal._convert_element_type( jnp.eye(restart, restart + 1, dtype=dtype), weak_type=weak_type) @@ -668,7 +665,7 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None, A = _normalize_matvec(A) M = _normalize_matvec(M) - b, x0 = device_put((b, x0)) + b, x0 = api.device_put((b, x0)) size = sum(bi.size for bi in tree_leaves(b)) if maxiter is None: diff --git a/jax/_src/scipy/spatial/transform.py b/jax/_src/scipy/spatial/transform.py index e62ea4c3502a..40a15f92897d 100644 --- a/jax/_src/scipy/spatial/transform.py +++ b/jax/_src/scipy/spatial/transform.py @@ -18,9 +18,13 @@ import re import typing +import numpy as np -import jax -import jax.numpy as jnp +from jax._src import config +from jax._src import numpy as jnp +from jax._src.numpy import linalg as jnp_linalg +from jax._src.numpy import vectorize as jnp_vectorize +from jax._src.typing import Array class Rotation(typing.NamedTuple): @@ -58,7 +62,7 @@ class Rotation(typing.NamedTuple): See the scipy :class:`~scipy.spatial.transform.Rotation` documentation for further examples of manipulating Rotation objects. """ - quat: jax.Array + quat: Array @classmethod def concatenate(cls, rotations: typing.Sequence): @@ -66,7 +70,7 @@ def concatenate(cls, rotations: typing.Sequence): return cls(jnp.concatenate([rotation.quat for rotation in rotations])) @classmethod - def from_euler(cls, seq: str, angles: jax.Array, degrees: bool = False): + def from_euler(cls, seq: str, angles: Array, degrees: bool = False): """Initialize from Euler angles.""" num_axes = len(seq) if num_axes < 1 or num_axes > 3: @@ -85,22 +89,22 @@ def from_euler(cls, seq: str, angles: jax.Array, degrees: bool = False): return cls(_elementary_quat_compose(angles, axes, intrinsic, degrees)) @classmethod - def from_matrix(cls, matrix: jax.Array): + def from_matrix(cls, matrix: Array): """Initialize from rotation matrix.""" return cls(_from_matrix(matrix)) @classmethod - def from_mrp(cls, mrp: jax.Array): + def from_mrp(cls, mrp: Array): """Initialize from Modified Rodrigues Parameters (MRPs).""" return cls(_from_mrp(mrp)) @classmethod - def from_quat(cls, quat: jax.Array): + def from_quat(cls, quat: Array): """Initialize from quaternions.""" return cls(_normalize_quaternion(quat)) @classmethod - def from_rotvec(cls, rotvec: jax.Array, degrees: bool = False): + def from_rotvec(cls, rotvec: Array, degrees: bool = False): """Initialize from rotation vectors.""" return cls(_from_rotvec(rotvec, degrees)) @@ -112,7 +116,7 @@ def identity(cls, num: int | None = None, dtype=float): return cls(quat) @classmethod - def random(cls, random_key: jax.Array, num: int | None = None): + def random(cls, random_key: Array, num: int | None = None): """Generate uniformly distributed rotations.""" # Need to implement scipy.stats.special_ortho_group for this to work... raise NotImplementedError() @@ -134,7 +138,7 @@ def __mul__(self, other) -> Rotation: """Compose this rotation with the other.""" return Rotation.from_quat(_compose_quat(self.quat, other.quat)) - def apply(self, vectors: jax.Array, inverse: bool = False) -> jax.Array: + def apply(self, vectors: Array, inverse: bool = False) -> Array: """Apply this rotation to one or more vectors.""" return _apply(self.as_matrix(), vectors, inverse) @@ -152,22 +156,22 @@ def as_euler(self, seq: str, degrees: bool = False): raise ValueError("Expected consecutive axes to be different, " "got {}".format(seq)) axes = jnp.array([_elementary_basis_index(x) for x in seq.lower()]) - with jax.numpy_rank_promotion('allow'): + with config.numpy_rank_promotion('allow'): return _compute_euler_from_quat(self.quat, axes, extrinsic, degrees) - def as_matrix(self) -> jax.Array: + def as_matrix(self) -> Array: """Represent as rotation matrix.""" return _as_matrix(self.quat) - def as_mrp(self) -> jax.Array: + def as_mrp(self) -> Array: """Represent as Modified Rodrigues Parameters (MRPs).""" return _as_mrp(self.quat) - def as_rotvec(self, degrees: bool = False) -> jax.Array: + def as_rotvec(self, degrees: bool = False) -> Array: """Represent as rotation vectors.""" return _as_rotvec(self.quat, degrees) - def as_quat(self, canonical: bool=False, scalar_first: bool=False) -> jax.Array: + def as_quat(self, canonical: bool=False, scalar_first: bool=False) -> Array: """Represent as quaternions.""" quat = _make_canonical(self.quat) if canonical else self.quat if scalar_first: @@ -178,11 +182,11 @@ def inv(self): """Invert this rotation.""" return Rotation(_inv(self.quat)) - def magnitude(self) -> jax.Array: + def magnitude(self) -> Array: """Get the magnitude(s) of the rotation(s).""" return _magnitude(self.quat) - def mean(self, weights: jax.Array | None = None): + def mean(self, weights: Array | None = None): """Get the mean of the rotations.""" w = jnp.ones(self.quat.shape[0], dtype=self.quat.dtype) if weights is None else jnp.asarray(weights, dtype=self.quat.dtype) if w.ndim != 1: @@ -192,8 +196,8 @@ def mean(self, weights: jax.Array | None = None): raise ValueError("Expected `weights` to have number of values " "equal to number of rotations, got " "{} values and {} rotations.".format(w.shape[0], len(self))) - K = jnp.dot(w[jnp.newaxis, :] * self.quat.T, self.quat) - _, v = jnp.linalg.eigh(K) + K = jnp.dot(w[np.newaxis, :] * self.quat.T, self.quat) + _, v = jnp_linalg.eigh(K) return Rotation(v[:, -1]) @property @@ -228,13 +232,13 @@ class Slerp(typing.NamedTuple): [ 0.0000000e+00, 0.0000000e+00, -5.2359891e-01]], dtype=float32) """ - times: jnp.ndarray - timedelta: jnp.ndarray + times: Array + timedelta: Array rotations: Rotation - rotvecs: jnp.ndarray + rotvecs: Array @classmethod - def init(cls, times: jax.Array, rotations: Rotation): + def init(cls, times: Array, rotations: Rotation): if not isinstance(rotations, Rotation): raise TypeError("`rotations` must be a `Rotation` instance.") if rotations.single or len(rotations) == 1: @@ -258,7 +262,7 @@ def init(cls, times: jax.Array, rotations: Rotation): rotations=new_rotations, rotvecs=(new_rotations.inv() * Rotation(rotations.as_quat()[1:])).as_rotvec()) - def __call__(self, times: jax.Array): + def __call__(self, times: Array): """Interpolate rotations.""" compute_times = jnp.asarray(times, dtype=self.times.dtype) if compute_times.ndim > 1: @@ -273,13 +277,13 @@ def __call__(self, times: jax.Array): return result -@functools.partial(jnp.vectorize, signature='(m,m),(m),()->(m)') -def _apply(matrix: jax.Array, vector: jax.Array, inverse: bool) -> jax.Array: +@functools.partial(jnp_vectorize.vectorize, signature='(m,m),(m),()->(m)') +def _apply(matrix: Array, vector: Array, inverse: bool) -> Array: return jnp.where(inverse, matrix.T, matrix) @ vector -@functools.partial(jnp.vectorize, signature='(m)->(n,n)') -def _as_matrix(quat: jax.Array) -> jax.Array: +@functools.partial(jnp_vectorize.vectorize, signature='(m)->(n,n)') +def _as_matrix(quat: Array) -> Array: x = quat[0] y = quat[1] z = quat[2] @@ -299,15 +303,15 @@ def _as_matrix(quat: jax.Array) -> jax.Array: [2 * (xz - yw), 2 * (yz + xw), - x2 - y2 + z2 + w2]]) -@functools.partial(jnp.vectorize, signature='(m)->(n)') -def _as_mrp(quat: jax.Array) -> jax.Array: +@functools.partial(jnp_vectorize.vectorize, signature='(m)->(n)') +def _as_mrp(quat: Array) -> Array: sign = jnp.where(quat[3] < 0, -1., 1.) denominator = 1. + sign * quat[3] return sign * quat[:3] / denominator -@functools.partial(jnp.vectorize, signature='(m),()->(n)') -def _as_rotvec(quat: jax.Array, degrees: bool) -> jax.Array: +@functools.partial(jnp_vectorize.vectorize, signature='(m),()->(n)') +def _as_rotvec(quat: Array, degrees: bool) -> Array: quat = jnp.where(quat[3] < 0, -quat, quat) # w > 0 to ensure 0 <= angle <= pi angle = 2. * jnp.arctan2(_vector_norm(quat[:3]), quat[3]) angle2 = angle * angle @@ -318,16 +322,16 @@ def _as_rotvec(quat: jax.Array, degrees: bool) -> jax.Array: return scale * jnp.array(quat[:3]) -@functools.partial(jnp.vectorize, signature='(n),(n)->(n)') -def _compose_quat(p: jax.Array, q: jax.Array) -> jax.Array: +@functools.partial(jnp_vectorize.vectorize, signature='(n),(n)->(n)') +def _compose_quat(p: Array, q: Array) -> Array: cross = jnp.cross(p[:3], q[:3]) return jnp.array([p[3]*q[0] + q[3]*p[0] + cross[0], p[3]*q[1] + q[3]*p[1] + cross[1], p[3]*q[2] + q[3]*p[2] + cross[2], p[3]*q[3] - p[0]*q[0] - p[1]*q[1] - p[2]*q[2]]) -@functools.partial(jnp.vectorize, signature='(m),(l),(),()->(n)') -def _compute_euler_from_quat(quat: jax.Array, axes: jax.Array, extrinsic: bool, degrees: bool) -> jax.Array: +@functools.partial(jnp_vectorize.vectorize, signature='(m),(l),(),()->(n)') +def _compute_euler_from_quat(quat: Array, axes: Array, extrinsic: bool, degrees: bool) -> Array: angle_first = jnp.where(extrinsic, 0, 2) angle_third = jnp.where(extrinsic, 2, 0) axes = jnp.where(extrinsic, axes, axes[::-1]) @@ -344,7 +348,7 @@ def _compute_euler_from_quat(quat: jax.Array, axes: jax.Array, extrinsic: bool, d = jnp.where(symmetric, quat[k] * sign, quat[k] * sign - quat[i]) angles = jnp.empty(3, dtype=quat.dtype) angles = angles.at[1].set(2 * jnp.arctan2(jnp.hypot(c, d), jnp.hypot(a, b))) - case = jnp.where(jnp.abs(angles[1] - jnp.pi) <= eps, 2, 0) + case = jnp.where(jnp.abs(angles[1] - np.pi) <= eps, 2, 0) case = jnp.where(jnp.abs(angles[1]) <= eps, 1, case) half_sum = jnp.arctan2(b, a) half_diff = jnp.arctan2(d, c) @@ -352,8 +356,8 @@ def _compute_euler_from_quat(quat: jax.Array, axes: jax.Array, extrinsic: bool, angles = angles.at[angle_first].set(jnp.where(case == 0, half_sum - half_diff, angles[angle_first])) angles = angles.at[angle_third].set(jnp.where(case == 0, half_sum + half_diff, angles[angle_third])) angles = angles.at[angle_third].set(jnp.where(symmetric, angles[angle_third], angles[angle_third] * sign)) - angles = angles.at[1].set(jnp.where(symmetric, angles[1], angles[1] - jnp.pi / 2)) - angles = (angles + jnp.pi) % (2 * jnp.pi) - jnp.pi + angles = angles.at[1].set(jnp.where(symmetric, angles[1], angles[1] - np.pi / 2)) + angles = (angles + np.pi) % (2 * np.pi) - np.pi return jnp.where(degrees, jnp.rad2deg(angles), angles) @@ -367,8 +371,8 @@ def _elementary_basis_index(axis: str) -> int: raise ValueError(f"Expected axis to be from ['x', 'y', 'z'], got {axis}") -@functools.partial(jnp.vectorize, signature=('(m),(m),(),()->(n)')) -def _elementary_quat_compose(angles: jax.Array, axes: jax.Array, intrinsic: bool, degrees: bool) -> jax.Array: +@functools.partial(jnp_vectorize.vectorize, signature=('(m),(m),(),()->(n)')) +def _elementary_quat_compose(angles: Array, axes: Array, intrinsic: bool, degrees: bool) -> Array: angles = jnp.where(degrees, jnp.deg2rad(angles), angles) result = _make_elementary_quat(axes[0], angles[0]) for idx in range(1, len(axes)): @@ -377,8 +381,8 @@ def _elementary_quat_compose(angles: jax.Array, axes: jax.Array, intrinsic: bool return result -@functools.partial(jnp.vectorize, signature=('(m),()->(n)')) -def _from_rotvec(rotvec: jax.Array, degrees: bool) -> jax.Array: +@functools.partial(jnp_vectorize.vectorize, signature=('(m),()->(n)')) +def _from_rotvec(rotvec: Array, degrees: bool) -> Array: rotvec = jnp.where(degrees, jnp.deg2rad(rotvec), rotvec) angle = _vector_norm(rotvec) angle2 = angle * angle @@ -388,8 +392,8 @@ def _from_rotvec(rotvec: jax.Array, degrees: bool) -> jax.Array: return jnp.hstack([scale * rotvec, jnp.cos(angle / 2)]) -@functools.partial(jnp.vectorize, signature=('(m,m)->(n)')) -def _from_matrix(matrix: jax.Array) -> jax.Array: +@functools.partial(jnp_vectorize.vectorize, signature=('(m,m)->(n)')) +def _from_matrix(matrix: Array) -> Array: matrix_trace = matrix[0, 0] + matrix[1, 1] + matrix[2, 2] decision = jnp.array([matrix[0, 0], matrix[1, 1], matrix[2, 2], matrix_trace], dtype=matrix.dtype) choice = jnp.argmax(decision) @@ -410,42 +414,42 @@ def _from_matrix(matrix: jax.Array) -> jax.Array: return _normalize_quaternion(quat) -@functools.partial(jnp.vectorize, signature='(m)->(n)') -def _from_mrp(mrp: jax.Array) -> jax.Array: +@functools.partial(jnp_vectorize.vectorize, signature='(m)->(n)') +def _from_mrp(mrp: Array) -> Array: mrp_squared_plus_1 = jnp.dot(mrp, mrp) + 1 return jnp.hstack([2 * mrp[:3], (2 - mrp_squared_plus_1)]) / mrp_squared_plus_1 -@functools.partial(jnp.vectorize, signature='(n)->(n)') -def _inv(quat: jax.Array) -> jax.Array: +@functools.partial(jnp_vectorize.vectorize, signature='(n)->(n)') +def _inv(quat: Array) -> Array: return quat * jnp.array([-1, -1, -1, 1], dtype=quat.dtype) -@functools.partial(jnp.vectorize, signature='(n)->()') -def _magnitude(quat: jax.Array) -> jax.Array: +@functools.partial(jnp_vectorize.vectorize, signature='(n)->()') +def _magnitude(quat: Array) -> Array: return 2. * jnp.arctan2(_vector_norm(quat[:3]), jnp.abs(quat[3])) -@functools.partial(jnp.vectorize, signature='(),()->(n)') -def _make_elementary_quat(axis: int, angle: jax.Array) -> jax.Array: +@functools.partial(jnp_vectorize.vectorize, signature='(),()->(n)') +def _make_elementary_quat(axis: int, angle: Array) -> Array: quat = jnp.zeros(4, dtype=angle.dtype) quat = quat.at[3].set(jnp.cos(angle / 2.)) quat = quat.at[axis].set(jnp.sin(angle / 2.)) return quat -@functools.partial(jnp.vectorize, signature='(n)->(n)') -def _normalize_quaternion(quat: jax.Array) -> jax.Array: +@functools.partial(jnp_vectorize.vectorize, signature='(n)->(n)') +def _normalize_quaternion(quat: Array) -> Array: return quat / _vector_norm(quat) -@functools.partial(jnp.vectorize, signature='(n)->()') -def _vector_norm(vector: jax.Array) -> jax.Array: +@functools.partial(jnp_vectorize.vectorize, signature='(n)->()') +def _vector_norm(vector: Array) -> Array: return jnp.sqrt(jnp.dot(vector, vector)) -@functools.partial(jnp.vectorize, signature='(n)->(n)') -def _make_canonical(quat: jax.Array) -> jax.Array: +@functools.partial(jnp_vectorize.vectorize, signature='(n)->(n)') +def _make_canonical(quat: Array) -> Array: is_neg = quat < 0 is_zero = quat == 0 diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index a24736ccfec0..e8d581162881 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -20,17 +20,20 @@ import numpy as np -import jax.numpy as jnp -from jax import jit -from jax import jvp -from jax import vmap -from jax import lax - +from jax._src import api_util +from jax._src import config from jax._src import core from jax._src import custom_derivatives from jax._src import deprecations +from jax._src import dispatch from jax._src import dtypes +from jax._src import lax +from jax._src import numpy as jnp +from jax._src.numpy.ufuncs import isposinf, isneginf, sinc +from jax._src.api import jit, jvp, vmap from jax._src.lax.lax import _const as _lax_const +from jax._src.numpy import einsum as jnp_einsum +from jax._src.numpy import vectorize as jnp_vectorize from jax._src.numpy.util import promote_args_inexact, promote_dtypes_inexact from jax._src.ops import special as ops_special from jax._src.third_party.scipy.betaln import betaln as _betaln_impl @@ -38,7 +41,6 @@ from jax._src.nn.functions import softmax as nn_softmax from jax._src.nn.functions import log_softmax as nn_log_softmax - def gammaln(x: ArrayLike) -> Array: r"""Natural log of the absolute value of the gamma function. @@ -101,6 +103,8 @@ def gammasgn(x: ArrayLike) -> Array: - :func:`jax.scipy.special.gammaln`: the natural log of the gamma function """ x, = promote_args_inexact("gammasgn", x) + if dtypes.issubdtype(x.dtype, np.complexfloating): + raise ValueError("gammasgn does not support complex-valued inputs.") typ = x.dtype.type floor_x = lax.floor(x) x_negative = x < 0 @@ -235,6 +239,8 @@ def beta(a: ArrayLike, b: ArrayLike) -> Array: - :func:`jax.scipy.special.betaln` """ a, b = promote_args_inexact("beta", a, b) + if dtypes.issubdtype(a.dtype, np.complexfloating): + raise ValueError("beta does not support complex-valued inputs.") sign = gammasgn(a) * gammasgn(b) * gammasgn(a + b) return sign * lax.exp(betaln(a, b)) @@ -246,7 +252,7 @@ def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array: .. math:: - \mathrm{betainc}(a, b, x) = B(a, b)\int_0^x t^{a-1}(1-t^{b-1})\mathrm{d}t + \mathrm{betainc}(a, b, x) = \frac{1}{B(a, b)}\int_0^x t^{a-1}(1-t)^{b-1}\mathrm{d}t where :math:`B(a, b)` is the :func:`~jax.scipy.special.beta` function. @@ -263,6 +269,8 @@ def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array: - :func:`jax.scipy.special.betaln` """ a, b, x = promote_args_inexact("betainc", a, b, x) + if dtypes.issubdtype(x.dtype, np.complexfloating): + raise ValueError("betainc does not support complex-valued inputs.") return lax.betainc(a, b, x) @@ -568,6 +576,8 @@ def entr(x: ArrayLike) -> Array: - :func:`jax.scipy.special.rel_entr` """ x, = promote_args_inexact("entr", x) + if dtypes.issubdtype(x.dtype, np.complexfloating): + raise ValueError("entr does not support complex-valued inputs.") return lax.select(lax.lt(x, _lax_const(x, 0)), lax.full_like(x, -np.inf), lax.neg(_xlogx(x))) @@ -641,6 +651,8 @@ def kl_div( - :func:`jax.scipy.special.rel_entr` """ p, q = promote_args_inexact("kl_div", p, q) + if dtypes.issubdtype(p.dtype, np.complexfloating): + raise ValueError("kl_div does not support complex-valued inputs.") return rel_entr(p, q) - p + q @@ -672,6 +684,8 @@ def rel_entr( - :func:`jax.scipy.special.kl_div` """ p, q = promote_args_inexact("rel_entr", p, q) + if dtypes.issubdtype(p.dtype, np.complexfloating): + raise ValueError("rel_entr does not support complex-valued inputs.") zero = _lax_const(p, 0.0) both_gt_zero_mask = lax.bitwise_and(lax.gt(p, zero), lax.gt(q, zero)) one_zero_mask = lax.bitwise_and(lax.eq(p, zero), lax.ge(q, zero)) @@ -680,13 +694,13 @@ def rel_entr( safe_q = jnp.where(both_gt_zero_mask, q, 1) log_val = lax.sub(_xlogx(safe_p), xlogy(safe_p, safe_q)) result = jnp.where( - both_gt_zero_mask, log_val, jnp.where(one_zero_mask, zero, jnp.inf) + both_gt_zero_mask, log_val, jnp.where(one_zero_mask, zero, np.inf) ) return result # coefs of (2k)! / B_{2k} where B are bernoulli numbers # those numbers are obtained using https://www.wolframalpha.com -_BERNOULLI_COEFS = [ +_BERNOULLI_COEFS = np.array([ 12, -720, 30240, @@ -703,7 +717,7 @@ def rel_entr( -37893265687455865519472640000000 / 3392780147, 759790291646040068357842010112000000 / 1723168255201, -134196726836183700385281186201600000000 / 7709321041217, -] +]) @custom_derivatives.custom_jvp @@ -746,7 +760,7 @@ def _zeta_series_expansion(x: ArrayLike, q: ArrayLike | None = None) -> Array: dtype = lax.dtype(a).type s_, a_ = jnp.expand_dims(s, -1), jnp.expand_dims(a, -1) # precision ~ N, M - N = M = dtype(8) if lax.dtype(a) == jnp.float32 else dtype(16) + N = M = dtype(8) if lax.dtype(a) == np.float32 else dtype(16) assert M <= len(_BERNOULLI_COEFS) k = jnp.expand_dims(np.arange(N, dtype=N.dtype), tuple(range(a.ndim))) S = jnp.sum((a_ + k) ** -s_, -1) @@ -755,7 +769,7 @@ def _zeta_series_expansion(x: ArrayLike, q: ArrayLike | None = None) -> Array: m = jnp.expand_dims(np.arange(2 * M, dtype=M.dtype), tuple(range(s.ndim))) s_over_a = (s_ + m) / (a_ + N) T1 = jnp.cumprod(s_over_a, -1)[..., ::2] - T1 = jnp.clip(T1, max=jnp.finfo(dtype).max) + T1 = jnp.clip(T1, max=dtypes.finfo(dtype).max) coefs = np.expand_dims(np.array(_BERNOULLI_COEFS[:T1.shape[-1]], dtype=dtype), tuple(range(a.ndim))) T1 = T1 / coefs @@ -772,9 +786,10 @@ def polygamma(n: ArrayLike, x: ArrayLike) -> Array: .. math:: - \mathrm{polygamma}(n, x) = \psi^{(n)}(x) = \frac{\mathrm{d}^n}{\mathrm{d}x^n}\log \Gamma(x) + \mathrm{polygamma}(n, x) = \psi^{(n)}(x) = \frac{\mathrm{d}^{n+1}}{\mathrm{d}x^{n+1}} \log \Gamma(x) - where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function. + where :math:`\psi` is the :func:`~jax.scipy.special.digamma` function and + :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function. Args: n: arraylike, integer-valued. The order of the derivative. @@ -787,8 +802,13 @@ def polygamma(n: ArrayLike, x: ArrayLike) -> Array: - :func:`jax.scipy.special.gamma` - :func:`jax.scipy.special.digamma` """ - assert jnp.issubdtype(lax.dtype(n), jnp.integer) + if not dtypes.issubdtype(lax.dtype(n), np.integer): + raise ValueError( + f"Argument `n` to polygamma must be of integer type. Got dtype {lax.dtype(n)}." + ) n_arr, x_arr = promote_args_inexact("polygamma", n, x) + if dtypes.issubdtype(x_arr.dtype, np.complexfloating): + raise ValueError("polygamma does not support complex-valued inputs.") return lax.polygamma(n_arr, x_arr) @@ -896,7 +916,7 @@ def ndtr(x: ArrayLike) -> Array: """ x = jnp.asarray(x) dtype = lax.dtype(x) - if dtype not in (jnp.float32, jnp.float64): + if dtype not in (np.float32, np.float64): raise TypeError( "x.dtype={} is not supported, see docstring for supported types." .format(dtype)) @@ -938,7 +958,7 @@ def ndtri(p: ArrayLike) -> Array: TypeError: if `p` is not floating-type. """ dtype = lax.dtype(p) - if dtype not in (jnp.float32, jnp.float64): + if dtype not in (np.float32, np.float64): raise TypeError( "x.dtype={} is not supported, see docstring for supported types." .format(dtype)) @@ -947,71 +967,62 @@ def ndtri(p: ArrayLike) -> Array: def _ndtri(p: ArrayLike) -> Array: """Implements ndtri core logic.""" + dtype = lax.dtype(p).type + shape = np.shape(p) # Constants used in piece-wise rational approximations. Taken from the cephes # library: # https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html - p0 = list(reversed([-5.99633501014107895267E1, - 9.80010754185999661536E1, - -5.66762857469070293439E1, - 1.39312609387279679503E1, - -1.23916583867381258016E0])) - q0 = list(reversed([1.0, - 1.95448858338141759834E0, - 4.67627912898881538453E0, - 8.63602421390890590575E1, - -2.25462687854119370527E2, - 2.00260212380060660359E2, - -8.20372256168333339912E1, - 1.59056225126211695515E1, - -1.18331621121330003142E0])) - p1 = list(reversed([4.05544892305962419923E0, - 3.15251094599893866154E1, - 5.71628192246421288162E1, - 4.40805073893200834700E1, - 1.46849561928858024014E1, - 2.18663306850790267539E0, - -1.40256079171354495875E-1, - -3.50424626827848203418E-2, - -8.57456785154685413611E-4])) - q1 = list(reversed([1.0, - 1.57799883256466749731E1, - 4.53907635128879210584E1, - 4.13172038254672030440E1, - 1.50425385692907503408E1, - 2.50464946208309415979E0, - -1.42182922854787788574E-1, - -3.80806407691578277194E-2, - -9.33259480895457427372E-4])) - p2 = list(reversed([3.23774891776946035970E0, - 6.91522889068984211695E0, - 3.93881025292474443415E0, - 1.33303460815807542389E0, - 2.01485389549179081538E-1, - 1.23716634817820021358E-2, - 3.01581553508235416007E-4, - 2.65806974686737550832E-6, - 6.23974539184983293730E-9])) - q2 = list(reversed([1.0, - 6.02427039364742014255E0, - 3.67983563856160859403E0, - 1.37702099489081330271E0, - 2.16236993594496635890E-1, - 1.34204006088543189037E-2, - 3.28014464682127739104E-4, - 2.89247864745380683936E-6, - 6.79019408009981274425E-9])) - - dtype = lax.dtype(p).type - shape = jnp.shape(p) - - def _create_polynomial(var, coeffs): - """Compute n_th order polynomial via Horner's method.""" - coeffs = np.array(coeffs, dtype) - if not coeffs.size: - return jnp.zeros_like(var) - return coeffs[0] + _create_polynomial(var, coeffs[1:]) * var - + p0 = np.array([-5.99633501014107895267E1, + 9.80010754185999661536E1, + -5.66762857469070293439E1, + 1.39312609387279679503E1, + -1.23916583867381258016E0], dtype=dtype) + q0 = np.array([1.0, + 1.95448858338141759834E0, + 4.67627912898881538453E0, + 8.63602421390890590575E1, + -2.25462687854119370527E2, + 2.00260212380060660359E2, + -8.20372256168333339912E1, + 1.59056225126211695515E1, + -1.18331621121330003142E0], dtype=dtype) + p1 = np.array([4.05544892305962419923E0, + 3.15251094599893866154E1, + 5.71628192246421288162E1, + 4.40805073893200834700E1, + 1.46849561928858024014E1, + 2.18663306850790267539E0, + -1.40256079171354495875E-1, + -3.50424626827848203418E-2, + -8.57456785154685413611E-4], dtype=dtype) + q1 = np.array([1.0, + 1.57799883256466749731E1, + 4.53907635128879210584E1, + 4.13172038254672030440E1, + 1.50425385692907503408E1, + 2.50464946208309415979E0, + -1.42182922854787788574E-1, + -3.80806407691578277194E-2, + -9.33259480895457427372E-4], dtype=dtype) + p2 = np.array([3.23774891776946035970E0, + 6.91522889068984211695E0, + 3.93881025292474443415E0, + 1.33303460815807542389E0, + 2.01485389549179081538E-1, + 1.23716634817820021358E-2, + 3.01581553508235416007E-4, + 2.65806974686737550832E-6, + 6.23974539184983293730E-9], dtype=dtype) + q2 = np.array([1.0, + 6.02427039364742014255E0, + 3.67983563856160859403E0, + 1.37702099489081330271E0, + 2.16236993594496635890E-1, + 1.34204006088543189037E-2, + 3.28014464682127739104E-4, + 2.89247864745380683936E-6, + 6.79019408009981274425E-9], dtype=dtype) maybe_complement_p = jnp.where(p > dtype(-np.expm1(-2.)), dtype(1.) - p, p) # Write in an arbitrary value in place of 0 for p since 0 will cause NaNs @@ -1025,8 +1036,7 @@ def _create_polynomial(var, coeffs): # Compute x for p > exp(-2): x/sqrt(2pi) = w + w**3 P0(w**2)/Q0(w**2). w = sanitized_mcp - dtype(0.5) ww = lax.square(w) - x_for_big_p = w + w * ww * (_create_polynomial(ww, p0) - / _create_polynomial(ww, q0)) + x_for_big_p = w + w * ww * (jnp.polyval(p0, ww) / jnp.polyval(q0, ww)) x_for_big_p *= -dtype(np.sqrt(2. * np.pi)) # Compute x for p <= exp(-2): x = z - log(z)/z - (1/z) P(1/z) / Q(1/z), @@ -1034,12 +1044,8 @@ def _create_polynomial(var, coeffs): # arrays based on whether p < exp(-32). z = lax.sqrt(dtype(-2.) * lax.log(sanitized_mcp)) first_term = z - lax.log(z) / z - second_term_small_p = ( - _create_polynomial(dtype(1.) / z, p2) / - _create_polynomial(dtype(1.) / z, q2) / z) - second_term_otherwise = ( - _create_polynomial(dtype(1.) / z, p1) / - _create_polynomial(dtype(1.) / z, q1) / z) + second_term_small_p = jnp.polyval(p2, 1 / z) / jnp.polyval(q2, 1 / z) / z + second_term_otherwise = jnp.polyval(p1, 1 / z) / jnp.polyval(q1, 1 / z) / z x_for_small_p = first_term - second_term_small_p x_otherwise = first_term - second_term_otherwise @@ -1048,10 +1054,17 @@ def _create_polynomial(var, coeffs): jnp.where(z >= dtype(8.0), x_for_small_p, x_otherwise)) x = jnp.where(p > dtype(1. - np.exp(-2.)), x, -x) - infinity = jnp.full(shape, dtype(np.inf)) - x_fix_boundaries = jnp.where( - p == dtype(0.0), -infinity, jnp.where(p == dtype(1.0), infinity, x)) - return x_fix_boundaries + with config.debug_infs(False): + infinity = jnp.full(shape, dtype(np.inf)) + x = jnp.where( + p == dtype(0.0), -infinity, jnp.where(p == dtype(1.0), infinity, x)) + if not isinstance(x, core.Tracer): + try: + dispatch.check_special("ndtri", [x]) + except api_util.InternalFloatingPointError as e: + raise FloatingPointError( + f"invalid value ({e.ty}) encountered in ndtri.") from None + return x @partial(custom_derivatives.custom_jvp, nondiff_argnums=(1,)) @@ -1127,10 +1140,10 @@ def log_ndtr(x: ArrayLike, series_order: int = 3) -> Array: x_arr = jnp.asarray(x) dtype = lax.dtype(x_arr) - if dtype == jnp.float64: + if dtype == np.float64: lower_segment: np.ndarray = _LOGNDTR_FLOAT64_LOWER upper_segment: np.ndarray = _LOGNDTR_FLOAT64_UPPER - elif dtype == jnp.float32: + elif dtype == np.float32: lower_segment = _LOGNDTR_FLOAT32_LOWER upper_segment = _LOGNDTR_FLOAT32_UPPER else: @@ -1342,7 +1355,7 @@ def _bessel_jn(z: ArrayLike, *, v: int, n_iter: int=50) -> Array: return j_vals -@partial(jit, static_argnames=["v", "n_iter"]) +@jit(static_argnames=["v", "n_iter"]) def bessel_jn(z: ArrayLike, *, v: int, n_iter: int=50) -> Array: """Bessel function of the first kind of integer order and real argument. @@ -1429,13 +1442,13 @@ def _gen_recurrence_mask( i, j, k = jnp.ogrid[:l_max + 1, :l_max + 1, :l_max + 1] mask = (i + j - k == 0).astype(dtype) - d0_mask_3d = jnp.einsum('jk,ijk->ijk', d0_mask, mask) - d1_mask_3d = jnp.einsum('jk,ijk->ijk', d1_mask, mask) + d0_mask_3d = jnp_einsum.einsum('jk,ijk->ijk', d0_mask, mask) + d1_mask_3d = jnp_einsum.einsum('jk,ijk->ijk', d1_mask, mask) return (d0_mask_3d, d1_mask_3d) -@partial(jit, static_argnums=(2)) +@jit(static_argnums=(2)) def _gen_derivatives(p: Array, x: Array, is_normalized: bool) -> Array: @@ -1472,14 +1485,14 @@ def _gen_derivatives(p: Array, l_vec = jnp.arange(1, num_l - 1, dtype=x.dtype) p_p1 = p[1, 1:num_l - 1, :] coeff = -1.0 / ((l_vec + 1) * l_vec) - update_p_p1 = jnp.einsum('i,ij->ij', coeff, p_p1) + update_p_p1 = jnp_einsum.einsum('i,ij->ij', coeff, p_p1) p_mm2_lm1 = p_mm2_lm1.at[1, 2:num_l, :].set(update_p_p1) if num_l > 2: l_vec = jnp.arange(2, num_l - 1, dtype=x.dtype) p_p2 = p[2, 2:num_l - 1, :] coeff = 1.0 / ((l_vec + 2) * (l_vec + 1) * l_vec * (l_vec - 1)) - update_p_p2 = jnp.einsum('i,ij->ij', coeff, p_p2) + update_p_p2 = jnp_einsum.einsum('i,ij->ij', coeff, p_p2) p_mm2_lm1 = p_mm2_lm1.at[0, 3:num_l, :].set(update_p_p2) m_mat, l_mat = jnp.meshgrid( @@ -1501,8 +1514,8 @@ def _gen_derivatives(p: Array, c0_masked = c0_masked.at[1, :].set(zero_vec) # p_l^{m-1}. - p_mm1_l = (jnp.einsum('ij,ijk->ijk', a0_masked, p_m_lm1) + - jnp.einsum('ij,ijk->ijk', c0_masked, p_mm2_lm1)) + p_mm1_l = (jnp_einsum.einsum('ij,ijk->ijk', a0_masked, p_m_lm1) + + jnp_einsum.einsum('ij,ijk->ijk', c0_masked, p_mm2_lm1)) d0 = -0.5 / (m_mat + 1.0) d0_masked = coeff_zeros.at[upper_0_indices].set(d0[upper_0_indices]) @@ -1510,27 +1523,27 @@ def _gen_derivatives(p: Array, e0_masked = coeff_zeros.at[upper_0_indices].set(e0[upper_0_indices]) # p_l^{m+1}. - p_mp1_l = (jnp.einsum('ij,ijk->ijk', d0_masked, p_mp2_lm1) + - jnp.einsum('ij,ijk->ijk', e0_masked, p_m_lm1)) + p_mp1_l = (jnp_einsum.einsum('ij,ijk->ijk', d0_masked, p_mp2_lm1) + + jnp_einsum.einsum('ij,ijk->ijk', e0_masked, p_m_lm1)) f0 = b0 * (l_mat - m_mat + 1.0) / 2.0 f0_masked = coeff_zeros.at[upper_0_indices].set(f0[upper_0_indices]) - p_derivative = jnp.einsum('ij,ijk->ijk', f0_masked, p_mm1_l) - 0.5 * p_mp1_l + p_derivative = jnp_einsum.einsum('ij,ijk->ijk', f0_masked, p_mm1_l) - 0.5 * p_mp1_l # Special treatment of the singularity at m = 1. if num_m > 1: l_vec = jnp.arange(num_l, dtype=p.dtype) - g0 = jnp.einsum('i,ij->ij', (l_vec + 1) * l_vec, p[0, :, :]) + g0 = jnp_einsum.einsum('i,ij->ij', (l_vec + 1) * l_vec, p[0, :, :]) if num_l > 2: g0 = g0 - p[2, :, :] - p_derivative_m0 = jnp.einsum('j,ij->ij', 0.5 / jnp.sqrt(1 - x * x), g0) + p_derivative_m0 = jnp_einsum.einsum('j,ij->ij', 0.5 / jnp.sqrt(1 - x * x), g0) p_derivative = p_derivative.at[1, :, :].set(p_derivative_m0) p_derivative = p_derivative.at[1, 0, :].set(0) return p_derivative -@partial(jit, static_argnums=(0, 2)) +@jit(static_argnums=(0, 2)) def _gen_associated_legendre(l_max: int, x: Array, is_normalized: bool) -> Array: @@ -1584,7 +1597,7 @@ def _gen_associated_legendre(l_max: int, a_idx = jnp.arange(1, l_max + 1, dtype=x.dtype) b_idx = jnp.arange(l_max, dtype=x.dtype) if is_normalized: - initial_value: ArrayLike = 0.5 / jnp.sqrt(jnp.pi) # The initial value p(0,0). + initial_value: ArrayLike = 0.5 / jnp.sqrt(np.pi) # The initial value p(0,0). f_a = jnp.cumprod(-1 * jnp.sqrt(1.0 + 0.5 / a_idx)) f_b = jnp.sqrt(2.0 * b_idx + 3.0) else: @@ -1598,13 +1611,13 @@ def _gen_associated_legendre(l_max: int, y = jnp.cumprod( jnp.broadcast_to(jnp.sqrt(1.0 - x * x), (l_max, x.shape[0])), axis=0) - p_diag = initial_value * jnp.einsum('i,ij->ij', f_a, y) + p_diag = initial_value * jnp_einsum.einsum('i,ij->ij', f_a, y) diag_indices = jnp.diag_indices(l_max + 1) p = p.at[(diag_indices[0][1:], diag_indices[1][1:])].set(p_diag) # Compute the off-diagonal entries with recurrence. - p_offdiag = jnp.einsum('ij,ij->ij', - jnp.einsum('i,j->ij', f_b, x), + p_offdiag = jnp_einsum.einsum('ij,ij->ij', + jnp_einsum.einsum('i,j->ij', f_b, x), p[jnp.diag_indices(l_max)]) offdiag_indices = (diag_indices[0][:l_max], diag_indices[1][:l_max] + 1) p = p.at[offdiag_indices].set(p_offdiag) @@ -1616,16 +1629,16 @@ def _gen_associated_legendre(l_max: int, def body_fun(i, p_val): coeff_0 = d0_mask_3d[i] coeff_1 = d1_mask_3d[i] - h = (jnp.einsum('ij,ijk->ijk', + h = (jnp_einsum.einsum('ij,ijk->ijk', coeff_0, - jnp.einsum( + jnp_einsum.einsum( 'ijk,k->ijk', jnp.roll(p_val, shift=1, axis=1), x)) - - jnp.einsum('ij,ijk->ijk', coeff_1, jnp.roll(p_val, shift=2, axis=1))) + jnp_einsum.einsum('ij,ijk->ijk', coeff_1, jnp.roll(p_val, shift=2, axis=1))) p_val = p_val + h return p_val # TODO(jakevdp): use some sort of fixed-point procedure here instead? - p = p.astype(jnp.result_type(p, x, d0_mask_3d)) + p = p.astype(dtypes.result_type(p, x, d0_mask_3d)) if l_max > 1: p = lax.fori_loop(lower=2, upper=l_max+1, body_fun=body_fun, init_val=p) @@ -1654,7 +1667,7 @@ def lpmn(m: int, n: int, z: Array) -> tuple[Array, Array]: NotImplementedError if `m!=n`. """ dtype = lax.dtype(z) - if dtype not in (jnp.float32, jnp.float64): + if dtype not in (np.float32, np.float64): raise TypeError( 'z.dtype={} is not supported, see docstring for supported types.' .format(dtype)) @@ -1711,7 +1724,7 @@ def lpmn_values(m: int, n: int, z: Array, is_normalized: bool) -> Array: NotImplementedError if `m!=n`. """ dtype = lax.dtype(z) - if dtype not in (jnp.float32, jnp.float64): + if dtype not in (np.float32, np.float64): raise TypeError( 'z.dtype={} is not supported, see docstring for supported types.' .format(dtype)) @@ -1731,7 +1744,7 @@ def lpmn_values(m: int, n: int, z: Array, is_normalized: bool) -> Array: -@partial(jit, static_argnums=(4,)) +@jit(static_argnums=(4,)) def _sph_harm(n: Array, m: Array, theta: Array, @@ -1868,15 +1881,15 @@ def sph_harm(m: Array, def _expint1(x: Array) -> Array: # 0 < x <= 2 - A = [ + A = np.array([ -5.350447357812542947283e0, 2.185049168816613393830e2, -4.176572384826693777058e3, 5.541176756393557601232e4, -3.313381331178144034309e5, 1.592627163384945414220e6, - ] - B = [ + ], dtype=x.dtype) + B = np.array([ 1.0, -5.250547959112862969197e1, 1.259616186786790571525e3, @@ -1884,27 +1897,23 @@ def _expint1(x: Array) -> Array: 1.493062117002725991967e5, -7.294949239640527645655e5, 1.592627163384945429726e6, - ] - A_arr = jnp.array(A, dtype=x.dtype) - B_arr = jnp.array(B, dtype=x.dtype) - f = jnp.polyval(A_arr, x) / jnp.polyval(B_arr, x) - return x * f + jnp.euler_gamma + jnp.log(x) + ], dtype=x.dtype) + f = jnp.polyval(A, x) / jnp.polyval(B, x) + return x * f + np.euler_gamma + jnp.log(x) -def _eval_expint_k(A: list[float], B: list[float], x: Array) -> Array: +def _eval_expint_k(A: ArrayLike, B: ArrayLike, x: Array) -> Array: # helper function for all subsequent intervals - A_arr = jnp.array(A, dtype=x.dtype) - B_arr = jnp.array(B, dtype=x.dtype) one = _lax_const(x, 1.0) w = one / x - f = jnp.polyval(A_arr, w) / jnp.polyval(B_arr, w) + f = jnp.polyval(A, w) / jnp.polyval(B, w) f = w * f + one return jnp.exp(x) * w * f def _expint2(x: Array) -> Array: # 2 <= x < 4 - A = [ + A = np.array([ 1.981808503259689673238e-2, -1.271645625984917501326e0, -2.088160335681228318920e0, @@ -1913,8 +1922,8 @@ def _expint2(x: Array) -> Array: 4.665623805935891391017e-2, -1.545042679673485262580e-3, 7.059980605299617478514e-5, - ] - B = [ + ], dtype=x.dtype) + B = np.array([ 1.0, 1.476498670914921440652e0, 5.629177174822436244827e-1, @@ -1923,13 +1932,13 @@ def _expint2(x: Array) -> Array: 4.450150439728752875043e-3, 1.727439612206521482874e-4, 3.953167195549672482304e-5, - ] + ], dtype=x.dtype) return _eval_expint_k(A, B, x) def _expint3(x: Array) -> Array: # 4 <= x <= 8 - A = [ + A = np.array([ -1.373215375871208729803e0, -7.084559133740838761406e-1, 1.580806855547941010501e0, @@ -1938,8 +1947,8 @@ def _expint3(x: Array) -> Array: -1.038086040188744005513e-3, 4.371064420753005429514e-5, 2.141783679522602903795e-6, - ] - B = [ + ], dtype=x.dtype) + B = np.array([ 1.0, 8.585231423622028380768e-1, 4.483285822873995129957e-1, @@ -1949,13 +1958,13 @@ def _expint3(x: Array) -> Array: 4.590952299511353531215e-4, -4.729848351866523044863e-6, 2.665195537390710170105e-6, - ] + ], dtype=x.dtype) return _eval_expint_k(A, B, x) def _expint4(x: Array) -> Array: # 8 <= x <= 16 - A = [ + A = np.array([ -2.106934601691916512584e0, 1.732733869664688041885e0, -2.423619178935841904839e-1, @@ -1966,8 +1975,8 @@ def _expint4(x: Array) -> Array: -3.655412321999253963714e-7, 1.464941733975961318456e-8, 6.176407863710360207074e-10, - ] - B = [ + ], dtype=x.dtype) + B = np.array([ 1.0, -2.298062239901678075778e-1, 1.105077041474037862347e-1, @@ -1978,13 +1987,13 @@ def _expint4(x: Array) -> Array: -4.459311796356686423199e-7, 1.394634930353847498145e-8, 6.150865933977338354138e-10, - ] + ], dtype=x.dtype) return _eval_expint_k(A, B, x) def _expint5(x): # 16 <= x <= 32 - A = [ + A = np.array([ -2.458119367674020323359e-1, -1.483382253322077687183e-1, 7.248291795735551591813e-2, @@ -1993,8 +2002,8 @@ def _expint5(x): -7.942465637159712264564e-5, 2.644179518984235952241e-6, -4.239473659313765177195e-8, - ] - B = [ + ], dtype=x.dtype) + B = np.array([ 1.0, -1.044225908443871106315e-1, -2.676453128101402655055e-1, @@ -2004,34 +2013,34 @@ def _expint5(x): -8.462452563778485013756e-5, 2.728938403476726394024e-6, -4.239462431819542051337e-8, - ] + ], dtype=x.dtype) return _eval_expint_k(A, B, x) def _expint6(x): # 32 <= x <= 64 - A = [ + A = np.array([ 1.212561118105456670844e-1, -5.823133179043894485122e-1, 2.348887314557016779211e-1, -3.040034318113248237280e-2, 1.510082146865190661777e-3, -2.523137095499571377122e-5, - ] - B = [ + ], dtype=x.dtype) + B = np.array([ 1.0, -1.002252150365854016662e0, 2.928709694872224144953e-1, -3.337004338674007801307e-2, 1.560544881127388842819e-3, -2.523137093603234562648e-5, - ] + ], dtype=x.dtype) return _eval_expint_k(A, B, x) def _expint7(x): # x > 64 - A = [ + A = np.array([ -7.657847078286127362028e-1, 6.886192415566705051750e-1, -2.132598113545206124553e-1, @@ -2041,8 +2050,8 @@ def _expint7(x): -6.103711682274170530369e-6, 1.218032765428652199087e-7, -1.086076102793290233007e-9, - ] - B = [ + ], dtype=x.dtype) + B = np.array([ 1.0, -1.888802868662308731041e0, 1.066691687211408896850e0, @@ -2053,7 +2062,7 @@ def _expint7(x): -6.345146083130515357861e-6, 1.239754287483206878024e-7, -1.086076102793126632978e-9, - ] + ], dtype=x.dtype) return _eval_expint_k(A, B, x) @@ -2095,9 +2104,10 @@ def expi(x: ArrayLike) -> Array: - :func:`jax.scipy.special.exp1` """ x_arr, = promote_args_inexact("expi", x) + if dtypes.issubdtype(x_arr.dtype, np.complexfloating): + raise ValueError("expi does not support complex-valued inputs.") return jnp.piecewise(x_arr, [x_arr < 0], [_expi_neg, _expi_pos]) - @expi.defjvp @jit def expi_jvp(primals, tangents): @@ -2106,14 +2116,248 @@ def expi_jvp(primals, tangents): return expi(x), jnp.exp(x) / x * x_dot +@custom_derivatives.custom_jvp +@jit +def sici(x: ArrayLike) -> tuple[Array, Array]: + r"""Sine and cosine integrals. + + JAX implementation of :obj:`scipy.special.sici`. + + .. math:: + + \mathrm{Si}(x) = \int_0^x \frac{\sin t}{t} \, dt + + .. math:: + + \mathrm{Ci}(x) = \gamma + \ln(x) + \int_0^x \frac{\cos t - 1}{t} \, dt + + where :math:`\gamma` is the Euler–Mascheroni constant. + + Args: + x: array-like, real-valued input. + + Returns: + A tuple of two arrays, each with the same shape as `x`: + - The first array contains the sine integral values `Si(x)`. + - The second array contains the cosine integral values `Ci(x)`. + + See also: + - :func:`jax.numpy.sinc` + """ + + x, = promote_args_inexact("sici", x) + + if dtypes.issubdtype(x.dtype, np.complexfloating): + raise ValueError( + f"Argument `x` to sici must be real-valued. Got dtype {x.dtype}." + ) + + x_abs = jnp.abs(x) + + si_series, ci_series = _sici_series(x_abs) + si_asymp, ci_asymp = _sici_asympt(x_abs) + si_approx, ci_approx = _sici_approx(x_abs) + + cond1 = x_abs <= 4 + cond2 = (x_abs > 4) & (x_abs <= 1e9) + + si = jnp.select([cond1, cond2], [si_series, si_asymp], si_approx) + ci = jnp.select([cond1, cond2], [ci_series, ci_asymp], ci_approx) + + si = jnp.sign(x) * si + ci = jnp.where(isneginf(x), np.nan, ci) + + return si, ci + +def _sici_approx(x: Array): + # sici approximation valid for x >= 1E9 + si = (np.pi / 2) - jnp.cos(x) / x + ci = jnp.sin(x) / x + + si = jnp.where(isposinf(x), np.pi / 2, si) + ci = jnp.where(isposinf(x), 0.0, ci) + + return si, ci + +def _sici_series(x: Array): + # sici series valid for x >= 0 and x <= 4 + def si_series(x): + # Values come from Cephes Implementation used by Scipy https://github.com/jeremybarnes/cephes/blob/60f27df395b8322c2da22c83751a2366b82d50d1/misc/sici.c + SN = np.array([-8.39167827910303881427E-11, + 4.62591714427012837309E-8, + -9.75759303843632795789E-6, + 9.76945438170435310816E-4, + -4.13470316229406538752E-2, + 1.00000000000000000302E0], dtype=x.dtype) + SD = np.array([ 2.03269266195951942049E-12, + 1.27997891179943299903E-9, + 4.41827842801218905784E-7, + 9.96412122043875552487E-5, + 1.42085239326149893930E-2, + 9.99999999999999996984E-1], dtype=x.dtype) + t = x * x + return (x * jnp.polyval(SN, t)) / jnp.polyval(SD, t) + + def ci_series(x): + # Values come from Cephes Implementation used by Scipy https://github.com/jeremybarnes/cephes/blob/60f27df395b8322c2da22c83751a2366b82d50d1/misc/sici.c + CN = np.array([ 2.02524002389102268789E-11, + -1.35249504915790756375E-8, + 3.59325051419993077021E-6, + -4.74007206873407909465E-4, + 2.89159652607555242092E-2, + -1.00000000000000000080E0], dtype=x.dtype) + CD = np.array([ 4.07746040061880559506E-12, + 3.06780997581887812692E-9, + 1.23210355685883423679E-6, + 3.17442024775032769882E-4, + 5.10028056236446052392E-2, + 4.00000000000000000080E0], dtype=x.dtype) + t = x * x + return np.euler_gamma + jnp.log(x) + t * jnp.polyval(CN, t) / jnp.polyval(CD, t) + + si = jnp.where( + x == 0, + 0.0, + si_series(x) + ) + + ci = jnp.where( + x == 0, + -np.inf, + ci_series(x) + ) + + return si, ci + +def _sici_asympt(x: Array): + # sici asympt valid for x > 4 & x <= 1E9 + s = jnp.sin(x) + c = jnp.cos(x) + z = 1.0 / (x * x) + + # Values come from Cephes Implementation used by Scipy https://github.com/jeremybarnes/cephes/blob/60f27df395b8322c2da22c83751a2366b82d50d1/misc/sici.c + FN4 = np.array([ + 4.23612862892216586994E0, + 5.45937717161812843388E0, + 1.62083287701538329132E0, + 1.67006611831323023771E-1, + 6.81020132472518137426E-3, + 1.08936580650328664411E-4, + 5.48900223421373614008E-7, + ], dtype=x.dtype) + FD4 = np.array([ + 1, + 8.16496634205391016773E0, + 7.30828822505564552187E0, + 1.86792257950184183883E0, + 1.78792052963149907262E-1, + 7.01710668322789753610E-3, + 1.10034357153915731354E-4, + 5.48900252756255700982E-7, + ], dtype=x.dtype) + GN4 = np.array([ + 8.71001698973114191777E-2, + 6.11379109952219284151E-1, + 3.97180296392337498885E-1, + 7.48527737628469092119E-2, + 5.38868681462177273157E-3, + 1.61999794598934024525E-4, + 1.97963874140963632189E-6, + 7.82579040744090311069E-9, + ], dtype=x.dtype) + GD4 = np.array([ + 1, + 1.64402202413355338886E0, + 6.66296701268987968381E-1, + 9.88771761277688796203E-2, + 6.22396345441768420760E-3, + 1.73221081474177119497E-4, + 2.02659182086343991969E-6, + 7.82579218933534490868E-9, + ], dtype=x.dtype) + + FN8 = np.array([ + 4.55880873470465315206E-1, + 7.13715274100146711374E-1, + 1.60300158222319456320E-1, + 1.16064229408124407915E-2, + 3.49556442447859055605E-4, + 4.86215430826454749482E-6, + 3.20092790091004902806E-8, + 9.41779576128512936592E-11, + 9.70507110881952024631E-14, + ], dtype=x.dtype) + FD8 = np.array([ + 1.0, + 9.17463611873684053703E-1, + 1.78685545332074536321E-1, + 1.22253594771971293032E-2, + 3.58696481881851580297E-4, + 4.92435064317881464393E-6, + 3.21956939101046018377E-8, + 9.43720590350276732376E-11, + 9.70507110881952025725E-14, + ], dtype=x.dtype) + GN8 = np.array([ + 6.97359953443276214934E-1, + 3.30410979305632063225E-1, + 3.84878767649974295920E-2, + 1.71718239052347903558E-3, + 3.48941165502279436777E-5, + 3.47131167084116673800E-7, + 1.70404452782044526189E-9, + 3.85945925430276600453E-12, + 3.14040098946363334640E-15, + ], dtype=x.dtype) + GD8 = np.array([ + 1.0, + 1.68548898811011640017E0, + 4.87852258695304967486E-1, + 4.67913194259625806320E-2, + 1.90284426674399523638E-3, + 3.68475504442561108162E-5, + 3.57043223443740838771E-7, + 1.72693748966316146736E-9, + 3.87830166023954706752E-12, + 3.14040098946363335242E-15, + ], dtype=x.dtype) + + f4 = jnp.polyval(FN4, z) / (x * jnp.polyval(FD4, z)) + g4 = z * jnp.polyval(GN4, z) / jnp.polyval(GD4, z) + + f8 = jnp.polyval(FN8, z) / (x * jnp.polyval(FD8, z)) + g8 = z * jnp.polyval(GN8, z) / jnp.polyval(GD8, z) + + mask = x < 8.0 + f = jnp.where(mask, f4, f8) + g = jnp.where(mask, g4, g8) + + si = (np.pi / 2) - f * c - g * s + ci = f * s - g * c + + return si, ci + +@sici.defjvp +@jit +def sici_jvp(primals, tangents): + (p,), (t,) = primals, tangents + primal_out = sici(p) + + sin_term = sinc(p / np.pi) + cos_term = jnp.cos(p) / p + + tangent_out = (sin_term * t, cos_term * t) + return primal_out, tangent_out + + def _expn1(x: Array, n: Array) -> Array: # exponential integral En _c = _lax_const - MACHEP = jnp.finfo(x.dtype).eps + MACHEP = dtypes.finfo(x.dtype).eps zero = _c(x, 0.0) one = _c(x, 1.0) - psi = -jnp.euler_gamma - jnp.log(x) + psi = -np.euler_gamma - jnp.log(x) psi = lax.fori_loop(_c(n, 1), n, lambda i, psi: psi + one / i, psi) n1 = jnp.where(n == _c(n, 1), one + one, n) init = dict( @@ -2123,7 +2367,7 @@ def _expn1(x: Array, n: Array) -> Array: yk=one, pk=one - n, ans=jnp.where(n == _c(n, 1), zero, one / (one - n1)), - t=jnp.inf, + t=np.inf, ) def body(d): @@ -2147,7 +2391,7 @@ def _expn2(x: Array, n: Array) -> Array: # x > 1. _c = _lax_const BIG = _c(x, 1.44115188075855872e17) - MACHEP = jnp.finfo(BIG.dtype).eps # ? + MACHEP = dtypes.finfo(x.dtype).eps zero = _c(x, 0.0) one = _c(x, 1.0) @@ -2158,7 +2402,7 @@ def _expn2(x: Array, n: Array) -> Array: pkm1=one, qkm1=x + n, ans=one / (x + n), - t=_c(x, jnp.inf), + t=_c(x, np.inf), r=zero, x=x, ) @@ -2208,7 +2452,7 @@ def _expn3(x: Array, n: Array) -> Array: @partial(custom_derivatives.custom_jvp, nondiff_argnums=(0,)) -@jnp.vectorize +@jnp_vectorize.vectorize @jit def expn(n: ArrayLike, x: ArrayLike) -> Array: r"""Generalized exponential integral function. @@ -2231,6 +2475,8 @@ def expn(n: ArrayLike, x: ArrayLike) -> Array: - :func:`jax.scipy.special.exp1` """ n, x = promote_args_inexact("expn", n, x) + if dtypes.issubdtype(x.dtype, np.complexfloating): + raise ValueError("expn does not support complex-valued inputs.") _c = _lax_const zero = _c(x, 0) one = _c(x, 1) @@ -2244,8 +2490,8 @@ def expn(n: ArrayLike, x: ArrayLike) -> Array: ] n1 = jnp.where(n == _c(n, 1), n + n, n) vals = [ - jnp.nan, - jnp.inf, + np.nan, + np.inf, one / n1, # prevent div by zero jnp.exp(-x) / x, _expn3, @@ -2286,6 +2532,8 @@ def exp1(x: ArrayLike) -> Array: - :func:`jax.scipy.special.expn` """ x, = promote_args_inexact("exp1", x) + if dtypes.issubdtype(x.dtype, np.complexfloating): + raise ValueError("exp1 does not support complex-valued inputs.") # Casting because custom_jvp generic does not work correctly with mypy. return cast(Array, expn(1, x)) @@ -2330,7 +2578,7 @@ def _spence_calc(x: Array) -> Array: lambda x: x - 1.0]) y = _spence_poly(w) - y_flag_one = jnp.pi ** 2 / 6.0 - jnp.log(x) * jnp.log(1.0 - x) - y + y_flag_one = np.pi ** 2 / 6.0 - jnp.log(x) * jnp.log(1.0 - x) - y y = jnp.where(x_5_bool, y_flag_one, y) y_flag_two = -0.5 * jnp.log(x) ** 2 - y return jnp.where(x2_bool, y_flag_two, y) @@ -2339,7 +2587,7 @@ def _spence_calc(x: Array) -> Array: def _spence(x: Array) -> Array: return jnp.piecewise(x, [x < 0.0, x == 1.0, x == 0.0], - [jnp.nan, 0, jnp.pi ** 2 / 6, _spence_calc]) + [np.nan, 0, np.pi ** 2 / 6, _spence_calc]) def spence(x: Array) -> Array: @@ -2380,7 +2628,7 @@ def spence(x: Array) -> Array: """ x = jnp.asarray(x) dtype = lax.dtype(x) - if dtype not in (jnp.float32, jnp.float64): + if dtype not in (np.float32, np.float64): raise TypeError( f"x.dtype={dtype} is not supported, see docstring for supported types.") return _spence(x) @@ -2410,7 +2658,7 @@ def bernoulli(n: int) -> Array: return b3[:n + 1] bn = jnp.zeros(n + 1).at[:3].set(b3) m = jnp.arange(4, n + 1, 2, dtype=bn.dtype) - q1 = (1. / jnp.pi ** 2) * jnp.cumprod(-(m - 1) * m / 4 / jnp.pi ** 2) + q1 = (1. / np.pi ** 2) * jnp.cumprod(-(m - 1) * m / 4 / np.pi ** 2) k = jnp.arange(2, 50, dtype=bn.dtype) # Choose 50 because 2 ** -50 < 1E-15 q2 = jnp.sum(k[:, None] ** -m[None, :], axis=0) return bn.at[4::2].set(q1 * (1 + q2)) @@ -2439,6 +2687,8 @@ def poch(z: ArrayLike, m: ArrayLike) -> Array: The JAX version supports only real-valued inputs. """ z, m = promote_args_inexact("poch", z, m) + if dtypes.issubdtype(z.dtype, np.complexfloating): + raise ValueError("jnp.poch does not support complex-valued inputs.") return jnp.where(m == 0., jnp.array(1, dtype=z.dtype), gamma(z + m) / gamma(z)) @@ -2474,7 +2724,7 @@ def _hyp1f1_serie(a, b, x): https://doi.org/10.48550/arXiv.1407.7786 """ - precision = jnp.finfo(x.dtype).eps + precision = dtypes.finfo(x.dtype).eps def body(state): serie, k, term = state @@ -2501,7 +2751,7 @@ def _hyp1f1_asymptotic(a, b, x): https://doi.org/10.48550/arXiv.1407.7786 """ - precision = jnp.finfo(x.dtype).eps + precision = dtypes.finfo(x.dtype).eps def body(state): serie, k, term = state @@ -2523,14 +2773,14 @@ def cond(state): @jit -@jnp.vectorize +@jnp_vectorize.vectorize def _hyp1f1_a_derivative(a, b, x): """ Define it as a serie using : https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric1F1/20/01/01/ """ - precision = jnp.finfo(x.dtype).eps + precision = dtypes.finfo(x.dtype).eps def body(state): serie, k, term = state @@ -2551,14 +2801,14 @@ def cond(state): @jit -@jnp.vectorize +@jnp_vectorize.vectorize def _hyp1f1_b_derivative(a, b, x): """ Define it as a serie using : https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric1F1/20/01/02/ """ - precision = jnp.finfo(x.dtype).eps + precision = dtypes.finfo(x.dtype).eps def body(state): serie, k, term = state @@ -2590,7 +2840,7 @@ def _hyp1f1_x_derivative(a, b, x): @custom_derivatives.custom_jvp @jit -@jnp.vectorize +@jnp_vectorize.vectorize def hyp1f1(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array: r"""The 1F1 hypergeometric function. @@ -2620,6 +2870,9 @@ def hyp1f1(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array: # evaluate lower values of hyp1f1 when a or b or both are > 60-80 a, b, x = promote_args_inexact('hyp1f1', a, b, x) + if dtypes.issubdtype(x.dtype, np.complexfloating): + raise ValueError("hyp1f1 does not support complex-valued inputs.") + result = lax.cond(lax.abs(x) < 100, _hyp1f1_serie, _hyp1f1_asymptotic, a, b, x) index = (a == 0) * 1 + ((a == b) & (a != 0)) * 2 + ((b == 0) & (a != 0)) * 3 @@ -2627,7 +2880,7 @@ def hyp1f1(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array: result, jnp.array(1, dtype=x.dtype), jnp.exp(x), - jnp.array(jnp.inf, dtype=x.dtype)) + jnp.array(np.inf, dtype=x.dtype)) hyp1f1.defjvps( @@ -2637,6 +2890,363 @@ def hyp1f1(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array: ) +def _hyp2f1_terminal(a, b, c, x): + """ + The Taylor series representation of the 2F1 hypergeometric function + terminates when either a or b is a non-positive integer. See Eq. 4.1 and + Taylor Series Method (a) from PEARSON, OLVER & PORTER 2014 + https://doi.org/10.48550/arXiv.1407.7786 + """ + # Ensure that between a and b, the negative integer parameter with the greater + # absolute value - that still has a magnitude less than the absolute value of + # c if c is non-positive - is used for the upper limit in the loop. + eps = dtypes.finfo(x.dtype).eps * 50 + ib = jnp.round(b) + mask = jnp.logical_and( + b < a, + jnp.logical_and( + jnp.abs(b - ib) < eps, + jnp.logical_not( + jnp.logical_and( + c % 1 == 0, + jnp.logical_and( + c <= 0, + c > b + ) + ) + ) + ) + ) + orig_a = a + a = jnp.where(mask, b, a) + b = jnp.where(mask, orig_a, b) + + a = jnp.abs(a) + + def body(i, state): + serie, term = state + + term *= -(a - i + 1) / (c + i - 1) * (b + i - 1) / i * x + serie += term + + return serie, term + + init = (jnp.array(1, dtype=x.dtype), jnp.array(1, dtype=x.dtype)) + + return lax.fori_loop(jnp.array(1, dtype=a.dtype), + a + 1, + body, + init)[0] + + +def _hyp2f1_serie(a, b, c, x): + """ + Compute the 2F1 hypergeometric function using the Taylor expansion. + See Eq. 4.1 from PEARSON, OLVER & PORTER 2014 + https://doi.org/10.48550/arXiv.1407.7786 + """ + rtol = dtypes.finfo(x.dtype).eps + + def body(state): + serie, k, term = state + + serie += term + term *= (a + k - 1) * (b + k - 1) / (c + k - 1) / k * x + k += 1 + + return serie, k, term + + def cond(state): + serie, k, term = state + + return (k < 250) & (lax.abs(term) > rtol * lax.abs(serie)) + + init = (jnp.array(0, dtype=x.dtype), + jnp.array(1, dtype=x.dtype), + jnp.array(1, dtype=x.dtype)) + + return lax.while_loop(cond, body, init)[0] + + +def _hyp2f1_terminal_or_serie(a, b, c, x): + """ + Check for recurrence relations along with whether or not the series + terminates. True recursion is not possible; however, the recurrence + relation may still be approximated. + See 4.6.1. Recurrence Relations from PEARSON, OLVER & PORTER 2014 + https://doi.org/10.48550/arXiv.1407.7786 + """ + eps = dtypes.finfo(x.dtype).eps * 50 + + d = c - a - b + + ia = jnp.round(a) + ib = jnp.round(b) + id = jnp.round(d) + + neg_int_a = jnp.logical_and(a <= 0, jnp.abs(a - ia) < eps) + neg_int_b = jnp.logical_and(b <= 0, jnp.abs(b - ib) < eps) + neg_int_a_or_b = jnp.logical_or(neg_int_a, neg_int_b) + not_neg_int_a_or_b = jnp.logical_not(neg_int_a_or_b) + + index = jnp.where(jnp.logical_and(x > 0.9, not_neg_int_a_or_b), + jnp.where(jnp.abs(d - id) >= eps, 0, 1), + jnp.where(neg_int_a_or_b, 2, 0)) + + return lax.select_n(index, + _hyp2f1_serie(a, b, c, x), + _hyp2f1_digamma_transform(a, b, c, x), + _hyp2f1_terminal(a, b, c, x)) + + +def _hyp2f1_digamma_transform(a, b, c, x): + """ + Digamma transformation of the 2F1 hypergeometric function. + See AMS55 #15.3.10, #15.3.11, #15.3.12 + """ + rtol = dtypes.finfo(x.dtype).eps + + d = c - a - b + s = 1 - x + rd = jnp.round(d) + + e = jnp.where(rd >= 0, d, -d) + d1 = jnp.where(rd >= 0, d, jnp.array(0, dtype=d.dtype)) + d2 = jnp.where(rd >= 0, jnp.array(0, dtype=d.dtype), d) + ard = jnp.where(rd >= 0, rd, -rd).astype('int32') + + ax = jnp.log(s) + + y = digamma(1.0) + digamma(1.0 + e) - digamma(a + d1) - digamma(b + d1) - ax + y /= gamma(e + 1.0) + + p = (a + d1) * (b + d1) * s / gamma(e + 2.0) + + def cond(state): + _, _, _, _, _, _, q, _, _, t, y = state + + return jnp.logical_and( + t < 250, + jnp.abs(q) >= rtol * jnp.abs(y) + ) + + def body(state): + a, ax, b, d1, e, p, q, r, s, t, y = state + + r = digamma(1.0 + t) + digamma(1.0 + t + e) - digamma(a + t + d1) \ + - digamma(b + t + d1) - ax + q = p * r + y += q + p *= s * (a + t + d1) / (t + 1.0) + p *= (b + t + d1) / (t + 1.0 + e) + t += 1.0 + + return a, ax, b, d1, e, p, q, r, s, t, y + + init = (a, ax, b, d1, e, p, y, jnp.array(0, dtype=x.dtype), s, + jnp.array(1, dtype=x.dtype), y) + _, _, _, _, _, _, q, r, _, _, y = lax.while_loop(cond, body, init) + + def compute_sum(y): + y1 = jnp.array(1, dtype=x.dtype) + t = jnp.array(0, dtype=x.dtype) + p = jnp.array(1, dtype=x.dtype) + + def for_body(i, state): + a, b, d2, e, p, s, t, y1 = state + + r = 1.0 - e + t + p *= s * (a + t + d2) * (b + t + d2) / r + t += 1.0 + p /= t + y1 += p + + return a, b, d2, e, p, s, t, y1 + + init_val = a, b, d2, e, p, s, t, y1 + y1 = lax.fori_loop(1, ard, for_body, init_val)[-1] + + p = gamma(c) + y1 *= gamma(e) * p / (gamma(a + d1) * gamma(b + d1)) + y *= p / (gamma(a + d2) * gamma(b + d2)) + + y = jnp.where((ard & 1) != 0, -y, y) + q = s ** rd + + return jnp.where(rd > 0, y * q + y1, y + y1 * q) + + return jnp.where( + rd == 0, + y * gamma(c) / (gamma(a) * gamma(b)), + compute_sum(y) + ) + + +@jit +@jnp_vectorize.vectorize +def _hyp2f1_a_derivative(a, b, c, x): + """ + Define it as a serie using : + https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric2F1/20/01/01/ + """ + + precision = dtypes.finfo(x.dtype).eps + + def body(state): + serie, k, term = state + serie += term * (digamma(a + k) - digamma(a)) + term *= (a + k) * (b + k) / (c + k) / (k + 1) * x + k += 1 + + return serie, k, term + + def cond(state): + serie, k, term = state + + return (k < 250) & (lax.abs(term) / lax.abs(serie) > precision) + + init = 0, 1, a * b / c * x + + return lax.while_loop(cond, body, init)[0] + + +@jit +@jnp_vectorize.vectorize +def _hyp2f1_b_derivative(a, b, c, x): + """ + Define it as a serie using : + https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric2F1/20/01/02/ + """ + + precision = dtypes.finfo(x.dtype).eps + + def body(state): + serie, k, term = state + serie += term * (digamma(b + k) - digamma(b)) + term *= (a + k) * (b + k) / (c + k) / (k + 1) * x + k += 1 + + return serie, k, term + + def cond(state): + serie, k, term = state + + return (k < 250) & (lax.abs(term) / lax.abs(serie) > precision) + + init = 0, 1, a * b / c * x + + return lax.while_loop(cond, body, init)[0] + + +@jit +@jnp_vectorize.vectorize +def _hyp2f1_c_derivative(a, b, c, x): + """ + Define it as a serie using : + https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric2F1/20/01/03/ + """ + + precision = dtypes.finfo(x.dtype).eps + + def body(state): + serie, k, term = state + serie += term * (digamma(c) - digamma(c + k)) + term *= (a + k) * (b + k) / (c + k) / (k + 1) * x + k += 1 + + return serie, k, term + + def cond(state): + serie, k, term = state + + return (k < 250) & (lax.abs(term) / lax.abs(serie) > precision) + + init = 0, 1, a * b / c * x + + return lax.while_loop(cond, body, init)[0] + + +@jit +def _hyp2f1_x_derivative(a, b, c, x): + """ + Define the derivative with regard to ``x`` : + https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric2F1/20/01/05/ + """ + + return a * b / c * hyp2f1(a + 1, b + 1, c + 1, x) + + +@custom_derivatives.custom_jvp +@jit +@jnp_vectorize.vectorize +def hyp2f1(a: ArrayLike, b: ArrayLike, c: ArrayLike, x: ArrayLike) -> Array: + r"""The 2F1 hypergeometric function. + + JAX implementation of :obj:`scipy.special.hyp2f1`. + + .. math:: + + \mathrm{hyp2f1}(a, b, c, x) = {}_2F_1(a; b; c; x) = \sum_{k=0}^\infty \frac{(a)_k(b)_k}{(c)_k}\frac{x^k}{k!} + + where :math:`(\cdot)_k` is the Pochammer symbol. + + The JAX version only accepts positive and real inputs. Values of + ``a``, ``b``, ``c``, and ``x`` leading to high values of 2F1 may + lead to erroneous results; consider enabling double precision in this case. + + Args: + a: arraylike, real-valued + b: arraylike, real-valued + c: arraylike, real-valued + x: arraylike, real-valued + + Returns: + array of 2F1 values. + """ + # This is backed by https://doi.org/10.48550/arXiv.1407.7786 + a, b, c, x = promote_args_inexact('hyp2f1', a, b, c, x) + eps = dtypes.finfo(x.dtype).eps * 50 + + d = c - a - b + s = 1 - x + ca = c - a + cb = c - b + + id = jnp.round(d) + ica = jnp.round(ca) + icb = jnp.round(cb) + + neg_int_ca = jnp.logical_and(ca <= 0, jnp.abs(ca - ica) < eps) + neg_int_cb = jnp.logical_and(cb <= 0, jnp.abs(cb - icb) < eps) + neg_int_ca_or_cb = jnp.logical_or(neg_int_ca, neg_int_cb) + + index = jnp.where(jnp.logical_or(x == 0, jnp.logical_and(jnp.logical_or(a == 0, b == 0), c != 0)), 0, + jnp.where(jnp.logical_or(c == 0, jnp.logical_and(c < 0, c % 1 == 0)), 1, + jnp.where(jnp.logical_and(d <= -1, jnp.logical_not(jnp.logical_and(jnp.abs(d - id) >= eps, s < 0))), 2, + jnp.where(jnp.logical_and(d <= 0, x == 1), 1, + jnp.where(jnp.logical_and(x < 1, b == c), 3, + jnp.where(jnp.logical_and(x < 1, a == c), 4, + jnp.where(x > 1, 1, + jnp.where(x == 1, 5, 6)))))))) + + return lax.select_n(index, + jnp.array(1, dtype=x.dtype), + jnp.array(np.inf, dtype=x.dtype), + s ** d * _hyp2f1_terminal_or_serie(ca, cb, c, x), + s ** (-a), + s ** (-b), + gamma(c) * gamma(d) / (gamma(ca) * gamma(cb)), + _hyp2f1_terminal_or_serie(a, b, c, x)) + + +hyp2f1.defjvps( + lambda a_dot, primal_out, a, b, c, x: _hyp2f1_a_derivative(a, b, c, x) * a_dot, + lambda b_dot, primal_out, a, b, c, x: _hyp2f1_b_derivative(a, b, c, x) * b_dot, + lambda c_dot, primal_out, a, b, c, x: _hyp2f1_c_derivative(a, b, c, x) * c_dot, + lambda x_dot, primal_out, a, b, c, x: _hyp2f1_x_derivative(a, b, c, x) * x_dot +) + + def softmax(x: ArrayLike, /, *, @@ -2656,6 +3266,7 @@ def softmax(x: ArrayLike, x : input array axis: the axis or axes along which the softmax should be computed. The softmax output summed across these dimensions should sum to :math:`1`. + ``None`` means all axes. Returns: An array of the same shape as ``x``. @@ -2690,7 +3301,7 @@ def log_softmax(x: ArrayLike, Args: x : input array axis: the axis or axes along which the :code:`log_softmax` should be - computed. + computed. ``None`` means all axes. Returns: An array of the same shape as ``x`` diff --git a/jax/_src/scipy/stats/_core.py b/jax/_src/scipy/stats/_core.py index 65c457f79cc8..cda42e1c5bfd 100644 --- a/jax/_src/scipy/stats/_core.py +++ b/jax/_src/scipy/stats/_core.py @@ -15,14 +15,14 @@ from __future__ import annotations from collections import namedtuple -from functools import partial import math -import jax -import jax.numpy as jnp -from jax import jit +import numpy as np + +from jax._src import api from jax._src import dtypes -from jax._src.api import vmap +from jax._src import lax +from jax._src import numpy as jnp from jax._src.numpy.util import check_arraylike, promote_args_inexact from jax._src.typing import ArrayLike, Array from jax._src.util import canonicalize_axis @@ -30,7 +30,7 @@ ModeResult = namedtuple('ModeResult', ('mode', 'count')) -@partial(jit, static_argnames=['axis', 'nan_policy', 'keepdims']) +@api.jit(static_argnames=['axis', 'nan_policy', 'keepdims']) def mode(a: ArrayLike, axis: int | None = 0, nan_policy: str = "propagate", keepdims: bool = False) -> ModeResult: """Compute the mode (most common value) along an axis of an array. @@ -116,11 +116,11 @@ def mode(a: ArrayLike, axis: int | None = 0, nan_policy: str = "propagate", keep axis = 0 x = x.ravel() - def _mode_helper(x: jax.Array) -> tuple[jax.Array, jax.Array]: + def _mode_helper(x: Array) -> tuple[Array, Array]: """Helper function to return mode and count of a given array.""" if x.size == 0: - return (jnp.array(jnp.nan, dtype=dtypes.canonicalize_dtype(jnp.float_)), - jnp.array(0, dtype=dtypes.canonicalize_dtype(jnp.float_))) + return (jnp.array(np.nan, dtype=dtypes.default_float_dtype()), + jnp.array(0, dtype=dtypes.default_float_dtype())) else: vals, counts = jnp.unique(x, return_counts=True, size=x.size) return vals[jnp.argmax(counts)], counts.max() @@ -128,7 +128,7 @@ def _mode_helper(x: jax.Array) -> tuple[jax.Array, jax.Array]: axis = canonicalize_axis(axis, x.ndim) x = jnp.moveaxis(x, axis, 0) x = x.reshape(x.shape[0], math.prod(x.shape[1:])) - vals, counts = vmap(_mode_helper, in_axes=1)(x) + vals, counts = api.vmap(_mode_helper, in_axes=1)(x) return ModeResult(vals.reshape(output_shape), counts.reshape(output_shape)) def invert_permutation(i: Array) -> Array: @@ -136,7 +136,7 @@ def invert_permutation(i: Array) -> Array: return jnp.empty_like(i).at[i].set(jnp.arange(i.size, dtype=i.dtype)) -@partial(jit, static_argnames=["method", "axis", "nan_policy"]) +@api.jit(static_argnames=["method", "axis", "nan_policy"]) def rankdata( a: ArrayLike, method: str = "average", @@ -198,7 +198,7 @@ def rankdata( return jnp.apply_along_axis(rankdata, axis, a, method) arr = jnp.ravel(a) - arr, sorter = jax.lax.sort_key_val(arr, jnp.arange(arr.size)) + arr, sorter = lax.sort_key_val(arr, jnp.arange(arr.size)) inv = invert_permutation(sorter) if method == "ordinal": @@ -213,11 +213,11 @@ def rankdata( if method == "min": return count[dense - 1] + 1 if method == "average": - return .5 * (count[dense] + count[dense - 1] + 1).astype(dtypes.canonicalize_dtype(jnp.float_)) + return .5 * (count[dense] + count[dense - 1] + 1).astype(dtypes.default_float_dtype()) raise ValueError(f"unknown method '{method}'") -@partial(jit, static_argnames=['axis', 'nan_policy', 'keepdims']) +@api.jit(static_argnames=['axis', 'nan_policy', 'keepdims']) def sem(a: ArrayLike, axis: int | None = 0, ddof: int = 1, nan_policy: str = "propagate", *, keepdims: bool = False) -> Array: """Compute the standard error of the mean. @@ -276,7 +276,7 @@ def sem(a: ArrayLike, axis: int | None = 0, ddof: int = 1, nan_policy: str = "pr Since, by default, ``nan_policy='propagate'``, ``sem`` propagates the ``nan`` values in the result. - >>> nan = jnp.nan + >>> nan = np.nan >>> x2 = jnp.array([[1, 2, 3, nan, 4, 2], ... [4, 5, 4, 3, nan, 1], ... [7, nan, 8, 7, 9, nan]]) @@ -285,7 +285,7 @@ def sem(a: ArrayLike, axis: int | None = 0, ddof: int = 1, nan_policy: str = "pr Array([1.73, nan, 1.53, nan, nan, nan], dtype=float32) If ``nan_policy='omit```, ``sem`` omits the ``nan`` values and computes the error - for the remainging values along the specified axis. + for the remaining values along the specified axis. >>> with jnp.printoptions(precision=2, suppress=True): ... jax.scipy.stats.sem(x2, nan_policy='omit') diff --git a/jax/_src/scipy/stats/bernoulli.py b/jax/_src/scipy/stats/bernoulli.py index 96e4a68b7697..5c87ab5e2e47 100644 --- a/jax/_src/scipy/stats/bernoulli.py +++ b/jax/_src/scipy/stats/bernoulli.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax import lax -import jax.numpy as jnp +import numpy as np + +from jax._src import lax +from jax._src import numpy as jnp from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import promote_args_inexact +from jax._src.scipy.special import xlogy, xlog1py from jax._src.typing import Array, ArrayLike -from jax.scipy.special import xlogy, xlog1py def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: @@ -54,7 +56,7 @@ def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: x = lax.sub(k, loc) log_probs = xlogy(x, p) + xlog1py(lax.sub(one, x), -p) return jnp.where(jnp.logical_or(lax.lt(x, zero), lax.gt(x, one)), - -jnp.inf, log_probs) + -np.inf, log_probs) def pmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: @@ -123,7 +125,7 @@ def cdf(k: ArrayLike, p: ArrayLike) -> Array: jnp.logical_and(lax.ge(k, zero), lax.lt(k, one)), lax.ge(k, one) ] - vals = [jnp.nan, zero, one - p, one] + vals = [np.nan, zero, one - p, one] return jnp.select(conds, vals) @@ -152,6 +154,6 @@ def ppf(q: ArrayLike, p: ArrayLike) -> Array: zero, one = _lax_const(q, 0), _lax_const(q, 1) return jnp.where( jnp.isnan(q) | jnp.isnan(p) | (p < zero) | (p > one) | (q < zero) | (q > one), - jnp.nan, + np.nan, jnp.where(lax.le(q, one - p), zero, one) ) diff --git a/jax/_src/scipy/stats/beta.py b/jax/_src/scipy/stats/beta.py index 19b8400ee29d..869115cc9bd0 100644 --- a/jax/_src/scipy/stats/beta.py +++ b/jax/_src/scipy/stats/beta.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax import lax -import jax.numpy as jnp +import numpy as np + +from jax._src import lax +from jax._src import numpy as jnp from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import promote_args_inexact +from jax._src.scipy.special import betaln, betainc, xlogy, xlog1py from jax._src.typing import Array, ArrayLike -from jax.scipy.special import betaln, betainc, xlogy, xlog1py def logpdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, @@ -61,9 +63,9 @@ def logpdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, xlog1py(lax.sub(b, one), lax.neg(y))) log_probs = lax.sub(lax.add(shape_term, log_linear_term), lax.log(scale)) result = jnp.where(jnp.logical_or(lax.gt(x, lax.add(loc, scale)), - lax.lt(x, loc)), -jnp.inf, log_probs) + lax.lt(x, loc)), -np.inf, log_probs) result_positive_constants = jnp.where(jnp.logical_or(jnp.logical_or(lax.le(a, zero), lax.le(b, zero)), - lax.le(scale, zero)), jnp.nan, result) + lax.le(scale, zero)), np.nan, result) return result_positive_constants diff --git a/jax/_src/scipy/stats/betabinom.py b/jax/_src/scipy/stats/betabinom.py index 7d4a4ed79cb9..aebdfbe364fd 100644 --- a/jax/_src/scipy/stats/betabinom.py +++ b/jax/_src/scipy/stats/betabinom.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License -from jax import lax -import jax.numpy as jnp +import numpy as np + +from jax._src import lax +from jax._src import numpy as jnp from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import promote_args_inexact from jax._src.scipy.special import betaln @@ -58,9 +60,9 @@ def logpmf(k: ArrayLike, n: ArrayLike, a: ArrayLike, b: ArrayLike, log_probs = jnp.where(jnp.logical_and(lax.eq(y, zero), lax.eq(n, zero)), 0., log_probs) y_cond = jnp.logical_or(jnp.logical_or(lax.lt(y, lax.neg(loc)), lax.gt(y, n)), lax.le(lax.add(y, a), zero)) - log_probs = jnp.where(y_cond, -jnp.inf, log_probs) + log_probs = jnp.where(y_cond, -np.inf, log_probs) n_a_b_cond = jnp.logical_or(jnp.logical_or(lax.lt(n, zero), lax.le(a, zero)), lax.le(b, zero)) - return jnp.where(n_a_b_cond, jnp.nan, log_probs) + return jnp.where(n_a_b_cond, np.nan, log_probs) def pmf(k: ArrayLike, n: ArrayLike, a: ArrayLike, b: ArrayLike, diff --git a/jax/_src/scipy/stats/binom.py b/jax/_src/scipy/stats/binom.py index 50d5aa6a8c99..3d9089b9af51 100644 --- a/jax/_src/scipy/stats/binom.py +++ b/jax/_src/scipy/stats/binom.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License -from jax import lax -import jax.numpy as jnp +import numpy as np + +from jax._src import lax +from jax._src import numpy as jnp from jax._src.numpy.util import promote_args_inexact from jax._src.lax.lax import _const as _lax_const from jax._src.scipy.special import gammaln, xlogy, xlog1py @@ -57,7 +59,7 @@ def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Arra y_n_cond = jnp.logical_or(jnp.logical_and(lax.eq(y, zero), lax.eq(n, zero)), lax.eq(log_linear_term, zero)) log_probs = jnp.where(y_n_cond, 0., log_probs) - return jnp.where(lax.ge(k, loc) & lax.lt(k, loc + n + 1), log_probs, -jnp.inf) + return jnp.where(lax.ge(k, loc) & lax.lt(k, loc + n + 1), log_probs, -np.inf) def pmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: diff --git a/jax/_src/scipy/stats/cauchy.py b/jax/_src/scipy/stats/cauchy.py index 922cdadb669a..ac0ec8574462 100644 --- a/jax/_src/scipy/stats/cauchy.py +++ b/jax/_src/scipy/stats/cauchy.py @@ -15,10 +15,10 @@ import numpy as np -from jax import lax +from jax._src import lax from jax._src.lax.lax import _const as _lax_const +from jax._src.numpy.ufuncs import arctan from jax._src.numpy.util import promote_args_inexact -from jax.numpy import arctan from jax._src.typing import Array, ArrayLike diff --git a/jax/_src/scipy/stats/chi2.py b/jax/_src/scipy/stats/chi2.py index 6637104e2123..05069b59f40f 100644 --- a/jax/_src/scipy/stats/chi2.py +++ b/jax/_src/scipy/stats/chi2.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License -from jax import lax -import jax.numpy as jnp +import numpy as np + +from jax._src import lax +from jax._src import numpy as jnp from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import promote_args_inexact +from jax._src.scipy.special import gammainc, gammaincc from jax._src.typing import Array, ArrayLike -from jax.scipy.special import gammainc, gammaincc def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: @@ -65,7 +67,7 @@ def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1 nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two))) log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel) - return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs) + return jnp.where(lax.lt(x, loc), -np.inf, log_probs) def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: @@ -146,7 +148,7 @@ def cdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) - lax.sub(x, loc), lax.mul(scale, two), ), - _lax_const(x, jnp.inf), + _lax_const(x, np.inf), ), ) @@ -226,7 +228,7 @@ def sf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> lax.sub(x, loc), lax.mul(scale, two), ), - _lax_const(x, jnp.inf), + _lax_const(x, np.inf), ), ) diff --git a/jax/_src/scipy/stats/dirichlet.py b/jax/_src/scipy/stats/dirichlet.py index ee28c7e3ea59..15bd6c06798f 100644 --- a/jax/_src/scipy/stats/dirichlet.py +++ b/jax/_src/scipy/stats/dirichlet.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax import lax -import jax.numpy as jnp +import numpy as np + +from jax._src import lax +from jax._src import numpy as jnp from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import promote_dtypes_inexact -from jax.scipy.special import gammaln, xlogy +from jax._src.scipy.special import gammaln, xlogy from jax._src.typing import Array, ArrayLike @@ -68,7 +70,7 @@ def _logpdf(x: Array, alpha: Array) -> Array: if x.ndim > 1: alpha = lax.broadcast_in_dim(alpha, alpha.shape + (1,) * (x.ndim - 1), (0,)) log_probs = lax.sub(jnp.sum(xlogy(lax.sub(alpha, one), x), axis=0), normalize_term) - return jnp.where(_is_simplex(x), log_probs, -jnp.inf) + return jnp.where(_is_simplex(x), log_probs, -np.inf) def pdf(x: ArrayLike, alpha: ArrayLike) -> Array: diff --git a/jax/_src/scipy/stats/expon.py b/jax/_src/scipy/stats/expon.py index ba80fa6fbcb1..3d5e3785837b 100644 --- a/jax/_src/scipy/stats/expon.py +++ b/jax/_src/scipy/stats/expon.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np -import jax.numpy as jnp -from jax import lax +from jax._src import lax +from jax._src import numpy as jnp from jax._src.numpy.util import promote_args_inexact from jax._src.typing import Array, ArrayLike @@ -54,7 +55,7 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: log_scale = lax.log(scale) linear_term = lax.div(lax.sub(x, loc), scale) log_probs = lax.neg(lax.add(linear_term, log_scale)) - return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs) + return jnp.where(lax.lt(x, loc), -np.inf, log_probs) def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: @@ -264,6 +265,6 @@ def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: neg_scaled_q = lax.div(lax.sub(loc, q), scale) return jnp.where( jnp.isnan(q) | (q < 0) | (q > 1), - jnp.nan, + np.nan, lax.neg(lax.log1p(neg_scaled_q)), ) diff --git a/jax/_src/scipy/stats/gamma.py b/jax/_src/scipy/stats/gamma.py index 97d73a3ee443..d74ef391f53c 100644 --- a/jax/_src/scipy/stats/gamma.py +++ b/jax/_src/scipy/stats/gamma.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax import lax -import jax.numpy as jnp +import numpy as np + +from jax._src import lax +from jax._src import numpy as jnp from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import promote_args_inexact +from jax._src.scipy.special import gammaln, xlogy, gammainc, gammaincc from jax._src.typing import Array, ArrayLike -from jax.scipy.special import gammaln, xlogy, gammainc, gammaincc def logpdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: @@ -57,7 +59,7 @@ def logpdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) log_linear_term = lax.sub(xlogy(lax.sub(a, one), y), y) shape_terms = lax.add(gammaln(a), lax.log(scale)) log_probs = lax.sub(log_linear_term, shape_terms) - return jnp.where(ok, log_probs, -jnp.inf) + return jnp.where(ok, log_probs, -np.inf) def pdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: @@ -129,7 +131,7 @@ def cdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> lax.clamp( _lax_const(x, 0), lax.div(lax.sub(x, loc), scale), - _lax_const(x, jnp.inf), + _lax_const(x, np.inf), ) ) diff --git a/jax/_src/scipy/stats/gennorm.py b/jax/_src/scipy/stats/gennorm.py index 9d24708c066a..ef4ce17482c2 100644 --- a/jax/_src/scipy/stats/gennorm.py +++ b/jax/_src/scipy/stats/gennorm.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax import lax +from jax._src import lax from jax._src.numpy.util import promote_args_inexact from jax._src.typing import Array, ArrayLike diff --git a/jax/_src/scipy/stats/geom.py b/jax/_src/scipy/stats/geom.py index 1d5133f9c3ea..82bf66d23cf9 100644 --- a/jax/_src/scipy/stats/geom.py +++ b/jax/_src/scipy/stats/geom.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax import lax -import jax.numpy as jnp +import numpy as np + +from jax._src import lax +from jax._src import numpy as jnp from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import promote_args_inexact -from jax.scipy.special import xlog1py +from jax._src.scipy.special import xlog1py from jax._src.typing import Array, ArrayLike @@ -49,7 +51,7 @@ def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: one = _lax_const(k, 1) x = lax.sub(k, loc) log_probs = xlog1py(lax.sub(x, one), -p) + lax.log(p) - return jnp.where(lax.le(x, zero), -jnp.inf, log_probs) + return jnp.where(lax.le(x, zero), -np.inf, log_probs) def pmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: diff --git a/jax/_src/scipy/stats/gumbel_l.py b/jax/_src/scipy/stats/gumbel_l.py new file mode 100644 index 000000000000..6ba7af2f8165 --- /dev/null +++ b/jax/_src/scipy/stats/gumbel_l.py @@ -0,0 +1,256 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from jax._src import lax +from jax._src import numpy as jnp +from jax._src.lax.lax import _const as _lax_const +from jax._src.numpy.util import promote_args_inexact +from jax._src.typing import Array, ArrayLike +from jax._src.scipy.special import xlogy, xlog1py + + +def logpdf(x: ArrayLike, + loc: ArrayLike = 0, + scale: ArrayLike = 1) -> Array: + r""" + Gumbel Distribution (Left Skewed) log probability distribution function. + + JAX implementation of :obj:`scipy.stats.gumbel_l` ``logpdf``. + + .. math:: + + f_{pdf}(x; \mu, \beta) = \frac{1}{\beta} \exp\left( \frac{x - \mu}{\beta} - \exp\left( \frac{x - \mu}{\beta} \right) \right) + + Args: + x: ArrayLike, value at which to evaluate log(pdf) + loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0) + scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1) + + Returns: + array of logpdf values + + See Also: + - :func:`jax.scipy.stats.gumbel_l.pdf` + - :func:`jax.scipy.stats.gumbel_l.logcdf` + - :func:`jax.scipy.stats.gumbel_l.cdf` + - :func:`jax.scipy.stats.gumbel_l.ppf` + - :func:`jax.scipy.stats.gumbel_l.logsf` + - :func:`jax.scipy.stats.gumbel_l.sf` + """ + + x, loc, scale = promote_args_inexact("gumbel_l.logpdf", x, loc, scale) + ok = lax.gt(scale, _lax_const(scale, 0)) + # logpdf = -log(scale) + z - exp(z) + z = lax.div(lax.sub(x, loc), scale) + neg_log_scale = xlogy(-1, scale) + t2 = lax.sub(z, lax.exp(z)) + log_pdf = lax.add(neg_log_scale, t2) + return jnp.where(ok, log_pdf, np.nan) + + +def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r""" + Gumbel Distribution (Left Skewed) probability distribution function. + + JAX implementation of :obj:`scipy.stats.gumbel_l` ``pdf``. + + .. math:: + + f_{pdf}(x; \mu, \beta) = \frac{1}{\beta} \exp\left( \frac{x - \mu}{\beta} - \exp\left( \frac{x - \mu}{\beta} \right) \right) + + Args: + x: ArrayLike, value at which to evaluate pdf + loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0) + scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1) + + Returns: + array of pdf values + + See Also: + - :func:`jax.scipy.stats.gumbel_l.logpdf` + - :func:`jax.scipy.stats.gumbel_l.logcdf` + - :func:`jax.scipy.stats.gumbel_l.cdf` + - :func:`jax.scipy.stats.gumbel_l.ppf` + - :func:`jax.scipy.stats.gumbel_l.logsf` + - :func:`jax.scipy.stats.gumbel_l.sf` + """ + return lax.exp(logpdf(x, loc, scale)) + + +def logcdf(x: ArrayLike, + loc: ArrayLike = 0, + scale: ArrayLike = 1) -> Array: + r""" + Gumbel Distribution (Left Skewed) log cumulative density function. + + JAX implementation of :obj:`scipy.stats.gumbel_l` ``logcdf``. + + .. math:: + + f_{cdf}(x; \mu, \beta) = 1 - \exp\left( -\exp\left( \frac{x - \mu}{\beta} \right) \right) + + Args: + x: ArrayLike, value at which to evaluate log(cdf) + loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0) + scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1) + + Returns: + array of logcdf values + + See Also: + - :func:`jax.scipy.stats.gumbel_l.logpdf` + - :func:`jax.scipy.stats.gumbel_l.pdf` + - :func:`jax.scipy.stats.gumbel_l.cdf` + - :func:`jax.scipy.stats.gumbel_l.ppf` + - :func:`jax.scipy.stats.gumbel_l.logsf` + - :func:`jax.scipy.stats.gumbel_l.sf` + """ + x, loc, scale = promote_args_inexact("gumbel_l.logcdf", x, loc, scale) + ok = lax.gt(scale, _lax_const(scale, 0)) + z = lax.div(lax.sub(x, loc), scale) + neg_exp_z = lax.neg(lax.exp(z)) + # xlog1p fails here, that's why log1p is used here + # even log1p fails for some cases when using float64 mode + # so we're using this formula which is stable + log_cdf = lax.log(-lax.expm1(neg_exp_z)) + return jnp.where(ok, log_cdf, np.nan) + + +def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r""" + Gumbel Distribution (Left Skewed) cumulative density function. + + JAX implementation of :obj:`scipy.stats.gumbel_l` ``cdf``. + + .. math:: + + f_{cdf}(x; \mu, \beta) = 1 - \exp\left( -\exp\left( \frac{x - \mu}{\beta} \right) \right) + + Args: + x: ArrayLike, value at which to evaluate cdf + loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0) + scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1) + + Returns: + array of cdf values + + See Also: + - :func:`jax.scipy.stats.gumbel_l.logpdf` + - :func:`jax.scipy.stats.gumbel_l.pdf` + - :func:`jax.scipy.stats.gumbel_l.logcdf` + - :func:`jax.scipy.stats.gumbel_l.ppf` + - :func:`jax.scipy.stats.gumbel_l.logsf` + - :func:`jax.scipy.stats.gumbel_l.sf` + """ + return lax.exp(logcdf(x, loc, scale)) + + +def ppf(p: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r""" + Gumbel Distribution (Left Skewed) percent point function (inverse of CDF) + + JAX implementation of :obj:`scipy.stats.gumbel_l` ``ppf``. + + .. math:: + + F_{ppf}}(p; \mu, \beta) = \mu + \beta \log\left( -\log(1 - p) \right) + + Args: + p: ArrayLike, probability value (quantile) at which to evaluate ppf + loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0) + scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1) + + Returns: + array of ppf values + + See Also: + - :func:`jax.scipy.stats.gumbel_l.logpdf` + - :func:`jax.scipy.stats.gumbel_l.pdf` + - :func:`jax.scipy.stats.gumbel_l.logcdf` + - :func:`jax.scipy.stats.gumbel_l.cdf` + - :func:`jax.scipy.stats.gumbel_l.logsf` + - :func:`jax.scipy.stats.gumbel_l.sf` + """ + p, loc, scale = promote_args_inexact("gumbel_l.ppf", p, loc, scale) + ok = lax.bitwise_and(lax.gt(p, _lax_const(p, 0)), + lax.lt(p, _lax_const(p, 1))) + # quantile = loc + (scale)*log(-log(1 - p)) + t1 = xlog1py(-1, lax.neg(p)) + # xlogp failed here too, that's why log is used + t = lax.mul(scale, lax.log(t1)) + quantile = lax.add(loc, t) + return jnp.where(ok, quantile, np.nan) + + +def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r""" + Gumbel Distribution (Left Skewed) survival function. + + JAX implementation of :obj:`scipy.stats.gumbel_l` ``sf``. + + .. math:: + + f_{sf}(x; \mu, \beta) = 1 - f_{cdf}(x, \mu, \beta) + + Args: + x: ArrayLike, value at which to evaluate survival function + loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0) + scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1) + + Returns: + array of sf values (1 - cdf) + + See Also: + - :func:`jax.scipy.stats.gumbel_l.logpdf` + - :func:`jax.scipy.stats.gumbel_l.pdf` + - :func:`jax.scipy.stats.gumbel_l.logcdf` + - :func:`jax.scipy.stats.gumbel_l.cdf` + - :func:`jax.scipy.stats.gumbel_l.logsf` + """ + return jnp.exp(logsf(x, loc, scale)) + + +def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r""" + Gumbel Distribution (Left Skewed) log survival function. + + JAX implementation of :obj:`scipy.stats.gumbel_l` ``logsf``. + + .. math:: + + f_{sf}(x; \mu, \beta) = 1 - f_{cdf}(x, \mu, \beta) + + Args: + x: ArrayLike, value at which to evaluate log survival function + loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0) + scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1) + + Returns: + array of logsf values + + See Also: + - :func:`jax.scipy.stats.gumbel_l.logpdf` + - :func:`jax.scipy.stats.gumbel_l.pdf` + - :func:`jax.scipy.stats.gumbel_l.logcdf` + - :func:`jax.scipy.stats.gumbel_l.cdf` + - :func:`jax.scipy.stats.gumbel_l.sf` + """ + x, loc, scale = promote_args_inexact("gumbel_l.logsf", x, loc, scale) + ok = lax.gt(scale, _lax_const(scale, 0)) + # logsf = -exp(z) + z = lax.div(lax.sub(x, loc), scale) + log_sf = lax.neg(lax.exp(z)) + return jnp.where(ok, log_sf, np.nan) diff --git a/jax/_src/scipy/stats/gumbel_r.py b/jax/_src/scipy/stats/gumbel_r.py new file mode 100644 index 000000000000..91bfafd9962b --- /dev/null +++ b/jax/_src/scipy/stats/gumbel_r.py @@ -0,0 +1,257 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from jax._src import lax +from jax._src import numpy as jnp +from jax._src.lax.lax import _const as _lax_const +from jax._src.numpy.util import promote_args_inexact +from jax._src.typing import Array, ArrayLike +from jax._src.scipy.special import xlogy +from jax._src.nn.functions import log1mexp + + +def logpdf(x: ArrayLike, + loc: ArrayLike = 0, + scale: ArrayLike = 1) -> Array: + r""" + Gumbel Distribution (Right Skewed) log probability distribution function. + + JAX implementation of :obj:`scipy.stats.gumbel_l` ``logpdf``. + + .. math:: + + f_{pdf}(x; \mu, \beta) = \frac{1}{\beta} \exp\left( -\frac{x - \mu}{\beta} - \exp\left( -\frac{x - \mu}{\beta} \right) \right) + + Args: + x: ArrayLike, value at which to evaluate log(pdf) + loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0) + scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1) + + Returns: + array of logpdf values + + See Also: + - :func:`jax.scipy.stats.gumbel_r.pdf` + - :func:`jax.scipy.stats.gumbel_r.logcdf` + - :func:`jax.scipy.stats.gumbel_r.cdf` + - :func:`jax.scipy.stats.gumbel_r.ppf` + - :func:`jax.scipy.stats.gumbel_r.sf` + - :func:`jax.scipy.stats.gumbel_r.logsf` + """ + + x, loc, scale = promote_args_inexact("gumbel_r.logpdf", x, loc, scale) + ok = lax.gt(scale, _lax_const(scale, 0)) + z = lax.div(lax.sub(x, loc), scale) + # logpdf = -log(beta) - (z + exp(-z)) + neg_log_scale = xlogy(-1, scale) + t2 = lax.neg(lax.add(z, lax.exp(lax.neg(z)))) + log_pdf = lax.add(neg_log_scale, t2) + return jnp.where(ok, log_pdf, np.nan) + + +def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r""" + Gumbel Distribution (Right Skewed) probability distribution function. + + JAX implementation of :obj:`scipy.stats.gumbel_r` ``pdf``. + + .. math:: + + f_{pdf}(x; \mu, \beta) = \frac{1}{\beta} \exp\left( -\frac{x - \mu}{\beta} - \exp\left( -\frac{x - \mu}{\beta} \right) \right) + + Args: + x: ArrayLike, value at which to evaluate pdf + loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0) + scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1) + + Returns: + array of pdf values + + See Also: + - :func:`jax.scipy.stats.gumbel_r.logpdf` + - :func:`jax.scipy.stats.gumbel_r.logcdf` + - :func:`jax.scipy.stats.gumbel_r.cdf` + - :func:`jax.scipy.stats.gumbel_r.ppf` + - :func:`jax.scipy.stats.gumbel_r.sf` + - :func:`jax.scipy.stats.gumbel_r.logsf` + """ + return lax.exp(logpdf(x, loc, scale)) + + +def logcdf(x: ArrayLike, + loc: ArrayLike = 0, + scale: ArrayLike = 1) -> Array: + r""" + Gumbel Distribution (Right Skewed) log cumulative density function. + + JAX implementation of :obj:`scipy.stats.gumbel_r` ``logcdf``. + + .. math:: + + f_{cdf}(x; \mu, \beta) = \exp\left( -\exp\left( -\frac{x - \mu}{\beta} \right) \right) + + Args: + x: ArrayLike, value at which to evaluate log(cdf) + loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0) + scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1) + + Returns: + array of logcdf values + + See Also: + - :func:`jax.scipy.stats.gumbel_r.logpdf` + - :func:`jax.scipy.stats.gumbel_r.pdf` + - :func:`jax.scipy.stats.gumbel_r.cdf` + - :func:`jax.scipy.stats.gumbel_r.ppf` + - :func:`jax.scipy.stats.gumbel_r.sf` + - :func:`jax.scipy.stats.gumbel_r.logsf` + """ + x, loc, scale = promote_args_inexact("gumbel_r.logcdf", x, loc, scale) + ok = lax.gt(scale, _lax_const(scale, 0)) + z = lax.div(lax.sub(x, loc), scale) + # log cdf = -exp(-z) + log_cdf = lax.neg(lax.exp(lax.neg(z))) + return jnp.where(ok, log_cdf, np.nan) + + +def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r""" + Gumbel Distribution (Right Skewed) cumulative density function. + + JAX implementation of :obj:`scipy.stats.gumbel_r` ``cdf``. + + .. math:: + + f_{cdf}(x; \mu, \beta) = \exp\left( -\exp\left( -\frac{x - \mu}{\beta} \right) \right) + + Args: + x: ArrayLike, value at which to evaluate cdf + loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0) + scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1) + + Returns: + array of cdf values + + See Also: + - :func:`jax.scipy.stats.gumbel_r.logpdf` + - :func:`jax.scipy.stats.gumbel_r.pdf` + - :func:`jax.scipy.stats.gumbel_r.logcdf` + - :func:`jax.scipy.stats.gumbel_r.ppf` + - :func:`jax.scipy.stats.gumbel_r.sf` + - :func:`jax.scipy.stats.gumbel_r.logsf` + """ + return lax.exp(logcdf(x, loc, scale)) + + +def ppf(p: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r""" + Gumbel Distribution (Right Skewed) percent point function. + + JAX implementation of :obj:`scipy.stats.gumbel_r` ``ppf``. + + .. math:: + + F(p; \mu, \beta) = \mu - \beta \log\left( -\log(p) \right) + + Args: + p: ArrayLike, probability value (quantile) at which to evaluate ppf + loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0) + scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1) + + Returns: + array of ppf values + + See Also: + - :func:`jax.scipy.stats.gumbel_r.logpdf` + - :func:`jax.scipy.stats.gumbel_r.pdf` + - :func:`jax.scipy.stats.gumbel_r.logcdf` + - :func:`jax.scipy.stats.gumbel_r.cdf` + - :func:`jax.scipy.stats.gumbel_r.sf` + - :func:`jax.scipy.stats.gumbel_r.logsf` + """ + p, loc, scale = promote_args_inexact("gumbel_r.ppf", p, loc, scale) + # 0 < p < 1 + ok = lax.bitwise_and(lax.gt(p, _lax_const(p, 0)), + lax.lt(p, _lax_const(p, 1))) + + # quantile = loc - (scale)*log(-log(p)) + t1 = xlogy(-1, p) + t = lax.mul(scale, lax.log(t1)) + quantile = lax.sub(loc, t) + return jnp.where(ok, quantile, np.nan) + + +def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r""" + Gumbel Distribution (Right Skewed) survival function. + + JAX implementation of :obj:`scipy.stats.gumbel_r` ``sf``. + + .. math:: + + f_{sf}(x; \mu, \beta) = 1 - F_{cdf}(x; \mu, \beta) + + Args: + x: ArrayLike, value at which to evaluate survival function + loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0) + scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1) + + Returns: + array of sf values (1 - cdf) + + See Also: + - :func:`jax.scipy.stats.gumbel_r.logpdf` + - :func:`jax.scipy.stats.gumbel_r.pdf` + - :func:`jax.scipy.stats.gumbel_r.logcdf` + - :func:`jax.scipy.stats.gumbel_r.cdf` + - :func:`jax.scipy.stats.gumbel_r.logsf` + """ + x, loc, scale = promote_args_inexact("gumbel_r.sf", x, loc, scale) + ok = lax.gt(scale, _lax_const(scale, 0)) + # sf = 1 - exp(-exp(-z)) + neg_z = lax.div(lax.sub(loc, x), scale) + t1 = lax.exp(lax.neg(lax.exp(neg_z))) + _sf = lax.sub(_lax_const(x, 1), t1) + return jnp.where(ok, _sf, np.nan) + + +def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r""" + Gumbel Distribution (Right Skewed) log survival function. + + JAX implementation of :obj:`scipy.stats.gumbel_r` ``logsf``. + + Args: + x: ArrayLike, value at which to evaluate log survival function + loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0) + scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1) + + Returns: + array of logsf values + + See Also: + - :func:`jax.scipy.stats.gumbel_r.logpdf` + - :func:`jax.scipy.stats.gumbel_r.pdf` + - :func:`jax.scipy.stats.gumbel_r.logcdf` + - :func:`jax.scipy.stats.gumbel_r.cdf` + - :func:`jax.scipy.stats.gumbel_r.sf` + """ + x, loc, scale = promote_args_inexact("gumbel_r.logsf", x, loc, scale) + ok = lax.gt(scale, _lax_const(scale, 0)) + # logsf = log(1 - exp(-exp(-z))) + neg_z = lax.div(lax.sub(loc, x), scale) + log_sf = log1mexp(lax.exp(neg_z)) + return jnp.where(ok, log_sf, np.nan) diff --git a/jax/_src/scipy/stats/kde.py b/jax/_src/scipy/stats/kde.py index a52ffb48bd2b..5e9d94343c65 100644 --- a/jax/_src/scipy/stats/kde.py +++ b/jax/_src/scipy/stats/kde.py @@ -18,11 +18,14 @@ import numpy as np -import jax.numpy as jnp -from jax import jit, lax, random, vmap +from jax._src import api +from jax._src import dtypes +from jax._src import lax +from jax._src import numpy as jnp +from jax._src import random from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact +from jax._src.scipy import linalg, special from jax._src.tree_util import register_pytree_node_class -from jax.scipy import linalg, special @register_pytree_node_class @@ -48,7 +51,7 @@ class gaussian_kde: def __init__(self, dataset, bw_method=None, weights=None): check_arraylike("gaussian_kde", dataset) dataset = jnp.atleast_2d(dataset) - if jnp.issubdtype(lax.dtype(dataset), np.complexfloating): + if dtypes.issubdtype(lax.dtype(dataset), np.complexfloating): raise NotImplementedError("gaussian_kde does not support complex data") if not dataset.size > 1: raise ValueError("`dataset` input should have multiple elements.") @@ -153,7 +156,7 @@ def integrate_box_1d(self, low, high): """Integrate the distribution over the given limits.""" if self.d != 1: raise ValueError("integrate_box_1d() only handles 1D pdfs") - if jnp.ndim(low) != 0 or jnp.ndim(high) != 0: + if np.ndim(low) != 0 or np.ndim(high) != 0: raise ValueError( "the limits of integration in integrate_box_1d must be scalars") sigma = jnp.squeeze(jnp.sqrt(self.covariance)) @@ -171,9 +174,9 @@ def integrate_kde(self, other): norm = 1.0 / norm sm, lg = (self, other) if self.n < other.n else (other, self) - result = vmap(partial(_gaussian_kernel_convolve, chol, norm, lg.dataset, - lg.weights), - in_axes=1)(sm.dataset) + result = api.vmap(partial(_gaussian_kernel_convolve, chol, norm, lg.dataset, + lg.weights), + in_axes=1)(sm.dataset) return jnp.sum(result * sm.weights) def resample(self, key, shape=()): @@ -222,7 +225,7 @@ def set_bandwidth(self, bw_method=None): "dynamically changing the bandwidth method is not supported") def _reshape_points(self, points): - if jnp.issubdtype(lax.dtype(points), np.complexfloating): + if dtypes.issubdtype(lax.dtype(points), np.complexfloating): raise NotImplementedError( "gaussian_kde does not support complex coordinates") points = jnp.atleast_2d(points) @@ -244,7 +247,7 @@ def _gaussian_kernel_convolve(chol, norm, target, weights, mean): return norm * jnp.sum(jnp.exp(-arg) * weights) -@partial(jit, static_argnums=0) +@api.jit(static_argnums=0) def _gaussian_kernel_eval(in_log, points, values, xi, precision): points, values, xi, precision = promote_dtypes_inexact( points, values, xi, precision) @@ -269,9 +272,9 @@ def kernel(x_test, x_train, y_train): return y_train * jnp.exp(arg) reduce = special.logsumexp if in_log else jnp.sum - reduced_kernel = lambda x: reduce(vmap(kernel, in_axes=(None, 0, 0)) + reduced_kernel = lambda x: reduce(api.vmap(kernel, in_axes=(None, 0, 0)) (x, points, values), axis=0) - mapped_kernel = vmap(reduced_kernel) + mapped_kernel = api.vmap(reduced_kernel) return mapped_kernel(xi) diff --git a/jax/_src/scipy/stats/laplace.py b/jax/_src/scipy/stats/laplace.py index 8761a2cb864f..c3902efbbf57 100644 --- a/jax/_src/scipy/stats/laplace.py +++ b/jax/_src/scipy/stats/laplace.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax import lax +from jax._src import lax from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import promote_args_inexact from jax._src.typing import Array, ArrayLike diff --git a/jax/_src/scipy/stats/logistic.py b/jax/_src/scipy/stats/logistic.py index 5e41e4f7c8e5..a12fa315476f 100644 --- a/jax/_src/scipy/stats/logistic.py +++ b/jax/_src/scipy/stats/logistic.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax.scipy.special import expit, logit -from jax import lax -import jax.numpy as jnp +from jax._src import lax +from jax._src import numpy as jnp from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import promote_args_inexact +from jax._src.scipy.special import expit, logit from jax._src.typing import Array, ArrayLike diff --git a/jax/_src/scipy/stats/multinomial.py b/jax/_src/scipy/stats/multinomial.py index fe9fd6781423..4270537f03ee 100644 --- a/jax/_src/scipy/stats/multinomial.py +++ b/jax/_src/scipy/stats/multinomial.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np -from jax import lax -import jax.numpy as jnp +from jax._src import dtypes +from jax._src import lax +from jax._src import numpy as jnp from jax._src.numpy.util import promote_args_inexact, promote_args_numeric from jax._src.scipy.special import gammaln, xlogy from jax._src.typing import Array, ArrayLike @@ -46,12 +48,12 @@ def logpmf(x: ArrayLike, n: ArrayLike, p: ArrayLike) -> Array: """ p, = promote_args_inexact("multinomial.logpmf", p) x, n = promote_args_numeric("multinomial.logpmf", x, n) - if not jnp.issubdtype(x.dtype, jnp.integer): + if not dtypes.issubdtype(x.dtype, np.integer): raise ValueError(f"x and n must be of integer type; got x.dtype={x.dtype}, n.dtype={n.dtype}") x = x.astype(p.dtype) n = n.astype(p.dtype) logprobs = gammaln(n + 1) + jnp.sum(xlogy(x, p) - gammaln(x + 1), axis=-1) - return jnp.where(jnp.equal(jnp.sum(x), n), logprobs, -jnp.inf) + return jnp.where(jnp.equal(jnp.sum(x), n), logprobs, -np.inf) def pmf(x: ArrayLike, n: ArrayLike, p: ArrayLike) -> Array: diff --git a/jax/_src/scipy/stats/multivariate_normal.py b/jax/_src/scipy/stats/multivariate_normal.py index 8ba34703aada..2a6848de6588 100644 --- a/jax/_src/scipy/stats/multivariate_normal.py +++ b/jax/_src/scipy/stats/multivariate_normal.py @@ -16,8 +16,10 @@ import numpy as np -from jax import lax -from jax import numpy as jnp +from jax._src import lax +from jax._src import numpy as jnp +from jax._src.numpy import einsum as jnp_einsum +from jax._src.numpy import vectorize as jnp_vectorize from jax._src.numpy.util import promote_dtypes_inexact from jax._src.typing import Array, ArrayLike @@ -58,17 +60,17 @@ def logpdf(x: ArrayLike, mean: ArrayLike, cov: ArrayLike, allow_singular: None = n = mean.shape[-1] if not np.shape(cov): y = x - mean - return (-1/2 * jnp.einsum('...i,...i->...', y, y) / cov + return (-1/2 * jnp_einsum.einsum('...i,...i->...', y, y) / cov - n/2 * (jnp.log(2*np.pi) + jnp.log(cov))) else: if cov.ndim < 2 or cov.shape[-2:] != (n, n): raise ValueError("multivariate_normal.logpdf got incompatible shapes") L = lax.linalg.cholesky(cov) - y = jnp.vectorize( + y = jnp_vectorize.vectorize( partial(lax.linalg.triangular_solve, lower=True, transpose_a=True), signature="(n,n),(n)->(n)" )(L, x - mean) - return (-1/2 * jnp.einsum('...i,...i->...', y, y) - n/2 * jnp.log(2*np.pi) + return (-1/2 * jnp_einsum.einsum('...i,...i->...', y, y) - n/2 * jnp.log(2*np.pi) - jnp.log(L.diagonal(axis1=-1, axis2=-2)).sum(-1)) diff --git a/jax/_src/scipy/stats/nbinom.py b/jax/_src/scipy/stats/nbinom.py index a8d968526e70..4e94a5798a17 100644 --- a/jax/_src/scipy/stats/nbinom.py +++ b/jax/_src/scipy/stats/nbinom.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License -from jax import lax -import jax.numpy as jnp +import numpy as np + +from jax._src import lax +from jax._src import numpy as jnp from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import promote_args_inexact from jax._src.scipy.special import gammaln, xlogy @@ -53,7 +55,7 @@ def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Arra ) log_linear_term = lax.add(xlogy(n, p), xlogy(y, lax.sub(one, p))) log_probs = lax.add(comb_term, log_linear_term) - return jnp.where(lax.lt(k, loc), -jnp.inf, log_probs) + return jnp.where(lax.lt(k, loc), -np.inf, log_probs) def pmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: diff --git a/jax/_src/scipy/stats/norm.py b/jax/_src/scipy/stats/norm.py index 6dbebe2f6d11..c44e8b358ff4 100644 --- a/jax/_src/scipy/stats/norm.py +++ b/jax/_src/scipy/stats/norm.py @@ -14,12 +14,12 @@ import numpy as np -from jax import lax -import jax.numpy as jnp +from jax._src import lax +from jax._src import numpy as jnp from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import promote_args_inexact +from jax._src.scipy import special from jax._src.typing import Array, ArrayLike -from jax.scipy import special def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: diff --git a/jax/_src/scipy/stats/pareto.py b/jax/_src/scipy/stats/pareto.py index 0b0c9e1a4993..0f239dd7db95 100644 --- a/jax/_src/scipy/stats/pareto.py +++ b/jax/_src/scipy/stats/pareto.py @@ -12,14 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax import lax -import jax.numpy as jnp + +import numpy as np + +from jax._src import lax +from jax._src import numpy as jnp from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import promote_args_inexact from jax._src.typing import Array, ArrayLike -def logpdf(x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: +def logpdf( + x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1 +) -> Array: r"""Pareto log probability distribution function. JAX implementation of :obj:`scipy.stats.pareto` ``logpdf``. @@ -45,17 +50,26 @@ def logpdf(x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) array of logpdf values. See Also: - :func:`jax.scipy.stats.pareto.pdf` + - :func:`jax.scipy.stats.pareto.logcdf` + - :func:`jax.scipy.stats.pareto.logsf` + - :func:`jax.scipy.stats.pareto.cdf` + - :func:`jax.scipy.stats.pareto.pdf` + - :func:`jax.scipy.stats.pareto.ppf` + - :func:`jax.scipy.stats.pareto.sf` """ x, b, loc, scale = promote_args_inexact("pareto.logpdf", x, b, loc, scale) one = _lax_const(x, 1) scaled_x = lax.div(lax.sub(x, loc), scale) normalize_term = lax.log(lax.div(scale, b)) - log_probs = lax.neg(lax.add(normalize_term, lax.mul(lax.add(b, one), lax.log(scaled_x)))) - return jnp.where(lax.lt(x, lax.add(loc, scale)), -jnp.inf, log_probs) + log_probs = lax.neg( + lax.add(normalize_term, lax.mul(lax.add(b, one), lax.log(scaled_x))) + ) + return jnp.where(lax.lt(x, lax.add(loc, scale)), -np.inf, log_probs) -def pdf(x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: +def pdf( + x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1 +) -> Array: r"""Pareto probability distribution function. JAX implementation of :obj:`scipy.stats.pareto` ``pdf``. @@ -81,6 +95,219 @@ def pdf(x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> array of pdf values. See Also: - :func:`jax.scipy.stats.pareto.logpdf` + - :func:`jax.scipy.stats.pareto.logcdf` + - :func:`jax.scipy.stats.pareto.logpdf` + - :func:`jax.scipy.stats.pareto.logsf` + - :func:`jax.scipy.stats.pareto.cdf` + - :func:`jax.scipy.stats.pareto.ppf` + - :func:`jax.scipy.stats.pareto.sf` """ return lax.exp(logpdf(x, b, loc, scale)) + + +def cdf( + x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1 +) -> Array: + r"""Pareto cumulative distribution function. + + JAX implementation of :obj:`scipy.stats.pareto` ``cdf``. + + The Pareto cumulative distribution function is given by + + .. math:: + + F(x, b) = \begin{cases} + 1 - x^{-b} & x \ge 1\\ + 0 & x < 1 + \end{cases} + + and is defined for :math:`b > 0`. + + Args: + x: arraylike, value at which to evaluate the CDF + b: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of CDF values. + + See Also: + - :func:`jax.scipy.stats.pareto.logcdf` + - :func:`jax.scipy.stats.pareto.logpdf` + - :func:`jax.scipy.stats.pareto.logsf` + - :func:`jax.scipy.stats.pareto.pdf` + - :func:`jax.scipy.stats.pareto.ppf` + - :func:`jax.scipy.stats.pareto.sf` + """ + x, b, loc, scale = promote_args_inexact("pareto.cdf", x, b, loc, scale) + one = _lax_const(x, 1) + zero = _lax_const(x, 0) + scaled_x = lax.div(lax.sub(x, loc), scale) + cdf = lax.sub(one, lax.pow(scaled_x, lax.neg(b))) + return jnp.where(lax.lt(x, lax.add(loc, scale)), zero, cdf) + + +def logcdf( + x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1 +) -> Array: + r"""Pareto log cumulative distribution function. + + JAX implementation of :obj:`scipy.stats.pareto` ``logcdf``. + + The Pareto cumulative distribution function is given by + + .. math:: + + F(x, b) = \begin{cases} + 1 - x^{-b} & x \ge 1\\ + 0 & x < 1 + \end{cases} + + and is defined for :math:`b > 0`. + + Args: + x: arraylike, value at which to evaluate the CDF + b: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of logCDF values. + + See Also: + - :func:`jax.scipy.stats.pareto.logpdf` + - :func:`jax.scipy.stats.pareto.logsf` + - :func:`jax.scipy.stats.pareto.cdf` + - :func:`jax.scipy.stats.pareto.pdf` + - :func:`jax.scipy.stats.pareto.ppf` + - :func:`jax.scipy.stats.pareto.sf` + """ + x, b, loc, scale = promote_args_inexact("pareto.logcdf", x, b, loc, scale) + scaled_x = lax.div(lax.sub(x, loc), scale) + logcdf_val = lax.log1p(lax.neg(lax.pow(scaled_x, lax.neg(b)))) + return jnp.where(lax.lt(x, lax.add(loc, scale)), -np.inf, logcdf_val) + + +def logsf( + x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1 +) -> Array: + r"""Pareto log survival function. + + JAX implementation of :obj:`scipy.stats.pareto` ``logsf``. + + The Pareto survival function is given by + + .. math:: + + S(x, b) = \begin{cases} + x^{-b} & x \ge 1\\ + 1 & x < 1 + \end{cases} + + and is defined for :math:`b > 0`. + + Args: + x: arraylike, value at which to evaluate the survival function + b: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of log survival function values. + + See Also: + - :func:`jax.scipy.stats.pareto.logcdf` + - :func:`jax.scipy.stats.pareto.logpdf` + - :func:`jax.scipy.stats.pareto.cdf` + - :func:`jax.scipy.stats.pareto.pdf` + - :func:`jax.scipy.stats.pareto.ppf` + - :func:`jax.scipy.stats.pareto.sf` + """ + x, b, loc, scale = promote_args_inexact("pareto.logsf", x, b, loc, scale) + zero = _lax_const(x, 0) + scaled_x = lax.div(lax.sub(x, loc), scale) + logsf_val = lax.neg(lax.mul(b, lax.log(scaled_x))) + return jnp.where(lax.lt(x, lax.add(loc, scale)), zero, logsf_val) + + +def sf( + x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1 +) -> Array: + r"""Pareto survival function. + + JAX implementation of :obj:`scipy.stats.pareto` ``sf``. + + The Pareto survival function is given by + + .. math:: + + S(x, b) = \begin{cases} + x^{-b} & x \ge 1\\ + 1 & x < 1 + \end{cases} + + and is defined for :math:`b > 0`. + + Args: + x: arraylike, value at which to evaluate the survival function + b: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of survival function values. + + See Also: + - :func:`jax.scipy.stats.pareto.logcdf` + - :func:`jax.scipy.stats.pareto.logpdf` + - :func:`jax.scipy.stats.pareto.logsf` + - :func:`jax.scipy.stats.pareto.cdf` + - :func:`jax.scipy.stats.pareto.pdf` + - :func:`jax.scipy.stats.pareto.ppf` + """ + return lax.exp(logsf(x, b, loc, scale)) + + +def ppf( + q: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1 +) -> Array: + r"""Pareto percent point function (inverse CDF). + + JAX implementation of :obj:`scipy.stats.pareto` ``ppf``. + + The Pareto percent point function is the inverse of the Pareto CDF, and is + given by + + .. math:: + + F^{-1}(q, b) = \begin{cases} + (1 - q)^{-1/b} & 0 \le q < 1\\ + \text{NaN} & \text{otherwise} + \end{cases} + + and is defined for :math:`b > 0`. + + Args: + q: arraylike, value at which to evaluate the inverse CDF + b: arraylike, distribution shape parameter + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of percent point function values. + + See Also: + - :func:`jax.scipy.stats.pareto.logcdf` + - :func:`jax.scipy.stats.pareto.logpdf` + - :func:`jax.scipy.stats.pareto.logsf` + - :func:`jax.scipy.stats.pareto.cdf` + - :func:`jax.scipy.stats.pareto.pdf` + - :func:`jax.scipy.stats.pareto.sf` + """ + q, b, loc, scale = promote_args_inexact("pareto.ppf", q, b, loc, scale) + one = _lax_const(q, 1) + ppf_val = lax.add( + loc, lax.mul(scale, lax.pow(lax.sub(one, q), lax.neg(lax.div(one, b)))) + ) + return jnp.where(jnp.isnan(q) | (q < 0) | (q > 1), np.nan, ppf_val) diff --git a/jax/_src/scipy/stats/poisson.py b/jax/_src/scipy/stats/poisson.py index 84f4cfe89208..bb2f9399dfdc 100644 --- a/jax/_src/scipy/stats/poisson.py +++ b/jax/_src/scipy/stats/poisson.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax import lax -import jax.numpy as jnp +import numpy as np + +from jax._src import lax +from jax._src import numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import promote_args_inexact +from jax._src.numpy.util import promote_args_inexact, promote_dtypes_inexact, ensure_arraylike +from jax._src.scipy.special import xlogy, entr, gammaln, gammaincc from jax._src.typing import Array, ArrayLike -from jax.scipy.special import xlogy, gammaln, gammaincc def logpmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array: @@ -50,7 +52,7 @@ def logpmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array: x = lax.sub(k, loc) log_probs = xlogy(x, mu) - gammaln(x + 1) - mu return jnp.where(jnp.logical_or(lax.lt(x, zero), - lax.ne(jnp.round(k), k)), -jnp.inf, log_probs) + lax.ne(jnp.round(k), k)), -np.inf, log_probs) def pmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array: @@ -112,3 +114,126 @@ def cdf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array: x = lax.sub(k, loc) p = gammaincc(jnp.floor(1 + x), mu) return jnp.where(lax.lt(x, zero), zero, p) + +def entropy(mu: ArrayLike, loc: ArrayLike = 0) -> Array: + r"""Shannon entropy of the Poisson distribution. + + JAX implementation of :obj:`scipy.stats.poisson` ``entropy``. + + The entropy :math:`H(X)` of a Poisson random variable + :math:`X \sim \text{Poisson}(\mu)` is defined as: + + .. math:: + + H(X) = -\sum_{k=0}^\infty p(k) \log p(k) + + where :math:`p(k) = e^{-\mu} \mu^k / k!` for + :math:`k \geq \max(0, \lfloor \text{loc} \rfloor)`. + + This implementation uses **regime switching** for numerical stability + and performance: + + - **Small** :math:`\mu < 10`: Direct summation over PMF with adaptive + upper bound :math:`k \leq \mu + 20` + - **Medium** :math:`10 \leq \mu < 100`: Summation with bound + :math:`k \leq \mu + 10\sqrt{\mu} + 20` + - **Large** :math:`\mu \geq 100`: Asymptotic Stirling approximation: + :math:`H(\mu) \approx \frac{1}{2} \log(2\pi e \mu) - \frac{1}{12\mu}` + + Matches SciPy to relative error :math:`< 10^{-5}` across all regimes. + + Args: + mu: arraylike, mean parameter of the Poisson distribution. + Must be ``> 0``. + loc: arraylike, optional location parameter (default: 0). + Accepted for API compatibility with scipy but does not + affect the entropy + + Returns: + Array of entropy values with shape broadcast from ``mu`` and ``loc``. + Returns ``NaN`` for ``mu <= 0``. + + Examples: + >>> from jax.scipy.stats import poisson + >>> poisson.entropy(5.0) + Array(2.204394, dtype=float32) + >>> poisson.entropy(jax.numpy.array([1, 10, 100])) + Array([1.3048419, 2.5614073, 3.7206903], dtype=float32) + + See Also: + - :func:`jax.scipy.stats.poisson.pmf` + - :func:`jax.scipy.stats.poisson.logpmf` + - :obj:`scipy.stats.poisson` + """ + mu, loc = ensure_arraylike("poisson.entropy", mu, loc) + promoted_mu, promoted_loc = promote_dtypes_inexact(mu, loc) + + #Note: loc does not affect the entropy - translation invariant + #it has only been taken to maintain compatibility with scipy api + result_shape = jnp.broadcast_shapes( + promoted_mu.shape, + promoted_loc.shape + ) + + mu_flat = jnp.ravel(promoted_mu) + zero_result = jnp.zeros_like(mu_flat) + + + # Choose the computation regime based on mu value + result = jnp.where( + mu_flat == 0, + zero_result, + jnp.where( + mu_flat < 10, + _entropy_small_mu(mu_flat), + jnp.where( + mu_flat < 100, + _entropy_medium_mu(mu_flat), + _entropy_large_mu(mu_flat) + ) + ) + ) + + result_mu_shape = jnp.reshape(result, promoted_mu.shape) + + # Restore original shape + return jnp.broadcast_to(result_mu_shape, result_shape) + +def _entropy_small_mu(mu: Array) -> Array: + """Entropy via direct PMF summation for small μ (< 10). + Uses adaptive upper bound k ≤ μ + 20 to capture >99.999% of mass. + """ + max_k = 35 + + k = jnp.arange(max_k, dtype=mu.dtype)[:, None] + probs = pmf(k, mu, 0) + + # Mask: only compute up to mu + 20 for each value + upper_bounds = jnp.ceil(mu + 20).astype(k.dtype) + mask = k < upper_bounds[None, :] + probs_masked = jnp.where(mask, probs, 0.0) + + return jnp.sum(entr(probs_masked), axis=0) + +def _entropy_medium_mu(mu: Array) -> Array: + """Entropy for medium mu (10-100): Adaptive bounds based on std dev. + + Bounds: k ≤ μ + 10√μ + 20. Caps at k=250 for JIT compatibility. + """ + max_k = 250 # Static bound for JIT. For mu<100, upper bound < 220 + + k = jnp.arange(max_k, dtype=mu.dtype)[:, None] + probs = pmf(k, mu, 0) + + upper_bounds = jnp.ceil(mu + 10 * jnp.sqrt(mu) + 20).astype(k.dtype) + mask = k < upper_bounds[None, :] + probs_masked = jnp.where(mask, probs, 0.0) + + return jnp.sum(entr(probs_masked), axis=0) + +def _entropy_large_mu(mu: Array) -> Array: + """Entropy for large mu (>= 100): Asymptotic approximation. + + Formula: H(λ) ≈ 0.5*log(2πeλ) - 1/(12λ) + O(λ^-2) + """ + return 0.5 * jnp.log(2 * np.pi * np.e * mu) - 1.0 / (12 * mu) diff --git a/jax/_src/scipy/stats/t.py b/jax/_src/scipy/stats/t.py index 2e276c831e28..95a837705039 100644 --- a/jax/_src/scipy/stats/t.py +++ b/jax/_src/scipy/stats/t.py @@ -15,7 +15,7 @@ import numpy as np -from jax import lax +from jax._src import lax from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import promote_args_inexact from jax._src.typing import Array, ArrayLike diff --git a/jax/_src/scipy/stats/truncnorm.py b/jax/_src/scipy/stats/truncnorm.py index a02e07d10480..1c012389aed3 100644 --- a/jax/_src/scipy/stats/truncnorm.py +++ b/jax/_src/scipy/stats/truncnorm.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax import lax -import jax.numpy as jnp +import numpy as np + +from jax._src import lax +from jax._src import numpy as jnp from jax._src.numpy.util import promote_args_inexact from jax._src.scipy.stats import norm from jax._src.scipy.special import logsumexp, log_ndtr, ndtr @@ -105,8 +107,8 @@ def logpdf(x, a, b, loc=0, scale=1): val = lax.sub(norm.logpdf(x, loc, scale), _log_gauss_mass(a, b)) x_scaled = lax.div(lax.sub(x, loc), scale) - val = jnp.where((x_scaled < a) | (x_scaled > b), -jnp.inf, val) - val = jnp.where(a >= b, jnp.nan, val) + val = jnp.where((x_scaled < a) | (x_scaled > b), -np.inf, val) + val = jnp.where(a >= b, np.nan, val) return val @@ -257,9 +259,9 @@ def logcdf(x, a, b, loc=0, scale=1): logcdf = jnp.select( # third condition: avoid catastrophic cancellation (from scipy) [x >= b, x <= a, logcdf > -0.1, x > a], - [0, -jnp.inf, jnp.log1p(-jnp.exp(logsf)), logcdf] + [0, -np.inf, jnp.log1p(-jnp.exp(logsf)), logcdf] ) - logcdf = jnp.where(a >= b, jnp.nan, logcdf) + logcdf = jnp.where(a >= b, np.nan, logcdf) return logcdf diff --git a/jax/_src/scipy/stats/uniform.py b/jax/_src/scipy/stats/uniform.py index 8d36e23c1b70..cb794eed0b24 100644 --- a/jax/_src/scipy/stats/uniform.py +++ b/jax/_src/scipy/stats/uniform.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax import lax -from jax import numpy as jnp -from jax.numpy import where, inf, logical_or +import numpy as np + +from jax._src import lax +from jax._src import numpy as jnp from jax._src.typing import Array, ArrayLike from jax._src.numpy.util import promote_args_inexact @@ -48,9 +49,9 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: """ x, loc, scale = promote_args_inexact("uniform.logpdf", x, loc, scale) log_probs = lax.neg(lax.log(scale)) - return where(logical_or(lax.gt(x, lax.add(loc, scale)), - lax.lt(x, loc)), - -inf, log_probs) + return jnp.where(jnp.logical_or(lax.gt(x, lax.add(loc, scale)), + lax.lt(x, loc)), + -np.inf, log_probs) def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: @@ -140,8 +141,8 @@ def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - :func:`jax.scipy.stats.uniform.logpdf` """ q, loc, scale = promote_args_inexact("uniform.ppf", q, loc, scale) - return where( + return jnp.where( jnp.isnan(q) | (q < 0) | (q > 1), - jnp.nan, + np.nan, lax.add(loc, lax.mul(scale, q)) ) diff --git a/jax/_src/scipy/stats/vonmises.py b/jax/_src/scipy/stats/vonmises.py index 631cc8ee2145..6b76744a0c69 100644 --- a/jax/_src/scipy/stats/vonmises.py +++ b/jax/_src/scipy/stats/vonmises.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax import lax -import jax.numpy as jnp +import numpy as np + +from jax._src import lax +from jax._src import numpy as jnp from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import promote_args_inexact from jax._src.typing import Array, ArrayLike @@ -46,7 +48,7 @@ def logpdf(x: ArrayLike, kappa: ArrayLike) -> Array: """ x, kappa = promote_args_inexact('vonmises.logpdf', x, kappa) zero = _lax_const(kappa, 0) - return jnp.where(lax.gt(kappa, zero), kappa * (jnp.cos(x) - 1) - jnp.log(2 * jnp.pi * lax.bessel_i0e(kappa)), jnp.nan) + return jnp.where(lax.gt(kappa, zero), kappa * (jnp.cos(x) - 1) - jnp.log(2 * np.pi * lax.bessel_i0e(kappa)), np.nan) def pdf(x: ArrayLike, kappa: ArrayLike) -> Array: diff --git a/jax/_src/scipy/stats/wrapcauchy.py b/jax/_src/scipy/stats/wrapcauchy.py index 26b24d7da447..68c30fdaced5 100644 --- a/jax/_src/scipy/stats/wrapcauchy.py +++ b/jax/_src/scipy/stats/wrapcauchy.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np -from jax import lax -import jax.numpy as jnp +from jax._src import lax +from jax._src import numpy as jnp from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import promote_args_inexact from jax._src.typing import Array, ArrayLike @@ -47,11 +48,11 @@ def logpdf(x: ArrayLike, c: ArrayLike) -> Array: return jnp.where( lax.gt(c, _lax_const(c, 0)) & lax.lt(c, _lax_const(c, 1)), jnp.where( - lax.ge(x, _lax_const(x, 0)) & lax.le(x, _lax_const(x, jnp.pi * 2)), - jnp.log(1 - c * c) - jnp.log(2 * jnp.pi) - jnp.log(1 + c * c - 2 * c * jnp.cos(x)), - -jnp.inf, + lax.ge(x, _lax_const(x, 0)) & lax.le(x, _lax_const(x, np.pi * 2)), + jnp.log(1 - c * c) - jnp.log(2 * np.pi) - jnp.log(1 + c * c - 2 * c * jnp.cos(x)), + -np.inf, ), - jnp.nan, + np.nan, ) diff --git a/jax/_src/shard_alike.py b/jax/_src/shard_alike.py index 7f1bab1fc847..0ac06b0708ad 100644 --- a/jax/_src/shard_alike.py +++ b/jax/_src/shard_alike.py @@ -71,10 +71,10 @@ def _shard_alike_batcher(batched_args, batch_dims): if xd == yd: return shard_alike(x, y), (xd, yd) elif xd is batching.not_mapped: - x = batching.broadcast(x, y.shape[yd], yd) + x = batching.broadcast(x, y.shape[yd], yd, None) return shard_alike(x, y), (yd, yd) elif yd is batching.not_mapped: - y = batching.broadcast(y, x.shape[xd], xd) + y = batching.broadcast(y, x.shape[xd], xd, None) return shard_alike(x, y), (xd, xd) else: y = batching.moveaxis(y, yd, xd) diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py new file mode 100644 index 000000000000..7e7a7c0f442f --- /dev/null +++ b/jax/_src/shard_map.py @@ -0,0 +1,1962 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from collections.abc import Callable, Hashable, Sequence, Set +import enum +from functools import partial +import inspect +from math import prod +import operator as op +from typing import Any, TypeVar, Union + +import numpy as np + +from jax._src import ad_util +from jax._src import api +from jax._src import api_util +from jax._src import config +from jax._src import core +from jax._src import dispatch +from jax._src import dtypes +from jax._src import linear_util as lu +from jax._src import sharding_impls +from jax._src import source_info_util +from jax._src import traceback_util +from jax._src import util +from jax._src.core import order_wrt_mesh +from jax._src.core import pvary, Tracer, typeof, shard_aval, unshard_aval +from jax._src.mesh import (AbstractMesh, Mesh, BaseMesh, AxisType, + use_abstract_mesh, get_abstract_mesh, + get_concrete_mesh) +from jax._src.pjit import reshard +from jax._src.lax import lax, parallel as lax_parallel +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import hlo, sdy +from jax._src.sharding_impls import NamedSharding, PartitionSpec +from jax._src.util import (HashableFunction, HashablePartial, unzip2, + as_hashable_function, partition_list, merge_lists, + split_list, subs_list2, fun_name as util_fun_name) +from jax._src.state import discharge +from jax._src.state.types import AbstractRef +from jax._src.interpreters import batching +from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe +from jax._src.interpreters import pxla +from jax._src.interpreters import ad +from jax._src.tree_util import ( + broadcast_prefix, keystr, prefix_errors, generate_key_paths, tree_flatten, + tree_leaves, tree_map, tree_structure, tree_unflatten, KeyPath, PyTreeDef) + +P = PartitionSpec + +map, unsafe_map = util.safe_map, map +zip, unsafe_zip = util.safe_zip, zip +traceback_util.register_exclusion(__file__) + +# API + +Specs = Any # PyTree[PartitionSpec] +AxisName = Hashable + +class InferFromArgs: + def __repr__(self): + return "jax.sharding.Infer" + + def __reduce__(self): + return (_get_default_infer, ()) + +Infer = InferFromArgs() + +def _get_default_infer(): + return Infer + +# See https://github.com/jax-ml/jax/pull/30753 to understand why `in_specs` +# defaults to `Infer`. +def shard_map(f=None, /, *, out_specs: Specs, + in_specs: Specs | None | InferFromArgs = Infer, + mesh: Mesh | AbstractMesh | None = None, + axis_names: Set[AxisName] = frozenset(), + check_vma: bool = True): + """Map a function over shards of data using a mesh of devices. + + See the docs at https://docs.jax.dev/en/latest/notebooks/shard_map.html. + + Args: + f: callable to be mapped. Each application of ``f``, or "instance" of ``f``, + takes as input a shard of the mapped-over arguments and produces a shard + of the output. + mesh: (optional, default None) a ``jax.sharding.Mesh`` representing the + array of devices over which to shard the data and on which to execute + instances of ``f``. The names of the ``Mesh`` can be used in collective + communication operations in ``f``. If mesh is None, it will be inferred + from the context which can be set via `jax.set_mesh` context + manager. + in_specs: (optional, default `Infer`) a pytree with + ``jax.sharding.PartitionSpec`` instances as leaves, with a tree structure + that is a tree prefix of the args tuple to be mapped over. Similar to + ``jax.sharding.NamedSharding``, each ``PartitionSpec`` represents how the + corresponding argument (or subtree of arguments) should be sharded along + the named axes of ``mesh``. In each ``PartitionSpec``, mentioning a + ``mesh`` axis name at a position expresses sharding the corresponding + argument array axis along that positional axis; not mentioning an axis + name expresses replication. + If ``Infer``, all mesh axes must be of type + `Explicit`, in which case the in_specs are inferred from the argument types. + If ``None``, inputs will be treated as static. + out_specs: a pytree with ``PartitionSpec`` instances as leaves, with a tree + structure that is a tree prefix of the output of ``f``. Each + ``PartitionSpec`` represents how the corresponding output shards should be + concatenated. In each ``PartitionSpec``, mentioning a ``mesh`` axis name + at a position expresses concatenation of that mesh axis's shards along the + corresponding positional axis; not mentioning a ``mesh`` axis name + expresses a promise that the output values are equal along that mesh axis, + and that rather than concatenating only a single value should be produced. + axis_names: (optional, default set()) set of axis names from ``mesh`` over + which the function ``f`` is manual. If empty, ``f``, is manual + over all mesh axes. + check_vma: (optional) boolean (default True) representing whether to enable + additional validity checks and automatic differentiation optimizations. + The validity checks concern whether any mesh axis names not mentioned in + ``out_specs`` are consistent with how the outputs of ``f`` are replicated. + + Returns: + A callable representing a mapped version of ``f``, which accepts positional + arguments corresponding to those of ``f`` and produces output corresponding + to that of ``f``. + """ + kwargs = dict(mesh=mesh, in_specs=in_specs, out_specs=out_specs, + axis_names=axis_names, check_vma=check_vma) + if f is None: + return lambda g: _shard_map(g, **kwargs) + return _shard_map(f, **kwargs) + + +def smap(f=None, /, *, in_axes=Infer, out_axes, axis_name: AxisName): + """Single axis shard_map that maps a function `f` one axis at a time. + + Args: + f: Callable to be mapped. Each application of ``f``, or "instance" of ``f``, + takes as input a shard of the mapped-over arguments and produces a shard + of the output. + in_axes: (optional) An integer, None, or sequence of values specifying which + input array axes to map over. If not specified, `smap` will try to infer + the axes from the arguments only under `Explicit` mode. + An integer or ``None`` indicates which array axis to map over for all + arguments (with ``None`` indicating not to map any axis), and a tuple + indicates which axis to map for each corresponding positional argument. + Axis integers must be in the range ``[-ndim, ndim)`` for each array, + where ``ndim`` is the number of dimensions (axes) of the corresponding + input array. + out_axes: An integer, None, or (nested) standard Python container + (tuple/list/dict) thereof indicating where the mapped axis should appear + in the output. + axis_name: ``mesh`` axis name over which the function ``f`` is manual. + + Returns: + A callable representing a mapped version of ``f``, which accepts positional + arguments corresponding to those of ``f`` and produces output corresponding + to that of ``f``. + """ + kwargs = dict(in_axes=in_axes, out_axes=out_axes, axis_name=axis_name) + if f is None: + return lambda g: _smap(g, **kwargs) + return _smap(f, **kwargs) + +def _smap(f, *, in_axes, out_axes, axis_name: AxisName): + if isinstance(axis_name, (list, tuple)): + raise TypeError( + f"smap axis_name should be a `str` or a `Hashable`, but got {axis_name}") + if (in_axes is not None and in_axes is not Infer and + not isinstance(in_axes, (int, tuple))): + raise TypeError( + "smap in_axes must be an int, None, jax.sharding.Infer, or a tuple of" + " entries corresponding to the positional arguments passed to the" + f" function, but got {in_axes}.") + if (in_axes is not Infer and + not all(isinstance(l, int) for l in tree_leaves(in_axes))): + raise TypeError( + "smap in_axes must be an int, None, jax.sharding.Infer, or (nested)" + f" container with those types as leaves, but got {in_axes}.") + if not all(isinstance(l, int) for l in tree_leaves(out_axes)): + raise TypeError("smap out_axes must be an int, None, or (nested) container " + f"with those types as leaves, but got {out_axes}.") + + in_specs = (Infer if in_axes is Infer else + tree_map(partial(_axes_to_pspec, axis_name), in_axes, + is_leaf=lambda x: x is None)) + out_specs = tree_map(partial(_axes_to_pspec, axis_name), out_axes, + is_leaf=lambda x: x is None) + return _shard_map(f, mesh=None, in_specs=in_specs, out_specs=out_specs, + axis_names={axis_name}, check_vma=True, _smap=True) + + +@partial(traceback_util.api_boundary, repro_api_name="jax.shard_map") +def _shard_map(f: Callable, *, mesh: Mesh | AbstractMesh | None, + in_specs: Specs, out_specs: Specs | Callable[[], Specs], + axis_names: Set[AxisName], check_vma: bool, + _smap: bool = False) -> Callable: + if not callable(f): + raise TypeError("shard_map requires a callable for its first argument, " + f"but got {f} of type {type(f)}.") + + @util.wraps(f) + @traceback_util.api_boundary + def wrapped(*args): + nonlocal mesh, axis_names + mesh, axis_names = _shmap_checks( + mesh, axis_names, in_specs, out_specs, _smap) + fun = lu.wrap_init( + f, debug_info=api_util.debug_info("shard_map", f, args, {})) + args_flat, in_tree = tree_flatten(args) + fun, out_specs_thunk = _broadcast_out_specs(fun, out_specs, axis_names) + fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) + + try: + in_specs_flat = broadcast_prefix( + in_specs, args, is_leaf=lambda x: x is None) + except ValueError: + e, *_ = prefix_errors(in_specs, args) + raise e('shard_map in_specs') from None + + if (in_specs is Infer and + all(mesh._name_to_type[a] == AxisType.Explicit for a in axis_names)): + arg_s = [typeof(a).sharding for a in args_flat] + assert all(i is Infer for i in in_specs_flat), in_specs_flat + in_specs_flat = [_manual_spec(axis_names, s.spec, mesh) for s in arg_s] + + dyn_argnums, in_specs_flat = unzip2((i, s) for i, s in enumerate(in_specs_flat) + if s is not None) + if (fun.debug_info.arg_names is not None and + len(dyn_argnums) != len(fun.debug_info.arg_names)): + fun = fun.with_unknown_names() + fun, args_flat = api_util.argnums_partial(fun, dyn_argnums, args_flat, False) + _check_specs_vs_args(f, mesh, in_tree, in_specs, dyn_argnums, in_specs_flat, + args_flat) + + if check_vma: + fun = _implicit_pvary_on_output(fun, out_specs_thunk) + fun = _implicit_unreduced_on_output(fun, out_specs_thunk) + + # TODO(yashkatariya): Add support for partial manual + mesh_axis_names_wo_vmap = ( + frozenset(mesh.axis_names) - core.get_axis_env().explicit_mesh_axis_names) + if (mesh_axis_names_wo_vmap == axis_names and + all(mesh._name_to_type[a] == AxisType.Explicit for a in axis_names)): + args_flat = [a if typeof(a).sharding.spec == s + else reshard(a, NamedSharding(mesh, s)) + for a, s in zip(args_flat, in_specs_flat)] + + try: + out_flat = shard_map_p.bind( + fun, *args_flat, mesh=mesh, in_specs=in_specs_flat, + out_specs_thunk=out_specs_thunk, check_vma=check_vma, + manual_axes=axis_names) + except _SpecError as e: + fails, = e.args + if not callable(out_specs): + msg = _spec_rank_error(SpecErrorType.out, f, out_tree(), out_specs, fails) + if any(fail is not no_fail and not fail.shape for fail in fails): + msg += (" In particular, for rank 0 outputs which are not constant " + "over the mesh, add at least one (singleton) axis to them so " + "that they can be concatenated using out_specs.") + raise ValueError(msg) from None + except _RepError as e: + fails, = e.args + if not callable(out_specs): + msg = _inout_vma_error(f, mesh, out_tree(), out_specs, fails) + raise ValueError(msg) from None + return tree_unflatten(out_tree(), out_flat) + return wrapped + + +@lu.transformation_with_aux2 +def _broadcast_out_specs(_fun, _store, out_specs, axis_names, *args, **kwargs): + ans = _fun(*args, **kwargs) + + if callable(out_specs): + out_specs_ = out_specs() + _check_specs(SpecErrorType.out, out_specs_, axis_names) + else: + out_specs_ = out_specs + + try: + out_specs_flat = broadcast_prefix(out_specs_, ans) + except ValueError: + e, *_ = prefix_errors(out_specs_, ans) + raise e('shard_map out_specs') from None + + _store.store(tuple(out_specs_flat)) + return ans + + +def _axes_to_pspec(axis_name, axis): + if axis is None: + return P() + return P(*[None] * axis + [axis_name]) + + +def _shmap_checks(mesh, axis_names, in_specs, out_specs, _smap): + if mesh is None: + mesh = get_abstract_mesh() + if mesh.empty: + raise ValueError( + "The context mesh cannot be empty. Use" + " `jax.set_mesh(mesh)` to enter into a mesh context") + else: + ctx_mesh = get_abstract_mesh() + if not ctx_mesh.empty and mesh.abstract_mesh != ctx_mesh: + raise ValueError( + f"The context mesh {ctx_mesh} should match the mesh passed to" + f" shard_map {mesh}") + + if not isinstance(mesh, (Mesh, AbstractMesh)): + raise TypeError("shard_map requires a `jax.sharding.Mesh` or a " + "`jax.sharding.AbstractMesh` instance for its " + f"second argument, but got {mesh} of type {type(mesh)}.") + + mesh_axis_names_wo_vmap = ( + frozenset(mesh.axis_names) - core.get_axis_env().explicit_mesh_axis_names + ) + + if not isinstance(axis_names, (frozenset, set)): + raise TypeError( + "`axis_names` argument of shard_map should be of type `frozenset` or" + f" `set`. Got type: {type(axis_names)}") + if isinstance(axis_names, set): + axis_names = frozenset(axis_names) + if not axis_names: + axis_names = mesh_axis_names_wo_vmap + if not axis_names.issubset(mesh_axis_names_wo_vmap): + raise ValueError( + f"jax.shard_map requires axis_names={axis_names} to be a subset of " + f"mesh.axis_names={mesh_axis_names_wo_vmap}") + + if (in_specs is Infer and + not all(mesh._name_to_type[a] == AxisType.Explicit for a in axis_names)): + axis_types = ', '.join(str(mesh._name_to_type[a]) for a in axis_names) + if _smap: + msg = (f"in_axes was not specified when axis_name={axis_names} was of" + f" type {axis_types}") + else: + msg = ("shard_map in_specs argument must be a pytree of" + " `jax.sharding.PartitionSpec` instances, but it was `None` when" + f" {axis_names=} are of type {axis_types}") + raise TypeError(msg) + + if in_specs is not Infer and in_specs is not None: + _check_specs(SpecErrorType.input, in_specs, axis_names) + _check_unreduced(SpecErrorType.input, mesh, axis_names, in_specs) + if not callable(out_specs): + _check_specs(SpecErrorType.out, out_specs, axis_names) + _check_unreduced(SpecErrorType.out, mesh, axis_names, out_specs) + return mesh, axis_names + + +def _manual_spec(manual_axes, spec: P, mesh) -> P: + out = [] # type: ignore + for s in spec: + if s is None: + out.append(s) + elif isinstance(s, tuple): + temp = [p if p in manual_axes else None for p in s] + while temp and temp[-1] is None: + temp.pop() + if None in temp: + raise ValueError(f"Invalid spec: {spec}") + out.append(None if len(temp) == 0 else tuple(temp)) + else: + out.append(s if s in manual_axes else None) + _check_unreduced(SpecErrorType.input, mesh, manual_axes, spec) + return P(*out, unreduced=spec.unreduced, reduced=spec.reduced) + + +# Error checking and messages + +SpecErrorType = enum.Enum('SpecErrorType', ['input', 'out']) + +def _check_unreduced(error_type, mesh, manual_axes, specs): + prefix = 'in' if error_type == SpecErrorType.input else 'out' + full_manual = frozenset(mesh.axis_names) == manual_axes + specs_flat, _ = tree_flatten(specs) + for s in specs_flat: + if not s.unreduced and not s.reduced: + continue + if not full_manual: + raise NotImplementedError( + f"unreduced/reduced can only be passed to {prefix}_specs when" + " shard_map is in full manual mode. Got mesh axis names" + f" {mesh.axis_names}, manual_axes: {manual_axes}, specs: {s}. Please" + " file a bug at https://github.com/jax-ml/jax/issues.") + if not all(mesh._name_to_type[u] == AxisType.Explicit for u in s.unreduced): + raise ValueError( + f"unreduced in {prefix}_specs {s} can only be used when the mesh" + " passed to shard_map contains axis names all of type `Explicit`." + f" Got mesh {mesh}") + if not all(mesh._name_to_type[u] == AxisType.Explicit for u in s.reduced): + raise ValueError( + f"reduced in {prefix}_specs {s} can only be used when the mesh" + " passed to shard_map contains axis names all of type `Explicit`." + f" Got mesh {mesh}") + + +def _check_specs(error_type: SpecErrorType, specs: Any, manual_axes) -> None: + if error_type == SpecErrorType.input and specs is None: + raise TypeError( + "shard_map in_specs argument must be a pytree of " + "`jax.sharding.PartitionSpec` instances, but it was None.\n" + "Instead of `in_specs=None`, did you mean `in_specs=P()`, " + "where `P = jax.sharding.PartitionSpec`?") + + def check_spec(p): + if not isinstance(p, PartitionSpec): + return False + for names in p: + names = (names,) if not isinstance(names, tuple) else names + for name in names: + if name is not None and name not in manual_axes: + return False + return True + + if all(check_spec(p) for p in tree_leaves(specs)): + return + prefix = 'in' if error_type == SpecErrorType.input else 'out' + msgs = [f" {prefix}_specs{keystr(key)} is {x} of type {type(x).__name__}, " + for key, x in generate_key_paths(specs) if not isinstance(x, P)] + if not msgs: + for key, p in generate_key_paths(specs): + for names in p: + names = (names,) if not isinstance(names, tuple) else names + for name in names: + if name is not None and name not in manual_axes: + msgs.append(f" {prefix}_specs{keystr(key)} refers to {repr(name)}") + raise ValueError( + f"shard_map {prefix}_specs argument must refer to an axis " + f"marked as manual ({manual_axes}), but:\n\n" + + '\n\n'.join(msgs) + '\n\n' + f"Check the {prefix}_specs values passed to shard_map.") + raise TypeError( + f"shard_map {prefix}_specs argument must be a pytree of " + f"`jax.sharding.PartitionSpec` instances, but:\n\n" + + '\n\n'.join(msgs) + '\n\n' + f"Check the {prefix}_specs values passed to shard_map.") + +class NoFail: + def __repr__(self): + return "NoFail()" + +no_fail = NoFail() + +def _check_specs_vs_args( + f: Callable, mesh: Mesh | AbstractMesh, in_tree: PyTreeDef, in_specs: Specs, + dyn_argnums: Sequence[int], in_specs_flat: Sequence[P], + xs: Sequence) -> None: + in_avals = map(core.shaped_abstractify, xs) + fail = [a if not len(p) <= a.ndim else no_fail + for p, a in zip(in_specs_flat, in_avals)] + if any(f is not no_fail for f in fail): + fail = _expand_fail(in_tree, dyn_argnums, fail) + msg = _spec_rank_error(SpecErrorType.input, f, in_tree, in_specs, fail) + raise ValueError(msg) + in_names_flat = tuple(map(_spec_to_names, in_specs_flat)) + fail = [a if any(a.shape[d] % prod(mesh.shape[n] for n in ns) + for d, ns in names.items()) else no_fail + for a, names in zip(in_avals, in_names_flat)] + if any(f is not no_fail for f in fail): + fail = _expand_fail(in_tree, dyn_argnums, fail) + msg = _spec_divisibility_error(f, mesh, in_tree, in_specs, fail) + raise ValueError(msg) + +def _expand_fail(in_tree: PyTreeDef, dyn_argnums: Sequence[int], + fail: Sequence[core.ShapedArray | NoFail] + ) -> list[core.ShapedArray | NoFail]: + fail_: list[core.ShapedArray | NoFail] = [no_fail] * in_tree.num_leaves + for i, f in zip(dyn_argnums, fail): + fail_[i] = f + return fail_ + +def _spec_rank_error( + error_type: SpecErrorType, f: Callable, tree: PyTreeDef, specs: Specs, + fails: list[core.ShapedArray | NoFail]) -> str: + fun_name = util_fun_name(f) + if error_type == SpecErrorType.input: + prefix, base = 'in', 'the passed args' + ba = _try_infer_args(f, tree) + else: + prefix, base = 'out', f'{fun_name}(*args)' + msgs = [] + for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails): + extra = "" + if error_type == SpecErrorType.input and ba is not None: + arg_key, *_ = fail_key + param_names, params = unzip2( + (name, param) for name, param in ba.signature.parameters.items() + if param.kind not in (inspect.Parameter.KEYWORD_ONLY, + inspect.Parameter.VAR_KEYWORD)) + if (arg_key.idx >= len(params) or + params[arg_key.idx].kind == inspect.Parameter.VAR_POSITIONAL): + extra = (f", where args{arg_key} is the index " + f"{arg_key.idx - len(params) + 1} component " + f"of {fun_name}'s varargs parameter '{param_names[-1]}',") + else: + param_name = params[arg_key.idx] + extra = (f", where args{arg_key} is bound to {fun_name}'s " + f"parameter '{param_name}',") + msgs.append( + f"* {prefix}_specs{keystr(spec_key)} is {spec} which has length " + f"{len(spec)}, but " + f"{base}{keystr(fail_key)}{extra} has shape {aval.str_short()}, " + f"which has rank {aval.ndim} (and {aval.ndim} < {len(spec)})") + assert msgs + if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point + msg = (f"shard_map applied to the function '{fun_name}' was given an " + f"{prefix}_specs entry which is too long to be compatible with the " + f"corresponding {prefix}put value from the function:\n\n" + + '\n\n'.join(msgs) + '\n\n' + + f"Entries in {prefix}_specs must be of length no greater than the " + f"number of axes in the corresponding {prefix}put value.\n\n" + f"Either revise the spec to be shorter, or modify '{fun_name}' so " + f"that its {prefix}puts have sufficient rank.") + if any(not aval.ndim for _, (_, aval) in _iter_paths(tree, specs, fails)): + msg += (f"\n\nFor scalar values (rank 0), consider using an {prefix}_specs " + "entry of `P()`, where `P = jax.sharding.PartitionSpec`.") + return msg + +def _spec_divisibility_error( + f: Callable, mesh: Mesh | AbstractMesh, tree: PyTreeDef, specs: Specs, + fails: list[core.ShapedArray | NoFail]) -> str: + ba = _try_infer_args(f, tree) + fun_name = getattr(f, '__name__', str(f)) + msgs = [] + for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails): + extra = "" + if ba is not None: + arg_key, *_ = fail_key + param_names, params = unzip2( + (name, param) for name, param in ba.signature.parameters.items() + if param.kind not in (inspect.Parameter.KEYWORD_ONLY, + inspect.Parameter.VAR_KEYWORD)) + if (arg_key.idx >= len(params) or + params[arg_key.idx].kind == inspect.Parameter.VAR_POSITIONAL): + extra = (f", where args{arg_key} is the index " + f"{arg_key.idx - len(params) + 1} component " + f"of {fun_name}'s varargs parameter '{param_names[-1]}',") + else: + param_name = params[arg_key.idx] + extra = (f", where args{arg_key} is bound to {fun_name}'s " + f"parameter '{param_name}',") + names = _spec_to_names(spec) + for d, ns in names.items(): + if aval.shape[d] % prod(mesh.shape[n] for n in ns): + axis = f"axes {ns}" if len(ns) > 1 else f"axis '{ns[0]}'" + total = 'total ' if len(ns) > 1 else '' + sz = prod(mesh.shape[n] for n in ns) + msgs.append( + f"* the passed args{keystr(fail_key)} of shape {aval.str_short()}{extra} " + f"corresponds to in_specs{keystr(spec_key)} of value {spec}, " + f"which maps array axis {d} (of size {aval.shape[d]}) to mesh " + f"{axis} (of {total}size {sz}), but {sz} does not evenly divide " + f"{aval.shape[d]}") + assert msgs + if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point + msg = (f"shard_map applied to the function '{fun_name}' was given argument " + f"arrays with axis sizes that are not evenly divisible by the " + f"corresponding mesh axis sizes:\n\n" + f"The mesh given has shape {tuple(mesh.shape.values())} with " + f"corresponding axis names {mesh.axis_names}.\n\n" + + '\n\n'.join(msgs) + '\n\n' + + f"Array arguments' axis sizes must be evenly divisible by the mesh " + f"axis or axes indicated by the corresponding elements of the " + f"argument's in_specs entry. Consider checking that in_specs are " + f"correct, and if so consider changing the mesh axis sizes or else " + f"padding the input and adapting '{fun_name}' appropriately.") + return msg + +def _inout_vma_error(f: Callable, mesh: Mesh | AbstractMesh, tree: PyTreeDef, + specs: Specs, fails: list[set | NoFail]) -> str: + fun_name = getattr(f, '__name__', str(f)) + msgs = [] + for (spec_key, spec), (fail_key, vma) in _iter_paths(tree, specs, fails): + unmentioned = _unmentioned(mesh, spec) + if len(unmentioned) > 1: + need_vma = ','.join(map(str, order_wrt_mesh(mesh, _spec_to_vma(spec)))) + got_vma = ','.join(map(str, order_wrt_mesh(mesh, vma))) + diff = ','.join(map(str, order_wrt_mesh( + mesh, [n for n in unmentioned if n in vma]))) + msgs.append( + f"* out_specs{keystr(spec_key)} is {spec} which implies that the " + f"corresponding output value is only varying across mesh axes " + f"{{{need_vma}}} and not {{{diff}}}, but it was inferred to be " + f"possibly varying over {{{got_vma}}}") + else: + need_rep_, = unmentioned + msgs.append( + f"* out_specs{keystr(spec_key)} is {spec} which implies that the " + f"corresponding output value is replicated across mesh axis " + f"'{need_rep_}', but could not infer replication over any axes") + assert msgs + if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point + msg = (f"shard_map applied to the function '{fun_name}' was given " + f"out_specs which require replication which can't be statically " + f"inferred given the mesh:\n\n" + f"The mesh given has shape {tuple(mesh.shape.values())} with " + f"corresponding axis names {mesh.axis_names}.\n\n" + + '\n\n'.join(msgs) + '\n\n' + + "Check if these output values are meant to be replicated over those " + "mesh axes. If not, consider revising the corresponding out_specs " + "entries. If so, consider disabling the check by passing the " + "check_vma=False argument to `jax.shard_map`.") + return msg + +def _unmentioned(mesh: Mesh | AbstractMesh, spec) -> list[AxisName]: + vma_set = _spec_to_vma(spec) + return [n for n in mesh.axis_names if n not in vma_set] + + +def _try_infer_args(f, tree): + dummy_args = tree_unflatten(tree, [False] * tree.num_leaves) + try: + return inspect.signature(f).bind(*dummy_args) + except (TypeError, ValueError): + return None + +T = TypeVar('T') +def _iter_paths(tree: PyTreeDef, specs: Specs, fails: list[T | NoFail] + ) -> list[tuple[tuple[KeyPath, P], tuple[KeyPath, T]]]: + failures = tree_unflatten(tree, fails) + failures_aug = generate_key_paths(failures) + specs_ = tree_unflatten(tree_structure(specs), map(Tup, generate_key_paths(specs))) + specs_aug = broadcast_prefix(specs_, failures, is_leaf=lambda x: x is None) + return [(s, (fail_key, fail_data)) for s, (fail_key, fail_data) + in zip(specs_aug, failures_aug) + if s is not None and fail_data is not no_fail] + +class Tup: + def __init__(self, vals): self.vals = vals + def __iter__(self): return iter(self.vals) + +# Primitive + +@lu.transformation2 +def _implicit_pvary_on_output(f, out_specs_thunk, *args, **kwargs): + out_flat = f(*args, **kwargs) + return [pvary(o, tuple(_spec_to_vma(sp) - typeof(o).vma)) + for o, sp in zip(out_flat, out_specs_thunk())] + + +@lu.transformation2 +def _implicit_unreduced_on_output(f, out_specs_thunk, *args, **kwargs): + out_flat = f(*args, **kwargs) + new_out_flat = [] + for o, sp in zip(out_flat, out_specs_thunk()): + o_aval = typeof(o) + if unreduced := (sp.unreduced - o_aval.sharding.spec.unreduced): + axes = order_wrt_mesh(o_aval.sharding.mesh, unreduced) + new_out_flat.append(lax_parallel.vary_unreduced_cast(o, axes)) + else: + new_out_flat.append(o) + return new_out_flat + + +JaxType = Any +MaybeTracer = Union[JaxType, Tracer] + +class ShardMapPrimitive(core.Primitive): + multiple_results = True + + def bind(self, *args, **params): + return self._true_bind(*args, **params) + + def bind_with_trace(self, trace, fun_and_args, params): + fun: lu.WrappedFun + fun, *args = fun_and_args + return trace.process_shard_map(shard_map_p, fun, args, **params) + + def get_bind_params(self, params): + new_params = dict(params) + jaxpr = new_params.pop('jaxpr') + assert isinstance(jaxpr, core.Jaxpr) + subfun = lu.hashable_partial( + lu.wrap_init(core.eval_jaxpr, debug_info=jaxpr.debug_info), jaxpr, ()) + axes = new_params.pop('out_specs') + new_params['out_specs_thunk'] = HashableFunction(lambda: axes, closure=axes) + return [subfun], new_params + +shard_map_p = ShardMapPrimitive('shard_map') + +# Staging + +@util.cache(max_size=256, trace_context_in_key=False) +def _as_manual_mesh(mesh, manual_axes: frozenset) -> AbstractMesh: + return mesh.abstract_mesh.update_axis_types( + {n: AxisType.Manual for n in manual_axes}) + +def _extend_axis_env(mesh, manual_axes): + return core.extend_axis_env_nd([(k, v) for k, v in mesh.shape.items() + if k in manual_axes]) + +def _shard_map_staging( + trace: pe.DynamicJaxprTrace, prim: core.Primitive, f: lu.WrappedFun, + in_tracers: Sequence[Any], *, mesh: Mesh, + in_specs, out_specs_thunk, check_vma: bool, manual_axes: frozenset, + ) -> Sequence[pe.DynamicJaxprTracer]: + source_info = source_info_util.current() + to_jaxpr_tracer = partial(trace.to_jaxpr_tracer, source_info=source_info) + in_tracers = map(to_jaxpr_tracer, in_tracers) + inner_mesh = _as_manual_mesh(mesh, manual_axes) + in_avals = [t.aval for t in in_tracers] + in_avals_ = map(partial(shard_aval, mesh, manual_axes, check_vma), in_specs, + in_avals) + with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(inner_mesh), + config._check_vma(check_vma)): + jaxpr, out_avals_, consts = pe.trace_to_jaxpr_dynamic( + f, in_avals_, lower=trace.requires_low) + + _check_names(out_specs_thunk(), out_avals_) + if check_vma: + out_vma = [v.aval.vma for v in jaxpr.outvars] + _check_vmas(mesh, out_specs_thunk(), out_vma) + out_avals = map(_check_shapedarray, out_avals_) + out_avals = [_check_shapedarray(unshard_aval(mesh, check_vma, spec, aval)) + for spec, aval in zip(out_specs_thunk(), out_avals)] + in_specs_staged = (P(),) * len(consts) + tuple(in_specs) # type: ignore + with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(inner_mesh), + config._check_vma(check_vma)): + jaxpr = pe.convert_constvars_jaxpr(jaxpr) + params = dict(mesh=mesh, in_specs=in_specs_staged, + out_specs=tuple(out_specs_thunk()), jaxpr=jaxpr, + check_vma=check_vma, manual_axes=manual_axes) + effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) + const_tracers = map(to_jaxpr_tracer, consts) + trace.frame.is_high |= jaxpr.is_high + return trace.emit_eqn([*const_tracers, *in_tracers], out_avals, prim, params, + effs, source_info) +pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging + +# TODO add underscore version, for direct-linearize to consume + +def _spec_to_names(spec: PartitionSpec): + return {i: names if isinstance(names, tuple) else (names,) + for i, names in enumerate(spec) if names is not None} + +def _check_shapedarray(aval: core.AbstractValue) -> core.ShapedArray: + assert isinstance(aval, core.ShapedArray) + return aval + +def _shard_shaped_array(mesh: Mesh, manual_axes: frozenset, check_vma, + spec, aval: core.AbstractValue) -> core.AbstractValue: + assert isinstance(aval, core.ShapedArray) + if spec.unreduced != aval.sharding.spec.unreduced: + raise ValueError( + f"in_specs containing unreduced {spec} passed to shard_map should be" + " equal to the unreduced present on the in_aval" + f" {aval.str_short(True)}") + if spec.reduced != aval.sharding.spec.reduced: + raise ValueError( + f"in_specs containing reduced {spec} passed to shard_map should be" + f" equal to the reduced present on the in_aval {aval.str_short(True)}") + names = _spec_to_names(spec) + new_shape = tuple(sz // prod(mesh.shape[n] for n in names.get(i, ())) + for i, sz in enumerate(aval.shape)) + manual_mesh = _as_manual_mesh(mesh, manual_axes) + new_sharding = aval.sharding.update(mesh=manual_mesh) + vma = _spec_to_vma(spec) if check_vma else frozenset() + vma = vma | aval.vma + return aval.update(shape=new_shape, sharding=new_sharding, vma=vma) +core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array + +def _unshard_shaped_array(mesh: Mesh, check_vma, spec, aval: core.AbstractValue + ) -> core.AbstractValue: + assert isinstance(aval, core.ShapedArray) + if spec.unreduced != aval.sharding.spec.unreduced: + raise ValueError( + "out_specs passed to shard_map should be equal to the unreduced" + f" present on the out_aval. Got out_specs={spec} and" + f" out_aval={aval.str_short(True)}") + if spec.reduced != aval.sharding.spec.reduced: + raise ValueError( + "out_specs passed to shard_map should be equal to the reduced present" + f" on the out_aval. Got out_specs={spec} and" + f" out_aval={aval.str_short(True)}") + names = _spec_to_names(spec) + new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ())) + for i, sz in enumerate(aval.shape)) + names_spec = spec._normalized_spec_for_aval(aval.ndim) + if aval.ndim == 0: + out_spec = P() + else: + out_spec = [] # type: ignore + for name_s, aval_s in zip(names_spec, aval.sharding.spec): + if name_s and not aval_s: + out_spec.append(name_s) + elif aval_s and not name_s: + out_spec.append(aval_s) + elif not name_s and not aval_s: + out_spec.append(None) + else: + assert name_s and aval_s + name_s = name_s if isinstance(name_s, tuple) else (name_s,) + aval_s = aval_s if isinstance(aval_s, tuple) else (aval_s,) + out_spec.append(name_s + aval_s) + out_spec = PartitionSpec(*out_spec, unreduced=spec.unreduced, + reduced=spec.reduced) + new_mesh = (mesh.abstract_mesh if get_abstract_mesh().empty else + get_abstract_mesh()) + new_sharding = NamedSharding(new_mesh, out_spec) + manual_axes = set(new_mesh.manual_axes) + vma = (frozenset(v for v in aval.vma if v in manual_axes) + if check_vma else frozenset()) + return aval.update(shape=new_shape, sharding=new_sharding, vma=vma) +core.unshard_aval_handlers[core.ShapedArray] = _unshard_shaped_array + +# Type-checking + +def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_specs, out_specs, + check_vma, manual_axes): + # TODO(mattjj,parkers): check auto + for v, x, in_spec in zip(jaxpr.invars, in_atoms, in_specs): + sharded_aval = shard_aval(mesh, manual_axes, check_vma, in_spec, x.aval) + if not core.typecompat(v.aval, sharded_aval): + raise core.JaxprTypeError("shard_map argument avals not compatible with " + "jaxpr binder avals and in_specs") + with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma): + core.check_jaxpr(jaxpr) + if check_vma: + out_vma = [v.aval.vma for v in jaxpr.outvars] + for vma, out_spec in zip(out_vma, out_specs): + if not _valid_repeats(mesh, vma, out_spec): + raise core.JaxprTypeError( + "shard_map can't prove output is sufficiently replicated") + out_avals_sharded = [x.aval for x in jaxpr.outvars] + out_avals = map(partial(unshard_aval, mesh, check_vma), out_specs, + out_avals_sharded) + effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) + return out_avals, effs +core.custom_typechecks[shard_map_p] = _shard_map_typecheck + + +def _valid_repeats(mesh: Mesh, vma: Set[AxisName], spec) -> bool: + um = set(_unmentioned(mesh, spec)) - set(mesh.manual_axes) + if any(u in vma for u in um): + return False + return True + +# Lowering + +def _shardy_shard_map_sharding( + ctx: mlir.LoweringRuleContext, mesh, manual_axes, spec, aval_in +) -> sharding_impls.SdyArray: + ns = _make_scoped_manual_sharding(ctx, mesh, spec) + if dtypes.issubdtype(aval_in.dtype, dtypes.extended): + ns = sharding_impls.physical_sharding(aval_in, ns) + aval_in = core.physical_aval(aval_in) + sdy_sharding = ns._to_sdy_sharding(aval_in.ndim) + if len(manual_axes) < len(mesh.axis_names): + for dim_sharding in sdy_sharding.dim_shardings: + dim_sharding.is_open = True + return sdy_sharding + + +def _get_token_sharding( + ctx: mlir.LoweringRuleContext, mesh + ) -> ir.Attribute: + ns = _make_scoped_manual_sharding(ctx, mesh, P()) + return ns._to_sdy_sharding(0) + + +def _get_spmdaxis_ctx_mesh(mesh): + if isinstance(mesh, AbstractMesh): + concrete_mesh = get_concrete_mesh() + return concrete_mesh if not concrete_mesh.empty else mesh + return mesh + + +def _shard_map_lowering_shardy( + ctx: mlir.LoweringRuleContext, in_nodes, + jaxpr: core.Jaxpr, mesh, in_specs, out_specs, manual_axes, check_vma): + axis_ctx = ctx.module_context.axis_context + in_avals_ = [v.aval for v in jaxpr.invars] + if isinstance(axis_ctx, sharding_impls.SPMDAxisContext): + # Nested `ManualComputationOp`s must only refer to the new manual axes, not + # all existing ones. Grab the newly-added manual axes. + shardy_manual_axes = manual_axes - axis_ctx.manual_axes + else: + shardy_manual_axes = manual_axes + new_axis_context = sharding_impls.SPMDAxisContext( + _get_spmdaxis_ctx_mesh(mesh), manual_axes) + sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) + + tokens = [ctx.tokens_in.get(eff) for eff in ctx.tokens_in.effects()] + num_tokens = len(tokens) + manual_axes = order_wrt_mesh(mesh, shardy_manual_axes) + if prod([mesh.shape[a] for a in manual_axes]) == 1: + # No need for a `ManualComputationOp` if all manual axes are size 1. + with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma): + out_nodes, tokens_out = mlir.jaxpr_subcomp( + sub_ctx, jaxpr, ctx.name_stack, + mlir.TokenSet(zip(ctx.tokens_in.effects(), tokens)), + (), *in_nodes, + dim_var_values=ctx.dim_var_values, + const_lowering=ctx.const_lowering) + ctx.set_tokens_out(tokens_out) + return out_nodes + + in_shardings = list( + map(partial(_shardy_shard_map_sharding, ctx, mesh, manual_axes), + in_specs, ctx.avals_in)) + const_args_and_avals = core.jaxpr_const_args(jaxpr) + const_args, const_avals = util.unzip2(const_args_and_avals) + num_const_args = len(const_args) + const_arg_values = tuple( + mlir.ir_constant(c, const_lowering=ctx.const_lowering, aval=aval) + for c, aval in const_args_and_avals) + # TODO(necula,yashkatariya): how to construct consts shardy shardings from + # consts that can be ndarray or jax.Array? + const_args_shardings = [ + _shardy_shard_map_sharding(ctx, mesh, manual_axes, P(), core.typeof(c)) + for c in const_args] + + num_dim_vars = len(ctx.dim_var_values) + in_shardings = ( + [_get_token_sharding(ctx, mesh)] * (num_tokens + num_dim_vars) + + const_args_shardings + in_shardings) + in_shardings = sharding_impls.SdyArrayList(in_shardings).build() + + out_shardings = list( + map(partial(_shardy_shard_map_sharding, ctx, mesh, manual_axes), + out_specs, ctx.avals_out)) + out_shardings = [ + _get_token_sharding(ctx, mesh)] * num_tokens + out_shardings + out_shardings = sharding_impls.SdyArrayList(out_shardings).build() + + output_types = ([hlo.TokenType.get()] * num_tokens + + list(map(mlir.aval_to_ir_type, ctx.avals_out))) + + args = (*ctx.dim_var_values, *tokens, *const_arg_values, *in_nodes) + manual_computation_op = sdy.ManualComputationOp( + output_types, mlir.flatten_ir_values(args), in_shardings, out_shardings, + sdy.ManualAxesAttr.get( + ir.ArrayAttr.get([ir.StringAttr.get(i) for i in manual_axes]))) + + dim_var_types = [mlir.aval_to_ir_type( + core.ShapedArray((), dtypes.default_int_dtype()))] * num_dim_vars + token_types = [hlo.TokenType.get()] * num_tokens + const_arg_types = map(mlir.aval_to_ir_type, const_avals) + in_types = map(mlir.aval_to_ir_type, in_avals_) + block = ir.Block.create_at_start( + manual_computation_op.body, + (*dim_var_types, *token_types, *const_arg_types, *in_types)) + + with (ir.InsertionPoint(block), _extend_axis_env(mesh, manual_axes), + config._check_vma(check_vma)): + dim_var_values, token_arg_values, const_arg_values, in_args = util.split_list( # type: ignore + block.arguments, [num_dim_vars, num_tokens, num_const_args]) + block_const_lowering = { + (id(c), aval): ca + for c, aval, ca in zip(const_args, const_avals, const_arg_values) + } + out_nodes_, tokens_out = mlir.jaxpr_subcomp( + sub_ctx, jaxpr, ctx.name_stack, + mlir.TokenSet(zip(ctx.tokens_in.effects(), token_arg_values)), + (), *in_args, + dim_var_values=dim_var_values, + const_lowering=block_const_lowering) + sdy.ReturnOp([ir.Value(x) for x in (*[v for _, v in tokens_out.items()], + *out_nodes_)]) + num_tokens = len(tokens_out.effects()) + tokens_out = tokens_out.update_tokens(mlir.TokenSet(zip( + ctx.tokens_in.effects(), manual_computation_op.results[:num_tokens]))) + ctx.set_tokens_out(tokens_out) + + return manual_computation_op.results[num_tokens:] + + +def _shard_map_lowering(ctx: mlir.LoweringRuleContext, *in_nodes, + jaxpr: core.Jaxpr, mesh, in_specs, out_specs, + check_vma, manual_axes): + if config.use_shardy_partitioner.value: + return _shard_map_lowering_shardy( + ctx, in_nodes, jaxpr, mesh, in_specs, out_specs, manual_axes, check_vma) + + in_avals_ = [v.aval for v in jaxpr.invars] + out_avals_ = [x.aval for x in jaxpr.outvars] + in_nodes_ = map(partial(_xla_shard, ctx, mesh, manual_axes), in_specs, + ctx.avals_in, in_avals_, in_nodes) + new_axis_context = sharding_impls.SPMDAxisContext( + _get_spmdaxis_ctx_mesh(mesh), manual_axes) + sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) + with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma): + out_nodes_, tokens_out = mlir.call_lowering( + "shmap_body", pe.close_jaxpr(jaxpr), None, sub_ctx, in_avals_, + out_avals_, ctx.tokens_in, *in_nodes_, + dim_var_values=ctx.dim_var_values, + const_lowering=ctx.const_lowering, + arg_names=map(_pspec_mhlo_attrs, in_specs, in_avals_), + result_names=map(_pspec_mhlo_attrs, out_specs, out_avals_)) + ctx.set_tokens_out(tokens_out) + return map(partial(_xla_unshard, ctx, mesh, manual_axes), out_specs, + out_avals_, ctx.avals_out, out_nodes_) +mlir.register_lowering(shard_map_p, _shard_map_lowering) + +def _make_scoped_manual_sharding(ctx, mesh, spec): + axis_ctx = ctx.module_context.axis_context + mesh = mesh.abstract_mesh + if isinstance(axis_ctx, sharding_impls.SPMDAxisContext): + mesh = mesh.update_axis_types( + {a: AxisType.Manual for a in axis_ctx.manual_axes}) + return NamedSharding(mesh, spec) + +def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, spec, + aval_in, aval_out, x): + if prod([size for n, size in mesh.shape.items() if n in manual_axes]) == 1: + return x + ns = _make_scoped_manual_sharding(ctx, mesh, spec) + if dtypes.issubdtype(aval_in.dtype, dtypes.extended): + ns = sharding_impls.physical_sharding(aval_in, ns) + aval_in = core.physical_aval(aval_in) + shard_proto = ns._to_xla_hlo_sharding(aval_in.ndim).to_proto() + unspecified = (set(range(aval_in.ndim)) + if len(manual_axes) < len(mesh.axis_names) else set()) + sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto, + unspecified_dims=unspecified) + manual_proto = pxla.manual_proto( + aval_in, manual_axes | set(mesh.manual_axes), mesh) + return mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, + unspecified) + +def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, spec, + aval_in, aval_out, x): + if prod([size for n, size in mesh.shape.items() if n in manual_axes]) == 1: + return x + ns = _make_scoped_manual_sharding(ctx, mesh, spec) + if dtypes.issubdtype(aval_out.dtype, dtypes.extended): + ns = sharding_impls.physical_sharding(aval_out, ns) + aval_out = core.physical_aval(aval_out) + unspecified = (set(range(aval_in.ndim)) + if len(manual_axes) < len(mesh.axis_names) else set()) + if dtypes.issubdtype(aval_in.dtype, dtypes.extended): + aval_in = core.physical_aval(aval_in) + manual_proto = pxla.manual_proto( + aval_in, manual_axes | set(mesh.manual_axes), mesh) + sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto, + unspecified_dims=unspecified) + shard_proto = ns._to_xla_hlo_sharding(aval_out.ndim).to_proto() + return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, shard_proto, + unspecified) + +def _pspec_mhlo_attrs(spec, aval: core.AbstractValue) -> str: + if isinstance(aval, core.ShapedArray): + names = _spec_to_names(spec) + return str(map(names.get, range(aval.ndim))) + return '' + +# Eager evaluation + +def get_mesh_from_args(args_flat, mesh): + for a in args_flat: + if hasattr(a, 'sharding') and isinstance(a.sharding, NamedSharding): + if a.sharding.mesh.shape_tuple != mesh.shape_tuple: + aval = core.shaped_abstractify(a) + raise ValueError( + f"Mesh shape of the input {a.sharding.mesh.shape_tuple} does not" + " match the mesh shape passed to shard_map " + f" {mesh.shape_tuple} for shape {aval.str_short()}") + mesh = a.sharding.mesh + if isinstance(mesh, AbstractMesh): + raise ValueError( + "Please pass `jax.Array`s with a `NamedSharding` as input to" + " `shard_map` when passing `AbstractMesh` to the mesh argument.") + assert isinstance(mesh, Mesh) + return mesh + +def _vma_to_spec(mesh, vma): + return P(order_wrt_mesh(mesh, vma)) + +def _spec_to_vma(spec): + return frozenset(p for s in spec if s is not None + for p in (s if isinstance(s, tuple) else (s,))) + +def _shard_map_impl(trace, prim, fun, args, *, mesh, in_specs, out_specs_thunk, + check_vma, manual_axes): + del prim + if isinstance(mesh, AbstractMesh): + concrete_mesh = get_concrete_mesh() + mesh = concrete_mesh if not concrete_mesh.empty else mesh + mesh = get_mesh_from_args(args, mesh) + cur_mesh = get_abstract_mesh() + args = map(partial(_unmatch_spec, mesh, check_vma, cur_mesh, manual_axes), + in_specs, args) + in_vma = map(_spec_to_vma, in_specs) + outs, out_vma = _run_shmap(fun, mesh, manual_axes, args, in_vma, check_vma) + out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs] + _check_names(out_specs_thunk(), out_avals) # pytype: disable=wrong-arg-types + if check_vma: + _check_vmas(mesh, out_specs_thunk(), out_vma) + src_pspecs = tuple(_vma_to_spec(mesh, r) for r in out_vma) + else: + src_pspecs = tuple(P(order_wrt_mesh(mesh, manual_axes)) + for _ in range(len(out_vma))) + dst_pspecs = out_specs_thunk() + return map(partial(_match_spec, mesh, check_vma, manual_axes), + src_pspecs, dst_pspecs, outs) +core.EvalTrace.process_shard_map = _shard_map_impl + +def _run_shmap(f, mesh, manual_axes, args, vmas, check_vma): + assert not mesh.manual_axes + trace = ShardMapTrace(mesh, manual_axes, check_vma) + in_tracers = map(partial(ShardMapTracer, trace), vmas, args) + inner_mesh = _as_manual_mesh(mesh, manual_axes) + with (core.set_current_trace(trace), _extend_axis_env(mesh, manual_axes), + use_abstract_mesh(inner_mesh), config._check_vma(check_vma)): + ans = f.call_wrapped(*in_tracers) + outs, out_vma = unzip2(map(trace.to_val_vma_pair, ans)) + return outs, out_vma + +def _unmatch_spec2(mesh, prev_manual, spec, x) -> JaxType: + with (core.eval_context(), api.disable_jit(False), + use_abstract_mesh(mesh.abstract_mesh)): + return api.jit(HashablePartial(_unmatch2, mesh, prev_manual, spec))(x) + +def _unmatch2(mesh, prev_manual, spec, x): + src = P(order_wrt_mesh(mesh, prev_manual), *spec) + newly_manual = _spec_to_vma(spec) + dst = P(order_wrt_mesh(mesh, prev_manual | newly_manual)) + return shard_map(lambda x: x, in_specs=src, out_specs=dst)(x) + +def _match_spec2(mesh, prev_manual, spec, x) -> JaxType: + with (core.eval_context(), api.disable_jit(False), + use_abstract_mesh(mesh.abstract_mesh)): + return api.jit(HashablePartial(_match2, mesh, prev_manual, spec))(x) + +def _match2(mesh, prev_manual, spec, x): + newly_manual = _spec_to_vma(spec) + src = P(order_wrt_mesh(mesh, prev_manual | newly_manual)) + dst = P(order_wrt_mesh(mesh, prev_manual), *spec) + return shard_map(lambda x: x, in_specs=src, out_specs=dst)(x) + + +def _unmatch_spec(mesh: Mesh, check_vma, context_mesh, manual_axes, in_spec, + x: JaxType) -> JaxType: + with (core.eval_context(), api.disable_jit(False), + use_abstract_mesh(context_mesh)): + return api.jit(HashablePartial(_unmatch, mesh, check_vma, in_spec, + manual_axes))(x) + +def _unmatch(mesh, check_vma, in_spec, manual_axes, x): + if check_vma: + used_axes = _spec_to_vma(in_spec) + dst = P(order_wrt_mesh(mesh, used_axes)) + else: + dst = P(mesh.axis_names) + check_vma = False + return shard_map(_add_singleton, mesh=mesh, in_specs=(in_spec,), + out_specs=dst, check_vma=check_vma, axis_names=manual_axes)(x) + +def _check_names(specs, avals: Sequence[core.ShapedArray]) -> None: + fail = [a if sp and len(sp) > a.ndim else no_fail + for sp, a in zip(specs, avals)] + if any(f is not no_fail for f in fail): + raise _SpecError(fail) + +class _SpecError(Exception): + pass + +def _check_vmas(mesh, specs, vmas): + fail = [vma if not _valid_repeats(mesh, vma, sp) else no_fail + for sp, vma in zip(specs, vmas)] + if any(f is not no_fail for f in fail): + raise _RepError(fail) + +class _RepError(Exception): + pass + +def _match_spec(mesh: Mesh, check_vma, manual_axes, src_pspec: PartitionSpec, + dst_pspec: PartitionSpec, x: JaxType) -> JaxType: + fn = HashablePartial(_match, mesh, check_vma, manual_axes, src_pspec, + dst_pspec) + with core.eval_context(), api.disable_jit(False): + if set(mesh.axis_names) == manual_axes: + return api.jit(fn, out_shardings=NamedSharding(mesh, dst_pspec))(x) + return api.jit(fn)(x) + +def _match(mesh, check_vma, manual_axes, src_pspec, dst_pspec, x): + return shard_map(_rem_singleton, mesh=mesh, in_specs=src_pspec, + out_specs=dst_pspec, check_vma=check_vma, + axis_names=manual_axes)(x) + +def _rem_singleton(x): return lax.squeeze(x, [0]) +def _add_singleton(x): return lax.expand_dims(x, [0]) + +def _maybe_check_special(outs): + if not config.debug_nans.value and not config.debug_infs.value: return + bufs = [s.data for leaf in tree_leaves(outs) + for s in getattr(leaf, 'addressable_shards', [])] + try: + dispatch.check_special('shard_map', bufs) + except api_util.InternalFloatingPointError as e: + raise FloatingPointError(f'Invalid value ({e.ty}) encountered in sharded computation.') from None + +class ShardMapTrace(core.Trace): + __slots__ = ("mesh", "manual_axes", "check", "amesh") + + mesh: Mesh # outer concrete or abstract mesh + manual_axes: frozenset[AxisName] + check: bool + + def __init__(self, mesh, manual_axes, check): + super().__init__() + self.mesh = mesh + self.manual_axes = manual_axes + self.check = check + self.amesh = mesh.abstract_mesh + + def to_val_vma_pair(self, val): + if isinstance(val, ShardMapTracer): + return val.val, val.vma + elif isinstance(val, Tracer): + raise Exception(f"Shouldn't have any non-shard_map tracers: {val}") + else: + val_ = _unmatch_spec(self.mesh, self.check, self.amesh, self.manual_axes, + P(), val) + return val_, frozenset() + + def process_primitive(self, prim, tracers, params): + in_vals, in_vma = unzip2(map(self.to_val_vma_pair, tracers)) + if self.check: + out_avals, _ = prim.abstract_eval(*(typeof(t) for t in tracers), **params) + out_avals = tuple(out_avals) if type(out_avals) is list else out_avals + out_vma = tree_map(lambda a: a.vma, out_avals) + in_specs = tuple(map(partial(_vma_to_spec, self.mesh), in_vma)) + out_specs = tree_map(partial(_vma_to_spec, self.mesh), out_vma) + else: + out_vma = frozenset() + in_specs = out_specs = P(order_wrt_mesh(self.mesh, self.manual_axes)) + + eager_rule = eager_rules.get(prim) + if eager_rule: + out_vals = eager_rule(self.mesh, *in_vals, **params) + else: + f = HashablePartial( + _prim_applier, prim, self.check, tuple(params.items()), self.mesh, + self.manual_axes, in_specs, out_specs) + with (core.eval_context(), api.disable_jit(False), config.debug_nans(False), + config.debug_infs(False), use_abstract_mesh(self.amesh)): + out_vals = api.jit(f)(*in_vals) + _maybe_check_special(out_vals) + if prim.multiple_results: + out_vma = (out_vma if isinstance(out_vma, (list, tuple)) + else [out_vma] * len(out_vals)) + return map(partial(ShardMapTracer, self), out_vma, out_vals) + return ShardMapTracer(self, out_vma, out_vals) + + def process_shard_map(self, prim, fun, args, mesh, in_specs, + out_specs_thunk, check_vma, manual_axes): + # Check consistency between outer and inner shmaps on explicitly passed + # mesh and check_vma. + if isinstance(mesh, Mesh): + if mesh != self.mesh: raise Exception + del mesh + if check_vma != self.check: # TODO(mattjj): add check in jit path + raise Exception + del check_vma + + in_vals, in_vmas = unzip2(map(self.to_val_vma_pair, args)) + trace = ShardMapTrace(self.mesh, manual_axes | self.manual_axes, self.check) + in_vmas_ = [vma | _spec_to_vma(s) for vma, s in zip(in_vmas, in_specs)] + in_vals_ = [_unmatch_spec2(self.mesh, self.manual_axes, spec, x) + for x, spec in zip(in_vals, in_specs)] + in_tracers = map(partial(ShardMapTracer, trace), in_vmas_, in_vals_) + inner_mesh = _as_manual_mesh(self.mesh, manual_axes | self.manual_axes) + with (core.set_current_trace(trace), _extend_axis_env(self.mesh, manual_axes), + use_abstract_mesh(inner_mesh)): + ans = fun.call_wrapped(*in_tracers) + out_vals_, out_vmas_ = unzip2(map(trace.to_val_vma_pair, ans)) + out_specs = out_specs_thunk() + out_vals = [_match_spec2(self.mesh, self.manual_axes, spec, x) + for x, spec in zip(out_vals_, out_specs)] + out_vmas = [v - _spec_to_vma(spec) for v, spec in zip(out_vmas_, out_specs)] + return map(partial(ShardMapTracer, self), out_vmas, out_vals) + + def process_call(self, call_primitive, fun, tracers, params): + raise NotImplementedError( + f"Eager evaluation of `{call_primitive}` inside a `shard_map` isn't " + "yet supported. Put a `jax.jit` around the `shard_map`-decorated " + "function, and open a feature request at " + "https://github.com/jax-ml/jax/issues !") + + def process_map(self, map_primitive, fun, tracers, params): + raise NotImplementedError( + "Eager evaluation of `pmap` inside a `shard_map` isn't yet supported." + "Put a `jax.jit` around the `shard_map`-decorated function, and open " + "a feature request at https://github.com/jax-ml/jax/issues !") + + def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): + # Since ShardMapTrace is only used as a base main, we can drop the jvp. + del prim, jvp, symbolic_zeros + in_vals, in_vma = unzip2(map(self.to_val_vma_pair, tracers)) + out_vals, out_vma = _run_shmap(fun, self.mesh, self.manual_axes, in_vals, + in_vma, self.check) + return map(partial(ShardMapTracer, self), out_vma, out_vals) + + def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, + symbolic_zeros): + if symbolic_zeros: + msg = ("custom_vjp symbolic_zeros support with shard_map is not " + "implemented; please open an issue at " + "https://github.com/jax-ml/jax/issues") + raise NotImplementedError(msg) + del prim, fwd, bwd, out_trees, symbolic_zeros + in_vals, in_vma = unzip2(map(self.to_val_vma_pair, tracers)) + out_vals, out_vma = _run_shmap(fun, self.mesh, self.manual_axes, in_vals, + in_vma, self.check) + return map(partial(ShardMapTracer, self), out_vma, out_vals) + + +class ShardMapTracer(core.Tracer): + vma: frozenset[AxisName] + val: JaxType + + def __init__(self, trace, vma, val): + self._trace = trace + if isinstance(vma, set): + vma = frozenset(vma) + assert isinstance(vma, frozenset) + self.vma = vma + self.val = val + + @property + def aval(self): + aval = core.get_aval(self.val) + vma = self.vma if self._trace.check else self._trace.manual_axes + size = prod(self._trace.mesh.shape[n] for n in vma) + out = core.mapped_aval(size, 0, aval) + new_sharding = NamedSharding( + _as_manual_mesh(self._trace.amesh, self._trace.manual_axes), + out.sharding.spec) # pytype: disable=attribute-error + vma = self.vma if config._check_vma.value else frozenset() + return out.update(sharding=new_sharding, vma=vma) + + def to_concrete_value(self): + if self._trace.check and self.vma == frozenset(): + with core.eval_context(), use_abstract_mesh(self._trace.amesh): + return core.to_concrete_value(self.val[0]) + else: + return None + + def __str__(self) -> str: + pb_names = set(self._trace.mesh.axis_names) - self.vma + self = pvary(self, tuple(pb_names)) + with core.eval_context(), use_abstract_mesh(self._trace.amesh): + blocks = list(self.val) + mesh = self._trace.mesh + axis_names = f"({', '.join(map(str, mesh.axis_names))},)" + return '\n'.join( + f"On {device} at mesh coordinates {axis_names} = {idx}:\n{block}\n" + for (idx, device), block in zip(np.ndenumerate(mesh.devices), blocks)) + + __repr__ = __str__ # for debuggers, like `p x` + +def _prim_applier(prim, check_vma, params_tup, concrete_mesh, manual_axes, + in_specs, out_specs, *args): + def apply(*args): + outs = prim.bind(*map(_rem_singleton, args), **dict(params_tup)) + return tree_map(_add_singleton, outs) + out_specs = list(out_specs) if type(out_specs) is tuple else out_specs + return shard_map(apply, mesh=concrete_mesh, in_specs=in_specs, + out_specs=out_specs, check_vma=check_vma, + axis_names=manual_axes)(*args) + +eager_rules: dict[core.Primitive, Callable] = {} + +def _device_put_eager_rule(mesh, *xs, srcs, devices, copy_semantics): + del mesh, srcs, copy_semantics + for device in devices: + if device is not None: + raise ValueError("device_put with explicit device not allowed within " + f"shard_map-decorated functions, but got device {device}") + return xs +eager_rules[dispatch.device_put_p] = _device_put_eager_rule + + +# Batching + +def _shard_map_batch( + trace: batching.BatchTrace, prim: core.Primitive, fun: lu.WrappedFun, + in_tracers: Sequence[batching.BatchTracer], mesh: Mesh, + in_specs, out_specs_thunk, check_vma: bool, manual_axes: frozenset + ) -> Sequence[batching.BatchTracer]: + in_vals, in_dims = unzip2(map(trace.to_batch_info, in_tracers)) + spmd_axis_name = trace.axis_data.spmd_name + explicit_mesh_axis = trace.axis_data.explicit_mesh_axis + if spmd_axis_name is not None: + used = {n for spec in in_specs for n in _spec_to_vma(spec)} + if not config.disable_vmap_shmap_error.value and set(spmd_axis_name) & used: + raise ValueError("vmap spmd_axis_name cannot appear in shard_map in_specs") + new_in_specs = [ + sp if d is batching.not_mapped else pxla.batch_spec(sp, d, spmd_axis_name) + for sp, d in zip(in_specs, in_dims)] + new_size = trace.axis_data.size // prod(mesh.shape[n] for n in spmd_axis_name) + new_axis_data = batching.AxisData( + trace.axis_data.name, new_size, trace.axis_data.spmd_name, + trace.axis_data.explicit_mesh_axis) + elif explicit_mesh_axis is not None: + used = {n for spec in in_specs for n in _spec_to_vma(spec)} + if set(explicit_mesh_axis) & used: + raise ValueError("vmapped away explicit mesh axis cannot appear in " + "shard_map in_specs") + new_in_specs = [ + sp if d is batching.not_mapped else pxla.batch_spec(sp, d, None) + for sp, d in zip(in_specs, in_dims)] + new_axis_data = trace.axis_data + else: + new_in_specs = [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, None) + for sp, d in zip(in_specs, in_dims)] + new_axis_data = trace.axis_data + + fun, out_dims = batching.batch_subtrace( + fun, trace.tag, new_axis_data, tuple(in_dims)) + + @as_hashable_function(closure=out_specs_thunk) + def new_out_specs_thunk(): + return _batch_out_specs(spmd_axis_name, explicit_mesh_axis, out_dims(), + out_specs_thunk()) + + new_params = dict(mesh=mesh, in_specs=new_in_specs, + out_specs_thunk=new_out_specs_thunk, check_vma=check_vma, + manual_axes=manual_axes) + with core.set_current_trace(trace.parent_trace): + out_vals = prim.bind(fun, *in_vals, **new_params) + make_tracer = partial(batching.BatchTracer, trace, + source_info=source_info_util.current()) + return map(make_tracer, out_vals, out_dims()) +batching.BatchTrace.process_shard_map = _shard_map_batch + +def _batch_out_specs(spmd_name, explicit_mesh_axis, dims, out_specs): + if spmd_name is not None: + used = {n for spec in out_specs for n in _spec_to_vma(spec)} + if not config.disable_vmap_shmap_error.value and set(spmd_name) & used: + raise ValueError("vmap spmd_axis_name cannot appear in shard_map out_specs") + return [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, spmd_name) + for sp, d in zip(out_specs, dims)] + elif explicit_mesh_axis is not None: + used = {n for spec in out_specs for n in _spec_to_vma(spec)} + if set(explicit_mesh_axis) & used: + raise ValueError("vmapped away explicit mesh axis cannot appear in " + "shard_map out_specs") + return [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, None) + for sp, d in zip(out_specs, dims)] + else: + return [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, None) + for sp, d in zip(out_specs, dims)] + + +# Autodiff + +def _shard_map_jvp(trace, shard_map_p, f: lu.WrappedFun, tracers, mesh, in_specs, + out_specs_thunk, check_vma, manual_axes): + f = f.with_unknown_names() + primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers)) + which_nz = [ type(t) is not ad.Zero for t in tangents] + tangents = [t if type(t) is not ad.Zero else None for t in tangents] + args, in_tree = tree_flatten((primals, tangents)) + f_jvp = ad.jvp_subtrace(f, trace.tag) + f_jvp, which_nz_out = ad.nonzero_tangent_outputs(f_jvp) + tangent_in_specs = [sp for sp, nz in zip(in_specs, which_nz) if nz] + + @as_hashable_function(closure=out_specs_thunk) + def new_out_specs_thunk(): + out_ax = out_specs_thunk() + return (*out_ax, *(ax for ax, nz in zip(out_ax, which_nz_out()) if nz)) + params = dict(mesh=mesh, in_specs=(*in_specs, *tangent_in_specs), + out_specs_thunk=new_out_specs_thunk, check_vma=check_vma, + manual_axes=manual_axes) + f_jvp, out_tree = ad.traceable(f_jvp, in_tree) + result = shard_map_p.bind_with_trace(trace.parent_trace, (f_jvp,) + tuple(args), params) + primal_out, tangent_out = tree_unflatten(out_tree(), result) + tangent_out = [ad.Zero(core.get_aval(p).to_tangent_aval()) if t is None else t + for p, t in zip(primal_out, tangent_out)] + return [ad.JVPTracer(trace, p, t) for p, t in zip(primal_out, tangent_out)] +ad.JVPTrace.process_shard_map = _shard_map_jvp + +def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p, + f: lu.WrappedFun, tracers, mesh, in_specs, + out_specs_thunk, check_vma, manual_axes): + tracers = map(trace.to_jaxpr_tracer, tracers) + in_pvals = [t.pval for t in tracers] + in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals) + unk_in_specs, known_in_specs = pe.partition_list(in_knowns, in_specs) + in_avals_sharded = map(partial(shard_aval, mesh, manual_axes, check_vma), + unk_in_specs, in_avals) + f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.tag, f.debug_info, False) + f = _promote_scalar_residuals(f) + f_known, aux = pe.partial_eval_wrapper_nounits2( + f, (*in_knowns,), (*in_avals_sharded,)) + all_names = _all_newly_manual_mesh_names(mesh, manual_axes) + + @as_hashable_function(closure=out_specs_thunk) + def known_out_specs(): + _, _, out_knowns, res_avals, _, _ = aux() + _, out_known_specs = pe.partition_list(out_knowns, out_specs_thunk()) + if check_vma: + res_specs = [P(order_wrt_mesh(mesh, a.vma)) for a in res_avals] + else: + res_specs = [P(all_names)] * len(res_avals) + return (*out_known_specs, *res_specs) + + known_params = dict(mesh=mesh, in_specs=(*known_in_specs,), + out_specs_thunk=known_out_specs, check_vma=check_vma, + manual_axes=manual_axes) + out = shard_map_p.bind_with_trace(trace.parent_trace, + (f_known.with_unknown_names(), *in_consts), + known_params) + in_fwd, out_fwd, out_knowns, res_avals, jaxpr, env = aux() + num_res = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) + out_consts, non_fwd_res = split_list(out, [len(out) - num_res]) + assert not jaxpr.constvars + unk_out_specs, _ = pe.partition_list(out_knowns, out_specs_thunk()) + known_out_specs_ = known_out_specs() + res = subs_list2(in_fwd, out_fwd, in_consts, out_consts, non_fwd_res) + # TODO make res_avals be the full set, not just the non-fwd ones + res_avals_iter = iter(res_avals) + res_specs = [] + for f1, f2 in zip(in_fwd, out_fwd): + if f1 is not None: + res_specs.append(known_in_specs[f1]) + elif f2 is not None: + res_specs.append(known_out_specs_[f2]) + else: + if check_vma: + res_vma = next(res_avals_iter).vma + res_specs.append(P(order_wrt_mesh(mesh, res_vma))) + else: + res_specs.append(P(all_names)) + unk_in_specs = (*res_specs,) + (P(),) * len(env) + (*unk_in_specs,) # type: ignore[assignment] + const_tracers = map(trace.new_instantiated_const, res) + env_tracers = map(trace.to_jaxpr_tracer, env) + unk_arg_tracers = [t for t in tracers if not t.is_known()] + out_avals_sharded = [v.aval for v in jaxpr.outvars] + unk_params = dict(mesh=mesh, in_specs=unk_in_specs, + out_specs=tuple(unk_out_specs), + jaxpr=jaxpr.replace(debug_info=jaxpr.debug_info.with_unknown_names()), + check_vma=check_vma, manual_axes=manual_axes) + out_avals = map(partial(unshard_aval, mesh, check_vma), unk_out_specs, + out_avals_sharded) + out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) + for a in out_avals] + effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) + eqn = pe.new_eqn_recipe(trace, (*const_tracers, *env_tracers, *unk_arg_tracers), + out_tracers, shard_map_p, unk_params, + effs, source_info_util.current()) + for t in out_tracers: t.recipe = eqn + return merge_lists(out_knowns, out_tracers, out_consts) +pe.JaxprTrace.process_shard_map = _shard_map_partial_eval + +def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun, + tracers, mesh, in_specs, out_specs_thunk, check_vma, + manual_axes): + primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers)) + nzs_in = tuple(type(t) is not ad.Zero for t in tangents) + f = f.with_unknown_names() + f_primal, linearize_outs_thunk = ad.linearize_subtrace( + f, trace.tag, nzs_in, f.debug_info) + f_primal = _promote_scalar_residuals_lin(f_primal, linearize_outs_thunk) + all_names = _all_newly_manual_mesh_names(mesh, manual_axes) + + @as_hashable_function(closure=linearize_outs_thunk) + def fwd_out_specs_thunk(): + res_avals, _, _, _, in_fwd, out_fwd = linearize_outs_thunk() + res_avals = [r for r, f1, f2 in zip(res_avals, in_fwd, out_fwd) + if f1 is None and f2 is None] + out_specs = out_specs_thunk() + if check_vma: + res_specs = [P(order_wrt_mesh(mesh, a.vma)) for a in res_avals] + else: + res_specs = [P(all_names)] * len(res_avals) + return (*res_specs, *out_specs) + + fwd_params = dict( + mesh=mesh, in_specs=in_specs, + out_specs_thunk=fwd_out_specs_thunk, check_vma=check_vma, + manual_axes=manual_axes) + all_fwd_results = shard_map_p.bind_with_trace( + trace.parent_trace, (f_primal, *primals), fwd_params) + res_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = linearize_outs_thunk() + num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) + non_fwd_res = all_fwd_results[:num_res_out] + primals_out = all_fwd_results[num_res_out:] + residuals = subs_list2(in_fwd, out_fwd, primals, primals_out, non_fwd_res) + args_to_promote = [getattr(aval, 'shape', ()) == () and f1 is None and f2 is None + for aval, f1, f2 in zip(res_avals, in_fwd, out_fwd)] + with (_extend_axis_env(mesh, manual_axes), + use_abstract_mesh(_as_manual_mesh(mesh, manual_axes)), + config._check_vma(check_vma)): + lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote) + out_specs = out_specs_thunk() + res_avals2 = [r for r, f1, f2 in zip(res_avals, in_fwd, out_fwd) + if f1 is None and f2 is None] + res_avals_iter = iter(res_avals2) + res_specs = [] + for f1, f2 in zip(in_fwd, out_fwd): + if f1 is not None: + res_specs.append(in_specs[f1]) + elif f2 is not None: + res_specs.append(out_specs[f2]) + else: + if check_vma: + res_vma = next(res_avals_iter).vma + res_specs.append(P(order_wrt_mesh(mesh, res_vma))) + else: + res_specs.append(P(all_names)) + new_in_specs = (*res_specs, *(P(),) * len(env), + *(ax for ax, nz in zip(in_specs, nzs_in) if nz)) + tangent_out_specs = tuple(ax for ax, nz in zip(out_specs_thunk(), nzs_out) + if nz) + @as_hashable_function(closure=tangent_out_specs) + def tangent_out_specs_thunk(): + return tangent_out_specs + tangent_params = dict( + mesh=mesh, in_specs=new_in_specs, out_specs_thunk=tangent_out_specs_thunk, + check_vma=check_vma, manual_axes=manual_axes) + + # TODO(mattjj): avoid round-tripping the jaxpr through eval_jaxpr here + def f_tangent(*args): + return core.eval_jaxpr(lin_jaxpr, (), *args) + + nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz] + nz_tangents_out = shard_map_p.bind_with_trace( + trace.tangent_trace, + (lu.wrap_init(f_tangent, debug_info=lin_jaxpr.debug_info), + *residuals, *env, *nz_tangents_in), tangent_params) + nz_tangents_out_iter = iter(nz_tangents_out) + tangents_out = [next(nz_tangents_out_iter) if nz else ad.Zero.from_primal_value(primal) + for nz, primal in zip(nzs_out, primals_out)] + return map(partial(ad.maybe_linearize_tracer, trace), primals_out, nzs_out, tangents_out) +ad.LinearizeTrace.process_shard_map = _shard_map_linearize + +@lu.transformation2 +def _promote_scalar_residuals_lin(f, linearize_outs_thunk, *args, **kwargs): + ans = f(*args, **kwargs) + _, _, _, _, in_fwd, out_fwd = linearize_outs_thunk() + num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) + residuals = ans[:num_res_out] + primals = ans[num_res_out:] + residuals = [lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x + for x in residuals] + return *residuals, *primals + +@lu.transformation2 +def _promote_scalar_residuals(f: Callable, *args, **kwargs): + jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = f(*args, **kwargs) + which = [f1 is None and f2 is None and not v.aval.shape + for f1, f2, v in zip(in_fwds, out_fwds, jaxpr.constvars)] + jaxpr = _promote_scalar_residuals_jaxpr(jaxpr, which) + out_consts = [lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x + for x in out_consts] + return jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) + +def _promote_scalar_residuals_jaxpr(jaxpr: core.Jaxpr, which: Sequence[bool]): + def fun(*res_and_args): + res, args = split_list(res_and_args, [len(jaxpr.constvars)]) + res = [_rem_singleton(x) if w else x for x, w in zip(res, which)] + return core.eval_jaxpr(jaxpr, res, *args) + res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval + for v, w in zip(jaxpr.constvars, which)] + in_avals = [*res_avals, *[v.aval for v in jaxpr.invars]] + jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(fun, debug_info=jaxpr.debug_info), in_avals) + return jaxpr + + +def _unmentioned2(mesh: Mesh, spec, manual_axes: frozenset[AxisName] + ) -> list[AxisName]: + # We use a filtered-down version of unmentioned to avoid defensive-psum over + # more chips than required in the transpose-no-check-vma case. + name_set = _spec_to_vma(spec) + return [n for n in _all_mesh_names_except_spmd(mesh, manual_axes) + if n not in name_set] + + +def _shard_map_transpose(out_cts, *args, + jaxpr: core.Jaxpr, mesh, in_specs, out_specs, + check_vma, manual_axes): + mb_div = lambda x, y: x / y if y != 1 else x + out_cts = [ + ad.Zero(shard_aval(mesh, manual_axes, check_vma, sp, x.aval)) + if type(x) is ad.Zero else x if check_vma or dtypes.dtype(x) == dtypes.float0 + else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, sp, manual_axes)))) + for sp, x in zip(out_specs, out_cts) + ] + args = [x if type(x) is not ad.UndefinedPrimal else + ad.UndefinedPrimal(shard_aval(mesh, manual_axes, check_vma, sp, x.aval)) + for sp, x in zip(in_specs, args)] + all_args, in_tree = tree_flatten((out_cts, tuple(args))) + + def fun_trans_callable(out_cts, args): + # TODO(mattjj): when #26811 lands, delete this and just run backward_pass + in_undef = map(ad.is_undefined_primal, args) + res, undefs = partition_list(in_undef, args) + jaxpr_known, jaxpr_unknown, _, _ = pe.partial_eval_jaxpr_nounits( + pe.close_jaxpr(jaxpr), in_undef, False) + res_reshaped = core.jaxpr_as_fun(jaxpr_known)(*res) + in_cts = ad.backward_pass( + jaxpr_unknown.jaxpr, False, (), (*res_reshaped, *undefs), out_cts + )[len(res_reshaped):] + _, in_ct_specs = partition_list(in_undef, in_specs) + in_cts = [ad.Zero(x.aval) if type(x) is ad.Zero else x if check_vma + else lax_parallel.psum(x, tuple(_unmentioned2(mesh, sp, manual_axes))) + for sp, x in zip(in_ct_specs, in_cts)] + res_zeros = [ad_util.zero_from_primal(r) for r in res] + return merge_lists(in_undef, res_zeros, in_cts) + + fun_trans_callable.__name__ = f"transpose({jaxpr.debug_info.func_name})" + fun_trans = lu.wrap_init(fun_trans_callable, debug_info=jaxpr.debug_info) + fun_trans, nz_arg_cts = ad.nonzero_outputs(fun_trans) + fun_trans_flat, out_tree = api_util.flatten_fun_nokwargs(fun_trans, in_tree) + + new_in_specs = ( + [core.primal_spec_to_cotangent_spec(s) + for s, x in zip(out_specs, out_cts) if type(x) is not ad.Zero] + + [s for s, x in zip(in_specs, args) if type(x) is not ad.UndefinedPrimal]) + + def new_out_specs_thunk(): + return tuple(core.primal_spec_to_cotangent_spec(sp) + for sp, nz in zip(in_specs, nz_arg_cts()) if nz) + + try: + out_flat = shard_map_p.bind( + fun_trans_flat, *all_args, mesh=mesh, in_specs=tuple(new_in_specs), + out_specs_thunk=new_out_specs_thunk, check_vma=check_vma, + manual_axes=manual_axes) + except (FloatingPointError, ZeroDivisionError) as e: + print("Invalid nan value encountered in the backward pass of a shard_map " + "function. Calling the de-optimized backward pass.") + try: + # TODO(mattjj): Remove this and do `fun_trans.call_wrapped(out_cts, args)` + # in eager mode so that output of shmap are not manual. + with api.disable_jit(True): + _ = shard_map_p.bind( + fun_trans_flat, *all_args, mesh=mesh, in_specs=tuple(new_in_specs), + out_specs_thunk=new_out_specs_thunk, check_vma=check_vma, + manual_axes=manual_axes) + except (FloatingPointError, ZeroDivisionError) as e2: + raise e2 from None + else: + api_util._raise_no_nan_in_deoptimized(e) + except _RepError as e: + fails, = e.args + if not callable(out_specs): + msg = _inout_vma_error( + fun_trans, mesh, out_tree(), list(new_out_specs_thunk()), fails) + raise ValueError(msg) from None + in_cts = tree_unflatten(out_tree(), out_flat) + return [ad.Zero(unshard_aval(mesh, check_vma, sp, x.aval)) + if type(x) is ad.Zero else x for sp, x in zip(in_specs, in_cts)] +ad.primitive_transposes[shard_map_p] = _shard_map_transpose + +# Remat + +def _partial_eval_jaxpr_custom_rule( + saveable: Callable[..., pe.RematCases_], unks_in: Sequence[bool], + inst_in: Sequence[bool], eqn: core.JaxprEqn +) -> tuple[core.JaxprEqn, core.JaxprEqn, Sequence[bool], Sequence[bool], + list[core.Var]]: + jaxpr, mesh = eqn.params['jaxpr'], eqn.params['mesh'] + check_vma, manual_axes = eqn.params['check_vma'], eqn.params['manual_axes'] + with (_extend_axis_env(mesh, manual_axes), config._check_vma(check_vma), + use_abstract_mesh(_as_manual_mesh(mesh, manual_axes))): + jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \ + pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable) + num_out_primals = len(jaxpr_known.outvars) - num_res + in_fwd = pe._jaxpr_forwarding(jaxpr_known)[num_out_primals:] + out_vars, res_vars = split_list(jaxpr_known.outvars, [num_out_primals]) + idx_map = {id(v): i for i, v in enumerate(out_vars)} + out_fwd = [idx_map.get(id(v)) for v in res_vars] + which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)] + mesh = eqn.params['mesh'] + with (_extend_axis_env(mesh, manual_axes), + use_abstract_mesh(_as_manual_mesh(mesh, manual_axes)), + config._check_vma(check_vma)): + jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which) + jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged) + jaxpr_known = core.remove_named_axis_effects(jaxpr_known, mesh.axis_names) + jaxpr_staged = core.remove_named_axis_effects(jaxpr_staged, mesh.axis_names) + ins_known, _ = partition_list(unks_in, eqn.invars) + out_binders_known, _ = partition_list(unks_out, eqn.outvars) + _, ins_staged = partition_list(inst_in, eqn.invars) + _, out_binders_staged = partition_list(inst_out, eqn.outvars) + newvar = core.gensym() + residuals, staged_in_res_specs = [], [] + for var, w in zip(jaxpr_staged.invars[:num_res], which): + if w: + rn = (P(order_wrt_mesh(mesh, var.aval.vma)) # type: ignore + if check_vma else P(_all_newly_manual_mesh_names(mesh, manual_axes))) + residuals.append(newvar(unshard_aval(mesh, check_vma, rn, var.aval))) + staged_in_res_specs.append(rn) + if check_vma: + out_res_specs_known = [P(order_wrt_mesh(mesh, var.aval.vma)) # type: ignore + for var, w in zip(res_vars, which) if w] + else: + out_res_specs_known = [ + P(_all_newly_manual_mesh_names(mesh, manual_axes))] * sum(which) + params_known, params_staged = _pe_custom_params( + unks_in, inst_in, map(op.not_, unks_out), inst_out, in_fwd, out_fwd, + out_res_specs_known, staged_in_res_specs, + dict(eqn.params, jaxpr=jaxpr_known), dict(eqn.params, jaxpr=jaxpr_staged)) + eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals], + eqn.primitive, params_known, jaxpr_known.effects, + eqn.source_info, eqn.ctx) + full_res = subs_list2(in_fwd, out_fwd, ins_known, out_binders_known, residuals) + eqn_staged = pe.new_jaxpr_eqn([*full_res, *ins_staged], out_binders_staged, + eqn.primitive, params_staged, + jaxpr_staged.effects, eqn.source_info, eqn.ctx) + assert len(eqn_staged.invars) == len(jaxpr_staged.invars) + new_inst = [x for x, inst in zip(eqn.invars, inst_in) + if type(x) is core.Var and not inst] + new_inst += [out_binders_known[f] for f in {i for i in out_fwd if i is not None}] + return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals +pe.partial_eval_jaxpr_custom_rules[shard_map_p] = \ + _partial_eval_jaxpr_custom_rule + +def _add_reshapes(which: Sequence[bool], + jaxpr_known: core.Jaxpr, + jaxpr_staged: core.Jaxpr) -> tuple[core.Jaxpr, core.Jaxpr]: + # add singleton axes to residuals which are from jaxpr_known and are scalars + which_ = [w and not v.aval.shape # pytype: disable=attribute-error + for w, v in zip(which, jaxpr_staged.invars[:len(which)])] + if not any(which_): return jaxpr_known, jaxpr_staged + assert not jaxpr_known.constvars and not jaxpr_staged.constvars + + def known(*args): + out = core.eval_jaxpr(jaxpr_known, (), *args) + out_known, res = split_list(out, [len(out) - sum(which)]) + res = [_add_singleton(x) if not x.shape else x for x in res] + return [*out_known, *res] + avals_in = [v.aval for v in jaxpr_known.invars] + jaxpr_known, _, () = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(known, debug_info=jaxpr_known.debug_info), avals_in) + + def staged(*args): + res_, ins = split_list(args, [len(which)]) + res = [_rem_singleton(x) if w else x for x, w in zip(res_, which_)] + return core.eval_jaxpr(jaxpr_staged, (), *res, *ins) + res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval + for w, v in zip(which_, jaxpr_staged.invars[:len(which)])] + avals_in = [*res_avals, *[v.aval for v in jaxpr_staged.invars[len(which):]]] + jaxpr_staged, _, () = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(staged, debug_info=jaxpr_staged.debug_info), avals_in) + + return jaxpr_known, jaxpr_staged + +def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged, + in_fwd, out_fwd, out_res_specs_known, staged_in_res_specs, + params_known, params_staged): + # prune inputs to jaxpr_known according to unks_in + in_specs_known, _ = partition_list(unks_in, params_known['in_specs']) + _, out_specs_known = partition_list(kept_outs_known, params_known['out_specs']) + out_specs_known = out_specs_known + out_res_specs_known + assert len(out_specs_known) == len(params_known['jaxpr'].outvars) + new_params_known = dict(params_known, in_specs=tuple(in_specs_known), + out_specs=tuple(out_specs_known)) + + # added num_res new inputs to jaxpr_staged, pruning according to inst_in + _, in_specs_staged = partition_list(inst_in, params_staged['in_specs']) + iter_staged = iter(staged_in_res_specs) + res_specs = [in_specs_known[f1] if f1 is not None else + out_specs_known[f2] if f2 is not None else + next(iter_staged) for f1, f2 in zip(in_fwd, out_fwd)] + + in_specs_staged = res_specs + in_specs_staged + _, out_specs_staged = partition_list(kept_outs_staged, params_staged['out_specs']) + new_params_staged = dict(params_staged, in_specs=tuple(in_specs_staged), + out_specs=tuple(out_specs_staged)) + return new_params_known, new_params_staged + +# TODO(mattjj): remove this mechanism when we revise mesh scopes +def _all_mesh_names_except_spmd( + mesh: Mesh, manual_axes: frozenset[AxisName]) -> tuple[AxisName, ...]: + axis_env = core.get_axis_env() + spmd_names = axis_env.spmd_axis_names + return tuple(name for name in mesh.axis_names + if name not in spmd_names and name in manual_axes) + +def _all_newly_manual_mesh_names( + mesh: BaseMesh, manual_axes: frozenset[AxisName]) -> tuple[AxisName, ...]: + axis_env = core.get_axis_env() + vmap_spmd_names = set(axis_env.spmd_axis_names) + if not (ctx_mesh := get_abstract_mesh()).empty: + mesh = ctx_mesh + already_manual_names = set(ctx_mesh.manual_axes) + else: + # TODO(mattjj): remove this mechanism when we revise mesh scopes + already_manual_names = set(axis_env.axis_sizes) # may include vmap axis_names + return tuple(name for name in mesh.axis_names + if (name not in vmap_spmd_names | already_manual_names and + name in manual_axes)) + + +# DCE + +# TODO(mattjj): de-duplicate with pe.dce_jaxpr_call_rule, and/or _pmap_dce_rule? +def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn + ) -> tuple[list[bool], core.JaxprEqn | None]: + if not any(used_outputs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None + mesh = eqn.params["mesh"] + manual_axes = eqn.params["manual_axes"] + check_vma = eqn.params["check_vma"] + with (_extend_axis_env(mesh, manual_axes), config._check_vma(check_vma), + use_abstract_mesh(_as_manual_mesh(mesh, manual_axes))): + jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs) + if not any(used_inputs) and not any(used_outputs) and not jaxpr.effects: + return used_inputs, None + else: + _, in_specs = partition_list(used_inputs, eqn.params['in_specs']) + _, out_specs = partition_list(used_outputs, eqn.params['out_specs']) + new_params = dict(eqn.params, jaxpr=jaxpr, in_specs=tuple(in_specs), + out_specs=tuple(out_specs)) + effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) + new_eqn = pe.new_jaxpr_eqn( + [v for v, used in zip(eqn.invars, used_inputs) if used], + [x for x, used in zip(eqn.outvars, used_outputs) if used], + eqn.primitive, new_params, effs, eqn.source_info, eqn.ctx) + return used_inputs, new_eqn +pe.dce_rules[shard_map_p] = _shard_map_dce + +# Mutable arrays / refs + +@discharge.register_discharge_rule(shard_map_p) +def _shard_map_discharge( + in_avals, out_avals, *args, jaxpr, mesh, in_specs, out_specs, check_vma, + manual_axes): + inner_mesh = _as_manual_mesh(mesh, manual_axes) + with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(inner_mesh), + config._check_vma(check_vma)): + discharged_jaxpr, discharged_consts = discharge.discharge_state(jaxpr, ()) + if discharged_consts: raise NotImplementedError + del discharged_consts + + ref_specs = [spec for spec, invar in zip(in_specs, jaxpr.invars) + if isinstance(invar.aval, AbstractRef)] + params = dict(jaxpr=discharged_jaxpr, out_specs=(*out_specs, *ref_specs)) + [f], params_ = shard_map_p.get_bind_params(params) + discharged_out_specs, = params_.values() + out_and_ref_vals = shard_map_p.bind( + f, *args, mesh=mesh, in_specs=in_specs, manual_axes=manual_axes, + out_specs_thunk=discharged_out_specs, check_vma=check_vma) + out_vals, ref_vals = split_list(out_and_ref_vals, [len(jaxpr.outvars)]) + ref_vals_ = iter(ref_vals) + new_invals = [next(ref_vals_) if isinstance(a, AbstractRef) else None + for a in in_avals] + assert next(ref_vals_, None) is None + return new_invals, out_vals diff --git a/jax/_src/sharding.py b/jax/_src/sharding.py index a9bf62b46473..22ca3c46c67a 100644 --- a/jax/_src/sharding.py +++ b/jax/_src/sharding.py @@ -21,8 +21,8 @@ from jax._src import xla_bridge as xb from jax._src.lib import xla_client as xc from jax._src.op_shardings import ( - are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated, - op_sharding_to_indices) + are_hlo_shardings_equal, get_num_ways_dim_sharded, + is_hlo_sharding_replicated, op_sharding_to_indices) Shape = tuple[int, ...] Device = xc.Device @@ -36,17 +36,19 @@ def _addressable_devices_indices_map( global_map = sharding.devices_indices_map(global_shape) if sharding.is_fully_addressable: return global_map - if hasattr(sharding, '_internal_device_list'): - return {d: global_map[d] - for d in sharding._internal_device_list.addressable_device_list} - return {d: ind for d, ind in global_map.items() - if d.process_index == d.client.process_index()} + return {d: global_map[d] + for d in sharding._internal_device_list.addressable_device_list} # type: ignore @cache(max_size=4096, trace_context_in_key=False) def common_devices_indices_map( s: Sharding, global_shape: Shape) -> Mapping[Device, Index]: s.shard_shape(global_shape) # raises a good error message hlo_sharding = s._to_xla_hlo_sharding(len(global_shape)) + if (xc.OpSharding.Type.UNREDUCED in hlo_sharding.subgroup_types() or + hlo_sharding.is_unreduced()): + raise NotImplementedError( + "device_indices_map doesn't work with unreduced. Please file a bug at" + ' https://github.com/jax-ml/jax/issues') indices = op_sharding_to_indices(hlo_sharding, global_shape, len(s._device_assignment)) return dict(safe_zip(s._device_assignment, indices)) @@ -55,7 +57,9 @@ def common_devices_indices_map( @cache(max_size=4096, trace_context_in_key=False) def _common_shard_shape(self, global_shape: Shape) -> Shape: hlo_sharding = self._to_xla_hlo_sharding(len(global_shape)) - if is_op_sharding_replicated(hlo_sharding): + if is_hlo_sharding_replicated(hlo_sharding): + return global_shape + if hlo_sharding.is_unreduced(): return global_shape partitions, _ = get_num_ways_dim_sharded(hlo_sharding) assert len(partitions) == len(global_shape), (len(partitions), len(global_shape)) @@ -129,6 +133,10 @@ def with_memory_kind(self, kind: str) -> Sharding: def _device_assignment(self) -> XLADeviceAssignment: raise NotImplementedError('Subclasses should implement this method.') + @property + def _internal_device_list(self) -> xc.DeviceList: + raise NotImplementedError('Subclasses should implement this method.') + def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: raise NotImplementedError('Subclasses should implement this method.') @@ -170,14 +178,15 @@ def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]: """ return common_devices_indices_map(self, global_shape) + @property + def has_addressable_devices(self) -> bool: + return len(self._internal_device_list.addressable_device_list) > 0 + @functools.cached_property def _addressable_device_assignment(self) -> XLADeviceAssignment: if self.is_fully_addressable: return self._device_assignment - if hasattr(self, '_internal_device_list'): - return tuple(self._internal_device_list.addressable_device_list) - return tuple(d for d in self._device_assignment - if d.process_index == d.client.process_index()) + return tuple(self._internal_device_list.addressable_device_list) # type: ignore def shard_shape(self, global_shape: Shape) -> Shape: """Returns the shape of the data on each device. @@ -192,13 +201,9 @@ def is_equivalent_to(self: Sharding, other: Sharding, ndim: int) -> bool: Two shardings are equivalent if they place the same logical array shards on the same devices. - - For example, a :class:`NamedSharding` may be equivalent - to a :class:`PositionalSharding` if both place the same shards of the array - on the same devices. """ try: - return (are_op_shardings_equal(self._to_xla_hlo_sharding(ndim), + return (are_hlo_shardings_equal(self._to_xla_hlo_sharding(ndim), other._to_xla_hlo_sharding(ndim)) and self._internal_device_list == other._internal_device_list and # type: ignore self.memory_kind == other.memory_kind) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 2bbf913783e3..48247bcce753 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -33,20 +33,21 @@ from jax._src import xla_bridge as xb from jax._src import mesh_utils from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src.lib.mlir.dialects import sdy from jax._src.named_sharding import ( # noqa: F401 - SdyArraySharding, SdyDimSharding, UnspecifiedValue, AUTO, - ParsedPartitionSpec, _check_unique_resources, NamedSharding, UNSPECIFIED, + SdyArray, SdyDim, UnspecifiedValue, AUTO, flatten_spec, NamedSharding, + _check_unique_resources, UNSPECIFIED, ArrayMapping, ArrayMappingOrAutoOrUnspecified, get_array_mapping, - array_mapping_to_axis_resources, get_single_pspec, preprocess, - named_sharding_to_xla_hlo_sharding) + array_mapping_to_axis_resources, named_sharding_to_xla_hlo_sharding, + modify_sdy_sharding_wrt_axis_types) from jax._src.op_shardings import ( - are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated) + are_hlo_shardings_equal, get_num_ways_dim_sharded, + is_hlo_sharding_replicated) from jax._src.partition_spec import PartitionSpec -from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method +from jax._src.util import safe_zip, use_cpp_class, use_cpp_method import numpy as np +config_ext = xc._xla.config Shape = tuple[int, ...] Device = xc.Device @@ -55,10 +56,6 @@ # TODO(yashkatariya): Remove this after 3 months of deprecation. XLACompatibleSharding = jsharding.Sharding -@dataclasses.dataclass(frozen=True) -class TransferToMemoryKind: - memory_kind: str - def hashed_index(x) -> int: # This works for both `pjit` indices and `pmap` indices (which might @@ -88,33 +85,19 @@ def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int] @dataclasses.dataclass -class SdyArrayShardingList: - shardings: Sequence[SdyArraySharding] +class SdyArrayList: + shardings: Sequence[SdyArray] def build(self) -> sdy.TensorShardingPerValueAttr: return sdy.TensorShardingPerValueAttr.get( [sharding.build() for sharding in self.shardings]) -# TODO(yashkatariya): Upstream this into `_to_sdy_sharding` maybe with an extra -# parameter to it `_to_sdy_sharding(self, ndim, modify_wrt_axis_types=False)` -def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArraySharding, mesh): - if mesh._any_axis_auto: - dim_shardings, used_axes = [], [] # type: ignore - for d in sdy_sharding.dimension_shardings: - # TODO(yashkatariya): Maybe if any mesh axis is auto, mark all axes as open? - dim_shardings.append(SdyDimSharding(axes=[], is_closed=False) - if not d.axes and d.is_closed else d) - used_axes.extend(d.axes) - remaining_axes = set(mesh.axis_names) - set(used_axes) - replicated_axes = tuple(r for r in remaining_axes - if mesh._name_to_type[r] == mesh_lib.AxisType.Explicit) - return SdyArraySharding(sdy_sharding.mesh_shape, dim_shardings, - sdy_sharding.logical_device_ids, replicated_axes) - return sdy_sharding +replicated_hlo_sharding = xc.HloSharding.replicate() -replicated_hlo_sharding = xc.HloSharding.replicate() +def _unpickle_single_device_sharding(device, memory_kind): + return SingleDeviceSharding(device, memory_kind=memory_kind) @use_cpp_class(xc.SingleDeviceSharding) @@ -139,7 +122,7 @@ def __init__(self, device: Device, *, memory_kind: str | None = None): self._memory_kind = memory_kind def __reduce__(self): - return type(self), (self._device,), {'memory_kind': self._memory_kind} + return (_unpickle_single_device_sharding, (self._device, self._memory_kind)) def __repr__(self): mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}' @@ -183,10 +166,10 @@ def _device_assignment(self) -> XLADeviceAssignment: def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return replicated_hlo_sharding - def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding: - sdy_dim_sharding = [SdyDimSharding(axes=[], is_closed=True) + def _to_sdy_sharding(self, num_dimensions: int) -> SdyArray: + sdy_dim_sharding = [SdyDim(axes=[], is_open=False) for _ in range(num_dimensions)] - return SdyArraySharding(None, sdy_dim_sharding) + return SdyArray(mesh_shape=None, dim_shardings=sdy_dim_sharding) @property def is_fully_replicated(self) -> bool: @@ -194,10 +177,12 @@ def is_fully_replicated(self) -> bool: @property def is_fully_addressable(self) -> bool: - if config.enable_empty_arrays.value: - return xb.process_index(self._device.client) == self._device.process_index - return True + return xb.process_index(self._device.client) == self._device.process_index + + def check_compatible_aval(self, aval_shape: Shape) -> None: + return +SingleDeviceSharding.__module__ = 'jax.sharding' @util.cache(max_size=4096, trace_context_in_key=False) def pmap_sharding_devices_indices_map( @@ -209,7 +194,6 @@ def pmap_sharding_devices_indices_map( @use_cpp_class(xc.PmapSharding) class PmapSharding(jsharding.Sharding): - """Describes a sharding used by :func:`jax.pmap`.""" devices: np.ndarray sharding_spec: sharding_specs.ShardingSpec _internal_device_list: xc.DeviceList @@ -222,8 +206,7 @@ def __init__(self, devices: Sequence[Device] | np.ndarray, self.sharding_spec = sharding_spec def __reduce__(self): - return (type(self), (self.devices, self.sharding_spec), - {'memory_kind': self.memory_kind}) + return (type(self), (self.devices, self.sharding_spec)) def __eq__(self, other): if not isinstance(other, PmapSharding): @@ -327,8 +310,8 @@ def with_memory_kind(self, kind: str): def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: raise NotImplementedError("pmap doesn't use OpSharding.") - def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding: - raise NotImplementedError("pmap doesn't use SdyArraySharding.") + def _to_sdy_sharding(self, num_dimensions: int) -> SdyArray: + raise NotImplementedError("pmap doesn't use SdyArray.") @functools.cached_property def is_fully_replicated(self) -> bool: @@ -341,6 +324,9 @@ def is_fully_replicated(self) -> bool: def is_fully_addressable(self) -> bool: return self._internal_device_list.is_fully_addressable + def check_compatible_aval(self, aval_shape: Shape) -> None: + return + def shard_shape(self, global_shape: Shape) -> Shape: sharded_dim = None sharded_dim_size = None @@ -366,237 +352,33 @@ def shard_shape(self, global_shape: Shape) -> Shape: f'the number of devices={len(self._device_assignment)}') return sharded_shape +PmapSharding.__module__ = 'jax.sharding' -def _op_sharding_to_pos_sharding( - op_sharding: xc.OpSharding | xc.HloSharding, - device_assignment: Sequence[xc.Device], - memory_kind: str | None = None) -> PositionalSharding: - if isinstance(op_sharding, xc.OpSharding): - op_sharding = xc.HloSharding.from_proto(op_sharding) - - if op_sharding.is_replicated(): - return PositionalSharding( - device_assignment, memory_kind=memory_kind).replicate() - - if len(op_sharding.subgroup_types()) > 1: - raise NotImplementedError( - 'Unhandled HloSharding type. Please open a bug report!' - ) - - name = device_assignment[0].platform.upper() - ids = np.array( - [DeviceIdSet(name, i) for i in op_sharding.tile_assignment_devices()] - ) - p = PositionalSharding._remake(tuple(device_assignment), ids, - memory_kind=memory_kind) - p = p.reshape(op_sharding.tile_assignment_dimensions()) - if op_sharding.replicate_on_last_tile_dim(): - p = p.replicate(-1, keepdims=False) - return p - - -@util.cache(max_size=4096, trace_context_in_key=False) -def _positional_sharding_to_xla_hlo_sharding( - self, num_dimensions: int) -> xc.HloSharding: - if self.shape == (1,) * self.ndim: - return replicated_hlo_sharding - - pbuf = xc.OpSharding() - shape = self.shape[self.ndim - num_dimensions:] # 'rank promotion' of val - set_size, = {len(device_set) for device_set in self._ids.flat} - pbuf.type = xc.OpSharding.Type.OTHER - if set_size > 1: - pbuf.last_tile_dims = [xc.OpSharding.Type.REPLICATED] - pbuf.tile_assignment_dimensions = (*shape, set_size) - else: - pbuf.tile_assignment_dimensions = shape - pbuf.tile_assignment_devices = [i for ids in self._ids.flat for i in ids] - product_of_dims = math.prod(pbuf.tile_assignment_dimensions) - num_devices = len(pbuf.tile_assignment_devices) - assert product_of_dims == num_devices, (product_of_dims, num_devices) - return xc.HloSharding.from_proto(pbuf) - - -class PositionalSharding(jsharding.Sharding): - _devices: tuple[xc.Device, ...] - _memory_kind: str | None - _ids: np.ndarray # dtype DeviceIdSet - - def __init__(self, devices: Sequence[xc.Device] | np.ndarray, - *, memory_kind: str | None = None): - super().__init__() - if not isinstance(devices, np.ndarray): - devices = np.array(devices, dtype='object') - if not devices.size: - raise ValueError(f"{self.__class__.__name__}.__init__ requires at least " - f"one device, got {devices}") - self._devices = tuple(devices.flat) - self._memory_kind = memory_kind - name = self._devices[0].platform.upper() - self._ids = np.array([DeviceIdSet(name, i) for i in range(devices.size)], - dtype='object').reshape(devices.shape) - self._internal_device_list = xc.DeviceList(self._devices) - self._memory_kind = xc.check_and_canonicalize_memory_kind( - self._memory_kind, self._internal_device_list) - - @property - def shape(self): - return self._ids.shape - - @property - def ndim(self): - return self._ids.ndim - - def __repr__(self) -> str: - cls_name = self.__class__.__name__ - ids = self._ids.copy() - platform_name = self._devices[0].platform.upper() - for idx, x in np.ndenumerate(ids): - ids[idx] = DeviceIdSet(platform_name, *(self._devices[i].id for i in x)) - body = np.array2string(ids, prefix=cls_name + '(', suffix=')', - max_line_width=100) - mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}' - return f'{cls_name}({body}{mem}, shape={self.shape})' - - def reshape(self, *shape) -> PositionalSharding: - return self._remake(self._devices, self._ids.reshape(*shape), - memory_kind=self.memory_kind) - - def transpose(self, *axes) -> PositionalSharding: - return self._remake(self._devices, self._ids.transpose(*axes), - memory_kind=self.memory_kind) - T = property(transpose) - - def replicate(self, axis=None, keepdims=True) -> PositionalSharding: - new_ids = self._ids.sum(axis=axis, keepdims=keepdims) # union - return self._remake(self._devices, new_ids, - memory_kind=self.memory_kind) - - def check_compatible_aval(self, aval_shape: Shape) -> None: - if len(aval_shape) != len(self.shape) and not self.is_fully_replicated: - raise ValueError( - f"Sharding {self} is only valid for values of rank " - f"{len(self.shape)}, but was applied to a value of rank " - f"{len(aval_shape)}") - - @classmethod - def _remake( - cls, devices: tuple[xc.Device, ...], ids: np.ndarray, - *, memory_kind: str | None = None) -> PositionalSharding: - sharding = cls(devices, memory_kind=memory_kind) - sharding._ids = ids - return sharding - - # Hashable - - def __hash__(self) -> int: - if not hasattr(self, '_hash'): - self._hash = hash((self._internal_device_list, self.memory_kind)) - return self._hash - - def __eq__(self, other) -> bool: - if not isinstance(other, PositionalSharding): - return False - if self is other: - return True - all_ids_equal = np.array_equal(self._ids,other._ids) - mem_kind_equal = self.memory_kind == other.memory_kind - if self._devices is other._devices and mem_kind_equal and all_ids_equal: - return True - return (mem_kind_equal and all_ids_equal and - self._internal_device_list == other._internal_device_list) - - # Sharding interface - - @property - def num_devices(self) -> int: - return len(self.device_set) - - @functools.cached_property - def device_set(self) -> set[xc.Device]: - return set(self._devices) - - @property - def memory_kind(self) -> str | None: - return self._memory_kind - - def with_memory_kind(self, kind: str) -> PositionalSharding: - return PositionalSharding(self._devices, memory_kind=kind) - - @functools.cached_property - def is_fully_replicated(self) -> bool: - return self.shape == (1,) * self.ndim - - # jsharding.Sharding interface - - @property - def _device_assignment(self) -> XLADeviceAssignment: - return self._devices - - def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: - return _positional_sharding_to_xla_hlo_sharding(self, num_dimensions) - - def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding: - raise NotImplementedError( - "PositionalSharding can't be converted to an SdyArraySharding.") - - @functools.cached_property - def is_fully_addressable(self) -> bool: - return self._internal_device_list.is_fully_addressable - - -class DeviceIdSet: - _name: str - _ids: frozenset[int] - def __init__(self, name, *ids): - self._name = name - self._ids = frozenset(ids) - - def __iter__(self): - return iter(sorted(self._ids)) - - def __add__(self, other) -> DeviceIdSet: - assert isinstance(other, DeviceIdSet) - return DeviceIdSet(self._name, *(self._ids | other._ids)) - - def __len__(self) -> int: - return len(self._ids) - - def __repr__(self) -> str: - ids = ', '.join(safe_map(str, sorted(self._ids))) - return f'{{{self._name} {ids}}}' - - def __hash__(self) -> int: - return hash((self._name, self._ids)) - - def __eq__(self, other) -> bool: - return (isinstance(other, DeviceIdSet) and self._name == other._name and - self._ids == other._ids) +def _unpickle_gspmd_sharding(devices, op_sharding, memory_kind): + return GSPMDSharding(devices, op_sharding, memory_kind=memory_kind) @use_cpp_class(xc.GSPMDSharding) class GSPMDSharding(jsharding.Sharding): - _devices: tuple[Device, ...] + _devices: xc.DeviceList _hlo_sharding: xc.HloSharding _memory_kind: str | None - _device_list: xc.DeviceList | None _internal_device_list: xc.DeviceList @use_cpp_method() - def __init__(self, devices: Sequence[Device], + def __init__(self, devices: Sequence[Device] | xc.DeviceList, op_sharding: xc.OpSharding | xc.HloSharding, - *, memory_kind: str | None = None, - _device_list: xc.DeviceList | None = None): - self._devices = tuple(devices) - if isinstance(op_sharding, xc.OpSharding): - self._hlo_sharding = xc.HloSharding.from_proto(op_sharding) - else: - self._hlo_sharding = op_sharding + *, memory_kind: str | None = None): + self._devices = (devices if isinstance(devices, xc.DeviceList) else + xc.DeviceList(tuple(devices))) + self._hlo_sharding = (xc.HloSharding.from_proto(op_sharding) + if isinstance(op_sharding, xc.OpSharding) else + op_sharding) self._memory_kind = memory_kind def __reduce__(self): - return (type(self), (self._devices, self._hlo_sharding.to_proto()), - {'memory_kind': self._memory_kind}) + return (_unpickle_gspmd_sharding, + (self._devices, self._hlo_sharding.to_proto(), self._memory_kind)) @functools.cached_property def _hlo_sharding_hash(self): @@ -609,7 +391,7 @@ def __eq__(self, other): return False if self is other: return True - return (are_op_shardings_equal(self._hlo_sharding, other._hlo_sharding) + return (are_hlo_shardings_equal(self._hlo_sharding, other._hlo_sharding) and self.memory_kind == other.memory_kind and self._internal_device_list == other._internal_device_list) @@ -633,7 +415,7 @@ def check_compatible_aval(self, aval_shape: Shape) -> None: @property def num_devices(self) -> int: - return len(self.device_set) + return len(self._internal_device_list) @functools.cached_property def device_set(self) -> set[Device]: @@ -648,18 +430,35 @@ def with_memory_kind(self, kind: str) -> GSPMDSharding: @property def _device_assignment(self) -> XLADeviceAssignment: - return self._devices + return tuple(self._devices) def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return self._hlo_sharding - def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding: - raise NotImplementedError( - "GSPMDSharding can't be converted to SdyArraySharding.") + def _to_sdy_sharding(self, num_dimensions: int) -> SdyArray: + if self._hlo_sharding.tuple_elements(): + raise TypeError( + f'Cannot convert GSPMDSharding {self._hlo_sharding} into SdyArray.') + elif self._hlo_sharding.is_replicated(): + empty_mesh = mesh_lib.AbstractMesh((), ()) + return NamedSharding(empty_mesh, PartitionSpec())._to_sdy_sharding( + num_dimensions) + elif self._hlo_sharding.is_tiled(): + if not self._hlo_sharding.is_tile_assignment_iota(): + raise TypeError( + f'Cannot convert GSPMDSharding {self._hlo_sharding} into SdyArray.') + axis_sizes = tuple(self._hlo_sharding.get_axis_sizes()) + axis_names = tuple(f'_axis_{i}' for i in range(len(axis_sizes))) + mesh = mesh_lib.AbstractMesh(axis_sizes, axis_names) + return _gspmd_to_named_sharding_via_mesh(self, mesh)._to_sdy_sharding( + num_dimensions) + else: + raise TypeError( + f'Cannot convert GSPMDSharding {self._hlo_sharding} into SdyArray.') @functools.cached_property def is_fully_replicated(self) -> bool: - return is_op_sharding_replicated(self._hlo_sharding) + return is_hlo_sharding_replicated(self._hlo_sharding) @functools.cached_property def is_fully_addressable(self) -> bool: @@ -667,7 +466,7 @@ def is_fully_addressable(self) -> bool: @classmethod def get_replicated(cls, device_assignment, *, memory_kind: str | None = None): - return cls(tuple(device_assignment), replicated_hlo_sharding, + return cls(device_assignment, replicated_hlo_sharding, memory_kind=memory_kind) @@ -676,7 +475,6 @@ def get_replicated(cls, device_assignment, *, memory_kind: str | None = None): def prepare_axis_resources(axis_resources, arg_name, allow_unconstrained_dims=False): - # PyTrees don't treat None values as leaves, so we use an is_leaf function. entries, treedef = tree_util.tree_flatten( axis_resources, is_leaf=lambda x: x is None) what = f"{arg_name} leaf specifications" @@ -689,14 +487,23 @@ def prepare_axis_resources(axis_resources, arg_name, if isinstance(entry, PmapSharding): raise ValueError(f'One of {what} got sharding {entry} which is not ' 'allowed.') + if isinstance(entry, NamedSharding) and entry.mesh.empty: + raise ValueError(f'One of {what} got an empty NamedSharding: {entry} ' + 'which is not allowed.') + if (not allow_unconstrained_dims and isinstance(entry, NamedSharding) and + PartitionSpec.UNCONSTRAINED in entry.spec): + raise ValueError( + f'Unconstrained dims are not allowed when passed to {arg_name}:' + f' {entry}') new_entries.append(entry) else: if not isinstance(entry, PartitionSpec): raise TypeError(f"{what} are expected to be " f"PartitionSpec instances or None, but got {entry}") - for e in entry: - if e is PartitionSpec.UNCONSTRAINED and not allow_unconstrained_dims: - raise ValueError(f"Unconstrained dims are not allowed: {entry}") + if not allow_unconstrained_dims and PartitionSpec.UNCONSTRAINED in entry: + raise ValueError( + f'Unconstrained dims are not allowed when passed to {arg_name}:' + f' {entry}') _check_unique_resources(entry, arg_name) new_entries.append(entry) @@ -882,8 +689,7 @@ def parse_flatten_op_sharding( return out elif hlo_sharding.is_replicated(): return [PartitionSpec()] - elif (xla_extension_version >= 319 and hlo_sharding.is_maximal() - and mesh.size == 1): + elif hlo_sharding.is_maximal() and mesh.size == 1: return [PartitionSpec()] elif hlo_sharding.is_tiled(): mesh_shape = mesh.shape @@ -898,7 +704,11 @@ def parse_flatten_op_sharding( while dim_size > 1: axis = next(mesh_axis) axis_size = mesh_shape[axis] - assert dim_size % axis_size == 0 + if dim_size % axis_size != 0: + raise ValueError( + f'{shape=} is incompatible with {mesh_shape=}: ' + f'{dim_size=} is not divisible by {axis_size=}.' + ) dim_size //= axis_size dim_partitions.append(axis) partitions.append(tuple(dim_partitions)) @@ -924,6 +734,7 @@ class NonUniformShardingError(ValueError): """Raised when sharding is not uniform across processes.""" +@util.cache(max_size=4096, trace_context_in_key=False) def get_process_index_and_count( tensor_sharding: jsharding.Sharding, dim: int, ndims: int) -> tuple[int, int]: """Get current process index and number of unique processes for given dimension. @@ -1169,15 +980,14 @@ def make_key_array_phys_sharding(aval, sharding): elif isinstance(sharding, NamedSharding): elt_aval = core.physical_element_aval(aval.dtype) trailing_spec = [None] * elt_aval.ndim - return sharding.with_spec(PartitionSpec(*sharding.spec, *trailing_spec)) + return sharding.update(spec=PartitionSpec(*sharding.spec, *trailing_spec)) else: hlos = sharding._to_xla_hlo_sharding(aval.ndim) return GSPMDSharding( - sharding._device_assignment, physical_hlo_sharding(aval, hlos)) + sharding._internal_device_list, physical_hlo_sharding(aval, hlos)) -def physical_sharding( - aval, sharding: jsharding.Sharding) -> jsharding.Sharding: +def physical_sharding(aval, sharding: jsharding.Sharding) -> jsharding.Sharding: return make_key_array_phys_sharding(aval, sharding) @@ -1191,7 +1001,7 @@ def get_logical_gspmd_sharding(logical_shape, dtype, phys_sharding): logical_op_sharding = phys_hlo_sharding.to_proto().clone() tad = partitions[:-elt_aval.ndim] + suffix logical_op_sharding.tile_assignment_dimensions = tad - return GSPMDSharding(phys_sharding._device_assignment, + return GSPMDSharding(phys_sharding._internal_device_list, xc.HloSharding.from_proto(logical_op_sharding)) def check_replicated_trailing_dims(sharding: jsharding.Sharding, @@ -1231,43 +1041,39 @@ def logical_sharding(logical_shape, dtype, phys_sharding) -> jsharding.Sharding: phys_spec = (*phys_sharding.spec, *[None] * (len(phys_shape) - len(phys_sharding.spec))) else: - phys_spec = phys_sharding.spec - return phys_sharding.with_spec(phys_spec[:-elt_aval.ndim]) + phys_spec = phys_sharding.spec # type: ignore + return phys_sharding.update(spec=phys_spec[:-elt_aval.ndim]) else: return get_logical_gspmd_sharding(logical_shape, dtype, phys_sharding) @util.cache() -def create_mesh_pspec_sharding( - mesh: mesh_lib.Mesh, pspec: PartitionSpec | None, +def cached_named_sharding( + mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh, pspec: PartitionSpec, memory_kind: str | None = None) -> NamedSharding: - if pspec is None: - pspec = PartitionSpec() return NamedSharding(mesh, pspec, memory_kind=memory_kind) def _gspmd_to_named_sharding_via_mesh( - out_s: GSPMDSharding, mesh: mesh_lib.Mesh) -> NamedSharding: + out_s: GSPMDSharding, mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh +) -> NamedSharding: spec = parse_flatten_op_sharding(out_s._hlo_sharding, mesh)[0] - return create_mesh_pspec_sharding( - mesh, spec, memory_kind=out_s.memory_kind) - -def flatten_spec(spec): - out = [] - for s in spec: - if isinstance(s, tuple): - out.extend(s) - else: - out.append(s) - return out + return cached_named_sharding(mesh, spec, out_s.memory_kind) + +@util.cache() def canonicalize_sharding(sharding: NamedSharding | PartitionSpec | None, api_name: str, check_mesh_consistency: bool = True ) -> NamedSharding | None: if sharding is None: - return sharding + return None if isinstance(sharding, NamedSharding) and sharding.mesh.empty: return None + if not isinstance(sharding, (NamedSharding, PartitionSpec)): + raise TypeError( + f"`out_sharding` argument of {api_name} only supports instances of" + f" `NamedSharding` or `PartitionSpec`. Got {sharding} of type:" + f" {type(sharding)}") cur_mesh = mesh_lib.get_abstract_mesh() if isinstance(sharding, PartitionSpec): @@ -1275,15 +1081,15 @@ def canonicalize_sharding(sharding: NamedSharding | PartitionSpec | None, raise ValueError( 'Using PartitionSpec when you are not under a mesh context is not' ' allowed. Please pass a NamedSharding instance or enter into a mesh' - f' context via `jax.sharding.use_mesh`. Got {sharding}') + f' context via `jax.set_mesh`. Got {sharding}') sharding = NamedSharding(cur_mesh, sharding) else: # There are cases when you have multiple meshes set. Allow that for full # auto mode because of existing use cases. # TODO(yashkatariya): Remove this once we disallow different meshes and # fix the existing use cases. - if (sharding.mesh.abstract_mesh._are_all_axes_auto and - cur_mesh._are_all_axes_auto): + if (sharding.mesh.abstract_mesh.are_all_axes_auto and + cur_mesh.are_all_axes_auto): check_mesh_consistency = False if (check_mesh_consistency and not cur_mesh.empty and sharding.mesh.abstract_mesh != cur_mesh): @@ -1292,6 +1098,8 @@ def canonicalize_sharding(sharding: NamedSharding | PartitionSpec | None, f' {sharding.mesh.abstract_mesh} passed to {api_name}.' ' This error occurs at source: ' f' {source_info_util.summarize(source_info_util.current())}') + # TODO(yashkatariya): Maybe allow concrete mesh at the top level + # i.e `core.trace_state_clean()` for APIs like jnp.zeros, etc? if isinstance(sharding.mesh, mesh_lib.Mesh): sharding = NamedSharding(sharding.mesh.abstract_mesh, sharding.spec) @@ -1304,15 +1112,14 @@ def canonicalize_sharding(sharding: NamedSharding | PartitionSpec | None, f'PartitionSpec passed to {api_name} cannot contain axis' ' names that are of type Auto or Manual. Got PartitionSpec:' f' {sharding.spec} with axis name: {s} of type:' - f' {sharding.mesh._name_to_type[s]}, This error occurs at source: ' + f' {sharding.mesh._name_to_type[s]}. This error occurs at source: ' f' {source_info_util.summarize(source_info_util.current())}') return sharding def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], - *, devices: Sequence[xc.Device] | None = None, - axis_types: tuple[mesh_lib.AxisType, ...] | None = None - ) -> mesh_lib.Mesh: + axis_types: tuple[mesh_lib.AxisType, ...] | None = None, + *, devices: Sequence[xc.Device] | None = None) -> mesh_lib.Mesh: """Creates an efficient mesh with the shape and axis names specified. This function attempts to automatically compute a good mapping from a set of @@ -1345,11 +1152,16 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], Args: axis_shapes: Shape of the mesh. For example, axis_shape=(4, 2) axis_names: Names of the mesh axes. For example, axis_names=('x', 'y') + axis_types: Optional tuple of :class:`jax.sharding.AxisType` entries + corresponding to the ``axis_names``. See `Explicit Sharding`_ for more + information. devices: Optional keyword only argument, that allows you to specify the devices you want to create a mesh with. Returns: - A `jax.sharding.Mesh` object. + A :class:`jax.sharding.Mesh` object. + + .. _Explicit Sharding: https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html """ if devices is None: devices = xb.devices() @@ -1374,49 +1186,74 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], mesh_devices = mesh_utils.create_device_mesh( new_axis_shapes, devices, allow_split_physical_axes=allow_split_physical_axes) + first_d = mesh_devices.flat[0] + if (first_d.platform == 'tpu' and hasattr(first_d, 'slice_index') and + len({d.slice_index for d in mesh_devices.flat}) > 1): + raise ValueError( + '`jax.make_mesh` does not support multi-slice topologies. Please use' + ' jax.experimental.mesh_utils.create_hybrid_device_mesh') + if axis_types is None: + axis_types = (mesh_lib.AxisType.Explicit,) * len(mesh_devices.shape) return mesh_lib.Mesh(mesh_devices, axis_names, axis_types=axis_types) +class set_mesh: + """Sets a concrete mesh in a thread-local context. -@contextlib.contextmanager -def use_mesh(mesh: mesh_lib.Mesh): - if not isinstance(mesh, mesh_lib.Mesh): - raise ValueError( - f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}") + ``jax.set_mesh`` has dual behavior. You can use it as a global setter or as a + context manager. - # TODO(yashkatariya): Enable this. - # if not core.trace_state_clean(): - # raise ValueError('`use_mesh` can only be used outside of `jax.jit`') + When a mesh is in context via ``jax.set_mesh``, you can use pass + raw PartitionSpecs to all APIs that accept sharding as an argument. + Using ``jax.set_mesh`` is also required for enabling explicit sharding mode: + https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html - with mesh_lib.use_abstract_mesh(mesh.abstract_mesh), use_concrete_mesh(mesh): - yield + For example:: -def set_mesh(mesh: mesh_lib.Mesh | None) -> mesh_lib.Mesh | None: - """Sets the given concrete mesh globally and returns the previous concrete - mesh.""" - if mesh is not None and not isinstance(mesh, mesh_lib.Mesh): - raise ValueError( - f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}") - if not core.trace_state_clean(): - raise ValueError('`set_mesh` can only be used outside of `jax.jit`.') + mesh = jax.make_mesh((2,), ('x',)) + jax.set_mesh(mesh) # use the API as a global setter - if mesh is None: - config.abstract_mesh_context_manager.set_global(mesh_lib.empty_abstract_mesh) # type: ignore - else: - config.abstract_mesh_context_manager.set_global(mesh.abstract_mesh) # type: ignore + with jax.set_mesh(mesh): # use the API as a context manager + ... - prev_mesh = config.device_context.get_global() - config.device_context.set_global(mesh) - return prev_mesh + Note: ``jax.set_mesh`` can only be used outside of ``jax.jit``. + """ + __slots__ = ["prev_abstract_mesh", "prev_mesh"] -@contextlib.contextmanager -def use_concrete_mesh(mesh: mesh_lib.Mesh | None): - if mesh is not None and not isinstance(mesh, mesh_lib.Mesh): + def __init__(self, mesh: mesh_lib.Mesh): + if not isinstance(mesh, mesh_lib.Mesh): + raise ValueError( + f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}") + if not core.trace_state_clean(): + raise ValueError('`set_mesh` can only be used outside of `jax.jit`.') + if mesh._any_axis_manual: + raise ValueError( + f'mesh {mesh} contains manual axes which is not allowed when using' + ' `jax.set_mesh`. Please use `jax.shard_map` to enter into `Manual`' + ' mode instead.') + + self.prev_abstract_mesh = config.abstract_mesh_context_manager.swap_local( + mesh.abstract_mesh) + self.prev_mesh = config.device_context.swap_local(mesh) + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_value, traceback): + config.abstract_mesh_context_manager.set_local(self.prev_abstract_mesh) + config.device_context.set_local(self.prev_mesh) + + +def get_mesh() -> mesh_lib.Mesh: + if not core.trace_state_clean(): raise ValueError( - f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}") - # TODO(yashkatariya): Enable this. - # if not core.trace_state_clean(): - # raise ValueError('`use_concrete_mesh` can only be used outside of `jax.jit`.') + '`get_mesh` can only be used outside of `jax.jit`. Maybe you want' + ' `jax.sharding.get_abstract_mesh()`?') + return mesh_lib.get_concrete_mesh() + +@contextlib.contextmanager +def _internal_use_concrete_mesh(mesh: mesh_lib.Mesh): + assert isinstance(mesh, mesh_lib.Mesh) prev_val = config.device_context.swap_local(mesh) try: yield diff --git a/jax/_src/sharding_specs.py b/jax/_src/sharding_specs.py index 7092b51ab894..e498923fa4b6 100644 --- a/jax/_src/sharding_specs.py +++ b/jax/_src/sharding_specs.py @@ -36,7 +36,6 @@ import numpy as np -from jax._src import config from jax._src import util from jax._src.lib import pmap_lib @@ -172,12 +171,8 @@ def shift_sharded_axis(a: MeshDimAssignment): return a # replication_factor represents the product of inner pmaps, so it goes # after the outer pmapped axis at index 0 - if config.pmap_no_rank_reduction.value: - sharding = util.tuple_update( - pspec.sharding, map_axis, Chunked([axis_size])) - else: - sharding = util.tuple_insert( - pspec.sharding, map_axis, Unstacked(axis_size)) + sharding = util.tuple_update( + pspec.sharding, map_axis, Chunked([axis_size])) return ShardingSpec( sharding=sharding, mesh_mapping=itertools.chain( @@ -192,10 +187,7 @@ def shift_sharded_axis(a: MeshDimAssignment): def create_pmap_sharding_spec(shape: tuple[int, ...], sharded_dim: int = 0, sharded_dim_size: int | None = None): if sharded_dim is not None: - if config.pmap_no_rank_reduction.value: - sharded_shape = util.tuple_update(shape, sharded_dim, 1) - else: - sharded_shape = util.tuple_delete(shape, sharded_dim) + sharded_shape = util.tuple_update(shape, sharded_dim, 1) if sharded_dim_size is None: sharded_dim_size = shape[sharded_dim] else: diff --git a/jax/_src/source_info_util.py b/jax/_src/source_info_util.py index b1901f44f022..3f077d82ac09 100644 --- a/jax/_src/source_info_util.py +++ b/jax/_src/source_info_util.py @@ -21,13 +21,11 @@ import itertools import os.path import re -import sys import sysconfig import threading import types from typing import NamedTuple -import jax.version from jax._src.lib import xla_client from jax._src import traceback_util @@ -48,11 +46,11 @@ class Frame(NamedTuple): _exclude_paths: list[str] = [ # Attach the separator to make sure that .../jax does not end up matching # .../jax_triton and other packages that might have a jax prefix. - os.path.dirname(jax.version.__file__) + os.sep, + os.path.dirname(os.path.dirname(__file__)) + os.sep, # Also exclude stdlib as user frames. In a non-standard Python runtime, - # the following two may be different. + # the following may be different. sysconfig.get_path('stdlib'), - os.path.dirname(sysconfig.__file__) + os.path.dirname(contextlib.__file__), ] @functools.cache @@ -95,6 +93,9 @@ class Transform(NamedTuple): def wrap(self, stack: list[str]): if stack: stack[-1] = f'{self.name}({stack[-1]})' + else: + stack.append(f'{self.name}()') + @dataclasses.dataclass(frozen=True) class NameStack: @@ -159,23 +160,17 @@ def is_user_filename(filename: str) -> bool: return (_include_path_regex().search(filename) is not None or _exclude_path_regex().search(filename) is None) -if sys.version_info >= (3, 11): - def raw_frame_to_frame(code: types.CodeType, lasti: int) -> Frame: - loc = xla_client.Traceback.code_addr2location(code, lasti) - start_line, start_column, end_line, end_column = loc - return Frame(file_name=code.co_filename, - function_name=code.co_qualname, - start_line=start_line, start_column=start_column, - end_line=end_line, end_column=end_column) -else: - def raw_frame_to_frame(code: types.CodeType, lasti: int) -> Frame: - # pre-3.11 co_qualname does not exist, use co_name - return Frame(file_name=code.co_filename, - function_name=code.co_name, - start_line=xla_client.Traceback.code_addr2line(code, lasti), - start_column=0, end_line=0, end_column=0) - -def user_frames(source_info: SourceInfo) -> Iterator[Frame]: + +def raw_frame_to_frame(code: types.CodeType, lasti: int) -> Frame: + loc = xla_client.Traceback.code_addr2location(code, lasti) + start_line, start_column, end_line, end_column = loc + return Frame(file_name=code.co_filename, + function_name=code.co_qualname, + start_line=start_line, start_column=start_column, + end_line=end_line, end_column=end_column) + + +def user_frames(traceback: Traceback | None) -> Iterator[Frame]: """Iterator over the user's frames, filtering jax-internal frames.""" # Guess the user's frame is the innermost frame not in the jax source tree or # Python stdlib. We don't use traceback_util.path_starts_with because that @@ -183,14 +178,13 @@ def user_frames(source_info: SourceInfo) -> Iterator[Frame]: # e.g. adding source provenance annotations to XLA lowerings, so we don't # want to incur the cost. We consider files that end with _test.py as user # frames, to allow testing this mechanism from tests. - traceback = source_info.traceback code, lasti = traceback.raw_frames() if traceback else ([], []) return (raw_frame_to_frame(code[i], lasti[i]) for i in range(len(code)) if is_user_filename(code[i].co_filename)) @functools.lru_cache(maxsize=64) -def user_frame(source_info: SourceInfo) -> Frame | None: - return next(user_frames(source_info), None) +def user_frame(traceback: Traceback | None) -> Frame | None: + return next(user_frames(traceback), None) def _summarize_frame(frame: Frame) -> str: if frame.start_column != 0: @@ -200,7 +194,7 @@ def _summarize_frame(frame: Frame) -> str: return f"{frame.file_name}:{frame.start_line} ({frame.function_name})" def summarize(source_info: SourceInfo, num_frames=1) -> str: - frames = itertools.islice(user_frames(source_info), num_frames) + frames = itertools.islice(user_frames(source_info.traceback), num_frames) frame_strs = [_summarize_frame(frame) if frame else "unknown" for frame in frames] return '\n'.join(reversed(frame_strs)) diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 19cd0822aa58..d64ffa6941ef 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -30,68 +30,84 @@ """ from __future__ import annotations -import functools +import dataclasses +import enum from collections.abc import Sequence from dataclasses import dataclass +import itertools as it from typing import Any, NamedTuple, Protocol, Union, runtime_checkable -import jax - from jax._src import core from jax._src import config +from jax._src import sharding as sharding_lib from jax._src import source_info_util from jax._src import traceback_util from jax._src import tree_util from jax._src import util -from jax._src.sharding_impls import UnspecifiedValue, AUTO -from jax._src.layout import Layout +from jax._src.typing import ArrayLike from jax._src.interpreters import mlir +from jax._src.layout import Format, Layout, AutoLayout +from jax._src.sharding_impls import UnspecifiedValue, AUTO from jax._src.lib.mlir import ir +from jax._src.lib import _jax from jax._src.lib import xla_client as xc +from jax._src.tree_util import tree_structure, tree_unflatten +from jax._src.core import typeof source_info_util.register_exclusion(__file__) traceback_util.register_exclusion(__file__) -xla_extension = xc._xla map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip CompilerOptions = dict[str, Union[str, bool]] -# -- Internal protocols +# -- Internal types + -class Executable(Protocol): - """Protocol for executables, which a user-facing ``Compiled`` encapsulates.""" +class Executable: + + def xla_extension_executable(self) -> xc.LoadedExecutable: + raise NotImplementedError( + "compiled executable carries no loaded XLA executable. It may be " + f"that {type(self)} defines an incomplete implementation.") def call(self, *args_flat) -> Sequence[Any]: """Execute on the flat list of arguments, returning flat outputs.""" - # TODO(frostig): improve annotation (sequences of arrays/buffers) - raise NotImplementedError + raise NotImplementedError("compiled executable does not support invocation") - def input_shardings(self) -> Sequence[jax.sharding.Sharding]: + def create_cpp_call(self, params: CompiledCallParams) -> Any: + """Optionally constructs a fast c++ dispatcher.""" + return None + + def input_shardings(self) -> Sequence[sharding_lib.Sharding]: """Flat sequence of input shardings. May raise ``NotImplementedError`` if unavailable, e.g. based on backend, compiler, or runtime. """ - raise NotImplementedError + raise NotImplementedError( + "compiled executable carries no input sharding information") - def output_shardings(self) -> Sequence[jax.sharding.Sharding]: + def output_shardings(self) -> Sequence[sharding_lib.Sharding]: """Flat sequence of output shardings. May raise ``NotImplementedError`` if unavailable, e.g. based on backend, compiler, or runtime. """ - raise NotImplementedError + raise NotImplementedError( + "compiled executable carries no output sharding information") - def input_layouts(self): - raise NotImplementedError + def input_formats(self): + raise NotImplementedError( + "compiled executable carries no input layout information") - def output_layouts(self): - raise NotImplementedError + def output_formats(self): + raise NotImplementedError( + "compiled executable carries no output layout information") def as_text(self) -> str: """A human-readable text representation of this executable. @@ -102,89 +118,30 @@ def as_text(self) -> str: May raise ``NotImplementedError`` if unavailable, e.g. based on backend, compiler, or runtime. """ - raise NotImplementedError - - def cost_analysis(self) -> Any: - """A summary of execution cost estimates. - - Intended for visualization and debugging purposes. The object output by - this is some simple data structure that can easily be printed or serialized - (e.g. nested dicts, lists, and tuples with numeric leaves). However, its - structure can be arbitrary: it need not be consistent across versions of JAX - and jaxlib, or even across invocations. It is relayed directly to external - callers. - - May raise ``NotImplementedError`` if unavailable, e.g. based on backend, - compiler, or runtime. - """ - # TODO(frostig): improve annotation (arbitrary pytree) - raise NotImplementedError - - def memory_analysis(self) -> Any: - """A summary of estimated memory requirements. - - Intended for visualization and debugging purposes. The object output by - this is some simple data structure that can easily be printed or serialized - (e.g. nested dicts, lists, and tuples with numeric leaves). However, its - structure can be arbitrary: it need not be consistent across versions of JAX - and jaxlib, or even across invocations. It is relayed directly to external - callers. - - May raise ``NotImplementedError`` if unavailable, e.g. based on backend, - compiler, or runtime. - """ - # TODO(frostig): improve annotation (arbitrary pytree) - raise NotImplementedError - - def runtime_executable(self) -> Any: - """An arbitrary object representation of this executable. - - Intended for debugging purposes. This need not be a valid nor reliable - serialization. It is relayed directly to external callers, with no - guarantee on type, structure, or consistency across invocations. - - May raise ``NotImplementedError`` if unavailable, e.g. based on backend or - compiler. - """ - raise NotImplementedError - - def create_cpp_call(self, no_kwargs, in_tree, out_tree) -> Any: - """Optionally constructs a fast c++ dispatcher.""" - return None - - -class Lowering(Protocol): - """Protocol for lowerings, which a user-facing ``Lowered`` encapsulates.""" - - def compile( - self, compiler_options: CompilerOptions | None = None) -> Executable: - """Compile and return a corresponding ``Executable``.""" - raise NotImplementedError - - def as_text(self, dialect: str | None = None, *, - debug_info: bool = False) -> str: - """A human-readable text representation of this lowering. - - Intended for visualization and debugging purposes. This need not be a valid - nor reliable serialization. It is relayed directly to external callers. - """ - raise NotImplementedError - - def compiler_ir(self, dialect: str | None = None) -> Any: - """An arbitrary object representation of this lowering. - - Intended for debugging purposes. This need not be a valid nor reliable - serialization. It is relayed directly to external callers, with no - guarantee on type, structure, or consistency across invocations. - - May raise ``NotImplementedError`` if unavailable, e.g. based on backend or - compiler. + xla_ext_exe = self.xla_extension_executable() + err_msg = ("text view unsupported on current XLA backend: " + f"{type(xla_ext_exe)}") - Args: - dialect: Optional string specifying a representation dialect - (e.g. "stablehlo") - """ - raise NotImplementedError + if hasattr(xla_ext_exe, "get_hlo_text"): + try: + return xla_ext_exe.get_hlo_text() + except _jax.JaxRuntimeError as e: + msg, *_ = e.args + if type(msg) is str and msg.startswith("UNIMPLEMENTED"): + raise NotImplementedError(err_msg) from e + else: + raise + else: + if not hasattr(xla_ext_exe, "hlo_modules"): + raise NotImplementedError(err_msg) + try: + return "\n\n".join([m.to_string() for m in xla_ext_exe.hlo_modules()]) + except _jax.JaxRuntimeError as e: + msg, *_ = e.args + if type(msg) is str and msg.startswith("UNIMPLEMENTED"): + raise NotImplementedError(err_msg) from e + else: + raise def cost_analysis(self) -> Any: """A summary of execution cost estimates. @@ -196,66 +153,15 @@ def cost_analysis(self) -> Any: and jaxlib, or even across invocations. It is relayed directly to external callers. - This function estimates execution cost in the absence of compiler - optimizations, which may drastically affect the cost. For execution cost - estimates after optimizations, compile this lowering and see - ``Compiled.cost_analysis``. - May raise ``NotImplementedError`` if unavailable, e.g. based on backend, compiler, or runtime. """ - # TODO(frostig): improve annotation (arbitrary pytree) - raise NotImplementedError - - -# -- Internal adapters from XLA-related objects to the above protocols - -class XlaExecutable(Executable): - - def xla_extension_executable(self) -> xc.LoadedExecutable: - raise NotImplementedError("must override") - - def call(self, *args_flat) -> Sequence[Any]: - raise NotImplementedError("must override") - - def input_shardings(self) -> Sequence[jax.sharding.Sharding]: - raise NotImplementedError( - "compiled executable carries no input sharding information") - - def output_shardings(self) -> Sequence[jax.sharding.Sharding]: - raise NotImplementedError( - "compiled executable carries no output sharding information") - - def input_layouts(self): - raise NotImplementedError( - "compiled executable carries no input layout information") - - def output_layouts(self): - raise NotImplementedError( - "compiled executable carries no input layout information") - - def as_text(self) -> str: - xla_ext_exe = self.xla_extension_executable() - err_msg = ("text view unsupported on current XLA backend: " - f"{type(xla_ext_exe)}") - if not hasattr(xla_ext_exe, "hlo_modules"): - raise NotImplementedError(err_msg) - try: - return "\n\n".join([m.to_string() for m in xla_ext_exe.hlo_modules()]) - except xla_extension.XlaRuntimeError as e: - msg, *_ = e.args - if type(msg) is str and msg.startswith("UNIMPLEMENTED"): - raise NotImplementedError(err_msg) from e - else: - raise - - def cost_analysis(self) -> dict[str, float]: xla_ext_exe = self.xla_extension_executable() if hasattr(xla_ext_exe, "cost_analysis"): try: return xla_ext_exe.cost_analysis() - except xla_extension.XlaRuntimeError as e: + except _jax.JaxRuntimeError as e: msg, *_ = e.args if not (type(msg) is str and msg.startswith("UNIMPLEMENTED")): raise @@ -273,6 +179,18 @@ def cost_analysis(self) -> dict[str, float]: ) def memory_analysis(self) -> Any: + """A summary of estimated memory requirements. + + Intended for visualization and debugging purposes. The object output by + this is some simple data structure that can easily be printed or serialized + (e.g. nested dicts, lists, and tuples with numeric leaves). However, its + structure can be arbitrary: it need not be consistent across versions of JAX + and jaxlib, or even across invocations. It is relayed directly to external + callers. + + May raise ``NotImplementedError`` if unavailable, e.g. based on backend, + compiler, or runtime. + """ xla_ext_exe = self.xla_extension_executable() err_msg = ("memory analysis unsupported on current XLA backend: " f"{type(xla_ext_exe)}") @@ -280,7 +198,7 @@ def memory_analysis(self) -> Any: raise NotImplementedError(err_msg) try: return xla_ext_exe.get_compiled_memory_stats() - except xla_extension.XlaRuntimeError as e: + except _jax.JaxRuntimeError as e: msg, *_ = e.args if type(msg) is str and msg.startswith("UNIMPLEMENTED"): raise NotImplementedError(err_msg) from e @@ -288,33 +206,54 @@ def memory_analysis(self) -> Any: raise def runtime_executable(self) -> Any: + """An arbitrary object representation of this executable. + + Intended for debugging purposes. This need not be a valid nor reliable + serialization. It is relayed directly to external callers, with no + guarantee on type, structure, or consistency across invocations. + + May raise ``NotImplementedError`` if unavailable, e.g. based on backend or + compiler. + """ return self.xla_extension_executable() -class XlaLowering(Lowering): - """Adapts our various internal XLA-backed computations into a ``Lowering``.""" +class Lowering: compile_args: dict[str, Any] + # the constants that have been hoisted out and must be passed as first args, + # after the tokens. + # See https://docs.jax.dev/en/latest/internals/constants.html. + const_args: list[ArrayLike] def hlo(self) -> xc.XlaComputation: """Return an HLO representation of this computation.""" hlo = self.stablehlo() m: str | bytes m = mlir.module_to_bytecode(hlo) - return xla_extension.mlir.mlir_module_to_xla_computation( + return _jax.mlir.mlir_module_to_xla_computation( m, use_tuple_args=self.compile_args["tuple_args"]) def stablehlo(self) -> ir.Module: """Return a StableHLO representation of this computation.""" - raise NotImplementedError("must override") + raise NotImplementedError( + f"cost analysis unsupported on XLA computation: {type(self)}") def compile( - self, compiler_options: CompilerOptions | None = None) -> Executable: - raise NotImplementedError("must override") + self, compiler_options: CompilerOptions | None = None, *, + device_assignment: tuple[xc.Device, ...] | None = None) -> Executable: + """Compile and return a corresponding ``Executable``.""" + raise NotImplementedError( + f"cost analysis unsupported on XLA computation: {type(self)}") def as_text(self, dialect: str | None = None, *, debug_info: bool = False) -> str: + """A human-readable text representation of this lowering. + + Intended for visualization and debugging purposes. This need not be a valid + nor reliable serialization. It is relayed directly to external callers. + """ if dialect is None: dialect = "stablehlo" if dialect == "stablehlo": @@ -328,6 +267,19 @@ def as_text(self, dialect: str | None = None, raise ValueError(f"unknown dialect: {dialect}") def compiler_ir(self, dialect: str | None = None) -> Any: + """An arbitrary object representation of this lowering. + + Intended for debugging purposes. This need not be a valid nor reliable + serialization. It is relayed directly to external callers, with no + guarantee on type, structure, or consistency across invocations. + + May raise ``NotImplementedError`` if unavailable, e.g. based on backend or + compiler. + + Args: + dialect: Optional string specifying a representation dialect + (e.g. "stablehlo") + """ if dialect is None: dialect = "stablehlo" if dialect == "stablehlo": @@ -337,8 +289,26 @@ def compiler_ir(self, dialect: str | None = None) -> Any: else: raise ValueError(f"unknown dialect: {dialect}") - def cost_analysis(self) -> dict[str, float]: - raise NotImplementedError("must override") + def cost_analysis(self) -> Any: + """A summary of execution cost estimates. + + Intended for visualization and debugging purposes. The object output by + this is some simple data structure that can easily be printed or serialized + (e.g. nested dicts, lists, and tuples with numeric leaves). However, its + structure can be arbitrary: it need not be consistent across versions of JAX + and jaxlib, or even across invocations. It is relayed directly to external + callers. + + This function estimates execution cost in the absence of compiler + optimizations, which may drastically affect the cost. For execution cost + estimates after optimizations, compile this lowering and see + ``Compiled.cost_analysis``. + + May raise ``NotImplementedError`` if unavailable, e.g. based on backend, + compiler, or runtime. + """ + raise NotImplementedError( + f"cost analysis unsupported on XLA computation: {type(self)}") # -- Public-facing API, plus helpers @@ -357,20 +327,13 @@ def dtype(self): return self._aval.dtype # pytype: disable=attribute-error -@dataclass(frozen=True) -class OutInfo: - shape: tuple[int, ...] - dtype: jax.typing.DTypeLike - sharding: jax.sharding.Sharding | None = None - - class Stage: args_info: Any # PyTree of ArgInfo @property def in_tree(self) -> tree_util.PyTreeDef: """Tree structure of the pair (positional arguments, keyword arguments).""" - return tree_util.tree_structure(self.args_info) + return tree_structure(self.args_info) @property def in_avals(self): @@ -396,8 +359,305 @@ def make_args_info(in_tree, in_avals, donate_argnums): class CompiledCallParams(NamedTuple): executable: Executable no_kwargs: bool - in_tree: tree_util.PyTreeDef + in_tree: tree_util.PyTreeDef # lo tree + out_tree: tree_util.PyTreeDef # lo tree + const_args: list[ArrayLike] # https://docs.jax.dev/en/latest/internals/constants.html + in_types: tuple[tree_util.PyTreeDef, list[core.AbstractValue | core.AvalQDD]] | None + out_types: tuple[tree_util.PyTreeDef, list[core.AbstractValue]] | None + + @property + def is_high(self): + return self.in_types and self.out_types and any( + a.is_high for a in it.chain(self.in_types[1], self.out_types[1])) + + +def _traced_args_info(self): + don8_rgn = tuple(i for i, d in enumerate(self._params['donated_invars']) if d) + arg_avals = self.jaxpr.in_avals[self._num_consts:] + return make_args_info(self._in_tree, arg_avals, don8_rgn) + +def _traced_out_info(self): + out_shardings = [None if isinstance(s, UnspecifiedValue) else s + for s in self._params['out_shardings']] + out_layouts = [None if isinstance(l, AutoLayout) else l + for l in self._params['out_layouts']] + out = [] + for a, out_s, out_l in zip(self.jaxpr.out_avals, out_shardings, out_layouts): + if isinstance(a, core.ShapedArray): + s = ((a.sharding if a.sharding.mesh._are_all_axes_explicit_or_manual + else out_s) if out_s is None else out_s) + out.append( + core.ShapeDtypeStruct( + a.shape, a.dtype, sharding=Format(out_l, s), + weak_type=a.weak_type, + vma=(a.vma if config._check_vma.value else None))) + else: + out.append(a) + return tree_util.tree_unflatten(self.out_tree, out) + + +class Traced(Stage): + """Traced form of a function specialized to argument types and values. + + A traced computation is ready for lowering. This class carries the + traced representation with the remaining information needed to later + lower, compile, and execute it. + + Provides access to both the hijax (high-level) and lojax (low-level) + representations via `.jaxpr` and `.lojax` properties respectively. + """ + __slots__ = ['_meta_tys_flat', '_params', '_in_tree', 'out_tree', '_consts', + '_lojax'] + + def __init__(self, meta_tys_flat, params, in_tree, out_tree, consts): + self._meta_tys_flat = meta_tys_flat + self._params = params + self._in_tree = in_tree + self.out_tree = out_tree + self._consts = consts + self._lojax = None + + jaxpr = property(lambda self: self._params['jaxpr']) + fun_name = property(lambda self: self._params['name']) + args_info = property(_traced_args_info) + out_info = property(_traced_out_info) + _num_consts = property(lambda self: len(self._consts)) + + @property + def out_avals(self): + return tree_unflatten(self.out_tree, self.jaxpr.out_avals) + + def __call__(self, *args, **kwargs): + args_flat = tree_util.tree_leaves_checked(self.in_tree, (args, kwargs)) + out_flat = core.jaxpr_as_fun(self.jaxpr)(*args_flat) + return tree_unflatten(self.out_tree, out_flat) + + + @property + def lojax(self) -> LoJax: + if self._lojax is not None: + return self._lojax + + if not self.jaxpr.is_high: + self._lojax = LoJax( + self._meta_tys_flat, self._params, self._in_tree, self.out_tree, + (self._in_tree, self.jaxpr.in_avals), + (self.out_tree, self.jaxpr.out_avals), + self._consts) + return self._lojax + + # TODO(mattjj): when pmap is deleted, merge with pjit.py BUILD rule + from jax._src.interpreters import partial_eval as pe # type:ignore + hi_jaxpr = self.jaxpr + _, closed_over_himutables = pe.convert_const_himutables(hi_jaxpr) + if closed_over_himutables: raise NotImplementedError # TODO(mattjj) + lo_jaxpr = pe.lower_jaxpr(hi_jaxpr) + if any(a.is_high for a in hi_jaxpr.final_aval_qdds): + in_tree = lojax_pytree(hi_jaxpr.in_aval_qdds, self._in_tree) + else: + in_tree = self._in_tree + if any(a.is_high for a in hi_jaxpr.out_avals): + out_tree = lojax_pytree(hi_jaxpr.out_avals, self.out_tree) + else: + out_tree = self.out_tree + params = dict(lojax_expand_params(hi_jaxpr, self._params), jaxpr=lo_jaxpr) + lo_meta_tys = [mty.replace(aval=lo_ty) + for mty, aq in zip(self._meta_tys_flat, hi_jaxpr.in_aval_qdds) + for lo_ty in (mty.aval.lo_ty_qdd(aq.qdd) + if mty.aval.has_qdd else mty.aval.lo_ty())] + self._lojax = LoJax( + lo_meta_tys, params, in_tree, out_tree, + (self._in_tree, hi_jaxpr.final_aval_qdds), + (self.out_tree, hi_jaxpr.out_avals), + self._consts) + return self._lojax + + def lower(self, *, lowering_platforms: tuple[str, ...] | None = None, + _private_parameters: mlir.LoweringParameters | None = None): + """Lower to compiler input, returning a ``Lowered`` instance.""" + lo = self.lojax + if _private_parameters is None: + _private_parameters = mlir.LoweringParameters() + try: + from jax._src.pjit import _resolve_and_lower # type: ignore + lowering = _resolve_and_lower( + lo._meta_tys_flat, **lo._params, lowering_platforms=lowering_platforms, + lowering_parameters=_private_parameters, pgle_profiler=None) + except DeviceAssignmentMismatchError as e: + fails, = e.args + msg = _device_assignment_mismatch_error( + lo._params['name'], fails, lo._meta_tys_flat, 'jit', + lo.jaxpr.debug_info.safe_arg_names(len(lo.jaxpr.in_avals))) + raise ValueError(msg) from None + return Lowered(lowering, lo.args_info, lo.out_tree, + in_types=lo._in_types, out_types=lo._out_types) + + +def lojax_expand_params(jaxpr, params): + from jax._src.pjit import _lojax_expand_params # type: ignore + lo_nums_in = [len(aval.lo_ty()) for aval in jaxpr.in_aval_qdds] + lo_nums_out = [len(t.lo_ty()) for t in jaxpr.out_avals] + lo_muts_out = sum(len(aval.lo_ty()) for aval in jaxpr.final_aval_qdds + if aval.has_qdd) + return _lojax_expand_params(lo_nums_in, lo_nums_out, lo_muts_out, + **dict(params, jaxpr=jaxpr)) + +def lojax_pytree(hi_avals, tree): + lo_avals = [t.lo_ty() for t in hi_avals] + return tree_structure(tree_unflatten(tree, lo_avals)) + + +class LoJax: + __slots__ = ['_meta_tys_flat', '_params', '_in_tree', 'out_tree', + '_consts', '_in_types', '_out_types'] + + def __init__(self, meta_tys_flat, params, in_tree, out_tree, in_types, out_types, + consts): + self._meta_tys_flat = meta_tys_flat + self._params = params + self._in_tree = in_tree + self.out_tree = out_tree + self._consts = consts + self._in_types = in_types # hi types + self._out_types = out_types + + jaxpr = property(lambda self: self._params['jaxpr']) + fun_name = property(lambda self: self._params['name']) + args_info = property(_traced_args_info) + out_info = property(_traced_out_info) + _num_consts = property(lambda self: len(self._consts)) + + + +class Lowered(Stage): + """Lowering of a function specialized to argument types and values. + + A lowering is a computation ready for compilation. This class + carries a lowering together with the remaining information needed to + later compile and execute it. It also provides a common API for + querying properties of lowered computations across JAX's various + lowering paths (:func:`~jax.jit`, :func:`~jax.pmap`, etc.). + """ + __slots__ = ["_lowering", "args_info", "out_tree", "_no_kwargs", + "_in_types", "_out_types"] + + _lowering: Lowering + args_info: Any # PyTree of ArgInfo, not including the const_args out_tree: tree_util.PyTreeDef + _no_kwargs: bool + _in_types: list[tuple[core.AbstractValue, core.QuasiDynamicData]] | None + _out_types: list[core.AbstractValue] | None + + def __init__(self, lowering: Lowering, args_info, + out_tree: tree_util.PyTreeDef, no_kwargs: bool = False, + in_types=None, out_types=None): + + self._lowering = lowering + self.args_info = args_info + self.out_tree = out_tree + self._no_kwargs = no_kwargs + self._in_types = in_types # type: ignore + self._out_types = out_types # type: ignore + + @property + def in_avals(self): + in_avals_ = self._lowering.compile_args.get("global_in_avals", None) + if in_avals_ is None: # For old pmap code i.e. PmapComputation + return tree_util.tree_map(lambda x: x._aval, self.args_info) + kept_var_idx = self._lowering.compile_args["kept_var_idx"] + non_dce_avals = self._lowering.compile_args["all_args_info"].in_avals + if self.in_tree.num_leaves > len(in_avals_): + iter_in_avals = iter(in_avals_) + in_avals_ = [ + next(iter_in_avals) if i in kept_var_idx + else a for i, a in zip(range(self.in_tree.num_leaves), non_dce_avals)] + return self.in_tree.unflatten(in_avals_) + + @property + def out_info(self): # PyTree of OutInfo + out_avals = self._lowering.compile_args["global_out_avals"] + out_shardings = self._lowering.compile_args["out_shardings"] + out_layouts = self._lowering.compile_args["out_layouts"] + outs = [] + for o, l, s in zip(out_avals, out_layouts, out_shardings): + s = None if isinstance(s, (UnspecifiedValue, AUTO)) else s + l = None if isinstance(l, AutoLayout) else l + format = Format(l, s) + outs.append(core.ShapeDtypeStruct(o.shape, o.dtype, sharding=format)) + return self.out_tree.unflatten(outs) + + def compile( + self, compiler_options: CompilerOptions | None = None, *, + device_assignment: tuple[xc.Device, ...] | None = None) -> Compiled: + """Compile, returning a corresponding ``Compiled`` instance.""" + + kw: dict[str, Any] = { + "compiler_options": compiler_options, + "device_assignment": device_assignment + } + return Compiled( + self._lowering.compile(**kw), # pytype: disable=wrong-keyword-args + self._lowering.const_args, + self.args_info, + self.out_tree, + self._no_kwargs, + self._in_types, + self._out_types, + ) + + def as_text(self, dialect: str | None = None, *, + debug_info: bool = False) -> str: + """A human-readable text representation of this lowering. + + Intended for visualization and debugging purposes. This need not be a valid + nor reliable serialization. + Use `jax.export` if you want reliable and portable serialization. + + Args: + dialect: Optional string specifying a lowering dialect (e.g. "stablehlo", + or "hlo"). + debug_info: Whether to include debugging information, + e.g., source location. + """ + return self._lowering.as_text(dialect, debug_info=debug_info) + + def compiler_ir(self, dialect: str | None = None) -> Any | None: + """An arbitrary object representation of this lowering. + + Intended for debugging purposes. This is not a valid nor reliable + serialization. The output has no guarantee of consistency across + invocations. + Use `jax.export` if you want reliable and portable serialization. + + Returns ``None`` if unavailable, e.g. based on backend, compiler, or + runtime. + + Args: + dialect: Optional string specifying a lowering dialect (e.g. "stablehlo", + or "hlo"). + """ + try: + return self._lowering.compiler_ir(dialect) + except NotImplementedError: + return None + + def cost_analysis(self) -> Any | None: + """A summary of execution cost estimates. + + Intended for visualization and debugging purposes. The object output by + this is some simple data structure that can easily be printed or serialized + (e.g. nested dicts, lists, and tuples with numeric leaves). However, its + structure can be arbitrary: it may be inconsistent across versions of JAX + and jaxlib, or even across invocations. + + Returns ``None`` if unavailable, e.g. based on backend, compiler, or + runtime. + """ + # TODO(frostig): improve annotation (basic pytree of arbitrary structure) + try: + return self._lowering.cost_analysis() + except NotImplementedError: + return None class Compiled(Stage): @@ -408,20 +668,23 @@ class Compiled(Stage): common API for querying properties of compiled computations across JAX's various compilation paths and backends. """ - __slots__ = ["args_info", "out_tree", "_executable", "_no_kwargs"] + __slots__ = ["args_info", "out_tree", "_executable", "_no_kwargs", "_params"] - args_info: Any # PyTree of ArgInfo + args_info: Any # PyTree of ArgInfo, not including const_args out_tree: tree_util.PyTreeDef _executable: Executable _no_kwargs: bool + _params: CompiledCallParams - def __init__(self, executable, args_info, out_tree, no_kwargs=False): + def __init__(self, executable, const_args: list[ArrayLike], + args_info, out_tree, no_kwargs=False, in_types=None, out_types=None): self._executable = executable self._no_kwargs = no_kwargs self.args_info = args_info self.out_tree = out_tree - self._params = CompiledCallParams(self._executable, self._no_kwargs, - self.in_tree, self.out_tree) + self._params = CompiledCallParams( + self._executable, self._no_kwargs, self.in_tree, self.out_tree, + const_args, in_types, out_types) self._call = None def as_text(self) -> str | None: @@ -474,6 +737,25 @@ def memory_analysis(self) -> Any | None: except NotImplementedError: return None + @property + def in_avals(self): + in_avals_ = self._executable.in_avals + if self.in_tree.num_leaves > len(in_avals_): + iter_in_avals = iter(in_avals_) + non_dce_avals = self._executable._all_args_info.in_avals + in_avals_ = [ + next(iter_in_avals) if i in self._executable._kept_var_idx + else a for i, a in zip(range(self.in_tree.num_leaves), non_dce_avals)] + return self.in_tree.unflatten(in_avals_) + + @property + def out_info(self): # PyTree of jax.ShapeDtypeStruct + out_avals = self._executable.out_avals + out_formats_flat = self._output_formats_flat + return self.out_tree.unflatten( + [core.ShapeDtypeStruct(o.shape, o.dtype, sharding=f) + for o, f in zip(out_avals, out_formats_flat)]) + def runtime_executable(self) -> Any | None: """An arbitrary object representation of this executable. @@ -486,37 +768,52 @@ def runtime_executable(self) -> Any | None: """ return self._executable.runtime_executable() - @property - def input_shardings(self): # PyTree[sharding.Sharding] - shardings_flat = self._executable.input_shardings() + def _input_shardings_flat(self): + shardings_flat = self._executable._in_shardings # Some input shardings got DCE'd if self.in_tree.num_leaves > len(shardings_flat): iter_shardings_flat = iter(shardings_flat) shardings_flat = [next(iter_shardings_flat) if i in self._executable._kept_var_idx else None for i in range(self.in_tree.num_leaves)] + return shardings_flat + + @property + def input_shardings(self): # -> PyTree[sharding.Sharding] + shardings_flat = self._input_shardings_flat() return tree_util.tree_unflatten(self.in_tree, shardings_flat) # pytype: disable=attribute-error @property - def output_shardings(self): # PyTree[sharding.Sharding] - shardings_flat = self._executable.output_shardings() + def output_shardings(self): # -> PyTree[sharding.Sharding] + shardings_flat = self._executable._out_shardings return tree_util.tree_unflatten(self.out_tree, shardings_flat) # pytype: disable=attribute-error - @property - def input_layouts(self): - layouts_flat = self._executable.input_layouts() - assert all(isinstance(l, Layout) for l in layouts_flat) + def _input_layouts_flat(self): + layouts_flat = self._executable._xla_in_layouts # Some input layouts got DCE'd if self.in_tree.num_leaves > len(layouts_flat): iter_layouts_flat = iter(layouts_flat) layouts_flat = [next(iter_layouts_flat) if i in self._executable._kept_var_idx - else Layout() for i in range(self.in_tree.num_leaves)] - return tree_util.tree_unflatten(self.in_tree, layouts_flat) # pytype: disable=attribute-error + else None for i in range(self.in_tree.num_leaves)] + return layouts_flat + + @property + def input_formats(self): + layouts_flat = self._input_layouts_flat() + shardings_flat = self._input_shardings_flat() + formats_flat = [Format(l, s) for l, s in zip(layouts_flat, shardings_flat)] + return tree_util.tree_unflatten(self.in_tree, formats_flat) # pytype: disable=attribute-error @property - def output_layouts(self): - layouts_flat = self._executable.output_layouts() + def _output_formats_flat(self): + layouts_flat = self._executable._xla_out_layouts + shardings_flat = self._executable._out_shardings assert all(isinstance(l, Layout) for l in layouts_flat) - return tree_util.tree_unflatten(self.out_tree, layouts_flat) # pytype: disable=attribute-error + return [Format(l, s) for l, s in zip(layouts_flat, shardings_flat)] + + @property + def output_formats(self): + formats_flat = self._output_formats_flat + return tree_util.tree_unflatten(self.out_tree, formats_flat) # pytype: disable=attribute-error @staticmethod def call(*args, **kwargs): @@ -526,15 +823,24 @@ def call(*args, **kwargs): # extract it from args because `params` can be passed as a kwarg by users # which might conflict here. params = args[0] - args = args[1:] - if config.dynamic_shapes.value: - raise NotImplementedError + args = args[1:] # Not including const_args if params.no_kwargs and kwargs: kws = ', '.join(kwargs.keys()) raise NotImplementedError( "function was compiled by a transformation that does not support " f"keyword arguments, but called with keyword arguments: {kws}") - args_flat, in_tree = tree_util.tree_flatten((args, kwargs)) + + if params.is_high: + hi_args_flat, in_hi_tree = tree_util.tree_flatten((args, kwargs)) + in_hi_tree_, final_qdds = params.in_types + args_flat = [a.read_loval(core.cur_qdd(x), x) if (a := typeof(x)).has_qdd + else a.lower_val(x) for x in hi_args_flat] + args_flat, in_tree = \ + tree_util.tree_flatten(tree_util.tree_unflatten(in_hi_tree, args_flat)) + else: + args_flat, in_tree = tree_util.tree_flatten((args, kwargs)) + + # TODO(mattjj): improve wrong-number-of-args error if in_tree != params.in_tree: errs = list(tree_util.equality_errors_pytreedef(in_tree, params.in_tree)) msg = [] @@ -548,32 +854,35 @@ def call(*args, **kwargs): f" * at {base}{tree_util.keystr(tuple(rest))}, seen {thing2} but now" f" given {thing1}, so {explanation}") raise TypeError('\n'.join(msg)) - try: - out_flat = params.executable.call(*args_flat) - except TypeError as e: - # We can't transform ahead-of-time compiled calls, since we've - # lowered and compiled for a fixed function signature, and JAX - # transformations change signatures. We interpret a Tracer - # argument as an indication of a transformation attempt. We - # could check this before the executable call, but we'd rather - # avoid isinstance checks on the call path. Seeing a TypeError - # might mean that arguments have JAX-invalid types, which in - # turn might mean some are Tracers. + + if not core.trace_state_clean(): + # We check for tracers when we are under a transformation, and skip the + # check in the common path. We can't transform ahead-of-time compiled + # calls, since we've lowered and compiled for a fixed function signature, + # and JAX transformations change signatures. for arg in args_flat: if isinstance(arg, core.Tracer): raise TypeError( "Cannot apply JAX transformations to a function lowered and " "compiled for a particular signature. Detected argument of " - f"Tracer type {type(arg)}.") from e - else: - raise - outs = tree_util.tree_unflatten(params.out_tree, out_flat) + f"Tracer type {type(arg)}.") + lo_outs = params.executable.call(*params.const_args, *args_flat) + + if params.is_high: + out_mut, lo_outs = util.split_list(lo_outs, [_num_himuts_out(final_qdds)]) + _apply_himut(final_qdds, hi_args_flat, out_mut) + out_hi_tree, out_hi_types = params.out_types + out_flat = _raise_lo_outs(out_hi_types, lo_outs) + outs = tree_util.tree_unflatten(out_hi_tree, out_flat) + else: + out_flat = lo_outs + outs = tree_util.tree_unflatten(params.out_tree, out_flat) + return outs, out_flat, args_flat def __call__(self, *args, **kwargs): if self._call is None: - self._call = self._executable.create_cpp_call( - self._no_kwargs, self.in_tree, self.out_tree) + self._call = self._executable.create_cpp_call(self._params) if self._call is None: params = self._params def cpp_call_fallback(*args, **kwargs): @@ -582,177 +891,22 @@ def cpp_call_fallback(*args, **kwargs): self._call = cpp_call_fallback return self._call(*args, **kwargs) +def _raise_lo_outs(avals, lo_outs): + from jax._src.interpreters import partial_eval as pe # type: ignore + return pe.raise_lo_outs(avals, lo_outs) -class Lowered(Stage): - """Lowering of a function specialized to argument types and values. +# TODO(mattjj): de-dup with partial_eval.py +def _num_himuts_out(final_qdds): + return sum(len(a.lo_ty()) for a in final_qdds if a.has_qdd) - A lowering is a computation ready for compilation. This class - carries a lowering together with the remaining information needed to - later compile and execute it. It also provides a common API for - querying properties of lowered computations across JAX's various - lowering paths (:func:`~jax.jit`, :func:`~jax.pmap`, etc.). - """ - __slots__ = ["_lowering", "args_info", "out_tree", "_no_kwargs"] - _lowering: XlaLowering - args_info: Any # PyTree of ArgInfo - out_tree: tree_util.PyTreeDef - _no_kwargs: bool - - def __init__( - self, - lowering: XlaLowering, - args_info, # PyTree of ArgInfo - out_tree: tree_util.PyTreeDef, - no_kwargs: bool = False): - - self._lowering = lowering - self.args_info = args_info - self.out_tree = out_tree - self._no_kwargs = no_kwargs - - @classmethod - def from_flat_info(cls, - lowering: XlaLowering, - in_tree: tree_util.PyTreeDef, - in_avals, - donate_argnums: tuple[int, ...], - out_tree: tree_util.PyTreeDef, - no_kwargs: bool = False): - """Initialize from flat info (``in_avals`` etc.) and an input PyTreeDef. - - Args: - in_tree: The ``PyTreeDef`` of (args, kwargs). - out_tree: The ``PyTreeDef`` of the outputs. - no_kwargs: If ``True`` the transformation, and the - ``Compiled`` returned from this object will not support keyword - arguments (an error will be raised if some are provided). - """ - return cls( - lowering, - make_args_info(in_tree, in_avals, donate_argnums), - out_tree, - no_kwargs=no_kwargs) - - @property - def out_info(self): # PyTree of OutInfo - out_avals = self._lowering.compile_args["global_out_avals"] - out_shardings = self._lowering.compile_args["out_shardings"] - return self.out_tree.unflatten( - [OutInfo(o.shape, o.dtype, None if isinstance(s, (UnspecifiedValue, AUTO)) else s) - for o, s in zip(out_avals, out_shardings)]) - - def compile( - self, compiler_options: CompilerOptions | None = None) -> Compiled: - """Compile, returning a corresponding ``Compiled`` instance.""" - kw: dict[str, Any] = {"compiler_options": compiler_options} - return Compiled( - self._lowering.compile(**kw), # pytype: disable=wrong-keyword-args - self.args_info, - self.out_tree, - no_kwargs=self._no_kwargs, - ) - - def as_text(self, dialect: str | None = None, *, - debug_info: bool = False) -> str: - """A human-readable text representation of this lowering. - - Intended for visualization and debugging purposes. This need not be a valid - nor reliable serialization. - Use `jax.export` if you want reliable and portable serialization. - - Args: - dialect: Optional string specifying a lowering dialect (e.g. "stablehlo", - or "hlo"). - debug_info: Whether to include debugging information, - e.g., source location. - """ - return self._lowering.as_text(dialect, debug_info=debug_info) - - def compiler_ir(self, dialect: str | None = None) -> Any | None: - """An arbitrary object representation of this lowering. - - Intended for debugging purposes. This is not a valid nor reliable - serialization. The output has no guarantee of consistency across - invocations. - Use `jax.export` if you want reliable and portable serialization. - - Returns ``None`` if unavailable, e.g. based on backend, compiler, or - runtime. - - Args: - dialect: Optional string specifying a lowering dialect (e.g. "stablehlo", - or "hlo"). - """ - try: - return self._lowering.compiler_ir(dialect) - except NotImplementedError: - return None - - def cost_analysis(self) -> Any | None: - """A summary of execution cost estimates. - - Intended for visualization and debugging purposes. The object output by - this is some simple data structure that can easily be printed or serialized - (e.g. nested dicts, lists, and tuples with numeric leaves). However, its - structure can be arbitrary: it may be inconsistent across versions of JAX - and jaxlib, or even across invocations. - - Returns ``None`` if unavailable, e.g. based on backend, compiler, or - runtime. - """ - # TODO(frostig): improve annotation (basic pytree of arbitrary structure) - try: - return self._lowering.cost_analysis() - except NotImplementedError: - return None - - -class Traced(Stage): - """Traced form of a function specialized to argument types and values. - - A traced computation is ready for lowering. This class carries the - traced representation with the remaining information needed to later - lower, compile, and execute it. - """ - __slots__ = ["jaxpr", "args_info", "fun_name", "_out_tree", "_lower_callable", - "_args_flat", "_arg_names", "_num_consts"] - - def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree, - lower_callable, args_flat=None, arg_names=None, - num_consts: int = 0): - self.jaxpr = jaxpr - self.args_info = args_info - self.fun_name = fun_name - self._out_tree = out_tree - self._lower_callable = lower_callable - self._args_flat = args_flat - self._arg_names = arg_names - self._num_consts = num_consts - - @property - def out_info(self): - return self._out_tree.unflatten( - [OutInfo(o.shape, o.dtype) for o in self.jaxpr.out_avals]) - - def lower(self, *, lowering_platforms: tuple[str, ...] | None = None, - _private_parameters: mlir.LoweringParameters | None = None): - """Lower to compiler input, returning a ``Lowered`` instance.""" - from jax._src.interpreters import pxla - from jax._src import pjit - - if _private_parameters is None: - _private_parameters = mlir.LoweringParameters() - new_callable = functools.partial( - self._lower_callable, lowering_platforms=lowering_platforms, - lowering_parameters=_private_parameters) - try: - lowering = new_callable() - except pxla.DeviceAssignmentMismatchError as e: - fails, = e.args - msg = pjit._device_assignment_mismatch_error( - self.fun_name, fails, self._args_flat, 'jit', self._arg_names) - raise ValueError(msg) from None - return Lowered(lowering, self.args_info, self._out_tree) +# TODO(mattjj): de-dup with partial_eval.py +def _apply_himut(final_qdds, hi_args, out_mut): + out_mut_ = iter(out_mut) + for i, a in enumerate(final_qdds): + if isinstance(a, core.AvalQDD): + lo_vals = it.islice(out_mut_, len(a.aval.lo_ty_qdd(a.qdd))) + a.aval.update_from_loval(a.qdd, hi_args[i], *lo_vals) # type: ignore + assert next(out_mut_, None) is None @runtime_checkable @@ -793,3 +947,110 @@ def lower(self, *args, **kwargs) -> Lowered: A ``Lowered`` instance representing the lowering. """ raise NotImplementedError + + +class MismatchType(enum.Enum): + ARG_SHARDING = enum.auto() + CONST_SHARDING = enum.auto() + OUT_SHARDING = enum.auto() + SHARDING_INSIDE_COMPUTATION = enum.auto() + CONTEXT_DEVICES = enum.auto() + IN_SHARDING = enum.auto() + + def __str__(self): + if self.name == 'IN_SHARDING': + return 'explicit input sharding' + if self.name == 'CONST_SHARDING': + return 'closed over constant sharding' + elif self.name == 'OUT_SHARDING': + return 'explicit output sharding' + elif self.name == 'CONTEXT_DEVICES': + return 'context mesh' + return f'{self.name}' + + +class SourceInfo(NamedTuple): + source_info: source_info_util.SourceInfo + eqn_name: str + + +@dataclasses.dataclass +class DeviceAssignmentMismatch: + da: Sequence[xc.Device] + m_type: MismatchType + source_info: SourceInfo | None + + @property + def device_ids(self) -> Sequence[int]: + return [d.id for d in self.da] + + @property + def platform(self) -> str: + return self.da[0].platform.upper() + + def _maybe_api_name(self, api_name) -> str: + return f" {api_name}'s" if self.m_type == MismatchType.CONTEXT_DEVICES else "" + + @property + def source_info_str(self): + return ( + "" if self.source_info is None + else f" at {source_info_util.summarize(self.source_info.source_info)}" + ) + + @property + def _dev_ids_plat_str(self): + return f"device ids {self.device_ids} on platform {self.platform}" + + def m_type_str(self, api_name): + return (f'{self.source_info and self.source_info.eqn_name} inside {api_name}' + if self.m_type == MismatchType.SHARDING_INSIDE_COMPUTATION else self.m_type) + + def _str(self, api_name): + return (f"{self._maybe_api_name(api_name)} {self.m_type_str(api_name)} with " + f"{self._dev_ids_plat_str}{self.source_info_str}") + + +class DeviceAssignmentMismatchError(Exception): + pass + + +def _find_arg_mismatch(arg_list, fails, fun_name): + mismatched_args_msg = [] + def mismatch(err): + for name, inp_da, aval in arg_list: + if err.m_type == MismatchType.ARG_SHARDING and err.da == inp_da: + mismatched_args_msg.append( + f"argument {name} of {fun_name} with shape {aval.str_short()} and " + f"{err._dev_ids_plat_str}") + break + first_err, second_err = fails + mismatch(first_err) + mismatch(second_err) + return mismatched_args_msg + + +def _device_assignment_mismatch_error(fun_name, fails, args_flat, api_name, + arg_names): + arg_list = [] + if arg_names is None: + arg_names = [''] * len(args_flat) + for a, n in zip(args_flat, arg_names): + da = a.sharding._device_assignment if a.sharding is not None else None + arg_list.append((n, da, a.aval)) + + mismatched_args_msg = _find_arg_mismatch(arg_list, fails, fun_name) + + if len(mismatched_args_msg) == 2: + first, second = mismatched_args_msg # pytype: disable=bad-unpacking + extra_msg = f" Got {first} and {second}" + elif len(mismatched_args_msg) == 1: + first, second = fails + # Choose the failure left which is not already covered by ARG_SHARDING. + left = second if first.m_type == MismatchType.ARG_SHARDING else first + extra_msg = f" Got {mismatched_args_msg[0]} and{left._str(api_name)}" + else: + first, second = fails + extra_msg = f" Got{first._str(api_name)} and{second._str(api_name)}" + msg = (f"Received incompatible devices for {api_name}ted computation.{extra_msg}") + return msg diff --git a/jax/_src/state/__init__.py b/jax/_src/state/__init__.py index 38710d9db874..adf7926d7dbd 100644 --- a/jax/_src/state/__init__.py +++ b/jax/_src/state/__init__.py @@ -22,5 +22,6 @@ TransformedRef as TransformedRef, WriteEffect as WriteEffect, get_ref_state_effects as get_ref_state_effects, + get_transforms_shape as get_transforms_shape, shaped_array_ref as shaped_array_ref, ) diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index 7ab77d5b1c37..bf6381dd30f2 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -23,38 +23,29 @@ from jax._src import ad_util from jax._src import api_util +from jax._src import config from jax._src import core +from jax._src import literals from jax._src import linear_util as lu +from jax._src import pjit +from jax._src import sharding_impls from jax._src import source_info_util from jax._src import tree_util +from jax._src import custom_derivatives from jax._src.interpreters import ad from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax from jax._src.lax import slicing as lax_slicing from jax._src.state import indexing -from jax._src.state.primitives import addupdate_p, get_p, swap_p +from jax._src.state.primitives import addupdate_p, get_p, swap_p, pin, unpin from jax._src.state.types import ( - AbstractRef, - RefBitcaster, - RefEffect, - RefReshaper, - get_ref_aval_from_value, - uninitialized, -) + AbstractRef, RefBitcaster, RefEffect, RefReshaper, get_ref_aval_from_value, + uninitialized,) from jax._src.state.utils import bitcast, hoist_consts_to_refs from jax._src.typing import Array -from jax._src.util import ( - foreach, - merge_lists, - partition_list, - safe_map, - safe_zip, - split_dict, - split_list, - unzip2, - weakref_lru_cache, -) +from jax._src.util import (foreach, safe_map, safe_zip, split_list, unzip2, + weakref_lru_cache) import numpy as np ## JAX utilities @@ -65,15 +56,29 @@ ## Discharging state -# Let's say we have a jaxpr that takes in `Ref`s and outputs regular JAX values -# (`Ref`s should never be outputs from jaxprs). We'd like to convert that jaxpr -# into a "pure" jaxpr that takes in and outputs values and no longer has the -# `Read/Write/Accum` effects. -def discharge_state(jaxpr: core.Jaxpr, consts: Sequence[Any], * , - should_discharge: bool | Sequence[bool] = True - ) -> tuple[core.Jaxpr, list[Any]]: - """Converts a jaxpr that takes in `Ref`s into one that doesn't.""" +def discharge_state( + jaxpr: core.Jaxpr, + consts: Sequence[Any], + *, + should_discharge: bool | Sequence[bool] = True, +) -> tuple[core.Jaxpr, Sequence[Any]]: + """Converts a stateful jaxpr into a pure one. + + Discharging replaces ``Ref`` inputs with regular values, threads updates + through the computation, and returns updated ``Ref``s as additional outputs. + + Args: + jaxpr: A stateful jaxpr with ``Ref`` inputs. + consts: Constants for the jaxpr. + should_discharge: Whether to discharge each ``Ref`` input. If a single bool, + applies to all inputs. + + Returns: + A tuple of ``(new_jaxpr, new_consts)`` where ``new_jaxpr`` is a jaxpr with + no ``Read``/``Write``/``Accum`` effects. Discharged ``Ref`` inputs become + regular value inputs, and their updated values are appended to the outputs. + """ if isinstance(should_discharge, bool): should_discharge = [should_discharge] * len(jaxpr.invars) in_avals = [v.aval.inner_aval @@ -81,10 +86,26 @@ def discharge_state(jaxpr: core.Jaxpr, consts: Sequence[Any], * , else v.aval for v, d in zip(jaxpr.invars, should_discharge)] eval_jaxpr = lu.wrap_init(partial(_eval_jaxpr_discharge_state, jaxpr, should_discharge, consts), - debug_info=jaxpr.debug_info) - new_jaxpr, _ , new_consts, () = pe.trace_to_jaxpr_dynamic(eval_jaxpr, in_avals) + debug_info=jaxpr.debug_info.with_unknown_names()) + new_jaxpr, _ , new_consts = pe.trace_to_jaxpr_dynamic(eval_jaxpr, in_avals) return new_jaxpr, new_consts +# TODO(mattjj): migrate callers to discharge_state2 for caching +def discharge_state2(jaxpr: core.ClosedJaxpr, + should_discharge: bool | Sequence[bool] = True, + ) -> core.ClosedJaxpr: + if isinstance(should_discharge, bool): + should_discharge = (should_discharge,) * len(jaxpr.in_avals) + return _discharge_state2(jaxpr, tuple(should_discharge)) + +@weakref_lru_cache +def _discharge_state2(jaxpr: core.ClosedJaxpr, + should_discharge: tuple[bool, ...], + ) -> core.ClosedJaxpr: + jaxpr_, consts = discharge_state(jaxpr.jaxpr, jaxpr.consts, + should_discharge=should_discharge) + return core.ClosedJaxpr(jaxpr_, consts) + @dataclasses.dataclass class Environment: env: dict[core.Var, Any] @@ -98,37 +119,69 @@ def read(self, v: core.Atom) -> Any: def write(self, v: core.Var, val: Any) -> None: self.env[v] = val + class DischargeRule(Protocol): - def __call__(self, in_avals: Sequence[core.AbstractValue], - out_avals: Sequence[core.AbstractValue], *args: Any, - **params: Any) -> tuple[Sequence[Any | None], Sequence[Any]]: - ... + def __call__( + self, + in_avals: Sequence[core.AbstractValue], + out_avals: Sequence[core.AbstractValue], + *args: Any, + **params: Any, + ) -> tuple[Sequence[Any | None], Any | Sequence[Any]]: + """Discharge rule for a primitive. + + See :func:`discharge_state` for an explanation of what discharge means. + + Args: + in_avals: Input abstract values. + out_avals: Output abstract values. + *args: Input values. + **params: Primitive parameters. + + Returns: + A tuple of ``(new_invals, new_outvals)`` where: + + * ``new_invals`` contains updated values for discharged ``Ref`` inputs, + or ``None`` if the input is not a ``Ref`` or was not updated. + * ``new_outvals`` is the primitive's output. A sequence if the primitive + has multiple results, otherwise a single value. + """ + _discharge_rules: dict[core.Primitive, DischargeRule] = {} + +def register_discharge_rule(prim: core.Primitive): + def register(f: DischargeRule): + _discharge_rules[prim] = f + return f + + return register + + class PartialDischargeRule(Protocol): - """A partial discharge rule. + """Discharge rule that supports selective discharging of ``Ref`` inputs. - Exactly like a discharge rule only it accepts a `should_discharge` - argument that indicates which inputs should be discharged and the - return value returns a tuple of which the first element is the new - inputs or none but only the ones that correspond to `True` entries - in `should_charge`. + Generalizes :class:`DischargeRule` by accepting a ``should_discharge`` + argument that specifies which ``Ref`` inputs to discharge. The returned + ``new_invals`` must contain a non-``None`` value if and only if the + corresponding ``Ref`` was discharged. """ - def __call__(self, should_discharge: Sequence[bool], + def __call__( + self, + should_discharge: Sequence[bool], in_avals: Sequence[core.AbstractValue], - out_avals: Sequence[core.AbstractValue], *args: Any, - **params: Any) -> tuple[Sequence[Any | None], Sequence[Any]]: + out_avals: Sequence[core.AbstractValue], + *args: Any, + **params: Any, + ) -> tuple[Sequence[Any | None], Any | Sequence[Any]]: ... + _partial_discharge_rules: dict[core.Primitive, PartialDischargeRule] = {} -def register_discharge_rule(prim: core.Primitive): - def register(f: DischargeRule): - _discharge_rules[prim] = f - return register def register_partial_discharge_rule(prim: core.Primitive): def register(f: PartialDischargeRule): @@ -150,31 +203,31 @@ def _eval_jaxpr_discharge_state( if d and isinstance(v.aval, AbstractRef)} for eqn in jaxpr.eqns: - name_stack = ( - source_info_util.current_name_stack() + eqn.source_info.name_stack - ) + name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack traceback = eqn.source_info.traceback with source_info_util.user_context( traceback, name_stack=name_stack), eqn.ctx.manager: should_discharge = [id(v.aval) in refs_to_discharge for v in eqn.invars] - if eqn.primitive is core.mutable_array_p: + if eqn.primitive is core.ref_p: [invar], [outvar] = eqn.invars, eqn.outvars ans = env.read(invar) + if config.refs_to_pins.value: + ans = pin(ans) refs_to_discharge.add(id(outvar.aval)) elif eqn.primitive is core.freeze_p: [invar], [outvar] = eqn.invars, eqn.outvars ans = env.read(invar) + if config.refs_to_pins.value: + ans = unpin(ans) refs_to_discharge.remove(id(invar.aval)) - elif (any(should_discharge) - or core.internal_mutable_array_effect in eqn.effects - ): + elif any(should_discharge) or core.internal_mutable_array_effect in eqn.effects: if eqn.primitive in _partial_discharge_rules: rule: DischargeRule = partial(_partial_discharge_rules[eqn.primitive], should_discharge) elif eqn.primitive in _discharge_rules: rule = _discharge_rules[eqn.primitive] else: - raise NotImplementedError("No state discharge rule implemented for " - f"primitive: {eqn.primitive}") + raise NotImplementedError( + f"No state discharge rule implemented for primitive: {eqn.primitive}") invals = map(env.read, eqn.invars) in_avals = [v.aval for v in eqn.invars] out_avals = [v.aval for v in eqn.outvars] @@ -208,14 +261,13 @@ def _eval_jaxpr_discharge_state( return out_vals + ref_vals def _is_trivial_indexer(indexer: indexing.NDIndexer): + """Returns whether the indexer selects the entire shape.""" for s, idx in zip(indexer.shape, indexer.indices): if not isinstance(idx, indexing.Slice): return False - if not isinstance(idx.start, int): + if idx.is_dynamic_start or idx.is_dynamic_size: return False - if idx.start: - return False - if idx.size != s: + if idx.start != 0 or idx.size != s: return False return True @@ -275,33 +327,97 @@ def _maybe_convert_to_dynamic_slice( return starts, sizes, squeeze_dims -def _convert_to_array_indexer(indexer: indexing.NDIndexer - ) -> tuple[int | Array, ...]: - # This is the general gather case. We need to create the gather arrays. - is_integer_indexer, _, integer_indexer = ( - indexing.unpack_ndindexer(indexer) +# In this code, indexing is handled in three ways: `slice`, `dynamic_slice`, and +# gather. For the gather case, the goal is to create a gather array, which means +# that we need to convert all other types of indexers into integer array +# indexers. This is done by looping over all indexers and checking if they are +# not integer array indexers, and if not, performing the conversion. However, +# during this process, the indexing semantics may change. Specifically, +# according to the indexing rules of NumPy, when there are integer array +# indexers separated by other indexers, the axes corresponding to the integer +# array indexers need to be moved to the front. After we convert all other +# indexers to integer array indexers, the distinction between integer array +# indexers and other types of indexers is lost. As a result, it becomes +# impossible to determine which axes should be moved to the front. In this case, +# we need to transpose the target array before the gather operation. We also +# need to transpose the target array back after the gather operation, if it is +# used in subsequent computations. +def _maybe_transpose_before_gather( + indexer: indexing.NDIndexer +) -> tuple[int, ...] | None: + is_int_indexing, _, _ = indexing.unpack_ndindexer(indexer) + + int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_int_indexing)[0]) == 1) ) - total_shape = indexer.get_indexer_shape() - int_indexer_shape = indexer.int_indexer_shape - slice_shape = total_shape[len(int_indexer_shape):] - slice_dims = tuple( - i + len(int_indexer_shape) for i in range(len(slice_shape)) + if int_indexers_contiguous: + return None # no transpose needed + + int_indexer_idxs: list[int] = [] + non_int_indexer_idxs: list[int] = [] + for i, is_int_index in enumerate(is_int_indexing): + (int_indexer_idxs if is_int_index else non_int_indexer_idxs).append(i) + transpose_order = (*int_indexer_idxs, *non_int_indexer_idxs) + return transpose_order + + +def _perform_transpose_before_gather( + target_arr: Array, + indexer: indexing.NDIndexer, + transpose_order: tuple[int, ...], +) -> tuple[Array, indexing.NDIndexer]: + new_target_arr = target_arr.transpose(transpose_order) + reordered_indices = tuple(indexer.indices[i] for i in transpose_order) + new_indexer = indexing.NDIndexer( + indices=reordered_indices, + shape=indexer.shape, + int_indexer_shape=indexer.int_indexer_shape, ) - slice_dim_iter = iter(slice_dims) - slice_indexer: list[Array] = [] - for idx, is_int_index in zip(indexer.indices, is_integer_indexer): - if not is_int_index: - assert isinstance(idx, indexing.Slice) - slice_indices = lax.broadcasted_iota( - np.dtype("int32"), total_shape, next(slice_dim_iter) - ) * idx.stride + idx.start - slice_indexer.append(slice_indices) - integer_indexer = tuple( - lax.expand_dims(idx, (-1,)) for idx in integer_indexer + return new_target_arr, new_indexer + + +def _convert_to_gather_arrays(indexer: indexing.NDIndexer) -> tuple[Array, ...]: + # This is the general gather case. We need to create the gather arrays. + total_shape = indexer.get_indexer_shape() + is_int_indexing, _, _ = indexing.unpack_ndindexer(indexer) + + if any(is_int_indexing): + n_idxers = len(indexer.indices) + int_indexer_shape = indexer.int_indexer_shape + n_int_indexers = sum(1 for p in is_int_indexing if p) + last_int_index_idx = n_idxers - 1 - is_int_indexing[::-1].index(True) + n_slice_index_dims_after_int = n_idxers - last_int_index_idx - 1 + + def get_idx_in_shape_after_indexing(i): + if not any(is_int_indexing): + return i + + if i < n_idxers - n_slice_index_dims_after_int - n_int_indexers: + return i + if i < n_idxers - n_slice_index_dims_after_int: + raise ValueError + return i - n_int_indexers + len(int_indexer_shape) + + arrs = [] + for i, idxer in enumerate(indexer.indices): + if isinstance(idxer, indexing.Slice): + idx_in_shape_after_indexing = get_idx_in_shape_after_indexing(i) + arr = ( + lax.iota(np.int32, total_shape[idx_in_shape_after_indexing]) + * idxer.stride + + idxer.start ) - continue - assert next(slice_dim_iter, None) is None - return tuple(merge_lists(is_integer_indexer, slice_indexer, integer_indexer)) + diff = len(total_shape) - idx_in_shape_after_indexing - 1 + arr = arr.reshape(arr.shape + (1,) * diff) + arrs.append(arr) + elif isinstance(idxer, (np.ndarray, Array, literals.TypedNdArray)): + diff = n_idxers - 1 - last_int_index_idx + arr = idxer.reshape(idxer.shape + (1,) * diff) + arrs.append(arr) + else: + raise ValueError(f"Invalid type of idxer: {type(idxer).__name__}") + + return tuple(arrs) @register_discharge_rule(get_p) @@ -313,20 +429,8 @@ def _get_discharge_rule( y = _get_discharge(x, idx, tree) return (None,) * (len(idx) + 1), y -def _prepend_gather(x, indexer): - # NumPy advanced int indexing won't prepend w/ only one dim, so add dummy. - return x[None][(np.array(0, 'int32'), *indexer)] - -def _prepend_scatter(x, indexer, val, *, add=False): - # NumPy advanced int indexing won't prepend w/ only one dim, so add dummy. - # However, since this is scatter, we need to remove the 1-sized dimension - # we added at the front. - if add: - return x[None].at[(0, *indexer)].add(val)[0] - return x[None].at[(0, *indexer)].set(val)[0] - -def _index_array(x, indexer): +def _index_array(x, indexer: indexing.NDIndexer): if _is_trivial_indexer(indexer): return x # Try the three APIs in the following order: `lax.slice`, @@ -336,13 +440,16 @@ def _index_array(x, indexer): # If everything in the indexer is a slice or ()-shaped, we can also # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. # We need to squeeze out the 1-sized slices at the end. - elif maybe_slice := _maybe_convert_to_dynamic_slice(indexer): - starts, sizes, squeeze_dims = maybe_slice + elif maybe_dynamic_slice := _maybe_convert_to_dynamic_slice(indexer): + starts, sizes, squeeze_dims = maybe_dynamic_slice y = lax_slicing.dynamic_slice(x, starts, sizes) x = lax.squeeze(y, squeeze_dims) else: - indexer = _convert_to_array_indexer(indexer) - x = x[None][(np.array(0, "int32"), *indexer)] + transpose_order = _maybe_transpose_before_gather(indexer) + if transpose_order is not None: + x, indexer = _perform_transpose_before_gather(x, indexer, transpose_order) + arrays = _convert_to_gather_arrays(indexer) + x = x[arrays] return x @@ -367,53 +474,79 @@ def transform_array(x, transforms): def transform_swap_array(x, transforms, val): if transforms is None: transforms = [] - result = x - result_val = val - # Compute updated "val" (result). - _results = [x] + + # Will hold the value read from `x` before the swap, and will have the same + # shape as `val`. + new_val = x + # List of intermediate results by transforming `x`. + intermediates = [x] + + # Read phase (forward loop) for transform in transforms: match transform: case indexing.NDIndexer(): indexer = transform if _is_trivial_indexer(indexer): - _results.append(_results[-1]) + intermediates.append(intermediates[-1]) continue # If everything in the indexer is a slice or ()-shaped, we can also # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. # We need to squeeze out the 1-sized slices at the end. if maybe_slice := _maybe_convert_to_dynamic_slice(indexer): starts, sizes, squeeze_dims = maybe_slice - result_old = lax_slicing.dynamic_slice(result, starts, sizes) - result = lax.squeeze(result_old, squeeze_dims) + new_val = lax.squeeze( + lax_slicing.dynamic_slice(new_val, starts, sizes), squeeze_dims + ) else: - indexer = _convert_to_array_indexer(indexer) - result = _prepend_gather(result, indexer) - _results.append(result) + transpose_order = _maybe_transpose_before_gather(indexer) + if transpose_order is not None: + new_val, indexer = _perform_transpose_before_gather( + new_val, indexer, transpose_order + ) + arrays = _convert_to_gather_arrays(indexer) + new_val = new_val[arrays] + # Here, we don't need to transpose `new_val` back because it now holds + # the result of the indexing, and is no longer the original array that + # was indexed into. + intermediates.append(new_val) case RefBitcaster(): - _results.append(bitcast(result, transform.dtype)) + intermediates.append(bitcast(new_val, transform.dtype)) case RefReshaper(): - _results.append(result.reshape(transform.shape)) + intermediates.append(new_val.reshape(transform.shape)) case _: raise NotImplementedError(f"Unsupported transform: {transform}") - # Compute updated "x" (result_val) - for i, transform in reversed(list(enumerate(transforms))): + # Will hold the final state of the `x` after `val` has been written to the + # transformed location, and will have the same shape as `x`. + new_x = val + + # Write phase (reversed loop) + for intermediate, transform in reversed(zip(intermediates[:-1], transforms)): if isinstance(transform, indexing.NDIndexer): indexer = transform if _is_trivial_indexer(indexer): continue if maybe_slice := _maybe_convert_to_dynamic_slice(indexer): starts, _, squeeze_dims = maybe_slice - result_val = lax.expand_dims(result_val, squeeze_dims) - result_val = lax_slicing.dynamic_update_slice( - _results[i], result_val, starts + new_x = lax_slicing.dynamic_update_slice( + intermediate, lax.expand_dims(new_x, squeeze_dims), starts ) else: - indexer = _convert_to_array_indexer(indexer) - result_val = _prepend_scatter(_results[i], indexer, result_val) + transpose_order = _maybe_transpose_before_gather(indexer) + if transpose_order is not None: + intermediate, indexer = _perform_transpose_before_gather( + intermediate, indexer, transpose_order + ) + arrays = _convert_to_gather_arrays(indexer) + new_x = intermediate.at[arrays].set(new_x) # pytype: disable=attribute-error + if transpose_order is not None: + transpose_order_inversed = np.argsort(transpose_order) + new_x = new_x.transpose(transpose_order_inversed) else: raise NotImplementedError(f"Unsupported transform: {transform}") - return result, result_val + + return new_val, new_x + def _get_discharge(x, idx, tree): transforms = tree_util.tree_unflatten(tree, idx) @@ -443,11 +576,15 @@ def _addupdate_discharge_rule( def _addupdate_discharge(x, val, idx, tree): transforms = tree_util.tree_unflatten(tree, idx) + if not transforms: + return x + val if len(transforms) > 1: raise NotImplementedError("Only single indexer is supported.") indexer = transforms[0] + if _is_trivial_indexer(indexer): return x + val + # If everything in the indexer is a slice or ()-shaped, we can also # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. # We need to squeeze out the 1-sized slices at the end. @@ -457,8 +594,17 @@ def _addupdate_discharge(x, val, idx, tree): val = lax.expand_dims(val, squeeze_dims) y = lax_slicing.dynamic_update_slice(x, x_old + val, starts) return y - indexer = _convert_to_array_indexer(indexer) - return _prepend_scatter(x, indexer, val, add=True) + + transpose_order = _maybe_transpose_before_gather(indexer) + if transpose_order is not None: + x, indexer = _perform_transpose_before_gather(x, indexer, transpose_order) + arrays = _convert_to_gather_arrays(indexer) + x = x.at[arrays].add(val) + if transpose_order is not None: + transpose_order_inversed = np.argsort(transpose_order) + x = x.transpose(transpose_order_inversed) + return x + @weakref_lru_cache def _cached_closed_jaxpr_discharge(closed_jaxpr: core.ClosedJaxpr): @@ -485,11 +631,56 @@ def _closed_call_discharge_rule( assert next(ref_vals_iter, sentinel) is sentinel return new_invals, out_vals +def _call_primitive_discharge_rule( + prim: core.Primitive, + in_avals: Sequence[core.AbstractValue], _,*args, + call_jaxpr: core.Jaxpr, **kwargs): + closed_call_jaxpr = core.ClosedJaxpr(call_jaxpr, ()) + discharged_closed_jaxpr, num_outs, fun = _cached_closed_jaxpr_discharge( + closed_call_jaxpr) + discharged_call_jaxpr = discharged_closed_jaxpr.jaxpr + discharged_consts = discharged_closed_jaxpr.consts + discharged_call_jaxpr = pe.convert_constvars_jaxpr(discharged_call_jaxpr) + out_and_ref_vals = prim.bind(fun, *discharged_consts, *args, + call_jaxpr=discharged_call_jaxpr, + **kwargs) + out_vals, ref_vals = split_list(out_and_ref_vals, [num_outs]) + ref_vals_iter = iter(ref_vals) + new_invals = tuple(next(ref_vals_iter) if isinstance(aval, AbstractRef) + else None for aval in in_avals) + sentinel = object() + assert next(ref_vals_iter, sentinel) is sentinel + return new_invals, out_vals +register_discharge_rule(core.call_p)( + partial(_call_primitive_discharge_rule, core.call_p) +) + + # # `run_state` run_state_p = core.Primitive("run_state") run_state_p.multiple_results = True +def _run_state_is_high(*_, jaxpr, **__): + return jaxpr.is_high +run_state_p.is_high = _run_state_is_high # type: ignore + +def _run_state_to_lojax(*args, jaxpr, is_initialized, **params): + assert not jaxpr.constvars + closed_jaxpr = core.ClosedJaxpr(jaxpr, ()) + arg_avals = map(core.typeof, args) + args, is_initialized = unzip2( + (lo_val, is_init) for a, x, is_init in zip(arg_avals, args, is_initialized) + for lo_val in (a.read_loval(x) if a.has_qdd else a.lower_val(x))) + lo_jaxpr = pe.lower_jaxpr(closed_jaxpr) + all_outs = run_state_p.bind(*lo_jaxpr.consts, *args, jaxpr=lo_jaxpr.jaxpr, + is_initialized=is_initialized, **params) + out_mut, lo_outs = split_list(all_outs, [pe.num_himuts_out(jaxpr)]) + pe.apply_himut(jaxpr, args, out_mut) + return pe.raise_lo_outs(arg_avals, lo_outs) +run_state_p.to_lojax = _run_state_to_lojax + + def _default_initialization(x): assert hasattr(x, 'shape') assert hasattr(x, 'dtype') @@ -593,390 +784,6 @@ def _run_state_jvp(primals: Sequence[Any], tangents: Sequence[Any], *, return out_primals, out_tangents ad.primitive_jvps[run_state_p] = _run_state_jvp -_save_everything = lambda *_, **__: True - -def _convert_outputs_to_writes( - jaxpr: core.Jaxpr) -> tuple[core.Jaxpr, list[core.ShapedArray]]: - assert not jaxpr.constvars, "Jaxpr shouldn't have constvars." - - in_avals = [v.aval for v in jaxpr.invars] - def eval_jaxpr(*refs): - # We split the refs into the original input refs and the dummy residual - # refs. - orig_refs, residual_refs = split_list(refs, [len(in_avals)]) - residual_vals = core.eval_jaxpr(jaxpr, (), *orig_refs) - for res_ref, res_val in zip(residual_refs, residual_vals): - res_ref[...] = res_val - return [] - res_ref_avals = [AbstractRef(v.aval) if not isinstance(v.aval, AbstractRef) - else v.aval for v in jaxpr.outvars] - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(eval_jaxpr, - debug_info=jaxpr.debug_info), - [*in_avals, *res_ref_avals]) - assert not consts - return jaxpr, [core.ShapedArray(a.shape, a.dtype) for a in res_ref_avals] - -def _convert_inputs_to_reads(num_res: int, jaxpr: core.Jaxpr) -> core.Jaxpr: - assert not jaxpr.constvars, "Jaxpr should not have constvars" - - def eval_jaxpr(*refs): - residual_refs, orig_refs = split_list(refs, [num_res]) - residual_vals = [r[...] for r in residual_refs] - () = core.eval_jaxpr(jaxpr, (), *residual_vals, *orig_refs) - return [] - - res_val_avals, orig_ref_avals = \ - split_list([v.aval for v in jaxpr.invars], [num_res]) - res_ref_avals = [AbstractRef(aval) if not isinstance(aval, AbstractRef) else - aval for aval in res_val_avals] - jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(eval_jaxpr, - debug_info=jaxpr.debug_info), - [*res_ref_avals, *orig_ref_avals]) - return jaxpr - -def _run_state_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer, - jaxpr: core.Jaxpr, which_linear: tuple[bool, ...], - is_initialized: tuple[bool, ...]): - if not all(is_initialized): - raise NotImplementedError( - "Uninitialized Refs are not supported in partial_eval." - ) - num_inputs = len(tracers) - assert num_inputs == len(jaxpr.invars) - in_unknowns = [not t.pval.is_known() for t in tracers] - # We first need to run a fixpoint to determine which of the `Ref`s are unknown - # after running the for loop. We want to use the jaxpr to determine which - # `Ref`s are unknown after executing the for loop body given which `Ref`s are - # unknown before. However, the jaxpr has no outputs. Instead, we discharge - # the body and run the fixpoint with the discharged jaxpr. We can do this - # because the outputs of the jaxpr are one-to-one with the inputs. - discharged_jaxpr_, discharged_consts = discharge_state(jaxpr, ()) - discharged_jaxpr = pe.convert_constvars_jaxpr(discharged_jaxpr_) - for _ in range(num_inputs): - jaxpr_in_unknowns = [False] * len(discharged_consts) + in_unknowns - _, _, out_unknowns, out_inst, _, _ = pe.partial_eval_jaxpr_stateful( - discharged_jaxpr, jaxpr_in_unknowns, jaxpr_in_unknowns, - in_unknowns, False, _save_everything) - # assert out_inst == out_unknowns - out_unknowns = list(out_unknowns) - if out_unknowns == in_unknowns: - break - in_unknowns = map(operator.or_, in_unknowns, out_unknowns) - else: - raise Exception("Invalid fixpoint") - del out_unknowns # redundant since it's the same as `in_unknowns` - tracers = tuple(trace.instantiate_const(t) if uk else t - for t, uk in zip(tracers, in_unknowns)) - - # We use `partial_eval_jaxpr_stateful` here because it won't remove effectful - # primitives like `get`/`set`. - jaxpr_known_resout, jaxpr_unknown_resin_, _, _, num_res_out, num_res_ref = \ - pe.partial_eval_jaxpr_stateful(jaxpr, in_unknowns, in_inst=in_unknowns, - ensure_out_unknowns=[], ensure_out_inst=[], - saveable=_save_everything) - # # `partial_eval_jaxpr_stateful` will give us jaxprs that have hybrid `Ref` - # and regular valued input/outputs. However, we'd like to bind these jaxprs to - # a `for`, which expects only `Ref` inputs and no output. We need to convert - # both of these jaxprs into ones that are compatible with `for`. - - # `jaxpr_known_resout` is a jaxpr that maps from all the input `Refs` - # to output residual values (none of them should be `Ref`s). We'll need to - # convert the output residual values into `Ref`s that are initially empty - # `Ref`s that are written to at the end of the jaxpr. - num_res = num_res_out + num_res_ref - - num_invars = len(jaxpr_known_resout.invars) - num_res_ref - _, res_ref_avals = split_list( - [v.aval for v in jaxpr_known_resout.invars], [num_invars]) - res_avals = [a.inner_aval for a in res_ref_avals] # pytype: disable=attribute-error - jaxpr_known, new_res_avals = _convert_outputs_to_writes(jaxpr_known_resout) - # We now run the known jaxpr to obtain our residual values. - known_tracers, _ = partition_list(in_unknowns, tracers) - known_which_linear, _ = partition_list(in_unknowns, which_linear) - known_vals = [t.pval.get_known() for t in known_tracers] - all_res_avals = [*res_avals, *new_res_avals] - empty_res = map(ad_util.zeros_like_aval, all_res_avals) - jaxpr_known_args = [*known_vals, *empty_res] - - jaxpr_known_which_linear = (*known_which_linear, *(False,) * num_res) - out_flat = run_state_p.bind(*jaxpr_known_args, jaxpr=jaxpr_known, - which_linear=jaxpr_known_which_linear, - # TODO(sharadmv): compute this in the general case - is_initialized=(True,) * len(jaxpr_known.invars)) - known_outputs, residuals = split_list(out_flat, [len(known_tracers)]) - residuals = map(trace.new_instantiated_const, residuals) - ref_res, nonref_res = split_list(residuals, [num_res_ref]) - - # Now we handle the `jaxpr_unknown` that expects residual values as inputs. - # This jaxpr is the output of `partial_eval_jaxpr_stateful` that marks which - # inputs are actually used. - # `partial_eval_jaxpr_stateful` doesn't remove extra inputs/outputs for you - # so we use `dce_jaxpr` here to do that. - # To make it compatible with `for`, we need to convert those residual values - # into `Ref`s. - jaxpr_unknown = _convert_inputs_to_reads(len(new_res_avals), - jaxpr_unknown_resin_) - _, unknown_tracers = partition_list(in_unknowns, tracers) - _, uk_which_linear = partition_list(in_unknowns, which_linear) - unknown_which_linear = (False,) * num_res + tuple(uk_which_linear) - unknown_inputs = [*nonref_res, *ref_res, *unknown_tracers] - # Outputs match inputs so we construct output tracers that look like the input - # tracers. - res_ref_unknown_outputs = [ - pe.JaxprTracer(trace, pe.PartialVal.unknown(t.aval), None) - for t in unknown_inputs] - name_stack = source_info_util.current_name_stack()[len(trace.name_stack):] - source = source_info_util.current().replace(name_stack=name_stack) - - assert len(unknown_inputs) == len(res_ref_unknown_outputs) - assert len(unknown_inputs) == len(jaxpr_unknown.invars) - uk_params = dict(jaxpr=jaxpr_unknown, which_linear=unknown_which_linear, - # TODO(sharadmv); compute this in the general case - is_initialized=(True,) * len(jaxpr_unknown.invars)) - _, eqn_effects = run_state_p.abstract_eval(*[v.aval for v in unknown_inputs], - **uk_params) - eqn = pe.new_eqn_recipe(unknown_inputs, res_ref_unknown_outputs, - run_state_p, uk_params, - eqn_effects, source) - for t in res_ref_unknown_outputs: t.recipe = eqn - _, unknown_outputs = split_list(res_ref_unknown_outputs, [num_res]) - return merge_lists(in_unknowns, known_outputs, unknown_outputs) -pe.custom_partial_eval_rules[run_state_p] = _run_state_partial_eval - -def _run_state_partial_eval_custom( - saveable: Callable[..., pe.RematCases_], - in_unknowns: Sequence[bool], - in_inst: Sequence[bool], - eqn: core.JaxprEqn): - if not any(in_unknowns): - return eqn, None, in_unknowns, [False] * len(in_unknowns), [] - jaxpr, which_linear, is_initialized = split_dict( - eqn.params, ["jaxpr", "which_linear", "is_initialized"] - ) - if not all(is_initialized): - raise NotImplementedError( - "Uninitialized Refs are not supported in partial_eval_custom." - ) - num_inputs = len(eqn.invars) - # We first need to run a fixpoint to determine which of the `Ref`s are unknown - # after running the for loop. However, the jaxpr has no outputs. Instead, we - # discharge the body and run the fixpoint with the discharged jaxpr. We can do - # this because the outputs of the discharged jaxpr are one-to-one with the - # inputs. - discharged_jaxpr, discharged_consts = discharge_state(jaxpr, ()) - discharged_jaxpr = discharged_jaxpr.replace( - invars=discharged_jaxpr.constvars + discharged_jaxpr.invars, - constvars=[]) - in_unknowns, in_inst = list(in_unknowns), list(in_inst) - out_unknowns, out_inst = in_unknowns, in_unknowns - for _ in range(num_inputs): - jaxpr_in_unknowns = [False] * len(discharged_consts) + in_unknowns - _, _, out_unknowns, out_inst, _, _ = pe.partial_eval_jaxpr_stateful( - discharged_jaxpr, - in_unknowns=jaxpr_in_unknowns, - in_inst=jaxpr_in_unknowns, - ensure_out_unknowns=in_unknowns, - ensure_out_inst=in_unknowns, - saveable=saveable) - out_unknowns = list(out_unknowns) - if out_unknowns == in_unknowns: - break - in_unknowns = map(operator.or_, in_unknowns, out_unknowns) - else: - if num_inputs > 0: - raise Exception("Invalid fixpoint") - del out_unknowns # Redundant since it's the same as `in_unknowns` - new_inst = [x for x, already, inst in zip(eqn.invars, in_inst, out_inst) - if type(x) is core.Var and inst and not already] - - # We use `partial_eval_jaxpr_stateful` here because it won't remove effectful - # primitives like `get`/`set`. - jaxpr_known_resout, jaxpr_staged_resin_, _, _, num_res_out, num_res_ref = \ - pe.partial_eval_jaxpr_stateful(jaxpr, in_unknowns, - in_unknowns, [], [], saveable) - num_res = num_res_ref + num_res_out - # `partial_eval_jaxpr_stateful` will give us jaxprs that have hybrid `Ref` and - # non-Ref input/outputs. However, we'd like to bind these jaxprs to a - # `for`, which expects only `Ref` inputs and no output. We need to convert - # both of these jaxprs into ones that are compatible with `for`. - # TODO(sharadmv,mattjj): implement "passthrough" optimization. - - # `jaxpr_known_resout` is a jaxpr that maps from all the input `Refs` - # to output residual values (none of them should be `Ref`s). We'll need to - # convert the output residual values into `Ref`s that are initially empty - # `Ref`s that are written to at the end of the jaxpr. - jaxpr_known, res_avals = _convert_outputs_to_writes(jaxpr_known_resout) - - # In a stateful partial_eval, the residuals should be `Ref`s. - res_avals = map(AbstractRef, res_avals) - - known_invars, staged_invars = partition_list(in_unknowns, eqn.invars) - known_outvars, staged_outvars = partition_list(in_unknowns, eqn.outvars) - newvar = core.gensym() - _, res_ref_avals = split_list([v.aval for v in jaxpr_known_resout.invars], - [len(known_invars)]) - nonref_resvars = map(newvar, res_avals) - ref_resvars = map(newvar, res_ref_avals) - known_out_resvars = map(newvar, [*res_ref_avals, *res_avals]) - - known_which_linear, _ = partition_list(in_unknowns, which_linear) - jaxpr_known_which_linear = (*known_which_linear, *(False,) * num_res) - known_and_res_invars = [*known_invars, *ref_resvars, *nonref_resvars] - - known_params = dict(jaxpr=jaxpr_known, which_linear=jaxpr_known_which_linear, - # TODO(sharadmv): compute this in the general case - is_initialized=(True,) * len(jaxpr_known.invars)) - _, known_effects = run_state_p.abstract_eval( - *[v.aval for v in known_and_res_invars], **known_params) - eqn_known = pe.new_jaxpr_eqn(known_and_res_invars, - [*known_outvars, *known_out_resvars], - run_state_p, known_params, - known_effects, eqn.source_info, eqn.ctx) - - jaxpr_staged = _convert_inputs_to_reads(len(res_avals), jaxpr_staged_resin_) - - _, staged_which_linear = partition_list(in_unknowns, which_linear) - which_linear_unknown = (*[False] * num_res, *staged_which_linear) - staged_params = dict(jaxpr=jaxpr_staged, which_linear=which_linear_unknown, - # TODO(sharadmv): compute this in the general case - is_initialized=(True,) * len(jaxpr_staged.invars)) - rejiggered_resvars = [*nonref_resvars, *ref_resvars] - _, staged_invars = partition_list(in_unknowns, eqn.invars) - res_staged_invars = [*rejiggered_resvars, *staged_invars] - _, staged_effects = run_state_p.abstract_eval( - *[v.aval for v in res_staged_invars], **staged_params) - _, staged_outvars = partition_list(in_unknowns, eqn.outvars) - if num_res: - - def staged(*args): - out = run_state_p.bind(*args, **staged_params) - return out[num_res:] - staged_call_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(staged, debug_info=jaxpr_staged.debug_info), - [v.aval for v in res_staged_invars]) - eqn_staged = pe.new_jaxpr_eqn(res_staged_invars, - staged_outvars, - core.closed_call_p, - dict(call_jaxpr=pe.close_jaxpr(staged_call_jaxpr)), - staged_effects, eqn.source_info, eqn.ctx) - assert len(res_staged_invars) == len(staged_call_jaxpr.invars) - assert len(staged_outvars) == len(staged_call_jaxpr.outvars) - else: - eqn_staged = pe.new_jaxpr_eqn(staged_invars, - staged_outvars, - run_state_p, - staged_params, - staged_effects, eqn.source_info, eqn.ctx) - new_vars = [*new_inst, *nonref_resvars, *ref_resvars] - return eqn_known, eqn_staged, in_unknowns, in_unknowns, new_vars -pe.partial_eval_jaxpr_custom_rules[run_state_p] = _run_state_partial_eval_custom - -def _transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: Sequence[bool], - is_initialized: tuple[bool, ...]) -> tuple[core.Jaxpr, Any]: - if not all(is_initialized): - raise NotImplementedError( - "Uninitialized Refs are not supported in transpose." - ) - def trans(*args): - # First we want to run the computation to read all the residual refs. We can - # do that by using partial evaluation with all linear inputs unknown. - res_jaxpr_, tangent_jaxpr_, *_, num_res_out, num_res_ref = \ - pe.partial_eval_jaxpr_stateful(jaxpr, which_linear, in_inst=which_linear, - ensure_out_inst=[], - ensure_out_unknowns=[], - saveable=_save_everything) - - num_unknown = sum(which_linear) - num_known = len(jaxpr.invars) - num_unknown - res_args, _ = partition_list(which_linear, args) - res_jaxpr_avals = [v.aval for v in res_jaxpr_.invars] - _, res_avals = split_list(res_jaxpr_avals, [num_known]) - res_avals = [a.inner_aval for a in res_avals] # pytype: disable=attribute-error - all_avals = [*res_avals, *[v.aval for v in res_jaxpr_.outvars]] - empty_res = map(ad.zeros_like_aval, all_avals) - res_jaxpr, _ = _convert_outputs_to_writes(res_jaxpr_) - res = run_state_p.bind( - *res_args, - *empty_res, - jaxpr=res_jaxpr, - which_linear=(False,) * (len(res_args) + len(empty_res)), - # TODO(sharadmv): compute this in the general case - is_initialized=(True,) * len(res_jaxpr.invars), - ) - res = res[len(res_args):] - ref_res_, nonref_res_ = split_list(res, [num_res_ref]) - - # Now that we have residual values, we run the tangent jaxpr. It takes as - # input the residuals, the loop index, and all the refs (at least, the ones - # that are used in the body). Luckily, `tangent_jaxpr_` has all known and - # unknown inputs! - tangent_jaxpr, used_inputs = pe.dce_jaxpr(tangent_jaxpr_, []) - used_res, used_cts = split_list(used_inputs, [len(res)]) - used_nonref_res, used_ref_res = split_list(used_res, [num_res_out]) - _, nonref_res = partition_list(used_nonref_res, nonref_res_) - _, ref_res = partition_list(used_ref_res, ref_res_) - primals_args = [*nonref_res, *ref_res] - _, tangent_args = partition_list(which_linear, args) - _, ct_args = partition_list(used_cts, tangent_args) - ad.backward_pass(tangent_jaxpr, False, (), (*primals_args, *ct_args), ()) - return [] - jaxpr_trans, _, consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(trans, - debug_info=jaxpr.debug_info), - [v.aval for v in jaxpr.invars]) - return jaxpr_trans, consts - -def _run_state_transpose(in_cts, *args, jaxpr: core.Jaxpr, - which_linear: tuple[bool, ...], - is_initialized: tuple[bool, ...]): - if not all(is_initialized): - raise NotImplementedError( - "Uninitialized Refs are not supported in transpose." - ) - # if any in_ct is nonzero, we definitely want it in args_ (and the - # corresponding x in args could be an undefined primal, but doesn't have to be) - # for non-res stuff: - # getting and setting => (nonzero ct, UndefinedPrimal arg) - # just setting => (nonzero ct, not UndefinedPrimal, dummy value) - # just getting => (zero ct , UndefinedPrimal arg) - # for res stuff: - # (zero ct , not UndefinedPrimal) - assert any(which_linear) - transpose_args = [] - for x, ct in zip(args, in_cts): - if type(ct) is ad_util.Zero and not ad.is_undefined_primal(x): - # this is a residual, take x! - transpose_args.append(x) - elif type(ct) is ad_util.Zero and ad.is_undefined_primal(x): - # the loop was 'just getting', plug in a zero - transpose_args.append(ad_util.zeros_like_aval(x.aval)) - elif type(ct) is not ad_util.Zero and not ad.is_undefined_primal(x): - # the loop was 'just setting', grab that cotangent! x is dummy - transpose_args.append(ct) - elif type(ct) is not ad_util.Zero and ad.is_undefined_primal(x): - # the loop was 'getting and setting', grab that cotangent! - transpose_args.append(ct) - jaxpr_transpose_, consts = _transpose_jaxpr( - jaxpr, which_linear, is_initialized - ) - jaxpr_transpose = hoist_consts_to_refs(jaxpr_transpose_) - which_linear = (*[False] * len(consts), *which_linear) - const_all_outs = run_state_p.bind( - *consts, - *transpose_args, - jaxpr=jaxpr_transpose, - which_linear=which_linear, - # TODO(sharadmv): compute this in the general case - is_initialized=(True,) * len(jaxpr_transpose.invars), - ) - _, all_outs = split_list(const_all_outs, [len(consts)]) - ct_outs = [ct if ad.is_undefined_primal(x) else None - for x, ct in zip(args, all_outs)] - return ct_outs -ad.primitive_transposes[run_state_p] = _run_state_transpose - @register_discharge_rule(run_state_p) def _run_state_discharge_rule(in_avals: Sequence[core.AbstractValue], out_avals: Sequence[core.AbstractValue], @@ -1009,7 +816,7 @@ def _initial_style_jaxpr(fun: Callable, fun_, out_tree_thunk = api_util.flatten_fun_nokwargs( lu.wrap_init(fun, debug_info=debug), tree_util.treedef_tuple((in_tree,))) - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, in_avals) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun_, in_avals) return jaxpr, consts, out_tree_thunk() @@ -1054,3 +861,69 @@ def wrapped(args): _, out_flat = split_list(out_const_flat, [len(consts)]) return in_tree.unflatten(out_flat) return wrapped + +@register_discharge_rule(pjit.jit_p) +def _pjit_state_discharge_rule( + in_avals, out_avals, *args, jaxpr, in_shardings, out_shardings, + in_layouts, out_layouts, **params): + if not (any(isinstance(e, RefEffect) for e in jaxpr.effects) + or any(isinstance(a, AbstractRef) for a in jaxpr.in_avals)): + # Only internal ref effects + jaxpr_ = discharge_state2(jaxpr) + out = pjit.jit_p.bind( + *args, + jaxpr=jaxpr_, + in_shardings=in_shardings, + out_shardings=out_shardings, + in_layouts=in_layouts, + out_layouts=out_layouts, + **params, + ) + new_invals = [None] * len(in_avals) + return new_invals, out + if not all(isinstance(s, sharding_impls.UnspecifiedValue) for s in (*in_shardings, *out_shardings)): + raise NotImplementedError + + if not (all(l is None for l in in_layouts) and + all(l is None for l in out_layouts)): + raise NotImplementedError + + discharged_jaxpr = discharge_state2(jaxpr) + new_in_shardings = (sharding_impls.UNSPECIFIED,) * len(discharged_jaxpr.in_avals) + new_out_shardings = (sharding_impls.UNSPECIFIED,) * len(discharged_jaxpr.out_avals) + new_in_layouts = (None,) * len(discharged_jaxpr.in_avals) + new_out_layouts = (None,) * len(discharged_jaxpr.out_avals) + out_and_ref_vals = pjit.jit_p.bind( + *args, jaxpr=discharged_jaxpr, in_shardings=new_in_shardings, + out_shardings=new_out_shardings, in_layouts=new_in_layouts, + out_layouts=new_out_layouts, **params) + out_vals, ref_vals = split_list(out_and_ref_vals, [len(jaxpr.out_avals)]) + ref_vals_iter = iter(ref_vals) + new_invals = tuple(next(ref_vals_iter) if isinstance(aval, AbstractRef) + else None for aval in in_avals) + sentinel = object() + assert next(ref_vals_iter, sentinel) is sentinel + return new_invals, out_vals + + +@register_discharge_rule(custom_derivatives.custom_vjp_call_p) +def custom_vjp_call_discharge(in_avals, out_avals, *args, call_jaxpr, + fwd_jaxpr_thunk, bwd, out_trees, symbolic_zeros, + num_consts): + # Discharge happens after all AD is done, so we can discard the AD rules. + del fwd_jaxpr_thunk, bwd, out_trees, symbolic_zeros, num_consts + dis_jaxpr, dis_consts = discharge_state(call_jaxpr.jaxpr, call_jaxpr.consts) + outs = _eval_jaxpr_ad_error(dis_jaxpr, dis_consts, args) + out_vals, ref_vals = split_list(outs, [len(call_jaxpr.out_avals)]) + ref_vals_ = iter(ref_vals) + new_invals = [next(ref_vals_) if isinstance(aval, AbstractRef) else None + for aval in in_avals] + assert next(ref_vals_, None) is None + return new_invals, out_vals + +@partial(custom_derivatives.custom_jvp, nondiff_argnums=(0,)) +def _eval_jaxpr_ad_error(dis_jaxpr, consts, args): + return core.eval_jaxpr(dis_jaxpr, consts, *args) +@_eval_jaxpr_ad_error.defjvp +def _eval_jaxpr_ad_error_jvp(*_): + raise Exception("should be unreachable, AD after discharge") diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index 4b627c1cd581..1c484fb2fe96 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -17,9 +17,10 @@ from __future__ import annotations import dataclasses -from typing import Any, Sequence, Union +from typing import Any, Union from jax._src import core +from jax._src import pretty_printer as pp from jax._src import tree_util from jax._src.typing import Array from jax._src.util import merge_lists @@ -41,8 +42,8 @@ class Slice: stride: int = 1 def __post_init__(self): - if self.stride < 1: - raise ValueError("`stride` must be >= 1.") + if self.stride < 0: + raise ValueError("`stride` must be >= 0.") @property def is_dynamic_start(self): @@ -78,6 +79,30 @@ def from_slice(cls, slc: slice, size: int) -> Slice: return cls(start, size, step) +def _pp_slice(context: core.JaxprPpContext, dim, slc: Slice) -> str: + start, size = slc.start, slc.size + if isinstance(start, core.Var): + start_str = core.pp_var(start, context) + size_str = ( + core.pp_var(size, context) if isinstance(size, core.Var) else str(size) + ) + return f"{start_str}:{start_str}+{size_str}" + else: + start_str = str(start) + if start == 0: + start_str = "" + if isinstance(size, core.Var): + size_str = core.pp_var(size, context) + if start_str: + return f"{start_str}:{start_str}+{size_str}" + else: + return f":{size_str}" + else: + end = start + size + end_str = "" if end == dim else str(end) + return f"{start_str}:{end_str}" + + def dslice( start: int | Array | None, size: int | Array | None = None, @@ -114,6 +139,7 @@ def dslice( def unpack_ndindexer(indexer: NDIndexer) -> tuple[tuple[bool, ...], tuple[Slice, ...], tuple[IntIndexer, ...]]: + # TODO(slebedev): Flip this to be ``is_slice_indexing`` and update callers. is_int_indexing = [not isinstance(i, Slice) for i in indexer.indices] slice_indexers, int_indexers = partition_list( is_int_indexing, indexer.indices) @@ -130,17 +156,17 @@ def _maybe_concretize(x: Any): class NDIndexer: indices: tuple[DimIndexer, ...] shape: tuple[int, ...] - int_indexer_shape: tuple[int, ...] + int_indexer_shape: tuple[int | Array, ...] # Off by default to avoid doing validation during pytree operations. validate: bool = False def __post_init__(self): - if not self.validate: - return if len(self.indices) != len(self.shape): raise ValueError( f"`indices` must be the same length as `Ref` shape.: {self}." ) + if not self.validate: + return # We validate integer indexing shapes here for idx, s in zip(self.indices, self.shape): if isinstance(idx, Slice): @@ -157,21 +183,27 @@ def __post_init__(self): continue # The shape of indexer integers should be broadcastable up to the # int_indexer_shape of the whole NDIndexer - if not np.shape(idx): + from jax._src.state import types as state_types # pytype: disable=import-error + idx_shape = ( + idx.shape + if isinstance(idx, state_types.TransformedRef) + else core.get_aval(idx).shape + ) + if not idx_shape: if (value := _maybe_concretize(idx)) and value >= s: raise ValueError(f"Out of bound indexer: idx={value}, dim={s}.") # For ()-shaped indexers, we can broadcast no problm. continue # If we don't have a ()-shaped indexer, the rank must match # int_indexer_shape - if np.ndim(idx) != len(self.int_indexer_shape): + if len(idx_shape) != len(self.int_indexer_shape): raise ValueError( - f"Indexer must have rank {np.ndim(idx)}: {idx=} vs." + f"Indexer must have rank {len(idx_shape)}: {idx=} vs." f" {self.int_indexer_shape=}" ) # Here we check that the shapes broadcast. try: - np.broadcast_shapes(np.shape(idx), self.int_indexer_shape) + np.broadcast_shapes(idx_shape, self.int_indexer_shape) except ValueError as e: raise ValueError( f"Could not broadcast integer indexer: {idx=} vs." @@ -184,11 +216,19 @@ def is_dynamic_size(self): def tree_flatten(self): flat_idx, idx_tree = tree_util.tree_flatten(self.indices) - return flat_idx, (idx_tree, self.shape, self.int_indexer_shape) + if not all(isinstance(i, int) for i in self.int_indexer_shape): + return (*flat_idx, self.int_indexer_shape), (idx_tree, self.shape) + else: + return flat_idx, (idx_tree, self.shape, self.int_indexer_shape) @classmethod def tree_unflatten(cls, data, flat_idx): - idx_tree, shape, int_indexer_shape = data + if len(data) == 3: + idx_tree, shape, int_indexer_shape = data + else: + # The ``int_indexer_shape`` is dynamic. + idx_tree, shape = data + *flat_idx, int_indexer_shape = flat_idx indices = tree_util.tree_unflatten(idx_tree, flat_idx) return cls(tuple(indices), shape, int_indexer_shape) @@ -218,17 +258,45 @@ def from_indices_shape(cls, indices, shape) -> NDIndexer: Slice.from_slice(i, s) if isinstance(i, slice) else i for i, s in zip(indices, shape)) - is_int_indexing = [not isinstance(i, Slice) for i in indices] - if any(is_int_indexing): - int_indexers: Sequence[Any] - other_indexers, int_indexers = partition_list(is_int_indexing, indices) - indexer_shapes = tuple(core.get_aval(i).shape for i in int_indexers) + is_slice_indexing = [isinstance(i, Slice) for i in indices] + if all(is_slice_indexing): + return cls(indices, shape, (), validate=True) + + other_indexers, slice_indexers = partition_list(is_slice_indexing, indices) + validate = True + + # We treat refs differently from scalars and arrays, because refs can have + # a dynamic shape, making it impossible to statically determine the + # broadcasted shape in the presence of other non-slice indexers. + from jax._src.state import types as state_types # pytype: disable=import-error + if ref_indexers := [ + i + for i in other_indexers + if isinstance(i, state_types.TransformedRef) + or isinstance(core.get_aval(i), state_types.AbstractRef) + ]: + # TODO(slebedev): Consider pushing these checks to lowering time. + if len(ref_indexers) > 1: + raise NotImplementedError("Multiple Ref indexers are not supported") + if len(ref_indexers) != len(other_indexers): + raise NotImplementedError( + "Ref cannot be mixed with other non-slice indexers" + ) + [ref_indexer] = ref_indexers + indexer_shape = ref_indexer.shape # type: ignore try: - int_indexer_shape = np.broadcast_shapes(*indexer_shapes) + core.canonicalize_shape(indexer_shape) + except TypeError: + validate = False # The shape is dynamic. + else: + indexer_shapes = [core.get_aval(i).shape for i in other_indexers] + try: + indexer_shape = np.broadcast_shapes(*indexer_shapes) except ValueError as e: # Raise a nicer error than the NumPy one. raise ValueError( - f"Cannot broadcast shapes for indexing: {indexer_shapes}") from e + "Cannot broadcast shapes for indexing: {indexer_shapes}" + ) from e # Here we use the `broadcast_to` primitive instead of composing lax # primitives together because it is easier to lower in targets like @@ -237,21 +305,37 @@ def from_indices_shape(cls, indices, shape) -> NDIndexer: # The local import avoids a circular dependency between primitives # and this module. from jax._src.state import primitives as sp # pytype: disable=import-error - int_indexers = [ - sp.broadcast_to(i, int_indexer_shape) for i in int_indexers + other_indexers = [ + sp.broadcast_to(i, indexer_shape) for i in other_indexers # type: ignore[arg-type] ] - indices = tuple(merge_lists(is_int_indexing, other_indexers, int_indexers)) - else: - int_indexer_shape = () + indices = tuple( + merge_lists(is_slice_indexing, other_indexers, slice_indexers) + ) + return cls(indices, shape, indexer_shape, validate) - return cls(indices, shape, int_indexer_shape, validate=True) + @classmethod + def make_trivial_indexer(cls, shape: tuple[int, ...]) -> NDIndexer: + return NDIndexer.from_indices_shape( + tuple(slice(0, e) for e in shape), + shape, + ) def get_indexer_shape(self) -> tuple[int | Array, ...]: - _, slice_indexers, _ = unpack_ndindexer(self) - slice_shape = [s.size for s in slice_indexers] - # In NDIndexers, the int_indexer_shape is *always* at the front of the - # result. - return (*self.int_indexer_shape, *slice_shape) + is_int_indexing, slice_indexers, _ = unpack_ndindexer(self) + + slice_shape = tuple(s.size for s in slice_indexers) + int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_int_indexing)[0]) == 1) + ) + if not int_indexers_contiguous: + return self.int_indexer_shape + slice_shape + + has_int_indexers = any(is_int_indexing) + if has_int_indexers: + pos = is_int_indexing.index(True) + return slice_shape[:pos] + self.int_indexer_shape + slice_shape[pos:] + + return slice_shape def transform_shape(self, shape: None | tuple[int | Array, ...]) -> None | tuple[int | Array, ...]: del shape # Unused @@ -282,3 +366,12 @@ def transform_sharding(self, sharding): f"along unsharded axes, but ref of shape {self.shape} " f"was sliced on axis {i}, which is sharded like {s}") return sharding + + def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc: + indices = [] + for idx, dim in zip(self.indices, self.shape): + if isinstance(idx, Slice): + indices.append(_pp_slice(context, dim, idx)) + else: + indices.append(core.pp_var(idx, context, print_literal_dtype=False)) # type: ignore + return pp.concat([pp.text("["), pp.text(",".join(indices)), pp.text("]")]) diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 6f7570a5f3cd..69d26237db54 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -18,9 +18,12 @@ import types from typing import Any, Union +import numpy as np + from jax._src import ad_util from jax._src import core from jax._src import dispatch +from jax._src import dtypes from jax._src import pretty_printer as pp from jax._src import traceback_util from jax._src import tree_util @@ -32,17 +35,19 @@ from jax._src.state import indexing from jax._src.state.types import ( AbstractRef, + AbstractLinVal, AccumEffect, ReadEffect, - RefBitcaster, - RefReshaper, Transform, TransformedRef, WriteEffect, ) -from jax._src.typing import Array +from jax._src.typing import Array, ArrayLike from jax._src.util import safe_map, safe_zip -import numpy as np + + +# Stand-in for hi-jax inputs to Ref. +HijaxType = Any ## General utilities @@ -63,8 +68,19 @@ # `Ref((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like # a:f32[3] <- x[] get_p = core.Primitive("get") +get_p.is_effectful = lambda params: True # type: ignore get_p.def_impl(partial(dispatch.apply_primitive, get_p)) -batching.ragged_prop_rules[get_p] = batching.ragged_mask_transfer_identity + +get_p.is_high = lambda ref_aval, *_, tree: ref_aval.is_high # type: ignore +def _get_to_lojax(ref, *idx, tree): + val_ty = core.typeof(ref._refs) + transforms = tree_util.tree_unflatten(tree, idx) + if transforms: + ref = TransformedRef(ref, transforms[:-1]) + idx = transforms[-1] + return val_ty.ref_get_to_lojax(ref, idx) + return val_ty.raise_val(*map(ref_get, val_ty.lower_val(ref._refs))) +get_p.to_lojax = _get_to_lojax # type: ignore Indexer = Union[int, slice, Array, types.EllipsisType] @@ -73,7 +89,6 @@ def get_ref_and_transforms( ref_or_view: Any, idx: Indexer | tuple[Indexer, ...] | None, function_name: str, - force_trailing_indexer: bool = True, # TODO(apaszke): Clean this up. ) -> tuple[Any, tuple[Transform, ...]]: if isinstance(ref_or_view, TransformedRef): ref, transforms = ref_or_view.ref, ref_or_view.transforms @@ -82,7 +97,8 @@ def get_ref_and_transforms( ref_aval = core.get_aval(ref) if not isinstance(ref_aval, AbstractRef): raise ValueError(f"Can only call `{function_name}` on a `Ref`: {ref}.") - if not isinstance(ref_aval.inner_aval, core.ShapedArray): + if (not isinstance(ref_aval.inner_aval, core.ShapedArray) + and not ref_aval.inner_aval.is_high): return ref, () if idx is None or idx is Ellipsis: @@ -90,19 +106,50 @@ def get_ref_and_transforms( elif not isinstance(idx, tuple): idx = (idx,) - if not idx and not force_trailing_indexer: + if not idx: return ref, transforms if not idx and transforms and isinstance(transforms[-1], indexing.NDIndexer): return ref, transforms nd_indexer = indexing.NDIndexer.from_indices_shape(idx, ref_or_view.shape) return ref, (*transforms, nd_indexer) - +@partial(traceback_util.api_boundary, repro_api_name="jax.ref.get") def ref_get( - ref_or_view: Any, idx: Indexer | tuple[Indexer, ...] | None = None -) -> Array: - """Reads a value from a `Ref`, a.k.a. value <- ref[idx].""" - ref, transforms = get_ref_and_transforms(ref_or_view, idx, "ref_get") + ref: core.Ref | TransformedRef, + idx: Indexer | tuple[Indexer, ...] | None = None +) -> Array | HijaxType: + """Read a value from an Ref. + + This is equivalent to ``ref[idx]`` for a NumPy-style indexer ``idx``. + For more on mutable array refs, refer to the `Ref guide`_. + + Args: + ref: a :class:`jax.ref.Ref` object. + idx: a NumPy-style indexer + + Returns: + A :class:`jax.Array` object (note, not a :class:`jax.ref.Ref`) containing + the indexed elements of the mutable reference. + + Examples: + >>> import jax + >>> ref = jax.new_ref(jax.numpy.arange(5)) + >>> jax.ref.get(ref, slice(1, 3)) + Array([1, 2], dtype=int32) + + Equivalent operation via indexing syntax: + + >>> ref[1:3] + Array([1, 2], dtype=int32) + + Use ``...`` to extract the full buffer: + + >>> ref[...] + Array([0, 1, 2, 3, 4], dtype=int32) + + .. _Ref guide: https://docs.jax.dev/en/latest/array_refs.html + """ + ref, transforms = get_ref_and_transforms(ref, idx, "ref_get") flat_transforms, tree = tree_util.tree_flatten(transforms) return get_p.bind(ref, *flat_transforms, tree=tree) @@ -124,38 +171,143 @@ def ref_get( # are `ShapedArray((), np.dtype('int32'))` leads to a jaxpr eqn printed like # x:Ref{f32[3]}[i, j] <- a swap_p = core.Primitive("swap") +swap_p.is_effectful = lambda params: True # type: ignore swap_p.def_impl(partial(dispatch.apply_primitive, swap_p)) - -def swap_ragged_prop_rule(eqn_params, invar_raggedness, outvars): - assert len(invar_raggedness) == 2 - invar_raggedness_lhs = invar_raggedness[0] - invar_raggedness_rhs = invar_raggedness[1] - - return [invar_raggedness_rhs, invar_raggedness_lhs], [None] - - -batching.ragged_prop_rules[swap_p] = swap_ragged_prop_rule - +swap_p.is_high = lambda ref_aval, *_, tree: ref_aval.is_high # type: ignore +def _swap_to_lojax(ref, val, *idx, tree): + ref_val_ty = core.typeof(ref._refs) + val_ty = core.typeof(val) + transforms = tree_util.tree_unflatten(tree, idx) + if transforms: + ref = TransformedRef(ref, transforms[:-1]) + idx = transforms[-1] + return ref_val_ty.ref_swap_to_lojax(ref, val, idx) + lo_refs = ref_val_ty.lower_val(ref._refs) + lo_vals = val_ty.lower_val(val) + outs = [ref_swap(lo_ref, idx, lo_val) for lo_ref, lo_val + in zip(lo_refs, lo_vals)] + return val_ty.raise_val(*outs) +swap_p.to_lojax = _swap_to_lojax # type: ignore + + +@partial(traceback_util.api_boundary, repro_api_name="jax.ref.swap") def ref_swap( - ref_or_view: AbstractRef | TransformedRef, + ref: core.Ref | TransformedRef, idx: Indexer | tuple[Indexer, ...] | None, - value: Array, + value: ArrayLike | HijaxType, _function_name: str = "ref_swap", -) -> Array: - """Sets a `Ref`'s value and returns the original value.""" - ref, transforms = get_ref_and_transforms(ref_or_view, idx, _function_name) +) -> Array | HijaxType: + """Update an array value inplace while returning the previous value. + + This is equivalent to ``ref[idx], prev = value, ref[idx]`` while returning + ``prev``, for a NumPy-style indexer ``idx``. + For more on mutable array refs, refer to the `Ref guide`_. + + Args: + ref: a :class:`jax.ref.Ref` object. On return, the buffer will be + mutated by this operation. + idx: a NumPy-style indexer + value: a :class:`jax.Array` object (note, not a :class:`jax.ref.Ref`) + containing the values to set in the array. + + Returns: + A :class:`jax.Array` containing the previous value at `idx`. + + Examples: + >>> import jax + >>> ref = jax.new_ref(jax.numpy.arange(5)) + >>> jax.ref.swap(ref, 3, 10) + Array(3, dtype=int32) + >>> ref + Ref([ 0, 1, 2, 10, 4], dtype=int32) + + Equivalent operation via indexing syntax: + + >>> ref = jax.new_ref(jax.numpy.arange(5)) + >>> ref[3], prev = 10, ref[3] + >>> prev + Array(3, dtype=int32) + >>> ref + Ref([ 0, 1, 2, 10, 4], dtype=int32) + + Use ``...`` to swap the value of a scalar ref: + + >>> ref = jax.new_ref(jax.numpy.int32(5)) + >>> jax.ref.swap(ref, ..., 10) + Array(5, dtype=int32) + >>> ref + Ref(10, dtype=int32) + + .. _Ref guide: https://docs.jax.dev/en/latest/array_refs.html + """ + if hasattr(ref, 'dtype'): + value = _maybe_implicit_cast(ref.dtype, value) + ref, transforms = get_ref_and_transforms(ref, idx, _function_name) flat_transforms, tree = tree_util.tree_flatten(transforms) return swap_p.bind(ref, value, *flat_transforms, tree=tree) - +# TODO(slebedev,mattjj): replace with special handling of Python numeric types: +# if (isinstance(value, (int, float, complex)) and +# value == np.array(value, dtype).item()): return cast +def _maybe_implicit_cast(dtype, value): + aval = core.typeof(value) + if not isinstance(aval, core.ShapedArray): + return value + if (aval.weak_type and + (dtypes.issubdtype(dtype, np.floating) and + dtypes.issubdtype(aval.dtype, np.floating)) or + (dtypes.issubdtype(dtype, np.integer) and + dtypes.issubdtype(aval.dtype, np.integer))): + return lax.convert_element_type(value, dtype) + return value + + +@partial(traceback_util.api_boundary, repro_api_name="jax.ref.set") def ref_set( - ref_or_view: AbstractRef | TransformedRef, + ref: core.Ref | TransformedRef, idx: Indexer | tuple[Indexer, ...] | None, - value: Array, + value: ArrayLike | HijaxType, ) -> None: - """Sets a `Ref`'s value, a.k.a. ref[idx] <- value.""" - ref_swap(ref_or_view, idx, value, _function_name="ref_set") + """Set a value in an Ref in-place. + + This is equivalent to ``ref[idx] = value`` for a NumPy-style indexer + ``idx``. For more on mutable array refs, refer to the `Ref guide`_. + + Args: + ref: a :class:`jax.ref.Ref` object. On return, the buffer will be + mutated by this operation. + idx: a NumPy-style indexer + value: a :class:`jax.Array` object (note, not a :class:`jax.ref.Ref`) + containing the values to set in the array. + + Returns: + None + + Examples: + >>> import jax + >>> ref = jax.new_ref(jax.numpy.zeros(5)) + >>> jax.ref.set(ref, 1, 10.0) + >>> ref + Ref([ 0., 10., 0., 0., 0.], dtype=float32) + + Equivalent operation via indexing syntax: + + >>> ref = jax.new_ref(jax.numpy.zeros(5)) + >>> ref[1] = 10.0 + >>> ref + Ref([ 0., 10., 0., 0., 0.], dtype=float32) + + Use ``...`` to set the value of a scalar ref: + + >>> ref = jax.new_ref(jax.numpy.int32(0)) + >>> ref[...] = 4 + >>> ref + Ref(4, dtype=int32) + + .. _Ref guide: https://docs.jax.dev/en/latest/array_refs.html + """ + ref_swap(ref, idx, value, _function_name="ref_set") # `addupdate_p` mutates a `Ref`, adding a value to its existing value. @@ -170,19 +322,60 @@ def ref_set( # _ = swap ref c *idx # ``` addupdate_p = core.Primitive('addupdate') +addupdate_p.is_effectful = lambda params: True # type: ignore addupdate_p.multiple_results = True addupdate_p.def_impl(partial(dispatch.apply_primitive, addupdate_p)) def ref_addupdate( - ref_or_view: AbstractRef, + ref: core.Ref | TransformedRef, idx: Indexer | tuple[Indexer, ...] | None, - x: Array, + x: ArrayLike | HijaxType, ) -> None: - """Mutates a ref with an additive update i.e. `ref[idx] += x`.""" - ref, transforms = get_ref_and_transforms(ref_or_view, idx, "ref_addupdate") + """Add to an element in an Ref in-place. + + This is analogous to ``ref[idx] += value`` for a NumPy array ``ref`` and + NumPy-style indexer ``idx``. However, for an Ref ``ref``, executing + ``ref[idx] += value`` actually performs a ``ref_get``, add, and ``ref_set``, + so using this function can be more efficient under autodiff. For more on + mutable array refs, refer to the `Ref guide`_. + + Args: + ref: a :class:`jax.ref.Ref` object. On return, the buffer will be + mutated by this operation. + idx: a NumPy-style indexer + x: a :class:`jax.Array` object (note, not a :class:`jax.ref.Ref`) + containing the values to add at the specified indices. + + Returns: + None + + Examples: + >>> import jax + >>> ref = jax.new_ref(jax.numpy.arange(5)) + >>> jax.ref.addupdate(ref, 2, 10) + >>> ref + Ref([ 0, 1, 12, 3, 4], dtype=int32) + + Equivalent operation via indexing syntax: + + >>> ref = jax.new_ref(jax.numpy.arange(5)) + >>> ref[2] += 10 + >>> ref + Ref([ 0, 1, 12, 3, 4], dtype=int32) + + Use ``...`` to add to a scalar ref: + + >>> ref = jax.new_ref(jax.numpy.int32(2)) + >>> ref[...] += 10 + >>> ref + Ref(12, dtype=int32) + + .. _Ref guide: https://docs.jax.dev/en/latest/array_refs.html + """ + ref, transforms = get_ref_and_transforms(ref, idx, "ref_addupdate") flat_transforms, tree = tree_util.tree_flatten(transforms) - return addupdate_p.bind(ref, x, *flat_transforms, tree=tree) + addupdate_p.bind(ref, x, *flat_transforms, tree=tree) ## get/set/addupdate abstract evaluation rules @@ -216,6 +409,8 @@ def _sharding_after_transforming(sharding, transforms): def _get_abstract_eval(ref_aval: AbstractRef, *args, tree): transforms = tree_util.tree_unflatten(tree, args) + if transforms and ref_aval.inner_aval.is_high: + return ref_aval.inner_aval.ref_get_abstract_eval(ref_aval, *args, tree=tree) if not isinstance(ref_aval, AbstractRef): raise ValueError(f"`get` must be called on `Ref` types: {ref_aval}.") if isinstance(ref_aval.inner_aval, core.ShapedArray): @@ -235,9 +430,15 @@ def _swap_abstract_eval(ref_aval: AbstractRef, val_aval: core.AbstractValue, *args: Any, tree): transforms = tree_util.tree_unflatten(tree, args) + if transforms and ref_aval.inner_aval.is_high: + return ref_aval.inner_aval.ref_swap_abstract_eval( + ref_aval, val_aval, *args, tree=tree) out_aval: core.AbstractValue if not isinstance(ref_aval, AbstractRef): raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.") + if isinstance(val_aval, AbstractRef): + raise ValueError("Cannot store a Ref into another Ref. " + "Did you forget to load from it using `[...]`?") if isinstance(ref_aval.inner_aval, core.ShapedArray): assert isinstance(val_aval, core.ShapedArray) expected_out_shape = _shape_after_transforming(ref_aval.shape, transforms) @@ -248,7 +449,7 @@ def _swap_abstract_eval(ref_aval: AbstractRef, f"Expected shape: {expected_out_shape}. " f"Value shape: {val_aval.shape}. " f"Transforms: {transforms}. ") - if expected_out_dtype != val_aval.dtype and not val_aval.weak_type: + if expected_out_dtype != val_aval.dtype: raise ValueError( "Invalid dtype for `swap`. " f"Ref dtype: {expected_out_dtype}. " @@ -272,6 +473,7 @@ def _addupdate_abstract_eval(ref_aval: AbstractRef, if isinstance(ref_aval.inner_aval, core.ShapedArray): out_shape = _shape_after_transforming(ref_aval.shape, transforms) out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms) + out_sharding = _sharding_after_transforming(ref_aval.sharding, transforms) assert isinstance(val_aval, core.ShapedArray) if out_shape != val_aval.shape: raise ValueError( @@ -285,6 +487,12 @@ def _addupdate_abstract_eval(ref_aval: AbstractRef, raise ValueError("Invalid dtype for `addupdate`. " f"Ref dtype: {ref_aval.dtype}. " f"Value shape: {val_aval.dtype}. ") + if ((out_sharding.mesh._any_axis_explicit or + val_aval.sharding.mesh._any_axis_explicit) and + out_sharding != val_aval.sharding): + raise ValueError("Invalid sharding for `addupdate`. " + f"Ref sharding: {ref_aval.sharding}. " + f"Value sharding: {val_aval.sharding}. ") else: # Check that the transforms are valid if transforms: @@ -297,70 +505,6 @@ def _addupdate_abstract_eval(ref_aval: AbstractRef, pp_ref_var = partial(pp.color, intensity=pp.Intensity.NORMAL, foreground=pp.Color.GREEN) -def _pp_slice(context: core.JaxprPpContext, dim, slc: indexing.Slice - ) -> str: - start, size = slc.start, slc.size - if isinstance(start, core.Var): - start_str = core.pp_var(start, context) - size_str = ( - core.pp_var(size, context) - if isinstance(size, core.Var) - else str(size) - ) - return f'{start_str}:{start_str}+{size_str}' - else: - start_str = str(start) - if start == 0: - start_str = '' - if isinstance(size, core.Var): - size_str = core.pp_var(size, context) - if start_str: - return f'{start_str}:{start_str}+{size_str}' - else: - return f':{size_str}' - else: - end = start + size - end_str = '' if end == dim else str(end) - return f'{start_str}:{end_str}' - -def pp_indexer(context: core.JaxprPpContext,indexer: indexing.NDIndexer - ) -> pp.Doc: - indices = [] - for idx, dim in zip(indexer.indices, indexer.shape): - if isinstance(idx, indexing.Slice): - indices.append(_pp_slice(context, dim, idx)) - else: - indices.append(core.pp_var(idx, context)) # type: ignore - return pp.concat([pp.text("["), pp.text(','.join(indices)), pp.text("]")]) - - -def pp_bitcaster( - context: core.JaxprPpContext, bitcaster: RefBitcaster -) -> pp.Doc: - del context - return pp.text( - f"[bitcast({bitcaster.dtype}[{','.join(str(d) for d in bitcaster.shape)}])]" - ) - - -def pp_reshaper(context: core.JaxprPpContext, reshaper: RefReshaper) -> pp.Doc: - del context - return pp.text( - f"[reshape({reshaper.dtype}[{','.join(str(d) for d in reshaper.shape)}])]" - ) - - -def pp_transform(context: core.JaxprPpContext, transform: Transform) -> pp.Doc: - match transform: - case indexing.NDIndexer(): - return pp_indexer(context, transform) - case RefBitcaster(): - return pp_bitcaster(context, transform) - case RefReshaper(): - return pp_reshaper(context, transform) - case _: - return pp.text(f"[{transform}]") - def _pp_transforms( context: core.JaxprPpContext, @@ -369,7 +513,7 @@ def _pp_transforms( if not transforms: return pp.text("[...]") return pp.concat( - [pp_transform(context, transform) for transform in transforms] + [transform.pretty_print(context) for transform in transforms] ) @@ -432,31 +576,41 @@ def _addupdate_pp_rule(eqn, context, settings) -> pp.Doc: def _get_jvp(primals: list[Any], tangents: list[Any], **params: Any): ref_primal, *idx = primals - assert isinstance(ref_primal.aval, AbstractRef) ref_tangent, *_ = tangents - assert isinstance(ref_tangent.aval, AbstractRef) - return (get_p.bind(ref_primal, *idx, **params), - get_p.bind(ref_tangent, *idx, **params)) + out_primal = get_p.bind(ref_primal, *idx, **params) + if isinstance(ref_tangent, ad_util.Zero): + out_tangent = ad_util.Zero(core.typeof(out_primal).to_tangent_aval()) + else: + out_tangent = get_p.bind(ref_tangent, *idx, **params) + return out_primal, out_tangent ad.primitive_jvps[get_p] = _get_jvp def _swap_jvp(primals: list[Any], tangents: list[Any], **params: Any): ref_primal, x_primal, *idx = primals - assert isinstance(ref_primal.aval, AbstractRef) ref_tangent, x_tangent, *_ = tangents - # if type(ref_tangent) is ad_util.Zero: - # raise Exception("you're an idiot") - assert isinstance(ref_tangent.aval, AbstractRef) - x_tangent = ad_util.instantiate(x_tangent) - return (swap_p.bind(ref_primal, x_primal, *idx, **params), - swap_p.bind(ref_tangent, x_tangent, *idx, **params)) + out_primal = swap_p.bind(ref_primal, x_primal, *idx, **params) + if isinstance(ref_tangent, ad_util.Zero) and isinstance(x_tangent, ad_util.Zero): + out_tangent = ad_util.Zero(core.typeof(out_primal).to_tangent_aval()) + elif ref_tangent.aval.kind == "anselm_ref": + out_tangent = ad_util.Zero(core.typeof(out_primal).to_tangent_aval()) + else: + if isinstance(ref_tangent, ad_util.Zero): + raise Exception("performing a set/swap operation with a differentiated " + "value on a non-differentiated array reference of type " + f"{core.typeof(ref_primal)}. Move the array reference " + "to be an argument of the differentiated function?") + x_tangent = ad_util.instantiate(x_tangent) + out_tangent = swap_p.bind(ref_tangent, x_tangent, *idx, **params) + return out_primal, out_tangent ad.primitive_jvps[swap_p] = _swap_jvp def addupdate_jvp_rule(primals: list[Any], tangents: list[Any], **params: Any): ref_primal, x_primal, *idx = primals ref_tangent, x_tangent, *_ = tangents x_tangent = ad_util.instantiate(x_tangent) - addupdate_p.bind(ref_primal, x_primal, *idx, **params) - addupdate_p.bind(ref_tangent, x_tangent, *idx, **params) + if ref_tangent.aval.kind != "anselm_ref": + addupdate_p.bind(ref_primal, x_primal, *idx, **params) + addupdate_p.bind(ref_tangent, x_tangent, *idx, **params) return [], [] ad.primitive_jvps[addupdate_p] = addupdate_jvp_rule @@ -470,7 +624,6 @@ def _get_transpose(g, ref, *idx, **params): ad.primitive_transposes[get_p] = _get_transpose def _swap_transpose(g, ref, x, *idx, **params): - del x # old value doesn't matter anymore # swap transpose is swap x_bar = swap_p.bind(ref, ad_util.instantiate(g), *idx, **params) return [None, x_bar] + [None] * len(idx) @@ -483,31 +636,140 @@ def addupdate_transpose(cts_in, ref, x, *idx, **params): return [None, g] + [None] * len(idx) ad.primitive_transposes[addupdate_p] = addupdate_transpose + +def _get_transpose_fancy(g, ref_, *idx, **params): + if idx and type(g) is not ad_util.Zero: + addupdate_p.bind(ref_.inst().ref, g, *idx, **params) + else: + ref_.accum(g) +ad.fancy_transposes[get_p] = _get_transpose_fancy + +def _swap_transpose_fancy(g, ref_, x, *idx, **params): + if ref_.ref is None and type(g) is ad_util.Zero: + return + elif ref_.ref is None: + swap_p.bind(ref_.inst().ref, ad_util.instantiate(g), *idx, **params) + else: + x_bar = swap_p.bind(ref_.inst().ref, ad_util.instantiate(g), *idx, **params) + x.accum(x_bar) +ad.fancy_transposes[swap_p] = _swap_transpose_fancy + +def addupdate_transpose_fancy(cts_in, ref_, x, *idx, **params): + if ref_.ref is not None and isinstance(x, ad.GradAccum): + x_bar = get_p.bind(ref_.ref, *idx, **params) + x.accum(x_bar) +ad.fancy_transposes[addupdate_p] = addupdate_transpose_fancy + ## get/swap/addupdate partial_eval_custom rules -def _state_partial_eval_custom(prim, saveable, unks_in, inst_in, eqn): - if any(unks_in): - res = [v for v, inst in zip(eqn.invars, inst_in) if not inst] - return None, eqn, [True] * len(eqn.outvars), [True] * len(eqn.outvars), res - elif saveable(prim, *[var.aval for var in eqn.invars], **eqn.params): - return eqn, None, [False] * len(eqn.outvars), [False] * len(eqn.outvars), [] - res = [v for v, inst in zip(eqn.invars, inst_in) if not inst] - return eqn, eqn, [False] * len(eqn.outvars), [True] * len(eqn.outvars), res - -pe.partial_eval_jaxpr_custom_rules[get_p] = partial(_state_partial_eval_custom, - get_p) -pe.partial_eval_jaxpr_custom_rules[swap_p] = partial(_state_partial_eval_custom, - swap_p) -pe.partial_eval_jaxpr_custom_rules[addupdate_p] = partial( - _state_partial_eval_custom, addupdate_p) +def _array_ref_partial_eval_custom(saveable, unks_in, inst_in, eqn): + del saveable # ignored, always full remat array_ref on known input + unk, = unks_in + inst, = inst_in + invar, = eqn.invars + res = [invar] if not inst else [] + if unk: + return None, eqn, [True], [True], res # tangent operation + else: + return eqn, eqn, [False], [True], res # full remat +pe.partial_eval_jaxpr_custom_rules[core.ref_p] = _array_ref_partial_eval_custom + +def _array_ref_batched(axis_data, vals_in, dims_in, memory_space, kind): + val, = vals_in + dim, = dims_in + if dim is None: + # We defensively batch the ref, b/c it could later be hit with a batched val + val2 = batching.broadcast(val, axis_data.size, 0, + axis_data.explicit_mesh_axis) + return core.ref_p.bind(val2, memory_space=memory_space, kind=kind), 0 + else: + return core.ref_p.bind(val, memory_space=memory_space, kind=kind), dim +batching.fancy_primitive_batchers[core.ref_p] = _array_ref_batched + +def _freeze_batched(axis_data, vals_in, dims_in): + ref, = vals_in + dim, = dims_in + return core.freeze_p.bind(ref), dim +batching.fancy_primitive_batchers[core.freeze_p] = _freeze_batched + +def _state_partial_eval_custom(saveable, unks_in, inst_in, eqn): + del saveable # ignored, always full remat state ops on known inputs + # (except for anselm_ref) + ref_unk, *_ = unks_in + ref_inst, *inst_in = inst_in + _, *val_vars = eqn.invars + assert ref_inst + res = [v for v, inst in zip(val_vars, inst_in) if not inst] + if ref_unk: + return None, eqn, [True], [True], res # tangent operation + elif eqn.invars[0].aval.kind == "anselm_ref": + return eqn, None, [False], [False], res + else: + return eqn, eqn, [False], [True], res # full remat +pe.partial_eval_jaxpr_custom_rules[get_p] = _state_partial_eval_custom +pe.partial_eval_jaxpr_custom_rules[swap_p] = _state_partial_eval_custom + +def _addupdate_partial_eval_custom(saveable, unks_in, inst_in, eqn): + del saveable # ignored, always full remat state ops on known inputs + ref_unk, *_ = unks_in + ref_inst, *inst_in = inst_in + _, *val_vars = eqn.invars + assert ref_inst + res = [v for v, inst in zip(val_vars, inst_in) if not inst] + if ref_unk: + return None, eqn, [], [], res # tangent operation + else: + return eqn, eqn, [], [], res # full remat +pe.partial_eval_jaxpr_custom_rules[addupdate_p] = _addupdate_partial_eval_custom ## get/swap/addupdate batching rules -def _batch_indexer(indexer: indexing.NDIndexer, dims, - axis_size: int, - ref_shape: tuple[int, ...], - ref_dim: int | batching.NotMapped, - idx_is_batched: bool) -> indexing.NDIndexer: +def _batch_indexer( + indexer: indexing.NDIndexer, + dims, + axis_size: int, + ref_shape: tuple[int, ...], + ref_dim: int | batching.NotMapped, + idx_is_batched: bool, +) -> indexing.NDIndexer: + """Converts a batched indexer into an unbatched one. + + This function handles the complexity of `vmap`-style batching where either the + `ref` being indexed, the indexer, or both may have batched dimensions. The + goal is to produce a new indexer that acts as if applied in a batched context, + but without actual batching, enabling downstream code to process it as usual. + + If any index in `indexer` is batched, all array indexers are normalized. If + the array indexer contains a batched dimension, the dimension is moved to the + front (axis 0). If the array indexer not batched, it is broadcasted to include + a batch dimension at the front. This is to guarantee that all array indexers + are still of the same shape. + + Slices are passed through unchanged unless they contain dynamic elements and + are themselves batched, which is currently unsupported. + + If `ref` is batched (`ref_dim` is not `NotMapped`), we simulate per-example + indexing by inserting a new iota array at the position corresponding to + `ref_dim` in the indexer. + + It is worth noting that if the array indexers in the original indexer are + contiguous, but become non-contiguous in the new indexer due to the insertion + of the iota, the dimensions corresponding to the array indexers will be moved + to the front in the indexing result. The batched dimension will be at axis 0, + while the dimensions corresponding to the array indexers in the original + indexer will start from axis 1. This behavior would cause a mismatch between + the original indexer and the new indexer. Callers must take this behavior into + account and properly transpose the arrays involved to avoid this mismatch. + + Args: + indexer: An `NDIndexer` that indexes into `ref`. + dims: A pytree with the same structure as `indexer`, indicating which + dimension (if any) is batched for each array indexer. + axis_size: Size of the batch dimension. + ref_shape: Shape of `ref`. + ref_dim: The dimension of `ref` that is batched (if any). + idx_is_batched: Whether any index in the `indexer` is batched. + """ indices = indexer.indices indices_dims = dims.indices new_indices: list[Array | indexing.Slice | int] = [] @@ -545,7 +807,7 @@ def _batch_indexer(indexer: indexing.NDIndexer, dims, idx = lax.broadcast_in_dim(idx, new_integer_indexer_shape, bcast_dims) else: - idx = batching.moveaxis(idx, dim, 0) + idx = batching.moveaxis(idx, dim, 0) # type: ignore[arg-type] new_indices.append(idx) else: if ref_dim is not batching.not_mapped: @@ -557,11 +819,16 @@ def _batch_indexer(indexer: indexing.NDIndexer, dims, bcast_dims) new_indices.append(idx) if ref_dim is not batching.not_mapped: - iota = lax.broadcasted_iota(np.dtype('int32'), new_integer_indexer_shape, 0) - new_indices.insert(ref_dim, iota) - return indexing.NDIndexer(tuple(new_indices), ref_shape, - new_integer_indexer_shape, - validate=True) + if indexer.int_indexer_shape: + batch_idx = lax.broadcasted_iota( + np.dtype('int32'), new_integer_indexer_shape, 0) + else: + batch_idx = indexing.Slice(0, axis_size) # type: ignore + new_integer_indexer_shape = () + new_indices.insert(ref_dim, batch_idx) + return indexing.NDIndexer( + tuple(new_indices), ref_shape, new_integer_indexer_shape, validate=True + ) def _get_vmap(batched_args, batched_dims, *, tree): axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims) @@ -569,23 +836,75 @@ def _get_vmap(batched_args, batched_dims, *, tree): ref, *flat_idxs = batched_args ref_dim, *flat_idx_dims = batched_dims indexers = tree_util.tree_unflatten(tree, flat_idxs) + if not indexers: + return get_p.bind(ref, *flat_idxs, tree=tree), ref_dim indexers_dims = tree_util.tree_unflatten(tree, flat_idx_dims) idx_is_batched = any(i_dim is not batching.not_mapped for i_dim in flat_idx_dims) if len(indexers) > 1: raise NotImplementedError("Batching with multiple indexers not supported.") + # TODO(sharadmv): handle vmap of multiple indexers - indexers = tuple(_batch_indexer(indexer, dims, axis_size, + new_indexers = tuple(_batch_indexer(indexer, dims, axis_size, ref.shape, ref_dim, idx_is_batched) for indexer, dims in zip(indexers, indexers_dims)) - flat_indexers, tree = tree_util.tree_flatten(indexers) - return get_p.bind(ref, *flat_indexers, tree=tree), 0 + flat_indexers, tree = tree_util.tree_flatten(new_indexers) + + is_int_indexing, _, _ = indexing.unpack_ndindexer(indexers[0]) + int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_int_indexing)[0]) == 1) + ) + # Note: _batch_indexer will add a slice for the batch dim if the int_indexer + # shape is empty, else it will use advanced/int indexing. + will_add_int_batcher = bool(indexers[0].int_indexer_shape) + + is_new_int_indexing, _, _ = indexing.unpack_ndindexer(new_indexers[0]) + new_int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_new_int_indexing)[0]) == 1) + ) + + out = get_p.bind(ref, *flat_indexers, tree=tree) + should_transpose = (int_indexers_contiguous and + not new_int_indexers_contiguous) + if will_add_int_batcher and should_transpose: + original_pos = is_int_indexing.index(True) + array_indexer_shape = new_indexers[0].int_indexer_shape + array_indexer_len = len(array_indexer_shape) + + transpose_order = list(range(len(out.shape))) + transpose_order = ( + transpose_order[0], + *transpose_order[array_indexer_len:array_indexer_len+original_pos], + *transpose_order[1:array_indexer_len], + *transpose_order[array_indexer_len+original_pos:], + ) + out = lax.transpose(out, transpose_order) + out_bdim = 0 + else: + if ref_dim is not batching.not_mapped: + if will_add_int_batcher: + if not int_indexers_contiguous: + # In this case the indexer is always moved to the front. + out_bdim = 0 + else: + # In this case the indexer is not moved to the front. + out_bdim = is_new_int_indexing.index(True) + else: + # We only trigger this case when the int_indexer shape is empty, + # so we don't need to account for int_indexer_shape. + int_indexers_before_ref_dim = int(np.sum(is_new_int_indexing[:ref_dim])) + out_bdim = ref_dim - int_indexers_before_ref_dim + else: + out_bdim = 0 + if any(is_int_indexing): + # The batch dim is the indexer's batch dim. + original_pos = is_int_indexing.index(True) + out_bdim = original_pos + return out, out_bdim batching.primitive_batchers[get_p] = _get_vmap -def _swap_vmap(batched_args, batched_dims, *, tree): - axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims) - if d is not batching.not_mapped} +def _swap_vmap(axis_data, batched_args, batched_dims, *, tree): ref, val, *flat_idxs = batched_args ref_dim, val_dim, *flat_idx_dims = batched_dims indexers = tree_util.tree_unflatten(tree, flat_idxs) @@ -595,23 +914,80 @@ def _swap_vmap(batched_args, batched_dims, *, tree): val_is_batched = val_dim is not batching.not_mapped idx_is_batched = any(i_dim is not batching.not_mapped for i_dim in flat_idx_dims) + + if not ref_is_batched: + raise Exception("performing a set/swap operation with vmapped value on " + f"an unbatched array reference of type {core.typeof(ref)}. " + "Move the array reference to be an argument to the vmapped " + "function?") + if not indexers: + if ref_is_batched and not val_is_batched: + val = batching.broadcast(val, axis_data.size, ref_dim, + axis_data.explicit_mesh_axis) + return swap_p.bind(ref, val, *flat_idxs, tree=tree), ref_dim if len(indexers) > 1: raise NotImplementedError("Batching with multiple indexers not supported.") # TODO(sharadmv): handle vmap of multiple indexers - indexers = tuple(_batch_indexer(indexer, dims, axis_size, + new_indexers = tuple(_batch_indexer(indexer, dims, axis_data.size, ref.shape, ref_dim, idx_is_batched) for indexer, dims in zip(indexers, indexers_dims)) - flat_indexers, tree = tree_util.tree_flatten(indexers) - if (ref_is_batched or idx_is_batched) and not val_is_batched: - val = batching.broadcast(val, axis_size, 0) - if val_is_batched: - val = batching.moveaxis(val, val_dim, 0) - return swap_p.bind(ref, val, *flat_indexers, tree=tree), 0 -batching.primitive_batchers[swap_p] = _swap_vmap - -def _addupdate_vmap(batched_args, batched_dims, *, tree): - axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims) - if d is not batching.not_mapped} + flat_indexers, tree = tree_util.tree_flatten(new_indexers) + + is_int_indexing, _, _ = indexing.unpack_ndindexer(indexers[0]) + int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_int_indexing)[0]) == 1) + ) + is_new_int_indexing, _, _ = indexing.unpack_ndindexer(new_indexers[0]) + new_int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_new_int_indexing)[0]) == 1) + ) + + if not new_int_indexers_contiguous: # will be moved to the front + batched_dim_in_result = 0 + else: + try: + batched_dim_in_result = is_new_int_indexing.index(True) + 0 + except ValueError: + batched_dim_in_result = ref_dim + + if not val_is_batched: + if ref_is_batched or idx_is_batched: + val = batching.broadcast(val, axis_data.size, batched_dim_in_result, + axis_data.explicit_mesh_axis) + else: + val = batching.moveaxis(val, val_dim, batched_dim_in_result) + + transpose_order_inversed = None + + # Originally not going to be moved to the front, but now going to be moved to + # the front. + if int_indexers_contiguous and not new_int_indexers_contiguous: + original_pos = is_int_indexing.index(True) + array_indexer_shape = new_indexers[0].int_indexer_shape + array_indexer_len = len(array_indexer_shape) + + transpose_order = list(range(len(val.shape))) + transpose_order = ( + transpose_order[0], + *transpose_order[1+original_pos:(1+original_pos)+(array_indexer_len-1)], + *transpose_order[1:1+original_pos], + *transpose_order[(1+original_pos)+(array_indexer_len-1):], + ) + val = val.transpose(transpose_order) + transpose_order_inversed = np.argsort(transpose_order) + + out = swap_p.bind(ref, val, *flat_indexers, tree=tree) + + # `val` should not be transposed, but we needed to transpose it to match + # `swap_p`. As a result, the output of `swap_p` is also transposed. Now we + # need to transpose it back. + if transpose_order_inversed is not None: + out = out.transpose(transpose_order_inversed) + + return out, batched_dim_in_result +batching.fancy_primitive_batchers[swap_p] = _swap_vmap + +def _addupdate_vmap(axis_data, batched_args, batched_dims, *, tree): ref, val, *flat_idxs = batched_args ref_dim, val_dim, *flat_idx_dims = batched_dims indexers = tree_util.tree_unflatten(tree, flat_idxs) @@ -621,19 +997,67 @@ def _addupdate_vmap(batched_args, batched_dims, *, tree): val_is_batched = val_dim is not batching.not_mapped idx_is_batched = any(i_dim is not batching.not_mapped for i_dim in flat_idx_dims) + + if not ref_is_batched: + raise Exception("performing an addupdate operation with vmapped value on " + f"an unbatched array reference of type {core.typeof(ref)}. " + "Move the array reference to be an argument to the vmapped " + "function?") + if not indexers: + if val_dim != ref_dim: + val = batching.matchaxis2(axis_data, val_dim, ref_dim, val) + return addupdate_p.bind(ref, val, *flat_idxs, tree=tree), [] if len(indexers) > 1: raise NotImplementedError("Batching with multiple indexers not supported.") + # TODO(sharadmv): handle vmap of multiple indexers - indexers = tuple(_batch_indexer(indexer, dims, axis_size, + new_indexers = tuple(_batch_indexer(indexer, dims, axis_data.size, ref.shape, ref_dim, idx_is_batched) for indexer, dims in zip(indexers, indexers_dims)) - flat_indexers, tree = tree_util.tree_flatten(indexers) - if (ref_is_batched or idx_is_batched) and not val_is_batched: - val = batching.broadcast(val, axis_size, 0) - if val_is_batched: - val = batching.moveaxis(val, val_dim, 0) + flat_indexers, tree = tree_util.tree_flatten(new_indexers) + + is_int_indexing, _, _ = indexing.unpack_ndindexer(indexers[0]) + int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_int_indexing)[0]) == 1) + ) + is_new_int_indexing, _, _ = indexing.unpack_ndindexer(new_indexers[0]) + new_int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_new_int_indexing)[0]) == 1) + ) + + if not new_int_indexers_contiguous: # will be moved to the front + batched_dim_in_result = 0 + else: + try: + batched_dim_in_result = is_new_int_indexing.index(True) + except ValueError: + batched_dim_in_result = ref_dim + + if not val_is_batched: + if ref_is_batched or idx_is_batched: + val = batching.broadcast(val, axis_data.size, batched_dim_in_result, + axis_data.explicit_mesh_axis) + else: + val = batching.moveaxis(val, val_dim, batched_dim_in_result) + + # Originally not going to be moved to the front, but now going to be moved to + # the front. + if int_indexers_contiguous and not new_int_indexers_contiguous: + original_pos = is_int_indexing.index(True) + array_indexer_shape = new_indexers[0].int_indexer_shape + array_indexer_len = len(array_indexer_shape) + + transpose_order = list(range(len(val.shape))) + transpose_order = ( + transpose_order[0], + *transpose_order[1+original_pos:(1+original_pos)+(array_indexer_len-1)], + *transpose_order[1:1+original_pos], + *transpose_order[(1+original_pos)+(array_indexer_len-1):], + ) + val = val.transpose(transpose_order) + return addupdate_p.bind(ref, val, *flat_indexers, tree=tree), [] -batching.primitive_batchers[addupdate_p] = _addupdate_vmap +batching.fancy_primitive_batchers[addupdate_p] = _addupdate_vmap # Currently, JAX doesn't have a primitive that does an equal-rank broadcast. # We could use `jnp.broadcast_to` but that lowers to squeezing, @@ -644,7 +1068,19 @@ def _addupdate_vmap(batched_args, batched_dims, *, tree): broadcast_to_p = core.Primitive('broadcast_to') def broadcast_to(a: Array, shape: tuple[int, ...]) -> Array: - import jax.numpy as jnp + """Broadcasts an array to a new shape. + + Args: + a: The array to broadcast. + shape: The desired shape to broadcast to. + + Returns: + An array of shape ``shape``. + + See Also: + :func:`jax.numpy.broadcast_to` + """ + import jax.numpy as jnp # pytype: disable=import-error a = jnp.asarray(a) if a.shape == shape: return a @@ -652,7 +1088,7 @@ def broadcast_to(a: Array, shape: tuple[int, ...]) -> Array: @broadcast_to_p.def_impl def _broadcast_to_impl(a, *, shape): - import jax.numpy as jnp + import jax.numpy as jnp # pytype: disable=import-error return jnp.broadcast_to(a, shape) @broadcast_to_p.def_abstract_eval @@ -665,14 +1101,93 @@ def _broadcast_to_abstract_eval(aval, *, shape): # === AD rules for mutable arrays === -def _mut_jvp(primals, tangents): - (init_val,), (init_val_dot,) = primals, tangents - primal_out = core.mutable_array_p.bind(init_val) - if type(init_val_dot) is ad_util.Zero: - tangent_out = core.mutable_array_p.bind(ad_util.zeros_like_aval(init_val_dot.aval)) +def _ref_jvp(primals, tangents, *, memory_space, kind): + (init_val,), (init_dot,) = primals, tangents + primal_out = core.ref_p.bind(init_val, memory_space=memory_space, kind=kind) + if type(init_dot) is ad_util.Zero: + zero = ad_util.zeros_like_aval(init_dot.aval) + tangent_out = core.ref_p.bind(zero, memory_space=memory_space, kind=kind) else: - tangent_out = core.mutable_array_p.bind(init_val_dot) + tangent_out = core.ref_p.bind(init_dot, memory_space=memory_space, kind=kind) return primal_out, tangent_out -ad.primitive_jvps[core.mutable_array_p] = _mut_jvp +def _ref_lin(nzs, x, *, memory_space, kind): + nz, = nzs + x_ref = core.ref_p.bind(x, memory_space=memory_space, kind=kind) + def mut_lin(_, x_dot): + if kind == 'anselm_ref': + aval = x_dot.aval if type(x_dot) is ad.Zero else core.typeof(x_dot) + return ad.Zero(AbstractRef(aval)) + zero = ad_util.instantiate(x_dot) + return core.ref_p.bind(zero, memory_space=memory_space, kind=kind) + return x_ref, kind != 'anselm_ref', None, mut_lin + +ad.primitive_jvps[core.ref_p] = _ref_jvp +ad.primitive_linearizations[core.ref_p] = _ref_lin +# TODO(mattjj): lin rule for freeze and accum_grad_in_ref? ad.defjvp(core.freeze_p, lambda g, _: core.freeze(g)) +ad.defjvp(core.accum_grad_in_ref_p, lambda g, _: core.accum_grad_in_ref_p.bind(g)) + +# === pinned, chained LinearVals === + +def create_linear(ty, memory_space=None): + return create_linear_p.bind(ty=ty, memory_space=memory_space) +create_linear_p = core.Primitive('create_linear') + +@create_linear_p.def_abstract_eval +def _create_linear_abstract_eval(*, ty, memory_space): + if not isinstance(ty, core.ShapedArray): raise NotImplementedError(ty) + return AbstractLinVal(ty, memory_space) + +def _lower_create_linear(ctx): + out_aval, = ctx.avals_out + return mlir.custom_call( + "CreateBuffer", + operands=[], + result_types=[mlir.aval_to_ir_type(out_aval)], + ).results +mlir.register_lowering(create_linear_p, _lower_create_linear) + + +def pin(x): + return pin_p.bind(x) +pin_p = core.Primitive('pin') + +@pin_p.def_abstract_eval +def _pin_abstract_eval(aval): + if not isinstance(aval, core.ShapedArray): raise NotImplementedError(aval) + return AbstractLinVal(aval) + +def _lower_pin(ctx, x_op): + out_aval, = ctx.avals_out + return mlir.custom_call( + "Pin", + operands=mlir.flatten_ir_values([x_op]), + result_types=[mlir.aval_to_ir_type(out_aval)], + ).results +mlir.register_lowering(pin_p, _lower_pin) + + +def unpin(x): + return unpin_p.bind(x) +unpin_p = core.Primitive('unpin') + +@unpin_p.def_abstract_eval +def _unpin_abstract_eval(aval): + if not isinstance(aval, AbstractLinVal): raise TypeError(aval) + return aval.inner_aval + +def _lower_unpin(ctx, x_op): + out_aval, = ctx.avals_out + return mlir.custom_call( + "Unpin", + operands=mlir.flatten_ir_values([x_op]), + result_types=[mlir.aval_to_ir_type(out_aval)], + ).results +mlir.register_lowering(unpin_p, _lower_unpin) + + +def _linval_to_mlir_type(a): + return mlir.ir.MemRefType.get(a.shape, mlir.dtype_to_ir_type(a.dtype), + memory_space=a.memory_space) +mlir.ir_type_handlers[AbstractLinVal] = _linval_to_mlir_type diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index 057242f4c1ac..3c7c09684f99 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -18,7 +18,8 @@ from collections.abc import Sequence import dataclasses import math -from typing import Any, Callable, Protocol, Union +from typing import Any, Protocol, Union +from collections.abc import Callable from jax._src import core from jax._src import dtypes @@ -75,6 +76,10 @@ class AccumEffect(RefEffect): name: str = "Accum" effects.control_flow_allowed_effects.add_type(RefEffect) +effects.custom_derivatives_allowed_effects.add_type(RefEffect) +effects.custom_derivatives_allowed_effects.add_type(core.InternalMutableArrayEffect) +effects.partial_eval_kept_effects.add_type(RefEffect) +effects.remat_allowed_effects.add_type(RefEffect) StateEffect = Union[ReadEffect, WriteEffect, AccumEffect] @@ -125,6 +130,10 @@ def transform_sharding(self, sharding): return sharding raise NotImplementedError + def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc: + del context # Unused. + return pp.text(f"{{bitcast({self.dtype}{list(self.shape)}])}}") + @tree_util.register_pytree_node_class @dataclasses.dataclass(frozen=True) @@ -138,6 +147,18 @@ def from_ref_new_shape(cls, ref_or_view: Any, *shape: Any) -> RefReshaper: shape = shape[0] if not shape: raise ValueError("Cannot reshape ref to empty shape") + if any(s == -1 for s in shape): + num_elements = math.prod(ref_or_view.shape) + defined_dims = [d for d in shape if d != -1] + if len(defined_dims) != len(shape) - 1: + raise ValueError(f"At most one dimension can be -1, but got {shape}") + if num_elements % math.prod(defined_dims): + raise ValueError( + f"Specified dims {shape} do not evenly divide the size of the " + f"ref ({num_elements})." + ) + remaining_dim = num_elements // math.prod(defined_dims) + shape = tuple(d if d != -1 else remaining_dim for d in shape) if np.prod(shape) != np.prod(ref_or_view.shape): raise TypeError( f"cannot reshape ref of shape {ref_or_view.shape} into shape {shape}" @@ -168,7 +189,7 @@ def transform_shape( del shape # Unused return self.shape - def transform_dtype(self, dtype): + def transform_dtype(self, dtype: DTypeLike | None) -> DTypeLike | None: del dtype # Unused return self.dtype @@ -178,6 +199,49 @@ def transform_sharding(self, sharding): return sharding raise NotImplementedError + def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc: + del context # Unused. + return pp.text(f"{{reshape({self.dtype}{list(self.shape)})}}") + + +@tree_util.register_dataclass +@dataclasses.dataclass(frozen=True) +class RefTransposer: + permutation: tuple[int, ...] = dataclasses.field(metadata=dict(static=True)) + + @classmethod + def from_ref_new_permutation( + cls, ref_or_view: Any, *perm: int + ) -> RefTransposer: + if len(perm) == 1 and isinstance(perm[0], tuple): + perm = perm[0] + if len(perm) != ref_or_view.ndim: + raise ValueError( + f"Permutation {perm} does not match the rank of the ref" + f" ({ref_or_view.ndim})" + ) + return cls(perm) + + def transform_shape( + self, shape: tuple[int | Array, ...] | None + ) -> tuple[int | Array, ...] | None: + if shape is None: + return None + return tuple(shape[i] for i in self.permutation) + + def transform_dtype(self, dtype): + return dtype + + def transform_sharding(self, sharding): + # If there are no explicit axes, do nothing. + if all(p is None for p in sharding.spec): + return sharding + raise NotImplementedError + + def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc: + del context # Unused. + return pp.text(f"{{transpose({list(self.permutation)})}}") + class Transform(Protocol): @@ -191,9 +255,7 @@ def transform_shape( """ return shape - def transform_dtype( - self, dtype: DTypeLike | None - ) -> DTypeLike | None: + def transform_dtype(self, dtype: DTypeLike | None) -> DTypeLike | None: """Transform the dtype. Can return None if the input dtype is not known, but must return a concrete @@ -205,6 +267,9 @@ def transform_sharding(self, sharding): if all(p is None for p in sharding.spec): return sharding # no explicit axes raise NotImplementedError + def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc: + return pp.text(f"{{{self}}}") + @dataclasses.dataclass class RefIndexer: @@ -243,7 +308,7 @@ def shape(self) -> tuple[int | Array, ...]: if not unprocessed: return shape # If there are any unprocessed transforms left, we apply them to the shape - # we've found previuously. + # we've found previously. for t in self.transforms[-unprocessed:]: shape = t.transform_shape(shape) assert shape is not None @@ -266,6 +331,10 @@ def dtype(self): assert dtype is not None return dtype + ndim = property(lambda self: len(self.shape)) + size = property(lambda self: math.prod(self.shape)) + T = property(lambda self: self.transpose(*reversed(range(self.ndim)))) + @property def at(self) -> RefIndexer: return RefIndexer(self) @@ -282,6 +351,10 @@ def reshape(self, *shape): (*self.transforms, RefReshaper.from_ref_new_shape(self, *shape)), ) + def transpose(self, *permutation): + transposer = RefTransposer.from_ref_new_permutation(self, *permutation) + return TransformedRef(self.ref, (*self.transforms, transposer)) + def set(self, value, idx=()): from jax._src.state.primitives import ref_set # pytype: disable=import-error return ref_set(self, idx, value) @@ -306,12 +379,51 @@ def __setitem__(self, slc, value): return ref_set(self, slc, value) +def get_transforms_shape( + ts: Sequence[Transform], shape: tuple[int | Array, ...] +) -> tuple[int | Array, ...]: + for t in ts: + shape = t.transform_shape(shape) # type: ignore + assert shape is not None + return shape + + # We need an aval for `Ref`s so we can represent `get` and `swap` in Jaxprs. class AbstractRef(core.AbstractValue): - __slots__ = ["inner_aval"] + """Abstract mutable array reference. + + Refer to the `Ref guide`_ for more information. + + .. _Ref guide: https://docs.jax.dev/en/latest/array_refs.html + """ + __slots__ = ["inner_aval", "memory_space", "kind"] - def __init__(self, inner_aval: core.AbstractValue): + def __init__(self, inner_aval: core.AbstractValue, memory_space: Any = None, + kind: Any = None): self.inner_aval = inner_aval + self.memory_space = memory_space + self.kind = kind + + @property + def is_high(self): + return self.inner_aval.is_high + + def lo_ty(self): + return [ + AbstractRef(x, memory_space=self.memory_space) + for x in self.inner_aval.lo_ty() + ] + + def lower_val(self, ref): + if not self.is_high: + return [ref] + return self.inner_aval.lower_val(ref._refs) # type: ignore + + def raise_val(self, *vals): + if not self.is_high: + ref, = vals + return ref + return core.Ref(self, self.inner_aval.raise_val(*vals)) # type: ignore @property def weak_type(self) -> bool: @@ -320,23 +432,30 @@ def weak_type(self) -> bool: return self.inner_aval.weak_type def update_weak_type(self, weak_type): - return AbstractRef(self.inner_aval.update_weak_type(weak_type)) + return self.update(inner_aval=self.inner_aval.update_weak_type(weak_type)) - def update(self, inner_aval=None): - if inner_aval is None: - return AbstractRef(self.inner_aval) - return AbstractRef(inner_aval) + def update(self, inner_aval=None, memory_space=None, kind=None): + inner_aval = self.inner_aval if inner_aval is None else inner_aval + memory_space = self.memory_space if memory_space is None else memory_space + kind = self.kind if kind is None else kind + return AbstractRef(inner_aval, memory_space, kind) ndim = property(lambda self: len(self.shape)) size = property(lambda self: math.prod(self.shape)) + def _len(self, ignored_tracer) -> int: + try: + return self.shape[0] + except IndexError as err: + raise TypeError("len() of unsized object") from err # same as numpy error + @property def shape(self): try: return self.inner_aval.shape # pytype: disable=attribute-error except AttributeError: raise AttributeError( - f"`Ref{{{self.inner_aval.str_short()}}} has no `shape`." + f"{self!r} has no `shape`." ) from None @property @@ -345,7 +464,7 @@ def dtype(self): return self.inner_aval.dtype # pytype: disable=attribute-error except AttributeError: raise AttributeError( - f"`Ref{{{self.inner_aval.str_short()}}} has no `dtype`." + f"{self!r} has no `dtype`." ) from None @property @@ -354,7 +473,16 @@ def sharding(self): return self.inner_aval.sharding # pytype: disable=attribute-error except AttributeError: raise AttributeError( - f"`Ref{{{self.inner_aval.str_short()}}} has no `sharding`." + f"{self!r} has no `sharding`." + ) from None + + @property + def vma(self): + try: + return self.inner_aval.vma # pytype: disable=attribute-error + except AttributeError: + raise AttributeError( + f"{self!r} has no `vma`." ) from None @core.aval_property @@ -363,11 +491,19 @@ def at(self): @core.aval_method def bitcast(self, dtype): - return TransformedRef(self, (RefBitcaster.from_ref_new_dtype(self, dtype),)) + return TransformedRef(self, ()).bitcast(dtype) @core.aval_method def reshape(self, *shape): - return TransformedRef(self, (RefReshaper.from_ref_new_shape(self, *shape),)) + return TransformedRef(self, ()).reshape(*shape) + + @core.aval_method + def transpose(self, *permutation): + return TransformedRef(self, ()).transpose(*permutation) + + @core.aval_property + def T(self): + return TransformedRef(self, ()).T @core.aval_method @staticmethod @@ -387,6 +523,12 @@ def set(tracer, value, idx=()): from jax._src.state.primitives import ref_set # pytype: disable=import-error return ref_set(tracer, idx, value) + @core.aval_method + @staticmethod + def addupdate(tracer, value, idx=()): + from jax._src.state.primitives import ref_addupdate # pytype: disable=import-error + ref_addupdate(tracer, idx, value) + def _getitem(self, tracer, idx) -> Array: from jax._src.state.primitives import ref_get # pytype: disable=import-error return ref_get(tracer, idx) @@ -395,24 +537,46 @@ def _setitem(self, tracer, idx, value) -> None: from jax._src.state.primitives import ref_set # pytype: disable=import-error return ref_set(tracer, idx, value) + def _addupdate(self, tracer, idx, value): + from jax._src.state.primitives import ref_addupdate # pytype: disable=import-error + ref_addupdate(tracer, idx, value) + + def str_short(self, short_dtypes=False, mesh_axis_types=False) -> str: + inner_aval_str = self.inner_aval.str_short( + short_dtypes=short_dtypes, + mesh_axis_types=mesh_axis_types, + ) + if self.memory_space is not None: + return f'Ref<{self.memory_space}>{{{inner_aval_str}}}' + return f'Ref{{{inner_aval_str}}}' + def __repr__(self) -> str: - return f'Ref{{{self.inner_aval.str_short()}}}' + return self.str_short() + __str__ = __repr__ def to_tangent_aval(self): - return AbstractRef(self.inner_aval.to_tangent_aval()) + return AbstractRef(self.inner_aval.to_tangent_aval(), self.memory_space, + kind=self.kind) + + def to_cotangent_aval(self): + return AbstractRef(self.inner_aval.to_cotangent_aval(), self.memory_space, + kind=self.kind) def __eq__(self, other): - return (type(self) is type(other) and self.inner_aval == other.inner_aval) + return (type(self) is type(other) and self.inner_aval == other.inner_aval + and self.memory_space == other.memory_space) def __hash__(self): - return hash((self.__class__, self.inner_aval)) + return hash((self.__class__, self.inner_aval, self.memory_space)) def _map_ref(size, axis, ref_aval): - return AbstractRef(core.mapped_aval(size, axis, ref_aval.inner_aval)) + return AbstractRef(core.mapped_aval(size, axis, ref_aval.inner_aval), + ref_aval.memory_space, ref_aval.kind) def _unmap_ref(size, axis, explicit_mesh_axis, ref_aval): return AbstractRef(core.unmapped_aval( - size, axis, ref_aval.inner_aval, explicit_mesh_axis)) + size, axis, ref_aval.inner_aval, explicit_mesh_axis), + ref_aval.memory_space, ref_aval.kind) core.aval_mapping_handlers[AbstractRef] = (_map_ref, _unmap_ref) @@ -427,20 +591,13 @@ def shaped_array_ref( shape: tuple[int, ...], dtype, weak_type: bool = False) -> AbstractRef: return AbstractRef(core.ShapedArray(shape, dtype, weak_type=weak_type)) -def _shard_ref(mesh, auto, names, ref_aval: AbstractRef): - del mesh - if names: - # Can't actually shard a ref, can only close over it. - raise NotImplementedError("Can't shard a Ref.") - return ref_aval +def _shard_ref(mesh, auto, check_rep, names, ref_aval: AbstractRef): + aval = core.shard_aval(mesh, auto, check_rep, names, ref_aval.inner_aval) + return AbstractRef(aval) core.shard_aval_handlers[AbstractRef] = _shard_ref -def _unshard_ref(mesh, names, ref_aval: AbstractRef): - del mesh - if names: - # Can't actually shard a ref, can only close over it. - raise NotImplementedError("Can't unshard a Ref") - return ref_aval +def _unshard_ref(mesh, check_rep, names, ref_aval: AbstractRef): + raise TypeError("can't unshard a ref") core.unshard_aval_handlers[AbstractRef] = _unshard_ref @@ -465,3 +622,14 @@ def get_ref_aval_from_value(x: Any): if type(x) in _ref_type_aval_mappings: return _ref_type_aval_mappings[type(x)](x) return _default_value_to_ref_aval(x) + +# === pinned, chained LinearVals === + +@dataclasses.dataclass(frozen=True) +class AbstractLinVal(core.AbstractValue): + inner_aval: core.AbstractValue + memory_space: Any = None + + shape = property(lambda self: self.inner_aval.shape) # type: ignore + dtype = property(lambda self: self.inner_aval.dtype) # type: ignore + ndim = property(lambda self: self.inner_aval.ndim) # type: ignore diff --git a/jax/_src/state/utils.py b/jax/_src/state/utils.py index 2dd57dcde0ca..2af8a800ae47 100644 --- a/jax/_src/state/utils.py +++ b/jax/_src/state/utils.py @@ -14,15 +14,16 @@ """Utilities for tracing stateful functions.""" from functools import partial -from typing import Callable +from collections.abc import Callable -import jax +from jax._src import api from jax._src import core from jax._src import dtypes from jax._src import linear_util as lu from jax._src.interpreters import partial_eval as pe -from jax._src.state import AbstractRef +from jax._src.lax import lax from jax._src.state.primitives import ref_get +from jax._src.state.types import AbstractRef from jax._src.typing import DTypeLike from jax._src.util import safe_map, safe_zip, split_list @@ -72,8 +73,9 @@ def _hoist(*consts_args): ] return core.eval_jaxpr(jaxpr, all_consts, *args0, *args1) - hoisted_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(_hoist, debug_info=jaxpr.debug_info), in_avals) + hoisted_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(_hoist, debug_info=jaxpr.debug_info.with_unknown_names()), + in_avals) assert not consts, "All consts should have been converted to refs" return hoisted_jaxpr @@ -85,15 +87,9 @@ def val_to_ref_aval(x) -> AbstractRef: return AbstractRef(aval) -def dtype_bitwidth(dtype: DTypeLike) -> int: - if dtypes.isdtype(dtype, "integral"): - return dtypes.iinfo(dtype).bits - return dtypes.dtype(dtype).itemsize * 8 - - def bitcast(x, dtype: DTypeLike): - x_bitwidth = dtype_bitwidth(x.dtype) - y_bitwidth = dtype_bitwidth(dtype) + x_bitwidth = dtypes.itemsize_bits(x.dtype) + y_bitwidth = dtypes.itemsize_bits(dtype) shape = list(x.shape) if x_bitwidth != y_bitwidth: if len(shape) < 2: @@ -112,7 +108,7 @@ def bitcast(x, dtype: DTypeLike): x = x.reshape(*x.shape[:-2], x.shape[-2] // ratio, ratio, -1).swapaxes( -1, -2 ) - y = jax.lax.bitcast_convert_type(x, dtype) + y = lax.bitcast_convert_type(x, dtype) if x_bitwidth > y_bitwidth: y = y.swapaxes(-1, -2).reshape(shape) return y @@ -120,4 +116,4 @@ def bitcast(x, dtype: DTypeLike): def eval_bitcast_shape(x, dtype: DTypeLike): f = partial(bitcast, dtype=dtype) - return jax.eval_shape(f, jax.ShapeDtypeStruct(x.shape, x.dtype)).shape + return api.eval_shape(f, api.ShapeDtypeStruct(x.shape, x.dtype)).shape diff --git a/jax/_src/stateful_rng.py b/jax/_src/stateful_rng.py new file mode 100644 index 000000000000..a09e3da8b9bb --- /dev/null +++ b/jax/_src/stateful_rng.py @@ -0,0 +1,302 @@ +# Copyright 2026 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Stateful, implicitly-updated PRNG implementation based on mutable refs. +""" +from __future__ import annotations + +import dataclasses +import operator +from collections.abc import Sequence + +from jax._src import api_util +from jax._src import core +from jax._src import dtypes +from jax._src import numpy as jnp +from jax._src import random +from jax._src import ref +from jax._src import tree_util +from jax._src import typing +from jax._src.state import primitives as ref_primitives +from jax._src.state import types as state_types +from jax._src.typing import Array, ArrayLike, DTypeLike + +import numpy as np + + +def _canonicalize_size(size: int | Sequence[int] | None, *args: ArrayLike) -> tuple[int, ...]: + if size is None: + return np.broadcast_shapes(*(np.shape(arg) for arg in args)) + elif isinstance(size, (int, np.number)): + return (operator.index(size),) + else: + return tuple(map(operator.index, size)) + + +@tree_util.register_dataclass +@dataclasses.dataclass(frozen=True) +class StatefulPRNG: + """Stateful JAX random generator. + + This should be instantiated using the :func:`jax.experimental.random.stateful_rng` function. + + Attributes: + _base_key: a typed JAX PRNG key object (see :func:`jax.random.key`). + _counter: a scalar integer wrapped in a :class:`jax.Ref`. + + Examples: + + >>> from jax.experimental import random + >>> rng = random.stateful_rng(42) + >>> rng + StatefulPRNG(_base_key=Array((), dtype=key) overlaying: + [ 0 42], _counter=Ref(0, dtype=int32, weak_type=True)) + """ + _base_key: Array + _counter: core.Ref + + def __post_init__(self): + if self._base_key is api_util.SENTINEL: + return + if not (isinstance(self._base_key, Array) + and dtypes.issubdtype(self._base_key.dtype, dtypes.prng_key)): + raise ValueError(f"Expected base_key to be a typed PRNG key; got {self._base_key}") + + # TODO(jakevdp): how to validate a traced mutable array? + if not (isinstance(self._counter, core.Ref) or + (isinstance(self._counter, core.Tracer) + and isinstance(self._counter.aval, state_types.AbstractRef))): + raise ValueError(f"Expected counter to be a scalar integer ref; got {self._counter}") + + def key(self, shape: int | Sequence[int] = ()) -> Array: + """Generate a new JAX PRNGKey, updating the internal state. + + Args: + shape: an optional shape if returning multiple keys. + + Returns: + A new, independent PRNG key with the same impl/dtype as + ``self._base_key``. + + Examples: + >>> from jax.experimental import random + >>> rng = random.stateful_rng(0) + >>> rng.key() + Array((), dtype=key) overlaying: + [1797259609 2579123966] + >>> rng.key() + Array((), dtype=key) overlaying: + [ 928981903 3453687069] + """ + if self._base_key.shape: + # TODO(jakevdp): better error message. + raise ValueError("cannot operate on split stateful generator") + + key = random.fold_in(self._base_key, ref_primitives.ref_get(self._counter)) + ref_primitives.ref_addupdate(self._counter, ..., 1) # pytype: disable=wrong-arg-types # pytype bug? + shape_tuple = _canonicalize_size(shape) + return random.split(key, shape_tuple) if shape_tuple else key + + def random( + self, + size: int | Sequence[int] | None = None, + dtype: DTypeLike = float, + ): + """Return random floats in the half-open interval [0.0, 1.0).""" + # TODO(jakevdp): write docstring + return random.uniform(self.key(), shape=_canonicalize_size(size), dtype=dtype) + + + def uniform( + self, + low: ArrayLike = 0, + high: ArrayLike = 1, + size: int | Sequence[int] | None = None, + *, + dtype: DTypeLike = float, + ) -> Array: + """Draw uniformly distributed pseudorandom values.""" + # TODO(jakevdp): write docstring + return random.uniform(self.key(), _canonicalize_size(size, low, high), + minval=low, maxval=high, dtype=dtype) + + def normal( + self, + loc: ArrayLike = 0, + scale: ArrayLike = 1, + size: int | Sequence[int] | None = None, + *, + dtype: DTypeLike = float, + ) -> Array: + """Draw normally-distributed pseudorandom values.""" + # TODO(jakevdp): write docstring + norm = random.normal(self.key(), _canonicalize_size(size, loc, scale), dtype) + return (jnp.asarray(loc) + jnp.asarray(scale) * norm).astype(dtype) + + def integers( + self, + low: ArrayLike, + high: ArrayLike | None = None, + size: int | Sequence[int] | None = None, + *, + dtype: DTypeLike = int, + ) -> Array: + """Draw pseudorandom integers.""" + # TODO(jakevdp): write docstring + if high is None: + low, high = 0, low + return random.randint(self.key(), _canonicalize_size(size, low, high), + minval=low, maxval=high, dtype=dtype) + + def split(self, num: int | Sequence[int]) -> StatefulPRNG: + """Create independent child generators suitable for use in :func:`jax.vmap`. + + Args: + num: integer or sequence of integers specifying the split shape + + Returns: + a single StatefulPRNG object with split contents, suitable for use + with :func:`jax.vmap` + + Examples: + >>> import jax + >>> from jax.experimental import random + >>> rng = random.stateful_rng(123) + >>> x = jax.numpy.zeros(3) + >>> def f(rng, x): + ... return x + rng.uniform() + >>> jax.vmap(f)(rng.split(3), x) + Array([0.35525954, 0.21937883, 0.5336956 ], dtype=float32) + + See also: + - :meth:`jax.experimental.random.StatefulPRNG.spawn`: This is similar to ``split``, but + returns a Python list of :class:`StatefulPRNG`` objects. + """ + return StatefulPRNG( + _base_key=self.key(num), + _counter=ref.new_ref(jnp.zeros(num, dtype=int)) + ) + + def spawn(self, n_children: int) -> list['StatefulPRNG']: + """Create a list of independent child generators. + + Args: + n_children: non-negative integer. + + Returns: + A list of length ``n_children`` containing new independent ``StatefulPRNG`` instances + spawned from the original instance. + + Examples: + >>> from jax.experimental import random + >>> rng = random.stateful_rng(123) + >>> child_rngs = rng.spawn(2) + >>> [r.integers(0, 10, 2) for r in child_rngs] + [Array([4, 5], dtype=int32), Array([2, 1], dtype=int32)] + + See also: + - :meth:`jax.experimental.random.StatefulPRNG.split`: this is similar to spawn, but returns + a single mapped :class:`jax.experimental.random.StatefulPRNG`` which can be passed to + :func:`jax.vmap`. + """ + return [self.__class__(key, ref.new_ref(0)) for key in self.key(n_children)] + + +def stateful_rng(seed: typing.ArrayLike | None = None, *, + impl: random.PRNGSpecDesc | None = None) -> StatefulPRNG: + """ + Experimental stateful RNG with implicitly-updated state. + + This implements a stateful PRNG API similar to :func:`numpy.random.default_rng`. + It is compatible with JAX transformations like :func:`~jax.jit` and others, + with a few exceptions mentioned in the Notes below. + + .. note:: + + This stateful PRNG API is a convenience wrapper around JAX's classic + stateless, explicitly updated PRNG, described in :mod:`jax.random`. + For performance-critical applications, it is recommended to use + :func:`jax.random.key` with explicit random state semantics. + + For a discussion of design considerations for this API, refer to + :ref:`stateful-randomness-jep`. + + Args: + seed: an optional 64- or 32-bit integer used as the value of the key. + This must be specified if the generator is instantiated within transformed + code; when used at the top level of the program, it may be omitted in + which case the RNG will be seeded using the default NumPy seeding. + impl: optional string specifying the PRNG implementation (e.g. + ``'threefry2x32'``) + + Returns: + A :class:`~jax.experimental.random.StatefulPRNG` object, with methods for generating + random values. + + Notes: + The :class:`~jax.experimental.random.StatefulPRNG` object created by this method uses + :func:`~jax.Ref` objects to allow implicit updates of state, and thus + inherits some of its limitiations. For example: + + - :class:`StatefulPRNG` objects cannot be among the return values of functions + wrapped in JIT or other JAX transformations. This means in particular + they cannot be used as `carry` values for :func:`jax.lax.scan`, + :func:`jax.lax.while_loop`, and other JAX control flow. + - :class:`StatefulPRNG` objects cannot be used together with + :func:`jax.checkpoint` or :func:`jax.remat`; in these cases it's best to + use the :meth:`StatefulPRNG.key` method to produce a standard JAX PRNG key. + + Examples: + >>> from jax.experimental import random + >>> rng = random.stateful_rng(42) + + Repeated draws implicitly update the key: + + >>> rng.uniform() + Array(0.5302608, dtype=float32) + >>> rng.uniform() + Array(0.72766423, dtype=float32) + + This also works under transformations like :func:`jax.jit`: + + >>> import jax + >>> jit_uniform = jax.jit(rng.uniform) + >>> jit_uniform() + Array(0.6672406, dtype=float32) + >>> jit_uniform() + Array(0.3890121, dtype=float32) + + Keys can be generated directly if desired: + + >>> rng.key() + Array((), dtype=key) overlaying: + [2954079971 3276725750] + >>> rng.key() + Array((), dtype=key) overlaying: + [2765691542 824333390] + """ + if seed is None: + if not core.trace_ctx.is_top_level(): + raise TypeError( + "When used within transformed code, jax.experimental.random.stateful_rng()" + " requires an explicit seed to be set.") + entropy = np.random.SeedSequence().entropy + assert isinstance(entropy, int) + seed = np.int64(entropy & np.iinfo(np.int64).max) + assert seed is not None + return StatefulPRNG( + _base_key=random.key(seed, impl=impl), + _counter=ref.new_ref(0) + ) diff --git a/jax/_src/test_loader.py b/jax/_src/test_loader.py new file mode 100644 index 000000000000..e29cec0d7481 --- /dev/null +++ b/jax/_src/test_loader.py @@ -0,0 +1,231 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Contains a custom unittest loader and test suite. + +Implements: +- A test filter based on the JAX_TEST_TARGETS and JAX_EXCLUDE_TEST_TARGETS + environment variables. +- A test suite that runs tests in parallel using threads if JAX_TEST_NUM_THREADS + is >= 1. +- Test decorators that mark a test case or test class as thread-hostile. +""" + +from __future__ import annotations + +from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager +import logging +import os +import re +import threading +import time +import unittest + +from absl.testing import absltest +from jax._src import config +from jax._src import test_warning_util +from jax._src import util + +logger = logging.getLogger(__name__) + + +_TEST_TARGETS = config.string_flag( + 'test_targets', os.getenv('JAX_TEST_TARGETS', ''), + 'Regular expression specifying which tests to run, called via re.search on ' + 'the test name. If empty or unspecified, run all tests.' +) + +_EXCLUDE_TEST_TARGETS = config.string_flag( + 'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''), + 'Regular expression specifying which tests NOT to run, called via re.search ' + 'on the test name. If empty or unspecified, run all tests.' +) + +TEST_NUM_THREADS = config.int_flag( + 'jax_test_num_threads', int(os.getenv('JAX_TEST_NUM_THREADS', '0')), + help='Number of threads to use for running tests. 0 means run everything ' + 'in the main thread. Using > 1 thread is experimental.' +) + +# We use a reader-writer lock to protect test execution. Tests that may run in +# parallel acquire a read lock; tests that are not thread-safe acquire a write +# lock. +_test_rwlock = util.Mutex() + +def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult): + if getattr(test.__class__, "thread_hostile", False): + _test_rwlock.writer_lock() + try: + test(result) # type: ignore + finally: + _test_rwlock.writer_unlock() + else: + _test_rwlock.reader_lock() + try: + test(result) # type: ignore + finally: + _test_rwlock.reader_unlock() + + +@contextmanager +def thread_unsafe_test(condition: bool = True): + """Decorator for tests that are not thread-safe. + + Args: + condition: If True, mark the test as thread-unsafe. If False, the test + runs normally without acquiring the write lock. Defaults to True. + + Note: this decorator (naturally) only applies to what it wraps, not to, say, + code in separate setUp() or tearDown() methods. + """ + if TEST_NUM_THREADS.value <= 0 or not condition: + yield + return + + _test_rwlock.assert_reader_held() + _test_rwlock.reader_unlock() + _test_rwlock.writer_lock() + try: + yield + finally: + _test_rwlock.writer_unlock() + _test_rwlock.reader_lock() + + +def thread_unsafe_test_class(condition: bool = True): + """Decorator that marks a TestCase class as thread-hostile. + + Args: + condition: If True, mark the test class as thread-hostile. If False, the + test class runs normally. Defaults to True. + """ + def f(klass): + assert issubclass(klass, unittest.TestCase), type(klass) + klass.thread_hostile = condition + return klass + return f + + +class ThreadSafeTestResult: + """ + Wraps a TestResult to make it thread safe. + + We do this by accumulating API calls and applying them in a batch under a + lock at the conclusion of each test case. + + We duck type instead of inheriting from TestResult because we aren't actually + a perfect implementation of TestResult, and would rather get a loud error + for things we haven't implemented. + """ + def __init__(self, lock: threading.Lock, result: unittest.TestResult): + self.lock = lock + self.test_result = result + self.actions: list[Callable[[], None]] = [] + + def startTest(self, test: unittest.TestCase): + logger.info("Test start: %s", test.id()) + self.start_time = time.time() + + def stopTest(self, test: unittest.TestCase): + logger.info("Test stop: %s", test.id()) + stop_time = time.time() + with self.lock: + # If test_result is an ABSL _TextAndXMLTestResult we override how it gets + # the time. This affects the timing that shows up in the XML output + # consumed by CI. + time_getter = getattr(self.test_result, "time_getter", None) + try: + self.test_result.time_getter = lambda: self.start_time + self.test_result.startTest(test) + for callback in self.actions: + callback() + self.test_result.time_getter = lambda: stop_time + self.test_result.stopTest(test) + finally: + if time_getter is not None: + self.test_result.time_getter = time_getter + + def addSuccess(self, test: unittest.TestCase): + self.actions.append(lambda: self.test_result.addSuccess(test)) + + def addSkip(self, test: unittest.TestCase, reason: str): + self.actions.append(lambda: self.test_result.addSkip(test, reason)) + + def addError(self, test: unittest.TestCase, err): + self.actions.append(lambda: self.test_result.addError(test, err)) + + def addFailure(self, test: unittest.TestCase, err): + self.actions.append(lambda: self.test_result.addFailure(test, err)) + + def addExpectedFailure(self, test: unittest.TestCase, err): + self.actions.append(lambda: self.test_result.addExpectedFailure(test, err)) + + def addDuration(self, test: unittest.TestCase, elapsed): + self.actions.append(lambda: self.test_result.addDuration(test, elapsed)) + + +class JaxTestSuite(unittest.TestSuite): + """Runs tests in parallel using threads if TEST_NUM_THREADS is > 1. + + Caution: this test suite does not run setUpClass or setUpModule methods if + thread parallelism is enabled. + """ + + def __init__(self, suite: unittest.TestSuite): + super().__init__(list(suite)) + + def run(self, result: unittest.TestResult, debug: bool = False) -> unittest.TestResult: + if TEST_NUM_THREADS.value <= 0: + return super().run(result) + + test_warning_util.install_threadsafe_warning_handlers() + + executor = ThreadPoolExecutor(TEST_NUM_THREADS.value) + lock = threading.Lock() + futures = [] + + def run_test(test): + """Recursively runs tests in a test suite or test case.""" + if isinstance(test, unittest.TestSuite): + for subtest in test: + run_test(subtest) + else: + test_result = ThreadSafeTestResult(lock, result) + futures.append(executor.submit(_run_one_test, test, test_result)) + + with executor: + run_test(self) + for future in futures: + future.result() + + return result + + +class JaxTestLoader(absltest.TestLoader): + suiteClass = JaxTestSuite + + def getTestCaseNames(self, testCaseClass): + names = super().getTestCaseNames(testCaseClass) + if _TEST_TARGETS.value: + pattern = re.compile(_TEST_TARGETS.value) + names = [name for name in names + if pattern.search(f"{testCaseClass.__name__}.{name}")] + if _EXCLUDE_TEST_TARGETS.value: + pattern = re.compile(_EXCLUDE_TEST_TARGETS.value) + names = [name for name in names + if not pattern.search(f"{testCaseClass.__name__}.{name}")] + return names diff --git a/jax/_src/test_multiprocess.py b/jax/_src/test_multiprocess.py new file mode 100644 index 000000000000..f3bb28d4c443 --- /dev/null +++ b/jax/_src/test_multiprocess.py @@ -0,0 +1,454 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper for running multi-process tests.""" + +import functools +import os +import pathlib +import re +import signal +import subprocess +import sys +import time + +from absl import app +import absl.flags +from absl.testing import absltest +from absl.testing import parameterized + +from jax._src import distributed +from jax._src import xla_bridge as xb +from jax._src import test_util as jtu +from jax._src.config import config +from jax._src.lib import cuda_versions +from jax._src.lib import _jax + +try: + import portpicker # pytype: disable=import-error +except ImportError: + portpicker = None + +NUM_PROCESSES = absl.flags.DEFINE_integer( + "num_processes", None, "Number of processes to use." +) + +_GPUS_PER_PROCESS = absl.flags.DEFINE_integer( + "gpus_per_process", + 0, + "Number of GPUs per worker process.", +) + +_TPU_CHIPS_PER_PROCESS = absl.flags.DEFINE_integer( + "tpu_chips_per_process", + 0, + "Number of TPU chips per worker process.", +) + +CPU_COLLECTIVES_IMPLEMENTATION = absl.flags.DEFINE_string( + "cpu_collectives_implementation", + "", + "CPU collectives implementation to use. Uses default if empty.", +) + +EXTRA_TEST_ARGS = absl.flags.DEFINE_multi_string( + "extra_test_args", [], "Extra flags to pass to worker process." +) + +# For internal use. +MULTIPROCESS_TEST_WORKER_ID = absl.flags.DEFINE_integer( + "multiprocess_test_worker_id", + -1, + "Worker id. Set by main test process; should not be set by users.", +) + +_MULTIPROCESS_TEST_CONTROLLER_ADDRESS = absl.flags.DEFINE_string( + "multiprocess_test_controller_address", + "", + "Address of the JAX controller. Set by the main test process; should not be" + " set by users.", +) + +_DEVICE_IDS = absl.flags.DEFINE_list( + "device_ids", + None, + "List of device ids to use. Set by main test process; should not be set by" + " users.", +) + +_ENABLE_MEGASCALE = absl.flags.DEFINE_bool( + "enable_megascale", False, "If true, enable Megascale runtime." +) + +_HEARTBEAT_TIMEOUT = absl.flags.DEFINE_integer( + "heartbeat_timeout", + 5, + "Timeout in seconds for heartbeat checks. Set to a higher number when" + " running under sanitizers.", +) + +_SHUTDOWN_TIMEOUT = absl.flags.DEFINE_integer( + "shutdown_timeout", + 15, + "JAX shutdown timeout duration in seconds for each subprocess worker. If " + "your test is timing out, try increasing this value.", +) + +_BARRIER_TIMEOUT = absl.flags.DEFINE_integer( + "barrier_timeout", + 10, + "Barrier timeout in seconds. Set to a higher number when running under" + " sanitizers.", +) + +_INITIALIZATION_TIMEOUT = absl.flags.DEFINE_integer( + "initialization_timeout", + 10, + "Coordination service initialization timeout in seconds. Set to a higher" + " number when running under sanitizers.", +) + +_DUMP_HLO = absl.flags.DEFINE_bool( + "dump_hlo", + False, + "If true, dump per-process HLO to undeclared outputs. They will show up in" + " sponge artifacts under the directory 'jax_%process_idx%_hlo_dump'.", +) + +expect_failures_with_regex = None + + +def main(shard_main=None): + config.config_with_absl() + app.run(functools.partial(_main, shard_main=shard_main)) + + +class GracefulKiller: + """Add a signal handler that sets a flag if SIGINT or SIGTERM are caught.""" + + # From https://stackoverflow.com/a/31464349 + kill_now = False + + def __init__(self): + signal.signal(signal.SIGINT, self.exit_gracefully) + signal.signal(signal.SIGTERM, self.exit_gracefully) + + def exit_gracefully(self, sig_num, unused_stack_frame): + print(f"Caught signal: {signal.Signals(sig_num).name} ({sig_num})") + self.kill_now = True + + +def _main(argv, shard_main): + # TODO(emilyaf): Enable multiprocess tests on Windows. + if sys.platform == "win32": + print("Multiprocess tests are not supported on Windows.") + return + num_processes = NUM_PROCESSES.value + if MULTIPROCESS_TEST_WORKER_ID.value >= 0: + local_device_ids = _DEVICE_IDS.value + if local_device_ids is not None: + local_device_ids = map(int, local_device_ids) + distributed.initialize( + _MULTIPROCESS_TEST_CONTROLLER_ADDRESS.value, + num_processes=num_processes, + process_id=MULTIPROCESS_TEST_WORKER_ID.value, + local_device_ids=local_device_ids, + heartbeat_timeout_seconds=_HEARTBEAT_TIMEOUT.value, + shutdown_timeout_seconds=_SHUTDOWN_TIMEOUT.value, + initialization_timeout=_INITIALIZATION_TIMEOUT.value, + ) + if shard_main is not None: + return shard_main() + return absltest.main(testLoader=jtu.JaxTestLoader()) + + if not argv[0].endswith(".py"): # Skip the interpreter path if present. + argv = argv[1:] + + if num_processes is None: + raise ValueError("num_processes must be set") + gpus_per_process = _GPUS_PER_PROCESS.value + tpu_chips_per_process = _TPU_CHIPS_PER_PROCESS.value + num_tpu_chips = num_processes * tpu_chips_per_process + if num_tpu_chips == 0: + pass + elif num_tpu_chips == 1: + assert tpu_chips_per_process == 1 + tpu_host_bounds = "1,1,1" + tpu_chips_per_host_bounds = "1,1,1" + elif num_tpu_chips == 4: + if tpu_chips_per_process == 1: + tpu_host_bounds = "2,2,1" + tpu_chips_per_host_bounds = "1,1,1" + elif tpu_chips_per_process == 2: + tpu_host_bounds = "2,1,1" + tpu_chips_per_host_bounds = "1,2,1" + elif tpu_chips_per_process == 4: + tpu_host_bounds = "1,1,1" + tpu_chips_per_host_bounds = "2,2,1" + else: + raise ValueError( + "Invalid number of TPU chips per worker {}".format( + tpu_chips_per_process + ) + ) + elif num_tpu_chips == 8: + if tpu_chips_per_process == 1: + tpu_host_bounds = "4,2,1" + tpu_chips_per_host_bounds = "1,1,1" + elif tpu_chips_per_process == 4: + # Note: this branch assumes we are using 2x4 v6e LitePod, and will not + # work with 4x2 v5e LitePod. + tpu_host_bounds = "1,2,1" + tpu_chips_per_host_bounds = "2,2,1" + elif tpu_chips_per_process == 8: + tpu_host_bounds = "1,1,1" + tpu_chips_per_host_bounds = "2,4,1" + else: + # TODO(phawkins): implement other cases. + raise ValueError( + "Invalid number of TPU chips per worker {}".format( + tpu_chips_per_process + ) + ) + else: + raise ValueError(f"Invalid number of TPU chips {num_tpu_chips}") + + if portpicker is None: + slicebuilder_ports = [10000 + i for i in range(num_processes)] + else: + slicebuilder_ports = [ + portpicker.pick_unused_port() for _ in range(num_processes) + ] + slicebuilder_addresses = ",".join( + f"localhost:{port}" for port in slicebuilder_ports + ) + megascale_coordinator_port = None + + if gpus_per_process > 0: + # Get the number of GPUs visible to this process without initializing the runtime + if cuda_versions is not None: + local_device_count = cuda_versions.cuda_device_count() + if num_processes * gpus_per_process > local_device_count: + print( + f"Cannot run {num_processes} processes with {gpus_per_process} GPU(s) " + f"each on a system with only {local_device_count} local GPU(s), " + f"starting {local_device_count // gpus_per_process} instead - test " + "cases will likely be skipped!" + ) + num_processes = local_device_count // gpus_per_process + + if portpicker is None: + jax_port = 9876 + else: + # TODO(emilyaf): Use a port server if there are flaky port collisions due + # to pick_unused_port() racing among tests. + jax_port = portpicker.pick_unused_port() + subprocesses = [] + output_filenames = [] + output_files = [] + for i in range(num_processes): + device_ids = None + env = os.environ.copy() + + args = [ + "/proc/self/exe", + *argv, + f"--num_processes={num_processes}", + f"--multiprocess_test_worker_id={i}", + f"--multiprocess_test_controller_address=localhost:{jax_port}", + f"--heartbeat_timeout={_HEARTBEAT_TIMEOUT.value}", + f"--shutdown_timeout={_SHUTDOWN_TIMEOUT.value}", + f"--barrier_timeout={_BARRIER_TIMEOUT.value}", + f"--initialization_timeout={_INITIALIZATION_TIMEOUT.value}", + "--logtostderr", + ] + + if num_tpu_chips > 0: + device_ids = range( + i * tpu_chips_per_process, (i + 1) * tpu_chips_per_process) + env["CLOUD_TPU_TASK_ID"] = str(i) + env["TPU_CHIPS_PER_PROCESS_BOUNDS"] = tpu_chips_per_host_bounds + env["TPU_PROCESS_BOUNDS"] = tpu_host_bounds + env["TPU_PROCESS_ADDRESSES"] = slicebuilder_addresses + env["TPU_PROCESS_PORT"] = str(slicebuilder_ports[i]) + env["TPU_VISIBLE_CHIPS"] = ",".join(map(str, device_ids)) + env["ALLOW_MULTIPLE_LIBTPU_LOAD"] = "1" + + if gpus_per_process > 0: + device_ids = range(i * gpus_per_process, (i + 1) * gpus_per_process) + args.append(f"--jax_cuda_visible_devices={','.join(map(str, device_ids))}") + + if device_ids is not None: + args.append(f"--device_ids={','.join(map(str, device_ids))}") + + cpu_collectives_impl = CPU_COLLECTIVES_IMPLEMENTATION.value + if cpu_collectives_impl: + args.append( + f"--jax_cpu_collectives_implementation={cpu_collectives_impl}" + ) + + if _ENABLE_MEGASCALE.value or cpu_collectives_impl == "megascale": + if portpicker is None: + megascale_port = 9877 + else: + megascale_port = portpicker.pick_unused_port() + if megascale_coordinator_port is None: + megascale_coordinator_port = megascale_port + args += [ + f"--megascale_coordinator_address=localhost:{megascale_coordinator_port}", + f"--megascale_port={megascale_port}", + ] + + args += EXTRA_TEST_ARGS.value + + undeclared_outputs = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", "/tmp") + stdout_name = f"{undeclared_outputs}/jax_{i}_stdout.log" + stderr_name = f"{undeclared_outputs}/jax_{i}_stderr.log" + + if _DUMP_HLO.value: + hlo_dump_path = f"{undeclared_outputs}/jax_{i}_hlo_dump/" + os.makedirs(hlo_dump_path, exist_ok=True) + env["XLA_FLAGS"] = f"--xla_dump_to={hlo_dump_path}" + + stdout = open(stdout_name, "wb") + stderr = open(stderr_name, "wb") + print(f"Launching process {i}:") + print(f" stdout: {stdout_name}") + print(f" stderr: {stderr_name}") + proc = subprocess.Popen(args, env=env, stdout=stdout, stderr=stderr) + subprocesses.append(proc) + output_filenames.append((stdout_name, stderr_name)) + output_files.append((stdout, stderr)) + + print(" All launched, running ".center(80, "="), flush=True) + + # Wait for all the children to finish or for a SIGTERM from bazel. If we get + # SIGTERM, we still want to collect their logs, so kill them and continue. + killer = GracefulKiller() + running_procs = dict(enumerate(subprocesses)) + while not killer.kill_now and running_procs: + time.sleep(0.1) + for i, proc in list(running_procs.items()): + if proc.poll() is not None: + print(f"Process {i} finished.", flush=True) + running_procs.pop(i) + if killer.kill_now and running_procs: + print("Caught termination, terminating remaining children.", flush=True) + + # Send a SIGTERM to each child process, to let it know it should terminate. + for i, proc in running_procs.items(): + proc.terminate() + print(f"Process {i} terminated.", flush=True) + + # We give the child process(es) a few seconds for their own cleanup, and + # keep the rest (up to 15s) for copying the children logs into our own. + time.sleep(5) + + # Send a SIGKILL (a "hard" kill) to each child process. This is CRITICAL: + # without it, this process may end up waiting a long time on the proc.wait() + # below, and never get to saving the children logs, making test timeouts + # very hard to debug. + for i, proc in running_procs.items(): + proc.kill() + print(f"Process {i} killed.") + print("Killed all child processes.", flush=True) + + retvals = [] + stdouts = [] + stderrs = [] + for proc, fds, (stdout, stderr) in zip( + subprocesses, output_files, output_filenames + ): + retvals.append(proc.wait()) + for fd in fds: + fd.close() + stdouts.append(pathlib.Path(stdout).read_text(errors="replace")) + stderrs.append(pathlib.Path(stderr).read_text(errors="replace")) + + print(" All finished ".center(80, "="), flush=True) + + print(" Summary ".center(80, "=")) + for i, (retval, stdout, stderr) in enumerate(zip(retvals, stdouts, stderrs)): + m = re.search(r"Ran \d+ tests? in [\d.]+s\n\n.*", stderr, re.MULTILINE) + result = m.group().replace("\n\n", "; ") if m else "Test crashed?" + print( + f"Process {i}, ret: {retval}, len(stdout): {len(stdout)}, " + f"len(stderr): {len(stderr)}; {result}" + ) + + print(" Detailed logs ".center(80, "=")) + for i, (retval, stdout, stderr) in enumerate(zip(retvals, stdouts, stderrs)): + print(f" Process {i}: return code: {retval} ".center(80, "=")) + if stdout: + print(f" Process {i} stdout ".center(80, "-")) + print(stdout) + if stderr: + print(f" Process {i} stderr ".center(80, "-")) + print(stderr) + + print(" Done detailed logs ".center(80, "="), flush=True) + for i, (retval, stderr) in enumerate(zip(retvals, stderrs)): + if retval != 0: + if expect_failures_with_regex is not None: + assert re.search( + expect_failures_with_regex, stderr + ), f"process {i} failed, expected regex: {expect_failures_with_regex}" + else: + assert retval == 0, f"process {i} failed, return value: {retval}" + + +class MultiProcessTest(parameterized.TestCase): + + def setUp(self): + """Start tests together.""" + super().setUp() + if xb.process_count() == 1: + self.skipTest("Test requires multiple processes.") + assert xb.process_count() == NUM_PROCESSES.value, ( + xb.process_count(), + NUM_PROCESSES.value, + ) + # Make sure all processes are at the same test case. + client = distributed.global_state.client + try: + client.wait_at_barrier( + f"{self._testMethodName}_start", _BARRIER_TIMEOUT.value * 1000) + except _jax.JaxRuntimeError as e: + msg, *_ = e.args + if msg.startswith("DEADLINE_EXCEEDED"): + raise RuntimeError( + f"Init or some test executed earlier than {self._testMethodName} " + "failed. Check logs from earlier tests to debug further. We " + "recommend debugging that specific failed test with " + "`--test_filter` before running the full test suite again." + ) from e + + def tearDown(self): + """End tests together.""" + client = distributed.global_state.client + # Ensure a shared fate for tests where a subset of processes run different + # test assertions (i.e. some processes may pass and some processes fail - + # but the overall test should fail). + try: + client.wait_at_barrier( + f"{self._testMethodName}_end", _BARRIER_TIMEOUT.value * 1000) + except _jax.JaxRuntimeError as e: + msg, *_ = e.args + if msg.startswith("DEADLINE_EXCEEDED"): + raise RuntimeError( + f"Test {self._testMethodName} failed in another process. We " + "recommend debugging that specific failed test with " + "`--test_filter` before running the full test suite again." + ) from e + super().tearDown() diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index c55dc2a560e0..d16c813d83b6 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -16,8 +16,7 @@ from __future__ import annotations import collections -from collections.abc import Callable, Generator, Iterable, Sequence -from concurrent.futures import ThreadPoolExecutor +from collections.abc import Callable, Generator, Iterable, Iterator, Sequence from contextlib import ExitStack, contextmanager import datetime import functools @@ -26,21 +25,18 @@ import logging import math import os +from pathlib import Path import platform import re import sys import tempfile import textwrap import threading -import time from typing import Any, TextIO import unittest import zlib -from absl.testing import absltest from absl.testing import parameterized -import jax -from jax import lax from jax._src import api from jax._src import compilation_cache from jax._src import config @@ -49,22 +45,35 @@ from jax._src import dispatch from jax._src import dtypes as _dtypes from jax._src import lib as _jaxlib +from jax._src import mesh as mesh_lib from jax._src import monitoring +from jax._src import sharding_impls from jax._src import test_warning_util from jax._src import xla_bridge from jax._src import util -from jax._src import mesh as mesh_lib from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm from jax._src.interpreters import mlir +from jax._src.lax import lax +from jax._src.lib import cuda_versions +from jax._src.lib.mlir.dialects import hlo from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact from jax._src.public_test_util import ( # noqa: F401 _assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads, - check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, rand_like, tolerance) + check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, rand_like, tolerance, ToleranceDict) +from jax._src.test_loader import thread_unsafe_test as thread_unsafe_test +from jax._src.test_loader import thread_unsafe_test_class as thread_unsafe_test_class +from jax._src.test_loader import JaxTestLoader as JaxTestLoader +from jax._src.test_loader import TEST_NUM_THREADS as TEST_NUM_THREADS +from jax._src.tree_util import tree_all, tree_flatten, tree_map, tree_unflatten +from jax._src.typing import ArrayLike, DTypeLike from jax._src.util import unzip2 -from jax.tree_util import tree_all, tree_flatten, tree_map, tree_unflatten import numpy as np import numpy.random as npr +# When running tests, install the ABSL failure signal handler. This dumps a +# C++ back trace on fatal signals, which is helpful for debugging. +util.install_failure_signal_handler() + # This submodule includes private test utilities that are not exported to # jax.test_util. Functionality appearing here is for internal use only, and @@ -89,22 +98,12 @@ 'sampling process is terminated.' ) -_SKIP_SLOW_TESTS = config.bool_flag( +SKIP_SLOW_TESTS = config.bool_flag( 'jax_skip_slow_tests', config.bool_env('JAX_SKIP_SLOW_TESTS', False), help='Skip tests marked as slow (> 5 sec).' ) -_TEST_TARGETS = config.string_flag( - 'test_targets', os.getenv('JAX_TEST_TARGETS', ''), - 'Regular expression specifying which tests to run, called via re.search on ' - 'the test name. If empty or unspecified, run all tests.' -) -_EXCLUDE_TEST_TARGETS = config.string_flag( - 'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''), - 'Regular expression specifying which tests NOT to run, called via re.search ' - 'on the test name. If empty or unspecified, run all tests.' -) TEST_WITH_PERSISTENT_COMPILATION_CACHE = config.bool_flag( 'jax_test_with_persistent_compilation_cache', config.bool_env('JAX_TEST_WITH_PERSISTENT_COMPILATION_CACHE', False), @@ -118,11 +117,6 @@ 'deterministic, interactive'), ) -TEST_NUM_THREADS = config.int_flag( - 'jax_test_num_threads', int(os.getenv('JAX_TEST_NUM_THREADS', '0')), - help='Number of threads to use for running tests. 0 means run everything ' - 'in the main thread. Using > 1 thread is experimental.' -) # We sanitize test names to ensure they work with "unitttest -k" and # "pytest -k" test filtering. pytest accepts '[' and ']' but unittest -k @@ -131,27 +125,25 @@ def sanitize_test_name(s: str) -> str: return kSanitizeNameRE.sub("_", s) -def num_float_bits(dtype): +def num_float_bits(dtype: DTypeLike) -> int: return _dtypes.finfo(_dtypes.canonicalize_dtype(dtype)).bits -def to_default_dtype(arr): +def to_default_dtype(arr: ArrayLike) -> np.ndarray: """Convert a value to an array with JAX's default dtype. This is generally used for type conversions of values returned by numpy functions, - to make their dtypes take into account the state of the ``jax_enable_x64`` and - ``jax_default_dtype_bits`` flags. + to make their dtypes take into account the state of the ``jax_enable_x64`` flag. """ arr = np.asarray(arr) - dtype = _dtypes._default_types.get(arr.dtype.kind) - return arr.astype(_dtypes.canonicalize_dtype(dtype)) if dtype else arr + dtype_fn = _dtypes.default_types.get(arr.dtype.kind) + return arr.astype(dtype_fn()) if dtype_fn else arr -def with_jax_dtype_defaults(func, use_defaults=True): +def with_jax_dtype_defaults(func: Callable[..., Any], use_defaults: bool = True): """Return a version of a function with outputs that match JAX's default dtypes. This is generally used to wrap numpy functions within tests, in order to make their default output dtypes match those of corresponding JAX functions, taking - into account the state of the ``jax_enable_x64`` and ``jax_default_dtype_bits`` - flags. + into account the state of the ``jax_enable_x64`` flag. Args: use_defaults : whether to convert any given output to the default dtype. May be @@ -168,7 +160,7 @@ def wrapped(*args, **kwargs): return tree_map(f, result, use_defaults) return wrapped -def is_sequence(x): +def is_sequence(x: Any) -> bool: try: iter(x) except TypeError: @@ -176,14 +168,16 @@ def is_sequence(x): else: return True -def _normalize_tolerance(tol): +def _normalize_tolerance(tol: int | float | ToleranceDict | None) -> ToleranceDict: tol = tol or 0 if isinstance(tol, dict): return {np.dtype(k): v for k, v in tol.items()} else: return dict.fromkeys(_default_tolerance, tol) -def join_tolerance(tol1, tol2): +def join_tolerance( + tol1: int | float | ToleranceDict | None, + tol2: int | float | ToleranceDict | None) -> ToleranceDict: tol1 = _normalize_tolerance(tol1) tol2 = _normalize_tolerance(tol2) out = tol1 @@ -192,7 +186,7 @@ def join_tolerance(tol1, tol2): return out -def check_eq(xs, ys, err_msg=''): +def check_eq(xs: Any, ys: Any, err_msg: str = '') -> None: assert_close = partial(_assert_numpy_allclose, err_msg=err_msg) tree_all(tree_map(assert_close, xs, ys)) @@ -269,13 +263,6 @@ def event_listener(name, *args): elif name == "batched_device_put_end": thread_local_state.nested_device_put_count -= 1 - elif name == "pjit._infer_params_impl": - # For infer_params, we collect per-function data, but only while a context - # manager is active. - infer_counts = thread_local_state.infer_params_fun_counts - if infer_counts is not None: - (fun,) = args - infer_counts[fun] += 1 elif name == "lower_jaxpr_to_fun": # For infer_params, we collect per-function data, but only while a context # manager is active. @@ -302,11 +289,12 @@ def count_event(): count_device_put = count_events("batched_device_put") count_device_put_fast_path_hit = count_events("batched_copy_array") -count_pjit_cpp_cache_miss = count_events("pjit_lower") -count_jit_tracing_cache_miss = count_events("create_pjit_jaxpr") +count_pjit_cpp_cache_miss = count_events("jit_cpp_cache_miss") +count_jit_tracing_cache_miss = count_events("trace_to_jaxpr") count_aot_jit_cpp_cache_miss = count_events("stages_compiled_call") count_jit_and_pmap_lowerings = count_events("lower_jaxpr_to_module") count_jit_compilation_cache_miss = count_events("pxla_cached_compilation") +count_compilation_after_persistent_cache_miss = count_events("compile_after_persistent_compilation_miss") count_jax_array_shard_arg_calls = count_events("_array_shard_arg") @@ -335,14 +323,18 @@ def count_subjaxpr_to_hlo_conversion(fun_name): assert thread_local_state.lower_jaxpr_to_fun_counts is None counts = collections.Counter() thread_local_state.lower_jaxpr_to_fun_counts = counts + def get(): + key, *others = {k for k in counts if fun_name in k} # type: ignore + if others: raise Exception(f"ambiguous name: {fun_name}") + return counts[key] try: - yield lambda: counts[fun_name] + yield get finally: thread_local_state.lower_jaxpr_to_fun_counts = None @contextmanager -def collect_lowered_jaxprs() -> Generator[Sequence[tuple[core.ClosedJaxpr, +def collect_lowered_jaxprs() -> Iterator[Sequence[tuple[core.ClosedJaxpr, mlir.ir.Module]]]: """ Collects all the pairs of (jaxpr, mlir_module) that are lowered. @@ -363,6 +355,18 @@ def assert_num_jit_and_pmap_compilations(times): raise AssertionError(f"Expected exactly {times} XLA compilations, " f"but executed {count()}") +@contextmanager +def count_internal_device_puts(): + before = _jaxlib._jax.get_internal_device_put_info() + counts = {} + try: + yield lambda: counts + finally: + after = _jaxlib._jax.get_internal_device_put_info() + for k, v in after.items(): + diff = v - before.get(k, 0) + if diff != 0: + counts[k] = diff def jaxlib_version() -> tuple[int, ...]: return _jaxlib.version @@ -373,8 +377,9 @@ def device_under_test(): def supported_dtypes(): if device_under_test() == "tpu": - types = {np.bool_, np.int8, np.int16, np.int32, np.uint8, np.uint16, - np.uint32, _dtypes.bfloat16, np.float16, np.float32, np.complex64, + types = {np.bool_, _dtypes.int4, np.int8, np.int16, np.int32, + _dtypes.uint4, np.uint8, np.uint16, np.uint32, + _dtypes.bfloat16, np.float16, np.float32, np.complex64, _dtypes.float8_e4m3fn, _dtypes.float8_e4m3b11fnuz, _dtypes.float8_e5m2} elif device_under_test() == "gpu": @@ -386,8 +391,8 @@ def supported_dtypes(): elif device_under_test() == "METAL": types = {np.int32, np.uint32, np.float32} else: - types = {np.bool_, np.int8, np.int16, np.int32, np.int64, - np.uint8, np.uint16, np.uint32, np.uint64, + types = {np.bool_, _dtypes.int4, np.int8, np.int16, np.int32, np.int64, + _dtypes.uint4, np.uint8, np.uint16, np.uint32, np.uint64, _dtypes.bfloat16, np.float16, np.float32, np.float64, np.complex64, np.complex128} if not config.enable_x64.value: @@ -397,17 +402,50 @@ def supported_dtypes(): def is_device_rocm(): return 'rocm' in xla_bridge.get_backend().platform_version +def get_rocm_version(): + rocm_path = os.environ.get("ROCM_PATH", "/opt/rocm") + version_path = Path(rocm_path) / ".info" / "version" + if not version_path.exists(): + raise FileNotFoundError(f"Expected ROCm version file at {version_path}") + version_str = version_path.read_text().strip() + major, minor, *_ = version_str.split(".") + return int(major), int(minor) + def is_device_cuda(): return 'cuda' in xla_bridge.get_backend().platform_version def is_cloud_tpu(): return running_in_cloud_tpu_vm +def is_optimized_build(): + return _jaxlib._jax.is_optimized_build() + +def is_asan(): + return _jaxlib._jax.is_asan() + +def is_msan(): + return _jaxlib._jax.is_msan() + +def is_tsan(): + return _jaxlib._jax.is_tsan() + +def is_sanitized(): + return _jaxlib._jax.is_sanitized() + +def is_gil_disabled() -> bool: + return not sys._is_gil_enabled() if hasattr(sys, "_is_gil_enabled") else False + +def is_test_rbe(): + """Check for a variable set by the RBE toolchain under testing.""" + return ( + os.getenv("IS_JAX_RBE_TESTING", "").lower() in {"true", "1", "yes", "y"} + ) + # Returns True if it is not cloud TPU. If it is cloud TPU, returns True if it is # built at least `date``. # TODO(b/327203806): after libtpu adds a XLA version and the oldest support # libtpu contains the XLA version, remove using built time to skip tests. -def if_cloud_tpu_at_least(year: int, month: int, day: int): +def is_cloud_tpu_at_least(year: int, month: int, day: int): date = datetime.date(year, month, day) if not is_cloud_tpu(): return True @@ -428,14 +466,22 @@ def pjrt_c_api_version_at_least(major_version: int, minor_version: int): return True return pjrt_c_api_versions >= (major_version, minor_version) +def stablehlo_version_at_least(required_version: str): + plugin_version = xla_bridge.backend_stablehlo_version() + if plugin_version is None: + return True + return hlo.get_smaller_version( + ".".join(map(str, plugin_version)), required_version + ) == plugin_version + def get_tpu_version() -> int: if device_under_test() != "tpu": raise ValueError("Device is not TPU") - kind = jax.devices()[0].device_kind - if kind.endswith(' lite'): - kind = kind[:-len(' lite')] - assert kind[:-1] == "TPU v", kind - return int(kind[-1]) + kind = xla_bridge.devices()[0].device_kind + match = re.match(r"TPU[^\d]*(\d+)", kind) + if match is None: + raise ValueError(f"Device kind {kind} is not supported") + return int(match.group(1)) def is_device_tpu_at_least(version: int) -> bool: if device_under_test() != "tpu": @@ -447,19 +493,87 @@ def is_device_tpu(version: int | None = None, variant: str = "") -> bool: return False if version is None: return True - device_kind = jax.devices()[0].device_kind + device_kind = xla_bridge.devices()[0].device_kind expected_version = f"v{version}{variant}" # Special case v5e until the name is updated in device_kind if expected_version == "v5e": return "v5 lite" in device_kind elif expected_version == "v6e": return "v6 lite" in device_kind + elif expected_version == "v5p": + return device_kind.endswith("v5") + elif expected_version == "v7x": + return "TPU7x" in device_kind return expected_version in device_kind +def pattern_search(patterns: str | Sequence[str], string: str): + if not isinstance(patterns, tuple): + patterns = (patterns,) # type: ignore + + for pattern in patterns: + if re.search(pattern, string): + return pattern + return None + +def device_kind_match(device_patterns: str | Sequence[str]): + device_kind = xla_bridge.devices()[0].device_kind + matching_pattern = pattern_search(device_patterns, device_kind) + return matching_pattern + +def skip_if_errors( + *, + error_patterns: str | Sequence[str], + device_patterns: str | Sequence[str], + reason: str | Callable[[str, str], str], +): + """Skip if both error message and device kind match a corresponding pattern.""" + def skip(test_method): + @functools.wraps(test_method) + def test_method_wrapper(self, *args, **kwargs): + device_kind = xla_bridge.devices()[0].device_kind + try: + return test_method(self, *args, **kwargs) + except Exception as e: + matching_error_pattern = pattern_search(error_patterns, str(e)) + matching_device_pattern = pattern_search(device_patterns, device_kind) + if matching_error_pattern and matching_device_pattern: + if not isinstance(reason, str): + reason_str = reason(matching_error_pattern, matching_device_pattern) + else: + reason_str = reason + self.skipTest(reason_str) + raise + return test_method_wrapper + return skip + +skip_if_mosaic_gpu_exceeds_shared_memory = functools.partial( + skip_if_errors, + error_patterns="kernel exceeds available shared memory", + reason=lambda err, dev: f"Mosaic GPU kernel exceeds shared memory on {dev}", +) + +skip_if_triton_exceeds_shared_memory = functools.partial( + skip_if_errors, + error_patterns="Shared memory size limit exceeded", + reason=lambda err, dev: f"Triton kernel exceeds shared memory on {dev}", +) + +def get_cuda_nonportable_max_cluster_size(): + # Per-device nonportable maximum cluster sizes for Jetson Thor and DGX + # Spark (GB10) determined by querying cuOccupancyMaxPotentialClusterSize + if device_kind_match("Thor$"): + return 8 + elif device_kind_match("GB10$"): + return 12 + # 16 is the nonportable maximum cluster size on: + # - Hopper: https://docs.nvidia.com/cuda/hopper-tuning-guide/index.html#:~:text=cluster%20size%20of-,16,-by%20opting%20in + # - Blackwell: https://docs.nvidia.com/cuda/blackwell-tuning-guide/index.html#:~:text=cluster%20size%20of-,16,-by%20opting%20in + return 16 + def is_cuda_compute_capability_at_least(capability: str) -> bool: if not is_device_cuda(): return False - d, *_ = jax.local_devices(backend="gpu") + d, *_ = xla_bridge.local_devices(backend="gpu") target = tuple(int(x) for x in capability.split(".")) current = tuple(int(x) for x in d.compute_capability.split(".")) return current >= target @@ -467,35 +581,42 @@ def is_cuda_compute_capability_at_least(capability: str) -> bool: def is_cuda_compute_capability_equal(capability: str) -> bool: if not is_device_cuda(): return False - d, *_ = jax.local_devices(backend="gpu") + d, *_ = xla_bridge.local_devices(backend="gpu") target = tuple(int(x) for x in capability.split(".")) current = tuple(int(x) for x in d.compute_capability.split(".")) return current == target +def is_cuda_version_at_least(major: int, minor: int): + assert 0 <= major + assert 0 <= minor < 100 + return ( + cuda_versions is not None + and cuda_versions.cuda_runtime_get_version() >= major * 1000 + minor * 10 + ) + class CudaArchSpecificTest: """A mixin with methods allowing to skip arch specific tests.""" def skip_unless_sm90a(self): if not is_cuda_compute_capability_equal("9.0"): - self.skipTest("Only works on GPU with capability sm90a") + self.skipTest("Only works on GPU with capability sm90a") # pytype: disable=attribute-error def skip_unless_sm100a(self): if not is_cuda_compute_capability_equal("10.0"): - self.skipTest("Only works on GPU with capability sm100a") + self.skipTest("Only works on GPU with capability sm100a") # pytype: disable=attribute-error def _get_device_tags(): """returns a set of tags defined for the device under test""" if is_device_rocm(): - device_tags = {device_under_test(), "rocm"} + return {device_under_test(), "rocm"} elif is_device_cuda(): - device_tags = {device_under_test(), "cuda"} + return {device_under_test(), "cuda"} elif device_under_test() == "METAL": - device_tags = {device_under_test(), "gpu"} + return {device_under_test(), "gpu"} else: - device_tags = {device_under_test()} - return device_tags + return {device_under_test()} def test_device_matches(device_types: Iterable[str]) -> bool: assert not isinstance( @@ -510,26 +631,57 @@ def test_device_matches(device_types: Iterable[str]) -> bool: test_device_matches.__test__ = False # This isn't a test case, pytest. -def _device_filter(predicate): +def _device_filter(predicate, skip_reason=None): def skip(test_method): @functools.wraps(test_method) def test_method_wrapper(self, *args, **kwargs): device_tags = _get_device_tags() if not predicate(): - test_name = getattr(test_method, '__name__', '[unknown test]') - raise unittest.SkipTest( - f"{test_name} not supported on device with tags {device_tags}.") + if skip_reason: + raise unittest.SkipTest(skip_reason) + else: + test_name = getattr(test_method, '__name__', '[unknown test]') + raise unittest.SkipTest( + f"{test_name} not supported on device with tags {device_tags}.") return test_method(self, *args, **kwargs) return test_method_wrapper return skip -def skip_on_devices(*disabled_devices): - """A decorator for test methods to skip the test on certain devices.""" - return _device_filter(lambda: not test_device_matches(disabled_devices)) - -def run_on_devices(*enabled_devices): - """A decorator for test methods to run the test only on certain devices.""" - return _device_filter(lambda: test_device_matches(enabled_devices)) +def skip_on_devices(*disabled_devices, skip_reason=None): + """A decorator for test methods to skip the test on certain devices. + + Args: + *disabled_devices: Device names that the test should skip on. + skip_reason: Optional custom skip message when test is skipped. + """ + if skip_reason is None: + skip_messages = { + ("gpu",): "Skipped on all GPUs.", + ("cpu",): "Skipped on CPU.", + ("tpu",): "Skipped on TPU.", + ("cuda",): "Skipped on CUDA GPUs.", + ("rocm",): "Skipped on ROCm GPUs.", + } + skip_reason = skip_messages.get(disabled_devices) + return _device_filter(lambda: not test_device_matches(disabled_devices), skip_reason) + +def run_on_devices(*enabled_devices, skip_reason=None): + """A decorator for test methods to run the test only on certain devices. + + Args: + *enabled_devices: Device names that the test should run on. + skip_reason: Optional custom skip message when test is skipped. + """ + if skip_reason is None: + device_specific_skip_reasons = { + ("cpu",): "Skipped: CPU-only test.", + ("tpu",): "Skipped: TPU-only test.", + ("gpu",): "Skipped: GPU-only test.", + ("rocm",): "Skipped: ROCm-only test.", + ("cuda",): "Skipped: CUDA-only test.", + } + skip_reason = device_specific_skip_reasons.get(enabled_devices) + return _device_filter(lambda: test_device_matches(enabled_devices), skip_reason) def device_supports_buffer_donation(): """A decorator for test methods to run the test only on devices that support @@ -579,7 +731,7 @@ def pytest_mark_if_available(marker: str): """A decorator for test classes or methods to pytest.mark if installed.""" def wrap(func_or_class): try: - import pytest + import pytest # pytype: disable=import-error except ImportError: return func_or_class return getattr(pytest.mark, marker)(func_or_class) @@ -1044,165 +1196,6 @@ def sample_product(*args, **kw): """ return parameterized.parameters(*sample_product_testcases(*args, **kw)) -# We use a reader-writer lock to protect test execution. Tests that may run in -# parallel acquire a read lock; tests that are not thread-safe acquire a write -# lock. -_test_rwlock = util.Mutex() - -def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult): - if getattr(test.__class__, "thread_hostile", False): - _test_rwlock.writer_lock() - try: - test(result) # type: ignore - finally: - _test_rwlock.writer_unlock() - else: - _test_rwlock.reader_lock() - try: - test(result) # type: ignore - finally: - _test_rwlock.reader_unlock() - - -@contextmanager -def thread_unsafe_test(): - """Decorator for tests that are not thread-safe. - - Note: this decorator (naturally) only applies to what it wraps, not to, say, - code in separate setUp() or tearDown() methods. - """ - if TEST_NUM_THREADS.value <= 0: - yield - return - - _test_rwlock.assert_reader_held() - _test_rwlock.reader_unlock() - _test_rwlock.writer_lock() - try: - yield - finally: - _test_rwlock.writer_unlock() - _test_rwlock.reader_lock() - - -def thread_unsafe_test_class(): - "Decorator that marks a TestCase class as thread-hostile." - def f(klass): - assert issubclass(klass, unittest.TestCase), type(klass) - klass.thread_hostile = True - return klass - return f - - -class ThreadSafeTestResult: - """ - Wraps a TestResult to make it thread safe. - - We do this by accumulating API calls and applying them in a batch under a - lock at the conclusion of each test case. - - We duck type instead of inheriting from TestResult because we aren't actually - a perfect implementation of TestResult, and would rather get a loud error - for things we haven't implemented. - """ - def __init__(self, lock: threading.Lock, result: unittest.TestResult): - self.lock = lock - self.test_result = result - self.actions: list[Callable] = [] - - def startTest(self, test: unittest.TestCase): - del test - self.start_time = time.time() - - def stopTest(self, test: unittest.TestCase): - stop_time = time.time() - with self.lock: - # If test_result is an ABSL _TextAndXMLTestResult we override how it gets - # the time. This affects the timing that shows up in the XML output - # consumed by CI. - time_getter = getattr(self.test_result, "time_getter", None) - try: - self.test_result.time_getter = lambda: self.start_time - self.test_result.startTest(test) - for callback in self.actions: - callback() - self.test_result.time_getter = lambda: stop_time - self.test_result.stopTest(test) - finally: - if time_getter is not None: - self.test_result.time_getter = time_getter - - def addSuccess(self, test: unittest.TestCase): - self.actions.append(lambda: self.test_result.addSuccess(test)) - - def addSkip(self, test: unittest.TestCase, reason: str): - self.actions.append(lambda: self.test_result.addSkip(test, reason)) - - def addError(self, test: unittest.TestCase, err): - self.actions.append(lambda: self.test_result.addError(test, err)) - - def addFailure(self, test: unittest.TestCase, err): - self.actions.append(lambda: self.test_result.addFailure(test, err)) - - def addExpectedFailure(self, test: unittest.TestCase, err): - self.actions.append(lambda: self.test_result.addExpectedFailure(test, err)) - - def addDuration(self, test: unittest.TestCase, elapsed): - self.actions.append(lambda: self.test_result.addDuration(test, elapsed)) - - -class JaxTestSuite(unittest.TestSuite): - """Runs tests in parallel using threads if TEST_NUM_THREADS is > 1. - - Caution: this test suite does not run setUpClass or setUpModule methods if - thread parallelism is enabled. - """ - - def __init__(self, suite: unittest.TestSuite): - super().__init__(list(suite)) - - def run(self, result: unittest.TestResult, debug: bool = False) -> unittest.TestResult: - if TEST_NUM_THREADS.value <= 0: - return super().run(result) - - test_warning_util.install_threadsafe_warning_handlers() - - executor = ThreadPoolExecutor(TEST_NUM_THREADS.value) - lock = threading.Lock() - futures = [] - - def run_test(test): - "Recursively runs tests in a test suite or test case." - if isinstance(test, unittest.TestSuite): - for subtest in test: - run_test(subtest) - else: - test_result = ThreadSafeTestResult(lock, result) - futures.append(executor.submit(_run_one_test, test, test_result)) - - with executor: - run_test(self) - for future in futures: - future.result() - - return result - - -class JaxTestLoader(absltest.TestLoader): - suiteClass = JaxTestSuite - - def getTestCaseNames(self, testCaseClass): - names = super().getTestCaseNames(testCaseClass) - if _TEST_TARGETS.value: - pattern = re.compile(_TEST_TARGETS.value) - names = [name for name in names - if pattern.search(f"{testCaseClass.__name__}.{name}")] - if _EXCLUDE_TEST_TARGETS.value: - pattern = re.compile(_EXCLUDE_TEST_TARGETS.value) - names = [name for name in names - if not pattern.search(f"{testCaseClass.__name__}.{name}")] - return names - def with_config(**kwds): """Test case decorator for subclasses of JaxTestCase""" @@ -1271,9 +1264,9 @@ def __repr__(self): @contextmanager def assert_global_configs_unchanged(): starting_cache = compilation_cache._cache - starting_config = jax.config.values.copy() + starting_config = config.config.values.copy() yield - ending_config = jax.config.values + ending_config = config.config.values ending_cache = compilation_cache._cache if starting_config != ending_config: @@ -1300,36 +1293,28 @@ class JaxTestCase(parameterized.TestCase): 'jax_legacy_prng_key': 'error', } - _context_stack: ExitStack | None = None - - def setUp(self): super().setUp() - self.enter_context(assert_global_configs_unchanged()) + self.enterContext(assert_global_configs_unchanged()) # We use the adler32 hash for two reasons. # a) it is deterministic run to run, unlike hash() which is randomized. # b) it returns values in int32 range, which RandomState requires. self._rng = npr.RandomState(zlib.adler32(self._testMethodName.encode())) - # TODO(phawkins): use TestCase.enterContext once Python 3.11 is the minimum - # version. - self._context_stack = ExitStack() - self.addCleanup(self._context_stack.close) - stack = self._context_stack - stack.enter_context(global_config_context(**self._default_global_config)) + self.enterContext(global_config_context(**self._default_global_config)) for config_name, value in self._default_thread_local_config.items(): - stack.enter_context(jax._src.config.config_states[config_name](value)) + self.enterContext(config.config_states[config_name](value)) if TEST_WITH_PERSISTENT_COMPILATION_CACHE.value: assert TEST_NUM_THREADS.value <= 1, "Persistent compilation cache is not thread-safe." - stack.enter_context(config.enable_compilation_cache(True)) - stack.enter_context(config.raise_persistent_cache_errors(True)) - stack.enter_context(config.persistent_cache_min_compile_time_secs(0)) - stack.enter_context(config.persistent_cache_min_entry_size_bytes(0)) - tmp_dir = stack.enter_context(tempfile.TemporaryDirectory()) - stack.enter_context(config.compilation_cache_dir(tmp_dir)) - stack.callback(compilation_cache.reset_cache) + self.enterContext(config.enable_compilation_cache(True)) + self.enterContext(config.raise_persistent_cache_errors(True)) + self.enterContext(config.persistent_cache_min_compile_time_secs(0)) + self.enterContext(config.persistent_cache_min_entry_size_bytes(0)) + tmp_dir = self.enterContext(tempfile.TemporaryDirectory()) + self.enterContext(config.compilation_cache_dir(tmp_dir)) + self.addCleanup(compilation_cache.reset_cache) def tearDown(self) -> None: assert core.reset_trace_state() @@ -1348,15 +1333,15 @@ def assertDeprecationWarnsOrRaises(self, deprecation_id: str, message: str): else: return self.assertWarnsRegex(DeprecationWarning, message) - def assertArraysEqual(self, x, y, *, check_dtypes=True, err_msg='', + def assertArraysEqual(self, actual, desired, *, check_dtypes=True, err_msg='', allow_object_dtype=False, verbose=True): """Assert that x and y arrays are exactly equal.""" if check_dtypes: - self.assertDtypesMatch(x, y) - x = np.asarray(x) - y = np.asarray(y) + self.assertDtypesMatch(actual, desired) + actual = np.asarray(actual) + desired = np.asarray(desired) - if (not allow_object_dtype) and (x.dtype == object or y.dtype == object): + if (not allow_object_dtype) and (actual.dtype == object or desired.dtype == object): # See https://github.com/jax-ml/jax/issues/17867 raise TypeError( "assertArraysEqual may be poorly behaved when np.asarray casts to dtype=object. " @@ -1366,57 +1351,59 @@ def assertArraysEqual(self, x, y, *, check_dtypes=True, err_msg='', # Work around https://github.com/numpy/numpy/issues/18992 with np.errstate(over='ignore'): - np.testing.assert_array_equal(x, y, err_msg=err_msg, + np.testing.assert_array_equal(actual, desired, err_msg=err_msg, verbose=verbose) - def assertArraysAllClose(self, x, y, *, check_dtypes=True, atol=None, + def assertArraysAllClose(self, actual, desired, *, check_dtypes=True, atol=None, rtol=None, err_msg=''): - """Assert that x and y are close (up to numerical tolerances).""" - self.assertEqual(x.shape, y.shape) - atol = max(tolerance(_dtype(x), atol), tolerance(_dtype(y), atol)) - rtol = max(tolerance(_dtype(x), rtol), tolerance(_dtype(y), rtol)) + """Assert that actual and desired are close (up to numerical tolerances).""" + self.assertEqual(actual.shape, desired.shape) + atol = max(tolerance(_dtype(actual), atol), tolerance(_dtype(desired), atol)) + rtol = max(tolerance(_dtype(actual), rtol), tolerance(_dtype(desired), rtol)) - _assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg) + _assert_numpy_allclose(actual, desired, atol=atol, rtol=rtol, err_msg=err_msg) if check_dtypes: - self.assertDtypesMatch(x, y) + self.assertDtypesMatch(actual, desired) - def assertDtypesMatch(self, x, y, *, canonicalize_dtypes=True): + def assertDtypesMatch(self, actual, desired, *, canonicalize_dtypes=True): if not config.enable_x64.value and canonicalize_dtypes: - self.assertEqual(_dtypes.canonicalize_dtype(_dtype(x), allow_extended_dtype=True), - _dtypes.canonicalize_dtype(_dtype(y), allow_extended_dtype=True)) + self.assertEqual(_dtypes.canonicalize_dtype(_dtype(actual), allow_extended_dtype=True), + _dtypes.canonicalize_dtype(_dtype(desired), allow_extended_dtype=True)) else: - self.assertEqual(_dtype(x), _dtype(y)) + self.assertEqual(_dtype(actual), _dtype(desired)) - def assertAllClose(self, x, y, *, check_dtypes=True, atol=None, rtol=None, + def assertAllClose(self, actual, desired, *, check_dtypes=True, atol=None, rtol=None, canonicalize_dtypes=True, err_msg=''): - """Assert that x and y, either arrays or nested tuples/lists, are close.""" - if isinstance(x, dict): - self.assertIsInstance(y, dict) - self.assertEqual(set(x.keys()), set(y.keys())) - for k in x.keys(): - self.assertAllClose(x[k], y[k], check_dtypes=check_dtypes, atol=atol, + """Assert that actual and desired, either arrays or nested tuples/lists, are close.""" + if isinstance(actual, dict): + self.assertIsInstance(desired, dict) + self.assertEqual(set(actual.keys()), set(desired.keys())) + for k in actual.keys(): + self.assertAllClose(actual[k], desired[k], check_dtypes=check_dtypes, atol=atol, rtol=rtol, canonicalize_dtypes=canonicalize_dtypes, err_msg=err_msg) - elif is_sequence(x) and not hasattr(x, '__array__'): - self.assertTrue(is_sequence(y) and not hasattr(y, '__array__')) - self.assertEqual(len(x), len(y)) - for x_elt, y_elt in zip(x, y): - self.assertAllClose(x_elt, y_elt, check_dtypes=check_dtypes, atol=atol, + elif is_sequence(actual) and not hasattr(actual, '__array__'): + self.assertTrue(is_sequence(desired) and not hasattr(desired, '__array__'), + msg=f"Expected sequence, got {desired}") + self.assertEqual(len(actual), len(desired)) + for actual_elt, desired_elt in zip(actual, desired): + self.assertAllClose(actual_elt, desired_elt, check_dtypes=check_dtypes, atol=atol, rtol=rtol, canonicalize_dtypes=canonicalize_dtypes, err_msg=err_msg) - elif hasattr(x, '__array__') or np.isscalar(x): - self.assertTrue(hasattr(y, '__array__') or np.isscalar(y)) + elif hasattr(actual, '__array__') or np.isscalar(actual): + self.assertTrue(hasattr(desired, '__array__') or np.isscalar(desired), + msg=f"Expected array-like, got {desired}") if check_dtypes: - self.assertDtypesMatch(x, y, canonicalize_dtypes=canonicalize_dtypes) - x = np.asarray(x) - y = np.asarray(y) - self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol, + self.assertDtypesMatch(actual, desired, canonicalize_dtypes=canonicalize_dtypes) + actual = np.asarray(actual) + desired = np.asarray(desired) + self.assertArraysAllClose(actual, desired, check_dtypes=False, atol=atol, rtol=rtol, err_msg=err_msg) - elif x == y: + elif actual == desired: return else: - raise TypeError((type(x), type(y))) + raise TypeError((type(actual), type(desired))) def assertMultiLineStrippedEqual(self, expected, what): """Asserts two strings are equal, after dedenting and stripping each line.""" @@ -1431,7 +1418,6 @@ def assertMultiLineStrippedEqual(self, expected, what): self.assertMultiLineEqual(expected_clean, what_clean, msg=f"Found\n{what}\nExpecting\n{expected}") - @contextmanager def assertNoWarnings(self): with test_warning_util.raise_on_warnings(): @@ -1455,7 +1441,7 @@ def assertWarns(self, warning, *, msg=None): @contextmanager def assertWarnsRegex(self, warning, regex): - if regex is not None: + if regex is not None and not isinstance(regex, re.Pattern): regex = re.compile(regex) with test_warning_util.record_warnings() as ws: @@ -1466,8 +1452,8 @@ def assertWarnsRegex(self, warning, regex): if regex is not None and not regex.search(str(w.message)): continue return - self.fail(f"Expected warning not found {warning}:'{regex}', got " - f"{ws}") + self.fail(f"Expected warning not found {warning}:'{regex}', " + f"got warnings: {[str(w.message) for w in ws]}") def _CompileAndCheck(self, fun, args_maker, *, check_dtypes=True, tol=None, @@ -1501,9 +1487,9 @@ def wrapped_fun(*args): python_should_be_executing = False compiled_ans = cfun(*args) - self.assertAllClose(python_ans, monitored_ans, check_dtypes=check_dtypes, + self.assertAllClose(monitored_ans, python_ans, check_dtypes=check_dtypes, atol=atol or tol, rtol=rtol or tol) - self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes, + self.assertAllClose(compiled_ans, python_ans, check_dtypes=check_dtypes, atol=atol or tol, rtol=rtol or tol) args = args_maker() @@ -1514,7 +1500,7 @@ def wrapped_fun(*args): python_should_be_executing = False compiled_ans = cfun(*args) - self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes, + self.assertAllClose(compiled_ans, python_ans, check_dtypes=check_dtypes, atol=atol or tol, rtol=rtol or tol) def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker, @@ -1523,11 +1509,36 @@ def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker, args = args_maker() lax_ans = lax_op(*args) numpy_ans = numpy_reference_op(*args) - self.assertAllClose(numpy_ans, lax_ans, check_dtypes=check_dtypes, + self.assertAllClose(lax_ans, numpy_ans, check_dtypes=check_dtypes, atol=atol or tol, rtol=rtol or tol, canonicalize_dtypes=canonicalize_dtypes) -_PJIT_IMPLEMENTATION = jax.jit + def assertCacheMisses(self, + func: Callable[[], Any], *, + cpp: int | None = None, + aot_call: int | None = None, + tracing: int | None = None, + lowering: int | None = None, + compilation_after_persistent_cache_miss: int | None = None): + with (count_pjit_cpp_cache_miss() as cpp_count, + count_aot_jit_cpp_cache_miss() as aot_call_count, + count_jit_tracing_cache_miss() as tracing_count, + count_jit_and_pmap_lowerings() as lowering_count, + count_compilation_after_persistent_cache_miss() as compilation_count): + func() + if cpp is not None: + self.assertEqual(cpp, cpp_count()) + if aot_call is not None: + self.assertEqual(aot_call, aot_call_count()) + if tracing is not None: + self.assertEqual(tracing, tracing_count()) + if lowering is not None: + self.assertEqual(lowering, lowering_count()) + if compilation_after_persistent_cache_miss is not None: + self.assertEqual(compilation_after_persistent_cache_miss, + compilation_count()) + +_PJIT_IMPLEMENTATION = api.jit _PJIT_IMPLEMENTATION._name = "jit" _NOOP_JIT_IMPLEMENTATION = lambda x, *args, **kwargs: x _NOOP_JIT_IMPLEMENTATION._name = "noop" @@ -1557,11 +1568,11 @@ def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]: # This is similar to the `with_mesh` function above, but isn't a decorator. axis_names, shape = unzip2(named_shape) size = math.prod(shape) - local_devices = list(jax.local_devices()) + local_devices = list(xla_bridge.local_devices()) if len(local_devices) < size: raise unittest.SkipTest(f"Test requires {size} local devices") mesh_devices = np.array(local_devices[:size]).reshape(shape) # type: ignore - with jax.sharding.Mesh(mesh_devices, axis_names): + with mesh_lib.Mesh(mesh_devices, axis_names): yield def with_mesh_from_kwargs(f): @@ -1575,13 +1586,13 @@ def with_and_without_mesh(f): ('Mesh', (('x', 2),), (('i', 'x'),)) ))(with_mesh_from_kwargs(f)) -def with_user_mesh(sizes, names, axis_types=None): +def with_explicit_mesh(sizes, names, axis_types=None, iota_order=False): axis_types = ((mesh_lib.AxisType.Explicit,) * len(names) if axis_types is None else axis_types) def decorator(fn): def mesh_fn(*args, **kwargs): - mesh = create_mesh(sizes, names, axis_types=axis_types) - with jax.sharding.use_mesh(mesh): + mesh = create_mesh(sizes, names, iota_order, axis_types=axis_types) + with sharding_impls.set_mesh(mesh): return fn(*args, **kwargs, mesh=mesh) return mesh_fn return decorator @@ -1589,14 +1600,16 @@ def mesh_fn(*args, **kwargs): def create_mesh(mesh_shape, axis_names, iota_order=False, axis_types=None): size = math.prod(mesh_shape) - if len(jax.devices()) < size: - raise unittest.SkipTest(f"Test requires {size} global devices.") + if len(xla_bridge.devices()) < size: + raise unittest.SkipTest(f"Test requires {size} global devices and found {len(xla_bridge.devices())}.") if iota_order: - devices = sorted(jax.devices(), key=lambda d: d.id) + devices = sorted(xla_bridge.devices(), key=lambda d: d.id) mesh_devices = np.array(devices[:size]).reshape(mesh_shape) - return jax.sharding.Mesh(mesh_devices, axis_names, axis_types=axis_types) + return mesh_lib.Mesh(mesh_devices, axis_names, axis_types=axis_types) else: - return jax.make_mesh(mesh_shape, axis_names, axis_types=axis_types) + if axis_types is None: + axis_types = (mesh_lib.AxisType.Auto,) * len(mesh_shape) + return sharding_impls.make_mesh(mesh_shape, axis_names, axis_types) class _cached_property: null = object() @@ -1630,15 +1643,11 @@ def custom_floats(self): _dtypes.float8_e4m3fnuz, _dtypes.float8_e5m2, _dtypes.float8_e5m2fnuz, + _dtypes.float8_e3m4, + _dtypes.float8_e4m3, + _dtypes.float8_e8m0fnu, + _dtypes.float4_e2m1fn, ] - if _dtypes.float8_e3m4 is not None: - float_dtypes += [_dtypes.float8_e3m4] - if _dtypes.float8_e4m3 is not None: - float_dtypes += [_dtypes.float8_e4m3] - if _dtypes.float8_e8m0fnu is not None: - float_dtypes += [_dtypes.float8_e8m0fnu] - if _dtypes.float4_e2m1fn is not None: - float_dtypes += [_dtypes.float4_e2m1fn] return self.supported(float_dtypes) @_cached_property @@ -1700,8 +1709,8 @@ def strict_promotion_if_dtypes_match(dtypes): and enable standard dtype promotion otherwise. """ if all(dtype == dtypes[0] for dtype in dtypes): - return jax.numpy_dtype_promotion('strict') - return jax.numpy_dtype_promotion('standard') + return config.numpy_dtype_promotion('strict') + return config.numpy_dtype_promotion('standard') _version_regex = re.compile(r"([0-9]+(?:\.[0-9]+)*)(?:(rc|dev).*)?") def parse_version(v: str) -> tuple[int, ...]: @@ -1779,7 +1788,7 @@ def register_event_duration_listener(callback): monitoring.register_event_duration_secs_listener(callback) yield finally: - monitoring._unregister_event_duration_listener_by_callback(callback) + monitoring.unregister_event_duration_listener(callback) @contextmanager @@ -1816,22 +1825,12 @@ def set_env(**kwargs): os.environ.update({k: v for k, v in original.items() if v is not None}) def fwd_bwd_jaxprs(f, *example_args): - fwd_jaxpr, (y_shape, res_shape) = jax.make_jaxpr( - lambda *args: jax.vjp(f, *args), return_shape=True)(*example_args) - bwd_jaxpr = jax.make_jaxpr(lambda res, outs: res(outs))(res_shape, y_shape) + fwd_jaxpr, (y_shape, res_shape) = api.make_jaxpr( + lambda *args: api.vjp(f, *args), return_shape=True)(*example_args) + bwd_jaxpr = api.make_jaxpr(lambda res, outs: res(outs))(res_shape, y_shape) return fwd_jaxpr, bwd_jaxpr -def numpy_vecdot(x, y, axis): - """Implementation of numpy.vecdot for testing on numpy < 2.0.0""" - if numpy_version() >= (2, 0, 0): - raise ValueError("should be calling vecdot directly on numpy 2.0.0") - x = np.moveaxis(x, axis, -1) - y = np.moveaxis(y, axis, -1) - x, y = np.broadcast_arrays(x, y) - return np.matmul(np.conj(x[..., None, :]), y[..., None])[..., 0, 0] - - def complex_plane_sample(dtype, size_re=10, size_im=None): """Return a 2-D array of complex numbers that covers the complex plane with a grid of samples. @@ -1933,8 +1932,8 @@ def __init__(self, *args, **kwargs): self.extra_prec_multiplier = kwargs.pop('extra_prec_multiplier', 0) self.extra_prec = kwargs.pop('extra_prec', 0) self.mpmath = mpmath - self.contexts = dict() - self.contexts_inv = dict() + self.contexts = {} + self.contexts_inv = {} for fp_format, prec in self.float_prec.items(): ctx = self.mpmath.mp.clone() ctx.prec = prec @@ -2024,7 +2023,7 @@ def mptonp(self, x): def __call__(self, *args, **kwargs): mp_args = [] - context = None + context: Any = None for a in args: if isinstance(a, (np.ndarray, np.floating, np.complexfloating)): mp_args.append(self.nptomp(a)) @@ -2387,6 +2386,17 @@ def worker(ctx, s, e, r, v): assert 0 # unreachable # Hypothesis testing support +def hypothesis_is_thread_safe() -> bool: + """Returns True if the installed hypothesis version is thread-safe. + + Hypothesis versions >= 6.136.9 are thread-safe. + """ + try: + import hypothesis as hp # pytype: disable=import-error + return tuple(int(x) for x in hp.__version__.split('.')) >= (6, 136, 9) + except (ModuleNotFoundError, ImportError): + return True + def setup_hypothesis(max_examples=30) -> None: """Sets up the hypothesis profiles. @@ -2399,7 +2409,7 @@ def setup_hypothesis(max_examples=30) -> None: the default "deterministic" profile. """ try: - import hypothesis as hp + import hypothesis as hp # pytype: disable=import-error except (ModuleNotFoundError, ImportError): return @@ -2446,3 +2456,13 @@ def setup_hypothesis(max_examples=30) -> None: profile = HYPOTHESIS_PROFILE.value logging.info("Using hypothesis profile: %s", profile) hp.settings.load_profile(profile) + + +def runtime_environment() -> str | None: + """Returns None, "bazel" or "pytest".""" + if sys.executable is None: + return None + elif 'bazel-out' in sys.executable: + return "bazel" + else: + return "pytest" diff --git a/jax/_src/third_party/scipy/betaln.py b/jax/_src/third_party/scipy/betaln.py index 10f12bdae037..bf4aaec6a10d 100644 --- a/jax/_src/third_party/scipy/betaln.py +++ b/jax/_src/third_party/scipy/betaln.py @@ -1,5 +1,5 @@ -from jax import lax -import jax.numpy as jnp +from jax._src import lax +from jax._src import numpy as jnp from jax._src.typing import Array, ArrayLike from jax._src.numpy.util import promote_args_inexact diff --git a/jax/_src/third_party/scipy/interpolate.py b/jax/_src/third_party/scipy/interpolate.py index 1eb726ea863c..5e2047bbe26a 100644 --- a/jax/_src/third_party/scipy/interpolate.py +++ b/jax/_src/third_party/scipy/interpolate.py @@ -1,7 +1,10 @@ from itertools import product -from jax.numpy import (asarray, broadcast_arrays, can_cast, - empty, nan, searchsorted, where, zeros) +import numpy as np + +from jax._src import dtypes +from jax._src.numpy import (asarray, broadcast_arrays, + empty, searchsorted, where, zeros) from jax._src.tree_util import register_pytree_node from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact @@ -62,7 +65,7 @@ def __init__(self, values, method="linear", bounds_error=False, - fill_value=nan): + fill_value=np.nan): if method not in ("linear", "nearest"): raise ValueError(f"method {method!r} is not defined") self.method = method @@ -80,7 +83,7 @@ def __init__(self, if fill_value is not None: check_arraylike("RegularGridInterpolator", fill_value) fill_value = asarray(fill_value) - if not can_cast(fill_value.dtype, values.dtype, casting='same_kind'): + if not dtypes.can_cast(fill_value.dtype, values.dtype, casting='same_kind'): ve = "fill_value must be either 'None' or of a type compatible with values" raise ValueError(ve) self.fill_value = fill_value diff --git a/jax/_src/third_party/scipy/linalg.py b/jax/_src/third_party/scipy/linalg.py index dce4df1fb817..c04ab3c155ff 100644 --- a/jax/_src/third_party/scipy/linalg.py +++ b/jax/_src/third_party/scipy/linalg.py @@ -2,14 +2,18 @@ from collections.abc import Callable -from jax import jit, lax -import jax.numpy as jnp +import numpy as np + +from jax._src import api +from jax._src import dtypes +from jax._src import lax +from jax._src import numpy as jnp from jax._src.numpy.linalg import norm from jax._src.scipy.linalg import rsf2csf, schur from jax._src.typing import ArrayLike, Array -@jit +@api.jit def _algorithm_11_1_1(F: Array, T: Array) -> tuple[Array, Array]: # Algorithm 11.1.1 from Golub and Van Loan "Matrix Computations" N = T.shape[0] @@ -99,14 +103,14 @@ def funm(A: ArrayLike, func: Callable[[Array], Array], return F if F.dtype.char.lower() == 'e': - tol = jnp.finfo(jnp.float16).eps + tol = dtypes.finfo(np.float16).eps if F.dtype.char.lower() == 'f': - tol = jnp.finfo(jnp.float32).eps + tol = dtypes.finfo(np.float32).eps else: - tol = jnp.finfo(jnp.float64).eps + tol = dtypes.finfo(np.float64).eps minden = jnp.where(minden == 0.0, tol, minden) - err = jnp.where(jnp.any(jnp.isinf(F)), jnp.inf, jnp.minimum(1, jnp.maximum( + err = jnp.where(jnp.any(jnp.isinf(F)), np.inf, jnp.minimum(1, jnp.maximum( tol, (tol / minden) * norm(jnp.triu(T, 1), 1)))) return F, err diff --git a/jax/_src/third_party/scipy/signal_helper.py b/jax/_src/third_party/scipy/signal_helper.py index 4a021675804d..4dcf96ad2897 100644 --- a/jax/_src/third_party/scipy/signal_helper.py +++ b/jax/_src/third_party/scipy/signal_helper.py @@ -5,7 +5,9 @@ from typing import Any import warnings -import jax.numpy as jnp +import numpy as np + +from jax._src import numpy as jnp from jax._src.typing import Array, ArrayLike, DTypeLike @@ -47,17 +49,17 @@ def _triage_segments(window: ArrayLike | str | tuple[Any, ...], nperseg: int | N nperseg_int = input_length if window == 'hann': # Implement the default case without scipy - win = jnp.array([1.0]) if nperseg_int == 1 else jnp.sin(jnp.linspace(0, jnp.pi, nperseg_int, endpoint=False)) ** 2 + win = jnp.array([1.0]) if nperseg_int == 1 else jnp.sin(jnp.linspace(0, np.pi, nperseg_int, endpoint=False)) ** 2 else: # TODO(jakevdp): implement get_window() in JAX to remove optional scipy dependency try: - from scipy.signal import get_window + from scipy.signal import get_window # pytype: disable=import-error except ImportError as err: raise ImportError(f"scipy must be available to use {window=}") from err win = get_window(window, nperseg_int) win = jnp.array(win, dtype=dtype) else: - win = jnp.asarray(window) + win = jnp.asarray(window, dtype=dtype) nperseg_int = win.size if nperseg is None else int(nperseg) if win.ndim != 1: raise ValueError('window must be 1-D') diff --git a/jax/_src/third_party/scipy/special.py b/jax/_src/third_party/scipy/special.py index bdb84df23280..28703d1a1532 100644 --- a/jax/_src/third_party/scipy/special.py +++ b/jax/_src/third_party/scipy/special.py @@ -1,15 +1,15 @@ from __future__ import annotations -import jax.numpy as jnp -from jax import jit +import numpy as np +from jax._src import api +from jax._src import numpy as jnp from jax._src import custom_derivatives, dtypes from jax._src.numpy.util import promote_args_inexact from jax._src.typing import Array, ArrayLike -from numpy import complexfloating -@jit +@api.jit def sincospisquaredhalf( x: Array, ) -> tuple[Array, Array]: @@ -32,17 +32,17 @@ def sincospisquaredhalf( sinpi = jnp.where( r < 0.5, - jnp.sin(jnp.pi * r), + jnp.sin(np.pi * r), jnp.where( r > 1.5, - jnp.sin(jnp.pi * (r - 2.0)), - -jnp.sin(jnp.pi * (r - 1.0)), + jnp.sin(np.pi * (r - 2.0)), + -jnp.sin(np.pi * (r - 1.0)), ), ) cospi = jnp.where( r == 0.5, 0.0, - jnp.where(r < 1.0, -jnp.sin(jnp.pi * (r - 0.5)), jnp.sin(jnp.pi * (r - 1.5))), + jnp.where(r < 1.0, -jnp.sin(np.pi * (r - 0.5)), jnp.sin(np.pi * (r - 1.5))), ) return sinpi, cospi @@ -92,16 +92,16 @@ def fresnel(x: ArrayLike) -> tuple[Array, Array]: # This part is mostly a direct translation of SciPy's C++ code, # and the original Cephes implementation for single precision. - if dtypes.issubdtype(xxa.dtype, complexfloating): + if dtypes.issubdtype(xxa.dtype, np.complexfloating): raise NotImplementedError( 'Support for complex-valued inputs is not implemented yet.') - elif xxa.dtype in (jnp.float32, jnp.float16, jnp.bfloat16): + elif xxa.dtype in (np.float32, np.float16, dtypes.bfloat16): # Single-precision Cephes coefficients # For half-precision, series expansions have either # produce overflow or poor accuracy. # Upcasting to single-precision is hence needed. - xxa = xxa.astype(jnp.float32) # No-op for float32 + xxa = xxa.astype(np.float32) # No-op for float32 fresnl_sn = jnp.array([ +1.647629463788700e-9, @@ -111,7 +111,7 @@ def fresnel(x: ArrayLike) -> tuple[Array, Array]: +7.244727626597022e-3, -9.228055941124598e-2, +5.235987735681432e-1, - ], dtype=jnp.float32) + ], dtype=np.float32) fresnl_cn = jnp.array([ +1.416802502367354e-8, @@ -121,7 +121,7 @@ def fresnel(x: ArrayLike) -> tuple[Array, Array]: +2.818489036795073e-2, -2.467398198317899e-1, +9.999999760004487e-1, - ], dtype=jnp.float32) + ], dtype=np.float32) fresnl_fn = jnp.array([ -1.903009855649792e12, @@ -132,7 +132,7 @@ def fresnel(x: ArrayLike) -> tuple[Array, Array]: +8.560515466275470e3, -1.032877601091159e2, +2.999401847870011e0, - ], dtype=jnp.float32) + ], dtype=np.float32) fresnl_gn = jnp.array([ -1.860843997624650e11, @@ -143,8 +143,8 @@ def fresnel(x: ArrayLike) -> tuple[Array, Array]: +8.602931494734327e2, -1.493439396592284e1, +9.999841934744914e-1, - ], dtype=jnp.float32) - elif xxa.dtype == jnp.float64: + ], dtype=np.float32) + elif xxa.dtype == np.float64: # Double-precision Cephes coefficients fresnl_sn = jnp.array([ @@ -154,7 +154,7 @@ def fresnel(x: ArrayLike) -> tuple[Array, Array]: +2.54890880573376359104e9, -4.42979518059697779103e10, +3.18016297876567817986e11, - ], dtype=jnp.float64) + ], dtype=np.float64) fresnl_sd = jnp.array([ +1.00000000000000000000e0, @@ -164,7 +164,7 @@ def fresnel(x: ArrayLike) -> tuple[Array, Array]: +4.19320245898111231129e8, +2.24411795645340920940e10, +6.07366389490084639049e11, - ], dtype=jnp.float64) + ], dtype=np.float64) fresnl_cn = jnp.array([ -4.98843114573573548651e-8, @@ -173,7 +173,7 @@ def fresnel(x: ArrayLike) -> tuple[Array, Array]: +1.88843319396703850064e-2, -2.05525900955013891793e-1, +9.99999999999999998822e-1, - ], dtype=jnp.float64) + ], dtype=np.float64) fresnl_cd = jnp.array([ +3.99982968972495980367e-12, @@ -183,7 +183,7 @@ def fresnel(x: ArrayLike) -> tuple[Array, Array]: +8.68029542941784300606e-4, +4.12142090722199792936e-2, +1.00000000000000000118e0, - ], dtype=jnp.float64) + ], dtype=np.float64) fresnl_fn = jnp.array([ +4.21543555043677546506e-1, @@ -196,7 +196,7 @@ def fresnel(x: ArrayLike) -> tuple[Array, Array]: +1.72010743268161828879e-13, +1.34283276233062758925e-16, +3.76329711269987889006e-20, - ], dtype=jnp.float64) + ], dtype=np.float64) fresnl_fd = jnp.array([ +1.00000000000000000000e0, @@ -210,7 +210,7 @@ def fresnel(x: ArrayLike) -> tuple[Array, Array]: +5.88754533621578410010e-14, +4.52001434074129701496e-17, +1.25443237090011264384e-20, - ], dtype=jnp.float64) + ], dtype=np.float64) fresnl_gn = jnp.array([ +5.04442073643383265887e-1, @@ -224,7 +224,7 @@ def fresnel(x: ArrayLike) -> tuple[Array, Array]: +1.37555460633261799868e-15, +8.36354435630677421531e-19, +1.86958710162783235106e-22, - ], dtype=jnp.float64) + ], dtype=np.float64) fresnl_gd = jnp.array([ +1.00000000000000000000e0, @@ -239,13 +239,13 @@ def fresnel(x: ArrayLike) -> tuple[Array, Array]: +1.38796531259578871258e-15, +8.39158816283118707363e-19, +1.86958710162783236342e-22, - ], dtype=jnp.float64) + ], dtype=np.float64) else: raise NotImplementedError( f'Support for {xxa.dtype} dtype is not implemented yet.') - assert xxa.dtype in (jnp.float32, jnp.float64) - single_precision = (xxa.dtype == jnp.float32) + assert xxa.dtype in (np.float32, np.float64) + single_precision = (xxa.dtype == np.float32) x = jnp.abs(xxa) @@ -272,11 +272,11 @@ def fresnel(x: ArrayLike) -> tuple[Array, Array]: c_large = c_inf s_large = s_inf else: - c_large = 0.5 + 1 / (jnp.pi * x) * sinpi - s_large = 0.5 - 1 / (jnp.pi * x) * cospi + c_large = 0.5 + 1 / (np.pi * x) * sinpi + s_large = 0.5 - 1 / (np.pi * x) * cospi # Other x values - t = jnp.pi * x2 + t = np.pi * x2 u = 1.0 / (t * t) t = 1.0 / t @@ -287,7 +287,7 @@ def fresnel(x: ArrayLike) -> tuple[Array, Array]: f = 1.0 - u * jnp.polyval(fresnl_fn, u) / jnp.polyval(fresnl_fd, u) g = t * jnp.polyval(fresnl_gn, u) / jnp.polyval(fresnl_gd, u) - t = jnp.pi * x + t = np.pi * x c_other = 0.5 + (f * sinpi - g * cospi) / t s_other = 0.5 - (f * cospi + g * sinpi) / t diff --git a/jax/_src/tpu/__init__.py b/jax/_src/tpu/__init__.py new file mode 100644 index 000000000000..1337256a5074 --- /dev/null +++ b/jax/_src/tpu/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jax/experimental/pallas/gpu.py b/jax/_src/tpu/linalg/__init__.py similarity index 63% rename from jax/experimental/pallas/gpu.py rename to jax/_src/tpu/linalg/__init__.py index 0ee84c8453ec..8c09b25d1e08 100644 --- a/jax/experimental/pallas/gpu.py +++ b/jax/_src/tpu/linalg/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 The JAX Authors. +# Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax._src import deprecations +import os -deprecations.warn( - "pallas-gpu-triton", - "The ``jax.experimental.pallas.gpu`` submodule is deprecated. " - " Use ``jax.experimental.pallas.triton`` instead.", - stacklevel=1, +from jax._src.tpu.linalg import ( + eigh as eigh, + qdwh as qdwh, + svd as svd, ) -from jax.experimental.pallas.triton import * # noqa: F403 +from jax._src import traceback_util +traceback_util.register_exclusion(os.path.dirname(__file__)) diff --git a/jax/_src/lax/eigh.py b/jax/_src/tpu/linalg/eigh.py similarity index 81% rename from jax/_src/lax/eigh.py rename to jax/_src/tpu/linalg/eigh.py index 99711dc6bf0e..4b935293adf1 100644 --- a/jax/_src/lax/eigh.py +++ b/jax/_src/tpu/linalg/eigh.py @@ -32,16 +32,25 @@ import numpy as np -import jax -import jax._src.numpy.lax_numpy as jnp -import jax._src.numpy.linalg as jnp_linalg +from jax._src import api +from jax._src import config +from jax._src import core +from jax._src import dtypes +from jax._src import lax +from jax._src import numpy as jnp +from jax._src.interpreters import mlir +from jax._src.lax import control_flow +from jax._src.lax import lax as lax_internal +from jax._src.lax import linalg as lax_linalg +from jax._src.lax.linalg import is_constant_shape +from jax._src.lib.mlir.dialects import hlo +from jax._src.numpy import linalg as jnp_linalg from jax._src.numpy import tensor_contractions from jax._src.numpy import reductions from jax._src.numpy import ufuncs -from jax import lax -from jax._src.lax import qdwh -from jax._src.lax import linalg as lax_linalg -from jax._src.lax.stack import Stack +from jax._src.tpu.linalg import qdwh +from jax._src.tpu.linalg.stack import Stack +from jax._src.typing import Array # QDWH-eigh is a recursive algorithm where the structure of the recursion @@ -152,7 +161,7 @@ def _projector_subspace(P, H, n, rank, maxiter=2, swap=False): X = _mask(X, (n, rank)) H_norm = jnp_linalg.norm(H) - thresh = 10.0 * float(jnp.finfo(X.dtype).eps) * H_norm + thresh = 10.0 * float(dtypes.finfo(X.dtype).eps) * H_norm # First iteration skips the matmul. def body_f_after_matmul(X): @@ -282,13 +291,13 @@ class _Subproblem(NamedTuple): in the workspace. """ # The row offset of the block in the matrix of blocks. - offset: jax.Array + offset: Array # The size of the block. - size: jax.Array + size: Array -@partial(jax.jit, static_argnames=('termination_size', 'subset_by_index')) +@api.jit(static_argnames=('termination_size', 'subset_by_index')) def _eigh_work(H, n, termination_size, subset_by_index): """ The main work loop performing the symmetric eigendecomposition of H. Each step recursively computes a projector into the space of eigenvalues @@ -368,7 +377,7 @@ def base_case(B, offset, b, agenda, blocks, eigenvectors): # and GPU eigendecompositions, and for those platforms this algorithm will # only do the right thing if termination_size == 1. H = _mask(H, (b, b)) - eig_vecs, eig_vals = lax.linalg.eigh(H, sort_eigenvalues=False) + eig_vecs, eig_vals = lax_linalg.eigh(H, sort_eigenvalues=False) eig_vecs = _mask(eig_vecs, (b, b)) eig_vals = _mask(eig_vals, (b,)) eig_vecs = tensor_contractions.dot(V, eig_vecs) @@ -443,7 +452,7 @@ def default_case(agenda, blocks, eigenvectors): # handle matrices with clusters of eigenvalues, including rank deficient # matrices. See Nakatsukasa and Higham section 5.2. norm = jnp_linalg.norm(H) - eps = jnp.asarray(jnp.finfo(H.dtype).eps, dtype=norm.dtype) + eps = jnp.asarray(dtypes.finfo(H.dtype).eps, dtype=norm.dtype) off_diag_norm = jnp_linalg.norm( H - jnp.diag(jnp.diag(ufuncs.real(H)).astype(H.dtype))) nearly_diagonal = off_diag_norm <= 5 * eps * norm @@ -490,7 +499,7 @@ def loop_cond(state): def loop_body(state): agenda, blocks, eigenvectors = state (offset, b), agenda = agenda.pop() - which = jnp.where(buckets < b, jnp.iinfo(np.int32).max, buckets) + which = jnp.where(buckets < b, dtypes.iinfo(np.int32).max, buckets) choice = jnp.argmin(which) return lax.switch(choice, branches, offset, b, agenda, blocks, eigenvectors) @@ -551,8 +560,10 @@ def eigh( if N <= termination_size: if n is not None: H = _mask(H, (n, n)) - eig_vals, eig_vecs = lax_linalg.eigh_jacobi( - H, sort_eigenvalues=(sort_eigenvalues or compute_slice) + eig_vecs, eig_vals = lax_linalg.eigh( + H, lower=True, sort_eigenvalues=(sort_eigenvalues or compute_slice), + subset_by_index=None, symmetrize_input=False, + implementation=lax_linalg.EighImplementation.JACOBI, ) if compute_slice: eig_vals = eig_vals[subset_by_index[0] : subset_by_index[1]] @@ -560,7 +571,7 @@ def eigh( return eig_vals, eig_vecs n = N if n is None else n - with jax.default_matmul_precision(precision): + with config.default_matmul_precision(precision): eig_vals, eig_vecs = _eigh_work( H, n, termination_size=termination_size, subset_by_index=subset_by_index ) @@ -573,3 +584,116 @@ def eigh( eig_vecs = eig_vecs[:, sort_idxs] return eig_vals, eig_vecs + + +def _T(x: Array) -> Array: + return lax.transpose(x, (*range(x.ndim - 2), x.ndim - 1, x.ndim - 2)) + + +def _eigh_qdwh_impl(x, *, lower, sort_eigenvalues, subset_by_index): + """QDWH-based eigendecomposition for TPU.""" + *_, m, n = x.shape + assert m == n, (m, n) + + termination_size = 256 + if not core.is_constant_dim(m): + # TODO: maybe we can relax the check below for shape polymorphism? + raise NotImplementedError( + "Shape polymorphism for native lowering for eigh is implemented " + f"only for the batch dimensions: {x.shape}") + + if m <= termination_size and ( + subset_by_index is None or subset_by_index == (0, n) + ): + return lax_linalg.eigh( + x, lower=lower, sort_eigenvalues=sort_eigenvalues, + symmetrize_input=False, + implementation=lax_linalg.EighImplementation.JACOBI + ) + + def eigh_qdwh(x): + if len(x.shape) > 2: + return control_flow.map(eigh_qdwh, x) + + # We should only look at elements from the lower/upper triangle. Reflects + # that triangle into the other triangle to form a Hermitian matrix. + if lower: + mask = lax_internal._tri(bool, (n, n), 0) + else: + mask = lax.bitwise_not(lax_internal._tri(bool, (n, n), -1)) + if dtypes.issubdtype(x.dtype, np.complexfloating): + re = lax.select(mask, lax.real(x), _T(lax.real(x))) + if lower: + im_mask = lax_internal._tri(bool, (n, n), -1) + else: + im_mask = lax.bitwise_not(lax_internal._tri(bool, (n, n), 0)) + im = lax.imag(x) + im = lax.select(im_mask, im, lax.full_like(im, 0)) + im = lax.select(mask, im, -_T(im)) + x = lax.complex(re, im) + else: + x = lax.select(mask, x, _T(x)) + + return eigh( + x, + sort_eigenvalues=sort_eigenvalues, + termination_size=termination_size, + subset_by_index=subset_by_index, + ) + + eig_vals, eig_vecs = eigh_qdwh(x) + return eig_vecs, eig_vals + + +def _eigh_tpu_lowering( + ctx, operand, *, lower, sort_eigenvalues, subset_by_index, algorithm +): + if algorithm is None: + algorithm = lax_linalg.EighImplementation.QDWH + + if algorithm == lax_linalg.EighImplementation.QR: + raise NotImplementedError("QR algorithm is not supported on TPU") + + elif algorithm == lax_linalg.EighImplementation.JACOBI: + operand_aval, = ctx.avals_in + if operand_aval.shape[-1] == 0: + reshape_aval = operand_aval.update(shape=operand_aval.shape[:-1]) + return [ + operand, + hlo.real(mlir.reshape(ctx, operand, reshape_aval)), + ] + + v_aval, w_aval = ctx.avals_out + eigvecs_type = mlir.aval_to_ir_type(v_aval) + eigvals_type = mlir.aval_to_ir_type(w_aval) + result_types = [eigvecs_type, eigvals_type] + + backend_config = f"{int(lower)},{int(sort_eigenvalues)},100,1e-6" + + if any(not is_constant_shape(aval_out.shape) + for aval_out in ctx.avals_out): + result_shapes = [ + mlir.eval_dynamic_shape_as_tensor(ctx, aval_out.shape) + for aval_out in ctx.avals_out + ] + else: + result_shapes = None + op = mlir.custom_call( + "Eigh", + result_types=result_types, + operands=[operand], + backend_config=backend_config, + api_version=1, + result_shapes=result_shapes, + ) + return op.results + elif algorithm == lax_linalg.EighImplementation.QDWH: + return mlir.lower_fun(_eigh_qdwh_impl, multiple_results=True)( + ctx, operand, lower=lower, sort_eigenvalues=sort_eigenvalues, + subset_by_index=subset_by_index) + + else: + raise ValueError(f"Unknown algorithm: {algorithm}") + + +mlir.register_lowering(lax_linalg.eigh_p, _eigh_tpu_lowering, platform='tpu') diff --git a/jax/_src/lax/qdwh.py b/jax/_src/tpu/linalg/qdwh.py similarity index 92% rename from jax/_src/lax/qdwh.py rename to jax/_src/tpu/linalg/qdwh.py index bac3ea957955..35fbbc3b6202 100644 --- a/jax/_src/lax/qdwh.py +++ b/jax/_src/tpu/linalg/qdwh.py @@ -28,11 +28,16 @@ import functools -import jax -import jax.numpy as jnp -from jax import lax +import numpy as np + +from jax._src import api +from jax._src import config from jax._src import core +from jax._src import dtypes +from jax._src import lax +from jax._src import numpy as jnp from jax._src.lax import linalg as lax_linalg +from jax._src.numpy import linalg as jnp_linalg # Helpers for working with padded shapes @@ -42,11 +47,11 @@ def _mask(x, dims, alternative=0): Replaces values outside those dimensions with `alternative`. `alternative` is broadcast with `x`. """ - assert jnp.ndim(x) == len(dims) + assert np.ndim(x) == len(dims) mask = None for i, d in enumerate(dims): if d is not None: - mask_dim_i = lax.broadcasted_iota(jnp.int32, x.shape, i) < d + mask_dim_i = lax.broadcasted_iota(np.int32, x.shape, i) < d mask = mask_dim_i if mask is None else (mask & mask_dim_i) return x if mask is None else jnp.where(mask, x, alternative) @@ -74,7 +79,7 @@ def _use_qr(u, m, n, params): a_minus_e_by_sqrt_c, sqrt_c, e = params M, N = u.shape - y = _dynamic_concat(sqrt_c * u, jnp.eye(N, dtype=jnp.dtype(u)), m) + y = _dynamic_concat(sqrt_c * u, jnp.eye(N, dtype=dtypes.dtype(u)), m) q, _ = lax_linalg.qr(y, full_matrices=False) # q1 = q[:m, :] q1 = _mask(lax.slice(q, (0, 0), (M, N)), (m, n)) @@ -94,7 +99,7 @@ def _use_cholesky(u, m, n, params): """ a_minus_e, c, e = params _, N = u.shape - x = c * (u.T.conj() @ u) + jnp.eye(N, dtype=jnp.dtype(u)) + x = c * (u.T.conj() @ u) + jnp.eye(N, dtype=dtypes.dtype(u)) # Pads the lower-right corner with the identity matrix to prevent the Cholesky # decomposition from failing due to the matrix not being PSD if padded with # zeros. @@ -119,9 +124,9 @@ def _qdwh(x, m, n, max_iterations, eps): # norm(x, 2) such that `alpha >= norm(x, 2)` and `beta` is a lower bound for # the smallest singular value of x. if eps is None: - eps = float(jnp.finfo(x.dtype).eps) - one_norm = jnp.linalg.norm(x, ord=1) - inf_norm = jnp.linalg.norm(x, ord=jnp.inf) + eps = float(dtypes.finfo(x.dtype).eps) + one_norm = jnp_linalg.norm(x, ord=1) + inf_norm = jnp_linalg.norm(x, ord=np.inf) alpha_inverse = lax.rsqrt(one_norm) * lax.rsqrt(inf_norm) alpha_inverse = jnp.where(one_norm == 0, 1, alpha_inverse) u = x * alpha_inverse.astype(x.dtype) @@ -176,7 +181,7 @@ def iteration(k, state, update_fn, coefs, test_convergence): is_not_converged = True if test_convergence: - is_not_converged = jnp.linalg.norm(u - u_prev) > tol_norm + is_not_converged = jnp_linalg.norm(u - u_prev) > tol_norm return u, is_not_converged def iterate(u, coefs, **kwargs): @@ -229,7 +234,7 @@ def body_fun(state): # TODO: Add pivoting. @functools.partial( - jax.jit, static_argnames=('is_hermitian', 'max_iterations', 'eps') + api.jit, static_argnames=('is_hermitian', 'max_iterations', 'eps') ) def qdwh( x, @@ -279,7 +284,7 @@ def qdwh( else: m, n = M, N - with jax.default_matmul_precision('float32'): + with config.default_matmul_precision('float32'): u, h, num_iters, is_converged = _qdwh(x, m, n, max_iterations, eps) return u, h, num_iters, is_converged diff --git a/jax/_src/lax/stack.py b/jax/_src/tpu/linalg/stack.py similarity index 84% rename from jax/_src/lax/stack.py rename to jax/_src/tpu/linalg/stack.py index 882195f17d51..75199cafbd00 100644 --- a/jax/_src/lax/stack.py +++ b/jax/_src/tpu/linalg/stack.py @@ -22,9 +22,10 @@ from typing import Any -import jax -from jax import lax -import jax.numpy as jnp +from jax._src import lax +from jax._src import numpy as jnp +from jax._src import tree_util + class Stack: """A bounded functional stack implementation. Elements may be pytrees.""" @@ -44,8 +45,8 @@ def create(capacity: int, prototype: Any) -> Stack: structure; the specific values are ignored. """ return Stack( - jnp.array(0, jnp.int32), - jax.tree_util.tree_map( + jnp.array(0, 'int32'), + tree_util.tree_map( lambda x: jnp.zeros((capacity,) + tuple(x.shape), x.dtype), prototype)) def empty(self) -> Any: @@ -56,23 +57,23 @@ def push(self, elem: Any) -> Stack: """Pushes `elem` onto the stack, returning the updated stack.""" return Stack( self._size + 1, - jax.tree_util.tree_map( + tree_util.tree_map( lambda x, y: lax.dynamic_update_index_in_dim(x, y, self._size, 0), self._data, elem)) def pop(self) -> tuple[Any, Stack]: """Pops from the stack, returning an (elem, updated stack) pair.""" - elem = jax.tree_util.tree_map( + elem = tree_util.tree_map( lambda x: lax.dynamic_index_in_dim(x, self._size - 1, 0, keepdims=False), self._data) return elem, Stack(self._size - 1, self._data) def flatten(self): - leaves, treedef = jax.tree_util.tree_flatten(self._data) + leaves, treedef = tree_util.tree_flatten(self._data) return ([self._size] + leaves), treedef @staticmethod def unflatten(treedef, leaves): - return Stack(leaves[0], jax.tree_util.tree_unflatten(treedef, leaves[1:])) + return Stack(leaves[0], tree_util.tree_unflatten(treedef, leaves[1:])) -jax.tree_util.register_pytree_node(Stack, Stack.flatten, Stack.unflatten) +tree_util.register_pytree_node(Stack, Stack.flatten, Stack.unflatten) diff --git a/jax/_src/lax/svd.py b/jax/_src/tpu/linalg/svd.py similarity index 76% rename from jax/_src/lax/svd.py rename to jax/_src/tpu/linalg/svd.py index 9f22f130cbb2..091c1a5f0ff9 100644 --- a/jax/_src/lax/svd.py +++ b/jax/_src/tpu/linalg/svd.py @@ -40,13 +40,20 @@ import operator from typing import Any -import jax -from jax import lax +import numpy as np + +from jax._src import api +from jax._src import config from jax._src import core -import jax.numpy as jnp +from jax._src import dtypes +from jax._src import lax +from jax._src import numpy as jnp +from jax._src.interpreters import mlir +from jax._src.lax import linalg as lax_linalg +from jax._src.tpu.linalg import qdwh as tpu_qdwh -@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4)) +@functools.partial(api.jit, static_argnums=(1, 2, 3, 4)) def _svd_tall_and_square_input( a: Any, hermitian: bool, @@ -69,12 +76,12 @@ def _svd_tall_and_square_input( `a = (u * s) @ v.T.conj()`. For `compute_uv=False`, only `s` is returned. """ - u_p, h, _, _ = lax.linalg.qdwh( + u_p, h, _, _ = tpu_qdwh.qdwh( a, is_hermitian=hermitian, max_iterations=max_iterations ) # TODO: Uses `eigvals_only=True` if `compute_uv=False`. - v, s = lax.linalg.eigh( + v, s = lax_linalg.eigh( h, subset_by_index=subset_by_index, sort_eigenvalues=False ) @@ -99,18 +106,19 @@ def _svd_tall_and_square_input( # eigenvalue decomposition and the SVD." SIAM Journal on Scientific Computing # 35, no. 3 (2013): A1325-A1349. def correct_rank_deficiency(u_out): - u_out, r = lax.linalg.qr(u_out, full_matrices=False) + u_out, r = lax_linalg.qr(u_out, full_matrices=False) u_out = u_out @ jnp.diag(jnp.where(jnp.diag(r) >= 0, 1, -1)) return u_out - eps = float(jnp.finfo(a.dtype).eps) + eps = float(dtypes.finfo(a.dtype).eps) do_correction = s_out[-1] <= a.shape[1] * eps * s_out[0] cond_f = lambda args: args[1] body_f = lambda args: (correct_rank_deficiency(args[0]), False) u_out, _ = lax.while_loop(cond_f, body_f, (u_out, do_correction)) return (u_out, s_out, v_out) -@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4, 5)) + +@functools.partial(api.jit, static_argnums=(1, 2, 3, 4, 5)) def svd( a: Any, full_matrices: bool, @@ -196,8 +204,8 @@ def svd( is_flip = True reduce_to_square = False - if full_matrices: - q_full, a_full = lax.linalg.qr(a, pivoting=False, full_matrices=True) + if full_matrices and m > n: + q_full, a_full = lax_linalg.qr(a, pivoting=False, full_matrices=True) q = q_full[:, :n] u_out_null = q_full[:, n:] a = a_full[:n, :] @@ -206,32 +214,32 @@ def svd( # The constant `1.15` comes from Yuji Nakatsukasa's implementation # https://www.mathworks.com/matlabcentral/fileexchange/36830-symmetric-eigenvalue-decomposition-and-the-svd?s_tid=FX_rc3_behav if m > 1.15 * n: - q, a = lax.linalg.qr(a, pivoting=False, full_matrices=False) + q, a = lax_linalg.qr(a, pivoting=False, full_matrices=False) reduce_to_square = True if not compute_uv: - with jax.default_matmul_precision('float32'): + with config.default_matmul_precision('float32'): return _svd_tall_and_square_input( a, hermitian, compute_uv, max_iterations, subset_by_index ) - with jax.default_matmul_precision('float32'): + with config.default_matmul_precision('float32'): u_out, s_out, v_out = _svd_tall_and_square_input( a, hermitian, compute_uv, max_iterations, subset_by_index ) if reduce_to_square: u_out = q @ u_out - if full_matrices: + if full_matrices and m > n: u_out = jnp.hstack((u_out, u_out_null)) is_finite = jnp.all(jnp.isfinite(a)) cond_f = lambda args: jnp.logical_not(args[0]) body_f = lambda args: ( jnp.array(True), - jnp.full_like(u_out, jnp.nan), - jnp.full_like(s_out, jnp.nan), - jnp.full_like(v_out, jnp.nan), + jnp.full_like(u_out, np.nan), + jnp.full_like(s_out, np.nan), + jnp.full_like(v_out, np.nan), ) _, u_out, s_out, v_out = lax.while_loop( cond_f, body_f, (is_finite, u_out, s_out, v_out) @@ -241,3 +249,60 @@ def svd( return (v_out, s_out, u_out.T.conj()) return (u_out, s_out, v_out.T.conj()) + + +def _svd_tpu(a, *, full_matrices, compute_uv, subset_by_index, algorithm=None): + if algorithm is not None and algorithm != lax_linalg.SvdAlgorithm.DEFAULT: + raise NotImplementedError( + "The SVD algorithm parameter is not implemented on TPU.") + + batch_dims = a.shape[:-2] + fn = functools.partial( + svd, + full_matrices=full_matrices, + compute_uv=compute_uv, + subset_by_index=subset_by_index, + ) + for _ in range(len(batch_dims)): + fn = api.vmap(fn) + + if compute_uv: + u, s, vh = fn(a) + return [s, u, vh] + else: + s = fn(a) + return [s] + + +def _svd_tpu_lowering_rule( + ctx, operand, *, full_matrices, compute_uv, subset_by_index, algorithm=None +): + operand_aval, = ctx.avals_in + m, n = operand_aval.shape[-2:] + + if algorithm is not None and algorithm not in [ + lax_linalg.SvdAlgorithm.DEFAULT, + lax_linalg.SvdAlgorithm.POLAR, + ]: + raise NotImplementedError( + 'Only the POLAR (which is also DEFAULT on TPU) SVD algorithm is' + ' supported on TPU.' + ) + + if m == 0 or n == 0: + return mlir.lower_fun(lax_linalg._empty_svd, multiple_results=True)( + ctx, + operand, + full_matrices=full_matrices, + compute_uv=compute_uv, + ) + + return mlir.lower_fun(_svd_tpu, multiple_results=True)( + ctx, + operand, + full_matrices=full_matrices, + compute_uv=compute_uv, + subset_by_index=subset_by_index, + ) + +mlir.register_lowering(lax_linalg.svd_p, _svd_tpu_lowering_rule) diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index 4089e047f8b0..c7c2ceba1f0e 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -22,22 +22,21 @@ from collections.abc import Callable, Sequence import dataclasses import enum -import functools import io -import os -import time -from typing import Any +import json +from typing import Any, TypedDict -import jax +from jax._src import api from jax._src import config from jax._src import core +from jax._src import dispatch from jax._src import sharding_impls +from jax._src.cloud_tpu_init import is_cloud_tpu_older_than +from jax._src.frozen_dict import FrozenDict +from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.lib import tpu -from jax._src.lib import xla_client -from jax.interpreters import xla from jaxlib.mlir import ir -from jaxlib.mlir.dialects import stablehlo from jaxlib.mlir.passmanager import PassManager try: @@ -46,16 +45,6 @@ except ImportError: FLAGS = {} -_MOSAIC_USE_PYTHON_PIPELINE = config.bool_state( - name="mosaic_use_python_pipeline", - default=False, - help=( - "Run the initial Mosaic MLIR passes from Python, when as_tpu_kernel" - " is called (for Pallas, this happens at JAX lowering time), instead of" - " later within XLA." - ), -) - _MOSAIC_ALLOW_HLO = config.bool_state( name="jax_mosaic_allow_hlo", default=False, @@ -63,20 +52,47 @@ ) -# This tracks the latest Mosaic IR version with a monthly delay. -FWD_COMPAT_IR_VERSION = 3 +# Controls the IR serialization version. Upon incrementing the +# default version in jaxlib/mosaic/dialect/tpu/transforms/serde.cc we must +# continue to use the old serialization version when in forward compatibility +# mode: for 1 month when exporting, or when using old cloud TPU. +# +# This can be achieved by adding: +# if ctx.is_forward_compat() or backend is None or is_cloud_tpu_older_than(): +# return +# return None +# +# We should also add a TODO to remove the conditional one month later. +def get_ir_version(ctx: mlir.LoweringRuleContext) -> int | None: + backend = ctx.module_context.get_backend(optional=True) + # TODO(apaszke): remove the forward compatibility check after 2025-12-5. + if ( + ctx.is_forward_compat() + or backend is None + or is_cloud_tpu_older_than(2025, 11, 5, backend) + ): + return 8 + return None tpu_custom_call_p = core.Primitive("tpu_custom_call") -tpu_custom_call_p.def_impl( - functools.partial(xla.apply_primitive, tpu_custom_call_p)) tpu_custom_call_p.multiple_results = True +dispatch.simple_impl(tpu_custom_call_p) -def get_target_shape(hardware_generation: int) -> tuple[int, int]: - """Returns the target shape for the given hardware generation.""" - del hardware_generation - return (8, 128) +def tpu_custom_call_batcher(axis_data, args, dims, **kwargs): + if axis_data.size != 1: + raise NotImplementedError( + "tpu_custom_call does not support non-trivial batching." + ) + unbatched_args = tuple( + a if (d is batching.not_mapped or d is None) else a[d] + for a, d in zip(args, dims, strict=True) + ) + out_unbatched = tpu_custom_call_p.bind(*unbatched_args, **kwargs) + out = tuple(o[None] for o in out_unbatched) + return out, (0,) * len(out) +batching.fancy_primitive_batchers[tpu_custom_call_p] = tpu_custom_call_batcher class MemorySpace(enum.Enum): @@ -84,6 +100,8 @@ class MemorySpace(enum.Enum): VMEM = enum.auto() SEMAPHORE_MEM = enum.auto() SMEM = enum.auto() + HOST = enum.auto() + SC_SCALAR_SEMAPHORE_MEM = enum.auto() @property def color(self) -> int: @@ -93,23 +111,43 @@ def color(self) -> int: return 1 elif self == MemorySpace.SEMAPHORE_MEM: return 2 + elif self == MemorySpace.SC_SCALAR_SEMAPHORE_MEM: + return 8 elif self == MemorySpace.SMEM: return 4 + elif self == MemorySpace.HOST: + return 5 else: raise ValueError("invalid memory space: " + str(self)) -@dataclasses.dataclass(frozen=True) -class CostEstimate: +class CostEstimate(TypedDict): flops: int transcendentals: int bytes_accessed: int + remote_bytes_transferred: int = 0 def to_json(self) -> bytes: return ( - f'{{"flops": {self.flops}, "transcendentals": {self.transcendentals},' - f' "bytes_accessed": {self.bytes_accessed}}}' - ).encode('ascii') + f'{{"flops": {self["flops"]}, "transcendentals":' + f' {self["transcendentals"]}, "bytes_accessed":' + f' {self["bytes_accessed"]}, "remote_bytes_transferred":' + f' {self["remote_bytes_transferred"]}}}' + ).encode("ascii") + + +class TpuSideEffectType(enum.Enum): + # No side effects, can be deduplicated / removed if unused. + PURE = "pure" + # Cannot be deduplicated, but can be removed if unused. + DATAFLOW_SIDE_EFFECTING = "dataflow_side_effecting" + # Cannot be deduplicated or removed. + SIDE_EFFECTING = "side_effecting" + + +class Tiling(enum.Enum): + COMPACT = "TILING_COMPACT" + SPARSE_CORE = "TILING_SPARSE_CORE" @dataclasses.dataclass(frozen=True) @@ -124,10 +162,24 @@ class CustomCallBackendConfig: needs_layout_passes: bool vmem_limit_bytes: int | None flags: dict[str, bool | int | float] | None - allow_input_fusion: list[bool] | None + allow_input_fusion: Sequence[bool] | None serialization_format: int | None internal_scratch_in_bytes: int | None output_memory_spaces: tuple[MemorySpace | None, ...] | None + disable_bounds_checks: bool + active_core_count: int | None + input_memory_spaces: tuple[MemorySpace | None, ...] | None + skip_device_barrier: bool + shape_invariant_numerics: bool + tiling: Tiling | None = None # Only used for SparseCore. + + def __post_init__(self): + if self.allow_input_fusion is not None: + object.__setattr__(self, "allow_input_fusion", + tuple(self.allow_input_fusion)) + if self.cost_estimate is not None: + object.__setattr__(self, "cost_estimate", + FrozenDict(self.cost_estimate)) # We omit the body while printing, because primitive params get embedded # in HLO metadata, and the body blows up its size. @@ -149,7 +201,7 @@ def to_json(self) -> bytes: config.write(str(self.collective_id).encode("ascii")) if self.cost_estimate is not None: config.write(b', "cost_estimate": ') - config.write(self.cost_estimate.to_json()) + config.write(_compact_json_object(**self.cost_estimate)) if self.needs_hlo_passes: config.write(b', "needs_hlo_passes": ') config.write(str(self.needs_hlo_passes).lower().encode("ascii")) @@ -159,11 +211,13 @@ def to_json(self) -> bytes: if self.needs_layout_passes: config.write(b', "needs_layout_passes": ') config.write(str(self.needs_layout_passes).lower().encode("ascii")) + if not self.shape_invariant_numerics: + config.write(b', "shape_invariant_numerics": ') + config.write(str(self.shape_invariant_numerics).lower().encode("ascii")) if self.allow_input_fusion is not None: config.write(b', "allow_input_fusion": [') for i, value in enumerate(self.allow_input_fusion): config.write(b"true" if value else b"false") - # config.write(str(value).lower().encode("ascii")) if i + 1 != len(self.allow_input_fusion): config.write(b",") config.write(b"]") @@ -178,7 +232,44 @@ def to_json(self) -> bytes: color = memory_space.color if memory_space is not None else -1 config.write(str(color).encode("ascii")) config.write(b"]") + if self.input_memory_spaces is not None: + comma = False + for i, input_memory_space in enumerate(self.input_memory_spaces): + if input_memory_space is None: + continue + if input_memory_space is MemorySpace.SMEM: + # TODO(sharadmv): Add support for SMEM (though atm, XLA will not + # page out SMEM arrays). + continue + if input_memory_space not in ( + MemorySpace.HBM, + MemorySpace.VMEM, + MemorySpace.SMEM, + ): + raise NotImplementedError( + "input_memory_space_colors only supports HBM, VMEM and SMEM" + ) + if comma: + config.write(b",") + else: + config.write(b', "input_memory_space_colors": [') + config.write( + f'{{"operand_index":{i},"color":{input_memory_space.color}}}' + .encode("ascii") + ) + comma = True + if comma: + config.write(b"]") + if self.disable_bounds_checks: + config.write(b', "disable_bounds_checks": ') + config.write(str(self.disable_bounds_checks).lower().encode("ascii")) + if self.skip_device_barrier: + config.write(b', "skip_device_barrier": ') + config.write(str(self.skip_device_barrier).lower().encode("ascii")) config.write(b"}") # End of custom_call_config. + if self.tiling is not None: + config.write(b', "sparse_core_config": ') + config.write(_compact_json_object(tiling=self.tiling.value)) if self.device_type is not None: config.write(b', "device_type": ') config.write( @@ -196,7 +287,7 @@ def to_json(self) -> bytes: for i, (flag, value) in enumerate(self.flags.items()): config.write(b'{"flag_type": "') config.write(flag.encode("ascii")) - config.write(b'", value: {') + config.write(b'", "value": {') if isinstance(value, bool): config.write(b'"boolean_value": ') config.write(b"true" if value else b"false") @@ -212,32 +303,43 @@ def to_json(self) -> bytes: if i + 1 != len(self.flags): config.write(b",") config.write(b"]") + if self.device_type == "sparsecore" and self.active_core_count == 1: + config.write(b', "megachip_parallelism_config": {"cores": ["0"]}') config.write(b"}") return config.getvalue() +def _compact_json_object(**kwargs: Any) -> bytes: + return json.dumps( + kwargs, sort_keys=True, indent=0, separators=(",", ":") + ).encode("ascii") + + @tpu_custom_call_p.def_abstract_eval def _tpu_custom_call_abstract_eval(*_, out_avals, **__): return out_avals def _avals_to_layouts(avals) -> Sequence[Sequence[int]]: - return [tuple(range(a.ndim - 1, -1, -1)) for a in avals] + return [tuple(range(a.ndim - 1, -1, -1)) for a in avals] # pytype: disable=attribute-error def _tpu_custom_call_lowering( ctx: mlir.LoweringRuleContext, *in_nodes, # pylint: disable=missing-function-docstring config: CustomCallBackendConfig, - has_side_effects: bool, + has_side_effects: TpuSideEffectType, kernel_name: str | None, out_avals: Any, input_output_aliases: tuple[tuple[int, int], ...], -) -> ...: + metadata: Any | None, +) -> ir.OpResultList: result_types = [mlir.aval_to_ir_type(aval) for aval in out_avals] axis_context = ctx.module_context.axis_context if isinstance(axis_context, sharding_impls.SPMDAxisContext): - if axis_context.manual_axes != frozenset(axis_context.mesh.axis_names): + manual_axes = axis_context.manual_axes | set(axis_context.mesh.manual_axes) + if (axis_context.manual_axes and + manual_axes != frozenset(axis_context.mesh.axis_names)): raise NotImplementedError( "Mosaic kernels cannot be automatically partitioned. Please wrap the" " call in a shard_map." @@ -264,21 +366,29 @@ def _tpu_custom_call_lowering( # information. if kernel_name is not None: extra_attributes = dict(kernel_name=ir.StringAttr.get(kernel_name)) - has_side_effects = has_side_effects if has_side_effects is not None else False call = mlir.custom_call( "tpu_custom_call", result_types=result_types, operands=in_nodes, backend_config=config.to_json(), api_version=1, - has_side_effect=has_side_effects, + has_side_effect=has_side_effects != TpuSideEffectType.PURE, operand_output_aliases=dict(input_output_aliases), operand_layouts=_avals_to_layouts(ctx.avals_in), result_layouts=_avals_to_layouts(ctx.avals_out), result_shapes=result_shapes, extra_attributes=extra_attributes, ) - + metadata_dict = {} + if metadata is not None: + metadata_dict["kernel_metadata"] = ir.StringAttr.get( + _compact_json_object(**metadata) + ) + assert isinstance(has_side_effects, TpuSideEffectType) + if has_side_effects == TpuSideEffectType.DATAFLOW_SIDE_EFFECTING: + metadata_dict["xla_allow_dce_side_effecting_op"] = ir.StringAttr.get("true") + if metadata_dict: + call.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(metadata_dict) return call.results @@ -286,211 +396,22 @@ def _tpu_custom_call_lowering( platform="tpu") -def _lower_tpu_kernel( - module: ir.Module, - hardware_generation: int, - target_shape: tuple[int, int], - kernel_name: str | None = None, -) -> ir.Module: - """Runs MLIR passes lowering the given module to an MLIR module. - - Uses Python versions of canonicalize-mosaic,infer-memref-layout and - apply-vector-layout. - - Args: - module: The MLIR module to lower. - hardware_generation: The TPU hardware generation to target. - target_shape: The target shape of (sublane_count, lane_count). - - Returns: - An MLIR module implementing the kernel. - """ - try: - module.operation.verify() - except ir.MLIRError as e: - raise ValueError("The compiled module fails MLIR verification") from e - - timestamp = time.time_ns() - dump_cnt = [0] - - def get_dump_file_prefix() -> str: - s = f"{timestamp}-{dump_cnt[0]:04}" - dump_cnt[0] += 1 - return s - - with module.context as ctx, module.operation.location as _: - ctx.append_dialect_registry(mlir.upstream_dialects) - ctx.load_all_available_dialects() - tpu.register_dialect(ctx) - stablehlo.register_dialect(ctx) - dump_mlir(module, "original", get_dump_file_prefix(), kernel_name) - - if _MOSAIC_ALLOW_HLO.value: - # Run dialect conversion: StableHLO -> linalg -> vector. - pipeline = [ - "func.func(stablehlo-legalize-to-linalg)", - "func.func(linalg-vectorization)", - ] - pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") - pipeline.run(module.operation) - dump_mlir(module, "post-hlo-conversion", get_dump_file_prefix(), kernel_name) - - sl_cnt, l_cnt = target_shape - # Note: we don't pass the TpuTilingFlags here, since we don't know the - # tiling decisions made by the compiler / what flags are enabled at this - # point, so we assume everything can be tiled up to default tiling. - pipeline = [ - "func.func(tpu-infer-memref-layout{" - f" hardware-generation={hardware_generation}" - f" sublane-count={sl_cnt}" - f" lane-count={l_cnt}" - "})" - ] - pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") - pipeline.run(module.operation) - dump_mlir(module, "post-infer-memref-layout", get_dump_file_prefix(), kernel_name) - - pipeline = [ - "canonicalize", - "cse", - ] - pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") - pipeline.run(module.operation) - dump_mlir( - module, - "post-infer-memref-layout-simplify", - get_dump_file_prefix(), - kernel_name, - ) - - try: - on_device_checks = FLAGS["xla_mosaic_on_device_checks"].value - except KeyError: - on_device_checks = False - - if checks := on_device_checks: - checks = set(checks.split(",")) - if checks == {"bounds"}: # We only support one kind of checks now. - pipeline = PassManager.parse( - "builtin.module(func.func(debug-assert-insertion))" - ) - pipeline.run(module.operation) - dump_mlir(module, "post-assert-insertion", get_dump_file_prefix(), kernel_name) - elif checks: - checks.discard("bounds") - raise ValueError( - f"Unrecognized on-device check categories: {', '.join(checks)}" - ) - - # Legacy pipeline always runs in compatibility mode. - compatibility_mode = True - pipeline = [ - ( - f"func.func(tpu-canonicalize-mosaic{{hardware-generation={hardware_generation} compatibility-mode={compatibility_mode}}})" - ), - ] - pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") - pipeline.run(module.operation) - dump_mlir(module, "post-canonicalize-mosaic", get_dump_file_prefix(), kernel_name) - - pipeline = [ - ( - "func.func(tpu-infer-vector-layout{" - f" hardware-generation={hardware_generation}" - f" sublane-count={sl_cnt} lane-count={l_cnt}" - "})" - ), - ] - pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") - pipeline.run(module.operation) - dump_mlir(module, "post-infer-vector-layout", get_dump_file_prefix(), kernel_name) - - pipeline = [ - ( - "func.func(tpu-relayout-insertion{" - f" sublane-count={sl_cnt} lane-count={l_cnt}" - f" hardware-generation={hardware_generation}" - "})" - ), - ] - pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") - pipeline.run(module.operation) - dump_mlir(module, "post-relayout-insertion", get_dump_file_prefix(), kernel_name) - - mxu_size = 128 if hardware_generation < 6 else 256 - pipeline = [ - "func.func(tpu-apply-vector-layout{" - f" sublane-count={sl_cnt} lane-count={l_cnt}" - f" hardware-generation={hardware_generation}" - f" mxu-contracting-size={mxu_size} mxu-noncontracting-size={mxu_size}" - f" max-sublanes-in-scratch={sl_cnt * (sl_cnt + 1)}" - "})" - ] - pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") - pipeline.run(module.operation) - dump_mlir(module, "post-apply-vector-layout", get_dump_file_prefix(), kernel_name) - - pipeline = [ - "canonicalize", - "cse", - ] - pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") - pipeline.run(module.operation) - dump_mlir( - module, - "post-apply-vector-layout-simplify", - get_dump_file_prefix(), - kernel_name, - ) - - return module - - def _lower_mosaic_module_to_asm( module: ir.Module, *, - backend: str, - device_type: str | None, - kernel_name: str | None, ir_version: int | None = None, -) -> tuple[ir.Module, tuple[bool, bool, bool, bool]]: +) -> tuple[ir.Module, tuple[bool, bool]]: has_communication, has_custom_barrier = tpu.private_has_communication( module.operation ) - needs_hlo_passes = _MOSAIC_ALLOW_HLO.value - needs_layout_passes = not device_type # We'll mutate the module, so clone it with module.context as ctx, module.operation.location as _: - if needs_layout_passes and _MOSAIC_USE_PYTHON_PIPELINE.value: - module = ir.Module.parse( - module.operation.get_asm(binary=True, enable_debug_info=True) - ) - module_op = module.operation - some_tpu = jax.devices(backend)[0] - device_kind = some_tpu.device_kind - if not device_kind.startswith("TPU v"): - raise ValueError( - f"Unrecognized TPU device kind: {device_kind}. " - "tpu_custom_call cannot be lowered on a machine without TPUs " - "when mosaic_use_python_pipeline=True.") - hardware_generation = int(device_kind[len("TPU v")]) - target_shape = get_target_shape(hardware_generation) - module = _lower_tpu_kernel( - module, hardware_generation, target_shape=target_shape, kernel_name=kernel_name, - ) - needs_hlo_passes = False - needs_layout_passes = False - else: - module_op = module.operation.clone() + module_op = module.operation.clone() prev_allow_unregistered_dialects = ctx.allow_unregistered_dialects ctx.allow_unregistered_dialects = True - # TODO(apaszke): Remove once the minimum jaxlib version is at least 0.4.37. - if jax.version._version_as_tuple(jax.lib.__version__) < (0, 4, 37): - target_version = "" - else: - target_version = ( - f"target-version={ir_version}" if ir_version is not None else "" - ) + target_version = ( + f"target-version={ir_version}" if ir_version is not None else "" + ) try: pipeline = PassManager.parse( "builtin.module(mosaic-serde{serialize=true " + target_version + "})" @@ -504,8 +425,6 @@ def _lower_mosaic_module_to_asm( return asm, ( has_communication, has_custom_barrier, - needs_hlo_passes, - needs_layout_passes, ) @@ -539,42 +458,121 @@ def assign_device_type_based_on_core_type(op: ir.Operation) -> ir.WalkResult: ) if tensorcore_func_found and sparsecore_func_found: raise ValueError( - "A single Mosaic kernel cannot contain both " - "TensorCore and SparseCore functions." + "A single Mosaic kernel cannot contain both TensorCore and SparseCore" + " functions." ) if sparsecore_func_found: return "sparsecore" return None +def _get_active_core_count(module: ir.Module) -> int | None: + + def get_core_parallel_dim_size( + dim_semantics: ir.ArrayAttr, + iter_bounds: ir.DenseI64ArrayAttr, + other_subkernel_core_dim_size: int | None = None) -> int | None: + + if len(iter_bounds) != len(dim_semantics): + raise ValueError( + "The iteration bounds and dimension semantics attributes must have" + " the same number of elements." + ) + + subkernel_core_dim_size = None + + for dim_idx, (dim_size, dim_sem) in enumerate( + zip(iter_bounds, dim_semantics) + ): + if str(dim_sem) != "#tpu.dimension_semantics": + continue + + if ir.ShapedType.is_dynamic_size(dim_size): + raise ValueError( + "The iteration bound corresponding to the core-parallel dimension " + f"{dim_idx} must be statically known." + ) + if subkernel_core_dim_size is not None: + raise ValueError( + "A single Mosaic subkernel cannot contain multiple core sharding " + "dimensions." + ) + if ( + other_subkernel_core_dim_size is not None + and other_subkernel_core_dim_size != dim_size + ): + raise ValueError( + "The iteration bound corresponding to the core-parallel dimension " + "be the same across all subkernels." + ) + subkernel_core_dim_size = dim_size + + return subkernel_core_dim_size + + core_parallel_dim_size = None + + for op in module.body.operations: + if op.operation.name != "func.func": + continue + + if ( + "iteration_bounds" not in op.attributes + or "dimension_semantics" not in op.attributes + ): + continue + + try: + iter_bounds = ir.DenseI64ArrayAttr(op.attributes["iteration_bounds"]) + except ValueError as e: + e.add_note("The iteration bounds attribute must be an array.") + raise + try: + dim_semantics = ir.ArrayAttr(op.attributes["dimension_semantics"]) + except ValueError as e: + e.add_note("The dimension semantics attribute must be an array.") + raise + + core_parallel_dim_size = get_core_parallel_dim_size( + dim_semantics=dim_semantics, + iter_bounds=iter_bounds, + other_subkernel_core_dim_size=core_parallel_dim_size, + ) + + return core_parallel_dim_size + + def _lower_to_custom_call_config( module: ir.Module, *, - backend: str, - device_type: str | None, vmem_limit_bytes: int | None, cost_estimate: CostEstimate | None, flags: dict[str, bool | int | float] | None, - allow_input_fusion: list[bool] | None, + allow_input_fusion: Sequence[bool] | None, internal_scratch_in_bytes: int | None, collective_id: int | None, serialization_format: int | None, output_memory_spaces: tuple[MemorySpace | None, ...] | None = None, - kernel_name: str | None = None, ir_version: int | None = None, + disable_bounds_checks: bool = False, + input_memory_spaces: tuple[MemorySpace | None, ...] | None = None, + skip_device_barrier: bool = False, + allow_collective_id_without_custom_barrier: bool = False, + shape_invariant_numerics: bool = False, + needs_layout_passes: bool | None = None, + tiling: Tiling | None = None, ) -> CustomCallBackendConfig: + device_type = _get_device_type(module) + needs_hlo_passes = _MOSAIC_ALLOW_HLO.value + if needs_layout_passes is None: + needs_layout_passes = not device_type lowered_module_asm, ( has_communication, has_custom_barrier, - needs_hlo_passes, - needs_layout_passes, ) = _lower_mosaic_module_to_asm( module, - backend=backend, - device_type=device_type, - kernel_name=kernel_name, ir_version=ir_version, ) + active_core_count = _get_active_core_count(module) return _lowered_to_custom_call_config( lowered_module_asm, vmem_limit_bytes=vmem_limit_bytes, @@ -590,6 +588,13 @@ def _lower_to_custom_call_config( needs_hlo_passes=needs_hlo_passes, needs_layout_passes=needs_layout_passes, output_memory_spaces=output_memory_spaces, + disable_bounds_checks=disable_bounds_checks, + active_core_count=active_core_count, + input_memory_spaces=input_memory_spaces, + skip_device_barrier=skip_device_barrier, + allow_collective_id_without_custom_barrier=allow_collective_id_without_custom_barrier, + shape_invariant_numerics=shape_invariant_numerics, + tiling=tiling, ) @@ -599,7 +604,7 @@ def _lowered_to_custom_call_config( vmem_limit_bytes: int | None, cost_estimate: CostEstimate | None, flags: dict[str, bool | int | float] | None, - allow_input_fusion: list[bool] | None, + allow_input_fusion: Sequence[bool] | None, internal_scratch_in_bytes: int | None, collective_id: int | None, serialization_format: int | None, @@ -609,13 +614,20 @@ def _lowered_to_custom_call_config( needs_layout_passes: bool, device_type: str | None, output_memory_spaces: tuple[MemorySpace | None, ...] | None = None, + disable_bounds_checks: bool = False, + active_core_count: int | None = None, + input_memory_spaces: tuple[MemorySpace | None, ...] | None = None, + skip_device_barrier: bool = False, + allow_collective_id_without_custom_barrier: bool = False, + shape_invariant_numerics: bool = False, + tiling: Tiling | None = None, ): if has_custom_barrier: if collective_id is None: raise ValueError( "collective_id has to be specified when using a custom barrier" ) - elif collective_id is not None: + elif collective_id is not None and not allow_collective_id_without_custom_barrier: raise ValueError( "collective_id has to be unspecified or None when not using a custom" " barrier" @@ -625,7 +637,7 @@ def _lowered_to_custom_call_config( "vmem_limit_bytes must be an int: provided with a" f" {type(vmem_limit_bytes)}." ) - config = CustomCallBackendConfig( + return CustomCallBackendConfig( lowered_module_asm, has_communication, collective_id, @@ -639,8 +651,13 @@ def _lowered_to_custom_call_config( serialization_format, internal_scratch_in_bytes, output_memory_spaces, + disable_bounds_checks, + active_core_count=active_core_count, + input_memory_spaces=input_memory_spaces, + skip_device_barrier=skip_device_barrier, + shape_invariant_numerics=shape_invariant_numerics, + tiling=tiling, ) - return config def lower_module_to_custom_call( @@ -648,34 +665,50 @@ def lower_module_to_custom_call( *in_nodes: ir.Value, module: ir.Module, out_type: Any, - backend: str, kernel_name: str, cost_estimate: CostEstimate | None, vmem_limit_bytes: int | None, flags: dict[str, bool | int | float] | None, - allow_input_fusion: list[bool] | None, + allow_input_fusion: Sequence[bool] | None, input_output_aliases: tuple[tuple[int, int], ...], internal_scratch_in_bytes: int | None, collective_id: int | None, - has_side_effects: bool, + has_side_effects: bool | TpuSideEffectType, serialization_format: int | None, output_memory_spaces: tuple[MemorySpace | None, ...] | None, - device_type: str | None, + disable_bounds_checks: bool = False, + input_memory_spaces: tuple[MemorySpace | None, ...] | None, + metadata: Any | None = None, + skip_device_barrier: bool = False, + allow_collective_id_without_custom_barrier: bool = False, + shape_invariant_numerics: bool = False, + needs_layout_passes: bool | None = None, + tiling: Tiling | None = None, ) -> Sequence[ir.Value]: + if isinstance(has_side_effects, bool): + has_side_effects = ( + TpuSideEffectType.PURE + if not has_side_effects + else TpuSideEffectType.SIDE_EFFECTING + ) config = _lower_to_custom_call_config( module, - backend=backend, vmem_limit_bytes=vmem_limit_bytes, cost_estimate=cost_estimate, flags=flags, allow_input_fusion=allow_input_fusion, internal_scratch_in_bytes=internal_scratch_in_bytes, collective_id=collective_id, - device_type=device_type, serialization_format=serialization_format, output_memory_spaces=output_memory_spaces, - kernel_name=kernel_name, - ir_version=FWD_COMPAT_IR_VERSION if ctx.is_forward_compat() else None, + ir_version=get_ir_version(ctx), + disable_bounds_checks=disable_bounds_checks, + input_memory_spaces=input_memory_spaces, + skip_device_barrier=skip_device_barrier, + allow_collective_id_without_custom_barrier=allow_collective_id_without_custom_barrier, + shape_invariant_numerics=shape_invariant_numerics, + needs_layout_passes=needs_layout_passes, + tiling=tiling, ) return _tpu_custom_call_lowering( ctx, @@ -685,6 +718,7 @@ def lower_module_to_custom_call( kernel_name=kernel_name, out_avals=out_type, input_output_aliases=input_output_aliases, + metadata=metadata, ) @@ -693,24 +727,26 @@ def as_tpu_kernel( out_type: Any, *, cost_estimate: CostEstimate | None = None, - backend: str | xla_client.Client = "tpu", kernel_name: str | None = None, vmem_limit_bytes: int | None = None, flags: dict[str, bool | int | float] | None = None, - allow_input_fusion: list[bool] | None = None, + allow_input_fusion: Sequence[bool] | None = None, input_output_aliases: tuple[tuple[int, int], ...] = (), internal_scratch_in_bytes: int | None = None, collective_id: int | None = None, - has_side_effects: bool = False, + has_side_effects: TpuSideEffectType = TpuSideEffectType.PURE, serialization_format: int | None = 1, output_memory_spaces: tuple[MemorySpace | None, ...] | None = None, + disable_bounds_checks: bool = False, + input_memory_spaces: tuple[MemorySpace | None, ...] | None = None, + shape_invariant_numerics: bool = False, + needs_layout_passes: bool | None = None, + metadata: Any | None = None, + _ir_version: int | None = None, ) -> Callable[..., Any]: """Turns an MLIR Mosaic kernel into a JAX-compatible function.""" - device_type = _get_device_type(module) config = _lower_to_custom_call_config( module, - backend=backend, - device_type=device_type, vmem_limit_bytes=vmem_limit_bytes, cost_estimate=cost_estimate, flags=flags, @@ -719,7 +755,11 @@ def as_tpu_kernel( collective_id=collective_id, serialization_format=serialization_format, output_memory_spaces=output_memory_spaces, - kernel_name=kernel_name, + disable_bounds_checks=disable_bounds_checks, + input_memory_spaces=input_memory_spaces, + shape_invariant_numerics=shape_invariant_numerics, + needs_layout_passes=needs_layout_passes, + ir_version=_ir_version, ) return _as_jax_callable( config, @@ -727,6 +767,7 @@ def as_tpu_kernel( out_type, kernel_name=kernel_name, input_output_aliases=input_output_aliases, + metadata=metadata, ) @@ -738,21 +779,30 @@ def lowered_as_tpu_kernel( cost_estimate: CostEstimate | None = None, needs_hlo_passes: bool = False, needs_layout_passes: bool = False, - device_type: str | None = None, has_communication: bool = False, - has_side_effects: bool = False, + has_side_effects: bool | TpuSideEffectType = False, has_custom_barrier: bool = False, kernel_name: str | None = None, vmem_limit_bytes: int | None = None, flags: dict[str, bool | int | float] | None = None, - allow_input_fusion: list[bool] | None = None, + allow_input_fusion: Sequence[bool] | None = None, input_output_aliases: tuple[tuple[int, int], ...] = (), serialization_format: int | None = None, internal_scratch_in_bytes: int | None = None, + disable_bounds_checks: bool = False, + metadata: Any | None = None, + allow_collective_id_without_custom_barrier: bool = False, ) -> Callable[..., Any]: + device_type = _get_device_type(lowered_module) lowered_module_asm = lowered_module.operation.get_asm( binary=True, enable_debug_info=True ) + if isinstance(has_side_effects, bool): + has_side_effects = ( + TpuSideEffectType.PURE + if not has_side_effects + else TpuSideEffectType.DATAFLOW_SIDE_EFFECTING + ) config = _lowered_to_custom_call_config( lowered_module_asm, vmem_limit_bytes=vmem_limit_bytes, @@ -767,6 +817,8 @@ def lowered_as_tpu_kernel( has_communication=has_communication, needs_hlo_passes=needs_hlo_passes, needs_layout_passes=needs_layout_passes, + disable_bounds_checks=disable_bounds_checks, + allow_collective_id_without_custom_barrier=allow_collective_id_without_custom_barrier, ) return _as_jax_callable( config, @@ -774,16 +826,18 @@ def lowered_as_tpu_kernel( out_type, kernel_name=kernel_name, input_output_aliases=input_output_aliases, + metadata=metadata, ) def _as_jax_callable( config: CustomCallBackendConfig, - has_side_effects: bool, + has_side_effects: TpuSideEffectType, out_type: Any, *, kernel_name: str | None, input_output_aliases: tuple[tuple[int, int], ...], + metadata: Any | None, ) -> Callable[..., Any]: unpack = False if not isinstance(out_type, collections.abc.Iterable): @@ -800,25 +854,8 @@ def apply_kernel(*args): kernel_name=kernel_name, out_avals=out_avals, input_output_aliases=input_output_aliases, + metadata=metadata, ) return result[0] if unpack else result - return jax.jit(apply_kernel) - - -def dump_mlir( - module: ir.Module, name: str, prefix: str, kernel_name: str | None = None -): - """A helper function to dump mosaic mlir module""" - try: - should_dump = FLAGS["xla_mosaic_dump_to"].value - except KeyError: - return - if should_dump == "sponge": - outdir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", None) - if outdir: - if kernel_name: - name = f"{kernel_name}-{name}" - path = os.path.join(outdir, f"{prefix}-mosaic-dump-{name}-py.txt") - with open(path, "w") as f: - f.write(str(module)) + return api.jit(apply_kernel) diff --git a/jax/_src/traceback_util.py b/jax/_src/traceback_util.py index d66cbb912a99..b7641c209589 100644 --- a/jax/_src/traceback_util.py +++ b/jax/_src/traceback_util.py @@ -17,22 +17,27 @@ from collections.abc import Callable import functools import os -import sys import traceback import types from typing import Any, TypeVar, cast from jax._src import config from jax._src import util -from jax._src.lib import xla_extension +from jax._src.lib import _jax C = TypeVar("C", bound=Callable[..., Any]) -_exclude_paths: list[str] = [__file__, util.__file__] +_exclude_paths: list[str] = [] def register_exclusion(path: str): _exclude_paths.append(path) + # TODO(nbasile): Remove hasattr checks after jaxlib 0.8.1 release + if hasattr(_jax, "add_exclude_path"): + _jax.add_exclude_path(path) + +register_exclusion(__file__) +register_exclusion(util.__file__) _jax_message_append = ( 'The stack trace below excludes JAX-internal frames.\n' @@ -56,8 +61,10 @@ def _path_starts_with(path: str, path_prefix: str) -> bool: return False def include_frame(f: types.FrameType) -> bool: - return not any(_path_starts_with(f.f_code.co_filename, path) - for path in _exclude_paths) + return include_filename(f.f_code.co_filename) + +def include_filename(filename: str) -> bool: + return not any(_path_starts_with(filename, path) for path in _exclude_paths) # When scanning stack traces, we might encounter frames from cpython that are # removed from printed stack traces, such as frames from parts of importlib. We @@ -67,7 +74,7 @@ def _ignore_known_hidden_frame(f: types.FrameType) -> bool: def _add_tracebackhide_to_hidden_frames(tb: types.TracebackType): for f, _lineno in traceback.walk_tb(tb): - if not include_frame(f): + if not include_frame(f) and not _is_reraiser_frame(f): f.f_locals["__tracebackhide__"] = True def filter_traceback(tb: types.TracebackType) -> types.TracebackType | None: @@ -103,9 +110,12 @@ def _add_call_stack_frames(tb: types.TracebackType) -> types.TracebackType: reached_module_level = True return out -def _is_reraiser_frame(f: traceback.FrameSummary) -> bool: - return (f.filename == __file__ and - f.name == 'reraise_with_filtered_traceback') +def _is_reraiser_frame(f: traceback.FrameSummary | types.FrameType) -> bool: + if isinstance(f, traceback.FrameSummary): + filename, name = f.filename, f.name + else: + filename, name = f.f_code.co_filename, f.f_code.co_name + return filename == __file__ and name == 'reraise_with_filtered_traceback' def _is_under_reraiser(e: BaseException) -> bool: if e.__traceback__ is None: @@ -150,7 +160,10 @@ def _filtering_mode() -> str: mode = "quiet_remove_frames" return mode -def api_boundary(fun: C) -> C: +def api_boundary( + fun: C, *, + repro_api_name: str | None = None, + repro_user_func: bool = False) -> C: '''Wraps ``fun`` to form a boundary for filtering exception tracebacks. When an exception occurs below ``fun``, this appends to it a custom @@ -171,6 +184,8 @@ def api_boundary(fun: C) -> C: ``g``. Because the function returned by :func:`~jax.jit` is annotated as an :func:`~api_boundary`, such an exception is accompanied by an additional traceback that excludes the frames specific to JAX's implementation. + + For the "repro" kwargs, see the comments for `repro.boundary`. ''' @functools.wraps(fun) @@ -186,30 +201,13 @@ def reraise_with_filtered_traceback(*args, **kwargs): _add_tracebackhide_to_hidden_frames(e.__traceback__) raise - filtered_tb, unfiltered = None, None + tb = e.__traceback__ try: - tb = e.__traceback__ - filtered_tb = filter_traceback(tb) - e.with_traceback(filtered_tb) - # In Python < 3.11, there seems to be no way to alter the currently - # raised exception traceback, except via the C API. The interpreter - # keeps a copy of the traceback (exc_traceback) that is separate to the - # __traceback__ of exc_value. Python 3.11 removes exc_traceback and - # just setting __traceback__ is enough. Since it is no longer needed, - # the XLA extension no longer defines a traceback-replacing method at - # Python 3.11 and onward. - if hasattr(xla_extension, "replace_thread_exc_traceback"): - # TODO(kidger): remove this line once Python 3.11 is the minimum supported - # version. - xla_extension.replace_thread_exc_traceback(filtered_tb) - if sys.version_info >= (3, 11) and mode == "quiet_remove_frames": + e.with_traceback(filter_traceback(tb)) + if mode == "quiet_remove_frames": e.add_note("--------------------\n" + _simplified_tb_msg) else: - if mode == "quiet_remove_frames": - # TODO(kidger): remove `SimplifiedTraceback` once Python 3.11 is the minimum - # supported version. - jax_error = SimplifiedTraceback() - elif mode == "remove_frames": + if mode == "remove_frames": msg = format_exception_only(e) msg = f'{msg}\n\n{_jax_message_append}' jax_error = UnfilteredStackTrace(msg) @@ -221,9 +219,21 @@ def reraise_with_filtered_traceback(*args, **kwargs): jax_error.__suppress_context__ = e.__suppress_context__ e.__cause__ = jax_error e.__context__ = None + del jax_error raise finally: - del filtered_tb - del unfiltered - del mode + del mode, tb + if (repro_api_name or repro_user_func) and repro: + reraise_with_filtered_traceback = repro.boundary( + reraise_with_filtered_traceback, api_name=repro_api_name, + is_user=repro_user_func) return cast(C, reraise_with_filtered_traceback) + +try: + # TODO: import from the final location + from jax._src import repro # type: ignore + repro_is_enabled = repro.is_enabled + +except ImportError: + repro = None # type: ignore + def repro_is_enabled(): return False # type: ignore diff --git a/jax/_src/tree.py b/jax/_src/tree.py index 70d75a126804..83d5684ad8f4 100644 --- a/jax/_src/tree.py +++ b/jax/_src/tree.py @@ -14,7 +14,7 @@ from __future__ import annotations from collections.abc import Callable, Iterable -from typing import Any, TypeVar, overload +from typing import Any, TypeVar from jax._src import tree_util @@ -155,21 +155,9 @@ def map(f: Callable[..., Any], return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf) -@overload def reduce(function: Callable[[T, Any], T], tree: Any, - *, - is_leaf: Callable[[Any], bool] | None = None) -> T: - ... -@overload -def reduce(function: Callable[[T, Any], T], - tree: Any, - initializer: T, - is_leaf: Callable[[Any], bool] | None = None) -> T: - ... -def reduce(function: Callable[[T, Any], T], - tree: Any, - initializer: Any = tree_util.no_initializer, + initializer: T | tree_util.Unspecified = tree_util.Unspecified(), is_leaf: Callable[[Any], bool] | None = None) -> T: """Call reduce() over the leaves of a tree. @@ -191,13 +179,66 @@ def reduce(function: Callable[[T, Any], T], >>> jax.tree.reduce(operator.add, [1, (2, 3), [4, 5, 6]]) 21 + Notes: + **Tip**: You can exclude leaves from the reduction by first mapping them to + ``None`` using :func:`jax.tree.map`. This causes them to not be counted as + leaves after that. + See Also: + - :func:`jax.tree.reduce_associative` - :func:`jax.tree.leaves` - :func:`jax.tree.map` """ return tree_util.tree_reduce(function, tree, initializer, is_leaf=is_leaf) +def reduce_associative( + operation: Callable[[T, T], T], + tree: Any, + *, + identity: T | tree_util.Unspecified = tree_util.Unspecified(), + is_leaf: Callable[[Any], bool] | None = None, +) -> T: + """Perform a reduction over a pytree with an associative binary operation. + + This function exploits the fact that the operation is associative to perform + the reduction in parallel (logarithmic depth). + + Args: + operation: the associative binary operation + tree: the pytree to reduce + identity: the identity element of the associative binary operation. + This is used only when the tree is empty. It is optional otherwise. + is_leaf: an optionally specified function that will be called at each + flattening step. It should return a boolean, which indicates whether the + flattening should traverse the current object, or if it should be stopped + immediately, with the whole subtree being treated as a leaf. + + Returns: + result: the reduced value + + Examples: + >>> import jax + >>> import operator + >>> jax.tree.reduce_associative(operator.add, [1, (2, 3), [4, 5, 6]]) + 21 + + Notes: + **Tip**: You can exclude leaves from the reduction by first mapping them to + ``None`` using :func:`jax.tree.map`. This causes them to not be counted as + leaves after that. + + See Also: + - :func:`jax.tree.reduce` + """ + return tree_util.tree_reduce_associative( + operation, + tree, + identity=identity, + is_leaf=is_leaf, + ) + + def structure(tree: Any, is_leaf: None | (Callable[[Any], bool]) = None) -> tree_util.PyTreeDef: """Gets the treedef for a pytree. @@ -287,7 +328,8 @@ def unflatten(treedef: tree_util.PyTreeDef, def flatten_with_path( - tree: Any, is_leaf: Callable[[Any], bool] | None = None + tree: Any, is_leaf: Callable[..., bool] | None = None, + is_leaf_takes_path: bool = False, ) -> tuple[list[tuple[tree_util.KeyPath, Any]], tree_util.PyTreeDef]: """Flattens a pytree like ``tree_flatten``, but also returns each leaf's key path. @@ -313,11 +355,12 @@ def flatten_with_path( - :func:`jax.tree.map_with_path` - :func:`jax.tree_util.register_pytree_with_keys` """ - return tree_util.tree_flatten_with_path(tree, is_leaf) + return tree_util.tree_flatten_with_path(tree, is_leaf, is_leaf_takes_path) def leaves_with_path( - tree: Any, is_leaf: Callable[[Any], bool] | None = None + tree: Any, is_leaf: Callable[..., bool] | None = None, + is_leaf_takes_path: bool = False, ) -> list[tuple[tree_util.KeyPath, Any]]: """Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path. @@ -338,14 +381,15 @@ def leaves_with_path( - :func:`jax.tree.flatten_with_path` - :func:`jax.tree_util.register_pytree_with_keys` """ - return tree_util.tree_leaves_with_path(tree, is_leaf) + return tree_util.tree_leaves_with_path(tree, is_leaf, is_leaf_takes_path) def map_with_path( f: Callable[..., Any], tree: Any, *rest: Any, - is_leaf: Callable[[Any], bool] | None = None, + is_leaf: Callable[..., bool] | None = None, + is_leaf_takes_path: bool = False, ) -> Any: """Maps a multi-input function over pytree key path and args to produce a new pytree. @@ -377,4 +421,37 @@ def map_with_path( - :func:`jax.tree.leaves_with_path` - :func:`jax.tree_util.register_pytree_with_keys` """ - return tree_util.tree_map_with_path(f, tree, *rest, is_leaf=is_leaf) + return tree_util.tree_map_with_path( + f, tree, *rest, is_leaf=is_leaf, is_leaf_takes_path=is_leaf_takes_path + ) + + +def broadcast(prefix_tree: Any, full_tree: Any, + is_leaf: Callable[[Any], bool] | None = None + ) -> Any: + """Broadcasts a tree prefix into the full structure of a given tree. + + Args: + prefix_tree: a pytree that is a tree prefix of full_tree. + full_tree: a pytree with the structure to broadcast the prefix leaves into. + is_leaf: an optionally specified function that will be called at each + flattening step. It should return a boolean, with true stopping the + traversal and the whole subtree being treated as a leaf, and false + indicating the flattening should traverse the current object. + + Returns: + A pytree matching the structure of full_tree where the leaves of prefix_tree have been + broadcasted into the leaves of each corresponding subtree. + + Examples: + >>> import jax + >>> prefix = (1, 2, 3) + >>> full = (0, {'a': 0, 'b': 0}, (0, 0)) + >>> jax.tree.broadcast(prefix, full) + (1, {'a': 2, 'b': 2}, (3, 3)) + + See Also: + - :func:`jax.tree.leaves` + - :func:`jax.tree.structure` + """ + return tree_util.tree_broadcast(prefix_tree, full_tree, is_leaf=is_leaf) diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 6c7e15a042e5..bd473dc8d5f8 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -18,10 +18,10 @@ import dataclasses import difflib import functools -from functools import partial +from functools import partial, cached_property import operator as op import textwrap -from typing import Any, NamedTuple, TypeVar, overload +from typing import Any, TypeVar from jax._src import traceback_util from jax._src.lib import pytree @@ -38,6 +38,7 @@ H = TypeVar("H", bound=Hashable) Leaf = Any +PyTree = Any PyTreeDef = pytree.PyTreeDef default_registry = pytree.default_registry() @@ -92,6 +93,13 @@ def tree_leaves(tree: Any, return default_registry.flatten(tree, is_leaf)[0] +@export +def tree_leaves_checked(treedef_expected: PyTreeDef, tree: Any) -> list[Leaf]: + flat_vals, treedef_actual = tree_flatten(tree) + assert treedef_actual == treedef_expected + return flat_vals + + @export def tree_structure(tree: Any, is_leaf: None | (Callable[[Any], @@ -123,7 +131,7 @@ def treedef_tuple(treedefs: Iterable[PyTreeDef]) -> PyTreeDef: See Also: - :func:`jax.tree_util.treedef_children` """ - return pytree.tuple(default_registry, list(treedefs)) + return pytree.treedef_tuple(default_registry, list(treedefs)) @export @@ -202,8 +210,11 @@ def all_leaves(iterable: Iterable[Any], if is_leaf is None: return pytree.all_leaves(default_registry, iterable) else: - lst = list(iterable) - return lst == tree_leaves(lst, is_leaf) + items = list(iterable) + leaves = tree_leaves(items, is_leaf) + return len(leaves) == len(items) and all( + item is leaf for item, leaf in zip(items, leaves, strict=True) + ) _Children = TypeVar("_Children", bound=Iterable[Any]) @@ -362,6 +373,8 @@ def tree_map(f: Callable[..., Any], def build_tree(treedef: PyTreeDef, xs: Any) -> Any: """Build a treedef from a nested iterable structure + DEPRECATED: Use :func:`jax.tree.unflatten` instead. + Args: treedef: the PyTreeDef structure to build. xs: nested iterables matching the arity as the treedef @@ -376,13 +389,6 @@ def build_tree(treedef: PyTreeDef, xs: Any) -> Any: >>> import jax >>> tree = [(1, 2), {'a': 3, 'b': 4}] >>> treedef = jax.tree.structure(tree) - - Both ``build_tree`` and :func:`jax.tree_util.tree_unflatten` can reconstruct - the tree from new values, but ``build_tree`` takes these values in terms of - a nested rather than flat structure: - - >>> jax.tree_util.build_tree(treedef, [[10, 11], [12, 13]]) - [(10, 11), {'a': 12, 'b': 13}] >>> jax.tree_util.tree_unflatten(treedef, [10, 11, 12, 13]) [(10, 11), {'a': 12, 'b': 13}] """ @@ -422,44 +428,55 @@ def tree_transpose(outer_treedef: PyTreeDef, inner_treedef: PyTreeDef | None, type(None): _RegistryEntry(lambda z: ((), None), lambda _, xs: None), } -def _replace_nones(sentinel, tree): - """Replaces ``None`` in ``tree`` with ``sentinel``.""" - leaves, treedef = none_leaf_registry.flatten(tree) - leaves = map(lambda x: sentinel if x is None else x, leaves) - return treedef.unflatten(leaves) - - -no_initializer = object() - -@overload -def tree_reduce(function: Callable[[T, Any], T], - tree: Any, - *, - is_leaf: Callable[[Any], bool] | None = None) -> T: - ... - - -@overload -def tree_reduce(function: Callable[[T, Any], T], - tree: Any, - initializer: T, - is_leaf: Callable[[Any], bool] | None = None) -> T: - ... +class Unspecified: + pass @export def tree_reduce(function: Callable[[T, Any], T], tree: Any, - initializer: Any = no_initializer, + initializer: T | Unspecified = Unspecified(), is_leaf: Callable[[Any], bool] | None = None) -> T: """Alias of :func:`jax.tree.reduce`.""" - if initializer is no_initializer: + if isinstance(initializer, Unspecified): return functools.reduce(function, tree_leaves(tree, is_leaf=is_leaf)) else: return functools.reduce(function, tree_leaves(tree, is_leaf=is_leaf), initializer) +def _parallel_reduce( + sequence: list[T], + operation: Callable[[T, T], T], + identity: T | Unspecified = Unspecified(), +) -> T: + length = len(sequence) + if length == 0: + if isinstance(identity, Unspecified): + raise TypeError("Must specify identity for parallel reduction of empty sequence.") + return identity + elif length == 1: + return sequence[0] + else: + index = length // 2 + a = _parallel_reduce(sequence[:index], operation, identity) + b = _parallel_reduce(sequence[index:], operation, identity) + return operation(a, b) + + +@export +def tree_reduce_associative( + operation: Callable[[T, T], T], + tree: Any, + *, + identity: T | Unspecified = Unspecified(), + is_leaf: Callable[[Any], bool] | None = None, +) -> T: + """Alias of :func:`jax.tree.reduce_associative`.""" + sequence = tree_leaves(tree, is_leaf=is_leaf) + return _parallel_reduce(sequence, operation, identity) + + @export def tree_all(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> bool: """Alias of :func:`jax.tree.all`.""" @@ -534,7 +551,7 @@ class Partial(functools.partial): >>> print_zero() 0 >>> call_func(print_zero) # doctest:+ELLIPSIS - Tracedwith + JitTracer(~int32[]) """ def __new__(klass, func, *args, **kw): @@ -562,20 +579,107 @@ def __new__(klass, func, *args, **kw): ) -# broadcast_prefix is not exported. +@export +def tree_broadcast(prefix_tree: Any, full_tree: Any, + is_leaf: Callable[[Any], bool] | None = None + ) -> Any: + """Alias of :func:`jax.tree.broadcast`.""" + broadcast_leaves = broadcast_prefix(prefix_tree, full_tree, is_leaf=is_leaf) + return tree_structure(full_tree).unflatten(broadcast_leaves) + + +# broadcast_prefix is not exported def broadcast_prefix(prefix_tree: Any, full_tree: Any, is_leaf: Callable[[Any], bool] | None = None ) -> list[Any]: - # If prefix_tree is not a tree prefix of full_tree, this code can raise a - # ValueError; use prefix_errors to find disagreements and raise more precise - # error messages. + """Broadcasts tree prefix leaves into the full set of leaves for a given full tree. + + Args: + prefix_tree: a pytree that is a tree prefix of full_tree. + full_tree: a pytree with the structure to broadcast the prefix leaves into. + is_leaf: an optionally specified function that will be called at each + flattening step for prefix_tree. It should return a boolean, with true + stopping the traversal and the whole subtree being treated as a leaf, + and false indicating the flattening should traverse the current object. + + Returns: + A list of leaves matching the expected count for the full tree, + with the leaf of each prefix tree being duplicated to match the count of + its corresponding subtree. + """ result = [] num_leaves = lambda t: tree_structure(t).num_leaves add_leaves = lambda x, subtree: result.extend([x] * num_leaves(subtree)) - tree_map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf) + try: + tree_map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf) + except ValueError: + e, *_ = prefix_errors(prefix_tree, full_tree) + raise e('broadcast_prefix prefix_tree') from None return result +# broadcast_flattened_prefix_with_treedef is not exported +def broadcast_flattened_prefix_with_treedef( + prefix_leaves: list[Any], + prefix_treedef: PyTreeDef, + full_treedef: PyTreeDef, +) -> list[Any]: + """Broadcasts tree prefix leaves into the full set of leaves for a given full treedef. + + Args: + prefix_leaves: the leaves of a pytree that is a tree prefix + of full_treedef. + prefix_treedef: the PyTreeDef of a pytree that is a tree prefix of + full_treedef. + full_treedef: a PyTreeDef with the structure to broadcast the prefix + leaves into. + + Returns: + A list of leaves matching the expected count for the full tree, + with each leaf of prefix tree being duplicated to match the count of + its corresponding subtree. + """ + # NOTE: At the moment, `broadcast_flattened_prefix_with_treedef` is only + # called from `api_util.flatten_axes`, which replaces any raised exception + # with its own exception and error message. The errors raised from this + # function should probably be improved before this function is used in + # more places. + # + # TODO(jburnim): Merge `broadcast_prefix` with this function? + # prefix_leaves, prefix_treedef = tree_flatten(prefix_tree, is_leaf) + ret = [] + + # TODO(jburnim): Should this traversal be done in C++? + def _broadcast(broadcast_fn, leaf_start, leaf_end, prefix_treedef, treedef): + if treedef_is_strict_leaf(prefix_treedef): + # We have encountered a leaf in the prefix, so we repeat the prefix leaf + # for each leaf in the corresponding part of the tree. + assert (leaf_end - leaf_start) == 1 + ret.extend(prefix_leaves[leaf_start:leaf_end] * treedef.num_leaves) + return + + if treedef_is_strict_leaf(treedef): + raise ValueError('`prefix_treedef` is not a prefix of `full_treedef`') + + prefix_node_data = prefix_treedef.node_data() + node_data = treedef.node_data() + if prefix_node_data != node_data: + raise ValueError(f'expected {node_data}, got {prefix_node_data}') + + prefix_i = leaf_start + for prefix_child, tree_child in zip( + prefix_treedef.children(), treedef.children(), strict=True): + broadcast_fn(broadcast_fn, prefix_i, prefix_i + prefix_child.num_leaves, + prefix_child, tree_child, + ) + prefix_i += prefix_child.num_leaves + + # Pass _broadcast as arg to avoid it being a free variable within its own + # closure, which creates a reference cycle. + _broadcast(_broadcast, 0, len(prefix_leaves), prefix_treedef, full_treedef) + return ret + + # flatten_one_level is not exported. def flatten_one_level(tree: Any) -> tuple[Iterable[Any], Hashable]: """Flatten the given pytree node by one level. @@ -767,42 +871,6 @@ def _simple_entrystr(key: KeyEntry) -> str: return str(key) -# TODO(ivyzheng): remove this after another jaxlib release. -class _RegistryWithKeypathsEntry(NamedTuple): - flatten_with_keys: Callable[..., Any] - unflatten_func: Callable[..., Any] - - -def _register_keypaths( - ty: type[T], handler: Callable[[T], tuple[KeyEntry, ...]] -) -> None: - def flatten_with_keys(xs): - children, treedef = _registry[ty].to_iter(xs) - return list(zip(handler(xs), children)), treedef - if ty in _registry: - _registry_with_keypaths[ty] = _RegistryWithKeypathsEntry( - flatten_with_keys, _registry[ty].from_iter - ) - -_registry_with_keypaths: dict[type[Any], _RegistryWithKeypathsEntry] = {} - -_register_keypaths( - tuple, lambda xs: tuple(SequenceKey(i) for i in range(len(xs))) -) -_register_keypaths( - list, lambda xs: tuple(SequenceKey(i) for i in range(len(xs))) -) -_register_keypaths(dict, lambda xs: tuple(DictKey(k) for k in sorted(xs))) - -_register_keypaths( - collections.defaultdict, lambda x: tuple(DictKey(k) for k in x.keys()) -) - -_register_keypaths( - collections.OrderedDict, lambda x: tuple(DictKey(k) for k in x.keys()) -) - - @export def register_pytree_with_keys( nodetype: type[T], @@ -872,9 +940,6 @@ def flatten_func_impl(tree): register_pytree_node( nodetype, flatten_func, unflatten_func, flatten_with_keys ) - _registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry( - flatten_with_keys, unflatten_func - ) @export @@ -938,7 +1003,7 @@ def register_dataclass( registries use the optimized C++ dataclass builtin instead of the argument functions. - See :ref:`extending-pytrees` for more information about registering pytrees. + See :ref:`pytrees-custom-pytree-nodes` for more information about registering pytrees. Args: nodetype: a Python type to treat as an internal pytree node. This is assumed @@ -947,7 +1012,7 @@ def register_dataclass( as keywords to the class constructor to create a copy of the object. All defined attributes should be listed among ``meta_fields`` or ``data_fields``. meta_fields: metadata field names: these are attributes which will be treated as - {term}`static` when this pytree is passed to :func:`jax.jit`. ``meta_fields`` is + :term:`static` when this pytree is passed to :func:`jax.jit`. ``meta_fields`` is optional only if ``nodetype`` is a dataclass, in which case individual fields can be marked static via :func:`dataclasses.field` (see examples below). Metadata fields *must* be static, hashable, immutable objects, as these objects @@ -1036,10 +1101,16 @@ def register_dataclass( if not dataclasses.is_dataclass(nodetype): raise TypeError("register_dataclass: data_fields and meta_fields are required when" f" nodetype is not a dataclass. Got {nodetype=}.") - data_fields = [f.name for f in dataclasses.fields(nodetype) - if not f.metadata.get('static', False)] - meta_fields = [f.name for f in dataclasses.fields(nodetype) - if f.metadata.get('static', False)] + data_fields = [ + f.name + for f in dataclasses.fields(nodetype) + if not f.metadata.get("static", False) + ] + meta_fields = [ + f.name + for f in dataclasses.fields(nodetype) + if f.metadata.get("static", False) + ] assert meta_fields is not None assert data_fields is not None @@ -1067,10 +1138,11 @@ def register_dataclass( msg += f" Unexpected fields: {unexpected}." raise ValueError(msg) - def flatten_with_keys(x): - meta = tuple(getattr(x, name) for name in meta_fields) - data = tuple((GetAttrKey(name), getattr(x, name)) for name in data_fields) - return data, meta + if overlap := set(data_fields) & set(meta_fields): + raise ValueError( + "data_fields and meta_fields must not overlap. Overlapping fields:" + f" {overlap}." + ) def unflatten_func(meta, data): meta_args = tuple(zip(meta_fields, meta)) @@ -1087,9 +1159,6 @@ def flatten_func(x): none_leaf_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields)) dispatch_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields)) _registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func) - _registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry( - flatten_with_keys, unflatten_func - ) return nodetype @@ -1150,34 +1219,38 @@ def register_static(cls: type[H]) -> type[H]: @export def tree_flatten_with_path( - tree: Any, is_leaf: Callable[[Any], bool] | None = None + tree: Any, is_leaf: Callable[..., bool] | None = None, + is_leaf_takes_path: bool = False, ) -> tuple[list[tuple[KeyPath, Any]], PyTreeDef]: """Alias of :func:`jax.tree.flatten_with_path`.""" - return default_registry.flatten_with_path(tree, is_leaf) + is_leaf_with_kp: Callable[[Any, Any], bool] | None = is_leaf + if not is_leaf_takes_path and is_leaf is not None: + is_leaf_with_kp = lambda _, x: is_leaf(x) + return default_registry.flatten_with_path(tree, is_leaf_with_kp) @export def tree_leaves_with_path( - tree: Any, is_leaf: Callable[[Any], bool] | None = None + tree: Any, is_leaf: Callable[..., bool] | None = None, + is_leaf_takes_path: bool = False, ) -> list[tuple[KeyPath, Any]]: """Alias of :func:`jax.tree.leaves_with_path`.""" - return tree_flatten_with_path(tree, is_leaf)[0] - - -# generate_key_paths is not exported. -def generate_key_paths( - tree: Any, is_leaf: Callable[[Any], bool] | None = None -) -> list[tuple[KeyPath, Any]]: - return tree_leaves_with_path(tree, is_leaf) -_generate_key_paths = generate_key_paths # alias for backward compat + return tree_flatten_with_path(tree, is_leaf, is_leaf_takes_path)[0] +generate_key_paths = tree_leaves_with_path @export -def tree_map_with_path(f: Callable[..., Any], - tree: Any, *rest: Any, - is_leaf: Callable[[Any], bool] | None = None) -> Any: +def tree_map_with_path( + f: Callable[..., Any], + tree: Any, + *rest: Any, + is_leaf: Callable[..., bool] | None = None, + is_leaf_takes_path: bool = False, +) -> Any: """Alias of :func:`jax.tree.map_with_path`.""" - keypath_leaves, treedef = tree_flatten_with_path(tree, is_leaf) + keypath_leaves, treedef = tree_flatten_with_path( + tree, is_leaf, is_leaf_takes_path + ) keypath_leaves = list(zip(*keypath_leaves)) all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest] return treedef.unflatten(f(*xs) for xs in zip(*all_keypath_leaves)) @@ -1283,3 +1356,225 @@ def _prefix_error( f"{prefix_tree_keys} and {full_tree_keys}") for k, t1, t2 in zip(prefix_tree_keys, prefix_tree_children, full_tree_children): yield from _prefix_error((*key_path, k), t1, t2) + +# === flat tree === + +class FlatTree: + """A FlatTree stores a treedef and a flat list of values. It's meant to be + isomorphic to the corresponding pytree but we can map over it more easily. + Compared to `tree_map`, FlatTree.map has these benefits: + 1. It doesn't touch user flatten/unflatten code (which shouldn't have side + effects but sometimes does in practice). + 2. It can be faster, because it skips the recursive traversal. + 3. It actually obeys the functor rules. For example, + `flat_tree.map(lambda x: (f(x), g(x))).unzip2()[0]` will give + the same result as `flat_tree.map(f)`, whereas in the `tree_map` version + the tuple-returning function would change the tree structure and `unzip` + wouldn't be able to recover it. + """ + # `FlatTree` constructor is private. Use `FlatTree.flatten` instead + def __init__(self, vals, treedef: PyTreeDef, statics): + assert isinstance(treedef, pytree.PyTreeDef) + if not isinstance(vals, tuple): + vals = tuple(vals) + self.vals = tuple(vals) + self.tree = treedef + self.statics = statics # tree-prefix tuple-dict-tree of bools + + def __eq__(self, other): + return (isinstance(other, FlatTree) and self.vals == other.vals + and self.tree == other.tree and self.statics == other.statics) + + def __hash__(self): + return hash((self.vals, self.tree)) + + def map(self, f: Callable) -> FlatTree: + return self.update(f(x) for x in self.vals) + + def map2(self: FlatTree, f: Callable, t2: FlatTree) -> FlatTree: + n = len(self) + assert len(t2) == n + return self.update(f(x1, x2) for x1, x2 in zip(self.vals, t2.vals)) + + def map3( + self: FlatTree, f: Callable, t2: FlatTree, t3: FlatTree) -> FlatTree: + n = len(self) + assert len(t2) == n and len(t3) == n + return self.update(f(x1, x2, x3) + for x1, x2, x3 in zip(self.vals, t2.vals, t3.vals)) + + def unzip2(self: FlatTree) -> tuple[FlatTree, FlatTree]: + ys = [] + zs = [] + for y, z in self.vals: + ys.append(y) + zs.append(z) + return self.update(ys), self.update(zs) + + # TODO: add other helpers like map3, zip, unzip3 etc. as needed + + @staticmethod + def pack(tree): + # We could generalize this to arbitrary pytrees of FlatTree but tuples/dicts + # are sufficient for now. + if isinstance(tree, FlatTree): + return tree + elif isinstance(tree, tuple): + vals = [] + trees = [] + staticss = [] + for child_tree in tree: + child = FlatTree.pack(child_tree) + vals.extend(child.vals) + trees.append(child.tree) + staticss.append(child.statics) + return FlatTree(vals, treedef_tuple(trees), tuple(staticss)) + elif isinstance(tree, dict): + # only empty case handled for now + if tree == {}: + return FlatTree.flatten({}) + else: + assert False + else: + assert False + + def unpack(self: FlatTree) -> tuple[FlatTree, ...]: + # TODO: this is O(N) not O(1) (with N as the number of leaves). If it + # becomes a problem we can fix it with a fancier data tree. + trees = treedef_children(self.tree) + children = [] + offset = 0 + for i, tree in enumerate(trees): + statics = False if isinstance(self.statics, bool) else self.statics[i] + new_offset = offset + tree.num_leaves + children.append(FlatTree(self.vals[offset:new_offset], tree, statics)) + offset = new_offset + return tuple(children) + + @staticmethod + def flatten(tree: PyTree) -> FlatTree: + vals, tree = tree_flatten(tree) + return FlatTree(vals, tree, False) + + @staticmethod + def flatten_static_argnums(args, static_argnums): + if not static_argnums: + return FlatTree.flatten(args) + else: + assert isinstance(args, tuple) + num_args = len(args) + static_argnums = [i % num_args if i < 0 else i for i in static_argnums] + statics = tuple(i in static_argnums for i, _ in enumerate(args)) + tree_with_statics = tuple( + Static(x) if static else x for static, x in zip(statics, args)) + vals, treedef = tree_flatten(tree_with_statics) + return FlatTree(vals, treedef, statics=statics) + + @staticmethod + def flatten_static_argnames(kwargs, static_argnames): + if not static_argnames: + return FlatTree.flatten(kwargs) + else: + assert isinstance(kwargs, dict) + statics = {k : k in static_argnames for k, _ in kwargs.items()} + tree_with_statics = {k : Static(v) if statics[k] else v + for k, v in kwargs.items()} + vals, treedef = tree_flatten(tree_with_statics) + return FlatTree(vals, treedef, statics=statics) + + @staticmethod + def flatten_static_argnums_argnames( + args, kwargs, static_argnums, static_argnames): + return FlatTree.pack(( + FlatTree.flatten_static_argnums(args, static_argnums), + FlatTree.flatten_static_argnames(kwargs, static_argnames))) + + def unflatten(self) -> PyTree: + pytree = tree_unflatten(self.tree, self.vals) + return unwrap_statics(pytree, self.statics) + + @property + def tree_without_statics(self): + # hardcodes default_registry because it's used implicitly in self.flatten + return filter_statics_from_treedef(default_registry, self.tree, self.statics) + + def update(self, new_vals) -> FlatTree: + # `new_vals` can be a generator because `FlatTree` forces it to a tuple + new = FlatTree(new_vals, self.tree, self.statics) + assert len(self.vals) == len(new.vals) + return new + + @cached_property + def paths(self) -> FlatTree: + # TODO(dougalm): find a way to do this without roundtripping + try: + paths, _ = unzip2(tree_leaves_with_path(self.unflatten())) + assert len(paths) == len(self.vals) + return self.update(paths) + except: + return self.update([()] * len(self.vals)) # not our fault + + def __len__(self): + return self.len + + @cached_property + def len(self): + return self.tree.num_leaves + + def __iter__(self): + return self.vals.__iter__() + +def unwrap_statics(pytree, statics): + if statics is False: + return pytree + elif statics is True: + return pytree.val # pytree should be a `Static` object + elif isinstance(pytree, tuple): + return tuple(unwrap_statics(p, s) for p, s in zip(pytree, statics)) + elif isinstance(pytree, dict): + return {k : unwrap_statics(p, statics[k]) for k, p in pytree.items()} + else: + assert False, "unreachable" + +def filter_statics_from_treedef(registry, treedef, statics): + if statics is False: + return treedef + elif statics is True: + assert False, "unreachable" + elif isinstance(statics, tuple): + filtered = tuple( + filter_statics_from_treedef(registry, td, s) + for td, s in zip(treedef.children(), statics) if s is not True) + return treedef.from_node_data_and_children(registry, treedef.node_data(), filtered) # type: ignore + elif isinstance(statics, dict): + ty, keys = treedef.node_data() # type: ignore + filtered_keys, filtered_subtrees = unzip2( + (k, filter_statics_from_treedef(registry, td, statics[k])) + for td, k in zip(treedef.children(), keys) if statics[k] is not True) + return treedef.from_node_data_and_children(registry, (ty, filtered_keys), filtered_subtrees) # type: ignore + else: + assert False, "unreachable" + +@register_static +@dataclasses.dataclass(frozen=True) +class Static: + val: Any + + def __eq__(self, other): + return (type(other) is Static and type(self.val) is type(other.val) and + self.val == other.val) + + +def _ensure_inbounds(allow_invalid: bool, num_args: int, argnums: Sequence[int] + ) -> tuple[int, ...]: + """Ensure argnum is within bounds. Also resolves negative argnums.""" + result = [] + for i in argnums: + if i >= num_args and allow_invalid: continue + if not -num_args <= i < num_args: + raise ValueError( + "Positional argument indices, e.g. for `static_argnums`, must have " + "value greater than or equal to -len(args) and less than len(args), " + f"but got value {i} for len(args) == {num_args}.") + result.append(i % num_args) # Resolve negative + return tuple(result) diff --git a/jax/_src/typing.py b/jax/_src/typing.py index 010841b45dd2..bb89dac9ab34 100644 --- a/jax/_src/typing.py +++ b/jax/_src/typing.py @@ -29,6 +29,7 @@ from collections.abc import Sequence import enum import typing +from types import EllipsisType from typing import Any, Protocol, Union from jax._src.basearray import ( @@ -47,7 +48,19 @@ @typing.runtime_checkable class SupportsDType(Protocol): @property - def dtype(self) -> DType: ... + def dtype(self, /) -> DType: ... + +class SupportsShape(Protocol): + @property + def shape(self, /) -> tuple[int, ...]: ... + +class SupportsSize(Protocol): + @property + def size(self, /) -> int: ... + +class SupportsNdim(Protocol): + @property + def ndim(self, /) -> int: ... # DTypeLike is meant to annotate inputs to np.dtype that return # a valid JAX dtype. It's different than numpy.typing.DTypeLike @@ -94,3 +107,7 @@ class DLDeviceType(enum.IntEnum): kDLCPU = 1 kDLCUDA = 2 kDLROCM = 10 + +AnyInt = int | np.integer +StaticIndex = AnyInt | slice | EllipsisType +Index = StaticIndex | None | Sequence[AnyInt] | Array | np.ndarray diff --git a/jax/_src/util.py b/jax/_src/util.py index 0e28aea04b5a..0610c6bbac01 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -16,24 +16,30 @@ import abc from collections.abc import Callable, Iterable, Iterator, Sequence +import dataclasses import functools from functools import partial import itertools as it import logging +import math import operator -from typing import (Any, Generic, TypeVar, overload, TYPE_CHECKING, cast) +from typing import (Any, Generic, SupportsIndex, Type, TypeVar, overload, TYPE_CHECKING, cast) import weakref import numpy as np from jax._src import config -from jax._src.lib import xla_client as xc +from jax._src.lib import weakref_lru_cache as _weakref_lru_cache from jax._src.lib import utils as jaxlib_utils +from jax._src.lib import jaxlib_extension_version logger = logging.getLogger(__name__) Seq = Sequence +# TODO(jakevdp): fix import cycles and import Array. +Array = Any + T = TypeVar("T") T1 = TypeVar("T1") T2 = TypeVar("T2") @@ -45,21 +51,28 @@ # to that used for builtins.zip in python/typeshed. This supports # return types matching input types for up to three arguments. @overload - def safe_zip(__arg1: Iterable[T1]) -> list[tuple[T1]]: ... + def safe_zip(__arg1: Iterable[T1], /) -> list[tuple[T1]]: ... @overload - def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2]) -> list[tuple[T1, T2]]: ... + def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2], /) -> list[tuple[T1, T2]]: ... @overload - def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3]) -> list[tuple[T1, T2, T3]]: ... + def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3], /) -> list[tuple[T1, T2, T3]]: ... @overload - def safe_zip(__arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], *args) -> list[tuple[Any, ...]]: ... + def safe_zip(__arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], /, *args) -> list[tuple[Any, ...]]: ... def safe_zip(*args): - args = list(map(list, args)) - n = len(args[0]) - for arg in args[1:]: - assert len(arg) == n, f'length mismatch: {list(map(len, args))}' - return list(zip(*args)) - + """ + Like builtin :func:`zip`, but with additional safety checks. + + The differences from :func:`zip` are: + + - :func:`safe_zip` checks that at least one argument is provided. + - :func:`safe_zip` checks that all arguments have the same length. + - :func:`safe_zip` returns an eagerly-evaluated list instead of a + lazily-evaluated iterator. + """ + if not args: + raise TypeError("safe_zip requires at least 1 argument.") + return list(zip(*args, strict=True)) else: safe_zip = jaxlib_utils.safe_zip @@ -69,16 +82,16 @@ def safe_zip(*args): # to that used for builtins.map in python/typeshed. This supports # checking input types for the callable with up to three arguments. @overload - def safe_map(f: Callable[[T1], T], __arg1: Iterable[T1]) -> list[T]: ... + def safe_map(f: Callable[[T1], T], __arg1: Iterable[T1], /) -> list[T]: ... @overload - def safe_map(f: Callable[[T1, T2], T], __arg1: Iterable[T1], __arg2: Iterable[T2]) -> list[T]: ... + def safe_map(f: Callable[[T1, T2], T], __arg1: Iterable[T1], __arg2: Iterable[T2], /) -> list[T]: ... @overload - def safe_map(f: Callable[[T1, T2, T3], T], __arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3]) -> list[T]: ... + def safe_map(f: Callable[[T1, T2, T3], T], __arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3], /) -> list[T]: ... @overload - def safe_map(f: Callable[..., T], __arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], *args) -> list[T]: ... + def safe_map(f: Callable[..., T], __arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], /, *args) -> list[T]: ... def safe_map(f, *args): args = list(map(list, args)) @@ -92,27 +105,23 @@ def safe_map(f, *args): if TYPE_CHECKING: @overload - def foreach(f: Callable[[T1], Any], __arg1: Iterable[T1]) -> None: ... + def foreach(f: Callable[[T1], Any], __arg1: Iterable[T1], /) -> None: ... @overload - def foreach(f: Callable[[T1, T2], Any], __arg1: Iterable[T1], __arg2: Iterable[T2]) -> None: ... + def foreach(f: Callable[[T1, T2], Any], __arg1: Iterable[T1], __arg2: Iterable[T2], /) -> None: ... @overload - def foreach(f: Callable[[T1, T2, T3], Any], __arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3]) -> None: ... + def foreach(f: Callable[[T1, T2, T3], Any], __arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3], /) -> None: ... @overload - def foreach(f: Callable[..., Any], __arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], *args) -> None: ... + def foreach(f: Callable[..., Any], __arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], /, *args) -> None: ... def foreach(f, *args): safe_map(f, *args) return None else: - # TODO(phawkins): remove after jaxlib 0.5.2 is the minimum. - if hasattr(jaxlib_utils, 'foreach'): - foreach = jaxlib_utils.foreach - else: - foreach = safe_map + foreach = jaxlib_utils.foreach def unzip2(xys: Iterable[tuple[T1, T2]] @@ -141,13 +150,15 @@ def unzip3(xyzs: Iterable[tuple[T1, T2, T3]] zs.append(z) return tuple(xs), tuple(ys), tuple(zs) -def subvals(lst, replace): +def subvals(lst: Sequence[T], replace: Iterable[tuple[int, T]]) -> tuple[T, ...]: + """Substitute values within a list.""" lst = list(lst) for i, v in replace: lst[i] = v return tuple(lst) def split_list(args: Sequence[T], ns: Sequence[int]) -> list[list[T]]: + """Split list into sublists of the specified sizes.""" args = list(args) lists = [] for n in ns: @@ -157,8 +168,9 @@ def split_list(args: Sequence[T], ns: Sequence[int]) -> list[list[T]]: return lists def split_list_checked(args: Sequence[T], ns: Sequence[int]) -> list[list[T]]: + """Split list into sublists of the specified sizes.""" args = list(args) - assert sum(ns) == len(args) + assert sum(ns) == len(args) and all(n >= 0 for n in ns) lists = [] for n in ns: lists.append(args[:n]) @@ -166,8 +178,9 @@ def split_list_checked(args: Sequence[T], ns: Sequence[int]) -> list[list[T]]: return lists def partition_list(bs: Sequence[bool], l: Sequence[T]) -> tuple[list[T], list[T]]: + """Partition a list into two based on a mask.""" assert len(bs) == len(l) - lists = [], [] # type: ignore + lists: tuple[list[T], list[T]] = ([], []) for b, x in zip(bs, l): lists[b].append(x) return lists @@ -176,6 +189,7 @@ def merge_lists(bs: Sequence[bool], l0: Sequence[T1], l1: Sequence[T2] ) -> list[T1 | T2]: + """Merge the elements of two lists based on a mask.""" assert sum(bs) == len(l1) and len(bs) - sum(bs) == len(l0) i0, i1 = iter(l0), iter(l1) out: list[T1 | T2] = [next(i1) if b else next(i0) for b in bs] @@ -204,7 +218,7 @@ def subs_list2( assert next(base_, sentinel) is sentinel return out -def split_dict(dct, names): +def split_dict(dct: dict[T1, T2], names: Sequence[T1]) -> list[T2]: dct = dict(dct) lst = [dct.pop(name) for name in names] assert not dct @@ -244,64 +258,14 @@ def curry(f): """ return wraps(f)(partial(partial, f)) -# TODO(phawkins): make this unconditional after jaxlib 0.5.3 is the minimum. toposort: Callable[[Iterable[Any]], list[Any]] -if hasattr(jaxlib_utils, "topological_sort"): - toposort = partial(jaxlib_utils.topological_sort, "parents") -else: +toposort = partial(jaxlib_utils.topological_sort, "parents") - def toposort(end_nodes): - if not end_nodes: - return [] - end_nodes = _remove_duplicates(end_nodes) - - child_counts = {} - stack = list(end_nodes) - while stack: - node = stack.pop() - if id(node) in child_counts: - child_counts[id(node)] += 1 - else: - child_counts[id(node)] = 1 - stack.extend(node.parents) - for node in end_nodes: - child_counts[id(node)] -= 1 - - sorted_nodes = [] - childless_nodes = [ - node for node in end_nodes if child_counts[id(node)] == 0 - ] - assert childless_nodes - while childless_nodes: - node = childless_nodes.pop() - sorted_nodes.append(node) - for parent in node.parents: - if child_counts[id(parent)] == 1: - childless_nodes.append(parent) - else: - child_counts[id(parent)] -= 1 - sorted_nodes = sorted_nodes[::-1] - - check_toposort(sorted_nodes) - return sorted_nodes - - def check_toposort(nodes): - visited = set() - for node in nodes: - assert all(id(parent) in visited for parent in node.parents) - visited.add(id(node)) - - def _remove_duplicates(node_list): - seen = set() - out = [] - for n in node_list: - if id(n) not in seen: - seen.add(id(n)) - out.append(n) - return out - -def split_merge(predicate, xs): +def split_merge( + predicate: Callable[[T], bool], + xs: Sequence[T] +) -> tuple[list[T], list[T], Callable[[Sequence[T], Sequence[T]], list[T]]]: sides = list(map(predicate, xs)) lhs = [x for x, s in zip(xs, sides) if s] rhs = [x for x, s in zip(xs, sides) if not s] @@ -321,57 +285,204 @@ def merge(new_lhs, new_rhs): return lhs, rhs, merge -def _ignore(): return None +def cache(max_size=4096, trace_context_in_key: bool | Callable = True): + if trace_context_in_key: + trace_context = (trace_context_in_key if callable(trace_context_in_key) + else config.trace_context) + def wrap(f): + @functools.lru_cache(max_size) + def cached(_, *args, **kwargs): + return f(*args, **kwargs) + @functools.wraps(f) + def wrapper(*args, **kwargs): + if config.check_tracer_leaks.value: + return f(*args, **kwargs) + return cached(trace_context(), *args, **kwargs) -def cache(max_size=4096, trace_context_in_key=True): - def wrap(f): - @functools.lru_cache(max_size) - def cached(_, *args, **kwargs): - return f(*args, **kwargs) + wrapper.cache_clear = cached.cache_clear + wrapper.cache_info = cached.cache_info + register_cache(wrapper, str(f)) + return wrapper + else: + def wrap(f): + wrapper = functools.lru_cache(max_size)(f) + register_cache(wrapper, str(f)) + return wrapper + return wrap - @functools.wraps(f) - def wrapper(*args, **kwargs): - if config.check_tracer_leaks.value: - return f(*args, **kwargs) - return cached(config.trace_context() if trace_context_in_key else _ignore(), - *args, **kwargs) +# Maps caches to the name of the callable they apply to. All caches in +# this dictionary support `cache_clear()`. +_caches: weakref.WeakKeyDictionary[Any, str] = weakref.WeakKeyDictionary() - wrapper.cache_clear = cached.cache_clear - wrapper.cache_info = cached.cache_info - cache_clearing_funs.add(wrapper.cache_clear) - return wrapper - return wrap +def register_cache(cache: Any, for_what: str): + """Registers a cache with JAX's cache management. -cache_clearing_funs = weakref.WeakSet() # type: ignore + Args: + cache: an object supporting `cache_clear()`, `cache_info()`, and + `cache_keys()`, like the result of `functools.lru_cache()`. + for_what: a string to identify what this cache is used for. This is + used for debugging. + """ + _caches[cache] = for_what def clear_all_caches(): - global cache_clearing_funs - for clear in cache_clearing_funs: - clear() + for cache in list(_caches.keys()): + cache.cache_clear() memoize = cache(max_size=None) -def weakref_lru_cache(call: Callable, maxsize=2048, - trace_context_in_key: bool = True): +def _ignore(): return None + +def weakref_lru_cache( + call: Callable, maxsize: int | None = 2048, + trace_context_in_key: bool = True, explain: Callable | None = None): """ Least recently used cache decorator with weakref support. The cache will take a weakref to the first argument of the wrapped function - and strong refs to all subsequent operations. In all other respects it should - behave similar to `functools.lru_cache`. + and strong refs to all other arguments. In all other respects it should + behave similar to `functools.lru_cache`. The cache is thread local. """ - global _weakref_lru_caches - cached_call = xc.weakref_lru_cache( - config.trace_context if trace_context_in_key else _ignore, call, maxsize) - _weakref_lru_caches.add(cached_call) + if jaxlib_extension_version >= 396: + cached_call = _weakref_lru_cache.weakref_lru_cache( # type: ignore + config.trace_context if trace_context_in_key else _ignore, call, maxsize, # type: ignore + explain = lambda: explain if config.explain_cache_misses.value else None) # type: ignore + else: + cached_call = _weakref_lru_cache.weakref_lru_cache( + config.trace_context if trace_context_in_key else _ignore, call, maxsize) + register_cache(cached_call, str(call)) return cached_call -_weakref_lru_caches = weakref.WeakSet() # type: ignore -def clear_all_weakref_lru_caches(): - for cached_call in _weakref_lru_caches: - cached_call.cache_clear() +@dataclasses.dataclass(frozen=True, slots=True, weakref_slot=True) +class MultiWeakRefCacheKey: + weakrefs: tuple[weakref.ref, ...] # Used only when len(weakrefs) >= 2 + + +class MultiWeakRefPlaceholder: + # Stands for an arg/kwarg that was replaced with a weakref + pass +_multi_weakref_placeholder = MultiWeakRefPlaceholder() + +# The types of arguments for which `multi_weakref_lru_cache` should keep +# weak references. +weakref_cache_key_types: set[Type] = set() +def is_weakref_cache_key_type(v): + return callable(v) or (type(v) in weakref_cache_key_types) + + +def multi_weakref_lru_cache( + call: Callable, *, + maxsize=2048, + trace_context_in_key: bool = True): + """ + Least recently used cache decorator with weakref support. + + Similar to `weakref_lru_cache`, except that it keeps weak references + to all positional and keyword arguments for which + `is_weakref_cache_key_type()` is true, and strong references to + other arguments. The cache entry is removed if any of the weakref + arguments dies. + """ + # Keep strong references to the MultiWeakRefCacheKeys that resulted in + # cache misses, and are cache keys. Indexed by id. Only keys with all + # included weakrefs live are present. + id_to_key: dict[int, MultiWeakRefCacheKey] = {} + # For each `wr: weakref.ref` present in `key: MultiWeakRefCacheKey` we have + # `id(key) in weakref_to_key_ids[wr]`. + weakref_to_key_ids: dict[weakref.ref, set[int]] = {} + + def remove_weakref(wr: weakref.ref): + key_ids = weakref_to_key_ids.get(wr, set()) + for key_id in key_ids: + try: + del id_to_key[key_id] + except KeyError: + pass + try: + del weakref_to_key_ids[wr] + except KeyError: + pass + + def weakrefs_to_sentinel(v, acc: list[Any]): + if type(v) is tuple: + return tuple(weakrefs_to_sentinel(v1, acc) for v1 in v) + elif type(v) is dict: + return {k: weakrefs_to_sentinel(v1, acc) for k, v1 in v.items()} + elif is_weakref_cache_key_type(v): + acc.append(v) + return _multi_weakref_placeholder + else: + return v + + def sentinel_to_referrents(v, + it: Iterator[weakref.ref], + key_id: int | None): + # key_id is not None iff we use a MultiWeakRefCacheKey (>= 2 weakrefs) + if type(v) is tuple: + return tuple(sentinel_to_referrents(v1, it, key_id) for v1 in v) + elif type(v) is dict: + return {k: sentinel_to_referrents(v1, it, key_id) + for k, v1 in v.items()} + elif v is _multi_weakref_placeholder: + wr = next(it) + if key_id is not None: + weakref_to_key_ids.setdefault(wr, set()).add(key_id) + return wr() + else: + return v + + def cache_miss(key: MultiWeakRefCacheKey | MultiWeakRefPlaceholder | Any, + *args, **kwargs): + if isinstance(key, MultiWeakRefCacheKey): # had at least 2 weakrefs + # We know `key` is in `cached_call` cache, so store strong references + key_id = id(key) + id_to_key[key_id] = key + orig_args, orig_kwargs = sentinel_to_referrents( + (args, kwargs), iter(key.weakrefs), key_id) + elif key is _multi_weakref_placeholder: # had 0 weakrefs + orig_args = args + orig_kwargs = kwargs + else: # had 1 weakref, we had put it first as the `key` + orig_args, orig_kwargs = sentinel_to_referrents( + (args, kwargs), iter([weakref.ref(key)]), None) + return call(*orig_args, **orig_kwargs) + + + cached_call = _weakref_lru_cache.weakref_lru_cache( + config.trace_context if trace_context_in_key else _ignore, + cache_miss, maxsize + ) + register_cache(cached_call, str(call)) + + @functools.wraps(call) + def wrapper(*orig_args, **orig_kwargs): + acc_weakrefs: list[Any] = [] + args, kwargs = weakrefs_to_sentinel((orig_args, orig_kwargs), + acc_weakrefs) + nr_weakrefs = len(acc_weakrefs) + if nr_weakrefs == 0: + return cached_call(_multi_weakref_placeholder, + *orig_args, **orig_kwargs) + elif nr_weakrefs == 1: + # Put the single weakref first, and skip the MultiWeakRefCacheKey + return cached_call(acc_weakrefs[0], + *args, **kwargs) + else: + value_to_weakref = {v: weakref.ref(v, remove_weakref) + for v in set(acc_weakrefs)} + key = MultiWeakRefCacheKey(weakrefs=tuple(value_to_weakref[v] + for v in acc_weakrefs)) + return cached_call(key, *args, **kwargs) + + wrapper.cache_info = cached_call.cache_info + wrapper.cache_clear = cached_call.cache_clear + wrapper.cache_keys = cached_call.cache_keys + wrapper._multi_weakref_id_to_key = id_to_key # stays alive as long as wrapper + wrapper._multi_weakref_to_key_ids = weakref_to_key_ids + return wrapper + class Unhashable: __slots__ = ["val"] @@ -406,19 +517,21 @@ def __hash__(self): def __eq__(self, other): return self.val == other.val -def wrap_name(name, transform_name): - return transform_name + '(' + name + ')' +def wrap_name(transform_name: str, name: str) -> str: + return f"{transform_name}({name})" -def fun_name(fun: Callable): + +def fun_name(fun: Callable, default_name: str = "") -> str: name = getattr(fun, "__name__", None) if name is not None: return name if isinstance(fun, partial): return fun_name(fun.func) else: - return "" + return default_name + -def fun_qual_name(fun: Callable): +def fun_qual_name(fun: Callable) -> str: qual_name = getattr(fun, "__qualname__", None) if qual_name is not None: return qual_name @@ -426,7 +539,7 @@ def fun_qual_name(fun: Callable): return fun_qual_name(fun.func) return fun_name(fun) -def canonicalize_axis(axis, num_dims) -> int: +def canonicalize_axis(axis: SupportsIndex, num_dims: int) -> int: """Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims).""" axis = operator.index(axis) if not -num_dims <= axis < num_dims: @@ -435,7 +548,18 @@ def canonicalize_axis(axis, num_dims) -> int: axis = axis + num_dims return axis -def moveaxis(x, src, dst): +def canonicalize_axis_tuple(axis: int | Sequence[int] | None, ndim: int, allow_duplicate: bool = False) -> tuple[int, ...]: + if axis is None: + return tuple(range(ndim)) + if isinstance(axis, Sequence): + axis = tuple(canonicalize_axis(i, ndim) for i in axis) + if not allow_duplicate and len(set(axis)) != len(axis): + raise ValueError(f"repeated axis: {axis}") + return axis + else: + return (canonicalize_axis(axis, ndim),) + +def moveaxis(x: Array, src: int | Sequence[int], dst: int | Sequence[int]) -> Array: if src == dst: return x if isinstance(src, int): @@ -449,7 +573,7 @@ def moveaxis(x, src, dst): perm.insert(d, s) return x.transpose(perm) -def ceil_of_ratio(x, y): +def ceil_of_ratio(x: int, y: int) -> int: return -(-x // y) @@ -475,8 +599,9 @@ def wrapper(fun: T) -> T: else docstr.format(fun=name, doc=doc, **kwargs)) fun.__qualname__ = getattr(wrapped, "__qualname__", fun.__name__) fun.__wrapped__ = wrapped - finally: - return fun + except Exception: + pass + return fun return wrapper @@ -485,22 +610,18 @@ def wrapper(fun: T) -> T: def assert_unreachable(x): raise AssertionError(f"Unhandled case: {type(x).__name__}") -def tuple_insert(t, idx, val): +def tuple_insert(t: tuple[T, ...], idx: int, val: T) -> tuple[T, ...]: assert 0 <= idx <= len(t), (idx, len(t)) return t[:idx] + (val,) + t[idx:] -def tuple_delete(t, idx): +def tuple_delete(t: tuple[T, ...], idx: int) -> tuple[T, ...]: assert 0 <= idx < len(t), (idx, len(t)) return t[:idx] + t[idx + 1:] -def tuple_update(t, idx, val): +def tuple_update(t: tuple[T, ...], idx: int, val: T) -> tuple[T, ...]: assert 0 <= idx < len(t), (idx, len(t)) return t[:idx] + (val,) + t[idx+1:] -def tuple_replace(tupl, index, item): - # unlike tuple_update, works with negative indices as well - return tupl[:index] + (item,) + tupl[index:][1:] - class HashableFunction: """Decouples function equality and hash from its identity. @@ -554,13 +675,8 @@ def __eq__(self, other): self.args == other.args and self.kwargs == other.kwargs) def __hash__(self): - return hash( - ( - self.f.__code__, - self.args, - tuple(sorted(self.kwargs.items(), key=lambda kv: kv[0])), - ), - ) + kwargs = tuple(sorted(self.kwargs.items(), key=lambda kv: kv[0])) + return hash((self.f.__code__, self.args, kwargs)) def __call__(self, *args, **kwargs): return self.f(*self.args, *args, **self.kwargs, **kwargs) @@ -643,7 +759,7 @@ def __eq__(self, other): return self.x == other.x if self.hash is not None else self.x is other.x -def _original_func(f): +def _original_func(f: Callable) -> Callable: if isinstance(f, property): return cast(property, f).fget elif isinstance(f, functools.cached_property): @@ -690,14 +806,6 @@ def decorator(f): return decorator -try: - # numpy 1.25.0 or newer - NumpyComplexWarning: type[Warning] = np.exceptions.ComplexWarning -except AttributeError: - # legacy numpy - NumpyComplexWarning = np.ComplexWarning - - class StrictABCMeta(abc.ABCMeta): """A variant of `abc.ABCMeta` which does not allow virtual subclasses. @@ -724,5 +832,19 @@ def test_event(name: str, *args) -> None: return test_event_listener(name, *args) -if hasattr(jaxlib_utils, "Mutex"): - Mutex = jaxlib_utils.Mutex +Mutex = jaxlib_utils.Mutex + + +def pprint_bytes(num_bytes: int | float) -> str: + prefixes = ("", "K", "M", "G", "T") + if num_bytes <= 0: + return "0.00B" + exponent = min(math.floor(math.log(num_bytes, 1000)), len(prefixes) - 1) + scaled_value = num_bytes / (1000**exponent) + return f"{scaled_value:.2f}{prefixes[exponent]}B" + +if hasattr(jaxlib_utils, "install_failure_signal_handler"): + install_failure_signal_handler = jaxlib_utils.install_failure_signal_handler +else: + def install_failure_signal_handler(call_previous_handler: bool = True): + pass diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index be96deab81d8..42d93e7744a7 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -23,7 +23,7 @@ import atexit from collections.abc import Callable, Mapping import dataclasses -from functools import lru_cache, partial +from functools import partial import importlib import json import logging @@ -31,8 +31,8 @@ import pkgutil import platform as py_platform import threading -import traceback from typing import Any, Union +from collections.abc import Sequence import warnings from jax._src import config @@ -41,9 +41,10 @@ from jax._src import traceback_util from jax._src import util from jax._src.cloud_tpu_init import get_tpu_library_path -from jax._src.lib import cuda_versions +from jax._src.lib import jaxlib_extension_version from jax._src.lib import xla_client -from jax._src.lib import xla_extension +from jax._src.lib import _jax +from jax._src.lib import _profiler logger = logging.getLogger(__name__) @@ -58,41 +59,42 @@ traceback_util.register_exclusion(__file__) -XlaBackend = xla_client.Client +# The platforms in this set will force forward compatibility for lowering. +FORCE_FORWARD_COMPAT_LOWERING_PLATFORMS: set[str] = set() MIN_COMPUTE_CAPABILITY = 52 -_DEFAULT_CPU_COLLECTIVES_IMPL = 'gloo' - # TODO(phawkins): Remove jax_xla_backend. _XLA_BACKEND = config.string_flag( 'jax_xla_backend', '', - 'Deprecated, please use --jax_platforms instead.') + help='Deprecated, please use --jax_platforms instead.') BACKEND_TARGET = config.string_flag( 'jax_backend_target', os.getenv('JAX_BACKEND_TARGET', '').lower(), - 'Either "local" or "rpc:address" to connect to a remote service target.') + help='Either "local" or "rpc:address" to connect to a remote service target.') # TODO(skye): warn when this is used once we test out --jax_platforms a bit _PLATFORM_NAME = config.string_flag( 'jax_platform_name', os.getenv('JAX_PLATFORM_NAME', '').lower(), - 'Deprecated, please use --jax_platforms instead.') + help='Deprecated, please use --jax_platforms instead.') CUDA_VISIBLE_DEVICES = config.string_flag( 'jax_cuda_visible_devices', 'all', - 'Restricts the set of CUDA devices that JAX will use. Either "all", or a ' - 'comma-separate list of integer device IDs.') + help=( + 'Restricts the set of CUDA devices that JAX will use. Either "all", or a ' + 'comma-separate list of integer device IDs.')) _ROCM_VISIBLE_DEVICES = config.string_flag( 'jax_rocm_visible_devices', 'all', - 'Restricts the set of ROCM devices that JAX will use. Either "all", or a ' - 'comma-separate list of integer device IDs.') + help=( + 'Restricts the set of ROCM devices that JAX will use. Either "all", or a ' + 'comma-separate list of integer device IDs.')) -_MOCK_NUM_GPU_PROCESSES = config.int_flag( +MOCK_NUM_GPU_PROCESSES = config.int_flag( name="mock_num_gpu_processes", default=0, help="Mock number of JAX processes in GPU client. Value zero turns " "off mocking.", ) -_MOCK_GPU_TOPOLOGY = config.string_flag( +MOCK_GPU_TOPOLOGY = config.string_flag( name="jax_mock_gpu_topology", default="", help='Mock multi-host GPU topology in GPU client. The value should ' @@ -100,12 +102,6 @@ '". Empty string turns off mocking.', ) -_CPU_ENABLE_GLOO_COLLECTIVES = config.bool_flag( - name="jax_cpu_enable_gloo_collectives", - default=False, - help="Deprecated, please use jax_cpu_collectives_implementation instead.", -) - _CPU_ENABLE_ASYNC_DISPATCH = config.bool_flag( name="jax_cpu_enable_async_dispatch", default=True, @@ -113,6 +109,43 @@ "inline without async dispatch.", ) +FORCE_DCN_CROSS_HOST_TRANSFERS = config.bool_flag( + name="jax_force_dcn_cross_host_transfers", + default=False, + help="Force cross host transfers to use the DCN socket transfer library " + "even when the plugin supports cross-host transfers." +) + +CROSS_HOST_TRANSFER_SOCKET_ADDRESS = config.string_flag( + name="jax_cross_host_transfer_socket_address", + default="", + help="Socket address to use for cross host device transfers via DCN. " + "Necessary only if the PjRt plugin does not support cross host transfers.", +) + +CROSS_HOST_TRANSPORT_ADDRESSES = config.string_flag( + name="jax_cross_host_transport_addresses", + default="", + help=( + "Comma-separated list of transport addresses to use for cross host " + "device transfers via DCN. If not set, defaults to [0.0.0.0:0] * 4." + ), +) + +CROSS_HOST_TRANSFER_TIMEOUT_SECONDS = config.int_flag( + "jax_cross_host_transfer_timeout_seconds", + None, + help=( + "Timeout for cross host transfer metadata exchange through KV store. " + "Default is one minute." + ), +) + +CROSS_HOST_TRANSFER_TRANSFER_SIZE = config.int_flag( + "jax_cross_host_transfer_transfer_size", + None, + help="Chunk size for chunked transfer requests." +) # Warn the user if they call fork(), because it's not going to go well for them. def _at_fork(): @@ -125,13 +158,67 @@ def _at_fork(): # Backends +_NameValueMapping = Mapping[str, Union[str, int, list[int], float, bool]] + +def _make_transfer_server_factory( +) -> xla_client._xla.TransferServerInterfaceFactory | None: + """Creates a transfer server interface factory.""" + if (not CROSS_HOST_TRANSFER_SOCKET_ADDRESS.value or not + hasattr(_jax, "make_transfer_server_interface_factory")): + return None + transport_addresses = [] + if CROSS_HOST_TRANSPORT_ADDRESSES.value: + transport_addresses = CROSS_HOST_TRANSPORT_ADDRESSES.value.split(",") + transfer_server_kwargs = { + "distributed_client": distributed.global_state.client, + "socket_address": CROSS_HOST_TRANSFER_SOCKET_ADDRESS.value, + "transport_addresses": transport_addresses, + } + if CROSS_HOST_TRANSFER_TIMEOUT_SECONDS.value is not None: + transfer_server_kwargs["cross_host_transfer_timeout_seconds"] = ( + CROSS_HOST_TRANSFER_TIMEOUT_SECONDS.value) + if CROSS_HOST_TRANSFER_TRANSFER_SIZE.value is not None: + transfer_server_kwargs["transfer_size"] = ( + CROSS_HOST_TRANSFER_TRANSFER_SIZE.value) + return _jax.make_transfer_server_interface_factory(**transfer_server_kwargs) # type: ignore + + +def make_tpu_client( + library_path: str | None = None, options: _NameValueMapping | None = None +): + """Returns a TPU client. Defaults to allowing 32 in-flight computations.""" + if not _jax.pjrt_plugin_loaded('tpu'): + c_api = xla_client.load_pjrt_plugin_dynamically( + "tpu", library_path or "libtpu.so" + ) + _profiler.register_plugin_profiler(c_api) + assert _jax.pjrt_plugin_loaded('tpu') + if not _jax.pjrt_plugin_initialized('tpu'): + _jax.initialize_pjrt_plugin('tpu') + if options is None: + options = {} + if jaxlib_extension_version < 397: + return _jax.get_c_api_client( + "tpu", + options, + distributed.global_state.client, + _make_transfer_server_factory(), + ) + return _jax.get_c_api_client( + "tpu", + options, + distributed.global_state.client, + _make_transfer_server_factory(), + FORCE_DCN_CROSS_HOST_TRANSFERS.value, + ) + def tpu_client_timer_callback(timer_secs: float) -> xla_client.Client | None: def _log_warning(): warnings.warn( f'TPU backend initialization is taking more than {timer_secs} seconds. ' 'Did you run your code on all TPU hosts? ' - 'See https://jax.readthedocs.io/en/latest/multi_process.html ' + 'See https://docs.jax.dev/en/latest/multi_process.html ' 'for more information.') # Will log a warning after `timer_secs`. @@ -139,7 +226,7 @@ def _log_warning(): t.start() try: - client = xla_client.make_tpu_client( + client = make_tpu_client( get_tpu_library_path(), _options_from_jax_configs("tpu")) finally: @@ -241,16 +328,6 @@ def make_cpu_client( # https://github.com/jax-ml/jax/pull/26172 goes in. if collectives is None and distributed.global_state.client is not None: collectives_impl = config.cpu_collectives_implementation.value - if _CPU_ENABLE_GLOO_COLLECTIVES.value: - collectives_impl = 'gloo' - warnings.warn('Setting `jax_cpu_enable_gloo_collectives` is ' - 'deprecated. Please use `jax.config.update(' - '"jax_cpu_collectives_implementation", "gloo")` instead.', - DeprecationWarning, - ) - if collectives_impl is None: - collectives_impl = _DEFAULT_CPU_COLLECTIVES_IMPL - if collectives_impl == 'gloo': collectives = xla_client._xla.make_gloo_tcp_collectives( distributed_client=distributed.global_state.client, @@ -265,12 +342,15 @@ def make_cpu_client( num_devices = num_cpu_devices.value if num_cpu_devices.value >= 0 else None return xla_client.make_cpu_client( - asynchronous=_CPU_ENABLE_ASYNC_DISPATCH.value, - distributed_client=distributed.global_state.client, - node_id=distributed.global_state.process_id, - num_nodes=distributed.global_state.num_processes, - collectives=collectives, - num_devices=num_devices, + asynchronous=_CPU_ENABLE_ASYNC_DISPATCH.value, + distributed_client=distributed.global_state.client, + node_id=distributed.global_state.process_id, + num_nodes=distributed.global_state.num_processes, + collectives=collectives, + num_devices=num_devices, + get_local_topology_timeout_minutes=cpu_get_local_topology_timeout_minutes.value, + get_global_topology_timeout_minutes=cpu_get_global_topology_timeout_minutes.value, + transfer_server_factory=_make_transfer_server_factory(), ) @@ -278,158 +358,7 @@ def make_cpu_client( "cpu", make_cpu_client, priority=0, fail_quietly=False ) - -def _check_cuda_compute_capability(devices_to_check): - for idx in devices_to_check: - compute_cap = cuda_versions.cuda_compute_capability(idx) - if compute_cap < MIN_COMPUTE_CAPABILITY: - warnings.warn( - f"Device {idx} has CUDA compute capability {compute_cap/10} which is " - "lower than the minimum supported compute capability " - f"{MIN_COMPUTE_CAPABILITY/10}. See " - "https://jax.readthedocs.io/en/latest/installation.html#nvidia-gpu for " - "more details", - RuntimeWarning - ) - - -def _check_cuda_versions(raise_on_first_error: bool = False, - debug: bool = False): - assert cuda_versions is not None - results: list[dict[str, Any]] = [] - - def _make_msg(name: str, - runtime_version: int, - build_version: int, - min_supported: int, - debug_msg: bool = False): - if debug_msg: - return (f"Package: {name}\n" - f"Version JAX was built against: {build_version}\n" - f"Minimum supported: {min_supported}\n" - f"Installed version: {runtime_version}") - if min_supported: - req_str = (f"The local installation version must be no lower than " - f"{min_supported}.") - else: - req_str = ("The local installation must be the same version as " - "the version against which JAX was built.") - msg = (f"Outdated {name} installation found.\n" - f"Version JAX was built against: {build_version}\n" - f"Minimum supported: {min_supported}\n" - f"Installed version: {runtime_version}\n" - f"{req_str}") - return msg - - def _version_check(name: str, - get_version, - get_build_version, - scale_for_comparison: int = 1, - min_supported_version: int = 0): - """Checks the runtime CUDA component version against the JAX one. - - Args: - name: Of the CUDA component. - get_version: A function to get the local runtime version of the component. - get_build_version: A function to get the build version of the component. - scale_for_comparison: For rounding down a version to ignore patch/minor. - min_supported_version: An absolute minimum version required. Must be - passed without rounding down. - - Raises: - RuntimeError: If the component is not found, or is of unsupported version, - and if raising the error is not deferred till later. - """ - - build_version = get_build_version() - try: - version = get_version() - except Exception as e: - err_msg = f"Unable to load {name}. Is it installed?" - if raise_on_first_error: - raise RuntimeError(err_msg) from e - err_msg += f"\n{traceback.format_exc()}" - results.append({"name": name, "installed": False, "msg": err_msg}) - return - - if not min_supported_version: - min_supported_version = build_version // scale_for_comparison - passed = min_supported_version <= version - - if not passed or debug: - msg = _make_msg(name=name, - runtime_version=version, - build_version=build_version, - min_supported=min_supported_version, - debug_msg=passed) - if not passed and raise_on_first_error: - raise RuntimeError(msg) - else: - record = {"name": name, - "installed": True, - "msg": msg, - "passed": passed, - "build_version": build_version, - "version": version, - "minimum_supported": min_supported_version} - results.append(record) - - _version_check("CUDA", cuda_versions.cuda_runtime_get_version, - cuda_versions.cuda_runtime_build_version, - scale_for_comparison=10, - min_supported_version=12010) - _version_check( - "cuDNN", - cuda_versions.cudnn_get_version, - cuda_versions.cudnn_build_version, - # NVIDIA promise both backwards and forwards compatibility for cuDNN patch - # versions: - # https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#api-compat - scale_for_comparison=100, - min_supported_version=9100 - ) - _version_check("cuFFT", cuda_versions.cufft_get_version, - cuda_versions.cufft_build_version, - # Ignore patch versions. - scale_for_comparison=100) - _version_check("cuSOLVER", cuda_versions.cusolver_get_version, - cuda_versions.cusolver_build_version, - # Ignore patch versions. - scale_for_comparison=100, - min_supported_version=11400) - _version_check("cuPTI", cuda_versions.cupti_get_version, - cuda_versions.cupti_build_version, - min_supported_version=18) - _version_check("cuBLAS", cuda_versions.cublas_get_version, - cuda_versions.cublas_build_version, - # Ignore patch versions. - scale_for_comparison=100, - min_supported_version=120100) - _version_check("cuSPARSE", cuda_versions.cusparse_get_version, - cuda_versions.cusparse_build_version, - # Ignore patch versions. - scale_for_comparison=100, - min_supported_version=12100) - - errors = [] - debug_results = [] - for result in results: - message: str = result['msg'] - if not result['installed'] or not result['passed']: - errors.append(message) - else: - debug_results.append(message) - - join_str = f'\n{"-" * 50}\n' - if debug_results: - print(f'CUDA components status (debug):\n' - f'{join_str.join(debug_results)}') - if errors: - raise RuntimeError(f'Unable to use CUDA because of the ' - f'following issues with CUDA components:\n' - f'{join_str.join(errors)}') - -def _get_num_nodes_from_gpu_topology(topology: str) -> int: +def get_num_nodes_from_gpu_topology(topology: str) -> int: try: slices_str, hosts_per_slice_str, _ = topology.split("x", 2) return int(slices_str) * int(hosts_per_slice_str) @@ -438,75 +367,11 @@ def _get_num_nodes_from_gpu_topology(topology: str) -> int: '" x x ' '".') -def make_gpu_client( - *, platform_name: str, visible_devices_flag: config.Flag[str] -) -> xla_client.Client: - visible_devices = visible_devices_flag.value - allowed_devices = None - if visible_devices != "all": - allowed_devices = {int(x) for x in visible_devices.split(",")} - - mock_gpu_topology = _MOCK_GPU_TOPOLOGY.value or None - mock_num_gpu_processes = (_get_num_nodes_from_gpu_topology(mock_gpu_topology) if - mock_gpu_topology else _MOCK_NUM_GPU_PROCESSES.value) - - use_mock_gpu_client = mock_num_gpu_processes > 0 - num_nodes = (mock_num_gpu_processes if use_mock_gpu_client - else distributed.global_state.num_processes) - - if platform_name == "cuda": - if not os.getenv("JAX_SKIP_CUDA_CONSTRAINTS_CHECK"): - _check_cuda_versions() - else: - print('Skipped CUDA versions constraints check due to the ' - 'JAX_SKIP_CUDA_CONSTRAINTS_CHECK env var being set.') - - devices_to_check = ( - allowed_devices - if allowed_devices - else range(cuda_versions.cuda_device_count()) - ) - _check_cuda_compute_capability(devices_to_check) - - return xla_client.make_gpu_client( - distributed_client=distributed.global_state.client, - node_id=distributed.global_state.process_id, - num_nodes=num_nodes, - platform_name=platform_name, - allowed_devices=allowed_devices, - mock=use_mock_gpu_client, - ) - - -if hasattr(xla_client, "make_gpu_client"): - register_backend_factory( - "cuda", - partial( - make_gpu_client, - platform_name="cuda", - visible_devices_flag=CUDA_VISIBLE_DEVICES, - ), - priority=200, - fail_quietly=True, - ) - register_backend_factory( - "rocm", - partial( - make_gpu_client, - platform_name="rocm", - visible_devices_flag=_ROCM_VISIBLE_DEVICES, - ), - priority=200, - fail_quietly=True, - ) - - -if hasattr(xla_client, "make_tpu_client"): - # TODO(phawkins,skyewm): switch TPU plugin to use the PJRT plugin mechanism, - # and then fail loudly on initialization failure. - register_backend_factory( - 'tpu', partial(tpu_client_timer_callback, timer_secs=60.0), priority=300, - fail_quietly=True) +# TODO(phawkins,skyewm): switch TPU plugin to use the PJRT plugin mechanism, +# and then fail loudly on initialization failure. +register_backend_factory( + 'tpu', partial(tpu_client_timer_callback, timer_secs=60.0), priority=300, + fail_quietly=True) def _get_pjrt_plugin_names_and_library_paths( @@ -563,7 +428,7 @@ def discover_pjrt_plugins() -> None: """Discovers plugins in the namespace package `jax_plugins` and import them. There are two methods used to discover plugin modules. They are intended - to be used together by implementors in order to cover all packaging and + to be used together by implementers in order to cover all packaging and development cases: 1. Define a globally unique module under the `jax_plugins` namespace @@ -631,27 +496,30 @@ def _options_from_jax_configs(plugin_name): options = {} pjrt_client_options = config.jax_pjrt_client_create_options.value - pjrt_client_option_list = [] - if pjrt_client_options: - pjrt_client_option_list = pjrt_client_options.split(";") - - for option in pjrt_client_option_list: - option_list = option.split(":") - if (len(option_list) != 2): - raise RuntimeError( - "Multiple ':' separators for option in " - f"jax_pjrt_client_create_options: '{option}'. " - "Should be in format 'key:value'") - options[option_list[0]] = option_list[1] + if isinstance(pjrt_client_options, str): + pjrt_client_option_list = [] + if pjrt_client_options: + pjrt_client_option_list = pjrt_client_options.split(";") + + for option in pjrt_client_option_list: + option_list = option.split(":") + if (len(option_list) != 2): + raise RuntimeError( + "Multiple ':' separators for option in " + f"jax_pjrt_client_create_options: '{option}'. " + "Should be in format 'key:value'") + options[option_list[0]] = option_list[1] + elif isinstance(pjrt_client_options, dict): + options.update(pjrt_client_options) if plugin_name in ("cuda", "rocm"): visible_devices = (CUDA_VISIBLE_DEVICES.value if plugin_name == "cuda" else _ROCM_VISIBLE_DEVICES.value) if visible_devices != 'all': options['visible_devices'] = [int(x) for x in visible_devices.split(',')] - mock_gpu_topology = _MOCK_GPU_TOPOLOGY.value or None - mock_num_processes = (_get_num_nodes_from_gpu_topology(mock_gpu_topology) if - mock_gpu_topology else _MOCK_NUM_GPU_PROCESSES.value) + mock_gpu_topology = MOCK_GPU_TOPOLOGY.value or None + mock_num_processes = (get_num_nodes_from_gpu_topology(mock_gpu_topology) if + mock_gpu_topology else MOCK_NUM_GPU_PROCESSES.value) options['enable_mock_nccl'] = mock_num_processes > 0 if mock_num_processes > 0: options['num_nodes'] = mock_num_processes @@ -660,16 +528,63 @@ def _options_from_jax_configs(plugin_name): return options +OptionsDict = Mapping[str, str | int | list[int] | float | bool] + + +def make_pjrt_c_api_client( + plugin_name: str, + options: OptionsDict | Callable[[], OptionsDict] | None = None, +) -> xla_client.Client: + """Creates a PjRt client for the given plugin. + + Args: + plugin_name: the name of the plugin. + options: Optional. It is used when creating a PJRT plugin client. Can be a + callable, in which case it will be invoked upon plugin initialization + time, and will be expected to return an option dictionary. + """ + if not xla_client.pjrt_plugin_initialized(plugin_name): + xla_client.initialize_pjrt_plugin(plugin_name) + updated_options: dict[str, Any] = {} + if options is not None: + updated_options.update(options() if callable(options) else options) + updated_options.update(_options_from_jax_configs(plugin_name)) + if distributed.global_state.client is None: + return xla_client.make_c_api_client(plugin_name, updated_options, None) + + distribute_options = { + 'node_id': distributed.global_state.process_id, + 'num_nodes': distributed.global_state.num_processes, + } + if (partition_index := distributed.global_state.partition_index) is not None: + distribute_options['partition_index'] = partition_index + if options is not None: + distribute_options.update(updated_options) + if jaxlib_extension_version < 397: + return xla_client.make_c_api_client( + plugin_name, + distribute_options, + distributed.global_state.client, + _make_transfer_server_factory(), + ) + return xla_client.make_c_api_client( + plugin_name, + distribute_options, + distributed.global_state.client, + _make_transfer_server_factory(), + FORCE_DCN_CROSS_HOST_TRANSFERS.value, + ) + -# TODO(b/261345120): decide on a public name and expose a public method which is -# an alias of this method. def register_plugin( plugin_name: str, *, priority: int = 400, library_path: str | None = None, - options: Mapping[str, str | int | list[int] | float | bool] | None = None, + options: OptionsDict | Callable[[], OptionsDict] | None = None, c_api: Any | None = None, + factory: BackendFactory | None = None, + make_topology: TopologyFactory | None = None, ) -> Any: """Registers a backend factory for the PJRT plugin. @@ -679,28 +594,13 @@ def register_plugin( Default to be 400. library_path: Optional. The full path to the .so file of the plugin. The plugin needs to provide either the library_path or the c_api. - options: Optional. It is used when creating a PJRT plugin client. + options: Optional. It is used when creating a PJRT plugin client. Can be a + callable, in which case it will be invoked upon plugin initialization + time, and will be expected to return an option dictionary. c_api: Optional. The plugin can provide a PJRT C API to be registered. + factory: Optional. A factory function that creates a PJRT client. If not + provided, a default factory will be used. """ - def factory(): - if not xla_client.pjrt_plugin_initialized(plugin_name): - xla_client.initialize_pjrt_plugin(plugin_name) - updated_options = {} - if options is not None: - updated_options.update(options) - updated_options.update(_options_from_jax_configs(plugin_name)) - if distributed.global_state.client is None: - return xla_client.make_c_api_client(plugin_name, updated_options, None) - - distribute_options = { - 'node_id': distributed.global_state.process_id, - 'num_nodes': distributed.global_state.num_processes, - } - if options is not None: - distribute_options.update(updated_options) - return xla_client.make_c_api_client( - plugin_name, distribute_options, distributed.global_state.client - ) if library_path and c_api: logger.error( @@ -717,17 +617,26 @@ def factory(): ) return + if factory is not None and options is not None: + raise ValueError( + "Cannot provide both 'factory' and 'options' when registering PJRT" + " plugin. When providing a custom factory, the factory's must handle" + " its own options." + ) + if factory is None: + factory = partial(make_pjrt_c_api_client, plugin_name, options=options) + logger.debug( 'registering PJRT plugin %s from %s', plugin_name, library_path ) if library_path is not None: c_api = xla_client.load_pjrt_plugin_dynamically(plugin_name, library_path) - xla_client.profiler.register_plugin_profiler(c_api) + _profiler.register_plugin_profiler(c_api) else: assert c_api is not None xla_client.load_pjrt_plugin_with_c_api(plugin_name, c_api) - make_topology = partial(xla_client.make_c_api_device_topology, c_api) + make_topology = make_topology or partial(xla_client.make_c_api_device_topology, c_api) experimental = plugin_name not in _nonexperimental_plugins register_backend_factory(plugin_name, factory, priority=priority, fail_quietly=False, experimental=experimental, @@ -950,14 +859,14 @@ def _suggest_missing_backends(): assert _default_backend is not None default_platform = _default_backend.platform if "cuda" not in _backends and hardware_utils.has_visible_nvidia_gpu(): - if hasattr(xla_extension, "GpuAllocatorConfig") and "cuda" in _backend_errors: + if hasattr(_jax, "GpuAllocatorConfig") and "cuda" in _backend_errors: err = _backend_errors["cuda"] warning_msg = f"CUDA backend failed to initialize: {err}." if "no supported devices found for platform CUDA." in err: warning_msg += ( "This may be due to JAX pre-allocating too much device " "memory, leaving too little for CUDA library initialization. See " - "https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html " + "https://docs.jax.dev/en/latest/gpu_memory_allocation.html " "for more details and potential workarounds." ) warning_msg += "(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)" @@ -984,8 +893,6 @@ def _clear_backends() -> None: _backend_errors = {} _default_backend = None - get_backend.cache_clear() - def _init_backend(platform: str) -> xla_client.Client: registration = _backend_factories.get(platform, None) @@ -1042,7 +949,7 @@ def _get_backend_uncached( return _default_backend -@lru_cache(maxsize=None) # don't use util.memoize because there is no X64 dependence. +@util.cache(max_size=None, trace_context_in_key=False) # don't use util.memoize because there is no X64 dependence. def get_backend( platform: None | str | xla_client.Client = None ) -> xla_client.Client: @@ -1092,7 +999,7 @@ def devices( ) -> list[xla_client.Device]: """Returns a list of all devices for a given backend. - .. currentmodule:: jaxlib.xla_extension + .. currentmodule:: jaxlib._jax Each device is represented by a subclass of :class:`Device` (e.g. :class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is @@ -1137,15 +1044,25 @@ def backend_xla_version(platform=None) -> int | None: """Returns the XLA version of the backend. Returns None if the backend does not use PJRT C API or does not have - xla_version in the plugin attributes. This methon can be used to skip features + xla_version in the plugin attributes. This method can be used to skip features that are not available before certain xla_version if the backend is a plugin and uses xla_version. """ backend = get_backend(platform) return getattr(backend, "xla_version", None) +def backend_stablehlo_version(platform=None) -> Sequence[int] | None: + """Returns the StableHLO version of the backend. -@lru_cache + Returns None if the backend does not use PJRT C API or does not have + stablehlo_current_version in the plugin attributes. This method can be used to + skip features that are not available before certain stablehlo_current_version + if the backend is a plugin and uses stablehlo_current_version. + """ + backend = get_backend(platform) + return getattr(backend, "stablehlo_current_version", None) + +@util.cache(max_size=None, trace_context_in_key=False) def local_devices(process_index: int | None = None, backend: str | xla_client.Client | None = None, host_id: int | None = None) -> list[xla_client.Device]: @@ -1203,12 +1120,13 @@ def host_id(backend: str | xla_client.Client | None = None) -> int: return process_index(backend) -@lru_cache +@util.cache(max_size=None, trace_context_in_key=False) def process_count( backend: str | xla_client.Client | None = None ) -> int: """Returns the number of JAX processes associated with the backend.""" - return max(d.process_index for d in devices(backend)) + 1 + gen = (d.process_index for d in devices(backend)) + return max(gen, default=0) + 1 # TODO: remove this sometime after jax 0.2.13 is released @@ -1266,7 +1184,7 @@ def make_pjrt_tpu_topology(topology_name='', **kwargs): "JAX TPU support not installed; cannot generate TPU topology. See" " https://github.com/jax-ml/jax#installation") c_api = xla_client.load_pjrt_plugin_dynamically("tpu", library_path) - xla_client.profiler.register_plugin_profiler(c_api) + _profiler.register_plugin_profiler(c_api) assert xla_client.pjrt_plugin_loaded("tpu") if not xla_client.pjrt_plugin_initialized("tpu"): xla_client.initialize_pjrt_plugin("tpu") @@ -1274,10 +1192,12 @@ def make_pjrt_tpu_topology(topology_name='', **kwargs): topology_name, **kwargs ) -def _validate_backend_not_initialized(new_val): +def _validate_backend_not_initialized(name, new_val): if backends_are_initialized(): + if getattr(config.config, name) == new_val: + return raise RuntimeError( - "jax_num_cpu_devices config should be updated before backends are" + f"{name} config should be updated before backends are" " initialized i.e. before any JAX operation is executed. You should" " initialize this config immediately after `import jax`.") @@ -1288,5 +1208,28 @@ def _validate_backend_not_initialized(new_val): "Number of CPU devices to use. If not provided, the value of " "the XLA flag --xla_force_host_platform_device_count is used." " Must be set before JAX is initialized."), - validator=_validate_backend_not_initialized, + validator=partial(_validate_backend_not_initialized, "jax_num_cpu_devices"), +) + +cpu_get_local_topology_timeout_minutes = config.int_state( + name="jax_cpu_get_local_topology_timeout_minutes", + default=2, + help=( + "Timeout in minutes for getting the local topology of each CPU device" + " when building the global topology." + ), + validator=partial(_validate_backend_not_initialized, + "jax_cpu_get_local_topology_timeout_minutes"), +) + +cpu_get_global_topology_timeout_minutes = config.int_state( + name="jax_cpu_get_global_topology_timeout_minutes", + default=5, + help=( + "Timeout in minutes for getting the global topology of CPU devices;" + " should be strictly greater than" + " `--jax_cpu_get_local_topology_timeout_minutes`." + ), + validator=partial(_validate_backend_not_initialized, + "jax_cpu_get_global_topology_timeout_minutes"), ) diff --git a/jax/_src/xla_metadata.py b/jax/_src/xla_metadata.py index 91895b4e7851..e7580dfb2d7f 100644 --- a/jax/_src/xla_metadata.py +++ b/jax/_src/xla_metadata.py @@ -12,46 +12,57 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial, wraps from typing import Any -from contextlib import contextmanager from jax._src import config +from jax._src import core +from jax._src import dispatch +from jax._src import tree_util +from jax._src import xla_metadata_lib +from jax._src.interpreters import ad, batching, mlir from jax._src.lib import xla_client +from jax._src.lib.mlir import ir config_ext = xla_client._xla.config -class XlaMetadata: - __slots__ = ['val', 'hash'] +class _XlaMetadataWrapper: + """A wrapper class to allow XlaMetadataContextManager to be used as a decorator. - def __init__(self, val): - self.val = val - self.hash = hash(tuple(sorted(self.val.items()))) + When XlaMetadataContextManager is used as a decorator on a function `f`, it + returns an instance of this class. This wrapper ensures that when `f` is + called, it runs within the metadata context. It also forwards attribute + access to `f` via `__getattr__`, and if an attribute of `f` is callable (e.g., + the `.lower()` method of a jitted function), it wraps that attribute so it + too runs within the metadata context when called. This allows decorated + functions to be used seamlessly with JAX transformations like `jax.jit`. + """ - def __hash__(self): - return self.hash + def __init__(self, f, ctx): + self._f = f + self._ctx = ctx + wraps(f)(self) - def __eq__(self, other): - return other is not None and self.val == other.val + def __call__(self, *args, **kwargs): + with self._ctx: + return self._f(*args, **kwargs) + def __getattr__(self, name): + attr = getattr(self._f, name) + if not callable(attr): + return attr -def update_metadata(a, b: dict[str, Any]): - if not b: - return a - if a is None or a is config_ext.unset: - return XlaMetadata(b) - val = a.val.copy() - val.update(b) - return XlaMetadata(val) + @wraps(attr) + def wrapper(*args, **kwargs): + with self._ctx: + return attr(*args, **kwargs) - -def current_xla_metadata(): - metadata = config.xla_metadata_context_manager.value - return None if metadata is None else metadata.val + return wrapper class XlaMetadataContextManager: - __slots__ = ['prev', 'updates'] + __slots__ = ["prev", "updates"] def __init__(self, updates): self.updates = updates @@ -62,7 +73,7 @@ def __enter__(self): self.prev = config.xla_metadata_context_manager.get_local() config.xla_metadata_context_manager.set_local( - update_metadata(self.prev, self.updates) + xla_metadata_lib.update_metadata(self.prev, self.updates) ) def __exit__(self, exc_type, exc_value, traceback): @@ -70,7 +81,71 @@ def __exit__(self, exc_type, exc_value, traceback): return config.xla_metadata_context_manager.set_local(self.prev) -@contextmanager -def set_xla_metadata(**kwargs): - with XlaMetadataContextManager(kwargs): - yield + def __call__(self, f): + return _XlaMetadataWrapper(f, self) + + +def set_xla_metadata(x=None, **kwargs): + if x is None: + return XlaMetadataContextManager(kwargs) + else: + hashable_metadata = tuple(sorted(kwargs.items())) + return tree_util.tree_map( + lambda v: xla_metadata_value_p.bind( + v, xla_metadata_kvs=hashable_metadata + ), + x, + ) + + +# `xla_metadata_value_p` is an identity primitive for attaching frontend_attributes +# to the primitive's producing (parent/owner) op. +xla_metadata_value_p = core.Primitive("xla_metadata_value") +xla_metadata_value_p.def_impl( + partial(dispatch.apply_primitive, xla_metadata_value_p) +) +xla_metadata_value_p.def_abstract_eval(lambda aval, *, xla_metadata_kvs: aval) +batching.defvectorized(xla_metadata_value_p) +# TODO(nbasile): Implement tagging gradient ops with metadata. +ad.deflinear2(xla_metadata_value_p, lambda ct, _, **kwargs: (ct,)) + + +def _xla_metadata_value_lowering_rule( + ctx: mlir.LoweringRuleContext, val: ir.Value, *, xla_metadata_kvs +): + xla_metadata = dict(xla_metadata_kvs) + op_to_attach_metadata = _target_op_to_attach_metadata(val) + if op_to_attach_metadata is not None: + _attach_xla_metadata_to_op(xla_metadata, op_to_attach_metadata) + return [val] + + +# If we leave `cacheable=True`, when we are in the lowering rule, the `val.owner` +# becomes a cached `FuncOp`. FuncOp.owners are Blocks, which we can't tag. +mlir.register_lowering( + xla_metadata_value_p, _xla_metadata_value_lowering_rule, cacheable=False +) + + +def _target_op_to_attach_metadata(value_mlir: ir.Value) -> ir.Operation | None: + op = value_mlir.owner + if op is None or isinstance(op, ir.Block): + return None + return op + + +def _attach_xla_metadata_to_op( + xla_metadata: dict[str, Any], op: ir.Operation +) -> None: + if xla_metadata: + ctx_attributes, existing_attributes = {}, {} + for k, v in xla_metadata.items(): + ctx_attributes[k] = ir.StringAttr.get(str(v).lower()) + # Combine with existing mhlo.frontend_attributes + for attr in op.attributes: + if attr == "mhlo.frontend_attributes": + for a in op.attributes[attr]: + existing_attributes[a.name] = a.attr + op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get( + ctx_attributes | existing_attributes + ) diff --git a/jax/_src/xla_metadata_lib.py b/jax/_src/xla_metadata_lib.py new file mode 100644 index 000000000000..65b0f4b548ba --- /dev/null +++ b/jax/_src/xla_metadata_lib.py @@ -0,0 +1,56 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from jax._src import config +from jax._src.lib import xla_client + +config_ext = xla_client._xla.config + + +class XlaMetadata: + __slots__ = ['val', 'hash'] + + val: dict[str, Any] + + def __init__(self, val): + self.val = val + self.hash = hash(tuple(sorted(self.val.items()))) + + def __hash__(self): + return self.hash + + def __eq__(self, other): + return other is not None and self.val == other.val + + +def filter_nones(d: dict) -> dict: + return {k: v for k, v in d.items() if v is not None} + + +def update_metadata(a, b: dict[str, Any]): + if not b: + return a + if a is None or a is config_ext.unset: + val = {} + else: + val = a.val.copy() + val.update(b) + return XlaMetadata(filter_nones(val)) + + +def current_xla_metadata() -> dict[str, Any] | None: + metadata = config.xla_metadata_context_manager.value + return None if metadata is None else metadata.val diff --git a/jax/ad_checkpoint.py b/jax/ad_checkpoint.py index 44c13e379330..f780e0068d9b 100644 --- a/jax/ad_checkpoint.py +++ b/jax/ad_checkpoint.py @@ -13,14 +13,36 @@ # limitations under the License. from jax._src.ad_checkpoint import ( - checkpoint as checkpoint, + checkpoint as _deprecated_checkpoint, checkpoint_policies as checkpoint_policies, checkpoint_name as checkpoint_name, print_saved_residuals as print_saved_residuals, - remat as remat, ) from jax._src.interpreters.partial_eval import ( Recompute as Recompute, Saveable as Saveable, Offloadable as Offloadable, ) + +_deprecations = { + # Added for v0.8.2 + "checkpoint": ( + "jax.ad_checkpoint.checkpoint is deprecated; use jax.checkpoint instead.", + _deprecated_checkpoint + ), + "remat": ( + "jax.ad_checkpoint.remat is deprecated; use jax.remat instead.", + _deprecated_checkpoint + ), +} + +import typing as _typing +if _typing.TYPE_CHECKING: + checkpoint = _deprecated_checkpoint + remat = _deprecated_checkpoint +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del _typing +del _deprecated_checkpoint diff --git a/jax/cloud_tpu_init.py b/jax/cloud_tpu_init.py index 8b886eb6b450..a564975b60c1 100644 --- a/jax/cloud_tpu_init.py +++ b/jax/cloud_tpu_init.py @@ -12,4 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax._src.cloud_tpu_init import cloud_tpu_init as cloud_tpu_init +import warnings + +warnings.warn( + "jax.cloud_tpu_init was deprecated in JAX v0.8.1. You should remove imports" + " of this module.", + DeprecationWarning, stacklevel=1 +) + +del warnings + +from jax._src.cloud_tpu_init import cloud_tpu_init as _cloud_tpu_init + +_deprecations = { + # Added 2025-10-28, remove in JAX 0.10. + "cloud_tpu_init": ( + "jax.cloud_tpu_init was deprecated in JAX v0.8.1. You do not need to call " + "this function explicitly; JAX calls this function automatically.", + _cloud_tpu_init + ), +} + +import typing as _typing +if _typing.TYPE_CHECKING: + cloud_tpu_init = _cloud_tpu_init +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del _typing diff --git a/jax/collect_profile.py b/jax/collect_profile.py index d1309e0c5bca..3f78ea4095e5 100644 --- a/jax/collect_profile.py +++ b/jax/collect_profile.py @@ -14,6 +14,7 @@ from __future__ import annotations +from typing import Any import argparse import gzip import os @@ -23,15 +24,11 @@ # pytype: disable=import-error from jax._src import profiler as jax_profiler try: - from tensorflow.python.profiler import profiler_v2 as profiler - from tensorflow.python.profiler import profiler_client -except ImportError: - raise ImportError("This script requires `tensorflow` to be installed.") -try: - from tensorboard_plugin_profile.convert import raw_to_tool_data as convert + from xprof.convert import _pywrap_profiler_plugin + from xprof.convert import raw_to_tool_data as convert except ImportError: raise ImportError( - "This script requires `tensorboard_plugin_profile` to be installed.") + "This script requires `xprof` to be installed.") # pytype: enable=import-error @@ -43,8 +40,16 @@ for a provided duration. The trace file will be dumped into a directory (determined by `--log_dir`) and by default, a Perfetto UI link will be generated to view the resulting trace. + +Common tracer options (with defaults): + --host_tracer_level=2 Profiler host tracer level. + --device_tracer_level=1 Profiler device tracer level. + --python_tracer_level=1 Profiler Python tracer level. """ -parser = argparse.ArgumentParser(description=_DESCRIPTION) +_GRPC_PREFIX = 'grpc://' +DEFAULT_NUM_TRACING_ATTEMPTS = 3 +parser = argparse.ArgumentParser(description=_DESCRIPTION, + formatter_class=argparse.RawTextHelpFormatter) parser.add_argument("--log_dir", default=None, help=("Directory to store log files. " "Uses a temporary directory if none provided."), @@ -58,29 +63,41 @@ parser.add_argument("--host", default="127.0.0.1", help="Host to collect trace. Defaults to 127.0.0.1", type=str) -parser.add_argument("--host_tracer_level", default=2, - help="Profiler host tracer level", type=int) -parser.add_argument("--device_tracer_level", default=1, - help="Profiler device tracer level", type=int) -parser.add_argument("--python_tracer_level", default=1, - help="Profiler Python tracer level", type=int) - -def collect_profile(port: int, duration_in_ms: int, host: str, - log_dir: os.PathLike | str | None, host_tracer_level: int, - device_tracer_level: int, python_tracer_level: int, - no_perfetto_link: bool): - options = profiler.ProfilerOptions( - host_tracer_level=host_tracer_level, - device_tracer_level=device_tracer_level, - python_tracer_level=python_tracer_level, - ) + +def collect_profile( + port: int, + duration_in_ms: int, + host: str, + log_dir: os.PathLike | str | None, + no_perfetto_link: bool, + xprof_options: dict[str, Any] | None = None,): + options: dict[str, Any] = { + "host_tracer_level": 2, + "device_tracer_level": 1, + "python_tracer_level": 1, + } + if xprof_options: + options.update(xprof_options) + + IS_GCS_PATH = str(log_dir).startswith("gs://") log_dir_ = pathlib.Path(log_dir if log_dir is not None else tempfile.mkdtemp()) - profiler_client.trace( - f"{host}:{port}", - str(log_dir_), + str_log_dir = log_dir if IS_GCS_PATH else str(log_dir_) + _pywrap_profiler_plugin.trace( + _strip_addresses(f"{host}:{port}", _GRPC_PREFIX), + str_log_dir, + '', + True, duration_in_ms, - options=options) - print(f"Dumped profiling information in: {log_dir_}") + DEFAULT_NUM_TRACING_ATTEMPTS, + options, + ) + print(f"Dumped profiling information in: {str_log_dir}") + # Traces stored on GCS cannot be converted to a Perfetto trace, as JAX doesn't + # directly support GCS paths. + if IS_GCS_PATH: + if not no_perfetto_link: + print("Perfetto link is not supported for GCS paths, skipping creation.") + return # The profiler dumps `xplane.pb` to the logging directory. To upload it to # the Perfetto trace viewer, we need to convert it to a `trace.json` file. # We do this by first finding the `xplane.pb` file, then passing it into @@ -91,7 +108,7 @@ def collect_profile(port: int, duration_in_ms: int, host: str, in root_trace_folder.iterdir()] latest_folder = max(trace_folders, key=os.path.getmtime) xplane = next(latest_folder.glob("*.xplane.pb")) - result, _ = convert.xspace_to_tool_data([xplane], "trace_viewer^", {}) + result, _ = convert.xspace_to_tool_data([xplane], "trace_viewer", {}) with gzip.open(str(latest_folder / "remote.trace.json.gz"), "wb") as fp: fp.write(result.encode("utf-8")) @@ -100,10 +117,56 @@ def collect_profile(port: int, duration_in_ms: int, host: str, path = jax_profiler._write_perfetto_trace_file(log_dir_) jax_profiler._host_perfetto_trace_file(path) -def main(args): - collect_profile(args.port, args.duration_in_ms, args.host, args.log_dir, - args.host_tracer_level, args.device_tracer_level, - args.python_tracer_level, args.no_perfetto_link) +def _strip_prefix(s, prefix): + return s[len(prefix):] if s.startswith(prefix) else s + +def _strip_addresses(addresses, prefix): + return ','.join([_strip_prefix(s, prefix) for s in addresses.split(',')]) + +def _parse_xprof_flags(unknown_flags: list[str]) -> dict[str, Any]: + parsed: dict[str, Any] = {} + i = 0 + while i < len(unknown_flags): + arg = unknown_flags[i] + if not arg.startswith('--'): + raise ValueError(f"Unknown positional argument encountered: {arg}") + + key = arg[2:] + if "=" in key: + key, value_str = key.split("=", 1) + i += 1 + elif i + 1 < len(unknown_flags) and not unknown_flags[i + 1].startswith('--'): + value_str = unknown_flags[i + 1] + i += 2 + else: + parsed[key] = True + i += 1 + continue + + value_lower = value_str.lower() + if value_lower in {'true', 't', 'yes', 'y'}: + parsed[key] = True + elif value_lower in {'false', 'f', 'no', 'n'}: + parsed[key] = False + else: + try: + parsed[key] = int(value_str, 0) + except ValueError: + parsed[key] = value_str # Keep as string + return parsed + + +def main(known_args, unknown_flags): + xprof_options = _parse_xprof_flags(unknown_flags) + collect_profile( + known_args.port, + known_args.duration_in_ms, + known_args.host, + known_args.log_dir, + known_args.no_perfetto_link, + xprof_options, + ) if __name__ == "__main__": - main(parser.parse_args()) + known_args, unknown_flags = parser.parse_known_args() + main(known_args, unknown_flags) diff --git a/jax/core.py b/jax/core.py index 3fd7af440d4a..a42aa4a950a4 100644 --- a/jax/core.py +++ b/jax/core.py @@ -15,36 +15,28 @@ # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/jax-ml/jax/issues/7570 +import jax._src.core as _src_core from jax._src.core import ( - AbstractToken as AbstractToken, AbstractValue as AbstractValue, Atom as Atom, CallPrimitive as CallPrimitive, DebugInfo as DebugInfo, - DShapedArray as DShapedArray, DropVar as DropVar, Effect as Effect, Effects as Effects, - get_opaque_trace_state as get_opaque_trace_state, InconclusiveDimensionOperation as InconclusiveDimensionOperation, JaxprPpContext as JaxprPpContext, JaxprPpSettings as JaxprPpSettings, JaxprTypeError as JaxprTypeError, - nonempty_axis_env as nonempty_axis_env_DO_NOT_USE, # noqa: F401 OutputType as OutputType, ParamDict as ParamDict, ShapedArray as ShapedArray, Trace as Trace, Tracer as Tracer, - unsafe_am_i_under_a_jit as unsafe_am_i_under_a_jit_DO_NOT_USE, # noqa: F401 - unsafe_am_i_under_a_vmap as unsafe_am_i_under_a_vmap_DO_NOT_USE, # noqa: F401 - unsafe_get_axis_names as unsafe_get_axis_names_DO_NOT_USE, # noqa: F401 - UnshapedArray as UnshapedArray, Value as Value, abstract_token as abstract_token, aval_mapping_handlers as aval_mapping_handlers, call as call, - call_impl as call_impl, check_jaxpr as check_jaxpr, concrete_or_error as concrete_or_error, concretization_function_error as concretization_function_error, @@ -54,184 +46,86 @@ eval_jaxpr as eval_jaxpr, find_top_trace as find_top_trace, gensym as gensym, - get_aval as get_aval, + get_opaque_trace_state as get_opaque_trace_state, is_concrete as is_concrete, is_constant_dim as is_constant_dim, is_constant_shape as is_constant_shape, jaxprs_in_params as jaxprs_in_params, literalable_types as literalable_types, - mapped_aval as mapped_aval, max_dim as max_dim, min_dim as min_dim, new_jaxpr_eqn as new_jaxpr_eqn, no_axis_name as no_axis_name, no_effects as no_effects, + nonempty_axis_env as nonempty_axis_env_DO_NOT_USE, # noqa: F401 primal_dtype_to_tangent_dtype as primal_dtype_to_tangent_dtype, pytype_aval_mappings as pytype_aval_mappings, - set_current_trace as set_current_trace, - subjaxprs as subjaxprs, - take_current_trace as take_current_trace, trace_ctx as trace_ctx, - TraceTag as TraceTag, - traverse_jaxpr_params as traverse_jaxpr_params, - unmapped_aval as unmapped_aval, + unsafe_am_i_under_a_jit as unsafe_am_i_under_a_jit_DO_NOT_USE, # noqa: F401 + unsafe_am_i_under_a_vmap as unsafe_am_i_under_a_vmap_DO_NOT_USE, # noqa: F401 + unsafe_get_axis_names as unsafe_get_axis_names_DO_NOT_USE, # noqa: F401 valid_jaxtype as valid_jaxtype, ) - -from jax._src import core as _src_core _deprecations = { - # Added 2024-12-16 - "ClosedJaxpr": ("jax.core.ClosedJaxpr is deprecated. Use jax.extend.core.ClosedJaxpr instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.ClosedJaxpr), - "Jaxpr": ("jax.core.Jaxpr is deprecated. Use jax.extend.core.Jaxpr instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.Jaxpr), - "JaxprEqn": ("jax.core.JaxprEqn is deprecated. Use jax.extend.core.JaxprEqn instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.JaxprEqn), - "Literal": ("jax.core.Literal is deprecated. Use jax.extend.core.Literal instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.Literal), - "Primitive": ("jax.core.Primitive is deprecated. Use jax.extend.core.Primitive instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.Primitive), - "Token": ("jax.core.Token is deprecated. Use jax.extend.core.Token instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.Token), - "Var": ("jax.core.Var is deprecated. Use jax.extend.core.Var instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.Var), - # Added 2024-12-11 - "axis_frame": ("jax.core.axis_frame is deprecated.", _src_core.axis_frame), - "AxisName": ("jax.core.AxisName is deprecated.", _src_core.AxisName), - "AxisSize": ("jax.core.AxisSize is deprecated.", _src_core.AxisSize), - "ConcretizationTypeError": ("jax.core.ConcretizationTypeError is deprecated; " - "use jax.errors.ConcretizationTypeError.", - _src_core.ConcretizationTypeError), - "EvalTrace": ("jax.core.EvalTrace is deprecated.", _src_core.EvalTrace), - "InDBIdx": ("jax.core.InDBIdx is deprecated.", _src_core.InDBIdx), - "InputType": ("jax.core.InputType is deprecated.", _src_core.InputType), - "MapPrimitive": ("jax.core.MapPrimitive is deprecated.", _src_core.MapPrimitive), - "OpaqueTraceState": ("jax.core.OpaqueTraceState is deprecated.", _src_core.OpaqueTraceState), - "OutDBIdx": ("jax.core.OutDBIdx is deprecated.", _src_core.OutDBIdx), - "TRACER_LEAK_DEBUGGER_WARNING": ("jax.core.TRACER_LEAK_DEBUGGER_WARNING is deprecated.", - _src_core.TRACER_LEAK_DEBUGGER_WARNING), - "call_p": ("jax.core.call_p is deprecated. Use jax.extend.core.primitives.call_p", - _src_core.call_p), - "closed_call_p": ("jax.core.closed_call_p is deprecated. Use jax.extend.core.primitives.closed_call_p", - _src_core.closed_call_p), - "concrete_aval": ("jax.core.concrete_aval is deprecated.", _src_core.abstractify), - "dedup_referents": ("jax.core.dedup_referents is deprecated.", _src_core.dedup_referents), - "escaped_tracer_error": ("jax.core.escaped_tracer_error is deprecated.", - _src_core.escaped_tracer_error), - "extend_axis_env_nd": ("jax.core.extend_axis_env_nd is deprecated.", - _src_core.extend_axis_env_nd), - "get_type": ("jax.core.get_type is deprecated.", _src_core.get_aval), - "get_referent": ("jax.core.get_referent is deprecated.", _src_core.get_referent), - "join_effects": ("jax.core.join_effects is deprecated.", _src_core.join_effects), - "leaked_tracer_error": ("jax.core.leaked_tracer_error is deprecated.", - _src_core.leaked_tracer_error), - "maybe_find_leaked_tracers": ("jax.core.maybe_find_leaked_tracers is deprecated.", - _src_core.maybe_find_leaked_tracers), - "raise_to_shaped_mappings": ("jax.core.raise_to_shaped_mappings is deprecated." - " It is unused as of jax v0.4.36.", - _src_core.raise_to_shaped_mappings), - "reset_trace_state": ("jax.core.reset_trace_state is deprecated.", - _src_core.reset_trace_state), - "str_eqn_compact": ("jax.core.str_eqn_compact is deprecated.", _src_core.str_eqn_compact), - "substitute_vars_in_output_ty": ("jax.core.substitute_vars_in_output_ty is deprecated.", - _src_core.substitute_vars_in_output_ty), - "trace_state_clean": ("jax.core.trace_state_clean is deprecated.", - _src_core.trace_state_clean), - "typecheck": ("jax.core.typecheck is deprecated.", _src_core.typecheck), - "typecompat": ("jax.core.typecompat is deprecated.", _src_core.typecompat), - "typematch": ("jax.core.typematch is deprecated.", _src_core.typematch), - "used_axis_names_jaxpr": ("jax.core.used_axis_names_jaxpr is deprecated.", - _src_core.used_axis_names_jaxpr), - # Added 2024-12-10 - "full_lower": ("jax.core.full_lower is deprecated. It is a no-op as of JAX v0.4.36.", - _src_core.full_lower), - "jaxpr_as_fun": ("jax.core.jaxpr_as_fun is deprecated. Use jax.extend.core.jaxpr_as_fun instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.jaxpr_as_fun), - "lattice_join": ("jax.core.lattice_join is deprecated. It is a no-op as of JAX v0.4.36.", - _src_core.lattice_join), - "raise_to_shaped": ("jax.core.raise_to_shaped is deprecated. It is a no-op as of JAX v0.4.36.", - _src_core.raise_to_shaped), - # Finalized 2024-12-11; remove after 2025-3-11 - "check_eqn": ("jax.core.check_eqn was removed in JAX v0.4.38.", None), - "check_type": ("jax.core.check_type was removed in JAX v0.4.38.", None), - "check_valid_jaxtype": ( - ("jax.core.check_valid_jaxtype was removed in JAX v0.4.38. Instead, you can manually" - " raise an error if core.valid_jaxtype() returns False."), - None), - "non_negative_dim": ( - "jax.core.non_negative_dim was removed in JAX v0.4.38. Use max_dim(..., 0).", None, - ), - # Finalized 2024-09-25; remove after 2024-12-25 - "pp_aval": ("jax.core.pp_aval was removed in JAX v0.4.34.", None), - "pp_eqn": ("jax.core.pp_eqn was removed in JAX v0.4.34.", None), - "pp_eqn_rules": ("jax.core.pp_eqn_rules was removed in JAX v0.4.34.", None), - "pp_eqns": ("jax.core.pp_eqns was removed in JAX v0.4.34.", None), - "pp_jaxpr": ("jax.core.pp_jaxpr was removed in JAX v0.4.34.", None), - "pp_jaxpr_eqn_range": ("jax.core.pp_jaxpr_eqn_range was removed in JAX v0.4.34.", None), - "pp_jaxpr_skeleton": ("jax.core.pp_jaxpr_skeleton was removed in JAX v0.4.34.", None), - "pp_jaxprs": ("jax.core.pp_jaxprs was removed in JAX v0.4.34.", None), - "pp_kv_pair": ("jax.core.pp_kv_pair was removed in JAX v0.4.34.", None), - "pp_kv_pairs": ("jax.core.pp_kv_pairs was removed in JAX v0.4.34.", None), - "pp_var": ("jax.core.pp_var was removed in JAX v0.4.34.", None), - "pp_vars": ("jax.core.pp_vars was removed in JAX v0.4.34.", None), + # Added for v0.8.2 + "call_impl": ( + "jax.core.call_impl is deprecated.", + _src_core.call_impl, + ), + "get_aval": ( + "jax.core.get_aval is deprecated; use jax.typeof instead.", + _src_core.get_aval, + ), + "mapped_aval": ( + "jax.core.mapped_aval is deprecated. Use jax.extend.core.mapped_aval.", + _src_core.mapped_aval, + ), + "set_current_trace": ( + "jax.core.set_current_trace is deprecated.", + _src_core.set_current_trace, + ), + "subjaxprs": ( + "jax.core.subjaxprs is deprecated.", + _src_core.subjaxprs, + ), + "take_current_trace": ( + "jax.core.take_current_trace is deprecated.", + _src_core.take_current_trace, + ), + "traverse_jaxpr_params": ( + "jax.core.traverse_jaxpr_params is deprecated.", + _src_core.traverse_jaxpr_params, + ), + "unmapped_aval": ( + "jax.core.unmapped_aval is deprecated. Use jax.extend.core.unmapped_aval.", + _src_core.unmapped_aval, + ), + "AbstractToken": ( + "jax.core.AbstractToken is deprecated.", + _src_core.AbstractToken, + ), + "TraceTag": ( + "jax.core.TraceTag is deprecated.", + _src_core.TraceTag, + ), } -import typing -if typing.TYPE_CHECKING: - AxisName = _src_core.AxisName - AxisSize = _src_core.AxisSize - ClosedJaxpr = _src_core.ClosedJaxpr - ConcretizationTypeError = _src_core.ConcretizationTypeError - EvalTrace = _src_core.EvalTrace - InDBIdx = _src_core.InDBIdx - InputType = _src_core.InputType - Jaxpr = _src_core.Jaxpr - JaxprEqn = _src_core.JaxprEqn - Literal = _src_core.Literal - MapPrimitive = _src_core.MapPrimitive - OpaqueTraceState = _src_core.OpaqueTraceState - OutDBIdx = _src_core.OutDBIdx - Primitive = _src_core.Primitive - Token = _src_core.Token - TRACER_LEAK_DEBUGGER_WARNING = _src_core.TRACER_LEAK_DEBUGGER_WARNING - Var = _src_core.Var - axis_frame = _src_core.axis_frame - call_p = _src_core.call_p - closed_call_p = _src_core.closed_call_p - concrete_aval = _src_core.abstractify - dedup_referents = _src_core.dedup_referents - escaped_tracer_error = _src_core.escaped_tracer_error - extend_axis_env_nd = _src_core.extend_axis_env_nd - full_lower = _src_core.full_lower - get_type = _src_core.get_aval - get_referent = _src_core.get_referent - jaxpr_as_fun = _src_core.jaxpr_as_fun - join_effects = _src_core.join_effects - lattice_join = _src_core.lattice_join - leaked_tracer_error = _src_core.leaked_tracer_error - maybe_find_leaked_tracers = _src_core.maybe_find_leaked_tracers - raise_to_shaped = _src_core.raise_to_shaped - raise_to_shaped_mappings = _src_core.raise_to_shaped_mappings - reset_trace_state = _src_core.reset_trace_state - str_eqn_compact = _src_core.str_eqn_compact - substitute_vars_in_output_ty = _src_core.substitute_vars_in_output_ty - trace_state_clean = _src_core.trace_state_clean - typecheck = _src_core.typecheck - typecompat = _src_core.typecompat - typematch = _src_core.typematch - used_axis_names_jaxpr = _src_core.used_axis_names_jaxpr +import typing as _typing +if _typing.TYPE_CHECKING: + call_impl = _src_core.call_impl + get_aval = _src_core.get_aval + mapped_aval = _src_core.mapped_aval + subjaxprs = _src_core.subjaxprs + set_current_trace = _src_core.set_current_trace + take_current_trace = _src_core.take_current_trace + traverse_jaxpr_params = _src_core.traverse_jaxpr_params + unmapped_aval = _src_core.unmapped_aval + AbstractToken = _src_core.AbstractToken + TraceTag = _src_core.TraceTag else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) del _deprecation_getattr -del typing +del _typing del _src_core diff --git a/jax/custom_derivatives.py b/jax/custom_derivatives.py index 3628ae4aaa6e..edefdae40c44 100644 --- a/jax/custom_derivatives.py +++ b/jax/custom_derivatives.py @@ -23,10 +23,8 @@ custom_gradient as custom_gradient, custom_jvp as custom_jvp, custom_jvp_call_p as custom_jvp_call_p, - custom_jvp_call_jaxpr_p as custom_jvp_call_jaxpr_p, custom_vjp as custom_vjp, custom_vjp_call_p as custom_vjp_call_p, - custom_vjp_call_jaxpr_p as custom_vjp_call_jaxpr_p, custom_vjp_primal_tree_values as custom_vjp_primal_tree_values, CustomVJPPrimal as CustomVJPPrimal, linear_call as linear_call, diff --git a/jax/debug.py b/jax/debug.py index 64578524d096..05a7091e1964 100644 --- a/jax/debug.py +++ b/jax/debug.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ["callback", "print", "DebugEffect", "visualize_array_sharding", +__all__ = ["callback", "print", "log", "DebugEffect", + "visualize_array_sharding", "inspect_array_sharding", "visualize_sharding", "breakpoint"] from jax._src.debugging import debug_callback as callback from jax._src.debugging import debug_print as print +from jax._src.debugging import debug_log as log from jax._src.debugging import DebugEffect from jax._src.debugging import visualize_array_sharding from jax._src.debugging import inspect_array_sharding diff --git a/jax/dlpack.py b/jax/dlpack.py index a65496ec0cbf..6fa73748ee8b 100644 --- a/jax/dlpack.py +++ b/jax/dlpack.py @@ -13,7 +13,6 @@ # limitations under the License. from jax._src.dlpack import ( - to_dlpack as to_dlpack, from_dlpack as from_dlpack, - SUPPORTED_DTYPES as SUPPORTED_DTYPES, + is_supported_dtype as is_supported_dtype, ) diff --git a/jax/dtypes.py b/jax/dtypes.py index 4c1136360687..894af115f16b 100644 --- a/jax/dtypes.py +++ b/jax/dtypes.py @@ -17,6 +17,7 @@ from jax._src.dtypes import ( bfloat16 as bfloat16, + itemsize_bits as itemsize_bits, canonicalize_dtype as canonicalize_dtype, finfo, # TODO(phawkins): switch callers to jnp.finfo? # noqa: F401 float0 as float0, diff --git a/jax/errors.py b/jax/errors.py index 6da7b717cb5f..a4a6c5388db2 100644 --- a/jax/errors.py +++ b/jax/errors.py @@ -27,8 +27,7 @@ KeyReuseError as KeyReuseError, ) -from jax._src.lib import xla_client as _xc -JaxRuntimeError = _xc.XlaRuntimeError -del _xc - -from jax._src.traceback_util import SimplifiedTraceback as SimplifiedTraceback +from jax._src.lib import _jax +JaxRuntimeError = _jax.JaxRuntimeError +JaxRuntimeError.__module__ = "jax.errors" +del _jax diff --git a/jax/example_libraries/BUILD b/jax/example_libraries/BUILD new file mode 100644 index 000000000000..e740757c32a1 --- /dev/null +++ b/jax/example_libraries/BUILD @@ -0,0 +1,49 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//jaxlib:jax.bzl", "py_deps", "pytype_strict_library") + +package( + default_applicable_licenses = [], + default_visibility = ["//jax:internal"], +) + +pytype_strict_library( + name = "example_libraries", + srcs = [ + "__init__.py", + ], + visibility = ["//jax:internal"], +) + +pytype_strict_library( + name = "stax", + srcs = [ + "stax.py", + ], + visibility = ["//visibility:public"], + deps = ["//jax"], +) + +pytype_strict_library( + name = "optimizers", + srcs = [ + "optimizers.py", + ], + visibility = ["//visibility:public"], + deps = [ + "//jax", + "//jax/_src:util", + ] + py_deps("numpy"), +) diff --git a/jax/example_libraries/stax.py b/jax/example_libraries/stax.py index 476252d92d5d..dcd93b63fd78 100644 --- a/jax/example_libraries/stax.py +++ b/jax/example_libraries/stax.py @@ -24,7 +24,7 @@ """ import functools -import operator as op +import math from jax import lax from jax import random @@ -215,7 +215,7 @@ def rescale(outputs, inputs, spec): def Flatten(): """Layer construction function for flattening all but the leading dim.""" def init_fun(rng, input_shape): - output_shape = input_shape[0], functools.reduce(op.mul, input_shape[1:], 1) + output_shape = input_shape[0], math.prod(input_shape[1:]) return output_shape, () def apply_fun(params, inputs, **kwargs): return jnp.reshape(inputs, (inputs.shape[0], -1)) diff --git a/jax/experimental/BUILD b/jax/experimental/BUILD new file mode 100644 index 000000000000..9e3c7b5008a5 --- /dev/null +++ b/jax/experimental/BUILD @@ -0,0 +1,765 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load( + "//jaxlib:jax.bzl", + "buffer_callback_internal_users", + "experimental_transfer_users", + "if_cuda_is_configured", + "jax_visibility", + "mosaic_gpu_internal_users", + "mosaic_internal_users", + "pallas_fuser_users", + "pallas_gpu_internal_users", + "pallas_sc_internal_users", + "py_deps", + "py_library_providing_imports_info", + "pytype_strict_library", + "serialize_executable_internal_users", +) + +package( + default_applicable_licenses = [], + default_visibility = ["//jax:internal"], +) + +# Package groups for controlling visibility of experimental APIs. + +package_group( + name = "buffer_callback_users", + includes = ["//jax:internal"], + packages = buffer_callback_internal_users, +) + +package_group( + name = "experimental_transfer_users", + includes = ["//jax:internal"], + packages = experimental_transfer_users, +) + +package_group( + name = "mosaic_users", + includes = ["//jax:internal"], + packages = mosaic_internal_users, +) + +package_group( + name = "mosaic_gpu_users", + includes = ["//jax:internal"], + packages = mosaic_gpu_internal_users, +) + +package_group( + name = "pallas_fuser_users", + includes = ["//jax:internal"], + packages = pallas_fuser_users, +) + +package_group( + name = "pallas_gpu_users", + includes = ["//jax:internal"], + packages = pallas_gpu_internal_users, +) + +package_group( + name = "pallas_sc_users", + includes = ["//jax:internal"], + packages = pallas_sc_internal_users, +) + +package_group( + name = "serialize_executable_users", + includes = ["//jax:internal"], + packages = serialize_executable_internal_users, +) + +pytype_strict_library( + name = "buffer_callback", + srcs = [ + "buffer_callback.py", + ], + visibility = [":buffer_callback_users"], + deps = [ + "//jax/_src:buffer_callback", + ], +) + +pytype_strict_library( + name = "checkify", + srcs = [ + "checkify.py", + ], + visibility = [ + "//jax:internal", + ] + jax_visibility("checkify"), + deps = [ + "//jax/_src:checkify", + ], +) + +pytype_strict_library( + name = "colocated_python", + srcs = [ + "colocated_python/__init__.py", + "colocated_python/api.py", + "colocated_python/func.py", + "colocated_python/func_backend.py", + "colocated_python/obj.py", + "colocated_python/obj_backend.py", + "colocated_python/serialization.py", + ], + visibility = ["//visibility:public"], + deps = [ + "//jax", + "//jax/_src:api", + "//jax/_src:api_util", + "//jax/_src:config", + "//jax/_src:traceback_util", + "//jax/_src:tree_util", + "//jax/_src:util", + "//jax/_src:xla_bridge", + "//jax/_src/lib", + "//jax/extend:backend", + "//jax/extend:ifrt_programs", + ] + py_deps("numpy") + py_deps("cloudpickle"), +) + +pytype_strict_library( + name = "compilation_cache", + srcs = [ + "compilation_cache/__init__.py", + "compilation_cache/compilation_cache.py", + ], + visibility = ["//visibility:public"], + deps = ["//jax/_src:compilation_cache_internal"], +) + +pytype_strict_library( + name = "compute_on", + srcs = ["compute_on.py"], + visibility = ["//visibility:public"], + deps = [ + "//jax/_src:compute_on", + ], +) + +pytype_strict_library( + name = "custom_dce", + srcs = ["custom_dce.py"], + visibility = ["//visibility:public"], + deps = [ + "//jax/_src:custom_dce", + ], +) + +pytype_strict_library( + name = "custom_partitioning", + srcs = ["custom_partitioning.py"], + visibility = ["//visibility:public"], + deps = [ + "//jax/_src:custom_partitioning", + "//jax/_src:custom_partitioning_sharding_rule", + ], +) + +pytype_strict_library( + name = "jet", + srcs = ["jet.py"], + visibility = ["//visibility:public"], + deps = [ + "//jax", + "//jax/_src:ad_util", + "//jax/_src:api", + "//jax/_src:core", + "//jax/_src:lax", + "//jax/_src:partial_eval", + "//jax/_src:sharding_impls", + "//jax/_src:util", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "key_reuse", + srcs = glob(["key_reuse/**/*.py"]), + visibility = ["//jax:internal"], + deps = [ + "//jax", + "//jax/_src:api", + "//jax/_src:api_util", + "//jax/_src:core", + "//jax/_src:debugging", + "//jax/_src:effects", + "//jax/_src:hashable_array", + "//jax/_src:lax", + "//jax/_src:partial_eval", + "//jax/_src:random", + "//jax/_src:shard_map", + "//jax/_src:source_info_util", + "//jax/_src:traceback_util", + "//jax/_src:util", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "layout", + srcs = ["layout.py"], + visibility = ["//visibility:public"], + deps = [ + "//jax/_src:api", + "//jax/_src:layout", + ], +) + +pytype_strict_library( + name = "mesh_utils", + srcs = ["mesh_utils.py"], + visibility = ["//visibility:public"], + deps = [ + "//jax/_src:internal_mesh_utils", + ], +) + +pytype_strict_library( + name = "mosaic", + srcs = [ + "mosaic/__init__.py", + "mosaic/dialects.py", + ], + visibility = [":mosaic_users"], + deps = [ + "//jax/_src:tpu_custom_call", + "//jax/_src/lib", + ], +) + +# This target only supports sm_90 GPUs. +py_library_providing_imports_info( + name = "mosaic_gpu", + srcs = glob( + include = ["mosaic/gpu/*.py"], + exclude = ["mosaic/gpu/test_util.py"], + ), + data = if_cuda_is_configured([ + "@local_config_cuda//cuda:runtime_nvdisasm", + "@nvidia_nvshmem//:libnvshmem_device", + # OSS-only dependency. + "@cuda_nvcc//:nvdisasm", + "@cuda_nvvm//:nvvm", + ]), + lib_rule = pytype_strict_library, + visibility = [ + ":mosaic_gpu_users", + ], + deps = [ + "//jax", + "//jax/_src:config", + "//jax/_src:core", + "//jax/_src:dtypes", + "//jax/_src:mesh", + "//jax/_src:mlir", + "//jax/_src:sharding_impls", + "//jax/_src:stages", + "//jax/_src:util", + "//jax/_src/lib", + "//jax/extend:backend", + "//jaxlib/mlir:arithmetic_dialect", + "//jaxlib/mlir:builtin_dialect", + "//jaxlib/mlir:control_flow_dialect", + "//jaxlib/mlir:func_dialect", + "//jaxlib/mlir:gpu_dialect", + "//jaxlib/mlir:ir", + "//jaxlib/mlir:llvm_dialect", + "//jaxlib/mlir:math_dialect", + "//jaxlib/mlir:memref_dialect", + "//jaxlib/mlir:nvgpu_dialect", + "//jaxlib/mlir:nvvm_dialect", + "//jaxlib/mlir:pass_manager", + "//jaxlib/mlir:scf_dialect", + "//jaxlib/mlir:vector_dialect", + "//jaxlib/mosaic/python:gpu_dialect", + ] + py_deps("absl-all") + py_deps("numpy"), +) + +pytype_strict_library( + name = "mosaic_gpu_test_util", + testonly = True, + srcs = ["mosaic/gpu/test_util.py"], + deps = [ + ":mosaic_gpu", + "//jax", + "//jax/_src/lib", + ], +) + +pytype_strict_library( + name = "multihost_utils", + srcs = ["multihost_utils.py"], + visibility = ["//visibility:public"], + deps = [ + "//jax", + "//jax/_src:ad", + "//jax/_src:api", + "//jax/_src:batching", + "//jax/_src:core", + "//jax/_src:dtypes", + "//jax/_src:mlir", + "//jax/_src:random", + "//jax/_src:sharding_impls", + "//jax/_src:util", + "//jax/_src:xla_bridge", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "ode", + srcs = ["ode.py"], + visibility = ["//visibility:public"], + deps = [ + "//jax", + "//jax/_src:core", + "//jax/_src:numpy", + "//jax/_src:util", + ], +) + +pytype_strict_library( + name = "pallas", + srcs = glob( + [ + "pallas/**/*.py", + ], + exclude = [ + "pallas/mosaic_gpu.py", + "pallas/ops/gpu/**/*.py", + "pallas/ops/tpu/**/*.py", + "pallas/tpu.py", + "pallas/fuser.py", + "pallas/triton.py", + ], + ), + visibility = ["//visibility:public"], + deps = [ + "//jax", + "//jax/_src:deprecations", + "//jax/_src:lax", + "//jax/_src:source_info_util", + "//jax/_src:state_types", + "//jax/_src/pallas", + "//jax/_src/pallas/mosaic:sc_core", + "//jax/_src/pallas/mosaic:sc_primitives", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "pallas_fuser", + srcs = ["pallas/fuser.py"], + visibility = [ + ":pallas_fuser_users", + ], + deps = [ + ":pallas", # build_cleaner: keep + "//jax/_src/pallas/fuser:block_spec", + "//jax/_src/pallas/fuser:custom_evaluate", + "//jax/_src/pallas/fuser:custom_fusion", + "//jax/_src/pallas/fuser:fusible", + "//jax/_src/pallas/fuser:fusion", + "//jax/_src/pallas/fuser:jaxpr_fusion", + ], +) + +pytype_strict_library( + name = "pallas_gpu", + visibility = [ + ":pallas_gpu_users", + ], + deps = [ + ":pallas_triton", + # TODO(slebedev): Add :pallas_mosaic_gpu once it is ready. + ], +) + +pytype_strict_library( + name = "pallas_gpu_ops", + srcs = ["//jax/experimental/pallas/ops/gpu:triton_ops"], + visibility = [ + ":pallas_gpu_users", + ], + deps = [ + ":pallas", + ":pallas_gpu", + "//jax", + "//jax/_src:lax", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "pallas_experimental_gpu_ops", + srcs = ["//jax/experimental/pallas/ops/gpu:mgpu_ops"], + visibility = [ + ":mosaic_gpu_users", + ], + deps = [ + ":mosaic_gpu", + ":pallas", + ":pallas_mosaic_gpu", + "//jax", + "//jax/_src:dtypes", + "//jax/_src:test_util", # This is only to make them runnable as jax_multiplatform_test... + "//jax/_src/lib", + "//jax/extend:backend", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "pallas_mosaic_gpu", + srcs = ["pallas/mosaic_gpu.py"], + visibility = [ + ":mosaic_gpu_users", + ], + deps = [ + ":mosaic_gpu", + "//jax/_src/pallas/mosaic_gpu:core", + "//jax/_src/pallas/mosaic_gpu:helpers", + "//jax/_src/pallas/mosaic_gpu:pallas_call_registration", # build_cleaner: keep + "//jax/_src/pallas/mosaic_gpu:pipeline", + "//jax/_src/pallas/mosaic_gpu:primitives", + "//jax/_src/pallas/mosaic_gpu:torch", + ], +) + +pytype_strict_library( + name = "pallas_tpu", + srcs = ["pallas/tpu.py"], + visibility = ["//visibility:public"], + deps = [ + ":pallas", # build_cleaner: keep + "//jax/_src:deprecations", + "//jax/_src/pallas", + "//jax/_src/pallas/mosaic:core", + "//jax/_src/pallas/mosaic:helpers", + "//jax/_src/pallas/mosaic:lowering", + "//jax/_src/pallas/mosaic:pallas_call_registration", # build_cleaner: keep + "//jax/_src/pallas/mosaic:pipeline", + "//jax/_src/pallas/mosaic:primitives", + "//jax/_src/pallas/mosaic:random", + "//jax/_src/pallas/mosaic:tpu_info", + "//jax/_src/pallas/mosaic/interpret:interpret_pallas_call", + ], +) + +pytype_strict_library( + name = "pallas_tpu_sc", + srcs = ["pallas/tpu_sc.py"], + visibility = [ + ":pallas_sc_users", + ], + deps = [ + ":pallas", # build_cleaner: keep + "//jax/_src/pallas/mosaic:sc_core", + "//jax/_src/pallas/mosaic:sc_primitives", + ], +) + +pytype_strict_library( + name = "pallas_tpu_ops", + srcs = glob(["pallas/ops/tpu/**/*.py"]), + visibility = ["//visibility:public"], + deps = [ + ":pallas", + ":pallas_tpu", + "//jax", + "//jax/_src:dtypes", + "//jax/_src:random", + "//jax/_src:shard_map", + "//jax/_src:util", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "pallas_triton", + srcs = [ + "pallas/triton.py", + ], + visibility = [ + ":pallas_gpu_users", + ], + deps = [ + "//jax/_src:deprecations", + "//jax/_src/pallas", + "//jax/_src/pallas/triton:core", + "//jax/_src/pallas/triton:pallas_call_registration", # build_cleaner: keep + "//jax/_src/pallas/triton:primitives", + ], +) + +pytype_strict_library( + name = "pjit", + srcs = ["pjit.py"], + visibility = ["//visibility:public"], + deps = [ + "//jax/_src:api", + "//jax/_src:sharding_impls", + ], +) + +pytype_strict_library( + name = "hijax", + srcs = ["hijax.py"], + visibility = ["//visibility:public"], + deps = [ + "//jax/_src:ad", + "//jax/_src:ad_util", + "//jax/_src:core", + "//jax/_src:effects", + "//jax/_src:hijax", + "//jax/_src:lax", + ], +) + +pytype_strict_library( + name = "profiler", + srcs = ["profiler.py"], + visibility = ["//visibility:public"], + deps = [ + "//jax/_src/lib", + ], +) + +pytype_strict_library( + name = "random", + srcs = ["random.py"], + visibility = ["//visibility:public"], + deps = [ + "//jax/_src:stateful_rng", + ], +) + +pytype_strict_library( + name = "rnn", + srcs = ["rnn.py"], + visibility = ["//visibility:public"], + deps = [ + "//jax", + "//jax/_src:api", + "//jax/_src:core", + "//jax/_src:custom_derivatives", + "//jax/_src:lax", + "//jax/_src:typing", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "serialize_executable", + srcs = ["serialize_executable.py"], + visibility = [":serialize_executable_users"], + deps = [ + "//jax", + "//jax/_src/lib", + ], +) + +pytype_strict_library( + name = "scheduling_groups", + srcs = ["scheduling_groups.py"], + visibility = ["//visibility:public"], + deps = [ + "//jax/_src:ad", + "//jax/_src:api", + "//jax/_src:api_util", + "//jax/_src:batching", + "//jax/_src:core", + "//jax/_src:mlir", + "//jax/_src:partial_eval", + "//jax/_src:tree_util", + "//jax/_src:util", + "//jax/_src/lib", + ], +) + +pytype_strict_library( + name = "shard_alike", + srcs = ["shard_alike.py"], + visibility = ["//visibility:public"], + deps = [ + "//jax/_src:shard_alike", + ], +) + +pytype_strict_library( + # jax.experimental.shard_map is a legacy API and should not + # be used in new code. Use jax.shard_map instead. + name = "shard_map", + srcs = ["shard_map.py"], + visibility = [ + "//jax:internal", + ] + jax_visibility("experimental/shard_map"), + deps = [ + "//jax", + "//jax/_src:mesh", + "//jax/_src:shard_map", + "//jax/_src:traceback_util", + ], +) + +pytype_strict_library( + name = "fused", + srcs = ["fused.py"], + visibility = ["//visibility:public"], + deps = [ + "//jax/_src:ad", + "//jax/_src:api", + "//jax/_src:api_util", + "//jax/_src:batching", + "//jax/_src:core", + "//jax/_src:mlir", + "//jax/_src:partial_eval", + "//jax/_src:tree_util", + "//jax/_src:util", + "//jax/_src/lib", + ], +) + +pytype_strict_library( + name = "source_mapper", + srcs = glob(include = ["source_mapper/**/*.py"]), + visibility = ["//visibility:public"], + deps = [ + "//jax", + "//jax/_src:config", + "//jax/_src:core", + "//jax/_src:source_info_util", + "//jax/_src:sourcemap", + ] + py_deps("absl/flags"), +) + +pytype_strict_library( + name = "sparse", + srcs = glob( + [ + "sparse/*.py", + ], + exclude = ["sparse/test_util.py"], + ), + visibility = ["//visibility:public"], + deps = [ + "//jax", + "//jax/_src:ad", + "//jax/_src:api", + "//jax/_src:api_util", + "//jax/_src:batching", + "//jax/_src:config", + "//jax/_src:core", + "//jax/_src:custom_derivatives", + "//jax/_src:dtypes", + "//jax/_src:ffi", + "//jax/_src:lax", + "//jax/_src:mlir", + "//jax/_src:numpy", + "//jax/_src:partial_eval", + "//jax/_src:sharding_impls", + "//jax/_src:traceback_util", + "//jax/_src:typing", + "//jax/_src:util", + "//jax/_src/lib", + ] + py_deps("numpy") + py_deps("scipy"), +) + +pytype_strict_library( + name = "sparse_test_util", + srcs = [ + "sparse/test_util.py", + ], + visibility = ["//jax:internal"], + deps = [ + ":sparse", + "//jax", + "//jax/_src:lax", + "//jax/_src:test_util", + "//jax/_src:typing", + "//jax/_src:util", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "topologies", + srcs = ["topologies.py"], + visibility = ["//visibility:public"], + deps = [ + ":mesh_utils", + "//jax", + "//jax/_src:xla_bridge", + "//jax/_src/lib", + ], +) + +pytype_strict_library( + name = "transfer", + srcs = ["transfer.py"], + visibility = [ + ":experimental_transfer_users", + "//jax:internal", + ], + deps = [ + "//jax", + "//jax/_src:util", + "//jax/_src/lib", + ], +) + +pytype_strict_library( + name = "xla_metadata", + srcs = ["xla_metadata.py"], + visibility = ["//visibility:public"], + deps = [ + "//jax/_src:xla_metadata", + ], +) + +# TODO(dsuo): Remove this once experimental aliases from jax/BUILD are removed. +py_library_providing_imports_info( + name = "experimental", + srcs = [ + "__init__.py", + "x64_context.py", + ], + visibility = ["//visibility:public"], + deps = [ + "//jax/_src:api", + "//jax/_src:callback", + "//jax/_src:config", + "//jax/_src:core", + "//jax/_src:dtypes", + "//jax/_src:earray", + ], +) + +# TODO(dsuo): Remove these filegroups once experimental aliases from jax/BUILD +# are removed. +filegroup( + name = "jax_public", + srcs = glob([ + "key_reuse/**/*.py", + "roofline/**/*.py", + "compilation_cache/**/*.py", + ]) + [ + "checkify.py", + "fused.py", + "multihost_utils.py", + "pjit.py", + "scheduling_groups.py", + "shard_map.py", + ], + visibility = ["//jax:internal"], +) diff --git a/jax/experimental/__init__.py b/jax/experimental/__init__.py index 375d058d0edc..5df46bfe8bba 100644 --- a/jax/experimental/__init__.py +++ b/jax/experimental/__init__.py @@ -15,10 +15,12 @@ # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/jax-ml/jax/issues/7570 -from jax.experimental.x64_context import ( - enable_x64 as enable_x64, - disable_x64 as disable_x64, -) +# Note: we discourage adding any new APIs directly here. Instead please consider +# adding them to a relevant or new submodule in jax.experimental. This approach +# gives the JAX team more granularity to manage access / visibility to +# experimental features and as a result, more flexibility to manage their status +# and lifetimes. + from jax._src.callback import ( io_callback as io_callback ) @@ -28,3 +30,34 @@ from jax._src.earray import ( EArray as EArray ) +from jax._src.core import ( + cur_qdd as cur_qdd, +) + +_deprecations = { + # Remove in v0.10.0 + "disable_x64": ( + ("jax.experimental.disable_x64 was removed in JAX v0.9.0;" + " use jax.enable_x64(False) instead."), + None, + ), + "enable_x64": ( + ("jax.experimental.enable_x64 was removed in JAX v0.9.0;" + " use jax.enable_x64(True) instead."), + None + ), + "mutable_array": ( + ("jax.experimental.mutable_array was removed in JAX v0.9.0;" + " use jax.new_ref instead."), + None, + ), + "MutableArray": ( + ("jax.experimental.MutableArray was removed in JAX v0.9.0;" + " use jax.Ref instead."), + None, + ), +} + +from jax._src.deprecations import deprecation_getattr as _deprecation_getattr +__getattr__ = _deprecation_getattr(__name__, _deprecations) +del _deprecation_getattr diff --git a/jax/experimental/_private_mm/examples/example_overlap.py b/jax/experimental/_private_mm/examples/example_overlap.py index 022eb3293dcc..f3c3726ec347 100644 --- a/jax/experimental/_private_mm/examples/example_overlap.py +++ b/jax/experimental/_private_mm/examples/example_overlap.py @@ -14,7 +14,8 @@ """An example showcasing overlap on a (forward-only) PP-like workload.""" from dataclasses import dataclass -from typing import Any, Callable +from typing import Any +from collections.abc import Callable import time import numpy as np diff --git a/jax/experimental/_private_mm/examples/example_pp.py b/jax/experimental/_private_mm/examples/example_pp.py index b43d1c743c28..846d96cb34a9 100644 --- a/jax/experimental/_private_mm/examples/example_pp.py +++ b/jax/experimental/_private_mm/examples/example_pp.py @@ -15,7 +15,8 @@ from dataclasses import dataclass from functools import cached_property, partial -from typing import Any, Callable +from typing import Any +from collections.abc import Callable import numpy as np diff --git a/jax/experimental/_private_mm/mini_dime.py b/jax/experimental/_private_mm/mini_dime.py index 971d5a016817..0c21417f7b2d 100644 --- a/jax/experimental/_private_mm/mini_dime.py +++ b/jax/experimental/_private_mm/mini_dime.py @@ -49,9 +49,9 @@ import jax import jax.numpy as jnp -import jaxlib.xla_extension as xe from jax._src import array -from jax._src.op_shardings import are_op_shardings_equal +from jax._src.lib import _jax +from jax._src.op_shardings import are_hlo_shardings_equal def _get_nccl_dtype_and_count(arr, count=None): @@ -66,10 +66,10 @@ def _get_nccl_dtype_and_count(arr, count=None): return nccl_dtype, count -def get_distributed_client() -> xe.DistributedRuntimeClient: +def get_distributed_client() -> _jax.DistributedRuntimeClient: from jax._src.distributed import global_state - assert isinstance(global_state.client, xe.DistributedRuntimeClient) + assert isinstance(global_state.client, _jax.DistributedRuntimeClient) return global_state.client @@ -145,7 +145,7 @@ def shardings_are_compatible( ): # NOTE: Variant of `jax.sharding.Sharding.is_equivalent_to` that skips _internal_device_list check return ( - are_op_shardings_equal( + are_hlo_shardings_equal( self._to_xla_hlo_sharding(ndim), other._to_xla_hlo_sharding(ndim) ) # and self._internal_device_list == other._internal_device_list # type: ignore diff --git a/jax/experimental/_private_mm/mm.py b/jax/experimental/_private_mm/mm.py index f47724ce6ec4..b108fb3e2e35 100644 --- a/jax/experimental/_private_mm/mm.py +++ b/jax/experimental/_private_mm/mm.py @@ -16,7 +16,7 @@ from dataclasses import dataclass from functools import cached_property, lru_cache, partial, wraps -from typing import Callable +from collections.abc import Callable import jax import jax.numpy as jnp diff --git a/jax/experimental/array_serialization/BUILD b/jax/experimental/array_serialization/BUILD index ab1ee3fd393e..559d8eb16269 100644 --- a/jax/experimental/array_serialization/BUILD +++ b/jax/experimental/array_serialization/BUILD @@ -35,9 +35,37 @@ pytype_library( "serialization.py", ], visibility = ["//visibility:public"], - deps = ["//jax"] + py_deps([ + deps = [ + "//jax", + "//jax/experimental/array_serialization:tensorstore_impl", + ] + py_deps([ + "absl/logging", + "numpy", + ]), +) + +pytype_library( + name = "pytree_serialization", + srcs = ["pytree_serialization.py"], + visibility = ["//visibility:public"], + deps = [ + "//jax", + "//jax/experimental/array_serialization:pytree_serialization_utils", + "//jax/experimental/array_serialization:tensorstore_impl", + ] + py_deps([ + "absl/logging", "numpy", + ]), +) + +pytype_library( + name = "pytree_serialization_utils", + srcs = ["pytree_serialization_utils.py"], + deps = [ + "//jax", + ] + py_deps([ "absl/logging", + "numpy", ]), ) @@ -45,10 +73,19 @@ jax_multiplatform_test( name = "serialization_test", srcs = ["serialization_test.py"], enable_configs = [ - "tpu_v3_2x2", + "tpu_v3_x4", ], deps = [ - "//jax:experimental", + "//jax/experimental/array_serialization:pytree_serialization", "//jax/experimental/array_serialization:serialization", ], ) + +pytype_library( + name = "tensorstore_impl", + srcs = ["tensorstore_impl.py"], + visibility = ["//visibility:public"], + deps = ["//jax"] + py_deps([ + "numpy", + ]), +) diff --git a/jax/experimental/array_serialization/pytree_serialization.py b/jax/experimental/array_serialization/pytree_serialization.py new file mode 100644 index 000000000000..dcc017aa5453 --- /dev/null +++ b/jax/experimental/array_serialization/pytree_serialization.py @@ -0,0 +1,513 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Serializations routines for pytrees including array and non-array serialization. +""" + +from __future__ import annotations + +from os import PathLike +import os +import re +from typing import Any +from uuid import uuid4, UUID +import json +import asyncio +import threading +from concurrent.futures import ThreadPoolExecutor +import shutil +import logging + +import jax +from jax._src import distributed +from jax._src.api_util import flatten_axes +from jax._src.layout import Format + +from jax.experimental import multihost_utils +from jax.experimental.array_serialization import tensorstore_impl as ts_impl +import jax.experimental.array_serialization.pytree_serialization_utils as utils +from jax._src import path as pathlib +import numpy as np + +logger = logging.getLogger(__name__) + +_THREADING_SAVE_LOCK = threading.Lock() + +_REMOTE_URL_PREFIXES = ['gs://', 's3://'] +_PYTREEDEF_FILE = "pytreedef.json" +_ARCHIVE_NAME = "archive.zip" +_USE_OCDBT = True # a lot of the code relies on this being True +_MAX_PATH_LENGTH = 4096 +_ARRAY_STORE_DIRNAME = "array_store" +_ARRAY_TYPE_FORMAT = "Array({dtype}[{shape}])" +_ARRAY_TYPE_REGEX = r"Array\(([a-zA-Z0-9_]+)\[([0-9, ]*)\]\)" +_MAX_CONCURRENCY = 32 +_TIMEOUT_SEC = 30 + +PyTreeT = Any + +__all__ = ["save", "load", "load_pytreedef", + "nonblocking_load", "nonblocking_save"] + + +def _get_unique_sync_key() -> str | None: + """Generate a thread-local key for ensuring all host finish (de)serializing""" + if jax.process_count() == 1: + return None + # broadcast a thread-local unique barrier name + sync_key_unique = multihost_utils.broadcast_one_to_all( + np.frombuffer(uuid4().bytes, dtype=np.int32)) + sync_key_id = UUID(bytes=np.array(sync_key_unique).tobytes()) + return f"jax_sync_key_{str(sync_key_id)}" + + +def _is_str_same_on_all_hosts(path: str | PathLike[str]) -> bool: + """All-gather the location of the checkpoint and check if it's the same.""" + if jax.process_count() <= 1: + return False + path_b = str(path).encode("utf-8") + if len(path_b) > _MAX_PATH_LENGTH: + raise ValueError(f"Path exceeds maximum length of {_MAX_PATH_LENGTH} in" + " multiprocess case.") + path_array = np.concatenate([ + np.frombuffer(path_b, dtype=np.uint8), np.zeros( + _MAX_PATH_LENGTH - len(path_b), dtype=np.uint8)]) + path_array = multihost_utils.process_allgather(path_array) + return bool(np.all(path_array[0] == path_array[1:])) + + +def _sync_on_key(key: str | None, extra_tag: str = "") -> None: + if key is None: + return + full_key = f"{key}-{extra_tag}" if extra_tag else key + if (client := distributed.global_state.client) is not None: + client.wait_at_barrier(full_key, timeout_in_ms=_TIMEOUT_SEC * 1000) + + +def _is_array_like(x): + return isinstance(x, (jax.Array, np.ndarray)) + + +def _leaf_to_desc(leaf) -> str: + if leaf is None: + return "null" + elif _is_array_like(leaf): + return _ARRAY_TYPE_FORMAT.format( + dtype=leaf.dtype.name, shape=", ".join(map(str, leaf.shape))) + else: + return type(leaf).__name__ + + +def _desc_to_leaf(leaf_desc: str | None) -> str | None | jax.ShapeDtypeStruct: + if leaf_desc is None: + return None + if not re.match(_ARRAY_TYPE_REGEX, leaf_desc): + return leaf_desc + shape_dtype_match = re.match(_ARRAY_TYPE_REGEX, leaf_desc) + assert shape_dtype_match is not None + dtype_str, shape_str = shape_dtype_match.groups() + shape = [int(x.strip()) for x in shape_str.strip("]").strip().split(",") + if len(x.strip()) > 0] + return jax.ShapeDtypeStruct(shape, jax.numpy.dtype(dtype_str)) + + +def _is_remote_path(path: str | PathLike[str]): + """Check whether a path is remote by examining the prefix.""" + # we need to truncate e.g., gs:// to gs:/ because pathlib.Path collapses // + return any(str(path).startswith(prefix[:-1]) + for prefix in _REMOTE_URL_PREFIXES) + + +def _norm_path(path: str | PathLike[str]) -> Any: + if _is_remote_path(path): + return pathlib.Path(path) + return pathlib.Path(path).expanduser().resolve() + + +def _rm_dir(root: Any) -> None: + if _is_remote_path(root): + root.rmtree() # pytype: disable=attribute-error + else: + shutil.rmtree(root) + + +def _set_up_destination(root: str | PathLike[str], overwrite: bool, + pytree_repr: dict[str, Any], distinct_locations: bool, + sync_key: str | None) -> dict[str, Any]: + """Inspect the destination, set it up for writing, potentially read existing data.""" + root = _norm_path(root) + if overwrite: + if root.exists() and len(list(root.iterdir())) > 0: + # check that we're only deleting things that come from JAX + # refuse to rm directories containing additional entries + extra_member_paths = [ + path for path in list(root.iterdir()) if path.name not in + (_PYTREEDEF_FILE, _ARCHIVE_NAME, _ARRAY_STORE_DIRNAME)] + + if len(extra_member_paths) != 0: + raise RuntimeError( + "Refusing to work on a directory that is not a previous checkpoint." + f" Unrecognized paths: {extra_member_paths}. Remove them manually" + f" if you're sure you want to use {root} as the checkpoint" + " directory.") + + if (jax.process_index() == 0 or distinct_locations) and root.exists(): + _rm_dir(root) + _sync_on_key(sync_key, "overwrite") + return pytree_repr + else: + if (root.exists() and len(list(root.iterdir())) > 0): # not empty + raise ValueError(f"Files already exist at path: `{root}`, but you" + f" specified `{overwrite=}`") + return pytree_repr + + +def _prepare_directory(root: str | PathLike[str], overwrite: bool, + pytreedef_repr: dict[str, Any], distinct_locations: bool, + sync_key: str | None): + """Prepare the directory: check destination, potentially read existing data + and overwrite. + + Raises: + RuntimeError: If the destination directory cannot be created. + """ + root = _norm_path(root) + # prepare the destination directory, overwrite destination directory or error + pytreedef_repr = _set_up_destination( + root, overwrite, pytreedef_repr, distinct_locations, sync_key) + + if not _is_remote_path(root) and (distinct_locations + or jax.process_index() == 0): + root.mkdir(exist_ok=True) # do not make parents, that's too much + if not root.exists() or not root.is_dir(): + raise RuntimeError(f"Could not create destination directory at {root}") + _sync_on_key(sync_key, "mkdir") + return pytreedef_repr + + +def _write_arrays(array_store_path: Any, arrs: list[Any], + arr_leaf_ids: list[int], ts_specs: list[Any | None], + distinct_locations: bool): + paths = [array_store_path / str(leaf_id) for leaf_id in arr_leaf_ids] + process_idx = None + if not distinct_locations and jax.process_count() > 1: + process_idx = jax.process_index() + default_ts_specs = [ts_impl.get_tensorstore_spec(path, ocdbt=_USE_OCDBT, + process_idx=process_idx, + arr=arr) + for (path, arr) in zip(paths, arrs)] + ts_specs = [ts_impl.merge_nested_ts_specs(default_ts_spec, ts_spec) + for (default_ts_spec, ts_spec) in zip(default_ts_specs, ts_specs)] + + # sanity check the ts specs + if len(ts_specs) > 0: # verify the base path is shared for all arrays + expected_path = ts_specs[0]["kvstore"]["base"]["path"] # shared base path + for ts_spec, arr in zip(ts_specs, arrs): + ts_impl.verify_tensorstore_spec(ts_spec, arr, expected_path, + ocdbt=_USE_OCDBT, check_metadata=True) + + async def _serialize_arrays(): + await asyncio.gather(*[ + ts_impl.async_serialize(arr, ts_spec, primary_host=None) + for (arr, ts_spec) in zip(arrs, ts_specs)]) + + asyncio.run(_serialize_arrays()) + + +def _finalize_array_store(kvstore_path, distinct_locations: bool): + """When multiple processes are writing, they must write to a per-process + location followed by combining them via no-copy links to the final location. + """ + # only in multiprocess case and only process 0 + if distinct_locations or jax.process_count() == 1 or jax.process_index() != 0: + return + dummy_key_path = os.path.join(kvstore_path, "dummy_key") + combined_kvstore = ts_impl.get_tensorstore_spec( + dummy_key_path, ocdbt=True, process_idx=None)["kvstore"] + children_kvstores = [ts_impl.get_tensorstore_spec( + dummy_key_path, ocdbt=True, process_idx=i)["kvstore"] + for i in range(jax.process_count())] + _ = combined_kvstore.pop("path") + _ = [kvstore.pop("path") for kvstore in children_kvstores] + asyncio.run(ts_impl.combine_kvstores(combined_kvstore, children_kvstores)) + + +def _write_pytreedef(directory: Any, pytree_repr: dict[str, Any], + distinct_locations: bool): + """Write the pytreedef to the destination directory and aux data to the archive.""" + if not (jax.process_index() == 0 or distinct_locations): + return + root = _norm_path(directory) + (root / _PYTREEDEF_FILE).write_text(json.dumps(pytree_repr, indent=2)) + + +def _tree_broadcast(a, b, is_leaf=lambda x: x is None): + """Broadcast the prefix tree `a` to the full tree `b` + + Uses `flatten_axes` for better error messages on mismatched arity but allowing + for custom is_leaf in the `a` and `b` trees. + """ + a_leaves, a_struct = jax.tree.flatten(a, is_leaf=is_leaf) + a_idx2leaf_map = dict(enumerate(a_leaves)) + a_idx = jax.tree.unflatten(a_struct, a_idx2leaf_map.keys()) + a_idx_broadcast = flatten_axes("tree_broadcast", + jax.tree.structure(b, is_leaf=is_leaf), a_idx) + return jax.tree.map(lambda i: a_idx2leaf_map[i], a_idx_broadcast) + + +_serialization_executor = ThreadPoolExecutor(max_workers=_MAX_CONCURRENCY) + + +def save(data: PyTreeT, directory: str | PathLike[str], *, + overwrite: bool = True, ts_specs: PyTreeT | None = None) -> None: + """Saves the given data structure to the provided directory path. + + This function provides functionality to serialize and save a data structure + comprising JAX arrays, along with its structure to a given directory. It + leverages `PyTree` for flattening and reconstructing the data structure. + + This is a simple experimental array serialization API, for anything more + complex and for all checkpointing prefer: https://github.com/google/orbax + + Args: + data: The data structure to be saved. Arbitrary composition of JAX arrays, + including nested structures. + directory: The directory path where the data will be saved. A local path or + a remote URL (e.g., gs://, s3://). For remote URLs, `etils` is required. + overwrite: If True, any existing directory with the same name will be + overwritten. + ts_specs: Optional tensorstore specs to use for serialization. If None, + defaults to using the default tensorstore specs. + + Example: + >>> data = {"a": jnp.array([1, 2]), "b": None} + >>> save(data, directory) + """ + with _THREADING_SAVE_LOCK: + return _save(data, directory, overwrite=overwrite, ts_specs=ts_specs) + + +def _save(data: PyTreeT, directory: str | PathLike[str], *, + overwrite: bool = True, ts_specs: PyTreeT | None = None) -> None: + sync_key = _get_unique_sync_key() # get a synchronization key for multi-host + + if _is_remote_path(directory) and not pathlib.epath_installed: + raise RuntimeError("For saving to remote URLs (e.g., gs, s3) you need the" + " `etils` module installed. You can install it using" + " `pip install etils`.") + ts_specs = _tree_broadcast(ts_specs, data, + is_leaf=ts_impl.is_tensorstore_spec_leaf) + data_flat, pytreedef = jax.tree.flatten(data, is_leaf=lambda x: x is None) + if not all(x is None or _is_array_like(x) for x in data_flat): + raise ValueError("For serialization, all leaves must be either None or" + " jax.Array-like objects.") + distinct_locations = not _is_str_same_on_all_hosts(directory) + if jax.process_count() > 1 and distinct_locations: + raise ValueError( + "Saving to different locations on different hosts is not supported," + " because it is extremely fragile. Consider using a single location.") + root = _norm_path(directory) + + # 1. serialize the pytree ################################# + pytreedef_repr = utils.serialize_pytreedef(pytreedef) + pytreedef_repr[utils._LEAF_IDS_KEY] = jax.tree.map(_leaf_to_desc, data_flat) + + pytreedef_repr = _prepare_directory( + root, overwrite, pytreedef_repr, distinct_locations, sync_key) + futures = [] + futures.append(_serialization_executor.submit( + _write_pytreedef, root, pytreedef_repr, distinct_locations)) + + # 2. serialize arrays ##################################### + array_store_path = root / _ARRAY_STORE_DIRNAME + arrs = [data for data in data_flat if _is_array_like(data)] + arr_leaf_ids = [i for i, data in enumerate(data_flat) if _is_array_like(data)] + ts_specs_flat = jax.tree.leaves(ts_specs, + is_leaf=ts_impl.is_tensorstore_spec_leaf) + ts_specs_flat = [ts_specs_flat[i] for i in arr_leaf_ids] + futures.append(_serialization_executor.submit( + _write_arrays, array_store_path, arrs, arr_leaf_ids, ts_specs_flat, + distinct_locations)) + + # 3. wait for all futures to complete ##################### + _ = [fut.result() for fut in futures] + _sync_on_key(sync_key, "array_serialization") + + # 4. finalize the array writing ########################### + if len(arr_leaf_ids) > 0 and _USE_OCDBT: + _serialization_executor.submit( # call from a thread to not nest asyncio + _finalize_array_store, array_store_path, distinct_locations).result() + # we are done with all async ops here, we can block #### + _sync_on_key(sync_key, "end") + + +def _read_arrays(array_store_path: str | PathLike[str], arr_leaf_ids: list[int], + ts_specs: list[Any], shardings: list[Any]): + # array_store_path = root / _LEAF_DATA_DIR / _ARRAY_STORE_DIRNAME + arr_store_path = _norm_path(array_store_path) + arr_paths = [arr_store_path / str(leaf_id) for leaf_id in arr_leaf_ids] + + # byte limiter to limit number of parallel reads, resizes to largest read + byte_limiter = ts_impl._LimitInFlightBytes(10 * 1024 ** 3) # 10 GB + + default_ts_specs = [ts_impl.get_tensorstore_spec(path, ocdbt=_USE_OCDBT, + process_idx=None) + for path in arr_paths] + ts_specs = [ts_impl.merge_nested_ts_specs(default_ts_spec, ts_spec) + for (default_ts_spec, ts_spec) in zip(default_ts_specs, ts_specs)] + + if len(ts_specs) > 0: # verify the base path is shared for all arrays + expected_path = ts_specs[0]["kvstore"]["base"]["path"] # shared base path + for ts_spec in ts_specs: + ts_impl.verify_tensorstore_spec(ts_spec, arr=None, path=expected_path, + ocdbt=_USE_OCDBT, check_metadata=False) + + async def _deserialize_arrays(): + return await asyncio.gather(*[ + ts_impl.async_deserialize(sharding, ts_spec, byte_limiter=byte_limiter) + for (sharding, ts_spec) in zip(shardings, ts_specs)]) + + return dict(zip(arr_leaf_ids, asyncio.run(_deserialize_arrays()))) + + +def load_pytreedef(directory: str | PathLike[str]) -> PyTreeT: + """Loads a pytree from the given directory. + + This is a simple experimental array serialization API, for anything more + complex and for all checkpointing prefer: https://github.com/google/orbax + + Args: + directory: Directory path to load from. + Returns: + The loaded pytree with arrays represented as jax.ShapeDtypeStruct's. + """ + assert not _is_remote_path(directory) or pathlib.epath_installed, ( + "For checkpointing using remote URLs (e.g., gs, s3) you need `etils`" + " module installed. You can install it using `pip install etils`.") + json_content = (_norm_path(directory) / _PYTREEDEF_FILE).read_text() + raw_tree = json.loads(json_content) + leaves = map(_desc_to_leaf, raw_tree[utils._LEAF_IDS_KEY]) + return jax.tree.unflatten(utils.deserialize_pytreedef(raw_tree), leaves) + + +def load(directory: str | PathLike[str], shardings: PyTreeT, *, + mask: PyTreeT | None = None, ts_specs: PyTreeT | None = None + ) -> PyTreeT: + """Loads and reconstructs a data structure from a directory. + + This is a simple experimental array serialization API, for anything more + complex and for all checkpointing prefer: https://github.com/google/orbax + + Args: + directory: Directory path where the data is stored. + shardings: Sharding strategy for array objects, either a Sharding or a + ShapeDtypeStruct with a Sharding/Format. + mask: boolean prefix tree for partial loading, will return None for False + leaves. + ts_specs: Optional tensorstore specs to use for deserialization. If None, + defaults to using the default tensorstore specs. + + Returns: + Reconstructed data. + + Example: + >>> save(data, directory) + >>> restored_data = load(directory, SingleDeviceSharding(jax.devices()[0])) + """ + assert not _is_remote_path(directory) or pathlib.epath_installed, ( + "For checkpointing using remote URLs (e.g., gs, s3) you need `etils`" + " module installed. You can install it using `pip install etils`.") + + root = _norm_path(directory) + assert root.is_dir(), f"Checkpoint directory {root} does not exist" + is_leaf = lambda x: x is None + + # deserialize PyTreeDef + pytree = load_pytreedef(directory) + # broadcast the (prefix) shardings and tensorstore specs to the full pytree + shardings = _tree_broadcast(shardings, pytree) + ts_specs = _tree_broadcast(ts_specs, pytree, + is_leaf=ts_impl.is_tensorstore_spec_leaf) + if mask is not None: + _prefix_mask = lambda m, x: jax.tree.map(lambda _: None, x) if not m else x + pytree = jax.tree.map(_prefix_mask, mask, pytree) + pytreedef = jax.tree.structure(pytree, is_leaf=is_leaf) + leaf_ids_flat = jax.tree.leaves(pytree, is_leaf=is_leaf) + shardings_flat = jax.tree.leaves(shardings, is_leaf=is_leaf) + if any(isinstance(shardings, Format) for shardings in shardings_flat): + raise NotImplementedError( + "Deserialization with `Format` instead of `Sharding` is not currently" + " supported. Pass ShapeDtypeStruct(shape, dtype, sharding=format)" + " instead.") + ts_specs_flat = jax.tree.leaves(ts_specs, + is_leaf=ts_impl.is_tensorstore_spec_leaf) + + # deserialize array objects + arr_leaf_ids = [i for i, leaf_id in enumerate(leaf_ids_flat) + if leaf_id is not None] + shardings_flat = [shardings_flat[i] for i in arr_leaf_ids] + ts_specs_flat = [ts_specs_flat[i] for i in arr_leaf_ids] + + arrs_fut = _serialization_executor.submit( + _read_arrays, root / _ARRAY_STORE_DIRNAME, arr_leaf_ids, ts_specs_flat, + shardings_flat) + + arrs = arrs_fut.result() + filled_values = [arrs.get(i, None) for i, _ in enumerate(leaf_ids_flat)] + return jax.tree.unflatten(pytreedef, filled_values) + + +def nonblocking_save(data: PyTreeT, directory: str | PathLike[str], *, + overwrite: bool = True, ts_specs: PyTreeT | None = None + ) -> utils.PyTreeFuture: + """Nonblocking alias of save, return an awaitable future with a pytree stub. + + This is a simple experimental array serialization API, for anything more + complex and for all checkpointing prefer: https://github.com/google/orbax + + Examples: + >>> fut = nonblocking_save(data, directory) + >>> print(fut.pytree) # a pytree of jax.ShapeDtypeStruct's + >>> print(fut.result()) # None, blocking until the serialization is done + """ + # start serialization immediately + fut = utils.PyTreeFuture(_serialization_executor.submit( + save, data, directory, overwrite=overwrite, ts_specs=ts_specs)) + # construct a nice looking pytree representing the nodes being read + fut.pytree = jax.tree.map(lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype) + if _is_array_like(x) else x, data) + return fut + + +def nonblocking_load(directory: str | PathLike[str], shardings: PyTreeT, *, + mask: PyTreeT | None = None, + ts_specs: PyTreeT | None = None) -> utils.PyTreeFuture: + """Nonblocking alias of load, return an awaitable future with a pytree stub. + + This is a simple experimental array serialization API, for anything more + complex and for all checkpointing prefer: https://github.com/google/orbax + + Examples: + >>> fut = nonblocking_load(directory) + >>> print(fut.pytree) # a pytree of jax.ShapeDtypeStruct + >>> print(fut.result()) # the fully populated pytree + """ + # TODO(rdyro): the awaitable future output is a workaround + # it should return the fully populated pytree instead of just + # jax.ShapeDtypeStruct for arrays by constructing them asynchronously + fut = utils.PyTreeFuture(_serialization_executor.submit( + load, directory, shardings, mask=mask, ts_specs=ts_specs)) + fut.pytree = load_pytreedef(directory) + return fut diff --git a/jax/experimental/array_serialization/pytree_serialization_utils.py b/jax/experimental/array_serialization/pytree_serialization_utils.py new file mode 100644 index 000000000000..4a6a42243ae8 --- /dev/null +++ b/jax/experimental/array_serialization/pytree_serialization_utils.py @@ -0,0 +1,84 @@ +# Copyright 2021 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# + +# # Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utilities for representing pytreedefs in a serializable format. +""" + +import base64 +import logging +from types import ModuleType +from concurrent.futures import Future +from typing import Any, TypeVar + +import jax +from jax._src.export.serialization import (flatbuffers, _serialize_pytreedef, + _deserialize_pytreedef_to_pytree, + ser_flatbuf) +from jax.export import register_pytree_node_serialization # pylint: disable=unused-import + +T = TypeVar("T") +PickleModule = ModuleType +logger = logging.getLogger(__name__) + +_READABLE_PYTREE_SERIALIZATION = True +_TREE_REPR_KEY = "__jax_pytreedef_repr" +_LEAF_IDS_KEY = "__jax_leaf_ids" + +_NOT_REGISTERED_MESSAGE = ( + " * If you want to register a custom leaf, register it via" + " `register_pytree_leaf_serialization` first.\n" + " * If you want to register a custom node, register is via" + " `register_pytree_node_serialization`") + +__all__ = ["serialize_pytreedef", "deserialize_pytreedef", + "register_pytree_node_serialization"] + +class PyTreeFuture(Future[Any]): + """A wrapper around a Future that makes it look like an async function.""" + def __init__(self, future: Future[Any]): + self._future, self.pytree = future, None + + def done(self): + return self._future.done() + + def result(self, *args, **kw): + return self._future.result(*args, **kw) + + def __await__(self): + while not self.done(): + yield + return self.result() + + def __repr__(self): + return f"PyTreeFuture(done={self.done()}, pytree={self.pytree})" + + +def serialize_pytreedef(node) -> dict[str, Any]: + builder = flatbuffers.Builder(65536) + exported = _serialize_pytreedef(builder, node) + builder.Finish(exported) + root_repr = base64.b64encode(builder.Output()).decode("utf-8") + leaf_count = node.num_leaves + pytree_repr = {_TREE_REPR_KEY: root_repr, + _LEAF_IDS_KEY: list(range(leaf_count))} + return pytree_repr + + +def deserialize_pytreedef(pytreedef_repr: dict[str, Any]): + buf = base64.b64decode(pytreedef_repr[_TREE_REPR_KEY]) + exp = ser_flatbuf.PyTreeDef.GetRootAs(buf) + treestruct = jax.tree.structure(_deserialize_pytreedef_to_pytree(exp)) + return treestruct diff --git a/jax/experimental/array_serialization/serialization.py b/jax/experimental/array_serialization/serialization.py index 8a082b6e912d..6f2e1a2c73b4 100644 --- a/jax/experimental/array_serialization/serialization.py +++ b/jax/experimental/array_serialization/serialization.py @@ -17,34 +17,44 @@ import abc import asyncio -from collections.abc import Awaitable, Callable, Sequence -from functools import partial +from collections.abc import Callable, Sequence +import functools import itertools import logging -import os import re import threading import time -from typing import Any, Optional +from typing import Any import jax from jax._src import array from jax._src import distributed from jax._src import sharding -from jax._src.layout import Layout from jax._src import typing from jax._src import util -from jax._src.lib import xla_extension as xe -import jax.numpy as jnp -import numpy as np +from jax._src.layout import Format +from jax._src.lib import _jax +from jax.experimental.array_serialization import tensorstore_impl as ts_impl +# ruff: noqa: F401 +# pylint: disable=unused-import +# import tensorstore-backed methods for backward compatibility. +from jax.experimental.array_serialization.tensorstore_impl import ( + _run_deserialization as run_deserialization, + _run_serialization as run_serialization, + async_serialize, async_deserialize, _TS_CONTEXT as TS_CONTEXT, + _DEFAULT_BASE_DRIVER as _DEFAULT_DRIVER, _LimitInFlightBytes) import tensorstore as ts +# for compatibility with older zarr format +_get_metadata = functools.partial(ts_impl._get_tensorstore_metadata, + driver='zarr') +get_tensorstore_spec = functools.partial(ts_impl.get_tensorstore_spec, + driver='zarr', ocdbt=False) +# pylint: enable=unused-import + -TS_CONTEXT = ts.Context({'file_io_concurrency': {'limit': 128}}) -_REMOVED_VALUE = 'Value removed' _CHECKPOINT_SUCCESS = 'checkpoint_write_success' _module_unique_count = itertools.count() -_DEFAULT_DRIVER = 'file' _DISTRIBUTED_SYSTEM_MSG = ( 'Please initialize the distributed system via ' '`jax.distributed.initialize()` at the start of your program.') @@ -54,7 +64,7 @@ {'driver': 's3', 'path_regex': None}, ] -class BarrierTimeoutException(Exception): +class BarrierTimeoutError(Exception): pass _BARRIER_TIMED_OUT_MSG = ( @@ -66,68 +76,6 @@ class BarrierTimeoutException(Exception): logger = logging.getLogger(__name__) -async def create_async_array_from_callback( - global_shape: array.Shape, - inp_sharding: jax.sharding.Sharding, - data_callback: Callable[[array.Index, jax.Device], Awaitable[jax.Array]], -): - device_to_index_map = inp_sharding.devices_indices_map(global_shape) - addressable_da = inp_sharding._addressable_device_assignment - future_arrays = [data_callback(device_to_index_map[d], d) - for d in addressable_da] - dbs = await asyncio.gather(*future_arrays) - return array.make_array_from_single_device_arrays( - global_shape, inp_sharding, dbs) - - -def _get_metadata(arr): - local_shape = arr.addressable_data(0).shape - return { - 'compressor': {'id': 'zstd'}, - 'shape': arr.shape, - 'chunks': np.array(np.maximum(1, local_shape)), - } - - -def _spec_has_metadata(tree): - if not isinstance(tree, dict): - return False - return 'metadata' in tree or any( - _spec_has_metadata(subtree) for _, subtree in tree.items()) - -def _get_kvstore_for_gcs(ckpt_path: str): - m = re.fullmatch('^gs://([^/]*)/(.*)$', ckpt_path, re.DOTALL) - if m is None: - raise ValueError('The ckpt_path should contain the bucket name and the ' - f'file path inside the bucket. Got: {ckpt_path}') - gcs_bucket = m.group(1) - path_without_bucket = m.group(2) - return {'driver': 'gcs', 'bucket': gcs_bucket, 'path': path_without_bucket} - -def get_tensorstore_spec(ckpt_path: str, ocdbt: bool = False): - # Normalize path to exclude trailing '/'. In GCS path case, we will need to - # fix the path prefix to add back the stripped '/'. - ckpt_path = os.path.normpath(ckpt_path).replace('gs:/', 'gs://') - is_gcs_path = ckpt_path.startswith('gs://') - spec = {'driver': 'zarr', 'kvstore': {}} - if ocdbt: - if not is_gcs_path and not os.path.isabs(ckpt_path): - raise ValueError(f'Checkpoint path should be absolute. Got {ckpt_path}') - base_path = os.path.dirname(ckpt_path) - spec['kvstore'] = { - 'driver': 'ocdbt', - 'base': base_path if is_gcs_path else f'{_DEFAULT_DRIVER}://{base_path}', - 'path': os.path.basename(ckpt_path), - } - else: - if is_gcs_path: - spec['kvstore'] = _get_kvstore_for_gcs(ckpt_path) - else: - spec['kvstore'] = {'driver': _DEFAULT_DRIVER, 'path': ckpt_path} - - return spec - - def is_remote_storage(tspec: dict[str, Any] | str) -> bool: """Detect if user is using cloud storages. @@ -157,278 +105,6 @@ def is_remote_storage(tspec: dict[str, Any] | str) -> bool: return False - -# Lifted from T5X. -class _LimitInFlightBytes: - """Limits in-flight bytes when reading/writing checkpoints per process.""" - - def __init__(self, num_bytes): - self._max_bytes = num_bytes - self._available_bytes = num_bytes - self._cv = asyncio.Condition(lock=asyncio.Lock()) - - async def wait_for_bytes(self, requested_bytes): - if requested_bytes > self._max_bytes: - raise ValueError('Requested more bytes than we reserved space for: ' - f'{requested_bytes} > {self._max_bytes}') - async with self._cv: - await self._cv.wait_for(lambda: self._available_bytes > requested_bytes) - self._available_bytes -= requested_bytes - assert self._available_bytes >= 0 - - async def release_bytes(self, requested_bytes): - async with self._cv: - self._available_bytes += requested_bytes - assert self._available_bytes <= self._max_bytes - self._cv.notify_all() - - -async def transfer_shard_to_host(shard: array.Shard) -> np.ndarray: - data = shard.data - has_pinned_host = any( - m.kind == "pinned_host" for m in shard.device.addressable_memories()) - if has_pinned_host: - # If available, transfer to pinned host memory - sharding = jax.sharding.SingleDeviceSharding(shard.device, - memory_kind="pinned_host") - data = jax.device_put(data, sharding) - else: - data.copy_to_host_async() - # Allow other transfers to be scheduled simultaneously - await asyncio.sleep(0) - # Ensure that jax.Array's internal numpy array can be zero-copied. Tensorstore - # implicitly converts the written data to a numpy array, and would otherwise - # silently copy host-to-host. - return np.array(data, copy=False) - - -async def async_serialize( - arr_inp, - tensorstore_spec, - commit_future=None, - context=TS_CONTEXT, - primary_host: int | None = 0, - replica_id: int = 0, - transaction: Optional[ts.Transaction] = None, -): - """Serialize an array using TensorStore. - - Args: - arr_inp: The array to serialize. - tensorstore_spec: The tensorstore spec to use. - commit_future: A list of futures that will be appended to. The futures can - be awaited asynchronously. If None, the futures will be awaited - synchronously by this method. - context: ts.Context instance. - primary_host: Primary host, which indicates the host that will be treated as - the "leader". If None, all hosts are treated as the primary. DO NOT USE - unless you are sure you know what you are doing. - replica_id: Allows overriding the shard replica id that will be saved. DO - NOT USE unless you are sure you know what you are doing. - transaction: TensorStore transaction to use for opening and writing the - array. If not specified, a non-transactional write will be used. - """ - if (isinstance(arr_inp, array.ArrayImpl) and jax.process_count() > 1 and - arr_inp.is_fully_addressable): - raise ValueError( - f'Passing fully addressable arrays to a multiprocess ' - f'serialization is not allowed, as this may lead to a race condition ' - f'between processes. Serialization have failed for the array with ' - f'the path "{tensorstore_spec["kvstore"]["path"]}".') - - # 'metadata' may not be present at the top level (for example, if we are using - # a 'cast' driver). - if not _spec_has_metadata(tensorstore_spec): - tensorstore_spec['metadata'] = _get_metadata(arr_inp) - - # Set dtype if it's not in spec - if 'dtype' not in tensorstore_spec: - tensorstore_spec['dtype'] = jnp.dtype(arr_inp.dtype).name - - # If primary_host is None, all hosts will checkpoint. This is used - # for checkpointing to local filesystem. - if primary_host is None or jax.process_index() == primary_host: - open_future = ts.open( - ts.Spec(tensorstore_spec), - create=True, - open=True, - context=context, - transaction=transaction, - ) - # Asynchronous case. - if commit_future is not None: - assert isinstance(commit_future, list) - commit_future.append(open_future) - else: - await open_future - - # `ts.open` runs twice for process `primary_host` because for the first time, - # we just get the future to be awaited upon in the background thread. The - # second one runs with `assume_metadata=True` which does no I/O operation and - # returns the tensorstore object. - # For every process other than `primary_host`, we open with - # `assume_metadata=True`. - t = await ts.open( - ts.Spec(tensorstore_spec), - open=True, - assume_metadata=True, - context=context, - transaction=transaction, - ) - - async def _write_array(shard): - if shard.replica_id == replica_id: - data = await transfer_shard_to_host(shard) - write_future = t[shard.index].write( - data, - # Avoid additional copy of input array into the TensorStore chunk - # cache. If `arr_inp` is a jax.Array, the result of converting - # it to a NumPy array, as is done internally by TensorStore, is - # guaranteed to be immutable and therefore it is safe to retain a - # reference indefinitely. - can_reference_source_data_indefinitely=isinstance( - arr_inp, array.ArrayImpl - ), - ) - if commit_future is not None: - assert isinstance(commit_future, list) - commit_future.append(write_future.commit) - await write_future.copy - else: - await write_future.commit - - local_shards = arr_inp.addressable_shards - future_write_state = jax.tree_util.tree_map(_write_array, local_shards) - return await asyncio.gather(*future_write_state) - - -def run_serialization(arrays, tensorstore_specs): - async def _run_serializer(): - future_writer = jax.tree_util.tree_map(async_serialize, arrays, tensorstore_specs) - return await asyncio.gather(*future_writer) - asyncio.run(_run_serializer()) - - -def estimate_read_memory_footprint(t: ts.TensorStore, - domain: ts.IndexDomain) -> int: - rank = t.rank - num_bytes = t.dtype.numpy_dtype.itemsize - chunk_template = t.chunk_layout.read_chunk_template - if domain is None: - domain = t.domain - origin = domain.origin - shape = domain.shape - chunk_origin = chunk_template.origin - chunk_shape = chunk_template.shape - - # Some TensorStore drivers are not chunked, e.g. the inline 'array' driver. - # For those, instead of returning a near-infinite memory footprint, estimate - # the footprint as the entire shape. - for i in range(rank): - if not chunk_template[i].finite: - return domain.size * num_bytes - - # Otherwise, if we have a chunked driver, estimate based on chunk size. - for i in range(rank): - origin_value = origin[i] - chunk_origin_value = chunk_origin[i] - chunk_size = chunk_shape[i] - lower = origin_value - chunk_origin_value - upper = origin_value + shape[i] - chunk_origin_value - lower_aligned = lower // chunk_size * chunk_size - upper_aligned = -(-upper // chunk_size) * chunk_size - num_bytes *= (upper_aligned - lower_aligned) - - return num_bytes - - -async def async_deserialize( - user_in_sharding: jax.sharding.Sharding | Layout, - tensorstore_spec: ts.Spec | dict[str, Any], - global_shape: Sequence[int] | None = None, - dtype=None, - byte_limiter: _LimitInFlightBytes | None = None, - context=TS_CONTEXT, - assume_metadata: bool = False, -): - in_sharding = (user_in_sharding.sharding - if isinstance(user_in_sharding, Layout) else user_in_sharding) - if not isinstance(in_sharding, jax.sharding.Sharding): - raise ValueError( - 'sharding passed to deserialization should be specified, concrete and' - f' an instance of `jax.sharding.Sharding`. Got {in_sharding}') - dll = (user_in_sharding.device_local_layout - if isinstance(user_in_sharding, Layout) else None) - t = await ts.open( - tensorstore_spec, - open=True, - assume_metadata=assume_metadata, - context=context, - ) - shape = t.shape if global_shape is None else global_shape - new_shard_shape = in_sharding.shard_shape(tuple(shape)) - - async def cb(index: array.Index, device: jax.Device): - requested_domain = ts.IndexTransform(input_shape=shape)[index].domain - restricted_domain = t.domain.intersect(requested_domain) - requested_bytes = estimate_read_memory_footprint(t, restricted_domain) - # Limit the bytes read for every shard. - if byte_limiter is not None: - await byte_limiter.wait_for_bytes(requested_bytes) - # This maybe needed because the shape the array was saved with is smaller - # than the requested shape of the array in which it will be reloaded. So - # the extra values will be filled with 0s. - out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype) - await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]][ - restricted_domain].write(t[restricted_domain]) - if dtype is not None: - # Cast while reloading on process to avoid 2 copies on device if the - # casting is done on device. - out = out.astype(dtype) - # Convert to jnp array so that layouts are initialized properly for - # sub-byte dtypes. - # TODO(yashkatariya): This is a band-aid fix. Figure out a better way to - # make this work. - if out.dtype == jnp.int4: - out = jnp.asarray(out) # type: ignore - result = jax.device_put( - out, Layout(dll, jax.sharding.SingleDeviceSharding(device))) - if byte_limiter is not None: - # NB: `out` actually might not be ready for garbage collection by the - # time we call release_bytes . Thus peak memory usage still might grow - # beyond what byte_limiter limit suggests it should. The simplest option - # would be to call `result.block_until_ready()`` here. However it - # also comes with ~15-20% perf penalty as we would be waiting for CPU->GPU - # transfer instead of loading data. In the future, if memory pressure - # becomes a problem, we can instead instrument bytelimiter to - # keep track of all in-flight tensors and only block_until_ready, if byte - # limiter hits the limit to get reduced memory usage, without losing - # performance in common use cases. - await byte_limiter.release_bytes(requested_bytes) - return result - - return await create_async_array_from_callback(tuple(shape), in_sharding, cb) - - -def run_deserialization(shardings: Sequence[sharding.Sharding | Layout], - tensorstore_specs: Sequence[dict[str, Any]], - global_shapes: Sequence[array.Shape] | None = None, - dtypes: Sequence[typing.DTypeLike] | None = None, - concurrent_gb: int = 32): - concurrent_bytes = concurrent_gb * 10**9 - - async def _run_deserializer(): - # Object should be created once per process. - byte_limiter = _LimitInFlightBytes(concurrent_bytes) - future_arrays = jax.tree_util.tree_map( - partial(async_deserialize, byte_limiter=byte_limiter), - list(shardings), list(tensorstore_specs), - [None] * len(tensorstore_specs) if global_shapes is None else global_shapes, - [None] * len(tensorstore_specs) if dtypes is None else dtypes) - return await asyncio.gather(*future_arrays) - return asyncio.run(_run_deserializer()) - - def _get_key(key: int): return f'tensorstore_checkpoint_{key}' @@ -510,8 +186,7 @@ def __init__(self, timeout_secs=300): if jax.process_count() > 1 and distributed.global_state.client is None: raise ValueError(_DISTRIBUTED_SYSTEM_MSG) - if jax.process_count() > 1: - self._client = distributed.global_state.client + self._client = distributed.global_state.client self._count = None def __del__(self): @@ -533,7 +208,9 @@ def _thread_func(self): logger.info('Finished committing to storage layer by process: %s', current_process) + key_for_barrier = None if process_count > 1: + assert self._client is not None # All processes will wait at the barrier. When all processes are at the # barrier, the barrier will be satisfied. If not, then it will timeout. key_for_barrier = _get_key(self._count) @@ -544,9 +221,11 @@ def _thread_func(self): current_process) if current_process == 0: - self._on_commit_callback() - logger.info('on_commit_callback successfully ran!') + if self._on_commit_callback is not None: + self._on_commit_callback() + logger.info('on_commit_callback successfully ran!') if process_count > 1: + assert self._client is not None self._client.key_value_set(key_for_barrier, _CHECKPOINT_SUCCESS) logger.info('Process 0 successfully set key %s in the kv store', key_for_barrier) @@ -555,7 +234,7 @@ def _thread_func(self): '/jax/checkpoint/write/async/thread_duration_sec', time.time() - thread_start_time) - except Exception as e: + except Exception as e: # pylint: disable=broad-except self._exception = e def _start_async_commit(self, on_commit_callback): @@ -570,9 +249,9 @@ def check_for_errors(self): # Clears self._exception so it is only raised once. exception = self._exception self._exception = None - if (isinstance(exception, xe.XlaRuntimeError) and + if (isinstance(exception, _jax.JaxRuntimeError) and 'DEADLINE_EXCEEDED: Barrier timed out' in str(exception)): - raise BarrierTimeoutException( + raise BarrierTimeoutError( '\n'.join([str(exception), _BARRIER_TIMED_OUT_MSG])) raise exception # pylint: disable=raising-bad-type @@ -586,6 +265,7 @@ def wait_until_finished(self): logger.info('Error check finished successfully') if jax.process_count() > 1 and self._count is not None: + assert self._client is not None # Block until process 0 writes success value to the key value store. # If it fails to write it, then `blocking_key_value_get` will time out. get_key = _get_key(self._count) @@ -593,7 +273,7 @@ def wait_until_finished(self): logger.info('blocking_key_value_get on key %s was successfully ' 'completed.', get_key) - def _add_futures(self, futures: Sequence[asyncio.Future]): + def _add_futures(self, futures: Sequence[ts.Future]): self._commit_futures = futures @@ -605,8 +285,8 @@ def serialize( arrays, tensorstore_specs, *, - on_commit_callback, - transaction: Optional[ts.Transaction] = None, + on_commit_callback: Callable[[], None] | None = None, + transaction: ts_impl.Transaction | None = None, ): """Serializes Arrays or Arrays via TensorStore asynchronously. @@ -635,11 +315,11 @@ def serialize( logger.info('Waiting for previous serialization to finish.') self.wait_until_finished() - commit_futures: list[ts.Future] = [] + commit_futures: list[ts_impl.Future] = [] async def _run_serializer(): future_writer = jax.tree_util.tree_map( - lambda arr_inp, tensorstore_spec: async_serialize( + lambda arr_inp, tensorstore_spec: ts_impl.async_serialize( arr_inp, tensorstore_spec, commit_future=commit_futures, @@ -649,7 +329,6 @@ async def _run_serializer(): tensorstore_specs, ) return await asyncio.gather(*future_writer) - asyncio.run(_run_serializer()) self._add_futures(commit_futures) @@ -663,25 +342,25 @@ def serialize_with_paths( arrays: Sequence[jax.Array], paths: Sequence[str], *, - on_commit_callback, - transaction: Optional[ts.Transaction] = None, + on_commit_callback: Callable[[], None] | None = None, + transaction: ts_impl.Transaction | None = None, ): tspecs = jax.tree.map(get_tensorstore_spec, paths) - self.serialize( + return self.serialize( arrays, tspecs, on_commit_callback=on_commit_callback, transaction=transaction, ) - def deserialize(self, shardings: Sequence[sharding.Sharding | Layout], + def deserialize(self, shardings: Sequence[sharding.Sharding | Format], tensorstore_specs: Sequence[dict[str, Any]], global_shapes: Sequence[array.Shape] | None = None, dtypes: Sequence[typing.DTypeLike] | None = None, concurrent_gb: int = 32): self.wait_until_finished() - return run_deserialization(shardings, tensorstore_specs, - global_shapes, dtypes, concurrent_gb) + return ts_impl._run_deserialization( + shardings, tensorstore_specs, global_shapes, dtypes, concurrent_gb) def deserialize_with_paths( self, shardings: Sequence[sharding.Sharding], diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py index 9f4539fc63c8..bacc1b299702 100644 --- a/jax/experimental/array_serialization/serialization_test.py +++ b/jax/experimental/array_serialization/serialization_test.py @@ -12,31 +12,69 @@ # See the License for the specific language governing permissions and # limitations under the License. +# pylint: disable=g-importing-member import asyncio -import math +from dataclasses import dataclass from functools import partial +import json +import logging +import math import os import pathlib -import tracemalloc as tm +import pickle +import tempfile +import threading +import time +import tracemalloc as tm +from typing import Any +from concurrent.futures import ThreadPoolExecutor from absl.testing import absltest from absl.testing import parameterized import jax -import jax.numpy as jnp +from jax._src import array from jax._src import config from jax._src import test_util as jtu -from jax._src import array -from jax.sharding import NamedSharding, GSPMDSharding, SingleDeviceSharding -from jax.sharding import PartitionSpec as P +from jax._src.export._export import ( + deserialization_registry as node_deserialization_registry) +from jax._src.export._export import ( + serialization_registry as node_serialization_registry) +from jax._src.layout import Layout +from jax._src.layout import Format +from jax.experimental.array_serialization import pytree_serialization from jax.experimental.array_serialization import serialization -from jax.experimental.layout import Layout, DeviceLocalLayout as DLL +from jax.experimental.array_serialization import tensorstore_impl as ts_impl + +from jax.experimental.array_serialization.pytree_serialization_utils import ( + register_pytree_node_serialization) + +import jax.numpy as jnp + +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P +from jax.sharding import SingleDeviceSharding import numpy as np import tensorstore as ts +# pylint: enable=g-importing-member jax.config.parse_flags_with_absl() jtu.request_cpu_devices(8) +_DEFAULT_SHARDING = None # to be overridden by tests with SingleDeviceSharding + + +def tree_load(*args, **kw): + return pytree_serialization.load(*args, shardings=_DEFAULT_SHARDING, **kw) + +tree_save = pytree_serialization.save +tree_load_pytreedef = pytree_serialization.load_pytreedef + + +def _get_replicated_sharding(devices): + return NamedSharding(jax.sharding.Mesh(devices, 'x'), P()) + + class CheckpointTest(jtu.JaxTestCase): def _on_commit_callback(self, temp_ckpt_dir, final_ckpt_dir): @@ -77,29 +115,32 @@ def test_deserialize_on_array_list(self): self.assertArraysEqual(deserialized_array, inp) @jtu.skip_on_devices('cpu') + @jtu.thread_unsafe_test() def test_memory_consumption(self): global_mesh = jtu.create_mesh((2, 4), ('x', 'y')) inp_shape = (2_048, 4_096) pspec = P('x', 'y') num = math.prod(inp_shape) sharding = NamedSharding(global_mesh, pspec) - src = jnp.arange(num, dtype=np.int32).reshape(inp_shape) # 8e9 + src = jnp.arange(num, dtype=np.int32).reshape(inp_shape) # 8e6 elements inp = array.make_array_from_callback( inp_shape, sharding, lambda idx: src[idx]) - ckpt_dir = pathlib.Path(self.create_tempdir('memprof').full_path) + ckpt_dir = pathlib.Path(self.create_tempdir( + 'memprof-deserialize').full_path) tspec = serialization.get_tensorstore_spec(str(ckpt_dir)) manager = serialization.GlobalAsyncCheckpointManager() manager.serialize( [inp], [tspec], - on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) + on_commit_callback=partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() async def deserialize_with_byte_limit(): r = await serialization.async_deserialize( - sharding, tspec, inp_shape, - byte_limiter=serialization._LimitInFlightBytes(4_200_000)) + sharding, tspec, inp_shape, + byte_limiter=serialization._LimitInFlightBytes(4_200_000)) r.block_until_ready() tm.start() @@ -107,7 +148,7 @@ async def deserialize_with_byte_limit(): unused_current, peak = tm.get_traced_memory() # NB: some padding + tensorstore overhead. It should always be # less than array size (2048 * 4096 * 4 = 32M) - self.assertLess(peak, 10_000_000) + self.assertLess(peak, 13_000_000) deserialize_wo_limit = serialization.async_deserialize( sharding, tspec, inp_shape) tm.clear_traces() @@ -122,6 +163,7 @@ async def deserialize_with_byte_limit(): self.assertGreater(peak, 30_000_000) tm.stop() + @jtu.thread_unsafe_test() def test_memory_consumption_for_save(self): global_mesh = jtu.create_mesh((1, 1), ('x', 'y')) inp_shape = (16 * 1024, 16 * 1024) @@ -132,25 +174,24 @@ def test_memory_consumption_for_save(self): inp = array.make_array_from_callback( inp_shape, sharding, lambda idx: src[idx] ) - ckpt_dir = pathlib.Path(self.create_tempdir('memprofsave').full_path) - tspec = serialization.get_tensorstore_spec(str(ckpt_dir)) + ckpt_dir = pathlib.Path(self.create_tempdir( + 'memprofsave-serialize').full_path) + tspec = ts_impl.get_tensorstore_spec(str(ckpt_dir), ocdbt=False, + driver='zarr3') tspec['metadata'] = { 'shape': inp.shape, - 'compressor': None, - 'chunks': inp.shape, + 'data_type': jnp.dtype(inp.dtype).name, + 'chunk_grid': { + 'name': 'regular', + 'configuration': {'chunk_shape': np.array(np.maximum(1, inp.shape))} + } } - is_cpu = jtu.test_device_matches(['cpu']) tm.start() try: manager = serialization.GlobalAsyncCheckpointManager() - manager.serialize( - [inp], - [tspec], - on_commit_callback=partial( - self._on_commit_callback, ckpt_dir, ckpt_dir - ), - ) + manager.serialize([inp], [tspec], on_commit_callback=partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() unused_current, peak = tm.get_traced_memory() self.assertLess(peak, src.nbytes * (1 * (not is_cpu) + 0.5)) @@ -176,7 +217,8 @@ def test_checkpointing_with_path_variant(self): manager = serialization.GlobalAsyncCheckpointManager() manager.serialize_with_paths( [a1], ckpt_paths, - on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) + on_commit_callback=partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() m1, = manager.deserialize_with_paths( @@ -201,7 +243,8 @@ def test_checkpointing_jax_array(self): inp_shape, NamedSharding(global_mesh, pspec), lambda idx: global_input_data1[idx]) ckpt_dir = pathlib.Path(self.create_tempdir('ckpt').full_path) - ckpt_path1 = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/first').full_path) + ckpt_path1 = pathlib.Path( + self.create_tempdir(f'{ckpt_dir}/first').full_path) # Second Array global_input_data2 = np.arange( @@ -209,7 +252,8 @@ def test_checkpointing_jax_array(self): a2 = array.make_array_from_callback( inp_shape, NamedSharding(global_mesh, pspec), lambda idx: global_input_data2[idx]) - ckpt_path2 = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/second').full_path) + ckpt_path2 = pathlib.Path( + self.create_tempdir(f'{ckpt_dir}/second').full_path) # Third Array def cb3(_): @@ -217,15 +261,17 @@ def cb3(_): global_mesh1d = jtu.create_mesh((8,), ('x',)) a3 = array.make_array_from_callback( (0,), NamedSharding(global_mesh1d, P(None)), cb3) - ckpt_path3 = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/third').full_path) + ckpt_path3 = pathlib.Path( + self.create_tempdir(f'{ckpt_dir}/third').full_path) ckpt_paths = [str(ckpt_path1), str(ckpt_path2), str(ckpt_path3)] - tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths) + tspecs = jax.tree.map(serialization.get_tensorstore_spec, ckpt_paths) manager = serialization.GlobalAsyncCheckpointManager() manager.serialize( [a1, a2, a3], tspecs, - on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) + on_commit_callback=partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() m1, m2, m3 = serialization.run_deserialization( @@ -257,6 +303,24 @@ def cb3(_): self.assertArraysEqual(np.asarray(s.data), np.array([], dtype=np.float32)) self.assertEqual(m3.dtype, np.float32) + @jtu.thread_unsafe_test() + def test_deserialization_does_not_hang_when_concurrent_gb_exceeded(self): + TIMEOUT_SEC = 5 + mngr = serialization.GlobalAsyncCheckpointManager(timeout_secs=TIMEOUT_SEC) + a = jnp.ones((1024, 1024)) + path = str(pathlib.Path(self.create_tempdir('small').full_path) / 'array') + mngr.serialize_with_paths([a], [path]) + mngr.wait_until_finished() + + executor = ThreadPoolExecutor(max_workers=1) + future = executor.submit(mngr.deserialize_with_paths, [a.sharding], + [path], concurrent_gb=1e-9) # 1 byte + try: + future.result(timeout=TIMEOUT_SEC) + except TimeoutError: + future.cancel() + self.fail('Deserialization times out if size exceeds concurrent_gb.') + def test_checkpointing_ocdbt_transaction(self): global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) inp_shape = (8, 2) @@ -295,9 +359,8 @@ def cb3(_): ckpt_path3 = ckpt_dir / 'third' ckpt_paths = [str(ckpt_path1), str(ckpt_path2), str(ckpt_path3)] - tspecs = jax.tree_util.tree_map( - lambda p: serialization.get_tensorstore_spec(p, ocdbt=True), ckpt_paths - ) + tspecs = jax.tree.map(partial(ts_impl.get_tensorstore_spec, ocdbt=True), + ckpt_paths) manager = serialization.GlobalAsyncCheckpointManager() with ts.Transaction(atomic=True) as transaction: @@ -312,13 +375,8 @@ def cb3(_): manager.wait_until_finished() m1, m2, m3 = serialization.run_deserialization( - [ - NamedSharding(global_mesh, pspec), - NamedSharding(global_mesh, P('x')), - NamedSharding(global_mesh1d, P(None)), - ], - tspecs, - ) + [NamedSharding(global_mesh, pspec), NamedSharding(global_mesh, P('x')), + NamedSharding(global_mesh1d, P(None))], tspecs) self.assertIsInstance(m1, array.ArrayImpl) self.assertArraysEqual( @@ -367,12 +425,13 @@ def cb1(index): ckpt_dir = pathlib.Path(self.create_tempdir('first').full_path) ckpt_paths = [str(ckpt_dir)] - tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths) + tspecs = jax.tree.map(serialization.get_tensorstore_spec, ckpt_paths) manager = serialization.GlobalAsyncCheckpointManager() manager.serialize( [arr], tspecs, - on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) + on_commit_callback=partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() ds = NamedSharding(jtu.create_mesh((4, 2), ('x', 'y'), iota_order=True), @@ -395,15 +454,16 @@ def cb1(index): for l in m1.addressable_shards: self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id]) - new_ds = GSPMDSharding.get_replicated(list(global_mesh.devices.flat)) - m2, = serialization.run_deserialization([new_ds], tspecs, [(8, 2)], [np.float32]) + new_ds = _get_replicated_sharding(list(global_mesh.devices.flat)) + m2, = serialization.run_deserialization([new_ds], tspecs, [(8, 2)], + [np.float32]) for l in m2.addressable_shards: self.assertArraysEqual(l.data, global_input_data1.astype('float32')) @parameterized.product(input_dtype=[jnp.int4, jnp.int8]) def test_checkpointing_with_int4(self, input_dtype): if config.use_shardy_partitioner.value: - self.skipTest("TODO(b/376077396): Fix XlaRuntimeError: INVALID_ARGUMENT") + self.skipTest('TODO(b/376077396): Fix JaxRuntimeError: INVALID_ARGUMENT') global_mesh = jtu.create_mesh((2, 2), ('x', 'y'), iota_order=True) global_input_shape = (8, 2) num = math.prod(global_input_shape) @@ -418,12 +478,13 @@ def cb(index): ckpt_dir = pathlib.Path(self.create_tempdir('first').full_path) ckpt_paths = [str(ckpt_dir)] - tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths) + tspecs = jax.tree.map(serialization.get_tensorstore_spec, ckpt_paths) manager = serialization.GlobalAsyncCheckpointManager() manager.serialize( [arr], tspecs, - on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) + on_commit_callback=partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() ds = NamedSharding(jtu.create_mesh((4, 2), ('x', 'y'), iota_order=True), @@ -448,8 +509,9 @@ def cb(index): for l in m1.addressable_shards: self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id]) - new_ds = GSPMDSharding.get_replicated(list(global_mesh.devices.flat)) - m2, = serialization.run_deserialization([new_ds], tspecs, [(8, 2)], [target_dtype]) + new_ds = _get_replicated_sharding(list(global_mesh.devices.flat)) + m2, = serialization.run_deserialization([new_ds], tspecs, [(8, 2)], + [target_dtype]) for l in m2.addressable_shards: self.assertArraysEqual(l.data, global_input_data.astype(target_dtype)) @@ -463,22 +525,17 @@ def test_checkpointing_scalar_jax_array(self): ckpt_dir = pathlib.Path(self.create_tempdir('first').full_path) ckpt_paths = [str(ckpt_dir)] - tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths) - + tspecs = jax.tree.map(serialization.get_tensorstore_spec, ckpt_paths) manager = serialization.GlobalAsyncCheckpointManager() manager.serialize( [array1], tspecs, - on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) + on_commit_callback=partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() ds = NamedSharding(jtu.create_mesh((2,), ('x')), P(None)) - m1, = serialization.run_deserialization( - [ds], - tspecs, - [()], - [np.float32] - ) + m1, = serialization.run_deserialization([ds], tspecs, [()], [np.float32]) for l in m1.addressable_shards: self.assertArraysEqual(np.asarray(l.data), data.astype(np.float32)) @@ -488,9 +545,7 @@ def test_deserialize_tensorstore_array_jax_array(self): data = np.arange(1024) tspec = ts.array(data).spec() m1, = serialization.run_deserialization( - [NamedSharding(global_mesh, P(None))], - [tspec] - ) + [NamedSharding(global_mesh, P(None))], [tspec]) for l in m1.addressable_shards: self.assertArraysEqual(np.asarray(l.data), data) @@ -507,9 +562,9 @@ def test_spec_has_metadata(self): }, 'f': 4 } - self.assertTrue(serialization._spec_has_metadata(spec)) + self.assertTrue(ts_impl._spec_has_metadata(spec)) self.assertTrue( - serialization._spec_has_metadata({ + ts_impl._spec_has_metadata({ 'driver': 'zarr', 'kvstore': 'gfile', 'metadata': { @@ -531,39 +586,40 @@ def test_spec_has_no_metadata(self): }, 'f': 4 } - self.assertFalse(serialization._spec_has_metadata(spec)) + self.assertFalse(ts_impl._spec_has_metadata(spec)) def test_empty_spec_has_no_metadata(self): spec = {} - self.assertFalse(serialization._spec_has_metadata(spec)) + self.assertFalse(ts_impl._spec_has_metadata(spec)) @parameterized.named_parameters( ('gcs', 'gs://my/ckpt/dir/path'), ('file', '/my/ckpt/dir/path') ) def test_get_tensorstore_spec_ocdbt(self, path): - spec = serialization.get_tensorstore_spec(path, ocdbt=True) + spec = ts_impl.get_tensorstore_spec(path, ocdbt=True) is_gcs_path = path.startswith('gs://') + # for OCDBT the last part of the path is the key in the kvstore + expected_path = os.path.split(path)[0] if is_gcs_path: - self.assertEqual(spec['kvstore']['base'], os.path.dirname(path)) + self.assertEqual(spec['kvstore']['base']['driver'], 'gcs') + self.assertTrue(expected_path.endswith(spec['kvstore']['base']['path'])) else: - self.assertEqual(spec['kvstore']['base'], - f'{serialization._DEFAULT_DRIVER}://{os.path.dirname(path)}') - self.assertEqual(spec['kvstore']['path'], 'path') + self.assertEqual(spec['kvstore']['base']['path'], expected_path) def test_get_tensorstore_spec_not_absolute_path(self): path = 'my/ckpt/path' with self.assertRaisesRegex(ValueError, - "Checkpoint path should be absolute"): - serialization.get_tensorstore_spec(path, ocdbt=True) + 'Checkpoint path should be absolute'): + ts_impl.get_tensorstore_spec(path, ocdbt=True) def test_maybe_cloud_storage(self): - gs_path = 'gs://some-buck/path' - gs_spec = serialization.get_tensorstore_spec(gs_path, ocdbt=True) + gs_path = 'gs://some-buck/path/array_name' + gs_spec = ts_impl.get_tensorstore_spec(gs_path, ocdbt=True) self.assertTrue(serialization.is_remote_storage(gs_spec)) - local_path = '/tmp/checkpoint' - local_spec = serialization.get_tensorstore_spec(local_path, ocdbt=True) + local_path = '/tmp/checkpoint/array_name' + local_spec = ts_impl.get_tensorstore_spec(local_path, ocdbt=True) self.assertFalse(serialization.is_remote_storage(local_spec)) nested_tspec = { @@ -571,46 +627,54 @@ def test_maybe_cloud_storage(self): 'dtype': 'int32', 'base': { 'driver': 'zarr', - 'kvstore': {'driver': 'ocdbt', 'base': 's3://some-bucket/path'}, + 'kvstore': {'driver': 'ocdbt', + 'base': 's3://some-bucket/path/array_name'}, }, } self.assertTrue(serialization.is_remote_storage(nested_tspec)) - def test_load_with_layout(self): + @parameterized.named_parameters(('4_devices', 4), ('1_device', 1)) + def test_load_with_layout(self, device_count): if not jtu.test_device_matches(['tpu']): self.skipTest('Layouts are only supported on TPUs') - mesh = jtu.create_mesh((4, 2), ('x', 'y')) - np_inp = np.arange(32).reshape(8, 4) + mesh = jtu.create_mesh((2, 2) if device_count == 4 else (1, 1), ('x', 'y')) + np_inp = np.arange(device_count * 128 * 128 * 2).reshape((-1, 2 * 128)) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) - out_layout = jax.jit(lambda x: x.T, out_shardings=Layout(DLL.AUTO)).lower( - arr).compile().output_layouts - self.assertEqual(arr.layout.device_local_layout.major_to_minor, - out_layout.device_local_layout.major_to_minor[::-1]) + transpose_layout = Layout( + major_to_minor=arr.format.layout.major_to_minor[::-1], + tiling=arr.format.layout.tiling) + out_format = Format(transpose_layout, arr.sharding) + out_ref = jax.device_put(arr.T, out_format) ckpt_dir = pathlib.Path(self.create_tempdir('ckpt').full_path) ckpt_path = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/first').full_path) - tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, [ckpt_path]) + tspecs = jax.tree.map(ts_impl.get_tensorstore_spec, [ckpt_path]) manager = serialization.GlobalAsyncCheckpointManager() manager.serialize( - [arr], tspecs, - on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) + [out_ref], tspecs, + on_commit_callback=partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() - out, = serialization.run_deserialization([out_layout], tspecs) + out_specs = [out_format, jax.ShapeDtypeStruct( + out_ref.shape, out_ref.dtype, sharding=out_format)] - self.assertEqual(out.layout, out_layout) - self.assertIsInstance(out, array.ArrayImpl) - self.assertArraysEqual(out, np_inp) - for s in out.addressable_shards: - self.assertArraysEqual(s.data, np_inp[s.index]) + for out_spec, name in zip(out_specs, ['format', 'ShapeDtypeStruct']): + with self.subTest(f'deserialization_with_{name}'): + out, = serialization.run_deserialization([out_spec], tspecs) + self.assertEqual(out.format, out_format) + self.assertIsInstance(out, array.ArrayImpl) + self.assertArraysEqual(out, out_ref) + for s, s_ref in zip(out.addressable_shards, out_ref.addressable_shards): + self.assertArraysEqual(s.data, s_ref.data) def test_deserialization_with_int4(self): if config.use_shardy_partitioner.value: - self.skipTest("TODO(b/376077396): Fix XlaRuntimeError: INVALID_ARGUMENT") + self.skipTest('TODO(b/376077396): Fix JaxRuntimeError: INVALID_ARGUMENT') if jtu.test_device_matches(['gpu']): self.skipTest("Fails on GPU. Enable after it's fixed") dtype = jnp.int4 @@ -620,10 +684,8 @@ def test_deserialization_with_int4(self): ckpt_dir = pathlib.Path(self.create_tempdir('test_ckpt').full_path) # Run serialization. - sharding = jax.sharding.GSPMDSharding.get_replicated(jax.devices()) - tspecs = jax.tree_util.tree_map( - serialization.get_tensorstore_spec, [ckpt_dir] - ) + sharding = _get_replicated_sharding(list(jax.devices())) + tspecs = jax.tree.map(serialization.get_tensorstore_spec, [ckpt_dir]) manager = serialization.GlobalAsyncCheckpointManager() manager.serialize( [arr], @@ -634,11 +696,8 @@ def test_deserialization_with_int4(self): # Run deserialization. deserialized_arr, = serialization.run_deserialization( - shardings=[sharding], - tensorstore_specs=tspecs, - global_shapes=[shape], - dtypes=[dtype], - ) + shardings=[sharding], tensorstore_specs=tspecs, global_shapes=[shape], + dtypes=[dtype]) out = deserialized_arr.astype(jnp.int8) # doesn't crash self.assertEqual(out.dtype, jnp.int8) @@ -650,13 +709,434 @@ class TransferShardTest(jtu.JaxTestCase): @jtu.skip_on_devices('cpu') def test_transfer_shard_to_host(self): np_inp = np.arange(16).reshape((4, 4)) - sharding = SingleDeviceSharding(jax.devices()[0], memory_kind="device") + sharding = SingleDeviceSharding(jax.devices()[0], memory_kind='device') arr = jax.device_put(np_inp, sharding) shard = arr.addressable_shards[0] - np_out = asyncio.run(serialization.transfer_shard_to_host(shard)) + np_out = asyncio.run(ts_impl._transfer_shard_to_host(shard)) self.assertArraysEqual(np_out, np_inp) + +def _remove_from_serialization_registry(t: Any): + if t in node_serialization_registry: + serialized_name = node_serialization_registry[t][0] + del node_serialization_registry[t] + del node_deserialization_registry[serialized_name] + + +class UserAPITestCase(jtu.JaxTestCase): + name: str | None + path: pathlib.Path | None + + def setUp(self): + super().setUp() + tmpdir = tempfile.TemporaryDirectory() + self.enter_context(tmpdir) + self.name = tmpdir.name + self.path = pathlib.Path(self.name) + + def tearDown(self): + self.path = None + self.name = None + super().tearDown() + + def generate_random_fp32(self, shape, dtype=jnp.float32): + seed = round(time.time() * 1e6) % (2 ** 31) + key = jax.random.key(seed) + return jax.random.normal(key, shape=shape).astype(dtype) + + def generate_clean_tree(self, dtype=jnp.float32): + r1 = self.generate_random_fp32((), dtype=dtype) + r2 = self.generate_random_fp32((4,), dtype=dtype) + r3 = self.generate_random_fp32((2, 3), dtype=dtype) + return (r1, {'a': r2, 'rs': [r1, r2, r3], 'c': {'d': {'e': (r2,)}}}) + + def _is_equal(self, el1, el2): + if not isinstance(el1, type(el2)) or not isinstance(el2, type(el1)): + return False + if isinstance(el1, (np.ndarray, jax.Array)): + return (el1.dtype == el2.dtype and el1.shape == el2.shape + and jnp.allclose(el1, el2)) + else: + return el1 == el2 + + def assertPyTreeEqual(self, p1, p2, is_leaf=None): + leaves1, struct1 = jax.tree.flatten(p1, is_leaf=is_leaf) + leaves2, struct2 = jax.tree.flatten(p2, is_leaf=is_leaf) + self.assertEqual(struct1, struct2) + self.assertTrue(all(self._is_equal(el1, el2) + for (el1, el2) in zip(leaves1, leaves2))) + +_DTYPES_LIST = [ + jnp.uint8, + jnp.uint16, + jnp.uint32, + jnp.int8, + jnp.int16, + jnp.int32, + jnp.float8_e4m3fn, + jnp.float8_e4m3fnuz, + jnp.float8_e5m2, + jnp.float8_e5m2fnuz, + jnp.float8_e4m3b11fnuz, + jnp.bfloat16, + jnp.float16, + jnp.float32, + jnp.complex64, +] + +_X64_DTYPES_LIST = [ + jnp.uint64, + jnp.int64, + jnp.float64, + jnp.complex128, +] + +if jax.config.x64_enabled: + _DTYPES_LIST.extend(_X64_DTYPES_LIST) + + +@jax.tree_util.register_pytree_node_class +class CustomNode: + def __init__(self, a): + self.a = a + + def tree_flatten(self): + return (self.a,), None + + @classmethod + def tree_unflatten(cls, aux_data, children): + del aux_data + return cls(*children) + + +@partial(jax.tree_util.register_dataclass, data_fields=['a', 'd'], + meta_fields=['c']) +@dataclass +class CustomDataclass: + a: int + c: str + d: int + + +@jax.tree_util.register_static +class CustomStatic: + def __init__(self, a): + self.a = a + +# we're testing custom type registration which modifies the global registry +# so need to ensure we're not running multiple custom types tests in parallel +custom_types_threading_lock = threading.Lock() + + +class UserPytreeAPITest(UserAPITestCase): + def setUp(self): + super().setUp() + global _DEFAULT_SHARDING + _DEFAULT_SHARDING = SingleDeviceSharding(jax.devices()[0]) + self.tempdirs = [] + + def tearDown(self): + for tempdir in self.tempdirs: + tempdir.cleanup() + super().tearDown() + + def create_tempdir(self): + tempdir = tempfile.TemporaryDirectory() + self.tempdirs.append(tempdir) + return pathlib.Path(tempdir.name).resolve() + + @parameterized.product(tree=[{'a': 1}, [1, 2, 3], (1, 2, 3), 1, 2, 3]) + def test_save_then_load(self, tree): # pylint: disable=redefined-outer-name + path = self.create_tempdir() + tree = jax.tree.map(jnp.array, tree) + tree_save(tree, path) + tree2 = tree_load(path) + self.assertPyTreeEqual(tree, tree2) + + @parameterized.product(dtype=_DTYPES_LIST) + def test_saving_dtype(self, dtype): + if dtype in _X64_DTYPES_LIST and jtu.test_device_matches(['tpu']): + self.skipTest('Don\'t test x64 dtypes on TPUs') + path = self.create_tempdir() + test_tree = self.generate_clean_tree(dtype=dtype) + tree_save(test_tree, path) + new_tree = tree_load(path) + self.assertPyTreeEqual(test_tree, new_tree) + + def test_do_not_overwrite_noncheckpoint_directories(self): + path = self.create_tempdir() + path.mkdir(exist_ok=True) + (path / 'hello.txt').write_text('Hello World') + with self.assertRaisesRegex(RuntimeError, 'Refusing to work on a directory' + ' that is not a previous checkpoint.'): + tree_save({'a': jnp.ones(1)}, path) + + def test_checkpoint_exists(self): + path = self.create_tempdir() + tree_save({'a': jnp.ones(1)}, path) + with self.assertRaises(ValueError): + tree_save({'a': jnp.ones(1)}, path, overwrite=False) + + @parameterized.product(test_load_fail=[True, False]) + def test_custom_types(self, test_load_fail): + path = self.create_tempdir() + with custom_types_threading_lock: + magic_value = jnp.ones(()) * 37 + n = CustomNode(magic_value) + d = CustomDataclass(magic_value, 'hello', magic_value + 1) + s = CustomStatic(magic_value - 1) + tree_to_save = [n, (d, s)] + + register_pytree_node_serialization(CustomNode, + serialized_name='CustomNode', + serialize_auxdata=pickle.dumps, + deserialize_auxdata=pickle.loads) + register_pytree_node_serialization(CustomStatic, + serialized_name='CustomStatic', + serialize_auxdata=pickle.dumps, + deserialize_auxdata=pickle.loads) + register_pytree_node_serialization(CustomDataclass, + serialized_name='CustomDataclass', + serialize_auxdata=pickle.dumps, + deserialize_auxdata=pickle.loads) + tree_save(tree_to_save, path) + if test_load_fail: + _ = [_remove_from_serialization_registry(cls) + for cls in [CustomStatic, CustomNode, CustomDataclass]] + with self.assertRaises(ValueError): + _ = tree_load(path) + else: + tree2 = tree_load(path) + self.assertEqual(tree2[0].a, magic_value) + self.assertEqual(tree2[1][0].a, magic_value) + self.assertEqual(tree2[1][0].c, 'hello') + self.assertEqual(tree2[1][0].d, magic_value + 1) + self.assertEqual(tree2[1][1].a, magic_value - 1) + _ = [_remove_from_serialization_registry(cls) + for cls in [CustomStatic, CustomNode, CustomDataclass]] + + def test_flax_frozen_dict(self): + path = self.create_tempdir() + try: + # pylint: disable=g-import-not-at-top + # pylint: disable=g-importing-member + from flax.core.frozen_dict import FrozenDict + # pylint: enable=g-importing-member + # pylint: enable=g-import-not-at-top + except ImportError: + logging.warning('Skipping Flax FrozenDict tests as flax is not installed') + return + + try: + register_pytree_node_serialization(FrozenDict, + serialized_name='FrozenDict', + serialize_auxdata=pickle.dumps, + deserialize_auxdata=pickle.loads) + tree_save(FrozenDict(a=1, b=self.generate_clean_tree()), path) + tree_load(path) + finally: + _remove_from_serialization_registry(FrozenDict) + + def test_register_as_decorator(self): + @partial(register_pytree_node_serialization, + serialized_name='CustomDNode', + serialize_auxdata=json.dumps, + deserialize_auxdata=json.loads) + @partial(jax.tree_util.register_dataclass, data_fields=['a', 'b'], + meta_fields=[]) + @dataclass + class CustomDNode: + a: int + b: int + + # test whether the object can be created (is visible in this scope) + _ = CustomDNode(1, 2) + + def test_custom_node_registration(self): + path = self.create_tempdir() + + @jax.tree_util.register_static + @dataclass + class P: + a: int = 2 + + @partial(jax.tree_util.register_dataclass, data_fields=['a', 'b'], + meta_fields=['op']) + @dataclass + class D: + a: Any + b: Any + op: str + + def serialize_D(data): + return json.dumps(jax.tree.map(lambda x: np.array(x).tolist(), data) + ).encode('utf-8') + + def deserialize_D(data): + return jnp.array(json.loads(data)) + + data = [jnp.ones(1), {'world': [jnp.zeros(3), (jnp.ones(1), jnp.ones(2))]}, + 7 * jnp.ones(()), P()] + + serialize_fn = lambda p: json.dumps(int(p.a)).encode('utf-8') + deserialize_fn = lambda data: P(json.loads(data)) + + with self.assertRaises(ValueError): + tree_save(data, path) + + register_pytree_node_serialization(P, + serialized_name='P', + serialize_auxdata=serialize_fn, + deserialize_auxdata=deserialize_fn) + magic_value = -171 + data[-1].a = jnp.array(magic_value) + tree_save(data, path) + ret = tree_load(path) + self.assertLen(ret, len(data)) + self.assertEqual(ret[-1].a, magic_value) + + magic_val = 17 * jnp.ones(2) + data.append(D(jnp.ones(1), jax.numpy.zeros(2), magic_val)) + with self.assertRaises(ValueError): + tree_save(data, path) + + register_pytree_node_serialization(D, + serialized_name='D', + serialize_auxdata=serialize_D, + deserialize_auxdata=deserialize_D) + tree_save(data, path) + ret = tree_load(path) + self.assertLen(ret, len(data)) + self.assertLess(jnp.linalg.norm(ret[-1].op - magic_val), 1e-5) + + jax.tree.flatten(data) + + def test_masked_reading(self): + path = self.create_tempdir() + data = [jnp.ones(1), {'world': [jnp.zeros(3), (jnp.ones(1), jnp.ones(2))]}, + 7 * jnp.ones(())] + tree_save(data, path) + for mask in [False, True]: + ret = tree_load(path, mask=mask) + expected = jax.tree.map(lambda x: None if not mask else x, data) + self.assertPyTreeEqual(ret, expected, is_leaf=lambda x: x is None) + + mask = [True, False, False] + expected = data[:1] + jax.tree.map(lambda x: None, data[1:]) + ret = tree_load(path, mask=mask) + self.assertPyTreeEqual(ret, expected, is_leaf=lambda x: x is None) + + mask = [True, True, False] + expected = data[:2] + jax.tree.map(lambda x: None, data[2:]) + ret = tree_load(path, mask=mask) + self.assertPyTreeEqual(ret, expected, is_leaf=lambda x: x is None) + + mask = [True, {'world': [True, (False, True)]}, False] + data[1]['world'][1] = (None, data[1]['world'][1][1]) + ret = tree_load(path, mask=mask) + self.assertPyTreeEqual(ret, expected, is_leaf=lambda x: x is None) + + # TODO(rdyro): Remove when serialization supports non-arrays + @parameterized.product(obj=[b'hello', 'hello', 1, 1.0, 1j]) + def test_serialization_works_for_arrays_only(self, obj): + path = self.create_tempdir() + data = [{'world': [jnp.zeros(3), (jnp.ones(1), jnp.ones(2))]}, obj] + msg = ('For serialization, all leaves must be either None or' + ' jax.Array-like objects.') + with self.assertRaisesRegex(ValueError, msg): + tree_save(data, path) + + def test_load_pytreedef(self): + path = self.create_tempdir() + data = [jnp.ones(1), {'world': [jnp.zeros(3), (jnp.ones(1), jnp.ones(2))]}, + 7 * jnp.ones(())] + tree_save(data, path) + pytreedef = tree_load_pytreedef(path) + expected_pytreedef = jax.tree.map( + lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), data) + self.assertPyTreeEqual(pytreedef, expected_pytreedef) + + @parameterized.product(data=[ + None, [None], [None, np.ones(())], + [None, {'world': [None, (np.ones(1), np.ones(2))]}, np.ones(())], + [None, {'world': [np.zeros(3), (None, np.ones(2))]}, None]]) + def test_save_and_load_null_leaves(self, data): + path = self.create_tempdir() + # TPUs might not have X64 enabled, so we need to convert to float32 + data = jax.tree.map(lambda x: jnp.array(x, dtype=jnp.float32), data) + tree_save(data, path) + pytreedef = tree_load_pytreedef(path) + is_leaf = lambda x: x is None + expected_pytreedef = jax.tree.map(lambda x: jax.ShapeDtypeStruct( + x.shape, x.dtype) if x is not None else x, data, is_leaf=is_leaf) + self.assertPyTreeEqual(pytreedef, expected_pytreedef) + load_data = tree_load(path) + load_leaves, load_struct = jax.tree.flatten(load_data, is_leaf=is_leaf) + expected_leaves, expected_struct = jax.tree.flatten(data, is_leaf=is_leaf) + self.assertEqual(load_struct, expected_struct) + self.assertLen(load_leaves, len(expected_leaves)) + for (l1, l2) in zip(load_leaves, expected_leaves): + if l1 is None: + self.assertIsNone(l2) + else: + self.assertArraysEqual(l1, l2) + + @parameterized.product(manually_broadcast_ts_specs=[True, False]) + def test_custom_ts_specs(self, manually_broadcast_ts_specs): + if ts_impl._TS_ARRAY_DRIVER == 'zarr': + self.skipTest('Skipping since this test assumes zarr is NOT the default') + path = self.create_tempdir() + data = [jnp.ones(()), (jnp.zeros(()), jnp.ones(())), None] + ts_spec = {'driver': 'zarr', 'metadata': {'shape': ()}} + if manually_broadcast_ts_specs: + ts_specs = [ts_spec, (ts_spec, None), None] # None ts_spec allowed + else: + ts_specs = ts_spec + tree_save(data, path, ts_specs=ts_specs) + load_data = tree_load(path, ts_specs=ts_specs) + self.assertPyTreeEqual(data, load_data) + with self.assertRaisesRegex(ValueError, + 'NOT_FOUND: Error opening "zarr3" driver:'): + _ = tree_load(path) # default attempts to open with zarr3 and fails + + def test_save_load_future_printable(self): + path = self.create_tempdir() + data = [jnp.ones(())] + save_fut = pytree_serialization.nonblocking_save(data, path) + str(save_fut) + save_fut.result() + load_fut = pytree_serialization.nonblocking_load( + path, shardings=_DEFAULT_SHARDING) + str(load_fut) + load_fut.result() + + def test_format_alone_not_supported(self): + # passing a format for a dtype not matching the dtype on disk will cause an + # XLA error (since formats can be dtype/bit-width specific), hence allow + # format only if dtype is also specified + path = self.create_tempdir() + data = jnp.arange(16 * 16, dtype=jnp.bfloat16).reshape((16, 16)) + sharding = NamedSharding(jtu.create_mesh((1, 1), ('x', 'y')), P('x', None)) + data: jax.Array = jax.device_put(data, sharding) + tree_save(data, path) + with self.assertRaisesRegex(NotImplementedError, + 'Deserialization with `Format` instead of' + ' `Sharding` is not currently supported.'): + pytree_serialization.load(path, shardings=data.format) + + def test_formats_support(self): + path = self.create_tempdir() + data = jnp.arange(16 * 16, dtype=jnp.float32).reshape((16, 16)) + data_bf16_format = jnp.arange(16 * 16, dtype=jnp.bfloat16).reshape( + (16, 16)).format + sharding = NamedSharding(jtu.create_mesh((1, 1), ('x', 'y')), P('x', None)) + data: jax.Array = jax.device_put(data, sharding) + tree_save(data, path) + pytree_serialization.load(path, shardings=jax.ShapeDtypeStruct( + data.shape, jnp.bfloat16, sharding=data_bf16_format)) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/jax/experimental/array_serialization/tensorstore_impl.py b/jax/experimental/array_serialization/tensorstore_impl.py new file mode 100644 index 000000000000..b2aa69ad1955 --- /dev/null +++ b/jax/experimental/array_serialization/tensorstore_impl.py @@ -0,0 +1,599 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from functools import partial +import functools +import os +from os import PathLike +import re +from typing import Any +from collections.abc import Awaitable, Callable, Sequence +import math +import logging + +import jax +from jax import numpy as jnp +from jax._src import array +from jax._src.layout import Format +from jax._src import typing +import numpy as np +import tensorstore as ts + +_TS_ARRAY_DRIVER = "zarr3" + +_TS_CONTEXT = ts.Context({ + 'file_io_concurrency': {'limit': 128}, + 'cache_pool': {'total_bytes_limit': 10_000_000_000}, # 10 GB RAM limit + 'cache_pool#remote': {'total_bytes_limit': 10_000_000_000}, + 'data_copy_concurrency': {'limit': 128} +}) +_TS_CHUNK_LAYOUT = ts.ChunkLayout({ + "chunk": {"elements": 100_000_000}, # 100M (800MB for float64) file size +}) + +_DEFAULT_BASE_DRIVER = 'file' +_PROCESS_DIR_FORMAT = "process_{}" +_FILE_SIZE_TARGET = 2 * 1024 ** 3 # 2 GB + +Future, Transaction = ts.Future, ts.Transaction + +logger = logging.getLogger(__name__) + +# Lifted from T5X. +class _LimitInFlightBytes: + """Limits host scratch memory usage when reading/writing checkpoints per process.""" + + def __init__(self, host_memory_bytes_limit: int): + self._max_bytes = host_memory_bytes_limit + self._available_bytes = host_memory_bytes_limit + self._cv = asyncio.Condition(lock=asyncio.Lock()) + + async def wait_for_bytes(self, requested_bytes): + if requested_bytes > self._max_bytes: + logger.debug("A single array item requests more bytes than we reserved" + " space for in the parallel pool: %d > %d. Increasing the" + " limit to %d.", requested_bytes, self._max_bytes, + requested_bytes) + bytes_currently_used = self._max_bytes - self._available_bytes + self._max_bytes = requested_bytes + self._available_bytes = self._max_bytes - bytes_currently_used + async with self._cv: + await self._cv.wait_for(lambda: self._available_bytes >= requested_bytes) + self._available_bytes -= requested_bytes + assert self._available_bytes >= 0 + + async def release_bytes(self, requested_bytes): + async with self._cv: + self._available_bytes += requested_bytes + assert self._available_bytes <= self._max_bytes + self._cv.notify_all() + +def is_tensorstore_spec_leaf(leaf: Any): + # TODO(rdyro): think of a better way to detect which leaf is a ts config + return leaf is None or (isinstance(leaf, dict) + and ("driver" in leaf or "kvstore" in leaf)) + +def _prime_factors(x: int) -> list[int]: + # find prime factors of axis sizes to help efficiently find divisor chunks + factors = [] + while x % 2 == 0: + factors.append(2) + x //= 2 + for i in range(3, int(math.sqrt(x)) + 1, 2): + while x % i == 0: + factors.append(i) + x //= i + if x > 1: + factors.append(x) + return sorted(factors) + +@functools.lru_cache(maxsize=1024) +def _compute_chunk_shape( + local_shape: Sequence[int], dtype: str | jnp.dtype, + file_size_target: int = _FILE_SIZE_TARGET) -> list[int]: + """Compute a chunk such that it divides the local shape and is less than + target file size. This helps the tensorstore kvstore driver limit the largest + file size on disk to below the ``file_size_target``. We compute a chunk with a + byte size at most 110% of the ``file_size_target``. + """ + local_shape = list(local_shape) + if len(local_shape) == 0 or math.prod(local_shape) == 0: + # a zero size array needs a non-zero chunk passed to tensorstore for compat. + return [max(z, 1) for z in local_shape] + total_size = math.prod(local_shape) * jnp.dtype(dtype).itemsize + axis_prime_factors = [_prime_factors(z) for z in local_shape] + chunk_shape, chunk_size = list(local_shape), total_size + # while chunk_size exceeds target size, reduce chunk_shape + while chunk_size > 1.1 * file_size_target: # 10% buffer + # 1. find the smallest axis divisor across all axes + chosen_axis_idx, chosen_divisor = None, 1 + for axis_idx in range(len(chunk_shape)): + if len(axis_prime_factors[axis_idx]) == 1: # ignore axes sizes == 1 + continue + if (chosen_axis_idx is None + or chosen_divisor > axis_prime_factors[axis_idx][0]): + chosen_axis_idx = axis_idx + chosen_divisor = axis_prime_factors[axis_idx][0] + # 2. if no divisor found, give up, return current chunk shape + if chosen_axis_idx is None: + return chunk_shape + # 3. remove the applied divisor from prime factors + prime_factors = axis_prime_factors[chosen_axis_idx] + prime_factors.pop(0) + # 4. apply the found divisor to reduce the chunk size + chunk_shape[chosen_axis_idx] //= chosen_divisor + chunk_size //= chosen_divisor + return chunk_shape + +def _get_tensorstore_metadata(arr, is_remote: bool = False, + file_size_target: int = _FILE_SIZE_TARGET, + driver: str = _TS_ARRAY_DRIVER) -> dict[str, Any]: + global_shape, dtype = arr.shape, arr.dtype + if isinstance(arr, jax.Array): + local_shape = arr.sharding.shard_shape(global_shape) + else: # np.ndarray + local_shape = global_shape + return _get_tensorstore_metadata_cached(global_shape, dtype, local_shape, + is_remote, file_size_target, driver) + +@functools.lru_cache(maxsize=1024) +def _get_tensorstore_metadata_cached( + global_shape: Sequence[int], dtype: jnp.dtype, local_shape: Sequence[int], + is_remote: bool = False, file_size_target: int = _FILE_SIZE_TARGET, + driver: str = _TS_ARRAY_DRIVER) -> dict[str, Any]: + if driver == "zarr3": + codecs = ([{"name": "zstd"}] if is_remote else []) + return { + 'codecs': codecs, + 'shape': global_shape, + 'data_type': jnp.dtype(dtype).name, + 'chunk_grid': { + 'name': 'regular', + 'configuration': {'chunk_shape': _compute_chunk_shape( + local_shape, dtype, file_size_target=file_size_target)} + } + } + elif driver == "zarr": # in zarr dtype goes in the base spec + return {'compressor': {'id': 'zstd'}, 'shape': global_shape, + 'chunks': np.array(np.maximum(1, local_shape)).tolist()} + else: + raise ValueError(f"Unsupported driver: {driver}") + +_divides = lambda x, y: np.all((np.array(x) % np.array(y)) == 0) + +def merge_nested_ts_specs(dict1: dict[Any, Any], dict2: dict[Any, Any] | None): + """Merge two ts specs, dict2 takes precedence.""" + if dict2 is None: # nothing to do + return dict1 + # TODO(rdyro): this is an opinionated merge, we should get user feedback + # merge kvstore explicitly + kvstore = dict1.get("kvstore", {}) | dict2.get("kvstore", {}) + return dict1 | dict(dict2, kvstore=kvstore) # merge with dict2 preferred + +def verify_tensorstore_spec(spec: dict[str, Any], arr: jax.Array | None, + path: str | os.PathLike[str], ocdbt: bool, + check_metadata: bool = True) -> None: + """Verify the minimum requirements for a tensorstore spec.""" + if ocdbt: + if spec.get("kvstore", {}).get("driver", "") != "ocdbt": + raise ValueError(f"Expected ocdbt driver, got {spec=}") + if check_metadata: + if arr is None: + raise ValueError("Array is required for metadata verification.") + metadata = spec['metadata'] + if spec.get("driver", "") == "zarr3": + if metadata['data_type'] != jnp.dtype(arr.dtype).name: + raise ValueError(f"Provided dtype ({metadata['data_type']=}) doesn't" + f" match ({arr.dtype=})") + if 'shape' in metadata: + if metadata['shape'] != arr.shape: + raise ValueError(f"Provided shape ({metadata['shape']=}) doesn't match" + f" ({arr.shape=})") + if isinstance(arr, jax.Array): + local_shape = arr.sharding.shard_shape(arr.shape) + else: # np.ndarray + local_shape = arr.shape # pytype: disable=attribute-error + if spec.get("driver", "") == "zarr3": + chunk_shape = metadata['chunk_grid']['configuration']['chunk_shape'] + if not _divides(local_shape, chunk_shape): + raise ValueError(f"Provided chunk shape {chunk_shape} does not divide" + f" the local shape of the array {local_shape}") + # check path is still the same one we expect + if ocdbt: + found_path = spec["kvstore"]['base']['path'] + else: + found_path = spec["kvstore"]['path'] + if str(found_path) != str(path): + raise ValueError(f"Provided {path=} does not match the spec path:" + f" {spec['kvstore']}") + +def _spec_has_metadata(tree): + if not isinstance(tree, dict): + return False + return 'metadata' in tree or any( + _spec_has_metadata(subtree) for _, subtree in tree.items()) + +def _get_kvstore_for_gcs(ckpt_path: str): + m = re.fullmatch('^gs://([^/]*)/(.*)$', ckpt_path) + if m is None: + raise ValueError('The ckpt_path should contain the bucket name and the ' + f'file path inside the bucket. Got: {ckpt_path}') + bucket = m.group(1) + path_without_bucket = m.group(2) + return {'driver': 'gcs', 'bucket': bucket, 'path': path_without_bucket} + +def _get_kvstore_for_s3(ckpt_path: str): + m = re.fullmatch('^s3://([^/]*)/(.*)$', ckpt_path, re.DOTALL) + if m is None: + raise ValueError('The ckpt_path should contain the bucket name and the ' + f'file path inside the bucket. Got: {ckpt_path}') + bucket = m.group(1) + path_without_bucket = m.group(2) + return {'driver': 's3', 'bucket': bucket, 'path': path_without_bucket} + +def get_tensorstore_spec( + ckpt_path: str | PathLike[str], ocdbt: bool = True, + process_idx: int | None = None, arr: jax.Array | None = None, + driver: str = _TS_ARRAY_DRIVER) -> dict[str, Any]: + + # Normalize path to exclude trailing '/'. In GCS path case, normpath will + # replace a the double '//' with a single '/' and we need to restore the + # filesystem type:// prefix for GCS (gs://) and S3 paths (s3://) + ckpt_path = os.path.normpath(str(ckpt_path)) + ckpt_path = re.sub(r"^([a-z]+):/", r"\1://", ckpt_path) + + # in cases of multi-process writes, we need to write to a different location + # for each process and finally created a combined symlink to the final + # location, tensorstore can do this via ts.KvStore.experimental_copy_range_to + if process_idx is not None: + _parent, _name = os.path.split(ckpt_path) + ckpt_path = os.path.join(_parent, _PROCESS_DIR_FORMAT.format(process_idx), + _name) + + is_gcs_path = ckpt_path.startswith('gs://') + is_s3_path = ckpt_path.startswith('s3://') + spec = {'driver': driver, 'kvstore': {}} + + # use a combined OCDBT store, the actual path is the parent path + # the name (filename/last part of the path) is the key in the ocdbt kvstore + entry_key = None + if ocdbt: + (ckpt_path, entry_key), org_ckpt_path = os.path.split(ckpt_path), ckpt_path + if is_gcs_path: + m = re.fullmatch('^gs://([^/]*)/(.*)$', ckpt_path) + elif is_s3_path: + m = re.fullmatch('^s3://([^/]*)/(.*)$', ckpt_path) + else: + m = re.match("a", "a") # make it True + if m is None: + raise ValueError('Using OCDBT requires the bucket name, the directory' + ' name and the array name, your path is: ' + f'{org_ckpt_path}') + + if is_gcs_path: + base_kvstore = _get_kvstore_for_gcs(ckpt_path) + elif is_s3_path: + base_kvstore = _get_kvstore_for_s3(ckpt_path) + else: + base_kvstore = {'driver': _DEFAULT_BASE_DRIVER, 'path': ckpt_path} + + if ocdbt: + if not is_gcs_path and not is_s3_path and not os.path.isabs(ckpt_path): + raise ValueError(f'Checkpoint path should be absolute. Got {ckpt_path}') + spec['kvstore'] = {'driver': 'ocdbt', 'base': base_kvstore, + 'path': entry_key} + else: + spec['kvstore'] = base_kvstore + # done writing tensorstore spec based on destination path + # optionally, if array is provided, we can add metadata to the spec + if arr is not None: + spec["metadata"] = _get_tensorstore_metadata( + arr, driver=str(spec["driver"])) + return spec + +async def _create_async_array_from_callback( + global_shape: array.Shape, + dtype: str | jnp.dtype | None, + inp_sharding: jax.sharding.Sharding, + data_callback: Callable[[array.Index, jax.Device], Awaitable[jax.Array]], +): + device_to_index_map = inp_sharding.devices_indices_map(global_shape) + addressable_da = inp_sharding._addressable_device_assignment + future_arrays = [data_callback(device_to_index_map[d], d) + for d in addressable_da] + dbs = await asyncio.gather(*future_arrays) + return array.make_array_from_single_device_arrays( + global_shape, inp_sharding, dbs, dtype=dtype) + +async def _transfer_shard_to_host(shard: array.Shard) -> np.ndarray: + data = shard.data + has_pinned_host = any( + m.kind == "pinned_host" for m in shard.device.addressable_memories()) + if has_pinned_host: + # If available, transfer to pinned host memory + sharding = jax.sharding.SingleDeviceSharding(shard.device, + memory_kind="pinned_host") + data = jax.device_put(data, sharding) + else: + data.copy_to_host_async() + # Allow other transfers to be scheduled simultaneously + await asyncio.sleep(0) + # Ensure that jax.Array's internal numpy array can be zero-copied. Tensorstore + # implicitly converts the written data to a numpy array, and would otherwise + # silently copy host-to-host. + return np.array(data, copy=False) + +async def combine_kvstores(combined_kvstore: dict[str, Any], + kvstores: list[dict[str, Any]], + context: ts.Context | dict[str, Any] = _TS_CONTEXT + ) -> None: + """Merge a list of kvstores into a single kvstore. NOT multi-process safe.""" + combined_fut = ts.KvStore.open(combined_kvstore, context=context) + kvstores_futs = [ts.KvStore.open(kvstore, context=context) + for kvstore in kvstores] + combined, kvstores = await asyncio.gather(combined_fut, + asyncio.gather(*kvstores_futs)) + tx = ts.Transaction() + await asyncio.gather(*[kvstore.experimental_copy_range_to( + combined.with_transaction(tx)) for kvstore in kvstores]) + await tx.commit_async() + +async def async_serialize( + arr_inp, + tensorstore_spec, + commit_future=None, + context=_TS_CONTEXT, + chunk_layout=_TS_CHUNK_LAYOUT, + primary_host: int | None = None, + replica_id: int = 0, + transaction: ts.Transaction | None = None, +): + """Serialize an array using TensorStore. + + Args: + arr_inp: The array to serialize. + tensorstore_spec: The tensorstore spec to use. + commit_future: A list of futures that will be appended to. The futures can + be awaited asynchronously. If None, the futures will be awaited + synchronously by this method. + context: ts.Context instance. + primary_host: Primary host, which indicates the host that will be treated as + the "leader". If None, all hosts are treated as the primary. DO NOT USE + unless you are sure you know what you are doing. + replica_id: Allows overriding the shard replica id that will be saved. DO + NOT USE unless you are sure you know what you are doing. + transaction: TensorStore transaction to use for opening and writing the + array. If not specified, a non-transactional write will be used. + """ + if (isinstance(arr_inp, array.ArrayImpl) and jax.process_count() > 1 and + arr_inp.is_fully_addressable): + raise ValueError( + f'Passing fully addressable arrays to a multiprocess ' + f'serialization is not allowed, as this may lead to a race condition ' + f'between processes. Serialization have failed for the array with ' + f'the path from kvstore: "{tensorstore_spec["kvstore"]}".') + + # 'metadata' may not be present at the top level (for example, if we are using + # a 'cast' driver). + if not _spec_has_metadata(tensorstore_spec): + tensorstore_spec['metadata'] = _get_tensorstore_metadata( + arr_inp, driver=tensorstore_spec['driver']) + ## zarr driver requires specifying the dtype in the spec base + if tensorstore_spec['driver'] == 'zarr' and 'dtype' not in tensorstore_spec: + tensorstore_spec['dtype'] = jnp.dtype(arr_inp.dtype).name + + # If primary_host is None, all hosts will checkpoint. This is used + # for checkpointing to local filesystem. + if primary_host is None or jax.process_index() == primary_host: + open_future = ts.open( + ts.Spec(tensorstore_spec), + create=True, + open=True, + context=context, + chunk_layout=chunk_layout, + transaction=transaction, + ) + # Asynchronous case. + if commit_future is not None: + assert isinstance(commit_future, list) + commit_future.append(open_future) + else: + await open_future + + # `ts.open` runs twice for process `primary_host` because for the first time, + # we just get the future to be awaited upon in the background thread. The + # second one runs with `assume_metadata=True` which does no I/O operation and + # returns the tensorstore object. + # For every process other than `primary_host`, we open with + # `assume_metadata=True`. + t = await ts.open( + ts.Spec(tensorstore_spec), + open=True, + assume_metadata=True, + context=context, + chunk_layout=chunk_layout, + transaction=transaction, + ) + + async def _write_array(shard): + if shard.replica_id == replica_id: + data = await _transfer_shard_to_host(shard) + write_future = t[shard.index].write( + data, + # Avoid additional copy of input array into the TensorStore chunk + # cache. If `arr_inp` is a jax.Array, the result of converting + # it to a NumPy array, as is done internally by TensorStore, is + # guaranteed to be immutable and therefore it is safe to retain a + # reference indefinitely. + can_reference_source_data_indefinitely=isinstance( + arr_inp, array.ArrayImpl + ), + ) + if commit_future is not None: + assert isinstance(commit_future, list) + commit_future.append(write_future.commit) + await write_future.copy + else: + await write_future.commit + + local_shards = arr_inp.addressable_shards + future_write_state = jax.tree_util.tree_map(_write_array, local_shards) + return await asyncio.gather(*future_write_state) + + +# TODO(rdyro): Remove this function. +def _run_serialization(arrays, tensorstore_specs): + """Legacy serialization of a list of arrays.""" + async def _run_serializer(): + future_writer = jax.tree_util.tree_map(async_serialize, arrays, tensorstore_specs) + return await asyncio.gather(*future_writer) + asyncio.run(_run_serializer()) + + +def estimate_read_memory_footprint(t: ts.TensorStore, + domain: ts.IndexDomain) -> int: + rank = t.rank + num_bytes = t.dtype.numpy_dtype.itemsize + chunk_template = t.chunk_layout.read_chunk_template + if domain is None: + domain = t.domain + origin = domain.origin + shape = domain.shape + chunk_origin = chunk_template.origin + chunk_shape = chunk_template.shape + + # Some TensorStore drivers are not chunked, e.g. the inline 'array' driver. + # For those, instead of returning a near-infinite memory footprint, estimate + # the footprint as the entire shape. + for i in range(rank): + if not chunk_template[i].finite: + return domain.size * num_bytes + + # Otherwise, if we have a chunked driver, estimate based on chunk size. + for i in range(rank): + origin_value = origin[i] + chunk_origin_value = chunk_origin[i] + chunk_size = chunk_shape[i] + lower = origin_value - chunk_origin_value + upper = origin_value + shape[i] - chunk_origin_value + lower_aligned = lower // chunk_size * chunk_size + upper_aligned = -(-upper // chunk_size) * chunk_size + num_bytes *= (upper_aligned - lower_aligned) + + return num_bytes + + +async def async_deserialize( + in_type: jax.sharding.Sharding | Format | jax.ShapeDtypeStruct, + tensorstore_spec: ts.Spec | dict[str, Any], + global_shape: Sequence[int] | None = None, + dtype=None, + byte_limiter: _LimitInFlightBytes | None = None, + context=_TS_CONTEXT, + chunk_layout=_TS_CHUNK_LAYOUT, + assume_metadata: bool = False, +): + """Main performant deserialization routine for arrays using tensorstore.""" + if isinstance(in_type, Format): + in_sharding, layout = in_type.sharding, in_type.layout + elif isinstance(in_type, jax.ShapeDtypeStruct): + dtype = in_type.dtype if dtype is None else dtype + in_sharding = in_type.sharding + layout = in_type.format.layout + else: + if not isinstance(in_type, jax.sharding.Sharding): + raise TypeError( + 'sharding passed to deserialization should be specified, concrete and' + f' an instance of `jax.sharding.Sharding`. Got {in_type}') + in_sharding = in_type + layout = None + assert isinstance(in_sharding, jax.sharding.Sharding) + t = await ts.open( + tensorstore_spec, + open=True, + assume_metadata=assume_metadata, + context=context, + chunk_layout=chunk_layout, + ) + shape = t.shape if global_shape is None else global_shape + dtype = dtype if dtype is not None else t.dtype.numpy_dtype + new_shard_shape = in_sharding.shard_shape(tuple(shape)) + + async def cb(index: array.Index, device: jax.Device): + requested_domain = ts.IndexTransform(input_shape=shape)[index].domain + restricted_domain = t.domain.intersect(requested_domain) + requested_bytes = estimate_read_memory_footprint(t, restricted_domain) + # Limit the bytes read for every shard. + if byte_limiter is not None: + await byte_limiter.wait_for_bytes(requested_bytes) + # This maybe needed because the shape the array was saved with is smaller + # than the requested shape of the array in which it will be reloaded. So + # the extra values will be filled with 0s. + out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype) + await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]][ + restricted_domain].write(t[restricted_domain]) + if dtype is not None: + # Cast while reloading on process to avoid 2 copies on device if the + # casting is done on device. + out = out.astype(dtype) + # Convert to jnp array so that layouts are initialized properly for + # sub-byte dtypes. + # TODO(yashkatariya): This is a band-aid fix. Figure out a better way to + # make this work. + if out.dtype == jnp.int4: + out = jnp.asarray(out) # type: ignore + result = jax.device_put( + out, Format(layout, jax.sharding.SingleDeviceSharding(device))) + if byte_limiter is not None: + # NB: `out` actually might not be ready for garbage collection by the + # time we call release_bytes . Thus peak memory usage still might grow + # beyond what byte_limiter limit suggests it should. The simplest option + # would be to call `result.block_until_ready()`` here. However it + # also comes with ~15-20% perf penalty as we would be waiting for CPU->GPU + # transfer instead of loading data. In the future, if memory pressure + # becomes a problem, we can instead instrument bytelimiter to + # keep track of all in-flight tensors and only block_until_ready, if byte + # limiter hits the limit to get reduced memory usage, without losing + # performance in common use cases. + await byte_limiter.release_bytes(requested_bytes) + return result + + # for deserialization canonicalize dtype to a dtype representable in jax + return await _create_async_array_from_callback( + tuple(shape), jax.dtypes.canonicalize_dtype(dtype), in_sharding, cb) + + +# TODO(rdyro): Remove this function. +def _run_deserialization(shardings: Sequence[jax.sharding.Sharding | Format], + tensorstore_specs: Sequence[dict[str, Any] | ts.Spec], + global_shapes: Sequence[array.Shape] | None = None, + dtypes: Sequence[typing.DTypeLike] | None = None, + concurrent_gb: int = 32): + """Legacy deserialization of a list of arrays. Optionally pass global_shapes + and dtypes for type-checking. + """ + concurrent_bytes = concurrent_gb * 10**9 + + async def _run_deserializer(): + # Object should be created once per process. + byte_limiter = _LimitInFlightBytes(concurrent_bytes) + + future_arrays = jax.tree_util.tree_map( + partial(async_deserialize, byte_limiter=byte_limiter), + list(shardings), list(tensorstore_specs), + [None] * len(tensorstore_specs) if global_shapes is None else global_shapes, + [None] * len(tensorstore_specs) if dtypes is None else dtypes) + return await asyncio.gather(*future_arrays) + return asyncio.run(_run_deserializer()) diff --git a/jax/experimental/attrs.py b/jax/experimental/attrs.py deleted file mode 100644 index 4e1dc4b8f493..000000000000 --- a/jax/experimental/attrs.py +++ /dev/null @@ -1,236 +0,0 @@ -# Copyright 2024 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from typing import Any, Callable - -from jax._src import core -from jax._src import source_info_util -from jax._src import api_util -from jax._src import linear_util as lu -from jax._src.ad_util import (Zero) -from jax._src.api_util import flatten_fun_nokwargs -from jax._src.interpreters import ad -from jax._src.interpreters import partial_eval as pe -from jax._src.tree_util import (tree_flatten, tree_unflatten, tree_structure, - treedef_tuple) -from jax._src.util import unzip2, safe_map, safe_zip, split_list -from jax._src.dtypes import dtype, float0 - -map, unsafe_map = safe_map, map -zip, unsafe_zip = safe_zip, zip - -JaxVal = Any -Pytree = Any - -register = api_util.register_class_with_attrs - -def jax_getattr(obj: Any, attr: str): - with core.take_current_trace() as t: - return t.process_getattr(obj, attr) - -def jax_setattr(obj: Any, attr: str, val: Pytree): - with core.take_current_trace() as t: - return t.process_setattr(obj, attr, val) - -def _getattr_impl(_, obj, attr): - return getattr(obj, attr) -core.EvalTrace.process_getattr = _getattr_impl - -def _setattr_impl(_, obj, attr, val): - setattr(obj, attr, val) -core.EvalTrace.process_setattr = _setattr_impl - -def _ensure_tracked(trace: pe.DynamicJaxprTrace, obj: Any, attr: str): - frame = trace.frame - - def new_tracer(x): - aval = core.get_aval(x) - tracer = pe.DynamicJaxprTracer(trace, aval, pe.source_info_util.current()) - var = frame.tracer_to_var[id(tracer)] = frame.newvar(aval) - frame.attrs_vars.append(var) - frame.tracers.append(tracer) - return tracer - - if (obj, attr) not in frame.attrs_tracked: - init_val = getattr(obj, attr) - frame.attrs_inits.append(init_val) - init_vals, init_tree = tree_flatten(init_val) - tracers = map(new_tracer, init_vals) - setattr(obj, attr, tree_unflatten(init_tree, tracers)) - frame.attrs_tracked.append((obj, attr)) -pe.DynamicJaxprTrace._ensure_tracked = _ensure_tracked - -def _getattr_staging(trace, obj, attr): - trace._ensure_tracked(obj, attr) - return getattr(obj, attr) -pe.DynamicJaxprTrace.process_getattr = _getattr_staging - -def _setattr_staging(trace, obj, attr, val): - trace._ensure_tracked(obj, attr) - setattr(obj, attr, val) -pe.DynamicJaxprTrace.process_setattr = _setattr_staging - - -def jvp(f, primals, tangents, attr_tangents): - attrs, attr_tangents = unzip2(((o, a), t) for o, a, t in attr_tangents) - attr_primals = tuple(jax_getattr(o, a) for o, a in attrs) - primals_flat, in_tree = tree_flatten((attr_primals, *primals)) - tangents_flat, in_tree_ = tree_flatten((attr_tangents, *tangents)) - if in_tree != in_tree_: raise Exception - dbg = api_util.debug_info("attrs_jvp", f, primals, {}) - f_, out_tree = flatten_fun_nokwargs( - _set_attrs(lu.wrap_init(f, debug_info=dbg), attrs), in_tree) - out_primals_flat, out_tangents_flat, tangent_attrs_out = _jvp(f_).call_wrapped( - primals_flat, tangents_flat) - out_primals = tree_unflatten(out_tree(), out_primals_flat) - out_tangents = tree_unflatten(out_tree(), out_tangents_flat) - return out_primals, out_tangents, tangent_attrs_out - -@lu.transformation2 -def _set_attrs(f, attrs, attr_vals, *args): - for (o, a), x in zip(attrs, attr_vals): - jax_setattr(o, a, x) - return f(*args) - -def _jvp(fun: lu.WrappedFun): - return jvpfun2(jvp_subtrace2(fun)) - -@lu.transformation2 -def jvpfun2(f, primals, tangents): - tag = core.TraceTag() - tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero) - and dtype(t) == float0 else t for t in tangents] - ctx = source_info_util.transform_name_stack('jvp') - with ctx: - out_primals, out_tangents, tangent_attrs_out = f(tag, primals, tangents) - return out_primals, out_tangents, tangent_attrs_out - -@lu.transformation2 -def jvp_subtrace2(f, tag, primals, tangents): - with core.take_current_trace() as parent_trace: - trace = ad.JVPTrace(parent_trace, tag) - tag.attrs_tracked = [] # attrs written to - in_tracers = [ad.JVPTracer(trace, x, t) if type(t) is not ad.Zero else x - for x, t in zip(primals, tangents)] - with core.set_current_trace(trace): - ans = f(*in_tracers) - out_primals, out_tangents = unzip2(map(trace.to_primal_tangent_pair, ans)) - tangent_attrs_out = [] - for (obj, name) in tag.attrs_tracked: - primal, tangent = trace.to_primal_tangent_pair(jax_getattr(obj, name)) - jax_setattr(obj, name, primal) - if type(tangent) is not ad.Zero: - tangent_attrs_out.append((obj, name, tangent)) - del tag.attrs_tracked - return out_primals, out_tangents, tangent_attrs_out - -def _setattr_jvp(trace, obj, attr, maybe_tracer): - primal, tangent = trace.to_primal_tangent_pair(maybe_tracer) - if isinstance(tangent, ad.Zero): - return setattr(obj, attr, primal) - if (obj, attr) not in trace.tag.attrs_tracked: - trace.tag.attrs_tracked.append((obj, attr)) - return setattr(obj, attr, ad.JVPTracer(trace, primal, tangent)) -ad.JVPTrace.process_setattr = _setattr_jvp - -def _getattr_jvp(trace, obj, attr): - return getattr(obj, attr) -ad.JVPTrace.process_getattr = _getattr_jvp - -ad.LinearizeTrace.process_setattr = _setattr_jvp -ad.LinearizeTrace.process_getattr = _getattr_jvp - -def linearize(f: Callable, *primals, attrs: list[tuple[Any, str]] = []): - attr_primals = [jax_getattr(o, a) for o, a in attrs] - attr_avals = [core.get_aval(p) for p in attr_primals] - primals_flat, in_tree = tree_flatten(primals) - tree = treedef_tuple((tree_structure(attr_primals), *in_tree.children())) - dbg = api_util.debug_info("attrs linearize", f, primals, {}) - f_, out_tree = flatten_fun_nokwargs( - _set_attrs(lu.wrap_init(f, debug_info=dbg), attrs), tree) - primal_out, out_pvals, jaxpr, consts, attrs_out = _linearize( - f_, *attr_primals, *primals_flat) - f_lin = _lin_wrap(jaxpr, consts, out_pvals, attr_avals, (in_tree, out_tree()), - attrs, attrs_out) - return tree_unflatten(out_tree(), primal_out), f_lin - -def _linearize(traceable: lu.WrappedFun, *primals): - jvpfun, attrs = _split_attrs(_jvp(traceable)) - in_pvals = (tuple(pe.PartialVal.known(p) for p in primals) - + tuple(pe.PartialVal.unknown(core.get_aval(p).to_tangent_aval()) - for p in primals)) - _, in_tree = tree_flatten((primals, primals)) - jvpfun_flat, out_tree = flatten_fun_nokwargs(jvpfun, in_tree) - jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals) - out_primals_pvals, out_tangents_pvals, out_tangent_attr_pvals = \ - tree_unflatten(out_tree(), out_pvals) - out_primals_consts = [pval.get_known() for pval in out_primals_pvals] - return (out_primals_consts, [*out_tangents_pvals, *out_tangent_attr_pvals], - jaxpr, consts, attrs()) - -@lu.transformation_with_aux2 -def _split_attrs(f, store, *args, **kwargs): - primals, tangents, tangent_attrs = f(*args, **kwargs) - attrs, tangent_attr_vals = unzip2(((o, a), t) for o, a, t in tangent_attrs) - store.store(attrs) - return primals, tangents, tangent_attr_vals - -def _lin_wrap(jaxpr, consts, out_pvals, attr_avals, io_tree, in_attrs, out_attrs): - in_tree, out_tree = io_tree - def f_lin(*tangents, attr_tangents): - if set(attr_tangents) - set(in_attrs): raise Exception - tangents_, in_tree_ = tree_flatten(tangents) - assert in_tree == in_tree_ - attr_tangents_ = [attr_tangents.get(a, ad.Zero(aval)) - for a, aval in zip(in_attrs, attr_avals)] - out = core.eval_jaxpr(jaxpr, consts, *attr_tangents_, *tangents_) - out_ = iter(out) - out = [p.get_known() if p.is_known() else next(out_) for p in out_pvals] - assert next(out_, None) is None - tangents_out, attr_tangents_out = split_list(out, [len(out)-len(out_attrs)]) - out_ct = tree_unflatten(out_tree, tangents_out) - return out_ct, dict(zip(out_attrs, attr_tangents_out)) - return f_lin - - -def vjp(f, *primals, attrs: list[tuple[Any, str]] = []): - attr_primals = [jax_getattr(o, a) for o, a in attrs] - primals_flat, in_tree = tree_flatten(primals) - tree = treedef_tuple((tree_structure(attr_primals), *in_tree.children())) - dbg = api_util.debug_info("attrs vjp", f, primals, {}) - f_, out_tree = flatten_fun_nokwargs( - _set_attrs(lu.wrap_init(f, debug_info=dbg), attrs), tree) - primal_out, out_pvals, jaxpr, consts, attrs_out = _linearize( - f_, *attr_primals, *primals_flat) - attr_avals = [core.get_aval(jax_getattr(o, a)).to_tangent_aval() - for o, a in attrs_out] - f_vjp = _vjp_wrap(jaxpr, consts, out_pvals, attr_avals, (in_tree, out_tree()), - attrs, attrs_out) - return tree_unflatten(out_tree(), primal_out), f_vjp - -def _vjp_wrap(jaxpr, consts, out_pvals, attr_avals, io_tree, in_attrs, out_attrs): - in_tree, out_tree = io_tree - dummies = [ad.UndefinedPrimal(v.aval) for v in jaxpr.invars] - def f_vjp(out_ct, *, attr_cotangents: dict[tuple[Any, str], JaxVal] = {}): - out_cts, out_tree_ = tree_flatten(out_ct) - assert out_tree == out_tree_ - attr_cts = [attr_cotangents.get(a, ad.Zero(aval)) - for a, aval in zip(out_attrs, attr_avals)] - out = ad.backward_pass(jaxpr, (), consts, dummies, (*out_cts, *attr_cts)) - in_attr_bars, arg_cts = split_list(out, [len(in_attrs)]) - args_ct = tree_unflatten(in_tree, map(ad.instantiate_zeros, arg_cts)) - return args_ct, dict(zip(in_attrs, in_attr_bars)) - return f_vjp diff --git a/jax/experimental/buffer_callback.py b/jax/experimental/buffer_callback.py new file mode 100644 index 000000000000..f919cfa10208 --- /dev/null +++ b/jax/experimental/buffer_callback.py @@ -0,0 +1,20 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jax._src.buffer_callback import ( + Buffer as Buffer, + ExecutionContext as ExecutionContext, + ExecutionStage as ExecutionStage, + buffer_callback as buffer_callback, +) diff --git a/jax/experimental/colocated_python/api.py b/jax/experimental/colocated_python/api.py index b855bba48abb..5cfa942f5cbf 100644 --- a/jax/experimental/colocated_python/api.py +++ b/jax/experimental/colocated_python/api.py @@ -16,18 +16,70 @@ from __future__ import annotations import collections -from typing import Any, Callable, Sequence, Type +from typing import Any, overload +from collections.abc import Callable, Sequence import jax from jax._src import api_util +from jax._src import util from jax.experimental.colocated_python.func import make_callable from jax.experimental.colocated_python.obj import wrap_class +import numpy as np +@overload def colocated_cpu_devices( - devices: Sequence[jax.Device], + devices_or_mesh: Sequence[jax.Device], +) -> Sequence[jax.Device]: + ... + + +@overload +def colocated_cpu_devices( + devices_or_mesh: jax.sharding.Mesh, +) -> jax.sharding.Mesh: + ... + + +def colocated_cpu_devices(devices_or_mesh): + """Finds devices or a mesh that has CPU devices colocated with the given devices or mesh. + + An accelerator device often accompanies a CPU device that is on the same host. + Furthermore, when a single host has multiple accelerator devices, there can be + multiple CPU devices, each of which is associated with one of the accelerator + devices with a 1:1 correspondence. + + This function finds the colocated CPU devices for the given devices or mesh. + When the input is a mesh, the returned value is another mesh that has the same + shape as the input mesh but has colocated CPU devices. If an input device is + already a CPU device, it is returned as-is. + + It preserves ordering. The output CPU device at index i is associated with the + input accelerator device at index i. + + Args: + devices_or_mesh: A tuple of devices or a mesh. + + Returns: + A tuple of devices or a mesh that has the colocated CPU devices. + """ + if isinstance(devices_or_mesh, jax.sharding.Mesh): + return _colocated_cpu_mesh_cached(devices_or_mesh) + + if not isinstance(devices_or_mesh, tuple): + devices_or_mesh = tuple(devices_or_mesh) + try: + return _colocated_cpu_devices_cached(devices_or_mesh) + except (ValueError, AttributeError): + return _colocated_cpu_devices_cached_fallback_to_cpu_backend( + devices_or_mesh + ) + + +@util.cache(max_size=1024, trace_context_in_key=False) +def _colocated_cpu_devices_cached( + devices: tuple[jax.Device, ...], ) -> Sequence[jax.Device]: - """Finds CPU devices colocated with the given devices.""" cpu_devices_by_colocation_id = collections.defaultdict(list) for device in devices[0].client._get_all_devices(): # pylint: disable=protected-access if device.device_kind == "cpu": @@ -49,13 +101,81 @@ def colocated_cpu_devices( return colocated_cpu_devices -def colocated_python(fun: Callable[..., Any]) -> Callable[..., Any]: - """Executes the given Python function on the same devices as the arguments.""" +@util.cache(max_size=1024, trace_context_in_key=False) +def _colocated_cpu_devices_cached_fallback_to_cpu_backend( + devices: tuple[jax.Device, ...], +) -> Sequence[jax.Device]: + # TODO(hyeontaek): Remove this fallback path once a PjRt-IFRT backend defines + # CPU devices by its own instead of using a separate CPU backend. + if devices[0].device_kind == "cpu": + # Use the devices from the backend of an original device if it defines CPU + # devices. + cpu_backend_devices = [d for d in devices[0].client._get_all_devices() + if d.device_kind == "cpu"] + else: + # PjRt-IFRT on a non-CPU platform currently defines CPU devices on a separae + # CPU backend. + cpu_backend_devices = jax.devices(backend="cpu") + device_index_map = {device.id: i for i, device in enumerate(jax.devices())} + + available_devices = devices[: min(len(cpu_backend_devices), len(devices))] + return [ + cpu_backend_devices[device_index_map[d.id]] for d in available_devices + ] + + +@util.cache(max_size=1024, trace_context_in_key=False) +def _colocated_cpu_mesh_cached(mesh: jax.sharding.Mesh) -> jax.sharding.Mesh: + """Returns a CPU mesh that is similar to the given mesh but has colocated CPU devices.""" + # Finding colocated CPU devices reuses the cache of `colocated_cpu_devices` + # called with devices. `_colocated_cpu_mesh` itself is also cached to avoid + # creating a new `Mesh` object repeatedly. + flat_cpu_devices = colocated_cpu_devices(tuple(mesh.devices.flat)) + return jax.sharding.Mesh( + np.array(flat_cpu_devices).reshape(mesh.axis_sizes), + mesh.axis_names, + axis_types=mesh.axis_types, + ) + + +def colocated_python(fun: Callable[..., Any]): + """Executes the given Python function on the same devices as the arguments. + + The returned colocated Python callable lets the user run a serializable Python + function on the same devices as the arguments, potentially on remote hosts. + + Python callable implements `specialize` and `__call__` methods. See their + docstrings for details and https://docs.jax.dev/en/latest/notebooks/colocated-python.html + for examples. + + Args: + fun: An original function to wrap as an I/O callable. + + Returns: + Colocated Python callable with no initial specialization. + """ return make_callable( fun, api_util.fun_sourceinfo(fun), api_util.fun_signature(fun) ) -def colocated_python_class(cls: Type[object]) -> Type[object]: - """Executes the given Python class methods on the same devices as the arguments.""" +def colocated_python_class(cls: type[object]) -> type[object]: + """Creates a wrapper class that executes the given Python class methods on the same devices as the arguments. + + The wrapper class exposes the returned type's methods, and can be instantiated + on JAX. An actual object will be instantiated on the host of the devices of + the arguments' when a method of the wrapper instance is called for the first + time. + + The actual object will persist while the wrapper object is alive, and will be + destroyed asynchronously when the wrapper object is destroyed. Note that if + the wrapper object is destroyed immediately without any method call, actual + objects will not be created. + + Args: + cls: The class to wrap as a colocated Python object. + + Returns: + Wrapper class. + """ return wrap_class(cls, api_util.fun_sourceinfo(cls)) diff --git a/jax/experimental/colocated_python/func.py b/jax/experimental/colocated_python/func.py index effca1fe77b7..220f0cfdf540 100644 --- a/jax/experimental/colocated_python/func.py +++ b/jax/experimental/colocated_python/func.py @@ -15,11 +15,14 @@ from __future__ import annotations +from collections.abc import Callable, Sequence import dataclasses import inspect import random import threading -from typing import Any, Callable, Sequence +from typing import Any +import uuid +import weakref import jax from jax._src import api @@ -30,7 +33,8 @@ from jax._src.traceback_util import api_boundary from jax._src.util import wraps from jax.experimental.colocated_python import func_backend -from jax.experimental.colocated_python.serialization import _deserialize_specs, _make_specs_for_serialized_specs, _serialize, _serialize_specs +from jax.experimental.colocated_python.serialization import _deserialize, _deserialize_specs, _make_specs_for_serialized_specs, _serialize, _serialize_specs +from jax.extend.backend import register_backend_cache as jax_register_backend_cache from jax.extend.ifrt_programs import ifrt_programs ShapeDtypeStructTree = Any # PyTree[api.ShapeDtypeStruct] @@ -65,7 +69,7 @@ def update( out_specs_treedef: tree_util.PyTreeDef | None = None, out_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None, devices: Sequence[jax.Device] | xc.DeviceList | None = None, - ) -> Any: + ): """Creates a new specialization with overrides.""" if in_specs_treedef is None: in_specs_treedef = self.in_specs_treedef @@ -169,7 +173,7 @@ def _compile_to_executable( program, compile_options ) out_handlers = pxla.global_avals_to_results_handler( - out_sdss, out_shardings, committed=True + out_sdss, out_shardings, committed=True # type: ignore ).handlers def call(*args, **kwargs): @@ -185,12 +189,14 @@ def call(*args, **kwargs): # TODO(hyeontaek): Implement colocated Python support in McJAX and remove # this fallback path. if "PjRtCompiler requires an HloProgram" in str(e): - return fun + return _deserialize(pickled_function)[0] raise def _make_output_specs_and_push_result_fun( - info: FunctionInfo, specialization: Specialization, uid: int + info: FunctionInfo, + specialization: Specialization, + uid: int, ) -> Callable[..., Any]: """Creates a function that computes output specs and pushes the result to the result store.""" assert specialization.in_specs_treedef is not None @@ -225,7 +231,9 @@ def lowered_fun(*args, **kwargs) -> jax.Array: def _make_pop_result_fun( - info: FunctionInfo, specialization: Specialization, uid: int + info: FunctionInfo, + specialization: Specialization, + uid: int, ) -> Callable[..., Any]: """Makes a function that pops results from the result store.""" assert specialization.out_specs_treedef is not None @@ -234,7 +242,7 @@ def _make_pop_result_fun( out_specs_treedef = specialization.out_specs_treedef - def lowered_fun() -> Any: + def lowered_fun(): result_leaves = func_backend.SINGLETON_RESULT_STORE.pop(uid) return tree_util.tree_unflatten(out_specs_treedef, result_leaves) @@ -258,7 +266,8 @@ def lowered_fun() -> Any: def _make_async_execution_fun( - info: FunctionInfo, specialization: Specialization + info: FunctionInfo, + specialization: Specialization, ) -> Callable[..., Any]: """Makes a function that asynchronously executes the function.""" assert specialization.in_specs_treedef is not None @@ -279,9 +288,9 @@ def _make_async_execution_fun( ) -@jax.util.cache(max_size=None) -def _get_specialized_func( - info: FunctionInfo, specialization: Specialization +def _uncached_get_specialized_func( + info: FunctionInfo, + specialization: Specialization, ) -> Callable[..., Any]: """Returns a specialized function for the given specialization.""" util.test_event("colocated_python_func._get_specialized_func") @@ -294,16 +303,21 @@ def _get_specialized_func( # Asynchronous execution function that has known output_specs. async_execution_func = None - def specialized_func(*args, **kwargs) -> Any: + def specialized_func(*args, **kwargs): """Specialized function to be executed with given args and kwargs.""" nonlocal specialization, async_execution_func with mutex: if async_execution_func is None: if specialization.out_specs_treedef is None: if specialization.out_specs_fn is None: - serialized_out_specs = _make_output_specs_and_push_result_fun( - info, specialization, uid - )(*args, **kwargs) + output_specs_and_push_result_fun = ( + _make_output_specs_and_push_result_fun( + info, specialization, uid + ) + ) + serialized_out_specs = output_specs_and_push_result_fun( + *args, **kwargs + ) # Waits for the output_specs. This may block. out_specs_treedef, out_specs_leaves = _deserialize_specs( @@ -320,6 +334,13 @@ def specialized_func(*args, **kwargs) -> Any: info, specialization ) + # Hold the PyExecutable until async_execution_fun is called at + # least once, so the number of _OBJECT_STORE references at the + # backend does not drop to 0. + async_execution_func.output_specs_and_push_result_fun = ( + output_specs_and_push_result_fun + ) + return _make_pop_result_fun(info, specialization, uid)() else: # Compute out_specs using out_specs_fn and inputs. @@ -347,122 +368,345 @@ def specialized_func(*args, **kwargs) -> Any: # Asynchronous execution runs outside of the mutex to allow concurrent # execution for inline executors. - return async_execution_func(*args, **kwargs) + result = async_execution_func(*args, **kwargs) + with mutex: + async_execution_func.output_specs_and_push_result_fun = None + return result return specialized_func -def make_callable( - fun: Callable[..., Any], - fun_sourceinfo: str | None, - fun_signature: inspect.Signature | None, -) -> Callable[..., Any]: - """Makes a colocated Python callable.""" - return _make_callable( - FunctionInfo(fun, fun_sourceinfo, fun_signature), Specialization() - ) +class _SpecializedCollection: + """Collection of specialized functions for a single unspecialized function. + + The `get()` method retrieves the specialized function for the provided input + spec, either by looking up a cache or by compiling the specialized function. + + Looking up a cache with an input spec as a key can be slow, because + `Sharding`'s equivalence comparison is slow. Instead, we maintain two caches + for the same value: we use the ID of the sharding object (via `WeakSpec`) as + the key in one cache, and the corresponding strong references to the sharding + object (via `StrongSpec`) as the key in another cache. Looking up the + `WeakSpec`-keyed cache is fast. Note that the ID integer in the `WeakSpec` + cache will remain valid as long as a strong-ref exists in the `StrongSpec` + cache. + + The `StrongSpec`-keyed cache is unbounded, while the `WeakSpec`-keyed cache + is LRU(1): if there is a miss in the `WeakSpec` cache but a hit in the + `StrongSpec` cache, the strong-ref is the `StrongSpec` cache and the ID + integer in the `WeakSpec` cache are both updated. + """ + + @dataclasses.dataclass(slots=True, unsafe_hash=True) + class WeakSpec: + """WeakSpec stores just the `id()` of the input spec sharding.""" + + dtypes: tuple[jax.numpy.dtype, ...] + shapes: tuple[tuple[int, ...], ...] + sharding_ids: tuple[int, ...] + treedef: tree_util.PyTreeDef + + def __init__( + self, args_leaves: Sequence[jax.Array], treedef: tree_util.PyTreeDef + ): + self.dtypes = tuple(x.dtype for x in args_leaves) + self.shapes = tuple(x.shape for x in args_leaves) + self.sharding_ids = tuple(id(x.sharding) for x in args_leaves) + self.treedef = treedef + + @dataclasses.dataclass(slots=True, unsafe_hash=True) + class StrongSpec: + """StrongSpec stores the full input spec sharding.""" + + in_specs_treedef: tree_util.PyTreeDef | None = None + in_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None + + def __init__( + self, args_leaves: Sequence[jax.Array], pytreedef: tree_util.PyTreeDef + ): + self.in_specs_leaves = tuple(_get_spec(x) for x in args_leaves) + self.in_specs_treedef = pytreedef + + def __init__(self): + CompiledId = int + + self._weak_to_id: dict[_SpecializedCollection.WeakSpec, CompiledId] = {} + self._id_to_weak: dict[CompiledId, _SpecializedCollection.WeakSpec] = {} + self._strong_to_id: dict[_SpecializedCollection.StrongSpec, CompiledId] = {} + self._id_to_compiled: dict[CompiledId, Callable[..., Any]] = {} + + self._counter = 0 + self._mu = threading.Lock() + + def get( + self, + args_leaves: Sequence[jax.Array], + pytreedef: tree_util.PyTreeDef, + func_info: FunctionInfo, + specialization: Specialization, + ) -> Callable[..., Any]: + # TODO(hyeontaek): Allow Python values in args_leaves, similar to the todo + # in _get_spec(). + + # Attempt fast-path cache hit. + weak_spec = _SpecializedCollection.WeakSpec(args_leaves, pytreedef) + compiled_id = self._weak_to_id.get(weak_spec) + if compiled_id is not None: + return self._id_to_compiled[compiled_id] + + with self._mu: + # Attempt slow-path cache hit. + strong_spec = _SpecializedCollection.StrongSpec(args_leaves, pytreedef) + compiled_id = self._strong_to_id.pop(strong_spec, None) + if compiled_id is not None: + # Update the caches so that the fast-path cache stores the `id()` of the + # shardings presented by the current invocation. + old_weak = self._id_to_weak.pop(compiled_id) + del self._weak_to_id[old_weak] + + self._strong_to_id[strong_spec] = compiled_id + self._weak_to_id[weak_spec] = compiled_id + self._id_to_weak[compiled_id] = weak_spec + + return self._id_to_compiled[compiled_id] + + # Cache-miss: compile. + if specialization.devices is None: + result = _uncached_get_specialized_func( + func_info, + specialization.update( + in_specs_treedef=strong_spec.in_specs_treedef, + in_specs_leaves=strong_spec.in_specs_leaves, + devices=_infer_devices_from_args(args_leaves), + ), + ) + else: + result = _uncached_get_specialized_func( + func_info, + specialization.update( + in_specs_treedef=strong_spec.in_specs_treedef, + in_specs_leaves=strong_spec.in_specs_leaves, + ), + ) + compiled_id = self._counter + self._counter += 1 -def _make_callable( - info: FunctionInfo, - specialization: Specialization, -) -> Callable[..., Any]: - """Internal implementation of make_callable.""" + self._weak_to_id[weak_spec] = compiled_id + self._strong_to_id[strong_spec] = compiled_id + self._id_to_weak[compiled_id] = weak_spec + self._id_to_compiled[compiled_id] = result + return result - def specialize( - in_specs: ShapeDtypeStructTree | None = None, - out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None, - devices: Sequence[jax.Device] | None = None, - ) -> Callable[..., Any]: - """Returns a colocated Python callable with extra specialization. - - Args: - in_specs: Optionally specifies the expected input specs. Input specs are - expressed as a `PyTree[ShapeDtypeStruct]` for `(args, kwargs)` of a - function call. - out_specs_fn: Optionally specifies a function that computes the output - specs from input specs. If unspecified, colocated_python will compute - the output specs during the very first execution, and this execution - will be synchronous. - devices: Optionally specifies the devices to execute the function on. Must - be provided if in_specs has no leaves because devices cannot be inferred - from input specs or arguments. - - Returns: - A colocated Python callable with extra specialization. - """ - # TODO(hyeontaek): Allow unspecified devices for zero-leaf `in_specs` if - # `out_specs_fn(in_specs)` returns at least one leaf that we can use for - # inferring `devices`. - if in_specs is None: - in_specs_leaves, in_specs_treedef = None, None - else: - in_specs_leaves_list, in_specs_treedef = tree_util.tree_flatten(in_specs) - in_specs_leaves = tuple(in_specs_leaves_list) - return _make_callable( - info, - specialization.update( - in_specs_treedef=in_specs_treedef, - in_specs_leaves=in_specs_leaves, - out_specs_fn=out_specs_fn, - devices=devices, - ), - ) - @api_boundary - def __call__(*args, **kwargs) -> Any: - """Executes the function. +class _JaxSecondLevelCaches: + """Manages second-level caches registered as a single cache with JAX.""" - If the output specs are not known, the very first execution will be - synchronous. - """ - args_leaves, in_specs_treedef = tree_util.tree_flatten((args, kwargs)) + def __init__(self, name: str): + self._lock = threading.Lock() + self._callbacks: dict[int, Callable[..., Any]] = {} + jax_register_backend_cache(self, name) - in_specs_leaves = tuple(_get_spec(x) for x in args_leaves) - if specialization.in_specs_treedef is None: - # Allow input polymorphism by applying input_specs specialization - # temporarily for this call. - return _make_callable( + def cache_clear(self): + """Meant to be invoked by JAX internals.""" + for callback in self._callbacks.values(): + callback() + self._callbacks.clear() + + def register_second_level( + self, uid: int, cache_clear_callback: Callable[..., Any] + ): + self._callbacks[uid] = cache_clear_callback + + def remove_second_level(self, uid: int): + try: + self._callbacks.pop(uid) + except KeyError: + pass + + +class _CachedColocatedFunctionMaker: + """Function maker for colocated Python functions. + + Generated functions are stored (cached) indefinitely so that they can be + reused, until the cache is dropped. + """ + + JAX_CACHE = _JaxSecondLevelCaches("colocated_python_specialized_func_cache") + + def __init__(self, held_by: int | None): + self.held_by = held_by if held_by is not None else uuid.uuid4().int + specialized_collections: list[_SpecializedCollection] = [] + specialized_functions: list[Callable[..., Any]] = [] + + def clear_caches(): + specialized_collections.clear() + specialized_functions.clear() + + _CachedColocatedFunctionMaker.JAX_CACHE.register_second_level( + self.held_by, + clear_caches, + ) + self.specialized_collections = specialized_collections + self.specialized_functions = specialized_functions + + def __del__(self): + self.specialized_collections.clear() + self.specialized_functions.clear() + try: + _CachedColocatedFunctionMaker.JAX_CACHE.remove_second_level(self.held_by) + except AttributeError: + # Ignore error during python finalization. + pass + + def _make_callable( + self, + info: FunctionInfo, + specialization: Specialization, + ): + """Internal implementation of make_callable.""" + + def specialize( + in_specs: ShapeDtypeStructTree | None = None, + out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None, + devices: Sequence[jax.Device] | None = None, + ): + """Returns a colocated Python callable with extra specialization. + + Args: + in_specs: Optionally specifies the expected input specs. Input specs are + expressed as a `PyTree[ShapeDtypeStruct]` for `(args, kwargs)` of a + function call. + out_specs_fn: Optionally specifies a function that computes the output + specs from input specs. If unspecified, colocated Python will compute + the output specs during the very first execution, and this execution + will be synchronous. + devices: Optionally specifies the devices to execute the function on. + Must be provided if `in_specs` has no leaves because devices cannot be + inferred from input specs or arguments. + + Returns: + A colocated Python callable with extra specialization. + """ + # TODO(hyeontaek): Allow unspecified devices for zero-leaf `in_specs` if + # `out_specs_fn(in_specs)` returns at least one leaf that we can use for + # inferring `devices`. + if in_specs is None: + in_specs_leaves, in_specs_treedef = None, None + else: + in_specs_leaves_list, in_specs_treedef = tree_util.tree_flatten( + in_specs + ) + in_specs_leaves = tuple(in_specs_leaves_list) + return self._make_callable( info, specialization.update( in_specs_treedef=in_specs_treedef, in_specs_leaves=in_specs_leaves, + out_specs_fn=out_specs_fn, + devices=devices, ), - )(*args, **kwargs) + ) - if specialization.devices is None: - devices = _infer_devices_from_args(args_leaves) - if devices is None: + # Caches for a collection of specialized functions or a specialized function + # itself. The latter is used as a performance optimization when the input + # spec is explicitly specified and can skip a collection lookup. The caches + # use weakrefs so that we avoid creating cyclic references. + specialized_collections_wref = lambda: None + specialized_functions_wref = lambda: None + wref_mu = threading.Lock() + + @api_boundary + def __call__(*args, **kwargs): + """Executes the given Python function on the same devices as the arguments or as specialized. + + If the callable has not been specialized with output shapes and shardings + (see `specialize` above), the very first call will run synchronously to + discover output shapes and shardings, and will run asynchronously after. + If specialized with output shapes and shardings, every execution of the + callable will be asynchronous. + """ + args_leaves, in_specs_treedef = tree_util.tree_flatten((args, kwargs)) + + no_input = len(args_leaves) == 0 + if no_input and specialization.devices is None: raise ValueError( "No devices found. colocated_python function without input" " arguments must be first specialized with devices." ) - # Allow device polymorphism by applying devices specialization temporarily - # for this call. - return _make_callable(info, specialization.update(devices=devices))( - *args, **kwargs + + fully_specified_in_spec = ( + specialization.in_specs_treedef is not None + and specialization.in_specs_leaves is not None ) - # Assertion is added to silence mypy error: Unsupported operand types for != - # ("PyTreeDef" and "None") [operator] - assert isinstance(specialization.in_specs_treedef, tree_util.PyTreeDef) - - # If input_specs is known, verify that it matches actual inputs. - if (specialization.in_specs_treedef != in_specs_treedef - or specialization.in_specs_leaves != in_specs_leaves): - raise ValueError( - "Input specs in specialization and input specs of arguments must have" - " the same pytree structure, but they have the following structural" - " differences:\n" - + ("\n".join( - f" - {tree_util.keystr(path)} is a {thing1} in value 1 and" - f" a {thing2} in value 2, so {explanation}.\n" - for path, thing1, thing2, explanation in tree_util.equality_errors_pytreedef( - specialization.in_specs_treedef, in_specs_treedef - )))) - - return _get_specialized_func(info, specialization)(*args, **kwargs) - - __call__ = wraps(info.fun)(__call__) - __call__.specialize = specialize - return __call__ + if not fully_specified_in_spec and not no_input: + # We need to handle input polymorphism + nonlocal specialized_collections_wref + with wref_mu: + collection: _SpecializedCollection = specialized_collections_wref() + if collection is None: + collection = _SpecializedCollection() + self.specialized_collections.append(collection) + specialized_collections_wref = weakref.ref(collection) + result = collection.get( + args_leaves, in_specs_treedef, info, specialization + )(*args, **kwargs) + del collection + return result + + # No input polymorphism -- exactly one compiled function is possible. + with wref_mu: + nonlocal specialized_functions_wref + func: Callable[..., Any] = specialized_functions_wref() + if func is None: + if fully_specified_in_spec and specialization.devices is not None: + func = _uncached_get_specialized_func(info, specialization) + elif fully_specified_in_spec: + func = _uncached_get_specialized_func( + info, + specialization.update( + devices=_infer_devices_from_args(args_leaves) + ), + ) + elif no_input: + func = _uncached_get_specialized_func( + info, + specialization.update( + in_specs_leaves=tuple(), + in_specs_treedef=in_specs_treedef, + ), + ) + self.specialized_functions.append(func) + specialized_functions_wref = weakref.ref(func) + result = func(*args, **kwargs) + del func + return result + + __call__ = wraps(info.fun)(__call__) + __call__.specialize = specialize + return __call__ + + def make_callable( + self, + fun: Callable[..., Any], + fun_sourceinfo: str | None, + fun_signature: inspect.Signature | None, + ): + """Makes a colocated Python callable.""" + return self._make_callable( + FunctionInfo(fun, fun_sourceinfo, fun_signature), Specialization() + ) + + +_DEFAULT_FUNCTION_MAKER = _CachedColocatedFunctionMaker(None) + + +def make_callable( + fun: Callable[..., Any], + fun_sourceinfo: str | None, + fun_signature: inspect.Signature | None, +): + return _DEFAULT_FUNCTION_MAKER.make_callable( + fun, fun_sourceinfo, fun_signature + ) diff --git a/jax/experimental/colocated_python/func_backend.py b/jax/experimental/colocated_python/func_backend.py index aa514015004d..4f1443da4b17 100644 --- a/jax/experimental/colocated_python/func_backend.py +++ b/jax/experimental/colocated_python/func_backend.py @@ -16,7 +16,7 @@ from __future__ import annotations import threading -from typing import Sequence +from collections.abc import Sequence import jax diff --git a/jax/experimental/colocated_python/obj.py b/jax/experimental/colocated_python/obj.py index b1e7a0b1eade..b7204f765937 100644 --- a/jax/experimental/colocated_python/obj.py +++ b/jax/experimental/colocated_python/obj.py @@ -15,19 +15,35 @@ from __future__ import annotations +from collections.abc import Callable import inspect import random import threading -from typing import Any, Callable, Type +from typing import Any +import weakref import jax from jax._src import api_util +from jax._src import config from jax._src import tree_util from jax._src.traceback_util import api_boundary from jax._src.util import wraps from jax.experimental.colocated_python import func from jax.experimental.colocated_python import obj_backend +# TODO(madthanu): Remove the following config option and make its behavior the +# default, once the behavior has been declared stable. +_USE_WEAKREFS = config.bool_state( + 'jax_experimental_colocated_python_object_use_weakrefs_at_backend', + False, + help=( + 'Unstable in-development feature that switches the colocated-python' + ' implementation to internally use reference counting for destructing' + ' objects at the colocated backend, instead of invoking an explicit' + ' delete-object function from the frontend.' + ), +) + class _InstanceRegistry: """Registry of object instances.""" @@ -58,7 +74,7 @@ def pop_instance(self, uid: int) -> set[jax.Device]: SINGLETON_INSTANCE_REGISTRY = _InstanceRegistry() -@jax.util.cache(max_size=4096) +@jax._src.util.cache(max_size=4096) def _update_instance_devices( uid: int, shardings: tuple[jax.sharding.Sharding, ...] ) -> None: @@ -70,53 +86,111 @@ def _update_instance_devices( def _make_method( - cls: Type[object], + cls: type[object], cls_sourceinfo: str | None, uid: int, init_args: tuple[Any, ...], init_kwargs: dict[str, Any], method_name: str, original_method: Callable[..., Any], + func_maker: func._CachedColocatedFunctionMaker, + use_weakrefs: bool, ): - # Initializer to use when the object is not present in the backend. - def initializer() -> object: - return cls(*init_args, **init_kwargs) - # Method to call on the backend. - def method(*args, **kwargs): - obj = obj_backend.SINGLETON_OBJECT_STORE.get_or_create(uid, initializer) - return getattr(obj, method_name)(*args, **kwargs) + class MethodCallerAtBackend: + + def __init__(self): + self._lock = threading.Lock() + + def __reduce__(self): + return type(self), () + + def _first_call(self): + # Temporarily hold a strong reference to a new object if it is created + # using initializer. + temp_strong_ref = None + + def initializer(): + if not use_weakrefs: + return obj_backend._ClassWrapperForGarbageCollection( # pylint: disable=protected-access + cls(*init_args, **init_kwargs) + ) + nonlocal temp_strong_ref + temp_strong_ref = cls(*init_args, **init_kwargs) + return weakref.ref(temp_strong_ref) + + retrieved = obj_backend.SINGLETON_OBJECT_STORE.get_or_create( + uid, initializer + ) + + if use_weakrefs: + self.obj = temp_strong_ref + else: + self.obj = retrieved + + def __call__(self, *args, **kwargs): + with self._lock: + if not hasattr(self, 'obj'): + self._first_call() + + if use_weakrefs: + return getattr(self.obj, method_name)(*args, **kwargs) + else: + assert isinstance( + self.obj, obj_backend._ClassWrapperForGarbageCollection + ) + return getattr(self.obj.obj, method_name)(*args, **kwargs) # Colocated Python callable for the controller. - callable = func.make_callable( - method, + callable = func_maker.make_callable( + MethodCallerAtBackend(), cls_sourceinfo, api_util.fun_signature(original_method), ) - # Outer wrapper of the method for the controller. It tracks - @api_boundary - def method_wrapper(*args, **kwargs): - if not args: - raise NotImplementedError( - 'Method calls with no arguments are not yet supported.' + # Outer wrapper of the method for the controller. It tracks devices that have + # been used with any method call. + def make_method_wrapper(callable): + @api_boundary + def method_wrapper(*args, **kwargs): + # TODO(hyeontaek): Instead of inspecting argument/result shardings, get + # shardings from final specialization of the function. This may require + # lowering `_update_instance_devices` into the function API. + + args_leaves = tree_util.tree_leaves((args, kwargs)) + args_shardings_leaves = tuple( + func._get_spec(x).sharding for x in args_leaves ) - # TODO(hyeontaek): Instead of inspecting argument shardings, get shardings - # from final specialization of the function. This may require lowering - # `_update_instance_devices` into the function API. - args_leaves = tree_util.tree_leaves((args, kwargs)) - shardings_leaves = tuple(func._get_spec(x).sharding for x in args_leaves) - _update_instance_devices(uid, shardings_leaves) - return callable(*args, **kwargs) - - method_wrapper = wraps(original_method)(method_wrapper) + if args_shardings_leaves: + _update_instance_devices(uid, args_shardings_leaves) + + result = callable(*args, **kwargs) + + # If args had any array, we can skip incorporating devices from the result + # because results will not use any new devices. + if not args_shardings_leaves: + result_leaves = tree_util.tree_leaves(result) + result_shardings_leaves = tuple( + func._get_spec(x).sharding for x in result_leaves + ) + _update_instance_devices(uid, result_shardings_leaves) + return result + + def specialize(*args, **kwargs): + return make_method_wrapper(callable.specialize(*args, **kwargs)) + + method_wrapper = wraps(original_method)(method_wrapper) + method_wrapper.specialize = specialize + return method_wrapper + + method_wrapper = make_method_wrapper(callable) return method_wrapper def wrap_class( - cls: Type[object], + cls: type[object], cls_sourceinfo: str | None, -) -> Type[object]: +) -> type[object]: class WrappedClass: @wraps(cls.__init__) @@ -124,6 +198,8 @@ def __init__(self, *init_args, **init_kwargs) -> None: uid = self._colocated_python_uid = ( SINGLETON_INSTANCE_REGISTRY.new_instance() ) + self.func_maker = func._CachedColocatedFunctionMaker(uid) + self.use_weakrefs = _USE_WEAKREFS.value for attr_name in dir(cls): original_member = getattr(cls, attr_name) if not inspect.isfunction(original_member): @@ -143,12 +219,17 @@ def __init__(self, *init_args, **init_kwargs) -> None: init_kwargs, attr_name, original_member, + self.func_maker, + self.use_weakrefs, ) # TODO(hyeontaek): Support method specialization similar to function # specialization. setattr(self, attr_name, method) - def __del__(self) -> None: + def __del__(self): + del self.func_maker + if self.use_weakrefs: + return uid = self._colocated_python_uid devices = SINGLETON_INSTANCE_REGISTRY.pop_instance(uid) if devices: @@ -156,16 +237,13 @@ def __del__(self) -> None: def remove_object() -> None: obj_backend.SINGLETON_OBJECT_STORE.remove(uid) - # TODO(hyeontaek): Request "best-effort" non-SPMD execution that tries - # to run this function on any healthy processes instead of failing when - # any process of the execution is unhealthy. destructor = func.make_callable( remove_object, cls_sourceinfo, None, ) destructor = destructor.specialize( # type: ignore[attribute-error] - devices=devices + devices=sorted(devices, key=lambda device: device.id) ) destructor() diff --git a/jax/experimental/colocated_python/obj_backend.py b/jax/experimental/colocated_python/obj_backend.py index ffa04a007818..8408ab7789f4 100644 --- a/jax/experimental/colocated_python/obj_backend.py +++ b/jax/experimental/colocated_python/obj_backend.py @@ -15,9 +15,15 @@ from __future__ import annotations +from collections.abc import Callable import dataclasses import threading -from typing import Any, Callable +from typing import Any + + +@dataclasses.dataclass +class _ClassWrapperForGarbageCollection: + obj: Any @dataclasses.dataclass(frozen=True) @@ -71,6 +77,8 @@ def remove(self, uid: int) -> None: state = self._storage.pop(uid) # The object will be deleted without holding the lock. + if isinstance(state.obj, _ClassWrapperForGarbageCollection): + del state.obj.obj del state diff --git a/jax/experimental/colocated_python/serialization.py b/jax/experimental/colocated_python/serialization.py index 1ca29ab12660..3c87906876cb 100644 --- a/jax/experimental/colocated_python/serialization.py +++ b/jax/experimental/colocated_python/serialization.py @@ -17,9 +17,11 @@ import base64 import collections +from collections.abc import Callable, Sequence import functools import io -from typing import Any, Callable, Sequence +import threading +from typing import Any try: import cloudpickle # type: ignore[import-not-found] @@ -35,7 +37,74 @@ DeviceList = xc.DeviceList -@jax.util.cache(max_size=None) + +class _CommonObjectState(threading.local): + """Tracks repeated objects within a single `_serialize()` or `_deserialize()`. + + It is common for `_serialize(x)` to be called with `x` being a nested + container or capturing other objects in a closure, with many references + pointing to only a few unique objects. The logic below + (`_make_reduce_func_with_common_obj`) avoids duplicating object serialization + by reducing a reference handle instead of the full object when an equal object + is repeatedly seen. + """ + + def __init__(self): + # Map from a common object key to its ID. Any objects with a matching key + # will use the common object ID instead of the full object during + # serialization. + self.common_obj_index: dict[Any, int] | None = None + + # Common object that has been reconstructed when their key was seen for the + # first time during deserialization. + self.common_obj: list[Any] | None = None + + +_common_obj_state = _CommonObjectState() + + +def _wrapped_unreduce_func_with_new_common_obj( + common_obj_id, unreduce_func, unreduce_args): + """Unreduces a new common object.""" + assert _common_obj_state.common_obj is not None + obj = unreduce_func(*unreduce_args) + assert len(_common_obj_state.common_obj) == common_obj_id, ( + f"Expected {common_obj_id} common objects, but got" + f" {len(_common_obj_state.common_obj)}. This can happen if serialization" + " and deserialization of objects happened in different orders." + ) + _common_obj_state.common_obj.append(obj) + return obj + + +def _wrapped_unreduce_func_with_existing_common_obj(common_obj_id): + """Unreduces a common object that has already appeared.""" + assert _common_obj_state.common_obj is not None + return _common_obj_state.common_obj[common_obj_id] + + +def _make_reduce_func_with_common_obj( + reduce_func: Callable[[Any], tuple[Any, Any]], +) -> Callable[[Any], tuple[Any, Any]]: + """Wraps a reduce function to serialize a common object once.""" + + @functools.wraps(reduce_func) + def wrapped_reduce_func(obj): + assert _common_obj_state.common_obj_index is not None + common_obj_id = _common_obj_state.common_obj_index.get(obj) + if common_obj_id is None: + unreduced_func, unreduced_args = reduce_func(obj) + common_obj_id = len(_common_obj_state.common_obj_index) + _common_obj_state.common_obj_index[obj] = common_obj_id + return _wrapped_unreduce_func_with_new_common_obj, ( + common_obj_id, unreduced_func, unreduced_args) + else: + return _wrapped_unreduce_func_with_existing_common_obj, (common_obj_id,) + + return wrapped_reduce_func + + +@jax._src.util.cache(max_size=None) def _get_cpu_device_map() -> dict[int, jax.Device]: """Returns a map from a device id to a matching device.""" cpu_device_map: dict[int, jax.Device] = {} @@ -83,46 +152,69 @@ def _lookup_cpu_device( return d +@_make_reduce_func_with_common_obj def _reduce_mesh( mesh: jax.sharding.Mesh, ) -> tuple[Callable[..., jax.sharding.Mesh], Any]: - def make_mesh( - mesh_device_ids: np.ndarray, axis_names: Any - ) -> jax.sharding.Mesh: - cpu_device_map = _get_cpu_device_map() - mesh_devices = np.vectorize( - functools.partial(_lookup_cpu_device, cpu_device_map) - )(mesh_device_ids) - return jax.sharding.Mesh(mesh_devices, axis_names) - mesh_device_ids = np.vectorize(lambda d: d.id, otypes=[int])(mesh.devices) - return make_mesh, (mesh_device_ids, mesh.axis_names) + return _unreduce_mesh, (mesh_device_ids, mesh.axis_names, mesh.axis_types) + + +def _unreduce_mesh( + mesh_device_ids: np.ndarray, axis_names: Any, axis_types: Any +) -> jax.sharding.Mesh: + cpu_device_map = _get_cpu_device_map() + mesh_devices = np.vectorize( + functools.partial(_lookup_cpu_device, cpu_device_map) + )(mesh_device_ids) + return jax.sharding.Mesh(mesh_devices, axis_names, axis_types) +@_make_reduce_func_with_common_obj +def _reduce_named_sharding( + sharding: jax.sharding.NamedSharding, +) -> tuple[Callable[..., jax.sharding.NamedSharding], Any]: + assert isinstance(sharding.mesh, jax.sharding.Mesh), "Only Mesh is supported" + reduced_mesh = _reduce_mesh(sharding.mesh) + return _unreduce_named_sharding, ( + reduced_mesh, sharding.spec, sharding.memory_kind) + + +def _unreduce_named_sharding(reduced_mesh, spec, memory_kind): + mesh = reduced_mesh[0](*reduced_mesh[1]) + return jax.NamedSharding(mesh, spec, memory_kind=memory_kind) + + +@_make_reduce_func_with_common_obj def _reduce_device_list( device_list: DeviceList, ) -> tuple[Callable[..., DeviceList], Any]: - def make_device_list(device_ids: Sequence[int]) -> DeviceList: - cpu_device_map = _get_cpu_device_map() - devices = np.vectorize( - functools.partial(_lookup_cpu_device, cpu_device_map) - )(device_ids) - return DeviceList(tuple(devices)) - device_ids = [d.id for d in device_list] - return make_device_list, (device_ids,) + return _unreduce_device_list, (device_ids,) + + +def _unreduce_device_list(device_ids: Sequence[int]) -> DeviceList: + cpu_device_map = _get_cpu_device_map() + devices = np.vectorize(functools.partial(_lookup_cpu_device, cpu_device_map))( + device_ids) + return DeviceList(tuple(devices)) +@_make_reduce_func_with_common_obj def _reduce_single_device_sharding( sharding: jax.sharding.SingleDeviceSharding, ) -> tuple[Callable[..., jax.sharding.SingleDeviceSharding], Any]: + return _unreduce_single_device_sharding, ( + sharding.device_set.pop().id, + sharding.memory_kind) - def make_single_device_sharding(device_id: int): - cpu_device_map = _get_cpu_device_map() - device = _lookup_cpu_device(cpu_device_map, device_id) - return jax.sharding.SingleDeviceSharding(device) - return make_single_device_sharding, (sharding.device_set.pop().id,) +def _unreduce_single_device_sharding( + device_id: int, memory_kind: str | None +) -> jax.sharding.SingleDeviceSharding: + cpu_device_map = _get_cpu_device_map() + device = _lookup_cpu_device(cpu_device_map, device_id) + return jax.sharding.SingleDeviceSharding(device, memory_kind=memory_kind) def _serialize(obj: Any) -> bytes: @@ -149,15 +241,22 @@ def _serialize(obj: Any) -> bytes: class _CustomPickler(cloudpickle.Pickler): dispatch_table = collections.ChainMap( {jax.sharding.Mesh: _reduce_mesh}, + {jax.sharding.NamedSharding: _reduce_named_sharding}, {DeviceList: _reduce_device_list}, {jax.sharding.SingleDeviceSharding: _reduce_single_device_sharding}, cloudpickle.CloudPickler.dispatch_table, # pylint: disable=attribute-error ) dispatch = dispatch_table - with io.BytesIO() as file: - _CustomPickler(file).dump(obj) - return file.getvalue() + assert _common_obj_state.common_obj_index is None, ( + "_serialize() expects no recursive calls") + _common_obj_state.common_obj_index = {} + try: + with io.BytesIO() as file: + _CustomPickler(file).dump(obj) + return file.getvalue() + finally: + _common_obj_state.common_obj_index = None def _deserialize(serialized: bytes) -> Any: @@ -172,7 +271,13 @@ def _deserialize(serialized: bytes) -> Any: if cloudpickle is None: raise ModuleNotFoundError('No module named "cloudpickle"') - return cloudpickle.loads(serialized) + assert _common_obj_state.common_obj is None, ( + "_deserialize() expects no recursive calls") + _common_obj_state.common_obj = [] + try: + return cloudpickle.loads(serialized) + finally: + _common_obj_state.common_obj = None def _make_specs_for_serialized_specs( @@ -201,7 +306,7 @@ def _serialize_specs( if not hasattr(np.dtypes, "StringDType"): raise TypeError( "Serializing Colocated Python requires StringDType. Please use" - " numpy to 2.0.0 or later, or explicityly provide an output spec" + " numpy to 2.0.0 or later, or explicitly provide an output spec" " function." ) @@ -223,7 +328,9 @@ def _serialize_specs( jax.device_put(s_np_array, device) for device in addressable_devices ] return jax.make_array_from_single_device_arrays( - arrays=out_arrays, sharding=replicated_sharding, shape=(), + arrays=out_arrays, + sharding=replicated_sharding, + shape=(), ) diff --git a/jax/experimental/compilation_cache/compilation_cache.py b/jax/experimental/compilation_cache/compilation_cache.py index 990dd1742051..8c820c5434fe 100644 --- a/jax/experimental/compilation_cache/compilation_cache.py +++ b/jax/experimental/compilation_cache/compilation_cache.py @@ -13,8 +13,6 @@ # limitations under the License. from jax._src.compilation_cache import ( - is_initialized as is_initialized, # deprecated - initialize_cache as initialize_cache, # deprecated; use set_cache_dir instead set_cache_dir as set_cache_dir, reset_cache as reset_cache, ) diff --git a/jax/experimental/fused.py b/jax/experimental/fused.py new file mode 100644 index 000000000000..871b29b8ebde --- /dev/null +++ b/jax/experimental/fused.py @@ -0,0 +1,152 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jax._src import core +from jax._src import linear_util as lu +from jax._src import dispatch +from jax._src.core import typeof +from jax._src.tree_util import tree_flatten, tree_unflatten +from jax._src.util import safe_map, safe_zip, weakref_lru_cache, unzip2 +from jax._src.api_util import debug_info, flatten_fun_nokwargs +from jax._src.interpreters import ad +from jax._src.interpreters import batching +from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe +from jax._src.lib.mlir import ir + +map, unsafe_map = safe_map, map +zip, unsafe_zip = safe_zip, zip + +def fused(*, out_spaces): + def wrap(f): + def wrapped(*args): + dbg = debug_info('fused', f, args, {}) + args_flat, in_tree = tree_flatten(args) + in_avals = [typeof(x).update(memory_space=core.MemorySpace.Any) + for x in args_flat] + jaxpr, out_tree = _trace_to_jaxpr(f, in_tree, tuple(in_avals), dbg) + outs_flat = fused_p.bind(*args_flat, jaxpr=jaxpr, out_spaces=out_spaces) + return tree_unflatten(out_tree, outs_flat) + return wrapped + return wrap + +@weakref_lru_cache +def _trace_to_jaxpr(fun, in_tree, in_avals, dbg): + f = lu.wrap_init(fun, debug_info=dbg) + f, out_tree = flatten_fun_nokwargs(f, in_tree) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(f, in_avals) + return core.ClosedJaxpr(jaxpr, consts), out_tree() + +fused_p = core.Primitive('fused_call') +fused_p.multiple_results = True + +@fused_p.def_abstract_eval +def _fused_abstract_eval(*in_avals, out_spaces, jaxpr): + return [a.update(memory_space=s) + for a, s in zip(jaxpr.out_avals, out_spaces)] + +dispatch.simple_impl(fused_p) + +def _fused_lowering(ctx, *args, out_spaces, jaxpr): + const_args_and_avals = core.jaxpr_const_args(jaxpr.jaxpr) + const_args, const_arg_avals = unzip2(const_args_and_avals) + const_arg_values = [ + mlir.ir_constant(c, const_lowering=ctx.const_lowering, aval=aval) + for c, aval in const_args_and_avals] + in_avals = [*const_arg_avals, *ctx.avals_in] + func_op, _, _ = mlir.lower_called_computation( + "fused", jaxpr, ctx.module_context, len(const_args), in_avals, + ctx.avals_out, ctx.tokens_in) + out_spaces_ = [ir.StringAttr.get(str(s)) for s in out_spaces] + fused = mlir.custom_call( + "fused", + result_types=func_op.type.results, + operands=mlir.flatten_ir_values([*const_arg_values, *args]), + called_computations=[func_op.name.value], + backend_config=dict(out_spaces=ir.ArrayAttr.get(out_spaces_), + inlineable=ir.BoolAttr.get(False), + MUST_FUSE=ir.BoolAttr.get(True)), + ) + return fused.results +mlir.register_lowering(fused_p, _fused_lowering, platform="cuda") + +def _fused_batcher(axis_data, vals_in, dims_in, *, jaxpr, out_spaces): + batched_jaxpr, dims_out = batching.batch_jaxpr2(jaxpr, axis_data, dims_in) + outs = fused_p.bind(*vals_in, jaxpr=batched_jaxpr, out_spaces=out_spaces) + return outs, dims_out +batching.fancy_primitive_batchers[fused_p] = _fused_batcher + +def _fused_jvp(primals, tangents, *, jaxpr, out_spaces): + nzs = [not isinstance(t, ad.Zero) for t in tangents] + jaxpr_jvp, out_nzs = ad.jvp_jaxpr(jaxpr, nzs, False) + nz_tangents = [t for t in tangents if not isinstance(t, ad.Zero)] + spaces_jvp = (*out_spaces, *[s for s, nz in zip(out_spaces, out_nzs) if nz]) + outs = fused_p.bind(*primals, *nz_tangents, jaxpr=jaxpr_jvp, + out_spaces=spaces_jvp) + primals_out, nz_tangents_out = outs[:len(out_nzs)], outs[len(out_nzs):] + nz_outs = iter(nz_tangents_out) + tangents_out = [next(nz_outs) if nz else ad.Zero(aval.to_tangent_aval()) + for aval, nz in zip(jaxpr.out_avals, out_nzs)] + assert next(nz_outs, None) is None + return primals_out, tangents_out +ad.primitive_jvps[fused_p] = _fused_jvp + +def _fused_lin(nzs, *primals, jaxpr, out_spaces): + jaxpr_jvp, out_nzs = ad.jvp_jaxpr(jaxpr, nzs, False) + lin_outs = [False] * len(out_nzs) + [True] * sum(out_nzs) + jaxpr_lin_, used_inputs = pe.dce_jaxpr(jaxpr_jvp.jaxpr, lin_outs, False) + jaxpr_lin = pe.close_jaxpr(jaxpr_lin_) + spaces_lin = tuple(s for s, nz in zip(out_spaces, out_nzs) if nz) + primals_out = fused_p.bind(*primals, jaxpr=jaxpr, out_spaces=out_spaces) + tangent_avals_out = [a.to_tangent_aval() for a in jaxpr.out_avals] + + def fused_lin(primals, *tangents): + nz_tangents = [t for t in tangents if not isinstance(t, ad.Zero)] + inputs = [x for x, u in zip([*primals, *nz_tangents], used_inputs) if u] + nz_outs = fused_p.bind(*inputs, jaxpr=jaxpr_lin, out_spaces=spaces_lin) + nz_outs_ = iter(nz_outs) + outs = [next(nz_outs_) if nz else ad.Zero(a) + for nz, a in zip(out_nzs, tangent_avals_out)] + assert next(nz_outs_, None) is None + return outs + + return primals_out, out_nzs, primals, fused_lin +ad.primitive_linearizations[fused_p] = _fused_lin + +def _fused_transpose(cts_in, *primals_in, jaxpr, out_spaces): + in_flat, in_tree = tree_flatten((primals_in, cts_in)) + in_avals = [typeof(x).update(memory_space=core.MemorySpace.Any) + for x in in_flat] + trans_jaxpr, out_tree = _transpose_jaxpr(jaxpr, in_tree, (*in_avals,)) + in_spaces = [x.aval.memory_space if isinstance(x, ad.UndefinedPrimal) + else typeof(x).memory_space for x in primals_in] + cts_out_ = tree_unflatten(out_tree, trans_jaxpr.out_avals) + trans_spaces = tuple(s for x, s in zip(cts_out_, in_spaces) if x) + cts_out = fused_p.bind(*in_flat, jaxpr=trans_jaxpr, out_spaces=trans_spaces) + return tree_unflatten(out_tree, cts_out) + +@weakref_lru_cache +def _transpose_jaxpr(jaxpr, in_tree, in_avals): + cell = lambda: None + def transposed(*in_flat): + primals_in, cts_in = tree_unflatten(in_tree, in_flat) + out = ad.backward_pass(jaxpr.jaxpr, False, jaxpr.consts, primals_in, cts_in) + out = [ct if not isinstance(ct, ad.Zero) else None for ct in out] + cts_out, cell.out_tree = tree_flatten(out) # type: ignore + return cts_out + dbg = jaxpr.jaxpr.debug_info.with_unknown_names() + trans_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(transposed, debug_info=dbg), in_avals) + return core.ClosedJaxpr(trans_jaxpr, consts), cell.out_tree # type: ignore +ad.primitive_transposes[fused_p] = _fused_transpose diff --git a/jax/experimental/hijax.py b/jax/experimental/hijax.py new file mode 100644 index 000000000000..1201c58120f1 --- /dev/null +++ b/jax/experimental/hijax.py @@ -0,0 +1,46 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + +from jax._src.ad_util import ( + Zero as Zero, +) +from jax._src.core import ( + AbstractValue as AbstractValue, + AvalQDD as AvalQDD, + ShapedArray as ShapedArray, + aval_method as aval_method, + aval_property as aval_property, + AvalMutableQDD as AvalMutableQDD, +) +from jax._src.interpreters.ad import ( + instantiate_zeros as instantiate_zeros, + is_undefined_primal as is_undefined_primal, +) +from jax._src.effects import ( + control_flow_allowed_effects as control_flow_allowed_effects, +) +from jax._src.hijax import ( + HiPrimitive as HiPrimitive, + HiType as HiType, + MutableHiType as MutableHiType, + VJPHiPrimitive as VJPHiPrimitive, + register_hitype as register_hitype, + VJPHiPrimitive as VJPHiPrimitive, +) +from jax._src.state import ( + AbstractRef as AbstractRef, + TransformedRef as TransformedRef +) diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py deleted file mode 100644 index 7d60f62e230f..000000000000 --- a/jax/experimental/host_callback.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright 2020 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Backwards compatibility shims for the deprecated host_callback APIs. - -.. warning:: - The host_callback APIs are deprecated as of March 20, 2024. - The functionality is subsumed by the - `new JAX external callbacks `_ - See https://github.com/jax-ml/jax/issues/20385. - -""" - -from __future__ import annotations - -def call(*_, **__): - raise NotImplementedError( - "jax.experimental.host_callback has been deprecated since March 2024 and " - "is now no longer supported. " - "See https://github.com/jax-ml/jax/issues/20385" - ) - - -id_tap = call diff --git a/jax/experimental/jax2tf/BUILD b/jax/experimental/jax2tf/BUILD index 85ad90326859..c57d0e597292 100644 --- a/jax/experimental/jax2tf/BUILD +++ b/jax/experimental/jax2tf/BUILD @@ -38,7 +38,6 @@ py_library( name = "jax2tf_internal", srcs = [ "call_tf.py", - "impl_no_xla.py", "jax2tf.py", ], # TODO: b/255503696: enable pytype diff --git a/jax/experimental/jax2tf/README.md b/jax/experimental/jax2tf/README.md index 0d827fbcc7a5..b2cb7657aee2 100644 --- a/jax/experimental/jax2tf/README.md +++ b/jax/experimental/jax2tf/README.md @@ -138,7 +138,7 @@ f_tf_graph = tf.function(f_tf, autograph=False) ``` Note that when using the default native serialization, the target JAX function -must be jittable (see [JAX - The Sharp Bits](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)). +must be jittable (see [JAX - The Sharp Bits](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html)). In the native serialization mode, under TensorFlow eager the whole JAX function executes as one op. @@ -461,7 +461,7 @@ presence of shape polymorphism, some dimensions may be dimension variables. The `polymorphic_shapes` parameter must be either `None`, or a pytree of shape specifiers corresponding to the pytree of arguments. (A value `None` for `polymorphic_shapes` is equivalent to a list of `None`. -See [how optional parameters are matched to arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).) +See [how optional parameters are matched to arguments](https://docs.jax.dev/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).) A shape specifier is combined with a `TensorSpec` as follows: * A shape specifier of `None` means that the shape is given @@ -568,6 +568,7 @@ because the shape abstraction that JAX tracing uses is given by the actual arguments are more specific and would actually work. Also, + ```python jax2tf.convert(lambda x: jnp.matmul(x, x), polymorphic_shapes=["(v, 4)"])(np.ones((4, 4))) @@ -808,6 +809,7 @@ TypeError: add got incompatible shapes for broadcasting: (a,), (floordiv(b, 2),) ``` You can fix this by adding a constraint: + ```python jax2tf.convert(lambda x, y: x + y[:y.shape[0] // 2], polymorphic_shapes=("a", "b"), @@ -826,19 +828,19 @@ For example, the following code will fail because `a1` and `a2` use different scopes (created by `export.symbolic_shape`): -````python +```python a1, = export.symbolic_shape("a,") a2, = export.symbolic_shape("a,", constraints=("a >= 8",)) a1 + a2 -```` +``` The symbolic expressions that originate from a single call to `export.symbolic_shape` share a scope and can be mixed up in arithmetic operations. The result would also share the same scope. -You can re-use scopes: +You can reuse scopes: ```python a, = export.symbolic_shape("a,", constraints=("a >= 8",)) @@ -1005,6 +1007,8 @@ We list here a history of the serialization version numbers: available in JAX since October 20th, 2023 (JAX 0.4.20), and the default since February 1st, 2024 (JAX 0.4.24). This is the only supported version as of 27th of March, 2024. + * Version 10 propagate the `jax.config.use_shardy_partitioner` value to + XlaCallModule. ## Known issues @@ -1024,7 +1028,7 @@ always behaves like the JAX function. JAX interprets the type of Python scalars differently based on `JAX_ENABLE_X64` flag. (See -[JAX - The Sharp Bits: Double (64bit) precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision).) +[JAX - The Sharp Bits: Double (64bit) precision](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision).) In the default configuration, the flag is unset, and JAX interprets Python constants as 32-bit, e.g., the type of `3.14` is `float32`. This is also what @@ -1086,7 +1090,7 @@ Applies to both native and non-native serialization. `jax2tf` can lower functions with arguments and results that are nested collections (tuples, lists, dictionaries) of numeric values or JAX arrays -([pytrees](https://jax.readthedocs.io/en/latest/pytrees.html)). The +([pytrees](https://docs.jax.dev/en/latest/pytrees.html)). The resulting TensorFlow function will take the same kind of arguments except the leaves can be numeric values or TensorFlow tensors (`tf.Tensor`, `tf.TensorSpec`, `tf.Variable`). @@ -1285,7 +1289,7 @@ per PRNG operation. The "unsafe" part is that it doesn't guarantee determinism across JAX/XLA versions, and the quality of random streams it generates from different keys is less well understood. Nevertheless, this should be fine for most inference/serving cases. -See more details in the [JAX PRNG documentation](https://jax.readthedocs.io/en/latest/jax.random.html?highlight=unsafe_rbg#advanced-rng-configuration). +See more details in the [JAX PRNG documentation](https://docs.jax.dev/en/latest/jax.random.html?highlight=unsafe_rbg#advanced-rng-configuration). ### SavedModel supports only first-order gradients @@ -1437,7 +1441,7 @@ may be slightly different for small matrices. Applies to non-native serialization only. Operations like ``jax.numpy.cumsum`` are lowered by JAX differently based -on the platform. For TPU, the lowering uses the [HLO ReduceWindow](https://www.tensorflow.org/xla/operation_semantics#reducewindow) +on the platform. For TPU, the lowering uses the [HLO ReduceWindow](https://www.openxla.org/xla/operation_semantics#reducewindow) operation, which has an efficient implementation for the cases when the reduction function is associative. For CPU and GPU, JAX uses an alternative lowering using [associative scans](https://github.com/jax-ml/jax/blob/f08bb50bfa9f6cf2de1f3f78f76e1aee4a78735d/jax/_src/lax/control_flow.py#L2801). @@ -1463,7 +1467,7 @@ For most JAX primitives there is a natural TensorFlow op that fits the needed se There are a few (listed in [no_xla_limitations.md](g3doc/no_xla_limitations.md)) JAX primitives for which there is no single TensorFlow op with matching semantics. This is not so surprising, because JAX primitives have been designed -to be compiled to [HLO ops](https://www.tensorflow.org/xla/operation_semantics), +to be compiled to [HLO ops](https://www.openxla.org/xla/operation_semantics), while the corresponding TensorFlow ops are sometimes higher-level. For the cases when there is no matching canonical TensorFlow op, we use a set of special TensorFlow ops that are thin wrappers over HLO ops @@ -1505,7 +1509,7 @@ deterministic PRNG](https://github.com/jax-ml/jax/blob/main/docs/design_notes/pr and it has an internal JAX primitive for it. This primitive is at the moment lowered to a soup of tf.bitwise operations, which has a clear performance penalty. We plan to look into using the -HLO [RNGBitGenerator](https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator) +HLO [RNGBitGenerator](https://www.openxla.org/xla/operation_semantics#rngbitgenerator) (exposed as a TFXLA op), which does implement the same basic Threefry algorithm as JAX’s PRNG, although that would result in different results than JAX’s PRNG. diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index 98c1c20cd6e5..cca98da6c2c8 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -40,12 +40,14 @@ from jax._src import core from jax._src import effects from jax._src import util -from jax._src.lib import xla_client +from jax._src.lib import _jax from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect from jax._src.lib.mlir.dialects import hlo +from jax.experimental import roofline from jax.experimental.jax2tf import jax2tf as jax2tf_internal from jax._src.interpreters import mlir +import ml_dtypes import numpy as np import tensorflow as tf @@ -133,7 +135,7 @@ def canonical_arg(v): args_flat_jax = tuple(map(canonical_arg, args_flat_jax)) def make_tensorspec(a_jax): a_tf_dtype = jax2tf_internal._to_tf_dtype(a_jax.dtype) - a_tf_shape = [d if core.is_constant_dim(d) else None for d in a_jax.shape] + a_tf_shape = [d if core.is_constant_dim(d) else None for d in getattr(a_jax, "shape", ())] return tf.TensorSpec(a_tf_shape, a_tf_dtype) args_flat_sig_tf = tuple(map(make_tensorspec, args_flat_jax)) @@ -302,10 +304,7 @@ def check_tf_result(idx: int, r_tf: TfVal, r_aval: core.ShapedArray | None) -> T # that tf.ensure_shape did this, but it can only take shapes that contain None # not computed shapes. However, in eager mode we should be able to resolve # the declared shapes to constants and we get better checking. - if tf.executing_eagerly(): - r_aval_shape_tf = jax2tf_internal._eval_shape(r_aval.shape) - else: - r_aval_shape_tf = jax2tf_internal._aval_to_tf_shape(r_aval) + r_aval_shape_tf = jax2tf_internal._aval_to_tf_shape(r_aval) # We do as much checking as we can here, instead of relying on tf.ensure_shape # because the latter gives different errors in eager vs. compiled mode. # TODO(b/279454591): This strange error is from TF. Eager function suppose @@ -344,9 +343,8 @@ def _call_tf_impl(*args_jax_flat, callable_flat_tf, **_): def _arg_jax_to_tf(arg_jax): if (isinstance(arg_jax, jax.Array) and list(arg_jax.devices())[0].platform in _DLPACK_PLATFORMS and - arg_jax.dtype.type in dlpack.SUPPORTED_DTYPES): - arg_dlpack = jax.dlpack.to_dlpack(arg_jax) - return tf.experimental.dlpack.from_dlpack(arg_dlpack) + dlpack.is_supported_dtype(arg_jax.dtype)): + return tf.experimental.dlpack.from_dlpack(arg_jax.__dlpack__()) # The following avoids copies to the host on CPU, always for Array # and even for ndarray if they are sufficiently aligned. # TODO(necula): on TPU this copies to the host! @@ -362,12 +360,11 @@ def _arg_jax_to_tf(arg_jax): def _res_tf_to_jax(res_tf: TfVal): res_tf, jax_dtype = jax2tf_internal._tfval_to_tensor_jax_dtype(res_tf) - if isinstance(res_tf, tf.Tensor) and jax_dtype.type in dlpack.SUPPORTED_DTYPES: + if isinstance(res_tf, tf.Tensor) and dlpack.is_supported_dtype(jax_dtype): res_tf_platform = tf.DeviceSpec.from_string(res_tf.backing_device).device_type res_jax_platform = res_tf_platform.lower() if res_jax_platform in _DLPACK_PLATFORMS: - res_dlpack = tf.experimental.dlpack.to_dlpack(res_tf) - return jax.dlpack.from_dlpack(res_dlpack) + return jax.dlpack.from_dlpack(res_tf) # When working with a bfloat16 scalar tf.Tensor,np.asarray() can fail. # To handle this special case, we create a numpy copy. @@ -468,6 +465,47 @@ def is_fully_known_shape(s): call_tf_p.def_effectful_abstract_eval(_call_tf_abstract_eval) +def _mlir_type_to_numpy_dtype(type: ir.Type) -> np.dtype: + """Converts an MLIR scalar type to a NumPy dtype.""" + + if isinstance(type, ir.IntegerType): + type = ir.IntegerType(type) + width = type.width + if width == 1: + return np.dtype(np.bool_) + elif width == 8: + return np.dtype(np.uint8 if type.is_unsigned else np.int8) + elif width == 16: + return np.dtype(np.uint16 if type.is_unsigned else np.int16) + elif width == 32: + return np.dtype(np.uint32 if type.is_unsigned else np.int32) + elif width == 64: + return np.dtype(np.uint64 if type.is_unsigned else np.int64) + else: + raise ValueError(f"Unsupported integer width: {width}") + + elif isinstance(type, ir.F16Type): + return np.dtype(np.float16) + elif isinstance(type, ir.F32Type): + return np.dtype(np.float32) + elif isinstance(type, ir.F64Type): + return np.dtype(np.float64) + elif isinstance(type, ir.BF16Type): + return np.dtype(ml_dtypes.bfloat16) + + elif isinstance(type, ir.ComplexType): + element_type = ir.ComplexType(type).element_type + if isinstance(element_type, ir.F32Type): + return np.dtype(np.complex64) + elif isinstance(element_type, ir.F64Type): + return np.dtype(np.complex128) + else: + raise ValueError(f"Unsupported complex element type: {element_type}") + + else: + raise TypeError(f"Unsupported MLIR type for NumPy conversion: {type}") + + def _call_tf_lowering( ctx: mlir.LoweringRuleContext, *args_op, @@ -555,33 +593,8 @@ def convert_to_spec(x): "\n\nCaught TensorFlow exception: " + str(e)) raise ValueError(msg) from e - xla_comp = xla_client.XlaComputation(func_tf_hlo) - - # Canonicalize the results; e.g., makes them x32 if JAX is in 32-bit mode - def canonical_res_aval(res_shape: xla_client.Shape) -> core.ShapedArray: - if not res_shape.is_static(): - msg = ("Compiled TensorFlow function has dynamic output shape " + - f"{res_shape}. call_tf can used " + - "in a staged context (under jax.jit, lax.scan, etc.) only with " + - "compilable functions with static output shapes. " + - "See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion.") - raise ValueError(msg) - - res_dtype = res_shape.numpy_dtype() - jax_res_dtype = dtypes.canonicalize_dtype(res_dtype) - return core.ShapedArray(res_shape.dimensions(), jax_res_dtype) - - result_shape = xla_comp.program_shape().result_shape() - if not result_shape.is_tuple(): - # TF does not wrap singletons as tuples, but JAX expects tuples because - # call_tf is a multiple_results primitive. - result_shapes = (result_shape,) - else: - result_shapes = result_shape.tuple_shapes() # type: ignore - - result_avals = tuple(map(canonical_res_aval, result_shapes)) - - submodule = mlir.xla_computation_to_mlir_module(xla_comp) + stablehlo = _jax.mlir.hlo_to_stablehlo(func_tf_hlo) + submodule = ir.Module.parse(stablehlo) symtab = ir.SymbolTable(submodule.operation) callee_result_types = symtab["main"].type.results fn = mlir.merge_mlir_modules(ctx.module_context.module, @@ -600,10 +613,26 @@ def canonical_res_aval(res_shape: xla_client.Shape) -> core.ShapedArray: ) outputs = [] - for op, res_aval, res_shape in zip(flat_results, result_avals, - result_shapes): - if res_aval.dtype != res_shape.numpy_dtype(): - op = hlo.ConvertOp(mlir.aval_to_ir_type(res_aval), op).result + for op, res_type in zip(flat_results, callee_result_types): + if not res_type.has_static_shape: + msg = ( + "Compiled TensorFlow function has dynamic output shape " + + f"{res_type}. call_tf can used in a staged context (under jax.jit," + " lax.scan, etc.) only with compilable functions with static" + " output shapes. See" + " https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf" + " for a discussion." + ) + raise ValueError(msg) + + res_dtype = _mlir_type_to_numpy_dtype(res_type.element_type) + # Canonicalize the results; e.g., makes them x32 if JAX is in 32-bit mode + jax_res_dtype = dtypes.canonicalize_dtype(res_dtype) + if res_dtype != jax_res_dtype: + op = hlo.ConvertOp( + mlir.aval_to_ir_type(core.ShapedArray(res_type.shape, jax_res_dtype)), + op, + ).result outputs.append(op) return outputs @@ -615,16 +644,6 @@ def _register_call_lowering(platform): for platform in ("cpu", "cuda", "tpu"): _register_call_lowering(platform) -# Support the call_tf under jax2tf.convert in eager mode -def _jax2tf_call_tf(*args: TfVal, - callable_flat_tf: Callable, - **_) -> TfVal: - with jax2tf_internal.inside_call_tf(): - res_tf_flat = callable_flat_tf(*args) - return res_tf_flat - -jax2tf_internal.tf_impl[call_tf_p] = _jax2tf_call_tf - def emit_tf_embedded_graph_custom_call( ctx: mlir.LoweringRuleContext, @@ -695,3 +714,8 @@ def add_to_call_tf_concrete_function_list(concrete_tf_fn: Any, call_tf_concrete_ called_index = len(call_tf_concrete_function_list) call_tf_concrete_function_list.append(concrete_tf_fn) return called_index + +# Register a roofline call so that users can use roofline on functions that +# contain call_tf. We register roofline in this file (instead of within the +# roofline module) to avoid having to import jax2tf in roofline. +roofline.register_standard_roofline(call_tf_p) diff --git a/jax/experimental/jax2tf/examples/keras_reuse_main_test.py b/jax/experimental/jax2tf/examples/keras_reuse_main_test.py index e34282a76ff4..a3c513e8c422 100644 --- a/jax/experimental/jax2tf/examples/keras_reuse_main_test.py +++ b/jax/experimental/jax2tf/examples/keras_reuse_main_test.py @@ -13,6 +13,16 @@ # limitations under the License. import os +import warnings + +# Must be set before import jax, as jax_google.py sets the flag during import. +warnings.filterwarnings( + 'ignore', + message='Setting `jax_pmap_shmap_merge` is deprecated', + category=DeprecationWarning, +) + +# pylint: disable=g-import-not-at-top from absl import flags from absl.testing import absltest from absl.testing import parameterized @@ -21,11 +31,13 @@ from jax.experimental.jax2tf.examples import keras_reuse_main from jax.experimental.jax2tf.tests import tf_test_util +# pylint: enable=g-import-not-at-top jax.config.parse_flags_with_absl() FLAGS = flags.FLAGS +@jtu.thread_unsafe_test_class() class KerasReuseMainTest(tf_test_util.JaxToTfTestCase): def setUp(self): diff --git a/jax/experimental/jax2tf/examples/saved_model_lib.py b/jax/experimental/jax2tf/examples/saved_model_lib.py index 8f2f0982fd3d..313ac521c800 100644 --- a/jax/experimental/jax2tf/examples/saved_model_lib.py +++ b/jax/experimental/jax2tf/examples/saved_model_lib.py @@ -41,7 +41,6 @@ def convert_and_save_model( input_signatures: Sequence[tf.TensorSpec], polymorphic_shapes: str | None = None, with_gradient: bool = False, - enable_xla: bool = True, compile_model: bool = True, saved_model_options: tf.saved_model.SaveOptions | None = None): """Convert a JAX function and saves a SavedModel. @@ -80,8 +79,6 @@ def convert_and_save_model( corresponding input shapes. with_gradient: the value to use for the `with_gradient` parameter for `jax2tf.convert`. - enable_xla: the value to use for the `enable_xla` parameter for - `jax2tf.convert`. compile_model: use TensorFlow jit_compiler on the SavedModel. This is needed if the SavedModel will be used for TensorFlow serving. polymorphic_shapes: if given then it will be used as the @@ -99,8 +96,7 @@ def convert_and_save_model( tf_fn = jax2tf.convert( jax_fn, with_gradient=with_gradient, - polymorphic_shapes=[None, polymorphic_shapes], - enable_xla=enable_xla) + polymorphic_shapes=[None, polymorphic_shapes]) # Create tf.Variables for the parameters. If you want more useful variable # names, you can use `tree.map_structure_with_path` from the `dm-tree` package diff --git a/jax/experimental/jax2tf/examples/saved_model_main_test.py b/jax/experimental/jax2tf/examples/saved_model_main_test.py index 5d698217968d..222ca7e88fb4 100644 --- a/jax/experimental/jax2tf/examples/saved_model_main_test.py +++ b/jax/experimental/jax2tf/examples/saved_model_main_test.py @@ -13,6 +13,16 @@ # limitations under the License. import os +import warnings + +# Must be set before import jax, as jax_google.py sets the flag during import. +warnings.filterwarnings( + 'ignore', + message='Setting `jax_pmap_shmap_merge` is deprecated', + category=DeprecationWarning, +) + +# pylint: disable=g-import-not-at-top from absl import flags from absl.testing import absltest from absl.testing import parameterized @@ -22,11 +32,13 @@ from jax.experimental.jax2tf.examples import saved_model_main from jax.experimental.jax2tf.tests import tf_test_util +# pylint: enable=g-import-not-at-top config.parse_flags_with_absl() FLAGS = flags.FLAGS +@jtu.thread_unsafe_test_class() class SavedModelMainTest(tf_test_util.JaxToTfTestCase): def setUp(self): @@ -47,9 +59,11 @@ def setUp(self): def test_train_and_save_full(self, model="mnist_flax", serving_batch_size=-1): + self.skipTest("no more dynamic shapes") if (serving_batch_size == -1 and - config.jax2tf_default_native_serialization.value and - not config.dynamic_shapes.value): + config.jax2tf_default_native_serialization.value + # and not config.dynamic_shapes.value + ): self.skipTest("shape polymorphism but --jax_dynamic_shapes is not set.") FLAGS.model = model FLAGS.model_classifier_layer = True diff --git a/jax/experimental/jax2tf/g3doc/no_xla_limitations.md b/jax/experimental/jax2tf/g3doc/no_xla_limitations.md deleted file mode 100644 index 24a1d62ee67e..000000000000 --- a/jax/experimental/jax2tf/g3doc/no_xla_limitations.md +++ /dev/null @@ -1,205 +0,0 @@ -# jax2tf Limitations for `enable_xla=False` - -*Note: the list below is only for running jax2tf with `enable_xla=False`. For general jax2tf known issues please see [here](https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf#known-issues)* - -For most JAX primitives there is a natural TF op that fits the needed semantics -(e.g., `jax.lax.abs` is equivalent to `tf.abs`). However, there are a number of -JAX primitives for which there is no single TF op with matching semantics -(e.g., `jax.lax.conv_general_dilated` does not have a matching `tf` op). For -these cases, the `jax2tf` emitter uses a set of special TF ops that are thin -wrappers over HLO ops. - -However, these ops are only be executable by a consumer that has XLA linked in, -and this is not the case for the TF.js and TFLite converter. Therefore we -provide limited support for these ops by implementing them in terms of ops that -are supported in [impl_no_xla.py](../impl_no_xla.py). - -## Summary Table - -The table below shows for each XLA ops by which JAX primitives it is used, and -whether the ops is fully, partially, or not supported by the `jax2tf` emitter. -In the next section we provide more details on the ops for which we provide -partial support. - -For a detailed description of these XLA ops, please see the -[XLA Operation Semantics documentation](https://www.tensorflow.org/xla/operation_semantics). - -| XLA ops ([documentation](https://www.tensorflow.org/xla/operation_semantics)) | JAX primitive(s) ([documentation](https://jax.readthedocs.io/en/latest/jax.lax.html)) | Supported | -| ------- | ---------------- | ------- | -| XlaDot | `lax.dot_general` | Full | -| XlaDynamicSlice | `lax.dynamic_slice` | Full | -| XlaDynamicUpdateSlice | `lax.dynamic_update_slice` | Full | -| XlaPad | `lax.pad` | Full | -| XlaConv | `lax.conv_general_dilated` | [Partial](#xlaconv) | -| XlaGather | `lax.gather` | [Partial](#xlagather) | -| XlaReduceWindow | `lax.reduce_window` | [Partial](#xlareducewindow) | -| XlaScatter | `lax.scatter`, `lax.scatter_min`, `lax.scatter_max`, `lax.scatter_mul`, `lax.scatter_add` | [Partial](#xlascatter) | -| XlaSelectAndScatter | `lax._select_and_scatter_add` | Unsupported | -| XlaReduce | `lax.reduce`, `lax.argmin`, `lax.argmax` | Unsupported | -| XlaVariadicSort | `lax.sort` | Unsupported | - - -## Partially Supported JAX Primitives - -Below we describe for all partially supported JAX primitives which cases we -support and which not. - -### XlaConv - -JAX convolutions are done using -[`lax.conv_general_dilated`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_general_dilated.html). - -``` -lax.conv_general_dilated( - lhs, rhs, window_strides, padding, lhs_dilation, - rhs_dilation, dimension_numbers, feature_group_count, - batch_group_count, precision, preferred_element_type -) -``` - -We provide support for convolutions as follows: - -* Only 1D and 2D convolutions, i.e. `lhs.ndim == 3 or 4`. -* Regular convolutions and atrous (aka, dilated) convolutions - (i.e., `rhs_dilation != (1, 1, ...)`) are supported through the TF op - [`tf.nn.conv2d`](https://www.tensorflow.org/api_docs/python/tf/nn/conv2d). -* Transposed convolutions (i.e., `lhs_dilation != (1, 1, ...)`) are supported - through - [`tf.nn.conv2d_transpose`](https://www.tensorflow.org/api_docs/python/tf/nn/conv2d_transpose), - with either 'SAME' or 'VALID' padding. -* Depthwise convolutions (i.e. - `in_channels == feature_group_count and feature_group_count > 1`) are - supported through - [`tf.nn.depthwise_conv2d`](https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d). - Note that atrous depthwise convolutions are supported. -* No support for batch groups, i.e. `batch_group_count == 1`. -* No support for feature groups, except for depth-wise convolutions. -* Input may be provided in any order (specified using `dimension_numbers`). -* Only one of depthwise, atrous and transposed convolutions may be used at the - same time, though depthwise atrous convolutions are supported. -* Convolutions are known to have a somewhat higher numeric inaccuracy, so if you - are using many large convolutions, this may lead to large deviations. - -### XlaGather - -XLA's gather op is complex and covers may use cases. It is called from JAX using -`lax.gather`, but many other primitives and operations use it as well, for -instance, parallelization primitives `vmap` and `pmap` use gather to specify a -batch dimension, and it is used for slices or multidimensional indexing as well, -e.g. `x[0, 1]`, `x[:, :1]`, or `x[[0], [1]]`. - -The signature of [`lax.gather`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.gather.html#jax.lax.gather) -is as follows: - -``` -lax.gather( - operand, start_indices, dimension_numbers, slice_sizes, - unique_indices=False, indices_are_sorted=False, mode=None, - fill_value=None -) -``` - -We provide support for the following cases: - -* *Scalar indexing*. This means we are indexing into a single dimension, - retrieving either a partial slice or a single value from that dimension. For - all other dimensions we retrieve the full slice. Examples include `op[2]`, - `op[:, :5, :]`, and `jnp.take(op, 0, axis=0)`. This means that - `len(start_indices.shape) == 1`. We provide support for this path through the - TF op - [`tf.strided_slice`](https://www.tensorflow.org/api_docs/python/tf/strided_slice). - -* *Multi-dimensional indexing*. This means we index into multiple dimensions, - e.g., `jnp.take(op, [[0], [1]], axis=0)` or `op[[0], [4], [1]]`. We currently - only support multi-dimensional indexing if the last dimension is 1, which - means we can only retrieve a single value per dimension, and we can't retrieve - slices. We provide support for this path through the TF op - [`tf.gather`](https://www.tensorflow.org/api_docs/python/tf/gather). - -* *Gather with a batch dimension*. E.g., when doing - `jax.vmap(lax.dynamic_slice)`, which will result in a call to `lax.gather` - where the first dimension of the input is the batch dimension. This means that - `len(batch_dims) == 1`. We currently only support a single batch dimension - (i.e., `vmap(vmap))` does not work). We provide support for this path through - the TF op [`tf.slice`](https://www.tensorflow.org/api_docs/python/tf/slice). - -All other cases of `lax.gather` are currently not supported. - - -### XlaReduceWindow - -The signature of [`lax.reduce_window`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.reduce_window.html) -is as follows: - -``` -lax.reduce_window(operand, init_value, computation: Callable, - window_dimensions: core.Shape, window_strides: Sequence[int], - padding: Union[str, Sequence[Tuple[int, int]]], - base_dilation: Optional[Sequence[int]] = None, - window_dilation: Optional[Sequence[int]] = None) -) -``` - -This function with either call a monoid reducer `lax.reduce_window_min_p`, -`lax.reduce_window_max_p`, or `lax.reduce_window_sum_p`, or the full reducer -function `lax.reduce_window_p` with the following conditions: - -* If `computation` is one of `lax.min`, `lax.max`, or `lax.add` and `init_value` - is the identity element for `computation` (for instance: 0 for `lax.add`), - then it will call one of the monoid reducers. - -* Otherwise, it will call the full reduction function `lax.reduce_window_p`. - -We provide partial support for all these ops, with the following limitations: - -* `computation` should be one of `lax.min`, `lax.max`, or `lax.add`. -* For `lax.min` and `lax.max`, dtypes `np.bool`, `np.uint32`, `np.uint64`, - `np.complex64`, and `np.complex128` are not supported. -* Additionally, for `lax.min`, dtypes `np.uint8` and `np.uint16` are not - supported. -* For `lax.add`, only dtypes `np.float16`, `np.float32`, and `np.float64` are - supported. -* We support at most 2 spatial dimension. -* Base dilations other than `(1,) * len(operand)` are not supported. -* `padding` should either be `VALID` or `SAME`. -* We compute `lax.reduce_window_sum_p` by calling `tf.nn.avg_pool` (through - `tf.nn.pool`), and then multiplying the result by - `np.prod(window_dimensions)`. If you are using an NN library that implements - `avg_pool` using `lax.reduce_window` (such as Flax's - [pooling.py](https://github.com/google/flax/blob/main/flax/linen/pooling.py)), - this is usually implemented by dividing the result with - `np.prod(window_dimensions)`. So when converting this function, the - resulting computation for `avg_pool` is `(tf.nn.avg_pool(xs) * - np.prod(window)) / np.prod(window)`. This is redundant and can be optimized. -* Using `lax.add` on TPU may give very large deviations. This is due to the - way the conversion is implemented (first take the average over the window - and then multiply by window size). This gives large deviations on TPU due to - the fact that it uses `bfloat16` for computations. - -We implement all reductions using the Tensorflow function -[tf.nn.pool](https://www.tensorflow.org/api_docs/python/tf/nn/pool). - -### XlaScatter - -This op is called by `lax.scatter`, `lax.scatter_min`, `lax.scatter_max`, -`lax.scatter_mul` and `lax.scatter_add`. - -We support all these ops for unique indices. For non-unique indices we -support (min,max,mul,add) for single depth scatters. - -We implement support for this op through -[tf.tensor_scatter_nd_update](https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update). - -There are a few more limitations: - -* Dtypes `np.bool` and `jnp.complex*` are not supported. -* We disallow scatter mode `lax.GatherScatterMode.CLIP` because it may lead to - incorrect behavior for out-of-bounds indices (see next point). -* The behavior for out-of-bounds scatter indices is as follows: - - When running in eager or graph mode, it throws an error. This is because - `tf.scatter` throws an error as well. If this is problematic for your use - case, please let us know and we can add more support for this. - - When running in compile mode, the out-of-bounds indices are dropped, which - is the behavior of both `lax.GatherScatterMode.FILL_OR_DROP` and - `lax.GatherScatterMode.PROMISE_IN_BOUNDS`. This is why `CLIP` is not - allowed. diff --git a/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md b/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md deleted file mode 100644 index b36b004a9d31..000000000000 --- a/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md +++ /dev/null @@ -1,168 +0,0 @@ -# Primitives with limited support for jax2tf - -*Last generated on (YYYY-MM-DD): 2023-07-31* - -This document summarizes known limitations of the jax2tf conversion. -There are several kinds of limitations. - - * There are some JAX primitives that are converted to TF ops that have incomplete coverage - for data types on different kinds of devices, - see [below](#generated-summary-of-primitives-with-unimplemented-support-in-tensorflow). - - * There are some cases when the converted program computes different results than - the JAX program, see [below](#generated-summary-of-primitives-with-known-numerical-discrepancies-in-tensorflow). - -Note that automated tests will fail if new limitations appear, but -they won't when limitations are fixed. If you see a limitation that -you think it does not exist anymore, please ask for this file to -be updated. - -## Generated summary of primitives with unimplemented support in Tensorflow - -The following JAX primitives are converted to Tensorflow but the result of the -conversion may trigger runtime errors when run on certain devices and with -certain data types. - -This table is organized by JAX primitive, but the actual errors described -in the table are for the Tensorflow ops to which the primitive is converted to. -In general, each JAX primitive is mapped -to one Tensorflow op, e.g., `sin` is mapped to `tf.math.sin`. - -The errors apply only for certain devices and compilation modes ("eager", -"graph", and "compiled"). In general, "eager" and "graph" mode share the same errors. -On TPU only the "compiled" mode is relevant. - -Our priority is to ensure same coverage and numerical behavior with JAX -in the "compiled" mode, i.e., **when using XLA to compile the converted program**. -We are pretty close to that goal. - -The converter has a mode in which it attempts to avoid special XLA TF ops -(`enable_xla=False`). In this mode, some primitives have additional limitations. - -This table only shows errors for cases that are working in JAX (see [separate -list of unsupported or partially-supported primitives](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) ) - -We do not yet have support for `pmap` (with its collective primitives), -nor for `sharded_jit` (SPMD partitioning). - -We use the following abbreviations for sets of dtypes: - - * `signed` = `int8`, `int16`, `int32`, `int64` - * `unsigned` = `uint8`, `uint16`, `uint32`, `uint64` - * `integer` = `signed`, `unsigned` - * `floating` = `float16`, `bfloat16`, `float32`, `float64` - * `complex` = `complex64`, `complex128` - * `inexact` = `floating`, `complex` - * `all` = `integer`, `inexact`, `bool` - -More detailed information can be found in the -[source code for the limitation specification](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/tests/primitives_test.py). - - -| Affected primitive | Description of limitation | Affected dtypes | Affected devices | Affected compilation modes | -| --- | --- | --- | --- | --- | -| approx_top_k | TF error: compilation not supported for float64. | float64 | cpu, gpu | compiled | -| approx_top_k | TF error: op not defined for dtype | floating | cpu, gpu | eager, graph | -| bessel_i0e | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph | -| bessel_i1e | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph | -| cholesky | TF test skipped: Not implemented in JAX: unimplemented | float16 | cpu, gpu | compiled, eager, graph | -| clamp | TF test skipped: Not implemented in JAX: unimplemented | bool, complex | cpu, gpu, tpu | compiled, eager, graph | -| conv_general_dilated | TF error: Numeric comparison disabled: Non-deterministic NaN for conv_general_dilated with preferred_element_type | int16, int32, int64 | cpu, gpu, tpu | compiled, eager, graph | -| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type not implemented for integers | signed | gpu | compiled, eager, graph | -| digamma | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph | -| div | TF error: TF integer division fails if divisor contains 0; JAX returns NaN | integer | cpu, gpu, tpu | compiled, eager, graph | -| dot_general | TF test skipped: TF error: Numeric comparison disabled: Crash when lhs_dtype != rhs_dtype for non-native serialization on TPU | all | tpu | compiled, eager, graph | -| dot_general | TF error: Numeric comparison disabled: Errors when lhs_dtype != rhs_dtype for non-native serialization on CPU and GPU | all | cpu, gpu, tpu | compiled, eager, graph | -| dot_general | TF error: Numeric comparison disabled: Large tolerances when upcasting with preferred_element_type on CPU (b/241740367) | all | cpu, gpu, tpu | compiled, eager, graph | -| dot_general | TF error: Numeric comparison disabled: Non-deterministic NaN for dot_general with preferred_element_type on GPU (b/189287598) | bfloat16, complex64, float16, float32 | gpu | compiled, eager, graph | -| dot_general | TF test skipped: Not implemented in JAX: preferred_element_type must be floating for integer dtype | integer | gpu | compiled, eager, graph | -| dot_general | TF test skipped: Not implemented in JAX: preferred_element_type must match dtype for floating point | inexact | gpu | compiled, eager, graph | -| dot_general | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph | -| eig | TF test skipped: Not implemented in JAX: only supported on CPU in JAX | all | gpu, tpu | compiled, eager, graph | -| eig | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu | compiled, eager, graph | -| eig | TF error: TF Conversion of eig is not implemented when both compute_left_eigenvectors and compute_right_eigenvectors are set to True | all | cpu, gpu, tpu | compiled, eager, graph | -| eig | TF error: function not compilable | all | cpu | compiled | -| eigh | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu | compiled, eager, graph | -| eigh | TF error: op not defined for dtype | bfloat16 | tpu | compiled, eager, graph | -| erf_inv | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu | eager, graph | -| fft | TF error: TF function not compilableble | float64 | cpu, gpu | compiled | -| fft | TF error: TF function not compilableble for IFFT and IRFFT | complex128 | cpu, gpu | compiled | -| igamma | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu | eager, graph | -| igammac | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu | eager, graph | -| lgamma | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph | -| lu | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph | -| nextafter | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph | -| qr | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu | compiled, eager, graph | -| qr | TF error: op not defined for dtype | bfloat16 | tpu | compiled, eager, graph | -| reduce_max | TF error: op not defined for dtype | complex | cpu, gpu, tpu | compiled, eager, graph | -| reduce_min | TF error: op not defined for dtype | complex | cpu, gpu, tpu | compiled, eager, graph | -| regularized_incomplete_beta | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph | -| rem | TF error: Numeric comparison disabled: TF division of inf by inf returns inf while in JAX returns nan | float32 | gpu | compiled, eager, graph | -| rem | TF error: Numeric comparison disabled: TF integer division fails if divisor contains 0; JAX returns NaN | integer | cpu, gpu, tpu | compiled, eager, graph | -| round | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph | -| scatter | TF error: Numeric comparison disabled: out-of-bounds scatters are not supported in graph and eager mode | inexact | cpu, gpu, tpu | eager, graph | -| scatter_add | TF test skipped: Not implemented in JAX: unimplemented | bool | cpu, gpu, tpu | compiled, eager, graph | -| scatter_add | TF error: Numeric comparison disabled: out-of-bounds scatters are not supported in graph and eager mode | inexact | cpu, gpu, tpu | eager, graph | -| scatter_max | TF error: Numeric comparison disabled: out-of-bounds scatters are not supported in graph and eager mode | inexact | cpu, gpu, tpu | eager, graph | -| scatter_min | TF error: Numeric comparison disabled: out-of-bounds scatters are not supported in graph and eager mode | inexact | cpu, gpu, tpu | eager, graph | -| scatter_mul | TF test skipped: Not implemented in JAX: unimplemented | bool | cpu, gpu, tpu | compiled, eager, graph | -| scatter_mul | TF error: Numeric comparison disabled: out-of-bounds scatters are not supported in graph and eager mode | inexact | cpu, gpu, tpu | eager, graph | -| select_and_gather_add | TF error: jax2tf unimplemented for 64-bit inputs because the current implementation relies on packing two values into a single value. This can be fixed by using a variadic XlaReduceWindow, when available | float64 | cpu, gpu | compiled, eager, graph | -| select_and_scatter_add | TF test skipped: Not implemented in JAX: works only for 2 or more inactive dimensions | all | tpu | compiled, eager, graph | -| svd | TF error: Numeric comparison disabled: Large numerical discrepancy | float16 | tpu | compiled, eager, graph | -| svd | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu | compiled, eager, graph | -| svd | TF error: function not compilable. Implemented using `tf.linalg.svd` and `tf.linalg.adjoint` | complex | cpu, gpu | compiled | -| svd | TF error: op not defined for dtype | bfloat16 | tpu | compiled, eager, graph | -| svd | TF error: op not defined for dtype | complex | tpu | compiled, graph | -| triangular_solve | TF test skipped: Not implemented in JAX: unimplemented | float16 | gpu | compiled, eager, graph | -| triangular_solve | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph | -| triangular_solve | TF error: op not defined for dtype | float16 | cpu, gpu | eager, graph | - -## Generated summary of primitives with known numerical discrepancies in Tensorflow - -In general, we expect a JAX program to produce the same exact answer as its conversion -with jax2tf. The following table lists that cases when this does not quite hold: - - -| Affected primitive | Description of limitation | Affected dtypes | Affected devices | Affected compilation modes | -| --- | --- | --- | --- | --- | -| acosh | May return different but still correct results | complex | cpu, gpu, tpu | eager, graph | -| approx_top_k | custom numeric comparison | floating | cpu, gpu | eager, graph | -| argmax | Numeric comparison disabled: different results when the input contains NaN and enable_xla=False | inexact | cpu, gpu, tpu | compiled, eager, graph | -| argmin | Numeric comparison disabled: different results when the input contains NaN and enable_xla=False | inexact | cpu, gpu, tpu | compiled, eager, graph | -| asin | May return different but still correct results | complex | cpu, gpu, tpu | eager, graph | -| asinh | May return different but still correct results | complex | cpu, gpu, tpu | eager, graph | -| atan | May return different but still correct results | complex | cpu, gpu, tpu | eager, graph | -| atanh | May return different but still correct results | complex | cpu, gpu, tpu | eager, graph | -| cholesky | May return different values in the strictly upper triangular part of the result. This does not matter for correctness, because this part of the matrix is not considered in the result. | all | cpu, gpu, tpu | compiled, eager, graph | -| custom_linear_solve | Numeric comparison disabled: TODO: large numerical discrepancy | float32 | tpu | compiled, eager, graph | -| digamma | May return different results at singularity points 0 and -1.JAX returns nan and TF returns inf | bfloat16 | cpu, gpu, tpu | eager, graph | -| eig | May return the eigenvalues and eigenvectors in a potentially different order. The eigenvectors may also be different, but equally valid. | all | cpu, gpu, tpu | eager, graph | -| eigh | May return the eigenvalues and eigenvectors in a potentially different order. The eigenvectors may also be different, but equally valid. | all | cpu, gpu, tpu | compiled, eager, graph | -| eigh | Numeric comparison disabled: TODO: numeric discrepancies | float16 | tpu | compiled, eager, graph | -| erf_inv | May return different results at undefined points (< -1 or > 1): JAX returns `NaN` and TF returns `+inf` or `-inf`. | float32, float64 | cpu, gpu, tpu | eager, graph | -| igamma | May return different results at undefined points (both arguments 0). JAX returns `NaN` and TF returns 0 or JAX returns 1 and TF returns `NaN` | all | cpu, gpu, tpu | eager, graph | -| igammac | May return different results at undefined points (both arguments less or equal 0). JAX returns `NaN` and TF returns 0 or JAX returns 1 and TF returns `NaN` | all | cpu, gpu | eager, graph | -| integer_pow | Numeric comparison disabled: Different overflow behavior for large exponents. | bfloat16, complex, float16, float32, signed | cpu, gpu, tpu | eager, graph | -| integer_pow | Numeric comparison disabled: Different overflow behavior. | bfloat16, float16 | tpu | eager, graph | -| integer_pow | custom numeric comparison | complex | cpu, gpu, tpu | eager, graph | -| lu | May return different, but also correct, results when the decomposition is not unique | all | cpu, gpu | compiled, eager, graph | -| max | May return different values when one of the values is NaN. JAX always returns NaN, while TF returns the value NaN is compared with. | all | cpu, gpu, tpu | compiled, eager, graph | -| max | TF and JAX use different values of the compiler flag xla_cpu_enable_fast_min_max compiler flag and therefore have different behavior of NaN propagation through min/max. | all | cpu | compiled, eager, graph | -| min | May return different values when one of the values is NaN. JAX always returns NaN, while TF returns the value NaN is compared with. | all | cpu, gpu, tpu | compiled, eager, graph | -| min | TF and JAX use different values of the compiler flag xla_cpu_enable_fast_min_max compiler flag and therefore have different behavior of NaN propagation through min/max. | all | cpu | compiled, eager, graph | -| pow | custom numeric comparison | complex | cpu, gpu, tpu | eager, graph | -| random_split | Returns JAX key arrays, so compare underlying base array | all | cpu, gpu, tpu | compiled, eager, graph | -| reduce_window_add | Numeric comparison disabled: Large deviations on TPU for enable_xla=False | float16, float32 | tpu | compiled, eager, graph | -| sort | Numeric comparison disabled: TODO: TF non-stable multiple-array sort | all | gpu | compiled, eager, graph | -| svd | custom numeric comparison when compute_uv on CPU/GPU | all | cpu, gpu | compiled, eager, graph | -| svd | custom numeric comparison when compute_uv on TPU | complex, float32, float64 | tpu | compiled, eager, graph | -| top_k | Produces different results when the array contains `inf` and `NaN` (they are sorted differently in TF vs. XLA). | floating | cpu, gpu, tpu | eager, graph | - -## Updating the documentation - -To update this documentation, run the following command: - -``` - JAX_ENABLE_X64=1 JAX_OUTPUT_LIMITATIONS_DOC=1 python jax/experimental/jax2tf/tests/primitives_test.py JaxPrimitiveTest.test_generate_limitations_doc -``` diff --git a/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md.template b/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md.template deleted file mode 100644 index 219802f5363a..000000000000 --- a/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md.template +++ /dev/null @@ -1,76 +0,0 @@ -# Primitives with limited support for jax2tf - -*Last generated on (YYYY-MM-DD): {{generation_date}}* - -This document summarizes known limitations of the jax2tf conversion. -There are several kinds of limitations. - - * There are some JAX primitives that are converted to TF ops that have incomplete coverage - for data types on different kinds of devices, - see [below](#generated-summary-of-primitives-with-unimplemented-support-in-tensorflow). - - * There are some cases when the converted program computes different results than - the JAX program, see [below](#generated-summary-of-primitives-with-known-numerical-discrepancies-in-tensorflow). - -Note that automated tests will fail if new limitations appear, but -they won't when limitations are fixed. If you see a limitation that -you think it does not exist anymore, please ask for this file to -be updated. - -## Generated summary of primitives with unimplemented support in Tensorflow - -The following JAX primitives are converted to Tensorflow but the result of the -conversion may trigger runtime errors when run on certain devices and with -certain data types. - -This table is organized by JAX primitive, but the actual errors described -in the table are for the Tensorflow ops to which the primitive is converted to. -In general, each JAX primitive is mapped -to one Tensorflow op, e.g., `sin` is mapped to `tf.math.sin`. - -The errors apply only for certain devices and compilation modes ("eager", -"graph", and "compiled"). In general, "eager" and "graph" mode share the same errors. -On TPU only the "compiled" mode is relevant. - -Our priority is to ensure same coverage and numerical behavior with JAX -in the "compiled" mode, i.e., **when using XLA to compile the converted program**. -We are pretty close to that goal. - -The converter has a mode in which it attempts to avoid special XLA TF ops -(`enable_xla=False`). In this mode, some primitives have additional limitations. - -This table only shows errors for cases that are working in JAX (see [separate -list of unsupported or partially-supported primitives](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) ) - -We do not yet have support for `pmap` (with its collective primitives), -nor for `sharded_jit` (SPMD partitioning). - -We use the following abbreviations for sets of dtypes: - - * `signed` = `int8`, `int16`, `int32`, `int64` - * `unsigned` = `uint8`, `uint16`, `uint32`, `uint64` - * `integer` = `signed`, `unsigned` - * `floating` = `float16`, `bfloat16`, `float32`, `float64` - * `complex` = `complex64`, `complex128` - * `inexact` = `floating`, `complex` - * `all` = `integer`, `inexact`, `bool` - -More detailed information can be found in the -[source code for the limitation specification](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/tests/primitives_test.py). - -{{tf_error_table}} - -## Generated summary of primitives with known numerical discrepancies in Tensorflow - -In general, we expect a JAX program to produce the same exact answer as its conversion -with jax2tf. The following table lists that cases when this does not quite hold: - -{{tf_numerical_discrepancies_table}} - -## Updating the documentation - -To update this documentation, run the following command: - -``` - JAX_ENABLE_X64=1 JAX_OUTPUT_LIMITATIONS_DOC=1 python jax/experimental/jax2tf/tests/primitives_test.py JaxPrimitiveTest.test_generate_limitations_doc -``` diff --git a/jax/experimental/jax2tf/impl_no_xla.py b/jax/experimental/jax2tf/impl_no_xla.py deleted file mode 100644 index 644c3324b4e2..000000000000 --- a/jax/experimental/jax2tf/impl_no_xla.py +++ /dev/null @@ -1,1288 +0,0 @@ -# Copyright 2020 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Workarounds for jax2tf transforms when XLA is not linked in.""" - -from __future__ import annotations - -import builtins -from collections.abc import Callable, Sequence -import dataclasses -from functools import partial, wraps -import math -import string -from typing import Any - -from jax._src import core -from jax import lax -from jax._src.lax import slicing as lax_slicing -from jax._src import dtypes -from jax._src import util - -from jax.experimental.jax2tf import jax2tf - -import numpy as np -import tensorflow as tf - - -# Implementation rules for primitives when XLA is not linked in. These -# implementations are workarounds, making use of TF ops that do work when XLA is -# not linked in. They are only used when the argument `enable_xla=False` when -# calling jax2tf.convert(). -tf_impl_no_xla: dict[core.Primitive, Callable[..., Any]] = {} - - -TfVal = Any -DType = Any -PrecisionType = Any - - -def _error(primitive_name: str, suffix_msg: str = "") -> Exception: - msg = f"Call to {primitive_name} cannot be converted with enable_xla=False." - if suffix_msg: - msg += (f" {suffix_msg} - See source code for the precise conditions under " - "which it can be converted without XLA.") - return NotImplementedError(msg) - -_conv_error = lambda msg: _error("conv_general_dilated", msg) -_reduce_error = lambda msg: _error("reduce_window", msg) -_scatter_error = lambda msg: _error("scatter_(update/add/multiply/min/max)", msg - ) - -def _unimplemented(name): - - def op(*arg, **kwargs): - raise _error(name) - - return op - - -# TODO(marcvanzee): Remove this function and use `tf.math.invert_permutation` -# once it is implemented by TFjs: -# https://github.com/tensorflow/tfjs/issues/6395. -def _invert_permutation(perm): - return tuple(perm.index(i) for i in range(len(perm))) - - -def _transpose_with_shape(x: TfVal, x_shape: core.Shape, permutation) -> tuple[TfVal, core.Shape]: - """Computes transposition of x and its shape. - - x_shape matches x.shape in the known dimensions, and it has dimension - polynomials elsewhere, while x.shape has None. - """ - return tf.transpose(x, perm=permutation), tuple(x_shape[i] for i in permutation) - - -def _transpose_for_tf_conv(lhs, lhs_shape: core.Shape, - rhs, rhs_shape: core.Shape, dimension_numbers): - """Transposes lhs and rhs to respectively NHWC and HWIO so they can be passed to TF functions. - - The shapes passed in and returned may contain polynomials, and thus may - be different than lhs.shape and rhs.shape. - """ - # TODO(marcvanzee): Add tests for this ops for shape polymorphism. - lhs_perm, rhs_perm, _ = dimension_numbers - - # TODO(marcvanzee): Consider merging transposes if we want to optimize. - # For `lhs_perm` / `output_perm`, perm (0, 1, 2, 3) corresponds to "NCHW". - lhs, lhs_shape = _transpose_with_shape(lhs, lhs_shape, lhs_perm) # lhs --> "NCHW" - if len(lhs_perm) == 3: - # For 1D convolution, we add a trivial "W" dimension, so that 2D Convolution - # logic can be applied downstream. - lhs = lhs[:, :, :, np.newaxis] - lhs_shape = tuple(lhs_shape) + (1,) - # However, the TF ops only support "NHWC" on CPU, so we transpose again. - lhs, lhs_shape = _transpose_with_shape(lhs, lhs_shape, (0, 2, 3, 1)) # "NCHW" --> "NHWC" - - # For `rhs_perm`, perm (0, 1, 2, 3) corresponds to "OIHW". - rhs, rhs_shape = _transpose_with_shape(rhs, rhs_shape, rhs_perm) # rhs --> "OIHW" - # Handle conv1d case. - if len(rhs_perm) == 3: - rhs = rhs[:, :, :, np.newaxis] - rhs_shape = tuple(rhs_shape) + (1,) - # For the tf ops, rhs is expected to be "OIHW". - rhs, rhs_shape = _transpose_with_shape(rhs, rhs_shape, (2, 3, 1, 0)) # "OIHW" --> "HWIO" - jax2tf._assert_matching_abstract_shape(lhs, lhs_shape) - jax2tf._assert_matching_abstract_shape(rhs, rhs_shape) - return lhs, lhs_shape, rhs, rhs_shape - - -def pads_to_padtype(in_shape, window_shape, window_strides, padding) -> str: - for pad_str in ["VALID", "SAME"]: - pads = lax.padtype_to_pads(in_shape, window_shape, window_strides, pad_str) - if list(pads) == list(padding): - return pad_str - return "EXPLICIT" - - -def _pad_spatial_dims(x, x_shape, padding): - """Pads `x` using `padding`, which specifies padding for the spatial dimensions.""" - padding = tuple(padding) - if len(padding) == len(x_shape) - 2: - # If necessary, add empty padding for batch and feature dimensions. - no_pad = ((0, 0),) - padding = no_pad + padding + no_pad - x = tf.pad(x, padding) - assert len(x.shape) == len(padding) - x_shape = tuple(p0 + xs + p1 for xs, (p0, p1) in zip(x_shape, padding)) - jax2tf._assert_matching_abstract_shape(x, x_shape) - return x, x_shape - -def _check_pad_spatial_dims(x, x_shape, padding): - """Pads `x` using `padding`, which specifies padding for the spatial dimensions.""" - padding = tuple(padding) - if len(padding) == len(x_shape) - 2: - # If necessary, add empty padding for batch and feature dimensions. - no_pad = ((0, 0),) - padding = no_pad + padding + no_pad - assert len(x.shape) == len(padding) - x_shape = tuple(p0 + xs + p1 for xs, (p0, p1) in zip(x_shape, padding)) - return x, x_shape, padding - -def _conv_transpose_pads_to_padtype(kernel_sdims, lhs_dilation, padding): - """Finds the padding type for a transpose convolution.""" - # This is simply checking agreement with lax._conv_transpose_padding. - is_valid = True - is_same = True - if not len(kernel_sdims) == len(lhs_dilation) == len(padding): - raise ValueError(f'Found different lengths for ' - f'kernel_sdims ({kernel_sdims}), ' - f'lhs_dilation ({lhs_dilation}), ' - f'and padding ({padding}).') - for k, s, (begin, end) in zip(kernel_sdims, lhs_dilation, padding): - # Check for VALID padding. - pad_len_valid = k + s - 2 + builtins.max(k - s, 0) - pad_a = k - 1 - pad_b = pad_len_valid - pad_a - if begin != pad_a or end != pad_b: - is_valid = False - - # Check for SAME padding. - pad_len_same = k + s - 2 - if s > k - 1: - pad_a = k - 1 - else: - pad_a = int(np.ceil(pad_len_same / 2)) - pad_b = pad_len_same - pad_a - if begin != pad_a or end != pad_b: - is_same = False - - if is_valid: - return 'VALID' - elif is_same: - return 'SAME' - raise ValueError('Transpose convolution padding mode must be ' - '`SAME` or `VALID`.') - -def _validate_spatial_dimensions(lhs: TfVal, lhs_shape: core.Shape, - rhs: TfVal, rhs_shape: core.Shape): - """Check spatial dimension support.""" - jax2tf._assert_matching_abstract_shape(lhs, lhs_shape) - jax2tf._assert_matching_abstract_shape(rhs, rhs_shape) - - nr_spatial_dimensions = len(lhs_shape) - 2 - # Currently we only support 1D+2D convolutions because it keeps the code - # relatively simple and covers most cases. - if nr_spatial_dimensions > 2: - raise _conv_error( - "We only support 1D or 2D convolutions, but found " - f"{nr_spatial_dimensions}.") - - -def _normalize_padding_and_dilations( - padding, lhs_dilation, rhs_dilation, is_conv1d): - if is_conv1d: - lhs_dilation = list(lhs_dilation) + [1] - rhs_dilation = list(rhs_dilation) + [1] - # Empty padding in the dummy dimension. - # Note that when kernel_size=stride=1, padding of (0, 0) is both 'VALID' and - # 'SAME'. So the inferred padding type will still register according to the - # first dimension padding. - padding = list(padding) + [(0, 0)] - return padding, lhs_dilation, rhs_dilation - - -def _normalize_window_strides(window_strides): - """Ensure window_strides has length 4.""" - # Some TF ops require len(window_strides) == 4 while others do not. We simply - # ensure it always has len(4). - if len(window_strides) == 1: - # This is the Conv1D case. We add a dummy dimension to allow using 2D ops, - # and use stride=1 on the dummy dimension. - window_strides = list(window_strides) + [1] - if len(window_strides) == 2: - window_strides = [1] + list(window_strides) + [1] - return window_strides - - -def _validate_conv_features( - is_transpose, is_atrous, is_depthwise, feature_group_count, - batch_group_count, preferred_element_type, lhs_dtype): - if feature_group_count > 1 and not is_depthwise: - raise _conv_error("Grouped convolutions are unsupported") - if (is_depthwise and is_atrous) and not is_transpose: - # We allow dilated depthwise convolutions. - pass - elif [is_depthwise, is_atrous, is_transpose].count(True) > 1: - raise _conv_error( - f"Can only do one of depthwise ({is_depthwise}), atrous ({is_atrous}) " - f"and transposed convolutions ({is_transpose})") - - # We can implement batch grouping when there is a need for it. - if batch_group_count != 1: - raise _conv_error("Unimplemented support for batch_group_count != 1 " - f"(found {batch_group_count})") - - if (preferred_element_type is not None and - preferred_element_type != lhs_dtype): - raise _conv_error("Unimplemented support for preferred_element_type") - - -def _conv_general_dilated( - lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation, - dimension_numbers: lax.ConvDimensionNumbers, feature_group_count: int, - batch_group_count: int, - precision: tuple[PrecisionType, PrecisionType] | None, - preferred_element_type: DType | None, - _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): - """Implementation of lax.conv_general_dilated_p using XlaConv.""" - # In presence of shape polymorphism, lhs.shape and rhs.shape may contain - # None. The actual dimension polynomial shapes are in _in_avals. - del precision # Unused arguments. - lhs_shape, rhs_shape = _in_avals[0].shape, _in_avals[1].shape - out_shape = _out_aval.shape - _validate_spatial_dimensions(lhs, lhs_shape, rhs, rhs_shape) - is_conv1d = len(lhs_shape) - 2 == 1 - - tf_window_strides = _normalize_window_strides(window_strides) - padding, lhs_dilation, rhs_dilation = _normalize_padding_and_dilations( - padding, lhs_dilation, rhs_dilation, is_conv1d) - - lhs, lhs_shape, rhs, rhs_shape = _transpose_for_tf_conv(lhs, lhs_shape, - rhs, rhs_shape, - dimension_numbers) - in_channels = lhs_shape[-1] - *rhs_spatial_shapes, _, rhs_out_channel = rhs_shape - - is_transpose = any(d != 1 for d in lhs_dilation) - is_atrous = any(d != 1 for d in rhs_dilation) - is_depthwise = in_channels == feature_group_count and feature_group_count > 1 - _validate_conv_features(is_transpose, is_atrous, is_depthwise, - feature_group_count, batch_group_count, - preferred_element_type, lhs.dtype.as_numpy_dtype) - - rhs_dilated_shape = [ - (k - 1) * r + 1 for k, r in zip(rhs_spatial_shapes, rhs_dilation) - ] - output_perm = dimension_numbers[2] - - if is_transpose: - padding_type = _conv_transpose_pads_to_padtype( - rhs_spatial_shapes, lhs_dilation, padding) - else: - padding_type = pads_to_padtype( - lhs_shape[1:3], rhs_dilated_shape, window_strides, padding) - # We only manually pad if we aren't using a transposed convolutions. - if padding_type == "EXPLICIT": - lhs, lhs_shape, padding = _check_pad_spatial_dims(lhs, lhs_shape, padding) - padding_type = padding - - if padding_type != "SAME" and any(l < r for l, r in zip(lhs_shape[1:3], rhs_dilated_shape)): - # If the input shape is smaller than the filter shape in a spatial dimension, - # lax returns only zeros while tf.conv2d returns an error. - # We thus return zeros to make sure the behavior is consistent. - return tf.broadcast_to(tf.constant(0, dtype=tf.float32), - jax2tf._eval_shape(out_shape)) - - if is_depthwise: - # Reshape filter from - # [filter_height, filter_width, 1, in_channels * channel_multiplier] to - # [filter_height, filter_width, in_channels, channel_multiplier]. - new_rhs_shape = tuple(rhs_spatial_shapes) + (in_channels, - rhs_out_channel // in_channels) - output = tf.nn.depthwise_conv2d( - input=lhs, - filter=tf.reshape(rhs, jax2tf._eval_shape(new_rhs_shape)), - strides=tf_window_strides, - padding=padding_type, - dilations=rhs_dilation) - - elif is_transpose: - # tf.nn.conv2d_transpose requires a transposed filter. - rhs_t = tf.reverse(rhs, [0, 1]) - rhs_t = tf.transpose(rhs_t, (0, 1, 3, 2)) - - # We should transpose `out_shape` to "NHWC", which is what TF expects. - # First transpose to "NCHW". - if is_conv1d: - tf_out_shape = tuple(out_shape[i] for i in output_perm) + (1,) - else: - tf_out_shape = tuple(out_shape[i] for i in output_perm) - # Then transpose "NCHW" to "NHWC". - tf_out_shape = tuple(tf_out_shape[i] for i in (0, 2, 3, 1)) - output = tf.nn.conv2d_transpose( - input=lhs, - filters=rhs_t, - output_shape=jax2tf._eval_shape(tf_out_shape), - strides=lhs_dilation, - padding=padding_type) - - else: - output = tf.nn.conv2d( - input=lhs, - filters=rhs, - strides=tf_window_strides, - padding=padding_type, - dilations=rhs_dilation) - - # TF outputs in format "NHWC", so convert to "NCHW", which is lax's default - # format. - output = tf.transpose(output, (0, 3, 1, 2)) # "NHWC" --> "NCHW" - if is_conv1d: - output = output[:, :, :, 0] - # To determine the right permutation, we compute the inverse permutation of - # `output_perm`, so that when `output_perm` is applied to `output`, we obtain - # the outpt in NCHW format. - inverse_perm = _invert_permutation(output_perm) - output = tf.transpose(output, inverse_perm) # "NCHW" -> desired output shape. - return output - - -tf_impl_no_xla[lax.conv_general_dilated_p] = _conv_general_dilated - - -def _dot_general(lhs, rhs, *, dimension_numbers, - precision: tuple[PrecisionType, PrecisionType] | None, - preferred_element_type: DType | None, - out_sharding=None, - _in_avals: Sequence[core.ShapedArray], - _out_aval: core.ShapedArray): - """Implementation of lax.dot_general_p in terms of tf.linalg.einsum.""" - # Unused arguments. - del precision - del preferred_element_type - - lhs, rhs, convert_result = jax2tf._dot_general_convert_to_common_dtype( - lhs, _in_avals[0], rhs, _in_avals[1], _out_aval) - - (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers - lhs_ndim, rhs_ndim = len(lhs.shape), len(rhs.shape) - - # This condition ensures that: - # 1) the batch dimensions are ordered in the same way in lhs and rhs (this is - # not strictly necessary, but we would have to reshape the array if that - # were not the case; - # 2) lhs and rhs have the same number of dimensions +/- 1 - # 3) the number of non-batch dimensions in both tensors is either 1 or 2 - # 4) the contracting dimensions are consistent with those of a classic - # matrix/matrix, vector/matrix or matrix/vector multiplication. - if (lhs_batch == rhs_batch == tuple(range(len(lhs_batch))) and - lhs_ndim - rhs_ndim in [-1, 0, 1] and - 1 <= lhs_ndim - len(lhs_batch) <= 2 and - 1 <= rhs_ndim - len(rhs_batch) <= 2 and - lhs_contracting == (len(lhs.shape) - 1,) and - rhs_contracting == (len(lhs_batch),)): - # All the inputs to tf.linalg.matmul must have 2 inner dimensions, - # after their batch dimensions, so we need to expand the dimensions - # appropriately. We can get to this branch with three combinations of - # inner shapes: - # - lhs.inner_shape == [a, b], rhs.inner_shape == [b, c] - # - in this case, the resulting inner shape is [a, c]; - # - lhs.inner_shape == [b] , rhs.inner_shape == [b, c] - # - in this case, we need to expand lhs to [1, b], and the resulting - # shape is [c]. We need to squeeze the result of tf.linalg.matmul - # as it will have shape [1, c]; - # - lhs.shape == [batch] + [a, b], rhs.shape == [batch] + [b] - # - in this case, we need to expand rhs to [b, 1], and the resulting - # shape is [a]. We need to squeeze the result of tf.linalg.matmul - # as it will have shape [a, 1]; - # - lhs.shape == [batch] + [b] , rhs.shape == [batch] + [b] - # - in this case, we need to expand lhs to [1, b] and rhs to [b, 1], - # and the resulting shape is (). We need to squeeze the result of - # tf.linalg.matmul as it will have shape [1, 1]. - squeeze_idxs = [] - if lhs_ndim - len(lhs_batch) == 1: - lhs = tf.expand_dims(lhs, lhs_ndim - 1) - squeeze_idxs.append(len(lhs.shape) - 2) - if rhs_ndim - len(rhs_batch) == 1: - rhs = tf.expand_dims(rhs, rhs_ndim) - squeeze_idxs.append(len(rhs.shape) - 1) - result = tf.linalg.matmul(lhs, rhs) - if len(squeeze_idxs) != 0: - assert all(result.shape[i] == 1 for i in squeeze_idxs) - result = tf.squeeze(result, squeeze_idxs) - return convert_result(result) - - new_id = iter(string.ascii_letters) - lhs_axis_ids = [next(new_id) for _ in lhs.shape] - rhs_axis_ids = [next(new_id) for _ in rhs.shape] - lhs_out_axis_ids = lhs_axis_ids[:] - rhs_out_axis_ids = rhs_axis_ids[:] - - for lhs_axis, rhs_axis in zip(lhs_contracting, rhs_contracting): - shared_id = next(new_id) - lhs_axis_ids[lhs_axis] = shared_id - rhs_axis_ids[rhs_axis] = shared_id - lhs_out_axis_ids[lhs_axis] = None # type: ignore[call-overload] - rhs_out_axis_ids[rhs_axis] = None # type: ignore[call-overload] - - batch_ids = [] - for lhs_axis, rhs_axis in zip(lhs_batch, rhs_batch): - shared_id = next(new_id) - lhs_axis_ids[lhs_axis] = shared_id - rhs_axis_ids[rhs_axis] = shared_id - lhs_out_axis_ids[lhs_axis] = None # type: ignore[call-overload] - rhs_out_axis_ids[rhs_axis] = None # type: ignore[call-overload] - batch_ids.append(shared_id) - - not_none = lambda x: x is not None - out_axis_ids = list( - filter(not_none, batch_ids + lhs_out_axis_ids + rhs_out_axis_ids)) - assert lhs.dtype == rhs.dtype - spec = "{},{}->{}".format("".join(lhs_axis_ids), "".join(rhs_axis_ids), - "".join(out_axis_ids)) - return convert_result(tf.linalg.einsum(spec, lhs, rhs)) - - -tf_impl_no_xla[lax.dot_general_p] = _dot_general - - -def _interior_padding(operand, padding_value, padding_config, operand_shape): - # Used only when enable_xla=False - # Applies only the interior padding from the padding_config. - # We do this somewhat inefficiently, as a scatter. - # For each dimension we compute the indices_by_dim as [0, f, 2f, 3f, ...] where - # f is the dilation factor for the dimension, i.e., 1 + interior_padding. - # Then we compute the cartesian production of the indices (using broadcast - # and concat). - - # We could make this code more complex and do all the padding at once, but - # we prefer to keep it simple. - indices_by_dim = [] - indices_shape = operand_shape + (1,) - output_shape = [] # considering only interior padding - for d, (dsz, (_, _, i)) in enumerate(zip(operand_shape, padding_config)): - dilation_factor = i + 1 - output_shape.append(dsz * dilation_factor - i) - indices = tf.range(dsz) * dilation_factor - expansion = [None] * (1 + len(operand_shape)) - expansion[d] = slice(None, None, None) - indices_by_dim.append(tf.broadcast_to(indices[expansion], indices_shape)) - - indices_cartesian = tf.concat(indices_by_dim, axis=len(operand_shape)) - scattered = tf.scatter_nd(indices_cartesian, operand, output_shape) - # What elements from the output array we use from - mask = tf.scatter_nd(indices_cartesian, tf.ones_like(operand, dtype=np.bool_), - output_shape) - return tf.where(mask, scattered, padding_value) - - -def _pad(operand, padding_value, *, padding_config, - _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): - # Do only the interior padding first. This is rarely needed. - if any(i != 0 for _, _, i in padding_config): - operand = _interior_padding(operand, padding_value, padding_config, - jax2tf._eval_shape(_in_avals[0].shape)) - - # Now do the non-negative edge padding. This is the common case, use tf.pad. - non_negative_padding = [((lo if lo >= 0 else 0), (hi if hi >= 0 else 0)) - for lo, hi, _ in padding_config] - operand = tf.pad( - operand, - non_negative_padding, - mode="CONSTANT", - constant_values=padding_value) - # Now the negative edge padding (this is also rare) - if any(lo < 0 or hi < 0 for lo, hi, _ in padding_config): - output_shape = jax2tf._eval_shape(_out_aval.shape) - begins = [(-lo if lo < 0 else 0) for lo, _, _ in padding_config] - operand = tf.slice(operand, begins, output_shape) - - return operand - - -tf_impl_no_xla[lax.pad_p] = _pad - - -def _argminmax(is_min: bool, operand: TfVal, axes: Sequence[int], - index_dtype: DType, _in_avals: Sequence[core.ShapedArray], - _out_aval: core.ShapedArray): - # The following is known to diverge from JAX behavior for NaN. - axis, = axes - output_type = tf.int32 - if dtypes.iinfo(index_dtype).bits > 32: - output_type = tf.int64 - # TODO(phawkins): handle axes larger than 2^31. - fn = tf.math.argmin if is_min else tf.math.argmax - result = fn(operand, axis=axis, output_type=output_type) - return tf.cast(result, jax2tf._to_tf_dtype(index_dtype)) - - -tf_impl_no_xla[lax.argmin_p] = partial(_argminmax, True) -tf_impl_no_xla[lax.argmax_p] = partial(_argminmax, False) - - -def _validate_reduce_window_inputs(operand_shape, computation_name, dtype, - window_dimensions, window_strides, - base_dilation, window_dilation): - if computation_name not in ["min", "max", "add"]: - raise _reduce_error("Reduction function should be either min, max, or add.") - if computation_name in ["min", "max"] and dtype in [ - tf.bool, tf.uint32, tf.uint64, tf.complex64, tf.complex128 - ]: - raise _reduce_error("Min/max pool does not support operands of type " - f"{dtype}") - if computation_name == "min" and dtype in [tf.uint8, tf.uint16]: - # TODO(marcvanzee): We currently implement min pooling by negating the - # input, but this doesn't work for uint. We could work around it using - # tf.math.reduce_min. - raise _reduce_error(f"Min pool does not support operands of type {dtype}") - if computation_name == "add" and dtype not in [ - tf.bfloat16, - tf.float16, - tf.float32, - tf.float64, - tf.int16, - tf.int32, - ]: - raise _reduce_error("Add pooling does not support operands of type " - f"{dtype}") - - if (len(operand_shape) != len(window_dimensions) != len(window_strides) != - len(window_dilation)): - raise _reduce_error("Input shapes, window dimensions, window stride " - "dimensions, and window dilation dimensions should " - "match.") - - has_only_spatial_dims = True - if len(operand_shape) > 4: - raise _reduce_error("Only 1D or 2D input are supported.") - if len(operand_shape) > 2: - # operand_shape = (batch, spatial_dims, ..., channel). - has_only_spatial_dims = False - - for name, value in [("window_dimensions", window_dimensions), - ("window_strides", window_strides), - ("window_dilation", window_dilation)]: - if value[0] != value[-1] != 1: - raise _reduce_error("Only 1D or 2D input are supported, expected " - f"{name}=(1, spatial_dims, ..., 1), but got " - f"{value}") - - if list(base_dilation) != [1] * len(operand_shape): - # TODO(marcvanzee): Add support for base dilations. We can do this using - # a scatter on operand. - raise _reduce_error("Unimplemented support for base dilation.") - - return has_only_spatial_dims - - -def _padding_reduce_window(operand, operand_shape, computation_name, - window_dimensions, window_strides, padding): - padding_type = pads_to_padtype(operand_shape, window_dimensions, - window_strides, padding) - - # https://github.com/jax-ml/jax/issues/11874. - needs_manual_padding = ( - padding_type == "SAME" and computation_name == "add" and - window_dimensions != [1] * len(operand_shape)) - - if needs_manual_padding or padding_type == "EXPLICIT": - operand, operand_shape = _pad_spatial_dims(operand, operand_shape, padding) - padding_type = "VALID" - - return operand, operand_shape, padding_type - - -def _reshape_reduce_window(operand, operand_shape, window_dimensions, - window_strides, window_dilation, *, - has_only_spatial_dims): - # Reshape inputs so they are accepted by tf.nn.pool, which expects batch and - # channel dimensions for operand but not for any of the other inputs. - if has_only_spatial_dims: # len(operand_shape) <= 2 - # Call eval_shape on a shape that may contain polynomials, otherwise TF does - # not know what to do with polynomials in the shape. - operand_shape = jax2tf._eval_shape(operand_shape) - # Add batch and channel dimensions to operand. - operand = tf.reshape(operand, (1,) + operand_shape + (1,)) - else: - # This branch assumes operand.shape = (batch, spatial_dims, ..., channel), - # and dimensions, strides, dilation are all (1, spatial_values, ..., 1). - # Input validation for this is done in _validate_reduce_window_inputs. - window_dimensions = window_dimensions[1:-1] - window_strides = window_strides[1:-1] - window_dilation = window_dilation[1:-1] - - return operand, window_dimensions, window_strides, window_dilation - - -def _reduce_monoid(operand, window_dimensions, window_strides, padding, - base_dilation, window_dilation, computation_name, - _in_avals: Sequence[core.ShapedArray], - _out_aval: core.ShapedArray): - dtype = operand.dtype - # In presence of shape polymorphism, operand.shape may contain None. The - # actual dimension polynomial shapes are in _in_avals. - operand_shape = _in_avals[0].shape - - # TODO(marcvanzee): Put reduce_window arguments into dataclass, similar to - # Gather, to simplify function calls. - has_only_spatial_dims = _validate_reduce_window_inputs( - operand_shape, computation_name, dtype, window_dimensions, window_strides, - base_dilation, window_dilation) - - operand, operand_shape, padding_type = _padding_reduce_window( - operand, operand_shape, computation_name, window_dimensions, - window_strides, padding) - - operand, window_dimensions, window_strides, dilations = _reshape_reduce_window( - operand, - operand_shape, - window_dimensions, - window_strides, - window_dilation, - has_only_spatial_dims=has_only_spatial_dims) - - def tf_pool(inputs, pooling_type): - if any(not core.is_constant_shape(s) for s in - (window_dimensions, window_strides, dilations)): - raise NotImplementedError( - f"TODO: use tf.nn.pool with dynamic shapes¨{window_dimensions=} " - f" {window_strides=} {dilations=}") - # tf.nn.pool() currently does not suport tf.int32 and so we cast back and - # forth in order to be able to convert. - if (inputs.dtype in [tf.int16, tf.int32]) and computation_name == "add": - original_dtype = inputs.dtype - inputs = tf.cast(inputs, dtype=tf.float32) - else: - original_dtype = None - result = tf.nn.pool( - inputs, - window_shape=window_dimensions, - pooling_type=pooling_type, - padding=padding_type, - strides=window_strides, - dilations=dilations) - if original_dtype: - result = tf.cast(result, dtype=original_dtype) - - if has_only_spatial_dims: - # If the input only had spatial dimensions we need to contract the batch - # and channel dimensions before returning the output. - result = tf.squeeze(result, [0, -1]) - - jax2tf._assert_matching_abstract_shape(result, _out_aval.shape) - return result - - negate = lambda x: tf.multiply(x, tf.constant(-1, dtype)) - if computation_name == "max": - return tf_pool(operand, "MAX") - elif computation_name == "min": - return negate(tf_pool(negate(operand), "MAX")) - elif computation_name == "add": - # TODO(marcvanzee): This may give very large deviations on TPU when using - # floats as inputs. Alternatively, we could implement this using a - # convolution with an all-1's kernel. - return tf.multiply(tf_pool(operand, "AVG"), math.prod(window_dimensions)) - - -def _reduce_window(*args, jaxpr, consts, window_dimensions, - window_strides, padding, base_dilation, window_dilation, - _in_avals: Sequence[core.ShapedArray], - _out_aval: tuple[core.ShapedArray, ...] - ) -> tuple[TfVal, ...]: - assert len(consts) == 0, "Reduction computation cannot have constants" - operands, init_values = util.split_list(args, [len(args) // 2]) - - if len(operands) != 1 or len(init_values) != 1: - raise _reduce_error("jax2tf does not support variadic reduce_window") - - operand, init_value = operands[0], init_values[0] - # Infer operation type from jaxpr. - if (len(jaxpr.eqns) != 1 or - len(jaxpr.eqns[0].invars) != 2 or - len(jaxpr.eqns[0].outvars) != 1 or - jaxpr.eqns[0].primitive.name not in ["min", "max", "add"]): - raise _reduce_error("Reduction function should be either min, max, or add.") - - computation_name = jaxpr.eqns[0].primitive.name - result = _reduce_monoid(operand, - window_dimensions=window_dimensions, - window_strides=window_strides, - padding=padding, - base_dilation=base_dilation, - window_dilation=window_dilation, - computation_name=computation_name, - _in_avals=(_in_avals[0],), # Don't pass init_value. - _out_aval=_out_aval[0]) # Returns single value. - - reduce_fn = { - "min": tf.minimum, - "max": tf.maximum, - "add": tf.add, - }[computation_name] - result = reduce_fn(result, init_value) - - # The output is expected to be wrapped in a tuple, and since we don't use - # variadic reductions, this tuple always contains a single element. - return (result,) - - -tf_impl_no_xla[lax.reduce_window_min_p] = ( - partial(_reduce_monoid, computation_name="min")) -tf_impl_no_xla[lax.reduce_window_max_p] = ( - partial(_reduce_monoid, computation_name="max")) -tf_impl_no_xla[lax.reduce_window_sum_p] = ( - partial(_reduce_monoid, computation_name="add")) - -tf_impl_no_xla[lax.reduce_window_p] = _reduce_window - -tf_impl_no_xla[lax.reduce_p] = _unimplemented("reduce") - -tf_impl_no_xla[lax.select_and_scatter_add_p] = _unimplemented( - "select_and_scatter_add") - -tf_impl_no_xla[lax.rng_bit_generator_p] = _unimplemented("rng_bit_generator") - - -def _clip(max_indices: Sequence[TfVal], start_indices: Sequence[TfVal], - slice_sizes: Sequence[TfVal]): - """Simulates XLA clipping behavior with TF ops. - - Various TF ops have different clipping behavior than XLA: - * If `start_indices` is out-of-bounds, then TF fails but XLA clips the indices - to - [0, max_len]. - * If `start_indices + slice_size` is out-of-bounds, then TF fails, but XLA - adjust - `start_indices` so that a full slice is returned. - This function clips the start indices correctly. - """ - # We cast both arguments to `tf.clip_by_value` to int32. Otherwise, this - # function may return uint32 which is not always compatible with TF ops, so - # this may result in type errors. - max_start = tf.cast(tf.subtract(max_indices, slice_sizes), dtype=tf.int32) - return tf.clip_by_value(tf.cast(start_indices, dtype=tf.int32), 0, max_start) - - -@dataclasses.dataclass -class GatherArgs: - operand: TfVal - start_indices: TfVal - dnums: lax.GatherDimensionNumbers - slice_sizes: TfVal - op_shape: core.Shape - start_indices_shape: core.Shape - out_aval: core.ShapedArray - - def __post_init__(self): - assert len(self.op_shape) == len(self.slice_sizes) - - def __repr__(self): - return (f"operand shape={self.op_shape}, " - f"start_indices={self.start_indices}, " - f"dimension_numbes={self.dnums}, " - f"slice_sizes={self.slice_sizes}") - @property - def batch_dims(self): - return tuple(x for x in range(len(self.out_aval.shape)) - if x not in self.dnums.offset_dims) - -def gather_precondition(precondition_fn: Callable[[GatherArgs], None]): - """Decorator for specifying a precondition function. - - This decorator should be put on a function with argument `arg` of type - `GatherArgs`. It will first call `precondition_fn` with `arg` (which may throw - an exception), and then call the function it is decorating with `arg` as well. - """ - - def decorator(gather_fn: Callable[[GatherArgs], Any]): - - @wraps(gather_fn) - def wrapper(args: GatherArgs): - # Call `precondition_fn`; we assume it may throw an exception. - precondition_fn(args) - return gather_fn(args) - - return wrapper - - return decorator - - -def _pre_gather_for_scalar_indexing(args: GatherArgs): - """Returns True if this call to gather represents scalar indexing into arrays. - - E.g., op[2], op[:, :5, :], jnp.take(op, 0, axis=0). - """ - # TODO(marcvanzee): Add more assumptions here, because this is currently too - # permissive. - if len(args.start_indices_shape) != 1: - raise ValueError("start_indices shape should be 1") - - -@gather_precondition(_pre_gather_for_scalar_indexing) -def _gather_for_scalar_indexing(args: GatherArgs): - """Implements 'scalar indexing into arrays' cases of lax.gather using tf.slice. - - E.g., op[2], op[:, :5, :], jnp.take(op, 0, axis=0). - """ - indices = tf.expand_dims(args.dnums.start_index_map, 1) - # lax.gather uses an "index map" which maps `start_indices` to the right axes - # in `operand`. Since tf.strided_slice uses a single array for specifying the - # start indices, we use a scatter to map the start indices to the right axes. - op_shape = jax2tf._eval_shape(args.op_shape) - slice_sizes_tf = jax2tf._eval_shape(args.slice_sizes) - # TODO(marcvanzee): Consider transposing `operand`, which is probably more - # optimization friendly. - begin = tf.scatter_nd(indices, args.start_indices, [len(op_shape)]) - begin = _clip(op_shape, begin, slice_sizes_tf) - end = slice_sizes_tf + begin - - # `collapsed_slice_dims` is a tuple of dimensions to collapse, e.g. (0, 2). - # `tf.strided_slice` expects a binary mask to specify the shrink axes, i.e., - # if we want to shrink axis 0 and 2, this corresponds to binary mask 101, - # which is 5 in decimals. The following line converts the lax representation - # to the one used by `tf.strided_slice`. - shrink_mask = sum(2**x for x in args.dnums.collapsed_slice_dims) - res = tf.strided_slice(args.operand, begin, end, shrink_axis_mask=shrink_mask) - # Shape inference doesn't work for tf.strided_slice. - res = jax2tf._ensure_tf_shape_if_dynamic( - res, jax2tf._aval_to_tf_shape(args.out_aval) - ) - return res - - -def _pre_gather_for_multidim_indexing(args: GatherArgs): - """Returns True if this call to gather represents multi-dimensional indexing. - - E.g., jnp.take(op, [[0], [1]], axis=0). - Note we currently only support multi-dimensional indexing if the last - dimension is 1. - """ - # Handle only the case when tf.gather argument batch_dims=0. - # Find axis to match the tf.gather semantics - # Let I = len(start_indices_shape) - # let O = len(op_shape) - # slice_sizes == op_shape[:axis] + (1,) + op_shape[axis+1:] - # collapsed_slice_dims == (axis,) - # start_index_map == (axis,) - # offset_dims == (0, 1, ..., axis - 1, axis + I, ..., O + I - 1) - # We added a trailing dimension of size 1 - op_shape = args.op_shape - start_index_map = args.dnums.start_index_map - collapsed_slice_dims = args.dnums.collapsed_slice_dims - offset_dims = args.dnums.offset_dims - if not (len(op_shape) >= 1 and len(start_index_map) == 1 and - len(collapsed_slice_dims) == 1 and collapsed_slice_dims[0] - == start_index_map[0] and len(offset_dims) == len(op_shape) - 1): - raise ValueError("unsupported dimension numbers") - # We added a trailing dimension of size 1 - if not core.definitely_equal(args.start_indices_shape[-1], 1): - raise ValueError("start_indices shape[-1] should be 1") - # Guess the axis - axis = collapsed_slice_dims[0] - index_dims = len(args.start_indices_shape) - 1 - expected_offset_dims = tuple( - list(range(axis)) + - list(range(axis + index_dims, - len(op_shape) + index_dims - 1))) - if offset_dims != expected_offset_dims: - raise ValueError("unsupported offset_dims") - expected_slice_sizes = op_shape[:axis] + (1,) + op_shape[axis + 1:] # type: ignore - if not core.definitely_equal_shape(args.slice_sizes, expected_slice_sizes): - raise ValueError("unsupported slice_sizes") - - -@gather_precondition(_pre_gather_for_multidim_indexing) -def _gather_for_multidim_indexing(args: GatherArgs): - """Implements 'multi-dimensional indexing into arrays' cases of lax.gather using tf.gather. - - E.g., jnp.take(op, [[0], [1]], axis=0). - """ - # Guess the axis. - axis = args.dnums.collapsed_slice_dims[0] - squeezed_indices = tf.squeeze(args.start_indices, -1) - op_shape = jax2tf._eval_shape(args.op_shape) - start_indices = _clip((op_shape[axis],), squeezed_indices, (1,)) - return tf.gather(args.operand, start_indices, axis=axis, batch_dims=0) - - -def _pre_gather_with_batch_dim(args: GatherArgs): - """Returns True if this call to gather has non-empty batch dimensions. - - This is for instance triggered when doing jax.vmap(lax.dynamic_slice). - """ - # We assume exactly one batch (and one or more non-batch dimensions). - if len(args.batch_dims) != 1: - raise ValueError(f"batch_dims is {len(args.batch_dims)} but should be 1") - - # `start_index_map` maps indices in `start_indices` to indices in `operand`. - # For simplicity, we currently only consider the case where this mapping is - # the identity function, i.e., [2, 3] in `start_indices` maps to - # `operand[2, 3]`. - if args.dnums.start_index_map != tuple(range(args.start_indices_shape[-1])): - raise ValueError("unsupported start_index_map") - - # The batch dims in `start_indices` and `operand` should agree. - if not core.definitely_equal(args.op_shape[0], args.start_indices_shape[0]): - raise ValueError("Batch dimensions in operand and start_indices don't " - "agree") - - -def _pre_gather_with_batch_dims(args: GatherArgs): - """Returns True if this call to gather has non-empty 2D batch dimensions. - - This is for instance triggered when doing - jax.vmap(jax.vmap(lax.dynamic_slice)). - """ - if len(args.dnums.collapsed_slice_dims) != 0: - # NOTE: this can be relaxed in _gather_with_batch_dims but we might - # also need to re-work the output reshaping - raise ValueError("only len(collapsed_slice_dims) == 0 is supported") - - # NOTE: This supports higher dimensions than listed (the highest dimension - # in the tests is 3D so it is limited to that, but the implementation is - # designed to handle higher dimensions (N-Dimensional)). - if len(args.batch_dims) not in [1, 2, 3]: - raise ValueError( - f"Size of batch_dims is {len(args.batch_dims)} but should be up to 3" - ) - -@gather_precondition(_pre_gather_with_batch_dim) -def _gather_with_batch_dim(args: GatherArgs): - """Implements call to gather with non-empty batch dimensions. - - E.g., when doing `jax.vmap(lax.dynamic_slice). - """ - op_shape = jax2tf._eval_shape(args.op_shape) - start_indices = _clip(op_shape, args.start_indices, args.slice_sizes) - result = tf.map_fn( - lambda idxs: tf.slice(args.operand, begin=idxs, size=args.slice_sizes), - start_indices, - fn_output_signature=jax2tf._to_tf_dtype(args.operand.dtype) - ) - result = tf.reshape(result, jax2tf._eval_shape(args.out_aval.shape)) - return result - - -def _gather_generate_indices(shape: tuple[int, ...]): - """ - Returns the indices of the according to `shape`: - each element in the output is the index of an element of an array - of the provided shape. The result's shape is (math.prod(shape), len(shape)) - - For example, given shape (2,2) it returns (0,0),(0,1),(1,0),(1,1) - """ - return tf.reshape( - tf.stack( - tf.meshgrid( - *[tf.range(start=0, limit=x) for x in shape], indexing="ij" - ), - axis=-1, - ), - (-1, len(shape)), - ) - - -@gather_precondition(_pre_gather_with_batch_dims) -def _gather_with_batch_dims(args: GatherArgs): - """Implements call to gather with non-empty 2D batch dimensions.""" - op_shape = jax2tf._eval_shape(args.op_shape) - output_shape = jax2tf._eval_shape(args.out_aval.shape) - # Used to map the start_indices w.r.t start_index_map - indices = tf.expand_dims(args.dnums.start_index_map, 1) - - # batch_indices is shaped (N,d) where N is the number of slices and d is - # the number of batch_dims; batch_indices_size equals to N - batch_indices = _gather_generate_indices( - tuple(output_shape[i] for i in args.batch_dims) - ) - batch_indices_size = jax2tf._eval_shape(batch_indices.shape)[0] - # offset_indices is shaped (K,d) where K is the number of elements in each - # slice and d is the number of offset_dims; offset_indices_size equals to K - offset_indices = _gather_generate_indices( - tuple(output_shape[i] for i in args.dnums.offset_dims) - ) - offset_indices_size = jax2tf._eval_shape(offset_indices.shape)[0] - - # After we compute the result we need to reshape the axes with respect to - # the output batch_dims and offset_dims. - dim_mask = args.batch_dims + args.dnums.offset_dims - mask_output_shape = tuple(output_shape[x] for x in dim_mask) - - def get_scatter_indices(indices, batch_indices_size, size_of_index_map): - """Generate the start indices of each slice, which index into the operand.""" - # Tile indices batch_indices_size times - tiled_indices = tf.tile( - tf.expand_dims(indices, 0), [batch_indices_size, 1, 1] - ) - # The above tiles need to index the proper element of batch_indices - # To do this generate a repeated sequence of numbers - temp_batch_indices = tf.repeat( - tf.range(start=0, limit=batch_indices_size), size_of_index_map - ) - # Reshape the above sequence so it follows the same shape of tiled_indices - temp_batch_indices = tf.reshape( - temp_batch_indices, (batch_indices_size, size_of_index_map, 1) - ) - # Now we concatenate to create indices offset by the temp_batch_indices - return tf.concat([temp_batch_indices, tiled_indices], axis=-1) - - slice_start_indices = tf.gather_nd(args.start_indices, batch_indices) - # TODO: In the case where start_index_map is the identity we can skip this. - scatter_indices = get_scatter_indices( - indices, batch_indices_size, len(args.dnums.start_index_map) - ) - # We map the scatter_indices w.r.t start_index_map - indices_in_operand = tf.scatter_nd( - scatter_indices, slice_start_indices, [batch_indices_size, len(op_shape)] - ) - - # We clip the indices as OOB cases are possible when offsetting past - # the operand boundaries - clipped_start_indices = _clip(op_shape, indices_in_operand, args.slice_sizes) - # Here we need to broadcast clipped_start_indices and add each of the offsets - # which will generate a large index tensor of shape (T,d) where T is the - # number of slices times the size of each slice (i.e total number of items - # across all sices); d is rank(operand) - slice_element_indices = tf.add( - tf.repeat(clipped_start_indices, offset_indices_size, axis=0), - tf.tile(offset_indices, (batch_indices_size, 1)), - ) - results = tf.gather_nd(args.operand, slice_element_indices) - - # Here results comes shaped as (N,1). Because collapsed_slice_dims is 0, - # offset_dims is effectviely slice_sizes. - # We reshape to mask_output_shape because if we directly reshape to the - # output shape and our batch_dims are non-contiguous we will produce the - # wrong shape. Reshaping to mask_output_shape gives (...,*slice_sizes), - # which we then transpose to permute the axes in the proper way. - # Note that if the batch_dims are contiguous this won't change the output. - temp = tf.reshape(results, shape=mask_output_shape) - return tf.transpose(temp, perm=tf.math.invert_permutation(dim_mask)) - -def _gather(operand, start_indices, *, dimension_numbers, - slice_sizes: core.Shape, indices_are_sorted, unique_indices, mode, - fill_value, _in_avals: Sequence[core.ShapedArray], - _out_aval: core.ShapedArray): - """Tensorflow implementation of gather.""" - if mode == lax.GatherScatterMode.FILL_OR_DROP: - gather_fill_fn = jax2tf._convert_jax_impl(lax_slicing._gather_fill, - multiple_results=False) - return gather_fill_fn( - operand, start_indices, dimension_numbers=dimension_numbers, - slice_sizes=slice_sizes, unique_indices=unique_indices, - indices_are_sorted=indices_are_sorted, fill_value=fill_value, - output_shape=_out_aval.shape, _in_avals=_in_avals, _out_aval=_out_aval) - - # TODO(marcvanzee): Check if we need more tests in shape_poly for gather with - # enable_xla=False. - gather_args = GatherArgs( - operand=operand, - start_indices=start_indices, - dnums=dimension_numbers, - slice_sizes=slice_sizes, - op_shape=_in_avals[0].shape, - start_indices_shape=_in_avals[1].shape, - out_aval=_out_aval) - - errors = [] - - for gather_fn in [ - _gather_for_scalar_indexing, - _gather_for_multidim_indexing, - _gather_with_batch_dim, - _gather_with_batch_dims, - ]: - try: - return gather_fn(gather_args) - except ValueError as e: - errors.append(f"{gather_fn}: {e!r}") - - error_msg = (f"Unsupported arguments for gather: {gather_args}, errors:\n" + - "\n".join(errors)) - - raise _error("gather", error_msg) - - -tf_impl_no_xla[lax.gather_p] = _gather - - -def _dynamic_slice(operand, *start_indices, slice_sizes: core.Shape, - _in_avals: Sequence[core.ShapedArray], - _out_aval: core.ShapedArray): - start_indices = tf.stack(start_indices) - slice_sizes_tf = jax2tf._eval_shape(slice_sizes) - - operand_shape = jax2tf._eval_shape(_in_avals[0].shape) - start_indices = _clip(operand_shape, start_indices, slice_sizes_tf) - return tf.slice(operand, start_indices, size=slice_sizes_tf) - - -tf_impl_no_xla[lax.dynamic_slice_p] = _dynamic_slice - - -def _dynamic_update_slice(operand, update, *start_indices, - _in_avals: Sequence[core.ShapedArray], - _out_aval: core.ShapedArray): - start_indices = tf.stack(start_indices) - - op_shape = jax2tf._eval_shape(_in_avals[0].shape) - op_size = tf.size(operand) - update_shape_tf = jax2tf._eval_shape(_in_avals[1].shape) - - start_indices = _clip(op_shape, start_indices, update_shape_tf) - end_indices = tf.add(start_indices, update_shape_tf) - - # Get the cells to update in `operand` as an array of ids. - id_tensor = tf.reshape(tf.range(op_size), op_shape) - scattered_indices = tf.strided_slice(id_tensor, start_indices, end_indices) - - # Create an array containing updates at scattered_indices and zeros otherwise. - flat_indices = tf.expand_dims(tf.nest.flatten(scattered_indices), -1) - flat_update = tf.nest.flatten(update) - update = tf.scatter_nd(flat_indices, flat_update, (op_size,)) - update = tf.reshape(update, op_shape) - - # Create a bool mask that is True only where `operand` should be updated. - update_mask = tf.ones_like(flat_update, dtype=tf.bool) - update_mask = tf.scatter_nd(flat_indices, update_mask, (op_size,)) - update_mask = tf.reshape(update_mask, op_shape) - - # Use the mask to only update `operand` with `update`. - return tf.where(update_mask, update, operand) - - -tf_impl_no_xla[lax.dynamic_update_slice_p] = _dynamic_update_slice - - -def shift_axes_forward(operand, - axes: tuple[int, ...], - inverse: bool = False, - forward: bool = True): - """Shifts the tuple of axes to the front of an array""" - other_axes = tuple(i for i in range(len(operand.shape)) if i not in axes) - fwd_order = axes + other_axes if forward else other_axes + axes - order = fwd_order if not inverse else _invert_permutation(fwd_order) - return tf.transpose(operand, order) - -def convert_scatter_jax_to_tf(update_op, unsorted_segment_op=None): - - def _sparse_scatter(operand, scatter_indices, updates, unique_indices, mode, - _in_avals: Sequence[core.ShapedArray], - _out_aval: core.ShapedArray): - """Implementation of scatter specialised to indexing from the front axes. - - This covers unique indices and non-unique indices of single depth. - Note on unique indices: `tf.tensor_scatter_nd_update` interprets indices - thusly: every axis except the final one encodes a batch dimension, the final - axis encoding the actual indices to scatter in to. It enforces, at least - one, batch dimension so we add an empty dimension to indices and updates if - lacking. - - Note on non-unique indices: There is no tf op for non-single depth indexing, - but if indexing is single depth, this can be viewed as a segment op. - """ - # Infer unique indices from lack of batch dimension - unique_indices = unique_indices or (len(scatter_indices.shape) == 1) - if unique_indices: - suboperand = tf.gather_nd(operand, scatter_indices) - updated_suboperand = update_op(suboperand, updates) - # add a batch dim if none exist - if len(scatter_indices.shape) == 1: - scatter_indices = scatter_indices[None] - updated_suboperand = updated_suboperand[None] - y = tf.tensor_scatter_nd_update(operand, scatter_indices, updated_suboperand) - else: - if (scatter_indices.shape[-1] == 1) and unsorted_segment_op: - # If only indexing into the first dimension, it's a segment op - operand_update = unsorted_segment_op(updates, - tf.squeeze(scatter_indices, -1), - operand.shape[0]) - y = update_op(operand, operand_update) - else: - raise _scatter_error( - "Scatter only supports non-unique " - "indices with indexing into only one dimension for (add, mul, min, " - "max)") - return y - - def sparse_scatter(operand, scatter_indices, updates, update_jaxpr, - update_consts, dimension_numbers, indices_are_sorted: bool, - unique_indices: bool, mode, - _in_avals: Sequence[core.ShapedArray], - _out_aval: core.ShapedArray): - """ - Wrapper around the scatter function. - The underlying tf ops `tf.tensor_scatter_nd_update` and - `tf.math.unsorted_segment_*` index from the front dimensions. - `tf.math.unsorted_segment_*` indexes to a depth 1 from the front. - `tf.tensor_scatter_nd_update` indexes from the front dimensions onwards, - with no ability to skip a dimension. This function shifts the axes to be - indexed to the front then calls a front-specific implementation, then - inverse-shifts the output. - - scatter_dims_to_operand_dims: dimensions which the scatter indexes in to. - We shift these to the front to match tf syntax. All other dims are batch - update_window_dims: dimensions which are not batch dimensions. We shift - these to the back as the remaining dimensions are batch dimensions. - """ - del update_jaxpr, update_consts, indices_are_sorted # Unused arguments - - update_window_dims = dimension_numbers.update_window_dims - inserted_window_dims = dimension_numbers.inserted_window_dims - scatter_to_operand_dims = dimension_numbers.scatter_dims_to_operand_dims - - dtype = operand.dtype # assume updates has same dtype as operand - if dtype in [tf.bool, tf.complex64]: - raise _scatter_error(f"Scatter does not support operands of type {dtype}") - - if inserted_window_dims != scatter_to_operand_dims: - raise _scatter_error("Complex scatters are not supported") - - if (mode != lax.GatherScatterMode.FILL_OR_DROP and - mode != lax.GatherScatterMode.PROMISE_IN_BOUNDS): - # The OOB behavior for tf.scatter is as follows: - # - When running in eager or graph mode, it throws an error. - # TODO(marcvanzee): Fix this case by removing the OOB indices. - # - When running in compile mode, the OOB indices are dropped, which is - # the same behavior as FILL_OR_DROP and PROMISE_IN_BOUNDS. - # To ensure correctness, we disallow CLIP mode for now. - raise _scatter_error("Only scatter modes `FILL_OR_DROP` and " - "`PROMISE_IN_BOUNDS` are supported.") - - # Shift axes to the front to match tf syntax, inverse afterwards - fwd = partial(shift_axes_forward, axes=scatter_to_operand_dims) - inv = partial(fwd, inverse=True) - - # Shift update value axes to the back, so batch are at the front - updates_shifted = shift_axes_forward( - updates, axes=update_window_dims, forward=False) - - return inv( - _sparse_scatter( - fwd(operand), scatter_indices, updates_shifted, unique_indices, - mode, _in_avals, _out_aval)) - return sparse_scatter - - -tf_impl_no_xla[lax.scatter_p] = convert_scatter_jax_to_tf( - lambda x, y: y) # just replace with the update -tf_impl_no_xla[lax.scatter_add_p] = convert_scatter_jax_to_tf(tf.add, tf.math.unsorted_segment_sum) -tf_impl_no_xla[lax.scatter_mul_p] = convert_scatter_jax_to_tf(tf.multiply, tf.math.unsorted_segment_prod) -tf_impl_no_xla[lax.scatter_min_p] = convert_scatter_jax_to_tf(tf.minimum, tf.math.unsorted_segment_min) -tf_impl_no_xla[lax.scatter_max_p] = convert_scatter_jax_to_tf(tf.maximum, tf.math.unsorted_segment_max) - -tf_impl_no_xla[lax.sort_p] = _unimplemented("sort") - -tf_impl_no_xla[lax.reduce_precision_p] = _unimplemented("reduce_precision") diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 7f98ce433815..c5757532c4ba 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -15,13 +15,11 @@ from __future__ import annotations -from collections.abc import Callable, Iterable, Sequence +from collections.abc import Callable, Sequence from functools import partial import contextlib import math -import operator import os -import re import threading from typing import Any, Union import warnings @@ -30,44 +28,20 @@ import numpy as np import jax -from jax import lax -from jax import custom_derivatives -from jax import random -from jax import numpy as jnp from jax import tree_util -from jax import sharding from jax import export -from jax.experimental.jax2tf import impl_no_xla -from jax._src import ad_checkpoint -from jax._src import ad_util from jax._src import api from jax._src import api_util from jax._src import config from jax._src import core -from jax._src import dispatch from jax._src import dtypes -from jax._src import linear_util as lu from jax._src import op_shardings -from jax._src import sharding_impls -from jax._src import mesh -from jax._src import pjit -from jax._src import prng -from jax._src import random as random_internal from jax._src import source_info_util from jax._src import util -from jax._src import shard_alike from jax._src.export import _export from jax._src.export import shape_poly -from jax._src.interpreters import ad -from jax._src.interpreters import mlir -from jax._src.lax import control_flow as lax_control_flow -from jax._src.lax import lax as lax_internal -from jax._src.lax import linalg as lax_linalg -from jax._src.lax import slicing as lax_slicing -from jax._src.lax import windowed_reductions as lax_windowed_reductions from jax._src.lib import xla_client -from jax._src.numpy.ufuncs import logaddexp import tensorflow as tf @@ -75,13 +49,11 @@ # pylint: disable=g-direct-tensorflow-import from tensorflow.compiler.tf2xla.python import xla as tfxla from tensorflow.compiler.xla import xla_data_pb2 -from tensorflow.core.framework import attr_value_pb2 try: from tensorflow.python.compiler.xla.experimental import xla_sharding except ModuleNotFoundError: # This can be removed when TF 2.10 support is no longer needed. from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding -from tensorflow.python.framework import ops as tf_ops from tensorflow.python.eager import context as tf_context # pylint: enable=g-direct-tensorflow-import @@ -91,37 +63,10 @@ DisabledSafetyCheck = export.DisabledSafetyCheck -# A temporary internal flag, to enable the wrapping of jax.jit functions -# with tf.function(jit_compile=True). See #7389. This change has triggered a -# number of failures in TF. We keep this until we are confident that it does -# not create problems. -# TODO(b/207464757): figure out why this change breaks test -_WRAP_JAX_JIT_WITH_TF_FUNCTION = False - -# The scope name need to be a valid TensorFlow name. See -# https://github.com/tensorflow/tensorflow/blob/r2.3/tensorflow/core/framework/node_def_util.cc#L731 -_VALID_SCOPE_REGEX = re.compile("^[A-Za-z0-9.][A-Za-z0-9_.\\/>-]*$") -_INVALID_SCOPE_CHAR = re.compile("[^A-Za-z0-9_.\\/-]") map = util.safe_map zip = util.safe_zip - -def _sanitize_scope_name(name): - scope_name = _INVALID_SCOPE_CHAR.sub("_", name) - if not _VALID_SCOPE_REGEX.match(scope_name): - scope_name = f".{scope_name}" - return scope_name - - -# TODO(b/353437394): Deprecate support for `enable_xla=False`. -# Line below is different externally and internally. -allow_enable_xla_false = lambda: True - -# TODO(b/353437398): Deprecate support for `native_serialization=False`. -# Line below is different externally and internally. -allow_native_serialization_false = lambda: True - # A value suitable in a TF tracing context: tf.Tensor, tf.Variable, # or Python scalar or numpy.ndarray. (A tf.EagerTensor is a tf.Tensor.) TfVal = Any @@ -142,27 +87,6 @@ class _DefaultNativeSerialization: pass DEFAULT_NATIVE_SERIALIZATION = _DefaultNativeSerialization() -# The implementation rules for primitives. The rule will be called with the -# arguments (TfVal) and must return TfVal (or a sequence thereof, -# if primitive.multiple_results). The exception are primarily the -# control-flow primitives. -tf_impl: dict[core.Primitive, Callable[..., Any]] = {} - -# Some primitive implementation rules need the abstract values of arguments -# and the results. This is the case for the primitives implemented using -# _convert_jax_impl and those that need to adjust the shape of the outputs -# due to missing TF shape inference rules for TFXLA ops. The rules for these -# primitives should be added to `tf_impl_with_avals`. -# The abstract value are passed to the implementation as two special kwargs -# `_in_avals` (a tuple of core.ShapedArray) and `_out_aval` (a -# core.ShapedArray, or a tuple thereof when primitive.multiple_results). -tf_impl_with_avals: dict[core.Primitive, Callable[..., Any]] = {} - -# XLA is not linked in all environments when converting a primitive. If this is -# the case, we first search for implementation rules for primitives in the -# following map. These implementations are workarounds, making use of TF ops -# that do work when XLA is not linked in. -tf_impl_no_xla = impl_no_xla.tf_impl_no_xla # In order to ensure that JAX picks up the proper user-frame for source # locations we will register the TensorFlow source path as an internal @@ -182,12 +106,6 @@ class _DefaultNativeSerialization: class _ThreadLocalState(threading.local): def __init__(self): - # XLA is not linked in all environments; when converting a primitive, if this - # variable is disabled, we try harder to use only standard TF ops if they are - # applicable to the concrete use case; if the resulting conversion path ends up - # requiring a TFXLA operation, an exception is thrown instead. - self.enable_xla = True - # Keep track if we are inside a call_tf. In that context we disable the # safety check that we are not inside JAX transformations. self.inside_call_tf = False @@ -195,22 +113,6 @@ def __init__(self): # Maps dimension variables to TF expressions, for non-native lowering self.shape_env: Sequence[tuple[str, TfVal]] = () - # Whether to actually include XLA op metadata in the generated TF ops - # TODO(b/189306134): implement support for XLA metadata - self.include_xla_op_metadata = False - - # A cache for the tf.convert_to_tensor for constants. We try to preserve - # sharing for constants, to enable tf.Graph to take advantage of it. - # See https://github.com/jax-ml/jax/issues/7992. - self.constant_cache = None # None means that we don't use a cache. We - # may be outside a conversion scope. - - # A cache for the outside tf name_scope when the converted - # function is running. We will add this as the prefix to the generated tf op - # name. For example, the tf op name will be like - # "{tf_outer_name_scope}/JAX_NAME_STACKS" - self.tf_outer_name_scope = "" - # A dict collecting all tf concrete_functions called by stablehlo.custom_call # This is used only by native serialization (unlike all the other # thread-local state). @@ -218,9 +120,6 @@ def __init__(self): _thread_local_state = _ThreadLocalState() -def _get_current_name_stack() -> NameStack | str: - return source_info_util.current_name_stack() - @contextlib.contextmanager def inside_call_tf(): # Set the inside_call_tf flag for a context. @@ -241,11 +140,11 @@ def get_thread_local_state_call_tf_concrete_function_list() -> ( @partial(api_util.api_hook, tag="jax2tf_convert") def convert(fun_jax: Callable, *, - polymorphic_shapes: str | None = None, + polymorphic_shapes: str | PolyShape | None | Sequence[str | PolyShape | None] = None, polymorphic_constraints: Sequence[str] = (), with_gradient: bool = True, - enable_xla: bool = True, - native_serialization: bool | _DefaultNativeSerialization = DEFAULT_NATIVE_SERIALIZATION, + enable_xla: bool = DEFAULT_NATIVE_SERIALIZATION, # type: ignore + native_serialization: bool | _DefaultNativeSerialization = DEFAULT_NATIVE_SERIALIZATION, # type: ignore native_serialization_platforms: Sequence[str] | None = None, native_serialization_disabled_checks: Sequence[DisabledSafetyCheck] = (), ) -> Callable: @@ -272,7 +171,7 @@ def convert(fun_jax: Callable, should be `None` (monomorphic argument), or a Python object with the same pytree structure as the argument. See [how optional parameters are matched to - arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). + arguments](https://docs.jax.dev/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). A shape specification for an array argument should be an object `PolyShape(dim0, dim1, ..., dimn)` @@ -304,30 +203,12 @@ def convert(fun_jax: Callable, function, by converting the ``jax.vjp(fun)``. This means that reverse-mode TensorFlow AD is supported for the output TensorFlow function, and the value of the gradient will be JAX-accurate. - enable_xla: if set (default), use the simplest conversion - and use XLA TF ops when necessary. These ops are known to create issues - for the TFLite and TFjs converters. For those cases, unset this parameter - so the lowering tries harder to use non-XLA TF ops to lower the - function and aborts if this is not possible. Cannot be set to `False` - when using `native_serialization`. - Starting with JAX 0.4.31 support for `enable_xla=False` is deprecated. - native_serialization: serialize the JAX function natively to - StableHLO with compatibility guarantees. This makes it easier to have - confidence that the code executed when calling this function from - TensorFlow is exactly the same as JAX would run natively. - The DEFAULT_NATIVE_SERIALIZATION value defers to `False` if `enable_xla` - is set to `False` or to the configuration flag - `--jax2tf_default_native_serialization` otherwise. - Native serialization cannot be used with `enable_xla=False`. - Starting with JAX 0.4.31 support for non-native serialization is deprecated. - native_serialization_platforms: In conjunction with - `native_serialization`, specify the platform(s) + native_serialization_platforms: Specifies the platform(s) for which to lower the code. Must be a tuple of strings, including a subset of: 'cpu', 'cuda', 'rocm', 'tpu'. The default (`None``), specifies the JAX default backend on the machine where the lowering is done. - native_serialization_disabled_checks: In conjunction with - `native_serialization`, disable the specified safety checks. + native_serialization_disabled_checks: Disables the specified safety checks. See docstring of `DisabledSafetyCheck`. Returns: @@ -335,51 +216,20 @@ def convert(fun_jax: Callable, tuple/lists/dicts thereof), and returns TfVals as outputs, and uses only TensorFlow ops and thus can be called from a TensorFlow program. """ - if native_serialization is DEFAULT_NATIVE_SERIALIZATION: - if not enable_xla: - native_serialization = False - else: - native_serialization = config.jax2tf_default_native_serialization.value - - if not enable_xla: - if allow_enable_xla_false(): - warnings.warn( - "jax2tf.convert with enable_xla=False has been deprecated " - "since July 2024.", - DeprecationWarning, - stacklevel=2) - if native_serialization: - raise ValueError( - "native_serialization is not supported with enable_xla=False") - else: - raise ValueError( - "jax2tf.convert with enable_xla=False has been deprecated " - "since July 2024 and it is not supported anymore.") - - elif not native_serialization: - if allow_native_serialization_false(): - warnings.warn( - "jax2tf.convert with native_serialization=False has been deprecated " - "since July 2024.", - DeprecationWarning, - stacklevel=2) - else: - raise ValueError( - "jax2tf.convert with native_serialization=False has been deprecated " - "since July 2024 and it is not supported anymore.") - - if not native_serialization and polymorphic_constraints: - raise ValueError( - "polymorphic_constraints are supported only with native serialization" - ) + if native_serialization is not DEFAULT_NATIVE_SERIALIZATION: + warnings.warn( + "The `native_serialization` parameter is deprecated and " + "will be removed in a future version of JAX.", + DeprecationWarning, stacklevel=2) + del native_serialization + if enable_xla is not DEFAULT_NATIVE_SERIALIZATION: + warnings.warn( + "The `enable_xla` parameter is deprecated and " + "will be removed in a future version of JAX.", + DeprecationWarning, stacklevel=2) + del enable_xla if native_serialization_platforms: - if not native_serialization: - warnings.warn( - "using native_serialization_platforms without native_serialization. " - "The parameter will have no effect, since the same code is serialized " - "for all platforms without native_serialization.") - if (not isinstance(native_serialization_platforms, (list, tuple)) or not all(p in ["cpu", "cuda", "rocm", "tpu"] for p in native_serialization_platforms)): @@ -433,19 +283,11 @@ def jax_arg_spec_from_tf(a: TfVal) -> jax.ShapeDtypeStruct: args_flat_tf = tuple( map(preprocess_arg_tf, range(len(args_flat_tf)), args_flat_tf)) - impl: SerializationImpl - if native_serialization: - impl = NativeSerializationImpl( - fun_jax, - args_specs=args_specs, kwargs_specs=kwargs_specs, - native_serialization_platforms=native_serialization_platforms, - native_serialization_disabled_checks=native_serialization_disabled_checks) - else: - impl = GraphSerializationImpl( - fun_jax, - args_specs=args_specs, kwargs_specs=kwargs_specs, - args_flat_tf=args_flat_tf, - enable_xla=enable_xla) + impl = NativeSerializationImpl( + fun_jax, + args_specs=args_specs, kwargs_specs=kwargs_specs, + native_serialization_platforms=native_serialization_platforms, + native_serialization_disabled_checks=native_serialization_disabled_checks) try: impl.before_conversion() @@ -484,52 +326,13 @@ def converted_fun_flat_with_custom_gradient_tf(*args_flat_tf: TfVal) -> TfVal: return converted_fun_tf -class SerializationImpl: - """Implementation details for jax2tf serialization. - - Abstract superclass for subclassing. - """ - def before_conversion(self): - """Called in the resulting TF function, before any other method. - Useful to set any global context.""" - raise NotImplementedError - - def after_conversion(self): - """Called in the resulting TF function, after conversion is done. - - Useful to restore any global context set up by `before_conversion`.""" - raise NotImplementedError - - def run_fun_tf(self, - args_flat_tf: Sequence[TfVal] - ) -> tuple[Sequence[TfVal], Sequence[core.ShapedArray], tree_util.PyTreeDef]: - """Runs the resulting TF function. - - Args: - args_flat_tf: a flat tuple of tf.Tensor arguments - - Returns: a tuple with: - outs_tfs: a flat tuple of tf.Tensor results - outs_avals: a flat tuple of JAX abstract values for the underlying JAX - function. - outs_tree: the PyTreeDef for the outputs - """ - raise NotImplementedError - - def get_vjp_fun(self) -> tuple[Callable, - Sequence[core.AbstractValue]]: - """Returns the VJP function, and the VJP in_avals.""" - raise NotImplementedError - - -class NativeSerializationImpl(SerializationImpl): +class NativeSerializationImpl: def __init__(self, fun_jax, *, args_specs, kwargs_specs, native_serialization_platforms: Sequence[str] | None, native_serialization_disabled_checks: Sequence[DisabledSafetyCheck]): - self.convert_kwargs = dict(native_serialization=True, - native_serialization_platforms=native_serialization_platforms, + self.convert_kwargs = dict(native_serialization_platforms=native_serialization_platforms, native_serialization_disabled_checks=native_serialization_disabled_checks) if hasattr(fun_jax, "trace"): # If we have a pjit or pmap already we do not wrap with another, and we @@ -578,90 +381,15 @@ def get_vjp_fun(self) -> tuple[Callable, return _export._get_vjp_fun(self.fun_jax, in_tree=self.exported.in_tree, in_avals=self.exported.in_avals, + has_named_shardings=self.exported._has_named_shardings, + in_named_shardings=self.exported._in_named_shardings, + out_named_shardings=self.exported._out_named_shardings, in_shardings_hlo=self.exported.in_shardings_hlo, out_avals=self.exported.out_avals, out_shardings_hlo=self.exported.out_shardings_hlo, device_assignment=self.device_assignment, apply_jit=True) -class GraphSerializationImpl(SerializationImpl): - def __init__(self, fun_jax, *, - args_specs, kwargs_specs, - args_flat_tf: Sequence[TfVal], - enable_xla: bool): - self.convert_kwargs = dict(native_serialization=False) - self.fun_jax = fun_jax - self.args_specs = args_specs - self.kwargs_specs = kwargs_specs - self.enable_xla = enable_xla - - fun_name = getattr(fun_jax, "__name__", "unknown") - name_stack = util.wrap_name(fun_name, "jax2tf") - self.name_stack = name_stack - self.args_flat_tf = args_flat_tf - self.debug = api_util.debug_info("jax2tf", fun_jax, - args_specs, kwargs_specs) - - def before_conversion(self): - prev_enable_xla = _thread_local_state.enable_xla - prev_include_xla_op_metadata = _thread_local_state.include_xla_op_metadata - prev_tf_outer_name_scope = _thread_local_state.tf_outer_name_scope - def _restore_context(): - _thread_local_state.enable_xla = prev_enable_xla - _thread_local_state.include_xla_op_metadata = prev_include_xla_op_metadata - _thread_local_state.tf_outer_name_scope = prev_tf_outer_name_scope - _thread_local_state.shape_env = () - self._restore_context = _restore_context - _thread_local_state.enable_xla = self.enable_xla - # TODO(b/189306134): implement support for XLA metadata - _thread_local_state.include_xla_op_metadata = False - _thread_local_state.tf_outer_name_scope = tf.get_current_name_scope() - assert not _thread_local_state.shape_env, f"Unexpected shape environment {_thread_local_state.shape_env}" - args_specs_flat, self.in_tree = tree_util.tree_flatten( - (self.args_specs, self.kwargs_specs)) - self.args_avals_flat = tuple( - map(core.get_aval, args_specs_flat)) - dim_vars = shape_poly.all_dim_vars(self.args_avals_flat) - dim_values, _ = _interpret_fun_jax( - partial(shape_poly.compute_dim_vars_from_arg_shapes, - self.args_avals_flat, args_kwargs_tree=self.in_tree), - self.args_flat_tf, self.args_avals_flat, self.name_stack, - debug_info=api_util.debug_info("jax2tf dim_vars", - shape_poly.compute_dim_vars_from_arg_shapes, - self.args_specs, self.kwargs_specs)) - - _thread_local_state.shape_env = zip(dim_vars, dim_values) - - def after_conversion(self): - self._restore_context() - - def run_fun_tf(self, - args_flat_tf: Sequence[TfVal] - ) -> tuple[Sequence[TfVal], Sequence[core.ShapedArray], tree_util.PyTreeDef]: - fun_flat_jax, out_tree_thunk = flatten_fun_jax(self.fun_jax, self.in_tree) - # out_tree_thunk will be ready after we _interpret_fun_jax below - outs_tf, self.outs_avals = _interpret_fun_jax( - fun_flat_jax, - args_flat_tf, self.args_avals_flat, - self.name_stack, - fresh_constant_cache=True, - debug_info=self.debug) - return outs_tf, self.outs_avals, out_tree_thunk() - - def get_vjp_fun(self) -> tuple[Callable, - Sequence[core.AbstractValue]]: - # We reuse the code for native serialization to get the VJP functions, - # except we use unspecified shardings, and we do not apply a jit on the - # VJP. This matches the older behavior of jax2tf for graph serialization. - return _export._get_vjp_fun(self.fun_jax, - in_tree=self.in_tree, - in_avals=self.args_avals_flat, - in_shardings_hlo=(None,) * len(self.args_avals_flat), - out_avals=self.outs_avals, - out_shardings_hlo=(None,) * len(self.outs_avals), - device_assignment=None, # Not used when apply_jit = False - apply_jit=False) - def dtype_of_val(val: TfVal) -> DType: """Computes the TensorFlow dtype using JAX's typing rules. @@ -783,7 +511,7 @@ def preprocess_arg_tf(arg_idx: int, def _make_custom_gradient_fn_tf(fun_jax, *, - impl: SerializationImpl, + impl: NativeSerializationImpl, with_gradient: bool, args_specs, kwargs_specs, args_tf: Sequence[TfVal], @@ -841,32 +569,6 @@ def fix_out_ct(out_ct_tf, out_ct_aval: core.ShapedArray, out_tf: TfVal): return grad_fn_tf -@contextlib.contextmanager -def _extended_name_stack(extra_name_stack: str | None): - name_ctx = (source_info_util.extend_name_stack(extra_name_stack) - if extra_name_stack - else contextlib.nullcontext()) - with name_ctx: - yield - return - - -def _interpret_fun_jax( - fun_jax: Callable, - args_tf: Sequence[TfVal], - args_avals: Sequence[core.ShapedArray], - extra_name_stack: str | None, *, - fresh_constant_cache: bool = False, - debug_info: core.DebugInfo, -) -> tuple[tuple[TfVal, ...], tuple[core.ShapedArray, ...]]: - subtrace_fun = _interpret_subtrace( - lu.wrap_init(fun_jax, debug_info=debug_info), args_avals) - with _extended_name_stack(extra_name_stack): - out_vals: Sequence[tuple[TfVal, core.ShapedArray]] = \ - _call_wrapped_with_new_constant_cache(subtrace_fun, args_tf, - fresh_constant_cache=fresh_constant_cache) - return util.unzip2(out_vals) - def _run_exported_as_tf(args_flat_tf: Sequence[TfVal], exported: export.Exported, @@ -898,7 +600,6 @@ def _convert_value(val, aval): out_types = tuple(_to_tf_dtype(out_aval.dtype) for out_aval in exported.out_avals) - kept_args_avals = [aval for i, aval in enumerate(exported.in_avals) if i in exported.module_kept_var_idx] kept_args_flat_tf = [atf for i, atf in enumerate(args_flat_tf) if i in exported.module_kept_var_idx] version = exported.calling_convention_version @@ -944,6 +645,11 @@ def _convert_value(val, aval): if DisabledSafetyCheck.platform() in exported.disabled_safety_checks: call_module_attrs["platforms"] = () # No platform checking + if version >= 10: + call_module_attrs["use_shardy_partitioner"] = ( + config.use_shardy_partitioner.value + ) + if logging.vlog_is_on(3): # We already logged the MLIR module when we exported it. logging.vlog(3, "XlaCallModule %s", str(call_module_attrs)) @@ -982,105 +688,6 @@ def _convert_value(val, aval): return res -def _call_wrapped_with_new_constant_cache(fun: lu.WrappedFun, - in_vals: Sequence[TfVal], - fresh_constant_cache: bool = False - ) -> Sequence[tuple[TfVal, core.ShapedArray]]: - try: - prev_constant_cache = _thread_local_state.constant_cache - # Start a new cache, so that we don't share constants across tf.function - # boundaries. - if fresh_constant_cache: - _thread_local_state.constant_cache = {} - else: - prev_constant_cache_keys = set(prev_constant_cache.keys()) if prev_constant_cache is not None else set() - out_vals: Sequence[tuple[TfVal, core.ShapedArray]] = \ - fun.call_wrapped(*in_vals) - finally: - if (not fresh_constant_cache and - prev_constant_cache is not None and - _WRAP_JAX_JIT_WITH_TF_FUNCTION): - newly_added_keys = set(prev_constant_cache.keys()) - prev_constant_cache_keys - # Delete the newly added keys - for k in newly_added_keys: - del prev_constant_cache[k] - _thread_local_state.constant_cache = prev_constant_cache - return out_vals - -def _convert_jax_impl(impl_jax: Callable, *, - multiple_results=True, - with_physical_avals=False, - extra_name_stack: str | None = None) -> Callable: - """Convert the JAX implementation of a primitive. - - Args: - impl_jax: typically the impl-rule for a primitive, with signature - `(*args_jax: JaxVal, **kwargs) -> Sequence[JaxVal]`. This function implements - a primitive in terms of other primitives. - multiple_results: whether `impl_jax` returns a sequence of results. - extra_name_stack: additional element to add to the name stack for the - converted ops. - - Returns: - a function with signature `(*args_tf: TfVal, _in_avals, _out_aval, **kwargs) - -> Sequence[TfVal]`. - """ - - def wrapped_tf(*args_tf: TfVal, _in_avals: Sequence[core.ShapedArray], - _out_aval: core.ShapedArray, - **kwargs) -> Sequence[TfVal]: - - if with_physical_avals: - _in_avals = map(_jax_physical_aval, _in_avals) - _out_aval = _jax_physical_aval(_out_aval) - - # We wrap the impl_jax to always return a tuple of results. - def impl_multiple_results_jax(*args_jax): - results_jax = impl_jax(*args_jax, **kwargs) - return results_jax if multiple_results else [results_jax] - - results_tf, _ = _interpret_fun_jax( - impl_multiple_results_jax, args_tf, _in_avals, - extra_name_stack, - debug_info=api_util.debug_info("jax2tf", impl_jax, - args_tf, kwargs)) - return results_tf if multiple_results else results_tf[0] - - return wrapped_tf - - -@lu.transformation2 -def _interpret_subtrace(f, in_avals: Sequence[core.ShapedArray], - *in_vals: TfVal): - trace = TensorFlowTrace() - in_tracers = tuple( - TensorFlowTracer(trace, val, aval) - for val, aval in zip(in_vals, in_avals)) - with core.set_current_trace(trace): - outs = f(*in_tracers) - out_tracers: Iterable[TensorFlowTracer] = ( - map(trace.to_tf_tracer, outs)) - out_vals_with_avals: Sequence[tuple[TfVal, core.ShapedArray]] = ( - tuple((t.val, t.aval) for t in out_tracers)) - return out_vals_with_avals - - -def _interpret_jaxpr(jaxpr: core.ClosedJaxpr, *args_tf: TfVal, - extra_name_stack: str | None, - fresh_constant_cache: bool = True) -> Sequence[TfVal]: - """Evaluates a Jaxpr with tf.Tensor arguments. - - This is most often used as the body of a tf.function, or tf.switch_case, - in which case it should use a fresh constant cache. - The output is a sequence of TfVal, suitable for use with TF. - """ - outs_tf, _ = _interpret_fun_jax(core.jaxpr_as_fun(jaxpr), - args_tf, jaxpr.in_avals, extra_name_stack, - fresh_constant_cache=fresh_constant_cache, - debug_info=jaxpr.jaxpr.debug_info) - return outs_tf - - def _jax_physical_aval(aval: core.ShapedArray) -> core.ShapedArray: """Converts JAX avals from logical to physical, if relevant. @@ -1195,2297 +802,6 @@ def _tfval_to_tensor_jax_dtype(val: TfVal, return tf_val, jax_dtype -def _eval_shape(shape: Sequence[shape_poly.DimSize], dtype=None) -> Sequence[TfVal]: - # Returns a tuple of shape_poly.dim_as_value_dtype - # Used only for non-native lowering - assert all(map(lambda x: x is not None, shape)), ( - f"Argument shape should be a valid JAX shape but got {shape}") - if dtype is not None: - shape = _jax_physical_aval(core.ShapedArray(shape, dtype)).shape - if core.is_constant_shape(shape): - return tuple(int(d) for d in shape) - - dim_vars, dim_values = util.unzip2(_thread_local_state.shape_env) - shape_values_tf, _ = _interpret_fun_jax( - partial(core.evaluate_shape, shape, dim_vars), - dim_values, [core.dim_value_aval()] * len(dim_values), "", # type: ignore - debug_info=api_util.debug_info("jax2tf evaluate_shape", core.evaluate_shape, - (0, 0, *dim_values), {})) - # Keep only the non-constant dimensions - return tuple(operator.index(d) if core.is_constant_dim(d) else d_tf # type: ignore - for d, d_tf in zip(shape, shape_values_tf)) - - -def _ensure_tf_shape_if_dynamic(x: TfVal, shape): - # Update TF tensor `x` with shape `shape` if the shape of `x`` is dynamic. - if x.shape.is_fully_defined(): - return x - return tf.ensure_shape(x, shape) - - -def _assert_matching_abstract_shape(x: TfVal, shape: Sequence[shape_poly.DimSize]): - """Asserts that shape matches x.shape in the known dimensions and has - dimension polynomials elsewhere.""" - # Ensures that the shape does not contain None; it should contain symbolic expressions. - def check_one(xd: int | None, sd: Any): - if core.is_constant_dim(sd): - return xd == sd - else: - assert export.is_symbolic_dim(sd) - return True - assert (len(x.shape) == len(shape) and - all(check_one(xd, sd) - for xd, sd in zip(x.shape, shape))), \ - f"Shape {shape} does not match x.shape {x.shape}" - -# TODO(b/26854495): pylint doesn't understand slots and inheritance. -# pylint: disable=assigning-non-slot - - -class TensorFlowTracer(core.Tracer): - """Tracer class that boxes a TF value and a JAX abstract value. - - In addition to the TF value we carry the JAX abstract value because - there are some cases when it cannot be recovered from the value: - when we are converting with polymorphic shapes or when the JAX aval - has a custom element type. In these cases the shape of the value may - have dimensions set to `None`, or it may only correspond to the JAX - "physical" (TF/lowering-compatible) shape, so the JAX abstract value - may contain more precise information. - - When the value has a partially-known shape, the dimensions marked as `None` - must correspond to non-constant dimensions in the abstract value. - - See README.md for details. - """ - # val: TfVal - # _aval: core.ShapedArray - __slots__ = ["val", "_aval"] - - def __init__(self, trace: TensorFlowTrace, val: TfVal, - aval: core.AbstractValue): - self._trace = trace - self._aval = aval - phys_aval = _jax_physical_aval(self._aval) # type: ignore[arg-type] - - if isinstance(val, (tf.Tensor, tf.Variable)): - val_shape = val.shape - - if config.enable_checks.value: - assert len(phys_aval.shape) == len(val_shape), f"_aval.shape={phys_aval.shape} different rank than {val_shape=}" - # To compare types, we must handle float0 in JAX and x64 in TF - if phys_aval.dtype == dtypes.float0: - assert _to_tf_dtype(phys_aval.dtype) == val.dtype, f"expected {phys_aval.dtype} == {val.dtype}" - else: - assert phys_aval.dtype == _to_jax_dtype(val.dtype), f"expected {phys_aval.dtype} == {val.dtype}" - - for aval_dim, val_dim in zip(phys_aval.shape, val_shape): - if val_dim is None: - assert export.is_symbolic_dim(aval_dim), f"expected {phys_aval.shape} == {val_shape}" - elif not export.is_symbolic_dim(aval_dim): - assert aval_dim == val_dim, f"expected {phys_aval.shape} == {val_shape}" - else: - # We have a TF value with known shape, and the abstract shape is a shape variable. - try: - aval_int = int(_eval_shape([aval_dim])) # type: ignore - except (TypeError, KeyError, shape_poly.UnexpectedDimVar): - continue - assert aval_int == val_dim, f"expected {phys_aval.shape} == {val_shape}. Found {aval_int} != {val_dim}." - - self.val = _tfval_to_tensor_jax_dtype(val, - phys_aval.dtype, - memoize_constants=True)[0] - - @property - def aval(self): - return self._aval - - def full_lower(self): - return self - -def _make_op_metadata(primitive: core.Primitive, - params: dict, *, - source_info: source_info_util.SourceInfo, - ) -> xla_data_pb2.OpMetadata: - eqn_str = (str(source_info.name_stack) + '/' - + core.str_eqn_compact(primitive, params)) - frame = source_info_util.user_frame(source_info) - return xla_data_pb2.OpMetadata( - op_type=primitive.name, - op_name=eqn_str, - source_file=mlir.get_canonical_source_file( - frame.file_name if frame else "", mlir.TracebackCaches()), - source_line=frame.start_line if frame else None) - - -class TensorFlowTrace(core.Trace): - """Trace class that underlies the jax2tf transformation. - - We are going to ensure that jax2tf.convert is never nested inside other - transformations. This is sufficient for intended use cases (converting - fully-transformed JAX code). It also simplifies our job because we do not have - to handle situations where we apply primitives on a mix of TF values and - JAX tracers from an outer transformation. E.g., for addition both the TF - values - and the JAX tracers have an override and they get confused if they see values - from the other world. - - Hence a TFT trace does not interact with non-TFT traces at lower-level. For - higher-order control-flow primitives we invoke recursively - _interpret_fun on the body of the conditional, which will create a nested TFT. - - We do want to allow transformations nested inside a TensorFlowTrace (TFT), but - those will introduce their own MainTrace, and any operations involving those - will be done on those traces, i.e., not a concern for TFT. - """ - - __slots__ = () - - def to_tf_tracer(self, val: TfVal) -> TensorFlowTracer: - """Lifts a non-Tracer into the TensorFlowTracer. - """ - if isinstance(val, TensorFlowTracer): - return val - if hasattr(val, "__jax_array__"): - with core.set_current_trace(self): - val = val.__jax_array__() - if isinstance(val, TensorFlowTracer): - return val - tf_val, jax_dtype = _tfval_to_tensor_jax_dtype(val, memoize_constants=True) - return TensorFlowTracer( - self, tf_val, core.ShapedArray(np.shape(val), jax_dtype, - weak_type=dtypes.is_weakly_typed(val))) - - def process_primitive(self, primitive: core.Primitive, - tracers: Sequence[TensorFlowTracer], - params) -> TensorFlowTracer: - tracers = map(self.to_tf_tracer, tracers) - impl, impl_needs_avals = self.get_primitive_impl(primitive) - args_avals: Sequence[core.ShapedArray] = tuple(t.aval for t in tracers) - # This is a bit conservative, doing abstract_eval even in op-by-op execution - # but we needed it for, e.g., shape_polymorphism where only JAX's - # abstract evaluation rules can properly track polymorphic shapes. - # Unfortunately under op-by-op execution this is a rare occasion where we - # need abstract evaluation. - out_aval, _ = primitive.abstract_eval(*args_avals, **params) - args_tf: Sequence[TfVal] = [t.val for t in tracers] - def invoke_impl() -> TfVal: - if impl_needs_avals: - return impl( - *args_tf, - _in_avals=args_avals, - _out_aval=out_aval, - **params) - else: - return impl(*args_tf, **params) - - current_name_stack = _get_current_name_stack() - # We don't use `str(name_stack)` because it uses parentheses for - # transformations, which aren't allowed in `name_scope`. - scope = '/'.join([s.name for s in current_name_stack.stack]) # type: ignore[union-attr] - - # Here we reset the name scope to the memorized TF name scope - # + JAX name stack by using absolute scope. - # We need to add a '/' to the name stack string to force `tf.name_scope` - # to interpret it as an absolute scope, not a relative scope. - if _thread_local_state.tf_outer_name_scope: - scope = f"{_thread_local_state.tf_outer_name_scope}/{scope}" - - if not scope.endswith("/"): - scope = scope + "/" - - with tf.name_scope(_sanitize_scope_name(scope)): - if _thread_local_state.include_xla_op_metadata: - op_metadata_proto = _make_op_metadata( - primitive, params, source_info=source_info_util.current()) - with tf_ops.get_default_graph()._attr_scope( - {"_XlaOpMetadata": attr_value_pb2.AttrValue( - s=op_metadata_proto.SerializeToString())}): - val_out = invoke_impl() - else: - val_out = invoke_impl() - - if primitive.multiple_results: - out = [ - TensorFlowTracer(self, v, a) - for v, a in zip(val_out, out_aval) - ] - else: - out = TensorFlowTracer(self, val_out, out_aval) # type: ignore - - # Check that the impl rule returned a value of expected shape and dtype - # TODO: adapt this to match polymorphic shapes - if config.enable_checks.value: - if primitive.multiple_results: - for o, expected_aval in zip(out, out_aval): - assert o.aval.strip_weak_type() == expected_aval.strip_weak_type(), ( - f"{primitive}: out.aval = {o.aval}; expected {expected_aval}") - else: - assert out.aval == out_aval, ( # type: ignore - f"{primitive}: out.aval = {out.aval}; expected {out_aval}" - ) - return out # type: ignore - - def process_call(self, call_primitive: core.Primitive, fun: lu.WrappedFun, - tracers: Sequence[TensorFlowTracer], params): - assert call_primitive.multiple_results - tracers = map(self.to_tf_tracer, tracers) - vals: Sequence[TfVal] = [t.val for t in tracers] - avals: Sequence[core.ShapedArray] = tuple(t.aval for t in tracers) - interpreted_fun = _interpret_subtrace(fun, avals) - extra_name_stack = None - with _extended_name_stack(extra_name_stack): - vals_out = interpreted_fun.call_wrapped(*vals) - return [TensorFlowTracer(self, v, a) for v, a in vals_out] - - def process_map(self, map_primitive, f, tracers, params): - raise NotImplementedError("process_map") - - def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): - # Drop the custom differentiation rule and act like a call primitive. This - # behavior is desirable because jax2tf stages code out of the JAX system, so - # there are no more JAX differentiation transformations to be applied. - del jvp, symbolic_zeros # Unused. - return self.process_call(core.call_p, fun, tracers, {}) - - def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, - symbolic_zeros): - # Drop the custom differentiation rule and act like a call primitive. This - # behavior is desirable because jax2tf stages code out of the JAX system, so - # there are no more JAX differentiation transformations to be applied. - del fwd, bwd, out_trees, symbolic_zeros # Unused. - return self.process_call(core.call_p, fun, tracers, {}) - - def get_primitive_impl(self, p: core.Primitive) -> tuple[Callable, bool]: - # Returns the primitive implementation and whether the implementation - # takes abstract values (see definition of tf_impl_with_avals) - if not _thread_local_state.enable_xla: - try: - return tf_impl_no_xla[p], True # Always require avals. - except KeyError: - pass - try: - return tf_impl[p], False - except KeyError: - try: - return tf_impl_with_avals[p], True - except KeyError as err: - msg = "TensorFlow interpretation rule for '{}' not implemented" - raise NotImplementedError(msg.format(p)) from err - -def _unexpected_primitive(p: core.Primitive, *args, **kwargs): - assert False, f"Encountered unexpected primitive {p}" - -for unexpected in [core.call_p]: - tf_impl[unexpected] = partial(_unexpected_primitive, unexpected) - -tf_impl[lax_control_flow.loops.eval_jaxpr_p] = \ - lambda *args, jaxpr: _interpret_jaxpr( - jaxpr, *args, fresh_constant_cache=False, extra_name_stack=None) - -# Primitives that are not yet implemented must be explicitly declared here. -tf_not_yet_impl = [ - "clz", - "igamma_grad_a", - "random_gamma_grad", - "polygamma", - "reduce_xor", - "schur", - "closed_call", - "unreachable", - "bint", - "getslice", - "full_to_shard", - "shard_to_full", - "pure_callback", - "run_state", - "for", - "inspect_sharding", - "io_callback", - "broadcast_to", - "shard_map", - "global_array_to_host_local_array", - "host_local_array_to_global_array", - "call_exported", - "zeta", - # Not high priority? - "after_all", - "all_to_all", - "check", - "create_token", - "custom_transpose_call", - "custom_vmap_call", - "infeed", - "linear_call", - "outfeed", - "pmax_p", - "pmin", - "ppermute", - "psum", - "psum2", - "pbroadcast", - "pmax", - "pgather", - "reduce_scatter", - "axis_index", - "all_gather", - "lu_pivots_to_permutation", - "xla_pmap", - "geqrf", - "geqp3", - "householder_product", - "hessenberg", - "tridiagonal", - "eigh_jacobi", - "platform_index", - "assert_consumed_value", - "consume", - "ragged_dot", - "ragged_dot_general", - "cholesky_update", - "symmetric_product", - "from_edtype", - "to_edtype", - "reciprocal", - # Pallas TPU primitives - "bitcast", - "repeat", - "roll", - # temporary pending cudnn fix, see https://github.com/jax-ml/jax/pull/23740 - "bias_fwd", - "bias_bwd", -] - -tf_impl[random_internal.random_clone_p] = lambda x: x - -tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient - - -def _add(x: TfVal, y: TfVal) -> TfVal: - return tf.raw_ops.AddV2(x=x, y=y) - - -tf_impl[ad_util.add_jaxvals_p] = _add -tf_impl[dispatch.device_put_p] = lambda *xs, devices=None, srcs=None, copy_semantics=None: xs -tf_impl[lax_internal.copy_p] = lambda x: x - -def _shard_alike(*args: TfVal, **_): - return tuple(args) -tf_impl[shard_alike.shard_alike_p] = _shard_alike - -def _neg(x: TfVal) -> TfVal: - if x.dtype.is_unsigned: - signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[x.dtype] - x_signed = tf.cast(x, signed_dtype) - res_signed = tf.math.negative(x_signed) - return tf.cast(res_signed, x.dtype) - else: - return tf.math.negative(x) - -tf_impl[lax.neg_p] = _neg - - -def _sign(x: TfVal) -> TfVal: - if x.dtype.is_unsigned: - # TF and XLA do not support tf.math.sign for unsigned types. - return tf.where( - tf.math.equal(x, 0), tf.constant(0, dtype=x.dtype), - tf.constant(1, dtype=x.dtype)) - else: - return tf.math.sign(x) - - -tf_impl[lax.sign_p] = _sign -tf_impl[lax.floor_p] = tf.math.floor -tf_impl[lax.ceil_p] = tf.math.ceil - - -def _round(operand, *, rounding_method, - _in_avals: Sequence[core.ShapedArray], - _out_aval: core.ShapedArray): - if rounding_method is lax.RoundingMethod.AWAY_FROM_ZERO: - # JAX uses a single HLO op Round here - sign = _sign(operand) - operand *= sign - floor = tf.math.floor(operand) - operand -= floor - cond = tf.math.equal(operand, tf.constant(np.array(0.5), operand.dtype)) - return sign * ( - tf.where(cond, tf.constant(np.array(1), operand.dtype), - tf.math.round(operand)) + floor) - else: # rounding_method is RoundingMethod.TO_NEAREST_EVEN - return tf.math.round(operand) - -tf_impl_with_avals[lax.round_p] = _round -tf_impl[lax.nextafter_p] = tf.math.nextafter - - -def _population_count(x): - orig_dtype = x.dtype - return tf.cast(tf.raw_ops.PopulationCount(x=x), orig_dtype) - - -tf_impl[lax.population_count_p] = _population_count -tf_impl[lax.is_finite_p] = tf.math.is_finite - - -def _abs(x: TfVal) -> TfVal: - # TF and XLA do not support tf.math.abs for unsigned types. - return tf.math.abs(x) if not x.dtype.is_unsigned else x - - -tf_impl[lax.abs_p] = _abs - - -def _pow(x: TfVal, y: TfVal, *, _in_avals, _out_aval) -> TfVal: - x = tf.dtypes.cast(x, _to_tf_dtype(_out_aval.dtype)) - y = tf.dtypes.cast(y, _to_tf_dtype(_out_aval.dtype)) - return tf.math.pow(x, y) - - -tf_impl_with_avals[lax.pow_p] = _pow - - -def _integer_pow(x, *, y: int, _in_avals: Sequence[core.ShapedArray], - _out_aval: core.ShapedArray): - # Follows the implementation in lax._integer_pow_translation_rule - if y == 0: - return tf.broadcast_to( - tf.constant(1, dtype=x.dtype, shape=()), _eval_shape(_out_aval.shape)) - is_reciprocal = y < 0 - if is_reciprocal: - y = -y - acc = None - while y > 0: - if y & 1: - acc = x if acc is None else tf.math.multiply(acc, x) - y >>= 1 - if y > 0: - x = tf.math.multiply(x, x) - return tf.math.reciprocal(acc) if is_reciprocal else acc - - -tf_impl_with_avals[lax.integer_pow_p] = _integer_pow -tf_impl[lax.exp_p] = tf.math.exp -tf_impl[lax_internal.exp2_p] = lambda x: \ - tf.math.exp(tf.math.multiply(tf.math.log(tf.constant(2, x.dtype)), x)) -tf_impl[lax.expm1_p] = tf.math.expm1 -tf_impl[lax.log_p] = tf.math.log -tf_impl[lax.log1p_p] = tf.math.log1p -tf_impl[lax.tan_p] = tf.math.tan -tf_impl[lax.tanh_p] = tf.math.tanh -tf_impl[lax.sin_p] = tf.math.sin -tf_impl[lax.sinh_p] = tf.math.sinh -tf_impl[lax.cos_p] = tf.math.cos -tf_impl[lax.cosh_p] = tf.math.cosh -tf_impl_with_avals[lax.atan_p] = _convert_jax_impl( - lax_internal.atan_impl, multiple_results=False) - -# TODO(phawkins): use tf.math.sigmoid here instead. -tf_impl_with_avals[lax.logistic_p] = _convert_jax_impl( - lax_internal.logistic_impl, multiple_results=False) - -def _atan2(y, x, **kwargs): - if x.dtype.is_complex or y.dtype.is_complex: - complex_component_dtype = { - tf.complex64: tf.float32, - tf.complex128: tf.float64 - }.get(y.dtype) - zero = tf.constant(0, complex_component_dtype) - one = tf.constant(1, complex_component_dtype) - i = tf.complex(zero, one) - return -i * tf.math.log((x + i * y)/tf.math.sqrt(x * x + y * y)) - else: - return tf.math.atan2(y, x) - - -tf_impl[lax.atan2_p] = _atan2 -tf_impl[lax.acosh_p] = tf.math.acosh -tf_impl[lax.atanh_p] = tf.math.atanh -tf_impl[lax.asinh_p] = tf.math.asinh -tf_impl[lax.asin_p] = tf.math.asin -tf_impl[lax.acos_p] = tf.math.acos - -tf_impl[lax.sqrt_p] = tf.math.sqrt -tf_impl[lax.square_p] = tf.math.square -tf_impl[lax.rsqrt_p] = tf.math.rsqrt - -def _cbrt(x): - return tf.math.sign(x) * tf.math.pow(tf.math.abs(x), 1/3) - -tf_impl[lax.cbrt_p] = _cbrt - -tf_impl[lax.lgamma_p] = tf.math.lgamma -tf_impl[lax.digamma_p] = tf.math.digamma -tf_impl[lax.igamma_p] = tf.math.igamma -tf_impl[lax.igammac_p] = tf.math.igammac -tf_impl[lax.regularized_incomplete_beta_p] = tf.math.betainc -tf_impl[lax.erf_p] = tf.math.erf -tf_impl[lax.erfc_p] = tf.math.erfc -tf_impl[lax.erf_inv_p] = tf.math.erfinv -tf_impl[lax.bessel_i0e_p] = tf.math.bessel_i0e -tf_impl[lax.bessel_i1e_p] = tf.math.bessel_i1e - -tf_impl[lax.complex_p] = tf.complex - - -def _conj(x, **kwargs): - # The only dtypes that are allowed are: float32, float64, complex64, and - # complex128. - if x.dtype == tf.float32: - return tf.cast(x, tf.complex64) - elif x.dtype == tf.float64: - return tf.cast(x, tf.complex128) - else: - return tf.math.conj(x) - - -tf_impl[lax.conj_p] = _conj -tf_impl[lax.real_p] = tf.math.real -tf_impl[lax.imag_p] = tf.math.imag - -tf_impl[lax.add_p] = _add -tf_impl[lax.sub_p] = tf.math.subtract -tf_impl[lax.mul_p] = tf.math.multiply - - -def _iota(*, dtype, shape, dimension, sharding=None): - dtype = _to_tf_dtype(dtype) - # Some dtypes are unsupported, like uint32, so we just fall back to int32. - # TODO(mattjj, necula): improve tf.range dtype handling - shape_tf = _eval_shape(shape) - vec = tf.range(tf.cast(shape_tf[dimension], tf.int32), dtype=tf.int32) - vec_shape = [-1 if i == dimension else 1 for i in range(len(shape))] - return tf.cast(tf.broadcast_to(tf.reshape(vec, vec_shape), shape_tf), dtype) - - -tf_impl[lax.iota_p] = _iota - - -def _div(lhs, rhs): - if lhs.dtype.is_integer: - quotient = tf.math.floordiv(lhs, rhs) - select = tf.math.logical_and( - tf.not_equal(_sign(lhs), _sign(rhs)), - tf.not_equal(tf.math.floormod(lhs, rhs), 0)) - return tf.where(select, quotient + 1, quotient) - else: - return tf.math.truediv(lhs, rhs) - - -def _rem(lhs, rhs): - return _sign(lhs) * tf.math.floormod(_abs(lhs), _abs(rhs)) - - -tf_impl[lax.div_p] = _div -tf_impl[lax.rem_p] = _rem - - -def _minmax(x: TfVal, y: TfVal, *, is_min: bool, - _in_avals: Sequence[core.ShapedArray], - _out_aval: core.ShapedArray,) -> TfVal: - # For complex numbers use lexicographic ordering, like JAX - if dtypes.issubdtype(x.dtype.as_numpy_dtype, np.complexfloating): - return _convert_jax_impl( - partial(lax_internal._minmax_complex_lowering, - lax_cmp_pick_x=lax.lt if is_min else lax.gt), - multiple_results=False)(x, y, _in_avals=_in_avals, _out_aval=_out_aval) - elif x.dtype.as_numpy_dtype == np.bool_: - return (tf.math.logical_and if is_min else tf.math.logical_or)(x, y) - else: - return (tf.math.minimum if is_min else tf.math.maximum)(x, y) - -def _minmax_scalar(x: TfVal, y: TfVal, *, is_min: bool) -> TfVal: - # For reducers we will need min/max for scalars only. In that case we - # can construct the AbstractValues ourselves, even in the presence of - # shape polymorphism. - assert len(x.shape) == 0 and len(y.shape) == 0, f"x: {x.shape}, y: {y.shape}" - aval = core.ShapedArray((), _to_jax_dtype(x.dtype)) - return _minmax(x, y, is_min=is_min, - _in_avals=[aval, aval], _out_aval=aval) - -tf_impl_with_avals[lax.max_p] = partial(_minmax, is_min=False) -tf_impl_with_avals[lax.min_p] = partial(_minmax, is_min=True) - -# Map from TF signed types to TF unsigned types. -_SIGNED_TO_UNSIGNED_TABLE = { - tf.int8: tf.uint8, - tf.int16: tf.uint16, - tf.int32: tf.uint32, - tf.int64: tf.uint64, -} - -# Map from TF unsigned types to TF signed types. -_UNSIGNED_TO_SIGNED_TABLE = {u: s for s, u in _SIGNED_TO_UNSIGNED_TABLE.items()} - - -# Note: Bitwise operations only yield identical results on unsigned integers! -# pylint: disable=protected-access -def _shift_right_arithmetic_raw(x, y): - if x.dtype.is_unsigned: - assert x.dtype == y.dtype - orig_dtype = x.dtype - signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[orig_dtype] - x = tf.cast(x, signed_dtype) - y = tf.cast(y, signed_dtype) - res = tf.bitwise.right_shift(x, y) - return tf.cast(res, orig_dtype) - else: - return tf.bitwise.right_shift(x, y) - - -def _shift_right_arithmetic(x, y): - # TF shift is "implementation defined" if the shift amount is negative - # or larger or equal to the size of the value. We implement the XLA - # semantics to return the shift by the max value (x_bits - 1). - # TODO: it is likely better to add XlaOps for shifts - x_bits = 8 * x.dtype.size - clamp_y = tf.where(_shift_in_bounds(x, y), y, x_bits - 1) - return _shift_right_arithmetic_raw(x, clamp_y) - - -tf_impl[lax.shift_right_arithmetic_p] = _shift_right_arithmetic - - -def _shift_right_logical_raw(x, y): - if x.dtype.is_unsigned: - return tf.bitwise.right_shift(x, y) - else: - assert x.dtype == y.dtype - orig_dtype = x.dtype - unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[orig_dtype] - x = tf.cast(x, unsigned_dtype) - y = tf.cast(y, unsigned_dtype) - res = tf.bitwise.right_shift(x, y) - return tf.cast(res, orig_dtype) - - -def _shift_right_logical(x, y): - # TF shift is "implementation defined" if the shift amount is negative - # or larger or equal to the size of the value. We implement the XLA semantics - # to return 0. - # TODO: it is likely better to add XlaOps for shifts - return tf.where( - _shift_in_bounds(x, y), _shift_right_logical_raw(x, y), tf.zeros_like(x)) - - -tf_impl[lax.shift_right_logical_p] = _shift_right_logical - - -def _shift_left(x, y): - # TF shift is "implementation defined" if the shift amount is negative - # or larger or equal to the size of the value. We implement the XLA semantics - # to return 0. - # TODO: it is likely better to add XlaOps for shifts - return tf.where( - _shift_in_bounds(x, y), tf.bitwise.left_shift(x, y), tf.zeros_like(x)) - - -tf_impl[lax.shift_left_p] = _shift_left - - -def _shift_in_bounds(x: TfVal, y: TfVal) -> TfVal: - # Return the TF expression for when y is within bounds (0 <= y < |x|) - x_bits = 8 * x.dtype.size - # TF does not have comparisons for uint16 and uint32 (despite what the - # documentation says) - y_comp = tf.cast( - y, _UNSIGNED_TO_SIGNED_TABLE[y.dtype]) if y.dtype.is_unsigned else y - y_lt_x_bits = tf.math.less(y_comp, x_bits) - y_ge_0 = tf.math.greater_equal(y_comp, 0) - return tf.logical_and(y_lt_x_bits, y_ge_0) - - -def _not(x): - """Computes bitwise not with support for booleans. - - Numpy and JAX support bitwise not for booleans by applying a logical not! - This means that applying bitwise_not yields an unexpected result: - jnp.bitwise_not(jnp.array([True, False])) - >> Array([False, True], dtype=bool) - - if you assume that booleans are simply casted to integers. - jnp.bitwise_not(jnp.array([True, False]).astype(np.int32)).astype(bool) - >> Array([True, True], dtype=bool) - """ - if x.dtype == tf.bool: - return tf.logical_not(x) - else: - return tf.bitwise.invert(x) - - -tf_impl[lax.not_p] = _not - - -def handle_boolean_args(f, argnums: Sequence[int], boolean_f=None): - """Computes functions with some bool args and bool results using int8. - - This is needed because some TF ops do not work for bool args, e.g., - inequalities, min/max. - - Args: - f: a TF callable to wrap. It will be called with non-boolean arguments. - argnums: the positional arguments that may be booleans. - boolean_f: [Optional] a TF callable compatible with boolean - arguments. - - Returns: a TF callable that can take a mix of boolean positional arguments - (in the positions specified by `argnums`) and some non-boolean positional - arguments. If there are no boolean arguments, just calls `f`. Otherwise, - it calls `boolean_f` if defined. Otherwise, casts the boolean - arguments to `int8`, calls `f`, then casts the result to `bool`. - """ - argnums = tf.nest.flatten(argnums) - - def wrapper(*args: TfVal, **kwargs): - argnum_types = {args[i].dtype for i in argnums} - if tf.bool not in argnum_types: - return f(*args, **kwargs) - else: - # All argnums should be boolean - assert len(argnum_types) == 1, argnum_types - if boolean_f != None: - return boolean_f(*args, **kwargs) - else: - args_cast = [(tf.cast(a, tf.int8) if i in argnums else a) - for i, a in enumerate(args)] - if "_in_avals" in kwargs: - - def cast_aval(aval): - assert aval.dtype == np.bool_ - return core.ShapedArray(aval.shape, np.int8) - - _in_avals_cast = [ - cast_aval(aval) if i in argnums else aval - for i, aval in enumerate(kwargs["_in_avals"]) - ] - _out_aval_cast = tf.nest.map_structure(cast_aval, kwargs["_out_aval"]) - kwargs = dict( - kwargs, _in_avals=_in_avals_cast, _out_aval=_out_aval_cast) - out = f(*args_cast, **kwargs) - return tf.nest.map_structure(lambda o: tf.cast(o, tf.bool), out) - - return wrapper - - -tf_impl[lax.or_p] = handle_boolean_args(tf.bitwise.bitwise_or, argnums=(0, 1), boolean_f=tf.logical_or) -tf_impl[lax.and_p] = handle_boolean_args(tf.bitwise.bitwise_and, argnums=(0, 1), boolean_f=tf.logical_and) -tf_impl[lax.xor_p] = handle_boolean_args(tf.bitwise.bitwise_xor, argnums=(0, 1), boolean_f=tf.math.logical_xor) - -tf_impl[lax.eq_p] = tf.math.equal -tf_impl[lax.ne_p] = tf.math.not_equal - - -def _total_order_adjustment(x): - if not dtypes.issubdtype(x.dtype.as_numpy_dtype, np.inexact): - return x - assert dtypes.issubdtype(x.dtype.as_numpy_dtype, np.floating) - # Switch from a floating point value to a integer value in such a way that - # when using the integer value to compare, we get the same result for normal - # values, and -nan is treated as the smallest value, and nan is treated as - # the largest value. - # If f is a float, and - # x = bit_cast(f); - # y = x < 0 ? int32_max - x : x; - # then y is ordered as an int32 such that finite values have the obvious - # order. In this scheme, -0 would be before 0, and -NaN and NaN appear at - # the beginning and end of the ordering. - nbits = dtypes.finfo(x.dtype.as_numpy_dtype).bits - signed_dtype = lax_internal._INT_DTYPES[nbits] - unsigned_dtype = lax_internal._UINT_DTYPES[nbits] - - signed = tf.bitcast(x, signed_dtype) - sign_mask = tf.bitcast(tf.bitwise.right_shift(signed, nbits - 1), unsigned_dtype) - sign_magnitude_mask = tf.bitcast(tf.bitwise.right_shift(sign_mask, 1), signed_dtype) - return tf.bitwise.bitwise_xor(signed, sign_magnitude_mask) - -def _total_order_equal(x, y): - if dtypes.issubdtype(x.dtype.as_numpy_dtype, np.complexfloating): - return _total_order_equal(tf.math.real(x), tf.math.real(y)) and _total_order_equal(tf.math.imag(x), tf.math.imag(y)) - return tf.math.equal(_total_order_adjustment(x), _total_order_adjustment(y)) - -tf_impl[lax.eq_to_p] = _total_order_equal - -boolean_greater = lambda x,y: tf.logical_and(x, tf.logical_not(y)) # Only one combo: T,F -> T -boolean_less = lambda x,y: tf.logical_and(tf.logical_not(x), y) # Only one combo: F,T -> T -boolean_greater_or_equal = lambda x, y: tf.logical_not(boolean_less(x,y)) # All cases except F,T -boolean_less_or_equal = lambda x, y: tf.logical_not(boolean_greater(x,y)) # All cases except T,F - -tf_impl[lax.gt_p] = handle_boolean_args(tf.math.greater, argnums=(0, 1), boolean_f=boolean_greater) -tf_impl[lax.lt_p] = handle_boolean_args(tf.math.less, argnums=(0, 1), boolean_f=boolean_less) -tf_impl[lax.ge_p] = handle_boolean_args(tf.math.greater_equal, argnums=(0, 1), boolean_f=boolean_greater_or_equal) -tf_impl[lax.le_p] = handle_boolean_args(tf.math.less_equal, argnums=(0, 1), boolean_f=boolean_less_or_equal) - -def _total_order_cond(cond, x, y): - return cond(_total_order_adjustment(x), _total_order_adjustment(y)) - -tf_impl[lax.lt_to_p] = handle_boolean_args(partial(_total_order_cond, tf.math.less), argnums=(0, 1), boolean_f=boolean_less) -tf_impl[lax.le_to_p] = handle_boolean_args(partial(_total_order_cond, tf.math.less_equal), argnums=(0, 1), boolean_f=boolean_less_or_equal) - -tf_impl[lax.linalg.cholesky_p] = tf.linalg.cholesky - - -def _convert_element_type(operand, *, new_dtype, weak_type=False, sharding=None): - old_dtype = operand.dtype.as_numpy_dtype - if (dtypes.issubdtype(old_dtype, np.complexfloating) and - not dtypes.issubdtype(new_dtype, np.complexfloating)): - operand = tf.math.real(operand) - if (dtypes.issubdtype(old_dtype, np.floating) and - not (dtypes.issubdtype(new_dtype, np.floating) or dtypes.issubdtype( - new_dtype, np.complexfloating) or new_dtype == np.bool_)): - sign = _sign(operand) - operand = sign * tf.math.floor(sign * operand) - return tf.dtypes.cast(operand, _to_tf_dtype(new_dtype)) - - -tf_impl[lax.convert_element_type_p] = _convert_element_type - - -def _bitcast_convert_type(operand, new_dtype): - if operand.dtype == new_dtype: - return operand - return tf.bitcast(operand, _to_tf_dtype(new_dtype)) - - -tf_impl[lax.bitcast_convert_type_p] = _bitcast_convert_type - - -def _clamp(minval, operand, maxval, *, _in_avals, _out_aval): - # The below permits mirroring the behavior of JAX when maxval < minval - op_shape_tf_val = _eval_shape(_in_avals[1].shape, _in_avals[1].dtype) - maxval = tf.broadcast_to(maxval, op_shape_tf_val) - minval = tf.math.minimum(tf.broadcast_to(minval, op_shape_tf_val), maxval) - return tf.clip_by_value(operand, minval, maxval) - - -tf_impl_with_avals[lax.clamp_p] = _clamp - - -def _concatenate(*operands, dimension): - return tf.concat(operands, axis=tf.cast(dimension, tf.int32)) - - -tf_impl[lax.concatenate_p] = _concatenate - - -def _split(operand, *, sizes, axis): - return tf.split(operand, _eval_shape(sizes), axis=axis) - -tf_impl[lax.split_p] = _split - - -def _conv_general_dimension_numbers_proto(dimension_numbers): - """Converts a ConvDimensionNumbers to an XLA ConvolutionDimensionNumbers.""" - assert isinstance(dimension_numbers, lax.ConvDimensionNumbers) - lhs_spec, rhs_spec, out_spec = dimension_numbers - proto = xla_data_pb2.ConvolutionDimensionNumbers() - proto.input_batch_dimension = lhs_spec[0] - proto.input_feature_dimension = lhs_spec[1] - proto.output_batch_dimension = out_spec[0] - proto.output_feature_dimension = out_spec[1] - proto.kernel_output_feature_dimension = rhs_spec[0] - proto.kernel_input_feature_dimension = rhs_spec[1] - proto.input_spatial_dimensions.extend(lhs_spec[2:]) - proto.kernel_spatial_dimensions.extend(rhs_spec[2:]) - proto.output_spatial_dimensions.extend(out_spec[2:]) - return proto - - -def _precision_config_proto(precision: None | (tuple[PrecisionType, - PrecisionType])): - """Convert an integer to an XLA.PrecisionConfig.""" - if precision is None: - return None - - proto = xla_data_pb2.PrecisionConfig() - proto.operand_precision.append(precision[0].value) - proto.operand_precision.append(precision[1].value) - return proto - - -def _conv_general_dilated(lhs, rhs, *, - window_strides, padding, lhs_dilation, - rhs_dilation, - dimension_numbers: lax.ConvDimensionNumbers, - feature_group_count: int, - batch_group_count: int, - precision: tuple[PrecisionType, PrecisionType] | None, - preferred_element_type: DType | None, - _in_avals: Sequence[core.ShapedArray], - _out_aval: core.ShapedArray): - """Implementation of lax.conv_general_dilated_p using XlaConv.""" - out_tf_shape = _aval_to_tf_shape(_out_aval) - dnums_proto = _conv_general_dimension_numbers_proto(dimension_numbers) - precision_config_proto = _precision_config_proto(precision) - - def gen_conv(lhs, rhs, preferred_element_type: DType | None): - tf_version = tuple(int(v) for v in tf.__version__.split(".")[:2]) - if tf_version >= (2, 8): - # TODO(necula): remove when 2.8.0 is the stable TF version (and supports - # batch_group_count. - padding_tf = [_eval_shape(p) for p in padding] - out = tfxla.conv( - lhs, rhs, window_strides, padding_tf, lhs_dilation, rhs_dilation, - dnums_proto, - feature_group_count=feature_group_count, - batch_group_count=batch_group_count, - precision_config=precision_config_proto, - preferred_element_type=preferred_element_type, - use_v2=True) - else: - if batch_group_count != 1: - raise ValueError( - "The batch_group_count parameter for conv requires TF version " - "at least 2.8.0. You may want to use tf-nightly.") - padding_tf = [_eval_shape(p) for p in padding] - out = tfxla.conv( - lhs, rhs, window_strides, padding_tf, lhs_dilation, rhs_dilation, - dnums_proto, - feature_group_count=feature_group_count, - precision_config=precision_config_proto, - preferred_element_type=preferred_element_type, - use_v2=True) - # TODO: implement shape inference for XlaConv - out = _ensure_tf_shape_if_dynamic(out, out_tf_shape) - if _WRAP_JAX_JIT_WITH_TF_FUNCTION: - out = tf.stop_gradient(out) # See #7839 - return out - - # Follow the lowering for complex convolutions from - # lax._conv_general_dilated_translation. We can use the same conversion on all - # platforms because on XLA:TPU the compiler does the same as a rewrite. - preferred_float_et: Any | None - if np.issubdtype(_in_avals[0].dtype, np.complexfloating): - if preferred_element_type is not None: - # Convert complex dtype to types used for real and imaginary parts - assert np.issubdtype(preferred_element_type, np.complexfloating) - preferred_float_et = ( - np.float64 if preferred_element_type == np.complex128 else np.float32) - else: - preferred_float_et = None - lhs_real, lhs_imag = tf.math.real(lhs), tf.math.imag(lhs) - rhs_real, rhs_imag = tf.math.real(rhs), tf.math.imag(rhs) - k1 = gen_conv(_add(lhs_real, lhs_imag), rhs_real, preferred_float_et) - k2 = gen_conv(lhs_real, tf.math.subtract(rhs_imag, rhs_real), - preferred_float_et) - k3 = gen_conv(lhs_imag, _add(rhs_real, rhs_imag), preferred_float_et) - return tf.complex(tf.math.subtract(k1, k3), _add(k1, k2)) - else: - return gen_conv(lhs, rhs, preferred_element_type) - - -tf_impl_with_avals[lax.conv_general_dilated_p] = _conv_general_dilated - - -def _dot_general(lhs, rhs, *, dimension_numbers, - precision: lax_internal.CanonicalPrecision, - preferred_element_type: DType | None, - out_sharding=None, - _in_avals: Sequence[core.ShapedArray], - _out_aval: core.ShapedArray): - """Implementation of lax.dot_general_p in terms of tf.linalg.einsum.""" - # TODO(b/293247337): we ought to turn on this safety check, but this leads to - # failures. Since we are going to turn on native serialization soon, wait - # until then to turn on this check. - # lhs_aval, rhs_aval = _in_avals - # if lhs_aval.dtype != rhs_aval.dtype: - # # There are multiple kinds of errors: handling jnp.bfloat16 in xla.py and - # # returning different result dtype than JAX expects for various combinations - # # of types. We ought to implement the same workarounds as in the - # # native dot_general lowering rules, but this is not a high priority now - # # that we deprecate non-native serialization. - # raise NotImplementedError( - # "dot_general with different lhs_dtype and rhs_dtype is not supported " - # "in non-native serialization") - - if precision == lax.DotAlgorithmPreset.DEFAULT: - precision = None - if precision is not None and not (isinstance(precision, tuple) and - len(precision) == 2): - raise NotImplementedError( - f"Unsupported precision in dot_general: {precision}") - - lhs, rhs, convert_result = _dot_general_convert_to_common_dtype( - lhs, _in_avals[0], rhs, _in_avals[1], _out_aval) - - (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers - dnums_proto = xla_data_pb2.DotDimensionNumbers() - dnums_proto.lhs_contracting_dimensions.extend(lhs_contracting) - dnums_proto.rhs_contracting_dimensions.extend(rhs_contracting) - dnums_proto.lhs_batch_dimensions.extend(lhs_batch) - dnums_proto.rhs_batch_dimensions.extend(rhs_batch) - precision_config_proto = _precision_config_proto(precision) # type: ignore - res = tfxla.dot_general( - lhs, - rhs, - dnums_proto, - precision_config_proto, - preferred_element_type=preferred_element_type, - use_v2=True) - res = convert_result(res) - if _WRAP_JAX_JIT_WITH_TF_FUNCTION: - res = tf.stop_gradient(res) # See #7839 - return res - - -tf_impl_with_avals[lax.dot_general_p] = _dot_general - -def _dot_general_convert_to_common_dtype( - lhs: TfVal, lhs_aval: core.ShapedArray, - rhs: TfVal, rhs_aval: core.ShapedArray, - out_aval: core.ShapedArray) -> tuple[TfVal, TfVal, Callable[[TfVal], TfVal]]: - # Returns the converted lhs, rhs, and the converter for the result. - # tfxla.dot_general does not handle arguments of different types. - # We convert the arguments and the result. - # Use native serialization for a more JAX-native behavior. - if lhs_aval.dtype != rhs_aval.dtype: - - common_dtype = dtypes.result_type(lhs_aval, rhs_aval) - if common_dtype != lhs_aval.dtype: - lhs = _convert_element_type(lhs, new_dtype=common_dtype) - if common_dtype != rhs_aval.dtype: - rhs = _convert_element_type(rhs, new_dtype=common_dtype) - convert_result = lambda res: _convert_element_type(res, new_dtype=out_aval.dtype) - else: - convert_result = lambda res: res - return (lhs, rhs, convert_result) - -def _broadcast_in_dim(operand, *, shape, broadcast_dimensions, sharding=None, - _in_avals: Sequence[core.ShapedArray], - _out_aval: core.ShapedArray): - # for i in range(len(operand.shape)): - # result.shape[bcast_dims[i]] <- operand.shape[i] - # bcast_dims must be strictly increasing. - # len(bcast_dims) == len(operand.shape) - op_shape = _in_avals[0].shape - dtype = _in_avals[0].dtype - add_1s_shape = [1] * len(shape) - for i, broadcast_dim_i in enumerate(broadcast_dimensions): - add_1s_shape[broadcast_dim_i] = op_shape[i] - with_1s = tf.reshape(operand, _eval_shape(add_1s_shape, dtype=dtype)) - return tf.broadcast_to(with_1s, _eval_shape(shape, dtype=dtype)) - - -tf_impl_with_avals[lax.broadcast_in_dim_p] = _broadcast_in_dim - - -def _empty(*, dtype): - if dtypes.issubdtype(dtype, dtypes.extended): - raise NotImplementedError # TODO(frostig,mattjj): jax2tf handlers - return tf.constant(np.array(0, dtype=dtype)) - - -tf_impl[lax_internal.empty_p] = _empty - - -def _reshape(operand, *, new_sizes, dimensions, sharding, _in_avals, _out_aval): - if dimensions is None: - dimensions = tf.range(tf.rank(operand)) - new_sizes_tf = _eval_shape(new_sizes, _in_avals[0].dtype) - return tf.reshape(tf.transpose(operand, dimensions), new_sizes_tf) - - -tf_impl_with_avals[lax.reshape_p] = _reshape - - -def _squeeze(operand, *, dimensions, _in_avals, _out_aval): - op_aval = _jax_physical_aval(_in_avals[0]) - op_shape = op_aval.shape - new_shape = tuple(d for i, d in enumerate(op_shape) if i not in dimensions) - new_shape_tf = _eval_shape(new_shape, op_aval.dtype) - return tf.reshape(operand, new_shape_tf) - - -tf_impl_with_avals[lax.squeeze_p] = _squeeze - - -def _pad(operand, padding_value, *, padding_config, - _in_avals: Sequence[core.ShapedArray], - _out_aval: core.ShapedArray): - low, high, interior = util.unzip3(map(_eval_shape, padding_config)) # type: ignore - out = tfxla.pad(operand, padding_value, low, high, interior) - # TODO: implement shape inference for XlaPad (when some padding_config is constant) - out = _ensure_tf_shape_if_dynamic(out, _aval_to_tf_shape(_out_aval)) - if _WRAP_JAX_JIT_WITH_TF_FUNCTION: - out = tf.stop_gradient(out) # See #7839 - return out - - -tf_impl_with_avals[lax.pad_p] = _pad - - -def _rev(operand, *, dimensions): - return tf.reverse(operand, dimensions) - - -tf_impl[lax.rev_p] = _rev - - -def _where(which, *cases): - if which.dtype == tf.bool: - assert len(cases) <= 2 - return cases if len(cases) == 1 else tf.where(which, cases[1], cases[0]) - - def _select(offset, cases): - assert len(cases) > 0 - if len(cases) == 1: - return cases[0] - mid = len(cases) // 2 - return tf.where(tf.less(which, offset + mid), - _select(offset, cases[:mid]), - _select(mid, cases[mid:])) - - return _select(0, cases) - - -tf_impl[lax.select_n_p] = _where - - -def _transpose(operand, *, permutation): - return tf.transpose(operand, perm=permutation) - - -tf_impl[lax.transpose_p] = _transpose - -axes_to_axis = lambda func: lambda operand, axes: func(operand, axis=axes) - -# reduce_sum and reduce_prod are not supported for bool -tf_impl[lax.reduce_sum_p] = axes_to_axis(tf.reduce_sum) -tf_impl[lax.reduce_prod_p] = axes_to_axis(tf.reduce_prod) -tf_impl[lax.reduce_max_p] = handle_boolean_args( - axes_to_axis(tf.reduce_max), argnums=[0], - boolean_f=axes_to_axis(tf.reduce_any)) # Max is T if any one is T -tf_impl[lax.reduce_min_p] = handle_boolean_args( - axes_to_axis(tf.reduce_min), argnums=[0], - boolean_f=axes_to_axis(tf.reduce_all)) # Min is F if not all are T -tf_impl[lax.reduce_or_p] = axes_to_axis(tf.reduce_any) -tf_impl[lax.reduce_and_p] = axes_to_axis(tf.reduce_all) - - -def _argminmax(is_min: bool, operand: TfVal, axes: Sequence[int], - index_dtype: DType, - _in_avals: Sequence[core.ShapedArray], - _out_aval: core.ShapedArray): - # Follow the JAX implementation, using a XlaReduce with a custom comparator - if is_min: - extra_name_stack = "argmin" - value_comparator = lax.lt - get_identity = lax_internal._get_min_identity - else: - extra_name_stack = "argmax" - value_comparator = lax.gt - get_identity = lax_internal._get_max_identity - - res = _convert_jax_impl( - partial(lax_internal._compute_argminmax, value_comparator, get_identity), - multiple_results=False, - extra_name_stack=extra_name_stack)( - operand, - index_dtype=index_dtype, - axes=axes, - _in_avals=_in_avals, - _out_aval=_out_aval) - return res - - -tf_impl_with_avals[lax.argmin_p] = partial(_argminmax, True) -tf_impl_with_avals[lax.argmax_p] = partial(_argminmax, False) - - -_add_fn = tf.function(_add, autograph=False) -_ge_fn = tf.function(tf.math.greater_equal, autograph=False) - - -def _select_and_gather_add( - tangents: TfVal, operand: TfVal, select_prim: core.Primitive, - window_dimensions: Sequence[int], window_strides: Sequence[int], - base_dilation: Sequence[int], window_dilation: Sequence[int], - padding: Sequence[tuple[int, int]], _in_avals: Sequence[core.ShapedArray], - _out_aval: core.ShapedArray): - # Note: this function follows the pattern in - # jax.lax._select_and_gather_add_translation. - dtype = operand.dtype - nbits = dtypes.finfo(dtype.as_numpy_dtype).bits - - # Specializing the function for 64 bits. Only up to 32 bits are supported on TPU, - # we thus intend to let the code throw a different exception on this platform. - max_bits = 64 - - assert nbits <= max_bits - double_word_reduction = nbits * 2 <= max_bits - - const = lambda dtype, x: tf.constant(np.array(x), dtype) - - if double_word_reduction: - word_dtype = lax_internal._UINT_DTYPES[nbits] - double_word_dtype = lax_internal._UINT_DTYPES[nbits * 2] - - # Packs two values into a tuple. - def pack(a, b): - a = _bitcast_convert_type(a, word_dtype) - b = _bitcast_convert_type(b, word_dtype) - a = _convert_element_type(a, new_dtype=double_word_dtype) - b = _convert_element_type(b, new_dtype=double_word_dtype) - a = tf.bitwise.left_shift(a, const(double_word_dtype, nbits)) - return tf.bitwise.bitwise_or(a, b) - - # Unpacks the first element of a tuple. - def fst(t): - assert t.dtype == double_word_dtype - st = _shift_right_logical(t, const(double_word_dtype, nbits)) - return _bitcast_convert_type( - _convert_element_type(st, new_dtype=word_dtype), dtype) - - # Unpacks the second element of a tuple. - def snd(t): - return _bitcast_convert_type( - _convert_element_type(t, new_dtype=word_dtype), dtype) - - else: - raise NotImplementedError( - f"TODO: need to pack {nbits * 2} bits but this platform can only go up to {max_bits} bits." - ) - - assert select_prim is lax.ge_p or select_prim is lax.le_p, select_prim - - def reducer(x, y): - which = tf_impl[select_prim] - return tf_impl[lax.select_n_p](which(fst(x), fst(y)), y, x) - - init = -np.inf if select_prim is lax.ge_p else np.inf - init_identity = lambda x: pack(const(dtype, init), const(dtype, 0)) - - out = _specialized_reduce_window( - reducer, - init_identity, - pack(operand, tangents), - window_dimensions=window_dimensions, - window_strides=window_strides, - padding=padding, - base_dilation=base_dilation, - window_dilation=window_dilation, - _in_avals=_in_avals, - _out_aval=_out_aval) - - return snd(out) - - -tf_impl_with_avals[lax.select_and_gather_add_p] = _select_and_gather_add - - -def _common_reduce_window(operand, init_val, reducer, window_dimensions, - window_strides, padding, base_dilation, - window_dilation, _in_avals, _out_aval): - o_spec = tf.TensorSpec((), dtype=operand.dtype) - reducer_fn = tf.function( - reducer, autograph=False).get_concrete_function(o_spec, o_spec) - - if not isinstance(init_val, (tf.Tensor, tf.Variable)): - init_val = tf.constant(init_val, operand.dtype) - window_dimensions_tf = _eval_shape(window_dimensions) - window_strides_tf = _eval_shape(window_strides) - window_dilation_tf = _eval_shape(window_dilation) - base_dilation_tf = _eval_shape(base_dilation) - padding_tf = [_eval_shape(p) for p in padding] - out = tfxla.reduce_window( - operand, - init_val, - reducer_fn, - window_dimensions_tf, - window_strides_tf, - base_dilations=base_dilation_tf, - window_dilations=window_dilation_tf, - padding=padding_tf) - # TODO: implement shape inference for XlaReduceWindow - out = _ensure_tf_shape_if_dynamic(out, _aval_to_tf_shape(_out_aval)) - if _WRAP_JAX_JIT_WITH_TF_FUNCTION: - out = tf.stop_gradient(out) # See #7839 - return out - - -def _reduce_window(*args, jaxpr, consts, window_dimensions, - window_strides, padding, base_dilation, window_dilation, - _in_avals, _out_aval): - """TensorFlow implementation of reduce_window. - - Args: - operands: N dimensional arrays containing elements of type T - init_values: starting values of the reduction - jaxpr: the jaxpr corresponding to the reduction function - consts: the constants associated with jaxpr. - window_dimensions: array of integers for window dimension values - window_strides: array of integers for window stride values - padding: array of pairs of integers for padding values - base_dilation: array of integers for base dilation values - window_dilation: array of integers for window dilation values - - Returns: - The reduced operand. - """ - assert len(consts) == 0, "Reduction computation cannot have constants" - operands, init_values = util.split_list(args, [len(args) // 2]) - - if len(operands) != 1: - raise NotImplementedError("jax2tf does not support variadic reduce_window") - - def reducer(arg1: TfVal, arg2: TfVal) -> TfVal: - closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) - res, = _interpret_jaxpr(closed_jaxpr, arg1, arg2, extra_name_stack=None) - return res - - return (_common_reduce_window(operands[0], init_values[0], reducer, - window_dimensions, window_strides, padding, - base_dilation, window_dilation, _in_avals, - _out_aval[0]),) - - -def _specialized_reduce_window(reducer, - identity, - operand, - *, - window_dimensions, - window_strides, - padding, - base_dilation, - window_dilation, - _in_avals, - _out_aval, - name=None): - """Wraps the TensorFlow reduce window operation based on a reducer and an - - identity function defining the initial value of the reduction depending on - the dtype of the operand. - - Args: - reducer: reduction function of type TfVal -> TfVal -> TfVal - identity: function that takes a TensorFlow dtype as a parameter and returns - the starting value of the reduction. - operand: N dimensional array containing elements of type T - window_dimensions: array of integers for window dimension values - window_strides: array of integers for window stride values - padding: array of pairs of integers for padding values - base_dilation: array of integers for base dilation values - window_dilation: array of integers for window dilation values - name: the name of the specialized reduce window primitive for which this - conversion function is called. This information may help to choose a - different conversion path (optional) - - Returns: - The reduced operand. - """ - return _common_reduce_window(operand, identity(operand.dtype), reducer, - window_dimensions, window_strides, padding, - base_dilation, window_dilation, _in_avals, - _out_aval) - - -def _get_max_identity(tf_dtype): - numpy_tf_dtype = tf_dtype.as_numpy_dtype - if tf_dtype == tf.bfloat16 or dtypes.issubdtype(numpy_tf_dtype, np.inexact): - return numpy_tf_dtype(-np.inf) - elif dtypes.issubdtype(numpy_tf_dtype, np.integer): - return dtypes.iinfo(numpy_tf_dtype).min - else: - assert dtypes.issubdtype( - numpy_tf_dtype, np.bool_), (f"{tf_dtype} has no defined max identity") - return False - - -def _get_min_identity(tf_dtype): - numpy_tf_dtype = tf_dtype.as_numpy_dtype - if tf_dtype == tf.bfloat16 or dtypes.issubdtype(numpy_tf_dtype, np.inexact): - return numpy_tf_dtype(np.inf) - elif dtypes.issubdtype(numpy_tf_dtype, np.integer): - return dtypes.iinfo(numpy_tf_dtype).max - else: - assert dtypes.issubdtype( - numpy_tf_dtype, np.bool_), (f"{tf_dtype} has no defined min identity") - return True - - -# pylint: disable=protected-access -tf_impl_with_avals[lax.reduce_window_sum_p] = ( - partial(_specialized_reduce_window, _add, lambda x: 0, - name="reduce_window_sum")) -tf_impl_with_avals[lax.reduce_window_min_p] = ( - partial(_specialized_reduce_window, - partial(_minmax_scalar, is_min=True), - _get_min_identity, - name="reduce_window_min")) -tf_impl_with_avals[lax.reduce_window_max_p] = ( - partial(_specialized_reduce_window, - partial(_minmax_scalar, is_min=False), - _get_max_identity, - name="reduce_window_max")) -tf_impl_with_avals[lax.reduce_window_p] = _reduce_window -# pylint: enable=protected-access - -def _reduce(*operands: TfVal, - computation: Callable, - jaxpr: core.ClosedJaxpr, - dimensions: Sequence[int], - _in_avals: Sequence[core.ShapedArray], - _out_aval: core.ShapedArray) -> Sequence[TfVal]: - del computation - assert not jaxpr.consts - assert len(operands) % 2 == 0 - # operands: op1, op2, ..., init_val1, init_val2, ... - # reducer takes op1[i], op2[i], ..., init_val1, init_val2, ... - nr_operands = len(operands) // 2 - init_vals = operands[nr_operands:] - operands = operands[0:nr_operands] - - reducer_arg_spec = tuple([tf.TensorSpec((), op.dtype) for op in init_vals] * 2) - - def reducer_computation(*args: TfVal) -> TfVal: - res = _interpret_jaxpr(jaxpr, *args, extra_name_stack=None) - return res - - xla_reducer_computation = ( - tf.function(reducer_computation, - autograph=False).get_concrete_function(*reducer_arg_spec)) - - outs = tfxla.variadic_reduce(operands, init_vals, - dimensions_to_reduce=dimensions, - reducer=xla_reducer_computation) - if _WRAP_JAX_JIT_WITH_TF_FUNCTION: - outs = tuple(tf.stop_gradient(out) for out in outs) # See #7839 - return outs - -tf_impl_with_avals[lax.reduce_p] = _reduce - - -# We use lax.cumred_reduce_window_impl to convert cummax, -# cummin, cumsum and cumprod. This is efficient on TPU, but the complexity is -# O(n^2) on other backends. This may be implemented using associative_scan -# instead to favor different backends. -def _cumred(lax_reduce_fn: Callable, - lax_reduce_window_fn: Callable, - extra_name_stack: str): - associative_scan = partial(lax_control_flow.associative_scan, lax_reduce_fn) - reduce_window = partial( - lax_control_flow.cumred_reduce_window_impl, lax_reduce_window_fn - ) - - def _call_impl(*args, **kwargs): - # Vary which implementation to use when cumulation is called. This cannot be - # done during import time because the caller may later use a python context - # to switch the implementation to use. - associative = config.jax2tf_associative_scan_reductions.value - return (associative_scan if associative else reduce_window)(*args, **kwargs) - - return _convert_jax_impl( - _call_impl, multiple_results=False, extra_name_stack=extra_name_stack - ) - - -tf_impl_with_avals[lax.cummax_p] = _cumred( - lax_reduce_window_fn=lax_windowed_reductions._reduce_window_max, - lax_reduce_fn=lax.max, - extra_name_stack="cummax") -tf_impl_with_avals[lax.cummin_p] = _cumred( - lax_reduce_window_fn=lax_windowed_reductions._reduce_window_min, - lax_reduce_fn=lax.min, - extra_name_stack="cummin") -tf_impl_with_avals[lax.cumlogsumexp_p] = _cumred( - lax_reduce_window_fn=lax_windowed_reductions._reduce_window_logaddexp, - lax_reduce_fn=logaddexp, - extra_name_stack="cumlogsumexp") -tf_impl_with_avals[lax.cumsum_p] = _cumred( - lax_reduce_window_fn=lax_windowed_reductions._reduce_window_sum, - lax_reduce_fn=lax.add, - extra_name_stack="cumsum") -tf_impl_with_avals[lax.cumprod_p] = _cumred( - lax_reduce_window_fn=lax_windowed_reductions._reduce_window_prod, - lax_reduce_fn=lax.mul, - extra_name_stack="cumprod") - - -def _select_and_scatter(operand, source, init_value, select_jaxpr, - select_consts, scatter_jaxpr, scatter_consts, - window_dimensions, window_strides, padding): - raise NotImplementedError("TODO: jax2tf can not convert _select_and_scatter") - - -tf_impl[lax.select_and_scatter_p] = _select_and_scatter - - -@partial(handle_boolean_args, argnums=(0, 1)) -def _select_and_scatter_add(source, operand, *, select_prim, window_dimensions, - window_strides, padding, _in_avals, _out_aval): - init_value = tf.zeros((), operand.dtype) - select_fn = ( - tf.function(tf_impl[select_prim], autograph=False).get_concrete_function( - init_value, init_value)) - scatter_fn = _add_fn.get_concrete_function(init_value, init_value) - out = tfxla.select_and_scatter(operand, window_dimensions, window_strides, - padding, source, init_value, select_fn, - scatter_fn) - out = _ensure_tf_shape_if_dynamic(out, _aval_to_tf_shape(_out_aval)) - if _WRAP_JAX_JIT_WITH_TF_FUNCTION: - out = tf.stop_gradient(out) # See #7839 - return out - - -tf_impl_with_avals[lax.select_and_scatter_add_p] = _select_and_scatter_add - - -def _random_seed_impl(seeds: TfVal, *, impl, _in_avals, _out_aval): - - def impl_wrapper(seeds: TfVal, *, impl): - return prng.random_seed_impl_base(seeds, impl=impl) - - converted_impl = _convert_jax_impl( - impl_wrapper, multiple_results=False, with_physical_avals=True, - extra_name_stack="random_seed") - return converted_impl( - seeds, impl=impl, _in_avals=_in_avals, _out_aval=_out_aval) - -tf_impl_with_avals[prng.random_seed_p] = _random_seed_impl - - -def _random_split_impl(keys: TfVal, *, shape, _in_avals, _out_aval): - keys_aval, = _in_avals - - def impl_wrapper(keys: TfVal, *, shape): - return prng.random_split_impl_base( - keys_aval.dtype._impl, keys, keys_aval.ndim, shape=shape) - - converted_impl = _convert_jax_impl( - impl_wrapper, multiple_results=False, with_physical_avals=True, - extra_name_stack="random_split") - return converted_impl( - keys, shape=shape, _in_avals=_in_avals, _out_aval=_out_aval) - -tf_impl_with_avals[prng.random_split_p] = _random_split_impl - - -def _random_fold_in_impl(keys: TfVal, msgs: TfVal, *, _in_avals, _out_aval): - keys_aval, _ = _in_avals - - def impl_wrapper(keys: TfVal, msgs: TfVal): - return prng.random_fold_in_impl_base( - keys_aval.dtype._impl, keys, msgs, keys_aval.shape) - - converted_impl = _convert_jax_impl( - impl_wrapper, multiple_results=False, with_physical_avals=True, - extra_name_stack="random_fold_in") - return converted_impl( - keys, msgs, _in_avals=_in_avals, _out_aval=_out_aval) - -tf_impl_with_avals[prng.random_fold_in_p] = _random_fold_in_impl - - -def _random_bits_impl(keys: TfVal, *, bit_width, shape, _in_avals, _out_aval): - keys_aval, = _in_avals - - def impl_wrapper(keys: TfVal, **kwargs): - return prng.random_bits_impl_base( - keys_aval.dtype._impl, keys, keys_aval.ndim, - bit_width=bit_width, shape=shape) - - converted_impl = _convert_jax_impl( - impl_wrapper, multiple_results=False, with_physical_avals=True, - extra_name_stack="random_bits") - return converted_impl(keys, bit_width=bit_width, shape=shape, - _in_avals=_in_avals, _out_aval=_out_aval) - -tf_impl_with_avals[prng.random_bits_p] = _random_bits_impl - - -def _random_wrap_impl(base_arr: TfVal, *, impl, _in_avals, _out_aval): - return base_arr - -tf_impl_with_avals[prng.random_wrap_p] = _random_wrap_impl - - -def _random_unwrap_impl(keys: TfVal, *, _in_avals, _out_aval): - return keys - -tf_impl_with_avals[prng.random_unwrap_p] = _random_unwrap_impl - - -def _threefry2x32_jax_impl(*args: TfVal, _in_avals, _out_aval): - res = _convert_jax_impl( - partial(prng._threefry2x32_lowering, use_rolled_loops=False), - multiple_results=True, extra_name_stack="threefry")( - *args, _in_avals=_in_avals, _out_aval=_out_aval) - return res - - -tf_impl_with_avals[prng.threefry2x32_p] = _threefry2x32_jax_impl - -# Use the vmap implementation, otherwise on TPU the performance is really bad -# With use_vmap=True on, we get about the same performance for JAX and jax2tf. -tf_impl_with_avals[random.random_gamma_p] = _convert_jax_impl( - partial(random_internal._gamma_impl, use_vmap=True), - multiple_results=False, extra_name_stack="random_gamma") - - -def _rng_bit_generator(key: TfVal, *, shape, dtype, algorithm) -> Sequence[TfVal]: - is_uint32_key = key.dtype == _to_tf_dtype(jnp.uint32) - if is_uint32_key: - key = tf.reshape(key, (2, 2)) - key = tfxla.bitcast_convert_type(key, _to_tf_dtype(jnp.uint64)) - shape_tf = _eval_shape(shape) - # JAX uses XLA algorithm enums; tfxla uses tf.random.Algorithm - if algorithm == lax.RandomAlgorithm.RNG_THREE_FRY: - algorithm_tf = tf.random.Algorithm.THREEFRY - elif algorithm == lax.RandomAlgorithm.RNG_PHILOX: - algorithm_tf = tf.random.Algorithm.PHILOX - elif algorithm == lax.RandomAlgorithm.RNG_DEFAULT: - algorithm_tf = tf.random.Algorithm.AUTO_SELECT - else: - assert False - (new_key, res) = tfxla.rng_bit_generator(algorithm_tf.value, key, shape_tf, - dtype=_to_tf_dtype(dtype)) - if is_uint32_key: - new_key = tfxla.bitcast_convert_type(new_key, _to_tf_dtype(jnp.uint32)) - new_key = tf.reshape(new_key, (4,)) - if _WRAP_JAX_JIT_WITH_TF_FUNCTION: - # See #7839 - new_key = tf.stop_gradient(new_key) - res = tf.stop_gradient(res) - return new_key, res - - -tf_impl[lax.rng_bit_generator_p] = _rng_bit_generator - - -def _rng_uniform(minval: TfVal, maxval: TfVal, *, shape) -> TfVal: - shape_tf = _eval_shape(shape) - return tf.random.uniform(shape_tf, minval=minval, maxval=maxval, dtype=minval.dtype) - -tf_impl[lax.rng_uniform_p] = _rng_uniform - - -def _iota_2x32_shape(*, shape): - def _add(x, y): - return x + y - def _mul(x, y): - if not core.is_constant_dim(x): - x = tf.cast(_eval_shape((x,))[0], y.dtype) - x = tf.broadcast_to(x, tf.shape(y)) - return x * y - def _cast32(xs): return tf.cast(xs, _to_tf_dtype(jnp.uint32)) - iotas = [_iota(dtype=jnp.uint64, shape=shape, dimension=dimension) - for dimension in range(len(shape))] - counts = prng.bcast_iotas_to_reshaped_iota(_add, _mul, shape, iotas) - counts_lo = _cast32(counts) - counts_hi = _cast32(tf.bitwise.right_shift(counts, 32)) - return counts_hi, counts_lo - -tf_impl[prng.iota_2x32_shape_p] = _iota_2x32_shape - - -def _gather_dimensions_proto(indices_shape, dimension_numbers): - proto = xla_data_pb2.GatherDimensionNumbers() - proto.offset_dims.extend(dimension_numbers.offset_dims) - proto.collapsed_slice_dims.extend(dimension_numbers.collapsed_slice_dims) - proto.start_index_map.extend(dimension_numbers.start_index_map) - proto.operand_batching_dims.extend(dimension_numbers.operand_batching_dims) - proto.start_indices_batching_dims.extend( - dimension_numbers.start_indices_batching_dims) - assert indices_shape - proto.index_vector_dim = len(indices_shape) - 1 - return proto - - -def _maybe_cast_to_int64(x: TfVal) -> TfVal: - if x.dtype != tf.int32 and x.dtype != tf.int64: - return tf.cast(x, tf.int64) - return x - - -@partial(handle_boolean_args, argnums=[0]) -def _gather(operand, start_indices, *, dimension_numbers, slice_sizes: core.Shape, - indices_are_sorted, unique_indices, mode, fill_value, - _in_avals: Sequence[core.ShapedArray], - _out_aval: core.ShapedArray): - """Tensorflow implementation of gather.""" - if mode == lax.GatherScatterMode.FILL_OR_DROP: - gather_fill_fn = _convert_jax_impl(lax_slicing._gather_fill, - multiple_results=False) - return gather_fill_fn( - operand, start_indices, dimension_numbers=dimension_numbers, - slice_sizes=slice_sizes, unique_indices=unique_indices, - indices_are_sorted=indices_are_sorted, fill_value=fill_value, - output_shape=_out_aval.shape, _in_avals=_in_avals, _out_aval=_out_aval) - - operand_aval = _in_avals[0] - start_indices = _maybe_cast_to_int64(start_indices) - if dtypes.issubdtype(operand_aval.dtype, dtypes.extended): - opaque_shape = _jax_physical_aval(operand_aval).shape[len(operand_aval.shape):] - trailing_offset_dims = [len(_out_aval.shape) + i for i in range(len(opaque_shape))] - dimension_numbers = dimension_numbers._replace( - offset_dims=(*dimension_numbers.offset_dims, *trailing_offset_dims)) - slice_sizes = (*slice_sizes, *opaque_shape) - proto = _gather_dimensions_proto(start_indices.shape, dimension_numbers) - slice_sizes_tf = _eval_shape(slice_sizes) - out = tfxla.gather(operand, start_indices, proto, slice_sizes_tf, - indices_are_sorted) - out = _ensure_tf_shape_if_dynamic(out, _aval_to_tf_shape(_out_aval)) - if _WRAP_JAX_JIT_WITH_TF_FUNCTION: - out = tf.stop_gradient(out) # See #7839 - return out - - -tf_impl_with_avals[lax.gather_p] = _gather - - -def _slice(operand, start_indices, limit_indices, strides, _in_avals, - _out_aval): - if strides is None: - strides = [1] * len(start_indices) - slices = tuple( - map(slice, _eval_shape(start_indices), _eval_shape(limit_indices), - _eval_shape(strides))) - out = operand[slices] - # TODO(b/184503314): improve shape inference for __getitem__ - # E.g., operand.shape=(b, 5, 3), start_indices=(0, 1, 1), limit_indices=(b, 5, 3), strides=(1, 2, 1) - out = _ensure_tf_shape_if_dynamic(out, _aval_to_tf_shape(_out_aval)) - return out - - -tf_impl_with_avals[lax.slice_p] = _slice - - -def _dynamic_slice(operand, *start_indices, slice_sizes: core.Shape, - _in_avals: Sequence[core.ShapedArray], - _out_aval: core.ShapedArray): - start_indices = _maybe_cast_to_int64(tf.stack(start_indices)) - operand_aval = _in_avals[0] - if dtypes.issubdtype(operand_aval.dtype, dtypes.extended): - opaque_shape = _jax_physical_aval(operand_aval).shape[len(operand_aval.shape):] - slice_sizes = (*slice_sizes, *opaque_shape) - start_indices = tf.concat([start_indices, tf.zeros((len(opaque_shape),), - dtype=start_indices.dtype)], - axis=0) - - slice_sizes_tf = _eval_shape(slice_sizes) - res = tfxla.dynamic_slice(operand, start_indices, size_indices=slice_sizes_tf) - if _WRAP_JAX_JIT_WITH_TF_FUNCTION: - res = tf.stop_gradient(res) # See #7839 - return res - - -tf_impl_with_avals[lax.dynamic_slice_p] = _dynamic_slice - - -def _dynamic_update_slice(operand, update, *start_indices, - _in_avals: Sequence[core.ShapedArray], - _out_aval: core.ShapedArray): - start_indices = _maybe_cast_to_int64(tf.stack(start_indices)) - operand_aval = _in_avals[0] - if dtypes.issubdtype(operand_aval.dtype, dtypes.extended): - opaque_shape = _jax_physical_aval(operand_aval).shape[len(operand_aval.shape):] - start_indices = tf.concat([start_indices, tf.zeros((len(opaque_shape),), - dtype=start_indices.dtype)], - axis=0) - out = tfxla.dynamic_update_slice(operand, update, start_indices) - if _WRAP_JAX_JIT_WITH_TF_FUNCTION: - out = tf.stop_gradient(out) # See #7839 - return out - - -tf_impl_with_avals[lax.dynamic_update_slice_p] = _dynamic_update_slice - - -def _scatter_dimensions_proto(indices_shape, dimension_numbers): - proto = xla_data_pb2.ScatterDimensionNumbers() - proto.update_window_dims.extend(dimension_numbers.update_window_dims) - proto.inserted_window_dims.extend(dimension_numbers.inserted_window_dims) - proto.scatter_dims_to_operand_dims.extend( - dimension_numbers.scatter_dims_to_operand_dims) - proto.input_batching_dims.extend(dimension_numbers.operand_batching_dims) - proto.scatter_indices_batching_dims.extend( - dimension_numbers.scatter_indices_batching_dims) - assert indices_shape - proto.index_vector_dim = len(indices_shape) - 1 - return proto - - -_scatter_reduction_computation = lambda x, y: y - - -def _scatter(operand, scatter_indices, updates, *, update_jaxpr, update_consts, - dimension_numbers, indices_are_sorted, unique_indices, mode, - _in_avals: Sequence[core.ShapedArray], - _out_aval: core.ShapedArray): - del unique_indices - if update_jaxpr is None: - assert not update_consts - update_jaxpr, update_consts = lax_internal._reduction_jaxpr( - _scatter_reduction_computation, - core.ShapedArray((), operand.dtype.as_numpy_dtype)) - - if mode == lax.GatherScatterMode.CLIP: - clip_fn = _convert_jax_impl(lax_slicing._clamp_scatter_indices, - multiple_results=False) - scatter_indices = clip_fn( - operand, scatter_indices, updates, dnums=dimension_numbers, - _in_avals=_in_avals, _out_aval=_in_avals[1]) - - assert len(update_consts) == 0, "Update computation cannot have constants" - - proto = _scatter_dimensions_proto(scatter_indices.shape, dimension_numbers) - - def update_computation(arg1: TfVal, arg2: TfVal) -> TfVal: - closed_jaxpr = core.ClosedJaxpr(update_jaxpr, update_consts) - res, = _interpret_jaxpr(closed_jaxpr, arg1, arg2, extra_name_stack=None) - return res - - o_spec = tf.TensorSpec((), dtype=operand.dtype) - xla_update_computation = ( - tf.function(update_computation, - autograph=False).get_concrete_function(o_spec, o_spec)) - out = tfxla.scatter( - operand, - scatter_indices, - updates, - xla_update_computation, - proto, - indices_are_sorted=indices_are_sorted) - if _WRAP_JAX_JIT_WITH_TF_FUNCTION: - out = tf.stop_gradient(out) # See #7839 - return out - - -tf_impl_with_avals[lax.scatter_p] = _scatter -tf_impl_with_avals[lax.scatter_min_p] = _scatter -tf_impl_with_avals[lax.scatter_max_p] = _scatter -tf_impl_with_avals[lax.scatter_mul_p] = _scatter -tf_impl_with_avals[lax.scatter_add_p] = _scatter -tf_impl_with_avals[lax.scatter_sub_p] = _scatter - - -def _cond( - index: TfVal, *operands: TfVal, branches: Sequence[core.ClosedJaxpr] -) -> Sequence[TfVal]: - # tf.cond needs lambdas with no arguments. - branches_tf = [ - partial(_interpret_jaxpr, jaxpr, *operands, - # Same name stack as the XLA translation of cond_p - extra_name_stack=f"branch_{i}_fun") - for i, jaxpr in enumerate(branches) - ] - # Same name stack as XLA translation of cond_p - # Note: extend_name_stack is a contextmanager, which is callable as a decorator. - branches_tf = list(map(source_info_util.extend_name_stack("cond"), - branches_tf)) - if len(branches) == 2: - # `index` comes with tf.int32 type of casted boolean parameter. - return tf.cond(tf.cast(index, tf.bool), branches_tf[1], branches_tf[0]) - else: - return tf.switch_case(index, branches_tf) - - -tf_impl[lax.cond_p] = _cond - - -def _while(*args: TfVal, cond_nconsts: int, cond_jaxpr: core.ClosedJaxpr, - body_nconsts: int, body_jaxpr: core.ClosedJaxpr) -> Sequence[TfVal]: - cond_consts, body_consts, init_carry = util.split_list( - args, [cond_nconsts, body_nconsts]) - if cond_jaxpr.out_avals[0].shape: - # The conditional is not a scalar, this must be a batched while - return _batched_cond_while( - *args, - cond_nconsts=cond_nconsts, - cond_jaxpr=cond_jaxpr, - body_nconsts=body_nconsts, - body_jaxpr=body_jaxpr) - - # The conditional must return a single value to TF - def cond_tf_func(*args: TfVal) -> TfVal: - pred, = _interpret_jaxpr(cond_jaxpr, *cond_consts, *args, - # Same name stack as the XLA translation of while_p - extra_name_stack="while/cond") - return pred - - body_tf_func = partial(_interpret_jaxpr, body_jaxpr, *body_consts, - extra_name_stack="while/body") - # Sometimes TF infers more specific shapes for the init_carry, and this has - # led to errors: "enters the loop with shape (1,), but has shape (None,) after one iteration" - shape_invariants = [tf.TensorShape(_aval_to_tf_shape(_out_aval)) - for _out_aval in body_jaxpr.out_avals] - return tf.while_loop(cond_tf_func, body_tf_func, init_carry, - shape_invariants=shape_invariants) - - -def _batched_cond_while(*args: TfVal, cond_nconsts: int, - cond_jaxpr: core.ClosedJaxpr, body_nconsts: int, - body_jaxpr: core.ClosedJaxpr) -> Sequence[TfVal]: - """Interprets a while_loop with a batched condition. - - A batched while has a conditional that returns a tensor of booleans, and - a body that returns a list of tensors whose leading dimensions match those - of the conditional tensor. - - We need to turn it into a while with scalar boolean conditional. We will - expand the loop carry to include a prefix with the current tensor boolean - condition. We prepend to the loop the first calculation of the tensor boolean - condition. The loop condition will use a "reduce_any" to calculate a scalar - boolean from the tensor boolean condition. The end of the loop body will - compute the new carry using a "tf.where", and we compute the new tensor - boolean condition. - """ - cond_consts, body_consts, init_carry = util.split_list( - args, [cond_nconsts, body_nconsts]) - # Initial computation of batched condition - init_pred_b, = _interpret_jaxpr(cond_jaxpr, *cond_consts, *init_carry, - extra_name_stack="while/body_pred") - - def new_cond_tf_func(pred_b: TfVal, *carry: TfVal) -> TfVal: - pred = tf.reduce_any(pred_b, axis=list(range(len(pred_b.shape)))) - return pred - - def new_body_tf_func(pred_b: TfVal, *carry: TfVal) -> Sequence[TfVal]: - new_carry: Sequence[TfVal] = _interpret_jaxpr(body_jaxpr, *body_consts, - *carry, - extra_name_stack="while/body") - # We repeat those carries for which the loop termination condition is false - def select_one_carry(new_c: TfVal, c: TfVal, c_aval: core.ShapedArray) -> TfVal: - pred_b_bcast = _broadcast_in_dim( - pred_b, - shape=_jax_physical_aval(c_aval).shape, # a JAX shape - broadcast_dimensions=list(range(len(pred_b.shape))), - _in_avals=cond_jaxpr.out_avals, - _out_aval=core.ShapedArray(c_aval.shape, np.bool_)) - return tf.where(pred_b_bcast, new_c, c) - - selected_carry: Sequence[TfVal] = list(map(select_one_carry, new_carry, carry, body_jaxpr.out_avals)) - next_pred_b, = _interpret_jaxpr(cond_jaxpr, *cond_consts, *selected_carry, - extra_name_stack="body_pred") - return (next_pred_b, *selected_carry) - - _, *res_carry = tf.while_loop(new_cond_tf_func, new_body_tf_func, - (init_pred_b, *init_carry)) - return res_carry - - -tf_impl[lax.while_p] = _while - -# We use the scan impl rule to rewrite in terms of while. -tf_impl_with_avals[lax.scan_p] = _convert_jax_impl( - lax_control_flow._scan_impl, - extra_name_stack="scan") - -tf_impl_with_avals[ad_checkpoint.remat_p] = \ - _convert_jax_impl(partial(ad_checkpoint.remat_expansion, - # TODO: jax2tf cannot discriminate by platform - is_gpu_platform=False), - multiple_results=True, - extra_name_stack="checkpoint") - -tf_impl[ad_checkpoint.name_p] = lambda x, *, name: x - -# TODO: Remove once tensorflow is 2.10.0 everywhere. -if hasattr(tfxla, 'optimization_barrier'): - tf_impl[lax_control_flow.optimization_barrier_p] = tfxla.optimization_barrier - -def _top_k(operand: TfVal, k: int) -> tuple[TfVal, TfVal]: - # Some types originally incompatible with tf.math.top_k can be promoted - # to a compatible type without loss of precision. - def promote_tf_dtype(tf_dtype): - if tf_dtype in [tf.bool, tf.uint8, tf.uint16]: - return tf.uint32 - if tf_dtype in [tf.int8, tf.int16]: - return tf.int32 - if tf_dtype is tf.float16: - return tf.float32 - return None - - conversion_dtype = promote_tf_dtype(operand.dtype) - if not core.is_constant_dim(k): - k_tf = _eval_shape((k,))[0] - k_tf = tf.cast(k_tf, tf.int32) # TopK works only for int32 - else: - k_tf = k - if conversion_dtype: - values, indices = tf.math.top_k( - tf.dtypes.cast(operand, conversion_dtype), k=k_tf, sorted=True) - return tf.dtypes.cast(values, operand.dtype), indices - else: - return tf.math.top_k(operand, k=k_tf, sorted=True) - - -tf_impl[lax.top_k_p] = _top_k - - -def _approx_top_k(operand: TfVal, k: int, reduction_dimension: int, - recall_target: float, is_max_k: bool, - reduction_input_size_override: int, - aggregate_to_topk: bool) -> tuple[TfVal, TfVal]: - k_tf = _eval_shape((k,))[0] - if is_max_k: - return tf.math.approx_max_k(operand, k_tf, reduction_dimension, recall_target, - reduction_input_size_override, - aggregate_to_topk) - else: - return tf.math.approx_min_k(operand, k_tf, reduction_dimension, recall_target, - reduction_input_size_override, - aggregate_to_topk) - - -tf_impl[lax.approx_top_k_p] = _approx_top_k - - -def _sort(*operands: TfVal, dimension: int, is_stable: bool, - num_keys: int) -> tuple[TfVal, ...]: - assert 1 <= num_keys <= len(operands) - assert 0 <= dimension < len( - operands[0].shape - ), f"Invalid {dimension} for ndim {len(operands[0].shape)}" - - comparator_spec: list[tf.TensorSpec] = [] - comparator_jax_in_avals: list[core.ShapedArray] = [] - for op in operands: - o_spec = tf.TensorSpec((), dtype=op.dtype) - comparator_spec.extend([o_spec, o_spec]) - o_aval = core.ShapedArray((), _to_jax_dtype(op.dtype)) - comparator_jax_in_avals.extend([o_aval, o_aval]) - - # Use the same comparator that JAX uses when compiling to XLA, to get the - # proper NaN/Inf total order, and the lexicographic ordering. - # The comparator is a 2N-argument TF function, with arguments [2k] and [2k +1] - # corresponding to two scalars from operand[k]. - def lexicographic_comparator(*tf_args: TfVal) -> TfVal: - return _convert_jax_impl( - lax_internal._sort_lt_comparator, multiple_results=False)( - *tf_args, - _in_avals=comparator_jax_in_avals, - _out_aval=core.ShapedArray((), np.bool_), - num_keys=num_keys) - - xla_comparator_computation = ( - tf.function(lexicographic_comparator, - autograph=False).get_concrete_function(*comparator_spec)) - results = tfxla.variadic_sort( - operands, - dimension=dimension, - is_stable=is_stable, - comparator=xla_comparator_computation) - if _WRAP_JAX_JIT_WITH_TF_FUNCTION: - results = tuple(tf.stop_gradient(out) for out in results) # See #7839 - return results - - -tf_impl[lax.sort_p] = _sort - - -def _fft(x, *, fft_type, fft_lengths, - _in_avals: Sequence[core.ShapedArray], - _out_aval: core.ShapedArray): - FFT, IFFT, RFFT, IRFFT = list(map(lax.FftType, [0, 1, 2, 3])) - tf_funcs = { - FFT: [tf.signal.fft, tf.signal.fft2d, tf.signal.fft3d], - IFFT: [tf.signal.ifft, tf.signal.ifft2d, tf.signal.ifft3d], - RFFT: [tf.signal.rfft, tf.signal.rfft2d, tf.signal.rfft3d], - IRFFT: [tf.signal.irfft, tf.signal.irfft2d, tf.signal.irfft3d] - } - tf_func = tf_funcs[fft_type][len(fft_lengths) - 1] - if fft_type in (RFFT, IRFFT): - # https://www.tensorflow.org/api_docs/python/tf/signal/irfft - # Here we only set `fft_lengths` argument for non-default value. - (x_aval,) = _in_avals - x_shape = x_aval.shape - expected_lengths = x_shape[-len(fft_lengths) : -1] + ( - (x_shape[-1] - 1) * 2, - ) - if fft_lengths != expected_lengths: - tf_func = partial(tf_func, fft_length=_eval_shape(fft_lengths)) - res = tf_func(x) - return _ensure_tf_shape_if_dynamic(res, _aval_to_tf_shape(_out_aval)) -tf_impl_with_avals[lax.fft_p] = _fft - - -def _qr(operand, full_matrices): - return tf.linalg.qr(operand, full_matrices=full_matrices) - - -tf_impl[lax.linalg.qr_p] = _qr - - -def _svd( - operand: TfVal, - full_matrices: bool, - compute_uv: bool, - subset_by_index: tuple[int, int] | None = None, - algorithm: lax.linalg.SvdAlgorithm | None = None, -): - if not ( - subset_by_index is None - or subset_by_index == (0, min(operand.shape[-1], operand.shape[-2])) - ): - raise NotImplementedError("subset_by_index is not implemented") - - if algorithm is not None and algorithm != lax.linalg.SvdAlgorithm.DEFAULT: - raise NotImplementedError("SVD algorithm is not implemented") - - result = tf.linalg.svd(operand, full_matrices, compute_uv) - if not compute_uv: - return result, - s, u, v = result - return s, u, tf.linalg.adjoint(v) - - -tf_impl[lax.linalg.svd_p] = _svd - - -def _eig(operand: TfVal, compute_left_eigenvectors: bool, - compute_right_eigenvectors: bool): - if compute_left_eigenvectors and compute_right_eigenvectors: - # TODO(bchetioui): didn't find a 100% reliable, easy and satisfying way to - # sort the left eigenvectors in the right order. The jax.numpy.linalg API - # suggests to me that left eigenvectors are anyway seldom used, so I - # think it is acceptable to leave as unimplemented for now. - msg = ("Conversion of eig is not implemented when both " - "compute_left_eigenvectors and compute_right_eigenvectors are set " - "to True.") - raise NotImplementedError(msg) - elif not (compute_left_eigenvectors or compute_right_eigenvectors): - return (tf.linalg.eigvals(operand),) - elif compute_right_eigenvectors: - return tuple(tf.linalg.eig(operand)) - else: # compute_left_eigenvectors == True - wH, vl = tf.linalg.eig(tf.linalg.adjoint(operand)) - wHH = tf.math.conj(wH) - return (wHH, vl) - - -tf_impl[lax.linalg.eig_p] = _eig - - -def _eigh( - operand: TfVal, - lower: bool, - sort_eigenvalues: bool, - subset_by_index: tuple, - _in_avals, - _out_aval, -): - del sort_eigenvalues - if operand.shape[-1] == 0: - v, w = operand, tf.reshape(operand, _eval_shape(_in_avals[0].shape[:-1])) - else: - if not lower: - operand = tf.linalg.adjoint(operand) - w, v = tf.linalg.eigh(operand) - cast_type = { - tf.complex64: tf.float32, - tf.complex128: tf.float64 - }.get(operand.dtype) - if not (subset_by_index is None or subset_by_index == (0, operand.shape[-1])): - raise NotImplementedError("subset_by_index is not implemented") - if cast_type is not None: - w = tf.cast(w, cast_type) - return v, w - - -tf_impl_with_avals[lax.linalg.eigh_p] = _eigh - - -def _lu(operand: TfVal, _in_avals, _out_aval): - return _convert_jax_impl(lax_linalg._lu_python, extra_name_stack="lu")( - operand, _in_avals=_in_avals, _out_aval=_out_aval) - - -tf_impl_with_avals[lax.linalg.lu_p] = _lu - - -def _triangular_solve(a: TfVal, b: TfVal, *, left_side: bool, lower: bool, - transpose_a: bool, conjugate_a: bool, unit_diagonal: bool, - _in_avals: Sequence[core.ShapedArray], - _out_aval: core.ShapedArray): - if unit_diagonal: - a_aval, _ = _in_avals - a_shape = _eval_shape(a_aval.shape) - a = tf.linalg.set_diag(a, tf.ones(a_shape[:-1], dtype=a.dtype)) - if not left_side: - rank = len(a.shape) - transpose_dimensions = list(range(rank - 2)) + [rank - 1, rank - 2] - a = tf.transpose(a, transpose_dimensions) - b = tf.transpose(b, transpose_dimensions) - lower = not lower - # adjoint == transpose for real dtypes, so special care need only be taken - # for complex types. - if a.dtype in [tf.complex64, tf.complex128]: - if (transpose_a and not conjugate_a) or (not transpose_a and conjugate_a): - a = tf.math.conj(a) - result = tf.linalg.triangular_solve(a, b, lower=lower, adjoint=transpose_a) - if not left_side: - result = tf.transpose(result, transpose_dimensions) - return result - - -tf_impl_with_avals[lax.linalg.triangular_solve_p] = _triangular_solve - - -def _linear_solve(*args: TfVal, const_lengths, jaxprs, _in_avals, _out_aval): - return _convert_jax_impl(lax_control_flow._custom_linear_solve_impl, - extra_name_stack="linear_solve")( - *args, - const_lengths=const_lengths, - jaxprs=jaxprs, - _in_avals=_in_avals, - _out_aval=_out_aval) - - -tf_impl_with_avals[lax.linear_solve_p] = _linear_solve - -def _tridiagonal_solve(*args: TfVal, _in_avals, _out_aval, **params): - return _convert_jax_impl(lax_linalg._tridiagonal_solve_jax, - multiple_results=False, - extra_name_stack="tridiagonal_solve")( - *args, - _in_avals=_in_avals, - _out_aval=_out_aval) - - -tf_impl_with_avals[lax.linalg.tridiagonal_solve_p] = _tridiagonal_solve - -def _custom_jvp_call(*args: TfVal, call_jaxpr: core.ClosedJaxpr, - jvp_jaxpr_fun: Callable, - num_consts: int) -> Sequence[TfVal]: - # TODO(necula): ensure that there is no AD transformation in scope - del jvp_jaxpr_fun, num_consts - return _interpret_jaxpr(call_jaxpr, *args, extra_name_stack="custom_jvp", - fresh_constant_cache=False) - - -tf_impl[custom_derivatives.custom_jvp_call_p] = _custom_jvp_call - - -def _custom_vjp_call_jaxpr(*args: TfVal, fun_jaxpr: core.ClosedJaxpr, - **_) -> Sequence[TfVal]: - # TODO(necula): ensure that there is no AD transformation in scope - return _interpret_jaxpr(fun_jaxpr, *args, extra_name_stack="custom_vjp", - fresh_constant_cache=False) - - -tf_impl[custom_derivatives.custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr - - -def _custom_lin(*args: TfVal, **_) -> Sequence[TfVal]: - raise TypeError("can't apply forward-mode autodiff (jvp) to a custom_vjp " - "function.") - - -tf_impl[ad.custom_lin_p] = _custom_lin - - -def _remat_opt(*args: TfVal, num_consts: int, num_res: int, - fwd_jaxpr: core.ClosedJaxpr, - fun_jaxpr_thunk: Callable) -> Sequence[TfVal]: - del num_consts, num_res, fun_jaxpr_thunk - return _interpret_jaxpr(fwd_jaxpr, *args, extra_name_stack="remat_opt", - fresh_constant_cache=False) - - -tf_impl[custom_derivatives.remat_opt_p] = _remat_opt - - PartitionsOrReplicated = Union[tuple[int, ...], None] def split_to_logical_devices(tensor: TfVal, @@ -3517,13 +833,6 @@ def split_to_logical_devices(tensor: TfVal, return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True) -def _xla_compatible_sharding_to_hlo_sharding( - s: sharding.Sharding, - aval: core.ShapedArray) -> xla_client.HloSharding | None: - if isinstance(s, sharding_impls.UnspecifiedValue): - return None - return s._to_xla_hlo_sharding(aval.ndim) - def _shard_value(val: TfVal, sd: xla_client.HloSharding | None, *, skip_replicated_sharding: bool) -> TfVal: @@ -3533,7 +842,7 @@ def _shard_value(val: TfVal, sharding_proto = sd.to_proto() if (skip_replicated_sharding and - op_shardings.is_op_sharding_replicated(sharding_proto)): + op_shardings.is_hlo_sharding_replicated(sd)): return val # Tensorflow heavily relies on tile_assignment_devices proto fields specific @@ -3552,104 +861,41 @@ def _shard_value(val: TfVal, tad = sharding_proto.tile_assignment_devices # type: ignore # To use xla_sharding.py, we must have a xla_data_pb2.OpSharding. - xla_sharding_proto: xla_data_pb2.OpSharding = xla_data_pb2.OpSharding( + xla_sharding_v1_proto: xla_data_pb2.OpSharding = xla_data_pb2.OpSharding( type=int(sharding_proto.type), tile_assignment_dimensions=sharding_proto.tile_assignment_dimensions, tile_assignment_devices=tad, replicate_on_last_tile_dim=sharding_proto.replicate_on_last_tile_dim, last_tile_dims=sharding_proto.last_tile_dims, ) + # Shardy requires V2 sharding format. + if config.use_shardy_partitioner.value: + xla_sharding_v2_proto: xla_data_pb2.OpSharding = xla_data_pb2.OpSharding( + type=int(sharding_proto.type), + tile_assignment_dimensions=sharding_proto.tile_assignment_dimensions, + tile_assignment_devices=sharding_proto.tile_assignment_devices, + iota_reshape_dims=sharding_proto.iota_reshape_dims, + iota_transpose_perm=sharding_proto.iota_transpose_perm, + replicate_on_last_tile_dim=sharding_proto.replicate_on_last_tile_dim, + last_tile_dims=sharding_proto.last_tile_dims, + ) + else: + xla_sharding_v2_proto = None if tf_context.executing_eagerly(): raise ValueError( "A jit function with sharded arguments or results must be used under a `tf.function` context. " "See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#support-for-partitioning for a discussion") - return xla_sharding.Sharding(proto=xla_sharding_proto).apply_to_tensor( - val, use_sharding_op=True) - - -def _pjit(*args: TfVal, - jaxpr: core.ClosedJaxpr, - in_shardings: Sequence[sharding.Sharding], - out_shardings: Sequence[sharding.Sharding], - in_layouts, out_layouts, - donated_invars, - ctx_mesh, - name: str, - keep_unused: bool, - inline: bool, - compiler_options_kvs, - _in_avals: Sequence[core.ShapedArray], - _out_aval: Sequence[core.ShapedArray]) -> TfVal: - del donated_invars - # Apply sharding annotation to the arguments - in_hlo_shardings: Sequence[xla_client.HloSharding | None] = map( - _xla_compatible_sharding_to_hlo_sharding, in_shardings, _in_avals) - sharded_args: Sequence[TfVal] = tuple( - map(partial(_shard_value, - skip_replicated_sharding=not _thread_local_state.enable_xla), - args, in_hlo_shardings)) - results = _interpret_jaxpr(jaxpr, *sharded_args, - extra_name_stack=util.wrap_name(name, "pjit"), - fresh_constant_cache=False) - out_hlo_shardings: Sequence[xla_client.HloSharding | None] = map( - _xla_compatible_sharding_to_hlo_sharding, out_shardings, _out_aval) - sharded_results: Sequence[TfVal] = tuple( - map(partial(_shard_value, - skip_replicated_sharding=not _thread_local_state.enable_xla), - results, out_hlo_shardings)) - return tuple(sharded_results) - - -tf_impl_with_avals[pjit.pjit_p] = _pjit - - -def _pjit_sharding_constraint(arg: TfVal, *, - sharding: sharding.Sharding, - context_mesh: mesh.Mesh, - _in_avals: Sequence[core.ShapedArray], - _out_aval: core.ShapedArray, - **kwargs) -> TfVal: - hlo_sharding = _xla_compatible_sharding_to_hlo_sharding(sharding, _in_avals[0]) - return _shard_value(arg, hlo_sharding, - skip_replicated_sharding=False) - - -tf_impl_with_avals[pjit.sharding_constraint_p] = _pjit_sharding_constraint - -def _dimension_size_jax2tf(op: TfVal, *, dimension, _in_avals, _out_aval): - dim_tf = tf.shape(op)[dimension] - if dim_tf.dtype != _to_tf_dtype(_out_aval.dtype): - return _convert_element_type(dim_tf, new_dtype=_out_aval.dtype, - weak_type=_out_aval.weak_type) - else: - return dim_tf - -tf_impl_with_avals[shape_poly.dimension_size_p] = _dimension_size_jax2tf - -def _dim_as_value_jax2tf(dim: shape_poly.DimSize): - dim_tf, = _eval_shape((dim,)) - return dim_tf - -tf_impl[shape_poly.dim_as_value_p] = _dim_as_value_jax2tf - -def _shape_assertion_jax2tf(assert_what, *error_message_inputs, - error_message: str): - - tf.debugging.assert_equal( - assert_what, True, - message=error_message.format(*error_message_inputs)) - return [] - -tf_impl[shape_poly.shape_assertion_p] = _shape_assertion_jax2tf - -def _reduce_precision(x, *, exponent_bits, mantissa_bits): - return tfxla.reduce_precision(x, exponent_bits=exponent_bits, - mantissa_bits=mantissa_bits) - -tf_impl[lax.reduce_precision_p] = _reduce_precision - -tf_impl[lax_internal.tie_p] = lambda x, y: y + tf_version = tuple(int(v) for v in tf.__version__.split(".")[:2]) + # apply_to_tensor comes from a tensorflow package, check the tensorflow + # version to make sure that it has the sharding_v2_proto parameter. + if tf_version < (2, 20): + return xla_sharding.Sharding(proto=xla_sharding_v1_proto).apply_to_tensor( + val, use_sharding_op=True + ) + return xla_sharding.Sharding(proto=xla_sharding_v1_proto).apply_to_tensor( + val, use_sharding_op=True, sharding_v2_proto=xla_sharding_v2_proto + ) def _register_checkpoint_pytrees(): diff --git a/jax/experimental/jax2tf/tests/BUILD b/jax/experimental/jax2tf/tests/BUILD new file mode 100644 index 000000000000..c743f3c23ccb --- /dev/null +++ b/jax/experimental/jax2tf/tests/BUILD @@ -0,0 +1,300 @@ +# Copyright 2022 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:defs.bzl", "py_binary", "py_library") +load("//jaxlib:jax.bzl", "if_oss", "jax_generate_backend_suites", "jax_multiplatform_test", "jax_visibility", "py_deps") + +licenses(["notice"]) + +package(default_applicable_licenses = []) + +jax_generate_backend_suites() + +py_library( + name = "jax2tf_tests", + srcs = [ + "__init__.py", + ], + visibility = ["//jax:internal"], +) + +py_library( + name = "jax2tf_limitations", + srcs = [ + "jax2tf_limitations.py", + ], + visibility = jax_visibility("jax2tf_limitations"), + deps = [ + "//jax/_src:internal_test_harnesses", + "//jax/_src:test_util", + ], +) + +py_library( + name = "tf_test_util", + srcs = [ + "tf_test_util.py", + ], + # //:wheel_additives needs visibility to this target + visibility = jax_visibility("tf_test_util") + if_oss(["//:__pkg__"]), + deps = [ + ":jax2tf_limitations", + "//jax/_src:internal_test_harnesses", + "//jax/_src:test_util", + "//jax/experimental/jax2tf", + ], +) + +py_library( + name = "models_util", + srcs = [ + "converters.py", + "model_harness.py", + ], + deps = [ + "//jax", + "//jax/experimental/jax2tf/tests/flax_models", + "//third_party/py/jraph", + "//third_party/py/numpy", + "//third_party/py/tensorflowjs", + ], +) + +jax_multiplatform_test( + name = "jax2tf_test", + srcs = ["jax2tf_test.py"], + enable_configs = [ + "tpu_v3_x4", + ], + # On GPU, this test uses both TF and JAX and might OOM, see: + # https://docs.jax.dev/en/latest/gpu_memory_allocation.html + env = { + "XLA_PYTHON_CLIENT_ALLOCATOR": "platform", + }, + tags = ["jax2tf"], + deps = [ + ":tf_test_util", + "//jax/experimental/jax2tf", + ] + py_deps("tensorflow") + py_deps([ + "absl/testing", + "absl/logging", + ]), +) + +jax_multiplatform_test( + name = "control_flow_ops_test", + srcs = ["control_flow_ops_test.py"], + config_tags_overrides = { + "tpu_v3": {"ondemand": False}, + }, + disable_configs = [ + # TODO(b/172667839): Compilation times out with -c dbg. + "tpu_v4i", + ], + # On GPU, this test uses both TF and JAX and might OOM, see: + # https://docs.jax.dev/en/latest/gpu_memory_allocation.html + env = { + "XLA_PYTHON_CLIENT_ALLOCATOR": "platform", + }, + tags = ["jax2tf"], + deps = [ + ":tf_test_util", + "//jax/experimental/jax2tf", + ], +) + +jax_multiplatform_test( + name = "shape_poly_test", + size = "large", + srcs = ["shape_poly_test.py"], + config_tags_overrides = { + "tpu_v3": {"ondemand": False}, + }, + disable_configs = [ + "tpu_v4i", # TODO(b/172667839): Compilation times out with -c dbg + ], + enable_configs = [ + "cpu", + "cpu_x32", + ], + # On GPU, this test uses both TF and JAX and might OOM, see: + # https://docs.jax.dev/en/latest/gpu_memory_allocation.html + env = { + "XLA_PYTHON_CLIENT_ALLOCATOR": "platform", + }, + shard_count = { + "cpu": 4, + "gpu": 4, + "tpu": 4, + }, + tags = [ + "jax2tf", + ], + deps = [ + ":tf_test_util", + "//jax/_src:internal_test_harnesses", + "//jax/experimental/jax2tf", + ], +) + +jax_multiplatform_test( + name = "primitives_test", + size = "large", + srcs = ["primitives_test.py"], + backend_variant_args = { + "tpu_v4i": ["--xla_tpu_allow_in_cmem_copy=true --xla_jf_convolution_performance_target=0.1"], # TODO(b/172584898): slow compilation. + }, + config_tags_overrides = { + "tpu_v3": {"ondemand": False}, + }, + # On GPU, this test uses both TF and JAX and might OOM, see: + # https://docs.jax.dev/en/latest/gpu_memory_allocation.html + env = { + "XLA_PYTHON_CLIENT_ALLOCATOR": "platform", + }, + shard_count = { + "cpu": 40, + "gpu": 20, + "tpu": 20, + }, + tags = [ + "jax2tf", + "nodebug", # Times out. + ], + deps = [ + ":tf_test_util", + "//jax/_src:internal_test_harnesses", + "//jax/experimental/jax2tf", + ], +) + +jax_multiplatform_test( + name = "jax_primitives_coverage_test", + srcs = ["jax_primitives_coverage_test.py"], + backend_variant_args = { + "tpu_v4i": ["--xla_jf_convolution_performance_target=0.1"], # TODO(b/172584898): slow compilation. + }, + shard_count = { + "cpu": 8, + "gpu": 8, + "tpu": 8, + }, + tags = [ + "jax2tf", + "nodebug", # Times out. + ], + deps = [ + ":tf_test_util", + "//jax/_src:internal_test_harnesses", + "//jax/experimental/jax2tf", + ], +) + +jax_multiplatform_test( + name = "savedmodel_test", + srcs = ["savedmodel_test.py"], + # On GPU, this test uses both TF and JAX and might OOM, see: + # https://docs.jax.dev/en/latest/gpu_memory_allocation.html + env = { + "XLA_PYTHON_CLIENT_ALLOCATOR": "platform", + }, + tags = ["jax2tf"], + deps = [ + ":tf_test_util", + "//jax/experimental/jax2tf", + ], +) + +jax_multiplatform_test( + name = "sharding_test", + srcs = ["sharding_test.py"], + disable_configs = [ + "gpu_h100", + "tpu_v4i", + ], + enable_configs = [ + "tpu_v3_x4", + ], + tags = [ + "jax2tf", + "requires-net:external", # For running with xprof + ], + deps = [ + ":tf_test_util", + "//jax/_src:compiler", + "//jax/experimental/jax2tf", + "//third_party/py/absl:app", + ], +) + +jax_multiplatform_test( + name = "call_tf_test", + srcs = ["call_tf_test.py"], + # On GPU, this test uses both TF and JAX and might OOM, see: + # https://docs.jax.dev/en/latest/gpu_memory_allocation.html + env = { + "XLA_PYTHON_CLIENT_ALLOCATOR": "platform", + }, + tags = ["jax2tf"], + deps = [ + ":tf_test_util", + ] + py_deps("tpu_ops"), +) + +jax_multiplatform_test( + name = "back_compat_tf_test", + srcs = ["back_compat_tf_test.py"], + # On GPU, this test uses both TF and JAX and might OOM, see: + # https://docs.jax.dev/en/latest/gpu_memory_allocation.html + env = { + "XLA_PYTHON_CLIENT_ALLOCATOR": "platform", + }, + tags = ["jax2tf"], + deps = [ + "//jax/_src:internal_export_back_compat_test_util", + "//jax/experimental/jax2tf", + "//jax/experimental/jax2tf/tests/back_compat_testdata", + ], +) + +# This filegroup specifies the set of tests which should be run on forge, for +# the purposes of the verify_tests_in_build. +# If a test is external only, add the filename to the `exclude` list. +filegroup( + name = "forge_tests", + srcs = glob( + include = [ + "*_test.py", + ], + exclude = [], + ) + ["BUILD"], + visibility = jax_visibility("forge_tests"), +) + +py_binary( + name = "models_test_main", + srcs = ["models_test_main.py"], + data = [ + "//jax/experimental/jax2tf/g3doc:convert_models_results", + ], + deps = [ + ":models_util", + "//jax", + "//pyglib:resources", + "//third_party/py/absl:app", + "//third_party/py/absl/flags", + "//third_party/py/numpy", + "//third_party/py/tensorflow", + ], +) diff --git a/jax/experimental/jax2tf/tests/back_compat_tf_test.py b/jax/experimental/jax2tf/tests/back_compat_tf_test.py index 2cf363b0cfb2..1041e0eb793f 100644 --- a/jax/experimental/jax2tf/tests/back_compat_tf_test.py +++ b/jax/experimental/jax2tf/tests/back_compat_tf_test.py @@ -30,7 +30,7 @@ import jax from jax._src import test_util as jtu from jax._src.internal_test_util import export_back_compat_test_util as bctu -from jax._src.lib import xla_extension +from jax._src.lib import _jax from jax.experimental import jax2tf from jax.experimental.jax2tf.tests.back_compat_testdata import tf_call_tf_function import jax.numpy as jnp @@ -77,7 +77,7 @@ def run_current(self, func: Callable, data: bctu.CompatTestData): # for the whole directory. @tf.function(autograph=False, jit_compile=True) def tf_func(the_input): # Use recognizable names for input and result - res = jax2tf.convert(func, native_serialization=True)(the_input) + res = jax2tf.convert(func)(the_input) return tf.identity(res, name="the_result") self.tf_func = tf_func @@ -96,7 +96,7 @@ def serialize( for op in tf_graph.get_operations(): if op.type == "XlaCallModule": serialized_module = op.get_attr("module") - module_str = xla_extension.mlir.deserialize_portable_artifact( + module_str = _jax.mlir.deserialize_portable_artifact( serialized_module ) module_version = op.get_attr("version") diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index 4647a16d79c6..cd17bcf0a2e9 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -22,12 +22,12 @@ from absl.testing import absltest from absl.testing import parameterized import jax -from jax import dlpack from jax import dtypes from jax import export from jax import lax from jax import numpy as jnp from jax._src import config +from jax._src import dlpack from jax._src import test_util as jtu from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo @@ -66,6 +66,8 @@ def _named_test(**kwargs): _call_tf_non_compilable_error = "Error compiling TensorFlow function" _call_tf_dynamic_shape_error = "call_tf cannot call functions whose output has dynamic shape" + +@jtu.thread_unsafe_test_class() class CallTfTest(tf_test_util.JaxToTfTestCase): def setUp(self): @@ -82,17 +84,6 @@ def setUp(self): continue # A virtual device if all(tf_device.device_type != d.device_type for d in self.tf_devices): self.tf_devices.append(tf_device) - self.warning_ctx = jtu.ignore_warning( - message=( - "(jax2tf.convert with native_serialization=False has been deprecated" - "|Calling from_dlpack with a DLPack tensor is deprecated)" - ) - ) - self.warning_ctx.__enter__() - - def tearDown(self): - self.warning_ctx.__exit__(None, None, None) - super().tearDown() @_parameterized_jit def test_eval_scalar_arg(self, with_jit=True): @@ -630,7 +621,9 @@ def fun_tf(x): return x * tf.broadcast_to(outer_var, x.shape) + 1. hlo = tf.function(fun_tf, jit_compile=True, autograph=False).experimental_get_compiler_ir(x)() - self.assertIn("(arg0.1: f32[3], arg1.2: f32[1]) -> f32[3]", hlo) + self.assertRegex( + hlo, r"\(arg0.[0-9]+: f32\[3\], arg1.[0-9]+: f32\[1\]\) -> f32\[3\]" + ) # Capture a constant outer_ct = np.array([3.], dtype=np.float32) @@ -825,7 +818,6 @@ def f_jax(x): # lowering will have the proper side effects for the function_list. f_tf = tf.function(jax2tf.convert( f_jax, - native_serialization=True, native_serialization_platforms=lowering_platforms)) for tf_device in self.tf_devices: with self.subTest(tf_device.device_type): @@ -836,8 +828,8 @@ def f_jax(x): self.assertAllClose(res, f_jax(x)) @parameterized.named_parameters( - {"testcase_name": f"_type={type_.__name__}", "type_": type_} - for type_ in dlpack.SUPPORTED_DTYPES + {"testcase_name": f"_type={type_.name}", "type_": type_} + for type_ in dlpack.SUPPORTED_DTYPES_SET ) def test_avoid_copy_between_gpu_and_cpu(self, type_): try: @@ -848,7 +840,7 @@ def test_avoid_copy_between_gpu_and_cpu(self, type_): raise unittest.SkipTest("Test requires a GPU device.") def tf_fun(x): - if type_ == jnp.bool_: + if type_ == np.dtype('bool'): return tf.math.logical_or(x, True) else: return x + 1 @@ -879,8 +871,9 @@ def _transfer_guard(guard_level): jax2tf.call_tf(tf_fun)(jax_array_on_gpu) +@jtu.thread_unsafe_test_class() class RoundTripToJaxTest(tf_test_util.JaxToTfTestCase): - "Reloading output of jax2tf into JAX with call_tf" + """Reloading output of jax2tf into JAX with call_tf.""" def setUp(self): if tf is None: @@ -889,17 +882,6 @@ def setUp(self): # bug in TensorFlow. _ = tf.add(1, 1) super().setUp() - self.warning_ctx = jtu.ignore_warning( - message=( - "(jax2tf.convert with native_serialization=False has been deprecated" - "|Calling from_dlpack with a DLPack tensor is deprecated)" - ) - ) - self.warning_ctx.__enter__() - - def tearDown(self): - self.warning_ctx.__exit__(None, None, None) - super().tearDown() def test_simple(self): f_jax = jnp.sin @@ -1067,9 +1049,7 @@ def fun_jax(x, y): x = np.array([-1.0, 0.0, 1.0], dtype=np.float32) y = np.array([-0.5, 0.0, 0.5], dtype=np.float32) - converted_fun = tf.function( - jax2tf.convert(fun_jax, native_serialization=True) - ) + converted_fun = tf.function(jax2tf.convert(fun_jax)) expected = np.sin(x) + np.cos(y) res = tf.function(converted_fun, jit_compile=True, autograph=False)(x, y) self.assertAllClose(expected, res.numpy(), atol=1e-5, rtol=1e-5) @@ -1185,8 +1165,9 @@ def tf_f(x, params): self.assertDictEqual(actual[0], {"y": x, "other": None}) +@jtu.thread_unsafe_test_class() class RoundTripToTfTest(tf_test_util.JaxToTfTestCase): - "Reloading output of call_tf into TF with jax2tf." + """Reloading output of call_tf into TF with jax2tf.""" def setUp(self): if tf is None: @@ -1195,17 +1176,6 @@ def setUp(self): # bug in TensorFlow. _ = tf.add(1, 1) super().setUp() - self.warning_ctx = jtu.ignore_warning( - message=( - "(jax2tf.convert with native_serialization=False has been deprecated" - "|Calling from_dlpack with a DLPack tensor is deprecated)" - ) - ) - self.warning_ctx.__enter__() - - def tearDown(self): - self.warning_ctx.__exit__(None, None, None) - super().tearDown() def test_alternate(self): # Alternate sin/cos with sin in TF and cos in JAX @@ -1319,58 +1289,6 @@ def fun_tf(x): # x:i32[3] fun_tf_rt = jax2tf.convert(jax2tf.call_tf(fun_tf)) fun_tf_rt(x) - @_parameterized_jit - def test_shape_poly_static_output_shape(self, with_jit=True): - if jax.config.jax2tf_default_native_serialization: - raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native serialization.") - x = np.array([0.7, 0.8], dtype=np.float32) - - def fun_tf(x): - return tf.math.reduce_sum(tf.math.sin(x)) - - fun_jax = jax2tf.call_tf(fun_tf) - fun_tf_rt = _maybe_tf_jit(with_jit, - jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])) - self.assertAllClose(fun_tf(x), fun_tf_rt(x)) - - @_parameterized_jit - def test_shape_poly(self, with_jit=False): - if jax.config.jax2tf_default_native_serialization: - raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native serialization.") - x = np.array([7, 8, 9, 10], dtype=np.float32) - def fun_jax(x): - y = jax2tf.call_tf(tf.math.sin, - output_shape_dtype=jax.ShapeDtypeStruct(x.shape, x.dtype))(x) - z = jnp.cos(y) - w = jax2tf.call_tf(lambda z: tf.concat([z, z], axis=0), - output_shape_dtype=jax.ShapeDtypeStruct((2 * z.shape[0],), z.dtype))(z) - assert w.shape[0] == 2 * x.shape[0] - return w - - fun_tf_rt = _maybe_tf_jit(with_jit, - jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])) - res_tf = fun_tf_rt(x) - self.assertAllClose(fun_jax(x), res_tf) - - @_parameterized_jit - def test_shape_poly_pytree_result(self, with_jit=True): - if jax.config.jax2tf_default_native_serialization: - raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native serialization.") - x = np.array([7, 8, 9, 10], dtype=np.float32) - def fun_jax(x): - # Returns a tuple - y = jax2tf.call_tf(lambda x: (x, tf.concat([x, x], axis=0)), - output_shape_dtype=(jax.ShapeDtypeStruct(x.shape, x.dtype), - jax.ShapeDtypeStruct((2 * x.shape[0],), x.dtype)))(x) - assert y[0].shape[0] == x.shape[0] - assert y[1].shape[0] == 2 * x.shape[0] - return y - - fun_tf_rt = _maybe_tf_jit(with_jit, - jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])) - res_tf = fun_tf_rt(x) - self.assertAllClose(fun_jax(x), res_tf) - @_parameterized_jit def test_shape_poly_error_no_output_shape_dtype(self, with_jit=True): x = np.array([7, 8, 9, 10], dtype=np.float32) @@ -1440,7 +1358,7 @@ def fun_jax(x): if kind == "bad_dim" and with_jit: # TODO: in jit more the error pops up later, at AddV2 expect_error = "Dimensions must be equal, but are 4 and 9 for .* AddV2" - if kind == "bad_dim" and jax.config.jax2tf_default_native_serialization: + if kind == "bad_dim": # TODO(b/268386622): call_tf with shape polymorphism and native serialization. expect_error = "Error compiling TensorFlow function" fun_tf_rt = _maybe_tf_jit(with_jit, @@ -1448,24 +1366,6 @@ def fun_jax(x): with self.assertRaisesRegex(expect_ex, expect_error): fun_tf_rt(x) - def test_inner_native_serialization(self): - # Two nested jax2tf, the inner one being with native serialization - x = np.ones((3,), dtype=np.float32) - def f_inner_jax(x): - return jnp.sin(x) - def f_outer_jax(x): - f_inner_tf = jax2tf.convert(f_inner_jax, native_serialization=True) - return jnp.cos(jax2tf.call_tf(f_inner_tf)(x)) - - f_outer_tf = tf.function( - jax2tf.convert(f_outer_jax, native_serialization=False), - autograph=False) - f_outer_graph = str(f_outer_tf.get_concrete_function(tf.convert_to_tensor(x)).graph.as_graph_def()) - # Quick way to check that there is an XlaCallModule op, and a Cos op, but no Sin op - self.assertIn('op: "Cos"', f_outer_graph) - self.assertIn('op: "XlaCallModule"', f_outer_graph) - self.assertNotIn('op: "Sin"', f_outer_graph) - @parameterized.named_parameters( _named_test(f2_function=f2_function, f2_saved_model=f2_saved_model, f4_function=f4_function, f4_saved_model=f4_saved_model) @@ -1476,12 +1376,6 @@ def f_outer_jax(x): def test_several_round_trips(self, f2_function=False, f2_saved_model=False, f4_function=False, f4_saved_model=False): - if (f2_saved_model and - f4_saved_model and - not jax.config.jax2tf_default_native_serialization): - # TODO: Getting error Found invalid capture Tensor("jax2tf_vjp/jax2tf_arg_0:0", shape=(), dtype=float32) when saving custom gradients - # when saving f4, but only with non-native serialization. - raise unittest.SkipTest("TODO: error invalid capture when saving custom gradients") x = np.array(.7, dtype=np.float32) # f(n)(x) = 2. * x^n def f(n): @@ -1629,7 +1523,6 @@ def _extract_info(op): # There is no runtime support yet so it can not run. tf_f_rt = jax2tf.convert( jax_f, - native_serialization=True, with_gradient=False, ) _, restored_model = tf_test_util.SaveAndLoadFunction( @@ -1687,7 +1580,6 @@ def tf_f(x): ) tf_f_rt = jax2tf.convert( jax_f, - native_serialization=True, with_gradient=False, ) _, _ = tf_test_util.SaveAndLoadFunction(tf_f_rt, input_args=[inputs]) @@ -1701,7 +1593,6 @@ def tf_f_2(): jax_f_2 = jax2tf.call_tf(tf.function(tf_f_2), call_tf_graph=True) tf_f_rt_2 = jax2tf.convert( jax_f_2, - native_serialization=True, with_gradient=False, ) _, _ = tf_test_util.SaveAndLoadFunction(tf_f_rt_2, input_args=[]) @@ -1772,7 +1663,6 @@ def _check_mlir_ops(op): f_tf = jax2tf.convert( f_jax, - native_serialization=True, with_gradient=False, ) _, restored_model = tf_test_util.SaveAndLoadFunction(f_tf, input_args=[x]) @@ -1816,7 +1706,6 @@ def test_call_tf_graph_polymorphic(self, ordered: bool, version: int): @tf.function(jit_compile=True, autograph=False) @partial(jax2tf.convert, with_gradient=False, - native_serialization=True, polymorphic_shapes=["(b)"]) @jax.jit def tf_f_2(x): @@ -1849,7 +1738,7 @@ def func_tf(x): data_inputs = (np.array([0.5, 0.7], dtype=np.float32),) def tf_func(the_input): - res = jax2tf.convert(jax_func, native_serialization=True)(the_input) + res = jax2tf.convert(jax_func)(the_input) return tf.identity(res, name="the_result") jit_tf_func = tf.function( diff --git a/jax/experimental/jax2tf/tests/control_flow_ops_test.py b/jax/experimental/jax2tf/tests/control_flow_ops_test.py index 3b39c8752ee7..568cf2bd04af 100644 --- a/jax/experimental/jax2tf/tests/control_flow_ops_test.py +++ b/jax/experimental/jax2tf/tests/control_flow_ops_test.py @@ -25,6 +25,7 @@ jax.config.parse_flags_with_absl() +@jtu.thread_unsafe_test_class() class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase): @jtu.ignore_warning(category=UserWarning, diff --git a/jax/experimental/jax2tf/tests/jax2tf_limitations.py b/jax/experimental/jax2tf/tests/jax2tf_limitations.py index 63f019b31157..878439726b9e 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_limitations.py +++ b/jax/experimental/jax2tf/tests/jax2tf_limitations.py @@ -16,16 +16,10 @@ from __future__ import annotations from collections.abc import Callable, Sequence -import itertools from typing import Any -import jax -from jax import lax -from jax import numpy as jnp -from jax._src import config -from jax._src import dtypes -from jax._src import test_util as jtu from jax._src.internal_test_util import test_harnesses +from jax._src import dtypes import numpy as np DType = Any @@ -37,10 +31,6 @@ class Jax2TfLimitation(test_harnesses.Limitation): See the primitive_test module docstring for details. """ - # Bitmask values for encoding limitations specific to native lowering - FOR_NATIVE = 1 - FOR_NON_NATIVE = 2 - def __init__( self, description: str, @@ -50,7 +40,6 @@ def __init__( enabled: bool = True, # jax2tf specific modes=("eager", "graph", "compiled"), - native_serialization=FOR_NON_NATIVE, skip_tf_run=False, expect_tf_error: bool = True, skip_comparison=False, @@ -86,7 +75,6 @@ def __init__( modes = (modes,) assert all(m in ["eager", "graph", "compiled"] for m in modes), "Invalid modes: {modes}" self.modes = modes - self.native_serialization = native_serialization self.expect_tf_error = expect_tf_error self.skip_tf_run = skip_tf_run self.custom_assert = custom_assert @@ -108,898 +96,54 @@ def get_max_tolerance_limitation( max_tol_lim = l return max_tol_lim - def filter( # type: ignore[override] - self, - dtype: DType | None = None, - device: str | None = None, - mode: str | None = None) -> bool: - """Checks if this limitation is enabled for dtype and device and mode.""" - native_serialization_mask = ( - Jax2TfLimitation.FOR_NATIVE - if config.jax2tf_default_native_serialization.value - else Jax2TfLimitation.FOR_NON_NATIVE) - return ((mode is None or mode in self.modes) and - (self.native_serialization & native_serialization_mask) and - super().filter(device=device, dtype=dtype)) - @classmethod def limitations_for_harness( cls, harness: test_harnesses.Harness) -> Sequence[Jax2TfLimitation]: group_method = getattr(cls, harness.group_name, None) - if harness.group_name in cls.harness_groups_no_limitations: - assert group_method is None, ( - f"Harness group '{harness.group_name}' is both in " - f"'harness_groups_no_limitations' and has a custom " - f"Jax2TfLimitation.classmethod defined (see module docstring)") - return [] - else: - assert group_method is not None, ( - f"Harness group '{harness.group_name}' must be either part of " - f"'harness_groups_no_limitations' or must have a custom " - f"Jax2TfLimitation.classmethod defined (see module docstring)") + if group_method is not None: limitations = group_method(harness) assert isinstance(limitations, (list, tuple)) return limitations - - # We keep here the explicit set of groups for which we don't have limitations - harness_groups_no_limitations = { - "abs", "add", "add_any", "and", "atan2", "bitcast_convert_type", - "broadcast", "broadcast_in_dim", "ceil", "clamp", "concatenate", - "cos", "cosh", "complex", "conj", "convert_element_type", "cummax", - "cummin", "device_put", "dynamic_slice", "dynamic_update_slice", "exp", - "eq", "floor", "gather", "ge", "gt", "imag", "iota", "iota_2x32_shape", - "is_finite", "le", "logistic", "lt", "log", "mul", "ne", "neg", "not", - "or", "pad", "population_count", "random_categorical", "random_uniform", - "random_randint", "reduce", "reduce_and", "reduce_precision", - "reduce_prod", "reduce_or", - "reduce_sum", "reduce_window_mul", "reduce_window_min", - "reduce_window_max", "real", "reshape", "rev", "rsqrt", "select_n", - "select_and_scatter_add", "shift_left", "shift_right_logical", - "shift_right_arithmetic", "sign", "sin", "sinh", "slice", "sqrt", - "squeeze", "stop_gradient", "sub", "tie_in", "transpose", "xor", - "zeros_like" - } - - @classmethod - def helper_get_trig_custom_limitation(cls, np_inverse): - - def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg): - operand, = args - tst.assertAllClose( - operand, np_inverse(result_tf), atol=tol, rtol=tol, err_msg=err_msg) - - return custom_numeric( - description="May return different but still correct results", - dtypes=[np.complex64, np.complex128], - custom_assert=custom_assert) - - @classmethod - def random_seed(cls, handess: test_harnesses.Harness): - return [custom_random_keys_output()] - - @classmethod - def random_split(cls, handess: test_harnesses.Harness): - return [custom_random_keys_output()] - - @classmethod - def random_fold_in(cls, handess: test_harnesses.Harness): - return [custom_random_keys_output()] - - @classmethod - def acos(cls, harness: test_harnesses.Harness): - return [ - custom_numeric( - dtypes=[np.complex64], - devices=("cpu", "gpu"), - tol=1e-4, - modes=("eager", "graph", "compiled")), - custom_numeric( - dtypes=[np.complex128], - devices=("cpu", "gpu"), - tol=1e-13, - modes=("eager", "graph", "compiled")), - custom_numeric( - dtypes=[np.complex64], - devices=("tpu",), - tol=1e-3, - modes=("eager", "graph", "compiled"), - native_serialization=Jax2TfLimitation.FOR_NON_NATIVE), - ] - - @classmethod - def acosh(cls, harness: test_harnesses.Harness): - return [ - custom_numeric(dtypes=[np.complex64], devices=("cpu", "gpu", "tpu"), - tol=1e-3), - custom_numeric(dtypes=[np.complex128], devices=("cpu", "gpu"), tol=1e-12), - Jax2TfLimitation( - "TF2XLA impl for Acosh doesn't properly handle large complex types," - " native serialization more closely matches numpy numerics.", - dtypes=[np.complex64, np.complex128], - devices=("cpu", "gpu", "tpu"), - modes="compiled", - expect_tf_error=False, - skip_comparison=True, - native_serialization=Jax2TfLimitation.FOR_NON_NATIVE, - ), - cls.helper_get_trig_custom_limitation(np.cosh), - ] - - @classmethod - def approx_top_k(cls, harness: test_harnesses.Harness): - supported_dtypes = jtu.supported_dtypes() - def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg): - del tol, err_msg - # Tests only that the indices correspond to the returned values - jax_values, jax_indices = result_jax - tf_values, tf_indices = result_tf - operand, = args - def operand_values(indices): - if operand.ndim == 1: - return operand[indices] - elif operand.ndim == 2: - return operand[np.arange(operand.shape[0]).reshape((-1, 1)), indices] - else: - assert False - tst.assertAllClose(operand_values(jax_indices), jax_values) - tst.assertAllClose(operand_values(tf_indices), tf_values) - - return [ - missing_tf_kernel( - dtypes=[t for t in [jnp.bfloat16, np.float16, np.float32, np.float64] - if t in supported_dtypes], - devices=("cpu", "gpu"), - modes=("graph", "eager")), - Jax2TfLimitation( - "compilation not supported for float64.", - dtypes=[np.float64], - devices=("cpu", "gpu"), - modes=("compiled",)), - custom_numeric( - dtypes=[t for t in [jnp.bfloat16, np.float16, np.float32, np.float64] - if t in supported_dtypes], - devices=("cpu", "gpu"), - modes=("eager", "graph"), - custom_assert=custom_assert)] - - @classmethod - def argmax(cls, harness: test_harnesses.Harness): - return [ - Jax2TfLimitation( - "different results when the input contains NaN and enable_xla=False", - dtypes=jtu.dtypes.all_inexact, - devices=("cpu", "gpu", "tpu"), - modes=("eager", "graph", "compiled"), - expect_tf_error=False, - skip_comparison=True, - enabled=("nan_" in harness.name and not harness.params["enable_xla"])), - ] - - @classmethod - def argmin(cls, harness: test_harnesses.Harness): - return cls.argmax(harness) - - @classmethod - def asin(cls, harness: test_harnesses.Harness): - return [ - custom_numeric(dtypes=[np.complex64], devices=("cpu", "gpu"), tol=1e-4, - modes=("eager", "graph", "compiled")), - custom_numeric(dtypes=[np.complex64], devices=("tpu", "gpu"), tol=2e-4, - modes=("eager", "graph", "compiled")), - custom_numeric(dtypes=[np.complex128], devices=("cpu", "gpu"), tol=1e-12, - modes=("eager", "graph", "compiled")), - cls.helper_get_trig_custom_limitation(np.sin) - ] + else: + return [] @classmethod def asinh(cls, harness: test_harnesses.Harness): return [ - custom_numeric(dtypes=[np.complex64], devices=("cpu", "gpu", "tpu"), - tol=1e-3), - custom_numeric(dtypes=[np.complex128], devices=("cpu", "gpu"), tol=1e-12), - custom_numeric(dtypes=[np.complex64, np.complex128], - devices=("cpu", "gpu", "tpu"), - modes=("compiled",), - tol=1e-3, - native_serialization=Jax2TfLimitation.FOR_NON_NATIVE), custom_numeric(dtypes=[np.complex128], devices=("cpu",), modes=("eager", "compiled", "graph"), - tol=1e-13, - native_serialization=Jax2TfLimitation.FOR_NATIVE | Jax2TfLimitation.FOR_NON_NATIVE), - cls.helper_get_trig_custom_limitation(np.sinh) - ] - - @classmethod - def atan(cls, harness: test_harnesses.Harness): - return [ - custom_numeric(dtypes=[np.complex64], devices=("cpu", "gpu"), tol=1e-5), - custom_numeric(dtypes=[np.complex64], devices=("tpu"), tol=1e-3), - custom_numeric(dtypes=[np.complex128], devices=("cpu", "gpu"), tol=1e-12), - cls.helper_get_trig_custom_limitation(np.tan) - ] - - @classmethod - def atanh(cls, harness: test_harnesses.Harness): - return [ - custom_numeric(dtypes=[np.float64], tol=1e-14), - custom_numeric(dtypes=[np.complex64], tol=1e-3), - custom_numeric(dtypes=[np.complex128], devices=("cpu", "gpu"), tol=1e-12), - cls.helper_get_trig_custom_limitation(np.tanh) - ] - - @classmethod - def bessel_i0e(cls, harness: test_harnesses.Harness): - return [ - missing_tf_kernel( - dtypes=[dtypes.bfloat16], - devices=("cpu", "gpu"), - modes=("eager", "graph")) - ] - - @classmethod - def bessel_i1e(cls, harness: test_harnesses.Harness): - return cls.bessel_i0e(harness) - - @classmethod - def cbrt(cls, harness: test_harnesses.Harness): - return [ - custom_numeric(dtypes=[np.float32], devices=("tpu"), tol=1e-5), + tol=1e-13), ] @classmethod def cholesky(cls, harness: test_harnesses.Harness): - - def custom_assert(tst, result_jax, result_tf, *, tol, err_msg, **_): - # cholesky_p returns garbage in the strictly upper triangular part of the - # result, so we can safely ignore that part. - tst.assertAllClose( - jnp.tril(result_jax), result_tf, atol=tol, err_msg=err_msg) - return [ - # TODO: very high tolerance - custom_numeric( - dtypes=[np.float32, np.complex64], - tol=1e-2, - devices=("cpu", "gpu"), - modes=("eager", "graph", "compiled")), - custom_numeric( - dtypes=[np.float64, np.complex128], - tol=1e-6, - devices=("cpu", "gpu"), - modes=("eager", "graph", "compiled")), - custom_numeric( - dtypes=[dtypes.bfloat16, np.float16], - tol=5e-2, - devices=("cpu", "gpu"), - modes=("eager", "graph", "compiled")), custom_numeric( dtypes=[dtypes.bfloat16], tol=5e-5, # Error for GL devices=("tpu",), - modes=("eager", "graph", "compiled"), - native_serialization=Jax2TfLimitation.FOR_NATIVE), - custom_numeric( - custom_assert=custom_assert, - description=( - "May return different values in the strictly upper triangular " - "part of the result. This does not matter for correctness, " - "because this part of the matrix is not considered in the result." - ), - modes=("eager", "graph", "compiled")) + modes=("eager", "graph", "compiled")), ] @classmethod def conv_general_dilated(cls, harness: test_harnesses.Harness): - prefer_elem = harness.params["preferred_element_type"] return [ - Jax2TfLimitation( - "Non-deterministic NaN for conv_general_dilated with preferred_element_type", - dtypes=[ - jnp.int32, np.int16, np.int64 - ], - devices=["cpu", "gpu", "tpu"], - modes=("eager", "graph", "compiled"), - enabled=(prefer_elem is not None - and prefer_elem in [jnp.bfloat16, np.float16, np.float32, np.float64]), - skip_comparison=True), # Even in compiled mode, for GPU we see a bit of discrepancy but # very minor. - custom_numeric(dtypes=[np.float32], devices="gpu", - modes=("eager", "graph", "compiled"), - tol=1e-5), custom_numeric(dtypes=[np.float32], devices="cpu", modes=("eager", "graph", "compiled"), - tol=1e-4, - native_serialization=Jax2TfLimitation.FOR_NATIVE | Jax2TfLimitation.FOR_NON_NATIVE), - custom_numeric(description="higher numeric inaccuracy when `enable_xla=False`", - modes=("eager", "graph", "compiled"), - enabled=(not harness.params["enable_xla"]), - tol=5e-3) - ] - - @classmethod - def cumlogsumexp(cls, harness): - return [ - custom_numeric( - dtypes=(np.float16, jnp.bfloat16, np.float32), - devices=("cpu", "gpu", "tpu"), - modes=("eager", "graph", "compiled"), - tol=5e-1, - ) - ] - - @classmethod - def cumprod(cls, harness): - return [ - custom_numeric( - dtypes=(np.float16, jnp.bfloat16), - devices=("cpu", "gpu", "tpu"), - modes=("eager", "graph", "compiled"), - tol=5e-1, - ) - ] - - @classmethod - def cumsum(cls, harness): - return [ - custom_numeric( - dtypes=(np.float16, jnp.bfloat16), - devices=("cpu", "gpu", "tpu"), - modes=("eager", "graph", "compiled"), - tol=5e-1, - ) - ] - - @classmethod - def custom_linear_solve(cls, harness: test_harnesses.Harness): - return [ - Jax2TfLimitation( - "TODO: large numerical discrepancy", - dtypes=[np.float32], - devices="tpu", - expect_tf_error=False, - skip_comparison=True), - custom_numeric(dtypes=[np.float32], devices="tpu", tol=0.01), - custom_numeric(tol=1e-3), - ] - - @classmethod - def digamma(cls, harness: test_harnesses.Harness): - dtype = harness.dtype - - # In the bfloat16 case, TF and lax both return NaN in undefined cases. - # digamma is not defined at 0 and -1 - def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg): - # lax.digamma returns NaN and tf.math.digamma returns inf - arg, = args - special_cases = (arg == 0.) | (arg == -1.) - nr_special_cases = np.count_nonzero(special_cases) - tst.assertAllClose( - np.full((nr_special_cases,), dtype(np.nan)), - result_jax[special_cases], - err_msg=err_msg) - tst.assertAllClose( - np.full((nr_special_cases,), dtype(np.inf)), - result_tf[special_cases], - err_msg=err_msg) - # non-special cases are equal - tst.assertAllClose( - result_jax[~special_cases], - result_tf[~special_cases], - atol=tol, - rtol=tol, - err_msg=err_msg) - - return [ - missing_tf_kernel( - dtypes=[dtypes.bfloat16], - devices=("cpu", "gpu"), - modes=("eager", "graph")), - custom_numeric(dtypes=[np.float64], tol=1e-13), - custom_numeric(dtypes=[np.float32], devices=["cpu", "gpu"], tol=1e-3), - custom_numeric( - dtypes=[dtypes.bfloat16], - custom_assert=custom_assert, - description=( - "May return different results at singularity points 0 and -1." - "JAX returns nan and TF returns inf")) - ] - - @classmethod - def div(cls, harness: test_harnesses.Harness): - return [ - Jax2TfLimitation( - "TF integer division fails if divisor contains 0; JAX returns NaN", - dtypes=[ - np.uint8, np.int8, np.uint16, np.uint32, np.uint64, np.int8, - np.int16, np.int32, np.int64 - ], - # Only the harnesses with "singularity" will have divide by 0 - enabled=("singularity" in harness.name)) - ] - - @classmethod - def dot_general(cls, harness: test_harnesses.Harness): - prefer_elem = harness.params["preferred_element_type"] - return [ - missing_tf_kernel(dtypes=[np.bool_],), - # TODO(b/189287598) - Jax2TfLimitation( - "Non-deterministic NaN for dot_general with preferred_element_type on GPU (b/189287598)", - dtypes=[ - jnp.bfloat16, np.float16, np.float32, np.complex64 - ], - devices="gpu", - modes=("eager", "graph", "compiled"), - enabled=(prefer_elem is not None), - skip_comparison=True), - # TODO(b/241740367) - note this only occurs when X64 is enabled. - Jax2TfLimitation( - "Large tolerances when upcasting with preferred_element_type on CPU (b/241740367)", - devices=["cpu", "gpu", "tpu"], - enabled=prefer_elem and np.dtype(harness.dtype) < np.dtype(prefer_elem), - skip_comparison=True), - # TODO(necula): look into this, but this is only for non-native serialization - Jax2TfLimitation( - "Errors when lhs_dtype != rhs_dtype for non-native serialization with 64-bit types", - devices=["cpu", "gpu", "tpu"], - enabled=(harness.dtype != harness.params["rhs_dtype"] and - (harness.dtype in [np.int64, np.uint64, np.float64] or - harness.params["rhs_dtype"] in [np.int64, np.uint64, np.float64])), - skip_comparison=True), - # TODO(necula): look into this, but this is only for non-native serialization and enable_xla=False - Jax2TfLimitation( - "Errors for non-native serialization with enable_xla=False for certain input dtype combinations", - devices=["cpu", "gpu", "tpu"], - enabled=(not harness.params["enable_xla"] and - (harness.dtype in [np.int16, np.uint32, np.uint16] or - harness.params["rhs_dtype"] in [np.int16, np.uint32, np.uint16] or - # Some combinations end up being widened to a larger type that is not - # supported - (harness.dtype, harness.params["rhs_dtype"]) in [ - (np.float16, jnp.bfloat16), - (np.int32, np.float16), - (np.int8, np.float16), - (np.int8, np.uint8), - ])), - skip_comparison=True, - skip_tf_run=True), - # TODO(necula): look into this, but this is only for non-native serialization - Jax2TfLimitation( - "Crash when lhs_dtype != rhs_dtype for non-native serialization on TPU for complex numbers", - devices=["tpu"], - enabled=(harness.dtype != harness.params["rhs_dtype"] and - (harness.dtype in [np.complex64, np.complex128] or - harness.params["rhs_dtype"] in [np.complex64, np.complex128])), - skip_comparison=True, - skip_tf_run=True), - # JAX performs float16 matmuls in float32 on CPU, so the JAX result - # may be more precise. - custom_numeric(dtypes=[np.float16], devices=["cpu"], tol=1e-2, - modes=("eager", "graph", "compiled")), - # Flakiness on different_dtypes_lhs_int16_4_3_rhs_float16_3_6_dimensionnumbers_1_0_enable_xla_True - # Strangely, we only see the flakiness in primitives_graph_serialization_test_gpu_pjrt_c_api - custom_numeric(dtypes=[np.int16], devices=["gpu"], tol=1e-2, - modes=("eager", "graph", "compiled"), - enabled=(harness.params["enable_xla"] and - harness.dtype != harness.params["rhs_dtype"])), - ] - - @classmethod - def eig(cls, harness: test_harnesses.Harness): - compute_left_eigenvectors = harness.params["compute_left_eigenvectors"] - compute_right_eigenvectors = harness.params["compute_right_eigenvectors"] - dtype = harness.dtype - - def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg): - operand, = args - inner_dimension = operand.shape[-1] - - # Test ported from tests.linlag_test.testEig - # Norm, adjusted for dimension and type. - def norm(x): - norm = np.linalg.norm(x, axis=(-2, -1)) - return norm / ((inner_dimension + 1) * jnp.finfo(dtype).eps) - - def check_right_eigenvectors(a, w, vr): - tst.assertTrue( - np.all(norm(np.matmul(a, vr) - w[..., None, :] * vr) < 100)) - - def check_left_eigenvectors(a, w, vl): - rank = len(a.shape) - aH = jnp.conj(a.transpose(list(range(rank - 2)) + [rank - 1, rank - 2])) - wC = jnp.conj(w) - check_right_eigenvectors(aH, wC, vl) - - def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array): - tol = None - # TODO(bchetioui): numerical discrepancies - if dtype in [np.float32, np.complex64]: - tol = 1e-4 - elif dtype in [np.float64, np.complex128]: - tol = 1e-13 - closest_diff = min(abs(eigenvalues_array - eigenvalue)) - tst.assertAllClose( - closest_diff, - np.array(0., closest_diff.dtype), - atol=tol, - err_msg=err_msg) - - all_w_jax, all_w_tf = result_jax[0], result_tf[0] - for idx in itertools.product(*map(range, operand.shape[:-2])): - w_jax, w_tf = all_w_jax[idx], all_w_tf[idx] - for i in range(inner_dimension): - check_eigenvalue_is_in_array(w_jax[i], w_tf) - check_eigenvalue_is_in_array(w_tf[i], w_jax) - - if compute_left_eigenvectors: - check_left_eigenvectors(operand, all_w_tf, result_tf[1]) - if compute_right_eigenvectors: - check_right_eigenvectors(operand, all_w_tf, - result_tf[1 + compute_left_eigenvectors]) - - return [ - # Eig does not work in JAX on gpu or tpu - Jax2TfLimitation( - "function not compilable", modes="compiled", devices="cpu"), - Jax2TfLimitation( - "TF Conversion of eig is not implemented when both compute_left_eigenvectors and compute_right_eigenvectors are set to True", - enabled=(compute_left_eigenvectors and compute_right_eigenvectors)), - custom_numeric( - custom_assert=custom_assert, - description=("May return the eigenvalues and eigenvectors in a " - "potentially different order. The eigenvectors may " - "also be different, but equally valid.")) - ] - - @classmethod - def eigh(cls, harness: test_harnesses.Harness): - dtype = harness.dtype - - def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg): - operand, = args - inner_dimension = operand.shape[-1] - - def check_right_eigenvectors(a, w, vr): - tol = 1e-16 - # TODO(bchetioui): tolerance needs to be very high in compiled mode, - # specifically for eigenvectors. - if dtype == np.float64: - tol = 2e-5 - elif dtype == np.float32: - tol = 1e-2 - elif dtype in [dtypes.bfloat16, np.complex64]: - tol = 1e-3 - elif dtype == np.complex128: - tol = 2e-5 - tst.assertAllClose( - np.matmul(a, vr) - w[..., None, :] * vr, - np.zeros(a.shape, dtype=vr.dtype), - atol=tol, - # For bfloat16 the np.matmul returns float32 result. - check_dtypes=False, - err_msg=err_msg) - - def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array): - tol = None - if dtype in [dtypes.bfloat16, np.float32, np.complex64]: - tol = 1e-3 - elif dtype in [np.float64, np.complex128]: - tol = 1e-5 - closest_diff = min(abs(eigenvalues_array - eigenvalue)) - tst.assertAllClose( - closest_diff, - np.array(0., closest_diff.dtype), - atol=tol, - err_msg=err_msg) - - _, all_w_jax = result_jax - all_vr_tf, all_w_tf = result_tf - - for idx in itertools.product(*map(range, operand.shape[:-2])): - w_jax, w_tf = all_w_jax[idx], all_w_tf[idx] - for i in range(inner_dimension): - check_eigenvalue_is_in_array(w_jax[i], w_tf) - check_eigenvalue_is_in_array(w_tf[i], w_jax) - - check_right_eigenvectors(operand, all_w_tf, all_vr_tf) - - return [ - missing_tf_kernel( - dtypes=[dtypes.bfloat16], - devices="tpu", - enabled=(harness.params["shape"] != (0, 0)), # This actually works! - ), - Jax2TfLimitation( - "TODO: numeric discrepancies", - dtypes=[np.float16], - devices="tpu", - expect_tf_error=False, - skip_comparison=True), - custom_numeric( - custom_assert=custom_assert, - description=("May return the eigenvalues and eigenvectors in a " - "potentially different order. The eigenvectors may " - "also be different, but equally valid."), - modes=("eager", "graph", "compiled")) + tol=1e-4), ] - @classmethod - def erf(cls, harness: test_harnesses.Harness): - return [] - - @classmethod - def erfc(cls, harness: test_harnesses.Harness): - return [] - - @classmethod - def erf_inv(cls, harness: test_harnesses.Harness): - # erf_inv is not defined for arg <= -1 or arg >= 1 - def custom_assert(tst, result_jax, result_tf, *, args, tol, - err_msg): # noqa: F811 - arg, = args - # for arg < -1 or arg > 1 - # lax.erf_inv returns NaN; tf.math.erf_inv return +/- inf - special_cases = (arg < -1.) | (arg > 1.) - # non-special cases are equal - tst.assertAllClose( - result_jax[~special_cases], - result_tf[~special_cases], - atol=tol, - rtol=tol, - err_msg=err_msg) - - return [ - missing_tf_kernel( - dtypes=[dtypes.bfloat16, np.float16], - devices=("cpu", "gpu"), - modes=("eager", "graph")), - custom_numeric(dtypes=[np.float32, np.float64], tol=1e-4), - custom_numeric( - dtypes=[np.float32, np.float64], - custom_assert=custom_assert, - description=( - "May return different results at undefined points (< -1 or > 1):" - " JAX returns `NaN` and TF returns `+inf` or `-inf`.")), - ] - - @classmethod - def expm1(cls, harness: test_harnesses.Harness): - return [custom_numeric(dtypes=[np.float64], tol=1e-5)] @classmethod def fft(cls, harness): return [ - Jax2TfLimitation( - "TF function not compilableble", - devices=("cpu", "gpu"), - dtypes=[np.float64], - modes="compiled"), - Jax2TfLimitation( - "TF function not compilableble for IFFT and IRFFT", - devices=("cpu", "gpu"), - dtypes=[np.complex128], - modes="compiled", - enabled=(str(harness.params["fft_type"]) in ["FftType.IFFT", - "FftType.IRFFT"])), - # TODO: very high tolerance - custom_numeric(tol=1e-3, modes=("eager", "graph", "compiled"), - native_serialization=Jax2TfLimitation.FOR_NON_NATIVE), custom_numeric(tol=1e-5, modes=("eager", "graph", "compiled"), - native_serialization=Jax2TfLimitation.FOR_NATIVE, devices=("cpu",)), ] - @classmethod - def _pow_test_util(cls, harness: test_harnesses.Harness): - - def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg): - # NaNs are mismatched, but assertAllClose will also behave weirdly for - # complex numbers containing np.inf as one of their components. See - # https://github.com/numpy/numpy/issues/15959 for more details. - mask = ( - np.isnan(result_jax) + np.isnan(result_tf) + np.isinf(result_jax) + - np.isinf(result_tf)) - tst.assertAllClose( - result_jax[~mask], result_tf[~mask], rtol=tol, err_msg=err_msg) - - return [ - custom_numeric( - dtypes=[np.float32, np.complex64], devices=("cpu", "gpu"), - tol=1e-3), - custom_numeric( - dtypes=[np.float64, np.complex128], - devices=("cpu", "gpu"), - tol=5e-5), - custom_numeric( - dtypes=[np.complex64, np.complex128], - custom_assert=custom_assert, - ) - ] - - @classmethod - def igamma(cls, harness: test_harnesses.Harness): - dtype = harness.dtype - - # igamma is not defined when the first argument is <=0 - def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg): - arg1, arg2 = args - # lax.igamma returns NaN when arg1 == arg2 == 0; tf.math.igamma returns 0 - special_cases = (arg1 == 0.) & (arg2 == 0.) - nr_special_cases = np.count_nonzero(special_cases) - tst.assertAllClose( - np.full((nr_special_cases,), np.nan, dtype=dtype), - result_jax[special_cases]) - tst.assertAllClose( - np.full((nr_special_cases,), 0., dtype=dtype), - result_tf[special_cases]) - if harness.dtype == np.float32: - tol = 1e-5 - # non-special cases are equal - tst.assertAllClose( - result_jax[~special_cases], - result_tf[~special_cases], - atol=tol, - rtol=tol, - err_msg=err_msg) - - return [ - missing_tf_kernel( - dtypes=[dtypes.bfloat16, np.float16], - devices=("cpu", "gpu"), - modes=("eager", "graph")), - custom_numeric( - custom_assert=custom_assert, - description=( - "May return different results at undefined points " - "(both arguments 0). JAX returns `NaN` and TF returns 0 or " - "JAX returns 1 and TF returns `NaN`")) - ] - - @classmethod - def igammac(cls, harness: test_harnesses.Harness): - dtype = harness.dtype - - # igammac is not defined when the first argument is <=0 - def custom_assert(tst, result_jax, result_tf, *, args, tol, - err_msg): # noqa: F811 - arg1, arg2 = args - # lax.igammac returns nan. when arg1 <= 0; tf.math.igammac returns 1 - special_cases = (arg1 <= 0.) | (arg2 <= 0) - nr_special_cases = np.count_nonzero(special_cases) - tst.assertAllClose( - np.full((nr_special_cases,), np.nan, dtype=dtype), - result_jax[special_cases], - err_msg=err_msg) - tst.assertAllClose( - np.full((nr_special_cases,), 1, dtype=dtype), - result_tf[special_cases], - err_msg=err_msg) - # non-special cases are equal - tst.assertAllClose( - result_jax[~special_cases], - result_tf[~special_cases], - atol=tol, - rtol=tol, - err_msg=err_msg) - - return [ - missing_tf_kernel( - dtypes=[dtypes.bfloat16, np.float16], - devices=("cpu", "gpu"), - modes=("eager", "graph")), - custom_numeric(dtypes=[np.float64], tol=1e-9), - custom_numeric(devices="gpu", tol=1e-3), - custom_numeric( - modes=("compiled",), - custom_assert=custom_assert, - devices=("cpu", "gpu", "tpu"), - description=( - "May return different results at undefined points " - "(both arguments less or equal 0). JAX returns `NaN` and TF returns 1")), - ] - - @classmethod - def integer_pow(cls, harness: test_harnesses.Harness): - y = harness.params["y"] - return [ - # TODO: on TPU, for f16, we get different results with eager mode - # than with compiled mode. - Jax2TfLimitation( - "Different overflow behavior. ", - dtypes=[np.float16, jnp.bfloat16], - devices="tpu", - expect_tf_error=False, - modes=("eager", "graph"), - skip_comparison=True), - Jax2TfLimitation( - "Different overflow behavior for large exponents. ", - dtypes=[ - np.int8, np.int16, np.int32, np.int64, np.float16, jnp.bfloat16, - np.float32, np.complex64, np.complex128 - ], - enabled=(abs(y) > 10), - expect_tf_error=False, - modes=("eager", "graph"), - skip_comparison=True), - custom_numeric(dtypes=[dtypes.bfloat16], tol=2e-2) - ] + list(cls._pow_test_util(harness)) - - @classmethod - def pow(cls, harness: test_harnesses.Harness): - return cls._pow_test_util(harness) - - @classmethod - def lgamma(cls, harness: test_harnesses.Harness): - return [ - missing_tf_kernel( - dtypes=[dtypes.bfloat16], - devices=("cpu", "gpu"), - modes=("eager", "graph")), - custom_numeric(dtypes=[np.float64], tol=1e-11), - custom_numeric(dtypes=[np.float32], tol=1e-3) - ] - - @classmethod - def log1p(cls, harness: test_harnesses.Harness): - return [ - custom_numeric(dtypes=[np.complex128], tol=3e-14), - custom_numeric(dtypes=[np.float64], tol=1e-10), - custom_numeric(dtypes=[np.float32], tol=1e-3) - ] - - @classmethod - def lu(cls, harness: test_harnesses.Harness): - dtype = harness.dtype - - def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg): - operand, = args - lu, pivots, perm = result_tf - batch_dims = operand.shape[:-2] - m, n = operand.shape[-2], operand.shape[-1] - - def _make_permutation_matrix(perm): - result = [] - for idx in itertools.product(*map(range, operand.shape[:-1])): - result += [0 if c != perm[idx] else 1 for c in range(m)] - result = np.reshape(np.array(result, dtype=dtype), [*batch_dims, m, m]) - return result - - k = min(m, n) - l = jnp.tril(lu, -1)[..., :, :k] + jnp.eye(m, k, dtype=dtype) - u = jnp.triu(lu)[..., :k, :] - p_mat = _make_permutation_matrix(perm) - - tst.assertArraysEqual( - lax.linalg.lu_pivots_to_permutation(pivots, m), perm) - tst.assertAllClose( - jnp.matmul(p_mat, operand), - jnp.matmul(l, u), - atol=tol, - rtol=tol, - err_msg=err_msg) - - return [ - custom_numeric( - dtypes=[np.float32, np.complex64], devices="tpu", tol=0.1), - custom_numeric( - dtypes=[np.float32, np.complex64], devices=("cpu", "gpu"), - tol=1e-5), - custom_numeric( - dtypes=[np.float64, np.complex128], - modes=("eager", "graph"), - tol=1e-13), - custom_numeric( - dtypes=[np.float64, np.complex128], modes=("compiled"), tol=1e-14), - custom_numeric( - custom_assert=custom_assert, - description=("May return different, but also correct, results when " - "the decomposition is not unique"), - devices=("cpu", "gpu"), - modes=("eager", "graph", "compiled")), - ] - @classmethod def max(cls, harness: test_harnesses.Harness): # TODO(bchetioui): discrepancies between TF & JAX when comparing with NaN; @@ -1009,14 +153,6 @@ def custom_assert(tst, result_jax, result_tf, err_msg, **_): tst.assertAllClose(result_jax[~mask], result_tf[~mask], err_msg=err_msg) return [ - custom_numeric( - custom_assert=custom_assert, - description=( - "May return different values when one of the values is NaN. " - "JAX always returns NaN, while TF returns the value NaN is compared with." - ), - modes=("eager", "graph", "compiled"), - native_serialization=Jax2TfLimitation.FOR_NON_NATIVE), # TODO(b/269996580) custom_numeric( custom_assert=custom_assert, @@ -1026,8 +162,7 @@ def custom_assert(tst, result_jax, result_tf, err_msg, **_): "xla_cpu_enable_fast_min_max compiler flag and therefore have " "different behavior of NaN propagation through min/max." ), - modes=("eager", "graph", "compiled"), - native_serialization=Jax2TfLimitation.FOR_NATIVE) + modes=("eager", "graph", "compiled")) ] @classmethod @@ -1039,14 +174,6 @@ def custom_assert(tst, result_jax, result_tf, *, err_msg, **_): tst.assertAllClose(result_jax[~mask], result_tf[~mask], err_msg=err_msg) return [ - custom_numeric( - custom_assert=custom_assert, - description=( - "May return different values when one of the values is NaN. " - "JAX always returns NaN, while TF returns the value NaN is compared with." - ), - modes=("eager", "graph", "compiled"), - native_serialization=Jax2TfLimitation.FOR_NON_NATIVE), # TODO(b/269996580) custom_numeric( custom_assert=custom_assert, @@ -1057,440 +184,9 @@ def custom_assert(tst, result_jax, result_tf, *, err_msg, **_): "different behavior of NaN propagation through min/max." ), modes=("eager", "graph", "compiled"), - native_serialization=Jax2TfLimitation.FOR_NATIVE) - ] - - @classmethod - def nextafter(cls, harness: test_harnesses.Harness): - return [missing_tf_kernel(dtypes=[np.float16, dtypes.bfloat16])] - - @classmethod - def qr(cls, harness: test_harnesses.Harness): - # See https://github.com/jax-ml/jax/pull/3775#issuecomment-659407824; - # # jit_compile=True breaks for complex types. - # TODO: see https://github.com/jax-ml/jax/pull/3775#issuecomment-659407824. - # - for now, the performance of the HLO QR implementation called when - # compiling with TF is expected to have worse performance than the - # custom calls made in JAX. - return [ - custom_numeric( - dtypes=[np.float64, np.complex128], - devices=("cpu", "gpu"), - modes=("eager", "graph", "compiled"), - tol=1e-13), - custom_numeric( - dtypes=[np.float32, np.complex64], - devices=("cpu", "gpu"), - modes=("eager", "graph", "compiled"), - tol=1e-4), - missing_tf_kernel( - dtypes=[dtypes.bfloat16], - devices="tpu", ) ] - @classmethod - def random_gamma(cls, harness: test_harnesses.Harness): - return [custom_numeric(devices="tpu", tol=1e-3)] - - @classmethod - def reduce_max(cls, harness: test_harnesses.Harness): - # Unlike reduce_window_max, we use a native TF op: tf.reduce_max, which - # does not work for complex - return [missing_tf_kernel(dtypes=[np.complex64, np.complex128])] - - @classmethod - def reduce_min(cls, harness: test_harnesses.Harness): - return cls.reduce_max(harness) - - @classmethod - def reduce_window_add(cls, harness: test_harnesses.Harness): - return [ - Jax2TfLimitation( - "Small deviations on GPU for large inputs and enable_xla=False", - dtypes=[np.float32], - devices="gpu", - modes=("eager", "graph", "compiled"), - expect_tf_error=False, - skip_comparison=False, - enabled=not harness.params["enable_xla"], - tol=3e-5), - Jax2TfLimitation( - "Large deviations on TPU for enable_xla=False", - dtypes=[dtypes.bfloat16, np.float16, np.float32], - devices="tpu", - modes=("eager", "graph", "compiled"), - expect_tf_error=False, - skip_comparison=True, - enabled=not harness.params["enable_xla"]), - custom_numeric(devices="cpu", dtypes=[np.float32], - modes=("eager", "graph", "compiled",), tol=1e-5), - custom_numeric(devices=("cpu", "gpu"), dtypes=[np.float16], - modes=("eager", "graph", "compiled",), tol=5e-3), - custom_numeric(devices=("cpu", "gpu"), dtypes=[dtypes.bfloat16], - modes=("eager", "graph", "compiled",), tol=5e-1), - ] - - @classmethod - def regularized_incomplete_beta(cls, harness: test_harnesses.Harness): - return [ - custom_numeric(dtypes=[np.float64], tol=1e-14), - missing_tf_kernel(dtypes=[np.float16, dtypes.bfloat16]) - ] - - @classmethod - def rem(cls, harness: test_harnesses.Harness): - return [ - Jax2TfLimitation( - "TF integer division fails if divisor contains 0; JAX returns NaN", - dtypes=[ - np.uint8, np.int8, np.uint16, np.uint32, np.uint64, np.int8, - np.int16, np.int32, np.int64 - ], - skip_comparison=True, - # Only the harnesses with "singularity" will have divide by 0 - enabled=("singularity" in harness.name)), - Jax2TfLimitation( - "TF division of inf by inf returns inf while in JAX returns nan", - dtypes=[ - np.float32, - ], - devices="gpu", - skip_comparison=True, - enabled=("singularity_inf_by_inf" in harness.name)), - ] - - @classmethod - def rng_bit_generator(cls, harness: test_harnesses.Harness): - return [] - - @classmethod - def round(cls, harness: test_harnesses.Harness): - return [ - missing_tf_kernel( - dtypes=[dtypes.bfloat16], - devices=("cpu", "gpu"), - modes=("eager", "graph")) - ] - - @classmethod - def scatter(cls, harness): - return [ - Jax2TfLimitation( - "out-of-bounds scatters are not supported in graph and eager mode", - dtypes=jtu.dtypes.all_inexact, - devices=("cpu", "gpu", "tpu"), - modes=("eager", "graph"), - expect_tf_error=True, - skip_comparison=True, - enabled=("modes_out_of_bounds" in harness.name and not harness.params["enable_xla"])), - custom_numeric(modes=("eager", "graph", "compiled"), - dtypes=[np.float16], tol=5e-3, - enabled=(not harness.params["enable_xla"])), - ] - - @classmethod - def scatter_add(cls, harness): - return cls.scatter(harness) - - @classmethod - def scatter_mul(cls, harness): - return cls.scatter(harness) - - @classmethod - def scatter_max(cls, harness): - return cls.scatter(harness) - - @classmethod - def scatter_min(cls, harness): - return cls.scatter(harness) - - @classmethod - def select_and_gather_add(cls, harness): - return [ - # This JAX primitives is not exposed directly in the JAX API - # but arises from JVP of `lax.reduce_window` for reducers - # `lax.max` or `lax.min`. It also arises from second-order - # VJP of the same. Implemented using XlaReduceWindow. - Jax2TfLimitation(( - "jax2tf unimplemented for 64-bit inputs because the current implementation " - "relies on packing two values into a single value. This can be " - "fixed by using a variadic XlaReduceWindow, when available"), - dtypes=[np.float64], - devices=("cpu", "gpu")) - ] - - @classmethod - def sort(cls, harness: test_harnesses.Harness): - return [ - Jax2TfLimitation( - # I think that this is because TF is running on CPU even for GPU tests? - "TODO: TF non-stable multiple-array sort", - devices="gpu", - enabled=(harness.params["num_arrays"] > 1 and - not harness.params["is_stable"]), - expect_tf_error=False, - skip_comparison=True), - ] - - @classmethod - def svd(cls, harness: test_harnesses.Harness): - # TODO: slow test - compute_uv = harness.params["compute_uv"] - - # Both `r_jax` and `r_tf` are 3-Tuples containing the SVD results: - # `S` (singular values), `U` (left singular vectors), and `Vh` (the - # adjoint of the right singular vectors). Note that the TF results are - # obtained through `_svd` in jax/experimental/jax2tf/jax2tf.py. - def custom_assert(tst, r_jax, r_tf, *, args, tol, err_msg): - - def reconstruct_operand(result): - # Reconstructing operand as documented in numpy.linalg.svd (see - # https://numpy.org/doc/stable/reference/generated/numpy.linalg.svd.html) - s, u, v = result - U = u[..., :s.shape[-1]] - V = v[..., :s.shape[-1], :] - S = s[..., None, :] - return jnp.matmul(U * S, V, precision=lax.Precision.HIGHEST) - - # Compares the shapes. - def compare_shapes(r_jax, r_tf): - shapes_jax = [result.shape for result in r_jax] - shapes_tf = [result.shape for result in r_tf] - tst.assertEqual(shapes_jax, shapes_tf) - - # Compares reconstructed operand. - # Computes backward error https://www.netlib.org/lapack/lug/node97.html - # and uses the maximum backward error if there are batch dimensions. - # The backward error is bounded by some constant multiplying the machine - # precision. - # TODO: Compares the operand instead of the reconstructed operand. - def compare_reconstructed_operand(r_jax, r_tf, tol): - operand_jax = reconstruct_operand(r_jax) - operand_tf = reconstruct_operand(r_tf) - error_norm = jnp.linalg.norm(operand_jax - operand_tf, - axis=(-2, -1)) - backward_error = (error_norm / - jnp.linalg.norm(operand_jax, axis=(-2, -1))) - max_backward_error = jnp.amax(backward_error) - tst.assertLess(max_backward_error, tol) - - # Computes the absolute gap between singular value `\sigma_i` and the - # nearest other singular value and for all singular values. The absolute - # gap is used to approximate the upper bound of angular difference - # between the computed and the true singular vectors. If the matrix is - # rectangular `m != n`, the gap for the smallest nonzero singular value - # should also consider the gap between it and zero. Note that this code - # relies on the singular values being in descending order. - def compute_absolute_gap(s, m, n): - forward_appendant = np.inf if m == n else 0 - forward_diff = jnp.diff(s, axis=-1, append=forward_appendant) - backward_diff = jnp.diff( - s[..., ::-1], axis=-1, append=np.inf)[..., ::-1] - absolute_gap = jnp.minimum(jnp.abs(forward_diff), - jnp.abs(backward_diff)) - return absolute_gap - - # See `CompareSingularVectors` in - # tensorflow/python/kernel_tests/linalg/svd_op_test.py - def compare_singular_vectors(x, y, *, error_bound): - # Singular vectors are only unique up to sign (complex phase factor for - # complex matrices), so we normalize the sign first. - sum_of_ratios = jnp.sum(jnp.divide(y, x), -2, keepdims=True) - phases = jnp.divide(sum_of_ratios, jnp.abs(sum_of_ratios)) - x *= phases - - # Note that in general `sqrt(sum(squares))` is not a stable way to - # compute l2 vector norms, but it should be OK for normalization - # factors of vectors with norm ~= 1 as here. - def dot_column_wise(a, b): - output = jnp.sum(jnp.einsum('...ij,...ij->...ij', a.conj(), b, - precision=lax.Precision.HIGHEST), - axis=-2) - return jnp.real(output) - - cos_angular_diff = ( - dot_column_wise(x, y) / - jnp.sqrt(dot_column_wise(x, x) * dot_column_wise(y, y))) - - # Values of `\cos(angular_diff)` outside the interval [0, 1] are clipped - # to the interval edges. For example, `\cos(angular_diff)` could contain - # values like 1.0000001 on float32, which are clipped to 1.0. It is - # possible that anything other than `cos_angular_diff` can be outside - # the interval [0, 1] due to roundoff. - cos_angular_diff = jnp.clip(cos_angular_diff, min=0.0, max=1.0) - - angular_diff = jnp.arccos(cos_angular_diff) - - # TODO: removes the slack factor on the angular difference. - # It is possible that the singular vectors are not accurate to much more - # than O(\sqrt(eps)), which is likely a property of the SVD algorithms - # in question; revisit with better understanding of the SVD algorithms. - if x.dtype in [np.float32, np.complex64]: - slack_factor = 2E4 - elif x.dtype in [np.float64, np.complex128]: - slack_factor = 2E9 - - np.testing.assert_array_less(angular_diff, - slack_factor * error_bound) - - if compute_uv: - # Compares the shapes. - compare_shapes(r_jax, r_tf) - - # Compares the singular values. Each computed singular value `\sigma_i` - # differs from the true `\sigma_i`* by at most - # `|\sigma_i - \sigma_i*| <= \epsilon \sigma_1`, where `\sigma_1` is the - # largest singular value and `\epsilon` denotes the machine precision. - s_jax, s_tf = r_jax[0], r_tf[0] - tst.assertAllClose(s_jax, s_tf, atol=tol, rtol=tol, err_msg=err_msg) - - # Compares the reconstructed operand. - compare_reconstructed_operand(r_jax, r_tf, tol) - - # Compares the singular vectors. - # We only compare the first `rank` singular vectors since the remainder - # forms an arbitrary orthonormal basis for the (row- or column-) null - # space, whose exact value depends on implementation details. - # TODO: A better estimation on the rank? - rank = r_jax[0].shape[-1] - - # Computes the upper bound for angular difference of singular vectors. - # The upper bound has the shape of `[..., k]`, where `...` denotes the - # batch dimensions and `k` is the number of nonzero singular values. - m = r_jax[1].shape[-2] - n = r_jax[2].shape[-2] - absolute_gap = compute_absolute_gap(r_jax[0], m, n) - epsilon = jnp.finfo(r_jax[0].dtype).eps - sigma_largest = (r_jax[0][..., 0])[..., None] - upperbound_singular_vectors = epsilon * sigma_largest / absolute_gap - upperbound_singular_vectors = upperbound_singular_vectors[..., :rank] - - # Left singular vectors. - u_jax = r_jax[1][..., :rank] - u_tf = r_tf[1][..., :rank] - compare_singular_vectors(u_jax, u_tf, - error_bound=upperbound_singular_vectors) - - # Right singular vectors. - v_jax = jnp.swapaxes(r_jax[2][..., :rank, :], -2, -1).conj() - v_tf = jnp.swapaxes(r_tf[2][..., :rank, :], -2, -1).conj() - compare_singular_vectors(v_jax, v_tf, - error_bound=upperbound_singular_vectors) - else: - tst.assertAllClose(r_jax, r_tf, atol=tol, rtol=tol, err_msg=err_msg) - - return [ - # Works in JAX for complex due to custom calls on cpu and gpu - Jax2TfLimitation( - "function not compilable. Implemented using `tf.linalg.svd` and `tf.linalg.adjoint`", - dtypes=[np.complex64, np.complex128], - devices=("cpu", "gpu"), - modes=("compiled",)), - Jax2TfLimitation( - "Large numerical discrepancy", - dtypes=[np.float16], - devices=("tpu"), - modes=("eager", "graph", "compiled"), - skip_comparison=True), - missing_tf_kernel(dtypes=[dtypes.bfloat16], devices="tpu"), - missing_tf_kernel(dtypes=[np.complex64, np.complex128], - modes=("compiled", "graph"), - devices="tpu"), - custom_numeric( - tol=1e-4, - dtypes=[np.float32, np.complex64], - devices=("cpu", "gpu"), - modes=("eager", "graph", "compiled")), - # TODO: this is very low tolerance for f64 - custom_numeric( - tol=1e-4, - dtypes=[np.float64, np.complex128], - devices=("cpu", "gpu"), - modes=("eager", "graph", "compiled")), - custom_numeric( - tol=1e-4, - description="custom numeric comparison when compute_uv on CPU/GPU", - custom_assert=custom_assert, - devices=("cpu", "gpu"), - modes=("eager", "graph", "compiled"), - enabled=(compute_uv == True)), - custom_numeric( - tol=1e-5, - description="custom numeric comparison when !compute_uv on TPU", - dtypes=[np.float32, np.complex64], - custom_assert=custom_assert, - devices=("tpu"), - modes=("eager", "graph", "compiled"), - enabled=not compute_uv), - custom_numeric( - tol=1e-2, - description="custom numeric comparison when compute_uv on TPU", - dtypes=[np.float32, np.float64, np.complex64, np.complex128], - custom_assert=custom_assert, - devices=("tpu"), - modes=("eager", "graph", "compiled"), - enabled=(compute_uv == True)), - ] - - @classmethod - def tan(cls, harness): - return [ - custom_numeric(dtypes=[np.complex64], devices="tpu", tol=1e-4), - custom_numeric(dtypes=[np.complex64], devices=("cpu", "gpu"), tol=1e-3), - custom_numeric(dtypes=[np.complex128], devices=("cpu", "gpu"), tol=1e-12) - ] - - @classmethod - def tanh(cls, harness): - return [ - custom_numeric(dtypes=[np.complex128], tol=1e-7), - custom_numeric(dtypes=[np.complex64], tol=1e-4) - ] - - @classmethod - def top_k(cls, harness): - - def custom_assert(tst, result_jax, result_tf, *, err_msg, **_): - assert len(result_jax) == len(result_tf) - # TODO: TF and JAX sort [inf, nan] differently. - first_arr_jax, first_arr_tf = result_jax[0], result_tf[0] - if np.all(first_arr_jax == first_arr_tf): - for arr_jax, arr_tf in zip(result_jax, result_tf): - tst.assertArraysEqual(arr_jax, arr_tf, err_msg=err_msg) - else: - mask_jax = np.isnan(first_arr_jax) | np.isinf(first_arr_jax) - mask_tf = np.isnan(first_arr_tf) | np.isinf(first_arr_tf) - tst.assertArraysEqual( - first_arr_jax[~mask_jax], first_arr_tf[~mask_tf], err_msg=err_msg) - - return [ - custom_numeric( - dtypes=[np.float16, dtypes.bfloat16, np.float32, np.float64], - custom_assert=custom_assert, - description=( - "Produces different results when the array contains `inf` and `NaN`" - " (they are sorted differently in TF vs. XLA).")) - ] - - @classmethod - def triangular_solve(cls, harness: test_harnesses.Harness): - return [ - missing_tf_kernel( - dtypes=[dtypes.bfloat16], - devices=("gpu", "cpu"), - modes=("eager", "graph")), - missing_tf_kernel( - dtypes=[np.float16], - devices=("gpu", "cpu"), - modes=("eager", "graph")), - custom_numeric(dtypes=[np.float32], tol=5e-3, - modes=("eager", "graph", "compiled")) - ] - - @classmethod - def tridiagonal_solve(cls, harness: test_harnesses.Harness): - return [] def custom_numeric( *, @@ -1504,7 +200,6 @@ def custom_numeric( devices=("cpu", "gpu", "tpu"), custom_assert=None, enabled=True, - native_serialization=Jax2TfLimitation.FOR_NON_NATIVE, tol=None) -> Jax2TfLimitation: return Jax2TfLimitation( @@ -1515,35 +210,4 @@ def custom_numeric( modes=modes, custom_assert=custom_assert, enabled=enabled, - native_serialization=native_serialization, tol=tol) - -def custom_random_keys_output(): - def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg): - # Here we handle both new-style and old-style keys; see JEP 9263 - def unwrap_keys(keys): - if jax.dtypes.issubdtype(keys.dtype, jax.dtypes.prng_key): - return jax._src.prng.random_unwrap(keys) - else: - return keys - - tst.assertAllClose(unwrap_keys(result_jax), result_tf, - atol=tol, rtol=tol, err_msg=err_msg) - - return custom_numeric( - description="Returns JAX key arrays, so compare underlying base array", - modes=("eager", "graph", "compiled"), - custom_assert=custom_assert) - - -def missing_tf_kernel(*, - description="op not defined for dtype", - dtypes, - modes=("eager", "graph", "compiled"), - devices=("cpu", "gpu", "tpu"), - native_serialization = Jax2TfLimitation.FOR_NON_NATIVE, - enabled=True) -> Jax2TfLimitation: - - return Jax2TfLimitation( - description, dtypes=dtypes, devices=devices, modes=modes, enabled=enabled, - native_serialization=native_serialization) diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index bea2b76cb7cf..80e44f8938c1 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -15,11 +15,11 @@ Specific JAX primitive conversion tests are in primitives_test.""" import collections -import contextlib import math import os import re import unittest +import warnings from absl import logging from absl.testing import absltest, parameterized @@ -36,22 +36,42 @@ from jax._src import source_info_util from jax._src import test_util as jtu from jax._src import xla_bridge as xb -from jax.experimental import jax2tf -from jax.experimental.jax2tf.tests import tf_test_util -from jax.experimental.shard_map import shard_map -from jax.experimental import pjit +from jax._src.shard_map import shard_map from jax.sharding import PartitionSpec as P import numpy as np -import tensorflow as tf +try: + # TODO(b/470156950): Remove this once a proper fix is in place + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", + category=FutureWarning, + message=".*np.object.*") + import tensorflow as tf + from jax.experimental import jax2tf + from jax.experimental.jax2tf.tests import tf_test_util + JaxToTfTestCase = tf_test_util.JaxToTfTestCase +except ImportError: + tf = None + jax2tf = None # type: ignore[assignment] + tf_test_util = None # type: ignore[assignment] + JaxToTfTestCase = jtu.JaxTestCase # type: ignore[misc] config.parse_flags_with_absl() -class Jax2TfTest(tf_test_util.JaxToTfTestCase): +@unittest.skipIf(tf is None, "Test requires tensorflow") +@jtu.thread_unsafe_test_class() +class Jax2TfTest(JaxToTfTestCase): def setUp(self): super().setUp() + versions = tf.version.VERSION.split(".") + if versions < ["2", "19", "1"]: + # StableHLO changed on March 18th, 2025 ,to version 1.10.0, and this + # introduces ops like vhlo_sine_v2. These ops require a TF version + # released after this date. + self.skipTest("Need version of TensorFlow at least 2.19.1") + # One TF device of each device_type self.tf_devices = [] for tf_device in (tf.config.list_logical_devices("TPU") + @@ -61,14 +81,6 @@ def setUp(self): continue # A virtual device if all(tf_device.device_type != d.device_type for d in self.tf_devices): self.tf_devices.append(tf_device) - self.warning_ctx = jtu.ignore_warning( - message="jax2tf.convert with native_serialization=False has been deprecated" - ) - self.warning_ctx.__enter__() - - def tearDown(self): - self.warning_ctx.__exit__(None, None, None) - super().tearDown() def test_empty(self): f_jax = lambda x, y: x @@ -825,24 +837,20 @@ def f(x1): x3 = jnp.sin(x2) x4 = jnp.sin(x3) return x4 - remat_f = ad_checkpoint.checkpoint(f) + remat_f = jax.checkpoint(f) # The computation of grad_f computes "sin" 5 times, 3 for the forward pass # and then to rematerialize "x2" and "x3" in the backward pass. arg = np.array(3.) f_tf = jax2tf.convert(jax.grad(remat_f)) f_tf_hlo = self.TfToHlo(f_tf, arg) - if config.remat_opt_barrier.value: - self.assertRegex(f_tf_hlo, r"opt-barrier") - else: - self.assertRegex(f_tf_hlo, - r'transpose/jax2tf_f_/jvp/checkpoint/cond/branch_1_fun/Sin') + self.assertRegex(f_tf_hlo, r"opt-barrier") def test_remat_free_var(self): def f(x): y = 2 * x - @ad_checkpoint.checkpoint + @jax.checkpoint def g(): return y @@ -958,12 +966,7 @@ def caller_jax(x): out = jax2tf.convert(caller_jax, with_gradient=False)(2.) return out - if config.jax2tf_default_native_serialization.value: - self.assertIn("my_test_function_jax/mul", self.TfToHlo(run_tf)) - else: - graph_def = str(tf.function(run_tf, autograph=False).get_concrete_function().graph.as_graph_def()) - if "my_test_function_jax/pjit_multiply_/Mul" not in graph_def: - self.assertIn("my_test_function_jax/jit_multiply_/Mul", graph_def) + self.assertIn("my_test_function_jax/mul", self.TfToHlo(run_tf)) def test_bfloat16_constant(self): # Re: https://github.com/jax-ml/jax/issues/3942 @@ -985,100 +988,6 @@ def jax_fn_array(x): tf_fn_array(np.array([3, 4, 5])), np.array([4.5, 10, 17.5], jnp.bfloat16)) - def test_shared_constants(self): - # Check that the constants are shared properly in converted functions - # See https://github.com/jax-ml/jax/issues/7992. - if config.jax2tf_default_native_serialization.value: - raise unittest.SkipTest("shared constants tests not interesting for native serialization") - const = np.random.uniform(size=256).astype(np.float32) # A shared constant - def f(x): - return x + const + const + const + const - - f_tf_consts = self.FindLargeTfConstants(jax2tf.convert(f), const) - self.assertLen(f_tf_consts, 1) - - def test_shared_constants_under_cond(self): - # Check that the constants are shared properly in converted functions - # See https://github.com/jax-ml/jax/issues/7992. - if config.jax2tf_default_native_serialization.value: - raise unittest.SkipTest("shared constants tests not interesting for native serialization") - const_size = 512 - const = np.random.uniform(size=const_size).astype(np.float32) # A shared constant - x = np.ones((const_size,), dtype=np.float32) - def f1(x): - # Ensure that we first see the constants in the inside jaxpr - return lax.cond(x[0] >= 0., lambda x: x + const, lambda x: x * const, x) + const - def f2(x): - return f1(x) + const # The extra const should not cost anything - f1_consts = self.FindLargeTfConstants(jax2tf.convert(f1), x, at_least=const_size) - f2_consts = self.FindLargeTfConstants(jax2tf.convert(f2), x, at_least=const_size) - self.assertLen(f2_consts, len(f1_consts)) - - def test_shared_constants_under_scan(self): - # See https://github.com/jax-ml/jax/issues/7992. - if config.jax2tf_default_native_serialization.value: - raise unittest.SkipTest("shared constants tests not interesting for native serialization") - const_size = 512 - const = np.random.uniform(size=const_size).astype(np.float32) # A shared constant - xs = np.ones((8, const_size), dtype=np.float32) - def f1(xs): - res, _ = lax.scan(lambda carry, x: (carry + x + const, None), - jnp.zeros((const_size,), dtype=np.float32), xs) - return res - - def f2(xs): - return f1(xs) + const # The extra const should not be saved - - f1_consts = self.FindLargeTfConstants(jax2tf.convert(f1), xs, at_least=const_size) - f2_consts = self.FindLargeTfConstants(jax2tf.convert(f2), xs, at_least=const_size) - self.assertLen(f2_consts, len(f1_consts)) - - def test_shared_constants_under_jit(self): - # We do not share constants under jit. - if config.jax2tf_default_native_serialization.value: - raise unittest.SkipTest("shared constants tests not interesting for native serialization") - const = np.random.uniform(size=(16, 16)).astype(np.float32) # A shared constant - @jax.jit - def g_jit(x): - return x * const - def f(x): - return g_jit(x) + const + const - - f_tf_graph_consts = self.FindLargeTfConstants(jax2tf.convert(f), const) - self.assertLen(f_tf_graph_consts, 1) - - def test_shared_constants_randint(self): - # randint has the property that the TF lowering of the randbits_p - # primitive generates constants that did not exist in the Jaxpr. As such - # it has created new errors related to the sharing of the constants. - if config.jax2tf_default_native_serialization.value: - raise unittest.SkipTest("shared constants tests not interesting for native serialization") - - key = jax.random.PRNGKey(42) - - def f_nested_jax(x): - # Lowering this will generate a tf.constant(shape=(1,), dtype=np.int32) - # that was not already in the Jaxpr, and hence JAX did not get a chance - # to share. - return x + jax.random.randint(key, shape=x.shape, - minval=0, maxval=100, dtype=np.int32) - def f_jax(x): - res = lax.cond(x[0] >= 2, lambda: f_nested_jax(x), lambda: f_nested_jax(x)) - res += lax.while_loop(lambda x: f_nested_jax(x)[0] <= 0, f_nested_jax, x) - # We also generate tf.while in the batching rule for cond - res += jax.vmap(lambda x: lax.cond(x[0] >= 2, - lambda: f_nested_jax(x), - lambda: f_nested_jax(x)))(jnp.stack([x, x])) - res += f_nested_jax(x) - return res - - # Must be odd to trigger the failure - x = np.array([123, 456, 789], dtype=np.int32) - - f_tf = tf.function(jax2tf.convert(f_jax), autograph=False) - res_tf = f_tf(x) - self.assertAllClose(res_tf, f_jax(x)) - def test_weak_types(self): mul = jax.jit(jnp.multiply) # The value `2` here should be weakly typed, and should not lead to @@ -1131,7 +1040,8 @@ def test_op_metadata_simple(self): self.skipTest("include_xla_op_metadata not yet enabled") # A simple example # The user_frame is used to compute line numbers for ops in the test. - user_frame = source_info_util.user_frame(source_info_util.current()) + user_frame = source_info_util.user_frame( + source_info_util.current().traceback) def f_simple(x): return jnp.sin(x) @@ -1150,7 +1060,8 @@ def test_op_metadata_sub_jit(self): self.skipTest("include_xla_op_metadata not yet enabled") # Calling a jitted-function # The user_frame is used to compute line numbers for ops in the test. - user_frame = source_info_util.user_frame(source_info_util.current()) + user_frame = source_info_util.user_frame( + source_info_util.current().traceback) def f_callee(x): return jnp.cos(x) def f_caller(x): @@ -1184,7 +1095,8 @@ def test_op_metadata_named(self): self.skipTest("include_xla_op_metadata not yet enabled") # Calling a jax.named_call # The user_frame is used to compute line numbers for ops in the test. - user_frame = source_info_util.user_frame(source_info_util.current()) + user_frame = source_info_util.user_frame( + source_info_util.current().traceback) def f_callee(x): return jnp.cos(x) def f_caller(x): @@ -1218,7 +1130,8 @@ def test_op_metadata_while_and_cond(self): self.skipTest("include_xla_op_metadata not yet enabled") # An example with while and cond # The user_frame is used to compute line numbers for ops in the test. - user_frame = source_info_util.user_frame(source_info_util.current()) + user_frame = source_info_util.user_frame( + source_info_util.current().traceback) def f_while_cond(x): def body_fun(i_acc): i, acc = i_acc @@ -1259,7 +1172,8 @@ def test_op_metadata_batched_while(self): self.skipTest("include_xla_op_metadata not yet enabled") # An example with while and cond # The user_frame is used to compute line numbers for ops in the test. - user_frame = source_info_util.user_frame(source_info_util.current()) + user_frame = source_info_util.user_frame( + source_info_util.current().traceback) @jax.vmap def f_while(x): def body_fun(carry): @@ -1307,7 +1221,7 @@ def f_simple(x): include_xla_op_metadata=False ) - def assertAllOperationStartWith(self, g: tf.Graph, scope_name: str): + def assertAllOperationStartWith(self, g: "tf.Graph", scope_name: str): """Assert all operations name start with ```scope_name```. Also the scope_name only occur one time. @@ -1321,9 +1235,9 @@ def assertAllOperationStartWith(self, g: tf.Graph, scope_name: str): self.fail(f"{op.name} does not start with {scope_name}.") def test_name_scope_polymorphic(self): - if (config.jax2tf_default_native_serialization.value and - not config.dynamic_shapes.value): - self.skipTest("shape polymorphism but --jax_dynamic_shapes is not set.") + self.skipTest("no more dynamic shapes") + # if not config.dynamic_shapes.value: + # self.skipTest("shape polymorphism but --jax_dynamic_shapes is not set.") def func_jax(x, y): return jnp.sin(x) + jnp.cos(y) @@ -1412,42 +1326,33 @@ def body(x): @parameterized.named_parameters( dict(testcase_name=( - f"{'with_mesh_' if with_mesh else ''}" f"2={transform2 if transform2 != 'none' else ''}" f"_1={transform1 if transform1 != 'none' else ''}" f"{'_nullary' if nullary else ''}"), - with_mesh=with_mesh, transform1=transform1, - transform2=transform2, nullary=nullary) + transform1=transform1, transform2=transform2, nullary=nullary) # Test transform2(transform1(func) for transform1 in [ "none", - "jit", - "pjit", "pjit_in_shardings_None", "pjit_in_shardings_P", - "pjit_in_shardings_Sharding", "shard_map", "pmap"] + "jit", "jit_in_shardings_None", + "jit_in_shardings_Sharding", "shard_map", "pmap"] for transform2 in ( - ["none", "pjit_in_shardings_None", "pjit_in_shardings_P", - "pjit_in_shardings_Sharding"] + ["none", "jit_in_shardings_None", + "jit_in_shardings_Sharding"] ) # Whether the function can be nullary for nullary in ( # To reduce the number of tests [True, False] if transform2 == "none" else [False]) - # Whether we use a "with mesh" - for with_mesh in ( - [True] if (transform1 not in ["base", "jit", "pjit"] or - transform2 != "none") else - [False, True]) ) - def test_cross_platform(self, with_mesh=True, transform1="pjit_in_shardings_P", - transform2="pjit_in_shardings_P", nullary=False): - # Tests cross-lowering for - # with mesh: - # transform2(transform1(func)) + def test_cross_platform(self, + transform1="jit_in_shardings_P", + transform2="jit_in_shardings_P", nullary=False): + # Tests cross-lowering for transform2(transform1(func)) if transform2 == "none" and ( transform1 == "shard_map" or - transform1 in ["pjit_in_shardings_P", "pjit_in_shardings_Sharding"] and nullary): - raise unittest.SkipTest("Skip because must have pjit at top level") + transform1 in ["jit_in_shardings_P", "jit_in_shardings_Sharding"] and nullary): + raise unittest.SkipTest("Skip because must have jit at top level") x = np.ones((4, 6), dtype=np.float32) mesh = sharding.Mesh(jax.devices()[:1], ("a",)) @@ -1456,27 +1361,19 @@ def test_cross_platform(self, with_mesh=True, transform1="pjit_in_shardings_P", # For shard_map we cannot use cummax :-( because it does not have a # replication rule. But we use lax.all_gather which on TPU is lowered with # an all-gather op - func_shard_map = lambda x: lax.all_gather(x, 'a', axis=1, tiled=True) + func_shard_map = lambda x: lax.all_gather(x, "a", axis=1, tiled=True) def apply_transform(func, transform: str): transformed_func = dict( none=func, jit=jax.jit(func), jit_in_shardings_None=jax.jit(func, in_shardings=None), - jit_in_shardings_P=jax.jit(func, in_shardings=(P("a"),)), jit_in_shardings_Sharding=jax.jit( - func, in_shardings=(sharding.NamedSharding(mesh, P("a")),)), - pjit=pjit.pjit(func), - pjit_in_shardings_None=pjit.pjit(func, in_shardings=None, - out_shardings=None), - pjit_in_shardings_P=pjit.pjit(func, in_shardings=(P("a"),), - out_shardings=P("a")), - pjit_in_shardings_Sharding=pjit.pjit( func, in_shardings=(sharding.NamedSharding(mesh, P("a")),), out_shardings=sharding.NamedSharding(mesh, P("a"))), shard_map=( - shard_map(func, mesh, in_specs=(P("a", None),), + shard_map(func, mesh=mesh, in_specs=(P("a", None),), out_specs=P("a", None))), pmap=jax.pmap(func, in_axes=0, out_axes=0), )[transform] @@ -1503,19 +1400,12 @@ def apply_transform(func, transform: str): raise unittest.SkipTest("Cannot lower nested pmap: jit-of-pmap warning") raise unittest.SkipTest("TODO: figure out how to invoke pmap from TF") - f_tf = jax2tf.convert(func_to_convert, - native_serialization=True, - native_serialization_platforms=('tpu',)) - f_tf = tf.function(f_tf, jit_compile=True, autograph=False) - with contextlib.ExitStack() as stack: - if with_mesh: - stack.enter_context(mesh) - # Run the JAX native version, to check it works, and to fill caches. - _ = func_to_convert(*args) - exported = export.export( - (jax.jit(func_to_convert) if not hasattr(func_to_convert, "trace") else func_to_convert), - platforms=("tpu",) - )(*(core.ShapedArray(a.shape, a.dtype) for a in args)) + # Run the JAX native version, to check it works, and to fill caches. + _ = func_to_convert(*args) + exported = export.export( + (jax.jit(func_to_convert) if not hasattr(func_to_convert, "trace") else func_to_convert), + platforms=("tpu",) + )(*(core.ShapedArray(a.shape, a.dtype) for a in args)) if transform1 == "shard_map": self.assertIn("stablehlo.all_gather", str(exported.mlir_module())) @@ -1523,7 +1413,7 @@ def apply_transform(func, transform: str): self.assertIn("stablehlo.reduce_window", str(exported.mlir_module())) def test_cross_platform_error(self): - f_tf = jax2tf.convert(jnp.sin, native_serialization=True, + f_tf = jax2tf.convert(jnp.sin, native_serialization_platforms=('tpu',)) x = np.float32(.5) if jtu.test_device_matches(["tpu"]): @@ -1537,7 +1427,6 @@ def test_cross_platform_error(self): "The current platform .* is not among the platforms required by the module"): f_tf(x) - @jtu.ignore_warning(message="using native_serialization_platforms without native_serialization") def test_native_parameters_for_non_native(self): # We can use the native_serialization_platforms even for non-native # serialization. @@ -1558,7 +1447,7 @@ def test_native_parameters_for_non_native(self): def test_native_serialization_grad(self): # Check that the grad function uses the same native serialization parameters # as the primal function. - f_tf = jax2tf.convert(jnp.sin, native_serialization=True, + f_tf = jax2tf.convert(jnp.sin, native_serialization_platforms=('tpu',)) x = np.arange(4, dtype=np.float32) x_v = tf.Variable(x) @@ -1590,7 +1479,7 @@ def f_jax(x): with self.assertRaisesRegex(NotImplementedError, "serialization of host_callbacks is not yet implemented"): - jax2tf.convert(f_jax, native_serialization=True)(np.float32(42.)) + jax2tf.convert(f_jax)(np.float32(42.)) def f_ordered_jax(x): jax.debug.print("{}", x, ordered=True) @@ -1598,7 +1487,7 @@ def f_ordered_jax(x): with self.assertRaisesRegex(NotImplementedError, "serialization of host_callbacks is not yet implemented"): - jax2tf.convert(f_ordered_jax, native_serialization=True)(np.float32(42.)) + jax2tf.convert(f_ordered_jax)(np.float32(42.)) def test_tuple_args(self): # On TPU if we have more than 2000 arguments, we pass them as a tuple. @@ -1615,11 +1504,9 @@ def f_jax(*many_args): # Test that we do set lowered.compile_args[tuple_args] lowered = jax.jit(f_jax).lower(*many_args) self.assertTrue(lowered._lowering.compile_args["tuple_args"]) - res = jax2tf.convert(f_jax, native_serialization=True)(*many_args) + res = jax2tf.convert(f_jax)(*many_args) self.assertAllClose(f_jax(*many_args), res) - @jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor", - category=DeprecationWarning) def test_nested_convert(self): # Test call sequence: convert -> call_tf -> convert. @@ -1631,13 +1518,13 @@ def f_jax(x): res = f_jax(inputs) - f_tf = jax2tf.convert(f_jax, native_serialization=True) + f_tf = jax2tf.convert(f_jax) self.assertAllClose(res, f_tf(inputs)) f_jax_nested = jax2tf.call_tf(f_tf) self.assertAllClose(res, f_jax_nested(inputs)) - f_tf_nested = jax2tf.convert(f_jax_nested, native_serialization=True) + f_tf_nested = jax2tf.convert(f_jax_nested) self.assertAllClose(res, f_tf_nested(inputs)) def test_multi_platform(self): @@ -1659,7 +1546,6 @@ def f_jax(x): x = np.float32(.42) f_tf = jax2tf.convert( f_jax, - native_serialization=True, native_serialization_platforms=("cpu", "cuda", "tpu")) for tf_device in self.tf_devices: logging.info( @@ -1686,31 +1572,62 @@ def test_dot_algorithm(self): def f_jax(x): return jax.lax.dot(x, x, precision=algorithm) - f_tf = jax2tf.convert(f_jax, native_serialization=True) + f_tf = jax2tf.convert(f_jax) f_tf(np.ones((128, 128), dtype=np.float32)) # no crash - def test_dot_algorithm_non_native_unsupported(self): - def f_jax(x): - return jax.lax.dot(x, x, precision="F32_F32_F32") - - x = np.ones((128, 128), dtype=np.float32) - with self.assertRaisesRegex(NotImplementedError, - "Unsupported precision in dot_general"): - jax2tf.convert(f_jax, native_serialization=False)(x) + def test_jvp_through_loop(self): + # Context: b/388929258 + + num_actions = 512 + + def tf_preprocessor(features): + features["num_c_actions"] = tf.constant(256, tf.int32) + return features + + def postprocessor(prob, features): + actions = jnp.arange(num_actions, dtype=jnp.int32) + r = actions // features["num_c_actions"] + c = actions - r * features["num_c_actions"] + rr = jnp.array([0.12, 0.3])[r] * prob + rc = (jnp.arange(256) * 0.7)[c] * prob + return rr, rc + + def loop_step(features, params): + features = jax2tf.call_tf(tf_preprocessor)(features) + odds = features["f1"] @ params["w1"] + features["f2"] @ params["w2"] + prob = jax.nn.sigmoid(odds) + rr, rc = postprocessor(prob, features) + new_f1 = jnp.mean(rr, keepdims=True) + new_f2 = jnp.mean(rc, keepdims=True) + return new_f1, new_f2 + + def loop(init_features, params): + def body(carry, unused_x): + f1, f2 = carry + return loop_step({"f1": f1, "f2": f2}, params), None + + (rr, rc), _ = jax.lax.scan( + body, (init_features["f1"], init_features["f2"]), length=10 + ) + return rr, rc + + def loss(features, params): + rr, rc = loop(features, params) + return jnp.mean((rr - rc) ** 2) + + jax.grad(loss, argnums=(1,))( + {"f1": jnp.array([0.5]), "f2": jnp.array([0.7])}, + { + "w1": jnp.ones((1, num_actions)) * 0.01, + "w2": jnp.ones((1, num_actions)) * 0.01, + }, + ) +@unittest.skipIf(tf is None, "Test requires tensorflow") @jtu.with_config(jax_enable_custom_prng=True) -class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase): - def setUp(self): - super().setUp() - self.warning_ctx = jtu.ignore_warning( - message="jax2tf.convert with native_serialization=False has been deprecated" - ) - self.warning_ctx.__enter__() - - def tearDown(self): - self.warning_ctx.__exit__(None, None, None) - super().tearDown() +@jtu.thread_unsafe_test_class() +class Jax2tfWithCustomPRNGTest(JaxToTfTestCase): def test_key_argument(self): func = lambda key: jax.random.uniform(key, ()) @@ -1738,15 +1655,21 @@ def func(): jax_result = func() self.assertEqual(tf_result, jax_result) -class Jax2TfVersioningTest(tf_test_util.JaxToTfTestCase): + +@unittest.skipIf(tf is None, "Test requires tensorflow") +@jtu.thread_unsafe_test_class() +class Jax2TfVersioningTest(JaxToTfTestCase): # Use a separate test case with the default jax_serialization_version def setUp(self): self.use_max_serialization_version = False + versions = tf.version.VERSION.split(".") + if versions < ["2", "19", "1"]: + # StableHLO changed on March 18th, 2025 ,to version 1.10.0, and this + # introduces ops like vhlo_sine_v2. These ops require a TF version + # released after this date. + self.skipTest("Need version of TensorFlow at least 2.19.1") super().setUp() - @jtu.ignore_warning( - message="jax2tf.convert with native_serialization=False has been deprecated" - ) def test_simple(self): self.ConvertAndCompare(jnp.sin, 0.7) diff --git a/jax/experimental/jax2tf/tests/multiprocess/BUILD b/jax/experimental/jax2tf/tests/multiprocess/BUILD new file mode 100644 index 000000000000..58b185b803c7 --- /dev/null +++ b/jax/experimental/jax2tf/tests/multiprocess/BUILD @@ -0,0 +1,33 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//jaxlib:jax.bzl", "jax_multiprocess_generate_backend_suites", "jax_multiprocess_test", "py_deps") + +licenses(["notice"]) + +package(default_applicable_licenses = []) + +jax_multiprocess_generate_backend_suites() + +jax_multiprocess_test( + name = "jax2tf_multiprocess_test", + srcs = ["jax2tf_multiprocess_test.py"], + main = "jax2tf_multiprocess_test.py", + tags = ["jax2tf"], + deps = [ + "//jax/_src:test_multiprocess", + "//jax/experimental/jax2tf", + "//jax/experimental/jax2tf/tests:tf_test_util", + ] + py_deps("tensorflow"), +) diff --git a/jax/experimental/jax2tf/tests/multiprocess/jax2tf_multiprocess_test.py b/jax/experimental/jax2tf/tests/multiprocess/jax2tf_multiprocess_test.py new file mode 100644 index 000000000000..fa4861f55d92 --- /dev/null +++ b/jax/experimental/jax2tf/tests/multiprocess/jax2tf_multiprocess_test.py @@ -0,0 +1,81 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multihost test for JAX2TF.""" + +import jax +from jax import numpy as jnp +from jax._src import pjit +from jax._src import test_multiprocess as jt_multiprocess +from jax._src import test_util as jtu +from jax.experimental import multihost_utils +from jax.sharding import PartitionSpec as P +import unittest +import warnings + +try: + # TODO(b/470156950): Remove this once a proper fix is in place + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", + category=FutureWarning, + message=".*np.object.*") + import tensorflow as tf + from jax.experimental import jax2tf + from jax.experimental.jax2tf.tests import tf_test_util + JaxToTfTestCase = tf_test_util.JaxToTfTestCase +except ImportError: + tf = None + jax2tf = None # type: ignore[assignment] + tf_test_util = None # type: ignore[assignment] + JaxToTfTestCase = jtu.JaxTestCase # type: ignore[misc] + + +@unittest.skipIf(tf is None, "Test requires tensorflow.") +class Jax2TfMultiProcessTest(JaxToTfTestCase, jt_multiprocess.MultiProcessTest): + + def test_multi_process_pjit_export(self): + """Pjitted function can be exported.""" + key_w, key_x = jax.random.split(jax.random.PRNGKey(1234), 2) + w = jax.random.uniform(key_w, [16, 16], dtype=jnp.float32) + x = jax.random.uniform(key_x, [16, 1], dtype=jnp.float32) + + with jtu.create_mesh((4, 2), ("x", "y")): + pjit_matmul = pjit.pjit(jnp.matmul, in_shardings=(P("x", "y"), None)) + jax_result = multihost_utils.process_allgather( + pjit_matmul(w, x), tiled=True) + + tf_model = tf.Module() + tf_model.w = tf.Variable(w) + tf_closure = tf.function( + lambda x: {"y": jax2tf.convert(pjit_matmul)(tf_model.w, x)}, + autograph=False, + ).get_concrete_function( + tf.TensorSpec.from_tensor(tf.constant(x), name="x") + ) + + if jax.process_index() == 0: + export_dir = self.create_tempdir().full_path + tf.saved_model.save( + tf_model, + export_dir, + signatures={"serving_default": tf_closure}, + ) + loaded = tf.saved_model.load(export_dir) + tf_result = loaded.signatures["serving_default"](x=x)["y"] + + self.assertAllClose(tf_result.numpy(), jax_result) + + +if __name__ == "__main__": + jt_multiprocess.main() diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index 1ccd009f157c..235b7da6bc11 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -29,10 +29,6 @@ in Tensorflow errors (for some devices and compilation modes). These limitations are captured as jax2tf_limitations.Jax2TfLimitation objects. -From the limitations objects, we generate a -[report](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md). -The report has instructions for how to re-generate it. - If a harness run fails with error, and a limitation that matches the device and data types is found, the error is logged but does not abort the test. If a harness run succeeds @@ -51,8 +47,6 @@ """ -import datetime -import os from typing import Any import unittest @@ -65,11 +59,8 @@ from jax import numpy as jnp from jax._src import config from jax._src import test_util as jtu -from jax.experimental import jax2tf -from jax.interpreters import mlir import numpy as np -import tensorflow as tf config.parse_flags_with_absl() @@ -90,6 +81,7 @@ ) +@jtu.thread_unsafe_test_class() class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase): # This test runs for all primitive harnesses. For each primitive "xxx" the @@ -118,27 +110,23 @@ def test_prim(self, harness: test_harnesses.Harness): device == "tpu"): raise unittest.SkipTest("b/264716764: error on tf.cast from c64 to f32") - if ("eigh" == harness.group_name and - device == "cpu"): + if "eigh" == harness.group_name and device == "cpu": raise unittest.SkipTest( "Equality comparisons on eigendecompositions are not stable.") - if (config.jax2tf_default_native_serialization.value and - device == "gpu" and - "lu" in harness.fullname): + if device == "gpu" and "lu" in harness.fullname: raise unittest.SkipTest("b/269388847: lu failures on GPU") def skipCustomCallTest(target: str): raise unittest.SkipTest( f"TODO(b/272239584): custom call target not guaranteed stable: {target}") - if config.jax2tf_default_native_serialization.value: - if device == "gpu": - if "custom_linear_solve_" in harness.fullname: - skipCustomCallTest("cusolver_geqrf, cublas_geqrf_batched") - if "svd_shape" in harness.fullname: - skipCustomCallTest("cusolver_gesvdj") - if "tridiagonal_solve_shape" in harness.fullname: - skipCustomCallTest("cusparse_gtsv2_f32, cusparse_gtsv2_f64") + if device == "gpu": + if "custom_linear_solve_" in harness.fullname: + skipCustomCallTest("cusolver_geqrf, cublas_geqrf_batched") + if "svd_shape" in harness.fullname: + skipCustomCallTest("cusolver_gesvdj") + if "tridiagonal_solve_shape" in harness.fullname: + skipCustomCallTest("cusparse_gtsv2_f32, cusparse_gtsv2_f64") associative_scan_reductions = harness.params.get("associative_scan_reductions", False) try: @@ -146,161 +134,12 @@ def skipCustomCallTest(target: str): self.ConvertAndCompare(func_jax, *args, limitations=limitations) except Exception as e: # TODO(b/264596006): custom calls are not registered properly with TF in OSS - if (config.jax2tf_default_native_serialization.value and - "does not work with custom calls" in str(e)): + if "does not work with custom calls" in str(e): logging.warning("Suppressing error %s", e) raise unittest.SkipTest("b/264596006: custom calls in native serialization fail in TF") else: raise e - def test_primitive_coverage(self): - """Fail if there are JAX primitives that are not implemented.""" - # Harvest primitives from XLA translation tables - all_primitives = ( - set(mlir._lowerings) - | set(mlir._platform_specific_lowerings["cpu"]) - | set(mlir._platform_specific_lowerings["gpu"]) - | set(mlir._platform_specific_lowerings["tpu"])) - - tf_impl = set(jax.experimental.jax2tf.jax2tf.tf_impl) | set( - jax.experimental.jax2tf.jax2tf.tf_impl_with_avals) - tf_not_yet_impl = set(jax.experimental.jax2tf.jax2tf.tf_not_yet_impl) - - all_primitives = tuple(sorted(all_primitives, key=str)) - for p in all_primitives: - if p.name == "axis_index": - continue - if p.name == "composite": - continue - if p.name == "sharding_constraint": - continue - if p.name == "mesh_cast": - continue - if p.name == "reshard": - continue - # TODO: Remove once tensorflow is 2.10.0 everywhere. - if p.name == "optimization_barrier": - continue - if p.name == "debug_callback" or p.name == "debug_print": - # TODO(sharadmv,necula): enable debug callbacks in TF - continue - if p.name in ("max_contiguous", "multiple_of", "run_scoped"): - # Pallas-specific primitives are not supported. - continue - if p.name == "pallas_call": - continue - if p.name == "ragged_all_to_all": - continue - if p.name == "ffi_call": - continue - if p.name == "tpu_custom_call": - continue - if p.name == "custom_partitioning": - continue - if p.name in ( - "dot_product_attention_fwd", - "dot_product_attention_bwd", - "dot_product_attention_fwd_wrapper", - "dot_product_attention_bwd_wrapper", - "dot_product_attention_fp8_fwd_wrapper", - "dot_product_attention_fp8_bwd_wrapper", - ): - continue - if p.name == "scaled_matmul_wrapper": - continue - if p.name in tf_not_yet_impl: - self.assertNotIn( - p, tf_impl) # Should not be in both tf_impl and tf_not_yet_impl - else: - self.assertIn(p, tf_impl) - - def test_generate_limitations_doc(self): - """Generates primitives_with_limited_support.md. - - See the doc for instructions. - """ - - harnesses = [ - h for h in test_harnesses.all_harnesses - if h.filter(h, include_jax_unimpl=True) - ] - print(f"Found {len(harnesses)} test harnesses that work in JAX") - - def unique_hash(h: test_harnesses.Harness, l: Jax2TfLimitation): - return (h.group_name, l.description, l.devices, - tuple(np.dtype(d).name for d in l.dtypes), l.modes) - - unique_limitations: dict[Any, tuple[test_harnesses.Harness, Jax2TfLimitation]] = {} - for h in harnesses: - for l in h.jax_unimplemented: - if l.enabled: - # Fake a Jax2TFLimitation from the Limitation - tfl = Jax2TfLimitation(description="Not implemented in JAX: " + l.description, - devices = l.devices, - dtypes = l.dtypes, - expect_tf_error = False, - skip_tf_run = True) - unique_limitations[hash(unique_hash(h, tfl))] = (h, tfl) - for h in harnesses: - for l in Jax2TfLimitation.limitations_for_harness(h): - unique_limitations[hash(unique_hash(h, l))] = (h, l) - - print(f"Found {len(unique_limitations)} unique limitations") - tf_error_table = [ - """ -| Affected primitive | Description of limitation | Affected dtypes | Affected devices | Affected compilation modes | -| --- | --- | --- | --- | --- |""" - ] - tf_numerical_discrepancies_table = list(tf_error_table) # a copy - for h, l in sorted( - unique_limitations.values(), key=lambda pair: unique_hash(*pair)): - devices = ", ".join(sorted(l.devices)) - modes = ", ".join(sorted(l.modes)) - description = l.description - if l.skip_comparison: - description = "Numeric comparison disabled: " + description - if l.expect_tf_error: - description = "TF error: " + description - if l.skip_tf_run: - description = "TF test skipped: " + description - - if l.skip_tf_run or l.expect_tf_error: - to_table = tf_error_table - elif l.skip_comparison or l.custom_assert: - to_table = tf_numerical_discrepancies_table - else: - continue - - to_table.append( - f"| {h.group_name} | {description} | " - f"{test_harnesses.dtypes_to_str(l.dtypes, empty_means_all=True)} | {devices} | {modes} |" - ) - - if not os.environ.get("JAX_OUTPUT_LIMITATIONS_DOC"): - raise unittest.SkipTest( - "Set JAX_OUTPUT_LIMITATIONS_DOC=1 to enable the generation of the documentation" - ) - # The CPU has more supported types, and harnesses - self.assertEqual("cpu", jtu.device_under_test()) - self.assertTrue( - config.enable_x64.value, - "Documentation generation must be run with JAX_ENABLE_X64=1") - - with open( - os.path.join( - os.path.dirname(__file__), - "../g3doc/primitives_with_limited_support.md.template")) as f: - template = f.read() - output_file = os.path.join( - os.path.dirname(__file__), - "../g3doc/primitives_with_limited_support.md") - - with open(output_file, "w") as f: - f.write(template.replace("{{generation_date}}", str(datetime.date.today())) \ - .replace("{{tf_error_table}}", "\n".join(tf_error_table)) \ - .replace("{{tf_numerical_discrepancies_table}}", "\n".join(tf_numerical_discrepancies_table)) \ - ) - # The rest of the test are checking special cases @parameterized.named_parameters( @@ -320,19 +159,6 @@ def test_type_promotion(self, f_jax=jnp.add): y = np.array([3, 4], dtype=y_dtype) self.ConvertAndCompare(f_jax, x, y) - def test_integer_div(self): - x = jnp.array([-4, -3, -1, 0, 1, 3, 6]) - y = np.int32(3) - self.ConvertAndCompare(jnp.floor_divide, x, y) - expected = jnp.floor_divide(x, y) - if not config.jax2tf_default_native_serialization.value: - # With native serialization TF1 seems to want to run the converted code - # on the CPU even when the default backend is the TPU. - # Try it with TF 1 as well (#5831) - with tf.compat.v1.Session() as sess: - tf1_res = sess.run(jax2tf.convert(jnp.floor_divide)(x, y)) - self.assertAllClose(expected, tf1_res) - def test_boolean_gather(self): values = np.array([[True, True], [False, True], [False, False]], dtype=np.bool_) diff --git a/jax/experimental/jax2tf/tests/savedmodel_test.py b/jax/experimental/jax2tf/tests/savedmodel_test.py index aee15883332a..16d83879c483 100644 --- a/jax/experimental/jax2tf/tests/savedmodel_test.py +++ b/jax/experimental/jax2tf/tests/savedmodel_test.py @@ -28,6 +28,7 @@ jax.config.parse_flags_with_absl() +@jtu.thread_unsafe_test_class() class SavedModelTest(tf_test_util.JaxToTfTestCase): def setUp(self): diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 09da97e8420a..08c506da946f 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -17,34 +17,22 @@ from collections.abc import Callable, Sequence import contextlib import math +import re from typing import Any -import unittest from absl import logging from absl.testing import absltest -import collections -import functools -from functools import partial -import operator as op -import re - import jax from jax.experimental import jax2tf -from jax.experimental import pjit from jax import export from jax import lax import jax.numpy as jnp from jax import random -from jax import tree_util -from jax._src import api_util from jax._src import config from jax._src import core from jax._src import test_util as jtu from jax._src import util -from jax._src.export import shape_poly -from jax._src.lax import lax as lax_internal -from jax._src.lax import control_flow as lax_control_flow import numpy as np from jax.experimental.jax2tf.tests import tf_test_util @@ -55,7 +43,7 @@ # Import after parsing flags from jax._src.internal_test_util import test_harnesses -from jax._src.internal_test_util.test_harnesses import Harness, CustomArg, RandArg, StaticArg +from jax._src.internal_test_util.test_harnesses import Harness, RandArg from jax.experimental.jax2tf.tests.jax2tf_limitations import Jax2TfLimitation _f32 = np.float32 @@ -244,6 +232,7 @@ def check_shape_poly(tst, f_jax: Callable, *, return h.run_test(tst) +@jtu.thread_unsafe_test_class() class ShapePolyTest(tf_test_util.JaxToTfTestCase): def test_simple_unary(self): @@ -296,8 +285,7 @@ def f_jax(x, y): expected_output_signature=( # for native serialization we cannot refine the inferred shape of the # output if the input is more specific than polymorphic_shapes. - tf.TensorSpec([2, 3]) if not config.jax2tf_default_native_serialization.value - else tf.TensorSpec([2, None]))) + tf.TensorSpec([2, None]))) check_shape_poly(self, f_jax, @@ -305,72 +293,6 @@ def f_jax(x, y): polymorphic_shapes=["h, h", "h, h"], expected_output_signature=tf.TensorSpec([None, None])) - @jtu.parameterized_filterable( - # make_args invoked with op.shape[0]: start, stop, step, dtype - kwargs=[ - dict(testcase_name=name, make_args=make_args, expect_error=expect_error, expect_msg=expect_msg) - for name, make_args, expect_error, expect_msg in [ - # make_args invoked with op.shape[0]: start, stop, step, dtype - ("float_start", lambda b: (0., b, None), - ValueError, "must be either dimension expressions or integers"), - ("float_step", lambda b: (0, b, 0.5), - ValueError, "must be either dimension expressions or integers"), - ("step_0", lambda b: (0, b, 0), - ValueError, "has step == 0"), - ("inconclusive_step_sign", lambda b: (0, b, b - 2), - core.InconclusiveDimensionOperation, - "must be resolved statically if it is > 0 or < 0"), - ] - ] - ) - def test_arange_error(self, make_args=lambda b: (0., b, 2), - expect_error=ValueError, - expect_msg="must be either dimension expressions or integers"): - def f_jax(x): # x: i32[b] - return x[0] + jnp.arange(*(make_args(x.shape[0]))) - x = np.ones((3,), dtype=np.int32) - with self.assertRaisesRegex(expect_error, expect_msg): - check_shape_poly(self, f_jax, arg_descriptors=[x], - polymorphic_shapes=["b"]) - - @jtu.parameterized_filterable( - kwargs=[ - dict(testcase_name=f"expr={name}", expr=expr) - for name, expr in [ - ("d + 2", lambda d: d + 2), - ("2 - d", lambda d: 2 - d), - ("d * 2", lambda d: d * 2), - ("d * d", lambda d: d * d), - ("(- d) * d", lambda d: (- d) * d), - ("d * d - d", lambda d: d * d - d), - # Division - ("d // 2", lambda d: d // 2), - ("(d + 1) // 2", lambda d: (d + 1) // 2), - ("d // -2", lambda d: d // -2), - ("(d + 1) // -2", lambda d: (d + 1) // -2), - ("(-d) // 2", lambda d: (-d) // 2), - ("(-d - 1) // 2", lambda d: (-d - 1) // 2), - ("(-d) // -2", lambda d: (-d) // -2), - ("(-d - 1) // -2", lambda d: (-d - 1) // -2), - # Remainder - ("d % 2", lambda d: d % 2), - ("(d + 1) % 2", lambda d: (d + 1) % 2), - ("d % -2", lambda d: d % -2), - ("(d + 1) % -2", lambda d: (d + 1) % -2), - ("(-d) % 2", lambda d: (-d) % 2), - ("(-d - 1) % 2", lambda d: (-d - 1) % 2), - ("(-d) % -2", lambda d: (-d) % -2), - ("(-d - 1) % -2", lambda d: (-d - 1) % -2), - ] - ]) - def test_non_trivial_dim_expr(self, expr=lambda d: d % -2): - # Check the lowering for shape expressions - check_shape_poly( - self, - lambda x: x[0] * 0 + expr(x.shape[0]), - arg_descriptors=[RandArg((3,), np.int64)], - polymorphic_shapes=["b"]) - def test_static_shape_result(self): """The result has static shape.""" @@ -399,8 +321,6 @@ def test_forgot_polymorphic_shapes_error(self): polymorphic_shapes=[None]) def test_with_constraints(self): - if not config.jax2tf_default_native_serialization.value: - self.skipTest("not supported") def f_jax(x): # x: i32[a], with a >= 8 return lax.dynamic_slice_in_dim(x, 0, 8, 0) check_shape_poly(self, f_jax, @@ -419,123 +339,6 @@ def f_jax(x, *, y): f_tf: Callable[..., Any] = jax2tf.convert(f_jax, polymorphic_shapes=["b, ..."]) self.assertAllClose(f_jax(x, y=y), f_tf(x, y=y)) - def test_arg_avals_non_native(self): - """Test conversion of actual arguments to abstract values.""" - - def check_avals(*, arg_shapes: Sequence[Sequence[int | None]], - polymorphic_shapes: Sequence[str | None], - expected_shapes: Sequence[str] | None = None, - expected_shapeenv: dict[str, int] | None = None, - eager_mode: bool = False): - # Use eager mode only for when all arg_shapes are known, in order to - # check expected_shapeenv. - arg_dtypes = (_f32,) * len(arg_shapes) - symbolic_scope = shape_poly.SymbolicScope() - def f_tf(*args_tf): - avals = tuple(map( - lambda s, dt, spec: core.ShapedArray( - export.symbolic_shape(spec, like=s, scope=symbolic_scope), - dt), - arg_shapes, arg_dtypes, polymorphic_shapes)) - dim_vars = shape_poly.all_dim_vars(avals) - dim_values, _ = jax2tf.jax2tf._interpret_fun_jax( - partial(shape_poly.compute_dim_vars_from_arg_shapes, - avals, - args_kwargs_tree=tree_util.tree_flatten((avals, {}))[1]), - args_tf, avals, "", - debug_info=api_util.debug_info("jax2tf dim_vars", - shape_poly.compute_dim_vars_from_arg_shapes, - avals, {})) - if expected_shapes is not None: - expected_avals = tree_util.tree_map( - lambda shape_str: core.ShapedArray( - shape_poly.symbolic_shape(shape_str, scope=symbolic_scope), - np.float32), - expected_shapes) - self.assertEqual(expected_avals, avals) - return dict(zip(dim_vars, dim_values)) - if eager_mode: - # If we want to check the shape_env then all arg_shapes must be known - assert all(all(d is not None for d in a_s) - for a_s in arg_shapes) - shape_env = f_tf(*[tf.ones(a_s, dtype=_f32) for a_s in arg_shapes]) - if expected_shapeenv is not None: - for v, val in expected_shapeenv.items(): - self.assertEqual(val, shape_env.get(v)) - else: - f_tf = tf.function(autograph=False)(f_tf) - f_tf.get_concrete_function(*[tf.TensorSpec(a_s, _f32) - for a_s in arg_shapes]) - assert not expected_shapeenv, "Should use eager_mode=True" - - # Known shapes for the arguments - check_avals( - arg_shapes=[(2, 3)], - polymorphic_shapes=[None], - expected_shapes=("2, 3",)) - - check_avals( - arg_shapes=[(2, 3)], - polymorphic_shapes=["(2, 3)"], - expected_shapes=("2, 3",)) - - check_avals( - arg_shapes=[(2, 3)], - polymorphic_shapes=["(_, 3)"], - expected_shapes=("2, 3",)) - - check_avals( - arg_shapes=[(2, 3)], - polymorphic_shapes=["..."], - expected_shapes=("2, 3",)) - - # Partially known shapes for the arguments - check_avals( - arg_shapes=[(None, 3)], - polymorphic_shapes=["b, ..."], - expected_shapes=("(b, 3)",)) - - check_avals( - arg_shapes=[(None, None)], - polymorphic_shapes=["h, h"], - expected_shapes=("(h, h)",)) - - check_avals( - arg_shapes=[(2, None)], - polymorphic_shapes=["h, h"], - expected_shapes=("(h, h)",)) - - check_avals( - arg_shapes=[(None, 3, 4)], - polymorphic_shapes=["(c, b, a)"], - expected_shapes=("(c, b, a)",), - ) - - # Check cases when the specifications are polynomials - check_avals( - arg_shapes=[(2, 3)], - polymorphic_shapes=["a + 1, b + 2"], - eager_mode=True, - expected_shapeenv=dict(a=1, b=1)) - - check_avals( - arg_shapes=[(7, 5)], - polymorphic_shapes=["2 * a + b, b + 2"], - eager_mode=True, - expected_shapeenv=dict(a=2, b=3)) - - check_avals( - arg_shapes=[(7, 11, 4)], - polymorphic_shapes=["2 * a + b, b * b + 2, b + 1"], - eager_mode=True, - expected_shapeenv=dict(a=2, b=3)) - - check_avals( - arg_shapes=[(7, 11, 19, 7)], - polymorphic_shapes=["2 * a + b, b * b + 2, b + c * c, 2 * c + -1"], - eager_mode=True, - expected_shapeenv=dict(a=2, b=3, c=4)) - def test_arg_avals_errors(self): """Test error reporting for shape polymorphism.""" def conv_and_run(*, arg_shape: core.Shape, @@ -595,7 +398,7 @@ def conv_and_run(*, arg_shape: core.Shape, "Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). " "Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), " "'b' = 0 from specification '2*b + a' for dimension args[0].shape[0] (= 2), . " - "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." + "Please see https://docs.jax.dev/en/latest/export/shape_poly.html#shape-assertion-errors for more details." )), dict(shape=(3, 2, 6), # a = 2, b = 0.5, c = 4 - b is not integer poly_spec="(a + 2*b, a, a + b + c)", @@ -604,7 +407,7 @@ def conv_and_run(*, arg_shape: core.Shape, "Division had remainder 1 when computing the value of 'b'. " "Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). " "Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), . " - "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." + "Please see https://docs.jax.dev/en/latest/export/shape_poly.html#shape-assertion-errors for more details." )), dict(shape=(8, 2, 6), # a = 2, b = 3 - inconsistency poly_spec="(a + 2*b, a, a + b)", @@ -614,7 +417,7 @@ def conv_and_run(*, arg_shape: core.Shape, "Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, b + a). " "Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), " "'b' = 4 from specification 'b + a' for dimension args[0].shape[2] (= 6), . " - "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." + "Please see https://docs.jax.dev/en/latest/export/shape_poly.html#shape-assertion-errors for more details." )), dict(shape=(7, 2, 36), # a = 2, b = 3, c = 6 - cannot solve c poly_spec="(2 * a + b, a, c * c)", @@ -623,7 +426,7 @@ def conv_and_run(*, arg_shape: core.Shape, "We can only solve linear uni-variate constraints. " "Using the following polymorphic shapes specifications: args[0].shape = (b + 2*a, a, c^2). " "Unprocessed specifications: 'c^2' for dimension size args[0].shape[2]. " - "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details." + "Please see https://docs.jax.dev/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details." )), ]) def test_shape_constraints_errors(self, *, @@ -645,8 +448,7 @@ def test_pytree(self): # Arguments are of the form [([x00, x01], [x10]), dict(a=ya, b=yb)] def add_all_jax(x_pair_of_list, y_dict): x_list_0, x_list_1 = x_pair_of_list - return functools.reduce(op.add, - x_list_0 + x_list_1 + [y_dict["a"], y_dict["b"]]) + return sum(x_list_0 + x_list_1 + [y_dict["a"], y_dict["b"]]) input_signature = [([tf.TensorSpec([None]), tf.TensorSpec([None])], [tf.TensorSpec([None])]), @@ -710,8 +512,7 @@ def test_pytree_errors(self, polymorphic_shapes=("b", "b", "b")): args = (([x, x], [x]), dict(a=x, b=x)) def add_all_jax(x_pair_of_list, y_dict): x_list_0, x_list_1 = x_pair_of_list - return functools.reduce(op.add, - x_list_0 + x_list_1 + [y_dict["a"], y_dict["b"]]) + return sum(x_list_0 + x_list_1 + [y_dict["a"], y_dict["b"]]) with self.assertRaisesRegex(ValueError, "pytree structure error"): jax2tf.convert(add_all_jax, @@ -841,12 +642,8 @@ def tf_value_and_grad(xv): # for native serialization we cannot refine the inferred shape of the # output if the input is more specific than polymorphic_shapes. - if config.jax2tf_default_native_serialization.value: - self.assertEqual((None, None, None, None), tuple(tf_grad.output_shapes[0])) - self.assertEqual((None, None, None, None), tuple(tf_grad.output_shapes[1])) - else: - self.assertEqual((3, 4, 8, 8), tuple(tf_grad.output_shapes[0])) - self.assertEqual((3, 4, 8, 9), tuple(tf_grad.output_shapes[1])) + self.assertEqual((None, None, None, None), tuple(tf_grad.output_shapes[0])) + self.assertEqual((None, None, None, None), tuple(tf_grad.output_shapes[1])) def test_gradients_pytree(self): """Shape polymorphism with gradients and pytrees for inputs and outputs.""" @@ -1033,9 +830,6 @@ def f_jax(x): # A function whose gradient is a constant f_tf, input_signature=[tf.TensorSpec([None], x.dtype)]) self.assertAllClose(f_jax(x), restored_f(x)) - @jtu.ignore_warning( - message="jax2tf.convert with native_serialization=False has been deprecated" - ) def test_readme_examples(self): """Some of the examples from the README.""" @@ -1179,104 +973,6 @@ def f(x): polymorphic_shapes=["b1, b2, ..."]) self.assertAllClose(res_iter, res_vmap_tf) - def test_with_hash_collision_vmap(self): - # Batching caches based on Jaxpr, and Jaxpr include _DimExpr. If we have - # a collision for the hashing of a _DimExpr, then Python will call the - # equality, which will raise InconclusiveDimensionOperation. - - def f_jax(x): - return jnp.reshape(x, (2, -1,)) - try: - # Override the hashing to create collisions - orig_hash = getattr(shape_poly._DimExpr, "__hash__") - def collision_hash(obj): - return hash(5) - - setattr(shape_poly._DimExpr, "__hash__", collision_hash) - xs = np.ones((3, 5, 6), dtype=np.float32) - f_toconvert = jax.vmap(pjit.pjit(f_jax)) - res_1 = jax2tf.convert(f_toconvert)(xs) - res_2 = jax2tf.convert(f_toconvert, - polymorphic_shapes = "b1, b2, ...")(xs) - self.assertAllClose(res_1, res_2) - finally: - setattr(shape_poly._DimExpr, "__hash__", orig_hash) - - @jtu.parameterized_filterable( - kwargs=[ - dict(testcase_name=op_name, op=op) - for op, op_name in [ - (jnp.array, "array"), - (jnp.sin, "sin"), - (lambda x: x, "id"), - (core.dimension_as_value, "dimension_as_value"), - ]]) - def test_poly_unary_op(self, *, op=jnp.array): - def f_jax(x): # x: f32[b] - poly = 2 * x.shape[0] - return (op(poly), x) # Make sure we are using x - - check_shape_poly(self, - f_jax, - arg_descriptors=[RandArg((3,), _f32)], - polymorphic_shapes=["b"], - expected_output_signature=(tf.TensorSpec([]), tf.TensorSpec((None,), _f32))) - - @jtu.parameterized_filterable( - kwargs=[ - dict(testcase_name=f"_{op.__name__}_other={other}:{type(other)}{'_other_jnp_array' if other_jnp_array else ''}{'_swap' if swap else ''}", - op=op, other=other, - other_jnp_array=other_jnp_array, swap=swap) - for op in [op.add, op.mul, op.sub, - op.mod, op.floordiv, op.truediv] - for other in [ - 2, np.int32(2), 2., np.float32(2), - np.array(2, dtype=np.int32), np.arange(1, 5, dtype=np.int32), - np.array(2., dtype=np.float32), np.arange(1., 7., dtype=np.float32) - ] - for other_jnp_array in ( - [True, False] if np.shape(other) == (7,) else [False]) # type: ignore - for swap in [False, True] # The poly is the left op by default - ]) - def test_poly_binary_op(self, *, op=op.add, - other=np.arange(2, dtype=np.int32), - other_jnp_array=False, - swap=True): - # Test arithmetic operations with poly and a variety of other operand types - def f_jax(x): # x: f32[b] - poly = 2 * x.shape[0] # This will allow divisions with 2 - other_wrapped = jnp.array(other) if other_jnp_array else other - ops = (poly, other_wrapped) if not swap else (other_wrapped, poly) - res = op(*ops) - - # If the other op is an integer then the result is a symbolic dim - try: - op.index(other) - other_isint = True - except Exception: - other_isint = False - - if (hasattr(poly, "dimension_as_value") and - other_isint and - op.__name__ != "truediv"): - # If we running under jax2tf and "other" is an integer the result - # should be a symbolic dimension - self.assertTrue(isinstance(res, int) or hasattr(res, "dimension_as_value")) - - if config.enable_x64.value: - # Outside jax2tf, x.shape[0] is a Python (64-bit) integer and for most - # operations here JAX is not involved at all because the other operand - # is a Python or NumPy constant. So the result will be 64-bits. But under - # jax2tf, x.shape[0] is rewritten to jnp.array(x.shape[0]) which when - # used with int32 or float32 values will produce 32-bit values. - return (lax.convert_element_type(res, np.float32), x) - return (res, x) # Make sure we are using x - - check_shape_poly(self, - f_jax, - arg_descriptors=[RandArg((3,), np.int32)], - polymorphic_shapes=["b"]) - def test_mean0(self): def f_jax(x): # x: f32[b, 4] return jnp.sum(x, axis=0) / x.shape[0] @@ -1430,1252 +1126,5 @@ def f2(z, w): # z: f32[a, 5] w: f32[a + b, 5] -> f32[2*a + b, 10] self.assertAllClose(f2(* f1(x, y)), res) -# List containing either harnesses, or lists of harnesses -_POLY_SHAPE_TEST_HARNESSES = [ - PolyHarness("add", "", - jnp.add, - arg_descriptors=[RandArg((3, 4), _f32), RandArg((2, 3, 4), _f32)], - polymorphic_shapes=["b, ...", "_, b, _"]), - PolyHarness("add_transpose", "", - jax.grad(lambda x: jnp.sum(jnp.sum(x, axis=0, keepdims=False) + jnp.sin(x))), - arg_descriptors=[RandArg((3, 4), _f32)], - polymorphic_shapes=["b, ..."]), - [ - # make_args invoked with op.shape[0] and produces the arange args: - # start, stop, step, dtype - PolyHarness("arange", kwargs["testcase_name"], # type: ignore - lambda x: jnp.arange(*(kwargs["make_args"](x.shape[0]))), # type: ignore - arg_descriptors=[RandArg((6,), np.float32)], - polymorphic_shapes=["b"]) - for kwargs in [ - # Positive step - dict(testcase_name="b", make_args=lambda b: (b, None, None, None)), - dict(testcase_name="0_b+1", make_args=lambda b: (0, b + 1, None, None)), - dict(testcase_name="0_5b_2", make_args=lambda b: (0, 5 * b, 2, None)), - dict(testcase_name="0_5b+1_2", make_args=lambda b: (0, 5 * b + 1, 2, None)), - dict(testcase_name="b_5b+2_2", make_args=lambda b: (b, 5 * b + 2, 2, None)), - dict(testcase_name="0_b-1_2", make_args=lambda b: (0, b - 1, 2, None)), - dict(testcase_name="0_b-2_2", make_args=lambda b: (0, b - 2, 2, None)), - dict(testcase_name="0_-b_2", make_args=lambda b: (0, -b, 2, None)), - dict(testcase_name="0_1-b_2", make_args=lambda b: (0, 1 - b, 2, None)), - dict(testcase_name="0_b-3_2", make_args=lambda b: (0, b - 3, 2, None)), - # Cannot tell if size >= 0 - # Negative step - dict(testcase_name="b_0_-1", make_args=lambda b: (b, 0, -1, None)), - dict(testcase_name="b_1_-2", make_args=lambda b: (b, 1, -2, None)), - dict(testcase_name="b_-1_-1", make_args=lambda b: (b, -1, -1, None)), - dict(testcase_name="5b+1_0_-2", - make_args=lambda b: (5 * b + 1, 0, -2, None)), - dict(testcase_name="5b+2_0_-2", - make_args=lambda b: (5 * b + 2, 0, -2, None)), - dict(testcase_name="b-3_0_-2", make_args=lambda b: (b - 3, 0, -2, None)), - # Cannot tell if size >= 0 - # Symbolic step - dict(testcase_name="0_10_b", make_args=lambda b: (0, 10, b)), - dict(testcase_name="0_0_b", make_args=lambda b: (0, 0, b)), - dict(testcase_name="10_0_-b", make_args=lambda b: (10, 0, -b)), - dict(testcase_name="b_1_-b", make_args=lambda b: (b, 1, -b)), - # Float return type - dict(testcase_name="0_b_1_f32", make_args=lambda b: (0, b, 1, np.float32)) - ] - ], - # Reduce the poly dimension - PolyHarness("argmax", "0", - lambda op: lax.argmax(op, axis=0, index_dtype=np.int32), - arg_descriptors=[RandArg((3, 4, 5), _f32)], - polymorphic_shapes=["b, ..."]), - # Reduce the non-poly dimension - PolyHarness("argmax", "1", - lambda op: lax.argmax(op, axis=1, index_dtype=np.int32), - arg_descriptors=[RandArg((3, 4, 5), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("jnp.argsort", "", - lambda op: jnp.argsort(op), - arg_descriptors=[RandArg((3, 4, 5), _f32)], - polymorphic_shapes=["b, ..."]), - [ - PolyHarness("average", - f"{axis=}_weights=None", - lambda x, axis: jnp.average(x, axis=axis, returned=False, weights=None), - arg_descriptors=[RandArg((7, 8, 4), _f32), StaticArg(axis)], - polymorphic_shapes=["b, ..."]) - for axis in [None, 0, 1] - ], - [ - PolyHarness("average", - f"{axis=}_weights=Some", - lambda x, weights, axis: jnp.average(x, axis=axis, returned=False, weights=weights), - arg_descriptors=[RandArg((7, 8, 4), _f32), RandArg((7, 8, 4), _f32), StaticArg(axis)], - polymorphic_shapes=["b, ...", "b, ..."]) - for axis in [None, 0, 1] - ], - PolyHarness("jnp.bincount", "length=constant", - lambda x: jnp.bincount(x % 2, length=4), - arg_descriptors=[RandArg((12,), np.int32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("jnp.bincount", "length=poly", - lambda x: jnp.bincount(x % 4, length=x.shape[0]), - arg_descriptors=[RandArg((12,), np.int32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("broadcast_to", "", - lambda x: jnp.broadcast_to(x, [x.shape[0], x.shape[0], 4]), - arg_descriptors=[RandArg((3, 4), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("broadcast_in_dim", "0", - lambda x: lax.broadcast_in_dim(x, [x.shape[0], 4, 5, 6], - broadcast_dimensions=(0, 2, 3)), - arg_descriptors=[RandArg((3, 1, 6), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("broadcast_in_dim", "poly", - lambda x: lax.broadcast_in_dim(x, [x.shape[0], x.shape[0] + x.shape[0], 4], - broadcast_dimensions=(0, 1, 2)), - arg_descriptors=[RandArg((3, 1, 4), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("broadcast_in_dim", "poly2", - lambda x: lax.broadcast_in_dim(x, [x.shape[0], 5, 6, x.shape[2], 4], - broadcast_dimensions=(0, 2, 3)), - arg_descriptors=[RandArg((3, 1, 4), _f32)], - polymorphic_shapes=["b1, _, b2"]), - PolyHarness("broadcast_in_dim", "transpose", - jax.grad(lambda x: jnp.sum( - lax.broadcast_in_dim(jnp.sin(x), [2, x.shape[0], 5, x.shape[2], 4], - broadcast_dimensions=(1, 2, 3)))), - arg_descriptors=[RandArg((3, 1, 4), _f32)], - polymorphic_shapes=["b1, _, b2"]), - PolyHarness("clamp", "", - lax.clamp, - arg_descriptors=[RandArg((3, 4, 5), _f32), RandArg((3, 4, 5), _f32), - RandArg((3, 4, 5), _f32)], - polymorphic_shapes=["b, ...", "b, ...", "b, ..."]), - PolyHarness("collapse", "", - lambda x: lax.collapse(x, 1, 4), - arg_descriptors=[RandArg((3, 4, 5, 6, 7), _f32)], - polymorphic_shapes=["b0, b1, _, b3, ..."]), - PolyHarness("concatenate", "", - lambda x: jnp.concatenate([x, x], axis=0), - arg_descriptors=[RandArg((3, 4, 5), _f32)], - polymorphic_shapes=["b0, b1, _"]), - PolyHarness("concatenate", "grad", - jax.grad(lambda x: jnp.sum(jnp.concatenate([x, jnp.sin(x)], axis=0))), - arg_descriptors=[RandArg((3, 4, 5), _f32)], - polymorphic_shapes=["b0, b1, _"]), - - PolyHarness("conv_general_dilated", "1d_stride=1", - lambda lhs, rhs: lax.conv_general_dilated( - lhs, rhs, - window_strides=(1,), - padding="SAME", - rhs_dilation=None, - dimension_numbers=lax.ConvDimensionNumbers(lhs_spec=(0, 2, 1), - rhs_spec=(2, 1, 0), - out_spec=(0, 2, 1))), - arg_descriptors=[RandArg((1, 12, 16), _f32), RandArg((4, 16, 16), _f32)], - polymorphic_shapes=["_, b, _", None]), - # The same example from above, but with stride=2. - PolyHarness("conv_general_dilated", "1d_stride=2_even", - lambda lhs, rhs: lax.conv_general_dilated( - lhs, rhs, - window_strides=(2,), - padding="SAME", - rhs_dilation=None, - dimension_numbers=lax.ConvDimensionNumbers(lhs_spec=(0, 2, 1), - rhs_spec=(2, 1, 0), - out_spec=(0, 2, 1))), - arg_descriptors=[RandArg((1, 12, 16), _f32), RandArg((4, 16, 16), _f32)], - polymorphic_shapes=["_, b, _", None]), - # The same example from above, but with stride=2 and odd input size. - PolyHarness("conv_general_dilated", "1d_stride=2_odd", - lambda lhs, rhs: lax.conv_general_dilated( - lhs, rhs, - window_strides=(2,), - padding="SAME", - rhs_dilation=None, - dimension_numbers=lax.ConvDimensionNumbers(lhs_spec=(0, 2, 1), - rhs_spec=(2, 1, 0), - out_spec=(0, 2, 1))), - arg_descriptors=[RandArg((1, 13, 16), _f32), RandArg((4, 16, 16), _f32)], - polymorphic_shapes=["_, b, _", None]), - PolyHarness("conv_general_dilated", "1d_stride=2_zero_output", - lambda lhs, rhs: lax.conv_general_dilated( - lhs, rhs, - window_strides=(2,), - padding="VALID", - rhs_dilation=None, - dimension_numbers=lax.ConvDimensionNumbers(lhs_spec=(0, 2, 1), - rhs_spec=(2, 1, 0), - out_spec=(0, 2, 1)) - ).shape[1], # should be 0 in JAX native - arg_descriptors=[RandArg((1, 4, 16), _f32), - RandArg((8, 16, 16), _f32)], - polymorphic_shapes=["_, b, _", - None]), - # Issue #11402 - PolyHarness("conv_general_dilated", "1d_2", - lambda lhs, rhs: lax.conv_transpose(lhs, rhs, - strides=(2,), - padding="SAME", - rhs_dilation=None, - transpose_kernel=False), - arg_descriptors=[RandArg((5, 12, 16), _f32), RandArg((4, 16, 16), _f32)], - polymorphic_shapes=["b, _, _", None], - tol=5e-5), - # Issue #11402 - PolyHarness("conv_general_dilated", "1d_3", - lambda lhs, rhs: lax.conv_transpose(lhs, rhs, - strides=(2,), - padding="SAME", - rhs_dilation=None, - transpose_kernel=False), - arg_descriptors=[RandArg((5, 12, 16), _f32), RandArg((4, 16, 16), _f32)], - polymorphic_shapes=["_, b, _", None], - tol=5e-5), - PolyHarness("conv_general_dilated", "", - lambda lhs, rhs: lax.conv_general_dilated( - lhs, rhs, - window_strides=(2, 3), - padding=((0, 0), (0, 0)), - lhs_dilation=(1, 1), - rhs_dilation=(1, 2), - dimension_numbers=("NCHW", "OIHW", "NCHW"), - feature_group_count=1, - batch_group_count=1, - precision=None), - arg_descriptors=[RandArg((7, 3, 9, 10), _f32), RandArg((3, 3, 4, 5), _f32)], - polymorphic_shapes=["b, ...", None]), - [ - [ - PolyHarness(cum_name, "reduce_axis_poly", - lambda x: cum_func(x, axis=0), - arg_descriptors=[RandArg((3, 5), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness(cum_name, "reduce_axis_static", - lambda x: cum_func(x, axis=1), - arg_descriptors=[RandArg((3, 5), _f32)], - polymorphic_shapes=["b, ..."]) - ] - for cum_name, cum_func in [ - ("cumlogsumexp", lax_control_flow.cumlogsumexp), - ("cummax", lax_control_flow.cummax), - ("cummin", lax_control_flow.cummin), - ("cumsum", lax_control_flow.cumsum), - ("cumprod", lax_control_flow.cumprod) - ] - ], - PolyHarness("delta", "0", - lambda x: lax_internal._delta(_f32, x.shape, axes=(0, 1)) + x, - arg_descriptors=[RandArg((3, 1), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("dot_general", "", - lambda lhs, rhs: lax.dot_general(lhs, rhs, - dimension_numbers=(((2,), (1,)), ((0,), (0,)))), - arg_descriptors=[RandArg((3, 4, 4), _f32), RandArg((3, 4), _f32)], - polymorphic_shapes=["b, ...", "b, ..."]), - PolyHarness("dynamic_slice", "idx=tuple_int", - # x:shape: (b, 4) - lambda x: lax.dynamic_slice(x, (0, 1), (x.shape[0], 2)), - arg_descriptors=[RandArg((3, 4), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("dynamic_slice", "idx=tuple_arg", - # x:shape: (b, 4) - lambda x, i0: lax.dynamic_slice(x, (i0, np.int32(1)), (x.shape[0], 2)), - arg_descriptors=[RandArg((3, 4), _f32), np.array(-2, dtype=np.int32)], - polymorphic_shapes=["b, ...", None]), - PolyHarness("dynamic_slice", "idx=array", - # x:shape: (b, 4) - lambda x, idx: lax.dynamic_slice(x, idx, (x.shape[0], 2)), - arg_descriptors=[RandArg((3, 4), _f32), np.array([-2, -1], dtype=np.int32)], - polymorphic_shapes=["b, ...", None]), - PolyHarness("dynamic_slice", "idx=tuple_int_start_oob_large", - # x:shape: (b, 4) - lambda x: lax.dynamic_slice(x, (1, 1), (x.shape[0], 2)), - arg_descriptors=[RandArg((3, 4), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("dynamic_slice", "idx=tuple_int_start_oob_small", - # x:shape: (b, 4) - lambda x: lax.dynamic_slice(x, (-1, 1), (x.shape[0] - 1, 2)), - arg_descriptors=[RandArg((3, 4), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("dynamic_slice_in_dim", "idx=0", - # x:shape: (b, 4) - lambda x: lax.dynamic_slice_in_dim(x, 0, x.shape[0], axis=0), - arg_descriptors=[RandArg((3, 4), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("dynamic_update_slice", "idx=tuple_int", - # x:shape: (b, 4) - lambda x: lax.dynamic_update_slice(x, x, (0, 0)), - arg_descriptors=[RandArg((3, 4), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("dynamic_update_slice", "idx=tuple_arg", - # x:shape: (b, 4) - lambda x, i0: lax.dynamic_update_slice(x, x, (i0, np.int32(0))), - arg_descriptors=[RandArg((3, 4), _f32), np.array(-2, dtype=np.int32)], - polymorphic_shapes=["b, ...", None]), - PolyHarness("dynamic_update_slice", "idx=array", - # x:shape: (b, 4) - lambda x, idx: lax.dynamic_update_slice(x, x, idx), - arg_descriptors=[RandArg((3, 4), _f32), np.array([-2, -1], dtype=np.int32)], - polymorphic_shapes=["b, _", None]), - [ - PolyHarness("eig", f"shape={jtu.format_shape_dtype_string((3, 5, 5), dtype)}_poly={poly}_{left=}_{right=}", - lambda x, left, right: lax.linalg.eig(x, compute_left_eigenvectors=left, compute_right_eigenvectors=right), - arg_descriptors=[RandArg((3, 5, 5), dtype), - StaticArg(left), StaticArg(right)], - polymorphic_shapes=[poly], - # In non-native serialization, we cannot check exact match, - # we ought to check the invariants of the result. - check_result=config.jax2tf_default_native_serialization.value) - for dtype in {np.float32, np.float64, np.complex64, np.complex128} & jtu.supported_dtypes() - for poly in ["b, ...", "b, w, w"] - for left in ([True, False] if dtype == np.float32 else [True]) - for right in ([True, False] if dtype == np.float32 else [False]) - ], - PolyHarness("einsum", "0", - lambda x: jnp.einsum("...i->...", x), - arg_descriptors=[RandArg((3, 4), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("einsum", "0_alt", - lambda x: jnp.einsum(x, (..., 1), [...]), - arg_descriptors=[RandArg((3, 4), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("einsum", "1", - lambda x, y: jnp.einsum("...ij,...jk->...ik", x, y), - arg_descriptors=[RandArg((3, 4, 5), _f32), RandArg((3, 5, 6), _f32)], - polymorphic_shapes=["b, ...", "b, ..."]), - PolyHarness("einsum", "1_alt", - lambda x, y: jnp.einsum(x, [..., 0, 1], y, (..., 1, 2), [..., 0, 2]), - arg_descriptors=[RandArg((3, 4, 5), _f32), RandArg((3, 5, 6), _f32)], - polymorphic_shapes=["b, ...", "b, ..."]), - PolyHarness("einsum", "2", - lambda x, y: jnp.einsum("...ij,jk->...ik", x, y), - arg_descriptors=[RandArg((3, 4, 5), _f32), RandArg((5, 6), _f32)], - polymorphic_shapes=["b, ...", None]), - PolyHarness("einsum", "2_alt", - lambda x, y: jnp.einsum(x, [..., 0, 1], y, [1, 2], [..., 0, 2]), - arg_descriptors=[RandArg((3, 4, 5), _f32), RandArg((5, 6), _f32)], - polymorphic_shapes=["b, ...", None]), - PolyHarness("einsum", "3", - # Reduced dimension is polymorphic - lambda x, y: jnp.einsum("ij,jk->ik", x, y), - arg_descriptors=[RandArg((3, 4), _f32), RandArg((4, 5), _f32)], - polymorphic_shapes=["_, b", "b, ..."]), - PolyHarness("einsum", "3_alt", - # Reduced dimension is polymorphic - lambda x, y: jnp.einsum(x, [0, 1], y, [1, 2], [0, 2]), - arg_descriptors=[RandArg((3, 4), _f32), RandArg((4, 5), _f32)], - polymorphic_shapes=["_, b", "b, ..."]), - PolyHarness("einsum", "4", - # Reduced dimension is polymorphic, and is 2*b - lambda x, y: jnp.einsum("ij,jk->ik", - jnp.concatenate([x, x], axis=1), - jnp.concatenate([y, y], axis=0)), - arg_descriptors=[RandArg((3, 4), _f32), RandArg((4, 5), _f32)], - polymorphic_shapes=["_, b", "b, ..."]), - PolyHarness("einsum", "4_alt", - # Reduced dimension is polymorphic, and is 2*b - lambda x, y: jnp.einsum(jnp.concatenate([x, x], axis=1), [0, 1], - jnp.concatenate([y, y], axis=0), [1, 2], - [0, 2]), - arg_descriptors=[RandArg((3, 4), _f32), RandArg((4, 5), _f32)], - polymorphic_shapes=["_, b", "b, ..."]), - PolyHarness("einsum", "multiple_contractions", - lambda x, y, z: jnp.einsum("ab,bc,cd->ad", x, y, z), - arg_descriptors=[RandArg((3, 2), _f32), RandArg((2, 3), _f32), RandArg((3, 4), _f32)], - polymorphic_shapes=["b, ...", None, None]), - PolyHarness("einsum", "incompatible_contractions_error", - lambda x, y: jnp.einsum("ab,cb->ac", x, y), - arg_descriptors=[RandArg((2, 3), _f32), RandArg((2, 3), _f32)], - polymorphic_shapes=["(2, b0)", "(2, b1)"], - input_signature=[tf.TensorSpec((2, None)), tf.TensorSpec((2, None))], - expect_error=(AssertionError, - "Incompatible reduction dimensions")), - PolyHarness("eye", "N=poly_M=None", - lambda x: jnp.eye(x.shape[0]) + x, - arg_descriptors=[RandArg((3, 1), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("eye", "N=poly_M=poly", - lambda x: jnp.eye(x.shape[0], M=x.shape[0] + 2) + x, - arg_descriptors=[RandArg((3, 1), _f32)], - polymorphic_shapes=["b, ..."]), - [ - PolyHarness("fft", f"{fft_type=}_{nr_fft_lengths=}", - lambda x, fft_type, nr_fft_lengths: lax.fft_p.bind( - x, fft_type=fft_type, - fft_lengths=tuple( - x.shape[-nr_fft_lengths:] if fft_type != lax.FftType.IRFFT else - [(x.shape[-1] - 1) * 2])), - arg_descriptors=[ - RandArg((3, 4, 5, 6), - np.float32 if fft_type == lax.FftType.RFFT else np.complex64), - StaticArg(fft_type), - StaticArg(nr_fft_lengths)], - # All axes but the last one are dynamic. This means that the test - # with nr_fft_lengths==1 will not have dynamic fft_lengths. - polymorphic_shapes=["b0, b1, b2, ..."], - tol=1e-4) - - for fft_type in (lax.FftType.FFT, lax.FftType.IFFT, - lax.FftType.RFFT, lax.FftType.IRFFT) - for nr_fft_lengths in (1, 2) - ], - PolyHarness("full", "", - lambda x: lax.full((x.shape[0], 2), 3.) + x, - arg_descriptors=[RandArg((3, 1), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("gather", "1d", - lambda operand, start_indices, x: lax.gather( - operand, - start_indices, - dimension_numbers=lax.GatherDimensionNumbers( - offset_dims=(1,), - collapsed_slice_dims=(), - start_index_map=(0,)), - slice_sizes=x.shape, - mode="promise_in_bounds"), - arg_descriptors=[ - RandArg((10,), np.float32), - np.random.randint(0, high=10, size=(3, 1), - dtype=np.int32), - np.zeros((10,), dtype=jnp.int32), - ], - polymorphic_shapes=["(t, )", "(3, 1)", "(t)"]), - # operand is non-poly, index is poly - PolyHarness("getitem", "op=static_idx=poly", - lambda a, i: a[i], - arg_descriptors=[RandArg((3, 4), _f32), np.array([2, 2], np.int32)], - polymorphic_shapes=[None, "b0, ..."]), - # operand is poly, index is integer - PolyHarness("getitem", "op=poly_idx=const", - lambda a: a[1], - arg_descriptors=[RandArg((3, 4), _f32)], - polymorphic_shapes=["b, ..."]), - # operand is poly, index is dim poly - PolyHarness("getitem", "op=poly_idx=dim", - lambda a: a[jnp.array(a.shape[0] - 2)], - arg_descriptors=[RandArg((3, 4), _f32)], - polymorphic_shapes=["b, ..."]), - # Both the operand and the index are poly - PolyHarness("getitem", "op=poly_idx=poly", - lambda a, i: a[i], - arg_descriptors=[RandArg((3, 4), _f32), np.array([1, 2, 0], np.int32)], - polymorphic_shapes=["b, ...", "b, ..."]), - # op is poly and index is an entire slice - PolyHarness("getitem", "op=poly_idx=slice-all", - lambda a: a[:], - arg_descriptors=[RandArg((3, 4), _f32)], - polymorphic_shapes=["b, ..."]), - # op is poly and index is a partial slice - PolyHarness("getitem", "op=poly_idx=slice-ct-1", - lambda a: a[:2], - arg_descriptors=[RandArg((3, 4), _f32)], - polymorphic_shapes=["b + 2, ..."]), - PolyHarness("getitem", "op=poly_idx=slice-ct-2", - lambda a: a[:, :2], - arg_descriptors=[RandArg((3, 4), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("getitem", "op=poly_idx=slice-None-1", - lambda a: a[:a.shape[0]], - arg_descriptors=[RandArg((3, 4), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("getitem", "op=poly_idx=slice-poly", - lambda a: a[:a.shape[0] - 1], - arg_descriptors=[RandArg((3, 4), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("image_resize", "linear_0", - lambda x: jax.image.resize(x, (x.shape[0], 2 * x.shape[1], 2 * x.shape[2], x.shape[3]), - method="linear"), - arg_descriptors=[RandArg((3, 16, 32, 3), _f32)], - polymorphic_shapes=["_, b1, b2, ..."]), - PolyHarness("image_resize", "linear_to_fixed_dim", - lambda x: jax.image.resize(x, (x.shape[0], 64, 64, x.shape[3]), - method="linear"), - arg_descriptors=[RandArg((3, 16, 32, 3), _f32)], - polymorphic_shapes=["_, b1, b2, ..."]), - PolyHarness("image_resize", "nearest_0", - lambda x: jax.image.resize(x, (x.shape[0], 2 * x.shape[1], 2 * x.shape[2], x.shape[3]), - method="nearest"), - arg_descriptors=[RandArg((3, 5, 7, 3), _f32)], - polymorphic_shapes=["_, b1, b2, ..."]), - PolyHarness("index_in_dim", "0", - lambda x: lax.index_in_dim(x, -1, axis=0, keepdims=False), - arg_descriptors=[RandArg((3, 4), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("index_in_dim", "idx=neg", - lambda x: lax.index_in_dim(x, -1, axis=0, keepdims=False), - arg_descriptors=[RandArg((3, 4), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("index_in_dim", "idx=last", - lambda x: lax.index_in_dim(x, x.shape[0] - 1, axis=0, keepdims=False), - arg_descriptors=[RandArg((3, 4), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("jnp.insert", "insert=constant", - lambda x: jnp.insert(x, jnp.arange(3, dtype=_i32), np.array([3, 4, 5], dtype=_i32)), - arg_descriptors=[RandArg((12,), _i32)], - polymorphic_shapes=["b, ..."], - expect_error=expect_error_associative_scan), - PolyHarness("jnp.insert", "insert=poly", - lambda x: jnp.insert(x, jnp.arange(x.shape[0], dtype=_i32), x, axis=0), - arg_descriptors=[RandArg((12, 3), _i32)], - polymorphic_shapes=["b0, b1, ..."], - expect_error=expect_error_associative_scan), - PolyHarness("iota", "", - lambda x: x + lax.iota(_f32, x.shape[0]), - arg_descriptors=[RandArg((3,), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("matmul", "0", - jnp.matmul, - arg_descriptors=[RandArg((7, 8, 4), _f32), RandArg((7, 4, 5), _f32)], - polymorphic_shapes=["b, ...", "b, ..."], - tol=1e-5), - PolyHarness("matmul", "1", - jnp.matmul, - arg_descriptors=[RandArg((7, 8, 4), _f32), RandArg((4, 5), _f32)], - polymorphic_shapes=["b, ...", None], - tol=1e-5), - [ - PolyHarness("mean", - f"{axis=}_{keepdims=}_where=None", - lambda x, axis, keepdims: jnp.mean(x, axis=axis, keepdims=keepdims, where=None), - arg_descriptors=[RandArg((7, 8, 4), _f32), StaticArg(axis), StaticArg(keepdims)], - polymorphic_shapes=["b, ..."]) - for keepdims in [False, True] - for axis in [None, (0,), (0, 1), (1,)] - ], - [ - PolyHarness("mean", - f"{axis=}_{keepdims=}_where=Some", - lambda x, where, axis, keepdims: jnp.mean(x, axis=axis, keepdims=keepdims, where=where), - arg_descriptors=[RandArg((7, 8, 4), _f32), RandArg((7, 8, 4), np.bool_), - StaticArg(axis), StaticArg(keepdims)], - polymorphic_shapes=["b, ...", "b, ..."]) - for keepdims in [False, True] - for axis in [None, (0,), (0, 1), (1,)] - ], - PolyHarness("jnp.nonzero", "size=constant", - lambda x: jnp.nonzero(x % 3, size=10, fill_value=100), - arg_descriptors=[RandArg((3, 2, 4), _i32)], - polymorphic_shapes=["b, ..."], - expect_error=expect_error_associative_scan), - PolyHarness("jnp.nonzero", "size=poly", - lambda x: jnp.nonzero(x % 3, size=x.shape[0] * 2, fill_value=100), - arg_descriptors=[RandArg((3, 2, 4), _i32)], - polymorphic_shapes=["b, ..."], - expect_error=expect_error_associative_scan), - PolyHarness("one_hot", "poly_num_classes", - lambda x, y: jax.nn.one_hot(x, y.shape[0]), - arg_descriptors=[np.arange(16, dtype=_i32), RandArg((16,), _f32)], - polymorphic_shapes=[None, "b0, ..."]), - PolyHarness("one_hot", "all_poly", - lambda x, y: jax.nn.one_hot(x, y.shape[0]), - arg_descriptors=[np.arange(16, dtype=_i32), RandArg((16,), _f32)], - polymorphic_shapes=["b, ...", "b, ..."]), - PolyHarness("ones", "", - lambda x: jnp.ones(x.shape, dtype=_f32) + x, - arg_descriptors=[RandArg((3, 2, 4), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("pad", "", - lax.pad, - arg_descriptors=[RandArg((3, 2, 5), _f32), np.float32(5.), - StaticArg(((0, 0, 0), (0, 0, 0), (1, 1, 1)))], - polymorphic_shapes=["b, ...", None]), - PolyHarness("pad", "poly_padding_config", - lambda x: lax.pad(x, _f32(0.), - ((x.shape[0], x.shape[1], x.shape[0]), - (0, 0, 0))), - arg_descriptors=[RandArg((3, 2), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("jnp.pad", "mode=constant", - lambda x: jnp.pad(x, [[x.shape[0], 0], [x.shape[1], 1]], - mode="constant"), - arg_descriptors=[RandArg((3, 5), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("jnp.pad", "mode=constant_bminus1", - # We slice first the unknown dimension to make it of size b - 1 - # which may be 0. - lambda x: jnp.pad(lax.dynamic_slice_in_dim(x, 1, x.shape[0] - 1, - axis=0), - [[x.shape[0], 0], [x.shape[1], 1]], - mode="constant"), - arg_descriptors=[RandArg((3, 5), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("jnp.pad", "mode=edge", - lambda x: jnp.pad(x, [[x.shape[0], 0], [x.shape[1], 1]], - mode="edge"), - arg_descriptors=[RandArg((3, 5), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("percentile", "axis=None", - lambda x: jnp.percentile(x, 50, axis=None), - arg_descriptors=[RandArg((3, 5), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("nanquantile", "axis=None", - lambda x: jnp.nanquantile(x, .5, axis=None), - arg_descriptors=[RandArg((3, 5), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("percentile", "axis=0", - lambda x: jnp.percentile(x, 50, axis=0), - arg_descriptors=[RandArg((3, 5), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("nanquantile", "axis=0", - lambda x: jnp.nanquantile(x, .5, axis=0), - arg_descriptors=[RandArg((3, 5), _f32)], - polymorphic_shapes=["b, ..."]), - [ - PolyHarness( - "qr", f"shape={jtu.format_shape_dtype_string(shape, dtype)}_poly={poly}_{full_matrices=}", - lambda x, full_matrices: lax.linalg.qr(x, full_matrices=full_matrices), - arg_descriptors=[RandArg(shape, dtype), StaticArg(full_matrices)], - polymorphic_shapes=[poly], - tol=(None if config.jax2tf_default_native_serialization.value else 1e-5)) - for dtype in {np.float32, np.float64, np.complex64, np.complex128} & jtu.supported_dtypes() - # m and n must be static for now - for shape, poly, full_matrices in [ - ((2, 0, 4), "b, ...", False), # m = 0 - ((2, 4, 0), "b, ...", False), # n = 0 - ((2, 3, 4, 4), "b1, b2, ...", False), # m == n - ((2, 3, 4, 4), "b1, b2, ...", True), - ((2, 3, 4, 5), "b1, b2, ...", False), # m < n - ((2, 3, 4, 5), "b1, b2, ...", True), - ((2, 3, 8, 4), "b1, b2, ...", False), # m > n - ((2, 3, 8, 4), "b1, b2, ...", True), - ] - ], - [ - # The random primitive tests, with threefry (both partitionable and - # non-partitionable), and unsafe_rbg. - [ - PolyHarness("random_gamma", f"{flags_name}", - lambda key, a: jax.vmap(jax.random.gamma)(key, a), - arg_descriptors=[RandArg((3, key_size), np.uint32), RandArg((3, 4, 5), _f32)], - polymorphic_shapes=["b, ...", "b, w, ..."], tol=1E-5, - override_jax_config_flags=override_jax_config_flags), # type: ignore - # The known dimensions product must be even. - PolyHarness("random_categorical", f"axis=0_{flags_name}", - lambda key, a: jax.random.categorical(key, a, axis=0), - arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 8), _f32)], - polymorphic_shapes=[None, "b0, ..."], - override_jax_config_flags=override_jax_config_flags), # type: ignore - PolyHarness("random_categorical", f"axis=1_{flags_name}", - lambda key, a: jax.random.categorical(key, a, axis=1), - arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 5, 8), _f32)], - polymorphic_shapes=[None, "b0, b1, ..."], - override_jax_config_flags=override_jax_config_flags), # type: ignore - PolyHarness("random_categorical", f"axis=1_then_reshape_{flags_name}", - lambda key, a: jax.random.categorical(key, a, axis=1).reshape(-1), - arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 5, 8), _f32)], - polymorphic_shapes=[None, "b0, b1, ..."], - override_jax_config_flags=override_jax_config_flags), # type: ignore - PolyHarness("random_categorical", f"0_dim_{flags_name}", # One axis has 0 size - lambda key, a: jax.random.categorical(key, a, axis=1), - arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 5, 0), _f32)], - polymorphic_shapes=[None, "b0, b1, ..."], - override_jax_config_flags=override_jax_config_flags), # type: ignore - PolyHarness("random_split", f"{flags_name}", - lambda key, a: jax.random.key_data(jax.random.split(key, 2 * a.shape[0])), - arg_descriptors=[RandArg((key_size,), np.uint32), - RandArg((3, 4), _f32)], - polymorphic_shapes=[None, "b0, ..."], - override_jax_config_flags=override_jax_config_flags), # type: ignore - # Works when the known dimensions are known to be even or odd. - PolyHarness("random_uniform", f"even_1_{flags_name}", - lambda key, a: jax.random.uniform(key, a.shape, dtype=_f32), - arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 4, 5), _f32)], - polymorphic_shapes=[None, "b0, ..."], - override_jax_config_flags=override_jax_config_flags), # type: ignore - PolyHarness("random_uniform", f"even_2_{flags_name}", - lambda key, a: jax.random.uniform(key, (2 * a.shape[0], a.shape[1]), - dtype=_f32), - arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 4), _f32)], - polymorphic_shapes=[None, "b0, b1, ..."], - override_jax_config_flags=override_jax_config_flags), # type: ignore - PolyHarness("random_uniform", f"error_not_even_{flags_name}", - lambda key, a: jax.random.uniform(key, a.shape, dtype=_f32), - arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 5), _f32)], - polymorphic_shapes=[None, "b0, b1"], - override_jax_config_flags=override_jax_config_flags) # type: ignore - ] - for key_size, flags_name, override_jax_config_flags in [ - (2, "threefry_non_partitionable", - dict(jax_default_prng_impl="threefry2x32", jax_threefry_partitionable=False)), - (2, "threefry_partitionable", - dict(jax_default_prng_impl="threefry2x32", jax_threefry_partitionable=True)), - (4, "unsafe_rbg", - dict(jax_default_prng_impl="unsafe_rbg")) - ] - ], - # For reduce_window we have a variant with one reduction axis of - # non-static shape, and one with additionally the dimension window - # non-static. - PolyHarness("reduce_window", "min_window_size=static", - # x: f32[b, 8] - lambda x: lax.reduce_window(x, np.array(1., _f32), lax.min, - (2, 2), (1, 1), "VALID"), - arg_descriptors=[RandArg((3, 8), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("reduce_window", "min_window_size=dynamic", - # x: f32[b, 8] - lambda x: lax.reduce_window(x, np.array(1., _f32), lax.min, - (2, x.shape[0]), (1, 1), "VALID"), - arg_descriptors=[RandArg((3, 8), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("reduce_window", "min_plus_max_window_size=static", - # x: f32[b, 8] - lambda x: ( - # Test that we don't get confusion for the reducer name. - lax.reduce_window(x, np.array(1., _f32), lax.min, - (2, 2), (1, 1), "VALID") + - lax.reduce_window(x, np.array(1., _f32), lax.max, - (2, 2), (1, 1), "VALID")), - arg_descriptors=[RandArg((3, 8), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("reduce_window", "min_plus_max_window_size=dynamic", - # x: f32[b, 8] - lambda x: ( - # Test that we don't get confusion for the reducer name. - lax.reduce_window(x, np.array(1., _f32), lax.min, - (2, x.shape[0]), (1, 1), "VALID") + - lax.reduce_window(x, np.array(1., _f32), lax.max, - (2, x.shape[0]), (1, 1), "VALID")), - arg_descriptors=[RandArg((3, 8), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("reduce_window", "add_monoid_base_window_size=static", - # x: f32[b, 8] - lambda x: lax.reduce_window(x, np.array(0., _f32), lax.add, - (2, 2), (1, 1), "VALID"), - arg_descriptors=[RandArg((3, 8), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("reduce_window", "add_monoid_base_window_size=dynamic", - # x: f32[b, 8] - lambda x: lax.reduce_window(x, np.array(0., _f32), lax.add, - (2, x.shape[0]), (1, 1), "VALID"), - arg_descriptors=[RandArg((3, 8), _f32)], - polymorphic_shapes=["b, ..."]), - # https://github.com/jax-ml/jax/issues/11804 - # Use the reshape trick to simulate a polymorphic dimension of 16*b. - # (See test "conv_general_dilated.1d_1" above for more details.) - PolyHarness("reduce_window", "add_monoid_strides_window_size=static", - # x: f32[1, 16*b, 1] - lambda x: lax.reduce_window( - jnp.reshape(x, (1, -1, 1)), - np.array(0., _f32), lax.add, (1, 4, 1), (1, 2, 1), "SAME"), - arg_descriptors=[RandArg((1, 128, 16), _f32)], - polymorphic_shapes=["_, b1, ..."]), - PolyHarness("reduce_window", "add_generic_window_size=static", - # x: f32[1, 16*b, 1] - # Use an initial value of 1. to trigger the generic reduction path - lambda x: lax.reduce_window( - jnp.reshape(x, (1, -1, 1)), - np.array(1., _f32), lax.add, (1, 4, 1), (1, 2, 1), "SAME"), - arg_descriptors=[RandArg((1, 128, 16), _f32)], - polymorphic_shapes=["_, b1, ..."]), - PolyHarness("reduce_window", "variadic_generic_window_size=static", - # x: f32[b, 8] y: f32[b, 8] - lambda x, y: lax.reduce_window( - (x, y), (np.array(1., _f32), np.array(2, _i32)), - lambda xy0, xy1: (lax.add(xy0[0], xy1[0]), - lax.sub(xy0[1], xy1[1])), - (2, 2), (1, 1), "VALID"), - arg_descriptors=[RandArg((3, 8), _f32), RandArg((3, 8), _i32)], - polymorphic_shapes=["b, ...", "b, ..."]), - PolyHarness("reduce_window", "variadic_generic_window_size=dynamic", - # x: f32[b, 8] y: f32[b, 8] - lambda x, y: lax.reduce_window( - (x, y), (np.array(1., _f32), np.array(2, _i32)), - lambda xy0, xy1: (lax.add(xy0[0], xy1[0]), - lax.sub(xy0[1], xy1[1])), - (2, x.shape[0]), (1, 1), "VALID"), - arg_descriptors=[RandArg((3, 8), _f32), RandArg((3, 8), _i32)], - polymorphic_shapes=["b, ...", "b, ..."]), - # TODO(necula): not yet supported, but also unlikely to come up. - # PolyHarness("random_uniform", "odd", - # lambda key, a: jax.random.uniform(key, (2 * a.shape[0] + 1, a.shape[1]), - # dtype=_f32), - # [RandArg((2,), np.uint32), RandArg((3, 5), _f32)], - # polymorphic_shapes=[None, "b0, ..."]), - [ - PolyHarness("reduce", reduce_op.__name__, - lambda x: reduce_op(x, axis=-1, keepdims=True), # type: ignore - arg_descriptors=[RandArg((3, 5), _f32)], - polymorphic_shapes=["b, ..."]) - for reduce_op in [jnp.all, jnp.any, jnp.max, jnp.min, jnp.prod, jnp.sum] - ], - # Repeat f32[b, 2] * 3 - PolyHarness("repeat", "repeats=int_axis=0", - lambda x: jnp.repeat(x, repeats=3, axis=0), - arg_descriptors=[RandArg((3, 2), _f32)], - polymorphic_shapes=["b, ..."]), - # Repeat f32[b, 2] * b - PolyHarness("repeat", "repeats=poly_axis=0", - lambda x: jnp.repeat(x, repeats=x.shape[0], axis=0), - arg_descriptors=[RandArg((3, 2), _f32)], - polymorphic_shapes=["b, ..."]), - # Repeat f32[b, 2] * b - PolyHarness("repeat", "repeats=poly_axis=None", - lambda x: jnp.repeat(x, repeats=x.shape[0], axis=None), - arg_descriptors=[RandArg((3, 2), _f32)], - polymorphic_shapes=["b, ..."]), - # Repeat f32 * b - PolyHarness("repeat", "repeats=poly_axis=None_scalar", - lambda x, y: jnp.repeat(x, repeats=y.shape[0], axis=None) + y, - arg_descriptors=[RandArg((), _f32), RandArg((3, 1), _f32)], - polymorphic_shapes=[None, "b0, ..."]), - PolyHarness("repeat", "repeats=poly_axis=None_total_repeat_length1", - lambda x: jnp.repeat(x, repeats=x.shape[0], axis=None, total_repeat_length=8), - arg_descriptors=[RandArg((3, 2), _f32)], - polymorphic_shapes=["b, ..."], - expect_error=(ValueError, "jnp.repeat with a non-constant `repeats` is supported only .*")), - PolyHarness("reshape", "0", - lambda x: x.reshape([x.shape[0], -1]), - arg_descriptors=[RandArg((3, 2, 3), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("reshape", "1", - lambda x: x.reshape([x.shape[0], -1]), - arg_descriptors=[RandArg((3, 2, 3), _f32)], - polymorphic_shapes=["b0, b1, ..."]), - PolyHarness("reshape", "2", - lambda x: x.reshape([x.shape[0], -1, x.shape[3], x.shape[2]]), - arg_descriptors=[RandArg((3, 4, 5, 6, 7), _f32)], - polymorphic_shapes=["b0, _, b2, b3, ..."]), - PolyHarness("reshape", "3", - lambda x: jnp.reshape(x, [2, -1]), - arg_descriptors=[RandArg((3, 4, 5, 6, 7), _f32)], - polymorphic_shapes=["b0, _, b2, ..."]), - PolyHarness("reshape", "_issue_9975", - # The newshape is a scalar - lambda x: jnp.reshape(x, x.shape[0] * x.shape[1]), - arg_descriptors=[RandArg((3, 4), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("reshape", "error", - lambda x: x.reshape([x.shape[0], -1, 3]), - arg_descriptors=[RandArg((3, 2, 4), _f32)], - polymorphic_shapes=["b, ..."], - input_signature=[tf.TensorSpec([None, 2, 4], _f32)], - skip_jax_run=True, - expect_error=(core.InconclusiveDimensionOperation, - re.escape( - "Cannot divide evenly the sizes of shapes (b, 2, 4) and (b, -1, 3)"))), - PolyHarness("roll", "axis=0", - lambda x: jnp.roll(x, 2, axis=0), - arg_descriptors=[RandArg((3, 4), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("roll", "axis=None", - lambda x: jnp.roll(x, 2), - arg_descriptors=[RandArg((3, 4), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("scatter_add", "", - partial(lax.scatter_add, indices_are_sorted=False, unique_indices=True), - arg_descriptors=[RandArg((7, 4), _f32), # op: [b, 4] - np.array([[1], [2]], np.int32), # indices: [2, 1] - RandArg((7, 2), _f32), # updates: [b, 2] - StaticArg(lax.ScatterDimensionNumbers((0,), (1,), (1,)))], - polymorphic_shapes=["b, ...", None, "b, ..."]), - PolyHarness("scatter_add", "clip0", - partial(lax.scatter_add, indices_are_sorted=False, unique_indices=True, mode=lax.GatherScatterMode.CLIP), - arg_descriptors=[RandArg((7, 4), _f32), # op: [b, 4] - np.array([[1], [2]], np.int32), # indices: [2, 1] - RandArg((7, 2), _f32), # updates: [b, 2] - StaticArg(lax.ScatterDimensionNumbers((0,), (1,), (1,)))], - polymorphic_shapes=["b, ...", None, "b, ..."]), - PolyHarness("scatter_add", "clip1", - partial(lax.scatter_add, indices_are_sorted=False, unique_indices=True, mode=lax.GatherScatterMode.CLIP), - arg_descriptors=[RandArg((7, 4), _f32), # op: [b, 4] - # indices: [b, 2] - np.array([[1, 2], [-2, 0], [6, 4], [7, -1], [1, 0], [3, 0], [0, 5]], np.int32), - RandArg((7, 1), _f32), # updates: [b, 1] - StaticArg(lax.ScatterDimensionNumbers((1,), (0,), (0, 1,)))], - polymorphic_shapes=["b, ...", "b, ...", "b, ..."]), - PolyHarness("scatter_grad", "", - lambda *args: jax.grad( - lambda *args: - jnp.sum(lax.scatter( - *args, - indices_are_sorted=False, - unique_indices=False, - )) - )(*args), - arg_descriptors=[RandArg((7, 4), _f32), # : [b, 4] - np.array([[1], [2]], np.int32), # indices: [2, 1] - RandArg((7, 2), _f32), # updates: [b, 2] - StaticArg(lax.ScatterDimensionNumbers((0,), (1,), (1,))), - ], - polymorphic_shapes=["b, ...", None, "b, ..."]), - PolyHarness("scatter_grad", "poly_indices", - lambda *args: jax.grad( - lambda *args: - jnp.sum(lax.scatter( - *args, - indices_are_sorted=False, - unique_indices=False)) - )(*args), - arg_descriptors=[RandArg((7, 4), _f32), # op: [b, 4] - # indices: [b, 2] - np.array( - [[1, 2], [-2, 0], [6, 4], [7, -1], [1, 0], - [3, 0], [0, 5]], np.int32), - RandArg((7, 1), _f32), # updates: [b, 1] - StaticArg(lax.ScatterDimensionNumbers((1,), (0,), (0, 1))), - ], - polymorphic_shapes=["b, ...", "b, ...", "b, ..."]), - [ - PolyHarness("schur", - f"shape={jtu.format_shape_dtype_string(shape, dtype)}_{poly=}_{compute_schur_vectors=}", - lambda a, compute_schur_vectors: lax.linalg.schur( - a, compute_schur_vectors=compute_schur_vectors), - arg_descriptors=[RandArg(shape, dtype), - StaticArg(compute_schur_vectors)], - polymorphic_shapes=[poly], - # In non-native serialization, we cannot check exact match, - # we ought to check the invariants of the result. - check_result=config.jax2tf_default_native_serialization.value) - for dtype in {np.float32, np.float64, np.complex64, np.complex128} & jtu.supported_dtypes() - for compute_schur_vectors in [True, False] - for (shape, poly) in [ - ((3, 3), "w, w"), - ((3, 4, 4), "b, w, w"), - ] - ], - PolyHarness("select", "0", - # x.shape = (b, 3) - lambda x: lax.select(x > 5., x, x), - arg_descriptors=[RandArg((7, 3), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("select", "1", - # x.shape = (b, 3); y.shape = (3,) - jax.vmap(lambda x, y: lax.select(x > 5., x, y), in_axes=[0, None]), - arg_descriptors=[RandArg((7, 3), _f32), RandArg((3,), _f32)], - polymorphic_shapes=["b, ...", None]), - PolyHarness("slice", "entire_axis", - lambda x: lax.slice(x, start_indices=(0, 1), limit_indices=(x.shape[0], 3)), - arg_descriptors=[RandArg((7, 3), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("slice_in_dim", "entire_axis", - lambda x: lax.slice_in_dim(x, 0, x.shape[0], stride=1, axis=0), - arg_descriptors=[RandArg((3, 4), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("slice_in_dim", "start=neg", - lambda x: lax.slice_in_dim(x, -1, x.shape[0], stride=1, axis=0), - arg_descriptors=[RandArg((3, 4), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("slice_in_dim", "limit=neg", - lambda x: lax.slice_in_dim(x, 0, -1, stride=1, axis=0), - arg_descriptors=[RandArg((3, 4), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("slice_in_dim", "stride=2_even", - lambda x: lax.slice_in_dim(x, 0, x.shape[0], stride=2, axis=0), - arg_descriptors=[RandArg((12, 4), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("slice_in_dim", "stride=2_odd", - lambda x: lax.slice_in_dim(x, 0, x.shape[0], stride=2, axis=0), - arg_descriptors=[RandArg((13, 4), _f32)], - polymorphic_shapes=["b, ..."]), - # Not yet, the slice_in_dim does int(stride) - # PolyHarness("slice_in_dim", "stride=sym", - # lambda x: lax.slice_in_dim(x, 0, x.shape[0], stride=x.shape[0] // 4, axis=0), - # arg_descriptors=[RandArg((13, 4), _f32)], - # polymorphic_shapes=["b, ..."]), - PolyHarness("squeeze", "axis=empty", - jnp.squeeze, - arg_descriptors=[RandArg((5,), _f32), StaticArg(())], - polymorphic_shapes=["b, ..."]), - PolyHarness("squeeze", "axis=None", - jnp.squeeze, - arg_descriptors=[RandArg((5,), _f32), StaticArg(None)], - polymorphic_shapes=["b, ..."], - expect_error=(ValueError, "jnp.squeeze with axis=None is not supported with shape polymorphism")), - PolyHarness("squeeze", "axis=1", - jnp.squeeze, - arg_descriptors=[RandArg((4, 1), _f32), StaticArg((1,))], - polymorphic_shapes=["b, ..."]), - PolyHarness("squeeze", "axis=1_2", - jnp.squeeze, - arg_descriptors=[RandArg((4, 1, 1), _f32), StaticArg((1, 2))], - polymorphic_shapes=["b, ..."]), - PolyHarness("squeeze", "error", - jnp.squeeze, - arg_descriptors=[RandArg((3, 33), _f32), StaticArg(-1)], - polymorphic_shapes=["b0, b1"], - input_signature=[tf.TensorSpec([None, None], _f32)], - skip_jax_run=True, - expect_error=(ValueError, - re.escape( - "cannot select an axis to squeeze out which has size not equal to one, got shape=(b0, b1) and dimensions=(1,)")) - ), - PolyHarness("take", "", - lambda a, i: jnp.take(a, i, axis=1), - arg_descriptors=[RandArg((3, 4, 5), _f32), np.array([1, 2], np.int32)], - polymorphic_shapes=["b, ...", None]), - PolyHarness("take_along_axis", "0", - lambda x, y: jnp.take_along_axis(x, y, axis=0), - arg_descriptors=[RandArg((5, 2), _f32), RandArg((5, 1), np.int32)], - polymorphic_shapes=["b, ...", "b, ..."]), - PolyHarness("take_along_axis", "1", - lambda x, y: jnp.take_along_axis(x, y, axis=1), - arg_descriptors=[RandArg((5, 2), _f32), RandArg((5, 1), np.int32)], - polymorphic_shapes=["b, ...", "b, ..."]), - PolyHarness("tile", "0", - lambda x: jnp.tile(x, (1, 2)), - arg_descriptors=[RandArg((4, 3), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("tile", "1", - # The repetitions are polys - lambda x: jnp.tile(x, (1, x.shape[0])), - arg_descriptors=[RandArg((4, 2), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("lax_top_k", "", - lambda x: jax.lax.top_k(x, x.shape[-1] - 1), - arg_descriptors=[RandArg((16,), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("tri", "N=poly_M=None", - lambda x: jnp.tri(x.shape[0]) + x, - arg_descriptors=[RandArg((3, 1), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("tri", "N=poly_M=poly", - lambda x: jnp.tri(x.shape[0], M=x.shape[0] + 2) + x, - arg_descriptors=[RandArg((3, 1), _f32)], - polymorphic_shapes=["b, ..."]), - PolyHarness("tril", "", - lambda x: jnp.tril(jnp.ones((x.shape[0], x.shape[0] + x.shape[1]), - dtype=_f32), - k=x.shape[1]), - arg_descriptors=[RandArg((3, 4), _f32)], - polymorphic_shapes=["m, n"]), - [ - PolyHarness("triangular_solve", - f"shape={jtu.format_shape_dtype_string(a_shape, dtype)}_{left_side=}_{a_poly=}_{b_poly=}", - lambda a, b, left_side: lax.linalg.triangular_solve( - jnp.tril(a) + 5 * jnp.eye(a.shape[-1], dtype=a.dtype), - b, left_side=left_side, - lower=True, transpose_a=False, conjugate_a=False, - unit_diagonal=False), - arg_descriptors=[RandArg(a_shape, dtype), - RandArg(b_shape, dtype), - StaticArg(left_side)], - polymorphic_shapes=[a_poly, b_poly], - # In non-native serialization, we cannot check exact match, - # we ought to check the invariants of the result. - check_result=config.jax2tf_default_native_serialization.value) - for dtype in {np.float32, np.float64, np.complex64, np.complex128} & jtu.supported_dtypes() - for (left_side, a_shape, b_shape, a_poly, b_poly) in [ - (True, (3, 4, 4), (3, 4, 5), "b, ...", "b, ..."), - (True, (3, 4, 4), (3, 4, 5), "b, k, k", "b, k, m"), - (False, (3, 4, 4), (3, 5, 4), "b, ...", "b, ..."), - (False, (3, 4, 4), (3, 5, 4), "b, k, k", "b, m, k"), - # We use custom calls on CPU if not batched - (True, (4, 4), (4, 5), "k, k", "k, m"), - (False, (4, 4), (5, 4), "k, k", "m, k"), - ] - ], - [ - PolyHarness("var", - f"{axis=}_{keepdims=}_where=None", - lambda x, axis, keepdims: jnp.var(x, axis=axis, keepdims=keepdims, where=None), - arg_descriptors=[RandArg((7, 8, 4), _f32), StaticArg(axis), StaticArg(keepdims)], - polymorphic_shapes=["b, ..."]) - for keepdims in [False, True] - for axis in [None, (0,), (0, 1), (1,)] - ], - [ - PolyHarness("var", - f"{axis=}_{keepdims=}_where=Some", - lambda x, where, axis, keepdims: jnp.var(x, axis=axis, keepdims=keepdims, where=where), - arg_descriptors=[RandArg((7, 8, 4), _f32), RandArg((7, 8, 4), np.bool_), StaticArg(axis), StaticArg(keepdims)], - polymorphic_shapes=["b, ...", "b, ..."]) - for keepdims in [False, True] - for axis in [None, (0,), (0, 1), (1,)] - ], - PolyHarness("where", "", - jnp.where, - arg_descriptors=[RandArg((2,), np.bool_), RandArg((), _f32), RandArg((2,), _f32)], - polymorphic_shapes=["b, ...", None, "b, ..."]), -] - -def _get_jax2tf_limitations( - device, h: test_harnesses.Harness) -> Sequence[Jax2TfLimitation]: - # And the jax2tf limitations - def applicable_jax2tf_limitation(l: Jax2TfLimitation) -> bool: - # The CheckShapePolymorphism uses tf.function, so we care about "graph" - return l.filter(device=device, dtype=h.dtype, mode="graph") - - limitations = Jax2TfLimitation.limitations_for_harness(h) - return tuple(filter(applicable_jax2tf_limitation, limitations)) - -### We add to the test harnesses some that are obtained from the -### primitive harnesses by applying vmap to the function and then asserting -### that we can convert shape polymorphically the result. - -def _make_vmap_primitive_harnesses() -> Sequence[PolyHarness]: - """For each harness group, pick a single dtype. - - See PolyHarness for documentation. - - Ignore harnesses that fail in graph mode in jax2tf. - """ - all_h = test_harnesses.all_harnesses - res = [] - - # Index by group - harness_groups: dict[ - str, Sequence[test_harnesses.Harness]] = collections.defaultdict(list) - device = jtu.device_under_test() - - for h in all_h: - # Drop the JAX limitations - if not h.filter(device_under_test=device, include_jax_unimpl=False): - continue - # And the jax2tf limitations that are known to result in TF error. - if any(l.expect_tf_error for l in _get_jax2tf_limitations(device, h)): - continue - harness_groups[h.group_name].append(h) - - selected_harnesses = [] - for _, hlist in harness_groups.items(): - # Pick the dtype with the most harnesses in this group. Some harness - # groups only test different use cases at a few dtypes. - c = collections.Counter([h.dtype for h in hlist]) - (_, max_count), = c.most_common(1) - # Pick the first alphabetically among those with max_count, to ensure - # that we generate deterministic tests. - dtypes_with_max_count = (dtype for dtype, count in c.items() - if count == max_count) - dtype, *_ = sorted(dtypes_with_max_count, key=str) - selected_harnesses.extend([h for h in hlist if h.dtype == dtype]) - - batch_size = 3 - for h in selected_harnesses: - if h.group_name in [ - "tridiagonal_solve", # batching not implemented in JAX - ]: - continue - - def make_batched_arg_descriptor( - ad: test_harnesses.ArgDescriptor) -> test_harnesses.ArgDescriptor | None: - if isinstance(ad, RandArg): - return RandArg((batch_size,) + ad.shape, ad.dtype) - elif isinstance(ad, CustomArg): - def wrap_custom(rng): - arg = ad.make(rng) - return np.stack([arg] * batch_size) - - return CustomArg(wrap_custom) - else: - assert isinstance(ad, np.ndarray), ad - return np.stack([ad] * batch_size) - - new_args = [make_batched_arg_descriptor(ad) - for ad in h.arg_descriptors - if not isinstance(ad, StaticArg)] - - # This test does not make sense for nullary functions - if not new_args: - continue - - limitations = [ - l for l in _get_jax2tf_limitations(device, h) - if not l.skip_comparison and (l.custom_assert or l.tol is not None)] - - vmap_harness = PolyHarness("vmap_" + h.group_name, h.name, - jax.vmap(h.dyn_fun, in_axes=0, out_axes=0), - arg_descriptors=new_args, - polymorphic_shapes=["b, ..."] * len(new_args), - limitations=limitations) - vmap_harness.original_harness = h - res.append(vmap_harness) - return res - -_POLY_SHAPE_TEST_HARNESSES.append(_make_vmap_primitive_harnesses()) - -def _flatten_harnesses(harnesses): - res = [] - for h in harnesses: - if isinstance(h, Sequence): - res.extend(_flatten_harnesses(h)) - else: - res.append(h) - return res - - -class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase): - """Tests for primitives that take shape values as parameters.""" - - # This test runs for all _POLY_SHAPE_PRIMITIVE_HARNESSES. - - # For each primitive "xxx" the test will be called "test_harness_xxx_...". - # If you want to run this test for only one harness that includes "foo" - # in the name (after test_harness), add parameter `one_containing="foo"` - # to parameterized below. - @test_harnesses.parameterized( - _flatten_harnesses(_POLY_SHAPE_TEST_HARNESSES), - #one_containing="", - ) - def test_harness(self, harness: PolyHarness): - if harness.expect_error == expect_error_associative_scan and ( - not config.jax2tf_default_native_serialization.value - or jtu.test_device_matches(["tpu"]) - ): - harness.expect_error = (None, None) - - # Exclude some harnesses that are known to fail for native serialization - # FOR NATIVE SERIALIZATION - # Set of harness.group_name:platform that are implemented with custom call - custom_call_harnesses = { - "householder_product:gpu", - "vmap_geqrf:gpu", # used for linalg.qr - "vmap_lu:gpu", - # custom_linear_solve works as long as lu works. - "vmap_custom_linear_solve:gpu", - "vmap_qr:gpu", "qr:gpu", - "vmap_svd:gpu", - } - if f"{harness.group_name}:{jtu.device_under_test()}" in custom_call_harnesses: - raise unittest.SkipTest("native serialization with shape polymorphism not implemented for custom calls; b/261671778") - - if harness.group_name == "schur" and not jtu.test_device_matches(["cpu"]): - raise unittest.SkipTest("schur decomposition is only implemented on CPU.") - - if "fft_fft_type" in harness.fullname: - if "nr_fft_lengths_2" in harness.fullname: - raise unittest.SkipTest("native serialization with shape polymorphism not implemented for fft with non-constant fft_lengths on GPU and TPU") - - if harness.group_name == "vmap_eigh" and jtu.test_device_matches(["gpu"]): - # For eigh on GPU with shape polymorphism under native serialization, - # we use a different lowering for small matrices. See README.md. - shape = harness.original_harness.params["shape"] - if 0 < shape[-1] <= 32: - harness.check_result = False - - if harness.group_name == "vmap_eigh": - raise unittest.SkipTest( - "Should not compare eigendecompositions for equality directly" - "because eigenvalues are sorted.") - - if harness.group_name == "vmap_tan": - # Tan (b/274462307) require support for custom call stablehlo.tan. - raise unittest.SkipTest( - "native lowering with shape polymorphism requires additional StableHLO feature support") - - if (jtu.test_device_matches(["cpu", "gpu"]) and - harness.fullname in [ - "cumsum_reduce_axis_poly", "cumprod_reduce_axis_poly", - "cummin_reduce_axis_poly", "cummax_reduce_axis_poly", - "cumlogsumexp_reduce_axis_poly", - "jnp_insert_insert_constant", "jnp_insert_insert_poly", - "jnp_nonzero_size_constant", "jnp_nonzero_size_poly"]): - # Need associative scan reductions on CPU and GPU. On TPU we use the - # reduce_window HLO, but on CPU and GPU (with axis size >= 32) we use - # a recursive associative scan that we cannot express with shape - # polymorphism. - raise unittest.SkipTest( - "native serialization with shape polymorphism not implemented for window_reductions on CPU and GPU") - - # FOR BOTH NATIVE AND GRAPH SERIALIZATION - if harness.group_name == "vmap_conv_general_dilated": - # https://github.com/openxla/stablehlo/issues/1268 - raise unittest.SkipTest("Need more dynamism for DynamicConvOp") - - if harness.group_name == "eig" and not jtu.test_device_matches(["cpu"]): - raise unittest.SkipTest("JAX implements eig only on CPU.") - - with jtu.thread_local_config_context(**harness.override_jax_config_flags): - harness.run_test(self) - - if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index 653ddce7dca4..f0bc0ffa78d5 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for the jax2tf conversion of pjit. +"""Tests for handling of sharding in the jax2tf conversion of jit. To verify that the tests do run indeed on multiple devices you can run @@ -24,6 +24,7 @@ import re from typing import Any import unittest +import warnings from absl import app from absl.testing import absltest @@ -33,18 +34,25 @@ from jax._src import config from jax._src import test_util as jtu from jax._src import xla_bridge +from jax._src.lib import xla_client as xc from jax import lax from jax.experimental import jax2tf from jax.experimental import pjit -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map from jax.sharding import NamedSharding from jax.sharding import Mesh from jax.sharding import PartitionSpec as P +from jax.sharding import AxisType import jax.numpy as jnp import numpy as np -import tensorflow as tf +# TODO(b/470156950): Remove this once a proper fix is in place +with warnings.catch_warnings(): + warnings.filterwarnings("ignore", + category=FutureWarning, + message=".*np.object.*") + import tensorflow as tf config.parse_flags_with_absl() jtu.request_cpu_devices(8) @@ -69,6 +77,7 @@ def initialize_tf_tpu(): app.call_after_init(initialize_tf_tpu) +@jtu.thread_unsafe_test_class() class ShardingTest(tf_test_util.JaxToTfTestCase): """Tests that inspect the HLO for the sharding annotations. """ @@ -81,14 +90,10 @@ def setUp(self): raise unittest.SkipTest("Test requires at least 2 local devices") self.devices = np.array(jax.devices()[:2]) # use 2 devices - self.warning_ctx = jtu.ignore_warning( - message="jax2tf.convert with native_serialization=False is deprecated" + def get_xla_options(self): + return tf.tpu.XLAOptions( + use_shardy_partitioner=jax.config.jax_use_shardy_partitioner ) - self.warning_ctx.__enter__() - - def tearDown(self): - self.warning_ctx.__exit__(None, None, None) - super().tearDown() def log_jax_hlo(self, f_jax, args: Sequence[Any], *, num_replicas=1, num_partitions=2): @@ -109,8 +114,9 @@ def log_jax_hlo(self, f_jax, args: Sequence[Any], *, device_assignment=device_assignment, use_spmd_partitioning=use_spmd_partitioning, ) - jax_optimized_hlo = backend.compile( - jax_hlo, compile_options).hlo_modules()[0].to_string() + executable = backend.compile_and_load( + jax_hlo, xc.DeviceList(tuple(self.devices.flat)), compile_options) # type: ignore + jax_optimized_hlo = executable.hlo_modules()[0].to_string() logging.info("[%s] got JAX optimized HLO for platform %s %s", self._testMethodName, backend.platform, jax_optimized_hlo) @@ -186,18 +192,18 @@ def check_sharding(self, f_tf, args_tf: Sequence[Any], *, for in_shardings in ("missing", None, "P") for out_shardings in ("missing", None, "P") ]) - @jtu.with_mesh([("x", 2)]) - def test_pjit_basic(self, in_shardings="P", out_shardings="P"): + @jtu.with_explicit_mesh((2,), ("x",), axis_types=(AxisType.Auto,)) + def test_jit_basic(self, *, mesh, in_shardings="P", out_shardings="P"): # Ensure that we can distinguish the inputs and outputs by shape def f_jax(x): # f32[10,20] -> f32[20,10] return jnp.sin(x.T) - pjit_kwargs = {} + jit_kwargs = {} if in_shardings != "missing": - pjit_kwargs["in_shardings"] = (P(None, "x") if in_shardings == "P" else None) + jit_kwargs["in_shardings"] = (P(None, "x") if in_shardings == "P" else None) if out_shardings != "missing": - pjit_kwargs["out_shardings"] = (P("x", None) if out_shardings == "P" else None) - f_jax = pjit.pjit(f_jax, **pjit_kwargs) + jit_kwargs["out_shardings"] = (P("x", None) if out_shardings == "P" else None) + f_jax = jax.jit(f_jax, **jit_kwargs) x_shape = (10, 20) x = np.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape) @@ -209,20 +215,21 @@ def f_tf(x): f_converted = jax2tf.convert(f_jax) if jtu.test_device_matches(["tpu"]): return tf.compat.v1.tpu.rewrite( - f_converted, [tf.convert_to_tensor(x)], + f_converted, + [tf.convert_to_tensor(x)], device_assignment=self.device_assignment( computation_shape=[1, 1, 1, 2], - ))[0] + ), + xla_options=self.get_xla_options(), + )[0] else: return f_converted(x) # Annotation count for the input count_in_P = 1 if in_shardings == "P" else 0 - if config.jax2tf_default_native_serialization.value: - # With native serialization even unspecified in_shardings turn into replicated - count_in_replicated = 1 if in_shardings in [None, "missing"] else 0 - else: - count_in_replicated = 1 if in_shardings is None else 0 + # With native serialization even unspecified in_shardings turn into replicated + count_in_replicated = 1 if in_shardings in [None, "missing"] else 0 + # Annotation count for the output count_out_P = 1 if out_shardings == "P" else 0 count_out_replicated = 1 if out_shardings is None else 0 @@ -231,10 +238,10 @@ def f_tf(x): jax2tf.convert(f_jax), [x], checks=[ # The argument - (r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2\]", + (r"f32\[10,20\].*custom_call_target.*\"Sharding.*sharding.*devices=\[1,2\]", count_in_P), # The result - (r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", + (r"f32\[20,10\].*custom_call_target.*\"Sharding.*sharding.*devices=\[2,1\]", count_out_P), ]) # TODO(b/326476605): Change the condition below if required. @@ -242,11 +249,11 @@ def f_tf(x): self.check_sharding( jax2tf.convert(f_jax), [x], checks=[ - (r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*replicated", + (r"f32\[10,20\].*custom_call_target.*\"Sharding.*sharding.*replicated", count_in_replicated), - (r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*replicated", + (r"f32\[20,10\].*custom_call_target.*\"Sharding.*sharding.*replicated", count_out_replicated), - (r"custom_call_target.*Sharding", + (r"custom_call_target.*\"Sharding", count_in_P + count_in_replicated + count_out_P + count_out_replicated), ]) @@ -254,10 +261,10 @@ def f_tf(x): res_tf = f_tf(x) self.assertAllClose(res_tf.numpy(), res_jax) - @jtu.with_mesh([("x", 2)]) - def test_pjit_variable_arg(self): + @jtu.with_explicit_mesh((2,), ("x",), axis_types=(AxisType.Auto,)) + def test_jit_variable_arg(self, mesh): # The first argument is a tf.Variable - @partial(pjit.pjit, in_shardings=(P(None, "x"), P("x", None)), + @jax.jit(in_shardings=(P(None, "x"), P("x", None)), out_shardings=None) def f_jax(x, y): # f32[10,20] , f32[20,30] -> f32[10,30] return x @ y @@ -276,21 +283,19 @@ def f_jax(x, y): # f32[10,20] , f32[20,30] -> f32[10,30] f_tf, [y], checks=[ # The variable argument - (r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2\]", 1), + (r"f32\[10,20\].*custom_call_target.*\"Sharding.*sharding.*devices=\[1,2\]", 1), # The y argument - (r"f32\[20,30\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", 1), - # The output sharding - (r"f32\[10,30\].*custom_call_target.*Sharding.*sharding.*replicated", 1), + (r"f32\[20,30\].*custom_call_target.*\"Sharding.*sharding.*devices=\[2,1\]", 1), # No other annotations - (r"custom_call_target.*Sharding", 3) + (r"custom_call_target.*\"Sharding", 2) ]) - @jtu.with_mesh([("x", 2)]) - def test_pjit_closed_over_const(self): + @jtu.with_explicit_mesh((2,), ("x",), axis_types=(AxisType.Auto,)) + def test_jit_closed_over_const(self, mesh): x = np.ones((10, 20), dtype=np.float32) const = jnp.full((10, 20), 7, dtype=np.float32) - @partial(pjit.pjit, in_shardings=(P("x"),), out_shardings=None) + @jax.jit(in_shardings=(P("x"),), out_shardings=None) def f_jax(x): # f32[10,20] -> f32[20,10] return (x * const).T @@ -299,9 +304,12 @@ def f_tf(x): f_converted = jax2tf.convert(f_jax) if jtu.test_device_matches(["tpu"]): return tf.compat.v1.tpu.rewrite( - f_converted, [tf.convert_to_tensor(x)], + f_converted, + [tf.convert_to_tensor(x)], device_assignment=self.device_assignment( - computation_shape=[1, 1, 1, 2]) + computation_shape=[1, 1, 1, 2] + ), + xla_options=self.get_xla_options(), )[0] else: return f_converted(x) @@ -310,11 +318,8 @@ def f_tf(x): jax2tf.convert(f_jax), [x], checks=[ # x - (r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", + (r"f32\[10,20\].*custom_call_target.*\"Sharding.*sharding.*devices=\[2,1\]", 1), - # The result - (r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*replicated", - self.GEQ(1)), ]) res_jax = f_jax(x) @@ -323,21 +328,29 @@ def f_tf(x): @jtu.parameterized_filterable( kwargs=[ - dict(testcase_name=f"_nested_pjit={nested_pjit}_constraint={constraint}_poly={poly}", - nested_pjit=nested_pjit, constraint=constraint, poly=poly) - # We add a constraint either with a nested pjit or with a sharding_constraint - for nested_pjit in (True, False) + dict(testcase_name=f"_nested_jit={nested_jit}_constraint={constraint}_poly={poly}", + nested_jit=nested_jit, constraint=constraint, poly=poly) + # We add a constraint either with a nested jit or with a sharding_constraint + for nested_jit in (True, False) for constraint in (None, "P") for poly in (None, "2*b1,_", "_,b2", "2*b1,b2") ]) + @jtu.ignore_warning(message='.*Please use `jax.jit` instead.*', + category=DeprecationWarning) @jtu.with_mesh([("x", 2)]) - def test_pjit_sharding_constraint(self, nested_pjit=True, constraint="P", poly="2*b1,b2"): + #@jtu.with_explicit_mesh((2,), ("x",), axis_types=(AxisType.Auto,)) + def test_jit_sharding_constraint(self, *, nested_jit=True, constraint="P", poly="2*b1,b2"): + # TODO(necula): move this test also to use jit. Currently, if we replace + # `with mesh` with `with set_mesh` (jtu.with_explicit_mesh above), and + # we keep using pjit, we get an error that the sharding constraint cannot + # be None. But if we also replace pjit with jit, there is no such error, + # and instead we see that the replicated shardings are silently dropped. constraint_sharding = P("x", None) if constraint == "P" else None @partial(pjit.pjit, in_shardings=None, out_shardings=None) - def f_jax(x): # x: f32[10, 20], optionally some axes as polymorphic + def f_jax(x): # x: f32[10, 20], optionally some axes are polymorphic y = jnp.concatenate([x, x], axis=1) # y: f32[10, 40] - if nested_pjit: + if nested_jit: y = pjit.pjit(lambda y: y, in_shardings=constraint_sharding, out_shardings=constraint_sharding)(y) else: @@ -351,22 +364,22 @@ def f_jax(x): # x: f32[10, 20], optionally some axes as polymorphic f_tf = jax2tf.convert(f_jax, polymorphic_shapes=poly) # If we use a pjit then we see two constraints, otherwise only 1 - count_inner_sharding = (2 if nested_pjit else 1) if constraint == "P" else 0 - count_inner_replicated = (2 if nested_pjit else 1) if constraint != "P" else 0 + count_inner_sharding = (2 if nested_jit else 1) if constraint == "P" else 0 + count_inner_replicated = (2 if nested_jit else 1) if constraint != "P" else 0 self.check_sharding( f_tf, [x], checks=[ # The input argument - (r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*replicated", 1), + (r"f32\[10,20\].*custom_call_target.*\"Sharding.*sharding.*replicated", 1), # The y argument - (r"f32\[10,40\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", + (r"f32\[10,40\].*custom_call_target.*\"Sharding.*sharding.*devices=\[2,1\]", count_inner_sharding), - (r"f32\[10,40\].*custom_call_target.*Sharding.*sharding.*replicated", + (r"f32\[10,40\].*custom_call_target.*\"Sharding.*sharding.*replicated", count_inner_replicated), # The output sharding - (r"f32\[10,80\].*custom_call_target.*Sharding.*sharding.*replicated", 1), + (r"f32\[10,80\].*custom_call_target.*\"Sharding.*sharding.*replicated", 1), # No other annotations - (r"custom_call_target.*Sharding", 2 + count_inner_sharding + count_inner_replicated) + (r"custom_call_target.*\"Sharding", 2 + count_inner_sharding + count_inner_replicated) ]) @jtu.parameterized_filterable( @@ -376,9 +389,7 @@ def f_jax(x): # x: f32[10, 20], optionally some axes as polymorphic for in_shardings in ("missing", None, "P") for out_shardings in ("missing", None, "P") ]) - def test_grad_pjit(self, in_shardings="P", out_shardings=None): - if not config.jax2tf_default_native_serialization.value: - self.skipTest("TODO: failure in non-native serialization") + def test_grad_jit(self, in_shardings="P", out_shardings=None): local_devices = list(jax.local_devices()) size = 2 if len(local_devices) < size: @@ -388,14 +399,14 @@ def test_grad_pjit(self, in_shardings="P", out_shardings=None): def f_jax(x): # x: f32[10,20] -> f32[20,10] return jnp.sin(x.T) - pjit_kwargs = {} + jit_kwargs = {} if in_shardings != "missing": - pjit_kwargs["in_shardings"] = ( + jit_kwargs["in_shardings"] = ( NamedSharding(mesh, P(None, "x")) if in_shardings == "P" else None) if out_shardings != "missing": - pjit_kwargs["out_shardings"] = ( + jit_kwargs["out_shardings"] = ( NamedSharding(mesh, P("x", None)) if out_shardings == "P" else None) - f_jax = pjit.pjit(f_jax, **pjit_kwargs) + f_jax = jax.jit(f_jax, **jit_kwargs) x_shape = (10, 20) x = np.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape) @@ -411,33 +422,27 @@ def f_grad_tf(x_v, res_ct): # Annotation count for the primal input and the grad output count_in_P = self.GEQ(2) if in_shardings == "P" else 0 - if config.jax2tf_default_native_serialization.value: - # With native serialization even unspecified shardings turn into replicated - count_in_replicated = self.GEQ(2) if in_shardings in [None, "missing"] else 0 - else: - count_in_replicated = self.GEQ(2) if in_shardings is None else 0 + # With native serialization even unspecified shardings turn into replicated + count_in_replicated = self.GEQ(2) if in_shardings in [None, "missing"] else 0 # Annotation count for the contangent input count_out_P = self.GEQ(1) if out_shardings == "P" else 0 - if config.jax2tf_default_native_serialization.value: - # With native serialization even unspecified shardings turn into replicated - count_out_replicated = self.GEQ(1) if out_shardings in [None, "missing"] else 0 - else: - count_out_replicated = self.GEQ(1) if out_shardings is None else 0 + # With native serialization even unspecified shardings turn into replicated + count_out_replicated = self.GEQ(1) if out_shardings in [None, "missing"] else 0 self.check_sharding(f_grad_tf, [x, x.T], checks=[ # The input primal argument, and the output grad - (r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2\]", count_in_P), + (r"f32\[10,20\].*custom_call_target.*\"Sharding.*sharding.*devices=\[1,2\]", count_in_P), # The primal result, and the input cotangent - (r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", count_out_P), + (r"f32\[20,10\].*custom_call_target.*\"Sharding.*sharding.*devices=\[2,1\]", count_out_P), ]) # TODO(b/326476605): Change the condition below if required. if out_shardings not in [None, "missing"] and in_shardings not in [None, "missing"]: self.check_sharding(f_grad_tf, [x, x.T], checks=[ - (r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*replicated", count_in_replicated), + (r"f32\[10,20\].*custom_call_target.*\"Sharding.*sharding.*replicated", count_in_replicated), # The primal result, and the input cotangent - (r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", count_out_P), + (r"f32\[20,10\].*custom_call_target.*\"Sharding.*sharding.*devices=\[2,1\]", count_out_P), ]) def test_grad_sharding_different_mesh(self): @@ -456,9 +461,9 @@ def f_jax(x): shardings = NamedSharding(mesh, jax.sharding.PartitionSpec(("i",))) shardings_rev = NamedSharding(mesh_rev, jax.sharding.PartitionSpec(("i",))) - f_tf = tf.function(jax2tf.convert(pjit.pjit(f_jax, in_shardings=shardings)), + f_tf = tf.function(jax2tf.convert(jax.jit(f_jax, in_shardings=shardings)), autograph=False) - f_tf_rev = tf.function(jax2tf.convert(pjit.pjit(f_jax, in_shardings=shardings_rev)), + f_tf_rev = tf.function(jax2tf.convert(jax.jit(f_jax, in_shardings=shardings_rev)), autograph=False) inp = np.ones((2, 4), dtype=np.float32) @@ -474,54 +479,6 @@ def f_jax(x): g_rev = tape.gradient(res_tf_rev, input_v) self.assertAllClose(g, g_rev) - @jtu.parameterized_filterable( - kwargs=[ - dict(testcase_name=f"_func={func}", func=func) - for func in ("pjit_sharded", "pjit_replicated", - "nested_pjit_sharded", "nested_pjit_replicated") - ]) - def test_pjit_eager_error(self, func="pjit_sharded"): - if config.jax2tf_default_native_serialization.value: - raise unittest.SkipTest("There is no error in eager mode for native serialization") - - # Define some test functions - @partial(pjit.pjit, in_shardings=(P("x"),), - out_shardings=None) - def f_pjit_sharded(a): - return a + a - - @partial(pjit.pjit, in_shardings=None, - out_shardings=None) - def f_pjit_replicated(a): - return a + a - - def f_nested_pjit_sharded(a): - return a + pjit.pjit(jnp.sin, in_shardings=(P("x"),), out_shardings=None)(a) - - def f_nested_pjit_replicated(a): - return a + pjit.pjit(jnp.sin, in_shardings=None, out_shardings=None)(a) - - shape = (8, 10) - a = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) - - if func == "pjit_sharded": - f_jax = f_pjit_sharded - elif func == "pjit_replicated": - f_jax = f_pjit_replicated - elif func == "nested_pjit_sharded": - f_jax = f_nested_pjit_sharded - elif func == "nested_pjit_replicated": - f_jax = f_nested_pjit_replicated - else: - assert False - - with Mesh(self.devices, axis_names=("x",)): - _ = f_jax(a) - with self.assertRaisesRegex( - ValueError, - "function with sharded arguments or results must be used under a `tf.function` context"): - jax2tf.convert(f_jax)(a) - @jtu.ignore_warning(category=UserWarning, message="all_to_all .* are only implemented properly for TPUs and GPUs .*") def test_shmap_all_to_all(self): @@ -531,21 +488,25 @@ def test_shmap_all_to_all(self): mesh = Mesh(self.devices, axis_names=('x')) a = np.arange(4 * 4, dtype=np.float32).reshape((4, 4)) - @partial(pjit.pjit, - in_shardings=(P('x', None),), out_shardings=P(None, 'x')) + @partial(jax.jit, + in_shardings=(NamedSharding(mesh, P("x", None)),), + out_shardings=NamedSharding(mesh, P(None, "x"))) @partial(shard_map, mesh=mesh, - in_specs=(P('x', None),), out_specs=P(None, 'x')) + in_specs=(P("x", None),), out_specs=P(None, "x")) def f_jax(b): # b: f32[2, 4] - return lax.all_to_all(b, 'x', split_axis=1, concat_axis=1, tiled=True) + return lax.all_to_all(b, "x", split_axis=1, concat_axis=1, tiled=True) @tf.function(autograph=False, jit_compile=True) def f_tf(a): - f_converted = jax2tf.convert(f_jax, native_serialization=True) + f_converted = jax2tf.convert(f_jax) if jtu.test_device_matches(["tpu"]): return tf.compat.v1.tpu.rewrite( - f_converted, [tf.convert_to_tensor(a)], + f_converted, + [tf.convert_to_tensor(a)], device_assignment=self.device_assignment( - computation_shape=[1, 1, 1, 2]) + computation_shape=[1, 1, 1, 2] + ), + xla_options=self.get_xla_options(), )[0] else: return f_converted(a) @@ -562,21 +523,16 @@ def f_tf(a): res_tf = f_tf(a) self.assertAllClose(res_tf, res_jax) - # TODO(b/274648842): Failed to GetCompilerIr - # self.check_sharding( - # jax2tf.convert(f_jax, native_serialization=True), [a], - # checks=[]) - @unittest.skip("TODO(b/268295912): ShardingRemover crash,on all platforms!!!") def test_repro_xla_bug_shmap_collective_permute(self): mesh = Mesh(self.devices, axis_names=('x')) - @partial(pjit.pjit, + @partial(jax.jit, in_shardings=(P('x', None),), out_shardings=P('x', None)) @partial(shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None)) def f_jax(b): # b: f32[2, 4] - axis_size = lax.psum(1, 'x') + axis_size = lax.axis_size('x') perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] return lax.ppermute(b, 'x', perm=perm) @@ -590,7 +546,7 @@ def f_jax(b): # b: f32[2, 4] # XLA bug: invoke the f_tf without tpu.replicate f_tf = tf.function( - jax2tf.convert(f_jax, native_serialization=True), + jax2tf.convert(f_jax), autograph=False, jit_compile=True) res_tf = f_tf(a) @@ -604,27 +560,30 @@ def f_jax(b): # b: f32[2, 4] def test_shmap_collective_permute(self, poly=None): if jtu.test_device_matches(["cpu"]): raise unittest.SkipTest("TODO(b/268295912): ShardingRemover crash") - mesh = Mesh(self.devices, axis_names=('x')) + mesh = Mesh(self.devices, axis_names=("x")) a = np.arange(4 * 4, dtype=np.float32).reshape((4, 4)) - @partial(pjit.pjit, - in_shardings=(P('x', None),), out_shardings=P('x', None)) + @partial(jax.jit, + in_shardings=(NamedSharding(mesh, P("x", None)),), + out_shardings=NamedSharding(mesh, P("x", None))) @partial(shard_map, mesh=mesh, - in_specs=(P('x', None),), out_specs=P('x', None)) + in_specs=(P("x", None),), out_specs=P("x", None)) def f_jax(b): # b: f32[2, 4] - axis_size = lax.psum(1, 'x') + axis_size = lax.axis_size("x") perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] - return lax.ppermute(b, 'x', perm=perm) + return lax.ppermute(b, "x", perm=perm) @tf.function(autograph=False, jit_compile=True) def f_tf(a): - f_converted = jax2tf.convert(f_jax, native_serialization=True, - polymorphic_shapes=poly) + f_converted = jax2tf.convert(f_jax, polymorphic_shapes=poly) if jtu.test_device_matches(["tpu"]): res = tf.compat.v1.tpu.rewrite( - f_converted, [tf.convert_to_tensor(a)], + f_converted, + [tf.convert_to_tensor(a)], device_assignment=self.device_assignment( - computation_shape=[1, 1, 1, 2]) + computation_shape=[1, 1, 1, 2] + ), + xla_options=self.get_xla_options(), )[0] else: res = f_converted(a) @@ -638,10 +597,7 @@ def f_tf(a): self.assertAllClose(res_jax, expected) res_tf = f_tf(a) self.assertAllClose(res_tf, expected) - # TODO(b/274648842): Failed to GetCompilerIr - # self.check_sharding( - # jax2tf.convert(f_jax, native_serialization=True), [a], - # checks=[]) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/jax/experimental/jax2tf/tests/tf_test_util.py b/jax/experimental/jax2tf/tests/tf_test_util.py index 32f89e533daf..df7e59a0d8ce 100644 --- a/jax/experimental/jax2tf/tests/tf_test_util.py +++ b/jax/experimental/jax2tf/tests/tf_test_util.py @@ -15,7 +15,6 @@ from __future__ import annotations from collections.abc import Callable, Sequence -import contextlib import dataclasses import re import os @@ -34,6 +33,7 @@ from jax import export from jax._src import config from jax._src import xla_bridge +from jax._src.lib import xla_client as xc import numpy as np import tensorflow as tf from tensorflow.compiler.xla import xla_data_pb2 @@ -194,9 +194,7 @@ def setUp(self): export.maximum_supported_calling_convention_version, tfxla.call_module_maximum_supported_version()) - with contextlib.ExitStack() as stack: - stack.enter_context(tf.device(self.tf_default_device)) - self.addCleanup(stack.pop_all().close) + self.enter_context(tf.device(self.tf_default_device)) def assertDtypesMatch(self, x, y, *, canonicalize_dtypes=True): """Compares dtypes across JAX and TF dtypes. Overrides super method.""" @@ -215,7 +213,6 @@ def to_numpy_dtype(dt): def ConvertAndCompare(self, func_jax: Callable, *args, - enable_xla: bool = True, limitations: Sequence = ()): """Compares jax_func(*args) with convert(jax_func)(*args). @@ -228,8 +225,6 @@ def ConvertAndCompare(self, Args: func_jax: the function to invoke (``func_jax(*args)``) args: the arguments. - enable_xla: if True, allows the use of XLA ops in jax2tf.convert - (default: True). limitations: the set of limitations for this harness (not yet filtered by mode). """ @@ -238,7 +233,7 @@ def ConvertAndCompare(self, result_jax = func_jax(*args) # JAX result_tf = None - func_tf = jax2tf.convert(func_jax, enable_xla=enable_xla) + func_tf = jax2tf.convert(func_jax) unexpected_successes: list[str] = [] # Run the "compiled" mode first, it is most important @@ -248,7 +243,7 @@ def ConvertAndCompare(self, def log_message(extra): return f"[{self._testMethodName}] {mode=}: {extra}" - jax2tf_limits = tuple(filter(lambda l: l.filter(mode=mode), limitations)) + jax2tf_limits = tuple(filter(lambda l: l.filter(), limitations)) skip_tf_run = [l for l in jax2tf_limits if l.skip_tf_run] if skip_tf_run: @@ -344,7 +339,9 @@ def log_message(extra): tf_hlo) backend = xla_bridge.get_backend() - modules = backend.compile(str(jax_lowered.compiler_ir())).hlo_modules() + device_list = xc.DeviceList(tuple(backend.local_devices())) + modules = backend.compile_and_load( + str(jax_lowered.compiler_ir()), device_list).hlo_modules() jax_opt_hlo = modules[0].to_string() logging.info("[%s] JAX OPT HLO\n%s", self._testMethodName, jax_opt_hlo) @@ -411,14 +408,9 @@ def FindLargeTfConstants(self, tf_fun: Callable, *args, # graph. We count the number of characters in the textual representation # of the constant. f_tf_graph = tf.function(tf_fun, autograph=False).get_concrete_function(*args).graph.as_graph_def() - if config.jax2tf_default_native_serialization.value: - # This way of finding constants may be brittle, if the constant representation - # contains >. It seems tobe hex-encoded, so this may be safe. - large_consts = [m for m in re.findall(r"dense<([^>]+)>", str(f_tf_graph)) if len(m) >= at_least] - else: - # We cannot find the constants just with string matching because their - # representation may contain escaped " - large_consts = [str(n) for n in f_tf_graph.node if n.op == "Const" and len(str(n)) >= at_least] + # This way of finding constants may be brittle, if the constant representation + # contains >. It seems tobe hex-encoded, so this may be safe. + large_consts = [m for m in re.findall(r"dense<([^>]+)>", str(f_tf_graph)) if len(m) >= at_least] return large_consts def CheckOpMetadata(self, jax_fun, x, diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 15273f0fd02a..c5ffd89b0def 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -62,7 +62,6 @@ from jax import lax from jax import api_util import jax.numpy as jnp -from jax.experimental import pjit from jax.tree_util import (register_pytree_node, tree_structure, treedef_is_leaf, tree_flatten, tree_unflatten,) @@ -70,13 +69,14 @@ from jax._src import core from jax._src import dispatch from jax._src import linear_util as lu +from jax._src import pjit from jax._src import sharding_impls from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax as lax_internal from jax._src.util import unzip2, weakref_lru_cache, safe_zip -def jet(fun, primals, series): +def jet(fun, primals, series, factorial_scaled=True, **_): r"""Taylor-mode higher-order automatic differentiation. Args: @@ -91,6 +91,13 @@ def jet(fun, primals, series): Together, `primals` and `series` make up a truncated Taylor polynomial. Should be either a tuple or a list of tuples or lists, and its length dictates the degree of the truncated Taylor polynomial. + factorial_scaled: If True, each term in both the input and output series is scaled + by the factorial of its order, so that the input and output series is a + Taylor series. This is the default behavior so that the n-th order term + in the input and output series is the n-th order derivative of the function. + If False, the input and output series are the non-factorial scaled Taylor + coefficients (i.e., the constant coefficients for each term in the Taylor + series). Returns: A ``(primals_out, series_out)`` pair, where ``primals_out`` is ``fun(*primals)``, @@ -120,10 +127,10 @@ def jet(fun, primals, series): 0.12467473 0.12467473 >>> print(f1, df(h0) * h1) - 0.7441479 0.74414825 + 0.74414825 0.74414825 >>> print(f2, ddf(h0) * h1 ** 2 + df(h0) * h2) - 2.9064622 2.9064634 + 2.9064636 2.9064634 """ try: order, = set(map(len, series)) @@ -151,9 +158,20 @@ def flatten_fun_output(f, store, *args): f, out_tree = flatten_fun_output( lu.wrap_init(fun, debug_info=api_util.debug_info("jet", fun, primals, {}))) + if factorial_scaled: + series = [[(term / fact(order + 1)) for order, term in enumerate(terms)] + for terms in series] out_primals, out_terms = jet_fun(jet_subtrace(f), order).call_wrapped(primals, series) + if factorial_scaled: + out_terms = [[term * fact(order + 1) for order, term in enumerate(terms)] + for terms in out_terms] return tree_unflatten(out_tree(), out_primals), tree_unflatten(out_tree(), out_terms) +jet2 = partial(jet, factorial_scaled=False) + +def fact(n): + return lax.exp(lax.lgamma(n+1.)) + @lu.transformation2 def jet_fun(f, order, primals, series): tag = core.TraceTag() @@ -365,7 +383,7 @@ def _cumulative_jet_rule(primals_in, series_in, *, axis: int, reverse: bool, # Irrespective of backend, we always use the parallel prefix scan # implementation when differentiating because reduce_window is not # arbitrarily differentiable. - return jet(partial(lax.associative_scan, combine_fn, axis=axis, + return jet2(partial(lax.associative_scan, combine_fn, axis=axis, reverse=reverse), primals_in, series_in) @@ -389,12 +407,12 @@ def deriv_prop(prim, deriv, primals_in, series_in): x, = primals_in series, = series_in primal_out = prim.bind(x) - c0, cs = jet(deriv, primals_in, series_in) + c0, cs = jet2(deriv, primals_in, series_in) c = [c0] + cs u = [x] + series v = [primal_out] + [None] * len(series) for k in range(1, len(v)): - v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1)) + v[k] = sum(j * c[k-j] * u[j] for j in range(1, k + 1)) / k primal_out, *series_out = v return primal_out, series_out @@ -405,11 +423,11 @@ def deriv_prop(prim, deriv, primals_in, series_in): lax.exp(lax.neg(lax.square(x))))) -def def_comp(prim, comp): +def def_comp(prim, comp, **kwargs): """ Define the jet rule for a primitive in terms of a composition of simpler primitives. """ - jet_rules[prim] = partial(jet, comp) + jet_rules[prim] = partial(jet2, comp, **kwargs) def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1) @@ -446,22 +464,22 @@ def _erf_inv_rule(primals_in, series_in): # we know c[:k], we compute c[k] # propagate c to get v - v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1)) + v[k] = sum(j * c[k-j] * u[j] for j in range(1, k + 1)) / k # propagate v to get next c # square - tmp_sq[k] = fact(k) * sum(_scale2(k, j) * v[k-j] * v[j] for j in range(k + 1)) + tmp_sq[k] = sum( v[k-j] * v[j] for j in range(k + 1)) # exp - tmp_exp[k] = fact(k-1) * sum(_scale(k, j) * tmp_exp[k-j] * tmp_sq[j] for j in range(1, k + 1)) + tmp_exp[k] = sum(j * tmp_exp[k-j] * tmp_sq[j] for j in range(1, k + 1)) / k # const c[k] = deriv_const * tmp_exp[k] # we can't, and don't need, to compute c[k+1], just need to get the last v[k] k = len(series) - v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1)) + v[k] = sum(j * c[k-j] * u[j] for j in range(1, k + 1)) / k primal_out, *series_out = v return primal_out, series_out @@ -469,22 +487,13 @@ def _erf_inv_rule(primals_in, series_in): ### More complicated rules -def fact(n): - return lax.exp(lax.lgamma(n+1.)) - -def _scale(k, j): - return 1. / (fact(k - j) * fact(j - 1)) - -def _scale2(k, j): - return 1. / (fact(k - j) * fact(j)) - -def _exp_taylor(primals_in, series_in): +def _exp_taylor(primals_in, series_in, **_): x, = primals_in series, = series_in u = [x] + series v = [lax.exp(x)] + [None] * len(series) for k in range(1,len(v)): - v[k] = fact(k-1) * sum(_scale(k, j) * v[k-j] * u[j] for j in range(1, k+1)) + v[k] = sum(j * v[k-j] * u[j] for j in range(1, k+1)) / k primal_out, *series_out = v return primal_out, series_out jet_rules[lax.exp_p] = _exp_taylor @@ -492,12 +501,11 @@ def _exp_taylor(primals_in, series_in): def _pow_taylor(primals_in, series_in): u_, r_ = primals_in - x, series = jet(lambda x, y: lax.mul(y, lax.log(x)), primals_in, series_in) - + x, series = jet2(lambda x, y: lax.mul(y, lax.log(x)), primals_in, series_in) u = [x] + series v = [u_ ** r_] + [None] * len(series) for k in range(1, len(v)): - v[k] = fact(k-1) * sum(_scale(k, j) * v[k-j] * u[j] for j in range(1, k+1)) + v[k] = sum(j * v[k-j] * u[j] for j in range(1, k+1)) / k primal_out, *series_out = v return primal_out, series_out @@ -515,22 +523,22 @@ def _pow_by_squaring(x, n): def _integer_pow_taylor(primals_in, series_in, *, y): if y == 0: - return jet(jnp.ones_like, primals_in, series_in) + return jet2(jnp.ones_like, primals_in, series_in) else: - return jet(lambda x: _pow_by_squaring(x, y), primals_in, series_in) + return jet2(lambda x: _pow_by_squaring(x, y), primals_in, series_in) jet_rules[lax.integer_pow_p] = _integer_pow_taylor -def _logistic_taylor(primals_in, series_in): +def _logistic_taylor(primals_in, series_in, **_): x, = primals_in series, = series_in u = [x] + series v = [lax.logistic(x)] + [None] * len(series) e = [v[0] * (1 - v[0])] + [None] * len(series) # terms for sigmoid' = sigmoid * (1 - sigmoid) for k in range(1, len(v)): - v[k] = fact(k-1) * sum(_scale(k, j) * e[k-j] * u[j] for j in range(1, k+1)) - e[k] = (1 - v[0]) * v[k] - fact(k) * sum(_scale2(k, j) * v[j] * v[k-j] for j in range(1, k+1)) + v[k] = sum(j * e[k-j] * u[j] for j in range(1, k+1)) / k + e[k] = (1 - v[0]) * v[k] - sum(v[j] * v[k-j] for j in range(1, k+1)) primal_out, *series_out = v return primal_out, series_out @@ -538,7 +546,7 @@ def _logistic_taylor(primals_in, series_in): jet_rules[lax.logistic_p] = _logistic_taylor -def _tanh_taylor(primals_in, series_in): +def _tanh_taylor(primals_in, series_in, **_): x, = primals_in series, = series_in u = [2*x] + [2 * series_ for series_ in series] @@ -548,14 +556,14 @@ def _tanh_taylor(primals_in, series_in): return 2 * primal_out - 1, series_out jet_rules[lax.tanh_p] = _tanh_taylor -def _log_taylor(primals_in, series_in): +def _log_taylor(primals_in, series_in, **_): x, = primals_in series, = series_in u = [x] + series v = [lax.log(x)] + [None] * len(series) for k in range(1, len(v)): - conv = sum(_scale(k, j) * v[j] * u[k-j] for j in range(1, k)) - v[k] = (u[k] - fact(k - 1) * conv) / u[0] + conv = sum(j * v[j] * u[k-j] for j in range(1, k)) + v[k] = (u[k] - conv / k) / u[0] primal_out, *series_out = v return primal_out, series_out jet_rules[lax.log_p] = _log_taylor @@ -564,14 +572,14 @@ def _atan2_taylor(primals_in, series_in): x, y = primals_in primal_out = lax.atan2(x, y) - x, series = jet(lax.div, primals_in, series_in) + x, series = jet2(lax.div, primals_in, series_in) one = lax_internal._const(x, 1) - c0, cs = jet(lambda x: lax.div(one, 1 + lax.square(x)), (x, ), (series, )) + c0, cs = jet2(lambda x: lax.div(one, 1 + lax.square(x)), (x, ), (series, )) c = [c0] + cs u = [x] + series v = [primal_out] + [None] * len(series) for k in range(1, len(v)): - v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1)) + v[k] = sum(j * c[k-j] * u[j] for j in range(1, k + 1)) / k primal_out, *series_out = v return primal_out, series_out jet_rules[lax.atan2_p] = _atan2_taylor @@ -582,15 +590,15 @@ def _div_taylor_rule(primals_in, series_in): u = [x] + x_terms w = [y] + y_terms v = [None] * len(u) - def scale(k, j): return 1. / (fact(k - j) * fact(j)) + for k in range(0, len(v)): - conv = sum(scale(k, j) * v[j] * w[k-j] for j in range(0, k)) - v[k] = (u[k] - fact(k) * conv) / w[0] + conv = sum(v[j] * w[k-j] for j in range(0, k)) + v[k] = (u[k] - conv) / w[0] primal_out, *series_out = v return primal_out, series_out jet_rules[lax.div_p] = _div_taylor_rule -def _sinusoidal_rule(sign, prims, primals_in, series_in): +def _sinusoidal_rule(sign, prims, primals_in, series_in, **_): x, = primals_in series, = series_in u = [x] + series @@ -598,12 +606,12 @@ def _sinusoidal_rule(sign, prims, primals_in, series_in): s = [s(x)] + [None] * len(series) c = [c(x)] + [None] * len(series) for k in range(1, len(s)): - s[k] = fact(k-1) * sum(_scale(k, j) * u[j] * c[k-j] for j in range(1, k + 1)) - c[k] = fact(k-1) * sum(_scale(k, j) * u[j] * s[k-j] for j in range(1, k + 1)) * sign + s[k] = sum(j * u[j] * c[k-j] for j in range(1, k + 1)) / k + c[k] = sum(j * u[j] * s[k-j] for j in range(1, k + 1)) / k * sign return (s[0], s[1:]), (c[0], c[1:]) def _get_ind(f, ind): - return lambda *args: f(*args)[ind] + return lambda *args, **kwargs: f(*args, **kwargs)[ind] jet_rules[lax.sin_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 0) jet_rules[lax.cos_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 1) @@ -617,9 +625,8 @@ def _bilinear_taylor_rule(prim, primals_in, series_in, **params): w = [y] + y_terms v = [None] * len(u) op = partial(prim.bind, **params) - def scale(k, j): return 1. / (fact(k - j) * fact(j)) for k in range(0, len(v)): - v[k] = fact(k) * sum(scale(k, j) * op(u[j], w[k-j]) for j in range(0, k+1)) + v[k] = sum(op(u[j], w[k-j]) for j in range(0, k+1)) primal_out, *series_out = v return primal_out, series_out jet_rules[lax.dot_general_p] = partial(_bilinear_taylor_rule, lax.dot_general_p) @@ -727,9 +734,10 @@ def _scatter_add_rule(primals_in, series_in, *, update_jaxpr, update_consts, def _jet_jaxpr( jaxpr: core.ClosedJaxpr, order: int, primals_and_series_avals, in_tree_def ) -> tuple[core.ClosedJaxpr, Any]: - f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=jaxpr.jaxpr.debug_info) + f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), + debug_info=jaxpr.jaxpr.debug_info.with_unknown_names()) f_jet, out_tree_def = traceable(jet_fun(jet_subtrace(f), order), in_tree_def) - jaxpr_jet, _, consts, () = pe.trace_to_jaxpr_dynamic( + jaxpr_jet, _, consts = pe.trace_to_jaxpr_dynamic( f_jet, primals_and_series_avals) return core.ClosedJaxpr(jaxpr_jet, consts), out_tree_def @@ -756,7 +764,7 @@ def _pjit_jet_rule(primals_in, series_in, **params): 'out_layouts': params['out_layouts'] + (None,) * num_series_out, 'donated_invars': params['donated_invars'] + (False,) * num_series_in, } - result = pjit.pjit_p.bind(*primals_and_series, **new_params) + result = pjit.jit_p.bind(*primals_and_series, **new_params) return tree_unflatten(out_tree_def(), result) -jet_rules[pjit.pjit_p] = _pjit_jet_rule +jet_rules[pjit.jit_p] = _pjit_jet_rule diff --git a/jax/experimental/key_reuse/_core.py b/jax/experimental/key_reuse/_core.py index 7275046f556d..cc18ff13d183 100644 --- a/jax/experimental/key_reuse/_core.py +++ b/jax/experimental/key_reuse/_core.py @@ -35,10 +35,12 @@ from jax._src import util from jax._src.ad_checkpoint import remat_p from jax._src.debugging import debug_callback_p +from jax._src.effects import Effect +from jax._src.hashable_array import HashableArray from jax._src.interpreters import partial_eval as pe from jax._src.util import weakref_lru_cache -from jax.experimental.shard_map import shard_map_p +from jax._src.shard_map import shard_map_p import numpy as np @@ -212,13 +214,13 @@ def key_reuse_signature_from_eqn(eqn: core.JaxprEqn) -> KeyReuseSignature: return sig.signature(eqn) else: raise TypeError( - f"Unrecognized key reuse sigature of type {type(sig)}: {sig}") + f"Unrecognized key reuse signature of type {type(sig)}: {sig}") else: return unknown_signature(eqn) def key_reuse_signature_from_primitive(prim, *args, **params): - if prim == pjit.pjit_p: + if prim == pjit.jit_p: return jaxpr_type_signature(params['jaxpr'].jaxpr) if prim not in key_reuse_signatures: # TODO(jakevdp) should we generate an unknown signature here? @@ -231,12 +233,13 @@ def key_reuse_signature_from_primitive(prim, *args, **params): return jaxpr_type_signature(jaxpr) else: raise TypeError( - f"Unrecognized key reuse sigature of type {type(sig)}: {sig}") + f"Unrecognized key reuse signature of type {type(sig)}: {sig}") +consume_effect = Effect() consume_p = core.Primitive("consume") consume_p.def_impl(lambda x: x) -consume_p.def_abstract_eval(lambda x: x) +consume_p.def_effectful_abstract_eval(lambda x: (x, {consume_effect})) batching.defvectorized(consume_p) mlir.register_lowering( consume_p, @@ -246,10 +249,11 @@ def consume(key): """Consume the key and return a consumed copy.""" return consume_p.bind(key) +assert_effect = Effect() assert_consumed_value_p = core.Primitive("assert_consumed_value") assert_consumed_value_p.def_impl(lambda x, *, value: x) -assert_consumed_value_p.def_abstract_eval(lambda x, *, value: x) +assert_consumed_value_p.def_effectful_abstract_eval(lambda x, *, value: (x, {assert_effect})) batching.defvectorized(assert_consumed_value_p) mlir.register_lowering( assert_consumed_value_p, @@ -257,16 +261,16 @@ def consume(key): def assert_unconsumed(key): """Assert that a key is unconsumed""" - assert_consumed_value_p.bind(key, value=False) + assert_consumed_value_p.bind(key, value=HashableArray(False)) def assert_consumed(key, value=True): """Assert that a key is consumed""" - assert_consumed_value_p.bind(key, value=value) + assert_consumed_value_p.bind(key, value=HashableArray(value)) def _check_consumed_value(eqn, consumed): """Extra check for use with assert_consumed_value_p""" - expected = eqn.params['value'] + expected = eqn.params['value'].val if not np.all(consumed == expected): if np.all(expected): raise AssertionError(f"Expected key to be consumed in {eqn}") @@ -401,7 +405,7 @@ def function_type_signature(fun: Callable[..., Any], *args: Any) -> KeyReuseSign lu.wrap_init(fun, debug_info=api_util.debug_info("key_reuse", fun, args, {})), in_tree) - jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat) + jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat) return jaxpr_type_signature(jaxpr) @@ -415,7 +419,7 @@ def check_key_reuse(fun: Callable[..., Any], /, *args: Any) -> None: function_type_signature(fun, *args) -#---------------------------------------------------------------------------------- +# ---------------------------------------------------------------------------------- # key reuse rules for particular primitives: @dynamic_key_reuse_signature @@ -450,7 +454,7 @@ def _concatenate_signature(eqn): def _pjit_key_type_signature(eqn): return jaxpr_type_signature(eqn.params['jaxpr'].jaxpr) -key_reuse_signatures[pjit.pjit_p] = _pjit_key_type_signature +key_reuse_signatures[pjit.jit_p] = _pjit_key_type_signature @dynamic_key_reuse_signature def _shard_map_type_signature(eqn): @@ -576,8 +580,8 @@ def call_impl_with_key_reuse_checks(prim: core.Primitive, raw_impl: Callable[... # TODO(jakevdp): should we use an unknown signature here? return raw_impl(*args, **kwargs) signature = key_reuse_signature_from_primitive(prim, *args, **kwargs) - funcname = "jit-compiled function" if prim == pjit.pjit_p else str(prim) - consts = kwargs['jaxpr'].consts if prim == pjit.pjit_p else [] + funcname = "jit-compiled function" if prim == pjit.jit_p else str(prim) + consts = kwargs['jaxpr'].consts if prim == pjit.jit_p else [] signature.check_signature(*args, *consts, funcname=funcname) result = raw_impl(*args, **kwargs) signature.update_consumption([*args, *consts], result if prim.multiple_results else [result]) diff --git a/jax/experimental/layout.py b/jax/experimental/layout.py index ed9f8931938e..8392926ff1b7 100644 --- a/jax/experimental/layout.py +++ b/jax/experimental/layout.py @@ -13,6 +13,9 @@ # limitations under the License. from jax._src.layout import ( - DeviceLocalLayout as DeviceLocalLayout, - Layout as Layout + Layout as Layout, + Format as Format, +) +from jax._src.pjit import ( + with_layout_constraint as with_layout_constraint, ) diff --git a/jax/experimental/mesh_utils.py b/jax/experimental/mesh_utils.py index 075e4e6eed48..58d20c331d5f 100644 --- a/jax/experimental/mesh_utils.py +++ b/jax/experimental/mesh_utils.py @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The JAX Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index d004c7deb3df..c8ddf41cbaf3 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -23,15 +23,20 @@ Barrier as Barrier, ClusterBarrier as ClusterBarrier, TMABarrier as TMABarrier, - ThreadSemantics as ThreadSemantics, + LoweringSemantics as LoweringSemantics, TMEM as TMEM, Union as Union, as_gpu_kernel as as_gpu_kernel, + as_torch_gpu_kernel as as_torch_gpu_kernel, + supports_cross_device_collectives as supports_cross_device_collectives, ) from .launch_context import ( + AsyncCopyImplementation as AsyncCopyImplementation, + GLOBAL_BROADCAST as GLOBAL_BROADCAST, LaunchContext as LaunchContext, MemRefTransform as MemRefTransform, + TMAReductionOp as TMAReductionOp, Rounding as Rounding, TileTransform as TileTransform, TransposeTransform as TransposeTransform, @@ -45,26 +50,43 @@ infer_layout as infer_layout, ) -from .transform_inference import ( - infer_transforms as infer_transforms, +from .layouts import ( + to_layout_attr as to_layout_attr, ) from .fragmented_array import ( FragmentedArray as FragmentedArray, FragmentedLayout as FragmentedLayout, + TCGEN05_LAYOUT as TCGEN05_LAYOUT, + TCGEN05_TRANSPOSED_LAYOUT as TCGEN05_TRANSPOSED_LAYOUT, + TCGEN05_ROW_LAYOUT as TCGEN05_ROW_LAYOUT, + TCGEN05_COL_LAYOUT as TCGEN05_COL_LAYOUT, + TiledLayout as TiledLayout, WGMMA_LAYOUT as WGMMA_LAYOUT, + WGMMA_LAYOUT_8BIT as WGMMA_LAYOUT_8BIT, WGMMA_ROW_LAYOUT as WGMMA_ROW_LAYOUT, - WGMMARowFragLayout as WGMMARowFragLayout, + WGMMA_COL_LAYOUT as WGMMA_COL_LAYOUT, + WGMMA_TRANSPOSED_LAYOUT as WGMMA_TRANSPOSED_LAYOUT, + WGMMA_LAYOUT_UPCAST_2X as WGMMA_LAYOUT_UPCAST_2X, + WGMMA_LAYOUT_UPCAST_4X as WGMMA_LAYOUT_UPCAST_4X, + TMEM_NATIVE_LAYOUT as TMEM_NATIVE_LAYOUT, + TMA_GATHER_INDICES_LAYOUT as TMA_GATHER_INDICES_LAYOUT, + tmem_native_layout as tmem_native_layout, WGSplatFragLayout as WGSplatFragLayout, WGStridedFragLayout as WGStridedFragLayout, + copy_tiled as copy_tiled, optimization_barrier as optimization_barrier, ) from .utils import ( BarrierRef as BarrierRef, + DialectBarrierRef as DialectBarrierRef, CollectiveBarrierRef as CollectiveBarrierRef, DynamicSlice as DynamicSlice, Partition as Partition, Partition1D as Partition1D, + SemaphoreRef as SemaphoreRef, + ThreadSubset as ThreadSubset, + MultimemReductionOp as MultimemReductionOp, bitwidth as bitwidth, bytewidth as bytewidth, c as c, @@ -72,22 +94,32 @@ debug_print as debug_print, ds as ds, fori as fori, + is_known_divisible as is_known_divisible, memref_fold as memref_fold, memref_slice as memref_slice, memref_reshape as memref_reshape, memref_transpose as memref_transpose, memref_unfold as memref_unfold, memref_unsqueeze as memref_unsqueeze, + nanosleep as nanosleep, + query_cluster_cancel as query_cluster_cancel, single_thread as single_thread, single_thread_predicate as single_thread_predicate, thread_idx as thread_idx, tile_shape as tile_shape, + try_cluster_cancel as try_cluster_cancel, warp_idx as warp_idx, warpgroup_barrier as warpgroup_barrier, warpgroup_idx as warpgroup_idx, when as when, ) +from .mma import ( + MMALayouts as MMALayouts, + mma as mma, +) from .wgmma import ( WGMMAAccumulator as WGMMAAccumulator, wgmma as wgmma, ) + +from . import tcgen05 as tcgen05 diff --git a/jax/experimental/mosaic/gpu/constraints.py b/jax/experimental/mosaic/gpu/constraints.py new file mode 100644 index 000000000000..f62e7151a50f --- /dev/null +++ b/jax/experimental/mosaic/gpu/constraints.py @@ -0,0 +1,894 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines expressions and constraints over layouts.""" + +# mypy has been causing more problems than it solves here. Disable it for these +# files. We have pytype checks anyway. +# mypy: ignore-errors + +from __future__ import annotations + +import abc +from collections.abc import Sequence +import dataclasses +import math +from typing import Any, assert_never, final + +from . import fragmented_array as fa +from . import launch_context as lc +from . import layouts as layouts_lib +from . import inference_utils +from . import tcgen05 + + +VariableKey = Any + + +@dataclasses.dataclass(frozen=True) +class Variable: + """A variable is an abstract identifier. + + `key` is supposed to be hashable. + """ + key: VariableKey + + def __str__(self): + return f"V({self.key})" + + +class Constant(abc.ABC): + """A constant is a known layout.""" + + +@dataclasses.dataclass(frozen=True) +class RegisterLayout(Constant): + """Wraps a known register layout.""" + + value: fa.FragmentedLayout + + def __str__(self): + return f"C({self.value})" + + +@dataclasses.dataclass(frozen=True) +class TMEMLayout(Constant): + """Wraps a known TMEM layout.""" + + value: tcgen05.TMEMLayout + + def __str__(self): + return f"C({self.value})" + + +@dataclasses.dataclass(frozen=True) +class SMEMTiling(Constant): + """Wraps a known SMEM Tile Transform. + + If an SMEM reference may, in principle, have transforms but should not be + tiled, then `value` is `None`. + """ + + value: lc.TileTransform | None + + def __str__(self): + return f"C({self.value})" + + +@dataclasses.dataclass(frozen=True) +class Reduce: + expression: Expression + axes: tuple[int, ...] + + def __str__(self): + return f"Reduce([{self.axes}], {self.expression})" + + +@dataclasses.dataclass(frozen=True) +class BroadcastInDim: + expression: Expression + axes: tuple[int, ...] + shape: tuple[int, ...] + + +@dataclasses.dataclass(frozen=True) +class Reshape: + expression: Expression + source_shape: tuple[int, ...] + target_shape: tuple[int, ...] + + +@dataclasses.dataclass(frozen=True) +class Transpose: + expression: Expression + + def __str__(self): + return f"T({self.expression})" + + +Expression = ( + Variable + | Constant + | Reduce + | BroadcastInDim + | Reshape + | Transpose +) + + +def reduce_broadcast_expression( + broadcast: BroadcastInDim, assignments: dict[Variable, Constant] +) -> Expression | Unsatisfiable: + def _check_shape_broadcast(shape: tuple[int, ...]) -> bool: + for axis, s in zip(broadcast.axes, shape, strict=True): + if broadcast.shape[axis] != s: + return False + return True + + reduced_expr = reduce_expression(broadcast.expression, assignments) + match reduced_expr: + case Unsatisfiable(): + return Unsatisfiable() + case RegisterLayout(value=layout): + match layout: + case fa.WGSplatFragLayout(shape=shape): + if not _check_shape_broadcast(shape): + return Unsatisfiable() + return RegisterLayout(fa.WGSplatFragLayout(shape=broadcast.shape)) + case _: + return BroadcastInDim( + expression=reduced_expr, + axes=broadcast.axes, + shape=broadcast.shape, + ) + case _: + return BroadcastInDim( + expression=reduced_expr, axes=broadcast.axes, shape=broadcast.shape + ) + + +def reduce_reshape_expression( + reshape: Reshape, assignments: dict[Variable, Constant] +) -> Expression | Unsatisfiable: + reduced_expr = reduce_expression(reshape.expression, assignments) + match reduced_expr: + case Unsatisfiable(): + return Unsatisfiable() + case RegisterLayout(value=layout): + match layout: + case fa.WGSplatFragLayout(shape=shape): + assert math.prod(shape) == math.prod(reshape.target_shape) + return RegisterLayout( + fa.WGSplatFragLayout(shape=reshape.target_shape) + ) + case fa.WGStridedFragLayout(shape=shape, vec_size=vec_size): + assert math.prod(shape) == math.prod(reshape.target_shape) + return RegisterLayout( + fa.WGStridedFragLayout( + shape=reshape.target_shape, vec_size=vec_size + ) + ) + case fa.TiledLayout() as tiled_layout: + tile_shape = tiled_layout.base_tile_shape + if len(reshape.target_shape) < len(tile_shape): + return dataclasses.replace(reshape, expression=reduced_expr) + # Even if the new shape is not perfectly tilable, it is possible that + # we may be able to reshape the tiling itself in a way that is + # compatible with the new shape. We do not handle this case at the + # moment. + for ts, s in zip(tile_shape, reshape.source_shape[-len(tile_shape):], strict=True): + if s % ts != 0: + return dataclasses.replace(reshape, expression=reduced_expr) + + # If minor tiled dimensions are modified, then reshaping is likely to + # not be a no-op since the strides between tiles will change, + # potentially mapping different elements to lanes and warps. We don't + # attempt to handle this case at the moment. + num_minor_tiled_dims = len(tile_shape) - 1 + source_minor_tiled_dims = reshape.source_shape[-num_minor_tiled_dims:] + target_minor_tiled_dims = reshape.target_shape[-num_minor_tiled_dims:] + major_tiled_dim = tile_shape[0] + if (source_minor_tiled_dims != target_minor_tiled_dims or + reshape.target_shape[-len(tile_shape)] % major_tiled_dim != 0): + return dataclasses.replace(reshape, expression=reduced_expr) + # At this point, we now that only non-tiled dimensions and/or the + # majormost tiled dimensions may have changed. We also know that the + # majormost tiled dimension is still tilable in the new shape. + # Therefore, we can return the tiled layout as is. + return RegisterLayout(tiled_layout) + case _: + return dataclasses.replace(reshape, expression=reduced_expr) # pytype: disable=bad-return-type + + +def reduce_transpose_expression( + transpose: Transpose, assignments: dict[Variable, Constant] +) -> Expression | Unsatisfiable: + reduced_expr = reduce_expression(transpose.expression, assignments) + match reduced_expr: + case Unsatisfiable(): + return Unsatisfiable() + case SMEMTiling(value=tile_transform): + if tile_transform is None: + return SMEMTiling(None) + tiling = tile_transform.tiling + if len(tiling) != 2: + raise NotImplementedError( + f"Only 2D tilings are supported, got {len(tiling)}" + ) + return SMEMTiling(lc.TileTransform(tiling[::-1])) + case _: + return Transpose(expression=reduced_expr) + + +def reduce_expression( + expr: Expression, assignments: dict[Variable, Constant] +) -> Expression | Unsatisfiable: + """Reduces an expression as much as is possible given a set of known variable assignments.""" + match expr: + case Constant(): + return expr + case Variable(): + return assignments.get(expr, expr) + case Reduce(expression=expr, axes=axes): + reduced_expr = reduce_expression(expr, assignments) + match reduced_expr: + case Unsatisfiable(): + return Unsatisfiable() + case RegisterLayout(value=layout) if isinstance(layout, fa.TiledLayout): + return RegisterLayout(layout.reduce(axes)) + case _: + return Reduce(expression=reduced_expr, axes=axes) + case BroadcastInDim(): + return reduce_broadcast_expression(expr, assignments) + case Reshape(): + return reduce_reshape_expression(expr, assignments) + case Transpose(): + return reduce_transpose_expression(expr, assignments) + case _: + assert_never(expr) + + +@dataclasses.dataclass(frozen=True) +class Equals: + """States that `lhs` and `rhs` are equal.""" + lhs: Expression + rhs: Expression + + def holds(self) -> bool | None: + if self.lhs == self.rhs: + return True + if isinstance(self.lhs, Constant) and isinstance(self.rhs, Constant): + return False + return None + + def __str__(self): + return f"Equals({self.lhs} == {self.rhs})" + + +_always_supported = lambda *args: True + + +# Maps a tuple of layouts (source, target) to a function that takes in a +# bitwidth and returns whether the source->target relayout is supported for +# values of types with the given bitwidth. +_SUPPORTED_TILED_RELAYOUTS = { + # Transposed layouts. + (fa.WGMMA_LAYOUT, fa.WGMMA_TRANSPOSED_LAYOUT): _always_supported, + (fa.WGMMA_TRANSPOSED_LAYOUT, fa.WGMMA_LAYOUT): _always_supported, + (fa.TCGEN05_LAYOUT, fa.TCGEN05_TRANSPOSED_LAYOUT): _always_supported, + (fa.TCGEN05_TRANSPOSED_LAYOUT, fa.TCGEN05_LAYOUT): _always_supported, + # "Conversion-optimized" layouts. + (fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT): + lambda bitwidth: fa.can_relayout_wgmma_2x_to_wgmma(bitwidth), + (fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT_UPCAST_2X): + lambda bitwidth: fa.can_relayout_wgmma_4x_to_wgmma_2x(bitwidth), + (fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT): + lambda bitwidth: fa.can_relayout_wgmma_4x_to_wgmma_2x(bitwidth) and fa.can_relayout_wgmma_2x_to_wgmma(bitwidth), +} + + +@dataclasses.dataclass(frozen=True) +class Relayout: + """States that `source` must be relayout-able to `target`. + + Relayout-ability here is not defined as a fundamental property of layouts, but + rather a reflection of our implementation. For instance, when evaluating this + constraint, we will return `False` systematically if a relayout exists but we + do not ever plan to support it. + + Modeling this constraint this way is helpful, in order to allow pruning + inefficient solutions when attempting to solve a constraint system. + + We include here the bitwidth of the element type we want to associate with + this constraint, as certain relayouts are only supported for specific + bitwidths. + """ + + source: Expression + target: Expression + bitwidth: int + + def holds(self) -> bool | None: + """Returns whether the relayout constraint holds. + + Returns `None` if the constraint can't be checked. + """ + source = self.source + target = self.target + + # Fast path for syntactically identical expressions. + if source == target: + return True + + if not isinstance(source, RegisterLayout) or not isinstance( + target, RegisterLayout + ): + return None + + source_layout, target_layout = source.value, target.value + match source_layout, target_layout: + case fa.WGSplatFragLayout() as splat, fa.WGStridedFragLayout() as strided: + return splat.shape == strided.shape + case fa.WGSplatFragLayout(), fa.TiledLayout(): + return layouts_lib.splat_is_compatible_with_tiled( + source_layout, target_layout + ) + case fa.TiledLayout(), fa.TiledLayout(): + is_supported = _SUPPORTED_TILED_RELAYOUTS.get( + (source_layout, target_layout), lambda *_: False + ) + return is_supported(self.bitwidth) + case _: + return False + + def __str__(self): + return f"Relayout({self.source} ⟶ {self.target})" + + +@dataclasses.dataclass(frozen=True) +class IsTransferable: + """States that `source` layout must be transferable across memory spaces to `target` layout.""" + + source: Expression + target: Expression + # TODO(allanrenucci): Can this be derived from the layouts? + shape: tuple[int, ...] + + def supported_tmem_transfers( + self, packing: int + ) -> list[tuple[tcgen05.TMEMLayout, fa.FragmentedLayout]]: + """Returns the list of supported TMEM <-> Register transfers.""" + assert len(self.shape) == 2 + columns = self.shape[1] + tmem_default_layout = tcgen05.tmem_default_layout(packing) + return [ + (tmem_default_layout, fa.TCGEN05_LAYOUT), + (tmem_default_layout, fa.TMEM_NATIVE_LAYOUT), + (tcgen05.tmem_half_lane_layout(columns, packing), fa.WGMMA_LAYOUT), + ( + tcgen05.tmem_m64_collective_layout(columns, packing), + tcgen05.fa_m64_collective_layout(columns), + ), + ] + + def _is_valid_tmem_transfer( + self, tmem_layout: tcgen05.TMEMLayout, reg_layout: fa.FragmentedLayout + ) -> bool: + packing = tmem_layout.vector_length + return (tmem_layout, reg_layout) in self.supported_tmem_transfers(packing) + + def _is_valid_smem_transfer( + self, + smem_layout: lc.TileTransform | None, + reg_layout: fa.FragmentedLayout, + ) -> bool: + # TODO(b/447079781): This is way too restrictive. We need to make it more + # precise by: + # - Consider whether the op is annotated with optimized copies or not. + # - If copies do not have to be optimized, always return True. + # - If copies have to be optimized, determine if the transfer is optimal by + # calling fragmented_array.plan_tiled_transfer. + if inference_utils.is_mma_layout(reg_layout): + return smem_layout is not None and len(smem_layout.tiling) == 2 + return smem_layout is None + + def holds(self) -> bool | None: + """Returns whether the constraint holds. + + Returns `None` if the constraint can't be checked. + """ + + assert self.source != self.target, ( + "IsTransferable constraints within the same memory space are not" + " supported." + ) + + match self.source, self.target: + case TMEMLayout(value=src), RegisterLayout(value=dst): + return self._is_valid_tmem_transfer(src, dst) + case RegisterLayout(value=src), TMEMLayout(value=dst): + return self._is_valid_tmem_transfer(dst, src) + case SMEMTiling(value=src), RegisterLayout(value=dst): + return self._is_valid_smem_transfer(src, dst) + case RegisterLayout(value=src), SMEMTiling(value=dst): + return self._is_valid_smem_transfer(dst, src) + case Constant(), Constant(): + source_type = type(self.source).__name__ + target_type = type(self.target).__name__ + raise NotImplementedError( + f"Unsupported transfer: {source_type} -> {target_type}" + ) + case _: + return None + + def __str__(self): + return f"IsTransferable({self.source} ⟶ {self.target})" + + +@dataclasses.dataclass(frozen=True) +class NotOfType: + """States that `expr` is not an instance of `type`.""" + + expr: Expression + type: type[fa.FragmentedLayout] + + def holds(self) -> bool | None: + """Whether the distinctiveness constraint holds. + + Returns `None` if the constraint can't be checked. + """ + if not isinstance(self.expr, Constant): + return None + if not isinstance(self.expr, RegisterLayout): + return True + return not isinstance(self.expr.value, self.type) + + def __str__(self): + return f"type({self.expr}) ≠ {self.type.__name__}" + + +@dataclasses.dataclass(frozen=True) +class Divides: + """States that the `expr` tiling is a divisor of `tiling_multiple`. + + That is to say that, for each tiled dimension in `expr`, the dimension must + divide its corresponding dimension in `tiling_multiple` starting from the + tail. + + If `tiling_multiple` contains more dimensions than `expr`, then + the extra dimensions in `tiling_multiple` are ignored for the purposes of the + check. + + `expr` is not allowed to contain more dimensions than `tiling_multiple`, and + this constraint therefore also constrains the rank of `expr`. + """ + expr: Expression + tiling_multiple: tuple[int, ...] + + def holds(self) -> bool | None: + match self.expr: + case SMEMTiling(value=None): + # If there is no tiling, then this holds trivially. + return True + case SMEMTiling(value=lc.TileTransform(tiling=t)): + tiling = t + case RegisterLayout(value=fa.TiledLayout() as layout): + tiling = layout.base_tile_shape + case TMEMLayout(value): + tiling = value.base_tile_shape + case _: + return None + + if len(tiling) > len(self.tiling_multiple): + # The rank of the tiling is larger than the rank of the constraint. This + # is not allowed. + return False + + for size, multiple in zip(reversed(tiling), reversed(self.tiling_multiple)): + if multiple % size: + return False + return True + + def __str__(self): + return f"{self.tiling_multiple} % {self.expr} == 0" + + +Constraint = Equals | Relayout | NotOfType | IsTransferable | Divides + + +def reduce_constraint( + constraint: Constraint, assignments: dict[Variable, Constant] +) -> Constraint | Unsatisfiable: + """Reduces a constraint.""" + + match constraint: + case Equals(lhs=lhs, rhs=rhs): + lhs_red = reduce_expression(lhs, assignments) + if isinstance(lhs_red, Unsatisfiable): + return Unsatisfiable() + rhs_red = reduce_expression(rhs, assignments) + if isinstance(rhs_red, Unsatisfiable): + return Unsatisfiable() + return Equals(lhs_red, rhs_red) + case Relayout(source=source, target=target, bitwidth=bitwidth): + source_red = reduce_expression(source, assignments) + target_red = reduce_expression(target, assignments) + if isinstance(source_red, Unsatisfiable) or isinstance( + target_red, Unsatisfiable + ): + return Unsatisfiable() + return Relayout(source_red, target_red, bitwidth) + case NotOfType(expr=expr, type=type): + expr_red = reduce_expression(expr, assignments) + if isinstance(expr_red, Unsatisfiable): + return Unsatisfiable() + return NotOfType(expr_red, type) + case IsTransferable(source=source, target=target, shape=shape): + source_red = reduce_expression(source, assignments) + target_red = reduce_expression(target, assignments) + if isinstance(source_red, Unsatisfiable) or isinstance(target_red, Unsatisfiable): + return Unsatisfiable() + return IsTransferable(source_red, target_red, shape) + case Divides(expr=expr, tiling_multiple=tiling_multiple): + expr_red = reduce_expression(expr, assignments) + if isinstance(expr_red, Unsatisfiable): + return Unsatisfiable() + return Divides(expr_red, tiling_multiple) + case _ as never: + assert_never(never) + + +@dataclasses.dataclass +class ConstraintSystem: + """A constraint system contains a set of constraints and assignments. + + Assignments assign constant values to variables in the system (bound + variables). Constraints describe relationships between variables that must be + upheld, and can be used to determine assignments for unknown (free) variables. + """ + + assignments: dict[Variable, Constant] = dataclasses.field( + default_factory=dict + ) + constraints: Sequence[Constraint] = dataclasses.field(default_factory=list) + + def unknowns(self) -> list[Variable]: + """Returns the list of free variables in the system.""" + seen_variables: set[Variable] = set() + free_variables: list[Variable] = [] + def extract_variables(expr: Expression) -> None: + match expr: + case Variable(): + if expr not in seen_variables and expr not in self.assignments: + seen_variables.add(expr) + free_variables.append(expr) + case Constant(): + ... + case Reduce(expression=e): + extract_variables(e) + case BroadcastInDim(expression=e): + extract_variables(e) + case Reshape(expression=e): + extract_variables(e) + case Transpose(expression=e): + extract_variables(e) + case _: + assert_never(expr) + for constraint in self.constraints: + match constraint: + case Equals(lhs=lhs, rhs=rhs): + extract_variables(lhs) + extract_variables(rhs) + case Relayout(source=source, target=target): + extract_variables(source) + extract_variables(target) + case NotOfType(expr=expr): + extract_variables(expr) + case IsTransferable(source=source, target=target, shape=_): + extract_variables(source) + extract_variables(target) + case Divides(expr=expr): + extract_variables(expr) + case _ as never: + assert_never(never) + return free_variables + + def __and__( + self, other: ConstraintSystem | Unsatisfiable + ) -> ConstraintSystem | Unsatisfiable: + if isinstance(other, Unsatisfiable): + return Unsatisfiable() + for variable, assignment in self.assignments.items(): + if variable in other.assignments and assignment != other.assignments[variable]: + return Unsatisfiable() + return ConstraintSystem( + assignments=self.assignments | other.assignments, + constraints=[*self.constraints, *other.constraints], + ) + + def __str__(self): + r = "ConstraintSystem\n" + r += " assignments:\n" + for assignment, constant in self.assignments.items(): + r += f" {assignment} ⟵ {constant}\n" + r += " constraints:\n" + for constraint in self.constraints: + r += f" {constraint}\n" + return r + + +@final +class Unsatisfiable: + + def __and__(self, other: ConstraintSystem | Unsatisfiable) -> Unsatisfiable: + return self + + +def non_splat_variables( + constraints: Sequence[Constraint], +) -> set[Variable]: + """Returns a all vars distinct from a splat.""" + vars: set[Variable] = set() + for constraint in constraints: + match constraint: + case NotOfType(expr=Variable() as var, type=fa.WGSplatFragLayout): + assert isinstance(var, Variable) # make pytype happy + vars.add(var) + return vars + + +def _has_relayout_of_non_splat_to_splat(constraints: Sequence[Constraint]) -> bool: + """Returns whether the constraints imply a non-splat to splat relayout. + + Such relayouts are impossible and this helps shortcut the search. + + If this function returns False, this doesn't necessarily mean that there are + no non-splat to splat relayouts, just that this is not known yet. + """ + non_splat = non_splat_variables(constraints) + if not non_splat: + return False + + def is_constant_splat(e) -> bool: + return isinstance(e, RegisterLayout) and isinstance( + e.value, fa.WGSplatFragLayout + ) + + for constraint in constraints: + match constraint: + case Relayout(source=source, target=target): + if source in non_splat and is_constant_splat(target): + return True + case _: + pass + return False + + +def saturate_distinct_from_splat( + constraint_system: ConstraintSystem, +) -> ConstraintSystem | Unsatisfiable: + """Adds transitive NotOfType constraints for all non-splat variables. + + Given `n` variables `l0`, ... `l{n-1}`, and a set of relayouts + `{ Relayout(l{i}, l{i+1}) : 0 <= i < n }`, if we also know that + `l{0}` is not splat, then we can automatically deduce that none of + `l0`, ..., `l{n-1}` are splat either. + + This helps us quickly conclude that a system is unsatisfiable in cases where + a non-splat variable is transitively relaid out into a splat layout. + """ + non_splat = non_splat_variables(constraint_system.constraints) + new_constraints: list[Constraint] = [] + new_non_splat_found = len(non_splat) > 0 + + while new_non_splat_found: + new_non_splat_found = False + for constraint in constraint_system.constraints: + match constraint: + case Relayout(source=source, target=target): + if ( + isinstance(target, Variable) + and source in non_splat + and target not in non_splat + ): + new_non_splat_found = True + non_splat.add(target) + new_constraints.append(NotOfType(target, fa.WGSplatFragLayout)) + case _: + pass + return constraint_system & ConstraintSystem(constraints=new_constraints) + + +def compute_transitively_equal_vars( + system: ConstraintSystem, +) -> dict[Variable, list[Variable]]: + """Computes all transitively equal variables in a constraint system. + + The output dictionary maps each variable that appears in constraints in the + constraint system to all the variables it is transitively equal to. + """ + # The equality relations between variables form a graph where variables are + # nodes and a constraint `v1 == v2` forms an edge. All variables in a + # connected component are transitively equal. We use a Union-Find data + # structure with path compression to efficiently find these connected + # components (i.e., equivalence classes). + parent: dict[Variable, Variable] = {} + def find(v: Variable) -> Variable: + if v not in parent: + parent[v] = v + if parent[v] != v: + parent[v] = find(parent[v]) + return parent[v] + + def union(v1: Variable, v2: Variable): + root1 = find(v1) + root2 = find(v2) + if root1 != root2: + parent[root2] = root1 + + all_vars: set[Variable] = set() + for constraint in system.constraints: + match constraint: + case Equals(lhs=Variable() as lhs, rhs=Variable() as rhs): + assert isinstance(lhs, Variable) # make pytype happy + assert isinstance(rhs, Variable) # make pytype happy + all_vars.add(lhs) + all_vars.add(rhs) + union(lhs, rhs) + + # Group variables by their component representative. + components: dict[Variable, list[Variable]] = {} + for v in sorted(all_vars, key=str): + root = find(v) + components.setdefault(root, []).append(v) + + equal_vars: dict[Variable, list[Variable]] = {} + for component_vars in components.values(): + for v in component_vars: + equal_vars[v] = [other for other in component_vars if other != v] + + return equal_vars + + +def saturate_divides_constraints_for_equal_vars( + system: ConstraintSystem, +) -> ConstraintSystem: + """Saturates Divides constraints between all transitively equal vars. + """ + equal_vars = compute_transitively_equal_vars(system) + new_constraints: list[Constraint] = [] + for constraint in system.constraints: + new_constraints.append(constraint) + match constraint: + case Divides(expr=expr, tiling_multiple=tiling_multiple): + if isinstance(expr, Variable): + for equal_var in equal_vars.get(expr, []): + new_constraints.append(Divides(equal_var, tiling_multiple)) + case _: + pass + new_constraints = merge_divides_constraints(new_constraints) + return dataclasses.replace(system, constraints=new_constraints) + + +# TODO(bchetioui): clean up API. +def merge_divides_constraints(constraints: Sequence[Constraint]) -> list[Constraint]: + """Merges Divides constraints that can be merged.""" + result: list[Constraint] = [] + var_to_tiling_multiples : dict[Variable, tuple[int, ...]] = {} + for constraint in constraints: + match constraint: + case Divides(expr=Variable() as v, tiling_multiple=tiling_multiple): + assert isinstance(v, Variable) # make pytype happy + if (previous_tiling_multiple := var_to_tiling_multiples.get(v)) is None: + var_to_tiling_multiples[v] = tiling_multiple + continue + # If the two tuples are of different lengths, the larger tuple will + # be truncated (removing initial multiples) to the length of the + # smaller tuple. This preserves the semantics of the Divides constraints + # where a tiling's rank cannot exceed the size of tiling_multiple. + min_len = min(len(tiling_multiple), len(previous_tiling_multiple)) + new_tiling_multiple = [] + if min_len > 0: + for x, y in zip(tiling_multiple[-min_len:], previous_tiling_multiple[-min_len:], strict=True): + new_tiling_multiple.append(math.gcd(x, y)) + var_to_tiling_multiples[v] = tuple(new_tiling_multiple) + case _: + result.append(constraint) + for expr, tiling_multiple in var_to_tiling_multiples.items(): + result.append(Divides(expr, tiling_multiple)) + return result + + +def _reduce_system_once( + constraint_system: ConstraintSystem, +) -> ConstraintSystem | Unsatisfiable | None: + """Performs one reduction step over each constraint in a constraint system. + + Returns: + - Unsatisfiable(): if the constraint system is unsatisfiable. + - A new constraint system if any constraint was reduced. + - None: if the constraint system is not known unsatisfiable, but hasn't been + reduced. + """ + assignments = constraint_system.assignments + constraints: list[Constraint] = [] + changed = False + + def try_assign(var: Variable, cst: Constant) -> bool: + if var in assignments and assignments[var] != cst: + return False + assignments[var] = cst + return True + + for constraint in constraint_system.constraints: + match reduce_constraint(constraint, assignments): + case Unsatisfiable(): + return Unsatisfiable() + case Equals(lhs=Variable() as var, rhs=Constant() as cst): + if not try_assign(var, cst): + return Unsatisfiable() + changed = True + case Equals(lhs=Constant() as cst, rhs=Variable() as var): + if not try_assign(var, cst): + return Unsatisfiable() + changed = True + case _ as new_constraint: + assert isinstance(new_constraint, Constraint) # make pytype happy + match new_constraint.holds(): + case None: + constraints.append(new_constraint) + changed |= new_constraint != constraint + case False: + return Unsatisfiable() + case True: + changed = True + + new_constraints = merge_divides_constraints(constraints) + changed |= len(new_constraints) != len(constraints) + constraints = new_constraints + + # Shortcut for a specific case of unsatisfiability. This shortcut + # drastically reduces the size of the search space. + if _has_relayout_of_non_splat_to_splat(constraints): + return Unsatisfiable() + + if changed: + return ConstraintSystem( + assignments=assignments | constraint_system.assignments, + constraints=constraints, + ) + return None + + +def reduce( + constraint_system: ConstraintSystem, +) -> ConstraintSystem | Unsatisfiable: + """Reduces a constraint system until it can no longer be reduced. + + Returns: + - Unsatisfiable(): if the constraint system is unsatisfiable. + - The maximally reduced constraint system otherwise. + """ + while True: + match _reduce_system_once(constraint_system): + case None: + break + case Unsatisfiable(): + return Unsatisfiable() + case ConstraintSystem() as new_system: + constraint_system = new_system + case _ as never: + assert_never(never) + + return constraint_system diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index b255893e2e2e..1c907a90d045 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -13,25 +13,35 @@ # limitations under the License. # ============================================================================== -from collections.abc import Sequence +from collections.abc import Callable, Iterable, Sequence import contextlib import ctypes import dataclasses import enum import functools import hashlib +import io +import itertools import math import os import pathlib import time -from typing import Any, Callable, Generic, TypeVar +from typing import Any, Generic, TypeVar import weakref import jax +from jax._src import core as jax_core +from jax._src import dtypes +from jax._src import lib +from jax._src import mesh as mesh_lib +from jax._src import sharding_impls +from jax._src import util as jax_util from jax._src.interpreters import mlir from jax._src.lib import mosaic_gpu_dialect as dialect +from jax.extend import backend as jex_backend from jaxlib.mlir import ir from jaxlib.mlir import passmanager +from jaxlib.mlir.dialects import _gpu_ops_gen from jaxlib.mlir.dialects import arith from jaxlib.mlir.dialects import builtin from jaxlib.mlir.dialects import func @@ -41,29 +51,21 @@ from jaxlib.mlir.dialects import nvvm import numpy as np -# mypy: ignore-errors - from . import dialect_lowering from . import launch_context from . import layout_inference +from . import layouts from . import profiler from . import tcgen05 -from . import transform_inference from . import utils # MLIR can't find libdevice unless we point it to the CUDA path -# TODO(apaszke): Unify with jax._src.lib.cuda_path -CUDA_ROOT = "/usr/local/cuda" -if os.environ.get("CUDA_ROOT") is None: - os.environ["CUDA_ROOT"] = CUDA_ROOT -else: - CUDA_ROOT = os.environ["CUDA_ROOT"] - -PTXAS_PATH = os.path.join(CUDA_ROOT, "bin/ptxas") -NVDISASM_PATH = os.path.join(CUDA_ROOT, "bin/nvdisasm") +cuda_root = lib.cuda_path or "/usr/local/cuda" +os.environ["CUDA_ROOT"] = cuda_root +PYTHON_RUNFILES = os.environ.get("PYTHON_RUNFILES") # This tracks the latest Mosaic GPU IR version with a monthly delay. -FWD_COMPAT_IR_VERSION = 1 +FWD_COMPAT_IR_VERSION = 2 c = utils.c # This is too common to fully qualify. @@ -84,17 +86,78 @@ os.environ["MOSAIC_GPU_RUNTIME_LIB_PATH"] = str(RUNTIME_PATH) -mosaic_gpu_p = jax._src.core.Primitive("mosaic_gpu_p") +try: + from nvidia import nvshmem # pytype: disable=import-error +except ImportError: + # Try to find the nvshmem library in Bazel runfiles. + if PYTHON_RUNFILES: + libdevice_path = os.path.join( + PYTHON_RUNFILES, "nvidia_nvshmem", "lib", "libnvshmem_device.bc" + ) + if os.path.exists(libdevice_path): + os.environ["MOSAIC_GPU_NVSHMEM_BC_PATH"] = libdevice_path + for root, _, files in os.walk(os.getcwd()): + if "/_solib" in root and "libnvshmem_host.so.3" in files: + os.environ["MOSAIC_GPU_NVSHMEM_SO_PATH"] = os.path.join( + root, "libnvshmem_host.so.3" + ) + break + else: + pass +else: + if os.environ.get("MOSAIC_GPU_NVSHMEM_BC_PATH") is None: + os.environ["MOSAIC_GPU_NVSHMEM_BC_PATH"] = os.path.join( + nvshmem.__path__[0], "lib/libnvshmem_device.bc" + ) + if os.environ.get("MOSAIC_GPU_NVSHMEM_SO_PATH") is None: + os.environ["MOSAIC_GPU_NVSHMEM_SO_PATH"] = os.path.join( + nvshmem.__path__[0], "lib/libnvshmem_host.so.3" + ) + + +def supports_cross_device_collectives(): + try: + nvshmem_bc_path = os.environ["MOSAIC_GPU_NVSHMEM_BC_PATH"] + except KeyError: + return False + if nvshmem_so_path := os.environ.get("MOSAIC_GPU_NVSHMEM_SO_PATH", ""): + try: + # This both ensures that the file exists, and it populates the dlopen + # cache, helping XLA find the library even if the RPATH is not right... + ctypes.CDLL(nvshmem_so_path) + except OSError: + return False + xla_flags = os.environ.get("XLA_FLAGS", "") + return ( + os.path.exists(nvshmem_bc_path) + and "--xla_gpu_experimental_enable_nvshmem" in xla_flags + ) + + +mosaic_gpu_p = jax_core.Primitive("mosaic_gpu_p") mosaic_gpu_p.multiple_results = True @mosaic_gpu_p.def_abstract_eval -def _mosaic_gpu_abstract_eval(*_, module, out_types): - del module # Unused. - return [jax._src.core.ShapedArray(t.shape, t.dtype) for t in out_types] +def _mosaic_gpu_abstract_eval(*_, module, out_types, inout_types): + del module # Unused. + return [ + jax_core.ShapedArray(t.shape, t.dtype) + for t in itertools.chain(out_types, inout_types) + ] + + +def _has_communication(module, **_): + empty_str_attr = ir.StringAttr.get("") + for op in module.body: + if "nvshmem" in getattr(op, "sym_name", empty_str_attr).value: + return True + return False + # TODO(apaszke): Implement a proper system for managing kernel lifetimes -KNOWN_KERNELS = {} +# Maps kernel ID to the compiled kernel ASM. +KNOWN_KERNELS: dict[bytes, bytes] = {} def _mosaic_gpu_lowering_rule( @@ -102,32 +165,78 @@ def _mosaic_gpu_lowering_rule( *args, module, out_types, + inout_types, input_output_aliases: tuple[tuple[int, int], ...] = (), + use_custom_barrier: bool = False, ): - assert len(out_types) == len(ctx.avals_out) + axis_context = ctx.module_context.axis_context + if _has_communication(module): + # Those checks are trying to ensure that the logical device ids are + # consistent with the NVSHMEM PE ids that Mosaic will be using for + # communication. Any divergence here would require us to implement a logical + # to physical translation, which is currently not implemented. + if isinstance(axis_context, sharding_impls.SPMDAxisContext): + mesh = axis_context.mesh + # Skip the check for AbstractMesh + if (isinstance(mesh, mesh_lib.Mesh) and + not np.array_equal(mesh.device_ids.ravel(), np.arange(mesh.size))): + raise NotImplementedError( + "Mosaic GPU only supports meshes with device ordering that follows" + f" row-major device ids. Got: {mesh.device_ids.ravel()} device ids." + ) + elif isinstance(axis_context, sharding_impls.ShardingContext): + if axis_context.num_devices != 1: + raise NotImplementedError( + "Mosaic GPU only supports single-device meshes in ShardingContext." + f" Got: {axis_context.num_devices} devices." + ) + else: + raise NotImplementedError(f"Unsupported sharding context: {axis_context}") + + if inout_types: + if input_output_aliases: + raise ValueError( + "input_output_aliases and inout_types are mutually exclusive" + ) + num_inputs = len(ctx.avals_in) + num_outputs = len(ctx.avals_out) + input_output_aliases = tuple( + (num_inputs - 1 - i, num_outputs - 1 - i) + for i in range(len(inout_types)) + ) + assert len(ctx.avals_in) == len(args) + assert len(ctx.avals_out) == len(out_types) + len(inout_types) module = _run_serde_pass( module, serialize=True, ir_version=FWD_COMPAT_IR_VERSION if ctx.is_forward_compat() else None, ) - module_asm = module.operation.get_asm(binary=True, enable_debug_info=True) + bytecode_buffer = io.BytesIO() + module.operation.write_bytecode(bytecode_buffer, desired_version=0) + module_asm = bytecode_buffer.getvalue() kernel_id = hashlib.sha256(module_asm).digest() # Note that this is technically only a half measure. Someone might load a # compiled module with a hash collision from disk. But that's so unlikely with # SHA256 that it shouldn't be a problem. if (kernel_text := KNOWN_KERNELS.get(kernel_id, None)) is not None: if kernel_text != module_asm: - raise RuntimeError("Hash collision!") + raise RuntimeError("Kernel hash collision!") else: KNOWN_KERNELS[kernel_id] = module_asm + op = mlir.custom_call( - "mosaic_gpu", + "mosaic_gpu_v2", result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], operands=args, operand_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_in], result_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_out], - backend_config=kernel_id + module_asm, + backend_config=dict( + kernel_hash=ir.StringAttr.get(kernel_id), + module=ir.StringAttr.get(module_asm), + use_custom_barrier=ir.BoolAttr.get(use_custom_barrier), + ), operand_output_aliases=dict(input_output_aliases), + api_version=4, ) return op.results @@ -157,64 +266,147 @@ class Barrier: arrival_count: int num_barriers: int = 1 + def __post_init__(self): + if self.arrival_count < 1: + raise ValueError( + f"Arrival count must be at least 1, but got {self.arrival_count}" + ) + @dataclasses.dataclass(frozen=True) class ClusterBarrier: collective_dims: Sequence[gpu.Dimension] + arrival_count: int = 1 num_barriers: int = 1 @dataclasses.dataclass(frozen=True) class TMEM: shape: tuple[int, int] dtype: Any + _: dataclasses.KW_ONLY layout: tcgen05.TMEMLayout | None = None collective: bool = False + packing: int | None = None def __post_init__(self): if self.layout is not None: - self.layout.check_shape(self.shape) + self.layout.check_type( + self.shape, utils.bitwidth(utils.dtype_to_ir_type(self.dtype)) + ) + if self.packing is not None: + raise ValueError("Cannot specify both layout and packing") def _count_buffer_bytes(shape_dtype: jax.ShapeDtypeStruct) -> int: - return math.prod(shape_dtype.shape) * np.dtype(shape_dtype.dtype).itemsize + return math.prod(shape_dtype.shape) * dtypes.itemsize_bits(dtypes.dtype(shape_dtype.dtype)) // 8 -class ThreadSemantics(enum.Enum): +class LoweringSemantics(enum.Enum): """Semantics for the kernel's instruction stream.""" Lane = enum.auto() Warpgroup = enum.auto() +@dataclasses.dataclass(frozen=True) +class _TMEMAlloc: + addr_ref: ir.Value + num_cols: int + collective: bool + + def alloc(self) -> int: + """Allocates TMEM and returns the number of columns allocated.""" + _, cols = tcgen05.tmem_alloc( + self.addr_ref, self.num_cols, collective=self.collective, exact=False + ) + return cols + + def dealloc(self): + addr = memref.load(self.addr_ref, []) + tcgen05.tmem_dealloc( + addr, self.num_cols, collective=self.collective, exact=False + ) + + +@dataclasses.dataclass() +class _TMEMDialectAlloc: + addr_ref: ir.Value + shape: tuple[int, int] + dtype: ir.Type + packing: int + collective: bool + tmem_ref: ir.Value | None = dataclasses.field(init=False, default=None) + + def alloc(self) -> int: + """Allocates TMEM and returns the number of columns allocated.""" + result_type = ir.MemRefType.get( + self.shape, + self.dtype, + memory_space=utils.tmem(), + ) + self.tmem_ref = dialect.tmem_alloc( + result_type, + self.addr_ref, + collective=self.collective, + packing=self.packing, + ) + ncols = self.shape[1] // self.packing + return tcgen05.tmem_alloc_exact_ncols(ncols, exact=False) + + def dealloc(self): + assert self.tmem_ref is not None + dialect.tmem_dealloc(self.tmem_ref) + + +def _slice_smem( + result: ir.Type, + smem_base: ir.Value, + offset: ir.Value, # This should be an ir.IndexType. + lowering_semantics: LoweringSemantics, +) -> ir.Value: + if lowering_semantics == LoweringSemantics.Warpgroup: + offset = arith.index_cast(ir.IntegerType.get_signless(32), offset) + return dialect.slice_smem(result, offset) + else: + return memref.view(result, smem_base, offset, []) + + def _construct_smem_reftree( cluster_shape: tuple[int, int, int], dynamic_smem: ir.Value, smem_buffers: ShapeTree, - delayed_warp_init: list[Callable[[], None]], # Mutated by this function! + tmem_allocs: list[ + _TMEMAlloc | _TMEMDialectAlloc + ], # Mutated by this function! + lowering_semantics: LoweringSemantics, dynamic_smem_offset: int = 0, ) -> Callable[[], RefTree]: index = ir.IndexType.get() - i8 = ir.IntegerType.get_signless(8) i32 = ir.IntegerType.get_signless(32) - smem = ir.Attribute.parse("#gpu.address_space") + i64 = ir.IntegerType.get_signless(64) flat_ref_tys, smem_buffer_tree = jax.tree.flatten( smem_buffers, is_leaf=lambda x: isinstance(x, Union) ) smem_refs = [] + for ref_ty in flat_ref_tys: - def get_barrier_ptr(num_barriers: int) -> ir.Value: + def barrier_memref(num_barriers: int) -> ir.Value: nonlocal dynamic_smem_offset - workgroup_nvptx_address_space = ( - utils.gpu_address_space_to_nvptx(gpu.AddressSpace.Workgroup) - ) - smem_base_ptr = utils.memref_ptr( - dynamic_smem, memory_space=workgroup_nvptx_address_space - ) - smem_ptr_ty = ir.Type.parse(f"!llvm.ptr<{workgroup_nvptx_address_space}>") - barrier_base_ptr = llvm.getelementptr( - smem_ptr_ty, smem_base_ptr, [], [dynamic_smem_offset], i8 + barrier_ty = ir.MemRefType.get( + (num_barriers,), + ir.Type.parse("!mosaic_gpu.barrier") + if lowering_semantics == LoweringSemantics.Warpgroup + else i64, + memory_space=utils.smem(), ) + barrier_memref = _slice_smem( + barrier_ty, + dynamic_smem, + c(dynamic_smem_offset, index), + lowering_semantics, + ) dynamic_smem_offset += num_barriers * utils.MBARRIER_BYTES - return barrier_base_ptr + return barrier_memref + ref: Any match ref_ty: case Union(members): member_thunks = [ @@ -222,7 +414,8 @@ def get_barrier_ptr(num_barriers: int) -> ir.Value: cluster_shape, dynamic_smem, m, - delayed_warp_init, + tmem_allocs, + lowering_semantics, dynamic_smem_offset, ) for m in members @@ -234,47 +427,65 @@ def ref(member_thunks=member_thunks): return Union([t() for t in member_thunks]) case TMABarrier(num_barriers): - ref = utils.BarrierRef.initialize( - get_barrier_ptr(num_barriers), num_barriers, arrival_count=1 + init_fn: Callable[..., Any] = ( + utils.DialectBarrierRef.initialize + if lowering_semantics == LoweringSemantics.Warpgroup + else utils.BarrierRef.initialize ) + ref = init_fn(barrier_memref(num_barriers), arrival_count=1) case Barrier(arrival_count, num_barriers): - ref = utils.BarrierRef.initialize( - get_barrier_ptr(num_barriers), - num_barriers, - arrival_count=arrival_count, + init_fn = ( + utils.DialectBarrierRef.initialize + if lowering_semantics == LoweringSemantics.Warpgroup + else utils.BarrierRef.initialize ) - case ClusterBarrier(collective_dims, num_barriers): + ref = init_fn(barrier_memref(num_barriers), arrival_count=arrival_count) + case ClusterBarrier(collective_dims, arrival_count, num_barriers): ref = utils.CollectiveBarrierRef.initialize( - get_barrier_ptr(num_barriers), - num_barriers, - collective_dims, - cluster_shape, + barrier_memref(num_barriers), arrival_count, collective_dims, cluster_shape ) - case TMEM(shape, dtype, layout, collective): - addr_ref = memref.view( - ir.MemRefType.get([], i32, memory_space=smem), - dynamic_smem, c(dynamic_smem_offset, index), [], + case TMEM(shape, dtype, layout=layout, collective=collective, packing=packing): + addr_ref = _slice_smem( + ir.MemRefType.get([], i32, memory_space=utils.smem()), + dynamic_smem, + c(dynamic_smem_offset, index), + lowering_semantics, ) - if layout is None: - layout = tcgen05._infer_tmem_layout(shape, collective) - num_cols = layout.cols_in_shape(shape) - delayed_warp_init.append( - functools.partial( - tcgen05.tmem_alloc, - addr_ref, num_cols, collective=collective, exact=False, - ) - ) - def ref(addr_ref=addr_ref, shape=shape, dtype=dtype, layout=layout): - addr = memref.load(addr_ref, []) - return tcgen05.TMEMRef( - addr, shape, utils.dtype_to_ir_type(dtype), layout + packing = 1 if packing is None else packing + ir_dtype = utils.dtype_to_ir_type(dtype) + if lowering_semantics == LoweringSemantics.Warpgroup: + if layout is not None: + packing = layout.vector_length + + alloc = _TMEMDialectAlloc( + addr_ref, shape, ir_dtype, packing, collective ) + tmem_allocs.append(alloc) + def ref(alloc=alloc, layout=layout): + assert alloc.tmem_ref is not None + if layout is not None: + layout_attr = layouts.to_layout_attr(layout) + return dialect.tmem_layout_cast(alloc.tmem_ref, layout_attr) + else: + return alloc.tmem_ref + + else: + if layout is None: + layout = tcgen05._infer_tmem_layout(shape, collective, packing) + num_cols = layout.cols_in_shape(shape, utils.bitwidth(ir_dtype)) + tmem_allocs.append(_TMEMAlloc(addr_ref, num_cols, collective)) + def ref(addr_ref=addr_ref, shape=shape, ir_dtype=ir_dtype, layout=layout): + addr = memref.load(addr_ref, []) + return tcgen05.TMEMRef(addr, shape, ir_dtype, layout) + dynamic_smem_offset += 4 # i32 takes up 4 bytes case _: mlir_dtype = utils.dtype_to_ir_type(ref_ty.dtype) - tile_smem = memref.view( - ir.MemRefType.get(ref_ty.shape, mlir_dtype, memory_space=smem), - dynamic_smem, c(dynamic_smem_offset, index), [], + tile_smem = _slice_smem( + ir.MemRefType.get(ref_ty.shape, mlir_dtype, memory_space=utils.smem()), + dynamic_smem, + c(dynamic_smem_offset, index), + lowering_semantics, ) dynamic_smem_offset += _count_buffer_bytes(ref_ty) ref = tile_smem @@ -300,13 +511,19 @@ def _smem_tree_size(smem_buffers: ShapeTree) -> int: size += max(_smem_tree_size(s) for s in members) case ( TMABarrier(num_barriers) - | ClusterBarrier(_, num_barriers=num_barriers) + | ClusterBarrier(_, _, num_barriers=num_barriers) | Barrier(_, num_barriers=num_barriers) ): if size % utils.MBARRIER_BYTES: - raise NotImplementedError("Misaligned barrier allocation") + raise NotImplementedError( + "Misaligned barrier allocation. Expected smem size" + f" ({size} bytes) to be divisible by the size of the barrier:" + f" {utils.MBARRIER_BYTES} bytes." + ) size += num_barriers * utils.MBARRIER_BYTES case TMEM(_): + # TODO(justinfu): This can trigger misaligned barrier allocations + # if TMEM is requested before barriers b/c it's not divisible by 8. size += 4 # i32 takes up 4 bytes case _: size += _count_buffer_bytes(l) @@ -320,13 +537,17 @@ def _launch( grid: tuple[int, int, int], cluster: tuple[int, int, int], block: tuple[int, int, int], - scratch_arr, smem_buffers: ShapeTree | Union[ShapeTree], + lowering_semantics: LoweringSemantics, + module: ir.Module, profiler_spec: profiler.ProfilerSpec | None = None, maybe_prof_buffer: ir.Value | None = None, ): if (profiler_spec is None) != (maybe_prof_buffer is None): - raise ValueError + raise ValueError( + "Both profiler_spec and maybe_prof_buffer must be specified or" + " left unspecified." + ) index = ir.IndexType.get() i32 = ir.IntegerType.get_signless(32) i8 = ir.IntegerType.get_signless(8) @@ -337,14 +558,29 @@ def _launch( smem_bytes = user_smem_bytes if profiler_spec is not None: - smem_bytes += profiler_spec.smem_bytes(block=block) - - # TODO(cperivol): Query the shared memory size programmatically. - if smem_bytes > 228 * 1024: - raise ValueError(f"Mosaic GPU kernel exceeds available shared memory {smem_bytes=} > 228000") + # Profiler array stores values in 64 bit chunks (vectors of size 2 + # of 32-bit elements), and so the starting address needs to be 64 + # bit = 8 byte aligned. + # https://docs.nvidia.com/cuda/parallel-thread-execution/#addresses-as-operands:~:text=The%20address%20must%20be%20naturally%20aligned%20to%20a%20multiple%20of%20the%20access%20size. + align = 8 + profiler_start = (smem_bytes + align - 1) & ~(align - 1) + smem_bytes = profiler_start + profiler_spec.smem_bytes(block=block) + + device = jax.local_devices()[0] + # For ahead-of-time compilation purposes, that is when a CUDA device + # isn't available to query directly, we default to 227 KB, the + # maximum amount of shared memory per thread block available in + # compute capabilities 9.0 and 10.x: + # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications-technical-specifications-per-compute-capability + # Note in either case we assume all devices have the same amount of + # shared memory. + max_smem_bytes = getattr(device, "shared_memory_per_block_optin", 227 * 1024) + if smem_bytes > max_smem_bytes: + raise ValueError("Mosaic GPU kernel exceeds available shared memory: " + f"{smem_bytes=} > {max_smem_bytes=}") if math.prod(cluster) != 1: if len(cluster) != 3: - raise ValueError("Clusters must be 3D") + raise ValueError(f"Clusters must be 3D. Got: {cluster}") cluster_kwargs = { "clusterSize" + d: c(s, index) for s, d in zip(cluster, "XYZ") } @@ -356,39 +592,53 @@ def _launch( ) else: cluster_kwargs = {} - launch_op = gpu.LaunchOp( - token.type, [token], *grid_vals, *block_vals, - dynamicSharedMemorySize=c(smem_bytes, i32), **cluster_kwargs) + # `gpu.LaunchOp` is missing the clusterSize{X,Y,Z} arguments. + launch_op = _gpu_ops_gen.LaunchOp( + token.type, + [token], + *grid_vals, + *block_vals, + dynamicSharedMemorySize=c(smem_bytes, i32), + **cluster_kwargs, + ) launch_op.body.blocks.append(*([index] * (12 + 2 * len(cluster_kwargs)))) # Append an empty block - smem = ir.Attribute.parse("#gpu.address_space") with ir.InsertionPoint(launch_op.body.blocks[0]): dynamic_smem = gpu.dynamic_shared_memory( - ir.MemRefType.get( - (ir.ShapedType.get_dynamic_size(),), i8, memory_space=smem - ) + ir.MemRefType.get((utils.DYNAMIC,), i8, memory_space=utils.smem()) ) if profiler_spec: - prof_smem = memref.view( + prof_smem = _slice_smem( ir.MemRefType.get( (profiler_spec.smem_i32_elements(block=block),), - i32, memory_space=smem, + i32, + memory_space=utils.smem(), ), - dynamic_smem, c(user_smem_bytes, index), [], + dynamic_smem, + c(profiler_start, index), + lowering_semantics, ) + if lowering_semantics == LoweringSemantics.Warpgroup: + prof_smem = dialect.with_transforms(prof_smem, ir.ArrayAttr.get([])) + wrap_in_custom_primitive = True + else: + wrap_in_custom_primitive = False prof = profiler.OnDeviceProfiler( - profiler_spec, prof_smem, maybe_prof_buffer + profiler_spec, + prof_smem, + maybe_prof_buffer, + wrap_in_custom_primitive, ) else: prof = None - ptr_ty = ir.Type.parse("!llvm.ptr") - scratch_ptr = builtin.unrealized_conversion_cast([ptr_ty], [scratch_arr]) - ctx = launch_context.LaunchContext(launch_op, scratch_ptr, cluster, prof) + ctx = launch_context.LaunchContext( + module, launch_context.Scratch(launch_op), cluster, prof + ) with ctx.named_region("Init"): - delayed_warp_init = [] + tmem_allocs: list[_TMEMAlloc | _TMEMDialectAlloc] = [] smem_ref_tree_thunk = _construct_smem_reftree( - cluster, dynamic_smem, smem_buffers, delayed_warp_init + cluster, dynamic_smem, smem_buffers, tmem_allocs, lowering_semantics ) # TODO(apaszke): Skip fences if no barriers or TMEM is initialized. # TODO(apaszke): Only initialize cluster barriers before the cluster wait. @@ -396,22 +646,79 @@ def _launch( if math.prod(cluster) != 1: nvvm.cluster_arrive_relaxed(aligned=ir.UnitAttr.get()) nvvm.cluster_wait(aligned=ir.UnitAttr.get()) - if delayed_warp_init: - eq = arith.CmpIPredicate.eq - is_init_warp = arith.cmpi(eq, utils.warp_idx(sync=False), c(0, i32)) - with utils.when(is_init_warp): - for init in delayed_warp_init: - init() - tcgen05.tmem_relinquish_alloc_permit() + if tmem_allocs: + init_warp_ctx: contextlib.AbstractContextManager + if lowering_semantics == LoweringSemantics.Warpgroup: + init_warp_ctx = contextlib.nullcontext() + else: + eq = arith.CmpIPredicate.eq + is_init_warp = arith.cmpi(eq, utils.warp_idx(sync=False), c(0, i32)) + init_warp_ctx = utils.when(is_init_warp) + with init_warp_ctx: + cols_used = 0 + for alloc in tmem_allocs: + cols_used += alloc.alloc() + if cols_used > tcgen05.TMEM_MAX_COLS: + raise ValueError( + "Total TMEM allocation exceeds memory limit. " + f"Requested {cols_used} columns which exceeds limit of " + f"{tcgen05.TMEM_MAX_COLS}." + ) + collective_types = {alloc.collective for alloc in tmem_allocs} + if len(collective_types) > 1: + raise ValueError( + "Can't mix collective and non-collective TMEM allocations" + " within the same kernel." + ) + collective = True in collective_types + if collective and math.prod(cluster) % 2: + raise ValueError( + "Collective TMEM allocations are only supported for clusters" + " with an even number of blocks in them. Got cluster:" + f" {cluster}" + ) + if lowering_semantics == LoweringSemantics.Warpgroup: + dialect.tmem_relinquish_alloc_permit(collective=collective) + else: + tcgen05.tmem_relinquish_alloc_permit(collective=collective) gpu.barrier() # Make sure the init is visible to all threads. smem_ref_tree = smem_ref_tree_thunk() yield ctx, smem_ref_tree + + if tmem_allocs: + gpu.barrier() # Make sure everyone is done before we release TMEM. + if any(alloc.collective for alloc in tmem_allocs): + nvvm.cluster_arrive_relaxed(aligned=ir.UnitAttr.get()) + nvvm.cluster_wait(aligned=ir.UnitAttr.get()) + if lowering_semantics == LoweringSemantics.Warpgroup: + init_warp_ctx = contextlib.nullcontext() + else: + init_warp_ctx = utils.when(is_init_warp) + with init_warp_ctx: + for alloc in tmem_allocs: + alloc.dealloc() if prof is not None: prof.finalize(grid=grid, block=block) gpu.terminator() +def _infer_arch() -> tuple[int, int]: + device: Any = jax.sharding.get_abstract_mesh().abstract_device + if device is None: + device = jex_backend.get_default_device() + if not hasattr(device, "compute_capability"): + return (9, 0) # TODO(apaszke): Remove this once we figure out the export story. + arch_name = device.compute_capability + # Handle ROCm devices that return architecture strings like "gfxXXX". + if arch_name.startswith("gfx"): + raise ValueError( + f"Mosaic GPU does not yet support AMD ROCm devices. " + f"Got compute_capability: {arch_name}" + ) + return tuple(map(int, arch_name.split("."))) # type: ignore + + def _lower_as_gpu_kernel( body, grid: tuple[int, int, int], @@ -419,40 +726,44 @@ def _lower_as_gpu_kernel( block: tuple[int, int, int], in_shapes: tuple[Any, ...], out_shape, + inout_shape, smem_scratch_shape: ShapeTree | Union[ShapeTree], + lowering_semantics: LoweringSemantics, module_name: str, - kernel_name: str | None = None, + kernel_name: str, prof_spec: profiler.ProfilerSpec | None = None, ): ptr_ty = ir.Type.parse("!llvm.ptr") token_ty = ir.Type.parse("!gpu.async.token") i32 = ir.IntegerType.get_signless(32) - i64 = ir.IntegerType.get_signless(64) def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType: return ir.MemRefType.get(shape.shape, utils.dtype_to_ir_type(shape.dtype)) in_ref_tys = [_shape_to_ref_ty(t) for t in in_shapes] + inout_ref_tys = [_shape_to_ref_ty(t) for t in inout_shape] unwrap_output_tuple = False if isinstance(out_shape, list): out_shape = tuple(out_shape) elif not isinstance(out_shape, tuple): out_shape = (out_shape,) - unwrap_output_tuple = True + unwrap_output_tuple = not inout_shape out_ref_tys = [_shape_to_ref_ty(t) for t in out_shape] if prof_spec is not None: out_shape = (*out_shape, prof_spec.jax_buffer_type(grid, block)) out_ref_tys.append(prof_spec.mlir_buffer_type(grid, block)) module = ir.Module.create() + dialect.register_dialect(module.context) attrs = module.operation.attributes attrs["sym_name"] = ir.StringAttr.get(module_name) - if kernel_name is None: - kernel_name = getattr(body, "__name__", "anonymous") + arch_major, arch_minor = _infer_arch() + attrs["mosaic_gpu.arch_major"] = ir.IntegerAttr.get(i32, arch_major) + attrs["mosaic_gpu.arch_minor"] = ir.IntegerAttr.get(i32, arch_minor) # These are needed as nonlocal below. - launch_ctx, scratch_arr = None, None + launch_ctx = None with ir.InsertionPoint(module.body): _declare_runtime_functions() global_scratch = llvm.GlobalOp( @@ -461,37 +772,31 @@ def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType: ir.Attribute.parse("#llvm.linkage"), addr_space=ir.IntegerAttr.get(i32, 4), # GPU constant memory. ) - @func.FuncOp.from_py_func(ptr_ty, ptr_ty, name=f"mosaic_gpu_{kernel_name}") + @func.FuncOp.from_py_func(ptr_ty, ptr_ty, name=f"{kernel_name}_mosaic_gpu") def main(token_ptr, buffers): - nonlocal launch_ctx, scratch_arr + nonlocal launch_ctx token = builtin.unrealized_conversion_cast([token_ty], [token_ptr]) arg_refs = [] - for i, ref_ty in enumerate([*in_ref_tys, *out_ref_tys]): - ptr = llvm.LoadOp(ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i], ptr_ty)) + # XLA will pass in inout refs again as outputs, but we ignore them. + for i, ref_ty in enumerate([*in_ref_tys, *inout_ref_tys, *out_ref_tys]): + ptr = llvm.LoadOp(ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i], ptr_ty, llvm.GEPNoWrapFlags.none)) arg_refs.append(utils.ptr_as_memref(ptr, ir.MemRefType(ref_ty))) - in_refs = arg_refs[:len(in_ref_tys)] - out_refs = arg_refs[len(in_ref_tys):] - prof_buffer = out_refs.pop() if prof_spec is not None else None - empty_arr_ty = ir.Type.parse("!llvm.array<0 x i8>") - scratch_alloc = llvm.AllocaOp( - ptr_ty, c(1, i64), empty_arr_ty, - alignment=launch_context.TMA_DESCRIPTOR_ALIGNMENT - ) - scratch_arr = llvm.load(empty_arr_ty, scratch_alloc.result) + prof_buffer = arg_refs.pop() if prof_spec is not None else None with _launch( - token, grid, cluster, block, scratch_arr, smem_scratch_shape, - prof_spec, prof_buffer + token, grid, cluster, block, smem_scratch_shape, + lowering_semantics, module, prof_spec, prof_buffer ) as (_launch_ctx, smem_refs): nonlocal launch_ctx launch_ctx = _launch_ctx - body(launch_ctx, *in_refs, *out_refs, smem_refs) + body(launch_ctx, *arg_refs, smem_refs) main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() sym_tab = ir.SymbolTable(module.operation) sym_tab.insert(main.func_op) sym_tab.insert(global_scratch) module.operation.verify() - return module, out_shape, unwrap_output_tuple, launch_ctx, scratch_arr + assert launch_ctx is not None + return module, out_shape, unwrap_output_tuple, launch_ctx def _run_serde_pass( @@ -518,27 +823,6 @@ def _run_serde_pass( return module -def _initialize_scratch( - launch_ctx : launch_context.LaunchContext, - scratch_arr: ir.Value, - ): - """ - Allocates and initializes the host buffer right before the launch. This needs - to be done after all TMA descriptors have been recorded by the launch context. - Only then we know what the scratch contains. - - When using the Mosaic GPU dialect, the necessary information is known only - after the lowering passes have run. - """ - with ir.InsertionPoint(scratch_arr.owner): - gmem_scratch_bytes = launch_ctx.next_scratch_offset - scratch_alloc_op = scratch_arr.owner.opview.addr.owner.opview - scratch_arr_ty = ir.Type.parse(f"!llvm.array<{gmem_scratch_bytes} x i8>") - scratch_alloc_op.elem_type = ir.TypeAttr.get(scratch_arr_ty) - scratch_arr.set_type(scratch_arr_ty) - for init_callback in launch_ctx.host_scratch_init: - init_callback(scratch_alloc_op.result) - def _declare_runtime_functions(): """Declares the runtime functions that can be used by the generated code.""" ptr_ty = ir.Type.parse("!llvm.ptr") @@ -550,7 +834,7 @@ def _declare_runtime_functions(): ) -def as_gpu_kernel( +def _kernel_to_module( body, grid: tuple[int, int, int], block: tuple[int, int, int], @@ -561,32 +845,80 @@ def as_gpu_kernel( cluster: tuple[int, int, int] = (1, 1, 1), module_name: str = "unknown", kernel_name: str | None = None, - ir_version: int | None = None, - thread_semantics: ThreadSemantics = ThreadSemantics.Lane, + thread_semantics: LoweringSemantics = LoweringSemantics.Lane, + inout_shape = (), ): if isinstance(in_shape, list): in_shape = tuple(in_shape) elif not isinstance(in_shape, tuple): in_shape = (in_shape,) + if isinstance(inout_shape, list): + inout_shape = tuple(inout_shape) + elif not isinstance(inout_shape, tuple): + inout_shape = (inout_shape,) + if kernel_name is None: + kernel_name = jax_util.fun_name(body, "anonymous") - module, out_shape, unwrap_output_tuple, launch_ctx, scratch_arr = ( + inout_shape = jax.tree.map(lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), + inout_shape) + out_shape = jax.tree.map(lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), + out_shape) + module, out_shape, unwrap_output_tuple, launch_ctx = ( _lower_as_gpu_kernel( - body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, - module_name, kernel_name, prof_spec + body, grid, cluster, block, in_shape, out_shape, inout_shape, + smem_scratch_shape, thread_semantics, module_name, kernel_name, + prof_spec ) ) - if thread_semantics == ThreadSemantics.Warpgroup and dialect is not None: + if thread_semantics == LoweringSemantics.Warpgroup and dialect is not None: + # We need to run a pass that removes dead-code for which layout inference + # does not work. + pm = mlir.passmanager.PassManager.parse("builtin.module(canonicalize)", module.context) + pm.run(module.operation) + # Run Python lowering passes. The remaining passes will be run in C++ in # jax/jaxlib/mosaic/gpu/custom_call.cc layout_inference.infer_layout(module) # pytype: disable=attribute-error - transform_inference.infer_transforms(module) # pytype: disable=attribute-error dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error - _initialize_scratch(launch_ctx, scratch_arr) + launch_ctx.scratch.finalize_size() module.operation.verify() - expected_arg_treedef = jax.tree.structure(in_shape) + return ( + module, + in_shape, + inout_shape, + out_shape, + unwrap_output_tuple, + launch_ctx.is_device_collective, + ) + + +def as_gpu_kernel( + body, + grid: tuple[int, int, int], + block: tuple[int, int, int], + in_shape, + out_shape, + smem_scratch_shape: ShapeTree | Union[ShapeTree], + prof_spec: profiler.ProfilerSpec | None = None, + cluster: tuple[int, int, int] = (1, 1, 1), + module_name: str = "unknown", + kernel_name: str | None = None, + ir_version: int | None = None, + thread_semantics: LoweringSemantics = LoweringSemantics.Lane, + inout_shape = (), +): + module, in_shape, inout_shape, out_shape, unwrap_output_tuple, is_device_collective = _kernel_to_module( + body, grid, block, in_shape, out_shape, smem_scratch_shape, prof_spec, + cluster, module_name, kernel_name, thread_semantics, inout_shape + ) + + if is_device_collective and not supports_cross_device_collectives(): + raise RuntimeError("Kernel is a cross-device collective but no support is available.") + + expected_arg_tys, expected_arg_treedef = jax.tree.flatten((*in_shape, *inout_shape)) def _check_args(*args): arg_treedef = jax.tree.structure(args) if arg_treedef != expected_arg_treedef: @@ -594,9 +926,28 @@ def _check_args(*args): f"Invalid argument structure: expected {expected_arg_treedef}, got" f" {arg_treedef}, ({args=})" ) + for arg, expected_ty in zip(args, expected_arg_tys): + if arg.shape != expected_ty.shape: + raise ValueError( + f"Argument shape mismatch: expected {expected_ty.shape}, got" + f" {arg.shape}" + ) + if arg.dtype != expected_ty.dtype: + hint = "" + if not arg.shape: + hint = f". Hint: cast the scalar to {expected_ty.dtype} explicitly." + raise ValueError( + f"Argument dtype mismatch: expected {expected_ty.dtype}, got" + f" {arg.dtype}{hint}" + ) def bind(*args) -> Any: - return mosaic_gpu_p.bind(*args, module=module, out_types=out_shape) + return mosaic_gpu_p.bind( + *args, + module=module, + out_types=out_shape, + inout_types=inout_shape, + ) if prof_spec is not None: @jax.jit @@ -604,10 +955,7 @@ def prof_kernel(*args): _check_args(*args) *results, prof_buffer = bind(*args) def dump_profile(prof_buffer): - out_file = os.path.join( - os.getenv("TEST_UNDECLARED_OUTPUTS_DIR"), - f"{time.time_ns()}-trace.json", - ) + out_file = os.path.join(prof_spec.dump_path, f"{time.time_ns()}-trace.json") try: with open(out_file, "x") as f: prof_spec.dump(prof_buffer, f, grid=grid, block=block) @@ -636,46 +984,58 @@ def as_torch_gpu_kernel( cluster: tuple[int, int, int] = (1, 1, 1), module_name: str = "unknown", kernel_name: str | None = None, - thread_semantics: ThreadSemantics = ThreadSemantics.Lane, + thread_semantics: LoweringSemantics = LoweringSemantics.Lane, + inout_shape=(), ): - try: - import torch - except ImportError: - raise RuntimeError("as_torch_gpu_kernel requires PyTorch") - torch.cuda.init() # Make sure CUDA context is set up. - - if isinstance(in_shape, list): - in_shape = tuple(in_shape) - elif not isinstance(in_shape, tuple): - in_shape = (in_shape,) - - flat_out_types, out_treedef = jax.tree.flatten(out_shape) - expected_arg_treedef = jax.tree.structure(in_shape) - - module, out_shape, unwrap_output_tuple, launch_ctx, scratch_arr = ( - _lower_as_gpu_kernel( - body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, - module_name, kernel_name, prof_spec - ) + ( + module, + in_shape, + inout_shape, + out_shape, + unwrap_output_tuple, + is_device_collective, + ) = _kernel_to_module( + body, + grid, + block, + in_shape, + out_shape, + smem_scratch_shape, + prof_spec, + cluster, + module_name, + kernel_name, + thread_semantics, + inout_shape, + ) + module = _run_serde_pass(module, serialize=True, ir_version=None) + return _as_torch_gpu_kernel( + module.operation.get_asm(binary=True, enable_debug_info=True), + in_shape, + out_shape, + inout_shape, + unwrap_output_tuple=unwrap_output_tuple, ) - if thread_semantics == ThreadSemantics.Warpgroup and dialect is not None: - # Run Python lowering passes. The remaining passes will be run in C++ in - # jax/jaxlib/mosaic/gpu/custom_call.cc - layout_inference.infer_layout(module) # pytype: disable=attribute-error - transform_inference.infer_transforms(module) # pytype: disable=attribute-error - dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error - _initialize_scratch(launch_ctx, scratch_arr) - module.operation.verify() +def _compile_as_torch_gpu_kernel(module_asm: bytes): + try: + import torch # type: ignore[import-not-found] # pytype: disable=import-error + except ImportError: + raise RuntimeError("Can't compile for PyTorch: import torch failed") from None + + torch.cuda.init() # Make sure CUDA context is set up. # Get our hands on the compilation and unload functions try: - import jax_plugins.xla_cuda12 as cuda_plugin + try: + import jax_plugins.xla_cuda13 as cuda_plugin # type: ignore[import-not-found] # pytype: disable=import-error + except ImportError: + import jax_plugins.xla_cuda12 as cuda_plugin # type: ignore[import-not-found] # pytype: disable=import-error except ImportError: - raise RuntimeError("as_torch_gpu_kernel only works with recent jaxlib builds " - "that use backend plugins") - dll = ctypes.CDLL(cuda_plugin._get_library_path()) + dll = ctypes.CDLL(None) + else: + dll = ctypes.CDLL(cuda_plugin._get_library_path()) compile_func = dll.MosaicGpuCompile compile_func.argtypes = [ctypes.c_void_p] compile_func.restype = ctypes.POINTER(ctypes.c_void_p) @@ -683,13 +1043,47 @@ def as_torch_gpu_kernel( unload_func.argtypes = [compile_func.restype] unload_func.restype = None - module_asm = module.operation.get_asm(binary=True, enable_debug_info=True) - compiled = compile_func(ctypes.c_char_p(module_asm)) - if compiled is None: + compiled = compile_func(ctypes.c_char_p(module_asm), ctypes.c_int(len(module_asm))) + if not compiled: raise RuntimeError("Failed to compile the module") ctx, launch_ptr = compiled[0], compiled[1] ctx_ptr_ptr = ctypes.pointer(ctypes.c_void_p(ctx)) - launch = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(launch_ptr) + launch_c = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(launch_ptr) + + def launch(arg_ptrs, device): + # Allocate another buffer for args of the host-side program. This is sadly + # the default MLIR calling convention. + launch_args_ptr = (ctypes.POINTER(ctypes.c_void_p) * 3)() + launch_args_ptr[0] = ctx_ptr_ptr + launch_args_ptr[1] = ctypes.pointer( + torch.cuda.default_stream(device)._as_parameter_ + ) + launch_args_ptr[2] = ctypes.cast( + ctypes.pointer(ctypes.pointer(arg_ptrs)), + ctypes.POINTER(ctypes.c_void_p), + ) + launch_c(launch_args_ptr) + + return launch, functools.partial(unload_func, compiled) + + +def _as_torch_gpu_kernel( + module_asm: bytes, + in_shape: Iterable[object], + out_shape: Iterable[object], + inout_shape: Iterable[object] = (), + *, + unwrap_output_tuple: bool = False, + _prepare_args = None, + _prepare_results = None, +): + flat_arg_types, expected_arg_treedef = jax.tree.flatten((*in_shape, *inout_shape)) + flat_out_types, _ = jax.tree.flatten(out_shape) + out_treedef = jax.tree.structure((*out_shape, *inout_shape)) + + launch, unload = _compile_as_torch_gpu_kernel(module_asm) + # _compile_as_torch_gpu_kernel checks that this succeeds + import torch # type: ignore[import-not-found] # pytype: disable=import-error def as_torch_dtype(dtype): # torch contains NumPy-compatible dtypes in its top namespace @@ -702,6 +1096,17 @@ def apply(*args): f"Invalid argument structure: expected {expected_arg_treedef}, got" f" {arg_treedef}, ({args=})" ) + for arg, expected_ty in zip(flat_args, flat_arg_types): + if arg.shape != expected_ty.shape: + raise ValueError( + f"Argument shape mismatch: expected {expected_ty.shape}, got" + f" {arg.shape}" + ) + if arg.dtype != as_torch_dtype(expected_ty.dtype): + raise ValueError( + "Argument dtype mismatch: expected" + f" {as_torch_dtype(expected_ty.dtype)}, got {arg.dtype}" + ) # Construct a device pointer list like in the XLA calling convention buffers = (ctypes.c_void_p * (arg_treedef.num_leaves + out_treedef.num_leaves))() @@ -715,19 +1120,13 @@ def apply(*args): out = torch.empty(t.shape, dtype=as_torch_dtype(t.dtype), device=device) flat_outs.append(out) buffers[i] = out.data_ptr() - # Allocate another buffer for args of the host-side program. This is sadly - # the default MLIR calling convention. - args_ptr = (ctypes.POINTER(ctypes.c_void_p) * 3)() - args_ptr[0] = ctx_ptr_ptr - args_ptr[1] = ctypes.pointer(torch.cuda.default_stream(device)._as_parameter_) - args_ptr[2] = ctypes.cast(ctypes.pointer(ctypes.pointer(buffers)), - ctypes.POINTER(ctypes.c_void_p)) - launch(args_ptr) - return jax.tree.unflatten(out_treedef, flat_outs) + if num_inout_args := jax.tree.structure(inout_shape).num_leaves: + flat_outs += flat_args[-num_inout_args:] + launch(buffers, device) + out = jax.tree.unflatten(out_treedef, flat_outs) + return out[0] if unwrap_output_tuple else out # Unload the compiled code when the Python function is destroyed. - def unload(_): - unload_func(compiled) - apply.destructor = weakref.ref(apply, unload) + apply.destructor = weakref.ref(apply, lambda _weak_ref: unload) return apply diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index fedde5a00887..ea94238eb82d 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -14,12 +14,17 @@ """Lowering rules and pass for the MLIR Mosaic GPU dialect.""" -from collections.abc import Callable +# mypy has been causing more problems than it solves here. Disable it for these +# files. We have pytype checks anyway. +# mypy: ignore-errors + +from collections.abc import Callable, Iterable, Sequence import dataclasses import functools import itertools +import math import operator -from typing import Any, Sequence, Type, cast +from typing import Any, Protocol, cast from jax._src.interpreters import mlir as mlir_interpreter from jax._src.lib import mosaic_gpu_dialect as mgpu @@ -28,32 +33,54 @@ from jax._src.lib.mlir.dialects import builtin from jax._src.lib.mlir.dialects import func from jax._src.lib.mlir.dialects import gpu -from jax._src.lib.mlir.dialects import llvm from jax._src.lib.mlir.dialects import math as mlir_math from jax._src.lib.mlir.dialects import memref from jax._src.lib.mlir.dialects import nvvm from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector +from jax._src.util import safe_zip +from jax.experimental.mosaic.gpu import layouts as layouts_lib +from jax.experimental.mosaic.gpu import utils as mgpu_utils import numpy as np from . import fragmented_array as fa from . import inference_utils from . import launch_context from . import layouts +from . import tcgen05 from . import utils from . import wgmma -# mypy: ignore-errors - @dataclasses.dataclass() class LoweringContext: launch_context: launch_context.LaunchContext | None single_thread_per_block_predicate: ir.Value | None single_thread_per_warpgroup_predicate: ir.Value | None + single_warp_per_block_predicate: ir.Value | None + auto_barriers: bool lowered_operations: set[ir.Operation | ir.OpView] = dataclasses.field( default_factory=set ) + is_collective_kernel: bool | None = dataclasses.field( + init=False, default=None + ) + + def check_collective(self, op: ir.OpView) -> None: + """Checks that the collective attribute is consistent across operations. + + It is an error to mix collective and non-collective operations in the same + kernel. + """ + if "collective" not in op.attributes: + return + if self.is_collective_kernel is None: + self.is_collective_kernel = op.attributes["collective"] + elif self.is_collective_kernel != op.attributes["collective"]: + raise ValueError( + "Collective attributes are inconsistent across operations in the" + " kernel." + ) def lower_op(self, op: ir.OpView): if not _should_lower(op): @@ -71,7 +98,7 @@ def lower_op(self, op: ir.OpView): raise ValueError(f"{op} is missing a layout and can not be lowered.") new_results = lowering_rule(self, op) - if new_results is not RECURSED: + if not isinstance(new_results, Recursed): for old, new in zip(op.results, new_results): old.replace_all_uses_with(new) self.lowered_operations.add(op) @@ -90,7 +117,46 @@ class Recursed: _lowerings: dict[str, MlirLoweringRule] = {} -def _fragmented_array_to_ir( +def _undo_conversion_cast( + ir_value: ir.Value, + expected_types: Sequence[ir.Type], +) -> tuple[builtin.UnrealizedConversionCastOp, Sequence[ir.Value]]: + """Undoes the provided unrealized conversion cast. + + The `ir_value` must be an unrealized conversion cast. This function will + create a new conversion cast that undoes the original one. The returned tuple + contains: + - The original unrealzied conversion cast (useful for extract attributes). + - The list of operands of the original conversion cast (which are the result + values of the undone conversion cast). + + The function will verify that the returned values have types that match + `expected_types`. + """ + conversion_cast = cast( + builtin.UnrealizedConversionCastOp, ir_value.owner.opview # pytype: disable=attribute-error + ) + + if not isinstance(conversion_cast, builtin.UnrealizedConversionCastOp): + raise ValueError(f"{conversion_cast} is not a conversion_cast") + + converted_outputs = builtin.unrealized_conversion_cast( + [operand.type for operand in conversion_cast.operands], + conversion_cast.results, + ) + if isinstance(converted_outputs, ir.OpResultList): + converted_outputs = list(converted_outputs) + elif not isinstance(converted_outputs, list): + converted_outputs = [converted_outputs] + + for v, t in zip(converted_outputs, expected_types, strict=True): + if v.type != t: + raise ValueError(f"Expected type {t} for value {v}") + + return conversion_cast, converted_outputs + + +def fragmented_array_to_ir( fragmented_array: fa.FragmentedArray, ty: ir.Type ) -> ir.Value: """Converts a FragmentedArray to an IR value. @@ -113,37 +179,39 @@ def _fragmented_array_to_ir( return conversion_cast.result +def _default_is_signed(dtype: ir.Type) -> bool | None: + """Returns `False` for Integer types, `None` otherwise. + + When converting from Pallas dtype to IR type, we lose the `is_signed` + information. We can default to `False` for most use cases. + """ + return False if isinstance(dtype, ir.IntegerType) else None + + def _fragmented_array_from_ir( fragmented_array_as_ir: ir.Value, layout: ir.Attribute, is_signed: bool | None = None, ) -> fa.FragmentedArray: - - conversion_cast = cast( - builtin.UnrealizedConversionCastOp, fragmented_array_as_ir.owner.opview # pytype: disable=attribute-error - ) - - if not isinstance(conversion_cast, builtin.UnrealizedConversionCastOp): - raise ValueError(f"{conversion_cast} is not a conversion_cast") - - converted_outputs = builtin.unrealized_conversion_cast( - [operand.type for operand in conversion_cast.operands], - conversion_cast.results, + producer_layout_attr = fragmented_array_as_ir.owner.attributes["layout"] + producer_layout = layouts.from_layout_attr(producer_layout_attr) + vector_ty = ir.VectorType(fragmented_array_as_ir.type) + reg_shape = producer_layout.registers_shape(tuple(vector_ty.shape)) + reg_ty = producer_layout.registers_element_type(vector_ty.element_type) + + conversion_cast, converted_outputs = _undo_conversion_cast( + fragmented_array_as_ir, [reg_ty] * math.prod(reg_shape) ) - if not isinstance(converted_outputs, list): - converted_outputs = [converted_outputs] reverse_conversion_cast = converted_outputs[0].owner.opview for attribute in conversion_cast.attributes: - attribute = cast(ir.NamedAttribute, attribute) - reverse_conversion_cast.attributes[attribute.name] = attribute.attr + reverse_conversion_cast.attributes[attribute] = conversion_cast.attributes[attribute] registers = np.array(list(converted_outputs)).reshape( [attr.value for attr in conversion_cast.attributes["registers_shape"]] ) - producer_layout = layouts.from_layout_attr(conversion_cast.attributes["layout"]) - if ir.IntegerType.isinstance(conversion_cast.outputs[0].type.element_type): + if isinstance(conversion_cast.outputs[0].type.element_type, ir.IntegerType): is_signed = False if is_signed is None else is_signed return fa.FragmentedArray( @@ -151,8 +219,43 @@ def _fragmented_array_from_ir( ).to_layout(layouts.from_layout_attr(layout)) +def wrap_transformed_memref( + transformed_memref: ir.Value, + logical_type: ir.Type, + transforms: ir.ArrayAttr, +) -> ir.Value: + """Wraps a transformed memref to an unrealized cast with transforms. + + The return type of the cast is the untransformed logical type. + """ + conversion_cast = builtin.UnrealizedConversionCastOp( + [logical_type], [transformed_memref] + ) + conversion_cast.attributes["transforms"] = transforms + return conversion_cast.result + + +def unwrap_transformed_memref( + ref: ir.Value, expected_transforms: ir.ArrayAttr +) -> ir.Value: + """Uwraps a memref from an unrealized cast and verifies its transforms.""" + + _, transforms = swizzle_and_transforms_from_transforms_attr(expected_transforms) + transformed_type = transformed_smem_ref_type(ref.type, transforms) + conversion_cast, [result] = _undo_conversion_cast(ref, [transformed_type]) + + # Check that the actual transforms match the expected ones. + if expected_transforms != conversion_cast.attributes["transforms"]: + raise ValueError( + f"Expected transforms {expected_transforms} do not match actual" + f" transforms {conversion_cast.attributes['transforms']}" + ) + + return result + + def _register_lowering( - op: str | Type[ir.OpView] | None + op: str | type[ir.OpView] | None ) -> Callable[[MlirLoweringRule], MlirLoweringRule]: def wrapper(f): if op is not None: @@ -170,42 +273,56 @@ def _lowered_barrier_type() -> ir.Type: @_register_lowering(mgpu.InitializeBarrierOp) def _initialize_barrier_op_lowering_rule( ctx: LoweringContext, - initialize_barrier_op: mgpu.InitializeBarrierOp, + op: mgpu.InitializeBarrierOp, ) -> Sequence[ir.Value]: - - shape = initialize_barrier_op.barriers_ref.type.shape - num_barriers = functools.reduce(operator.mul, shape, 1) - i32 = ir.IntegerType.get_signless(32) - workgroup_nvptx_address_space = utils.gpu_address_space_to_nvptx( - gpu.AddressSpace.Workgroup) - ptr_ty = ir.Type.parse(f"!llvm.ptr<{workgroup_nvptx_address_space}>") - lowered_barrier_type = _lowered_barrier_type() - for i in range(num_barriers): - nvvm.mbarrier_init_shared( - llvm.getelementptr(ptr_ty, initialize_barrier_op.base_pointer, [], [i], - lowered_barrier_type), - utils.c(initialize_barrier_op.arrival_count.value, i32), - predicate=ctx.single_thread_per_block_predicate + for i in range(op.num_barriers.value): + nvvm.mbarrier_init( + utils.getelementptr(op.base_pointer, [i], lowered_barrier_type), + utils.c( + op.arrival_count.value * utils.WARPGROUP_SIZE, + i32, + ), + predicate=ctx.single_thread_per_block_predicate, ) gpu.barrier() + return [] - barrier_base_ptr = llvm.getelementptr( - ir.Type.parse("!llvm.ptr"), - initialize_barrier_op.base_pointer, [], [0], lowered_barrier_type) - return utils.ptr_as_memref( - barrier_base_ptr, initialize_barrier_op.barriers_ref.type), +@_register_lowering(mgpu.OptimizationBarrierOp) +def _optimization_barrier_op_lowering_rule( + _: LoweringContext, + op: mgpu.OptimizationBarrierOp, +) -> Sequence[ir.Value]: + if not all( + isinstance(operand.type, ir.VectorType) for operand in op.operands + ): + raise NotImplementedError( + f"Optimization barrier op {op} has non-vector operands." + ) + + fragmented_arrays = [] + for operand, layout in safe_zip(op.operands, inference_utils.in_layouts(op)): + fragmented_arrays.append(_fragmented_array_from_ir(operand, layout)) + + lowered_fragmented_arrays = fa.optimization_barrier(*fragmented_arrays) + if isinstance(lowered_fragmented_arrays, fa.FragmentedArray): + lowered_fragmented_arrays = [lowered_fragmented_arrays] + + return [ + fragmented_array_to_ir(arr, result.type) + for arr, result in safe_zip(lowered_fragmented_arrays, op.results) + ] @_register_lowering(arith.ConstantOp) def _arith_constant_op_lowering_rule( _: LoweringContext, op: arith.ConstantOp ) -> Sequence[ir.Value]: - if not ir.DenseElementsAttr.isinstance(op.value): + if not isinstance(op.value, ir.DenseElementsAttr): raise NotImplementedError(f"Unsupported constant op: {op}") value = ir.DenseElementsAttr(op.value) @@ -213,10 +330,10 @@ def _arith_constant_op_lowering_rule( raise NotImplementedError(f"Unsupported constant op: {op}") ty = ir.VectorType(op.result.type) - is_signed = False if ir.IntegerType.isinstance(ty.element_type) else None + is_signed = _default_is_signed(ty.element_type) return [ - _fragmented_array_to_ir( + fragmented_array_to_ir( fa.FragmentedArray.splat( arith.constant(ty.element_type, value.get_splat_value()), tuple(ty.shape), @@ -254,7 +371,10 @@ def _check_transforms_and_swizzle_are_supported( ) } - tile_transforms = partitioned_transforms.get(True, []) + tile_transforms = cast( + list[launch_context.TileTransform], + partitioned_transforms.get(True, []), + ) other_transforms = partitioned_transforms.get(False, []) if len(tile_transforms) > 1: @@ -282,114 +402,229 @@ def _check_transforms_and_swizzle_are_supported( ) -@_register_lowering(vector.LoadOp) +class _Transfer(Protocol): + def __call__(self, optimized: bool) -> Any: + ... + +def _retry_on_failure(transfer: _Transfer, optimized: bool | None) -> Any: + """If `optimized` is `None`, retry `transfer` with `optimized=False` on failure.""" + if optimized is not None: + return transfer(optimized) + + # TODO(allanrenucci): Ideally we would have a way to know if we can emit an + # optimzed transfer. This relies on DCE to delete instructions generated by + # a failed call to `transfer`. + try: + return transfer(optimized=True) + except ValueError: + return transfer(optimized=False) + + +@_register_lowering(mgpu.VectorLoadOp) def _vector_load_op_lowering_rule( - _: LoweringContext, vector_load_op: vector.LoadOp + _: LoweringContext, op: mgpu.VectorLoadOp ) -> Sequence[ir.Value]: - (out_layout_attr,) = cast( - ir.ArrayAttr, vector_load_op.attributes["out_layouts"] - ) + (out_layout_attr,) = inference_utils.out_layouts(op) - for i in vector_load_op.indices: - index_defining_op = i.owner.opview - if ( - not isinstance(index_defining_op, arith.ConstantOp) - or index_defining_op.literal_value != 0 - ): - # TODO(bchetioui,dasenov): support non-zero indices. - raise NotImplementedError( - "Only constants with value 0 are supported as indices " - f"for {vector_load_op}" - ) + element_type = ir.VectorType(op.result.type).element_type + is_signed = _default_is_signed(element_type) - element_type = vector_load_op.result.type.element_type - is_signed = False if ir.IntegerType.isinstance(element_type) else None + def _fragmented_array_to_ir( + fragmented_array: fa.FragmentedArray, + ) -> ir.Value: + return fragmented_array_to_ir(fragmented_array, op.result.type) if layouts.is_strided_fragmented_layout(out_layout_attr): strided_layout = layouts.from_strided_fragmented_layout_attr( out_layout_attr ) + # TODO(bchetioui): Process transforms. fragmented_array = fa.FragmentedArray.load_strided( - vector_load_op.base, + op.source, is_signed=is_signed, vec_size=strided_layout.vec_size, ) - elif layouts.from_layout_attr(out_layout_attr) == fa.WGMMA_LAYOUT: - swizzle, transforms = swizzle_and_transforms_from_transforms_attr( - inference_utils.in_transforms(vector_load_op)[0] - ) - ref_ty = ir.MemRefType(vector_load_op.base.type) - _check_transforms_and_swizzle_are_supported(ref_ty, transforms, swizzle) - transformed_ref = transform_memref(vector_load_op.base, transforms) - fragmented_array = fa.FragmentedArray.load_tiled( - transformed_ref, - swizzle=swizzle, + return [_fragmented_array_to_ir(fragmented_array)] + + if not layouts.is_tiled_layout(out_layout_attr): + raise ValueError(f"{op} has an unsupported layout: {out_layout_attr}") + + optimized = op.optimized.value if op.optimized is not None else None + layout = layouts.from_tiled_layout_attr(out_layout_attr) + ref_ty = ir.MemRefType(op.source.type) + if ref_ty.memory_space is None: # GMEM + fragmented_array = fa.FragmentedArray.load_untiled( + op.source, + layout=layout, is_signed=is_signed, - layout=fa.WGMMA_LAYOUT, + optimized=bool(optimized), ) + return [_fragmented_array_to_ir(fragmented_array)] + + if ref_ty.memory_space != utils.smem(): + raise ValueError(f"Unsupported memory space: {ref_ty.memory_space}") + + transforms_attr = inference_utils.in_transforms(op)[0] + swizzle, transforms = swizzle_and_transforms_from_transforms_attr( + transforms_attr + ) + has_transforms = swizzle != mgpu.SwizzlingMode.kNoSwizzle or transforms + if has_transforms: + _check_transforms_and_swizzle_are_supported(ref_ty, transforms, swizzle) + transformed_ref = unwrap_transformed_memref(op.source, transforms_attr) + + def load_tiled(optimized: bool) -> fa.FragmentedArray: + return fa.FragmentedArray.load_tiled( + transformed_ref, + swizzle, + is_signed=is_signed, + layout=layout, + optimized=optimized, + ) + + fragmented_array = _retry_on_failure(load_tiled, optimized) else: - raise ValueError( - f"{vector_load_op} has an unsupported layout: {out_layout_attr}" - ) - return [_fragmented_array_to_ir(fragmented_array, vector_load_op.result.type)] + def load_untiled(optimized: bool) -> fa.FragmentedArray: + return fa.FragmentedArray.load_untiled( + op.source, + layout=layout, + is_signed=is_signed, + optimized=optimized, + ) + + fragmented_array = _retry_on_failure(load_untiled, optimized) -@_register_lowering(vector.StoreOp) + return [_fragmented_array_to_ir(fragmented_array)] + + +@_register_lowering(mgpu.VectorStoreOp) def _vector_store_op_lowering_rule( - _: LoweringContext, vector_store_op: vector.StoreOp + ctx: LoweringContext, op: mgpu.VectorStoreOp ) -> Sequence[ir.Value]: - for i in vector_store_op.indices: - index_defining_op = i.owner.opview - if ( - not isinstance(index_defining_op, arith.ConstantOp) - or index_defining_op.literal_value != 0 - ): - # TODO(bchetioui,dasenov): support non-zero indices. - raise NotImplementedError( - "Only constants with value 0 are supported as indices " - f"for {vector_store_op}" - ) + [to_store_layout] = inference_utils.in_layouts(op) + fragmented_array = _fragmented_array_from_ir(op.valueToStore, to_store_layout) - [to_store_layout] = inference_utils.in_layouts(vector_store_op) - fragmented_array = _fragmented_array_from_ir( - vector_store_op.valueToStore, to_store_layout - ) + if ctx.auto_barriers: + mgpu_utils.warpgroup_barrier() # Make sure the reads have completed. - if fragmented_array.layout == fa.WGMMA_LAYOUT: + ref = op.destination + ref_type = ir.MemRefType(ref.type) + optimized = op.optimized.value if op.optimized is not None else None + + if ref_type.memory_space is None: # GMEM + fragmented_array.store_untiled(ref, optimized=bool(optimized)) + elif ref_type.memory_space == utils.smem(): + transforms_attr = inference_utils.in_transforms(op)[0] swizzle, transforms = swizzle_and_transforms_from_transforms_attr( - inference_utils.in_transforms(vector_store_op)[0] - ) - ref_ty = ir.MemRefType(vector_store_op.base.type) - _check_transforms_and_swizzle_are_supported(ref_ty, transforms, swizzle) - fragmented_array.store_tiled( - transform_memref(vector_store_op.base, transforms), swizzle + transforms_attr ) - elif (isinstance(fragmented_array.layout, fa.WGStridedFragLayout) or - isinstance(fragmented_array.layout, fa.WGSplatFragLayout)): - fragmented_array.store_untiled(vector_store_op.base) + has_transforms = swizzle != mgpu.SwizzlingMode.kNoSwizzle or transforms + if has_transforms: + _check_transforms_and_swizzle_are_supported(ref_type, transforms, swizzle) + unwrapped_ref = unwrap_transformed_memref(ref, transforms_attr) + + def store_tiled(optimized: bool): + fragmented_array.store_tiled(unwrapped_ref, swizzle, optimized) + + _retry_on_failure(store_tiled, optimized) + else: + + def store_untiled(optimized: bool): + fragmented_array.store_untiled(ref, optimized=optimized) + + _retry_on_failure(store_untiled, optimized) else: - raise ValueError( - f"{vector_store_op} has an unsupported layout: {to_store_layout}" - ) + raise ValueError(f"Unsupported memory space: {ref_type.memory_space}") + if ctx.auto_barriers: + mgpu_utils.warpgroup_barrier() # Make sure the writes have completed. + + return [] + + +@_register_lowering(mgpu.DebugPrintOp) +def _debug_print_op_lowering_rule( + ctx: LoweringContext, op: mgpu.DebugPrintOp +) -> Sequence[ir.Value]: + del ctx + [layout] = inference_utils.in_layouts(op) + a = _fragmented_array_from_ir(op.value, layout) + a.debug_print(op.format.value) return [] -@_register_lowering(vector.SplatOp) -def _vector_splat_op_lowering_rule( - _: LoweringContext, vector_splat_op: vector.SplatOp + +def pprint_layout(v: fa.FragmentedArray | tcgen05.TMEMRef) -> str: + if isinstance(v, fa.FragmentedArray): + match v.layout: + case fa.WGMMA_LAYOUT: + return "WGMMA" + case fa.WGMMA_ROW_LAYOUT: + return "WGMMA_ROW" + case fa.WGMMA_TRANSPOSED_LAYOUT: + return "WGMMA_TRANSPOSED" + case fa.TCGEN05_LAYOUT: + return "TCGEN05" + case fa.TCGEN05_TRANSPOSED_LAYOUT: + return "TCGEN05_TRANSPOSED" + case fa.TMEM_NATIVE_LAYOUT: + return "TCGEN05_TMEM_NATIVE" + case _: + return str(v.layout) + else: + assert isinstance(v, tcgen05.TMEMRef), v + if v.layout == tcgen05.tmem_default_layout(packing=v.packing): + return f"TMEM_DEFAULT(packing={v.packing})" + return str(v.layout) + + +@_register_lowering(mgpu.PrintLayoutOp) +def _print_layout_op_lowering_rule( + ctx: LoweringContext, op: mgpu.PrintLayoutOp ) -> Sequence[ir.Value]: + del ctx + if isinstance(op.value.type, ir.VectorType): + (layout,) = inference_utils.in_layouts(op) + a = _fragmented_array_from_ir(op.value, layout) + print(op.format.value.format(pprint_layout(a))) + else: + (layout,) = inference_utils.in_tmem_layouts(op) + ref = _tmem_ref_from_ir(op.value, layout) + print(op.format.value.format(pprint_layout(ref))) + return [] - out_vec_ty = ir.VectorType(vector_splat_op.aggregate.type) - is_signed = ( - False if ir.IntegerType.isinstance(out_vec_ty.element_type) else None + +@_register_lowering(mgpu.BroadcastedIotaOp) +def _broadcasted_iota_op_lowering_rule( + ctx: LoweringContext, op: mgpu.BroadcastedIotaOp +) -> Sequence[ir.Value]: + del ctx + [layout] = inference_utils.out_layouts(op) + result_type = ir.VectorType(op.result.type) + a = fa.FragmentedArray.broadcasted_iota( + result_type.element_type, + tuple(result_type.shape), + op.dimension.value, + layouts.from_layout_attr(layout), + is_signed=_default_is_signed(result_type.element_type), ) + return [fragmented_array_to_ir(a, result_type)] + + +@_register_lowering(vector.BroadcastOp) +def _vector_broadcast_op_lowering_rule( + _: LoweringContext, op: vector.BroadcastOp +) -> Sequence[ir.Value]: + out_vec_ty = ir.VectorType(op.vector.type) fragmented_array = fa.FragmentedArray.splat( - vector_splat_op.input, + op.source, tuple(out_vec_ty.shape), - layouts.from_layout_attr(vector_splat_op.attributes["out_layouts"][0]), - is_signed=is_signed, + layouts.from_layout_attr( + op.attributes["out_layouts"][0] + ), + is_signed=_default_is_signed(out_vec_ty.element_type), ) - return [_fragmented_array_to_ir(fragmented_array, out_vec_ty)] + return [fragmented_array_to_ir(fragmented_array, out_vec_ty)] @_register_lowering(vector.ShapeCastOp) @@ -399,11 +634,67 @@ def _vector_shape_cast_op_lowering_rule( [layout] = inference_utils.in_layouts(op) out_vec_ty = ir.VectorType(op.result.type) assert out_vec_ty.has_static_shape - is_signed = ( - False if ir.IntegerType.isinstance(out_vec_ty.element_type) else None + a = _fragmented_array_from_ir(op.source, layout) + return [ + fragmented_array_to_ir(a.reshape(tuple(out_vec_ty.shape)), out_vec_ty) + ] + + +@_register_lowering(vector.ExtractStridedSliceOp) +def _vector_extract_strided_slice_op_lowering_rule( + ctx: LoweringContext, op: vector.ExtractStridedSliceOp +) -> Sequence[ir.Value]: + del ctx + if any(ir.IntegerAttr(s).value != 1 for s in op.strides): + raise NotImplementedError("`strides` must contain only 1s.") + [in_layout] = inference_utils.in_layouts(op) + [out_layout] = inference_utils.out_layouts(op) + assert in_layout == out_layout + out_vec_ty = ir.VectorType(op.result.type) + assert out_vec_ty.has_static_shape + a = _fragmented_array_from_ir(op.source, in_layout) + indices = tuple( + utils.DynamicSlice( + ir.IntegerAttr(offset).value, ir.IntegerAttr(length).value + ) + for offset, length in zip(op.offsets, op.sizes, strict=True) ) - a = _fragmented_array_from_ir(op.source, layout, is_signed) - return [_fragmented_array_to_ir(a.reshape(out_vec_ty.shape), out_vec_ty)] + result = a[indices] + assert result.layout == layouts.from_layout_attr(out_layout) + return [fragmented_array_to_ir(result, out_vec_ty)] + + +@_register_lowering(vector.ExtractOp) +def _vector_extract_op_lowering_rule( + ctx: LoweringContext, op: vector.ExtractOp +) -> Sequence[ir.Value]: + del ctx + if op.dynamic_position: + raise NotImplementedError("Only slicing with static indices allowed.") + + [in_layout] = inference_utils.in_layouts(op) + a = _fragmented_array_from_ir(op.source, in_layout) + + if not isinstance(op.result.type, ir.VectorType): # scalar result + result = a[tuple(op.static_position)] + assert isinstance(result.layout, fa.WGSplatFragLayout) + return [result.registers.item()] + + [out_layout] = inference_utils.out_layouts(op) + assert in_layout == out_layout + a = _fragmented_array_from_ir(op.source, in_layout) + result_type = ir.VectorType(op.result.type) + slices = tuple(slice(i, i + 1) for i in op.static_position) + # TODO(allanrenucci): Add direct support for indexing to FragmentedArray. + result = a[slices].reshape(tuple(result_type.shape)) + assert result.layout == layouts.from_layout_attr(out_layout) + return [fragmented_array_to_ir(result, result_type)] + + +def _combining_kind(attr: ir.Attribute) -> vector.CombiningKind: + return vector.CombiningKind[ + str(attr).removeprefix("#vector.kind<").removesuffix(">").upper() + ] @_register_lowering(vector.ReductionOp) @@ -412,26 +703,117 @@ def _vector_reduction_op_lowering_rule( ) -> Sequence[ir.Value]: del ctx # Unused. [layout] = inference_utils.in_layouts(op) - () = inference_utils.out_layouts(op) - element_type = ir.VectorType(op.vector.type).element_type - is_signed = False if ir.IntegerType.isinstance(element_type) else None - a = _fragmented_array_from_ir(op.vector, layout, is_signed) - match str(op.kind): - case "#vector.kind": - smem = ir.Attribute.parse("#gpu.address_space") - scratch = _slice_smem( - ir.MemRefType.get([4], element_type, memory_space=smem), - arith.constant(None, op.attributes["offset"]), - ) - result = a.reduce_sum(scratch) - case ( - "#vector.kind" | "#vector.kind" | "#vector.kind" - ): - # TODO(slebedev): Implement this and remove the raise below. + element_type = op.vector.type.element_type + scratch = _slice_smem( + ir.MemRefType.get([4], element_type, memory_space=utils.smem()), + arith.constant(None, op.attributes["offset"]), + ) + axes = range(op.vector.type.rank) + op_kind = _combining_kind(op.kind) + match op_kind: + case vector.CombiningKind.ADD: + a = _fragmented_array_from_ir(op.vector, layout) + result = a.reduce("add", axes, scratch) + case vector.CombiningKind.MAXSI | vector.CombiningKind.MAXUI: + is_signed = op_kind == vector.CombiningKind.MAXSI + a = _fragmented_array_from_ir(op.vector, layout, is_signed) + result = a.reduce("max", axes, scratch) + case vector.CombiningKind.MAXIMUMF: + a = _fragmented_array_from_ir(op.vector, layout) + result = a.reduce("max", axes, scratch) + case vector.CombiningKind.MINUI | vector.CombiningKind.MINSI: + is_signed = op_kind == vector.CombiningKind.MINSI + a = _fragmented_array_from_ir(op.vector, layout, is_signed) + result = a.reduce("min", axes, scratch) + case vector.CombiningKind.MINIMUMF: + a = _fragmented_array_from_ir(op.vector, layout) + result = a.reduce("min", axes, scratch) + case _: raise NotImplementedError(f"Unsupported reduction kind: {op.kind}") + assert isinstance(result.layout, fa.WGSplatFragLayout) + return [result.registers.item()] + +@_register_lowering(vector.MultiDimReductionOp) +def _vector_multi_dim_reduction_op_lowering_rule( + ctx: LoweringContext, op: vector.MultiDimReductionOp +) -> Sequence[ir.Value]: + del ctx + + [in_layout, acc_layout] = inference_utils.in_layouts(op) + [out_layout] = inference_utils.out_layouts(op) + if out_layout != acc_layout: + raise ValueError( + f"Output layout {out_layout} must match the accumulator layout" + f" {acc_layout}" + ) + + if len(op.reduction_dims) != 1: + raise NotImplementedError("Only 1 reduction dimension is supported.") + + op_kind = _combining_kind(op.kind) + match op_kind: + case vector.CombiningKind.ADD: + src = _fragmented_array_from_ir(op.source, in_layout) + acc = _fragmented_array_from_ir(op.acc, acc_layout) + result = src.reduce("add", op.reduction_dims[0]) + result += acc + case vector.CombiningKind.MAXSI | vector.CombiningKind.MAXUI: + is_signed = op_kind == vector.CombiningKind.MAXSI + src = _fragmented_array_from_ir(op.source, in_layout, is_signed) + acc = _fragmented_array_from_ir(op.acc, acc_layout, is_signed) + result = src.reduce("max", op.reduction_dims[0]) + result = result.max(acc) + case vector.CombiningKind.MAXIMUMF: + src = _fragmented_array_from_ir(op.source, in_layout) + acc = _fragmented_array_from_ir(op.acc, acc_layout) + result = src.reduce("max", op.reduction_dims[0]) + result = result.max(acc) + case vector.CombiningKind.MINUI | vector.CombiningKind.MINSI: + is_signed = op_kind == vector.CombiningKind.MINSI + src = _fragmented_array_from_ir(op.source, in_layout, is_signed) + acc = _fragmented_array_from_ir(op.acc, acc_layout, is_signed) + result = src.reduce("min", op.reduction_dims[0]) + result = result.min(acc) + case vector.CombiningKind.MINIMUMF: + src = _fragmented_array_from_ir(op.source, in_layout) + acc = _fragmented_array_from_ir(op.acc, acc_layout) + result = src.reduce("min", op.reduction_dims[0]) + result = result.min(acc) case _: raise NotImplementedError(f"Unsupported reduction kind: {op.kind}") - return [_fragmented_array_to_ir(result, op.result.type)] + assert result.layout == layouts.from_layout_attr(out_layout) # pytype: disable=attribute-error + return [fragmented_array_to_ir(result, op.result.type)] + + +@_register_lowering(mgpu.LayoutCastOp) +def _mgpu_layout_cast_op_lowering_rule( + _: LoweringContext, op: mgpu.LayoutCastOp +) -> Sequence[ir.Value]: + [in_layout] = inference_utils.in_layouts(op) + [out_layout] = inference_utils.out_layouts(op) + in_array = _fragmented_array_from_ir(op.x, in_layout) + out_array = in_array.to_layout(layouts.from_layout_attr(out_layout)) + return [fragmented_array_to_ir(out_array, op.result.type)] + + +@_register_lowering(mgpu.BroadcastInDimOp) +def _mgpu_broadcast_in_dim_op_lowering_rule( + _: LoweringContext, op: mgpu.BroadcastInDimOp +) -> Sequence[ir.Value]: + in_ty = ir.VectorType(op.operand.type) + out_ty = ir.VectorType(op.result.type) + if len(in_ty.shape) != 1 or len(out_ty.shape) != 2: + raise NotImplementedError( + "Broadcast in dim with non-trivial broadcast dimensions is not" + f" supported: {op}" + ) + + broadcast_dims = tuple(op.broadcast_dimensions) + in_layout_attr = inference_utils.in_layouts(op)[0] + operand_fa = _fragmented_array_from_ir(op.operand, in_layout_attr) + out_layout = layouts.from_layout_attr(inference_utils.out_layouts(op)[0]) + out = operand_fa.broadcast_in_dim(out_ty.shape, broadcast_dims, out_layout) + return [fragmented_array_to_ir(out, out_ty)] def swizzle_and_transforms_from_transforms_attr( @@ -475,32 +857,109 @@ def swizzle_and_transforms_from_transforms_attr( return swizzle or mgpu.SwizzlingMode.kNoSwizzle, tuple(gmem_transforms) -def transform_memref( - mem_ref: ir.Value, transforms: tuple[launch_context.MemRefTransform, ...] -) -> ir.Value: - """Reinterprets the memref to one where the shape is transformed as given.""" +def transformed_smem_ref_type( + ref_ty: ir.MemRefType, + transforms: tuple[launch_context.MemRefTransform, ...], +) -> ir.MemRefType: + """Returns the transformed ref type for the given logical ref and transforms. + """ if not transforms: - return mem_ref + return ref_ty - mem_ref_type = ir.MemRefType(mem_ref.type) - if mem_ref_type.memory_space != ir.Attribute.parse( - "#gpu.address_space" - ): - raise ValueError(f"Only workgroup memory is supported but got {mem_ref}.") + if not utils.is_smem_ref(ref_ty): + raise ValueError(f"Only workgroup memory is supported but got {ref_ty}.") + + shape = ref_ty.shape + strides, offset = ref_ty.get_strides_and_offset() + transposed = utils.is_memref_transposed(ref_ty) + if transposed: + if len(shape) != 2: + raise NotImplementedError( + f"Only 2D shapes can be transposed, but got {shape}" + ) + if strides[0] != 1 or strides[1] != shape[0]: + raise NotImplementedError( + f"Only contiguous 2D memrefs can be transposed, but got {ref_ty}" + ) - shape = mem_ref_type.shape for t in transforms: - shape = t.transform_shape(shape) + shape = list(t.transform_shape(shape)) + + minor_to_major_stride_order: tuple[int, ...] + if transposed: + # The expected output is a transposed ref and `shape` is already transposed. + # We need to compute the correct strides to match the shape. + if len(shape) == 2: + minor_to_major_stride_order = (1, 0) + elif len(shape) == 4: + minor_to_major_stride_order = (2, 3, 0, 1) + else: + raise NotImplementedError( + f"Expected a 2D or 4D shape after transforms, but got {shape}" + ) + else: + minor_to_major_stride_order = tuple(reversed(range(len(shape)))) - memref_new_type = ir.MemRefType.get( + new_strides = [1] * len(shape) + for i in range(1, len(shape)): + dim = minor_to_major_stride_order[i] + prev_dim = minor_to_major_stride_order[i-1] + new_strides[dim] = new_strides[prev_dim] * shape[prev_dim] + + new_ref_ty = ir.MemRefType.get( shape, - mem_ref_type.element_type, - memory_space=mem_ref_type.memory_space, + ref_ty.element_type, + memory_space=ref_ty.memory_space, + layout=ir.StridedLayoutAttr.get(offset, new_strides), ) + return new_ref_ty + + +def reinterpret_smem_ref( + ref: ir.Value, + transforms: tuple[launch_context.MemRefTransform, ...], +) -> ir.Value: + """Applies transforms on the ref, and makes sure that their effect is + propagated appropriately on the strides. + This function is used any time we lower from a dialect SMEM ref (2D for wgmma) + with given transforms to a "physical" SMEM ref (4D for wgmma) that is fully + transformed and transposed as needed. + """ + ref_ty = ir.MemRefType(ref.type) + new_ref_ty = transformed_smem_ref_type(ref_ty, transforms) + if ref_ty == new_ref_ty: + return ref ms = utils.WORKGROUP_NVPTX_ADDRESS_SPACE - ptr = utils.memref_ptr(mem_ref, memory_space=ms) - return utils.ptr_as_memref(ptr, memref_new_type, ptr_memory_space=ms) + ptr = utils.memref_ptr(ref, memory_space=ms) + new_ref = utils.ptr_as_memref(ptr, new_ref_ty, ptr_memory_space=ms) + return new_ref + + +def _gmem_slice_and_predicate( + ctx: LoweringContext, + op: mgpu.AsyncLoadOp | mgpu.AsyncPrefetchOp | mgpu.AsyncStoreOp, +) -> tuple[ + tuple[ir.Value | fa.FragmentedArray | utils.DynamicSlice, ...], + dict[str, ir.Value], +]: + """Returns the GMEM slice and predicate for the given async op.""" + gmem_slice = [] + predicate = dict(predicate=ctx.single_thread_per_warpgroup_predicate) + for idx, size in zip(op.indices, op.slice_lengths, strict=True): + if isinstance(idx.type, ir.IntegerType): + idx_int = arith.index_cast(ir.IndexType.get(), idx) + v = idx_int if size < 0 else utils.DynamicSlice(idx_int, size) + gmem_slice.append(v) + elif isinstance(idx.type, ir.VectorType): + layout = inference_utils.in_layouts(op)[0] + assert layouts.from_layout_attr(layout) == fa.TMA_GATHER_INDICES_LAYOUT + idx_fa = _fragmented_array_from_ir(idx, layout) + gmem_slice.append(idx_fa) + predicate = dict() + else: + raise TypeError(f"Unsupported index type: {idx.type}") + return tuple(gmem_slice), predicate @_register_lowering(mgpu.AsyncLoadOp) @@ -508,34 +967,77 @@ def _mgpu_async_load_op_lowering_rule( ctx: LoweringContext, load_op: mgpu.AsyncLoadOp ) -> Sequence[ir.Value]: assert ctx.launch_context is not None - barrier = utils.BarrierRef.from_dialect_barrier_memref(load_op.barrier) + barrier = utils.DialectBarrierRef.from_barrier_memref(load_op.barrier) - if inference_utils.has_in_transforms_set(load_op): - [transforms] = inference_utils.in_transforms(load_op) - swizzle, transforms = swizzle_and_transforms_from_transforms_attr( - transforms + [transforms_attr] = inference_utils.in_transforms(load_op) + swizzle, transforms = swizzle_and_transforms_from_transforms_attr( + transforms_attr + ) + + unwrapped_dst = unwrap_transformed_memref( + load_op.destination, transforms_attr + ) + + if utils.is_memref_transposed(unwrapped_dst.type): + strides, _ = ir.MemRefType(unwrapped_dst.type).get_strides_and_offset() + permutation = tuple( + sorted(range(len(strides)), key=lambda i: strides[i], reverse=True) ) - else: - swizzle = mgpu.SwizzlingMode.kNoSwizzle - transforms = () + # We undo the tranpose and apply it as a transform. + unwrapped_dst = utils.memref_transpose( + unwrapped_dst, permutation + ) + if transforms: + raise NotImplementedError("Can't transpose transformed refs.") + transforms = (launch_context.TransposeTransform(permutation),) - gmem_slice = [] - for idx_i32, size in zip(load_op.indices, load_op.slice_lengths): - idx = arith.index_cast(ir.IndexType.get(), idx_i32) - v = idx if size < 0 else utils.DynamicSlice(idx, size) - gmem_slice.append(v) + gmem_slice, predicate = _gmem_slice_and_predicate(ctx, load_op) + + collective = [ + gpu.Dimension(ir.IntegerAttr(axis).value) + for axis in load_op.collective or [] + ] + + # TODO(dasenov): async_copy requires all GMEM strides except the last one + # to be a multiple of 16 bytes. This restriction could be loosned with + # strided layouts when they are contiguous in GMEM. In that case, we could do: + # flatten -> async_copy -> unflatted here, as long as flattened size is a + # multiple of 16. # TODO(dasenov): Add support for the remaining op properties. + if ctx.auto_barriers: + mgpu_utils.warpgroup_barrier() # Make sure the writes have completed. ctx.launch_context.async_copy( src_ref=load_op.source, - dst_ref=transform_memref(load_op.destination, transforms), - gmem_slice=tuple(gmem_slice), - barrier=barrier, + dst_ref=unwrapped_dst, + gmem_slice=gmem_slice, + barrier=barrier.barrier_ref, + collective=collective, arrive=False, - uniform=True, swizzle=swizzle, gmem_transform=transforms, - predicate=ctx.single_thread_per_warpgroup_predicate, + **predicate, + ) + return [] + + +@_register_lowering(mgpu.AsyncPrefetchOp) +def _mgpu_async_prefetch_op_lowering_rule( + ctx: LoweringContext, load_op: mgpu.AsyncPrefetchOp +) -> Sequence[ir.Value]: + assert ctx.launch_context is not None + + gmem_slice, predicate = _gmem_slice_and_predicate(ctx, load_op) + + if load_op.collective: + raise NotImplementedError("Collective prefetches are not supported yet.") + + ctx.launch_context.async_prefetch( + gmem_ref=load_op.source, + gmem_slice=gmem_slice, + swizzle=None, + gmem_transform=(), + **predicate, ) return [] @@ -546,35 +1048,81 @@ def _mgpu_async_store_op_lowering_rule( ) -> Sequence[ir.Value]: assert ctx.launch_context is not None - if inference_utils.has_in_transforms_set(store_op): - [transforms] = inference_utils.in_transforms(store_op) - swizzle, transforms = swizzle_and_transforms_from_transforms_attr( - transforms + [transforms_attr] = inference_utils.in_transforms(store_op) + swizzle, transforms = swizzle_and_transforms_from_transforms_attr( + transforms_attr + ) + unwrapped_source = unwrap_transformed_memref(store_op.source, transforms_attr) + if utils.is_memref_transposed(unwrapped_source.type): + strides, _ = ir.MemRefType(unwrapped_source.type).get_strides_and_offset() + permutation = tuple( + sorted(range(len(strides)), key=lambda i: strides[i], reverse=True) ) - else: - swizzle = mgpu.SwizzlingMode.kNoSwizzle - transforms = () + # We undo the tranpose and apply it as a transform. + unwrapped_source = utils.memref_transpose( + unwrapped_source, permutation + ) + if transforms: + raise NotImplementedError("Can't transpose transformed refs.") + transforms = (launch_context.TransposeTransform(permutation),) - gmem_slice = [] - for idx_i32, size in zip(store_op.indices, store_op.slice_lengths): - idx = arith.index_cast(ir.IndexType.get(), idx_i32) - v = idx if size < 0 else utils.DynamicSlice(idx, size) - gmem_slice.append(v) + gmem_slice, predicate = _gmem_slice_and_predicate(ctx, store_op) + + # TODO(dasenov): async_copy requires all GMEM strides except the last one + # to be a multiple of 16 bytes. This restriction could be loosned with + # strided layouts when they are contiguous in GMEM. In that case, we could do: + # flatten -> async_copy -> unflatted here, as long as flattened size is a + # multiple of 16. + + # TODO(b/415721295):Simplify, after the minimal jaxlib version is 0.8.2. + if hasattr(mgpu, "TMAReduction") and store_op.reduction_op is not None: + reduction_op = mgpu.TMAReduction(store_op.reduction_op.value).name.lower() + else: + reduction_op = None # TODO(dasenov): Add support for the remaining op properties. ctx.launch_context.async_copy( - src_ref=transform_memref(store_op.source, transforms), + src_ref=unwrapped_source, dst_ref=store_op.destination, - gmem_slice=tuple(gmem_slice), + gmem_slice=gmem_slice, swizzle=swizzle, gmem_transform=transforms, - uniform=True, - predicate=ctx.single_thread_per_warpgroup_predicate, + **predicate, arrive=store_op.commit_group, + reduction_op=reduction_op, ) return [] +@_register_lowering(mgpu.TmemLayoutCastOp) +def _tmem_layout_cast_lowering_rule( + ctx: LoweringContext, + op: mgpu.TmemLayoutCastOp, +) -> Sequence[ir.Value]: + del ctx + in_layout = inference_utils.in_tmem_layouts(op)[0] + tmem_ref = _tmem_ref_from_ir(op.ref, in_layout) + # We can't relayout TMEM. + assert layouts.to_layout_attr(tmem_ref.layout) == op.new_layout + return [op.ref] + + +@_register_lowering(mgpu.SliceTmemOp) +def _slice_tmem_lowering_rule( + ctx: LoweringContext, op: mgpu.SliceTmemOp +) -> Sequence[ir.Value]: + del ctx + in_layout_attr = inference_utils.in_tmem_layouts(op)[0] + out_layout_attr = inference_utils.out_tmem_layouts(op)[0] + source = _tmem_ref_from_ir(op.source, in_layout_attr) + i32 = ir.IntegerType.get_signless(32) + offset = arith.constant(i32, op.offset) + dest_addr = arith.addi(source.address, offset) + cast = builtin.UnrealizedConversionCastOp([op.result.type], [dest_addr]) + cast.attributes["layout"] = out_layout_attr + return [cast.result] + + def _conversion_op_lowering_rule( _: LoweringContext, op: ir.OpView, @@ -589,7 +1137,7 @@ def _conversion_op_lowering_rule( target_ty = op.result.type.element_type # pytype: disable=attribute-error operand = _fragmented_array_from_ir(op.operands[0], layout, source_is_signed) converted = operand.astype(target_ty, is_signed=target_is_signed) - return [_fragmented_array_to_ir(converted, op.result.type)] + return [fragmented_array_to_ir(converted, op.result.type)] for op, source_is_signed, target_is_signed in [ @@ -613,31 +1161,41 @@ def _conversion_op_lowering_rule( def _unary_op_lowering_rule( _: LoweringContext, op: Any, - impl: Callable[[fa.FragmentedArray], fa.FragmentedArray], + impl: Callable[..., fa.FragmentedArray], is_signed: bool | None = None, ) -> Sequence[ir.Value]: in_layouts = inference_utils.in_layouts(op) [layout] = inference_utils.out_layouts(op) if any(in_layout != layout for in_layout in in_layouts): raise ValueError("Layout mismatch") - kwargs = {} - if hasattr(op, "fastmath"): - kwargs = dict( - approx=op.fastmath == ir.Attribute.parse("#arith.fastmath") - ) a = _fragmented_array_from_ir(op.operand, layout, is_signed) - return [_fragmented_array_to_ir(impl(a, **kwargs), op.result.type)] + if hasattr(op, "fastmath"): + if op.fastmath == ir.Attribute.parse("#arith.fastmath"): + result_fa = impl(a, approx=True) + else: + result_fa = impl(a) + else: + result_fa = impl(a) + return [fragmented_array_to_ir(result_fa, op.result.type)] -for op, impl, is_signed in [ + +for op, unary_impl, is_signed in [ (mlir_math.RsqrtOp, fa.FragmentedArray.rsqrt, None), (mlir_math.ExpOp, fa.FragmentedArray.exp, None), (mlir_math.Exp2Op, fa.FragmentedArray.exp2, None), + (mlir_math.SinOp, fa.FragmentedArray.sin, None), + (mlir_math.CosOp, fa.FragmentedArray.cos, None), (mlir_math.LogOp, fa.FragmentedArray.log, None), (mlir_math.TanhOp, fa.FragmentedArray.tanh, None), + (mlir_math.AbsFOp, fa.FragmentedArray.abs, None), + (mlir_math.AbsIOp, fa.FragmentedArray.abs, True), + (mlir_math.RoundOp, fa.FragmentedArray.round, None), + (mlir_math.RoundEvenOp, fa.FragmentedArray.round_even, None), + (mlir_math.ErfOp, fa.FragmentedArray.erf, None), ]: _lowerings[op.OPERATION_NAME] = functools.partial( - _unary_op_lowering_rule, impl=impl, is_signed=is_signed + _unary_op_lowering_rule, impl=unary_impl, is_signed=is_signed ) @@ -655,10 +1213,10 @@ def _binary_op_lowering_rule( raise ValueError("Layout mismatch") lhs = _fragmented_array_from_ir(op.lhs, layout, is_signed) rhs = _fragmented_array_from_ir(op.rhs, layout, is_signed) - return [_fragmented_array_to_ir(impl(lhs, rhs), op.result.type)] + return [fragmented_array_to_ir(impl(lhs, rhs), op.result.type)] -for op, impl, is_signed in [ +for op, binary_impl, is_signed in [ (arith.AddIOp, operator.add, False), (arith.AddFOp, operator.add, None), (arith.SubIOp, operator.sub, False), @@ -680,9 +1238,11 @@ def _binary_op_lowering_rule( (arith.MinSIOp, fa.FragmentedArray.min, True), (arith.MinUIOp, fa.FragmentedArray.min, False), (arith.MinimumFOp, fa.FragmentedArray.min, None), + (mlir_math.Atan2Op, fa.FragmentedArray.atan2, None), + (mlir_math.CopySignOp, fa.FragmentedArray.copysign, None), ]: _lowerings[op.OPERATION_NAME] = functools.partial( - _binary_op_lowering_rule, impl=impl, is_signed=is_signed + _binary_op_lowering_rule, impl=binary_impl, is_signed=is_signed ) @@ -708,10 +1268,10 @@ def _cmpi_op_lowering_rule( [layout] = inference_utils.out_layouts(op) if any(in_layout != layout for in_layout in in_layouts): raise ValueError("Layout mismatch") - impl, is_signed = CMPI_IMPLS[op.predicate.value] + impl, is_signed = CMPI_IMPLS[op.predicate.value] # pytype: disable=attribute-error lhs = _fragmented_array_from_ir(op.lhs, layout, is_signed) rhs = _fragmented_array_from_ir(op.rhs, layout, is_signed) - return [_fragmented_array_to_ir(impl(lhs, rhs), op.result.type)] + return [fragmented_array_to_ir(impl(lhs, rhs), op.result.type)] CMPF_IMPLS = { @@ -732,10 +1292,10 @@ def _cmpf_op_lowering_rule( [layout] = inference_utils.out_layouts(op) if any(in_layout != layout for in_layout in in_layouts): raise ValueError("Layout mismatch") - impl = CMPF_IMPLS[op.predicate.value] + impl = CMPF_IMPLS[op.predicate.value] # pytype: disable=attribute-error lhs = _fragmented_array_from_ir(op.lhs, layout) rhs = _fragmented_array_from_ir(op.rhs, layout) - return [_fragmented_array_to_ir(impl(lhs, rhs), op.result.type)] + return [fragmented_array_to_ir(impl(lhs, rhs), op.result.type)] @_register_lowering(arith.BitcastOp) @@ -750,91 +1310,146 @@ def _bitcast_op_lowering_rule( out_element_type = ir.VectorType(op.result.type).element_type out = in_.bitcast( out_element_type, - output_is_signed=False - if ir.IntegerType.isinstance(out_element_type) - else None, + output_is_signed=_default_is_signed(out_element_type), ) - return [_fragmented_array_to_ir(out, op.result.type)] + return [fragmented_array_to_ir(out, op.result.type)] + + +@_register_lowering(arith.SelectOp) +def _select_op_lowering_rule( + ctx: LoweringContext, op: arith.SelectOp +) -> Sequence[ir.Value]: + del ctx + in_layouts = inference_utils.in_layouts(op) + [layout] = inference_utils.out_layouts(op) + if any(in_layout != layout for in_layout in in_layouts): + raise ValueError("Layout mismatch") + pred = _fragmented_array_from_ir(op.condition, layout) + true_value = _fragmented_array_from_ir(op.true_value, layout) + false_value = _fragmented_array_from_ir(op.false_value, layout) + result = pred.select(true_value, false_value) + return [fragmented_array_to_ir(result, op.result.type)] @_register_lowering(mgpu.WGMMAOp) def _mgpu_wgmma_op_lowering_rule( _: LoweringContext, wgmma_op: mgpu.WGMMAOp ) -> Sequence[ir.Value]: - if wgmma_op.transpose_a or wgmma_op.transpose_b: - raise ValueError("Transpose arguments are to be deleted.") - - fa_layouts = ( - *inference_utils.in_layouts(wgmma_op), - *inference_utils.out_layouts(wgmma_op), - ) - is_supported_layout = ( - lambda l: layouts.from_tiled_layout_attr(l) == fa.WGMMA_LAYOUT - ) - if not all(map(is_supported_layout, fa_layouts)): - raise ValueError("Layout mismatch") - wgmma_layout = fa_layouts[0] - - # TODO(dasenov): Move the value -> accumulator conversion outisde of wgmma. + in_layouts = inference_utils.in_layouts(wgmma_op) + assert in_layouts[0] == layouts.to_layout_attr(fa.WGMMA_LAYOUT) + [out_layout] = inference_utils.out_layouts(wgmma_op) + assert out_layout == layouts.to_layout_attr(fa.WGMMA_LAYOUT) + + # s8/i8 WGMMA expects signed integer accumulator. + element_type = wgmma_op.a.type.element_type + is_signed = True if isinstance(element_type, ir.IntegerType) else None + # TODO(dasenov): Move the value -> accumulator conversion outside of wgmma. # The associated fence could be a little expensive and is not needed if the # result a wgmma feeds into another wgmma (even in another loop step). - acc_in = _fragmented_array_from_ir(wgmma_op.accumulator, wgmma_layout) - regs = acc_in.to_layout(fa.WGMMA_LAYOUT) + regs = _fragmented_array_from_ir( + wgmma_op.accumulator, in_layouts[0], is_signed + ) acc = wgmma.WGMMAAccumulator.from_registers(regs) - if ir.VectorType.isinstance(wgmma_op.a.type): + if isinstance(wgmma_op.a.type, ir.VectorType): a_transforms = None b_transforms = inference_utils.in_transforms(wgmma_op)[0] + unwrapped_a_ref = None + unwrapped_b_ref = unwrap_transformed_memref(wgmma_op.b, b_transforms) else: a_transforms, b_transforms = inference_utils.in_transforms(wgmma_op) + unwrapped_a_ref = unwrap_transformed_memref(wgmma_op.a, a_transforms) + unwrapped_b_ref = unwrap_transformed_memref(wgmma_op.b, b_transforms) b_swizzle, b_transforms = swizzle_and_transforms_from_transforms_attr( b_transforms ) minimum_swizzle = mgpu.SwizzlingMode.k32ByteSwizzle - ref_ty = ir.MemRefType(wgmma_op.b.type) _check_transforms_and_swizzle_are_supported( - ref_ty, b_transforms, b_swizzle, minimum_swizzle + ir.MemRefType(wgmma_op.b.type), b_transforms, b_swizzle, minimum_swizzle ) - b_operand = transform_memref(wgmma_op.b, b_transforms) - if ir.VectorType.isinstance(wgmma_op.a.type): - a_operand = _fragmented_array_from_ir(wgmma_op.a, wgmma_layout) + if isinstance(wgmma_op.a.type, ir.VectorType): + expected_a_layout = ( + fa.WGMMA_LAYOUT_8BIT + if utils.bitwidth(element_type) == 8 + else fa.WGMMA_LAYOUT + ) + assert in_layouts[1] == layouts.to_layout_attr(expected_a_layout) + a_operand = _fragmented_array_from_ir(wgmma_op.a, in_layouts[1], is_signed) else: a_swizzle, a_transforms = swizzle_and_transforms_from_transforms_attr( a_transforms ) - ref_ty = ir.MemRefType(wgmma_op.a.type) _check_transforms_and_swizzle_are_supported( - ref_ty, a_transforms, a_swizzle, minimum_swizzle + ir.MemRefType(wgmma_op.a.type), a_transforms, a_swizzle, minimum_swizzle ) if a_swizzle != b_swizzle: raise ValueError( f"Non-matching swizzles of operands a and b in WGMMA: {a_swizzle} !=" f" {b_swizzle}" ) - a_operand = transform_memref(wgmma_op.a, a_transforms) - - new_acc = wgmma.wgmma(acc, a_operand, b_operand, swizzle=b_swizzle) + assert unwrapped_a_ref is not None + a_operand = unwrapped_a_ref + new_acc = wgmma.wgmma(acc, a_operand, unwrapped_b_ref, swizzle=b_swizzle) return [ - _fragmented_array_to_ir( + fragmented_array_to_ir( new_acc.value.to_layout(fa.WGMMA_LAYOUT), wgmma_op.accumulator.type, ) ] -@_register_lowering(mgpu.ArriveExpectTxOp) -def _mgpu_arrive_expect_tx_op_lowering_rule( - ctx: LoweringContext, arrive_expect_tx_op: mgpu.ArriveExpectTxOp +@_register_lowering(mgpu.ArriveOp) +def _mgpu_arrive_op_lowering_rule( + ctx: LoweringContext, arrive_op: mgpu.ArriveOp ) -> Sequence[ir.Value]: + barrier = utils.DialectBarrierRef.from_barrier_memref(arrive_op.barrier) + orders_tc = arrive_op.orders_tensor_core.value + if orders_tc: + # Only one thread arrives, so make sure it ups the arrival count for the + # whole warpgroup. + # + # TODO(b/415721295): At the moment we assume that there is a single arrival + # per warpgroup. If we need to support also Warp-level semantics we will + # need to use a warp-level predicate. + predicate = ctx.single_thread_per_warpgroup_predicate + arrival_count = utils.WARPGROUP_SIZE + else: + # Each thread arrives once. + arrival_count = 1 + predicate = None + + barrier.barrier_ref.arrive( + arrival_count=arrival_count, + orders_tensor_core=orders_tc, + predicate=predicate, + ) + return [] + - barrier = utils.BarrierRef.from_dialect_barrier_memref(arrive_expect_tx_op.barrier) - barrier.arrive_expect_tx( - arrive_expect_tx_op.expect_tx.value, - ctx.single_thread_per_warpgroup_predicate, +@_register_lowering(mgpu.ArriveExpectTxOp) +def _mgpu_arrive_expect_tx_op_lowering_rule( + _: LoweringContext, arrive_expect_tx_op: mgpu.ArriveExpectTxOp +) -> Sequence[ir.Value]: + bytes = arrive_expect_tx_op.expect_tx.value + if bytes % utils.WARPGROUP_SIZE: + raise NotImplementedError( + "Only copies of a multiple of 128 bytes are supported" + ) + # We arrive uniformly from each thread in the WG, so we need to divide the + # number of bytes by the number of threads in the WG. + # TODO: dasenov - Relax this. We can just select the WG leader and have it + # arrive with the whole transfer size, while everyone else arrives with 0. + # But we should continue using this scheme as it's likely to be faster. + bytes //= utils.WARPGROUP_SIZE + bytes = utils.c(bytes, ir.IntegerType.get_signless(32)) + + barrier = utils.DialectBarrierRef.from_barrier_memref( + arrive_expect_tx_op.barrier ) + utils.nvvm_mbarrier_arrive_expect_tx(barrier.get_ptr(), bytes) return [] @@ -844,33 +1459,720 @@ def _mgpu_wait_op_lowering_rule( _: LoweringContext, wait_op: mgpu.WaitOp ) -> Sequence[ir.Value]: - barrier = utils.BarrierRef.from_dialect_barrier_memref(wait_op.barrier) + barrier = utils.DialectBarrierRef.from_barrier_memref(wait_op.barrier) barrier.wait_parity(wait_op.parity) return [] -# TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2. -SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None) - - -@_register_lowering(SliceSMEMOp) +@_register_lowering(mgpu.SliceSMEMOp) def _mgpu_slice_smem_op_lowering_rule( - ctx: LoweringContext, op: SliceSMEMOp + ctx: LoweringContext, op: mgpu.SliceSMEMOp ) -> Sequence[ir.Value]: del ctx - return [_slice_smem(op.result.type, op.offset)] + sliced_ref = _slice_smem(op.result.type, op.offset) + + memref_ty = ir.MemRefType(sliced_ref.type) + if ( + memref_ty.element_type == ir.Type.parse("!mosaic_gpu.barrier") + ): + # Barrier memrefs are not transformed and must not be wrapped. + assert not inference_utils.has_out_transforms_set(op) + return [sliced_ref] + + out_transforms = inference_utils.out_transforms(op)[0] + _, transforms = swizzle_and_transforms_from_transforms_attr(out_transforms) + transformed_ref = reinterpret_smem_ref(sliced_ref, transforms) + wrapped_ref = wrap_transformed_memref(transformed_ref, op.result.type, out_transforms) + return [wrapped_ref] def _slice_smem(result: ir.Type, offset: ir.Value): i8 = ir.IntegerType.get_signless(8) - smem = ir.Attribute.parse("#gpu.address_space") smem_base = gpu.dynamic_shared_memory( - ir.MemRefType.get((utils.DYNAMIC,), i8, memory_space=smem) + ir.MemRefType.get((utils.DYNAMIC,), i8, memory_space=utils.smem()) ) offset = arith.index_cast(ir.IndexType.get(), offset) - return memref.view(result, smem_base, offset, []) + lowered_result_type = result + if isinstance(result, ir.MemRefType): + memref_ty = ir.MemRefType(result) + if memref_ty.element_type == ir.Type.parse("!mosaic_gpu.barrier"): + lowered_result_type = ir.MemRefType.get( + memref_ty.shape, _lowered_barrier_type(), memory_space=utils.smem() + ) + view = memref.view(lowered_result_type, smem_base, offset, []) + if result == lowered_result_type: + return view + return builtin.unrealized_conversion_cast([result], [view]) + + +@_register_lowering(mgpu.WithTransformsOp) +def _mgpu_with_transforms_op_lowering_rule( + ctx: LoweringContext, op: mgpu.WithTransformsOp +) -> Sequence[ir.Value]: + """Lowering rule for mgpu.WithTransformsOp. + This is a noop that simply returns its input. + """ + del ctx + + [in_transforms] = inference_utils.in_transforms(op) + unwrapped_source_ref = unwrap_transformed_memref(op.ref, in_transforms) + out_transforms = inference_utils.out_transforms(op)[0] + wrapped_ref = wrap_transformed_memref( + unwrapped_source_ref, op.result.type, out_transforms + ) + return [wrapped_ref] + + +def _tile_transform_offsets( + tiling: Sequence[int], + static_offsets: Sequence[int], + dynamic_offsets: Sequence[ir.Value], +) -> tuple[Sequence[int], Sequence[ir.Value]]: + """Computes the static and dynamic offsets after the given tiling is applied. + + Conceptually, this function is analogous to + tile.transform_shape(static_offsets), except that it also handles dynamic offsets. + """ + dynamic_offset_index = 0 + new_static_offsets = [] + new_dynamic_offsets = [] + + # Preserve all offsets in non-tiled dimensions. + for offset in static_offsets[: -len(tiling)]: + new_static_offsets.append(offset) + if offset == ir.ShapedType.get_dynamic_stride_or_offset(): + new_dynamic_offsets.append(dynamic_offsets[dynamic_offset_index]) + dynamic_offset_index += 1 + + # Compute static and dynamic offsets of tiled dimensions. + for tile_size, offset in zip( + tiling, static_offsets[-len(tiling) :], strict=True + ): + if offset == ir.ShapedType.get_dynamic_stride_or_offset(): + # Here we assume that the offset is divisble by the tile size, but we + # don't check it. This has been established at the time the tiling was + # inferred. + dyn_offset = arith.divui( + dynamic_offsets[dynamic_offset_index], + utils.c(tile_size, ir.IndexType.get()), + ) + new_dynamic_offsets.append(dyn_offset) + new_static_offsets.append(ir.ShapedType.get_dynamic_stride_or_offset()) + dynamic_offset_index += 1 + else: + assert offset % tile_size == 0 + new_static_offsets.append(offset // tile_size) + + # Add 0 offsets for the newly created dimension of the tile. + new_static_offsets += [0] * len(tiling) + + return new_static_offsets, new_dynamic_offsets + + +@_register_lowering(memref.SubViewOp) +def _memref_subview_op_lowering_rule( + ctx: LoweringContext, op: memref.SubViewOp +) -> Sequence[ir.Value]: + del ctx + + if any(s != 1 for s in op.static_strides): + raise NotImplementedError("SubViewOp only supports static strides of 1.") + if op.sizes: + raise NotImplementedError("SubViewOp only supports static sizes.") + src_ty = ir.MemRefType(op.source.type) + + if utils.is_memref_transposed(src_ty): + raise NotImplementedError("SubViewOp does not support transposed memrefs.") + + if utils.is_tmem_ref(src_ty): + [in_tmem_layout] = inference_utils.in_tmem_layouts(op) + [out_tmem_layout] = inference_utils.out_tmem_layouts(op) + assert in_tmem_layout == out_tmem_layout + ref = _tmem_ref_from_ir(op.source, in_tmem_layout) + indices = [] + dynamic_offset_index = 0 + for offset, size in zip(op.static_offsets, op.static_sizes, strict=True): + if ir.ShapedType.is_dynamic_size(offset): + offset = op.offsets[dynamic_offset_index] + dynamic_offset_index += 1 + indices.append(utils.DynamicSlice(offset, size)) + return [_tmem_ref_to_ir(ref.slice(*indices))] + + in_transforms = inference_utils.in_transforms(op)[0] + out_transforms = inference_utils.out_transforms(op)[0] + + if in_transforms != out_transforms: + raise NotImplementedError( + "SubViewOp transforms for the input and output refs must be identical." + ) + + unwrapped_source_ref = unwrap_transformed_memref(op.source, in_transforms) + swizzle, transforms = swizzle_and_transforms_from_transforms_attr(out_transforms) + if swizzle != mgpu.SwizzlingMode.kNoSwizzle: + swizzle_elems = swizzle * 8 // utils.bitwidth(src_ty.element_type) + source_strides, _ = src_ty.get_strides_and_offset() + for stride, offset, size in zip( + source_strides, op.static_offsets, op.static_sizes, strict=True + ): + if stride != 1: + continue + # A dimension with stride 1 is a minor dimension and is swizzled. + if size % swizzle_elems != 0: + raise ValueError( + f"Swizzled dimension of {size=} is not a multiple of" + f" {swizzle_elems=}." + ) + # TODO(allanrenucci): Support dynamic offsets that are divisible by + # `swizzle_elems`. E.g. using `utils.is_known_divisible`. + if ir.ShapedType.is_dynamic_size(offset): + raise NotImplementedError( + "Slicing a swizzled dynamic dimension is not supported." + ) + if offset % swizzle_elems != 0: + raise ValueError( + f"subview {offset=} is not a multiple of {swizzle_elems=}." + ) + + match transforms: + case (): + new_subview_op = memref.SubViewOp( + op.result.type, + unwrapped_source_ref, + op.offsets, + None, + None, + static_offsets=op.static_offsets, + static_sizes=op.static_sizes, + static_strides=op.static_strides, + ) + case (tile_transform, ) if isinstance(tile_transform, launch_context.TileTransform): + in_transformed_ty = ir.MemRefType(unwrapped_source_ref.type) + tiling = tile_transform.tiling + if any( + ir.ShapedType.is_dynamic_size(s) + for s in list(op.static_sizes)[-len(tiling) :] + ): + raise NotImplementedError( + "SubViewOp only supports static sizes for the tiled dimensions." + ) + new_sizes = tile_transform.transform_shape(list(op.static_sizes)) + new_static_offsets, new_dynamic_offsets = _tile_transform_offsets( + tiling, list(op.static_offsets), list(op.offsets) + ) + + new_subview_op = memref.SubViewOp( + transformed_smem_ref_type(op.result.type, transforms), + unwrapped_source_ref, + new_dynamic_offsets, + None, + None, + static_offsets=new_static_offsets, + static_sizes=new_sizes, + static_strides=[1] * len(in_transformed_ty.shape), + ) + case _: + raise NotImplementedError( + "SubViewOp only supports a single tile transform." + ) + + wrapped_ref = wrap_transformed_memref( + new_subview_op.result, op.result.type, out_transforms + ) + return [wrapped_ref] + + +@_register_lowering(memref.CastOp) +def _memref_cast_op_lowering_rule( + ctx: LoweringContext, op: memref.CastOp +) -> Sequence[ir.Value]: + """Lowering rule for memref.CastOp. + Only casts that add a dynamic offset are supported. + """ + del ctx + + in_transforms = inference_utils.in_transforms(op)[0] + out_transforms = inference_utils.out_transforms(op)[0] + if in_transforms != out_transforms: + raise NotImplementedError( + "CastOp transforms for the input and output refs must be identical." + ) + + in_ty = ir.MemRefType(op.source.type) + out_ty = ir.MemRefType(op.result.type) + if in_ty.element_type != out_ty.element_type: + raise NotImplementedError( + "CastOp only supports casts between memrefs with the same element type." + ) + if in_ty.shape != out_ty.shape: + raise NotImplementedError( + "CastOp only supports casts between memrefs with the same shape." + ) + in_strides, _ = in_ty.get_strides_and_offset() + out_strides, out_offset = out_ty.get_strides_and_offset() + if in_strides != out_strides: + raise NotImplementedError( + "CastOp only supports casts between memrefs with the same strides." + ) + + unwrapped_source_ref = unwrap_transformed_memref(op.source, in_transforms) + in_transformed_ty = ir.MemRefType(unwrapped_source_ref.type) + transformed_strides, _ = in_transformed_ty.get_strides_and_offset() + out_layout = ir.StridedLayoutAttr.get(out_offset, transformed_strides) + out_transformed_ty = ir.MemRefType.get( + in_transformed_ty.shape, + in_transformed_ty.element_type, + memory_space=in_transformed_ty.memory_space, + layout=out_layout, + ) + new_cast_op = memref.CastOp(out_transformed_ty, unwrapped_source_ref) + wrapped_ref = wrap_transformed_memref( + new_cast_op.result, op.result.type, out_transforms + ) + return [wrapped_ref] + + +def _permutation_to_affine_map_attr( + permutation: Sequence[int], +) -> ir.AffineMapAttr: + return ir.AffineMapAttr.get(ir.AffineMap.get_permutation(permutation)) + + +@_register_lowering(memref.TransposeOp) +def _memref_transpose_op_lowering_rule( + ctx: LoweringContext, op: memref.TransposeOp +) -> Sequence[ir.Value]: + del ctx + + in_transforms = inference_utils.in_transforms(op)[0] + unwrapped_in_ref = unwrap_transformed_memref(op.in_, in_transforms) + in_transformed_ty = ir.MemRefType(unwrapped_in_ref.type) + if in_transformed_ty.rank == op.in_.type.rank: + new_permutation = op.permutation + elif in_transformed_ty.rank == 4: + if op.permutation == _permutation_to_affine_map_attr([0, 1]): + new_permutation = _permutation_to_affine_map_attr([0, 1, 2, 3]) + elif op.permutation == _permutation_to_affine_map_attr([1, 0]): + new_permutation = _permutation_to_affine_map_attr([1, 0, 3, 2]) + else: + raise NotImplementedError(f"Unsupported permutation={op.permutation}.") + else: + raise NotImplementedError( + "TransposeOp only supports transposing 4D tiled memrefs and untiled" + " memrefs." + ) + + out_transforms = inference_utils.out_transforms(op)[0] + _, transforms = swizzle_and_transforms_from_transforms_attr(out_transforms) + new_transpose_op = memref.TransposeOp( + transformed_smem_ref_type(op.result.type, transforms), + unwrapped_in_ref, + new_permutation, + ) + + wrapped_ref = wrap_transformed_memref( + new_transpose_op.result, op.result.type, out_transforms + ) + return [wrapped_ref] + + +@_register_lowering(memref.ExpandShapeOp) +def _memref_expand_shape_op_lowering_rule( + ctx: LoweringContext, op: memref.ExpandShapeOp +) -> Sequence[ir.Value]: + del ctx + + in_transforms = inference_utils.in_transforms(op)[0] + unwrapped_in_ref = unwrap_transformed_memref(op.src, in_transforms) + in_transformed_ty = ir.MemRefType(unwrapped_in_ref.type) + + out_transforms = inference_utils.out_transforms(op)[0] + _, transforms = swizzle_and_transforms_from_transforms_attr(out_transforms) + out_transformed_ty = transformed_smem_ref_type(op.result.type, transforms) + + reassociation = list(op.reassociation) + num_tiling_dims = len(in_transformed_ty.shape) - len(op.src.type.shape) + + # We don't currently allow expanding tiled dimensions. So to compute the + # reassociation on the lowered types, we just need to backfill the original + # one with the number of missing dimensions. + if num_tiling_dims > 0 and any( + len(x) > 1 for x in reassociation[-num_tiling_dims:] + ): + raise NotImplementedError("Expanding tiled dimensions is not supported.") + + start_index = len(op.static_output_shape) + for i in range(start_index, start_index + num_tiling_dims): + reassociation.append([i]) + + new_expand_shape_op = memref.ExpandShapeOp( + out_transformed_ty, + unwrapped_in_ref, + reassociation, + output_shape=op.output_shape, + static_output_shape=out_transformed_ty.shape, + ) + + wrapped_ref = wrap_transformed_memref( + new_expand_shape_op.result, op.result.type, out_transforms + ) + return [wrapped_ref] + + +@_register_lowering(memref.LoadOp) +def _memref_load_op_lowering_rule( + ctx: LoweringContext, op: memref.LoadOp +) -> Sequence[ir.Value]: + """Lowering rule for memref.LoadOp. + + Loads are never transformed so this rule is mostly just a pass-through. + """ + del ctx + + in_transforms = inference_utils.in_transforms(op)[0] + if in_transforms: + raise NotImplementedError(f"memref.LoadOp does not support transforms: {op}") + + new_load_op = memref.LoadOp( + memref=unwrap_transformed_memref(op.memref, in_transforms), + indices=op.indices, + nontemporal=op.nontemporal, + ) + return [new_load_op.result] + + +@_register_lowering(memref.StoreOp) +def _memref_store_op_lowering_rule( + ctx: LoweringContext, op: memref.StoreOp +) -> Sequence[ir.Value]: + """Lowering rule for memref.StoreOp. + + Stores are never transformed so this rule is mostly just a pass-through. + """ + del ctx + + in_transforms = inference_utils.in_transforms(op)[0] + if in_transforms: + raise NotImplementedError(f"memref.StoreOp does not support transforms: {op}") + + memref.StoreOp( + value=op.value, + memref=unwrap_transformed_memref(op.memref, in_transforms), + indices=op.indices, + nontemporal=op.nontemporal, + ) + return [] + + +@_register_lowering(mgpu.TmemAllocOp) +def _tmem_alloc_op_lowering_rule( + ctx: LoweringContext, op: mgpu.TmemAllocOp +) -> Sequence[ir.Value]: + """Lowering rule for mgpu.TmemAllocOp.""" + ctx.check_collective(op) + + output_shape = ir.MemRefType(op.result.type).shape + ncols = output_shape[1] // op.packing.value + + with mgpu_utils.when(ctx.single_warp_per_block_predicate): + tcgen05.tmem_alloc(op.smem_ptr, ncols, op.collective, exact=False) + gpu.barrier() + tmem_addr = memref.load(op.smem_ptr, []) + + cast_op = builtin.UnrealizedConversionCastOp( + [op.result.type], [tmem_addr] + ) + cast_op.attributes["collective"] = op.collective + cast_op.attributes["packing"] = op.packing + cast_op.attributes["layout"] = inference_utils.out_tmem_layouts(op)[0] + + return [cast_op.result] + + +@_register_lowering(mgpu.TmemRelinquishAllocPermitOp) +def _tmem_relinquish_alloc_permit_op_lowering_rule( + ctx: LoweringContext, op: mgpu.TmemRelinquishAllocPermitOp +) -> Sequence[ir.Value]: + """Lowering rule for mgpu.TmemRelinquishAllocPermitOp.""" + ctx.check_collective(op) + with mgpu_utils.when(ctx.single_warp_per_block_predicate): + tcgen05.tmem_relinquish_alloc_permit(op.collective) + return [] + + +@_register_lowering(mgpu.TmemDeallocOp) +def _tmem_dealloc_op_lowering_rule( + ctx: LoweringContext, op: mgpu.TmemDeallocOp +) -> Sequence[ir.Value]: + """Lowering rule for mgpu.TmemDeallocOp.""" + i32 = ir.IntegerType.get_signless(32) + conversion_cast, [tmem_addr] = _undo_conversion_cast(op.tmem_ref, [i32]) + collective = ir.BoolAttr(conversion_cast.attributes["collective"]).value + packing = ir.IntegerAttr(conversion_cast.attributes["packing"]).value + + output_shape = ir.MemRefType(op.tmem_ref.type).shape + ncols = output_shape[1] // packing + + with mgpu_utils.when(ctx.single_warp_per_block_predicate): + tcgen05.tmem_dealloc(tmem_addr, ncols, collective, exact=False) + + return [] + + +def _swizzle(attrs: Sequence[ir.Attribute]) -> mgpu.SwizzlingMode: + """Returns the swizzle transform from the given attributes.""" + swizzle = None + for attr in attrs: + if mgpu.SwizzleTransformAttr.isinstance(attr): + if swizzle is not None: + raise ValueError("Multiple swizzle transforms are not supported.") + swizzle = mgpu.SwizzleTransformAttr(attr).swizzle + return swizzle if swizzle is not None else mgpu.SwizzlingMode.kNoSwizzle + +def _tmem_ref_from_ir( + ref: ir.Value, expected_layout: ir.Attribute +) -> tcgen05.TMEMRef: + """Returns a TMEMRef from an IR value. + + Throws an error if the annotated layout does not match the expected layout. + """ + if not isinstance(ref.type, ir.MemRefType): + raise ValueError(f"{ref} is not a memref.") + mem_ref_ty = ir.MemRefType(ref.type) + + if mem_ref_ty.memory_space != mgpu_utils.tmem(): + raise ValueError( + f"{ref} has a memory space {mem_ref_ty.memory_space} that is not TMEM." + ) + + i32 = ir.IntegerType.get_signless(32) + cast, [tmem_addr] = _undo_conversion_cast(ref, [i32]) + + shape = tuple(mem_ref_ty.shape) + el_ty = mem_ref_ty.element_type + layout_attr = cast.attributes["layout"] + if layout_attr != expected_layout: + raise ValueError( + f"{ref} has a layout {layout_attr} that does not match the expected" + f" layout {expected_layout}." + ) + layout = layouts_lib.from_layout_attr(layout_attr) + assert isinstance(layout, fa.TiledLayout) + tmem_layout = tcgen05.TMEMLayout( + layout.tiling, layout.warp_dims, layout.lane_dims, layout.vector_dim + ) + return tcgen05.TMEMRef(tmem_addr, shape, el_ty, tmem_layout) + + +def _tmem_ref_to_ir(ref: tcgen05.TMEMRef) -> ir.Value: + """Returns an IR value from a TMEMRef.""" + type = ir.MemRefType.get(ref.shape, ref.dtype, memory_space=mgpu_utils.tmem()) + cast = builtin.UnrealizedConversionCastOp([type], [ref.address]) + cast.attributes["layout"] = layouts_lib.to_layout_attr(ref.layout) + return cast.result + + +@_register_lowering(mgpu.TcGen05MMAOp) +def _tcgen05_mma_op_lowering_rule( + ctx: LoweringContext, op: mgpu.TcGen05MMAOp +) -> Sequence[ir.Value]: + ctx.check_collective(op) + + in_tmem_layouts = inference_utils.in_tmem_layouts(op) + acc_layout = in_tmem_layouts[0] + acc_ref = _tmem_ref_from_ir(op.accumulator, acc_layout) + + if utils.is_smem_ref(op.a): + a_transforms, b_transforms = inference_utils.in_transforms(op) + a_swizzle = _swizzle(a_transforms) + b_swizzle = _swizzle(b_transforms) + a_ref = unwrap_transformed_memref(op.a, a_transforms) + b_ref = unwrap_transformed_memref(op.b, b_transforms) + else: + a_ref = _tmem_ref_from_ir(op.a, in_tmem_layouts[1]) + [b_transforms] = inference_utils.in_transforms(op) + b_swizzle = _swizzle(b_transforms) + a_swizzle = b_swizzle + b_ref = unwrap_transformed_memref(op.b, b_transforms) + + with mgpu_utils.when(ctx.single_thread_per_block_predicate): + tcgen05.mma( + acc_ref, + a_ref, + b_ref, + a_swizzle=a_swizzle, + b_swizzle=b_swizzle, + a_scale=op.a_scale, + b_scale=op.b_scale, + accumulate=op.accumulate, + collective=op.collective.value, + ) + + return [] + + +@_register_lowering(mgpu.AsyncLoadTmemOp) +def _async_load_tmem_op_lowering_rule( + ctx: LoweringContext, op: mgpu.AsyncLoadTmemOp +) -> Sequence[ir.Value]: + """Lowering rule for mgpu.AsyncLoadTmemOp.""" + del ctx + in_layout_attr = inference_utils.in_tmem_layouts(op)[0] + tmem_ref = _tmem_ref_from_ir(op.source, in_layout_attr) + out_layout_attr = inference_utils.out_layouts(op)[0] + out_layout = layouts_lib.from_tiled_layout_attr(out_layout_attr) + is_signed = _default_is_signed(ir.MemRefType(op.source.type).element_type) + fa = tmem_ref.load(out_layout, is_signed) + return [fragmented_array_to_ir(fa, op.result.type)] + + +@_register_lowering(mgpu.AsyncStoreTmemOp) +def _async_store_tmem_op_lowering_rule( + ctx: LoweringContext, op: mgpu.AsyncStoreTmemOp +) -> Sequence[ir.Value]: + """Lowering rule for mgpu.AsyncStoreTmemOp.""" + del ctx + in_layout_attr = inference_utils.in_tmem_layouts(op)[0] + tmem_ref = _tmem_ref_from_ir(op.destination, in_layout_attr) + in_layout_attr = inference_utils.in_layouts(op)[0] + fa = _fragmented_array_from_ir(op.source, in_layout_attr) + tmem_ref.store(fa) + + return [] + + +@_register_lowering(mgpu.CustomPrimitiveOp) +def _mgpu_custom_primitive_op_lowering_rule( + ctx: LoweringContext, op: mgpu.CustomPrimitiveOp +) -> Sequence[ir.Value]: + """Lowering rule for mgpu.CustomPrimitiveOp.""" + del ctx + block = op.body.blocks[0] + for arg, op in zip(block.arguments, op.operands, strict=True): + arg.replace_all_uses_with(op) + + return_op = None + ip = ir.InsertionPoint.current + for op in block.operations: + if isinstance(op.opview, mgpu.ReturnOp): + assert return_op is None + return_op = op.opview + continue + op.detach_from_parent() + ip.insert(op) + + if return_op is None: + raise ValueError("A custom return op must terminate the block.") + + return return_op.operands + + +# The metadata needed to recostruct a vector from its flattened representation. +_VectorTemplate = tuple[Sequence[int], fa.FragmentedLayout, ir.VectorType] + + +def _flatten_ir_values( + values: Sequence[ir.Value], fa_layouts: Iterable[ir.Attribute] +) -> tuple[Sequence[ir.Value], Sequence[_VectorTemplate | None]]: + """Flattens a sequence of values. + + Non-vector values are preserved as is. Vectors are mapped to fragmented + arrays and then flattened into per-register values. + + Args: + values: The sequence of values to flatten. + fa_layouts: The layouts of vectors in ``values``. + + Returns: + A tuple of (flattened values, templates). The templates are used to + reconstruct the vectors from the per-register values. + """ + fa_layouts_it = iter(fa_layouts) + result: list[ir.Value] = [] + templates: list[_VectorTemplate | None] = [] + for v in values: + if isinstance(v.type, ir.VectorType): + fa = _fragmented_array_from_ir(v, next(fa_layouts_it)) + result.extend(fa.registers.flat) + templates.append((fa.registers.shape, fa.layout, ir.VectorType(v.type))) + else: + result.append(v) + templates.append(None) + return result, templates + + +def _unflatten_ir_values( + flat_values: Sequence[ir.Value], templates: Sequence[_VectorTemplate | None] +) -> Sequence[ir.Value]: + """The inverse of ``_flatten_ir_values``.""" + result = [] + flat_values_it = iter(flat_values) + for template in templates: + if template is None: + result.append(next(flat_values_it)) + continue + registers_shape, layout, vec_type = template + value_registers = np.asarray( + [next(flat_values_it) for _ in range(math.prod(registers_shape))], + dtype=object, + ) + value = fa.FragmentedArray( + _registers=value_registers.reshape(registers_shape), + _layout=layout, + _is_signed=_default_is_signed(vec_type.element_type), + ) + result.append(fragmented_array_to_ir(value, vec_type)) + return result + + +def _move_scf_block_to_block_with_flattened_arguments( + ctx: LoweringContext, + old_block: ir.Block, + new_block: ir.Block, + last_op_type: type[ir.OpView], + args_template: Sequence[_VectorTemplate | None], + *new_leading_args: Sequence[ir.Value], +) -> Sequence[_VectorTemplate | None]: + """Moves the operations from `old_block` to `new_block`. + + The input arguments to the block, if any, are flattened using the provided + `args_template`, except for any new_leading_args which are simply prepended + to the flattened arguments and must be part of the template. + + The last operation of the old block must be of type `last_op_type` which + is expected to be either a `scf.YieldOp` or a `scf.ConditionOp`. This + operation is recreated with flattened output arguments. + """ + out_template = None + with ir.InsertionPoint(new_block): + new_carry = _unflatten_ir_values(new_block.arguments[len(new_leading_args):], args_template) + new_args = new_leading_args + tuple(new_carry) + for old_arg, new_arg in zip(old_block.arguments, new_args, strict=True): + old_arg.replace_all_uses_with(new_arg) + for op in [*old_block]: + if not isinstance(op, last_op_type): + # `append` moves the operation. + new_block.append(op) + ctx.lower_op(op) + else: + assert out_template is None + layouts = ( + inference_utils.in_layouts(op) + if inference_utils.has_in_layouts_set(op) + else [] + ) + if isinstance(op, scf.YieldOp): + flat_operands, out_template = _flatten_ir_values(op.operands, layouts) + scf.yield_(flat_operands) + elif isinstance(op, scf.ConditionOp): + flat_carry, out_template = _flatten_ir_values(op.args, layouts) + scf.condition(op.condition, flat_carry) + else: + raise NotImplementedError(f"Unsupported op type: {op}") + op.erase() + assert out_template is not None + return out_template @_register_lowering(scf.ForOp) def _for_op_lowering_rule( @@ -884,84 +2186,152 @@ def _for_op_lowering_rule( yield_layouts = inference_utils.in_layouts(yield_op) if in_layouts != out_layouts or in_layouts != yield_layouts: raise ValueError("Layout mismatch") - fa_layouts = in_layouts - - fa_layouts_it = iter(fa_layouts) - arg_template = [ - (_fragmented_array_from_ir(arg, next(fa_layouts_it)), arg.type) - if ir.VectorType.isinstance(arg.type) - else (arg, arg.type) - for arg in for_op.initArgs - ] - def lower_carry(carry): - fa_layouts_it = iter(fa_layouts) - carry_with_fas = [ - _fragmented_array_from_ir(arg, next(fa_layouts_it)) - if ir.VectorType.isinstance(arg.type) - else arg - for arg in carry - ] - lowered_carry = [] - for c in carry_with_fas: - if isinstance(c, fa.FragmentedArray): - lowered_carry.extend(c.registers.flat) - else: - lowered_carry.append(c) - return lowered_carry - - def recreate_carry(lowered_carry): - recreated_carry = [] - arg_it = iter(lowered_carry) - for arg_value, arg_type in arg_template: - if isinstance(arg_value, fa.FragmentedArray): - carry_registers = np.asarray( - [next(arg_it) for _ in arg_value.registers.flat], dtype=object - ) - carry_registers = carry_registers.reshape(arg_value.registers.shape) - carry = fa.FragmentedArray( - _registers=carry_registers, - _layout=arg_value.layout, - _is_signed=arg_value.is_signed, - ) - recreated_carry.append(_fragmented_array_to_ir(carry, arg_type)) - else: - recreated_carry.append(next(arg_it)) - return recreated_carry + flat_init_args, args_template = _flatten_ir_values( + for_op.initArgs, in_layouts + ) new_for_op = scf.ForOp( for_op.lowerBound, for_op.upperBound, for_op.step, - lower_carry(for_op.initArgs), + flat_init_args, + ) + + _move_scf_block_to_block_with_flattened_arguments( + ctx, + for_op.body, + new_for_op.body, + scf.YieldOp, + args_template, + new_for_op.induction_variable, + ) + + return _unflatten_ir_values(new_for_op.results, args_template) + + +@_register_lowering(scf.WhileOp) +def _while_op_lowering_rule( + ctx: LoweringContext, while_op: scf.WhileOp +) -> MlirLoweringRuleResult: + if not inference_utils.should_have_layout(while_op): + return _traverse_op_lowering_rule(ctx, while_op) + + before_block = while_op.before.blocks[0] + after_block = while_op.after.blocks[0] + condition_op = before_block.operations[len(before_block.operations) - 1] + yield_op = after_block.operations[len(after_block.operations) - 1] + + in_layouts = ( + inference_utils.in_layouts(while_op) + if inference_utils.should_have_in_layout(while_op) + else [] + ) + out_layouts = ( + inference_utils.out_layouts(while_op) + if inference_utils.should_have_out_layout(while_op) + else [] + ) + + if in_layouts: + yield_layouts = inference_utils.in_layouts(yield_op) + if in_layouts != yield_layouts: + raise ValueError( + f"Input layouts {in_layouts} do not match yield layouts" + f" {yield_layouts}" + ) + + if out_layouts: + condition_layouts = inference_utils.in_layouts(condition_op) + if out_layouts != condition_layouts: + raise ValueError( + f"Output layouts {out_layouts} do not match condition layouts" + f" {condition_layouts}" + ) + + flat_inits, inits_template = _flatten_ir_values(while_op.inits, in_layouts) + result_types = _infer_flat_result_types(while_op, out_layouts) + new_while_op = scf.WhileOp(result_types, flat_inits) + + # Before block + init_types = [v.type for v in flat_inits] + new_before_block = new_while_op.before.blocks.append(*init_types) + results_template = _move_scf_block_to_block_with_flattened_arguments( + ctx, + before_block, + new_before_block, + scf.ConditionOp, + inits_template, + ) + + # After block + new_after_block = new_while_op.after.blocks.append(*result_types) + _move_scf_block_to_block_with_flattened_arguments( + ctx, + after_block, + new_after_block, + scf.YieldOp, + results_template, + ) + + return _unflatten_ir_values(new_while_op.results, results_template) + + +def _infer_flat_result_types( + op: ir.OpView, out_layouts: Sequence[ir.Attribute] +) -> Sequence[ir.Type]: + result_types: list[ir.Type] = [] + out_layouts_it = iter(out_layouts) + for r in op.results: + if not isinstance(r.type, ir.VectorType): + result_types.append(r.type) + continue + vec_type = ir.VectorType(r.type) + layout = layouts_lib.from_layout_attr(next(out_layouts_it)) + result_types.extend( + [layout.registers_element_type(vec_type.element_type)] + * math.prod(layout.registers_shape(tuple(vec_type.shape))) + ) + return result_types + + +@_register_lowering(scf.IfOp) +def _if_op_lowering_rule( + ctx: LoweringContext, if_op: scf.IfOp +) -> MlirLoweringRuleResult: + if not inference_utils.should_have_layout(if_op): + return _traverse_op_lowering_rule(ctx, if_op) + + raise NotImplementedError + + +@_register_lowering(scf.IndexSwitchOp) +def _index_switch_op_lowering_rule( + ctx: LoweringContext, switch_op: scf.IndexSwitchOp +) -> MlirLoweringRuleResult: + if not inference_utils.should_have_layout(switch_op): + return _traverse_op_lowering_rule(ctx, switch_op) + + out_layouts = inference_utils.out_layouts(switch_op) + new_switch_op = scf.IndexSwitchOp( + _infer_flat_result_types(switch_op, out_layouts), + switch_op.arg, + switch_op.cases, ) - with ir.InsertionPoint(new_for_op.body): - recreated_carry = recreate_carry(new_for_op.body.arguments[1:]) - ops_to_lower = [] - for op in for_op.body: - if op == yield_op: - continue - mgpu.private_operation_remove_from_parent(op) - mgpu.private_block_append_owned_operation(new_for_op.body, op) - ops_to_lower.append(op) - new_args = (new_for_op.induction_variable, *recreated_carry) - for old_carry, new_carry in zip(for_op.body.arguments, new_args, strict=True): - old_carry.replace_all_uses_with(new_carry) - - for op in ops_to_lower: - with ir.InsertionPoint(op): - ctx.lower_op(op) - with ir.InsertionPoint(new_for_op.body): - new_yield_operands = lower_carry(yield_op.operands) - yield_op.erase() - scf.yield_(new_yield_operands) - return recreate_carry(new_for_op.results) + results_template: Sequence[_VectorTemplate | None] = [] + for region, new_region in zip( + switch_op.regions, new_switch_op.regions, strict=True + ): + [block] = region.blocks + new_block = new_region.blocks[0] + results_template = _move_scf_block_to_block_with_flattened_arguments( + ctx, block, new_block, scf.YieldOp, [] + ) + return _unflatten_ir_values(new_switch_op.results, results_template) @_register_lowering(func.FuncOp) @_register_lowering(gpu.LaunchOp) -@_register_lowering(scf.IfOp) # TODO(apaszke,bchetioui): Add a proper rule. -@_register_lowering(scf.IndexSwitchOp) # TODO(apaszke,bchetioui): Add a proper rule. def _traverse_op_lowering_rule( ctx: LoweringContext, op: ir.OpView ) -> MlirLoweringRuleResult: @@ -977,44 +2347,61 @@ def _traverse_op_lowering_rule( return RECURSED -def single_thread_predicates(module: ir.Module) -> tuple[ir.Value, ir.Value]: - """Returns a single thread predicate per block and one per warpgroup.""" - block_predicate = warpgroup_predicate = None +def _should_lower(op: ir.OpView) -> bool: + """Returns 'true' if the operation should be lowered.""" + return ( + op.OPERATION_NAME.startswith("mosaic_gpu.") # pytype: disable=attribute-error + or inference_utils.should_have_layout(op) + or inference_utils.should_have_transforms(op) + or inference_utils.should_have_tmem_layout(op) + or any(bool(b) for r in op.regions for b in r) # Does it have subblocks? + ) + + +def _gpu_launch_op(module: ir.Module) -> ir.Operation: for op in module.body.operations: for region in op.operation.regions: for block in region.blocks: for sub_op in block.operations: if sub_op.operation.name == "gpu.launch": - with ir.InsertionPoint.at_block_begin( - sub_op.operation.regions[0].blocks[0] - ): - assert block_predicate is None - block_predicate = utils.single_thread_predicate(per_block=True) - warpgroup_predicate = utils.single_thread_predicate( - per_block=False - ) - - if block_predicate is None: - raise ValueError( - "No suitable function found to instantiate the single thread" - " predicates." - ) + return sub_op.operation + raise ValueError("gpu.launch op not found.") - return block_predicate, warpgroup_predicate +def _lowering_context( + module: ir.Module, + launch_context: launch_context.LaunchContext | None, + auto_barriers: bool, +) -> LoweringContext: + """Returns a `LoweringContext` for the given `LaunchContext`.""" + # TODO(bchetioui): fix tests to not have a test-only path polluting the API. + if launch_context is None: # this case is used in some tests + return LoweringContext(None, None, None, None, auto_barriers) -def _should_lower(op: ir.OpView) -> bool: - """Returns 'true' if the operation should be lowered.""" - return ( - op.OPERATION_NAME.startswith("mosaic_gpu.") # pytype: disable=attribute-error - or inference_utils.should_have_layout(op) - or any(bool(b) for r in op.regions for b in r) # Does it have subblocks? - ) + gpu_launch_op = _gpu_launch_op(module) + with ir.InsertionPoint.at_block_begin(gpu_launch_op.regions[0].blocks[0]): + block_predicate = utils.single_thread_predicate( + scope=utils.ThreadSubset.BLOCK + ) + warpgroup_predicate = utils.single_thread_predicate( + scope=utils.ThreadSubset.WARPGROUP + ) + eq = arith.CmpIPredicate.eq + i32 = ir.IntegerType.get_signless(32) + warp_predicate = arith.cmpi(eq, utils.warp_idx(sync=False), utils.c(0, i32)) + return LoweringContext( + launch_context, + block_predicate, + warpgroup_predicate, + warp_predicate, + auto_barriers, + ) def lower_mgpu_dialect( module: ir.Module, launch_context: launch_context.LaunchContext | None, + auto_barriers: bool = True, ): # TODO(apaszke,bchetioui): Make sure the layouts match. # TODO(bchetioui): rethink this API. It doesn't make sense to pass in a full @@ -1025,14 +2412,7 @@ def lower_mgpu_dialect( # kernel. module.context.append_dialect_registry(mlir_interpreter.upstream_dialects) module.context.load_all_available_dialects() - - # TODO(bchetioui): fix tests to not have a test-only path polluting the API. - if launch_context is None: # this case is used in some tests - block_predicate = warpgroup_predicate = None - else: - block_predicate, warpgroup_predicate = single_thread_predicates(module) - - ctx = LoweringContext(launch_context, block_predicate, warpgroup_predicate) + ctx = _lowering_context(module, launch_context, auto_barriers) with ir.InsertionPoint(module.body): for op in list(module.body): ctx.lower_op(op) diff --git a/jax/experimental/mosaic/gpu/examples/BUILD b/jax/experimental/mosaic/gpu/examples/BUILD index fe1a7e9180ac..f33c62cdbf2e 100644 --- a/jax/experimental/mosaic/gpu/examples/BUILD +++ b/jax/experimental/mosaic/gpu/examples/BUILD @@ -19,7 +19,7 @@ licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//jax:mosaic_gpu_users"], + default_visibility = ["//jax/experimental:mosaic_gpu_users"], ) exports_files( @@ -30,12 +30,29 @@ exports_files( visibility = ["//jax:internal"], ) +py_library( + name = "gpu_examples", + srcs = [ + "__init__.py", + ], + visibility = ["//jax:internal"], +) + py_library( name = "matmul", srcs = ["matmul.py"], deps = [ "//jax", - "//jax:mosaic_gpu", + "//jax/experimental:mosaic_gpu", + ], +) + +py_library( + name = "matmul_blackwell", + srcs = ["matmul_blackwell.py"], + deps = [ + "//jax", + "//jax/experimental:mosaic_gpu", ], ) @@ -44,7 +61,7 @@ py_library( srcs = ["flash_attention.py"], deps = [ "//jax", - "//jax:mosaic_gpu", + "//jax/experimental:mosaic_gpu", ], ) @@ -59,6 +76,6 @@ jax_multiplatform_test( "notap", ], deps = [ - "//jax:mosaic_gpu", + "//jax/experimental:mosaic_gpu", ] + py_deps("numpy"), ) diff --git a/jax/experimental/mosaic/gpu/examples/__init__.py b/jax/experimental/mosaic/gpu/examples/__init__.py new file mode 100644 index 000000000000..3da0dd1fa3ca --- /dev/null +++ b/jax/experimental/mosaic/gpu/examples/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/jax/experimental/mosaic/gpu/examples/flash_attention.py b/jax/experimental/mosaic/gpu/examples/flash_attention.py index dc59dda3a6e5..3af86b5e4923 100644 --- a/jax/experimental/mosaic/gpu/examples/flash_attention.py +++ b/jax/experimental/mosaic/gpu/examples/flash_attention.py @@ -17,7 +17,6 @@ import dataclasses import enum import itertools -import warnings import jax from jax import random @@ -179,7 +178,7 @@ def only_wg(idx): loop_partition = Partition1D(kv_seq_len, chunk_size=blocks.kv) if_compute = scf.IfOp( - arith.cmpi(arith.CmpIPredicate.ne, wg_idx, c(2, i32)), hasElse=True + arith.cmpi(arith.CmpIPredicate.ne, wg_idx, c(2, i32)), has_else=True ) with ir.InsertionPoint(if_compute.then_block): nvvm.setmaxregister(232, nvvm.SetMaxRegisterAction.increase) @@ -244,8 +243,8 @@ def kv_loop(kv_step, carry): perform_schedule_barrier() - # This is quite suprising, but it seems like warp shuffles cannot - # run simutaneously with the WGMMA. For that reason we include it as + # This is quite surprising, but it seems like warp shuffles cannot + # run simultaneously with the WGMMA. For that reason we include it as # part of the TensorCore critical section and not the ALU section. with ctx.named_region("Softmax reduction"): l_i += p.reduce(arith.addf, axis=1) @@ -299,7 +298,7 @@ def kv_loop(kv_step, carry): scf.yield_([]) with ir.InsertionPoint(if_compute.else_block): nvvm.setmaxregister(40, nvvm.SetMaxRegisterAction.decrease) - with single_thread(per_block=False): + with single_thread(scope=ThreadSubset.WARPGROUP): k_tr = (TileTransform(tiling), TransposeTransform((1, 0, 2, 3))) v_tr = TileTransform(tiling) kv_head_idx = arith.divui(q_head_idx, c(q_heads_per_kv_head)) @@ -310,7 +309,7 @@ def start_kv_copy(slot, kv_seq_base, smem, gmem, barrier, transform): gmem_slice=(kv_head_idx, ds(kv_seq_base, blocks.kv)), gmem_transform=transform, barrier=barrier, - uniform=False, + predicate=None, swizzle=128, ) def start_k_copy(slot, kv_seq_base): @@ -391,7 +390,7 @@ def only_wg(idx): kv_head_idx = arith.divui(q_head_idx, c(q_heads_per_kv_head)) def kv_copy_init(slot, kv_seq_base): - with single_thread(per_block=False): + with single_thread(ThreadSubset.WARPGROUP): txcount = 2 * blocks.kv * head_dim * bytewidth(f16) barriers[slot].arrive_expect_tx(txcount) k_tr = (TileTransform(tiling), TransposeTransform((1, 0, 2, 3))) @@ -404,7 +403,7 @@ def kv_copy_init(slot, kv_seq_base): gmem_transform=t, barrier=barriers[slot], arrive=False, - uniform=False, + predicate=None, swizzle=128, ) @@ -601,7 +600,7 @@ def ref(q, k, v): if __name__ == "__main__": if (not jtu.test_device_matches(["cuda"]) or not jtu.is_cuda_compute_capability_equal("9.0")): - warnings.warn( + print( "Mosaic GPU Flash Attention requires compute capability 9.0a to run, " "skipping.") exit(0) diff --git a/jax/experimental/mosaic/gpu/examples/matmul.py b/jax/experimental/mosaic/gpu/examples/matmul.py index a5dd29e0dc4d..5c8363fa8b27 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul.py +++ b/jax/experimental/mosaic/gpu/examples/matmul.py @@ -206,7 +206,7 @@ def fetch(slot, ki): rhs_tma_tile_bytes = int(np.prod(block_tiling.kn) * rhs_elem_bytes) txcount = lhs_tma_tile_bytes + rhs_tma_tile_bytes common_copy_args = dict( - swizzle=swizzle, barrier=barrier, arrive=False, uniform=False, + swizzle=swizzle, barrier=barrier, arrive=False, predicate=None, ) with single_thread(): barrier.arrive_expect_tx(txcount) diff --git a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py index 6af394d00138..7c65e8761aa9 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py +++ b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py @@ -15,6 +15,7 @@ """Matmul kernel for Blackwell.""" import itertools +import math import jax from jax._src.interpreters import mlir @@ -41,7 +42,8 @@ def bytecount(shape, dtype): def build_kernel( - m, n, k, + m, k, n, + dtype: jnp.dtype, tile_m: int = 128, tile_n: int = 128, grid_tile_m: int = 1, @@ -51,12 +53,15 @@ def build_kernel( i1 = ir.IntegerType.get_signless(1) i32 = ir.IntegerType.get_signless(32) index = ir.IndexType.get() + if jnp.dtype(dtype).itemsize != 2: + raise NotImplementedError(f"Only tested with 16-bit dtypes, but got {dtype}") + if tile_m != 128: + raise NotImplementedError(f"Only tile_m=128 supported, but got {tile_m}") swizzle = 128 - swizzle_elems = tile_k = swizzle // 2 + swizzle_elems = tile_k = 8 * swizzle // jnp.finfo(dtype).bits tiling = (8, swizzle_elems) - in_dtype = jnp.float16 k_loop_iter = k // tile_k max_concurrent_steps = min(max_concurrent_steps, k_loop_iter) @@ -74,132 +79,187 @@ def build_kernel( raise ValueError(f"{n=} must be divisible by {tile_n=}") if k % tile_k != 0: raise ValueError(f"{k=} must be divisible by {tile_k=}") - if (m // tile_m) % grid_tile_m: + if (m // block_tile_m) % grid_tile_m: raise ValueError(f"{m=} // {tile_m=} must be divisible by {grid_tile_m=}") + # We intend this to be iterated in column-major order. + logical_grid = (grid_tile_m, n // tile_n, m // (block_tile_m * grid_tile_m)) + def kernel(ctx, a, b, d, smem): - ((a_smem, b_smem), d_smem), barriers, mma_done_barrier, acc = smem + ((a_smem, b_smem), d_smem), barriers, mma_done_barrier, tmem_done_barrier, acc = smem (ab_full_barriers, ab_empty_barriers) = barriers warp_idx = mgpu.warp_idx(sync=True) is_warp_leader = nvvm.elect_sync(i1) - is_leader_of = lambda i: arith.andi(arith.cmpi(arith.CmpIPredicate.eq, warp_idx, c(i, i32)), is_warp_leader) - is_leader_block = arith.cmpi(arith.CmpIPredicate.eq, ctx.cluster_idx(gpu.Dimension.x), c(0, index)) - - m_idx = arith.addi( - gpu.block_id(gpu.Dimension.x), - arith.muli(gpu.block_id(gpu.Dimension.z), c(grid_tile_m, index)), + is_leader_of = lambda i: arith.andi( + arith.cmpi(arith.CmpIPredicate.eq, warp_idx, c(i, i32)), is_warp_leader + ) + is_leader_block = arith.cmpi( + arith.CmpIPredicate.eq, ctx.cluster_idx(gpu.Dimension.x), c(0, index) + ) + is_store_warpgroup = arith.cmpi( + arith.CmpIPredicate.eq, mgpu.warpgroup_idx(sync=True), c(1, i32) ) - n_idx = gpu.block_id(gpu.Dimension.y) - block_m_start = arith.muli(m_idx, c(block_tile_m, index)) - # All blocks in the cluster share the same m_start -- align it! - m_start = arith.muli(arith.divui(block_m_start, c(tile_m, index)), c(tile_m, index)) - n_start = arith.muli(n_idx, c(tile_n,index)) + def compute_output(block_m_start, n_start, call_counter): + """Compute and store a single output tile. - with mgpu.when(is_leader_of(TMA_WARP)): - @mgpu.fori(c(k_loop_iter, index), None) - def _tma_body(ki, _): - slot = arith.remui(ki, c(max_concurrent_steps, index)) - # TODO(apaszke): Use a predicate instead of a conditional. - with mgpu.when(arith.cmpi(arith.CmpIPredicate.uge, ki, c(max_concurrent_steps, index))): - ab_empty_barriers[slot].wait() - full_barrier = ab_full_barriers[slot] - with mgpu.when(is_leader_block): - full_barrier.arrive_expect_tx( - bytecount((tile_m, tile_k), in_dtype) + bytecount((tile_n, tile_k), in_dtype) + call_counter should be 0 the first time this function is called and + incremented by 1 before each subsequent call. + """ + acc_slot = arith.remui(call_counter, c(2, index)) + acc_slice = acc.slice(slice(None), mgpu.ds(arith.muli(acc_slot, c(tile_n, index)), tile_n)) + # All blocks in the cluster share the same m_start -- align it! + m_start = arith.muli(arith.divui(block_m_start, c(tile_m, index)), c(tile_m, index)) + with mgpu.when(is_leader_of(TMA_WARP)): + @mgpu.fori(c(k_loop_iter, index), None) + def _tma_body(ki, _): + slot = arith.remui(ki, c(max_concurrent_steps, index)) + isnt_warmup = arith.cmpi( + arith.CmpIPredicate.uge, ki, c(max_concurrent_steps, index) + ) + isnt_first_call = arith.cmpi( + arith.CmpIPredicate.ne, call_counter, c(0, index) + ) + with mgpu.when(arith.ori(isnt_first_call, isnt_warmup)): + ab_empty_barriers[slot].wait() + full_barrier = ab_full_barriers[slot] + with mgpu.when(is_leader_block): + full_barrier.arrive_expect_tx( + bytecount((tile_m, tile_k), dtype) + bytecount((tile_n, tile_k), dtype) + ) + k_start = arith.muli(ki, c(tile_k, index)) + common_args = dict( + swizzle=swizzle, + barrier=full_barrier, + arrive=False, + predicate=None, + collective=gpu.Dimension.x, + partitioned=0, # Non-contracting dim is always 0. + ) + ctx.async_copy( + src_ref=a, + dst_ref=mgpu.memref_slice(a_smem, slot), + gmem_slice=(ds(m_start, tile_m), ds(k_start, tile_k)), + gmem_transform=mgpu.TileTransform(tiling), + **common_args, + ) + ctx.async_copy( + src_ref=b, + dst_ref=mgpu.memref_slice(b_smem, slot), + gmem_slice=(ds(n_start, tile_n), ds(k_start, tile_k)), + gmem_transform=mgpu.TileTransform(tiling), + **common_args, ) - k_start = arith.muli(ki, c(tile_k, index)) - common_args = dict( - swizzle=swizzle, - barrier=full_barrier, - arrive=False, - uniform=False, - collective=gpu.Dimension.x, - partitioned=0, # Non-contracting dim is always 0. - ) - ctx.async_copy( - src_ref=a, - dst_ref=mgpu.memref_slice(a_smem, slot), - gmem_slice=(ds(m_start, tile_m), ds(k_start, tile_k)), - gmem_transform=mgpu.TileTransform(tiling), - **common_args, - ) - ctx.async_copy( - src_ref=b, - dst_ref=mgpu.memref_slice(b_smem, slot), - gmem_slice=(ds(n_start, tile_n), ds(k_start, tile_k)), - gmem_transform=mgpu.TileTransform(tiling), - **common_args, - ) - with mgpu.when(arith.andi(is_leader_of(MMA_WARP), is_leader_block)): - @mgpu.fori(c(k_loop_iter, index), arith.constant(i1, 0)) - def _mma_body(ki, accumulate): - slot = arith.remui(ki, c(max_concurrent_steps, index)) - ab_full_barriers[slot].wait() - tcgen05.mma( - acc, - mgpu.memref_slice(a_smem, slot), - mgpu.memref_transpose(mgpu.memref_slice(b_smem, slot), (1, 0, 3, 2)), - a_swizzle=swizzle, - b_swizzle=swizzle, - accumulate=accumulate, - collective=collective, - ) - accumulate = arith.constant(i1, 1) - is_last_iter = arith.cmpi( - arith.CmpIPredicate.eq, ki, c(k_loop_iter - 1, index) - ) - barrier_ptr = arith.select( - is_last_iter, - mma_done_barrier.get_ptr(), - ab_empty_barriers[slot].get_ptr(), - ) - tcgen05.commit_arrive(barrier_ptr, collective=collective, ctx=ctx) - return accumulate + # We wait in all blocks in the cluster to avoid double arrival errors. + reuses_tmem = arith.cmpi(arith.CmpIPredicate.uge, call_counter, c(2, index)) + with mgpu.when(arith.andi(is_leader_of(MMA_WARP), reuses_tmem)): + tmem_done_barrier[acc_slot].wait(orders_tensor_core=True) + with mgpu.when(arith.andi(is_leader_of(MMA_WARP), is_leader_block)): + @mgpu.fori(c(k_loop_iter, index), arith.constant(i1, 0)) + def _mma_body(ki, accumulate): + slot = arith.remui(ki, c(max_concurrent_steps, index)) + ab_full_barriers[slot].wait() + tcgen05.mma( + acc_slice, + mgpu.memref_slice(a_smem, slot), + mgpu.memref_transpose(mgpu.memref_slice(b_smem, slot), (1, 0, 3, 2)), + a_swizzle=swizzle, + b_swizzle=swizzle, + accumulate=accumulate, + collective=collective, + ) + accumulate = arith.constant(i1, 1) + tcgen05.commit_arrive(ab_empty_barriers[slot], collective=collective, ctx=ctx) + is_last_iter = arith.cmpi( + arith.CmpIPredicate.eq, ki, c(k_loop_iter - 1, index) + ) + with mgpu.when(is_last_iter): + tcgen05.commit_arrive(mma_done_barrier[acc_slot], collective=collective, ctx=ctx) + return accumulate - gpu.barrier() - mma_done_barrier.wait(for_tensor_core=True) + with mgpu.when(is_store_warpgroup): + mma_done_barrier[acc_slot].wait(orders_tensor_core=True) + final_acc = acc_slice.load().astype(mlir.dtype_to_ir_type(jnp.dtype(dtype))) + assert tile_n % epilogue_tile_n == 0 + for ni in range(tile_n // epilogue_tile_n): + n_slice = ds(ni * epilogue_tile_n, epilogue_tile_n) + final_acc[:, n_slice].store_tiled(d_smem, swizzle=128) + # We store the first tile before arriving to reduce register pressure. + mgpu.commit_shared() + store_n_start = arith.addi(n_start, c(ni * epilogue_tile_n, index)) + ctx.async_copy( + src_ref=d_smem, + dst_ref=d, + gmem_slice=( + ds(block_m_start, block_tile_m), + ds(store_n_start, epilogue_tile_n), + ), + gmem_transform=mgpu.TileTransform((128, swizzle_elems)), + swizzle=128, + ) + ctx.await_async_copy(0, await_read_only=True) + tmem_done_barrier[acc_slot].arrive(orders_tensor_core=True) - acc[:].astype(ir.F16Type.get()).store_tiled(d_smem, swizzle=128) - mgpu.commit_shared() - ctx.async_copy( - src_ref=d_smem, - dst_ref=d, - gmem_slice=(ds(block_m_start, block_tile_m), ds(n_start, tile_n)), - gmem_transform=mgpu.TileTransform((128, swizzle_elems)), - swizzle=swizzle, + # We statically assign the tiles to SMs. + logical_grid_size = math.prod(logical_grid) + sm_id = gpu.block_id(gpu.Dimension.x) + extra_step = arith.cmpi( + arith.CmpIPredicate.slt, sm_id, c(logical_grid_size % num_sms, index) + ) # Some SMs do an extra step when grid size isn't divisible by SM count. + mn_steps = arith.addi( + mgpu.c(logical_grid_size // num_sms, index), + arith.index_castui(index, extra_step), ) - ctx.await_async_copy(0) + + @mgpu.fori(mn_steps, None) + def _mn_loop(local_mn_step, _): + global_mn_step = arith.addi( + sm_id, arith.muli(local_mn_step, mgpu.c(num_sms, index)) + ) + logical_idxs = [] + for dim_size in logical_grid: + logical_idxs.append(arith.remui(global_mn_step, mgpu.c(dim_size, index))) + global_mn_step = arith.divui(global_mn_step, mgpu.c(dim_size, index)) + lx, ly, lz = logical_idxs + m_idx = arith.addi(lx, arith.muli(lz, c(grid_tile_m, index))) + n_idx = ly + + block_m_start = arith.muli(m_idx, c(block_tile_m, index)) + n_start = arith.muli(n_idx, c(tile_n,index)) + compute_output(block_m_start, n_start, local_mn_step) compute_buffers = ( jax.ShapeDtypeStruct( mgpu.tile_shape((max_concurrent_steps, block_tile_m, tile_k), tiling), - jnp.float16), + dtype), jax.ShapeDtypeStruct( - mgpu.tile_shape((max_concurrent_steps, block_tile_n, tile_k), tiling), - jnp.float16), + mgpu.tile_shape((max_concurrent_steps, block_tile_n, tile_k), tiling), + dtype), ) + epilogue_tile_n = 64 epilogue_buffer = jax.ShapeDtypeStruct( - mgpu.tile_shape((block_tile_m, tile_n), (128, swizzle_elems)), - jnp.float16) - smem_buffers = mgpu.Union([compute_buffers, epilogue_buffer]) + mgpu.tile_shape((block_tile_m, epilogue_tile_n), (128, swizzle_elems)), + dtype) + smem_buffers = [compute_buffers, epilogue_buffer] smem = ( smem_buffers, [mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps)] * 2, - mgpu.Barrier(arrival_count=1), - mgpu.TMEM((128, tile_n), jnp.float32, collective=collective), + mgpu.Barrier(arrival_count=1, num_barriers=2), + mgpu.ClusterBarrier(collective_dims=(gpu.Dimension.x,), num_barriers=2), + mgpu.TMEM((128, 2 * tile_n), jnp.float32, collective=collective), ) + num_sms = 148 return mgpu.as_gpu_kernel( kernel, - (grid_tile_m, n // tile_n, m // (block_tile_m * grid_tile_m)), - (128, 1, 1), + (num_sms, 1, 1), # This is a persistent kernel. + (2 * 128, 1, 1), ( - jax.ShapeDtypeStruct((m, k), jnp.float16), - jax.ShapeDtypeStruct((n, k), jnp.float16), + jax.ShapeDtypeStruct((m, k), dtype), + jax.ShapeDtypeStruct((n, k), dtype), ), - jax.ShapeDtypeStruct((m, n), jnp.float16), + jax.ShapeDtypeStruct((m, n), dtype), smem, cluster=(2 if collective else 1, 1, 1), ) @@ -213,7 +273,7 @@ def main(unused_argv): b = jr.normal(key=kb, shape=(n, k), dtype=jnp.float16) tile_m = (128,) - tile_n = (128, 256, 512) + tile_n = (128, 256) max_concurrent_steps = (2, 4, 5, 6) grid_tile_m = (1, 2, 4, 8, 16) collective = (False, True) @@ -230,13 +290,13 @@ def main(unused_argv): tile_n *= 2 if m < tile_m or n < tile_n: continue - if tile_n > 512: + if 2 * tile_n > 512: continue if (m // tile_m) % kwargs["grid_tile_m"]: continue try: with mlir.make_ir_context(), ir.Location.unknown(): - f = build_kernel(m, n, k, **kwargs) + f = build_kernel(m, k, n, jnp.float16, **kwargs) _, runtime = profiler.measure(f)(a, b) except ValueError as e: if "Mosaic GPU kernel exceeds available shared memory" not in str(e): @@ -251,7 +311,7 @@ def main(unused_argv): raise ValueError("No valid configuration found") with mlir.make_ir_context(), ir.Location.unknown(): - d, runtime = profiler.measure(build_kernel(m, n, k, **best_kwargs))(a, b) + d, runtime = profiler.measure(build_kernel(m, k, n, jnp.float16, **best_kwargs))(a, b) d_ref, ref_runtime = profiler.measure(jax.jit(lambda a, b: a @ b.T))(a, b) tflops = float(2 * k * m * n) / (runtime / 1e3) / 1e12 diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 5daed8416589..4410f0fe5044 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -16,28 +16,27 @@ from __future__ import annotations +from collections.abc import Callable, Iterable, Sequence import dataclasses import functools +import itertools import math -from collections.abc import Callable -from typing import Iterable, Protocol, Sequence, TypeVar +from typing import Any, Protocol, TypeAlias, TypeVar, cast, overload -import itertools import jax +import jax.experimental.mosaic.gpu as mgpu from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith from jaxlib.mlir.dialects import gpu from jaxlib.mlir.dialects import llvm +from jaxlib.mlir.dialects import nvvm from jaxlib.mlir.dialects import math as mlir_math from jaxlib.mlir.dialects import memref -from jaxlib.mlir.dialects import nvvm from jaxlib.mlir.dialects import vector import numpy as np -import jax.experimental.mosaic.gpu as mgpu from . import utils -# mypy: ignore-errors T = TypeVar("T") WARPGROUP_SIZE = utils.WARPGROUP_SIZE @@ -49,7 +48,7 @@ @dataclasses.dataclass(frozen=True) -class Tiling: +class TilingImpl: """A tiling expression describing a permutation of elements of an nd-array. To apply one level of tiling to an array, each of the trailing dimensions (up @@ -68,23 +67,24 @@ class Tiling: def __post_init__(self): if not self.tiles: return - tiled_rank = len(self.tiles[0]) + last_tile_rank = len(self.tiles[0]) for tile in self.tiles: - if len(tile) > tiled_rank: - raise ValueError("Only the first tile can refer to value dimensions") + if len(tile) > last_tile_rank: + raise ValueError("Tiles must have a decreasing rank") if not tile: raise ValueError("Tiles must not be empty") if any(d <= 0 for d in tile): raise ValueError(f"Tile shape must only have positive sizes, got: {self.tiles}") - tiled_rank += len(tile) + last_tile_rank = len(tile) def __str__(self): return f"Tiling({''.join(map(str, self.tiles))})" def tile_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: """Computes the shape of an array after tiling.""" + orig_shape = shape def fail(): - raise ValueError(f"Tiling {self.tiles} does not apply to shape {shape}") + raise ValueError(f"Tiling {self.tiles} does not apply to shape {orig_shape}") for tile in self.tiles: if len(tile) > len(shape): fail() @@ -96,9 +96,10 @@ def fail(): def untile_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: """Computes the shape of an array before tiling from its tiled shape.""" + orig_shape = shape def fail(): raise ValueError( - f"shape {shape} is not a valid result of applying tiling {self}." + f"shape {orig_shape} is not a valid result of applying tiling {self}." ) for tile in reversed(self.tiles): if len(tile) > len(shape): @@ -111,6 +112,43 @@ def fail(): shape = (*untiled_dims, *(d * t for d, t in zip(tiled_dims, tile))) return shape + def canonicalize(self) -> TilingImpl: + """Returns a canonicalized version of the tiling. + + We define a tiling to be canonical if, at each step (except the first one, + which defines the base tile shape): + + 1. The tiling partitions at least one dimension in more than 1 tile. For + example, the tiling `(8, 8)(8, 8)` is not canonical, as applying it + yields a shape `(1, 1, 8, 8)`. We canonicalize it to `(8, 8)`, which + allows getting rid of the unnecessary `1` dimensions. + 2. The leading dimensions of each tile are not `1`. If canonicalizing a + tile in this way leads to an empty tile, then the tile is given shape + `(1,)`---which is still a meaningful (final) tile. For example, the + tiling `(8, 8)(1, 4)` is not canonical, as applying it yields a shape + `(8, 2, 1, 4)`. We canonicalize it to `(8, 8)(4,)`, which allows + getting rid of the unnecessary `1` dimension, and yields a shape + `(8, 2, 4)`. + """ + if len(self.tiles) <= 1: + return self + + shape = self.tiles[0] + new_tiling = [self.tiles[0]] + for tile in self.tiles[1:]: + for i, d in enumerate(tile): + if d != 1: + canonical_tile = tile[i:] + break + else: + canonical_tile = (1,) + tiled_dims = shape[-len(canonical_tile):] + if tiled_dims == canonical_tile: + continue + shape = canonical_tile + new_tiling.append(canonical_tile) + return TilingImpl(tuple(new_tiling)) + def tile_strides(self, strides: tuple[int, ...]) -> tuple[int, ...]: """Computes the strides of an array after tiling.""" for tile in self.tiles: @@ -118,6 +156,34 @@ def tile_strides(self, strides: tuple[int, ...]) -> tuple[int, ...]: strides = (*untiled, *(s * t for s, t in zip(tiled, tile)), *tiled) return strides + def tile_dimension(self, dim: int) -> tuple[bool, ...]: + """Result is True whenever the tiled dim originated from the given input dim.""" + tiling_rank = len(self.tiles[0]) + if dim < 0 or dim >= tiling_rank: + raise ValueError(f"Invalid dimension {dim} for tiling {self}") + strides = [1] * tiling_rank + strides[dim] = 0 + return tuple(s == 0 for s in self.tile_strides(tuple(strides))) + + def remove_dimension(self, dim: int) -> TilingImpl: + """Returns a tiling with the given dimension removed.""" + tiling_rank = len(self.tiles[0]) + if dim < 0 or dim >= tiling_rank: + raise ValueError(f"Invalid dimension {dim} for tiling {self}") + dim_in_tile = dim + tiles = [] + last_tile_rank = len(self.tiles[0]) + for t in self.tiles: + assert last_tile_rank >= len(t) + dim_in_tile -= last_tile_rank - len(t) + last_tile_rank = len(t) + if dim_in_tile >= 0: + t = t[:dim_in_tile] + t[dim_in_tile + 1:] + if not t: # If this tile is empty, all other tiles will be empty too. + break + tiles.append(t) + return TilingImpl(tuple(tiles)) + def tile_nested_shape_strides( self, shape: tuple[tuple[int, ...], ...], @@ -170,8 +236,8 @@ def fail_if(cond, shape=shape): # Capture shape now. minor_dim_shapes.append(minor_dim_shape_rev[::-1]) major_dim_strides.append(major_dim_stride_rev[::-1]) minor_dim_strides.append(minor_dim_stride_rev[::-1]) - shape = (*untiled_shape, *major_dim_shapes, *minor_dim_shapes) - strides = (*untiled_strides, *major_dim_strides, *minor_dim_strides) + shape = (*untiled_shape, *major_dim_shapes, *minor_dim_shapes) # type: ignore[arg-type] + strides = (*untiled_strides, *major_dim_strides, *minor_dim_strides) # type: ignore[arg-type] return ( tuple(tuple(d) if d else (1,) for d in shape), tuple(tuple(d) if d else (1,) for d in strides), @@ -195,6 +261,14 @@ def untile_indices(self, indices: tuple[int, ...]) -> tuple[int, ...]: indices = (*untiled, *(o * t + i for o, i, t in zip(outer, inner, tile))) return indices +# TODO(olechwierowicz): Clean up this once C++ Tiling is always available in JAX build. +Tiling: Any +if hasattr(mgpu.dialect, "Tiling"): + Tiling = mgpu.dialect.Tiling +else: + Tiling = TilingImpl + + def enumerate_negative(elems: Sequence[T]) -> Iterable[tuple[int, T]]: """Like built-in enumerate, but returns negative indices into the sequence.""" offset = len(elems) @@ -202,6 +276,35 @@ def enumerate_negative(elems: Sequence[T]) -> Iterable[tuple[int, T]]: yield i - offset, e +@dataclasses.dataclass(frozen=True) +class ReplicatedImpl: + times: int + + +# TODO(olechwierowicz): Clean up this once C++ Replicated is always available in JAX build. +Replicated: Any +if hasattr(mgpu.dialect, "Replicated"): + Replicated = mgpu.dialect.Replicated +else: + Replicated = ReplicatedImpl + + +def cc_method_exists(self, method_name: str): + return hasattr(mgpu.dialect, self.__class__.__name__) and hasattr( + getattr(mgpu.dialect, self.__class__.__name__), method_name + ) + + +def dispatch_to_cc_method(self, method_name: str, extract_args_fun, *args, **kwargs): + """Dispatches a method call to the corresponding C++ method.""" + cls = getattr(mgpu.dialect, self.__class__.__name__) + instance = cls(*extract_args_fun(self)) + attr = getattr(instance, method_name) + if not callable(attr): + return attr + return attr(*args, **kwargs) + + @dataclasses.dataclass(frozen=True) class TiledLayout: """A FragmentedArray layout derived from a tiling expression. @@ -211,16 +314,13 @@ class TiledLayout: the dimension indices. All dimension indices must be negative and should refer to the dimensions after tiling is applied. - Note that warp_dim and vector_dim could be sets as well, but we don't have a - usecase for that yet. - To better understand this layout, consider the example of WGMMA-related tiling from https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-d as applied to a 128x128 array. The corresponding TiledLayout has a tiling of: (64, 8)(16, 8)(8, 8)(1, 2) - and warp_dim=-8, lane_dims=(-4, -3), vector_dim=-1. + and warp_dims=(-8,), lane_dims=(-4, -3), vector_dim=-1. We begin by applying the tiling (note that it always applies to a suffix): @@ -232,8 +332,8 @@ class TiledLayout: 2 16 4 1 2 1 8 8 (1, 2) 2 16 4 1 2 1 8 4 1 2 - The last expression is our final shape. At this stage, we're ready to - interpret the dimensions: warp_dim=-8 means that the 8-th dimension from the + The last expression is our final shape. At this stage, we're ready to partition + the dimensions: warp_dims=(-8,) means that the 8-th dimension from the end is partitioned over 4 warps in a warpgroup (and so it must be of size 4). lane_dims=(-4, -3) indicate that those two dimensions are partitioned over the lanes within a warp (their product must be equal to 32, i.e. warp size). @@ -247,34 +347,72 @@ class TiledLayout: by a single (logical) register. """ tiling: Tiling - warp_dim: int - lane_dims: tuple[int, ...] # major-to-minor + warp_dims: tuple[int | Replicated, ...] # major-to-minor + lane_dims: tuple[int | Replicated, ...] # major-to-minor vector_dim: int + # Whether to enforce that the layout is canonical. Users of `TiledLayout` + # should not set this to `False`, but it is helpful to be able to construct + # non-canonical layouts as an intermediate state when implementing layout + # transformations. + _check_canonical: dataclasses.InitVar[bool] = True - def __post_init__(self): + def __post_init__(self, _check_canonical: bool): if not self.tiling.tiles: raise ValueError("Tiling must have at least one tile") min_shape = self.tiling.tiles[0] min_tiled_shape = self.tiling.tile_shape(min_shape) - dims_set = {self.warp_dim, *self.lane_dims, self.vector_dim} - if len(dims_set) != len(self.lane_dims) + 2: - raise ValueError + dims_set = { + *self.partitioned_warp_dims, *self.partitioned_lane_dims, self.vector_dim, + } + if len(dims_set) != len(self.partitioned_warp_dims) + len(self.partitioned_lane_dims) + 1: + raise ValueError("Duplicate partitioning dimensions") for d in dims_set: if d >= 0: raise ValueError("All dimensions must be negative") if d < -(len(min_tiled_shape) - len(min_shape)): raise ValueError("Dimension out of range") - if min_tiled_shape[self.warp_dim] != WARPS_IN_WARPGROUP: - raise ValueError - if math.prod(min_tiled_shape[d] for d in self.lane_dims) != WARP_SIZE: - raise ValueError + warp_dims_prod = math.prod( + d.times if isinstance(d, Replicated) else min_tiled_shape[d] + for d in self.warp_dims + ) + if warp_dims_prod != WARPS_IN_WARPGROUP: + raise ValueError( + "The product of warp dims does not equal the number of warps in a" + " warpgroup" + ) + lane_dims_prod = math.prod( + d.times if isinstance(d, Replicated) else min_tiled_shape[d] + for d in self.lane_dims + ) + if lane_dims_prod != WARP_SIZE: + raise ValueError("The product of lane dims does not equal the warp size") + if _check_canonical: + canonical_layout = self.canonicalize() + if self != canonical_layout: + raise ValueError(f"{self} is not canonical.") + + @functools.cached_property + def partitioned_warp_dims(self) -> tuple[int, ...]: + if cc_method_exists(self, "partitioned_warp_dims"): + return self.dispatch_to_cc("partitioned_warp_dims", check_canonical=False) + return tuple( + d for d in self.warp_dims if not isinstance(d, Replicated) + ) + + @functools.cached_property + def partitioned_lane_dims(self) -> tuple[int, ...]: + if cc_method_exists(self, "partitioned_lane_dims"): + return self.dispatch_to_cc("partitioned_lane_dims", check_canonical=False) + return tuple( + d for d in self.lane_dims if not isinstance(d, Replicated) + ) def thread_idxs(self, shape: tuple[int, ...]) -> Iterable[tuple[ir.Value, ...]]: # We first find the linear index and then divide by the shape to # get the index. i32 = ir.IntegerType.get_signless(32) index = ir.IndexType.get() - contig_strides = utils.get_contiguous_strides(shape) + contig_strides = tuple(utils.get_contiguous_strides(shape)) tile_strides = self.tiling.tile_strides(contig_strides) dyn_tile_strides = [c(s, i32) for s in tile_strides[-self.tiled_tiling_rank:]] warp_offset = utils.dyn_dot(self.warp_indices(), dyn_tile_strides) @@ -291,7 +429,7 @@ def thread_idxs(self, shape: tuple[int, ...]) -> Iterable[tuple[ir.Value, ...]]: yield tuple(idx) @property - def base_tile_shape(self) -> int: + def base_tile_shape(self) -> tuple[int, ...]: """The shape of the first tile in the tiling expression. This tile acts as the divisibility constraint for a suffix of arrays to @@ -308,6 +446,8 @@ def tiled_tiling_shape(self) -> tuple[int, ...]: so the tiled shape always ends with this suffix, no matter what array shape it's applied to. """ + if cc_method_exists(self, "tiled_tiling_shape"): + return self.dispatch_to_cc("tiled_tiling_shape", check_canonical=False) base_tile_shape = self.base_tile_shape return self.tiling.tile_shape(base_tile_shape)[len(base_tile_shape):] @@ -317,13 +457,19 @@ def tiled_tiling_rank(self) -> int: @property def vector_length(self) -> int: + if cc_method_exists(self, "vector_length"): + return self.dispatch_to_cc("vector_length", check_canonical=False) return self.tiled_tiling_shape[self.vector_dim] + def registers_element_type(self, t: ir.Type) -> ir.Type: + return ir.VectorType.get((self.vector_length,), t) + def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: """Returns the shape of the register array needed to represent an array of the given logical shape.""" tiled_shape = list(self.tiling.tile_shape(shape)) - tiled_shape[self.warp_dim] = 1 - for d in self.lane_dims: + for d in self.partitioned_warp_dims: + tiled_shape[d] = 1 + for d in self.partitioned_lane_dims: tiled_shape[d] = 1 tiled_shape[self.vector_dim] = 1 return tuple(tiled_shape) @@ -335,38 +481,167 @@ def shape_from_registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: """ tiled_tiling = self.tiled_tiling_shape shape = list(shape) - shape[self.warp_dim] = WARPS_IN_WARPGROUP - for d in self.lane_dims: + for d in self.partitioned_warp_dims: + shape[d] = tiled_tiling[d] + for d in self.partitioned_lane_dims: shape[d] = tiled_tiling[d] shape[self.vector_dim] = tiled_tiling[self.vector_dim] return self.tiling.untile_shape(tuple(shape)) - def lane_indices(self) -> tuple[ir.Value, ...]: + def _delinearize_index( + self, idx: ir.Value, dims: tuple[int | Replicated, ...] + ) -> tuple[ir.Value, ...]: i32 = ir.IntegerType.get_signless(32) tiled_shape = self.tiled_tiling_shape - lanes_shape = tuple(tiled_shape[d] for d in self.lane_dims) - assert math.prod(lanes_shape) == WARP_SIZE - lane_strides = utils.get_contiguous_strides(lanes_shape) - lane_idx = arith.remui(utils.thread_idx(), c(WARP_SIZE, i32)) - lane_indices = tuple( - arith.remui(arith.divui(lane_idx, c(stride, i32)), c(size, i32)) - for stride, size in zip(lane_strides, lanes_shape) + dims_shape = tuple( + d.times if isinstance(d, Replicated) else tiled_shape[d] + for d in dims + ) + dims_strides = utils.get_contiguous_strides(dims_shape) + dims_indices = tuple( + arith.remui(arith.divui(idx, c(stride, i32)), c(size, i32)) + for stride, size in zip(dims_strides, dims_shape) ) full_indices = [arith.constant(i32, 0)] * len(tiled_shape) - for d, i in zip(self.lane_dims, lane_indices): + for d, i in zip(dims, dims_indices): + if isinstance(d, Replicated): + continue full_indices[d] = i return tuple(full_indices) + def lane_indices(self) -> tuple[ir.Value, ...]: + i32 = ir.IntegerType.get_signless(32) + lane_idx = arith.remui(utils.thread_idx(), c(WARP_SIZE, i32)) + return self._delinearize_index(lane_idx, self.lane_dims) + def warp_indices(self) -> tuple[ir.Value, ...]: i32 = ir.IntegerType.get_signless(32) - tiled_shape_rank = len(self.tiled_tiling_shape) warp_idx = arith.remui( arith.divui(utils.thread_idx(), c(WARP_SIZE, i32)), c(WARPS_IN_WARPGROUP, i32), ) - indices = [arith.constant(i32, 0)] * tiled_shape_rank - indices[self.warp_dim] = warp_idx - return tuple(indices) + return self._delinearize_index(warp_idx, self.warp_dims) + + def remove_dimension(self, dim: int) -> TiledLayout: + if dim < 0 or dim >= len(self.tiling.tiles[0]): + raise ValueError(f"Dimension {dim} is out of range for {self.tiling}") + new_tiling = self.tiling.remove_dimension(dim) + tiled_shape = self.tiled_tiling_shape + removed_dim = self.tiling.tile_dimension(dim) + dim_offsets = np.cumsum(removed_dim[::-1])[::-1].tolist() + if removed_dim[self.vector_dim]: + new_tiling = Tiling((*new_tiling.tiles, (1,))) + new_vector_dim = -1 + dim_offsets = [o - 1 for o in dim_offsets] # We inserted an extra dim. + else: + new_vector_dim = self.vector_dim + dim_offsets[self.vector_dim] + def replace_tiled_dim(d: int | Replicated, size: int): + if isinstance(d, Replicated): + return d + elif removed_dim[d]: + return Replicated(size) + else: + return d + dim_offsets[d] + return TiledLayout( + new_tiling, + tuple( + d if isinstance(d, Replicated) else replace_tiled_dim(d, tiled_shape[d]) + for d in self.warp_dims + ), + tuple( + d if isinstance(d, Replicated) else replace_tiled_dim(d, tiled_shape[d]) + for d in self.lane_dims + ), + new_vector_dim, + _check_canonical=False, + ).canonicalize() + + def reduce(self, axes: Sequence[int]) -> TiledLayout: + reduced_layout = self + for a in sorted(axes, reverse=True): + reduced_layout = reduced_layout.remove_dimension(a) + return reduced_layout + + def canonicalize(self) -> TiledLayout: + """Returns a version of this layout where tiling is canonical.""" + if cc_method_exists(self, "canonicalize"): + c_layout = self.dispatch_to_cc("canonicalize", check_canonical=False) + return TiledLayout( + tiling=c_layout.tiling, + warp_dims=c_layout.warp_dims, + lane_dims=c_layout.lane_dims, + vector_dim=c_layout.vector_dim, + _check_canonical=False + ) + + canonical_tiling = self.tiling.canonicalize() + + s = self.base_tile_shape + tiled_tiling_shape = self.tiled_tiling_shape + canonical_tiled_tiling_shape = canonical_tiling.tile_shape(s)[len(s):] + offset = len(canonical_tiled_tiling_shape) - 1 + + rev_removed_dims = [] + # Iterate starting from the end in order to eliminate leading dimensions, + # whenever possible. For instance, say we have + # + # shape=(4, 32, 1, 1, 1, 1, 1) + # warp_dims=(-7,), + # lane_dims=(-6,) + # vector_dim=-1 + # + # and we want to canonicalize this to + # + # shape=(4, 32, 1) + # warp_dims=(-3,), + # lane_dims=(-2,) + # vector_dim=-1. + # + # After the loop below, we end up with + # + # rev_removed_dims=[False, True, True, True, True, False, False] + # + # which will yield offsets `4` for `warp_dims[0]`, `4` for `lane_dims[0]`, + # and `0` for `vector_dim`. + for d in reversed(tiled_tiling_shape): + if offset >= 0 and d == canonical_tiled_tiling_shape[offset]: + rev_removed_dims.append(False) + offset -= 1 + else: + rev_removed_dims.append(True) + assert offset == -1 + + dim_offsets = np.cumsum(rev_removed_dims)[::-1].tolist() + + def replace_tiled_dim(d: int | Replicated): + return d if isinstance(d, Replicated) else d + dim_offsets[d] + + def is_nontrivial(d: int | Replicated): + return isinstance(d, Replicated) or tiled_tiling_shape[d] != 1 + + return TiledLayout( + canonical_tiling, + tuple(replace_tiled_dim(d) for d in self.warp_dims if is_nontrivial(d)), + tuple(replace_tiled_dim(d) for d in self.lane_dims if is_nontrivial(d)), + replace_tiled_dim(self.vector_dim), + _check_canonical=False, + ) + + def dispatch_to_cc(self, method_name: str, *args, **kwargs): + check_canonical = kwargs.pop("check_canonical", True) + return dispatch_to_cc_method( + self, + method_name, + lambda inst: [ + inst.tiling, + inst.warp_dims, + inst.lane_dims, + inst.vector_dim, + check_canonical + ], + *args, + **kwargs, + ) def _tiled_wgmma_layout(shape: tuple[int, ...]): @@ -382,28 +657,6 @@ def _tiled_wgmma_layout(shape: tuple[int, ...]): return WGMMA_LAYOUT -@dataclasses.dataclass(frozen=True) -class WGMMARowFragLayout: - """[m] matrix, where m % 64 == 0.""" - - def thread_idxs(self, shape): - index = ir.IndexType.get() - assert len(shape) == 1 - assert shape[0] % 64 == 0 - tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx()) - tid_wg = arith.remui(tid, c(WARPGROUP_SIZE, index)) - warp_idx = arith.divui(tid_wg, c(32, index)) - lane_id = arith.remui(tid_wg, c(32, index)) - row_base = arith.addi( - arith.divui(lane_id, c(4, index)), arith.muli(warp_idx, c(16, index)) - ) - - for row_group in range(0, shape[0], 64): - for row_subgroup in (0, 8): - row = arith.addi(row_base, c(row_group + row_subgroup, index)) - yield (row,) - - @dataclasses.dataclass(frozen=True) class WGSplatFragLayout: """A fragmented array where all the values are equal represented as a register per thread. @@ -430,10 +683,28 @@ class WGSplatFragLayout: def can_broadcast_to(self, shape) -> bool: """Check that the shape can be broadcast. - Only dimensions of size 1 can be broadcast. All other dimensions - must be the same as the argument shape. + All source dimensions must match the target's trailing dimensions by + equality or being set to 1 (i.e. we can broadcast 1-sized dimensions or + create new leading dimensions). """ - return all(dim1 == dim2 or dim1 == 1 for dim1, dim2 in zip(self.shape[::-1], shape[::-1])) + return len(self.shape) <= len(shape) and all( + dim1 == dim2 or dim1 == 1 + for dim1, dim2 in zip(self.shape[::-1], shape[::-1]) + ) + + def registers_element_type(self, t: ir.Type) -> ir.Type: + return t + + def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: + """Returns the shape of the register array needed to represent an array of the given logical shape.""" + del shape # Unused. + return () + + def shape_from_registers_shape( + self, shape: tuple[int, ...] + ) -> tuple[int, ...]: + del shape # Unused. + return self.shape def thread_idxs(self, shape): assert shape == self.shape @@ -452,23 +723,41 @@ def __post_init__(self): raise ValueError((self, WARPGROUP_SIZE)) @classmethod - def from_shaped_type(cls, shaped_ty: ir.Type): - if not ir.ShapedType.isinstance(shaped_ty): + def from_shaped_type(cls, shaped_ty: ir.Type) -> WGStridedFragLayout | None: + """Returns a WGStridedFragLayout for the given shaped type. + + Return None if the shaped type cannot have a strided layout. + """ + if not isinstance(shaped_ty, ir.ShapedType): raise TypeError(shaped_ty) shaped_ty = ir.ShapedType(shaped_ty) - bw = mgpu.bytewidth(shaped_ty.element_type) + if (bitwidth := mgpu.bitwidth(shaped_ty.element_type)) % 8: + return None + bw = bitwidth // 8 assert 8 % bw == 0 and 8 // bw != 0, bw if math.prod(shaped_ty.shape) % WARPGROUP_SIZE != 0: - raise ValueError( - f"{shaped_ty} must have a number of elements that is a multiple of" - f" {WARPGROUP_SIZE} (got {math.prod(shaped_ty.shape)})" - ) + return None max_vec_size = np.prod(shaped_ty.shape) // WARPGROUP_SIZE return cls( shape=tuple(shaped_ty.shape), vec_size=min(8 // bw, max_vec_size) ) + def registers_element_type(self, t: ir.Type) -> ir.Type: + return ir.VectorType.get((self.vec_size,), t) + + def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: + """Returns the shape of the register array needed to represent an array of the given logical shape.""" + if shape != self.shape: + raise ValueError(f"Shape {shape} is not compatible with {self}") + return (math.prod(self.shape) // (WARPGROUP_SIZE * self.vec_size),) + + def shape_from_registers_shape( + self, shape: tuple[int, ...] + ) -> tuple[int, ...]: + del shape # Unused. + return self.shape + def thread_idxs(self, shape): assert shape == self.shape index = ir.IndexType.get() @@ -496,15 +785,25 @@ def linear_thread_idxs(self): for i in range(reg_num): yield arith.addi(off, c(i * WARPGROUP_SIZE * self.vec_size, tidx.type)) - -FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | WGMMARowFragLayout | TiledLayout +FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | TiledLayout -WGMMA_ROW_LAYOUT = WGMMARowFragLayout() +WGMMA_COL_LAYOUT = TiledLayout( + Tiling(((8,), (2,))), + warp_dims=(Replicated(4),), + lane_dims=(Replicated(8), -2), + vector_dim=-1, +) +WGMMA_ROW_LAYOUT = TiledLayout( + Tiling(((64,), (16,), (8,), (1,))), + warp_dims=(-4,), + lane_dims=(-2, Replicated(4)), + vector_dim=-1, +) # The tiled layout is equivalent to one described here in PTX documentation: # https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-d -# In this layout, we partition the 64x8 tiles over 4 warpgroups into 16x8 tiles. +# In this layout, we partition the 64x8 tiles over 4 warps into 16x8 tiles. # Then, we further split the 16x8 tiles into 8x8 submatrices which are the unit # of data that is split across a warp. Since 8*8 = 64, but a warp has only 32 # threads, we vectorize pairs of elements along columns. @@ -516,11 +815,44 @@ def linear_thread_idxs(self): # 12 12 13 13 14 14 15 15 # ... WGMMA_LAYOUT = TiledLayout( - Tiling(((64, 8), (16, 8), (8, 8), (1, 2))), - warp_dim=-8, + Tiling(((64, 8), (16, 8), (8, 8), (2,))), + warp_dims=(-7,), + lane_dims=(-3, -2), + vector_dim=-1, +) +# This is the same as WGMMA_LAYOUT, only with a vector length of 1. LLVM now +# treats <2 x float> as a native PTX type and uses 64-bit registers to store +# them. This, in turn, means that we have to explode them into 32-bit registers +# right before WGMMA, which makes ptxas very unhappy and causes it to insert +# lots of WGMMA waits that absolutely tank the performance. As a workaround, +# we use this layout when 32-bit data with WGMMA_LAYOUT is used to initialize +# a WGMMAAccumulator, to ensure that the LLVM accumulator registers will always +# be represented as 32-bit PTX registers. +WGMMA_LAYOUT_ACC_32BIT = TiledLayout( + Tiling(((64, 8), (16, 8), (8, 8), (2,), (1,))), + warp_dims=(-8,), lane_dims=(-4, -3), vector_dim=-1, ) +# The tiled layout is equivalent to one described here in PTX documentation: +# https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n32-a +# In this layout, we partition the 64x16 tiles over 4 warps into 16x16 tiles. +# Then, we further split the 16x16 tiles into 8x16 submatrices which are the unit +# of data that is split across a warp. Since 8*16 = 128, but a warp has only 32 +# threads, we vectorize quadruplets of elements along columns. +# The assignment of elements to warp lanes is as follows: +# +# 0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3 +# 4 4 4 4 5 5 5 5 6 6 6 6 7 7 7 7 +# 8 8 8 8 9 9 9 9 10 10 10 10 11 11 11 11 +# 12 12 12 12 13 13 13 13 14 14 14 14 15 15 15 15 +# ... +WGMMA_LAYOUT_8BIT = TiledLayout( + Tiling(((64, 16), (16, 16), (8, 16), (4,))), + warp_dims=(-7,), + lane_dims=(-3, -2), + vector_dim=-1, +) # This tiled layout is similar to the WGMMA layout, only the unit at which we # assign submatrices to warps grows from 8x8 to 8x16. The elements within each # submatrix are assigned to threads in the following way: @@ -536,7 +868,7 @@ def linear_thread_idxs(self): # only requires a single warp shuffle (plus permutes local to each thread). WGMMA_LAYOUT_UPCAST_2X = TiledLayout( Tiling(((64, 16), (16, 16), (8, 16), (8,), (4,))), - warp_dim=-8, + warp_dims=(-8,), lane_dims=(-4, -2, -3), vector_dim=-1, ) @@ -548,7 +880,7 @@ def linear_thread_idxs(self): # 5 and 6, etc. that WGMMA_LAYOUT_UPCAST_2X does). WGMMA_LAYOUT_UPCAST_4X = TiledLayout( Tiling(((64, 32), (16, 32), (8, 32), (8,))), - warp_dim=-7, + warp_dims=(-7,), lane_dims=(-3, -2), vector_dim=-1, ) @@ -570,18 +902,89 @@ def linear_thread_idxs(self): # ... # # You can see that we have taken 2x2 submatrices from the above layout and -# transposed them. The assigment of lanes to elements is such that in both +# transposed them. The assignment of lanes to elements is such that in both # layouts the same two lanes map to a single 2x2 submatrix, making the transpose # very cheap (one shuffle and permute suffices to change between those layouts). WGMMA_TRANSPOSED_LAYOUT = TiledLayout( Tiling(((64, 8), (16, 8), (8, 8), (2, 2), (2, 1))), - warp_dim=-10, + warp_dims=(-10,), + lane_dims=(-6, -3, -5), + vector_dim=-2, +) + +# Like WGMMA_LAYOUT, only each warp holds a 32xN strip instead of 16xN. +TCGEN05_LAYOUT = TiledLayout( + Tiling(((128, 8), (32, 8), (8, 8), (2,))), + warp_dims=(-7,), + lane_dims=(-3, -2), + vector_dim=-1, +) +# Like WGMMA_TRANSPOSED_LAYOUT, only each warp holds a 32xN strip instead of 16xN. +TCGEN05_TRANSPOSED_LAYOUT = TiledLayout( + Tiling(((128, 8), (32, 8), (8, 8), (2, 2), (2, 1))), + warp_dims=(-10,), lane_dims=(-6, -3, -5), vector_dim=-2, ) +# TCGEN05_ROW_LAYOUT is to TCGEN05_LAYOUT as WGMMA_ROW_LAYOUT is to +# WGMMA_LAYOUT. +TCGEN05_ROW_LAYOUT = TiledLayout( + Tiling(tiles=((128,), (32,), (8,), (1,))), + warp_dims=(-4,), + lane_dims=(-2, Replicated(times=4)), + vector_dim=-1, +) +# TCGEN05_COL_LAYOUT is to TCGEN05_LAYOUT as WGMMA_COL_LAYOUT is to +# WGMMA_LAYOUT. +TCGEN05_COL_LAYOUT = TiledLayout( + Tiling(tiles=((8,), (2,))), + warp_dims=(Replicated(times=4),), + lane_dims=(Replicated(times=8), -2), + vector_dim=-1, +) + + +def tmem_native_layout(vector_length: int): + """A layout resembling the logical organization of TMEM. + + The 128 rows in a tile are assigned to 128 lanes in the warpgroup. Useful when + the result needs to be processed in registers and then stored back into TMEM. + Usually shouldn't be used if the result is to be written back to SMEM, as + there is no good way to store it without bank conflicts, but it still + sometimes pays off. + """ + return TiledLayout( + Tiling(((128, vector_length), (32, vector_length))), + warp_dims=(-4,), + lane_dims=(-2,), + vector_dim=-1, + ) + +# We use a vector_dim of 2, to be able to make sure that the vectors are always +# a multiple of 32-bits, even when the data is 16-bits. +TMEM_NATIVE_LAYOUT = tmem_native_layout(2) + +# A layout for the row indices used by TMA gather4/scatter4 instructions. +# Index 00 01 02 03 04 05 06 07 08 09 10 11 12 13 14 15 16 17 ... +# Warp <--- 0 ---> <--- 1 ---> <--- 2 ---> <--- 3 ---> <--- 0 -- +TMA_GATHER_INDICES_LAYOUT = TiledLayout( + Tiling(((16,), (4,))), + warp_dims=(-2,), + lane_dims=(Replicated(32),), + vector_dim=-1, +) + + +def can_relayout_wgmma_4x_to_wgmma_2x(bitwidth: int) -> bool: + return bitwidth == 4 + + +def can_relayout_wgmma_2x_to_wgmma(bitwidth: int) -> bool: + return bitwidth <= 16 + @jax.tree_util.register_pytree_node_class -@dataclasses.dataclass(init=False, eq=False, frozen=True, slots=True) +@dataclasses.dataclass(init=False, frozen=True, slots=True) class FragmentedArray: # An array of ir.Value, see checks in init for shapes. registers: np.ndarray = dataclasses.field(repr=False) @@ -605,19 +1008,13 @@ def __init__( object.__setattr__(self, "layout", _layout) object.__setattr__(self, "is_signed", _is_signed) - if (_is_signed is not None) != ir.IntegerType.isinstance(self.mlir_dtype): + if (_is_signed is not None) != isinstance(self.mlir_dtype, ir.IntegerType): raise TypeError( "is_signed must be non-None if and only if the MLIR type is an" f" integer type, got {_is_signed=} for {self.mlir_dtype}" ) match self.layout: - # Registers are [m_tiles, 2 rows] in WGMMA_ROW layout - # Each element is a dtype scalar - case WGMMARowFragLayout(): - if _registers.ndim != 2 or _registers.shape[-1] != 2: - raise ValueError(f"Invalid register array shape: {_registers.shape}") - # Registers are flat case WGStridedFragLayout(shape): [reg_size] = ir.VectorType(_registers.flat[0].type).shape @@ -626,8 +1023,8 @@ def __init__( != math.prod(_registers.shape) * WARPGROUP_SIZE * reg_size ): raise ValueError( - "Invalid register array shape: math.prod({_registers.shape}) *" - " {WARPGROUP_SIZE} * {reg_size}, want: math.prod({shape})" + f"Invalid register array shape: math.prod({_registers.shape}) *" + f" {WARPGROUP_SIZE} * {reg_size}, want: math.prod({shape})" ) # Just a single register @@ -653,90 +1050,83 @@ def load_strided( *, is_signed: bool | None = None, vec_size: int | None = None, - ): - if not ir.MemRefType.isinstance(ref.type): + ) -> FragmentedArray: + if not isinstance(ref.type, ir.MemRefType): raise TypeError(ref.type) ref_ty = ir.MemRefType(ref.type) shape = tuple(ref_ty.shape) if vec_size is None: layout = WGStridedFragLayout.from_shaped_type(ref_ty) + if layout is None: + raise ValueError( + f"{ref_ty} must have a number of elements that is a multiple of" + f" {WARPGROUP_SIZE} (got {math.prod(shape)})" + ) else: layout = WGStridedFragLayout(shape=shape, vec_size=vec_size) + registers = np.empty(layout.registers_shape(shape), dtype=object) vec_ty = ir.VectorType.get((layout.vec_size,), ref_ty.element_type) - try: - # Flattening the reference potentially produces simpler PTX but - # if the ref is not already 1D and has strided dimensions - # flattening won't work. - ref_ = mgpu.memref_fold(ref, 0, len(ref_ty.shape)) - vecs = [vector.load(vec_ty, ref_, [vec_idx]) for vec_idx in layout.linear_thread_idxs()] - except NotImplementedError: - vecs = [vector.load(vec_ty, ref, vec_idx) for vec_idx in layout.thread_idxs(shape)] - return cls(_registers=np.array(vecs), _layout=layout, _is_signed=is_signed) - - @classmethod - def load_wgmma_row( - cls, - ref: ir.Value, - *, - is_signed: bool | None = None, - ): - if not ir.MemRefType.isinstance(ref.type): - raise TypeError(ref.type) - - ref_ty = ir.MemRefType(ref.type) - shape = tuple(ref_ty.shape) - if len(shape) != 1: - raise ValueError("WGMMARowFragLayout requires a 1D shape") - if shape[0] % 64: - raise ValueError( - "WGMMARowFragLayout requires shape[0] to be a multiple of 64" - ) - - layout = WGMMARowFragLayout() - registers = [memref.load(ref, [idx]) for (idx,) in layout.thread_idxs(shape)] - registers = np.array(registers).reshape(-1, 2) + for _get, update, ref, idx in cls.transfer_strided(ref, layout.vec_size): + update(registers, vector.load(vec_ty, ref, idx)) return cls(_registers=registers, _layout=layout, _is_signed=is_signed) - @classmethod - def splat(cls, value, shape, layout=None, *, is_signed: bool | None = None): + def splat( + cls, value, shape, layout=None, *, is_signed: bool | None = None + ) -> FragmentedArray: layout = layout or WGSplatFragLayout(shape) match layout: - case WGMMARowFragLayout(): - if len(shape) != 1: - raise ValueError("WGMMARowFragLayout requires a 1D shape") - if shape[0] % 64: - raise ValueError( - "WGMMARowFragLayout requires shape[0] to be a multiple of 64" - ) - reg_shape = (shape[0] // 64, 2) - case WGStridedFragLayout(vec_size=vec_size): - assert shape == layout.shape - elems = np.prod(shape) - reg_shape = (elems // (WARPGROUP_SIZE * vec_size),) - value = vector.splat(ir.VectorType.get((vec_size,), value.type), value) case WGSplatFragLayout(): - assert shape == layout.shape - reg_shape = () - case TiledLayout(): - value = vector.splat(ir.VectorType.get((layout.vector_length,), value.type), value) - reg_shape = layout.registers_shape(shape) + pass + case WGStridedFragLayout() | TiledLayout(): + value = vector.broadcast( + layout.registers_element_type(value.type), value + ) case _: raise NotImplementedError(layout) return cls( - _registers=np.full(reg_shape, value, dtype=object), + _registers=np.full(layout.registers_shape(shape), value, dtype=object), _layout=layout, _is_signed=is_signed, ) + @staticmethod + def broadcasted_iota( + dtype: ir.Type, + shape: tuple[int, ...], + dimension: int, + layout: FragmentedLayout | None = None, + *, + is_signed: bool | None = None, + ) -> FragmentedArray: + """Creates a broadcasted iota array along the specified dimension.""" + if dimension >= len(shape): + raise ValueError( + "`dimension` must be smaller than the rank of the array." + ) + + def cast(idx: ir.Value) -> ir.Value: + if isinstance(dtype, ir.FloatType): + i32 = ir.IntegerType.get_signless(32) + return arith.uitofp(dtype, arith.index_cast(i32, idx)) + return arith.index_cast(dtype, idx) + + return mgpu.FragmentedArray.splat( + llvm.mlir_undef(dtype), + shape, + layout, + is_signed=is_signed, + ).foreach( + lambda _, idx: cast(idx[dimension]), + create_array=True, + is_signed=is_signed, + ) + @property - def shape(self): + def shape(self) -> tuple[int, ...]: match self.layout: - case WGMMARowFragLayout(): - row_tiles = self.registers.shape[0] - return (row_tiles * 64,) case WGStridedFragLayout(shape): return shape case WGSplatFragLayout(shape=shape): @@ -747,30 +1137,31 @@ def shape(self): raise NotImplementedError @property - def mlir_dtype(self): + def mlir_dtype(self) -> ir.Type: reg_ty = self.registers.flat[0].type match self.layout: case WGStridedFragLayout() | TiledLayout(): return ir.VectorType(reg_ty).element_type - case WGMMARowFragLayout() | WGSplatFragLayout(): + case WGSplatFragLayout(): return reg_ty case _: raise NotImplementedError - def to_layout(self, new_layout: FragmentedLayout): - """Converts the fragmented array to the given layout. - - At the moment, only conversions from ``WGSplatFragLayout`` are supported. - """ + def to_layout(self, new_layout: FragmentedLayout) -> FragmentedArray: + """Converts the fragmented array to the given layout.""" i32 = ir.IntegerType.get_signless(32) c = lambda x: arith.constant(i32, x) if self.layout == new_layout: return self shape = self.shape - if ( - self.layout == WGMMA_LAYOUT - and new_layout == WGMMA_TRANSPOSED_LAYOUT - and utils.bitwidth(self.mlir_dtype) == 16 + bitwidth = utils.bitwidth(self.mlir_dtype) + transpose_pairs = ( + (WGMMA_LAYOUT, WGMMA_TRANSPOSED_LAYOUT), + (TCGEN05_LAYOUT, TCGEN05_TRANSPOSED_LAYOUT), + ) + if bitwidth in {16, 32} and ( + (self.layout, new_layout) in transpose_pairs + or (new_layout, self.layout) in transpose_pairs ): is_even_row = arith.cmpi( arith.CmpIPredicate.eq, @@ -778,22 +1169,92 @@ def to_layout(self, new_layout: FragmentedLayout): c(0), ) perm = arith.select(is_even_row, c(0x5410), c(0x3276)) - new_regs = [] + tmp_new_regs = [] for reg in self.registers.flat: reg_ty = reg.type - reg = utils.bitcast(reg, i32) - reg_shfl = utils.shfl_bfly(reg, 4) - new_reg = utils.prmt(reg, reg_shfl, perm) - new_regs.append(utils.bitcast(new_reg, reg_ty)) + if bitwidth == 16: + reg = utils.bitcast(reg, i32) + reg_shfl = utils.shfl_bfly(reg, 4) + new_reg = utils.prmt(reg, reg_shfl, perm) + elif bitwidth == 32: + i32_vec = ir.VectorType.get((1,), i32) + regs = [ + utils.bitcast(utils.vector_slice(reg, slice(i, i + 1)), i32) + for i in range(2) + ] + reg_to_shfl = arith.select(is_even_row, regs[1], regs[0]) + reg_shfl = utils.shfl_bfly(reg_to_shfl, 4) + new_reg_low = arith.select(is_even_row, regs[0], reg_shfl) + new_reg_high = arith.select(is_even_row, reg_shfl, regs[1]) + new_reg_i32 = utils.vector_concat([ + utils.bitcast(new_reg_low, i32_vec), + utils.bitcast(new_reg_high, i32_vec), + ]) + new_reg = utils.bitcast(new_reg_i32, reg_ty) + else: + raise ValueError(f"Unsupported bitwidth: {bitwidth}") + tmp_new_regs.append(utils.bitcast(new_reg, reg_ty)) + new_regs = np.asarray( + tmp_new_regs, dtype=object + ).reshape(new_layout.registers_shape(shape)) + return FragmentedArray( + _registers=new_regs, _layout=new_layout, _is_signed=self.is_signed + ) + if ( + isinstance(self.layout, TiledLayout) + and isinstance(new_layout, TiledLayout) + and self.layout == tmem_native_layout(self.layout.vector_length) + and new_layout == tmem_native_layout(new_layout.vector_length) + ): + new_registers = np.empty(new_layout.registers_shape(shape), dtype=object) + if self.layout.vector_length > new_layout.vector_length: + ratio = self.layout.vector_length // new_layout.vector_length + new_length = new_layout.vector_length + for idx, reg in np.ndenumerate(self.registers): + for i in range(ratio): + new_reg = utils.vector_slice( + reg, slice(i * new_length, (i + 1) * new_length) + ) + new_registers[(idx[0], idx[1] * ratio + i, *idx[2:])] = new_reg + elif self.layout.vector_length < new_layout.vector_length: + ratio = new_layout.vector_length // self.layout.vector_length + for idx in np.ndindex(new_registers.shape): + new_reg = utils.vector_concat([ + self.registers[idx[0], idx[1] * ratio + i, *idx[2:]] + for i in range(ratio) + ]) + new_registers[idx] = new_reg + return FragmentedArray( + _registers=new_registers, _layout=new_layout, _is_signed=self.is_signed, + ) + if self.layout == WGMMA_LAYOUT_ACC_32BIT and new_layout == WGMMA_LAYOUT: + new_regs_shape = new_layout.registers_shape(shape) + assert new_regs_shape[-1] == 1 + assert self.registers.shape == (*new_regs_shape[:-1], 2, 1) + new_regs = np.empty(new_regs_shape, dtype=object) + for idx in np.ndindex(new_regs_shape[:-1]): + new_regs[(*idx, 0)] = utils.vector_concat([ + self.registers[*idx, i, 0] for i in range(2) + ]) + return FragmentedArray( + _registers=new_regs, _layout=new_layout, _is_signed=self.is_signed, + ) + if self.layout == WGMMA_LAYOUT and new_layout == WGMMA_LAYOUT_ACC_32BIT: + new_regs_shape = new_layout.registers_shape(shape) + assert self.registers.shape[-1] == 1 + assert new_regs_shape == (*self.registers.shape[:-1], 2, 1) + new_regs = np.empty(new_regs_shape, dtype=object) + for idx, reg in np.ndenumerate(self.registers): + for i in range(2): + new_regs[(*idx[:-1], i, 0)] = utils.vector_slice(reg, slice(i, i + 1)) return FragmentedArray( - _registers=np.asarray(new_regs, dtype=object).reshape(new_layout.registers_shape(shape)), - _layout=new_layout, - _is_signed=self.is_signed, + _registers=new_regs, _layout=new_layout, _is_signed=self.is_signed, ) + dtype_bitwidth = utils.bitwidth(self.mlir_dtype) if ( self.layout == WGMMA_LAYOUT_UPCAST_2X and new_layout == WGMMA_LAYOUT - and (dtype_bitwidth := utils.bitwidth(self.mlir_dtype)) <= 16 + and can_relayout_wgmma_2x_to_wgmma(dtype_bitwidth) ): assert shape[1] % 16 == 0 # Should be implied by the layout new_registers = np.empty(new_layout.registers_shape(shape), dtype=object) @@ -882,7 +1343,7 @@ def to_layout(self, new_layout: FragmentedLayout): if ( self.layout == WGMMA_LAYOUT_UPCAST_4X and new_layout == WGMMA_LAYOUT_UPCAST_2X - and utils.bitwidth(self.mlir_dtype) == 4 + and can_relayout_wgmma_4x_to_wgmma_2x(dtype_bitwidth) ): assert shape[0] % 64 == 0 # Should be implied by the layout assert shape[1] % 32 == 0 # Should be implied by the layout @@ -920,12 +1381,13 @@ def to_layout(self, new_layout: FragmentedLayout): raise NotImplementedError( f"Cannot convert from {self.layout} to {new_layout}" ) - [reg] = self.registers.flat return type(self).splat( - reg, self.shape, new_layout, is_signed=self.is_signed + self.registers.item(), self.shape, new_layout, is_signed=self.is_signed ) - def _pointwise(self, op, *other, output_is_signed: bool | None = None): + def _pointwise( + self, op, *other, output_is_signed: bool | None = None + ) -> FragmentedArray: # If our layout is a splat, then we should either dispatch to a non-splat # layout, or broadcast ourselves to the output shape first. if isinstance(self.layout, WGSplatFragLayout): @@ -983,9 +1445,9 @@ def _pointwise(self, op, *other, output_is_signed: bool | None = None): for idx, reg in np.ndenumerate(self.registers): new_regs[idx] = op(reg, *(o.registers[idx] for o in other_arrs)) reg_ty = new_regs.flat[0].type - if ir.VectorType.isinstance(reg_ty): + if isinstance(reg_ty, ir.VectorType): reg_ty = ir.VectorType(reg_ty).element_type - if output_is_signed is None and ir.IntegerType.isinstance(reg_ty): + if output_is_signed is None and isinstance(reg_ty, ir.IntegerType): output_is_signed = self.is_signed return FragmentedArray( _registers=new_regs, _layout=self.layout, _is_signed=output_is_signed @@ -995,17 +1457,17 @@ def __pos__(self): return self def __neg__(self): - if ir.FloatType.isinstance(self.mlir_dtype): + if isinstance(self.mlir_dtype, ir.FloatType): return self._pointwise(arith.negf) - elif ir.IntegerType.isinstance(self.mlir_dtype): + elif isinstance(self.mlir_dtype, ir.IntegerType): return 0 - self else: return NotImplemented def __add__(self, other): - if ir.FloatType.isinstance(self.mlir_dtype): + if isinstance(self.mlir_dtype, ir.FloatType): return self._pointwise(addf, other) - elif ir.IntegerType.isinstance(self.mlir_dtype): + elif isinstance(self.mlir_dtype, ir.IntegerType): return self._pointwise(arith.addi, other) else: return NotImplemented @@ -1014,9 +1476,9 @@ def __radd__(self, other): return self + other def __mul__(self, other): - if ir.FloatType.isinstance(self.mlir_dtype): + if isinstance(self.mlir_dtype, ir.FloatType): return self._pointwise(mulf, other) - elif ir.IntegerType.isinstance(self.mlir_dtype): + elif isinstance(self.mlir_dtype, ir.IntegerType): return self._pointwise(arith.muli, other) else: return NotImplemented @@ -1025,37 +1487,37 @@ def __rmul__(self, other): return self * other def __sub__(self, other): - if ir.FloatType.isinstance(self.mlir_dtype): + if isinstance(self.mlir_dtype, ir.FloatType): return self._pointwise(subf, other) - elif ir.IntegerType.isinstance(self.mlir_dtype): + elif isinstance(self.mlir_dtype, ir.IntegerType): return self._pointwise(arith.subi, other) else: return NotImplemented def __rsub__(self, other): - if ir.FloatType.isinstance(self.mlir_dtype): + if isinstance(self.mlir_dtype, ir.FloatType): return self._pointwise(lambda s, o: subf(o, s), other) - elif ir.IntegerType.isinstance(self.mlir_dtype): + elif isinstance(self.mlir_dtype, ir.IntegerType): return self._pointwise(lambda s, o: arith.subi(o, s), other) else: return NotImplemented def __truediv__(self, other): - if not ir.FloatType.isinstance(self.mlir_dtype): + if not isinstance(self.mlir_dtype, ir.FloatType): return NotImplemented return self._pointwise(arith.divf, other) def __rtruediv__(self, other): - if not ir.FloatType.isinstance(self.mlir_dtype): + if not isinstance(self.mlir_dtype, ir.FloatType): return NotImplemented return self._pointwise(lambda s, o: arith.divf(o, s), other) def __floordiv__(self, other): - if ir.FloatType.isinstance(self.mlir_dtype): + if isinstance(self.mlir_dtype, ir.FloatType): return self._pointwise( lambda s, o: mlir_math.floor(arith.divf(s, o)), other ) - elif ir.IntegerType.isinstance(self.mlir_dtype): + elif isinstance(self.mlir_dtype, ir.IntegerType): if self.is_signed: return self._pointwise(arith.floordivsi, other) else: @@ -1064,11 +1526,11 @@ def __floordiv__(self, other): return NotImplemented def __rfloordiv__(self, other): - if ir.FloatType.isinstance(self.mlir_dtype): + if isinstance(self.mlir_dtype, ir.FloatType): return self._pointwise( lambda s, o: mlir_math.floor(arith.divf(o, s)), other ) - elif ir.IntegerType.isinstance(self.mlir_dtype): + elif isinstance(self.mlir_dtype, ir.IntegerType): if self.is_signed: return self._pointwise(lambda s, o: arith.floordivsi(o, s), other) else: @@ -1077,7 +1539,7 @@ def __rfloordiv__(self, other): return NotImplemented def __mod__(self, other): - if not ir.IntegerType.isinstance(self.mlir_dtype): + if not isinstance(self.mlir_dtype, ir.IntegerType): return NotImplemented if self.is_signed: return self._pointwise(arith.remsi, other) @@ -1085,7 +1547,7 @@ def __mod__(self, other): return self._pointwise(arith.remui, other) def __rmod__(self, other): - if not ir.IntegerType.isinstance(self.mlir_dtype): + if not isinstance(self.mlir_dtype, ir.IntegerType): return NotImplemented if self.is_signed: return self._pointwise(lambda s, o: arith.remsi(o, s), other) @@ -1093,12 +1555,12 @@ def __rmod__(self, other): return self._pointwise(lambda s, o: arith.remui(o, s), other) def __invert__(self): - if not ir.IntegerType.isinstance(self.mlir_dtype): + if not isinstance(self.mlir_dtype, ir.IntegerType): return NotImplemented return self ^ ~0 def __or__(self, other): - if not ir.IntegerType.isinstance(self.mlir_dtype): + if not isinstance(self.mlir_dtype, ir.IntegerType): return NotImplemented return self._pointwise(arith.ori, other) @@ -1106,7 +1568,7 @@ def __ror__(self, other): return self | other def __and__(self, other): - if not ir.IntegerType.isinstance(self.mlir_dtype): + if not isinstance(self.mlir_dtype, ir.IntegerType): return NotImplemented return self._pointwise(arith.andi, other) @@ -1114,7 +1576,7 @@ def __rand__(self, other): return self & other def __xor__(self, other): - if not ir.IntegerType.isinstance(self.mlir_dtype): + if not isinstance(self.mlir_dtype, ir.IntegerType): return NotImplemented return self._pointwise(arith.xori, other) @@ -1170,60 +1632,86 @@ def __ge__(self, other): ) def _compare(self, other, *, f_pred, si_pred, ui_pred): - if ir.FloatType.isinstance(self.mlir_dtype): + if isinstance(self.mlir_dtype, ir.FloatType): pred = functools.partial(arith.cmpf, f_pred) - elif ir.IntegerType.isinstance(self.mlir_dtype): - if ir.IntegerType(self.mlir_dtype).is_signed: + elif isinstance(self.mlir_dtype, ir.IntegerType): + if self.is_signed: pred = functools.partial(arith.cmpi, si_pred) else: pred = functools.partial(arith.cmpi, ui_pred) else: - raise NotImplementedError + return NotImplemented return self._pointwise(pred, other, output_is_signed=False) - def max(self, other): - if ir.FloatType.isinstance(self.mlir_dtype): + def max(self, other) -> FragmentedArray: + if isinstance(self.mlir_dtype, ir.FloatType): maximumf = arith.maximumf - if ir.F32Type.isinstance(self.mlir_dtype): + if isinstance(self.mlir_dtype, ir.F32Type): maximumf = self._lift_fast_instr("max.NaN.f32") + elif isinstance(self.mlir_dtype, ir.F16Type): + maximumf = self._lift_fast_packed_instr("max.NaN.f16x2", "max.NaN.f16") + elif isinstance(self.mlir_dtype, ir.BF16Type): + maximumf = self._lift_fast_packed_instr("max.NaN.bf16x2", "max.NaN.bf16") return self._pointwise(maximumf, other) - elif ir.IntegerType.isinstance(self.mlir_dtype): + elif isinstance(self.mlir_dtype, ir.IntegerType): + width = utils.bitwidth(self.mlir_dtype) + if width == 16: + sign = "s" if self.is_signed else "u" + instr = self._lift_fast_packed_instr(f"max.{sign}16x2", f"max.{sign}16") + return self._pointwise(instr, other) return self._pointwise( arith.maxsi if self.is_signed else arith.maxui, other ) else: - return NotImplementedError + raise NotImplementedError - def min(self, other): - if ir.FloatType.isinstance(self.mlir_dtype): - return self._pointwise(arith.minimumf, other) - elif ir.IntegerType.isinstance(self.mlir_dtype): + def min(self, other) -> FragmentedArray: + if isinstance(self.mlir_dtype, ir.FloatType): + minimumf = arith.minimumf + if isinstance(self.mlir_dtype, ir.F32Type): + minimumf = self._lift_fast_instr("min.NaN.f32") + elif isinstance(self.mlir_dtype, ir.F16Type): + minimumf = self._lift_fast_packed_instr("min.NaN.f16x2", "min.NaN.f16") + elif isinstance(self.mlir_dtype, ir.BF16Type): + minimumf = self._lift_fast_packed_instr("min.NaN.bf16x2", "min.NaN.bf16") + return self._pointwise(minimumf, other) + elif isinstance(self.mlir_dtype, ir.IntegerType): + width = utils.bitwidth(self.mlir_dtype) + if width == 16: + sign = "s" if self.is_signed else "u" + instr = self._lift_fast_packed_instr(f"min.{sign}16x2", f"min.{sign}16") + return self._pointwise(instr, other) return self._pointwise( arith.minsi if self.is_signed else arith.minui, other ) else: - return NotImplementedError + raise NotImplementedError + + def copysign(self, other: FragmentedArray) -> FragmentedArray: + if not isinstance(self.mlir_dtype, ir.FloatType): + raise NotImplementedError + return self._pointwise(mlir_math.copysign, other) - def exp(self, *, approx: bool = False): - if not ir.FloatType.isinstance(self.mlir_dtype): + def exp(self, *, approx: bool = False) -> FragmentedArray: + if not isinstance(self.mlir_dtype, ir.FloatType): raise NotImplementedError if approx: dtype = self.mlir_dtype log2e = arith.constant(dtype, ir.FloatAttr.get(dtype, 1.4426950408889634)) - return (self * log2e).exp2() + return cast(FragmentedArray, self * log2e).exp2() return self._pointwise(mlir_math.exp) - def exp2(self, *, approx: bool = False): - if not ir.FloatType.isinstance(self.mlir_dtype): + def exp2(self, *, approx: bool = False) -> FragmentedArray: + if not isinstance(self.mlir_dtype, ir.FloatType): raise NotImplementedError if approx: - if not ir.F32Type.isinstance(self.mlir_dtype): + if not isinstance(self.mlir_dtype, ir.F32Type): raise NotImplementedError(self.mlir_dtype) return self._pointwise(self._lift_fast_instr("ex2.approx.ftz.f32")) return self._pointwise(mlir_math.exp2) - def log(self, *, approx: bool = False): - if not ir.FloatType.isinstance(self.mlir_dtype): + def log(self, *, approx: bool = False) -> FragmentedArray: + if not isinstance(self.mlir_dtype, ir.FloatType): raise NotImplementedError if approx: dtype = self.mlir_dtype @@ -1231,17 +1719,17 @@ def log(self, *, approx: bool = False): return self.log2(approx=True) * ln2 return self._pointwise(mlir_math.log) - def log2(self, *, approx: bool = False): - if not ir.FloatType.isinstance(self.mlir_dtype): + def log2(self, *, approx: bool = False) -> FragmentedArray: + if not isinstance(self.mlir_dtype, ir.FloatType): raise NotImplementedError(self.mlir_dtype) if approx: - if not ir.F32Type.isinstance(self.mlir_dtype): + if not isinstance(self.mlir_dtype, ir.F32Type): raise NotImplementedError(self.mlir_dtype) return self._pointwise(self._lift_fast_instr("lg2.approx.ftz.f32")) return self._pointwise(mlir_math.log2) - def sin(self, *, approx: bool = False): - if not ir.FloatType.isinstance(self.mlir_dtype): + def sin(self, *, approx: bool = False) -> FragmentedArray: + if not isinstance(self.mlir_dtype, ir.FloatType): raise NotImplementedError if approx and self.mlir_dtype != ir.F32Type.get(): raise NotImplementedError @@ -1249,8 +1737,8 @@ def sin(self, *, approx: bool = False): self._lift_fast_instr("sin.approx.f32") if approx else mlir_math.sin ) - def cos(self, *, approx: bool = False): - if not ir.FloatType.isinstance(self.mlir_dtype): + def cos(self, *, approx: bool = False) -> FragmentedArray: + if not isinstance(self.mlir_dtype, ir.FloatType): raise NotImplementedError if approx and self.mlir_dtype != ir.F32Type.get(): raise NotImplementedError @@ -1258,8 +1746,8 @@ def cos(self, *, approx: bool = False): self._lift_fast_instr("cos.approx.f32") if approx else mlir_math.cos ) - def tanh(self, *, approx: bool = False): - if not ir.FloatType.isinstance(self.mlir_dtype): + def tanh(self, *, approx: bool = False) -> FragmentedArray: + if not isinstance(self.mlir_dtype, ir.FloatType): raise NotImplementedError if approx and self.mlir_dtype != ir.F32Type.get(): raise NotImplementedError @@ -1267,8 +1755,8 @@ def tanh(self, *, approx: bool = False): self._lift_fast_instr("tanh.approx.f32") if approx else mlir_math.tanh ) - def rsqrt(self, *, approx: bool = False): - if not ir.FloatType.isinstance(self.mlir_dtype): + def rsqrt(self, *, approx: bool = False) -> FragmentedArray: + if not isinstance(self.mlir_dtype, ir.FloatType): raise NotImplementedError if approx and self.mlir_dtype != ir.F32Type.get(): raise NotImplementedError @@ -1276,10 +1764,39 @@ def rsqrt(self, *, approx: bool = False): self._lift_fast_instr("rsqrt.approx.f32") if approx else mlir_math.rsqrt ) + def abs(self) -> FragmentedArray: + if isinstance(self.mlir_dtype, ir.FloatType): + return self._pointwise(mlir_math.absf) + if isinstance(self.mlir_dtype, ir.IntegerType): + return self._pointwise(mlir_math.absi) + raise NotImplementedError + + def round(self) -> FragmentedArray: + """Same as `lax.round(..., AWAY_FROM_ZERO)`.""" + if not isinstance(self.mlir_dtype, ir.FloatType): + raise NotImplementedError + return self._pointwise(mlir_math.round) + + def round_even(self) -> FragmentedArray: + """Same as `lax.round(..., TO_NEAREST_EVEN)`.""" + if not isinstance(self.mlir_dtype, ir.FloatType): + raise NotImplementedError + return self._pointwise(mlir_math.roundeven) + + def erf(self) -> FragmentedArray: + if not isinstance(self.mlir_dtype, ir.FloatType): + raise NotImplementedError(self.mlir_dtype) + return self._pointwise(mlir_math.erf) + + def atan2(self, other: FragmentedArray) -> FragmentedArray: + if not isinstance(self.mlir_dtype, ir.FloatType): + raise NotImplementedError(self.mlir_dtype) + return self._pointwise(mlir_math.atan2, other) + @staticmethod def _lift_fast_instr( instr: str | Callable[[ir.Value], ir.Value], - ) -> Callable[[ir.Value], ir.Value]: + ) -> Callable[[ir.Value, ir.Value], ir.Value]: def fast_instr(*args): f32 = ir.F32Type.get() arg_ty = args[0].type @@ -1292,21 +1809,77 @@ def fast_instr(*args): ) else: return instr(*args) - elif ir.VectorType.isinstance(arg_ty): - index = ir.IndexType.get() + elif isinstance(arg_ty, ir.VectorType): result = llvm.mlir_undef(arg_ty) [vec_len] = ir.VectorType(arg_ty).shape for i in range(vec_len): - vs = [vector.extractelement(a, position=c(i, index)) for a in args] + vs = [ + vector.extract( + a, + dynamic_position=[], + static_position=ir.DenseI64ArrayAttr.get([i]), + ) + for a in args + ] vr = fast_instr(*vs) - result = vector.insertelement(vr, result, position=c(i, index)) + result = vector.insert( + vr, + result, + dynamic_position=[], + static_position=ir.DenseI64ArrayAttr.get([i]), + ) return result else: raise NotImplementedError(arg_ty) return fast_instr - def bitcast(self, elt: ir.Type, *, output_is_signed: bool | None = None): - if (output_is_signed is not None) != ir.IntegerType.isinstance(elt): + @staticmethod + def _lift_fast_packed_instr( + packed_instr: str, single_instr: str, + ) -> Callable[[ir.Value, ir.Value], ir.Value]: + def fast_instr(*args): + arg_ty = original_arg_ty = args[0].type + assert all(a.type == arg_ty for a in args) + if not isinstance(arg_ty, ir.VectorType): + args = [vector.broadcast(ir.VectorType.get((1,), arg_ty), a) for a in args] + arg_ty = ir.VectorType(args[0].type) + [vec_len] = arg_ty.shape + vec_bitwidth = vec_len * utils.bitwidth(arg_ty.element_type) + if vec_len == 1 or vec_bitwidth == 32: + assert vec_bitwidth.bit_count() == 1 + if vec_bitwidth == 32: + cstr = "r" + elif vec_bitwidth == 16: + cstr = "h" + else: + raise NotImplementedError(vec_bitwidth) + int_ty = ir.IntegerType.get_signless(vec_bitwidth) + args_ptx = ", ".join(f"${i}" for i in range(len(args) + 1)) + args_int = [utils.bitcast(a, int_ty) for a in args] + result_int = llvm.inline_asm( + int_ty, + args_int, + f"{single_instr if vec_len == 1 else packed_instr} {args_ptx};", + f"={cstr}" + f",{cstr}" * len(args) + ) + return utils.bitcast(result_int, original_arg_ty) + else: + assert vec_bitwidth > 32 + slice_len = 32 // utils.bitwidth(arg_ty.element_type) + offset = 0 + slices = [] + while offset < vec_len: + slice_end = min(offset + slice_len, vec_len) + args_slice = [utils.vector_slice(a, slice(offset, slice_end)) for a in args] + slices.append(fast_instr(*args_slice)) + offset = slice_end + return utils.vector_concat(slices) + return fast_instr + + def bitcast( + self, elt: ir.Type, *, output_is_signed: bool | None = None + ) -> FragmentedArray: + if (output_is_signed is not None) != isinstance(elt, ir.IntegerType): raise TypeError( "output_is_signed must be non-None if and only if the MLIR type is an" f" integer type, got {output_is_signed=} for {elt}" @@ -1315,7 +1888,7 @@ def bitcast(self, elt: ir.Type, *, output_is_signed: bool | None = None): if elt == self.mlir_dtype: return self reg_type = self.registers.flat[0].type - if ir.VectorType.isinstance(reg_type): + if isinstance(reg_type, ir.VectorType): reg_shape = ir.VectorType(reg_type).shape ty = ir.VectorType.get(reg_shape, elt) else: @@ -1325,38 +1898,99 @@ def bitcast(self, elt: ir.Type, *, output_is_signed: bool | None = None): lambda x: arith.bitcast(ty, x), output_is_signed=output_is_signed ) - def __getitem__(self, idx): - if self.layout != WGMMA_LAYOUT: - raise NotImplementedError("Only WGMMA layouts support slicing") + def __getitem__(self, idx) -> FragmentedArray: + base_idx, slice_shape, is_squeezed = utils.parse_indices(idx, self.shape) + if isinstance(self.layout, WGSplatFragLayout): + shape = tuple(d for d, s in zip(slice_shape, is_squeezed) if not s) + return self.splat(self.registers.item(), shape, is_signed=self.is_signed) + if not isinstance(self.layout, TiledLayout): + raise NotImplementedError("Only arrays with tiled layouts can be sliced") + if any(isinstance(idx, ir.Value) for idx in base_idx): + raise ValueError("Only slicing with static indices allowed") + if any(is_squeezed): + raise NotImplementedError("Integer indexing not implemented (only slicing allowed)") + base_tile_shape = self.layout.base_tile_shape + if untiled_rank := len(self.shape) - len(base_tile_shape): + base_tile_shape = (1,) * untiled_rank + base_tile_shape + if any(b % t for b, t in zip(base_idx, base_tile_shape, strict=True)): + raise ValueError( + "Base indices of array slices must be aligned to the beginning of a" + f" tile. The array uses a tiling of {base_tile_shape}, but your base" + f" indices are {base_idx}. Consider using a different array layout." + ) + if any(l % t for l, t in zip(slice_shape, base_tile_shape, strict=True)): + raise ValueError( + "The slice shape must be a multiple of the tile shape. The array" + f" uses a tiling of {base_tile_shape}, but your slice shape is" + f" {slice_shape}. Consider using a different array layout." + ) + register_slices = tuple( + slice(b // t, (b + l) // t) + for b, l, t in zip(base_idx, slice_shape, base_tile_shape, strict=True) + ) + new_regs = self.registers[register_slices] + return FragmentedArray( + _registers=new_regs, _layout=self.layout, _is_signed=self.is_signed + ) + + def __setitem__(self, idx: object, value: FragmentedArray) -> None: + if not isinstance(value, FragmentedArray): + raise ValueError(f"Expected a FragmentedArray, got: {value}") + if not isinstance(self.layout, TiledLayout): + raise NotImplementedError("Only arrays with tiled layouts can be sliced") base_idx, slice_shape, is_squeezed = utils.parse_indices(idx, self.shape) + if any(isinstance(idx, ir.Value) for idx in base_idx): + raise ValueError("Only slicing with static indices allowed") if any(is_squeezed): - raise NotImplementedError("Only slicing implemented") - if ( - base_idx[0] % 64 - or slice_shape[0] % 64 - or base_idx[1] % 8 - or slice_shape[1] % 8 + raise NotImplementedError("Integer indexing not implemented (only slicing allowed)") + if value.shape != tuple(slice_shape): + raise ValueError( + f"Slice has shape {tuple(slice_shape)}, but assigned array has shape" + f" {value.shape}" + ) + if value.mlir_dtype != self.mlir_dtype: + raise ValueError( + f"Array has dtype {value.mlir_dtype}, but assigned array has dtype" + f" {self.mlir_dtype}" + ) + if value.layout != self.layout: + raise ValueError( + f"Array has layout {value.layout}, but assigned array has layout" + f" {self.layout}" + ) + base_tile_shape = self.layout.base_tile_shape + if len(base_tile_shape) != len(self.shape): + raise NotImplementedError("Tiling has different rank than array") + if any( + b % t or l % t + for b, l, t in zip(base_idx, slice_shape, base_tile_shape, strict=True) ): raise NotImplementedError("Only tile aligned slicing supported") - base_idx[0] //= 64 - slice_shape[0] //= 64 - base_idx[1] //= 8 - slice_shape[1] //= 8 - new_regs = self.registers[ - base_idx[0] : base_idx[0] + slice_shape[0], - base_idx[1] : base_idx[1] + slice_shape[1], - ] + register_slices = tuple( + slice(b // t, (b + l) // t) + for b, l, t in zip(base_idx, slice_shape, base_tile_shape, strict=True) + ) + assert self.registers[register_slices].shape == value.registers.shape + self.registers[register_slices] = value.registers + + def copy(self) -> FragmentedArray: return FragmentedArray( - _registers=new_regs, _layout=self.layout, _is_signed=self.is_signed + _registers=np.copy(self.registers), + _layout=self.layout, + _is_signed=self.is_signed, ) # TODO(apaszke): Support JAX dtypes here as well? - def astype(self, new_dtype: ir.Type, *, is_signed: bool | None = None): + def astype( + self, new_dtype: ir.Type, *, is_signed: bool | None = None + ) -> FragmentedArray: i4 = ir.IntegerType.get_signless(4) i8 = ir.IntegerType.get_signless(8) i16 = ir.IntegerType.get_signless(16) i32 = ir.IntegerType.get_signless(32) bf16 = ir.BF16Type.get() + f32 = ir.F32Type.get() + f8e4m3fn = ir.Float8E4M3FNType.get() cur_dtype = self.mlir_dtype if cur_dtype == new_dtype: @@ -1365,8 +1999,12 @@ def astype(self, new_dtype: ir.Type, *, is_signed: bool | None = None): return FragmentedArray( _registers=self.registers, _layout=self.layout, _is_signed=is_signed ) - reg_type = self.registers.flat[0].type - is_vector_reg = ir.VectorType.isinstance(reg_type) + # Otherwise, mypy is unhappy with using ``idx`` for both range and + # np.ndenumerate. + idx: Any + any_reg = self.registers.flat[0] + reg_type = any_reg.type + is_vector_reg = isinstance(reg_type, ir.VectorType) reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else (1,) [vector_len] = reg_shape # This is meant to be a 1D assertion. if (new_reg_bitwidth := utils.bitwidth(new_dtype) * vector_len) % 8: @@ -1374,16 +2012,132 @@ def astype(self, new_dtype: ir.Type, *, is_signed: bool | None = None): "Register bitwidth in target type must be divisible by 8, got" f" {new_reg_bitwidth}" ) - if cur_dtype == i4 and self.is_signed and new_dtype == bf16: + # If the vector originates from a slice (common after relayouts), we + # can fuse the slicing into the conversion and reuse many + # preprocessing ops (shifts, prmts) accross different vectors. + regs_from_32bit_slice = ( + isinstance( + _slice_op := getattr(any_reg.owner, "opview", None), + vector.ExtractStridedSliceOp, + ) + and utils.bitwidth(_slice_op.source.type) == 32 + and _slice_op.strides[0].value == 1 + ) + def packed_registers( + dst_vector_len: int, *, if_not_sliced: bool + ) -> Iterable[tuple[Sequence[tuple[int, ...]], ir.Value]]: + """Tries to pack registers up to destination vector length.""" + if regs_from_32bit_slice and if_not_sliced: + for idx, reg in np.ndenumerate(self.registers): + yield [idx], reg + return + generator = np.ndenumerate(self.registers) + indices = [] + regs = [] + while True: + try: + for _ in range(max(dst_vector_len // vector_len, 1)): + idx, reg = next(generator) + indices.append(idx) + regs.append(reg) + yield indices, utils.vector_concat(regs) + regs.clear() + indices.clear() + except StopIteration: + break + if regs: + yield indices, utils.vector_concat(regs) + + if cur_dtype == i4 and new_dtype == f8e4m3fn: + # The algorithm here is taken from CUTLASS's `NumericArrayConverter` + # specialization for int4 -> f8e4m3, available at + # https://github.com/NVIDIA/cutlass/blob/5c6bca04414e06ce74458ab0a2018e2b8272701c/include/cutlass/numeric_conversion.h#L4982. + # Each call to the function below will upcast 4 contiguous nibbles of + # the input 32-bit register, and whether to select the 4 low nibbles or + # the 4 high nibbles is determined by the `part` argument. + def upcast_to_f8e4m3fn(reg: ir.Value, part: int): + lut = [ + 0x44403800, # [0, 1, 2, 3] encoded as f8e4m3fn + 0x4E4C4A48, # [4, 5, 6, 7] encoded as f8e4m3fn + 0xCACCCED0, # [-8, -7, -6, -5] encoded as f8e4m3fn + 0xB8C0C4C8, # [-4, -3, -2, -1] encoded as f8e4m3fn + ] + + sign = arith.shrui(arith.andi(reg, c(0x88888888, i32)), c(1, i32)) + # Ignore the sign when indexing into the LUT. + lut_idx = arith.andi(reg, c(0x77777777, i32)) + + assert 0 <= part < 2 + if part == 1: + lut_idx = arith.shrui(lut_idx, c(16, i32)) + sign = arith.shrui(sign, c(16, i32)) + + prmt_sign_pattern = arith.ori(sign, c(0x32103210, i32)) + return llvm.inline_asm( + i32, + [lut_idx, prmt_sign_pattern], + f""" + {{ + .reg .b32 pos_f8s, neg_f8s; + prmt.b32 pos_f8s, {lut[0]}, {lut[1]}, $1; + prmt.b32 neg_f8s, {lut[2]}, {lut[3]}, $1; + prmt.b32 $0, pos_f8s, neg_f8s, $2; + }} + """, + "=r,r,r", + ) new_registers = np.empty_like(self.registers) - out_vec_ty = ir.VectorType.get((vector_len,), new_dtype) - for idx, reg in np.ndenumerate(self.registers): - # The algorithm here is largely the same as CUTLASS's - # NumericArrayConverter specialization for int4 -> bf16 casts. - # We modify it slightly, because we only extract 2 values. - # We first shift the value by 4 bits, to put the high int4 in low bits. - # The prmt then blends the two values together, by putting them into the - # low bits of each 16-bit subword of our register. Then, we use the lop3 + + # TODO(apaszke,bchetioui): Using 8 helps some (but not all) cases. + # TODO(apaszke,bchetioui): Add the slice optimization here. + packing_width = 8 if vector_len == 2 else 4 + for indices, reg in packed_registers(packing_width, if_not_sliced=False): + [group_size] = ir.VectorType(reg.type).shape + assert group_size % vector_len == 0 + int_ty = ir.IntegerType.get_signless(group_size * 4) + reg_as_i32 = utils.bitcast(reg, int_ty) + if int_ty != i32: + reg_as_i32 = arith.extsi(i32, reg_as_i32) + out_i32_regs = [ + upcast_to_f8e4m3fn(reg_as_i32, part=part) + for part in range(max(group_size // 4, 1)) + ] + out_vec_int = utils.vector_concat([ + vector.broadcast(ir.VectorType.get((1,), i32), out_i32_reg) + for out_i32_reg in out_i32_regs + ]) + out_vector_len = len(out_i32_regs) * 4 + # Bitcast to i8 first to allow slicing as necessary, since LLVM chokes + # on f8 types. + out_vec = utils.bitcast( + out_vec_int, ir.VectorType.get((out_vector_len,), i8) + ) + offset = 0 + for idx in indices: + sliced_out_vec = utils.vector_slice( + out_vec, slice(offset, offset + vector_len) + ) + new_registers[idx] = utils.bitcast( + sliced_out_vec, ir.VectorType.get((vector_len,), f8e4m3fn) + ) + offset += vector_len + return FragmentedArray( + _registers=new_registers, _layout=self.layout, _is_signed=None + ) + if cur_dtype == i4 and self.is_signed and new_dtype == bf16 and vector_len % 2 == 0: + new_registers = np.empty_like(self.registers) + out_vec_ty = ir.VectorType.get((vector_len,), new_dtype) + # We use packed_registers for consistency, even though the packing is not + # really profitable here: the PTX below begins by an op dependent on the + # extracted part and so there are no ops that can be shared across packed + # parts. + for indices, reg in packed_registers(2, if_not_sliced=True): + # The algorithm here is largely the same as CUTLASS's + # NumericArrayConverter specialization for int4 -> bf16 casts. + # We modify it slightly, because we only extract 2 values. + # We first shift the value by 4 bits, to put the high int4 in low bits. + # The prmt then blends the two values together, by putting them into the + # low bits of each 16-bit subword of our register. Then, we use the lop3 # to zero any bits that don't belong to our int4s, and finally use the # XOR to: (1) set the exponent bits to 0x43 (at which point the mantissa # represents integer increments) and (2) flip the sign bit. If we @@ -1391,9 +2145,9 @@ def astype(self, new_dtype: ir.Type, *, is_signed: bool | None = None): # positive int4s will end up larger than negative int4s, with a bias of # 8. Use use the sub to subtract the base (our initial exponent) and the # bias coming from flipping the sign bit which is 136 (0x4308 as bits). - def upcast_to_bf16(reg: ir.Value, reg_shr: ir.Value, part: int): + def upcast_i4_to_bf16(reg: ir.Value, reg_shr: ir.Value, part: int): assert 0 <= part < 4 - return llvm.inline_asm( + int_reg = llvm.inline_asm( i32, [reg, reg_shr], f""" @@ -1407,48 +2161,115 @@ def upcast_to_bf16(reg: ir.Value, reg_shr: ir.Value, part: int): """, "=r,r,r", ) + return utils.bitcast(int_reg, ir.VectorType.get((2,), bf16)) + [group_size] = ir.VectorType(reg.type).shape + assert group_size % vector_len == 0 + assert group_size * 4 <= 32 + int_ty = ir.IntegerType.get_signless(group_size * 4) + # If the vector originates from a slice (common after relayouts), we + # can fuse the slicing into the conversion and prevent LLVM from + # generating a bunch of shifts to align the vector data to the LSB. + # This also lets us share the right shift among more vectors. + out_int_regs: list[ir.Value] = [] + if regs_from_32bit_slice: + slice_op = reg.owner.opview + slice_offset = slice_op.offsets[0].value + reg_int = utils.bitcast(slice_op.source, i32) + reg_int_shr = arith.shrui(reg_int, c(4, i32)) + assert slice_offset % 2 == 0 + out_int_regs.extend( + upcast_i4_to_bf16(reg_int, reg_int_shr, part=slice_offset // 2 + part) + for part in range(group_size // 2) + ) + else: + reg_slice_int = utils.bitcast(reg, int_ty) + if int_ty != i32: + reg_slice_int = arith.extsi(i32, reg_slice_int) + reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32)) + out_int_regs.extend( + upcast_i4_to_bf16(reg_slice_int, reg_slice_int_shr, part=part) + for part in range(group_size // 2) + ) + out_reg = utils.vector_concat(out_int_regs) offset = 0 - out_int_regs = [] - for group_size in (8, 4, 2): - int_ty = ir.IntegerType.get_signless(group_size * 4) - while vector_len - offset >= group_size: - # If the vector originates from a slice (common after relayouts), we - # can fuse the slicing into the conversion and prevent LLVM from - # generating a bunch of shifts to align the vector data to the LSB. - # This also lets us share the right shift among more vectors. - if (isinstance(slice_op := reg.owner.opview, vector.ExtractStridedSliceOp) - and utils.bitwidth(slice_op.vector.type) == 32 - and slice_op.strides[0].value == 1): - slice_offset = slice_op.offsets[0].value + offset - reg_int = utils.bitcast(slice_op.vector, i32) - reg_int_shr = arith.shrui(reg_int, c(4, i32)) - out_int_regs.extend( - upcast_to_bf16(reg_int, reg_int_shr, part=(slice_offset // 2 + part)) - for part in range(group_size // 2) - ) - else: - reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size)) - reg_slice_int = utils.bitcast(reg_slice, int_ty) - if int_ty != i32: - reg_slice_int = arith.extsi(i32, reg_slice_int) - reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32)) - out_int_regs.extend( - upcast_to_bf16(reg_slice_int, reg_slice_int_shr, part=part) - for part in range(group_size // 2) - ) - offset += group_size - assert offset == vector_len - out_vec_int = utils.vector_concat([ - vector.splat(ir.VectorType.get((1,), i32), reg) - for reg in out_int_regs - ]) - new_registers[idx] = utils.bitcast(out_vec_int, out_vec_ty) + for idx in indices: + new_registers[idx] = new_reg = utils.vector_slice( + out_reg, slice(offset, offset + vector_len) + ) + offset += vector_len + assert new_reg.type == out_vec_ty return FragmentedArray( _registers=new_registers, _layout=self.layout, _is_signed=None ) + if cur_dtype == i4 and self.is_signed and new_dtype == i8 and is_signed: + new_registers = np.empty_like(self.registers) + out_vec_ty = ir.VectorType.get((vector_len,), new_dtype) + for indices, reg in packed_registers(8, if_not_sliced=True): + def upcast_i4_to_i8(reg: ir.Value, first_valid_nibble: int = 0): + # When first_valid_nibble is >0, then only the nibbles in the range + # [first_valid_nibble, 8) will be upcast and placed in the low + # elements of the output vector. All high entries are undefined. + assert first_valid_nibble % 2 == 0 + low_prmt = "".join(str(min(first_valid_nibble // 2 + i, 7)) for i in [5, 1, 4, 0]) + high_prmt = "".join(str(min(first_valid_nibble // 2 + i, 7)) for i in [7, 3, 6, 2]) + # Note: (0xf0 & 0xaa) | (0xcc & ~0xaa) = 0xe4. lop3 acts as a blend. + # Below xN means the value of nibble N, sN means that all 4 bits are + # equal to the sign bit of nibble N, and 00 means an all 0 nibble. + out_struct = llvm.inline_asm( + ir.Type.parse("!llvm.struct<(i32, i32)>"), + [reg], + f""" + {{ + .reg .b32 high_even; // $2 is high_odd + .reg .b32 low_odd; // $2 is low_even + .reg .b32 sign_even, sign_odd; + .reg .b32 i8_odd, i8_even; + shl.b32 high_even, $2, 4; // x6x5x4x3x2x1x000 + prmt.b32 sign_even, high_even, high_even, 0xba98; // s6s6s4s4s2s2s0s0 + prmt.b32 sign_odd, $2, $2, 0xba98; // s7s7s5s5s3s3s1s1 + shr.u32 low_odd, $2, 4; // 00x7x6x5x4x3x2x1 + lop3.b32 i8_odd, sign_odd, low_odd, 0xf0f0f0f0, 0xe4; // s7x7s5x5s3x3s1x1 + lop3.b32 i8_even, sign_even, $2, 0xf0f0f0f0, 0xe4; // s6x6s4x4s2x2s0x0 + prmt.b32 $0, i8_even, i8_odd, 0x{low_prmt}; // s3x3s2x2s1x2s0x0 + prmt.b32 $1, i8_even, i8_odd, 0x{high_prmt}; // s7x7s6x5s4x4s3x3 + }} + """, + "=r,=r,r", + ) + i8_vec = ir.VectorType.get((4,), i8) + return utils.vector_concat([ + utils.bitcast(llvm.extractvalue(i32, out_struct, (i,)), i8_vec) + for i in range(2) + ]) + [group_size] = ir.VectorType(reg.type).shape + assert group_size % vector_len == 0 + assert group_size * 4 <= 32 + int_ty = ir.IntegerType.get_signless(group_size * 4) + if regs_from_32bit_slice: + slice_op = reg.owner.opview + slice_offset = slice_op.offsets[0].value + reg_int = utils.bitcast(slice_op.source, i32) + reg_i8 = upcast_i4_to_i8(reg_int, first_valid_nibble=slice_offset) + else: + reg_slice_int = utils.bitcast(reg, int_ty) + if int_ty != i32: + reg_slice_int = arith.extsi(i32, reg_slice_int) + reg_i8 = upcast_i4_to_i8(reg_slice_int) + + # distribute packed registers to original indices + offset = 0 + for idx in indices: + new_registers[idx] = new_reg = utils.vector_slice( + reg_i8, slice(offset, offset + vector_len) + ) + offset += vector_len + assert new_reg.type == out_vec_ty + return FragmentedArray( + _registers=new_registers, _layout=self.layout, _is_signed=is_signed + ) if cur_dtype == i8 and self.is_signed and new_dtype == bf16 and vector_len in {2, 4}: new_registers = np.empty_like(self.registers) - def upcast_to_bf16(reg, high): + def upcast_i8_to_bf16(reg, high): # We first embed the s8 into a bf16 with the exponent equal to # bias + mantissa bits. Then, we zero the msb that didn't fit into the # mantissa, zero out all bits other than msb, and subtract the last @@ -1471,15 +2292,17 @@ def upcast_to_bf16(reg, high): "=r,r", ) empty_vec_32 = llvm.mlir_undef(ir.VectorType.get((vector_len // 2,), i32)) + pad_vec_16 = llvm.mlir_undef(ir.VectorType.get((1,), i16)) for idx, reg in np.ndenumerate(self.registers): if vector_len == 2: reg_16 = vector.bitcast(ir.VectorType.get((1,), i16), reg) - new_reg_32 = upcast_to_bf16(reg_16, high=False) + reg_32 = utils.vector_concat([reg_16, pad_vec_16]) + new_reg_32 = upcast_i8_to_bf16(reg_32, high=False) new_vec_32 = llvm.insertelement(empty_vec_32, new_reg_32, c(0, i32)) elif vector_len == 4: reg_32 = vector.bitcast(ir.VectorType.get((1,), i32), reg) - low = upcast_to_bf16(reg_32, high=False) - high = upcast_to_bf16(reg_32, high=True) + low = upcast_i8_to_bf16(reg_32, high=False) + high = upcast_i8_to_bf16(reg_32, high=True) new_vec_32 = llvm.insertelement(empty_vec_32, low, c(0, i32)) new_vec_32 = llvm.insertelement(new_vec_32, high, c(1, i32)) else: @@ -1490,11 +2313,44 @@ def upcast_to_bf16(reg, high): return FragmentedArray( _registers=new_registers, _layout=self.layout, _is_signed=is_signed ) + # TODO(bchetioui): handle conversions to/from other float8 types. + if cur_dtype in {bf16, f32} and new_dtype == f8e4m3fn: + if vector_len != 2: + raise NotImplementedError(vector_len) + new_registers = np.empty_like(self.registers) + empty_vec_16 = llvm.mlir_undef(ir.VectorType.get((1,), i16)) + for idx, reg in np.ndenumerate(self.registers): + e0 = vector.extract( + reg, + dynamic_position=[], + static_position=ir.DenseI64ArrayAttr.get([0]), + ) + e1 = vector.extract( + reg, + dynamic_position=[], + static_position=ir.DenseI64ArrayAttr.get([1]), + ) + # TODO(bchetioui): can we do faster than this? + if cur_dtype == bf16: + e0 = arith.extf(f32, e0) + e1 = arith.extf(f32, e1) + new_reg_16 = llvm.inline_asm( + i16, + [e1, e0], + "cvt.rn.satfinite.e4m3x2.f32 $0, $1, $2;", + "=h,f,f", + ) + new_registers[idx] = vector.bitcast( + ir.VectorType.get((2,), f8e4m3fn), + llvm.insertelement(empty_vec_16, new_reg_16, c(0, i32))) + return FragmentedArray( + _registers=new_registers, _layout=self.layout, _is_signed=is_signed + ) # Generic path. - from_float = ir.FloatType.isinstance(cur_dtype) - to_float = ir.FloatType.isinstance(new_dtype) - from_integer = ir.IntegerType.isinstance(cur_dtype) - to_integer = ir.IntegerType.isinstance(new_dtype) + from_float = isinstance(cur_dtype, ir.FloatType) + to_float = isinstance(new_dtype, ir.FloatType) + from_integer = isinstance(cur_dtype, ir.IntegerType) + to_integer = isinstance(new_dtype, ir.IntegerType) if from_float and to_float: cur_ty_width = ir.FloatType(cur_dtype).width new_ty_width = ir.FloatType(new_dtype).width @@ -1514,7 +2370,7 @@ def upcast_to_bf16(reg, high): case WGStridedFragLayout() | TiledLayout(): shape = ir.VectorType(self.registers.flat[0].type).shape upcast_ty = ir.VectorType.get(shape, larger_ty) - case WGMMARowFragLayout() | WGSplatFragLayout(): + case WGSplatFragLayout(): upcast_ty = larger_ty case _: raise NotImplementedError(f"Unsupported layout {self.layout}") @@ -1539,7 +2395,7 @@ def upcast_to_bf16(reg, high): case WGStridedFragLayout() | TiledLayout(): shape = ir.VectorType(self.registers.flat[0].type).shape new_reg_ty = ir.VectorType.get(shape, new_dtype) - case WGMMARowFragLayout() | WGSplatFragLayout(): + case WGSplatFragLayout(): new_reg_ty = new_dtype case _: raise NotImplementedError(f"Unsupported layout {self.layout}") @@ -1549,134 +2405,354 @@ def upcast_to_bf16(reg, high): _registers=new_registers, _layout=self.layout, _is_signed=is_signed ) - # NOTE: scratch can be reused immediately once this function returns. - def reduce_sum(self, scratch: ir.Value | None = None): - if isinstance(self.layout, WGSplatFragLayout): - [reg] = self.registers.flat - if ir.FloatType.isinstance(self.mlir_dtype): - op = mulf - elif ir.IntegerType.isinstance(self.mlir_dtype): - op = arith.muli - else: - raise NotImplementedError(self.mlir_dtype) - return FragmentedArray.splat( - op(reg, utils.c(math.prod(self.shape), self.mlir_dtype)), - (), - is_signed=self.is_signed, - ) - - if not isinstance(self.layout, WGStridedFragLayout): - raise NotImplementedError(f"Unsupported layout {self.layout}") - - if scratch is None: - raise ValueError("scratch must be provided") - - if ir.FloatType.isinstance(self.mlir_dtype): - op = addf - elif ir.IntegerType.isinstance(self.mlir_dtype): - op = arith.addi - else: - raise NotImplementedError(self.mlir_dtype) - - result = c(0, self.mlir_dtype) - for reg in self.registers: - result = op( - result, - vector.reduction(self.mlir_dtype, vector.CombiningKind.ADD, reg), - ) - scratch_ty = ir.MemRefType(scratch.type) - if scratch_ty.element_type != self.mlir_dtype or scratch_ty.shape != [4]: - raise ValueError(f"Expected shape={(4,)}, {self.mlir_dtype} (got {scratch_ty})") - - index = ir.IndexType.get() - warp_result = utils.warp_tree_reduce(result, op, 32) - warp_id = arith.divui(gpu.thread_id(gpu.Dimension.x), c(32, index)) - memref.store(warp_result, scratch, [warp_id]) - utils.warpgroup_barrier() - zero_index = c(0, index) - with mgpu.single_thread(per_block=False): - scratch_vec = vector.load( - ir.VectorType.get((4,), self.mlir_dtype), - scratch, - [zero_index], - ) - scratch_sum = vector.reduction( - self.mlir_dtype, vector.CombiningKind.ADD, scratch_vec - ) - memref.store(scratch_sum, scratch, [zero_index]) - utils.warpgroup_barrier() - result = memref.load(scratch, [zero_index]) - utils.warpgroup_barrier() # Make sure everyone is done using scratch. - return FragmentedArray.splat(result, (), is_signed=self.is_signed) - - def reduce(self, op: str | Callable[[ir.Value, ir.Value], ir.Value], axis): + def reduce( + self, + op: str | Callable[[ir.Value, ir.Value], ir.Value], + axis: int | Sequence[int], + scratch: ir.Value | None = None, + ) -> FragmentedArray: + i32 = ir.IntegerType.get_signless(32) + if isinstance(axis, int): + axis = (axis,) + splat_op = None + redux_op = None + # TODO(apaszke): For associative reductions that reduce both inside and + # across warps, we could just have everyone use SMEM atomics instead of + # performing an explicit warp reduction in registers. if isinstance(op, str): match op: case "add": - if ir.FloatType.isinstance(self.mlir_dtype): + reduced_elems = math.prod(self.shape[a] for a in axis) + if isinstance(self.mlir_dtype, ir.FloatType): op = addf - elif ir.IntegerType.isinstance(self.mlir_dtype): + splat_op = lambda x: arith.mulf(x, c(reduced_elems, x.type)) + # TODO(apaszke): Use redux.sync on Blackwell for f32. + elif isinstance(self.mlir_dtype, ir.IntegerType): op = arith.addi + splat_op = lambda x: arith.muli(x, c(reduced_elems, x.type)) + if utils.bitwidth(self.mlir_dtype) == 32: + redux_op = functools.partial(utils.redux, kind=nvvm.ReduxKind.ADD) else: raise NotImplementedError(self.mlir_dtype) case "max": - if ir.F32Type.isinstance(self.mlir_dtype): + if isinstance(self.mlir_dtype, ir.F32Type): op = self._lift_fast_instr("max.NaN.f32") - elif ir.FloatType.isinstance(self.mlir_dtype): + if utils.get_arch().major == 10: + redux_op = functools.partial(utils.redux, kind=nvvm.ReduxKind.FMAX) + elif isinstance(self.mlir_dtype, ir.F16Type): + op = self._lift_fast_packed_instr("max.NaN.f16x2", "max.NaN.f16") + elif isinstance(self.mlir_dtype, ir.BF16Type): + op = self._lift_fast_packed_instr("max.NaN.bf16x2", "max.NaN.bf16") + elif isinstance(self.mlir_dtype, ir.FloatType): op = arith.maximumf - elif ir.IntegerType.isinstance(self.mlir_dtype): + elif isinstance(self.mlir_dtype, ir.IntegerType): op = arith.maxsi if self.is_signed else arith.maxui + if utils.bitwidth(self.mlir_dtype) == 32: + kind = nvvm.ReduxKind.MAX if self.is_signed else nvvm.ReduxKind.UMAX + redux_op = functools.partial(utils.redux, kind=kind) + else: + raise NotImplementedError(self.mlir_dtype) + splat_op = lambda x: x + case "min": + if isinstance(self.mlir_dtype, ir.F32Type): + op = self._lift_fast_instr("min.NaN.f32") + if utils.get_arch().major == 10: + redux_op = functools.partial(utils.redux, kind=nvvm.ReduxKind.FMIN) + elif isinstance(self.mlir_dtype, ir.F16Type): + op = self._lift_fast_packed_instr("min.NaN.f16x2", "min.NaN.f16") + elif isinstance(self.mlir_dtype, ir.BF16Type): + op = self._lift_fast_packed_instr("min.NaN.bf16x2", "min.NaN.bf16") + elif isinstance(self.mlir_dtype, ir.FloatType): + op = arith.minimumf + elif isinstance(self.mlir_dtype, ir.IntegerType): + op = arith.minsi if self.is_signed else arith.minui + else: + raise NotImplementedError(self.mlir_dtype) + splat_op = lambda x: x + case "prod": + reduced_elems = math.prod(self.shape[a] for a in axis) + if isinstance(self.mlir_dtype, ir.FloatType): + op = arith.mulf + # For splat, prod(x, x, ..., x) = x^n + splat_op = lambda x: mlir_math.powf( + x, c(float(reduced_elems), x.type) + ) + elif isinstance(self.mlir_dtype, ir.IntegerType): + op = arith.muli + # For splat, use repeated squaring to compute x^n + def int_pow(x, n=reduced_elems): + result = c(1, x.type) + base = x + while n > 0: + if n % 2 == 1: + result = arith.muli(result, base) + base = arith.muli(base, base) + n //= 2 + return result + splat_op = int_pow else: raise NotImplementedError(self.mlir_dtype) case _: raise ValueError(f"Unrecognized reduction operator: {op}") - if self.layout != WGMMA_LAYOUT: - raise NotImplementedError(self.layout) - if axis != 1: + assert not isinstance(op, str) + match self.layout: + case WGStridedFragLayout(shape=_, vec_size=vec_size): + if set(axis) != set(range(len(self.shape))): + raise NotImplementedError( + "Warpgroup strided layout only support reductions along all axes" + ) + # We reinterpret the data as a tiled layout. We're reducing it all anyway. + layout = TiledLayout( + tiling=Tiling(((128 * vec_size,), (32 * vec_size,), (vec_size,))), + warp_dims=(-3,), + lane_dims=(-2,), + vector_dim=-1, + ) + return FragmentedArray( + _registers=self.registers.reshape( + layout.registers_shape((math.prod(self.shape),)) + ), + _layout=layout, + _is_signed=self.is_signed, + ).reduce(op, 0, scratch) + case WGSplatFragLayout(): + if splat_op is None: + raise NotImplementedError( + "Splat reductions only supported when the operator is a string" + ) + assert not self.registers.shape + return FragmentedArray( + _registers=np.asarray( + splat_op(self.registers.item()), dtype=object + ), + _layout=WGSplatFragLayout( + tuple(d for a, d in enumerate(self.shape) if a not in axis) + ), + _is_signed=self.is_signed, + ) + case TiledLayout(): + pass + case _: + raise NotImplementedError(self.layout) + if len(self.layout.base_tile_shape) != len(self.shape): raise NotImplementedError + if isinstance(axis, int): + axis = (axis,) + layout = self.layout + tiled_tiling_shape = layout.tiled_tiling_shape + reduced_dims = layout.tiling.tile_dimension(axis[0]) + for a in axis[1:]: + reduced_dims = tuple( + r or d for r, d in zip(reduced_dims, layout.tiling.tile_dimension(a), strict=True) + ) + regs_shape = self.registers.shape + reduced_shape = tuple( + d if r else 1 for r, d in zip(reduced_dims, regs_shape, strict=True) + ) + remaining_shape = tuple( + 1 if r else d for r, d in zip(reduced_dims, regs_shape) + ) + out_regs = np.empty(remaining_shape, dtype=object) index = ir.IndexType.get() - i32 = ir.IntegerType.get_signless(32) - row_tile_dim = self.registers.shape[0] - row_subtile_dim = self.registers.shape[4] - new_regs = np.empty((row_tile_dim, row_subtile_dim), dtype=object) - assert self.registers.shape[-1] == 1 - for row_tile, row_subtile in np.ndindex(new_regs.shape): - # Reduce the registers owned by the current thread over n tiles - reg_index = [0] * self.registers.ndim - reg_index[0] = row_tile - reg_index[4] = row_subtile - thread_result_vec = self.registers[tuple(reg_index)] - for n_tile in range(1, self.registers.shape[1]): - reg_index[1] = n_tile - thread_result_vec = op( - thread_result_vec, self.registers[tuple(reg_index)] + for out_idx in np.ndindex(remaining_shape): + out_reg = None + for red_idx in np.ndindex(reduced_shape): + src_idx = tuple(o + r for o, r in zip(out_idx, red_idx)) + if out_reg is None: + out_reg = self.registers[src_idx] + else: + out_reg = op(out_reg, self.registers[src_idx]) + assert out_reg is not None + # Reduce within the vector dimension, if necessary. + if reduced_dims[layout.vector_dim]: + [vec_len] = ir.VectorType(out_reg.type).shape + scalar_out_reg = None + for i in range(vec_len): + scalar = vector.extract( + out_reg, + dynamic_position=[], + static_position=ir.DenseI64ArrayAttr.get([i]), + ) + scalar_out_reg = ( + scalar if scalar_out_reg is None else op(scalar_out_reg, scalar) + ) + out_reg = vector.broadcast( + ir.VectorType.get((1,), out_reg.type.element_type), scalar_out_reg ) - - thread_result = vector.extractelement(thread_result_vec, position=c(0, index)) - for i in range(1, self.layout.vector_length): - thread_result = op( - thread_result, - vector.extractelement(thread_result_vec, position=c(i, index)), + # Reduce across warp lanes, if necessary (using warp shuffles). + if any(reduced_dims[d] for d in layout.partitioned_lane_dims): + # TODO(apaszke): Reenable Redux after targeted optimization and benchmarking. + redux_op = None + if redux_op is not None: + mask = [True] # The bit significance grows together with the index. + mask_shift_bits = 0 + lane_stride = 1 + for d in layout.lane_dims[::-1]: + if isinstance(d, Replicated): + size = d.times + reduced = False + else: + size = tiled_tiling_shape[d] + reduced = reduced_dims[d] + if reduced: + mask = mask * size + else: + mask += [False] * (len(mask) * (size - 1)) + # This could really be computed as: + # d_idx = (lane_index // lane_stride) % size + # mask_shift += d_idx * lane_stride + # but if you look closely enough and realize that strides/sizes + # are powers of 2, the div/mod/mul is just an AND, and + is an OR: + # mask_shift |= lane_index & ((size - 1) * lane_stride) + # What's more, instead of repeatedly doing the AND/OR, we can just + # compute which bits of the lane_index we want to use statically, + # and use a single AND operation to extract them after the loop. + assert lane_stride.bit_count() == 1 and size.bit_count() == 1 + mask_shift_bits |= (size - 1) * lane_stride + lane_stride *= size + mask = sum(1 << i for i, m in enumerate(mask) if m) + lane_index = arith.remui(utils.thread_idx(), c(utils.WARP_SIZE, i32)) + mask_shift = arith.andi(lane_index, arith.constant(i32, mask_shift_bits)) + dyn_mask = arith.shli(arith.constant(i32, mask), mask_shift) + out_reg = redux_op(out_reg, dyn_mask) + else: + lane_stride = 1 + for d in layout.lane_dims[::-1]: # Iterate minor-to-major + if isinstance(d, Replicated): + lane_stride *= d.times + elif not reduced_dims[d]: + lane_stride *= tiled_tiling_shape[d] + else: + assert lane_stride.bit_count() == 1 + reduction_size = tiled_tiling_shape[d] + while reduction_size > 1: + other_out_reg = utils.shfl_bfly(out_reg, lane_stride) + out_reg = op(out_reg, other_out_reg) + lane_stride *= 2 + reduction_size //= 2 + assert lane_stride == WARP_SIZE, lane_stride + # TODO(apaszke): At the moment we do a barrier for every output register, + # which is very expensive. If we have enough scratch, we should just try + # using a single barrier for multiple reductions. + # Reduce across warps in the warpgroup, if necessary. + if any(reduced_dims[d] for d in layout.partitioned_warp_dims): + if scratch is None: + raise ValueError( + "scratch must be provided when cross-warp reduction is required" + ) + [vec_len] = ir.VectorType(out_reg.type).shape + scratch_ty = ir.MemRefType(scratch.type) + if scratch_ty.rank != 1: + raise ValueError(f"Expected rank 1 for scratch, got {scratch_ty.rank}") + if scratch_ty.element_type != self.mlir_dtype: + raise ValueError( + f"Expected element type {self.mlir_dtype} for scratch, got" + f" {scratch_ty.element_type}" + ) + # TODO(apaszke): All lanes that replicate data can share the same scratch. + # For now we treat the complete reduction as a special case. + reduces_all_dims = set(axis) == set(range(len(self.shape))) + unique_lanes = 1 if reduces_all_dims else 32 + if scratch_ty.shape[0] < WARPS_IN_WARPGROUP * unique_lanes * vec_len: + raise ValueError("Insufficient scratch space for cross-warp reduction") + if scratch_ty.get_strides_and_offset()[0] != [1]: + raise ValueError("Expected scratch to be contiguous") + thread_idx = utils.thread_idx() + if reduces_all_dims: + lane_idx = c(0, i32) + else: + lane_idx = arith.remui(thread_idx, c(WARP_SIZE, i32)) + warp_idx = arith.divui( + arith.remui(thread_idx, c(WARPGROUP_SIZE, i32)), c(WARP_SIZE, i32) ) - - # Do a shuffle to reduce in groups of 4 consecutive threads. - result = thread_result - for i in (1, 2): - other_result = nvvm.shfl_sync( - result.type, - c(0xFFFFFFFF, i32), - result, - c(i, i32), - c(0x1F, i32), - nvvm.ShflKind.bfly, + spill_base = arith.muli(lane_idx, c(WARPS_IN_WARPGROUP, i32)) + store_idx = arith.index_cast(index, arith.addi(spill_base, warp_idx)) + vector.store( + out_reg, scratch, [arith.muli(store_idx, c(vec_len, index))] ) - result = op(result, other_result) - new_regs[row_tile, row_subtile] = result + utils.warpgroup_barrier() + # warp_idx & warp_group_mask gives you the reduction group of the current warp. + if all(isinstance(d, int) and reduced_dims[d] for d in layout.warp_dims): + warp_offsets, warp_group_mask = [*range(WARPS_IN_WARPGROUP)], 0 + else: + # 4 has only two non-trivial prime factors: 2 and 2. + assert len(layout.warp_dims) == 2 + wd0, wd1 = layout.warp_dims + if isinstance(wd0, int) and reduced_dims[wd0]: + warp_offsets, warp_group_mask = [0, 2], 1 + else: + assert isinstance(wd1, int) and reduced_dims[wd1] + warp_offsets, warp_group_mask = [0, 1], 2 + reg_ty = out_reg.type + out_reg = None + warp_reduction_group = arith.andi(warp_idx, arith.constant(i32, warp_group_mask)) + for warp_offset in warp_offsets: + reduced_warp = arith.addi(warp_reduction_group, c(warp_offset, i32)) + load_idx = arith.index_cast( + index, + arith.muli(arith.addi(spill_base, reduced_warp), c(vec_len, i32)), + ) + part = vector.load(reg_ty, scratch, [load_idx]) + out_reg = part if out_reg is None else op(out_reg, part) + utils.warpgroup_barrier() # Make sure everyone is done using scratch. + out_regs[out_idx] = out_reg + # Infer the output layout and reshape the registers accordingly. + reduced_logical_shape = list(self.shape) + for a in sorted(axis, reverse=True): + del reduced_logical_shape[a] + if not reduced_logical_shape: # Complete reduction results in a splat. + reduced_layout: FragmentedLayout = WGSplatFragLayout(()) + assert out_regs.size == 1 + out_reg = out_regs.flat[0] + assert ir.VectorType(out_reg.type).shape == [1] + out_reg = vector.extract( + out_reg, + dynamic_position=[], + static_position=ir.DenseI64ArrayAttr.get([0]), + ) + out_regs = np.asarray(out_reg, dtype=object) + else: + reduced_layout = layout.reduce(axis) + out_regs = out_regs.reshape( + reduced_layout.registers_shape(tuple(reduced_logical_shape)) + ) return FragmentedArray( - _registers=new_regs, _layout=WGMMA_ROW_LAYOUT, _is_signed=self.is_signed + _registers=out_regs, _layout=reduced_layout, _is_signed=self.is_signed ) - def broadcast(self, shape): + def broadcast(self, shape) -> FragmentedArray: + if isinstance(self.layout, WGStridedFragLayout): + src_shape, dst_shape = self.layout.shape, shape + if len(src_shape) > len(dst_shape): + raise ValueError( + f"Shape length mismatch. Expected len({src_shape}) <= len({dst_shape})" + ) + if not all(s == 1 or s == d for s, d in zip(src_shape[::-1], dst_shape[::-1])): + raise ValueError( + "Can broadcast if all source dimensions match trailing target" + " dimensions by being equal or set to 1. Broadcasting from" + f" {src_shape} to {dst_shape}" + ) + rank_diff = len(dst_shape) - len(src_shape) + src_shape = tuple([1] * rank_diff + list(src_shape)) + + assert len(src_shape) == len(dst_shape), (src_shape, dst_shape) + len_suffix = next( + (i for i in range(len(src_shape)) if src_shape[~i] != dst_shape[~i]), + len(src_shape) + ) + if len_suffix > 0 and all(x == 1 for x in src_shape[:-len_suffix]): + return FragmentedArray( + _registers=np.tile(self.registers, np.prod(dst_shape[:-len_suffix])), + _layout=WGStridedFragLayout(shape, self.layout.vec_size), + _is_signed=self.is_signed, + ) + + raise NotImplementedError( + "Only major-most broadcast for WGStridedFragLayout is implemented." + f" Broadcasting from: {src_shape}, to: {dst_shape}." + ) + if not isinstance(self.layout, WGSplatFragLayout): raise NotImplementedError(self.layout) @@ -1692,7 +2768,7 @@ def broadcast(self, shape): _is_signed=self.is_signed, ) - def reshape(self, shape): + def reshape(self, shape: tuple[int, ...]) -> FragmentedArray: if self.shape == shape: return self if math.prod(shape) != math.prod(self.shape): @@ -1701,35 +2777,113 @@ def reshape(self, shape): match self.layout: case WGSplatFragLayout() | WGStridedFragLayout(): new_layout = dataclasses.replace(self.layout, shape=shape) + return FragmentedArray( + _registers=self.registers, + _layout=new_layout, + _is_signed=self.is_signed, + ) + case TiledLayout(): + base_tile_shape = self.layout.base_tile_shape + assert base_tile_shape + old_shape_suffix = self.shape[-len(base_tile_shape):] + new_shape_suffix = shape[-len(base_tile_shape):] + # We already know that old_shape_suffix[0] is divisible by + # base_tile_shape[0]. + if ( + old_shape_suffix[1:] != new_shape_suffix[1:] + or new_shape_suffix[0] % base_tile_shape[0] + ): + raise ValueError( + f"Can't reshape {self.shape} to {shape} with a tiled layout with" + f" base tile of {base_tile_shape}" + ) + new_registers_shape = self.layout.registers_shape(shape) + return FragmentedArray( + _registers=self.registers.reshape(new_registers_shape), + _layout=self.layout, + _is_signed=self.is_signed, + ) case _: raise NotImplementedError(self.layout) - return FragmentedArray( - _registers=self.registers, _layout=new_layout, _is_signed=self.is_signed - ) - - def broadcast_minor(self, n): - if self.layout != WGMMA_ROW_LAYOUT: - raise NotImplementedError + def broadcast_minor(self, n) -> FragmentedArray: + if len(self.shape) != 1: + raise ValueError("Broadcast minor is only supported for 1D arrays") if n % 8: - raise ValueError("Number of columns must be divisible by 8") - reg_shape = WGMMA_LAYOUT.registers_shape((self.shape[0], n)) - new_regs = np.empty(reg_shape, dtype=object) - dtype = self.mlir_dtype - for (row_tile, row_subtile), reg in np.ndenumerate(self.registers): - tile = [slice(None)] * len(new_regs.shape) - tile[0] = row_tile - tile[4] = row_subtile - new_regs[tuple(tile)] = vector.splat( - ir.VectorType.get((WGMMA_LAYOUT.vector_length,), dtype), reg + raise ValueError(f"The broadcast dimension must be a multiple of 8, got {n}") + if self.layout == WGMMA_ROW_LAYOUT: + new_layout = WGMMA_LAYOUT + elif self.layout == TCGEN05_ROW_LAYOUT: + new_layout = TCGEN05_LAYOUT + else: + raise NotImplementedError(self.layout) + return self.broadcast_in_dim((self.shape[0], n), (0,), new_layout) + + def broadcast_in_dim( + self, shape, source_dimensions, layout: FragmentedLayout + ) -> FragmentedArray: + for i, target_dim in enumerate(source_dimensions): + if self.shape[i] != shape[target_dim]: + raise ValueError( + f"Dimension {i} has size {self.shape[i]} in source shape and" + f" {shape[target_dim]} in shape after broadcast" + ) + if isinstance(self.layout, WGSplatFragLayout): + return type(self).splat( + self.registers.item(), shape, layout, is_signed=self.is_signed + ) + if isinstance(self.layout, WGStridedFragLayout) and isinstance(layout, WGStridedFragLayout): + new_dims = set(range(len(shape))) - set(source_dimensions) + vec_match = self.layout.vec_size == layout.vec_size + broadcast_dim_match = new_dims == set(range(len(new_dims))) + assert layout.shape == shape, (layout.shape, shape) + if vec_match and broadcast_dim_match: + return FragmentedArray( + _registers=np.tile( + self.registers, + np.prod(shape[:len(new_dims)]), + ), + _layout=layout, + _is_signed=self.is_signed, + ) + if not isinstance(self.layout, TiledLayout) or not isinstance(layout, TiledLayout): + raise NotImplementedError(self.layout, layout) + if any(d1 >= d2 for d1, d2 in zip(source_dimensions, source_dimensions[1:])): + raise NotImplementedError("source_dimensions must be strictly increasing") + if len(layout.base_tile_shape) != len(shape): + raise NotImplementedError("Tiling rank different than broadcast result rank") + new_dimensions = sorted(set(range(len(shape))) - set(source_dimensions)) + expected_layout = layout.reduce(new_dimensions) + if expected_layout != self.layout: + raise ValueError( + "Source and destination layouts aren't compatible for a broadcast" ) + new_registers_shape = layout.registers_shape(shape) + pre_broadcast_registers_shape = list(new_registers_shape) + for new_dim in new_dimensions: + for i, is_new in enumerate(layout.tiling.tile_dimension(new_dim)): + if is_new: + pre_broadcast_registers_shape[i] = 1 + # The broadcast for all dims but the vector_dim amounts to repeating the + # registers along the new dimensions. Along the vector_dim, we actually need + # to extend the vector length to change the type of the registers. + if layout.vector_length != self.layout.vector_length: + assert self.layout.vector_length == 1 + registers = np.empty_like(self.registers) + for idx, reg in np.ndenumerate(self.registers): + registers[idx] = utils.vector_concat([reg] * layout.vector_length) + else: + registers = self.registers + new_registers = np.broadcast_to( + registers.reshape(pre_broadcast_registers_shape), new_registers_shape, + ) return FragmentedArray( - _registers=new_regs, _layout=WGMMA_LAYOUT, _is_signed=self.is_signed + _registers=new_registers, _layout=layout, _is_signed=self.is_signed, ) def select(self, on_true, on_false): if ( - not ir.IntegerType.isinstance(self.mlir_dtype) + not isinstance(self.mlir_dtype, ir.IntegerType) or ir.IntegerType(self.mlir_dtype).width != 1 ): raise NotImplementedError @@ -1739,6 +2893,21 @@ def select(self, on_true, on_false): lambda t, p, f: arith.select(p, t, f), self, on_false, ) + @classmethod + def build( + cls, + shape: tuple[int, ...], + layout: FragmentedLayout, + fn: Callable[..., ir.Value], # ir.Value varargs, one for each dim + *, + is_signed: bool | None = None, + ) -> FragmentedArray: + undef = llvm.mlir_undef(ir.IntegerType.get_signless(32)) + dummy = cls.splat(undef, shape, layout, is_signed=False) + return dummy.foreach( + lambda _, idx: fn(*idx), create_array=True, is_signed=is_signed + ) + def foreach( self, fn: Callable[[ir.Value, tuple[ir.Value, ...]], ir.Value | None], @@ -1749,59 +2918,156 @@ def foreach( """Call a function for each value and index.""" index = ir.IndexType.get() new_regs = None - if create_array: - new_regs = np.full_like(self.registers, llvm.mlir_undef(self.registers.flat[0].type)) + orig_fn = fn + del fn + def wrapped_fn(*args): + nonlocal new_regs + result = orig_fn(*args) + old_reg_type = self.registers.flat[0].type + # Lazily create new_regs once we know the desired output type. + if create_array and new_regs is None: + assert result is not None + if isinstance(old_reg_type, ir.VectorType): + new_reg_type = ir.VectorType.get(old_reg_type.shape, result.type) + else: + new_reg_type = result.type + new_regs = np.full_like(self.registers, llvm.mlir_undef(new_reg_type)) + return result for mlir_idx, reg_idx in zip(self.layout.thread_idxs(self.shape), np.ndindex(self.registers.shape), strict=True): reg = self.registers[reg_idx] assert len(mlir_idx) == len(self.shape), (mlir_idx, self.shape) - [elems] = ir.VectorType(reg.type).shape - for i in range(elems): - i = c(i, index) - val = fn(vector.extractelement(reg, position=i), (*mlir_idx[:-1], arith.addi(mlir_idx[-1], i))) + if isinstance(reg.type, ir.VectorType): + [elems] = ir.VectorType(reg.type).shape + for i in range(elems): + c_i = c(i, index) + val = wrapped_fn( + vector.extract( + reg, + dynamic_position=[], + static_position=ir.DenseI64ArrayAttr.get([i]), + ), + (*mlir_idx[:-1], arith.addi(mlir_idx[-1], c_i)), + ) + if create_array: + assert new_regs is not None + new_regs[reg_idx] = vector.insert( + val, + new_regs[reg_idx], + dynamic_position=[], + static_position=ir.DenseI64ArrayAttr.get([i]), + ) + else: + val = wrapped_fn(reg, mlir_idx) if create_array: - new_regs[reg_idx] = vector.insertelement(val, new_regs[reg_idx], position=i) + assert new_regs is not None + new_regs[reg_idx] = val if create_array: + assert new_regs is not None return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed) - def debug_print(self, fmt: str): + def debug_print(self, fmt: str) -> None: idx_fmt = ", ".join(["{}"] * len(self.shape)) @self.foreach def _(val, idx): fmt_str = fmt.format(f"[{idx_fmt}]: {{}}") utils.debug_print(fmt_str, *idx, val, uniform=False) - def store_untiled(self, ref: ir.Value, *, vector_store: bool = True): - if not ir.MemRefType.isinstance(ref.type): + def store_untiled( + self, ref: ir.Value | utils.MultimemRef, *, swizzle: int = 16, optimized: bool = True + ) -> None: + if not isinstance(ref.type, ir.MemRefType): raise ValueError(ref) - - def vs_unsupported(): - if not vector_store: - raise NotImplementedError( - f"Can't use non-vector stores with layout {self.layout}" - ) - match self.layout: - case WGMMARowFragLayout(): - self._store_untiled_wgmma_row(ref) case WGSplatFragLayout(): - vs_unsupported() + if isinstance(ref, utils.MultimemRef): + raise NotImplementedError("Splat layout does not support multimem") + # All values are the same so swizzle does not affect anything here. self._store_untiled_splat(ref) case WGStridedFragLayout(): - vs_unsupported() - self._store_untiled_wg_strided(ref) + if swizzle != 16: + raise ValueError("Only TiledLayouts support swizzling") + assert isinstance(self.layout, WGStridedFragLayout) + for get, _update, ref, idx in self.transfer_strided(ref, self.layout.vec_size): + if isinstance(ref, utils.MultimemRef): + ref.store(get(self.registers), idx) + else: + vector.store(get(self.registers), ref, idx) case TiledLayout(): - self._store_untiled_tiled(ref, vector_store=vector_store) + ref_shape = ir.MemRefType(ref.type).shape + ref = utils.memref_reshape(ref, (*(1 for _ in ref_shape), *ref_shape)) + self.store_tiled(ref, swizzle=swizzle, optimized=optimized) case _: raise NotImplementedError(self.layout) + @classmethod + def load_reduce_untiled( + cls, + ref: utils.MultimemRef, + layout: TiledLayout | WGStridedFragLayout, + reduction: utils.MultimemReductionOp, + swizzle: int = 16, + is_signed: bool | None = None, + ): + ref_ty = ir.MemRefType(ref.type) + shape = tuple(ref_ty.shape) + if isinstance(layout, WGStridedFragLayout): + if swizzle != 16: + raise ValueError("Only TiledLayouts support swizzling") + registers = np.empty(layout.registers_shape(shape), dtype=object) + vec_ty = ir.VectorType.get((layout.vec_size,), ref_ty.element_type) + for _get, update, ref, idx in cls.transfer_strided(ref, layout.vec_size): + ptr = utils.memref_ptr(utils.memref_slice(ref.ref, tuple(idx))) + update(registers, utils.multimem_load_reduce(vec_ty, ptr, reduction, is_signed)) + return cls(_registers=registers, _layout=layout, _is_signed=is_signed) + ref = utils.memref_reshape(ref, (*(1 for _ in shape), *shape)) + return cls.load_tiled( + ref.ref, + swizzle=swizzle, + is_signed=is_signed, + layout=layout, + optimized=False, # multimem refs are always GMEM refs. + _load_fun=functools.partial( + utils.multimem_load_reduce, reduction=reduction, is_signed=is_signed + ), + # multimem_load_reduce supports vectors of narrow floats, so we don't + # need to do any casting. + _narrow_float_as_int=False, + ) + + @classmethod + def load_untiled( + cls, + ref: ir.Value, + *, + layout: TiledLayout, + swizzle: int = 16, + is_signed: bool | None = None, + optimized: bool = True, + ) -> FragmentedArray: + ref_ty = ir.MemRefType(ref.type) + ref = utils.memref_reshape(ref, (*(1 for _ in ref_ty.shape), *ref_ty.shape)) + return cls.load_tiled( + ref, swizzle=swizzle, is_signed=is_signed, layout=layout, optimized=optimized + ) + def _store_untiled_splat(self, ref: ir.Value): + if math.prod(self.shape) == 1: + c0 = c(0, ir.IndexType.get()) + memref.store( + self.registers.flat[0], ref, [c0] * len(ir.MemRefType(ref.type).shape) + ) + return + vec_size = 64 // mgpu.bitwidth(self.mlir_dtype) if np.prod(self.shape) < vec_size * WARPGROUP_SIZE: vec_size = 1 if np.prod(self.shape) % WARPGROUP_SIZE * vec_size: - raise ValueError(self.shape, WARPGROUP_SIZE, vec_size) + raise NotImplementedError( + "Arrays with the splat layout can only be stored when they have a" + f" single element or a multiple of {WARPGROUP_SIZE} elements" + ) fa = FragmentedArray.splat( self.registers.flat[0], @@ -1811,95 +3077,30 @@ def _store_untiled_splat(self, ref: ir.Value): ) fa.store_untiled(ref) - def _store_untiled_wg_strided(self, ref: ir.Value): - ref_ty = ir.MemRefType(ref.type) - try: - # Flattening the reference potentially produces simpler PTX but - # if the ref is not already 1D and has strided dimensions - # flattening won't work. We use a different variable for ref in - # case `NotImplementedError` is thrown by - # .linear_thread_idxs(). - ref_ = mgpu.memref_fold(ref, 0, len(ref_ty.shape)) - idxs = ([i] for i in self.layout.linear_thread_idxs()) - except NotImplementedError: - ref_ = ref - idxs = self.layout.thread_idxs() - ref_shape = tuple(ref_ty.shape) - if ref_shape != self.shape: - raise ValueError((ref_shape, self.shape)) - for idx, reg in zip(idxs, self.registers.flat): - vector.store(reg, ref_, idx) - - def _store_untiled_wgmma_row(self, ref: ir.Value): - """Stores an array with a WGMMA row layout.""" - assert self.layout == WGMMA_ROW_LAYOUT - index = ir.IndexType.get() - tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx()) - - is_first = arith.cmpi( - arith.CmpIPredicate.eq, arith.remui(tid, c(4, index)), c(0, index) - ) - # Consecutive groups of 4 threads hold the same value in this layout, - # therefore we only need to transfer data from one of them. - with utils.when(is_first): - for (idx,), value in zip( - self.layout.thread_idxs(self.shape), self.registers.flatten() - ): - memref.store(value, ref, [idx]) - - def _store_untiled_tiled(self, ref: ir.Value, *, vector_store: bool = True): - """Stores an array with a tiled layout. Not optimized at the moment.""" - if utils.bitwidth(self.mlir_dtype) < 8: - raise NotImplementedError(f"Can't store sub-byte types ({self.mlir_dtype=})") - i32 = ir.IntegerType.get_signless(32) - layout = self.layout - assert isinstance(layout, TiledLayout) - ref_strides, _ = ir.MemRefType(ref.type).get_strides_and_offset() - if vector_store and ref_strides[layout.vector_dim] != 1: - raise NotImplementedError( - "Can't use vector stores with non-unit minormost stride" - ) - strides = layout.tiling.tile_strides(ref_strides) - smem_space = ir.Attribute.parse("#gpu.address_space") - ref_space = ir.MemRefType(ref.type).memory_space - memory_space = None - if str(ref_space) == str(smem_space): - memory_space = 3 - elif ref_space: - raise NotImplementedError(f"Unexpected ref space {ref_space}") - ptr = utils.memref_ptr(ref, memory_space=memory_space) - # Fold warp and lane offsets into the pointer once, since they are dynamic. - dyn_strides = [ - arith.constant(i32, s) for s in strides[-layout.tiled_tiling_rank :] - ] - warp_offset = utils.dyn_dot(layout.warp_indices(), dyn_strides) - lane_offset = utils.dyn_dot(layout.lane_indices(), dyn_strides) - dyn_offset = arith.addi(warp_offset, lane_offset) - ptr = utils.getelementptr(ptr, [dyn_offset], self.mlir_dtype) - # All warp tile offsets are static and can be fused into the store. - for tile_idx, reg in np.ndenumerate(self.registers): - if vector_store: - elems = [reg] - else: - index = ir.IndexType.get() - elems = [ - vector.extractelement(reg, position=c(i, index)) - for i in range(ir.VectorType(reg.type).shape[0]) - ] - for i, e in enumerate(elems): - tile_idx_local = list(tile_idx) - tile_idx_local[layout.vector_dim] += i - tile_idx_local = list(tile_idx_local) - lin_idx = sum(i * s for i, s in zip(tile_idx_local, strides, strict=True)) - reg_ptr = utils.getelementptr(ptr, [lin_idx], self.mlir_dtype) - llvm.store(e, reg_ptr) - - def store_tiled(self, ref, swizzle: int | None): + def store_tiled(self, ref: ir.Value | utils.MultimemRef, swizzle: int | None, optimized: bool = True): if not isinstance(self.layout, TiledLayout): raise NotImplementedError(self.layout) layout, shape = self.layout, self.shape - for get, _, ptr in self.transfer_tiled2(ref, swizzle, layout, shape): - llvm.store(get(self.registers), ptr) + # Note that the loop below will "race" for layouts that replicate data. + # However, in that case all of the racing writes store the same data, which + # is ok in the CUDA memory model. + if isinstance(ref, utils.MultimemRef): + stores = self.transfer_tiled(ref.ref, swizzle, layout, shape, optimized) + for get, _update, _idx, ptr in stores: + utils.multimem_store(ptr, get(self.registers)) + else: + stores = self.transfer_tiled(ref, swizzle, layout, shape, optimized) + for get, _update, _idx, ptr in stores: + reg = get(self.registers) + reg_ty = ir.VectorType(reg.type) + element_bitwidth = utils.bitwidth(reg_ty.element_type) + if ( + isinstance(reg_ty.element_type, ir.FloatType) + and element_bitwidth <= 8 + ): + narrow_int = ir.IntegerType.get_signless(element_bitwidth) + reg = vector.bitcast(ir.VectorType.get(reg_ty.shape, narrow_int), reg) + llvm.store(reg, ptr) @classmethod def load_tiled( @@ -1909,120 +3110,102 @@ def load_tiled( *, is_signed: bool | None = None, layout: FragmentedLayout = WGMMA_LAYOUT, - ): + optimized: bool = True, + _load_fun: Callable[[ir.VectorType, ir.Value], ir.Value] = llvm.load, + _narrow_float_as_int: bool = True, + ) -> FragmentedArray: + if not isinstance(layout, TiledLayout): + raise NotImplementedError(layout) ref_ty = ir.MemRefType(ref.type) dtype = ref_ty.element_type - match layout: - case TiledLayout(): - ref_ty = ir.MemRefType(ref.type) - tiled_shape = ref_ty.shape - if len(tiled_shape) % 2: - raise ValueError("Tiled reference must have even rank") - tiling = Tiling((tiled_shape[len(tiled_shape) // 2 :],)) - shape = tiling.untile_shape(tiled_shape) - zero = ( - vector.splat( - ir.VectorType.get((layout.vector_length,), dtype), c(0, dtype) - ), - ) - registers = np.full(layout.registers_shape(shape), zero, dtype=object) - reg_ty = ir.VectorType.get((layout.vector_length,), ref_ty.element_type) - for _, update, ptr in cls.transfer_tiled2(ref, swizzle, layout, shape): - update(registers, llvm.load(reg_ty, ptr)) - case _: - raise NotImplementedError(layout) + tiled_shape = ref_ty.shape + if len(tiled_shape) % 2: + raise ValueError("Tiled reference must have even rank") + if len(tiled_shape) < 2: + raise ValueError("Tiled reference must have at least two dimensions") + tiling = Tiling((tiled_shape[len(tiled_shape) // 2 :],)) + shape = tiling.untile_shape(tiled_shape) + reg_ty = ir.VectorType.get((layout.vector_length,), dtype) + zero = vector.broadcast(reg_ty, c(0, dtype)) + registers = np.full(layout.registers_shape(shape), zero, dtype=object) + is_narrow_float = ( + isinstance(dtype, ir.FloatType) and utils.bitwidth(dtype) <= 8 + ) + narrow_int = ir.IntegerType.get_signless(utils.bitwidth(dtype)) + # Narrow floats are not supported by LLVM, so we need to transfer them as + # narrow ints and bitcast back to the desired type. + transfer_ty = ir.VectorType.get( + (layout.vector_length,), + narrow_int if is_narrow_float and _narrow_float_as_int else dtype + ) + loads = cls.transfer_tiled(ref, swizzle, layout, shape, optimized) + for _get, update, _idx, ptr in loads: + loaded_reg = _load_fun(transfer_ty, ptr) + if is_narrow_float and _narrow_float_as_int: + loaded_reg = vector.bitcast(reg_ty, loaded_reg) + update(registers, loaded_reg) return cls(_registers=registers, _layout=layout, _is_signed=is_signed) - @staticmethod - def transfer_tiled(shape, dtype, swizzle: int | None): - # TODO(apaszke): We could use ldmatrix/stmatrix for 16-bit types. - bw = mgpu.bitwidth(dtype) - m, n = shape - assert m % 64 == 0 and n % 8 == 0 # Implied by the layout. - cols_per_tile = swizzle_elems = (swizzle * 8) // bw - if n < swizzle_elems: - cols_per_tile = n - else: - assert n % swizzle_elems == 0, (n, swizzle_elems) - if swizzle not in {32, 64, 128}: - raise NotImplementedError("Only swizzled stores supported") - - c = arith.ConstantOp.create_index - tidx = arith.remui(gpu.thread_id(gpu.Dimension.x), c(WARPGROUP_SIZE)) - lane_id = arith.remui(tidx, c(32)) # {0, 1, ..., 31} - warp_id = arith.divui(tidx, c(32)) # {0, 1, 2, 3} - sub_row_base = arith.divui(lane_id, c(4)) # {0, 1, ..., 7} - if bw > 16: # Stagger is only necessary for values larger than 16bit. - # We split the rows into two groups (left/right) and change the order in - # which they perform accesses to avoid bank conflicts. - # It seems that the STS.64 is 2x faster (and the hardware reports no - # conflicts) when the conflicts are split between half-warps, as - # opposed to having them within the half-warp. This requires a - # little more work for the selects, but is ultimately worth it. - match swizzle: - case 128: - is_stagger_left = arith.cmpi( - arith.CmpIPredicate.eq, arith.remui(sub_row_base, c(2)), c(0) - ) - case 64: - is_stagger_left = arith.cmpi( - arith.CmpIPredicate.eq, - arith.remui(arith.divui(sub_row_base, c(2)), c(2)), - c(0), + @classmethod + def transfer_strided(self, ref: ir.Value, vec_size: int): + ref_ty = ir.MemRefType(ref.type) + layout = WGStridedFragLayout(shape=tuple(ref_ty.shape), vec_size=vec_size) + try: + # Flattening the reference potentially produces simpler PTX but + # if the ref is not already 1D and has strided dimensions + # flattening won't work. + ref = mgpu.memref_fold(ref, 0, len(ref_ty.shape)) + except ValueError: + if vec_size > 1: + ref_ty = ir.MemRefType(ref.type) + shape = ref_ty.shape + strides, _ = ref_ty.get_strides_and_offset() + # Try to fold contiguous dimension pairs. + for i in reversed(range(len(shape) - 1)): + if strides[i] == shape[i+1] * strides[i+1]: + ref = mgpu.memref_fold(ref, i, 2) + ref_ty = ir.MemRefType(ref.type) + shape = ref_ty.shape + strides, _ = ref_ty.get_strides_and_offset() + has_contiguous_dim = False + for size, stride in zip(shape, strides): + if stride == 1: + has_contiguous_dim = True + if size % vec_size != 0: + raise ValueError( + "The contiguous dimension of the reference must be a" + f" multiple of the layout's vector size (got {size} and" + f" vector size {vec_size})" + ) from None + elif size > 1: + if stride % vec_size != 0: + raise ValueError( + "Non-contiguous dimension of the reference must have strides" + " that are multiples of the layout's vector size (got" + f" {stride} and vector size {vec_size})" + ) from None + if not has_contiguous_dim: + raise ValueError( + "The reference must have a contiguous dimension when vec_size > 1" ) - case 32: - # 32-byte tiles of 4-byte types have only 8 columns so there is no way - # to stagger the memory accesses within a single tile. We could do it - # across tiles, but that would be a completely different scheme. - raise NotImplementedError - case _: - raise AssertionError(swizzle) - stagger_amount = swizzle // 64 - if (cols_per_tile // 8) % (stagger_amount * 2): - raise NotImplementedError + layout = WGStridedFragLayout(shape=tuple(ref_ty.shape), vec_size=vec_size) + idx_gen = layout.thread_idxs(tuple(ref_ty.shape)) else: - # We rely on canonicalization to clean up the selects. - i1 = ir.IntegerType.get_signless(1) - is_stagger_left = arith.constant(i1, ir.BoolAttr.get(True)) - stagger_amount = 0 - row_base = arith.addi(sub_row_base, arith.muli(warp_id, c(16))) - col_base = arith.muli(arith.remui(lane_id, c(4)), c(2)) # {0, 2, 4, 6} - # The swizzle pattern is constant for a given thread. - col_swizzle_bits = arith.muli( - arith.divui(sub_row_base, c(128 // swizzle)), c(128 // bw), - ) - for row_group in range(m // 64): - for col_group in range(n // cols_per_tile): - for row_subidx in range(2): - row = arith.addi(row_base, c(row_subidx * 8)) - for col_subidx in range(cols_per_tile // 8): - col_subidx_left = col_subidx - col_subidx_right = col_subidx ^ stagger_amount - col_off = arith.select( - is_stagger_left, c(col_subidx_left * 8), c(col_subidx_right * 8) - ) - col = arith.addi(col_base, col_off) - col = arith.xori(col, col_swizzle_bits) - reg_idx_left = col_subidx_left + col_group * (cols_per_tile // 8) - reg_idx_right = col_subidx_right + col_group * (cols_per_tile // 8) - left_idx = row_group, reg_idx_left, row_subidx, 0 - right_idx = row_group, reg_idx_right, row_subidx, 0 - idx = c(row_group), c(col_group), row, col - def get_register(regs, left_idx=left_idx, right_idx=right_idx): - value_left = regs[left_idx] - value_right = regs[right_idx] - return arith.select(is_stagger_left, value_left, value_right) - def update_registers(regs, new, left_idx=left_idx, right_idx=right_idx): - regs[left_idx] = arith.select(is_stagger_left, new, regs[left_idx]) - regs[right_idx] = arith.select(is_stagger_left, regs[right_idx], new) - yield get_register, update_registers, idx + idx_gen = map(lambda x: [x], layout.linear_thread_idxs()) + for i, vec_idx in enumerate(idx_gen): + def update(registers, reg, _i=i): + registers[_i] = reg + def get(registers, _i=i): + return registers[_i] + yield get, update, ref, vec_idx @staticmethod - def transfer_tiled2( + def transfer_tiled( ref: ir.Value, swizzle: int | None, layout: TiledLayout, shape: tuple[int, ...], + optimized: bool = True, ): """Generate a transfer schedule for a tiled layout. @@ -2049,57 +3232,71 @@ def transfer_tiled2( ref_tiling_shape = tuple(ref_ty.shape[ref_logical_rank:]) ref_tiling = Tiling((ref_tiling_shape,)) ref_strides, _ = ref_ty.get_strides_and_offset() - if ref_tiling.untile_shape(tuple(ref_ty.shape)) != shape: - raise ValueError() + if (ref_logical_shape := ref_tiling.untile_shape(tuple(ref_ty.shape))) != shape: + raise ValueError( + f"The reference has untiled shape of {ref_logical_shape} while the" + f" register array has shape {shape}" + ) nested_ref_shape = tuple( (ref_ty.shape[i], ref_ty.shape[i + ref_logical_rank]) + if ref_ty.shape[i + ref_logical_rank] != 1 else (ref_ty.shape[i],) for i in range(ref_logical_rank) ) nested_ref_strides = tuple( (ref_strides[i], ref_strides[i + ref_logical_rank]) + if ref_ty.shape[i + ref_logical_rank] != 1 else (ref_strides[i],) for i in range(ref_logical_rank) ) tiled_nested_shape, tiled_nested_strides = tiling.tile_nested_shape_strides( nested_ref_shape, nested_ref_strides ) - - # We could technically handle this case, but it would be quite complicated. - # If tiling dimensions would have to be expanded into multiple, we'd have to - # adjust the dimension indices in layouts, including expanding some of them - # into multiple indices. Note that for non-tiling dims, we allow the shape - # to be arbitrary, which is why we fix it up below in mem_idx_to_reg_idx. - if any( - len(dim_shape) != 1 for dim_shape in tiled_nested_shape[-layout.tiled_tiling_rank :] - ): - raise NotImplementedError("Memory and register tiling incompatible") - tiled_shape = list(itertools.chain.from_iterable(tiled_nested_shape)) - elem_tiled_strides = list(itertools.chain.from_iterable(tiled_nested_strides)) - elem_lane_strides = [elem_tiled_strides[d] for d in layout.lane_dims] - lane_shape = [tiled_shape[d] for d in layout.lane_dims] - if elem_tiled_strides[layout.vector_dim] != 1: - raise ValueError("Stride of the vectorized dimension should be 1") - for d in (layout.warp_dim, *layout.lane_dims, layout.vector_dim): - tiled_shape[d] = 1 - - element_bits = mgpu.bitwidth(dtype) - if (layout.vector_length * element_bits) % 8 != 0: - raise ValueError( - f"Vector length ({layout.vector_length}) must be a multiple of bytes," - f" but has {layout.vector_length * element_bits} bits" - ) - transfer_bytes = (layout.vector_length * element_bits) // 8 # Not sure if this is strictly required for all data types, but it certainly # is for sub-byte types (else we might not increment the pointer by whole bytes). if any( - s % layout.vector_length and i != layout.vector_dim and d != 1 - for i, (s, d) in enumerate_negative( - list(zip(elem_tiled_strides, tiled_shape)) - ) + any(s % layout.vector_length and d != 1 for s, d in zip(ss, ds)) + for i, (ss, ds) in enumerate_negative(list(zip(tiled_nested_strides, tiled_nested_shape))) + if i != layout.vector_dim ): raise ValueError( "Tiled strides must be a multiple of the vector length, except for the" " vector dimension" ) + if tiled_nested_strides[layout.vector_dim] != (1,): + raise ValueError( + "Vectorized dimension should not require further tiling and have a" + " stride of 1" + ) + + tiles_shape = list(tiled_nested_shape) + tiles_strides = list(tiled_nested_strides) + for d in (*layout.partitioned_warp_dims, *layout.partitioned_lane_dims, layout.vector_dim): + # We could avoid repeating the singleton dimensions, but it simplifies the + # code below that computes the register index for a given tile. + tiles_shape[d] = (1,) * len(tiles_shape[d]) + tiles_strides[d] = (0,) * len(tiles_strides[d]) + tiles_shape = list(itertools.chain.from_iterable(tiles_shape)) + tiles_strides = list(itertools.chain.from_iterable(tiles_strides)) + warp_shape = list(itertools.chain.from_iterable( + (d.times,) if isinstance(d, Replicated) else tiled_nested_shape[d] for d in layout.warp_dims + )) + warp_strides = list(itertools.chain.from_iterable( + (0,) if isinstance(d, Replicated) else tiled_nested_strides[d] for d in layout.warp_dims + )) + lane_shape = list(itertools.chain.from_iterable( + (d.times,) if isinstance(d, Replicated) else tiled_nested_shape[d] for d in layout.lane_dims + )) + lane_strides = list(itertools.chain.from_iterable( + (0,) if isinstance(d, Replicated) else tiled_nested_strides[d] for d in layout.lane_dims + )) + vector_length = layout.vector_length + + element_bits = mgpu.bitwidth(dtype) + if (vector_length * element_bits) % 8 != 0: + raise ValueError( + f"Vector length ({vector_length}) must be a multiple of bytes," + f" but has {vector_length * element_bits} bits" + ) + transfer_bytes = (vector_length * element_bits) // 8 if swizzle not in {16, 32, 64, 128}: raise ValueError("Only swizzled transfers supported") @@ -2109,32 +3306,68 @@ def transfer_tiled2( swizzle_group_transfers = 128 // transfer_bytes swizzle_groups_per_block = swizzle // 16 swizzle_block_transfers = swizzle_groups_per_block * swizzle_group_transfers - # Technically we should keep the vector_dim set to 1, but its shape is 1 - # so it does not matter. - transfer_tiled_strides = [s // layout.vector_length for s in elem_tiled_strides] - transfer_dtype = ir.VectorType.get((layout.vector_length,), dtype) - - plan = plan_tiled_transfer( - tiled_shape, elem_tiled_strides, lane_shape, elem_lane_strides, layout, - element_bits, swizzle - ) + if isinstance(dtype, ir.FloatType) and element_bits <= 8: + narrow_int = ir.IntegerType.get_signless(element_bits) + transfer_dtype = ir.VectorType.get((vector_length,), narrow_int) + else: + transfer_dtype = ir.VectorType.get((vector_length,), dtype) - # All offsets are in units of transfer_dtype. + if ref_ty.memory_space is None: + llvm_memory_space = None + elif utils.is_smem_ref(ref_ty): + llvm_memory_space = 3 + else: + raise ValueError(f"Unsupported memory space: {ref_ty.memory_space}") + + if optimized: + if llvm_memory_space != 3: + raise NotImplementedError("Only optimized transfers to SMEM supported") + plan = plan_tiled_transfer( + tiles_shape, tiles_strides, + warp_shape, warp_strides, + lane_shape, lane_strides, + vector_length, element_bits, swizzle + ) + else: + plan = TrivialTransferPlan() + + tiles_strides_transfer = [s // vector_length for s in tiles_strides] + # Technically we should keep the vector_dim stride set to 1, but its shape + # is 1 so it does not matter. dyn_tiled_strides = [ - c(s) for s in transfer_tiled_strides[-layout.tiled_tiling_rank :] + c(s // vector_length) + for s in itertools.chain.from_iterable( + tiled_nested_strides[-layout.tiled_tiling_rank :] + ) ] - lane_offset = utils.dyn_dot(layout.lane_indices(), dyn_tiled_strides) - warp_offset = utils.dyn_dot(layout.warp_indices(), dyn_tiled_strides) + # This expands a tiled index into a finer-grained index that accounts for + # the fact that some tiled dims are tiled further in the nested shape. + def expand_nested_dims(idxs: Sequence[ir.Value]) -> list[ir.Value]: + assert len(idxs) == layout.tiled_tiling_rank + new_idxs = [] + for idx, dim_shape in zip(idxs, tiled_nested_shape[-layout.tiled_tiling_rank :]): + if dim_shape == (1,): + new_idxs.append(idx) + continue + dim_strides = utils.get_contiguous_strides(dim_shape) + for i, (size, stride) in enumerate(zip(dim_shape, dim_strides)): + new_idx = arith.divui(idx, c(stride)) + if i != 0: # No need to apply rem to the first dim. + new_idx = arith.remui(new_idx, c(size)) + new_idxs.append(new_idx) + assert len(new_idxs) == sum(map(len, tiled_nested_shape[-layout.tiled_tiling_rank :])) + return new_idxs + # All offsets are in units of transfer_dtype. + lane_offset = utils.dyn_dot(expand_nested_dims(layout.lane_indices()), dyn_tiled_strides) + warp_offset = utils.dyn_dot(expand_nested_dims(layout.warp_indices()), dyn_tiled_strides) dyn_offset = arith.addi(lane_offset, warp_offset) - if ref_ty.memory_space != ir.Attribute.parse("#gpu.address_space"): - raise ValueError("Tiled stores can be performed into SMEM") - ptr = utils.memref_ptr(ref, memory_space=3) + ptr = utils.memref_ptr(ref, memory_space=llvm_memory_space) _as_consts = lambda consts: [c(const) for const in consts.tolist()] # This has bits set only for the offset bits that influence swizzling. swizzle_mask = swizzle_block_transfers - swizzle_tile_transfers - for tile_idx in np.ndindex(*tiled_shape): + for tile_idx in np.ndindex(*tiles_shape): indices = np.asarray([f(tile_idx) for f in plan.tile_index_transforms]) - const_offset = np.dot(indices, transfer_tiled_strides) + const_offset = np.dot(indices, tiles_strides_transfer) # We split the offset into a part that interacts with swizzling and a # part that doesn't. This lets us generate better code because constant # offsets can be fused into load and store instructions. @@ -2160,16 +3393,16 @@ def transfer_tiled2( def mem_idx_to_reg_idx(idx): reg_tiled_idx = [] base_idx = 0 - for dim_shape in tiled_nested_shape[:ref_logical_rank]: + for dim_shape in tiled_nested_shape: dim_strides = utils.get_contiguous_strides(dim_shape) dim_idxs = idx[base_idx:base_idx + len(dim_shape)] base_idx += len(dim_shape) reg_tiled_idx.append(sum(i * s for i, s in zip(dim_idxs, dim_strides))) - # We should have fixed up all but the tiling dims. - assert base_idx == len(idx) - layout.tiled_tiling_rank - return (*reg_tiled_idx, *idx[base_idx:]) + return tuple(reg_tiled_idx) reg_idxs = [mem_idx_to_reg_idx(idx) for idx in indices.tolist()] def get_register(regs, reg_idxs=reg_idxs): + # f8 data types are not handled by the LLVM dialect, so we need to + # transfer them as i8 and bitcast them back to f8. return plan.select([regs[reg_idx] for reg_idx in reg_idxs]) def update_registers(regs, new, reg_idxs=reg_idxs): # TODO(apaszke): If the staggering forms a permutation with a small @@ -2181,7 +3414,15 @@ def update_registers(regs, new, reg_idxs=reg_idxs): # we could save half the selects). for i, reg_idx in enumerate(reg_idxs): regs[reg_idx] = plan.select_if_group(i, regs[reg_idx], new) - yield get_register, update_registers, reg_ptr + def get_base_index(): + if not isinstance(plan, TrivialTransferPlan): + raise NotImplementedError( + "Base index computation only supported for trivial transfer plans" + ) + if any(len(t) != 1 for t in tiled_nested_shape): + raise NotImplementedError("Tiling too complicated") + return tiling.untile_indices(indices.tolist()[0]) + yield get_register, update_registers, get_base_index, reg_ptr def tree_flatten(self): aux = self.layout, self.registers.shape, self.is_signed @@ -2194,8 +3435,10 @@ def tree_unflatten(cls, aux, flat_registers): return cls(_registers=registers, _layout=layout, _is_signed=is_signed) +IndexTransform: TypeAlias = Callable[[tuple[int, ...]], tuple[int, ...]] + + class TransferPlan(Protocol): - IndexTransform = Callable[[tuple[int, ...]], tuple[int, ...]] tile_index_transforms: tuple[IndexTransform, ...] def select(self, group_elems: Sequence[ir.Value]) -> ir.Value: @@ -2255,14 +3498,33 @@ def select_if_group(self, group_idx: int, old: ir.Value, new: ir.Value) -> ir.Va def plan_tiled_transfer( - tiled_shape: Sequence[int], - tiled_strides: Sequence[int], + tiles_shape: Sequence[int], + tiles_strides: Sequence[int], + warp_shape: Sequence[int], + warp_strides: Sequence[int], lane_shape: Sequence[int], lane_strides: Sequence[int], - layout: TiledLayout, + vector_length: int, element_bits: int, swizzle: int, ) -> TransferPlan: + """Plans the tiled transfer in a way that avoids SMEM bank conflicts. + + Note that while xyz_shape length should always match the length of + xyz_strides, we do not require the iteration spaces of tiles/warps/lanes to + have the same rank. + + Arguments: + tiles_shape: The nd-iteration space over tiles. + tiles_strides: The memory strides (in elements) for each tile dimension. + warp_shape: The nd-iteration space over warps in warpgroup. + warp_strides: The memory strides (in elements) for each warp dimension. + lane_shape: The nd-iteration space over lanes in a warp. + lane_strides: The memory strides (in elements) for each lane dimension. + vector_length: The length of a single transfer. + element_bits: Element bitwidth. + swizzle: The swizzle pattern length. + """ i32 = ir.IntegerType.get_signless(32) c = lambda x: arith.constant(i32, x) # TODO(apaszke): Rewrite this function in terms of transfer_bytes (that we get @@ -2270,26 +3532,30 @@ def plan_tiled_transfer( swizzle_tile_elems = (16 * 8) // element_bits swizzle_group_elems = (128 * 8) // element_bits # Should be checked at the call site. - assert layout.vector_length * element_bits % 8 == 0 - transfer_bytes = (layout.vector_length * element_bits) // 8 + assert vector_length * element_bits % 8 == 0 + transfer_bytes = (vector_length * element_bits) // 8 # Below, all calculations are in elements, not in bytes, since it should # generalize better to sub-byte types. # Here, we verify two conditions: # 1. Each vector transfer only accesses addresses that fall within a single # swizzle tile (if not we'd need to split it and swizzle parts differently). + chain = itertools.chain transfer_alignment = math.gcd(*( s - for i, (s, d) in enumerate_negative(list(zip(tiled_strides, tiled_shape))) - if d > 1 or i in {layout.warp_dim, *layout.lane_dims} + for (s, d) in zip( + chain(tiles_strides, warp_strides, lane_strides), + chain(tiles_shape, warp_shape, lane_shape), + ) + if d > 1 )) if ( swizzle_tile_elems % transfer_alignment - and layout.vector_length <= transfer_alignment + and vector_length <= transfer_alignment ): raise ValueError( "Failed to prove that vector transfers don't cross swizzle tile" " boundaries. This check is incomplete, and does not guarantee that" - " this is a user error, but it might be." + str(transfer_alignment) + f" this is a user error, but it might be. {transfer_alignment=}" ) # 2. The transfer pattern does not cause bank conflicts. @@ -2307,23 +3573,35 @@ def plan_tiled_transfer( num_wavefronts = max(transfer_bytes // smem_bank_bytes, 1) wavefront_lanes = WARP_SIZE // num_wavefronts + lane_mask = np.full(lane_shape, False) + lane_mask[tuple(slice(0, 1) if s == 0 else slice(None) for s in lane_strides)] = True + wavefront_mask = lane_mask.reshape(num_wavefronts, wavefront_lanes) + lane_offsets_in_tile = np.dot(list(np.ndindex(*lane_shape)), lane_strides) def has_bank_conflicts(tile_idx_transform): - tile_idxs = np.unravel_index(np.arange(math.prod(tiled_shape)), tiled_shape) + num_tiles = math.prod(tiles_shape) + tile_idxs = np.unravel_index(np.arange(num_tiles), tiles_shape) tile_idxs = np.expand_dims(np.stack(tile_idxs, 1), 1) # [#tiles, 1, #dims] lane_tile_idx = tile_idx_transform(tile_idxs) # [#tiles, #lanes/1, #dims] assert lane_tile_idx.shape[1] in {1, WARP_SIZE} - lane_tile_offsets = np.dot(lane_tile_idx, tiled_strides) + lane_tile_offsets = np.dot(lane_tile_idx, tiles_strides) offsets = lane_tile_offsets + lane_offsets_in_tile # [#tiles, #lanes] assert offsets.shape[-1] == WARP_SIZE swizzle_groups = (offsets // swizzle_group_elems) % (swizzle // 16) swizzle_bits = swizzle_groups * swizzle_tile_elems lane_banks = ((offsets ^ swizzle_bits) // elems_per_bank) % num_banks wavefront_banks = lane_banks.reshape(-1, num_wavefronts, wavefront_lanes) - # Order of threads within the wavefront is unimportant. - wavefront_banks = np.sort(wavefront_banks, axis=-1) - # There are no conflicts if each wavefront only contains unique banks. - return np.any(wavefront_banks[..., 1:] == wavefront_banks[..., :-1]) + # We step over wavefronts since they might have a different number of lanes. + wavefront_banks = wavefront_banks.swapaxes(0, 1) + for banks, mask in zip(wavefront_banks, wavefront_mask): + banks = banks[:, mask] + # Order of threads within the wavefront is unimportant. + banks = np.sort(banks, axis=-1) + # There are no conflicts if each wavefront only contains unique banks. + repeats = np.any(banks[..., 1:] == banks[..., :-1]) + if repeats: + return True + return False # We don't need any special treatment if there are no conflicts when each lane # transfers the same tile at a time. @@ -2339,7 +3617,7 @@ def has_bank_conflicts(tile_idx_transform): # the lanes into more groups, but the selects will become more expensive if # we do that. It's a possibility we have if we need it. candidate_dims = ( - i for i, (s, d) in enumerate(zip(tiled_strides, tiled_shape)) + i for i, (s, d) in enumerate(zip(tiles_strides, tiles_shape)) if d > 1 and s % (SMEM_BANKS * elems_per_bank) ) for dim in candidate_dims: @@ -2349,17 +3627,17 @@ def has_bank_conflicts(tile_idx_transform): lane_group = (lane_id // group_stride) % 2 # We only consider a transformation where the second group stores to a # tile that's a constant offset (modulo dim size) from the first one. - for stagger in range(1, tiled_shape[dim]): - offset = np.zeros(len(tiled_shape), np.int64) + for stagger in range(1, tiles_shape[dim]): + offset = np.zeros(len(tiles_shape), np.int64) offset[dim] = stagger - transform = lambda idx: (idx + offset * lane_group) % tiled_shape + transform = lambda idx: (idx + offset * lane_group) % tiles_shape if not has_bank_conflicts(transform): # We've found a strategy that avoids bank conflicts! lane_idx = arith.remui(utils.thread_idx(), c(WARP_SIZE)) group_idx = arith.remui(arith.divui(lane_idx, c(group_stride)), c(2)) group_pred = arith.cmpi(arith.CmpIPredicate.ne, group_idx, c(0)) return StaggeredTransferPlan( - stagger, dim, tiled_shape[dim], group_pred + stagger, dim, tiles_shape[dim], group_pred ) raise ValueError( "Failed to synthesize a transfer pattern that avoids bank conflicts" @@ -2377,97 +3655,127 @@ def mulf(a: ir.Value, b: ir.Value): return arith.mulf(a, b, fastmath=arith.FastMathFlags.contract) -def optimization_barrier(*arrays: mgpu.FragmentedArray): +@overload +def optimization_barrier( + a: mgpu.FragmentedArray, + b: mgpu.FragmentedArray, + /, + *arrays: mgpu.FragmentedArray, +) -> Sequence[mgpu.FragmentedArray]: + ... + + +@overload +def optimization_barrier(a: mgpu.FragmentedArray) -> mgpu.FragmentedArray: + ... + + +def optimization_barrier(*arrays): """Acts as an optimization barrier for LLVM. Passing arrays through this function will make sure that they are computed before any side-effecting operations that follow this barrier. """ - index = ir.IndexType.get() i32 = ir.IntegerType.get_signless(32) + def _repack(regs_it, reg_ty): + if not isinstance(reg_ty, ir.VectorType): + result_reg = next(regs_it) + assert result_reg.type == reg_ty + return result_reg + + num_i32_regs = utils.bitwidth(reg_ty) // 32 + i32_reg_ty = ir.VectorType.get((num_i32_regs,), i32) + reg = llvm.mlir_undef(i32_reg_ty) + for i_elem in range(num_i32_regs): + val = llvm.bitcast(i32, next(regs_it)) + reg = llvm.insertelement(reg, val, arith.constant(i32, i_elem)) + return vector.bitcast(reg_ty, reg) + regs = [] reg_dtypes = [] reg_constraints = [] - repack_fns = [] # We unpack each array into a flat list of registers, and prepare the # functions that invert the transform in repack_fns. for array in arrays: reg_ty = array.registers.flat[0].type dtype = array.mlir_dtype - if ir.F32Type.isinstance(dtype): - if ir.VectorType.isinstance(reg_ty): + if isinstance(dtype, ir.F32Type) or dtype == i32: + if isinstance(reg_ty, ir.VectorType): [vec_len] = ir.VectorType(reg_ty).shape array_regs = [ # pylint: disable=g-complex-comprehension - vector.extractelement(reg, position=c(pos, index)) + vector.extract( + reg, + dynamic_position=[], + static_position=ir.DenseI64ArrayAttr.get([pos]), + ) for reg in array.registers.flat for pos in range(vec_len) ] - def _repack(regs, reg_ty=reg_ty): - reg = llvm.mlir_undef(reg_ty) - [vec_len] = ir.VectorType(reg_ty).shape - for i_elem in range(vec_len): - reg = llvm.insertelement( - reg, next(regs), arith.constant(i32, i_elem) - ) - return reg - repack_fns.append(_repack) else: array_regs = list(array.registers.flat) - repack_fns.append(lambda regs: next(regs)) - reg_constraint = "f" - elif ir.BF16Type.isinstance(dtype) or ir.F16Type.isinstance(dtype): - if not ir.VectorType.isinstance(reg_ty): + reg_constraint = "r" if dtype == i32 else "f" + elif utils.bitwidth(dtype) < 32: + reg_packing = 4 // utils.bytewidth(dtype) + if not isinstance(reg_ty, ir.VectorType): raise NotImplementedError(array.mlir_dtype) [vec_len] = ir.VectorType(reg_ty).shape - if vec_len != 2: + if vec_len % reg_packing: raise NotImplementedError(vec_len) - i32_reg_ty = ir.VectorType.get((1,), i32) + num_i32_regs = vec_len // reg_packing + i32_reg_ty = ir.VectorType.get((num_i32_regs,), i32) array_regs = [ - vector.extractelement( - vector.bitcast(i32_reg_ty, reg), position=c(0, index) + vector.extract( + vector.bitcast(i32_reg_ty, reg), + dynamic_position=[], + static_position=ir.DenseI64ArrayAttr.get([i]), ) + for i in range(num_i32_regs) for reg in array.registers.flat ] reg_constraint = "r" - def _repack(regs, reg_ty=reg_ty, i32_reg_ty=i32_reg_ty): - return vector.bitcast(reg_ty, vector.splat(i32_reg_ty, next(regs))) - repack_fns.append(_repack) else: raise NotImplementedError(array.mlir_dtype) regs += array_regs reg_dtypes += [array_regs[0].type] * len(array_regs) reg_constraints += [reg_constraint] * len(array_regs) - ptx_lines = [ - f"mov.b32 ${i}, ${len(reg_constraints)+i}" - for i in range(len(reg_constraints)) - ] - ptx = ";\n\t".join(ptx_lines) + ";" + ptx = "" all_reg_constraints = ",".join( - [*("=" + c for c in reg_constraints), *reg_constraints] + [*("=" + c for c in reg_constraints), *map(str, range(len(reg_constraints)))] ) - struct_ty = ir.Type.parse( - f"!llvm.struct<({','.join(map(str, reg_dtypes))})>" - ) - result_struct = llvm.inline_asm( - struct_ty, regs, ptx, all_reg_constraints, - asm_dialect=0, has_side_effects=True, - ) - regs = [ - llvm.extractvalue(dtype, result_struct, [i]) - for i, dtype in enumerate(reg_dtypes) - ] + + if len(reg_dtypes) == 1: + # The InlineAsm::verify() function doesn't allow a struct output when there + # is only one element (even though that seems to work for the case below). + result_elem = llvm.inline_asm( + reg_dtypes[0], regs, ptx, all_reg_constraints, + asm_dialect=0, has_side_effects=True, + ) + regs = [result_elem] + else: + struct_ty = ir.Type.parse( + f"!llvm.struct<({','.join(map(str, reg_dtypes))})>" + ) + result_struct = llvm.inline_asm( + struct_ty, regs, ptx, all_reg_constraints, + asm_dialect=0, has_side_effects=True, + ) + regs = [ + llvm.extractvalue(dtype, result_struct, [i]) + for i, dtype in enumerate(reg_dtypes) + ] + i32 = ir.IntegerType.get_signless(32) results = [] regs_it = iter(regs) - for array, repack_fn in zip(arrays, repack_fns, strict=True): + for array in arrays: num_regs = array.registers.size reg_ty = array.registers.flat[0].type - if ir.VectorType.isinstance(reg_ty): + if isinstance(reg_ty, ir.VectorType): reg_ty = ir.VectorType(reg_ty) new_registers = np.empty((num_regs,), dtype=object) for i_vreg in range(num_regs): - reg = repack_fn(regs_it) + reg = _repack(regs_it, reg_ty) assert reg.type == reg_ty, (reg.type, reg_ty) new_registers[i_vreg] = reg results.append( @@ -2477,4 +3785,127 @@ def _repack(regs, reg_ty=reg_ty, i32_reg_ty=i32_reg_ty): _is_signed=array.is_signed, ) ) - return results[0] if len(arrays) == 1 else results + # pytype cannot type check the return type of an overloaded function. + return results[0] if len(arrays) == 1 else results # pytype: disable=bad-return-type + + +def tiled_copy_smem_gmem_layout( + row_tiles: int, col_tiles: int, swizzle: int, bitwidth: int +) -> TiledLayout: + swizzle_elems = 8 * swizzle // bitwidth + if row_tiles % 4 == 0: + warp_row_tiles, warp_col_tiles = 4, 1 + elif row_tiles % 2 == 0: + if col_tiles % 2: + raise NotImplementedError("Number of tiles is not a multiple of 4") + warp_row_tiles, warp_col_tiles = 2, 2 + else: + if col_tiles % 4: + raise NotImplementedError("Number of tiles is not a multiple of 4") + warp_row_tiles, warp_col_tiles = 1, 4 + row_tiles //= warp_row_tiles + col_tiles //= warp_col_tiles + bytes_per_thread = min(16, 8 * swizzle // WARP_SIZE) + lane_row_tiles = lane_col_tiles = 1 + if bytes_per_thread < 16: # Try to splread multiple tiles over a warp. + max_scale_up = 16 // bytes_per_thread + while max_scale_up > 1 and col_tiles % 2 == 0: + max_scale_up //= 2 + lane_col_tiles *= 2 + col_tiles //= 2 + while max_scale_up > 1 and row_tiles % 2 == 0: + max_scale_up //= 2 + lane_row_tiles *= 2 + row_tiles //= 2 + bytes_per_thread *= lane_row_tiles * lane_col_tiles + if 8 * bytes_per_thread < bitwidth: + raise NotImplementedError("Element types with bitwidth so large aren't supported") + vector_length = bytes_per_thread * 8 // bitwidth + assert swizzle_elems % vector_length == 0 + # How many steps of vector transfers are needed to transfer a single tile? + if vector_length * WARP_SIZE > 8 * swizzle_elems: + steps_per_tile = 1 + else: + steps_per_tile = 8 * swizzle_elems // (vector_length * WARP_SIZE) + tile_rows_per_step = 8 // steps_per_tile + # There are two cases to consider here: either a single transfer fits within + # a single tile (lane_row_tiles == lane_col_tiles == 1), which is the case + # for large swizzles, or it spans multiple tiles. The layout below ensures + # that consecutive lanes first traverse the columns within a tile, followed + # by rows within a tile, columns across tiles, and then rows across tiles. + # This ensures we never end up with bank conflicts, and yields well + # coalesced GMEM accesses. + return TiledLayout( + Tiling( + ( + (warp_row_tiles * lane_row_tiles * 8, warp_col_tiles * lane_col_tiles * swizzle_elems), + (lane_row_tiles * 8, lane_col_tiles * swizzle_elems), + (8, swizzle_elems), + (tile_rows_per_step, swizzle_elems), + (vector_length,) + ) + ), + warp_dims=(-9, -8), + lane_dims=(-7, -6, -3, -2), + vector_dim=-1, + _check_canonical=False, + ).canonicalize() + + +def copy_tiled(src: ir.Value, dst: ir.Value, swizzle: int = 16): + """Copy the data from the src reference to the dst reference. + + Exactly one of src/dst should be in SMEM, while the other should be in GMEM. + The SMEM reference is expected to be tiled into (8, swizzle_elems) (as it + would for MMA), and so should have a rank larger by 2 than the GMEM ref. + """ + src_ty = ir.MemRefType(src.type) + dst_ty = ir.MemRefType(dst.type) + if math.prod(src_ty.shape) != math.prod(dst_ty.shape): + raise ValueError( + "Source and destination must have the same number of elements, but got" + f" source shape {src_ty.shape} and destination shape {dst_ty.shape}" + ) + if src_ty.element_type != dst_ty.element_type: + raise ValueError( + "Source and destination must have the same element type, but got" + f" source type {src_ty.element_type} and destination type" + f" {dst_ty.element_type}" + ) + bitwidth = utils.bitwidth(src_ty.element_type) + # Signedness doesn't matter, but we need to specify something for the + # intermediate arrays. + is_signed = False if isinstance(src_ty.element_type, ir.IntegerType) else None + if utils.is_smem_ref(src_ty) != utils.is_smem_ref(dst_ty): + if utils.is_smem_ref(src_ty): + smem_ty, gmem_ty = src_ty, dst_ty + else: + smem_ty, gmem_ty = dst_ty, src_ty + if smem_ty.rank != gmem_ty.rank + 2: + raise ValueError( + "SMEM reference must have a rank larger by 2 than the destination" + f" reference (due to 2D tiling), but got SMEM rank {smem_ty.rank} and" + f" destination rank {gmem_ty.rank}." + ) + swizzle_elems = 8 * swizzle // bitwidth + if smem_ty.shape[-2:] != [8, swizzle_elems]: + raise NotImplementedError( + f"For {swizzle=}, expected SMEM tiling to be (8, {swizzle_elems})" + ) + expected_src_shape = utils.tile_shape(gmem_ty.shape, (8, swizzle_elems)) + if tuple(smem_ty.shape) != expected_src_shape: + raise ValueError( + f"Expected SMEM reference to have shape {expected_src_shape} (tiling" + f" {gmem_ty.shape} by (8, {swizzle_elems})), but got {smem_ty.shape}" + ) + layout = tiled_copy_smem_gmem_layout( + *smem_ty.shape[-4:-2], swizzle, bitwidth # type: ignore[call-arg] + ) + if utils.is_smem_ref(src_ty): + regs = FragmentedArray.load_tiled(src, swizzle, is_signed=is_signed, layout=layout) + regs.store_untiled(dst, optimized=False) + else: + regs = FragmentedArray.load_untiled(src, is_signed=is_signed, layout=layout, optimized=False) + regs.store_tiled(dst, swizzle) + return + raise NotImplementedError(f"Unsupported copy: {src.type} -> {dst.type}") diff --git a/jax/experimental/mosaic/gpu/inference_utils.py b/jax/experimental/mosaic/gpu/inference_utils.py index 6362626404c5..13bab74d77f2 100644 --- a/jax/experimental/mosaic/gpu/inference_utils.py +++ b/jax/experimental/mosaic/gpu/inference_utils.py @@ -14,15 +14,17 @@ """Layout & transform inference convenience utils.""" -from collections.abc import Callable, Sequence -import enum +from collections.abc import Sequence from functools import partial -import itertools -from typing import cast +from typing import cast, Union from jax._src.lib.mlir import ir -MlirOperation = ir.Operation | ir.OpView +from . import fragmented_array as fa +from . import tcgen05 +from . import utils + +MlirOperation = Union[ir.Operation, ir.OpView] def in_layouts(op: MlirOperation) -> Sequence[ir.Attribute]: """Returns the in_layouts attribute of the given operation. @@ -68,11 +70,70 @@ def out_transforms(op: MlirOperation) -> Sequence[ir.Attribute]: return op.attributes["out_transforms"] # type: ignore +def in_tmem_layouts(op: MlirOperation) -> Sequence[ir.Attribute]: + """Returns the in_tmem_layouts attribute of the given operation. + + Raises: + ValueError: If the operation does not have an in_tmem_layouts attribute. + """ + if "in_tmem_layouts" not in op.attributes: + raise ValueError(f"{op} does not have an in_tmem_layouts attribute.") + return op.attributes["in_tmem_layouts"] # type: ignore + + +def out_tmem_layouts(op: MlirOperation) -> Sequence[ir.Attribute]: + """Returns the out_tmem_layouts attribute of the given operation. + + Raises: + ValueError: If the operation does not have an out_tmem_layouts attribute. + """ + if "out_tmem_layouts" not in op.attributes: + raise ValueError(f"{op} does not have an out_tmem_layouts attribute.") + return op.attributes["out_tmem_layouts"] # type: ignore + + +def should_have_in_tmem_layout(op: MlirOperation) -> bool: + """Returns 'true' if the operation operands should be assigned a TMEM layout.""" + return any( + isinstance(v.type, ir.MemRefType) and utils.is_tmem_ref(v) + for v in op.operands + ) + + +def should_have_out_tmem_layout(op: MlirOperation) -> bool: + """Returns 'true' if the operation results should be assigned a TMEM layout.""" + return any( + isinstance(v.type, ir.MemRefType) and utils.is_tmem_ref(v) + for v in op.results + ) + + +def should_have_tmem_layout(op: MlirOperation) -> bool: + """Returns 'true' if the operation should be assigned a TMEM layout.""" + return should_have_in_tmem_layout(op) or should_have_out_tmem_layout(op) + + +def has_in_tmem_layouts_set(op: MlirOperation) -> bool: + return "in_tmem_layouts" in op.attributes + + +def has_out_tmem_layouts_set(op: MlirOperation) -> bool: + return "out_tmem_layouts" in op.attributes + + +def should_have_in_layout(op: MlirOperation) -> bool: + """Returns 'true' if the operation operands should be assigned a layout.""" + return any(isinstance(v.type, ir.VectorType) for v in op.operands) + + +def should_have_out_layout(op: MlirOperation) -> bool: + """Returns 'true' if the operation results should be assigned a layout.""" + return any(isinstance(v.type, ir.VectorType) for v in op.results) + + def should_have_layout(op: MlirOperation) -> bool: """Returns 'true' if the operation should be assigned a layout.""" - - is_array = lambda v: ir.VectorType.isinstance(v.type) - return any(map(is_array, itertools.chain(op.operands, op.results))) # type: ignore + return should_have_in_layout(op) or should_have_out_layout(op) def has_in_layouts_set(op: MlirOperation) -> bool: @@ -95,13 +156,29 @@ def has_out_transforms_set(op: MlirOperation) -> bool: return "out_transforms" in op.attributes +def attr_element( + attr_name: str, op: MlirOperation, index: int +) -> ir.Attribute | None: + """Returns `op.attributes[attr_name][index]` if it exists, otherwise None. + + If `op.attributes[attr_name]` exists, then `index` must be a valid index into + the attribute array. + """ + if attr_name not in op.attributes: + return None + attr = op.attributes[attr_name] + if not attr: + return None + return op.attributes[attr_name][index] # type: ignore + + def _in_attr_for_operand( op: MlirOperation, operand: ir.Value, attr_name: str, ) -> ir.Attribute | None: if attr_name == "in_layouts": - predicate = lambda v: ir.VectorType.isinstance(v.type) + predicate = lambda v: isinstance(v.type, ir.VectorType) elif attr_name == "in_transforms": predicate = is_transformable_smem_memref else: @@ -109,9 +186,7 @@ def _in_attr_for_operand( operand_number = [o for o in op.operands if predicate(o)].index(operand) - if attr_name not in op.attributes: - return None - return op.attributes[attr_name][operand_number] # type: ignore + return attr_element(attr_name, op, operand_number) in_layout_for_operand = partial( @@ -121,22 +196,36 @@ def _in_attr_for_operand( _in_attr_for_operand, attr_name="in_transforms" ) + +def should_have_in_transforms(op: ir.OpView) -> bool: + """Returns 'True' if the operation should be assigned in transforms.""" + return any(map(is_transformable_smem_memref, op.operands)) + + +def should_have_out_transforms(op: ir.OpView) -> bool: + """Returns 'True' if the operation should be assigned out transforms.""" + return any(map(is_transformable_smem_memref, op.results)) + + +def should_have_transforms(op: ir.OpView) -> bool: + """Returns 'True' if the operation should be assigned in/out transforms.""" + return should_have_in_transforms(op) or should_have_out_transforms(op) + + def is_transformable_smem_memref(v: ir.Value) -> bool: """Whether the value is a memref in SMEM on which transforms should be applied.""" barrier_ty = ir.Type.parse("!mosaic_gpu.barrier") - smem = ir.Attribute.parse("#gpu.address_space") return ( - ir.MemRefType.isinstance(v.type) + isinstance(v.type, ir.MemRefType) # barriers have no business being transformed and v.type.element_type != barrier_ty # pylint: disable=attribute-error - and v.type.memory_space is not None # pylint: disable=attribute-error - and v.type.memory_space == smem # pylint: disable=attribute-error + and utils.is_smem_ref(v) ) def _value_attr(value: ir.Value, attr_type: str) -> ir.Attribute | None: if attr_type == "layouts": - predicate = lambda v: ir.VectorType.isinstance(v.type) + predicate = lambda v: isinstance(v.type, ir.VectorType) elif attr_type == "transforms": predicate = is_transformable_smem_memref else: @@ -175,7 +264,7 @@ def value_layout(value: ir.Value) -> ir.Attribute | None: Raises: ValueError: If `result` is not a Vector. """ - if not ir.VectorType.isinstance(value.type): + if not isinstance(value.type, ir.VectorType): raise ValueError(f"{value} is not a vector.") return _value_attr(value, "layouts") @@ -187,31 +276,29 @@ def value_transforms(value: ir.Value) -> ir.Attribute | None: Raises: ValueError: If `result` is not a memref. """ - if not ir.MemRefType.isinstance(value.type): + if not isinstance(value.type, ir.MemRefType): raise ValueError(f"{value} is not a memref.") return _value_attr(value, "transforms") -class TraversalOrder(enum.Enum): - """Traversal orders with respect to the data flow for IR.""" - - FORWARD = 1 - BACKWARDS = 2 - - -def traverse_op( - op: ir.OpView, - callback: Callable[[ir.OpView], None], - traversal_order: TraversalOrder = TraversalOrder.FORWARD, -): - """Traverses the operation and applies the callback in the given order.""" - for region in op.operation.regions: - for block in region: - if traversal_order == TraversalOrder.FORWARD: - ops_to_traverse = list(block) - else: - ops_to_traverse = reversed(list(block)) # type: ignore - for block_op in ops_to_traverse: - traverse_op(block_op, callback, traversal_order) - callback(op) +def is_mma_layout(layout: fa.FragmentedLayout) -> bool: + if not isinstance(layout, fa.TiledLayout): + return False + if layout in { + fa.WGMMA_LAYOUT, + fa.WGMMA_LAYOUT_ACC_32BIT, + fa.WGMMA_LAYOUT_UPCAST_2X, + fa.WGMMA_LAYOUT_UPCAST_4X, + fa.WGMMA_TRANSPOSED_LAYOUT, + fa.WGMMA_LAYOUT_8BIT, + fa.TCGEN05_LAYOUT, + fa.TCGEN05_TRANSPOSED_LAYOUT, + }: + return True + if len(layout.tiling.tiles[0]) != 2: + return False + columns = layout.tiling.tiles[0][1] + return columns % 16 == 0 and ( + layout == tcgen05.fa_m64_collective_layout(columns) + ) diff --git a/jax/experimental/mosaic/gpu/launch_context.py b/jax/experimental/mosaic/gpu/launch_context.py index ce432f26dac2..407407f00fed 100644 --- a/jax/experimental/mosaic/gpu/launch_context.py +++ b/jax/experimental/mosaic/gpu/launch_context.py @@ -19,11 +19,12 @@ import enum import functools import math -from typing import Any +from typing import Any, Literal from jax._src.lib import mosaic_gpu_dialect as mgpu_dialect from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith +from jaxlib.mlir.dialects import builtin from jaxlib.mlir.dialects import func from jaxlib.mlir.dialects import gpu from jaxlib.mlir.dialects import llvm @@ -31,15 +32,39 @@ from jaxlib.mlir.dialects import nvvm import numpy as np +from . import fragmented_array as fa from . import profiler from . import utils -# mypy: ignore-errors TMA_DESCRIPTOR_BYTES = 128 TMA_DESCRIPTOR_ALIGNMENT = 64 +TMAReductionOp = Literal[ + "add", + "min", + "max", + "inc", + "dec", + "and", + "or", + "xor", + "umin", + "umax", + "smin", + "smax", +] + +def _reduction_op_to_ptx(reduction_op: TMAReductionOp) -> str: + # convert [s|u]min|max to min|max + return reduction_op[-3:] c = utils.c # This is too common to fully qualify. +class GlobalBroadcast: + pass + +GLOBAL_BROADCAST = GlobalBroadcast() + + @dataclasses.dataclass(frozen=True) class MemRefTransform: def apply(self, ref: ir.Value) -> ir.Value: @@ -51,6 +76,9 @@ def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: raise NotImplementedError("Subclasses should override this method") + def transform_strides(self, shape: Sequence[int]) -> tuple[int, ...]: + raise NotImplementedError("Subclasses should override this method") + def batch(self, leading_rank: int) -> 'MemRefTransform': """Returns a transform that accepts a ref with the extra `leading_rank` dims. @@ -147,6 +175,14 @@ def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: *self.tiling, ) + def transform_strides(self, strides: Sequence[int]) -> tuple[int, ...]: + tiling_rank = len(self.tiling) + return ( + *strides[:-tiling_rank], + *(s * t for s, t in zip(strides[-tiling_rank:], self.tiling)), + *strides[-tiling_rank:], + ) + def batch(self, leading_rank: int) -> MemRefTransform: return self @@ -158,7 +194,7 @@ class TransposeTransform(MemRefTransform): def __post_init__(self): if len(self.permutation) != len(set(self.permutation)): - raise ValueError("Permutation must be a permutation") + raise ValueError("All elements of `permutation` must be unique") def apply(self, ref: ir.Value) -> ir.Value: return utils.memref_transpose(ref, self.permutation) @@ -169,6 +205,9 @@ def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: return tuple(shape[p] for p in self.permutation) + def transform_strides(self, strides: Sequence[int]) -> tuple[int, ...]: + return tuple(strides[p] for p in self.permutation) + def batch(self, leading_rank: int) -> MemRefTransform: return TransposeTransform( (*range(leading_rank), *(d + leading_rank for d in self.permutation)) @@ -228,21 +267,272 @@ def batch(self, leading_rank: int) -> MemRefTransform: OnDeviceProfiler = profiler.OnDeviceProfiler +MOSAIC_GPU_SMEM_ALLOC_ATTR = "mosaic_gpu_smem_alloc" + +class Scratch: + """Manages ops handling the GMEM scratch that contains the TMA descriptors. + + TMA descriptors are created on the host and then copied to GMEM. So there + needs to be some code on the host to allocate and initialize the TMA + descriptors. However, we only know what descriptors we need after we have + lowered the entire kernel. This class helps manage everything needed to + correctly allocate and initialize the scratch. + + To help reconcile the needs of kernels that use the dialect lowering with + those that use MGPU APIs directly, this class only creates the relevant ops + lazily. Eager creation would make them appear dead before dialect lowering + and MLIR's DCE would remove them. + + During the lowering, we collect information about how many bytes are needed + and also how each descriptor should be initialized on the host. At the end + of the lowering, the finalize_size() method should be called to add the + necessary code on the host to allocate and initialize all descriptors. + + Here's how the IR looks after the initial ops are created for the first time: + + + %1 = llvm.alloc_op {elem_type = !llvm.array<0 x i8>} -> !llvm.ptr + %2 = llvm.load_op (%1) : (!llvm.ptr) -> !llvm.array<0 x i8> + ... + %3 = gpu.launch async + ^bb0: + %4 = builtin.unrealized_conversion_cast_op(%2) + : (!llvm.array<256 x i8>) -> !llvm.ptr + + + And here is an example of how the IR could look like after finalize_size() is + called: + + + %11 = llvm.alloc_op {elem_type = !llvm.array<256 x i8>} -> !llvm.ptr + %22 = llvm.load_op (%11) : (!llvm.ptr) -> !llvm.array<256 x i8> + ... + # Ops inserted to initialize the tma descriptors on the host: + ... + %33 = llvm.getelementptr %11[0] : (!llvm.ptr) -> !llvm.ptr, i8 + call @mosaic_gpu_init_tma_desc (%33, ...) + ... + %44 = llvm.getelementptr %11[128] : (!llvm.ptr) -> !llvm.ptr, i8 + call @mosaic_gpu_init_tma_desc (%44, ...) + ... + %55 = gpu.launch async + ^bb0: + %66 = builtin.unrealized_conversion_cast_op(%22) + : (!llvm.array<256 x i8>) -> !llvm.ptr + + """ + def __init__(self, gpu_launch_op: gpu.LaunchOp): + self.next_offset: int = 0 + self.host_init: list[Callable[[ir.Value], None]] = [] + self._ops_created = False + + # Ideally, we would store the gpu.launch op directly. However, it gets + # invalidated by passes like "canonicalize". Thus we store the module and + # find the gpu.launch op from there when needed. + op = gpu_launch_op + while op.name != "builtin.module": + op = op.parent.opview + assert op is not None + self._module_op = op + + def _find_first_op( + self, op_name: str, block: ir.Block, tag_attribute_name: str | None = None + ) -> ir.OpView | None: + for op in block: + if op.name == op_name and ( + tag_attribute_name is None or tag_attribute_name in op.attributes + ): + return op + for region in op.regions: + for block in region: + child_op = self._find_first_op(op_name, block, tag_attribute_name) + if child_op is not None: + return child_op + return None + + def _create_ops(self): + if self._ops_created: + return + self._ops_created = True + + gpu_launch_op = self._find_first_op("gpu.launch", self._module_op.body) + assert gpu_launch_op is not None + + ptr_ty = ir.Type.parse("!llvm.ptr") + empty_arr_ty = ir.Type.parse("!llvm.array<0 x i8>") + i64 = ir.IntegerType.get_signless(64) + + with ir.InsertionPoint(gpu_launch_op): + alloc_op = llvm.AllocaOp( + ptr_ty, c(1, i64), empty_arr_ty, + alignment=TMA_DESCRIPTOR_ALIGNMENT + ) + # Tag the alloc op with an attribute so that we can find it later. + alloc_op.attributes[MOSAIC_GPU_SMEM_ALLOC_ATTR] = ir.UnitAttr.get() + load_op = llvm.LoadOp(empty_arr_ty, alloc_op) + + with ir.InsertionPoint.at_block_begin(gpu_launch_op.body.blocks[0]): + builtin.unrealized_conversion_cast([ptr_ty], [load_op]) + + def _find_alloc_load_and_device_ptr( + self, + ) -> tuple[llvm.AllocaOp, llvm.LoadOp, ir.Value]: + if not self._ops_created: + self._create_ops() + + alloc_op = self._find_first_op( + "llvm.alloca", self._module_op.body, MOSAIC_GPU_SMEM_ALLOC_ATTR + ) + assert alloc_op is not None + [alloc_user] = alloc_op.result.uses + load_op = alloc_user.owner + assert load_op.operation.name == "llvm.load" + [load_op_user] = load_op.result.uses + device_ptr = load_op_user.owner + assert device_ptr.operation.name == "builtin.unrealized_conversion_cast" + return alloc_op, load_op, device_ptr.result + + def device_ptr(self) -> ir.Value: + _, _, device_ptr = self._find_alloc_load_and_device_ptr() + return device_ptr + + def finalize_size(self): + """ + Allocates and initializes the host buffer. This needs to be done after + lowering, i.e. after all TMA descriptors have been recorded. Only then we + know what the scratch contains. + """ + if self.next_offset == 0: + return + alloc_op, load_op, _ = self._find_alloc_load_and_device_ptr() + + with ir.InsertionPoint(load_op): + gmem_scratch_bytes = self.next_offset + scratch_arr_ty = ir.Type.parse(f"!llvm.array<{gmem_scratch_bytes} x i8>") + alloc_op.elem_type = ir.TypeAttr.get(scratch_arr_ty) + load_op.result.set_type(scratch_arr_ty) + for init_callback in self.host_init: + init_callback(alloc_op.result) + + +class _DefaultPredicate: + pass + + +def _find_kernel_argument_for_gmem_ref( + gmem_ref: ir.Value, +) -> builtin.UnrealizedConversionCastOp: + """Returns the kernel argument value for a given gmem_ref. + + The kernel argument is expected to be an unrealized conversion cast. This + function will recursively go up block arguments in case of nested blocks. + """ + if not isinstance(gmem_ref.type, ir.MemRefType): + raise ValueError(f"Expected {gmem_ref} to have a memref type.") + + while isinstance(gmem_ref, ir.BlockArgument): + gmem_ref = gmem_ref.owner.owner.operands[gmem_ref.arg_number] + + # TODO(apaszke): This is a very approximate check. Improve it! + if not isinstance(gmem_ref.owner.opview, builtin.UnrealizedConversionCastOp): + raise NotImplementedError( + f"Expected {gmem_ref.owner} to be an unrealized conversion cast" + " corresponding to a GMEM kernel argument." + ) + return gmem_ref + + +def _is_tma_reduction_op_supported( + reduction_op: TMAReductionOp | None, dtype: ir.Type, +) -> bool: + """Returns whether the given TMA reduction op supports the given dtype. + + This function essentially implements the table at: + https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-reduce-async-bulk-tensor + with the following differences: + - For `add` reductions, we also support int64, treating it as uint64. + - For `and`, `or`, and `xor` reductions, we support signed integer types. + - For `inc` and `dec` reductions, we support both signed and unsigned i32 + treating both as unsigned. + """ + i32 = ir.IntegerType.get_signless(32) + i64 = ir.IntegerType.get_signless(64) + f16 = ir.F16Type.get() + f32 = ir.F32Type.get() + bf16 = ir.BF16Type.get() + + match reduction_op: + case None: + return True + case "add": + return dtype in (f16, f32, bf16, i32, i64) + case "max" | "min": + return dtype in (f16, bf16) + case "umax" | "umin" | "smax" | "smin": + return dtype in (i32, i64) + case "inc" | "dec": + return dtype == i32 + case "and" | "or" | "xor": + return dtype in (i32, i64) + + +def _tma_dma_type( + element_type: ir.Type, + reduction_op: TMAReductionOp | None, +) -> int: + """Returns the TMA DMA type for the given element type and signedness.""" + if isinstance(element_type, ir.IntegerType): + bitwidth = utils.bitwidth_impl(element_type) + if bitwidth == 2: + tma_dtype = 8 + elif bitwidth == 4: + tma_dtype = 0 + elif bitwidth == 8: + tma_dtype = 1 + elif bitwidth == 16: + tma_dtype = 2 + elif bitwidth == 32: + tma_dtype = 9 if reduction_op in ("smin", "smax") else 3 + elif bitwidth == 64: + tma_dtype = 10 if reduction_op in ("smin", "smax") else 4 + else: + raise ValueError(f"Unsupported integer bitwidth: {bitwidth}") + elif isinstance(element_type, ir.F16Type): + tma_dtype = 5 + elif isinstance(element_type, ir.F32Type): + tma_dtype = 6 + elif isinstance(element_type, ir.BF16Type): + tma_dtype = 7 + # We treat narrow floats as integers + elif isinstance(element_type, ir.Float8E5M2Type): + tma_dtype = 1 + elif isinstance(element_type, ir.Float8E4M3FNType): + tma_dtype = 1 + elif isinstance(element_type, ir.Float8E8M0FNUType): + tma_dtype = 1 + elif isinstance(element_type, ir.Float4E2M1FNType): + tma_dtype = 0 + else: + raise ValueError(f"unsupported TMA dtype {element_type}") + return tma_dtype + + +class AsyncCopyImplementation(enum.Enum): + TMA = enum.auto() + CP_ASYNC = enum.auto() + @dataclasses.dataclass() class LaunchContext: - launch_op: gpu.LaunchOp - gmem_scratch_ptr: ir.Value + module: ir.Module + scratch: Scratch cluster_size: tuple[int, int, int] profiler: OnDeviceProfiler | None = None - next_scratch_offset: int = 0 - host_scratch_init: list[Callable[[ir.Value], None]] = dataclasses.field( - default_factory=list, init=False - ) tma_descriptors: dict[ - tuple[ir.Value, tuple[int, ...], int | None, tuple[MemRefTransform, ...]], + tuple[ir.Value, tuple[int, ...], int | None, tuple[MemRefTransform, ...], Any, int], ir.Value, ] = dataclasses.field(default_factory=dict, init=False) + is_device_collective: bool = False @contextlib.contextmanager def named_region(self, *args, **kwargs): @@ -253,22 +543,60 @@ def named_region(self, *args, **kwargs): yield def cluster_idx( - self, dim: gpu.Dimension | Sequence[gpu.Dimension] | None = None + self, + dim: gpu.Dimension | Sequence[gpu.Dimension] | None = None, + dim_idx: ir.Value | Sequence[ir.Value] | None = None, ) -> ir.Value: - """Returns the index of a block within a subset of the cluster spanned by the given dimensions.""" + """Returns the linear index of a block within a subset of the cluster spanned by the given dimensions. + + dim_idx can be used to specify the index of another block along the selected + dimensions. If not provided, the current block's index is used. + """ if dim is None: dim = gpu.Dimension elif isinstance(dim, gpu.Dimension): dim = (dim,) + if dim_idx is None: + dim_idx = [gpu.cluster_block_id(d) for d in dim] + elif isinstance(dim_idx, ir.Value): + if len(dim) != 1: + raise ValueError( + "Expected a single dimension when passing a single index" + ) + dim_idx = [dim_idx] index = ir.IndexType.get() stride = 1 - idx = c(0, index) - for d in sorted(dim): + lin_idx = c(0, index) + for d, idx in sorted(zip(dim, dim_idx, strict=True), key=lambda x: x[0]): if self.cluster_size[d] == 1: # Optimize a multiply by 0. continue - idx = arith.addi(idx, arith.muli(gpu.cluster_block_id(d), c(stride, index))) + lin_idx = arith.addi(lin_idx, arith.muli(idx, c(stride, index))) stride *= self.cluster_size[d] - return idx + return lin_idx + + def get_cluster_ref(self, ref: ir.Value, dim: gpu.Dimension, idx: ir.Value): + i32 = ir.IntegerType.get_signless(32) + # We replace the offset in the ref type by 0, because memref_ptr always + # folds the offset into the pointer. + ref_ty = ir.MemRefType(ref.type) + strides, _ = ref_ty.get_strides_and_offset() + result_type = ir.MemRefType.get( + ref_ty.shape, + ref_ty.element_type, + ir.StridedLayoutAttr.get(0, strides), + None, + ) + if ref_ty.memory_space != ir.Attribute.parse("#gpu.address_space"): + raise ValueError(f"Expected SMEM but got: {ref.memory_space}") + idxs = [gpu.cluster_block_id(d) for d in gpu.Dimension] + idxs[dim] = idx + flat_block = arith.index_cast(i32, self.cluster_idx(gpu.Dimension, idxs)) # type: ignore + return utils.ptr_as_memref( + utils.get_cluster_ptr( + utils.memref_ptr(ref, memory_space=3), flat_block + ), + result_type, + ) def _alloc_scratch( self, @@ -286,32 +614,40 @@ def _alloc_scratch( ptr_ty = ir.Type.parse("!llvm.ptr") if alignment is None: alignment = size - if self.next_scratch_offset % alignment: + if self.scratch.next_offset % alignment: raise NotImplementedError # TODO(apaszke): Pad to match alignment - alloc_base = self.next_scratch_offset - self.next_scratch_offset += size + alloc_base = self.scratch.next_offset + self.scratch.next_offset += size def host_init_wrapped(host_ptr): host_init( - llvm.getelementptr(ptr_ty, host_ptr, [], [alloc_base], i8) + llvm.getelementptr(ptr_ty, host_ptr, [], [alloc_base], i8, llvm.GEPNoWrapFlags.none) ) - self.host_scratch_init.append(host_init_wrapped) + self.scratch.host_init.append(host_init_wrapped) # with ir.InsertionPoint(self.gmem_scratch_ptr.owner): # There is no way to create an insertion point after an operation... gep = llvm.GEPOp( - ptr_ty, self.gmem_scratch_ptr, [], [alloc_base], i8 + ptr_ty, self.scratch.device_ptr(), [], [alloc_base], i8, llvm.GEPNoWrapFlags.none ) - gep.move_after(self.gmem_scratch_ptr.owner) + gep.move_after(self.scratch.device_ptr().owner) return device_init(gep.result) def _get_tma_desc( self, - gmem_ref, + gmem_ref: ir.Value, gmem_transform: tuple[MemRefTransform, ...], + gmem_peer_id: int | ir.Value | GlobalBroadcast | None, transformed_slice_shape: tuple[int, ...], swizzle: int | None, + reduction_op: TMAReductionOp | None, ): - tma_desc_key = (gmem_ref, transformed_slice_shape, swizzle, gmem_transform) + gmem_ref = _find_kernel_argument_for_gmem_ref(gmem_ref) + tma_dtype = _tma_dma_type(ir.MemRefType(gmem_ref.type).element_type, reduction_op) + # Using ir.Values in cache keys is a little sketchy, but I think it should + # be fine. Having it in the key will keep it alive, and if comparison and + # hashing is by identity then it should work out. + tma_desc_key = (gmem_ref, transformed_slice_shape, swizzle, gmem_transform, gmem_peer_id, tma_dtype) if (tma_desc := self.tma_descriptors.get(tma_desc_key, None)) is None: + i32 = ir.IntegerType.get_signless(32) i64 = ir.IntegerType.get_signless(64) ptr_ty = ir.Type.parse("!llvm.ptr") def init_tma_desc(host_ptr): @@ -320,14 +656,51 @@ def init_tma_desc(host_ptr): ref = t.apply(ref) ref_ty = ir.MemRefType(ref.type) # TODO(apaszke): Use utils.memref_ptr to compute base_ptr + strides, _ = ref_ty.get_strides_and_offset() + if strides[-1] != 1: + raise ValueError( + "TMA requires the stride of the last dimension after" + " transforming the GMEM reference to be 1, but it is" + f" {strides[-1]}." + ) + _, offset, *sizes_and_strides = memref.extract_strided_metadata(ref) aligned_ptr_idx = memref.extract_aligned_pointer_as_index(ref) as_i64 = lambda i: arith.index_cast(i64, i) alloc_ptr = llvm.inttoptr(ptr_ty, as_i64(aligned_ptr_idx)) llvm_dyn = -2147483648 # TODO(apaszke): Improve the MLIR bindings... base_ptr = llvm.getelementptr( - ptr_ty, alloc_ptr, [as_i64(offset)], [llvm_dyn], ref_ty.element_type, + ptr_ty, alloc_ptr, [as_i64(offset)], [llvm_dyn], ref_ty.element_type, llvm.GEPNoWrapFlags.none, ) + if isinstance(gmem_peer_id, GlobalBroadcast): + self._ensure_nvshmem_decls() + world_team = arith.constant(i32, 0) + base_ptr = llvm.call( + base_ptr.type, + [world_team, base_ptr], + [], + [], + callee="nvshmemx_mc_ptr", + ) + elif gmem_peer_id is not None: + if not isinstance(gmem_peer_id, ir.Value): + peer_id = c(gmem_peer_id, i32) + else: + try: + # We try to reproduce the gmem_peer_id computation on the host. + peer_id = _recompute_peer_id(gmem_peer_id, fuel=16) + except ReplicationError as e: + raise ValueError( + "Failed to recompute the async_copy peer id on the host" + ) from e + self._ensure_nvshmem_decls() + base_ptr = llvm.call( + base_ptr.type, + [base_ptr, peer_id], + [], + [], + callee="nvshmem_ptr", + ) rank = ref_ty.rank assert rank * 2 == len(sizes_and_strides) swizzle_arg = ( @@ -337,10 +710,11 @@ def init_tma_desc(host_ptr): ) # TODO(apaszke): Better verification (e.g. slice is non-zero) # TODO(apaszke): We always know strides statically. + dtype_or_bitwidth = c(tma_dtype, i64) args = [ host_ptr, base_ptr, - c(utils.bitwidth(ref_ty.element_type), i64), + dtype_or_bitwidth, c(rank, i64), utils.pack_array([as_i64(i) for i in sizes_and_strides[:rank]]), utils.pack_array([as_i64(i) for i in sizes_and_strides[rank:]]), @@ -361,118 +735,80 @@ def cast_tma_desc(device_ptr): self.tma_descriptors[tma_desc_key] = tma_desc return tma_desc - def async_copy( + def _prepare_async_copy( self, - *, - src_ref, - dst_ref, - gmem_slice: Any = (), - gmem_transform: MemRefTransform | tuple[MemRefTransform, ...] = (), - barrier: utils.BarrierRef | None = None, - swizzle: int | None = None, - arrive: bool | None = None, - uniform: bool = True, - collective: Sequence[gpu.Dimension] | gpu.Dimension | None = None, - partitioned: int | None = None, - predicate: ir.Value | None = None, # Should select 0 or 1 threads from the WG. + gmem_ref: ir.Value, + gmem_slice: Any, + gmem_transform: tuple[MemRefTransform, ...], + collective: Sequence[gpu.Dimension] | None, + partitioned: int | None, + implementation: AsyncCopyImplementation, ): - """Initiates an async copy between GMEM and SMEM. - - Exactly one of `src_ref` and `dst_ref` must be in GMEM and in SMEM, and the - SMEM reference must be contiguous. The GMEM window that is read or written - to is specified by the `gmem_slice`. The copy can change the order in which - the data appears in the window by applying a sequence of transforms to the - GMEM reference (as specified by `gmem_transform`). - - When `collective` is specified (only allowed for GMEM -> SMEM copies), the - identical async_copy must be scheduled by all blocks that share the same - coordinates along collective dimensions within a cluster. The behavior is - undefined otherwise. The semantics of collective loads depend further on the - `partitioned` argument: - - - If `partitioned` is not specified, all blocks load the same data into - their shared memory and all receive the update in their barriers, unless - `arrive` is False. If `arrive` is False, you should expect the barrier to - have expect_tx incremented by the same amount of bytes as if `collective` - was not specified. - - If `partitioned` is specified, each block only loads a separate slice of - the data into SMEM, partitioned into equal tiles along the `partitioned` - dimension. In this case only the barrier of the first block in the - collective will have its expect_tx incremented by the total size of the - transfer across all blocks involved in the collective. Barriers supplied - by other blocks will be ignored (even if `arrive` is True). - """ + """Performs setup common to TMA and CP_ASYNC implementations.""" index = ir.IndexType.get() - i16 = ir.IntegerType.get_signless(16) - i32 = ir.IntegerType.get_signless(32) - smem = ir.Attribute.parse("#gpu.address_space") - src_ref_ty = ir.MemRefType(src_ref.type) - dst_ref_ty = ir.MemRefType(dst_ref.type) - element_type = src_ref_ty.element_type - element_bitwidth = utils.bitwidth(element_type) - if element_type != dst_ref_ty.element_type: - raise ValueError( - f"Expected same element type, got {element_type} and" - f" {dst_ref_ty.element_type}" - ) - if predicate is not None and not uniform: - raise ValueError("Predicate can only be defined when uniform is True") - if not isinstance(gmem_transform, tuple): - gmem_transform = (gmem_transform,) - if src_ref_ty.memory_space is None and dst_ref_ty.memory_space == smem: - gmem_ref, smem_ref = src_ref, dst_ref - if barrier is None: - raise ValueError("Barriers are required for GMEM -> SMEM copies") - if arrive is None: - arrive = True # Arrive by default - elif src_ref_ty.memory_space == smem and dst_ref_ty.memory_space is None: - gmem_ref, smem_ref = dst_ref, src_ref - if barrier is not None: - raise ValueError("Barriers are unsupported for SMEM -> GMEM copies") - if arrive is None: - arrive = True # Commit this copy to the async group by default - else: - raise ValueError("Only SMEM <-> GMEM copies supported") - # TODO(apaszke): This is a very approximate check. Improve it! - expected_name = "builtin.unrealized_conversion_cast" - if ( - gmem_ref.owner is None - or gmem_ref.owner.opview.OPERATION_NAME != expected_name - ): - raise ValueError("GMEM reference in async_copy must be a kernel argument") gmem_ref_ty = ir.MemRefType(gmem_ref.type) gmem_strides, _ = gmem_ref_ty.get_strides_and_offset() if gmem_strides != utils.get_contiguous_strides(gmem_ref_ty.shape): raise NotImplementedError( "async_copy assumes the GMEM reference is contiguous" ) - if any(s * element_bitwidth % 128 != 0 for s in gmem_strides[:-1]): - raise ValueError( - "async_copy requires all GMEM strides except the last one to be a" - " multiple of 16 bytes" - ) - # NOTE: TMA supports OOB indices, so we skip the check. + # Look for and verify gather indices in gmem_slice. + is_gathered_dim = [isinstance(s, fa.FragmentedArray) for s in gmem_slice] + gather_indices: fa.FragmentedArray | None = None + if any(is_gathered_dim): + if is_gathered_dim != [True, False]: + raise NotImplementedError( + "Gathers/scatters only supported along the first dimension of 2D" + " arrays" + ) + gather_indices = gmem_slice[0] + if not isinstance(gather_indices, fa.FragmentedArray): + raise ValueError("Gather/scatter indices must be a FragmentedArray") + if len(gather_indices.shape) != 1: + raise ValueError("Gather/scatter indices must be 1D") + idx_dtype = gather_indices.mlir_dtype + if ( + not isinstance(idx_dtype, ir.IntegerType) + or utils.bitwidth(idx_dtype) > 32 + ): + raise ValueError("Gather/scatter indices must be integers that are at most 32-bit wide") + if gather_indices.is_signed: + raise ValueError("Gather/scatter indices must be unsigned") + gmem_slice = (slice(None), *gmem_slice[1:]) + + # Analyze the slice (taking gathers into account). base_indices, slice_shape, is_squeezed = utils.parse_indices( - gmem_slice, ir.MemRefType(gmem_ref.type).shape, check_oob=False + gmem_slice, + ir.MemRefType(gmem_ref.type).shape, + # NOTE: TMA supports OOB indices, so we skip the check. + check_oob=implementation != AsyncCopyImplementation.TMA, ) + if gather_indices is not None: + slice_shape = [gather_indices.shape[0], *slice_shape[1:]] + del gmem_slice # Use slice_shape, base_indices and is_squeezed from now on! dyn_base_indices = tuple( c(i, index) if not isinstance(i, ir.Value) else i for i in base_indices ) del base_indices # Use the dynamic indices from now on! - collective_size = 1 - if collective is not None: - if isinstance(collective, gpu.Dimension): - collective = (collective,) - collective_size = math.prod(self.cluster_size[d] for d in collective) - if gmem_ref is dst_ref: - raise ValueError("Only GMEM -> SMEM copies can be collective") + # Deal with collective and partitioned loads. + if collective: + if implementation != AsyncCopyImplementation.TMA: + raise ValueError("Only the TMA implementation supports collective copies") + if gather_indices is not None: + raise NotImplementedError("Collective copies with gather/scatter unsupported") if partitioned is not None: - if collective is None: + # Increment partitioned by the number of preceding squeezed dimensions. + partitioned = np.where( + np.cumsum(~np.array(is_squeezed)) == partitioned+1)[0][0] + # Partitioning happens on the logical slice we extract from GMEM, so we do + # it before we apply transforms. + if not collective: # This implies non-gather TMA already. raise ValueError("Only collective loads can be partitioned") - if collective_size > 1 and partitioned is not None: + collective_size = math.prod(self.cluster_size[d] for d in collective) + if collective_size > 1: if math.prod(self.cluster_size) != 2: raise NotImplementedError( "Partitioned loads only supported for clusters of size 2" @@ -484,8 +820,8 @@ def async_copy( f" {slice_shape[partitioned]}" ) slice_shape[partitioned] //= collective_size - dyn_base_indices = list(dyn_base_indices) - dyn_base_indices[partitioned] = arith.addi( + dyn_base_indices = list(dyn_base_indices) # type: ignore[assignment] + dyn_base_indices[partitioned] = arith.addi( # type: ignore[index] dyn_base_indices[partitioned], arith.muli( self.cluster_idx(collective), c(slice_shape[partitioned], index) @@ -493,15 +829,17 @@ def async_copy( ) dyn_base_indices = tuple(dyn_base_indices) - squeezed_dims = [i for i, squeezed in enumerate(is_squeezed) if squeezed] - sliced_dims = [i for i, squeezed in enumerate(is_squeezed) if not squeezed] + squeezed_dims = tuple( + i for i, squeezed in enumerate(is_squeezed) if squeezed + ) # Indexing is really slicing + squeezing, and user transforms are meant to # apply after that. However, we actually have to apply the indexing last # (it's fused into the TMA) and so we need to commute it with all the user # transforms. For slicing this is done using transform_index and # transform_shape. For squeezing we actually move all the squeezed dims to # the front, and then batch each transform, making it ignore the extra dims. - if squeezed_dims: + if squeezed_dims and implementation != AsyncCopyImplementation.CP_ASYNC: + sliced_dims = [i for i, squeezed in enumerate(is_squeezed) if not squeezed] gmem_transform = (TransposeTransform((*squeezed_dims, *sliced_dims)), *(t.batch(len(squeezed_dims)) for t in gmem_transform)) @@ -510,64 +848,92 @@ def async_copy( dyn_base_indices = t.transform_index(dyn_base_indices) slice_shape = t.transform_shape(slice_shape) + return ( + list(slice_shape), + dyn_base_indices, + squeezed_dims, + gather_indices, + gmem_transform, + ) + + def _prepare_tma( + self, + gmem_ref: ir.Value, + smem_ref: ir.Value | None, + swizzle: int | None, + slice_shape: list[int], + dyn_base_indices: tuple[ir.Value, ...], + gather_indices, + squeezed_dims: tuple[int, ...], + gmem_transform: tuple[MemRefTransform, ...], + collective: Sequence[gpu.Dimension], + partitioned: int | None, + ): + """Finalizes setup specific to the TMA implementation of async_copy.""" + index = ir.IndexType.get() + # The function below is called only to verify the GMEM ref. The output + # is meant to be ignored. + _find_kernel_argument_for_gmem_ref(gmem_ref) + gmem_ref_ty = ir.MemRefType(gmem_ref.type) + element_bitwidth = utils.bitwidth(gmem_ref_ty.element_type) + gmem_strides, _ = gmem_ref_ty.get_strides_and_offset() + if any(s * element_bitwidth % 128 != 0 for s in gmem_strides[:-1]): + raise ValueError( + "async_copy requires all GMEM strides except the last one to be a" + " multiple of 16 bytes" + ) + # We don't need to do this for gather TMAs, because we'll unroll the + # transfers ourselves anyway. num_squeezed_dims = len(squeezed_dims) - if len(slice_shape) > 5: + if len(slice_shape) > 5 and gather_indices is None: # We can try to collapse all squeezed dims into one. if len(slice_shape) - num_squeezed_dims + 1 > 5: raise ValueError( "Async copies only support striding up to 5 dimensions" ) - collapse = CollapseLeadingIndicesTransform( - tuple(gmem_strides[d] for d in squeezed_dims) - ) + squeezed_dim_strides = tuple(gmem_strides[d] for d in squeezed_dims) + collapse = CollapseLeadingIndicesTransform(squeezed_dim_strides) gmem_transform = (*gmem_transform, collapse) dyn_base_indices = collapse.transform_index(dyn_base_indices) - slice_shape = collapse.transform_shape(slice_shape) + slice_shape = list(collapse.transform_shape(tuple(slice_shape))) num_squeezed_dims = 1 - del squeezed_dims, sliced_dims # Those no longer make sense. - - smem_ref_ty = ir.MemRefType(smem_ref.type) - # We moved all squeezed dims to the front. - if slice_shape[num_squeezed_dims:] != tuple(smem_ref_ty.shape): - raise ValueError( - "Expected the SMEM reference to have the same shape as the" - f" transformed slice: {tuple(smem_ref_ty.shape)} != {slice_shape}" - ) - smem_strides, _ = smem_ref_ty.get_strides_and_offset() - if any( - s != cs and d != 1 # Strides don't matter for dims of size 1. - for s, cs, d in zip( - smem_strides, - utils.get_contiguous_strides(smem_ref_ty.shape), - smem_ref_ty.shape, - ) - ): - raise ValueError( - "async_copy needs the SMEM reference to be contiguous, but got" - f" strides {smem_strides} for shape {smem_ref_ty.shape}" - ) dyn_base_indices = list(dyn_base_indices) slice_shape = list(slice_shape) assert all(d == 1 for d in slice_shape[:num_squeezed_dims]) # Partitioned loads have already been processed (before transforms). + # We process non-partitioned collective loads here, because only here are we + # able to know in what order the data will be written to SMEM. Transposes + # and tiling change that order and if we picked a partition based on the + # untransformed slice shape, we might have ended up with a non-contiguous + # SMEM window, which would no longer be realizable in a single TMA transfer. + collective_size = math.prod(self.cluster_size[d] for d in collective) # type: ignore if collective_size > 1 and partitioned is None: + assert gather_indices is None # Checked above. def partition_dim(dim: int, idx: ir.Value, num_chunks: int): # No need to partition squeezed dims. They don't even exist in smem_ref. assert dim >= num_squeezed_dims nonlocal smem_ref slice_shape[dim] //= num_chunks block_offset = arith.muli(idx, c(slice_shape[dim], index)) - dyn_base_indices[dim] = arith.addi(dyn_base_indices[dim], block_offset) - smem_ref = utils.memref_slice( - smem_ref, - (slice(None),) * (dim - num_squeezed_dims) - + (utils.ds(block_offset, slice_shape[dim]),), - ) + dyn_base_indices[dim] = arith.addi(dyn_base_indices[dim], block_offset) # type: ignore[index] + if smem_ref is not None: + smem_ref = utils.memref_slice( + smem_ref, + (slice(None),) * (dim - num_squeezed_dims) + + (utils.ds(block_offset, slice_shape[dim]),), + ) idx = self.cluster_idx(collective) rem_collective_size = collective_size - for dim, slice_size in enumerate(slice_shape[:-1]): + has_swizzle = ( + swizzle is not None + and swizzle != mgpu_dialect.SwizzlingMode.kNoSwizzle + ) + # We can partition the minormost dim if there's no swizzling. + for dim, slice_size in enumerate( + slice_shape[:-1] if has_swizzle else slice_shape + ): if slice_size % rem_collective_size == 0: partition_dim(dim, idx, rem_collective_size) rem_collective_size = 1 @@ -588,38 +954,16 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): f" {slice_shape} is divisible by the collective size" f" {collective_size}" ) - # Make each block load a smaller slice, adjust the GMEM indices and slice - # the SMEM reference accordingly. - multicast_mask = arith.trunci( - i16, utils.cluster_collective_mask(self.cluster_size, collective) - ) - else: - multicast_mask = None - - tma_desc = self._get_tma_desc( - gmem_ref, gmem_transform, tuple(slice_shape), swizzle, - ) - - # We constuct TMA descriptors in column-major order. - rev_dyn_base_indices = [ - arith.index_cast(i32, idx) for idx in reversed(dyn_base_indices) - ] - - uniform_ctx = ( - functools.partial(utils.single_thread, per_block=False) - if uniform and predicate is None - else contextlib.nullcontext - ) if max(slice_shape) > 256: raise ValueError( "Async copies only support copying <=256 elements along each" - " dimension" + f" dimension, got {tuple(slice_shape)}" ) if (zeroth_bw := slice_shape[-1] * element_bitwidth) % 128 != 0: raise ValueError( - "Async copies require the number of bytes copied along the last" - f" dimension to be divisible by 16, but got {zeroth_bw}" + "Async copies require the number of bits copied along the last" + f" dimension to be divisible by 128, but got {zeroth_bw}" ) if ( swizzle is not None @@ -632,62 +976,652 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): f" {(swizzle * 8) // element_bitwidth} elements, but got" f" {slice_shape[-1]} elements." ) - smem_ptr = utils.memref_ptr(smem_ref, memory_space=3) - if gmem_ref is src_ref: - assert barrier is not None # for pytype - assert np.prod(slice_shape) * element_bitwidth * collective_size % 8 == 0 - transfer_bytes = c( - np.prod(slice_shape) * element_bitwidth * collective_size // 8, i32 + return (smem_ref, slice_shape, dyn_base_indices, gmem_transform) + + def async_copy( + self, + *, + src_ref: ir.Value, + dst_ref: ir.Value, + gmem_slice: Any = (), + gmem_transform: MemRefTransform | tuple[MemRefTransform, ...] = (), + gmem_peer_id: int | ir.Value | GlobalBroadcast | None = None, + barrier: utils.BarrierRef | None = None, + swizzle: int | None = None, + arrive: bool | None = None, + collective: Sequence[gpu.Dimension] | gpu.Dimension | None = None, + partitioned: int | None = None, + # Should select 0 or 1 threads from the WG. + predicate: ir.Value | None | _DefaultPredicate = _DefaultPredicate(), + reduction_op: TMAReductionOp | None = None, + implementation: AsyncCopyImplementation = AsyncCopyImplementation.TMA, + ): + """Initiates an async copy between GMEM and SMEM. + + Exactly one of `src_ref` and `dst_ref` must be in GMEM and in SMEM, and the + SMEM reference must be contiguous. The GMEM window that is read or written + to is specified by the `gmem_slice`. The copy can change the order in which + the data appears in the window by applying a sequence of transforms to the + GMEM reference (as specified by `gmem_transform`). + + When `collective` is specified (only allowed for GMEM -> SMEM copies), the + identical async_copy must be scheduled by all blocks that share the same + coordinates along collective dimensions within a cluster. The behavior is + undefined otherwise. The semantics of collective loads depend further on the + `partitioned` argument: + + - If `partitioned` is not specified, all blocks load the same data into + their shared memory and all receive the update in their barriers, unless + `arrive` is False. If `arrive` is False, you should expect the barrier to + have expect_tx incremented by the same amount of bytes as if `collective` + was not specified. + - If `partitioned` is specified, each block only loads a separate slice of + the data into SMEM, partitioned into equal tiles along the `partitioned` + dimension. In this case only the barrier of the first block in the + collective will have its expect_tx incremented by the total size of the + transfer across all blocks involved in the collective. Barriers supplied + by other blocks will be ignored (even if `arrive` is True). + """ + index = ir.IndexType.get() + i8 = ir.IntegerType.get_signless(8) + i16 = ir.IntegerType.get_signless(16) + i32 = ir.IntegerType.get_signless(32) + + src_ref_ty = ir.MemRefType(src_ref.type) + dst_ref_ty = ir.MemRefType(dst_ref.type) + element_type = src_ref_ty.element_type + element_bitwidth = utils.bitwidth(element_type) + if element_type != dst_ref_ty.element_type: + raise ValueError( + f"Expected same element type, got {element_type} and" + f" {dst_ref_ty.element_type}" + ) + + if isinstance(collective, gpu.Dimension): + collective = (collective,) + elif collective is None: + collective = () + if not isinstance(gmem_transform, tuple): + gmem_transform = (gmem_transform,) + if not isinstance(gmem_slice, tuple): + gmem_slice = (gmem_slice,) + + if reduction_op is not None: + if implementation != AsyncCopyImplementation.TMA: + raise ValueError("Only the TMA implementation supports reductions") + if not _is_tma_reduction_op_supported(reduction_op, element_type): + raise ValueError( + f"Reduction op {reduction_op} not supported by the TMA" + f" implementation for element type {element_type}" + ) + + if src_ref_ty.memory_space is None and utils.is_smem_ref(dst_ref_ty): + gmem_ref, smem_ref = src_ref, dst_ref + if implementation == AsyncCopyImplementation.TMA: + if barrier is None: + raise ValueError("Barriers are required for TMA GMEM -> SMEM copies") + else: + assert implementation == AsyncCopyImplementation.CP_ASYNC + if barrier is not None: + raise NotImplementedError( + "Barriers are unsupported for CP_ASYNC GMEM -> SMEM copies" + ) + if arrive is None: + arrive = True # Arrive by default + elif utils.is_smem_ref(src_ref_ty) and dst_ref_ty.memory_space is None: + gmem_ref, smem_ref = dst_ref, src_ref + if barrier is not None: + raise ValueError("Barriers are unsupported for SMEM -> GMEM copies") + if arrive is None: + arrive = True # Commit this copy to the async group by default + else: + raise ValueError("Only SMEM <-> GMEM copies supported") + + if collective and gmem_ref is dst_ref: + raise ValueError("Only GMEM -> SMEM copies can be collective") + + ( + slice_shape, + dyn_base_indices, + squeezed_dims, + gather_indices, + gmem_transform, + ) = self._prepare_async_copy( + gmem_ref, + gmem_slice, + gmem_transform, + collective, + partitioned, + implementation, + ) + del gmem_slice # Use slice_shape, dyn_base_indices and squeezed_dims instead. + + gmem_ref_ty = ir.MemRefType(gmem_ref.type) + smem_ref_ty = ir.MemRefType(smem_ref.type) + # TODO(apaszke): Support squeezed dims for CP_ASYNC. + if implementation == AsyncCopyImplementation.CP_ASYNC and squeezed_dims: + raise NotImplementedError( + "Integer indexing in gmem_slice not supported for CP_ASYNC" + ) + # We moved all squeezed dims to the front in _prepare_async_copy. + assert all(d == 1 for d in slice_shape[:len(squeezed_dims)]) + if slice_shape[len(squeezed_dims):] != smem_ref_ty.shape: + raise ValueError( + "Expected the SMEM reference to have the same shape as the" + f" transformed slice: {tuple(smem_ref_ty.shape)} !=" + f" {tuple(slice_shape[len(squeezed_dims):])}" + ) + + if implementation == AsyncCopyImplementation.CP_ASYNC: + assert not collective + assert partitioned is None + if not isinstance(predicate, _DefaultPredicate): + raise NotImplementedError( + "CP_ASYNC needs to be performed by the whole warpgroup and does not" + " support the predicate argument" + ) + # TODO(apaszke): This should be quite easy? The only complication is that + # the indices array needs to have a layout compatible with the way we + # assign lanes to rows/cols. + if gather_indices is not None: + raise NotImplementedError("Gather/scatter unsupported for the CP_ASYNC implementation") + if smem_ref is src_ref: + raise ValueError("CP_ASYNC implementation only supports GMEM -> SMEM copies") + assert swizzle is not None + swizzle_elems = 8 * swizzle // element_bitwidth + if gmem_transform != (TileTransform((8, swizzle_elems)),): + raise NotImplementedError(gmem_transform) + layout = fa.tiled_copy_smem_gmem_layout( + *smem_ref_ty.shape[-4:-2], swizzle, element_bitwidth # type: ignore[call-arg] + ) + gmem_strides = gmem_ref_ty.get_strides_and_offset()[0] + dst_tiled_strides = [ + arith.constant(i32, s) + for s in layout.tiling.tile_strides(gmem_strides)[gmem_ref_ty.rank :] + ] + lane_offset = utils.dyn_dot(layout.lane_indices(), dst_tiled_strides) + warp_offset = utils.dyn_dot(layout.warp_indices(), dst_tiled_strides) + dyn_offset = arith.addi(lane_offset, warp_offset) + offset_scale = 1 if element_bitwidth >= 8 else 8 // element_bitwidth + if element_bitwidth < 8: + gep_type = i8 + elif ( + isinstance(element_type, ir.FloatType) + and ir.FloatType(element_type).width == 8 + ): + gep_type = i8 # LLVM has no support for f8. + else: + gep_type = element_type + dyn_offset = arith.divui(dyn_offset, c(offset_scale, i32)) + if gmem_ref_ty.rank != 2: + raise NotImplementedError("Only 2D copies implemented") + transfers = fa.FragmentedArray.transfer_tiled( + smem_ref, swizzle, layout, tuple(gmem_ref_ty.shape), optimized=False + ) + gmem_base_ptr = utils.getelementptr(utils.memref_ptr(gmem_ref), [dyn_offset], gep_type) + gmem_base_ptr = llvm.addrspacecast(ir.Type.parse("!llvm.ptr<1>"), gmem_base_ptr) + bytes_per_transfer = layout.vector_length * element_bitwidth // 8 + # Only 16-byte transfers can skip the L1 cache (this is what CG means). + cache_modifier = ( + nvvm.LoadCacheModifierKind.CG + if bytes_per_transfer == 16 + else nvvm.LoadCacheModifierKind.CA + ) + for _get, _update, get_base_idx, smem_ptr in transfers: + constant_offset = sum(i * s for i, s in zip(get_base_idx(), gmem_strides, strict=True)) + gmem_ptr = utils.getelementptr(gmem_base_ptr, [constant_offset // offset_scale], gep_type) + nvvm.cp_async_shared_global(smem_ptr, gmem_ptr, bytes_per_transfer, cache_modifier) + if barrier is None: + nvvm.cp_async_commit_group() + else: + raise NotImplementedError + return + + assert implementation == AsyncCopyImplementation.TMA + + (smem_ref, slice_shape, dyn_base_indices, gmem_transform) = ( + self._prepare_tma( + gmem_ref, + smem_ref, + swizzle, + slice_shape, + dyn_base_indices, + gather_indices, + squeezed_dims, + gmem_transform, + collective, + partitioned, + ) + ) + assert smem_ref is not None # For type checkers. + + smem_strides, _ = ir.MemRefType(smem_ref.type).get_strides_and_offset() + if any( + s != cs and d != 1 # Strides don't matter for dims of size 1. + for s, cs, d in zip( + smem_strides, + utils.get_contiguous_strides(smem_ref_ty.shape), + smem_ref_ty.shape, + ) + ): + raise ValueError( + "async_copy needs the SMEM reference to be contiguous, but got" + f" strides {smem_strides} for shape {smem_ref_ty.shape}" ) + + collective_size = math.prod(self.cluster_size[d] for d in collective) + assert math.prod(slice_shape) * element_bitwidth * collective_size % 8 == 0 + transfer_bytes = c( + math.prod(slice_shape) * element_bitwidth * collective_size // 8, i32 + ) + + if gather_indices is not None: + import builtins + zips = functools.partial(builtins.zip, strict=True) + # The gather TMA instruction is limited to 2D GMEM references. That means + # that we can't apply the transforms to the GMEM reference and have the + # TMA engine deal with permuting the data, like we do for non-gather TMA. + # Instead, we have to break up the transfer into multiple 2D gathers + # ourselves, which requires us to do more complicated stride math etc. + # + # The minor transformed dim should be a contiguous transfer dim. + # The second minor should be a gather dim of size divisible by 4. + # The rest can be anything, and we will unroll the transfers over them. + if smem_ref is src_ref: + raise NotImplementedError("Scatter unsupported for the TMA implementation") + assert barrier is not None # for pytype barrier_ptr = barrier.get_ptr() - with uniform_ctx(): - if collective_size > 1 and partitioned is not None: - if predicate is None: - predicate = c(1, ir.IntegerType.get_signless(1)) - if arrive: - first_block = arith.cmpi( - arith.CmpIPredicate.eq, self.cluster_idx(collective), c(0, index), - ) - arrive_predicate = arith.andi(predicate, first_block) - nvvm.mbarrier_arrive_expect_tx_shared( - barrier_ptr, transfer_bytes, predicate=arrive_predicate - ) - rank = len(slice_shape) - idx_operands = ",".join(f"${i}" for i in range(4, 4 + rank)) + if squeezed_dims: + raise NotImplementedError("Gather/scatter unsupported when using integer indexing") + if reduction_op is not None: + raise ValueError("Gather/scatter TMA can't perform reductions") + if not isinstance(predicate, _DefaultPredicate): + raise ValueError("Gather/scatter TMA can't use a predicate") + if gather_indices.layout != fa.TMA_GATHER_INDICES_LAYOUT: + raise ValueError(f"Unsupported gather indices layout: {gather_indices.layout}") + ROWS_PER_INSTR = 4 + # Make sure we'll always be accessing SMEM with sufficient alignment. + single_tma_bits = ROWS_PER_INSTR * slice_shape[-1] * element_bitwidth + if single_tma_bits % 1024: + raise ValueError( + "Gather/scatter TMA would require breaking it up into transfers of" + f" {single_tma_bits // 8} bytes, but need a multiple of 128 bytes" + ) + + if arrive: + arrive_predicate = utils.single_thread_predicate(utils.ThreadSubset.WARPGROUP) + utils.nvvm_mbarrier_arrive_expect_tx( + barrier_ptr, + transfer_bytes, + predicate=arrive_predicate, + ) + + gmem_strides, _ = gmem_ref_ty.get_strides_and_offset() + assert len(gmem_strides) == 2 + _, gmem_cols = gmem_ref_ty.shape + slice_gather_strides: tuple[int, ...] = (1, 0) # Each row gets a new index, column has no effect. + for t in gmem_transform: + gmem_strides = t.transform_strides(gmem_strides) + slice_gather_strides = t.transform_strides(slice_gather_strides) + is_gather_dim = [bool(s) for s in slice_gather_strides] + + tma_desc = self._get_tma_desc( + gmem_ref, (), gmem_peer_id, (1, slice_shape[-1]), swizzle, reduction_op, + ) + + # Indices are split over 4 warps, and replicated within each warp. + assert fa.TMA_GATHER_INDICES_LAYOUT.vector_length == ROWS_PER_INSTR + # Index 00 01 02 03 04 05 06 07 08 09 10 11 12 13 14 15 16 17 ... + # Warp <--- 0 ---> <--- 1 ---> <--- 2 ---> <--- 3 ---> <--- 0 -- + warp_idx = arith.remui( + utils.warp_idx(sync=True), + arith.constant(i32, utils.WARPS_IN_WARPGROUP), + ) + gather_linear_idx_warp = arith.muli(warp_idx, c(ROWS_PER_INSTR, i32)) + + # Since the TMA instruction is limited to 2D gathers, we flatten all + # non-gather dims into the column index. + max_non_gather_linear_index = sum( + (d - 1) * s + for g, d, s in zip(is_gather_dim[:-1], slice_shape[:-1], gmem_strides[:-1]) + if not g + ) + # If we ever exceed this then we need to change the size of the GMEM ref, + # to prevent the TMA engine from clipping our indices. + if max_non_gather_linear_index > gmem_cols: + raise NotImplementedError("Non-gather dims don't fit into the columns") + col_base_offset = functools.reduce( + arith.addi, + ( + arith.muli(idx, arith.constant(index, stride)) + for g, idx, stride in zips( + is_gather_dim, dyn_base_indices, gmem_strides + ) + if not g + ), + arith.constant(index, 0), + ) + col_base_offset = arith.index_cast(i32, col_base_offset) + # TMA instructions are uniform, so we can't use multiple lanes. + predicate = utils.single_thread_predicate(utils.ThreadSubset.WARP) + # We need to unroll over all non-gather dimensions other than the last one + non_gather_slice_shape = tuple( + 1 if g else d for d, g in zips(slice_shape[:-1], is_gather_dim[:-1]) + ) + # First, iterate over gather index registers we have available. + for i, reg in enumerate(gather_indices.registers.flat): + if utils.bitwidth(gather_indices.mlir_dtype) != 32: + reg = arith.extui(ir.VectorType.get((4,), i32), reg) + # Compute which rows within the 2D slice we'll be gathering. + gather_linear_idx_reg = i * ROWS_PER_INSTR * utils.WARPS_IN_WARPGROUP + gather_linear_idx = arith.addi( + gather_linear_idx_warp, arith.constant(i32, gather_linear_idx_reg) + ) + # Transform row indices to align with the transformed SMEM shape. + gather_slice_idx = [ + arith.remui(arith.divui(gather_linear_idx, c(s, i32)), c(d, i32)) + for g, d, s in zip(is_gather_dim, slice_shape, slice_gather_strides) + if g + ] + gather_slice_idx = [arith.index_cast(index, i) for i in gather_slice_idx] + gather_rows = [ + llvm.extractelement(reg, c(i, i32)) for i in range(ROWS_PER_INSTR) + ] + # Second, step over non-gather slice indices. + for non_gather_idxs in np.ndindex(non_gather_slice_shape): + gather_slice_idx_it = iter(gather_slice_idx) + smem_indices = tuple( + next(gather_slice_idx_it) if g else i + for g, i in zip(is_gather_dim[:-1], non_gather_idxs) + ) + # We should really take a slice here, but it doesn't matter. We're + # just going to take the base pointer anyway. + transfer_smem_ref = utils.memref_slice(smem_ref, smem_indices) + smem_ptr = utils.memref_ptr(transfer_smem_ref, memory_space=3) + # The slice index needs to be folded into the gather col index. + col_slice_offset = sum( + idx * stride + for g, idx, stride in zips( + is_gather_dim[:-1], non_gather_idxs, gmem_strides[:-1] + ) + if not g + ) + col_offset = arith.addi(col_base_offset, arith.constant(i32, col_slice_offset)) llvm.inline_asm( ir.Type.parse("!llvm.void"), - [predicate, smem_ptr, tma_desc, barrier_ptr, *rev_dyn_base_indices], - f""" - {{ - .reg .b32 mapped_addr; - @$0 mapa.shared::cluster.u32 mapped_addr, $3, 0; - @$0 cp.async.bulk.tensor.{rank}d.shared::cta.global.tile.mbarrier::complete_tx::bytes.cta_group::2 - [$1], [$2, {{{idx_operands}}}], [mapped_addr]; - }} - """, - "b,r,l,r" + ",r" * rank, + [predicate, smem_ptr, tma_desc, barrier_ptr, col_offset, *gather_rows], + "@$0 cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes [$1], [$2, {$4, $5, $6, $7, $8}], [$3];", + "b,r,l,r" + ",r" * (ROWS_PER_INSTR + 1), has_side_effects=True, ) - else: - if arrive: - nvvm.mbarrier_arrive_expect_tx_shared( - barrier_ptr, transfer_bytes, predicate=predicate - ) - nvvm.cp_async_bulk_tensor_shared_cluster_global( - smem_ptr, tma_desc, rev_dyn_base_indices, barrier_ptr, [], - multicast_mask=multicast_mask, predicate=predicate + return + + assert gather_indices is None # Only tiled TMA handled below. + tma_desc = self._get_tma_desc( + gmem_ref, gmem_transform, gmem_peer_id, + tuple(slice_shape), swizzle, reduction_op, + ) + # We construct TMA descriptors in column-major order. + rev_dyn_base_indices = [ + arith.index_cast(i32, idx) for idx in reversed(dyn_base_indices) + ] + if isinstance(predicate, _DefaultPredicate): + predicate = utils.single_thread_predicate(utils.ThreadSubset.WARPGROUP) + if predicate is None: + predicate = c(1, ir.IntegerType.get_signless(1)) + smem_ptr = utils.memref_ptr(smem_ref, memory_space=3) + if gmem_ref is src_ref: + assert barrier is not None # for pytype + barrier_ptr = barrier.get_ptr() + assert reduction_op is None + if collective_size > 1 and partitioned is not None: + assert collective_size == 2 + if arrive: + first_block = arith.cmpi( + arith.CmpIPredicate.eq, self.cluster_idx(collective), c(0, index), + ) + arrive_predicate = arith.andi(predicate, first_block) + utils.nvvm_mbarrier_arrive_expect_tx( + barrier_ptr, transfer_bytes, predicate=arrive_predicate + ) + rank = len(slice_shape) + idx_operands = ",".join(f"${i}" for i in range(4, 4 + rank)) + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [predicate, smem_ptr, tma_desc, barrier_ptr, *rev_dyn_base_indices], + f""" + {{ + .reg .b32 mapped_addr; + @$0 mapa.shared::cluster.u32 mapped_addr, $3, 0; + @$0 cp.async.bulk.tensor.{rank}d.shared::cta.global.tile.mbarrier::complete_tx::bytes.cta_group::2 + [$1], [$2, {{{idx_operands}}}], [mapped_addr]; + }} + """, + "b,r,l,r" + ",r" * rank, + has_side_effects=True, + ) + else: + if arrive: + utils.nvvm_mbarrier_arrive_expect_tx( + barrier_ptr, transfer_bytes, predicate=predicate ) + if collective_size > 1: + multicast_mask = arith.trunci( + i16, utils.cluster_collective_mask(self.cluster_size, collective) + ) + else: + multicast_mask = None + nvvm.cp_async_bulk_tensor_shared_cluster_global( + smem_ptr, tma_desc, rev_dyn_base_indices, barrier_ptr, [], + multicast_mask=multicast_mask, predicate=predicate + ) else: - assert multicast_mask is None - with uniform_ctx(): + if reduction_op is not None: + rank = len(slice_shape) + idx_operands = ",".join(f"${i}" for i in range(3, 3 + rank)) + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [predicate,smem_ptr,tma_desc,*rev_dyn_base_indices], + f"@$0 cp.reduce.async.bulk.tensor.{rank}d.global.shared::cta.{_reduction_op_to_ptx(reduction_op)}.tile.bulk_group [$2,{{{idx_operands}}}], [$1];", + "b,r,l" + ",r" * rank, + has_side_effects=True, + ) + if arrive: + nvvm.cp_async_bulk_commit_group() + else: nvvm.cp_async_bulk_tensor_global_shared_cta( tma_desc, smem_ptr, rev_dyn_base_indices, predicate=predicate ) if arrive: nvvm.cp_async_bulk_commit_group() + def async_prefetch( + self, + *, + gmem_ref: ir.Value, + gmem_slice: Any = (), + gmem_transform: MemRefTransform | tuple[MemRefTransform, ...] = (), + gmem_peer_id: int | ir.Value | None = None, + swizzle: int | None = None, + collective: Sequence[gpu.Dimension] | gpu.Dimension | None = None, + partitioned: int | None = None, + # Should select 0 or 1 threads from the WG. + predicate: ir.Value | None | _DefaultPredicate = _DefaultPredicate(), + ): + i32 = ir.IntegerType.get_signless(32) + + if isinstance(collective, gpu.Dimension): + collective = (collective,) + elif collective is None: + collective = () + if not isinstance(gmem_transform, tuple): + gmem_transform = (gmem_transform,) + if not isinstance(gmem_slice, tuple): + gmem_slice = (gmem_slice,) + + impl = AsyncCopyImplementation.TMA + ( + slice_shape, + dyn_base_indices, + squeezed_dims, + gather_indices, + gmem_transform, + ) = self._prepare_async_copy( + gmem_ref, gmem_slice, gmem_transform, collective, partitioned, impl + ) + del gmem_slice # Use slice_shape, dyn_base_indices and squeezed_dims instead. + + (_, slice_shape, dyn_base_indices, gmem_transform) = ( + self._prepare_tma( + gmem_ref, + None, + swizzle, + slice_shape, + dyn_base_indices, + gather_indices, + squeezed_dims, + gmem_transform, + collective, + partitioned, + ) + ) + + if gather_indices is not None: + raise NotImplementedError("Gather/scatter prefetch not implemented yet") + + tma_desc = self._get_tma_desc( + gmem_ref, gmem_transform, gmem_peer_id, + tuple(slice_shape), swizzle, reduction_op=None, + ) + # We construct TMA descriptors in column-major order. + rev_dyn_base_indices = [ + arith.index_cast(i32, idx) for idx in reversed(dyn_base_indices) + ] + if isinstance(predicate, _DefaultPredicate): + predicate = utils.single_thread_predicate(utils.ThreadSubset.WARPGROUP) + if predicate is None: + predicate = c(1, ir.IntegerType.get_signless(1)) + rank = len(slice_shape) + idx_operands = ",".join(f"${i}" for i in range(2, 2 + rank)) + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [predicate, tma_desc, *rev_dyn_base_indices], + f"@$0 cp.async.bulk.prefetch.tensor.{rank}d.L2.global.tile [$1, {{{idx_operands}}}];", + "b,l" + ",r" * rank, + has_side_effects=True, + ) + def await_async_copy( - self, allow_groups: int, await_read_only: bool = False + self, allow_groups: int, await_read_only: bool = False, + scope: utils.ThreadSubset = utils.ThreadSubset.WARPGROUP, ): nvvm.cp_async_bulk_wait_group(allow_groups, read=await_read_only) + if scope == utils.ThreadSubset.WARPGROUP: + utils.warpgroup_barrier() + elif scope == utils.ThreadSubset.WARP: + utils.warp_barrier() + else: + raise ValueError(f"Unsupported scope: {scope}") + + def await_cp_async_copy(self, allow_groups: int): + nvvm.cp_async_wait_group(allow_groups) utils.warpgroup_barrier() + + def _ensure_nvshmem_decls(self): + if self.is_device_collective: + return + self.is_device_collective = True + with ir.InsertionPoint(self.module.body): + nvshmem_my_pe_type = ir.TypeAttr.get(ir.Type.parse("!llvm.func")) + llvm.LLVMFuncOp( + "nvshmem_my_pe", nvshmem_my_pe_type, sym_visibility="private" + ) + nvshmem_ptr_type = ir.TypeAttr.get( + ir.Type.parse("!llvm.func") + ) + llvm.LLVMFuncOp("nvshmem_ptr", nvshmem_ptr_type, sym_visibility="private") + nvshmemx_mc_ptr_type = ir.TypeAttr.get( + ir.Type.parse("!llvm.func") + ) + llvm.LLVMFuncOp( + "nvshmemx_mc_ptr", nvshmemx_mc_ptr_type, sym_visibility="private" + ) + + def to_remote(self, ref: ir.Value, peer: ir.Value): + self._ensure_nvshmem_decls() + if isinstance(ref.type, ir.MemRefType): + # We replace the offset in the ref type by 0, because memref_ptr always + # folds the offset into the pointer. + ref_ty = ir.MemRefType(ref.type) + strides, _ = ref_ty.get_strides_and_offset() + result_type = ir.MemRefType.get( + ref_ty.shape, + ref_ty.element_type, + ir.StridedLayoutAttr.get(0, strides), + ref_ty.memory_space, + ) + return utils.ptr_as_memref( + self.to_remote(utils.memref_ptr(ref), peer), result_type + ) + if ref.type != ir.Type.parse("!llvm.ptr"): + raise ValueError(f"Unsupported type for to_remote: {ref.type}") + if peer.type != ir.IntegerType.get_signless(32): + raise ValueError(f"peer index must be an i32, got {peer.type}") + return llvm.call(ref.type, [ref, peer], [], [], callee="nvshmem_ptr") + + def to_remote_multicast(self, ref: ir.Value): + i32 = ir.IntegerType.get_signless(32) + self._ensure_nvshmem_decls() + if not isinstance(ref.type, ir.MemRefType): + raise ValueError(f"Unsupported type for to_remote_multicast: {ref.type}") + # We replace the offset in the ref type by 0, because memref_ptr always + # folds the offset into the pointer. + ref_ty = ir.MemRefType(ref.type) + strides, _ = ref_ty.get_strides_and_offset() + result_type = ir.MemRefType.get( + ref_ty.shape, + ref_ty.element_type, + ir.StridedLayoutAttr.get(0, strides), + ref_ty.memory_space, + ) + world_team = arith.constant(i32, 0) + ptr = utils.memref_ptr(ref) + mc_ptr = llvm.call( + ptr.type, [world_team, ptr], [], [], callee="nvshmemx_mc_ptr", + ) + return utils.MultimemRef(utils.ptr_as_memref(mc_ptr, result_type)) + + def device_id(self) -> ir.Value: + self._ensure_nvshmem_decls() + i32 = ir.IntegerType.get_signless(32) + return llvm.call(i32, [], [], [], callee="nvshmem_my_pe") + + +class ReplicationError(Exception): + pass + +def _recompute_peer_id(peer_id: ir.Value, fuel=8) -> ir.Value: + if fuel == 0: + raise ReplicationError( + "gmem_peer_id computation is too complicated to recompute on the host" + ) + if isinstance(peer_id, ir.BlockArgument): + raise ReplicationError("Can't recompute a value that's a block argument") + op = peer_id.owner.opview + # We accept all arith ops + if op.OPERATION_NAME.startswith("arith."): + new_operands = [_recompute_peer_id(x, fuel - 1) for x in op.operands] + result_types = [r.type for r in op.results] + new_attributes = {na: op.attributes[na] for na in op.attributes} + new_op = ir.Operation.create( + op.OPERATION_NAME, result_types, new_operands, new_attributes + ) + return new_op.results if len(new_op.results) > 1 else new_op.result + # nvshmem_my_pe queries the device id of the current process and works on both + # the host and the device. + if isinstance(op, llvm.CallOp) and op.callee.value == "nvshmem_my_pe": + i32 = ir.IntegerType.get_signless(32) + return llvm.call(i32, [], [], [], callee="nvshmem_my_pe") + raise ReplicationError( + f"Unrecognized op can't be recomputed on the host: {op}" + ) diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index 0d2811bb5610..b881d8bb4aa7 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -1,4 +1,4 @@ -# Copyright 2024 The JAX Authors. +# Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,16 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Layout inference pass for the MLIR Mosaic GPU dialect.""" +"""Layout and transform inference pass for the MLIR Mosaic GPU dialect.""" -from collections.abc import Callable, Sequence +# mypy has been causing more problems than it solves here. Disable it for these +# files. We have pytype checks anyway. +# mypy: ignore-errors + +from __future__ import annotations + +from collections.abc import Callable, Iterator, Sequence import dataclasses import enum -from functools import partial +import itertools import math -from typing import cast +import re +from typing import assert_never, cast -from jax._src.lib import mosaic_gpu_dialect as mgpu +from absl import logging +from jax._src.lib import mosaic_gpu_dialect as mgpu # noqa: F401 from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith from jax._src.lib.mlir.dialects import math as mlir_math @@ -30,176 +38,474 @@ from jax._src.lib.mlir.dialects import vector import numpy as np +from . import constraints as cs from . import fragmented_array as fa from . import inference_utils +from . import launch_context as lc from . import layouts as layouts_lib +from . import tcgen05 from . import utils -# mypy: ignore-errors +# This value was arrived at by looking at an existing kernel where layout +# inference would never be able to complete successfully, and kernels where it +# would, as well as existing tests as of 2025-11-03. We observed the following: +# +# 1. all tests would pass with a fuel that is at least ~15_000; +# 2. the kernel for which layout inference fails would fail in less than 12 +# seconds when using a fuel of 100_000. +# +# All in all, this seems like a reasonable compromise: the value is high +# enough that we can comfortably find a solution to even the most complicated +# layout inference problems that we have seen so far, but the runtime is fast +# enough that users will not waste much time waiting for a never-ending pass to +# complete when the system is unable to find a solution. +_DEFAULT_LAYOUT_INFERENCE_FUEL = 100_000 + -OptionalLayouts = tuple[list[ir.Attribute], list[ir.Attribute]] | None -LayoutInferenceRule = Callable[[ir.OpView], OptionalLayouts] -_layout_inference_rules: dict[str, LayoutInferenceRule] = {} +class VariableType(enum.IntEnum): + """The type of a variable. + Variables are operands, results, or arguments of MLIR operations. + """ + OPERAND = 0 + RESULT = 1 + ARGUMENT = 2 -def _add_layout_inference_rule(op: type[ir.OpView], rule: LayoutInferenceRule): - _layout_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error +class MemorySpace(enum.Enum): + """The memory space of a variable.""" + REG = enum.auto() + SMEM = enum.auto() + TMEM = enum.auto() -def _set_layout_attributes( - op: ir.OpView, - in_layouts: list[ir.Attribute], - out_layouts: list[ir.Attribute], -): - op.attributes["in_layouts"] = ir.ArrayAttr.get(in_layouts) - op.attributes["out_layouts"] = ir.ArrayAttr.get(out_layouts) +_op_name_regex = re.compile(r"^(%\d+ = )?\S+") -def _choose_representative_layout( - layouts: set[ir.Attribute], -) -> ir.Attribute | None: - """Chooses an appropriate layout from a given set of possible layouts. - Given the input set of possible layouts, this function extracts a single - representative layout. Currently, this function only works with strided, - splat, and tiled layouts. +@dataclasses.dataclass(frozen=True) +class ValueSite: + """A unique identifier for a variable. - Returns: - A single layout that can be used to annotate the operation, or None if the - input set is empty. + This class describes a particular role of a Value, either as a result of an + operation, an operand of an operation, or a block argument. """ + # A MLIR operation. If the type is `ARGUMENT`, this is the owner of the block + # and region_index is the region that contains the block with the argument. + # The block is always the first block of the region. + operation: ir.OpView + # Whether this represents an operand, a result, or an argument. + type: VariableType + # The index of the operand/result/argument within the op's + # operands/results/arguments. + index: int + # The index of the region that contains the block with the argument. + region_index: int | None = None + + def __post_init__(self): + assert (self.type != VariableType.ARGUMENT) == (self.region_index is None) + + @property + def value(self) -> ir.Value: + """Returns the IR value corresponding to this value site.""" + if self.type == VariableType.OPERAND: + return self.operation.operands[self.index] + elif self.type == VariableType.RESULT: + return self.operation.results[self.index] + else: + return self.operation.regions[self.region_index].blocks[0].arguments[self.index] + + @property + def shape(self) -> tuple[int, ...]: + """Returns the shape of the underlying value.""" + return tuple(self.value.type.shape) # pytype: disable=attribute-error + + @property + def memory_space(self) -> MemorySpace: + """Returns the memory space associated with this value.""" + type = self.value.type + if isinstance(type, ir.VectorType): + return MemorySpace.REG + assert isinstance(type, ir.MemRefType) + if utils.is_tmem_ref(type): + return MemorySpace.TMEM + elif utils.is_smem_ref(type): + return MemorySpace.SMEM + raise ValueError(f"Unsupported memory space for: {type}") + + def __str__(self): + match = _op_name_regex.match(str(self.operation)) + assert match is not None + if self.type == VariableType.OPERAND: + return f"{match.group(0)}:o-{self.index}" + elif self.type == VariableType.RESULT: + return f"{match.group(0)}:r-{self.index}" + else: + return f"{match.group(0)}:a-{self.index}" + + +def extract_assignment_candidates_from_reduce_equation( + small: cs.RegisterLayout, + large: cs.Variable, + reduction_dims: tuple[int, ...] +) -> Iterator[cs.RegisterLayout]: + """Yields layout candidates for the reduce equation `small = reduce(large, reduction_dims).""" + large_shape = large.key.value.type.shape # pytype: disable=attribute-error + candidates = [ + fa.WGMMA_LAYOUT, + fa.WGMMA_TRANSPOSED_LAYOUT, + fa.TCGEN05_LAYOUT, + fa.TCGEN05_TRANSPOSED_LAYOUT, + tcgen05.TMEM_NATIVE_LAYOUT, + ] + if large_shape[-1] % 16 == 0: + candidates.append(tcgen05.fa_m64_collective_layout(large_shape[-1])) - if not layouts: - return None + for candidate in candidates: + if len(candidate.base_tile_shape) > len(large_shape): + continue + if candidate.reduce(reduction_dims) == small.value: + yield cs.RegisterLayout(candidate) - strided_layouts: list[fa.WGStridedFragLayout] = [ - layouts_lib.from_layout_attr(layout) - for layout in layouts - if layouts_lib.is_strided_fragmented_layout(layout) - ] - splat_layouts: list[fa.WGSplatFragLayout] = list( - map( - layouts_lib.from_layout_attr, - filter(layouts_lib.is_splat_fragmented_layout, layouts), - ) - ) +def _strided_layout_for_variable( + variable: cs.Variable, +) -> fa.WGStridedFragLayout | None: + """Returns a strided layout for the given variable. - tiled_layouts: list[fa.TiledLayout] = list( - map( - layouts_lib.from_layout_attr, - filter(layouts_lib.is_tiled_layout, layouts), - ) - ) + If the given variable cannot have a strided layout, returns `None`. + """ + # TODO(bchetioui): should we make variables carry a shape as well, to make + # things easier? + type = variable.key.value.type + assert isinstance(type, ir.VectorType) + return fa.WGStridedFragLayout.from_shaped_type(type) + + +def _default_tmem_layout_for_variable( + variable: cs.Variable, +) -> tcgen05.TMEMLayout | None: + """Returns a default TMEM layout for the given variable, if one is defined.""" + value = variable.key.value + parent = value.owner + if isinstance(parent, mgpu.TmemAllocOp): + return tcgen05._infer_tmem_layout( + tuple(value.type.shape), parent.collective, packing=1 + ) + return None + + +def _extract_tiling_candidate( + divide_constraint: cs.Divides, num_tiled_dims: int +) -> Iterator[tuple[cs.Variable, cs.Constant]]: + if not isinstance(divide_constraint.expr, cs.Variable): + return + if num_tiled_dims > len(divide_constraint.tiling_multiple): + # The tiling's rank cannot be larger than the size of `tiling_multiple`. + return + tiling = divide_constraint.tiling_multiple[-num_tiled_dims:] + yield divide_constraint.expr, cs.SMEMTiling(lc.TileTransform(tiling)) + + +def _extract_layout_candidates_from_memory_space_transfer( + constraint: cs.IsTransferable, + division_constraint_per_var: dict[cs.Variable, cs.Divides], +) -> Iterator[tuple[cs.Variable, cs.Constant]]: + """Attempts to extract variable assignments from a `Constraint`.""" + # This code assumes that the `IsTransferable` constraint is bidirectional. + # This is currently true for TMEM <-> REG transfers and SMEM <-> REG + # transfers. + src, tgt = constraint.source, constraint.target + match src, tgt: + case cs.Variable(), cs.Constant(): + variable, constant = src, tgt + case cs.Constant(), cs.Variable(): + variable, constant = tgt, src + case _: + return - if len(splat_layouts) + len(strided_layouts) + len(tiled_layouts) != len( - layouts + assert isinstance(variable, cs.Variable) # Satisfy type checkers. + if isinstance(constant, cs.RegisterLayout): + layout = constant.value + if variable.key.memory_space == MemorySpace.TMEM: + dtype = ir.MemRefType(variable.key.value.type).element_type + for packing in (1, 32 // utils.bitwidth(dtype)): + for tmem_layout, reg_layout in constraint.supported_tmem_transfers( + packing + ): + if layout == reg_layout: + yield variable, cs.TMEMLayout(tmem_layout) + elif variable.key.memory_space == MemorySpace.SMEM: + if inference_utils.is_mma_layout(layout): + tiling = _infer_tiling_for_mma_ref( + variable.key.value.type, + max_swizzle=mgpu.SwizzlingMode.k128ByteSwizzle + ) + divide = cs.Divides(variable, tiling) + if (divide2 := division_constraint_per_var.get(variable)) is not None: + # This is done on two lines to satisfy type checkers. + # TODO(b/447079781): clean up the `merge_divides_constraints` to + # avoid the need for this. + [merged] = cs.merge_divides_constraints([divide, divide2]) + divide = cast(cs.Divides, merged) + yield from _extract_tiling_candidate(divide, len(tiling)) + else: + # An empty tiling is valid here but we don't yield it in order to + # avoid duplicating the empty tiling yielded by the caller. + return + + if isinstance(constant, cs.TMEMLayout): + layout = constant.value + packing = layout.vector_length + for tmem_layout, reg_layout in constraint.supported_tmem_transfers(packing): + if layout == tmem_layout: + yield variable, cs.RegisterLayout(reg_layout) + + +def _divides_per_var( + constraints: Sequence[cs.Constraint], +) -> dict[cs.Variable, cs.Divides]: + result: dict[cs.Variable, cs.Divides] = {} + for constraint in constraints: + if isinstance(constraint, cs.Divides) and isinstance( + constraint.expr, cs.Variable + ): + assert constraint.expr not in result + result[constraint.expr] = constraint + return result + + +# TODO(bchetioui): flatten this call hierarchy. +def _extract_variable_assignments_from_constraints( + constraints: Sequence[cs.Constraint], +) -> Iterator[tuple[cs.Variable, cs.Constant]]: + """Attempts to extract variable assignments from all constraints.""" + dpv = _divides_per_var(constraints) + for c in constraints: + match c: + case cs.IsTransferable(): + yield from _extract_layout_candidates_from_memory_space_transfer(c, dpv) + case cs.Equals(cs.Reduce(cs.Variable() as large, axes=axes), cs.RegisterLayout() as small): + for layout in extract_assignment_candidates_from_reduce_equation(small, large, axes): + yield large, layout + case cs.Equals(cs.RegisterLayout() as small, cs.Reduce(cs.Variable() as large, axes=axes)): + for layout in extract_assignment_candidates_from_reduce_equation(small, large, axes): + yield large, layout + case cs.Relayout(cs.Variable() as var, cs.RegisterLayout() as layout): + yield var, layout + case cs.Relayout(cs.RegisterLayout() as layout, cs.Variable() as var): + yield var, layout + + +def conjure_assignment( + unknowns: Sequence[cs.Variable], + constraint_system: cs.ConstraintSystem, +) -> Iterator[tuple[cs.Variable, cs.Constant]]: + """Attempts to conjure an assignment for an unknown variable.""" + # TODO(allanrenucci): We should be able to short-circuit the search here if + # the constraint is not satisfiable. + + # As we extract assignment candidates from constraints, we prioritize + # candidates that are more "interesting"; e.g., in the case of registers, + # introducing splat layout candidate assignments often leads to a dead end in + # practice---as opposed to tiled layouts, which are more likely to yield + # solutions to the constraint system. + low_priority_assignments: list[tuple[cs.Variable, cs.Constant]] = [] + for variable, constant in _extract_variable_assignments_from_constraints( + constraint_system.constraints ): - raise ValueError( - f"Expected only strided, splat, and tiled layouts, got {layouts}" - ) + match constant: + case cs.RegisterLayout(value=value) if not isinstance(value, fa.TiledLayout): + low_priority_assignments.append((variable, constant)) + case _: + yield variable, constant + + # After all high-priority assignments have been attempted, switch to using + # low-priority assignments. + for variable, constant in low_priority_assignments: + yield variable, constant + + # Here, we have not managed to find an assignment for all the unknown + # variables. We now try to introduce new arbitrary (valid) assignments into + # the system, and hope that they turn out to be compatible with the constraint + # system. + for variable in unknowns: + if variable in constraint_system.assignments: + continue + # Try to instantiate a single variable to a default layout and see if it + # reduces the system. + match variable.key.memory_space: + case MemorySpace.REG: + layout = _strided_layout_for_variable(variable) + if layout is not None: + yield variable, cs.RegisterLayout(layout) + case MemorySpace.SMEM: + yield variable, cs.SMEMTiling(None) + case MemorySpace.TMEM: + layout = _default_tmem_layout_for_variable(variable) + if layout is not None: + yield variable, cs.TMEMLayout(layout) + case _: + raise ValueError(f"Unsupported memory space: {variable.key.memory_space}") + + +def find_assignments_for( + unknowns: Sequence[cs.Variable], + constraint_system: cs.ConstraintSystem, + *, + fuel: int, +) -> tuple[dict[cs.Variable, cs.Constant] | cs.Unsatisfiable, int]: + """Attempts to find assignments that satisfy `constraint_system` for `unknowns`. + + Args: + unknowns: the set of variables that are unknown. Represented as a sequence + of `Variable`s for determinism purposes. + constraint_system: the constraint system to satisfy. + fuel: the fuel to use for the search. Once the fuel is exhausted, we raise + an error. - if len(splat_layouts) > 1: - raise NotImplementedError( - "Finding a representative layout for several distinct splat layouts " - "is not supported." - ) + Returns: + A tuple where the first element is the solution, and the second element is + the fuel remaining after the search. The solution is either: + - Unsatisfiable() if the constraint system has unsatisfiable constraints. + - A dictionary assigning all the unknown variables to + `ConstantExpression`s such that the assignment satisfies the constraint + system otherwise. + """ + constraint_system = cs.reduce(constraint_system) + if isinstance(constraint_system, cs.Unsatisfiable): + return cs.Unsatisfiable(), fuel - if len(strided_layouts) > 1: - raise NotImplementedError( - "Finding a representative layout for several distinct strided layouts " - "is not supported." - ) + remaining_unknowns = [ + u for u in unknowns if u not in constraint_system.assignments.keys() + ] - if len(tiled_layouts) > 1: - raise NotImplementedError( - "Finding a representative layout for several distinct tiled layouts " - "is not supported." + # In this case, we have determined an assignment for all the unknown + # variables. Return their respective assignment. + if not remaining_unknowns: + assert not constraint_system.constraints, ( + "A satisfiable system should not have remaining unsatisfied" + " constraints. This is a bug." ) - - if tiled_layouts and strided_layouts: - raise NotImplementedError( - "Mixing strided and tiled layouts is not supported." + return { + v: k for v, k in constraint_system.assignments.items() if v in unknowns + }, fuel + + # If unknowns remain and we have fully reduced the system, we may still + # be able to make progress by trying out potential assignments. These + # new assignments could make the system unsatisfiable, so we use a recursive + # call to be able to backtrack if necessary. + for assignment in conjure_assignment( + remaining_unknowns, constraint_system + ): + if fuel <= 0: + raise ValueError( + "Layout inference failed to find a solution. Consider adding layout " + "annotations to your program to guide the search." + ) + # Trying one assignment consumes fuel. + fuel -= 1 + variable, expr = assignment + new_constraint_system = ( + cs.ConstraintSystem(assignments={variable: expr}) & constraint_system ) + if isinstance(new_constraint_system, cs.Unsatisfiable): + # This assignment is not compatible with the constraint system. + continue + solution, fuel = find_assignments_for( + unknowns, new_constraint_system, fuel=fuel + ) + if not isinstance(solution, cs.Unsatisfiable): + return solution, fuel - if tiled_layouts: - return layouts_lib.to_layout_attr(tiled_layouts[0]) + # TODO(bchetioui): should we have a way to give a useful dump to the user + # here, perhaps indicating what to layout cast. + return cs.Unsatisfiable(), fuel - if strided_layouts: - [strided_layout] = strided_layouts - return layouts_lib.to_layout_attr(strided_layout) - [splat_layout] = splat_layouts - return layouts_lib.to_layout_attr(splat_layout) +@dataclasses.dataclass() +class DerivationContext: + """Holds context information used for deriving an constraint system.""" + # A map of `ValueSite` to the variable that it is associated with. + variable_for_value_site: dict[ValueSite, cs.Variable] = dataclasses.field( + default_factory=dict, init=False + ) + # A map of `cs.Variable` to all the `ValueSite`s that it is associated with. + value_sites_for_variable: ValueSitesForVariable = ( + dataclasses.field(default_factory=dict, init=False) + ) + def update(self, mapping: ValueSitesForVariable) -> None: + for variable, value_sites in mapping.items(): + if variable in self.value_sites_for_variable: + self.value_sites_for_variable[variable].extend(value_sites) + else: + self.value_sites_for_variable[variable] = value_sites + for value_site in value_sites: + assert value_site not in self.variable_for_value_site + self.variable_for_value_site[value_site] = variable -def _infer_pointwise_op_layouts(op: ir.OpView) -> OptionalLayouts: + def producer_ref(self, operand: ValueSite) -> cs.Variable: + """Returns the producer reference variable for the given operand.""" + return self.variable_for_value_site[producer_result(operand)] - def is_array(v: ir.Value) -> bool: - return ir.VectorType.isinstance(v.type) - num_vector_operands = len([o for o in op.operands if is_array(o)]) - num_vector_results = len([r for r in op.results if is_array(r)]) +ValueSitesForVariable = dict[cs.Variable, list[ValueSite]] - if inference_utils.has_in_layouts_set(op): - op_in_layouts = inference_utils.in_layouts(op) - if op_in_layouts: - layout = op_in_layouts[0] - return (num_vector_operands * [layout], num_vector_results * [layout]) +# A constraint system derivation rule is a function that takes an MLIR operation +# and returns a constraint system, and a mapping from variables to value site +# identifiers. +# +# The intended meaning of the mapping is that, for each identifier in the list +# keyed by a given variable, the MLIR operand/result/argument corresponding to +# that identifier has the same layout as the variable. +# +# A `ConstraintSystemDerivationRule` must return a mapping such that the +# identifier corresponding to each value site must appear in the mapping, +# and each identifier in the mapping must be keyed by exactly one variable. +# Lastly, the mapping must only refer to variables and +# operands/results/arguments that correspond to the given operation. +ConstraintSystemDerivationRuleResult = cs.Unsatisfiable | tuple[ + cs.ConstraintSystem, ValueSitesForVariable +] +ConstraintSystemDerivationRule = Callable[ + [DerivationContext, ir.OpView], + ConstraintSystemDerivationRuleResult, +] +_constraint_system_derivation_rules: dict[ + str, ConstraintSystemDerivationRule +] = {} - if inference_utils.has_out_layouts_set(op): - op_out_layouts = inference_utils.out_layouts(op) - if op_out_layouts: - layout = op_out_layouts[0] - return (num_vector_operands * [layout], num_vector_results * [layout]) - layouts = set() +def _add_constraint_system_derivation_rule(op: type[ir.OpView]): + def wrapper(rule: ConstraintSystemDerivationRule): + if op is not None: + _constraint_system_derivation_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error + return rule - # We can also try to infer layouts from the layout of producer and - # consumer operations. - # - # We first look at producers; this enables e.g. propagating splat layouts as - # far down as possible, until since we may be able to propagate splat layouts - # further down before requiring a relayout in that way. - all_inputs_have_layout = True - for operand in op.operands: - if not ir.VectorType.isinstance(operand.type): - continue - if (layout := inference_utils.value_layout(operand)) is not None: - layouts.add(layout) - else: - all_inputs_have_layout = False - - # We only look at consumers if we haven't found a possible layout yet. This is - # to avoid propagating more complicated layouts up, to e.g. preserve splat - # layouts as far down as possible. - if not layouts: - for op_result in op.results: - if not ir.VectorType.isinstance(op_result.type): - continue - for op_operand_use in cast(ir.OpResult, op_result).uses: - consumer = op_operand_use.owner - op_user = consumer.operands[op_operand_use.operand_number] - layout = inference_utils.in_layout_for_operand(consumer, op_user) - if layout is not None: - layouts.add(layout) + return wrapper + + +def is_vector(v: ir.Value) -> bool: + return isinstance(v.type, ir.VectorType) - # TODO(bchetioui): when propagating up, the representative layout should be - # chosen in the opposite way as when propagating down. E.g., when propagating - # down, we should pick a strided layout over a splat layout; when propagating - # up, we should pick a splat layout over a strided layout. - # This is left for a future change, and currently we only do "down - # propagation". - layout = _choose_representative_layout(layouts) - # It is unsafe to t conclude that this op produces a splat if not all inputs - # have been inferred: some of them might turn out not to be splats! - if layouts_lib.is_splat_fragmented_layout(layout) and not all_inputs_have_layout: - return None - if layout is None: - return None - return (num_vector_operands * [layout], num_vector_results * [layout]) +def _is_smem_ref(v: ir.Value) -> bool: + return isinstance(v.type, ir.MemRefType) and utils.is_smem_ref(v) + + +def _is_tmem_ref(v: ir.Value) -> bool: + return isinstance(v.type, ir.MemRefType) and utils.is_tmem_ref(v) + + +def _pointwise_op_constraint_system( + ctx: DerivationContext, + op: ir.OpView, +) -> ConstraintSystemDerivationRuleResult: + del ctx + all_value_sites = vector_value_sites(op) + variable = cs.Variable(all_value_sites[-1]) + return cs.ConstraintSystem(), {variable: all_value_sites} for op in [ @@ -236,291 +542,1592 @@ def is_array(v: ir.Value) -> bool: arith.TruncFOp, arith.TruncIOp, arith.XOrIOp, + arith.SelectOp, mlir_math.ExpOp, mlir_math.Exp2Op, + mlir_math.SinOp, + mlir_math.CosOp, mlir_math.LogOp, mlir_math.RsqrtOp, mlir_math.TanhOp, - vector.LoadOp, - vector.StoreOp, + mlir_math.AbsFOp, + mlir_math.AbsIOp, + mlir_math.RoundOp, + mlir_math.RoundEvenOp, + mlir_math.CopySignOp, ]: - _add_layout_inference_rule(op, _infer_pointwise_op_layouts) + _add_constraint_system_derivation_rule(op)(_pointwise_op_constraint_system) + + +@_add_constraint_system_derivation_rule(mgpu.VectorLoadOp) +def _vector_load_constraint_system( + ctx: DerivationContext, + op: mgpu.VectorLoadOp, +) -> ConstraintSystemDerivationRuleResult: + # TODO(b/447079781): Investigate whether we should check for contiguous + # strides here. An initial implementation of this failed the + # test_gmem_to_smem_with_multiple_smem_indexers_and_transforms test, but + # we should confirm that this is properly supported. + + # Registers + dest = ValueSite(op, VariableType.RESULT, 0) + dest_var = cs.Variable(dest) + value_sites_for_variable = {dest_var: [dest]} + constraints = [cs.NotOfType(dest_var, fa.WGSplatFragLayout)] + + # SMEM + if utils.is_smem_ref(op.source): + source = ValueSite(op, VariableType.OPERAND, 0) + source_var = ctx.producer_ref(source) + value_sites_for_variable[source_var] = [source] + shape = tuple(ir.MemRefType(op.source.type).shape) + constraints.append(cs.IsTransferable(source_var, dest_var, shape)) + + system = cs.ConstraintSystem(constraints=constraints) + return system, value_sites_for_variable + + +@_add_constraint_system_derivation_rule(mgpu.VectorStoreOp) +def _vector_store_constraint_system( + ctx: DerivationContext, + op: mgpu.VectorStoreOp, +) -> ConstraintSystemDerivationRuleResult: + # TODO(b/447079781): Investigate whether we should check for contiguous + # strides here. An initial implementaiton of this failed the + # test_gmem_to_smem_with_multiple_smem_indexers_and_transforms test, but + # we should confirm that this is properly supported. + + # Registers + value = ValueSite(op, VariableType.OPERAND, 0) + value_var = cs.Variable(value) + value_sites_for_variable = {value_var: [value]} + + # SMEM + constraints = [] + if utils.is_smem_ref(op.destination): + dest = ValueSite(op, VariableType.OPERAND, 1) + dest_var = ctx.producer_ref(dest) + value_sites_for_variable[dest_var] = [dest] + shape = tuple(ir.MemRefType(op.destination.type).shape) + constraints.append(cs.IsTransferable(value_var, dest_var, shape)) + + system = cs.ConstraintSystem(constraints=constraints) + return system, value_sites_for_variable + + +@_add_constraint_system_derivation_rule(mgpu.DebugPrintOp) +def _debug_print_constraint_system( + ctx: DerivationContext, + op: mgpu.DebugPrintOp, +) -> ConstraintSystemDerivationRuleResult: + del ctx + value = ValueSite(op, VariableType.OPERAND, 0) + return cs.ConstraintSystem(), {cs.Variable(value): [value]} + + +@_add_constraint_system_derivation_rule(mgpu.PrintLayoutOp) +def _print_layout_constraint_system( + ctx: DerivationContext, + op: mgpu.PrintLayoutOp, +) -> ConstraintSystemDerivationRuleResult: + value = ValueSite(op, VariableType.OPERAND, 0) + var = cs.Variable(value) if is_vector(op.value) else ctx.producer_ref(value) + return cs.ConstraintSystem(), {var: [value]} + + +@_add_constraint_system_derivation_rule(mgpu.BroadcastedIotaOp) +def _broadcasted_iota_constraint_system( + ctx: DerivationContext, + op: mgpu.BroadcastedIotaOp, +) -> ConstraintSystemDerivationRuleResult: + del ctx + value = ValueSite(op, VariableType.RESULT, 0) + var = cs.Variable(value) + constraints = [cs.NotOfType(var, fa.WGSplatFragLayout)] + return cs.ConstraintSystem(constraints=constraints), {var: [value]} + + +@_add_constraint_system_derivation_rule(mgpu.OptimizationBarrierOp) +def _optimization_barrier_constraint_system( + ctx: DerivationContext, + op: ir.OpView, +) -> ConstraintSystemDerivationRuleResult: + del ctx + value_sites_for_variable: ValueSitesForVariable = {} + for i, operand in enumerate(op.operands): + if not is_vector(operand): + continue + variable = cs.Variable(ValueSite(op, VariableType.OPERAND, i)) + value_sites_for_variable[variable] = [ + ValueSite(op, VariableType.OPERAND, i), + ValueSite(op, VariableType.RESULT, i) + ] -@partial(_add_layout_inference_rule, arith.ConstantOp) -def _infer_constant_op_layout(constant_op: arith.ConstantOp) -> OptionalLayouts: - if not ir.VectorType.isinstance(constant_op.result.type): - return None + return cs.ConstraintSystem(), value_sites_for_variable - shaped_ty = cast(ir.ShapedType, constant_op.result.type) + +@_add_constraint_system_derivation_rule(vector.BroadcastOp) +def _vector_splat_constraint_system( + ctx: DerivationContext, + op: ir.OpView, +) -> ConstraintSystemDerivationRuleResult: + del ctx + result = ValueSite(op, VariableType.RESULT, 0) + variable = cs.Variable(result) + layout = fa.WGSplatFragLayout(tuple(cast(ir.ShapedType, op.result.type).shape)) + system = cs.ConstraintSystem( + assignments={variable: cs.RegisterLayout(layout)} + ) + return system, {variable: [result]} + + +@_add_constraint_system_derivation_rule(arith.ConstantOp) +def _constant_constraint_system( + ctx: DerivationContext, + constant_op: arith.ConstantOp, +) -> ConstraintSystemDerivationRuleResult: + del ctx value = constant_op.value - layout = None + result = ValueSite(constant_op, VariableType.RESULT, 0) + variable = cs.Variable(result) + shape = tuple(ir.ShapedType(constant_op.result.type).shape) if ( - ir.DenseElementsAttr.isinstance(value) + isinstance(value, ir.DenseElementsAttr) and ir.DenseElementsAttr(value).is_splat ): - layout = layouts_lib.to_splat_fragmented_layout_attr( - fa.WGSplatFragLayout(shape=shaped_ty.shape) + layout = fa.WGSplatFragLayout(shape=shape) + system = cs.ConstraintSystem( + assignments={variable: cs.RegisterLayout(layout)} ) - # If the constant is not a splat, there is no obvious good choice of layout. - # We need to look at the consumers of the constant to find a layout that works - # for them. If there are several users with N different layouts, we can - # arbitrarily choose any one of them for the constant, since we expect - # whichever choice we make to lead to N-1 relayouts, which all have the same - # cost. - # - # We assign a strided layout if the constant has no user, for completeness. - elif constant_op.result.uses: - for use in cast(ir.OpResult, constant_op.result).uses: - consumer = use.owner - operand = consumer.operands[use.operand_number] - layout = inference_utils.in_layout_for_operand(consumer, operand) - if layout is not None: - break - - # If the constant is not a splat, has no user, or a layout could not be - # determined from looking at the users, we assign a strided layout for - # completeness. - if layout is None: - layout = layouts_lib.to_strided_fragmented_layout_attr( - fa.WGStridedFragLayout.from_shaped_type(shaped_ty) + else: + constant_is_not_splat = cs.NotOfType(variable, fa.WGSplatFragLayout) + system = cs.ConstraintSystem(constraints=[constant_is_not_splat]) + + return system, {variable: [result]} + + +def _terminator( + block: ir.Block, expected_terminator: type[ir.OpView] +) -> ir.OpView: + """Returns the terminator of the given block. + + Checks that the terminator is of the expected type. + """ + terminator = block.operations[len(block.operations) - 1] + assert isinstance(terminator, expected_terminator) + return terminator.opview + + +@_add_constraint_system_derivation_rule(scf.ForOp) +def _for_constraint_system( + ctx: DerivationContext, + op: scf.ForOp, +) -> ConstraintSystemDerivationRuleResult: + [block] = op.region.blocks + yield_op = _terminator(block, scf.YieldOp) + value_sites_for_variable: ValueSitesForVariable = {} + + # Account for the lower bound, upper bound, and step of the loop, which appear + # in the operands but not in the results. + num_leading_args = 3 + for index, o in enumerate(op.operands): + if not is_vector(o) and not _is_smem_ref(o): + continue + result_index = index - num_leading_args + arg_index = index - num_leading_args + 1 # Account for the induction var. + operand = ValueSite(op, VariableType.OPERAND, index) + arg = ValueSite(op, VariableType.ARGUMENT, arg_index, region_index=0) + result = ValueSite(op, VariableType.RESULT, result_index) + yield_operand = ValueSite( + yield_op, VariableType.OPERAND, result_index ) + var = cs.Variable(operand) if is_vector(o) else ctx.producer_ref(operand) + value_sites_for_variable[var] = [operand, arg, result, yield_operand] + + return cs.ConstraintSystem(), value_sites_for_variable - return [], [layout] +def prime_decomposition(n: int) -> list[int]: + """Returns the prime decomposition of the given number `n` as a list of ints. -@partial(_add_layout_inference_rule, scf.YieldOp) -def _infer_yield_op_layout(op: scf.YieldOp) -> OptionalLayouts: - layouts = [] - for result in op.results_: - if not ir.VectorType.isinstance(result.type): + A factor appears as many times in the list as the power up to which it divides + `n`. + """ + # This implementation should be sufficiently efficient for small `n`, which + # should always be the case for us. + prime_factors = [] + divisor = 2 + while divisor * divisor <= n: + while n % divisor == 0: + n //= divisor + prime_factors.append(divisor) + divisor += 1 + if n != 1: + prime_factors.append(n) + return prime_factors + + +# TODO(bchetioui): let's see if we need to parametrize this by depth. +def dynamic_gcd(a: int, b: ir.Value) -> int: + if a <= 0: + raise ValueError("a must be strictly positive") + if isinstance(b.type, ir.VectorType): + # We don't actually know the values of the vector elements, so we pick 1 + # as the only safe value. + return 1 + if not isinstance(b.type, ir.IntegerType) and not isinstance( + b.type, ir.IndexType + ): + raise ValueError(f"Expected an integer dynamic value, got a {b.type}") + if isinstance(b.owner, arith.ConstantOp): + return math.gcd(a, b.owner.literal_value) + running_gcd = 1 + for factor in prime_decomposition(a): + if utils.is_known_divisible(b, running_gcd * factor): + running_gcd *= factor + return running_gcd + + +@_add_constraint_system_derivation_rule(scf.WhileOp) +def _while_constraint_system( + ctx: DerivationContext, + op: scf.WhileOp, +) -> ConstraintSystemDerivationRuleResult: + del ctx + [before_block] = op.before.blocks + [after_block] = op.after.blocks + cond_op = _terminator(before_block, scf.ConditionOp) + yield_op = _terminator(after_block, scf.YieldOp) + + value_sites_for_variable: ValueSitesForVariable = {} + + for value_site in vector_value_sites(op): + idx = value_site.index + match value_site.type: + case VariableType.OPERAND: + arg = ValueSite(op, VariableType.ARGUMENT, idx, region_index=0) + yield_operand = ValueSite(yield_op, VariableType.OPERAND, idx) + value_sites_for_variable[cs.Variable(value_site)] = [ + value_site, + arg, + yield_operand, + ] + case VariableType.RESULT: + # Increment by 1 to account for the conditional. + cond_operand = ValueSite(cond_op, VariableType.OPERAND, idx + 1) + arg = ValueSite(op, VariableType.ARGUMENT, idx, region_index=1) + value_sites_for_variable[cs.Variable(value_site)] = [ + value_site, + arg, + cond_operand, + ] + case _ as never: + assert_never(never) # pytype: disable=wrong-arg-types + + return cs.ConstraintSystem(), value_sites_for_variable + + +@_add_constraint_system_derivation_rule(scf.IndexSwitchOp) +def _index_switch_constraint_system( + ctx: DerivationContext, + op: scf.IndexSwitchOp, +) -> ConstraintSystemDerivationRuleResult: + del ctx + value_sites_for_variable: ValueSitesForVariable = { + cs.Variable(o): [o] for o in vector_value_sites(op) + } + for region in op.regions: + [block] = region.blocks + yield_op = _terminator(block, scf.YieldOp) + for value_site in value_sites_for_variable.keys(): + assert value_site.key.type == VariableType.RESULT + yield_operand = ValueSite( + yield_op, VariableType.OPERAND, value_site.key.index + ) + value_sites_for_variable[value_site].append(yield_operand) + + return cs.ConstraintSystem(), value_sites_for_variable + + +@_add_constraint_system_derivation_rule(mgpu.LayoutCastOp) +def _layout_cast_constraint_system( + ctx: DerivationContext, + op: mgpu.LayoutCastOp, +) -> ConstraintSystemDerivationRuleResult: + del ctx + operand = ValueSite(op, VariableType.OPERAND, 0) + result = ValueSite(op, VariableType.RESULT, 0) + variable = cs.Variable(operand) + out_layout = layouts_lib.from_layout_attr(op.new_layout) + # TODO(bchetioui): think about raising a better error here. + if not is_valid_register_layout_assignment(operand.shape, out_layout): + return cs.Unsatisfiable() + return ( + cs.ConstraintSystem( + assignments={variable: cs.RegisterLayout(out_layout)} + ), + {variable: [operand, result]}, + ) + + +def _infer_tiling_for_mma_ref( + ref_ty: ir.MemRefType, max_swizzle: mgpu.SwizzlingMode +) -> tuple[int, int]: + element_bytewidth = utils.bytewidth(ref_ty.element_type) + strides, _ = ref_ty.get_strides_and_offset() + min_dim_index = np.argmin(strides) + minor_dim = ref_ty.shape[min_dim_index] + + # Try tiling with all swizzling modes starting from the largest one. + for swizzle in [ + mgpu.SwizzlingMode.k128ByteSwizzle, + mgpu.SwizzlingMode.k64ByteSwizzle, + mgpu.SwizzlingMode.k32ByteSwizzle, + mgpu.SwizzlingMode.kNoSwizzle, + ]: + if swizzle > max_swizzle: continue - if (layout := inference_utils.value_layout(result)) is not None: - if layouts_lib.is_splat_fragmented_layout(layout): - return None - layouts.append(layout) + swizzle_elems = swizzle // element_bytewidth + if minor_dim % swizzle_elems == 0: + minor_tiling = swizzle_elems + break + else: + # No valid tile transform can be inferred. + raise ValueError(f"{ref_ty.shape} is not a valid WGMMA shape") + + major_tiling = 8 + transposed = min_dim_index != len(strides) - 1 + if transposed: + tiling = (minor_tiling, major_tiling) + else: + tiling = (major_tiling, minor_tiling) + return tiling + + +def _infer_wgmma_tiling( + a_type: ir.Type, b_type: ir.MemRefType +) -> tuple[tuple[int, int] | None, tuple[int, int]]: + """Infers the tiling for a (if in SMEM) and b of a WGMMAOp. + + If both a and b are in SMEM, this function infers tilings that have matching + swizzle values. + """ + b_tiling = _infer_tiling_for_mma_ref( + b_type, max_swizzle=mgpu.SwizzlingMode.k128ByteSwizzle + ) + b_swizzle = _compute_swizzle(b_type, lc.TileTransform(b_tiling)) + if not isinstance(a_type, ir.MemRefType): + return None, b_tiling + + a_tiling = _infer_tiling_for_mma_ref( + cast(ir.MemRefType, a_type), max_swizzle=b_swizzle + ) + a_swizzle = _compute_swizzle(a_type, lc.TileTransform(a_tiling)) + if a_swizzle != b_swizzle: + # The swizzle for a and b has to match. This is not a fundamental + # limitation, rather the lowering doesn't currently support it. + b_tiling = _infer_tiling_for_mma_ref(b_type, max_swizzle=a_swizzle) + b_swizzle = _compute_swizzle(b_type, lc.TileTransform(b_tiling)) + assert a_swizzle == b_swizzle + return a_tiling, b_tiling + + +@_add_constraint_system_derivation_rule(mgpu.WGMMAOp) +def _wgmma_constraint_system( + ctx: DerivationContext, + op: mgpu.WGMMAOp, +) -> ConstraintSystemDerivationRuleResult: + assignments: dict[cs.Variable, cs.Constant] = {} + value_sites_for_variable: ValueSitesForVariable = {} + + acc_out = ValueSite(op, VariableType.RESULT, 0) + acc_in = ValueSite(op, VariableType.OPERAND, 0) + acc_var = cs.Variable(acc_out) + acc_layout = fa.WGMMA_LAYOUT + assignments[acc_var] = cs.RegisterLayout(acc_layout) + acc_is_valid = is_valid_register_layout_assignment(acc_out.shape, acc_layout) + value_sites_for_variable[acc_var] = [acc_in, acc_out] + + a_tiling, b_tiling = _infer_wgmma_tiling(op.a.type, op.b.type) + b = ValueSite(op, VariableType.OPERAND, 2) + b_var = ctx.producer_ref(b) + b_tile_transform = lc.TileTransform(b_tiling) + b_is_valid = is_valid_smem_layout_assignment(b.shape, b_tile_transform) + assignments[b_var] = cs.SMEMTiling(b_tile_transform) + value_sites_for_variable[b_var] = [b] + + a = ValueSite(op, VariableType.OPERAND, 1) + if _is_smem_ref(op.a): + a_var = ctx.producer_ref(a) + a_tile_transform = lc.TileTransform(a_tiling) + assignments[a_var] = cs.SMEMTiling(a_tile_transform) + a_is_valid = is_valid_smem_layout_assignment(a.shape, a_tile_transform) + else: + assert a_tiling is None + a_var = cs.Variable(a) + if utils.bitwidth(op.a.type.element_type) == 8: + layout = fa.WGMMA_LAYOUT_8BIT else: - # Not all layouts could be inferred for vector ops. Return for now. - return None + layout = fa.WGMMA_LAYOUT + assignments[a_var] = cs.RegisterLayout(layout) + a_is_valid = is_valid_register_layout_assignment(a.shape, layout) + + value_sites_for_variable[a_var] = [a] + + # TODO(bchetioui): think about raising a better error here. + if not a_is_valid or not b_is_valid or not acc_is_valid: + return cs.Unsatisfiable() + return cs.ConstraintSystem(assignments), value_sites_for_variable + + +@_add_constraint_system_derivation_rule(vector.BroadcastOp) +def _vector_broadcast_constraint_system( + ctx: DerivationContext, + op: vector.BroadcastOp, +) -> ConstraintSystemDerivationRuleResult: + del ctx + # This is not expected to be necessary at the moment. We should be using + # mgpu.BroadcastInDimOp instead when dealing with broadcasting vectors. + if isinstance(op.source.type, ir.ShapedType): + raise NotImplementedError("Only vector broadcasts from scalars are supported.") + out_variable = cs.Variable(ValueSite(op, VariableType.RESULT, 0)) + layout = cs.RegisterLayout(fa.WGSplatFragLayout(tuple(op.result.type.shape))) + return ( + cs.ConstraintSystem(assignments={out_variable: layout}), + {out_variable: [out_variable.key]}, + ) - return (layouts, []) +@_add_constraint_system_derivation_rule(vector.ReductionOp) +def _vector_reduction_constraint_system( + ctx: DerivationContext, + op: vector.ReductionOp, +) -> ConstraintSystemDerivationRuleResult: + del ctx + in_variable = cs.Variable(ValueSite(op, VariableType.OPERAND, 0)) + return cs.ConstraintSystem(), {in_variable: [in_variable.key]} + + +def _reduction_constraints( + larger: cs.Variable, + smaller: cs.Variable, + reduction_dims: tuple[int, ...], +) -> list[cs.Constraint]: + return [ + cs.Equals(lhs=smaller, rhs=cs.Reduce(larger, reduction_dims)), + # TODO(allanrenucci): Remove once we support reduction of strided layouts. + cs.NotOfType(larger, fa.WGStridedFragLayout), + ] -@partial(_add_layout_inference_rule, scf.ForOp) -def _infer_for_op_layout(op: scf.ForOp) -> OptionalLayouts: - yield_op = op.body.operations[len(op.body.operations) - 1] - assert isinstance(yield_op, scf.YieldOp) - if inference_utils.has_in_layouts_set(yield_op): - yield_layouts = list(inference_utils.in_layouts(yield_op)) - if any( - layouts_lib.is_splat_fragmented_layout(layout) - for layout in yield_layouts - ): - return None - return (yield_layouts, yield_layouts) +@_add_constraint_system_derivation_rule(vector.MultiDimReductionOp) +def _multi_dim_reduction_constraint_system( + ctx: DerivationContext, + op: vector.MultiDimReductionOp, +) -> ConstraintSystemDerivationRuleResult: + del ctx + source = ValueSite(op, VariableType.OPERAND, 0) + acc = ValueSite(op, VariableType.OPERAND, 1) + out = ValueSite(op, VariableType.RESULT, 0) + source_variable = cs.Variable(source) + out_variable = cs.Variable(out) + + reduction_constraints = _reduction_constraints( + source_variable, + out_variable, + tuple(op.reduction_dims), + ) + # TODO(bchetioui): in the future, we may need to add rules that prevent + # strided layouts from being chosen---since trying to reduce a strided layout + # may cause us to raise an Exception at the moment. + return ( + cs.ConstraintSystem(constraints=reduction_constraints), + {source_variable: [source], out_variable: [acc, out]}, + ) - # TODO(bchetioui): we don't attempt to propagate from outside for the moment. - # For the existing kernels, propagating from the YieldOp should be enough. - return None +@_add_constraint_system_derivation_rule(mgpu.BroadcastInDimOp) +def _broadcast_in_dim_constraint_system( + ctx: DerivationContext, + op: mgpu.BroadcastInDimOp, +) -> ConstraintSystemDerivationRuleResult: + del ctx + out_variable = cs.Variable(ValueSite(op, VariableType.RESULT, 0)) + source_variable = cs.Variable(ValueSite(op, VariableType.OPERAND, 0)) + out_shape = tuple(cast(ir.ShapedType, op.result.type).shape) + reduction_dims = tuple( + i for i in range(len(out_shape)) if i not in op.broadcast_dimensions + ) + reduction_constraints = _reduction_constraints( + out_variable, source_variable, reduction_dims + ) + + return ( + cs.ConstraintSystem(constraints=reduction_constraints), + { + source_variable: [source_variable.key], + out_variable: [out_variable.key], + }, + ) + +@_add_constraint_system_derivation_rule(vector.ShapeCastOp) +def _shape_cast_constraint_system( + ctx: DerivationContext, op: vector.ShapeCastOp +) -> ConstraintSystemDerivationRuleResult: + del ctx + in_shape = tuple(cast(ir.ShapedType, op.source.type).shape) + out_shape = tuple(cast(ir.ShapedType, op.result.type).shape) -@partial(_add_layout_inference_rule, vector.SplatOp) -def _infer_splat_op_layout(splat_op: vector.SplatOp) -> OptionalLayouts: - layout = layouts_lib.to_splat_fragmented_layout_attr( - fa.WGSplatFragLayout( - shape=cast(ir.ShapedType, splat_op.result.type).shape + in_variable = cs.Variable(ValueSite(op, VariableType.OPERAND, 0)) + out_variable = cs.Variable(ValueSite(op, VariableType.RESULT, 0)) + + # Here, we are in a case where we are stating + # + # out_variable = reshape(in_variable, in_shape, out_shape). + # + # Thanks to the symmetric property of reshape, we can also issue a constraint + # in the other direction, i.e. + # + # in_variable = reshape(out_variable, out_shape, in_shape) + # + # in order to be able to figure out an assignment for `in_variable`. if we + # happen to know `out_variable`. If we only issue the first constraint, then + # we will not be able to figure out an assignment for `in_variable` if we + # only know `out_variable`, even though their relationship is fully + # determined. + in_to_out = cs.Reshape( + in_variable, source_shape=in_shape, target_shape=out_shape + ) + out_to_in = cs.Reshape( + out_variable, source_shape=out_shape, target_shape=in_shape + ) + + return ( + cs.ConstraintSystem( + constraints=[ + cs.Equals(lhs=out_variable, rhs=in_to_out), + cs.Equals(lhs=in_variable, rhs=out_to_in), + ], + ), + {in_variable: [in_variable.key], out_variable: [out_variable.key]}, + ) + + +@_add_constraint_system_derivation_rule(vector.ExtractStridedSliceOp) +def _extract_strided_slice_constraint_system( + ctx: DerivationContext, op: vector.ExtractStridedSliceOp +) -> ConstraintSystemDerivationRuleResult: + del ctx + if any(ir.IntegerAttr(s).value != 1 for s in op.strides): + raise NotImplementedError("`strides` must contain only 1s.") + operand = ValueSite(op, VariableType.OPERAND, 0) + result = ValueSite(op, VariableType.RESULT, 0) + variable = cs.Variable(operand) + offsets = tuple(ir.IntegerAttr(o).value for o in op.offsets) + constraints = [ + cs.Divides(variable, offsets), + # TODO(allanrenucci): Remove once vectors with splat and strided layouts + # can be sliced. + cs.NotOfType(variable, fa.WGSplatFragLayout), + cs.NotOfType(variable, fa.WGStridedFragLayout), + ] + return ( + cs.ConstraintSystem(constraints=constraints), + # We use a single variable because lowering does not support two different + # layouts for `source` and `result`. + {variable: [operand, result]}, + ) + + +@_add_constraint_system_derivation_rule(vector.ExtractOp) +def _vector_extract_constraint_system( + ctx: DerivationContext, op: vector.ExtractOp +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: + del ctx + if not isinstance(op.result.type, ir.VectorType): # scalar result + operand = ValueSite(op, VariableType.OPERAND, 0) + variable = cs.Variable(operand) + layout = fa.WGSplatFragLayout(tuple(op.source.type.shape)) + # We only support indexing for splat layout. + assignments = {variable: cs.RegisterLayout(layout)} + return cs.ConstraintSystem(assignments), {variable: [operand]} + + if op.dynamic_position: + raise NotImplementedError("Only slicing with static indices allowed.") + operand = ValueSite(op, VariableType.OPERAND, 0) + result = ValueSite(op, VariableType.RESULT, 0) + variable = cs.Variable(operand) + constraints = [ + cs.Divides(variable, tuple(op.result.type.shape)), + # TODO(allanrenucci): Remove once vectors with splat and strided layouts + # can be sliced. + cs.NotOfType(variable, fa.WGSplatFragLayout), + cs.NotOfType(variable, fa.WGStridedFragLayout), + ] + return ( + cs.ConstraintSystem(constraints=constraints), + {variable: [operand, result]}, + ) + + +@_add_constraint_system_derivation_rule(mgpu.CustomPrimitiveOp) +def _custom_primitive_constraint_system( + ctx: DerivationContext, + op: mgpu.CustomPrimitiveOp, +) -> ConstraintSystemDerivationRuleResult: + assignments: dict[cs.Variable, cs.Constant] = {} + constraints: list[cs.Constraint] = [] + in_layouts = iter(op.in_layouts) + in_transforms = iter(op.in_transforms) + variables: list[cs.Variable] = [] + for i, operand in enumerate(op.operands): + if is_vector(operand): + v = cs.Variable(ValueSite(op, VariableType.OPERAND, i)) + variables.append(v) + assignments[v] = cs.RegisterLayout( + layouts_lib.from_layout_attr(next(in_layouts)) + ) + elif _is_smem_ref(operand): + # Here we need to create a new variable, even though it is equal to the + # source operand. This is because we directly assign the new variable and + # if we did that to the source there could be conflicting assignments. + # For example, the same ref could be passed into the custom op twice with + # different transforms, which needs to yield an unsatisfiable system. + # + # TODO(b/447079781): Consider creating the final constraint system using + # __and__ and potentially returning Unsatisfiable() directly if there is + # a conflict between the assignments. + value_site = ValueSite(op, VariableType.OPERAND, i) + source_var = ctx.producer_ref(value_site) + v = cs.Variable(value_site) + constraints.append(cs.Equals(lhs=source_var, rhs=v)) + variables.append(v) + transforms = next(in_transforms) + ref_ty = value_site.value.type + tiling = _extract_smem_tiling_from_custom_transform_attrs(ref_ty, transforms) + assignments[v] = tiling + + out_layouts = iter(op.out_layouts) + for i, result in enumerate(op.results): + if isinstance(result.type, ir.VectorType): + v = cs.Variable(ValueSite(op, VariableType.RESULT, i)) + variables.append(v) + assignments[v] = cs.RegisterLayout( + layouts_lib.from_layout_attr(next(out_layouts)) ) + return ( + cs.ConstraintSystem(assignments, constraints), + {v: [v.key] for v in variables}, ) - return [], [layout] +def _tmem_layout_from_layout_attr( + layout_attr: mgpu.TiledLayout, +) -> tcgen05.TMEMLayout: + layout = layouts_lib.from_layout_attr(layout_attr) + assert isinstance(layout, fa.TiledLayout) + return tcgen05.TMEMLayout( + layout.tiling, layout.warp_dims, layout.lane_dims, layout.vector_dim + ) + + +@_add_constraint_system_derivation_rule(mgpu.TmemLayoutCastOp) +def _tmem_layout_cast_constraint_system( + ctx: DerivationContext, + op: mgpu.TmemLayoutCastOp, +) -> ConstraintSystemDerivationRuleResult: + operand = ValueSite(op, VariableType.OPERAND, 0) + variable = ctx.producer_ref(operand) + result = ValueSite(op, VariableType.RESULT, 0) + tmem_layout = _tmem_layout_from_layout_attr(op.new_layout) + if not is_valid_tmem_layout_assignment(operand.shape, tmem_layout): + return cs.Unsatisfiable() + out_layout = cs.TMEMLayout(tmem_layout) + return ( + cs.ConstraintSystem(assignments={variable: out_layout}), + {variable: [operand, result]}, + ) -def _update_layout_shape( - layout: ir.Attribute, shape: Sequence[int], origin: str -) -> ir.Attribute: - if layouts_lib.is_splat_fragmented_layout( - layout - ) or layouts_lib.is_strided_fragmented_layout(layout): - return layouts_lib.to_layout_attr( - dataclasses.replace(layouts_lib.from_layout_attr(layout), shape=shape) + +@_add_constraint_system_derivation_rule(mgpu.TmemAllocOp) +def _tmem_alloc_constraint_system( + ctx: DerivationContext, + op: mgpu.TmemAllocOp, +) -> ConstraintSystemDerivationRuleResult: + del ctx + result = ValueSite(op, VariableType.RESULT, 0) + result_var = cs.Variable(result) + in_smem = ValueSite(op, VariableType.OPERAND, 0) + in_smem_var = cs.Variable(in_smem) + assignments: dict[cs.Variable, cs.Constant] = { + in_smem_var: cs.SMEMTiling(None) + } + operands_for_variable = {result_var: [result], in_smem_var: [in_smem]} + return cs.ConstraintSystem(assignments=assignments), operands_for_variable + + +@_add_constraint_system_derivation_rule(mgpu.TmemDeallocOp) +def _tmem_dealloc_constraint_system( + ctx: DerivationContext, + op: mgpu.TmemDeallocOp, +) -> ConstraintSystemDerivationRuleResult: + operand = ValueSite(op, VariableType.OPERAND, 0) + variable = ctx.producer_ref(operand) + return cs.ConstraintSystem(), {variable: [operand]} + + +@_add_constraint_system_derivation_rule(mgpu.TcGen05MMAOp) +def _tcgen05_mma_constraint_system( + ctx: DerivationContext, + op: mgpu.TcGen05MMAOp, +) -> ConstraintSystemDerivationRuleResult: + assignments: dict[cs.Variable, cs.Constant] = {} + operands_for_variable: ValueSitesForVariable = {} + + # TMEM + acc = ValueSite(op, VariableType.OPERAND, 0) + acc_variable = ctx.producer_ref(acc) + acc_type = ir.ShapedType(op.accumulator.type) + acc_layout = tcgen05._infer_tmem_layout( + tuple(acc_type.shape), op.collective, packing=1 + ) + assignments[acc_variable] = cs.TMEMLayout(acc_layout) + acc_is_valid = is_valid_tmem_layout_assignment(acc.shape, acc_layout) + operands_for_variable[acc_variable] = [acc] + + if _is_tmem_ref(op.a): + a = ValueSite(op, VariableType.OPERAND, 1) + a_type = ir.ShapedType(op.a.type) + a_var = ctx.producer_ref(a) + packing = 32 // utils.bitwidth(a_type.element_type) + a_layout = tcgen05._infer_tmem_layout( + tuple(a_type.shape), op.collective, packing + ) + assignments[a_var] = cs.TMEMLayout(a_layout) + operands_for_variable[a_var] = [a] + a_is_valid = is_valid_tmem_layout_assignment(a.shape, a_layout) + else: + assert _is_smem_ref(op.a) + a_tiling = _infer_tiling_for_mma_ref( + ir.MemRefType(op.a.type), + max_swizzle=mgpu.SwizzlingMode.k128ByteSwizzle, ) - raise NotImplementedError(f"Unsupported {origin} layout: {layout}.") - - -@partial(_add_layout_inference_rule, vector.ShapeCastOp) -def _infer_shape_cast_op_layout(op: vector.ShapeCastOp) -> OptionalLayouts: - in_layout = inference_utils.value_layout(op.source) - if in_layout is None: - out_layout = inference_utils.value_layout(op.result) - if out_layout is None: - return None - in_layout = _update_layout_shape( - out_layout, ir.VectorType(op.source.type).shape, "source" + a = ValueSite(op, VariableType.OPERAND, 1) + a_var = ctx.producer_ref(a) + a_tile_transform = lc.TileTransform(a_tiling) + assignments[a_var] = cs.SMEMTiling(a_tile_transform) + operands_for_variable[a_var] = [a] + a_is_valid = is_valid_smem_layout_assignment(a.shape, a_tile_transform) + + # SMEM + M = op.accumulator.type.shape[0] + if M == 64 and not op.collective.value: + # We can't split N into groups if we would partition it below the tile size. + N = op.b.type.shape[1] + element_type_bitwidth = utils.bitwidth(op.b.type.element_type) + n_lane_groups = 2 + max_b_swizzle = next( + s + for s in reversed(mgpu.SwizzlingMode) + if 8 * s // element_type_bitwidth <= N // n_lane_groups ) - return [in_layout], [out_layout] + else: + max_b_swizzle = mgpu.SwizzlingMode.k128ByteSwizzle + + b_tiling = _infer_tiling_for_mma_ref(ir.MemRefType(op.b.type), max_b_swizzle) + b = ValueSite(op, VariableType.OPERAND, 2) + b_var = ctx.producer_ref(b) + b_tile_transform = lc.TileTransform(b_tiling) + assignments[b_var] = cs.SMEMTiling(b_tile_transform) + operands_for_variable[b_var] = [b] + b_is_valid = is_valid_smem_layout_assignment(b.shape, b_tile_transform) + + # TODO(bchetioui): think about raising a better error here. + if not a_is_valid or not b_is_valid or not acc_is_valid: + return cs.Unsatisfiable() + + return cs.ConstraintSystem(assignments=assignments), operands_for_variable + + +@_add_constraint_system_derivation_rule(mgpu.AsyncLoadTmemOp) +def _async_load_tmem_constraint_system( + ctx: DerivationContext, + op: mgpu.AsyncLoadTmemOp, +) -> ConstraintSystemDerivationRuleResult: + source = ValueSite(op, VariableType.OPERAND, 0) + source_variable = ctx.producer_ref(source) + destination = ValueSite(op, VariableType.RESULT, 0) + destination_variable = cs.Variable(destination) + constraint = cs.IsTransferable( + source_variable, + destination_variable, + tuple(ir.ShapedType(op.source.type).shape), + ) + return ( + cs.ConstraintSystem(constraints=[constraint]), + {source_variable: [source], destination_variable: [destination]}, + ) + - out_layout = _update_layout_shape( - in_layout, ir.VectorType(op.result.type).shape, "result" +@_add_constraint_system_derivation_rule(mgpu.SliceTmemOp) +def _slice_tmem_constraint_system( + ctx: DerivationContext, + op: mgpu.SliceTmemOp, +) -> ConstraintSystemDerivationRuleResult: + operand = ValueSite(op, VariableType.OPERAND, 0) + operand_variable = ctx.producer_ref(operand) + result = ValueSite(op, VariableType.RESULT, 0) + result_variable = cs.Variable(result) + return ( + cs.ConstraintSystem(), + {operand_variable: [operand], result_variable: [result]}, ) - return [in_layout], [out_layout] -@partial(_add_layout_inference_rule, vector.ReductionOp) -def _infer_reduction_op_layout(op: vector.ReductionOp) -> OptionalLayouts: - if layout := inference_utils.value_layout(op.vector): - return [layout], [] - return None +@_add_constraint_system_derivation_rule(mgpu.AsyncStoreTmemOp) +def _async_store_tmem_constraint_system( + ctx: DerivationContext, + op: mgpu.AsyncStoreTmemOp, +) -> ConstraintSystemDerivationRuleResult: + source = ValueSite(op, VariableType.OPERAND, 0) + source_variable = cs.Variable(source) + destination = ValueSite(op, VariableType.OPERAND, 1) + destination_variable = ctx.producer_ref(destination) + constraint = cs.IsTransferable( + source_variable, + destination_variable, + tuple(ir.ShapedType(op.source.type).shape), + ) + return ( + cs.ConstraintSystem(constraints=[constraint]), + {source_variable: [source], destination_variable: [destination]}, + ) + + +@_add_constraint_system_derivation_rule(mgpu.SliceSMEMOp) +def _slice_smem_constraint_system( + ctx: DerivationContext, + op: mgpu.SliceSMEMOp, +) -> ConstraintSystemDerivationRuleResult: + del ctx + res = ValueSite(op, VariableType.RESULT, 0) + res_var = cs.Variable(res) + return cs.ConstraintSystem(), {res_var: [res]} + + +@_add_constraint_system_derivation_rule(memref.SubViewOp) +def _memref_subview_constraint_system( + ctx: DerivationContext, + op: memref.SubViewOp, +) -> ConstraintSystemDerivationRuleResult: + source = ValueSite(op, VariableType.OPERAND, 0) + dest = ValueSite(op, VariableType.RESULT, 0) + source_dest_var = ctx.producer_ref(source) + + if any(s != 1 for s in op.static_strides): + raise NotImplementedError( + f"Only unit strides are supported but got {op.static_strides}." + ) + + # Collect all the constraints from all dimensions. + tiling_multiple = [] + dynamic_offset_index = 0 + for i, size in enumerate(op.static_sizes): + offset = op.static_offsets[i] + if offset == ir.ShapedType.get_dynamic_size(): + offset = op.offsets[dynamic_offset_index] + dynamic_offset_index += 1 + + # Drop all dimensions up to and including the last dynamic size. Dynamic + # sizes are not supported yet. + # + # Supporting dynamic sizes here can be done analogously to how dynamic + # offsets are supported. The reason we don't support dynamic sizes now is + # because the lowering does not yet support them. + if ir.ShapedType.is_dynamic_size(size): + tiling_multiple = [] + else: + src_type = ir.MemRefType(op.source.type) + divisibility_constraint = math.gcd(size, src_type.shape[i]) + if isinstance(offset, int): + divisibility_constraint = math.gcd(divisibility_constraint, offset) + else: + divisibility_constraint = dynamic_gcd(divisibility_constraint, offset) + tiling_multiple.append(divisibility_constraint) + + constraints = [cs.Divides(source_dest_var, tuple(tiling_multiple))] + system = cs.ConstraintSystem(constraints=constraints) + return system, {source_dest_var: [source, dest]} + + +@_add_constraint_system_derivation_rule(memref.CastOp) +def _memref_cast_op_constraint_system( + ctx: DerivationContext, + op: memref.CastOp, +) -> ConstraintSystemDerivationRuleResult: + source = ValueSite(op, VariableType.OPERAND, 0) + var_source_dest = ctx.producer_ref(source) + dest = ValueSite(op, VariableType.RESULT, 0) + return cs.ConstraintSystem(), {var_source_dest: [source, dest]} + + +@_add_constraint_system_derivation_rule(memref.TransposeOp) +def _memref_transpose_op_constraint_system( + ctx: DerivationContext, + op: memref.TransposeOp, +) -> ConstraintSystemDerivationRuleResult: + in_ty = ir.MemRefType(op.in_.type) + in_strides, _ = in_ty.get_strides_and_offset() + out_strides, _ = ir.MemRefType(op.result.type).get_strides_and_offset() + transpose = in_strides != out_strides + + source = ValueSite(op, VariableType.OPERAND, 0) + dest = ValueSite(op, VariableType.RESULT, 0) + source_var = ctx.producer_ref(source) + + if not transpose: + return cs.ConstraintSystem(), {source_var: [source, dest]} + + dest_var = cs.Variable(dest) + constraints = [ + cs.Equals(cs.Transpose(source_var), dest_var), + cs.Equals(source_var, cs.Transpose(dest_var)), + ] + system = cs.ConstraintSystem(constraints=constraints) + return system, {source_var: [source], dest_var: [dest]} + + +@_add_constraint_system_derivation_rule(memref.ExpandShapeOp) +def _memref_expand_shape_op_equation_system( + ctx: DerivationContext, + op: memref.ExpandShapeOp, +) -> ConstraintSystemDerivationRuleResult: + if utils.is_memref_transposed(ir.MemRefType(op.src.type)): + raise NotImplementedError( + "Transposed memrefs are not supported in ExpandShapeOp." + ) + + source = ValueSite(op, VariableType.OPERAND, 0) + dest = ValueSite(op, VariableType.RESULT, 0) + var = ctx.producer_ref(source) + + reverse_tiling_multiple = [] + for dim, idx in zip( + reversed(op.static_output_shape), reversed(op.reassociation) + ): + if ir.ShapedType.is_dynamic_size(dim) or len(idx) > 1: + # For simplicity, we only support tiling non-expanded static dimensions. + # These limitations could be lifted later if needed. + break + reverse_tiling_multiple.append(dim) + + constraints = [cs.Divides(var, tuple(reversed(reverse_tiling_multiple)))] + return cs.ConstraintSystem(constraints=constraints), {var: [source, dest]} + + +# `memref.load` and `memref.store` are used to load barrier phases which are +# scalars---the rule needn't do anything interesting, but we need to have it. +@_add_constraint_system_derivation_rule(memref.LoadOp) +@_add_constraint_system_derivation_rule(memref.StoreOp) +def _memref_load_store_op_constraint_system( + ctx: DerivationContext, + op: memref.LoadOp | memref.StoreOp, +) -> ConstraintSystemDerivationRuleResult: + del ctx + + ref_shape = ir.MemRefType(op.memref.type).shape + if ref_shape and ref_shape != [1]: + raise NotImplementedError( + f"Only scalar memrefs are supported, got {ref_shape}" + ) + + ref_op_index = 0 if isinstance(op, memref.LoadOp) else 1 + ref = ValueSite(op, VariableType.OPERAND, ref_op_index) + var = cs.Variable(ref) + assignments: dict[cs.Variable, cs.Constant] = {var: cs.SMEMTiling(None)} + return cs.ConstraintSystem(assignments=assignments), {var: [ref]} + + +def _extract_smem_tiling_from_custom_transform_attrs( + ref_type: ir.MemRefType, + transform_attrs: ir.ArrayAttr, +) -> cs.SMEMTiling: + transforms = [layouts_lib.from_transform_attr(x) for x in transform_attrs] + match transforms: + case []: + tile_transform = None + swizzle = None + case [lc.TileTransform() as t]: + tile_transform = t + swizzle = None + case [lc.TileTransform() as t, mgpu.SwizzlingMode() as s]: + tile_transform = t + swizzle = s + case _: + raise NotImplementedError(f"Unsupported transforms {transforms}") + + if swizzle is not None: + computed_swizzle = _compute_swizzle(ref_type, tile_transform) + if computed_swizzle != swizzle: + raise NotImplementedError( + f"Cannot honor caller-provided swizzle {swizzle} that is different " + f"from the computed swizle {computed_swizzle} for type {ref_type}." + ) + + return cs.SMEMTiling(tile_transform) + + +@_add_constraint_system_derivation_rule(mgpu.WithTransformsOp) +def _with_transforms_constraint_system( + ctx: DerivationContext, + op: mgpu.WithTransformsOp, +) -> ConstraintSystemDerivationRuleResult: + source = ValueSite(op, VariableType.OPERAND, 0) + dest = ValueSite(op, VariableType.RESULT, 0) + var = ctx.producer_ref(source) + tiling = _extract_smem_tiling_from_custom_transform_attrs(op.ref.type, op.transforms) + if tiling.value is not None: + # TODO(bchetioui): think about raising a better error here. + if not is_valid_smem_layout_assignment(source.shape, tiling.value): + return cs.Unsatisfiable() + assignments: dict[cs.Variable, cs.Constant] = {var: tiling} + return cs.ConstraintSystem(assignments=assignments), {var: [source, dest]} + + +def _vector_value_sites_and_assignments_for_async_ops( + op: mgpu.AsyncLoadOp | mgpu.AsyncStoreOp | mgpu.AsyncPrefetchOp, +) -> tuple[ValueSitesForVariable, dict[cs.Variable, cs.Constant]]: + values_sites: ValueSitesForVariable = dict() + assignments: dict[cs.Variable, cs.Constant] = dict() + + match op: + case mgpu.AsyncLoadOp(): + base_operand_index = 3 + case mgpu.AsyncStoreOp(): + base_operand_index = 2 + case mgpu.AsyncPrefetchOp(): + base_operand_index = 1 + case _: + raise ValueError(f"Unsupported op type: {op}") # make pytype happy + + for i, idx in enumerate(op.indices): + if isinstance(idx.type, ir.VectorType): + value_site = ValueSite(op, VariableType.OPERAND, base_operand_index + i) + value_site_var = cs.Variable(value_site) + layout = cs.RegisterLayout(value=fa.TMA_GATHER_INDICES_LAYOUT) + values_sites[value_site_var] = [value_site] + assignments[value_site_var] = layout + return values_sites, assignments + + +@_add_constraint_system_derivation_rule(mgpu.AsyncLoadOp) +@_add_constraint_system_derivation_rule(mgpu.AsyncStoreOp) +def _async_load_store_constraint_system( + ctx: DerivationContext, + op: mgpu.AsyncLoadOp | mgpu.AsyncStoreOp, +) -> ConstraintSystemDerivationRuleResult: + tiling_multiple = [] + for size, index in zip(op.slice_lengths, op.indices, strict=True): + if size == -1: + # This dimension does not appear in the final smem memref shape. + continue + tiling_multiple.append(dynamic_gcd(size, index)) + + operand_index = 1 if isinstance(op, mgpu.AsyncLoadOp) else 0 + operand = ValueSite(op, VariableType.OPERAND, operand_index) + var = ctx.producer_ref(operand) + constraints = [cs.Divides(expr=var, tiling_multiple=tuple(tiling_multiple))] + value_sites_for_variable = {var: [operand]} + value_sites, assignments = _vector_value_sites_and_assignments_for_async_ops(op) + value_sites_for_variable.update(value_sites) + return cs.ConstraintSystem(assignments, constraints), value_sites_for_variable + + +@_add_constraint_system_derivation_rule(mgpu.AsyncPrefetchOp) +def _async_prefetch_constraint_system( + ctx: DerivationContext, + op: mgpu.AsyncPrefetchOp, +) -> ConstraintSystemDerivationRuleResult: + del ctx + value_sites, assignments = _vector_value_sites_and_assignments_for_async_ops(op) + return cs.ConstraintSystem(assignments), value_sites + + +def _ensure_all_layouts_are_set(op: ir.OpView) -> None: + if inference_utils.should_have_layout(op): + _ensure_right_number_of_layouts(is_vector, "layouts", "vector", op) + if inference_utils.should_have_tmem_layout(op): + _ensure_right_number_of_layouts(_is_tmem_ref, "tmem_layouts", "TMEM ref", op) + if inference_utils.should_have_transforms(op): + _ensure_right_number_of_layouts( + inference_utils.is_transformable_smem_memref, "transforms", "SMEM ref", op, + ) + + +def _ensure_right_number_of_layouts( + filter_fn: Callable[[ir.Value], bool], + attr_suffix: str, + value_type: str, + op: ir.OpView, +) -> None: + """Ensures that the right number of in/out layouts are provided for an op. + + Layouts here are can be vector layouts, TMEM layouts, or SMEM transforms. + """ + layouts = lambda attr: op.attributes[attr] if attr in op.attributes else [] + in_layouts = layouts(f"in_{attr_suffix}") + out_layouts = layouts(f"out_{attr_suffix}") + + num_matching_operands = sum(map(filter_fn, op.operands)) + if len(in_layouts) != num_matching_operands: + raise ValueError( + f"Expected the same number of in_{attr_suffix} ({len(in_layouts)}) as " + f"{value_type} operands ({num_matching_operands}). op=\n {op}" + ) + num_matching_results = sum(map(filter_fn, op.results)) + if len(out_layouts) != num_matching_results: + raise ValueError( + f"Expected the same number of out_{attr_suffix} ({len(out_layouts)}) " + f"as {value_type} results ({num_matching_results}). op=\n {op}" + ) -@partial(_add_layout_inference_rule, mgpu.WGMMAOp) -def _infer_wgmma_op_layout(wgmma_op: mgpu.WGMMAOp) -> OptionalLayouts: - layout = layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT) +def _compute_swizzle( + type: ir.Type, tile_transform: lc.TileTransform | None +) -> mgpu.SwizzlingMode: + """Computes the swizzle mode given a tiling transform and a data type.""" + if tile_transform is None: + # TODO(b/447079781): Revisit if this is the behavior we want. + return mgpu.SwizzlingMode.kNoSwizzle - if ir.VectorType.isinstance(wgmma_op.a.type): - return [layout, layout], [layout] + if not isinstance(type, ir.MemRefType): + raise ValueError(f"Expected a MemRefType, got {type}.") + ref_ty = ir.MemRefType(type) + strides, _ = ref_ty.get_strides_and_offset() + tiling = tile_transform.tiling - return [layout], [layout] + if len(tiling) > len(strides): + raise ValueError( + f"The tile rank ({len(tiling)}) cannot be greater than the ref's rank" + f" ({len(strides)})." + ) + minor_tiling = tiling[np.argmin(strides[-len(tiling):])] + swizzle = minor_tiling * utils.bytewidth(ref_ty.element_type) + assert swizzle in ( + mgpu.SwizzlingMode.k128ByteSwizzle, + mgpu.SwizzlingMode.k64ByteSwizzle, + mgpu.SwizzlingMode.k32ByteSwizzle, + mgpu.SwizzlingMode.kNoSwizzle, + ) + return mgpu.SwizzlingMode(swizzle) -def _earliest_use(regions: list[ir.Region], uses: Sequence[ir.OpOperand]) -> ir.OpView: - owners = [use.owner for use in uses] - for region in regions: - for block in region: - for op in block: - if op in owners: - return op - raise ValueError("None of uses are in the given block") +@dataclasses.dataclass(frozen=True) +class _TypeAndLayout: + type: ir.Type + layout: cs.Constant -def _insert_memref_layout_cast(layout: ir.Attribute, view_op: memref.ViewOp): - mem_ref_type = ir.MemRefType(view_op.result.type) - memref_new_type = ir.MemRefType.get( - mem_ref_type.shape, - mem_ref_type.element_type, - layout, - mem_ref_type.memory_space, + +def assign_layouts(solution: dict[ValueSite, cs.Constant]) -> None: + """Assigns the layouts in `solution` to the MLIR ops they belong to. + + This function requires that, for each MLIR op that appears in `solution`, + `solution` contains a layout assignment for all of its `vector`, TMEM, and + SMEM operands and results. Block arguments are ignored. + """ + solution_sorted_by_op = sorted( + solution.items(), key=lambda kv: id(kv[0].operation) + ) + solution_per_op = itertools.groupby( + solution_sorted_by_op, key=lambda kv: kv[0].operation ) - uses = list(view_op.result.uses) - with ir.InsertionPoint(_earliest_use(view_op.parent.regions, uses)): - cast_op = memref.cast(memref_new_type, view_op.result) - for use in uses: - use.owner.operands[use.operand_number] = cast_op + for op, assignments in solution_per_op: + assignments_sorted_by_type = sorted(assignments, key=lambda kv: kv[0].type) + assignments_by_type = { + ty: list(group) + for ty, group in itertools.groupby( + assignments_sorted_by_type, key=lambda kv: kv[0].type + ) + } + + in_assignments = assignments_by_type.get(VariableType.OPERAND, []) + out_assignments = assignments_by_type.get(VariableType.RESULT, []) + + index = lambda kv: kv[0].index + in_tls = [ + _TypeAndLayout(v.value.type, ce) + for v, ce in sorted(in_assignments, key=index) + ] + out_tls = [ + _TypeAndLayout(v.value.type, ce) + for v, ce in sorted(out_assignments, key=index) + ] + + in_layouts = [ + tl.layout.value + for tl in in_tls + if isinstance(tl.layout, cs.RegisterLayout) + ] + out_layouts = [ + tl.layout.value + for tl in out_tls + if isinstance(tl.layout, cs.RegisterLayout) + ] + in_tmem_layouts = [ + tl.layout.value for tl in in_tls if isinstance(tl.layout, cs.TMEMLayout) + ] + out_tmem_layouts = [ + tl.layout.value + for tl in out_tls + if isinstance(tl.layout, cs.TMEMLayout) + ] + in_transforms = [ + tl for tl in in_tls if isinstance(tl.layout, cs.SMEMTiling) + ] + out_transforms = [ + tl for tl in out_tls if isinstance(tl.layout, cs.SMEMTiling) + ] + + if inference_utils.should_have_in_layout(op): + attrs = [layouts_lib.to_layout_attr(l) for l in in_layouts] + op.attributes["in_layouts"] = ir.ArrayAttr.get(attrs) + if inference_utils.should_have_out_layout(op): + attrs = [layouts_lib.to_layout_attr(l) for l in out_layouts] + op.attributes["out_layouts"] = ir.ArrayAttr.get(attrs) + if inference_utils.should_have_in_tmem_layout(op): + attrs = [layouts_lib.to_layout_attr(l) for l in in_tmem_layouts] + op.attributes["in_tmem_layouts"] = ir.ArrayAttr.get(attrs) + if inference_utils.should_have_out_tmem_layout(op): + attrs = [layouts_lib.to_layout_attr(l) for l in out_tmem_layouts] + op.attributes["out_tmem_layouts"] = ir.ArrayAttr.get(attrs) + + def _to_transform_attrs( + transforms: list[_TypeAndLayout], + ) -> list[ir.ArrayAttr]: + all_attrs: list[ir.ArrayAttr] = [] + for tl in transforms: + assert isinstance(tl.layout, cs.SMEMTiling) # make pytype happy + attrs = [] + if tl.layout.value is not None: + attrs.append(layouts_lib.to_transform_attr(tl.layout.value)) + swizzle = _compute_swizzle(tl.type, tl.layout.value) + attrs.append(layouts_lib.to_transform_attr(swizzle)) + all_attrs.append(ir.ArrayAttr.get(attrs)) + return all_attrs + + if inference_utils.should_have_in_transforms(op): + attrs = _to_transform_attrs(in_transforms) + op.attributes["in_transforms"] = ir.ArrayAttr.get(attrs) + if inference_utils.should_have_out_transforms(op): + attrs = _to_transform_attrs(out_transforms) + op.attributes["out_transforms"] = ir.ArrayAttr.get(attrs) + + _ensure_all_layouts_are_set(op) + + +def vector_value_sites(op: ir.OpView) -> list[ValueSite]: + """Returns all the vector operands and results for the given op.""" + value_sites = [ + ValueSite(op, VariableType.OPERAND, i) + for i, o in enumerate(op.operands) + if is_vector(o) + ] + value_sites.extend([ + ValueSite(op, VariableType.RESULT, i) + for i, o in enumerate(op.results) + if is_vector(o) + ]) + return value_sites -class TraversalOrder(enum.Enum): - """Traversal orders with respect to the data flow for IR.""" - FORWARD = 1 - BACKWARDS = 2 +def producer_result(operand: ValueSite) -> ValueSite: + """Given an operand, returns the corresponding result in its producer. + + When the producer is a block, we return the corresponding operand in the + operation that owns the block. + """ + assert operand.type == VariableType.OPERAND + value = operand.value + producer = value.owner + if isinstance(producer, ir.OpView): + index = list(producer.results).index(value) + return ValueSite(producer, VariableType.RESULT, index) + + if isinstance(producer, ir.Block): + index = list(producer.arguments).index(value) + region_index = list(producer.owner.regions).index(producer.region) + return ValueSite(producer.owner, VariableType.ARGUMENT, index, region_index) + + raise TypeError( + f"Producer {producer} is not an operation nor a block: {type(producer)}." + ) + + +def consumer_operands(result: ValueSite) -> Sequence[ValueSite]: + """Given a result or an argument, returns the corresponding operands in its consumers.""" + assert result.type in (VariableType.RESULT, VariableType.ARGUMENT) + consumer_operands: list[ValueSite] = [] + # The layout can also be chosen from the layout of the consumers of the + # results. + for use in result.value.uses: + consumer = use.owner + index = use.operand_number + consumer_operands.append(ValueSite(consumer, VariableType.OPERAND, index)) + return consumer_operands + + +def derive_relayout_constraints( + value_sites_for_variable: ValueSitesForVariable, +) -> list[cs.Relayout]: + """Derives relayout constraints from the given variable mapping.""" + constraints: list[cs.Relayout] = [] + variable_for_value_site: dict[ValueSite, cs.Variable] = {} + for variable, value_sites in value_sites_for_variable.items(): + for value_site in value_sites: + if value_site in variable_for_value_site: + raise ValueError( + f"{value_site} is mapped to both {variable} and " + f"{variable_for_value_site[value_site]}" + ) + variable_for_value_site |= {k: variable for k in value_sites} + + visited: set[cs.Variable] = set() + for variable, value_sites in value_sites_for_variable.items(): + producers: list[cs.Variable] = [] + consumers: list[cs.Variable] = [] + for value_site in value_sites: + # We can only relayout variables that are in registers. + if value_site.memory_space != MemorySpace.REG: + continue + + elt_bitwidth = utils.bitwidth(value_site.value.type.element_type) # pytype: disable=attribute-error + if value_site.type == VariableType.OPERAND: + pr = producer_result(value_site) + producer_variable = variable_for_value_site[pr] + producers.append(producer_variable) + # Only add the constraint if we haven't already created that constraint + # when processing this variable as one of the producer's consumers. + if producer_variable not in visited: + # The producer of a variable must be relayout-able to the variable. + constraints.append( + cs.Relayout(producer_variable, variable, elt_bitwidth) + ) + elif value_site.type in (VariableType.RESULT, VariableType.ARGUMENT): + for co in consumer_operands(value_site): + consumer_variable = variable_for_value_site[co] + consumers.append(consumer_variable) + # Only add the constraint if we haven't already created that + # constraint when processing this variable as the consumer's producer. + if consumer_variable not in visited: + # A variable must be relayout-able to its consumers. + constraints.append( + cs.Relayout(variable, consumer_variable, elt_bitwidth) + ) + visited.add(variable) + return constraints + + +def is_terminator(op: ir.OpView) -> bool: + return isinstance(op, (scf.YieldOp, scf.ConditionOp)) def traverse_op( op: ir.OpView, callback: Callable[[ir.OpView], None], - traversal_order: TraversalOrder = TraversalOrder.FORWARD, ): - """Traverses the operation and applies the callback in the given order.""" - for region in op.operation.regions: - for block in region: - if traversal_order == TraversalOrder.FORWARD: - ops_to_traverse = block - else: - ops_to_traverse = reversed(list(block)) - for block_op in ops_to_traverse: - traverse_op(block_op, callback, traversal_order) + """Traverses the operation and applies the callback in pre-order fashion. + + Skips recursing into `mgpu.CustomPrimitiveOp`s, and assumes that the values + iterated on are not being modified. + """ callback(op) + # The block of a mosaic_gpu.custom_primitive op is already lowered so it + # should not be traversed. + if not isinstance(op, mgpu.CustomPrimitiveOp): + for region in op.operation.regions: + for block in region: + for block_op in block.operations: + traverse_op(block_op, callback) + + +def is_valid_register_layout_assignment( + shape: tuple[int, ...], layout: fa.FragmentedLayout +) -> bool: + match layout: + case fa.WGStridedFragLayout() as strided_layout: + return strided_layout.shape == shape + case fa.WGSplatFragLayout() as splat_layout: + return splat_layout.shape == shape + case fa.TiledLayout(tiling=tiling): + try: + # `tiling.tile_shape` will raise if the shape is not tileable. + _ = tiling.tile_shape(shape) + except ValueError: + return False + return True + case _: + assert False, f"Unreachable {shape}, {layout}" + + +def is_valid_smem_layout_assignment( + shape: tuple[int, ...], tiling: lc.TileTransform +) -> bool: + try: + # `tiling.transform_shape` will raise if the shape is not tileable. + _ = tiling.transform_shape(shape) + except ValueError: + return False + return True + + +def is_valid_tmem_layout_assignment( + shape: tuple[int, ...], layout: tcgen05.TMEMLayout +) -> bool: + try: + # `layout.tiling.tile_shape` will raise if the shape is not tileable. + _ = layout.tiling.tile_shape(shape) + except ValueError: + return False + return True + + +def check_layout_assignment(v: ValueSite, layout: cs.Constant) -> None: + """Raises if the given layout can not be assigned to the given `ValueSite`.""" + match v.memory_space, layout: + case MemorySpace.REG, cs.RegisterLayout(value=reg_layout): + if not is_valid_register_layout_assignment(v.shape, reg_layout): + raise ValueError( + f"Layout {reg_layout} is not compatible with register variable " + f"{v.value}. This is a bug." + ) + case MemorySpace.TMEM, cs.TMEMLayout(value=tmem_layout): + if not is_valid_tmem_layout_assignment(v.shape, tmem_layout): + raise ValueError( + f"Layout {tmem_layout} is not compatible with TMEM variable " + f"{v.value}. This is a bug." + ) + case MemorySpace.SMEM, cs.SMEMTiling(value=tiling_or_none): + if tiling_or_none is None: + return + if not is_valid_smem_layout_assignment(v.shape, tiling_or_none): + raise ValueError( + f"Layout {tiling_or_none} is not compatible with SMEM variable " + f"{v.value}. This is a bug." + ) + case _: + raise ValueError( + f"Variable {v.value} in memory space {v.memory_space} should not be " + f"assigned a layout of type {type(layout)}. This is a bug." + ) -def infer_layout(module: ir.Module): - def inference_step(op: ir.Operation): - if not inference_utils.should_have_layout(op): +def infer_layout( + module: ir.Module, *, fuel: int = _DEFAULT_LAYOUT_INFERENCE_FUEL +): + """Infers layouts for the given module. + + * If there are vector (respectively SMEM refs, TMEM refs) operands, + `in_layouts` (respectively `in_transforms`, `in_tmem_layouts`) will be set and + contain one element per relevant argument in the memory space. + * If there are vector (respectively SMEM refs, TMEM refs) outputs, + `out_layouts` (respectively `out_transforms`, `out_tmem_layouts`) will be set + and contain one element per relevant argument in the memory space. + * Any of these attributes is guaranteed to not be set if there is no relevant + input/output in the corresponding memory space. + + The fuel is provided in order to limit the number of attempts made by the + solver. + """ + global_constraint_system: cs.ConstraintSystem | cs.Unsatisfiable + global_constraint_system = cs.ConstraintSystem() + ctx = DerivationContext() + + def gather_constraints(op: ir.Operation): + # Terminator ops are handled directly by the op whose region they belong to. + # This is because they need to be in sync with their parent op's inputs and + # outputs---and the parent op's constraints therefore need to take them into + # account. + if is_terminator(op): return - elif inference_rule := _layout_inference_rules.get(op.OPERATION_NAME, None): # pytype: disable=attribute-error - pass - else: - raise NotImplementedError(f"Can not infer layout for {op}") - - maybe_layouts = inference_rule(op) - if maybe_layouts is None: + should_have_layout = ( + inference_utils.should_have_layout(op) + or inference_utils.should_have_tmem_layout(op) + or inference_utils.should_have_transforms(op) + ) + if not should_have_layout: return + rule = _constraint_system_derivation_rules.get(op.OPERATION_NAME, None) # pytype: disable=attribute-error + if rule is None: + raise NotImplementedError(f"No layout inference rule defined for {op}") + rule_result = rule(ctx, op) + nonlocal global_constraint_system + if isinstance(rule_result, cs.Unsatisfiable): + global_constraint_system = cs.Unsatisfiable() + return + constraint_system, mapping = rule_result + global_constraint_system &= constraint_system + ctx.update(mapping) - _set_layout_attributes(op, *maybe_layouts) - - # TODO(bchetioui): consider switching the order of the passes. This would - # allow propagating "simpler" layouts further down in the computation, which - # is more efficient when possible. - # - # We run two passes over the module, in order to make sure that layouts - # defined in the middle of the computation are propagated wherever they need - # to be propagated. We start with a backwards (root-to-parameters) pass to - # propagate the information as far up as possible, and then a forward pass - # (parameters-to-root). - # - # Backwards pass for op in module.body: - inference_utils.traverse_op( - op, inference_step, inference_utils.TraversalOrder.BACKWARDS - ) + traverse_op(op, gather_constraints) + # Short-circuit if we have an unsatisfiable constraint system, we won't + # construct anything useful anymore. + if isinstance(global_constraint_system, cs.Unsatisfiable): + break - # Forward pass - for op in module.body: - inference_utils.traverse_op( - op, inference_step, inference_utils.TraversalOrder.FORWARD + if isinstance(global_constraint_system, cs.Unsatisfiable): + raise ValueError( + "Failed to infer a possible set of layouts. This should only happen if " + "user-provided layout casts are unsatisfiable." ) - # At this point, layouts have been propagated as far as they could be - # propagated. However, it is possible for some operations to remain - # unannotated---for example, if there were no annotations on any operation in - # the module at the start of this function. We annotate all the remaining ops - # that should be annotated with a strided fragmented layout, whose vector size - # is derived from the narrowest type and vector size used in the program. We - # make sure to derive a single vector size in order to avoid relayouts at - # lowering time. - default_vector_size = math.inf - - def update_default_vector_size(op: ir.OpView): - nonlocal default_vector_size - for v in list(op.operands) + list(op.results): - if ir.VectorType.isinstance(v.type): - max_vec_size_for_v = ( - np.prod(cast(ir.ShapedType, v.type).shape) // fa.WARPGROUP_SIZE - ) - desired_vec_size = 8 // utils.bytewidth(v.type.element_type) - default_vector_size = min( - default_vector_size, max_vec_size_for_v, desired_vec_size - ) + constraints = derive_relayout_constraints(ctx.value_sites_for_variable) + global_constraint_system &= cs.ConstraintSystem(constraints=constraints) + assert not isinstance(global_constraint_system, cs.Unsatisfiable) - for op in module.body: - traverse_op(op, update_default_vector_size) + # Add additional (redundant) constraints which helps the search converge + # faster. + global_constraint_system = cs.saturate_distinct_from_splat( + global_constraint_system + ) + assert not isinstance(global_constraint_system, cs.Unsatisfiable) + global_constraint_system = cs.saturate_divides_constraints_for_equal_vars( + global_constraint_system + ) - if default_vector_size is None: # Nothing to annotate. - return + # Attempt to find assignments that satisfy the constraint system. + solution, remaining_fuel = find_assignments_for( + list(ctx.value_sites_for_variable.keys()), + global_constraint_system, + fuel=fuel, + ) - def to_default_layout(ty: ir.Type) -> ir.Attribute | None: - if not ir.VectorType.isinstance(ty): - return None - layout = fa.WGStridedFragLayout( - shape=cast(ir.ShapedType, ty).shape, vec_size=default_vector_size - ) - return layouts_lib.to_strided_fragmented_layout_attr(layout) + if logging.vlog_is_on(1): + print("Finding a solution (or exhausting the entire search space) " + f"consumed {fuel - remaining_fuel}/{fuel} fuel.") - def set_default_layout(op: ir.OpView): - if inference_utils.should_have_layout( - op - ) and not inference_utils.has_any_layout_set(op): - in_layouts = [] - for operand in op.operands: - if (layout := to_default_layout(operand.type)) is not None: - in_layouts.append(layout) + if isinstance(solution, cs.Unsatisfiable): + raise ValueError( + "Failed to infer a possible set of layouts. This should only happen if " + "user-provided layout casts are unsatisfiable." + ) - out_layouts = [] - for result in op.results: - if (layout := to_default_layout(result.type)) is not None: - out_layouts.append(layout) + layout_for_value_site: dict[ValueSite, cs.Constant] = {} + for variable, value_sites in ctx.value_sites_for_variable.items(): + for value_site in value_sites: + layout = solution[variable] + # Ensure that the layout assignment is valid for the value site. This + # should only ever fail if our implementation is buggy. + check_layout_assignment(value_site, layout) + layout_for_value_site[value_site] = layout - _set_layout_attributes(op, in_layouts, out_layouts) + # Assigns the layouts that we found to the ops. + assign_layouts(layout_for_value_site) + # Sanity check: ensure that all ops have the right number of in/out layouts. for op in module.body: - traverse_op(op, set_default_layout) + traverse_op(op, _ensure_all_layouts_are_set) diff --git a/jax/experimental/mosaic/gpu/layouts.py b/jax/experimental/mosaic/gpu/layouts.py index 5c3b23119779..1b75e4d48b47 100644 --- a/jax/experimental/mosaic/gpu/layouts.py +++ b/jax/experimental/mosaic/gpu/layouts.py @@ -16,8 +16,11 @@ import re +from jax._src.lib import mosaic_gpu_dialect as mgpu from jax._src.lib.mlir import ir + from . import fragmented_array as fa +from . import launch_context _splat_fragmented_layout_attr_pattern = re.compile( @@ -96,7 +99,7 @@ def is_strided_fragmented_layout(attr: ir.Attribute) -> bool: _tiled_layout_attr_pattern = re.compile( r"^#mosaic_gpu.TiledLayout<\[(?P.*)\]," - r" warp_dim\s*=\s*(?P[-\d]+)," + r" warp_dims\s*=\s*\[(?P.*)\]," r" lane_dims\s*=\s*\[(?P.*)\]," r" vector_dim\s*=\s*(?P[-\d]+)>$" ) @@ -107,15 +110,31 @@ def to_tiled_layout_attr( ) -> ir.Attribute: """Constructs a #mosaic_gpu.TiledLayout attribute from a TiledLayout.""" + def _int_or_replicated(d: int | fa.Replicated) -> str: + if isinstance(d, fa.Replicated): + return f"#mosaic_gpu.Replicated" + return str(d) + tile_str = lambda tile: "[" + ", ".join(str(d) for d in tile) + "]" tiling = "[" + ", ".join(tile_str(tile) for tile in layout.tiling.tiles) + "]" + warp_dims = ( + "[" + ",".join(_int_or_replicated(d) for d in layout.warp_dims) + "]" + ) + lane_dims = ( + "[" + ",".join(_int_or_replicated(d) for d in layout.lane_dims) + "]" + ) + return ir.Attribute.parse( - f"#mosaic_gpu.TiledLayout<{tiling}, warp_dim={layout.warp_dim}," - f" lane_dims={list(layout.lane_dims)}, vector_dim={layout.vector_dim}>" + f"#mosaic_gpu.TiledLayout<{tiling}, warp_dims={warp_dims}," + f" lane_dims={lane_dims}, vector_dim={layout.vector_dim}>" ) _list_of_lists_delimiter = re.compile(r"\]\s*,\s*\[") +_int_pattern = re.compile(r"^(?P[-\d]+)(\s*:\s*\w+)?$") +_replicated_pattern = re.compile( + r"^#mosaic_gpu.Replicated<\s*times\s*=\s*(?P\d+)\s*>\s*$" +) def from_tiled_layout_attr( @@ -133,6 +152,15 @@ def from_tiled_layout_attr( f"Expected a #mosaic_gpu.TiledLayout attribute, got {attr}" ) + def _int_or_replicated(replicated_dim: str) -> int | fa.Replicated: + match = _replicated_pattern.fullmatch(replicated_dim) + if match: + return fa.Replicated(int(match.group("times"))) + match = _int_pattern.fullmatch(replicated_dim) + if match: + return int(match.group("num")) + raise ValueError(f"Unexpected format for replicated dim {replicated_dim}") + tiling_str = match.group("tiling") tile_strings = [] if len(tiling_str) > 2: @@ -140,9 +168,15 @@ def from_tiled_layout_attr( tiles = tuple(tuple(map(int, ts.split(","))) for ts in tile_strings) return fa.TiledLayout( tiling=fa.Tiling(tiles), - warp_dim=int(match.group("warp_dim")), - lane_dims=tuple(int(s) for s in match.group("lane_dims").split(",")), - vector_dim=int(match.group("vector_dim")) + warp_dims=tuple( + _int_or_replicated(s.strip()) + for s in match.group("warp_dims").split(",") + ), + lane_dims=tuple( + _int_or_replicated(s.strip()) + for s in match.group("lane_dims").split(",") + ), + vector_dim=int(match.group("vector_dim")), ) @@ -150,14 +184,7 @@ def is_tiled_layout(attr: ir.Attribute) -> bool: return bool(_tiled_layout_attr_pattern.search(str(attr))) -def to_layout_attr( - layout: ( - fa.WGSplatFragLayout - | fa.WGStridedFragLayout - | fa.TiledLayout - | fa.WGMMARowFragLayout - ), -) -> ir.Attribute: +def to_layout_attr(layout: fa.FragmentedLayout) -> ir.Attribute: """Constructs an MLIR attribute that corresponds to the given layout.""" match layout: case fa.WGSplatFragLayout(): @@ -166,31 +193,13 @@ def to_layout_attr( return to_strided_fragmented_layout_attr(layout) case fa.TiledLayout(): return to_tiled_layout_attr(layout) - case fa.WGMMARowFragLayout(): - return ir.Attribute.parse("#mosaic_gpu.WGMMARowFragLayout") case _: raise NotImplementedError( f"Unsupported layout for conversion to MLIR attribute: {layout}" ) -_wgmma_row_fragmented_layout_attr_pattern = re.compile( - r"^#mosaic_gpu.WGMMARowFragLayout$" -) - - -def is_wgmma_row_fragmented_layout(attr: ir.Attribute) -> bool: - return bool(_wgmma_row_fragmented_layout_attr_pattern.search(str(attr))) - - -def from_layout_attr( - attr: ir.Attribute, -) -> ( - fa.WGSplatFragLayout - | fa.WGStridedFragLayout - | fa.TiledLayout - | fa.WGMMARowFragLayout -): +def from_layout_attr(attr: ir.Attribute) -> fa.FragmentedLayout: """Constructs a layout from an MLIR attribute.""" if is_splat_fragmented_layout(attr): return from_splat_fragmented_layout_attr(attr) @@ -198,9 +207,73 @@ def from_layout_attr( return from_strided_fragmented_layout_attr(attr) elif is_tiled_layout(attr): return from_tiled_layout_attr(attr) - elif is_wgmma_row_fragmented_layout(attr): - return fa.WGMMARowFragLayout() else: raise NotImplementedError( f"Unsupported layout for conversion from MLIR attribute: {attr}" ) + + +def splat_is_compatible_with_tiled( + l1: fa.WGSplatFragLayout, l2: fa.TiledLayout +) -> bool: + # A splat layout is compatible with a tiled layout up to replication if each + # dimension in the shape of the splat layout is divisible by the corresponding + # dimension in the base tile shape. + s1, s2 = l1.shape, l2.base_tile_shape + return all(d1 % d2 == 0 for d1, d2 in zip(s1, s2)) + + +_tile_transform_attr_pattern = re.compile( + r"^#mosaic_gpu.tile<[^>]+>$" +) + + +def is_tile_transform(attr: ir.Attribute) -> bool: + return bool(_tile_transform_attr_pattern.search(str(attr))) + + +_transpose_transform_attr_pattern = re.compile( + r"^#mosaic_gpu.transpose<[^>]+>$" +) + + +def is_transpose_transform(attr: ir.Attribute) -> bool: + return bool(_transpose_transform_attr_pattern.search(str(attr))) + + +_swizzle_transform_attr_pattern = re.compile( + r"^#mosaic_gpu.swizzle<[^>]+>$" +) + +def is_swizzle_transform(attr: ir.Attribute) -> bool: + return bool(_swizzle_transform_attr_pattern.search(str(attr))) + + +def to_transform_attr( + transform: launch_context.MemRefTransform | mgpu.SwizzlingMode, +) -> ir.Attribute: + if isinstance(transform, launch_context.TileTransform): + return mgpu.TileTransformAttr.get(transform.tiling) + elif isinstance(transform, launch_context.TransposeTransform): + return mgpu.TransposeTransformAttr.get(transform.permutation) + elif isinstance(transform, mgpu.SwizzlingMode): + return mgpu.SwizzleTransformAttr.get(transform) + else: + raise NotImplementedError(f"Unsupported transform {transform}") + + +def from_transform_attr( + transform: ir.Attribute, +) -> launch_context.MemRefTransform | mgpu.SwizzlingMode: + if is_tile_transform(transform): + return launch_context.TileTransform( + mgpu.TileTransformAttr(transform).tiling + ) + elif is_transpose_transform(transform): + return launch_context.TransposeTransform( + mgpu.TransposeTransformAttr(transform).permutation + ) + elif is_swizzle_transform(transform): + return mgpu.SwizzlingMode(mgpu.SwizzleTransformAttr(transform).swizzle) + else: + raise NotImplementedError(f"Unsupported transform {transform}") diff --git a/jax/experimental/mosaic/gpu/mma.py b/jax/experimental/mosaic/gpu/mma.py new file mode 100644 index 000000000000..6366c54c1b27 --- /dev/null +++ b/jax/experimental/mosaic/gpu/mma.py @@ -0,0 +1,220 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import itertools +from jax.experimental.mosaic.gpu import fragmented_array as fa +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import llvm +from jaxlib.mlir.dialects import vector +import numpy as np +from . import utils + + +class MMALayouts: + """Container for MMA layouts, providing a convenient way to create + layouts for MMA operands based on warp configuration. + """ + + lhs = fa.TiledLayout( + fa.Tiling(((64, 16), (16, 8), (8, 8), (2,))), + warp_dims=(-7,), + lane_dims=(-3, -2), + vector_dim=-1, + ) + rhs = fa.TiledLayout( + fa.Tiling(((8, 16), (8, 8), (2,))), + warp_dims=(fa.Replicated(4),), + lane_dims=(-3, -2), + vector_dim=-1, + ) + acc = fa.TiledLayout( + fa.Tiling(((64, 8), (16, 8), (8, 8), (2,))), + warp_dims=(-7,), + lane_dims=(-3, -2), + vector_dim=-1, + ) + + +def _mma_single_tile( + acc: fa.FragmentedArray, a: fa.FragmentedArray, b: fa.FragmentedArray +) -> fa.FragmentedArray: + """Performs `acc + a @ b.T` using warp level MMA instructions.""" + + # Muliply by 4 because the fragmtned array has a tile per warp. + assert a.shape == (64, 16) + assert b.shape == (8, 16) + assert acc.shape == (64, 8) + assert a.mlir_dtype == b.mlir_dtype + assert a.mlir_dtype in (ir.F16Type.get(), ir.BF16Type.get()) + assert acc.mlir_dtype == ir.F32Type.get() + assert ( + isinstance(acc.layout, fa.TiledLayout) + and isinstance(a.layout, fa.TiledLayout) + and isinstance(b.layout, fa.TiledLayout) + ) + num_acc_regs, num_a_regs, num_b_regs = 4, 4, 2 + + acc_regs = [ # pylint: disable=g-complex-comprehension + vector.extract( + reg, + dynamic_position=[], + static_position=ir.DenseI64ArrayAttr.get([pos]), + ) + for reg in acc.registers.flatten() + for pos in range(acc.layout.vector_length) + ] + i32 = ir.IntegerType.get_signless(32) + a_regs = [utils.bitcast(r, i32) for r in a.registers.flatten()] + b_regs = [utils.bitcast(r, i32) for r in b.registers.flatten()] + + # Make sure we have the right number of registers for the instruction. + assert len(a_regs) == 4 + assert len(acc_regs) == 4 + assert len(b_regs) == 2 + + instr = f"mma.sync.aligned.m16n8k16.row.col.f32.{a.mlir_dtype}.{b.mlir_dtype}.f32" + counter = itertools.count() + n_regs_str = lambda n: ( + "{" + ",".join([f"${next(counter)}" for _ in range(n)]) + "}" + ) + out_regs_str = n_regs_str(num_acc_regs) + a_regs_str = n_regs_str(num_a_regs) + b_regs_str = n_regs_str(num_b_regs) + c_regs_str = n_regs_str(num_acc_regs) + ptx = f"{instr} {out_regs_str}, {a_regs_str}, {b_regs_str}, {c_regs_str};" + # See: https://llvm.org/docs/LangRef.html#inline-assembler-expressions + constraints = ( + f"{','.join(['=f']*num_acc_regs)}," # Output accumulator regs + f"{','.join(['r']*num_a_regs)}," # Input A regs + f"{','.join(['r']*num_b_regs)}," + f"{','.join(['f']*num_acc_regs)}" # Input accumulator regs + ) + + in_operands = [*a_regs, *b_regs, *acc_regs] + acc_struct_type = ir.Type.parse( + f"!llvm.struct<({','.join(str(acc.mlir_dtype) for _ in acc_regs)})>" + ) + out_regs_struct = llvm.inline_asm( + acc_struct_type, + in_operands, + ptx, + constraints, + has_side_effects=False, + ) + out_regs = [ + llvm.extractvalue(acc.mlir_dtype, out_regs_struct, [i]) + for i in range(len(acc_regs)) + ] + vec_regs = [] + vec_undef = llvm.mlir_undef(ir.VectorType.get((2,), acc.mlir_dtype)) + for first, second in zip(out_regs[::2], out_regs[1::2]): + vec = llvm.insertelement(vec_undef, first, position=utils.c(0, i32)) + vec = llvm.insertelement(vec, second, position=utils.c(1, i32)) + vec_regs.append(vec) + out_regs = np.asarray(vec_regs, dtype=object).reshape(acc.registers.shape) + return fa.FragmentedArray( + _registers=out_regs, _layout=acc.layout, _is_signed=None + ) + + +# TODO(cperivol): More datatypes other than (b)f16. +def mma( + acc: fa.FragmentedArray, + a: fa.FragmentedArray, + b: fa.FragmentedArray, +) -> fa.FragmentedArray: + """Computes `acc + a @ b.T` using synchronouse MMA instructions. + + All operands must have `TiledLayout`s. The layouts must be generated + by the `MMALayouts` class, which ensures that the tiles are mapped + to the warps correctly. + + Args: + acc: A `FragmentedArray` with a `TiledLayout` generated from + `MMALayouts.acc`. + a: A `FragmentedArray` with a `TiledLayout` generated from + `MMALayouts.lhs`. + b: A `FragmentedArray` with a `TiledLayout` generated from `MMALayouts.rhs`. + + Returns: + A new `FragmentedArray` with the result of the computation with + the same type as `acc`. + """ + + (m, k) = a.shape + (n, k2) = b.shape + (m2, n2) = acc.shape + + if m != m2: + raise ValueError(f"M mismatch: {m} != {m2}") + if n != n2: + raise ValueError(f"N mismatch: {n} != {n2}") + if k != k2: + raise ValueError(f"K mismatch: {k} != {k2}") + + # todo(cperivol): A tile shape can have dimensions that are higher + # multiples of the mma op size as long as those dimensions are not + # sharded across warps. + bf16 = ir.BF16Type.get() + f16 = ir.F16Type.get() + if a.mlir_dtype != b.mlir_dtype: + raise ValueError(f"Dtype mismatch: {a.mlir_dtype} != {b.mlir_dtype}") + if a.mlir_dtype not in (bf16, f16): + raise NotImplementedError("Only bf16 and f16 supported for the operands.") + if acc.mlir_dtype != ir.F32Type.get(): + raise NotImplementedError("Only f32 accumulator supported.") + + if MMALayouts.lhs != a.layout: + raise ValueError("Expected MMALayouts.lhs layout for A") + if MMALayouts.rhs != b.layout: + raise ValueError("Expected MMALayouts.rhs layout for B") + if MMALayouts.acc != acc.layout: + raise ValueError("Expected MMALayouts.acc layout for acc") + + assert isinstance(a.layout, fa.TiledLayout) + assert isinstance(b.layout, fa.TiledLayout) + assert isinstance(acc.layout, fa.TiledLayout) + m_tile, k_tile = a.layout.base_tile_shape + n_tile, k_tile2 = b.layout.base_tile_shape + m_tile2, n_tile2 = acc.layout.base_tile_shape + + assert k_tile == k_tile2 + assert m_tile2 == m_tile + assert n_tile2 == n_tile + + num_m_tiles, num_n_tiles, num_k_tiles = m // m_tile, n // n_tile, k // k_tile + if m != m2: + raise ValueError(f"M mismatch: {m} != {m2}") + if n != n2: + raise ValueError(f"N mismatch: {n} != {n2}") + if k != k2: + raise ValueError(f"K mismatch: {k} != {k2}") + + assert m_tile == 64 and n_tile == 8 and k_tile == 16, ( + f"Tile shape {m_tile}, {n_tile}, {k_tile} not supported." + ) + + # Do not modify the accumualtor itself. + acc = acc.copy() + s = lambda idx, length: slice(idx * length, (idx + 1) * length) + for k_idx in range(num_k_tiles): + for m_idx in range(num_m_tiles): + for n_idx in range(num_n_tiles): + ms = s(m_idx, m_tile) + ns = s(n_idx, n_tile) + ks = s(k_idx, k_tile) + acc[ms, ns] = _mma_single_tile(acc[ms, ns], a[ms, ks], b[ns, ks]) + + return acc diff --git a/jax/experimental/mosaic/gpu/mma_utils.py b/jax/experimental/mosaic/gpu/mma_utils.py index 81f6af1a925d..ebd789ca348c 100644 --- a/jax/experimental/mosaic/gpu/mma_utils.py +++ b/jax/experimental/mosaic/gpu/mma_utils.py @@ -23,7 +23,6 @@ from . import utils -# mypy: ignore-errors def tiled_memref_shape(ref: ir.Value): """Returns the 2D untiled shape and element type of a tiled 4D memref.""" @@ -48,17 +47,26 @@ def create_descriptor( logical_k_major: bool, # False for LHS, True for RHS. # Soft deprecated. Use small tiling instead. large_tile: tuple[int, int] | None = None, + mma_bytewidth_k: int = 32, + split_const: bool = False, ): ref_ty = ir.MemRefType(ref.type) - element_bytewidth = utils.bytewidth(ref_ty.element_type) - swizzle_elems = swizzle // element_bytewidth + element_bitwidth = utils.bitwidth(ref_ty.element_type) + swizzle_elems = 8 * swizzle // element_bitwidth ref_strides, _ = ref_ty.get_strides_and_offset() - ref_byte_strides = [s * element_bytewidth for s in ref_strides] + def to_byte_stride(stride: int): + if element_bitwidth >= 8: + assert element_bitwidth % 8 == 0 + return stride * element_bitwidth // 8 + else: + packing = 8 // element_bitwidth + assert stride % packing == 0 + return stride // packing mn_large_tile = k_large_tile = None if logical_k_major: _, mn_tiles, k_tiling, mn_tiling = ref_ty.shape k_tile_stride, mn_tile_stride, k_tiling_stride, mn_tiling_stride = ( - ref_byte_strides + ref_strides ) k_group_size, mn_group_size = group_size if large_tile is not None: @@ -66,7 +74,7 @@ def create_descriptor( else: mn_tiles, _, mn_tiling, k_tiling = ref_ty.shape mn_tile_stride, k_tile_stride, mn_tiling_stride, k_tiling_stride = ( - ref_byte_strides + ref_strides ) mn_group_size, k_group_size = group_size if large_tile is not None: @@ -74,8 +82,9 @@ def create_descriptor( IGNORED = 0 MMA_ATOM_ROWS = 8 - MMA_BYTEWIDTH_K = 32 - mma_width_k = MMA_BYTEWIDTH_K // element_bytewidth + mma_width_k = 8 * mma_bytewidth_k // element_bitwidth + desc_k_tiling: tuple[int, ...] = () + desc_k_strides: tuple[int, ...] # As far as I can tell (which does not seem to fully align with the way MMA is # documented in PTX docs), MMA expects the data to be tiled into matrices # of shape 8 x swizzle_elems, with swizzle_elems dim being the fastest @@ -93,13 +102,12 @@ def create_descriptor( # There are configurations where large tiles are same size as small ones. # We use the small path since it has fewer restrictions. and set(large_tile) != {MMA_ATOM_ROWS, swizzle_elems} + and mma_bytewidth_k == 32 ): # Large tiles. - if ( - k_tiling_stride == element_bytewidth - and mn_tiling_stride == k_tiling * element_bytewidth - ): + if k_tiling_stride == 1 and mn_tiling_stride == k_tiling: fastest_dim = Dim.K leading_byte_offset = IGNORED # TC assumes K to be contiguous here. + assert k_tiling == k_group_size # Else we need multi-level striding. # MMA atoms in a group are contiguous, so we increment by the MMA atom # size. However, we only have one level of striding, and so if the group # size exceeds a single large tile (and there is more than one tile) then @@ -108,20 +116,17 @@ def create_descriptor( if ( mn_tiles > 1 and mn_group_size > mn_tiling - and mn_tile_stride != math.prod(large_tile) * element_bytewidth + and mn_tile_stride != math.prod(large_tile) ): raise ValueError( "MMA layout with large tiles that is K-fastest only supports" " multiple MN tiles when the tiled MN dimension is a contiguous" " stack of tiles " - f"({mn_tiles}, {mn_tile_stride} != {math.prod(large_tile)} * {element_bytewidth})" + f"({mn_tiles}, {mn_tile_stride} != {math.prod(large_tile)})" ) stride_byte_offset = MMA_ATOM_ROWS * swizzle - desc_k_stride = MMA_BYTEWIDTH_K # K is contiguous. - elif ( - k_tiling_stride == k_tiling * element_bytewidth - and mn_tiling_stride == element_bytewidth - ): + desc_k_strides = (mma_bytewidth_k,) # K is contiguous. + elif k_tiling_stride == k_tiling and mn_tiling_stride == 1: if k_large_tile != mn_large_tile: raise ValueError( "MMA layout with large tiles that is MN-fastest is only supported" @@ -129,13 +134,13 @@ def create_descriptor( ) fastest_dim = Dim.MN # Next swizzle atom with the same K coordinate is in the next MN tile. - leading_byte_offset = mn_tile_stride + leading_byte_offset = to_byte_stride(mn_tile_stride) # MMA atoms in a group are contiguous and a group does not exceed a tile. assert k_large_tile == k_group_size stride_byte_offset = MMA_ATOM_ROWS * swizzle # Each row is swizzle bytes wide, and we read mma_width_k rows at a time. - assert mn_large_tile == swizzle // element_bytewidth - desc_k_stride = mma_width_k * swizzle + assert mn_large_tile == 8 * swizzle // element_bitwidth + desc_k_strides = (mma_width_k * swizzle,) else: raise ValueError("MMA tiles must be contiguous") else: # Small tiles. @@ -146,21 +151,32 @@ def create_descriptor( if slower_tiling != MMA_ATOM_ROWS or faster_tiling != swizzle_elems: raise ValueError( f"Tiling should be ({MMA_ATOM_ROWS}, swizzle_elems) where" - f" swizzle_elems = swizzle // bytewidth(dtype) (= {swizzle} //" - f" {element_bytewidth} = {swizzle_elems}), but got ({slower_tiling}," + f" swizzle_elems = 8 * swizzle // bitwidth(dtype) (= 8 * {swizzle} //" + f" {element_bitwidth} = {swizzle_elems}), but got ({slower_tiling}," f" {faster_tiling})" ) - if k_tiling_stride == element_bytewidth and mn_tiling_stride == swizzle: + if k_tiling_stride == 1 and mn_tiling_stride * element_bitwidth == MMA_ATOM_ROWS * swizzle: fastest_dim = Dim.K leading_byte_offset = IGNORED # TC assumes K to be contiguous here. - stride_byte_offset = mn_tile_stride - desc_k_stride = MMA_BYTEWIDTH_K # K is contiguous. - elif k_tiling_stride == swizzle and mn_tiling_stride == element_bytewidth: + stride_byte_offset = to_byte_stride(mn_tile_stride) + if k_tiling == k_group_size: + desc_k_strides = (mma_bytewidth_k,) # K is contiguous. + elif k_group_size % k_tiling == 0: + desc_k_tiling = (k_tiling // mma_width_k,) + desc_k_strides = (MMA_ATOM_ROWS * swizzle, mma_bytewidth_k) + else: + if k_tiling < mma_width_k: + raise ValueError( + "K dimension tiling is smaller than the width of a single MMA" + " instruction. Increase swizzle." + ) + raise NotImplementedError(f"{k_group_size=} must be larger than {k_tiling=}") + elif k_tiling_stride * element_bitwidth == MMA_ATOM_ROWS * swizzle and mn_tiling_stride == 1: fastest_dim = Dim.MN - leading_byte_offset = mn_tile_stride - stride_byte_offset = k_tile_stride + leading_byte_offset = to_byte_stride(mn_tile_stride) + stride_byte_offset = to_byte_stride(k_tile_stride) k_tiles_per_mma = mma_width_k // MMA_ATOM_ROWS - desc_k_stride = k_tile_stride * k_tiles_per_mma + desc_k_strides = (to_byte_stride(k_tile_stride) * k_tiles_per_mma,) else: raise ValueError("MMA tiles must be contiguous") desc_base = encode_descriptor( @@ -168,17 +184,27 @@ def create_descriptor( leading_byte_offset=leading_byte_offset, stride_byte_offset=stride_byte_offset, swizzle=swizzle, + split_const=split_const, ) mn_tiles_per_group, rem = divmod(mn_group_size, mn_tiling) - assert not rem - mn_group_stride = mn_tile_stride * mn_tiles_per_group + if rem: + raise ValueError( + f"The M or N MMA instruction size was chosen to be {mn_group_size}," + " which is not a multiple of the tiling of the non-contracting" + f" dimension {mn_tiling}" + ) + mn_group_stride = to_byte_stride(mn_tile_stride) * mn_tiles_per_group k_tiles_per_group, rem = divmod(k_group_size, k_tiling) - assert not rem - k_group_stride = k_tile_stride * k_tiles_per_group + if rem: + raise ValueError( + f"The K MMA instruction size was chosen to be {k_group_size}, which is" + f" not a multiple of the tiling of the contracting dimension {k_tiling}" + ) + k_group_stride = to_byte_stride(k_tile_stride) * k_tiles_per_group return ( - (desc_base, desc_k_stride), + (desc_base, (desc_k_tiling, desc_k_strides)), (mn_group_stride, k_group_stride), fastest_dim, ) @@ -192,14 +218,21 @@ def encode_addr(x: int): def encode_descriptor( - memref_arg, + ref_arg, leading_byte_offset: int, stride_byte_offset: int, swizzle: int | mgpu_dialect.SwizzlingMode | None, const_init: int = 0, + split_const: bool = False, ): + i32 = ir.IntegerType.get_signless(32) i64 = ir.IntegerType.get_signless(64) - ptr_val = llvm.ptrtoint(i64, utils.memref_ptr(memref_arg, 3)) + if isinstance(ref_arg.type, ir.MemRefType): + ptr = utils.memref_ptr(ref_arg, 3) + else: + ptr = ref_arg + assert ptr.type == ir.Type.parse("!llvm.ptr<3>"), ptr.type + ptr_val = llvm.ptrtoint(i64, ptr) c = lambda x: arith.constant(i64, x) if swizzle is None or swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle: swizzle_encoding = 0 @@ -217,7 +250,18 @@ def encode_descriptor( const_init | (encode_addr(leading_byte_offset) << 16) | (encode_addr(stride_byte_offset) << 32) + | (swizzle_encoding << 62) ) - desc = llvm.or_(arith.shli(c(swizzle_encoding), c(62)), c(desc_const)) - desc = llvm.or_(encoded_base_addr, desc) - return desc + if split_const: + # The encoded base addr fits within a single 32-bit register. + return arith.trunci(i32, encoded_base_addr), desc_const + else: + # The desc_const frequently has the MSB set, leading to errors when trying + # to create ir.IntegerAttr through the MLIR python bindings... This should + # be easy enough for LLVM to constant fold away. + if desc_const >> 63: + desc_val = c(desc_const & 0xFFFFFFFF) + desc_val = llvm.or_(desc_val, arith.shli(c(desc_const >> 32), c(32))) + else: + desc_val = c(desc_const) + return llvm.or_(encoded_base_addr, desc_val) diff --git a/jax/experimental/mosaic/gpu/profiler.py b/jax/experimental/mosaic/gpu/profiler.py index 0c128f88d169..84c33946a15b 100644 --- a/jax/experimental/mosaic/gpu/profiler.py +++ b/jax/experimental/mosaic/gpu/profiler.py @@ -13,21 +13,24 @@ # limitations under the License. # ============================================================================== +from collections.abc import Callable import contextlib import itertools import json import math -from typing import Callable, ParamSpec, TypeVar +import os +import tempfile +from typing import Literal, ParamSpec, TypeVar, overload import warnings import jax -from jax._src.lib import xla_client +from jax._src import stages +from jax._src import util import jax.numpy as jnp from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith from jaxlib.mlir.dialects import gpu from jaxlib.mlir.dialects import memref -from jaxlib.mlir.dialects import scf import numpy as np from .utils import * # noqa: F403 @@ -35,165 +38,153 @@ try: from jax._src.lib import mosaic_gpu as mosaic_gpu_lib except ImportError: - has_registrations = False -else: - # TODO(slebedev): Remove the if once the minimum jaxlib is 0.4.36. - has_registrations = hasattr(mosaic_gpu_lib._mosaic_gpu_ext, "registrations") - if has_registrations: - for name, handler in mosaic_gpu_lib._mosaic_gpu_ext.registrations(): - xla_client.register_custom_call_target( - name, handler, platform="CUDA", api_version=1 - ) + mosaic_gpu_lib = None # type: ignore[assignment] # ruff: noqa: F405 -# mypy: ignore-errors T = TypeVar("T") P = ParamSpec("P") -def _event_record(args, *, copy_before): - flat_args, treedef = jax.tree.flatten(args) - event, *flat_outs = jax.ffi.ffi_call( - "mgpu_event_record", - result_shape_dtypes=(jax.core.ShapedArray((), jnp.uint64), *flat_args), - input_output_aliases={i: i + 1 for i in range(len(flat_args))}, - )(*flat_args, copy_before=copy_before) - return event, treedef.unflatten(flat_outs) - - -def _event_elapsed(start_event, end_event): - return jax.ffi.ffi_call( - "mgpu_event_elapsed", - result_shape_dtypes=jax.core.ShapedArray((), jnp.float32), - )(start_event, end_event) - - -def _measure_events( - f: Callable[P, T], *args: P.args, **kwargs: P.kwargs -) -> tuple[T, float]: - if not has_registrations: - raise RuntimeError( - "This function requires jaxlib >=0.4.36 with CUDA support." - ) - - if not (args or kwargs): - # We require at least one argument and at least one output to ensure - # that there is a data dependency between `_event_record` calls in - # the resulting HLO program. - raise ValueError("Can only measure functions with arguments") - - @jax.jit - def run(*args, **kwargs): - start_event, (args, kwargs) = _event_record( - (args, kwargs), copy_before=True - ) - end_event, outs = _event_record(f(*args, **kwargs), copy_before=False) - if jax.tree.structure(outs).num_leaves == 0: - raise ValueError("Can only measure functions with at least one output") - return outs, _event_elapsed(start_event, end_event) - - jax.block_until_ready(run(*args, **kwargs)) # Warmup. - outs, elapsed = run(*args, **kwargs) - return outs, float(elapsed) - - -def _measure_cupti(f, aggregate): - def run(*args, **kwargs): - mosaic_gpu_lib._mosaic_gpu_ext._cupti_init() - try: - results = jax.block_until_ready(jax.jit(f)(*args, **kwargs)) - finally: - timings = mosaic_gpu_lib._mosaic_gpu_ext._cupti_get_timings() - return results, timings - - def wrapper(*args, **kwargs): - run(*args, **kwargs) # Warmup. - results, timings = run(*args, **kwargs) - if not timings: - return results, None - elif aggregate: - return results, sum(item[1] for item in timings) - else: - return results, timings - return wrapper - -def measure(f: Callable, *, mode: str = "events", aggregate: bool = True -) -> Callable: - """Sets up a function ``f`` for profiling on GPU. - - ``measure`` is a higher-order function that augments the argument ``f`` to - return GPU runtime in milliseconds, in addition to its proper outputs. +@dataclasses.dataclass(frozen=True, kw_only=True) +class Cupti: + """CUPTI-based profiler.""" + + # If `True`, detach CUPTI from the process after measurement. + finalize: bool = True + + def measure( + self, f, *, aggregate: bool = True, iterations: int = 1, + ): + if not isinstance(f, (stages.Wrapped, stages.Compiled)): + f = jax.jit(f) + + def wrapper(*args, **kwargs): + if mosaic_gpu_lib is None: + raise RuntimeError("CUPTI profiling is not supported on this platform") + + jax.block_until_ready(f(*args, **kwargs)) # Warmup. + ext = mosaic_gpu_lib._mosaic_gpu_ext + ext._cupti_init() + try: + all_results = [f(*args, **kwargs) for _ in range(iterations)] + for r in all_results: + jax.block_until_ready(r) + results = all_results[0] + finally: + timings = ext._cupti_get_timings(self.finalize) + if not timings: + return results, None + + if len(timings) % iterations != 0: + raise RuntimeError( + "The number of kernel launches is not divisible by the number of" + " iterations" + ) + kernels_per_iter = len(timings) // iterations + iter_timings = util.split_list( + timings, [kernels_per_iter] * (iterations - 1) + ) + for kernel_idx, (kernel_name, _) in enumerate(iter_timings[0]): + for i in range(1, iterations): + if iter_timings[i][kernel_idx][0] != kernel_name: + raise RuntimeError("Kernel names are not consistent across iterations") + + if aggregate: + iter_timings = [ + sum(item[1] for item in timings) for timings in iter_timings + ] + + return results, iter_timings[0] if len(iter_timings) == 1 else iter_timings + + return wrapper + +@overload +def measure( + f: Callable[P, T], + *, + aggregate: Literal[True] = ..., + iterations: Literal[1] = ..., +) -> Callable[P, tuple[T, float | None]]: + ... + +@overload +def measure( + f: Callable[P, T], + *, + aggregate: Literal[False] = ..., + iterations: Literal[1] = ..., +) -> Callable[P, tuple[T, list[tuple[str, float]] | None]]: + ... + +@overload +def measure( + f: Callable[P, T], + *, + aggregate: Literal[True] = ..., + iterations: int = ..., +) -> Callable[P, tuple[T, list[float] | None]]: + ... + +@overload +def measure( + f: Callable[P, T], + *, + aggregate: Literal[False] = ..., + iterations: int = ..., +) -> Callable[P, tuple[T, list[list[tuple[str, float]]] | None]]: + ... + + +def measure( + f, *, aggregate: bool = True, iterations: int = 1, +): + """Measures the GPU runtime of a function using CUPTI. + + ``measure`` is a higher-order function that wraps a function ``f`` to + return GPU runtime in milliseconds, in addition to its regular outputs. Args: - f: The function to measure. It must accept at least one argument and return - at least one output to be measurable. - mode: The mode of operation. Possible values are: - - - "cupti", for CUPTI-based profiling. - - "events", for CUDA events-based profiling. - - The two modes use different measurement methodologies and should not be - treated as interchangeable backends. See the Notes section for important - discussion. + f: The function to measure. aggregate: Whether to report an aggregate runtime. When ``False`` (only supported by ``mode="cupti"``), the per-kernel timings are returned as a list of tuples ``(, )``. + iterations: How many times to run the function. Only supported by + ``mode="cupti"``. When greater than 1, the return type will become a list + of measurements. Returns: - A new function ``g`` that returns the measured GPU runtime as its last - additional output. Otherwise ``g`` accepts the same inputs and returns the - same outputs as ``f``. + A function that accepts the same inputs as ``f`` and returns + ``(f_outputs, timings)``, where ``f_outputs`` are the outputs of ``f``, + and ``timings`` is either a float or a list of tuples, depending on + ``aggregate``. If no kernels are launched, ``timings`` is ``None``. Notes: `CUPTI (CUDA Profiling Tools Interface) - `_ is a high-accuracy, - high-precision profiling and tracing API, used in particular by Nsight - Systems and Nsight Compute. When using ``measure`` with ``mode="cupti"``, - device (GPU) execution runtimes are recorded for each kernel launched - during the execution of the function. In that mode, setting - ``aggregate=True`` will sum the individual kernel runtimes to arrive at an - aggregate measurement. The "gaps" between the kernels when the device is - idle are not included in the aggregate. - - The CUPTI API only allows a single "subscriber". This means that the - CUPTI-based profiler will fail when the program is run using tools that - make use of CUPTI, such as CUDA-GDB, Compute Sanitizer, Nsight Systems, or - Nsight Compute. - - ``mode="events"`` uses a different approach: a CUDA event is recorded - before and after the function ``f`` is executed. The reported runtime is - the time elapsed between the two events. In particular, included in the - measurement are: - - - any potential "gaps" between the kernels when the device is idle - - any potential "gaps" between the "before" event and the start of the - first kernel, or between the end of the last kernel and the "after" event - - In an attempt to minimize the second effect, internally the events-based - implementation may execute ``f`` more than once to "warm up" and exclude - compilation time from the measurement. - """ - match mode: - case "cupti": - return _measure_cupti(f, aggregate) - case "events": - if not aggregate: - raise ValueError(f"{aggregate=} is not supported with {mode=}") - def measure_events_wrapper(*args, **kwargs): - return _measure_events(f, *args, **kwargs) - return measure_events_wrapper - case _: - raise ValueError(f"Unrecognized profiler mode {mode}") + `_ is a high-accuracy profiling + API used by Nsight Systems and Nsight Compute. The CUPTI API only allows a + single subscriber, so ``measure`` cannot be used with other CUPTI-based + tools like CUDA-GDB, Compute Sanitizer, Nsight Systems, or Nsight + Compute. + """ # fmt: skip + if iterations < 1: + raise ValueError(f"{iterations=} must be positive") + return Cupti().measure(f, aggregate=aggregate, iterations=iterations) class ProfilerSpec: ENTER = 0 EXIT = 1 << 31 - def __init__(self, entries_per_warpgroup: int): + def __init__(self, entries_per_warpgroup: int, dump_path: str = "sponge"): self.entries_per_warpgroup = entries_per_warpgroup - self.interned_names = {} + self.interned_names: dict[str, int] = {} + if dump_path == "sponge": + self.dump_path = os.getenv( + "TEST_UNDECLARED_OUTPUTS_DIR", tempfile.gettempdir() + ) + else: + self.dump_path = dump_path def _num_warpgroups( self, grid: tuple[int, ...], block: tuple[int, ...] @@ -243,14 +234,25 @@ def dump(self, buffer, f, grid: tuple[int, ...], block: tuple[int, ...]): ) start_times = entries[..., 0] sm_ids = entries[..., 1] - entries_used = entries[..., 2] - if np.any(entries_used > self.entries_per_warpgroup - 2): + traces_used = entries[..., 2] + entries_used = traces_used + 3 + if np.any(entries_used > self.entries_per_warpgroup): raise RuntimeError("Insufficient space to capture a full trace") traces = entries[..., 3:] + + # Estimate the overhead of profiling. + time_events = traces[:, :, 1::2] + valid_times_mask = np.arange(traces.shape[-1])[1::2] < traces_used[..., None] + # 12 cycles is a ballpark estimate for H100 + profiling_overhead = (time_events[:, :, 1:] - time_events[:, :, :-1]).min( + where=valid_times_mask[:, :, 1:], initial=12 + ) + profiling_overhead = max(0, profiling_overhead - 1) + unintern = {v: k for k, v in self.interned_names.items()} events = [] for block_idx, wg_idx in np.ndindex(num_blocks, warpgroups_per_block): - valid_entries = entries_used[block_idx, wg_idx] - 3 + valid_entries = traces_used[block_idx, wg_idx] local_clock_offset = None assert valid_entries % 2 == 0, valid_entries start_time = start_times[block_idx, wg_idx] @@ -262,7 +264,7 @@ def dump(self, buffer, f, grid: tuple[int, ...], block: tuple[int, ...]): if local_clock_offset is None: local_clock_offset = time time -= local_clock_offset - time -= i * 6 # Account for the overhead of profiling. + time -= (i // 2) * profiling_overhead # Account for the overhead of profiling. if time < 0: break # Detect a timer wraparound name_id = tag @@ -295,60 +297,105 @@ def dump(self, buffer, f, grid: tuple[int, ...], block: tuple[int, ...]): return json.dump({"displayTimeUnit": "ns", "traceEvents": flat_events}, f) +@dataclasses.dataclass(frozen=True) +class _ProfilerCtx: + """Set of IR values referenced by the profiler logic. + + The profiler logic is implemented using `CustomPrimitiveOp` which requires + that all IR values referenced in its body be passed as operands to the op. + """ + + start: ir.Value + is_profiling_thread: ir.Value + smem_buffer: ir.Value + gmem_buffer: ir.Value + offset: ir.Value + + class OnDeviceProfiler: - def __init__(self, spec: ProfilerSpec, smem_buffer: ir.Value, gmem_buffer: ir.Value): - self.spec = spec - self.start = globaltimer("low") + def __init__( + self, + spec: ProfilerSpec, + smem_buffer: ir.Value, + gmem_buffer: ir.Value, + wrap_in_custom_primitive: bool, + ): i32 = ir.IntegerType.get_signless(32) index = ir.IndexType.get() + self.spec = spec self.entries_per_wg = spec.entries_per_warpgroup + self.wrap_in_custom_primitive = wrap_in_custom_primitive wg_idx = warpgroup_idx(sync=False) - self.smem_buffer = memref_slice( - smem_buffer, - ds( - arith.index_cast( - index, arith.muli(wg_idx, c(self.entries_per_wg, i32)) - ), - self.entries_per_wg, - ), + wg_offset = arith.index_cast( + index, arith.muli(wg_idx, c(self.entries_per_wg, i32)) ) - self.smem_buffer_ptr = memref_ptr(self.smem_buffer, memory_space=3) - self.gmem_buffer = gmem_buffer - self.is_profiling_thread = arith.cmpi( + smem_buffer = memref_slice(smem_buffer, ds(wg_offset, self.entries_per_wg)) + is_profiling_thread = arith.cmpi( arith.CmpIPredicate.eq, arith.remui(thread_idx(), c(WARPGROUP_SIZE, i32)), c(0, i32), ) # Hopefully mem2reg will remove the allocation. - self.offset = memref.alloca(ir.MemRefType.get((), i32), [], []) - memref.store(c(0, i32), self.offset, []) + offset = memref.alloca(ir.MemRefType.get((), index), [], []) + memref.store(c(0, index), offset, []) + self.ctx = _ProfilerCtx( + start=globaltimer("low"), + is_profiling_thread=is_profiling_thread, + smem_buffer=smem_buffer, + gmem_buffer=gmem_buffer, + offset=offset, + ) + + @contextlib.contextmanager + def _profiler_ctx(self): + if not self.wrap_in_custom_primitive: + yield self.ctx + return + + def fields(obj) -> list[ir.Value]: + return [getattr(obj, field.name) for field in dataclasses.fields(obj)] + + op = dialect.CustomPrimitiveOp( + result=[], + operands_=fields(self.ctx), + in_layouts=[], + in_transforms=[ir.ArrayAttr.get([])], + out_layouts=[], + ) + args_ty = [arg.type for arg in op.operands_] + block = op.body.blocks.append(*args_ty) + with ir.InsertionPoint(block): + yield _ProfilerCtx(*block.arguments) + dialect.return_([]) @contextlib.contextmanager def record(self, name: str): i32 = ir.IntegerType.get_signless(32) + index = ir.IndexType.get() name_id = self.spec.intern_name(name) def store(modifier): - cur = memref.load(self.offset, []) - i64 = ir.IntegerType.get_signless(64) - base_addr = arith.addi( - llvm.ptrtoint(i64, self.smem_buffer_ptr), - arith.extui(i64, arith.muli(cur, c(4, i32))), - ) - llvm.inline_asm( - ir.Type.parse("!llvm.void"), - [self.is_profiling_thread, base_addr, c(modifier | name_id, i32)], - """ - @$0 st.shared.v2.u32 [$1], {$2, %clock}; - """, - "b,l,r", - has_side_effects=True, - ) - memref.store( - arith.addi(cur, c(2, cur.type)), - self.offset, - [], - ) + with self._profiler_ctx() as ctx: + # smem_buffer[offset] = modifier | name_id + # smem_buffer[offset + 1] = %clock + # offset += 2 + offset = memref.load(ctx.offset, []) + base_ref = memref_slice(ctx.smem_buffer, offset) + base_ptr = memref_ptr(base_ref, memory_space=3) + i64 = ir.IntegerType.get_signless(64) + base_addr = llvm.ptrtoint(i64, base_ptr) + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [ctx.is_profiling_thread, base_addr, c(modifier | name_id, i32)], + """ + @$0 st.shared.v2.u32 [$1], {$2, %clock}; + """, + "b,l,r", + has_side_effects=True, + ) + new_offset = arith.addi(offset, c(2, index)) + memref.store(new_offset, ctx.offset, []) + store(ProfilerSpec.ENTER) yield store(ProfilerSpec.EXIT) @@ -357,50 +404,32 @@ def finalize(self, grid: tuple[int, ...], block: tuple[int, ...]): index = ir.IndexType.get() i32 = ir.IntegerType.get_signless(32) - gpu.barrier() # Make sure all warpgroups are done. + with self._profiler_ctx() as ctx: + gpu.barrier() # Make sure all warpgroups are done. - block_idx = c(0, index) - for dim in gpu.Dimension: # pytype: disable=wrong-arg-types - block_idx = arith.addi( - arith.muli(block_idx, gpu.grid_dim(dim)), gpu.block_id(dim) - ) - wg_idx = warpgroup_idx(sync=False) - wg_per_block = math.prod(block) // WARPGROUP_SIZE - global_wg_idx = arith.addi( - arith.muli(block_idx, c(wg_per_block, index)), - arith.index_cast(index, wg_idx), - ) - start_offset = arith.muli(global_wg_idx, c(self.entries_per_wg, index)) - wg_gmem_buffer = memref.subview( - self.gmem_buffer, [start_offset], [self.entries_per_wg], [1], - result_type=ir.Type.parse( - f"memref<{self.entries_per_wg}xi32, strided<[1], offset: ?>>" - ), - ) - thread_in_wg = arith.remui(thread_idx(), c(128, i32)) - if_first = scf.IfOp( - arith.cmpi(arith.CmpIPredicate.eq, thread_in_wg, c(0, i32)) - ) - with ir.InsertionPoint(if_first.then_block): - memref.store(self.start, wg_gmem_buffer, [c(0, index)]) - memref.store(smid(), wg_gmem_buffer, [c(1, index)]) - memref.store( - arith.addi(memref.load(self.offset, []), c(3, i32)), - wg_gmem_buffer, - [c(2, index)], + block_idx = c(0, index) + for dim in gpu.Dimension: # pytype: disable=wrong-arg-types + block_idx = arith.addi( + arith.muli(block_idx, gpu.grid_dim(dim)), gpu.block_id(dim) + ) + wg_idx = warpgroup_idx(sync=False) + wg_per_block = math.prod(block) // WARPGROUP_SIZE + global_wg_idx = arith.addi( + arith.muli(block_idx, c(wg_per_block, index)), + arith.index_cast(index, wg_idx), ) - - for_op = scf.ForOp( - c(0, index), - c(self.entries_per_wg - 3, index), - c(1, index), + start_offset = arith.muli(global_wg_idx, c(self.entries_per_wg, index)) + wg_gmem_buffer = memref_slice( + ctx.gmem_buffer, ds(start_offset, self.entries_per_wg) ) - with ir.InsertionPoint(for_op.body): - x = memref.load(self.smem_buffer, [for_op.induction_variable]) - memref.store( - x, - wg_gmem_buffer, - [arith.addi(for_op.induction_variable, c(3, index))], + with when(ctx.is_profiling_thread): + memref.store(ctx.start, wg_gmem_buffer, [c(0, index)]) + memref.store(smid(), wg_gmem_buffer, [c(1, index)]) + num_traces = arith.index_cast(i32, memref.load(ctx.offset, [])) + memref.store(num_traces, wg_gmem_buffer, [c(2, index)]) + traces = vector.load( + ir.VectorType.get((self.entries_per_wg - 3,), i32), + ctx.smem_buffer, + [c(0, index)], ) - scf.yield_([]) - scf.yield_([]) + vector.store(traces, wg_gmem_buffer, [c(3, index)]) diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index 3330500cd6dc..7703799675df 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -16,25 +16,33 @@ from __future__ import annotations import dataclasses +import functools +import itertools import math +from typing import Any, Callable, Iterator, cast from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith from jaxlib.mlir.dialects import llvm from jaxlib.mlir.dialects import memref +from jaxlib.mlir.dialects import nvvm import numpy as np -from . import utils from . import fragmented_array as fa from . import mma_utils +from . import utils from .launch_context import LaunchContext -# MyPy does a terrible job with the MLIR API. -# mypy: ignore-errors - TMEM_ROWS = 128 +TMEM_MAX_COLS = 512 TCGEN05_SMEM_DESCRIPTOR_BIT = 1 << 46 +LAYOUT = fa.TCGEN05_LAYOUT +TRANSPOSED_LAYOUT = fa.TCGEN05_TRANSPOSED_LAYOUT +ROW_LAYOUT = fa.TCGEN05_ROW_LAYOUT +COL_LAYOUT = fa.TCGEN05_COL_LAYOUT +TMEM_NATIVE_LAYOUT = fa.TMEM_NATIVE_LAYOUT + def create_instr_descriptor( m: int, @@ -43,21 +51,46 @@ def create_instr_descriptor( input_dtype, transpose_a: bool = False, transpose_b: bool = False, -): - f32 = ir.F32Type.get() - bf16 = ir.BF16Type.get() + sparsity_selector: int | None = None, +) -> ir.Value: f16 = ir.F16Type.get() - if input_dtype not in {f16, bf16}: - raise NotImplementedError("Only float16 and bfloat16 inputs supported") - if acc_dtype not in {f32, f16}: - raise NotImplementedError("Only float32 and float16 accumulators supported") + f32 = ir.F32Type.get() + i32 = ir.IntegerType.get_signless(32) desc = 0 - # We ignore sparsity in bits 0-3 - desc |= (acc_dtype == f32) << 4 # D dtype, bits 4-5 + if sparsity_selector is not None: + assert 0 <= sparsity_selector < 3 + desc |= sparsity_selector + desc |= 1 << 2 # Enable sparsity + if acc_dtype == f16: + d_type_val = 0 + elif acc_dtype == f32: + d_type_val = 1 + elif acc_dtype == i32: + d_type_val = 2 + else: + raise NotImplementedError(f"Unsupported accumulator dtype: {acc_dtype}") + desc |= (d_type_val << 4) # D type, bits 4-5 # Bit 6 is reserved - desc |= (input_dtype == bf16) << 7 # A dtype, bits 7-9 - desc |= (input_dtype == bf16) << 10 # B dtype, bits 10-12 + if input_dtype == f16: + assert acc_dtype in {f16, f32} + ab_type_val = 0 + elif input_dtype == ir.BF16Type.get(): + assert acc_dtype == f32 + ab_type_val = 1 + elif input_dtype == ir.Float8E4M3FNType.get(): + assert acc_dtype in {f16, f32} + ab_type_val = 0 + elif input_dtype == ir.Float8E5M2Type.get(): + assert acc_dtype in {f16, f32} + ab_type_val = 1 + elif input_dtype == ir.IntegerType.get_signless(8): # Only s8 for now. + assert acc_dtype == i32 + ab_type_val = 1 + else: + raise NotImplementedError(f"Unsupported input dtype: {input_dtype}") + desc |= (ab_type_val << 7) # A dtype, bits 7-9 + desc |= (ab_type_val << 10) # B dtype, bits 10-12 # We ignore negate bits 13-14 desc |= transpose_a << 15 # Transpose A desc |= transpose_b << 16 # Transpose B @@ -73,34 +106,123 @@ def create_instr_descriptor( return arith.constant(ir.IntegerType.get_signless(32), desc) +def _create_scaled_instr_descriptor( + get_input_encoding: Callable[[ir.Type], int], + m: int, + n: int, + a_type: ir.Type, + b_type: ir.Type, + a_scale_idx: int, + b_scale_idx: int, + transpose_a: bool, + transpose_b: bool, + scale_type: ir.Type, +) -> ir.Value: + desc = 0 + # Bits 0, 1 are reserved + # We ignore sparsity (bit 2) + # Bit 3 is reserved + assert 0 <= b_scale_idx < 4 + desc |= b_scale_idx << 4 # B scale factor data ID, bits 4-5 + # Bit 6 is reserved + desc |= get_input_encoding(a_type) << 7 # A dtype, bits 7-9 + desc |= get_input_encoding(b_type) << 10 # B dtype, bits 10-12 + # We ignore negate bits 13-14 + desc |= transpose_a << 15 # Transpose A + desc |= transpose_b << 16 # Transpose B + if n % 8 or n > 256: + raise ValueError(f"N must be a multiple of 8 and <= 256, got: {n}") + desc |= (n >> 3) << 17 # N, bits 17-22 + if scale_type == ir.Float8E8M0FNUType.get(): + scale_encoding = 1 + elif scale_type == ir.Float8E4M3FNType.get(): + scale_encoding = 0 + else: + raise NotImplementedError(f"Unsupported scale type: {scale_type}") + desc |= scale_encoding << 23 # Scale matrix type + # Bits 24-26 are reserved + if m % 128 or m > 256: + raise ValueError(f"M must be a multiple of 16 and <= 256, got: {m}") + desc |= (m >> 7) << 27 # M >> 7, bits 27-28 + desc |= a_scale_idx << 29 # A scale factor data ID, bits 29-30 + # Bit 31 is reserved + return arith.constant(ir.IntegerType.get_signless(32), desc) + + +def create_scaled_f8f6f4_instr_descriptor(*args, **kwargs) -> ir.Value: + def get_input_encoding(ty): + if ty == ir.Float8E4M3FNType.get(): + return 0 + elif ty == ir.Float8E5M2Type.get(): + return 1 + else: + raise NotImplementedError(f"Unsupported input dtype: {ty}") + return _create_scaled_instr_descriptor(get_input_encoding, *args, **kwargs) + + +def create_scaled_f4_instr_descriptor(*args, **kwargs) -> ir.Value: + def get_input_encoding(ty): + if ty == ir.Float4E2M1FNType.get(): + return 1 + else: + raise NotImplementedError(f"Unsupported input dtype: {ty}") + return _create_scaled_instr_descriptor(get_input_encoding, *args, **kwargs) + + def mma( d: TMEMRef, - a: ir.Value, + a: ir.Value | TMEMRef, b: ir.Value, *, a_swizzle: int = 128, b_swizzle: int = 128, + a_scale: TMEMRef | None = None, + b_scale: TMEMRef | None = None, + a_sparse_metadata: TMEMRef | None = None, accumulate: ir.Value | bool = True, collective: bool = False, -): +) -> None: if a_swizzle == 16 or b_swizzle == 16: raise NotImplementedError("No swizzle is not supported") i32 = ir.IntegerType.get_signless(32) i64 = ir.IntegerType.get_signless(64) if isinstance(accumulate, bool): accumulate = arith.constant(ir.IntegerType.get_signless(1), accumulate) - if a_swizzle != b_swizzle: - raise NotImplementedError(f"{a_swizzle=} != {b_swizzle=}") - swizzle = a_swizzle num_cta = 2 if collective else 1 + if (is_scaled := a_scale is not None) != (b_scale is not None): + raise ValueError("Either none or both scales should be provided") + is_sparse = a_sparse_metadata is not None + if is_scaled and is_sparse: + raise NotImplementedError("Block-scaled sparse matmuls unsupported") # Step 1. Establish the shape and element type of the operation. - if not ir.MemRefType.isinstance(a.type): - raise ValueError(f"A must be a memref, got {a.type}") - if not ir.MemRefType.isinstance(b.type): + if not isinstance(b.type, ir.MemRefType): raise ValueError(f"B must be a memref, got: {b.type}") (k, n), element_type = mma_utils.tiled_memref_shape(b) - (m, k2), element_type2 = mma_utils.tiled_memref_shape(a) + if isinstance(a, TMEMRef): + m, k2 = a.shape + element_type2 = a.dtype + if is_scaled: + raise NotImplementedError( + "A in TMEM unsupported for block-scaled matmuls" + ) + if m != 128: + raise NotImplementedError(f"Only M=128 is supported for MMA with A in TMEM, but got M={m}") + # Watch out: this layout must be consistent with D's layout (up to packing). + expected_packing = 32 // utils.bitwidth(element_type) + expected_layout = _infer_tmem_layout( + a.shape, collective, packing=expected_packing + ) + if a.layout != expected_layout: + raise ValueError( + f"A layout mismatch: expected {expected_layout}, got {a.layout}" + ) + else: + if not isinstance(a.type, ir.MemRefType): + raise ValueError(f"A must be a memref, got {a.type}") + (m, k2), element_type2 = mma_utils.tiled_memref_shape(a) + if is_sparse: + k2 *= 2 if k != k2: raise ValueError( "MMA requires A and B to have the same contraction dimension (K)," @@ -115,35 +237,139 @@ def mma( raise ValueError( f"Accumulator shape mismatch: expected {(m, n * num_cta)}, got {d.shape}" ) - if d.layout != (expected_layout := _infer_tmem_layout(d.shape, collective)): - raise ValueError( - f"Accumulator layout mismatch: expected {expected_layout}, got {d.layout}" - ) + if m == 128: + if d.layout != (expected_d_layout := tmem_default_layout(packing=1)): + raise ValueError( + f"Accumulator layout mismatch: expected {expected_d_layout}, got {d.layout}" + ) + n_lane_groups = 1 + elif m == 64: + if is_scaled: + raise NotImplementedError("MMA with block scaling is not supported for M=64") + if is_sparse: + raise NotImplementedError("Sparse MMA not supported for M=64") + # Watch out: this layout must be consistent with A's layout (up to packing). + # 2CTA M=128 instruction uses a different TMEM layout than 1CTA M=64. + expected_d_layout = _infer_tmem_layout(d.shape, collective, packing=1) + if d.layout != expected_d_layout: + raise ValueError( + f"Accumulator layout mismatch: expected {expected_d_layout}, got {d.layout}" + ) + if collective: + n_lane_groups = 1 + else: + n_lane_groups = 2 + # We can't split N into groups if we would partition it below the tile size. + # TODO: We only need to check this if N is the minormost dim in B. + if 8 * b_swizzle // utils.bitwidth(element_type) > n // n_lane_groups: + raise ValueError( + f"Swizzle={b_swizzle} is too big for MMA with M=64. Try" + " lowering it." + ) + else: + raise ValueError(f"Only M=128 and M=64 are supported for MMA, but got M={m}") f32 = ir.F32Type.get() + f16 = ir.F16Type.get() + s32 = ir.IntegerType.get_signless(32) if element_type == f32 or element_type == ir.BF16Type.get(): + if element_type == f32 and is_sparse: + raise NotImplementedError("Sparse MMA unsupported for f32") + if is_scaled: + raise ValueError( + f"MMA with element type {element_type} does not support block scaling" + ) if d.dtype != f32: raise ValueError( f"MMA with element type {element_type} only supports accumulators" f" of type f32, but got: {d.dtype}" ) - elif element_type == ir.F16Type.get(): - if d.dtype != element_type and d.dtype != f32: + elif element_type == f16: + if is_scaled: + raise ValueError( + f"MMA with element type {element_type} does not support block scaling" + ) + if d.dtype != f16 and d.dtype != f32: + raise ValueError( + f"MMA with element type {element_type} only supports accumulators of" + f" type f32 or f16, but got: {d.dtype}" + ) + elif any( + isinstance(element_type, t) + for t in {ir.Float8E5M2Type, ir.Float8E4M3FNType} + ): + if d.dtype != f16 and d.dtype != f32: + raise ValueError( + f"MMA with element type {element_type} only supports accumulators of" + f" type f32 or f16, but got: {d.dtype}" + ) + if is_scaled and d.dtype != f32: + raise ValueError( + f"Block-scaled MMA with element type {element_type} only supports f32" + f" accumulators, but got: {d.dtype}" + ) + elif any(isinstance(element_type, t) for t in {ir.Float4E2M1FNType}): + if is_sparse: + raise NotImplementedError("Sparse MMA unsupported for f4e2m1fn") + if not is_scaled: raise ValueError( - "MMA with element type f16 only supports accumulators of type f32" - f" or f16, but got: {d.dtype}" + f"MMA with element type {element_type} only supports block scaling" ) + if d.dtype != f32: + raise ValueError( + f"Block-scaled MMA with element type {element_type} only supports f32" + f" accumulators, but got: {d.dtype}" + ) + elif element_type == ir.IntegerType.get_signless(8): + if is_scaled: + raise ValueError( + f"MMA with element type {element_type} does not support block scaling" + ) + if d.dtype != s32: + raise ValueError( + "MMA with element type s8 only supports s32 accumulators, but got:" + f" {d.dtype}" + ) + else: + raise NotImplementedError(f"Unsupported element type: {element_type}") # Step 2. Decide on the instruction shapes we'll use. Note that with swizzles, - # instructions must be issued in groups of the same width as the swizzle. - m_group_elems = d.layout.elements_in_tile[0] - if m_group_elems != 128: - raise NotImplementedError("Only 128-row accumulators supported for now") - k_group_elems = swizzle // utils.bytewidth(element_type) - if n % 8: - raise ValueError(f"N must be a multiple of 8, got: {n}") - elif n > 256 and n != 512: - raise ValueError("Only N below 256 or N=512 are supported") - n_group_elems = min(n, 256 // num_cta) + # instructions must be issued in groups that are a multiple of swizzle. + m_group_elems = m # We have already verified M is supported above. + k_group_elems = 8 * max(a_swizzle * (1 + is_sparse), b_swizzle) // utils.bitwidth(element_type) + if is_sparse and k_group_elems < 64: + # This is a limitation of the implementation below. We could relax it if we + # ever need to support k=32. + k_group_elems = 64 + scale_block: int | None = None + if is_scaled: + scale_block = 32 if a_scale.dtype == ir.Float8E8M0FNUType.get() else 16 # type: ignore + k_group_elems = max(k_group_elems, 4 * scale_block) + required_multiple = 16 if collective else 8 + mode_name = "2 CTA" if collective else "1 CTA" + if d.dtype == s32: + required_multiple *= 2 + mode_name += " integer" + if n_lane_groups > 1: + mode_name += f" with {n_lane_groups} lane groups" + if (n // n_lane_groups) % required_multiple != 0: + raise ValueError( + f"In {mode_name} MMA, N must be a multiple of {required_multiple}," + f" got N={n}" + ) + if (is_sparse or is_scaled) and n.bit_count() != 1: + raise NotImplementedError( + "Only N that is power of 2 supported for sparse and block-scaled MMA," + f" but got N={n}" + ) + if n > 256 and n.bit_count() != 1: + raise NotImplementedError(f"The only supported N > 256, is 512, but got N={n}") + # TODO: We could relax those constraints if we have multiple n_lane_groups, + # since we will be unrolling the instructions anyway. + if collective and n > 128: + raise ValueError("Only N <= 128 are supported for collective MMA") + elif n > 512: + raise ValueError("Only N <= 512 are supported for MMA") + n_group_elems = min(n // n_lane_groups, 256 // num_cta) if m % m_group_elems: raise ValueError(f"M must be a multiple of {m_group_elems}, got: {m}") if k % k_group_elems: @@ -153,114 +379,366 @@ def mma( m_groups = m // m_group_elems k_groups = k // k_group_elems n_groups = n // n_group_elems - # TODO(apaszke): Require users to bitcast input refs to tf32 before WGMMA. - wgmma_element_type = ( + # TODO(apaszke): Require users to bitcast input refs to tf32 before MMA. + mma_element_type = ( ir.FloatTF32Type.get() if element_type == ir.F32Type.get() else element_type ) + # Check that the shapes and element types are correct for block scaling. + scale_element_type = None + if is_scaled: + assert m == 128 # Checked above. + if n % 32: + raise ValueError( + f"MMA with block scaling requires N to be divisible by 32, got: {n}" + ) + assert a_scale is not None and b_scale is not None + scale_element_type = a_scale.dtype + if ( + a_scale.dtype != ir.Float8E8M0FNUType.get() + and a_scale.dtype != ir.Float8E4M3FNType.get() + ): + raise ValueError( + f"A scale dtype mismatch: expected f8e8m0fnu or f8e4m3fn, got {a_scale.dtype}" + ) + if b_scale.dtype != a_scale.dtype: + raise ValueError( + f"B scale dtype mismatch: expected {a_scale.dtype} (same as A), got" + f" {b_scale.dtype}" + ) + if a_scale.shape != (m, k // scale_block): + raise ValueError( + f"A scale shape mismatch: expected ({m}, {k // scale_block}), got" + f" {a_scale.shape}" + ) + if b_scale.shape != (n * num_cta, k // scale_block): + raise ValueError( + f"B scale shape mismatch: expected ({n}, {k // scale_block}), got" + f" {b_scale.shape}" + ) + if is_sparse: + a_sparse_metadata = cast(TMEMRef, a_sparse_metadata) + if n % 32: + raise ValueError(f"Sparse MMA requires N to be divisible by 32, got: {n}") + if a_sparse_metadata.shape != (m, k // 2): + raise ValueError( + f"A sparse metadata shape mismatch: expected {(m, k // 2)}, got" + f" {a_sparse_metadata.shape}" + ) + if a_sparse_metadata.dtype != ir.IntegerType.get_signless(2): + raise ValueError( + "A sparse metadata dtype mismatch: expected i2, got" + f" {a_sparse_metadata.dtype}" + ) + # Step 3. Compute the operand descriptors. + if not isinstance(a, TMEMRef): + # Both dense and sparse matmul consume A with a K bytewidth of 32, only + # the group size is halved when it's sparse. + ( + (a_desc_base, a_k_instr_strides), + (a_m_group_stride, a_k_group_stride), + a_fastest, + ) = mma_utils.create_descriptor( + a, + swizzle=a_swizzle, + group_size=(m_group_elems, k_group_elems // (1 + is_sparse)), + logical_k_major=False, + mma_bytewidth_k=32, + split_const=True, + ) + else: + a_fastest = mma_utils.Dim.K + a_k_instr_strides = None + a_m_group_stride = a_k_group_stride = a_desc_base = None ( - (a_desc_base, a_k_instr_stride), - (a_m_group_stride, a_k_group_stride), - a_fastest, - ) = mma_utils.create_descriptor( - a, - swizzle=swizzle, - group_size=(m_group_elems, k_group_elems), - logical_k_major=False, - ) - ( - (b_desc_base, b_k_instr_stride), + (b_desc_base, b_k_instr_strides), (b_n_group_stride, b_k_group_stride), b_fastest, ) = mma_utils.create_descriptor( b, - swizzle=swizzle, + swizzle=b_swizzle, group_size=(k_group_elems, n_group_elems), logical_k_major=True, + mma_bytewidth_k=64 if is_sparse else 32, + split_const=True, ) + if is_scaled and utils.bitwidth(mma_element_type) == 4: + if a_fastest != mma_utils.Dim.K: + raise ValueError( + "4-bit block scaled MMA only supports K-fastest operands, but A is M-fastest" + ) + if b_fastest != mma_utils.Dim.K: + raise ValueError( + "4-bit block scaled MMA only supports K-fastest operands, but B is N-fastest" + ) + if is_sparse: + if b_swizzle == 32 and b_fastest == mma_utils.Dim.K: + raise NotImplementedError( + "B tiling too small. Increase swizzle or transpose the input." + ) + # Step 4. Issue the instructions. true = arith.constant(ir.IntegerType.get_signless(1), 1) n_collective_group_elems = n_group_elems * num_cta + n_col_groups = n_groups // n_lane_groups + assert d.layout.base_tile_shape[0] % 4 == 0 + lanes_per_n_group = d.layout.base_tile_shape[0] // 4 + a_sparse_addr_base = a_sparse_metadata.address if is_sparse else None # type: ignore + a_scale_addr_base = a_scale.address if is_scaled else None # type: ignore + b_scale_addr_base = b_scale.address if is_scaled else None # type: ignore for mi, ni, ki in np.ndindex(m_groups, n_groups, k_groups): - a_offset = mi * a_m_group_stride + ki * a_k_group_stride - a_mk = arith.addi(a_desc_base, utils.c(mma_utils.encode_addr(a_offset), i64)) + if isinstance(a, TMEMRef): + if m_groups != 1: + raise NotImplementedError("A address calculation for multiple M tiles") + a_k_group_elems = k_group_elems // (1 + is_sparse) + a_mk = a.slice(slice(None), utils.ds(ki * a_k_group_elems, a_k_group_elems)).address + else: + a_offset = mi * a_m_group_stride + ki * a_k_group_stride + a_mk = (a_desc_base[0], a_desc_base[1] + mma_utils.encode_addr(a_offset)) b_offset = ni * b_n_group_stride + ki * b_k_group_stride - b_nk = arith.addi(b_desc_base, utils.c(mma_utils.encode_addr(b_offset), i64)) - if m_groups != 1: - raise NotImplementedError("D needs to be sliced") + b_nk = (b_desc_base[0], b_desc_base[1] + mma_utils.encode_addr(b_offset)) + if a_sparse_addr_base is not None: + if n_groups != 1 or m_groups != 1: + raise NotImplementedError("A sparse metadata address calculation for multiple tiles") + assert k_group_elems % 32 == 0 + cols_per_k_group = k_group_elems // 32 + a_sparse_addr = arith.addi(a_sparse_addr_base, utils.c(ki * cols_per_k_group, i32)) + else: + a_sparse_addr = None + if a_scale_addr_base is not None and b_scale_addr_base is not None: + if m_groups != 1: + raise NotImplementedError("A scale address calculation for multiple M tiles") + if n_groups != 1: + raise NotImplementedError("B scale address calculation for multiple N tiles") + assert scale_block is not None # For type checkers. + assert k_group_elems % (scale_block * 4) == 0 + assert m_group_elems % 32 == 0 and n_group_elems % 32 == 0 + k_scales_per_group = k_group_elems // (scale_block * 4) + # A scales are sharded, B scales are replicated across CTAs. + a_scale_addr = arith.addi( + a_scale_addr_base, + utils.c(ki * k_scales_per_group * m_group_elems // 32, i32), + ) + b_scale_addr = arith.addi( + b_scale_addr_base, + utils.c(ki * k_scales_per_group * n_collective_group_elems // 32, i32) + ) + else: + a_scale_addr = b_scale_addr = None acc = accumulate if ki == 0 else true + ni_lane_group, ni_col = ni // n_col_groups, ni % n_col_groups + d_offset = ( + ((ni_lane_group * lanes_per_n_group) << 16) + + ni_col * n_collective_group_elems + ) + if m_groups != 1: + raise NotImplementedError("D address calculation for multiple M tiles") _do_mma( - arith.addi( - d.address, arith.constant(i32, ni * n_collective_group_elems) - ), + arith.addi(d.address, arith.constant(i32, d_offset)), a_mk, b_nk, - d_type=ir.F32Type.get(), + d_type=d.dtype, m=m_group_elems, n=n_group_elems, + k=k_group_elems, collective=collective, a_transpose=a_fastest != mma_utils.Dim.K, b_transpose=b_fastest != mma_utils.Dim.K, - a_k_stride=a_k_instr_stride, - b_k_stride=b_k_instr_stride, + a_k_strides=a_k_instr_strides, + b_k_strides=b_k_instr_strides, + a_scale_addr=a_scale_addr, + b_scale_addr=b_scale_addr, + a_sparse_addr=a_sparse_addr, accumulate=acc, - swizzle=swizzle, - element_type=wgmma_element_type, + element_type=mma_element_type, + scale_element_type=scale_element_type, ) def _do_mma( d_addr: ir.Value, - a_desc: ir.Value, - b_desc: ir.Value, + a_desc_or_addr: tuple[ir.Value, int] | ir.Value, # TMEM address if a_k_stride is None + b_desc: tuple[ir.Value, int], a_transpose: bool, b_transpose: bool, - a_k_stride: int, - b_k_stride: int, + a_k_strides: tuple[tuple[int, ...], tuple[int, ...]] | None, + b_k_strides: tuple[tuple[int, ...], tuple[int, ...]], + a_scale_addr: ir.Value | None, + b_scale_addr: ir.Value | None, + a_sparse_addr: ir.Value | None, m: int, n: int, - swizzle: int, + k: int, element_type: ir.Type, + scale_element_type: ir.Type | None, d_type: ir.Type, accumulate: ir.Value, collective: bool, -): +) -> None: i1 = ir.IntegerType.get_signless(1) + i32 = ir.IntegerType.get_signless(32) i64 = ir.IntegerType.get_signless(64) - kn_tiling = swizzle // utils.bytewidth(element_type) - instr_k = 32 // utils.bytewidth(element_type) - if a_k_stride % 16 or b_k_stride % 16: - raise ValueError + a_k_idx_tiling, a_k_strides = a_k_strides or (None, None) + b_k_idx_tiling, b_k_strides = b_k_strides + assert all(s % 16 == 0 for s in itertools.chain(a_k_strides or (), b_k_strides)) + assert (a_scale_addr is None) == (b_scale_addr is None) + is_scaled = a_scale_addr is not None + is_sparse = a_sparse_addr is not None + elem_bitwidth = utils.bitwidth(element_type) + instr_k = (1 + is_sparse) * 8 * 32 // elem_bitwidth + packing = 8 * 4 // elem_bitwidth - if ir.F16Type.isinstance(element_type) or ir.BF16Type.isinstance(element_type): - kind = "f16" + scale_steps = None + if is_scaled: + assert not is_sparse + if isinstance(element_type, ir.Float8E5M2Type) or isinstance( + element_type, ir.Float8E4M3FNType + ): + if scale_element_type != ir.Float8E8M0FNUType.get(): + raise ValueError( + f"Scale element type mismatch: expected f8e8m0fnu, got {scale_element_type}" + ) + kind = "mxf8f6f4.block_scale.scale_vec::1X" + scale_steps = 4 + create_scaled_instr_descriptor = functools.partial( + create_scaled_f8f6f4_instr_descriptor, scale_type=scale_element_type + ) + elif isinstance(element_type, ir.Float4E2M1FNType): + assert not a_transpose and not b_transpose + create_scaled_instr_descriptor = functools.partial( + create_scaled_f4_instr_descriptor, + scale_type=scale_element_type, + ) + if scale_element_type == ir.Float8E8M0FNUType.get(): + kind = "mxf4.block_scale.scale_vec::2X" + scale_steps = 2 + elif scale_element_type == ir.Float8E4M3FNType.get(): + kind = "mxf4nvf4.block_scale.scale_vec::4X" + scale_steps = 1 + else: + raise NotImplementedError(f"Unsupported element type for block scaling: {element_type}") + extra_ptx = "[$5], [$6], " + extra_constraints = ",r,r" else: - raise NotImplementedError(f"Unsupported input element type: {element_type}") + if isinstance(element_type, ir.F16Type) or isinstance( + element_type, ir.BF16Type + ): + kind = "f16" + elif isinstance(element_type, ir.Float8E5M2Type): + kind = "f8f6f4" + elif isinstance(element_type, ir.Float8E4M3FNType): + kind = "f8f6f4" + elif ( + isinstance(element_type, ir.IntegerType) + and element_type.width == 8 + and element_type.is_signless + ): + kind = "i8" + else: + raise NotImplementedError( + f"Unsupported input element type: {element_type}" + ) + extra_constraints = extra_ptx = "" + + def create_scaled_instr_descriptor(*args): # type: ignore + raise NotImplementedError num_cta = 2 if collective else 1 - i_desc = create_instr_descriptor( - m * num_cta, n * num_cta, d_type, element_type, a_transpose, b_transpose - ) - for _ in range(kn_tiling // instr_k): + a_in_tmem = a_k_strides is None + a_ptx = "[a_desc]" if a_in_tmem else "a_desc" + sparse_mod = ".sp" if is_sparse else "" + sparse_meta_ptx = "[$5], " if is_sparse else "" + extra_constraints += ",r" if is_sparse else "" + sparse_addr: tuple[Any, ...] = () + scales_addrs: tuple[Any, ...] = () + def _get_offset(idx: int, idx_tiling: tuple[int, ...], strides: tuple[int, ...]): + assert len(idx_tiling) + 1 == len(strides) + idxs = [] + for t in idx_tiling: + idxs.append(idx // t) + idx = idx % t + idxs.append(idx) + offset = sum(i * s for i, s in zip(idxs, strides, strict=True)) + return offset >> 4 + for k_step in range(k // instr_k): + if is_scaled: + assert scale_steps is not None + assert not is_sparse + scale_vec_width = 4 // scale_steps + scale_id = (k_step % scale_steps) * scale_vec_width + i_desc = create_scaled_instr_descriptor( + m * num_cta, n * num_cta, element_type, element_type, + scale_id, scale_id, a_transpose, b_transpose + ) + assert m == 128 + assert (n * num_cta) % 128 == 0 + # A scales are sharded, B scales are replicated across CTAs. + a_scale_addr_offset = arith.constant(i32, k_step // scale_steps * 4) + b_scale_addr_offset = arith.constant(i32, k_step // scale_steps * n // 32 * num_cta) + scales_addrs = ( + arith.addi(a_scale_addr, a_scale_addr_offset), + arith.addi(b_scale_addr, b_scale_addr_offset), + ) + else: + sp_selector = None + if is_sparse: + assert 32 <= instr_k <= 64 + selector_width = instr_k + k_steps_for_col_inc = 64 // selector_width + assert (k // instr_k) % k_steps_for_col_inc == 0 + sp_selector = k_step % k_steps_for_col_inc + # If the K group is large, we need to increment the sparse metadata. + # TODO(apaszke): At this point the purpose of this function is becoming + # less clear, since we end up replicating address arithmetic that's + # already there in the caller. We should unify them into a single loop. + sparse_addr = ( + arith.addi( + a_sparse_addr, utils.c(k_step // k_steps_for_col_inc * 2, i32) + ), + ) + i_desc = create_instr_descriptor( + m * num_cta, n * num_cta, d_type, element_type, a_transpose, b_transpose, sparsity_selector=sp_selector + ) + if a_in_tmem: + cols_per_k_group = instr_k // packing // (1 + is_sparse) + a_offset = k_step * cols_per_k_group + assert isinstance(a_desc_or_addr, ir.Value) + assert a_desc_or_addr.type == ir.IntegerType.get_signless(32) + a_enc_addr_base = a_desc_or_addr + else: + assert a_k_idx_tiling is not None and a_k_strides is not None + a_enc_addr_base, a_offset = a_desc_or_addr + a_offset += _get_offset(k_step, a_k_idx_tiling, a_k_strides) + b_enc_addr_base, b_offset = b_desc + b_offset += _get_offset(k_step, b_k_idx_tiling, b_k_strides) + a_offset_low, a_offset_high = a_offset & 0xFFFFFFFF, a_offset >> 32 + b_offset_low, b_offset_high = b_offset & 0xFFFFFFFF, b_offset >> 32 llvm.inline_asm( ir.Type.parse("!llvm.void"), - [d_addr, a_desc, b_desc, i_desc, accumulate], - f"tcgen05.mma.cta_group::{num_cta}.kind::{kind} [$0], $1, $2, $3, $4;", - "r,l,l,r,b", + [d_addr, a_enc_addr_base, b_enc_addr_base, i_desc, accumulate, *scales_addrs, *sparse_addr], + f"""{{ + .reg .b32 a_desc_low, a_desc_high, b_desc_low, b_desc_high; + .reg {".b32" if a_in_tmem else ".b64"} a_desc; + .reg .b64 b_desc; + add.s32 a_desc_low, $1, {a_offset_low}; + add.s32 b_desc_low, $2, {b_offset_low}; + mov.b64 b_desc, {{b_desc_low, {b_offset_high}}}; + {"mov.b32 a_desc, a_desc_low;" if a_in_tmem else f"mov.b64 a_desc, {{a_desc_low, {a_offset_high}}};"} + tcgen05.mma{sparse_mod}.cta_group::{num_cta}.kind::{kind} [$0], {a_ptx}, b_desc, {sparse_meta_ptx}$3, {extra_ptx}$4; + }}""", + "r,r,r,r,b" + extra_constraints, has_side_effects=True, ) accumulate = arith.constant(i1, 1) - a_desc = arith.addi(a_desc, arith.constant(i64, a_k_stride >> 4)) - b_desc = arith.addi(b_desc, arith.constant(i64, b_k_stride >> 4)) def commit_arrive( barrier: utils.BarrierRef | ir.Value, collective: bool = False, ctx: LaunchContext | None = None, -): +) -> None: if isinstance(barrier, utils.BarrierRef): barrier = barrier.get_ptr() elif barrier.type != ir.Type.parse("!llvm.ptr<3>"): @@ -274,32 +752,25 @@ def commit_arrive( # TODO(apaszke): This is just 0b11 shifted by the even CTA index. if ctx.cluster_size != (2, 1, 1): raise NotImplementedError("Collective arrivals only support (2, 1, 1)-shaped clusters") - ptx = """ - { - .reg .b16 msk; - mov.b16 msk, 3; - tcgen05.commit.cta_group::2.mbarrier::arrive::one.multicast::cluster.b64 [$0], msk; - } - """ + i16 = ir.IntegerType.get_signless(16) + mask = arith.constant(i16, 3) + nvvm.tcgen05_commit( + barrier, group=nvvm.CTAGroupKind.CTA_2, multicast_mask=mask + ) else: - ptx = "tcgen05.commit.cta_group::1.mbarrier::arrive::one.b64 [$0];" - return llvm.inline_asm( - ir.Type.parse("!llvm.void"), [barrier], ptx, "l", has_side_effects=True - ) + nvvm.tcgen05_commit(barrier) -def tmem_alloc(tmem_addr: ir.Value, ncols: int, collective: bool = False, exact: bool = True): - if ir.MemRefType.isinstance(tmem_addr.type): - ref_ty = ir.MemRefType(tmem_addr.type) - if ref_ty.element_type != ir.IntegerType.get_signless(32): - raise ValueError(f"tmem_addr must be an i32 memref, got: {ref_ty}") - if ref_ty.memory_space != ir.Attribute.parse("#gpu.address_space"): - raise ValueError(f"tmem_addr must be in shared memory, got: {ref_ty}") - if math.prod(ref_ty.shape) != 1: - raise ValueError(f"tmem_addr must contain a single element, got: {ref_ty}") - tmem_addr = utils.memref_ptr(tmem_addr, memory_space=3) - elif tmem_addr.type != ir.Type.parse("!llvm.ptr<3>"): - raise ValueError(f"tmem_addr must be an SMEM pointer or a memref, got: {tmem_addr.type}") +def tmem_alloc_exact_ncols(ncols: int, exact: bool) -> int: + """Returns the exact number of columns to allocate in TMEM. + + The number of columns is rounded up to the nearest power of 2. + + Args: + ncols: The number of columns to allocate. + exact: If true, throws an error if the number of columns is not a power of 2 + and within [32, 512]. + """ if exact: if ncols.bit_count() != 1 or not 32 <= ncols <= 512: raise ValueError(f"ncols must be a power of 2 and within [32, 512], got: {ncols}") @@ -309,138 +780,274 @@ def tmem_alloc(tmem_addr: ir.Value, ncols: int, collective: bool = False, exact: raise ValueError( f"After rounding up, got {ncols} columns, exceeding the limit of 512" ) - num_cta = 2 if collective else 1 - return llvm.inline_asm( - ir.Type.parse("!llvm.void"), - [tmem_addr], - f"tcgen05.alloc.cta_group::{num_cta}.sync.aligned.shared::cta.b32 [$0], {ncols};", - "r", - has_side_effects=True, - ) + return ncols -def tmem_relinquish_alloc_permit(): - return llvm.inline_asm( - ir.Type.parse("!llvm.void"), - [], - "tcgen05.relinquish_alloc_permit.cta_group::1.sync.aligned;", - "", - has_side_effects=True, + +def tmem_alloc(tmem_addr: ir.Value, ncols: int, collective: bool = False, exact: bool = True) -> tuple[ir.Value, int]: + if isinstance(tmem_addr.type, ir.MemRefType): + ref_ty = ir.MemRefType(tmem_addr.type) + if ref_ty.element_type != ir.IntegerType.get_signless(32): + raise ValueError(f"tmem_addr must be an i32 memref, got: {ref_ty}") + if not utils.is_smem_ref(ref_ty): + raise ValueError(f"tmem_addr must be in shared memory, got: {ref_ty}") + if math.prod(ref_ty.shape) != 1: + raise ValueError(f"tmem_addr must contain a single element, got: {ref_ty}") + tmem_addr = utils.memref_ptr(tmem_addr, memory_space=3) + elif tmem_addr.type != ir.Type.parse("!llvm.ptr<3>"): + raise ValueError(f"tmem_addr must be an SMEM pointer or a memref, got: {tmem_addr.type}") + ncols = tmem_alloc_exact_ncols(ncols, exact) + group = nvvm.CTAGroupKind.CTA_2 if collective else nvvm.CTAGroupKind.CTA_1 + i32 = ir.IntegerType.get_signless(32) + return nvvm.tcgen05_alloc(tmem_addr, utils.c(ncols, i32), group=group), ncols + + +def _tmem_addr_to_ptr(tmem_addr: ir.Value) -> ir.Value: + assert tmem_addr.type == ir.IntegerType.get_signless(32) + ptr_ty = ir.Type.parse("!llvm.ptr<6>") + return llvm.inttoptr(ptr_ty, tmem_addr) + + +def tmem_dealloc(tmem_addr: ir.Value, ncols: int, collective: bool = False, exact: bool = True) -> None: + if tmem_addr.type != ir.IntegerType.get_signless(32): + raise ValueError(f"tmem_addr must be an i32, got: {tmem_addr.type}") + ncols = tmem_alloc_exact_ncols(ncols, exact) + group = nvvm.CTAGroupKind.CTA_2 if collective else nvvm.CTAGroupKind.CTA_1 + i32 = ir.IntegerType.get_signless(32) + nvvm.tcgen05_dealloc( + _tmem_addr_to_ptr(tmem_addr), utils.c(ncols, i32), group=group ) -def tmem_load(tmem_addr, shape, num): + +def tmem_relinquish_alloc_permit(collective: bool) -> None: + group = nvvm.CTAGroupKind.CTA_2 if collective else nvvm.CTAGroupKind.CTA_1 + nvvm.tcgen05_relinquish_alloc_permit(group=group) + +def _tmem_access_helper(shape, num) -> tuple[int, str]: if num.bit_count() != 1 or num > 128: raise ValueError(f"num must be a power of 2 and <= 128, got: {num}") match shape: + case "32x32b": + num_regs = 1 case "16x128b": - num_out_regs = 2 + num_regs = 2 case "16x256b": - num_out_regs = 4 + num_regs = 4 case _: raise NotImplementedError(f"{shape=} is unsupported") - if num * num_out_regs >= 256: + num_regs *= num + if num_regs > 255: raise ValueError( - f"Loading too much TMEM at once: {num=} and each load requires" - f" {num_out_regs} registers, which exceeds the limit of 256" + f"TMEM translation too big : {shape=} and {num=} involve" + f" {num_regs} registers per-thread, which exceeds the limit of 255" ) - num_out_regs *= num + regs_vector = ",".join(f"${i}" for i in range(num_regs)) + regs_vector = "{" + regs_vector + "}" + return num_regs, regs_vector + + +def _tmem_load(tmem_addr, shape, num, pack: bool): i32 = ir.IntegerType.get_signless(32) - out_regs = ",".join("$" + str(i) for i in range(num_out_regs)) + num_out_regs, regs_vector = _tmem_access_helper(shape, num) + pack_mod = ".pack::16b" if pack else "" regs = llvm.inline_asm( ir.Type.parse( "!llvm.struct<(" + ",".join("i32" for _ in range(num_out_regs)) + ")>" ), [tmem_addr], - f"tcgen05.ld.sync.aligned.{shape}.x{num}.b32 {{{out_regs}}}, [${num_out_regs}];", + f"tcgen05.ld.sync.aligned.{shape}.x{num}{pack_mod}.b32 {regs_vector}, [${num_out_regs}];", "=r," * num_out_regs + "r", has_side_effects=True, ) return [llvm.extractvalue(i32, regs, [i]) for i in range(num_out_regs)] -@dataclasses.dataclass(frozen=True) -class TMEMLayout: - """Represents the way a shape is laid out in TMEM. - - Only 2D shapes are supported. Row tiling must be between 32 and 128, and be - a power of 2. If the row tiling is smaller than 128 (the row count in TMEM), - the tiles are linearized in row-major order, but laid out in TMEM in a - column-major order. - - Consider an array that is (128, 128) and we apply tiling of (64, 64): - - +------------------+------------------+ - | [0:64, 0:64] | [0:64, 64:128] | - +------------------+------------------+ - | [64:128, 0:64] | [64:128, 64:128] | - +------------------+------------------+ +def _tmem_store(tmem_addr, shape, num, regs, unpack: bool) -> None: + num_out_regs, regs_vector = _tmem_access_helper(shape, num) + pack_mod = ".unpack::16b" if unpack else "" + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [*regs, tmem_addr], + f"tcgen05.st.sync.aligned.{shape}.x{num}{pack_mod}.b32 [${num_out_regs}], {regs_vector};", + "r," * num_out_regs + "r", + has_side_effects=True, + ) - In TMEM it will be laid out as follows: - +------------------+------------------+ - | [0:64, 0:64] | [64:128, 0:64] | - +------------------+------------------+ - | [0:64, 64:128] | [64:128, 64:128] | - +------------------+------------------+ +@dataclasses.dataclass(frozen=True) +class TMEMLayout(fa.TiledLayout): + """Represents the way a shape is laid out in TMEM. - The above is further complicated by column_tile_stride, which is used to - swizzle the ordering of column tiles. That is, if column_tile_stride is 2, - we will first lay out all tiles that have the column index 0, 2, 4, and so on - until we run out of tiles. Only then we lay out the tiles with column index - 1, 3, etc. + The layout describes how the shape is split across the 128 rows (lanes) of + TMEM. We reinterpret warp_dims as the partitioning of TMEM into 4 banks, each + accessible from a single warp. The 32 lanes inside each bank are assigned + consecutive elements from lane_dims. The data within each lane is linearized + in row-major order, with each vector padded up to 32 bits (wider vectors are + unsupported). """ - elements_in_tile: tuple[int, int] - column_tile_stride: int = 1 - - def __post_init__(self): - row_tiling = self.elements_in_tile[0] - if not 32 <= row_tiling <= 128: - raise ValueError( - f"Row tiling must be between 32 and 128, got: {row_tiling}" - ) - if row_tiling.bit_count() != 1: - raise ValueError(f"Row tiling must be a power of 2, got: {row_tiling}") - def check_shape(self, shape: tuple[int, ...]): + def check_type(self, shape: tuple[int, ...], bitwidth: int) -> None: if len(shape) != 2: raise ValueError(f"TMEM can only represent 2D shapes, got {shape}") - if any(s % t for s, t in zip(shape, self.elements_in_tile)): + if any(s % t for s, t in zip(shape, self.base_tile_shape)): raise ValueError( - f"{shape} is divisible into tiles of shape {self.elements_in_tile}" + f"{shape} is not divisible into tiles of shape {self.base_tile_shape}" + ) + if self.vector_length not in {1, fully_packed := 32 // bitwidth}: + raise ValueError( + f"For {bitwidth}-bit types, the vector length must be 1 or" + f" {fully_packed} , but got: {self.vector_length}" ) - def cols_in_shape(self, shape: tuple[int, int]): - cols_in_tile = self.elements_in_tile[1] - tiles_in_row = TMEM_ROWS // self.elements_in_tile[0] - num_tiles = math.prod(utils.tile_shape(shape, self.elements_in_tile)[:-2]) - assert num_tiles % tiles_in_row == 0 - return num_tiles // tiles_in_row * cols_in_tile - + def cols_in_shape(self, shape: tuple[int, int], bitwidth: int) -> int: + self.check_type(shape, bitwidth) + replication_factor = 1 + for dim in self.warp_dims: + if isinstance(dim, fa.Replicated): + replication_factor *= dim.times + for dim in self.lane_dims: + if isinstance(dim, fa.Replicated): + replication_factor *= dim.times + return math.prod(shape) // TMEM_ROWS // self.vector_length * replication_factor -def _infer_tmem_layout(shape: tuple[int, int], collective: bool) -> TMEMLayout: - if shape[0] > TMEM_ROWS: - raise ValueError( - "Can only infer TMEM layout for shapes with at most 128 rows, got:" - f" {shape[0]}" - ) - if shape[0] < 32: - raise ValueError( - "Can only infer TMEM layout for shapes with at least 32 rows, got:" - f" {shape[0]}" + def canonicalize(self) -> TMEMLayout: + layout = super().canonicalize() + return TMEMLayout( + layout.tiling, + layout.warp_dims, + layout.lane_dims, + layout.vector_dim, + _check_canonical=False, ) - if shape[0].bit_count() != 1: - raise ValueError( - "Can only infer TMEM layout for shapes with row count that's a power of" - f" 2, got: {shape[0]}" + + def as_tiled_layout(self) -> fa.TiledLayout: + return fa.TiledLayout( + self.tiling, self.warp_dims, self.lane_dims, self.vector_dim ) - if shape[1] % 8: + + +def _infer_tmem_load_registers_layout( + tmem_layout: TMEMLayout, columns: int, packing: int +) -> fa.TiledLayout: + if tmem_layout == tmem_default_layout(packing=packing): + return LAYOUT + if tmem_layout == tmem_half_lane_layout(columns, packing=packing): + return fa.WGMMA_LAYOUT + if tmem_layout == tmem_m64_collective_layout(columns, packing=packing): + return fa_m64_collective_layout(columns) + raise ValueError(f"TMEM layout {tmem_layout} is not supported") + + +def _infer_tmem_layout(shape: tuple[int, int], collective: bool, packing: int) -> TMEMLayout: + if len(shape) != 2: + raise ValueError(f"TMEM can only represent 2D shapes, got {shape}") + if packing > 8 or packing.bit_count() != 1: + raise ValueError(f"Packing must be <= 8 and a power of 2, got: {packing}") + if shape[1] % packing: + raise ValueError(f"Minor dimension of shape must be divisible by packing, got: {shape}") + if shape[0] == TMEM_ROWS: + return tmem_default_layout(packing) + elif shape[0] == TMEM_ROWS // 2: + if collective: + return tmem_m64_collective_layout(shape[1], packing) + else: + return tmem_half_lane_layout(shape[1], packing) + else: raise ValueError( - "Can only infer TMEM layout for shapes with column count that's a" - f" multiple of 8, got: {shape[1]}" + f"Unsupported shape: {shape}. TMEM references must have either" + f" {TMEM_ROWS} or {TMEM_ROWS // 2} rows, but got {shape[0]}." ) - if collective and shape[1] == 512: - return TMEMLayout(elements_in_tile=(shape[0], 128), column_tile_stride=2) - else: - return TMEMLayout(elements_in_tile=(shape[0], 8)) + + +def tmem_default_layout(packing: int = 1) -> TMEMLayout: + """A TMEM layout used for 1CTA MMA with M=128 and 2CTA MMA with M=256.""" + if packing.bit_count() != 1: + raise ValueError(f"Packing must be a power of 2, got: {packing}") + return TMEMLayout( + fa.Tiling(((TMEM_ROWS, packing), (fa.WARP_SIZE, packing))), + warp_dims=(-4,), + lane_dims=(-2,), + vector_dim=-1, + ) + + +def tmem_half_lane_layout(columns, packing: int = 1) -> TMEMLayout: + """A TMEM layout used for 1CTA MMA with M=64.""" + if packing > columns or packing.bit_count() != 1: + raise ValueError(f"Packing must be <= 8 and a power of 2, got: {packing}") + if columns % 16: + raise ValueError(f"Columns must be a multiple of 16, got: {columns}") + return TMEMLayout( + fa.Tiling(( + (TMEM_ROWS // 2, columns), + (fa.WARP_SIZE // 2, columns // 2), + (packing,), + )), + warp_dims=(-5,), + lane_dims=(-4, -3), + vector_dim=-1, + ) + + +def tmem_m64_collective_layout(columns: int, packing: int = 1) -> TMEMLayout: + """A TMEM layout used for 2CTA MMA with M=128.""" + if packing > 8 or packing.bit_count() != 1: + raise ValueError(f"Packing must be <= 8 and a power of 2, got: {packing}") + if columns % 16: + raise ValueError(f"Columns must be a multiple of 16, got: {columns}") + return TMEMLayout( + fa.Tiling(( + (TMEM_ROWS // 2, columns), + (fa.WARP_SIZE, columns // 2), + (packing,), + )), + warp_dims=(-4, -5,), + lane_dims=(-3,), + vector_dim=-1, + ) + + +def fa_m64_collective_layout(columns: int) -> fa.TiledLayout: + """The register layout for transfers to/from tmem_m64_collective_layout.""" + if columns % 16: + raise ValueError(f"Columns must be a multiple of 16, got: {columns}") + return fa.TiledLayout( + fa.Tiling(( + (TMEM_ROWS // 2, columns), (fa.WARP_SIZE, columns // 2), (8, 8), (2,) + )), + warp_dims=(-6, -7), + lane_dims=(-3, -2), + vector_dim=-1, + ) + + +def scales_layout() -> TMEMLayout: + """A TMEM layout for A and B scales in .scale_vec::1X configuration. + + See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x + """ + return TMEMLayout( + fa.Tiling(((TMEM_ROWS, 4), (TMEM_ROWS // 4, 1))), + warp_dims=(fa.Replicated(times=4),), + lane_dims=(-2,), + vector_dim=-3, + ) + + +def sparse_meta_layout() -> TMEMLayout: + """A TMEM layout for A sparsity metadata. + + See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-sparse-matrices-sparsity-selector-kind-tf32-m128-256 + """ + # TODO(apaszke): This does not really describe this layout and we can't do it + # until we add support for multiple vector dims. Still, it's ok to do for now, + # because we don't use TMEM layouts for any automatic transformations at the + # moment and only ever compare it for equality. + return TMEMLayout( + fa.Tiling(((TMEM_ROWS, 16), (TMEM_ROWS // 4, 1), (16, 1), (8, 1))), + warp_dims=(-8,), + lane_dims=(-2, -4, -6), + vector_dim=-7, + ) @dataclasses.dataclass(frozen=True) @@ -450,6 +1057,16 @@ class TMEMRef: dtype: ir.Type layout: TMEMLayout + @property + def packing(self) -> int: + return self.layout.vector_length + + def __post_init__(self): + packed_bitwidth = utils.bitwidth(self.dtype) * self.packing + if not packed_bitwidth <= 32: + raise ValueError("Expected packed packed bitwidth to be <= 32, but got: " + f"{packed_bitwidth=}") + @classmethod def from_alloc( cls, @@ -458,14 +1075,13 @@ def from_alloc( dtype, collective: bool | None = None, layout: TMEMLayout | None = None, - ): + ) -> TMEMRef: i32 = ir.IntegerType.get_signless(32) - if not ir.MemRefType.isinstance(tmem_addr_ref.type): + if not isinstance(tmem_addr_ref.type, ir.MemRefType): raise ValueError(f"tmem_addr_ref must be a memref or a pointer, got: {tmem_addr_ref.type}") addr_ref_ty = ir.MemRefType(tmem_addr_ref.type) - smem = ir.Attribute.parse("#gpu.address_space") - if addr_ref_ty.memory_space != smem: - raise ValueError(f"tmem_addr_ref must be in workgroup memory, got: {addr_ref_ty}") + if not utils.is_smem_ref(addr_ref_ty): + raise ValueError(f"tmem_addr_ref must be in shared memory, got: {addr_ref_ty}") if addr_ref_ty.element_type != i32: raise ValueError(f"tmem_addr_ref must be an i32 memref, got: {addr_ref_ty}") if math.prod(addr_ref_ty.shape) != 1: @@ -479,19 +1095,22 @@ def from_alloc( raise ValueError( "collective argument must be provided when TMEM layout is inferred" ) - layout = _infer_tmem_layout(shape, collective) + layout = _infer_tmem_layout(shape, collective, packing=1) else: - layout.check_shape(shape) + layout.check_type(shape, utils.bitwidth(dtype)) # TODO: Do we have to do this?? # warp_idx = utils.warp_idx(sync=False) # tmem_addr = arith.ori(tmem_addr, arith.shli(warp_idx, utils.c(21, i32))) return cls(tmem_addr, shape, dtype, layout) - def slice(self, *idxs): + def slice(self, *idxs) -> TMEMRef: + i32 = ir.IntegerType.get_signless(32) base_idx, slice_shape, is_squeezed = utils.parse_indices(idxs, self.shape) if any(is_squeezed): raise ValueError("TMEM can only be sliced, not indexed") - if self.layout != TMEMLayout(elements_in_tile=(TMEM_ROWS, 8)): + if base_idx == [0] * len(base_idx) and slice_shape == list(self.shape): + return self # Trival slice + if self.layout != tmem_default_layout(packing=self.packing): raise NotImplementedError( "Slicing only implemented for refs with standard layout, got:" f" {self.layout}" @@ -500,103 +1119,523 @@ def slice(self, *idxs): raise NotImplementedError("TMEM cannot be sliced along rows") if slice_shape[1] % 8: raise NotImplementedError( - "TMEM column slice length must be a multiple of 8" + "TMEM column slice length must be a multiple of 8. " + f"Got {slice_shape[1]}." ) col_idx = base_idx[1] if not isinstance(col_idx, ir.Value): - col_idx = arith.constant(ir.IntegerType.get_signless(32), col_idx) + col_idx = arith.constant(i32, col_idx) + if col_idx.type == ir.IndexType.get(): + col_idx = arith.index_cast(i32, col_idx) + if self.packing != 1: + col_idx = arith.divui(col_idx, arith.constant(i32, self.packing)) return TMEMRef( address=arith.addi(self.address, col_idx), - shape=tuple(slice_shape), + shape=cast(tuple[int, int], tuple(slice_shape)), layout=self.layout, dtype=self.dtype, ) - def __getitem__(self, *idxs): - i32 = ir.IntegerType.get_signless(32) - base_idxs, slice_shape, is_squeezed = utils.parse_indices(idxs, self.shape) - if any(is_squeezed): - raise ValueError("TMEM loads only support slicing") - if any(idx != 0 for idx in base_idxs) or tuple(slice_shape) != self.shape: - raise NotImplementedError("Slicing of TMEM not impelmented yet") - if self.shape[1] % 8: - raise NotImplementedError - if self.dtype != ir.F32Type.get(): - raise NotImplementedError(self.dtype) - layout = _m128_256bit_32bit_layout(self.shape) + def load(self, layout: fa.TiledLayout | None = None, is_signed: bool | None = None) -> fa.FragmentedArray: + packing = self.packing + if layout is None: + layout = _infer_tmem_load_registers_layout( + self.layout, self.shape[1], packing + ) + bitwidth = utils.bitwidth(self.dtype) + has_default_layout = self.layout == tmem_default_layout(packing=packing) regs_shape = layout.registers_shape(self.shape) - if self.layout == TMEMLayout(elements_in_tile=(TMEM_ROWS, 8)): - # load_32xcols returns a 4xN array, but the FA tiling we use here tiles - # columns before rows, and so it is Nx4 (after ignoring all 1 dims). + if regs_shape[0] != 1: # We'll need to issue multiple loads below. + raise NotImplementedError("Loading multiple row tiles") + if layout == LAYOUT and self.layout == tmem_default_layout(packing=packing): registers = _load_32xcols( - self.address, self.shape[1], self.dtype + self.address, self.shape[1], self.dtype, packing ).T.reshape(regs_shape) - elif self.layout == TMEMLayout(elements_in_tile=(TMEM_ROWS, 128), column_tile_stride=2): - if self.shape[1] % 128 != 0: - raise ValueError( - f"TMEM layout {self.layout} is not compatible with shape {self.shape}" - ) - num_column_tiles = self.shape[1] // 128 - column_tile_stride = self.layout.column_tile_stride - num_strided_col_groups = utils.ceil_div(num_column_tiles, column_tile_stride) - tiles = [] - for col_tile_base in range(num_strided_col_groups): - for col_tile in range(col_tile_base, num_column_tiles, column_tile_stride): - tiles.append( - _load_32xcols( - arith.addi(self.address, arith.constant(i32, col_tile * 128)), - cols=128, - dtype=self.dtype, - ) - ) - registers = np.concatenate(tiles, axis=1).T.reshape(regs_shape) + elif layout == self.layout.as_tiled_layout() and packing * bitwidth == 32: + assert len(layout.base_tile_shape) == 2 + # We could allow replicated dims in the input, but we'd need to divide the + # split factor computed below by the replication factor of the input. + assert not any(isinstance(d, fa.Replicated) for d in layout.warp_dims) + assert not any(isinstance(d, fa.Replicated) for d in layout.lane_dims) + warp_split_factor = math.prod( + d.times if isinstance(d, fa.Replicated) else 1 + for d in layout.remove_dimension(1).warp_dims + ) + lane_split_factor = math.prod( + d.times if isinstance(d, fa.Replicated) else 1 + for d in layout.remove_dimension(1).lane_dims + ) + split_factor = warp_split_factor * lane_split_factor + registers = _load_32xcols_native( + self.address, self.shape[1] // split_factor, self.dtype, packing, packing + ).reshape(regs_shape) + # TODO(apaszke): Support the case where we have a long vector length in the + # FA more generally, not just for 2x32b. + # 16-bit types are special, because the store instruction can unpack them. + elif layout == TMEM_NATIVE_LAYOUT and has_default_layout and ( + (bitwidth == 16 and packing == 1) + or (bitwidth == 32 and layout.vector_length == 2) + ): + registers = _load_32xcols_native( + self.address, self.shape[1], self.dtype, packing, TMEM_NATIVE_LAYOUT.vector_length + ).reshape(regs_shape) + elif layout == fa.WGMMA_LAYOUT and self.layout == tmem_half_lane_layout(self.shape[1], packing=packing): + # Load half the columns, since they are folded over lanes. + raw_registers = _load_32xcols( + self.address, self.shape[1] // 2, self.dtype, packing + ) + assert raw_registers.shape[0] == 4 + registers = np.concatenate([raw_registers[:2], raw_registers[2:]], axis=1) + registers = registers.T.reshape(regs_shape) + elif layout == fa_m64_collective_layout(self.shape[1]) and self.layout == tmem_m64_collective_layout(self.shape[1], packing=packing): + regs_shape = layout.registers_shape(self.shape) + # We take half the columns, because they are split over halves of TMEM. + registers = _load_32xcols( + self.address, self.shape[1] // 2, self.dtype, packing + ).reshape(regs_shape) + else: + raise ValueError( + f"Loads from TMEM layout {self.layout} to register layout" + f" {layout} are not supported" + ) + return fa.FragmentedArray( + _registers=registers, _layout=layout, _is_signed=is_signed + ) + + def store(self, value: fa.FragmentedArray): + if not isinstance(value, fa.FragmentedArray): + raise TypeError(f"TMEM stores expect a FragmentedArray, got: {value}") + if value.shape != self.shape: + raise ValueError( + f"Stored array has shape {value.shape}, but TMEM has shape" + f" {self.shape}" + ) + if value.mlir_dtype != self.dtype: + raise ValueError( + f"Stored array has dtype {value.mlir_dtype}, but TMEM has dtype" + f" {self.dtype}" + ) + if not isinstance(value.layout, fa.TiledLayout): + raise TypeError(f"Stored array has layout {value.layout}, but TMEM stores expect a TiledLayout") + packing = self.packing + has_default_layout = self.layout == tmem_default_layout(packing=packing) + bitwidth = utils.bitwidth(self.dtype) + if value.layout == LAYOUT and has_default_layout: + _store_32xcols( + self.address, value.registers.T.reshape((4, -1)), packing + ) + elif value.layout == self.layout.as_tiled_layout() and packing * bitwidth == 32: + _store_32xcols_native(self.address, value.registers.reshape(-1), packing) + # TODO(apaszke): Support the case where we have a long vector length in the + # FA more generally, not just for 2x32b. + # TODO(apaszke): Support a wider range of layouts when dealing with unpacking. + # 16-bit types are special, because the store instruction can unpack them. + elif value.layout == TMEM_NATIVE_LAYOUT and has_default_layout and ( + (bitwidth == 16 and packing == 1) + or (bitwidth == 32 and value.layout.vector_length == 2) + ): + _store_32xcols_native(self.address, value.registers.reshape(-1), packing) + elif ( + value.layout == fa.WGMMA_LAYOUT + and self.layout == tmem_half_lane_layout(self.shape[1], packing=packing) + ): + registers = value.registers.T.reshape(2, -1) + registers = np.concatenate(np.split(registers, 2, axis=1), axis=0) + _store_32xcols(self.address, registers, packing) + elif value.layout == fa_m64_collective_layout( + self.shape[1] + ) and self.layout == tmem_m64_collective_layout( + self.shape[1], packing=packing + ): + _store_32xcols(self.address, value.registers.reshape(4, -1), packing) else: + raise ValueError( + f"Storing from register layout {value.layout} to TMEM layout" + f" {self.layout} is not supported" + ) + + def _debug_print(self) -> None: + i32 = ir.IntegerType.get_signless(32) + num_cols = self.layout.cols_in_shape(self.shape, utils.bitwidth(self.dtype)) + lane = arith.remui(utils.thread_idx(), arith.constant(i32, utils.WARPGROUP_SIZE)) + for c in range(num_cols): + ptr = _tmem_addr_to_ptr(arith.addi(self.address, arith.constant(i32, c))) + val = nvvm.tcgen05_ld(i32, nvvm.Tcgen05LdStShape.SHAPE_32X32B, ptr) + dtype_bitwidth = utils.bitwidth(self.dtype) + full_packing = 32 // dtype_bitwidth + if self.packing == 1: + if dtype_bitwidth < 32: + val = arith.trunci(ir.IntegerType.get_signless(dtype_bitwidth), val) + val = utils.bitcast(val, self.dtype) + elif self.packing == full_packing: + val = utils.bitcast(val, ir.VectorType.get((full_packing,), self.dtype)) + else: + raise NotImplementedError(f"Unsupported packing: {self.packing}") + # TODO(apaszke): Make this print logical, not physical location. + utils.debug_print(f"[{{}}, {c}]: {{}}", lane, val, uniform=False) + + +def _transfer_32xcols( + base_addr: ir.Value, + cols: int, + atom_shape: tuple[int, int], + tmem_packing: int, + reg_packing: int, +) -> Iterator[tuple[ir.Value, int, int, slice]]: + """Generates a sequence of parameters for a given TMEM read or write. + + Arguments: + base_addr: The base address of the TMEM region. + cols: The number of logical columns to transfer. + atom_shape: The logical shape of the tile written by the warp in a single + TMEM transfer. + tmem_packing: Packing degree in TMEM. When packing is 1, but the data is + 16-bit, we expect that each transfer actually involves double the number + of physical columns. + reg_packing: The number of elements that fit in a single 32-bit register. + """ + i32 = ir.IntegerType.get_signless(32) + atom_rows, atom_cols = atom_shape + assert cols % atom_cols == 0 + total_num = cols // atom_cols + regs_per_instr = atom_shape[0] * atom_shape[1] // (utils.WARP_SIZE * reg_packing) + assert 32 % atom_rows == 0 + num_row_steps = 32 // atom_rows + # We artificially lower the instr_num compared to its limits, because higher + # values can lead to register spills.. + max_num = 1 << (total_num.bit_length() - 1) # power of 2 <= than total_num + max_num = min(max_num, 32 // regs_per_instr) + for lane_step in range(num_row_steps): + addr_row = arith.addi(base_addr, utils.c((lane_step * atom_rows) << 16, i32)) + num_processed = 0 + instr_num = max_num + while (remaining := total_num - num_processed) > 0: + while instr_num > remaining: + instr_num //= 2 + num_slice = slice(num_processed, num_processed + instr_num) + addr_row_col = arith.addi( + addr_row, utils.c(num_processed * atom_cols // tmem_packing, i32) + ) + yield addr_row_col, instr_num, lane_step, num_slice + num_processed += instr_num + assert num_processed == total_num + + +def _store_32xcols(base_addr, vector_regs, tmem_packing) -> None: + i32 = ir.IntegerType.get_signless(32) + assert vector_regs.ndim == 2 and vector_regs.shape[0] == 4 + cols = vector_regs.shape[1] * 8 + + reg_packing = 64 // utils.bitwidth(vector_regs.flat[0].type) + if reg_packing == 1: + store_shape = "16x256b" # 4 threads * 64 bits per vreg = 256 bits + regs = np.empty((4, vector_regs.shape[1], 2), dtype=object) + c0 = arith.constant(i32, 0) + c1 = arith.constant(i32, 1) + for idx, vreg in np.ndenumerate(vector_regs): + regs[(*idx, 0)] = llvm.extractelement(vreg, c0) + regs[(*idx, 1)] = llvm.extractelement(vreg, c1) + regs = regs.reshape(2, 2, vector_regs.shape[1], 2).swapaxes(1, 2) + # From a single lane perspective a num tile consists of a 2x2, with the + # minor dim traversing columns and major being 8 rows apart. + # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b + assert regs.shape[-2:] == (2, 2) + assert tmem_packing == 1 + unpack = False + elif reg_packing == 2: + store_shape = "16x128b" # 4 threads * 32 bits per vreg = 128 bits + # From a single lane perspective a num tile has 2 registers, 8 rows apart. + # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16128b + regs = vector_regs.reshape(2, 2, vector_regs.shape[1]).swapaxes(1, 2) + assert 1 <= tmem_packing <= 2 + unpack = tmem_packing == 1 + else: + raise NotImplementedError(reg_packing) + + it = _transfer_32xcols(base_addr, cols, (16, 8), tmem_packing, reg_packing) + for addr_row_col, instr_num, lane_step, num_slice in it: + regs_slice = regs[lane_step, num_slice].flat + _tmem_store(addr_row_col, store_shape, instr_num, regs_slice, unpack) + + +def _store_32xcols_native(base_addr, vector_regs, tmem_packing) -> None: + i32 = ir.IntegerType.get_signless(32) + assert vector_regs.ndim == 1 + vec_ty = ir.VectorType(vector_regs.flat[0].type) + [vector_length] = vec_ty.shape + elt_bitwidth = utils.bitwidth(vec_ty.element_type) + reg_packing = 32 // elt_bitwidth + store_atom_shape = (32, reg_packing) + # TODO(apaszke): More general register splitting code, not just 2x32b. + if reg_packing == 1: + if vector_length == 2: + # Transform data such that each reg is 32 bits wide. + regs = [None] * (len(vector_regs) * 2) + c0 = arith.constant(i32, 0) + c1 = arith.constant(i32, 1) + for idx, vreg in enumerate(vector_regs): + regs[2 * idx] = llvm.extractelement(vreg, c0) + regs[2 * idx + 1] = llvm.extractelement(vreg, c1) + else: + regs = [utils.bitcast(r, i32) for r in vector_regs] + assert tmem_packing == 1 + unpack = False + elif reg_packing == 2: + assert vector_length == 2 + # In this case, registers are already packed into 32-bit registers. + regs = [utils.bitcast(r, i32) for r in vector_regs] + if elt_bitwidth == 16: + assert 1 <= tmem_packing <= 2 + unpack = tmem_packing == 1 + else: + if tmem_packing == 1 and elt_bitwidth != 32: + raise NotImplementedError( + f"Unsupported packing: {tmem_packing} for element type {elt_bitwidth}" + ) + assert tmem_packing == 32 // elt_bitwidth + unpack = False + else: + if tmem_packing != reg_packing: raise NotImplementedError( - f"Loads only implemented for refs with standard layout, got: {self.layout}" + f"Only {reg_packing} packing supported for bitwidth {elt_bitwidth}," + f" but got TMEM packing of {tmem_packing}" ) - return fa.FragmentedArray(_registers=registers, _layout=layout, _is_signed=None) + assert utils.bitwidth(vec_ty) == 32 + regs = [utils.bitcast(r, i32) for r in vector_regs] + unpack = False + cols = len(regs) * reg_packing + it = _transfer_32xcols(base_addr, cols, store_atom_shape, tmem_packing, reg_packing) + for addr_row_col, instr_num, lane_step, num_slice in it: + assert lane_step == 0 + regs_slice = regs[num_slice] + _tmem_store(addr_row_col, "32x32b", instr_num, regs_slice, unpack) -def _load_32xcols(base_addr, cols, dtype): - # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b + +def _load_32xcols(base_addr, cols, dtype, tmem_packing) -> np.ndarray: i32 = ir.IntegerType.get_signless(32) - assert cols % 8 == 0 - cols_per_num_tile = 8 - load_shape = "16x256b" - num = cols // 8 - if num <= 32: - num_tiling = num - elif num == 64: - num_tiling = 32 + vec_ty = ir.VectorType.get((2,), dtype) + reg_packing = 32 // utils.bitwidth(dtype) + if reg_packing == 1: + load_shape = "16x256b" # 4 threads * 64 bits per vreg = 256 bits + assert tmem_packing == 1 + pack = False + elif reg_packing == 2: + load_shape = "16x128b" # 4 threads * 32 bits per vreg = 128 bits + assert 1 <= tmem_packing <= 2 + pack = tmem_packing == 1 else: - raise NotImplementedError(num) - vector_regs = np.ndarray((4, num), dtype=object) - # We load 16 lanes at a time, but need 32 in total. - for row_group in range(2): - addr_row = arith.addi(base_addr, arith.constant(i32, (row_group * 16) << 16)) - regs = [] - for num_group in range(num // num_tiling): - addr_row_col = arith.addi( - addr_row, - arith.constant(i32, num_tiling * num_group * cols_per_num_tile), - ) - regs += tmem_load(addr_row_col, load_shape, num_tiling) - regs = [llvm.bitcast(dtype, r) for r in regs] - undef = llvm.mlir_undef(ir.VectorType.get((2,), dtype)) - for r_low, r_high, idx in zip(regs[::2], regs[1::2], np.ndindex(num, 2)): - high_undef = llvm.insertelement(undef, r_low, utils.c(0, i32)) - vreg = llvm.insertelement(high_undef, r_high, utils.c(1, i32)) - vector_regs[idx[1] + 2 * row_group, idx[0]] = vreg + raise NotImplementedError(reg_packing) + + vector_regs = np.ndarray((4, cols // 8), dtype=object) + + it = _transfer_32xcols(base_addr, cols, (16, 8), tmem_packing, reg_packing) + c0 = arith.constant(i32, 0) + c1 = arith.constant(i32, 1) + for addr_row_col, instr_num, lane_step, num_slice in it: + regs = _tmem_load(addr_row_col, load_shape, instr_num, pack) + row_slice = slice(lane_step * 2, (lane_step + 1) * 2) + # This aliases the original array, so updates will be reflected there. + vector_regs_update = vector_regs[row_slice, num_slice] + assert vector_regs_update.shape == (2, instr_num), (vector_regs_update.shape, instr_num) + if reg_packing == 1: + regs = [llvm.bitcast(dtype, r) for r in regs] + # From a single lane perspective a num tile consists of a 2x2, with the + # minor dim traversing columns and major being 8 rows apart. + # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b + regs = np.asarray(regs, dtype=object).reshape(instr_num, 2, 2).swapaxes(0, 1) + undef = llvm.mlir_undef(vec_ty) + assert regs.shape == (*vector_regs_update.shape, 2) + for idx in np.ndindex(vector_regs_update.shape): + high_undef = llvm.insertelement(undef, regs[(*idx, 0)], c0) + vreg = llvm.insertelement(high_undef, regs[(*idx, 1)], c1) + vector_regs_update[idx] = vreg + else: + assert reg_packing == 2 + regs = [llvm.bitcast(vec_ty, r) for r in regs] + # From a single lane perspective a num tile has 2 registers, 8 rows apart. + # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16128b + regs = np.asarray(regs, dtype=object).reshape(instr_num, 2).swapaxes(0, 1) + vector_regs_update[...] = regs + return vector_regs -def _m128_256bit_32bit_layout(shape: tuple[int, ...]): - if len(shape) != 2: - raise ValueError(f"Shape {shape} is not 2D") - if shape[0] % 128 != 0 or shape[1] % 8 != 0: - raise ValueError(f"Shape {shape} is not a multiple of 64x8") - return fa.TiledLayout( - fa.Tiling(((128, 8), (32, 8), (8, 8), (1, 2))), - warp_dim=-8, - lane_dims=(-4, -3), - vector_dim=-1, - ) +def _load_32xcols_native(base_addr, cols, dtype, tmem_packing, vector_length) -> np.ndarray: + i32 = ir.IntegerType.get_signless(32) + vec_ty = ir.VectorType.get((vector_length,), dtype) + reg_packing = 32 // utils.bitwidth(dtype) + assert vector_length % reg_packing == 0 + load_shape = "32x32b" + load_atom_shape = (32, reg_packing) + if reg_packing == 2: + assert 1 <= tmem_packing <= 2 + pack = tmem_packing == 1 + else: + if tmem_packing != reg_packing: + raise NotImplementedError( + f"Only {reg_packing} supported for element type {dtype}, but got" + f" TMEM packing of {tmem_packing}" + ) + pack = False + + it = _transfer_32xcols(base_addr, cols, load_atom_shape, tmem_packing, reg_packing) + c0 = arith.constant(i32, 0) + c1 = arith.constant(i32, 1) + regs = [None] * (cols // reg_packing) + for addr_row_col, instr_num, lane_step, num_slice in it: + assert lane_step == 0, lane_step + instr_regs = _tmem_load(addr_row_col, load_shape, instr_num, pack) + if reg_packing == 1 and vector_length == 2: + regs[num_slice] = [llvm.bitcast(dtype, r) for r in instr_regs] + else: + regs[num_slice] = [utils.bitcast(r, vec_ty) for r in instr_regs] + + if reg_packing == 1 and vector_length == 2: + vector_regs = np.ndarray((cols // 2,), dtype=object) + undef = llvm.mlir_undef(vec_ty) + for idx in range(vector_regs.size): + high_undef = llvm.insertelement(undef, regs[2 * idx], c0) + vreg = llvm.insertelement(high_undef, regs[2 * idx + 1], c1) + vector_regs[idx] = vreg + else: + assert vector_length == reg_packing + vector_regs = np.asarray(regs, dtype=object) + + return vector_regs + + +def commit_tmem() -> None: + nvvm.tcgen05_wait(nvvm.Tcgen05WaitKind.STORE) + utils.warpgroup_barrier() + + +def wait_load_tmem() -> None: + nvvm.tcgen05_wait(nvvm.Tcgen05WaitKind.LOAD) + utils.warpgroup_barrier() + + +def async_copy_scales_smem_to_tmem( + smem_ref: ir.Value, tmem_ref: TMEMRef, collective: bool = False +) -> None: + """Asynchronously copies the scale data from SMEM to TMEM. + + The result of the copy can be awaited by calling ``commit_arrive`` and waiting + on the chosen ``Barrier``. However, if TMEM reference is to be consumed by a + MMA issued in the same thread, no additional synchronization is needed. + + At the moment the function requires ``smem_ref`` to be contiguous and have a + shape of ``(MN // 128, K // 128, 32, 16)`` for 8-bit scales (here MN stands + for the size of the non-contracting dimension which is M or N), matching the + scale layout for .scale_vec::1X. See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x + for more details. Note that we always put the non-contracting dimension first. + If you have a (MN, K // 32) array of scales in JAX (where MN and K are + divisible by 128), you can prepare it for use in the kernel this way:: + + scales.reshape(mn // 128, 4, 32, k // 4, 4) + .transpose(0, 3, 2, 1, 4) + .reshape(mn // 128, k // 4, 32, 16) + + The TMEM ref is expected to have the logical shape of the scales + ``(MN, K // 32)``, and the layout created by ``scales_layout()``. + """ + i32 = ir.IntegerType.get_signless(32) + smem_ty = ir.MemRefType(smem_ref.type) + if (dtype := smem_ty.element_type) != tmem_ref.dtype: + raise ValueError(f"Incompatible dtypes: SMEM has {dtype}, TMEM has {tmem_ref.dtype}") + if dtype not in {ir.Float8E8M0FNUType.get(), ir.Float8E4M3FNType.get()}: + raise NotImplementedError(f"Unsupported dtype: {dtype}, only f8e8m0fnu and f8e4m3fn are supported") + if tmem_ref.shape[0] % TMEM_ROWS: + raise ValueError(f"TMEM reference must have a multiple of {TMEM_ROWS} rows, but got {tmem_ref.shape[0]}") + if tmem_ref.shape[1] % 4: + raise ValueError(f"TMEM reference must have a multiple of 4 columns, but got {tmem_ref.shape[1]}") + if tmem_ref.layout != scales_layout(): + raise ValueError(f"TMEM layout {tmem_ref.layout} is not supported") + smem_shape = tuple(smem_ty.shape) + expected_smem_shape = (tmem_ref.shape[0] // TMEM_ROWS, tmem_ref.shape[1] // 4, 32, 16) + if smem_shape != expected_smem_shape: + raise NotImplementedError( + f"SMEM has {smem_shape}, but expected {expected_smem_shape} for TMEM" + f" ref shape {tmem_ref.shape}" + ) + strides, _ = smem_ty.get_strides_and_offset() + # TODO(apaszke): This should only matter for the two minor dims. + if strides != utils.get_contiguous_strides(smem_shape): + raise ValueError("Only copies from contiguous SMEM references are supported") + mn_tile_stride, k_tile_stride = strides[:2] + # One tile of scales has 128 bytes. + if mn_tile_stride % 128 or k_tile_stride % 128: + raise ValueError("Scale tile strides must be a multiple of 128") + mn_tile_stride_i32 = mn_tile_stride // 4 + k_tile_stride_i32 = k_tile_stride // 4 + smem_base_ptr = utils.memref_ptr(smem_ref, 3) + # TODO(apaszke): Need to figure out the TMEM layout otherwise and MMA doesn't + # support it anyway. + if smem_shape[0] > 2: + raise NotImplementedError("Only M/N up to 256 supported") + for mn_tile, k_tile in np.ndindex(smem_shape[:2]): + load_ptr = utils.getelementptr( + smem_base_ptr, + [mn_tile * mn_tile_stride_i32 + k_tile * k_tile_stride_i32], + i32, + ) + # NOTE: The tiles are MN-minor in TMEM, but MN-major (logically) in SMEM. + store_addr = arith.addi( + tmem_ref.address, + arith.constant(i32, 4 * smem_shape[0] * k_tile + 4 * mn_tile), + ) + # The "core matrix" here is the same as in MMA: 8x(16 bytes). + desc = mma_utils.encode_descriptor(load_ptr, 0, 8 * 16, swizzle=None) + nvvm.tcgen05_cp( + nvvm.Tcgen05CpShape.SHAPE_32x128b, + _tmem_addr_to_ptr(store_addr), + desc, + multicast=nvvm.Tcgen05CpMulticast.WARPX4, + group=nvvm.CTAGroupKind.CTA_2 if collective else nvvm.CTAGroupKind.CTA_1 + ) + + +def async_copy_sparse_metadata_smem_to_tmem( + smem_ref: ir.Value, tmem_ref: TMEMRef, collective: bool = False +) -> None: + i8 = ir.IntegerType.get_signless(8) + i32 = ir.IntegerType.get_signless(32) + smem_ty = ir.MemRefType(smem_ref.type) + if (dtype := smem_ty.element_type) != tmem_ref.dtype: + raise ValueError(f"Incompatible dtypes: SMEM has {dtype}, TMEM has {tmem_ref.dtype}") + if dtype != ir.IntegerType.get_signless(2): + raise NotImplementedError(f"Unsupported dtype: {dtype}, only i2 supported") + if tmem_ref.shape[0] % 128: + raise ValueError(f"TMEM reference must have a multiple of 128 rows, but got {tmem_ref.shape[0]}") + if tmem_ref.shape[1] % 64: + raise ValueError(f"TMEM reference must have a multiple of 64 colums, but got {tmem_ref.shape[1]}") + if tmem_ref.layout != sparse_meta_layout(): + raise ValueError(f"TMEM layout {tmem_ref.layout} is not supported") + smem_shape = tuple(smem_ty.shape) + expected_smem_shape = (tmem_ref.shape[0] // 128, tmem_ref.shape[1] // 64, 128, 64) + if smem_shape != expected_smem_shape: + raise NotImplementedError( + f"SMEM has {smem_shape}, but expected {expected_smem_shape} for TMEM" + f" ref shape {tmem_ref.shape}" + ) + strides, _ = smem_ty.get_strides_and_offset() + if strides != utils.get_contiguous_strides(smem_shape): + raise ValueError("Only copies from contiguous SMEM references are supported") + if expected_smem_shape[0] != 1: + raise NotImplementedError("Only M=128 supported") + k_tile_stride = strides[1] + if k_tile_stride % 16: + raise ValueError("K tile stride must be a multiple of 16") + k_tile_byte_stride = k_tile_stride // 4 + smem_base_ptr = utils.memref_ptr(smem_ref, 3) + for k_tile in range(expected_smem_shape[1]): + load_ptr = utils.getelementptr( + smem_base_ptr, [k_tile * k_tile_byte_stride], i8 + ) + store_ptr = arith.addi(tmem_ref.address, arith.constant(i32, 4 * k_tile)) + # The "core matrix" here is the same as in MMA: 8x(16 bytes). + desc = mma_utils.encode_descriptor(load_ptr, 0, 8 * 16, swizzle=None) + ptr = _tmem_addr_to_ptr(store_ptr) + nvvm.tcgen05_cp( + nvvm.Tcgen05CpShape.SHAPE_128x128b, ptr, desc, + group=nvvm.CTAGroupKind.CTA_2 if collective else nvvm.CTAGroupKind.CTA_1 + ) diff --git a/jax/experimental/mosaic/gpu/test_util.py b/jax/experimental/mosaic/gpu/test_util.py new file mode 100644 index 000000000000..d36833b6cf12 --- /dev/null +++ b/jax/experimental/mosaic/gpu/test_util.py @@ -0,0 +1,72 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import enum +import jax +from jax._src.lib.mlir import ir +from jax.experimental.mosaic.gpu import fragmented_array as fa +from jax.experimental.mosaic.gpu import layouts +from jax.experimental.mosaic.gpu import tcgen05 +from jax.experimental.mosaic.gpu import utils + + +class RegisterLayout(enum.Enum): + """The list of supported register layouts.""" + + WGMMA = enum.auto() + WG_SPLAT = enum.auto() + WG_STRIDED = enum.auto() + TCGEN05 = enum.auto() + TCGEN05_M64_COLLECTIVE = enum.auto() + TCGEN05_TMEM_NATIVE = enum.auto() + SMEM_GMEM_COPY = enum.auto() + TMA_GATHER_INDICES = enum.auto() + + def to_mgpu( + self, shape: tuple[int, int], dtype: jax.typing.DTypeLike | ir.Type + ) -> fa.FragmentedLayout: + if not isinstance(dtype, ir.Type): + dtype = utils.dtype_to_ir_type(dtype) + match self: + case RegisterLayout.WGMMA: + return fa.WGMMA_LAYOUT + case RegisterLayout.WG_SPLAT: + return fa.WGSplatFragLayout(shape) + case RegisterLayout.WG_STRIDED: + ty = ir.VectorType.get(shape, dtype) + layout = fa.WGStridedFragLayout.from_shaped_type(ty) + assert layout is not None + return layout + case RegisterLayout.TCGEN05: + return fa.TCGEN05_LAYOUT + case RegisterLayout.TCGEN05_M64_COLLECTIVE: + return tcgen05.fa_m64_collective_layout(shape[1]) + case RegisterLayout.TCGEN05_TMEM_NATIVE: + return fa.TMEM_NATIVE_LAYOUT + case RegisterLayout.SMEM_GMEM_COPY: + swizzle = 128 + bitwidth = utils.bitwidth(dtype) + tiling = (8, 8 * swizzle // bitwidth) + row_tiles, col_tiles = utils.tile_shape(shape, tiling)[-4:-2] + return fa.tiled_copy_smem_gmem_layout( + row_tiles, col_tiles, swizzle, bitwidth + ) + case RegisterLayout.TMA_GATHER_INDICES: + return fa.TMA_GATHER_INDICES_LAYOUT + + def to_layout_attr( + self, shape: tuple[int, int], dtype: jax.typing.DTypeLike | ir.Type + ) -> ir.Attribute: + return layouts.to_layout_attr(self.to_mgpu(shape, dtype)) diff --git a/jax/experimental/mosaic/gpu/transform_inference.py b/jax/experimental/mosaic/gpu/transform_inference.py deleted file mode 100644 index ef2d3661674c..000000000000 --- a/jax/experimental/mosaic/gpu/transform_inference.py +++ /dev/null @@ -1,301 +0,0 @@ -# Copyright 2025 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Transform inference pass for the MLIR Mosaic GPU dialect. - -The transform inference pass is meant to run on IR that has already been -annotated with layouts (see `layout_inference.py` for the relevant pass). -""" - -from collections.abc import Callable -from functools import partial -import itertools -from typing import cast - -from jax._src.lib import mosaic_gpu_dialect as mgpu -from jax._src.lib.mlir import ir -from jax._src.lib.mlir.dialects import arith -from jax._src.lib.mlir.dialects import builtin -from jax._src.lib.mlir.dialects import gpu -from jax._src.lib.mlir.dialects import memref -from jax._src.lib.mlir.dialects import vector - -from . import fragmented_array as fa -from . import inference_utils -from . import layouts as layouts_lib -from . import utils - -# mypy: ignore-errors - -OptionalTransforms = tuple[list[ir.Attribute], list[ir.Attribute]] | None -TransformInferenceRule = Callable[[ir.OpView], OptionalTransforms] -_transform_inference_rules: dict[str, TransformInferenceRule] = {} - - -def _add_transform_inference_rule( - op: type[ir.OpView], rule: TransformInferenceRule -): - if op is not None: - _transform_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error - return rule - - -def _set_transform_attributes( - op: ir.OpView, - in_transforms: list[ir.Attribute], - out_transforms: list[ir.Attribute], -): - op.attributes["in_transforms"] = ir.ArrayAttr.get(in_transforms) - op.attributes["out_transforms"] = ir.ArrayAttr.get(out_transforms) - - -def infer_transforms_for_wgmma_ref(ref_ty: ir.MemRefType) -> ir.ArrayAttr: - if len(ref_ty.shape) != 2: - raise ValueError(f"Expected a 2D memref, got {ref_ty}") - - element_bytewidth = utils.bytewidth(ref_ty.element_type) - strides, _ = ref_ty.get_strides_and_offset() - - if strides[0] < strides[1]: - raise NotImplementedError("Transpositions aren't handled yet.") - - minor_dim = ref_ty.shape[1] - major_tiling = 8 - - # Try tiling with all swizzling modes starting from the largest one. - for swizzle in [ - mgpu.SwizzlingMode.k128ByteSwizzle, - mgpu.SwizzlingMode.k64ByteSwizzle, - mgpu.SwizzlingMode.k32ByteSwizzle, - mgpu.SwizzlingMode.kNoSwizzle, - ]: - swizzle_elems = swizzle // element_bytewidth - if minor_dim % swizzle_elems == 0: - minor_tiling = swizzle_elems - break - else: - # No valid tile transform can be inferred. - raise ValueError( - f"{ref_ty.shape} is not a valid WGMMA shape" - ) - - return ir.ArrayAttr.get([ - mgpu.TileTransformAttr.get((major_tiling, minor_tiling)), - mgpu.SwizzleTransformAttr.get(minor_tiling * element_bytewidth), - ]) - - -@partial(_add_transform_inference_rule, mgpu.WGMMAOp) -def infer_wgmma_transforms(op: mgpu.WGMMAOp) -> OptionalTransforms: - b_transforms = infer_transforms_for_wgmma_ref(ir.MemRefType(op.b.type)) - if ir.MemRefType.isinstance(op.a.type): - a_transforms = infer_transforms_for_wgmma_ref( - cast(ir.MemRefType, op.a.type) - ) - return [a_transforms, b_transforms], [] - return [b_transforms], [] - - -@partial(_add_transform_inference_rule, mgpu.AsyncStoreOp) -def _infer_async_store_transforms(op: mgpu.AsyncStoreOp) -> OptionalTransforms: - in_transforms = inference_utils.value_transforms(op.source) - return None if in_transforms is None else ([in_transforms], []) - - -@partial(_add_transform_inference_rule, mgpu.AsyncLoadOp) -def _infer_async_load_transforms(op: mgpu.AsyncLoadOp) -> OptionalTransforms: - in_transforms = inference_utils.value_transforms(op.destination) - return None if in_transforms is None else ([in_transforms], []) - - -@partial(_add_transform_inference_rule, vector.LoadOp) -@partial(_add_transform_inference_rule, vector.StoreOp) -def _infer_vector_load_store_transforms( - op: vector.LoadOp | vector.StoreOp, -) -> OptionalTransforms: - for i in op.indices: - index_defining_op = i.owner.opview - if ( - not isinstance(index_defining_op, arith.ConstantOp) - or index_defining_op.literal_value != 0 - ): - # TODO(bchetioui): handle slicing. - raise NotImplementedError( - f"Only constants with value 0 are supported as indices for {op}" - ) - - if isinstance(op, vector.LoadOp): - [layout_attr] = inference_utils.out_layouts(op) - else: - assert isinstance(op, vector.StoreOp) - [layout_attr] = inference_utils.in_layouts(op) - - layout = layouts_lib.from_layout_attr(layout_attr) - transforms = inference_utils.value_transforms(op.base) - - if layout == fa.WGMMA_LAYOUT: - layout_transforms = infer_transforms_for_wgmma_ref( - ir.MemRefType(op.base.type) - ) - elif (isinstance(layout, fa.WGStridedFragLayout) or - isinstance(layout, fa.WGSplatFragLayout)): - layout_transforms = None - else: - raise NotImplementedError( - f"Got layout {layout} which is not yet supported" - ) - - if transforms is not None and layout_transforms is not None: - if transforms != layout_transforms: - raise NotImplementedError( - f"Conflicting transforms for {op.base} in {op}: " - f"{transforms} != {layout_transforms}." - ) - return [transforms], [] - - if transforms is not None: - return [transforms], [] - - if layout_transforms is not None: - return [layout_transforms], [] - - return None - -# TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2. -SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None) - -@partial(_add_transform_inference_rule, SliceSMEMOp) -def _infer_slice_smem_transforms(op: SliceSMEMOp) -> OptionalTransforms: - transforms = None - uses = cast(ir.OpResult, op.result).uses - - for op_operand_use in uses: - consumer = op_operand_use.owner - op_user = consumer.operands[op_operand_use.operand_number] - out_transforms = inference_utils.in_transforms_for_operand( - consumer, op_user - ) - if transforms is not None and out_transforms is not None: - if transforms != out_transforms: - raise NotImplementedError( - f"Conflicting transforms for {op_user} in {op}: " - f"{transforms} != {out_transforms}." - ) - elif out_transforms is not None: - transforms = out_transforms - - return None if transforms is None else ([], [transforms]) - - -# TODO(bchetioui,apaszke): this empty rule is necessary while Mosaic doesn't use -# the dialect in all cases. -# The rule is necessary in order to handle the lowering of `utils.memref_ptr` -# which is used in `_construct_smem_reftree`. -@partial(_add_transform_inference_rule, builtin.UnrealizedConversionCastOp) -def _infer_unrealized_conversion_cast_transforms( - _: builtin.UnrealizedConversionCastOp, -) -> OptionalTransforms: - return None - - -@partial(_add_transform_inference_rule, memref.ViewOp) -def _infer_memref_view_transforms(op: memref.ViewOp) -> OptionalTransforms: - if not isinstance(op.source.owner.opview, gpu.DynamicSharedMemoryOp): - raise NotImplementedError( - "Memref view transforms are only inferred when the op is a direct user " - f"of a DynamicSharedMemoryOp but got {op}." - ) - transforms = inference_utils.value_transforms(op.source) - if transforms is not None: - raise NotImplementedError( - "memref view with in_transforms aren't yet supported" - ) - uses = cast(ir.OpResult, op.result).uses - - for op_operand_use in uses: - consumer = op_operand_use.owner - op_user = consumer.operands[op_operand_use.operand_number] - out_transforms = inference_utils.in_transforms_for_operand( - consumer, op_user - ) - if transforms is not None and out_transforms is not None: - if transforms != out_transforms: - raise ValueError( - f"Conflicting transforms for {op_user} in {op}: " - f"{transforms} != {out_transforms}." - ) - elif out_transforms is not None: - transforms = out_transforms - - # TODO(bchetioui): do we actually need to assign a transform to the input of - # the view op? Presumably, it'll only be used to access scratch memory. - return None if transforms is None else ([], [transforms]) - - -# TODO(bchetioui,apaszke): this empty rule is necessary while Mosaic doesn't use -# the dialect in all cases. -@partial(_add_transform_inference_rule, gpu.DynamicSharedMemoryOp) -def _infer_dynamic_smem_transforms( - _: gpu.DynamicSharedMemoryOp, -) -> OptionalTransforms: - return None - - -def _should_have_transforms(op: ir.OpView) -> bool: - """Returns 'True' if the operation should be assigned in/out transforms.""" - return any( - map( - inference_utils.is_transformable_smem_memref, - itertools.chain(op.operands, op.results), - ) - ) - - -def infer_transforms(module: ir.Module): - """Infers transforms for the given module. - - Transforms are to memrefs what layouts are to vectors. More specifically, - transforms describe mappings between SMEM refs and GMEM refs, and are - determined based on how SMEM refs are used. For that reason, we always - annotate and apply memrefs on SMEM refs. - - The pass is meant to be called on a module where layouts have been fully - specified. We error out if two distinct sets of transforms are competing to - annotate the same memref. - """ - def inference_step(op: ir.Operation): - if not _should_have_transforms(op): - return - elif inference_rule := _transform_inference_rules.get(op.OPERATION_NAME, None): # pytype: disable=attribute-error - pass - else: - raise NotImplementedError(f"Can not infer transforms for {op}") - - maybe_transforms = inference_rule(op) - if maybe_transforms is None: - return - - _set_transform_attributes(op, *maybe_transforms) - - # It's enough to do a single backwards propagation (starting from vector - # users), and then a single forward propagation (to feed into the async loads - # and stores). - for op in module.body: - inference_utils.traverse_op( - op, inference_step, inference_utils.TraversalOrder.BACKWARDS - ) - for op in module.body: - inference_utils.traverse_op( - op, inference_step, inference_utils.TraversalOrder.FORWARD - ) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 28534cf4025b..990c7d60b7a9 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -24,6 +24,7 @@ import jax from jax import numpy as jnp +from jax._src.lib import mosaic_gpu_dialect as dialect # noqa: F401 from jax.interpreters import mlir from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith @@ -36,11 +37,10 @@ from jaxlib.mlir.dialects import vector import numpy as np -from jax._src.lib import mosaic_gpu_dialect as dialect # noqa: F401 - -# mypy: ignore-errors +WARP_SIZE: int = 32 WARPGROUP_SIZE: int = 128 +WARPS_IN_WARPGROUP: int = WARPGROUP_SIZE // WARP_SIZE DYNAMIC = -9223372036854775808 DYNAMIC32 = -2147483648 MBARRIER_BYTES = 8 @@ -63,13 +63,19 @@ def gpu_address_space_to_nvptx(address_space: gpu.AddressSpace) -> int: ) -def ptr_as_memref(ptr, memref_ty: ir.MemRefType, ptr_memory_space: int | None = None): +def ptr_as_memref( + ptr, memref_ty: ir.MemRefType, ptr_memory_space: int | None = None +): + strides, offset = memref_ty.get_strides_and_offset() + if offset != 0: + raise ValueError("Non-zero offset is not supported for ptr_as_memref") i64 = ir.IntegerType.get_signless(64) rank = len(memref_ty.shape) ptr_ty = "ptr" if ptr_memory_space is None else f"ptr<{ptr_memory_space}>" if rank > 0: desc_ty = ir.Type.parse( - f"!llvm.struct<({ptr_ty}, {ptr_ty}, i64, array<{rank} x i64>, array<{rank} x i64>)>" + f"!llvm.struct<({ptr_ty}, {ptr_ty}, i64, array<{rank} x i64>," + f" array<{rank} x i64>)>" ) else: desc_ty = ir.Type.parse(f"!llvm.struct<({ptr_ty}, {ptr_ty}, i64)>") @@ -84,7 +90,7 @@ def ptr_as_memref(ptr, memref_ty: ir.MemRefType, ptr_memory_space: int | None = desc = llvm.InsertValueOp( desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, s)), [3, i] ) - for i, s in enumerate(get_contiguous_strides(memref_ty.shape)): + for i, s in enumerate(strides): desc = llvm.InsertValueOp( desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, s)), [4, i] ) @@ -99,7 +105,7 @@ def pack_array(values): ptr_ty = ir.Type.parse("!llvm.ptr") arr_ptr = llvm.alloca(ptr_ty, c(len(values), i64), elem_ty) for i, v in enumerate(values): - elem_ptr = llvm.getelementptr(ptr_ty, arr_ptr, [], [i], elem_ty) + elem_ptr = getelementptr(arr_ptr, [i], elem_ty) llvm.store(v, elem_ptr) return arr_ptr @@ -114,46 +120,59 @@ def get_contiguous_strides(xs): def c(val: int | float, ty): - if ir.IntegerType.isinstance(ty) or ir.IndexType.isinstance(ty): + if isinstance(ty, ir.IntegerType) or isinstance(ty, ir.IndexType): if not isinstance(val, (int, np.integer)): raise TypeError(type(val)) attr = ir.IntegerAttr.get(ty, val) - elif ir.FloatType.isinstance(ty): + elif isinstance(ty, ir.FloatType): attr = ir.FloatAttr.get(ty, val) - elif ir.VectorType.isinstance(ty): - return vector.splat(ty, c(val, ir.VectorType(ty).element_type)) + elif isinstance(ty, ir.VectorType): + return vector.broadcast(ty, c(val, ir.VectorType(ty).element_type)) else: raise NotImplementedError(ty) return arith.constant(ty, attr) + def _debug_scalar_ty_format(arg): - if ir.IndexType.isinstance(arg.type): + if isinstance(arg.type, ir.IndexType): return "%llu", arg - if ir.IntegerType.isinstance(arg.type): + if isinstance(arg.type, ir.IntegerType): if ir.IntegerType(arg.type).width < 64: arg = arith.extui(ir.IntegerType.get_signless(64), arg) return "%llu", arg - if ir.F32Type.isinstance(arg.type): + if isinstance(arg.type, ir.F32Type): return "%f", arg - if ir.F16Type.isinstance(arg.type): + if isinstance(arg.type, ir.BF16Type) or isinstance(arg.type, ir.F16Type): arg = arith.extf(ir.F32Type.get(), arg) return "%f", arg raise NotImplementedError(f"Can't print the type {arg.type}") -def debug_print(fmt, *args, uniform=True): + +def debug_print(fmt, *args, uniform=True, scope=None): + if not uniform and scope is not None: + raise ValueError("Cannot specify scope to a non-uniform debug_print.") + if scope is None: + scope = ThreadSubset.WARPGROUP type_formats = [] new_args = [] for arg in args: - if ir.VectorType.isinstance(arg.type): + if isinstance(arg.type, ir.VectorType): index = ir.IndexType.get() vec_ty = ir.VectorType(arg.type) if len(vec_ty.shape) > 1: - raise NotImplementedError(vec_ty) + raise NotImplementedError( + "2D+ vectors are not supported in debug_print:" + f" {vec_ty}" + ) vec_args = [ - vector.extractelement(arg, position=c(i, index)) + vector.extract( + arg, + dynamic_position=[], + static_position=ir.DenseI64ArrayAttr.get([i]), + ) for i in range(vec_ty.shape[0]) ] - ty_formats, args = zip(*map(_debug_scalar_ty_format,vec_args)) + ty_formats, args = zip(*map(_debug_scalar_ty_format, vec_args)) ty_format = f"[{','.join(ty_formats)}]" new_args += args else: @@ -164,12 +183,149 @@ def debug_print(fmt, *args, uniform=True): raise NotImplementedError(arg.type) type_formats.append(ty_format) ctx = ( - functools.partial(single_thread, per_block=False) + functools.partial(single_thread, scope=scope) if uniform else contextlib.nullcontext ) with ctx(): - gpu.printf(fmt.format(*type_formats) + "\n", new_args) + gpu.printf(fmt.format(*type_formats) + "\n", *new_args) + + +@dataclasses.dataclass(frozen=True) +class MultimemRef: + ref: ir.Value + + @property + def type(self) -> ir.Type: + return ir.MemRefType(self.ref.type) + + def store(self, value: ir.Value, indices: Sequence[ir.Value]): + ptr = memref_ptr(memref_slice(self.ref, tuple(indices))) + multimem_store(ptr, value) + + +def multimem_store(ptr: ir.Value, value: ir.Value): + i32 = ir.IntegerType.get_signless(32) + if (bw := bitwidth(value.type)) not in {32, 64, 128}: + raise ValueError("Only 32-, 64- and 128-bit stores are supported") + vector_length = bw // 32 + value = bitcast(value, ir.VectorType.get((vector_length,), i32)) + regs = [ + llvm.extractelement(value, arith.constant(i32, i)) + for i in range(vector_length) + ] + if vector_length == 1: + vec_ptx = "$1" + vec_mod = "" + else: + vec_ptx = f"{{{','.join(f'${i}' for i in range(1, vector_length + 1))}}}" + vec_mod = ".v" + str(vector_length) + # It's unclear to me why, but at least according to PTX docs, we have to use + # the floating-point instructions here to be able to store vectors. + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [ptr, *regs], + f"multimem.st.relaxed.sys.global{vec_mod}.f32 [$0], {vec_ptx};", + "l" + ",r" * len(regs), + has_side_effects=True, + ) + + +MultimemReductionOp = Literal["add", "min", "max", "and", "or", "xor"] + + +def multimem_load_reduce( + ty: ir.Type, + ptr: ir.Value, + reduction: MultimemReductionOp, + is_signed: bool | None = None, +): + i32 = ir.IntegerType.get_signless(32) + if bitwidth(ty) not in {32, 64, 128}: + raise ValueError("Only 32-, 64- and 128-bit loads are supported") + if isinstance(ty, ir.VectorType): + vty = ir.VectorType(ty) + if len(vty.shape) > 1: + raise ValueError("Only 1D vectors are supported") + vector_length = vty.shape[0] + vector_i32_length = vector_length * bitwidth(vty.element_type) // 32 + if isinstance(vty.element_type, ir.IntegerType): + # TODO(apaszke): Emulate this by unrolling. + if vector_length != 1: + raise NotImplementedError( + "Only single-element integer operations are supported" + ) + if bitwidth(vty.element_type) not in {32, 64}: + raise NotImplementedError( + "Only 32-bit and 64-bit integer operations are supported" + ) + if reduction in {"and", "or", "xor"}: + ptx_ty = f"b{bitwidth(vty.element_type)}" + elif reduction in {"min", "max", "add"}: + if is_signed is None: + raise ValueError( + "Signedness must be specified for integer min, max and add" + " reductions" + ) + ptx_ty = f"{'s' if is_signed else 'u'}{bitwidth(vty.element_type)}" + else: + raise ValueError(f"Unsupported reduction operation: {reduction}") + elif isinstance(vty.element_type, ir.FloatType): + if reduction not in {"add", "min", "max"}: + raise ValueError("Only add, min and max are supported for floats") + if isinstance(vty.element_type, ir.F32Type): + if reduction != "add": + raise ValueError("Only add is supported for f32") + ptx_ty = "f32" + elif isinstance(vty.element_type, ir.BF16Type): + ptx_ty = "bf16x2" + elif isinstance(vty.element_type, ir.F16Type): + ptx_ty = "f16x2" + elif isinstance(vty.element_type, ir.Float8E5M2Type): + ptx_ty = "e5m2x4" + elif isinstance(vty.element_type, ir.Float8E4M3FNType): + ptx_ty = "e4m3x4" + else: + raise NotImplementedError(vty.element_type) + else: + raise NotImplementedError(vty.element_type) + else: + raise NotImplementedError(ty) + if vector_i32_length == 1: + vec_ptx = "$0" + vec_mod = "" + else: + vec_ptx = f"{{{','.join(f'${i}' for i in range(vector_i32_length))}}}" + vec_mod = ".v" + str(vector_i32_length) + # It's unclear to me why, but at least according to PTX docs, we have to use + # the floating-point instructions here to be able to store vectors. + acc_prec = "" + if vector_i32_length == 1: + asm_out_ty = i32 + else: + asm_out_ty = ir.Type.parse( + f"!llvm.struct<({','.join(['i32'] * vector_i32_length)})>" + ) + out_reg_struct = llvm.inline_asm( + asm_out_ty, + [ptr], + f"multimem.ld_reduce.relaxed.sys.global.{reduction}{acc_prec}{vec_mod}.{ptx_ty} {vec_ptx}," + f" [${vector_i32_length}];", + "=r," * vector_i32_length + "l", + has_side_effects=True, + ) + if vector_i32_length == 1: + return bitcast(out_reg_struct, ty) + else: + out_regs = [ + llvm.extractvalue(i32, out_reg_struct, [i]) + for i in range(vector_i32_length) + ] + vec_i32_ty = ir.VectorType.get((1,), i32) + return bitcast( + vector_concat([bitcast(out_reg, vec_i32_ty) for out_reg in out_regs]), + ty, + ) @dataclasses.dataclass(frozen=True) @@ -222,15 +378,19 @@ def when(cond): scf.yield_([]) -def thread_idx(): +def _3d_to_1d_idx(dim_idx_fn, dim_size_fn): i32 = ir.IntegerType.get_signless(32) as_i32 = lambda x: arith.index_cast(i32, x) - tidx = as_i32(gpu.thread_id(gpu.Dimension.x)) - stride = as_i32(gpu.block_dim(gpu.Dimension.x)) + idx = as_i32(dim_idx_fn(gpu.Dimension.x)) + stride = as_i32(dim_size_fn(gpu.Dimension.x)) for dim in (gpu.Dimension.y, gpu.Dimension.z): - tidx = arith.addi(tidx, arith.muli(as_i32(gpu.thread_id(dim)), stride)) - stride = arith.muli(stride, as_i32(gpu.block_dim(dim))) - return tidx + idx = arith.addi(idx, arith.muli(as_i32(dim_idx_fn(dim)), stride)) + stride = arith.muli(stride, as_i32(dim_size_fn(dim))) + return idx + + +thread_idx = functools.partial(_3d_to_1d_idx, gpu.thread_id, gpu.block_dim) +block_idx = functools.partial(_3d_to_1d_idx, gpu.block_id, gpu.grid_dim) def _warp_bcast(val, lane_idx=0): @@ -258,33 +418,43 @@ def warpgroup_idx(sync=True): class ThreadSubset(enum.IntEnum): + WARP = enum.auto() WARPGROUP = enum.auto() BLOCK = enum.auto() -# True withon `once()` contexts. +# True within `once()` contexts. _ONCE_PER: ThreadSubset | None = None -def single_thread_predicate(per_block=True): +def single_thread_predicate(scope: ThreadSubset = ThreadSubset.BLOCK): + """Returns a predicate that selects a single thread. + + Args: + scope: What level of the thread hierarchy to select a thread from. For + example, if the scope is BLOCK, only one thread per block will be + selected. + """ + elected = nvvm.elect_sync(ir.IntegerType.get_signless(1)) + if scope == ThreadSubset.WARP: + return elected warp = warp_idx() - if not per_block: + if scope is not ThreadSubset.BLOCK: warp = arith.remui(warp, c(4, warp.type)) first_warp = arith.cmpi(arith.CmpIPredicate.eq, warp, c(0, warp.type)) - elected = nvvm.elect_sync(ir.IntegerType.get_signless(1)) return arith.andi(first_warp, elected) @contextlib.contextmanager -def single_thread(per_block=True): +def single_thread(scope: ThreadSubset = ThreadSubset.BLOCK): """Runs the context only from a single thread. Args: - per_block: If True, only one thread per block will run the context. - Otherwise, only one thread per warp group will run the context. + scope: What level of the thread hierarchy to select a thread from. For + example, if the scope is BLOCK, only one thread per block will be + selected. """ global _ONCE_PER - scope = ThreadSubset.BLOCK if per_block else ThreadSubset.WARPGROUP # If we're already in a single-thread context, we don't have to do anything. if _ONCE_PER is not None and _ONCE_PER >= scope: yield @@ -293,7 +463,7 @@ def single_thread(per_block=True): prev_scope = _ONCE_PER _ONCE_PER = scope try: - if_op = scf.IfOp(single_thread_predicate(per_block)) + if_op = scf.IfOp(single_thread_predicate(scope)) with ir.InsertionPoint(if_op.then_block): yield scf.YieldOp([]) @@ -310,22 +480,28 @@ def clock(): def smid(): i32 = ir.IntegerType.get_signless(32) - return llvm.inline_asm( - i32, [], "mov.u32 $0,%smid;", "=r", asm_dialect=0 - ) + return llvm.inline_asm(i32, [], "mov.u32 $0,%smid;", "=r", asm_dialect=0) def globaltimer(kind: Literal["low", "high"] | None = None): if kind is None: i64 = ir.IntegerType.get_signless(64) return llvm.inline_asm( - i64, [], "mov.u32 $0,%globaltimer;", - "=l", asm_dialect=0, has_side_effects=True, + i64, + [], + "mov.u64 $0,%globaltimer;", + "=l", + asm_dialect=0, + has_side_effects=True, ) i32 = ir.IntegerType.get_signless(32) return llvm.inline_asm( - i32, [], f"mov.u32 $0,%globaltimer_{kind[:2]};", - "=r", asm_dialect=0, has_side_effects=True, + i32, + [], + f"mov.u32 $0,%globaltimer_{kind[:2]};", + "=r", + asm_dialect=0, + has_side_effects=True, ) @@ -340,15 +516,15 @@ def bitwidth_impl(ty: ir.Type): # 32 bits for compatibility reasons. TF32 used to be 32 bits wide in upstream # MLIR, but it changed in # https://github.com/llvm/llvm-project/commit/67a1fdb014790a38a205d28e1748634de34471dd. - if ir.FloatTF32Type.isinstance(ty): + if isinstance(ty, ir.FloatTF32Type): return 32 - if ir.IntegerType.isinstance(ty): + if isinstance(ty, ir.IntegerType): return ir.IntegerType(ty).width - if ir.FloatType.isinstance(ty): + if isinstance(ty, ir.FloatType): return ir.FloatType(ty).width if dialect is not None and ty == ir.Type.parse("!mosaic_gpu.barrier"): return MBARRIER_BYTES * 8 - if ir.VectorType.isinstance(ty): + if isinstance(ty, ir.VectorType): vty = ir.VectorType(ty) return math.prod(vty.shape) * bitwidth(vty.element_type) raise NotImplementedError(ty) @@ -399,7 +575,10 @@ def memref_slice(ref: ir.Value, index) -> ir.Value: new_layout = ir.StridedLayoutAttr.get(new_offset, new_strides) ref_slice = memref.subview( - ref, base_indices, slice_shape, [1] * len(ref_ty.shape), + ref, + base_indices, + slice_shape, + [1] * len(ref_ty.shape), result_type=ir.MemRefType.get( new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space ), @@ -411,7 +590,7 @@ def _is_contiguous_shape_slice( ref_ty: ir.MemRefType, dim_slice: slice | None = slice(None) ): # If it's not a strided layout then we are definitely contiguous. - if not ir.StridedLayoutAttr.isinstance(ref_ty.layout): + if not isinstance(ref_ty.layout, ir.StridedLayoutAttr): return True strides = ir.StridedLayoutAttr(ref_ty.layout).strides[dim_slice] @@ -435,7 +614,8 @@ def _reshape(ref: ir.Value, sh0: list[int], sh1: list[int]): """ i0, i1 = 0, 0 - def fold_until(shape, off , target) -> tuple[int, int]: + + def fold_until(shape, off, target) -> tuple[int, int]: assert shape[off] < target dim = 1 for to in range(off, len(shape)): @@ -446,16 +626,22 @@ def fold_until(shape, off , target) -> tuple[int, int]: # TODO(cperivol): Implement dependent fold-unfolds for subsections # of the shape eg (..., 4,5,5, ...) -> (..., 10,10, ...) could be # supported without touching any other dimensions. - raise NotImplementedError(f"Can't reshape {sh0} to {sh1} bu composing independent folds/unfolds.") + raise NotImplementedError( + f"Can't reshape {sh0} to {sh1} by composing independent" + " folds/unfolds." + ) - raise AssertionError(f"Unreachable: number of elements don't match in each shape ({sh0} ans {sh1})") + raise AssertionError( + f"Unreachable: number of elements don't match in each shape ({sh0} ans" + f" {sh1})" + ) while i0 < len(sh0) and i1 < len(sh1): if sh0[i0] > sh1[i1]: # How many dimensions following i1 should we unfold i0 into. idx, _ = fold_until(sh1, i1, sh0[i0]) ref = memref_unfold(ref, i0, sh1[i1:idx]) - sh0[i0:i0+1] = sh1[i1:idx] + sh0[i0 : i0 + 1] = sh1[i1:idx] i0 += idx - i1 i1 = idx elif sh0[i0] < sh1[i1]: @@ -481,28 +667,90 @@ def fold_until(shape, off , target) -> tuple[int, int]: return ref -def memref_reshape(ref: ir.Value, shape: tuple[int, ...]) -> ir.Value: +def memref_reshape( + ref: ir.Value | MultimemRef, shape: tuple[int, ...] +) -> ir.Value | MultimemRef: """Reshape by means of folding and unfolding. The use of memref fold/unfold may avoid some possible issues with strided memrefs. """ + if isinstance(ref, MultimemRef): + return MultimemRef(memref_reshape(ref.ref, shape)) + ref_ty = ir.MemRefType(ref.type) if math.prod(ref_ty.shape) != math.prod(shape): - raise ValueError("Cannot reshape to a different size") + raise ValueError( + f"Cannot reshape to a different size. Ref shape: {ref_ty.shape} (size:" + f" {math.prod(ref_ty.shape)}), new shape: {shape} (size:" + f" {math.prod(shape)})" + ) if not all(dim > 0 for dim in shape): raise ValueError( "Shapes must havbe only positive dimensions (no -1 or 0 dimensions" f" allowed) {shape}" ) - return _reshape(ref, list(ref_ty.shape), list(shape)) + src_shape = list(ref_ty.shape) + dst_shape = list(shape) + if src_shape == dst_shape: + return ref + if not src_shape: + _, offset = ref_ty.get_strides_and_offset() + identity = ir.AffineMapAttr.get(ir.AffineMap.get_identity(0)) + if ref_ty.layout == identity: + new_layout = ir.AffineMapAttr.get( + ir.AffineMap.get_identity(len(dst_shape)) + ) + else: + new_layout = ir.StridedLayoutAttr.get(offset, [1] * len(dst_shape)) + result_ty = ir.MemRefType.get( + dst_shape, ref_ty.element_type, new_layout, ref_ty.memory_space + ) + return memref.expand_shape(result_ty, ref, [], [], dst_shape) + if not dst_shape: + _, offset = ref_ty.get_strides_and_offset() + identity = ir.AffineMapAttr.get(ir.AffineMap.get_identity(ref_ty.rank)) + contig_strided_1d = ir.Attribute.parse("strided<[1]>") + if ref_ty.layout == identity or ref_ty.layout == contig_strided_1d: + new_layout = ir.AffineMapAttr.get(ir.AffineMap.get_identity(0)) + else: + new_layout = ir.StridedLayoutAttr.get(offset, []) + result_ty = ir.MemRefType.get( + (), ref_ty.element_type, new_layout, ref_ty.memory_space + ) + return memref.collapse_shape(result_ty, ref, []) + # For contiguous refs we can do arbitrary reshapes easily. + strides, _ = ref_ty.get_strides_and_offset() + if all( + d == 1 or s1 == s2 + for d, s1, s2 in zip( + ref_ty.shape, + get_contiguous_strides(ref_ty.shape), + strides, + strict=True, + ) + ): + return memref_unfold(memref_fold(ref, 0, ref_ty.rank), 0, shape) + return _reshape(ref, src_shape, dst_shape) + +def memref_fold( + ref: ir.Value | MultimemRef, dim, fold_rank +) -> ir.Value | MultimemRef: + if isinstance(ref, MultimemRef): + return MultimemRef(memref_fold(ref.ref, dim, fold_rank)) -def memref_fold(ref: ir.Value, dim, fold_rank) -> ir.Value: ref_ty = ir.MemRefType(ref.type) new_shape = list(ref_ty.shape) + if dim < 0: + raise ValueError(f"Dimension {dim} is negative") + if dim + fold_rank > len(new_shape): + raise ValueError( + f"Folding {fold_rank} dimensions starting from {dim} is out of bounds" + f" for shape {new_shape}" + ) new_shape[dim : dim + fold_rank] = [np.prod(new_shape[dim : dim + fold_rank])] identity = ir.AffineMapAttr.get(ir.AffineMap.get_identity(ref_ty.rank)) contig_strided_1d = ir.Attribute.parse("strided<[1]>") @@ -516,7 +764,7 @@ def memref_fold(ref: ir.Value, dim, fold_rank) -> ir.Value: new_strides[dim : dim + fold_rank] = [new_strides[dim + fold_rank - 1]] new_layout = ir.StridedLayoutAttr.get(offset, new_strides) else: - raise NotImplementedError( + raise ValueError( f"strides={ref_ty.get_strides_and_offset()[0]}, {ref_ty.shape=}," f" {dim=}, {fold_rank=}" ) @@ -545,7 +793,8 @@ def memref_unfold(ref: ir.Value, dim, factors) -> ir.Value: ) new_shape[dim : dim + 1] = factors identity = ir.AffineMapAttr.get(ir.AffineMap.get_identity(ref_ty.rank)) - if ref_ty.layout == identity: + contig_strided_1d = ir.Attribute.parse("strided<[1]>") + if ref_ty.layout == identity or ref_ty.layout == contig_strided_1d: new_layout = ir.AffineMapAttr.get( ir.AffineMap.get_identity(ref_ty.rank + len(factors) - 1) ) @@ -597,6 +846,16 @@ def memref_unsqueeze(ref: ir.Value, dim) -> ir.Value: return memref_unfold(ref, dim, (1, None)) +def is_memref_transposed(ref: ir.MemRefType) -> bool: + strides, _ = ref.get_strides_and_offset() + prev_stride = math.inf + for stride in strides: + if stride > prev_stride: + return True + prev_stride = stride + return False + + def memref_transpose(ref: ir.Value, permutation: Sequence[int]) -> ir.Value: ref_ty = ir.MemRefType(ref.type) strides, offset = ref_ty.get_strides_and_offset() @@ -661,7 +920,7 @@ def parse_indices( slice_shape.append(idx.length) is_squeezed.append(False) elif isinstance(idx, ir.Value): - if not ir.IndexType.isinstance(idx.type): + if not isinstance(idx.type, ir.IndexType): raise ValueError("Expected an index-typed index") base_indices.append(idx) slice_shape.append(1) @@ -692,6 +951,10 @@ def warpgroup_barrier(): ) +def warp_barrier(): + nvvm.bar_warp_sync(c(0xFFFFFFFF, ir.IntegerType.get_signless(32))) + + @dataclasses.dataclass(frozen=True) class BarrierRef: base_address: ir.Value @@ -700,18 +963,24 @@ class BarrierRef: num_barriers: int @staticmethod - def initialize(address: ir.Value, num_barriers: int, arrival_count: int = 1) -> "BarrierRef": + def initialize( + barrier_memref: ir.Value, arrival_count: int = 1 + ) -> "BarrierRef": + barrier_ty = ir.MemRefType(barrier_memref.type) + [num_barriers] = barrier_ty.shape if num_barriers > 32: raise NotImplementedError("Only up to 32 barriers per group supported") i32 = ir.IntegerType.get_signless(32) i64 = ir.IntegerType.get_signless(64) - ptr = ir.Type.parse(f"!llvm.ptr<{WORKGROUP_NVPTX_ADDRESS_SPACE}>") + address = memref_ptr( + barrier_memref, memory_space=WORKGROUP_NVPTX_ADDRESS_SPACE + ) phases = memref.alloca(ir.MemRefType.get((), i32), [], []) memref.store(c(0, i32), phases, []) - with single_thread(per_block=True): + with single_thread(scope=ThreadSubset.BLOCK): for i in range(num_barriers): - nvvm.mbarrier_init_shared( - llvm.getelementptr(ptr, address, [], [i], i64), + nvvm.mbarrier_init( + getelementptr(address, [i], i64), c(arrival_count, i32), ) return BarrierRef(address, c(0, i32), phases, num_barriers) @@ -726,8 +995,10 @@ def __iter__(self) -> Iterator["BarrierRef"]: def __getitem__(self, offset: ir.Value | int) -> "BarrierRef": i32 = ir.IntegerType.get_signless(32) if isinstance(offset, int): + if offset >= self.num_barriers: + raise IndexError(f"Barrier offset {offset} is out of bounds") offset = c(offset, i32) - elif ir.IndexType.isinstance(offset.type): + elif isinstance(offset.type, ir.IndexType): offset = arith.index_castui(i32, offset) elif offset.type != i32: raise ValueError(f"Expected a dynamic index or an integer, got {offset}") @@ -738,23 +1009,25 @@ def __getitem__(self, offset: ir.Value | int) -> "BarrierRef": 1, ) - def wait_parity(self, parity, for_tensor_core=False): + def wait_parity(self, parity, orders_tensor_core=False): i32 = ir.IntegerType.get_signless(32) ticks = arith.constant(i32, 10000000) parity = arith.extui(i32, parity) - nvvm.mbarrier_try_wait_parity_shared(self.get_ptr(), parity, ticks) - if for_tensor_core: + nvvm.mbarrier_try_wait_parity(self.get_ptr(), parity, ticks) + if orders_tensor_core: llvm.inline_asm( ir.Type.parse("!llvm.void"), - [], "tcgen05.fence::after_thread_sync;", "", + [], + "tcgen05.fence::after_thread_sync;", + "", has_side_effects=True, ) - def wait(self, for_tensor_core: bool = False): + def wait(self, orders_tensor_core: bool = False): parities = memref.load(self.phases, []) parity, new_parities = self.update_parities(parities) memref.store(new_parities, self.phases, []) - self.wait_parity(parity, for_tensor_core) + self.wait_parity(parity, orders_tensor_core) def update_parities(self, parities: ir.Value) -> tuple[ir.Value, ir.Value]: i32 = ir.IntegerType.get_signless(32) @@ -764,39 +1037,145 @@ def update_parities(self, parities: ir.Value) -> tuple[ir.Value, ir.Value]: ) return parity, arith.xori(parities, bitmask) - def arrive(self): + def arrive( + self, + arrival_count: int = 1, + can_complete: bool = True, + orders_tensor_core: bool = False, + predicate: ir.Value | None = None, + ): i64 = ir.IntegerType.get_signless(64) - nvvm.mbarrier_arrive_shared(i64, self.get_ptr()) + if orders_tensor_core: + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [], + "tcgen05.fence::before_thread_sync;", + "", + has_side_effects=True, + ) + if can_complete: + pred_ptx = pred_constraint = "" + if predicate is not None: + pred_ptx = "@$2" + pred_constraint = ",b" + llvm.inline_asm( + ir.IntegerType.get_signless(64), + [self.get_ptr()] + ([predicate] if predicate is not None else []), + f"{pred_ptx} mbarrier.arrive.release.cta.shared::cta.b64 $0, [$1]," + f" {arrival_count};", + "=l,r" + pred_constraint, + has_side_effects=True, + ) + else: + if predicate is not None: + raise NotImplementedError( + "Predicate not supported for no-complete arrive" + ) + count = c(arrival_count, ir.IntegerType.get_signless(32)) + nvvm.mbarrier_arrive_nocomplete(i64, self.get_ptr(), count) def arrive_expect_tx( self, bytes: int | ir.Value, predicate: ir.Value | None = None ): if isinstance(bytes, int): bytes = c(bytes, ir.IntegerType.get_signless(32)) - elif ir.IndexType.isinstance(bytes.type): + elif isinstance(bytes.type, ir.IndexType): i32 = ir.IntegerType.get_signless(32) bytes = arith.index_cast(i32, bytes) - nvvm.mbarrier_arrive_expect_tx_shared(self.get_ptr(), bytes, predicate=predicate) + nvvm_mbarrier_arrive_expect_tx( + self.get_ptr(), bytes, predicate=predicate + ) + + def complete_tx( + self, bytes: int | ir.Value, predicate: ir.Value | None = None + ): + if isinstance(bytes, int): + bytes = c(bytes, ir.IntegerType.get_signless(32)) + elif isinstance(bytes.type, ir.IndexType): + i32 = ir.IntegerType.get_signless(32) + bytes = arith.index_cast(i32, bytes) + + pred_ptx = pred_constraint = "" + if predicate is not None: + pred_ptx = "@$2" + pred_constraint = ",b" + + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [self.get_ptr(), bytes] + ([predicate] if predicate is not None else []), + f"{pred_ptx} mbarrier.complete_tx.shared::cta.b64 [$0], $1;", + "l,r" + pred_constraint, + has_side_effects=True, + ) def get_ptr(self): - ptr = ir.Type.parse(f"!llvm.ptr<{WORKGROUP_NVPTX_ADDRESS_SPACE}>") i64 = ir.IntegerType.get_signless(64) - DYNAMIC32 = -2147483648 - return llvm.getelementptr( - ptr, self.base_address, [self.offset], [DYNAMIC32], i64 - ) + return getelementptr(self.base_address, [self.offset], i64) + + +@dataclasses.dataclass(frozen=True) +class DialectBarrierRef: + barrier_ref: BarrierRef - def as_dialect_barrier_memref(self) -> ir.Value: - shape = () if self.num_barriers == 1 else (self.num_barriers,) - return ptr_as_memref( - self.get_ptr(), - ir.MemRefType.get(shape, ir.Type.parse("!mosaic_gpu.barrier")), - ptr_memory_space=WORKGROUP_NVPTX_ADDRESS_SPACE, + @staticmethod + def initialize( + barrier_memref: ir.Value, + arrival_count: int = 1, + ) -> "DialectBarrierRef": + barrier_ty = ir.MemRefType(barrier_memref.type) + [num_barriers] = barrier_ty.shape + if num_barriers > 32: + raise NotImplementedError("Only up to 32 barriers per group supported") + + address = memref_ptr( + barrier_memref, memory_space=WORKGROUP_NVPTX_ADDRESS_SPACE + ) + dialect.initialize_barrier(address, arrival_count, num_barriers) + i32 = ir.IntegerType.get_signless(32) + phases = memref.alloca(ir.MemRefType.get((), i32), [], []) + memref.store(c(0, i32), phases, []) + return DialectBarrierRef( + barrier_ref=BarrierRef(address, c(0, i32), phases, num_barriers) ) + def __iter__(self) -> Iterator["DialectBarrierRef"]: + if self.barrier_ref.num_barriers == 1: + yield self + else: + for offset in range(self.barrier_ref.num_barriers): + yield self[offset] + + def __getitem__(self, offset: ir.Value | int) -> "DialectBarrierRef": + return DialectBarrierRef(self.barrier_ref[offset]) + + def wait_parity(self, parity, orders_tensor_core=False): + self.barrier_ref.wait_parity(parity, orders_tensor_core) + + def wait(self, orders_tensor_core: bool = False): + assert self.barrier_ref.phases is not None + self.barrier_ref.wait(orders_tensor_core) + + def update_parities(self, parities: ir.Value) -> tuple[ir.Value, ir.Value]: + return self.barrier_ref.update_parities(parities) + + def arrive(self, orders_tensor_core: bool = False): + dialect.ArriveOp(self.as_barrier_memref(), orders_tensor_core) + + def arrive_expect_tx(self, bytes: int | ir.Value): + dialect.ArriveExpectTxOp(barrier=self.as_barrier_memref(), expect_tx=bytes) + + def get_ptr(self): + return self.barrier_ref.get_ptr() + + def as_barrier_memref(self) -> ir.Value: + num_barriers = self.barrier_ref.num_barriers + shape = () if num_barriers == 1 else (num_barriers,) + memref_type = ir.MemRefType.get(shape, ir.Type.parse("!mosaic_gpu.barrier")) + return builtin.unrealized_conversion_cast([memref_type], [self.get_ptr()]) + @classmethod - def from_dialect_barrier_memref(cls, barrier: ir.Value): - """Creates a BarrierRef from a memref of a dialect barrier.""" + def from_barrier_memref(cls, barrier: ir.Value): + """Creates a DialectBarrierRef from a memref of a dialect barrier.""" memref_type = ir.MemRefType(barrier.type) if memref_type.rank > 1 or memref_type.element_type != ir.Type.parse( "!mosaic_gpu.barrier" @@ -806,13 +1185,15 @@ def from_dialect_barrier_memref(cls, barrier: ir.Value): f"!mosaic_gpu.barrier, but got {barrier.type}" ) + ptr_type = ir.Type.parse(f"!llvm.ptr<{WORKGROUP_NVPTX_ADDRESS_SPACE}>") + addr = builtin.unrealized_conversion_cast([ptr_type], [barrier]) return cls( - base_address=memref_ptr( - barrier, memory_space=WORKGROUP_NVPTX_ADDRESS_SPACE - ), - offset=c(0, ir.IntegerType.get_signless(64)), - phases=None, - num_barriers=(1 if memref_type.rank == 0 else memref_type.shape[0]), + barrier_ref=BarrierRef( + base_address=addr, + offset=c(0, ir.IntegerType.get_signless(64)), + phases=None, + num_barriers=(1 if memref_type.rank == 0 else memref_type.shape[0]), + ) ) @@ -823,8 +1204,8 @@ class CollectiveBarrierRef: @staticmethod def initialize( - address: ir.Value, - num_barriers: int, + barrier_memref: ir.Value, + arrival_count: int, dims: Sequence[gpu.Dimension | Sequence[gpu.Dimension]], cluster_shape: tuple[int, int, int], ) -> "CollectiveBarrierRef": @@ -838,8 +1219,8 @@ def initialize( else math.prod(cluster_shape[dd] for dd in d) for d in dims ] - arrival_count = sum(dims_shape) - len(dims) + 1 - if arrival_count == 1: + cluster_arrival_count = sum(dims_shape) - len(dims) + 1 + if cluster_arrival_count == 1: assert all(s == 1 for s in dims_shape) cluster_mask = None else: @@ -852,7 +1233,9 @@ def initialize( cluster_mask = arith.ori( cluster_mask, cluster_collective_mask(cluster_shape, d) ) - barrier = BarrierRef.initialize(address, num_barriers, arrival_count=arrival_count) + barrier = BarrierRef.initialize( + barrier_memref, arrival_count=arrival_count * cluster_arrival_count + ) return CollectiveBarrierRef(barrier, cluster_mask) def __iter__(self): @@ -862,15 +1245,23 @@ def __iter__(self): def __getitem__(self, offset): return CollectiveBarrierRef(self.barrier[offset], self.cluster_mask) - def arrive(self): + def arrive(self, orders_tensor_core: bool = False): """Arrives on a barrier in all blocks that share at least one of the coordinates along the collective dimensions. Note that unlike in arrive, each warpgroup arrives once. """ + if orders_tensor_core: + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [], + "tcgen05.fence::before_thread_sync;", + "", + has_side_effects=True, + ) if self.barrier.num_barriers != 1: raise ValueError("Can only arrive on a single barrier") if self.cluster_mask is None: - with single_thread(per_block=False): + with single_thread(scope=ThreadSubset.WARPGROUP): self.barrier.arrive() return i32 = ir.IntegerType.get_signless(32) @@ -909,6 +1300,124 @@ def wait_parity(self, *args, **kwargs): self.barrier.wait_parity(*args, **kwargs) +@dataclasses.dataclass(frozen=True) +class SemaphoreRef: + ptr: ir.Value + + def signal( + self, + value: ir.Value | int, + predicate: ir.Value | None = None, + relaxed: bool = False, + ): + i32 = ir.IntegerType.get_signless(32) + if not isinstance(value, ir.Value): + value = c(value, i32) + elif value.type != i32: + raise ValueError(f"Expected a i32 value, got {value.type}") + if predicate is None: + predicate = single_thread_predicate(ThreadSubset.WARPGROUP) + semantics = "relaxed" if relaxed else "release" + llvm.inline_asm( + i32, + [self.ptr, value, predicate], + f"@$3 atom.add.{semantics}.sys.global.u32 $0, [$1], $2;", + "=r,l,r,b", + has_side_effects=True, + ) + + @staticmethod + def signal_multimem(ptr, value, predicate: ir.Value | None = None): + i32 = ir.IntegerType.get_signless(32) + if not isinstance(value, ir.Value): + value = c(value, i32) + elif value.type != i32: + raise ValueError(f"Expected a i32 value, got {value.type}") + if predicate is None: + predicate = single_thread_predicate(ThreadSubset.WARPGROUP) + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [ptr, value, predicate], + """{ + @$2 multimem.red.release.sys.global.add.u32 [$0], $1; + fence.proxy.alias; + } + """, + "l,r,b", + has_side_effects=True, + ) + + def wait( + self, + value: ir.Value | int = 1, + *, + decrement: bool = True, + scope: ThreadSubset = ThreadSubset.WARPGROUP, + ): + i32 = ir.IntegerType.get_signless(32) + if not isinstance(value, ir.Value): + value = c(value, i32) + elif value.type != i32: + raise ValueError(f"Expected a i32 value, got {value.type}") + + with single_thread(scope=scope): + # Create the while loop for busy waiting + while_op = scf.WhileOp([i32], [value]) + before_block = while_op.before.blocks.append(i32) + with ir.InsertionPoint.at_block_begin(before_block): + [expected_in_memory] = before_block.arguments + if decrement: + new_val = arith.subi(expected_in_memory, value) + in_memory = llvm.inline_asm( + i32, + [self.ptr, expected_in_memory, new_val], + "atom.acquire.sys.global.cas.b32 $0, [$1], $2, $3;", + "=r,l,r,r", + has_side_effects=True, + ) + ne_pred = arith.CmpIPredicate.ne + comparison = arith.cmpi(ne_pred, in_memory, expected_in_memory) + new_expected_in_memory = arith.maxui(in_memory, value) + else: + in_memory = llvm.inline_asm( + i32, + [self.ptr], + "ld.relaxed.sys.global.b32 $0, [$1];", + "=r,l", + has_side_effects=True, + ) + lt_pred = arith.CmpIPredicate.ult + comparison = arith.cmpi(lt_pred, in_memory, value) + new_expected_in_memory = expected_in_memory + scf.condition(comparison, [new_expected_in_memory]) + after_block = while_op.after.blocks.append(i32) + with ir.InsertionPoint.at_block_begin(after_block): + scf.yield_(after_block.arguments) + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [], + "fence.acquire.sys;", + "", + has_side_effects=True, + ) + if scope == ThreadSubset.WARPGROUP: + warpgroup_barrier() + elif scope == ThreadSubset.WARP: + warp_barrier() + else: + raise ValueError(f"Unsupported scope: {scope}") + + +def fence_release_sys(): + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [], + "fence.release.sys;", + "", + has_side_effects=True, + ) + + class Partition: source_bounds: tuple[int, ...] target_bounds: tuple[int, ...] @@ -936,6 +1445,7 @@ def __init__( if num_chunks is not None: self.source_bounds = num_chunks else: + assert chunk_size is not None if len(chunk_size) != len(self.target_bounds): raise ValueError source_bounds = [] @@ -964,8 +1474,10 @@ def num_chunks(self) -> tuple[int, ...]: @property def target_block_shape(self): - return tuple(tb if p is None else tb // self.source_bounds[p] - for tb, p in zip(self.target_bounds, self.partition)) + return tuple( + tb if p is None else tb // self.source_bounds[p] + for tb, p in zip(self.target_bounds, self.partition) + ) def get_base(self, *source_coords: ir.Value | int) -> list[ir.Value]: coords = [] @@ -1003,6 +1515,7 @@ def __init__( if num_chunks is not None: self.partition = Partition(num_chunks=(num_chunks,), **common_kwargs) else: + assert chunk_size is not None self.partition = Partition(chunk_size=(chunk_size,), **common_kwargs) @property @@ -1029,7 +1542,11 @@ def refine( def tile_shape(shape, tiling): if len(tiling) > len(shape): - raise ValueError + raise ValueError( + "Expected tiling to be at most rank of shape. Got tiling:" + f" {tiling} (rank: {len(tiling)}) and shape {shape} (rank:" + f" {len(shape)})." + ) if not tiling: return shape tiling_rank = len(tiling) @@ -1051,7 +1568,9 @@ def warp_tree_reduce(value, op, group_size): result = value iters = np.log2(group_size) if not iters.is_integer(): - raise ValueError(f"Warp reduction group size should be a power of 2 (got {group_size})") + raise ValueError( + f"Warp reduction group size should be a power of 2 (got {group_size})" + ) iters = int(iters) for i in range(iters): other_result = nvvm.shfl_sync( @@ -1092,7 +1611,10 @@ def memref_ptr(memref_arg, memory_space=None): assert elem_bitwidth.bit_count() == 1 packing = 8 // elem_bitwidth if static_offset % packing != 0: - raise ValueError + raise ValueError( + f"{memref_ty} {static_offset=} is not divisible by" + f" {packing=}`" + ) offset_bytes = c(static_offset // packing, i64) else: offset_bits = llvm.mul( @@ -1140,7 +1662,8 @@ def cluster_collective_mask( if cluster_shape[cluster_dim] != 1: # Constant-fold multiply by 0. dim_idx = arith.index_castui(i32, gpu.cluster_block_id(cluster_dim)) mask_shift = arith.addi( - mask_shift, arith.muli(dim_idx, c(stride, i32)), + mask_shift, + arith.muli(dim_idx, c(stride, i32)), ) mask_unshifted = 0 collective_strides = [cluster_strides[d] for d in collective] @@ -1171,7 +1694,14 @@ def getelementptr( ) -> ir.Value: static_indices = [i if isinstance(i, int) else DYNAMIC32 for i in indices] dyn_indices = [i for i in indices if not isinstance(i, int)] - return llvm.getelementptr(ptr.type, ptr, dyn_indices, static_indices, dtype) + return llvm.getelementptr( + ptr.type, + ptr, + dyn_indices, + static_indices, + dtype, + llvm.GEPNoWrapFlags.none, + ) def dyn_dot(x, y): @@ -1181,16 +1711,87 @@ def dyn_dot(x, y): def shfl_bfly(x: ir.Value, distance: int | ir.Value): i32 = ir.IntegerType.get_signless(32) + index = ir.IndexType.get() if isinstance(distance, int): distance = c(distance, i32) if (result_type := x.type) != i32: + if (x_bitwidth := bitwidth(x.type)) < 32: # Pad to 32-bits if necessary. + assert 32 % x_bitwidth == 0 + x = bitcast(x, ir.IntegerType.get_signless(x_bitwidth)) + empty32 = llvm.mlir_undef(ir.VectorType.get((32 // x_bitwidth,), x.type)) + x = vector.insert( + x, + empty32, + dynamic_position=[], + static_position=ir.DenseI64ArrayAttr.get([0]), + ) + elif x_bitwidth > 32: + assert x_bitwidth % 32 == 0 + num_words = x_bitwidth // 32 + xs_vec = bitcast(x, ir.VectorType.get((num_words,), i32)) + y = llvm.mlir_undef(xs_vec.type) + for i in range(num_words): + x_elem = vector.extract( + xs_vec, + dynamic_position=[], + static_position=ir.DenseI64ArrayAttr.get([i]), + ) + y_elem = shfl_bfly(x_elem, distance) + y = vector.insert( + y_elem, + y, + dynamic_position=[], + static_position=ir.DenseI64ArrayAttr.get([i]), + ) + return bitcast(y, result_type) x = bitcast(x, i32) y = nvvm.shfl_sync( - i32, c(0xFFFFFFFF, i32), x, distance, c(0x1F, i32), nvvm.ShflKind.bfly, + i32, + c(0xFFFFFFFF, i32), + x, + distance, + c(0x1F, i32), + nvvm.ShflKind.bfly, ) + if (x_bitwidth := bitwidth(result_type)) < 32: + bits_ty = ir.IntegerType.get_signless(x_bitwidth) + y_vec = bitcast(y, ir.VectorType.get((32 // x_bitwidth,), bits_ty)) + y = vector.extract( + y_vec, + dynamic_position=[], + static_position=ir.DenseI64ArrayAttr.get([0]), + ) return bitcast(y, result_type) +def redux(x: ir.Value, mask: ir.Value, kind: nvvm.ReduxKind): + i32 = ir.IntegerType.get_signless(32) + if isinstance(vec_ty := x.type, ir.VectorType): + if bitwidth(vec_ty.element_type) != 32: + raise ValueError("Only 32-bit types supported") + [vec_len] = vec_ty.shape + result = llvm.mlir_undef(x.type) + for i in range(vec_len): + xi = llvm.extractelement(x, arith.constant(i32, i)) + yi = redux(xi, mask, kind) + result = llvm.insertelement(result, yi, arith.constant(i32, i)) + return result + if bitwidth(x.type) != 32: + raise ValueError("Only 32-bit scalar types supported") + if isinstance(x.type, ir.IntegerType): + pass + elif isinstance(x.type, ir.F32Type): + if get_arch().major != 10: + raise ValueError("F32 redux only supported on Blackwell GPUs") + else: + raise NotImplementedError(x.type) + assert mask.type == i32 + extra_kwargs = {} + if kind == nvvm.ReduxKind.FMAX or kind == nvvm.ReduxKind.FMIN: + extra_kwargs = dict(nan=True) + return nvvm.redux_sync(x.type, x, kind, mask, **extra_kwargs) + + def prmt(high: ir.Value, low: ir.Value, permutation: ir.Value): i32 = ir.IntegerType.get_signless(32) if (result_type := high.type) != low.type: @@ -1210,25 +1811,41 @@ def prmt(high: ir.Value, low: ir.Value, permutation: ir.Value): def bitcast(x: ir.Value, new_type: ir.Type): if x.type == new_type: return x - if ir.VectorType.isinstance(x.type) and ir.IntegerType.isinstance(new_type): + if (x_bw := bitwidth(x.type)) != (new_bw := bitwidth(new_type)): + raise ValueError( + f"Can't bitcast {x.type} (of bitwidth {x_bw}) to {new_type} (of" + f" bitwidth {new_bw})" + ) + if isinstance(x.type, ir.VectorType) and isinstance(new_type, ir.IntegerType): new_type = ir.IntegerType(new_type) x_ty = ir.VectorType(x.type) assert new_type.width == bitwidth(x_ty.element_type) * math.prod(x_ty.shape) - i0 = arith.ConstantOp.create_index(0) - return vector.extractelement( - vector.bitcast(ir.VectorType.get((1,), new_type), x), position=i0 + return vector.extract( + vector.bitcast(ir.VectorType.get((1,), new_type), x), + dynamic_position=[], + static_position=ir.DenseI64ArrayAttr.get([0]), ) - if ir.IntegerType.isinstance(x.type) and ir.VectorType.isinstance(new_type): + if isinstance(x.type, ir.IntegerType) and isinstance(new_type, ir.VectorType): new_type = ir.VectorType(new_type) x_ty = ir.IntegerType(x.type) - assert x_ty.width == bitwidth(new_type.element_type) * math.prod(new_type.shape) - return vector.bitcast(new_type, vector.splat(ir.VectorType.get((1,), x_ty), x)) - if ir.VectorType.isinstance(x.type) and ir.VectorType.isinstance(new_type): + assert x_ty.width == bitwidth(new_type.element_type) * math.prod( + new_type.shape + ) + return vector.bitcast( + new_type, vector.broadcast(ir.VectorType.get((1,), x_ty), x) + ) + if isinstance(x.type, ir.VectorType) and isinstance(new_type, ir.VectorType): x_ty = ir.VectorType(x.type) new_ty = ir.VectorType(new_type) if bitwidth(x_ty) != bitwidth(new_ty): raise ValueError(f"Can't bitcast {x.type} to {new_type}") return vector.bitcast(new_type, x) + if isinstance(x.type, ir.IntegerType) and isinstance(new_type, ir.FloatType): + return arith.bitcast(new_type, x) + if isinstance(x.type, ir.FloatType) and isinstance(new_type, ir.IntegerType): + return arith.bitcast(new_type, x) + if isinstance(x.type, ir.FloatType) and isinstance(new_type, ir.FloatType): + return arith.bitcast(new_type, x) raise ValueError(f"Can't bitcast {x.type} to {new_type}") @@ -1239,34 +1856,247 @@ def ceil_div(x: int, y: int): def vector_slice(v: ir.Value, s: slice): v_ty = ir.VectorType(v.type) if len(v_ty.shape) != 1: - raise NotImplementedError(v_ty) + raise NotImplementedError(f"Only 1D vectors are supported {v_ty}") [v_len] = v_ty.shape slice_length = len(range(v_len)[s]) return vector.extract_strided_slice( ir.VectorType.get((slice_length,), v_ty.element_type), - v, [s.start or 0], [slice_length], [1], + v, + [s.start or 0], + [slice_length], + [1], ) def vector_concat(vectors: Sequence[ir.Value]) -> ir.Value: - index = ir.IndexType.get() if not vectors: raise ValueError("Cannot concatenate an empty list of vectors") vty = vectors[0].type - if not ir.VectorType.isinstance(vty): + if not isinstance(vty, ir.VectorType): raise ValueError("Cannot concatenate non-vector values") + vty = ir.VectorType(vty) if vty.rank != 1: raise NotImplementedError("Only 1D vectors are supported") for v in vectors: - if v.type != vty: - raise ValueError("Cannot concatenate vectors of different types") - result = llvm.mlir_undef( - ir.VectorType.get((vty.shape[0] * len(vectors),), vty.element_type) + if v.type.element_type != vty.element_type: + raise ValueError("Cannot concatenate vectors of different element types") + if v.type.rank != 1: + raise ValueError("Can only concatenate 1D vectors") + return _vector_concat_rec(vectors) + + +def _vector_concat_rec(vectors: Sequence[ir.Value]) -> ir.Value: + match vectors: + case [v]: + return v + case [v, w]: + [v_len] = ir.VectorType(v.type).shape + [w_len] = ir.VectorType(w.type).shape + mask = ir.DenseI64ArrayAttr.get(list(range(v_len + w_len))) + return vector.shuffle(*vectors, mask=mask) + case _: + assert vectors + l = _vector_concat_rec(vectors[: len(vectors) // 2]) + r = _vector_concat_rec(vectors[len(vectors) // 2 :]) + return _vector_concat_rec([l, r]) + + +def is_known_divisible(value, divisor, max_depth=10) -> bool: + """Returns True if the value is statically known to be divisible by the divisor.""" + if divisor == 1: + return True + if max_depth < 0 or not isinstance(value.owner, ir.OpView): + return False + + new_depth = max_depth - 1 + def_op = value.owner.opview + + match def_op: + case arith.IndexCastOp(): + return is_known_divisible(value.owner.operands[0], divisor, max_depth - 1) + case arith.ConstantOp(): + return ir.IntegerAttr(def_op.value).value % divisor == 0 + case arith.MulIOp(): + # Only cover the case where one operand is divisible. It's still possible + # that the final product is divisible, but we don't check that here. + return is_known_divisible( + value.owner.operands[0], divisor, new_depth + ) or is_known_divisible(value.owner.operands[1], divisor, new_depth) + case arith.SelectOp(): + return is_known_divisible( + value.owner.operands[1], divisor, new_depth + ) and is_known_divisible(value.owner.operands[2], divisor, new_depth) + case arith.MaxSIOp() | arith.MinSIOp() | arith.MaxUIOp() | arith.MinUIOp(): + return is_known_divisible( + value.owner.operands[0], divisor, new_depth + ) and is_known_divisible(value.owner.operands[1], divisor, new_depth) + case arith.AddIOp() | arith.SubIOp(): + # Only cover the common case where both operads are divisible. + return is_known_divisible( + value.owner.operands[0], divisor, new_depth + ) and is_known_divisible(value.owner.operands[1], divisor, new_depth) + case arith.AndIOp(): + # Only cover the specific case where the divisor is a power of two. + return divisor.bit_count() == 1 and ( + is_known_divisible(value.owner.operands[0], divisor, new_depth) + or is_known_divisible(value.owner.operands[1], divisor, new_depth) + ) + + return False + + +def smem() -> ir.Attribute: + """Returns the attribute for the SMEM memory space.""" + return ir.Attribute.parse("#gpu.address_space") + + +def tmem() -> ir.Attribute: + """Returns the attribute for the TMEM memory space.""" + return ir.Attribute.parse("#mosaic_gpu.tmem") + + +def is_smem_ref(ref: ir.Value | ir.Type) -> bool: + """Returns true if the input mem ref or memref type points to SMEM. + + If the input is not at all of a memref type, raises a ValueError. + """ + if isinstance(ref, ir.Value): + ref = ref.type + if not isinstance(ref, ir.MemRefType): + raise ValueError(f"Expected a memref type but got {ref}") + ref = ir.MemRefType(ref) + return ref.memory_space is not None and ref.memory_space == smem() + + +def is_tmem_ref(ref: ir.Value | ir.Type) -> bool: + """Returns true if the input mem ref or memref type points to TMEM. + + If the input is not at all of a memref type, raises a ValueError. + """ + if isinstance(ref, ir.Value): + ref = ref.type + if not isinstance(ref, ir.MemRefType): + raise ValueError(f"Expected a memref type but got {ref}") + ref = ir.MemRefType(ref) + return ref.memory_space is not None and ref.memory_space == tmem() + + +def try_cluster_cancel( + result_ref, + barrier: BarrierRef, + predicate: ir.Value | None = None, +): + """Atomically cancels a pending cluster launch. + + The response is stored in a opaque 128-bit value containing the CTA id of the + first CTA in the canceled cluster. + """ + if predicate is None: + predicate = single_thread_predicate(ThreadSubset.BLOCK) + + pred_ptx = "@$2" + pred_constraint = ",b" + + addr = memref_ptr(result_ref, memory_space=3) + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [addr, barrier.get_ptr()] + + ([predicate] if predicate is not None else []), + f"{pred_ptx} clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.multicast::cluster::all.b128" + " [$0], [$1];", + "r,r" + pred_constraint, + has_side_effects=True, ) - offset = 0 - for v in vectors: - for i in range(vty.shape[0]): - elem = vector.extractelement(v, position=c(i, index)) - result = vector.insertelement(elem, result, position=c(offset + i, index)) - offset += vty.shape[0] - return result + + +def query_cluster_cancel( + result_ref, +) -> tuple[ir.Value, ir.Value, ir.Value, ir.Value]: + """Decodes the response of `try_cluster_cancel`. + + It checks if the cancellation was successful, and if yes, it also extracts + the CTA ID of the first CTA in the canceled cluster. + """ + + i32 = ir.IntegerType.get_signless(32) + i1 = ir.IntegerType.get_signless(1) + struct_ty = llvm.StructType.get_literal([i32, i32, i32, i1]) + + addr = memref_ptr(result_ref, memory_space=3) + desc = llvm.inline_asm( + struct_ty, + [addr], + """ + { + .reg .b128 handle; + ld.shared.b128 handle, [$4]; + clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 $3, handle; + @$3 clusterlaunchcontrol.query_cancel.get_first_ctaid.v4.b32.b128 {$0, $1, $2, _}, handle; + }""", + "=r,=r,=r,=b,r", + ) + + cta_ids = [llvm.extractvalue(i32, desc, [idx]) for idx in [0, 1, 2]] + cancelled_launch = llvm.extractvalue(i1, desc, [3]) + + return (*cta_ids, cancelled_launch) + + +def nanosleep(nanos: ir.Value): + """Sleeps the current thread for the given number of nanoseconds.""" + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [nanos], + "nanosleep.u32 $0;", + "r", + has_side_effects=True, + ) + + +def nvvm_mbarrier_arrive_expect_tx(barrier: ir.Value, expect_tx: ir.Value, predicate: ir.Value | None = None): + try: + return nvvm.mbarrier_arrive_expect_tx(None, barrier, expect_tx, predicate=predicate) # type: ignore + except TypeError: + return nvvm.mbarrier_arrive_expect_tx(barrier, expect_tx, predicate=predicate) # pytype: disable=missing-parameter + + +def elements_to_bytes(offset: ir.Value, element_bitwidth: int) -> ir.Value: + """Convert an element-based linear offset to a byte-based offset.""" + index_ty = offset.type + + if element_bitwidth > 8: + return arith.muli(offset, c(element_bitwidth // 8, index_ty)) + elif element_bitwidth < 8: + return arith.divsi(offset, c(8 // element_bitwidth, index_ty)) + else: + return offset + + +def get_cluster_ptr(ptr: ir.Value, cluster_block: ir.Value): + i32 = ir.IntegerType.get_signless(32) + assert cluster_block.type == i32, cluster_block.type + assert ptr.type == ir.Type.parse("!llvm.ptr<3>"), ptr.type + mapped_smem_ptr = nvvm.mapa(ir.Type.parse("!llvm.ptr<7>"), ptr, cluster_block) + return llvm.addrspacecast(ir.Type.parse("!llvm.ptr"), mapped_smem_ptr) + + +@dataclasses.dataclass(frozen=True) +class Arch: + major: int + minor: int + + +def get_arch() -> Arch: + ip = ir.InsertionPoint.current + if ip is None: + raise ValueError("Cannot retrieve the architecture without an insertion point") + block = ip.block + op = block.owner + while op is not None: + if op.name == "builtin.module": + return Arch( + op.attributes["mosaic_gpu.arch_major"].value, + op.attributes["mosaic_gpu.arch_minor"].value, + ) + op = op.parent + raise ValueError("Cannot retrieve the architecture: no module found") diff --git a/jax/experimental/mosaic/gpu/wgmma.py b/jax/experimental/mosaic/gpu/wgmma.py index 8baa16d8a7e9..33e8b63e3b0d 100644 --- a/jax/experimental/mosaic/gpu/wgmma.py +++ b/jax/experimental/mosaic/gpu/wgmma.py @@ -14,7 +14,6 @@ # ============================================================================== import dataclasses -import functools import itertools import math @@ -30,7 +29,6 @@ from . import mma_utils from . import utils -# mypy: ignore-errors c = utils.c bytewidth = utils.bytewidth @@ -45,50 +43,75 @@ class WGMMAAccumulator: as a WGMMA accumulator. In particular, when created from a FragmentedArray, the necessary synchronization is inserted at construction. """ - value: fa.FragmentedArray - - def __init__(self, *, _value: fa.FragmentedArray, _sync: bool = True): - if _value.layout != fa.WGMMA_LAYOUT: - raise ValueError("Only WGMMA layouts supported in WGMMAAccumulator") - self.value = _value + _original_layout: fa.FragmentedLayout + _value: fa.FragmentedArray + + def __init__( + self, + *, + _value: fa.FragmentedArray, + _original_layout: fa.FragmentedLayout, + _sync: bool = True, + ): + self._original_layout = _original_layout + self._value = _value if _sync: - self.value = wgmma_fence(_value) + self._value = wgmma_fence(_value) + + @property + def value(self) -> fa.FragmentedArray: + return self._value.to_layout(self._original_layout) @classmethod def zero(cls, m, n, dtype=None, *, is_signed: bool | None = None): if m % 64 or n % 8: - raise ValueError - if is_signed is False: + raise ValueError("WGMMA requires m and n to be multiples of 64 and 8, " + f"got {m} and {n}") + if is_signed is False: # pylint: disable=g-bool-id-comparison raise TypeError("PTX does not support unsigned WGMMA accumulators") f32 = ir.F32Type.get() if dtype is None: dtype = f32 - zero = arith.constant(dtype, ir.FloatAttr.get(dtype, 0.0)) - return cls( - _value=fa.FragmentedArray.splat( + if isinstance(dtype, ir.IntegerType): + zero = arith.constant(dtype, ir.IntegerAttr.get(dtype, 0)) + else: + zero = arith.constant(dtype, ir.FloatAttr.get(dtype, 0.0)) + return cls.from_registers( + fa.FragmentedArray.splat( zero, (m, n), fa.WGMMA_LAYOUT, is_signed=is_signed ) ) @classmethod def from_registers(cls, registers): - return cls(_value=registers) + original_layout = registers.layout + if registers.layout != fa.WGMMA_LAYOUT and registers.layout != fa.WGMMA_LAYOUT_ACC_32BIT: + raise ValueError("Only WGMMA layouts supported in WGMMAAccumulator") + if utils.bitwidth(registers.mlir_dtype) == 32: + registers = registers.to_layout(fa.WGMMA_LAYOUT_ACC_32BIT) + return cls(_value=registers, _original_layout=original_layout) def tree_flatten(self): - return (self.value,), () + return (self._value,), (self._original_layout,) @classmethod def tree_unflatten(cls, aux, value): - del aux - return cls(_value=value[0], _sync=False) + return cls(_value=value[0], _original_layout=aux[0], _sync=False) def _supported_wgmma_types(dtype, abtype) -> bool: - input_types_are = lambda ty: ty.isinstance(abtype) - if ir.F32Type.isinstance(dtype): - return any(input_types_are(ty) for ty in (ir.FloatTF32Type, ir.BF16Type, ir.F16Type)) - elif ir.F16Type.isinstance(dtype): - return input_types_are(ir.F16Type) + input_types_are = lambda ty: isinstance(abtype, ty) + f16_acc_types = (ir.F16Type, ir.Float8E5M2Type, ir.Float8E4M3FNType) + if isinstance(dtype, ir.F32Type): + return any(input_types_are(ty) for ty in (ir.FloatTF32Type, ir.BF16Type, *f16_acc_types)) + elif isinstance(dtype, ir.F16Type): + return any(input_types_are(ty) for ty in f16_acc_types) + elif ( + isinstance(dtype, ir.IntegerType) + and dtype.width == 32 + and dtype.is_signless + ): + return input_types_are(ir.IntegerType) else: return False @@ -107,13 +130,17 @@ def wgmma_m64( ): out_ty = ir.VectorType(acc.flat[0].type).element_type if not _supported_wgmma_types(out_ty, element_type): - raise ValueError(f"Usupported wgmma types {(out_ty, element_type)=}") + raise ValueError(f"Unsupported wgmma types {(out_ty, element_type)=}") if n % 8: raise ValueError + bf16 = ir.BF16Type.get() + f16 = ir.F16Type.get() + i8 = ir.IntegerType.get_signless(8) i32 = ir.IntegerType.get_signless(32) i64 = ir.IntegerType.get_signless(64) - index = ir.IndexType.get() + f8e5m2 = ir.Float8E5M2Type.get() + f8e4m3fn = ir.Float8E4M3FNType.get() if b_k_stride % 16: raise ValueError # Only 16-bit types support transposes @@ -121,10 +148,14 @@ def wgmma_m64( if not supports_transpose and (a_transpose or b_transpose): raise ValueError("Only f16 WGMMA supports transposes") if a_in_regs := isinstance(a, fa.FragmentedArray): - if a.mlir_dtype != ir.F16Type.get() and a.mlir_dtype != ir.BF16Type.get(): + if a.mlir_dtype not in {bf16, f16, i8, f8e5m2, f8e4m3fn}: raise ValueError(f"Unsupported A register array dtype: {a.mlir_dtype}") # Column count must be equal to swizzle // bytewidth. - if a.layout != fa.WGMMA_LAYOUT or a.shape != (64, swizzle // 2): + elt_bytewidth = utils.bytewidth(element_type) + swizzle_elems = swizzle // elt_bytewidth + if a.shape != (64, swizzle_elems): + raise ValueError("Unsupported A register array shape") + if a.layout not in {fa.WGMMA_LAYOUT, fa.WGMMA_LAYOUT_8BIT}: raise ValueError("Unsupported A register array layout") if a_k_stride is not None or a_transpose is not None: raise ValueError("Unsupported WGMMA features with A in registers") @@ -134,31 +165,35 @@ def wgmma_m64( if a_transpose is None: raise ValueError - if ir.F32Type.isinstance(out_ty): + if isinstance(out_ty, ir.F32Type) or out_ty == i32: num_acc_regs = n // 2 - out_ty_field = out_ty - acc_regs = [ # pylint: disable=g-complex-comprehension - vector.extractelement(reg, position=c(pos, index)) - for reg in acc.flat - for pos in range(2) - ] - to_acc_vec_regs = functools.partial(_as_fragmented_reg_ndarray, dtype=out_ty, shape=acc.shape) - acc_constraint = "f" - elif ir.F16Type.isinstance(out_ty): + out_ty_field = ir.VectorType.get((1,), out_ty) + acc_regs = list(acc.flat) + assert acc_regs[0].type == ir.VectorType.get((1,), out_ty) + to_acc_vec_regs = lambda regs: np.array(regs).reshape(acc.shape) + acc_constraint = "r" if isinstance(out_ty, ir.IntegerType) else "f" + elif isinstance(out_ty, ir.F16Type): num_acc_regs = n // 4 out_ty_field = i32 acc_regs = [_as_i32_reg(reg) for reg in acc.flat] vec_ty = ir.VectorType(acc.flat[0].type) - to_acc_vec_regs = lambda regs : np.array([_unpack_i32(vec_ty, reg) for reg in regs]).reshape(acc.shape) + to_acc_vec_regs = lambda regs: np.array([_unpack_i32(vec_ty, reg) for reg in regs]).reshape(acc.shape) acc_constraint = "r" else: - raise ValueError(f"WGMMA instruciton only supports f32 and f16 out (got {out_ty})") + raise ValueError( + f"WGMMA instruction only supports f32, f16 and s32 out (got {out_ty})") - num_imm_regs = 4 if supports_transpose else 2 + if supports_transpose: + num_imm_regs = 4 + elif out_ty == i32: + num_imm_regs = 0 + else: + num_imm_regs = 2 if a_in_regs: - a_reg_constraints = ["r"] * 4 # 4x f16x2 registers - num_imm_regs -= 1 # transpose not supported for a in registers + a_reg_constraints = ["r"] * 4 # 4x (b)f16x2 or s8x4 registers + if supports_transpose: + num_imm_regs -= 1 # transpose not supported for a in registers else: a_reg_constraints = ["l"] # descriptor # Reference for i/o aliasing: https://gcc.gnu.org/onlinedocs/gcc/Extended-Asm.html @@ -171,7 +206,6 @@ def wgmma_m64( + ["n"] * (1 + num_imm_regs) # literal constants ) reg_constraints = ",".join(reg_constraints_list) - reg_count = itertools.count() def take_regs(n): @@ -185,13 +219,28 @@ def take_regs(n): else: a_regs, = take_regs(1) b_desc_reg, use_out_reg = take_regs(2) - imm_regs = ", ".join(take_regs(num_imm_regs)) # Immediate regs (scale, ...). + # Immediate regs (scale, ...). + imm_regs = "".join(f", {r}" for r in take_regs(num_imm_regs)) assert next(reg_count) == len(reg_constraints_list) - el_ty = element_type k_instr = 32 // bytewidth(element_type) + el_ty = str(element_type) + if isinstance(element_type, ir.Float8E5M2Type): + el_ty = "e5m2" + elif isinstance(element_type, ir.Float8E4M3FNType): + el_ty = "e4m3" + elif isinstance(element_type, ir.IntegerType): + # TODO(bchetioui): add u8 support in the future. Currently we always assume + # that 8-bit integers are s8, and we would need to change the signature of + # `wgmma` to indicate whether the input should be treated as signed or not. + el_ty = "s8" + + out_ty_str = str(out_ty) + if out_ty == i32: + out_ty_str = "s32" + wgmma_instr = ( - f"wgmma.mma_async.sync.aligned.m64n{n}k{k_instr}.{out_ty}.{el_ty}.{el_ty} " - f"{acc_reg_vector}, {a_regs}, {b_desc_reg}, p, {imm_regs};" + f"wgmma.mma_async.sync.aligned.m64n{n}k{k_instr}.{out_ty_str}.{el_ty}.{el_ty} " + f"{acc_reg_vector}, {a_regs}, {b_desc_reg}, p{imm_regs};" ) ptx = f"{{ .reg .pred p; setp.ne.b32 p, {use_out_reg}, 0; {wgmma_instr} }}\n" @@ -199,12 +248,21 @@ def lc(x): return llvm.ConstantOp(i32, ir.IntegerAttr.get(i32, x)).result use_out = scale_a = scale_b = lc(1) - imms = [use_out, scale_a, scale_b] + if out_ty == i32: + imms = [use_out] + else: + imms = [use_out, scale_a, scale_b] + if supports_transpose and a_transpose is not None: imms += [lc(int(a_transpose)), lc(int(b_transpose))] elif supports_transpose: imms += [lc(int(b_transpose))] - if acc.ndim != 10 or acc.shape[0] != 1 or math.prod(acc.shape[2:]) != 2: + + assert len(imms) == num_imm_regs + 1 # +1 for the use_out_reg in setp.ne.b32 + + expected_dim = 10 if utils.bitwidth(out_ty) == 32 else 9 + expected_regs_per_tile = 4 if utils.bitwidth(out_ty) == 32 else 2 + if acc.ndim != expected_dim or acc.shape[0] != 1 or math.prod(acc.shape[2:]) != expected_regs_per_tile: raise ValueError(acc.shape) acc_struct_type = ir.Type.parse( f"!llvm.struct<({','.join(str(out_ty_field) for _ in acc_regs)})>" @@ -212,10 +270,11 @@ def lc(x): for i in range((swizzle // bytewidth(element_type)) // k_instr): # Slice out the relevant part of A or advance the A descriptor. if a_in_regs: - a_slice = a[:, (i * 16) : ((i + 1) * 16)] + a_slice = a[:, (i * k_instr) : ((i + 1) * k_instr)] a_args = [_as_i32_reg(v) for v in a_slice.registers.flat] else: if i > 0: + assert a_k_stride is not None a = _llvm_add( a, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, a_k_stride >> 4)), @@ -251,28 +310,43 @@ def wgmma( ): """Perform acc += a @ b using the WGMMA instruction. - The expected memref shapes are: - a: (m, k, 64, S) - b: (k, n, S, S) - where S = swizzle // bytewidth(element_type). + `a` may be passed in registers, or as a memref. `b` must be a memref. - The refs must be contiguous or be contiguous except for having their two minor - dimensions swapped. + The expected (logical) memref shapes are: + a: (m // tile_m, k // tile_k, tile_m, tile_k) + b: (k // tile_k, n // tile_n, tile_k, tile_n). + + While the shapes may be physically transposed, when considering the row-major + physical shape, the tile dimensions must be the two minor dimensions and must + have the shape (8, S) where S = swizzle // bytewidth(element_type). """ if swizzle == 16: raise NotImplementedError("No swizzle is not supported") # Step 1. Establish the shape and element type of the operation. - if not ir.MemRefType.isinstance(b.type): + if not isinstance(b.type, ir.MemRefType): raise ValueError(f"B must be a memref, got: {b.type}") + bf16 = ir.BF16Type.get() + f32 = ir.F32Type.get() + f16 = ir.F16Type.get() + i32 = ir.IntegerType.get_signless(32) + i8 = ir.IntegerType.get_signless(8) + f8e5m2 = ir.Float8E5M2Type.get() + f8e4m3fn = ir.Float8E4M3FNType.get() (k, n), element_type = mma_utils.tiled_memref_shape(b) if a_in_regs := isinstance(a, fa.FragmentedArray): m, k2 = a.shape element_type2 = a.mlir_dtype - if a.mlir_dtype != ir.F16Type.get() and a.mlir_dtype != ir.BF16Type.get(): + if element_type2 not in {f16, bf16, i8, f8e5m2, f8e4m3fn}: raise ValueError( - f"Only 16-bit dtypes supported for A in registers, got {a.mlir_dtype}" + "Only f16, bf16, i8, f8e5m2, f8e4m3fn are supported for A " + f"in registers, got {element_type2}" ) - elif ir.MemRefType.isinstance(a.type): + if element_type2 == i8 and swizzle == 32: + # TODO(bchetioui): relax this when ptxas is fixed. As of ptxas 12.8, + # optimizations eliminate MMA instructions, leading to only the first tile + # of the result being computed correctly. + raise NotImplementedError("swizzle=32 not supported for s8 lhs in registers") + elif isinstance(a.type, ir.MemRefType): (m, k2), element_type2 = mma_utils.tiled_memref_shape(a) else: raise ValueError(f"Unsupported A type: {type(a)}") @@ -286,23 +360,35 @@ def wgmma( "WGMMA requires A and B to have the same element type, got:" f" {element_type2} and {element_type}" ) - if acc.value.shape != (m, n): + if acc._value.shape != (m, n): raise ValueError( - f"Accumulator shape mismatch: expected {(m, n)}, got {acc.value.shape}" + f"Accumulator shape mismatch: expected {(m, n)}, got {acc._value.shape}" ) - f32 = ir.F32Type.get() if element_type == f32 or element_type == ir.BF16Type.get(): - if acc.value.mlir_dtype != f32: + if acc._value.mlir_dtype != f32: raise ValueError( f"WGMMA with element type {element_type} only supports accumulators" - f" of type f32, but got: {acc.value.mlir_dtype}" + f" of type f32, but got: {acc._value.mlir_dtype}" + ) + elif any( + isinstance(element_type, t) + for t in {ir.F16Type, ir.Float8E5M2Type, ir.Float8E4M3FNType} + ): + if acc._value.mlir_dtype != f16 and acc._value.mlir_dtype != f32: + raise ValueError( + f"WGMMA with element type {element_type} only supports accumulators " + f"of type f32 or f16, but got: {acc._value.mlir_dtype}" ) - elif element_type == ir.F16Type.get(): - if acc.value.mlir_dtype != element_type and acc.value.mlir_dtype != f32: + elif element_type == i8: + if a_in_regs and not a.is_signed: + raise NotImplementedError("WGMMA with lhs of type u8") + if acc._value.mlir_dtype != i32 or not acc._value.is_signed: raise ValueError( - "WGMMA with element type f16 only supports accumulators of type f32" - f" or f16, but got: {acc.value.mlir_dtype}" + f"WGMMA with element type {element_type} only supports accumulators " + f"of type s32, but got: {acc._value.mlir_dtype}" ) + else: + raise NotImplementedError(f"Unsupported element type: {element_type}") # Step 2. Decide on the instruction shapes we'll use. Note that with swizzles, # instructions must be issued in groups of the same width as the swizzle. @@ -338,6 +424,8 @@ def wgmma( group_size=(m_group_elems, k_group_elems), logical_k_major=False, ) + assert not a_k_instr_stride[0] # We'd need separate a/b swizzles. + a_k_instr_stride = a_k_instr_stride[1][0] a_instr_params = dict(a_transpose=a_fastest != mma_utils.Dim.K, a_k_stride=a_k_instr_stride) ( @@ -351,6 +439,8 @@ def wgmma( group_size=(k_group_elems, n_group_elems), logical_k_major=True, ) + assert not b_k_instr_stride[0] # We'd need separate a/b swizzles. + b_k_instr_stride = b_k_instr_stride[1][0] del b_n_group_stride # We only support one N group. # Step 4. Issue the instructions. @@ -358,7 +448,7 @@ def wgmma( a = wgmma_fence(a) # Make sure the registers are ready. i64 = ir.IntegerType.get_signless(64) - new_acc_regs = acc.value.registers.copy() + new_acc_regs = acc._value.registers.copy() for mi in range(m_groups): for ki in range(k_groups): if a_in_regs: @@ -367,6 +457,7 @@ def wgmma( ki * k_group_elems : (ki + 1) * k_group_elems, ] else: + assert a_m_group_stride is not None and a_k_group_stride is not None a_group_offset = mi * a_m_group_stride + ki * a_k_group_stride a_mk = _llvm_add( a_desc_base, c(mma_utils.encode_addr(a_group_offset), i64), @@ -388,14 +479,15 @@ def wgmma( return WGMMAAccumulator( _value=fa.FragmentedArray( _registers=new_acc_regs, - _layout=fa.WGMMA_LAYOUT, - _is_signed=acc.value.is_signed, + _layout=acc._value.layout, + _is_signed=acc._value.is_signed, ), + _original_layout=acc._original_layout, _sync=False, ) -def wgmma_fence(array: fa.FragmentedArray): +def wgmma_fence(array: fa.FragmentedArray) -> fa.FragmentedArray: """Fences the array construction from WGMMA instructions. LLVM treats in-register computation as pure and can move it after the fence, @@ -407,16 +499,6 @@ def wgmma_fence(array: fa.FragmentedArray): return array -def _as_fragmented_reg_ndarray(flat_regs, dtype: ir.Type, shape: tuple[int, ...]): - vec_regs = [] - for first, second in zip(flat_regs[::2], flat_regs[1::2]): - vec = llvm.mlir_undef(ir.VectorType.get((2,), dtype)) - vec = llvm.insertelement(vec, first, position=_lc(0)) - vec = llvm.insertelement(vec, second, position=_lc(1)) - vec_regs.append(vec) - return np.asarray(vec_regs, dtype=object).reshape(shape) - - def _as_i32_reg(v): i32 = ir.IntegerType.get_signless(32) return llvm.extractelement( @@ -436,5 +518,5 @@ def _llvm_add(x, y): def _unpack_i32(vec_ty, r): i32 = ir.IntegerType.get_signless(32) return vector.bitcast( - vec_ty, vector.splat(ir.VectorType.get((1,), i32), r) + vec_ty, vector.broadcast(ir.VectorType.get((1,), i32), r) ) diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index 2bde1fbeadc4..f3026502abc6 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -18,19 +18,21 @@ from functools import partial, lru_cache import zlib +import contextlib from typing import Any import jax import jax.numpy as jnp from jax.tree_util import tree_flatten, tree_unflatten from jax._src import core +from jax._src import dtypes from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src import array from jax._src import sharding_impls from jax._src.interpreters import pxla -from jax.interpreters import xla from jax._src import pjit as pjit_lib +from jax._src import prng from jax.sharding import PartitionSpec as P from jax._src import distributed from jax._src.util import safe_zip @@ -39,8 +41,8 @@ import numpy as np -def _psum(x: Any) -> Any: - return jax.tree.map(partial(jnp.sum, axis=0), x) +def _psum(xs: Any) -> Any: + return jax.tree.map(lambda x: jnp.sum(x, dtype=x.dtype, axis=0), xs) def broadcast_one_to_all(in_tree: Any, is_source: bool | None = None) -> Any: @@ -80,15 +82,10 @@ def post_jit(x): return jax.device_get(x.addressable_data(0)) in_tree = jax.tree.map(pre_jit, in_tree) - out_tree = jax.jit(_psum, out_shardings=jax.sharding.NamedSharding( - global_mesh, P()))(in_tree) - return jax.tree.map(post_jit, out_tree) - + with jax.set_mesh(global_mesh): + out_tree = jax.jit(_psum, out_shardings=P())(in_tree) -def sync_global_devices(name: str): - """Creates a barrier across all hosts/devices.""" - h = np.uint32(zlib.crc32(name.encode())) - assert_equal(h, f"sync_global_devices name mismatch ('{name}')") + return jax.tree.map(post_jit, out_tree) # Identity function is at the top level so that `process_allgather` doesn't @@ -99,8 +96,15 @@ def _identity_fn(x): def _handle_array_process_allgather(inp, tiled): if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable: - reps = sharding_impls.GSPMDSharding.get_replicated( - inp.sharding._device_assignment) + if not tiled: + raise ValueError( + 'Gathering global non-fully-addressable arrays only supports' + ' tiled=True') + if isinstance(inp.sharding, sharding_impls.NamedSharding): + reps = inp.sharding.update(spec=P()) + else: + reps = sharding_impls.GSPMDSharding.get_replicated( + inp.sharding._device_assignment, memory_kind=inp.sharding.memory_kind) out = jax.jit(_identity_fn, out_shardings=reps)(inp) else: # All inputs here will be fully addressable. @@ -119,15 +123,15 @@ def _handle_array_process_allgather(inp, tiled): host_np_arr = np.expand_dims(host_np_arr, axis=0) aval = core.ShapedArray(host_np_arr.shape, host_np_arr.dtype) + pspec = sharding_impls.prepare_axis_resources(pspec, "pspec to array_mapping") global_aval = pxla.mesh_local_to_global( - global_mesh, pxla.get_array_mapping(pspec), aval) + global_mesh, sharding_impls.get_array_mapping(pspec), aval) bufs = [jax.device_put(host_np_arr, d) for d in jax.local_devices()] global_arr = array.make_array_from_single_device_arrays( global_aval.shape, s, bufs) - out = jax.jit(_identity_fn, - out_shardings=jax.NamedSharding(global_mesh, P()))(global_arr) - + with jax.set_mesh(global_mesh): + out = jax.jit(_identity_fn, out_shardings=P())(global_arr) return np.asarray(out.addressable_data(0)) @@ -155,13 +159,29 @@ def _pjit(inp): return jax.tree.map(_pjit, in_tree) +def sync_global_devices(name: str): + """Creates a barrier across all hosts/devices.""" + h = np.uint32(zlib.crc32(name.encode())) + assert_equal(h, f"sync_global_devices name mismatch ('{name}')") + + def assert_equal(in_tree, fail_message: str = ''): """Verifies that all the hosts have the same tree of values.""" - expected = broadcast_one_to_all(in_tree) - if not jax.tree_util.tree_all( - jax.tree_util.tree_map(lambda *x: np.all(np.equal(*x)), in_tree, expected)): + def concat_in_tree(x): + if isinstance(x, array.ArrayImpl) and not x.is_fully_addressable: + return np.asarray(x.addressable_data(0)) + else: + x = np.asarray(x) + if x.ndim == 0: + x = np.expand_dims(x, axis=0) + return np.concat([x] * jax.process_count()) + + out = process_allgather(in_tree, tiled=True) + expected_in_tree = jax.tree.map(concat_in_tree, in_tree) + if not jax.tree.all( + jax.tree.map(lambda *x: np.all(np.equal(*x)), expected_in_tree, out)): raise AssertionError( - f'{fail_message} Expected: {expected}; got: {in_tree}.') + f'{fail_message}. Expected: {out}; got: {in_tree}.') def reached_preemption_sync_point(step_id: int) -> bool: @@ -200,13 +220,16 @@ def should_save(step_id: int) -> bool: after some hosts are preempted. Raises: - RuntimeError: if preemption sync manager has not been inititialized. + RuntimeError: if preemption sync manager has not been initialized. """ if distributed.global_state.client is None: return False sync_manager = distributed.global_state.preemption_sync_manager if sync_manager is None: - raise RuntimeError("Preemption sync manager has not been initialized.") + raise RuntimeError( + "Preemption sync manager has not been initialized. Make sure the" + " 'jax_enable_preemption_service' config is enabled." + ) return sync_manager.reached_sync_point(step_id) @@ -217,13 +240,15 @@ def _flatten_pspecs(name, in_tree, pspecs_thunk): @lru_cache def _local_to_global_aval(local_aval, mesh, pspec): - return pxla.mesh_local_to_global(mesh, pxla.get_array_mapping(pspec), - local_aval) + pspec = sharding_impls.prepare_axis_resources(pspec, "pspec to array_mapping") + return pxla.mesh_local_to_global( + mesh, sharding_impls.get_array_mapping(pspec), local_aval) @lru_cache def _global_to_local_aval(global_aval, mesh, pspec): - return pxla.mesh_global_to_local(mesh, pxla.get_array_mapping(pspec), - global_aval) + pspec = sharding_impls.prepare_axis_resources(pspec, "pspec to array_mapping") + return pxla.mesh_global_to_local( + mesh, sharding_impls.get_array_mapping(pspec), global_aval) def host_local_array_to_global_array_impl( @@ -235,9 +260,14 @@ def host_local_array_to_global_array_impl( # If the Array is not fully addressable i.e. not host local, return it. if isinstance(arr, array.ArrayImpl) and not arr.is_fully_addressable: return arr - if isinstance(arr, array.ArrayImpl) and isinstance( - arr.sharding, jax.sharding.PmapSharding): + if (isinstance(arr, array.ArrayImpl) and isinstance( + arr.sharding, jax.sharding.PmapSharding)) or not hasattr(arr, 'shape'): arr = np.array(arr) + if arr.dtype == dtypes.float0: + arr = np.zeros(arr.shape, dtype=np.dtype(bool)) + dtype = arr.dtype + if is_prng_key_array := isinstance(arr, prng.PRNGKeyArray): + arr = arr._base_array local_sharding = jax.sharding.NamedSharding(global_mesh.local_mesh, pspec) @@ -248,17 +278,20 @@ def host_local_array_to_global_array_impl( arr.sharding.is_equivalent_to(local_sharding, arr.ndim)): arrays = [x.data for x in arr.addressable_shards] else: - arr = xla.canonicalize_dtype(arr) + arr = dtypes.canonicalize_value(arr) arrays = [ - arr[index] - for d, index in local_sharding.devices_indices_map(arr.shape).items()] + arr[i] for i in local_sharding.devices_indices_map(arr.shape).values() + ] global_aval = _local_to_global_aval( core.ShapedArray(arr.shape, arr.dtype), global_mesh, pspec) - return pxla.batched_device_put( + out = pxla.batched_device_put( global_aval, jax.sharding.NamedSharding(global_mesh, pspec), arrays, list(global_mesh.local_mesh.devices.flat)) + if is_prng_key_array: + return prng.PRNGKeyArray(dtype._impl, out) + return out def host_local_array_to_global_array( @@ -325,7 +358,7 @@ def host_local_array_to_global_array( >>> >>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs) # doctest: +SKIP - Please note ths function requires global mesh to be a continuous mesh, meaning + Please note this function requires global mesh to be a continuous mesh, meaning that devices that belong to each host should form a subcube in this mesh. To move local data to global array with non-continuous mesh use jax.make_array_from_callback or jax.make_array_from_single_device_arrays @@ -363,11 +396,13 @@ def ltg_abstract_eval(arr, *, global_mesh, pspec): host_local_array_to_global_array_p.bind(ct, **params),)) def ltg_batcher(insert_axis, axis_data, vals_in, dims_in, global_mesh, pspec): + del insert_axis x, = vals_in d, = dims_in new_parts = None if axis_data.spmd_name is None else axis_data.spmd_name new_pspec = list(pspec) - new_pspec.insert(d, new_parts) + if d is not None: + new_pspec.insert(d, new_parts) new_pspec = P(*new_pspec) y = host_local_array_to_global_array_p.bind( x, global_mesh=global_mesh, pspec=new_pspec) @@ -389,6 +424,13 @@ def global_array_to_host_local_array_impl( # If the Array is already fully addressable i.e. host local, return it. if isinstance(arr, array.ArrayImpl) and arr.is_fully_addressable: return arr + if not hasattr(arr, 'shape'): + arr = np.array(arr) + if arr.dtype == dtypes.float0: + arr = np.zeros(arr.shape, dtype=np.dtype(bool)) + dtype = arr.dtype + if is_prng_key_array := isinstance(arr, prng.PRNGKeyArray): + arr = arr._base_array global_sharding = jax.sharding.NamedSharding(global_mesh, pspec) local_sharding = jax.sharding.NamedSharding(global_mesh.local_mesh, pspec) @@ -401,16 +443,19 @@ def global_array_to_host_local_array_impl( else: resharded_array = jax.device_put(arr, global_sharding) arrays = resharded_array._arrays - return array.ArrayImpl(local_aval, local_sharding, arrays, committed=True) + out = array.ArrayImpl(local_aval, local_sharding, arrays, committed=True) + if is_prng_key_array: + return prng.PRNGKeyArray(dtype._impl, out) + return out else: # numpy array can show up here during AD. - arr = xla.canonicalize_dtype(arr) + arr = dtypes.canonicalize_value(arr) arrays = [ - arr[index] - for d, index in local_sharding.devices_indices_map(arr.shape).items()] - return pxla.batched_device_put( - local_aval, local_sharding, arrays, - list(global_mesh.local_mesh.devices.flat)) + arr[i] for i in local_sharding.devices_indices_map(arr.shape).values() + ] + return pxla.batched_device_put( + local_aval, local_sharding, arrays, + list(global_mesh.local_mesh.devices.flat)) def global_array_to_host_local_array( @@ -474,15 +519,55 @@ def _gtl_lowering(ctx, x, *, global_mesh, pspec): mlir.register_lowering(global_array_to_host_local_array_p, _gtl_lowering) -def live_devices(devices: list[xla_client.Device]) -> list[xla_client.Device]: - """Returns the subset of the provided devices that are live and healthy. +def _live_devices(client, devices: list[xla_client.Device]) -> dict[xla_client.Device, int]: + """Returns the subset of the provided devices that are live and healthy.""" + process_ids = {d.process_index for d in devices} + if xla_bridge.process_index() not in process_ids: + # A process can only participate in an live_devices call if it hosts some of + # the provided devices. + raise ValueError('Provided devices do not have any local devices.') + + live_process_ids = client.get_live_nodes(list(process_ids)) + return { + d: live_process_ids[d.process_index] + for d in devices + if d.process_index in live_process_ids + } + + +class _LiveDevices: + """A context manager for atomically running code on the set of live devices. + + THIS API IS UNDER ACTIVE DEVELOPMENT AND IS NOT STABLE. + + # Overview - This API is under active development and is not stable. + `live_devices` is a low-level primitive that can be used to make + multi-controller JAX programs fault tolerant. A multi-controller JAX program + runs across many devices, and the machines that host these devices might fail. + `live_devices` is a context manager that yields the current set of healthy + devices, allowing you to run JAX code on the healthy devices while ignoring + the failed ones. - `live_devices` is a low-level fault tolerance primitive that can be used to - implement fault tolerant multi-process JAX programs. + Concretely, `live_devices` is a context manager. You provide it the set of + devices you are interested in, and it yields the subset of these devices that + are live. In the body of the `with` statement, you can execute arbitrary JAX + code using the set of live devices. - Barrier Semantics + # Example Usage + + try: + with jax.live_devices(jax.devices()) as devices: + # Run JAX code here with devices. + pass + except: + # A device died while executing the with statement above. + pass + else: + # The with statement executed successfully. + pass + + # Barrier Semantics It's important that every process agrees on which devices are live to avoid the processes' behavior from diverging. For example, imagine a set of @@ -490,19 +575,19 @@ def live_devices(devices: list[xla_client.Device]) -> list[xla_client.Device]: should be participating in the AllGather. This is buggy. To ensure that every process agrees on the set of live devices, the - `live_devices` function has barrier-like semantics. Consider an invocation - `live_devices(devices)` where `devices` includes devices across a set of - processes P. The invocation acts as a barrier, waiting for every process in P - to call `live_devices(devices)`. Afterwards, `live_devices` returns the same - set of live devices `A` to all the processes in P. This ensures that every - process agrees on the set of live devices. + `live_devices` context manager has barrier-like semantics. Consider an + invocation `with live_devices(devices)` where `devices` includes devices + across a set of processes P. The invocation acts as a barrier, waiting for + every process in P to call `with live_devices(devices)`. Afterwards, + `live_devices` returns the same set of live devices `A` to all the processes + in P. This ensures that every process agrees on the set of live devices. `live_devices` does not actually act as a barrier for *every* process in P because some processes in P might have failed. Instead, the `live_devices` function waits only for the processes with a device in the returned set of live devices A. - An Example + # An Example Imagine we have four processes, each with two devices: @@ -511,12 +596,40 @@ def live_devices(devices: list[xla_client.Device]) -> list[xla_client.Device]: Process C: Devices 5 and 6 Process D: Devices 7 and 8 - Further imagine that process D fails and that every process calls - `live_devices(jax.devices())`. The invocation returns devices 1, 2, 3, 4, 5, + Further imagine that process D fails and that every process calls `with + live_devices(jax.devices())`. The invocation returns devices 1, 2, 3, 4, 5, and 6. Because these devices are hosted by processes A, B, and C, the call to `live_devices` acts as a barrier across processes A, B, and C. Process D, which failed, is ignored. + # Atomicity + + `live_devices` also provides the following transaction-like atomicity + property. When a process exits the body of a `with jax.live_devices(...) as + devices:` block, there are two possibilities. + + 1. All processes in `devices` successfully executed all code in the block + without any exceptions being raised. + 2. All processes in `devices` did not successfully execute the code in the + block, and all the processes will raise an exception. + + Consider the following code. + + try: + with jax.live_devices(...) as devices: + pass + except: + pass # A + else: + pass # B + + The atomicity property says that either every process with devices in + `devices` will enter the except branch (A) or every process with devices in + `devices` will enter the else branch (B). It is impossible for some processes + to enter A and others to enter B. + + TODO: mwhittaker - Link to formal live devices semantics. + Args: devices: A list of devices. The provided devices must include at least one local device. @@ -528,26 +641,37 @@ def live_devices(devices: list[xla_client.Device]) -> list[xla_client.Device]: RuntimeError: If the distributed runtime was not initialized. ValueError: If no local devices are provided. """ - client = distributed.global_state.client - if client is None: - raise RuntimeError('Distributed JAX not initialized.') - - if not devices: - # TODO(mwhittaker): Make devices optional. If it's not provided, use - # jax.devices() as a default. - raise ValueError('No devices provided.') - process_ids = {d.process_index for d in devices} - if xla_bridge.process_index() not in process_ids: - # A process can only participate in an live_devices call if it hosts some - # of the provided devices. - raise ValueError('Provided devices do not have any local devices.') - - if len(process_ids) == 1: - # If the provided devices are hosted by a single process (this one), then we - # don't have to perform any distributed computation. We know our local - # devices are all live. - return devices - - live_process_ids = client.get_live_nodes(list(process_ids)) - return [d for d in devices if d.process_index in live_process_ids] + def __init__(self): + self.devices = None + + @contextlib.contextmanager + def __call__(self, devices): + client = distributed.global_state.client + if client is None: + raise RuntimeError('Distributed JAX not initialized.') + + if not devices: + # TODO(mwhittaker): Make devices optional. If it's not provided, use + # jax.devices() as a default. + raise ValueError('No devices provided.') + + if self.devices is None: + self.devices = _live_devices(client, devices) + exception = None + try: + alive = list(self.devices.keys()) + alive.sort(key=lambda d: d.id) + yield alive + except Exception as e: + exception = e + finally: + old_devices = self.devices + new_devices = _live_devices(client, devices) + self.devices = new_devices + if exception: + raise exception + if not old_devices.items() <= new_devices.items(): + raise ValueError(f'{old_devices} is not a subset of {new_devices}') + +live_devices = _LiveDevices() diff --git a/jax/experimental/ode.py b/jax/experimental/ode.py index db7865124687..bdbd52000b15 100644 --- a/jax/experimental/ode.py +++ b/jax/experimental/ode.py @@ -28,7 +28,7 @@ from functools import partial import operator as op -from typing import Callable +from collections.abc import Callable import jax from jax import api_util @@ -214,7 +214,7 @@ def body_fun(state): _, *carry = lax.while_loop(cond_fun, body_fun, [0] + carry) _, _, t, _, last_t, interp_coeff = carry relative_output_time = (target_t - last_t) / (t - last_t) - y_target = jnp.polyval(interp_coeff, relative_output_time.astype(interp_coeff.dtype)) + y_target = jnp.polyval(interp_coeff, relative_output_time.astype(interp_coeff.dtype)) # pytype: disable=attribute-error return carry, y_target f0 = func_(y0, ts[0]) diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index 1e0abacfc25f..6872dfd0941a 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -15,48 +15,51 @@ """Module for Pallas, a JAX extension for custom kernels. See the Pallas documentation at -https://jax.readthedocs.io/en/latest/pallas.html. +https://docs.jax.dev/en/latest/pallas/index.html. """ +from jax._src.pallas.core import BlockDim as BlockDim from jax._src.pallas.core import Blocked as Blocked from jax._src.pallas.core import BlockSpec as BlockSpec +from jax._src.pallas.core import BoundedSlice as BoundedSlice +from jax._src.pallas.core import Buffered as Buffered from jax._src.pallas.core import CompilerParams as CompilerParams from jax._src.pallas.core import core_map as core_map from jax._src.pallas.core import CostEstimate as CostEstimate +from jax._src.pallas.core import Element as Element from jax._src.pallas.core import GridSpec as GridSpec -from jax._src.pallas.core import IndexingMode as IndexingMode from jax._src.pallas.core import lower_as_mlir as lower_as_mlir from jax._src.pallas.core import MemoryRef as MemoryRef from jax._src.pallas.core import MemorySpace as MemorySpace -from jax._src.pallas.core import Buffered as Buffered from jax._src.pallas.core import no_block_spec as no_block_spec -from jax._src.pallas.core import Unblocked as Unblocked -from jax._src.pallas.core import unblocked as unblocked +from jax._src.pallas.core import semaphore as semaphore +from jax._src.pallas.core import Squeezed as Squeezed +from jax._src.pallas.core import squeezed as squeezed from jax._src.pallas.cost_estimate import estimate_cost as estimate_cost +from jax._src.pallas.helpers import debug_check as debug_check +from jax._src.pallas.helpers import debug_checks_enabled as debug_checks_enabled from jax._src.pallas.helpers import empty as empty from jax._src.pallas.helpers import empty_like as empty_like +from jax._src.pallas.helpers import empty_ref_like as empty_ref_like +from jax._src.pallas.helpers import enable_debug_checks as enable_debug_checks +from jax._src.pallas.helpers import kernel as kernel +from jax._src.pallas.helpers import loop as loop from jax._src.pallas.helpers import when as when from jax._src.pallas.pallas_call import pallas_call as pallas_call from jax._src.pallas.pallas_call import pallas_call_p as pallas_call_p -from jax._src.pallas.primitives import atomic_add as atomic_add -from jax._src.pallas.primitives import atomic_and as atomic_and -from jax._src.pallas.primitives import atomic_cas as atomic_cas -from jax._src.pallas.primitives import atomic_max as atomic_max -from jax._src.pallas.primitives import atomic_min as atomic_min -from jax._src.pallas.primitives import atomic_or as atomic_or -from jax._src.pallas.primitives import atomic_xchg as atomic_xchg -from jax._src.pallas.primitives import atomic_xor as atomic_xor from jax._src.pallas.primitives import debug_print as debug_print +from jax._src.pallas.primitives import delay as delay +from jax._src.pallas.primitives import DeviceIdType as DeviceIdType from jax._src.pallas.primitives import dot as dot -from jax._src.pallas.primitives import load as load -from jax._src.pallas.primitives import max_contiguous as max_contiguous +from jax._src.pallas.primitives import get_global as get_global from jax._src.pallas.primitives import multiple_of as multiple_of from jax._src.pallas.primitives import num_programs as num_programs from jax._src.pallas.primitives import program_id as program_id from jax._src.pallas.primitives import reciprocal as reciprocal from jax._src.pallas.primitives import run_scoped as run_scoped -from jax._src.pallas.primitives import store as store -from jax._src.pallas.primitives import swap as swap +from jax._src.pallas.primitives import semaphore_read as semaphore_read +from jax._src.pallas.primitives import semaphore_signal as semaphore_signal +from jax._src.pallas.primitives import semaphore_wait as semaphore_wait from jax._src.pallas.utils import cdiv as cdiv from jax._src.pallas.utils import next_power_of_2 as next_power_of_2 from jax._src.pallas.utils import strides_from_shape as strides_from_shape @@ -68,3 +71,4 @@ ANY = MemorySpace.ANY +HOST = MemorySpace.HOST diff --git a/jax/experimental/pallas/fuser.py b/jax/experimental/pallas/fuser.py index 729a447b7408..c69b6f153db6 100644 --- a/jax/experimental/pallas/fuser.py +++ b/jax/experimental/pallas/fuser.py @@ -18,7 +18,8 @@ from jax._src.pallas.fuser.block_spec import make_scalar_prefetch_handler as make_scalar_prefetch_handler from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec +from jax._src.pallas.fuser.custom_fusion_lib import custom_fusion as custom_fusion from jax._src.pallas.fuser.custom_evaluate import evaluate as evaluate -from jax._src.pallas.fuser.fusable import fusable as fusable +from jax._src.pallas.fuser.fusible import fusible as fusible from jax._src.pallas.fuser.fusion import Fusion as Fusion from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse diff --git a/jax/experimental/pallas/g3doc/debugging.md b/jax/experimental/pallas/g3doc/debugging.md index 6dfa95eb16fa..72ee0154f101 100644 --- a/jax/experimental/pallas/g3doc/debugging.md +++ b/jax/experimental/pallas/g3doc/debugging.md @@ -3,7 +3,7 @@ [TOC] @@ -16,10 +16,39 @@ a ticket on https://github.com/jax-ml/jax/issues. ### Interpret (HLO) Mode -Passing in `interpret=True` into `pl.pallas_call` will run the kernel in HLO instead of lowering to Mosaic/Triton. This is useful for checking correctness of your program and prototyping on smaller block sizes (as TPUs kernels require block sizes of at least 8x128). HLO is also more feature-complete so sometimes kernels will run in interpret mode but fail otherwise - this will make sure the bug is not in your kernel but in Pallas. +Passing in `interpret=True` into `pl.pallas_call` or `pl.core_map` will run the kernel in HLO instead of lowering to Mosaic/Triton. This is useful for checking correctness of your program and prototyping on smaller block sizes (as TPUs kernels require block sizes of at least 8x128). HLO is also more feature-complete so sometimes kernels will run in interpret mode but fail otherwise - this will make sure the bug is not in your kernel but in Pallas. Note that interpret mode will not be able to fully replicate the behavior or programs that use communication (DMAs) between devices. This is because low-level communication APIs are more general than the interface that XLA provides via SPMD collective operations. +### TPU Interpret Mode + +TPU interpret mode is similar to [interpret (HLO) mode](#interpret-hlo-mode), +but TPU interpret mode explicitly simulates accesses to TPU memory (HBM, VMEM, +SMEM, etc.), communication via remote DMAs, TPU synchronization operations +(e.g., barriers and semaphores), and parallel execution of kernels distributed +across +[multiple TPUs](https://docs.jax.dev/en/latest/pallas/tpu/distributed.html) and +[Megacore cores](https://docs.jax.dev/en/latest/pallas/tpu/distributed.html#megacore). + +TPU interpret mode is slower than interpret (HLO) mode, but it can be useful for +developing and debugging distributed TPU kernels with explicit communication and +synchronization. With this mode, kernels can be run on CPU -- enabling local +development (with no TPU), using a debugger and inspecting the state of +simulated TPU buffers and semaphores, etc. + +To use TPU interpret mode, pass `interpret=pltpu.InterpretParams()` into +`pl.pallas_call` or `pl.core_map`. For examples, see +`test_matmul_example` in +[tpu_pallas_interpret_test.py](https://github.com/jax-ml/jax/blob/main/tests/pallas/tpu_pallas_interpret_test.py#:~:text=test_matmul_example) +and +`test_right_permute_example` and the other tests in +[tpu_pallas_interpret_distributed_test.py](https://github.com/jax-ml/jax/blob/main/tests/pallas/tpu_pallas_interpret_distributed_test.py#:~:text=test_right_permute_example). + +The behavior of TPU interpret mode can be configured via arguments to +[`pltpu.InterpretParams`](https://github.com/jax-ml/jax/blob/main/jax/_src/pallas/mosaic/interpret.py#:~:text=class%20InterpretParams). For example, use `num_cores_per_device=2` +to simulate Megacore or `uninitialized_memory='zero'` to initialize simuluated +TPU buffers with zeros instead of NaNs. + ### debug_print The `pl.debug_print` function can be used to print runtime values inside of a kernel. @@ -45,16 +74,14 @@ as a Python error after the kernel has successfully executed. #### Hard assertion -Hard assertions can be inserted with `checkify.check` -and running your program with the `--jax_pallas_enable_runtime_assert` flag. +Hard assertions can be inserted with `pl.debug_check` +and running your program with the `--jax_pallas_enable_debug_checks` flag. Your code will look like the following: ```python -from jax.experimental import checkify - def kernel(...): - checkify.check(x > y, "Check x > y failed") # Will halt if x <= y + pl.debug_check(x > y, "Check x > y failed") # Will halt if x <= y ``` This will print a relatively lengthy dump which resembles the following: @@ -76,11 +103,10 @@ Functionalized asserts can be performed by checkify-ing the `pl.pallas_call` op from jax.experimental import checkify def kernel(...): - checkify.check(x > y, "Check x > y failed") # Will throw an error if x <= y + pl.debug_check(x > y, "Check x > y failed") # Will throw an error if x <= y kernel = pl.pallas_call(...) -checkified_kernel = checkify.checkify(kernel, - errors=checkify.all_checks) +checkified_kernel = checkify.checkify(kernel, errors=checkify.all_checks) error, result = checkified_kernel(x) error.throw() ``` @@ -148,26 +174,39 @@ Mosaic is the underlying TPU compiler for Pallas. It can be useful to dump Mosai Passing the `--xla_mosaic_dump_to=` argument will dump the output of all intermediate Mosaic passes. The names of the files contain either the parameter `name` passed to the `pallas_call`, or the name of the kernel function. A useful option is to dump to Sponge with `--test_arg=--xla_mosaic_dump_to=sponge` after which you will see all passes under the “Artifacts” tab in sponge. -### Static Verification +### Dynamic Race Detection -The static verification tool can be used to automatically detect race conditions in distributed kernels. -Because this tool uses formal verification, it is best used for small kernels (<=2 devices). +[TPU Interpret Mode](#tpu-interpret-mode) includes a dynamic race detector. +While running a kernel, it can detect and log data races -- pairs of accesses +to shared memory (HBM, VMEM, SMEM, etc.) that are not properly synchronized. -Verification can be performed by running your kernel with the `--jax_pallas_dump_promela_to=`, -which will output a Promela dump file. Afterwards, the dump file can be -analyzed using the [`spin`](https://spinroot.com) tool. For example, with a dump named `dump.pml`, run: +To enable the dynamic race detector, use the option `detect_races=True` in the +`pltpu.InterpretParams` passed to `pl.pallas_call`: +```python +pl.pallas_call( + kernel, + ..., + intepret=pltpu.InterpretParams(..., detect_races=True), +) ``` -spin -a dump.pml && gcc -o pan -O3 pan.c -Wno-format-overflow && time ./pan + +If any data races are detected while running the kernel, a message will be +printed -- for example: + +``` +RACE DETECTED + write ... from ...jax/tests/pallas/tpu_pallas_interpret_distributed_test.py:1038:10 (InterpretDistributedTest.test_race_detection..kernel.._) + write ... from .../jax/tests/pallas/tpu_pallas_interpret_distributed_test.py:1038:10 (InterpretDistributedTest.test_race_detection..kernel.._) ``` - + ## Useful Command line flags * OOB Checks: `--xla_mosaic_on_device_checks=bounds` * Poison VMEM allocations: `--xla_jf_poison_vmem_allocations=true` - + * Dump Mosaic: `--xla_mosaic_dump_to=` * Enable trace markers in XProf: `--xla_enable_transpose_trace` @@ -203,5 +242,3 @@ In most cases the error message should hint at what is wrong. For specific errors: * `Mixed dtype operands in cmp` when using `jnp.mod`: Use lax.rem instead of jnp.mod - - diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index 631b4f720984..eccc06881936 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -18,36 +18,84 @@ """ from jax._src.pallas.mosaic_gpu.core import Barrier as Barrier -from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec as GPUBlockSpec -from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams as GPUCompilerParams -from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace as GPUMemorySpace -from jax._src.pallas.mosaic_gpu.core import GPUMesh as GPUMesh +from jax._src.pallas.mosaic_gpu.core import BlockSpec as BlockSpec +from jax._src.pallas.mosaic_gpu.core import ClusterBarrier as ClusterBarrier +from jax._src.pallas.mosaic_gpu.core import CompilerParams as CompilerParams from jax._src.pallas.mosaic_gpu.core import kernel as kernel +from jax._src.pallas.mosaic_gpu.core import Layout as Layout +from jax._src.pallas.mosaic_gpu.core import layout_cast as layout_cast +from jax._src.pallas.mosaic_gpu.core import MemoryRefTransform as MemoryRefTransform +from jax._src.pallas.mosaic_gpu.core import MemorySpace as MemorySpace +from jax._src.pallas.mosaic_gpu.core import Mesh as Mesh +from jax._src.pallas.mosaic_gpu.core import multicast_ref as multicast_ref +from jax._src.pallas.mosaic_gpu.core import PeerMemRef as PeerMemRef +from jax._src.pallas.mosaic_gpu.core import RefUnion as RefUnion +from jax._src.pallas.mosaic_gpu.core import remote_ref as remote_ref +from jax._src.pallas.mosaic_gpu.core import SemaphoreType as SemaphoreType from jax._src.pallas.mosaic_gpu.core import SwizzleTransform as SwizzleTransform from jax._src.pallas.mosaic_gpu.core import TilingTransform as TilingTransform +from jax._src.pallas.mosaic_gpu.core import TMEMLayout as TMEMLayout +from jax._src.pallas.mosaic_gpu.core import transform_ref as transform_ref from jax._src.pallas.mosaic_gpu.core import transpose_ref as transpose_ref from jax._src.pallas.mosaic_gpu.core import TransposeTransform as TransposeTransform +from jax._src.pallas.mosaic_gpu.core import TryClusterCancelResult as TryClusterCancelResult +from jax._src.pallas.mosaic_gpu.core import unswizzle_ref as unswizzle_ref +from jax._src.pallas.mosaic_gpu.core import untile_ref as untile_ref +from jax._src.pallas.mosaic_gpu.core import WarpMesh as WarpMesh from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC # noqa: F401 from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as WGMMAAccumulatorRef +from jax._src.pallas.mosaic_gpu.helpers import find_swizzle as find_swizzle +from jax._src.pallas.mosaic_gpu.helpers import format_tcgen05_sparse_metadata as format_tcgen05_sparse_metadata +from jax._src.pallas.mosaic_gpu.helpers import nd_loop as nd_loop +from jax._src.pallas.mosaic_gpu.helpers import NDLoopInfo as NDLoopInfo +from jax._src.pallas.mosaic_gpu.helpers import planar_snake as planar_snake +from jax._src.pallas.mosaic_gpu.helpers import dynamic_scheduling_loop as dynamic_scheduling_loop from jax._src.pallas.mosaic_gpu.pipeline import emit_pipeline as emit_pipeline from jax._src.pallas.mosaic_gpu.pipeline import emit_pipeline_warp_specialized as emit_pipeline_warp_specialized +from jax._src.pallas.mosaic_gpu.pipeline import PipelinePipeline as PipelinePipeline +from jax._src.pallas.mosaic_gpu.primitives import async_copy_scales_to_tmem as async_copy_scales_to_tmem +from jax._src.pallas.mosaic_gpu.primitives import async_copy_sparse_metadata_to_tmem as async_copy_sparse_metadata_to_tmem +from jax._src.pallas.mosaic_gpu.primitives import async_load_tmem as async_load_tmem +from jax._src.pallas.mosaic_gpu.primitives import async_prefetch as async_prefetch +from jax._src.pallas.mosaic_gpu.primitives import async_store_tmem as async_store_tmem from jax._src.pallas.mosaic_gpu.primitives import barrier_arrive as barrier_arrive from jax._src.pallas.mosaic_gpu.primitives import barrier_wait as barrier_wait from jax._src.pallas.mosaic_gpu.primitives import broadcasted_iota as broadcasted_iota from jax._src.pallas.mosaic_gpu.primitives import commit_smem as commit_smem from jax._src.pallas.mosaic_gpu.primitives import commit_smem_to_gmem_group as commit_smem_to_gmem_group +from jax._src.pallas.mosaic_gpu.primitives import commit_tmem as commit_tmem from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem as copy_gmem_to_smem from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem as copy_smem_to_gmem -from jax._src.pallas.mosaic_gpu.primitives import Layout as Layout -from jax._src.pallas.mosaic_gpu.primitives import layout_cast as layout_cast +from jax._src.pallas.mosaic_gpu.primitives import inline_mgpu as inline_mgpu +from jax._src.pallas.mosaic_gpu.primitives import load as load +from jax._src.pallas.mosaic_gpu.primitives import multimem_store as multimem_store +from jax._src.pallas.mosaic_gpu.primitives import multimem_load_reduce as multimem_load_reduce +from jax._src.pallas.mosaic_gpu.primitives import print_layout as print_layout +from jax._src.pallas.mosaic_gpu.primitives import query_cluster_cancel as query_cluster_cancel +from jax._src.pallas.mosaic_gpu.primitives import RefType as RefType +from jax._src.pallas.mosaic_gpu.primitives import semaphore_signal_multicast as semaphore_signal_multicast +from jax._src.pallas.mosaic_gpu.primitives import semaphore_signal_parallel as semaphore_signal_parallel +from jax._src.pallas.mosaic_gpu.primitives import SemaphoreSignal as SemaphoreSignal from jax._src.pallas.mosaic_gpu.primitives import set_max_registers as set_max_registers +from jax._src.pallas.mosaic_gpu.primitives import ShapeDtypeStruct as ShapeDtypeStruct +from jax._src.pallas.mosaic_gpu.primitives import tcgen05_commit_arrive as tcgen05_commit_arrive +from jax._src.pallas.mosaic_gpu.primitives import tcgen05_mma as tcgen05_mma +from jax._src.pallas.mosaic_gpu.primitives import try_cluster_cancel as try_cluster_cancel +from jax._src.pallas.mosaic_gpu.primitives import wait_load_tmem as wait_load_tmem from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem as wait_smem_to_gmem from jax._src.pallas.mosaic_gpu.primitives import wgmma as wgmma from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait as wgmma_wait -from jax.experimental.mosaic.gpu.core import ThreadSemantics as ThreadSemantics +from jax._src.pallas.mosaic_gpu.torch import as_torch_kernel as as_torch_kernel +from jax.experimental.mosaic.gpu.core import LoweringSemantics as LoweringSemantics +from jax.experimental.mosaic.gpu.fragmented_array import Replicated as Replicated +from jax.experimental.mosaic.gpu.fragmented_array import Tiling as Tiling -#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.GMEM`. -GMEM = GPUMemorySpace.GMEM -#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.SMEM`. -SMEM = GPUMemorySpace.SMEM +#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.MemorySpace.GMEM`. +GMEM = MemorySpace.GMEM +#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.MemorySpace.SMEM`. +SMEM = MemorySpace.SMEM +#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.MemorySpace.TMEM`. +TMEM = MemorySpace.TMEM +#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.MemorySpace.REGS`. +REGS = MemorySpace.REGS diff --git a/jax/experimental/pallas/ops/gpu/all_gather_mgpu.py b/jax/experimental/pallas/ops/gpu/all_gather_mgpu.py new file mode 100644 index 000000000000..9404a7a1f27c --- /dev/null +++ b/jax/experimental/pallas/ops/gpu/all_gather_mgpu.py @@ -0,0 +1,226 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""All-gather kernel implemented using Mosaic GPU.""" + +from collections.abc import Hashable +import functools +import itertools +import math + +import jax +from jax import lax +from jax.experimental import multihost_utils +from jax.experimental import pallas as pl +from jax.experimental.mosaic.gpu import profiler +from jax.experimental.pallas import mosaic_gpu as plgpu +from jax.extend import backend +import jax.numpy as jnp + + +def all_gather( + x: jax.Array, + *, + axis_name: Hashable, + gather_dimension: int = 0, + num_blocks: int | None = None, + tile_size: int | None = None, + vec_size: int | None = None, +) -> jax.Array: + """Performs an all-gather operation using multimem instructions. + + Args: + x: Input array. Should be sharded across the specified axis. + axis_name: Name of the mesh axis to all-gather across. + gather_dimension: Axis along which to gather. + num_blocks: Number of blocks to use. Defaults to the device core count. + tile_size: Total tile size to split across major, gather, and minor dimensions. + vec_size: Vector size for the layout. If None, automatically inferred from dtype. + """ + num_devices = lax.axis_size(axis_name) + input_shape = x.shape + dtype = x.dtype + ndim = len(input_shape) + + if num_blocks is None: + num_blocks = backend.get_default_device().core_count + + if gather_dimension < -ndim or gather_dimension >= ndim: + raise ValueError( + f"gather_dimension {gather_dimension} out of bounds for array of rank" + f" {ndim}" + ) + if gather_dimension < 0: + gather_dimension += ndim + + input_gather_dim = input_shape[gather_dimension] + major_dims = math.prod(input_shape[:gather_dimension]) + minor_dims = math.prod(input_shape[gather_dimension+1:]) + output_gather_dim = input_gather_dim * num_devices + output_shape = ( + *input_shape[:gather_dimension], output_gather_dim, *input_shape[gather_dimension + 1 :], + ) + + if (output_size := math.prod(output_shape)) % 128: + raise ValueError("Output size must be divisible by 128") + if jnp.issubdtype(dtype, jnp.integer): + if vec_size is None: + vec_size = 1 # Integer types only support unvectorized operations + elif vec_size != 1: + raise ValueError("Integer types only support vec_size=1") + elif vec_size is None: # vec_size inference for floating point types + dtype_bits = jnp.finfo(dtype).bits + max_vec_size = min(128 // dtype_bits, output_size // 128) + if tile_size is not None: + max_vec_size_for_tile = tile_size // 128 + max_vec_size = min(max_vec_size, max_vec_size_for_tile) + vec_size = 32 // dtype_bits # We don't support multimem below 32-bit + while vec_size * 2 <= max_vec_size: + vec_size *= 2 + if math.prod(output_shape) % vec_size: + raise ValueError( + "The total number of elements in the output" + f" ({math.prod(output_shape)}) must be divisible by the vec_size" + f" ({vec_size})" + ) + + min_transfer_elems = 128 * vec_size + if tile_size is None: + # TODO(apaszke): 8 is just an arbitrary unrolling factor. Tune it! + unroll_factor = min(math.prod(input_shape) // min_transfer_elems, 8) + tile_size = unroll_factor * min_transfer_elems + if tile_size < min_transfer_elems: + raise ValueError( + f"{tile_size=} is smaller than minimum required" + f" {min_transfer_elems} for {vec_size=}" + ) + + minor_tile = math.gcd(tile_size, minor_dims) + remaining_tile = tile_size // minor_tile + gather_tile = math.gcd(remaining_tile, input_gather_dim) + major_tile = remaining_tile // gather_tile + + if major_dims % major_tile != 0: + raise NotImplementedError( + f"Major dimension size ({major_dims}) must be divisible by the" + f" inferred major tile size ({major_tile}). Consider adjusting tile_size." + ) + + def kernel(x_ref, y_ref, done_barrier): + dev_idx = lax.axis_index(axis_name) + x_ref_3d = x_ref.reshape((major_dims, input_gather_dim, minor_dims)) + y_ref_3d = y_ref.reshape((major_dims, output_gather_dim, minor_dims)) + y_ref_3d = y_ref_3d.at[:, pl.ds(dev_idx * input_gather_dim, input_gather_dim), :] + + major_tiles = major_dims // major_tile + gather_tiles = input_gather_dim // gather_tile + minor_tiles = minor_dims // minor_tile + # TODO(apaszke): Use a TMA pipeline + @plgpu.nd_loop((major_tiles, gather_tiles, minor_tiles), collective_axes="blocks") + def _transfer_loop(loop_info: plgpu.NDLoopInfo): + major_tile_idx, gather_tile_idx, minor_tile_idx = loop_info.index + idxs = ( + pl.ds(major_tile_idx * major_tile, major_tile), + pl.ds(gather_tile_idx * gather_tile, gather_tile), + pl.ds(minor_tile_idx * minor_tile, minor_tile) + ) + output_data = plgpu.layout_cast( + x_ref_3d[idxs], + plgpu.Layout.WG_STRIDED((major_tile, gather_tile, minor_tile), vec_size=vec_size) + ) + plgpu.multimem_store(output_data, y_ref_3d.at[idxs], axis_name) + + # Wait for everyone to finish storing into our memory before returning. + plgpu.semaphore_signal_multicast(done_barrier, collective_axes=axis_name) + pl.semaphore_wait(done_barrier, num_devices, decrement=False) + + # TODO(b/448323639): We fake modify the input to ensure that XLA:GPU copies + # the operand into symmetric memory. + @pl.when(dev_idx == -1) + def _never(): + x_ref[(0,) * len(x_ref.shape)] = jnp.asarray(0, x_ref.dtype) + + return plgpu.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct(output_shape, dtype), + grid=(num_blocks,), + grid_names=("blocks",), + scratch_shapes=[plgpu.SemaphoreType.REGULAR], + )(x) + + +def _run_example(): + P = jax.sharding.PartitionSpec + shape = (4 * 4096, 4 * 4096) # This shape is global! + dtype = jnp.bfloat16 + shards = jax.device_count() + mesh = jax.make_mesh( + (shards,), ("x",), axis_types=(jax.sharding.AxisType.Explicit,) + ) + jax.set_mesh(mesh) + + # We measure time per-shard and so we only need bytes per shard. + local_out_bytes = math.prod(shape) * jnp.dtype(dtype).itemsize + total_bytes = local_out_bytes + + a = jax.random.normal(jax.random.key(1), shape, dtype) + a = jax.sharding.reshard(a, P("x", None)) + + @jax.jit + @functools.partial(jax.shard_map, mesh=mesh, in_specs=P("x", None), out_specs=P("x", None)) + def ref_fn(x): + return lax.all_gather(x, "x", axis=0, tiled=True) + ref_fn(a).block_until_ready() # Warmup. + _, ref_kernels_ms = profiler.measure(ref_fn, aggregate=False)(a) + ref_time_us = sum(t * 1e3 for _, t in ref_kernels_ms) + # We choose the minimum across processes to choose the runtime that didn't + # include devices waiting for other devices. + ref_time_us = min(multihost_utils.process_allgather(ref_time_us).tolist()) + ref_bw = total_bytes / (ref_time_us * 1e-6) / 1e9 # GB/s + + tuning_it = itertools.product( + (4, 8, 16, 32, 64, 132), # num_blocks + (1024, 2048, 4096, 8192), # tile_size + ) + best_bw = 0.0 + best_runtime = float("inf") + for num_blocks, tile_size in tuning_it: + @jax.jit + @functools.partial( + jax.shard_map, mesh=mesh, in_specs=P("x", None), out_specs=P("x", None), check_vma=False + ) + def kernel_fn(x): + return all_gather(x, axis_name="x", gather_dimension=0, num_blocks=num_blocks, tile_size=tile_size) + try: + _, kernels_ms = profiler.measure(kernel_fn, aggregate=False)(a) + except ValueError as e: + if "exceeds available shared memory" in e.args[0]: # Ignore SMEM OOMs. + continue + raise + runtime_us = sum(t * 1e3 for _, t in kernels_ms) + runtime_us = min(multihost_utils.process_allgather(runtime_us).tolist()) + achieved_bw = total_bytes / (runtime_us * 1e-6) / 1e9 # GB/s + if achieved_bw > best_bw: + best_runtime = runtime_us + best_bw = achieved_bw + print(f"{num_blocks=}, {tile_size=}: {runtime_us:<7.1f}us = {achieved_bw:4.1f} GB/s") + + print(f"Total bytes transferred: {total_bytes / 1e9:.2f} GB") + print(f"\tBest: {best_runtime:<7.1f}us = {best_bw:4.1f} GB/s") + print(f"\tRef: {ref_time_us:<7.1f}us = {ref_bw:4.1f} GB/s") + + +if __name__ == "__main__": + from jax._src import test_multiprocess as jt_multiprocess # pytype: disable=import-error + jt_multiprocess.main(shard_main=_run_example) diff --git a/jax/experimental/pallas/ops/gpu/attention.py b/jax/experimental/pallas/ops/gpu/attention.py index 8b83d24ea199..dea524e42ac3 100644 --- a/jax/experimental/pallas/ops/gpu/attention.py +++ b/jax/experimental/pallas/ops/gpu/attention.py @@ -57,10 +57,10 @@ def get_default(cls): return BlockSizes( block_q=128, block_k=128, - block_q_dkv=128, - block_kv_dkv=128, - block_q_dq=128, - block_kv_dq=128, + block_q_dkv=32, + block_kv_dkv=32, + block_q_dq=32, + block_kv_dq=32, ) @property @@ -86,32 +86,31 @@ def mha_forward_kernel( segment_ids_ref: jax.Array | None, # segment_id arrays o_ref: Any, # Output *residual_refs: Any, # Residual outputs - num_heads: int, sm_scale: float, causal: bool, block_q: int, - block_d: int, block_k: int, + head_dim: int, ): seq_len = k_ref.shape[0] start_q = pl.program_id(0) + head_dim_padded = q_ref.shape[-1] # o is the buffer where we accumulate the output on sram. # m_i and l_i (see FlashAttention paper) are updated during the k,v loop. m_i = jnp.zeros(block_q, dtype=jnp.float32) - float('inf') l_i = jnp.zeros(block_q, dtype=jnp.float32) # acc is the buffer where we accumulate the output on sram. - o = jnp.zeros((block_q, block_d), dtype=jnp.float32) + o = jnp.zeros((block_q, head_dim_padded), dtype=jnp.float32) # Load q: it will stay in L1 throughout. Indices form a matrix because we # read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index. - # q tile has shape [block_q, block_d], block_d == head_dim. + # q tile has shape [block_q, head_dim_padded], head_dim_padded >= head_dim. curr_q_slice = pl.dslice(start_q * block_q, block_q) - q = q_ref[...] + head_mask = (jnp.arange(head_dim_padded) < head_dim)[None, :] + q = plgpu.load(q_ref, mask=head_mask, other=0.0) q_segment_ids = ( - None - if segment_ids_ref is None - else pl.load(segment_ids_ref, (curr_q_slice,)) + None if segment_ids_ref is None else segment_ids_ref[curr_q_slice] ) # In FlashAttention algorithm 1 there are 2 loops: slow over tiles of kv (size # (Bc == block_k here), and fast over blocks of q (size Br == block_q here). @@ -121,7 +120,7 @@ def body(start_k, carry): o_prev, m_prev, l_prev = carry curr_k_slice = pl.dslice(start_k * block_k, block_k) - k = pl.load(k_ref, (curr_k_slice, slice(None))) + k = plgpu.load(k_ref.at[curr_k_slice, :], mask=head_mask, other=0.0) qk = pl.dot(q, k.T) # [block_q, block_k] # Scale logits to convert from base-2 to the natural log domain. @@ -139,7 +138,7 @@ def body(start_k, carry): if causal or segment_ids_ref is not None: mask = None if segment_ids_ref is not None: - kv_segment_ids = pl.load(segment_ids_ref, (curr_k_slice,)) + kv_segment_ids = segment_ids_ref[curr_k_slice] mask = segment_mask(q_segment_ids, kv_segment_ids) if causal: span_q = start_q * block_q + jnp.arange(block_q) @@ -151,7 +150,7 @@ def body(start_k, carry): # Apply mask to qk. qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE) - m_curr = qk.max(axis=-1) + m_curr = jnp.max(qk, axis=-1) m_next = jnp.maximum(m_prev, m_curr) correction = jnp.exp2(m_prev - m_next) l_prev_corr = correction * l_prev @@ -161,7 +160,7 @@ def body(start_k, carry): l_curr = s_curr.sum(axis=-1) l_next = l_prev_corr + l_curr o_prev_corr = correction[:, None] * o_prev - v = pl.load(v_ref, (curr_k_slice, pl.dslice(block_d))) + v = plgpu.load(v_ref.at[curr_k_slice, :], mask=head_mask) o_curr = pl.dot(s_curr.astype(v.dtype), v) o_next = o_prev_corr + o_curr @@ -182,7 +181,7 @@ def body(start_k, carry): lse_ref = residual_refs[0] lse_ref[...] = m_i + jnp.log2(l_i) # Write output to dram. - o_ref[...] = o.astype(o_ref.dtype) + plgpu.store(o_ref.at[:, : o.shape[-1]], o.astype(o_ref.dtype), mask=head_mask) def segment_mask( q_segment_ids: jax.Array, @@ -199,7 +198,7 @@ def segment_mask( @functools.partial( - jax.custom_vjp, nondiff_argnums=[4, 5, 6, 7, 8, 9, 10, 11, 12] + jax.custom_vjp, nondiff_argnums=[4, 5, 6, 7, 8, 9, 10, 11, 12, 13] ) @functools.partial( jax.jit, @@ -213,6 +212,7 @@ def segment_mask( "grid", "interpret", "debug", + "return_residuals", ], ) def mha( @@ -229,12 +229,24 @@ def mha( grid: tuple[int, ...] | None = None, interpret: bool = False, debug: bool = False, + return_residuals: bool = False, ): del backward_pass_impl batch_size, q_seq_len, num_heads, head_dim = q.shape kv_seq_len = k.shape[1] block_q = min(block_sizes.block_q, q_seq_len) block_k = min(block_sizes.block_k, kv_seq_len) + head_dim_padded = pl.next_power_of_2(head_dim) + if (q.shape[-1] != k.shape[-1]) or (q.shape[-1] != v.shape[-1]): + raise ValueError( + f"This kernel expects q, k, and v to have the same head dimension, but" + f" found {q.shape=}, {k.shape=}, {v.shape=}." + ) + if q_seq_len % block_q != 0: + raise ValueError(f"{q_seq_len=} must be a multiple of {block_q=}") + if kv_seq_len % block_k != 0: + raise ValueError(f"{kv_seq_len=} must be a multiple of {block_k=}") + # Heuristics. grid_ = grid if grid_ is None: @@ -243,42 +255,44 @@ def mha( num_warps_ = num_warps if num_warps_ is None: num_warps_ = 4 if head_dim <= 64 else 8 - kernel = functools.partial(mha_forward_kernel, num_heads=num_heads, - sm_scale=sm_scale, block_q=block_q, - block_k=block_k, block_d=head_dim, - causal=causal) + kernel = functools.partial(mha_forward_kernel, sm_scale=sm_scale, + block_q=block_q, block_k=block_k, + head_dim=head_dim, causal=causal) in_specs = [ - pl.BlockSpec( - (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) - ), - pl.BlockSpec( - (None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) - ), - pl.BlockSpec( - (None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) - ), + pl.BlockSpec((None, block_q, None, head_dim_padded), + lambda i, j, k: (j, i, k, 0)), + pl.BlockSpec((None, kv_seq_len, None, head_dim_padded), + lambda _, j, k: (j, 0, k, 0)), + pl.BlockSpec((None, kv_seq_len, None, head_dim_padded), + lambda _, j, k: (j, 0, k, 0)), ] in_specs.append( None # type: ignore[arg-type] if segment_ids is None else pl.BlockSpec((None, kv_seq_len), lambda _, j, k: (j, 0)) ) - out_shape = jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype) - return pl.pallas_call( + out_shape = [q] + out_specs = [pl.BlockSpec((None, block_q, None, head_dim_padded), + lambda i, j, k: (j, i, k, 0))] + if return_residuals: + out_shape.append(jax.ShapeDtypeStruct( + shape=(batch_size, num_heads, q_seq_len), dtype=jnp.float32)) # lse + out_specs.append( + pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i))) # lse + out = pl.pallas_call( kernel, grid=grid_, in_specs=in_specs, - out_specs=pl.BlockSpec( - (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) - ), - compiler_params=plgpu.TritonCompilerParams( + out_specs=out_specs, + compiler_params=plgpu.CompilerParams( num_warps=num_warps_, num_stages=num_stages), out_shape=out_shape, debug=debug, interpret=interpret, name="mha_forward", )(q, k, v, segment_ids) + return out if return_residuals else out[0] def _mha_forward( @@ -295,70 +309,24 @@ def _mha_forward( grid: Any, interpret: bool, debug: bool, + return_residuals: bool, ): - del backward_pass_impl - batch_size, q_seq_len, num_heads, head_dim = q.shape - kv_seq_len = k.shape[1] - block_q = min(block_sizes.block_q, q_seq_len) - block_k = min(block_sizes.block_k, kv_seq_len) - # Heuristics. - grid_ = grid - if grid_ is None: - grid_ = (pl.cdiv(q_seq_len, block_q), batch_size, num_heads) - - num_warps_ = num_warps - if num_warps_ is None: - num_warps_ = 4 if head_dim <= 64 else 8 - kernel = functools.partial(mha_forward_kernel, num_heads=num_heads, - sm_scale=sm_scale, causal=causal, block_q=block_q, - block_k=block_k, block_d=head_dim) - out_shape = [ - jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), # out - jax.ShapeDtypeStruct( - shape=(batch_size, num_heads, q_seq_len), dtype=jnp.float32 # lse - ), - ] - in_specs = [ - pl.BlockSpec( - (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) - ), - pl.BlockSpec( - (None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) - ), - pl.BlockSpec( - (None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) - ), - ] - in_specs.append( - None # type: ignore[arg-type] - if segment_ids is None - else pl.BlockSpec((None, kv_seq_len), lambda _, j, k: (j, 0)) - ) - out, lse = pl.pallas_call( - kernel, - grid=grid_, - in_specs=in_specs, - out_specs=[ - pl.BlockSpec( - (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) - ), - pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i)), - ], - compiler_params=plgpu.TritonCompilerParams( - num_warps=num_warps_, num_stages=num_stages - ), - out_shape=out_shape, - debug=debug, - interpret=interpret, - name="mha_forward", - )(q, k, v, segment_ids) - return out, (q, k, v, segment_ids, out, lse) - - -def _preprocess_backward_kernel(out_ref, dout_ref, delta_ref): + out, lse = mha(q, k, v, segment_ids=segment_ids, sm_scale=sm_scale, + causal=causal, block_sizes=block_sizes, + backward_pass_impl=backward_pass_impl, + num_warps=num_warps, num_stages=num_stages, + grid=grid, interpret=interpret, debug=debug, + return_residuals=True) + residuals = (q, k, v, segment_ids, out, lse) + ret = (out, lse) if return_residuals else out + return ret, residuals + + +def _preprocess_backward_kernel(out_ref, dout_ref, delta_ref, head_dim: int): # load - o = out_ref[...].astype(jnp.float32) - do = dout_ref[...].astype(jnp.float32) + head_mask = (jnp.arange(out_ref.shape[-1]) < head_dim)[None, :] + o = plgpu.load(out_ref, mask=head_mask, other=0.0) + do = plgpu.load(dout_ref, mask=head_mask, other=0.0) # compute delta = jnp.sum(o * do, axis=1) # write-back @@ -368,20 +336,19 @@ def _preprocess_backward_kernel(out_ref, dout_ref, delta_ref): def _preprocess_backward(out, do, lse, block_q: int, debug: bool, interpret: bool): batch_size, seq_len, num_heads, head_dim = out.shape + head_dim_padded = pl.next_power_of_2(head_dim) out_shape = jax.ShapeDtypeStruct(lse.shape, lse.dtype) delta = pl.pallas_call( - _preprocess_backward_kernel, + functools.partial(_preprocess_backward_kernel, head_dim=head_dim), grid=(pl.cdiv(seq_len, block_q), batch_size, num_heads), in_specs=[ - pl.BlockSpec( - (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) - ), - pl.BlockSpec( - (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) - ), + pl.BlockSpec((None, block_q, None, head_dim_padded), + lambda i, j, k: (j, i, k, 0)), + pl.BlockSpec((None, block_q, None, head_dim_padded), + lambda i, j, k: (j, i, k, 0)), ], out_specs=pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i)), - compiler_params=plgpu.TritonCompilerParams(num_warps=4, num_stages=3), + compiler_params=plgpu.CompilerParams(num_warps=4, num_stages=3), out_shape=out_shape, debug=debug, interpret=interpret, @@ -414,7 +381,7 @@ def mha_backward_kernel( block_kv_dkv: int, block_q_dq: int, block_kv_dq: int, - block_d: int, + head_dim: int, ): del out_ref # Not needed q_seq_len = q_ref.shape[0] @@ -427,23 +394,23 @@ def mha_backward_kernel( start_k = pl.program_id(2) curr_k_slice = pl.dslice(start_k * block_kv_dkv, block_kv_dkv) - dv = jnp.zeros([block_kv_dkv, block_d], dtype=jnp.float32) - dk = jnp.zeros([block_kv_dkv, block_d], dtype=jnp.float32) + head_dim_padded = q_ref.shape[-1] + dv = jnp.zeros([block_kv_dkv, head_dim_padded], dtype=jnp.float32) + dk = jnp.zeros([block_kv_dkv, head_dim_padded], dtype=jnp.float32) - v = pl.load(v_ref, (curr_k_slice, slice(None))) - k = pl.load(k_ref, (curr_k_slice, slice(None))) + head_mask = (jnp.arange(head_dim_padded) < head_dim)[None, :] + v = plgpu.load(v_ref.at[curr_k_slice, :], mask=head_mask, other=0.0) + k = plgpu.load(k_ref.at[curr_k_slice, :], mask=head_mask, other=0.0) span_k = start_k * block_kv_dkv + jnp.arange(block_kv_dkv) kv_segment_ids = ( - None - if segment_ids_ref is None - else pl.load(segment_ids_ref, (curr_k_slice,)) + None if segment_ids_ref is None else segment_ids_ref[curr_k_slice] ) def inner_loop_dkdv(start_q, carry): dv, dk = carry curr_q_slice = pl.dslice(start_q * block_q_dkv, block_q_dkv) - q = pl.load(q_ref, (curr_q_slice, slice(None))) + q = plgpu.load(q_ref.at[curr_q_slice, :], mask=head_mask, other=0.0) qk = pl.dot(q, k.T) qk_scale = math.log2(math.e) if sm_scale != 1.: @@ -453,7 +420,7 @@ def inner_loop_dkdv(start_q, carry): if causal or segment_ids_ref is not None: mask = None if segment_ids_ref is not None: - q_segment_ids = pl.load(segment_ids_ref, (curr_q_slice,)) + q_segment_ids = segment_ids_ref[curr_q_slice] mask = segment_mask(q_segment_ids, kv_segment_ids) if causal: @@ -464,9 +431,11 @@ def inner_loop_dkdv(start_q, carry): ) qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE) - lse = pl.load(lse_ref, (curr_q_slice,)) - di = pl.load(delta_ref, (curr_q_slice,)) - do = pl.load(do_scaled_ref, (curr_q_slice, slice(None))) + lse = lse_ref[curr_q_slice] + di = delta_ref[curr_q_slice] + do = plgpu.load( + do_scaled_ref.at[curr_q_slice, :], mask=head_mask, other=0.0 + ) p = jnp.exp2(qk - lse[:, None]) dv = dv + pl.dot(p.astype(do.dtype).T, do) @@ -483,8 +452,12 @@ def inner_loop_dkdv(start_q, carry): dv, dk = lax.fori_loop( lower_bound, pl.cdiv(q_seq_len, block_q_dkv), inner_loop_dkdv, (dv, dk) ) - dv_ref[...] = dv.astype(dv_ref.dtype) - dk_ref[...] = dk.astype(dk_ref.dtype) + plgpu.store( + dv_ref.at[:, : dv.shape[-1]], dv.astype(dv_ref.dtype), mask=head_mask + ) + plgpu.store( + dk_ref.at[:, : dk.shape[-1]], dk.astype(dk_ref.dtype), mask=head_mask + ) # Scan #2: dQ # 1. Load a block of Q of size (block_q_dq, head_dim) in SMEM. @@ -493,22 +466,20 @@ def inner_loop_dkdv(start_q, carry): start_q = pl.program_id(2) curr_q_slice = pl.ds(start_q * block_q_dq, block_q_dq) span_q = start_q * block_q_dq + jnp.arange(block_q_dq) - dq = jnp.zeros([block_q_dq, block_d], dtype=jnp.float32) + dq = jnp.zeros([block_q_dq, head_dim_padded], dtype=jnp.float32) - q = pl.load(q_ref, (curr_q_slice, slice(None))) + q = plgpu.load(q_ref.at[curr_q_slice, :], mask=head_mask, other=0.0) q_segment_ids = ( - None - if segment_ids_ref is None - else pl.load(segment_ids_ref, (curr_q_slice,)) + None if segment_ids_ref is None else segment_ids_ref[curr_q_slice] ) - lse = pl.load(lse_ref, (curr_q_slice,)) - do = pl.load(do_scaled_ref, (curr_q_slice, slice(None))) - di = pl.load(delta_ref, (curr_q_slice,)) + lse = lse_ref[curr_q_slice] + do = plgpu.load(do_scaled_ref.at[curr_q_slice, :], mask=head_mask, other=0.0) + di = delta_ref[curr_q_slice] def inner_loop_dq(start_k, dq): curr_k_slice = pl.dslice(start_k * block_kv_dq, block_kv_dq) - k = pl.load(k_ref, (curr_k_slice, slice(None))) - v = pl.load(v_ref, (curr_k_slice, slice(None))) + k = plgpu.load(k_ref.at[curr_k_slice, :], mask=head_mask, other=0.0) + v = plgpu.load(v_ref.at[curr_k_slice, :], mask=head_mask, other=0.0) qk = pl.dot(q, k.T) qk_scale = math.log2(math.e) @@ -519,7 +490,7 @@ def inner_loop_dq(start_k, dq): if causal or segment_ids_ref is not None: mask = None if segment_ids_ref is not None: - kv_segment_ids = pl.load(segment_ids_ref, (curr_k_slice,)) + kv_segment_ids = segment_ids_ref[curr_k_slice] mask = segment_mask(q_segment_ids, kv_segment_ids) if causal: @@ -547,15 +518,20 @@ def inner_loop_dq(start_k, dq): upper_bound = pl.cdiv(kv_seq_len, block_kv_dq) dq = lax.fori_loop(0, upper_bound, inner_loop_dq, (dq)) - dq_ref[...] = dq.astype(dq_ref.dtype) + plgpu.store( + dq_ref.at[:, : dq.shape[-1]], dq.astype(dq_ref.dtype), mask=head_mask + ) def _mha_backward(sm_scale: float, causal: bool, block_sizes: BlockSizes, backward_pass_impl: str, num_warps: int | None, num_stages: int, grid: Any, interpret: bool, - debug: bool, res, do): - del num_stages, grid + debug: bool, return_residuals: bool, res, do): + if return_residuals: + raise ValueError( + "Kernel differentiation is not supported if return_residuals is True.") q, k, v, segment_ids, out, lse = res + del num_stages, grid, return_residuals if backward_pass_impl == "xla": return jax.vjp( @@ -576,6 +552,7 @@ def _mha_backward(sm_scale: float, causal: bool, block_sizes: BlockSizes, block_kv_dkv = min(block_sizes.block_kv_dkv, kv_seq_len) block_q_dq = min(block_sizes.block_q_dq, q_seq_len) block_kv_dq = min(block_sizes.block_kv_dq, kv_seq_len) + head_dim_padded = pl.next_power_of_2(head_dim) if q_seq_len // block_q_dq != kv_seq_len // block_kv_dkv: raise ValueError( @@ -591,28 +568,24 @@ def _mha_backward(sm_scale: float, causal: bool, block_sizes: BlockSizes, ] in_specs = [ - pl.BlockSpec( - (None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) - ), - pl.BlockSpec( - (None, kv_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) - ), - pl.BlockSpec( - (None, kv_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) - ), - pl.BlockSpec( - (None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) - ), - pl.BlockSpec( - (None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) - ), + pl.BlockSpec((None, q_seq_len, None, head_dim_padded), + lambda i, j, _: (i, 0, j, 0)), + pl.BlockSpec((None, kv_seq_len, None, head_dim_padded), + lambda i, j, _: (i, 0, j, 0)), + pl.BlockSpec((None, kv_seq_len, None, head_dim_padded), + lambda i, j, _: (i, 0, j, 0)), + pl.BlockSpec((None, q_seq_len, None, head_dim_padded), + lambda i, j, _: (i, 0, j, 0)), + pl.BlockSpec((None, q_seq_len, None, head_dim_padded), + lambda i, j, _: (i, 0, j, 0)), pl.BlockSpec((None, None, q_seq_len), lambda i, j, _: (i, j, 0)), pl.BlockSpec((None, None, q_seq_len), lambda i, j, _: (i, j, 0)), ] if segment_ids is None: in_specs.insert(3, None) # type: ignore[arg-type] else: - in_specs.insert(3, pl.BlockSpec((None, kv_seq_len), lambda i, j, _: (i, 0))) + in_specs.insert(3, pl.BlockSpec((None, kv_seq_len), + lambda i, j, _: (i, 0))) grid = (batch_size, num_heads, pl.cdiv(kv_seq_len, block_kv_dkv)) num_warps_ = num_warps @@ -635,29 +608,29 @@ def _mha_backward(sm_scale: float, causal: bool, block_sizes: BlockSizes, block_kv_dkv=block_kv_dkv, block_q_dq=block_q_dq, block_kv_dq=block_kv_dq, - block_d=head_dim, + head_dim=head_dim, ), out_shape=out_shapes, in_specs=in_specs, grid=grid, out_specs=[ pl.BlockSpec( - (None, block_q_dq, None, head_dim), + (None, block_q_dq, None, head_dim_padded), lambda i, j, k: (i, k, j, 0), # dq ), pl.BlockSpec( - (None, block_kv_dkv, None, head_dim), + (None, block_kv_dkv, None, head_dim_padded), lambda i, j, k: (i, k, j, 0), # dk ), pl.BlockSpec( - (None, block_kv_dkv, None, head_dim), + (None, block_kv_dkv, None, head_dim_padded), lambda i, j, k: (i, k, j, 0), # dv ), ], name="mha_backward", debug=debug, interpret=interpret, - compiler_params=plgpu.TritonCompilerParams( + compiler_params=plgpu.CompilerParams( num_warps=num_warps_, num_stages=2 ), )(q, k, v, segment_ids, out, do, lse, delta) diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index 8883878f5f0e..dd9bf85c91b2 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -20,12 +20,13 @@ import jax from jax import lax from jax._src import test_util as jtu # noqa: F401 +from jax._src.lib import cuda_versions # noqa: F401 from jax.experimental.mosaic.gpu import profiler import jax.experimental.pallas as pl import jax.experimental.pallas.mosaic_gpu as plgpu import jax.numpy as jnp import numpy as np - +from functools import partial @dataclasses.dataclass(frozen=True) class TuningConfig: @@ -33,6 +34,13 @@ class TuningConfig: block_kv: int max_concurrent_steps: int use_schedule_barrier: bool = True + causal: bool = False + compute_wgs_bwd: int = 1 + + block_q_dkv: int | None = None + block_kv_dkv: int | None = None + block_q_dq: int | None = None + block_kv_dq: int | None = None def __post_init__(self): if self.block_q % 64: @@ -42,9 +50,27 @@ def __post_init__(self): if self.max_concurrent_steps < 2: raise ValueError(f"{self.max_concurrent_steps=} must be at least 2") + backward_blocks = [self.block_q_dkv, self.block_kv_dkv, self.block_q_dq, self.block_kv_dq] + block_is_set = [blk is not None for blk in backward_blocks] + if any(block_is_set) and not all(block_is_set): + raise ValueError( + "Backward block sizes (block_q_dkv, block_kv_dkv, block_q_dq, " + "block_kv_dq) must either all be specified or all be None." + ) -@functools.partial(jax.jit, static_argnames=["config"]) -def attention(q, k, v, config: TuningConfig): + @property + def has_backward_blocks(self) -> bool: + return self.block_q_dkv is not None + +def _attention_forward(q, k, v, config: TuningConfig, save_residuals: bool = False): + assert cuda_versions is not None + cuda_runtime_version = cuda_versions.cuda_runtime_get_version() + # TODO(pobudzey): Undo when we upgrade to cuda 12.9.1. + if config.causal and cuda_runtime_version >= 12080 and cuda_runtime_version < 12091: + raise ValueError( + "Causal masking not supported with cuda versions between 12.8.0 and" + " 12.9.1 due to a ptxas miscompilation." + ) if q.ndim != 4 or k.ndim != 4 or v.ndim != 4: raise ValueError(f"q, k, and v should all be 4D, got: {q.ndim=}, {k.ndim=}, {v.ndim=}") batch_size, q_seq_len, num_q_heads, head_dim = q.shape @@ -68,25 +94,39 @@ def attention(q, k, v, config: TuningConfig): config.max_concurrent_steps, kv_seq_len // config.block_kv ) block_q, block_kv = config.block_q, config.block_kv + if kv_seq_len % block_kv: + raise ValueError(f"{kv_seq_len=} must be a multiple of {block_kv=}") - def kernel(q_ref, k_ref, v_ref, out_ref, scoped): + def kernel(q_ref, k_ref, v_ref, out_ref, lse_ref, scoped): batch = lax.axis_index("batch") q_head = lax.axis_index("heads") smem_buffers, buffer_barriers, consumed_barriers, schedule_barrier = scoped wg_idx = lax.axis_index("wg") - qo_smem2, k_smem, v_smem = smem_buffers + qo_smem2, k_smem, v_smem, lse_smem2 = smem_buffers k_barriers, v_barriers, q_barriers = buffer_barriers k_consumed_barriers, v_consumed_barriers = consumed_barriers def perform_schedule_barrier(): plgpu.barrier_arrive(schedule_barrier) plgpu.barrier_wait(schedule_barrier) + if config.causal: + block_q_end = (lax.axis_index("q_seq") + 1) * (2 * block_q) + block_max_kv_steps = pl.cdiv(block_q_end, jnp.array(block_kv, jnp.int32)) + else: + block_max_kv_steps = kv_seq_len // block_kv + @pl.when(wg_idx < 2) def _compute_wg(): plgpu.set_max_registers(232, action="increase") qo_smem = qo_smem2.at[wg_idx] + lse_smem = lse_smem2.at[wg_idx] if lse_smem2 is not None else None q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q + if config.causal: + kv_steps = pl.cdiv(q_seq_base + block_q, jnp.array(block_kv, jnp.int32)) + else: + kv_steps = block_max_kv_steps + plgpu.copy_gmem_to_smem( q_ref.at[batch, pl.ds(q_seq_base, block_q), q_head], qo_smem, @@ -104,12 +144,14 @@ def _compute_wg(): jnp.full((block_q, head_dim), 0, dtype=jnp.float32), plgpu.Layout.WGMMA, ) - plgpu.barrier_wait(k_barriers.at[0]) + @pl.when(kv_steps > 0) + def _(): + plgpu.barrier_wait(k_barriers.at[0]) pl.when(wg_idx == 1)(perform_schedule_barrier) - def kv_loop(kv_step, carry): + def kv_loop(kv_step, carry, causal: bool = False): acc, m_i, l_i = carry - slot = lax.rem(kv_step, max_concurrent_steps) + slot = lax.rem(kv_step, jnp.array(max_concurrent_steps, kv_step.dtype)) # QK def compute_qk(acc_ref): @@ -119,6 +161,12 @@ def compute_qk(acc_ref): qk = pl.run_scoped(compute_qk, plgpu.ACC((block_q, block_kv), jnp.float32)) plgpu.barrier_arrive(k_consumed_barriers.at[slot]) + if causal: + q_ids = plgpu.broadcasted_iota(jnp.int32, (block_q, block_kv), 0, layout=plgpu.Layout.WGMMA) + kv_ids = plgpu.broadcasted_iota(jnp.int32, (block_q, block_kv), 1, layout=plgpu.Layout.WGMMA) + mask = (q_ids + q_seq_base) >= (kv_ids + kv_step * block_kv) + qk = jnp.where(mask, qk, -jnp.inf) + # Softmax # We keep m scaled by log2e to use FMA instructions when computing p. log2e = math.log2(math.e) @@ -149,28 +197,53 @@ def compute_pv(acc_ref): plgpu.wgmma(acc_ref, p16, v_smem.at[slot]) wait_step = kv_step + 1 - wait_slot = lax.rem(wait_step, max_concurrent_steps) - @pl.when(wait_step < kv_seq_len // block_kv) + wait_slot = lax.rem(wait_step, jnp.array(max_concurrent_steps, kv_step.dtype)) + @pl.when(wait_step < kv_steps) def _wait(): plgpu.barrier_wait(k_barriers.at[wait_slot]) acc = pl.run_state(compute_pv)(plgpu.ACC.init(acc)) plgpu.barrier_arrive(v_consumed_barriers.at[slot]) return acc, m_i, l_i - if kv_seq_len % block_kv: - raise ValueError(f"{kv_seq_len=} must be a multiple of {block_kv=}") - acc, m_i, l_i = lax.fori_loop( - 0, kv_seq_len // block_kv, kv_loop, (acc, m_i, l_i) - ) + + if not config.causal: + acc, m_i, l_i = lax.fori_loop(0, block_max_kv_steps, kv_loop, (acc, m_i, l_i)) + else: + def epilogue_kv_loop(kv_step, _): + # This loop makes sure that all the pipelined KV data is processed, even + # if one compute wg finishes early like with causal masking. + slot = lax.rem(kv_step, jnp.array(max_concurrent_steps, kv_step.dtype)) + plgpu.barrier_arrive(k_consumed_barriers.at[slot]) + plgpu.barrier_arrive(v_consumed_barriers.at[slot]) + perform_schedule_barrier() + perform_schedule_barrier() + + causal_kv_loop = functools.partial(kv_loop, causal=True) + full_kv_steps = lax.div(q_seq_base, jnp.array(block_kv, jnp.int32)) + # With causal masking, the KV loop unrolling is split in 3 sections: + # 1. A fast path where no causal mask is needed. + acc, m_i, l_i = lax.fori_loop(0, full_kv_steps, kv_loop, (acc, m_i, l_i)) + # 2. Causal masking. + acc, m_i, l_i = lax.fori_loop(full_kv_steps, kv_steps, causal_kv_loop, (acc, m_i, l_i)) + # 3. Epilogue to flush the data pipeline. + lax.fori_loop(kv_steps, block_max_kv_steps, epilogue_kv_loop, None) pl.when(wg_idx == 0)(perform_schedule_barrier) - del m_i # Not needed anymore # TODO(apaszke): Invert and multiply to avoid expensive divisions. acc /= lax.broadcast_in_dim(l_i, (block_q, head_dim), [0]) qo_smem[...] = acc.astype(dtype) + if lse_smem is not None: + RCP_LN2 = 1.4426950408889634 + log2 = lambda x: jnp.log(x) * RCP_LN2 + lse_smem[...] = m_i + log2(l_i) plgpu.commit_smem() plgpu.copy_smem_to_gmem( qo_smem, out_ref.at[batch, pl.ds(q_seq_base, block_q), q_head], ) + if lse_smem is not None: + plgpu.copy_smem_to_gmem( + lse_smem, + lse_ref.at[batch, q_head, pl.ds(q_seq_base, block_q)], + ) plgpu.wait_smem_to_gmem(0) @pl.when(wg_idx == 2) def _memory_wg(): @@ -181,19 +254,19 @@ def _memory_wg(): plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[i], k_barriers.at[i]) plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[i], v_barriers.at[i]) - def kv_loop(kv_step, _): + @pl.loop(0, block_max_kv_steps - max_concurrent_steps) + def _kv_loop(kv_step): tma_step = kv_step + max_concurrent_steps - tma_slot = lax.rem(kv_step, max_concurrent_steps) + tma_slot = lax.rem(kv_step, jnp.array(max_concurrent_steps, kv_step.dtype)) s = (batch, pl.ds(tma_step * block_kv, block_kv), kv_head) plgpu.barrier_wait(k_consumed_barriers.at[tma_slot]) plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[tma_slot], k_barriers.at[tma_slot]) plgpu.barrier_wait(v_consumed_barriers.at[tma_slot]) plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[tma_slot], v_barriers.at[tma_slot]) - lax.fori_loop(0, kv_seq_len // block_kv - max_concurrent_steps, kv_loop, None) - def entry(q_ref, k_ref, v_ref, out_ref): + def entry(q_ref, k_ref, v_ref, out_ref, lse_ref): compute_wgs = 2 - tiling = plgpu.TilingTransform((64, 64)) + tiling = plgpu.TilingTransform((8, 64)) swizzle = plgpu.SwizzleTransform(128) qo_scratch = plgpu.SMEM( (compute_wgs, block_q, head_dim), jnp.float16, @@ -201,39 +274,371 @@ def entry(q_ref, k_ref, v_ref, out_ref): ) k_scratch = plgpu.SMEM( (max_concurrent_steps, block_kv, head_dim), jnp.float16, - transforms=(tiling, plgpu.TransposeTransform((0, 2, 1, 3, 4)), swizzle), + transforms=(tiling, swizzle), ) v_scratch = plgpu.SMEM( (max_concurrent_steps, block_kv, head_dim), jnp.float16, transforms=(tiling, swizzle), ) + scratch = [qo_scratch, k_scratch, v_scratch, None] + if save_residuals: + scratch[3] = plgpu.SMEM((compute_wgs, block_q), jnp.float32) pl.run_scoped( - lambda *args: kernel(q_ref, k_ref, v_ref, out_ref, args), - (qo_scratch, k_scratch, v_scratch), + lambda *args: kernel(q_ref, k_ref, v_ref, out_ref, lse_ref, args), + scratch, ( - plgpu.Barrier(1, num_barriers=max_concurrent_steps), - plgpu.Barrier(1, num_barriers=max_concurrent_steps), - plgpu.Barrier(1, num_barriers=compute_wgs), + plgpu.Barrier(num_barriers=max_concurrent_steps), + plgpu.Barrier(num_barriers=max_concurrent_steps), + plgpu.Barrier(num_barriers=compute_wgs), ), (plgpu.Barrier(num_arrivals=compute_wgs, num_barriers=max_concurrent_steps),) * 2, plgpu.Barrier(num_arrivals=compute_wgs), + collective_axes="wg", ) num_q_tiles, rem = divmod(q_seq_len, block_q * 2) if rem: raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}") - return plgpu.kernel( + out_shape = [q, None] + if save_residuals: + # Note that we keep seq_len in the minor-most dimension so that we can do + # 1D TMAs on chunks of `block_q`. + out_shape[1] = jax.ShapeDtypeStruct( + (batch_size, num_q_heads, q_seq_len), jnp.float32 + ) + + out, lse = plgpu.kernel( entry, - out_shape=q, - grid=(batch_size, num_q_tiles, num_q_heads), + out_shape=out_shape, + grid=(num_q_heads, num_q_tiles, batch_size), + grid_names=("heads", "q_seq", "batch"), num_threads=3, - axis_names=("batch", "q_seq", "heads", "wg"), - compiler_params=plgpu.GPUCompilerParams(approx_math=True), + thread_name="wg", + compiler_params=plgpu.CompilerParams(approx_math=True), )(q, k, v) -@functools.partial(jax.jit, static_argnames=["config"]) -def attention_with_pipeline_emitter(q, k, v, config: TuningConfig): + if save_residuals: + assert lse is not None + return out, (lse,) + + return out + +@partial(jax.custom_vjp, nondiff_argnums=(3, 4)) +@partial(jax.jit, static_argnames=["config", "save_residuals"]) +def attention(q, k, v, config: TuningConfig, save_residuals: bool = False): + return _attention_forward(q, k, v, config, save_residuals) + +def _attention_fwd(q, k, v, config: TuningConfig, save_residuals: bool): + del save_residuals + + out, (lse,) = _attention_forward(q, k, v, config, save_residuals=True) + return out, (q, k, v, out, lse) + +def _attention_bwd(config: TuningConfig, save_residuals: bool, res, do): + del save_residuals + q, k, v, out, lse = res + + if config.causal: + raise NotImplementedError("Causal attention not supported in the backwards pass yet.") + + if not config.has_backward_blocks: + raise ValueError("Need to specify backward blocks.") + + assert config.block_q_dq is not None + assert config.block_kv_dq is not None + assert config.block_q_dkv is not None + assert config.block_kv_dkv is not None + + batch_size, q_seq_len, num_q_heads, head_dim = q.shape + _, kv_seq_len, num_kv_heads, _ = k.shape + q_heads_per_kv_head = num_q_heads // num_kv_heads + dtype = q.dtype + compute_wgs = config.compute_wgs_bwd + + num_q_tiles, rem = divmod(q_seq_len, config.block_q_dq * compute_wgs) + if rem: + raise NotImplementedError( + f"{q_seq_len=} must be a multiple of {config.block_q_dq=} * {compute_wgs=}") + + num_kv_tiles, rem = divmod(kv_seq_len, config.block_kv_dkv * compute_wgs) + if rem: + raise NotImplementedError( + f"{kv_seq_len=} must be a multiple of {config.block_kv_dkv=} * {compute_wgs=}") + + num_q_tiles_in_dkv, rem = divmod(q_seq_len, config.block_q_dkv) + if rem: + raise NotImplementedError(f"{q_seq_len=} must be a multiple of {config.block_q_dkv=}") + + num_kv_tiles_in_dq, rem = divmod(kv_seq_len, config.block_kv_dq) + if rem: + raise NotImplementedError(f"{kv_seq_len=} must be a multiple of {config.block_kv_dq=}") + + tiling = plgpu.TilingTransform((8, 64)) + swizzle = plgpu.SwizzleTransform(128) + + delta = jnp.einsum('bqhd,bqhd->bhq', out.astype(jnp.float32), do.astype(jnp.float32)) + del out # Not needed anymore. + + def kernel_dq(q_ref, k_ref, v_ref, do_ref, lse_ref, delta_ref, dq_ref, + smem_buffers, buffer_barriers, block_q, block_kv): + batch = lax.axis_index("batch") + q_head = lax.axis_index("heads") + wg_idx = lax.axis_index("wg") + kv_head = lax.div(q_head, jnp.array(q_heads_per_kv_head, q_head.dtype)) + q_smem2, do_smem2, lse_smem2, delta_smem2 = smem_buffers + q_barriers, do_barriers, lse_barriers, delta_barriers = buffer_barriers + def _compute_thread(pipeline_callback): + q_smem, do_smem, lse_smem, delta_smem = q_smem2.at[wg_idx], do_smem2.at[wg_idx], lse_smem2.at[wg_idx], delta_smem2.at[wg_idx] + q_seq_base = lax.axis_index("q_seq") * (compute_wgs * block_q) + wg_idx * block_q + q_slice = (batch, pl.ds(q_seq_base, block_q), q_head) + plgpu.copy_gmem_to_smem(q_ref.at[q_slice], q_smem, q_barriers.at[wg_idx]) + plgpu.copy_gmem_to_smem(do_ref.at[q_slice], do_smem, do_barriers.at[wg_idx]) + plgpu.copy_gmem_to_smem( + delta_ref.at[batch, q_head, pl.ds(q_seq_base, block_q)], + delta_smem, + delta_barriers.at[wg_idx], + ) + plgpu.copy_gmem_to_smem( + lse_ref.at[batch, q_head, pl.ds(q_seq_base, block_q)], + lse_smem, + lse_barriers.at[wg_idx], + ) + for buffer in buffer_barriers: + plgpu.barrier_wait(buffer.at[wg_idx]) + + delta = plgpu.load(delta_smem, (), layout=plgpu.Layout.WGMMA_ROW) + lse = plgpu.load(lse_smem, (), layout=plgpu.Layout.WGMMA_ROW) + dq_acc = plgpu.layout_cast( + jnp.full((block_q, head_dim), 0, dtype=jnp.float32), plgpu.Layout.WGMMA, + ) + dq, _, _ = pipeline_callback((dq_acc, lse, delta)) + q_smem[...] = dq.astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(q_smem, dq_ref.at[q_slice]) + plgpu.wait_smem_to_gmem(0) + + def kv_pipeline(_, k_smem, v_smem, k_consumed_barrier, v_consumed_barrier, carry): + q_smem, do_smem = q_smem2.at[wg_idx], do_smem2.at[wg_idx] + (dq_acc, lse, delta) = carry + + def compute_s(acc_ref): + plgpu.wgmma(acc_ref, q_smem, plgpu.transpose_ref(k_smem, (1, 0))) + return acc_ref[...] + + s = pl.run_scoped(compute_s, plgpu.ACC((block_q, block_kv), jnp.float32)) + s *= math.log2(math.e) + p = jnp.exp2(s - lax.broadcast_in_dim(lse, (block_q, block_kv), [0])) + + # dP + def compute_dp(acc_ref): + plgpu.wgmma(acc_ref, do_smem, plgpu.transpose_ref(v_smem, (1, 0))) + return acc_ref[...] + + dp = pl.run_scoped(compute_dp, plgpu.ACC((block_q, block_kv), jnp.float32)) + plgpu.barrier_arrive(v_consumed_barrier) + + # dS + ds = p * (dp - lax.broadcast_in_dim(delta, (block_q, block_kv), [0])) + + # dQ + def compute_dq(acc_ref): + plgpu.wgmma(acc_ref, ds.astype(k_ref.dtype), k_smem) + + dq_acc = pl.run_state(compute_dq)(plgpu.ACC.init(dq_acc)) + plgpu.barrier_arrive(k_consumed_barrier) + + return (dq_acc, lse, delta) + + pipeline = plgpu.emit_pipeline_warp_specialized( + kv_pipeline, + grid=(num_kv_tiles_in_dq,), + max_concurrent_steps=min([config.max_concurrent_steps, num_q_tiles]), + num_compute_wgs=compute_wgs, + memory_registers=40, + wg_axis="wg", + manual_consumed_barriers=True, + compute_context=_compute_thread, + in_specs=[ + plgpu.BlockSpec( # k + block_shape=(block_kv, head_dim), + index_map=lambda i: (i, 0), + transforms=[tiling, swizzle]), + plgpu.BlockSpec( # v + block_shape=(block_kv, head_dim), + index_map=lambda i: (i, 0), + transforms=[tiling, swizzle]), + ]) + k_ref = k_ref.at[batch, :, kv_head, :] + v_ref = v_ref.at[batch, :, kv_head, :] + pipeline(k_ref, v_ref) + + def kernel_dkv(q_ref, k_ref, v_ref, do_ref, lse_ref, delta_ref, + dk_ref, dv_ref, smem_buffers, buffer_barriers, block_q: int, block_kv: int): + batch = lax.axis_index("batch") + q_head = lax.axis_index("heads") + wg_idx = lax.axis_index("wg") + (k_smem2, v_smem2) = smem_buffers + (k_barriers, v_barriers) = buffer_barriers + + def _compute_thread(pipeline_callback): + k_smem, v_smem = k_smem2.at[wg_idx], v_smem2.at[wg_idx] + kv_seq_base = lax.axis_index("kv_seq") * (compute_wgs * block_kv) + wg_idx * block_kv + kv_head = lax.div(q_head, jnp.array(q_heads_per_kv_head, q_head.dtype)) + plgpu.copy_gmem_to_smem( + k_ref.at[(batch, pl.ds(kv_seq_base, block_kv), kv_head)], + k_smem, + k_barriers.at[wg_idx]) + plgpu.copy_gmem_to_smem( + v_ref.at[(batch, pl.ds(kv_seq_base, block_kv), kv_head)], + v_smem, + v_barriers.at[wg_idx]) + plgpu.barrier_wait(k_barriers.at[wg_idx]) + plgpu.barrier_wait(v_barriers.at[wg_idx]) + dk_acc = plgpu.layout_cast( + jnp.full((block_kv, head_dim), 0, dtype=jnp.float32), plgpu.Layout.WGMMA, + ) + dv_acc = plgpu.layout_cast( + jnp.full((block_kv, head_dim), 0, dtype=jnp.float32), plgpu.Layout.WGMMA, + ) + (dk, dv) = pipeline_callback((dv_acc, dk_acc)) + k_smem[...] = dk.astype(dtype) + v_smem[...] = dv.astype(dtype) + + plgpu.commit_smem() + plgpu.copy_smem_to_gmem( + k_smem, + dk_ref.at[(batch, pl.ds(kv_seq_base, block_kv), q_head)], + commit_group=False) + plgpu.copy_smem_to_gmem( + v_smem, + dv_ref.at[(batch, pl.ds(kv_seq_base, block_kv), q_head)], + commit_group=False) + plgpu.commit_smem_to_gmem_group() + plgpu.wait_smem_to_gmem(0) + + def q_pipeline(_, q_smem, do_smem, lse_smem, delta_smem, q_consumed_barrier, do_consumed_barrier, lse_consumed_barrier, delta_consumed_barrier, carry): + k_smem, v_smem = k_smem2.at[wg_idx], v_smem2.at[wg_idx] + dk_acc, dv_acc = carry + + def _compute_sT(acc_ref): + plgpu.wgmma(acc_ref, k_smem, plgpu.transpose_ref(q_smem, (1, 0))) + return acc_ref[...] + sT = pl.run_scoped(_compute_sT, plgpu.ACC((block_kv, block_q), jnp.float32)) + sT *= math.log2(math.e) + + lse = plgpu.load(lse_smem, (), layout=plgpu.Layout.WGMMA_COL) + plgpu.barrier_arrive(lse_consumed_barrier) + pT = jnp.exp2(sT - lax.broadcast_in_dim(lse, (block_kv, block_q), [1])) + + def _compute(refs): + # Combining two WGMMA calls in one block to avoid the unnecessary + # synchronization from two `wgmma.wait_group` calls. + dv_acc_ref, dpT_acc_ref = refs + plgpu.wgmma(dv_acc_ref, pT.astype(dtype), do_smem) # dV + plgpu.wgmma(dpT_acc_ref, v_smem, plgpu.transpose_ref(do_smem, (1, 0))) # dpT + + zeros = plgpu.layout_cast( + jnp.full((block_kv, block_q), 0, dtype=jnp.float32), plgpu.Layout.WGMMA, + ) + dv_acc, dpT = pl.run_state(_compute)((plgpu.ACC.init(dv_acc), plgpu.ACC.init(zeros))) + plgpu.barrier_arrive(do_consumed_barrier) + + delta = plgpu.load(delta_smem, (), layout=plgpu.Layout.WGMMA_COL) + plgpu.barrier_arrive(delta_consumed_barrier) + + dsT = pT * (dpT - lax.broadcast_in_dim(delta, (block_kv, block_q), [1])) # pytype: disable=wrong-arg-types # jax-operator-types + + def compute_dk(acc_ref): + plgpu.wgmma(acc_ref, dsT.astype(dtype), q_smem) + + dk_acc = pl.run_state(compute_dk)(plgpu.ACC.init(dk_acc)) + plgpu.barrier_arrive(q_consumed_barrier) + + return (dk_acc, dv_acc) + + pipeline = plgpu.emit_pipeline_warp_specialized( + q_pipeline, + grid=(num_q_tiles_in_dkv,), + max_concurrent_steps=min([config.max_concurrent_steps, num_kv_tiles]), + num_compute_wgs=compute_wgs, + memory_registers=40, + wg_axis="wg", + manual_consumed_barriers=True, + compute_context=_compute_thread, + in_specs=[ + plgpu.BlockSpec( # q + block_shape=(block_q, head_dim), + index_map=lambda i: (i, 0), + transforms=[tiling, swizzle]), + plgpu.BlockSpec( # do + block_shape=(block_q, head_dim), + index_map=lambda i: (i, 0), + transforms=[tiling, swizzle]), + plgpu.BlockSpec(block_shape=(block_q,), index_map=lambda i: (i,)), + plgpu.BlockSpec(block_shape=(block_q,), index_map=lambda i: (i,)) + ]) + q_ref = q_ref.at[batch, :, q_head, :] + do_ref = do_ref.at[batch, :, q_head, :] + lse_ref = lse_ref.at[batch, q_head, :] + delta_ref = delta_ref.at[batch, q_head, :] + pipeline(q_ref, do_ref, lse_ref, delta_ref) + + q_scratch = plgpu.SMEM( + (compute_wgs, config.block_q_dq, head_dim), jnp.float16, + transforms=(tiling, swizzle), + ) + do_scratch = q_scratch + lse_scratch = plgpu.SMEM((compute_wgs, config.block_q_dq), jnp.float32) + delta_scratch = plgpu.SMEM((compute_wgs, config.block_q_dq), jnp.float32) + dq = plgpu.kernel( + partial(kernel_dq, block_q=config.block_q_dq, block_kv=config.block_kv_dq), + out_shape=q, + scratch_shapes=[ + (q_scratch, do_scratch, lse_scratch, delta_scratch), # type: ignore + (plgpu.Barrier(num_barriers=compute_wgs),) * 4 # type: ignore + ], + compiler_params=plgpu.CompilerParams(approx_math=True), + grid=(num_q_heads, num_q_tiles, batch_size), + grid_names=("heads", "q_seq", "batch"), + num_threads=compute_wgs + 1, + thread_name="wg", + )(q, k, v, do, lse, delta) + + k_scratch = plgpu.SMEM( + (compute_wgs, config.block_kv_dkv, head_dim), jnp.float16, + transforms=(tiling, swizzle), + ) + v_scratch = k_scratch + out_shape_kv = jax.ShapeDtypeStruct( + (batch_size, kv_seq_len, num_q_heads, head_dim), dtype=jnp.float16) + dk, dv = plgpu.kernel( + partial(kernel_dkv, block_q=config.block_q_dkv, block_kv=config.block_kv_dkv), + out_shape=[out_shape_kv, out_shape_kv], + scratch_shapes=[ + (k_scratch, v_scratch), # type: ignore + (plgpu.Barrier(num_barriers=compute_wgs),) * 2 # type: ignore + ], + compiler_params=plgpu.CompilerParams(approx_math=True), + grid=(num_q_heads, num_kv_tiles, batch_size), + grid_names=("heads", "kv_seq", "batch"), + num_threads=compute_wgs + 1, + thread_name="wg" + )(q, k, v, do, lse, delta) + + if q_heads_per_kv_head > 1: + sum_shape = (*k.shape[:-1], q_heads_per_kv_head, head_dim) + dk = dk.reshape(sum_shape).astype(jnp.float32).sum(axis=-2).astype(dk.dtype) + dv = dv.reshape(sum_shape).astype(jnp.float32).sum(axis=-2).astype(dv.dtype) + + return dq, dk, dv + +attention.defvjp(_attention_fwd, _attention_bwd) + +@functools.partial(jax.jit, static_argnames=["config", "save_residuals"]) +def attention_with_pipeline_emitter(q, k, v, config: TuningConfig, save_residuals=False): + if config.causal: + raise NotImplementedError("Causal attention is not supported with the pipeline emitter yet.") if q.ndim != 4 or k.ndim != 4 or v.ndim != 4: raise ValueError(f"q, k, and v should all be 4D, got: {q.ndim=}, {k.ndim=}, {v.ndim=}") batch_size, q_seq_len, num_q_heads, head_dim = q.shape @@ -262,14 +667,10 @@ def attention_with_pipeline_emitter(q, k, v, config: TuningConfig): if rem: raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}") - tiling = plgpu.TilingTransform((64, 64)) - swizzle = plgpu.SwizzleTransform(128) - transpose = plgpu.TransposeTransform((0, 2, 1, 3, 4)) - - def fa3_kernel(q_ref, k_ref, v_ref, out_ref, scoped): + def fa3_kernel(q_ref, k_ref, v_ref, out_ref, lse_ref, smem_buffers, q_barriers, schedule_barrier): batch = lax.axis_index("batch") wg_idx = lax.axis_index("wg") - qo_smem2, q_barriers, schedule_barrier = scoped + qo_smem2, lse_smem2 = smem_buffers q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q q_head = lax.axis_index("heads") kv_head = lax.div(q_head, jnp.array(q_heads_per_kv_head, q_head.dtype)) @@ -279,17 +680,12 @@ def perform_schedule_barrier(): plgpu.barrier_arrive(schedule_barrier) plgpu.barrier_wait(schedule_barrier) - def _compute_thread(): + def _compute_thread(pipeline_callback): qo_smem = qo_smem2.at[wg_idx] - m_i = plgpu.layout_cast( - jnp.full((block_q,), -jnp.inf, dtype=jnp.float32), plgpu.Layout.WGMMA_ROW, - ) - l_i = plgpu.layout_cast( - jnp.full((block_q,), 0, dtype=jnp.float32), plgpu.Layout.WGMMA_ROW, - ) - acc = plgpu.layout_cast( - jnp.full((block_q, head_dim), 0, dtype=jnp.float32), plgpu.Layout.WGMMA, - ) + lse_smem = lse_smem2.at[wg_idx] if lse_smem2 is not None else None + m_i = jnp.full((block_q,), -jnp.inf, dtype=jnp.float32) + l_i = jnp.full((block_q,), 0, dtype=jnp.float32) + acc = jnp.full((block_q, head_dim), 0, dtype=jnp.float32) # Q is not pipelined, so we load in with a manual DMA. plgpu.copy_gmem_to_smem( q_ref.at[batch, pl.ds(q_seq_base, block_q), q_head], @@ -298,19 +694,27 @@ def _compute_thread(): ) plgpu.barrier_wait(q_barriers.at[wg_idx]) pl.when(wg_idx == 1)(perform_schedule_barrier) - final_carry = (yield (acc, m_i, l_i)) - del m_i # Unused + final_carry = pipeline_callback((acc, m_i, l_i)) pl.when(wg_idx == 0)(perform_schedule_barrier) - acc, _, l_i = final_carry + acc, m_i, l_i = final_carry acc /= lax.broadcast_in_dim(l_i, (block_q, head_dim), [0]) qo_smem[...] = acc.astype(dtype) + if lse_smem is not None: + RCP_LN2 = 1.4426950408889634 + log2 = lambda x: jnp.log(x) * RCP_LN2 + lse_smem[...] = m_i + log2(l_i) plgpu.commit_smem() plgpu.copy_smem_to_gmem( qo_smem, out_ref.at[batch, pl.ds(q_seq_base, block_q), q_head], ) + if lse_smem is not None: + plgpu.copy_smem_to_gmem( + lse_smem, + lse_ref.at[batch, q_head, pl.ds(q_seq_base, block_q)], + ) plgpu.wait_smem_to_gmem(0) - def kv_pipeline(k_smem, v_smem, + def kv_pipeline(_, k_smem, v_smem, k_consumed_barrier, v_consumed_barrier, carry): acc, m_i, l_i = carry @@ -348,66 +752,82 @@ def compute_pv(acc_ref): memory_registers=40, wg_axis="wg", manual_consumed_barriers=True, - carry_coroutine=_compute_thread, + compute_context=_compute_thread, in_specs=[ - plgpu.GPUBlockSpec( # k + plgpu.BlockSpec( # k block_shape=(block_kv, head_dim), - index_map=lambda i: (i, 0), - transforms=[tiling, transpose, swizzle]), - plgpu.GPUBlockSpec( # v + index_map=lambda i: (i, 0)), + plgpu.BlockSpec( # v block_shape=(block_kv, head_dim), - index_map=lambda i: (i, 0), - transforms=[tiling, swizzle]), + index_map=lambda i: (i, 0)), ], out_specs=[], ) k_ref = k_ref.at[batch, :, kv_head, :] v_ref = v_ref.at[batch, :, kv_head, :] pipeline(k_ref, v_ref) - mesh = plgpu.GPUMesh( - grid=(batch_size, num_q_tiles, num_q_heads), + + out_shape = [q, None] + if save_residuals: + out_shape[1] = jax.ShapeDtypeStruct((batch_size, num_q_heads, q_seq_len), jnp.float32) + + qo_scratch = plgpu.SMEM((compute_wgs, block_q, head_dim), jnp.float16) + smem_scratch = [qo_scratch, None] + if save_residuals: + smem_scratch[1] = plgpu.SMEM((compute_wgs, block_q), jnp.float32) + + out, lse = plgpu.kernel( + fa3_kernel, + grid=(num_q_heads, num_q_tiles, batch_size), + grid_names=("heads", "q_seq", "batch"), num_threads=3, - axis_names=("batch", "q_seq", "heads", "wg"), - ) - def run(refs): - q_ref, k_ref, v_ref, out_ref = refs - @pl.core_map(mesh, - compiler_params=plgpu.GPUCompilerParams(approx_math=True), - ) - def _kernel_entry(): - qo_scratch = plgpu.SMEM( - (compute_wgs, block_q, head_dim), jnp.float16, - transforms=(tiling, swizzle), - ) - pl.run_scoped( - lambda *args: fa3_kernel(q_ref, k_ref, v_ref, out_ref, args), - qo_scratch, - plgpu.Barrier(1, num_barriers=compute_wgs), - plgpu.Barrier(num_arrivals=compute_wgs), - ) - @jax.jit - def run_function(q, k, v, o): - _, _, _, out = pl.run_state(run)((q, k, v, o)) - return out - out = run_function(q, k, v, jnp.full_like(q, jnp.inf)) + thread_name="wg", + out_shape=out_shape, + scratch_shapes=( + tuple(smem_scratch), # type: ignore + plgpu.Barrier(num_barriers=compute_wgs), # type: ignore + plgpu.Barrier(num_arrivals=compute_wgs),), # type: ignore + compiler_params=plgpu.CompilerParams( + approx_math=True, lowering_semantics=plgpu.LoweringSemantics.Warpgroup, + ), + )(q, k, v) + + if save_residuals: + assert lse is not None + return out, (lse,) + return out -@jax.jit -def attention_reference(q, k, v): +@functools.partial(jax.jit, static_argnames=["causal", "save_residuals"]) +def attention_reference(q, k, v, causal=False, save_residuals=False): batch_size, q_seq_len, num_q_heads, head_dim = q.shape - num_kv_heads = k.shape[2] + kv_seq_len, num_kv_heads = k.shape[1], k.shape[2] q, k, v = map(lambda x: x.astype(jnp.float32), (q, k, v)) q_reshaped = q.reshape( batch_size, q_seq_len, num_kv_heads, num_q_heads // num_kv_heads, head_dim ) logits = jnp.einsum("bqHhc,bkHc->bqHhk", q_reshaped, k) + + if causal: + mask = jnp.arange(q_seq_len)[:, None] >= jnp.arange(kv_seq_len)[None, :] + mask = jnp.broadcast_to(mask[:, None, None, :], logits.shape) + logits = jnp.where(mask, logits, -jnp.inf) + m = logits.max(axis=-1, keepdims=True) unnormalized = jnp.exp(logits - m) l = unnormalized.sum(axis=-1, keepdims=True) weights = unnormalized / l - return jnp.einsum("bqHhk,bkHc->bqHhc", weights, v).reshape(*q.shape) - + out = jnp.einsum("bqHhk,bkHc->bqHhc", weights, v).reshape(*q.shape) + + if save_residuals: + log2e = math.log2(math.e) + l = l.reshape(*q.shape[:-1]) + m = m.reshape(*q.shape[:-1]) + lse = m * log2e + jnp.log2(l) + return out, (lse.swapaxes(-1, -2),) + else: + return out def main(unused_argv): num_q_heads = 16 @@ -421,11 +841,19 @@ def main(unused_argv): schedule_barrier_opts = (True,) problem_it = itertools.product( - (1,), (4096, 32768,), (64, 128, 256,), schedule_barrier_opts) - for batch_size, seq_len, head_dim, use_schedule_barrier in problem_it: + (1,), (4096, 32768,), (64, 128, 256,), schedule_barrier_opts, (False, True)) + for batch_size, seq_len, head_dim, use_schedule_barrier, causal in problem_it: + assert cuda_versions is not None + cuda_runtime_version = cuda_versions.cuda_runtime_get_version() + # TODO(pobudzey): Undo when we upgrade to cuda 12.9.1. + if causal and cuda_runtime_version >= 12080 and cuda_runtime_version < 12091: + continue + + if causal and use_pipeline_emitter: + continue q_seq_len = kv_seq_len = seq_len print(f"==== {batch_size=:<6} {kv_seq_len=:<6} {q_seq_len=:<6}" - f"{num_q_heads=:<4} {head_dim=:<6} {use_schedule_barrier=:} ====") + f"{num_q_heads=:<4} {head_dim=:<6} {use_schedule_barrier=:} {causal=:} ====") k1, k2, k3 = jax.random.split(jax.random.key(42), 3) q = jax.random.normal(k1, (batch_size, q_seq_len, num_q_heads, head_dim), jnp.float16) k = jax.random.normal(k2, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) @@ -433,11 +861,11 @@ def main(unused_argv): block_q = 64 best = None for block_kv in (256, 128, 64): - config = TuningConfig(block_q=block_q, block_kv=block_kv, max_concurrent_steps=2, use_schedule_barrier=use_schedule_barrier) + config = TuningConfig(block_q=block_q, block_kv=block_kv, max_concurrent_steps=2, use_schedule_barrier=use_schedule_barrier, causal=causal) try: out, runtime_ms = profiler.measure(functools.partial(attention_impl, config=config))(q, k, v) if seq_len < 32768: - out_ref = attention_reference(q, k, v) + out_ref = attention_reference(q, k, v, causal=causal) np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3) except ValueError as e: if "exceeds available shared memory" in e.args[0]: @@ -447,6 +875,8 @@ def main(unused_argv): matmul_flops = ( 4 * q_seq_len * kv_seq_len * head_dim * num_q_heads * batch_size ) + if causal: + matmul_flops //= 2 peak_flops = 1e15 # f16 TensorCore peak = 1000TFLOPS optimal_time = matmul_flops / peak_flops * 1e6 # us achieved_tc_util = optimal_time / runtime_us * 100 diff --git a/jax/experimental/pallas/ops/gpu/blackwell_matmul_mgpu.py b/jax/experimental/pallas/ops/gpu/blackwell_matmul_mgpu.py new file mode 100644 index 000000000000..80dd2e95e2b0 --- /dev/null +++ b/jax/experimental/pallas/ops/gpu/blackwell_matmul_mgpu.py @@ -0,0 +1,339 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Matrix Multiplication kernel for Blackwell GPUs.""" +import dataclasses +import enum +import functools +import itertools +import statistics + +import jax +from jax import lax +from jax._src import test_util as jtu # noqa: F401 +from jax.experimental.mosaic.gpu import profiler +import jax.experimental.pallas as pl +import jax.experimental.pallas.mosaic_gpu as plgpu +import jax.numpy as jnp +import numpy as np + +class MatmulDimension(enum.IntEnum): + M = 0 + N = 1 + + +@dataclasses.dataclass(frozen=True) +class TuningConfig: + tile_m: int + tile_n: int + tile_k: int + max_concurrent_steps: int + collective: bool + epilogue_tile_n: int = 64 + grid_minor_dim: MatmulDimension = MatmulDimension.N + grid_tile_width: int = 1 + + +def matmul_kernel(a, b, config: TuningConfig): + dtype = a.dtype + if a.dtype != b.dtype: + raise ValueError( + f"Matmul LHS and RHS have incompatible dtypes {a.dtype} vs {b.dtype}" + ) + m, k = a.shape + k2, n = b.shape + if k != k2: + raise ValueError( + f"Matmul LHS and RHS have incompatible shapes {a.shape} vs {b.shape}" + ) + collective = config.collective + tile_m, tile_n, tile_k = (config.tile_m, config.tile_n, config.tile_k) + epilogue_tile_n = config.epilogue_tile_n + if tile_n % epilogue_tile_n != 0: + raise ValueError( + f"{tile_n=} must be divisible by {epilogue_tile_n=}" + ) + block_tile_m = tile_m + block_tile_n = tile_n + if collective: + tile_m *= 2 + tile_n *= 2 + swizzle = plgpu.find_swizzle(tile_k * jnp.dtype(dtype).itemsize * 8) + swizzle_elems = swizzle // jnp.dtype(dtype).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(swizzle), + ) + out_swizzle = plgpu.find_swizzle(epilogue_tile_n * jnp.dtype(dtype).itemsize * 8) + out_swizzle_elems = out_swizzle // jnp.dtype(dtype).itemsize + out_transforms = ( + plgpu.TilingTransform((8, out_swizzle_elems)), + plgpu.SwizzleTransform(out_swizzle), + ) + if m % tile_m != 0: + raise ValueError(f"{m=} must be divisible by {tile_m=}") + if n % tile_n != 0: + raise ValueError(f"{n=} must be divisible by {tile_n=}") + if k % tile_k != 0: + raise ValueError(f"{k=} must be divisible by {tile_k=}") + m_iters = m // tile_m + n_iters = n // tile_n + k_iters = k // tile_k + max_concurrent_steps = config.max_concurrent_steps + + TMA_WARP = 0 + MMA_WARP = 1 + COMPUTE_WG = 0 + STORE_WG = 1 + + def kernel(a_gmem, b_gmem, out_gmem, + a_smem, b_smem, acc_tmem, acc_smem, + ab_tma_barrier, store_done_barrier, mma_done_barrier, + consumed_barrier): + wg_idx = lax.axis_index("wg") + cluster_idx = lax.axis_index("x") + is_lead_block = cluster_idx == 0 + + @plgpu.dynamic_scheduling_loop(grid_names=("mn_linear",), thread_axis="wg") + def mn_loop(loop_info: plgpu.NDLoopInfo): # pylint: disable=unused-variable + (lin_idx,) = loop_info.index + local_index = loop_info.local_index + m_index, n_index = plgpu.planar_snake( + lin_idx, + (m_iters, n_iters), + config.grid_minor_dim, + config.grid_tile_width, + ) + block_m_index = m_index * 2 + cluster_idx if collective else m_index + + block_slice_m = pl.ds(block_m_index * block_tile_m, block_tile_m) + slice_m = pl.ds(m_index * tile_m, tile_m) + slice_n = pl.ds(n_index * tile_n, tile_n) + acc_slot = lax.rem(local_index, jnp.int32(2)) + + @pl.when(wg_idx == COMPUTE_WG) + def _(): + @pl.core_map(plgpu.WarpMesh(axis_name="warp")) + def _per_warp(): + warp_id = lax.axis_index("warp") + @pl.when(warp_id == TMA_WARP) + def _memory(): + def _loop_body(ki, _): + slice_k = pl.ds(ki * tile_k, tile_k) + slot = lax.rem(ki, max_concurrent_steps) + @pl.when(jnp.logical_or(ki >= max_concurrent_steps, + local_index > 0)) + def _(): + plgpu.barrier_wait(consumed_barrier.at[slot]) + plgpu.copy_gmem_to_smem( + a_gmem.at[slice_m, slice_k], + a_smem.at[slot], + ab_tma_barrier.at[slot], + partitioned_axis=0 if collective else None, + collective_axes="x" if collective else None, + ) + plgpu.copy_gmem_to_smem( + b_gmem.at[slice_k, slice_n], + b_smem.at[slot], + ab_tma_barrier.at[slot], + partitioned_axis=1 if collective else None, + collective_axes="x" if collective else None, + ) + lax.fori_loop(0, k_iters, _loop_body, None) + + @pl.when(jnp.logical_and(warp_id == MMA_WARP, local_index > 1)) + def _wait_store(): + plgpu.barrier_wait(store_done_barrier.at[acc_slot]) + @pl.when(jnp.logical_and(warp_id == MMA_WARP, is_lead_block)) + def _compute(): + def _loop_body(ki, _): + slot = lax.rem(ki, max_concurrent_steps) + plgpu.barrier_wait(ab_tma_barrier.at[slot]) + + is_last_iter = ki >= k_iters - 1 + acc_tmem_slice = acc_tmem.at[:, pl.ds(acc_slot * tile_n, tile_n)] + plgpu.tcgen05_mma( + acc_tmem_slice, + a_smem.at[slot], + b_smem.at[slot], + consumed_barrier.at[slot], + accumulate=(ki > 0), + collective_axis="x" if collective else None, + ) + @pl.when(is_last_iter) + def _(): + plgpu.tcgen05_commit_arrive( + mma_done_barrier.at[acc_slot], + collective_axis="x" if collective else None, + ) + + lax.fori_loop(0, k_iters, _loop_body, None) + + @pl.when(wg_idx == STORE_WG) + def _(): + plgpu.wait_smem_to_gmem(0, wait_read_only=True) + plgpu.barrier_wait(mma_done_barrier.at[acc_slot]) + acc_tmem_slot = acc_tmem.at[:, pl.ds(acc_slot * tile_n, tile_n)] + step_out_gmem = out_gmem.at[block_slice_m, slice_n] + for ni in range(tile_n // epilogue_tile_n): + acc_smem_ni = acc_smem.at[ni % 2] + ni_col_slice = pl.ds(ni * epilogue_tile_n, epilogue_tile_n) + acc_smem_ni[...] = plgpu.async_load_tmem( + acc_tmem_slot.at[:, ni_col_slice] + ).astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(acc_smem_ni, step_out_gmem.at[:, ni_col_slice]) + plgpu.wait_smem_to_gmem(1, wait_read_only=True) + plgpu.wait_load_tmem() # Load must complete before we continue. + plgpu.barrier_arrive(store_done_barrier.at[acc_slot]) + + if collective: + store_done_barrier = plgpu.ClusterBarrier( + collective_axes=("x",), + num_arrivals=1, + num_barriers=2, + orders_tensor_core=True, + ) + else: + store_done_barrier = plgpu.Barrier( # type: ignore + num_arrivals=1, num_barriers=2, orders_tensor_core=True + ) + f = plgpu.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct((m, n), dtype), + grid=(m_iters * n_iters,), + grid_names=("mn_linear",), + num_threads=2, + thread_name="wg", + cluster_names=("x",), + cluster=(1 + collective,), + scratch_shapes=dict( + a_smem=plgpu.SMEM( + (max_concurrent_steps, block_tile_m, tile_k), + dtype, + transforms=transforms, + ), + b_smem=plgpu.SMEM( + (max_concurrent_steps, tile_k, block_tile_n), + dtype, + transforms=transforms, + ), + acc_tmem=plgpu.TMEM( + (block_tile_m, tile_n * 2), jnp.float32, collective=collective + ), + acc_smem=plgpu.SMEM( + (2, block_tile_m, epilogue_tile_n), + dtype, + transforms=out_transforms, + ), + ab_tma_barrier=plgpu.Barrier( + num_arrivals=2, num_barriers=max_concurrent_steps + ), + store_done_barrier=store_done_barrier, + mma_done_barrier=plgpu.Barrier( + num_arrivals=1, num_barriers=2, orders_tensor_core=True + ), + consumed_barrier=plgpu.Barrier( + num_arrivals=1, + num_barriers=max_concurrent_steps, + orders_tensor_core=True, + ), + ), + ) + return f(a, b) + + +def main(_) -> None: + problem_it = [(4096, 8192, 4096)] + for M, N, K in problem_it: + print(f"==== {M=} {N=} {K=} ====") + matmul_flops = 2 * M * N * K + peak_flops = 2.25e15 # f16 TensorCore peak = 2250 TFLOPS + a = jax.random.uniform(jax.random.key(1), (M, K), jnp.float16, -1, 1) + b = jax.random.uniform(jax.random.key(2), (K, N), jnp.float16, -1, 1) + tuning_it = itertools.product( + (128,), # tile_m + (128, 256), # tile_n + (64,), # tile_k + MatmulDimension, # grid_minor_dim + (1, 4, 8, 12, 16), # grid_tile_width + (2, 4, 6), # max_concurrent_steps + (False, True), # collective + (32,), # epilogue_tile_n + ) + best_util = -float("inf") + expected = jnp.dot(a, b, precision=jax.lax.DotAlgorithmPreset.F16_F16_F32) + for (tile_m, tile_n, tile_k, grid_minor_dim, grid_tile_width, + max_concurrent_steps, collective, epilogue_tile_n) in tuning_it: + # Only N <= 128 are supported for collective MMAs + if collective and tile_n > 128: + continue + config = TuningConfig( + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + max_concurrent_steps=max_concurrent_steps, + collective=collective, + epilogue_tile_n=epilogue_tile_n, + grid_minor_dim=grid_minor_dim, + grid_tile_width=grid_tile_width, + ) + if collective: + tile_m *= 2 + tile_n *= 2 + try: + out, runtimes_ms = profiler.measure( + functools.partial(matmul_kernel, config=config), iterations=10 + )(a, b) + assert runtimes_ms is not None + runtime_ms = statistics.median(runtimes_ms) + except ValueError as e: + if ("exceeds available shared memory" in e.args[0] or + "Accumulator layout mismatch:" in e.args[0]): + # Accumulator layout mismatch triggers for tile_n=256 on some configs. + continue + raise + runtime_us = runtime_ms * 1e3 # type: ignore + optimal_time = matmul_flops / peak_flops * 1e6 # us + achieved_tc_util = optimal_time / runtime_us * 100 + if achieved_tc_util > best_util: + np.testing.assert_allclose(out, expected) + best_util = achieved_tc_util + print( + f"{tile_m=} {tile_n=} {tile_k=} {max_concurrent_steps=} " + f"{grid_minor_dim=} {grid_tile_width=} " + f"{epilogue_tile_n=} " + f"{collective=} : " + f"{runtime_us:<7.1f}us" + f" = {achieved_tc_util:4.1f}% TC utilization" + ) + print(f"\tBest utilization: {best_util:4.1f}%") + _, runtimes_ms = profiler.measure( + functools.partial( + jnp.dot, precision=jax.lax.DotAlgorithmPreset.F16_F16_F32 + ), + iterations=10, + )(a, b) + assert runtimes_ms is not None + runtime_ms = statistics.median(runtimes_ms) + runtime_us = runtime_ms * 1e3 # type: ignore + optimal_time = matmul_flops / peak_flops * 1e6 # us + achieved_tc_util = optimal_time / runtime_us * 100 + print(f"\tReference: {achieved_tc_util:4.1f}%") + + +if __name__ == "__main__": + from absl import app + + jax.config.config_with_absl() + app.run(main) diff --git a/jax/experimental/pallas/ops/gpu/blackwell_ragged_dot_mgpu.py b/jax/experimental/pallas/ops/gpu/blackwell_ragged_dot_mgpu.py new file mode 100644 index 000000000000..843ac50554cc --- /dev/null +++ b/jax/experimental/pallas/ops/gpu/blackwell_ragged_dot_mgpu.py @@ -0,0 +1,437 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Ragged/Grouped Matrix Multiplication kernel for Blackwell GPUs.""" +import dataclasses +import functools +import itertools +import math +import jax +from jax import lax +from jax._src import test_util as jtu # noqa: F401 +from jax.experimental.mosaic.gpu import profiler +import jax.experimental.pallas as pl +import jax.experimental.pallas.mosaic_gpu as plgpu +from jax.experimental.pallas.ops.gpu import blackwell_matmul_mgpu +from jax.experimental.pallas.ops.gpu import ragged_dot_mgpu +import jax.numpy as jnp +import numpy as np +from typing import Sequence + + +@dataclasses.dataclass(frozen=True) +class TuningConfig: + tile_m: int + tile_n: int + tile_k: int + max_concurrent_steps: int + collective: bool + grid_tile_width: int + grid_minor_dim: blackwell_matmul_mgpu.MatmulDimension + epilogue_tile_n: int = 64 + + def __str__(self): + return "_".join(f"{k}={v}" for k, v in dataclasses.asdict(self).items()) + + +# TODO(justinfu): Merge with blackwell_matmul_mgpu.py +def do_matmul(a_gmem, + b_gmem, + out_gmem, + grid_indices: Sequence[jax.Array], + wg_axis: str, + collective_axes: tuple[str, ...], + local_index: jax.Array, + config: TuningConfig, + group_info: ragged_dot_mgpu.GroupInfo, + a_smem, b_smem, acc_tmem, acc_smem, + a_tma_barrier, b_tma_barrier, store_done_barrier, mma_done_barrier, + consumed_barrier + ): + """Compute a non-ragged matmul for a single output block.""" + dtype = out_gmem.dtype + m, k = a_gmem.shape + collective = config.collective + tile_m, tile_n, tile_k = (config.tile_m, config.tile_n, config.tile_k) + epilogue_tile_n = config.epilogue_tile_n + max_concurrent_steps = config.max_concurrent_steps + block_tile_m = tile_m + if collective: + tile_m *= 2 + tile_n *= 2 + k_iters = k // tile_k + + if collective: + m_index, n_index, cluster_idx = grid_indices + block_m_index = m_index * 2 + cluster_idx + is_lead_block = cluster_idx == 0 + else: + m_index, n_index = grid_indices + cluster_idx = 0 # type: ignore + block_m_index = m_index + is_lead_block = True # type: ignore + wg_idx = lax.axis_index(wg_axis) + collective_axis = collective_axes[0] if collective else None + + TMA_WARP = 0 + MMA_WARP = 1 + COMPUTE_WG = 0 + STORE_WG = 1 + + block_slice_m = pl.ds(block_m_index * block_tile_m, block_tile_m) + slice_m = pl.ds(m_index * tile_m, tile_m) + slice_n = pl.ds(n_index * tile_n, tile_n) + acc_slot = lax.rem(local_index, jnp.int32(2)) + regs_layout = plgpu.Layout.TCGEN05 + + @pl.when(wg_idx == COMPUTE_WG) + @jax.named_scope("compute_wg") + def _(): + @pl.core_map(plgpu.WarpMesh(axis_name="warp")) + def _per_warp(): + warp_id = lax.axis_index("warp") + @pl.when(warp_id == TMA_WARP) + def _memory(): + def _loop_body(ki, _): + slice_k = pl.ds(ki * tile_k, tile_k) + slot = lax.rem(ki, max_concurrent_steps) + @pl.when(jnp.logical_or(ki >= max_concurrent_steps, + local_index > 0)) + def _(): + plgpu.barrier_wait(consumed_barrier.at[slot]) + plgpu.copy_gmem_to_smem( + a_gmem.at[slice_m, slice_k], + a_smem.at[slot], + a_tma_barrier.at[slot], + partitioned_axis=0 if collective else None, + collective_axes=collective_axis, + ) + plgpu.copy_gmem_to_smem( + b_gmem.at[slice_k, slice_n], + b_smem.at[slot], + b_tma_barrier.at[slot], + partitioned_axis=1 if collective else None, + collective_axes=collective_axis, + ) + lax.fori_loop(0, k_iters, _loop_body, None) + + @pl.when(jnp.logical_and(warp_id == MMA_WARP, local_index > 1)) + def _wait_store(): + plgpu.barrier_wait(store_done_barrier.at[acc_slot]) + @pl.when(jnp.logical_and(warp_id == MMA_WARP, is_lead_block)) + def _compute(): + def _loop_body(ki, _): + slot = lax.rem(ki, max_concurrent_steps) + plgpu.barrier_wait(a_tma_barrier.at[slot]) + plgpu.barrier_wait(b_tma_barrier.at[slot]) + + is_last_iter = ki >= k_iters - 1 + acc_tmem_slice = acc_tmem.at[:, pl.ds(acc_slot * tile_n, tile_n)] + plgpu.tcgen05_mma( + acc_tmem_slice, + a_smem.at[slot], + b_smem.at[slot], + consumed_barrier.at[slot], + accumulate=(ki > 0), + collective_axis=collective_axis, + ) + @pl.when(is_last_iter) + def _(): + plgpu.tcgen05_commit_arrive( + mma_done_barrier.at[acc_slot], + collective_axis=collective_axis, + ) + + lax.fori_loop(0, k_iters, _loop_body, None) + + @pl.when(wg_idx == STORE_WG) + @jax.named_scope("store_wg") + def _(): + plgpu.barrier_wait(mma_done_barrier.at[acc_slot]) + acc_tmem_slot = acc_tmem.at[:, pl.ds(acc_slot * tile_n, tile_n)] + step_out_gmem = out_gmem.at[block_slice_m, slice_n] + # group_info contains start/size info relative to the logical + # tiling (tile_m) but because for collective matmuls we use 2 CTAs per + # logical block, but we need to compute the start/size relative to the + # current block. + # For example, for the following parameters: + # block_tile_m=64 (tile_m=128) + # group_info.start_within_block=60 + # group_info.actual_size=37 + # The requested copy will be split across both blocks + # Memory: | Block 0 | Block 1 | + # |--- 64 ---|--- 64 ---| + # Copy: |-- 37 --| + # Where block 0 copies rows 60-64 (4 rows total) and block 1 copies + # the remaining rows 64-97 (33 rows total). + smem_start = group_info.start_within_block - cluster_idx * block_tile_m + smem_start = lax.max(smem_start, jnp.int32(0)) + def _clamp(min, x, max): + return lax.max(lax.min(x, max), min) + block0_copy_size = _clamp( + jnp.int32(0), + block_tile_m - group_info.start_within_block, + group_info.actual_size) + block_local_size = lax.select(is_lead_block, + # block 0 copies up to end of the first block or actual_size, + # whichever comes first. + block0_copy_size, + # block 1 copies the remaining rows that block 0 did not copy. + group_info.actual_size - block0_copy_size + ) + for ni in range(tile_n // epilogue_tile_n): + acc_smem[...] = plgpu.async_load_tmem( + acc_tmem_slot.at[:, pl.ds(ni * epilogue_tile_n, epilogue_tile_n)], + layout=regs_layout).astype(dtype) + plgpu.commit_smem() + cur_smem_idx = smem_start + remaining_rows = min(block_tile_m, m) + while remaining_rows > 0: + const_rows_len = 1 << int(math.log2(remaining_rows)) + remaining_rows //= 2 + @pl.when(block_local_size & const_rows_len != 0) + def _(): + o_smem_slice = acc_smem.at[pl.ds(cur_smem_idx, const_rows_len)] + o_gref_slice = step_out_gmem.at[ + pl.ds(cur_smem_idx, const_rows_len), + pl.ds(ni * epilogue_tile_n, epilogue_tile_n), + ] + plgpu.copy_smem_to_gmem(o_smem_slice, o_gref_slice) + cur_smem_idx += block_local_size & const_rows_len + plgpu.wait_smem_to_gmem(0, wait_read_only=True) + plgpu.wait_load_tmem() # Load must complete before we continue. + plgpu.barrier_arrive(store_done_barrier.at[acc_slot]) + + +def ragged_dot_kernel(a, b, group_sizes, config: TuningConfig): + dtype = a.dtype + if a.dtype != b.dtype: + raise ValueError( + f"Matmul LHS and RHS have incompatible dtypes {a.dtype} vs {b.dtype}" + ) + m, k = a.shape + num_groups, k2, n = b.shape + if num_groups != group_sizes.shape[0]: + raise ValueError("RHS and group_sizes have incompatible shapes.") + if k != k2: + raise ValueError( + "Matmul LHS and RHS have incompatible shapes " + f"{a.shape} vs {b.shape[1:]}" + ) + collective = config.collective + tile_m, tile_n, tile_k = (config.tile_m, config.tile_n, config.tile_k) + block_tile_m = tile_m + block_tile_n = tile_n + if collective: + tile_m *= 2 + tile_n *= 2 + m_iters = m // tile_m + n_iters = n // tile_n + + max_concurrent_steps = config.max_concurrent_steps + epilogue_tile_n = config.epilogue_tile_n + if tile_n % epilogue_tile_n != 0: + raise ValueError( + f"{tile_n=} must be divisible by {epilogue_tile_n=}" + ) + + if m % tile_m != 0: + raise ValueError(f"{m=} must be divisible by {tile_m=}") + if n % tile_n != 0: + raise ValueError(f"{n=} must be divisible by {tile_n=}") + if k % tile_k != 0: + raise ValueError(f"{k=} must be divisible by {tile_k=}") + swizzle = plgpu.find_swizzle(tile_k * jnp.dtype(dtype).itemsize * 8) + swizzle_elems = swizzle // jnp.dtype(dtype).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(swizzle), + ) + + def kernel(a_gmem, b_gmem, group_sizes_gmem, out_gmem): + linear_grid = (m_iters + num_groups - 1) * n_iters + group_sizes_regs = [group_sizes_gmem[i] for i in range(num_groups)] + cluster_idx = lax.axis_index("x") + + @functools.partial(pl.run_scoped, + a_smem=plgpu.SMEM( + (max_concurrent_steps, block_tile_m, tile_k), + dtype, transforms=transforms + ), + b_smem=plgpu.SMEM( + (max_concurrent_steps, tile_k, block_tile_n), + dtype, transforms=transforms + ), + # Temporary SMEM used for storing accumulator output to GMEM. + acc_smem=plgpu.SMEM( + (block_tile_m, epilogue_tile_n), dtype), + # a/b_tma_barrier + a_tma_barrier=plgpu.Barrier(num_arrivals=1, num_barriers=max_concurrent_steps), + b_tma_barrier=plgpu.Barrier(num_arrivals=1, num_barriers=max_concurrent_steps), + # store_done_barrier, double-buffered + store_done_barrier=plgpu.Barrier(num_arrivals=1, num_barriers=2, + orders_tensor_core=True), + # mma_done_barrier, double-buffered + mma_done_barrier=plgpu.Barrier(num_arrivals=1, num_barriers=2, + orders_tensor_core=True), + # consumed_barrier + consumed_barrier=plgpu.Barrier( + num_arrivals=1, + num_barriers=max_concurrent_steps, + orders_tensor_core=True, + ), + # Accumulator TMEM (double-buffered) + acc_tmem=plgpu.TMEM( + (block_tile_m, tile_n * 2), jnp.float32, collective=collective), + collective_axes=("wg",) + ) + def _scoped(**ref_kwargs): + @plgpu.nd_loop(grid=(linear_grid,), + collective_axes="sm") + def mn_loop(loop_info: plgpu.NDLoopInfo): # pylint: disable=unused-variable + linear_idx, = loop_info.index + local_index = loop_info.local_index # type: ignore + m_index, n_index = plgpu.planar_snake( + linear_idx, + (m_iters + num_groups - 1, n_iters), + config.grid_minor_dim, + config.grid_tile_width, + ) + with jax.named_scope("create_group_info"): + group_info = ragged_dot_mgpu.GroupInfo.create( + group_sizes_regs, tile_m, m_index + ) + do_matmul( + a_gmem, + b_gmem.at[group_info.group_id], + out_gmem, + grid_indices=(group_info.block, n_index, cluster_idx), + wg_axis="wg", + collective_axes=("x",) if collective else (), + local_index=local_index, # type: ignore + config=config, + group_info=group_info, + **ref_kwargs + ) + + num_sms = jax.local_devices()[0].core_count + compiler_params = None + f = plgpu.kernel( + kernel, + compiler_params=compiler_params, + kernel_name=f"ragged_dot_kernel_{str(config)}", + out_shape=jax.ShapeDtypeStruct((m, n), dtype), + grid=(num_sms//2,) if collective else (num_sms,), + grid_names=("sm",), + num_threads=2, + thread_name="wg", + cluster_names=("x",) if collective else (), + cluster=(2,) if collective else (), + ) + return f(a, b, group_sizes) + + +def ragged_dot_reference(a, b, g): + return lax.ragged_dot(a, b, g, preferred_element_type=jnp.float16) + + +def sample_group_sizes(key: jax.Array, + num_groups: int, + num_elements: int, + alpha: float = 10.0, + ): + """Sample group sizes. + + Args: + key: PRNG key. + num_groups: Number of groups to sample. + num_elements: Total number of elements to sample. + alpha: Shape parameter. The lower the alpha, the more imbalanced the + group sizes will be. As alpha approaches infinity, the group sizes + approach a uniform distribution. + + Returns: + A jax.Array of shape (num_groups,) that sums to num_elements. + """ + probs_key, sample_key = jax.random.split(key) + probs = jax.random.dirichlet(probs_key, jnp.ones((num_groups,)) * alpha) + return jax.random.multinomial( + sample_key, num_elements, probs).astype(jnp.int32) + + +def main(_) -> None: + M = 16 * 1024 + K = 2048 + N = 16 * 1024 + num_groups = 16 + group_sizes = sample_group_sizes(jax.random.key(0), num_groups, M, alpha=10.0) + + print(f"==== {M=} {N=} {K=} {num_groups=}====") + matmul_flops = 2 * M * N * K + peak_flops = 2.25e15 # f16 TensorCore peak = 2250 TFLOPS + a = jax.random.uniform(jax.random.key(1), (M, K), jnp.float16) + b = jax.random.uniform(jax.random.key(2), (num_groups, K, N), jnp.float16) + + tuning_it = itertools.product( + (128,), # tile_m + (128,), # tile_n + (64,), # tile_k + (1, 8, 12, 16), # grid_tile_width + blackwell_matmul_mgpu.MatmulDimension, # grid_minor_dim + (4, 6) # max_concurrent_steps + ) + best_util = -float("inf") + for (tile_m, tile_n, tile_k, grid_tile_width, grid_minor_dim, + max_concurrent_steps,) in tuning_it: + config = TuningConfig( + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + grid_tile_width=grid_tile_width, + grid_minor_dim=grid_minor_dim, + max_concurrent_steps=max_concurrent_steps, + collective=True, + ) + try: + out, runtime_ms = profiler.measure( + functools.partial(ragged_dot_kernel, config=config), + iterations=10 + )(a, b, group_sizes) + runtime_ms = np.median(runtime_ms if runtime_ms else []) # type: ignore + except ValueError as e: + if ("exceeds available shared memory" in e.args[0] or + "Accumulator layout mismatch:" in e.args[0]): + print(e.args[0]) + continue + raise + expected = ragged_dot_reference(a, b, group_sizes) + np.testing.assert_allclose(out, expected) + + runtime_us = runtime_ms * 1e3 # type: ignore + optimal_time = matmul_flops / peak_flops * 1e6 # us + achieved_tc_util = optimal_time / runtime_us * 100 + if achieved_tc_util > best_util: + best_util = achieved_tc_util + print( + f"{tile_m=} {tile_n=} {tile_k=} {grid_tile_width=} {grid_minor_dim=} {max_concurrent_steps=} " + f"{runtime_us:<7.1f}us" + f" = {achieved_tc_util:4.1f}% TC utilization" + ) + print(f"\tBest utilization: {best_util:4.1f}%") + + +if __name__ == "__main__": + from absl import app + + jax.config.config_with_absl() + app.run(main) diff --git a/jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py b/jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py new file mode 100644 index 000000000000..01b92e66476b --- /dev/null +++ b/jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py @@ -0,0 +1,253 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A collective matmul kernel implemented using Mosaic GPU.""" + +import functools +import itertools + +import jax +from jax import lax +from jax.experimental import multihost_utils +from jax.experimental import pallas as pl +from jax.experimental.mosaic.gpu import profiler +from jax.experimental.pallas import mosaic_gpu as plgpu +from jax.experimental.pallas.ops.gpu import hopper_matmul_mgpu +import jax.numpy as jnp + + +MatmulDimension = hopper_matmul_mgpu.MatmulDimension +TuningConfig = hopper_matmul_mgpu.TuningConfig + + +def all_gather_lhs_matmul( + lhs: jax.Array, + rhs: jax.Array, + axis_name, + *, + config: hopper_matmul_mgpu.TuningConfig, + dtype: jnp.dtype = jnp.float16, +) -> jax.Array: + if (num_devices := jax.device_count()) != jax.process_count(): + raise ValueError("The kernel only supports one device per process") + if (axis_size := lax.axis_size(axis_name)) != num_devices: + raise ValueError("The kernel can only work over all devices in a Mesh.") + if jnp.dtype(dtype) not in map(jnp.dtype, [jnp.float16, jnp.bfloat16]): + raise NotImplementedError(f"Only f16 and bf16 are supported, got {dtype=}") + if config.cluster_dimension is not None: + raise NotImplementedError("Cluster dimension must be None for all-gather matmuls.") + + m_shard, k = lhs.shape + k2, n_shard = rhs.shape + if k != k2: + raise ValueError( + f"lhs and rhs must have the same contraction size, got {k} and {k2}." + ) + if (element_type := lhs.dtype) != rhs.dtype: + raise ValueError( + f"lhs and rhs must have the same element type, got {element_type} and" + f" {rhs.dtype}." + ) + tile_m, tile_n, tile_k = config.tile_m, config.tile_n, config.tile_k + max_concurrent_steps = config.max_concurrent_steps + if max_concurrent_steps < 2: + raise ValueError("max_concurrent_steps must be >= 2") + cta_tile_m = tile_m * (1 + (config.wg_dimension == MatmulDimension.M)) + + epi_tile_n = config.epi_tile_n or tile_n + epi_tile_m = config.epi_tile_m or tile_m + if tile_n % epi_tile_n != 0: + raise ValueError(f"{tile_n=} must be divisible by {epi_tile_n=}") + if tile_m % epi_tile_m != 0: + raise ValueError(f"{tile_m=} must be divisible by {epi_tile_m=}") + + num_sms = jax.devices()[0].core_count # 132 for H100 SXM GPUs. + + def kernel_body(lhs_local_ref, rhs_ref, out_ref, scratch_ref): + received_sem = pl.get_global(plgpu.SemaphoreType.REGULAR) + wg_idx = lax.axis_index("wg") + dev_id = lax.axis_index(axis_name) + send_dev_id = lax.rem(dev_id + axis_size - 1, axis_size) + send_scratch_ref = plgpu.remote_ref(scratch_ref, send_dev_id) + + def send_lhs(m_idx, n_idx, k_idx, a_smem, b_smem, send_ref, should_send): + del b_smem # Unused. + # We only send when n_idx == 0 to avoid sending the same data + # multiple times when revisiting lhs. + @pl.when(should_send & jnp.bool(n_idx == 0)) + def _(): + k_slice = pl.ds(k_idx * tile_k, tile_k) + m_slice = pl.ds(m_idx * cta_tile_m, cta_tile_m) + plgpu.copy_smem_to_gmem(a_smem, send_ref.at[m_slice, k_slice]) + # We only delay release by 1 step, so we need to wait for the + # previous copies. + plgpu.wait_smem_to_gmem(1, wait_read_only=True) + + def device_step(lhs_source_ref, device_offset): + # Invariant: lhs_source_ref is ready to be used + next_scratch_slot = device_offset + out_device_m_slice = pl.ds( + lax.rem(device_offset + dev_id, num_devices) * m_shard, m_shard + ) + is_send_wg = wg_idx == 0 + has_send_space = next_scratch_slot < num_devices - 1 + should_send = is_send_wg & has_send_space + + # This reuses the regular matmul kernel, only with the exception of + # inserting send_lhs into the pipeline. + # TODO(apaszke): This contains run_scoped inside, meaning that it will + # synchronize all threads at each device step. If we optimize the barrier + # below, then it might be better to move it out to make bubbles smaller. + hopper_matmul_mgpu.kernel( + lhs_source_ref, # Use the lhs from previous step. + rhs_ref, # Use the same rhs for all steps. + out_ref.at[out_device_m_slice], # Use a slice of the output. + config=config, + pipeline_callback=functools.partial( + send_lhs, + send_ref=send_scratch_ref.at[next_scratch_slot], + should_send=should_send, + ), + delay_release=1, + ) + + # Wait for the next scratch to arrive --- see the device loop invariant. + @pl.when(should_send) + def _signal(): + # TODO(apaszke): We could do this signal a lot earlier if we better + # control the order of sends. If we tile the grid along N, then we can + # signal as soon as everyone moves on from the first column tile. + # Make sure the copy is done and signal the receiving device. + plgpu.wait_smem_to_gmem(0, wait_read_only=False) + pl.semaphore_signal(received_sem, device_id=send_dev_id) + @pl.when(next_scratch_slot < num_devices - 1) + def _wait(): + pl.semaphore_wait(received_sem, value=(device_offset + 1) * num_sms, decrement=False) + + # We peel the first step to copy data directly form lhs_local_ref. + device_step(lhs_local_ref, 0) + @pl.loop(1, num_devices) + def _device_loop(device_offset): + device_step(scratch_ref.at[device_offset - 1], device_offset) + # Make sure all copies are fully done. + plgpu.wait_smem_to_gmem(0, wait_read_only=True) + + result, _ = plgpu.kernel( + kernel_body, + out_shape=[ + # The output, with its M dimension all-gathered. + jax.ShapeDtypeStruct((axis_size * m_shard, n_shard), dtype), + # The scratch buffer used for the all-gather. + jax.ShapeDtypeStruct((num_devices - 1, m_shard, k), dtype), + ], + grid=(num_sms,), + grid_names=("cluster_grid",), + num_threads=3, + thread_name="wg", + cluster=(1,), + cluster_names=("cluster",), + )(lhs, rhs) + return result + + +def _run_example(): + P = jax.sharding.PartitionSpec + m_shard = 1024 + n_shard = 4096 + k = 4096 + dtype = jnp.bfloat16 + shards = jax.device_count() + mesh = jax.make_mesh( + (shards,), ("x",), axis_types=(jax.sharding.AxisType.Explicit,) + ) + jax.set_mesh(mesh) + + # We measure time per-shard and so we only need FLOPs per shard. + matmul_flops = 2 * (shards * m_shard) * n_shard * k + peak_flops = 990e12 # f16 TensorCore peak = 990 TFLOPS + optimal_time = matmul_flops / peak_flops * 1e6 # us + a = jax.random.normal(jax.random.key(1), (shards * m_shard, k), dtype) + b = jax.random.normal(jax.random.key(2), (k, shards * n_shard), dtype) + a = jax.sharding.reshard(a, P("x", None)) + b = jax.sharding.reshard(b, P(None, "x")) + _, ref_kernels_ms = profiler.measure(jax.jit( + jax.shard_map( + lambda x, y: lax.all_gather(x, "x", axis=0, tiled=True) @ y, + out_specs=P(None, "x"), + check_vma=False, + ) + ), aggregate=False)(a, b) + ref_time_us = sum(t * 1e3 for _, t in ref_kernels_ms) + # We choose the minimum across processes to choose the runtime that didn't + # include devices waiting for other devices. + ref_time_us = min(multihost_utils.process_allgather(ref_time_us).tolist()) + ref_util = optimal_time / ref_time_us * 100 + + tuning_it = itertools.product( + (128, 256,), # tile_m + (64, 128), # tile_n + (64,), # tile_k + (4,), # max_concurrent_steps + (MatmulDimension.M, MatmulDimension.N), # grid_minor_dim + (4, 8, 16), # grid_tile_width + MatmulDimension, # wg_dimension + ) + best_util = 0.0 + best_runtime = float("inf") + def build_kernel(**kwargs): + return jax.jit( + jax.shard_map( + functools.partial(all_gather_lhs_matmul, **kwargs), + out_specs=P(None, "x"), + check_vma=False, + ) + ) + + for tile_m, tile_n, tile_k, max_concurrent_steps, grid_minor_dim, grid_tile_width, wg_dimension in tuning_it: + try: + config = TuningConfig( + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + max_concurrent_steps=max_concurrent_steps, + grid_minor_dim=grid_minor_dim, + grid_tile_width=grid_tile_width, + wg_dimension=wg_dimension, + ) + _, kernels_ms = profiler.measure( + build_kernel(axis_name="x", config=config, dtype=dtype), + aggregate=False, + )(a, b) + except ValueError as e: + if "exceeds available shared memory" in e.args[0]: # Ignore SMEM OOMs. + continue + raise + runtime_us = sum(t * 1e3 for _, t in kernels_ms) + runtime_us = min(multihost_utils.process_allgather(runtime_us).tolist()) + achieved_tc_util = optimal_time / runtime_us * 100 + if achieved_tc_util > best_util: + best_runtime = runtime_us + best_util = achieved_tc_util + print( + f"{tile_m=} {tile_n=} {tile_k=} {max_concurrent_steps=} {grid_minor_dim=} {grid_tile_width=} {wg_dimension=}: " + f"{runtime_us:<7.1f}us" + f" = {achieved_tc_util:4.1f}% TC utilization" + ) + print(f"\tBest: {best_runtime:<7.1f}us = {best_util:4.1f}% TC utilization") + print(f"\tRef: {ref_time_us:<7.1f}us = {ref_util:4.1f}% TC utilization") + + +if __name__ == "__main__": + from jax._src import test_multiprocess as jt_multiprocess # pytype: disable=import-error + jt_multiprocess.main(shard_main=_run_example) diff --git a/jax/experimental/pallas/ops/gpu/decode_attention.py b/jax/experimental/pallas/ops/gpu/decode_attention.py index e2c19b3eaf2d..cf815e640141 100644 --- a/jax/experimental/pallas/ops/gpu/decode_attention.py +++ b/jax/experimental/pallas/ops/gpu/decode_attention.py @@ -50,7 +50,7 @@ def _compute(start_idx, kv_seq_len, o, m_i, l_i): # Load q: it will stay in L1 throughout. Indices form a matrix because we # read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index. # q tile has shape [block_h, head_dim]. - q = pl.load(q_ref, (q_slice, pl.ds(None)), mask=q_mask) + q = plgpu.load(q_ref.at[q_slice, :], mask=q_mask) def _dot(a, b): # if a.shape[0] == 1: @@ -66,7 +66,7 @@ def body(start_k, carry): o_prev, m_prev, l_prev = carry curr_k_slice = pl.ds(start_k * block_k, block_k) - k = pl.load(k_ref, (curr_k_slice, slice(None))) + k = k_ref[curr_k_slice, :] qk = _dot(q, k.T) # [block_h, block_k] if sm_scale != 1.0: qk *= sm_scale # [block_h, block_k] @@ -86,7 +86,7 @@ def body(start_k, carry): ) # Use m_next instead of m_curr to avoid a correction on l_curr l_curr = s_curr.sum(axis=-1) l_next = l_prev_corr + l_curr - v = pl.load(v_ref, (curr_k_slice, slice(None))) + v = v_ref[curr_k_slice, :] o_curr = _dot(s_curr.astype(v.dtype), v) # flash2 unscaled_o @@ -106,10 +106,10 @@ def body(start_k, carry): start_idx = split_k_seq_len * prog_j if start_idx_ref is not None: - start_idx = jnp.maximum(start_idx, pl.load(start_idx_ref, ())) + start_idx = jnp.maximum(start_idx, start_idx_ref[()]) kv_seq_len = (prog_j + 1) * split_k_seq_len # lower bound on actual k_seq_len if kv_seq_len_ref is not None: - kv_seq_len = jnp.minimum(kv_seq_len, pl.load(kv_seq_len_ref, ())) + kv_seq_len = jnp.minimum(kv_seq_len, kv_seq_len_ref[()]) if start_idx_ref is None and kv_seq_len is None: o, m_i, l_i = _compute(start_idx, kv_seq_len, o, m_i, l_i) @@ -122,10 +122,10 @@ def body(start_k, carry): if residual_refs: l_ref, m_ref = residual_refs vec_q_mask = q_mask.reshape(-1) if q_mask is not None else None - pl.store(l_ref, q_slice, l_i, mask=vec_q_mask) - pl.store(m_ref, q_slice, m_i, mask=vec_q_mask) + plgpu.store(l_ref.at[q_slice], l_i, mask=vec_q_mask) + plgpu.store(m_ref.at[q_slice], m_i, mask=vec_q_mask) o = o.astype(o_ref.dtype) - pl.store(o_ref, (q_slice, pl.ds(None)), o, mask=q_mask) + plgpu.store(o_ref.at[q_slice, :], o, mask=q_mask) def decode_attn_unbatched( @@ -193,7 +193,7 @@ def decode_attn_unbatched( pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # l pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # m ], - compiler_params=plgpu.TritonCompilerParams( + compiler_params=plgpu.CompilerParams( num_warps=num_warps_, num_stages=num_stages ), out_shape=[ @@ -359,16 +359,18 @@ def gqa( normalize_output=normalize_output, ) with_kv_heads = jax.vmap(inner) - o, *res = jax.vmap(with_kv_heads)( + outputs = jax.vmap(with_kv_heads)( q_reshaped, k_transposed, v_transposed, start_idx, kv_seq_len ) - o = o.reshape(batch_size, q_heads, head_dim) if return_residuals: - l, m = res[0] + o, (l, m) = outputs + o = o.reshape(batch_size, q_heads, head_dim) l = l.reshape(batch_size, q_heads) m = m.reshape(batch_size, q_heads) return o, (l, m) else: + o = outputs + o = o.reshape(batch_size, q_heads, head_dim) # pytype: disable=attribute-error return o diff --git a/jax/experimental/pallas/ops/gpu/hopper_matmul_mgpu.py b/jax/experimental/pallas/ops/gpu/hopper_matmul_mgpu.py new file mode 100644 index 000000000000..25917b5ec093 --- /dev/null +++ b/jax/experimental/pallas/ops/gpu/hopper_matmul_mgpu.py @@ -0,0 +1,307 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Matrix Multiplication kernel for Hopper GPUs.""" +import statistics +import dataclasses +import enum +import functools +import itertools +import jax +from jax import lax +from jax._src import test_util as jtu # noqa: F401 +from jax.experimental.mosaic.gpu import profiler +import jax.experimental.pallas as pl +import jax.experimental.pallas.mosaic_gpu as plgpu +from jax.extend import backend +import jax.numpy as jnp +import numpy as np + +# mypy: ignore-errors + +class MatmulDimension(enum.IntEnum): + M = 0 + N = 1 + + def __str__(self): + return self.name + + def __repr__(self): + return self.name + + +@dataclasses.dataclass(frozen=True) +class TuningConfig: + tile_m: int + tile_n: int + tile_k: int + max_concurrent_steps: int + epi_tile_n: int | None = 64 # This needs to be lowered for for small N. + epi_tile_m: int | None = 64 + grid_minor_dim: MatmulDimension = MatmulDimension.N + grid_tile_width: int = 1 + wg_dimension: MatmulDimension = MatmulDimension.N + cluster_dimension: None | MatmulDimension = None + + +# pipeline_callback and delay_release are only used for collective matmuls. +def kernel(a_gmem, b_gmem, out_gmem, config: TuningConfig, + pipeline_callback=None, delay_release=0): + dtype = a_gmem.dtype + assert b_gmem.dtype == dtype + m, k = a_gmem.shape + k2, n = b_gmem.shape + assert k == k2 + tile_m, tile_n, tile_k = config.tile_m, config.tile_n, config.tile_k + max_concurrent_steps = config.max_concurrent_steps + swizzle = plgpu.find_swizzle(tile_k * jnp.dtype(dtype).itemsize * 8) + swizzle_elems = swizzle // jnp.dtype(dtype).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), plgpu.SwizzleTransform(swizzle) + ) + + cta_tile_m = tile_m * (1 + (config.wg_dimension == MatmulDimension.M)) + cta_tile_n = tile_n * (1 + (config.wg_dimension == MatmulDimension.N)) + cluster_tile_m = cta_tile_m * (1 + (config.cluster_dimension == MatmulDimension.M)) + cluster_tile_n = cta_tile_n * (1 + (config.cluster_dimension == MatmulDimension.N)) + if m % cluster_tile_m != 0: + raise ValueError(f"{m=} must be divisible by {cluster_tile_m} for the given config") + if n % cluster_tile_n != 0: + raise ValueError(f"{n=} must be divisible by {cluster_tile_n} for the given config") + if k % tile_k != 0: + raise ValueError(f"{k=} must be divisible by {tile_k=}") + m_iters = m // cluster_tile_m + n_iters = n // cluster_tile_n + k_iters = k // tile_k + + epi_tile_m = config.epi_tile_m or tile_m + epi_tile_n = config.epi_tile_n or tile_n + # We don't need multiple slots if there's only one epilogue tile. + num_out_slots = min(2, (tile_m * tile_n) // (epi_tile_m * epi_tile_n)) + out_swizzle = plgpu.find_swizzle(epi_tile_n * jnp.dtype(dtype).itemsize * 8) + out_swizzle_elems = out_swizzle // jnp.dtype(dtype).itemsize + out_transforms = ( + plgpu.TilingTransform((8, out_swizzle_elems)), + plgpu.SwizzleTransform(out_swizzle), + ) + + def get_pipeline(pipeline_body, compute_context): + return plgpu.emit_pipeline_warp_specialized( + pipeline_body, + grid=(k_iters,), + memory_registers=40, + in_specs=[ + plgpu.BlockSpec( + (cta_tile_m, tile_k), + lambda k: (0, k), + transforms=transforms, + memory_space=plgpu.SMEM, + delay_release=delay_release, + collective_axes=("cluster",) + if config.cluster_dimension == MatmulDimension.N + else (), + ), + plgpu.BlockSpec( + (tile_k, cta_tile_n), + lambda k: (k, 0), + transforms=transforms, + memory_space=plgpu.SMEM, + delay_release=delay_release, + collective_axes=("cluster",) + if config.cluster_dimension == MatmulDimension.M + else (), + ), + ], + wg_axis="wg", + num_compute_wgs=2, + max_concurrent_steps=max_concurrent_steps, + compute_context=compute_context, + ) + + # Functions don't influence the allocations necessary to run the pipeline. + ignore = lambda *_, **__: None + @functools.partial( + pl.run_scoped, + pipeline_allocs=get_pipeline(ignore, ignore).get_allocations(a_gmem, b_gmem), + out_smem=plgpu.SMEM( + (2, num_out_slots, epi_tile_m, epi_tile_n), + dtype, + transforms=out_transforms, + ), + collective_axes="wg", + ) + def _pipeline_scope(pipeline_allocs, out_smem): + wg_idx = lax.axis_index("wg") + cta_idx = lax.axis_index("cluster") + @plgpu.nd_loop((m_iters * n_iters,), collective_axes="cluster_grid") + def _mn_loop(loop_info: plgpu.NDLoopInfo): + (lin_idx,) = loop_info.index + m_cluster_idx, n_cluster_idx = plgpu.planar_snake( + lin_idx, + (m_iters, n_iters), + config.grid_minor_dim, + config.grid_tile_width, + ) + m_idx = m_cluster_idx + n_idx = n_cluster_idx + if config.cluster_dimension == MatmulDimension.M: + m_idx = m_cluster_idx * 2 + cta_idx + elif config.cluster_dimension == MatmulDimension.N: + n_idx = n_cluster_idx * 2 + cta_idx + cta_m_slice = pl.ds(m_idx * cta_tile_m, cta_tile_m) + cta_n_slice = pl.ds(n_idx * cta_tile_n, cta_tile_n) + if config.wg_dimension == MatmulDimension.M: + wg_m_slice = pl.ds(wg_idx * tile_m, tile_m) + wg_n_slice = slice(None) + else: + wg_m_slice = slice(None) + wg_n_slice = pl.ds(wg_idx * tile_n, tile_n) + + def compute_context(eval_pipeline): + @functools.partial( + pl.run_scoped, acc_ref=plgpu.ACC((tile_m, tile_n), jnp.float32) + ) + def _acc_scope(acc_ref): + eval_pipeline(acc_ref) + acc = acc_ref[...].astype(dtype) + plgpu.wait_smem_to_gmem(0, wait_read_only=True) + for epi_mi in range(tile_m // epi_tile_m): + for epi_ni in range(tile_n // epi_tile_n): + epi_m_slice = slice(epi_mi * epi_tile_m, (epi_mi + 1) * epi_tile_m) + epi_n_slice = slice(epi_ni * epi_tile_n, (epi_ni + 1) * epi_tile_n) + slot = (epi_mi * (tile_n // epi_tile_n) + epi_ni) % 2 + plgpu.wait_smem_to_gmem(1, wait_read_only=True) + out_smem[wg_idx, slot] = acc[epi_m_slice, epi_n_slice] + plgpu.commit_smem() + plgpu.copy_smem_to_gmem( + out_smem.at[wg_idx, slot], + out_gmem.at[cta_m_slice, cta_n_slice] + .at[wg_m_slice, wg_n_slice] + .at[epi_m_slice, epi_n_slice], + ) + + def mma_body(idxs, a_smem, b_smem, acc_ref): + plgpu.wgmma(acc_ref, a_smem.at[wg_m_slice], b_smem.at[:, wg_n_slice]) + if pipeline_callback is not None: + (k_idx,) = idxs + pipeline_callback(m_idx, n_idx, k_idx, a_smem, b_smem) + plgpu.wgmma_wait(delay_release) + return acc_ref + + get_pipeline(mma_body, compute_context)( + a_gmem.at[cta_m_slice, :], + b_gmem.at[:, cta_n_slice], + allocations=pipeline_allocs, + ) + # Await all transfers before we exit. + plgpu.wait_smem_to_gmem(0, wait_read_only=True) + + +def matmul(a, b, config: TuningConfig): + dtype = a.dtype + if a.dtype != b.dtype: + raise ValueError( + f"Matmul LHS and RHS have incompatible dtypes {a.dtype} vs {b.dtype}" + ) + m, k = a.shape + k2, n = b.shape + assert k == k2 + if k != k2: + raise ValueError( + f"Matmul LHS and RHS have incompatible shapes {a.shape} vs {b.shape}" + ) + tile_m, tile_n = config.tile_m, config.tile_n + epi_tile_n = config.epi_tile_n or tile_n + epi_tile_m = config.epi_tile_m or tile_m + config = dataclasses.replace(config, epi_tile_n=epi_tile_n, epi_tile_m=epi_tile_m) + + num_sms = backend.get_default_device().core_count + cluster_size = 1 + (config.cluster_dimension is not None) + f = plgpu.kernel( + functools.partial(kernel, config=config), + out_shape=jax.ShapeDtypeStruct((m, n), dtype), + grid=(num_sms // cluster_size,), + grid_names=("cluster_grid",), + cluster=(cluster_size,), + cluster_names=("cluster",), + num_threads=3, + thread_name="wg", + ) + return f(a, b) + + +def main(_) -> None: + problem_it = [(4096, 8192, 4096)] + for M, N, K in problem_it: + print(f"==== {M=} {N=} {K=} ====") + matmul_flops = 2 * M * N * K + peak_flops = 990e12 # f16 TensorCore peak = 990 TFLOPS + a = jax.random.uniform(jax.random.key(0), (M, K), jnp.float16) + b = jax.random.uniform(jax.random.key(1), (K, N), jnp.float16) + ref = a @ b + tuning_it = itertools.product( + (128, 256,), # tile_m + (64, 128), # tile_n + (64,), # tile_k + (4,), # max_concurrent_steps + (True,), # Tiled epilogue + (MatmulDimension.M, MatmulDimension.N), # grid_minor_dim + (4, 8, 16), # grid_tile_width + MatmulDimension, # wg_dimension + # Consider adding MatmulDimension here to try out collective TMA kernels + (None,) # cluster_dimension + ) + best_util = 0.0 + best_runtime = float("inf") + for tile_m, tile_n, tile_k, max_concurrent_steps, tiled_epilogue, grid_minor_dim, grid_tile_width, wg_dimension, cluster_dimension in tuning_it: + config = TuningConfig( + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + max_concurrent_steps=max_concurrent_steps, + epi_tile_n=64 if tiled_epilogue else None, + epi_tile_m=64 if tiled_epilogue else None, + grid_minor_dim=grid_minor_dim, + grid_tile_width=grid_tile_width, + wg_dimension=wg_dimension, + cluster_dimension=cluster_dimension, + ) + try: + out, runtimes_ms = profiler.measure( + functools.partial(matmul, config=config), iterations=10, + )(a, b) + assert runtimes_ms is not None + runtime_ms = statistics.median(runtimes_ms) + except ValueError as e: + if "exceeds available shared memory" in e.args[0]: # Ignore SMEM OOMs. + continue + raise + np.testing.assert_allclose(out, ref) + runtime_us = runtime_ms * 1e3 # type: ignore + optimal_time = matmul_flops / peak_flops * 1e6 # us + achieved_tc_util = optimal_time / runtime_us * 100 + if achieved_tc_util > best_util: + best_runtime = runtime_us + best_util = achieved_tc_util + print( + f"{tile_m=} {tile_n=} {tile_k=} {max_concurrent_steps=} {tiled_epilogue=} {grid_minor_dim=} {grid_tile_width=} {wg_dimension=} {cluster_dimension=}:" + f" {runtime_us:<7.1f}us = {achieved_tc_util:4.1f}% TC utilization" + ) + print(f"\tBest: {best_runtime:<7.1f}us = {best_util:4.1f}% TC utilization") + + +if __name__ == "__main__": + from absl import app + + jax.config.config_with_absl() + app.run(main) diff --git a/jax/experimental/pallas/ops/gpu/hopper_mixed_type_matmul_mgpu.py b/jax/experimental/pallas/ops/gpu/hopper_mixed_type_matmul_mgpu.py new file mode 100644 index 000000000000..a032a9b49d72 --- /dev/null +++ b/jax/experimental/pallas/ops/gpu/hopper_mixed_type_matmul_mgpu.py @@ -0,0 +1,343 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Matrix Multiplication kernel for Hopper GPUs.""" +import statistics +import dataclasses +import enum +import functools +import itertools +import jax +from jax._src import dtypes +from jax import lax +from jax._src import test_util as jtu # noqa: F401 +from jax.experimental.mosaic.gpu import profiler +import jax.experimental.pallas as pl +import jax.experimental.pallas.mosaic_gpu as plgpu +from jax.extend import backend +import jax.numpy as jnp +import numpy as np + + +class MatmulDimension(enum.IntEnum): + M = 0 + N = 1 + + def __str__(self): + return self.name + + def __repr__(self): + return self.name + + +@dataclasses.dataclass(frozen=True) +class TuningConfig: + tile_m: int + tile_n: int + tile_k: int + max_concurrent_steps: int + epi_tile_n: int | None = 64 # This needs to be lowered for for small N. + epi_tile_m: int | None = 64 + grid_minor_dim: MatmulDimension = MatmulDimension.N + grid_tile_width: int = 1 + wg_dimension: MatmulDimension = MatmulDimension.N + cluster_dimension: None | MatmulDimension = None + + +def mixed_matmul_kernel( + a: jax.Array, b: jax.Array, *, out_dtype: jnp.dtype, config: TuningConfig +) -> jax.Array: + """Mixed-type matrix multiplication kernel for Hopper GPUs. + + Specifically, this kernel implements the function + (a.as_dtype(b.dtype) @ b).astype(out_dtype). + """ + if a.dtype == b.dtype: + raise ValueError( + f"Mixed matmul LHS and RHS have the same dtype {a.dtype}. For such " + "matrix multiplications, use the `hopper_matmul_mgpu` kernel instead." + ) + match (a.dtype, b.dtype): + case (jnp.int8, jnp.bfloat16): + pass + case (jnp.int8, jnp.float16): + pass + case _, _: + # We do support more combinations, but we haven't benchmarked them + # yet---so we raise for the time being. + raise NotImplementedError( + f"Unbenchmarked dtype combination: {a.dtype=} and {b.dtype=}" + ) + m, k = a.shape + k2, n = b.shape + if k != k2: + raise ValueError( + f"Matmul LHS and RHS have incompatible shapes {a.shape} vs {b.shape}" + ) + + tile_m, tile_n, tile_k = config.tile_m, config.tile_n, config.tile_k + epi_tile_n = config.epi_tile_n or tile_n + epi_tile_m = config.epi_tile_m or tile_m + if tile_n % epi_tile_n != 0: + raise ValueError(f"{tile_n=} must be divisible by {epi_tile_n=}") + if tile_m % epi_tile_m != 0: + raise ValueError(f"{tile_m=} must be divisible by {epi_tile_m=}") + + a_bits = dtypes.itemsize_bits(a.dtype) + b_bits = dtypes.itemsize_bits(b.dtype) + out_bits = dtypes.itemsize_bits(out_dtype) + + a_swizzle = plgpu.find_swizzle(tile_k * a_bits, "lhs") + b_swizzle = plgpu.find_swizzle(tile_n * b_bits, "rhs") + out_swizzle = plgpu.find_swizzle(epi_tile_n * out_bits, "out") + + a_transforms = ( + plgpu.TilingTransform((8, a_swizzle * 8 // a_bits)), + plgpu.SwizzleTransform(a_swizzle), + ) + b_transforms = ( + plgpu.TilingTransform((8, b_swizzle * 8 // b_bits)), + plgpu.SwizzleTransform(b_swizzle), + ) + out_transforms = ( + plgpu.TilingTransform((8, out_swizzle * 8 // out_bits)), + plgpu.SwizzleTransform(out_swizzle), + ) + + max_concurrent_steps = config.max_concurrent_steps + cta_tile_m = tile_m * (1 + (config.wg_dimension == MatmulDimension.M)) + cta_tile_n = tile_n * (1 + (config.wg_dimension == MatmulDimension.N)) + cluster_tile_m = cta_tile_m * (1 + (config.cluster_dimension == MatmulDimension.M)) + cluster_tile_n = cta_tile_n * (1 + (config.cluster_dimension == MatmulDimension.N)) + if m % cluster_tile_m != 0: + raise ValueError(f"{m=} must be divisible by {cluster_tile_m} for the given config") + if n % cluster_tile_n != 0: + raise ValueError(f"{n=} must be divisible by {cluster_tile_n} for the given config") + if k % tile_k != 0: + raise ValueError(f"{k=} must be divisible by {tile_k=}") + m_iters = m // cluster_tile_m + n_iters = n // cluster_tile_n + k_iters = k // tile_k + + def kernel(a_gmem, b_gmem, out_gmem, out_smem): + + def get_pipeline(pipeline_body, compute_context): + return plgpu.emit_pipeline_warp_specialized( + pipeline_body, + grid=(k_iters,), + memory_registers=40, + in_specs=[ + plgpu.BlockSpec( + (cta_tile_m, tile_k), + lambda k: (0, k), + transforms=a_transforms, + memory_space=plgpu.SMEM, + collective_axes=("cluster",) + if config.cluster_dimension == MatmulDimension.N + else (), + ), + plgpu.BlockSpec( + (tile_k, cta_tile_n), + lambda k: (k, 0), + transforms=b_transforms, + memory_space=plgpu.SMEM, + collective_axes=("cluster",) + if config.cluster_dimension == MatmulDimension.M + else (), + ), + ], + wg_axis="wg", + num_compute_wgs=2, + max_concurrent_steps=max_concurrent_steps, + compute_context=compute_context, + ) + + # Functions don't influence the allocations necessary to run the pipeline. + ignore = lambda *_, **__: None + @functools.partial( + pl.run_scoped, + pipeline_allocs=get_pipeline(ignore, ignore).get_allocations(a_gmem, b_gmem), + collective_axes="wg", + ) + def _pipeline_scope(pipeline_allocs): + wg_idx = lax.axis_index("wg") + cta_idx = lax.axis_index("cluster") + @plgpu.nd_loop((m_iters * n_iters,), collective_axes="cluster_grid") + def _mn_loop(loop_info: plgpu.NDLoopInfo): + (lin_idx,) = loop_info.index + m_cluster_idx, n_cluster_idx = plgpu.planar_snake( + lin_idx, + (m_iters, n_iters), + config.grid_minor_dim, + config.grid_tile_width, + ) + m_idx = m_cluster_idx + n_idx = n_cluster_idx + if config.cluster_dimension == MatmulDimension.M: + m_idx = m_cluster_idx * 2 + cta_idx + elif config.cluster_dimension == MatmulDimension.N: + n_idx = n_cluster_idx * 2 + cta_idx + cta_m_slice = pl.ds(m_idx * cta_tile_m, cta_tile_m) + cta_n_slice = pl.ds(n_idx * cta_tile_n, cta_tile_n) + if config.wg_dimension == MatmulDimension.M: + wg_m_slice = pl.ds(wg_idx * tile_m, tile_m) + wg_n_slice = slice(None) + else: + wg_m_slice = slice(None) + wg_n_slice = pl.ds(wg_idx * tile_n, tile_n) # type: ignore + + def compute_context(eval_pipeline): + @functools.partial( + pl.run_scoped, acc_ref=plgpu.ACC((tile_m, tile_n), jnp.float32) + ) + def _acc_scope(acc_ref): + eval_pipeline(acc_ref) + acc = acc_ref[...].astype(out_dtype) + plgpu.wait_smem_to_gmem(0, wait_read_only=True) + for epi_mi in range(tile_m // epi_tile_m): + for epi_ni in range(tile_n // epi_tile_n): + epi_m_slice = slice(epi_mi * epi_tile_m, (epi_mi + 1) * epi_tile_m) + epi_n_slice = slice(epi_ni * epi_tile_n, (epi_ni + 1) * epi_tile_n) + slot = (epi_mi * (tile_n // epi_tile_n) + epi_ni) % 2 + plgpu.wait_smem_to_gmem(1, wait_read_only=True) + out_smem[wg_idx, slot] = acc[epi_m_slice, epi_n_slice] + plgpu.commit_smem() + plgpu.copy_smem_to_gmem( + out_smem.at[wg_idx, slot], + out_gmem.at[cta_m_slice, cta_n_slice] + .at[wg_m_slice, wg_n_slice] + .at[epi_m_slice, epi_n_slice], + ) + + def mma_body(_, a_smem, b_smem, acc_ref): + with jax.named_scope("smem_load"): + a_reg = a_smem[wg_m_slice] + with jax.named_scope("dequant"): + a_reg = a_reg.astype(b.dtype) + with jax.named_scope("wgmma"): + plgpu.wgmma(acc_ref, a_reg, b_smem.at[:, wg_n_slice]) + with jax.named_scope("wgmma_wait"): + plgpu.wgmma_wait(0) + return acc_ref + + get_pipeline(mma_body, compute_context)( + a_gmem.at[cta_m_slice, :], + b_gmem.at[:, cta_n_slice], + allocations=pipeline_allocs, + ) + # Await all transfers before we exit. + plgpu.wait_smem_to_gmem(0, wait_read_only=True) + + # We don't need multiple slots if there's only one epilogue tile. + num_out_slots = min(2, (tile_m * tile_n) // (epi_tile_m * epi_tile_n)) + num_sms = backend.get_default_device().core_count + cluster_size = 1 + (config.cluster_dimension is not None) + f = plgpu.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct((m, n), out_dtype), + grid=(num_sms // cluster_size,), + grid_names=("cluster_grid",), + cluster=(cluster_size,), + cluster_names=("cluster",), + num_threads=3, + thread_name="wg", + scratch_shapes=dict( + out_smem=plgpu.SMEM( + (2, num_out_slots, epi_tile_m, epi_tile_n), + out_dtype, + transforms=out_transforms, + ) + ), + ) + return f(a, b) + + +def reference( + a: jax.Array, b: jax.Array, *, out_dtype: jnp.dtype +) -> jax.Array: + """Reference implementation of a mixed-type matrix multiplication.""" + return jax.numpy.dot(a, b, preferred_element_type=jnp.float32).astype( + out_dtype + ) + + +def main(_) -> None: + problem_it = [(4096, 8192, 4096)] + for M, N, K in problem_it: + print(f"==== {M=} {N=} {K=} ====") + matmul_flops = 2 * M * N * K + peak_flops = 990e12 # f16 TensorCore peak = 990 TFLOPS + a = jax.random.randint( + jax.random.key(0), minval=-128, maxval=127, shape=(M, K), dtype=jnp.int8 + ) + b = jax.random.uniform(jax.random.key(1), (K, N), jnp.bfloat16) + ref = reference(a, b, out_dtype=jnp.bfloat16) + tuning_it = itertools.product( + (64, 128, 256,), # tile_m + (64, 128), # tile_n + (64, 128), # tile_k + (4,), # max_concurrent_steps + (True,), # Tiled epilogue + (MatmulDimension.M, MatmulDimension.N), # grid_minor_dim + (4, 8, 16), # grid_tile_width + MatmulDimension, # wg_dimension + # Consider adding MatmulDimension here to try out collective TMA kernels + (None,) # cluster_dimension + ) + best_util = 0.0 + best_runtime = float("inf") + for tile_m, tile_n, tile_k, max_concurrent_steps, tiled_epilogue, grid_minor_dim, grid_tile_width, wg_dimension, cluster_dimension in tuning_it: + config = TuningConfig( + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + max_concurrent_steps=max_concurrent_steps, + epi_tile_n=64 if tiled_epilogue else None, + epi_tile_m=64 if tiled_epilogue else None, + grid_minor_dim=grid_minor_dim, + grid_tile_width=grid_tile_width, + wg_dimension=wg_dimension, + cluster_dimension=cluster_dimension, + ) + try: + out, runtimes_ms = profiler.measure( + functools.partial( + mixed_matmul_kernel, out_dtype=jnp.bfloat16, config=config + ), + iterations=10, + )(a, b) + assert runtimes_ms is not None + runtime_ms = statistics.median(runtimes_ms) + except ValueError as e: + if "exceeds available shared memory" in e.args[0]: # Ignore SMEM OOMs. + continue + raise + np.testing.assert_allclose(out, ref) + runtime_us = runtime_ms * 1e3 # type: ignore + optimal_time = matmul_flops / peak_flops * 1e6 # us + achieved_tc_util = optimal_time / runtime_us * 100 + if achieved_tc_util > best_util: + best_runtime = runtime_us + best_util = achieved_tc_util + print( + f"{tile_m=} {tile_n=} {tile_k=} {max_concurrent_steps=} {tiled_epilogue=} {grid_minor_dim=} {grid_tile_width=} {wg_dimension=} {cluster_dimension=}:" + f" {runtime_us:<7.1f}us = {achieved_tc_util:4.1f}% TC utilization" + ) + print(f"\tBest: {best_runtime:<7.1f}us = {best_util:4.1f}% TC utilization") + + +if __name__ == "__main__": + from absl import app + + jax.config.config_with_absl() + app.run(main) diff --git a/jax/experimental/pallas/ops/gpu/layer_norm.py b/jax/experimental/pallas/ops/gpu/layer_norm.py index d37afaf4d9e0..bb0d5757b751 100644 --- a/jax/experimental/pallas/ops/gpu/layer_norm.py +++ b/jax/experimental/pallas/ops/gpu/layer_norm.py @@ -21,7 +21,6 @@ import jax from jax import lax import jax.numpy as jnp -from jax._src.lax.control_flow.for_loop import for_loop from jax.experimental import pallas as pl from jax.experimental.pallas import triton as plgpu @@ -32,40 +31,55 @@ def layer_norm_forward_kernel( *, eps: float, block_size: int): n_col = x_ref.shape[0] - def mean_body(i, acc_ref): + def mean_body(i, acc): col_idx = i * block_size + jnp.arange(block_size) mask = col_idx < n_col - a = pl.load(x_ref, (col_idx,), mask=mask, other=0., - eviction_policy="evict_last").astype(jnp.float32) - acc_ref[:] += a - mean = for_loop(pl.cdiv(n_col, block_size), mean_body, - jnp.zeros(block_size)).sum() / n_col + a = plgpu.load( + x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last" + ).astype(jnp.float32) + return acc + a + + mean = lax.fori_loop( + 0, + pl.cdiv(n_col, block_size), + mean_body, + init_val=jnp.zeros(block_size), + ).sum() + mean /= n_col - def var_body(i, acc_ref): + def var_body(i, acc): col_idx = i * block_size + jnp.arange(block_size) mask = col_idx < n_col - a = pl.load(x_ref, (col_idx,), mask=mask, other=0., - eviction_policy="evict_last").astype(jnp.float32) + a = plgpu.load( + x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last" + ).astype(jnp.float32) a = jnp.where(mask, a - mean, 0.) - acc_ref[:] += a * a - var = for_loop(pl.cdiv(n_col, block_size), var_body, - jnp.zeros(block_size)).sum() / n_col + return acc + a * a + + var = lax.fori_loop( + 0, + pl.cdiv(n_col, block_size), + var_body, + init_val=jnp.zeros(block_size), + ).sum() + var /= n_col rstd = 1 / jnp.sqrt(var + eps) if mean_ref is not None: mean_ref[...] = mean.astype(mean_ref.dtype) if rstd_ref is not None: rstd_ref[...] = rstd.astype(rstd_ref.dtype) - def body(i, _): + @pl.loop(0, pl.cdiv(n_col, block_size)) + def body(i): col_idx = i * block_size + jnp.arange(block_size) mask = col_idx < n_col - weight = pl.load(weight_ref, (col_idx,), mask=mask) - bias = pl.load(bias_ref, (col_idx,), mask=mask) - x = pl.load(x_ref, (col_idx,), mask=mask, other=0., - eviction_policy="evict_first").astype(jnp.float32) + weight = plgpu.load(weight_ref.at[col_idx], mask=mask) + bias = plgpu.load(bias_ref.at[col_idx], mask=mask) + x = plgpu.load( + x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_first" + ).astype(jnp.float32) out = (x - mean) * rstd * weight + bias - pl.store(o_ref, (col_idx,), out.astype(o_ref.dtype), mask=mask) - for_loop(pl.cdiv(n_col, block_size), body, ()) + plgpu.store(o_ref.at[col_idx], out.astype(o_ref.dtype), mask=mask) def layer_norm_forward( @@ -94,7 +108,7 @@ def layer_norm_forward( ] method = pl.pallas_call( kernel, - compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps), + compiler_params=plgpu.CompilerParams(num_warps=num_warps), grid=(), out_shape=out_shape, debug=False, @@ -116,40 +130,54 @@ def layer_norm_backward_kernel_dx( *, eps: float, block_size: int): n_col = x_ref.shape[0] - def mean_body(i, acc_ref): + def mean_body(i, acc): col_idx = i * block_size + jnp.arange(block_size) mask = col_idx < n_col - a = pl.load(x_ref, (col_idx,), mask=mask, other=0., - eviction_policy="evict_last").astype(jnp.float32) - dout = pl.load(do_ref, (col_idx,), mask=mask, other=0., - eviction_policy="evict_last").astype(jnp.float32) - weight = pl.load(weight_ref, (col_idx,), mask=mask, other=0., - eviction_policy="evict_last").astype(jnp.float32) + a = plgpu.load( + x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last" + ).astype(jnp.float32) + dout = plgpu.load( + do_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last" + ).astype(jnp.float32) + weight = plgpu.load( + weight_ref.at[col_idx], + mask=mask, + other=0.0, + eviction_policy="evict_last", + ).astype(jnp.float32) a_hat = (a - mean_ref[...]) * rstd_ref[...] wdout = weight * dout - mean1_acc_ref, mean2_acc_ref = acc_ref - mean1_acc_ref[:] += a_hat * wdout - mean2_acc_ref[:] += wdout - mean = for_loop(pl.cdiv(n_col, block_size), mean_body, - (jnp.zeros(block_size), jnp.zeros(block_size))) - mean1, mean2 = mean + mean1_acc, mean2_acc = acc + return mean1_acc + a_hat * wdout, mean2_acc + wdout + mean1, mean2 = lax.fori_loop( + 0, + pl.cdiv(n_col, block_size), + mean_body, + init_val=(jnp.zeros(block_size), jnp.zeros(block_size)), + ) mean1 = mean1.sum() / n_col mean2 = mean2.sum() / n_col - def dx_body(i, acc_ref): + @pl.loop(0, pl.cdiv(n_col, block_size)) + def dx_body(i): col_idx = i * block_size + jnp.arange(block_size) mask = col_idx < n_col - a = pl.load(x_ref, (col_idx,), mask=mask, other=0., - eviction_policy="evict_last").astype(jnp.float32) - dout = pl.load(do_ref, (col_idx,), mask=mask, other=0., - eviction_policy="evict_last").astype(jnp.float32) - weight = pl.load(weight_ref, (col_idx,), mask=mask, other=0., - eviction_policy="evict_last").astype(jnp.float32) + a = plgpu.load( + x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last" + ).astype(jnp.float32) + dout = plgpu.load( + do_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last" + ).astype(jnp.float32) + weight = plgpu.load( + weight_ref.at[col_idx], + mask=mask, + other=0.0, + eviction_policy="evict_last", + ).astype(jnp.float32) a_hat = (a - mean_ref[...]) * rstd_ref[...] wdout = weight * dout da = (wdout - (a_hat * mean1 + mean2)) * rstd_ref[...] - pl.store(dx_ref, (col_idx,), da.astype(dx_ref.dtype), mask=mask) - for_loop(pl.cdiv(n_col, block_size), dx_body, ()) + plgpu.store(dx_ref.at[col_idx], da.astype(dx_ref.dtype), mask=mask) def layer_norm_backward_kernel_dw_db( @@ -164,25 +192,36 @@ def layer_norm_backward_kernel_dw_db( col_idx = j * block_n + jnp.arange(block_n) col_mask = col_idx < n_col - def body(i, acc_ref): + def body(i, acc): row_idx = i * block_m + jnp.arange(block_m) row_mask = row_idx < m mask = row_mask[:, None] & col_mask[None, :] - a = pl.load( - x_ref, (row_idx[:, None], col_idx[None]), mask=mask, other=0.0 + a = plgpu.load( + x_ref.at[row_idx[:, None], col_idx[None]], mask=mask, other=0.0 ).astype(jnp.float32) - dout = pl.load( - do_ref, (row_idx[:, None], col_idx[None]), mask=mask, other=0.0 + dout = plgpu.load( + do_ref.at[row_idx[:, None], col_idx[None]], mask=mask, other=0.0 ).astype(jnp.float32) - mean = pl.load(mean_ref, (row_idx,), mask=row_mask, other=0.).astype(jnp.float32) - rstd = pl.load(rstd_ref, (row_idx,), mask=row_mask, other=0.).astype(jnp.float32) + mean = plgpu.load(mean_ref.at[row_idx], mask=row_mask, other=0.0).astype( + jnp.float32 + ) + rstd = plgpu.load(rstd_ref.at[row_idx], mask=row_mask, other=0.0).astype( + jnp.float32 + ) a_hat = (a - mean[:, None]) * rstd[:, None] - dw_acc_ref, db_acc_ref = acc_ref - dw_acc_ref[:] += (dout * a_hat).sum(axis=0) - db_acc_ref[:] += dout.sum(axis=0) - dw_acc, db_acc = for_loop(pl.cdiv(m, block_m), body, (jnp.zeros(block_n), jnp.zeros(block_n))) - pl.store(dw_ref, (col_idx,), dw_acc.astype(dw_ref.dtype), mask=col_mask) - pl.store(db_ref, (col_idx,), db_acc.astype(db_ref.dtype), mask=col_mask) + dw_acc_ref, db_acc_ref = acc + return dw_acc_ref + (dout * a_hat).sum(axis=0), db_acc_ref + dout.sum( + axis=0 + ) + + dw_acc, db_acc = lax.fori_loop( + 0, + pl.cdiv(m, block_m), + body, + init_val=(jnp.zeros(block_n), jnp.zeros(block_n)), + ) + plgpu.store(dw_ref.at[col_idx], dw_acc.astype(dw_ref.dtype), mask=col_mask) + plgpu.store(db_ref.at[col_idx], db_acc.astype(db_ref.dtype), mask=col_mask) def layer_norm_backward( @@ -215,7 +254,7 @@ def layer_norm_backward( out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype) method = pl.pallas_call( kernel, - compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps), + compiler_params=plgpu.CompilerParams(num_warps=num_warps), grid=(), out_shape=out_shape_dx, debug=False, @@ -247,7 +286,7 @@ def layer_norm_backward( grid_ = (pl.cdiv(reshaped_x.shape[1], block_n),) method = pl.pallas_call( kernel, - compiler_params=dict(triton=dict(num_warps=num_warps)), + compiler_params=plgpu.CompilerParams(num_warps=num_warps), grid=grid_, out_shape=out_shape_dwbias, debug=False, @@ -283,7 +322,7 @@ def layer_norm( out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype) method = pl.pallas_call( kernel, - compiler_params=plgpu.TritonCompilerParams( + compiler_params=plgpu.CompilerParams( num_warps=num_warps, num_stages=num_stages), grid=(), out_shape=out_shape, diff --git a/jax/experimental/pallas/ops/gpu/paged_attention.py b/jax/experimental/pallas/ops/gpu/paged_attention.py index b30ef554fe12..1e6e19719253 100644 --- a/jax/experimental/pallas/ops/gpu/paged_attention.py +++ b/jax/experimental/pallas/ops/gpu/paged_attention.py @@ -33,7 +33,9 @@ def paged_attention_kernel( # inputs q_ref, # [block_h, head_dim] k_pages_ref, # [total_num_pages, page_size, head_dim] + k_scales_pages_ref, # [total_num_pages, page_size] v_pages_ref, # [total_num_pages, page_size, head_dim] + v_scales_pages_ref, # [total_num_pages, page_size] block_tables_ref, # [pages_per_partition] lengths_ref, # [1] # outputs @@ -52,7 +54,7 @@ def paged_attention_kernel( def _compute(start_page_idx, end_page_idx, o, m_i, l_i): q_slice = pl.ds(0, block_h) - q = pl.load(q_ref, (q_slice, slice(None))) + q = q_ref[q_slice, :] # Loop over blocks of pages to process a entire page sequence partition. # Grid loops over q blocks over num_heads. @@ -62,10 +64,19 @@ def body(start_k, carry): block_tables_slice = pl.ds( start_k * pages_per_compute_block, pages_per_compute_block ) - block_tables = pl.load(block_tables_ref, block_tables_slice) + block_tables = block_tables_ref[block_tables_slice] k = k_pages_ref[block_tables].reshape(block_k, head_dim) v = v_pages_ref[block_tables].reshape(block_k, head_dim) + if k_scales_pages_ref is not None: + # dynamic lhs quantized dot is not currently implemented + # so we cast rhs to the lhs dtype + k = k.astype(q.dtype) uncapped_logits = pl.dot(q, k.T) # [block_h, block_k] + if k_scales_pages_ref is not None: + # k_scales_pages_ref are one per head + # they're laid out across the output dimension, so scale output + k_scale = k_scales_pages_ref[block_tables].reshape((1, block_k)) + uncapped_logits *= k_scale.astype(uncapped_logits.dtype) if attn_logits_soft_cap is not None: logits = jnp.tanh(uncapped_logits / attn_logits_soft_cap) logits = logits * attn_logits_soft_cap @@ -92,6 +103,14 @@ def body(start_k, carry): l_curr = s_curr.sum(axis=-1) l_next = l_prev_corr + l_curr o_prev_corr = correction[:, None] * o_prev + if v_scales_pages_ref is not None: + # v_scales are 1 per head + # they're laid out across the reduction dimension, so scale lhs + v_scale = v_scales_pages_ref[block_tables].reshape((1, block_k)) + s_curr *= v_scale.astype(s_curr.dtype) + # dynamic lhs quantized dot is not currently implemented + # so we cast rhs to the lhs dtype + v = v.astype(s_curr.dtype) o_curr = pl.dot(s_curr.astype(v.dtype), v) o_next = o_prev_corr + o_curr @@ -134,6 +153,8 @@ def paged_attention_unbatched( v_pages: jax.Array, # [num_kv_heads, total_num_pages, page_size, head_dim] block_tables: jax.Array, # [pages_per_sequence] lengths: jax.Array | None, # [1] + k_scales_pages: jax.Array | None = None, # [num_kv_heads, total_num_pages, page_size] + v_scales_pages: jax.Array | None = None, # [num_kv_heads, total_num_pages, page_size] *, block_h: int, pages_per_compute_block: int, @@ -179,6 +200,19 @@ def paged_attention_unbatched( mask_value=mask_value, attn_logits_soft_cap=attn_logits_soft_cap, ) + # set up quantization scales + if k_scales_pages is not None: + assert k_scales_pages.shape == (num_kv_heads, total_num_pages, page_size) + k_scales_spec = pl.BlockSpec((None, total_num_pages, page_size), + lambda h, i, k: (h, 0, 0)) + else: + k_scales_spec = None + if v_scales_pages is not None: + assert v_scales_pages.shape == (num_kv_heads, total_num_pages, page_size) + v_scales_spec = pl.BlockSpec((None, total_num_pages, page_size), + lambda h, i, k: (h, 0, 0)) + else: + v_scales_spec = None o, l, m = pl.pallas_call( kernel, @@ -191,10 +225,12 @@ def paged_attention_unbatched( (None, total_num_pages, page_size, head_dim), lambda h, i, k: (h, 0, 0, 0), ), # k_pages + k_scales_spec, # k_pages_scale pl.BlockSpec( (None, total_num_pages, page_size, head_dim), lambda h, i, k: (h, 0, 0, 0), ), # v_pages + v_scales_spec, # v_pages_scale pl.BlockSpec( (None, pages_per_partition), lambda h, i, k: (k, 0) ), # block_tables @@ -222,11 +258,11 @@ def paged_attention_unbatched( ], debug=debug, interpret=interpret, - compiler_params=plgpu.TritonCompilerParams( + compiler_params=plgpu.CompilerParams( num_warps=num_warps, num_stages=num_stages ), name=f"paged_attention_{block_h=}_{pages_per_compute_block=}", - )(q_reshaped, k_pages, v_pages, block_tables, lengths) + )(q_reshaped, k_pages, k_scales_pages, v_pages, v_scales_pages, block_tables, lengths) if q_heads_per_kv_head % block_h: o = o[..., :q_heads_per_kv_head, :] @@ -265,6 +301,8 @@ def paged_attention( v_pages: jax.Array, block_tables: jax.Array, lengths: jax.Array | None, + k_scales_pages: jax.Array | None = None, + v_scales_pages: jax.Array | None = None, *, block_h: int = 16, pages_per_compute_block: int = 8, @@ -286,6 +324,8 @@ def paged_attention( should be in the range of [0, total_num_pages), indicating where to locate the page in `k_pages` or `v_pages`. lengths: A i32[batch_size] jax.Array the length of each example. + k_scales_pages: A [num_kv_heads, total_num_pages, page_size] jax.Array. + v_scales_pages: A [num_kv_heads, total_num_pages, page_size] jax.Array. block_h: int The block size that partitions the number of head groups. pages_per_compute_block: int The maximum number of blocks per compute block. k_splits: int Number of partitions used to parallelize key-value sequence @@ -342,12 +382,14 @@ def paged_attention( attn_logits_soft_cap=attn_logits_soft_cap, ) - o = jax.vmap(impl, (0, None, None, 0, 0), 0)( + o = jax.vmap(impl, (0, None, None, 0, 0, None, None), 0)( q, k_pages, v_pages, block_tables, lengths[..., None] if lengths is not None else None, + k_scales_pages, + v_scales_pages, ) return o diff --git a/jax/experimental/pallas/ops/gpu/ragged_dot_mgpu.py b/jax/experimental/pallas/ops/gpu/ragged_dot_mgpu.py new file mode 100644 index 000000000000..bbce3c32ee5d --- /dev/null +++ b/jax/experimental/pallas/ops/gpu/ragged_dot_mgpu.py @@ -0,0 +1,333 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ragged dot Pallas-Mosaic-GPU implementation.""" + +import dataclasses +import functools +import itertools +import math +import jax +from jax import lax +from jax import numpy as jnp +from jax import random +from jax._src import test_util as jtu # noqa: F401 +from jax.experimental import pallas as pl +from jax.experimental.mosaic.gpu import profiler +from jax.experimental.pallas import mosaic_gpu as plgpu +import numpy as np + + +@dataclasses.dataclass(frozen=True) +class GroupInfo: + """Information regarding the group being processed in a block.""" + + group_id: jax.Array + block: jax.Array + block_start: jax.Array + actual_start: jax.Array + actual_end: jax.Array + start_within_block: jax.Array + actual_size: jax.Array + + @classmethod + def create(cls, group_lengths, tile, tid): + """Get the group info for the current block.""" + + tile = jnp.int32(tile) + group_boundaries = [group_lengths[i] for i in range(len(group_lengths))] + + # We usually only have very few groups, so we unroll the loop processing + # them. Normally we'd break out of the loop early, once we'd have found our + # boundary, but we can't do that when unrolling, so we rely on many selects + # to mask out the epilogue of the loop. + group_end = group_start = block = group = end = jnp.array( + 0, dtype=jnp.int32 + ) + + for i, b in enumerate(group_boundaries): + # Start/end are inclusive + start = end + end = start + b + final = end - 1 + start_block = lax.div(start, tile) + final_block = lax.div(final, tile) + block_end = final_block + 1 + tid_begin = start_block + i + tid_end = block_end + i + # How many blocks after is our block? + this_is_group = (tid_begin <= tid) & (tid < tid_end) + block = lax.select(this_is_group, tid - tid_begin + start_block, block) + group = lax.select(this_is_group, jnp.int32(i), group) + group_start = lax.select(this_is_group, start, group_start) + group_end = lax.select(this_is_group, end, group_end) + + block_start = block * tile + actual_start = jnp.maximum(group_start, block_start) + actual_end = jnp.minimum(group_end, block_start + tile) + start_within_block = actual_start - block_start + actual_size = actual_end - actual_start + return cls( + group_id=group, + block=block, + block_start=block_start, + actual_start=actual_start, + actual_end=actual_end, + start_within_block=start_within_block, + actual_size=actual_size, + ) + + +def ragged_dot( + lhs, # (M, K) + rhs, # (G, K, N) + *, + group_sizes, # (G,) + block_m: int, + block_n: int, + block_k: int, + max_concurrent_steps: int, + grid_block_n: int, + transpose_rhs: bool = False, + load_group_sizes_to_register: bool = True, +) -> jax.Array: + if lhs.dtype != rhs.dtype: + raise NotImplementedError( + f"lhs and rhs must have the same dtype, got {lhs.dtype} and {rhs.dtype}" + ) + m, k = lhs.shape + g, k2, n = rhs.shape + + if transpose_rhs: + k2, n = n, k2 + + if group_sizes.shape[0] != g: + raise ValueError( + f"Expected group_sizes to have shape {g} but got {group_sizes.shape}" + ) + + if k != k2: + raise ValueError(f"lhs.shape={k} must match rhs.shape={k2}") + + if k % block_k != 0: + raise ValueError(f"k={k} must be a multiple of block_k={block_k}") + + def body(rows_per_expert_gmem, lhs_gmem, rhs_gmem, o_gmem): + grid_m = pl.cdiv(m, block_m) + g - 1 + grid_n = pl.cdiv(n, block_n) + grid = (grid_m * grid_n,) + if load_group_sizes_to_register: + rows_per_expert = [rows_per_expert_gmem[i] for i in range(len(rows_per_expert_gmem))] + else: + rows_per_expert = rows_per_expert_gmem + + @plgpu.nd_loop(grid, collective_axes="sm") + def mn_loop(loop_info: plgpu.NDLoopInfo): # pylint: disable=unused-variable + mi, ni = plgpu.planar_snake( + loop_info.index[0], + (grid_m, grid_n), + 1, + grid_block_n, + ) + group_info = GroupInfo.create(rows_per_expert_gmem, block_m, mi) + + def acc_scope(acc_ref): + plgpu.emit_pipeline( + lambda _, lhs_smem, rhs_smem: plgpu.wgmma( + acc_ref, + lhs_smem, + plgpu.transpose_ref(rhs_smem, (1, 0)) if transpose_rhs else rhs_smem, + ), + grid=(k // block_k,), + in_specs=[ + plgpu.BlockSpec( + (block_m, block_k), + lambda k: (group_info.block, k), + delay_release=1, + ), + plgpu.BlockSpec( + (block_n, block_k) if transpose_rhs else (block_k, block_n), + lambda k: (ni, k) if transpose_rhs else (k, ni), + delay_release=1, + ), + ], + max_concurrent_steps=max_concurrent_steps, + )(lhs_gmem, rhs_gmem.at[group_info.group_id]) + return acc_ref[...] + + acc = pl.run_scoped(acc_scope, plgpu.ACC((block_m, block_n))) + + @functools.partial( + pl.run_scoped, + o_smem=plgpu.SMEM((block_m, block_n), dtype=o_gmem.dtype) + ) + def store_scope(o_smem): # pylint: disable=unused-variable + o_smem[...] = acc.astype(o_smem.dtype) + plgpu.commit_smem() + + smem_start = group_info.start_within_block + remaining_rows = min(block_m, m) + # TMA descriptors need to be generated with static tile sizes along each + # axis, but we do not know at compile time how many rows we will need to + # store. We only know that the number of rows to store is bounded by + # min(block_m, m). + # + # In order to work around that, we construct a logarithmic ladder of + # TMA descriptors, where each descriptor can store 2**i rows for some + # i between 0 and log2(min(block_m, m)). This allows storing any + # number of rows we will need to store, so long as this number of rows + # is between `1` and `min(block_m, m)`. + # + # E.g., imagine we have block_m = 8, m = 16. The loop below will be + # unrolled into 4 iterations, where the first one will generate a TMA + # descriptor that can store 8 rows, the second one will generate a TMA + # descriptor that can store 4 rows, etc. all the way to 1 row. + # + # At run time, we finally know the actual number of rows we need to + # store as we go through the unrolled loop iterations. Let's imagine + # that we need to store 5 rows. + # + # The first unrolled iteration will check whether we can store 8 rows. + # Since we only need to store 5 rows, we won't store anything then. + # + # The second unrolled iteration will check whether we can store 4 rows. + # We're able to store 4 rows, and are left with a single remaining row. + # + # The fourth unrolled iteration will store the single remaining row, and + # we end up with a storing scheme as follows for our 5 rows: + # + # ----------------------------------------------------------- + # 0 | | + # 1 | | + # 2 | Store 4 rows | + # 3 | | + # ----------------------------------------------------------- + # 4 | Store 1 row | + # ----------------------------------------------------------- + while remaining_rows > 0: + const_rows_len = 1 << int(math.log2(remaining_rows)) + remaining_rows //= 2 + + @pl.when(group_info.actual_size & const_rows_len != 0) + def _(): + o_smem_slice = o_smem.at[pl.ds(smem_start, const_rows_len)] + o_gref_slice = o_gmem.at[ + pl.ds(group_info.block_start + smem_start, const_rows_len), + pl.ds(ni * block_n, block_n), + ] + plgpu.copy_smem_to_gmem(o_smem_slice, o_gref_slice) + + smem_start += group_info.actual_size & const_rows_len + plgpu.wait_smem_to_gmem(0, wait_read_only=True) + + # There are 132 SMs on a H100 SXM GPU. + num_sms = 132 + kernel = plgpu.kernel( + body, + out_shape=jax.ShapeDtypeStruct((m, n), lhs.dtype), + grid=(num_sms,), + grid_names=("sm",), + compiler_params=plgpu.CompilerParams( + lowering_semantics=plgpu.LoweringSemantics.Warpgroup, + ), + ) + return kernel(group_sizes, lhs, rhs) + + +def main(unused_argv): + for transpose_rhs in [False, True]: + m, k, n, num_groups = 16 * 1024, 2048, 16 * 1024, 16 + kx, ky, kz = random.split(random.key(1234), num=3) + + lhs = jax.random.normal(kx, (m, k), jnp.float16) + if transpose_rhs: + rhs = jax.random.normal(ky, (num_groups, n, k), jnp.float16) + else: + rhs = jax.random.normal(ky, (num_groups, k, n), jnp.float16) + group_boundaries = jax.lax.sort( + jax.random.randint(kz, (num_groups - 1,), 0, m, jnp.int32) + ) + group_starts = lax.concatenate( + [jnp.array([0], dtype=jnp.int32), group_boundaries], 0 + ) + group_ends = lax.concatenate( + [group_boundaries, jnp.array([m], dtype=jnp.int32)], 0 + ) + group_sizes = group_ends - group_starts + assert group_sizes.shape == (num_groups,) + + block_m = block_n = (64, 128, 192) + block_k = (64,) + max_concurrent_steps = (2, 4, 5, 6) + grid_block_n = (1, 2, 4, 8, 16) + configs = itertools.product( + block_m, block_n, block_k, max_concurrent_steps, grid_block_n + ) + names = ( + "block_m", "block_n", "block_k", "max_concurrent_steps", "grid_block_n" + ) + best_runtime = float("inf") + best_kwargs = {} + for config in configs: + kwargs = dict(zip(names, config)) + if n % (kwargs["grid_block_n"] * kwargs["block_n"]): + continue + try: + f = functools.partial( + ragged_dot, group_sizes=group_sizes, transpose_rhs=transpose_rhs, + **kwargs + ) + _, runtime = profiler.measure(f)(lhs, rhs) + except ValueError as e: + if "Mosaic GPU kernel exceeds available shared memory" not in str(e): + raise + runtime = float("inf") + # Enable this to get more detailed information. + else: + print(" ".join(f"{k}={v}" for k, v in kwargs.items()), int(runtime * 1000)) + if runtime < best_runtime: # pytype: disable=unsupported-operands + best_runtime = runtime + best_kwargs = kwargs + if not best_kwargs: + raise ValueError("No valid configuration found") + + def ref_ragged_dot(lhs, rhs, group_sizes): + if transpose_rhs: + rhs = jnp.transpose(rhs, (0, 2, 1)) + return jax.lax.ragged_dot(lhs, rhs, group_sizes=group_sizes) + + ref, ref_runtime = profiler.measure(ref_ragged_dot)( + lhs, rhs, group_sizes=group_sizes + ) + result = ragged_dot( + lhs, rhs, group_sizes=group_sizes, transpose_rhs=transpose_rhs, + **best_kwargs + ) + np.testing.assert_allclose(result, ref, atol=1e-3, rtol=1e-3) + + tflops = float(2 * k * m * n) / (best_runtime / 1e3) / 1e12 + ref_tflops = float(2 * k * m * n) / (ref_runtime / 1e3) / 1e12 + print(f"Transpose RHS: {transpose_rhs}") + print( + "Best parameters: ", " ".join(f"{k}={v}" for k, v in best_kwargs.items()) + ) + print(f"Kernel: {best_runtime * 1000:.1f} us = {tflops:.1f} TFLOPS") + print(f"Reference: {ref_runtime * 1000:.1f} us = {ref_tflops:.1f} TFLOPS") + + +if __name__ == "__main__": + from absl import app + + jax.config.config_with_absl() + app.run(main) diff --git a/jax/experimental/pallas/ops/gpu/reduce_scatter_mgpu.py b/jax/experimental/pallas/ops/gpu/reduce_scatter_mgpu.py new file mode 100644 index 000000000000..28f6eed91324 --- /dev/null +++ b/jax/experimental/pallas/ops/gpu/reduce_scatter_mgpu.py @@ -0,0 +1,248 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Reduce scatter kernel implemented using Mosaic GPU.""" + +import functools +import itertools +import math +from typing import Literal + +import jax +from jax import lax +from jax.experimental import multihost_utils +from jax.experimental import pallas as pl +from jax.experimental.mosaic.gpu import profiler +from jax.experimental.pallas import mosaic_gpu as plgpu +from jax.extend import backend +import jax.numpy as jnp + + +def reduce_scatter( + x: jax.Array, + *, + axis_name, + scatter_dimension: int | None = 0, + reduction: Literal["add", "min", "max", "and", "or", "xor"] = "add", + num_blocks: int | None = None, + tile_size: int | None = None, + vec_size: int | None = None, +) -> jax.Array: + """Performs a reduce-scatter or all-reduce operation across devices using multimem instructions. + + Args: + x: Input array. Should be sharded across the specified axis. + axis_name: Name of the mesh axis to reduce-scatter across. + scatter_dimension: Axis along which to reduce-scatter. If None, performs + all-reduce instead. Defaults to 0. + reduction: Reduction operation to perform. Supported: "add", "min", "max", + "and", "or", "xor". + vec_size: Vector size for the layout. If None, automatically inferred from dtype. + num_blocks: Number of blocks to use. Defaults to the device core count. + tile_size: Total tile size to split across major, scatter, and minor dimensions. + """ + num_devices = lax.axis_size(axis_name) + input_shape = x.shape + dtype = x.dtype + ndim = len(input_shape) + + if num_blocks is None: + num_blocks = backend.get_default_device().core_count + + if scatter_dimension is None: + major_dims, scatter_dim, minor_dims = 1, math.prod(input_shape), 1 + output_scatter_dim = scatter_dim + output_shape = input_shape + else: + if scatter_dimension < -ndim or scatter_dimension >= ndim: + raise ValueError( + f"scatter_dimension {scatter_dimension} out of bounds for array of" + f" dimension {ndim}" + ) + if scatter_dimension < 0: + scatter_dimension += ndim + + scatter_dim = input_shape[scatter_dimension] + if scatter_dim % num_devices != 0: + raise ValueError( + f"Scattered dimension {scatter_dimension} of input ({scatter_dim})" + f" must be divisible by number of devices ({num_devices})" + ) + + major_dims = math.prod(input_shape[:scatter_dimension]) + minor_dims = math.prod(input_shape[scatter_dimension+1:]) + output_scatter_dim = scatter_dim // num_devices + output_shape = ( + *input_shape[:scatter_dimension], output_scatter_dim, *input_shape[scatter_dimension + 1 :], + ) + + if (output_size := math.prod(output_shape)) % 128: + raise ValueError("Output size must be divisible by 128") + if jnp.issubdtype(dtype, jnp.integer): + if vec_size is None: + vec_size = 1 # Integer types only support unvectorized reductions + elif vec_size != 1: + raise ValueError("Integer types only support vec_size=1") + elif vec_size is None: # vec_size inference for floating point types + dtype_bits = jnp.finfo(dtype).bits + max_vec_size = min(128 // dtype_bits, output_size // 128) + if tile_size is not None: + max_vec_size_for_tile = tile_size // 128 + max_vec_size = min(max_vec_size, max_vec_size_for_tile) + vec_size = 32 // dtype_bits # We don't support ld_reduce below 32-bit + while vec_size * 2 <= max_vec_size: + vec_size *= 2 + if math.prod(output_shape) % vec_size: + raise ValueError( + "The total number of elements in the output" + f" ({math.prod(output_shape)}) must be divisible by the vec_size" + f" ({vec_size})" + ) + + min_transfer_elems = 128 * vec_size + if tile_size is None: + # TODO(apaszke): 8 is just an arbitrary unrolling factor. Tune it! + unroll_factor = min(math.prod(output_shape) // min_transfer_elems, 8) + tile_size = unroll_factor * min_transfer_elems + if tile_size < min_transfer_elems: + raise ValueError( + f"{tile_size=} is smaller than minimum required" + f" {min_transfer_elems} for {vec_size=}" + ) + + minor_tile = math.gcd(tile_size, minor_dims) + remaining_tile = tile_size // minor_tile + scatter_tile = math.gcd(remaining_tile, output_scatter_dim) + major_tile = remaining_tile // scatter_tile + + if major_dims % major_tile != 0: + raise NotImplementedError( + f"Major dimension size ({major_dims}) must be divisible by the" + f" inferred major tile size ({major_tile}). Consider adjusting tile_size." + ) + + def kernel(x_ref, y_ref, done_barrier): + dev_idx = lax.axis_index(axis_name) + x_ref_3d = x_ref.reshape((major_dims, scatter_dim, minor_dims)) + y_ref_3d = y_ref.reshape((major_dims, output_scatter_dim, minor_dims)) + + if scatter_dimension is not None: + dev_slice = pl.ds(dev_idx * output_scatter_dim, output_scatter_dim) + x_ref_3d = x_ref_3d.at[:, dev_slice, :] + + major_tiles = major_dims // major_tile + scatter_tiles = output_scatter_dim // scatter_tile + minor_tiles = minor_dims // minor_tile + @plgpu.nd_loop((major_tiles, scatter_tiles, minor_tiles), collective_axes="blocks") + def _transfer_loop(loop_info: plgpu.NDLoopInfo): + major_tile_idx, scatter_tile_idx, minor_tile_idx = loop_info.index + idxs = ( + pl.ds(major_tile_idx * major_tile, major_tile), + pl.ds(scatter_tile_idx * scatter_tile, scatter_tile), + pl.ds(minor_tile_idx * minor_tile, minor_tile) + ) + + y_ref_3d[idxs] = plgpu.layout_cast( + plgpu.multimem_load_reduce( + x_ref_3d.at[idxs], collective_axes=axis_name, reduction_op=reduction + ), + plgpu.Layout.WG_STRIDED((major_tile, scatter_tile, minor_tile), vec_size=vec_size) + ) + + # Wait for everyone to finish reading the operands before we exit and potentially free them + plgpu.semaphore_signal_multicast(done_barrier, collective_axes=axis_name) + pl.semaphore_wait(done_barrier, num_devices, decrement=False) + + # TODO(b/448323639): We fake modify the input to ensure that XLA:GPU copies + # the operand into symmetric memory. + @pl.when(dev_idx == -1) + def _never(): + x_ref[(0,) * len(x_ref.shape)] = jnp.asarray(0, x_ref.dtype) + + return plgpu.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct(output_shape, dtype), + grid=(num_blocks,), + grid_names=("blocks",), + scratch_shapes=[plgpu.SemaphoreType.REGULAR], + )(x) + + +def _run_example(): + P = jax.sharding.PartitionSpec + shape = (4 * 4096, 4 * 4096) # This shape is global! + dtype = jnp.bfloat16 + shards = jax.device_count() + mesh = jax.make_mesh( + (shards,), ("x",), axis_types=(jax.sharding.AxisType.Explicit,) + ) + jax.set_mesh(mesh) + + # We measure time per-shard and so we only need bytes per shard. + local_in_bytes = math.prod(shape) / shards * jnp.dtype(dtype).itemsize + # In reduce-scatter, we send (shards - 1) / shards worth of input data to the + # switch and receive as much data as in the whole output, which is 1 / shards. + total_bytes = local_in_bytes + + a = jax.random.normal(jax.random.key(1), shape, dtype) + a = jax.sharding.reshard(a, P(None, "x")) + + @jax.jit + @functools.partial(jax.shard_map, mesh=mesh, in_specs=P(None, "x"), out_specs=P(None, "x")) + def ref_fn(x): + return lax.psum_scatter(x, "x", scatter_dimension=1, tiled=True) + ref_fn(a).block_until_ready() # Warmup. + _, ref_kernels_ms = profiler.measure(ref_fn, aggregate=False)(a) + ref_time_us = sum(t * 1e3 for _, t in ref_kernels_ms) + # We choose the minimum across processes to choose the runtime that didn't + # include devices waiting for other devices. + ref_time_us = min(multihost_utils.process_allgather(ref_time_us).tolist()) + ref_bw = total_bytes / (ref_time_us * 1e-6) / 1e9 # GB/s + + tuning_it = itertools.product( + (4, 8, 16, 32, 64, 132), # num_blocks + (1024, 2048, 4096, 8192), # tile_size + ) + best_bw = 0.0 + best_runtime = float("inf") + for num_blocks, tile_size in tuning_it: + try: + @jax.jit + @functools.partial( + jax.shard_map, mesh=mesh, in_specs=P(None, "x"), out_specs=P(None, "x"), check_vma=False + ) + def kernel_fn(x): + return reduce_scatter(x, axis_name="x", scatter_dimension=1, num_blocks=num_blocks, tile_size=tile_size) + kernel_fn(a).block_until_ready() # Warmup. + _, kernels_ms = profiler.measure(kernel_fn, aggregate=False)(a) + except ValueError as e: + if "exceeds available shared memory" in e.args[0]: # Ignore SMEM OOMs. + continue + raise + runtime_us = sum(t * 1e3 for _, t in kernels_ms) + runtime_us = min(multihost_utils.process_allgather(runtime_us).tolist()) + achieved_bw = total_bytes / (runtime_us * 1e-6) / 1e9 # GB/s + if achieved_bw > best_bw: + best_runtime = runtime_us + best_bw = achieved_bw + print(f"{num_blocks=}, {tile_size=}: {runtime_us:<7.1f}us = {achieved_bw:4.1f} GB/s") + + print(f"Total bytes transferred: {total_bytes / 1e9:.2f} GB") + print(f"\tBest: {best_runtime:<7.1f}us = {best_bw:4.1f} GB/s") + print(f"\tRef: {ref_time_us:<7.1f}us = {ref_bw:4.1f} GB/s") + + +if __name__ == "__main__": + from jax._src import test_multiprocess as jt_multiprocess # pytype: disable=import-error + jt_multiprocess.main(shard_main=_run_example) diff --git a/jax/experimental/pallas/ops/gpu/rms_norm.py b/jax/experimental/pallas/ops/gpu/rms_norm.py index ff224c6dfde7..4993f9e63b77 100644 --- a/jax/experimental/pallas/ops/gpu/rms_norm.py +++ b/jax/experimental/pallas/ops/gpu/rms_norm.py @@ -20,11 +20,9 @@ import jax from jax import lax -import jax.numpy as jnp -from jax._src.lax.control_flow.for_loop import for_loop - from jax.experimental import pallas as pl from jax.experimental.pallas import triton as plgpu +import jax.numpy as jnp def rms_norm_forward_kernel( x_ref, weight_ref, bias_ref, # Input arrays @@ -32,29 +30,34 @@ def rms_norm_forward_kernel( *, eps: float, block_size: int): n_col = x_ref.shape[0] - def var_body(i, acc_ref): + def var_body(i, acc): col_idx = i * block_size + jnp.arange(block_size) mask = col_idx < n_col - a = pl.load(x_ref, (col_idx,), mask=mask, other=0., - eviction_policy="evict_last").astype(jnp.float32) + a = plgpu.load( + x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last" + ).astype(jnp.float32) a = jnp.where(mask, a, 0.) - acc_ref[:] += a * a - var = for_loop(pl.cdiv(n_col, block_size), var_body, - jnp.zeros(block_size)).sum() / n_col + return acc + a * a + + var = lax.fori_loop( + 0, pl.cdiv(n_col, block_size), var_body, init_val=jnp.zeros(block_size) + ).sum() + var /= n_col rstd = 1 / jnp.sqrt(var + eps) if rstd_ref is not None: rstd_ref[...] = rstd.astype(rstd_ref.dtype) - def body(i, _): + @pl.loop(0, pl.cdiv(n_col, block_size)) + def body(i): col_idx = i * block_size + jnp.arange(block_size) mask = col_idx < n_col - weight = pl.load(weight_ref, (col_idx,), mask=mask) - bias = pl.load(bias_ref, (col_idx,), mask=mask) - x = pl.load(x_ref, (col_idx,), mask=mask, other=0., - eviction_policy="evict_first").astype(jnp.float32) + weight = plgpu.load(weight_ref.at[col_idx], mask=mask) + bias = plgpu.load(bias_ref.at[col_idx], mask=mask) + x = plgpu.load( + x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_first" + ).astype(jnp.float32) out = x * rstd * weight + bias - pl.store(o_ref, (col_idx,), out.astype(o_ref.dtype), mask=mask) - for_loop(pl.cdiv(n_col, block_size), body, ()) + plgpu.store(o_ref.at[col_idx], out.astype(o_ref.dtype), mask=mask) def rms_norm_forward( @@ -82,7 +85,7 @@ def rms_norm_forward( ] method = pl.pallas_call( kernel, - compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps), + compiler_params=plgpu.CompilerParams(num_warps=num_warps), grid=(), out_shape=out_shape, debug=False, @@ -104,35 +107,50 @@ def rms_norm_backward_kernel_dx( *, eps: float, block_size: int): n_col = x_ref.shape[0] - def mean_body(i, c1_acc_ref): + def mean_body(i, c1_acc): col_idx = i * block_size + jnp.arange(block_size) mask = col_idx < n_col - a = pl.load(x_ref, (col_idx,), mask=mask, other=0., - eviction_policy="evict_last").astype(jnp.float32) - dout = pl.load(do_ref, (col_idx,), mask=mask, other=0., - eviction_policy="evict_last").astype(jnp.float32) - weight = pl.load(weight_ref, (col_idx,), mask=mask, other=0., - eviction_policy="evict_last").astype(jnp.float32) + a = plgpu.load( + x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last" + ).astype(jnp.float32) + dout = plgpu.load( + do_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last" + ).astype(jnp.float32) + weight = plgpu.load( + weight_ref.at[col_idx], + mask=mask, + other=0.0, + eviction_policy="evict_last", + ).astype(jnp.float32) a_hat = a * rstd_ref[...] wdout = weight * dout - c1_acc_ref[:] += a_hat * wdout - c1 = for_loop(pl.cdiv(n_col, block_size), mean_body, jnp.zeros(block_size)) + return c1_acc + a_hat * wdout + + c1 = lax.fori_loop( + 0, pl.cdiv(n_col, block_size), mean_body, jnp.zeros(block_size) + ) c1 = c1.sum() / n_col - def dx_body(i, acc_ref): + @pl.loop(0, pl.cdiv(n_col, block_size)) + def dx_body(i): col_idx = i * block_size + jnp.arange(block_size) mask = col_idx < n_col - a = pl.load(x_ref, (col_idx,), mask=mask, other=0., - eviction_policy="evict_last").astype(jnp.float32) - dout = pl.load(do_ref, (col_idx,), mask=mask, other=0., - eviction_policy="evict_last").astype(jnp.float32) - weight = pl.load(weight_ref, (col_idx,), mask=mask, other=0., - eviction_policy="evict_last").astype(jnp.float32) + a = plgpu.load( + x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last" + ).astype(jnp.float32) + dout = plgpu.load( + do_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last" + ).astype(jnp.float32) + weight = plgpu.load( + weight_ref.at[col_idx], + mask=mask, + other=0.0, + eviction_policy="evict_last", + ).astype(jnp.float32) a_hat = a * rstd_ref[...] wdout = weight * dout da = (wdout - (a_hat * c1)) * rstd_ref[...] - pl.store(dx_ref, (col_idx,), da.astype(dx_ref.dtype), mask=mask) - for_loop(pl.cdiv(n_col, block_size), dx_body, ()) + plgpu.store(dx_ref.at[col_idx], da.astype(dx_ref.dtype), mask=mask) def rms_norm_backward_kernel_dw_db( @@ -147,24 +165,31 @@ def rms_norm_backward_kernel_dw_db( col_idx = j * block_n + jnp.arange(block_n) col_mask = col_idx < n_col - def body(i, acc_ref): + def body(i, acc): row_idx = i * block_m + jnp.arange(block_m) row_mask = row_idx < m mask = row_mask[:, None] & col_mask[None, :] - a = pl.load( - x_ref, (row_idx[:, None], col_idx[None]), mask=mask, other=0.0 + a = plgpu.load( + x_ref.at[row_idx[:, None], col_idx[None]], mask=mask, other=0.0 ).astype(jnp.float32) - dout = pl.load( - do_ref, (row_idx[:, None], col_idx[None]), mask=mask, other=0.0 + dout = plgpu.load( + do_ref.at[row_idx[:, None], col_idx[None]], mask=mask, other=0.0 ).astype(jnp.float32) - rstd = pl.load(rstd_ref, (row_idx,), mask=row_mask, other=0.).astype(jnp.float32) + rstd = plgpu.load(rstd_ref.at[row_idx], mask=row_mask, other=0.0).astype( + jnp.float32 + ) a_hat = a * rstd[:, None] - dw_acc_ref, db_acc_ref = acc_ref - dw_acc_ref[:] += (dout * a_hat).sum(axis=0) - db_acc_ref[:] += dout.sum(axis=0) - dw_acc, db_acc = for_loop(pl.cdiv(m, block_m), body, (jnp.zeros(block_n), jnp.zeros(block_n))) - pl.store(dw_ref, (col_idx,), dw_acc.astype(dw_ref.dtype), mask=col_mask) - pl.store(db_ref, (col_idx,), db_acc.astype(db_ref.dtype), mask=col_mask) + dw_acc, db_acc = acc + return (dw_acc + (dout * a_hat).sum(axis=0), db_acc + dout.sum(axis=0)) + + dw_acc, db_acc = lax.fori_loop( + 0, + pl.cdiv(m, block_m), + body, + init_val=(jnp.zeros(block_n), jnp.zeros(block_n)), + ) + plgpu.store(dw_ref.at[col_idx], dw_acc.astype(dw_ref.dtype), mask=col_mask) + plgpu.store(db_ref.at[col_idx], db_acc.astype(db_ref.dtype), mask=col_mask) def rms_norm_backward( @@ -196,7 +221,7 @@ def rms_norm_backward( out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype) method = pl.pallas_call( kernel, - compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps), + compiler_params=plgpu.CompilerParams(num_warps=num_warps), grid=(), out_shape=out_shape_dx, debug=False, @@ -228,7 +253,7 @@ def rms_norm_backward( grid_ = (pl.cdiv(reshaped_x.shape[1], block_n),) method = pl.pallas_call( kernel, - compiler_params=dict(triton=dict(num_warps=num_warps)), + compiler_params=plgpu.CompilerParams(num_warps=num_warps), grid=grid_, out_shape=out_shape_dwbias, debug=False, @@ -264,8 +289,8 @@ def rms_norm( out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype) method = pl.pallas_call( kernel, - compiler_params=dict( - triton=dict(num_warps=num_warps, num_stages=num_stages) + compiler_params=plgpu.CompilerParams( + num_warps=num_warps, num_stages=num_stages ), grid=(), out_shape=out_shape, diff --git a/jax/experimental/pallas/ops/gpu/softmax.py b/jax/experimental/pallas/ops/gpu/softmax.py index 7fc6a0f50cb4..27af8e254809 100644 --- a/jax/experimental/pallas/ops/gpu/softmax.py +++ b/jax/experimental/pallas/ops/gpu/softmax.py @@ -34,16 +34,16 @@ def _vmappable_softmax_kernel( row_len = input_ref.shape[-1] mask = jnp.arange(block_row) < row_len - row = pl.load( - input_ref, (pl.dslice(0, block_row),), mask=mask, other=-float("inf") + row = plgpu.load( + input_ref.at[pl.ds(0, block_row)], mask=mask, other=-float("inf") ) row_max = jnp.max(row, axis=0) numerator = jnp.exp((row - row_max).astype(jnp.float32)) denominator = jnp.sum(numerator, axis=0) - pl.store( - probs_ref, (pl.dslice(0, block_row),), + plgpu.store( + probs_ref.at[pl.ds(0, block_row)], (numerator / denominator).astype(probs_ref.dtype), mask=mask ) @@ -80,7 +80,7 @@ def softmax( kernel = functools.partial(_vmappable_softmax_kernel, block_row=block_row) f = pl.pallas_call( kernel, - compiler_params=plgpu.TritonCompilerParams( + compiler_params=plgpu.CompilerParams( num_warps=num_warps, num_stages=1), grid=(), out_shape=out_shape, diff --git a/jax/experimental/pallas/ops/gpu/transposed_ragged_dot_mgpu.py b/jax/experimental/pallas/ops/gpu/transposed_ragged_dot_mgpu.py new file mode 100644 index 000000000000..a4e8c53dc226 --- /dev/null +++ b/jax/experimental/pallas/ops/gpu/transposed_ragged_dot_mgpu.py @@ -0,0 +1,298 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Transposed ragged dot Pallas-Mosaic-GPU implementation.""" + +import functools +import itertools +import jax +from jax import lax +from jax import numpy as jnp +from jax import random +from jax._src import test_util as jtu # noqa: F401 +from jax.experimental import pallas as pl +from jax.experimental.mosaic.gpu import profiler +from jax.experimental.pallas import mosaic_gpu as plgpu +import numpy as np + + +def transposed_ragged_dot( + lhs, # (K, M) + rhs, # (K, N) + *, + group_sizes, # (G,) + block_m: int, + block_n: int, + block_k: int, + max_concurrent_steps: int, + grid_block_n: int, +) -> jax.Array: + if lhs.dtype != rhs.dtype: + raise NotImplementedError( + f"lhs and rhs must have the same dtype, got {lhs.dtype} and {rhs.dtype}" + ) + k, m = lhs.shape + k2, n = rhs.shape + g = group_sizes.shape[0] + + if k != k2: + raise ValueError(f"lhs.shape={k} must match rhs.shape={k2}") + + if m % block_m != 0: + raise ValueError(f"m={m} must be a multiple of block_m={block_m}") + if n % block_n != 0: + raise ValueError(f"n={n} must be a multiple of block_n={block_n}") + + group_sizes = group_sizes.astype(int) + group_starts = jnp.concatenate( + [jnp.zeros(1, dtype=int), jnp.cumsum(group_sizes)[:-1]] + ).astype(int) + group_ends = jnp.cumsum(group_sizes) + group_block_starts = group_starts // block_k * block_k + group_block_ends = -(group_ends // -block_k) * block_k + group_num_blocks = (group_block_ends - group_block_starts) // block_k + + swizzle = plgpu.find_swizzle(block_k * jnp.dtype(lhs.dtype).itemsize * 8) + swizzle_elems = swizzle // jnp.dtype(lhs.dtype).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), plgpu.SwizzleTransform(swizzle) + ) + + def body( + group_sizes_gmem, + group_starts_gmem, + group_ends_gmem, + group_num_blocks_gmem, + group_block_starts_gmem, + lhs_gmem, + rhs_gmem, + o_gmem, + ): + + grid_m = pl.cdiv(m, block_m) + grid_n = pl.cdiv(n, block_n) + + @plgpu.nd_loop((g, grid_m * grid_n), collective_axes="sm") + def mn_loop(loop_info: plgpu.NDLoopInfo): + g_i = loop_info.index[0] + m_i, n_i = plgpu.planar_snake( + loop_info.index[1], + (grid_m, grid_n), + 1, + grid_block_n, + ) + + # This slice is potentially out of bounds, but we never access the + # out of bound part in emit_pipeline. + gmem_slice = pl.ds(group_block_starts_gmem[g_i], k) + + def acc_scope(acc_ref): + def block_matmul(block_idx, lhs_smem, rhs_smem): + block_idx = block_idx[0] + + @pl.when(block_idx == 0) + def _(): + # Handles the first block of the group, where there might be + # data from the previous group in the beginning of the block. + lhs_reg = lhs_smem[...] + start_index = lax.rem(group_starts_gmem[g_i], block_k) + indices = plgpu.layout_cast( + jax.lax.broadcasted_iota(jnp.int32, (block_k, block_m), 0), + plgpu.Layout.WGMMA + ) + lhs_mask = (indices >= start_index).astype(lhs_smem.dtype) + + lhs_reg = lhs_reg * lhs_mask + lhs_smem[...] = lhs_reg + plgpu.commit_smem() + + @pl.when(block_idx == group_num_blocks_gmem[g_i] - 1) + def _(): + # Handles the last block of the group, where there might be + # data from the next group in the end of the block. + lhs_reg = lhs_smem[...] + last_index = lax.rem(group_ends_gmem[g_i] - 1, block_k) + indices = plgpu.layout_cast( + jax.lax.broadcasted_iota(jnp.int32, (block_k, block_m), 0), + plgpu.Layout.WGMMA + ) + lhs_mask = (indices <= last_index).astype(lhs_smem.dtype) + + lhs_reg = lhs_reg * lhs_mask + lhs_smem[...] = lhs_reg + plgpu.commit_smem() + + plgpu.wgmma(acc_ref, plgpu.transpose_ref(lhs_smem, (1, 0)), rhs_smem) + if max_concurrent_steps == 1: + # Without delayed release, we won't have at least two separate + # smem blocks in flight. Therefore, we cannot rely on the implicit + # wait of wgmma to gaurantee that the data in smem is ready to be + # overwritten by the next pipeline iteration. + plgpu.wgmma_wait(0) + + @pl.when(group_sizes_gmem[g_i] > 0) # Skip the group if it is empty. + def _(): + plgpu.emit_pipeline( + block_matmul, + grid=(group_num_blocks_gmem[g_i],), + in_specs=[ + plgpu.BlockSpec( + (block_k, block_m), + lambda k_i: (k_i, m_i), + delay_release=1 if max_concurrent_steps > 1 else 0, + transforms=transforms, + ), + plgpu.BlockSpec( + (block_k, block_n), + lambda k_i: (k_i, n_i), + delay_release=1 if max_concurrent_steps > 1 else 0, + transforms=transforms, + ), + ], + max_concurrent_steps=max_concurrent_steps, + )(lhs_gmem.at[gmem_slice, :], rhs_gmem.at[gmem_slice, :]) + + return acc_ref[...] + + acc = pl.run_scoped(acc_scope, plgpu.ACC((block_m, block_n))) + + @functools.partial( + pl.run_scoped, + o_smem=plgpu.SMEM( + (block_m, block_n), + dtype=o_gmem.dtype, + transforms=transforms, + ) + ) + def store_scope(o_smem): + o_smem[...] = acc.astype(o_smem.dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem( + o_smem, o_gmem.at[ + g_i, + pl.ds(m_i * block_m, block_m), + pl.ds(n_i * block_n, block_n) + ] + ) + plgpu.wait_smem_to_gmem(0, wait_read_only=True) + + # There are 132 SMs on a H100 SXM GPU. + num_sms = jax.devices()[0].core_count + kernel = plgpu.kernel( + body, + out_shape=jax.ShapeDtypeStruct((g, m, n), lhs.dtype), + grid=(num_sms,), + grid_names=("sm",), + ) + return kernel( + group_sizes, + group_starts, + group_ends, + group_num_blocks, + group_block_starts, + lhs, + rhs, + ) + + +def ref_transposed_ragged_dot(lhs, rhs, group_sizes): + return jax.lax.ragged_dot_general( + lhs, rhs, group_sizes, + ragged_dot_dimension_numbers=jax.lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(((0,), (0,)), ((), ())), + lhs_ragged_dimensions=[0], + rhs_group_dimensions=[], + ) + ) + + +def main(unused_argv): + k, m, n, num_groups = 16 * 1024, 2048, 2048, 16 + kx, ky, kz = random.split(random.key(1234), num=3) + + lhs = jax.random.normal(kx, (k, m), jnp.float16) + rhs = jax.random.normal(ky, (k, n), jnp.float16) + group_boundaries = jax.lax.sort( + jax.random.randint(kz, (num_groups - 1,), 0, k, jnp.int32) + ) + group_starts = lax.concatenate( + [jnp.array([0], dtype=jnp.int32), group_boundaries], 0 + ) + group_ends = lax.concatenate( + [group_boundaries, jnp.array([k], dtype=jnp.int32)], 0 + ) + group_sizes = group_ends - group_starts + assert group_sizes.shape == (num_groups,) + + block_m = block_n = [64, 128] + block_k = [64, 128] + max_concurrent_steps = [1, 2, 4, 5, 6] + grid_block_n = [1, 2, 4, 8, 16] + + configs = itertools.product( + block_m, block_n, block_k, max_concurrent_steps, grid_block_n + ) + names = ( + "block_m", "block_n", "block_k", "max_concurrent_steps", "grid_block_n", + ) + best_runtime = float("inf") + best_kwargs = {} + for config in configs: + kwargs = dict(zip(names, config)) + if n % kwargs["block_n"]: + continue + try: + f = functools.partial( + transposed_ragged_dot, group_sizes=group_sizes, + **kwargs + ) + _, runtime = profiler.measure(f)(lhs, rhs) + except ValueError as e: + if "Mosaic GPU kernel exceeds available shared memory" not in str(e): + raise + runtime = float("inf") + # Enable this to get more detailed information. + else: + print( + " ".join(f"{k}={v}" for k, v in kwargs.items()), + f"{int(runtime * 1000):.1f} us", + ) + if runtime < best_runtime: # pytype: disable=unsupported-operands + best_runtime = runtime + best_kwargs = kwargs + if not best_kwargs: + raise ValueError("No valid configuration found") + + ref, ref_runtime = profiler.measure(ref_transposed_ragged_dot)( + lhs, rhs, group_sizes=group_sizes + ) + result = transposed_ragged_dot( + lhs, rhs, group_sizes=group_sizes, **best_kwargs + ) + + tflops = float(2 * k * m * n) / (best_runtime / 1e3) / 1e12 + ref_tflops = float(2 * k * m * n) / (ref_runtime / 1e3) / 1e12 + print( + "Best parameters: ", " ".join(f"{k}={v}" for k, v in best_kwargs.items()) + ) + print(f"Kernel: {best_runtime * 1000:.1f} us = {tflops:.1f} TFLOPS") + print(f"Reference: {ref_runtime * 1000:.1f} us = {ref_tflops:.1f} TFLOPS") + np.testing.assert_allclose(result, ref, atol=1e-3, rtol=1e-3) + + +if __name__ == "__main__": + from absl import app + + jax.config.config_with_absl() + app.run(main) diff --git a/jax/experimental/pallas/ops/tpu/all_gather.py b/jax/experimental/pallas/ops/tpu/all_gather.py index 8fb975504e26..ce80a443547e 100644 --- a/jax/experimental/pallas/ops/tpu/all_gather.py +++ b/jax/experimental/pallas/ops/tpu/all_gather.py @@ -30,7 +30,7 @@ import jax from jax import lax from jax.experimental import pallas as pl -from jax.experimental import shard_map +from jax._src import shard_map from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp @@ -48,7 +48,7 @@ def get_neighbor( idx if i == which_axis else lax.axis_index(a) for i, a in enumerate(axis_names) ] - axis_size = lax.psum(1, axis_name) + axis_size = lax.axis_size(axis_name) if direction == "right": next_idx = lax.rem(idx + 1, axis_size) else: @@ -67,7 +67,7 @@ def ag_kernel(x_ref, o_ref, send_sem, recv_sem, *, axis_name: str, pltpu.async_copy(x_ref, o_ref.at[my_id], recv_sem[0]).wait() with jax.named_scope("neighbour_lookup"): - axis_size = lax.psum(1, axis_name) + axis_size = lax.axis_size(axis_name) left_neighbor = get_neighbor(my_id, mesh, axis_name, direction="left") right_neighbor = get_neighbor(my_id, mesh, axis_name, direction="right") @@ -120,7 +120,7 @@ def ag_kernel(x_ref, o_ref, send_sem, recv_sem, *, axis_name: str, jax.jit, static_argnames=["mesh", "axis_name", "memory_space"] ) def all_gather(x, *, mesh: jax.sharding.Mesh, axis_name: str | Sequence[str], - memory_space: pltpu.TPUMemorySpace = pltpu.VMEM): + memory_space: pltpu.MemorySpace = pltpu.VMEM): if isinstance(axis_name, str): axis_name = (axis_name,) # TODO(sharadmv): enable all gather over multiple axes @@ -131,12 +131,12 @@ def all_gather(x, *, mesh: jax.sharding.Mesh, axis_name: str | Sequence[str], # We can short-circuit here if our axis size is 1 return x def ag_local(x_shard): - axis_size = lax.psum(1, axis_name) + axis_size = lax.axis_size(axis_name) out_shape = jax.ShapeDtypeStruct((axis_size, *x_shard.shape), x_shard.dtype) out = pl.pallas_call( functools.partial(ag_kernel, axis_name=axis_name, mesh=mesh), out_shape=out_shape, - compiler_params=pltpu.TPUCompilerParams(collective_id=0), + compiler_params=pltpu.CompilerParams(collective_id=0), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, scratch_shapes=( @@ -151,5 +151,5 @@ def ag_local(x_shard): return shard_map.shard_map( ag_local, mesh=mesh, in_specs=P(axis_name), out_specs=P(None), - check_rep=False + check_vma=False )(x) diff --git a/jax/experimental/pallas/ops/tpu/flash_attention.py b/jax/experimental/pallas/ops/tpu/flash_attention.py index 0cb3d798d09e..731405fdba31 100644 --- a/jax/experimental/pallas/ops/tpu/flash_attention.py +++ b/jax/experimental/pallas/ops/tpu/flash_attention.py @@ -383,17 +383,14 @@ def start_new_sequence(): @pl.when(should_run) def run(): - @functools.partial( - lax.fori_loop, 0, block_k_major // block_k, init_val=None, unroll=True - ) - def body(i, _): + @pl.loop(0, block_k_major, step=block_k, unroll=True) + def _body(start_k): m_prev = m_scratch_ref[batch_idx] l_prev = l_scratch_ref[batch_idx] q = q_tile_ref[batch_idx] # [block_q, head_dim] - start_k = i * block_k - k = pl.load( - k_tile_ref, (*batch_idx, pl.dslice(start_k, block_k), slice(None)) - ) # [block_k, head_dim] + k = k_tile_ref[ + (*batch_idx, pl.dslice(start_k, block_k), slice(None)) + ] # [block_k, head_dim] s = jax.lax.dot_general( q, k, TRANS_B_DIM_NUMBERS, preferred_element_type=jnp.float32 @@ -403,10 +400,9 @@ def body(i, _): # TODO(tanburn) Should the attention bias be added before or after # multiplication by sm_scale? if ab_tile_ref is not None: - ab = pl.load( - ab_tile_ref, + ab = ab_tile_ref[ (*batch_idx, pl.dslice(None), pl.dslice(start_k, block_k)) - ).astype(jnp.float32) + ].astype(jnp.float32) s += ab if sm_scale != 1.0: @@ -419,13 +415,12 @@ def body(i, _): raise NotImplementedError( f"kv block size must be a multiple of {NUM_LANES}" ) - q_segment_ids = pltpu.repeat( - q_segment_ids_tile_ref[batch_idx[0]], repeats, axis=1 + q_segment_ids = jnp.tile( + q_segment_ids_tile_ref[batch_idx[0]], (1, repeats) ) # [block_q, block_k]. - kv_segment_ids = pl.load( - kv_segment_ids_tile_ref, - (batch_idx[0], pl.dslice(1), pl.dslice(start_k, block_k)), - ) # [1, block_k]. + kv_segment_ids = kv_segment_ids_tile_ref[ + batch_idx[0], :1, pl.dslice(start_k, block_k) + ] # [1, block_k]. mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_) if causal: @@ -449,7 +444,7 @@ def body(i, _): raise NotImplementedError( f"{block_k=} should be a multiple of {MIN_BLOCK_SIZE}" ) - p = jnp.exp(s - pltpu.repeat(m_next, block_k_repeats, 1)) + p = jnp.exp(s - jnp.tile(m_next, (1, block_k_repeats))) alpha = jnp.exp(m_prev - m_next) # Shape [block_q, 128]. @@ -458,7 +453,7 @@ def body(i, _): l_next = jnp.sum(p, axis=1)[:, None] + l_corr # Shape [block_q, 128] head_dim_repeats, rem = divmod(head_dim, MIN_BLOCK_SIZE) - l_broadcast = lambda l: pltpu.repeat(l, head_dim_repeats, 1) + l_broadcast = lambda l: jnp.tile(l, (1, head_dim_repeats)) if rem: if head_dim_repeats == 0: l_broadcast = lambda l: l[:, :head_dim] @@ -471,9 +466,7 @@ def body(i, _): l_next_inv_safe = jnp.where(l_next == 0.0, 1.0, 1.0 / l_next) acc_scratch_ref[batch_idx] *= l_broadcast(l_corr * l_next_inv_safe) - v = pl.load( - v_tile_ref, (*batch_idx, pl.dslice(start_k, block_k), slice(None)) - ) + v = v_tile_ref[(*batch_idx, pl.dslice(start_k, block_k), slice(None))] o_curr = jax.lax.dot( p.astype(v.dtype), v, preferred_element_type=jnp.float32 ) @@ -529,15 +522,13 @@ def _flash_attention_kernel_single_batch_single_step( raise NotImplementedError( f"kv block size must be a multiple of {NUM_LANES}" ) - q_segment_ids = pl.load( - q_segment_ids_tile_ref, (batch_idx[0],) - ) # [block_q, NUM_LANES]. - q_segment_ids = pltpu.repeat( - q_segment_ids, repeats, axis=1 + q_segment_ids = q_segment_ids_tile_ref[ + batch_idx[0] + ] # [block_q, NUM_LANES]. + q_segment_ids = jnp.tile( + q_segment_ids, (1, repeats) ) # [block_q, block_k]. - kv_segment_ids = pl.load( - kv_segment_ids_tile_ref, (batch_idx[0], pl.dslice(1)) - ) # [1, block_k]. + kv_segment_ids = kv_segment_ids_tile_ref[batch_idx[0], :1] # [1, block_k]. mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_) if causal: @@ -775,7 +766,7 @@ def kv_segment_ids_index_map( ), out_shape=out_shape, debug=debug, - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=( "parallel", "parallel", @@ -840,33 +831,27 @@ def q_body(j, _): start_q = j * block_q def k_body(i, _): start_k = i * block_k - k = pl.load(k_tile_ref, (0, 0, pl.ds(start_k, block_k), slice(None))) - v = pl.load(v_tile_ref, (0, 0, pl.ds(start_k, block_k), slice(None))) - q = pl.load(q_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None)) - ) # [block_q, head_dim] - l = pl.load(l_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None)) - ) # [block_q, 128] - m = pl.load(m_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None)) - ) # [block_q, 128] - do = pl.load(do_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None)) - ) # [block_q, 128] - di = pl.load(di_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None)) - ).astype(jnp.float32) # [block_q, 128] + k = k_tile_ref[0, 0, pl.ds(start_k, block_k), :] + v = v_tile_ref[0, 0, pl.ds(start_k, block_k), :] + q = q_tile_ref[0, 0, pl.ds(start_q, block_q), :] # [block_q, head_dim] + l = l_tile_ref[0, 0, pl.ds(start_q, block_q), :] # [block_q, 128] + m = m_tile_ref[0, 0, pl.ds(start_q, block_q), :] # [block_q, 128] + do = do_tile_ref[0, 0, pl.ds(start_q, block_q), :] # [block_q, 128] + di = di_tile_ref[0, 0, pl.ds(start_q, block_q), :].astype( + jnp.float32 + ) # [block_q, 128] capped_logits = lax.dot_general( q, k, TRANS_B_DIM_NUMBERS, preferred_element_type=jnp.float32 ) # [block_q_major, block_k] if ab_tile_ref is not None: - ab = pl.load( - ab_tile_ref, - ( - 0, - 0, - pl.dslice(j * block_q, block_q), - pl.dslice(i * block_k, block_k), - ), - ).astype(jnp.float32) + ab = ab_tile_ref[ + 0, + 0, + pl.dslice(j * block_q, block_q), + pl.dslice(i * block_k, block_k), + ].astype(jnp.float32) capped_logits += ab if sm_scale != 1.0: @@ -878,15 +863,15 @@ def k_body(i, _): if rem: raise NotImplementedError( ) - q_segment_ids = pl.load( - q_segment_ids_tile_ref, (0, pl.ds(start_q, block_q), slice(None)) - ) # [block_q, NUM_LANES]. - q_segment_ids = pltpu.repeat( - q_segment_ids, repeats, axis=1 + q_segment_ids = q_segment_ids_tile_ref[ + 0, pl.ds(start_q, block_q), : + ] # [block_q, NUM_LANES]. + q_segment_ids = jnp.tile( + q_segment_ids, (1, repeats) ) # [block_q, block_k]. - kv_segment_ids = pl.load( - kv_segment_ids_tile_ref, (slice(None), 0, pl.ds(start_k, block_k)) - ) # [1, block_k]. + kv_segment_ids = kv_segment_ids_tile_ref[ + :, 0, pl.ds(start_k, block_k) + ] # [1, block_k]. mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_) if causal: @@ -907,15 +892,15 @@ def k_body(i, _): ) p = jnp.exp( - capped_logits - pltpu.repeat(m, block_k // MIN_BLOCK_SIZE, axis=1) + capped_logits - jnp.tile(m, (1, block_k // MIN_BLOCK_SIZE)) ) - p = p * pltpu.repeat( - 1 / l, block_k // MIN_BLOCK_SIZE, axis=1 + p = p * jnp.tile( + 1 / l, (1, block_k // MIN_BLOCK_SIZE) ) # [block_q_major, block_k_major] dv = lax.dot(p.T.astype(do.dtype), do, preferred_element_type=jnp.float32) - pl.store(dv_scratch_ref, (pl.ds(start_k, block_k), slice(None)), - pl.load(dv_scratch_ref, (pl.ds(start_k, block_k), slice(None))) - + dv.astype(dv_scratch_ref.dtype)) + dv_scratch_ref[pl.ds(start_k, block_k), :] += dv.astype( + dv_scratch_ref.dtype + ) # di: [block_q, 128] # do: [block_q, head_dim] @@ -923,7 +908,7 @@ def k_body(i, _): dp = lax.dot_general( do, v, TRANS_B_DIM_NUMBERS, preferred_element_type=jnp.float32 ) - ds = (dp - pltpu.repeat(di, block_k // MIN_BLOCK_SIZE, axis=1)) * p + ds = (dp - jnp.tile(di, (1, block_k // MIN_BLOCK_SIZE))) * p if sm_scale != 1.0: ds = ds * sm_scale @@ -931,9 +916,9 @@ def k_body(i, _): # ds: [block_q_major, block_k_major] # q: [block_q_major, head_dim] dk = lax.dot(ds.T.astype(do.dtype), q, preferred_element_type=jnp.float32) - pl.store(dk_scratch_ref, (pl.ds(start_k, block_k), slice(None)), - pl.load(dk_scratch_ref, (pl.ds(start_k, block_k), slice(None))) - + dk.astype(dk_scratch_ref.dtype)) + dk_scratch_ref[pl.ds(start_k, block_k), :] += dk.astype( + dk_scratch_ref.dtype + ) lax.fori_loop(0, block_k_major // block_k, k_body, None, unroll=True) if causal: @@ -1144,7 +1129,7 @@ def dkv_index_map(batch_index, head_index, kv_seq_index, _): ), out_shape=out_shapes, debug=debug, - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=( "parallel", "parallel", @@ -1192,12 +1177,8 @@ def start_new_sequence(): def body(i, _): k_slice = pl.ds(i * block_k, block_k) q = q_tile_ref[0, 0, :, :] - k = pl.load( - k_tile_ref, (0, 0, k_slice, slice(None)), - ) # [block_k, head_dim] - v = pl.load( - v_tile_ref, (0, 0, k_slice, slice(None)), - ) # [block_k, head_dim] + k = k_tile_ref[0, 0, k_slice, :] # [block_k, head_dim] + v = v_tile_ref[0, 0, k_slice, :] # [block_k, head_dim] l = l_tile_ref[0, 0, :, :] # [block_q_major, 128] m = m_tile_ref[0, 0, :, :] # [block_q_major, 128] do = do_tile_ref[0, 0, :, :] # [block_q_major, head_dim] @@ -1208,9 +1189,9 @@ def body(i, _): ) if ab_tile_ref is not None: - ab = pl.load( - ab_tile_ref, (0, 0, pl.dslice(None), pl.dslice(i * block_k, block_k)) - ).astype(jnp.float32) + ab = ab_tile_ref[0, 0, :, pl.dslice(i * block_k, block_k)].astype( + jnp.float32 + ) capped_logits += ab if sm_scale != 1.0: @@ -1223,12 +1204,10 @@ def body(i, _): raise NotImplementedError( f"kv block size must be a multiple of {NUM_LANES}" ) - q_segment_ids = pltpu.repeat( - q_segment_ids_tile_ref[0], repeats, axis=1 + q_segment_ids = jnp.tile( + q_segment_ids_tile_ref[0], (1, repeats) ) # [block_q, block_k]. - kv_segment_ids = pl.load( - kv_segment_ids_tile_ref, (slice(None), 0, k_slice) - ) # [1, block_k]. + kv_segment_ids = kv_segment_ids_tile_ref[:, 0, k_slice] # [1, block_k]. mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_) if causal: @@ -1246,10 +1225,10 @@ def body(i, _): ) p = jnp.exp( - capped_logits - pltpu.repeat(m, block_k // MIN_BLOCK_SIZE, axis=1) + capped_logits - jnp.tile(m, (1, block_k // MIN_BLOCK_SIZE)) ) - p = p * pltpu.repeat( - 1 / l, block_k // MIN_BLOCK_SIZE, axis=1 + p = p * jnp.tile( + 1 / l, (1, block_k // MIN_BLOCK_SIZE) ) # [block_q_major, block_k] # di: [block_q_major, 128] @@ -1261,7 +1240,7 @@ def body(i, _): TRANS_B_DIM_NUMBERS, preferred_element_type=jnp.float32, ) - ds = (dp - pltpu.repeat(di, block_k // MIN_BLOCK_SIZE, axis=1)) * p + ds = (dp - jnp.tile(di, (1, block_k // MIN_BLOCK_SIZE))) * p # dp = jnp.dot(do, v.T) # ds = (dp - (dp * p).sum(axis=1)[:, None]) * p @@ -1269,10 +1248,8 @@ def body(i, _): ds = ds * sm_scale if ds_tile_ref is not None: - pl.store( - ds_tile_ref, - (0, 0, pl.dslice(None), pl.dslice(i * block_k, block_k)), - ds.astype(ds_tile_ref.dtype), + ds_tile_ref[0, 0, :, pl.dslice(i * block_k, block_k)] = ds.astype( + ds_tile_ref.dtype ) # dp: [block_q_major, block_k] @@ -1487,7 +1464,7 @@ def kv_segment_ids_index_map( ), out_shape=out_shapes, debug=debug, - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=( "parallel", "parallel", diff --git a/jax/experimental/pallas/ops/tpu/matmul.py b/jax/experimental/pallas/ops/tpu/matmul.py index 4ff82acbb5dd..341aa93fa258 100644 --- a/jax/experimental/pallas/ops/tpu/matmul.py +++ b/jax/experimental/pallas/ops/tpu/matmul.py @@ -14,7 +14,7 @@ """Example matmul TPU kernel. -See discussion in https://jax.readthedocs.io/en/latest/pallas/tpu/matmul.html. +See discussion in https://docs.jax.dev/en/latest/pallas/tpu/matmul.html. """ import functools @@ -78,7 +78,7 @@ def matmul( grid=(x.shape[0] // l, y.shape[1] // r, x.shape[1] // block_k), scratch_shapes=[pltpu.VMEM((l, r), acc_dtype)], ), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("parallel", "parallel", "arbitrary")), debug=debug, )(x, y) diff --git a/jax/experimental/pallas/ops/tpu/megablox/common.py b/jax/experimental/pallas/ops/tpu/megablox/common.py index bd843cf46ca4..a4a91185f67b 100644 --- a/jax/experimental/pallas/ops/tpu/megablox/common.py +++ b/jax/experimental/pallas/ops/tpu/megablox/common.py @@ -29,13 +29,15 @@ def tpu_kind() -> str: return jax.devices()[0].device_kind -_TPU_KIND_PATTERN = re.compile(r"TPU v(\d+)") +# Most TPU devices follow the pattern "TPU v{version}{variant}", e.g. "TPU v5p" +# TPU v7 has a different pattern (i.e. "TPU7x") +_TPU_KIND_PATTERN = re.compile(r"TPU( v)?(\d+)") def tpu_generation() -> int: """Generation number of the currently attached TPU.""" if version := _TPU_KIND_PATTERN.match(tpu_kind()): - return int(version[1]) + return int(version[2]) raise NotImplementedError("only TPU devices are supported") diff --git a/jax/experimental/pallas/ops/tpu/megablox/gmm.py b/jax/experimental/pallas/ops/tpu/megablox/gmm.py index 5c2f938597e7..cb185fc45f1d 100644 --- a/jax/experimental/pallas/ops/tpu/megablox/gmm.py +++ b/jax/experimental/pallas/ops/tpu/megablox/gmm.py @@ -538,7 +538,7 @@ def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset): scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)], ), input_output_aliases=input_output_aliases, - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("parallel", "arbitrary", "arbitrary")), interpret=interpret, cost_estimate=cost_estimate, @@ -777,7 +777,7 @@ def out_transform_indices(n_i, k_i, grid_id, group_metadata, group_offset): scratch_shapes=[pltpu.VMEM((tk, tn), jnp.float32)], ), input_output_aliases=input_output_aliases, - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("parallel", "arbitrary", "arbitrary")), interpret=interpret, cost_estimate=cost_estimate, diff --git a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py index eb1e11df17da..b35d6d6dc8d5 100644 --- a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py @@ -114,7 +114,7 @@ def paged_flash_attention_kernel( lengths_ref, page_indices_ref, buffer_index_ref, - step_ref, + init_flag_ref, q_ref, k_pages_hbm_ref, k_scales_pages_hbm_ref, @@ -127,7 +127,8 @@ def paged_flash_attention_kernel( k_scales_vmem_buffer, v_vmem_buffer, v_scales_vmem_buffer, - sem, + k_sems, + v_sems, *, batch_size: int, pages_per_compute_block: int, @@ -176,7 +177,9 @@ def advance_to_next_non_zero_length(): return ( lax.cond( - jnp.logical_and(next_b < batch_size, lengths_ref[next_b] == 0), + jnp.logical_and( + next_b < batch_size, + lengths_ref[lax.clamp(0, next_b, batch_size - 1)] == 0), advance_to_next_non_zero_length, lambda: next_b, ), @@ -200,7 +203,7 @@ def create_kv_async_copy_descriptors(b, h, i, buffer_index): k_scales_vmem_buffer.at[buffer_index] if k_scales_vmem_buffer is not None else None, - sem, + k_sems.at[buffer_index], page_indices_ref, page_offset, pages_to_load, @@ -213,7 +216,7 @@ def create_kv_async_copy_descriptors(b, h, i, buffer_index): v_scales_vmem_buffer.at[buffer_index] if v_scales_vmem_buffer is not None else None, - sem, + v_sems.at[buffer_index], page_indices_ref, page_offset, pages_to_load, @@ -223,16 +226,12 @@ def create_kv_async_copy_descriptors(b, h, i, buffer_index): @pl.when(i * bk < length) def flash_attention(): # pylint: disable=unused-variable - step = step_ref[0] + init_flag = init_flag_ref[0] + init_flag_ref[0] = 0 buffer_index = buffer_index_ref[0] + next_b, next_h, next_i = compute_block_indices(b, h, i + 1) - @pl.when(i == 0) - def init(): # pylint: disable=unused-variable - m_ref[...] = jnp.full_like(m_ref, -jnp.inf) - l_ref[...] = jnp.zeros_like(l_ref) - o_ref[...] = jnp.zeros_like(o_ref) - - @pl.when(step == 0) + @pl.when(init_flag) def prefetch_first_block(): # pylint: disable=unused-variable async_copy_k, async_copy_v = create_kv_async_copy_descriptors( b, h, i, buffer_index @@ -240,7 +239,11 @@ def prefetch_first_block(): # pylint: disable=unused-variable async_copy_k.start() async_copy_v.start() - next_b, next_h, next_i = compute_block_indices(b, h, i + 1) + @pl.when(i == 0) + def init(): # pylint: disable=unused-variable + m_ref[...] = jnp.full_like(m_ref, -jnp.inf) + l_ref[...] = jnp.zeros_like(l_ref) + o_ref[...] = jnp.zeros_like(o_ref) @pl.when(next_b < batch_size) def prefetch_next_block(): # pylint: disable=unused-variable @@ -257,7 +260,7 @@ def prefetch_next_block(): # pylint: disable=unused-variable ) q = q_ref[...].astype(jnp.float32) k = async_copy_k.wait_and_get_loaded() - qk = jnp.einsum('hd,td->ht', q, k, preferred_element_type=jnp.float32) + qk = jnp.einsum("gd,td->gt", q, k, preferred_element_type=jnp.float32) if attn_logits_soft_cap is not None: capped_qk = jnp.tanh(qk / attn_logits_soft_cap) qk = capped_qk * attn_logits_soft_cap @@ -274,24 +277,21 @@ def prefetch_next_block(): # pylint: disable=unused-variable alpha = jnp.exp(m_prev - m_next) beta = jnp.exp(m_curr - m_next) l_next = alpha * l_prev + beta * l_curr - l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next) + m_ref[...], l_ref[...] = m_next, l_next v = async_copy_v.wait_and_get_loaded() - o_curr_times_l_curr = jnp.dot(s_curr, v) + o_curr = jnp.einsum("gt,td->gd", s_curr, v) - m_ref[...], l_ref[...] = m_next, l_next_safe o_ref[...] = ( - (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next_safe + (l_prev * alpha * o_ref[...] + beta * o_curr) / l_next ).astype(o_ref.dtype) - step_ref[0] = step + 1 - def paged_flash_attention_kernel_inline_seq_dim( lengths_ref, page_indices_ref, buffer_index_ref, - step_ref, + init_flag_ref, q_ref, k_pages_hbm_ref, k_scales_pages_hbm_ref, @@ -304,7 +304,8 @@ def paged_flash_attention_kernel_inline_seq_dim( k_scales_vmem_buffer, v_vmem_buffer, v_scales_vmem_buffer, - sem, + k_sems, + v_sems, *, batch_size: int, pages_per_compute_block: int, @@ -326,7 +327,7 @@ def body(i, _): lengths_ref, page_indices_ref, buffer_index_ref, - step_ref, + init_flag_ref, q_ref, k_pages_hbm_ref, k_scales_pages_hbm_ref, @@ -339,7 +340,8 @@ def body(i, _): k_scales_vmem_buffer, v_vmem_buffer, v_scales_vmem_buffer, - sem, + k_sems, + v_sems, batch_size=batch_size, pages_per_compute_block=pages_per_compute_block, pages_per_sequence=pages_per_sequence, @@ -387,7 +389,7 @@ def paged_attention( """Paged grouped query attention. Args: - q: A [batch_size, num_heads, head_dim] jax.Array. + q: A [batch_size, num_q_heads, head_dim] jax.Array. k_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array. v_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array. lengths: A i32[batch_size] jax.Array the length of each example. @@ -412,7 +414,7 @@ def paged_attention( one kernel. Returns: - The output of attention([batch_size, num_heads, head_dim]). + The output of attention([batch_size, num_q_heads, head_dim]). """ if isinstance(k_pages, quantization_utils.QuantizedTensor): k_pages, k_scales_pages = k_pages.weight, k_pages.scales @@ -431,7 +433,7 @@ def paged_attention( else: v_scales_pages = None - batch_size, num_heads, head_dim = q.shape + batch_size, num_q_heads, head_dim = q.shape num_kv_heads, _, page_size, head_dim_k = k_pages.shape batch_size_paged_indices, pages_per_sequence = page_indices.shape @@ -440,10 +442,10 @@ def paged_attention( f"k_pages and v_pages must have the same shape. Got {k_pages.shape} and" f" {v_pages.shape}" # pytype: disable=attribute-error ) - if num_heads % num_kv_heads != 0: + if num_q_heads % num_kv_heads != 0: raise ValueError( "Number of Q heads must be divisible by number of KV heads. Got" - f" {num_heads} and {num_kv_heads}." + f" {num_q_heads} and {num_kv_heads}." ) if head_dim_k != head_dim: raise ValueError( @@ -461,7 +463,7 @@ def paged_attention( raise ValueError("`page_indices` and `q` must have the same batch size") if lengths.dtype != jnp.int32: raise ValueError( - "The dtype of `lengths` must be int32. Got {lengths.dtype}" + f"The dtype of `lengths` must be int32. Got {lengths.dtype}" ) # TODO(dinghua): get the actual cores per chip once there's an official API. @@ -480,40 +482,41 @@ def paged_attention( else: raise ValueError("megacore_mode must be one of ['kv_head', 'batch', None]") - if (num_heads // num_kv_heads) % 8 != 0: + num_groups = num_q_heads // num_kv_heads + if (num_groups) % 8 != 0: # Reshape q to hint XLA to pick a <1x128> layout otherwise it will pick a # <8x128> layout for a <1x128> memref inside the kernel and error out. - q = q.reshape(batch_size, num_heads, 1, head_dim) + q = q.reshape(batch_size, num_q_heads, 1, head_dim) if megacore_mode == "kv_head": q_block_spec = pl.BlockSpec( - (None, num_heads // num_kv_heads, None, head_dim), + (None, num_groups, None, head_dim), lambda core_index, b, h, *_: (b, h * num_cores + core_index, 0, 0), ) elif megacore_mode == "batch": q_block_spec = pl.BlockSpec( - (None, num_heads // num_kv_heads, None, head_dim), + (None, num_groups, None, head_dim), lambda core_index, b, h, *_: (b * num_cores + core_index, h, 0, 0), ) else: q_block_spec = pl.BlockSpec( - (None, num_heads // num_kv_heads, None, head_dim), + (None, num_groups, None, head_dim), lambda core_index, b, h, *_: (b, h, 0, 0), ) q_dtype_for_kernel_launch = jnp.float32 else: if megacore_mode == "kv_head": q_block_spec = pl.BlockSpec( - (None, num_heads // num_kv_heads, head_dim), + (None, num_groups, head_dim), lambda core_index, b, h, *_: (b, h * num_cores + core_index, 0), ) elif megacore_mode == "batch": q_block_spec = pl.BlockSpec( - (None, num_heads // num_kv_heads, head_dim), + (None, num_groups, head_dim), lambda core_index, b, h, *_: (b * num_cores + core_index, h, 0), ) else: q_block_spec = pl.BlockSpec( - (None, num_heads // num_kv_heads, head_dim), + (None, num_groups, head_dim), lambda core_index, b, h, *_: (b, h, 0), ) q_dtype_for_kernel_launch = q.dtype @@ -544,10 +547,10 @@ def paged_attention( if k_scales_pages is not None and v_scales_pages is not None: in_specs = [ q_block_spec, - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), ] scratch_shapes = ( pltpu.VMEM( @@ -586,14 +589,15 @@ def paged_attention( ), v_scales_pages.dtype, # pytype: disable=attribute-error ), # v_scales_pages buffer - pltpu.SemaphoreType.DMA, + pltpu.SemaphoreType.DMA((2,)), + pltpu.SemaphoreType.DMA((2,)), ) else: in_specs = [ q_block_spec, - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), None, # type: ignore[list-item] - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), None, # type: ignore[list-item] ] scratch_shapes = ( @@ -617,7 +621,8 @@ def paged_attention( v_pages.dtype, ), # v_pages buffer None, - pltpu.SemaphoreType.DMA, + pltpu.SemaphoreType.DMA((2,)), + pltpu.SemaphoreType.DMA((2,)), ) out, _, _ = pl.pallas_call( @@ -632,7 +637,7 @@ def paged_attention( ), grid_spec=pltpu.PrefetchScalarGridSpec( # There are 4 scalars prefetched per kernel call: `lengths_ref`, - # `page_indices_ref`, `buffer_index_ref`, `step_ref` + # `page_indices_ref`, `buffer_index_ref`, `init_flag_ref` num_scalar_prefetch=4, in_specs=in_specs, out_specs=[ @@ -643,8 +648,9 @@ def paged_attention( grid=grid, scratch_shapes=scratch_shapes, ), - compiler_params=pltpu.TPUCompilerParams( - dimension_semantics=dimension_semantics), + compiler_params=pltpu.CompilerParams( + dimension_semantics=dimension_semantics + ), out_shape=[ jax.ShapeDtypeStruct(q.shape, q_dtype_for_kernel_launch), jax.ShapeDtypeStruct((*q.shape[:-1], 1), jnp.float32), @@ -654,11 +660,11 @@ def paged_attention( lengths, page_indices.reshape(-1), jnp.zeros((1,), jnp.int32), # buffer index - jnp.zeros((1,), jnp.int32), # step + jnp.ones((1,), jnp.int32), # init flag q.astype(q_dtype_for_kernel_launch), k_pages, k_scales_pages, v_pages, v_scales_pages, ) - return out.reshape(batch_size, num_heads, head_dim).astype(q.dtype) + return out.reshape(batch_size, num_q_heads, head_dim).astype(q.dtype) diff --git a/jax/experimental/pallas/ops/tpu/paged_attention/util.py b/jax/experimental/pallas/ops/tpu/paged_attention/util.py new file mode 100644 index 000000000000..92aa3a7a1b2c --- /dev/null +++ b/jax/experimental/pallas/ops/tpu/paged_attention/util.py @@ -0,0 +1,82 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""JAX reference implementation of grouped query attention.""" + +import jax +from jax.experimental.pallas.ops.tpu.paged_attention import quantization_utils +import jax.numpy as jnp + +MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max) + + +def grouped_query_attention_reference( + queries: jax.Array, # [batch_size, num_q_heads, head_dim] + k_pages: jax.Array, # [batch_size, num_kv_heads, max_seq_len, head_dim] + v_pages: jax.Array, # [batch_size, num_kv_heads, max_seq_len, head_dim] + seq_lens: jax.Array, # i32[batch_size] + soft_cap: float | None = None, + debug: bool = False, +) -> jax.Array: # [batch_size, num_q_heads, head_dim] + """Grouped query attention with a single query per request.""" + # Check input shapes + assert k_pages.shape == v_pages.shape + batch_size, num_q_heads, head_dim = queries.shape + batch_size2, num_kv_heads, max_seq_len, head_dim2 = k_pages.shape + assert batch_size2 == batch_size + assert head_dim2 == head_dim + + # Unquantize kv pages if necessary + if isinstance(k_pages, quantization_utils.QuantizedTensor): + k_pages = quantization_utils.unquantize_from_int8( + k_pages, dtype=jnp.float32 + ) + if isinstance(v_pages, quantization_utils.QuantizedTensor): + v_pages = quantization_utils.unquantize_from_int8( + v_pages, dtype=jnp.float32 + ) + + # Reshape for num_groups queries per k head + assert num_q_heads % num_kv_heads == 0 + num_groups = num_q_heads // num_kv_heads + queries = queries.reshape(batch_size, num_kv_heads, num_groups, head_dim) + + # Compute the dot product q*k and apply soft cap if necessary + qk = jnp.einsum( + "bhgd,bhtd->bhgt", + queries.astype(jnp.float32), + k_pages.astype(jnp.float32), + ) + if soft_cap is not None and soft_cap != 0.0: + qk = jnp.tanh(qk / soft_cap) * soft_cap + assert qk.shape == (batch_size, num_kv_heads, num_groups, max_seq_len) + if debug: + jax.debug.print("qk: {qk}", qk=qk) + + # Enforce causal mask (adding dimensions when necessary) + mask = jnp.arange(max_seq_len)[None] < seq_lens[:, None] + qk += jnp.where(mask, 0.0, MASK_VALUE)[:, None, None, :] + if debug: + jax.debug.print("masked: {qk}", qk=qk) + + # Generate probability distribution using softmax + probs = jax.nn.softmax(qk, axis=-1).astype(v_pages.dtype) + assert probs.shape == (batch_size, num_kv_heads, num_groups, max_seq_len) + if debug: + jax.debug.print("softmax: {probs}", probs=probs) + + # Attention is probability-weighted sum of v heads + attention = jnp.einsum("bhgt,bhtd->bhgd", probs, v_pages) + assert attention.shape == (batch_size, num_kv_heads, num_groups, head_dim) + return attention.reshape(batch_size, num_q_heads, head_dim) diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention/__init__.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/__init__.py new file mode 100644 index 000000000000..8abc4695cf96 --- /dev/null +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jax.experimental.pallas.ops.tpu.ragged_paged_attention import kernel +from jax.experimental.pallas.ops.tpu.ragged_paged_attention import tuned_block_sizes + +dynamic_validate_inputs = kernel.dynamic_validate_inputs +ragged_paged_attention = kernel.ragged_paged_attention +ref_ragged_paged_attention = kernel.ref_ragged_paged_attention +static_validate_inputs = kernel.static_validate_inputs +get_tuned_block_sizes = tuned_block_sizes.get_tuned_block_sizes diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py similarity index 55% rename from jax/experimental/pallas/ops/tpu/ragged_paged_attention.py rename to jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py index 6600d765024c..5ddeab270657 100644 --- a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py @@ -19,14 +19,16 @@ specifications. It supports mixed prefill and decoding, enhancing throughput during inference. """ - import functools import jax from jax import lax +from jax._src import dtypes from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu +from jax.experimental.pallas.ops.tpu.ragged_paged_attention.tuned_block_sizes import get_tuned_block_sizes import jax.numpy as jnp + DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max) @@ -35,23 +37,20 @@ class MultiPageAsyncCopyDescriptor: def __init__( self, - pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads_per_blk, head_dim] - vmem_buf, # [num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim] + pages_hbm_ref, # [total_num_pages, page_size, num_combined_kv_heads_per_blk, head_dim] + vmem_buf, # [num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, head_dim] sem, page_indices_ref, # i32[max_num_seqs, pages_per_seq] - offset, # [seq_idx, kv_pages_start] + metadata, # [seq_idx, start_page_idx, end_page_idx] ): self._vmem_buf = vmem_buf - seq_id, kv_pages_start = offset - pages_per_seq = page_indices_ref.shape[1] + seq_id, start_page_idx, end_page_idx = metadata self._async_copies = [] # TODO(jevinjiang): Only fetch dynamic shape in need! This will insert # a bunch of if-ops. Check the performance when we have benchmarking setup. for i in range(vmem_buf.shape[0]): - page_idx = kv_pages_start + i - page_idx = jax.lax.select( - page_idx < pages_per_seq, page_idx, pages_per_seq - 1 - ) + page_idx = start_page_idx + i + page_idx = jax.lax.select(page_idx < end_page_idx, page_idx, 0) self._async_copies.append( pltpu.make_async_copy( pages_hbm_ref.at[page_indices_ref[seq_id, page_idx]], @@ -73,17 +72,38 @@ def wait(self): def ref_ragged_paged_attention( queries: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] - k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] num_seqs: jax.Array, # i32[1], *, sm_scale: float = 1.0, - mask_value: float = DEFAULT_MASK_VALUE, + sliding_window: int | None = None, + soft_cap: float | None = None, + mask_value: float | None = DEFAULT_MASK_VALUE, + k_scale: float | None = None, + v_scale: float | None = None, ): - _, _, num_kv_heads, head_dim = k_pages.shape + static_validate_inputs( + queries, + kv_pages, + kv_lens, + page_indices, + cu_q_lens, + num_seqs, + sm_scale=sm_scale, + k_scale=k_scale, + v_scale=v_scale, + sliding_window=sliding_window, + soft_cap=soft_cap, + mask_value=mask_value, + ) + if mask_value is None: + mask_value = DEFAULT_MASK_VALUE + _, _, num_combined_kv_heads, head_dim = kv_pages.shape + assert num_combined_kv_heads % 2 == 0 + num_kv_heads = num_combined_kv_heads // 2 num_q_heads = queries.shape[1] assert num_q_heads % num_kv_heads == 0 num_query_per_kv = num_q_heads // num_kv_heads @@ -95,8 +115,18 @@ def ref_ragged_paged_attention( kv_len = kv_lens[i] indices = page_indices[i] q = queries[q_start:q_end] - k = k_pages[indices, :, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len] - v = v_pages[indices, :, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len] + k = kv_pages[indices, :, 0::2, :].reshape(-1, num_kv_heads, head_dim)[ + :kv_len + ] + v = kv_pages[indices, :, 1::2, :].reshape(-1, num_kv_heads, head_dim)[ + :kv_len + ] + if k_scale is not None: + k = k.astype(jnp.float32) * k_scale + k = k.astype(q.dtype) + if v_scale is not None: + v = v.astype(jnp.float32) * v_scale + v = v.astype(q.dtype) k = jnp.repeat(k, num_query_per_kv, axis=1) v = jnp.repeat(v, num_query_per_kv, axis=1) attn = jnp.einsum("qhd,khd->hqk", q, k, preferred_element_type=jnp.float32) @@ -105,7 +135,12 @@ def ref_ragged_paged_attention( jnp.int32, attn.shape, 1 ) kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2) - attn += jnp.where(q_span < kv_span, mask_value, 0.0) + mask = q_span < kv_span + if sliding_window is not None: + mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span) + if soft_cap is not None: + attn = soft_cap * jnp.tanh(attn / soft_cap) + attn += jnp.where(mask, mask_value, 0.0) attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype) out = jnp.einsum("hqk,khd->qhd", attn, v).astype(queries.dtype) outputs.append(out) @@ -113,26 +148,51 @@ def ref_ragged_paged_attention( return jnp.concatenate(outputs, axis=0) -# Expect to run these checkes during runtime. -def validate_inputs_on_runtime( +# Expect to run these checks during runtime. +def dynamic_validate_inputs( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] - k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] - num_seqs, # i32[1] + num_seqs: jax.Array, # i32[1] + *, + # These inputs are optional. If not specified, we will not validate them. + sm_scale: float | None = None, + sliding_window: int | None = None, + soft_cap: float | None = None, + mask_value: float | None = None, + k_scale: float | None = None, + v_scale: float | None = None, + # Kernel tuning params. + num_kv_pages_per_block: int | None = None, + num_queries_per_block: int | None = None, + vmem_limit_bytes: int | None = None, ): - check_inputs_shapes( - q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, num_seqs + static_validate_inputs( + q, + kv_pages, + kv_lens, + page_indices, + cu_q_lens, + num_seqs, + sm_scale=sm_scale, + sliding_window=sliding_window, + soft_cap=soft_cap, + mask_value=mask_value, + k_scale=k_scale, + v_scale=v_scale, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + vmem_limit_bytes=vmem_limit_bytes, ) max_num_batched_tokens = q.shape[0] - page_size = k_pages.shape[1] + page_size = kv_pages.shape[1] max_num_seqs, pages_per_seq = page_indices.shape if num_seqs[0] > max_num_seqs: raise ValueError(f"{num_seqs[0]=} must be less or equal to {max_num_seqs=}") max_kv_len = jnp.max(kv_lens) - min_pages_per_seq = ceil_div(max_kv_len, page_size) + min_pages_per_seq = pl.cdiv(max_kv_len, page_size) if pages_per_seq < min_pages_per_seq: raise ValueError( f"{pages_per_seq=} must be greater or equal to" @@ -153,24 +213,35 @@ def validate_inputs_on_runtime( # Expect to run these checks during compile time. -def check_inputs_shapes( +def static_validate_inputs( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] - k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] - num_seqs, # i32[1] + num_seqs: jax.Array, # i32[1] + *, + # These inputs are optional. If not specified, we will not validate them. + sm_scale: float | None = None, + sliding_window: int | None = None, + soft_cap: float | None = None, + mask_value: float | None = None, + k_scale: float | None = None, + v_scale: float | None = None, + # Kernel tuning params. + num_kv_pages_per_block: int | None = None, + num_queries_per_block: int | None = None, + vmem_limit_bytes: int | None = None, ): _, num_q_heads, head_dim = q.shape - _, _, num_kv_heads, head_dim_k = k_pages.shape - max_num_seqs, _ = page_indices.shape + _, _, num_combined_kv_heads, head_dim_k = kv_pages.shape + assert num_combined_kv_heads % 2 == 0 + assert isinstance(k_scale, float) or k_scale is None + assert isinstance(v_scale, float) or v_scale is None + num_kv_heads = num_combined_kv_heads // 2 + max_num_seqs, pages_per_seq = page_indices.shape if num_seqs.shape != (1,): raise ValueError(f"{num_seqs.shape=} must be (1,)") - if k_pages.shape != v_pages.shape: - raise ValueError( - f"{k_pages.shape=} and {v_pages.shape=} must have the same shape." - ) if head_dim_k != head_dim: raise ValueError( f"Q head_dim {head_dim} must be the same as that of K/V {head_dim_k}." @@ -197,6 +268,23 @@ def check_inputs_shapes( ) if num_q_heads % num_kv_heads != 0: raise ValueError(f"{num_q_heads=} must be divisible by {num_kv_heads=}") + if sliding_window is not None and sliding_window <= 0: + raise ValueError(f"{sliding_window=} must be positive.") + if soft_cap is not None and soft_cap == 0.0: + raise ValueError(f"{soft_cap=} must not be 0.0.") + if ( + num_kv_pages_per_block is not None + and not 0 < num_kv_pages_per_block <= pages_per_seq + ): + raise ValueError( + f"{num_kv_pages_per_block=} must be in range (0, {pages_per_seq}]." + ) + if num_queries_per_block is not None and num_queries_per_block <= 0: + raise ValueError(f"{num_queries_per_block=} must be positive.") + if vmem_limit_bytes is not None and vmem_limit_bytes <= 0: + raise ValueError(f"{vmem_limit_bytes=} must be positive.") + del sm_scale # No constraints on sm_scale. + del mask_value # No consstraints on mask_value. def ragged_paged_attention_kernel( @@ -209,23 +297,32 @@ def ragged_paged_attention_kernel( num_seqs_ref, # Input q_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim] - k_pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_pages_hbm_ref, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] # Output o_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim] # Scratch - k_bufs, # [2, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim] - v_bufs, # [2, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim] + kv_bufs, # [2, num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, head_dim] sems, # [2, 2] l_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128] m_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128] + acc_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim] *, sm_scale: float, - mask_value: float, + sliding_window: int | None = None, + soft_cap: float | None = None, + mask_value: float | None = DEFAULT_MASK_VALUE, + k_scale: float | None = None, + v_scale: float | None = None, ): + if mask_value is None: + mask_value = DEFAULT_MASK_VALUE num_q_per_blk, num_q_heads_per_blk, head_dim = q_ref.shape + pages_per_seq = page_indices_ref.shape[-1] num_seqs = num_seqs_ref[0] - _, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, _ = k_bufs.shape + _, num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, _ = ( + kv_bufs.shape + ) + num_kv_heads_per_blk = num_combined_kv_heads_per_blk // 2 num_kv_per_blk = num_kv_pages_per_blk * page_size num_q_heads_per_kv_head = num_q_heads_per_blk // num_kv_heads_per_blk heads_blk_idx, q_blk_idx = ( @@ -241,42 +338,59 @@ def ragged_paged_attention_kernel( def create_kv_async_copy_descriptors( heads_blk_idx, seq_idx, kv_blk_idx, buf_idx ): - offset = (seq_idx, kv_blk_idx * num_kv_pages_per_blk) - heads_start = heads_blk_idx * num_kv_heads_per_blk - async_copy_k = MultiPageAsyncCopyDescriptor( - k_pages_hbm_ref.at[:, :, pl.ds(heads_start, num_kv_heads_per_blk), :], - k_bufs.at[buf_idx], - sems.at[buf_idx, 0], - page_indices_ref, - offset, + start_kv_page_idx = kv_blk_idx * num_kv_pages_per_blk + end_kv_page_idx = jnp.minimum( + pages_per_seq, pl.cdiv(kv_lens_ref[seq_idx], page_size) ) - async_copy_v = MultiPageAsyncCopyDescriptor( - v_pages_hbm_ref.at[:, :, pl.ds(heads_start, num_kv_heads_per_blk), :], - v_bufs.at[buf_idx], - sems.at[buf_idx, 1], + metadata = (seq_idx, start_kv_page_idx, end_kv_page_idx) + heads_start = heads_blk_idx * num_combined_kv_heads_per_blk + async_copy_kv = MultiPageAsyncCopyDescriptor( + kv_pages_hbm_ref.at[ + :, :, pl.ds(heads_start, num_combined_kv_heads_per_blk), : + ], + kv_bufs.at[buf_idx], + sems.at[buf_idx], page_indices_ref, - offset, + metadata, ) - return async_copy_k, async_copy_v + return async_copy_kv # TODO(jevinjiang): Add these to Mosaic: - # 1. Support arbitrary strided load/store for any dtype. + # 1. Support arbitrary strided load/store for int4 and int8 dtype. # 2. Support arbitrary strided load/store for any last dimension. def strided_load_kv(ref, start, step): - if ref.dtype == jnp.float32: - return ref[start::step, :] packing = get_dtype_packing(ref.dtype) - assert ref.dtype == jnp.bfloat16 + if packing == 1: + return [ref[start::step, :]], [ref[start + 1 :: step, :]] + assert packing in (2, 4, 8) assert step % packing == 0 + k_list, v_list = [], [] b_start = start // packing - b_offset = start % packing b_step = step // packing - b_ref = ref.bitcast(jnp.int32) + b_ref = ref.bitcast(jnp.uint32) b = b_ref[b_start::b_step, :] - bw = 32 // packing - b = jnp.right_shift(b, bw * b_offset) - b = jnp.left_shift(b, bw * (packing - 1)) - return pltpu.bitcast(b, jnp.float32).astype(jnp.bfloat16) + + # TODO(chengjiyao): use the general strided loading logic for bf16 after + # fixing the issue in mosaic's infer vector layout pass + if ref.dtype == jnp.bfloat16: + bk = b << 16 + bv = b & jnp.uint32(0xFFFF0000) + k = pltpu.bitcast(bk, jnp.float32).astype(jnp.bfloat16) + v = pltpu.bitcast(bv, jnp.float32).astype(jnp.bfloat16) + k_list.append(k) + v_list.append(v) + else: + bitwidth = 32 // packing + bitcast_dst_dtype = jnp.dtype(f"uint{bitwidth}") + for i in range(0, packing, 2): + bk = b >> (i * bitwidth) + k = pltpu.bitcast(bk.astype(bitcast_dst_dtype), ref.dtype) + k_list.append(k) + bv = b >> ((i + 1) * bitwidth) + v = pltpu.bitcast(bv.astype(bitcast_dst_dtype), ref.dtype) + v_list.append(v) + + return k_list, v_list def fold_on_2nd_minor(vec): assert vec.dtype == jnp.bfloat16 or vec.dtype == jnp.float32 @@ -289,15 +403,16 @@ def fold_on_2nd_minor(vec): @pl.when(heads_blk_idx + q_blk_idx == 0) def prefetch_first_kv_blk(): - async_copy_k, async_copy_v = create_kv_async_copy_descriptors( + async_copy_kv = create_kv_async_copy_descriptors( heads_blk_idx, init_seq_idx, 0, init_buf_idx ) - async_copy_k.start() - async_copy_v.start() + async_copy_kv.start() def is_cur_q_blk_needed(q_states): done, cur_seq_idx, _ = q_states - return jnp.logical_and(done == 0, cur_seq_idx < num_seqs) + should_run = jnp.logical_and(q_len_start < cu_q_lens_ref[num_seqs], + cur_seq_idx < num_seqs) + return jnp.logical_and(done == 0, should_run) def compute_with_cur_q_blk(q_states): done, cur_seq_idx, cur_buf_idx = q_states @@ -342,7 +457,7 @@ def flash_attention( v, # [num_kv_per_blk, head_dim] head_l_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128] head_m_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128] - head_o_ref, # [num_q_per_blk, num_q_heads_per_kv_head, head_dim] + head_acc_ref, # [num_q_per_blk, num_q_heads_per_kv_head, head_dim] *, kv_blk_idx, ): @@ -350,20 +465,24 @@ def flash_attention( num_q_per_blk * num_q_heads_per_kv_head, head_dim, ) - assert k.shape == ( - num_kv_per_blk, - head_dim, - ), f"{k.shape=}, {(num_kv_per_blk, head_dim)=} {k.dtype=}" - assert v.shape == (num_kv_per_blk, head_dim) - assert head_m_ref.shape == ( - num_q_per_blk * num_q_heads_per_kv_head, - 128, + assert ( + k.shape + == v.shape + == ( + num_kv_per_blk, + head_dim, + ) ) - assert head_l_ref.shape == ( - num_q_per_blk * num_q_heads_per_kv_head, - 128, + assert k.dtype == v.dtype + assert ( + head_m_ref.shape + == head_l_ref.shape + == ( + num_q_per_blk * num_q_heads_per_kv_head, + 128, + ) ) - assert head_o_ref.shape == ( + assert head_acc_ref.shape == ( num_q_per_blk, num_q_heads_per_kv_head, head_dim, @@ -372,8 +491,19 @@ def flash_attention( def masked_store(ref, val, start, end, group=1): iota = lax.broadcasted_iota(jnp.int32, ref.shape, 0) // group - mask = jnp.logical_and(iota >= start, iota < end) - pl.store(ref, tuple(slice(None) for _ in ref.shape), val, mask=mask) + pltpu.store(ref, val, mask=jnp.logical_and(iota >= start, iota < end)) + + def load_with_init(ref, init_val): + return jnp.where( + kv_blk_idx == 0, jnp.full_like(ref, init_val), ref[...] + ) + + # kv lens will be contracting dim, we should mask out the NaNs. + kv_mask = ( + lax.broadcasted_iota(jnp.int32, k.shape, 0) < kv_len - kv_len_start + ) + k = jnp.where(kv_mask, k.astype(jnp.float32), 0).astype(k.dtype) + v = jnp.where(kv_mask, v.astype(jnp.float32), 0).astype(v.dtype) qk = ( jnp.einsum("nd,md->nm", q, k, preferred_element_type=jnp.float32) @@ -382,29 +512,6 @@ def masked_store(ref, val, start, end, group=1): store_start = jnp.maximum(q_start - q_len_start, 0) store_end = jnp.minimum(q_end - q_len_start, num_q_per_blk) - @pl.when(kv_blk_idx == 0) - def init_scratch_ref(): - masked_store( - head_m_ref, - jnp.full_like(head_m_ref, -jnp.inf), - store_start, - store_end, - num_q_heads_per_kv_head, - ) - masked_store( - head_l_ref, - jnp.zeros_like(head_l_ref), - store_start, - store_end, - num_q_heads_per_kv_head, - ) - masked_store( - head_o_ref, - jnp.zeros_like(head_o_ref), - store_start, - store_end, - ) - row_ids = ( (kv_len - q_len) + q_len_start @@ -422,6 +529,11 @@ def init_scratch_ref(): 1, ) causal_mask = row_ids < col_ids + if sliding_window is not None: + causal_mask = jnp.logical_or(causal_mask, + row_ids - sliding_window >= col_ids) + if soft_cap is not None: + qk = soft_cap * jnp.tanh(qk / soft_cap) qk += jnp.where(causal_mask, mask_value, 0.0) m_curr = jnp.max(qk, axis=1, keepdims=True) s_curr = jnp.exp(qk - m_curr) @@ -431,8 +543,8 @@ def init_scratch_ref(): l_curr = jnp.broadcast_to( s_curr.sum(axis=1, keepdims=True), lm_store_shape ) - m_prev = head_m_ref[...] - l_prev = head_l_ref[...] + m_prev = load_with_init(head_m_ref, -jnp.inf) + l_prev = load_with_init(head_l_ref, 0.0) m_next = jnp.maximum(m_prev, m_curr) masked_store( head_m_ref, m_next, store_start, store_end, num_q_heads_per_kv_head @@ -461,17 +573,17 @@ def broadcast_to_shape(arr, shape): [arr for _ in range(shape[1] // arr.shape[1])], axis=1 ) - o_curr = head_o_ref[...].reshape(-1, head_dim) + o_curr = load_with_init(head_acc_ref, 0.0).reshape(-1, head_dim) l_alpha = broadcast_to_shape(l_alpha, qkv.shape) beta = broadcast_to_shape(beta, qkv.shape) l_next_safe = broadcast_to_shape(l_next_safe, qkv.shape) out = lax.div( l_alpha * o_curr + beta * qkv, l_next_safe, - ).astype(head_o_ref.dtype) + ) masked_store( - head_o_ref, - out.reshape(head_o_ref.shape), + head_acc_ref, + out.reshape(head_acc_ref.shape), store_start, store_end, ) @@ -493,39 +605,54 @@ def prefetch_next_kv_blk(): # TODO(jevinjiang): reuse the same buffer if it is already prefetched! # TODO(jevinjiang): only fetch effective dynamic size to hold kv_len and # DMA to fixed size buffer! - next_async_copy_k, next_async_copy_v = create_kv_async_copy_descriptors( + next_async_copy_kv = create_kv_async_copy_descriptors( next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx ) - next_async_copy_k.start() - next_async_copy_v.start() + next_async_copy_kv.start() - cur_async_copy_k, cur_async_copy_v = create_kv_async_copy_descriptors( + cur_async_copy_kv = create_kv_async_copy_descriptors( heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx ) - kv_to_load_shape = ( - num_kv_pages_per_blk * page_size * num_kv_heads_per_blk, + kv_ref = cur_async_copy_kv.wait().reshape( + num_kv_pages_per_blk * page_size * num_combined_kv_heads_per_blk, head_dim, ) - k_ref = cur_async_copy_k.wait().reshape(kv_to_load_shape) - v_ref = cur_async_copy_v.wait().reshape(kv_to_load_shape) - for kv_head_idx in range(num_kv_heads_per_blk): - q_head_idx = kv_head_idx * num_q_heads_per_kv_head - # TODO(jevinjiang): extra handlig for packed type that can start at - # unaligned position! - q = fold_on_2nd_minor( - q_ref[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :] - ) - k = strided_load_kv(k_ref, kv_head_idx, num_kv_heads_per_blk) - v = strided_load_kv(v_ref, kv_head_idx, num_kv_heads_per_blk) - flash_attention( - q, - k, - v, - l_ref.at[kv_head_idx], - m_ref.at[kv_head_idx], - o_ref.at[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :], - kv_blk_idx=kv_blk_idx, + kv_packing = get_dtype_packing(kv_ref.dtype) + # NOTE: kv_packing is divided by 2 because k and v are packed together. + kv_load_step = max(1, kv_packing // 2) + for kv_head_chunk_idx in range(0, num_kv_heads_per_blk, kv_load_step): + k_list, v_list = strided_load_kv( + kv_ref, kv_head_chunk_idx * 2, num_combined_kv_heads_per_blk ) + for step_idx in range(kv_load_step): + k = k_list[step_idx] + v = v_list[step_idx] + if k_scale is not None: + # NOTE: Conversion between arbitrary data types is not supported. + # That's why it is converted to float32 first. + k = k.astype(jnp.float32) * k_scale + k = k.astype(q_ref.dtype) + if v_scale is not None: + v = v.astype(jnp.float32) * v_scale + v = v.astype(q_ref.dtype) + kv_head_idx = kv_head_chunk_idx + step_idx + q_head_idx = kv_head_idx * num_q_heads_per_kv_head + # TODO(jevinjiang): extra handling for packed type that can start at + # unaligned position! + q = fold_on_2nd_minor( + q_ref[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :] + ) + flash_attention( + q, + k, + v, + l_ref.at[kv_head_idx], + m_ref.at[kv_head_idx], + acc_ref.at[ + :, q_head_idx : q_head_idx + num_q_heads_per_kv_head, : + ], + kv_blk_idx=kv_blk_idx, + ) return kv_blk_idx + 1, next_buf_idx _, next_buf_idx = lax.while_loop( @@ -545,26 +672,17 @@ def prefetch_next_kv_blk(): # Reset seq_idx for next kv_heads_blk if run out of seqs! seq_buf_idx_ref[0] = lax.select(seq_idx < num_seqs, seq_idx, 0) seq_buf_idx_ref[1] = buf_idx + o_ref[...] = acc_ref[...].astype(q_ref.dtype) -def ceil_div(a, b): - assert b != 0 - return (a + b - 1) // b +def get_dtype_packing(dtype): + bits = dtypes.itemsize_bits(dtype) + return 32 // bits -def get_dtype_packing(dtype): - if dtype == jnp.float32: - return 1 - if dtype == jnp.bfloat16: - return 2 - if dtype == jnp.int8: - return 4 - if dtype == jnp.int4: - return 8 - raise ValueError(f"Not implemented: unsupported {dtype=}") - - -def get_min_heads_per_blk(num_q_heads, num_kv_heads, q_dtype, kv_dtype): +def get_min_heads_per_blk( + num_q_heads, num_combined_kv_heads, q_dtype, kv_dtype +): q_packing = get_dtype_packing(q_dtype) kv_packing = get_dtype_packing(kv_dtype) @@ -575,22 +693,26 @@ def can_be_xla_fully_tiled(x, packing): return x in (1, 2, 4, 8) or x % 8 == 0 # TODO(jevinjiang): support unaligned number of heads! - if not can_be_xla_fully_tiled(num_kv_heads, kv_packing): + if not can_be_xla_fully_tiled(num_combined_kv_heads, kv_packing): raise ValueError( - f"Not implemented: {num_kv_heads=} can not be XLA fully tiled." + f"Not implemented: {num_combined_kv_heads=} can not be XLA fully tiled." ) + assert num_combined_kv_heads % 2 == 0 + num_kv_heads = num_combined_kv_heads // 2 assert num_q_heads % num_kv_heads == 0 ratio = num_q_heads // num_kv_heads # TODO(jevinjiang): we can choose smaller tiling for packed type if large # second minor tiling is not on. - max_kv_tiling = 8 * kv_packing - min_kv_heads = ( - max_kv_tiling if num_kv_heads % max_kv_tiling == 0 else num_kv_heads + max_combined_kv_tiling = 8 * kv_packing + min_combined_kv_heads = ( + max_combined_kv_tiling + if num_combined_kv_heads % max_combined_kv_tiling == 0 + else num_combined_kv_heads ) - min_q_heads = min_kv_heads * ratio + min_q_heads = min_combined_kv_heads // 2 * ratio if can_be_xla_fully_tiled(min_q_heads, q_packing): - return min_q_heads, min_kv_heads - return num_q_heads, num_kv_heads + return min_q_heads, min_combined_kv_heads + return num_q_heads, num_combined_kv_heads @functools.partial( @@ -601,30 +723,36 @@ def can_be_xla_fully_tiled(x, packing): "num_kv_pages_per_block", "num_queries_per_block", "vmem_limit_bytes", + "sliding_window", + "soft_cap", + "k_scale", + "v_scale", ], ) def ragged_paged_attention( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] # TODO(jevinjiang): create a write_to_kv_cache kernel! - k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] num_seqs: jax.Array, # i32[1] *, sm_scale: float = 1.0, - mask_value: float = DEFAULT_MASK_VALUE, - num_kv_pages_per_block: int = 16, - num_queries_per_block: int = 128, + sliding_window: int | None = None, + soft_cap: float | None = None, + mask_value: float | None = DEFAULT_MASK_VALUE, + k_scale: float | None = None, + v_scale: float | None = None, + num_kv_pages_per_block: int | None = None, + num_queries_per_block: int | None = None, vmem_limit_bytes: int | None = None, ): """Ragged paged attention that supports mixed prefill and decode. Args: q: concatenated all sequences' queries. - k_pages: paged K cache. Normally in HBM. - v_pages: paged V cache. Normally in HBM. + kv_pages: paged KV cache. Normally in HBM. kv_lens: padded kv lengths. Only the first num_seqs values are valid. page_indices: the first index indicates which page to use in the kv cache for each sequence. Only the first num_seqs values are valid. @@ -632,7 +760,11 @@ def ragged_paged_attention( kv_lens, only the first num_seqs+1 values are valid. num_seqs: the dynamic number of sequences. sm_scale: the softmax scale which will be applied to the Q@K^T. + sliding_window: the sliding window size for the attention. + soft_cap: the logit soft cap for the attention. mask_value: mask value for causal mask. + k_scale: the scale for the key cache. + v_scale: the scale for the value cache. num_kv_pages_per_block: number of kv pages to be processed in one flash attention block in the pallas kernel. num_queries_per_block: number of kv pages to be processed in one flash @@ -642,18 +774,50 @@ def ragged_paged_attention( Returns: The output of the attention. """ - check_inputs_shapes( - q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, num_seqs + static_validate_inputs( + q, + kv_pages, + kv_lens, + page_indices, + cu_q_lens, + num_seqs, + sm_scale=sm_scale, + sliding_window=sliding_window, + soft_cap=soft_cap, + mask_value=mask_value, + k_scale=k_scale, + v_scale=v_scale, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + vmem_limit_bytes=vmem_limit_bytes, + ) + if mask_value is None: + mask_value = DEFAULT_MASK_VALUE + num_q_tokens, num_q_heads, head_dim = q.shape + _, page_size, num_combined_kv_heads, _ = kv_pages.shape + assert num_combined_kv_heads % 2 == 0 + num_kv_heads = num_combined_kv_heads // 2 + _, pages_per_seq = page_indices.shape + num_q_heads_per_blk, num_combined_kv_heads_per_blk = get_min_heads_per_blk( + num_q_heads, num_combined_kv_heads, q.dtype, kv_pages.dtype ) - _, num_q_heads, head_dim = q.shape - _, page_size, num_kv_heads, _ = k_pages.shape num_q_per_blk = num_queries_per_block num_kv_pages_per_blk = num_kv_pages_per_block + if num_q_per_blk is None or num_kv_pages_per_blk is None: + num_kv_pages_per_blk, num_q_per_blk = get_tuned_block_sizes( + q.dtype, + kv_pages.dtype, + num_q_heads_per_blk, + num_combined_kv_heads_per_blk // 2, + head_dim, + page_size, + num_q_tokens, + pages_per_seq, + ) num_q_heads_per_kv_head = num_q_heads // num_kv_heads - num_q_blks = ceil_div(cu_q_lens[num_seqs[0]], num_q_per_blk) - num_q_heads_per_blk, num_kv_heads_per_blk = get_min_heads_per_blk( - num_q_heads, num_kv_heads, q.dtype, k_pages.dtype - ) + num_q_blks = pl.cdiv(num_q_tokens, num_q_per_blk) + assert num_combined_kv_heads_per_blk % 2 == 0 + num_kv_heads_per_blk = num_combined_kv_heads_per_blk // 2 assert num_q_heads_per_blk % num_q_heads_per_kv_head == 0 num_heads_blks = num_q_heads // num_q_heads_per_blk grid = (num_heads_blks, num_q_blks) @@ -667,8 +831,7 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): ) in_specs = [ q_block_spec, - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ] out_specs = q_block_spec lm_scratch = pltpu.VMEM( @@ -677,22 +840,26 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): (num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128), jnp.float32, ) + acc_scratch = pltpu.VMEM( + (num_q_per_blk, num_q_heads_per_blk, head_dim), + jnp.float32, + ) double_buf_scratch = pltpu.VMEM( ( 2, # For double buffering during DMA copies. num_kv_pages_per_blk, page_size, - num_kv_heads_per_blk, + num_combined_kv_heads_per_blk, head_dim, ), - k_pages.dtype, + kv_pages.dtype, ) scratch_shapes = [ - double_buf_scratch, # k_bufs - double_buf_scratch, # v_bufs - pltpu.SemaphoreType.DMA((2, 2)), # [double_buffers, k_sem/v_sem] + double_buf_scratch, # kv_bufs + pltpu.SemaphoreType.DMA((2,)), # Semaphores for double buffers. lm_scratch, # l_ref lm_scratch, # m_ref + acc_scratch, ] scalar_prefetches = ( kv_lens, @@ -705,7 +872,11 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): functools.partial( ragged_paged_attention_kernel, sm_scale=sm_scale, + sliding_window=sliding_window, + soft_cap=soft_cap, mask_value=mask_value, + k_scale=k_scale, + v_scale=v_scale, ), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=len(scalar_prefetches), @@ -714,16 +885,15 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): grid=grid, scratch_shapes=scratch_shapes, ), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=( "arbitrary", "arbitrary", ), vmem_limit_bytes=vmem_limit_bytes, ), - out_shape=jax.ShapeDtypeStruct(shape=q.shape, dtype=jnp.float32), + out_shape=jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), name="ragged_paged_attention_kernel", ) - # TODO(jevinjiang): Use f32 acc scratch for output! So we only need - # to transfer output with desired dtype back to HBM. - return kernel(*scalar_prefetches, q, k_pages, v_pages).astype(q.dtype) + + return kernel(*scalar_prefetches, q, kv_pages) diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention/tuned_block_sizes.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/tuned_block_sizes.py new file mode 100644 index 000000000000..f19fb55c6a52 --- /dev/null +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/tuned_block_sizes.py @@ -0,0 +1,1482 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Auto-tuned block sizes for ragged paged attention.""" + +import jax +import jax.numpy as jnp + +# The page size is too small. We only have 32 SREGs in TC. If the pages +# per seq is too large, SREGs will spill. +MAX_PAGES_PER_SEQ = 16 + +# key: +# - q_dtype_name +# - kv_dtype_name +# - num_q_heads_per_blk +# - num_kv_heads_per_blk +# - head_dim +# - page_size +# - max_num_batched_tokens +# - max_model_len = page_size * pages_per_seq +# value: +# - num_kv_pages_per_block +# - num_queries_per_block +TUNED_BLOCK_SIZES = { + 'TPU v6': { + # go/keep-sorted start + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 1024, 1024): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 1024, 1280): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 1024, 512): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 2048, 1280): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 2048, 2048): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 2048, 512): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 4096, 1024): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 4096, 1280): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 4096, 2048): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 4096, 512): (4, 64), + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 512, 1280): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 16, 1024, 128): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 16, 1024, 256): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 16, 1024, 64): (4, 64), + ('bfloat16', 'bfloat16', 12, 2, 128, 16, 2048, 128): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 16, 2048, 256): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 16, 2048, 64): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 16, 4096, 128): (8, 64), + ('bfloat16', 'bfloat16', 12, 2, 128, 16, 4096, 256): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 16, 4096, 64): (4, 64), + ('bfloat16', 'bfloat16', 12, 2, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 16, 512, 64): (4, 64), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 1024, 1280): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 1024, 2048): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 1024, 4096): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 2048, 1024): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 2048, 1280): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 2048, 2048): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 2048, 4096): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 4096, 1024): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 4096, 1280): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 4096, 2048): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 4096, 4096): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 512, 1280): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 512, 2048): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 512, 4096): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 32, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 32, 1024, 256): (8, 64), + ('bfloat16', 'bfloat16', 12, 2, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 32, 2048, 128): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 32, 2048, 256): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 32, 2048, 512): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 32, 4096, 128): (4, 64), + ('bfloat16', 'bfloat16', 12, 2, 128, 32, 4096, 256): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 32, 4096, 512): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 32, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 64, 1024, 256): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 64, 1024, 512): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 64, 2048, 1024): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 64, 2048, 256): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 64, 2048, 512): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 64, 4096, 1024): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 64, 4096, 256): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 64, 4096, 512): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 64, 512, 256): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 64, 512, 512): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 1024, 1024): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 1024, 1280): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 1024, 512): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 2048, 1280): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 2048, 2048): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 2048, 512): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 4096, 1024): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 4096, 1280): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 4096, 2048): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 4096, 512): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 512, 1280): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 16, 1024, 128): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 16, 1024, 256): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 16, 1024, 64): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 16, 2048, 128): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 16, 2048, 256): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 16, 2048, 64): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 16, 4096, 128): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 16, 4096, 256): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 16, 4096, 64): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 1024, 1280): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 1024, 2048): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 1024, 4096): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 2048, 1024): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 2048, 1280): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 2048, 2048): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 2048, 4096): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 4096, 1024): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 4096, 1280): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 4096, 2048): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 4096, 4096): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 512, 1280): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 512, 2048): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 512, 4096): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 32, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 32, 1024, 256): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 32, 2048, 128): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 32, 2048, 256): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 32, 2048, 512): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 32, 4096, 128): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 32, 4096, 256): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 32, 4096, 512): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 32, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 64, 1024, 256): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 64, 1024, 512): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 64, 2048, 1024): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 64, 2048, 256): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 64, 2048, 512): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 64, 4096, 1024): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 64, 4096, 256): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 64, 4096, 512): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 64, 512, 256): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 64, 512, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 1280): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 1280): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 1280): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 1280): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 1280): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 1280): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 1280): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 1280): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 1024): (16, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 512): (8, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 1024, 1024): (8, 128), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 1024, 1280): (4, 128), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 1024, 2048): (8, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 1024, 512): (4, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 2048, 1024): (8, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 2048, 1280): (4, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 2048, 2048): (8, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 2048, 512): (4, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 4096, 1024): (8, 128), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 4096, 1280): (8, 128), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 4096, 2048): (16, 128), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 4096, 512): (4, 128), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 512, 1280): (4, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 16, 1024, 128): (8, 128), + ('bfloat16', 'bfloat16', 4, 1, 128, 16, 1024, 256): (16, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 16, 1024, 64): (4, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 16, 2048, 128): (8, 128), + ('bfloat16', 'bfloat16', 4, 1, 128, 16, 2048, 256): (16, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 16, 2048, 64): (4, 128), + ('bfloat16', 'bfloat16', 4, 1, 128, 16, 4096, 128): (8, 128), + ('bfloat16', 'bfloat16', 4, 1, 128, 16, 4096, 256): (16, 256), + ('bfloat16', 'bfloat16', 4, 1, 128, 16, 4096, 64): (4, 256), + ('bfloat16', 'bfloat16', 4, 1, 128, 16, 512, 128): (8, 256), + ('bfloat16', 'bfloat16', 4, 1, 128, 16, 512, 256): (16, 128), + ('bfloat16', 'bfloat16', 4, 1, 128, 16, 512, 64): (4, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 1024, 1024): (4, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 1024, 1280): (4, 128), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 1024, 2048): (4, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 1024, 4096): (8, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 2048, 1024): (4, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 2048, 1280): (4, 128), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 2048, 2048): (8, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 2048, 4096): (16, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 4096, 1024): (4, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 4096, 1280): (4, 128), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 4096, 2048): (8, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 4096, 4096): (16, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 512, 1280): (4, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 512, 2048): (8, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 512, 4096): (16, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 32, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 32, 1024, 256): (8, 128), + ('bfloat16', 'bfloat16', 4, 1, 128, 32, 1024, 512): (16, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 32, 2048, 128): (4, 128), + ('bfloat16', 'bfloat16', 4, 1, 128, 32, 2048, 256): (8, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 32, 2048, 512): (16, 128), + ('bfloat16', 'bfloat16', 4, 1, 128, 32, 4096, 128): (4, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 32, 4096, 256): (8, 128), + ('bfloat16', 'bfloat16', 4, 1, 128, 32, 4096, 512): (16, 256), + ('bfloat16', 'bfloat16', 4, 1, 128, 32, 512, 128): (4, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 32, 512, 256): (4, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 32, 512, 512): (16, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 64, 1024, 1024): (16, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 64, 1024, 256): (4, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 64, 1024, 512): (8, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 64, 2048, 1024): (16, 128), + ('bfloat16', 'bfloat16', 4, 1, 128, 64, 2048, 256): (4, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 64, 2048, 512): (8, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 64, 4096, 1024): (16, 128), + ('bfloat16', 'bfloat16', 4, 1, 128, 64, 4096, 256): (4, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 64, 4096, 512): (8, 128), + ('bfloat16', 'bfloat16', 4, 1, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 64, 512, 256): (4, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 64, 512, 512): (8, 128), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 1024, 1024): (8, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 1024, 1280): (8, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 1024, 2048): (16, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 1024, 512): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 2048, 1024): (8, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 2048, 1280): (8, 128), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 2048, 2048): (16, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 2048, 512): (4, 128), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 4096, 1024): (8, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 4096, 1280): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 4096, 2048): (16, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 4096, 512): (4, 128), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 512, 1280): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 512, 2048): (16, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 16, 1024, 128): (8, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 16, 1024, 256): (16, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 16, 1024, 64): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 16, 2048, 128): (8, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 16, 2048, 256): (16, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 16, 2048, 64): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 16, 4096, 128): (8, 128), + ('bfloat16', 'bfloat16', 4, 2, 128, 16, 4096, 256): (16, 128), + ('bfloat16', 'bfloat16', 4, 2, 128, 16, 4096, 64): (4, 128), + ('bfloat16', 'bfloat16', 4, 2, 128, 16, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 16, 512, 64): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 1024, 1024): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 1024, 1280): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 1024, 2048): (8, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 1024, 4096): (16, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 2048, 1024): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 2048, 1280): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 2048, 2048): (8, 128), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 2048, 4096): (16, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 4096, 1024): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 4096, 1280): (4, 128), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 4096, 2048): (8, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 4096, 4096): (16, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 512, 1280): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 512, 2048): (8, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 512, 4096): (16, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 32, 1024, 128): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 32, 1024, 256): (8, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 32, 1024, 512): (16, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 32, 2048, 128): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 32, 2048, 256): (8, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 32, 2048, 512): (16, 128), + ('bfloat16', 'bfloat16', 4, 2, 128, 32, 4096, 128): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 32, 4096, 256): (8, 128), + ('bfloat16', 'bfloat16', 4, 2, 128, 32, 4096, 512): (16, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 32, 512, 128): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 32, 512, 256): (8, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 32, 512, 512): (16, 128), + ('bfloat16', 'bfloat16', 4, 2, 128, 64, 1024, 1024): (16, 128), + ('bfloat16', 'bfloat16', 4, 2, 128, 64, 1024, 256): (4, 128), + ('bfloat16', 'bfloat16', 4, 2, 128, 64, 1024, 512): (8, 128), + ('bfloat16', 'bfloat16', 4, 2, 128, 64, 2048, 1024): (16, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 64, 2048, 256): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 64, 2048, 512): (8, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 64, 4096, 1024): (16, 128), + ('bfloat16', 'bfloat16', 4, 2, 128, 64, 4096, 256): (4, 128), + ('bfloat16', 'bfloat16', 4, 2, 128, 64, 4096, 512): (8, 128), + ('bfloat16', 'bfloat16', 4, 2, 128, 64, 512, 1024): (16, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 64, 512, 256): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 64, 512, 512): (8, 64), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 1024, 1024): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 1024, 1280): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 1024, 512): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 2048, 1280): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 2048, 2048): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 2048, 512): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 4096, 1024): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 4096, 1280): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 4096, 2048): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 4096, 512): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 512, 1280): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 16, 1024, 128): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 16, 1024, 256): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 16, 1024, 64): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 16, 2048, 128): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 16, 2048, 256): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 16, 2048, 64): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 16, 4096, 128): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 16, 4096, 256): (16, 64), + ('bfloat16', 'bfloat16', 5, 1, 128, 16, 4096, 64): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 1024, 1280): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 1024, 2048): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 1024, 4096): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 2048, 1024): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 2048, 1280): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 2048, 2048): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 2048, 4096): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 4096, 1024): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 4096, 1280): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 4096, 2048): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 4096, 4096): (16, 64), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 512, 1280): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 512, 2048): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 512, 4096): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 32, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 32, 1024, 256): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 32, 2048, 128): (4, 64), + ('bfloat16', 'bfloat16', 5, 1, 128, 32, 2048, 256): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 32, 2048, 512): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 32, 4096, 128): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 32, 4096, 256): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 32, 4096, 512): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 32, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 64, 1024, 256): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 64, 1024, 512): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 64, 2048, 1024): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 64, 2048, 256): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 64, 2048, 512): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 64, 4096, 1024): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 64, 4096, 256): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 64, 4096, 512): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 64, 512, 256): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 64, 512, 512): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 1024, 1024): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 1024, 1280): (8, 64), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 1024, 512): (4, 64), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 2048, 1280): (8, 64), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 2048, 2048): (16, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 2048, 512): (4, 64), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 4096, 1024): (8, 64), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 4096, 1280): (8, 64), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 4096, 2048): (16, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 4096, 512): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 512, 1024): (8, 64), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 512, 1280): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 16, 1024, 128): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 16, 1024, 256): (16, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 16, 1024, 64): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 16, 2048, 128): (8, 64), + ('bfloat16', 'bfloat16', 6, 1, 128, 16, 2048, 256): (16, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 16, 2048, 64): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 16, 4096, 128): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 16, 4096, 256): (16, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 16, 4096, 64): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 16, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 16, 512, 256): (16, 64), + ('bfloat16', 'bfloat16', 6, 1, 128, 16, 512, 64): (4, 64), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 1024, 1280): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 1024, 2048): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 1024, 4096): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 2048, 1024): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 2048, 1280): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 2048, 2048): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 2048, 4096): (16, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 4096, 1024): (4, 64), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 4096, 1280): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 4096, 2048): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 4096, 4096): (16, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 512, 1280): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 512, 2048): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 512, 4096): (16, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 32, 1024, 128): (4, 64), + ('bfloat16', 'bfloat16', 6, 1, 128, 32, 1024, 256): (8, 64), + ('bfloat16', 'bfloat16', 6, 1, 128, 32, 1024, 512): (16, 64), + ('bfloat16', 'bfloat16', 6, 1, 128, 32, 2048, 128): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 32, 2048, 256): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 32, 2048, 512): (16, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 32, 4096, 128): (4, 64), + ('bfloat16', 'bfloat16', 6, 1, 128, 32, 4096, 256): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 32, 4096, 512): (16, 128), + ('bfloat16', 'bfloat16', 6, 1, 128, 32, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 64, 1024, 256): (4, 64), + ('bfloat16', 'bfloat16', 6, 1, 128, 64, 1024, 512): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 64, 2048, 1024): (16, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 64, 2048, 256): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 64, 2048, 512): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 64, 4096, 1024): (16, 64), + ('bfloat16', 'bfloat16', 6, 1, 128, 64, 4096, 256): (4, 128), + ('bfloat16', 'bfloat16', 6, 1, 128, 64, 4096, 512): (8, 64), + ('bfloat16', 'bfloat16', 6, 1, 128, 64, 512, 1024): (8, 64), + ('bfloat16', 'bfloat16', 6, 1, 128, 64, 512, 256): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 64, 512, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 1024): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 1280): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 512): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 1024): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 1280): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 512): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 1024): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 1280): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 512): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 1280): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 128): (8, 128), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 64): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 128): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 256): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 64): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 128): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 256): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 64): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 256): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 64): (4, 128), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 1280): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 2048): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 4096): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 1280): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 2048): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 4096): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 1024): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 1280): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 2048): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 4096): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 1280): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 2048): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 4096): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 256): (8, 128), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 128): (4, 128), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 256): (8, 128), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 512): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 128): (4, 128), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 512): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 512): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 256): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 512): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 256): (4, 128), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 512): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 1024): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 256): (4, 128), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 512): (8, 128), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 512): (8, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 1024, 1024): (8, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 1024, 1280): (8, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 1024, 512): (4, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 2048, 1024): (8, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 2048, 1280): (8, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 2048, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 2048, 512): (4, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 4096, 1024): (8, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 4096, 1280): (4, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 4096, 2048): (16, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 4096, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 512, 1024): (8, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 512, 1280): (8, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 16, 1024, 128): (8, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 16, 1024, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 16, 1024, 64): (4, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 16, 2048, 128): (8, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 16, 2048, 256): (16, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 16, 2048, 64): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 16, 4096, 128): (8, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 16, 4096, 256): (16, 128), + ('bfloat16', 'bfloat16', 8, 2, 128, 16, 4096, 64): (4, 128), + ('bfloat16', 'bfloat16', 8, 2, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 1024, 1280): (4, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 1024, 2048): (8, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 1024, 4096): (16, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 2048, 1024): (4, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 2048, 1280): (4, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 2048, 2048): (8, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 2048, 4096): (16, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 4096, 1024): (4, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 4096, 1280): (4, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 4096, 2048): (8, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 4096, 4096): (16, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 512, 1280): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 512, 2048): (8, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 512, 4096): (16, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 32, 1024, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 32, 1024, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 32, 1024, 512): (16, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 32, 2048, 128): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 32, 2048, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 32, 2048, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 32, 4096, 128): (4, 128), + ('bfloat16', 'bfloat16', 8, 2, 128, 32, 4096, 256): (8, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 32, 4096, 512): (16, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 32, 512, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 64, 1024, 1024): (16, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 64, 1024, 256): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 64, 1024, 512): (8, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 64, 2048, 1024): (16, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 64, 2048, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 64, 2048, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 64, 4096, 1024): (16, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 64, 4096, 256): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 64, 4096, 512): (8, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 64, 512, 256): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 64, 512, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 1024, 1024): (8, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 1024, 1280): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 1024, 2048): (16, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 1024, 512): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 2048, 1280): (8, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 2048, 2048): (16, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 2048, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 4096, 1024): (8, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 4096, 1280): (8, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 4096, 2048): (16, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 4096, 512): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 512, 1280): (8, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 512, 512): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 16, 1024, 128): (8, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 16, 1024, 256): (16, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 16, 1024, 64): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 16, 2048, 128): (8, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 16, 2048, 256): (16, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 16, 2048, 64): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 16, 4096, 128): (8, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 16, 4096, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 16, 4096, 64): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 16, 512, 128): (8, 128), + ('bfloat16', 'bfloat16', 8, 4, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 1024, 1024): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 1024, 1280): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 1024, 2048): (8, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 1024, 4096): (16, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 2048, 1024): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 2048, 1280): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 2048, 2048): (8, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 2048, 4096): (16, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 4096, 1024): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 4096, 1280): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 4096, 2048): (8, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 4096, 4096): (16, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 512, 1280): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 512, 2048): (8, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 512, 4096): (16, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 32, 1024, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 32, 1024, 256): (8, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 32, 1024, 512): (16, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 32, 2048, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 32, 2048, 256): (8, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 32, 2048, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 32, 4096, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 32, 4096, 256): (8, 128), + ('bfloat16', 'bfloat16', 8, 4, 128, 32, 4096, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 32, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 64, 1024, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 64, 1024, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 64, 2048, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 64, 2048, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 64, 2048, 512): (8, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 64, 4096, 1024): (16, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 64, 4096, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 64, 4096, 512): (8, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 64, 512, 1024): (16, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 64, 512, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 64, 512, 512): (8, 32), + # go/keep-sorted end + }, + 'TPU v5': { + # go/keep-sorted start + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 1024, 1024): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 1024, 1280): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 1024, 512): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 2048, 1280): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 2048, 2048): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 2048, 512): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 4096, 1024): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 4096, 1280): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 4096, 2048): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 4096, 512): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 512, 1280): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 16, 1024, 128): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 16, 1024, 256): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 16, 1024, 64): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 16, 2048, 128): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 16, 2048, 256): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 16, 2048, 64): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 16, 4096, 128): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 16, 4096, 256): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 16, 4096, 64): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 1024, 1280): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 1024, 2048): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 1024, 4096): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 2048, 1024): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 2048, 1280): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 2048, 2048): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 2048, 4096): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 4096, 1024): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 4096, 1280): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 4096, 2048): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 4096, 4096): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 512, 1280): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 512, 2048): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 256, 512, 4096): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 32, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 32, 1024, 256): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 32, 2048, 128): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 32, 2048, 256): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 32, 2048, 512): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 32, 4096, 128): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 32, 4096, 256): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 32, 4096, 512): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 32, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 64, 1024, 256): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 64, 1024, 512): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 64, 2048, 1024): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 64, 2048, 256): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 64, 2048, 512): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 64, 4096, 1024): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 64, 4096, 256): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 64, 4096, 512): (8, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 64, 512, 256): (4, 32), + ('bfloat16', 'bfloat16', 12, 2, 128, 64, 512, 512): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 1024, 1024): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 1024, 1280): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 1024, 512): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 2048, 1280): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 2048, 2048): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 2048, 512): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 4096, 1024): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 4096, 1280): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 4096, 2048): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 4096, 512): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 512, 1280): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 16, 1024, 128): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 16, 1024, 256): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 16, 1024, 64): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 16, 2048, 128): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 16, 2048, 256): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 16, 2048, 64): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 16, 4096, 128): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 16, 4096, 256): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 16, 4096, 64): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 1024, 1280): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 1024, 2048): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 1024, 4096): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 2048, 1024): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 2048, 1280): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 2048, 2048): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 2048, 4096): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 4096, 1024): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 4096, 1280): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 4096, 2048): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 4096, 4096): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 512, 1280): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 512, 2048): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 256, 512, 4096): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 32, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 32, 1024, 256): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 32, 2048, 128): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 32, 2048, 256): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 32, 2048, 512): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 32, 4096, 128): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 32, 4096, 256): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 32, 4096, 512): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 32, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 64, 1024, 256): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 64, 1024, 512): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 64, 2048, 1024): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 64, 2048, 256): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 64, 2048, 512): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 64, 4096, 1024): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 64, 4096, 256): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 64, 4096, 512): (8, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 64, 512, 256): (4, 32), + ('bfloat16', 'bfloat16', 28, 4, 128, 64, 512, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 1280): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 1280): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 1280): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 1280): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 1280): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 1280): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 1280): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 1280): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 512): (8, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 1024, 1024): (8, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 1024, 1280): (8, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 1024, 512): (4, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 2048, 1024): (8, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 2048, 1280): (4, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 2048, 2048): (16, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 2048, 512): (4, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 4096, 1024): (8, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 4096, 1280): (4, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 4096, 2048): (16, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 4096, 512): (4, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 512, 1280): (4, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 128, 512, 512): (4, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 16, 1024, 128): (8, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 16, 1024, 256): (16, 128), + ('bfloat16', 'bfloat16', 4, 1, 128, 16, 1024, 64): (4, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 16, 2048, 128): (8, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 16, 2048, 256): (16, 128), + ('bfloat16', 'bfloat16', 4, 1, 128, 16, 2048, 64): (4, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 16, 4096, 128): (4, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 16, 4096, 256): (16, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 16, 4096, 64): (4, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 16, 512, 128): (8, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 16, 512, 256): (16, 128), + ('bfloat16', 'bfloat16', 4, 1, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 1024, 1024): (4, 128), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 1024, 1280): (4, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 1024, 2048): (8, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 1024, 4096): (16, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 2048, 1024): (4, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 2048, 1280): (4, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 2048, 2048): (4, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 2048, 4096): (16, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 4096, 1024): (4, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 4096, 1280): (4, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 4096, 2048): (8, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 4096, 4096): (16, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 512, 1280): (4, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 512, 2048): (8, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 256, 512, 4096): (16, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 32, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 32, 1024, 256): (8, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 32, 1024, 512): (16, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 32, 2048, 128): (4, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 32, 2048, 256): (8, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 32, 2048, 512): (16, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 32, 4096, 128): (4, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 32, 4096, 256): (8, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 32, 4096, 512): (8, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 32, 512, 128): (4, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 32, 512, 512): (16, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 64, 1024, 1024): (8, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 64, 1024, 256): (4, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 64, 1024, 512): (8, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 64, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 64, 2048, 256): (4, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 64, 2048, 512): (8, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 64, 4096, 1024): (16, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 64, 4096, 256): (4, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 64, 4096, 512): (4, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 4, 1, 128, 64, 512, 256): (4, 64), + ('bfloat16', 'bfloat16', 4, 1, 128, 64, 512, 512): (8, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 1024, 1024): (8, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 1024, 1280): (8, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 1024, 2048): (16, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 1024, 512): (4, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 2048, 1024): (8, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 2048, 1280): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 2048, 2048): (16, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 2048, 512): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 4096, 1024): (8, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 4096, 1280): (4, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 4096, 2048): (16, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 4096, 512): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 512, 1280): (4, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 512, 2048): (16, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 16, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 16, 1024, 256): (8, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 16, 1024, 64): (4, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 16, 2048, 128): (8, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 16, 2048, 256): (16, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 16, 2048, 64): (4, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 16, 4096, 128): (8, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 16, 4096, 256): (16, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 16, 4096, 64): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 16, 512, 64): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 1024, 1280): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 1024, 2048): (8, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 1024, 4096): (16, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 2048, 1024): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 2048, 1280): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 2048, 2048): (8, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 2048, 4096): (8, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 4096, 1024): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 4096, 1280): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 4096, 2048): (8, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 4096, 4096): (16, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 512, 1024): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 512, 1280): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 512, 2048): (8, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 256, 512, 4096): (16, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 32, 1024, 128): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 32, 1024, 256): (8, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 32, 2048, 128): (4, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 32, 2048, 256): (8, 128), + ('bfloat16', 'bfloat16', 4, 2, 128, 32, 2048, 512): (16, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 32, 4096, 128): (4, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 32, 4096, 256): (8, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 32, 4096, 512): (16, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 32, 512, 128): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 32, 512, 256): (8, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 32, 512, 512): (16, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 64, 1024, 256): (4, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 64, 1024, 512): (8, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 64, 2048, 1024): (16, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 64, 2048, 256): (4, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 64, 2048, 512): (8, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 64, 4096, 1024): (16, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 64, 4096, 256): (4, 32), + ('bfloat16', 'bfloat16', 4, 2, 128, 64, 4096, 512): (8, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 64, 512, 1024): (16, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 64, 512, 256): (4, 64), + ('bfloat16', 'bfloat16', 4, 2, 128, 64, 512, 512): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 1024, 1024): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 1024, 1280): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 1024, 512): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 2048, 1280): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 2048, 2048): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 2048, 512): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 4096, 1024): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 4096, 1280): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 4096, 2048): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 4096, 512): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 512, 1280): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 16, 1024, 128): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 16, 1024, 256): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 16, 1024, 64): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 16, 2048, 128): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 16, 2048, 256): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 16, 2048, 64): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 16, 4096, 128): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 16, 4096, 256): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 16, 4096, 64): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 1024, 1280): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 1024, 2048): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 1024, 4096): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 2048, 1024): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 2048, 1280): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 2048, 2048): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 2048, 4096): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 4096, 1024): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 4096, 1280): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 4096, 2048): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 4096, 4096): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 512, 1280): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 512, 2048): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 256, 512, 4096): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 32, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 32, 1024, 256): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 32, 2048, 128): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 32, 2048, 256): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 32, 2048, 512): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 32, 4096, 128): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 32, 4096, 256): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 32, 4096, 512): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 32, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 64, 1024, 256): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 64, 1024, 512): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 64, 2048, 1024): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 64, 2048, 256): (4, 64), + ('bfloat16', 'bfloat16', 5, 1, 128, 64, 2048, 512): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 64, 4096, 1024): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 64, 4096, 256): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 64, 4096, 512): (8, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 64, 512, 256): (4, 32), + ('bfloat16', 'bfloat16', 5, 1, 128, 64, 512, 512): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 1024, 1024): (8, 64), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 1024, 1280): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 1024, 512): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 2048, 1280): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 2048, 2048): (16, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 2048, 512): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 4096, 1024): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 4096, 1280): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 4096, 2048): (16, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 4096, 512): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 512, 1280): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 16, 1024, 128): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 16, 1024, 256): (16, 64), + ('bfloat16', 'bfloat16', 6, 1, 128, 16, 1024, 64): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 16, 2048, 128): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 16, 2048, 256): (16, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 16, 2048, 64): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 16, 4096, 128): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 16, 4096, 256): (16, 64), + ('bfloat16', 'bfloat16', 6, 1, 128, 16, 4096, 64): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 1024, 1280): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 1024, 2048): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 1024, 4096): (16, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 2048, 1024): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 2048, 1280): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 2048, 2048): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 2048, 4096): (16, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 4096, 1024): (4, 64), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 4096, 1280): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 4096, 2048): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 4096, 4096): (16, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 512, 1280): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 512, 2048): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 256, 512, 4096): (16, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 32, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 32, 1024, 256): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 32, 2048, 128): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 32, 2048, 256): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 32, 2048, 512): (16, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 32, 4096, 128): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 32, 4096, 256): (8, 64), + ('bfloat16', 'bfloat16', 6, 1, 128, 32, 4096, 512): (16, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 32, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 32, 512, 256): (8, 64), + ('bfloat16', 'bfloat16', 6, 1, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 64, 1024, 256): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 64, 1024, 512): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 64, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 64, 2048, 256): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 64, 2048, 512): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 64, 4096, 1024): (16, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 64, 4096, 256): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 64, 4096, 512): (8, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 64, 512, 256): (4, 32), + ('bfloat16', 'bfloat16', 6, 1, 128, 64, 512, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 1280): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 1280): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 2048): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 1024): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 1280): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 1280): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 64): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 128): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 256): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 64): (4, 128), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 128): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 256): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 64): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 1280): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 2048): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 4096): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 1280): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 2048): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 4096): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 1024): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 1280): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 2048): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 4096): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 1280): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 2048): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 4096): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 128): (4, 128), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 256): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 128): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 512): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 256): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 512): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 256): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 512): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 256): (4, 128), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 256): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 1024, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 1024, 1280): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 1024, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 2048, 1280): (8, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 2048, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 2048, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 4096, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 4096, 1280): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 4096, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 4096, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 512, 1280): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 512, 2048): (8, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 16, 1024, 128): (8, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 16, 1024, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 16, 1024, 64): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 16, 2048, 128): (8, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 16, 2048, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 16, 2048, 64): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 16, 4096, 128): (8, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 16, 4096, 256): (16, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 16, 4096, 64): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 1024, 1280): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 1024, 2048): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 1024, 4096): (8, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 2048, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 2048, 1280): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 2048, 2048): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 2048, 4096): (16, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 4096, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 4096, 1280): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 4096, 2048): (8, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 4096, 4096): (16, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 512, 1280): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 512, 2048): (8, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 256, 512, 4096): (16, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 32, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 32, 1024, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 32, 1024, 512): (16, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 32, 2048, 128): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 32, 2048, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 32, 2048, 512): (16, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 32, 4096, 128): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 32, 4096, 256): (8, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 32, 4096, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 32, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 32, 512, 512): (16, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 64, 1024, 256): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 64, 1024, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 64, 2048, 1024): (16, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 64, 2048, 256): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 64, 2048, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 64, 4096, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 64, 4096, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 2, 128, 64, 4096, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 64, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 64, 512, 256): (4, 32), + ('bfloat16', 'bfloat16', 8, 2, 128, 64, 512, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 1024, 1024): (8, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 1024, 1280): (8, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 1024, 2048): (16, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 1024, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 2048, 1024): (8, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 2048, 1280): (8, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 2048, 2048): (16, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 2048, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 4096, 1024): (8, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 4096, 1280): (8, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 4096, 2048): (16, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 4096, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 512, 1280): (4, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 512, 2048): (16, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 128, 512, 512): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 16, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 16, 1024, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 16, 1024, 64): (4, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 16, 2048, 128): (8, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 16, 2048, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 16, 2048, 64): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 16, 4096, 128): (4, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 16, 4096, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 16, 4096, 64): (4, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 1024, 1024): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 1024, 1280): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 1024, 2048): (8, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 1024, 4096): (8, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 2048, 1024): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 2048, 1280): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 2048, 2048): (8, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 2048, 4096): (16, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 4096, 1024): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 4096, 1280): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 4096, 2048): (8, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 4096, 4096): (8, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 512, 1024): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 512, 1280): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 512, 2048): (8, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 256, 512, 4096): (16, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 32, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 32, 1024, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 32, 2048, 128): (4, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 32, 2048, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 32, 2048, 512): (16, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 32, 4096, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 32, 4096, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 32, 4096, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 32, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 32, 512, 256): (8, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 64, 1024, 256): (4, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 64, 1024, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 64, 2048, 1024): (16, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 64, 2048, 256): (4, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 64, 2048, 512): (8, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 64, 4096, 1024): (16, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 64, 4096, 256): (4, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 64, 4096, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 64, 512, 1024): (16, 64), + ('bfloat16', 'bfloat16', 8, 4, 128, 64, 512, 256): (4, 32), + ('bfloat16', 'bfloat16', 8, 4, 128, 64, 512, 512): (8, 32), + # go/keep-sorted end + }, +} + + +def next_power_of_2(x: int): + """Finds the smallest power of 2 >= x using bit manipulation. + + Args: + x: The input number (should be an integer). + + Returns: + The smallest integer power of 2 that is >= x. + """ + assert x > 0 + if x == 1: + return 1 + return 1 << (x - 1).bit_length() + + +def simplify_key(key): + """Simplify the key to reduce the number of combinations.""" + ( + q_dtype, + kv_dtype, + num_q_heads_per_blk, + num_kv_heads_per_blk, + head_dim, + page_size, + max_num_batched_tokens, + pages_per_seq, + ) = key + return ( + jnp.dtype(q_dtype).name, + jnp.dtype(kv_dtype).name, + next_power_of_2(num_q_heads_per_blk), + next_power_of_2(num_kv_heads_per_blk), + (head_dim + 127) // 128 * 128, + next_power_of_2(page_size), + next_power_of_2(max_num_batched_tokens), + next_power_of_2(page_size * pages_per_seq), + ) + + +def get_tpu_version() -> int: + """Returns the numeric version of the TPU, or -1 if not on TPU.""" + kind = jax.devices()[0].device_kind + if 'TPU' not in kind: + return -1 + if kind.endswith(' lite'): + kind = kind[: -len(' lite')] + assert kind[:-1] == 'TPU v', kind + return int(kind[-1]) + + +def get_device_name(num_devices: int | None = None): + name = ' '.join(jax.devices()[0].device_kind.split()[:2]) + if num_devices is not None: + name += f'-{num_devices}' + return name + + +def get_tuned_block_sizes( + q_dtype, + kv_dtype, + num_q_heads_per_blk, + num_kv_heads_per_blk, + head_dim, + page_size, + max_num_batched_tokens, + pages_per_seq, +) -> tuple[int, int]: + """Look up for the best (num_kv_pages_per_blk, num_queries_per_blk) from auto-tuned table.""" + tpu_version = get_tpu_version() + if tpu_version < 4: + raise NotImplementedError('TPU version must be 4 or higher.') + key = ( + q_dtype, + kv_dtype, + num_q_heads_per_blk, + num_kv_heads_per_blk, + head_dim, + page_size, + max_num_batched_tokens, + pages_per_seq, + ) + key = simplify_key(key) + device_name = get_device_name() + + # Default block sizes. + bkv, bq = (128, 32) + if tpu_version == 4: + # This default block size is not tuned, only make sure there's no + # OOM in vmem + bkv, bq = (32, 32) + elif device_name in TUNED_BLOCK_SIZES: + if key in TUNED_BLOCK_SIZES[device_name]: + bkv, bq = TUNED_BLOCK_SIZES[device_name][key] + return (min(pages_per_seq, bkv), min(max_num_batched_tokens, bq)) + + +def get_min_page_size(max_model_len, min_page_size=16): + """Recommended min page size for high-performance kernel.""" + return max(next_power_of_2(max_model_len) // MAX_PAGES_PER_SEQ, min_page_size) diff --git a/jax/experimental/pallas/ops/tpu/random/__init__.py b/jax/experimental/pallas/ops/tpu/random/__init__.py new file mode 100644 index 000000000000..3da0dd1fa3ca --- /dev/null +++ b/jax/experimental/pallas/ops/tpu/random/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/jax/experimental/pallas/ops/tpu/random/philox.py b/jax/experimental/pallas/ops/tpu/random/philox.py index 28e627cfb298..cb108c319507 100644 --- a/jax/experimental/pallas/ops/tpu/random/philox.py +++ b/jax/experimental/pallas/ops/tpu/random/philox.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Implementation of the Philox PRNG as a Pallas kernel.""" -from typing import Sequence +from collections.abc import Sequence import jax from jax import typing from jax._src import prng @@ -46,7 +46,9 @@ def mul32_hi_lo(x: jax.Array, y: jax.Array) -> tuple[jax.Array, jax.Array]: cross_xy = xhi * ylo cross_yx = xlo * yhi carry = (cross_xy & 0xffff) + (cross_yx & 0xffff) + (xy_lo >> 16) - return xy_hi + (cross_xy >> 16) + (cross_yx >> 16) + (carry >> 16), xy_lo + result_hi = xy_hi + (cross_xy >> 16) + (cross_yx >> 16) + (carry >> 16) + result_lo = (carry << 16) + (xy_lo & 0xffff) + return result_hi, result_lo def philox_4x32(hi0, lo0, hi1, lo1, k_hi, k_lo, rounds = 10): @@ -115,7 +117,7 @@ def kernel(offset_ref, key_ref, out_ref): offset = prng_utils.compute_scalar_offset( counts_idx, unpadded_shape, block_shape) counts_lo = prng_utils.blocked_iota(block_size, unpadded_shape) - counts_lo = counts_lo + offset + offset_ref[0] + counts_lo = counts_lo + offset.astype(jnp.uint32) + offset_ref[0] counts_lo = counts_lo.astype(jnp.uint32) # TODO(justinfu): Support hi bits on count. _zeros = jnp.zeros_like(counts_lo) @@ -140,8 +142,8 @@ def kernel(offset_ref, key_ref, out_ref): return pl.pallas_call( kernel, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), ], out_specs=out_spec, grid=grid_dims, diff --git a/jax/experimental/pallas/ops/tpu/random/prng_utils.py b/jax/experimental/pallas/ops/tpu/random/prng_utils.py index e5a3ac155eea..3014c7748f22 100644 --- a/jax/experimental/pallas/ops/tpu/random/prng_utils.py +++ b/jax/experimental/pallas/ops/tpu/random/prng_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Helper functions for PRNG kernels.""" -from typing import Sequence +from collections.abc import Sequence from jax import lax import jax.numpy as jnp diff --git a/jax/experimental/pallas/ops/tpu/random/threefry.py b/jax/experimental/pallas/ops/tpu/random/threefry.py index 5c460d491f48..06a82f4abac8 100644 --- a/jax/experimental/pallas/ops/tpu/random/threefry.py +++ b/jax/experimental/pallas/ops/tpu/random/threefry.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Implementation of the Threefry PRNG as a Pallas kernel.""" -from typing import Sequence +from collections.abc import Sequence import jax from jax._src import prng from jax.experimental import pallas as pl @@ -63,7 +63,7 @@ def kernel(key_ref, out_ref): offset = prng_utils.compute_scalar_offset( counts_idx, unpadded_shape, block_shape) counts_lo = prng_utils.blocked_iota(block_size, unpadded_shape) - counts_lo = counts_lo + offset + counts_lo = counts_lo + offset.astype(jnp.uint32) counts_lo = counts_lo.astype(jnp.uint32) # TODO(justinfu): Support hi bits on count. counts_hi = jnp.zeros_like(counts_lo) @@ -79,7 +79,7 @@ def kernel(key_ref, out_ref): block_shape = (1,) * (len(shape)-2) + block_size result = pl.pallas_call( kernel, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM)], + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], out_specs=pl.BlockSpec(block_shape, lambda *idxs: idxs), grid=grid_dims, out_shape=out, diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py index 4b6e4a41c43b..c30214f36f29 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py @@ -16,10 +16,11 @@ from __future__ import annotations -from collections.abc import Callable, Mapping +from collections.abc import Callable import dataclasses import enum import functools +import json from typing import Any, Literal, NamedTuple, Optional, Union, overload import jax @@ -80,6 +81,7 @@ class SegmentIds(NamedTuple): jax.Array, # k jax.Array, # v Optional[SegmentIds], # segment_ids + Optional[jax.Array], # sinks jax.Array, # out jax.Array, # logsumexp Optional[mask_info_lib.MaskInfo], # dq_mask_info @@ -90,7 +92,6 @@ class SegmentIds(NamedTuple): def get_kernel_name( - block_metadata: Mapping[str, Any], is_mqa: bool, save_residuals: bool, is_segmented: bool, @@ -100,16 +101,10 @@ def get_kernel_name( assert phase == "dq" or phase == "dkv" or phase == "fwd" # Saving residuals is supported only for the fwd phase. assert not save_residuals or phase == "fwd" - residuals = "" - if save_residuals: - residuals = "_residuals" - elif phase == "fwd": - residuals = "_no_residuals" + residuals = "_residuals" if save_residuals else "_no_residuals" attention_type = "mqa" if is_mqa else "mha" segments = "_segmented" if is_segmented else "" - return f"splash_{attention_type}_{phase}{segments}{residuals}_" + "_".join( - f"{k}={v}" for k, v in sorted(block_metadata.items()) - ) + return f"splash_{attention_type}_{phase}{segments}{residuals}" # Reference attention implementations @@ -122,6 +117,7 @@ def _attention_reference( k: jax.Array, v: jax.Array, segment_ids: SegmentIds | None, + sinks: jax.Array | None, save_residuals: Literal[False], mask_value: float, custom_type: str, @@ -137,6 +133,7 @@ def _attention_reference( k: jax.Array, v: jax.Array, segment_ids: SegmentIds | None, + sinks: jax.Array | None, save_residuals: Literal[True], mask_value: float, custom_type: str, @@ -151,6 +148,7 @@ def _attention_reference( k: jax.Array, # [kv_seq_len, head_dim] v: jax.Array, # [kv_seq_len, head_dim] segment_ids: SegmentIds | None, + sinks: jax.Array | None, mask_value: float, save_residuals: bool, custom_type: str, @@ -162,6 +160,7 @@ def _attention_reference( k, v, segment_ids, + sinks, mask_value, save_residuals, custom_type, @@ -175,6 +174,7 @@ def _attention_reference_default( k: jax.Array, # [kv_seq_len, head_dim] v: jax.Array, # [kv_seq_len, head_dim] segment_ids: SegmentIds | None, + sinks: jax.Array | None, # [] one scalar per qhead mask_value: float, save_residuals: bool, custom_type: str, @@ -194,8 +194,10 @@ def _attention_reference_default( logits = jnp.where(mask, logits, mask_value) m = logits.max(axis=-1) + sinks = None if sinks is None else sinks.astype(logits.dtype) + m = m if sinks is None else jnp.maximum(m, sinks) s = jnp.exp(logits - m[..., None]) - l = s.sum(axis=-1) + l = s.sum(axis=-1) + (0 if sinks is None else jnp.exp(sinks - m)) s = s / l[..., None] o = jnp.einsum("st,td->sd", s, v.astype(jnp.float32)) @@ -212,6 +214,7 @@ def attention_reference( k: jax.Array, # [kv_seq_len, head_dim] v: jax.Array, # [kv_seq_len, head_dim] segment_ids: SegmentIds | None, + sinks: jax.Array | None = None, *, mask_value: float = DEFAULT_MASK_VALUE, save_residuals: bool = False, @@ -224,6 +227,7 @@ def attention_reference( k, v, segment_ids, + sinks, mask_value=mask_value, save_residuals=save_residuals, custom_type=custom_type, @@ -237,6 +241,7 @@ def _attention_reference_custom_fwd( k: jax.Array, # [kv_seq_len, head_dim] v: jax.Array, # [kv_seq_len, head_dim] segment_ids: SegmentIds | None, + sinks: jax.Array | None, mask_value: float, save_residuals: bool, custom_type: str, @@ -251,12 +256,13 @@ def _attention_reference_custom_fwd( k, v, segment_ids, + sinks, mask_value=mask_value, save_residuals=True, custom_type=custom_type, attn_logits_soft_cap=attn_logits_soft_cap, ) - return o, (mask, q, k, v, segment_ids, o, logsumexp) + return o, (mask, q, k, v, segment_ids, sinks, o, logsumexp) def _attention_reference_custom_bwd( @@ -266,9 +272,9 @@ def _attention_reference_custom_bwd( attn_logits_soft_cap: float | None, res, do: jax.Array, -) -> tuple[None, jax.Array, jax.Array, jax.Array, None]: +) -> tuple[None, jax.Array, jax.Array, jax.Array, None, jax.Array | None]: del save_residuals - mask, q, k, v, segment_ids, o, logsumexp = res + mask, q, k, v, segment_ids, sinks, o, logsumexp = res uncapped_logits = jnp.einsum( "qc,kc->qk", q, k, preferred_element_type=jnp.float32) @@ -306,11 +312,17 @@ def _attention_reference_custom_bwd( ds = g + g * d dk = jnp.einsum("sd,st->td", q.astype(jnp.float32), ds).astype(k.dtype) dq = jnp.einsum("st,td->sd", ds, k.astype(jnp.float32)).astype(q.dtype) - return None, dq, dk, dv, None + dsinks = None + if sinks is not None: # the gradient is ``sum(-exp(s) / exp(lse) * o * do)`` + sinks_exp = -jnp.exp(sinks[..., None, None].astype(jnp.float32) + - logsumexp[..., None].astype(jnp.float32)) + dsinks = jnp.sum(sinks_exp.astype(o.dtype) * do * o) + return None, dq, dk, dv, None, dsinks _attention_reference_custom = jax.custom_vjp( - _attention_reference, nondiff_argnums=(5, 6, 7, 8) + _attention_reference, nondiff_argnames=( + "mask_value", "save_residuals", "custom_type", "attn_logits_soft_cap") ) _attention_reference_custom.defvjp(_attention_reference_custom_fwd, _attention_reference_custom_bwd) @@ -322,6 +334,7 @@ def attention_reference_custom( k: jax.Array, # [kv_seq_len, head_dim] v: jax.Array, # [kv_seq_len, head_dim] segment_ids: SegmentIds | None, + sinks: jax.Array | None = None, *, mask_value: float = DEFAULT_MASK_VALUE, save_residuals: bool = False, @@ -334,6 +347,7 @@ def attention_reference_custom( k, v, segment_ids, + sinks, mask_value, save_residuals, custom_type=custom_type, @@ -361,6 +375,7 @@ def _wrapped( k: jax.Array, v: jax.Array, segment_ids: SegmentIds | None = None, + sinks: jax.Array | None = None, *, mask_value: float = DEFAULT_MASK_VALUE, save_residuals: bool = False, @@ -385,7 +400,7 @@ def _wrapped( ) if is_mqa: - func = jax.vmap(func, in_axes=(0, 0, None, None, None)) + func = jax.vmap(func, in_axes=(0, 0, None, None, None, 0)) is_grouped = False else: # In grouped attention (1 < num_kv_heads && num_kv_heads < num_q_heads). @@ -412,14 +427,16 @@ def _wrapped( q_heads_per_kv_head = q_heads // kv_heads q = q.reshape((kv_heads, q_heads_per_kv_head, q_seq_len, head_dim)) mask = mask.reshape((kv_heads, q_heads_per_kv_head, *mask.shape[1:])) + if sinks is not None: + sinks = sinks.reshape((kv_heads, q_heads_per_kv_head)) # Inner-most vmap: iterate over the q heads. - func = jax.vmap(func, in_axes=(0, 0, None, None, None)) + func = jax.vmap(func, in_axes=(0, 0, None, None, None, 0)) # Outer-most vmap: iterate over the kv heads. - func = jax.vmap(func, in_axes=(0, 0, 0, 0, None)) + func = jax.vmap(func, in_axes=(0, 0, 0, 0, None, 0)) - out = func(mask, q, k, v, segment_ids) + out = func(mask, q, k, v, segment_ids, sinks) if is_grouped: @@ -599,9 +616,9 @@ def _apply_mask_and_soft_cap( masks = [] if mask_ref is not None: if k_in_lanes: - mask = pl.load(mask_ref, (slice(None), k_slice)) + mask = mask_ref[:, k_slice] else: - mask = pl.load(mask_ref, (k_slice, slice(None))) + mask = mask_ref[k_slice, :] masks.append( jnp.bitwise_or(mask, jnp.broadcast_to(should_not_mask, mask.shape)) @@ -621,8 +638,8 @@ def _apply_mask_and_soft_cap( repeats, rem = divmod(k_slice.size, NUM_LANES) assert rem == 0 - q_sequence = pltpu.repeat( - q_sequence_ref[...], repeats, axis=1 + q_sequence = jnp.tile( + q_sequence_ref[...], (1, repeats) ) # [bq, k_slice.size] else: assert q_sequence_ref.shape == (NUM_SUBLANES, bq) @@ -630,7 +647,7 @@ def _apply_mask_and_soft_cap( k_sequence = k_offset + jax.lax.broadcasted_iota( jnp.int32, (k_slice.size, bq), 0 ) - q_sequence = pl.load(q_sequence_ref, (pl.ds(1), slice(None))) # [1, bq] + q_sequence = q_sequence_ref[:1, :] # [1, bq] q_sequence = jnp.broadcast_to(q_sequence, (k_slice.size, bq)) assert q_sequence.shape == k_sequence.shape @@ -644,20 +661,20 @@ def _apply_mask_and_soft_cap( if q_segment_ids_ref is not None: if k_in_lanes: - kv_ids = pl.load(kv_segment_ids_ref, (pl.ds(1), k_slice)) # [1, k_slice] + kv_ids = kv_segment_ids_ref[:1, k_slice] # [1, k_slice] repeats, rem = divmod(kv_ids.shape[1], NUM_LANES) if rem: raise NotImplementedError(f"block_kv must be a multiple of {NUM_LANES}") - q_ids = pltpu.repeat(q_segment_ids_ref[:], repeats, axis=1) # [bq, bkv] + q_ids = jnp.tile(q_segment_ids_ref[:], (1, repeats)) # [bq, bkv] else: assert bq == q_segment_ids_ref.shape[-1] repeats, rem = divmod(bq, NUM_LANES) if rem: raise NotImplementedError(f"block_q must be a multiple of {NUM_LANES}") - kv_ids = pltpu.repeat( - pl.load(kv_segment_ids_ref, (k_slice, slice(None))), repeats, axis=1 + kv_ids = jnp.tile( + kv_segment_ids_ref[k_slice, :], (1, repeats) ) # [k_slice, bq] - q_ids = pl.load(q_segment_ids_ref, (pl.ds(1), slice(None))) # [1, bq] + q_ids = q_segment_ids_ref[:1, :] # [1, bq] masks.append(q_ids == kv_ids) def cap_logits(logits): @@ -687,6 +704,7 @@ def flash_attention_kernel( v_ref, q_segment_ids_ref, kv_segment_ids_ref, + sinks_ref, mask_ref, q_sequence_ref, # Outputs @@ -710,20 +728,21 @@ def flash_attention_kernel( ): float32 = jnp.float32 HEAD_DIM_MINOR = QKVLayout.HEAD_DIM_MINOR - - head_dim_v_repeats, rem = divmod(head_dim_v, NUM_LANES) - if rem != 0: - raise NotImplementedError( - f"{head_dim_v=} should be a multiple of {NUM_LANES}" - ) + head_dim_v_repeats = pl.cdiv(head_dim_v, NUM_LANES) h, i, j = pl.program_id(0), pl.program_id(1), pl.program_id(2) @pl.when(j == 0) def init(): o_scratch_ref[...] = jnp.zeros_like(o_scratch_ref) - m_scratch_ref[...] = jnp.full_like(m_scratch_ref, mask_value) - l_scratch_ref[...] = jnp.zeros_like(l_scratch_ref) + if sinks_ref is not None: + sinks = sinks_ref[0, h].astype(m_scratch_ref.dtype) + # initialize `max = sinks`, so `exp(sinks - max = 0) = 1` + m_scratch_ref[...] = sinks * jnp.ones_like(m_scratch_ref) + l_scratch_ref[...] = jnp.ones_like(l_scratch_ref) + else: + m_scratch_ref[...] = jnp.full_like(m_scratch_ref, mask_value) + l_scratch_ref[...] = jnp.zeros_like(l_scratch_ref) global_kv_index, _, should_run, should_not_mask = _next_nonzero( h, @@ -743,9 +762,9 @@ def body(kv_compute_index, _): q = q_ref[...] if q_layout == HEAD_DIM_MINOR else q_ref[...].T qk_dims = NT_DIM_NUMBERS if k_layout == HEAD_DIM_MINOR else NN_DIM_NUMBERS if k_layout == HEAD_DIM_MINOR: - k = pl.load(k_ref, (slice_k, slice(None))) + k = k_ref[slice_k, :] else: - k = pl.load(k_ref, (slice(None), slice_k)) + k = k_ref[:, slice_k] qk = lax.dot_general(q, k, qk_dims, preferred_element_type=float32) assert qk.shape == (bq, bkv_compute) @@ -782,7 +801,7 @@ def body(kv_compute_index, _): f"{bkv_compute=} should be a multiple of {NUM_LANES}" ) - s_curr = jnp.exp(qk - pltpu.repeat(m_next, bkv_repeats, axis=1)) + s_curr = jnp.exp(qk - jnp.tile(m_next, (1, bkv_repeats))) assert s_curr.shape == (bq, bkv_compute) l_curr = jax.lax.broadcast_in_dim(s_curr.sum(axis=-1), l_prev.shape, (0,)) @@ -794,13 +813,14 @@ def body(kv_compute_index, _): sv_dims = NN_DIM_NUMBERS if v_layout == HEAD_DIM_MINOR else NT_DIM_NUMBERS if v_layout == HEAD_DIM_MINOR: - v = pl.load(v_ref, (slice_k, slice(None))) + v = v_ref[slice_k, :] else: - v = pl.load(v_ref, (slice(None), slice_k)) + v = v_ref[:, slice_k] v = v.astype(float32) o_curr = lax.dot_general(s_curr, v, sv_dims) - alpha_o = pltpu.repeat(alpha, head_dim_v_repeats, axis=1) + alpha_o = jnp.tile( + alpha, (1, head_dim_v_repeats))[..., :o_scratch_ref.shape[-1]] o_scratch_ref[:] = alpha_o * o_scratch_ref[:] + o_curr @pl.when(should_run) @@ -814,7 +834,8 @@ def run(): @pl.when(j == grid_width - 1) def end(): l = l_scratch_ref[...] - l_inv = pltpu.repeat(1.0 / l, head_dim_v_repeats, axis=1) + l_inv = jnp.tile( + 1.0 / l, (1, head_dim_v_repeats))[..., :o_scratch_ref.shape[-1]] o_ref[...] = (o_scratch_ref[...] * l_inv).astype(o_ref.dtype) if logsumexp_ref is not None: assert logsumexp_ref.shape == (bq, NUM_LANES) @@ -852,6 +873,7 @@ def _splash_attention_forward( k: jax.Array, v: jax.Array, segment_ids: SegmentIds | None, + sinks: jax.Array | None, mask_value: float, is_mqa: bool, block_sizes: BlockSizes, @@ -876,6 +898,7 @@ def _splash_attention_forward( k: jax.Array, v: jax.Array, segment_ids: SegmentIds | None, + sinks: jax.Array | None, mask_value: float, is_mqa: bool, block_sizes: BlockSizes, @@ -1001,11 +1024,13 @@ def kv_segment_ids_index_map(h, i, j, data_next_ref, block_mask_ref, # Convert the logical shape from head-minor to sequence-minor. in_specs = [ pl.BlockSpec( - from_head_minor((None, bq, head_dim_qk), q_layout), q_index_map + from_head_minor((None, bq, head_dim_qk), q_layout), + q_index_map ), pl.BlockSpec( from_head_minor( - (bkv, head_dim_qk) if is_mqa else (None, bkv, head_dim_qk), k_layout + (bkv, head_dim_qk) + if is_mqa else (None, bkv, head_dim_qk), k_layout ), k_index_map, ), @@ -1031,6 +1056,18 @@ def kv_segment_ids_index_map(h, i, j, data_next_ref, block_mask_ref, in_specs += [None, None] q_segment_ids = kv_segment_ids = None + if sinks is not None: + assert sinks.shape == (num_q_heads,) + # align sinks to sublanes to allow vmap and shard_map over the kernel + in_specs += [ + pl.BlockSpec((NUM_SUBLANES, num_q_heads), lambda h, i, j, *_: (0, 0), + memory_space=pltpu.SMEM) + ] + sinks = jnp.broadcast_to(sinks.astype(jnp.float32)[None, :], + (NUM_SUBLANES, num_q_heads)) + else: + in_specs += [None] + if fwd_mask_info.partial_mask_blocks is not None: in_specs.append(pl.BlockSpec((None, bq, bkv), mask_index_map)) else: @@ -1083,12 +1120,12 @@ def logsumexp_index_map(h, i, *_): out_specs += [None] kernel_name = get_kernel_name( - dataclasses.asdict(block_sizes), is_mqa=is_mqa, save_residuals=save_residuals, is_segmented=segment_ids is not None, phase="fwd", ) + metadata = {"xprof_metadata": json.dumps(dataclasses.asdict(block_sizes))} if fwd_mask_info.data_next is not None: grid_width = fwd_mask_info.data_next.shape[-1] @@ -1118,12 +1155,13 @@ def logsumexp_index_map(h, i, *_): out_specs=out_specs, grid=grid, ), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("parallel", "arbitrary", "arbitrary"), ), out_shape=out_shapes, name=kernel_name, interpret=interpret, + metadata=metadata, )( fwd_mask_info.data_next, fwd_mask_info.block_mask, @@ -1133,6 +1171,7 @@ def logsumexp_index_map(h, i, *_): v if v_layout == QKVLayout.HEAD_DIM_MINOR else v.swapaxes(-1, -2), q_segment_ids, kv_segment_ids, + sinks, fwd_mask_info.partial_mask_blocks, q_sequence, ) @@ -1160,7 +1199,11 @@ def logsumexp_index_map(h, i, *_): return out -@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14)) +@partial(jax.custom_vjp, nondiff_argnames=( + "save_residuals", "mask_value", "is_mqa", "block_sizes", + "residual_checkpoint_name", "mask_function", "attn_logits_soft_cap", + "interpret") +) def _splash_attention_custom( fwd_mask_info: mask_info_lib.MaskInfo, dq_mask_info: mask_info_lib.MaskInfo | None, @@ -1169,6 +1212,7 @@ def _splash_attention_custom( k: jax.Array, v: jax.Array, segment_ids: SegmentIds | None, + sinks: jax.Array | None, save_residuals: bool, mask_value: float, is_mqa: bool, @@ -1195,6 +1239,7 @@ def _splash_attention_custom( k, v, segment_ids, + sinks=sinks, mask_value=mask_value, is_mqa=is_mqa, block_sizes=block_sizes, @@ -1214,6 +1259,7 @@ def _splash_attention_fwd( k: jax.Array, v: jax.Array, segment_ids: SegmentIds | None, + sinks: jax.Array | None, save_residuals: bool, mask_value: float, is_mqa: bool, @@ -1235,6 +1281,7 @@ def _splash_attention_fwd( k, v, segment_ids, + sinks, mask_value=mask_value, is_mqa=is_mqa, block_sizes=block_sizes, @@ -1249,6 +1296,7 @@ def _splash_attention_fwd( k, v, segment_ids, + sinks, out, logsumexp, dq_mask_info, @@ -1267,6 +1315,7 @@ def _flash_attention_dq_kernel( v_ref, q_segment_ids_ref, kv_segment_ids_ref, + sinks_ref, logsumexp_ref, do_ref, di_ref, @@ -1286,6 +1335,7 @@ def _flash_attention_dq_kernel( v_layout: QKVLayout, mask_function: MaskFunctionType | None, ): + del sinks_ref # potentially fuse dsinks computation into the kernel later float32 = jnp.float32 HEAD_DIM_MINOR = QKVLayout.HEAD_DIM_MINOR @@ -1357,6 +1407,7 @@ def _splash_attention_bwd_dq( k, v, segment_ids, + sinks, logsumexp, do, di, @@ -1492,6 +1543,18 @@ def kv_segment_ids_index_map( q_segment_spec = kv_segment_spec = None q_segment_ids = kv_segment_ids = None + if sinks is not None: + assert sinks.shape == (num_q_heads,) + # align sinks to sublanes to allow vmap and shard_map over the kernel + sinks_spec = pl.BlockSpec( + (NUM_SUBLANES, num_q_heads), lambda h, i, j, *_: (0, 0), + memory_space=pltpu.SMEM + ) + sinks = jnp.broadcast_to(sinks.astype(jnp.float32)[None, :], + (NUM_SUBLANES, num_q_heads)) + else: + sinks_spec = None + do_spec = o_spec def logsumexp_index_map(h, i, *_): @@ -1511,6 +1574,7 @@ def logsumexp_index_map(h, i, *_): v_spec, q_segment_spec, kv_segment_spec, + sinks_spec, logsumexp_spec, do_spec, di_spec, @@ -1555,18 +1619,18 @@ def logsumexp_index_map(h, i, *_): num_scalar_prefetch = 3 kernel_name = get_kernel_name( - dict( - block_q_dq=bq, - block_kv_dq=bkv, - q_layout=q_layout, - k_layout=k_layout, - v_layout=v_layout, - ), is_mqa=is_mqa, save_residuals=False, is_segmented=segment_ids is not None, phase="dq", ) + metadata = {"xprof_metadata": json.dumps(dict( + block_q_dq=bq, + block_kv_dq=bkv, + q_layout=q_layout, + k_layout=k_layout, + v_layout=v_layout, + ))} with jax.named_scope(kernel_name): _, dq = pl.pallas_call( kernel, @@ -1577,11 +1641,12 @@ def logsumexp_index_map(h, i, *_): grid=grid, ), out_shape=out_shapes, - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("arbitrary", "arbitrary", "arbitrary"), ), name=kernel_name, interpret=interpret, + metadata=metadata, )( mask_info.data_next, mask_info.block_mask, @@ -1591,6 +1656,7 @@ def logsumexp_index_map(h, i, *_): v if v_layout == QKVLayout.HEAD_DIM_MINOR else v.swapaxes(-1, -2), q_segment_ids, kv_segment_ids, + sinks, logsumexp, do, di, @@ -1611,6 +1677,7 @@ def _flash_attention_dkv_kernel( v_ref, q_segment_ids_ref, kv_segment_ids_ref, + sinks_ref, logsumexp_ref, do_ref, di_ref, @@ -1638,6 +1705,7 @@ def _flash_attention_dkv_kernel( bkv: int, mask_function: MaskFunctionType | None, ): + del sinks_ref # potentially fuse dsinks computation into the kernel later HEAD_DIM_MINOR = QKVLayout.HEAD_DIM_MINOR kv_index, q_head_index, q_index = ( pl.program_id(0), @@ -1688,13 +1756,13 @@ def body(i, _): q = q_ref[...] # We keep q potentially transposed, since it's always RHS def _load_kv(ref, layout): if layout == HEAD_DIM_MINOR: - return pl.load(ref, (slice_k, slice(None))) - return pl.load(ref, (slice(None), slice_k)).T + return ref[slice_k, :] + return ref[:, slice_k].T k = _load_kv(k_ref, k_layout) v = _load_kv(v_ref, v_layout) - logsumexp = pl.load(logsumexp_ref, (pl.ds(1), slice(None))) + logsumexp = logsumexp_ref[:1, :] do = do_ref[...] - di = pl.load(di_ref, (pl.ds(1), slice(None))) + di = di_ref[:1, :] qk_dims = NT_DIM_NUMBERS if q_layout == HEAD_DIM_MINOR else NN_DIM_NUMBERS qk_uncapped = lax.dot_general( @@ -1718,10 +1786,8 @@ def _load_kv(ref, layout): ) p = jnp.exp(qk - logsumexp) dv = lax.dot(p.astype(do.dtype), do, preferred_element_type=jnp.float32) - dv = dv.astype(dv_scratch_ref.dtype) + pl.load( - dv_scratch_ref, (slice_k, slice(None)) - ) - pl.store(dv_scratch_ref, (slice_k, slice(None)), dv) + dv = dv.astype(dv_scratch_ref.dtype) + dv_scratch_ref[slice_k, :] + dv_scratch_ref[slice_k, :] = dv dp = lax.dot_general( v, do, NT_DIM_NUMBERS, @@ -1737,10 +1803,8 @@ def _load_kv(ref, layout): dk = lax.dot_general( ds.astype(do.dtype), q, dk_dims, preferred_element_type=jnp.float32 ) - dk = dk.astype(dk_scratch_ref.dtype) + pl.load( - dk_scratch_ref, (slice_k, slice(None)) - ) - pl.store(dk_scratch_ref, (slice_k, slice(None)), dk) + dk = dk.astype(dk_scratch_ref.dtype) + dk_scratch_ref[slice_k, :] + dk_scratch_ref[slice_k, :] = dk if dq_scratch_ref is not None or dq_ref is not None: dq = lax.dot_general( ds.T.astype(k.dtype), k, NN_DIM_NUMBERS, @@ -1795,6 +1859,7 @@ def _splash_attention_bwd_dkv( k, v, segment_ids, + sinks, logsumexp, do, di, @@ -1949,8 +2014,7 @@ def dkv_index_map(kv_index, head_index, *_): ) dv_spec = pl.BlockSpec( - (bkv, head_dim_v) if is_mqa else (None, bkv, head_dim_v), - dkv_index_map, + (bkv, head_dim_v) if is_mqa else (None, bkv, head_dim_v), dkv_index_map, ) def mask_index_map( @@ -2009,6 +2073,18 @@ def kv_segment_ids_index_map(kv_index, *_): q_segment_spec = kv_segment_spec = None q_segment_ids = kv_segment_ids = None + if sinks is not None: + assert sinks.shape == (num_q_heads,) + # align sinks to sublanes to allow vmap and shard_map over the kernel + sinks_spec = pl.BlockSpec( + (NUM_SUBLANES, num_q_heads), lambda h, i, j, *_: (0, 0), + memory_space=pltpu.SMEM + ) + sinks = jnp.broadcast_to(sinks.astype(jnp.float32)[None, :], + (NUM_SUBLANES, num_q_heads)) + else: + sinks_spec = None + do_spec = o_spec def logsumexp_index_map( @@ -2048,6 +2124,7 @@ def logsumexp_index_map( v_spec, q_segment_spec, kv_segment_spec, + sinks_spec, logsumexp_spec, do_spec, di_spec, @@ -2102,19 +2179,19 @@ def logsumexp_index_map( num_scalar_prefetch = 3 kernel_name = get_kernel_name( - dict( - block_q_dkv=bq, - block_kv_dkv=bkv, - block_kv_dkv_compute=bkv_compute, - q_layout=q_layout, - k_layout=k_layout, - v_layout=v_layout, - ), is_mqa=is_mqa, save_residuals=False, is_segmented=segment_ids is not None, phase="dkv", ) + metadata = {"xprof_metadata": json.dumps(dict( + block_q_dkv=bq, + block_kv_dkv=bkv, + block_kv_dkv_compute=bkv_compute, + q_layout=q_layout, + k_layout=k_layout, + v_layout=v_layout, + ))} with jax.named_scope(kernel_name): _, _, _, dq_unreduced, dk, dv = pl.pallas_call( kernel, @@ -2130,11 +2207,12 @@ def logsumexp_index_map( # megacore # 2) for heads, we are reducing over heads # 3) for q_seq_len, we are reducing over it to compute dkv - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("arbitrary", "arbitrary", "arbitrary"), ), name=kernel_name, interpret=interpret, + metadata=metadata, )( mask_info.data_next, mask_info.block_mask, @@ -2144,6 +2222,7 @@ def logsumexp_index_map( v if v_layout == QKVLayout.HEAD_DIM_MINOR else v.swapaxes(-1, -2), q_segment_ids, kv_segment_ids, + sinks, logsumexp, do, di, @@ -2178,6 +2257,7 @@ def _splash_attention_bwd( jax.Array, # k jax.Array, # v SegmentIds | None, # segmend_ids + jax.Array | None, # sinks ]: del save_residuals, residual_checkpoint_name if not block_sizes.has_backward_blocks: @@ -2194,6 +2274,7 @@ def _splash_attention_bwd( k, v, segment_ids, + sinks, o, logsumexp, dq_mask_info, @@ -2207,6 +2288,7 @@ def _splash_attention_bwd( k, v, segment_ids, + sinks, logsumexp, do, di, @@ -2231,6 +2313,7 @@ def _splash_attention_bwd( k, v, segment_ids, + sinks, logsumexp, do, di, @@ -2248,6 +2331,11 @@ def _splash_attention_bwd( ) # Match the signature of the fwd function. assert dq is not None + dsinks = None + if sinks is not None: + sinks_exp = -jnp.exp(sinks[..., None, None].astype(jnp.float32) + - logsumexp[..., None].astype(jnp.float32)) + dsinks = jnp.sum(sinks_exp.astype(o.dtype) * o * do, axis=(-1, -2)) return ( None, # fwd_mask_info None, # dq_mask_info @@ -2256,6 +2344,7 @@ def _splash_attention_bwd( dk, # k dv, # v None, # segment_ids + dsinks, # sinks ) @@ -2283,6 +2372,7 @@ def _splash_attention( k: jax.Array, v: jax.Array, segment_ids: SegmentIds | None = None, + sinks: jax.Array | None = None, *, is_mqa: bool, block_sizes: BlockSizes | None, @@ -2293,6 +2383,26 @@ def _splash_attention( mask_function: MaskFunctionType | None, interpret: bool, ) -> SplashCustomReturnType: + """ + For dynamic masks, `partial_mask_blocks` has shape (head_count, q_blocks, kv_blocks, block_q, block_kv). + This shape allows sharding across both head count and query sequence dimensions. + + Note: The leading dimensions (head_count, q_blocks, kv_blocks) must be + collapsed into a single dimension before being passed to the kernel. + """ + def _collapse_partial_mask_blocks(mask_info: mask_info_lib.MaskInfo | None): + if mask_info is None or mask_info.partial_mask_blocks is None: + return mask_info + + return mask_info._replace( + partial_mask_blocks=mask_info.partial_mask_blocks.reshape( + -1, *mask_info.partial_mask_blocks.shape[-2:] + ) + ) + + fwd_mask_info = _collapse_partial_mask_blocks(fwd_mask_info) + dq_mask_info = _collapse_partial_mask_blocks(dq_mask_info) + dkv_mask_info = _collapse_partial_mask_blocks(dkv_mask_info) return _splash_attention_custom( fwd_mask_info, dq_mask_info, @@ -2301,6 +2411,7 @@ def _splash_attention( k, v, segment_ids, + sinks, mask_value=mask_value, is_mqa=is_mqa, block_sizes=block_sizes, @@ -2352,13 +2463,16 @@ def manual_sharding_spec(self, sharding: jax.sharding.NamedSharding): spec = sharding.spec assert len(spec) == 2 replicated = jax.sharding.PartitionSpec() + partial_mask_blocks_spec = ( + spec if self.fwd_mask_info.is_dynamic_mask else replicated + ) # Shard q_sequence over the sequence dimension only. q_sequence_spec = jax.sharding.PartitionSpec(spec[1]) mask_info_specs = mask_info_lib.MaskInfo( # pytype: disable=wrong-arg-types data_next=spec if self.fwd_mask_info.data_next is not None else None, mask_next=spec if self.fwd_mask_info.mask_next is not None else None, block_mask=spec if self.fwd_mask_info.block_mask is not None else None, - partial_mask_blocks=replicated + partial_mask_blocks=partial_mask_blocks_spec if self.fwd_mask_info.partial_mask_blocks is not None else None, q_sequence=q_sequence_spec diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py index eab2a695dc02..354fdb24f9df 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py @@ -92,6 +92,35 @@ def make_local_attention_mask( return mask.astype(np.bool_) +def make_chunk_attention_mask( + shape: tuple[int, int], chunk_size: int +) -> np.ndarray: + """Makes a chunked causal attention mask. + + Args: + shape: The desired shape of the mask (q_seq_len, kv_seq_len). + chunk_size: The size of the attention chunks. + + Returns: + A boolean mask of shape `mask_shape` where True indicates attention is + allowed according to chunked causal rules, and False otherwise. + + Raises: + ValueError: If chunk_window_size is None or not positive. + """ + if chunk_size <= 0: + raise ValueError('chunk_size must be positive') + + q_seq_len, kv_seq_len = shape + q_idx = np.arange(q_seq_len, dtype=np.int32) + kv_idx = np.arange(kv_seq_len, dtype=np.int32) + + # chunk mask calculation + same_chunk = (q_idx[:, None] // chunk_size) == (kv_idx[None, :] // chunk_size) + mask = same_chunk & (q_idx[:, None] >= kv_idx[None, :]) + return mask + + def make_random_mask( shape: tuple[int, int], sparsity: float, seed: int ) -> np.ndarray: @@ -196,15 +225,20 @@ def __hash__(self): class _ComputableMask(Mask): """Superclass for all masks that can be computed inside the kernel using a callable object. + This subclass is designed to be used with Splash Attention. + It allows the mask logic to be computed on-the-fly or fused into the attention + kernel, avoiding the memory cost of materializing the full + (sequence_length, sequence_length) boolean mask array, which can be excessive + for long sequences. + Attributes: _shape: Shape of the 2-dim mask: (q_seq_len, kv_seq_len). offset: Offset of q start wrt kv. A positive offset shifts the bottom triangle upward, a negative one shifts it downward. A negative offset makes the first 'offset' rows of the attention matrix all 0s which leads to undefined softmax. - q_sequence: Indices of Q sequence. - q_sequence is reused across __getitem__ calls which is important for - compile-time performance. + q_sequence: Indices of Q sequence. q_sequence is reused across __getitem__ + calls which is important for compile-time performance. mask_function: Function used by the SplashAttention kernel to compute the mask rather than loading it. """ @@ -314,26 +348,80 @@ def __hash__(self): )) -class LocalMask(Mask): +class ChunkedCausalMask(_ComputableMask): + """Lazy chunked causal mask. + + Attention is causal within each chunk (0, K), (K, 2K), (2K, 3K), ... tokens + attend to each other but not across chunks. + Llama4 models use interleaved chunk attention along with global attention. + + + Attributes: + chunk_size: The size of each attention chunk. + """ + + chunk_size: int + + def __init__( + self, + shape: tuple[int, int], + chunk_size: int, + shard_count: int = 1, + ): + if chunk_size <= 0: + raise ValueError('chunk_size must be positive') + self.chunk_size = chunk_size + + # Define the mask function for chunk attention + def chunked_causal_mask_function(q_ids, kv_ids): + """Computes the mask logic for the given slice indices.""" + # Condition 1: Same chunk + same_chunk = (q_ids // self.chunk_size) == (kv_ids // self.chunk_size) + + # Condition 2: Causal + causal = q_ids >= kv_ids + + return same_chunk & causal + + super().__init__( + shape=shape, + mask_function=chunked_causal_mask_function, + shard_count=shard_count, + ) + + def __eq__(self, other: object): + if not isinstance(other, type(self)): + return NotImplemented + + return ( + self.shape == other.shape + and self.chunk_size == other.chunk_size + and np.array_equal(self.q_sequence, other.q_sequence) + ) + + def __hash__(self): + return hash(( + type(self), + self.shape, + self.chunk_size, + self.q_sequence.tobytes() if self.q_sequence is not None else None, + )) + + +class LocalMask(_ComputableMask): """Lazy local mask, prevents model from attending to tokens outside window. Attributes: - _shape: Shape of the 2-dim mask: (q_seq_len, kv_seq_len). - window_size: Size of the two sides of the local window (None identifes no + window_size: Size of the two sides of the local window (None identifies no limit for the given side). offset: Offset of q start wrt kv. A positive offset shifts the bottom triangle upward, a negative one shifts it downward. A negative offset makes the first 'offset' rows of the attention matrix all 0s which leads to undefined softmax. - _q_sequence: Important for performance. """ - # TODO(amagni): Transform LocalMask into a _ComputableMask. - - _shape: tuple[int, int] window_size: tuple[int | None, int | None] offset: int - _q_sequence: np.ndarray | None = None def __init__( self, @@ -342,68 +430,50 @@ def __init__( offset: int, shard_count: int = 1, ): - self._shape = shape self.window_size = window_size self.offset = offset - if self.shape[0] % (shard_count * shard_count) != 0: - raise ValueError( - f'Shard count squared ({shard_count * shard_count}) must' - f' divide Q seq_len ({self.shape[0]}) evenly.' - ) - - @property - def shape(self) -> tuple[int, int]: - return self._shape - - def __getitem__(self, idx) -> np.ndarray: - if len(idx) != 2: - raise NotImplementedError(f'Unsupported slice: {idx}') - q_slice, kv_slice = idx - if not isinstance(q_slice, slice) or not isinstance(kv_slice, slice): - raise NotImplementedError(f'Unsupported slice: {idx}') - - q_slice = _fill_slice(q_slice, self.shape[0]) - kv_slice = _fill_slice(kv_slice, self.shape[1]) + def local_mask_function(q_ids, kv_ids): + """Computes the local attention mask for the given slice indices.""" + left_size, right_size = self.window_size - if self._q_sequence is None: - rows = np.arange(q_slice.start, q_slice.stop) - else: - rows = self._q_sequence[q_slice] - - cols = np.arange(kv_slice.start, kv_slice.stop) + assert q_ids.ndim == 2 + assert kv_ids.ndim == 2 - left_size, right_size = self.window_size + if left_size is None and right_size is None: + return np.ones((q_ids.shape[0], kv_ids.shape[1]), dtype=np.bool_) - if left_size is None and right_size is None: - return np.ones((rows.shape[0], cols.shape[0]), dtype=np.bool_) - else: - expanded_cols = cols[None, :] - if self.offset != 0: - expanded_rows = rows[:, None] + self.offset + # Avoid the addition when possible to avoid instantiating an actual array. + if offset != 0: + shifted_q_ids = q_ids + self.offset else: - expanded_rows = rows[:, None] - if left_size is not None and right_size is not None: - return (expanded_rows <= expanded_cols + left_size) & ( - expanded_cols - right_size <= expanded_rows - ) + shifted_q_ids = q_ids + + mask = None + if left_size is not None: + mask = shifted_q_ids - left_size <= kv_ids + if right_size is not None: + if mask is None: + mask = shifted_q_ids + right_size >= kv_ids + else: + mask &= shifted_q_ids + right_size >= kv_ids + return mask - elif left_size is not None and right_size is None: - return expanded_rows <= expanded_cols + left_size - else: - assert left_size is None and right_size is not None - return expanded_cols - right_size <= expanded_rows + super().__init__( + shape=shape, + mask_function=local_mask_function, + shard_count=shard_count, + ) def __eq__(self, other: object): if not isinstance(other, type(self)): - return NotImplemented + return False return ( self.shape == other.shape and self.window_size == other.window_size and self.offset == other.offset - and (True if self._q_sequence is None else - np.array_equal(self._q_sequence, other._q_sequence)) + and np.array_equal(self.q_sequence, other.q_sequence) ) def __hash__(self): @@ -412,7 +482,7 @@ def __hash__(self): self.shape, self.window_size, self.offset, - self._q_sequence.tobytes() if self._q_sequence is not None else None, + self.q_sequence.tobytes() if self.q_sequence is not None else None, )) diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py index 65081e79c0cf..1a8f9dcc55ff 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py @@ -22,7 +22,7 @@ from typing import NamedTuple import jax -from jax import util as jax_util +from jax._src import util as jax_util from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib import jax.numpy as jnp import numpy as np @@ -67,6 +67,10 @@ class MaskInfo(NamedTuple): q_sequence: A i32[q_sequence_length] NumPy array. When using causal masking, this contains the list of indices that correspond to q tokens. For plain causal this is just np.arange(q_sequence_length). + is_dynamic_mask: A bool indicating whether the mask is dynamic or static. + When True, the leading dimensions of `partial_mask_blocks` (num_heads, + q_blocks, kv_blocks) are not collapsed, allowing us to shard it along + those dimensions. """ data_next: np.ndarray | jax.Array | None @@ -74,6 +78,7 @@ class MaskInfo(NamedTuple): block_mask: np.ndarray | jax.Array | None partial_mask_blocks: np.ndarray | jax.Array | None q_sequence: np.ndarray | None + is_dynamic_mask: bool = None def _downcast_to_small_type(array: np.ndarray) -> np.ndarray: @@ -168,7 +173,7 @@ def __eq__(self, other: object) -> bool: def _get_mask_info_for_shard( output_shape: tuple[int, int, int], has_mask_next: bool, - mask: mask_lib.MultiHeadMask, + mask: mask_lib.MultiHeadMask | jax.Array, block_shape: tuple[int, int], coords_to_partial_mask_block_index: dict[tuple[int, int, int], int], masks_per_head_shard: int, @@ -338,7 +343,8 @@ def _process_dynamic_mask( launched. q_seq_shards: Number of Q sequence shards of the mesh in which the kernel is launched. - shrink_grid: Whether or not we should apply the grid shrinking optimization. This is currently ignored. + shrink_grid: Whether or not we should apply the grid shrinking optimization. + This is currently ignored. Returns: `MaskInfo`, a sparse representation of the dense mask. @@ -349,11 +355,6 @@ def _process_dynamic_mask( """ del shrink_grid - - # TODO(pobudzey): Properly support sharding. - if head_shards != 1 or q_seq_shards != 1: - raise ValueError('Dynamic mask processing does not support sharding.') - if len(mask.shape) != 3: raise ValueError(f'Expected a 3-dim mask, instead got: {mask.shape}.') @@ -370,6 +371,18 @@ def _process_dynamic_mask( if kv_mod != 0: raise ValueError(f'{kv_block_size=} should divide {kv_seq_len=}.') + q_seq_len_per_shard, mod = divmod(q_seq_len, q_seq_shards) + if mod != 0: + raise ValueError(f'{q_seq_shards=} should divide {q_seq_len=}.') + + q_blocks_per_shard, mod = divmod(q_seq_len_per_shard, q_block_size) + if mod != 0: + raise ValueError(f'{q_block_size=} should divide {q_seq_len_per_shard=}.') + + heads_per_shard, mod = divmod(head_count, head_shards) + if mod != 0: + raise ValueError(f'{head_shards=} should divide {head_count=}.') + block_mask_shape = ( head_count, q_blocks_count, @@ -398,26 +411,66 @@ def _process_dynamic_mask( block_mask = jnp.where(is_full_mask, 2, block_mask) block_mask = jnp.where(is_empty_mask, 0, block_mask) - # TODO(pobudzey): Return the next valid mask index instead of 0 for a more efficient pipeline. - mask_next = jnp.where( - jnp.logical_or(is_empty_mask, is_full_mask), - 0, - jnp.arange(math.prod(block_mask_shape), dtype=np.int32).reshape( - block_mask_shape - ), - ) + q_sequence_axis = 1 + head_axis = 0 - # data_next stores the index of the next non-empty data block in the sequence. - # The indices of empty blocks are set to 0 to avoid copying extra data when - # pipeling. - if is_dkv: - data_next = jnp.arange(q_blocks_count, dtype=np.int32)[None, :, None] - else: - data_next = jnp.arange(kv_blocks_count, dtype=np.int32)[None, None, :] - data_next = jnp.broadcast_to(data_next, block_mask_shape) - data_next = jnp.where(is_empty_mask, 0, data_next) + # Each iteration of the loop processes a slice of the mask info + # tensors of this shape: + mask_info_slice_shape = (heads_per_shard, q_blocks_per_shard, kv_blocks_count) + + # Collect mask_info shards along the head dimension, concatenate (or + # broadcast) them after the loop. + data_next_per_head_list, mask_next_per_head_list = [], [] + for head_shard in range(head_shards): + head_start = head_shard * heads_per_shard + mask_head_slice = slice(head_start, head_start + heads_per_shard) + + # Collect mask_info shards along the q_sequence dimension, concatenate them + # after the loop. + data_next_sequence_slices, mask_next_sequence_slices = [], [] + for q_seq_len_shard in range(q_seq_shards): + q_seq_len_start = q_seq_len_shard * q_blocks_per_shard + blocked_q_seq_len_slice = slice( + q_seq_len_start, q_seq_len_start + q_blocks_per_shard + ) + local_block_mask = block_mask[mask_head_slice, blocked_q_seq_len_slice] + + mask_next_slice = jnp.arange( + math.prod(mask_info_slice_shape), dtype=np.int32 + ).reshape(mask_info_slice_shape) + mask_next_slice = jnp.where(local_block_mask == 1, mask_next_slice, 0) + + # data_next stores the index of the next non-empty data block in the sequence. + # The indices of empty blocks are set to 0 to avoid copying extra data when + # pipeling. + if is_dkv: + data_next_slice = jnp.arange(q_blocks_per_shard, dtype=np.int32)[ + None, :, None + ] + else: + data_next_slice = jnp.arange(kv_blocks_count, dtype=np.int32)[ + None, None, : + ] + data_next_slice = jnp.broadcast_to(data_next_slice, mask_info_slice_shape) + data_next_slice = jnp.where(local_block_mask == 0, 0, data_next_slice) + + data_next_sequence_slices.append(data_next_slice) + mask_next_sequence_slices.append(mask_next_slice) + + # Concatenate the sequence shards. + data_next_per_head = jnp.concatenate( + data_next_sequence_slices, axis=q_sequence_axis + ) + data_next_per_head_list.append(data_next_per_head) + mask_next_per_head = jnp.concatenate( + mask_next_sequence_slices, axis=q_sequence_axis + ) + mask_next_per_head_list.append(mask_next_per_head) + + # Concatenate (or broadcast) the head shards. + data_next = jnp.concatenate(data_next_per_head_list, axis=head_axis) + mask_next = jnp.concatenate(mask_next_per_head_list, axis=head_axis) - partial_mask_blocks = partial_mask_blocks.reshape(-1, *block_shape) if is_dkv: partial_mask_blocks = partial_mask_blocks.swapaxes(-1, -2) @@ -438,9 +491,11 @@ def _downcast(array: jax.Array, max_value: int) -> jax.Array: if downcast_smem_data: block_mask = block_mask.astype(np.int8) # values are in the range [0, 1, 2] data_next = _downcast( - data_next, q_blocks_count if is_dkv else kv_blocks_count + data_next, q_blocks_per_shard if is_dkv else kv_blocks_count + ) + mask_next = _downcast( + mask_next, heads_per_shard * q_blocks_per_shard * kv_blocks_count ) - mask_next = _downcast(mask_next, math.prod(block_mask_shape)) return ( MaskInfo( @@ -449,6 +504,7 @@ def _downcast(array: jax.Array, max_value: int) -> jax.Array: block_mask=block_mask, partial_mask_blocks=partial_mask_blocks, q_sequence=None, + is_dynamic_mask=True, ), None, ) @@ -577,7 +633,7 @@ def assign_unique_ids(objects): ] # TODO(amagni): checking the validity of the masks is slow for large masks. - # Disable it for now, reevalute in the future. + # Disable it for now, reevaluate in the future. partial_mask_block_ids: dict[_HashableNDArray, int] = collections.defaultdict( lambda: len(partial_mask_block_ids) @@ -691,7 +747,7 @@ def set_block_mask(mask_id: int, q_index: int, kv_index: int, value: int): q_sequence_axis = 1 head_axis = 0 - # Collect mask_info shards along the head dimension, concatentate (or + # Collect mask_info shards along the head dimension, concatenate (or # broadcast) them after the loop. data_next_per_head_list, mask_next_per_head_list = [], [] for head_shard in range(shards_to_process): diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index ecc9d0d15120..317c34b17b0b 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -13,60 +13,105 @@ # limitations under the License. """Mosaic-specific Pallas APIs.""" +import typing from jax._src.pallas.mosaic import core as core -from jax._src.pallas.mosaic.core import ARBITRARY as ARBITRARY from jax._src.pallas.mosaic.core import create_tensorcore_mesh as create_tensorcore_mesh from jax._src.pallas.mosaic.core import dma_semaphore as dma_semaphore from jax._src.pallas.mosaic.core import GridDimensionSemantics as GridDimensionSemantics -from jax._src.pallas.mosaic.core import PARALLEL as PARALLEL +from jax._src.pallas.mosaic.core import KernelType as KernelType from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec as PrefetchScalarGridSpec -from jax._src.pallas.mosaic.core import semaphore as semaphore from jax._src.pallas.mosaic.core import SemaphoreType as SemaphoreType -from jax._src.pallas.mosaic.core import TPUMemorySpace as TPUMemorySpace -from jax._src.pallas.mosaic.core import TPUCompilerParams as TPUCompilerParams -from jax._src.pallas.mosaic.core import runtime_assert_enabled as runtime_assert_enabled -from jax._src.pallas.mosaic.core import _ENABLE_RUNTIME_ASSERT as enable_runtime_assert # noqa: F401 +from jax._src.pallas.mosaic.core import SideEffectType as SideEffectType +from jax._src.pallas.mosaic.core import MemorySpace as MemorySpace +from jax._src.pallas.mosaic.core import CompilerParams as CompilerParams from jax._src.pallas.mosaic.helpers import sync_copy as sync_copy from jax._src.pallas.mosaic.helpers import core_barrier as core_barrier from jax._src.pallas.mosaic.helpers import run_on_first_core as run_on_first_core +from jax._src.pallas.mosaic.interpret.interpret_pallas_call import InterpretParams as InterpretParams +from jax._src.pallas.mosaic.interpret.interpret_pallas_call import force_tpu_interpret_mode as force_tpu_interpret_mode +from jax._src.pallas.mosaic.interpret.interpret_pallas_call import reset_tpu_interpret_mode_state as reset_tpu_interpret_mode_state +from jax._src.pallas.mosaic.interpret.interpret_pallas_call import set_tpu_interpret_mode as set_tpu_interpret_mode from jax._src.pallas.mosaic.lowering import LoweringException as LoweringException from jax._src.pallas.mosaic.pipeline import BufferedRef as BufferedRef +from jax._src.pallas.mosaic.pipeline import BufferedRefBase as BufferedRefBase from jax._src.pallas.mosaic.pipeline import emit_pipeline as emit_pipeline from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations as emit_pipeline_with_allocations from jax._src.pallas.mosaic.pipeline import get_pipeline_schedule as get_pipeline_schedule from jax._src.pallas.mosaic.pipeline import make_pipeline_allocations as make_pipeline_allocations +from jax._src.pallas.mosaic.pipeline import Tiling as Tiling from jax._src.pallas.mosaic.primitives import async_copy as async_copy from jax._src.pallas.mosaic.primitives import async_remote_copy as async_remote_copy from jax._src.pallas.mosaic.primitives import bitcast as bitcast -from jax._src.pallas.mosaic.primitives import delay as delay -from jax._src.pallas.mosaic.primitives import device_id as device_id -from jax._src.pallas.mosaic.primitives import DeviceIdType as DeviceIdType from jax._src.pallas.mosaic.primitives import get_barrier_semaphore as get_barrier_semaphore +from jax._src.pallas.mosaic.primitives import load as load from jax._src.pallas.mosaic.primitives import make_async_copy as make_async_copy from jax._src.pallas.mosaic.primitives import make_async_remote_copy as make_async_remote_copy +from jax._src.pallas.mosaic.primitives import pack_elementwise as pack_elementwise from jax._src.pallas.mosaic.primitives import prng_random_bits as prng_random_bits from jax._src.pallas.mosaic.primitives import prng_seed as prng_seed from jax._src.pallas.mosaic.primitives import repeat as repeat from jax._src.pallas.mosaic.primitives import roll as roll -from jax._src.pallas.mosaic.primitives import semaphore_read as semaphore_read -from jax._src.pallas.mosaic.primitives import semaphore_signal as semaphore_signal -from jax._src.pallas.mosaic.primitives import semaphore_wait as semaphore_wait +from jax._src.pallas.mosaic.primitives import stochastic_round as stochastic_round +from jax._src.pallas.mosaic.primitives import store as store +from jax._src.pallas.mosaic.primitives import touch as touch +from jax._src.pallas.mosaic.primitives import trace_value as trace_value +from jax._src.pallas.mosaic.primitives import unpack_elementwise as unpack_elementwise +from jax._src.pallas.mosaic.primitives import with_memory_space_constraint as with_memory_space_constraint from jax._src.pallas.mosaic.random import sample_block as sample_block +from jax._src.pallas.mosaic.random import stateful_bernoulli as stateful_bernoulli +from jax._src.pallas.mosaic.random import stateful_bits as stateful_bits +from jax._src.pallas.mosaic.random import stateful_normal as stateful_normal +from jax._src.pallas.mosaic.random import stateful_uniform as stateful_uniform from jax._src.pallas.mosaic.random import to_pallas_key as to_pallas_key +from jax._src.pallas.mosaic.tpu_info import ChipVersion as ChipVersion +from jax._src.pallas.mosaic.tpu_info import get_tpu_info as get_tpu_info +from jax._src.pallas.mosaic.tpu_info import is_tpu_device as is_tpu_device +from jax._src.pallas.mosaic.tpu_info import TpuInfo as TpuInfo -import types -from jax._src.pallas.mosaic.verification import assume -from jax._src.pallas.mosaic.verification import pretend -from jax._src.pallas.mosaic.verification import skip -from jax._src.pallas.mosaic.verification import define_model -verification = types.SimpleNamespace( - assume=assume, pretend=pretend, skip=skip, define_model=define_model -) -del types, assume, pretend, skip, define_model # Clean up. +# Those primitives got moved to Pallas core. Keeping the updated imports +# here for backward compatibility. +from jax._src.pallas import primitives as pl_primitives +from jax._src.pallas.core import semaphore as semaphore +from jax._src.pallas.core import MemorySpace as GeneralMemorySpace +from jax._src.pallas.primitives import DeviceIdType as DeviceIdType +from jax._src.pallas.primitives import semaphore_read as semaphore_read +from jax._src.pallas.primitives import semaphore_signal as semaphore_signal +from jax._src.pallas.primitives import semaphore_wait as semaphore_wait -ANY = TPUMemorySpace.ANY -CMEM = TPUMemorySpace.CMEM -SMEM = TPUMemorySpace.SMEM -VMEM = TPUMemorySpace.VMEM -SEMAPHORE = TPUMemorySpace.SEMAPHORE +PARALLEL = GridDimensionSemantics.PARALLEL +CORE_PARALLEL = GridDimensionSemantics.CORE_PARALLEL +SUBCORE_PARALLEL = GridDimensionSemantics.SUBCORE_PARALLEL +ARBITRARY = GridDimensionSemantics.ARBITRARY + +CMEM = MemorySpace.CMEM +SMEM = MemorySpace.SMEM +VMEM = MemorySpace.VMEM +VMEM_SHARED = MemorySpace.VMEM_SHARED +HBM = MemorySpace.HBM +HOST = MemorySpace.HOST +SEMAPHORE = MemorySpace.SEMAPHORE + +_deprecations = { + # Added Oct 31, 2025 + "delay": ( + "pltpu.delay is deprecated, use pl.delay instead.", + pl_primitives.delay + ), + # Added Dec 10, 2025 + "ANY": ( + "pltpu.ANY is deprecated, use pl.ANY instead.", + GeneralMemorySpace.ANY + ), +} + +if typing.TYPE_CHECKING: + delay = pl_primitives.delay + ANY = GeneralMemorySpace.ANY +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del typing +del pl_primitives +del GeneralMemorySpace diff --git a/jax/experimental/pallas/tpu_sc.py b/jax/experimental/pallas/tpu_sc.py new file mode 100644 index 000000000000..e9a90ac9f7ac --- /dev/null +++ b/jax/experimental/pallas/tpu_sc.py @@ -0,0 +1,39 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TPU SparseCore Extensions to Pallas.""" + +from jax._src.pallas.mosaic.sc_core import BlockSpec as BlockSpec +from jax._src.pallas.mosaic.sc_core import get_sparse_core_info as get_sparse_core_info +from jax._src.pallas.mosaic.sc_core import MemoryRef as MemoryRef +from jax._src.pallas.mosaic.sc_core import ScalarSubcoreMesh as ScalarSubcoreMesh +from jax._src.pallas.mosaic.sc_core import VectorSubcoreMesh as VectorSubcoreMesh +from jax._src.pallas.mosaic.sc_primitives import addupdate as addupdate +from jax._src.pallas.mosaic.sc_primitives import addupdate_compressed as addupdate_compressed +from jax._src.pallas.mosaic.sc_primitives import addupdate_scatter as addupdate_scatter +from jax._src.pallas.mosaic.sc_primitives import all_reduce_ffs as all_reduce_ffs +from jax._src.pallas.mosaic.sc_primitives import all_reduce_population_count as all_reduce_population_count +from jax._src.pallas.mosaic.sc_primitives import bitcast as bitcast +from jax._src.pallas.mosaic.sc_primitives import cummax as cummax +from jax._src.pallas.mosaic.sc_primitives import cumsum as cumsum +from jax._src.pallas.mosaic.sc_primitives import load_expanded as load_expanded +from jax._src.pallas.mosaic.sc_primitives import load_gather as load_gather +from jax._src.pallas.mosaic.sc_primitives import pack as pack +from jax._src.pallas.mosaic.sc_primitives import PackFormat as PackFormat +from jax._src.pallas.mosaic.sc_primitives import parallel_loop as parallel_loop +from jax._src.pallas.mosaic.sc_primitives import scan_count as scan_count +from jax._src.pallas.mosaic.sc_primitives import sort_key_val as sort_key_val +from jax._src.pallas.mosaic.sc_primitives import store_compressed as store_compressed +from jax._src.pallas.mosaic.sc_primitives import store_scatter as store_scatter +from jax._src.pallas.mosaic.sc_primitives import subcore_barrier as subcore_barrier +from jax._src.pallas.mosaic.sc_primitives import unpack as unpack diff --git a/jax/experimental/pallas/triton.py b/jax/experimental/pallas/triton.py index 06adb9e6da7e..3878a5a8af0c 100644 --- a/jax/experimental/pallas/triton.py +++ b/jax/experimental/pallas/triton.py @@ -14,7 +14,18 @@ """Triton-specific Pallas APIs.""" -from jax._src.pallas.triton.core import TritonCompilerParams as TritonCompilerParams +from jax._src.pallas.primitives import atomic_add as atomic_add +from jax._src.pallas.primitives import atomic_and as atomic_and +from jax._src.pallas.primitives import atomic_cas as atomic_cas +from jax._src.pallas.primitives import atomic_max as atomic_max +from jax._src.pallas.primitives import atomic_min as atomic_min +from jax._src.pallas.primitives import atomic_or as atomic_or +from jax._src.pallas.primitives import atomic_xchg as atomic_xchg +from jax._src.pallas.primitives import atomic_xor as atomic_xor +from jax._src.pallas.primitives import max_contiguous as max_contiguous +from jax._src.pallas.triton.core import CompilerParams as CompilerParams from jax._src.pallas.triton.primitives import approx_tanh as approx_tanh from jax._src.pallas.triton.primitives import debug_barrier as debug_barrier from jax._src.pallas.triton.primitives import elementwise_inline_asm as elementwise_inline_asm +from jax._src.pallas.triton.primitives import load as load +from jax._src.pallas.triton.primitives import store as store diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index 8ba7eb25d646..5a701707dcb8 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -15,10 +15,28 @@ # ruff: noqa from jax._src.pjit import ( - pjit as pjit, - pjit_p as pjit_p, + pjit as _deprecated_pjit, ) from jax._src.sharding_impls import ( AUTO as AUTO, - UNSPECIFIED as _UNSPECIFIED, ) + +_deprecations = { + # Added Oct 13, 2025 + "pjit": ( + ( + "jax.experimental.pjit.pjit has been deprecated. Please use" + " `jax.jit` instead." + ), + _deprecated_pjit, + ) +} + +import typing as _typing +if _typing.TYPE_CHECKING: + pjit = _deprecated_pjit +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del _typing diff --git a/jax/experimental/profiler.py b/jax/experimental/profiler.py index 766d20472155..f22fba50092b 100644 --- a/jax/experimental/profiler.py +++ b/jax/experimental/profiler.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax._src.lib import xla_client +from jax._src.lib import _profiler def get_profiled_instructions_proto(tensorboard_dir: str) -> bytes: @@ -30,4 +30,4 @@ def get_profiled_instructions_proto(tensorboard_dir: str) -> bytes: Serialized [ProfiledInstructionsProto](https://github.com/openxla/xla/blob/main/third_party/tsl/tsl/profiler/protobuf/profiled_instructions.proto). """ - return xla_client.profiler.get_profiled_instructions_proto(tensorboard_dir) + return _profiler.get_profiled_instructions_proto(tensorboard_dir) diff --git a/jax/experimental/random.py b/jax/experimental/random.py new file mode 100644 index 000000000000..698fca2d8393 --- /dev/null +++ b/jax/experimental/random.py @@ -0,0 +1,20 @@ +# Copyright 2026 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Experimental random APIs.""" + +from jax._src.stateful_rng import ( + stateful_rng as stateful_rng, + StatefulPRNG as StatefulPRNG, +) diff --git a/jax/experimental/rnn.py b/jax/experimental/rnn.py index 55cf2b3bae70..a5ec0cf262f0 100644 --- a/jax/experimental/rnn.py +++ b/jax/experimental/rnn.py @@ -89,8 +89,8 @@ import jax import numpy as np from jax._src import core +from jax._src import dispatch from jax.interpreters import mlir -from jax.interpreters import xla from jax._src.custom_derivatives import custom_vjp from jax._src.typing import Array, Shape from jax._src.lax import lax @@ -453,17 +453,13 @@ def rnn_abstract_eval(x_aval, h_0_aval, c_0_aval, w_aval, seq_lengths_aval, return output_aval, h_0_aval, c_0_aval, reserve_space_aval -def _gpu_lowering_strip_tf32(fn, *args, cudnn_allow_tf32, **kw): - del cudnn_allow_tf32 - return fn(*args, **kw) - rnn_fwd_p = core.Primitive('rnn_fwd') rnn_fwd_p.multiple_results = True -rnn_fwd_p.def_impl(partial(xla.apply_primitive, rnn_fwd_p)) +rnn_fwd_p.def_impl(partial(dispatch.apply_primitive, rnn_fwd_p)) rnn_fwd_p.def_abstract_eval(rnn_abstract_eval) if gpu_rnn: mlir.register_lowering(rnn_fwd_p, gpu_rnn.cudnn_rnn_lowering, platform='cuda') - if hasattr(gpu_rnn, "miopen_rnn_fwd_lowering"): + if hasattr(gpu_rnn, "miopen_rnn_lowering"): mlir.register_lowering(rnn_fwd_p, gpu_rnn.miopen_rnn_lowering, platform='rocm') @@ -503,7 +499,7 @@ def rnn_bwd_abstract_eval(dy_aval, dhn_aval, dcn_aval, x_aval, h0_aval, c0_aval, rnn_bwd_p = core.Primitive('rnn_bwd') rnn_bwd_p.multiple_results = True -rnn_bwd_p.def_impl(partial(xla.apply_primitive, rnn_bwd_p)) +rnn_bwd_p.def_impl(partial(dispatch.apply_primitive, rnn_bwd_p)) rnn_bwd_p.def_abstract_eval(rnn_bwd_abstract_eval) if gpu_rnn: mlir.register_lowering( diff --git a/jax/experimental/roofline/roofline.py b/jax/experimental/roofline/roofline.py index 6a7f2916b503..3d6a0e0501ff 100644 --- a/jax/experimental/roofline/roofline.py +++ b/jax/experimental/roofline/roofline.py @@ -14,13 +14,16 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, Callable, Protocol, Sequence +from typing import Any, Protocol +from collections.abc import Callable, Sequence import numpy as np +from absl import logging import jax.numpy as jnp from jax.sharding import NamedSharding from jax._src import api from jax._src import core +from jax._src import prng from jax._src import source_info_util from jax._src import traceback_util from jax._src import util @@ -29,11 +32,12 @@ from jax._src.mesh import AbstractMesh, Mesh from jax._src.tree_util import broadcast_prefix, tree_flatten, tree_unflatten, tree_map from jax._src.util import foreach -from jax.experimental import shard_map +from jax._src.shard_map import shard_map, shard_map_p ShapeDtypeStructTree = Any - +Specs = Any +ValidRooflineDtype = np.dtype | prng.KeyTy map = util.safe_map @@ -53,14 +57,16 @@ class RooflineRuleContext: @dataclass(frozen=True, slots=True, kw_only=True) class RooflineShape: shape: tuple[int, ...] - dtype: np.dtype + dtype: ValidRooflineDtype @classmethod - def from_aval(cls, aval: core.AbstractValue) -> "RooflineShape": + def from_aval(cls, aval: core.AbstractValue) -> RooflineShape: if not isinstance(aval, core.ShapedArray): raise TypeError(f"Expected ShapedArray, got {type(aval)}.") - if not isinstance(aval.dtype, np.dtype): - raise TypeError(f"Expected numpy dtype, got {type(aval.dtype)}.") + if not isinstance(aval.dtype, ValidRooflineDtype): + raise TypeError( + f"Expected numpy or prng.KeyTy dtype, got {type(aval.dtype)}." + ) return cls(shape=aval.shape, dtype=aval.dtype) @property @@ -87,10 +93,10 @@ class RooflineResult: unfused_hbm_bytes: int = 0 @classmethod - def zeros(cls) -> "RooflineResult": + def zeros(cls) -> RooflineResult: return cls() - def __add__(self, other: "RooflineResult") -> "RooflineResult": + def __add__(self, other: RooflineResult) -> RooflineResult: def merge_ici_dicts(d1: dict[str, int], d2: dict[str, int]) -> dict[str, int]: return {k: d1.get(k, 0) + d2.get(k, 0) for k in set(d1) | set(d2)} @@ -104,7 +110,7 @@ def merge_ici_dicts(d1: dict[str, int], d2: dict[str, int]) -> dict[str, int]: unfused_hbm_bytes=self.unfused_hbm_bytes + other.unfused_hbm_bytes, ) - def __mul__(self, constant: int | float) -> "RooflineResult": + def __mul__(self, constant: int | float) -> RooflineResult: return RooflineResult( flops=int(self.flops * constant), unfused_flops=int(self.unfused_flops * constant), @@ -115,7 +121,7 @@ def __mul__(self, constant: int | float) -> "RooflineResult": unfused_hbm_bytes=int(self.unfused_hbm_bytes * constant), ) - def __rmul__(self, constant: int | float) -> "RooflineResult": + def __rmul__(self, constant: int | float) -> RooflineResult: return self.__mul__(constant) @@ -136,7 +142,7 @@ def _roofline_interpreter( pin_lhs_in_vmem: bool = False, pin_rhs_in_vmem: bool = False, ) -> RooflineResult: - name_stack = source_info_util.new_name_stack(util.wrap_name(f_name, "roofline")) + name_stack = source_info_util.new_name_stack(util.wrap_name("roofline", f_name)) result = RooflineResult.zeros() @@ -159,10 +165,8 @@ def aval(v: core.Atom) -> core.AbstractValue: else: return v.aval - def calculate_peak_hbm_bytes() -> int: - return int( - sum(np.prod(shape.shape) * shape.dtype.itemsize for shape in env.values()) - ) + def sum_bytes(shapes: Sequence[RooflineShape]) -> int: + return sum(shape.bytes for shape in shapes) jaxpr = jaxpr.jaxpr if isinstance(jaxpr, core.ClosedJaxpr) else jaxpr make_roofline_shape = lambda x: RooflineShape.from_aval(aval(x)) @@ -173,6 +177,10 @@ def calculate_peak_hbm_bytes() -> int: ) foreach(write, jaxpr.invars, map(make_roofline_shape, jaxpr.invars)) last_used = core.last_used(jaxpr) + + current_hbm_bytes = sum_bytes(list(env.values())) + peak_hbm_bytes = current_hbm_bytes + for eqn in jaxpr.eqns: source_info = eqn.source_info.replace( name_stack=name_stack + eqn.source_info.name_stack @@ -182,19 +190,29 @@ def calculate_peak_hbm_bytes() -> int: ): if "jaxpr" in eqn.params: result += _roofline_interpreter( - util.wrap_name(f_name, eqn.primitive.name), + util.wrap_name(eqn.primitive.name, f_name), eqn.params["jaxpr"], mesh, pin_lhs_in_vmem=pin_lhs_in_vmem, pin_rhs_in_vmem=pin_rhs_in_vmem, ) + elif "call_jaxpr" in eqn.params: + # Used for custom_jvp_call_p. Recursively calculates roofline result for + # all primitives in the custom function. + result += _roofline_interpreter( + util.wrap_name(eqn.primitive.name, f_name), + eqn.params['call_jaxpr'], + mesh, + pin_lhs_in_vmem=pin_lhs_in_vmem, + pin_rhs_in_vmem=pin_rhs_in_vmem, + ) + elif eqn.primitive not in _rooflines: + msg = f"No roofline rule for {eqn.primitive}, skipping..." + for attr in dir(eqn): + if not attr.startswith("_"): + msg += f"\n{attr}: {getattr(eqn, attr)}" + logging.warning(msg) else: - if eqn.primitive not in _rooflines: - msg = f"No roofline rule for {eqn.primitive}." - for attr in dir(eqn): - if not attr.startswith("_"): - msg += f"\n{attr}: {getattr(eqn, attr)}" - raise NotImplementedError(msg) rule = _rooflines[eqn.primitive] result += rule( RooflineRuleContext( @@ -211,10 +229,22 @@ def calculate_peak_hbm_bytes() -> int: **eqn.params, ) - foreach(write, eqn.outvars, map(make_roofline_shape, eqn.outvars)) + # Add bytes for the newly-created output variables. + outvar_shapes = map(make_roofline_shape, eqn.outvars) + current_hbm_bytes += sum_bytes(outvar_shapes) + foreach(write, eqn.outvars, outvar_shapes) + + # Remove bytes for the no-longer-needed input variables. + removed_shapes = [ + env[v] for v in eqn.invars + if not isinstance(v, core.Literal) and last_used[v] is eqn + ] + current_hbm_bytes -= sum_bytes(removed_shapes) core.clean_up_dead_vars(eqn, env, last_used) - result += RooflineResult(peak_hbm_bytes=calculate_peak_hbm_bytes()) + peak_hbm_bytes = max(peak_hbm_bytes, current_hbm_bytes) + + result += RooflineResult(peak_hbm_bytes=peak_hbm_bytes) return result @@ -230,8 +260,8 @@ def wrapped(*args): def roofline( f: Callable, mesh: Mesh | AbstractMesh | None = None, - in_specs: shard_map.Specs | None = None, - out_specs: shard_map.Specs | None = None, + in_specs: Specs | None = None, + out_specs: Specs | None = None, *, pin_lhs_in_vmem: bool = False, pin_rhs_in_vmem: bool = False, @@ -243,14 +273,15 @@ def roofline( def wrapped(*args): wrapped_f = f if in_specs is not None and out_specs is not None and mesh is not None: - wrapped_f = shard_map.shard_map(wrapped_f, mesh, in_specs, out_specs) + wrapped_f = shard_map(wrapped_f, mesh=mesh, in_specs=in_specs, + out_specs=out_specs) if vjp: wrapped_f = _f_with_vjp(wrapped_f) jaxpr, out_shapes = make_jaxpr(wrapped_f, return_shape=True)(*args) def make_sharded_shape_dtype_struct( - shape: api.ShapeDtypeStruct, out_spec: shard_map.Specs + shape: api.ShapeDtypeStruct, out_spec: Specs ) -> api.ShapeDtypeStruct: return api.ShapeDtypeStruct( shape.shape, shape.dtype, sharding=NamedSharding(mesh, out_spec) # type: ignore @@ -267,7 +298,7 @@ def make_sharded_shape_dtype_struct( used_outputs = (True,) * len(jaxpr.jaxpr.outvars) jaxpr, _ = dce_jaxpr(jaxpr.jaxpr, used_outputs) shard_map_eqns = [ - e for e in jaxpr.eqns if e.primitive == shard_map.shard_map_p + e for e in jaxpr.eqns if e.primitive == shard_map_p ] if shard_map_eqns: try: @@ -307,8 +338,8 @@ def standard_rule(ctx: RooflineRuleContext, *args, **kwargs): def roofline_and_grad( f: Callable, mesh: Mesh | AbstractMesh, - in_specs: shard_map.Specs, - out_specs: shard_map.Specs, + in_specs: Specs, + out_specs: Specs, *, pin_lhs_in_vmem: bool = False, pin_rhs_in_vmem: bool = False, diff --git a/jax/experimental/roofline/rooflines.py b/jax/experimental/roofline/rooflines.py index 1edd1e0649b1..87572c42181b 100644 --- a/jax/experimental/roofline/rooflines.py +++ b/jax/experimental/roofline/rooflines.py @@ -14,15 +14,24 @@ from collections import defaultdict from dataclasses import replace import itertools as it +from collections.abc import Sequence import numpy as np +from jax._src import api +from jax._src import ad_checkpoint from jax._src import ad_util from jax._src import core, util +from jax._src import dispatch from jax._src import ops +from jax._src import pjit from jax._src import prng from jax._src import random +from jax._src import shard_map +from jax._src import callback +from jax._src import debugging from jax._src.lax import ( ann, + control_flow, convolution, fft, lax, @@ -33,17 +42,23 @@ windowed_reductions, ) from jax.experimental import roofline -from jax.experimental import shard_map +# One FMA (Fused Multiply Add) takes 2 flops to compute. +_FMA_FLOPS_FACTOR = 2 for prim in it.chain( + ad_checkpoint.__dict__.values(), ad_util.__dict__.values(), ann.__dict__.values(), + callback.__dict__.values(), + control_flow.__dict__.values(), convolution.__dict__.values(), + dispatch.__dict__.values(), fft.__dict__.values(), lax.__dict__.values(), linalg.__dict__.values(), ops.__dict__.values(), + [pjit.sharding_constraint_p], prng.__dict__.values(), random.__dict__.values(), shard_map.__dict__.values(), @@ -106,6 +121,8 @@ def _unary_p_roofline( roofline.register_roofline(special.erfc_p)(_unary_p_roofline) roofline.register_roofline(special.lgamma_p)(_unary_p_roofline) +roofline.register_standard_roofline(core.pvary_p) + def _binary_p_roofline( ctx: roofline.RooflineRuleContext, *args, @@ -143,6 +160,50 @@ def _binary_p_roofline( roofline.register_roofline(lax.min_p)(_binary_p_roofline) roofline.register_roofline(lax.max_p)(_binary_p_roofline) +def _cumulative_p_roofline( + ctx: roofline.RooflineRuleContext, + *args, + axis: int, + **kw, +) -> roofline.RooflineResult: + (x,) = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in) + out = roofline.RooflineShape.from_aval(ctx.avals_out[0]) + return roofline.RooflineResult( + # `cum{max, min, prod, sum}` only calculate values for one axis. + unfused_flops=x.shape[axis], + unfused_hbm_bytes=( + x.dtype.itemsize * x.size + out.dtype.itemsize * out.size + ), + ) + +roofline.register_roofline(control_flow.cummax_p)(_cumulative_p_roofline) +roofline.register_roofline(control_flow.cummin_p)(_cumulative_p_roofline) +roofline.register_roofline(control_flow.cumprod_p)(_cumulative_p_roofline) +roofline.register_roofline(control_flow.cumsum_p)(_cumulative_p_roofline) + +@roofline.register_roofline(control_flow.cumlogsumexp_p) +def _cumlogsumexp_p_roofline( + ctx: roofline.RooflineRuleContext, + *args, + axis: int, + **kw, +) -> roofline.RooflineResult: + (x,) = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in) + out = roofline.RooflineShape.from_aval(ctx.avals_out[0]) + return roofline.RooflineResult( + # Similar to `cum{max, min, prod, sum}`, `cumlogsumexp` only calculates + # values for one axis. But for `x.shape[axis] = S`, it computes (for a + # naive implementation): + # S `exp` ops. + # S-1 `add` ops. + # 1 log op. + # Thus, the total number of flops is 2 * S. + unfused_flops=x.shape[axis] * 2, + unfused_hbm_bytes=( + x.dtype.itemsize * x.size + out.dtype.itemsize * out.size + ), + ) + @roofline.register_roofline(lax.dot_general_p) def _dot_general_roofline( @@ -156,7 +217,7 @@ def _dot_general_roofline( (lhs_contract, _), (lhs_batch, _) = dimension_numbers flops = ( - 2 + _FMA_FLOPS_FACTOR * lhs.size * rhs.size / np.prod([lhs.shape[i] for i in lhs_contract]) @@ -177,16 +238,208 @@ def _dot_general_roofline( unfused_hbm_bytes=hbm_bytes, ) + +def _get_spatial_valid_position_count_for_one_dim( + window_dim_stride: int, + base_dilation: int, + window_dilation: int, + kernel_limit: int, + input_limit: int, + output_limit: int, + padding: tuple[int, int], +) -> int: + """Gets the valid position count for conv for a single spatial dimension. + + Args: + window_dim_stride: The stride of the window along this dimension. + base_dilation: The base dilation factor along this dimension. + window_dilation: The window dilation factor along this dimension. + kernel_limit: The size of the kernel along this dimension. + input_limit: The size of the input along this dimension. + output_limit: The size of the output along this dimension. + padding: The padding applied to the input along this dimension. + """ + padding_low = padding[0] + padding_high = padding[1] + + # These two conditions will create an N^2 iteration pattern with only N + # valid elements. This is a performance optimization and produces the same + # result as the whole loop. + if ( + input_limit == output_limit + and kernel_limit == output_limit + and input_limit == base_dilation + and window_dilation == 1 + and max(1, input_limit - 1) == window_dim_stride + and padding_low == 0 + and padding_high == 0 + ): + return input_limit + + if ( + input_limit == 1 + and kernel_limit == output_limit + and window_dilation == 1 + and base_dilation == 1 + and window_dim_stride == 1 + and padding_low == output_limit - 1 + and padding_high == output_limit - 1 + ): + return output_limit + + valid_position_count = 0 + # Loop over each point in the kernel + for kernel_idx in range(kernel_limit): + + # Skip loop for trivial stride and base_dilation + if window_dim_stride == 1 and base_dilation == 1: + undilated_index_base = padding_low - kernel_idx * window_dilation + upper_limit = min( + input_limit + undilated_index_base, + output_limit, + ) + lower_limit = max(0, undilated_index_base) + + valid_position_count += max(upper_limit - lower_limit, 0) + continue + + # Loop over each point in the output + for output_idx in range(output_limit): + # Calculate lhs (input) index without taking base dilation into account + undilated_index = ( + output_idx * window_dim_stride + - padding_low + + kernel_idx * window_dilation + ) + # Calculate the actual lhs (input) index after dilation + lhs_spatial_index = int(undilated_index / base_dilation) + + # Skip if the lhs (input) index is to be dilated. + if undilated_index != lhs_spatial_index * base_dilation: + continue + # Skip if input index is not in bound. + if lhs_spatial_index < 0 or lhs_spatial_index >= input_limit: + continue + + valid_position_count += 1 + return valid_position_count + + +def _get_spatial_valid_position_count( + dnums: convolution.ConvDimensionNumbers, + lhs: roofline.RooflineShape, + rhs: roofline.RooflineShape, + out: roofline.RooflineShape, + window_strides: Sequence[int], + padding: Sequence[tuple[int, int]], + lhs_dilation: Sequence[int], + rhs_dilation: Sequence[int], +) -> int: + """Gets the number of valid spatial positions for conv_general_dilated. + + Args: + dnums: The dimension numbers for the convolution. + lhs: The shape of the left-hand side of the convolution. + rhs: The shape of the right-hand side of the convolution. + out: The shape of the output of the convolution. + window_strides: The stride of the window along each spatial dimension. + padding: The padding applied to the input along each spatial dimension. + lhs_dilation: The dilation factor for the left-hand side along each spatial + dimension. + rhs_dilation: The dilation factor for the right-hand side along each spatial + dimension. + """ + input_spatial_dims, kernel_spatial_dims, out_spatial_dims = ( + dnums.lhs_spec[2:], + dnums.rhs_spec[2:], + dnums.out_spec[2:], + ) + + valid_position_counts = 1 + # Loop over each spatial dimension and determine how many valid positions + # there are for each dimension. + for d in range(len(input_spatial_dims)): + valid_position_counts *= _get_spatial_valid_position_count_for_one_dim( + window_dim_stride=window_strides[d], + base_dilation=lhs_dilation[d], + window_dilation=rhs_dilation[d], + kernel_limit=rhs.shape[kernel_spatial_dims[d]], + input_limit=lhs.shape[input_spatial_dims[d]], + output_limit=out.shape[out_spatial_dims[d]], + padding=padding[d], + ) + + return valid_position_counts + + +def _calculate_conv_flops( + lhs: roofline.RooflineShape, + rhs: roofline.RooflineShape, + out: roofline.RooflineShape, + window_strides: Sequence[int], + padding: Sequence[tuple[int, int]], + lhs_dilation: Sequence[int], + rhs_dilation: Sequence[int], + dimension_numbers: convolution.ConvGeneralDilatedDimensionNumbers, + batch_group_count: int, +) -> int: + """Calculates roofline unfused flops for Jax's conv_general_dilated primitive. + + See `jax.lax.conv_general_dilated` for details on the arguments. + """ + dnums = convolution.conv_dimension_numbers( + lhs.shape, rhs.shape, dimension_numbers + ) + + spatial_valid_position_counts = _get_spatial_valid_position_count( + dnums, lhs, rhs, out, window_strides, padding, lhs_dilation, rhs_dilation + ) + + batch = lhs.shape[dnums.lhs_spec[0]] + num_output_features = out.shape[dnums.out_spec[1]] + num_input_features = rhs.shape[dnums.rhs_spec[1]] + num_output_batch = batch / batch_group_count + + non_spatial_dims_factor = ( + num_input_features * num_output_features * num_output_batch + ) + + fma_count = non_spatial_dims_factor * spatial_valid_position_counts + flops = fma_count * _FMA_FLOPS_FACTOR + return int(flops) + + @roofline.register_roofline(convolution.conv_general_dilated_p) def _conv_general_dilated_roofline( - ctx: roofline.RooflineRuleContext, - *args, - **kw, + ctx: roofline.RooflineRuleContext, + *args, + window_strides: Sequence[int], + padding: Sequence[tuple[int, int]], + lhs_dilation: Sequence[int], + rhs_dilation: Sequence[int], + dimension_numbers: convolution.ConvGeneralDilatedDimensionNumbers, + batch_group_count: int, + **kw, ) -> roofline.RooflineResult: + """Roofline for Jax's conv_general_dilated primitive. + + See `jax.lax.conv_general_dilated` for details on the arguments. + """ lhs, rhs = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in) out = roofline.RooflineShape.from_aval(ctx.avals_out[0]) - # TODO(b/394648206): support computing unfused_flops for conv. + return roofline.RooflineResult( + unfused_flops=_calculate_conv_flops( + lhs, + rhs, + out, + window_strides, + padding, + lhs_dilation, + rhs_dilation, + dimension_numbers, + batch_group_count, + ), unfused_hbm_bytes=( lhs.dtype.itemsize * lhs.size + rhs.dtype.itemsize * rhs.size @@ -257,11 +510,124 @@ def _ring_collective_roofline( ) +def _calculate_gather_flops( + mode: slicing.GatherScatterMode, + indices_size: int, + output_size: int, +) -> int: + """Calculates roofline unfused flops for Jax's gather primitive.""" + + if mode == slicing.GatherScatterMode.FILL_OR_DROP: + # With FILL_OR_DROP, we have 4 steps to check whether to fill (or drop): + # 1. Check if the index is within upper bound. + # 2. Check if the index is within lower bound. + # 3. Call `and` on #1 and #2 to check the index is "in bounds". + # 4. `reduce` the result to a single boolean per window. + # Each of the steps is a single elementwise op on the indices. + index_check_flops = indices_size * 4 + + # Once we know whether to fill or drop (per window), there are 2 steps to + # mask the output: + # 1. Broadcast the per-window boolean to the output shape. + # 2. Choose whether to fill (from `operand`) if in-bounds, or drop if + # out-of-bounds. + # Broadcasting is free, but choosing whether to fill or drop involves an + # elementwise op the size of the output. + output_mask_flops = output_size + return index_check_flops + output_mask_flops + + return 0 + + +@roofline.register_roofline(slicing.gather_p) +def _gather_roofline( + ctx: roofline.RooflineRuleContext, + *args, + mode: slicing.GatherScatterMode, + **kw, +) -> roofline.RooflineResult: + _, indices = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in) + out = roofline.RooflineShape.from_aval(ctx.avals_out[0]) + + # Gather doesn't read the whole input buffer, it's equivalent to a copy the + # size of the output shape and a read of the gather indices. + unfused_hbm_bytes = ( + out.dtype.itemsize * out.size * 2 + indices.dtype.itemsize * indices.size + ) + + return roofline.RooflineResult( + unfused_flops=_calculate_gather_flops(mode, indices.size, out.size), + unfused_hbm_bytes=unfused_hbm_bytes, + ) + + +def _scatter_roofline( + ctx: roofline.RooflineRuleContext, + *args, + **kw, +) -> roofline.RooflineResult: + """Roofline for Jax's `scatter*` primitives. + + The `scatter` functionality itself is a simple data read and write, which + contributes 0 flops. + + But, the jaxpr for each `scatter*` function (aside from `jax.lax.scatter`) + contains an `update_jaxpr` that gets applied to the operand & scattered + updates (e.g. `add` for `scatter_add`, or arbitrary unary function for + `scatter_apply`), which *does* contribute flops. This `update_jaxpr` gets + applied to every element of the scattered updates. + + Thus, + flops = [# flops for `update_jaxpr` per element] * [# elements in `updates`]. + + To calculate # flops for `update_jaxpr`, we convert the `update_jaxpr` back to + a callable, and then call `roofline` on that callable. `update_jaxpr` does not + contain any information about input shapes or dtypes; it expects scalars. It + will therefore give us a # flops-per-element result, which we multiply by + the size of the updates to get the total flops. + """ + (_, indices, updates) = ( + roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in + ) + + update_jaxpr = kw.get('update_jaxpr') + + flops = 0 + if update_jaxpr: + update_fn = lambda *inputs: core.eval_jaxpr(update_jaxpr, [], *inputs) + # Create dummy scalar inputs. + dummy_inputs = [ + api.ShapeDtypeStruct((), updates.dtype) for _ in update_jaxpr.invars + ] + # Calculate the flops for the `update_jaxpr` on scalar inputs. + _, roofline_result = roofline.roofline(update_fn)(*dummy_inputs) + # Multiply by the size of the updates to get the total flops. + flops = roofline_result.unfused_flops * updates.size + + return roofline.RooflineResult( + unfused_flops=flops, + # Scatter accesses the equivalent of 3N update shapes (input, output, and + # updates), and the scatter indices. + unfused_hbm_bytes=( + 3 * updates.dtype.itemsize * updates.size + + indices.dtype.itemsize * indices.size + ), + ) + + +roofline.register_roofline(slicing.scatter_add_p)(_scatter_roofline) +roofline.register_roofline(slicing.scatter_max_p)(_scatter_roofline) +roofline.register_roofline(slicing.scatter_min_p)(_scatter_roofline) +roofline.register_roofline(slicing.scatter_mul_p)(_scatter_roofline) +roofline.register_roofline(slicing.scatter_sub_p)(_scatter_roofline) +# Also registers `jax.lax.scatter_apply`, which uses the `scatter_p` primitive. +roofline.register_roofline(slicing.scatter_p)(_scatter_roofline) + def _scalar_collective_roofline( - ctx: roofline.RooflineRuleContext, - *args, - axes: tuple[str, ...], - **kw, + ctx: roofline.RooflineRuleContext, + *args, + axes: tuple[str, ...], + **kw, ) -> roofline.RooflineResult: shapes = [roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in] ctx = replace(ctx, avals_in=[core.ShapedArray((1,), shape.dtype) for shape in shapes]) @@ -272,7 +638,7 @@ def _scalar_collective_roofline( roofline.register_roofline(lax_parallel.pmax_p)(_scalar_collective_roofline) -@roofline.register_roofline(shard_map.psum2_p) +@roofline.register_roofline(lax_parallel.psum_invariant_p) def _psum2_roofline( ctx: roofline.RooflineRuleContext, *args, @@ -405,3 +771,52 @@ def _reduce_sum_p_roofline( # as accumulator.) unfused_hbm_bytes=int(x.dtype.itemsize * (x.size + result_size)), ) + +@roofline.register_roofline(lax.select_n_p) +def _select_n_p_roofline( + ctx: roofline.RooflineRuleContext, + *args, + **kw, +) -> roofline.RooflineResult: + (x, *_) = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in) + out = roofline.RooflineShape.from_aval(ctx.avals_out[0]) + + return roofline.RooflineResult( + unfused_flops=out.size, + unfused_hbm_bytes=( + x.dtype.itemsize * x.size + out.dtype.itemsize * out.size + ), + ) + + +@roofline.register_roofline(callback.pure_callback_p) +@roofline.register_roofline(callback.io_callback_p) +def _callback_with_output_roofline( + ctx: roofline.RooflineRuleContext, + *args, + **kw, +) -> roofline.RooflineResult: + avals_in = ctx.avals_in + avals_out = ctx.avals_out + # HBM bytes for transferring inputs to host and results back to device. + hbm_bytes = roofline.RooflineShape.total_bytes( + avals_in + ) + roofline.RooflineShape.total_bytes(avals_out) + # We don't have access to the `callback_func`, so we assume it contributes 0 + # flops. + return roofline.RooflineResult(unfused_hbm_bytes=hbm_bytes) + + +@roofline.register_roofline(debugging.debug_callback_p) +def _debug_callback_roofline( + ctx: roofline.RooflineRuleContext, + *args, + **kw, +) -> roofline.RooflineResult: + avals_in = ctx.avals_in + # `debug_callback` does not return values to the JAX program, so only input + # HBM bytes are considered. + hbm_bytes = roofline.RooflineShape.total_bytes(avals_in) + # We don't have access to the `callback_func`, so we assume it contributes 0 + # flops. + return roofline.RooflineResult(unfused_hbm_bytes=hbm_bytes) diff --git a/jax/experimental/scheduling_groups.py b/jax/experimental/scheduling_groups.py new file mode 100644 index 000000000000..85fef11010eb --- /dev/null +++ b/jax/experimental/scheduling_groups.py @@ -0,0 +1,184 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +from jax._src import core +from jax._src import dispatch +from jax._src import linear_util as lu +from jax._src.api_util import debug_info, flatten_fun +from jax._src.util import (safe_map, safe_zip, weakref_lru_cache, unzip2, + split_list) +from jax._src.tree_util import tree_flatten, tree_unflatten +from jax._src.interpreters import ad, mlir, partial_eval as pe, batching +from jax._src.lib.mlir.dialects import func as func_dialect +from jax._src.lib.mlir import ir + +map, unsafe_map = safe_map, map +zip, unsafe_zip = safe_zip, zip + +def scheduling_group(name): + return xla_metadata_call(scheduling_group=name) + +def xla_metadata_call(f=None, **meta): + if f is None: + return lambda g: _xla_metadata_call(g, **meta) + return _xla_metadata_call(f, **meta) + +# TODO(yashkatariya): Figure out a way to reuse code with compute_on2_p, fused_p +def _xla_metadata_call(fun, **meta): + def wrapped(*args, **kwargs): + dbg = debug_info('xla_metadata_call', fun, args, kwargs) + args_flat, in_tree = tree_flatten((args, kwargs)) + f = lu.wrap_init(fun, debug_info=dbg) + f, out_tree = flatten_fun(f, in_tree) + in_avals = tuple(core.shaped_abstractify(x) for x in args_flat) + jaxpr = _trace_to_jaxpr(f, in_avals) + outs_flat = xla_metadata_call_p.bind(*args_flat, jaxpr=jaxpr, **meta) + return tree_unflatten(out_tree(), outs_flat) + return wrapped + +@lu.cache +def _trace_to_jaxpr(flat_fun, in_avals): + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) + return core.ClosedJaxpr(jaxpr, consts) + +xla_metadata_call_p = core.Primitive('xla_metadata_call') +xla_metadata_call_p.multiple_results = True +dispatch.simple_impl(xla_metadata_call_p) + + +def _xla_metadata_call_abstract_eval(*in_avals, jaxpr, **meta): + return jaxpr.out_avals +xla_metadata_call_p.def_abstract_eval(_xla_metadata_call_abstract_eval) + + +def attr_get(x): + if isinstance(x, str): + return ir.StringAttr.get(x) + else: + raise NotImplementedError(f'mlir attr handler for {type(x)=}') + +def _xla_metadata_call_lowering(ctx, *args, jaxpr, **meta): + const_args_and_avals = core.jaxpr_const_args(jaxpr.jaxpr) + const_args, const_avals = unzip2(const_args_and_avals) + const_arg_values = [ + mlir.ir_constant(c, const_lowering=ctx.const_lowering, aval=aval) + for c, aval in const_args_and_avals] + in_avals = (*const_avals, *ctx.avals_in) + func_op, output_types, effects = mlir.lower_called_computation( + "xla_metadata_call", jaxpr, ctx.module_context, len(const_args), in_avals, + ctx.avals_out, ctx.tokens_in) + + symbol_name = func_op.name.value + flat_output_types = mlir.flatten_ir_types(output_types) + tokens = [ctx.tokens_in.get(eff) for eff in effects] + args = (*ctx.dim_var_values, *tokens, *const_arg_values, *args) + call = func_dialect.CallOp( + flat_output_types, ir.FlatSymbolRefAttr.get(symbol_name), + mlir.flatten_ir_values(args)) + call.operation.attributes['mhlo.frontend_attributes'] = ir.DictAttr.get( + {k: attr_get(v) for k, v in meta.items()}) + + out_nodes = mlir.unflatten_ir_values_like_types(call.results, output_types) + tokens, out_nodes = split_list(out_nodes, [len(effects)]) + tokens_out = ctx.tokens_in.update_tokens(mlir.TokenSet(zip(effects, tokens))) + ctx.set_tokens_out(tokens_out) + return out_nodes +mlir.register_lowering(xla_metadata_call_p, _xla_metadata_call_lowering) + + +def _xla_metadata_call_batcher(axis_data, vals_in, dims_in, *, jaxpr, **meta): + batched_jaxpr, dims_out = batching.batch_jaxpr2(jaxpr, axis_data, dims_in) + outs = xla_metadata_call_p.bind(*vals_in, jaxpr=batched_jaxpr, **meta) + return outs, dims_out +batching.fancy_primitive_batchers[xla_metadata_call_p] = _xla_metadata_call_batcher + + +def _xla_metadata_call_jvp(primals, tangents, *, jaxpr, **meta): + nzs = [not isinstance(t, ad.Zero) for t in tangents] + jaxpr_jvp, out_nzs = ad.jvp_jaxpr(jaxpr, nzs, False) + nz_tangents = [t for t in tangents if not isinstance(t, ad.Zero)] + outs = xla_metadata_call_p.bind(*primals, *nz_tangents, jaxpr=jaxpr_jvp, **meta) + primals_out, nz_tangents_out = outs[:len(out_nzs)], outs[len(out_nzs):] + nz_outs = iter(nz_tangents_out) + tangents_out = [next(nz_outs) if nz else ad.Zero(aval.to_tangent_aval()) + for aval, nz in zip(jaxpr.out_avals, out_nzs)] + assert next(nz_outs, None) is None + return primals_out, tangents_out +ad.primitive_jvps[xla_metadata_call_p] = _xla_metadata_call_jvp + + +def _xla_metadata_call_lin(nzs, *primals, jaxpr, **meta): + jaxpr_jvp, out_nzs = ad.jvp_jaxpr(jaxpr, nzs, False) + lin_outs = [False] * len(out_nzs) + [True] * sum(out_nzs) + jaxpr_lin_, used_inputs = pe.dce_jaxpr(jaxpr_jvp.jaxpr, lin_outs, False) + jaxpr_lin = pe.close_jaxpr(jaxpr_lin_) + primals_out = xla_metadata_call_p.bind(*primals, jaxpr=jaxpr, **meta) + tangent_avals_out = [a.to_tangent_aval() for a in jaxpr.out_avals] + + def xla_metadata_call_lin(primals, *tangents): + nz_tangents = [t for t in tangents if not isinstance(t, ad.Zero)] + inputs = [x for x, u in zip([*primals, *nz_tangents], used_inputs) if u] + nz_outs = xla_metadata_call_p.bind(*inputs, jaxpr=jaxpr_lin, **meta) + nz_outs_ = iter(nz_outs) + outs = [next(nz_outs_) if nz else ad.Zero(a) + for nz, a in zip(out_nzs, tangent_avals_out)] + assert next(nz_outs_, None) is None + return outs + return primals_out, out_nzs, primals, xla_metadata_call_lin +ad.primitive_linearizations[xla_metadata_call_p] = _xla_metadata_call_lin + + +pe.partial_eval_jaxpr_custom_rules[xla_metadata_call_p] = \ + partial(pe.closed_call_partial_eval_custom_rule, 'jaxpr', + lambda _, __, ___, ____, _____, ______, x, y: (x, y)) + +@weakref_lru_cache +def _transpose_jaxpr(jaxpr, in_avals, in_tree): + cell = lambda: None + def transposed(*in_flat): + primals_in, cts_in = tree_unflatten(in_tree, in_flat) + out = ad.backward_pass(jaxpr.jaxpr, False, jaxpr.consts, primals_in, cts_in) + out = [ct if not isinstance(ct, ad.Zero) else None for ct in out] + cts_out, cell.out_tree = tree_flatten(out) # type: ignore + return cts_out + dbg = jaxpr.jaxpr.debug_info.with_unknown_names() + trans_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(transposed, debug_info=dbg), in_avals) + return core.ClosedJaxpr(trans_jaxpr, consts), cell.out_tree # type: ignore + +def _xla_metadata_call_transpose(cts_in, *primals_in, jaxpr, **meta): + in_flat, in_tree = tree_flatten((primals_in, cts_in)) + in_avals = tuple(core.typeof(x) for x in in_flat) + trans_jaxpr, out_tree = _transpose_jaxpr(jaxpr, in_avals, in_tree) + cts_out = xla_metadata_call_p.bind(*in_flat, jaxpr=trans_jaxpr, **meta) + return tree_unflatten(out_tree, cts_out) +ad.primitive_transposes[xla_metadata_call_p] = _xla_metadata_call_transpose + + +def dce_jaxpr_xla_metadata_rule(used_outputs: list[bool], eqn: pe.JaxprEqn + ) -> tuple[list[bool], pe.JaxprEqn | None]: + if not any(used_outputs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None + jaxpr_ = eqn.params['jaxpr'] + closed_jaxpr, used_inputs = pe._cached_closed_call_dce( + jaxpr_, tuple(used_outputs)) + new_params = dict(eqn.params, jaxpr=closed_jaxpr) + new_eqn = pe.new_jaxpr_eqn( + [v for v, used in zip(eqn.invars, used_inputs) if used], + [v for v, used in zip(eqn.outvars, used_outputs) if used], + eqn.primitive, new_params, closed_jaxpr.effects, eqn.source_info, eqn.ctx) + return used_inputs, new_eqn +pe.dce_rules[xla_metadata_call_p] = dce_jaxpr_xla_metadata_rule diff --git a/jax/experimental/serialize_executable.py b/jax/experimental/serialize_executable.py index 2d65141a22ea..af5442c7a7b1 100644 --- a/jax/experimental/serialize_executable.py +++ b/jax/experimental/serialize_executable.py @@ -20,6 +20,7 @@ import jax from jax._src.lib import xla_client as xc +from collections.abc import Sequence def serialize(compiled: jax.stages.Compiled): @@ -32,8 +33,12 @@ def serialize(compiled: jax.stages.Compiled): '_unloaded_executable', None) if unloaded_executable is None: raise ValueError("Compilation does not support serialization") + if getattr(unloaded_executable, 'mut', None) and unloaded_executable.mut.in_mut: + raise ValueError("can't serialize with a closed-over mutable array ref") args_info_flat, in_tree = jax.tree_util.tree_flatten(compiled.args_info) - + # TODO(necula): deal with constants in serialized executables + if compiled._params.const_args: + raise NotImplementedError("serialize_executables with const_args") with io.BytesIO() as file: _JaxPjrtPickler(file).dump( (unloaded_executable, args_info_flat, compiled._no_kwargs)) @@ -43,21 +48,34 @@ def serialize(compiled: jax.stages.Compiled): def deserialize_and_load(serialized, in_tree, out_tree, - backend: str | xc.Client | None = None): + backend: str | xc.Client | None = None, + execution_devices: Sequence[xc.Device] | None = None): """Constructs a jax.stages.Compiled from a serialized executable.""" if backend is None or isinstance(backend, str): backend = jax.devices(backend)[0].client + if execution_devices is None: + execution_devices = backend.devices() + else: + device_backend = execution_devices[0].client + if device_backend != backend: + raise ValueError( + 'Execution devices belong to a client other than `backend`. Got ' + f'backend client: {(backend.platform, backend.platform_version)} and ' + 'execution devices client: ' + f'{(device_backend.platform, device_backend.platform_version)}') + (unloaded_executable, args_info_flat, - no_kwargs) = _JaxPjrtUnpickler(io.BytesIO(serialized), backend).load() + no_kwargs) = _JaxPjrtUnpickler( + io.BytesIO(serialized), backend, execution_devices).load() args_info = in_tree.unflatten(args_info_flat) loaded_compiled_obj = unloaded_executable.load() - + # TODO(necula): deal with constants in serialized executables return jax.stages.Compiled( - loaded_compiled_obj, args_info, out_tree, no_kwargs=no_kwargs) + loaded_compiled_obj, [], args_info, out_tree, no_kwargs=no_kwargs) class _JaxPjrtPickler(pickle.Pickler): @@ -77,14 +95,26 @@ def persistent_id(self, obj): class _JaxPjrtUnpickler(pickle.Unpickler): - def __init__(self, file, backend): + def __init__(self, file, backend, execution_devices=None): super().__init__(file) self.backend = backend - self.devices_by_id = {d.id: d for d in backend.devices()} + if execution_devices is None: + execution_devices = backend.devices() + else: + device_backend = execution_devices[0].client + if device_backend != backend: + raise ValueError( + 'Execution devices belong to a client other than `backend`. Got ' + f'backend client: {(backend.platform, backend.platform_version)} ' + 'and execution devices client: ' + f'{(device_backend.platform, device_backend.platform_version)}') + self.devices_by_id = {d.id: d for d in execution_devices} + self.execution_devices = xc.DeviceList(tuple(execution_devices)) def persistent_load(self, pid): if pid[0] == 'exec': - return self.backend.deserialize_executable(pid[1]) + return self.backend.deserialize_executable( + pid[1], executable_devices=self.execution_devices) if pid[0] == 'device': return self.devices_by_id[pid[1]] if pid[0] == 'client': diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 66b70c6c2d34..d549a5c57ccf 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -11,2199 +11,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations -from collections.abc import Callable, Hashable, Sequence -import enum -from functools import partial -import inspect -import itertools as it -from math import prod -import operator as op -from typing import Any, TypeVar, Union - -import numpy as np - -import jax -import jax.numpy as jnp -from jax.sharding import NamedSharding, PartitionSpec -from jax._src import ad_checkpoint -from jax._src import ad_util -from jax._src import api_util -from jax._src import callback -from jax._src import config -from jax._src import core -from jax._src import custom_derivatives -from jax._src import debugging -from jax._src import dispatch -from jax._src import dtypes -from jax._src import linear_util as lu -from jax._src import ops -from jax._src import pjit -from jax._src import prng -from jax._src import random -from jax._src import sharding_impls -from jax._src import source_info_util from jax._src import traceback_util -from jax._src import util -from jax._src.core import Tracer -from jax._src.mesh import (AbstractMesh, Mesh, AxisType, use_abstract_mesh, - get_abstract_mesh) -from jax._src.api import _shared_code_pmap, _prepare_pmap -from jax._src.lax import (lax, parallel as lax_parallel, slicing, - windowed_reductions, convolution, fft, linalg, - special, control_flow, ann) -from jax._src import ffi -from jax._src.lib.mlir import ir -from jax._src.lib.mlir.dialects import sdy -from jax._src.util import (HashableFunction, HashablePartial, unzip2, - as_hashable_function, memoize, partition_list, - merge_lists, split_list, subs_list2, foreach) -from jax._src.interpreters import batching -from jax._src.interpreters import mlir -from jax._src.interpreters import partial_eval as pe -from jax._src.interpreters import pxla -from jax._src.interpreters import ad -from jax.tree_util import (tree_map, tree_flatten, tree_unflatten, - tree_structure, tree_leaves, keystr) -from jax._src.tree_util import (broadcast_prefix, prefix_errors, PyTreeDef, - generate_key_paths, KeyPath) -from jax.experimental.multihost_utils import (host_local_array_to_global_array, - global_array_to_host_local_array) - -P = PartitionSpec - -map, unsafe_map = util.safe_map, map -zip, unsafe_zip = util.safe_zip, zip -traceback_util.register_exclusion(__file__) - -# API - -Specs = Any # PyTree[PartitionSpec] -AxisName = Hashable - +from jax._src import shard_map as jshmap @traceback_util.api_boundary -def shard_map(f: Callable, mesh: Mesh | AbstractMesh, in_specs: Specs, - out_specs: Specs, check_rep: bool = True, - auto: frozenset[AxisName] = frozenset()): - """Map a function over shards of data. - - Note: - ``shard_map`` is an experimental API, and still subject to change. For an - introduction to sharded data, refer to :ref:`sharded-computation`. For a more - in-depth look at using ``shard_map``, refer to `SPMD multi-device parallelism with shard_map`_. - - Args: - f: callable to be mapped. Each application of ``f``, or "instance" of ``f``, - takes as input a shard of the mapped-over arguments and produces a shard - of the output. - mesh: a ``jax.sharding.Mesh`` representing the array of devices over which - to shard the data and on which to execute instances of ``f``. The names of - the ``Mesh`` can be used in collective communication operations in ``f``. - This is typically created by a utility function like - :func:`jax.experimental.mesh_utils.create_device_mesh`. - in_specs: a pytree with :class:`~jax.sharding.PartitionSpec` instances as leaves, - with a tree structure that is a tree prefix of the args tuple to be mapped - over. Similar to :class:`~jax.sharding.NamedSharding`, each ``PartitionSpec`` - represents how the corresponding argument (or subtree of arguments) should - be sharded along the named axes of ``mesh``. In each ``PartitionSpec``, - mentioning a ``mesh`` axis name at a position expresses sharding the - corresponding argument array axis along that positional axis; not - mentioning an axis name expresses replication. If an argument, or argument - subtree, has a corresponding spec of None, that argument is not sharded. - out_specs: a pytree with :class:`~jax.sharding.PartitionSpec` instances as leaves, - with a tree structure that is a tree prefix of the output of ``f``. Each - ``PartitionSpec`` represents how the corresponding output shards should be - concatenated. In each ``PartitionSpec``, metioning a ``mesh`` axis name at - a position expresses concatenation of that mesh axis's shards along the - corresponding positional axis. Not mentioning a ``mesh`` axis name - expresses a promise that the output values are equal along that mesh axis, - and that rather than concatenating only a single value should be produced. - check_rep: If True (default) enable additional validity checks and automatic - differentiation optimizations. The validity checks concern whether any mesh - axis names not mentioned in ``out_specs`` are consistent with how the outputs - of ``f`` are replicated. Must be set False if using a Pallas kernel in ``f``. - auto: (experimental) an optional set of axis names from ``mesh`` over which we - do not shard the data or map the function, but rather we allow the - compiler to control sharding. These names cannot be used in ``in_specs``, - ``out_specs``, or in communication collectives in ``f``. - - Returns: - A callable that applies the input function ``f`` across data sharded according to - the ``mesh`` and ``in_specs``. - - Examples: - For examples, refer to :ref:`sharded-computation` or `SPMD multi-device parallelism with shard_map`_. - - .. _SPMD multi-device parallelism with shard_map: https://jax.readthedocs.io/en/latest/notebooks/shard_map.html - """ - return _shard_map(f, mesh, in_specs, out_specs, check_rep, auto) - -def _shard_map(f: Callable, mesh: Mesh | AbstractMesh, in_specs: Specs, - out_specs: Specs | Callable[[], Specs], - check_rep: bool, auto: frozenset[AxisName]): - if not callable(f): - raise TypeError("shard_map requires a callable for its first argument, " - f"but got {f} of type {type(f)}.") - if not isinstance(mesh, (Mesh, AbstractMesh)): - raise TypeError("shard_map requires a `jax.sharding.Mesh` or a " - "`jax.sharding.AbstractMesh` instance for its " - f"second argument, but got {mesh} of type {type(mesh)}.") - if not auto.issubset(mesh.axis_names): - raise ValueError(f"shard_map requires auto={auto} to be a subset of " - f"mesh.axis_names={mesh.axis_names}") - _check_specs(SpecErrorType.input, in_specs, auto) - if not callable(out_specs): - _check_specs(SpecErrorType.out, out_specs, auto) - - @util.wraps(f) - @traceback_util.api_boundary - def wrapped(*args): - fun = lu.wrap_init(f, - debug_info=api_util.debug_info("shard_map", f, args, {})) - args_flat, in_tree = tree_flatten(args) - fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) - try: in_specs_flat = broadcast_prefix(in_specs, args, - is_leaf=lambda x: x is None) - except ValueError: - e, *_ = prefix_errors(in_specs, args) - raise e('shard_map in_specs') from None - dyn_argnums, in_specs_flat = unzip2((i, s) for i, s in enumerate(in_specs_flat) - if s is not None) - fun, args_flat = api_util.argnums_partial(fun, dyn_argnums, args_flat, False) - _check_specs_vs_args(f, mesh, in_tree, in_specs, dyn_argnums, in_specs_flat, args_flat) - in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat)) - - @memoize - def out_names_thunk(): - if callable(out_specs): - out_specs_ = out_specs() - _check_specs(SpecErrorType.out, out_specs_, auto) - else: - out_specs_ = out_specs - dummy = tree_unflatten(out_tree(), [object()] * out_tree().num_leaves) - try: out_specs_flat = broadcast_prefix(out_specs_, dummy) - except ValueError: - e, *_ = prefix_errors(out_specs_, dummy) - raise e('shard_map out_specs') from None - return tuple(map(_canonicalize_spec, out_specs_flat)) - - if rewrite := check_rep: - fun = _efficient_transpose_rewrite(fun, mesh, in_names_flat, out_names_thunk) - - try: - out_flat = shard_map_p.bind( - fun, *args_flat, mesh=mesh, in_names=in_names_flat, - out_names_thunk=out_names_thunk, check_rep=check_rep, rewrite=rewrite, - auto=auto) - except _SpecError as e: - fails, = e.args - if not callable(out_specs): - msg = _spec_rank_error(SpecErrorType.out, f, out_tree(), out_specs, fails) - if any(fail is not no_fail and not fail.shape for fail in fails): - msg += (" In particular, for rank 0 outputs which are not constant " - "over the mesh, add at least one (singleton) axis to them so " - "that they can be concatenated using out_specs.") - raise ValueError(msg) from None - except _RepError as e: - fails, = e.args - if not callable(out_specs): - msg = _inout_rep_error(f, mesh, out_tree(), out_specs, fails) - raise ValueError(msg) from None - return tree_unflatten(out_tree(), out_flat) - return wrapped - -# Internally use AxisNames = dict[int, tuple[AxisName, ...]], not PartitionSpecs -AxisNames = dict[int, tuple[AxisName, ...]] # TODO(mattjj): make it hashable -def _canonicalize_spec(spec: PartitionSpec) -> AxisNames: - if isinstance(spec, PartitionSpec): - return {i: names if isinstance(names, tuple) else (names,) - for i, names in enumerate(spec) if names is not None} - else: - return spec - -# Error checking and messages - -SpecErrorType = enum.Enum('SpecErrorType', ['input', 'out']) - -def _check_specs(error_type: SpecErrorType, specs: Any, auto) -> None: - if error_type == SpecErrorType.input and specs is None: - raise TypeError( - "shard_map in_specs argument must be a pytree of " - "`jax.sharding.PartitionSpec` instances, but it was None.\n" - "Instead of `in_specs=None`, did you mean `in_specs=P()`, " - "where `P = jax.sharding.PartitionSpec`?") - def check_spec(p): - if not isinstance(p, PartitionSpec): - return False - for names in p: - if not isinstance(names, tuple): - names = (names,) - for name in names: - if name in auto: - return False - return True - if all(check_spec(p) for p in tree_leaves(specs)): return - prefix = 'in' if error_type == SpecErrorType.input else 'out' - msgs = [f" {prefix}_specs{keystr(key)} is {x} of type {type(x).__name__}, " - for key, x in generate_key_paths(specs) if not isinstance(x, P)] - if not msgs: - for key, p in generate_key_paths(specs): - for names in p: - if not isinstance(names, tuple): - names = (names,) - for name in names: - if name in auto: - msgs.append(f" {prefix}_specs{keystr(key)} refers to {repr(name)}") - raise ValueError( - f"shard_map {prefix}_specs argument cannot refer to an axis " - f"marked auto ({auto}), but:\n\n" - + '\n\n'.join(msgs) + '\n\n' - f"Check the {prefix}_specs values passed to shard_map.") - raise TypeError( - f"shard_map {prefix}_specs argument must be a pytree of " - f"`jax.sharding.PartitionSpec` instances, but:\n\n" - + '\n\n'.join(msgs) + '\n\n' - f"Check the {prefix}_specs values passed to shard_map.") - -class NoFail: pass -no_fail = NoFail() - -def _check_specs_vs_args( - f: Callable, mesh: Mesh, in_tree: PyTreeDef, in_specs: Specs, - dyn_argnums: Sequence[int], in_specs_flat: Sequence[P], - xs: Sequence) -> None: - in_avals = map(core.shaped_abstractify, xs) - fail = [a if not len(p) <= a.ndim else no_fail - for p, a in zip(in_specs_flat, in_avals)] - if any(f is not no_fail for f in fail): - fail = _expand_fail(in_tree, dyn_argnums, fail) - msg = _spec_rank_error(SpecErrorType.input, f, in_tree, in_specs, fail) - raise ValueError(msg) - in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat)) - fail = [a if any(a.shape[d] % prod(mesh.shape[n] for n in ns) - for d, ns in names.items()) else no_fail - for a, names in zip(in_avals, in_names_flat)] - if any(f is not no_fail for f in fail): - fail = _expand_fail(in_tree, dyn_argnums, fail) - msg = _spec_divisibility_error(f, mesh, in_tree, in_specs, fail) - raise ValueError(msg) - -def _expand_fail(in_tree: PyTreeDef, dyn_argnums: Sequence[int], - fail: Sequence[core.ShapedArray | NoFail] - ) -> list[core.ShapedArray | NoFail]: - fail_: list[core.ShapedArray | NoFail] = [no_fail] * in_tree.num_leaves - for i, f in zip(dyn_argnums, fail): - fail_[i] = f - return fail_ - -def _spec_rank_error( - error_type: SpecErrorType, f: Callable, tree: PyTreeDef, specs: Specs, - fails: list[core.ShapedArray | NoFail]) -> str: - fun_name = getattr(f, '__name__', str(f)) - if error_type == SpecErrorType.input: - prefix, base = 'in', 'args' - ba = _try_infer_args(f, tree) - else: - prefix, base = 'out', f'{fun_name}(*args)' - msgs = [] - for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails): - extra = "" - if error_type == SpecErrorType.input and ba is not None: - arg_key, *_ = fail_key - if arg_key.idx < len(ba.arguments): - param_name = list(ba.arguments.keys())[arg_key.idx] - extra = (f", where {base}{arg_key} is bound to {fun_name}'s " - f"parameter '{param_name}',") - else: - param = list(ba.signature.parameters.values())[-1] - assert param.kind == inspect.Parameter.VAR_POSITIONAL - extra = (f", where {base}{arg_key} is the index " - f"{arg_key.idx - len(ba.signature.parameters) + 1} component " - f"of {fun_name}'s varargs parameter '{param.name}',") - msgs.append( - f"* {prefix}_specs{keystr(spec_key)} is {spec} which has length " - f"{len(spec)}, but " - f"{base}{keystr(fail_key)}{extra} has shape {aval.str_short()}, " - f"which has rank {aval.ndim} (and {aval.ndim} < {len(spec)})") - assert msgs - if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point - msg = (f"shard_map applied to the function '{fun_name}' was given an " - f"{prefix}_specs entry which is too long to be compatible with the " - f"corresponding {prefix}put value from the function:\n\n" - + '\n\n'.join(msgs) + '\n\n' + - f"Entries in {prefix}_specs must be of length no greater than the " - f"number of axes in the corresponding {prefix}put value.\n\n" - f"Either revise the spec to be shorter, or modify '{fun_name}' so " - f"that its {prefix}puts have sufficient rank.") - if any(not aval.ndim for _, (_, aval) in _iter_paths(tree, specs, fails)): - msg += (f"\n\nFor scalar values (rank 0), consider using an {prefix}_specs " - "entry of `P()`, where `P = jax.sharding.PartitionSpec`.") - return msg - -def _spec_divisibility_error( - f: Callable, mesh: Mesh, tree: PyTreeDef, specs: Specs, - fails: list[core.ShapedArray | NoFail]) -> str: - ba = _try_infer_args(f, tree) - fun_name = getattr(f, '__name__', str(f)) - msgs = [] - for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails): - extra = "" - if ba is not None: - arg_key, *_ = fail_key - if arg_key.idx < len(ba.arguments): - param_name = list(ba.arguments.keys())[arg_key.idx] - extra = (f", where args{arg_key} is bound to {fun_name}'s " - f"parameter '{param_name}',") - else: - param = list(ba.signature.parameters.values())[-1] - assert param.kind == inspect.Parameter.VAR_POSITIONAL - extra = (f", where args{arg_key} is the index " - f"{arg_key.idx - len(ba.signature.parameters) + 1} component " - f"of {fun_name}'s varargs parameter '{param.name}',") - names = _canonicalize_spec(spec) - for d, ns in names.items(): - if aval.shape[d] % prod(mesh.shape[n] for n in ns): - axis = f"axes {ns}" if len(ns) > 1 else f"axis '{ns[0]}'" - total = 'total ' if len(ns) > 1 else '' - sz = prod(mesh.shape[n] for n in ns) - msgs.append( - f"* args{keystr(fail_key)} of shape {aval.str_short()}{extra} " - f"corresponds to in_specs{keystr(spec_key)} of value {spec}, " - f"which maps array axis {d} (of size {aval.shape[d]}) to mesh " - f"{axis} (of {total}size {sz}), but {sz} does not evenly divide " - f"{aval.shape[d]}") - assert msgs - if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point - msg = (f"shard_map applied to the function '{fun_name}' was given argument " - f"arrays with axis sizes that are not evenly divisible by the " - f"corresponding mesh axis sizes:\n\n" - f"The mesh given has shape {tuple(mesh.shape.values())} with " - f"corresponding axis names {mesh.axis_names}.\n\n" - + '\n\n'.join(msgs) + '\n\n' + - f"Array arguments' axis sizes must be evenly divisible by the mesh " - f"axis or axes indicated by the corresponding elements of the " - f"argument's in_specs entry. Consider checking that in_specs are " - f"correct, and if so consider changing the mesh axis sizes or else " - f"padding the input and adapting '{fun_name}' appropriately.") - return msg - -def _inout_rep_error(f: Callable, mesh: Mesh, tree: PyTreeDef, specs: Specs, - fails: list[set | NoFail]) -> str: - fun_name = getattr(f, '__name__', str(f)) - msgs = [] - for (spec_key, spec), (fail_key, rep) in _iter_paths(tree, specs, fails): - dst = _canonicalize_spec(spec) - unmentioned = _unmentioned(mesh, dst) - if len(unmentioned) > 1: - need_rep = ','.join(map(str, unmentioned)) - got_rep = ','.join(map(str, rep)) - diff = ','.join(map(str, [n for n in unmentioned if n not in rep])) - msgs.append( - f"* out_specs{keystr(spec_key)} is {spec} which implies that the " - f"corresponding output value is replicated across mesh axes " - f"{{{need_rep}}}, but could only infer replication over {{{got_rep}}}, " - f"which is missing the required axes {diff}") - else: - need_rep_, = unmentioned - msgs.append( - f"* out_specs{keystr(spec_key)} is {spec} which implies that the " - f"corresponding output value is replicated across mesh axis " - f"'{need_rep_}', but could not infer replication over any axes") - assert msgs - if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point - msg = (f"shard_map applied to the function '{fun_name}' was given " - f"out_specs which require replication which can't be statically " - f"inferred given the mesh:\n\n" - f"The mesh given has shape {tuple(mesh.shape.values())} with " - f"corresponding axis names {mesh.axis_names}.\n\n" - + '\n\n'.join(msgs) + '\n\n' + - "Check if these output values are meant to be replicated over those " - "mesh axes. If not, consider revising the corresponding out_specs " - "entries. If so, consider disabling the check by passing the " - "check_rep=False argument to shard_map.") - return msg - -def _unmentioned(mesh: Mesh, names: AxisNames) -> list[AxisName]: - name_set = {n for ns in names.values() for n in ns} - return [n for n in mesh.axis_names if n not in name_set] - - -def _try_infer_args(f, tree): - dummy_args = tree_unflatten(tree, [False] * tree.num_leaves) - try: - return inspect.signature(f).bind(*dummy_args) - except (TypeError, ValueError): - return None - -T = TypeVar('T') -def _iter_paths(tree: PyTreeDef, specs: Specs, fails: list[T | NoFail] - ) -> list[tuple[tuple[KeyPath, P], tuple[KeyPath, T]]]: - failures = tree_unflatten(tree, fails) - failures_aug = generate_key_paths(failures) - specs_ = tree_unflatten(tree_structure(specs), generate_key_paths(specs)) - leaf = lambda x: x is None or type(x) is tuple and len(x) == 2 and type(x[1]) is P - specs_aug = broadcast_prefix(specs_, failures, is_leaf=leaf) - return [(s, (fail_key, fail_data)) for s, (fail_key, fail_data) - in zip(specs_aug, failures_aug) - if s is not None and fail_data is not no_fail] - -# Primitive - -JaxType = Any -MaybeTracer = Union[JaxType, Tracer] - -class ShardMapPrimitive(core.Primitive): - multiple_results = True - - def bind(self, *args, **params): - return self._true_bind(*args, **params) - - def bind_with_trace(self, trace, fun_and_args, params): - fun: lu.WrappedFun - fun, *args = fun_and_args - return trace.process_shard_map(shard_map_p, fun, args, **params) - - def get_bind_params(self, params): - new_params = dict(params) - jaxpr: core.Jaxpr = new_params.pop('jaxpr') - subfun = lu.hashable_partial(lu.wrap_init(core.eval_jaxpr, - debug_info=jaxpr.debug_info), - jaxpr, ()) - axes = new_params.pop('out_names') - new_params['out_names_thunk'] = HashableFunction(lambda: axes, closure=axes) - return [subfun], new_params - -shard_map_p = ShardMapPrimitive('shard_map') - -# Staging - -@util.cache(max_size=256, trace_context_in_key=True) -def _as_manual_mesh(mesh, auto: frozenset): - manual_axes = tuple(set(mesh.axis_names) - auto) - cur_mesh = get_abstract_mesh() - if cur_mesh.empty: - cur_mesh = mesh - explicit_axes, auto_axes = set(), set() # type: ignore - for a in auto: - if cur_mesh._name_to_type[a] == AxisType.Auto: - auto_axes.add(a) - else: - assert cur_mesh._name_to_type[a] == AxisType.Explicit - explicit_axes.add(a) - - new_axis_types = [] - for n in mesh.axis_names: - if n in manual_axes: - new_axis_types.append(AxisType.Manual) - elif n in auto_axes: - new_axis_types.append(AxisType.Auto) - else: - assert n in explicit_axes - new_axis_types.append(AxisType.Explicit) - return AbstractMesh(mesh.axis_sizes, mesh.axis_names, - axis_types=tuple(new_axis_types)) - - -def _extend_axis_env(mesh, auto): - return core.extend_axis_env_nd([(k, v) for k, v in mesh.shape.items() - if k not in auto]) - -def _shard_map_staging( - trace: pe.DynamicJaxprTrace, prim: core.Primitive, f: lu.WrappedFun, - in_tracers: Sequence[Any], *, mesh: Mesh, - in_names: tuple[AxisNames, ...], - out_names_thunk: Callable[[], tuple[AxisNames, ...]], - check_rep: bool, - rewrite: bool, - auto: frozenset, - ) -> Sequence[pe.DynamicJaxprTracer]: - in_tracers = map(trace.to_jaxpr_tracer, in_tracers) - in_avals = [t.aval for t in in_tracers] - in_avals_ = map(partial(_shard_aval, mesh, auto), in_names, in_avals) - manual_mesh = _as_manual_mesh(mesh, auto) - with _extend_axis_env(mesh, auto), use_abstract_mesh(manual_mesh): - jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_) - _check_names(out_names_thunk(), out_avals_) - if check_rep: - in_rep = map(partial(_in_names_to_rep, mesh), in_names) - out_rep = _check_rep(mesh, jaxpr, in_rep) - _check_reps(mesh, out_names_thunk(), out_rep) - out_avals = map(_check_shapedarray, out_avals_) - out_avals = [_check_shapedarray(_unshard_aval(mesh, names, aval)) - for names, aval in zip(out_names_thunk(), out_avals)] - source_info = source_info_util.current() - out_tracers = [pe.DynamicJaxprTracer(trace, a, source_info) for a in out_avals] - invars = map(trace.getvar, in_tracers) - constvars = map(trace.getvar, map(trace.to_jaxpr_tracer, consts)) - outvars = map(trace.makevar, out_tracers) - in_names_staged = ({},) * len(consts) + tuple(in_names) # type: ignore - with _extend_axis_env(mesh, auto), use_abstract_mesh(manual_mesh): - jaxpr = pe.convert_constvars_jaxpr(jaxpr) - params = dict(mesh=mesh, in_names=in_names_staged, - out_names=tuple(out_names_thunk()), jaxpr=jaxpr, - check_rep=check_rep, rewrite=rewrite, auto=auto) - effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) - eqn = pe.new_jaxpr_eqn([*constvars, *invars], outvars, prim, params, - effs, source_info) - trace.frame.add_eqn(eqn) - return out_tracers -pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging - -# TODO add underscore version, for direct-linearize to consume - -def _check_shapedarray(aval: core.AbstractValue) -> core.ShapedArray: - assert isinstance(aval, core.ShapedArray) - return aval - -def _shard_aval(mesh: Mesh, auto, names: AxisNames, aval: core.AbstractValue - ) -> core.AbstractValue: - if type(aval) in core.shard_aval_handlers: - return core.shard_aval_handlers[type(aval)](mesh, auto, names, aval) - raise NotImplementedError(f"Unsupported aval type: {type(aval)}") - -def _unshard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue - ) -> core.AbstractValue: - if type(aval) in core.unshard_aval_handlers: - return core.unshard_aval_handlers[type(aval)](mesh, names, aval) - else: - raise NotImplementedError(f"Unsupported aval type: {type(aval)}") - -def _shard_shaped_array(mesh: Mesh, auto: frozenset, names: AxisNames, - aval: core.AbstractValue) -> core.AbstractValue: - assert isinstance(aval, core.ShapedArray) - new_shape = tuple(sz // prod(mesh.shape[n] for n in names.get(i, ())) - for i, sz in enumerate(aval.shape)) - manual_mesh = _as_manual_mesh(mesh, auto) - new_sharding = NamedSharding(manual_mesh, aval.sharding.spec) - return aval.update(shape=new_shape, sharding=new_sharding) -core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array - -def _unshard_shaped_array(mesh: Mesh, names: AxisNames, - aval: core.AbstractValue,) -> core.AbstractValue: - assert isinstance(aval, core.ShapedArray) - new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ())) - for i, sz in enumerate(aval.shape)) - names_spec = _names_to_pspec(names)._normalized_spec_for_aval(aval.ndim) - if aval.ndim == 0: - out_spec = names_spec - else: - out_spec = [] # type: ignore - for name_s, aval_s in zip(names_spec, aval.sharding.spec): - if name_s and not aval_s: - out_spec.append(name_s) - elif aval_s and not name_s: - out_spec.append(aval_s) - elif not name_s and not aval_s: - out_spec.append(None) - else: - assert name_s and aval_s - name_s = name_s if isinstance(name_s, tuple) else (name_s,) - aval_s = aval_s if isinstance(aval_s, tuple) else (aval_s,) - out_spec.append(name_s + aval_s) - out_spec = PartitionSpec(*out_spec) - new_mesh = (mesh.abstract_mesh if get_abstract_mesh().empty else - get_abstract_mesh()) - new_sharding = NamedSharding(new_mesh, out_spec) - return aval.update(shape=new_shape, sharding=new_sharding) -core.unshard_aval_handlers[core.ShapedArray] = _unshard_shaped_array - -# Type-checking - -RepType = Union[set[AxisName], None] - -def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names, - check_rep, rewrite, auto): - # TODO(mattjj,parkers): check auto - for v, x, in_name in zip(jaxpr.invars, in_atoms, in_names): - if not core.typecompat(v.aval, _shard_aval(mesh, auto, in_name, x.aval)): - raise core.JaxprTypeError("shard_map argument avals not compatible with " - "jaxpr binder avals and in_names") - with _extend_axis_env(mesh, auto): - core.check_jaxpr(jaxpr) - if check_rep: - in_rep = map(partial(_in_names_to_rep, mesh), in_names) - out_rep = _check_rep(mesh, jaxpr, in_rep) - for rep, dst in zip(out_rep, out_names): - if not _valid_repeats(mesh, rep, dst): - raise core.JaxprTypeError("shard_map can't prove output is " - "sufficiently replicated") - out_avals_sharded = [x.aval for x in jaxpr.outvars] - out_avals = map(partial(_unshard_aval, mesh), out_names, out_avals_sharded) - effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) - return out_avals, effs -core.custom_typechecks[shard_map_p] = _shard_map_typecheck - -def _in_names_to_rep(mesh: Mesh, names: AxisNames) -> set[AxisName]: - return set(mesh.axis_names) - {n for ns in names.values() for n in ns} - -def _check_rep(mesh: Mesh, jaxpr: core.Jaxpr, in_rep: Sequence[RepType] - ) -> Sequence[RepType]: - env: dict[core.Var, RepType] = {} - - def read(x: core.Atom) -> RepType: - return env[x] if type(x) is core.Var else None - - def write(v: core.Var, val: RepType) -> None: - env[v] = val - - foreach(write, jaxpr.constvars, [set(mesh.axis_names)] * len(jaxpr.constvars)) - foreach(write, jaxpr.invars, in_rep) - last_used = core.last_used(jaxpr) - for e in jaxpr.eqns: - rule = _check_rules.get(e.primitive, partial(_rule_missing, e.primitive)) - out_rep = rule(mesh, *map(read, e.invars), **e.params) - if e.primitive.multiple_results: - out_rep = [out_rep] * len(e.outvars) if type(out_rep) is set else out_rep - foreach(write, e.outvars, out_rep) - else: - write(e.outvars[0], out_rep) - core.clean_up_dead_vars(e, env, last_used) - return map(read, jaxpr.outvars) - -def _valid_repeats(mesh: Mesh, rep: RepType, dst: AxisNames) -> bool: - return rep is None or set(_unmentioned(mesh, dst)).issubset(rep) - -def _rule_missing(prim: core.Primitive, *_, **__): - raise NotImplementedError( - f"No replication rule for {prim}. As a workaround, pass the " - "`check_rep=False` argument to `shard_map`. To get this fixed, open an " - "issue at https://github.com/jax-ml/jax/issues") - -# Lowering - - -def _shardy_shard_map_sharding( - ctx: mlir.LoweringRuleContext, mesh, auto, names, aval_in -) -> sharding_impls.SdyArraySharding: - axes = {name: i for i, ns in names.items() for name in ns} - ns = _make_scoped_manual_sharding(ctx, mesh, axes) - if dtypes.issubdtype(aval_in.dtype, dtypes.extended): - ns = sharding_impls.physical_sharding(aval_in, ns) - aval_in = core.physical_aval(aval_in) - sdy_sharding = ns._to_sdy_sharding(aval_in.ndim) - if auto: - for dim_sharding in sdy_sharding.dimension_shardings: - # Only allow dimensions which have no sharding to be auto-sharded. - if not dim_sharding.axes: - dim_sharding.is_closed = False - return sdy_sharding - - -def _shard_map_lowering_shardy( - ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto): - in_avals_ = [v.aval for v in jaxpr.invars] - if isinstance(ctx.module_context.axis_context, sharding_impls.SPMDAxisContext): - # Nested `ManualComputationOp`s cannot refer to axes that are already - # manual. So figure out what axes are free thus far. - free_axes = frozenset(mesh.axis_names) - ctx.module_context.axis_context.manual_axes - shardy_manual_axes = free_axes - auto - else: - shardy_manual_axes = frozenset(mesh.axis_names) - auto - new_axis_context = sharding_impls.SPMDAxisContext( - mesh, frozenset(mesh.axis_names) - auto) - sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) - - # The order of manual axes should match the order of mesh.axis_names to avoid - # non-determinism issues. - manual_axes = [a for a in mesh.axis_names - if a in shardy_manual_axes] - if np.prod([mesh.shape[a] for a in manual_axes]) == 1: - # No need for a `ManualComputationOp` if all manual axes are size 1. - with _extend_axis_env(mesh, auto): - out_nodes, _ = mlir.jaxpr_subcomp( - sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *in_nodes, - dim_var_values=ctx.dim_var_values) - return out_nodes - - in_shardings = sharding_impls.SdyArrayShardingList(map( - partial(_shardy_shard_map_sharding, ctx, mesh, auto), - in_names, ctx.avals_in)).build() - out_shardings = sharding_impls.SdyArrayShardingList(map( - partial(_shardy_shard_map_sharding, ctx, mesh, auto), - out_names, ctx.avals_out)).build() - output_types = map(mlir.aval_to_ir_type, ctx.avals_out) - manual_computation_op = sdy.ManualComputationOp( - output_types, in_nodes, in_shardings, out_shardings, - sdy.ManualAxesAttr.get( - ir.ArrayAttr.get([ir.StringAttr.get(i) for i in manual_axes]))) - block = ir.Block.create_at_start( - manual_computation_op.body, map(mlir.aval_to_ir_type, in_avals_)) - with ir.InsertionPoint(block), _extend_axis_env(mesh, auto): - out_nodes_, _ = mlir.jaxpr_subcomp( - sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *block.arguments, - dim_var_values=ctx.dim_var_values) - sdy.ReturnOp([ir.Value(x) for x in out_nodes_]) - - return manual_computation_op.results - - -def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names, - check_rep, rewrite, auto): - del check_rep, rewrite - - if config.use_shardy_partitioner.value: - return _shard_map_lowering_shardy( - ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto) - - in_avals_ = [v.aval for v in jaxpr.invars] - out_avals_ = [x.aval for x in jaxpr.outvars] - in_nodes_ = map(partial(_xla_shard, ctx, mesh, auto), in_names, ctx.avals_in, - in_avals_, in_nodes) - manual_axes = frozenset(mesh.axis_names) - auto - new_axis_context = sharding_impls.SPMDAxisContext(mesh, manual_axes) - sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) - with _extend_axis_env(mesh, auto): - out_nodes_, tokens_out = mlir.call_lowering( - "shmap_body", ctx.name_stack, jaxpr, None, sub_ctx, in_avals_, - out_avals_, ctx.tokens_in, *in_nodes_, dim_var_values=ctx.dim_var_values, - arg_names=map(_pspec_mhlo_attrs, in_names, in_avals_), - result_names=map(_pspec_mhlo_attrs, out_names, out_avals_)) - ctx.set_tokens_out(tokens_out) - return map(partial(_xla_unshard, ctx, mesh, auto), out_names, out_avals_, - ctx.avals_out, out_nodes_) -mlir.register_lowering(shard_map_p, _shard_map_lowering) - -def _make_scoped_manual_sharding(ctx, mesh, axes): - axis_ctx = ctx.module_context.axis_context - if isinstance(axis_ctx, sharding_impls.SPMDAxisContext): - manual_axes = axis_ctx.manual_axes - else: - manual_axes = frozenset({}) - return NamedSharding( - mesh, sharding_impls.array_mapping_to_axis_resources(axes), # pytype: disable=wrong-arg-types - _manual_axes=manual_axes) - -def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names, - aval_in, aval_out, x): - if prod([size for n, size in mesh.shape.items() if n not in auto]) == 1: - return x - axes = {name: i for i, ns in names.items() for name in ns} - ns = _make_scoped_manual_sharding(ctx, mesh, axes) - if dtypes.issubdtype(aval_in.dtype, dtypes.extended): - ns = sharding_impls.physical_sharding(aval_in, ns) - aval_in = core.physical_aval(aval_in) - shard_proto = ns._to_xla_hlo_sharding(aval_in.ndim).to_proto() - unspecified = set(range(aval_in.ndim)) if auto else set() - sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto, - unspecified_dims=unspecified) - manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh) - return mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, unspecified) - -def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, auto, names, - aval_in, aval_out, x): - if prod([size for n, size in mesh.shape.items() if n not in auto]) == 1: - return x - axes = {name: i for i, ns in names.items() for name in ns} - ns = _make_scoped_manual_sharding(ctx, mesh, axes) - if dtypes.issubdtype(aval_out.dtype, dtypes.extended): - ns = sharding_impls.physical_sharding(aval_out, ns) - aval_out = core.physical_aval(aval_out) - unspecified = set(range(aval_out.ndim)) if auto else set() - if dtypes.issubdtype(aval_in.dtype, dtypes.extended): - aval_in = core.physical_aval(aval_in) - manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh) - sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto, unspecified_dims=unspecified) - shard_proto = ns._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, shard_proto, - unspecified) - -def _pspec_mhlo_attrs(names: AxisNames, aval: core.AbstractValue) -> str: - if isinstance(aval, core.ShapedArray): - return str(map(names.get, range(aval.ndim))) - return '' - -# Eager evaluation - -def get_mesh_from_args(args_flat, mesh): - for a in args_flat: - if hasattr(a, 'sharding') and isinstance(a.sharding, NamedSharding): - if a.sharding.mesh.shape_tuple != mesh.shape_tuple: - aval = core.shaped_abstractify(a) - raise ValueError( - f"Mesh shape of the input {a.sharding.mesh.shape_tuple} does not" - " match the mesh shape passed to shard_map " - f" {mesh.shape_tuple} for shape {aval.str_short()}") - mesh = a.sharding.mesh - if isinstance(mesh, AbstractMesh): - raise ValueError( - "Please pass `jax.Array`s with a `NamedSharding` as input to" - " `shard_map` when passing `AbstractMesh` to the mesh argument.") - assert isinstance(mesh, Mesh) - return mesh - -def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, - check_rep, rewrite, auto): - if auto: raise NotImplementedError - del prim - if isinstance(mesh, AbstractMesh): - mesh = get_mesh_from_args(args, mesh) - args = map(partial(_unmatch_spec, mesh, context_mesh=get_abstract_mesh()), - in_names, args) - in_rep = map(partial(_in_names_to_rep, mesh), in_names) - outs, out_rep = _run_shmap(fun, mesh, auto, args, in_rep, check_rep, - get_abstract_mesh()) - out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs] - _check_names(out_names_thunk(), out_avals) # pytype: disable=wrong-arg-types - if check_rep: - _check_reps(mesh, out_names_thunk(), out_rep) - pspecs = map(_names_to_pspec, out_names_thunk()) - return map(partial(_match_spec, mesh, check_rep), pspecs, outs) -core.EvalTrace.process_shard_map = _shard_map_impl - -def _run_shmap(f, mesh, auto, args, reps, check_rep, context_mesh): - trace = ShardMapTrace(mesh, auto, check_rep, context_mesh) - in_tracers = map(partial(ShardMapTracer, trace), reps, args) - manual_mesh = _as_manual_mesh(mesh, auto) - with (core.set_current_trace(trace), _extend_axis_env(mesh, auto), - use_abstract_mesh(manual_mesh)): - ans = f.call_wrapped(*in_tracers) - outs, out_rep = unzip2(map(trace.to_val_rep_pair, ans)) - return outs, out_rep - -def _names_to_pspec(names: AxisNames) -> PartitionSpec: - ndmin = max(names) + 1 if names else 0 - unpack = lambda t: t[0] if t is not None and len(t) == 1 else t - return PartitionSpec(*(unpack(names.get(i)) for i in range(ndmin))) - -def _unmatch_spec(mesh: Mesh, src: AxisNames, x: JaxType, context_mesh) -> JaxType: - with (core.eval_context(), jax.disable_jit(False), - use_abstract_mesh(context_mesh)): - return jax.jit(HashablePartial(_unmatch, mesh, tuple(src.items())))(x) - -def _unmatch(mesh, src_tup, x): - src = _names_to_pspec(dict(src_tup)) - dst = P(mesh.axis_names) - return shard_map(_add_singleton, mesh, (src,), dst, check_rep=False)(x) - -def _check_names(names: Sequence[AxisNames], avals: Sequence[core.ShapedArray] - ) -> None: - fail = [a if n and not max(n) < a.ndim else no_fail - for n, a in zip(names, avals)] - if any(f is not no_fail for f in fail): raise _SpecError(fail) -class _SpecError(Exception): pass - -def _check_reps(mesh, names, reps): - fail = [r if not _valid_repeats(mesh, r, n) else no_fail - for n, r in zip(names, reps)] - if any(f is not no_fail for f in fail): raise _RepError(fail) -class _RepError(Exception): pass - -def _check_reps2(mesh, reps_dest, reps): - fail = [src if not dst.issubset(src) else no_fail - for dst, src in zip(reps_dest, reps)] - if any(f is not no_fail for f in fail): raise _RepError(fail) - -def _match_spec(mesh: Mesh, check_rep: bool, - pspec: PartitionSpec, x: JaxType) -> JaxType: - fn = HashablePartial(_match, mesh, check_rep, pspec) - with core.eval_context(), jax.disable_jit(False): - return jax.jit(fn, out_shardings=NamedSharding(mesh, pspec))(x) - -def _match(mesh, check_rep, pspec, x): - src = P(mesh.axis_names) - return shard_map(_rem_singleton, mesh, (src,), pspec, check_rep=False)(x) - -def _rem_singleton(x): return jnp.squeeze(x, axis=0) -def _add_singleton(x): return jnp.expand_dims(x, axis=0) - -def _maybe_check_special(outs): - if not config.debug_nans.value and not config.debug_infs.value: return - bufs = [s.data for leaf in tree_leaves(outs) - for s in getattr(leaf, 'addressable_shards', [])] - try: - dispatch.check_special('shard_map', bufs) - except dispatch.InternalFloatingPointError as e: - raise FloatingPointError(f'Invalid value ({e.ty}) encountered in sharded computation.') from None - -class ShardMapTrace(core.Trace): - __slots__ = ("mesh", "auto", "check", "context_mesh") - - mesh: Mesh - auto: frozenset[AxisName] - check: bool - context_mesh: AbstractMesh - - def __init__(self, mesh, auto, check, context_mesh): - super().__init__() - self.mesh = mesh - self.auto = auto - self.check = check - self.context_mesh = context_mesh - - def to_val_rep_pair(self, val): - if isinstance(val, ShardMapTracer): - return val.val, val.rep - elif isinstance(val, Tracer): - raise Exception(f"Shouldn't have any non-shard_map tracers: {val}") - else: - val_ = _unmatch_spec(self.mesh, {}, val, self.context_mesh) - return val_, None - - def process_primitive(self, prim, tracers, params): - in_vals, in_rep = unzip2(map(self.to_val_rep_pair, tracers)) - eager_rule = eager_rules.get(prim) - if eager_rule: - out_vals = eager_rule(self.mesh, *in_vals, **params) - else: - f = HashablePartial(_prim_applier, prim, tuple(params.items()), self.mesh) - with (core.eval_context(), jax.disable_jit(False), jax.debug_nans(False), - jax.debug_infs(False), use_abstract_mesh(self.context_mesh)): - out_vals = jax.jit(f)(*in_vals) - _maybe_check_special(out_vals) - rep_rule = _check_rules.get(prim, partial(_rule_missing, prim)) - out_rep = rep_rule(self.mesh, *in_rep, **params) if self.check else set() - if prim.multiple_results: - out_rep = [out_rep] * len(out_vals) if type(out_rep) is set else out_rep - return map(partial(ShardMapTracer, self), out_rep, out_vals) - return ShardMapTracer(self, out_rep, out_vals) - - def process_call(self, call_primitive, fun, tracers, params): - raise NotImplementedError( - f"Eager evaluation of `{call_primitive}` inside a `shard_map` isn't " - "yet supported. Put a `jax.jit` around the `shard_map`-decorated " - "function, and open a feature request at " - "https://github.com/jax-ml/jax/issues !") - - def process_map(self, map_primitive, fun, tracers, params): - raise NotImplementedError( - "Eager evaluation of `pmap` inside a `shard_map` isn't yet supported." - "Put a `jax.jit` around the `shard_map`-decorated function, and open " - "a feature request at https://github.com/jax-ml/jax/issues !") - - def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): - # Since ShardMapTrace is only used as a base main, we can drop the jvp. - if symbolic_zeros: - msg = ("custom_jvp symbolic_zeros support with shard_map is not " - "implemented; please open an issue at " - "https://github.com/jax-ml/jax/issues") - raise NotImplementedError(msg) - del prim, jvp, symbolic_zeros - in_vals, in_rep = unzip2(map(self.to_val_rep_pair, tracers)) - out_vals, out_rep = _run_shmap(fun, self.mesh, self.auto, in_vals, in_rep, self.check, - self.context_mesh) - return map(partial(ShardMapTracer, self), out_rep, out_vals) - - def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, - symbolic_zeros): - if symbolic_zeros: - msg = ("custom_vjp symbolic_zeros support with shard_map is not " - "implemented; please open an issue at " - "https://github.com/jax-ml/jax/issues") - raise NotImplementedError(msg) - del prim, fwd, bwd, out_trees, symbolic_zeros - in_vals, in_rep = unzip2(map(self.to_val_rep_pair, tracers)) - out_vals, out_rep = _run_shmap(fun, self.mesh, self.auto, in_vals, in_rep, self.check, - self.context_mesh) - return map(partial(ShardMapTracer, self), out_rep, out_vals) - - -class ShardMapTracer(core.Tracer): - rep: RepType - val: JaxType - - def __init__(self, trace, rep, val): - self._trace = trace - self.rep = rep - self.val = val - - @property - def aval(self): - aval = core.get_aval(self.val) - out = core.mapped_aval(self._trace.mesh.size, 0, aval) - new_sharding = NamedSharding( - _as_manual_mesh(self._trace.mesh, self._trace.auto), - out.sharding.spec) # pytype: disable=attribute-error - return out.update(sharding=new_sharding) - - def to_concrete_value(self): - if self.rep == set(self._trace.mesh.axis_names): - with core.eval_context(), use_abstract_mesh(self._trace.context_mesh): - return core.to_concrete_value(self.val[0]) - else: - return None - - def __str__(self) -> str: - with core.eval_context(), use_abstract_mesh(self._trace.context_mesh): - blocks = list(self.val) - mesh = self._trace.mesh - axis_names = f"({', '.join(map(str, mesh.axis_names))},)" - return '\n'.join( - f"On {device} at mesh coordinates {axis_names} = {idx}:\n{block}\n" - for (idx, device), block in zip(np.ndenumerate(mesh.devices), blocks)) - __repr__ = __str__ # for debuggers, like `p x` - -def _prim_applier(prim, params_tup, mesh, *args): - def apply(*args): - outs = prim.bind(*map(_rem_singleton, args), **dict(params_tup)) - return tree_map(_add_singleton, outs) - spec = P(mesh.axis_names) - return shard_map(apply, mesh, spec, spec, False)(*args) - -eager_rules: dict[core.Primitive, Callable] = {} - -# TODO(mattjj): working around an apparent XLA or PjRt bug, remove eventually -def _debug_callback_eager_rule(mesh, *args, callback: Callable[..., Any], - effect: debugging.DebugEffect): - del effect - with core.eval_context(): - all_blocks = zip(*map(list, args)) - for (idx, device), blocks in zip(np.ndenumerate(mesh.devices), all_blocks): - callback(*blocks) - return [] -eager_rules[debugging.debug_callback_p] = _debug_callback_eager_rule - -def _device_put_eager_rule(mesh, *xs, srcs, devices, copy_semantics): - del mesh, srcs, copy_semantics - for device in devices: - if device is not None: - raise ValueError("device_put with explicit device not allowed within " - f"shard_map-decorated functions, but got device {device}") - return xs -eager_rules[dispatch.device_put_p] = _device_put_eager_rule - -# New primitives for efficient transposition - -# psum2_p is like psum_p except has a different transpose, so mostly copied: -psum2_p = core.Primitive('psum2') -psum2_p.multiple_results = True -psum2_p.def_impl(lax_parallel.psum_p.impl) -psum2_p.def_effectful_abstract_eval(lax_parallel.psum_p.abstract_eval) -mlir.register_lowering(psum2_p, mlir._lowerings[lax_parallel.psum_p]) -batching.fancy_primitive_batchers[psum2_p] = \ - partial(lax_parallel._batched_reduction_collective, psum2_p, - lambda v, axis_size: axis_size * v) -batching.skippable_batchers[psum2_p] = partial(lax_parallel._names_in_param, 'axes') - -def _psum2_transpose_rule(cts, *args, axes, axis_index_groups): - del args - return pbroadcast_p.bind(*cts, axes=axes, axis_index_groups=axis_index_groups) -ad.deflinear2(psum2_p, _psum2_transpose_rule) - -# pbroadcast_p is exactly the transpose of psum2_p -def pbroadcast(x, axis_name): - axes = (axis_name,) if not isinstance(axis_name, tuple) else axis_name - if not axis_name: return x - xs, treedef = tree_flatten(x) - ys = pbroadcast_p.bind(*xs, axes=axes, axis_index_groups=None) - return tree_unflatten(treedef, ys) -pbroadcast_p = core.Primitive('pbroadcast') -pbroadcast_p.multiple_results = True -pbroadcast_p.def_impl(lambda *args, axes, axis_index_groups: args) -pbroadcast_p.def_abstract_eval(lambda *args, axes, axis_index_groups: args) -mlir.register_lowering(pbroadcast_p, lambda ctx, *x, axes, axis_index_groups: x) -def _pbroadcast_batcher(vals_in, dims_in, *, axes, axis_index_groups): - if any(type(axis) is int for axis in axes): raise NotImplementedError - vals_out = pbroadcast_p.bind(*vals_in, axes=axes, - axis_index_groups=axis_index_groups) - return vals_out, dims_in -batching.primitive_batchers[pbroadcast_p] = _pbroadcast_batcher -ad.deflinear2(pbroadcast_p, - lambda cts, *_, axes, axis_index_groups: - psum2_p.bind(*cts, axes=axes, axis_index_groups=axis_index_groups)) - -# Rewrite rules and static replication checking for efficient transposition - -_rewrite_rules: dict[core.Primitive, Callable] = {} -register_rewrite = lambda prim: lambda r: _rewrite_rules.setdefault(prim, r) -register_standard_rewrite = lambda prim: \ - _rewrite_rules.setdefault(prim, partial(_standard_rewrite_rule, prim)) -register_norewrite = lambda p: \ - _rewrite_rules.setdefault(p, partial(_no_rewrite, p, _check_rules[p])) - -_check_rules: dict[core.Primitive, Callable] = {} -register_check = lambda prim: lambda rule: _check_rules.setdefault(prim, rule) -register_standard_check = \ - lambda prim: _check_rules.setdefault(prim, partial(_standard_check, prim)) - -def _no_rewrite(prim, rule, mesh, in_rep, *args, **params): - out_vals = prim.bind(*args,**params) - out_rep = rule(mesh, *in_rep, **params) - if prim.multiple_results: - out_rep_ = out_rep if type(out_rep) is list else [out_rep] * len(out_vals) - else: - out_vals, out_rep_ = [out_vals], [out_rep] - return out_vals, out_rep_ - -def _standard_rewrite_rule(prim, mesh, in_rep, *args, **params): - # The standard rewrite inserts pbroadcasts but doesn't change the primitive. - out_rep_ = set.intersection(*in_rep) if in_rep else set(mesh.axis_names) - args_ = [pbroadcast(x, tuple(n for n in src if n not in out_rep_)) - if src - out_rep_ else x for x, src in zip(args, in_rep)] - out_vals_ = prim.bind(*args_, **params) - out_rep = [out_rep_] * len(out_vals_) if prim.multiple_results else [out_rep_] - out_vals = [out_vals_] if not prim.multiple_results else out_vals_ - return out_vals, out_rep - -def _standard_check(prim, mesh, *in_rep, **__): - # The standard check require args' and outputs' replications to be the same, - # except for Nones which correspond to constants. - in_rep_ = [r for r in in_rep if r is not None] - if in_rep_ and not in_rep_[:-1] == in_rep_[1:]: - raise Exception(f"Primitive {prim} requires argument replication types " - f"to match, but got {in_rep}. Please open an issue at " - "https://github.com/jax-ml/jax/issues and as a temporary " - "workaround pass the check_rep=False argument to shard_map") - return in_rep_[0] if in_rep_ else None - -def register_standard_collective(prim): - register_check(prim)(partial(_standard_collective_check, prim)) - register_rewrite(prim)(partial(_standard_collective_rewrite, prim)) - -def register_reduction_collective(prim): - register_check(prim)(partial(_reduction_collective_check, prim)) - register_rewrite(prim)(partial(_reduction_collective_rewrite, prim)) - -def _standard_collective_check(prim, mesh, x_rep, *, axis_name, **params): - # The standard collective check is varying -> varying over axis_name. - del mesh, params - if x_rep is None or axis_name in x_rep: - raise Exception(f"Collective {prim} must be applied to a device-varying " - f"replication type, but got {x_rep} for collective acting " - f"over axis name {axis_name}. Please open an issue at " - "https://github.com/jax-ml/jax/issues and as a temporary " - "workaround pass the check_rep=False argument to shard_map") - return x_rep - -def _standard_collective_rewrite(prim, mesh, in_rep, x, axis_name, **params): - # The standard collective rewrite may insert a pbroadcast on the input. - axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name - x_rep, = in_rep - axis_name_set = set(axis_name) - if pbroadcast_axis_name := axis_name_set & x_rep: - x = pbroadcast(x, tuple(pbroadcast_axis_name)) - out_val = prim.bind(x, axis_name=axis_name, **params) - return [out_val], [x_rep - axis_name_set] - -def _reduction_collective_check(prim, mesh, x_rep, *, axes, **params): - # The reduction collective check is varying -> replicated over axes. - del mesh, params - axes = (axes,) if not isinstance(axes, tuple) else axes - if x_rep is None or any(a in x_rep for a in axes): - raise Exception(f"Collective {prim} must be applied to a device-varying " - f"replication type, but got {x_rep} for collective acting " - f"over axis name {axes}. Please open an issue at " - "https://github.com/jax-ml/jax/issues and as a temporary " - "workaround pass the check_rep=False argument to shard_map") - return x_rep | set(axes) - -def _reduction_collective_rewrite(prim, mesh, in_rep, x, axes, **params): - # The standard collective rewrite may insert a pbroadcast on the input. - axes = (axes,) if not isinstance(axes, tuple) else axes - x_rep, = in_rep - axes_set = set(axes) - if pbroadcast_axes := axes_set & x_rep: - x = pbroadcast(x, tuple(pbroadcast_axes)) - out_val, = prim.bind(x, axes=axes, **params) - return [out_val], [x_rep | axes_set] - - -for o in it.chain(lax.__dict__.values(), slicing.__dict__.values(), - windowed_reductions.__dict__.values(), - special.__dict__.values(), convolution.__dict__.values(), - fft.__dict__.values(), linalg.__dict__.values(), - ops.__dict__.values(), ad_util.__dict__.values(), - prng.__dict__.values(), ann.__dict__.values(), - random.__dict__.values()): - if isinstance(o, core.Primitive): - register_standard_check(o) - register_standard_rewrite(o) - -for p in [control_flow.loops.cumsum_p, control_flow.loops.cumlogsumexp_p, - control_flow.loops.cumprod_p, control_flow.loops.cummax_p, - control_flow.loops.cummin_p, pjit.sharding_constraint_p, - pjit.mesh_cast_p]: - register_standard_check(p) - register_standard_rewrite(p) - - -@register_check(lax_parallel.psum_p) -def _psum_check(_, *in_rep, axes, axis_index_groups): - assert False # should be rewritten away - -@register_rewrite(lax_parallel.psum_p) -def _psum_rewrite(mesh, in_rep, *args, axes, axis_index_groups): - # Replace the psum with psum2, insert pbroadcasts on input, replicated output. - if axis_index_groups is not None: raise NotImplementedError - axes = (axes,) if not isinstance(axes, tuple) else axes - axes_ = set(axes) - out_rep = [r | axes_ for r in in_rep] # TODO determinism (and elsewhere) - args_ = [pbroadcast(x, tuple(n for n in mesh.axis_names if n in axes_ & src)) - for x, src in zip(args, in_rep)] - out_val = psum2_p.bind(*args_, axes=axes, axis_index_groups=axis_index_groups) - return out_val, out_rep - - -@register_check(psum2_p) -def _psum2_check(mesh, *in_rep, axes, axis_index_groups): - assert type(axes) is tuple - if any(set(axes) & r for r in in_rep if r is not None): - raise Exception("Collective psum must be applied to a device-varying " - f"replication type, but got {in_rep} for collective acting " - f"over axis name {axes}. Please open an issue at " - "https://github.com/jax-ml/jax/issues, and as a temporary " - "workaround pass the check_rep=False argument to shard_map") - in_rep = tuple(set(mesh.axis_names) if r is None else r for r in in_rep) - return [r | set(axes) for r in in_rep] -register_norewrite(psum2_p) - - -@register_check(pbroadcast_p) -def _pbroadcast_check(mesh, *in_rep, axes, axis_index_groups): - assert type(axes) is tuple - if not all(r is None or set(axes) & r for r in in_rep): - raise Exception("Collective pbroadcast must be applied to a " - "non-device-varying " - f"replication type, but got {in_rep} for collective acting " - f"over axis name {axes}. Please open an issue at " - "https://github.com/jax-ml/jax/issues, and as a temporary " - "workaround pass the check_rep=False argument to shard_map") - in_rep = tuple(set(mesh.axis_names) if r is None else r for r in in_rep) - return [r - set(axes) for r in in_rep] -register_norewrite(pbroadcast_p) - - -register_standard_collective(lax_parallel.all_gather_p) -register_standard_collective(lax_parallel.all_to_all_p) -register_standard_collective(lax_parallel.ppermute_p) -register_standard_collective(lax_parallel.reduce_scatter_p) -register_reduction_collective(lax_parallel.pmin_p) -register_reduction_collective(lax_parallel.pmax_p) - - -@register_check(lax_parallel.axis_index_p) -def _axis_index_check(mesh, *, axis_name): - axis_name = (axis_name,) if not type(axis_name) is tuple else axis_name - return set(mesh.shape) - set(axis_name) -register_norewrite(lax_parallel.axis_index_p) - - -@register_rewrite(pjit.pjit_p) -def _pjit_rewrite(mesh, in_rep, *args, jaxpr, **kwargs): - jaxpr_, out_rep = _replication_rewrite_nomatch(mesh, jaxpr, in_rep) - out_vals = pjit.pjit_p.bind(*args, jaxpr=jaxpr_, **kwargs) - return out_vals, out_rep - -@register_check(pjit.pjit_p) -def _pjit_check(mesh, *in_rep, jaxpr, **kwargs): - return _check_rep(mesh, jaxpr.jaxpr, in_rep) - - -@register_rewrite(ad_checkpoint.remat_p) -def _remat_rewrite(mesh, in_rep, *args, jaxpr, **kwargs): - jaxpr_ = pe.close_jaxpr(jaxpr) - jaxpr_, out_rep = _replication_rewrite_nomatch(mesh, jaxpr_, in_rep) - jaxpr, () = jaxpr_.jaxpr, jaxpr_.consts - out_vals = ad_checkpoint.remat_p.bind(*args, jaxpr=jaxpr, **kwargs) - return out_vals, out_rep - -@register_check(ad_checkpoint.remat_p) -def _remat_check(mesh, *in_rep, jaxpr, **kwargs): - return _check_rep(mesh, jaxpr, in_rep) - - -@register_check(core.call_p) -def _core_call_check(mesh, *in_rep, call_jaxpr, **kwargs): - return _check_rep(mesh, call_jaxpr, in_rep) - - -@register_check(debugging.debug_callback_p) -def _debug_callback_rule(mesh, *in_rep, **_): - return [] -register_norewrite(debugging.debug_callback_p) - - -@register_check(callback.pure_callback_p) -def _pure_callback_rule(mesh, *_, result_avals, **__): - return [set()] * len(result_avals) -register_norewrite(callback.pure_callback_p) - - -@register_check(callback.io_callback_p) -def _io_callback_rule(mesh, *_, result_avals, **__): - return [set()] * len(result_avals) -register_norewrite(callback.io_callback_p) - - -@register_check(dispatch.device_put_p) -def _device_put_rule(mesh, *xs, **_): - return list(xs) -register_norewrite(dispatch.device_put_p) - - -@register_check(ad.custom_lin_p) -def _custom_lin_rule(mesh, *_, out_avals, **__): - return [set()] * len(out_avals) -register_norewrite(ad.custom_lin_p) - - -@register_check(control_flow.loops.scan_p) -def _scan_check(mesh, *in_rep, jaxpr, num_consts, num_carry, **_): - _, carry_rep_in, _ = split_list(in_rep, [num_consts, num_carry]) - out_rep = _check_rep(mesh, jaxpr.jaxpr, in_rep) - carry_rep_out, _ = split_list(out_rep, [num_carry]) - if not carry_rep_in == carry_rep_out: - raise Exception("Scan carry input and output got mismatched replication " - f"types {carry_rep_in} and {carry_rep_out}. Please open an " - "issue at https://github.com/jax-ml/jax/issues, and as a " - "temporary workaround pass the check_rep=False argument to " - "shard_map") - return out_rep - -@register_rewrite(control_flow.loops.scan_p) -def _scan_rewrite(mesh, in_rep, *args, jaxpr, num_consts, num_carry, **params): - const_rep, carry_rep_in, xs_rep = split_list(in_rep, [num_consts, num_carry]) - for _ in range(1 + num_carry): - in_rep_ = [*const_rep, *carry_rep_in, *xs_rep] - _, out_rep = _replication_rewrite_nomatch(mesh, jaxpr, in_rep_) - carry_rep_out, ys_rep = split_list(out_rep, [num_carry]) - carry_rep_out = map(op.and_, carry_rep_in, carry_rep_out) - if carry_rep_in == carry_rep_out: - break - else: - carry_rep_in = carry_rep_out - else: - assert False, 'Fixpoint not reached' - - args = [pbroadcast(x, tuple(n for n in src if n not in dst)) - if src - dst else x for x, src, dst in zip(args, in_rep, in_rep_)] - out_rep = [*carry_rep_out, *ys_rep] - jaxpr_ = _replication_rewrite_match(mesh, jaxpr, in_rep_, out_rep) - - out_vals = control_flow.loops.scan_p.bind( - *args, jaxpr=jaxpr_, num_consts=num_consts, num_carry=num_carry, **params) - return out_vals, out_rep - -@register_check(control_flow.conditionals.cond_p) -def _cond_rule(mesh, *in_rep, branches): - _, *args_rep = in_rep - out_rep = _check_rep(mesh, branches[0].jaxpr, args_rep) - for branch in branches[1:]: - out_rep_ = _check_rep(mesh, branch.jaxpr, args_rep) - if not out_rep_ == out_rep: - raise Exception("The branches of cond produced mismatched replication " - "types. Please open an issue at " - "https://github.com/jax-ml/jax/issues, and as a " - "temporary workaround pass the check_rep=False argument " - "to shard_map") - return out_rep - -@register_rewrite(control_flow.conditionals.cond_p) -def _cond_rewrite(mesh, in_rep, *args, branches): - pred_rep, *args_rep = in_rep - _, out_rep = _replication_rewrite_nomatch(mesh, branches[0], args_rep) - for branch in branches[1:]: - _, out_rep_ = _replication_rewrite_nomatch(mesh, branch, args_rep) - if out_rep: - out_rep = map(op.and_, out_rep, out_rep_) - else: - out_rep = out_rep_ - out_rep = map(partial(op.and_, pred_rep), out_rep) - branches_ = tuple(_replication_rewrite_match(mesh, branch, args_rep, out_rep) - for branch in branches) - out_vals = control_flow.conditionals.cond_p.bind(*args, branches=branches_) - return out_vals, out_rep - -@register_check(control_flow.conditionals.platform_index_p) -def _platform_index_rule(mesh, *_, **__): - return set(mesh.axis_names) -register_norewrite(control_flow.conditionals.platform_index_p) - -@register_rewrite(core.closed_call_p) -def _closed_call_rewrite(mesh, in_rep, *args, call_jaxpr, **kwargs): - new_jaxpr, out_rep = _replication_rewrite_nomatch(mesh, call_jaxpr, in_rep) - out_vals = core.closed_call_p.bind(*args, jaxpr=new_jaxpr, **kwargs) - return out_vals, out_rep - -@register_check(core.closed_call_p) -def _closed_call_check(mesh, *in_rep, call_jaxpr, **kwargs): - return _check_rep(mesh, call_jaxpr.jaxpr, in_rep) - - -@register_check(custom_derivatives.custom_jvp_call_p) -def _custom_jvp_call_check(mesh, *in_rep, call_jaxpr, jvp_jaxpr_fun, - num_consts, symbolic_zeros): - return _check_rep(mesh, call_jaxpr.jaxpr, in_rep) - -@register_rewrite(custom_derivatives.custom_vjp_call_jaxpr_p) -def _custom_vjp_call_jaxpr_rewrite( - mesh, in_rep, *args, fun_jaxpr, fwd_jaxpr_thunk, bwd, num_consts, out_trees, - symbolic_zeros): - if symbolic_zeros: - msg = ("Please open an issue at https://github.com/jax-ml/jax/issues and as" - " a temporary workaround pass the check_rep=False argument to " - "shard_map") - raise NotImplementedError(msg) - - fun_jaxpr_, out_rep = _replication_rewrite_nomatch(mesh, fun_jaxpr, in_rep) - _, in_rep_ = split_list(in_rep, [num_consts]) - out_rep2 = [] - - @pe._memoize - def fwd_jaxpr_thunk_(*zeros): - fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk(*zeros)) - fwd_jaxpr_, out_rep = _replication_rewrite_nomatch(mesh, fwd_jaxpr, in_rep_) - out_rep2.append(out_rep) - return fwd_jaxpr_.jaxpr, fwd_jaxpr_.consts - - bwd_ = _rewrite_bwd(bwd, mesh, lambda: out_rep2[0], in_rep_) - - outs = custom_derivatives.custom_vjp_call_jaxpr_p.bind( - *args, fun_jaxpr=fun_jaxpr_, fwd_jaxpr_thunk=fwd_jaxpr_thunk_, bwd=bwd_, - num_consts=num_consts, out_trees=out_trees, symbolic_zeros=symbolic_zeros) - out_rep = out_rep2[0] if out_rep2 else out_rep - return outs, out_rep - -@register_check(custom_derivatives.custom_vjp_call_jaxpr_p) -def _custom_vjp_call_jaxpr_check(mesh, *in_rep, fun_jaxpr, **_): - return _check_rep(mesh, fun_jaxpr.jaxpr, in_rep) - -@register_check(control_flow.solves.linear_solve_p) -def _linear_solve_check(mesh, *in_rep, jaxprs, **_): - out_rep = _standard_check(control_flow.solves.linear_solve_p, mesh, *in_rep) - return [out_rep] * len(jaxprs.solve.out_avals) -register_standard_rewrite(control_flow.solves.linear_solve_p) - -@register_check(ffi.ffi_call_p) -def _ffi_call_check(mesh, *in_rep, result_avals, **_): - out_rep = _standard_check(ffi.ffi_call_p, mesh, *in_rep) - return [out_rep] * len(result_avals) -register_standard_rewrite(ffi.ffi_call_p) - -del _check_rules[lax.tie_p] - -@register_check(lax.tie_p) -def _tie_check(mesh, x_rep, y_rep): - return x_rep -register_norewrite(lax.tie_p) - - -# Batching - -def _shard_map_batch( - trace: batching.BatchTrace, prim: core.Primitive, fun: lu.WrappedFun, - in_tracers: Sequence[batching.BatchTracer], mesh: Mesh, - in_names: tuple[AxisNames, ...], - out_names_thunk: Callable[[], tuple[AxisNames, ...]], - check_rep: bool, - rewrite: bool, - auto: frozenset) -> Sequence[batching.BatchTracer]: - in_vals, in_dims = unzip2(map(trace.to_batch_info, in_tracers)) - if any(isinstance(d, batching.RaggedAxis) for d in in_dims): - raise NotImplementedError - new_in_names = [{ax + (d is not batching.not_mapped and d <= ax): names[ax] - for ax in names} for names, d in zip(in_names, in_dims)] - spmd_axis_name = trace.axis_data.spmd_name - if spmd_axis_name is not None: - used = {n for names in in_names for ns in names.values() for n in ns} - if not config.disable_vmap_shmap_error.value and set(spmd_axis_name) & used: - raise ValueError("vmap spmd_axis_name cannot appear in shard_map in_specs") - new_in_names = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped - else ns for ns, d in zip(new_in_names, in_dims)] - new_size = trace.axis_data.size // prod(mesh.shape[n] for n in spmd_axis_name) - new_axis_data = batching.AxisData(trace.axis_data.name, new_size, - trace.axis_data.spmd_name, None) - else: - new_axis_data = trace.axis_data - fun, out_dims = batching.batch_subtrace(fun, trace.tag, new_axis_data, tuple(in_dims)) - @as_hashable_function(closure=out_names_thunk) - def new_out_names_thunk(): - return _batch_out_names(spmd_axis_name, out_dims(), out_names_thunk()) - - new_params = dict(mesh=mesh, in_names=new_in_names, - out_names_thunk=new_out_names_thunk, check_rep=check_rep, - rewrite=rewrite, auto=auto) - with core.set_current_trace(trace.parent_trace): - out_vals = prim.bind(fun, *in_vals, **new_params) - make_tracer = partial(batching.BatchTracer, trace, - source_info=source_info_util.current()) - return map(make_tracer, out_vals, out_dims()) -batching.BatchTrace.process_shard_map = _shard_map_batch - -def _batch_out_names(spmd_axis_name, dims, out_names): - out_names_ = [{ax + (d is not batching.not_mapped and d <= ax): names[ax] - for ax in names} for names, d in zip(out_names, dims)] - if spmd_axis_name is not None: - used = {n for names in out_names for ns in names.values() for n in ns} - if not config.disable_vmap_shmap_error.value and set(spmd_axis_name) & used: - raise ValueError("vmap spmd_axis_name cannot appear in shard_map out_specs") - out_names_ = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped - else ns for ns, d in zip(out_names_, dims)] - return out_names_ - - -# Autodiff - -def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names, - out_names_thunk, check_rep, rewrite, auto): - primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers)) - which_nz = [ type(t) is not ad.Zero for t in tangents] - tangents = [t if type(t) is not ad.Zero else None for t in tangents] - args, in_tree = tree_flatten((primals, tangents)) - f_jvp = ad.jvp_subtrace(f, trace.tag) - f_jvp, which_nz_out = ad.nonzero_tangent_outputs(f_jvp) - tangent_in_names = [ax for ax, nz in zip(in_names, which_nz) if nz] - - @as_hashable_function(closure=out_names_thunk) - def new_out_names_thunk(): - out_ax = out_names_thunk() - return (*out_ax, *(ax for ax, nz in zip(out_ax, which_nz_out()) if nz)) - params = dict(mesh=mesh, in_names=(*in_names, *tangent_in_names), - out_names_thunk=new_out_names_thunk, check_rep=check_rep, - rewrite=rewrite, auto=auto) - f_jvp, out_tree = ad.traceable(f_jvp, in_tree) - result = shard_map_p.bind_with_trace(trace.parent_trace, (f_jvp,) + tuple(args), params) - primal_out, tangent_out = tree_unflatten(out_tree(), result) - tangent_out = [ad.Zero(core.get_aval(p).to_tangent_aval()) if t is None else t - for p, t in zip(primal_out, tangent_out)] - return [ad.JVPTracer(trace, p, t) for p, t in zip(primal_out, tangent_out)] -ad.JVPTrace.process_shard_map = _shard_map_jvp - -def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p, - f: lu.WrappedFun, tracers, mesh, in_names, - out_names_thunk, check_rep, rewrite, auto): - tracers = map(trace.to_jaxpr_tracer, tracers) - in_pvals = [t.pval for t in tracers] - in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals) - unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names) - all_names = _all_newly_manual_mesh_names(mesh, auto, trace) - in_avals_sharded = map(partial(_shard_aval, mesh, auto), unk_in_names, in_avals) - f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.tag, f.debug_info, False) - f = _promote_scalar_residuals(f) - f_known, aux = pe.partial_eval_wrapper_nounits( - f, (*in_knowns,), (*in_avals_sharded,)) - - @as_hashable_function(closure=out_names_thunk) - def known_out_names(): - in_fwd, out_fwd, out_knowns, _, jaxpr, _ = aux() - _, out_known_names = pe.partition_list(out_knowns, out_names_thunk()) - num_res = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) - return (*out_known_names, *({0: all_names},) * num_res) - - known_params = dict(mesh=mesh, in_names=(*known_in_names,), - out_names_thunk=known_out_names, check_rep=check_rep, - rewrite=rewrite, auto=auto) - out = shard_map_p.bind_with_trace(trace.parent_trace, (f_known, *in_consts), known_params) - in_fwd, out_fwd, out_knowns, out_avals_sharded, jaxpr, env = aux() - num_res = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) - out_consts, non_fwd_res = split_list(out, [len(out) - num_res]) - assert not jaxpr.constvars - unk_out_names, _ = pe.partition_list(out_knowns, out_names_thunk()) - known_out_names_ = known_out_names() - res = subs_list2(in_fwd, out_fwd, in_consts, out_consts, non_fwd_res) - res_names = [known_in_names[f1] if f1 is not None else - known_out_names_[f2] if f2 is not None else - {0: all_names} for f1, f2 in zip(in_fwd, out_fwd)] - unk_in_names = (*res_names,) + ({},) * len(env) + (*unk_in_names,) # type: ignore[assignment] - const_tracers = map(trace.new_instantiated_const, res) - env_tracers = map(trace.to_jaxpr_tracer, env) - unk_arg_tracers = [t for t in tracers if not t.is_known()] - unk_params = dict(mesh=mesh, in_names=unk_in_names, - out_names=unk_out_names, jaxpr=jaxpr, check_rep=False, - rewrite=rewrite, auto=auto) - out_avals = map(partial(_unshard_aval, mesh), unk_out_names, out_avals_sharded) - out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) - for a in out_avals] - effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) - eqn = pe.new_eqn_recipe((*const_tracers, *env_tracers, *unk_arg_tracers), - out_tracers, shard_map_p, unk_params, - effs, source_info_util.current()) - for t in out_tracers: t.recipe = eqn - return merge_lists(out_knowns, out_tracers, out_consts) -pe.JaxprTrace.process_shard_map = _shard_map_partial_eval - -def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun, - tracers, mesh, in_names, - out_names_thunk, check_rep, rewrite, auto): - primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers)) - nzs_in = tuple(type(t) is not ad.Zero for t in tangents) - f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in, f.debug_info) - f_primal = _promote_scalar_residuals_lin(f_primal, linearize_outs_thunk) - tangent_in_names = [ax for ax, nz in zip(in_names, nzs_in) if nz] - res_names = _all_newly_manual_mesh_names(mesh, auto, trace) - - @as_hashable_function(closure=linearize_outs_thunk) - def fwd_out_names_thunk(): - _, _, _, _, in_fwd, out_fwd = linearize_outs_thunk() - out_names = out_names_thunk() - num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) - # This is incorrect so we set `check_rep=False` in the tangent (as in JVP). - return (*({0: res_names} for _ in range(num_res_out)), *out_names) - fwd_params = dict( - mesh=mesh, in_names=in_names, - out_names_thunk=fwd_out_names_thunk, check_rep=check_rep, - rewrite=rewrite, auto=auto) - all_fwd_results = shard_map_p.bind_with_trace( - trace.parent_trace, (f_primal, *primals), fwd_params) - residual_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = linearize_outs_thunk() - num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) - non_fwd_res = all_fwd_results[:num_res_out] - primals_out = all_fwd_results[num_res_out:] - residuals = subs_list2(in_fwd, out_fwd, primals, primals_out, non_fwd_res) - args_to_promote = [getattr(aval, 'shape', ()) == () and f1 is None and f2 is None - for aval, f1, f2 in zip(residual_avals, in_fwd, out_fwd)] - with _extend_axis_env(mesh, auto), use_abstract_mesh(_as_manual_mesh(mesh, auto)): - lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote) - out_names = out_names_thunk() - residual_names = [in_names[f1] if f1 is not None else - out_names[f2] if f2 is not None else - {0: res_names} for f1, f2 in zip(in_fwd, out_fwd)] - new_in_names = (*residual_names, *({} for _ in range(len(env))), - *(ax for ax, nz in zip(in_names, nzs_in) if nz)) - tangent_out_names = tuple(ax for ax, nz in zip(out_names_thunk(), nzs_out) if nz) - @as_hashable_function(closure=tangent_out_names) - def tangent_out_names_thunk(): - return tangent_out_names - tangent_params = dict( - mesh=mesh, in_names=new_in_names, out_names_thunk=tangent_out_names_thunk, - check_rep=False, rewrite=rewrite, auto=auto) - - # TODO(mattjj): avoid round-tripping the jaxpr through eval_jaxpr here - def f_tangent(*args): - return core.eval_jaxpr(lin_jaxpr, (), *args) - - nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz] - nz_tangents_out = shard_map_p.bind_with_trace( - trace.tangent_trace, - (lu.wrap_init(f_tangent, debug_info=lin_jaxpr.debug_info), - *residuals, *env, *nz_tangents_in), tangent_params) - nz_tangents_out_iter = iter(nz_tangents_out) - tangents_out = [next(nz_tangents_out_iter) if nz else ad.Zero.from_primal_value(primal) - for nz, primal in zip(nzs_out, primals_out)] - return map(partial(ad.maybe_linearize_tracer, trace), primals_out, nzs_out, tangents_out) -ad.LinearizeTrace.process_shard_map = _shard_map_linearize - -@lu.transformation2 -def _promote_scalar_residuals_lin(f, linearize_outs_thunk, *args, **kwargs): - ans = f(*args, **kwargs) - _, _, _, _, in_fwd, out_fwd = linearize_outs_thunk() - num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) - residuals = ans[:num_res_out] - primals = ans[num_res_out:] - residuals = [jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x - for x in residuals] - return *residuals, *primals - -@lu.transformation2 -def _promote_scalar_residuals(f: Callable, *args, **kwargs): - jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = f(*args, **kwargs) - which = [f1 is None and f2 is None and not v.aval.shape - for f1, f2, v in zip(in_fwds, out_fwds, jaxpr.constvars)] - jaxpr = _promote_scalar_residuals_jaxpr(jaxpr, which) - out_consts = [jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x - for x in out_consts] - return jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) - -def _promote_scalar_residuals_jaxpr(jaxpr: core.Jaxpr, - which: Sequence[bool]): - def fun(*res_and_args): - res, args = split_list(res_and_args, [len(jaxpr.constvars)]) - res = [_rem_singleton(x) if w else x for x, w in zip(res, which)] - return core.eval_jaxpr(jaxpr, res, *args) - res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval - for v, w in zip(jaxpr.constvars, which)] - in_avals = [*res_avals, *[v.aval for v in jaxpr.invars]] - jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(fun, debug_info=jaxpr.debug_info), in_avals) - return jaxpr - - -def _unmentioned2(mesh: Mesh, names: AxisNames, - auto: frozenset[AxisName]) -> list[AxisName]: - # We use a filtered-down version of unmentioned to avoid defensive-psum over - # more chips than required in the transpose-no-check-rep case. - name_set = {n for ns in names.values() for n in ns} | auto - return [n for n in _all_mesh_names_except_spmd(mesh, auto) - if n not in name_set] - - -def _shard_map_transpose(out_cts, *args, - jaxpr: core.Jaxpr, mesh, in_names, out_names, - check_rep, rewrite, auto): - mb_div = lambda x, y: x / y if y != 1 else x - out_cts = [ - ad.Zero(_shard_aval(mesh, auto, ns, x.aval)) if type(x) is ad.Zero - else x if rewrite or dtypes.dtype(x) == dtypes.float0 - else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns, auto)))) - for ns, x in zip(out_names, out_cts) - ] - args = tuple(x if type(x) is not ad.UndefinedPrimal else - ad.UndefinedPrimal(_shard_aval(mesh, auto, ns, x.aval)) - for ns, x in zip(in_names, args)) - all_args, in_tree = tree_flatten((out_cts, args)) - - def fun_trans_callable(out_cts, args): - # TODO(mattjj): when #26811 lands, delete this and just run backward_pass - in_undef = map(ad.is_undefined_primal, args) - res, undefs = partition_list(in_undef, args) - jaxpr_known, jaxpr_unknown, _, _ = pe.partial_eval_jaxpr_nounits( - pe.close_jaxpr(jaxpr), in_undef, False) - res_reshaped = core.jaxpr_as_fun(jaxpr_known)(*res) - in_cts = ad.backward_pass( - jaxpr_unknown.jaxpr, False, (), (*res_reshaped, *undefs), out_cts - )[len(res_reshaped):] - _, in_ct_names = partition_list(in_undef, in_names) - in_cts = [ad.Zero(_unshard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero - else x if rewrite - else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns, auto))) - for ns, x in zip(in_ct_names, in_cts)] - res_zeros = [ad_util.zero_from_primal(r) for r in res] - return merge_lists(in_undef, res_zeros, in_cts) - - fun_trans = lu.wrap_init(fun_trans_callable, debug_info=jaxpr.debug_info) - fun_trans, nz_arg_cts = ad.nonzero_outputs(fun_trans) - fun_trans_flat, out_tree = api_util.flatten_fun_nokwargs(fun_trans, in_tree) - - new_in_names = \ - [n for n, x in zip(out_names, out_cts) if type(x) is not ad.Zero] + \ - [n for n, x in zip(in_names, args) if type(x) is not ad.UndefinedPrimal] - - def new_out_names_thunk(): - return tuple(names for names, nz in zip(in_names, nz_arg_cts()) if nz) - - try: - out_flat = shard_map_p.bind( - fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names), - out_names_thunk=new_out_names_thunk, check_rep=check_rep, rewrite=rewrite, - auto=auto) - except (FloatingPointError, ZeroDivisionError) as e: - print("Invalid nan value encountered in the backward pass of a shard_map " - "function. Calling the de-optimized backward pass.") - try: - # TODO(mattjj): Remove this and do `fun_trans.call_wrapped(out_cts, args)` - # in eager mode so that output of shmap are not manual. - with jax.disable_jit(True): - _ = shard_map_p.bind( - fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names), - out_names_thunk=new_out_names_thunk, check_rep=check_rep, - rewrite=rewrite, auto=auto) - except (FloatingPointError, ZeroDivisionError) as e2: - raise e2 from None - else: - dispatch._raise_no_nan_in_deoptimized(e) - return tree_unflatten(out_tree(), out_flat) -ad.primitive_transposes[shard_map_p] = _shard_map_transpose - -# Remat - -def _partial_eval_jaxpr_custom_rule( - saveable: Callable[..., pe.RematCases_], unks_in: Sequence[bool], - inst_in: Sequence[bool], eqn: core.JaxprEqn -) -> tuple[core.JaxprEqn, core.JaxprEqn, Sequence[bool], Sequence[bool], - list[core.Var]]: - jaxpr, mesh = eqn.params['jaxpr'], eqn.params['mesh'] - auto = eqn.params['auto'] - with _extend_axis_env(mesh, auto): - jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \ - pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable) - num_out_primals = len(jaxpr_known.outvars) - num_res - in_fwd = pe._jaxpr_forwarding(jaxpr_known)[num_out_primals:] - out_vars, res_vars = split_list(jaxpr_known.outvars, [num_out_primals]) - idx_map = {id(v): i for i, v in enumerate(out_vars)} - out_fwd = [idx_map.get(id(v)) for v in res_vars] - which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)] - mesh = eqn.params['mesh'] - with (_extend_axis_env(mesh, auto), - use_abstract_mesh(_as_manual_mesh(mesh, auto))): - jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which) - jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged) - jaxpr_known = core.remove_named_axis_effects(jaxpr_known, mesh.axis_names) - jaxpr_staged = core.remove_named_axis_effects(jaxpr_staged, mesh.axis_names) - ins_known, _ = partition_list(unks_in, eqn.invars) - out_binders_known, _ = partition_list(unks_out, eqn.outvars) - _, ins_staged = partition_list(inst_in, eqn.invars) - _, out_binders_staged = partition_list(inst_out, eqn.outvars) - newvar = core.gensym() - params_known, params_staged, res_names = _pe_custom_params( - unks_in, inst_in, map(op.not_, unks_out), inst_out, in_fwd, out_fwd, which, - dict(eqn.params, jaxpr=jaxpr_known), dict(eqn.params, jaxpr=jaxpr_staged)) - residuals = [newvar(_unshard_aval(mesh, {0: res_names}, var.aval)) - for var, w in zip(jaxpr_staged.invars[:num_res], which) if w] - eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals], - eqn.primitive, params_known, jaxpr_known.effects, - eqn.source_info, eqn.ctx) - full_res = subs_list2(in_fwd, out_fwd, ins_known, out_binders_known, residuals) - eqn_staged = pe.new_jaxpr_eqn([*full_res, *ins_staged], out_binders_staged, - eqn.primitive, params_staged, - jaxpr_staged.effects, eqn.source_info, eqn.ctx) - assert len(eqn_staged.invars) == len(jaxpr_staged.invars) - new_inst = [x for x, inst in zip(eqn.invars, inst_in) - if type(x) is core.Var and not inst] - new_inst += [out_binders_known[f] for f in {i for i in out_fwd if i is not None}] - return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals -pe.partial_eval_jaxpr_custom_rules[shard_map_p] = \ - _partial_eval_jaxpr_custom_rule - -def _add_reshapes(which: Sequence[bool], - jaxpr_known: core.Jaxpr, - jaxpr_staged: core.Jaxpr) -> tuple[core.Jaxpr, core.Jaxpr]: - # add singleton axes to residuals which are from jaxpr_known and are scalars - which_ = [w and not v.aval.shape # pytype: disable=attribute-error - for w, v in zip(which, jaxpr_staged.invars[:len(which)])] - if not any(which_): return jaxpr_known, jaxpr_staged - assert not jaxpr_known.constvars and not jaxpr_staged.constvars - - def known(*args): - out = core.eval_jaxpr(jaxpr_known, (), *args) - out_known, res = split_list(out, [len(out) - sum(which)]) - res = [_add_singleton(x) if not x.shape else x for x in res] - return [*out_known, *res] - avals_in = [v.aval for v in jaxpr_known.invars] - jaxpr_known, _, (), () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(known, debug_info=jaxpr_known.debug_info), avals_in) - - def staged(*args): - res_, ins = split_list(args, [len(which)]) - res = [_rem_singleton(x) if w else x for x, w in zip(res_, which_)] - return core.eval_jaxpr(jaxpr_staged, (), *res, *ins) - res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval - for w, v in zip(which_, jaxpr_staged.invars[:len(which)])] - avals_in = [*res_avals, *[v.aval for v in jaxpr_staged.invars[len(which):]]] - jaxpr_staged, _, (), () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(staged, debug_info=jaxpr_staged.debug_info), avals_in) - - return jaxpr_known, jaxpr_staged - -def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged, - in_fwd, out_fwd, which, params_known, params_staged): - # prune inputs to jaxpr_known according to unks_in - mesh = params_known['mesh'] - auto = params_known['auto'] - res_names_ = _all_newly_manual_mesh_names(mesh, auto) - in_names_known, _ = partition_list(unks_in, params_known['in_names']) - _, out_names_known = partition_list(kept_outs_known, params_known['out_names']) - out_names_known = out_names_known + [{0: res_names_}] * sum(which) - new_params_known = dict(params_known, in_names=tuple(in_names_known), - out_names=tuple(out_names_known)) - - # added num_res new inputs to jaxpr_staged, pruning according to inst_in - _, in_names_staged = partition_list(inst_in, params_staged['in_names']) - res_names = [in_names_known[f1] if f1 is not None else - out_names_known[f2] if f2 is not None else - {0: res_names_} for f1, f2 in zip(in_fwd, out_fwd)] - in_names_staged = res_names + in_names_staged - _, out_names_staged = partition_list(kept_outs_staged, params_staged['out_names']) - new_params_staged = dict(params_staged, in_names=tuple(in_names_staged), - out_names=tuple(out_names_staged), check_rep=False) - return new_params_known, new_params_staged, res_names_ - -# TODO(mattjj): remove this mechanism when we revise mesh scopes -def _all_mesh_names_except_spmd( - mesh: Mesh, auto: frozenset[AxisName], trace=None -) -> tuple[AxisName, ...]: - axis_env = core.get_axis_env() - spmd_names = axis_env.spmd_axis_names - return tuple(name for name in mesh.axis_names if name not in spmd_names and - name not in auto) - -def _all_newly_manual_mesh_names( - mesh: Mesh, auto: frozenset[AxisName], trace=None -) -> tuple[AxisName, ...]: - axis_env = core.get_axis_env() - vmap_spmd_names = set(axis_env.spmd_axis_names) - if not (ctx_mesh := get_abstract_mesh()).empty: - mesh = ctx_mesh - already_manual_names = set(ctx_mesh._axis_types_dict.get(AxisType.Manual, ())) - else: - # TODO(mattjj): remove this mechanism when we revise mesh scopes - already_manual_names = set(axis_env.axis_sizes) # may include vmap axis_names - return tuple(name for name in mesh.axis_names - if name not in auto | vmap_spmd_names | already_manual_names) - - -# DCE - -# TODO(mattjj): de-duplicate with pe.dce_jaxpr_call_rule, and/or _pmap_dce_rule? -def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn - ) -> tuple[list[bool], core.JaxprEqn | None]: - if not any(used_outputs) and not pe.has_effects(eqn): - return [False] * len(eqn.invars), None - mesh = eqn.params["mesh"] - auto = eqn.params["auto"] - with _extend_axis_env(mesh, auto): - jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs) - if not any(used_inputs) and not any(used_outputs) and not jaxpr.effects: - return used_inputs, None - else: - _, in_names = partition_list(used_inputs, eqn.params['in_names']) - _, out_names = partition_list(used_outputs, eqn.params['out_names']) - new_params = dict(eqn.params, jaxpr=jaxpr, in_names=tuple(in_names), - out_names=tuple(out_names)) - effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) - new_eqn = pe.new_jaxpr_eqn( - [v for v, used in zip(eqn.invars, used_inputs) if used], - [x for x, used in zip(eqn.outvars, used_outputs) if used], - eqn.primitive, new_params, effs, eqn.source_info, eqn.ctx) - return used_inputs, new_eqn -pe.dce_rules[shard_map_p] = _shard_map_dce - -# Implementing pmap in terms of shard_map - -def pmap(f, axis_name=None, *, in_axes=0, out_axes=0, - static_broadcasted_argnums=(), devices=None, backend=None, - axis_size=None, donate_argnums=(), global_arg_shapes=None): - devices = tuple(devices) if devices is not None else devices - axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap( - f, axis_name, static_broadcasted_argnums, donate_argnums, in_axes, out_axes) - - def infer_params(*args, **kwargs): - p = _prepare_pmap(f, in_axes, out_axes, static_broadcasted_tuple, - donate_tuple, devices, backend, axis_size, args, kwargs) - for arg in p.flat_args: - dispatch.check_arg(arg) - mesh = Mesh(_get_devices(p, backend), (axis_name,)) - _pmapped, in_specs, out_specs = _cached_shard_map( - p.flat_fun, mesh, p.in_axes_flat, p.out_axes_thunk, axis_name) - flat_global_args = host_local_array_to_global_array( - p.flat_args, mesh, list(in_specs)) - jitted_f = jax.jit( - _pmapped, - donate_argnums=(i for i, val in enumerate(p.donated_invars) if val)) - return jitted_f, flat_global_args, p.out_tree, mesh, out_specs - - def wrapped(*args, **kwargs): - (jitted_f, flat_global_args, out_tree, mesh, - out_specs) = infer_params(*args, **kwargs) - outs = jitted_f(*flat_global_args) - outs = global_array_to_host_local_array(outs, mesh, out_specs()) - return tree_unflatten(out_tree(), outs) - - def lower(*args, **kwargs): - jitted_f, _, _, _, _ = infer_params(*args, **kwargs) - return jitted_f.lower(*args, **kwargs) - wrapped.lower = lower - - return wrapped - - -@lu.cache -def _cached_shard_map(flat_fun, mesh, in_axes_flat, out_axes_thunk, axis_name): - in_specs = tuple(map(partial(_axis_to_spec, axis_name), in_axes_flat)) - out_specs = lambda: map(partial(_axis_to_spec, axis_name), out_axes_thunk()) - fun = _handle_reshapes(flat_fun, in_axes_flat, out_axes_thunk) - return (_shard_map(fun.call_wrapped, mesh, in_specs, out_specs, - check_rep=False, auto=frozenset()), - in_specs, out_specs) - -@lu.transformation2 -def _handle_reshapes(f, in_axes, out_axes_thunk, *args, **kwargs): - args = tree_map(lambda x, ax: x if ax is None else jnp.squeeze(x, axis=ax), - list(args), list(in_axes)) - out = f(*args) - return tree_map(lambda x, ax: x if ax is None else jnp.expand_dims(x, axis=ax), - list(out), list(out_axes_thunk())) - -def _axis_to_spec(axis_name, ax): - if isinstance(ax, int): - specs = [None] * ax + [axis_name] - return P(*specs) - elif ax is None: - return P() - else: - raise TypeError(ax) - -def _get_devices(p, backend): - if backend is not None and p.devices is None: - devs = jax.devices(backend=backend) - else: - devs = jax.devices() if p.devices is None else p.devices - if jax.process_count() > 1: - return devs[:p.global_axis_size] - return devs[:p.local_axis_size] - - -### Rewrite! - -Val = Any - -class RewriteTracer(core.Tracer): - rep: set[AxisName] - val: Val - - def __init__(self, trace, rep, val): - self._trace = trace - self.rep = rep - self.val = val - - @property - def aval(self) -> core.AbstractValue: - return core.get_aval(self.val) - - def to_concrete_value(self): - return core.to_concrete_value(self.val) - - def __str__(self) -> str: - return str(self.val) # TODO(mattjj): could show replication info here - __repr__ = __str__ # for debuggers, like `p x` - -class RewriteTrace(core.Trace): - __slots__ = ("parent_trace", "tag", "mesh") - - parent_trace : core.Trace - tag : core.TraceTag - mesh: Mesh - - def __init__(self, parent_trace, tag, mesh): - super().__init__() - self.parent_trace = parent_trace - self.tag = tag - self.mesh = mesh - - def to_val_rep_pair(self, val): - # TODO: add a tag to tell if self - if isinstance(val, RewriteTracer) and val._trace.tag is self.tag: - return val.val, val.rep - else: - return val, set(self.mesh.axis_names) - - def process_primitive(self, prim, in_tracers, params): - rule = _rewrite_rules.get(prim, partial(_rule_missing, prim)) - in_vals, in_reps = unzip2(map(self.to_val_rep_pair, in_tracers)) - with core.set_current_trace(self.parent_trace): - out_vals, out_reps = rule(self.mesh, in_reps, *in_vals, **params) - out_tracers = map(partial(RewriteTracer, self), out_reps, out_vals) - return out_tracers if prim.multiple_results else out_tracers[0] - - def process_call(self, call_primitive, f, in_tracers, params): - in_vals, in_reps = unzip2(map(self.to_val_rep_pair, in_tracers)) - f, out_reps = _rewrite_subtrace(f, self.tag, self.mesh, tuple(in_reps)) - with core.set_current_trace(self.parent_trace): - out_vals = call_primitive.bind(f, *in_vals, **params) - return map(partial(RewriteTracer, self), out_reps(), out_vals) - - def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): - if symbolic_zeros: - msg = ("Please open an issue at https://github.com/jax-ml/jax/issues and " - "as a temporary workaround pass the check_rep=False argument to " - "shard_map") - raise NotImplementedError(msg) - in_vals, in_reps = unzip2(map(self.to_val_rep_pair, tracers)) - fun, out_reps1 = _rewrite_subtrace(fun, self.tag, self.mesh, in_reps) - jvp, out_reps2 = _rewrite_subtrace(jvp, self.tag, self.mesh, in_reps * 2) - with core.set_current_trace(self.parent_trace): - out_vals = prim.bind(fun, jvp, *in_vals, symbolic_zeros=symbolic_zeros) - fst, out_reps = lu.merge_linear_aux(out_reps1, out_reps2) - if not fst: - assert out_reps == out_reps[:len(out_reps) // 2] * 2 - out_reps = out_reps[:len(out_reps) // 2] - return map(partial(RewriteTracer, self), out_reps, out_vals) - - def process_custom_vjp_call(self, prim: core.Primitive, fun: lu.WrappedFun, - fwd: lu.WrappedFun, bwd: lu.WrappedFun, - tracers, - out_trees: Callable[[], Sequence[PyTreeDef]], - symbolic_zeros: bool): - if symbolic_zeros: - msg = ("Please open an issue at https://github.com/jax-ml/jax/issues and " - "as a temporary workaround pass the check_rep=False argument to " - "shard_map") - raise NotImplementedError(msg) - in_vals, in_reps = unzip2(map(self.to_val_rep_pair, tracers)) - fun, out_reps1 = _rewrite_subtrace(fun, self.tag, self.mesh, in_reps) - fwd_in_reps = [r_ for r in in_reps for r_ in [r, set(self.mesh.axis_names)]] - fwd, out_reps2 = _rewrite_subtrace(fwd, self.tag, self.mesh, fwd_in_reps) - bwd = _rewrite_bwd(bwd, self.mesh, out_reps2, in_reps) - with core.set_current_trace(self.parent_trace): - out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees, - symbolic_zeros=symbolic_zeros) - fst, out_reps = lu.merge_linear_aux(out_reps1, out_reps2) - if not fst: - _, res_tree = out_trees() - _, out_reps = split_list(out_reps, [res_tree.num_leaves]) - return map(partial(RewriteTracer, self), out_reps, out_vals) - -def _efficient_transpose_rewrite(fun, mesh, in_names, out_names_thunk): - in_reps = map(partial(_in_names_to_rep, mesh), in_names) - out_reps_dst = lambda: [set(_unmentioned(mesh, n)) for n in out_names_thunk()] - fun, out_reps_src = _efficient_transpose_rewrite_nomatch(fun, mesh, in_reps) - return _match_rep(fun, mesh, out_reps_src, out_reps_dst) - -@lu.transformation_with_aux2 -def _efficient_transpose_rewrite_nomatch(f, store, mesh, in_reps, *args): - with core.take_current_trace() as parent: - tag = core.TraceTag() - t = RewriteTrace(parent_trace = parent, tag = tag, mesh=mesh) - in_tracers = map(partial(RewriteTracer, t), in_reps, args) - with core.set_current_trace(t): - ans = f(*in_tracers) - out_vals, out_reps = unzip2(map(t.to_val_rep_pair, ans)) - del t, in_tracers, ans - store.store(out_reps) - return out_vals - -@lu.transformation2 -def _match_rep(f, mesh, out_reps_src_, out_reps_dst_, *args): - outs = f(*args) - out_reps_src = out_reps_src_() if callable(out_reps_src_) else out_reps_src_ - out_reps_dst = out_reps_dst_() if callable(out_reps_dst_) else out_reps_dst_ - _check_reps2(mesh, out_reps_dst, out_reps_src) - outs = [pbroadcast(x, tuple(n for n in src if n not in dst)) if src - dst - else x for x, src, dst in zip(outs, out_reps_src, out_reps_dst)] - return outs - -# TODO(mattjj): caching -def _replication_rewrite_match( - mesh: Mesh, - jaxpr: core.ClosedJaxpr, - in_rep: Sequence[set[AxisName]], - out_rep_dst: Sequence[set[AxisName]], -) -> core.ClosedJaxpr: - f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts), - debug_info=jaxpr.jaxpr.debug_info) - f, out_rep = _efficient_transpose_rewrite_nomatch(f, mesh, in_rep) - f = _match_rep(f, mesh, out_rep, out_rep_dst) - jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals) - return core.ClosedJaxpr(jaxpr_, consts) - -# TODO(mattjj): caching -def _replication_rewrite_nomatch( - mesh: Mesh, - jaxpr: core.ClosedJaxpr, - in_rep: Sequence[set[AxisName]], -) -> tuple[core.ClosedJaxpr, list[set[AxisName]]]: - f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts), - debug_info=jaxpr.jaxpr.debug_info) - f, out_rep = _efficient_transpose_rewrite_nomatch(f, mesh, in_rep) - jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals) - return core.ClosedJaxpr(jaxpr_, consts), out_rep() - -@lu.transformation_with_aux2 -def _rewrite_subtrace(f: Callable, store: lu.Store, - tag: core.TraceTag, mesh: Mesh, in_reps, *in_vals): - with core.take_current_trace() as parent_trace: - assert len(in_reps) == len(in_vals), (len(in_reps), len(in_vals)) - t = RewriteTrace(parent_trace, tag, mesh) - in_tracers = map(partial(RewriteTracer, t), in_reps, in_vals) - with core.set_current_trace(t): - outs = f(*in_tracers) - out_vals, out_reps = unzip2(map(t.to_val_rep_pair, outs)) - store.store(out_reps) - return out_vals - -def _rewrite_bwd(bwd: lu.WrappedFun, - mesh: Mesh, in_reps, reps_dst) -> lu.WrappedFun: - def new_bwd(*args): - tag = core.TraceTag() - bwd_, reps_thunk = _rewrite_subtrace(bwd, tag, mesh, in_reps()) - out = bwd_.call_wrapped(*args) - return map(_match_replication, reps_thunk(), reps_dst, out) - return lu.wrap_init(new_bwd, debug_info=bwd.debug_info) - -def _match_replication(src, dst, x): - if dst - src: - x, = psum2_p.bind(x, axes=tuple(n for n in dst if n not in src), - axis_index_groups=None) - if src - dst: - x = pbroadcast(x, tuple(n for n in src if n not in dst)) - return x - -# TODO(parkers,mattjj): change implementation when we have sharding-in-types. -def get_replication(x: jax.Array) -> set[AxisName]: - """For a jax.Array, return what axes it is known to be replicated along.""" - - if isinstance(x, RewriteTracer): - return x.rep - if isinstance(x, batching.BatchTracer): - return get_replication(x.val) - raise ValueError("get_replication not defined on %s" % repr(type(x))) +def shard_map(f, mesh, in_specs, out_specs, check_rep=True): + """Please use `jax.shard_map`. `jax.experimental.shard_map.shard_map` + has been deprecated.""" + return jshmap._shard_map(f, mesh=mesh, in_specs=in_specs, out_specs=out_specs, + axis_names=set(), check_vma=check_rep) + +_deprecations = { + # Deprecated in v0.8.0; we plan to keep this as a deprecated legacy API. + "shard_map": ( + "jax.experimental.shard_map is deprecated in v0.8.0. Used jax.shard_map instead.", + shard_map + ) +} + +import typing +if typing.TYPE_CHECKING: + pass +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr + del shard_map +del typing diff --git a/jax/experimental/slab/djax.py b/jax/experimental/slab/djax.py index 18f47515ac35..838f9860e8bb 100644 --- a/jax/experimental/slab/djax.py +++ b/jax/experimental/slab/djax.py @@ -14,7 +14,6 @@ from __future__ import annotations -import collections from collections.abc import Callable from functools import partial import sys @@ -55,15 +54,6 @@ def interp(djaxpr, slab, sizes, args): slab, outs = eval_djaxpr(djaxpr, slab, *sizes.values(), *views) return slab, outs -def _check_axis_size_conflicts(all_axes, sizes): - if len(all_axes) != len(set(all_axes)): - d = collections.defaultdict(list) - for name, sz in zip(all_axes, sizes): - d[name].append(sz) - msg = '; '.join([f'{name}: {" != ".join(map(str, sizes))}' - for name, sizes in d.items() if len(sizes) > 1]) - raise ValueError(f'abstracted axes resolve to conflicting sizes. {msg}') - def djit(f, abstracted_axes, **djit_kwargs): # TODO(frostig,mattjj): un/flatten f def f_wrapped(slab, *args): # TODO(frostig,mattjj): kw support diff --git a/jax/experimental/source_mapper/common.py b/jax/experimental/source_mapper/common.py index f7d10bc88f10..471fc0a7a877 100644 --- a/jax/experimental/source_mapper/common.py +++ b/jax/experimental/source_mapper/common.py @@ -15,7 +15,8 @@ import contextlib import dataclasses import re -from typing import Any, Protocol, Sequence +from typing import Any, Protocol +from collections.abc import Sequence from absl import flags import jax diff --git a/jax/experimental/source_mapper/generate_map.py b/jax/experimental/source_mapper/generate_map.py index 76fd0f744463..0066e35285fb 100644 --- a/jax/experimental/source_mapper/generate_map.py +++ b/jax/experimental/source_mapper/generate_map.py @@ -14,7 +14,8 @@ """Generates source maps for JAX functions.""" import os import tempfile -from typing import Sequence, Protocol +from typing import Protocol +from collections.abc import Sequence from jax.experimental.source_mapper import common diff --git a/jax/experimental/source_mapper/hlo.py b/jax/experimental/source_mapper/hlo.py index 5c9f1c01ac96..ad189c80f0b5 100644 --- a/jax/experimental/source_mapper/hlo.py +++ b/jax/experimental/source_mapper/hlo.py @@ -30,15 +30,17 @@ class HloPass(enum.Enum): METADATA_REGEX = re.compile( - r"metadata={op_name=\"(?P.*)\" source_file=\"(?P.*)\"" - r" source_line=(?P[0-9]+)\}" + r"metadata={.*op_name=\"(?P.*)\"" + r" source_file=\"(?P.*)\"" + r" source_line=(?P[0-9]+).*?}" ) -def parse_hlo_dump(text: str) -> sourcemap.SourceMap: +# TODO(justinfu): Remove when new format is the default. +def _parse_hlo_old_format(lines: list[str]) -> sourcemap.SourceMap: mappings = sourcemap.MappingsGenerator() used_source_files = [] - for line in text.split("\n"): + for line in lines: mappings.new_group() match = METADATA_REGEX.search(line) if match: @@ -63,6 +65,96 @@ def parse_hlo_dump(text: str) -> sourcemap.SourceMap: ) +def _parse_hlo_new_format(lines: list[str]) -> sourcemap.SourceMap: + file_names = {} + file_locations = {} + stack_frames = {} + current_section = None + for line in lines: + line = line.strip() + if not line: + continue + + if line in ["FileNames", "FunctionNames", "FileLocations", "StackFrames"]: + current_section = line + continue + + if current_section == "FileNames": + match = re.match(r"(\d+)\s+\"(.*)\"", line) + if match: + file_names[int(match.group(1))] = match.group(2) + elif current_section == "FileLocations": + # Format: 1 {file_name_id=1 function_name_id=1 line=153 end_line=153 column=2 end_column=31} + match = re.match(r"(\d+)\s+{(.*)}", line) + if match: + loc_id = int(match.group(1)) + attrs = match.group(2) + loc_data = {} + for part in attrs.split(): + if "=" in part: + k, v = part.split("=") + if k not in ["file_name_id", "function_name_id", "line", + "end_line", "column", "end_column"]: + raise ValueError(f"Unknown attribute for FileLocations: {k}") + loc_data[k] = int(v) + file_locations[loc_id] = loc_data + elif current_section == "StackFrames": + # Format: 1 {file_location_id=1 parent_frame_id=1} + match = re.match(r"(\d+)\s+{(.*)}", line) + if match: + frame_id = int(match.group(1)) + attrs = match.group(2) + frame_data = {} + for part in attrs.split(): + if "=" in part: + k, v = part.split("=") + if k not in ["file_location_id", "parent_frame_id"]: + raise ValueError(f"Unknown attribute for StackFrames: {k}") + frame_data[k] = int(v) + stack_frames[frame_id] = frame_data + + mappings = sourcemap.MappingsGenerator() + used_source_files = [] + + for line in lines: + mappings.new_group() + if "metadata={" in line: + match = re.search(r"stack_frame_id=(\d+)", line) + if match: + stack_frame_id = int(match.group(1)) + if stack_frame_id in stack_frames: + frame = stack_frames[stack_frame_id] + file_loc = file_locations.get(frame["file_location_id"]) + if file_loc: + file_name = file_names.get(file_loc["file_name_id"]) + if file_name: + if file_name not in used_source_files: + used_source_files.append(file_name) + src_file_idx = used_source_files.index(file_name) + src_line = file_loc["line"] - 1 + first_col = line.index(line.strip()[0]) + mappings.new_segment(first_col, src_file_idx, src_line, 0) + else: + raise ValueError(f"Could not find mapping for {file_loc=}") + else: + raise ValueError(f"Could not find mapping for {stack_frame_id=}") + mappings.new_group() + return sourcemap.SourceMap( + version=3, + sources=used_source_files, + sources_content=[], + mappings=mappings.mappings(), + names=[], + ) + + +def parse_hlo_dump(text: str) -> sourcemap.SourceMap: + lines = text.split("\n") + if "FileNames" in text: + return _parse_hlo_new_format(lines) + return _parse_hlo_old_format(lines) + + def trace_and_lower(work_dir, f, f_args, f_kwargs, **_): lowered = jax.jit(lambda *args: f(*args, **f_kwargs)).lower(*f_args) return (lowered, work_dir) diff --git a/jax/experimental/sparse/__init__.py b/jax/experimental/sparse/__init__.py index f388cd527cf9..dbd21e343bb7 100644 --- a/jax/experimental/sparse/__init__.py +++ b/jax/experimental/sparse/__init__.py @@ -15,10 +15,17 @@ """ .. currentmodule:: jax.experimental.sparse +.. note:: + + The methods in ``jax.experimental.sparse`` are experimental reference implementations, + and not recommended for use in performance-critical applications. The submodule is no + longer being actively developed, but the team will continue supporting existing features + as best we can. + The :mod:`jax.experimental.sparse` module includes experimental support for sparse matrix -operations in JAX. It is under active development, and the API is subject to change. The -primary interfaces made available are the :class:`BCOO` sparse array type, and the -:func:`sparsify` transform. +operations in JAX. The primary interfaces made available are the :class:`BCOO` sparse array +type, and the :func:`sparsify` transform. + Batched-coordinate (BCOO) sparse matrices ----------------------------------------- diff --git a/jax/experimental/sparse/_base.py b/jax/experimental/sparse/_base.py index 7739af0291f1..36d84cb0db62 100644 --- a/jax/experimental/sparse/_base.py +++ b/jax/experimental/sparse/_base.py @@ -19,18 +19,8 @@ import jax from jax._src import core -from jax._src import ffi from jax._src import util from jax._src.typing import Array -from jax._src.lib import gpu_sparse - - -if hasattr(gpu_sparse, "registrations"): - for platform, targets in gpu_sparse.registrations().items(): - for name, value, api_version in targets: - ffi.register_ffi_target( - name, value, platform=platform, api_version=api_version - ) class JAXSparse(util.StrictABC): diff --git a/jax/experimental/sparse/_lowerings.py b/jax/experimental/sparse/_lowerings.py index 6962ef78bcff..c5a28d7a4607 100644 --- a/jax/experimental/sparse/_lowerings.py +++ b/jax/experimental/sparse/_lowerings.py @@ -18,13 +18,40 @@ """ from functools import partial +from typing import Any from jax._src import core from jax._src import dispatch +from jax._src import ffi from jax._src.interpreters import mlir from jax._src.lib import gpu_sparse +from jax._src.lib import has_cpu_sparse import numpy as np +if hasattr(gpu_sparse, "registrations"): + for platform, targets in gpu_sparse.registrations().items(): + for name, value, api_version in targets: + ffi.register_ffi_target( + name, value, platform=platform, api_version=api_version + ) + +if has_cpu_sparse: + from jax._src.lib import cpu_sparse + + if hasattr(cpu_sparse, "registrations"): + for platform, targets in cpu_sparse.registrations().items(): + for name, value, api_version in targets: + ffi.register_ffi_target( + name, value, platform=platform, api_version=api_version + ) + +def _get_module(target_name_prefix: str) -> Any: + if target_name_prefix == "cu": + return gpu_sparse._cusparse + elif target_name_prefix == "hip": + return gpu_sparse._hipsparse + else: + raise ValueError(f"Unsupported target_name_prefix: {target_name_prefix}") SUPPORTED_DATA_DTYPES = [np.float32, np.float64, np.complex64, np.complex128] SUPPORTED_INDEX_DTYPES = [np.int32] @@ -54,28 +81,29 @@ def _coo_spmv_abstract_eval(data, row, col, x, *, transpose, shape): shape=shape[1:] if transpose else shape[:1], dtype=x.dtype) -def _coo_spmv_gpu_lowering(coo_spmv_hlo, ctx, data, row, col, x, *, transpose, shape): +def _coo_spmv_gpu_lowering(ctx, data, row, col, x, *, transpose, shape, + target_name_prefix): + rows, cols = shape data_aval, row_aval, _, x_aval = ctx.avals_in - return [coo_spmv_hlo( - data, row, col, x, - shape=shape, - transpose=transpose, - data_dtype=data_aval.dtype, - index_dtype=row_aval.dtype, - x_dtype=x_aval.dtype)] + nnz, = data_aval.shape + buffer_size, opaque = _get_module(target_name_prefix).build_coo_matvec_descriptor( + data_aval.dtype, x_aval.dtype, data_aval.dtype, row_aval.dtype, + rows, cols, nnz, transpose) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_coo_matvec_ffi") + return rule(sub_ctx, data, row, col, x, opaque=opaque)[:1] coo_spmv_p.def_abstract_eval(_coo_spmv_abstract_eval) dispatch.simple_impl(coo_spmv_p) -if gpu_sparse.cuda_is_supported: - mlir.register_lowering( - coo_spmv_p, - partial(_coo_spmv_gpu_lowering, gpu_sparse.cuda_coo_matvec), - platform='cuda') -if gpu_sparse.rocm_is_supported: - mlir.register_lowering( - coo_spmv_p, - partial(_coo_spmv_gpu_lowering, gpu_sparse.rocm_coo_matvec), - platform='rocm') +mlir.register_lowering( + coo_spmv_p, + partial(_coo_spmv_gpu_lowering, target_name_prefix='cu'), + platform='cuda') +mlir.register_lowering( + coo_spmv_p, + partial(_coo_spmv_gpu_lowering, target_name_prefix='hip'), + platform='rocm') # coo_spmm_p @@ -103,28 +131,50 @@ def _coo_spmm_abstract_eval(data, row, col, x, *, transpose, shape): shape=(shape[1] if transpose else shape[0], x.shape[1]), dtype=x.dtype) -def _coo_spmm_gpu_lowering(coo_spmm_hlo, ctx, data, row, col, x, *, transpose, shape): +def _coo_spmm_gpu_lowering(ctx, data, row, col, x, *, transpose, shape, + target_name_prefix): data_aval, row_aval, _, x_aval = ctx.avals_in - return [coo_spmm_hlo( - data, row, col, x, - shape=shape, - transpose=transpose, - data_dtype=data_aval.dtype, - index_dtype=row_aval.dtype, - x_dtype=x_aval.dtype)] + nnz, = data_aval.shape + _, Ccols = x_aval.shape + + batch_count = 1 + if len(shape) == 2: + rows, cols = shape + elif len(shape) == 3: + batch_count, rows, cols = shape + nnz = nnz // batch_count + else: + raise NotImplementedError(f"Unsupported shape: {shape}") + + # TODO(tianjianlu): use batch stride to trigger different mode of batch + # computation. Currently batch_stride = 0 is not allowed because of the issue + # in cusparse https://github.com/NVIDIA/CUDALibrarySamples/issues/81#issuecomment-1205562643 + # Set batch stride to be the matrix size for now. + lhs_batch_stride = nnz + B_rows = rows if transpose else cols + rhs_batch_stride = B_rows * Ccols + + buffer_size, opaque = _get_module(target_name_prefix).build_coo_matmat_descriptor( + data_aval.dtype, x_aval.dtype, data_aval.dtype, row_aval.dtype, + rows, cols, Ccols, nnz, transpose, batch_count, lhs_batch_stride, + rhs_batch_stride) + + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_coo_matmat_ffi") + return rule(sub_ctx, data, row, col, x, opaque=opaque)[:1] + coo_spmm_p.def_abstract_eval(_coo_spmm_abstract_eval) dispatch.simple_impl(coo_spmm_p) -if gpu_sparse.cuda_is_supported: - mlir.register_lowering( - coo_spmm_p, - partial(_coo_spmm_gpu_lowering, gpu_sparse.cuda_coo_matmat), - platform='cuda') -if gpu_sparse.rocm_is_supported: - mlir.register_lowering( - coo_spmm_p, - partial(_coo_spmm_gpu_lowering, gpu_sparse.rocm_coo_matmat), - platform='rocm') +mlir.register_lowering( + coo_spmm_p, + partial(_coo_spmm_gpu_lowering, target_name_prefix='cu'), + platform='cuda') +mlir.register_lowering( + coo_spmm_p, + partial(_coo_spmm_gpu_lowering, target_name_prefix='hip'), + platform='rocm') # csr_spmv_p # This is an internal-only primitive that calls into cusparse csr SpMV. @@ -151,30 +201,31 @@ def _csr_spmv_abstract_eval(data, indices, indptr, x, *, transpose, shape): shape=shape[1:] if transpose else shape[:1], dtype=x.dtype) -def _csr_spmv_gpu_lowering(csr_spmv_hlo, ctx, data, indices, indptr, x, *, transpose, shape): +def _csr_spmv_gpu_lowering(ctx, data, indices, indptr, x, *, transpose, shape, + target_name_prefix): + rows, cols = shape data_aval, indices_aval, _, x_aval = ctx.avals_in - return [csr_spmv_hlo( - data, indices, indptr, x, - shape=shape, - transpose=transpose, - data_dtype=data_aval.dtype, - index_dtype=indices_aval.dtype, - x_dtype=x_aval.dtype)] + nnz, = data_aval.shape + buffer_size, opaque = _get_module(target_name_prefix).build_csr_matvec_descriptor( + data_aval.dtype, x_aval.dtype, data_aval.dtype, indices_aval.dtype, + rows, cols, nnz, transpose) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_csr_matvec_ffi") + return rule(sub_ctx, data, indices, indptr, x, opaque=opaque)[:1] csr_spmv_p.def_abstract_eval(_csr_spmv_abstract_eval) dispatch.simple_impl(csr_spmv_p) -if gpu_sparse.cuda_is_supported: - mlir.register_lowering( - csr_spmv_p, - partial(_csr_spmv_gpu_lowering, gpu_sparse.cuda_csr_matvec), - platform='cuda') -if gpu_sparse.rocm_is_supported: - mlir.register_lowering( - csr_spmv_p, - partial(_csr_spmv_gpu_lowering, gpu_sparse.rocm_csr_matvec), - platform='rocm') - - # csr_spmm_p +mlir.register_lowering( + csr_spmv_p, + partial(_csr_spmv_gpu_lowering, target_name_prefix='cu'), + platform='cuda') +mlir.register_lowering( + csr_spmv_p, + partial(_csr_spmv_gpu_lowering, target_name_prefix='hip'), + platform='rocm') + +# csr_spmm_p # This is an internal-only primitive that calls into cusparse CSR SpMM. # This is a raw lowering that does no validation of inputs; the indices are # assumed to be lexicographically sorted, deduplicated, and in-bounds. @@ -199,25 +250,89 @@ def _csr_spmm_abstract_eval(data, indices, indptr, x, *, transpose, shape): shape=(shape[1] if transpose else shape[0], x.shape[1]), dtype=x.dtype) -def _csr_spmm_gpu_lowering(csr_spmm_hlo, ctx, data, indices, indptr, x, *, transpose, shape): +def _csr_spmm_gpu_lowering(ctx, data, indices, indptr, x, *, transpose, shape, + target_name_prefix): + rows, cols = shape data_aval, indices_aval, _, x_aval = ctx.avals_in - return [csr_spmm_hlo( - data, indices, indptr, x, - shape=shape, - transpose=transpose, - data_dtype=data_aval.dtype, - index_dtype=indices_aval.dtype, - B_dtype=x_aval.dtype)] + nnz, = data_aval.shape + _, Ccols = x_aval.shape + buffer_size, opaque = _get_module(target_name_prefix).build_csr_matmat_descriptor( + data_aval.dtype, x_aval.dtype, data_aval.dtype, indices_aval.dtype, + rows, cols, Ccols, nnz, transpose) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_csr_matmat_ffi") + return rule(sub_ctx, data, indices, indptr, x, opaque=opaque)[:1] csr_spmm_p.def_abstract_eval(_csr_spmm_abstract_eval) dispatch.simple_impl(csr_spmm_p) -if gpu_sparse.cuda_is_supported: +mlir.register_lowering( + csr_spmm_p, + partial(_csr_spmm_gpu_lowering, target_name_prefix='cu'), + platform='cuda') +mlir.register_lowering( + csr_spmm_p, + partial(_csr_spmm_gpu_lowering, target_name_prefix='hip'), + platform='rocm') + + +if has_cpu_sparse: + def _csr_spmm_cpu_lowering(ctx, data, outer_indices, inner_indices, rhs): + rule = ffi.ffi_lowering("cpu_csr_sparse_dense_ffi") + return rule(ctx, data, outer_indices, inner_indices, rhs) + + + # _csr_spmm_cpu_lowering can handle both matrix-matrix and matrix-vector + # multiplication. mlir.register_lowering( - csr_spmm_p, - partial(_csr_spmm_gpu_lowering, gpu_sparse.cuda_csr_matmat), - platform='cuda') -if gpu_sparse.rocm_is_supported: + csr_spmv_p, + _csr_spmm_cpu_lowering, + platform="cpu", + ) mlir.register_lowering( - csr_spmm_p, - partial(_csr_spmm_gpu_lowering, gpu_sparse.rocm_csr_matmat), - platform='rocm') + csr_spmm_p, + _csr_spmm_cpu_lowering, + platform="cpu", + ) + +def coo_todense_gpu_lowering(ctx, data, row, col, *, shape, target_name_prefix): + data_aval, row_aval, _ = ctx.avals_in + nnz, = data_aval.shape + rows, cols = shape + buffer_size, opaque = _get_module(target_name_prefix).build_coo_todense_descriptor( + data_aval.dtype, row_aval.dtype, rows, cols, nnz) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_coo_todense_ffi") + return rule(sub_ctx, data, row, col, opaque=opaque)[0] + +def coo_fromdense_gpu_lowering(ctx, mat, *, nnz, index_dtype, target_name_prefix): + mat_aval, = ctx.avals_in + rows, cols = mat_aval.shape + buffer_size, opaque = _get_module(target_name_prefix).build_coo_fromdense_descriptor( + mat_aval.dtype, np.dtype(index_dtype), rows, cols, nnz) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[*ctx.avals_out, buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_coo_fromdense_ffi") + return rule(sub_ctx, mat, opaque=opaque)[:3] + +def csr_todense_gpu_lowering(ctx, data, indices, indptr, *, shape, target_name_prefix): + data_aval, indices_aval, _, = ctx.avals_in + nnz, = data_aval.shape + rows, cols = shape + buffer_size, opaque = _get_module(target_name_prefix).build_csr_todense_descriptor( + data_aval.dtype, indices_aval.dtype, rows, cols, nnz) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_csr_todense_ffi") + return rule(sub_ctx, data, indices, indptr, opaque=opaque)[0] + +def csr_fromdense_gpu_lowering(ctx, mat, *, nnz, index_dtype, target_name_prefix): + mat_aval, = ctx.avals_in + rows, cols = mat_aval.shape + buffer_size, opaque = _get_module(target_name_prefix).build_csr_fromdense_descriptor( + mat_aval.dtype, np.dtype(index_dtype), rows, cols, nnz) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[*ctx.avals_out, buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_csr_fromdense_ffi") + return rule(sub_ctx, mat, opaque=opaque)[:3] diff --git a/jax/experimental/sparse/ad.py b/jax/experimental/sparse/ad.py index 018047e3d5e1..861ef5289cdd 100644 --- a/jax/experimental/sparse/ad.py +++ b/jax/experimental/sparse/ad.py @@ -22,7 +22,7 @@ from jax._src import core from jax import tree_util from jax._src.api_util import _ensure_index, _ensure_index_tuple -from jax.util import safe_zip +from jax._src.util import safe_zip from jax._src.util import split_list, wraps from jax._src.traceback_util import api_boundary from jax.experimental.sparse._base import JAXSparse diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index 42820fe73651..60c217fc6113 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -38,7 +38,7 @@ from jax.experimental.sparse._lowerings import coo_spmv_p, coo_spmm_p from jax._src.interpreters import mlir import jax.numpy as jnp -from jax.util import safe_zip, unzip2, split_list +from jax._src.util import safe_zip, unzip2, split_list from jax._src import api_util from jax._src import config from jax._src import core @@ -49,7 +49,6 @@ from jax._src.lax.lax import ( _const, ranges_like, remaining, _dot_general_batch_dim_nums, DotDimensionNumbers) from jax._src.lax.slicing import GatherDimensionNumbers, GatherScatterMode -from jax._src.lib import gpu_sparse from jax._src.numpy.setops import _unique from jax._src.typing import Array, ArrayLike, DTypeLike from jax._src.util import canonicalize_axis @@ -923,12 +922,10 @@ def _bcoo_dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers, mlir.register_lowering(bcoo_dot_general_p, _bcoo_dot_general_default_lowering) dispatch.simple_impl(bcoo_dot_general_p) -if gpu_sparse.cuda_is_supported: - mlir.register_lowering( - bcoo_dot_general_p, _bcoo_dot_general_gpu_lowering, platform='cuda') -if gpu_sparse.rocm_is_supported: - mlir.register_lowering( - bcoo_dot_general_p, _bcoo_dot_general_gpu_lowering, platform='rocm') +mlir.register_lowering( + bcoo_dot_general_p, _bcoo_dot_general_gpu_lowering, platform='cuda') +mlir.register_lowering( + bcoo_dot_general_p, _bcoo_dot_general_gpu_lowering, platform='rocm') #---------------------------------------------------------------------- @@ -1043,7 +1040,7 @@ def _bcoo_dot_general_sampled_impl(A, B, indices, *, dimension_numbers): @bcoo_dot_general_sampled_p.def_abstract_eval def _bcoo_dot_general_sampled_abstract_eval(A, B, indices, *, dimension_numbers): dbg = api_util.debug_info("bcoo_dot_general_sampled_abstract_eval", - lax.dot_general, (A, B), dict(dimension_numbers=dimension_numbers)) + lax.dot_general, (A, B), {}) dense_result, = pe.abstract_eval_fun(lambda *args: [lax.dot_general(*args, dimension_numbers=dimension_numbers)], A, B, debug_info=dbg) dbg = api_util.debug_info("bcoo_dot_general_sampled_abstract_eval", @@ -1418,7 +1415,7 @@ def _bcoo_sum_duplicates_impl(data, indices, *, spinfo, nse): nse = 1 if props.n_sparse == 0 else nse_batched.max() indices_out = _adjust_indices_nse(indices_out, nse=nse, shape=spinfo.shape) if props.n_sparse == 0: - data = data.sum(props.n_batch, keepdims=True) + data = data.sum(props.n_batch, keepdims=True, dtype=data.dtype) data_out = jnp.empty((*map(max, indices.shape[:props.n_batch], data.shape[:props.n_batch]), nse, *data.shape[props.n_batch + 1:]), dtype=data.dtype) permute = lambda d_out, m, d: d_out.at[m].add(d, mode='drop') @@ -1539,8 +1536,8 @@ def _bcoo_sum_duplicates_jvp(primals, tangents, *, spinfo, nse): "jit, vmap, and other transformations requiring abstract evaluation.") indices_out = _adjust_indices_nse(indices_out, nse=nse, shape=spinfo.shape) if props.n_sparse == 0: - data = data.sum(props.n_batch, keepdims=True) - data_dot = data_dot.sum(props.n_batch, keepdims=True) + data = data.sum(props.n_batch, keepdims=True, dtype=data.dtype) + data_dot = data_dot.sum(props.n_batch, keepdims=True, dtype=data_dot.dtype) data_out = jnp.empty((*map(max, indices.shape[:props.n_batch], data.shape[:props.n_batch]), nse, *data.shape[props.n_batch + 1:]), dtype=data.dtype) data_dot_out = data_out @@ -1835,7 +1832,7 @@ def bcoo_concatenate(operands: Sequence[BCOO], *, dimension: int) -> BCOO: def bcoo_reshape(mat: BCOO, *, new_sizes: Sequence[int], dimensions: Sequence[int] | None = None, sharding=None) -> BCOO: - """Sparse implementation of {func}`jax.lax.reshape`. + """Sparse implementation of :func:`jax.lax.reshape`. Args: operand: BCOO array to be reshaped. @@ -1898,7 +1895,7 @@ def bcoo_reshape(mat: BCOO, *, new_sizes: Sequence[int], def bcoo_rev(operand, dimensions): - """Sparse implementation of {func}`jax.lax.rev`""" + """Sparse implementation of :func:`jax.lax.rev`""" # Check validity of dimensions via original implementation. _ = jax.jit(lax.rev, static_argnames=("dimensions",)).eval_shape( jax.ShapeDtypeStruct(operand.shape, operand.dtype), @@ -1926,7 +1923,7 @@ def bcoo_rev(operand, dimensions): def bcoo_squeeze(arr: BCOO, *, dimensions: Sequence[int]) -> BCOO: - """Sparse implementation of {func}`jax.lax.squeeze`. + """Sparse implementation of :func:`jax.lax.squeeze`. Squeeze any number of size 1 dimensions from an array. @@ -1955,7 +1952,7 @@ def bcoo_squeeze(arr: BCOO, *, dimensions: Sequence[int]) -> BCOO: def bcoo_slice(mat: BCOO, *, start_indices: Sequence[int], limit_indices: Sequence[int], strides: Sequence[int] | None = None) -> BCOO: - """Sparse implementation of {func}`jax.lax.slice`. + """Sparse implementation of :func:`jax.lax.slice`. Args: mat: BCOO array to be reshaped. @@ -2031,7 +2028,7 @@ def bcoo_slice(mat: BCOO, *, start_indices: Sequence[int], limit_indices: Sequen return BCOO((new_data, new_indices), shape=new_shape) def bcoo_dynamic_slice(mat: BCOO, start_indices: Sequence[Any], slice_sizes: Sequence[int]) -> BCOO: - """Sparse implementation of {func}`jax.lax.dynamic_slice`. + """Sparse implementation of :func:`jax.lax.dynamic_slice`. Args: mat: BCOO array to slice. @@ -2358,14 +2355,16 @@ def slice_func(indices): def bcoo_conv_general_dilated(lhs, rhs, *, window_strides, padding, lhs_dilation=None, rhs_dilation=None, dimension_numbers=None, feature_group_count=1, batch_group_count=1, precision=None, - preferred_element_type=None) -> BCOO: + preferred_element_type=None, + out_sharding=None) -> BCOO: # Validate and process parameters using lax.conv_general_dilated abstract evaluation. func = functools.partial( lax.conv_general_dilated, window_strides=window_strides, padding=padding, lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, batch_group_count=batch_group_count, - precision=precision, preferred_element_type=preferred_element_type) + precision=precision, preferred_element_type=preferred_element_type, + out_sharding=out_sharding) jaxpr = jax.make_jaxpr(func)(jax.ShapeDtypeStruct(lhs.shape, lhs.dtype), jax.ShapeDtypeStruct(rhs.shape, rhs.dtype)) assert isinstance(jaxpr, core.ClosedJaxpr) and len(jaxpr.eqns) == 1 diff --git a/jax/experimental/sparse/bcsr.py b/jax/experimental/sparse/bcsr.py index 7fefd1572f45..cfa52fe35dfe 100644 --- a/jax/experimental/sparse/bcsr.py +++ b/jax/experimental/sparse/bcsr.py @@ -27,19 +27,19 @@ import jax.numpy as jnp from jax import lax from jax import tree_util +from jax.experimental.sparse import _lowerings from jax.experimental.sparse._base import JAXSparse from jax.experimental.sparse import bcoo from jax.experimental.sparse.util import ( - nfold_vmap, _count_stored_elements, - _csr_to_coo, CuSparseEfficiencyWarning, SparseInfo, Shape) -from jax.util import split_list, safe_zip + nfold_vmap, _count_stored_elements, _csr_to_coo, + SparseEfficiencyWarning, CuSparseEfficiencyWarning, SparseInfo, Shape) +from jax._src.util import split_list, safe_zip from jax._src import api_util from jax._src import config from jax._src import core from jax._src import dispatch from jax._src.lax.lax import DotDimensionNumbers, _dot_general_batch_dim_nums -from jax._src.lib import gpu_sparse from jax._src.lib.mlir.dialects import hlo from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -144,7 +144,7 @@ def _bcsr_to_bcoo(indices: jax.Array, indptr: jax.Array, *, def _bcoo_to_bcsr(indices: Array, *, shape: Sequence[int], - index_dtype: DTypeLike = jnp.int32) -> tuple[Array, Array]: + index_dtype: DTypeLike) -> tuple[Array, Array]: """Given BCOO (indices), return BCSR (indices, indptr). Note: this assumes that ``indices`` are lexicographically sorted within each batch. @@ -237,7 +237,9 @@ def _bcsr_fromdense_impl(mat, *, nse, n_batch, n_dense, index_dtype): raise ValueError("bcsr_fromdense: must have 2 sparse dimensions.") bcoo_mat = bcoo.bcoo_fromdense(mat, nse=nse, index_dtype=index_dtype, n_dense=n_dense, n_batch=n_batch) - indices, indptr = _bcoo_to_bcsr(bcoo_mat.indices, shape=mat.shape) + indices, indptr = _bcoo_to_bcsr( + bcoo_mat.indices, shape=mat.shape, index_dtype=index_dtype + ) return bcoo_mat.data, indices, indptr @@ -620,9 +622,9 @@ def _bcsr_correct_out_of_bound_indices(data, indices, indptr, rhs, *, shape): _bcsr_correct_out_of_bound_indices, multiple_results=True) def _bcsr_dot_general_gpu_lowering( - csr_matvec_lowering, csr_matmat_lowering, + # csr_matvec_lowering, csr_matmat_lowering, ctx, lhs_data, lhs_indices, lhs_indptr, rhs, *, dimension_numbers, - preferred_element_type, lhs_spinfo: SparseInfo): + preferred_element_type, lhs_spinfo: SparseInfo, target_name_prefix): if not config.bcoo_cusparse_lowering.value: return _bcsr_dot_general_default_lowering( @@ -674,22 +676,112 @@ def _bcsr_dot_general_gpu_lowering( lhs_data, lhs_indices = _bcsr_correct_out_of_bound_indices_lowered( ctx, lhs_data, lhs_indices, lhs_indptr, rhs, shape=lhs_spinfo.shape) + sub_ctx = ctx if rhs_aval.ndim == 1: - dot_general_fn = csr_matvec_lowering - x_dtype = 'x_dtype' + dot_general_fn = _lowerings._csr_spmv_gpu_lowering elif rhs_aval.ndim == 2: - dot_general_fn = csr_matmat_lowering - x_dtype = 'B_dtype' + dot_general_fn = _lowerings._csr_spmm_gpu_lowering if rhs_contract[0] == 1: rhs = hlo.transpose(rhs, permutation=mlir.dense_int_array([1, 0])) + *avals_in, rhs_aval = sub_ctx.avals_in + rhs_aval = core.ShapedArray( + shape=(rhs_aval.shape[1], rhs_aval.shape[0]), dtype=rhs_aval.dtype) + sub_ctx = sub_ctx.replace(avals_in=[*avals_in, rhs_aval]) else: raise ValueError(f"rhs has to be 1d or 2d; get {rhs_aval.ndim}d.") - return [dot_general_fn(lhs_data, lhs_indices, lhs_indptr, rhs, - shape=lhs_spinfo.shape, transpose=False, - data_dtype=lhs_data_aval.dtype, - index_dtype=lhs_indices_aval.dtype, - **{x_dtype: rhs_aval.dtype})] + return dot_general_fn(sub_ctx, lhs_data, lhs_indices, lhs_indptr, rhs, + shape=lhs_spinfo.shape, transpose=False, + target_name_prefix=target_name_prefix) + + +def _bcsr_dot_general_cpu_lowering( + # csr_matvec_lowering, csr_matmat_lowering, + ctx, + lhs_data, + lhs_indices, + lhs_indptr, + rhs, + *, + dimension_numbers, + preferred_element_type, + lhs_spinfo: SparseInfo, +): + + (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers + lhs_data_aval, lhs_indices_aval, lhs_indptr_aval, rhs_aval = ctx.avals_in + props = _validate_bcsr( + lhs_data_aval, lhs_indices_aval, lhs_indptr_aval, lhs_spinfo.shape + ) + + use_default_lowering = False + dtype = lhs_data_aval.dtype + if lhs_batch or rhs_batch: + # TODO(willfroom): Add support for batched matrices. + use_default_lowering = True + elif lhs_data_aval.dtype != rhs_aval.dtype: + use_default_lowering = True + elif ( + preferred_element_type is not None + and preferred_element_type != lhs_data_aval.dtype + ): + use_default_lowering = True + elif len(lhs_spinfo.shape) != 2 or rhs_aval.ndim not in [1, 2]: + # only matmat / matvec supported + use_default_lowering = True + elif props.n_batch or props.n_dense: + # batch and dense dimensions in BCSR not supported + use_default_lowering = True + elif list(lhs_contract) != [1] or list(rhs_contract) != [0]: + # TODO(willfroom): Add support for non-canonical dots. + use_default_lowering = True + elif lhs_indices_aval.dtype != lhs_indptr_aval.dtype: + warnings.warn( + "bcsr_dot_general cpu lowering not available, " + f" {lhs_indices_aval.dtype=} and {lhs_indptr_aval.dtype=} do not match." + " Falling back to default implementation.", + SparseEfficiencyWarning, + ) + use_default_lowering = True + elif lhs_indices_aval.dtype not in [np.int32, np.int64]: + use_default_lowering = True + warnings.warn( + "bcsr_dot_general cpu lowering not available for" + f" {lhs_indices_aval.dtype=}. Falling back to default implementation.", + SparseEfficiencyWarning, + ) + elif dtype not in [ + np.int32, + np.int64, + np.float32, + np.float64, + np.complex64, + np.complex128, + ]: + # This would be supported if not for the dtype. + warnings.warn( + "bcsr_dot_general cpu lowering not available " + f"for {dtype=}. Falling back to default implementation.", + SparseEfficiencyWarning, + ) + use_default_lowering = True + + if use_default_lowering: + return _bcsr_dot_general_default_lowering( + ctx, + lhs_data, + lhs_indices, + lhs_indptr, + rhs, + dimension_numbers=dimension_numbers, + preferred_element_type=preferred_element_type, + lhs_spinfo=lhs_spinfo, + ) + + return _lowerings._csr_spmm_cpu_lowering( + ctx, lhs_data, lhs_indptr, lhs_indices, rhs + ) + _bcsr_dot_general_default_lowering = mlir.lower_fun( _bcsr_dot_general_impl, multiple_results=False) @@ -697,19 +789,20 @@ def _bcsr_dot_general_gpu_lowering( bcsr_dot_general_p, _bcsr_dot_general_default_lowering) dispatch.simple_impl(bcsr_dot_general_p) -if gpu_sparse.cuda_is_supported: - mlir.register_lowering(bcsr_dot_general_p, - partial(_bcsr_dot_general_gpu_lowering, - gpu_sparse.cuda_csr_matvec, - gpu_sparse.cuda_csr_matmat), - platform='cuda') -if gpu_sparse.rocm_is_supported: - mlir.register_lowering(bcsr_dot_general_p, - partial(_bcsr_dot_general_gpu_lowering, - gpu_sparse.rocm_csr_matvec, - gpu_sparse.rocm_csr_matmat), - platform='rocm') +mlir.register_lowering(bcsr_dot_general_p, + partial(_bcsr_dot_general_gpu_lowering, + target_name_prefix='cu'), + platform='cuda') +mlir.register_lowering(bcsr_dot_general_p, + partial(_bcsr_dot_general_gpu_lowering, + target_name_prefix='hip'), + platform='rocm') + +if _lowerings.has_cpu_sparse: + mlir.register_lowering( + bcsr_dot_general_p, _bcsr_dot_general_cpu_lowering, platform="cpu" + ) #---------------------------------------------------------------------- # BCOO functions that maybe should be primitives? @@ -867,7 +960,9 @@ def from_bcoo(cls, arr: bcoo.BCOO) -> BCSR: raise NotImplementedError(f"BSCR.from_bcoo requires n_sparse=2; got {arr.n_sparse=}") if not arr.indices_sorted: arr = arr.sort_indices() - indices, indptr = _bcoo_to_bcsr(arr.indices, shape=arr.shape) + indices, indptr = _bcoo_to_bcsr( + arr.indices, shape=arr.shape, index_dtype=arr.indices.dtype + ) return cls((arr.data, indices, indptr), shape=arr.shape) @classmethod diff --git a/jax/experimental/sparse/coo.py b/jax/experimental/sparse/coo.py index c65bc87235d6..d29a80218e66 100644 --- a/jax/experimental/sparse/coo.py +++ b/jax/experimental/sparse/coo.py @@ -26,6 +26,7 @@ import jax from jax import lax from jax.interpreters import mlir +from jax.experimental.sparse import _lowerings from jax.experimental.sparse._base import JAXSparse from jax.experimental.sparse.util import _coo_extract, CuSparseEfficiencyWarning from jax import tree_util @@ -34,7 +35,6 @@ from jax._src.interpreters import ad from jax._src.lax.lax import _const from jax._src.lib.mlir.dialects import hlo -from jax._src.lib import gpu_sparse from jax._src.numpy.util import promote_dtypes from jax._src.typing import Array, ArrayLike, DTypeLike import jax.numpy as jnp @@ -205,7 +205,7 @@ def _coo_todense_abstract_eval(data, row, col, *, spinfo): _coo_todense_lowering = mlir.lower_fun( _coo_todense_impl, multiple_results=False) -def _coo_todense_gpu_lowering(coo_todense_hlo, ctx, data, row, col, *, spinfo): +def _coo_todense_gpu_lowering(ctx, data, row, col, *, spinfo, target_name_prefix): data_aval, row_aval, _ = ctx.avals_in dtype = data_aval.dtype if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)): @@ -226,8 +226,13 @@ def _coo_todense_gpu_lowering(coo_todense_hlo, ctx, data, row, col, *, spinfo): "back to the default implementation.", CuSparseEfficiencyWarning) return _coo_todense_lowering(ctx, data, row, col, spinfo=spinfo) - result = coo_todense_hlo( - data, row, col, shape=shape, data_dtype=dtype, index_dtype=row_aval.dtype) + sub_ctx = ctx + if transpose: + out_aval, = ctx.avals_out + out_aval = core.ShapedArray(shape=out_aval.shape[::-1], dtype=out_aval.dtype) + sub_ctx = sub_ctx.replace(avals_out=[out_aval]) + result = _lowerings.coo_todense_gpu_lowering( + sub_ctx, data, row, col, shape=shape, target_name_prefix=target_name_prefix) return ( [hlo.transpose(result, mlir.dense_int_array([1, 0]))] if transpose else [result]) @@ -252,16 +257,14 @@ def _coo_todense_transpose(ct, data, row, col, *, spinfo): mlir.register_lowering(coo_todense_p, _coo_todense_lowering) dispatch.simple_impl(coo_todense_p) -if gpu_sparse.cuda_is_supported: - mlir.register_lowering( - coo_todense_p, - partial(_coo_todense_gpu_lowering, gpu_sparse.cuda_coo_todense), - platform='cuda') -if gpu_sparse.rocm_is_supported: - mlir.register_lowering( - coo_todense_p, - partial(_coo_todense_gpu_lowering, gpu_sparse.rocm_coo_todense), - platform='rocm') +mlir.register_lowering( + coo_todense_p, + partial(_coo_todense_gpu_lowering, target_name_prefix='cu'), + platform='cuda') +mlir.register_lowering( + coo_todense_p, + partial(_coo_todense_gpu_lowering, target_name_prefix='hip'), + platform='rocm') #-------------------------------------------------------------------- # coo_fromdense @@ -325,20 +328,15 @@ def _coo_fromdense_abstract_eval(mat, *, nse, index_dtype): _coo_fromdense_lowering = mlir.lower_fun( _coo_fromdense_impl, multiple_results=True) -def _coo_fromdense_gpu_lowering(coo_fromdense_hlo, ctx, mat, *, nse, - index_dtype): +def _coo_fromdense_gpu_lowering(ctx, mat, *, nse, index_dtype, target_name_prefix): dtype = ctx.avals_in[0].dtype if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)): warnings.warn(f"coo_fromdense cusparse/hipsparse lowering not available for {dtype=}. " "Falling back to default implementation.", CuSparseEfficiencyWarning) return _coo_fromdense_lowering(ctx, mat, nse=nse, index_dtype=index_dtype) - data, row, col = coo_fromdense_hlo( - mat, nnz=nse, - data_dtype=dtype, - index_dtype=np.dtype(index_dtype), - index_type=mlir.dtype_to_ir_type(np.dtype(index_dtype))) - return [data, row, col] - + return _lowerings.coo_fromdense_gpu_lowering( + ctx, mat, nnz=nse, index_dtype=index_dtype, + target_name_prefix=target_name_prefix) def _coo_fromdense_jvp(primals, tangents, *, nse, index_dtype): M, = primals @@ -370,16 +368,14 @@ def _coo_fromdense_transpose(ct, M, *, nse, index_dtype): mlir.register_lowering(coo_fromdense_p, _coo_fromdense_lowering) dispatch.simple_impl(coo_fromdense_p) -if gpu_sparse.cuda_is_supported: - mlir.register_lowering( - coo_fromdense_p, - partial(_coo_fromdense_gpu_lowering, gpu_sparse.cuda_coo_fromdense), - platform='cuda') -if gpu_sparse.rocm_is_supported: - mlir.register_lowering( - coo_fromdense_p, - partial(_coo_fromdense_gpu_lowering, gpu_sparse.rocm_coo_fromdense), - platform='rocm') +mlir.register_lowering( + coo_fromdense_p, + partial(_coo_fromdense_gpu_lowering, target_name_prefix='cu'), + platform='cuda') +mlir.register_lowering( + coo_fromdense_p, + partial(_coo_fromdense_gpu_lowering, target_name_prefix='hip'), + platform='rocm') #-------------------------------------------------------------------- # coo_matvec @@ -444,8 +440,8 @@ def _coo_matvec_abstract_eval(data, row, col, v, *, spinfo, transpose): _coo_matvec_lowering = mlir.lower_fun( _coo_matvec_impl, multiple_results=False) -def _coo_matvec_gpu_lowering(coo_matvec_hlo, ctx, data, row, col, v, *, spinfo, - transpose): +def _coo_matvec_gpu_lowering(ctx, data, row, col, v, *, spinfo, transpose, + target_name_prefix): data_aval, row_aval, _, x_aval = ctx.avals_in dtype = data_aval.dtype if dtype not in [np.float32, np.float64, np.complex64, np.complex128]: @@ -466,9 +462,9 @@ def _coo_matvec_gpu_lowering(coo_matvec_hlo, ctx, data, row, col, v, *, spinfo, return _coo_matvec_lowering(ctx, data, row, col, v, spinfo=spinfo, transpose=transpose) - return [coo_matvec_hlo( - data, row, col, v, shape=shape, transpose=transpose, - index_dtype=row_aval.dtype, data_dtype=dtype, x_dtype=x_aval.dtype)] + return _lowerings._coo_spmv_gpu_lowering( + ctx, data, row, col, v, transpose=transpose, shape=shape, + target_name_prefix=target_name_prefix) def _coo_matvec_jvp_mat(data_dot, data, row, col, v, *, spinfo, transpose): @@ -494,16 +490,14 @@ def _coo_matvec_transpose(ct, data, row, col, v, *, spinfo, transpose): mlir.register_lowering(coo_matvec_p, _coo_matvec_lowering) dispatch.simple_impl(coo_matvec_p) -if gpu_sparse.cuda_is_supported: - mlir.register_lowering( - coo_matvec_p, - partial(_coo_matvec_gpu_lowering, gpu_sparse.cuda_coo_matvec), - platform='cuda') -if gpu_sparse.rocm_is_supported: - mlir.register_lowering( - coo_matvec_p, - partial(_coo_matvec_gpu_lowering, gpu_sparse.rocm_coo_matvec), - platform='rocm') +mlir.register_lowering( + coo_matvec_p, + partial(_coo_matvec_gpu_lowering, target_name_prefix='cu'), + platform='cuda') +mlir.register_lowering( + coo_matvec_p, + partial(_coo_matvec_gpu_lowering, target_name_prefix='hip'), + platform='rocm') #-------------------------------------------------------------------- @@ -567,8 +561,8 @@ def _coo_matmat_abstract_eval(data, row, col, B, *, spinfo, transpose): _coo_matmat_lowering = mlir.lower_fun(_coo_matmat_impl, multiple_results=False) -def _coo_matmat_gpu_lowering(coo_matmat_hlo, ctx, data, row, col, B, *, spinfo, - transpose): +def _coo_matmat_gpu_lowering(ctx, data, row, col, B, *, spinfo, transpose, + target_name_prefix): data_aval, row_aval, _, B_aval = ctx.avals_in dtype = data_aval.dtype if dtype not in [np.float32, np.float64, np.complex64, np.complex128]: @@ -589,10 +583,9 @@ def _coo_matmat_gpu_lowering(coo_matmat_hlo, ctx, data, row, col, B, *, spinfo, return _coo_matmat_lowering(ctx, data, row, col, B, spinfo=spinfo, transpose=transpose) - return [coo_matmat_hlo(data, row, col, B, shape=shape, - transpose=transpose, x_dtype=B_aval.dtype, - data_dtype=data_aval.dtype, - index_dtype=row_aval.dtype)] + return _lowerings._coo_spmm_gpu_lowering( + ctx, data, row, col, B, transpose=transpose, shape=shape, + target_name_prefix=target_name_prefix) def _coo_matmat_jvp_left(data_dot, data, row, col, B, *, spinfo, transpose): @@ -615,13 +608,11 @@ def _coo_matmat_transpose(ct, data, row, col, B, *, spinfo, transpose): mlir.register_lowering(coo_matmat_p, _coo_matmat_lowering) dispatch.simple_impl(coo_matmat_p) -if gpu_sparse.cuda_is_supported: - mlir.register_lowering( - coo_matmat_p, - partial(_coo_matmat_gpu_lowering, gpu_sparse.cuda_coo_matmat), - platform='cuda') -if gpu_sparse.rocm_is_supported: - mlir.register_lowering( - coo_matmat_p, - partial(_coo_matmat_gpu_lowering, gpu_sparse.rocm_coo_matmat), - platform='rocm') +mlir.register_lowering( + coo_matmat_p, + partial(_coo_matmat_gpu_lowering, target_name_prefix='cu'), + platform='cuda') +mlir.register_lowering( + coo_matmat_p, + partial(_coo_matmat_gpu_lowering, target_name_prefix='hip'), + platform='rocm') diff --git a/jax/experimental/sparse/csr.py b/jax/experimental/sparse/csr.py index 84171855b85e..5b93267384d7 100644 --- a/jax/experimental/sparse/csr.py +++ b/jax/experimental/sparse/csr.py @@ -23,6 +23,7 @@ import jax from jax.interpreters import mlir +from jax.experimental.sparse import _lowerings from jax.experimental.sparse._base import JAXSparse from jax.experimental.sparse.coo import _coo_matmat, _coo_matvec, _coo_todense, COOInfo from jax.experimental.sparse.util import _csr_to_coo, _csr_extract, CuSparseEfficiencyWarning @@ -32,7 +33,6 @@ from jax._src import dispatch from jax._src.interpreters import ad from jax._src.lax.lax import _const -from jax._src.lib import gpu_sparse from jax._src.numpy.util import promote_dtypes from jax._src.typing import Array, DTypeLike import jax.numpy as jnp @@ -249,17 +249,16 @@ def _csr_todense_abstract_eval(data, indices, indptr, *, shape): _csr_todense_lowering = mlir.lower_fun( _csr_todense_impl, multiple_results=False) -def _csr_todense_gpu_lowering(csr_todense_hlo, ctx, data, indices, indptr, *, - shape): +def _csr_todense_gpu_lowering(ctx, data, indices, indptr, *, shape, target_name_prefix): data_aval, indices_aval, _ = ctx.avals_in dtype = data_aval.dtype if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)): warnings.warn(f"csr_todense cusparse/hipsparse lowering not available for {dtype=}. " "Falling back to default implementation.", CuSparseEfficiencyWarning) return _csr_todense_lowering(ctx, data, indices, indptr, shape=shape) - return [csr_todense_hlo( - data, indices, indptr, shape=shape, data_dtype=dtype, - index_dtype=indices_aval.dtype)] + return [_lowerings.csr_todense_gpu_lowering( + ctx, data, indices, indptr, shape=shape, + target_name_prefix=target_name_prefix)] def _csr_todense_jvp(data_dot, data, indices, indptr, *, shape): @@ -281,16 +280,14 @@ def _csr_todense_transpose(ct, data, indices, indptr, *, shape): mlir.register_lowering(csr_todense_p, _csr_todense_lowering) dispatch.simple_impl(csr_todense_p) -if gpu_sparse.cuda_is_supported: - mlir.register_lowering( - csr_todense_p, - partial(_csr_todense_gpu_lowering, gpu_sparse.cuda_csr_todense), - platform='cuda') -if gpu_sparse.rocm_is_supported: - mlir.register_lowering( - csr_todense_p, - partial(_csr_todense_gpu_lowering, gpu_sparse.rocm_csr_todense), - platform='rocm') +mlir.register_lowering( + csr_todense_p, + partial(_csr_todense_gpu_lowering, target_name_prefix='cu'), + platform='cuda') +mlir.register_lowering( + csr_todense_p, + partial(_csr_todense_gpu_lowering, target_name_prefix='hip'), + platform='rocm') #-------------------------------------------------------------------- @@ -359,16 +356,16 @@ def _csr_fromdense_abstract_eval(mat, *, nse, index_dtype): _csr_fromdense_lowering = mlir.lower_fun(_csr_fromdense_impl, multiple_results=True) -def _csr_fromdense_gpu_lowering(csr_fromdense_hlo, ctx, mat, *, nse, index_dtype): +def _csr_fromdense_gpu_lowering(ctx, mat, *, nse, index_dtype, + target_name_prefix): dtype = ctx.avals_in[0].dtype if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)): warnings.warn(f"csr_fromdense cusparse/hipsparse lowering not available for {dtype=}. " "Falling back to default implementation.", CuSparseEfficiencyWarning) return _csr_fromdense_lowering(ctx, mat, nse=nse, index_dtype=index_dtype) - data, indices, indptr = csr_fromdense_hlo( - mat, nnz=nse, index_dtype=np.dtype(index_dtype), - data_dtype=dtype, index_type=mlir.dtype_to_ir_type(np.dtype(index_dtype))) - return [data, indices, indptr] + return _lowerings.csr_fromdense_gpu_lowering( + ctx, mat, nnz=nse, index_dtype=index_dtype, + target_name_prefix=target_name_prefix) def _csr_fromdense_jvp(primals, tangents, *, nse, index_dtype): @@ -401,16 +398,14 @@ def _csr_fromdense_transpose(ct, M, *, nse, index_dtype): mlir.register_lowering(csr_fromdense_p, _csr_fromdense_lowering) dispatch.simple_impl(csr_fromdense_p) -if gpu_sparse.cuda_is_supported: - mlir.register_lowering( - csr_fromdense_p, - partial(_csr_fromdense_gpu_lowering, gpu_sparse.cuda_csr_fromdense), - platform='cuda') -if gpu_sparse.rocm_is_supported: - mlir.register_lowering( - csr_fromdense_p, - partial(_csr_fromdense_gpu_lowering, gpu_sparse.rocm_csr_fromdense), - platform='rocm') +mlir.register_lowering( + csr_fromdense_p, + partial(_csr_fromdense_gpu_lowering, target_name_prefix='cu'), + platform='cuda') +mlir.register_lowering( + csr_fromdense_p, + partial(_csr_fromdense_gpu_lowering, target_name_prefix='hip'), + platform='rocm') #-------------------------------------------------------------------- # csr_matvec @@ -470,8 +465,8 @@ def _csr_matvec_abstract_eval(data, indices, indptr, v, *, shape, transpose): _csr_matvec_lowering = mlir.lower_fun(_csr_matvec_impl, multiple_results=False) -def _csr_matvec_gpu_lowering(csr_matvec_hlo, ctx, data, indices, indptr, v, *, - shape, transpose): +def _csr_matvec_gpu_lowering(ctx, data, indices, indptr, v, *, shape, transpose, + target_name_prefix): data_aval, indices_aval, _, v_aval = ctx.avals_in dtype = data_aval.dtype if dtype not in [np.float32, np.float64, np.complex64, np.complex128]: @@ -479,10 +474,9 @@ def _csr_matvec_gpu_lowering(csr_matvec_hlo, ctx, data, indices, indptr, v, *, "Falling back to default implementation.", CuSparseEfficiencyWarning) return _csr_matvec_lowering(ctx, data, indices, indptr, v, shape=shape, transpose=transpose) - return [csr_matvec_hlo( - data, indices, indptr, v, shape=shape, transpose=transpose, - data_dtype=dtype, index_dtype=indices_aval.dtype, x_dtype=v_aval.dtype)] - + return _lowerings._csr_spmv_gpu_lowering( + ctx, data, indices, indptr, v, shape=shape, transpose=transpose, + target_name_prefix=target_name_prefix) def _csr_matvec_jvp_mat(data_dot, data, indices, indptr, v, *, shape, transpose): return _csr_matvec(data_dot, indices, indptr, v, shape=shape, transpose=transpose) @@ -508,16 +502,14 @@ def _csr_matvec_transpose(ct, data, indices, indptr, v, *, shape, transpose): mlir.register_lowering(csr_matvec_p, _csr_matvec_lowering) dispatch.simple_impl(csr_matvec_p) -if gpu_sparse.cuda_is_supported: - mlir.register_lowering( - csr_matvec_p, - partial(_csr_matvec_gpu_lowering, gpu_sparse.cuda_csr_matvec), - platform='cuda') -if gpu_sparse.rocm_is_supported: - mlir.register_lowering( - csr_matvec_p, - partial(_csr_matvec_gpu_lowering, gpu_sparse.rocm_csr_matvec), - platform='rocm') +mlir.register_lowering( + csr_matvec_p, + partial(_csr_matvec_gpu_lowering, target_name_prefix='cu'), + platform='cuda') +mlir.register_lowering( + csr_matvec_p, + partial(_csr_matvec_gpu_lowering, target_name_prefix='hip'), + platform='rocm') #-------------------------------------------------------------------- @@ -580,8 +572,8 @@ def _csr_matmat_abstract_eval(data, indices, indptr, B, *, shape, transpose): _csr_matmat_lowering = mlir.lower_fun(_csr_matmat_impl, multiple_results=False) -def _csr_matmat_gpu_lowering(csr_matmat_hlo, ctx, data, indices, indptr, B, *, - shape, transpose): +def _csr_matmat_gpu_lowering(ctx, data, indices, indptr, B, *, shape, transpose, + target_name_prefix): data_aval, indices_aval, _, B_aval = ctx.avals_in dtype = data_aval.dtype if dtype not in [np.float32, np.float64, np.complex64, np.complex128]: @@ -589,11 +581,9 @@ def _csr_matmat_gpu_lowering(csr_matmat_hlo, ctx, data, indices, indptr, B, *, "Falling back to default implementation.", CuSparseEfficiencyWarning) return _csr_matmat_lowering(ctx, data, indices, indptr, B, shape=shape, transpose=transpose) - return [csr_matmat_hlo( - data, indices, indptr, B, shape=shape, transpose=transpose, - index_dtype=indices_aval.dtype, data_dtype=data_aval.dtype, - B_dtype=B_aval.dtype)] - + return _lowerings._csr_spmm_gpu_lowering( + ctx, data, indices, indptr, B, shape=shape, transpose=transpose, + target_name_prefix=target_name_prefix) def _csr_matmat_jvp_left(data_dot, data, indices, indptr, B, *, shape, transpose): return _csr_matmat(data_dot, indices, indptr, B, shape=shape, transpose=transpose) @@ -617,14 +607,11 @@ def _csr_matmat_transpose(ct, data, indices, indptr, B, *, shape, transpose): mlir.register_lowering(csr_matmat_p, _csr_matmat_lowering) dispatch.simple_impl(csr_matmat_p) -if gpu_sparse: - if gpu_sparse.cuda_is_supported: - mlir.register_lowering( - csr_matmat_p, - partial(_csr_matmat_gpu_lowering, gpu_sparse.cuda_csr_matmat), - platform='cuda') - if gpu_sparse.rocm_is_supported: - mlir.register_lowering( - csr_matmat_p, - partial(_csr_matmat_gpu_lowering, gpu_sparse.rocm_csr_matmat), - platform='rocm') +mlir.register_lowering( + csr_matmat_p, + partial(_csr_matmat_gpu_lowering, target_name_prefix='cu'), + platform='cuda') +mlir.register_lowering( + csr_matmat_p, + partial(_csr_matmat_gpu_lowering, target_name_prefix='hip'), + platform='rocm') diff --git a/jax/experimental/sparse/linalg.py b/jax/experimental/sparse/linalg.py index a931b0a30dcf..c201ec8ea0fd 100644 --- a/jax/experimental/sparse/linalg.py +++ b/jax/experimental/sparse/linalg.py @@ -24,15 +24,14 @@ from jax.experimental import sparse from jax.interpreters import mlir -from jax.interpreters import xla from jax._src import core +from jax._src import dispatch from jax._src import ffi from jax._src.interpreters import ad -from jax._src.lib import gpu_solver import numpy as np -from scipy.sparse import csr_matrix, linalg +import scipy.sparse def lobpcg_standard( @@ -265,7 +264,7 @@ def _check_inputs(A, X): def _mm(a, b, precision=jax.lax.Precision.HIGHEST): - return jax.lax.dot(a, b, (precision, precision)) + return jax.lax.dot(a, b, precision=(precision, precision)) def _generate_diagnostics(prev_XPR, X, P, R, theta, converged, adj_resid): k = X.shape[1] @@ -534,11 +533,6 @@ def _spsolve_abstract_eval(data, indices, indptr, b, *, tol, reorder): def _spsolve_gpu_lowering(ctx, data, indices, indptr, b, *, tol, reorder): - # TODO(danfm): remove after JAX 0.5.1 release. - if hasattr(gpu_solver, "cuda_csrlsvqr"): - data_aval, _, _, _, = ctx.avals_in - return gpu_solver.cuda_csrlsvqr(data_aval.dtype, data, indices, - indptr, b, tol, reorder) return ffi.ffi_lowering("cusolver_csrlsvqr_ffi")( ctx, data, indices, indptr, b, tol=np.float64(tol), reorder=np.int32(reorder)) @@ -548,12 +542,12 @@ def _spsolve_cpu_lowering(ctx, data, indices, indptr, b, tol, reorder): args = [data, indices, indptr, b] def _callback(data, indices, indptr, b, **kwargs): - A = csr_matrix((data, indices, indptr), shape=(b.size, b.size)) - return (linalg.spsolve(A, b).astype(b.dtype),) + A = scipy.sparse.csr_matrix((data, indices, indptr), shape=(b.size, b.size)) + return (scipy.sparse.linalg.spsolve(A, b).astype(b.dtype),) result, _, _ = mlir.emit_python_callback( ctx, _callback, None, args, ctx.avals_in, ctx.avals_out, - has_side_effect=False) + has_side_effect=False, returns_token=False) return result @@ -595,7 +589,7 @@ def _spsolve_transpose(ct, data, indices, indptr, b, **kwds): spsolve_p = core.Primitive('spsolve') -spsolve_p.def_impl(functools.partial(xla.apply_primitive, spsolve_p)) +spsolve_p.def_impl(functools.partial(dispatch.apply_primitive, spsolve_p)) spsolve_p.def_abstract_eval(_spsolve_abstract_eval) ad.defjvp(spsolve_p, _spsolve_jvp_lhs, None, None, _spsolve_jvp_rhs) ad.primitive_transposes[spsolve_p] = _spsolve_transpose diff --git a/jax/experimental/sparse/nm.py b/jax/experimental/sparse/nm.py index f9d28f5ff83c..766158522773 100644 --- a/jax/experimental/sparse/nm.py +++ b/jax/experimental/sparse/nm.py @@ -17,7 +17,6 @@ from jax._src import core from jax._src import dispatch from jax._src.lax.lax import DotDimensionNumbers -from jax._src.lib import gpu_sparse from jax._src.lib.mlir.dialects import mhlo from jax._src.typing import Array, DTypeLike from jax.interpreters import mlir @@ -37,7 +36,7 @@ def nm_spmm( lhs: Array, rhs: Array, metadata: Array, - dimension_numbers: DotDimensionNumbers = (((1,), (0,)), (tuple(), tuple())), + dimension_numbers: DotDimensionNumbers = (((1,), (0,)), ((), ())), sparse_operand_idx: int = 0, output_dtype: DTypeLike = jnp.bfloat16, ) -> Array: @@ -178,11 +177,8 @@ def _nm_spmm_abstract_eval( mlir.register_lowering(nm_spmm_p, _nm_spmm_default_lowering) dispatch.simple_impl(nm_spmm_p) -if gpu_sparse.cuda_is_supported: - mlir.register_lowering(nm_spmm_p, _nm_spmm_gpu_lowering, platform="cuda") - -if gpu_sparse.rocm_is_supported: - mlir.register_lowering(nm_spmm_p, _nm_spmm_gpu_lowering, platform="rocm") +mlir.register_lowering(nm_spmm_p, _nm_spmm_gpu_lowering, platform="cuda") +mlir.register_lowering(nm_spmm_p, _nm_spmm_gpu_lowering, platform="rocm") # -------------------------------------------------------------------- # nm_pack diff --git a/jax/experimental/sparse/random.py b/jax/experimental/sparse/random.py index f90c2572d282..67d58ae9db23 100644 --- a/jax/experimental/sparse/random.py +++ b/jax/experimental/sparse/random.py @@ -15,10 +15,10 @@ import math import operator -from jax import dtypes from jax import vmap from jax import random -from jax.util import split_list +from jax._src import dtypes +from jax._src.util import split_list import jax.numpy as jnp from jax.experimental import sparse @@ -69,7 +69,7 @@ def random_bcoo(key, shape, *, dtype=jnp.float_, indices_dtype=None, data_shape = batch_shape + (nse,) + dense_shape indices_shape = batch_shape + (nse, n_sparse) if indices_dtype is None: - indices_dtype = dtypes.canonicalize_dtype(jnp.int_) + indices_dtype = dtypes.default_int_dtype() if sparse_size > jnp.iinfo(indices_dtype).max: raise ValueError(f"{indices_dtype=} does not have enough range to generate " f"sparse indices of size {sparse_size}.") diff --git a/jax/experimental/sparse/test_util.py b/jax/experimental/sparse/test_util.py index 77c97513041c..63e035d2d1ac 100644 --- a/jax/experimental/sparse/test_util.py +++ b/jax/experimental/sparse/test_util.py @@ -29,7 +29,7 @@ from jax._src.typing import DTypeLike from jax.experimental import sparse import jax.numpy as jnp -from jax.util import safe_zip, split_list +from jax._src.util import safe_zip, split_list import numpy as np MATMUL_TOL = { diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index ce1d3f4af9d0..66f831b58b01 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -62,17 +62,17 @@ from jax._src import linear_util as lu from jax._src import pjit from jax._src import sharding_impls -from jax.experimental.sparse.bcoo import bcoo_multiply_dense, bcoo_multiply_sparse +from jax.experimental.sparse.bcoo import bcoo_multiply_dense, bcoo_multiply_sparse, BCOO +from jax.experimental.sparse.bcsr import BCSR import jax.numpy as jnp from jax._src.api_util import flatten_fun_nokwargs from jax._src.lib import pytree from jax._src.interpreters import partial_eval as pe from jax.tree_util import tree_flatten, tree_map, tree_unflatten -from jax.util import safe_map, safe_zip, split_list +from jax._src.util import safe_map, safe_zip, split_list from jax._src.lax.control_flow import _check_tree_and_avals from jax._src.numpy import indexing as jnp_indexing from jax.experimental import sparse -from jax.experimental.sparse import BCOO, BCSR sparse_rules_bcoo : dict[core.Primitive, Callable] = {} sparse_rules_bcsr : dict[core.Primitive, Callable] = {} @@ -448,9 +448,10 @@ def wrapped( lu.wrap_init( f, params, debug_info=api_util.debug_info("sparsify", f, - spvalues_to_arrays(spenv, spvalues), {})), + in_tree.unflatten([True] * len(in_avals_flat)), + {})), in_tree) - jaxpr, out_avals_flat, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat) + jaxpr, out_avals_flat, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat) result = eval_sparse(jaxpr, consts, spvalues_flat, spenv) if len(out_avals_flat) != len(result): raise Exception("Internal: eval_sparse does not return expected number of arguments. " @@ -701,7 +702,7 @@ def _div_sparse(spenv, *spvalues): sparse_rules_bcoo[lax.div_p] = _div_sparse -def _reduce_sum_sparse(spenv, *spvalues, axes): +def _reduce_sum_sparse(spenv, *spvalues, axes, out_sharding): X, = spvalues X_promoted = spvalues_to_arrays(spenv, X) mat = sparse.bcoo_reduce_sum(X_promoted, axes=axes) @@ -750,8 +751,8 @@ def wrapped(*args_flat): args = spvalues_to_arrays(spenv, spvalues) args_flat, in_tree = tree_flatten(args) avals_flat = [core.get_aval(arg) for arg in args_flat] - sp_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(wrapped, debug_info=jaxpr.jaxpr.debug_info), avals_flat) + sp_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(wrapped, debug_info=jaxpr.jaxpr.debug_info.with_unknown_names()), avals_flat) sp_jaxpr = pe.ClosedJaxpr(sp_jaxpr, consts) assert out_tree is not None return sp_jaxpr, out_tree @@ -802,7 +803,7 @@ def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings, None for _ in range(len(sp_call_jaxpr.out_avals) - len(out_layouts)) ) - out_flat = pjit.pjit_p.bind( + out_flat = pjit.jit_p.bind( *args_flat, jaxpr=sp_call_jaxpr, in_shardings=in_shardings, @@ -817,7 +818,7 @@ def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings, compiler_options_kvs=compiler_options_kvs) return arrays_to_spvalues(spenv, tree_unflatten(out_tree, out_flat)) -sparse_rules_bcoo[pjit.pjit_p] = _pjit_sparse +sparse_rules_bcoo[pjit.jit_p] = _pjit_sparse def _duplicate_for_sparse_spvalues(spvalues, params): @@ -861,7 +862,7 @@ def _cond_sparse(spenv, pred, *operands, branches, **params): "sparsified false_fun output", treedefs[1], sp_branches[1].out_avals) args, _ = tree_flatten(spvalues_to_arrays(spenv, (pred, *operands))) - out_flat = lax.cond_p.bind(*args, branches=sp_branches, **params) + out_flat = lax.cond_p.bind(*args, branches=tuple(sp_branches), **params) out = tree_unflatten(treedefs[0], out_flat) return arrays_to_spvalues(spenv, out) diff --git a/jax/experimental/sparse/util.py b/jax/experimental/sparse/util.py index 36e9a9c51664..7c6bfb1ec345 100644 --- a/jax/experimental/sparse/util.py +++ b/jax/experimental/sparse/util.py @@ -25,7 +25,7 @@ from jax._src import core from jax._src.api_util import flatten_axes import jax.numpy as jnp -from jax.util import safe_zip +from jax._src.util import safe_zip from jax._src.lax.lax import _dot_general_shape_rule, DotDimensionNumbers from jax._src.typing import Array diff --git a/jax/experimental/topologies.py b/jax/experimental/topologies.py index 06be2b74853f..e5b04597b056 100644 --- a/jax/experimental/topologies.py +++ b/jax/experimental/topologies.py @@ -18,11 +18,10 @@ import jax from jax.experimental import mesh_utils -from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension +from jax._src.lib import _jax from jax._src import xla_bridge as xb -Device = xc.Device +Device = _jax.Device class TopologyDescription: @@ -45,8 +44,8 @@ def get_topology_desc( ) try: topology = xb.make_pjrt_topology(platform, topology_name, **kwargs) - return TopologyDescription(topology._make_compile_only_devices()) - except xla_extension.XlaRuntimeError as e: + return TopologyDescription(topology._make_compile_only_devices()) # pytype: disable=attribute-error + except _jax.JaxRuntimeError as e: msg, *_ = e.args if msg.startswith("UNIMPLEMENTED"): raise NotImplementedError(msg) from e diff --git a/jax/experimental/transfer.py b/jax/experimental/transfer.py index 1522df2cfc36..f6849521dabd 100644 --- a/jax/experimental/transfer.py +++ b/jax/experimental/transfer.py @@ -70,3 +70,8 @@ def await_pull(self, uuid: int, arrays: Any) -> Any: TransferServer = use_cpp_class(_xc._xla.TransferServer)(TransferServer) start_transfer_server = _xc._xla.start_transfer_server +if hasattr(_xc._xla, "_make_error_array"): + + def make_error_array(aval, message): + backend = next(iter(aval.sharding.device_set)).client + return _xc._xla._make_error_array(backend, aval, str(message)) diff --git a/jax/experimental/x64_context.py b/jax/experimental/x64_context.py index 1772d466b006..f98408908611 100644 --- a/jax/experimental/x64_context.py +++ b/jax/experimental/x64_context.py @@ -14,52 +14,23 @@ """Context managers for toggling X64 mode. -**Experimental: please give feedback, and expect changes.** +**Deprecated: use :func:`jax.enable_x64` instead.** """ -# This file provides -# 1. a jax.experimental API endpoint; -# 2. the `disable_x64` wrapper. -# TODO(jakevdp): remove this file, and consider removing `disable_x64` for -# uniformity - -from contextlib import contextmanager -from jax._src import config - -@contextmanager -def enable_x64(new_val: bool = True): - """Experimental context manager to temporarily enable X64 mode. - - Usage:: - - >>> x = np.arange(5, dtype='float64') - >>> with enable_x64(): - ... print(jnp.asarray(x).dtype) - ... - float64 - - See Also - -------- - jax.experimental.enable_x64 : temporarily enable X64 mode. - """ - with config.enable_x64(new_val): - yield - -@contextmanager -def disable_x64(): - """Experimental context manager to temporarily disable X64 mode. - - Usage:: - - >>> x = np.arange(5, dtype='float64') - >>> with disable_x64(): - ... print(jnp.asarray(x).dtype) - ... - float32 - - See Also - -------- - jax.experimental.enable_x64 : temporarily enable X64 mode. - """ - with config.enable_x64(False): - yield +_deprecations = { + # Remove in v0.10.0 + "disable_x64": ( + ("jax.experimental.x64_context.disable_x64 was removed in JAX v0.9.0;" + " use jax.enable_x64(False) instead."), + None + ), + "enable_x64": ( + ("jax.experimental.x64_context.enable_x64 was removed in JAX v0.9.0;" + " use jax.enable_x64(True) instead."), + None + ), +} + +from jax._src.deprecations import deprecation_getattr as _deprecation_getattr +__getattr__ = _deprecation_getattr(__name__, _deprecations) +del _deprecation_getattr diff --git a/jax/extend/BUILD b/jax/extend/BUILD index 59958c1da389..1a37fed88fca 100644 --- a/jax/extend/BUILD +++ b/jax/extend/BUILD @@ -29,10 +29,11 @@ pytype_strict_library( deps = [ ":backend", ":core", - ":ffi", ":linear_util", ":random", + ":sharding", ":source_info_util", + "//jax/extend/mlir", ], ) @@ -42,43 +43,58 @@ py_library_providing_imports_info( lib_rule = pytype_strict_library, deps = [ "//jax", - "//jax:abstract_arrays", - "//jax:ad_util", - "//jax:core", + "//jax/_src:abstract_arrays", + "//jax/_src:ad", + "//jax/_src:ad_util", + "//jax/_src:api", + "//jax/_src:core", + "//jax/_src:custom_derivatives", + "//jax/_src:lax", + "//jax/_src:random", ], ) pytype_strict_library( name = "linear_util", srcs = ["linear_util.py"], - deps = ["//jax:core"], + deps = ["//jax/_src:core"], ) pytype_strict_library( name = "backend", srcs = ["backend.py"], deps = [ - "//jax", - "//jax:xla_bridge", + "//jax/_src:api", + "//jax/_src:compiler", + "//jax/_src:util", + "//jax/_src:xla_bridge", + "//jax/_src/lib", ], ) pytype_strict_library( name = "random", srcs = ["random.py"], - deps = ["//jax"], + deps = [ + "//jax", + "//jax/_src:extend_src", + "//jax/_src:random", + ], ) pytype_strict_library( - name = "source_info_util", - srcs = ["source_info_util.py"], - deps = ["//jax:source_info_util"], + name = "sharding", + srcs = ["sharding.py"], + deps = [ + "//jax/_src:sharding_impls", + "//jax/_src/lib", + ], ) pytype_strict_library( - name = "ffi", - srcs = ["ffi.py"], - deps = ["//jax"], + name = "source_info_util", + srcs = ["source_info_util.py"], + deps = ["//jax/_src:source_info_util"], ) pytype_strict_library( diff --git a/jax/extend/__init__.py b/jax/extend/__init__.py index bbb5925ab41a..c892daed00f4 100644 --- a/jax/extend/__init__.py +++ b/jax/extend/__init__.py @@ -16,31 +16,32 @@ The :mod:`jax.extend` module provides modules for access to JAX internal machinery. See -`JEP #15856 `_. +`JEP #15856 `_. This module is not the only means by which JAX aims to be extensible. For example, the main JAX API offers mechanisms for `customizing derivatives -`_, +`_, `registering custom pytree definitions -`_, +`_, and more. API policy ---------- Unlike the -`public API `_, +`public API `_, this module offers **no compatibility guarantee** across releases. Breaking changes will be announced via the -`JAX project changelog `_. +`JAX project changelog `_. """ from jax.extend import ( backend as backend, core as core, - ffi as ffi, linear_util as linear_util, + mlir as mlir, random as random, + sharding as sharding, source_info_util as source_info_util, ) diff --git a/jax/extend/backend.py b/jax/extend/backend.py index 8d5488baba16..ced05e83605a 100644 --- a/jax/extend/backend.py +++ b/jax/extend/backend.py @@ -18,6 +18,9 @@ from jax._src.api import ( clear_backends as clear_backends, ) +from jax._src.compiler import ( + get_compile_options as get_compile_options, +) from jax._src.xla_bridge import ( backends as backends, backend_xla_version as backend_xla_version, @@ -27,3 +30,13 @@ from jax._src.interpreters.pxla import ( get_default_device as get_default_device ) +from jax._src import ( + util as _util +) +register_backend_cache = _util.register_cache # type: ignore + +from jax._src.lib import ( + ifrt_proxy as ifrt_proxy +) + +del _util diff --git a/jax/extend/core/__init__.py b/jax/extend/core/__init__.py index 9f1632fb37a9..ed39315bd74d 100644 --- a/jax/extend/core/__init__.py +++ b/jax/extend/core/__init__.py @@ -24,9 +24,11 @@ Jaxpr as Jaxpr, JaxprEqn as JaxprEqn, jaxpr_as_fun as jaxpr_as_fun, + mapped_aval as mapped_aval, Literal as Literal, Primitive as Primitive, Token as Token, + unmapped_aval as unmapped_aval, Var as Var, ) diff --git a/jax/extend/core/primitives.py b/jax/extend/core/primitives.py index d8a10154cf4a..348df455802d 100644 --- a/jax/extend/core/primitives.py +++ b/jax/extend/core/primitives.py @@ -15,6 +15,11 @@ # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/jax-ml/jax/issues/7570 +from jax._src.ad_checkpoint import ( + name_p as name_p, + remat_p as remat_p, +) + from jax._src.ad_util import stop_gradient_p as stop_gradient_p from jax._src.core import ( @@ -24,9 +29,7 @@ from jax._src.custom_derivatives import ( custom_jvp_call_p as custom_jvp_call_p, - custom_jvp_call_jaxpr_p as custom_jvp_call_jaxpr_p, custom_vjp_call_p as custom_vjp_call_p, - custom_vjp_call_jaxpr_p as custom_vjp_call_jaxpr_p, ) from jax._src.dispatch import device_put_p as device_put_p @@ -34,7 +37,6 @@ from jax._src.interpreters.ad import ( add_jaxvals_p as add_jaxvals_p, custom_lin_p as custom_lin_p, - zeros_like_p as zeros_like_p, ) from jax._src.interpreters.pxla import xla_pmap_p as xla_pmap_p @@ -78,7 +80,6 @@ ge_p as ge_p, gt_p as gt_p, imag_p as imag_p, - infeed_p as infeed_p, integer_pow_p as integer_pow_p, iota_p as iota_p, is_finite_p as is_finite_p, @@ -97,7 +98,6 @@ nextafter_p as nextafter_p, not_p as not_p, or_p as or_p, - outfeed_p as outfeed_p, pad_p as pad_p, population_count_p as population_count_p, pow_p as pow_p, @@ -135,6 +135,7 @@ top_k_p as top_k_p, transpose_p as transpose_p, xor_p as xor_p, + empty2_p as empty2_p, ) from jax._src.lax.special import ( @@ -149,7 +150,6 @@ igamma_p as igamma_p, lgamma_p as lgamma_p, polygamma_p as polygamma_p, - random_gamma_grad_p as random_gamma_grad_p, regularized_incomplete_beta_p as regularized_incomplete_beta_p, zeta_p as zeta_p, ) @@ -226,7 +226,10 @@ schur_p as schur_p, ) -from jax._src.pjit import sharding_constraint_p as sharding_constraint_p +from jax._src.pjit import ( + jit_p as jit_p, + sharding_constraint_p as sharding_constraint_p, +) from jax._src.prng import ( random_bits_p as random_bits_p, diff --git a/jax/extend/ffi.py b/jax/extend/ffi.py deleted file mode 100644 index 21642055993b..000000000000 --- a/jax/extend/ffi.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright 2024 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from jax._src import ffi as _ffi - -_deprecations = { - # Added 2024-12-20 - "ffi_call": ( - "jax.extend.ffi.ffi_call is deprecated, use jax.ffi.ffi_call instead.", - _ffi.ffi_call, - ), - "ffi_lowering": ( - "jax.extend.ffi.ffi_lowering is deprecated, use jax.ffi.ffi_lowering instead.", - _ffi.ffi_lowering, - ), - "include_dir": ( - "jax.extend.ffi.include_dir is deprecated, use jax.ffi.include_dir instead.", - _ffi.include_dir, - ), - "pycapsule": ( - "jax.extend.ffi.pycapsule is deprecated, use jax.ffi.pycapsule instead.", - _ffi.pycapsule, - ), - "register_ffi_target": ( - "jax.extend.ffi.register_ffi_target is deprecated, use jax.ffi.register_ffi_target instead.", - _ffi.register_ffi_target, - ), -} - -import typing -if typing.TYPE_CHECKING: - ffi_call = _ffi.ffi_call - ffi_lowering = _ffi.ffi_lowering - include_dir = _ffi.include_dir - pycapsule = _ffi.pycapsule - register_ffi_target = _ffi.register_ffi_target -else: - from jax._src.deprecations import deprecation_getattr as _deprecation_getattr - __getattr__ = _deprecation_getattr(__name__, _deprecations) - del _deprecation_getattr -del typing -del _ffi diff --git a/jax/extend/ifrt_programs.py b/jax/extend/ifrt_programs.py index 715dfd43592c..13ba9088bc55 100644 --- a/jax/extend/ifrt_programs.py +++ b/jax/extend/ifrt_programs.py @@ -15,8 +15,8 @@ # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/jax-ml/jax/issues/7570 -from jax._src.lib import xla_extension as _xe +from jax._src.lib import _jax -ifrt_programs = _xe.ifrt_programs +ifrt_programs = _jax.ifrt_programs -del _xe +del _jax diff --git a/jax/extend/linear_util.py b/jax/extend/linear_util.py index 0cf9a013a9e4..ad67f6ac8f73 100644 --- a/jax/extend/linear_util.py +++ b/jax/extend/linear_util.py @@ -15,7 +15,7 @@ # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/jax-ml/jax/issues/7570 -from typing import Callable +from collections.abc import Callable from jax._src.linear_util import ( StoreException as StoreException, diff --git a/jax/extend/mlir/BUILD b/jax/extend/mlir/BUILD index 17d5aab3a7d3..cdc778b1f615 100644 --- a/jax/extend/mlir/BUILD +++ b/jax/extend/mlir/BUILD @@ -29,6 +29,7 @@ pytype_strict_library( deps = [ ":ir", ":pass_manager", + "//jax/_src/lib", ], ) diff --git a/jax/extend/mlir/__init__.py b/jax/extend/mlir/__init__.py index 38d13f42da99..fca1a7556040 100644 --- a/jax/extend/mlir/__init__.py +++ b/jax/extend/mlir/__init__.py @@ -11,3 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from jax._src.lib import ( + _jax as _jax +) + +deserialize_portable_artifact = _jax.mlir.deserialize_portable_artifact +serialize_portable_artifact = _jax.mlir.serialize_portable_artifact +refine_polymorphic_shapes = _jax.mlir.refine_polymorphic_shapes +hlo_to_stablehlo = _jax.mlir.hlo_to_stablehlo + +del _jax diff --git a/jax/extend/mlir/dialects/BUILD b/jax/extend/mlir/dialects/BUILD index 75275b45bd27..82bbea5c70e1 100644 --- a/jax/extend/mlir/dialects/BUILD +++ b/jax/extend/mlir/dialects/BUILD @@ -33,6 +33,7 @@ pytype_strict_library( ":func_dialect", ":math_dialect", ":memref_dialect", + ":mpmd_dialect", ":scf_dialect", ":sdy_dialect", ":sparse_tensor_dialect", @@ -77,6 +78,12 @@ pytype_strict_library( deps = if_building_jaxlib(["//jaxlib/mlir:memref_dialect"]), ) +pytype_strict_library( + name = "mpmd_dialect", + srcs = ["mpmd.py"], + deps = if_building_jaxlib(["//jaxlib/mlir:mpmd_dialect"]), +) + pytype_strict_library( name = "scf_dialect", srcs = ["scf.py"], diff --git a/jax/experimental/shard.py b/jax/extend/mlir/dialects/mpmd.py similarity index 73% rename from jax/experimental/shard.py rename to jax/extend/mlir/dialects/mpmd.py index c7acbff306ce..b3e6b190d509 100644 --- a/jax/experimental/shard.py +++ b/jax/extend/mlir/dialects/mpmd.py @@ -9,11 +9,9 @@ # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the ific language governing permissions and +# See the License for the specific language governing permissions and # limitations under the License. -from jax._src.pjit import ( - reshard as reshard, - auto_axes as auto_axes, - explicit_axes as explicit_axes, -) +# ruff: noqa: F403 + +from jaxlib.mlir.dialects.mpmd import * diff --git a/jax/extend/mlir/dialects/sdy.py b/jax/extend/mlir/dialects/sdy.py index 48586cc26760..d83fd90ecdf4 100644 --- a/jax/extend/mlir/dialects/sdy.py +++ b/jax/extend/mlir/dialects/sdy.py @@ -14,8 +14,4 @@ # ruff: noqa: F403 -# TODO(bartchr): Once JAX is released with SDY, remove the try/except. -try: - from jaxlib.mlir.dialects.sdy import * -except ImportError: - pass +from jaxlib.mlir.dialects.sdy import * diff --git a/jax/extend/sharding.py b/jax/extend/sharding.py new file mode 100644 index 000000000000..e61dbc7c099c --- /dev/null +++ b/jax/extend/sharding.py @@ -0,0 +1,36 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO(yashkatariya): Remove this after NamedSharding supports more complicated +# shardings like sub-axes, strided shardings, etc. +from jax._src.lib import xla_client +from jax._src.sharding_impls import GSPMDSharding as GSPMDSharding + + +def get_op_sharding_from_serialized_proto( + sharding: bytes) -> xla_client.OpSharding: + proto = xla_client.OpSharding() + proto.ParseFromString(sharding) + return proto + + +def get_hlo_sharding_from_serialized_proto( + sharding: bytes) -> xla_client.HloSharding: + return xla_client.HloSharding.from_proto( + get_op_sharding_from_serialized_proto(sharding)) + + +def get_serialized_proto_from_hlo_sharding( + sharding: xla_client.HloSharding) -> bytes: + return sharding.to_proto().SerializeToString() diff --git a/jax/ffi.py b/jax/ffi.py index 6606c58a0353..a4eb2a29fbd0 100644 --- a/jax/ffi.py +++ b/jax/ffi.py @@ -16,11 +16,13 @@ # See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.ffi import ( + build_ffi_lowering_function as build_ffi_lowering_function, ffi_call as ffi_call, ffi_lowering as ffi_lowering, include_dir as include_dir, pycapsule as pycapsule, register_ffi_target as register_ffi_target, register_ffi_type_id as register_ffi_type_id, + register_ffi_type as register_ffi_type, register_ffi_target_as_batch_partitionable as register_ffi_target_as_batch_partitionable, ) diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index 4ded4a803ae0..85f58c4ce4fd 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -17,6 +17,8 @@ from __future__ import annotations +from jax._src.interpreters import ad as _src_ad + from jax._src.interpreters.ad import ( JVPTrace as JVPTrace, JVPTracer as JVPTracer, @@ -25,54 +27,131 @@ add_jaxvals as add_jaxvals, add_jaxvals_p as add_jaxvals_p, add_tangents as add_tangents, - backward_pass as backward_pass_internal, - bilinear_transpose as bilinear_transpose, - call_param_updaters as call_param_updaters, - call_transpose as call_transpose, - call_transpose_param_updaters as call_transpose_param_updaters, - closed_backward_pass as closed_backward_pass, - custom_lin_p as custom_lin_p, defbilinear as defbilinear, defjvp as defjvp, defjvp2 as defjvp2, - defjvp_zero as defjvp_zero, deflinear as deflinear, deflinear2 as deflinear2, - f_jvp_traceable as f_jvp_traceable, get_primitive_transpose as get_primitive_transpose, instantiate_zeros as instantiate_zeros, is_undefined_primal as is_undefined_primal, jvp as jvp, - jvp_jaxpr as jvp_jaxpr, - jvp_subtrace as jvp_subtrace, - jvp_subtrace_aux as jvp_subtrace_aux, - jvpfun as jvpfun, - linear_jvp as linear_jvp, - linear_transpose as linear_transpose, - linear_transpose2 as linear_transpose2, linearize as linearize, - map_transpose as map_transpose, - nonzero_outputs as nonzero_outputs, - nonzero_tangent_outputs as nonzero_tangent_outputs, primitive_jvps as primitive_jvps, primitive_transposes as primitive_transposes, - rearrange_binders as rearrange_binders, - reducing_transposes as reducing_transposes, - standard_jvp as standard_jvp, - standard_jvp2 as standard_jvp2, - traceable as traceable, - unpair_pval as unpair_pval, - vjp as vjp, - zero_jvp as zero_jvp, zeros_like_aval as zeros_like_aval, - zeros_like_p as zeros_like_p, ) -def backward_pass(jaxpr, reduce_axes, transform_stack, - consts, primals_in, cotangents_in): - if reduce_axes: - raise NotImplementedError("reduce_axes on ad.backward_pass is deprecated") - del reduce_axes - return backward_pass_internal( - jaxpr, transform_stack, consts, primals_in, cotangents_in) +_deprecations = { + # Deprecated for JAX v0.7.1; finalized in JAX v0.9.0; Remove in v0.10.0. + "zeros_like_p": ( + "jax.interpreters.ad.zeros_like_p was removed in JAX v0.9.0.", + None, + ), + "bilinear_transpose": ( + "jax.interpreters.ad.bilinear_transpose was removed in JAX v0.9.0.", + None, + ), + "call_param_updaters": ( + "jax.interpreters.ad.call_param_updaters was removed in JAX v0.9.0.", + None, + ), + "call_transpose": ( + "jax.interpreters.ad.call_transpose was removed in JAX v0.9.0.", + None, + ), + "call_transpose_param_updaters": ( + "jax.interpreters.ad.call_transpose_param_updaters was removed in JAX v0.9.0.", + None, + ), + "custom_lin_p": ( + "jax.interpreters.ad.custom_lin_p was removed in JAX v0.9.0.", + None, + ), + "defjvp_zero": ( + "jax.interpreters.ad.defjvp_zero was removed in JAX v0.9.0.", + None, + ), + "f_jvp_traceable": ( + "jax.interpreters.ad.f_jvp_traceable was removed in JAX v0.9.0.", + None, + ), + "jvp_jaxpr": ( + "jax.interpreters.ad.jvp_jaxpr was removed in JAX v0.9.0.", + None, + ), + "jvp_subtrace": ( + "jax.interpreters.ad.jvp_subtrace was removed in JAX v0.9.0.", + None, + ), + "jvp_subtrace_aux": ( + "jax.interpreters.ad.jvp_subtrace_aux was removed in JAX v0.9.0.", + None, + ), + "jvpfun": ( + "jax.interpreters.ad.jvpfun was removed in JAX v0.9.0.", + None, + ), + "linear_jvp": ( + "jax.interpreters.ad.linear_jvp was removed in JAX v0.9.0.", + None, + ), + "linear_transpose": ( + "jax.interpreters.ad.linear_transpose was removed in JAX v0.9.0.", + None, + ), + "linear_transpose2": ( + "jax.interpreters.ad.linear_transpose2 was removed in JAX v0.9.0.", + None, + ), + "map_transpose": ( + "jax.interpreters.ad.map_transpose was removed in JAX v0.9.0.", + None, + ), + "nonzero_outputs": ( + "jax.interpreters.ad.nonzero_outputs was removed in JAX v0.9.0.", + None, + ), + "nonzero_tangent_outputs": ( + "jax.interpreters.ad.nonzero_tangent_outputs was removed in JAX v0.9.0.", + None, + ), + "rearrange_binders": ( + "jax.interpreters.ad.rearrange_binders was removed in JAX v0.9.0.", + None, + ), + "standard_jvp": ( + "jax.interpreters.ad.standard_jvp was removed in JAX v0.9.0.", + None, + ), + "standard_jvp2": ( + "jax.interpreters.ad.standard_jvp2 was removed in JAX v0.9.0.", + None, + ), + "traceable": ( + "jax.interpreters.ad.traceable was removed in JAX v0.9.0.", + None, + ), + "zero_jvp": ( + "jax.interpreters.ad.zero_jvp was removed in JAX v0.9.0.", + None, + ), + # Deprecated for JAX v0.9.0; finalize in JAX v0.10.0. + "reducing_transposes": ( + ( + "jax.interpreters.ad.reducing_transposes is deprecated in JAX v0.9.0." + " It has been unused since v0.4.38." + ), + _src_ad.reducing_transposes, + ), +} + +import typing +if typing.TYPE_CHECKING: + reducing_transposes = _src_ad.reducing_transposes +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del typing diff --git a/jax/interpreters/batching.py b/jax/interpreters/batching.py index 7a93a6942c21..c6f9daa04942 100644 --- a/jax/interpreters/batching.py +++ b/jax/interpreters/batching.py @@ -15,61 +15,185 @@ # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/jax-ml/jax/issues/7570 +from jax._src.interpreters import batching as _src_batching + from jax._src.interpreters.batching import ( - Array as Array, - AxisSize as AxisSize, - BatchTrace as BatchTrace, - BatchTracer as BatchTracer, - BatchingRule as BatchingRule, - RaggedAxis as RaggedAxis, - Elt as Elt, - FromEltHandler as FromEltHandler, - GetIdx as GetIdx, - IndexedAxisSize as IndexedAxisSize, - MakeIotaHandler as MakeIotaHandler, - MapSpec as MapSpec, - NotMapped as NotMapped, - Jumble as Jumble, - JumbleAxis as JumbleAxis, - JumbleTy as JumbleTy, - ToEltHandler as ToEltHandler, - Vmappable as Vmappable, - Zero as Zero, - ZeroIfMapped as ZeroIfMapped, axis_primitive_batchers as axis_primitive_batchers, - batch as batch, - batch_custom_jvp_subtrace as batch_custom_jvp_subtrace, - batch_custom_vjp_bwd as batch_custom_vjp_bwd, - batch_jaxpr as batch_jaxpr, - batch_jaxpr2 as batch_jaxpr2, - batch_jaxpr_axes as batch_jaxpr_axes, - batch_subtrace as batch_subtrace, bdim_at_front as bdim_at_front, broadcast as broadcast, - broadcast_batcher as broadcast_batcher, defbroadcasting as defbroadcasting, defreducer as defreducer, defvectorized as defvectorized, fancy_primitive_batchers as fancy_primitive_batchers, - flatten_fun_for_vmap as flatten_fun_for_vmap, - from_elt as from_elt, - from_elt_handlers as from_elt_handlers, - is_vmappable as is_vmappable, - make_iota as make_iota, - make_iota_handlers as make_iota_handlers, - matchaxis as matchaxis, - moveaxis as moveaxis, not_mapped as not_mapped, - jumble_axis as jumble_axis, primitive_batchers as primitive_batchers, - reducer_batcher as reducer_batcher, register_vmappable as register_vmappable, - spec_types as spec_types, - to_elt as to_elt, - to_elt_handlers as to_elt_handlers, unregister_vmappable as unregister_vmappable, - vectorized_batcher as vectorized_batcher, - vmappables as vmappables, - vtile as vtile, - zero_if_mapped as zero_if_mapped, ) + + +_deprecations = { + # Deprecated for JAX v0.7.1; finalize in JAX v0.9.0. + "AxisSize": ( + "jax.interpreters.batching.AxisSize is deprecated.", + None, + ), + "Array": ( + "jax.interpreters.batching.Array is deprecated. Use jax.Array directly.", + None, + ), + "BatchTrace": ( + "jax.interpreters.batching.BatchTrace is deprecated.", + None, + ), + "BatchTracer": ( + "jax.interpreters.batching.BatchTracer is deprecated.", + None, + ), + "BatchingRule": ( + "jax.interpreters.batching.BatchingRule is deprecated.", + None, + ), + "Elt": ( + "jax.interpreters.batching.Elt is deprecated.", + None, + ), + "FromEltHandler": ( + "jax.interpreters.batching.FromEltHandler is deprecated.", + None, + ), + "GetIdx": ( + "jax.interpreters.batching.GetIdx is deprecated.", + None, + ), + "MakeIotaHandler": ( + "jax.interpreters.batching.MakeIotaHandler is deprecated.", + None, + ), + "MapSpec": ( + "jax.interpreters.batching.MapSpec is deprecated.", + None, + ), + "NotMapped": ( + "jax.interpreters.batching.NotMapped is deprecated.", + _src_batching.NotMapped, + ), + "ToEltHandler": ( + "jax.interpreters.batching.ToEltHandler is deprecated.", + None, + ), + "Vmappable": ( + "jax.interpreters.batching.Vmappable is deprecated.", + None, + ), + "Zeros": ( + "jax.interpreters.batching.Zero is deprecated. Use jax.interpreters.ad.Zero.", + None, + ), + "ZeroIfMapped": ( + "jax.interpreters.batching.ZeroIfMapped is deprecated. It is an internal type.", + None, + ), + "batch": ( + "jax.interpreters.batching.batch is deprecated. It is an internal API.", + None, + ), + "batch_custom_jvp_subtrace": ( + "jax.interpreters.batching.batch_custom_jvp_subtrace is deprecated. It is an internal API.", + None, + ), + "batch_custom_vjp_bwd": ( + "jax.interpreters.batching.batch_custom_vjp_bwd is deprecated. It is an internal API.", + None, + ), + "batch_jaxpr": ( + "jax.interpreters.batching.batch_jaxpr is deprecated. It is an internal API.", + None, + ), + "batch_jaxpr_axes": ( + "jax.interpreters.batching.batch_jaxpr_axes is deprecated. It is an internal API.", + None, + ), + "batch_subtrace": ( + "jax.interpreters.batching.batch_subtrace is deprecated. It is an internal API.", + None, + ), + "broadcast_batcher": ( + "jax.interpreters.batching.broadcast_batcher is deprecated. It is an internal API.", + None, + ), + "flatten_fun_for_vmap": ( + "jax.interpreters.batching.flatten_fun_for_vmap is deprecated. It is an internal API.", + None, + ), + "from_elt": ( + "jax.interpreters.batching.from_elt is deprecated. It is an internal API.", + None, + ), + "from_elt_handlers": ( + "jax.interpreters.batching.from_elt_handlers is deprecated. It is an internal API.", + None, + ), + "is_vmappable": ( + "jax.interpreters.batching.is_vmappable is deprecated. It is an internal API.", + None, + ), + "make_iota": ( + "jax.interpreters.batching.make_iota is deprecated. It is an internal API.", + None, + ), + "make_iota_handlers": ( + "jax.interpreters.batching.make_iota_handlers is deprecated. It is an internal API.", + None, + ), + "matchaxis": ( + "jax.interpreters.batching.matchaxis is deprecated. It is an internal API.", + None, + ), + "moveaxis": ( + "jax.interpreters.batching.moveaxis is deprecated. Use jax.numpy.moveaxis.", + None, + ), + "reducer_batcher": ( + "jax.interpreters.batching.reducer_batcher is deprecated. It is an internal API.", + None, + ), + "spec_types": ( + "jax.interpreters.batching.spec_types is deprecated. It is an internal API.", + None, + ), + "to_elt": ( + "jax.interpreters.batching.to_elt is deprecated. It is an internal API.", + None, + ), + "to_elt_handlers": ( + "jax.interpreters.batching.to_elt_handlers is deprecated. It is an internal API.", + None, + ), + "vectorized_batcher": ( + "jax.interpreters.batching.vectorized_batcher is deprecated. It is an internal API.", + None, + ), + "vmappables": ( + "jax.interpreters.batching.vmappables is deprecated. It is an internal API.", + None, + ), + "vtile": ( + "jax.interpreters.batching.vtile is deprecated. It is an internal API.", + None, + ), + "zero_if_mapped": ( + "jax.interpreters.batching.zero_if_mapped is deprecated. It is an internal API.", + None, + ), +} + + +import typing as _typing +if _typing.TYPE_CHECKING: + NotMapped = _src_batching.NotMapped +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del _typing diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index 0f32799f7ea9..463ea47cfbe1 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -33,28 +33,23 @@ aval_to_ir_type as aval_to_ir_type, aval_to_ir_types as aval_to_ir_types, core_call_lowering as core_call_lowering, - custom_call as custom_call, dense_bool_elements as dense_bool_elements, - dense_bool_array as dense_bool_array, dense_int_array as dense_int_array, dense_int_elements as dense_int_elements, dtype_to_ir_type as dtype_to_ir_type, flatten_ir_types as flatten_ir_types, - flatten_ir_values as flatten_lowering_ir_args, # TODO(phawkins): remove me # noqa: F401 flatten_ir_values as flatten_ir_values, unflatten_ir_values_like_types as unflatten_ir_values_like_types, - func_dialect as func_dialect, - hlo as hlo, i32_attr as i32_attr, i64_attr as i64_attr, ir as ir, + ir_attribute as ir_attribute, ir_constant as ir_constant, ir_type_handlers as ir_type_handlers, jaxpr_subcomp as jaxpr_subcomp, lower_fun as lower_fun, lower_jaxpr_to_fun as lower_jaxpr_to_fun, lower_jaxpr_to_module as lower_jaxpr_to_module, - lowerable_effects as lowerable_effects, make_ir_context as make_ir_context, merge_mlir_modules as merge_mlir_modules, module_to_bytecode as module_to_bytecode, @@ -63,7 +58,6 @@ register_lowering as register_lowering, shape_tensor as shape_tensor, token_type as token_type, - xla_computation_to_mlir_module as xla_computation_to_mlir_module, ) from jax._src.mesh import Mesh as Mesh @@ -73,6 +67,7 @@ SPMDAxisContext as SPMDAxisContext, ShardingContext as ShardingContext, ) +from jax._src.effects import lowerable_effects as lowerable_effects # TODO(dsuo): Temporarily maintain symbols related to callback lowering for sake @@ -80,3 +75,25 @@ from jax._src.callback import ( emit_python_callback as emit_python_callback, ) + +_deprecations = { + # Added Apr 7 2025 + "custom_call": ( + ( + "mlir.custom_call was removed in JAX v0.8.0; use the APIs provided" + " by jax.ffi instead." + ), + None, + ) +} + +import typing as _typing + +if _typing.TYPE_CHECKING: + pass +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del _typing diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index b546d774a2e9..3a34c2d5ce10 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -12,84 +12,220 @@ # See the License for the specific language governing permissions and # limitations under the License. + from jax._src.interpreters.partial_eval import ( - AbstractedAxesSpec as AbstractedAxesSpec, - AbstractedAxisName as AbstractedAxisName, - BoundedAxisSize as BoundedAxisSize, - Const as Const, - ConstFoldRule as ConstFoldRule, - ConstVar as ConstVar, - DCERule as DCERule, - DynamicJaxprTrace as DynamicJaxprTrace, DynamicJaxprTracer as DynamicJaxprTracer, - ForwardingRule as ForwardingRule, - FreeVar as FreeVar, - JaxprEqnRecipe as JaxprEqnRecipe, - JaxprStackFrame as JaxprStackFrame, - JaxprTrace as JaxprTrace, JaxprTracer as JaxprTracer, - JaxprTracerRecipe as JaxprTracerRecipe, - LambdaBinding as LambdaBinding, - ParamsUpdater as ParamsUpdater, - PartialEvalCustomResult as PartialEvalCustomResult, - PartialEvalCustomRule as PartialEvalCustomRule, PartialVal as PartialVal, - ResAvalUpdater as ResAvalUpdater, - TracerAsName as TracerAsName, - TracerId as TracerId, Val as Val, - abstract_eval_fun as abstract_eval_fun, - call_padding_rule as call_padding_rule, - call_param_updaters as call_param_updaters, - call_partial_eval_custom_rule as call_partial_eval_custom_rule, - call_partial_eval_rules as call_partial_eval_rules, - close_jaxpr as close_jaxpr, - closed_call_partial_eval_custom_rule as closed_call_partial_eval_custom_rule, - config as config, - const_fold_rules as const_fold_rules, - convert_constvars_jaxpr as convert_constvars_jaxpr, - convert_envvars_to_constvars as convert_envvars_to_constvars, - convert_invars_to_constvars as convert_invars_to_constvars, custom_partial_eval_rules as custom_partial_eval_rules, - custom_staging_rules as custom_staging_rules, dce_jaxpr as dce_jaxpr, dce_jaxpr_call_rule as dce_jaxpr_call_rule, dce_jaxpr_closed_call_rule as dce_jaxpr_closed_call_rule, dce_jaxpr_consts as dce_jaxpr_consts, dce_rules as dce_rules, - def_trivial_padding as def_trivial_padding, - forwarding_rules as forwarding_rules, - has_effects as has_effects, - infer_lambda_input_type as infer_lambda_input_type, - instantiate_const_at as instantiate_const_at, - make_jaxpr_effects as make_jaxpr_effects, - move_binders_to_back as move_binders_to_back, - move_binders_to_front as move_binders_to_front, - new_eqn_recipe as new_eqn_recipe, - pad_jaxpr as pad_jaxpr, - padding_rules as padding_rules, - partial_eval_jaxpr_custom as partial_eval_jaxpr_custom, - partial_eval_jaxpr_custom_rule_not_implemented as partial_eval_jaxpr_custom_rule_not_implemented, partial_eval_jaxpr_custom_rules as partial_eval_jaxpr_custom_rules, - partial_eval_jaxpr_nounits as partial_eval_jaxpr_nounits, - partial_eval_wrapper_nounits as partial_eval_wrapper_nounits, - partition_pvals as partition_pvals, - recipe_to_eqn as recipe_to_eqn, - trace_to_jaxpr_dynamic as _trace_to_jaxpr_dynamic, - trace_to_jaxpr_dynamic2 as trace_to_jaxpr_dynamic2, + trace_to_jaxpr_dynamic as trace_to_jaxpr_dynamic, trace_to_jaxpr_nounits as trace_to_jaxpr_nounits, - trace_to_subjaxpr_nounits as trace_to_subjaxpr_nounits, - trace_to_subjaxpr_nounits_fwd as trace_to_subjaxpr_nounits_fwd, - tracers_to_jaxpr as tracers_to_jaxpr, - trivial_ctx as trivial_ctx, ) -# TODO(mattjj): remove temporary shim when trace_to_jaxpr_dynamic sig stabilizes -def trace_to_jaxpr_dynamic(fun, in_avals, *, keep_inputs=None): # noqa - jaxpr, out_avals, consts, () = _trace_to_jaxpr_dynamic( - fun, in_avals, keep_inputs=keep_inputs) - return jaxpr, out_avals, consts - +_deprecations = { + # Remove in v0.10.0 + "Const": ( + "jax.interpreters.partial_eval.Const is deprecated.", + None, + ), + "ConstFoldRule": ( + "jax.interpreters.partial_eval.ConstFoldRule is deprecated.", + None, + ), + "ConstVar": ( + "jax.interpreters.partial_eval.ConstVar is deprecated.", + None, + ), + "DCERule": ( + "jax.interpreters.partial_eval.DCERule is deprecated.", + None, + ), + "DynamicJaxprTrace": ( + "jax.interpreters.partial_eval.DynamicJaxprTrace is deprecated.", + None, + ), + "ForwardingRule": ( + "jax.interpreters.partial_eval.ForwardingRule is deprecated.", + None, + ), + "FreeVar": ( + "jax.interpreters.partial_eval.FreeVar is deprecated.", + None, + ), + "Jaxpr": ( + ( + "jax.interpreters.partial_eval.Jaxpr is deprecated. Use" + " jax.extend.core.Jaxpr, and please note that you must" + " `import jax.extend` explicitly." + ), + None, + ), + "JaxprEqnRecipe": ( + "jax.interpreters.partial_eval.JaxprEqnRecipe is deprecated.", + None, + ), + "JaxprStackFrame": ( + "jax.interpreters.partial_eval.JaxprStackFrame is deprecated.", + None, + ), + "JaxprTrace": ( + "jax.interpreters.partial_eval.JaxprTrace is deprecated.", + None, + ), + "JaxprTracerRecipe": ( + "jax.interpreters.partial_eval.JaxprTracerRecipe is deprecated.", + None, + ), + "LambdaBinding": ( + "jax.interpreters.partial_eval.LambdaBinding is deprecated.", + None, + ), + "ParamsUpdater": ( + "jax.interpreters.partial_eval.ParamsUpdater is deprecated.", + None, + ), + "PartialEvalCustomResult": ( + "jax.interpreters.partial_eval.PartialEvalCustomResult is deprecated.", + None, + ), + "PartialEvalCustomRule": ( + "jax.interpreters.partial_eval.PartialEvalCustomRule is deprecated.", + None, + ), + "ResAvalUpdater": ( + "jax.interpreters.partial_eval.ResAvalUpdater is deprecated.", + None, + ), + "TracerAsName": ( + "jax.interpreters.partial_eval.TracerAsName is deprecated.", + None, + ), + "TracerId": ( + "jax.interpreters.partial_eval.TracerId is deprecated.", + None, + ), + "abstract_eval_fun": ( + "jax.interpreters.partial_eval.abstract_eval_fun is deprecated.", + None, + ), + "call_param_updaters": ( + "jax.interpreters.partial_eval.call_param_updaters is deprecated.", + None, + ), + "call_partial_eval_custom_rule": ( + "jax.interpreters.partial_eval.call_partial_eval_custom_rule is deprecated.", + None, + ), + "call_partial_eval_rules": ( + "jax.interpreters.partial_eval.call_partial_eval_rules is deprecated.", + None, + ), + "close_jaxpr": ( + "jax.interpreters.partial_eval.close_jaxpr is deprecated.", + None, + ), + "closed_call_partial_eval_custom_rule": ( + "jax.interpreters.partial_eval.closed_call_partial_eval_custom_rule is deprecated.", + None, + ), + "config": ( + "jax.interpreters.partial_eval.config is deprecated; use jax.config directly.", + None, + ), + "const_fold_rules": ( + "jax.interpreters.partial_eval.const_fold_rules is deprecated.", + None, + ), + "convert_constvars_jaxpr": ( + "jax.interpreters.partial_eval.convert_constvars_jaxpr is deprecated.", + None, + ), + "convert_envvars_to_constvars": ( + "jax.interpreters.partial_eval.convert_envvars_to_constvars is deprecated.", + None, + ), + "convert_invars_to_constvars": ( + "jax.interpreters.partial_eval.convert_invars_to_constvars is deprecated.", + None, + ), + "custom_staging_rules": ( + "jax.interpreters.partial_eval.custom_staging_rules is deprecated.", + None, + ), + "forwarding_rules": ( + "jax.interpreters.partial_eval.forwarding_rules is deprecated.", + None, + ), + "has_effects": ( + "jax.interpreters.partial_eval.has_effects is deprecated.", + None, + ), + "instantiate_const_at": ( + "jax.interpreters.partial_eval.instantiate_const_at is deprecated.", + None, + ), + "make_jaxpr_effects": ( + "jax.interpreters.partial_eval.make_jaxpr_effects is deprecated.", + None, + ), + "move_binders_to_back": ( + "jax.interpreters.partial_eval.move_binders_to_back is deprecated.", + None, + ), + "move_binders_to_front": ( + "jax.interpreters.partial_eval.move_binders_to_front is deprecated.", + None, + ), + "new_eqn_recipe": ( + "jax.interpreters.partial_eval.new_eqn_recipe is deprecated.", + None, + ), + "partial_eval_jaxpr_custom": ( + "jax.interpreters.partial_eval.partial_eval_jaxpr_custom is deprecated.", + None, + ), + "partial_eval_jaxpr_custom_rule_not_implemented": ( + "jax.interpreters.partial_eval.partial_eval_jaxpr_custom_rule_not_implemented is deprecated.", + None, + ), + "partial_eval_jaxpr_nounits": ( + "jax.interpreters.partial_eval.partial_eval_jaxpr_nounits is deprecated.", + None, + ), + "partial_eval_wrapper_nounits": ( + "jax.interpreters.partial_eval.partial_eval_wrapper_nounits is deprecated.", + None, + ), + "partition_pvals": ( + "jax.interpreters.partial_eval.partition_pvals is deprecated.", + None, + ), + "recipe_to_eqn": ( + "jax.interpreters.partial_eval.recipe_to_eqn is deprecated.", + None, + ), + "trace_to_subjaxpr_nounits": ( + "jax.interpreters.partial_eval.trace_to_subjaxpr_nounits is deprecated.", + None, + ), + "trace_to_subjaxpr_nounits_fwd": ( + "jax.interpreters.partial_eval.trace_to_subjaxpr_nounits_fwd is deprecated.", + None, + ), + "tracers_to_jaxpr": ( + "jax.interpreters.partial_eval.tracers_to_jaxpr is deprecated.", + None, + ), +} -from jax._src.core import Jaxpr as Jaxpr +from jax._src.deprecations import deprecation_getattr as _deprecation_getattr +__getattr__ = _deprecation_getattr(__name__, _deprecations) +del _deprecation_getattr diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index f3fd8bac558c..af179b1905a5 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -12,42 +12,153 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax._src.interpreters.pxla import ( - Index as Index, - MapTracer as MapTracer, - MeshAxisName as MeshAxisName, - MeshComputation as MeshComputation, - MeshExecutable as MeshExecutable, - PmapExecutable as PmapExecutable, - global_aval_to_result_handler as global_aval_to_result_handler, - global_avals_to_results_handler as global_avals_to_results_handler, - global_result_handlers as global_result_handlers, - parallel_callable as parallel_callable, - shard_args as shard_args, - xla_pmap_p as xla_pmap_p, -) -from jax._src.mesh import ( - thread_resources as thread_resources, -) +# Note: import as is required for names to be exported. +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 -from jax._src.op_shardings import ( - are_op_shardings_equal as are_op_shardings_equal, - is_op_sharding_replicated as is_op_sharding_replicated, - op_sharding_to_indices as op_sharding_to_indices, -) +from jax._src.interpreters import pxla as _deprecated_pxla +from jax._src import mesh as _deprecated_mesh +from jax._src import op_shardings as _deprecated_op_shardings +from jax._src import sharding_impls as _deprecated_sharding_impls +from jax._src import sharding_specs as _deprecated_sharding_specs -from jax._src.sharding_impls import ( - ArrayMapping as ArrayMapping, - UNSPECIFIED as _UNSPECIFIED, # noqa: F401 - array_mapping_to_axis_resources as array_mapping_to_axis_resources, -) +_deprecations = { + # deprecated as of JAX v0.8.2 (Dec 2025) + "Index": ( + "jax.interpreters.pxla.Index is deprecated as of JAX v0.8.2.", + _deprecated_pxla.Index, + ), + "MapTracer": ( + "jax.interpreters.pxla.MapTracer is deprecated as of JAX v0.8.2.", + _deprecated_pxla.MapTracer, + ), + "MeshAxisName": ( + "jax.interpreters.pxla.MeshAxisName is deprecated as of JAX v0.8.2. Use jax.sharding.Mesh axis names directly.", + _deprecated_pxla.MeshAxisName, + ), + "MeshComputation": ( + "jax.interpreters.pxla.MeshComputation is deprecated as of JAX v0.8.2.", + _deprecated_pxla.MeshComputation, + ), + "MeshExecutable": ( + "jax.interpreters.pxla.MeshExecutable is deprecated as of JAX v0.8.2.", + _deprecated_pxla.MeshExecutable, + ), + "PmapExecutable": ( + "jax.interpreters.pxla.PmapExecutable is deprecated as of JAX v0.8.2.", + _deprecated_pxla.PmapExecutable, + ), + "global_aval_to_result_handler": ( + "jax.interpreters.pxla.global_aval_to_result_handler is deprecated as of JAX v0.8.2.", + _deprecated_pxla.global_aval_to_result_handler, + ), + "global_avals_to_results_handler": ( + "jax.interpreters.pxla.global_avals_to_results_handler is deprecated as of JAX v0.8.2.", + _deprecated_pxla.global_avals_to_results_handler, + ), + "global_result_handlers": ( + "jax.interpreters.pxla.global_result_handlers is deprecated as of JAX v0.8.2.", + _deprecated_pxla.global_result_handlers, + ), + "parallel_callable": ( + "jax.interpreters.pxla.parallel_callable is deprecated as of JAX v0.8.2.", + _deprecated_pxla.parallel_callable, + ), + "shard_args": ( + "jax.interpreters.pxla.shard_args is deprecated as of JAX v0.8.2.", + _deprecated_pxla.shard_args, + ), + "xla_pmap_p": ( + "jax.interpreters.pxla.xla_pmap_p is deprecated as of JAX v0.8.2.", + _deprecated_pxla.xla_pmap_p, + ), + "thread_resources": ( + "jax.interpreters.pxla.thread_resources is deprecated as of JAX v0.8.2.", + _deprecated_mesh.thread_resources, + ), + "are_hlo_shardings_equal": ( + "jax.interpreters.pxla.are_hlo_shardings_equal is deprecated as of JAX v0.8.2.", + _deprecated_op_shardings.are_hlo_shardings_equal, + ), + "is_hlo_sharding_replicated": ( + "jax.interpreters.pxla.is_hlo_sharding_replicated is deprecated as of JAX v0.8.2.", + _deprecated_op_shardings.is_hlo_sharding_replicated, + ), + "op_sharding_to_indices": ( + "jax.interpreters.pxla.op_sharding_to_indices is deprecated as of JAX v0.8.2.", + _deprecated_op_shardings.op_sharding_to_indices, + ), + "ArrayMapping": ( + "jax.interpreters.pxla.ArrayMapping is deprecated as of JAX v0.8.2.", + _deprecated_sharding_impls.ArrayMapping, + ), + "_UNSPECIFIED": ( + "jax.interpreters.pxla._UNSPECIFIED is deprecated as of JAX v0.8.2.", + _deprecated_sharding_impls.UNSPECIFIED, + ), + "array_mapping_to_axis_resources": ( + "jax.interpreters.pxla.array_mapping_to_axis_resources is deprecated as of JAX v0.8.2.", + _deprecated_sharding_impls.array_mapping_to_axis_resources, + ), + "Chunked": ( + "jax.interpreters.pxla.Chunked is deprecated as of JAX v0.8.2.", + _deprecated_sharding_specs.Chunked, + ), + "NoSharding": ( + "jax.interpreters.pxla.NoSharding is deprecated as of JAX v0.8.2.", + _deprecated_sharding_specs.NoSharding, + ), + "Replicated": ( + "jax.interpreters.pxla.Replicated is deprecated as of JAX v0.8.2.", + _deprecated_sharding_specs.Replicated, + ), + "ShardedAxis": ( + "jax.interpreters.pxla.ShardedAxis is deprecated as of JAX v0.8.2.", + _deprecated_sharding_specs.ShardedAxis, + ), + "ShardingSpec": ( + "jax.interpreters.pxla.ShardingSpec is deprecated as of JAX v0.8.2.", + _deprecated_sharding_specs.ShardingSpec, + ), + "Unstacked": ( + "jax.interpreters.pxla.Unstacked is deprecated as of JAX v0.8.2.", + _deprecated_sharding_specs.Unstacked, + ), + "spec_to_indices": ( + "jax.interpreters.pxla.spec_to_indices is deprecated as of JAX v0.8.2.", + _deprecated_sharding_specs.spec_to_indices, + ), +} -from jax._src.sharding_specs import ( - Chunked as Chunked, - NoSharding as NoSharding, - Replicated as Replicated, - ShardedAxis as ShardedAxis, - ShardingSpec as ShardingSpec, - Unstacked as Unstacked, - spec_to_indices as spec_to_indices, -) +import typing as _typing +if _typing.TYPE_CHECKING: + Index = _deprecated_pxla.Index + MapTracer = _deprecated_pxla.MapTracer + MeshAxisName = _deprecated_pxla.MeshAxisName + MeshComputation = _deprecated_pxla.MeshComputation + MeshExecutable = _deprecated_pxla.MeshExecutable + PmapExecutable = _deprecated_pxla.PmapExecutable + global_aval_to_result_handler = _deprecated_pxla.global_aval_to_result_handler + global_avals_to_results_handler = _deprecated_pxla.global_avals_to_results_handler + global_result_handlers = _deprecated_pxla.global_result_handlers + parallel_callable = _deprecated_pxla.parallel_callable + shard_args = _deprecated_pxla.shard_args + xla_pmap_p = _deprecated_pxla.xla_pmap_p + thread_resources = _deprecated_mesh.thread_resources + are_hlo_shardings_equal = _deprecated_op_shardings.are_hlo_shardings_equal + is_hlo_sharding_replicated = _deprecated_op_shardings.is_hlo_sharding_replicated + op_sharding_to_indices = _deprecated_op_shardings.op_sharding_to_indices + ArrayMapping = _deprecated_sharding_impls.ArrayMapping + _UNSPECIFIED = _deprecated_sharding_impls.UNSPECIFIED + array_mapping_to_axis_resources = _deprecated_sharding_impls.array_mapping_to_axis_resources + Chunked = _deprecated_sharding_specs.Chunked + NoSharding = _deprecated_sharding_specs.NoSharding + Replicated = _deprecated_sharding_specs.Replicated + ShardedAxis = _deprecated_sharding_specs.ShardedAxis + ShardingSpec = _deprecated_sharding_specs.ShardingSpec + Unstacked = _deprecated_sharding_specs.Unstacked + spec_to_indices = _deprecated_sharding_specs.spec_to_indices +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del _typing diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index bd3b83e37d24..4d7cc994f341 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax._src.interpreters.xla import ( - canonicalize_dtype as canonicalize_dtype, - canonicalize_dtype_handlers as canonicalize_dtype_handlers, +__all__ = ["apply_primitive", "canonicalize_dtype_handlers", "Backend"] + +from jax._src.dtypes import ( + canonicalize_value_handlers as canonicalize_dtype_handlers ) from jax._src.dispatch import ( @@ -24,42 +25,3 @@ from jax._src.lib import xla_client as _xc Backend = _xc._xla.Client del _xc - -from jax._src import core as _src_core - -# Deprecations -_deprecations = { - # Added 2024-12-17 - "abstractify": ( - "jax.interpreters.xla.abstractify is deprecated.", - _src_core.abstractify - ), - "pytype_aval_mappings": ( - "jax.interpreters.xla.pytype_aval_mappings is deprecated.", - _src_core.pytype_aval_mappings - ), - # Finalized 2024-10-24; remove after 2025-01-24 - "xb": ( - ("jax.interpreters.xla.xb was removed in JAX v0.4.36. " - "Use jax.lib.xla_bridge instead."), None - ), - "xc": ( - ("jax.interpreters.xla.xc was removed in JAX v0.4.36. " - "Use jax.lib.xla_client instead."), None - ), - "xe": ( - ("jax.interpreters.xla.xe was removed in JAX v0.4.36. " - "Use jax.lib.xla_extension instead."), None - ), -} - -import typing as _typing -from jax._src.deprecations import deprecation_getattr as _deprecation_getattr -if _typing.TYPE_CHECKING: - abstractify = _src_core.abstractify - pytype_aval_mappings = _src_core.pytype_aval_mappings -else: - __getattr__ = _deprecation_getattr(__name__, _deprecations) -del _deprecation_getattr -del _typing -del _src_core diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index 4e376fb666d1..8a8acee84755 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -18,6 +18,8 @@ from jax._src.lax.lax import ( DotDimensionNumbers as DotDimensionNumbers, RaggedDotDimensionNumbers as RaggedDotDimensionNumbers, + AccuracyMode as AccuracyMode, + Tolerance as Tolerance, Precision as Precision, PrecisionLike as PrecisionLike, DotAlgorithm as DotAlgorithm, @@ -82,6 +84,8 @@ convert_element_type_p as convert_element_type_p, copy_p as copy_p, cos as cos, + dce_sink_p as dce_sink_p, + dce_sink as dce_sink, cos_p as cos_p, cosh as cosh, cosh_p as cosh_p, @@ -113,8 +117,6 @@ gt_p as gt_p, imag as imag, imag_p as imag_p, - infeed as infeed, - infeed_p as infeed_p, integer_pow as integer_pow, integer_pow_p as integer_pow_p, iota as iota, @@ -149,8 +151,6 @@ optimization_barrier as optimization_barrier, optimization_barrier_p as optimization_barrier_p, or_p as or_p, - outfeed as outfeed, - outfeed_p as outfeed_p, pad as pad, pad_p as pad_p, padtype_to_pads as padtype_to_pads, @@ -198,6 +198,7 @@ select as select, select_n as select_n, select_n_p as select_n_p, + shape_as_value as shape_as_value, shift_left as shift_left, shift_left_p as shift_left_p, shift_right_arithmetic as shift_right_arithmetic, @@ -228,12 +229,14 @@ tan_p as tan_p, tanh as tanh, tanh_p as tanh_p, + tile as tile, + tile_p as tile_p, top_k as top_k, top_k_p as top_k_p, transpose as transpose, transpose_p as transpose_p, xor_p as xor_p, - zeros_like_array as zeros_like_array, + empty as empty, ) from jax._src.lax.special import ( bessel_i0e as bessel_i0e, @@ -260,7 +263,6 @@ polygamma as polygamma, polygamma_p as polygamma_p, random_gamma_grad as random_gamma_grad, - random_gamma_grad_p as random_gamma_grad_p, regularized_incomplete_beta_p as regularized_incomplete_beta_p, zeta as zeta, zeta_p as zeta_p, @@ -356,11 +358,13 @@ ) from jax._src.lax.parallel import ( all_gather as all_gather, + pcast as pcast, all_gather_p as all_gather_p, all_to_all as all_to_all, all_to_all_p as all_to_all_p, axis_index as axis_index, axis_index_p as axis_index_p, + axis_size as axis_size, pbroadcast as pbroadcast, pmax as pmax, pmax_p as pmax_p, @@ -369,6 +373,8 @@ pmin_p as pmin_p, ppermute as ppermute, ppermute_p as ppermute_p, + psend as psend, + precv as precv, pshuffle as pshuffle, psum as psum, psum_p as psum_p, @@ -377,6 +383,9 @@ ragged_all_to_all as ragged_all_to_all, ragged_all_to_all_p as ragged_all_to_all_p, ) +from jax._src.core import ( + pvary as _deprecated_pvary, +) from jax._src.lax.other import ( conv_general_dilated_local as conv_general_dilated_local, conv_general_dilated_patches as conv_general_dilated_patches @@ -392,3 +401,31 @@ from jax._src.pjit import with_sharding_constraint as with_sharding_constraint from jax._src.pjit import sharding_constraint_p as sharding_constraint_p from jax._src.dispatch import device_put_p as device_put_p + +_deprecations = { + # Deprecated in v0.7.1; finalized in v0.9.0. + # TODO(jakevdp) remove entry in v0.10.0. + "zeros_like_array": ( + ( + "jax.lax.zeros_like_array was deprecated in JAX 0.7.1 and removed" + " in JAX v0.9.0. Use jax.numpy.zeros_like instead." + ), + None, + ), + # Deprecated in v0.8.2. + # TODO(jakevdp) finalize in v0.10.0. + "pvary": ( + "jax.lax.pvary is deprecated. Use `jax.lax.pcast(..., to='varying')", + _deprecated_pvary, + ), +} + +import typing as _typing +if _typing.TYPE_CHECKING: + pvary = _deprecated_pvary +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del _deprecated_pvary +del _typing diff --git a/jax/lax/linalg.py b/jax/lax/linalg.py index 343073ca56d0..0bda73f0c311 100644 --- a/jax/lax/linalg.py +++ b/jax/lax/linalg.py @@ -17,9 +17,11 @@ cholesky_p as cholesky_p, cholesky_update as cholesky_update, cholesky_update_p as cholesky_update_p, + EigImplementation as EigImplementation, eig as eig, eig_p as eig_p, eigh as eigh, + EighImplementation as EighImplementation, eigh_p as eigh_p, hessenberg as hessenberg, hessenberg_p as hessenberg_p, @@ -46,6 +48,6 @@ tridiagonal_solve_p as tridiagonal_solve_p, ) -from jax._src.lax.qdwh import ( +from jax._src.tpu.linalg.qdwh import ( qdwh as qdwh ) diff --git a/jax/lib/__init__.py b/jax/lib/__init__.py index 7534cd6c700f..46b3668e0fdf 100644 --- a/jax/lib/__init__.py +++ b/jax/lib/__init__.py @@ -16,8 +16,3 @@ from jax._src.lib import ( version_str as __version__, ) -from jax.lib import ( - xla_bridge as xla_bridge, - xla_client as xla_client, - xla_extension as xla_extension, -) diff --git a/jax/lib/xla_bridge.py b/jax/lib/xla_bridge.py deleted file mode 100644 index b158d9b1ff51..000000000000 --- a/jax/lib/xla_bridge.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2018 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# ruff: noqa: F401 -from jax._src.xla_bridge import ( - get_backend as _deprecated_get_backend, -) - -from jax._src.compiler import ( - get_compile_options as get_compile_options, -) - -_deprecations = { - # Added July 31, 2024 - "get_backend": ( - "jax.lib.xla_bridge.get_backend is deprecated; use jax.extend.backend.get_backend.", - _deprecated_get_backend - ), - # Finalized 2024-12-11; remove after 2025-3-11 - "xla_client": ( - "jax.lib.xla_bridge.xla_client was removed in JAX v0.4.38; use jax.lib.xla_client directly.", - None - ), - "default_backend": ( - "jax.lib.xla_bridge.default_backend was removed in JAX v0.4.38; use jax.default_backend.", - None - ), -} - -import typing as _typing -if _typing.TYPE_CHECKING: - from jax._src.xla_bridge import get_backend as get_backend -else: - from jax._src.deprecations import deprecation_getattr as _deprecation_getattr - __getattr__ = _deprecation_getattr(__name__, _deprecations) - del _deprecation_getattr -del _typing diff --git a/jax/lib/xla_client.py b/jax/lib/xla_client.py deleted file mode 100644 index 86e7307c804b..000000000000 --- a/jax/lib/xla_client.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright 2024 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from jax._src.lax.fft import FftType as _FftType -from jax._src.lib import xla_client as _xc - -get_topology_for_devices = _xc.get_topology_for_devices -heap_profile = _xc.heap_profile -mlir_api_version = _xc.mlir_api_version -Client = _xc.Client -CompileOptions = _xc.CompileOptions -DeviceAssignment = _xc.DeviceAssignment -Frame = _xc.Frame -HloSharding = _xc.HloSharding -OpSharding = _xc.OpSharding -Traceback = _xc.Traceback - -_deprecations = { - # Finalized 2024-12-11; remove after 2025-3-11 - "_xla": ( - "jax.lib.xla_client._xla was removed in JAX v0.4.38; use jax.lib.xla_extension.", - None, - ), - "bfloat16": ( - "jax.lib.xla_client.bfloat16 was removed in JAX v0.4.38; use ml_dtypes.bfloat16.", - None, - ), - # Finalized 2024-12-23; remove after 2024-03-23 - "Device": ( - "jax.lib.xla_client.Device is deprecated; use jax.Device instead.", - None, - ), - "XlaRuntimeError": ( - ( - "jax.lib.xla_client.XlaRuntimeError is deprecated; use" - " jax.errors.JaxRuntimeError." - ), - None, - ), - # Added Oct 10 2024 - "FftType": ( - "jax.lib.xla_client.FftType is deprecated; use jax.lax.FftType.", - _FftType, - ), - "PaddingType": ( - ( - "jax.lib.xla_client.PaddingType is deprecated; this type is unused" - " by JAX so there is no replacement." - ), - _xc.PaddingType, - ), - # Added Oct 11 2024 - "dtype_to_etype": ( - "dtype_to_etype is deprecated; use StableHLO instead.", - _xc.dtype_to_etype, - ), - "ops": ( - "ops is deprecated; use StableHLO instead.", - _xc.ops, - ), - "register_custom_call_target": ( - "register_custom_call_target is deprecated; use the JAX FFI instead " - "(https://jax.readthedocs.io/en/latest/ffi.html)", - _xc.register_custom_call_target, - ), - "shape_from_pyval": ( - "shape_from_pyval is deprecated; use StableHLO instead.", - _xc.shape_from_pyval, - ), - "PrimitiveType": ( - "PrimitiveType is deprecated; use StableHLO instead.", - _xc.PrimitiveType, - ), - "Shape": ( - "Shape is deprecated; use StableHLO instead.", - _xc.Shape, - ), - "XlaBuilder": ( - "XlaBuilder is deprecated; use StableHLO instead.", - _xc.XlaBuilder, - ), - "XlaComputation": ( - "XlaComputation is deprecated; use StableHLO instead.", - _xc.XlaComputation, - ), - # Added Nov 20 2024 - "ArrayImpl": ( - "jax.lib.xla_client.ArrayImpl is deprecated; use jax.Array instead.", - _xc.ArrayImpl, - ), -} - -import typing as _typing - -if _typing.TYPE_CHECKING: - dtype_to_etype = _xc.dtype_to_etype - ops = _xc.ops - register_custom_call_target = _xc.register_custom_call_target - shape_from_pyval = _xc.shape_from_pyval - ArrayImpl = _xc.ArrayImpl - Device = _xc.Device - FftType = _FftType - PaddingType = _xc.PaddingType - PrimitiveType = _xc.PrimitiveType - Shape = _xc.Shape - XlaBuilder = _xc.XlaBuilder - XlaComputation = _xc.XlaComputation - XlaRuntimeError = _xc.XlaRuntimeError -else: - from jax._src.deprecations import deprecation_getattr as _deprecation_getattr - - __getattr__ = _deprecation_getattr(__name__, _deprecations) - del _deprecation_getattr -del _typing -del _FftType -del _xc diff --git a/jax/lib/xla_extension.py b/jax/lib/xla_extension.py deleted file mode 100644 index 52fe94e231d1..000000000000 --- a/jax/lib/xla_extension.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2024 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from jax._src.lib import xla_extension as _xe - -get_distributed_runtime_client = _xe.get_distributed_runtime_client -get_distributed_runtime_service = _xe.get_distributed_runtime_service -hlo_module_cost_analysis = _xe.hlo_module_cost_analysis -hlo_module_to_dot_graph = _xe.hlo_module_to_dot_graph -ifrt_proxy = _xe.ifrt_proxy -jax_jit = _xe.jax_jit -mlir = _xe.mlir -pmap_lib = _xe.pmap_lib -profiler = _xe.profiler -pytree = _xe.pytree -Device = _xe.Device -DistributedRuntimeClient = _xe.DistributedRuntimeClient -HloModule = _xe.HloModule -HloPrintOptions = _xe.HloPrintOptions -OpSharding = _xe.OpSharding -PjitFunctionCache = _xe.PjitFunctionCache -PjitFunction = _xe.PjitFunction -PmapFunction = _xe.PmapFunction - -_deprecations = { - # Added Nov 20 2024 - "ArrayImpl": ( - "jax.lib.xla_extension.ArrayImpl is deprecated; use jax.Array instead.", - _xe.ArrayImpl, - ), - "XlaRuntimeError": ( - "jax.lib.xla_extension.XlaRuntimeError is deprecated; use jax.errors.JaxRuntimeError instead.", - _xe.XlaRuntimeError, - ), -} - -import typing as _typing - -if _typing.TYPE_CHECKING: - ArrayImpl = _xe.ArrayImpl - XlaRuntimeError = _xe.XlaRuntimeError -else: - from jax._src.deprecations import deprecation_getattr as _deprecation_getattr - - __getattr__ = _deprecation_getattr(__name__, _deprecations) - del _deprecation_getattr -del _typing -del _xe diff --git a/jax/memory.py b/jax/memory.py new file mode 100644 index 000000000000..3269f5887186 --- /dev/null +++ b/jax/memory.py @@ -0,0 +1,15 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jax._src.memory import Space as Space diff --git a/jax/monitoring.py b/jax/monitoring.py index 4c9996da582c..3c3a11208403 100644 --- a/jax/monitoring.py +++ b/jax/monitoring.py @@ -26,7 +26,13 @@ record_event_duration_secs as record_event_duration_secs, record_event_time_span as record_event_time_span, record_event as record_event, + record_scalar as record_scalar, register_event_duration_secs_listener as register_event_duration_secs_listener, register_event_listener as register_event_listener, register_event_time_span_listener as register_event_time_span_listener, + register_scalar_listener as register_scalar_listener, + unregister_event_duration_listener as unregister_event_duration_listener, + unregister_event_listener as unregister_event_listener, + unregister_event_time_span_listener as unregister_event_time_span_listener, + unregister_scalar_listener as unregister_scalar_listener, ) diff --git a/jax/nn/__init__.py b/jax/nn/__init__.py index 3f08e1c0fd12..7c2da998d065 100644 --- a/jax/nn/__init__.py +++ b/jax/nn/__init__.py @@ -31,12 +31,15 @@ leaky_relu as leaky_relu, log_sigmoid as log_sigmoid, log_softmax as log_softmax, + logmeanexp as logmeanexp, logsumexp as logsumexp, standardize as standardize, one_hot as one_hot, relu as relu, + identity as identity, relu6 as relu6, dot_product_attention as dot_product_attention, + get_scaled_dot_general_config as get_scaled_dot_general_config, scaled_dot_general as scaled_dot_general, scaled_matmul as scaled_matmul, selu as selu, @@ -50,4 +53,5 @@ swish as swish, squareplus as squareplus, mish as mish, + log1mexp as log1mexp, ) diff --git a/jax/nn/__init__.pyi b/jax/nn/__init__.pyi new file mode 100644 index 000000000000..9e01197f3ff5 --- /dev/null +++ b/jax/nn/__init__.pyi @@ -0,0 +1,160 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Literal, overload, Sequence + +from jax._src.core import AxisName +from jax._src.cudnn.scaled_matmul_stablehlo import BlockScaleConfig +from jax._src.lax.lax import DotDimensionNumbers +from jax._src.typing import Array, ArrayLike, DTypeLike + +from jax.nn import initializers as initializers + +Axis = int | Sequence[int] | None + + +def celu(x: ArrayLike, alpha: ArrayLike = ...) -> Array: ... +@overload +def dot_product_attention( + query: ArrayLike, + key: ArrayLike, + value: ArrayLike, + bias: ArrayLike | None = ..., + mask: ArrayLike | None = ..., + *, + scale: float | None = ..., + is_causal: bool = ..., + query_seq_lengths: ArrayLike | None = ..., + key_value_seq_lengths: ArrayLike | None = ..., + local_window_size: int | tuple[int, int] | None = ..., + implementation: Literal['xla', 'cudnn'] | None = ..., + return_residual: Literal[False] = ..., + ) -> Array: ... +@overload +def dot_product_attention( + query: ArrayLike, + key: ArrayLike, + value: ArrayLike, + bias: ArrayLike | None = ..., + mask: ArrayLike | None = ..., + *, + scale: float | None = ..., + is_causal: bool = ..., + query_seq_lengths: ArrayLike | None = ..., + key_value_seq_lengths: ArrayLike | None = ..., + local_window_size: int | tuple[int, int] | None = ..., + implementation: Literal['xla', 'cudnn'] | None = ..., + return_residual: Literal[True] = ..., + ) -> tuple[Array, Array]: ... +def elu(x: ArrayLike, alpha: ArrayLike = ...) -> Array: ... +def gelu(x: ArrayLike, approximate: bool = ...) -> Array: ... +def get_scaled_dot_general_config( + mode: Literal['nvfp4', 'mxfp8'], + global_scale: Array | None = ..., + ) -> BlockScaleConfig: ... +def glu(x: ArrayLike, axis: int = ...) -> Array: ... +def hard_sigmoid(x: ArrayLike) -> Array: ... +def hard_silu(x: ArrayLike) -> Array: ... +def hard_swish(x: ArrayLike) -> Array: ... +def hard_tanh(x: ArrayLike) -> Array: ... +def identity(x: ArrayLike) -> Array: ... +def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = ...) -> Array: ... +def log_sigmoid(x: ArrayLike) -> Array: ... +def log_softmax( + x: ArrayLike, + axis: Axis = ..., + where: ArrayLike | None = ..., + ) -> Array: ... +def logmeanexp( + x: ArrayLike, + axis: Axis = None, + where: ArrayLike | None = None, + keepdims: bool = False, +) -> Array: ... +@overload +def logsumexp( + a: ArrayLike, + axis: Axis = ..., + b: ArrayLike | None = ..., + keepdims: bool = ..., + return_sign: Literal[False] = ..., + where: ArrayLike | None = ..., + ) -> Array: ... +@overload +def logsumexp( + a: ArrayLike, + axis: Axis = ..., + b: ArrayLike | None = ..., + keepdims: bool = ..., + *, + return_sign: Literal[True], + where: ArrayLike | None = ..., + ) -> tuple[Array, Array]: ... +@overload +def logsumexp( + a: ArrayLike, + axis: Axis = ..., + b: ArrayLike | None = ..., + keepdims: bool = ..., + return_sign: bool = ..., + where: ArrayLike | None = ..., + ) -> Array | tuple[Array, Array]: ... +def mish(x: ArrayLike) -> Array: ... +def one_hot( + x: Any, + num_classes: int, + *, + dtype: Any = ..., + axis: int | AxisName = ... + ) -> Array: ... +def relu(x: ArrayLike) -> Array: ... +def relu6(x: ArrayLike) -> Array: ... +def scaled_dot_general( + lhs: ArrayLike, rhs: ArrayLike, + dimension_numbers: DotDimensionNumbers, + preferred_element_type: DTypeLike = ..., + configs: List[BlockScaleConfig] | None = ..., + implementation: Literal['cudnn'] | None = ..., + ) -> Array: ... +def scaled_matmul( + lhs: Array, + rhs: Array, + lhs_scales: Array, + rhs_scales: Array, + preferred_element_type: DTypeLike = ..., + ) -> Array: ... +def selu(x: ArrayLike) -> Array: ... +def sigmoid(x: ArrayLike) -> Array: ... +def silu(x: ArrayLike) -> Array: ... +def soft_sign(x: ArrayLike) -> Array: ... +def softmax( + x: ArrayLike, + axis: Axis = ..., + where: ArrayLike | None = ... + ) -> Array: ... +def softplus(x: ArrayLike) -> Array: ... +def sparse_plus(x: ArrayLike) -> Array: ... +def sparse_sigmoid(x: ArrayLike) -> Array: ... +def squareplus(x: ArrayLike, b: ArrayLike = ...) -> Array: ... +def standardize( + x: ArrayLike, + axis: Axis = ..., + mean: ArrayLike | None = ..., + variance: ArrayLike | None = ..., + epsilon: ArrayLike = ..., + where: ArrayLike | None = ... + ) -> Array: ... +def swish(x: ArrayLike) -> Array: ... +def tanh(x: ArrayLike, /) -> Array: ... +def log1mexp(x: ArrayLike) -> Array: ... diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index cb291bdca79a..c35984b691e5 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -24,6 +24,11 @@ isdtype as isdtype, ) +from jax._src.numpy.array_constructors import ( + array as array, + asarray as asarray, +) + from jax._src.numpy.lax_numpy import ( ComplexWarning as ComplexWarning, allclose as allclose, @@ -36,12 +41,10 @@ argmin as argmin, argwhere as argwhere, around as around, - array as array, array_equal as array_equal, array_equiv as array_equiv, array_split as array_split, astype as astype, - asarray as asarray, atleast_1d as atleast_1d, atleast_2d as atleast_2d, atleast_3d as atleast_3d, @@ -79,7 +82,7 @@ eye as eye, fill_diagonal as fill_diagonal, finfo as finfo, - fix as fix, + fix as _deprecated_fix, flatnonzero as flatnonzero, flip as flip, fliplr as fliplr, @@ -93,7 +96,6 @@ fromstring as fromstring, from_dlpack as from_dlpack, gcd as gcd, - geomspace as geomspace, get_printoptions as get_printoptions, gradient as gradient, histogram as histogram, @@ -118,9 +120,7 @@ ix_ as ix_, kron as kron, lcm as lcm, - linspace as linspace, load as load, - logspace as logspace, mask_indices as mask_indices, matrix_transpose as matrix_transpose, meshgrid as meshgrid, @@ -180,6 +180,9 @@ empty_like as empty_like, full as full, full_like as full_like, + geomspace as geomspace, + linspace as linspace, + logspace as logspace, ones as ones, ones_like as ones_like, zeros as zeros, @@ -211,13 +214,18 @@ double as double, float16 as float16, float32 as float32, + float4_e2m1fn as float4_e2m1fn, float64 as float64, + float8_e3m4 as float8_e3m4, + float8_e4m3 as float8_e4m3, float8_e4m3b11fnuz as float8_e4m3b11fnuz, float8_e4m3fn as float8_e4m3fn, float8_e4m3fnuz as float8_e4m3fnuz, float8_e5m2 as float8_e5m2, float8_e5m2fnuz as float8_e5m2fnuz, + float8_e8m0fnu as float8_e8m0fnu, float_ as float_, + int2 as int2, int4 as int4, int8 as int8, int16 as int16, @@ -226,6 +234,7 @@ int_ as int_, single as single, uint as uint, + uint2 as uint2, uint4 as uint4, uint8 as uint8, uint16 as uint16, @@ -295,26 +304,6 @@ unsignedinteger as unsignedinteger, ) -# TODO(slebedev): Remove the try-except once we upgrade to ml_dtypes 0.4.1. -try: - from jax._src.numpy.scalar_types import ( - int2 as int2, - uint2 as uint2, - ) -except ImportError: - pass - -# TODO: Remove the try-except once we upgrade to ml_dtypes 0.5.0 -try: - from jax._src.numpy.scalar_types import ( - float8_e3m4 as float8_e3m4, - float8_e4m3 as float8_e4m3, - float8_e8m0fnu as float8_e8m0fnu, - float4_e2m1fn as float4_e2m1fn, - ) -except ImportError: - pass - from jax._src.numpy.array_api_metadata import ( __array_api_version__ as __array_api_version__, __array_namespace_info__ as __array_namespace_info__, @@ -509,16 +498,22 @@ _deprecations = { - # Finalized 2024-12-13; remove after 2024-3-13 - "round_": ( - "jnp.round_ was deprecated in JAX 0.4.38; use jnp.round instead.", - None + # Deprecated in v0.9.0 + "fix": ( + ( + "jax.numpy.fix was deprecated in JAX v0.9.0, and will be" + " removed in JAX v0.10.0. Use jax.numpy.trunc instead." + ), + _deprecated_fix, ), } -import typing -if not typing.TYPE_CHECKING: +import typing as _typing +if _typing.TYPE_CHECKING: + fix = _deprecated_fix +else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) del _deprecation_getattr -del typing +del _typing +del _deprecated_fix diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index b73a3b95b9a5..2c0ce335441b 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -15,7 +15,7 @@ from jax._src.numpy.index_tricks import _Mgrid, _Ogrid, CClass as _CClass, RClas from jax._src.numpy.array_api_metadata import ArrayNamespaceInfo from jax._src.typing import ( Array, ArrayLike, DType, DTypeLike, DeprecatedArg, - DimSize, DuckTypedArray, Shape, StaticScalar, + DimSize, DuckTypedArray, Shape, StaticScalar, SupportsNdim, SupportsShape, SupportsSize, ) from jax._src.sharding_impls import NamedSharding, PartitionSpec as P from jax.numpy import fft as fft, linalg as linalg @@ -135,6 +135,7 @@ def arange( step: ArrayLike | None = ..., dtype: DTypeLike | None = ..., *, device: _Device | _Sharding | None = ..., + out_sharding: NamedSharding | P | None = ..., ) -> Array: ... def arccos(x: ArrayLike, /) -> Array: ... def arccosh(x: ArrayLike, /) -> Array: ... @@ -174,7 +175,8 @@ def argwhere( def around(a: ArrayLike, decimals: int = ..., out: None = ...) -> Array: ... def array(object: Any, dtype: DTypeLike | None = ..., copy: builtins.bool = True, order: str | None = ..., ndmin: int = ..., *, - device: _Device | _Sharding | None = None) -> Array: ... + device: _Device | _Sharding | None = None, + out_sharding: NamedSharding | P | None = None) -> Array: ... def array_equal( a1: ArrayLike, a2: ArrayLike, equal_nan: builtins.bool = ... ) -> Array: ... @@ -190,6 +192,7 @@ def asarray( a: Any, dtype: DTypeLike | None = ..., order: str | None = ..., *, copy: builtins.bool | None = ..., device: _Device | _Sharding | None = ..., + out_sharding: NamedSharding | P | None = ..., ) -> Array: ... def asin(x: ArrayLike, /) -> Array: ... def asinh(x: ArrayLike, /) -> Array: ... @@ -253,7 +256,8 @@ def broadcast_shapes(*shapes: Sequence[int]) -> tuple[int, ...]: ... def broadcast_shapes(*shapes: Sequence[int | _core.Tracer] ) -> tuple[int | _core.Tracer, ...]: ... -def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array: ... +def broadcast_to(array: ArrayLike, shape: DimSize | Shape, *, + out_sharding: NamedSharding | P | None = None) -> Array: ... c_: _CClass can_cast = _np.can_cast def cbrt(x: ArrayLike, /) -> Array: ... @@ -267,6 +271,7 @@ def clip( /, min: ArrayLike | None = ..., max: ArrayLike | None = ..., + *, a: ArrayLike | DeprecatedArg | None = ..., a_min: ArrayLike | DeprecatedArg | None = ..., a_max: ArrayLike | DeprecatedArg | None = ... @@ -278,7 +283,7 @@ complex128: Any complex64: Any complex_: Any complexfloating = _np.complexfloating -def compress(condition: ArrayLike, a: ArrayLike, axis: int | None = ..., +def compress(condition: ArrayLike, a: ArrayLike, axis: int | None = ..., *, size: int | None = ..., fill_value: ArrayLike = ..., out: None = ...) -> Array: ... def concat(arrays: Sequence[ArrayLike], /, *, axis: int | None = 0) -> Array: ... def concatenate( @@ -293,7 +298,10 @@ def convolve(a: ArrayLike, v: ArrayLike, mode: str = ..., *, preferred_element_type: DTypeLike | None = ...) -> Array: ... def copy(a: ArrayLike, order: str | None = ...) -> Array: ... def copysign(x: ArrayLike, y: ArrayLike, /) -> Array: ... -def corrcoef(x: ArrayLike, y: ArrayLike | None = ..., rowvar: builtins.bool = ...) -> Array: ... +def corrcoef(x: ArrayLike, + y: ArrayLike | None = ..., + rowvar: builtins.bool = ..., + dtype: DTypeLike | None = ...) -> Array: ... def correlate(a: ArrayLike, v: ArrayLike, mode: str = ..., *, precision: PrecisionLike = ..., preferred_element_type: DTypeLike | None = ...) -> Array: ... @@ -304,7 +312,8 @@ def count_nonzero(a: ArrayLike, axis: _Axis = ..., def cov(m: ArrayLike, y: ArrayLike | None = ..., rowvar: builtins.bool = ..., bias: builtins.bool = ..., ddof: int | None = ..., fweights: ArrayLike | None = ..., - aweights: ArrayLike | None = ...) -> Array: ... + aweights: ArrayLike | None = ..., + dtype: DTypeLike | None = ...) -> Array: ... def cross( a: ArrayLike, b: ArrayLike, @@ -314,9 +323,9 @@ def cross( axis: int | None = ..., ) -> Array: ... csingle: Any -def cumprod(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike = ..., +def cumprod(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike | None = ..., out: None = ...) -> Array: ... -def cumsum(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike = ..., +def cumsum(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike | None = ..., out: None = ...) -> Array: ... def cumulative_prod(x: ArrayLike, /, *, axis: int | None = ..., dtype: DTypeLike | None = ..., @@ -350,7 +359,8 @@ def divide(x: ArrayLike, y: ArrayLike, /) -> Array: ... def divmod(x: ArrayLike, y: ArrayLike, /) -> tuple[Array, Array]: ... def dot( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = ..., - preferred_element_type: DTypeLike | None = ...) -> Array: ... + preferred_element_type: DTypeLike | None = ..., + out_sharding: NamedSharding | P | None = ...) -> Array: ... double: Any def dsplit( ary: ArrayLike, indices_or_sections: int | ArrayLike @@ -370,7 +380,6 @@ def einsum( optimize: str | builtins.bool | list[tuple[int, ...]] = ..., precision: PrecisionLike = ..., preferred_element_type: DTypeLike | None = ..., - _use_xeinsum: builtins.bool = False, _dot_general: Callable[..., Array] = ..., out_sharding: NamedSharding | P | None = ..., ) -> Array: ... @@ -384,7 +393,6 @@ def einsum( optimize: str | builtins.bool | list[tuple[int, ...]] = ..., precision: PrecisionLike = ..., preferred_element_type: DTypeLike | None = ..., - _use_xeinsum: builtins.bool = False, _dot_general: Callable[..., Array] = ..., out_sharding: NamedSharding | P | None = ..., ) -> Array: ... @@ -396,7 +404,6 @@ def einsum( optimize: str | builtins.bool | list[tuple[int, ...]] = ..., precision: PrecisionLike = ..., preferred_element_type: DTypeLike | None = ..., - _use_xeinsum: builtins.bool = ..., _dot_general: Callable[..., Array] = ..., out_sharding: NamedSharding | P | None = ..., ) -> Array: ... @@ -421,8 +428,9 @@ def einsum_path( optimize: str | builtins.bool | list[tuple[int, ...]] = ..., ) -> tuple[list[tuple[int, ...]], Any]: ... -def empty(shape: Any, dtype: DTypeLike | None = ..., - device: _Device | _Sharding | None = ...) -> Array: ... +def empty(shape: Any, dtype: DTypeLike | None = ..., *, + device: _Device | _Sharding | None = ..., + out_sharding: NamedSharding | P | None = ...) -> Array: ... def empty_like(prototype: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = ..., shape: Any = ..., *, @@ -456,12 +464,16 @@ def fliplr(m: ArrayLike) -> Array: ... def flipud(m: ArrayLike) -> Array: ... float16: Any float32: Any +float4_e2m1fn: Any float64: Any +float8_e3m4: Any +float8_e4m3: Any float8_e4m3b11fnuz: Any float8_e4m3fn: Any float8_e4m3fnuz: Any float8_e5m2: Any float8_e5m2fnuz: Any +float8_e8m0fnu: Any float_: Any def float_power(x: ArrayLike, y: ArrayLike, /) -> Array: ... floating = _np.floating @@ -562,6 +574,7 @@ def inner( def insert(arr: ArrayLike, obj: ArrayLike | slice, values: ArrayLike, axis: int | None = ...) -> Array: ... int16: Any +int2: Any int32: Any int4: Any int64: Any @@ -578,17 +591,17 @@ def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: builtins.bool = . def invert(x: ArrayLike, /) -> Array: ... def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = ..., atol: ArrayLike = ..., equal_nan: builtins.bool = ...) -> Array: ... -def iscomplex(m: ArrayLike) -> Array: ... +def iscomplex(x: ArrayLike) -> Array: ... def iscomplexobj(x: Any) -> builtins.bool: ... def isdtype(dtype: DTypeLike, kind: DType | str | tuple[DType | str, ...]) -> builtins.bool: ... def isfinite(x: ArrayLike, /) -> Array: ... -def isin(element: ArrayLike, test_elements: ArrayLike, - assume_unique: builtins.bool = ..., invert: builtins.bool = ...) -> Array: ... +def isin(element: ArrayLike, test_elements: ArrayLike, assume_unique: builtins.bool = ..., + invert: builtins.bool = ..., *, method: str = ...) -> Array: ... def isinf(x: ArrayLike, /) -> Array: ... def isnan(x: ArrayLike, /) -> Array: ... def isneginf(x: ArrayLike, /) -> Array: ... def isposinf(x: ArrayLike, /) -> Array: ... -def isreal(m: ArrayLike) -> Array: ... +def isreal(x: ArrayLike) -> Array: ... def isrealobj(x: Any) -> builtins.bool: ... def isscalar(element: Any) -> builtins.bool: ... def issubdtype(arg1: DTypeLike, arg2: DTypeLike) -> builtins.bool: ... @@ -643,18 +656,19 @@ def logspace(start: ArrayLike, stop: ArrayLike, num: int = ..., endpoint: builtins.bool = ..., base: ArrayLike = ..., dtype: DTypeLike | None = ..., axis: int = ...) -> Array: ... def mask_indices( - n: int, mask_func: Callable, k: int = ... + n: int, mask_func: Callable, k: int = ..., *, size: int | None = ... ) -> tuple[Array, ...]: ... def matmul( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = ..., - preferred_element_type: DTypeLike | None = ...) -> Array: ... + preferred_element_type: DTypeLike | None = ..., + out_sharding: NamedSharding | P | None = ...) -> Array: ... def matrix_transpose(x: ArrayLike, /) -> Array: ... def matvec(x1: ArrayLike, x2: ArrayLike, /) -> Array: ... def max(a: ArrayLike, axis: _Axis = ..., out: None = ..., keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., where: ArrayLike | None = ...) -> Array: ... -def maximum(x: ArrayLike, y: ArrayLike, /) -> Array: ... -def mean(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +maximum: BinaryUfunc +def mean(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ..., out: None = ..., keepdims: builtins.bool = ..., *, where: ArrayLike | None = ...) -> Array: ... def median(a: ArrayLike, axis: int | tuple[int, ...] | None = ..., @@ -666,7 +680,7 @@ mgrid: _Mgrid def min(a: ArrayLike, axis: _Axis = ..., out: None = ..., keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., where: ArrayLike | None = ...) -> Array: ... -def minimum(x: ArrayLike, y: ArrayLike, /) -> Array: ... +minimum: BinaryUfunc def mod(x: ArrayLike, y: ArrayLike, /) -> Array: ... def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: ... def moveaxis(a: ArrayLike, source: int | Sequence[int], @@ -688,14 +702,14 @@ def nanargmin( out: None = ..., keepdims: builtins.bool | None = ..., ) -> Array: ... -def nancumprod(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike = ..., +def nancumprod(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike | None = ..., out: None = ...) -> Array: ... -def nancumsum(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike = ..., +def nancumsum(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike | None = ..., out: None = ...) -> Array: ... def nanmax(a: ArrayLike, axis: _Axis = ..., out: None = ..., keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., where: ArrayLike | None = ...) -> Array: ... -def nanmean(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +def nanmean(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ..., out: None = ..., keepdims: builtins.bool = ..., where: ArrayLike | None = ...) -> Array: ... @@ -709,26 +723,26 @@ def nanpercentile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = ..., out: None = ..., overwrite_input: builtins.bool = ..., method: str = ..., keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ... -def nanprod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +def nanprod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ..., out: None = ..., keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., where: ArrayLike | None = ...) -> Array: ... def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = ..., out: None = ..., overwrite_input: builtins.bool = ..., method: str = ..., keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ... -def nanstd(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., - ddof: int = ..., keepdims: builtins.bool = ..., +def nanstd(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ..., + out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., where: ArrayLike | None = ...) -> Array: ... -def nansum(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +def nansum(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ..., out: None = ..., keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., where: ArrayLike | None = ...) -> Array: ... -def nanvar(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +def nanvar(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ..., out: None = ..., ddof: int = 0, keepdims: builtins.bool = False, where: ArrayLike | None = ...) -> Array: ... ndarray = Array -def ndim(a: ArrayLike) -> int: ... +def ndim(a: ArrayLike | SupportsNdim) -> int: ... def negative(x: ArrayLike, /) -> Array: ... newaxis = None def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: ... @@ -739,12 +753,14 @@ def not_equal(x: ArrayLike, y: ArrayLike, /) -> Array: ... number = _np.number object_ = _np.object_ ogrid: _Ogrid -def ones(shape: Any, dtype: DTypeLike | None = ..., - device: _Device | _Sharding | None = ...) -> Array: ... +def ones(shape: Any, dtype: DTypeLike | None = ..., *, + device: _Device | _Sharding | None = ..., + out_sharding: NamedSharding | P | None = ...) -> Array: ... def ones_like(a: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = ..., shape: Any = ..., *, - device: _Device | _Sharding | None = ...) -> Array: ... + device: _Device | _Sharding | None = ..., + out_sharding: NamedSharding | P | None = ...) -> Array: ... def outer(a: ArrayLike, b: Array, out: None = ...) -> Array: ... def packbits( a: ArrayLike, axis: int | None = ..., bitorder: str = ... @@ -781,7 +797,7 @@ def positive(x: ArrayLike, /) -> Array: ... def pow(x: ArrayLike, y: ArrayLike, /) -> Array: ... def power(x: ArrayLike, y: ArrayLike, /) -> Array: ... printoptions = _np.printoptions -def prod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +def prod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ..., out: None = ..., keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., where: ArrayLike | None = ..., promote_integers: builtins.bool = ...) -> Array: ... @@ -798,18 +814,19 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = .. r_: _RClass def rad2deg(x: ArrayLike, /) -> Array: ... def radians(x: ArrayLike, /) -> Array: ... -def ravel(a: ArrayLike, order: str = ...) -> Array: ... +def ravel(a: ArrayLike, order: str = ..., *, + out_sharding: NamedSharding | P | None = ...) -> Array: ... def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int], mode: str = ..., order: str = ...) -> Array: ... def real(x: ArrayLike, /) -> Array: ... def reciprocal(x: ArrayLike, /) -> Array: ... -register_jax_array_methods: Any def remainder(x: ArrayLike, y: ArrayLike, /) -> Array: ... def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = ..., *, - total_repeat_length: int | None = ...) -> Array: ... + total_repeat_length: int | None = ..., + out_sharding: NamedSharding | P | None = None) -> Array: ... def reshape( - a: ArrayLike, shape: DimSize | Shape = ..., - newshape: DimSize | Shape | None = ..., order: str = ... + a: ArrayLike, shape: DimSize | Shape, order: str = ..., *, copy: bool | None = ..., + out_sharding: NamedSharding | P | None = ..., ) -> Array: ... def resize(a: ArrayLike, new_shape: Shape) -> Array: ... @@ -841,8 +858,9 @@ def setdiff1d( size: int | None = ..., fill_value: ArrayLike | None = ..., ) -> Array: ... -def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: builtins.bool = ...) -> Array: ... -def shape(a: ArrayLike) -> tuple[int, ...]: ... +def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: builtins.bool = ..., *, + size: int | None = ..., fill_value: ArrayLike | None = ...) -> Array: ... +def shape(a: ArrayLike | SupportsShape) -> tuple[int, ...]: ... def sign(x: ArrayLike, /) -> Array: ... def signbit(x: ArrayLike, /) -> Array: ... signedinteger = _np.signedinteger @@ -850,7 +868,7 @@ def sin(x: ArrayLike, /) -> Array: ... def sinc(x: ArrayLike, /) -> Array: ... single: Any def sinh(x: ArrayLike, /) -> Array: ... -def size(a: ArrayLike, axis: int | None = None) -> int: ... +def size(a: ArrayLike | SupportsSize, axis: _Axis = None) -> int: ... def sort( a: ArrayLike, axis: int | None = ..., @@ -879,14 +897,14 @@ def stack( out: None = ..., dtype: DTypeLike | None = ..., ) -> Array: ... -def std(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +def std(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ..., out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., *, where: ArrayLike | None = ..., correction: int | float | None = ...) -> Array: ... subtract: BinaryUfunc def sum( a: ArrayLike, axis: _Axis = ..., - dtype: DTypeLike = ..., + dtype: DTypeLike | None = ..., out: None = ..., keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., @@ -907,7 +925,7 @@ def take( def take_along_axis( arr: ArrayLike, indices: ArrayLike, - axis: int | None, + axis: int | None = ..., mode: str | GatherScatterMode | None = ..., fill_value: StaticScalar | None = None, ) -> Array: ... @@ -916,7 +934,8 @@ def tanh(x: ArrayLike, /) -> Array: ... def tensordot(a: ArrayLike, b: ArrayLike, axes: int | Sequence[int] | Sequence[Sequence[int]] = ..., *, precision: PrecisionLike = ..., - preferred_element_type: DTypeLike | None = ...) -> Array: ... + preferred_element_type: DTypeLike | None = ..., + out_sharding: NamedSharding | P | None = ...) -> Array: ... def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array: ... def trace(a: ArrayLike, offset: int | ArrayLike = ..., axis1: int = ..., axis2: int = ..., dtype: DTypeLike | None = ..., out: None = ...) -> Array: ... @@ -924,24 +943,25 @@ def transpose(a: ArrayLike, axes: Sequence[int] | None = ...) -> Array: ... def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = ..., axis: int = ...) -> Array: ... def tri( - N: int, M: int | None = ..., k: int = ..., dtype: DTypeLike = ... + N: int, M: int | None = ..., k: int = ..., dtype: DTypeLike | None = ... ) -> Array: ... def tril(m: ArrayLike, k: int = ...) -> Array: ... def tril_indices( n: int, k: int = ..., m: int | None = ... ) -> tuple[Array, Array]: ... -def tril_indices_from(arr: ArrayLike, k: int = ...) -> tuple[Array, Array]: ... +def tril_indices_from(arr: ArrayLike | SupportsShape, k: int = ...) -> tuple[Array, Array]: ... def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: builtins.bool = ..., *, inplace: builtins.bool = ...) -> Array: ... -def trim_zeros(filt: ArrayLike, trim: str = ...) -> Array: ... +def trim_zeros(filt: ArrayLike, trim: str = ..., axis: int | Sequence[int] | None = ...) -> Array: ... def triu(m: ArrayLike, k: int = ...) -> Array: ... def triu_indices( n: int, k: int = ..., m: int | None = ... ) -> tuple[Array, Array]: ... -def triu_indices_from(arr: ArrayLike, k: int = ...) -> tuple[Array, Array]: ... +def triu_indices_from(arr: ArrayLike | SupportsShape, k: int = ...) -> tuple[Array, Array]: ... def true_divide(x: ArrayLike, y: ArrayLike, /) -> Array: ... def trunc(x: ArrayLike, /) -> Array: ... uint: Any uint16: Any +uint2: Any uint32: Any uint4: Any uint64: Any @@ -967,7 +987,7 @@ class _UniqueInverseResult(NamedTuple): def unique(ar: ArrayLike, return_index: builtins.bool = ..., return_inverse: builtins.bool = ..., return_counts: builtins.bool = ..., axis: int | None = ..., *, equal_nan: builtins.bool = ..., size: int | None = ..., - fill_value: ArrayLike | None = ... + fill_value: ArrayLike | None = ..., sorted: bool = ..., ): ... def unique_all(x: ArrayLike, /, *, size: int | None = ..., fill_value: ArrayLike | None = ...) -> _UniqueAllResult: ... @@ -991,7 +1011,7 @@ def unwrap(p: ArrayLike, discont: ArrayLike | None = ..., def vander( x: ArrayLike, N: int | None = ..., increasing: builtins.bool = ... ) -> Array: ... -def var(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +def var(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ..., out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., *, where: ArrayLike | None = ..., correction: int | float | None = ...) -> Array: ... def vdot( @@ -1026,11 +1046,13 @@ def where(condition: ArrayLike, x: ArrayLike | None = ..., fill_value: None | ArrayLike | tuple[ArrayLike, ...] = ... ) -> Array | tuple[Array, ...]: ... -def zeros(shape: Any, dtype: DTypeLike | None = ..., - device: _Device | _Sharding | None = ...) -> Array: ... +def zeros(shape: Any, dtype: DTypeLike | None = ..., *, + device: _Device | _Sharding | None = ..., + out_sharding: NamedSharding | P | None = ...) -> Array: ... def zeros_like(a: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = ..., shape: Any = ..., *, - device: _Device | _Sharding | None = ...) -> Array: ... + device: _Device | _Sharding | None = ..., + out_sharding: NamedSharding | P | None = ...) -> Array: ... def vectorize(pyfunc, *, excluded = ..., signature = ...) -> Callable: ... diff --git a/jax/profiler.py b/jax/profiler.py index 77157dc02a13..86711a815ef8 100644 --- a/jax/profiler.py +++ b/jax/profiler.py @@ -14,16 +14,19 @@ # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/jax-ml/jax/issues/7570 - from jax._src.profiler import ( - StepTraceAnnotation as StepTraceAnnotation, - TraceAnnotation as TraceAnnotation, - device_memory_profile as device_memory_profile, - save_device_memory_profile as save_device_memory_profile, - start_server as start_server, - stop_server as stop_server, - start_trace as start_trace, - stop_trace as stop_trace, - trace as trace, - annotate_function as annotate_function, + ProfileData as ProfileData, + ProfileEvent as ProfileEvent, + ProfileOptions as ProfileOptions, + ProfilePlane as ProfilePlane, + StepTraceAnnotation as StepTraceAnnotation, + TraceAnnotation as TraceAnnotation, + annotate_function as annotate_function, + device_memory_profile as device_memory_profile, + save_device_memory_profile as save_device_memory_profile, + start_server as start_server, + start_trace as start_trace, + stop_server as stop_server, + stop_trace as stop_trace, + trace as trace, ) diff --git a/jax/random.py b/jax/random.py index 9db584895cf1..42b6df56e6d0 100644 --- a/jax/random.py +++ b/jax/random.py @@ -92,7 +92,7 @@ To learn more about this upgrade, and the design of key types, see `JEP 9263 - `_. + `_. Advanced -------- @@ -166,10 +166,11 @@ In order for ``jax.jit`` to efficiently auto-partition functions that generate sharded random number arrays (or key arrays), all PRNG -implementations require extra flags: +implementations depend on extra flags: -- For ``"threefry2x32"``, and ``"rbg"`` key derivation, set - ``jax_threefry_partitionable=True``. +- For ``"threefry2x32"``, and ``"rbg"`` key derivation, have + ``jax_threefry_partitionable=True``. As of JAX v.0.5.0, this is the + default. - For ``"unsafe_rbg"``, and ``"rbg"`` random generation", set the XLA flag ``--xla_tpu_spmd_rng_bit_generator_unsafe=1``. @@ -178,7 +179,7 @@ ``XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1``. For more about ``jax_threefry_partitionable``, see -https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers +https://github.com/jax-ml/jax/discussions/18480 **Summary:** @@ -195,7 +196,7 @@ exact ``jax.vmap`` over keys ✅ ✅ ================================= ======== ========= === ========== ===== ============ -(*): with ``jax_threefry_partitionable=1`` set +(*): with ``jax_threefry_partitionable=1`` set (default as of JAX v0.5.0) (**): with ``XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1`` set """ diff --git a/jax/ref.py b/jax/ref.py new file mode 100644 index 000000000000..f51c1012eb93 --- /dev/null +++ b/jax/ref.py @@ -0,0 +1,44 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = [ + 'AbstractRef', 'Ref', 'addupdate', 'freeze', 'get', 'new_ref', 'set', 'swap' +] + +from jax._src.core import Ref, freeze +from jax._src.ref import new_ref +from jax._src.state.types import AbstractRef +from jax._src.state.primitives import ( + ref_get as get, + ref_set as set, + ref_swap as swap, + ref_addupdate as addupdate, +) + + +_deprecations = { + # Remove in v0.10.0 + "array_ref": ( + "jax.array_ref was removed in JAX v0.9.0; use jax.new_ref instead.", + None + ), + "ArrayRef": ( + "jax.ArrayRef was removed in JAX v0.9.0; use jax.Ref instead.", + None + ), +} + +from jax._src.deprecations import deprecation_getattr as _deprecation_getattr +__getattr__ = _deprecation_getattr(__name__, _deprecations) +del _deprecation_getattr diff --git a/jax/scipy/linalg.py b/jax/scipy/linalg.py index 64bc0544000b..dd12456dbd15 100644 --- a/jax/scipy/linalg.py +++ b/jax/scipy/linalg.py @@ -31,12 +31,14 @@ lu as lu, lu_factor as lu_factor, lu_solve as lu_solve, + pascal as pascal, polar as polar, qr as qr, rsf2csf as rsf2csf, schur as schur, sqrtm as sqrtm, solve as solve, + solve_sylvester as solve_sylvester, solve_triangular as solve_triangular, svd as svd, toeplitz as toeplitz, diff --git a/jax/scipy/special.py b/jax/scipy/special.py index 2ffc65a1abe1..dbe77ad276bc 100644 --- a/jax/scipy/special.py +++ b/jax/scipy/special.py @@ -37,6 +37,7 @@ gammaln as gammaln, gammasgn as gammasgn, hyp1f1 as hyp1f1, + hyp2f1 as hyp2f1, i0 as i0, i0e as i0e, i1 as i1, @@ -54,9 +55,10 @@ poch as poch, polygamma as polygamma, rel_entr as rel_entr, + sici as sici, softmax as softmax, spence as spence, - sph_harm as sph_harm, + sph_harm as _deprecated_sph_harm, sph_harm_y as sph_harm_y, xlog1py as xlog1py, xlogy as xlogy, @@ -77,12 +79,18 @@ "jax.scipy.special.lpmn_values is deprecated; no replacement is planned.", _deprecated_lpmn_values, ), + # Added Jul 7 2025 + "sph_harm": ( + "jax.scipy.special.sph_harm is deprecated; use sph_harm_y instead.", + _deprecated_sph_harm + ) } import typing as _typing if _typing.TYPE_CHECKING: lpmn = _deprecated_lpmn lpmn_values = _deprecated_lpmn_values + sph_harm = _deprecated_sph_harm else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) diff --git a/jax/scipy/stats/__init__.py b/jax/scipy/stats/__init__.py index 7719945f23df..842113192715 100644 --- a/jax/scipy/stats/__init__.py +++ b/jax/scipy/stats/__init__.py @@ -41,3 +41,5 @@ from jax._src.scipy.stats._core import mode as mode, rankdata as rankdata, sem as sem from jax.scipy.stats import vonmises as vonmises from jax.scipy.stats import wrapcauchy as wrapcauchy +from jax.scipy.stats import gumbel_r as gumbel_r +from jax.scipy.stats import gumbel_l as gumbel_l diff --git a/jax/util.py b/jax/scipy/stats/gumbel_l.py similarity index 59% rename from jax/util.py rename to jax/scipy/stats/gumbel_l.py index 8071f77dffe2..dd5764487a1f 100644 --- a/jax/util.py +++ b/jax/scipy/stats/gumbel_l.py @@ -1,4 +1,4 @@ -# Copyright 2018 The JAX Authors. +# Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,19 +15,12 @@ # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/jax-ml/jax/issues/7570 -from jax._src.util import ( - HashableFunction as HashableFunction, - as_hashable_function as as_hashable_function, - cache as cache, - safe_map as safe_map, - safe_zip as safe_zip, - split_dict as split_dict, - split_list as split_list, - split_list_checked as split_list_checked, - split_merge as split_merge, - subvals as subvals, - toposort as toposort, - unzip2 as unzip2, - wrap_name as wrap_name, - wraps as wraps, +from jax._src.scipy.stats.gumbel_l import ( + logpdf as logpdf, + pdf as pdf, + logcdf as logcdf, + cdf as cdf, + ppf as ppf, + sf as sf, + logsf as logsf ) diff --git a/jax/scipy/stats/gumbel_r.py b/jax/scipy/stats/gumbel_r.py new file mode 100644 index 000000000000..8ac1fe241460 --- /dev/null +++ b/jax/scipy/stats/gumbel_r.py @@ -0,0 +1,26 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Note: import as is required for names to be exported. +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 + +from jax._src.scipy.stats.gumbel_r import ( + logpdf as logpdf, + pdf as pdf, + logcdf as logcdf, + cdf as cdf, + ppf as ppf, + sf as sf, + logsf as logsf +) diff --git a/jax/scipy/stats/pareto.py b/jax/scipy/stats/pareto.py index 5e46fd5d0bc7..fb3a2ac698ef 100644 --- a/jax/scipy/stats/pareto.py +++ b/jax/scipy/stats/pareto.py @@ -16,6 +16,11 @@ # See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.pareto import ( + logcdf as logcdf, logpdf as logpdf, + logsf as logsf, + cdf as cdf, pdf as pdf, + ppf as ppf, + sf as sf, ) diff --git a/jax/scipy/stats/poisson.py b/jax/scipy/stats/poisson.py index 5fcde905f89b..ac7cfa141063 100644 --- a/jax/scipy/stats/poisson.py +++ b/jax/scipy/stats/poisson.py @@ -19,4 +19,5 @@ logpmf as logpmf, pmf as pmf, cdf as cdf, + entropy as entropy ) diff --git a/jax/sharding.py b/jax/sharding.py index 55ff0f6aea0b..c592abec393f 100644 --- a/jax/sharding.py +++ b/jax/sharding.py @@ -19,33 +19,42 @@ from jax._src.sharding_impls import ( NamedSharding as NamedSharding, SingleDeviceSharding as SingleDeviceSharding, - PmapSharding as PmapSharding, - GSPMDSharding as GSPMDSharding, - PositionalSharding as PositionalSharding, - use_mesh as use_mesh, + PmapSharding as _deprecated_PmapSharding, set_mesh as set_mesh, + get_mesh as get_mesh, ) from jax._src.partition_spec import ( PartitionSpec as PartitionSpec, ) from jax._src.mesh import ( Mesh as Mesh, + AbstractDevice as AbstractDevice, AbstractMesh as AbstractMesh, AxisType as AxisType, get_abstract_mesh as get_abstract_mesh, + use_abstract_mesh as use_abstract_mesh, +) + +from jax._src.pjit import ( + reshard as reshard, + auto_axes as auto_axes, + explicit_axes as explicit_axes, ) _deprecations = { - # Finalized 2024-10-01; remove after 2025-01-01. - "XLACompatibleSharding": ( - ( - "jax.sharding.XLACompatibleSharding was removed in JAX v0.4.34. " - "Use jax.sharding.Sharding instead." - ), - None, - ) + # Added for v0.8.1 + "PmapSharding": ( + "jax.sharding.PmapSharding is deprecated; use jax.sharding.NamedSharding instead.", + _deprecated_PmapSharding + ), } -from jax._src.deprecations import deprecation_getattr as _deprecation_getattr -__getattr__ = _deprecation_getattr(__name__, _deprecations) -del _deprecation_getattr +import typing as _typing +if _typing.TYPE_CHECKING: + PmapSharding = _deprecated_PmapSharding +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del _typing +del _deprecated_PmapSharding diff --git a/jax/stages.py b/jax/stages.py index 3e7e461c385b..0c47ce41fe20 100644 --- a/jax/stages.py +++ b/jax/stages.py @@ -18,7 +18,7 @@ lowering and compilation *ahead of time*. This module defines types that represent the stages of this process. -For more, see the `AOT walkthrough `_. +For more, see the `AOT walkthrough `_. """ # Note: import as is required for names to be exported. @@ -30,6 +30,5 @@ Lowered as Lowered, Wrapped as Wrapped, ArgInfo as ArgInfo, - OutInfo as OutInfo, Traced as Traced, ) diff --git a/jax/tools/build_defs.bzl b/jax/tools/build_defs.bzl index 06f5e69833c5..5c31a6cc53bc 100644 --- a/jax/tools/build_defs.bzl +++ b/jax/tools/build_defs.bzl @@ -15,21 +15,7 @@ """JAX tools.""" load("@rules_python//python:py_binary.bzl", "py_binary") - -def _shell_quote(s): - """Copy of bazel-skylib's shell.quote. - - Quotes the given string for use in a shell command. - - This function quotes the given string (in case it contains spaces or other - shell metacharacters.) - - Args: - s: The string to quote. - Returns: - A quoted version of the string that can be passed to a shell command. - """ - return "'" + s.replace("'", "'\\''") + "'" +load("@bazel_skylib//lib:shell.bzl", "shell") def jax_to_hlo(name, deps, fn, input_shapes, constants = None): jax_to_ir(name, deps, fn, input_shapes, constants = constants, format = "HLO") @@ -182,9 +168,9 @@ EOF """.format( name = name, fn = fn, - input_shapes = _shell_quote(str(input_shapes)), - constants = _shell_quote(str(constants)), + input_shapes = shell.quote(str(input_shapes)), + constants = shell.quote(str(constants)), runner = runner, - format = _shell_quote(format), + format = shell.quote(format), ), ) diff --git a/jax/tools/jax_to_ir.py b/jax/tools/jax_to_ir.py index 904ce509a87e..47b85382f8bf 100644 --- a/jax/tools/jax_to_ir.py +++ b/jax/tools/jax_to_ir.py @@ -240,16 +240,12 @@ def parse_shape_str(s): _DT = { 'pred': jnp.bool_, - 'u4': jnp.uint4, 'u8': jnp.uint8, 'u16': jnp.uint16, 'u32': jnp.uint32, 'u64': jnp.uint64, - 's4': jnp.int4, 's8': jnp.int8, 's16': jnp.int16, 's32': jnp.int32, 's64': jnp.int64, + 'u2': jnp.uint2, 'u4': jnp.uint4, 'u8': jnp.uint8, 'u16': jnp.uint16, 'u32': jnp.uint32, 'u64': jnp.uint64, + 's2': jnp.int2, 's4': jnp.int4, 's8': jnp.int8, 's16': jnp.int16, 's32': jnp.int32, 's64': jnp.int64, 'bf16': jnp.bfloat16, 'f16': jnp.float16, 'f32': jnp.float32, 'f64': jnp.float64, 'c64': jnp.complex64, 'c128': jnp.complex128 } -if hasattr(jnp, 'int2'): - _DT['s2'] = jnp.int2 -if hasattr(jnp, 'uint2'): - _DT['u2'] = jnp.uint2 _SHAPE_RE = re.compile(f"^({'|'.join(_DT)})\\[\\s*(\\d*[\\s*,\\d+]*)\\s*\\]$") diff --git a/jax/tools/pgo_nsys_converter.py b/jax/tools/pgo_nsys_converter.py index 10209c9a85ba..d34f6ec5a12d 100644 --- a/jax/tools/pgo_nsys_converter.py +++ b/jax/tools/pgo_nsys_converter.py @@ -54,7 +54,7 @@ proc.wait() thunk_re = re.compile("hlo_op=(.*)#") - cost_dictionary: dict[str, list] = dict() + cost_dictionary: dict[str, list] = {} with open(f"{args.pgle_output_path}", 'w', newline='') as protofile: with open(f"{pgle_folder}{pgle_filename}.pbtxt_{report_name}.csv", newline='') as csvfile: reader = csv.DictReader(csvfile) @@ -64,7 +64,7 @@ m = thunk_re.search(name) if m is not None: if args.post_process: - cost_dictionary.setdefault(m.group(1), []).append((time_ns/1000.0)) + cost_dictionary.setdefault(m.group(1), []).append(time_ns/1000.0) else: protofile.write(f'costs {{ name: "{m.group(1)}" cost_us: {time_ns / 1000.0} }}\n') if args.post_process: diff --git a/jax/tree.py b/jax/tree.py index 270c34fe9647..854ee9c8ea7e 100644 --- a/jax/tree.py +++ b/jax/tree.py @@ -19,6 +19,7 @@ from jax._src.tree import ( all as all, + broadcast as broadcast, flatten_with_path as flatten_with_path, flatten as flatten, leaves_with_path as leaves_with_path, @@ -26,6 +27,7 @@ map_with_path as map_with_path, map as map, reduce as reduce, + reduce_associative as reduce_associative, structure as structure, transpose as transpose, unflatten as unflatten, diff --git a/jax/tree_util.py b/jax/tree_util.py index 956d79b9b4ef..486c1d017953 100644 --- a/jax/tree_util.py +++ b/jax/tree_util.py @@ -48,16 +48,16 @@ PyTreeDef as PyTreeDef, SequenceKey as SequenceKey, all_leaves as all_leaves, - build_tree as build_tree, default_registry as default_registry, keystr as keystr, + register_dataclass as register_dataclass, register_pytree_node_class as register_pytree_node_class, register_pytree_node as register_pytree_node, register_pytree_with_keys_class as register_pytree_with_keys_class, - register_dataclass as register_dataclass, register_pytree_with_keys as register_pytree_with_keys, register_static as register_static, tree_all as tree_all, + tree_broadcast as tree_broadcast, tree_flatten_with_path as tree_flatten_with_path, tree_flatten as tree_flatten, tree_leaves_with_path as tree_leaves_with_path, @@ -65,6 +65,7 @@ tree_map_with_path as tree_map_with_path, tree_map as tree_map, tree_reduce as tree_reduce, + tree_reduce_associative as tree_reduce_associative, tree_structure as tree_structure, tree_transpose as tree_transpose, tree_unflatten as tree_unflatten, diff --git a/jax/typing.py b/jax/typing.py index 89efa1f2ca66..0530c69e60ca 100644 --- a/jax/typing.py +++ b/jax/typing.py @@ -15,7 +15,7 @@ """ The JAX typing module is where JAX-specific static type annotations live. This submodule is a work in progress; to see the proposal behind the types exported -here, see https://jax.readthedocs.io/en/latest/jep/12049-type-annotations.html. +here, see https://docs.jax.dev/en/latest/jep/12049-type-annotations.html. The currently-available types are: @@ -67,7 +67,7 @@ def my_function(x: ArrayLike) -> Array: batch-wise transforms like :func:`~jax.vmap` or :func:`jax.pmap`. For more information on this, see `Non-array inputs NumPy vs JAX`_ -.. _Non-array inputs NumPy vs JAX: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#non-array-inputs-numpy-vs-jax +.. _Non-array inputs NumPy vs JAX: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#non-array-inputs-numpy-vs-jax """ from jax._src.typing import ( ArrayLike as ArrayLike, diff --git a/jax/version.py b/jax/version.py index be20aca06358..6bf81f71614e 100644 --- a/jax/version.py +++ b/jax/version.py @@ -21,7 +21,7 @@ import pathlib import subprocess -_version = "0.5.3" +_version = "0.9.0" # The following line is overwritten by build scripts in distributions & # releases. Do not modify this manually, or jax/jaxlib build will fail. _release_version: str | None = None @@ -93,6 +93,12 @@ def _get_version_for_build() -> str: return _version_from_git_tree(_version) or _version_from_todays_date(_version) +def _is_prerelease() -> bool: + """Determine if this is a pre-release ("rc" wheels) build.""" + rc_version = os.getenv("WHEEL_VERSION_SUFFIX", "") + return True if rc_version.startswith("rc") else False + + def _write_version(fname: str) -> None: """Used by setup.py to write the specified version info into the source tree.""" release_version = _get_version_for_build() @@ -146,7 +152,7 @@ def make_release_tree(self, base_dir, files): __version__ = _get_version_string() -_minimum_jaxlib_version = "0.5.1" +_minimum_jaxlib_version = '0.9.0' def _version_as_tuple(version_str): return tuple(int(i) for i in version_str.split(".") if i.isdigit()) diff --git a/jax_plugins/cuda/BUILD.bazel b/jax_plugins/cuda/BUILD.bazel index 1f4e5a08dcb9..7070bf6bc495 100644 --- a/jax_plugins/cuda/BUILD.bazel +++ b/jax_plugins/cuda/BUILD.bazel @@ -12,15 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -licenses(["notice"]) - load( - "//jaxlib:jax.bzl", - "if_windows", - "py_library_providing_imports_info", - "pytype_library", + "//jaxlib:jax.bzl", + "py_library_providing_imports_info", + "pytype_library", ) +licenses(["notice"]) + package( default_applicable_licenses = [], default_visibility = ["//:__subpackages__"], @@ -34,46 +33,27 @@ exports_files([ "setup.py", ]) +cc_binary( + name = "pjrt_c_api_gpu_plugin.so", + linkopts = [ + "-Wl,--version-script,$(location :gpu_version_script.lds)", + "-Wl,--no-undefined", + ], + linkshared = True, + deps = [ + ":gpu_version_script.lds", + "//jaxlib/mosaic/gpu:custom_call", + "@xla//xla/pjrt/c:pjrt_c_api_gpu", + "@xla//xla/service:gpu_plugin", + "@xla//xla/stream_executor:cuda_platform", + ], +) + py_library_providing_imports_info( name = "cuda_plugin", srcs = [ "__init__.py", ], - data = if_windows( - ["@xla//xla/pjrt/c/pjrt_c_api_gpu_plugin.pyd"], - ["@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so"], - ), + data = [":pjrt_c_api_gpu_plugin.so"], lib_rule = pytype_library, ) - -config_setting( - name = "disable_jaxlib_for_cpu_build", - flag_values = { - "//jax:build_jaxlib": "false", - "@local_config_cuda//:enable_cuda": "False", - }, -) - -config_setting( - name = "disable_jaxlib_for_cuda12_build", - flag_values = { - "//jax:build_jaxlib": "false", - "@local_config_cuda//:enable_cuda": "True", - }, -) - -config_setting( - name = "enable_py_import_for_cpu_build", - flag_values = { - "//jax:build_jaxlib": "wheel", - "@local_config_cuda//:enable_cuda": "False", - }, -) - -config_setting( - name = "enable_py_import_for_cuda12_build", - flag_values = { - "//jax:build_jaxlib": "wheel", - "@local_config_cuda//:enable_cuda": "True", - }, -) diff --git a/jax_plugins/cuda/__init__.py b/jax_plugins/cuda/__init__.py index f6540e986024..213317cbdabc 100644 --- a/jax_plugins/cuda/__init__.py +++ b/jax_plugins/cuda/__init__.py @@ -12,27 +12,41 @@ # See the License for the specific language governing permissions and # limitations under the License. +import ctypes import functools import importlib import logging import os import pathlib +import traceback +from typing import Any from jax._src.lib import triton from jax._src.lib import xla_client import jax._src.xla_bridge as xb -# cuda_plugin_extension locates inside jaxlib. `jaxlib` is for testing without -# preinstalled jax cuda plugin packages. -for pkg_name in ['jax_cuda12_plugin', 'jaxlib.cuda']: - try: - cuda_plugin_extension = importlib.import_module( - f'{pkg_name}.cuda_plugin_extension' - ) - except ImportError: - cuda_plugin_extension = None - else: - break +cuda_plugin_extension = None +cuda_versions = None + +def _import_extensions(): + global cuda_plugin_extension + global cuda_versions + + # cuda_plugin_extension locates inside jaxlib. `jaxlib` is for testing without + # preinstalled jax cuda plugin packages. + for pkg_name in ['jax_cuda13_plugin', 'jax_cuda12_plugin', 'jaxlib.cuda']: + try: + cuda_plugin_extension = importlib.import_module( + f'{pkg_name}.cuda_plugin_extension' + ) + cuda_versions = importlib.import_module( + f'{pkg_name}._versions' + ) + except ImportError: + cuda_plugin_extension = None + cuda_versions = None + else: + break logger = logging.getLogger(__name__) @@ -51,7 +65,7 @@ def _get_library_path(): runfiles_dir = os.getenv('RUNFILES_DIR', None) if runfiles_dir: local_path = os.path.join( - runfiles_dir, 'xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so' + runfiles_dir, '__main__/jax_plugins/cuda/pjrt_c_api_gpu_plugin.so' ) if os.path.exists(local_path): @@ -76,30 +90,271 @@ def _get_library_path(): return None +def _load(module, libraries): + try: + m = importlib.import_module(f"nvidia.{module}") + except ImportError: + m = None + + for lib in libraries: + excs = [] + if m is not None: + path = pathlib.Path(m.__path__[0]) / "lib" / lib + try: + ctypes.cdll.LoadLibrary(path) + continue + except OSError as e: + excs.append(e) + + # TODO(phawkins): check the non-Python path here and error if not found. + # # Try again, without the Python module path. + # try: + # ctypes.cdll.LoadLibrary(lib) + # continue + # except OSError as e: + # excs.append(e) + # + # raise ExceptionGroup(f"Unable to load CUDA library {lib}", excs) # noqa: F821 + + +def _load_nvidia_libraries(): + """Attempts to load NVIDIA's libraries. + + We prefer the Python packages, if present. If not, we fall back to loading + them from LD_LIBRARY_PATH. By loading the libraries here, later lookups will + find these copies.""" + _load("cuda_runtime", ["libcudart.so.12"]) + _load("cu13", ["libcudart.so.13"]) + # cuda_nvrtc isn't directly a dependency of JAX, but CUDNN appears to need it + # and at least in CUDA 12.9 has RUNPATHs misconfigured to refer to + # nvidia/nvrtc instead of nvidia/cuda_nvrtc. + _load("cuda_nvrtc", ["libnvrtc.so.12"]) + _load("cu13", ["libnvrtc.so.13"]) + _load("cublas", ["libcublas.so.12", "libcublasLt.so.12"]) + _load("cu13", ["libcublas.so.13", "libcublasLt.so.13"]) + _load("nccl", ["libnccl.so.2"]) + _load("cuda_cupti", ["libcupti.so.12"]) + _load("cu13", ["libcupti.so.13"]) + _load("cusparse", ["libcusparse.so.12"]) + _load("cu13", ["libcusparse.so.12"]) + _load("cusolver", ["libcusolver.so.11"]) + _load("cu13", ["libcusolver.so.12"]) + _load("cufft", ["libcufft.so.11"]) + _load("cu13", ["libcufft.so.12"]) + _load("nvshmem", ["libnvshmem_host.so.3"]) + _load("cudnn", ["libcudnn.so.9"]) + + +def _check_cuda_versions(raise_on_first_error: bool = False, + debug: bool = False): + assert cuda_versions is not None + results: list[dict[str, Any]] = [] + + def _make_msg(name: str, + runtime_version: int, + build_version: int, + min_supported: int, + debug_msg: bool = False): + if debug_msg: + return (f"Package: {name}\n" + f"Version JAX was built against: {build_version}\n" + f"Minimum supported: {min_supported}\n" + f"Installed version: {runtime_version}") + if min_supported: + req_str = (f"The local installation version must be no lower than " + f"{min_supported}.") + else: + req_str = ("The local installation must be the same version as " + "the version against which JAX was built.") + msg = (f"Outdated {name} installation found.\n" + f"Version JAX was built against: {build_version}\n" + f"Minimum supported: {min_supported}\n" + f"Installed version: {runtime_version}\n" + f"{req_str}") + return msg + + + def _version_check(name: str, + get_version, + get_build_version, + scale_for_comparison: int = 1, + min_supported_version: int = 0) -> int | None: + """Checks the runtime CUDA component version against the JAX one. + + Args: + name: Of the CUDA component. + get_version: A function to get the local runtime version of the component. + get_build_version: A function to get the build version of the component. + scale_for_comparison: For rounding down a version to ignore patch/minor. + min_supported_version: An absolute minimum version required. Must be + passed without rounding down. + + Returns: the runtime version, or None if the component is not found. + + Raises: + RuntimeError: If the component is not found, or is of unsupported version, + and if raising the error is not deferred till later. + """ + + build_version = get_build_version() + try: + version = get_version() + except Exception as e: + err_msg = f"Unable to load {name}. Is it installed?" + if raise_on_first_error: + raise RuntimeError(err_msg) from e + err_msg += f"\n{traceback.format_exc()}" + results.append({"name": name, "installed": False, "msg": err_msg}) + return + + if not min_supported_version: + min_supported_version = build_version // scale_for_comparison + passed = min_supported_version <= version + + if not passed or debug: + msg = _make_msg(name=name, + runtime_version=version, + build_version=build_version, + min_supported=min_supported_version, + debug_msg=passed) + if not passed and raise_on_first_error: + raise RuntimeError(msg) + else: + record = {"name": name, + "installed": True, + "msg": msg, + "passed": passed, + "build_version": build_version, + "version": version, + "minimum_supported": min_supported_version} + results.append(record) + return version + + _version_check("CUDA", cuda_versions.cuda_runtime_get_version, + cuda_versions.cuda_runtime_build_version, + scale_for_comparison=10, + min_supported_version=12010) + cudnn_version = _version_check( + "cuDNN", + cuda_versions.cudnn_get_version, + cuda_versions.cudnn_build_version, + # NVIDIA promise both backwards and forwards compatibility for cuDNN patch + # versions: + # https://docs.nvidia.com/deeplearning/cudnn/backend/latest/developer/forward-compatibility.html#cudnn-api-compatibility + scale_for_comparison=100, + ) + _version_check("cuFFT", cuda_versions.cufft_get_version, + cuda_versions.cufft_build_version, + # Ignore patch versions. + scale_for_comparison=100) + # TODO(phawkins): for some reason this check fails with a cusolver internal + # error when fetching the version. This may be a path error from our stubs. + # Figure out what's happening here and re-enable. + # _version_check("cuSOLVER", cuda_versions.cusolver_get_version, + # cuda_versions.cusolver_build_version, + # # Ignore patch versions. + # scale_for_comparison=100, + # min_supported_version=11400) + _version_check("cuPTI", cuda_versions.cupti_get_version, + cuda_versions.cupti_build_version, + min_supported_version=18) + cublas_version = _version_check("cuBLAS", cuda_versions.cublas_get_version, + cuda_versions.cublas_build_version, + # Ignore patch versions. + scale_for_comparison=100, + min_supported_version=120100) + _version_check("cuSPARSE", cuda_versions.cusparse_get_version, + cuda_versions.cusparse_build_version, + # Ignore patch versions. + scale_for_comparison=100, + min_supported_version=12100) + + # https://docs.nvidia.com/deeplearning/cudnn/backend/latest/release-notes.html#cudnn-9-10-1 + if (cudnn_version is not None and cudnn_version == 91000 + and cuda_versions.cudnn_build_version() != 91000): + msg = ("cuDNN 9.10.0 had a binary backward-compatibility issue due to reordered enum " + f"values affecting block-scale datatypes. Found runtime version {cudnn_version} " + f"and build version {cuda_versions.cudnn_build_version()}. Please upgrade to " + "9.10.1 or above.") + if raise_on_first_error: + raise RuntimeError(msg) + else: + results.append({"installed": True, "msg": msg, "passed": False}) + # xb.local_device_count() cannot safely be called at this point + if xb.CUDA_VISIBLE_DEVICES.value == "all": + local_device_count = cuda_versions.cuda_device_count() + else: + local_device_count = len(xb.CUDA_VISIBLE_DEVICES.value.split(",")) + # https://docs.nvidia.com/deeplearning/cudnn/backend/latest/release-notes.html#cudnn-9-10-0 + if (cudnn_version is not None and cudnn_version < 91001 + and cublas_version is not None and cublas_version >= 120900 + and local_device_count > 1): + msg = (f"cuDNN < 9.10.0 ({cudnn_version} found) had an issue that caused some multi-GPU " + "matmuls, in which the same finalized execution plan is used across different " + f"GPUs, to be functionally incorrect when run with cublasLt >= 12.9 ({cublas_version} " + "found). Please upgrade to 9.10.1 or above.") + if raise_on_first_error: + raise RuntimeError(msg) + else: + results.append({"installed": True, "msg": msg, "passed": False}) + + errors = [] + debug_results = [] + for result in results: + message: str = result['msg'] + if not result['installed'] or not result['passed']: + errors.append(message) + else: + debug_results.append(message) + + join_str = f'\n{"-" * 50}\n' + if debug_results: + print(f'CUDA components status (debug):\n' + f'{join_str.join(debug_results)}') + if errors: + raise RuntimeError(f'Unable to use CUDA because of the ' + f'following issues with CUDA components:\n' + f'{join_str.join(errors)}') + + def initialize(): + _load_nvidia_libraries() + _import_extensions() path = _get_library_path() if path is None: return + if not os.getenv("JAX_SKIP_CUDA_CONSTRAINTS_CHECK"): + _check_cuda_versions(raise_on_first_error=True) + else: + logger.debug('Skipped CUDA versions constraints check due to the ' + 'JAX_SKIP_CUDA_CONSTRAINTS_CHECK env var being set.') + options = xla_client.generate_pjrt_gpu_plugin_options() c_api = xb.register_plugin( 'cuda', priority=500, library_path=str(path), options=options ) if cuda_plugin_extension: - xla_client.register_custom_call_handler( + xla_client.register_custom_type_handler( "CUDA", functools.partial( - cuda_plugin_extension.register_custom_call_target, c_api + cuda_plugin_extension.register_custom_type, c_api ), ) - for _name, _value in cuda_plugin_extension.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="CUDA") - xla_client.register_custom_type_id_handler( + xla_client.register_custom_call_handler( "CUDA", functools.partial( - cuda_plugin_extension.register_custom_type_id, c_api + cuda_plugin_extension.register_custom_call_target, c_api ), ) + for _name, _value in cuda_plugin_extension.ffi_types().items(): + xla_client.register_custom_type( + _name, _value, platform='CUDA' + ) + for _name, _value in cuda_plugin_extension.ffi_handlers().items(): + xla_client.register_custom_call_target( + _name, _value, platform='CUDA', api_version=1 + ) triton.register_compilation_handler( "CUDA", functools.partial( diff --git a/jaxlib/tools/gpu_version_script.lds b/jax_plugins/cuda/gpu_version_script.lds similarity index 100% rename from jaxlib/tools/gpu_version_script.lds rename to jax_plugins/cuda/gpu_version_script.lds diff --git a/jax_plugins/cuda/plugin_setup.py b/jax_plugins/cuda/plugin_setup.py index ce31684de46f..d460e89f4a24 100644 --- a/jax_plugins/cuda/plugin_setup.py +++ b/jax_plugins/cuda/plugin_setup.py @@ -21,6 +21,20 @@ cuda_version = 0 # placeholder project_name = f"jax-cuda{cuda_version}-plugin" package_name = f"jax_cuda{cuda_version}_plugin" +cuda_wheel_suffix = '' # placeholder + +nvidia_cublas_version = '' # placeholder +nvidia_cuda_cupti_version = '' # placeholder +nvidia_cuda_nvcc_version = '' # placeholder +nvidia_cuda_runtime_version = '' # placeholder +nvidia_cudnn_version = '' # placeholder +nvidia_cufft_version = '' # placeholder +nvidia_cusolver_version = '' # placeholder +nvidia_cusparse_version = '' # placeholder +nvidia_nccl_version = '' # placeholder +nvidia_nvjitlink_version = '' # placeholder +nvidia_cuda_nvrtc_version = '' # placeholder +nvidia_nvshmem_version = '' # placeholder def load_version_module(pkg_path): spec = importlib.util.spec_from_file_location( @@ -49,19 +63,19 @@ def has_ext_modules(self): author="JAX team", author_email="jax-dev@google.com", packages=[package_name], - python_requires=">=3.10", + python_requires=">=3.11", install_requires=[f"jax-cuda{cuda_version}-pjrt=={__version__}"], extras_require={ - 'with_cuda': [ - "nvidia-cublas-cu12>=12.1.3.1", - "nvidia-cuda-cupti-cu12>=12.1.105", - "nvidia-cuda-nvcc-cu12>=12.6.85", - "nvidia-cuda-runtime-cu12>=12.1.105", - "nvidia-cudnn-cu12>=9.1,<10.0", - "nvidia-cufft-cu12>=11.0.2.54", - "nvidia-cusolver-cu12>=11.4.5.107", - "nvidia-cusparse-cu12>=12.1.0.106", - "nvidia-nccl-cu12>=2.18.1", + 'with-cuda': [ + f"nvidia-cublas{cuda_wheel_suffix}{nvidia_cublas_version}", + f"nvidia-cuda-cupti{cuda_wheel_suffix}{nvidia_cuda_cupti_version}", + f"nvidia-cuda-nvcc{cuda_wheel_suffix}{nvidia_cuda_nvcc_version}", + f"nvidia-cuda-runtime{cuda_wheel_suffix}{nvidia_cuda_runtime_version}", + f"nvidia-cudnn-cu{cuda_version}{nvidia_cudnn_version}", + f"nvidia-cufft{cuda_wheel_suffix}{nvidia_cufft_version}", + f"nvidia-cusolver{cuda_wheel_suffix}{nvidia_cusolver_version}", + f"nvidia-cusparse{cuda_wheel_suffix}{nvidia_cusparse_version}", + f"nvidia-nccl-cu{cuda_version}{nvidia_nccl_version}", # nvjitlink is not a direct dependency of JAX, but it is a transitive # dependency via, for example, cuSOLVER. NVIDIA's cuSOLVER packages # do not have a version constraint on their dependencies, so the @@ -69,16 +83,23 @@ def has_ext_modules(self): # problems (https://github.com/jax-ml/jax/issues/18027#issuecomment-1756305196) # Until NVIDIA add version constraints, add a version constraint # here. - "nvidia-nvjitlink-cu12>=12.1.105", - ], + f"nvidia-nvjitlink{cuda_wheel_suffix}{nvidia_nvjitlink_version}", + # nvrtc is a transitive and undeclared dep of cudnn. + f"nvidia-cuda-nvrtc{cuda_wheel_suffix}{nvidia_cuda_nvrtc_version}", + # NVSHMEM is used by Mosaic GPU collectives and can be used by XLA to + # speed up collectives too. + f"nvidia-nvshmem-cu{cuda_version}{nvidia_nvshmem_version}", + ] + (["nvidia-nvvm"] if cuda_version == 13 else []), }, url="https://github.com/jax-ml/jax", license="Apache-2.0", classifiers=[ - "Development Status :: 3 - Alpha", - "Programming Language :: Python :: 3.10", + "Development Status :: 5 - Production/Stable", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "Programming Language :: Python :: Free Threading :: 3 - Stable", ], package_data={ package_name: [ diff --git a/jax_plugins/cuda/setup.py b/jax_plugins/cuda/setup.py index 1ce555978dac..b2c89285e7fd 100644 --- a/jax_plugins/cuda/setup.py +++ b/jax_plugins/cuda/setup.py @@ -51,8 +51,9 @@ def load_version_module(pkg_path): url="https://github.com/jax-ml/jax", license="Apache-2.0", classifiers=[ - "Development Status :: 3 - Alpha", + "Development Status :: 5 - Production/Stable", "Programming Language :: Python :: 3", + "Programming Language :: Python :: Free Threading :: 3 - Stable", ], package_data={ package_name: ["xla_cuda_plugin.so"], diff --git a/jax_plugins/rocm/BUILD.bazel b/jax_plugins/rocm/BUILD.bazel index 6e265bcd18cf..7ee0726e7960 100644 --- a/jax_plugins/rocm/BUILD.bazel +++ b/jax_plugins/rocm/BUILD.bazel @@ -16,7 +16,6 @@ licenses(["notice"]) load( "//jaxlib:jax.bzl", - "if_windows", "py_library_providing_imports_info", "pytype_library", ) @@ -34,14 +33,26 @@ exports_files([ "setup.py", ]) +cc_binary( + name = "pjrt_c_api_gpu_plugin.so", + linkopts = [ + "-Wl,--version-script,$(location :gpu_version_script.lds)", + "-Wl,--no-undefined", + ], + linkshared = True, + deps = [ + ":gpu_version_script.lds", + "@xla//xla/pjrt/c:pjrt_c_api_gpu", + "@xla//xla/service:gpu_plugin", + "@xla//xla/stream_executor:rocm_platform", + ], +) + py_library_providing_imports_info( name = "rocm_plugin", srcs = [ "__init__.py", ], - data = if_windows( - ["@xla//xla/pjrt/c/pjrt_c_api_gpu_plugin.pyd"], - ["@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so"], - ), + data = [":pjrt_c_api_gpu_plugin.so"], lib_rule = pytype_library, ) diff --git a/jax_plugins/rocm/__init__.py b/jax_plugins/rocm/__init__.py index c48a681bf337..a231cb20be82 100644 --- a/jax_plugins/rocm/__init__.py +++ b/jax_plugins/rocm/__init__.py @@ -51,7 +51,7 @@ def _get_library_path(): runfiles_dir = os.getenv('RUNFILES_DIR', None) if runfiles_dir: local_path = pathlib.Path( - os.path.join(runfiles_dir, 'xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so') + os.path.join(runfiles_dir, '__main__/jax_plugins/rocm/pjrt_c_api_gpu_plugin.so') ) if local_path.exists(): @@ -86,19 +86,25 @@ def initialize(): 'rocm', priority=500, library_path=str(path), options=options ) if rocm_plugin_extension: - xla_client.register_custom_call_handler( + xla_client.register_custom_type_handler( "ROCM", functools.partial( - rocm_plugin_extension.register_custom_call_target, c_api + rocm_plugin_extension.register_custom_type, c_api ), ) - for _name, _value in rocm_plugin_extension.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="ROCM") - xla_client.register_custom_type_id_handler( + xla_client.register_custom_call_handler( "ROCM", functools.partial( - rocm_plugin_extension.register_custom_type_id, c_api + rocm_plugin_extension.register_custom_call_target, c_api ), ) + for _name, _value in rocm_plugin_extension.ffi_types().items(): + xla_client.register_custom_type( + _name, _value, platform='ROCM' + ) + for _name, _value in rocm_plugin_extension.ffi_handlers().items(): + xla_client.register_custom_call_target( + _name, _value, platform='ROCM', api_version=1 + ) else: logger.warning('rocm_plugin_extension is not found.') diff --git a/jax_plugins/rocm/gpu_version_script.lds b/jax_plugins/rocm/gpu_version_script.lds new file mode 100644 index 000000000000..cbac4549bde3 --- /dev/null +++ b/jax_plugins/rocm/gpu_version_script.lds @@ -0,0 +1,9 @@ +VERS_1.0 { + global: + extern "C" { + GetPjrtApi; + }; + + local: + *; +}; diff --git a/jax_plugins/rocm/plugin_setup.py b/jax_plugins/rocm/plugin_setup.py index d504d0a11666..f528820bae5c 100644 --- a/jax_plugins/rocm/plugin_setup.py +++ b/jax_plugins/rocm/plugin_setup.py @@ -54,16 +54,16 @@ def has_ext_modules(self): author="Ruturaj4", author_email="Ruturaj.Vaidya@amd.com", packages=[package_name], - python_requires=">=3.9", + python_requires=">=3.11", install_requires=[f"jax-rocm{rocm_version}-pjrt=={__version__}"], url="https://github.com/jax-ml/jax", license="Apache-2.0", classifiers=[ "Development Status :: 3 - Alpha", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", ], package_data={ package_name: [ diff --git a/jaxlib/BUILD b/jaxlib/BUILD index a35eabc9a505..a42c75f2f2c7 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -14,11 +14,24 @@ # JAX is Autograd and XLA +load("@rules_cc//cc:cc_library.bzl", "cc_library") load( "//jaxlib:jax.bzl", + "cc_proto_library", + "if_oss", + "jax_visibility", "nanobind_extension", - "py_library_providing_imports_info", + "proto_library", + "py_deps", + "py_strict_test", "pytype_library", + "pytype_strict_library", +) +load( + "//jaxlib:pywrap.bzl", + "nanobind_pywrap_extension", + "pywrap_binaries", + "pywrap_library", ) load("//jaxlib:symlink_files.bzl", "symlink_files") @@ -26,44 +39,38 @@ licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//jax:internal"], + default_visibility = ["//visibility:public"], ) -# This makes xla_extension module accessible from jax._src.lib. -genrule( - name = "xla_extension_py", - outs = ["xla_extension.py"], - cmd = "echo 'from xla.xla.python.xla_extension import *\n' > $@", +package_group( + name = "xla_python", + includes = [ + "//jax:internal", + ], ) -py_library_providing_imports_info( +pytype_strict_library( name = "jaxlib", - srcs = [ - "gpu_common_utils.py", - "gpu_linalg.py", - "gpu_prng.py", - "gpu_rnn.py", - "gpu_solver.py", - "gpu_sparse.py", - "gpu_triton.py", - "hlo_helpers.py", - "init.py", - "lapack.py", - "plugin_support.py", - ":version", - ":xla_client", - ":xla_extension_py", - ], data = [":ffi_headers"], - lib_rule = pytype_library, deps = [ + ":_ifrt_proxy", + ":_jax", + ":_pathways", + ":_pretty_printer", + ":_sdy_mpmd", ":cpu_feature_guard", + ":jax", + ":jaxlib_files", ":utils", + ":weakref_lru_cache", + ":xla_client", "//jaxlib/cpu:_lapack", + "//jaxlib/cpu:_sparse", "//jaxlib/mlir", "//jaxlib/mlir:arithmetic_dialect", "//jaxlib/mlir:builtin_dialect", "//jaxlib/mlir:chlo_dialect", + "//jaxlib/mlir:control_flow_dialect", "//jaxlib/mlir:func_dialect", "//jaxlib/mlir:gpu_dialect", "//jaxlib/mlir:ir", @@ -71,6 +78,7 @@ py_library_providing_imports_info( "//jaxlib/mlir:math_dialect", "//jaxlib/mlir:memref_dialect", "//jaxlib/mlir:mhlo_dialect", + "//jaxlib/mlir:mpmd_dialect", "//jaxlib/mlir:nvgpu_dialect", "//jaxlib/mlir:nvvm_dialect", "//jaxlib/mlir:pass_manager", @@ -79,22 +87,45 @@ py_library_providing_imports_info( "//jaxlib/mlir:sparse_tensor_dialect", "//jaxlib/mlir:stablehlo_dialect", "//jaxlib/mlir:vector_dialect", + "//jaxlib/mlir/_mlir_libs:_jax_mlir_ext", "//jaxlib/mosaic", + "//jaxlib/mosaic/python:gpu_dialect", + "//jaxlib/mosaic/python:tpu_dialect", "//jaxlib/triton", - "@xla//xla/python:xla_extension", + "@xla//xla/python:_profile_data", + "@xla//xla/python:_profiler", ], ) -symlink_files( - name = "version", - srcs = ["//jax:version.py"], - dst = ".", - flatten = True, +pytype_library( + name = "jaxlib_files", + srcs = [ + "cpu_sparse.py", + "gpu_common_utils.py", + "gpu_linalg.py", + "gpu_prng.py", + "gpu_rnn.py", + "gpu_solver.py", + "gpu_sparse.py", + "gpu_triton.py", + "init.py", + "lapack.py", + "plugin_support.py", + "xla_client.py", + ":version", + ], + deps = [ + ":_jax", + "//jaxlib/cpu:_lapack", + "//jaxlib/cpu:_sparse", + "//jaxlib/mlir:ir", + "//jaxlib/mlir:stablehlo_dialect", + ], ) symlink_files( - name = "xla_client", - srcs = ["@xla//xla/python:xla_client"], + name = "version", + srcs = ["//jax:version.py"], dst = ".", flatten = True, ) @@ -111,6 +142,52 @@ exports_files([ "setup.py", ]) +pywrap_library( + name = "jax", + common_lib_def_files_or_filters = { + "jaxlib/jax_common": "jax_common.json", + }, + common_lib_version_scripts = { + "jaxlib/jax_common": select({ + "@bazel_tools//src/conditions:windows": None, + "@bazel_tools//src/conditions:darwin": "libjax_common_darwin.lds", + "//conditions:default": "libjax_common.lds", + }), + }, + deps = [ + ":_ifrt_proxy", + ":_jax", + ":_pathways", + ":_pretty_printer", + ":_sdy_mpmd", + ":utils", + ":weakref_lru_cache", + "//jaxlib/mlir/_mlir_libs:_chlo", + "//jaxlib/mlir/_mlir_libs:_jax_mlir_ext", + "//jaxlib/mlir/_mlir_libs:_mlir", + "//jaxlib/mlir/_mlir_libs:_mlirDialectsGPU", + "//jaxlib/mlir/_mlir_libs:_mlirDialectsLLVM", + "//jaxlib/mlir/_mlir_libs:_mlirDialectsNVGPU", + "//jaxlib/mlir/_mlir_libs:_mlirDialectsSparseTensor", + "//jaxlib/mlir/_mlir_libs:_mlirGPUPasses", + "//jaxlib/mlir/_mlir_libs:_mlirHlo", + "//jaxlib/mlir/_mlir_libs:_mlirSparseTensorPasses", + "//jaxlib/mlir/_mlir_libs:_mosaic_gpu_ext", + "//jaxlib/mlir/_mlir_libs:_sdy", + "//jaxlib/mlir/_mlir_libs:_sdyMpmd", + "//jaxlib/mlir/_mlir_libs:_stablehlo", + "//jaxlib/mlir/_mlir_libs:_tpu_ext", + "//jaxlib/mlir/_mlir_libs:_triton_ext", + "@xla//xla/python:_profile_data", + "@xla//xla/python:_profiler", + ], +) + +pywrap_binaries( + name = "jaxlib_binaries", + dep = ":jax", +) + cc_library( name = "absl_status_casters", hdrs = ["absl_status_casters.h"], @@ -128,6 +205,7 @@ cc_library( cc_library( name = "ffi_helpers", hdrs = ["ffi_helpers.h"], + # compatible with libtpu features = ["-use_header_modules"], deps = [ "@com_google_absl//absl/algorithm:container", @@ -167,58 +245,1204 @@ cc_library( features = ["-use_header_modules"], deps = [ "@com_google_absl//absl/base", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], ) +# This isn't a CPU kernel. This exists to catch cases where jaxlib is built for the wrong +# target architecture. +nanobind_extension( + name = "cpu_feature_guard", + srcs = ["cpu_feature_guard.c"], + module_name = "cpu_feature_guard", + deps = [ + "@xla//third_party/python_runtime:headers", + ], +) + +nanobind_pywrap_extension( + name = "_pretty_printer", + srcs = ["_pretty_printer.cc"], + deps = [ + ":nb_class_ptr", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@nanobind", + ], +) + +nanobind_pywrap_extension( + name = "weakref_lru_cache", + srcs = ["weakref_lru_cache.cc"], + pytype_srcs = ["weakref_lru_cache.pyi"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@nanobind", + "@xla//third_party/python_runtime:headers", + "@xla//xla/pjrt:lru_cache", + "@xla//xla/tsl/platform:logging", + ], +) + +py_strict_test( + name = "weakref_lru_cache_test", + srcs = ["weakref_lru_cache_test.py"], + deps = [ + ":weakref_lru_cache", + ] + py_deps([ + "absl/flags", + "absl/logging", + "absl/testing", + ]), +) + +nanobind_pywrap_extension( + name = "utils", + srcs = ["utils.cc"], + deps = [ + "@com_google_absl//absl/base:log_severity", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/debugging:failure_signal_handler", + "@com_google_absl//absl/log:globals", + "@com_google_absl//absl/synchronization", + "@nanobind", + "@tsl//tsl/platform", + "@xla//third_party/python_runtime:headers", + ], +) + +nanobind_pywrap_extension( + name = "_jax", + srcs = ["jax.cc"], + additional_stubgen_deps = [ + "//third_party/py/numpy", + "//jaxlib/mlir:ir", + ], + enable_stub_generation = True, + pytype_deps = py_deps(["numpy"]), + pytype_srcs = glob(["_jax/*.pyi"]), + stub_replacement_patterns = { + "jax.jaxlib._jax.Array$": "Array: Any", + "jax.jaxlib._jax.ArrayImpl$": "ArrayImpl: Any", + }, + visibility = jax_visibility("jaxlib/_jax"), + deps = [ + ":call_location", + ":config", + ":custom_call_sharding", + ":dlpack", + ":ffi", + ":guard_lib", + ":jax_jit", + ":mlir", + ":nb_class_ptr", + ":pjit", + ":pmap_lib", + ":pprof_profile_builder", + ":py_client", + ":python_ref_manager", + ":pytree", + ":traceback", + ":util", + ":xla_compiler", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:initialize", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@nanobind", + "@tsl//tsl/platform", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:literal", + "@xla//xla:shape_util", + "@xla//xla:types", + "@xla//xla:util", + "@xla//xla/backends/cpu/collectives:cpu_collectives", + "@xla//xla/ffi:ffi_api", + "@xla//xla/hlo/builder/lib:approx_topk_shape", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:mlir_to_hlo", + "@xla//xla/pjrt:pjrt_api", + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:pjrt_common", + "@xla//xla/pjrt:pjrt_compiler", + "@xla//xla/pjrt:pjrt_executable", + "@xla//xla/pjrt:pjrt_layout", + "@xla//xla/pjrt:status_casters", + "@xla//xla/pjrt/c:pjrt_c_api_hdrs", + "@xla//xla/pjrt/c:pjrt_c_api_raw_buffer_external", + "@xla//xla/pjrt/c_api_client:pjrt_c_api_client", + "@xla//xla/pjrt/cpu:cpu_client", + "@xla//xla/pjrt/distributed", + "@xla//xla/pjrt/distributed:client", + "@xla//xla/pjrt/distributed:key_value_store_interface", + "@xla//xla/pjrt/distributed:protocol_proto_cc", + "@xla//xla/pjrt/distributed:service", + "@xla//xla/pjrt/distributed/preemption:preemption_sync_manager", + "@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options", + "@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", + "@xla//xla/python:logging", + "@xla//xla/python:nb_absl_flat_hash_map", + "@xla//xla/python:nb_absl_span", + "@xla//xla/python:refine_polymorphic_shapes", + "@xla//xla/python:types", + "@xla//xla/python:version", + "@xla//xla/python/ifrt", + "@xla//xla/python/ifrt:attribute_map", + "@xla//xla/python/ifrt:plugin_program", + "@xla//xla/python/ifrt:plugin_program_serdes", + "@xla//xla/python/pjrt_ifrt", + "@xla//xla/python/pjrt_ifrt:pjrt_attribute_map_util", + "@xla//xla/python/pjrt_ifrt:transfer_server_interface", + "@xla//xla/python/pjrt_ifrt:xla_ifrt", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/distributed_runtime/preemption:preemption_sync_manager", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:status", + "@xla//xla/tsl/platform:statusor", + "@xla//xla/tsl/platform/cloud:gcs_file_system", + "@xla//xla/tsl/python/lib/core:numpy", + ] + select({ + # gloo tcp transport only builds on linux + "@xla//xla/tsl:macos": [ + "@gloo//:transport_uv", + "@xla//xla/backends/cpu/collectives:gloo_collectives", + "@xla//xla/backends/cpu/collectives:gloo_kv_store", + ], + "@xla//xla/tsl:windows": [], + "//conditions:default": [ + ":py_socket_transfer", + "@gloo//:transport_tcp", + "@xla//xla/backends/cpu/collectives:gloo_collectives", + "@xla//xla/backends/cpu/collectives:gloo_kv_store", + ], + }) + select({ + # mpitrampoline does not build on windows + "@xla//xla/tsl:windows": [], + # we support MPI collectives only in OSS builds + "//conditions:default": if_oss(["@xla//xla/backends/cpu/collectives:mpi_collectives"]), + }), +) + cc_library( - name = "pass_boilerplate", - hdrs = ["pass_boilerplate.h"], - # compatible with libtpu + name = "pprof_profile_builder", + srcs = ["pprof_profile_builder.cc"], + hdrs = ["pprof_profile_builder.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@nanobind", + "@tsl//tsl/platform:protobuf", + "@tsl//tsl/profiler/protobuf:profile_proto_cc", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:util", + "@xla//xla/tsl/platform:logging", + ], +) + +cc_library( + name = "callback", + srcs = [ + "callback.cc", + ], + hdrs = [ + "callback.h", + ], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":python_ref_manager", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:comparison_util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/pjrt:host_callback", + "@xla//xla/pjrt:transpose", + "@xla//xla/python:nb_numpy", + "@xla//xla/tsl/platform:statusor", + "@xla//xla/tsl/python/lib/core:numpy", + ], +) + +cc_library( + name = "call_location", + srcs = ["call_location.cc"], + hdrs = ["call_location.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":py_user_context", + ":traceback", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@nanobind", + "@xla//xla/python/ifrt", + "@xla//xla/python/ifrt:attribute_map", + "@xla//xla/python/ifrt:user_context", + "@xla//xla/python/pjrt_ifrt", + ], +) + +cc_library( + name = "config", + srcs = ["config.cc"], + hdrs = ["config.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":python_ref_manager", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla/tsl/platform:logging", + ], +) + +cc_library( + name = "custom_call_sharding", + srcs = ["custom_call_sharding.cc"], + hdrs = ["custom_call_sharding.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@nanobind", + "@xla//third_party/python_runtime:headers", + "@xla//xla:shape_util", + "@xla//xla:util", + "@xla//xla/hlo/ir:hlo", + "@xla//xla/hlo/utils:hlo_sharding_util", + "@xla//xla/pjrt:status_casters", + "@xla//xla/pjrt/c:pjrt_c_api_custom_partitioner_extension_hdrs", + "@xla//xla/pjrt/c:pjrt_c_api_hdrs", + "@xla//xla/pjrt/c:pjrt_c_api_helpers", + "@xla//xla/python:custom_call_batch_partitioner", + "@xla//xla/python:custom_partition_callback", + "@xla//xla/python:debug_callback_partitioner", + "@xla//xla/python:inspect_sharding", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_library( + name = "dlpack", + srcs = ["dlpack.cc"], + hdrs = ["dlpack.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":nb_class_ptr", + ":py_client", + ":py_user_context", + ":python_ref_manager", + ":util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@dlpack", + "@llvm-project//llvm:Support", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:shape_util", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:pjrt_common", + "@xla//xla/pjrt:pjrt_compiler", + "@xla//xla/python:dlpack_types", + "@xla//xla/python:types", + "@xla//xla/python/ifrt", + "@xla//xla/python/pjrt_ifrt", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_library( + name = "ffi", + srcs = ["ffi.cc"], + hdrs = ["ffi.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/base", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@dlpack", + "@nanobind", + "@xla//third_party/python_runtime:headers", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/ffi:ffi_api", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", + "@xla//xla/pjrt:host_callback", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:dlpack_types", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:types", + "@xla//xla/service:custom_call_target_registry", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_library( + name = "guard_lib", + srcs = ["guard_lib.cc"], + hdrs = ["guard_lib.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", + "@nanobind", + "@xla//xla:util", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python/ifrt", + ], +) + +nanobind_pywrap_extension( + name = "_ifrt_proxy", + srcs = ["ifrt_proxy.cc"], + pytype_srcs = ["_ifrt_proxy.pyi"], deps = [ + ":nb_class_ptr", + ":py_client", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/log:log_entry", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/time", + "@nanobind", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python/ifrt", + "@xla//xla/python/ifrt:attribute_map", + "@xla//xla/python/ifrt_proxy/client:grpc_client", + "@xla//xla/python/ifrt_proxy/client:registry", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:statusor", + ], +) + +nanobind_pywrap_extension( + name = "_pathways", + srcs = ["pathways.cc"], + pytype_srcs = ["_pathways.pyi"], + visibility = jax_visibility("jaxlib/_pathways"), + deps = [ + ":nb_class_ptr", + ":py_client", + ":py_user_context", + ":traceback", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//third_party/python_runtime:headers", + "@xla//xla:util", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_absl_span", + "@xla//xla/python:types", + "@xla//xla/python/ifrt", + "@xla//xla/python/ifrt:user_context", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:status", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_library( + name = "jax_jit", + srcs = ["jax_jit.cc"], + hdrs = ["jax_jit.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":config", + ":nb_class_ptr", + ":py_client", + ":pytree", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@nanobind", + "@tsl//tsl/profiler/lib:traceme", + "@xla//third_party/python_runtime:headers", # build_cleaner: keep + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:pjrt_layout", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_absl_inlined_vector", + "@xla//xla/python:nb_absl_span", + "@xla//xla/python:types", + "@xla//xla/tsl/platform:logging", + ], +) + +cc_library( + name = "mlir", + srcs = ["mlir.cc"], + hdrs = ["mlir.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeWriter", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ReconcileUnrealizedCasts", "@llvm-project//mlir:Support", + "@nanobind", + "@shardy//shardy/dialect/mpmd/ir:dialect", + "@shardy//shardy/dialect/sdy/ir:dialect", + "@stablehlo//:stablehlo_serialization", + "@xla//xla/hlo/builder:xla_computation", + "@xla//xla/hlo/translate:stablehlo", + "@xla//xla/mlir_hlo:mhlo_passes", + "@xla//xla/pjrt:mlir_to_hlo", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:refine_polymorphic_shapes", + "@xla//xla/python:version", + "@xla//xla/service:hlo_proto_cc", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", ], ) cc_library( - name = "handle_pool", - hdrs = ["handle_pool.h"], + name = "nb_class_ptr", + hdrs = ["nb_class_ptr.h"], + copts = ["-fexceptions"], + features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/nb_class_ptr"), + deps = ["@nanobind"], +) + +cc_library( + name = "pjit", + srcs = ["pjit.cc"], + hdrs = ["pjit.h"], + compatible_with = [], copts = [ "-fexceptions", "-fno-strict-aliasing", ], features = ["-use_header_modules"], deps = [ + ":call_location", + ":config", + ":guard_lib", + ":jax_jit", + ":nb_class_ptr", + ":py_client", + ":py_user_context", + ":python_ref_manager", + ":pytree", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@nanobind", + "@tsl//tsl/profiler/lib:traceme", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:shape_util", + "@xla//xla:util", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:lru_cache", + "@xla//xla/python:nb_helpers", + "@xla//xla/python:nb_numpy", + "@xla//xla/python/ifrt", + "@xla//xla/python/ifrt:user_context", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", ], ) -# This isn't a CPU kernel. This exists to catch cases where jaxlib is built for the wrong -# target architecture. -nanobind_extension( - name = "cpu_feature_guard", - srcs = ["cpu_feature_guard.c"], - module_name = "cpu_feature_guard", +cc_library( + name = "pmap_lib", + srcs = ["pmap_lib.cc"], + hdrs = ["pmap_lib.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], deps = [ - "@xla//third_party/python_runtime:headers", + ":call_location", + ":config", + ":jax_jit", + ":nb_class_ptr", + ":py_client", + ":py_user_context", + ":python_ref_manager", + ":pytree", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@nanobind", + "@tsl//tsl/profiler/lib:traceme", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:status_macros", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_helpers", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:safe_static_init", + "@xla//xla/python:types", + "@xla//xla/python/ifrt", + "@xla//xla/python/ifrt:user_context", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + "@xla//xla/tsl/python/lib/core:numpy", ], ) -nanobind_extension( - name = "utils", - srcs = ["utils.cc"], - module_name = "utils", +cc_library( + name = "cached_py_object", + hdrs = ["cached_py_object.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], deps = [ - "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/functional:function_ref", + "@nanobind", + ], +) + +cc_library( + name = "py_client", + srcs = [ + "partition_spec.cc", + "py_array.cc", + "py_client.cc", + "py_compile_only_client.cc", + "py_device.cc", + "py_device_list.cc", + "py_executable.cc", + "py_memory_space.cc", + "py_program.cc", + "py_values.cc", + "sharding.cc", + "to_ifrt_sharding.cc", + ], + hdrs = [ + "partition_spec.h", + "py_array.h", + "py_client.h", + "py_compile_only_client.h", + "py_device.h", + "py_device_list.h", + "py_executable.h", + "py_memory_space.h", + "py_program.h", + "py_values.h", + "sharded_device_array.h", + "sharding.h", + "to_ifrt_sharding.h", + ], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/py_client"), + deps = [ + ":cached_py_object", + ":call_location", + ":guard_lib", + ":nb_class_ptr", + ":pprof_profile_builder", + ":py_client_cpu", + ":py_host_callback", + ":py_user_context", + ":python_ref_manager", + ":traceback", + ":util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps", + "@llvm-project//mlir:Pass", + "@nanobind", + "@tsl//tsl/platform:fingerprint", + "@tsl//tsl/platform:ml_dtypes", + "@tsl//tsl/profiler/lib:traceme", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:future", + "@xla//xla:literal", + "@xla//xla:shape_util", + "@xla//xla:status_macros", + "@xla//xla:types", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/hlo/ir:hlo", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:lru_cache", + "@xla//xla/pjrt:mlir_to_hlo", + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:pjrt_compiler", + "@xla//xla/pjrt:pjrt_executable", + "@xla//xla/pjrt:pjrt_layout", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_absl_flat_hash_map", + "@xla//xla/python:nb_absl_span", + "@xla//xla/python:nb_helpers", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:safe_static_init", + "@xla//xla/python:types", + "@xla//xla/python:version", + "@xla//xla/python/compile_only_ifrt:client", + "@xla//xla/python/ifrt", + "@xla//xla/python/ifrt:attribute_map", + "@xla//xla/python/ifrt:custom_call_program", + "@xla//xla/python/ifrt:plugin_program", + "@xla//xla/python/ifrt:plugin_program_serdes", + "@xla//xla/python/ifrt:sharding_serdes", + "@xla//xla/python/ifrt:user_context", + "@xla//xla/python/ifrt:user_context_status_util", + "@xla//xla/python/ifrt/hlo:hlo_program", + "@xla//xla/python/pjrt_ifrt", + "@xla//xla/python/pjrt_ifrt:pjrt_attribute_map_util", + "@xla//xla/python/pjrt_ifrt:pjrt_dtype", + "@xla//xla/python/pjrt_ifrt:pjrt_layout_serdes", + "@xla//xla/python/pjrt_ifrt:xla_executable_version_serdes", + "@xla//xla/python/pjrt_ifrt:xla_ifrt", + "@xla//xla/python/pjrt_ifrt:xla_sharding_serdes", + "@xla//xla/service:platform_util", + "@xla//xla/service/spmd/shardy:utils", + "@xla//xla/tsl/concurrency:future", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/framework:allocator", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:status", + "@xla//xla/tsl/platform:statusor", + "@xla//xla/tsl/python/lib/core:numpy", + ], +) + +cc_library( + name = "py_client_cpu", + srcs = ["py_client_cpu.cc"], + hdrs = ["py_client_cpu.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":ffi", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@dlpack", "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:shape_util", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/ffi:ffi_api", + "@xla//xla/ffi/api:ffi", + "@xla//xla/pjrt:host_callback", + "@xla//xla/pjrt:transpose", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:types", + ], + alwayslink = 1, +) + +cc_library( + name = "py_host_callback", + srcs = ["py_host_callback.cc"], + hdrs = ["py_host_callback.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":callback", + ":py_host_callback_cc_proto", + ":python_ref_manager", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@nanobind", + "@xla//xla:shape_util", + "@xla//xla:status_macros", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/pjrt:host_callback", + "@xla//xla/python:types", + "@xla//xla/python/ifrt", + "@xla//xla/python/pjrt_ifrt", + "@xla//xla/python/pjrt_ifrt:xla_host_callback_proto_cc", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/platform:statusor", + ], +) + +proto_library( + name = "py_host_callback_proto", + srcs = ["py_host_callback.proto"], +) + +cc_proto_library( + name = "py_host_callback_cc_proto", + visibility = jax_visibility("jaxlib/py_host_callback_cc_proto"), + deps = [":py_host_callback_proto"], +) + +cc_library( + name = "py_socket_transfer", + srcs = ["py_socket_transfer.cc"], + hdrs = ["py_socket_transfer.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":nb_class_ptr", + ":py_client", + ":py_user_context", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", + "@llvm-project//llvm:Support", + "@nanobind", + "@tsl//tsl/platform:casts", + "@xla//xla:future", + "@xla//xla:util", + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:status_casters", + "@xla//xla/pjrt/distributed:client", + "@xla//xla/pjrt/distributed:key_value_store_interface", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:types", + "@xla//xla/python/ifrt", + "@xla//xla/python/pjrt_ifrt", + "@xla//xla/python/pjrt_ifrt:pjrt_dtype", + "@xla//xla/python/pjrt_ifrt:transfer_server_interface", + "@xla//xla/python/transfer:event_loop", + "@xla//xla/python/transfer:pjrt_transfer_server", + "@xla//xla/python/transfer:socket-server", + "@xla//xla/python/transfer:socket_bulk_transport", + "@xla//xla/python/transfer:streaming", + "@xla//xla/python/transfer:streaming_ifrt", + "@xla//xla/python/transfer:transfer_socket_proto_cc", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_library( + name = "py_user_context", + srcs = ["py_user_context.cc"], + hdrs = ["py_user_context.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/py_user_context"), + deps = [ + ":python_ref_manager", + ":traceback", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@llvm-project//llvm:Support", + "@nanobind", + "@tsl//tsl/platform:random", "@xla//third_party/python_runtime:headers", + "@xla//xla/python:version", + "@xla//xla/python/ifrt:user_context", + "@xla//xla/service:slow_operation_alarm", + "@xla//xla/tsl/concurrency:ref_count", + ], +) + +cc_library( + name = "python_ref_manager", + srcs = ["python_ref_manager.cc"], + hdrs = ["python_ref_manager.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/python_ref_manager"), + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@nanobind", + "@tsl//tsl/profiler/lib:traceme", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep ], ) + +proto_library( + name = "pytree_proto", + srcs = ["pytree.proto"], +) + +cc_proto_library( + name = "pytree_cc_proto", + deps = [":pytree_proto"], +) + +cc_library( + name = "pytree", + srcs = ["pytree.cc"], + hdrs = ["pytree.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/pytree"), + deps = [ + ":nb_class_ptr", + ":pytree_cc_proto", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla/pjrt:exceptions", + "@xla//xla/tsl/platform:logging", + ], +) + +nanobind_pywrap_extension( + name = "_sdy_mpmd", + srcs = ["sdy_mpmd.cc"], + deps = [ + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", + "@nanobind", + "@shardy//shardy/dialect/mpmd/ir:dialect", + "@shardy//shardy/dialect/mpmd/ir:fragment_execution_rules", + "@shardy//shardy/dialect/mpmd/transforms/import:mesh_assignment_map", + "@shardy//shardy/integrations/python/jax/mpmd/jaxlib:mpmd_program", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_absl_flat_hash_map", + "@xla//xla/python/ifrt/ir/conversions/mpmd:lower_to_ifrt", + ], +) + +cc_library( + name = "traceback", + srcs = ["traceback.cc"], + hdrs = ["traceback.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/traceback"), + deps = [ + ":nb_class_ptr", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@nanobind", + "@tsl//tsl/platform", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla/pjrt:exceptions", + "@xla//xla/python:nb_helpers", + "@xla//xla/tsl/platform:logging", + ], +) + +cc_library( + name = "util", + srcs = ["util.cc"], + hdrs = ["util.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//xla:future", + "@xla//xla:util", + "@xla//xla/python:version", + "@xla//xla/python/ifrt", + "@xla//xla/tsl/concurrency:async_value", + "@xla//xla/tsl/concurrency:future", + "@xla//xla/tsl/concurrency:ref_count", + ], +) + +cc_library( + name = "xla_compiler", + srcs = ["xla_compiler.cc"], + hdrs = ["xla_compiler.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":dlpack", + ":py_client", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@llvm-project//mlir:Support", + "@nanobind", + "@xla//xla:array", + "@xla//xla:debug_options_flags", + "@xla//xla:literal", + "@xla//xla:shape_util", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla:xla_proto_cc", + "@xla//xla/client:executable_build_options", + "@xla//xla/hlo/builder:xla_computation", + "@xla//xla/hlo/ir:hlo", + "@xla//xla/hlo/parser:hlo_parser", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:pjrt_executable", + "@xla//xla/pjrt:status_casters", + "@xla//xla/pjrt/proto:compile_options_proto_cc", + "@xla//xla/python:nb_absl_span", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:types", + "@xla//xla/service:computation_placer", + "@xla//xla/service:hlo_graph_dumper", + "@xla//xla/service:hlo_module_config", + "@xla//xla/service:hlo_proto_cc", + "@xla//xla/service/spmd/shardy/stablehlo_round_trip:stablehlo_import", + "@xla//xla/tsl/lib/strings:proto_serialization", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + ], +) + +pytype_strict_library( + name = "xla_client", + srcs = ["xla_client.py"], + visibility = [":xla_python"], + deps = py_deps([ + "numpy", + "ml_dtypes", + ]) + [":_jax"], +) + +py_strict_test( + name = "pytree_test", + srcs = ["pytree_test.py"], + deps = [ + ":xla_client", + ] + py_deps([ + "absl/flags", + "absl/logging", + "absl/testing", + ]), +) + +py_strict_test( + name = "config_test", + srcs = ["config_test.py"], + deps = [ + ":xla_client", + ] + py_deps([ + "absl/flags", + "absl/logging", + "absl/testing", + ]), +) diff --git a/jaxlib/_ifrt_proxy.pyi b/jaxlib/_ifrt_proxy.pyi new file mode 100644 index 000000000000..c8f6b1141a4e --- /dev/null +++ b/jaxlib/_ifrt_proxy.pyi @@ -0,0 +1,31 @@ +# Copyright 2024 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from collections.abc import Callable +from typing import Any + +from jaxlib import _jax + +_Status = Any +Client = _jax.Client + +class ClientConnectionOptions: + on_disconnect: Callable[[_Status], None] | None = None + on_connection_update: Callable[[str], None] | None = None + connection_timeout_in_seconds: int | None = None + +def get_client( + proxy_server_address: str, options: ClientConnectionOptions +) -> Client: ... diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi new file mode 100644 index 000000000000..326ddec6c5aa --- /dev/null +++ b/jaxlib/_jax/__init__.pyi @@ -0,0 +1,1627 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable, Iterator, Mapping, Sequence +import enum +import inspect +import traceback +import types +from typing import Annotated, Any, TypeAlias, overload + +import numpy +from numpy.typing import NDArray +import typing_extensions + +from . import ( + config as config, + ffi as ffi, + guard_lib as guard_lib, + hlo_sharding_util as hlo_sharding_util, + ifrt_programs as ifrt_programs, + jax_jit as jax_jit, + mlir as mlir, + pmap_lib as pmap_lib, + pytree as pytree, +) +from .pmap_lib import PmapFunction as PmapFunction +from .pytree import (PyTreeDef as PyTreeDef, PyTreeRegistry as _PyTreeRegistry) + +class JaxRuntimeError(RuntimeError): + """Runtime errors thrown by the JAX runtime. + + While the JAX runtime may raise other exceptions as well, most exceptions + thrown by the runtime are instances of this class. + """ + +class PrimitiveType(enum.IntEnum): + PRIMITIVE_TYPE_INVALID = 0 + + PRED = 1 + + S4 = 21 + + S8 = 2 + + S16 = 3 + + S32 = 4 + + S64 = 5 + + U4 = 22 + + U8 = 6 + + U16 = 7 + + U32 = 8 + + U64 = 9 + + F16 = 10 + + F4E2M1FN = 32 + + F8E3M4 = 29 + + F8E4M3 = 28 + + F8E4M3FN = 20 + + F8E4M3B11FNUZ = 23 + + F8E4M3FNUZ = 25 + + F8E5M2 = 19 + + F8E5M2FNUZ = 24 + + F8E8M0FNU = 33 + + BF16 = 16 + + F32 = 11 + + F64 = 12 + + C64 = 15 + + C128 = 18 + + TUPLE = 13 + + OPAQUE_TYPE = 14 + + TOKEN = 17 + +class Layout: + @overload + def __init__(self, arg: Sequence[int], /) -> None: ... + @overload + def __init__( + self, arg0: Sequence[int], arg1: Sequence[tuple[int, ...]], arg2: int, / + ) -> None: ... + def minor_to_major(self) -> tuple[int, ...]: ... + def element_size_in_bits(self) -> int: ... + def tiling(self) -> list[tuple[int, ...]]: ... + def __eq__(self, other: object, /) -> bool: ... + def __ne__(self, other: object, /) -> bool: ... + def __str__(self) -> str: ... + def __hash__(self) -> int: ... + def to_string(self) -> str: ... + def __getstate__(self) -> tuple: ... + def __setstate__(self, arg: tuple, /) -> None: ... + +class Shape: + def __init__(self, arg: str, /) -> None: ... + @staticmethod + def tuple_shape(arg: Sequence[Shape], /) -> Shape: + """Constructs a tuple shape.""" + + @overload + @staticmethod + def array_shape( + type: PrimitiveType, + dims: Sequence[int], + layout: Sequence[int] | None = ..., + dynamic_dimensions: Sequence[bool] | None = ..., + ) -> Shape: + """Constructs an array shape.""" + + @overload + @staticmethod + def array_shape( + type: numpy.dtype, + dims: Sequence[int], + layout: Sequence[int] | None = ..., + dynamic_dimensions: Sequence[bool] | None = ..., + ) -> Shape: ... + @staticmethod + def token_shape() -> Shape: ... + @overload + @staticmethod + def scalar_shape(type: PrimitiveType) -> Shape: + """Constructs a scalar shape.""" + + @overload + @staticmethod + def scalar_shape(type: numpy.dtype) -> Shape: ... + def dimensions(self) -> tuple[int, ...]: ... + def layout(self) -> Layout: ... + def xla_element_type(self) -> PrimitiveType: ... + def element_type(self) -> numpy.dtype: ... + def numpy_dtype(self) -> numpy.dtype: ... + def is_tuple(self) -> bool: ... + def is_array(self) -> bool: ... + def is_token(self) -> bool: ... + def is_static(self) -> bool: ... + def is_dynamic(self) -> bool: ... + def is_dynamic_dimension(self, dimension: int) -> bool: ... + def set_dynamic_dimension(self, dimension: int, is_dynamic: bool) -> None: ... + def rank(self) -> int: ... + def to_serialized_proto(self) -> bytes: ... + def tuple_shapes(self) -> list[Shape]: ... + def leaf_count(self) -> int: ... + def with_major_to_minor_layout_if_absent(self) -> Shape: + """Returns a copy of a shape with missing layouts set to major-to-minor.""" + + def __eq__(self, other: object, /) -> bool: ... + def __ne__(self, other: object, /) -> bool: ... + def __hash__(self) -> int: ... + def __repr__(self) -> str: ... + +class ProgramShape: + def __init__(self, arg0: Sequence[Shape], arg1: Shape, /) -> None: ... + def parameter_shapes(self) -> list[Shape]: ... + def result_shape(self) -> Shape: ... + def __repr__(self) -> str: ... + +class Literal: + def __init__(self, arg: Shape, /) -> None: ... + def __repr__(self) -> str: ... + def __array__( + self, dtype: object | None = ..., copy: bool | None = ... + ) -> NDArray: ... + def shape(self) -> Shape: ... + +class XlaComputation: + def __init__(self, arg: bytes, /) -> None: ... + def get_hlo_module(self) -> HloModule: ... + def program_shape(self) -> ProgramShape: ... + def name(self) -> str: ... + def as_serialized_hlo_module_proto(self) -> bytes: ... + def as_hlo_text(self, print_large_constants: bool = ...) -> str: ... + def as_hlo_dot_graph(self) -> str: ... + def hash(self) -> int: ... + def as_hlo_module(self) -> HloModule: ... + +class HloPrintOptions: + def __init__(self) -> None: ... + @staticmethod + def short_parsable() -> HloPrintOptions: ... + @staticmethod + def canonical() -> HloPrintOptions: ... + @staticmethod + def fingerprint() -> HloPrintOptions: ... + @property + def print_large_constants(self) -> bool: ... + @print_large_constants.setter + def print_large_constants(self, arg: bool, /) -> HloPrintOptions: ... + @property + def print_metadata(self) -> bool: ... + @print_metadata.setter + def print_metadata(self, arg: bool, /) -> HloPrintOptions: ... + @property + def print_backend_config(self) -> bool: ... + @print_backend_config.setter + def print_backend_config(self, arg: bool, /) -> HloPrintOptions: ... + @property + def print_result_shape(self) -> bool: ... + @print_result_shape.setter + def print_result_shape(self, arg: bool, /) -> HloPrintOptions: ... + @property + def print_operand_shape(self) -> bool: ... + @print_operand_shape.setter + def print_operand_shape(self, arg: bool, /) -> HloPrintOptions: ... + @property + def print_operand_names(self) -> bool: ... + @print_operand_names.setter + def print_operand_names(self, arg: bool, /) -> HloPrintOptions: ... + @property + def print_ids(self) -> bool: ... + @print_ids.setter + def print_ids(self, arg: bool, /) -> HloPrintOptions: ... + @property + def print_extra_attributes(self) -> bool: ... + @print_extra_attributes.setter + def print_extra_attributes(self, arg: bool, /) -> HloPrintOptions: ... + @property + def print_program_shape(self) -> bool: ... + @print_program_shape.setter + def print_program_shape(self, arg: bool, /) -> HloPrintOptions: ... + @property + def print_percent(self) -> bool: ... + @print_percent.setter + def print_percent(self, arg: bool, /) -> HloPrintOptions: ... + @property + def print_control_dependencies(self) -> bool: ... + @print_control_dependencies.setter + def print_control_dependencies(self, arg: bool, /) -> HloPrintOptions: ... + @property + def compact_operands(self) -> bool: ... + @compact_operands.setter + def compact_operands(self, arg: bool, /) -> HloPrintOptions: ... + @property + def include_layout_in_shapes(self) -> bool: ... + @include_layout_in_shapes.setter + def include_layout_in_shapes(self, arg: bool, /) -> HloPrintOptions: ... + @property + def canonicalize_instruction_names(self) -> bool: ... + @canonicalize_instruction_names.setter + def canonicalize_instruction_names(self, arg: bool, /) -> HloPrintOptions: ... + @property + def canonicalize_computations(self) -> bool: ... + @canonicalize_computations.setter + def canonicalize_computations(self, arg: bool, /) -> HloPrintOptions: ... + @property + def indent_amount(self) -> int: ... + @indent_amount.setter + def indent_amount(self, arg: int, /) -> HloPrintOptions: ... + @property + def is_in_nested_computation(self) -> int: ... + @is_in_nested_computation.setter + def is_in_nested_computation(self, arg: bool, /) -> HloPrintOptions: ... + +class HloComputation: + @property + def name(self) -> str: ... + def render_html(self, arg: str, /) -> None: ... + +class HloModule: + @property + def name(self) -> str: ... + def to_string(self, options: HloPrintOptions = ...) -> str: ... + def as_serialized_hlo_module_proto(self) -> bytes: ... + def from_serialized_hlo_module_proto(self) -> HloModule: ... + def computations(self) -> list[HloComputation]: ... + @property + def spmd_output_sharding(self) -> OpSharding | None: ... + @property + def spmd_parameters_shardings(self) -> list[OpSharding] | None: ... + +def hlo_module_to_dot_graph(arg: HloModule, /) -> str: ... +def hlo_module_cost_analysis(arg0: Client, arg1: HloModule, /) -> dict: ... +def hlo_module_from_text(arg: str, /) -> HloModule: ... + +class DeviceAssignment: + @staticmethod + def create( + arg: Annotated[NDArray[numpy.int32], dict(shape=(None, None))], / + ) -> DeviceAssignment: ... + def replica_count(self) -> int: ... + def computation_count(self) -> int: ... + def __repr__(self) -> str: ... + def serialize(self) -> bytes: ... + +class CompileOptions: + def __init__(self) -> None: ... + def __getstate__(self) -> tuple: ... + def __setstate__(self, arg: tuple, /) -> None: ... + def SerializeAsString(self) -> bytes: ... + @staticmethod + def ParseFromString(arg: bytes, /) -> CompileOptions: ... + @property + def argument_layouts(self) -> list[Shape] | None: ... + @argument_layouts.setter + def argument_layouts(self, arg: Sequence[Shape], /) -> None: ... + @property + def parameter_is_tupled_arguments(self) -> bool: ... + @parameter_is_tupled_arguments.setter + def parameter_is_tupled_arguments(self, arg: bool, /) -> None: ... + @property + def compile_portable_executable(self) -> bool: ... + @compile_portable_executable.setter + def compile_portable_executable(self, arg: bool, /) -> None: ... + @property + def executable_build_options(self) -> ExecutableBuildOptions: ... + @property + def env_option_overrides( + self, + ) -> list[tuple[str, str | bool | int | float]]: ... + @env_option_overrides.setter + def env_option_overrides( + self, arg: Sequence[tuple[str, str | bool | int | float]], / + ) -> None: ... + @property + def num_replicas(self) -> int: ... + @num_replicas.setter + def num_replicas(self, arg: int, /) -> None: ... + @property + def num_partitions(self) -> int: ... + @num_partitions.setter + def num_partitions(self, arg: int, /) -> None: ... + @property + def profile_version(self) -> int: ... + @profile_version.setter + def profile_version(self, arg: int, /) -> None: ... + @property + def device_assignment(self) -> DeviceAssignment | None: ... + @device_assignment.setter + def device_assignment(self, arg: DeviceAssignment, /) -> None: ... + +class AutotuneCacheMode(enum.Enum): + UNSPECIFIED = 0 + + UPDATE = 1 + + READ = 2 + +class DebugOptions: + def __repr__(self) -> str: ... + @property + def xla_backend_optimization_level(self) -> int: ... + @xla_backend_optimization_level.setter + def xla_backend_optimization_level(self, arg: int, /) -> None: ... + @property + def xla_cpu_enable_fast_math(self) -> bool: ... + @xla_cpu_enable_fast_math.setter + def xla_cpu_enable_fast_math(self, arg: bool, /) -> None: ... + @property + def xla_cpu_enable_xprof_traceme(self) -> bool: ... + @xla_cpu_enable_xprof_traceme.setter + def xla_cpu_enable_xprof_traceme(self, arg: bool, /) -> None: ... + @property + def xla_cpu_fast_math_honor_infs(self) -> bool: ... + @xla_cpu_fast_math_honor_infs.setter + def xla_cpu_fast_math_honor_infs(self, arg: bool, /) -> None: ... + @property + def xla_cpu_fast_math_honor_nans(self) -> bool: ... + @xla_cpu_fast_math_honor_nans.setter + def xla_cpu_fast_math_honor_nans(self, arg: bool, /) -> None: ... + @property + def xla_cpu_fast_math_honor_division(self) -> bool: ... + @xla_cpu_fast_math_honor_division.setter + def xla_cpu_fast_math_honor_division(self, arg: bool, /) -> None: ... + @property + def xla_cpu_fast_math_honor_functions(self) -> bool: ... + @xla_cpu_fast_math_honor_functions.setter + def xla_cpu_fast_math_honor_functions(self, arg: bool, /) -> None: ... + @property + def xla_detailed_logging(self) -> bool: ... + @xla_detailed_logging.setter + def xla_detailed_logging(self, arg: bool, /) -> None: ... + @property + def xla_enable_dumping(self) -> bool: ... + @xla_enable_dumping.setter + def xla_enable_dumping(self, arg: bool, /) -> None: ... + @property + def xla_gpu_enable_fast_min_max(self) -> bool: ... + @xla_gpu_enable_fast_min_max.setter + def xla_gpu_enable_fast_min_max(self, arg: bool, /) -> None: ... + @property + def xla_gpu_dump_autotune_results_to(self) -> str: ... + @xla_gpu_dump_autotune_results_to.setter + def xla_gpu_dump_autotune_results_to(self, arg: str, /) -> None: ... + @property + def xla_gpu_load_autotune_results_from(self) -> str: ... + @xla_gpu_load_autotune_results_from.setter + def xla_gpu_load_autotune_results_from(self, arg: str, /) -> None: ... + @property + def xla_gpu_cuda_data_dir(self) -> str: ... + @xla_gpu_cuda_data_dir.setter + def xla_gpu_cuda_data_dir(self, arg: str, /) -> None: ... + @property + def xla_llvm_disable_expensive_passes(self) -> bool: ... + @xla_llvm_disable_expensive_passes.setter + def xla_llvm_disable_expensive_passes(self, arg: bool, /) -> None: ... + @property + def xla_disable_hlo_passes(self) -> str: ... + @xla_disable_hlo_passes.setter + def xla_disable_hlo_passes(self, arg: str, /) -> None: ... + @property + def xla_enable_hlo_passes_only(self) -> str: ... + @xla_enable_hlo_passes_only.setter + def xla_enable_hlo_passes_only(self, arg: str, /) -> None: ... + @property + def xla_test_all_input_layouts(self) -> bool: ... + @xla_test_all_input_layouts.setter + def xla_test_all_input_layouts(self, arg: bool, /) -> None: ... + @property + def xla_force_host_platform_device_count(self) -> int: ... + @xla_force_host_platform_device_count.setter + def xla_force_host_platform_device_count(self, arg: int, /) -> None: ... + @property + def xla_dump_to(self) -> str: ... + @xla_dump_to.setter + def xla_dump_to(self, arg: str, /) -> None: ... + @property + def xla_dump_hlo_module_re(self) -> str: ... + @xla_dump_hlo_module_re.setter + def xla_dump_hlo_module_re(self, arg: str, /) -> None: ... + @property + def xla_dump_hlo_pass_re(self) -> str: ... + @xla_dump_hlo_pass_re.setter + def xla_dump_hlo_pass_re(self, arg: str, /) -> None: ... + @property + def xla_dump_hlo_as_text(self) -> bool: ... + @xla_dump_hlo_as_text.setter + def xla_dump_hlo_as_text(self, arg: bool, /) -> None: ... + @property + def xla_dump_hlo_as_proto(self) -> bool: ... + @xla_dump_hlo_as_proto.setter + def xla_dump_hlo_as_proto(self, arg: bool, /) -> None: ... + @property + def xla_dump_hlo_as_dot(self) -> bool: ... + @xla_dump_hlo_as_dot.setter + def xla_dump_hlo_as_dot(self, arg: bool, /) -> None: ... + @property + def xla_dump_hlo_as_url(self) -> bool: ... + @xla_dump_hlo_as_url.setter + def xla_dump_hlo_as_url(self, arg: bool, /) -> None: ... + @property + def xla_dump_hlo_as_html(self) -> bool: ... + @xla_dump_hlo_as_html.setter + def xla_dump_hlo_as_html(self, arg: bool, /) -> None: ... + @property + def xla_dump_fusion_visualization(self) -> bool: ... + @xla_dump_fusion_visualization.setter + def xla_dump_fusion_visualization(self, arg: bool, /) -> None: ... + @property + def xla_dump_hlo_snapshots(self) -> bool: ... + @xla_dump_hlo_snapshots.setter + def xla_dump_hlo_snapshots(self, arg: bool, /) -> None: ... + @property + def xla_dump_max_hlo_modules(self) -> int: ... + @xla_dump_max_hlo_modules.setter + def xla_dump_max_hlo_modules(self, arg: int, /) -> None: ... + @property + def xla_dump_module_metadata(self) -> bool: ... + @xla_dump_module_metadata.setter + def xla_dump_module_metadata(self, arg: bool, /) -> None: ... + @property + def xla_dump_compress_protos(self) -> bool: ... + @xla_dump_compress_protos.setter + def xla_dump_compress_protos(self, arg: bool, /) -> None: ... + @property + def xla_dump_hlo_as_long_text(self) -> bool: ... + @xla_dump_hlo_as_long_text.setter + def xla_dump_hlo_as_long_text(self, arg: bool, /) -> None: ... + @property + def xla_dump_disable_metadata(self) -> bool: ... + @xla_dump_disable_metadata.setter + def xla_dump_disable_metadata(self, arg: bool, /) -> None: ... + @property + def xla_dump_hlo_pipeline_re(self) -> str: ... + @xla_dump_hlo_pipeline_re.setter + def xla_dump_hlo_pipeline_re(self, arg: str, /) -> None: ... + @property + def xla_gpu_dump_autotune_logs_to(self) -> str: ... + @xla_gpu_dump_autotune_logs_to.setter + def xla_gpu_dump_autotune_logs_to(self, arg: str, /) -> None: ... + @property + def xla_gpu_kernel_cache_file(self) -> str: ... + @xla_gpu_kernel_cache_file.setter + def xla_gpu_kernel_cache_file(self, arg: str, /) -> None: ... + @property + def xla_gpu_enable_llvm_module_compilation_parallelism(self) -> bool: ... + @xla_gpu_enable_llvm_module_compilation_parallelism.setter + def xla_gpu_enable_llvm_module_compilation_parallelism( + self, arg: bool, / + ) -> None: ... + @property + def xla_gpu_per_fusion_autotune_cache_dir(self) -> str: ... + @xla_gpu_per_fusion_autotune_cache_dir.setter + def xla_gpu_per_fusion_autotune_cache_dir(self, arg: str, /) -> None: ... + @property + def xla_gpu_experimental_autotune_cache_mode(self) -> AutotuneCacheMode: ... + @xla_gpu_experimental_autotune_cache_mode.setter + def xla_gpu_experimental_autotune_cache_mode( + self, arg: AutotuneCacheMode, / + ) -> None: ... + +class ExecutableBuildOptions: + def __init__(self) -> None: ... + def __repr__(self) -> str: ... + @property + def fdo_profile(self) -> bytes: ... + @fdo_profile.setter + def fdo_profile(self, arg: bytes, /) -> None: ... + @property + def result_layout(self) -> Shape | None: ... + @result_layout.setter + def result_layout(self, arg: Shape, /) -> ExecutableBuildOptions: ... + @property + def num_replicas(self) -> int: ... + @num_replicas.setter + def num_replicas(self, arg: int, /) -> ExecutableBuildOptions: ... + @property + def num_partitions(self) -> int: ... + @num_partitions.setter + def num_partitions(self, arg: int, /) -> ExecutableBuildOptions: ... + @property + def debug_options(self) -> DebugOptions: ... + @property + def device_assignment(self) -> DeviceAssignment | None: ... + @device_assignment.setter + def device_assignment( + self, arg: DeviceAssignment, / + ) -> ExecutableBuildOptions: ... + def compilation_environments_from_serialized_proto( + self, arg: bytes, / + ) -> None: ... + @property + def exec_time_optimization_effort(self) -> float: ... + @exec_time_optimization_effort.setter + def exec_time_optimization_effort( + self, arg: float, / + ) -> ExecutableBuildOptions: ... + @property + def memory_fitting_effort(self) -> float: ... + @memory_fitting_effort.setter + def memory_fitting_effort(self, arg: float, /) -> ExecutableBuildOptions: ... + @property + def optimization_level(self) -> int: ... + @optimization_level.setter + def optimization_level(self, arg: int, /) -> None: ... + @property + def memory_fitting_level(self) -> int: ... + @memory_fitting_level.setter + def memory_fitting_level(self, arg: int, /) -> None: ... + @property + def use_spmd_partitioning(self) -> bool: ... + @use_spmd_partitioning.setter + def use_spmd_partitioning(self, arg: bool, /) -> ExecutableBuildOptions: ... + @property + def use_auto_spmd_partitioning(self) -> bool: ... + @use_auto_spmd_partitioning.setter + def use_auto_spmd_partitioning( + self, arg: bool, / + ) -> ExecutableBuildOptions: ... + @property + def auto_spmd_partitioning_mesh_shape(self) -> list[int]: ... + @auto_spmd_partitioning_mesh_shape.setter + def auto_spmd_partitioning_mesh_shape( + self, arg: Sequence[int], / + ) -> ExecutableBuildOptions: ... + @property + def auto_spmd_partitioning_mesh_ids(self) -> list[int]: ... + @auto_spmd_partitioning_mesh_ids.setter + def auto_spmd_partitioning_mesh_ids( + self, arg: Sequence[int], / + ) -> ExecutableBuildOptions: ... + @property + def allow_spmd_sharding_propagation_to_parameters(self) -> list[bool]: ... + @allow_spmd_sharding_propagation_to_parameters.setter + def allow_spmd_sharding_propagation_to_parameters( + self, arg: Sequence[bool], / + ) -> None: ... + @property + def allow_spmd_sharding_propagation_to_output(self) -> list[bool]: ... + @allow_spmd_sharding_propagation_to_output.setter + def allow_spmd_sharding_propagation_to_output( + self, arg: Sequence[bool], / + ) -> None: ... + @property + def use_shardy_partitioner(self) -> bool: ... + @use_shardy_partitioner.setter + def use_shardy_partitioner(self, arg: bool, /) -> ExecutableBuildOptions: ... + +class OpSharding_Type(enum.IntEnum): + REPLICATED = 0 + + MAXIMAL = 1 + + MANUAL = 4 + + UNREDUCED = 6 + + TUPLE = 2 + + OTHER = 3 + + UNKNOWN = 5 + +class OpSharding_ShardGroupType(enum.Enum): + AS = 0 + + LIKE = 1 + +class OpSharding: + def __init__(self) -> None: ... + + Type: TypeAlias = OpSharding_Type + + ShardGroupType: TypeAlias = OpSharding_ShardGroupType + + def __getstate__(self) -> tuple: ... + def __setstate__(self, arg: tuple, /) -> None: ... + @property + def type(self) -> OpSharding_Type: ... + @type.setter + def type(self, arg: OpSharding_Type, /) -> None: ... + @property + def replicate_on_last_tile_dim(self) -> bool: ... + @replicate_on_last_tile_dim.setter + def replicate_on_last_tile_dim(self, arg: bool, /) -> None: ... + @property + def is_shard_group(self) -> bool: ... + @is_shard_group.setter + def is_shard_group(self, arg: bool, /) -> None: ... + @property + def shard_group_id(self) -> int: ... + @shard_group_id.setter + def shard_group_id(self, arg: int, /) -> None: ... + @property + def shard_group_type(self) -> OpSharding_ShardGroupType: ... + @shard_group_type.setter + def shard_group_type(self, arg: OpSharding_ShardGroupType, /) -> None: ... + def __repr__(self) -> str: ... + def ParseFromString(self, arg: bytes, /) -> None: ... + def SerializeToString(self) -> bytes: ... + def clone(self) -> OpSharding: ... + @property + def tile_assignment_dimensions(self) -> list[int]: ... + @tile_assignment_dimensions.setter + def tile_assignment_dimensions(self, arg: Sequence[int], /) -> None: ... + @property + def tile_assignment_devices(self) -> list[int]: ... + @tile_assignment_devices.setter + def tile_assignment_devices(self, arg: Sequence[int], /) -> None: ... + @property + def iota_reshape_dims(self) -> list[int]: ... + @iota_reshape_dims.setter + def iota_reshape_dims(self, arg: Sequence[int], /) -> None: ... + @property + def iota_transpose_perm(self) -> list[int]: ... + @iota_transpose_perm.setter + def iota_transpose_perm(self, arg: Sequence[int], /) -> None: ... + @property + def tuple_shardings(self) -> list[OpSharding]: ... + @tuple_shardings.setter + def tuple_shardings(self, arg: Sequence[OpSharding], /) -> None: ... + @property + def last_tile_dims(self) -> list[int]: ... + @last_tile_dims.setter + def last_tile_dims(self, arg: Sequence[int], /) -> None: ... + +class HloSharding: + @staticmethod + def from_proto(arg: OpSharding, /) -> HloSharding: ... + @staticmethod + def from_string(arg: str, /) -> HloSharding: ... + @staticmethod + def tuple_sharding( + arg0: Shape, arg1: Sequence[HloSharding], / + ) -> HloSharding: + """Constructs a tuple sharding.""" + + @staticmethod + def iota_tile( + dims: Sequence[int], + reshape_dims: Sequence[int] = ..., + transpose_perm: Sequence[int] = ..., + subgroup_types: Sequence[OpSharding_Type] = ..., + ) -> HloSharding: ... + @staticmethod + def manual() -> HloSharding: ... + @staticmethod + def replicate() -> HloSharding: ... + @staticmethod + def unreduced() -> HloSharding: ... + @staticmethod + def unknown() -> HloSharding: ... + @staticmethod + def subgroup_with_device_ordering( + tile_assignment: Annotated[NDArray[numpy.int64], dict(order='C')], + subgroup_types: Sequence[OpSharding_Type] = ..., + ) -> HloSharding: ... + def __eq__(self, other: object, /) -> bool: ... + def __ne__(self, other: object, /) -> bool: ... + def __hash__(self) -> int: ... + def is_replicated(self) -> bool: ... + def is_manual(self) -> bool: ... + def is_unreduced(self) -> bool: ... + def is_unknown(self) -> bool: ... + def is_tiled(self) -> bool: ... + def is_maximal(self) -> bool: ... + def tile(self, arg: Shape, /) -> Shape: ... + def tuple_elements(self) -> list[HloSharding]: ... + def num_devices(self) -> int: ... + def num_dimensions(self) -> int: ... + def is_tile_assignment_iota(self) -> bool: ... + def tile_assignment_dimensions(self) -> Sequence[int]: ... + def tile_assignment_devices(self) -> Sequence[int]: ... + def replicate_on_last_tile_dim(self) -> bool: ... + def subgroup_types(self) -> list[OpSharding_Type]: ... + def __repr__(self) -> str: ... + def to_proto(self) -> OpSharding: ... + def get_axis_sizes(self) -> list[int]: ... + +class Device: + """A descriptor of an available device. + + Subclasses are used to represent specific types of devices, e.g. CPUs, GPUs. + Subclasses may have additional properties specific to that device type. + """ + + @property + def id(self) -> int: + """Integer ID of this device. + + Unique across all available devices of this type, including remote devices + on multi-host platforms. + """ + + @property + def process_index(self) -> int: + """Integer index of this device's process. + + This is always 0 except on multi-process platforms. + """ + + @property + def host_id(self) -> int: + """Deprecated; please use process_index""" + + @property + def task_id(self) -> int: + """Deprecated; please use process_index""" + + @property + def platform(self) -> str: ... + @property + def device_kind(self) -> str: ... + @property + def client(self) -> Client: ... + @property + def local_hardware_id(self) -> int | None: + """Opaque hardware ID, e.g., the CUDA device number. + + In general, not guaranteed to be dense, and not guaranteed to be defined on + all platforms. + """ + + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def memory(self, kind: str) -> Memory: ... + def default_memory(self) -> Memory: + """Returns the default memory of a device.""" + + def addressable_memories(self) -> list[Memory]: + """Returns all the memories that a device can address.""" + + def live_buffers(self) -> list: ... + def memory_stats(self) -> dict[str, int] | None: + """Returns memory statistics for this device keyed by name. + + May not be implemented on all platforms, and different platforms may return + different stats, or -1 for unavailable stats. 'bytes_in_use' is usually + available. Intended for diagnostic use. + """ + + def get_stream_for_external_ready_events(self) -> int: ... + + __getattr__: types.MethodDescriptorType = ... + +class Memory: + @property + def process_index(self) -> int: ... + @property + def platform(self) -> str: ... + @property + def kind(self) -> str: ... + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def addressable_by_devices(self) -> list[Device]: + """Returns devices that can address this memory.""" + +class HostBufferSemantics(enum.Enum): + IMMUTABLE_ONLY_DURING_CALL = 0 + + IMMUTABLE_UNTIL_TRANSFER_COMPLETES = 1 + + ZERO_COPY = 2 + +class Client: + @property + def platform(self) -> str: ... + @property + def _raw_platform(self) -> str: ... + @property + def platform_version(self) -> str: ... + @property + def runtime_type(self) -> str: ... + def device_count(self) -> int: ... + def local_device_count(self) -> int: ... + def devices(self) -> list[Device]: ... + def local_devices(self) -> list[Device]: ... + def _get_all_devices(self) -> list[Device]: ... + def device_from_local_hardware_id(self, arg: int, /) -> Device: ... + def live_executables(self) -> list[LoadedExecutable]: ... + def live_arrays(self) -> list[Array]: ... + def live_buffers(self) -> list[Array]: ... + def process_index(self) -> int: ... + def host_id(self) -> int: ... + def task_id(self) -> int: ... + def buffer_from_pyval( + self, + argument: object, + device: Device | None = ..., + force_copy: bool = ..., + host_buffer_semantics: HostBufferSemantics = HostBufferSemantics.ZERO_COPY, + ) -> object: ... + def compile( + self, + computation: object, + executable_devices: DeviceList, + compile_options: CompileOptions = ..., + ) -> Executable: ... + @overload + def compile_and_load( + self, + computation: object, + executable_devices: DeviceList, + compile_options: CompileOptions = ..., + host_callbacks: Sequence[typing_extensions.CapsuleType] = ..., + ) -> LoadedExecutable: ... + @overload + def compile_and_load( + self, + computation: object, + executable_devices: DeviceList, + compile_options: CompileOptions = ..., + host_callbacks: Sequence[Callable[..., Any]] = ..., + ) -> LoadedExecutable: ... + @overload + def compile_and_load( + self, + computation: bytes, + executable_devices: Sequence, + compile_options: CompileOptions = ..., + ) -> LoadedExecutable: ... + @overload + def compile_and_load( + self, + computation: str, + executable_devices: Sequence, + compile_options: CompileOptions = ..., + ) -> LoadedExecutable: ... + def compile_ifrt_program( + self, arg0: ifrt_programs.Program, arg1: ifrt_programs.CompileOptions, / + ) -> LoadedExecutable: ... + def compile_and_load_ifrt_program( + self, arg0: ifrt_programs.Program, arg1: ifrt_programs.CompileOptions, / + ) -> LoadedExecutable: ... + def serialize_executable(self, arg: LoadedExecutable, /) -> bytes: ... + @overload + def deserialize_executable( + self, + serialized: bytes, + executable_devices: DeviceList, + compile_options: CompileOptions | None = ..., + host_callbacks: Sequence[typing_extensions.CapsuleType] = ..., + ) -> LoadedExecutable: ... + @overload + def deserialize_executable( + self, + serialized: bytes, + executable_devices: DeviceList, + compile_options: CompileOptions | None = ..., + host_callbacks: Sequence[Callable] = ..., + ) -> LoadedExecutable: ... + @overload + def deserialize_executable( + self, + serialized: bytes, + executable_devices: Sequence, + compile_options: CompileOptions | None = ..., + ) -> LoadedExecutable: ... + def heap_profile(self) -> bytes: ... + def defragment(self) -> None: ... + def make_python_callback_from_host_send_and_recv( + self, + callable: Callable, + operand_shapes: Sequence[Shape], + result_shapes: Sequence[Shape], + send_channel_ids: Sequence[int], + recv_channel_ids: Sequence[int], + serializer: Callable | None = ..., + ) -> object: ... + def get_default_layout( + self, dtype: numpy.dtype, shard_shape: Sequence, device: Device + ) -> PjRtLayout: ... + def __getattr__(self, arg: str, /) -> object: ... + +class ArrayCopySemantics(enum.IntEnum): + ALWAYS_COPY = 0 + + REUSE_INPUT = 1 + + DONATE_INPUT = 2 + +class PjRtLayout: + def __str__(self) -> str: ... + def __eq__(self, arg: object, /) -> bool: ... + def __hash__(self) -> int: ... + def _xla_layout(self) -> Layout: ... + def __getstate__(self) -> tuple: ... + def __setstate__(self, arg: tuple, /) -> None: ... + +class CpuCollectives: + def Init(self) -> None: ... + def Finalize(self) -> None: ... + +def make_gloo_tcp_collectives( + distributed_client: DistributedRuntimeClient, + hostname: str | None = ..., + interface: str | None = ..., +) -> CpuCollectives: ... +def make_mpi_collectives() -> CpuCollectives: ... +def get_tfrt_cpu_client( + asynchronous: bool = ..., + distributed_client: DistributedRuntimeClient | None = ..., + node_id: int = ..., + num_nodes: int = ..., + collectives: CpuCollectives | None = ..., + num_devices: int | None = ..., + get_local_topology_timeout_minutes: int | None = ..., + get_global_topology_timeout_minutes: int | None = ..., + transfer_server_factory: TransferServerInterfaceFactory | None = ..., +) -> Client: ... +def pjrt_plugin_loaded(arg: str, /) -> bool: ... +def load_pjrt_plugin( + platform_name: str, + library_path: str | None = ..., + c_api: typing_extensions.CapsuleType | None = ..., +) -> typing_extensions.CapsuleType: ... +def pjrt_plugin_initialized(arg: str, /) -> bool: ... +def initialize_pjrt_plugin(arg: str, /) -> None: ... +def get_c_api_client( + platform_name: str, + options: Mapping[str, str | bool | int | Sequence[int] | float] = ..., + distributed_client: DistributedRuntimeClient | None = ..., + transfer_server_factory: TransferServerInterfaceFactory | None = ..., + force_dcn_cross_host_transfers: bool = ..., +) -> Client: ... +def get_default_c_api_topology( + arg0: str, + arg1: str, + arg2: Mapping[str, str | bool | int | Sequence[int] | float], + /, +) -> DeviceTopology: ... +def get_c_api_topology( + arg0: typing_extensions.CapsuleType, + arg1: str, + arg2: Mapping[str, str | bool | int | Sequence[int] | float], + /, +) -> DeviceTopology: ... +def get_topology_for_devices(arg: Sequence[Device], /) -> DeviceTopology: ... + +class ArrayMeta(type): + def __instancecheck__(self, x: object | None) -> bool: ... + +Array: Any + +def set_tracer_class(arg: object, /) -> None: ... + +ArrayImpl: Any + +def batched_copy_array_to_devices_with_sharding( + arg0: Sequence[Array], + arg1: Sequence[DeviceList], + arg2: Sequence[object], + arg3: Sequence[ArrayCopySemantics], + /, +) -> list[Array]: ... +def array_result_handler( + aval: object, sharding: object, committed: bool, _skip_checks: bool = ... +) -> ResultHandler: ... + +class ResultHandler: + def __call__(self, arg: Array | Sequence[Array], /) -> Array: ... + def wrap(self, arg: Callable, /) -> ResultHandler: ... + def pre_wrap(self, arg: Callable, /) -> ResultHandler: ... + +class DeviceList: + def __init__(self, arg: tuple[Device, ...], /) -> None: ... + def __hash__(self) -> int: ... + def __eq__(self, arg: object, /) -> bool: ... + def __ne__(self, arg: object, /) -> bool: ... + def __len__(self) -> int: ... + @overload + def __getitem__(self, index: int, /) -> Device: ... + @overload + def __getitem__(self, slice: slice, /) -> Sequence[Device]: ... + def __iter__(self) -> Iterator: ... + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def __getstate__(self) -> tuple: ... + def __setstate__(self, arg: tuple, /) -> None: ... + @property + def is_fully_addressable(self) -> bool: ... + @property + def addressable_device_list(self) -> DeviceList: ... + @property + def process_indices(self) -> set[int]: ... + @property + def default_memory_kind(self) -> str | None: ... + @property + def memory_kinds(self) -> tuple[str, ...]: ... + @property + def device_kind(self) -> str: ... + +class Sharding: + def __init__(self) -> None: ... + +class NamedSharding(Sharding): + def __init__( + self, + mesh: object, + spec: PartitionSpec, + memory_kind: object | None = ..., + _logical_device_ids: object | None = ..., + ) -> None: ... + @property + def mesh(self) -> object: ... + @property + def spec(self) -> PartitionSpec: ... + @property + def _memory_kind(self) -> object: ... + @property + def _logical_device_ids(self) -> object: ... + @property + def _internal_device_list(self) -> DeviceList: ... + def __eq__(self, arg: object) -> bool: ... + def __hash__(self) -> int: ... + +class SingleDeviceSharding(Sharding): + def __init__( + self, device: object, memory_kind: object | None = ... + ) -> None: ... + @property + def _device(self) -> object: ... + @property + def _memory_kind(self) -> object: ... + @property + def _internal_device_list(self) -> DeviceList: ... + +class PmapSharding(Sharding): + def __init__( + self, devices: object, sharding_spec: pmap_lib.ShardingSpec + ) -> None: ... + @property + def devices(self) -> numpy.ndarray: ... + @property + def sharding_spec(self) -> pmap_lib.ShardingSpec: ... + @property + def _internal_device_list(self) -> DeviceList: ... + +class GSPMDSharding(Sharding): + @overload + def __init__( + self, + devices: DeviceList, + op_sharding: OpSharding, + memory_kind: object | None = ..., + ) -> None: ... + @overload + def __init__( + self, + devices: DeviceList, + op_sharding: HloSharding, + memory_kind: object | None = ..., + ) -> None: ... + @overload + def __init__( + self, + devices: Sequence[Device], + op_sharding: OpSharding, + memory_kind: object | None = ..., + ) -> None: ... + @overload + def __init__( + self, + devices: Sequence[Device], + op_sharding: HloSharding, + memory_kind: object | None = ..., + ) -> None: ... + @property + def _devices(self) -> DeviceList: ... + @property + def _hlo_sharding(self) -> HloSharding: ... + @property + def _memory_kind(self) -> object: ... + @property + def _internal_device_list(self) -> DeviceList: ... + +class CompiledMemoryStats: + @property + def generated_code_size_in_bytes(self) -> int: ... + @generated_code_size_in_bytes.setter + def generated_code_size_in_bytes(self, arg: int, /) -> None: ... + @property + def argument_size_in_bytes(self) -> int: ... + @argument_size_in_bytes.setter + def argument_size_in_bytes(self, arg: int, /) -> None: ... + @property + def output_size_in_bytes(self) -> int: ... + @output_size_in_bytes.setter + def output_size_in_bytes(self, arg: int, /) -> None: ... + @property + def alias_size_in_bytes(self) -> int: ... + @alias_size_in_bytes.setter + def alias_size_in_bytes(self, arg: int, /) -> None: ... + @property + def temp_size_in_bytes(self) -> int: ... + @temp_size_in_bytes.setter + def temp_size_in_bytes(self, arg: int, /) -> None: ... + @property + def host_generated_code_size_in_bytes(self) -> int: ... + @host_generated_code_size_in_bytes.setter + def host_generated_code_size_in_bytes(self, arg: int, /) -> None: ... + @property + def host_argument_size_in_bytes(self) -> int: ... + @host_argument_size_in_bytes.setter + def host_argument_size_in_bytes(self, arg: int, /) -> None: ... + @property + def host_output_size_in_bytes(self) -> int: ... + @host_output_size_in_bytes.setter + def host_output_size_in_bytes(self, arg: int, /) -> None: ... + @property + def host_alias_size_in_bytes(self) -> int: ... + @host_alias_size_in_bytes.setter + def host_alias_size_in_bytes(self, arg: int, /) -> None: ... + @property + def host_temp_size_in_bytes(self) -> int: ... + @host_temp_size_in_bytes.setter + def host_temp_size_in_bytes(self, arg: int, /) -> None: ... + @property + def serialized_buffer_assignment_proto(self) -> bytes: ... + @property + def peak_memory_in_bytes(self) -> int: ... + @peak_memory_in_bytes.setter + def peak_memory_in_bytes(self, arg: int, /) -> None: ... + def __str__(self) -> str: ... + +def get_execution_stream_id() -> int: ... +def set_execution_stream_id(arg: int, /) -> None: ... + +class LoadedExecutable: + @property + def client(self) -> Client: ... + def local_devices(self) -> list[Device]: ... + def get_hlo_text(self) -> str: ... + def serialize(self) -> bytes: ... + def size_of_generated_code_in_bytes(self) -> int: ... + def get_compiled_memory_stats(self) -> CompiledMemoryStats: ... + def execute_sharded( + self, arguments: Sequence[Array], with_tokens: bool = ... + ) -> ExecuteResults: ... + def hlo_modules(self) -> list[HloModule]: ... + def get_output_memory_kinds(self) -> list[list[str]]: ... + def get_output_shardings(self) -> list[OpSharding] | None: ... + def get_parameter_layouts(self) -> list[PjRtLayout]: ... + def get_output_layouts(self) -> list[PjRtLayout]: ... + def get_parameter_shardings(self) -> list[OpSharding] | None: ... + def keep_alive(self, arg: object, /) -> None: ... + def cost_analysis( + self, + ) -> dict[str, str | bool | int | list[int] | float]: ... + @property + def traceback(self) -> Traceback | None: ... + @property + def fingerprint(self) -> object: ... + +class ExecuteResults: + def __len__(self) -> int: ... + def disassemble_into_single_device_arrays(self) -> list[list[Array]]: ... + def disassemble_prefix_into_single_device_arrays( + self, arg: int, / + ) -> list[list[Array]]: ... + def consume_with_handlers( + self, out_handlers: Sequence[ResultHandler | object], strict: bool = ... + ) -> list[object]: ... + def consume_token(self) -> ShardedToken: ... + +class Token: + def block_until_ready(self) -> None: ... + +class ShardedToken: + def block_until_ready(self) -> None: ... + def get_token(self, arg: int, /) -> Token: ... + +class Executable: + def hlo_modules(self) -> list[HloModule]: ... + def get_output_memory_kinds(self) -> list[list[str]]: ... + def get_output_shardings(self) -> list[OpSharding] | None: ... + def get_parameter_layouts(self) -> list[PjRtLayout]: ... + def get_output_layouts(self) -> list[PjRtLayout]: ... + def get_parameter_shardings(self) -> list[OpSharding] | None: ... + def get_compiled_memory_stats(self) -> CompiledMemoryStats: ... + def serialize(self) -> bytes: ... + def cost_analysis( + self, + ) -> dict[str, str | bool | int | list[int] | float]: ... + +def buffer_to_dlpack_managed_tensor( + buffer: object, stream: int | None = ... +) -> typing_extensions.CapsuleType: ... +def dlpack_managed_tensor_to_buffer( + dlpack: typing_extensions.CapsuleType, + device: Device, + stream: int | None, + copy: bool | None = ..., +) -> ArrayImpl: ... +def cuda_array_interface_to_buffer( + cai: dict, gpu_backend: Client | None = ..., device_id: int | None = ... +) -> object: ... + +class RuntimeTracebackMode(enum.Enum): + OFF = 0 + + ON = 1 + + FULL = 2 + +def add_exclude_path(arg: str, /) -> None: + """Adds a path to exclude from tracebacks.""" + +def set_send_traceback_to_runtime_global( + arg: RuntimeTracebackMode, / +) -> None: ... +def set_send_traceback_to_runtime_thread_local( + mode: RuntimeTracebackMode | None, +) -> None: ... + +class PjitFunctionCache: + def __init__(self, capacity: int = ...) -> None: ... + def size(self) -> int: ... + def capacity(self) -> int: ... + def clear(self) -> None: ... + @staticmethod + def clear_all() -> None: ... + def __getstate__(self) -> dict: ... + def __setstate__(self, arg: dict, /) -> None: ... + +class PjitFunction: + def __repr__(self, /): + """Return repr(self).""" + + def __call__(self, /, *args, **kwargs): + """Call self as a function.""" + + def __get__(self, instance, owner=..., /): + """Return an attribute of instance, which is of type owner.""" + __vectorcalloffset__: types.MemberDescriptorType = ... + + def __getstate__(self) -> dict: ... + def __setstate__(self, arg: dict, /) -> None: ... + @property + def __signature__(self) -> inspect.Signature: ... + @property + def _cache_miss(self) -> Callable: ... + def _cache_size(self) -> int: ... + def _clear_cache(self) -> None: ... + +def pjit( + function_name: str, + fun: Callable[..., Any] | None, + cache_miss: Callable[..., Any], + static_argnums: Sequence[int], + static_argnames: Sequence[str], + global_cache_key: Any, + pytree_registry: _PyTreeRegistry, + shard_arg_fallback: Callable[..., Any], + cache: PjitFunctionCache | None = ..., +) -> PjitFunction: ... + +class Frame: + def __init__(self, arg0: str, arg1: str, arg2: int, arg3: int, /) -> None: ... + @property + def file_name(self) -> str: ... + @property + def function_name(self) -> str: ... + @property + def function_start_line(self) -> int: ... + @property + def line_num(self) -> int: ... + def __repr__(self) -> str: ... + +class Traceback: + def __hash__(self, /): + """Return hash(self).""" + + def __str__(self, /): + """Return str(self).""" + + def __lt__(self, value, /): + """Return selfvalue.""" + + def __ge__(self, value, /): + """Return self>=value.""" + + @staticmethod + def get_traceback() -> Traceback | None: + """Returns a :class:`Traceback` for the current thread. + + If ``Traceback.enabled`` is ``True``, returns a :class:`Traceback` + object that describes the Python stack of the calling thread. Stack + trace collection has a small overhead, so it is disabled by default. If + traceback collection is disabled, returns ``None``. + """ + + @property + def frames(self) -> list[Frame]: ... + def raw_frames(self) -> tuple[list[types.CodeType], list[int]]: ... + def as_python_traceback(self) -> traceback.TracebackType: ... + @staticmethod + def traceback_from_frames(frames: list[Frame]) -> traceback.TracebackType: + """Creates a traceback from a list of frames.""" + + @staticmethod + def code_addr2line(code: types.CodeType, lasti: int) -> int: + """Python wrapper around the Python C API function PyCode_Addr2Line""" + + @staticmethod + def code_addr2location( + code: types.CodeType, lasti: int + ) -> tuple[int, int, int, int]: + """Python wrapper around the Python C API function PyCode_Addr2Location""" + +def tracebacks_enabled() -> bool: ... +def set_tracebacks_enabled(arg: bool, /) -> None: ... +def register_custom_call_partitioner( + name: str, + prop_user_sharding: object, + partition: object, + infer_sharding_from_operands: object, + can_side_effecting_have_replicated_sharding: bool = ..., + c_api: typing_extensions.CapsuleType | None = ..., +) -> None: + """Registers a partitioner for a custom-call operation. + + Args: + name: custom_call_target to match. + prop_user_sharding: Custom backwards sharding propagation rule. Takes result + sharding and returns the instruction sharding. + partition: Lowering rule. Takes operand and result shardings and returns a + generated HLO and sharding specs. The spmd lowerer first reshards to match + the returned sharding specs and then inserts the generated hlo. + infer_sharding_from_operands: Custom forwards sharding propagation rule. + Takes operand sharding and returns the instruction sharding. + can_side_effecting_have_replicated_sharding: Side effecting ops are not + allowed to have replicated sharding. Pass true to disable this check. + c_api: Optional `PJRT_Api*` if it is called with a plugin. This is safe to + call on plugins that do not implement the custom partitioner extension + """ + +def encode_inspect_sharding_callback(arg: object, /) -> bytes: ... +def register_custom_call_as_batch_partitionable( + target_name: str, c_api: typing_extensions.CapsuleType | None = ... +) -> None: + """Registers a custom call as batch partitionable. + + If a custom call is "batch partitionable", it means that it can be trivially + partitioned on some number of (leading) dimensions, with the same call being + executed independently on each shard of data. If the data are sharded on + non-batch dimensions, partitioning will re-shard the data to be replicated on + the non-batch dimensions. + + Args: + target_name: the target name of the batch partitionable custom call. + c_api: optional `PJRT_Api*` to support registration via a PJRT plugin. + """ + +def register_custom_call_target( + fn_name: object, + fn: object, + platform: str, + api_version: int = ..., + traits: int = ..., +) -> None: ... +def custom_call_targets(platform: str) -> dict: ... +def register_custom_type(type_name: str, type_id: object) -> None: ... + +class TransferConnection: + def _testonly_inject_failure(self) -> None: ... + def _poison_connection(self) -> None: ... + def _pull_flat( + self, arg0: int, arg1: Client, arg2: Sequence[object], / + ) -> list[Array]: ... + def _pull_into_flat( + self, arg0: int, arg1: Sequence[Array], arg2: Sequence[slice], / + ) -> list[Token]: ... + +class TransferServer: + def address(self) -> str: ... + def _await_pull_flat(self, arg0: int, arg1: Sequence[Array], /) -> None: ... + def _reset_rendevous_table(self) -> None: ... + def connect(self, arg: str, /) -> TransferConnection: ... + +def _make_error_array(arg0: Client, arg1: object, arg2: str, /) -> Array: ... +def start_transfer_server( + client: Client, + address: str = ..., + transport_addresses: Sequence[str] = ..., + max_num_parallel_copies: int = ..., + transfer_size: int = ..., + supports_pinned_allocator: bool = ..., + use_raw_buffers: bool = ..., +) -> TransferServer: ... +def make_transfer_server_interface_factory( + transfer_size: int = ..., + cross_host_transfer_timeout_seconds: int = ..., + distributed_client: DistributedRuntimeClient | None = ..., + socket_address: str = ..., + transport_addresses: Sequence[str] = ..., +) -> TransferServerInterfaceFactory: ... + +class PreemptionSyncManager: + def initialize( + self, distributed_client: DistributedRuntimeClient + ) -> None: ... + def reached_sync_point(self, arg: int, /) -> bool: ... + def shutdown(self) -> None: ... + +def create_preemption_sync_manager() -> PreemptionSyncManager: ... + +class DistributedRuntimeService: + def shutdown(self) -> None: ... + +class DistributedRuntimeClient: + def connect(self) -> None: ... + def shutdown(self) -> None: ... + def blocking_key_value_get(self, key: str, timeout_in_ms: int) -> str: ... + def blocking_key_value_get_bytes( + self, key: str, timeout_in_ms: int + ) -> bytes: ... + def key_value_try_get(self, key: str) -> str: ... + def key_value_try_get_bytes(self, key: str) -> bytes: ... + def key_value_increment(self, key: str, increment: int) -> int: ... + def wait_at_barrier( + self, + barrier_id: str, + timeout_in_ms: int, + process_ids: Sequence[int] | None = ..., + ) -> None: ... + def get_live_nodes(self, process_ids: Sequence[int]) -> dict[int, int]: ... + def key_value_set( + self, key: str, value: str, allow_overwrite: bool = ... + ) -> None: ... + def key_value_set_bytes( + self, key: str, value: bytes, allow_overwrite: bool = ... + ) -> None: ... + def key_value_dir_get(self, key: str) -> list[tuple[str, str]]: ... + def key_value_dir_get_bytes(self, key: str) -> list[tuple[str, bytes]]: ... + def key_value_delete(self, key: str) -> None: ... + +def get_distributed_runtime_service( + address: str, + num_nodes: int, + heartbeat_timeout: int | None = ..., + cluster_register_timeout: int | None = ..., + shutdown_timeout: int | None = ..., +) -> DistributedRuntimeService: ... +def get_distributed_runtime_client( + address: str, + node_id: int, + rpc_timeout: int | None = ..., + init_timeout: int | None = ..., + shutdown_timeout: int | None = ..., + heartbeat_timeout: int | None = ..., + missed_heartbeat_callback: Callable | None = ..., + shutdown_on_destruction: bool | None = ..., + use_compression: bool | None = ..., + recoverable: bool | None = ..., +) -> DistributedRuntimeClient: ... +def collect_garbage() -> None: ... +def is_optimized_build() -> bool: ... +def json_to_pprof_profile(arg: str, /) -> bytes: + """Encodes the JSON representation of a pprof Profile into its binary protocol buffer encoding.""" + +def pprof_profile_to_json(arg: bytes, /) -> str: + """Decodes an uncompressed pprof Profile protocol buffer into a JSON representation""" + +class CompileOnlyPyClient(Client): + def compile( + self, + computation: object, + executable_devices: DeviceList, + compile_options: CompileOptions = ..., + host_callbacks: Sequence[typing_extensions.CapsuleType] = ..., + ) -> Executable: ... + +class DeviceTopology: + def _make_compile_only_devices(self) -> list[Device]: ... + @property + def platform(self) -> str: ... + @property + def platform_version(self) -> str: ... + def serialize(self) -> bytes: ... + def __getattr__(self, arg: str, /) -> object: ... + +class TransferServerInterfaceFactory: + pass + +def is_asan() -> bool: ... +def is_msan() -> bool: ... +def is_tsan() -> bool: ... +def is_sanitized() -> bool: ... +def batched_device_put( + aval: object, + sharding: object, + xs: Sequence[object], + devices: Sequence[Device], + committed: bool = ..., + force_copy: bool = ..., + host_buffer_semantics: HostBufferSemantics = HostBufferSemantics.ZERO_COPY, + enable_x64: bool | None = ..., +) -> object: ... +def reorder_shards( + x: Array, dst_sharding: object, array_copy_semantics: ArrayCopySemantics +) -> Array: ... +def batched_block_until_ready(arg: Sequence[object], /) -> None: ... +def check_and_canonicalize_memory_kind( + memory_kind: object | None, device_list: DeviceList +) -> object: ... + +ifrt_version_number: int = ... + +def approx_top_k_reduction_output_size( + input_size: int, + rank: int, + top_k: int, + recall_target: float, + aggregate_to_topk: bool = ..., + input_size_override: int = ..., +) -> tuple[int, int]: ... +def get_internal_device_put_info() -> dict[str, int]: ... + +class UnconstrainedSingleton: + def __repr__(self) -> str: ... + def __reduce__(self) -> str: ... + +UNCONSTRAINED_PARTITION: UnconstrainedSingleton = ... + +def canonicalize_partition(arg: object, /) -> object: ... + +class PartitionSpec(Any): + def __init__( + self, *partitions, unreduced: object = ..., reduced: object = ... + ) -> None: ... + @property + def _partitions(self) -> tuple: ... + @property + def unreduced(self) -> frozenset: ... + @property + def reduced(self) -> frozenset: ... + def __eq__(self, arg: object) -> bool: ... + def __hash__(self) -> int: ... + +def set_typed_int_type(arg: object, /) -> None: ... +def set_typed_float_type(arg: object, /) -> None: ... +def set_typed_complex_type(arg: object, /) -> None: ... +def set_typed_ndarray_type(arg: object, /) -> None: ... diff --git a/jaxlib/_jax/config.pyi b/jaxlib/_jax/config.pyi new file mode 100644 index 000000000000..85d3ba4c6bdb --- /dev/null +++ b/jaxlib/_jax/config.pyi @@ -0,0 +1,40 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Generic, TypeVar + +unset: object = ... + +_T = TypeVar("_T") + +class Config(Generic[_T]): + def __init__( + self, + name: str, + value: _T, + *, + include_in_jit_key: bool = ..., + include_in_trace_context: bool = ..., + ) -> None: ... + @property + def value(self) -> _T: ... + @property + def name(self) -> str: ... + def get_local(self) -> Any: ... + def get_global(self) -> _T: ... + def set_local(self, value: Any | None) -> None: ... + def swap_local(self, value: Any | None) -> Any: ... + def set_global(self, value: Any | None) -> None: ... + +def trace_context() -> tuple: ... diff --git a/jaxlib/_jax/ffi.pyi b/jaxlib/_jax/ffi.pyi new file mode 100644 index 000000000000..3804100a2efb --- /dev/null +++ b/jaxlib/_jax/ffi.pyi @@ -0,0 +1,55 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +import numpy +import typing_extensions + +class Buffer: + @property + def dtype(self) -> numpy.dtype: ... + @property + def ndim(self) -> int: ... + @property + def shape(self) -> tuple: ... + @property + def writeable(self) -> bool: ... + def __array__( + self, dtype: object | None = ..., copy: object | None = ... + ) -> numpy.ndarray: ... + @property + def __cuda_array_interface__(self) -> dict: ... + def __dlpack__( + self, + stream: object | None = ..., + max_version: object | None = ..., + dl_device: object | None = ..., + copy: object | None = ..., + ) -> typing_extensions.CapsuleType: ... + def __dlpack_device__(self) -> tuple: ... + +class ExecutionStage(enum.Enum): + INSTANTIATE = 0 + + PREPARE = 1 + + INITIALIZE = 2 + + EXECUTE = 3 + +class ExecutionContext: + @property + def stage(self) -> ExecutionStage: ... + @property + def stream(self) -> int: ... diff --git a/jaxlib/_jax/guard_lib.pyi b/jaxlib/_jax/guard_lib.pyi new file mode 100644 index 000000000000..e14ce8a14155 --- /dev/null +++ b/jaxlib/_jax/guard_lib.pyi @@ -0,0 +1,65 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum + +class TransferGuardLevel(enum.Enum): + ALLOW = 0 + + LOG = 1 + + DISALLOW = 2 + + LOG_EXPLICIT = 3 + + DISALLOW_EXPLICIT = 4 + +class GarbageCollectionGuardLevel(enum.Enum): + ALLOW = 0 + + LOG = 1 + + FATAL = 2 + +class GuardState: + @property + def host_to_device(self) -> TransferGuardLevel | None: ... + @host_to_device.setter + def host_to_device(self, arg: TransferGuardLevel | None) -> None: ... + @property + def device_to_device(self) -> TransferGuardLevel | None: ... + @device_to_device.setter + def device_to_device(self, arg: TransferGuardLevel | None) -> None: ... + @property + def device_to_host(self) -> TransferGuardLevel | None: ... + @device_to_host.setter + def device_to_host(self, arg: TransferGuardLevel | None) -> None: ... + @property + def explicit_device_put(self) -> bool: ... + @explicit_device_put.setter + def explicit_device_put(self, arg: bool, /) -> None: ... + @property + def explicit_device_get(self) -> bool: ... + @explicit_device_get.setter + def explicit_device_get(self, arg: bool, /) -> None: ... + @property + def garbage_collect_array(self) -> GarbageCollectionGuardLevel | None: ... + @garbage_collect_array.setter + def garbage_collect_array( + self, arg: GarbageCollectionGuardLevel | None + ) -> None: ... + +def global_state() -> GuardState: ... +def thread_local_state() -> GuardState: ... +def update_thread_guard_global_state(arg: bool) -> None: ... diff --git a/jaxlib/_jax/hlo_sharding_util.pyi b/jaxlib/_jax/hlo_sharding_util.pyi new file mode 100644 index 000000000000..bad23a1a736b --- /dev/null +++ b/jaxlib/_jax/hlo_sharding_util.pyi @@ -0,0 +1,21 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Sequence + +from .import HloSharding as _HloSharding + +def PartiallyReplicateTiledShardingOnDims( + sharding: _HloSharding, dims: Sequence[int], / +) -> _HloSharding: ... diff --git a/jaxlib/_jax/ifrt_programs.pyi b/jaxlib/_jax/ifrt_programs.pyi new file mode 100644 index 000000000000..ee74bc4bf877 --- /dev/null +++ b/jaxlib/_jax/ifrt_programs.pyi @@ -0,0 +1,52 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Sequence +from typing import Any, overload + +from .import ( + CompileOptions as _CompileOptions, + DeviceList as _DeviceList, + Device as _Device, +) +import typing_extensions + +class Program: + pass + +class CompileOptions: + pass + +@overload +def make_hlo_program(mlir_module: str) -> Program: ... +@overload +def make_hlo_program(mlir_module: bytes) -> Program: ... +def make_colocated_python_program( + name: str, + picked_function: bytes, + devices: Sequence[_Device] | _DeviceList, + input_avals: Sequence[Any], + output_avals: Sequence[Any], +) -> Program: ... +@overload +def make_plugin_program(data: str) -> Program: ... +@overload +def make_plugin_program(data: bytes) -> Program: ... +def make_xla_compile_options( + options: _CompileOptions, + executable_devices: Sequence[_Device], + host_callbacks: Sequence[typing_extensions.CapsuleType], +) -> CompileOptions: ... +def make_colocated_python_compile_options() -> CompileOptions: ... +def make_plugin_compile_options() -> CompileOptions: ... diff --git a/jaxlib/_jax/jax_jit.pyi b/jaxlib/_jax/jax_jit.pyi new file mode 100644 index 000000000000..07e02d4731c8 --- /dev/null +++ b/jaxlib/_jax/jax_jit.pyi @@ -0,0 +1,74 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable, Sequence +from .config import Config as _Config +from .pytree import ( + PyTreeDef as _PyTreeDef, + PyTreeRegistry as _PyTreeRegistry, +) +import numpy + +def set_disable_jit_state(config: _Config) -> None: ... +def set_enable_x64_state(config: _Config) -> None: ... +def set_post_hook_state(config: _Config) -> None: ... +def set_thread_local_state_initialization_callback( + f: Callable[[], None], +) -> None: ... + +class PyArgSignature: + @property + def dtype(self) -> numpy.dtype: ... + @property + def shape(self) -> tuple[int, ...]: ... + @property + def weak_type(self) -> bool: ... + +def _ArgSignatureOfValue(arg0: object, arg1: bool, /) -> PyArgSignature: ... + +class ArgumentSignature: + @property + def static_args(self) -> list[object]: ... + @property + def static_arg_names(self) -> list[str]: ... + @property + def dynamic_arg_names(self) -> list[str]: ... + @property + def dynamic_arg_treedefs(self) -> Sequence[_PyTreeDef]: ... + def __repr__(self) -> str: ... + def __str__(self) -> str: ... + def __hash__(self) -> int: ... + def __eq__(self, arg: object, /) -> bool: ... + def __ne__(self, arg: object, /) -> bool: ... + +def parse_arguments( + positional_args: Sequence[object], + keyword_args: Sequence[object], + kwnames: tuple[str, ...], + static_argnums: Sequence[int], + static_argnames: Sequence[str], + pytree_registry: _PyTreeRegistry, +) -> tuple[ArgumentSignature, list[object]]: + """Parses the arguments to a function as jax.jit would. + + Returns a ArgumentSignature and the flattened dynamic arguments. + + Args: + positional_args: The positional arguments. + keyword_args: The keyword arguments. + kwnames: The keyword names. + static_argnums: The static argument numbers. + static_argnames: The static argument names. + pytree_registry: The pytree registry. + """ diff --git a/jaxlib/_jax/mlir.pyi b/jaxlib/_jax/mlir.pyi new file mode 100644 index 000000000000..eaf0166a686b --- /dev/null +++ b/jaxlib/_jax/mlir.pyi @@ -0,0 +1,54 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import overload + +from .import XlaComputation as _XlaComputation + +def hlo_to_stablehlo(computation: bytes) -> bytes: ... +def xla_computation_to_mlir_module(computation: _XlaComputation) -> str: ... +@overload +def mlir_module_to_xla_computation( + mlir_module: bytes, use_tuple_args: bool = ..., return_tuple: bool = ... +) -> _XlaComputation: ... +@overload +def mlir_module_to_xla_computation( + mlir_module: str, use_tuple_args: bool = ..., return_tuple: bool = ... +) -> _XlaComputation: ... +@overload +def mhlo_to_stablehlo(mlir_module: bytes) -> bytes: ... +@overload +def mhlo_to_stablehlo(mlir_module: str) -> bytes: ... +@overload +def serialize_portable_artifact( + mlir_module: bytes, target: str, use_mixed_serialization: bool = ... +) -> bytes: ... +@overload +def serialize_portable_artifact( + mlir_module: str, target: str, use_mixed_serialization: bool = ... +) -> bytes: ... +def deserialize_portable_artifact(mlir_module: bytes) -> str: ... +def refine_polymorphic_shapes( + mlir_module: bytes, + enable_shape_assertions: bool = ..., + validate_static_shapes: bool = ..., + enable_shardy: bool = ..., +) -> bytes: + """Refines the dynamic shapes for a module. + + The "main" function must have static shapes and all the + intermediate dynamic shapes depend only on the input static + shapes. Optionally, also validates that the resulting module has + only static shapes. + """ diff --git a/jaxlib/_jax/pmap_lib.pyi b/jaxlib/_jax/pmap_lib.pyi new file mode 100644 index 000000000000..222d29299b44 --- /dev/null +++ b/jaxlib/_jax/pmap_lib.pyi @@ -0,0 +1,102 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable, Iterable, Sequence +import inspect +import types +from typing import Any + +from .pytree import PyTreeRegistry as _PyTreeRegistry + +class NoSharding: + def __init__(self) -> None: ... + def __getstate__(self) -> tuple: ... + def __setstate__(self, arg: tuple, /) -> None: ... + def __repr__(self) -> str: ... + def __eq__(self, arg: object, /) -> bool: ... + def __hash__(self) -> int: ... + +class Chunked: + def __init__(self, arg: Sequence[int], /) -> None: ... + def __getstate__(self) -> tuple: ... + def __setstate__(self, arg: tuple, /) -> None: ... + @property + def chunks(self) -> list[int]: ... + def __repr__(self) -> str: ... + def __eq__(self, arg: object, /) -> bool: ... + +class Unstacked: + def __init__(self, arg: int, /) -> None: ... + def __getstate__(self) -> tuple: ... + def __setstate__(self, arg: tuple, /) -> None: ... + @property + def size(self) -> int: ... + def __repr__(self) -> str: ... + def __eq__(self, arg: object, /) -> bool: ... + +class ShardedAxis: + def __init__(self, arg: int, /) -> None: ... + def __getstate__(self) -> tuple: ... + def __setstate__(self, arg: tuple, /) -> None: ... + @property + def axis(self) -> int: ... + def __repr__(self) -> str: ... + def __eq__(self, arg: object, /) -> bool: ... + +class Replicated: + def __init__(self, arg: int, /) -> None: ... + def __getstate__(self) -> tuple: ... + def __setstate__(self, arg: tuple, /) -> None: ... + @property + def replicas(self) -> int: ... + def __repr__(self) -> str: ... + def __eq__(self, arg: object, /) -> bool: ... + +class ShardingSpec(Any): + def __init__(self, sharding: Iterable, mesh_mapping: Iterable) -> None: ... + def __getstate__(self) -> tuple: ... + def __setstate__(self, arg: tuple, /) -> None: ... + @property + def sharding(self) -> tuple[NoSharding | Chunked | Unstacked, ...]: ... + @property + def mesh_mapping(self) -> tuple[ShardedAxis | Replicated, ...]: ... + def __eq__(self, arg: object, /) -> bool: ... + def __hash__(self) -> int: ... + +class PmapFunction: + def __call__(self, /, *args, **kwargs): + """Call self as a function.""" + + def __get__(self, instance, owner=..., /): + """Return an attribute of instance, which is of type owner.""" + __vectorcalloffset__: types.MemberDescriptorType = ... + + @property + def __signature__(self) -> inspect.Signature: ... + @property + def _cache_miss(self) -> Callable: ... + def __getstate__(self) -> dict: ... + def __setstate__(self, arg: dict, /) -> None: ... + @property + def _cache_size(self) -> int: ... + def _cache_clear(self) -> None: ... + def _debug_cache_keys(self) -> str: ... + +def pmap( + fun: Callable[..., Any], + cache_miss: Callable[..., Any], + static_argnums: Sequence[int], + shard_arg_fallback: Callable[..., Any], + pytree_registry: _PyTreeRegistry, +) -> PmapFunction: ... diff --git a/jaxlib/_jax/profiler.pyi b/jaxlib/_jax/profiler.pyi new file mode 100644 index 000000000000..a2fcc67fbcb7 --- /dev/null +++ b/jaxlib/_jax/profiler.pyi @@ -0,0 +1,59 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from types import TracebackType +from typing import Any + +_Status = Any + +class ProfilerServer: ... +def start_server(port: int) -> ProfilerServer: ... + +def register_plugin_profiler(c_api: Any) -> None: ... + +def get_profiled_instructions_proto(tensorboard_dir: str) -> bytes: ... +def get_instructins_profile(tensorboard_dir: str) -> list[tuple[str, float]]: ... +def get_fdo_profile( + xspace: bytes, as_textproto: bool = ... +) -> bytes | str: ... + +class ProfilerSession: + def __init__(self, options: ProfileOptions | None = ...) -> None: ... + def stop(self) -> bytes: ... + def export(self, xspace: bytes, tensorboard_dir: str) -> _Status:... + +class ProfileOptions: + include_dataset_ops: bool + host_tracer_level: int + python_tracer_level: int + enable_hlo_proto: bool + start_timestamp_ns: int + duration_ms: int + repository_path: str + raise_error_on_start_failure: bool + +def aggregate_profiled_instructions(profiles: list[bytes], percentile: int) -> str: ... + +class TraceMe: + def __init__(self, name: str, **kwargs: Any) -> None: ... + def __enter__(self) -> TraceMe: ... + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + exc_tb: TracebackType | None) -> bool | None:... + def set_metadata(self, **kwargs): ... + @staticmethod + def is_enabled() -> bool: ... diff --git a/jaxlib/_jax/pytree.pyi b/jaxlib/_jax/pytree.pyi new file mode 100644 index 000000000000..02332bae96fe --- /dev/null +++ b/jaxlib/_jax/pytree.pyi @@ -0,0 +1,184 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable, Hashable, Iterable, Sequence +from typing import Any, TypeVar + +version: int = ... + +_T = TypeVar("_T") + +_Children = TypeVar("_Children", bound=Iterable[Any]) + +_KeyLeafPair = TypeVar("_KeyLeafPair", bound=tuple[Any, Any]) + +_KeyLeafPairs = TypeVar("_KeyLeafPairs", bound=Iterable[tuple[Any, Any]]) + +_KeyPath = TypeVar("_KeyPath", bound=tuple[Any, ...]) + +_AuxData = TypeVar("_AuxData", bound=Hashable) + +class PyTreeRegistry: + def __init__( + self, + enable_none: bool = ..., + enable_tuple: bool = ..., + enable_namedtuple: bool = ..., + enable_list: bool = ..., + enable_dict: bool = ..., + ) -> None: ... + def flatten( + self, + tree: object | None, + leaf_predicate: Callable[[Any], bool] | None = None, + ) -> tuple[list[Any], PyTreeDef]: ... + def flatten_one_level( + self, tree: object | None + ) -> tuple[Iterable[Any], Any] | None: ... + def flatten_one_level_with_keys( + self, tree: object | None + ) -> tuple[Iterable[_KeyLeafPair], Any] | None: ... + def flatten_with_path( + self, + tree: object | None, + leaf_predicate: Callable[[Any, Any], bool] | None = None, + ) -> tuple[list[tuple[_KeyPath, Any]], PyTreeDef]: ... + def register_node( + self, + type: type[_T], + to_iterable: Callable[[_T], tuple[_Children, _AuxData]], + from_iterable: Callable[[_AuxData, _Children], _T], + to_iterable_with_keys: ( + Callable[[_T], tuple[_KeyLeafPairs, _AuxData]] | None + ) = None, + ) -> Any: ... + def register_dataclass_node( + self, + type: type, + data_fields: Sequence[str], + meta_fields: Sequence[str], + /, + ) -> Any: ... + def __reduce__(self) -> str: ... + +_default_registry: PyTreeRegistry = ... + +def default_registry() -> PyTreeRegistry: ... +def treedef_tuple( + registry: PyTreeRegistry, arg0: Sequence[PyTreeDef], / +) -> PyTreeDef: ... +def all_leaves(arg0: PyTreeRegistry, arg1: Iterable, /) -> bool: ... + +class PyTreeDef: + def unflatten(self, arg: Iterable[Any], /) -> Any: ... + def flatten_up_to(self, tree: object | None) -> list: ... + def compose(self, arg: PyTreeDef, /) -> PyTreeDef: ... + def walk( + self, + __f_node: Callable[[Any, Any], Any], + __f_leaf: Callable[[_T], Any] | None, + leaves: Iterable[Any], + /, + ) -> Any: + """Walk pytree, calling f_node(node, node_data) at nodes, and f_leaf at leaves""" + + def from_iterable_tree(self, arg: object, /) -> object: ... + def children(self) -> list[PyTreeDef]: ... + @property + def num_leaves(self) -> int: ... + @property + def num_nodes(self) -> int: ... + def __repr__(self) -> str: ... + def __eq__(self, arg: object, /) -> bool: ... + def __ne__(self, arg: object, /) -> bool: ... + def __hash__(self) -> int: ... + def serialize_using_proto(self) -> bytes: ... + @staticmethod + def deserialize_using_proto( + registry: PyTreeRegistry, data: bytes + ) -> PyTreeDef: ... + def node_data(self) -> tuple[type, Any] | None: + """Returns None if a leaf-pytree, else (type, node_data)""" + + @staticmethod + def from_node_data_and_children( + self, + registry: PyTreeRegistry, + node_data: tuple[type, Any] | None, + children: Iterable[PyTreeDef], + ) -> PyTreeDef: + """Reconstructs a pytree from `node_data()` and `children()`.""" + + def __getstate__(self) -> object: ... + def __setstate__(self, arg: object, /) -> None: ... + +class SequenceKey(Hashable): + def __init__(self, idx: int) -> None: ... + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def __eq__(self, arg: object, /) -> bool: ... + def __hash__(self) -> int: ... + @property + def idx(self) -> int: ... + + __match_args__: tuple = ... + """(arg: object, /) -> tuple""" + + def __getstate__(self) -> tuple: ... + def __setstate__(self, arg: tuple, /) -> None: ... + +class DictKey(Hashable): + def __init__(self, key: object) -> None: ... + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def __eq__(self, arg: object, /) -> bool: ... + def __hash__(self) -> int: ... + @property + def key(self) -> object: ... + + __match_args__: tuple = ... + """(arg: object, /) -> tuple""" + + def __getstate__(self) -> tuple: ... + def __setstate__(self, arg: tuple, /) -> None: ... + +class GetAttrKey(Hashable): + def __init__(self, name: str) -> None: ... + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def __eq__(self, arg: object, /) -> bool: ... + def __hash__(self) -> int: ... + @property + def name(self) -> str: ... + + __match_args__: tuple = ... + """(arg: object, /) -> tuple""" + + def __getstate__(self) -> tuple: ... + def __setstate__(self, arg: tuple, /) -> None: ... + +class FlattenedIndexKey(Hashable): + def __init__(self, key: int) -> None: ... + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def __eq__(self, arg: object, /) -> bool: ... + def __hash__(self) -> int: ... + @property + def key(self) -> int: ... + + __match_args__: tuple = ... + """(arg: object, /) -> tuple""" + + def __getstate__(self) -> tuple: ... + def __setstate__(self, arg: tuple, /) -> None: ... diff --git a/jaxlib/_jax/transfer_guard_lib.pyi b/jaxlib/_jax/transfer_guard_lib.pyi new file mode 100644 index 000000000000..d293f7c59798 --- /dev/null +++ b/jaxlib/_jax/transfer_guard_lib.pyi @@ -0,0 +1,39 @@ +# Copyright 2022 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any + +class TransferGuardLevel: + ALLOW: Any + LOG: Any + DISALLOW: Any + LOG_EXPLICIT: Any + DISALLOW_EXPLICIT: Any + +class TransferGuardState: + host_to_device: TransferGuardLevel | None + device_to_device: TransferGuardLevel | None + device_to_host: TransferGuardLevel | None + + explicit_device_put: bool + explicit_device_get: bool + +def global_state() -> TransferGuardState: ... +def thread_local_state() -> TransferGuardState: ... + +class _TestingScopedLogSink: + def __enter__(self) -> _TestingScopedLogSink: ... + def __exit__(self, *args, **kwargs) -> None: ... + def logs(self) -> list[str]: ... diff --git a/jaxlib/_pathways.pyi b/jaxlib/_pathways.pyi new file mode 100644 index 000000000000..c5d374572170 --- /dev/null +++ b/jaxlib/_pathways.pyi @@ -0,0 +1,35 @@ +# Copyright 2025 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from collections.abc import Sequence +from typing import Any + +from jaxlib import _jax + +def _transfer_to_shardings( + arrays: Sequence[_jax.ArrayImpl], + out_shardings: Sequence[Any], + donate: bool = ..., +) -> Sequence[_jax.ArrayImpl]: ... + +def _split_by_mesh_axis( + arrays: Sequence[_jax.ArrayImpl], + sharded_dim_idxs: Sequence[int], + mesh_axis_sizes: Sequence[int], + mesh_axis_idx: int, + mesh_axis_sections: Sequence[int], + submesh_shardings: Sequence[Sequence[int]], + donate: bool = ..., +) -> Sequence[Sequence[_jax.ArrayImpl]]: ... diff --git a/jaxlib/_pretty_printer.cc b/jaxlib/_pretty_printer.cc new file mode 100644 index 000000000000..1ac125c5bd6b --- /dev/null +++ b/jaxlib/_pretty_printer.cc @@ -0,0 +1,755 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/nb_class_ptr.h" + +namespace nb = nanobind; + +namespace jax { + +enum class Color { + kBlack = 30, + kRed = 31, + kGreen = 32, + kYellow = 33, + kBlue = 34, + kMagenta = 35, + kCyan = 36, + kWhite = 37, + kReset = 39, +}; + +std::string ColorToString(Color color) { + switch (color) { + case Color::kBlack: + return "black"; + case Color::kRed: + return "red"; + case Color::kGreen: + return "green"; + case Color::kYellow: + return "yellow"; + case Color::kBlue: + return "blue"; + case Color::kMagenta: + return "magenta"; + case Color::kCyan: + return "cyan"; + case Color::kWhite: + return "white"; + case Color::kReset: + return "reset"; + } +} + +enum class Intensity { + kNormal = 22, + kDim = 2, + kBright = 1, +}; + +std::string IntensityToString(Intensity intensity) { + switch (intensity) { + case Intensity::kNormal: + return "normal"; + case Intensity::kDim: + return "dim"; + case Intensity::kBright: + return "bright"; + } +} + +struct FormatState; +struct FormatAgendum; + +class Doc { + public: + Doc(int num_annotations) : num_annotations_(num_annotations) {} + virtual ~Doc() = default; + virtual std::string Repr() const = 0; + + int num_annotations() const { return num_annotations_; } + + virtual void Fits(std::stack& agenda, int& width) const = 0; + + // Returns true if the doc may be sparse, i.e. there are no breaks between + // annotations. Returns false if the doc is known not to be sparse. + virtual bool Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const = 0; + + virtual void Format(const FormatAgendum& agendum, + FormatState& state) const = 0; + + private: + int num_annotations_; +}; + +class NilDoc final : public Doc { + public: + NilDoc() : Doc(/*num_annotations=*/0) {} + std::string Repr() const override; + + void Fits(std::stack& agenda, int& width) const override; + bool Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const override; + virtual void Format(const FormatAgendum& agendum, + FormatState& state) const override; +}; + +class TextDoc final : public Doc { + public: + TextDoc(std::string text, std::optional annotation) + : Doc(annotation.has_value() ? 1 : 0), + text_(std::move(text)), + annotation_(std::move(annotation)) {} + std::string Repr() const override; + void Fits(std::stack& agenda, int& width) const override; + bool Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const override; + virtual void Format(const FormatAgendum& agendum, + FormatState& state) const override; + + private: + std::string text_; + std::optional annotation_; +}; + +class ConcatDoc final : public Doc { + public: + explicit ConcatDoc(std::vector> children) + : Doc(TotalNumAnnotations(children)), children_(std::move(children)) {} + std::string Repr() const override; + + void Fits(std::stack& agenda, int& width) const override; + bool Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const override; + virtual void Format(const FormatAgendum& agendum, + FormatState& state) const override; + + private: + static int TotalNumAnnotations(absl::Span> children) { + int total = 0; + for (const auto& child : children) { + total += child->num_annotations(); + } + return total; + } + std::vector> children_; +}; + +class BreakDoc final : public Doc { + public: + explicit BreakDoc(std::string text) + : Doc(/*num_annotations=*/0), text_(std::move(text)) {} + std::string Repr() const override; + void Fits(std::stack& agenda, int& width) const override; + bool Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const override; + virtual void Format(const FormatAgendum& agendum, + FormatState& state) const override; + + private: + std::string text_; +}; + +class GroupDoc final : public Doc { + public: + explicit GroupDoc(nb_class_ptr child) + : Doc(/*num_annotations=*/child->num_annotations()), + child_(std::move(child)) {} + std::string Repr() const override; + void Fits(std::stack& agenda, int& width) const override; + bool Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const override; + virtual void Format(const FormatAgendum& agendum, + FormatState& state) const override; + + private: + nb_class_ptr child_; +}; + +class NestDoc final : public Doc { + public: + explicit NestDoc(int n, nb_class_ptr child) + : Doc(child->num_annotations()), n_(n), child_(std::move(child)) {} + std::string Repr() const override; + void Fits(std::stack& agenda, int& width) const override; + bool Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const override; + virtual void Format(const FormatAgendum& agendum, + FormatState& state) const override; + + private: + int n_; + nb_class_ptr child_; +}; + +class SourceMapDoc final : public Doc { + public: + explicit SourceMapDoc(nb_class_ptr child, nb::object source) + : Doc(child->num_annotations()), + child_(std::move(child)), + source_(std::move(source)) {} + std::string Repr() const override; + void Fits(std::stack& agenda, int& width) const override; + bool Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const override; + virtual void Format(const FormatAgendum& agendum, + FormatState& state) const override; + + private: + nb_class_ptr child_; + nb::object source_; +}; + +class ColorDoc final : public Doc { + public: + explicit ColorDoc(nb_class_ptr child, std::optional foreground, + std::optional background, + std::optional intensity) + : Doc(child->num_annotations()), + child_(std::move(child)), + foreground_(foreground), + background_(background), + intensity_(intensity) {} + + std::string Repr() const override; + void Fits(std::stack& agenda, int& width) const override; + bool Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const override; + virtual void Format(const FormatAgendum& agendum, + FormatState& state) const override; + + private: + nb_class_ptr child_; + std::optional foreground_; + std::optional background_; + std::optional intensity_; +}; + +std::string NilDoc::Repr() const { return "nil"; } + +std::string TextDoc::Repr() const { + if (annotation_.has_value()) { + return absl::StrFormat("text(\"%s\", annotation=\"%s\")", text_, + *annotation_); + } else { + return absl::StrFormat("text(\"%s\")", text_); + } +} + +std::string ConcatDoc::Repr() const { + return absl::StrFormat( + "concat(%s)", + absl::StrJoin(children_, ", ", [](std::string* out, const auto& child) { + absl::StrAppend(out, child->Repr()); + })); +} + +std::string BreakDoc::Repr() const { + return absl::StrFormat("break(\"%s\")", text_); +} + +std::string GroupDoc::Repr() const { + return absl::StrFormat("group(%s)", child_->Repr()); +} + +std::string NestDoc::Repr() const { + return absl::StrFormat("nest(%d, %s)", n_, child_->Repr()); +} + +std::string SourceMapDoc::Repr() const { + return absl::StrFormat("source(%s, %s)", child_->Repr(), + nb::cast(nb::repr(source_))); +} + +std::string ColorDoc::Repr() const { + std::string foreground_str = + foreground_.has_value() ? ColorToString(*foreground_) : "None"; + std::string background_str = + background_.has_value() ? ColorToString(*background_) : "None"; + std::string intensity_str = + intensity_.has_value() ? IntensityToString(*intensity_) : "None"; + return absl::StrFormat("color(%s, %s, %s, %s)", child_->Repr(), + foreground_str, background_str, intensity_str); +} + +// Fits method implementations + +void NilDoc::Fits(std::stack& agenda, int& width) const {} + +void TextDoc::Fits(std::stack& agenda, int& width) const { + width -= text_.size(); +} + +void ConcatDoc::Fits(std::stack& agenda, int& width) const { + for (auto it = children_.rbegin(); it != children_.rend(); ++it) { + agenda.push(it->get()); + } +} + +void BreakDoc::Fits(std::stack& agenda, int& width) const { + width -= static_cast(text_.size()); +} + +void GroupDoc::Fits(std::stack& agenda, int& width) const { + agenda.push(child_.get()); +} + +void NestDoc::Fits(std::stack& agenda, int& width) const { + agenda.push(child_.get()); +} + +void SourceMapDoc::Fits(std::stack& agenda, int& width) const { + agenda.push(child_.get()); +} + +void ColorDoc::Fits(std::stack& agenda, int& width) const { + agenda.push(child_.get()); +} + +bool Fits(const Doc* doc, int width) { + std::stack agenda; + agenda.push(doc); + while (width >= 0 && !agenda.empty()) { + const Doc* doc = agenda.top(); + agenda.pop(); + doc->Fits(agenda, width); + } + return width >= 0; +} + +// Sparse method implementations + +bool NilDoc::Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const { + return true; +} + +bool TextDoc::Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const { + if (annotation_.has_value()) { + if (num_annotations >= 1 && seen_break) { + return false; + } + num_annotations -= 1; + } + return true; +} + +bool ConcatDoc::Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const { + for (auto it = children_.rbegin(); it != children_.rend(); ++it) { + agenda.push(it->get()); + } + return true; +} + +bool BreakDoc::Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const { + seen_break = true; + return true; +} + +bool GroupDoc::Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const { + agenda.push(child_.get()); + return true; +} + +bool NestDoc::Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const { + agenda.push(child_.get()); + return true; +} + +bool SourceMapDoc::Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const { + agenda.push(child_.get()); + return true; +} + +bool ColorDoc::Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const { + agenda.push(child_.get()); + return true; +} + +// Returns true if the doc is sparse, i.e. there are no breaks between +// annotations. +bool Sparse(const Doc* doc) { + if (doc->num_annotations() == 0) { + return true; + } + std::stack agenda; + agenda.push(doc); + int num_annotations = 0; + bool seen_break = false; + while (!agenda.empty()) { + const Doc* doc = agenda.top(); + agenda.pop(); + if (!doc->Sparse(agenda, num_annotations, seen_break)) { + return false; + } + } + return true; +} + +struct ColorState { + Color foreground; + Color background; + Intensity intensity; + + bool operator==(const ColorState& other) const { + return foreground == other.foreground && background == other.background && + intensity == other.intensity; + } + bool operator!=(const ColorState& other) const { return !operator==(other); } +}; + +constexpr ColorState kDefaultColors = + ColorState{Color::kReset, Color::kReset, Intensity::kNormal}; +constexpr ColorState kAnnotationColors = + ColorState{Color::kReset, Color::kReset, Intensity::kDim}; + +enum class BreakMode { kFlat, kBreak }; + +struct FormatAgendum { + int indent; + BreakMode mode; + const Doc* doc; + ColorState color; + nb::object source; +}; + +struct Line { + std::string text; + int width; + std::vector annotations; +}; + +// Format method implementations + +struct FormatState { + int width; + std::stack agenda; + std::string line_text; + int k; + std::vector line_annotations; + std::optional color; + std::optional source_map; + nb::list line_source_map; + int source_start; + nb::object source; + std::vector lines; +}; + +std::string UpdateColor(std::optional& state, + const ColorState& update) { + if (!state.has_value() || *state == update) { + return ""; + } + std::string result = "\033["; + absl::InlinedVector codes; + if (state->foreground != update.foreground) { + codes.push_back(absl::StrCat(static_cast(update.foreground))); + } + if (state->background != update.background) { + codes.push_back(absl::StrCat(static_cast(update.background) + 10)); + } + if (state->intensity != update.intensity) { + codes.push_back(absl::StrCat(static_cast(update.intensity))); + } + absl::StrAppend(&result, absl::StrJoin(codes, ";"), "m"); + state = update; + return result; +} + +void NilDoc::Format(const FormatAgendum& agendum, FormatState& state) const {} + +void TextDoc::Format(const FormatAgendum& agendum, FormatState& state) const { + absl::StrAppend(&state.line_text, UpdateColor(state.color, agendum.color), + text_); + if (annotation_.has_value()) { + state.line_annotations.push_back(*annotation_); + } + state.k += text_.size(); +} + +void ConcatDoc::Format(const FormatAgendum& agendum, FormatState& state) const { + for (auto it = children_.rbegin(); it != children_.rend(); ++it) { + state.agenda.push(FormatAgendum{agendum.indent, agendum.mode, it->get(), + agendum.color, state.source}); + } +} + +void BreakDoc::Format(const FormatAgendum& agendum, FormatState& state) const { + if (agendum.mode == BreakMode::kBreak) { + if (!state.line_annotations.empty()) { + absl::StrAppend(&state.line_text, + UpdateColor(state.color, kAnnotationColors)); + } + if (state.source_map.has_value()) { + int pos = state.line_text.size(); + if (state.source_start != pos && state.source.ptr() != nullptr) { + state.line_source_map.append( + nb::make_tuple(state.source_start, pos, state.source)); + } + state.source_map->append(state.line_source_map); + state.line_source_map = nb::list(); + state.source_start = agendum.indent; + } + state.lines.push_back(Line{std::move(state.line_text), state.k, + std::move(state.line_annotations)}); + state.line_text = std::string(agendum.indent, ' '); + state.line_annotations.clear(); + state.k = agendum.indent; + } else { + absl::StrAppend(&state.line_text, UpdateColor(state.color, agendum.color), + text_); + state.k += text_.size(); + } +} + +void GroupDoc::Format(const FormatAgendum& agendum, FormatState& state) const { + // In Lindig's paper, _fits is passed the remainder of the document. + // I'm pretty sure that's a bug and we care only if the current group fits! + bool fits = ::jax::Fits(agendum.doc, state.width - state.k) && + ::jax::Sparse(agendum.doc); + state.agenda.push(FormatAgendum{agendum.indent, + fits ? BreakMode::kFlat : BreakMode::kBreak, + child_.get(), agendum.color, state.source}); +} + +void NestDoc::Format(const FormatAgendum& agendum, FormatState& state) const { + state.agenda.push(FormatAgendum{agendum.indent + n_, agendum.mode, + child_.get(), agendum.color, state.source}); +} + +void SourceMapDoc::Format(const FormatAgendum& agendum, + FormatState& state) const { + state.agenda.push(FormatAgendum{agendum.indent, agendum.mode, child_.get(), + agendum.color, source_}); +} + +void ColorDoc::Format(const FormatAgendum& agendum, FormatState& state) const { + ColorState color = agendum.color; + if (foreground_.has_value()) { + color.foreground = *foreground_; + } + if (background_.has_value()) { + color.background = *background_; + } + if (intensity_.has_value()) { + color.intensity = *intensity_; + } + state.agenda.push(FormatAgendum{agendum.indent, agendum.mode, child_.get(), + color, state.source}); +} + +std::string Format(const Doc* doc, int width, bool use_color, + std::string annotation_prefix, + std::optional source_map) { + FormatState state; + if (use_color) { + state.color = kDefaultColors; + } + state.width = width; + state.source_start = 0; + state.source_map = source_map; + state.agenda.push( + FormatAgendum{0, BreakMode::kBreak, doc, kDefaultColors, nb::object()}); + state.k = 0; + while (!state.agenda.empty()) { + FormatAgendum agendum = state.agenda.top(); + state.agenda.pop(); + if (source_map.has_value() && agendum.source.ptr() != state.source.ptr()) { + int pos = state.line_text.size(); + if (state.source_start != pos && state.source.ptr() != nullptr) { + state.line_source_map.append( + nb::make_tuple(state.source_start, pos, state.source)); + } + state.source = agendum.source; + state.source_start = pos; + } + agendum.doc->Format(agendum, state); + } + if (!state.line_annotations.empty()) { + absl::StrAppend(&state.line_text, + UpdateColor(state.color, kAnnotationColors)); + } + if (state.source_map.has_value()) { + int pos = state.line_text.size(); + if (state.source_start != pos && state.source.ptr() != nullptr) { + state.line_source_map.append( + nb::make_tuple(state.source_start, pos, state.source)); + } + state.source_map->append(state.line_source_map); + } + state.lines.push_back(Line{std::move(state.line_text), state.k, + std::move(state.line_annotations)}); + + int max_width = 0; + for (const auto& line : state.lines) { + max_width = std::max(max_width, line.width); + } + std::string out = + absl::StrJoin(state.lines, "\n", [&](std::string* out, const Line& line) { + if (line.annotations.empty()) { + absl::StrAppend(out, line.text); + } else { + absl::StrAppend(out, line.text, + std::string(max_width - line.width, ' '), + annotation_prefix, line.annotations[0]); + for (int i = 1; i < line.annotations.size(); ++i) { + absl::StrAppend(out, std::string(max_width, ' '), annotation_prefix, + line.annotations[i]); + } + } + }); + absl::StrAppend(&out, UpdateColor(state.color, kDefaultColors)); + return out; +} + +NB_MODULE(_pretty_printer, m) { + nb::enum_(m, "Color") + .value("BLACK", Color::kBlack) + .value("RED", Color::kRed) + .value("GREEN", Color::kGreen) + .value("YELLOW", Color::kYellow) + .value("BLUE", Color::kBlue) + .value("MAGENTA", Color::kMagenta) + .value("CYAN", Color::kCyan) + .value("WHITE", Color::kWhite) + .value("RESET", Color::kReset); + + nb::enum_(m, "Intensity") + .value("DIM", Intensity::kDim) + .value("NORMAL", Intensity::kNormal) + .value("BRIGHT", Intensity::kBright); + + nb::class_(m, "Doc") + .def("__repr__", &Doc::Repr) + .def("__add__", + [](nb_class_ptr self, + nb_class_ptr other) -> nb_class_ptr { + return make_nb_class(std::vector>{ + std::move(self), std::move(other)}); + }) + .def("_format", &Format, nb::arg("width"), nb::arg("use_color"), + nb::arg("annotation_prefix"), nb::arg("source_map").none()); + + nb::class_(m, "NilDoc"); + nb::class_(m, "TextDoc"); + nb::class_(m, "ConcatDoc"); + nb::class_(m, "BreakDoc"); + nb::class_(m, "GroupDoc"); + nb::class_(m, "NestDoc"); + nb::class_(m, "ColorDoc"); + nb::class_(m, "SourceMapDoc"); + + m.def( + "nil", []() -> nb_class_ptr { return make_nb_class(); }, + "An empty document."); + m.def( + "text", + [](std::string text, + std::optional annotation) -> nb_class_ptr { + return make_nb_class(std::move(text), std::move(annotation)); + }, + nb::arg("text"), nb::arg("annotation").none() = std::nullopt, + "Literal text."); + m.def( + "concat", + [](std::vector> children) -> nb_class_ptr { + return make_nb_class(std::move(children)); + }, + nb::arg("children"), "Concatenation of documents."); + m.def( + "brk", + [](std::string text) -> nb_class_ptr { + return make_nb_class(text); + }, + nb::arg("text") = std::string(" "), + R"(A break. + +Prints either as a newline or as `text`, depending on the enclosing group. +)"); + m.def( + "group", + [](nb_class_ptr child) -> nb_class_ptr { + return make_nb_class(std::move(child)); + }, + R"(Layout alternative groups. + +Prints the group with its breaks as their text (typically spaces) if the +entire group would fit on the line when printed that way. Otherwise, breaks +inside the group as printed as newlines. +)"); + m.def( + "nest", + [](int n, nb_class_ptr child) -> nb_class_ptr { + return make_nb_class(n, std::move(child)); + }, + "Increases the indentation level by `n`."); + m.def( + "color", + [](nb_class_ptr child, std::optional foreground, + std::optional background, + std::optional intensity) -> nb_class_ptr { + return make_nb_class(std::move(child), foreground, background, + intensity); + }, + nb::arg("child"), nb::arg("foreground").none() = std::nullopt, + nb::arg("background").none() = std::nullopt, + nb::arg("intensity").none() = std::nullopt, + R"(ANSI colors. + +Overrides the foreground/background/intensity of the text for the child doc. +Requires use_colors=True to be set when printing; otherwise does nothing. +)"); + m.def( + "source_map", + [](nb_class_ptr child, nb::object source) -> nb_class_ptr { + return make_nb_class(std::move(child), std::move(source)); + }, + nb::arg("doc"), nb::arg("source"), + R"(Source mapping. + +A source map associates a region of the pretty-printer's text output with a +source location that produced it. For the purposes of the pretty printer a +``source`` may be any object: we require only that we can compare sources for +equality. A text region to source object mapping can be populated as a side +output of the ``format`` method. +)"); +} + +} // namespace jax diff --git a/jaxlib/_pretty_printer.pyi b/jaxlib/_pretty_printer.pyi new file mode 100644 index 000000000000..8dbb32f7fc1a --- /dev/null +++ b/jaxlib/_pretty_printer.pyi @@ -0,0 +1,105 @@ +from collections.abc import Sequence +import enum + + +class Color(enum.Enum): + BLACK = 30 + + RED = 31 + + GREEN = 32 + + YELLOW = 33 + + BLUE = 34 + + MAGENTA = 35 + + CYAN = 36 + + WHITE = 37 + + RESET = 39 + +class Intensity(enum.Enum): + DIM = 2 + + NORMAL = 22 + + BRIGHT = 1 + +class Doc: + def __repr__(self) -> str: ... + + def __add__(self, arg: Doc, /) -> Doc: ... + +class NilDoc(Doc): + pass + +class TextDoc(Doc): + pass + +class ConcatDoc(Doc): + pass + +class BreakDoc(Doc): + pass + +class GroupDoc(Doc): + pass + +class NestDoc(Doc): + pass + +class ColorDoc(Doc): + pass + +class SourceMapDoc(Doc): + pass + +def nil() -> Doc: + """An empty document.""" + +def text(text: str, annotation: str | None = None) -> Doc: + """Literal text.""" + +def concat(children: Sequence[Doc]) -> Doc: + """Concatenation of documents.""" + +def brk(text: str = ' ') -> Doc: + """ + A break. + + Prints either as a newline or as `text`, depending on the enclosing group. + """ + +def group(arg: Doc, /) -> Doc: + """ + Layout alternative groups. + + Prints the group with its breaks as their text (typically spaces) if the + entire group would fit on the line when printed that way. Otherwise, breaks + inside the group as printed as newlines. + """ + +def nest(arg0: int, arg1: Doc, /) -> Doc: + """Increases the indentation level by `n`.""" + +def color(child: Doc, foreground: Color | None = None, background: Color | None = None, intensity: Intensity | None = None) -> Doc: + """ + ANSI colors. + + Overrides the foreground/background/intensity of the text for the child doc. + Requires use_colors=True to be set when printing; otherwise does nothing. + """ + +def source_map(doc: Doc, source: object) -> Doc: + """ + Source mapping. + + A source map associates a region of the pretty-printer's text output with a + source location that produced it. For the purposes of the pretty printer a + ``source`` may be any object: we require only that we can compare sources for + equality. A text region to source object mapping can be populated as a side + output of the ``format`` method. + """ diff --git a/jaxlib/cached_py_object.h b/jaxlib/cached_py_object.h new file mode 100644 index 000000000000..b934fa203a44 --- /dev/null +++ b/jaxlib/cached_py_object.h @@ -0,0 +1,61 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAX_JAXLIB_CACHED_PY_OBJECT_H_ +#define JAX_JAXLIB_CACHED_PY_OBJECT_H_ + +#include + +#include "absl/functional/function_ref.h" +#include "nanobind/nanobind.h" + +namespace jax { + +// A lock-free thread-safe cache for a single Python object. +// Example use case: caching a hash value in an object. +class CachedPyObject { + public: + CachedPyObject() = default; + ~CachedPyObject() { + PyObject* value = value_.load(); + Py_XDECREF(value); + } + + // Returns the cached value of the object. If the object is not present, + // factory() will be called to create it and the cache will be populated. + // Note: factory() may be called multiple times if used concurrently. The + // returned value will be one of the returned values of factory(). + // Thread-safe. + nanobind::object Get(absl::FunctionRef factory) { + PyObject* v = value_.load(); + if (v) { + return nanobind::borrow(v); + } + nanobind::object new_value = factory(); + if (value_.compare_exchange_strong(v, new_value.inc_ref().ptr())) { + return new_value; + } else { + new_value.dec_ref(); + return nanobind::borrow(v); + } + } + + private: + std::atomic value_ = nullptr; +}; + +} // namespace jax + +#endif // JAX_JAXLIB_CACHED_PY_OBJECT_H_ diff --git a/jaxlib/call_location.cc b/jaxlib/call_location.cc new file mode 100644 index 000000000000..96855b114fb2 --- /dev/null +++ b/jaxlib/call_location.cc @@ -0,0 +1,153 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/call_location.h" + +#include +#include +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/py_user_context.h" +#include "jaxlib/traceback.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/user_context.h" +#include "xla/python/pjrt_ifrt/pjrt_executable.h" + +namespace nb = nanobind; + +namespace jax { + +namespace { + +std::atomic global_runtime_traceback_mode_ = + RuntimeTracebackMode::kOff; + +thread_local std::optional + runtime_traceback_mode_thread_local_ = std::nullopt; + +RuntimeTracebackMode GetRuntimeTracebackMode() { + if (runtime_traceback_mode_thread_local_.has_value()) { + return *runtime_traceback_mode_thread_local_; + } + return global_runtime_traceback_mode_.load(); +} + +static absl::Mutex shared_data_mu; + +static absl::NoDestructor> exclude_paths_from_python + ABSL_GUARDED_BY(shared_data_mu); + +static absl::NoDestructor> + known_code_objects ABSL_GUARDED_BY(shared_data_mu); + +// Returns true if the code object is an internal JAX frame (Cached) +bool IsJaxInternalFrame(PyCodeObject* code) { + nb::str file_name = nb::borrow(code->co_filename); + std::string file_name_sv = nb::cast(file_name); + + absl::MutexLock lock(shared_data_mu); + auto it = known_code_objects->find(file_name_sv); + if (it != known_code_objects->end()) { + return it->second; + } + + bool is_internal = false; + for (const auto& prefix : *exclude_paths_from_python) { + if (absl::StartsWith(file_name_sv, prefix)) { + is_internal = true; + break; + } + } + (*known_code_objects)[file_name_sv] = is_internal; + return is_internal; +} + +// Returns the first non-JAX internal frame in the format "file:line" +std::string GetCallLocation(const jax::Traceback& traceback) { + auto frames = traceback.RawFrames(); + for (const auto& frame : frames) { + if (!IsJaxInternalFrame(frame.code)) { + nb::str file_name = nb::borrow(frame.code->co_filename); + int line_num = PyCode_Addr2Line(frame.code, frame.lasti); + return absl::StrCat(nb::cast(file_name), ":", line_num); + } + } + return ""; +} + +} // namespace + +// Populates the "call_location" field in the execute options if +// jax.config.jax_send_traceback_to_runtime is not 'off'. +// A traceback will be collected from the current user context. +void PopulateCallLocation(xla::ifrt::ExecuteOptions& options, + const xla::ifrt::UserContext* user_context) { + RuntimeTracebackMode mode = GetRuntimeTracebackMode(); + if (mode == RuntimeTracebackMode::kOff) { // Default case + return; + } + std::optional traceback = GetTraceback(user_context); + if (!traceback.has_value()) { + return; + } + + std::string call_location_str; + if (mode == RuntimeTracebackMode::kFull) { + call_location_str = traceback->ToString(); + } else { // mode == RuntimeTracebackMode::kOn + call_location_str = GetCallLocation(*traceback); + } + + if (!call_location_str.empty()) { + if (!options.custom_options.has_value()) { + options.custom_options.emplace(xla::ifrt::AttributeMap({})); + } + CHECK_OK(options.custom_options->Set( + std::string(xla::ifrt::PjRtCompatibleLoadedExecutable::kCallLocation), + std::move(call_location_str))); + } +} + +// Function to be called from Python to add a single path +void AddExcludePath(std::string path) { + absl::MutexLock lock(shared_data_mu); + exclude_paths_from_python->push_back(std::move(path)); + known_code_objects->clear(); +} + +void SetSendTracebackToRuntimeGlobal(RuntimeTracebackMode mode) { + global_runtime_traceback_mode_.store(mode); +} + +void SetSendTracebackToRuntimeThreadLocal( + std::optional mode) { + runtime_traceback_mode_thread_local_ = mode; +} + +} // namespace jax diff --git a/jaxlib/call_location.h b/jaxlib/call_location.h new file mode 100644 index 000000000000..4413b7112df2 --- /dev/null +++ b/jaxlib/call_location.h @@ -0,0 +1,44 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_CALL_LOCATION_H_ +#define JAXLIB_CALL_LOCATION_H_ + +#include +#include + +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/user_context.h" + +namespace jax { + +enum class RuntimeTracebackMode { + kOff = 0, + kOn = 1, + kFull = 2, +}; + +void PopulateCallLocation(xla::ifrt::ExecuteOptions& options, + const xla::ifrt::UserContext* user_context); + +void AddExcludePath(std::string path); +void SetSendTracebackToRuntimeGlobal(RuntimeTracebackMode mode); +void SetSendTracebackToRuntimeThreadLocal( + std::optional mode); + + +} // namespace jax + +#endif // JAX_JAXLIB_CALL_LOCATION_H_ diff --git a/jaxlib/callback.cc b/jaxlib/callback.cc new file mode 100644 index 000000000000..81263d149c64 --- /dev/null +++ b/jaxlib/callback.cc @@ -0,0 +1,175 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/callback.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/python_ref_manager.h" +#include "xla/pjrt/host_callback.h" +#include "xla/pjrt/transpose.h" +#include "xla/primitive_util.h" +#include "xla/python/nb_numpy.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/python/lib/core/numpy.h" +#include "xla/xla_data.pb.h" + +namespace nb = nanobind; + +namespace jax { + +CpuCallback::~CpuCallback() { + // The destructor may be called without GIL held. In that case, we defer it + // to GlobalPyRefManager. + std::vector objects; + objects.push_back(std::move(callable_)); + for (auto& arg : args_) { + objects.push_back(std::move(arg.dtype)); + } + + GlobalPyRefManager()->AddGarbage(absl::MakeSpan(objects)); +} + +absl::Status CpuCallback::PrepareAndCall(void** result, void** arg_ptrs) { + absl::Span inputs(arg_ptrs, args_.size()); + absl::Span outputs(result, results_.size()); + + nb::gil_scoped_acquire gil; + nb::tuple args = nb::steal(PyTuple_New(inputs.size())); + for (size_t i = 0; i < inputs.size(); ++i) { + if (args_[i].type == xla::TOKEN) { + PyTuple_SET_ITEM(args.ptr(), i, nb::none().release().ptr()); + } else { + xla::nb_numpy_ndarray array = + xla::nb_numpy_ndarray(args_[i].dtype, args_[i].dims, args_[i].strides, + const_cast(inputs[i])); + array.attr("flags").attr("writeable") = nb::bool_(false); + PyTuple_SET_ITEM(args.ptr(), i, array.release().ptr()); + } + } + + absl::StatusOr maybe_result_tuple; + { + xla::HostCallbackScope scope; + maybe_result_tuple = Call(std::move(args)); + } + TF_ASSIGN_OR_RETURN(auto result_tuple, maybe_result_tuple); + + for (size_t i = 0; i < results_.size(); ++i) { + if (results_[i].type == xla::TOKEN) { + continue; + } + nb::object output = + nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); + xla::nb_numpy_ndarray array = + xla::nb_numpy_ndarray::ensure(std::move(output)); + absl::Span dims( + reinterpret_cast(array.shape()), array.ndim()); + absl::Span strides( + reinterpret_cast(array.strides()), array.ndim()); + if (strides == results_[i].expected_strides) { + std::memcpy(outputs[i], array.data(), results_[i].size_in_bytes); + } else { + xla::TransposePlan::Options options; + options.elem_size_in_bytes = + xla::primitive_util::ByteWidth(results_[i].type); + options.dims = dims; + options.permutation = results_[i].reversed_layout; + options.input_striding = xla::TransposePlan::Striding{strides}; + absl::StatusOr> plan = + transpose_cache_.GetOrCreate(options); + if (!plan.ok()) { + return std::move(plan).status(); + } + plan.value()->Execute(array.data(), outputs[i]); + } + } + + return absl::OkStatus(); +} + +absl::StatusOr CpuCallback::Call(nb::tuple args) { + auto py_error_to_status = [](nb::python_error& e) { + std::string error_message = e.what(); + return absl::InternalError( + absl::StrFormat("CpuCallback error: %s", error_message)); + }; + nb::object result_object; + try { + result_object = callable_(*nb::borrow(args)); + } catch (nb::python_error& e) { + return py_error_to_status(e); + } + if (!PyTuple_Check(result_object.ptr())) { + return absl::InternalError( + absl::StrFormat("CPU callback expected a tuple result, got %s", + nb::cast(nb::repr(result_object)))); + } + if (PyTuple_Size(result_object.ptr()) != results_.size()) { + return absl::InternalError( + absl::StrFormat("CPU callback expected a tuple with %d results, got %d", + results_.size(), PyTuple_Size(result_object.ptr()))); + } + nb::tuple result_tuple = nb::cast(result_object); + for (size_t i = 0; i < results_.size(); ++i) { + nb::object output = + nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); + if (results_[i].type == xla::TOKEN) { + if (!output.is_none()) { + return absl::InternalError(absl::StrFormat( + "Token output from Python callback should be None, got %s", + nb::cast(nb::repr(output)))); + } + continue; + } + xla::nb_numpy_ndarray array; + try { + array = xla::nb_numpy_ndarray::from_any(output, NPY_ARRAY_ENSUREARRAY); + } catch (nb::python_error& e) { + return py_error_to_status(e); + } + static_assert(sizeof(ssize_t) == sizeof(int64_t), + "Expected ssize_t to be of equal size to int64_t"); + absl::Span dims( + reinterpret_cast(array.shape()), array.ndim()); + if (dims != results_[i].expected_dims) { + return absl::InternalError(absl::StrFormat( + "Mismatched result shape for %d-th return value from CPU callback; " + "expected array with dimensions %s, got %s", + i, absl::StrJoin(results_[i].expected_dims, ","), + absl::StrJoin(dims, ","))); + } + } + return result_tuple; +} + +} // namespace jax diff --git a/jaxlib/callback.h b/jaxlib/callback.h new file mode 100644 index 000000000000..c74e15e14f69 --- /dev/null +++ b/jaxlib/callback.h @@ -0,0 +1,87 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_CALLBACK_H_ +#define JAXLIB_CALLBACK_H_ + +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "xla/pjrt/transpose.h" +#include "xla/python/nb_numpy.h" +#include "xla/xla_data.pb.h" + +namespace jax { + +class CpuCallback { + public: + struct Arg { + xla::PrimitiveType type; // XLA type + xla::nb_dtype dtype; // NumPy type, for array types. + absl::InlinedVector dims; // Dimensions, for array types. + std::vector strides; // Byte strides, for array types. + size_t size_in_bytes; // Size of the array in bytes. + }; + struct Result { + xla::PrimitiveType type; // XLA type + // Expected output shape, for array types + absl::InlinedVector expected_dims; + // Expected output byte strides, for array types. If the strides do not + // match the output will be transposed into the expected layout. + std::vector expected_strides; + // The desired order of output dimensions in major-to-minor order. + absl::InlinedVector reversed_layout; + // Size of the array in bytes. + size_t size_in_bytes; + }; + + explicit CpuCallback(nanobind::callable callable, std::vector args, + std::vector results) + : callable_(std::move(callable)), + args_(std::move(args)), + results_(std::move(results)), + transpose_cache_(/*capacity=*/16) {} + + ~CpuCallback(); + + const std::vector& args() const { return args_; } + size_t num_args() const { return args_.size(); } + + const std::vector& results() const { return results_; } + size_t num_results() const { return results_.size(); } + void* callback() const { return callable_.ptr(); } + + xla::TransposePlanCache& transpose_cache() { return transpose_cache_; } + + absl::Status PrepareAndCall(void** result, void** arg_ptrs); + + absl::StatusOr Call(nanobind::tuple args); + + private: + nanobind::callable callable_; + std::vector args_; + std::vector results_; + xla::TransposePlanCache transpose_cache_; +}; + +} // namespace jax + +#endif // JAXLIB_CALLBACK_H_ diff --git a/jaxlib/config.cc b/jaxlib/config.cc new file mode 100644 index 000000000000..01d9143ac657 --- /dev/null +++ b/jaxlib/config.cc @@ -0,0 +1,389 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/config.h" + +#include + +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_set.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/typing.h" +#include "jaxlib/python_ref_manager.h" +#include "xla/tsl/platform/logging.h" + +namespace jax { + +namespace nb = nanobind; + +// Singleton object used to represent "value not set" in thread-local configs. +nb::object UnsetObject() { + return nb::steal(PyObject_CallObject( + reinterpret_cast(&PyBaseObject_Type), nullptr)); +} + +// Each configuration object has: +// * a global value, and +// * a thread-local value. +// When querying the state of a config, the thread-local value is used if it is +// set. Otherwise, the global value is used. + +// This class represents all of the thread-local configuration state for a +// thread. +class ThreadLocalConfigState { + public: + ThreadLocalConfigState(); + ~ThreadLocalConfigState(); + + static ThreadLocalConfigState& Instance() { + thread_local auto state = std::make_unique(); + return *state; + } + + nb::object Get(int key) { + DCHECK_GE(key, 0); + return key >= entries_.size() ? nb::object() : entries_[key]; + } + + void Set(int key, nb::object value); + + private: + friend class GlobalConfigState; + + // These values are accessed in one of two ways: + // * The owning thread reads or writes them, while holding the GIL, or, under + // free-threading, while the owning thread is in ATTACHED gc state. + // * Other threads may read or clear values while performing a garbage + // collection. + // No locking is needed because a GC thread cannot run concurrently with other + // Python threads; even under free-threading Python uses a stop-the-world GC. + std::vector entries_; +}; + +// This class represents all of the global configuration state. +class GlobalConfigState { + public: + static GlobalConfigState& Instance() { + static auto state = new GlobalConfigState(); + return *state; + } + + nb::object Get(int key) const; + void Set(int key, nb::object value); + + // Adds or removes a thread-local state from the set of thread-local states. + void AddThreadLocalState(ThreadLocalConfigState* state) { + absl::MutexLock lock(mu_); + thread_local_states_.insert(state); + } + void RemoveThreadLocalState(ThreadLocalConfigState* state) { + absl::MutexLock lock(mu_); + thread_local_states_.erase(state); + } + + // Python GC helpers. These are called from the tp_traverse and tp_clear + // methods of the Config class. + int tp_traverse(int key, PyObject* self, visitproc visit, void* arg); + int tp_clear(int key, PyObject* self); + + // Returns the singleton object representing "value not set". + const nb::object& unset() const { return unset_; } + + // Returns the set of keys that should be included in the jit key. + absl::Span include_in_jit_key() const { + return include_in_jit_key_; + } + + // Returns the set of keys that should be included in the trace context. + absl::Span include_in_trace_context() const { + return include_in_trace_context_; + } + + absl::Span names() const { return names_; } + const std::string& name(int key) const { return names_[key]; } + + private: + friend class Config; + + // The set of thread-local states. This is used during garbage collection to + // visit thread-local values. + absl::Mutex mu_; + absl::flat_hash_set thread_local_states_ + ABSL_GUARDED_BY(mu_); + std::vector names_; + std::vector entries_; + std::vector include_in_jit_key_; + std::vector include_in_trace_context_; + nb::object unset_ = UnsetObject(); +}; + +ThreadLocalConfigState::ThreadLocalConfigState() { + GlobalConfigState::Instance().AddThreadLocalState(this); +} + +ThreadLocalConfigState::~ThreadLocalConfigState() { + // It's important that we remove the thread-local state before we access + // entries_. This ensures that accesses to entries_ are ordered with respect + // any garbage collection. + GlobalConfigState::Instance().RemoveThreadLocalState(this); + // We do not hold the GIL, so we must use deferred destruction. + GlobalPyRefManager()->AddGarbage(absl::MakeSpan(entries_)); +} + +void ThreadLocalConfigState::Set(int key, nb::object value) { + DCHECK_GE(key, 0); + if (key >= entries_.size()) { + entries_.resize(key + 1); + } + std::swap(entries_[key], value); +} + +nb::object GlobalConfigState::Get(int key) const { + DCHECK_GE(key, 0); + DCHECK_LT(key, entries_.size()); + return entries_[key]; +} + +void GlobalConfigState::Set(int key, nb::object value) { + DCHECK_GE(key, 0); + DCHECK_LT(key, entries_.size()); + std::swap(entries_[key], value); +} + +int GlobalConfigState::tp_traverse(int key, PyObject* self, visitproc visit, + void* arg) { + DCHECK_GE(key, 0); + if (key < entries_.size()) { + PyObject* value = entries_[key].ptr(); + Py_VISIT(value); + } + absl::MutexLock lock(mu_); + for (const auto* state : thread_local_states_) { + if (key < state->entries_.size()) { + PyObject* value = state->entries_[key].ptr(); + Py_VISIT(value); + } + } + return 0; +} + +int GlobalConfigState::tp_clear(int key, PyObject* self) { + if (key < entries_.size()) { + nb::object tmp; + std::swap(entries_[key], tmp); + } + // We destroy the python objects outside of the lock out of an abundance of + // caution. + std::vector to_destroy; + absl::MutexLock lock(mu_); + to_destroy.reserve(thread_local_states_.size()); + for (auto* state : thread_local_states_) { + if (key < state->entries_.size()) { + nb::object tmp; + std::swap(state->entries_[key], tmp); + to_destroy.push_back(std::move(tmp)); + } + } + return 0; +} + +Config::Config(std::string name, nb::object value, bool include_in_jit_key, + bool include_in_trace_context) { + auto& instance = GlobalConfigState::Instance(); + key_ = instance.entries_.size(); + instance.names_.push_back(std::move(name)); + instance.entries_.push_back(std::move(value)); + if (include_in_jit_key) { + instance.include_in_jit_key_.push_back(key_); + } + if (include_in_trace_context) { + instance.include_in_trace_context_.push_back(key_); + } +} + +const std::string& Config::Name() { + return GlobalConfigState::Instance().name(key_); +} + +nb::object Config::GetLocal() { + nb::object result = ThreadLocalConfigState::Instance().Get(key_); + if (!result.is_valid()) { + return GlobalConfigState::Instance().unset(); + } + return result; +} + +nb::object Config::GetGlobal() { + return GlobalConfigState::Instance().Get(key_); +} + +nb::object Config::Get() { + nb::object local = ThreadLocalConfigState::Instance().Get(key_); + if (local.is_valid()) { + return local; + } + return GetGlobal(); +} + +void Config::SetLocal(nb::object value) { + const auto& instance = GlobalConfigState::Instance(); + if (value.ptr() == instance.unset().ptr()) { + value = nb::object(); + } + ThreadLocalConfigState::Instance().Set(key_, std::move(value)); +} + +nb::object Config::SwapLocal(nb::object value) { + const auto& global_instance = GlobalConfigState::Instance(); + auto& instance = ThreadLocalConfigState::Instance(); + auto result = instance.Get(key_); + if (value.ptr() == global_instance.unset().ptr()) { + value = nb::object(); + } + instance.Set(key_, std::move(value)); + if (!result.is_valid()) { + return global_instance.unset(); + } + return result; +} + +void Config::SetGlobal(nb::object value) { + GlobalConfigState::Instance().Set(key_, value); +} + +/* static */ int Config::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + Py_VISIT(Py_TYPE(self)); + if (!nb::inst_ready(self)) { + return 0; + } + Config* c = nb::inst_ptr(self); + // For the purposes of GC, we pretend that this object owns both the global + // and any thread-local values corresponding to this key. + return GlobalConfigState::Instance().tp_traverse(c->key_, self, visit, arg); +} + +/* static */ int Config::tp_clear(PyObject* self) { + Config* c = nb::inst_ptr(self); + return GlobalConfigState::Instance().tp_clear(c->key_, self); +} + +PyType_Slot Config::slots_[] = { + {Py_tp_traverse, reinterpret_cast(Config::tp_traverse)}, + {Py_tp_clear, reinterpret_cast(Config::tp_clear)}, + {0, nullptr}, +}; + +/* static */ const nb::object& Config::UnsetObject() { + return GlobalConfigState::Instance().unset(); +} + +std::vector JitConfigs() { + auto& instance = GlobalConfigState::Instance(); + auto& thread_local_instance = ThreadLocalConfigState::Instance(); + std::vector result; + result.reserve(instance.include_in_jit_key().size()); + for (int i : instance.include_in_jit_key()) { + nb::object local = thread_local_instance.Get(i); + if (local.is_valid()) { + result.push_back(std::move(local)); + } else { + result.push_back(instance.Get(i)); + } + } + return result; +} + +std::vector JitConfigNames() { + auto& instance = GlobalConfigState::Instance(); + std::vector result; + result.reserve(instance.include_in_jit_key().size()); + for (int i : instance.include_in_jit_key()) { + result.push_back(instance.name(i)); + } + return result; +} + +nanobind::tuple TraceContext() { + auto& instance = GlobalConfigState::Instance(); + auto& thread_local_instance = ThreadLocalConfigState::Instance(); + nb::tuple result = nb::steal( + PyTuple_New(instance.include_in_trace_context().size())); + int pos = 0; + for (int i : instance.include_in_trace_context()) { + nb::object local = thread_local_instance.Get(i); + if (local.is_valid()) { + PyTuple_SET_ITEM(result.ptr(), pos, local.release().ptr()); + } else { + nb::object global = instance.Get(i); + PyTuple_SET_ITEM(result.ptr(), pos, global.release().ptr()); + } + ++pos; + } + return result; +} + +void BuildConfigSubmodule(nanobind::module_& m) { + nb::module_ config_module = m.def_submodule("config", "Config library"); + + config_module.attr("unset") = GlobalConfigState::Instance().unset(); + + config_module.attr("_T") = nb::type_var("_T"); + + nb::class_ config(config_module, "Config", + nb::type_slots(Config::slots_), nb::is_generic(), + nb::sig("class Config(typing.Generic[_T])")); + config.def(nb::init(), nb::arg("name"), + nb::arg("value").none(), nb::kw_only(), + nb::arg("include_in_jit_key") = false, + nb::arg("include_in_trace_context") = false, + nb::sig( + // clang-format off + "def __init__(" + "self, " + "name: str, " + "value: _T, *, " + "include_in_jit_key: bool = ..., " + "include_in_trace_context: bool = ..." + ") -> None" + // clang-format on + )); + config.def_prop_ro("value", &Config::Get, nb::sig("def value(self) -> _T")); + config.def_prop_ro("name", &Config::Name); + // TODO(slebedev): All getters and setters should be using _T. + config.def("get_local", &Config::GetLocal, + nb::sig("def get_local(self) -> typing.Any")); + config.def("get_global", &Config::GetGlobal, + nb::sig("def get_global(self) -> _T")); + config.def("set_local", &Config::SetLocal, nb::arg("value").none(), + nb::sig("def set_local(self, value: Any | None) -> None")); + config.def("swap_local", &Config::SwapLocal, nb::arg("value").none(), + nb::sig("def swap_local(self, value: Any | None) -> Any")); + config.def("set_global", &Config::SetGlobal, nb::arg("value").none(), + nb::sig("def set_global(self, value: Any | None) -> None")); + + config_module.def("trace_context", &TraceContext); +} + +} // namespace jax diff --git a/jaxlib/config.h b/jaxlib/config.h new file mode 100644 index 000000000000..aebf6f958da6 --- /dev/null +++ b/jaxlib/config.h @@ -0,0 +1,83 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_CONFIG_H_ +#define JAXLIB_CONFIG_H_ + +#include +#include + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace jax { + +// A Config object represents a configurable object with both global and +// thread-local state. This class is wrapped using nanobind and exposed to +// Python. +class Config { + public: + Config(std::string name, nanobind::object value, bool include_in_jit_key, + bool include_in_trace_context); + + // Returns the name of the config. + const std::string& Name(); + + // Returns the thread-local value if it is set, otherwise the global value. + nanobind::object Get(); + + // Returns the global value. + nanobind::object GetGlobal(); + + // Sets the global value. + void SetGlobal(nanobind::object value); + + // Returns the thread-local value. + nanobind::object GetLocal(); + + // Sets the thread-local value. May be `unset`. + void SetLocal(nanobind::object value); + + // Swaps the thread-local value with `value`. Returns the previous value. + // Either may be `unset`. + nanobind::object SwapLocal(nanobind::object value); + + // This class doesn't actually hold any data, but it's the only type + // known to Python. We pretend that this object owns both the global and any + // thread-local values corresponding to this key. + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); + static PyType_Slot slots_[]; + + static const nanobind::object& UnsetObject(); + + private: + int key_; +}; + +// Returns the set of configuration values that should be included in the JIT +// cache key. +std::vector JitConfigs(); + +// The corresponding config names, for debugging. +std::vector JitConfigNames(); + +nanobind::tuple TraceContext(); + +void BuildConfigSubmodule(nanobind::module_& m); + +} // namespace jax + +#endif // JAXLIB_CONFIG_H_ diff --git a/jaxlib/config_test.py b/jaxlib/config_test.py new file mode 100644 index 000000000000..5feda772930a --- /dev/null +++ b/jaxlib/config_test.py @@ -0,0 +1,78 @@ +# Copyright 2024 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import threading + +from absl.testing import absltest + +from jax.jaxlib import xla_client + +config = xla_client._xla.config + + +class ConfigTest(absltest.TestCase): + + def testBasic(self): + if xla_client._version >= 376: + c = config.Config("test", 1) + self.assertEqual(c.name, "test") + else: + c = config.Config(1) + self.assertEqual(c.value, 1) + self.assertEqual(c.get_global(), 1) + self.assertEqual(c.get_local(), config.unset) + + c.set_global(2) + self.assertEqual(c.value, 2) + self.assertEqual(c.get_global(), 2) + self.assertEqual(c.get_local(), config.unset) + + c.set_local(3) + self.assertEqual(c.value, 3) + self.assertEqual(c.get_global(), 2) + self.assertEqual(c.get_local(), 3) + + c.set_global(4) + self.assertEqual(c.value, 3) + self.assertEqual(c.get_global(), 4) + self.assertEqual(c.get_local(), 3) + + c.set_local(config.unset) + self.assertEqual(c.value, 4) + self.assertEqual(c.get_global(), 4) + self.assertEqual(c.get_local(), config.unset) + + def testThreading(self): + if xla_client._version >= 376: + c = config.Config("test", 1) + else: + c = config.Config(1) + + def Body(): + for i in range(100): + c.set_local(i) + self.assertEqual(c.get_local(), i) + self.assertEqual(c.get_global(), 1) + self.assertEqual(c.value, i) + + threads = [threading.Thread(target=Body) for _ in range(4)] + for t in threads: + t.start() + for t in threads: + t.join() + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxlib/cpu/BUILD b/jaxlib/cpu/BUILD index 76934df6c37b..229f630bfb35 100644 --- a/jaxlib/cpu/BUILD +++ b/jaxlib/cpu/BUILD @@ -14,6 +14,7 @@ # JAX is Autograd and XLA +load("@rules_cc//cc:cc_library.bzl", "cc_library") load( "//jaxlib:jax.bzl", "nanobind_extension", @@ -32,6 +33,7 @@ cc_library( name = "lapack_kernels", srcs = ["lapack_kernels.cc"], hdrs = ["lapack_kernels.h"], + # compatible with libtpu copts = ["-fexceptions"], features = ["-use_header_modules"], deps = [ @@ -42,13 +44,13 @@ cc_library( "@com_google_absl//absl/types:span", "@xla//xla/ffi/api:c_api", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", ], ) cc_library( name = "lapack_kernels_using_lapack", srcs = ["lapack_kernels_using_lapack.cc"], + # compatible with libtpu deps = [":lapack_kernels"], alwayslink = 1, ) @@ -81,13 +83,47 @@ nanobind_extension( cc_library( name = "cpu_kernels", srcs = ["cpu_kernels.cc"], + # compatible with libtpu visibility = ["//visibility:public"], deps = [ ":lapack_kernels", ":lapack_kernels_using_lapack", + ":sparse_kernels", "@xla//xla/ffi/api:c_api", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_target_registry", ], alwayslink = 1, ) + +cc_library( + name = "sparse_kernels", + srcs = ["sparse_kernels.cc"], + hdrs = ["sparse_kernels.h"], + # compatible with libtpu + deps = [ + "@eigen_archive//:eigen3", + "@xla//xla/ffi/api:ffi", + ], +) + +nanobind_extension( + name = "_sparse", + srcs = ["sparse.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + enable_stub_generation = False, + features = ["-use_header_modules"], + module_name = "_sparse", + pytype_srcs = [ + "_sparse/__init__.pyi", + ], + deps = [ + ":sparse_kernels", + "//jaxlib:kernel_nanobind_helpers", + "@com_google_absl//absl/base", + "@nanobind", + "@xla//xla/ffi/api:ffi", + ], +) diff --git a/jaxlib/cpu/_lapack/__init__.pyi b/jaxlib/cpu/_lapack/__init__.pyi index 4275d8e48813..f8b9a023b480 100644 --- a/jaxlib/cpu/_lapack/__init__.pyi +++ b/jaxlib/cpu/_lapack/__init__.pyi @@ -17,39 +17,3 @@ from . import eig as eig def initialize() -> None: ... def registrations() -> dict: ... - - -# Old-style LAPACK Workspace Size Queries -def cgesdd_rwork_size(m: int, n: int, compute_uv: int) -> int: ... -def cgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matrices: bool) -> int: ... -def dgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matrices: bool) -> int: ... -def gesdd_iwork_size(m: int, n: int) -> int: ... -def heevd_rwork_size(n: int) -> int: ... -def heevd_work_size(n: int) -> int: ... -def lapack_cgehrd_workspace(lda: int, n: int, ilo: int, ihi: int) -> int: ... -def lapack_cgeqrf_workspace(m: int, n: int) -> int: ... -def lapack_chetrd_workspace(lda: int, n: int) -> int: ... -def lapack_cungqr_workspace(m: int, n: int, k: int) -> int: ... -def lapack_dgehrd_workspace(lda: int, n: int, ilo: int, ihi: int) -> int: ... -def lapack_dgeqrf_workspace(m: int, n: int) -> int: ... -def lapack_dorgqr_workspace(m: int, n: int, k: int) -> int: ... -def lapack_dsytrd_workspace(lda: int, n: int) -> int: ... -def lapack_sgehrd_workspace(lda: int, n: int, ilo: int, ihi: int) -> int: ... -def lapack_sgeqrf_workspace(m: int, n: int) -> int: ... -def lapack_sorgqr_workspace(m: int, n: int, k: int) -> int: ... -def lapack_ssytrd_workspace(lda: int, n: int) -> int: ... -def lapack_zgehrd_workspace(lda: int, n: int, ilo: int, ihi: int) -> int: ... -def lapack_zgeqrf_workspace(m: int, n: int) -> int: ... -def lapack_zhetrd_workspace(lda: int, n: int) -> int: ... -def lapack_zungqr_workspace(m: int, n: int, k: int) -> int: ... -def sgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matrices: bool) -> int: ... -def syevd_iwork_size(n: int) -> int: ... -def syevd_work_size(n: int) -> int: ... -def zgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matrices: bool) -> int: ... - - -# FFI Kernel LAPACK Workspace Size Queries -def lapack_cungqr_workspace_ffi(m: int, n: int, k: int) -> int: ... -def lapack_dorgqr_workspace_ffi(m: int, n: int, k: int) -> int: ... -def lapack_sorgqr_workspace_ffi(m: int, n: int, k: int) -> int: ... -def lapack_zungqr_workspace_ffi(m: int, n: int, k: int) -> int: ... diff --git a/jaxlib/cpu/_sparse/__init__.pyi b/jaxlib/cpu/_sparse/__init__.pyi new file mode 100644 index 000000000000..a82f83b267b7 --- /dev/null +++ b/jaxlib/cpu/_sparse/__init__.pyi @@ -0,0 +1,15 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +def registrations() -> dict: ... diff --git a/jaxlib/cpu/cpu_kernels.cc b/jaxlib/cpu/cpu_kernels.cc index 6ed42496f2f2..bc4d0c74734a 100644 --- a/jaxlib/cpu/cpu_kernels.cc +++ b/jaxlib/cpu/cpu_kernels.cc @@ -16,12 +16,10 @@ limitations under the License. // This file is not used by JAX itself, but exists to assist with running // JAX-generated HLO code from outside of JAX. -#include - #include "jaxlib/cpu/lapack_kernels.h" +#include "jaxlib/cpu/sparse_kernels.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/api/ffi.h" -#include "xla/service/custom_call_target_registry.h" #define JAX_CPU_REGISTER_HANDLER(name) \ XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), #name, "Host", name); @@ -29,94 +27,6 @@ limitations under the License. namespace jax { namespace { -// Old-style kernels -// TODO(b/344892332): To be removed after the 6M compatibility period is over. - -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("blas_strsm", Trsm::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("blas_dtrsm", Trsm::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("blas_ctrsm", - Trsm>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("blas_ztrsm", - Trsm>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_sgetrf", Getrf::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dgetrf", Getrf::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_cgetrf", - Getrf>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_zgetrf", - Getrf>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_sgeqrf", Geqrf::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dgeqrf", Geqrf::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_cgeqrf", - Geqrf>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_zgeqrf", - Geqrf>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_sorgqr", Orgqr::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dorgqr", Orgqr::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_cungqr", - Orgqr>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_zungqr", - Orgqr>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_spotrf", Potrf::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dpotrf", Potrf::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_cpotrf", - Potrf>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_zpotrf", - Potrf>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_sgesdd", - RealGesdd::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dgesdd", - RealGesdd::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "lapack_cgesdd", ComplexGesdd>::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "lapack_zgesdd", ComplexGesdd>::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_ssyevd", - RealSyevd::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dsyevd", - RealSyevd::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "lapack_cheevd", ComplexHeevd>::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "lapack_zheevd", ComplexHeevd>::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_sgeev", - RealGeev::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dgeev", - RealGeev::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "lapack_cgeev", ComplexGeev>::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "lapack_zgeev", ComplexGeev>::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_sgees", - RealGees::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dgees", - RealGees::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "lapack_cgees", ComplexGees>::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "lapack_zgees", ComplexGees>::Kernel, "Host"); - -// FFI Kernels - JAX_CPU_REGISTER_HANDLER(lapack_strsm_ffi); JAX_CPU_REGISTER_HANDLER(lapack_dtrsm_ffi); JAX_CPU_REGISTER_HANDLER(lapack_ctrsm_ffi); @@ -174,6 +84,8 @@ JAX_CPU_REGISTER_HANDLER(lapack_dgtsv_ffi); JAX_CPU_REGISTER_HANDLER(lapack_cgtsv_ffi); JAX_CPU_REGISTER_HANDLER(lapack_zgtsv_ffi); +JAX_CPU_REGISTER_HANDLER(cpu_csr_sparse_dense_ffi); + #undef JAX_CPU_REGISTER_HANDLER } // namespace diff --git a/jaxlib/cpu/lapack.cc b/jaxlib/cpu/lapack.cc index c104019777e5..9e89a69cd632 100644 --- a/jaxlib/cpu/lapack.cc +++ b/jaxlib/cpu/lapack.cc @@ -13,10 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - -#include "nanobind/nanobind.h" #include "absl/base/call_once.h" +#include "nanobind/nanobind.h" #include "jaxlib/cpu/lapack_kernels.h" #include "jaxlib/kernel_nanobind_helpers.h" @@ -29,6 +27,9 @@ using ::xla::ffi::DataType; void GetLapackKernelsFromScipy() { static absl::once_flag initialized; + if (lapack_kernels_initialized) { + return; + } // For reasons I'm not entirely sure of, if the import_ call is done inside // the call_once scope, we sometimes observe deadlocks in the test suite. // However it probably doesn't do much harm to just import them a second time, @@ -45,10 +46,6 @@ void GetLapackKernelsFromScipy() { return nb::cast(blas_capi[name]).data(); }; - AssignKernelFn>(blas_ptr("strsm")); - AssignKernelFn>(blas_ptr("dtrsm")); - AssignKernelFn>>(blas_ptr("ctrsm")); - AssignKernelFn>>(blas_ptr("ztrsm")); AssignKernelFn>(blas_ptr("strsm")); AssignKernelFn>(blas_ptr("dtrsm")); AssignKernelFn>(blas_ptr("ctrsm")); @@ -58,19 +55,11 @@ void GetLapackKernelsFromScipy() { auto lapack_ptr = [&](const char* name) { return nb::cast(lapack_capi[name]).data(); }; - AssignKernelFn>(lapack_ptr("sgetrf")); - AssignKernelFn>(lapack_ptr("dgetrf")); - AssignKernelFn>>(lapack_ptr("cgetrf")); - AssignKernelFn>>(lapack_ptr("zgetrf")); AssignKernelFn>(lapack_ptr("sgetrf")); AssignKernelFn>(lapack_ptr("dgetrf")); AssignKernelFn>(lapack_ptr("cgetrf")); AssignKernelFn>(lapack_ptr("zgetrf")); - AssignKernelFn>(lapack_ptr("sgeqrf")); - AssignKernelFn>(lapack_ptr("dgeqrf")); - AssignKernelFn>>(lapack_ptr("cgeqrf")); - AssignKernelFn>>(lapack_ptr("zgeqrf")); AssignKernelFn>(lapack_ptr("sgeqrf")); AssignKernelFn>(lapack_ptr("dgeqrf")); AssignKernelFn>(lapack_ptr("cgeqrf")); @@ -85,28 +74,16 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>( lapack_ptr("zgeqp3")); - AssignKernelFn>(lapack_ptr("sorgqr")); - AssignKernelFn>(lapack_ptr("dorgqr")); - AssignKernelFn>>(lapack_ptr("cungqr")); - AssignKernelFn>>(lapack_ptr("zungqr")); AssignKernelFn>(lapack_ptr("sorgqr")); AssignKernelFn>(lapack_ptr("dorgqr")); AssignKernelFn>(lapack_ptr("cungqr")); AssignKernelFn>(lapack_ptr("zungqr")); - AssignKernelFn>(lapack_ptr("spotrf")); - AssignKernelFn>(lapack_ptr("dpotrf")); - AssignKernelFn>>(lapack_ptr("cpotrf")); - AssignKernelFn>>(lapack_ptr("zpotrf")); AssignKernelFn>(lapack_ptr("spotrf")); AssignKernelFn>(lapack_ptr("dpotrf")); AssignKernelFn>(lapack_ptr("cpotrf")); AssignKernelFn>(lapack_ptr("zpotrf")); - AssignKernelFn>(lapack_ptr("sgesdd")); - AssignKernelFn>(lapack_ptr("dgesdd")); - AssignKernelFn>>(lapack_ptr("cgesdd")); - AssignKernelFn>>(lapack_ptr("zgesdd")); AssignKernelFn>(lapack_ptr("sgesdd")); AssignKernelFn>(lapack_ptr("dgesdd")); AssignKernelFn>(lapack_ptr("cgesdd")); @@ -116,10 +93,6 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>(lapack_ptr("cgesvd")); AssignKernelFn>(lapack_ptr("zgesvd")); - AssignKernelFn>(lapack_ptr("ssyevd")); - AssignKernelFn>(lapack_ptr("dsyevd")); - AssignKernelFn>>(lapack_ptr("cheevd")); - AssignKernelFn>>(lapack_ptr("zheevd")); AssignKernelFn>( lapack_ptr("ssyevd")); AssignKernelFn>( @@ -129,10 +102,6 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>( lapack_ptr("zheevd")); - AssignKernelFn>(lapack_ptr("sgeev")); - AssignKernelFn>(lapack_ptr("dgeev")); - AssignKernelFn>>(lapack_ptr("cgeev")); - AssignKernelFn>>(lapack_ptr("zgeev")); AssignKernelFn>(lapack_ptr("sgeev")); AssignKernelFn>(lapack_ptr("dgeev")); AssignKernelFn>( @@ -140,10 +109,6 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>( lapack_ptr("zgeev")); - AssignKernelFn>(lapack_ptr("sgees")); - AssignKernelFn>(lapack_ptr("dgees")); - AssignKernelFn>>(lapack_ptr("cgees")); - AssignKernelFn>>(lapack_ptr("zgees")); AssignKernelFn>(lapack_ptr("sgees")); AssignKernelFn>(lapack_ptr("dgees")); AssignKernelFn>( @@ -151,10 +116,6 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>( lapack_ptr("zgees")); - AssignKernelFn>(lapack_ptr("sgehrd")); - AssignKernelFn>(lapack_ptr("dgehrd")); - AssignKernelFn>>(lapack_ptr("cgehrd")); - AssignKernelFn>>(lapack_ptr("zgehrd")); AssignKernelFn>( lapack_ptr("sgehrd")); AssignKernelFn>( @@ -164,10 +125,6 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>( lapack_ptr("zgehrd")); - AssignKernelFn>(lapack_ptr("ssytrd")); - AssignKernelFn>(lapack_ptr("dsytrd")); - AssignKernelFn>>(lapack_ptr("chetrd")); - AssignKernelFn>>(lapack_ptr("zhetrd")); AssignKernelFn>(lapack_ptr("ssytrd")); AssignKernelFn>(lapack_ptr("dsytrd")); AssignKernelFn>(lapack_ptr("chetrd")); @@ -182,74 +139,6 @@ void GetLapackKernelsFromScipy() { nb::dict Registrations() { nb::dict dict; - dict["blas_strsm"] = EncapsulateFunction(Trsm::Kernel); - dict["blas_dtrsm"] = EncapsulateFunction(Trsm::Kernel); - dict["blas_ctrsm"] = EncapsulateFunction(Trsm>::Kernel); - dict["blas_ztrsm"] = EncapsulateFunction(Trsm>::Kernel); - dict["lapack_sgetrf"] = EncapsulateFunction(Getrf::Kernel); - dict["lapack_dgetrf"] = EncapsulateFunction(Getrf::Kernel); - dict["lapack_cgetrf"] = - EncapsulateFunction(Getrf>::Kernel); - dict["lapack_zgetrf"] = - EncapsulateFunction(Getrf>::Kernel); - dict["lapack_sgeqrf"] = EncapsulateFunction(Geqrf::Kernel); - dict["lapack_dgeqrf"] = EncapsulateFunction(Geqrf::Kernel); - dict["lapack_cgeqrf"] = - EncapsulateFunction(Geqrf>::Kernel); - dict["lapack_zgeqrf"] = - EncapsulateFunction(Geqrf>::Kernel); - dict["lapack_sorgqr"] = EncapsulateFunction(Orgqr::Kernel); - dict["lapack_dorgqr"] = EncapsulateFunction(Orgqr::Kernel); - dict["lapack_cungqr"] = - EncapsulateFunction(Orgqr>::Kernel); - dict["lapack_zungqr"] = - EncapsulateFunction(Orgqr>::Kernel); - dict["lapack_spotrf"] = EncapsulateFunction(Potrf::Kernel); - dict["lapack_dpotrf"] = EncapsulateFunction(Potrf::Kernel); - dict["lapack_cpotrf"] = - EncapsulateFunction(Potrf>::Kernel); - dict["lapack_zpotrf"] = - EncapsulateFunction(Potrf>::Kernel); - dict["lapack_sgesdd"] = EncapsulateFunction(RealGesdd::Kernel); - dict["lapack_dgesdd"] = EncapsulateFunction(RealGesdd::Kernel); - dict["lapack_cgesdd"] = - EncapsulateFunction(ComplexGesdd>::Kernel); - dict["lapack_zgesdd"] = - EncapsulateFunction(ComplexGesdd>::Kernel); - dict["lapack_ssyevd"] = EncapsulateFunction(RealSyevd::Kernel); - dict["lapack_dsyevd"] = EncapsulateFunction(RealSyevd::Kernel); - dict["lapack_cheevd"] = - EncapsulateFunction(ComplexHeevd>::Kernel); - dict["lapack_zheevd"] = - EncapsulateFunction(ComplexHeevd>::Kernel); - dict["lapack_sgeev"] = EncapsulateFunction(RealGeev::Kernel); - dict["lapack_dgeev"] = EncapsulateFunction(RealGeev::Kernel); - dict["lapack_cgeev"] = - EncapsulateFunction(ComplexGeev>::Kernel); - dict["lapack_zgeev"] = - EncapsulateFunction(ComplexGeev>::Kernel); - - dict["lapack_sgees"] = EncapsulateFunction(RealGees::Kernel); - dict["lapack_dgees"] = EncapsulateFunction(RealGees::Kernel); - dict["lapack_cgees"] = - EncapsulateFunction(ComplexGees>::Kernel); - dict["lapack_zgees"] = - EncapsulateFunction(ComplexGees>::Kernel); - - dict["lapack_sgehrd"] = EncapsulateFunction(Gehrd::Kernel); - dict["lapack_dgehrd"] = EncapsulateFunction(Gehrd::Kernel); - dict["lapack_cgehrd"] = - EncapsulateFunction(Gehrd>::Kernel); - dict["lapack_zgehrd"] = - EncapsulateFunction(Gehrd>::Kernel); - - dict["lapack_ssytrd"] = EncapsulateFunction(Sytrd::Kernel); - dict["lapack_dsytrd"] = EncapsulateFunction(Sytrd::Kernel); - dict["lapack_chetrd"] = - EncapsulateFunction(Sytrd>::Kernel); - dict["lapack_zhetrd"] = - EncapsulateFunction(Sytrd>::Kernel); - dict["lapack_strsm_ffi"] = EncapsulateFunction(lapack_strsm_ffi); dict["lapack_dtrsm_ffi"] = EncapsulateFunction(lapack_dtrsm_ffi); dict["lapack_ctrsm_ffi"] = EncapsulateFunction(lapack_ctrsm_ffi); @@ -335,73 +224,6 @@ NB_MODULE(_lapack, m) { nb::enum_(schur, "Sort") .value("kNoSortEigenvalues", schur::Sort::kNoSortEigenvalues) .value("kSortEigenvalues", schur::Sort::kSortEigenvalues); - - // Old-style LAPACK Workspace Size Queries - m.def("lapack_sgeqrf_workspace", &Geqrf::Workspace, nb::arg("m"), - nb::arg("n")); - m.def("lapack_dgeqrf_workspace", &Geqrf::Workspace, nb::arg("m"), - nb::arg("n")); - m.def("lapack_cgeqrf_workspace", &Geqrf>::Workspace, - nb::arg("m"), nb::arg("n")); - m.def("lapack_zgeqrf_workspace", &Geqrf>::Workspace, - nb::arg("m"), nb::arg("n")); - m.def("lapack_sorgqr_workspace", &Orgqr::Workspace, nb::arg("m"), - nb::arg("n"), nb::arg("k")); - m.def("lapack_dorgqr_workspace", &Orgqr::Workspace, nb::arg("m"), - nb::arg("n"), nb::arg("k")); - m.def("lapack_cungqr_workspace", &Orgqr>::Workspace, - nb::arg("m"), nb::arg("n"), nb::arg("k")); - m.def("lapack_zungqr_workspace", &Orgqr>::Workspace, - nb::arg("m"), nb::arg("n"), nb::arg("k")); - m.def("gesdd_iwork_size", &GesddIworkSize, nb::arg("m"), nb::arg("n")); - m.def("sgesdd_work_size", &RealGesdd::Workspace, nb::arg("m"), - nb::arg("n"), nb::arg("job_opt_compute_uv"), - nb::arg("job_opt_full_matrices")); - m.def("dgesdd_work_size", &RealGesdd::Workspace, nb::arg("m"), - nb::arg("n"), nb::arg("job_opt_compute_uv"), - nb::arg("job_opt_full_matrices")); - m.def("cgesdd_rwork_size", &ComplexGesddRworkSize, nb::arg("m"), nb::arg("n"), - nb::arg("compute_uv")); - m.def("cgesdd_work_size", &ComplexGesdd>::Workspace, - nb::arg("m"), nb::arg("n"), nb::arg("job_opt_compute_uv"), - nb::arg("job_opt_full_matrices")); - m.def("zgesdd_work_size", &ComplexGesdd>::Workspace, - nb::arg("m"), nb::arg("n"), nb::arg("job_opt_compute_uv"), - nb::arg("job_opt_full_matrices")); - m.def("syevd_work_size", &SyevdWorkSize, nb::arg("n")); - m.def("syevd_iwork_size", &SyevdIworkSize, nb::arg("n")); - m.def("heevd_work_size", &HeevdWorkSize, nb::arg("n")); - m.def("heevd_rwork_size", &HeevdRworkSize, nb::arg("n")); - - m.def("lapack_sgehrd_workspace", &Gehrd::Workspace, nb::arg("lda"), - nb::arg("n"), nb::arg("ilo"), nb::arg("ihi")); - m.def("lapack_dgehrd_workspace", &Gehrd::Workspace, nb::arg("lda"), - nb::arg("n"), nb::arg("ilo"), nb::arg("ihi")); - m.def("lapack_cgehrd_workspace", &Gehrd>::Workspace, - nb::arg("lda"), nb::arg("n"), nb::arg("ilo"), nb::arg("ihi")); - m.def("lapack_zgehrd_workspace", &Gehrd>::Workspace, - nb::arg("lda"), nb::arg("n"), nb::arg("ilo"), nb::arg("ihi")); - m.def("lapack_ssytrd_workspace", &Sytrd::Workspace, nb::arg("lda"), - nb::arg("n")); - m.def("lapack_dsytrd_workspace", &Sytrd::Workspace, nb::arg("lda"), - nb::arg("n")); - m.def("lapack_chetrd_workspace", &Sytrd>::Workspace, - nb::arg("lda"), nb::arg("n")); - m.def("lapack_zhetrd_workspace", &Sytrd>::Workspace, - nb::arg("lda"), nb::arg("n")); - // FFI Kernel LAPACK Workspace Size Queries - m.def("lapack_sorgqr_workspace_ffi", - &OrthogonalQr::GetWorkspaceSize, nb::arg("m"), - nb::arg("n"), nb::arg("k")); - m.def("lapack_dorgqr_workspace_ffi", - &OrthogonalQr::GetWorkspaceSize, nb::arg("m"), - nb::arg("n"), nb::arg("k")); - m.def("lapack_cungqr_workspace_ffi", - &OrthogonalQr::GetWorkspaceSize, nb::arg("m"), - nb::arg("n"), nb::arg("k")); - m.def("lapack_zungqr_workspace_ffi", - &OrthogonalQr::GetWorkspaceSize, nb::arg("m"), - nb::arg("n"), nb::arg("k")); } } // namespace diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index ddc93261eeb5..2ede89572365 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -18,14 +18,11 @@ limitations under the License. #include #include #include -#include #include -#include #include -#include #include #include -#include +#include #include "absl/algorithm/container.h" #include "absl/base/dynamic_annotations.h" @@ -34,45 +31,27 @@ limitations under the License. #include "jaxlib/ffi_helpers.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/api/ffi.h" -#include "xla/service/custom_call_status.h" static_assert(sizeof(jax::lapack_int) == sizeof(int32_t), "Expected LAPACK integers to be 32-bit"); namespace ffi = xla::ffi; -#define REGISTER_CHAR_ENUM_ATTR_DECODING(type) \ - std::optional xla::ffi::AttrDecoding::Decode( \ - XLA_FFI_AttrType attr_type, void* attr, DiagnosticEngine& diagnostic) { \ - if (attr_type != XLA_FFI_AttrType_SCALAR) [[unlikely]] { \ - return diagnostic.Emit("Wrong attribute type: expected ") \ - << XLA_FFI_AttrType_SCALAR << " but got" << attr_type; \ - } \ - auto* scalar = reinterpret_cast(attr); \ - if (scalar->dtype != XLA_FFI_DataType_U8) [[unlikely]] { \ - return diagnostic.Emit("Wrong scalar data type: expected ") \ - << XLA_FFI_DataType_U8 << " but got " << scalar->dtype; \ - } \ - auto underlying = \ - *reinterpret_cast*>(scalar->value); \ - return static_cast(underlying); \ - } - -REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Side); -REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Transpose); -REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Diag); -REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::UpLo); -REGISTER_CHAR_ENUM_ATTR_DECODING(jax::svd::ComputationMode); -REGISTER_CHAR_ENUM_ATTR_DECODING(jax::eig::ComputationMode); -REGISTER_CHAR_ENUM_ATTR_DECODING(jax::schur::ComputationMode); -REGISTER_CHAR_ENUM_ATTR_DECODING(jax::schur::Sort); - -#undef REGISTER_CHAR_ENUM_ATTR_DECODING +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::MatrixParams::Side); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::MatrixParams::Transpose); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::MatrixParams::Diag); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::MatrixParams::UpLo); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::svd::ComputationMode); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::eig::ComputationMode); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::schur::ComputationMode); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::schur::Sort); namespace jax { +bool lapack_kernels_initialized = false; + template -inline T CastNoOverflow(int64_t value, const std::string& source = __FILE__) { +inline T CastNoOverflow(int64_t value, std::string_view source = __FILE__) { auto result = MaybeCastNoOverflow(value, source); if (!result.ok()) { throw std::overflow_error{std::string(result.status().message())}; @@ -90,69 +69,11 @@ void CopyIfDiffBuffer(ffi::Buffer x, ffi::ResultBuffer x_out) { //== Triangular System Solver ==// -// lapack trsm - -template -typename Trsm::FnType* Trsm::fn = nullptr; - -template -void Trsm::Kernel(void* out, void** data, XlaCustomCallStatus*) { - int32_t left_side = *reinterpret_cast(data[0]); - int32_t lower = *reinterpret_cast(data[1]); - int32_t trans_a = *reinterpret_cast(data[2]); - int32_t diag = *reinterpret_cast(data[3]); - int m = *reinterpret_cast(data[4]); - int n = *reinterpret_cast(data[5]); - int batch = *reinterpret_cast(data[6]); - T* alpha = reinterpret_cast(data[7]); - T* a = reinterpret_cast(data[8]); - T* b = reinterpret_cast(data[9]); - - T* x = reinterpret_cast(out); - if (x != b) { - std::memcpy(x, b, - static_cast(batch) * static_cast(m) * - static_cast(n) * sizeof(T)); - } - - char cside = left_side ? 'L' : 'R'; - char cuplo = lower ? 'L' : 'U'; - char ctransa = 'N'; - if (trans_a == 1) { - ctransa = 'T'; - } else if (trans_a == 2) { - ctransa = 'C'; - } - char cdiag = diag ? 'U' : 'N'; - int lda = left_side ? m : n; - int ldb = m; - - int64_t x_plus = static_cast(m) * static_cast(n); - int64_t a_plus = static_cast(lda) * static_cast(lda); - - for (int i = 0; i < batch; ++i) { - fn(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb); - x += x_plus; - a += a_plus; - } -} - -template struct Trsm; -template struct Trsm; -template struct Trsm>; -template struct Trsm>; - -// FFI Kernel - template ffi::Error TriMatrixEquationSolver::Kernel( - ffi::Buffer x, ffi::Buffer y, - // TODO(b/397715595): Remove RemainingArgs no earlier than 180 days after - // the release of JAX 0.5.2. - ffi::RemainingArgs, - ffi::ResultBuffer y_out, MatrixParams::Side side, - MatrixParams::UpLo uplo, MatrixParams::Transpose trans_x, - MatrixParams::Diag diag) { + ffi::Buffer x, ffi::Buffer y, ffi::ResultBuffer y_out, + MatrixParams::Side side, MatrixParams::UpLo uplo, + MatrixParams::Transpose trans_x, MatrixParams::Diag diag) { CopyIfDiffBuffer(y, y_out); FFI_ASSIGN_OR_RETURN((auto [batch_count, y_rows, y_cols]), SplitBatch2D(y.dimensions())); @@ -189,42 +110,6 @@ template struct TriMatrixEquationSolver; //== LU Decomposition ==// -// lapack getrf - -template -typename Getrf::FnType* Getrf::fn = nullptr; - -template -void Getrf::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int b = *(reinterpret_cast(data[0])); - int m = *(reinterpret_cast(data[1])); - int n = *(reinterpret_cast(data[2])); - const T* a_in = reinterpret_cast(data[3]); - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - int* ipiv = reinterpret_cast(out[1]); - int* info = reinterpret_cast(out[2]); - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(m) * - static_cast(n) * sizeof(T)); - } - for (int i = 0; i < b; ++i) { - fn(&m, &n, a_out, &m, ipiv, info); - a_out += static_cast(m) * static_cast(n); - ipiv += std::min(m, n); - ++info; - } -} - -template struct Getrf; -template struct Getrf; -template struct Getrf>; -template struct Getrf>; - -// FFI Kernel - template ffi::Error LuDecomposition::Kernel( ffi::Buffer x, ffi::ResultBuffer x_out, @@ -261,55 +146,6 @@ template struct LuDecomposition; //== QR Factorization ==// -// lapack geqrf - -template -typename Geqrf::FnType* Geqrf::fn = nullptr; - -template -void Geqrf::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int b = *(reinterpret_cast(data[0])); - int m = *(reinterpret_cast(data[1])); - int n = *(reinterpret_cast(data[2])); - int lwork = *(reinterpret_cast(data[3])); - const T* a_in = reinterpret_cast(data[4]); - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - T* tau = reinterpret_cast(out[1]); - int* info = reinterpret_cast(out[2]); - T* work = reinterpret_cast(out[3]); - - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(m) * - static_cast(n) * sizeof(T)); - } - - for (int i = 0; i < b; ++i) { - fn(&m, &n, a_out, &m, tau, work, &lwork, info); - a_out += static_cast(m) * static_cast(n); - tau += std::min(m, n); - ++info; - } -} - -template -int64_t Geqrf::Workspace(lapack_int m, lapack_int n) { - T work = 0; - lapack_int lwork = -1; - lapack_int info = 0; - fn(&m, &n, nullptr, &m, nullptr, &work, &lwork, &info); - return info == 0 ? static_cast(std::real(work)) : -1; -} - -template struct Geqrf; -template struct Geqrf; -template struct Geqrf>; -template struct Geqrf>; - -// FFI Kernel - template ffi::Error QrFactorization::Kernel(ffi::Buffer x, ffi::ResultBuffer x_out, @@ -430,56 +266,6 @@ template struct PivotingQrFactorization; //== Orthogonal QR ==// //== Computes orthogonal matrix Q from QR Decomposition ==// -// lapack orgqr - -template -typename Orgqr::FnType* Orgqr::fn = nullptr; - -template -void Orgqr::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int b = *(reinterpret_cast(data[0])); - int m = *(reinterpret_cast(data[1])); - int n = *(reinterpret_cast(data[2])); - int k = *(reinterpret_cast(data[3])); - int lwork = *(reinterpret_cast(data[4])); - const T* a_in = reinterpret_cast(data[5]); - T* tau = reinterpret_cast(data[6]); - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - int* info = reinterpret_cast(out[1]); - T* work = reinterpret_cast(out[2]); - - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(m) * - static_cast(n) * sizeof(T)); - } - - for (int i = 0; i < b; ++i) { - fn(&m, &n, &k, a_out, &m, tau, work, &lwork, info); - a_out += static_cast(m) * static_cast(n); - tau += k; - ++info; - } -} - -template -int64_t Orgqr::Workspace(int m, int n, int k) { - T work = 0; - int lwork = -1; - int info = 0; - fn(&m, &n, &k, nullptr, &m, nullptr, &work, &lwork, &info); - return info ? -1 : static_cast(std::real(work)); -} - -template struct Orgqr; -template struct Orgqr; -template struct Orgqr>; -template struct Orgqr>; - -// FFI Kernel - template ffi::Error OrthogonalQr::Kernel(ffi::Buffer x, ffi::Buffer tau, @@ -535,42 +321,6 @@ template struct OrthogonalQr; //== Cholesky Factorization ==// -// lapack potrf - -template -typename Potrf::FnType* Potrf::fn = nullptr; - -template -void Potrf::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int32_t lower = *(reinterpret_cast(data[0])); - int b = *(reinterpret_cast(data[1])); - int n = *(reinterpret_cast(data[2])); - const T* a_in = reinterpret_cast(data[3]); - char uplo = lower ? 'L' : 'U'; - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - int* info = reinterpret_cast(out[1]); - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(n) * - static_cast(n) * sizeof(T)); - } - - for (int i = 0; i < b; ++i) { - fn(&uplo, &n, a_out, &n, info); - a_out += static_cast(n) * static_cast(n); - ++info; - } -} - -template struct Potrf; -template struct Potrf; -template struct Potrf>; -template struct Potrf>; - -// FFI Kernel - template ffi::Error CholeskyFactorization::Kernel( ffi::Buffer x, MatrixParams::UpLo uplo, @@ -604,162 +354,6 @@ template struct CholeskyFactorization; //== Singular Value Decomposition (SVD) ==// //== using a divide and conquer method ==// -// lapack gesdd - -static char GesddJobz(bool job_opt_compute_uv, bool job_opt_full_matrices) { - if (!job_opt_compute_uv) { - return 'N'; - } else if (!job_opt_full_matrices) { - return 'S'; - } - return 'A'; -} - -lapack_int GesddIworkSize(int64_t m, int64_t n) { - return CastNoOverflow(8 * std::min(m, n), "gesdd iwork"); -} - -template -typename RealGesdd::FnType* RealGesdd::fn = nullptr; - -template -void RealGesdd::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int32_t job_opt_full_matrices = *(reinterpret_cast(data[0])); - int32_t job_opt_compute_uv = *(reinterpret_cast(data[1])); - int b = *(reinterpret_cast(data[2])); - int m = *(reinterpret_cast(data[3])); - int n = *(reinterpret_cast(data[4])); - int lwork = *(reinterpret_cast(data[5])); - T* a_in = reinterpret_cast(data[6]); - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - T* s = reinterpret_cast(out[1]); - T* u = reinterpret_cast(out[2]); - T* vt = reinterpret_cast(out[3]); - int* info = reinterpret_cast(out[4]); - int* iwork = reinterpret_cast(out[5]); - T* work = reinterpret_cast(out[6]); - - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(m) * - static_cast(n) * sizeof(T)); - } - - char jobz = GesddJobz(job_opt_compute_uv, job_opt_full_matrices); - - int lda = m; - int ldu = m; - int tdu = job_opt_full_matrices ? m : std::min(m, n); - int ldvt = job_opt_full_matrices ? n : std::min(m, n); - - for (int i = 0; i < b; ++i) { - fn(&jobz, &m, &n, a_out, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, - info); - a_out += static_cast(m) * n; - s += std::min(m, n); - u += static_cast(m) * tdu; - vt += static_cast(ldvt) * n; - ++info; - } -} - -template -int64_t RealGesdd::Workspace(lapack_int m, lapack_int n, - bool job_opt_compute_uv, - bool job_opt_full_matrices) { - T work = 0; - int lwork = -1; - int info = 0; - int ldvt = job_opt_full_matrices ? n : std::min(m, n); - char jobz = GesddJobz(job_opt_compute_uv, job_opt_full_matrices); - fn(&jobz, &m, &n, nullptr, &m, nullptr, nullptr, &m, nullptr, &ldvt, &work, - &lwork, nullptr, &info); - return info ? -1 : static_cast(work); -} - -lapack_int ComplexGesddRworkSize(int64_t m, int64_t n, int compute_uv) { - int64_t mn = std::min(m, n); - if (compute_uv == 0) { - return CastNoOverflow(7 * mn, "complex gesdd rwork"); - } - int64_t mx = std::max(m, n); - return CastNoOverflow( - std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn), - "complex gesdd rwork"); -} - -template -typename ComplexGesdd::FnType* ComplexGesdd::fn = nullptr; - -template -void ComplexGesdd::Kernel(void* out_tuple, void** data, - XlaCustomCallStatus*) { - int32_t job_opt_full_matrices = *(reinterpret_cast(data[0])); - int32_t job_opt_compute_uv = *(reinterpret_cast(data[1])); - int b = *(reinterpret_cast(data[2])); - int m = *(reinterpret_cast(data[3])); - int n = *(reinterpret_cast(data[4])); - int lwork = *(reinterpret_cast(data[5])); - T* a_in = reinterpret_cast(data[6]); - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - typename T::value_type* s = reinterpret_cast(out[1]); - T* u = reinterpret_cast(out[2]); - T* vt = reinterpret_cast(out[3]); - int* info = reinterpret_cast(out[4]); - int* iwork = reinterpret_cast(out[5]); - typename T::value_type* rwork = - reinterpret_cast(out[6]); - T* work = reinterpret_cast(out[7]); - - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(m) * - static_cast(n) * sizeof(T)); - } - - char jobz = GesddJobz(job_opt_compute_uv, job_opt_full_matrices); - - int lda = m; - int ldu = m; - int tdu = job_opt_full_matrices ? m : std::min(m, n); - int ldvt = job_opt_full_matrices ? n : std::min(m, n); - - for (int i = 0; i < b; ++i) { - fn(&jobz, &m, &n, a_out, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, rwork, - iwork, info); - a_out += static_cast(m) * n; - s += std::min(m, n); - u += static_cast(m) * tdu; - vt += static_cast(ldvt) * n; - ++info; - } -} - -template -int64_t ComplexGesdd::Workspace(lapack_int m, lapack_int n, - bool job_opt_compute_uv, - bool job_opt_full_matrices) { - T work = 0; - int lwork = -1; - int info = 0; - int ldvt = job_opt_full_matrices ? n : std::min(m, n); - char jobz = GesddJobz(job_opt_compute_uv, job_opt_full_matrices); - fn(&jobz, &m, &n, nullptr, &m, nullptr, nullptr, &m, nullptr, &ldvt, &work, - &lwork, nullptr, nullptr, &info); - return info ? -1 : static_cast(work.real()); -} - -template struct RealGesdd; -template struct RealGesdd; -template struct ComplexGesdd>; -template struct ComplexGesdd>; - -// FFI Kernel - namespace internal { template @@ -949,16 +543,16 @@ static ffi::Error SvdQRKernel( for (int64_t i = 0; i < batch_count; ++i) { if constexpr (ffi::IsComplexType()) { - svd::SVDQRType::fn(&mode_v, &mode_v, &x_rows_v, &x_cols_v, x_out_data, - &x_leading_dim_v, singular_values_data, u_data, - &u_leading_dim_v, vt_data, &vt_leading_dim_v, - work_data.get(), &workspace_dim_v, rwork.get(), - info_data); + svd::SVDQRType::fn(&mode_v, &mode_v, &x_rows_v, &x_cols_v, + x_out_data, &x_leading_dim_v, + singular_values_data, u_data, &u_leading_dim_v, + vt_data, &vt_leading_dim_v, work_data.get(), + &workspace_dim_v, rwork.get(), info_data); } else { - svd::SVDQRType::fn(&mode_v, &mode_v, &x_rows_v, &x_cols_v, x_out_data, - &x_leading_dim_v, singular_values_data, u_data, - &u_leading_dim_v, vt_data, &vt_leading_dim_v, - work_data.get(), &workspace_dim_v, info_data); + svd::SVDQRType::fn( + &mode_v, &mode_v, &x_rows_v, &x_cols_v, x_out_data, &x_leading_dim_v, + singular_values_data, u_data, &u_leading_dim_v, vt_data, + &vt_leading_dim_v, work_data.get(), &workspace_dim_v, info_data); } x_out_data += x_out_step; singular_values_data += singular_values_step; @@ -970,9 +564,8 @@ static ffi::Error SvdQRKernel( } template -static absl::StatusOr SvdQRGetWorkspaceSize(lapack_int x_rows, - lapack_int x_cols, - svd::ComputationMode mode) { +static absl::StatusOr SvdQRGetWorkspaceSize( + lapack_int x_rows, lapack_int x_cols, svd::ComputationMode mode) { ffi::NativeType optimal_size = {}; lapack_int info = 0; lapack_int workspace_query = -1; @@ -994,7 +587,8 @@ static absl::StatusOr SvdQRGetWorkspaceSize(lapack_int x_rows, &u_leading_dim_v, nullptr, &vt_leading_dim_v, &optimal_size, &workspace_query, &info); } - return info == 0 ? MaybeCastNoOverflow(std::real(optimal_size)) : -1; + return info == 0 ? MaybeCastNoOverflow(std::real(optimal_size)) + : -1; } } // namespace internal @@ -1053,7 +647,8 @@ ffi::Error SingularValueDecompositionQRComplex::Kernel( } template -absl::StatusOr SingularValueDecompositionQR::GetWorkspaceSize( +absl::StatusOr +SingularValueDecompositionQR::GetWorkspaceSize( lapack_int x_rows, lapack_int x_cols, svd::ComputationMode mode) { return internal::SvdQRGetWorkspaceSize(x_rows, x_cols, mode); } @@ -1077,7 +672,8 @@ absl::StatusOr svd::GetRealWorkspaceSize( 2 * max_dim * min_dim + 2 * min_dim * min_dim + min_dim)); } -absl::StatusOr svd::GetRealWorkspaceSizeQR(int64_t x_rows, int64_t x_cols) { +absl::StatusOr svd::GetRealWorkspaceSizeQR(int64_t x_rows, + int64_t x_cols) { return CastNoOverflow(5 * std::min(x_rows, x_cols)); } @@ -1098,109 +694,6 @@ template struct SingularValueDecompositionQRComplex; //== Eigenvalues and eigenvectors ==// -// lapack syevd/heevd - -// # Workspace sizes, taken from the LAPACK documentation. -lapack_int SyevdWorkSize(int64_t n) { - return CastNoOverflow(1 + 6 * n + 2 * n * n, "syevd lwork"); -} - -lapack_int SyevdIworkSize(int64_t n) { - return CastNoOverflow(3 + 5 * n, "syevd iwork"); -} - -template -typename RealSyevd::FnType* RealSyevd::fn = nullptr; - -template -void RealSyevd::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int32_t lower = *(reinterpret_cast(data[0])); - int b = *(reinterpret_cast(data[1])); - int n = *(reinterpret_cast(data[2])); - const T* a_in = reinterpret_cast(data[3]); - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - T* w_out = reinterpret_cast(out[1]); - int* info_out = reinterpret_cast(out[2]); - T* work = reinterpret_cast(out[3]); - int* iwork = reinterpret_cast(out[4]); - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(n) * - static_cast(n) * sizeof(T)); - } - - char jobz = 'V'; - char uplo = lower ? 'L' : 'U'; - - lapack_int lwork = SyevdWorkSize(n); - lapack_int liwork = SyevdIworkSize(n); - for (int i = 0; i < b; ++i) { - fn(&jobz, &uplo, &n, a_out, &n, w_out, work, &lwork, iwork, &liwork, - info_out); - a_out += static_cast(n) * n; - w_out += n; - ++info_out; - } -} - -// Workspace sizes, taken from the LAPACK documentation. -lapack_int HeevdWorkSize(int64_t n) { - return CastNoOverflow(1 + 2 * n + n * n, "heevd work"); -} - -lapack_int HeevdRworkSize(int64_t n) { - return CastNoOverflow(1 + 5 * n + 2 * n * n, "heevd rwork"); -} - -template -typename ComplexHeevd::FnType* ComplexHeevd::fn = nullptr; - -template -void ComplexHeevd::Kernel(void* out_tuple, void** data, - XlaCustomCallStatus*) { - int32_t lower = *(reinterpret_cast(data[0])); - int b = *(reinterpret_cast(data[1])); - int n = *(reinterpret_cast(data[2])); - const T* a_in = reinterpret_cast(data[3]); - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - typename T::value_type* w_out = - reinterpret_cast(out[1]); - int* info_out = reinterpret_cast(out[2]); - T* work = reinterpret_cast(out[3]); - typename T::value_type* rwork = - reinterpret_cast(out[4]); - int* iwork = reinterpret_cast(out[5]); - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(n) * - static_cast(n) * sizeof(T)); - } - - char jobz = 'V'; - char uplo = lower ? 'L' : 'U'; - - lapack_int lwork = HeevdWorkSize(n); - lapack_int lrwork = HeevdRworkSize(n); - lapack_int liwork = SyevdIworkSize(n); - for (int i = 0; i < b; ++i) { - fn(&jobz, &uplo, &n, a_out, &n, w_out, work, &lwork, rwork, &lrwork, iwork, - &liwork, info_out); - a_out += static_cast(n) * n; - w_out += n; - ++info_out; - } -} - -template struct RealSyevd; -template struct RealSyevd; -template struct ComplexHeevd>; -template struct ComplexHeevd>; - -// FFI Kernel - absl::StatusOr eig::GetWorkspaceSize(int64_t x_cols, ComputationMode mode) { switch (mode) { @@ -1339,155 +832,6 @@ template struct EigenvalueDecompositionSymmetric; template struct EigenvalueDecompositionHermitian; template struct EigenvalueDecompositionHermitian; -// lapack geev - -template -typename RealGeev::FnType* RealGeev::fn = nullptr; - -template -void RealGeev::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int b = *(reinterpret_cast(data[0])); - int n_int = *(reinterpret_cast(data[1])); - int64_t n = n_int; - char jobvl = *(reinterpret_cast(data[2])); - char jobvr = *(reinterpret_cast(data[3])); - - const T* a_in = reinterpret_cast(data[4]); - - void** out = reinterpret_cast(out_tuple); - T* a_work = reinterpret_cast(out[0]); - T* vl_work = reinterpret_cast(out[1]); - T* vr_work = reinterpret_cast(out[2]); - - T* wr_out = reinterpret_cast(out[3]); - T* wi_out = reinterpret_cast(out[4]); - std::complex* vl_out = reinterpret_cast*>(out[5]); - std::complex* vr_out = reinterpret_cast*>(out[6]); - int* info_out = reinterpret_cast(out[7]); - - // TODO(phawkins): preallocate workspace using XLA. - T work_query; - int lwork = -1; - fn(&jobvl, &jobvr, &n_int, a_work, &n_int, wr_out, wi_out, vl_work, &n_int, - vr_work, &n_int, &work_query, &lwork, info_out); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&work_query, sizeof(work_query)); - lwork = static_cast(work_query); - T* work = new T[lwork]; - - auto is_finite = [](T* a_work, int64_t n) { - for (int64_t j = 0; j < n; ++j) { - for (int64_t k = 0; k < n; ++k) { - if (!std::isfinite(a_work[j * n + k])) { - return false; - } - } - } - return true; - }; - for (int i = 0; i < b; ++i) { - size_t a_size = n * n * sizeof(T); - std::memcpy(a_work, a_in, a_size); - if (is_finite(a_work, n)) { - fn(&jobvl, &jobvr, &n_int, a_work, &n_int, wr_out, wi_out, vl_work, - &n_int, vr_work, &n_int, work, &lwork, info_out); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wr_out, sizeof(T) * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wi_out, sizeof(T) * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_work, sizeof(T) * n * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_work, sizeof(T) * n * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int)); - if (info_out[0] == 0) { - UnpackEigenvectors(n, wi_out, vl_work, vl_out); - UnpackEigenvectors(n, wi_out, vr_work, vr_out); - } - } else { - *info_out = -4; - } - a_in += n * n; - wr_out += n; - wi_out += n; - vl_out += n * n; - vr_out += n * n; - ++info_out; - } - delete[] work; -} - -template -typename ComplexGeev::FnType* ComplexGeev::fn = nullptr; - -template -void ComplexGeev::Kernel(void* out_tuple, void** data, - XlaCustomCallStatus*) { - int b = *(reinterpret_cast(data[0])); - int n_int = *(reinterpret_cast(data[1])); - int64_t n = n_int; - char jobvl = *(reinterpret_cast(data[2])); - char jobvr = *(reinterpret_cast(data[3])); - - const T* a_in = reinterpret_cast(data[4]); - - void** out = reinterpret_cast(out_tuple); - T* a_work = reinterpret_cast(out[0]); - typename T::value_type* r_work = - reinterpret_cast(out[1]); - - T* w_out = reinterpret_cast(out[2]); - T* vl_out = reinterpret_cast(out[3]); - T* vr_out = reinterpret_cast(out[4]); - int* info_out = reinterpret_cast(out[5]); - - // TODO(phawkins): preallocate workspace using XLA. - T work_query; - int lwork = -1; - fn(&jobvl, &jobvr, &n_int, a_work, &n_int, w_out, vl_out, &n_int, vr_out, - &n_int, &work_query, &lwork, r_work, info_out); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&work_query, sizeof(work_query)); - lwork = static_cast(work_query.real()); - T* work = new T[lwork]; - - auto is_finite = [](T* a_work, int64_t n) { - for (int64_t j = 0; j < n; ++j) { - for (int64_t k = 0; k < n; ++k) { - T v = a_work[j * n + k]; - if (!std::isfinite(v.real()) || !std::isfinite(v.imag())) { - return false; - } - } - } - return true; - }; - - for (int i = 0; i < b; ++i) { - size_t a_size = n * n * sizeof(T); - std::memcpy(a_work, a_in, a_size); - if (is_finite(a_work, n)) { - fn(&jobvl, &jobvr, &n_int, a_work, &n_int, w_out, vl_out, &n_int, vr_out, - &n_int, work, &lwork, r_work, info_out); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(w_out, sizeof(T) * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_out, sizeof(T) * n * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_out, sizeof(T) * n * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int)); - } else { - *info_out = -4; - } - a_in += n * n; - w_out += n; - vl_out += n * n; - vr_out += n * n; - info_out += 1; - } - delete[] work; -} - -template struct RealGeev; -template struct RealGeev; -template struct ComplexGeev>; -template struct ComplexGeev>; - -// FFI Kernel - template ffi::Error EigenvalueDecomposition::Kernel( ffi::Buffer x, eig::ComputationMode compute_left, @@ -1662,138 +1006,6 @@ template struct EigenvalueDecompositionComplex; //== Schur Decomposition ==// -// lapack gees - -template -typename RealGees::FnType* RealGees::fn = nullptr; - -template -void RealGees::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int b = *(reinterpret_cast(data[0])); - int n_int = *(reinterpret_cast(data[1])); - int64_t n = n_int; - char jobvs = *(reinterpret_cast(data[2])); - char sort = *(reinterpret_cast(data[3])); - - const T* a_in = reinterpret_cast(data[4]); - - // bool* select (T, T) = reinterpret_cast(data[5]); - bool (*select)(T, T) = nullptr; - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - - T* wr_out = reinterpret_cast(out[1]); - T* wi_out = reinterpret_cast(out[2]); - T* vs_out = reinterpret_cast(out[3]); - int* sdim_out = reinterpret_cast(out[4]); - int* info_out = reinterpret_cast(out[5]); - - bool* b_work = (sort != 'N') ? (new bool[n]) : nullptr; - - T work_query; - int lwork = -1; - fn(&jobvs, &sort, select, &n_int, a_out, &n_int, sdim_out, wr_out, wi_out, - vs_out, &n_int, &work_query, &lwork, b_work, info_out); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&work_query, sizeof(work_query)); - lwork = static_cast(work_query); - T* work = new T[lwork]; - - size_t a_size = static_cast(n) * static_cast(n) * sizeof(T); - if (a_out != a_in) { - std::memcpy(a_out, a_in, static_cast(b) * a_size); - } - - for (int i = 0; i < b; ++i) { - fn(&jobvs, &sort, select, &n_int, a_out, &n_int, sdim_out, wr_out, wi_out, - vs_out, &n_int, work, &lwork, b_work, info_out); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_out, a_size); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(sdim_out, sizeof(int)); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wr_out, sizeof(T) * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wi_out, sizeof(T) * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vs_out, sizeof(T) * n * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int)); - - a_in += n * n; - a_out += n * n; - wr_out += n; - wi_out += n; - vs_out += n * n; - ++sdim_out; - ++info_out; - } - delete[] work; - delete[] b_work; -} - -template -typename ComplexGees::FnType* ComplexGees::fn = nullptr; - -template -void ComplexGees::Kernel(void* out_tuple, void** data, - XlaCustomCallStatus*) { - int b = *(reinterpret_cast(data[0])); - int n_int = *(reinterpret_cast(data[1])); - int64_t n = n_int; - char jobvs = *(reinterpret_cast(data[2])); - char sort = *(reinterpret_cast(data[3])); - - const T* a_in = reinterpret_cast(data[4]); - - // bool* select (T, T) = reinterpret_cast(data[5]); - bool (*select)(T) = nullptr; - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - typename T::value_type* r_work = - reinterpret_cast(out[1]); - T* w_out = reinterpret_cast(out[2]); - T* vs_out = reinterpret_cast(out[3]); - int* sdim_out = reinterpret_cast(out[4]); - int* info_out = reinterpret_cast(out[5]); - - bool* b_work = (sort != 'N') ? (new bool[n]) : nullptr; - - T work_query; - int lwork = -1; - fn(&jobvs, &sort, select, &n_int, a_out, &n_int, sdim_out, w_out, vs_out, - &n_int, &work_query, &lwork, r_work, b_work, info_out); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&work_query, sizeof(work_query)); - lwork = static_cast(work_query.real()); - T* work = new T[lwork]; - - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(n) * - static_cast(n) * sizeof(T)); - } - - for (int i = 0; i < b; ++i) { - fn(&jobvs, &sort, select, &n_int, a_out, &n_int, sdim_out, w_out, vs_out, - &n_int, work, &lwork, r_work, b_work, info_out); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(w_out, sizeof(T) * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vs_out, sizeof(T) * n * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int)); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(sdim_out, sizeof(int)); - - a_in += n * n; - a_out += n * n; - w_out += n; - vs_out += n * n; - ++info_out; - ++sdim_out; - } - delete[] work; - delete[] b_work; -} - -template struct RealGees; -template struct RealGees; -template struct ComplexGees>; -template struct ComplexGees>; - -// FFI Kernel - template ffi::Error SchurDecomposition::Kernel( ffi::Buffer x, schur::ComputationMode mode, schur::Sort sort, @@ -1968,60 +1180,6 @@ template struct SchurDecompositionComplex; //== Hessenberg Decomposition ==// -// lapack gehrd - -template -typename Gehrd::FnType* Gehrd::fn = nullptr; - -template -void Gehrd::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int32_t n = *reinterpret_cast(data[0]); - int32_t ilo = *reinterpret_cast(data[1]); - int32_t ihi = *reinterpret_cast(data[2]); - int32_t lda = *reinterpret_cast(data[3]); - int32_t batch = *reinterpret_cast(data[4]); - int32_t lwork = *reinterpret_cast(data[5]); - T* a = reinterpret_cast(data[6]); - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - T* tau = reinterpret_cast(out[1]); - int* info = reinterpret_cast(out[2]); - T* work = reinterpret_cast(out[3]); - - if (a_out != a) { - std::memcpy(a_out, a, - static_cast(batch) * static_cast(n) * - static_cast(n) * sizeof(T)); - } - - int64_t a_plus = static_cast(lda) * static_cast(n); - - for (int i = 0; i < batch; ++i) { - fn(&n, &ilo, &ihi, a_out, &lda, tau, work, &lwork, info); - a_out += a_plus; - tau += n - 1; - ++info; - } -} - -template -int64_t Gehrd::Workspace(lapack_int lda, lapack_int n, lapack_int ilo, - lapack_int ihi) { - T work = 0; - lapack_int lwork = -1; - lapack_int info = 0; - fn(&n, &ilo, &ihi, nullptr, &lda, nullptr, &work, &lwork, &info); - return info == 0 ? static_cast(std::real(work)) : -1; -} - -template struct Gehrd; -template struct Gehrd; -template struct Gehrd>; -template struct Gehrd>; - -// FFI Kernel - template ffi::Error HessenbergDecomposition::Kernel( ffi::Buffer x, lapack_int low, lapack_int high, @@ -2075,67 +1233,6 @@ template struct HessenbergDecomposition; //== Tridiagonal Reduction ==// -// lapack sytrd/hetrd - -template -typename Sytrd::FnType* Sytrd::fn = nullptr; - -template -void Sytrd::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int32_t n = *reinterpret_cast(data[0]); - int32_t lower = *reinterpret_cast(data[1]); - int32_t lda = *reinterpret_cast(data[2]); - int32_t batch = *reinterpret_cast(data[3]); - int32_t lwork = *reinterpret_cast(data[4]); - T* a = reinterpret_cast(data[5]); - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - typedef typename real_type::type Real; - Real* d = reinterpret_cast(out[1]); - Real* e = reinterpret_cast(out[2]); - T* tau = reinterpret_cast(out[3]); - int* info = reinterpret_cast(out[4]); - T* work = reinterpret_cast(out[5]); - - if (a_out != a) { - std::memcpy(a_out, a, - static_cast(batch) * static_cast(n) * - static_cast(n) * sizeof(T)); - } - - char cuplo = lower ? 'L' : 'U'; - - int64_t a_plus = static_cast(lda) * static_cast(n); - - for (int i = 0; i < batch; ++i) { - fn(&cuplo, &n, a_out, &lda, d, e, tau, work, &lwork, info); - a_out += a_plus; - d += n; - e += n - 1; - tau += n - 1; - ++info; - } -} - -template -int64_t Sytrd::Workspace(lapack_int lda, lapack_int n) { - char cuplo = 'L'; - T work = 0; - lapack_int lwork = -1; - lapack_int info = 0; - fn(&cuplo, &n, nullptr, &lda, nullptr, nullptr, nullptr, &work, &lwork, - &info); - return info == 0 ? static_cast(std::real(work)) : -1; -} - -template struct Sytrd; -template struct Sytrd; -template struct Sytrd>; -template struct Sytrd>; - -// FFI Kernel - template ffi::Error TridiagonalReduction::Kernel( ffi::Buffer x, MatrixParams::UpLo uplo, @@ -2250,7 +1347,6 @@ template struct TridiagonalSolver; ::xla::ffi::Ffi::Bind() \ .Arg<::xla::ffi::Buffer>(/*x*/) \ .Arg<::xla::ffi::Buffer>(/*y*/) \ - .RemainingArgs() \ .Ret<::xla::ffi::Buffer>(/*y_out*/) \ .Attr("side") \ .Attr("uplo") \ diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index e075ff29387f..fe285347ac9e 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -18,13 +18,10 @@ limitations under the License. #include #include -#include #include #include "absl/status/statusor.h" -#include "xla/ffi/api/c_api.h" #include "xla/ffi/api/ffi.h" -#include "xla/service/custom_call_status.h" // Underlying function pointers (i.e., KERNEL_CLASS::Fn) are initialized either // by the nanobind wrapper that links them to an existing SciPy lapack instance, @@ -33,6 +30,8 @@ limitations under the License. namespace jax { +extern bool lapack_kernels_initialized; + struct MatrixParams { enum class Side : char { kLeft = 'L', kRight = 'R' }; enum class UpLo : char { kLower = 'L', kUpper = 'U' }; @@ -93,26 +92,6 @@ void AssignKernelFn(typename KernelType::FnType* func) { } // namespace jax -#define DEFINE_CHAR_ENUM_ATTR_DECODING(ATTR) \ - template <> \ - struct xla::ffi::AttrDecoding { \ - using Type = ATTR; \ - static std::optional Decode(XLA_FFI_AttrType type, void* attr, \ - DiagnosticEngine& diagnostic); \ - } - -// XLA needs attributes to have deserialization method specified -DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Side); -DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::UpLo); -DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Transpose); -DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Diag); -DEFINE_CHAR_ENUM_ATTR_DECODING(jax::svd::ComputationMode); -DEFINE_CHAR_ENUM_ATTR_DECODING(jax::eig::ComputationMode); -DEFINE_CHAR_ENUM_ATTR_DECODING(jax::schur::ComputationMode); -DEFINE_CHAR_ENUM_ATTR_DECODING(jax::schur::Sort); - -#undef DEFINE_CHAR_ENUM_ATTR_DECODING - namespace jax { using lapack_int = int; @@ -122,20 +101,6 @@ static_assert( //== Triangular System Solver ==// -// lapack trsm - -template -struct Trsm { - using FnType = void(char* side, char* uplo, char* transa, char* diag, - lapack_int* m, lapack_int* n, T* alpha, T* a, - lapack_int* lda, T* b, lapack_int* ldb); - - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct TriMatrixEquationSolver { using ValueType = ::xla::ffi::NativeType; @@ -145,28 +110,17 @@ struct TriMatrixEquationSolver { lapack_int* ldb); inline static FnType* fn = nullptr; - static ::xla::ffi::Error Kernel( - ::xla::ffi::Buffer x, ::xla::ffi::Buffer y, - ::xla::ffi::RemainingArgs, ::xla::ffi::ResultBuffer y_out, - MatrixParams::Side side, MatrixParams::UpLo uplo, - MatrixParams::Transpose trans_x, MatrixParams::Diag diag); + static ::xla::ffi::Error Kernel(::xla::ffi::Buffer x, + ::xla::ffi::Buffer y, + ::xla::ffi::ResultBuffer y_out, + MatrixParams::Side side, + MatrixParams::UpLo uplo, + MatrixParams::Transpose trans_x, + MatrixParams::Diag diag); }; //== LU Decomposition ==// -// lapack getrf - -template -struct Getrf { - using FnType = void(lapack_int* m, lapack_int* n, T* a, lapack_int* lda, - lapack_int* ipiv, lapack_int* info); - - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct LuDecomposition { using ValueType = ::xla::ffi::NativeType; @@ -182,21 +136,6 @@ struct LuDecomposition { //== QR Factorization ==// -// lapack geqrf - -template -struct Geqrf { - using FnType = void(lapack_int* m, lapack_int* n, T* a, lapack_int* lda, - T* tau, T* work, lapack_int* lwork, lapack_int* info); - - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); - - static int64_t Workspace(lapack_int m, lapack_int n); -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct QrFactorization { using ValueType = ::xla::ffi::NativeType; @@ -240,23 +179,8 @@ struct PivotingQrFactorization { static int64_t GetWorkspaceSize(lapack_int x_rows, lapack_int x_cols); }; - //== Orthogonal QR ==// -// lapack orgqr - -template -struct Orgqr { - using FnType = void(lapack_int* m, lapack_int* n, lapack_int* k, T* a, - lapack_int* lda, T* tau, T* work, lapack_int* lwork, - lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); - static int64_t Workspace(lapack_int m, lapack_int n, lapack_int k); -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct OrthogonalQr { using ValueType = ::xla::ffi::NativeType; @@ -276,16 +200,6 @@ struct OrthogonalQr { //== Cholesky Factorization ==// -// lapack potrf - -template -struct Potrf { - using FnType = void(char* uplo, lapack_int* n, T* a, lapack_int* lda, - lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - template <::xla::ffi::DataType dtype> struct CholeskyFactorization { using ValueType = ::xla::ffi::NativeType; @@ -302,41 +216,6 @@ struct CholeskyFactorization { //== Singular Value Decomposition (SVD) ==// -// lapack gesdd - -lapack_int GesddIworkSize(int64_t m, int64_t n); - -template -struct RealGesdd { - using FnType = void(char* jobz, lapack_int* m, lapack_int* n, T* a, - lapack_int* lda, T* s, T* u, lapack_int* ldu, T* vt, - lapack_int* ldvt, T* work, lapack_int* lwork, - lapack_int* iwork, lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); - - static int64_t Workspace(lapack_int m, lapack_int n, bool job_opt_compute_uv, - bool job_opt_full_matrices); -}; - -lapack_int ComplexGesddRworkSize(int64_t m, int64_t n, int compute_uv); - -template -struct ComplexGesdd { - using FnType = void(char* jobz, lapack_int* m, lapack_int* n, T* a, - lapack_int* lda, typename T::value_type* s, T* u, - lapack_int* ldu, T* vt, lapack_int* ldvt, T* work, - lapack_int* lwork, typename T::value_type* rwork, - lapack_int* iwork, lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); - - static int64_t Workspace(lapack_int m, lapack_int n, bool job_opt_compute_uv, - bool job_opt_full_matrices); -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct SingularValueDecomposition { static_assert(!::xla::ffi::IsComplexType(), @@ -407,8 +286,8 @@ struct SingularValueDecompositionQR { ::xla::ffi::ResultBuffer info, svd::ComputationMode mode); static absl::StatusOr GetWorkspaceSize(lapack_int x_rows, - lapack_int x_cols, - svd::ComputationMode mode); + lapack_int x_cols, + svd::ComputationMode mode); }; template <::xla::ffi::DataType dtype> @@ -432,8 +311,8 @@ struct SingularValueDecompositionQRComplex { ::xla::ffi::ResultBuffer info, svd::ComputationMode mode); static absl::StatusOr GetWorkspaceSize(lapack_int x_rows, - lapack_int x_cols, - svd::ComputationMode mode); + lapack_int x_cols, + svd::ComputationMode mode); }; namespace svd { @@ -451,42 +330,13 @@ using SVDQRType = std::conditional_t<::xla::ffi::IsComplexType(), absl::StatusOr GetIntWorkspaceSize(int64_t x_rows, int64_t x_cols); absl::StatusOr GetRealWorkspaceSize(int64_t x_rows, int64_t x_cols, ComputationMode mode); -absl::StatusOr GetRealWorkspaceSizeQR(int64_t x_rows, int64_t x_cols); +absl::StatusOr GetRealWorkspaceSizeQR(int64_t x_rows, + int64_t x_cols); } // namespace svd //== Eigenvalues and eigenvectors ==// -// lapack syevd/heevd - -lapack_int SyevdWorkSize(int64_t n); -lapack_int SyevdIworkSize(int64_t n); - -template -struct RealSyevd { - using FnType = void(char* jobz, char* uplo, lapack_int* n, T* a, - lapack_int* lda, T* w, T* work, lapack_int* lwork, - lapack_int* iwork, lapack_int* liwork, lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - -lapack_int HeevdWorkSize(int64_t n); -lapack_int HeevdRworkSize(int64_t n); - -template -struct ComplexHeevd { - using FnType = void(char* jobz, char* uplo, lapack_int* n, T* a, - lapack_int* lda, typename T::value_type* w, T* work, - lapack_int* lwork, typename T::value_type* rwork, - lapack_int* lrwork, lapack_int* iwork, lapack_int* liwork, - lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - -// FFI Kernel - namespace eig { // Eigenvalue Decomposition @@ -544,8 +394,6 @@ struct EigenvalueDecompositionHermitian { ::xla::ffi::ResultBuffer info, eig::ComputationMode mode); }; -// lapack geev - // LAPACK uses a packed representation to represent a mixture of real // eigenvectors and complex conjugate pairs. This helper unpacks the // representation into regular complex matrices. @@ -574,28 +422,6 @@ static void UnpackEigenvectors(Int n, const T* eigenvals_imag, const T* packed, } } -template -struct RealGeev { - using FnType = void(char* jobvl, char* jobvr, lapack_int* n, T* a, - lapack_int* lda, T* wr, T* wi, T* vl, lapack_int* ldvl, - T* vr, lapack_int* ldvr, T* work, lapack_int* lwork, - lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - -template -struct ComplexGeev { - using FnType = void(char* jobvl, char* jobvr, lapack_int* n, T* a, - lapack_int* lda, T* w, T* vl, lapack_int* ldvl, T* vr, - lapack_int* ldvr, T* work, lapack_int* lwork, - typename T::value_type* rwork, lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct EigenvalueDecomposition { static_assert(!::xla::ffi::IsComplexType(), @@ -653,31 +479,6 @@ struct EigenvalueDecompositionComplex { //== Schur Decomposition ==// -// lapack gees - -template -struct RealGees { - using FnType = void(char* jobvs, char* sort, bool (*select)(T, T), - lapack_int* n, T* a, lapack_int* lda, lapack_int* sdim, - T* wr, T* wi, T* vs, lapack_int* ldvs, T* work, - lapack_int* lwork, bool* bwork, lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - -template -struct ComplexGees { - using FnType = void(char* jobvs, char* sort, bool (*select)(T), lapack_int* n, - T* a, lapack_int* lda, lapack_int* sdim, T* w, T* vs, - lapack_int* ldvs, T* work, lapack_int* lwork, - typename T::value_type* rwork, bool* bwork, - lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct SchurDecomposition { static_assert(!::xla::ffi::IsComplexType(), @@ -737,32 +538,6 @@ struct SchurDecompositionComplex { //== Hessenberg Decomposition ==// //== Reduces a non-symmetric square matrix to upper Hessenberg form ==// -// lapack gehrd - -template -struct Gehrd { - using FnType = void(lapack_int* n, lapack_int* ilo, lapack_int* ihi, T* a, - lapack_int* lda, T* tau, T* work, lapack_int* lwork, - lapack_int* info); - - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); - - static int64_t Workspace(lapack_int lda, lapack_int n, lapack_int ilo, - lapack_int ihi); -}; - -template -struct real_type { - typedef T type; -}; -template -struct real_type> { - typedef T type; -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct HessenbergDecomposition { using ValueType = ::xla::ffi::NativeType; @@ -785,23 +560,6 @@ struct HessenbergDecomposition { //== Tridiagonal Reduction ==// //== Reduces a Symmetric/Hermitian square matrix to tridiagonal form ==// -// lapack sytrd/hetrd - -template -struct Sytrd { - using FnType = void(char* uplo, lapack_int* n, T* a, lapack_int* lda, - typename real_type::type* d, - typename real_type::type* e, T* tau, T* work, - lapack_int* lwork, lapack_int* info); - - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); - - static int64_t Workspace(lapack_int lda, lapack_int n); -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct TridiagonalReduction { using ValueType = ::xla::ffi::NativeType; diff --git a/jaxlib/cpu/lapack_kernels_using_lapack.cc b/jaxlib/cpu/lapack_kernels_using_lapack.cc index 3c8ddf11cf29..1c3694af5b98 100644 --- a/jaxlib/cpu/lapack_kernels_using_lapack.cc +++ b/jaxlib/cpu/lapack_kernels_using_lapack.cc @@ -13,9 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include - #include "jaxlib/cpu/lapack_kernels.h" // From a Python binary, JAX obtains its LAPACK/BLAS kernels from Scipy, but @@ -100,241 +97,7 @@ jax::TridiagonalSolver::FnType zgtsv_; namespace jax { -#define JAX_KERNEL_FNTYPE_MISMATCH_MSG "FFI Kernel FnType mismatch" - -static_assert( - std::is_same_v::FnType, - jax::Trsm::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Trsm::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Trsm>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Trsm>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Getrf::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Getrf::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Getrf>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Getrf>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Geqrf::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Geqrf::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Geqrf>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Geqrf>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Orgqr::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Orgqr::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Orgqr>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Orgqr>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Potrf::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Potrf::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Potrf>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Potrf>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::RealGesdd::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::RealGesdd::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::SingularValueDecompositionComplex::FnType, - jax::ComplexGesdd>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::SingularValueDecompositionComplex::FnType, - jax::ComplexGesdd>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::EigenvalueDecompositionSymmetric::FnType, - jax::RealSyevd::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::EigenvalueDecompositionSymmetric::FnType, - jax::RealSyevd::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::EigenvalueDecompositionHermitian::FnType, - jax::ComplexHeevd>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::EigenvalueDecompositionHermitian::FnType, - jax::ComplexHeevd>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::RealGeev::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::RealGeev::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::EigenvalueDecompositionComplex::FnType, - jax::ComplexGeev>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::EigenvalueDecompositionComplex::FnType, - jax::ComplexGeev>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Sytrd::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Sytrd::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Sytrd>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Sytrd>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::RealGees::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::RealGees::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::ComplexGees>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::ComplexGees>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Gehrd::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Gehrd::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Gehrd>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Gehrd>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); - -#undef JAX_KERNEL_FNTYPE_MISMATCH_MSG - static auto init = []() -> int { - AssignKernelFn>(strsm_); - AssignKernelFn>(dtrsm_); - AssignKernelFn>>(ctrsm_); - AssignKernelFn>>(ztrsm_); - - AssignKernelFn>(sgetrf_); - AssignKernelFn>(dgetrf_); - AssignKernelFn>>(cgetrf_); - AssignKernelFn>>(zgetrf_); - - AssignKernelFn>(sgeqrf_); - AssignKernelFn>(dgeqrf_); - AssignKernelFn>>(cgeqrf_); - AssignKernelFn>>(zgeqrf_); - - AssignKernelFn>(sorgqr_); - AssignKernelFn>(dorgqr_); - AssignKernelFn>>(cungqr_); - AssignKernelFn>>(zungqr_); - - AssignKernelFn>(spotrf_); - AssignKernelFn>(dpotrf_); - AssignKernelFn>>(cpotrf_); - AssignKernelFn>>(zpotrf_); - - AssignKernelFn>(sgesdd_); - AssignKernelFn>(dgesdd_); - AssignKernelFn>>(cgesdd_); - AssignKernelFn>>(zgesdd_); - - AssignKernelFn>(ssyevd_); - AssignKernelFn>(dsyevd_); - AssignKernelFn>>(cheevd_); - AssignKernelFn>>(zheevd_); - - AssignKernelFn>(sgeev_); - AssignKernelFn>(dgeev_); - AssignKernelFn>>(cgeev_); - AssignKernelFn>>(zgeev_); - - AssignKernelFn>(sgees_); - AssignKernelFn>(dgees_); - AssignKernelFn>>(cgees_); - AssignKernelFn>>(zgees_); - - AssignKernelFn>(sgehrd_); - AssignKernelFn>(dgehrd_); - AssignKernelFn>>(cgehrd_); - AssignKernelFn>>(zgehrd_); - - AssignKernelFn>(ssytrd_); - AssignKernelFn>(dsytrd_); - AssignKernelFn>>(chetrd_); - AssignKernelFn>>(zhetrd_); - - // FFI Kernels - AssignKernelFn>(strsm_); AssignKernelFn>(dtrsm_); AssignKernelFn>(ctrsm_); @@ -410,6 +173,7 @@ static auto init = []() -> int { AssignKernelFn>(cgtsv_); AssignKernelFn>(zgtsv_); + lapack_kernels_initialized = true; return 0; }(); diff --git a/jaxlib/cpu/sparse.cc b/jaxlib/cpu/sparse.cc new file mode 100644 index 000000000000..15f5c0f1984f --- /dev/null +++ b/jaxlib/cpu/sparse.cc @@ -0,0 +1,37 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "nanobind/nanobind.h" +#include "jaxlib/cpu/sparse_kernels.h" +#include "jaxlib/kernel_nanobind_helpers.h" + +namespace jax { +namespace { + +namespace nb = nanobind; + +nb::dict Registrations() { + nb::dict dict; + + dict["cpu_csr_sparse_dense_ffi"] = + EncapsulateFunction(cpu_csr_sparse_dense_ffi); + + return dict; +} + +NB_MODULE(_sparse, m) { m.def("registrations", &Registrations); } + +} // namespace +} // namespace jax diff --git a/jaxlib/cpu/sparse_kernels.cc b/jaxlib/cpu/sparse_kernels.cc new file mode 100644 index 000000000000..8000abca65cc --- /dev/null +++ b/jaxlib/cpu/sparse_kernels.cc @@ -0,0 +1,215 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/cpu/sparse_kernels.h" + +#include +#include +#include +#include + +#include "Eigen/Core" +#include "Eigen/SparseCore" +#include "xla/ffi/api/ffi.h" + +namespace ffi = xla::ffi; + +namespace jax { + +template +using SparseMatrixType = + Eigen::SparseMatrix; +template +using DenseMatrixType = + Eigen::Matrix; + +template +using InputMap = Eigen::Map; +template +using OutputMap = Eigen::Map; + +template +static ffi::Future CsrSparseDenseKernelImpl( + const InputMap>& lhs_matrix, + const InputMap>& rhs_matrix, + OutputMap>& out_matrix, + ffi::ThreadPool& thread_pool) { + // Rule of thumb to give each task at least 100k cycles to hide the cost of + // task scheduling. + // TODO(willfroom) Do we want to make this configurable? + constexpr int64_t kTargetCyclesPerTask = 100'000; + // Based on AVX (CPI 0.5 -> 2 IPC) + constexpr int64_t kScalarProductsPerCycle = 2 * 32 / sizeof(ElementType); + constexpr int64_t kTaskSize = kTargetCyclesPerTask * kScalarProductsPerCycle; + + if (lhs_matrix.nonZeros() * rhs_matrix.cols() <= kTaskSize || + thread_pool.num_threads() == 0) { + out_matrix.noalias() = lhs_matrix * rhs_matrix; + + ffi::Promise promise; + promise.SetAvailable(); + return ffi::Future(promise); + } else { + std::vector batch_sizes; + { + int64_t running_batch_nnz = 0; + int64_t running_number_rows = 0; + for (int row = 0; row < lhs_matrix.rows(); ++row) { + int64_t row_nnz = lhs_matrix.outerIndexPtr()[row + 1] - + lhs_matrix.outerIndexPtr()[row]; + // If there is no non-zero elements in a row the task still needs to + // write out a zero row we give each row a non-zero contribution to + // avoid the pathological case of a task having to write many rows where + // there is a large block of zero inputs. + running_batch_nnz += std::max(row_nnz, static_cast(1)); + running_number_rows++; + if (running_batch_nnz * rhs_matrix.cols() > kTaskSize) { + batch_sizes.push_back(running_number_rows); + running_batch_nnz = 0; + running_number_rows = 0; + } else if (row == lhs_matrix.rows() - 1 && running_number_rows > 0) { + batch_sizes.push_back(running_number_rows); + } + } + } + + ffi::CountDownPromise promise(batch_sizes.size()); + ffi::Future future(promise); + int64_t batch_start = 0; + for (int64_t size : batch_sizes) { + thread_pool.Schedule([out_matrix, lhs_matrix, rhs_matrix, batch_start, + size, promise]() mutable { + out_matrix.middleRows(batch_start, size).noalias() = + lhs_matrix.middleRows(batch_start, size) * rhs_matrix; + promise.CountDown(); + }); + batch_start += size; + } + return future; + } +} + +template +static ffi::Future CsrSparseDenseKernelTypedDispatch( + ffi::AnyBuffer lhs_data, ffi::AnyBuffer lhs_outer_indicies, + ffi::AnyBuffer lhs_inner_indicies, ffi::AnyBuffer rhs, + ffi::Result out, ffi::ThreadPool thread_pool) { + ffi::Span rhs_shape = rhs.dimensions(); + ffi::Span out_shape = out->dimensions(); + + InputMap> lhs_matrix( + out_shape[0], rhs_shape[0], lhs_data.element_count(), + lhs_outer_indicies.reinterpret_data(), + lhs_inner_indicies.reinterpret_data(), + lhs_data.reinterpret_data()); + + InputMap> rhs_matrix( + rhs.reinterpret_data(), rhs_shape[0], + rhs_shape.size() > 1 ? rhs_shape[1] : 1); + OutputMap> out_matrix( + out->reinterpret_data(), lhs_matrix.rows(), + rhs_matrix.cols()); + + return CsrSparseDenseKernelImpl( + lhs_matrix, rhs_matrix, out_matrix, thread_pool); +} + +template +static ffi::Future CsrSparseDenseKernelTypedDispatch( + ffi::AnyBuffer lhs_data, ffi::AnyBuffer lhs_outer_indicies, + ffi::AnyBuffer lhs_inner_indicies, ffi::AnyBuffer rhs, + ffi::Result out, ffi::ThreadPool thread_pool) { + if (lhs_outer_indicies.element_type() != lhs_inner_indicies.element_type()) { + ffi::Promise promise; + promise.SetError(ffi::Error(ffi::ErrorCode::kInvalidArgument, + "Sparse index type mismatch")); + return ffi::Future(promise); + } + + switch (lhs_outer_indicies.element_type()) { + case ffi::DataType::S32: + return CsrSparseDenseKernelTypedDispatch( + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, + thread_pool); + case ffi::DataType::S64: + return CsrSparseDenseKernelTypedDispatch( + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, + thread_pool); + default: + ffi::Promise promise; + promise.SetError(ffi::Error(ffi::ErrorCode::kInvalidArgument, + "Invalid index data type")); + return ffi::Future(promise); + } +} + +static ffi::Future CsrSparseDenseKernelDispatch( + ffi::AnyBuffer lhs_data, ffi::AnyBuffer lhs_outer_indicies, + ffi::AnyBuffer lhs_inner_indicies, ffi::AnyBuffer rhs, + ffi::Result out, ffi::ThreadPool thread_pool) { + if (lhs_data.element_type() != rhs.element_type() || + lhs_data.element_type() != out->element_type()) { + ffi::Promise promise; + promise.SetError( + ffi::Error(ffi::ErrorCode::kInvalidArgument, "Element type mismatch")); + return ffi::Future(promise); + } + + switch (lhs_data.element_type()) { + case ffi::DataType::S32: + return CsrSparseDenseKernelTypedDispatch( + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, + thread_pool); + case ffi::DataType::S64: + return CsrSparseDenseKernelTypedDispatch( + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, + thread_pool); + case ffi::DataType::F32: + return CsrSparseDenseKernelTypedDispatch( + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, + thread_pool); + case ffi::DataType::F64: + return CsrSparseDenseKernelTypedDispatch( + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, + thread_pool); + case ffi::DataType::C64: + return CsrSparseDenseKernelTypedDispatch>( + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, + thread_pool); + case ffi::DataType::C128: + return CsrSparseDenseKernelTypedDispatch>( + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, + thread_pool); + default: + ffi::Promise promise; + promise.SetError( + ffi::Error(ffi::ErrorCode::kInvalidArgument, "Invalid data type")); + return ffi::Future(promise); + } +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(cpu_csr_sparse_dense_ffi, + CsrSparseDenseKernelDispatch, + (ffi::Ffi::Bind() + .Arg(/*lhs_data*/) + .Arg( + /*lhs_outer_indicies*/) + .Arg( + /*lhs_inner_indicies*/) + .Arg(/*rhs*/) + .Ret(/*out*/) + .Ctx(/*thread_pool*/))); + +} // namespace jax diff --git a/jaxlib/cpu/sparse_kernels.h b/jaxlib/cpu/sparse_kernels.h new file mode 100644 index 000000000000..856b1da9d36c --- /dev/null +++ b/jaxlib/cpu/sparse_kernels.h @@ -0,0 +1,27 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_PY_JAX_JAXLIB_CPU_SPARSE_KERNELS_H_ +#define THIRD_PARTY_PY_JAX_JAXLIB_CPU_SPARSE_KERNELS_H_ + +#include "xla/ffi/api/ffi.h" + +namespace jax { + +XLA_FFI_DECLARE_HANDLER_SYMBOL(cpu_csr_sparse_dense_ffi); + +} // namespace jax + +#endif // THIRD_PARTY_PY_JAX_JAXLIB_CPU_SPARSE_KERNELS_H_ diff --git a/jaxlib/cpu_sparse.py b/jaxlib/cpu_sparse.py new file mode 100644 index 000000000000..ed43b3ee0f92 --- /dev/null +++ b/jaxlib/cpu_sparse.py @@ -0,0 +1,27 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from .cpu import _sparse + + +def registrations() -> dict[str, list[tuple[str, Any, int]]]: + api_version = 1 + return { + "cpu": [ + (name, value, api_version) + for name, value in _sparse.registrations().items() + ] + } diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index a9bd35b7768d..e8e405a44f47 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -14,6 +14,7 @@ # NVIDIA CUDA kernels +load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_python//python:defs.bzl", "py_library") load( "//jaxlib:jax.bzl", @@ -64,7 +65,6 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@local_config_cuda//cuda:cublas_headers", "@local_config_cuda//cuda:cuda_headers", "@xla//xla/tsl/cuda:cupti", "@xla//xla/tsl/cuda:cusolver", @@ -89,7 +89,7 @@ cc_library( deps = [ ":cuda_gpu_kernel_helpers", ":cuda_vendor", - "//jaxlib:handle_pool", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", @@ -98,55 +98,6 @@ cc_library( ], ) -cc_library( - name = "cublas_kernels", - srcs = ["//jaxlib/gpu:blas_kernels.cc"], - hdrs = ["//jaxlib/gpu:blas_kernels.h"], - deps = [ - ":cuda_blas_handle_pool", - ":cuda_gpu_kernel_helpers", - ":cuda_make_batch_pointers", - ":cuda_vendor", - "//jaxlib:kernel_helpers", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@local_config_cuda//cuda:cublas_headers", - "@local_config_cuda//cuda:cuda_headers", - "@xla//xla/service:custom_call_status", - "@xla//xla/tsl/cuda:cublas", - "@xla//xla/tsl/cuda:cudart", - ], -) - -nanobind_extension( - name = "_blas", - srcs = ["//jaxlib/gpu:blas.cc"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - module_name = "_blas", - deps = [ - ":cublas_kernels", - ":cuda_vendor", - "//jaxlib:kernel_nanobind_helpers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings:str_format", - "@nanobind", - "@xla//xla/tsl/cuda:cublas", - "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/python/lib/core:numpy", - ], -) - cc_library( name = "cudnn_rnn_kernels", srcs = ["//jaxlib/gpu:rnn_kernels.cc"], @@ -155,14 +106,14 @@ cc_library( ":cuda_gpu_kernel_helpers", ":cuda_vendor", ":ffi_wrapper", - "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cudnn", ], @@ -195,7 +146,7 @@ cc_library( deps = [ ":cuda_gpu_kernel_helpers", ":cuda_vendor", - "//jaxlib:handle_pool", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", @@ -204,24 +155,6 @@ cc_library( ], ) -cc_library( - name = "cusolver_kernels", - srcs = ["//jaxlib/gpu:solver_kernels.cc"], - hdrs = ["//jaxlib/gpu:solver_kernels.h"], - deps = [ - ":cuda_gpu_kernel_helpers", - ":cuda_solver_handle_pool", - ":cuda_vendor", - "//jaxlib:kernel_helpers", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@local_config_cuda//cuda:cuda_headers", - "@xla//xla/service:custom_call_status", - "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/cuda:cusolver", - ], -) - cc_library( name = "cusolver_interface", srcs = ["//jaxlib/gpu:solver_interface.cc"], @@ -272,21 +205,14 @@ nanobind_extension( features = ["-use_header_modules"], module_name = "_solver", deps = [ - ":cuda_gpu_kernel_helpers", - ":cuda_solver_handle_pool", ":cuda_vendor", - ":cusolver_kernels", ":cusolver_kernels_ffi", "//jaxlib:kernel_nanobind_helpers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", "@local_config_cuda//cuda:cuda_headers", "@nanobind", "@xla//xla/tsl/cuda:cublas", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cusolver", - "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -308,15 +234,16 @@ cc_library( ":cuda_gpu_kernel_helpers", ":cuda_vendor", ":ffi_wrapper", - "//jaxlib:handle_pool", + "//jaxlib:ffi_helpers", "//jaxlib:kernel_helpers", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cusparse", ], @@ -423,7 +350,6 @@ cc_library( "@local_config_cuda//cuda:cuda_headers", "@xla//xla/ffi/api:c_api", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", ], ) @@ -439,7 +365,6 @@ cuda_library( "//jaxlib:kernel_helpers", "@local_config_cuda//cuda:cuda_headers", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", ], ) @@ -455,6 +380,7 @@ nanobind_extension( deps = [ ":cuda_gpu_kernel_helpers", ":cuda_prng_kernels", + ":cuda_vendor", "//jaxlib:kernel_nanobind_helpers", "@local_config_cuda//cuda:cuda_headers", "@nanobind", @@ -511,15 +437,14 @@ cc_library( srcs = ["//jaxlib/gpu:gpu_kernels.cc"], visibility = ["//visibility:public"], deps = [ - ":cublas_kernels", ":cuda_linalg_kernels", ":cuda_prng_kernels", ":cuda_vendor", ":cudnn_rnn_kernels", - ":cusolver_kernels", ":cusolver_kernels_ffi", ":cusparse_kernels", ":triton_kernels", + "//jaxlib/mosaic/gpu:mosaic_gpu_support", "@xla//xla/ffi/api:c_api", "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_target_registry", @@ -546,9 +471,9 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", - "@tsl//tsl/platform:env", "@xla//xla/service:custom_call_status", "@xla//xla/stream_executor/cuda:cuda_asm_compiler", + "@xla//xla/stream_executor/cuda:cuda_compute_capability", "@xla//xla/tsl/cuda:cudart", ], ) @@ -586,6 +511,7 @@ nanobind_extension( "//jaxlib:absl_status_casters", "//jaxlib:kernel_nanobind_helpers", "//jaxlib/gpu:triton_cc_proto", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@nanobind", ], @@ -644,7 +570,6 @@ nanobind_extension( py_library( name = "cuda_gpu_support", deps = [ - ":_blas", ":_hybrid", ":_linalg", ":_prng", @@ -657,11 +582,51 @@ py_library( ], ) +cc_library( + name = "py_client_gpu", + srcs = ["//jaxlib/gpu:py_client_gpu.cc"], + hdrs = ["//jaxlib/gpu:py_client_gpu.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":cuda_vendor", + "//jaxlib:ffi", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@dlpack", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:comparison_util", + "@xla//xla:shape_util", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/ffi:ffi_api", + "@xla//xla/ffi/api:ffi", + "@xla//xla/pjrt:host_callback", + "@xla//xla/pjrt:transpose", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:types", + "@xla//xla/service:platform_util", + ], +) + nanobind_extension( name = "cuda_plugin_extension", srcs = ["cuda_plugin_extension.cc"], module_name = "cuda_plugin_extension", deps = [ + ":py_client_gpu", + "//jaxlib:kernel_nanobind_helpers", "//jaxlib/gpu:gpu_plugin_extension", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", diff --git a/jaxlib/cuda/cuda_plugin_extension.cc b/jaxlib/cuda/cuda_plugin_extension.cc index 8d8514bd2740..20b0785cf439 100644 --- a/jaxlib/cuda/cuda_plugin_extension.cc +++ b/jaxlib/cuda/cuda_plugin_extension.cc @@ -16,17 +16,20 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "third_party/gpus/cuda/include/cuda.h" +#include "nanobind/nanobind.h" #include "jaxlib/gpu/gpu_plugin_extension.h" +#include "jaxlib/gpu/py_client_gpu.h" +#include "jaxlib/kernel_nanobind_helpers.h" #include "xla/pjrt/status_casters.h" namespace nb = nanobind; -namespace xla { +namespace jax { namespace { + static std::string ToString(CUresult result) { const char* error_name; if (cuGetErrorName(result, &error_name)) { @@ -38,10 +41,44 @@ static std::string ToString(CUresult result) { } return absl::StrCat(error_name, ": ", error_string); } + +static nb::dict GpuTransposePlanCacheType() { + auto [type_id, type_info] = cuda::GpuTransposePlanCacheTypeInfo(); + nb::dict d; + d["type_id"] = nb::capsule(type_id); + d["type_info"] = nb::capsule(type_info); + return d; +} + +nb::dict FfiTypes() { + nb::dict dict; + dict["GpuTransposePlanCache"] = GpuTransposePlanCacheType(); + return dict; +} + +nb::dict FfiHandlers() { + nb::dict dict; + nb::dict gpu_callback_dict; + gpu_callback_dict["instantiate"] = + EncapsulateFfiHandler(cuda::kGpuTransposePlanCacheInstantiate); + gpu_callback_dict["execute"] = + EncapsulateFfiHandler(cuda::kXlaFfiPythonGpuCallback); + dict["xla_ffi_python_gpu_callback"] = gpu_callback_dict; + dict["xla_ffi_partitioned_python_gpu_callback"] = gpu_callback_dict; + dict["xla_buffer_python_gpu_callback"] = + EncapsulateFfiHandler(cuda::kXlaBufferPythonGpuCallback); + dict["xla_buffer_python_gpu_callback_cmd_buffer"] = + EncapsulateFfiHandler(cuda::kXlaBufferPythonGpuCallbackCmdBuffer); + return dict; +} + } // namespace NB_MODULE(cuda_plugin_extension, m) { BuildGpuPluginExtension(m); + m.def("ffi_types", &FfiTypes); + m.def("ffi_handlers", &FfiHandlers); + m.def( "get_device_ordinal", [](std::intptr_t data_value) { @@ -62,4 +99,4 @@ NB_MODULE(cuda_plugin_extension, m) { }, nb::arg("data_value")); } -} // namespace xla +} // namespace jax diff --git a/jaxlib/cuda/versions.cc b/jaxlib/cuda/versions.cc index 8d6577f46709..d9f9f4c86865 100644 --- a/jaxlib/cuda/versions.cc +++ b/jaxlib/cuda/versions.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/cuda/versions_helpers.h" - #include "nanobind/nanobind.h" +#include "jaxlib/cuda/versions_helpers.h" #include "jaxlib/gpu/vendor.h" namespace jax::cuda { diff --git a/jaxlib/cuda/versions_helpers.cc b/jaxlib/cuda/versions_helpers.cc index d42199d37467..508a92c326cb 100644 --- a/jaxlib/cuda/versions_helpers.cc +++ b/jaxlib/cuda/versions_helpers.cc @@ -16,6 +16,7 @@ limitations under the License. #include "jaxlib/cuda/versions_helpers.h" #include +#include #include #include "absl/base/dynamic_annotations.h" diff --git a/jaxlib/custom_call_sharding.cc b/jaxlib/custom_call_sharding.cc new file mode 100644 index 000000000000..29743656fb84 --- /dev/null +++ b/jaxlib/custom_call_sharding.cc @@ -0,0 +1,357 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/custom_call_sharding.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/tuple.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/utils/hlo_sharding_util.h" +#include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/c/pjrt_c_api_custom_partitioner_extension.h" +#include "xla/pjrt/c/pjrt_c_api_helpers.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/custom_call_batch_partitioner.h" +#include "xla/python/custom_partition_callback.h" +#include "xla/python/inspect_sharding.h" +#include "xla/shape.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace jax { + +namespace nb = ::nanobind; + +class PyCustomCallPartitionerCallbacks { + public: + PyCustomCallPartitionerCallbacks(nb::object prop_user_sharding, + nb::object partition, + nb::object infer_sharding_from_operands) + : prop_user_sharding_(prop_user_sharding), + partition_(partition), + infer_sharding_from_operands_(infer_sharding_from_operands) { + callbacks_.version = 0; + callbacks_.private_data = this; + callbacks_.dtor = +[](JAX_CustomCallPartitioner_Callbacks* self) { + delete GetSelfPtr(self); + }; + callbacks_.partition = +[](JAX_CustomCallPartitioner_Callbacks* self, + JAX_CustomCallPartitioner_Partition_Args* args) { + PopulateResults(GetSelfPtr(self)->CallPartition(args), args); + }; + callbacks_.infer_sharding = + +[](JAX_CustomCallPartitioner_Callbacks* self, + JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args) { + PopulateResults(GetSelfPtr(self)->CallInferShardingFromOperands(args), + args); + }; + callbacks_.propagate_user_sharding = + +[](JAX_CustomCallPartitioner_Callbacks* self, + JAX_CustomCallPartitioner_PropagateUserSharding_Args* args) { + PopulateResults(GetSelfPtr(self)->CallPropagateUserSharding(args), + args); + }; + } + + absl::StatusOr< + std::tuple, xla::HloSharding>> + CallPartition(JAX_CustomCallPartitioner_Partition_Args* args) const { + if (args->header.api_version != 0) { + return absl::InternalError("API version mismatch."); + } + TF_ASSIGN_OR_RETURN(auto args_tuple, ReadArgs(args)); + std::vector shapes = std::move(std::get<0>(args_tuple)); + std::vector> shardings = + std::move(std::get<1>(args_tuple)); + xla::Shape result_shape = std::move(std::get<2>(args_tuple)); + std::optional result_sharding = + std::move(std::get<3>(args_tuple)); + std::string_view backend_config = std::move(std::get<4>(args_tuple)); + + { + nb::gil_scoped_acquire gil; + try { + auto py_result = + partition_(shapes, shardings, result_shape, result_sharding, + nb::bytes(backend_config.data(), backend_config.size())); + try { + auto [ir, arg_shardings, result_sharding] = + nb::cast, + xla::HloSharding>>(py_result); + if (arg_shardings.size() != args->num_args) { + return xla::Internal( + "Shardings returned from partitioning: lengths must match: %d " + "vs %d", + arg_shardings.size(), args->num_args); + } + return std::make_tuple(std::string(ir.c_str(), ir.size()), + std::move(arg_shardings), + std::move(result_sharding)); + } catch (const nb::cast_error& e) { + return xla::Internal( + "Shardings returned from partitioning: expected " + "Tuple[bytes, List[HloSharding], HloSharding] got: %s", + nb::cast(nb::repr(py_result))); + } + } catch (const nb::python_error& e) { + return xla::Internal("custom_partitioner: %s", e.what()); + } + } + } + + absl::StatusOr> CallInferShardingFromOperands( + JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args) const { + if (args->header.api_version != 0) { + return absl::InternalError("API version mismatch."); + } + TF_ASSIGN_OR_RETURN(auto args_tuple, ReadArgs(args)); + std::vector arg_shapes = std::move(std::get<0>(args_tuple)); + std::vector> arg_shardings = + std::move(std::get<1>(args_tuple)); + xla::Shape result_shape = std::move(std::get<2>(args_tuple)); + std::string_view backend_config = std::move(std::get<3>(args_tuple)); + + std::optional result; + nb::gil_scoped_acquire gil; + try { + auto py_result = infer_sharding_from_operands_( + arg_shapes, arg_shardings, result_shape, + nb::bytes(backend_config.data(), backend_config.size())); + if (py_result.is_none()) { + return std::nullopt; + } + return nb::cast(py_result); + } catch (const nb::python_error& e) { + return xla::Internal("custom_partitioner: %s", e.what()); + } + } + + absl::StatusOr CallPropagateUserSharding( + JAX_CustomCallPartitioner_PropagateUserSharding_Args* args) const { + if (args->header.api_version != 0) { + return absl::InternalError("API version mismatch."); + } + TF_ASSIGN_OR_RETURN(auto args_tuple, ReadArgs(args)); + xla::HloSharding result_sharding = std::move(std::get<0>(args_tuple)); + xla::Shape result_shape = std::move(std::get<1>(args_tuple)); + std::string_view backend_config = std::move(std::get<2>(args_tuple)); + + nb::gil_scoped_acquire gil; + try { + // TODO(parkers): expand this API to handle the `user` sharding. + // The user is used when the custom call returns a Tuple and + // the user is a get-tuple-element. In this case we must update only + // part of the sharding spec. + auto result = nb::cast(prop_user_sharding_( + result_sharding, result_shape, + nb::bytes(backend_config.data(), backend_config.size()))); + return result; + } catch (const nb::python_error& e) { + return xla::Internal("custom_partitioner: %s", e.what()); + } + } + + JAX_CustomCallPartitioner_Callbacks* callbacks() { return &callbacks_; } + + private: + static PyCustomCallPartitionerCallbacks* GetSelfPtr( + JAX_CustomCallPartitioner_Callbacks* callbacks) { + return reinterpret_cast( + callbacks->private_data); + } + + JAX_CustomCallPartitioner_Callbacks callbacks_; + nb::object prop_user_sharding_; + nb::object partition_; + nb::object infer_sharding_from_operands_; +}; + +namespace { + +void CallInspectSharding(void* obj, JAX_InspectSharding_Callback_Args* args) { + std::optional arg = InspectShardingReadArgs(args); + if (!arg.has_value()) { + return; + } + try { + nb::gil_scoped_acquire gil; + nb::handle(reinterpret_cast(obj))(*std::move(arg)); + } catch (const nb::python_error& e) { + InspectShardingSetError(args, std::string(e.what())); + } +} + +} // namespace + +void BuildCustomCallShardingPybindAPI(nb::module_& m) { + m.def( + "register_custom_call_partitioner", + [](std::string name, nb::object prop_user_sharding, nb::object partition, + nb::object infer_sharding_from_operands, + bool can_side_effecting_have_replicated_sharding, + std::optional c_api) { + auto* c_fns = + (new PyCustomCallPartitionerCallbacks(prop_user_sharding, partition, + infer_sharding_from_operands)) + ->callbacks(); + c_fns->can_side_effecting_have_replicated_sharding = + can_side_effecting_have_replicated_sharding; + if (!c_api.has_value()) { + RegisterCustomCallPartitioner(name, + CreateCApiCustomCallPartitioner(c_fns)); + return; + } + + if (std::string_view(c_api->name()) != "pjrt_c_api") { + throw absl::InvalidArgumentError( + "Argument to register_custom_call_partitioner was not a " + "pjrt_c_api capsule."); + } + auto* c_api_value = static_cast(c_api->data()); + PJRT_Custom_Partitioner_Extension* extension = + pjrt::FindExtension( + c_api_value, + PJRT_Extension_Type::PJRT_Extension_Type_Custom_Partitioner); + if (extension == nullptr) { + return; + } + PJRT_Register_Custom_Partitioner_Args args; + args.struct_size = PJRT_Register_Custom_Partitioner_Args_STRUCT_SIZE; + args.name = name.c_str(); + args.name_size = name.size(); + args.callbacks = c_fns; + PJRT_Error* error = + reinterpret_cast( + extension) + ->register_custom_partitioner(&args); + std::unique_ptr error_ptr( + error, pjrt::MakeErrorDeleter(c_api_value)); + xla::ThrowIfError( + pjrt::PjrtErrorToStatus(error_ptr.get(), c_api_value)); + }, + R"(Registers a partitioner for a custom-call operation. + +Args: + name: custom_call_target to match. + prop_user_sharding: Custom backwards sharding propagation rule. + Takes result sharding and returns the instruction sharding. + partition: Lowering rule. Takes operand and result shardings and returns + a generated HLO and sharding specs. The spmd lowerer first reshards + to match the returned sharding specs and then inserts the generated hlo. + infer_sharding_from_operands: Custom forwards sharding propagation rule. + Takes operand sharding and returns the instruction sharding. + can_side_effecting_have_replicated_sharding: Side effecting ops are not + allowed to have replicated sharding. Pass true to disable this check. + c_api: Optional `PJRT_Api*` if it is called with a plugin. This is safe to + call on plugins that do not implement the custom partitioner extension +)", + nb::arg("name"), nb::arg("prop_user_sharding"), nb::arg("partition"), + nb::arg("infer_sharding_from_operands"), + nb::arg("can_side_effecting_have_replicated_sharding") = false, + nb::arg("c_api").none() = std::nullopt); + m.def("encode_inspect_sharding_callback", + [](nb::object handler) -> nb::bytes { + JAX_InspectSharding_Callback cb; + cb.call = &CallInspectSharding; + cb.data = handler.ptr(); + char bytes[sizeof(JAX_InspectSharding_Callback)]; + std::memcpy(&bytes, &cb, sizeof(JAX_InspectSharding_Callback)); + return nb::bytes(bytes, sizeof(JAX_InspectSharding_Callback)); + }); + + nb::module_ hlo_sharding_util_m = m.def_submodule( + "hlo_sharding_util", "Utilities for manipulating HloSharding."); + hlo_sharding_util_m.attr("_HloSharding") = m.attr("HloSharding"); + hlo_sharding_util_m.def( + "PartiallyReplicateTiledShardingOnDims", + [](const xla::HloSharding& sharding, std::vector dims) { + return xla::hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + sharding, dims); + }, + nb::sig( + // clang-format off + "def PartiallyReplicateTiledShardingOnDims(" + "sharding: _HloSharding, " + "dims: typing.Sequence[int], /" + ") -> _HloSharding" + // clang-format on + )); + + m.def( + "register_custom_call_as_batch_partitionable", + [](std::string target_name, std::optional c_api) { + if (!c_api.has_value()) { + RegisterCustomCallPartitioner( + target_name, std::make_unique()); + return; + } + if (std::string_view(c_api->name()) != "pjrt_c_api") { + throw absl::InvalidArgumentError( + "Argument to register_custom_call_partitioner was not a " + "pjrt_c_api capsule."); + } + auto* c_api_value = static_cast(c_api->data()); + PJRT_Custom_Partitioner_Extension* extension = + pjrt::FindExtension( + c_api_value, + PJRT_Extension_Type::PJRT_Extension_Type_Custom_Partitioner); + if (extension == nullptr) { + return; + } + PJRT_Register_Batch_Partitionable_Args args; + args.struct_size = PJRT_Register_Batch_Partitionable_Args_STRUCT_SIZE; + args.name = target_name.c_str(); + args.name_size = target_name.size(); + PJRT_Error* error = extension->register_batch_partitionable(&args); + std::unique_ptr error_ptr( + error, pjrt::MakeErrorDeleter(c_api_value)); + xla::ThrowIfError( + pjrt::PjrtErrorToStatus(error_ptr.get(), c_api_value)); + }, + R"(Registers a custom call as batch partitionable. + +If a custom call is "batch partitionable", it means that it can be trivially +partitioned on some number of (leading) dimensions, with the same call being +executed independently on each shard of data. If the data are sharded on +non-batch dimensions, partitioning will re-shard the data to be replicated on +the non-batch dimensions. + +Args: + target_name: the target name of the batch partitionable custom call. + c_api: optional `PJRT_Api*` to support registration via a PJRT plugin. +)", + nb::arg("target_name"), nb::arg("c_api").none() = std::nullopt); +} + +} // namespace jax diff --git a/jaxlib/custom_call_sharding.h b/jaxlib/custom_call_sharding.h new file mode 100644 index 000000000000..cb6054c5237e --- /dev/null +++ b/jaxlib/custom_call_sharding.h @@ -0,0 +1,28 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_CUSTOM_CALL_SHARDING_H_ +#define JAXLIB_CUSTOM_CALL_SHARDING_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace jax { + +void BuildCustomCallShardingPybindAPI(nanobind::module_& m); + +} // namespace jax + +#endif // JAXLIB_CUSTOM_CALL_SHARDING_H_ diff --git a/jaxlib/dlpack.cc b/jaxlib/dlpack.cc new file mode 100644 index 000000000000..14ccb51a85c1 --- /dev/null +++ b/jaxlib/dlpack.cc @@ -0,0 +1,415 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/dlpack.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "include/dlpack/dlpack.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "nanobind/ndarray.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_user_context.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/util.h" +#include "xla/layout.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/python/dlpack_types.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/types.h" +#include "xla/python/version.h" +#include "xla/shape_util.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace ifrt = xla::ifrt; +namespace nb = nanobind; + +namespace jax { +namespace { + +const char* const kDlTensorCapsuleName = "dltensor"; + +struct DLPackTensor { + ~DLPackTensor(); + + // `buffer_reference` is populated if we have shared (read-only) access. + nb::object buffer_reference; + + // `external_reference` is always populated. + std::unique_ptr external_reference; + + std::vector shape; + std::vector strides; + DLManagedTensor tensor; +}; + +DLPackTensor::~DLPackTensor() { + // We must release the external reference first before deleting the array. + external_reference.reset(); + if (buffer_reference) { + GlobalPyRefManager()->AddGarbage( + absl::MakeSpan(&buffer_reference, /*size=*/1)); + } +} + +void DLPackTensorDeleter(DLManagedTensor* t) { + if (t) { + delete static_cast(t->manager_ctx); + } +} + +absl::StatusOr> StridesToLayout( + absl::Span dims, absl::Span strides) { + CHECK_EQ(dims.size(), strides.size()); + std::vector minor_to_major(dims.size()); + std::iota(minor_to_major.begin(), minor_to_major.end(), 0); + absl::c_sort(minor_to_major, [&](int a, int b) { + if (strides[a] < strides[b]) { + return true; + } + if (strides[a] > strides[b]) { + return false; + } + // If two dimensions have the same stride, prefer the major-to-minor + // interpretation of the ordering, since that's what JAX wants. + return b < a; + }); + int64_t stride = 1; + for (int64_t d : minor_to_major) { + if (dims[d] > 1 && strides[d] != stride) { + return xla::Unimplemented( + "Only DLPack tensors with trivial (compact) striding are supported; " + "i.e., tensors whose striding represents a transposition of the " + "underlying buffer but not broadcasting. Dimensions were: [%s], " + "strides were [%s].", + absl::StrJoin(dims, ","), absl::StrJoin(strides, ",")); + } + stride *= dims[d]; + } + return minor_to_major; +} + +absl::StatusOr DLDeviceTypeForDevice( + const xla::PjRtDevice& device) { + if (device.client()->platform_id() == xla::CpuId()) { + return kDLCPU; + } else if (device.client()->platform_id() == xla::CudaId()) { + return kDLCUDA; + } else if (device.client()->platform_id() == xla::RocmId()) { + return kDLROCM; + } + return xla::InvalidArgument("Device %s cannot be used as a DLPack device.", + device.DebugString()); +} + +absl::StatusOr DLDeviceForDevice(const xla::PjRtDevice& device) { + DLDevice context; + TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device)); + context.device_id = device.local_hardware_id().value(); + return context; +} + +absl::Status VerifyDType(const DLTensor& dl_tensor) { + if (dl_tensor.dtype.bits % 8 != 0) { + return xla::InvalidArgument( + "Unsupported DLPack tensor dtype: bits should be a multiple of 8, got " + "%d", + dl_tensor.dtype.bits); + } + + if (dl_tensor.dtype.lanes != 1) { + return xla::InvalidArgument( + "Unsupported DLPack tensor dtype: lanes should be equal to 1, got %d", + dl_tensor.dtype.lanes); + } + + return absl::OkStatus(); +} + +absl::StatusOr> GetByteStrides(const DLTensor& dl_tensor) { + TF_RETURN_IF_ERROR(VerifyDType(dl_tensor)); + + // Convert element strides from the number of elements to the number of bytes. + std::vector strides; + strides.reserve(dl_tensor.ndim); + for (int i = 0; i < dl_tensor.ndim; ++i) { + strides.push_back(dl_tensor.strides[i] * dl_tensor.dtype.bits / 8); + } + return strides; +} + +// Makes a PjRtBuffer from a DLPack tensor. Returns a pair where the second +// element is true if a copy actually happened. +absl::StatusOr, bool>> +MakePjrtBuffer(xla::PjRtDevice& device, ::DLManagedTensor* dlmt, + const xla::Shape& shape, xla::PrimitiveType element_type, + absl::Span dimensions, + std::optional copy = std::nullopt, + std::optional stream = std::nullopt) { + std::function on_delete_callback; + if (dlmt->deleter) { + on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); }; + } + + void* data = + static_cast(dlmt->dl_tensor.data) + dlmt->dl_tensor.byte_offset; + + // On CPU, creating a view may fail because of unaligned data buffer + // in which case we'll fallback to copy. On non-CPU, array-api copy + // semantics is handled in dlpack._place_array function. + bool fallback_to_copy = + !copy.has_value() && dlmt->dl_tensor.device.device_type == kDLCPU; + + // Create a view. + if (!copy.value_or(false)) { + auto result = device.client()->CreateViewOfDeviceBuffer( + data, shape, *device.default_memory_space(), on_delete_callback, + stream); + if (!(result.status().code() == absl::StatusCode::kInvalidArgument && + fallback_to_copy)) { + TF_RETURN_IF_ERROR(result.status()); + return std::make_pair(*std::move(result), false); + } + } + + // Convert tensor strides (expressed in number of elements) to byte strides. + std::optional> byte_strides; + if (dlmt->dl_tensor.strides) { + TF_ASSIGN_OR_RETURN(byte_strides, GetByteStrides(dlmt->dl_tensor)); + } + + TF_ASSIGN_OR_RETURN(auto* memory_space, device.default_memory_space()); + + // Create a copy. + TF_ASSIGN_OR_RETURN( + auto buffer, + device.client()->BufferFromHostBuffer( + data, element_type, dimensions, byte_strides, + xla::PjRtClient::HostBufferSemantics::kMutableZeroCopy, + on_delete_callback, memory_space, /*device_layout=*/nullptr)); + return std::make_pair(std::move(buffer), true); +} + +} // namespace + +absl::StatusOr BufferToDLPackManagedTensor( + nb::handle py_buffer, std::optional stream) { + ifrt::Array* ifrt_array = nb::cast(py_buffer).ifrt_array(); + if (ifrt_array == nullptr) { + return xla::Unimplemented( + "BufferToDLPackManagedTensor called on deleted array."); + } + auto* arr = llvm::dyn_cast_or_null(ifrt_array); + if (arr == nullptr) { + throw xla::XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + xla::PjRtBuffer* pjrt_buffer = arr->pjrt_buffers().front().get(); + + if (pjrt_buffer->IsTuple()) { + return xla::Unimplemented( + "BufferToDLPackManagedTensor is not implemented for tuple " + "buffers."); + } + if (pjrt_buffer->has_dynamic_dimensions()) { + return xla::Unimplemented("DynamicShape is not implemented in DLPack."); + } + + auto pack = std::make_unique(); + DLTensor& dt = pack->tensor.dl_tensor; + { + // AcquireExternalReference may block; there are no API guarantees. + GlobalPyRefManager()->CollectGarbage(); + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(pack->external_reference, + pjrt_buffer->AcquireExternalReference()); + if (stream) { + TF_RETURN_IF_ERROR( + pack->external_reference->WaitUntilBufferReadyOnStream(*stream)); + } else { + TF_RETURN_IF_ERROR( + AwaitBuffersReady(absl::MakeConstSpan(&ifrt_array, 1))); + } + } + pack->buffer_reference = nb::borrow(py_buffer); + + dt.data = pack->external_reference->OpaqueDeviceMemoryDataPointer(); + pack->tensor.manager_ctx = pack.get(); + pack->tensor.deleter = DLPackTensorDeleter; + TF_ASSIGN_OR_RETURN(dt.device, DLDeviceForDevice(*pjrt_buffer->device())); + dt.device.device_id = pjrt_buffer->device()->local_hardware_id().value(); + dt.ndim = pjrt_buffer->dimensions().size(); + TF_ASSIGN_OR_RETURN(dt.dtype, + PrimitiveTypeToDLDataType(pjrt_buffer->element_type())); + + pack->shape = std::vector(pjrt_buffer->dimensions().begin(), + pjrt_buffer->dimensions().end()); + + // TODO(b/327524065): use PjRtLayout directly instead of xla::Layout + xla::Layout xla_layout = pjrt_buffer->layout()->xla_layout(); + pack->strides = StridesForShape(pjrt_buffer->element_type(), + pjrt_buffer->dimensions(), xla_layout); + + dt.shape = reinterpret_cast(pack->shape.data()); + dt.strides = reinterpret_cast(pack->strides.data()); + dt.byte_offset = 0; + + // We cannot use nanobind's capsule object constructor because we need to + // detect if the capsule name has been changed in the deleter, but nanobind + // hides the underlying Python object from the deleter. + nb::capsule capsule = nb::steal( + PyCapsule_New(&pack.release()->tensor, kDlTensorCapsuleName, + [](PyObject* obj) noexcept { +#if PY_VERSION_HEX < 0x030C0000 + PyObject *type, *value, *traceback; + PyErr_Fetch(&type, &value, &traceback); +#else // PY_VERSION_HEX < 0x030C0000 + PyObject* exc = PyErr_GetRaisedException(); +#endif // PY_VERSION_HEX < 0x030C0000 + DLManagedTensor* dlmt = static_cast( + PyCapsule_GetPointer(obj, kDlTensorCapsuleName)); + if (dlmt) { + DLPackTensorDeleter(dlmt); + } + // PyCapsule_GetPointer may have raised. Restore the + // previous exception if there was one. +#if PY_VERSION_HEX < 0x030C0000 + PyErr_Restore(type, value, traceback); +#else // PY_VERSION_HEX < 0x030C0000 + PyErr_SetRaisedException(exc); +#endif // PY_VERSION_HEX < 0x030C0000 + })); + if (!capsule.ptr()) { + throw nb::python_error(); + } + return capsule; +} + +absl::StatusOr DLPackManagedTensorToBuffer( + const nb::capsule& tensor, ifrt::Device* ifrt_device, + nb_class_ptr client, std::optional stream, + std::optional copy) { + ifrt::PjRtDevice* device = + llvm::dyn_cast_or_null(ifrt_device); + if (device == nullptr) { + throw xla::XlaRuntimeError( + "DLPack is supported for PjRt-compatible backends only."); + } + if (!device->IsAddressable()) { + throw xla::XlaRuntimeError( + "DLPack is only supported for devices addressable by the current " + "process."); + } + if (std::string_view(tensor.name()) != kDlTensorCapsuleName) { + return xla::InvalidArgument( + "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". " + "Note that a DLPack tensor may be consumed at most once.", + std::string_view(tensor.name())); + } + DLManagedTensor* dlmt = static_cast(tensor.data()); + if (dlmt->dl_tensor.ndim < 0) { + return xla::InvalidArgument( + "Number of dimensions in DLManagedTensor must be nonnegative, got %d", + dlmt->dl_tensor.ndim); + } + absl::Span dimensions( + reinterpret_cast(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim); + TF_ASSIGN_OR_RETURN(xla::PrimitiveType element_type, + xla::DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype)); + + bool has_custom_layout = dlmt->dl_tensor.strides != nullptr; + std::vector minor_to_major; + if (dlmt->dl_tensor.strides && + absl::c_find(dimensions, 0) == dimensions.end()) { + absl::Span strides( + reinterpret_cast(dlmt->dl_tensor.strides), + dlmt->dl_tensor.ndim); + TF_ASSIGN_OR_RETURN(minor_to_major, StridesToLayout(dimensions, strides)); + } else { + minor_to_major.resize(dlmt->dl_tensor.ndim); + std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); + } + xla::Shape shape = xla::ShapeUtil::MakeShapeWithDenseLayout( + element_type, dimensions, minor_to_major); + + TF_ASSIGN_OR_RETURN(auto pjrt_buffer_and_copied, + MakePjrtBuffer(*device->pjrt_device(), dlmt, shape, + element_type, dimensions, copy, stream)); + if (pjrt_buffer_and_copied.second) { + // A PjRtBuffer uses a default layout if it has been created using copy. + has_custom_layout = false; + } + + // We have taken ownership of the array inside the capsule; make sure the + // capsule it cannot be used again. + PyCapsule_SetName(tensor.ptr(), "used_dltensor"); + PyCapsule_SetDestructor(tensor.ptr(), nullptr); + + auto* ifrt_client = + llvm::dyn_cast_or_null(client->ifrt_client()); + if (ifrt_client == nullptr) { + throw xla::XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + PyUserContextScope user_context_scope; + TF_ASSIGN_OR_RETURN( + auto ifrt_array, + ifrt_client->CreatePjRtArray(std::move(pjrt_buffer_and_copied.first), + has_custom_layout)); + return PyArray::MakeFromSingleDeviceArray(std::move(client), + std::move(ifrt_array), false, true); +} + +absl::StatusOr PrimitiveTypeToNbDLDataType( + xla::PrimitiveType type) { + TF_ASSIGN_OR_RETURN(DLDataType dl_type, PrimitiveTypeToDLDataType(type)); + + nanobind::dlpack::dtype nb_type; + nb_type.lanes = dl_type.lanes; + nb_type.bits = dl_type.bits; + nb_type.code = dl_type.code; + + return nb_type; +} + +} // namespace jax diff --git a/jaxlib/dlpack.h b/jaxlib/dlpack.h new file mode 100644 index 000000000000..bed6006a37d8 --- /dev/null +++ b/jaxlib/dlpack.h @@ -0,0 +1,54 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_DLPACK_H_ +#define JAXLIB_DLPACK_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "nanobind/ndarray.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "xla/python/ifrt/device.h" +#include "xla/xla_data.pb.h" + +namespace jax { + +// If take_ownership is true, ownership of the buffer is handed to DLPack, and +// the receiver may mutate the buffer as they see fit. Otherwise PjRt retains +// ownership of the buffer and it should be immutable. +// +// stream, if set, is a GPU stream, e.g. cudaStream_t for CUDA GPUs, that should +// be synchronized to the buffer as per +// https://dmlc.github.io/dlpack/latest/python_spec.html#python-specification-for-dlpack. +absl::StatusOr BufferToDLPackManagedTensor( + nanobind::handle buffer, std::optional stream); + +absl::StatusOr DLPackManagedTensorToBuffer( + const nanobind::capsule& tensor, xla::ifrt::Device* device, + nb_class_ptr client, std::optional stream, + std::optional copy); + +// Converts a PrimitiveType to the nanobind specific implementation of +// DLDataType. +absl::StatusOr PrimitiveTypeToNbDLDataType( + xla::PrimitiveType type); + +} // namespace jax + +#endif // JAXLIB_DLPACK_H_ diff --git a/jaxlib/ffi.cc b/jaxlib/ffi.cc new file mode 100644 index 000000000000..81e83896afc2 --- /dev/null +++ b/jaxlib/ffi.cc @@ -0,0 +1,547 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/ffi.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "include/dlpack/dlpack.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" +#include "xla/ffi/ffi_api.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/dlpack_types.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/types.h" +#include "xla/service/custom_call_target_registry.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/xla_data.pb.h" + +namespace jax { + +namespace ffi = xla::ffi; +namespace nb = nanobind; + +namespace { +const char* const kDlTensorCapsuleName = "dltensor"; +const char* const kDlTensorVersionedCapsuleName = "dltensor_versioned"; + +template +struct DLPackTensor { + std::vector shape; + ManagedTensor tensor; +}; + +template +void DLPackTensorDeleter(ManagedTensor* t) { + if (t) { + delete static_cast*>(t->manager_ctx); + } +} + +xla::PrimitiveType PrimitiveTypeForFfiDataType(ffi::DataType dtype) { + switch (dtype) { + case ffi::DataType::INVALID: + return xla::PrimitiveType::PRIMITIVE_TYPE_INVALID; + case ffi::PRED: + return xla::PrimitiveType::PRED; + case ffi::S1: + return xla::PrimitiveType::S1; + case ffi::S2: + return xla::PrimitiveType::S2; + case ffi::S4: + return xla::PrimitiveType::S4; + case ffi::S8: + return xla::PrimitiveType::S8; + case ffi::S16: + return xla::PrimitiveType::S16; + case ffi::S32: + return xla::PrimitiveType::S32; + case ffi::S64: + return xla::PrimitiveType::S64; + case ffi::U1: + return xla::PrimitiveType::U1; + case ffi::U2: + return xla::PrimitiveType::U2; + case ffi::U4: + return xla::PrimitiveType::U4; + case ffi::U8: + return xla::PrimitiveType::U8; + case ffi::U16: + return xla::PrimitiveType::U16; + case ffi::U32: + return xla::PrimitiveType::U32; + case ffi::U64: + return xla::PrimitiveType::U64; + case ffi::F16: + return xla::PrimitiveType::F16; + case ffi::F32: + return xla::PrimitiveType::F32; + case ffi::F64: + return xla::PrimitiveType::F64; + case ffi::BF16: + return xla::PrimitiveType::BF16; + case ffi::C64: + return xla::PrimitiveType::C64; + case ffi::C128: + return xla::PrimitiveType::C128; + case ffi::TOKEN: + return xla::PrimitiveType::TOKEN; + case ffi::F8E5M2: + return xla::PrimitiveType::F8E5M2; + case ffi::F8E4M3: + return xla::PrimitiveType::F8E4M3; + case ffi::F8E4M3FN: + return xla::PrimitiveType::F8E4M3FN; + case ffi::F8E4M3B11FNUZ: + return xla::PrimitiveType::F8E4M3B11FNUZ; + case ffi::F8E5M2FNUZ: + return xla::PrimitiveType::F8E5M2FNUZ; + case ffi::F8E4M3FNUZ: + return xla::PrimitiveType::F8E4M3FNUZ; + case ffi::F8E3M4: + return xla::PrimitiveType::F8E3M4; + case ffi::F4E2M1FN: + return xla::PrimitiveType::F4E2M1FN; + case ffi::F8E8M0FNU: + return xla::PrimitiveType::F8E8M0FNU; + } +} +// Registers a 'fn' as a custom call target. +// +// `fn` must be a custom call implementation function pointer (XLA_FFI_Handler* +// when implemented as FFI handler) encapsulated in a PyCapsule object or a +// a dictionary of function pointers (also encapsulated in a PyCapsule). +// +// See XLA_FFI_ExecutionStage documentation for more details about the +// custom execution stages. +absl::Status PyRegisterCustomCallTarget(const std::string& fn_name, + nb::object fn, + const std::string& platform, + int api_version, + XLA_FFI_Handler_Traits traits) { + // Register legacy custom call target (untyped void* API). + if (api_version == 0) { + if (traits != 0) { + return absl::InvalidArgumentError( + "Custom call target registration with traits is not supported for " + "api_version=0"); + } + + nb::capsule capsule; + if (!nb::try_cast(fn, capsule)) { + return absl::InvalidArgumentError( + "Custom call target registration with api_version=0 requires a " + "PyCapsule fn object"); + } + + xla::CustomCallTargetRegistry::Global()->Register( + fn_name, static_cast(capsule.data()), platform); + return absl::OkStatus(); + } + + // Register XLA FFI handler (typed API with explicit function signatures). + if (api_version == 1) { + nb::capsule capsule; + if (nb::try_cast(fn, capsule)) { + return ffi::TakeStatus(ffi::Ffi::RegisterStaticHandler( + xla::ffi::GetXlaFfiApi(), fn_name, platform, + reinterpret_cast( + static_cast(capsule.data())))); + } + + nb::dict bundle; + if (nb::try_cast(fn, bundle)) { + auto handler = [&](const char* name) -> absl::StatusOr { + if (!bundle.contains(name)) return nullptr; + + nb::capsule capsule; + if (!nb::try_cast(bundle[name], capsule)) { + return absl::InvalidArgumentError( + "Custom call target registration with api_version=1 requires a " + "PyCapsule fn object for all dict keys"); + } + + return reinterpret_cast(capsule.data()); + }; + + XLA_FFI_Handler_Bundle bundle; + TF_ASSIGN_OR_RETURN(bundle.instantiate, handler("instantiate")); + TF_ASSIGN_OR_RETURN(bundle.prepare, handler("prepare")); + TF_ASSIGN_OR_RETURN(bundle.initialize, handler("initialize")); + TF_ASSIGN_OR_RETURN(bundle.execute, handler("execute")); + + return ffi::TakeStatus(ffi::Ffi::RegisterStaticHandler( + xla::ffi::GetXlaFfiApi(), fn_name, platform, bundle, traits)); + } + + return absl::InvalidArgumentError( + "Unsupported custom call target type for api_version=1"); + } + + return absl::UnimplementedError(absl::StrFormat( + "API version %d is not supported by RegisterCustomCallTarget. " + "Supported versions are 0 and 1.", + api_version)); +} + +absl::Status PyRegisterCustomType(std::string_view type_name, nb::object type) { + XLA_FFI_TypeId* type_id = nullptr; + XLA_FFI_TypeInfo* type_info = nullptr; + + auto as_capsule = [](nb::object obj) -> absl::StatusOr { + nb::capsule capsule; + if (!nb::try_cast(obj, capsule)) { + return absl::InvalidArgumentError( + "Custom type registration requires handlers as PyCapsules"); + } + return capsule; + }; + + // Extract XLA_FFI_TypeId and optional XLA_FFI_TypeInfo from the type dict. + nb::dict type_dict; + if (!nb::try_cast(type, type_dict) || + !type_dict.contains("type_id")) { + return absl::InvalidArgumentError( + "The type_id argument to register_custom_call_type must be a " + "dictionary holding a pointer to a XLA_FFI_TypeId in `type_id` and " + "optional pointer to a XLA_FFI_TypeInfo in `type_info` fields."); + } + + TF_ASSIGN_OR_RETURN(auto type_id_capsule, as_capsule(type_dict["type_id"])); + type_id = static_cast(type_id_capsule.data()); + + if (type_dict.contains("type_info")) { + TF_ASSIGN_OR_RETURN(auto type_info_capsule, + as_capsule(type_dict["type_info"])); + type_info = static_cast(type_info_capsule.data()); + } + + return ffi::TakeStatus( + ffi::Ffi::RegisterTypeId(xla::ffi::GetXlaFfiApi(), type_name, type_id, + type_info ? *type_info : XLA_FFI_TypeInfo{})); +} +} // namespace + +PyFfiContext::PyFfiContext(const XLA_FFI_Api* api, + XLA_FFI_ExecutionContext* ctx, + XLA_FFI_ExecutionStage stage) + : api_(api), ctx_(ctx), stage_(stage) {} + +PyFfiContext::Stage PyFfiContext::stage() const { + return static_cast(stage_); +} + +absl::StatusOr PyFfiContext::stream() const { + XLA_FFI_Stream_Get_Args args; + args.struct_size = XLA_FFI_Stream_Get_Args_STRUCT_SIZE; + args.extension_start = nullptr; + args.ctx = ctx_; + args.stream = nullptr; + if (XLA_FFI_Error* error = api_->XLA_FFI_Stream_Get(&args)) { + return ffi::TakeStatus(error); + } + return absl::bit_cast(args.stream); +} + +PyFfiAnyBuffer::PyFfiAnyBuffer(DLDeviceType device_type, int32_t device_ordinal, + void* data, ffi::Span dimensions, + ffi::DataType element_type, bool writeable) + : device_type_(device_type), + device_ordinal_(device_ordinal), + data_(data), + dimensions_(dimensions.begin(), dimensions.size()), + element_type_(PrimitiveTypeForFfiDataType(element_type)), + writeable_(writeable) {} + +PyFfiAnyBuffer::PyFfiAnyBuffer(DLDeviceType device_type, int32_t device_ordinal, + ffi::AnyBuffer buf) + : PyFfiAnyBuffer(device_type, device_ordinal, buf.untyped_data(), + buf.dimensions(), buf.element_type(), + /*writeable=*/false) {} + +PyFfiAnyBuffer::PyFfiAnyBuffer(DLDeviceType device_type, int32_t device_ordinal, + ffi::Result buf) + : PyFfiAnyBuffer(device_type, device_ordinal, buf->untyped_data(), + buf->dimensions(), buf->element_type(), + /*writeable=*/true) {} + +absl::StatusOr PyFfiAnyBuffer::dtype() const { + return xla::PrimitiveTypeToNbDtype(element_type_); +} + +size_t PyFfiAnyBuffer::ndim() const { return dimensions_.size(); } + +nb::tuple PyFfiAnyBuffer::shape() const { + return xla::SpanToNbTuple(dimensions_); +} + +bool PyFfiAnyBuffer::writeable() const { return writeable_; } + +absl::StatusOr PyFfiAnyBuffer::NumpyArray() const { + if (device_type_ != kDLCPU) { + return absl::UnimplementedError( + "Buffer.__array__ is only supported on CPU."); + } + + TF_ASSIGN_OR_RETURN(auto dtype, this->dtype()); + xla::nb_numpy_ndarray array(dtype, dimensions_, /* strides= */ std::nullopt, + data_, nb::cast(this)); + + // TODO(danfm): We don't seem to be allowed to set this flag like this + // because the array doesn't own its data. + // array.attr("flags").attr("writeable") = nb::bool_(writeable_); + + return array; +} + +absl::StatusOr PyFfiAnyBuffer::CudaArrayInterface() const { + if (device_type_ != kDLCUDA && device_type_ != kDLROCM) { + return absl::UnimplementedError( + "Buffer.__cuda_array_interface__ is only supported on CUDA and ROCm."); + } + + nb::dict result; + result["shape"] = xla::SpanToNbTuple(dimensions_); + TF_ASSIGN_OR_RETURN(result["typestr"], + TypeDescriptorForPrimitiveType(element_type_)); + result["data"] = nb::make_tuple( + nb::int_(absl::bit_cast(data_)), !writeable_); + result["version"] = nb::int_(2); + return result; +} + +absl::StatusOr PyFfiAnyBuffer::DLPack() const { + auto pack = std::make_unique>(); + pack->tensor.manager_ctx = pack.get(); + pack->tensor.deleter = DLPackTensorDeleter; + + DLTensor& dt = pack->tensor.dl_tensor; + dt.data = data_; + dt.device = DLDevice{device_type_, device_ordinal_}; + dt.ndim = dimensions_.size(); + TF_ASSIGN_OR_RETURN(dt.dtype, xla::PrimitiveTypeToDLDataType(element_type_)); + pack->shape = std::vector(dimensions_.begin(), dimensions_.end()); + dt.shape = reinterpret_cast(pack->shape.data()); + dt.strides = nullptr; + dt.byte_offset = 0; + + // We cannot use nanobind's capsule object constructor because we need to + // detect if the capsule name has been changed in the deleter, but nanobind + // hides the underlying Python object from the deleter. + nb::capsule capsule = nb::steal( + PyCapsule_New(&pack.release()->tensor, kDlTensorCapsuleName, + [](PyObject* obj) noexcept { + DLManagedTensor* dlmt = static_cast( + PyCapsule_GetPointer(obj, kDlTensorCapsuleName)); + if (dlmt) { + DLPackTensorDeleter(dlmt); + } else { + // The tensor has been deleted. Clear any error from + // PyCapsule_GetPointer. + PyErr_Clear(); + } + })); + if (!capsule.ptr()) { + throw nb::python_error(); + } + + return capsule; +} + +absl::StatusOr PyFfiAnyBuffer::DLPackVersioned() const { + auto pack = std::make_unique>(); + pack->tensor.version = + DLPackVersion{DLPACK_MAJOR_VERSION, DLPACK_MINOR_VERSION}; + pack->tensor.manager_ctx = pack.get(); + pack->tensor.deleter = DLPackTensorDeleter; + pack->tensor.flags = writeable_ ? 0 : DLPACK_FLAG_BITMASK_READ_ONLY; + + DLTensor& dt = pack->tensor.dl_tensor; + dt.data = data_; + dt.device = DLDevice{device_type_, device_ordinal_}; + dt.ndim = dimensions_.size(); + TF_ASSIGN_OR_RETURN(dt.dtype, xla::PrimitiveTypeToDLDataType(element_type_)); + pack->shape = std::vector(dimensions_.begin(), dimensions_.end()); + dt.shape = reinterpret_cast(pack->shape.data()); + dt.strides = nullptr; + dt.byte_offset = 0; + + // We cannot use nanobind's capsule object constructor because we need to + // detect if the capsule name has been changed in the deleter, but nanobind + // hides the underlying Python object from the deleter. + nb::capsule capsule = nb::steal(PyCapsule_New( + &pack.release()->tensor, kDlTensorVersionedCapsuleName, + [](PyObject* obj) noexcept { + DLManagedTensorVersioned* dlmt = static_cast( + PyCapsule_GetPointer(obj, kDlTensorVersionedCapsuleName)); + if (dlmt) { + DLPackTensorDeleter(dlmt); + } else { + // The tensor has been deleted. Clear any error from + // PyCapsule_GetPointer. + PyErr_Clear(); + } + })); + if (!capsule.ptr()) { + throw nb::python_error(); + } + + return capsule; +} + +nb::tuple PyFfiAnyBuffer::DLPackDevice() const { + return nb::make_tuple(static_cast(device_type_), device_ordinal_); +} + +void RegisterFfiApis(nb::module_& m) { + nb::module_ ffi_module = + m.def_submodule("ffi", "Python bindings for the XLA FFI."); + + nb::class_ buffer(ffi_module, "Buffer"); + buffer.def_prop_ro("dtype", xla::ValueOrThrowWrapper(&PyFfiAnyBuffer::dtype)); + buffer.def_prop_ro("ndim", &PyFfiAnyBuffer::ndim); + buffer.def_prop_ro("shape", &PyFfiAnyBuffer::shape); + buffer.def_prop_ro("writeable", &PyFfiAnyBuffer::writeable); + buffer.def( + "__array__", + [](PyFfiAnyBuffer self, nb::object dtype, nb::object copy) { + if (!dtype.is_none()) { + throw nb::value_error( + "dtype parameter is not supported by Buffer.__array__."); + } + if (!copy.is_none() && nb::cast(copy)) { + throw nb::value_error( + "Buffer.__array__ with copy=True is not supported."); + } + return xla::ValueOrThrow(self.NumpyArray()); + }, + nb::arg("dtype") = nb::none(), nb::arg("copy") = nb::none()); + buffer.def_prop_ro( + "__cuda_array_interface__", + xla::ValueOrThrowWrapper(&PyFfiAnyBuffer::CudaArrayInterface)); + buffer.def( + "__dlpack__", + [](PyFfiAnyBuffer self, nb::object stream, nb::object max_version, + nb::object dl_device, nb::object copy) { + if (!copy.is_none() && nb::cast(copy)) { + throw nb::value_error( + "Buffer.__dlpack__ with copy=True is not supported."); + } + + // Fall back on the non-versioned API if unsupported by the requested + // max_version. + nb::tuple max_version_tuple; + int64_t max_version_major; + if (!nb::try_cast(max_version, max_version_tuple) || + max_version_tuple.size() < 2 || + !nb::try_cast(max_version_tuple[0], max_version_major) || + max_version_major < 1) { + return xla::ValueOrThrow(self.DLPack()); + } + + // TODO(danfm): Handle other optional inputs. + return xla::ValueOrThrow(self.DLPackVersioned()); + }, + nb::arg("stream") = nb::none(), nb::arg("max_version") = nb::none(), + nb::arg("dl_device") = nb::none(), nb::arg("copy") = nb::none()); + buffer.def("__dlpack_device__", &PyFfiAnyBuffer::DLPackDevice); + + nb::enum_(ffi_module, "ExecutionStage") + .value("INSTANTIATE", PyFfiContext::Stage::kInstantiate) + .value("PREPARE", PyFfiContext::Stage::kPrepare) + .value("INITIALIZE", PyFfiContext::Stage::kInitialize) + .value("EXECUTE", PyFfiContext::Stage::kExecute); + + nb::class_ context(ffi_module, "ExecutionContext"); + context.def_prop_ro("stage", &PyFfiContext::stage); + context.def_prop_ro("stream", + xla::ValueOrThrowWrapper(&PyFfiContext::stream)); + + // Custom-call targets. + m.def( + "register_custom_call_target", + [](nb::object fn_name_py, nb::object fn, const std::string& platform, + int api_version, XLA_FFI_Handler_Traits traits) { + std::string fn_name; + if (!nb::try_cast(fn_name_py, fn_name)) { + nb::bytes bytes = nb::cast(fn_name_py); + fn_name = std::string(bytes.c_str(), bytes.size()); + } + xla::ThrowIfError(PyRegisterCustomCallTarget( + fn_name, std::move(fn), platform, api_version, traits)); + }, + nb::arg("fn_name"), nb::arg("fn"), nb::arg("platform"), + nb::arg("api_version") = 0, nb::arg("traits") = 0); + + m.def( + "custom_call_targets", + [](const std::string& platform) -> nb::dict { + nb::dict targets; + for (const auto& [name, target] : + xla::CustomCallTargetRegistry::Global()->registered_symbols( + platform)) { + targets[nb::str(name.data(), name.size())] = nb::capsule(target); + } + + auto ffi_handlers = ffi::StaticRegisteredHandlers(platform); + if (!ffi_handlers.ok()) return targets; + + for (const auto& [name, registration] : *ffi_handlers) { + nb::dict bundle; + auto export_handler = [&](std::string_view name, XLA_FFI_Handler* h) { + if (h != nullptr) { + bundle[nb::str(name.data(), name.size())] = + nb::capsule(reinterpret_cast(h)); + } + }; + export_handler("prepare", registration.bundle.prepare); + export_handler("initialize", registration.bundle.initialize); + export_handler("execute", registration.bundle.execute); + targets[nb::str(name.data(), name.size())] = std::move(bundle); + } + return targets; + }, + nb::arg("platform")); + + m.def( + "register_custom_type", + [](std::string_view type_name, nb::object type) { + xla::ThrowIfError(PyRegisterCustomType(type_name, type)); + }, + nb::arg("type_name"), nb::arg("type_id")); +} + +} // namespace jax diff --git a/jaxlib/ffi.h b/jaxlib/ffi.h new file mode 100644 index 000000000000..eb3db50a985d --- /dev/null +++ b/jaxlib/ffi.h @@ -0,0 +1,150 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_FFI_H_ +#define JAXLIB_XLA_FFI_H_ + +#include + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "include/dlpack/dlpack.h" +#include "nanobind/nanobind.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" +#include "xla/pjrt/host_callback.h" +#include "xla/python/nb_numpy.h" +#include "xla/xla_data.pb.h" + +namespace jax { + +namespace ffi = xla::ffi; +namespace nb = nanobind; + +// Wrapper class for XLA FFI execution context. +// +// This class provides a Python interface to the XLA FFI execution context, +// exposing metadata such as the execution stage, device ordinal, and stream. +class PyFfiContext { + public: + enum class Stage { + kInstantiate, + kPrepare, + kInitialize, + kExecute, + }; + + PyFfiContext(const XLA_FFI_Api* api, XLA_FFI_ExecutionContext* ctx, + XLA_FFI_ExecutionStage stage); + Stage stage() const; + absl::StatusOr stream() const; + + private: + const XLA_FFI_Api* api_; + XLA_FFI_ExecutionContext* ctx_; + XLA_FFI_ExecutionStage stage_; +}; + +// Wrapper class for XLA FFI AnyBuffer. +// +// This class provides a Python interface to the XLA FFI `AnyBuffer` class. +// From Python, this object looks like an array (with `.dtype` and `.shape` +// attributes), but it also provides methods zero-copy conversions to standard +// transport formats: `__array__`, `__cuda_array_interface__`, and `__dlpack__`. +class PyFfiAnyBuffer { + public: + PyFfiAnyBuffer(DLDeviceType device_type, int32_t device_ordinal, void* data, + ffi::Span dimensions, + ffi::DataType element_type, bool writeable); + PyFfiAnyBuffer(DLDeviceType device_type, int32_t device_ordinal, + ffi::AnyBuffer buf); + PyFfiAnyBuffer(DLDeviceType device_type, int32_t device_ordinal, + ffi::Result buf); + + absl::StatusOr dtype() const; + size_t ndim() const; + nb::tuple shape() const; + bool writeable() const; + + absl::StatusOr NumpyArray() const; + absl::StatusOr CudaArrayInterface() const; + absl::StatusOr DLPack() const; + absl::StatusOr DLPackVersioned() const; + nb::tuple DLPackDevice() const; + + private: + DLDeviceType device_type_; + int32_t device_ordinal_; + void* data_; + absl::Span dimensions_; + xla::PrimitiveType element_type_; + bool writeable_; +}; + +template +ffi::Error XlaBufferCallback(ffi::Context ctx, int32_t device_ordinal, + xla::FfiLoadedHostCallbacks* callbacks, + uint64_t index, ffi::RemainingArgs args, + ffi::RemainingRets rets) { + nb::gil_scoped_acquire gil; + auto callback = nb::borrow( + static_cast(callbacks->callbacks[index])); + auto nb_args = + nb::steal(PyTuple_New(1 + args.size() + rets.size())); + + PyFfiContext py_ctx(ctx.api(), ctx.ctx(), XLA_FFI_ExecutionStage_EXECUTE); + PyTuple_SET_ITEM(nb_args.ptr(), 0, nb::cast(py_ctx).release().ptr()); + + size_t offset = 1; + for (size_t i = 0; i < args.size(); ++i, ++offset) { + auto arg = args.get(i); + if (arg.has_error()) { + return arg.error(); + } + PyFfiAnyBuffer py_buffer(DeviceType, device_ordinal, arg.value()); + PyTuple_SET_ITEM(nb_args.ptr(), offset, + nb::cast(py_buffer).release().ptr()); + } + + for (size_t i = 0; i < rets.size(); ++i, ++offset) { + auto ret = rets.get(i); + if (ret.has_error()) { + return ret.error(); + } + PyFfiAnyBuffer py_buffer(DeviceType, device_ordinal, ret.value()); + PyTuple_SET_ITEM(nb_args.ptr(), offset, + nb::cast(py_buffer).release().ptr()); + } + + xla::HostCallbackScope cleanup; + try { + callback(*nb::borrow(nb_args)); + } catch (nb::python_error& e) { + return ffi::Error::Internal( + absl::StrFormat("Error when calling buffer callback: %s", e.what())); + } + + return ffi::Error::Success(); +} + +void RegisterFfiApis(nanobind::module_& m); + +} // namespace jax + +#endif // JAXLIB_XLA_FFI_H_ diff --git a/jaxlib/ffi_helpers.h b/jaxlib/ffi_helpers.h index 5c6d80093df5..7c4dfce81311 100644 --- a/jaxlib/ffi_helpers.h +++ b/jaxlib/ffi_helpers.h @@ -1,3 +1,18 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef JAXLIB_FFI_HELPERS_H_ #define JAXLIB_FFI_HELPERS_H_ @@ -74,7 +89,7 @@ namespace jax { template inline absl::StatusOr MaybeCastNoOverflow( - std::int64_t value, const std::string& source = __FILE__) { + std::int64_t value, std::string_view source = __FILE__) { if constexpr (sizeof(T) == sizeof(std::int64_t)) { return value; } else { diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index b5292746dd10..6169c882aa05 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -14,13 +14,15 @@ # Shared CUDA/ROCM GPU kernels. +load("@rules_cc//cc:cc_library.bzl", "cc_library") +# Placeholder: load proto_library + load( "//jaxlib:jax.bzl", "cc_proto_library", "jax_visibility", "xla_py_proto_library", ) -# Placeholder: load proto_library licenses(["notice"]) @@ -30,11 +32,8 @@ package( ) exports_files(srcs = [ - "blas.cc", "blas_handle_pool.cc", "blas_handle_pool.h", - "blas_kernels.cc", - "blas_kernels.h", "ffi_wrapper.h", "gpu_kernel_helpers.cc", "gpu_kernel_helpers.h", @@ -52,6 +51,8 @@ exports_files(srcs = [ "prng_kernels.cc", "prng_kernels.cu.cc", "prng_kernels.h", + "py_client_gpu.cc", + "py_client_gpu.h", "rnn.cc", "rnn_kernels.cc", "rnn_kernels.h", @@ -60,8 +61,6 @@ exports_files(srcs = [ "solver_handle_pool.h", "solver_interface.cc", "solver_interface.h", - "solver_kernels.cc", - "solver_kernels.h", "solver_kernels_ffi.cc", "solver_kernels_ffi.h", "sparse.cc", @@ -82,6 +81,11 @@ proto_library( cc_proto_library( name = "triton_cc_proto", + compatible_with = None, + visibility = [ + "//jax:internal", + "//third_party/py/enzyme_ad:__subpackages__", + ], deps = [":triton_proto"], ) @@ -91,6 +95,21 @@ xla_py_proto_library( deps = [":triton_proto"], ) +cc_library( + name = "handle_pool", + hdrs = ["handle_pool.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + ], +) + cc_library( name = "gpu_plugin_extension", srcs = ["gpu_plugin_extension.cc"], @@ -105,7 +124,6 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", "@nanobind", "@xla//xla:util", "@xla//xla/ffi/api:c_api", @@ -115,7 +133,7 @@ cc_library( "@xla//xla/pjrt/c:pjrt_c_api_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_helpers", "@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs", - "@xla//xla/python:py_client_gpu", + "@xla//xla/tsl/platform:statusor", "@xla//xla/tsl/python/lib/core:numpy", ], ) diff --git a/jaxlib/gpu/blas.cc b/jaxlib/gpu/blas.cc deleted file mode 100644 index e8761bd32ac9..000000000000 --- a/jaxlib/gpu/blas.cc +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright 2019 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" -#include "absl/container/flat_hash_map.h" -#include "absl/strings/str_format.h" -#include "jaxlib/gpu/blas_kernels.h" -#include "jaxlib/gpu/vendor.h" -#include "jaxlib/kernel_nanobind_helpers.h" -#include "xla/tsl/python/lib/core/numpy.h" - -namespace jax { -namespace JAX_GPU_NAMESPACE { -namespace { - -namespace nb = nanobind; - -// Converts a NumPy dtype to a Type. -BlasType DtypeToBlasType(const dtype& np_type) { - static auto* types = new absl::flat_hash_map, BlasType>({ - {{'f', 4}, BlasType::F32}, - {{'f', 8}, BlasType::F64}, - {{'c', 8}, BlasType::C64}, - {{'c', 16}, BlasType::C128}, - }); - auto it = types->find({np_type.kind(), np_type.itemsize()}); - if (it == types->end()) { - nb::str repr = nb::repr(np_type); - throw std::invalid_argument( - absl::StrFormat("Unsupported dtype %s", repr.c_str())); - } - return it->second; -} - -// Returns the descriptor for a GetrfBatched operation. -std::pair BuildGetrfBatchedDescriptor(const dtype& dtype, - int b, int n) { - BlasType type = DtypeToBlasType(dtype); - size_t size = b * sizeof(void*); - return {size, PackDescriptor(GetrfBatchedDescriptor{type, b, n})}; -} - -// Returns the descriptor for a GetrfBatched operation. -std::pair BuildGeqrfBatchedDescriptor(const dtype& dtype, - int b, int m, int n) { - BlasType type = DtypeToBlasType(dtype); - size_t size = b * sizeof(void*); - return {size, PackDescriptor(GeqrfBatchedDescriptor{type, b, m, n})}; -} - -nb::dict Registrations() { - nb::dict dict; - dict[JAX_GPU_PREFIX "blas_getrf_batched"] = EncapsulateFunction(GetrfBatched); - dict[JAX_GPU_PREFIX "blas_geqrf_batched"] = EncapsulateFunction(GeqrfBatched); - return dict; -} - -NB_MODULE(_blas, m) { - tsl::ImportNumpy(); - - m.def("registrations", &Registrations); - m.def("build_getrf_batched_descriptor", &BuildGetrfBatchedDescriptor); - m.def("build_geqrf_batched_descriptor", &BuildGeqrfBatchedDescriptor); -} - -} // namespace -} // namespace JAX_GPU_NAMESPACE -} // namespace jax diff --git a/jaxlib/gpu/blas_handle_pool.cc b/jaxlib/gpu/blas_handle_pool.cc index 2ce204453039..ab449ff12f6d 100644 --- a/jaxlib/gpu/blas_handle_pool.cc +++ b/jaxlib/gpu/blas_handle_pool.cc @@ -19,7 +19,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/handle_pool.h" +#include "jaxlib/gpu/handle_pool.h" namespace jax { @@ -27,7 +27,7 @@ template <> /*static*/ absl::StatusOr BlasHandlePool::Borrow( gpuStream_t stream) { BlasHandlePool* pool = Instance(); - absl::MutexLock lock(&pool->mu_); + absl::MutexLock lock(pool->mu_); gpublasHandle_t handle; if (pool->handles_[stream].empty()) { JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasCreate(&handle))); diff --git a/jaxlib/gpu/blas_handle_pool.h b/jaxlib/gpu/blas_handle_pool.h index b3cdbaa88867..43724baab45e 100644 --- a/jaxlib/gpu/blas_handle_pool.h +++ b/jaxlib/gpu/blas_handle_pool.h @@ -18,7 +18,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/handle_pool.h" +#include "jaxlib/gpu/handle_pool.h" namespace jax { diff --git a/jaxlib/gpu/blas_kernels.cc b/jaxlib/gpu/blas_kernels.cc deleted file mode 100644 index ac30aa9cc520..000000000000 --- a/jaxlib/gpu/blas_kernels.cc +++ /dev/null @@ -1,198 +0,0 @@ -/* Copyright 2019 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "jaxlib/gpu/blas_kernels.h" - -#include -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_format.h" -#include "jaxlib/gpu/blas_handle_pool.h" -#include "jaxlib/gpu/gpu_kernel_helpers.h" -#include "jaxlib/gpu/make_batch_pointers.h" -#include "jaxlib/gpu/vendor.h" -#include "jaxlib/kernel_helpers.h" -#include "xla/service/custom_call_status.h" - -namespace jax { - -namespace JAX_GPU_NAMESPACE { - -namespace { - -int SizeOfBlasType(BlasType type) { - switch (type) { - case BlasType::F32: - return sizeof(float); - case BlasType::F64: - return sizeof(double); - case BlasType::C64: - return sizeof(gpublasComplex); - case BlasType::C128: - return sizeof(gpublasDoubleComplex); - } -} - -} // namespace - -// Batched LU decomposition: getrfbatched - -static absl::Status GetrfBatched_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const GetrfBatchedDescriptor& d = **s; - auto h = BlasHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[0] != buffers[1]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], SizeOfBlasType(d.type) * d.batch * d.n * d.n, - gpuMemcpyDeviceToDevice, stream))); - } - - int* ipiv = static_cast(buffers[2]); - int* info = static_cast(buffers[3]); - MakeBatchPointersAsync(stream, buffers[1], buffers[4], d.batch, - SizeOfBlasType(d.type) * d.n * d.n); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError())); - switch (d.type) { - case BlasType::F32: { - float** batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasSgetrfBatched( - handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); - break; - } - case BlasType::F64: { - double** batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasDgetrfBatched( - handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); - break; - } - case BlasType::C64: { - gpublasComplex** batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasCgetrfBatched( - handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); - break; - } - case BlasType::C128: { - gpublasDoubleComplex** batch_ptrs = - static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasZgetrfBatched( - handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); - break; - } - } - return absl::OkStatus(); -} - -void GetrfBatched(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = GetrfBatched_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -// Batched QR decomposition: geqrfbatched - -static absl::Status GeqrfBatched_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const GeqrfBatchedDescriptor& d = **s; - auto h = BlasHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[0] != buffers[1]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], SizeOfBlasType(d.type) * d.batch * d.m * d.n, - gpuMemcpyDeviceToDevice, stream))); - } - - std::vector info(d.batch); - MakeBatchPointersAsync(stream, buffers[1], buffers[3], d.batch, - SizeOfBlasType(d.type) * d.m * d.n); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError())); - MakeBatchPointersAsync(stream, buffers[2], buffers[4], d.batch, - SizeOfBlasType(d.type) * std::min(d.m, d.n)); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError())); - switch (d.type) { - case BlasType::F32: { - float** a_batch_ptrs = static_cast(buffers[3]); - float** tau_batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpublasSgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, - tau_batch_ptrs, info.data(), d.batch))); - break; - } - case BlasType::F64: { - double** a_batch_ptrs = static_cast(buffers[3]); - double** tau_batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpublasDgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, - tau_batch_ptrs, info.data(), d.batch))); - break; - } - case BlasType::C64: { - gpublasComplex** a_batch_ptrs = static_cast(buffers[3]); - gpublasComplex** tau_batch_ptrs = - static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpublasCgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, - tau_batch_ptrs, info.data(), d.batch))); - break; - } - case BlasType::C128: { - gpublasDoubleComplex** a_batch_ptrs = - static_cast(buffers[3]); - gpublasDoubleComplex** tau_batch_ptrs = - static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpublasZgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, - tau_batch_ptrs, info.data(), d.batch))); - break; - } - } - auto it = - std::find_if(info.begin(), info.end(), [](int i) { return i != 0; }); - - if (it != info.end()) { - return absl::InvalidArgumentError( - absl::StrFormat("QR decomposition failed with status %d for batch " - "element %d", - *it, std::distance(info.begin(), it))); - } - - return absl::OkStatus(); -} - -void GeqrfBatched(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = GeqrfBatched_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -} // namespace JAX_GPU_NAMESPACE -} // namespace jax diff --git a/jaxlib/gpu/blas_kernels.h b/jaxlib/gpu/blas_kernels.h deleted file mode 100644 index 724565ea73d1..000000000000 --- a/jaxlib/gpu/blas_kernels.h +++ /dev/null @@ -1,58 +0,0 @@ -/* Copyright 2019 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef JAXLIB_GPU_BLAS_KERNELS_H_ -#define JAXLIB_GPU_BLAS_KERNELS_H_ - -#include - -#include "jaxlib/gpu/vendor.h" -#include "xla/service/custom_call_status.h" - -namespace jax { -namespace JAX_GPU_NAMESPACE { - -// Set of types known to Cusolver. -enum class BlasType { - F32, - F64, - C64, - C128, -}; - -// Batched LU decomposition: getrfbatched - -struct GetrfBatchedDescriptor { - BlasType type; - int batch, n; -}; - -void GetrfBatched(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// Batched QR decomposition: geqrfbatched - -struct GeqrfBatchedDescriptor { - BlasType type; - int batch, m, n; -}; - -void GeqrfBatched(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -} // namespace JAX_GPU_NAMESPACE -} // namespace jax - -#endif // JAXLIB_GPU_BLAS_KERNELS_H_ diff --git a/jaxlib/gpu/gpu_kernel_helpers.cc b/jaxlib/gpu/gpu_kernel_helpers.cc index 5a434f4b6ad5..c88fd753bac4 100644 --- a/jaxlib/gpu/gpu_kernel_helpers.cc +++ b/jaxlib/gpu/gpu_kernel_helpers.cc @@ -15,12 +15,15 @@ limitations under the License. #include "jaxlib/gpu/gpu_kernel_helpers.h" +#include +#include + #include "absl/base/optimization.h" #include "absl/log/check.h" -#include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "jaxlib/gpu/vendor.h" namespace jax { namespace JAX_GPU_NAMESPACE { @@ -141,18 +144,12 @@ std::string ErrorString(cufftResult status) { return "cuFFT invalid size"; case CUFFT_UNALIGNED_DATA: return "cuFFT unaligned data"; - case CUFFT_INCOMPLETE_PARAMETER_LIST: - return "cuFFT incomplete parameter list"; case CUFFT_INVALID_DEVICE: return "cuFFT invalid device"; - case CUFFT_PARSE_ERROR: - return "cuFFT parse error"; case CUFFT_NO_WORKSPACE: return "cuFFT no workspace"; case CUFFT_NOT_IMPLEMENTED: return "cuFFT not implemented"; - case CUFFT_LICENSE_ERROR: - return "cuFFT license error"; case CUFFT_NOT_SUPPORTED: return "cuFFT not supported"; default: diff --git a/jaxlib/gpu/gpu_kernel_helpers.h b/jaxlib/gpu/gpu_kernel_helpers.h index aecb8a4fdcf1..0326d7f44620 100644 --- a/jaxlib/gpu/gpu_kernel_helpers.h +++ b/jaxlib/gpu/gpu_kernel_helpers.h @@ -16,11 +16,10 @@ limitations under the License. #ifndef JAXLIB_GPU_GPU_KERNEL_HELPERS_H_ #define JAXLIB_GPU_GPU_KERNEL_HELPERS_H_ -#include +#include #include "absl/base/optimization.h" #include "absl/status/status.h" -#include "absl/status/statusor.h" #include "jaxlib/gpu/vendor.h" #define JAX_AS_STATUS(expr) \ diff --git a/jaxlib/gpu/gpu_kernels.cc b/jaxlib/gpu/gpu_kernels.cc index 242078357254..5004747d17f2 100644 --- a/jaxlib/gpu/gpu_kernels.cc +++ b/jaxlib/gpu/gpu_kernels.cc @@ -16,11 +16,9 @@ limitations under the License. // This file is not used by JAX itself, but exists to assist with running // JAX-generated HLO code from outside of JAX. -#include "jaxlib/gpu/blas_kernels.h" #include "jaxlib/gpu/linalg_kernels.h" #include "jaxlib/gpu/prng_kernels.h" #include "jaxlib/gpu/rnn_kernels.h" -#include "jaxlib/gpu/solver_kernels.h" #include "jaxlib/gpu/solver_kernels_ffi.h" #include "jaxlib/gpu/sparse_kernels.h" #include "jaxlib/gpu/triton_kernels.h" @@ -33,39 +31,31 @@ namespace jax { namespace JAX_GPU_NAMESPACE { namespace { -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_getrf_batched", GetrfBatched, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_geqrf_batched", GeqrfBatched, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cudnn_rnn", RNNForward, "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cudnn_rnn_bwd", RNNBackward, "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_getrf", Getrf, "CUDA"); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cudnn_rnn", "CUDA", RNNForwardFfi); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cudnn_rnn_bwd", "CUDA", + RNNBackwardFfi); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_getrf_ffi", "CUDA", GetrfFfi); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_syrk_ffi", "CUDA", SyrkFfi); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_geqrf", Geqrf, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_geqrf_ffi", "CUDA", GeqrfFfi); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_csrlsvqr", Csrlsvqr, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_csrlsvqr_ffi", "CUDA", CsrlsvqrFfi); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_orgqr", Orgqr, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_orgqr_ffi", "CUDA", OrgqrFfi); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevd", Syevd, "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevj", Syevj, "CUDA"); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_potrf_ffi", "CUDA", + PotrfFfi); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_syevd_ffi", "CUDA", SyevdFfi); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_sytrd", Sytrd, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_sytrd_ffi", "CUDA", SytrdFfi); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvd", Gesvd, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_gesvd_ffi", "CUDA", GesvdFfi); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvdj", Gesvdj, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_gesvdj_ffi", "CUDA", GesvdjFfi); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_gesvdp_ffi", "CUDA", + GesvdpFfi); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cu_cholesky_update_ffi", "CUDA", CholeskyUpdateFfi); @@ -74,28 +64,24 @@ XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cu_lu_pivots_to_permutation", XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cu_threefry2x32_ffi", "CUDA", ThreeFry2x32Ffi); -#if JAX_CUSPARSE_11300 -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_csr_todense", CsrToDense, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_csr_fromdense", CsrFromDense, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_csr_matvec", CsrMatvec, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_csr_matmat", CsrMatmat, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_coo_todense", CooToDense, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_coo_fromdense", CooFromDense, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_coo_matvec", CooMatvec, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_coo_matmat", CooMatmat, - "CUDA"); -#endif -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_gtsv2_f32", gtsv2_f32, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_gtsv2_f64", gtsv2_f64, - "CUDA"); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusparse_csr_todense_ffi", "CUDA", + CsrToDenseFfi); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusparse_csr_fromdense_ffi", "CUDA", + CsrFromDenseFfi); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusparse_csr_matvec_ffi", "CUDA", + CsrMatvecFfi); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusparse_csr_matmat_ffi", "CUDA", + CsrMatmatFfi); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusparse_coo_todense_ffi", "CUDA", + CooToDenseFfi); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusparse_coo_fromdense_ffi", "CUDA", + CooFromDenseFfi); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusparse_coo_matvec_ffi", "CUDA", + CooMatvecFfi); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusparse_coo_matmat_ffi", "CUDA", + CooMatmatFfi); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusparse_gtsv2_ffi", "CUDA", + kGtsv2); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("triton_kernel_call", TritonKernelCall, "CUDA"); diff --git a/jaxlib/gpu/gpu_plugin_extension.cc b/jaxlib/gpu/gpu_plugin_extension.cc index b56cb8337f1b..d3f411cc87b6 100644 --- a/jaxlib/gpu/gpu_plugin_extension.cc +++ b/jaxlib/gpu/gpu_plugin_extension.cc @@ -18,15 +18,15 @@ limitations under the License. #include #include #include +#include #include -#include "nanobind/nanobind.h" -#include "nanobind/stl/string.h" // IWYU pragma: keep -#include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "jaxlib/kernel_nanobind_helpers.h" #include "xla/ffi/api/c_api.h" #include "xla/pjrt/c/pjrt_c_api.h" @@ -35,32 +35,29 @@ limitations under the License. #include "xla/pjrt/c/pjrt_c_api_helpers.h" #include "xla/pjrt/c/pjrt_c_api_triton_extension.h" #include "xla/pjrt/status_casters.h" -#include "xla/python/py_client_gpu.h" +#include "xla/tsl/platform/statusor.h" #include "xla/tsl/python/lib/core/numpy.h" #include "xla/util.h" namespace nb = nanobind; -namespace xla { +namespace jax { namespace { struct TritonCompilationResult { std::string asm_text; int64_t smem_bytes; - int cluster_dim_x; - int cluster_dim_y; - int cluster_dim_z; }; absl::StatusOr CompileTritonToASM( - const PJRT_Api* c_api, absl::string_view module, - absl::string_view arch_name, int num_warps, int num_ctas, int num_stages) { + const PJRT_Api* c_api, std::string_view module, std::string_view arch_name, + int num_warps, int num_ctas, int num_stages) { const PJRT_Triton_Extension* triton_ext = pjrt::FindExtension( c_api, PJRT_Extension_Type::PJRT_Extension_Type_Triton); if (triton_ext == nullptr) { - return Unimplemented("The plugin does not have a Triton extension."); + return xla::Unimplemented("The plugin does not have a Triton extension."); } PJRT_Triton_Compile_Args args; args.struct_size = PJRT_Triton_Compile_Args_STRUCT_SIZE; @@ -77,9 +74,6 @@ absl::StatusOr CompileTritonToASM( return TritonCompilationResult{ .asm_text = asm_text, .smem_bytes = args.out_smem_bytes, - .cluster_dim_x = args.out_cluster_dim_x, - .cluster_dim_y = args.out_cluster_dim_y, - .cluster_dim_z = args.out_cluster_dim_z, }; } @@ -92,13 +86,15 @@ absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, pjrt::FindExtension( c_api, PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call); if (custom_call_ext == nullptr) { - return Unimplemented("The plugin does not have a custom call extension."); + return xla::Unimplemented( + "The plugin does not have a custom call extension."); } PJRT_Gpu_Register_Custom_Call* register_custom_call = custom_call_ext->custom_call; if (traits != 0) { - return Unimplemented("The plugin does not support custom call traits."); + return xla::Unimplemented( + "The plugin does not support custom call traits."); } PJRT_Gpu_Register_Custom_Call_Args args; @@ -174,39 +170,61 @@ absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, #endif } -absl::Status RegisterCustomTypeId(const PJRT_Api* c_api, - const char* type_name_c_str, - size_t type_name_size, nb::object type_id) { +absl::Status RegisterCustomType(const PJRT_Api* c_api, + const char* type_name_c_str, + size_t type_name_size, nb::object type) { const PJRT_FFI_Extension* ffi_ext = pjrt::FindExtension( c_api, PJRT_Extension_Type::PJRT_Extension_Type_FFI); if (ffi_ext == nullptr) { - return Unimplemented("The plugin does not have the FFI extension."); + return xla::Unimplemented("The plugin does not have the FFI extension."); } - PJRT_FFI_TypeID_Register_Args args; - args.struct_size = PJRT_FFI_TypeID_Register_Args_STRUCT_SIZE; - args.type_name = type_name_c_str; - args.type_name_size = type_name_size; - RETURN_STATUS_IF_PJRT_ERROR(ffi_ext->type_id_register(&args), c_api); + XLA_FFI_TypeId* type_id = nullptr; + XLA_FFI_TypeInfo* type_info = nullptr; + + auto as_capsule = [](nb::object obj) -> absl::StatusOr { + nb::capsule capsule; + if (!nb::try_cast(obj, capsule)) { + return absl::InvalidArgumentError( + "Custom type registration requires handlers as PyCapsules"); + } + return capsule; + }; - nb::capsule capsule; - if (!nb::try_cast(type_id, capsule)) { + // Extract XLA_FFI_TypeId and optional XLA_FFI_TypeInfo from the type dict. + nb::dict type_dict; + if (!nb::try_cast(type, type_dict) || + !type_dict.contains("type_id")) { return absl::InvalidArgumentError( - "The type_id argument to register_custom_call_type_id must be a " - "PyCapsule object holding a pointer to a XLA_FFI_TypeId."); + "The type_id argument to register_custom_call_type must be a " + "dictionary holding a pointer to a XLA_FFI_TypeId in `type_id` and " + "optional pointer to a XLA_FFI_TypeInfo in `type_info` fields."); } - XLA_FFI_TypeId* type_id_ptr = - reinterpret_cast(static_cast(capsule.data())); - type_id_ptr->type_id = args.type_id; - return absl::OkStatus(); -} + TF_ASSIGN_OR_RETURN(auto type_id_capsule, as_capsule(type_dict["type_id"])); + type_id = static_cast(type_id_capsule.data()); + + if (type_dict.contains("type_info")) { + TF_ASSIGN_OR_RETURN(auto type_info_capsule, + as_capsule(type_dict["type_info"])); + type_info = static_cast(type_info_capsule.data()); + } -nb::dict Registrations() { - nb::dict dict; - dict["xla_python_gpu_callback"] = - jax::EncapsulateFunction(xla::XlaPythonGpuCallback); - return dict; + PJRT_FFI_Type_Info pjrt_type_info{ + /*deleter=*/type_info ? type_info->deleter : nullptr, + }; + + PJRT_FFI_Type_Register_Args args; + args.struct_size = PJRT_FFI_Type_Register_Args_STRUCT_SIZE; + args.type_name = type_name_c_str; + args.type_name_size = type_name_size; + args.type_id = type_id->type_id; + args.type_info = &pjrt_type_info; + RETURN_STATUS_IF_PJRT_ERROR(ffi_ext->type_register(&args), c_api); + + // Return registered type id to the caller. + type_id->type_id = args.type_id; + return absl::OkStatus(); } } // namespace @@ -216,18 +234,15 @@ void BuildGpuPluginExtension(nanobind::module_& m) { nb::class_(m, "TritonCompilationResult") .def_ro("asm", &TritonCompilationResult::asm_text) - .def_ro("smem_bytes", &TritonCompilationResult::smem_bytes) - .def_ro("cluster_dim_x", &TritonCompilationResult::cluster_dim_x) - .def_ro("cluster_dim_y", &TritonCompilationResult::cluster_dim_y) - .def_ro("cluster_dim_z", &TritonCompilationResult::cluster_dim_z); + .def_ro("smem_bytes", &TritonCompilationResult::smem_bytes); m.def("compile_triton_to_asm", - [](nb::capsule c_api, nb::bytes module, absl::string_view arch_name, + [](nb::capsule c_api, nb::bytes module, std::string_view arch_name, int num_warps, int num_ctas, int num_stages) { return xla::ValueOrThrow(CompileTritonToASM( static_cast(c_api.data()), - absl::string_view(static_cast(module.data()), - module.size()), + std::string_view(static_cast(module.data()), + module.size()), arch_name, num_warps, num_ctas, num_stages)); }); @@ -255,16 +270,15 @@ void BuildGpuPluginExtension(nanobind::module_& m) { nb::arg("xla_platform_name"), nb::arg("api_version") = 0, nb::arg("traits") = 0); m.def( - "register_custom_type_id", - [](nb::capsule c_api, nb::str type_name_py, nb::object type_id) { + "register_custom_type", + [](nb::capsule c_api, nb::str type_name_py, nb::object type) { const char* type_name_c_str = type_name_py.c_str(); size_t type_name_size = nb::len(type_name_py); - xla::ThrowIfError(RegisterCustomTypeId( + xla::ThrowIfError(RegisterCustomType( static_cast(c_api.data()), type_name_c_str, - type_name_size, std::move(type_id))); + type_name_size, std::move(type))); }, nb::arg("c_api"), nb::arg("type_name"), nb::arg("type_id")); - m.def("registrations", &Registrations); } -} // namespace xla +} // namespace jax diff --git a/jaxlib/gpu/gpu_plugin_extension.h b/jaxlib/gpu/gpu_plugin_extension.h index 70c74454ecc6..3845db85c9d1 100644 --- a/jaxlib/gpu/gpu_plugin_extension.h +++ b/jaxlib/gpu/gpu_plugin_extension.h @@ -18,10 +18,10 @@ limitations under the License. #include "nanobind/nanobind.h" -namespace xla { +namespace jax { void BuildGpuPluginExtension(nanobind::module_& m); -} // namespace xla +} // namespace jax #endif // JAXLIB_GPU_GPU_PLUGIN_EXTENSION_H_ diff --git a/jaxlib/handle_pool.h b/jaxlib/gpu/handle_pool.h similarity index 96% rename from jaxlib/handle_pool.h rename to jaxlib/gpu/handle_pool.h index 9201d8d579c5..9189bb174b06 100644 --- a/jaxlib/handle_pool.h +++ b/jaxlib/gpu/handle_pool.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_HANDLE_POOL_H_ -#define JAXLIB_HANDLE_POOL_H_ +#ifndef JAXLIB_GPU_HANDLE_POOL_H_ +#define JAXLIB_GPU_HANDLE_POOL_H_ #include #include @@ -107,4 +107,4 @@ void HandlePool::Return(HandleType handle, } // namespace jax -#endif // JAXLIB_HANDLE_POOL_H_ +#endif // JAXLIB_GPU_HANDLE_POOL_H_ diff --git a/jaxlib/gpu/hybrid.cc b/jaxlib/gpu/hybrid.cc index 94975a5b969f..d5b91db496fd 100644 --- a/jaxlib/gpu/hybrid.cc +++ b/jaxlib/gpu/hybrid.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "nanobind/nanobind.h" #include "absl/base/call_once.h" +#include "nanobind/nanobind.h" #include "jaxlib/cpu/lapack_kernels.h" #include "jaxlib/gpu/hybrid_kernels.h" #include "jaxlib/gpu/vendor.h" @@ -29,6 +29,9 @@ namespace nb = nanobind; void GetLapackKernelsFromScipy() { static absl::once_flag initialized; + if (lapack_kernels_initialized) { + return; + } // For reasons I'm not entirely sure of, if the import_ call is done inside // the call_once scope, we sometimes observe deadlocks in the test suite. // However it probably doesn't do much harm to just import them a second time, diff --git a/jaxlib/gpu/hybrid_kernels.cc b/jaxlib/gpu/hybrid_kernels.cc index 8caa0e1d75a6..9af772dc9aeb 100644 --- a/jaxlib/gpu/hybrid_kernels.cc +++ b/jaxlib/gpu/hybrid_kernels.cc @@ -110,21 +110,25 @@ struct MagmaGeqp3 { template <> struct MagmaGeqp3 { static constexpr char name[] = "magma_sgeqp3_gpu"; + static constexpr char expert_name[] = "magma_sgeqp3_expert_gpu_work"; static constexpr char block_size_name[] = "magma_get_sgeqp3_nb"; }; template <> struct MagmaGeqp3 { static constexpr char name[] = "magma_dgeqp3_gpu"; + static constexpr char expert_name[] = "magma_dgeqp3_expert_gpu_work"; static constexpr char block_size_name[] = "magma_get_dgeqp3_nb"; }; template <> struct MagmaGeqp3 { static constexpr char name[] = "magma_cgeqp3_gpu"; + static constexpr char expert_name[] = "magma_cgeqp3_expert_gpu_work"; static constexpr char block_size_name[] = "magma_get_cgeqp3_nb"; }; template <> struct MagmaGeqp3 { static constexpr char name[] = "magma_zgeqp3_gpu"; + static constexpr char expert_name[] = "magma_zgeqp3_expert_gpu_work"; static constexpr char block_size_name[] = "magma_get_zgeqp3_nb"; }; @@ -221,7 +225,7 @@ absl::StatusOr MagmaLookup::Find(const char name[]) { absl::StatusOr FindMagmaSymbol(const char name[]) { static absl::Mutex mu; static MagmaLookup& lookup = *new MagmaLookup ABSL_GUARDED_BY(mu); - absl::MutexLock lock(&mu); + absl::MutexLock lock(mu); auto status = lookup.Initialize(); if (!status.ok()) { return status; @@ -371,24 +375,32 @@ class PivotingQrFactorizationMagma { private: Fn* fn_ = nullptr; - BlockSizeFn* block_size_fn_ = nullptr; absl::StatusOr lwork(int m, int n) { - // `{c,d,s,z}_geqp3_gpu` do not support a workspace query, but we can still - // assign the symbol here. + // `{c,d,s,z}_geqp3_gpu` do not support a workspace query, but we can call + // the expert API instead. auto maybe_ptr = FindMagmaSymbol(MagmaGeqp3::name); if (!maybe_ptr.ok()) return maybe_ptr.status(); fn_ = reinterpret_cast(*maybe_ptr); - auto block_size_maybe_ptr = - FindMagmaSymbol(MagmaGeqp3::block_size_name); - if (!block_size_maybe_ptr.ok()) return block_size_maybe_ptr.status(); - block_size_fn_ = reinterpret_cast(*block_size_maybe_ptr); - int optimal_block_size = block_size_fn_(m, n); - if constexpr (ffi::IsComplexType()) { - return (n + 1) * optimal_block_size; + auto maybe_expert_ptr = FindMagmaSymbol(MagmaGeqp3::expert_name); + if (!maybe_expert_ptr.ok()) return maybe_expert_ptr.status(); + using ExpertFn = int(int m, int n, ValueType* dA, int ldda, int* jpvt, + ValueType* tau, void* host_work, int* lwork_host, + void* device_work, int* lwork_device, int* info, + void* queue); + auto* expert_fn = reinterpret_cast(*maybe_expert_ptr); + + int lwork_host = -1; + int lwork_device = -1; + int info = 0; + expert_fn(m, n, nullptr, std::max(1, m), nullptr, nullptr, nullptr, + &lwork_host, nullptr, &lwork_device, &info, nullptr); + if (info != 0) { + return absl::InternalError(absl::StrFormat( + "MAGMA expert geqp3 workspace query failed with info=%d", info)); } - return (n + 1) * optimal_block_size + 2 * n; + return lwork_device / sizeof(ValueType); } }; diff --git a/jaxlib/gpu/linalg_kernels.cc b/jaxlib/gpu/linalg_kernels.cc index 2293bef89b7d..b48e64f2181d 100644 --- a/jaxlib/gpu/linalg_kernels.cc +++ b/jaxlib/gpu/linalg_kernels.cc @@ -90,8 +90,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CholeskyUpdateFfi, CholeskyUpdateFfiImpl, namespace { ffi::Error LuPivotsToPermutationImpl( - gpuStream_t stream, ffi::Dictionary /* unused */, - ffi::Buffer pivots, + gpuStream_t stream, ffi::Buffer pivots, ffi::Result> permutation) { FFI_ASSIGN_OR_RETURN((auto [batch_size, pivot_size]), SplitBatch1D(pivots.dimensions())); @@ -119,10 +118,6 @@ ffi::Error LuPivotsToPermutationImpl( XLA_FFI_DEFINE_HANDLER_SYMBOL(LuPivotsToPermutation, LuPivotsToPermutationImpl, ffi::Ffi::Bind() .Ctx>() - // TODO(b/358275922): remove Attrs (and the - // unused Dictionary above) 12 weeks after - // release of jaxlib v0.4.32. - .Attrs() .Arg>() .Ret>()); diff --git a/jaxlib/gpu/make_batch_pointers.cu.cc b/jaxlib/gpu/make_batch_pointers.cu.cc index 3a24e355ead0..1d05fa8adcac 100644 --- a/jaxlib/gpu/make_batch_pointers.cu.cc +++ b/jaxlib/gpu/make_batch_pointers.cu.cc @@ -16,6 +16,7 @@ limitations under the License. #include "jaxlib/gpu/make_batch_pointers.h" #include +#include #include #include "jaxlib/gpu/vendor.h" diff --git a/jaxlib/gpu/prng.cc b/jaxlib/gpu/prng.cc index 1ce428d7f9dc..007e51b76de7 100644 --- a/jaxlib/gpu/prng.cc +++ b/jaxlib/gpu/prng.cc @@ -15,6 +15,7 @@ limitations under the License. #include "nanobind/nanobind.h" #include "jaxlib/gpu/prng_kernels.h" +#include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" namespace jax { diff --git a/jaxlib/gpu/prng_kernels.cc b/jaxlib/gpu/prng_kernels.cc index f5d6abef83f8..1dac1e47bd44 100644 --- a/jaxlib/gpu/prng_kernels.cc +++ b/jaxlib/gpu/prng_kernels.cc @@ -17,16 +17,12 @@ limitations under the License. #include #include -#include #include "absl/algorithm/container.h" -#include "absl/status/status.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/ffi_helpers.h" -#include "jaxlib/kernel_helpers.h" #include "xla/ffi/api/ffi.h" -#include "xla/service/custom_call_status.h" namespace jax { namespace JAX_GPU_NAMESPACE { diff --git a/jaxlib/gpu/prng_kernels.cu.cc b/jaxlib/gpu/prng_kernels.cu.cc index d4aaec62320d..e42165f95d15 100644 --- a/jaxlib/gpu/prng_kernels.cu.cc +++ b/jaxlib/gpu/prng_kernels.cu.cc @@ -15,8 +15,7 @@ limitations under the License. #include "jaxlib/gpu/prng_kernels.h" -#include -#include +#include #include #include "jaxlib/gpu/vendor.h" diff --git a/jaxlib/gpu/prng_kernels.h b/jaxlib/gpu/prng_kernels.h index c98fd485700d..4d64d2b4a4e4 100644 --- a/jaxlib/gpu/prng_kernels.h +++ b/jaxlib/gpu/prng_kernels.h @@ -16,12 +16,10 @@ limitations under the License. #ifndef JAXLIB_GPU_PRNG_KERNELS_H_ #define JAXLIB_GPU_PRNG_KERNELS_H_ -#include #include #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" -#include "xla/service/custom_call_status.h" namespace jax { namespace JAX_GPU_NAMESPACE { diff --git a/jaxlib/gpu/py_client_gpu.cc b/jaxlib/gpu/py_client_gpu.cc new file mode 100644 index 000000000000..f6678e023c6e --- /dev/null +++ b/jaxlib/gpu/py_client_gpu.cc @@ -0,0 +1,320 @@ +/* Copyright 2022 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/gpu/py_client_gpu.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "include/dlpack/dlpack.h" +#include "nanobind/nanobind.h" +#include "jaxlib/ffi.h" +#include "jaxlib/gpu/vendor.h" +#include "xla/ffi/api/ffi.h" +#include "xla/ffi/ffi_api.h" +#include "xla/pjrt/host_callback.h" +#include "xla/pjrt/transpose.h" +#include "xla/primitive_util.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/types.h" +#include "xla/shape_util.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace nb = nanobind; + +namespace jax { +namespace JAX_GPU_NAMESPACE { + +struct GpuTransposePlanCache { + static xla::ffi::TypeId id; + static xla::ffi::TypeInfo info; + + explicit GpuTransposePlanCache(int capacity) : cache(capacity) {} + xla::TransposePlanCache cache; +}; + +xla::ffi::TypeId GpuTransposePlanCache::id = {}; +xla::ffi::TypeInfo GpuTransposePlanCache::info = + xla::ffi::MakeTypeInfo(); + +XLA_FFI_REGISTER_TYPE(xla::ffi::GetXlaFfiApi(), "GpuTransposePlanCache", + &GpuTransposePlanCache::id, &GpuTransposePlanCache::info); + +std::pair +GpuTransposePlanCacheTypeInfo() { + return std::make_pair(&GpuTransposePlanCache::id, + &GpuTransposePlanCache::info); +} + +static xla::ffi::ErrorOr> +GpuTransposePlanCacheInstantiate(uint64_t index) { + return std::make_unique(16); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + kGpuTransposePlanCacheInstantiate, GpuTransposePlanCacheInstantiate, + xla::ffi::Ffi::BindInstantiate().Attr("index")); + +xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, + xla::FfiLoadedHostCallbacks* callbacks, + GpuTransposePlanCache* transpose_cache, + uint64_t index, + xla::ffi::RemainingArgs args, + xla::ffi::RemainingRets rets) { + size_t arity = args.size(); + std::vector host_input_buffers(arity); + // Copy input GPU buffers to host + for (size_t i = 0; i < arity; ++i) { + auto arg = args.get(i); + auto ptype = static_cast(arg->element_type()); + // TODO(b/395428868): Remove this check once we support subbyte types. + if (ptype == xla::S1 || ptype == xla::U1) { + return xla::ffi::Error(xla::ffi::ErrorCode::kUnimplemented, + absl::StrFormat("Unsupported primitive type: %s", + PrimitiveType_Name(ptype))); + } + if (ptype == xla::TOKEN) { + host_input_buffers[i] = nullptr; + continue; + } + size_t size_bytes = arg->size_bytes(); + // NOTE(dsuo): FFI arguments and return buffers are sized assuming + // minimum 1-byte element sizes, even if the data itself is packed. We + // assume that 2-bit and 4-bit types are packed. + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + if (bits_per_element == 2 || bits_per_element == 4) { + size_bytes = arg->element_count() * bits_per_element / 8; + } + host_input_buffers[i] = new char[size_bytes]; + // TODO(b/238441608): Use pinned memory here to speed up the transfer. + auto gpu_res = + gpuMemcpyAsync(host_input_buffers[i], arg.value().untyped_data(), + size_bytes, gpuMemcpyDeviceToHost, stream); + CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; + } + CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess) + << "Failed to gpuStreamSynchronize"; + nb::gil_scoped_acquire gil; + auto callback = nb::borrow( + static_cast(callbacks->callbacks[index])); + nb::tuple host_input_arrays = nb::steal(PyTuple_New(arity)); + for (size_t i = 0; i < arity; ++i) { + auto arg = args.get(i); + auto ptype = static_cast(arg->element_type()); + if (ptype == xla::TOKEN) { + PyTuple_SET_ITEM(host_input_arrays.ptr(), i, nb::none().inc_ref().ptr()); + continue; + } + auto maybe_dtype = PrimitiveTypeToNbDtype(ptype); + if (!maybe_dtype.ok()) { + return xla::ffi::Error::Internal(maybe_dtype.status().ToString()); + } + auto dtype = maybe_dtype.value(); + auto dims = absl::Span(arg->dimensions().begin(), + arg->dimensions().size()); + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + // of packing/unpacking to/from numpy arrays. + // We pass in data using default numpy layout i.e., std::nullopt. + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): FFI arguments and return buffers are sized assuming + // minimum 1-byte element sizes, even if the data itself is packed. We + // assume that 2-bit and 4-bit types are packed. + auto size_bytes = arg->element_count() * bits_per_element / 8; + auto buffer = xla::UnpackIntN( + bits_per_element, static_cast(host_input_buffers[i]), + size_bytes); + delete[] static_cast(host_input_buffers[i]); + host_input_buffers[i] = buffer.release(); + } + nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { + delete[] static_cast(ptr); + }); + auto array = xla::nb_numpy_ndarray(dtype, dims, std::nullopt, + host_input_buffers[i], base); + array.attr("flags").attr("writeable") = nb::bool_(false); + PyTuple_SET_ITEM(host_input_arrays.ptr(), i, array.inc_ref().ptr()); + } + + // TODO(dsuo): Change this to use the Python vectorcall protocol, which allows + // you to avoid constructing a tuple for the arguments. + nb::tuple result_tuple; + { + xla::HostCallbackScope scope; + try { + auto result_object = callback(*nb::borrow(host_input_arrays)); + result_tuple = nb::cast(result_object); + } catch (nb::python_error& e) { + return xla::ffi::Error::Internal( + absl::StrFormat("CpuCallback error calling callback: %s", e.what())); + } + } + + std::vector temp_buffers; + for (size_t i = 0; i < rets.size(); ++i) { + auto ret = rets.get(i).value(); + auto ptype = static_cast(ret->element_type()); + // TODO(b/395428868): Remove this check once we support subbyte types. + if (ptype == xla::S1 || ptype == xla::U1) { + return xla::ffi::Error(xla::ffi::ErrorCode::kUnimplemented, + absl::StrFormat("Unsupported primitive type: %s", + PrimitiveType_Name(ptype))); + } + if (ptype == xla::TOKEN) continue; + nb::object output = + nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); + auto array = xla::nb_numpy_ndarray::ensure(std::move(output)); + absl::Span strides( + reinterpret_cast(array.strides()), array.ndim()); + // We expect the output to be in default numpy layout. + auto dims = absl::Span(ret->dimensions().begin(), + ret->dimensions().size()); + auto maybe_expected_shape = xla::ShapeUtil::MakeValidatedShape(ptype, dims); + if (!maybe_expected_shape.ok()) { + return xla::ffi::Error::Internal( + maybe_expected_shape.status().ToString()); + } + auto expected_shape = maybe_expected_shape.value(); + auto expected_strides = xla::ByteStridesForShape(expected_shape); + + const void* data = array.data(); + size_t size_bytes = array.size() * array.itemsize(); + if (strides != expected_strides) { + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + options.dims = absl::Span( + reinterpret_cast(array.shape()), array.ndim()); + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.dimensions().size()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_striding = xla::TransposePlan::Striding{strides}; + auto maybe_plan = transpose_cache->cache.GetOrCreate(options); + if (!maybe_plan.ok()) { + return xla::ffi::Error::Internal(maybe_plan.status().ToString()); + } + auto plan = maybe_plan.value(); + void* temp = new char[size_bytes]; + temp_buffers.push_back(temp); + plan->Execute(data, temp); + data = temp; + } + + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + // of packing/unpacking to/from numpy arrays. + std::unique_ptr buffer; + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): FFI arguments and return buffers are sized assuming + // minimum 1-byte element sizes, even if the data itself is packed. We + // assume that 2-bit and 4-bit types are packed. + buffer = xla::PackIntN(bits_per_element, static_cast(data), + size_bytes); + data = buffer.get(); + size_bytes = (size_bytes * bits_per_element) / 8; + } + + auto gpu_res = gpuMemcpyAsync(ret->untyped_data(), data, size_bytes, + gpuMemcpyHostToDevice, stream); + CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; + } + nb::gil_scoped_release release; + CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess) + << "Failed to gpuStreamSynchronize"; + for (int i = 0; i < temp_buffers.size(); ++i) { + delete[] static_cast(temp_buffers[i]); + } + return xla::ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + kXlaFfiPythonGpuCallback, XlaFfiPythonGpuCallback, + xla::ffi::Ffi::Bind() + .Ctx>() + .Ctx>() + .Ctx>() + .Attr("index") + .RemainingArgs() + .RemainingRets()); +XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), + "xla_ffi_python_gpu_callback", + absl::AsciiStrToUpper(JAX_GPU_PLUGIN_NAME), + {kGpuTransposePlanCacheInstantiate, nullptr, nullptr, + kXlaFfiPythonGpuCallback}); +XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), + "xla_ffi_partitioned_python_gpu_callback", + absl::AsciiStrToUpper(JAX_GPU_PLUGIN_NAME), + {kGpuTransposePlanCacheInstantiate, nullptr, nullptr, + kXlaFfiPythonGpuCallback}); + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + kXlaBufferPythonGpuCallback, +#ifdef JAX_GPU_CUDA + (jax::XlaBufferCallback), +#else + (jax::XlaBufferCallback), +#endif + xla::ffi::Ffi::Bind() + .Ctx() + .Ctx() + .Ctx>() + .Attr("index") + .RemainingArgs() + .RemainingRets()); + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + kXlaBufferPythonGpuCallbackCmdBuffer, +#ifdef JAX_GPU_CUDA + (jax::XlaBufferCallback), +#else + (jax::XlaBufferCallback), +#endif + xla::ffi::Ffi::Bind() + .Ctx() + .Ctx() + .Ctx>() + .Attr("index") + .RemainingArgs() + .RemainingRets(), + {ffi::Traits::kCmdBufferCompatible}); + +XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), + "xla_buffer_python_gpu_callback", + absl::AsciiStrToUpper(JAX_GPU_PLUGIN_NAME), + kXlaBufferPythonGpuCallback); +XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), + "xla_buffer_python_gpu_callback_cmd_buffer", + absl::AsciiStrToUpper(JAX_GPU_PLUGIN_NAME), + kXlaBufferPythonGpuCallbackCmdBuffer); + +} // namespace JAX_GPU_NAMESPACE +} // namespace jax diff --git a/jaxlib/gpu/py_client_gpu.h b/jaxlib/gpu/py_client_gpu.h new file mode 100644 index 000000000000..6c9fd7912a2c --- /dev/null +++ b/jaxlib/gpu/py_client_gpu.h @@ -0,0 +1,38 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAX_JAXLIB_GPU_PY_CLIENT_GPU_H_ +#define JAX_JAXLIB_GPU_PY_CLIENT_GPU_H_ + +#include + +#include "jaxlib/gpu/vendor.h" +#include "xla/ffi/api/ffi.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { + +std::pair +GpuTransposePlanCacheTypeInfo(); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(kGpuTransposePlanCacheInstantiate); +XLA_FFI_DECLARE_HANDLER_SYMBOL(kXlaFfiPythonGpuCallback); +XLA_FFI_DECLARE_HANDLER_SYMBOL(kXlaBufferPythonGpuCallback); +XLA_FFI_DECLARE_HANDLER_SYMBOL(kXlaBufferPythonGpuCallbackCmdBuffer); + +} // namespace JAX_GPU_NAMESPACE +} // namespace jax + +#endif // JAX_JAXLIB_GPU_PY_CLIENT_GPU_H_ diff --git a/jaxlib/gpu/rnn.cc b/jaxlib/gpu/rnn.cc index eaa815d33e68..c235aa9fecfb 100644 --- a/jaxlib/gpu/rnn.cc +++ b/jaxlib/gpu/rnn.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" +#include "nanobind/stl/pair.h" // IWYU pragma: keep #include "jaxlib/absl_status_casters.h" #include "jaxlib/gpu/rnn_kernels.h" #include "jaxlib/gpu/vendor.h" @@ -39,8 +39,6 @@ nb::bytes BuildRnnDescriptor(int input_size, int hidden_size, int num_layers, nb::dict Registrations() { nb::dict dict; - dict[JAX_GPU_PREFIX "dnn_rnn"] = EncapsulateFunction(RNNForward); - dict[JAX_GPU_PREFIX "dnn_rnn_bwd"] = EncapsulateFunction(RNNBackward); dict[JAX_GPU_PREFIX "dnn_rnn_ffi"] = EncapsulateFfiHandler(RNNForwardFfi); dict[JAX_GPU_PREFIX "dnn_rnn_bwd_ffi"] = EncapsulateFfiHandler(RNNBackwardFfi); diff --git a/jaxlib/gpu/rnn_kernels.cc b/jaxlib/gpu/rnn_kernels.cc index e9820bc31f1e..5020c9a1d36f 100644 --- a/jaxlib/gpu/rnn_kernels.cc +++ b/jaxlib/gpu/rnn_kernels.cc @@ -16,16 +16,20 @@ limitations under the License. #include "jaxlib/gpu/rnn_kernels.h" #include +#include +#include #include #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" #include "jaxlib/gpu/ffi_wrapper.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" -#include "jaxlib/handle_pool.h" +#include "jaxlib/gpu/handle_pool.h" +#include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_helpers.h" -#include "xla/service/custom_call_status.h" namespace jax { @@ -56,7 +60,7 @@ template <> /*static*/ absl::StatusOr DnnHandlePool::Borrow( gpuStream_t stream) { DnnHandlePool* pool = Instance(); - absl::MutexLock lock(&pool->mu_); + absl::MutexLock lock(pool->mu_); gpudnnHandle_t handle; if (pool->handles_[stream].empty()) { JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnCreate(&handle))); @@ -168,6 +172,9 @@ DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size, JAX_RETURN_IF_ERROR( JAX_AS_STATUS(gpudnnDestroyRNNDataDescriptor(input_data_desc))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnDestroyRNNDescriptor(rnn_desc))); +#ifdef JAX_GPU_HIP + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuFree(dropout_states_dev))); +#endif // Round up to nearest multiples of 4 so we can return them as f32 arrays. workSpaceSize += (workSpaceSize % 4); @@ -347,6 +354,10 @@ static absl::Status DnnRNNForward_(gpuStream_t stream, void** buffers, JAX_RETURN_IF_ERROR( JAX_AS_STATUS(gpudnnDestroyDropoutDescriptor(dropout_desc))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnDestroyRNNDescriptor(rnn_desc))); +#ifdef JAX_GPU_HIP + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuFree(dropout_states_dev))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnDestroyTensorDescriptor(input_tensor_desc))); +#endif return absl::OkStatus(); } @@ -532,28 +543,14 @@ static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers, JAX_RETURN_IF_ERROR( JAX_AS_STATUS(gpudnnDestroyDropoutDescriptor(dropout_desc))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnDestroyRNNDescriptor(rnn_desc))); +#ifdef JAX_GPU_HIP + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuFree(dropout_states_dev))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnDestroyTensorDescriptor(input_tensor_desc))); +#endif return absl::OkStatus(); } -void RNNForward(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = DnnRNNForward_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -void RNNBackward(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = DnnRNNBackward_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(RNNForwardFfi, DnnRNNForward_); JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(RNNBackwardFfi, DnnRNNBackward_); diff --git a/jaxlib/gpu/rnn_kernels.h b/jaxlib/gpu/rnn_kernels.h index e95b7788382a..c1d6712a9eac 100644 --- a/jaxlib/gpu/rnn_kernels.h +++ b/jaxlib/gpu/rnn_kernels.h @@ -17,11 +17,11 @@ limitations under the License. #define JAXLIB_GPU_RNN_KERNELS_H_ #include +#include #include "absl/status/statusor.h" #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" -#include "xla/service/custom_call_status.h" namespace jax { namespace JAX_GPU_NAMESPACE { @@ -46,12 +46,6 @@ absl::StatusOr> RnnComputeWorkspaceReserveSpaceSizes( int max_seq_length, float dropout, bool bidirectional, bool cudnn_allow_tf32); -void RNNForward(gpuStream_t stream, void **buffers, const char *opaque, - size_t opaque_len, XlaCustomCallStatus *status); - -void RNNBackward(gpuStream_t stream, void **buffers, const char *opaque, - size_t opaque_len, XlaCustomCallStatus *status); - XLA_FFI_DECLARE_HANDLER_SYMBOL(RNNForwardFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(RNNBackwardFfi); diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index 357a38eecfd5..6eb76b570722 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -13,22 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include -#include - #include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" -#include "absl/container/flat_hash_map.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_format.h" -#include "jaxlib/gpu/gpu_kernel_helpers.h" -#include "jaxlib/gpu/solver_handle_pool.h" -#include "jaxlib/gpu/solver_kernels.h" #include "jaxlib/gpu/solver_kernels_ffi.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" -#include "xla/tsl/python/lib/core/numpy.h" namespace jax { namespace JAX_GPU_NAMESPACE { @@ -36,477 +24,34 @@ namespace { namespace nb = nanobind; -// Converts a NumPy dtype to a Type. -SolverType DtypeToSolverType(const dtype& np_type) { - static auto* types = - new absl::flat_hash_map, SolverType>({ - {{'f', 4}, SolverType::F32}, - {{'f', 8}, SolverType::F64}, - {{'c', 8}, SolverType::C64}, - {{'c', 16}, SolverType::C128}, - }); - auto it = types->find({np_type.kind(), np_type.itemsize()}); - if (it == types->end()) { - nb::str repr = nb::repr(np_type); - throw std::invalid_argument( - absl::StrFormat("Unsupported dtype %s", repr.c_str())); - } - return it->second; -} - -// getrf: LU decomposition - -// Returns the workspace size and a descriptor for a getrf operation. -std::pair BuildGetrfDescriptor(const dtype& dtype, int b, int m, - int n) { - SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnSgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnDgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnCgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnZgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - } - return {lwork, PackDescriptor(GetrfDescriptor{type, b, m, n, lwork})}; -} - -// geqrf: QR decomposition - -// Returns the workspace size and a descriptor for a geqrf operation. -std::pair BuildGeqrfDescriptor(const dtype& dtype, int b, int m, - int n) { - SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnSgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnDgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnCgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnZgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - } - return {lwork, PackDescriptor(GeqrfDescriptor{type, b, m, n, lwork})}; -} - -#ifdef JAX_GPU_CUDA - -// csrlsvqr: Linear system solve via Sparse QR - -// Returns a descriptor for a csrlsvqr operation. -nb::bytes BuildCsrlsvqrDescriptor(const dtype& dtype, int n, int nnzA, - int reorder, double tol) { - SolverType type = DtypeToSolverType(dtype); - return PackDescriptor(CsrlsvqrDescriptor{type, n, nnzA, reorder, tol}); -} - -#endif // JAX_GPU_CUDA - -// orgqr/ungqr: apply elementary Householder transformations - -// Returns the workspace size and a descriptor for a geqrf operation. -std::pair BuildOrgqrDescriptor(const dtype& dtype, int b, int m, - int n, int k) { - SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnSorgqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, - /*tau=*/nullptr, &lwork))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnDorgqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, - /*tau=*/nullptr, &lwork))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnCungqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, - /*tau=*/nullptr, &lwork))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnZungqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, - /*tau=*/nullptr, &lwork))); - break; - } - return {lwork, PackDescriptor(OrgqrDescriptor{type, b, m, n, k, lwork})}; -} - -// Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd - -// Returns the workspace size and a descriptor for a syevd operation. -std::pair BuildSyevdDescriptor(const dtype& dtype, bool lower, - int b, int n) { - SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR; - gpusolverFillMode_t uplo = - lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER; - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsyevd_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, - &lwork))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsyevd_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, - &lwork))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnCheevd_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, - &lwork))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnZheevd_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, - &lwork))); - break; - } - return {lwork, PackDescriptor(SyevdDescriptor{type, uplo, b, n, lwork})}; -} - -// Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj -// Supports batches of matrices up to size 32. - -// Returns the workspace size and a descriptor for a syevj_batched operation. -std::pair BuildSyevjDescriptor(const dtype& dtype, bool lower, - int batch, int n) { - SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - gpuSyevjInfo_t params; - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnCreateSyevjInfo(¶ms))); - std::unique_ptr params_cleanup( - params, [](gpuSyevjInfo_t p) { gpusolverDnDestroySyevjInfo(p); }); - gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR; - gpusolverFillMode_t uplo = - lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER; - if (batch == 1) { - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsyevj_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsyevj_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnCheevj_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnZheevj_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params))); - break; - } - } else { - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsyevjBatched_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params, batch))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsyevjBatched_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params, batch))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnCheevjBatched_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params, batch))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnZheevjBatched_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params, batch))); - break; - } - } - return {lwork, PackDescriptor(SyevjDescriptor{type, uplo, batch, n, lwork})}; -} - -// Singular value decomposition using QR algorithm: gesvd - -// Returns the workspace size and a descriptor for a gesvd operation. -std::pair BuildGesvdDescriptor(const dtype& dtype, int b, int m, - int n, bool compute_uv, - bool full_matrices) { - SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - signed char jobu, jobvt; - if (compute_uv) { - if (full_matrices) { - jobu = jobvt = 'A'; - } else { - jobu = jobvt = 'S'; - } - } else { - jobu = jobvt = 'N'; - } - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnSgesvd_bufferSize( - handle.get(), jobu, jobvt, m, n, &lwork))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnDgesvd_bufferSize( - handle.get(), jobu, jobvt, m, n, &lwork))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnCgesvd_bufferSize( - handle.get(), jobu, jobvt, m, n, &lwork))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnZgesvd_bufferSize( - handle.get(), jobu, jobvt, m, n, &lwork))); - break; - } - return {lwork, - PackDescriptor(GesvdDescriptor{type, b, m, n, lwork, jobu, jobvt})}; -} - -#ifdef JAX_GPU_CUDA - -// Singular value decomposition using Jacobi algorithm: gesvdj - -// Returns the workspace size and a descriptor for a gesvdj operation. -std::pair BuildGesvdjDescriptor(const dtype& dtype, int batch, - int m, int n, bool compute_uv, - int econ) { - SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - gpusolverEigMode_t jobz = - compute_uv ? GPUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR; - gesvdjInfo_t params; - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCreateGesvdjInfo(¶ms))); - std::unique_ptr params_cleanup( - params, [](gesvdjInfo* p) { cusolverDnDestroyGesvdjInfo(p); }); - if (batch <= 1 || m > 32 || n > 32 || econ) { - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdj_bufferSize( - handle.get(), jobz, econ, m, n, - /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, - /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvdj_bufferSize( - handle.get(), jobz, econ, m, n, - /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, - /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdj_bufferSize( - handle.get(), jobz, econ, m, n, - /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, - /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdj_bufferSize( - handle.get(), jobz, econ, m, n, - /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, - /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params))); - break; - } - } else { - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdjBatched_bufferSize( - handle.get(), jobz, m, n, - /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, - /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params, batch))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvdjBatched_bufferSize( - handle.get(), jobz, m, n, - /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, - /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params, batch))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdjBatched_bufferSize( - handle.get(), jobz, m, n, - /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, - /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params, batch))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdjBatched_bufferSize( - handle.get(), jobz, m, n, - /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, - /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params, batch))); - break; - } - } - return {lwork, PackDescriptor( - GesvdjDescriptor{type, batch, m, n, lwork, jobz, econ})}; -} - -#endif // JAX_GPU_CUDA - -// Returns the workspace size and a descriptor for a geqrf operation. -std::pair BuildSytrdDescriptor(const dtype& dtype, bool lower, - int b, int n) { - SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - gpusolverFillMode_t uplo = - lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER; - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsytrd_bufferSize( - handle.get(), uplo, n, /*A=*/nullptr, /*lda=*/n, /*D=*/nullptr, - /*E=*/nullptr, /*tau=*/nullptr, &lwork))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsytrd_bufferSize( - handle.get(), uplo, n, /*A=*/nullptr, /*lda=*/n, /*D=*/nullptr, - /*E=*/nullptr, /*tau=*/nullptr, &lwork))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnChetrd_bufferSize( - handle.get(), uplo, n, /*A=*/nullptr, /*lda=*/n, /*D=*/nullptr, - /*E=*/nullptr, /*tau=*/nullptr, &lwork))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnZhetrd_bufferSize( - handle.get(), uplo, n, /*A=*/nullptr, /*lda=*/n, /*D=*/nullptr, - /*E=*/nullptr, /*tau=*/nullptr, &lwork))); - break; - } - return {lwork, PackDescriptor(SytrdDescriptor{type, uplo, b, n, n, lwork})}; -} - nb::dict Registrations() { nb::dict dict; - dict[JAX_GPU_PREFIX "solver_getrf"] = EncapsulateFunction(Getrf); - dict[JAX_GPU_PREFIX "solver_geqrf"] = EncapsulateFunction(Geqrf); - dict[JAX_GPU_PREFIX "solver_orgqr"] = EncapsulateFunction(Orgqr); - dict[JAX_GPU_PREFIX "solver_syevd"] = EncapsulateFunction(Syevd); - dict[JAX_GPU_PREFIX "solver_syevj"] = EncapsulateFunction(Syevj); - dict[JAX_GPU_PREFIX "solver_gesvd"] = EncapsulateFunction(Gesvd); - dict[JAX_GPU_PREFIX "solver_sytrd"] = EncapsulateFunction(Sytrd); - -#ifdef JAX_GPU_CUDA - dict["cusolver_csrlsvqr"] = EncapsulateFunction(Csrlsvqr); - dict["cusolver_gesvdj"] = EncapsulateFunction(Gesvdj); - -#endif // JAX_GPU_CUDA dict[JAX_GPU_PREFIX "solver_getrf_ffi"] = EncapsulateFfiHandler(GetrfFfi); dict[JAX_GPU_PREFIX "solver_geqrf_ffi"] = EncapsulateFfiHandler(GeqrfFfi); dict[JAX_GPU_PREFIX "solver_orgqr_ffi"] = EncapsulateFfiHandler(OrgqrFfi); + dict[JAX_GPU_PREFIX "solver_potrf_ffi"] = EncapsulateFfiHandler(PotrfFfi); dict[JAX_GPU_PREFIX "solver_syevd_ffi"] = EncapsulateFfiHandler(SyevdFfi); dict[JAX_GPU_PREFIX "solver_syrk_ffi"] = EncapsulateFfiHandler(SyrkFfi); dict[JAX_GPU_PREFIX "solver_gesvd_ffi"] = EncapsulateFfiHandler(GesvdFfi); + dict[JAX_GPU_PREFIX "solver_gesvdj_ffi"] = EncapsulateFfiHandler(GesvdjFfi); dict[JAX_GPU_PREFIX "solver_sytrd_ffi"] = EncapsulateFfiHandler(SytrdFfi); #ifdef JAX_GPU_CUDA - dict[JAX_GPU_PREFIX "solver_gesvdj_ffi"] = EncapsulateFfiHandler(GesvdjFfi); + dict[JAX_GPU_PREFIX "solver_gesvdp_ffi"] = EncapsulateFfiHandler(GesvdpFfi); dict[JAX_GPU_PREFIX "solver_csrlsvqr_ffi"] = EncapsulateFfiHandler(CsrlsvqrFfi); #endif // JAX_GPU_CUDA +#if JAX_GPU_HAVE_SOLVER_GEEV + dict[JAX_GPU_PREFIX "solver_geev_ffi"] = EncapsulateFfiHandler(GeevFfi); +#endif // JAX_GPU_HAVE_SOLVER_GEEV + return dict; } NB_MODULE(_solver, m) { - tsl::ImportNumpy(); m.def("registrations", &Registrations); - m.def("build_getrf_descriptor", &BuildGetrfDescriptor); - m.def("build_geqrf_descriptor", &BuildGeqrfDescriptor); - m.def("build_orgqr_descriptor", &BuildOrgqrDescriptor); - m.def("build_syevd_descriptor", &BuildSyevdDescriptor); - m.def("build_syevj_descriptor", &BuildSyevjDescriptor); - m.def("build_gesvd_descriptor", &BuildGesvdDescriptor); - m.def("build_sytrd_descriptor", &BuildSytrdDescriptor); -#ifdef JAX_GPU_CUDA - m.def("build_csrlsvqr_descriptor", &BuildCsrlsvqrDescriptor); - m.def("build_gesvdj_descriptor", &BuildGesvdjDescriptor); -#endif // JAX_GPU_CUDA } } // namespace diff --git a/jaxlib/gpu/solver_handle_pool.cc b/jaxlib/gpu/solver_handle_pool.cc index c55ea923b21b..9abe6e334c58 100644 --- a/jaxlib/gpu/solver_handle_pool.cc +++ b/jaxlib/gpu/solver_handle_pool.cc @@ -19,7 +19,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/handle_pool.h" +#include "jaxlib/gpu/handle_pool.h" #ifdef JAX_GPU_CUDA #include "third_party/gpus/cuda/include/cusolverSp.h" @@ -31,7 +31,7 @@ template <> /*static*/ absl::StatusOr SolverHandlePool::Borrow( gpuStream_t stream) { SolverHandlePool* pool = Instance(); - absl::MutexLock lock(&pool->mu_); + absl::MutexLock lock(pool->mu_); gpusolverDnHandle_t handle; if (pool->handles_[stream].empty()) { JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCreate(&handle))); @@ -51,7 +51,7 @@ template <> /*static*/ absl::StatusOr SpSolverHandlePool::Borrow(gpuStream_t stream) { SpSolverHandlePool* pool = Instance(); - absl::MutexLock lock(&pool->mu_); + absl::MutexLock lock(pool->mu_); cusolverSpHandle_t handle; if (pool->handles_[stream].empty()) { JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverSpCreate(&handle))); diff --git a/jaxlib/gpu/solver_handle_pool.h b/jaxlib/gpu/solver_handle_pool.h index c46c062b3054..4e369ea85520 100644 --- a/jaxlib/gpu/solver_handle_pool.h +++ b/jaxlib/gpu/solver_handle_pool.h @@ -18,7 +18,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/handle_pool.h" +#include "jaxlib/gpu/handle_pool.h" #ifdef JAX_GPU_CUDA #include "third_party/gpus/cuda/include/cusolverSp.h" diff --git a/jaxlib/gpu/solver_interface.cc b/jaxlib/gpu/solver_interface.cc index f439413215b2..8866c7bea2fe 100644 --- a/jaxlib/gpu/solver_interface.cc +++ b/jaxlib/gpu/solver_interface.cc @@ -62,8 +62,8 @@ JAX_GPU_DEFINE_GETRF(gpuDoubleComplex, gpusolverDnZgetrf); JAX_GPU_DEFINE_GETRF_BATCHED(float, gpublasSgetrfBatched); JAX_GPU_DEFINE_GETRF_BATCHED(double, gpublasDgetrfBatched); -JAX_GPU_DEFINE_GETRF_BATCHED(gpublasComplex, gpublasCgetrfBatched); -JAX_GPU_DEFINE_GETRF_BATCHED(gpublasDoubleComplex, gpublasZgetrfBatched); +JAX_GPU_DEFINE_GETRF_BATCHED(gpuComplex, gpublasCgetrfBatched); +JAX_GPU_DEFINE_GETRF_BATCHED(gpuDoubleComplex, gpublasZgetrfBatched); #undef JAX_GPU_DEFINE_GETRF_BATCHED // QR decomposition: geqrf @@ -101,8 +101,8 @@ JAX_GPU_DEFINE_GEQRF(gpuDoubleComplex, gpusolverDnZgeqrf); JAX_GPU_DEFINE_GEQRF_BATCHED(float, gpublasSgeqrfBatched); JAX_GPU_DEFINE_GEQRF_BATCHED(double, gpublasDgeqrfBatched); -JAX_GPU_DEFINE_GEQRF_BATCHED(gpublasComplex, gpublasCgeqrfBatched); -JAX_GPU_DEFINE_GEQRF_BATCHED(gpublasDoubleComplex, gpublasZgeqrfBatched); +JAX_GPU_DEFINE_GEQRF_BATCHED(gpuComplex, gpublasCgeqrfBatched); +JAX_GPU_DEFINE_GEQRF_BATCHED(gpuDoubleComplex, gpublasZgeqrfBatched); #undef JAX_GPU_DEFINE_GEQRF_BATCHED // Householder transformations: orgqr @@ -131,6 +131,46 @@ JAX_GPU_DEFINE_ORGQR(gpuComplex, gpusolverDnCungqr); JAX_GPU_DEFINE_ORGQR(gpuDoubleComplex, gpusolverDnZungqr); #undef JAX_GPU_DEFINE_ORGQR +// Cholesky decomposition: potrf + +#define JAX_GPU_DEFINE_POTRF(Type, Name) \ + template <> \ + absl::StatusOr PotrfBufferSize(gpusolverDnHandle_t handle, \ + gpusolverFillMode_t uplo, int n) { \ + int lwork; \ + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ + Name##_bufferSize(handle, uplo, n, /*A=*/nullptr, n, &lwork))); \ + return lwork; \ + } \ + \ + template <> \ + absl::Status Potrf(gpusolverDnHandle_t handle, \ + gpusolverFillMode_t uplo, int n, Type *a, \ + Type *workspace, int lwork, int *info) { \ + return JAX_AS_STATUS( \ + Name(handle, uplo, n, a, n, workspace, lwork, info)); \ + } + +JAX_GPU_DEFINE_POTRF(float, gpusolverDnSpotrf); +JAX_GPU_DEFINE_POTRF(double, gpusolverDnDpotrf); +JAX_GPU_DEFINE_POTRF(gpuComplex, gpusolverDnCpotrf); +JAX_GPU_DEFINE_POTRF(gpuDoubleComplex, gpusolverDnZpotrf); +#undef JAX_GPU_DEFINE_POTRF + +#define JAX_GPU_DEFINE_POTRF_BATCHED(Type, Name) \ + template <> \ + absl::Status PotrfBatched(gpusolverDnHandle_t handle, \ + gpusolverFillMode_t uplo, int n, Type **a, \ + int lda, int *info, int batch) { \ + return JAX_AS_STATUS(Name(handle, uplo, n, a, lda, info, batch)); \ + } + +JAX_GPU_DEFINE_POTRF_BATCHED(float, gpusolverDnSpotrfBatched); +JAX_GPU_DEFINE_POTRF_BATCHED(double, gpusolverDnDpotrfBatched); +JAX_GPU_DEFINE_POTRF_BATCHED(gpuComplex, gpusolverDnCpotrfBatched); +JAX_GPU_DEFINE_POTRF_BATCHED(gpuDoubleComplex, gpusolverDnZpotrfBatched); +#undef JAX_GPU_DEFINE_POTRF_BATCHED + // Symmetric (Hermitian) eigendecomposition: // * Jacobi algorithm: syevj/heevj (batches of matrices up to 32) // * QR algorithm: syevd/heevd @@ -232,8 +272,8 @@ JAX_GPU_DEFINE_SYEVD(gpuDoubleComplex, gpusolverDnZheevd); JAX_GPU_DEFINE_SYRK(float, gpublasSsyrk); JAX_GPU_DEFINE_SYRK(double, gpublasDsyrk); -JAX_GPU_DEFINE_SYRK(gpublasComplex, gpublasCsyrk); -JAX_GPU_DEFINE_SYRK(gpublasDoubleComplex, gpublasZsyrk); +JAX_GPU_DEFINE_SYRK(gpuComplex, gpublasCsyrk); +JAX_GPU_DEFINE_SYRK(gpuDoubleComplex, gpublasZsyrk); #undef JAX_GPU_DEFINE_SYRK // Singular Value Decomposition: gesvd @@ -262,8 +302,6 @@ JAX_GPU_DEFINE_GESVD(gpuComplex, gpusolverDnCgesvd); JAX_GPU_DEFINE_GESVD(gpuDoubleComplex, gpusolverDnZgesvd); #undef JAX_GPU_DEFINE_GESVD -#ifdef JAX_GPU_CUDA - #define JAX_GPU_DEFINE_GESVDJ(Type, Name) \ template <> \ absl::StatusOr GesvdjBufferSize( \ @@ -319,6 +357,8 @@ JAX_GPU_DEFINE_GESVDJ_BATCHED(gpuComplex, gpusolverDnCgesvdjBatched); JAX_GPU_DEFINE_GESVDJ_BATCHED(gpuDoubleComplex, gpusolverDnZgesvdjBatched); #undef JAX_GPU_DEFINE_GESVDJ_BATCHED +#ifdef JAX_GPU_CUDA + #define JAX_GPU_DEFINE_CSRLSVQR(Type, Scalar, Name) \ template <> \ absl::Status Csrlsvqr( \ diff --git a/jaxlib/gpu/solver_interface.h b/jaxlib/gpu/solver_interface.h index fa11f3d0e752..fb284ec984dd 100644 --- a/jaxlib/gpu/solver_interface.h +++ b/jaxlib/gpu/solver_interface.h @@ -117,6 +117,25 @@ JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, OrgqrBufferSize); JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Orgqr); #undef JAX_GPU_SOLVER_Orgqr_ARGS +// Cholesky decomposition: potrf + +#define JAX_GPU_SOLVER_PotrfBufferSize_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, gpusolverFillMode_t uplo, int n +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, PotrfBufferSize); +#undef JAX_GPU_SOLVER_PotrfBufferSize_ARGS + +#define JAX_GPU_SOLVER_Potrf_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, gpusolverFillMode_t uplo, int n, Type *a, \ + Type *workspace, int lwork, int *info +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Potrf); +#undef JAX_GPU_SOLVER_Potrf_ARGS + +#define JAX_GPU_SOLVER_PotrfBatched_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, gpusolverFillMode_t uplo, int n, Type **a, \ + int lda, int *info, int batch +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, PotrfBatched); +#undef JAX_GPU_SOLVER_PotrfBatched_ARGS + // Symmetric (Hermitian) eigendecomposition: // * Jacobi algorithm: syevj/heevj (batches of matrices up to 32) // * QR algorithm: syevd/heevd @@ -182,18 +201,16 @@ JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, GesvdBufferSize); JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Gesvd); #undef JAX_GPU_SOLVER_Gesvd_ARGS -#ifdef JAX_GPU_CUDA - #define JAX_GPU_SOLVER_GesvdjBufferSize_ARGS(Type, ...) \ gpusolverDnHandle_t handle, gpusolverEigMode_t job, int econ, int m, int n, \ - gesvdjInfo_t params + gpuGesvdjInfo_t params JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, GesvdjBufferSize); #undef JAX_GPU_SOLVER_GesvdjBufferSize_ARGS #define JAX_GPU_SOLVER_Gesvdj_ARGS(Type, Real) \ gpusolverDnHandle_t handle, gpusolverEigMode_t job, int econ, int m, int n, \ Type *a, Real *s, Type *u, Type *v, Type *workspace, int lwork, \ - int *info, gesvdjInfo_t params + int *info, gpuGesvdjInfo_t params JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Gesvdj); #undef JAX_GPU_SOLVER_Gesvdj_ARGS @@ -210,6 +227,8 @@ JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, GesvdjBatchedBufferSize); JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, GesvdjBatched); #undef JAX_GPU_SOLVER_GesvdjBatched_ARGS +#ifdef JAX_GPU_CUDA + #define JAX_GPU_SOLVER_Csrlsvqr_ARGS(Type, ...) \ cusolverSpHandle_t handle, int n, int nnz, cusparseMatDescr_t matdesc, \ const Type *csrValA, const int *csrRowPtrA, const int *csrColIndA, \ diff --git a/jaxlib/gpu/solver_kernels.cc b/jaxlib/gpu/solver_kernels.cc deleted file mode 100644 index 8c22dfcdbca7..000000000000 --- a/jaxlib/gpu/solver_kernels.cc +++ /dev/null @@ -1,978 +0,0 @@ -/* Copyright 2019 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "jaxlib/gpu/solver_kernels.h" - -#include -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "jaxlib/gpu/gpu_kernel_helpers.h" -#include "jaxlib/gpu/solver_handle_pool.h" -#include "jaxlib/gpu/vendor.h" -#include "jaxlib/kernel_helpers.h" -#include "xla/service/custom_call_status.h" - -#ifdef JAX_GPU_CUDA -#include "third_party/gpus/cuda/include/cusolverSp.h" -#endif // JAX_GPU_CUDA - -namespace jax { - -namespace JAX_GPU_NAMESPACE { - -static int SizeOfSolverType(SolverType type) { - switch (type) { - case SolverType::F32: - return sizeof(float); - case SolverType::F64: - return sizeof(double); - case SolverType::C64: - return sizeof(gpuComplex); - case SolverType::C128: - return sizeof(gpuDoubleComplex); - } -} - -// getrf: LU decomposition - -static absl::Status Getrf_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const GetrfDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[1] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], - SizeOfSolverType(d.type) * static_cast(d.batch) * - static_cast(d.m) * static_cast(d.n), - gpuMemcpyDeviceToDevice, stream))); - } - - int* ipiv = static_cast(buffers[2]); - int* info = static_cast(buffers[3]); - void* workspace = buffers[4]; - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSgetrf( - handle.get(), d.m, d.n, a, d.m, static_cast(workspace), - d.lwork, ipiv, info))); - a += d.m * d.n; - ipiv += std::min(d.m, d.n); - ++info; - } - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnDgetrf( - handle.get(), d.m, d.n, a, d.m, static_cast(workspace), - d.lwork, ipiv, info))); - a += d.m * d.n; - ipiv += std::min(d.m, d.n); - ++info; - } - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCgetrf( - handle.get(), d.m, d.n, a, d.m, static_cast(workspace), - d.lwork, ipiv, info))); - a += d.m * d.n; - ipiv += std::min(d.m, d.n); - ++info; - } - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZgetrf( - handle.get(), d.m, d.n, a, d.m, - static_cast(workspace), d.lwork, ipiv, info))); - a += d.m * d.n; - ipiv += std::min(d.m, d.n); - ++info; - } - break; - } - } - return absl::OkStatus(); -} - -void Getrf(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Getrf_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -// geqrf: QR decomposition - -static absl::Status Geqrf_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const GeqrfDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[1] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], - SizeOfSolverType(d.type) * static_cast(d.batch) * - static_cast(d.m) * static_cast(d.n), - gpuMemcpyDeviceToDevice, stream))); - } - - int* info = static_cast(buffers[3]); - void* workspace = buffers[4]; - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[1]); - float* tau = static_cast(buffers[2]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnSgeqrf(handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += std::min(d.m, d.n); - ++info; - } - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[1]); - double* tau = static_cast(buffers[2]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnDgeqrf(handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += std::min(d.m, d.n); - ++info; - } - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[1]); - gpuComplex* tau = static_cast(buffers[2]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCgeqrf( - handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += std::min(d.m, d.n); - ++info; - } - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[1]); - gpuDoubleComplex* tau = static_cast(buffers[2]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZgeqrf( - handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += std::min(d.m, d.n); - ++info; - } - break; - } - } - return absl::OkStatus(); -} - -void Geqrf(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Geqrf_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -#ifdef JAX_GPU_CUDA - -// csrlsvqr: Linear system solve via Sparse QR - -static absl::Status Csrlsvqr_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len, - int& singularity) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const CsrlsvqrDescriptor& d = **s; - - // This is the handle to the CUDA session. Gets a cusolverSp handle. - auto h = SpSolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - - cusparseMatDescr_t matdesc = nullptr; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateMatDescr(&matdesc))); - JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(cusparseSetMatType(matdesc, CUSPARSE_MATRIX_TYPE_GENERAL))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusparseSetMatIndexBase(matdesc, CUSPARSE_INDEX_BASE_ZERO))); - - switch (d.type) { - case SolverType::F32: { - float* csrValA = static_cast(buffers[0]); - int* csrRowPtrA = static_cast(buffers[1]); - int* csrColIndA = static_cast(buffers[2]); - float* b = static_cast(buffers[3]); - float* x = static_cast(buffers[4]); - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverSpScsrlsvqr( - handle.get(), d.n, d.nnz, matdesc, csrValA, csrRowPtrA, csrColIndA, b, - (float)d.tol, d.reorder, x, &singularity))); - - break; - } - case SolverType::F64: { - double* csrValA = static_cast(buffers[0]); - int* csrRowPtrA = static_cast(buffers[1]); - int* csrColIndA = static_cast(buffers[2]); - double* b = static_cast(buffers[3]); - double* x = static_cast(buffers[4]); - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverSpDcsrlsvqr( - handle.get(), d.n, d.nnz, matdesc, csrValA, csrRowPtrA, csrColIndA, b, - d.tol, d.reorder, x, &singularity))); - - break; - } - case SolverType::C64: { - gpuComplex* csrValA = static_cast(buffers[0]); - int* csrRowPtrA = static_cast(buffers[1]); - int* csrColIndA = static_cast(buffers[2]); - gpuComplex* b = static_cast(buffers[3]); - gpuComplex* x = static_cast(buffers[4]); - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverSpCcsrlsvqr( - handle.get(), d.n, d.nnz, matdesc, csrValA, csrRowPtrA, csrColIndA, b, - (float)d.tol, d.reorder, x, &singularity))); - - break; - } - case SolverType::C128: { - gpuDoubleComplex* csrValA = static_cast(buffers[0]); - int* csrRowPtrA = static_cast(buffers[1]); - int* csrColIndA = static_cast(buffers[2]); - gpuDoubleComplex* b = static_cast(buffers[3]); - gpuDoubleComplex* x = static_cast(buffers[4]); - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverSpZcsrlsvqr( - handle.get(), d.n, d.nnz, matdesc, csrValA, csrRowPtrA, csrColIndA, b, - (float)d.tol, d.reorder, x, &singularity))); - - break; - } - } - - cusparseDestroyMatDescr(matdesc); - return absl::OkStatus(); -} - -void Csrlsvqr(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - // Is >= 0 if A is singular. - int singularity = -1; - - auto s = Csrlsvqr_(stream, buffers, opaque, opaque_len, singularity); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } - - if (singularity >= 0) { - auto s = std::string("Singular matrix in linear solve."); - XlaCustomCallStatusSetFailure(status, s.c_str(), s.length()); - } -} - -#endif // JAX_GPU_CUDA - -// orgqr/ungqr: apply elementary Householder transformations - -static absl::Status Orgqr_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const OrgqrDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[2] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[2], buffers[0], - SizeOfSolverType(d.type) * static_cast(d.batch) * - static_cast(d.m) * static_cast(d.n), - gpuMemcpyDeviceToDevice, stream))); - } - - int* info = static_cast(buffers[3]); - void* workspace = buffers[4]; - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[2]); - float* tau = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnSorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += d.k; - ++info; - } - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[2]); - double* tau = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnDorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += d.k; - ++info; - } - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[2]); - gpuComplex* tau = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCungqr( - handle.get(), d.m, d.n, d.k, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += d.k; - ++info; - } - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[2]); - gpuDoubleComplex* tau = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZungqr( - handle.get(), d.m, d.n, d.k, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += d.k; - ++info; - } - break; - } - } - return absl::OkStatus(); -} - -void Orgqr(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Orgqr_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -// Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd - -static absl::Status Syevd_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const SyevdDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - - std::int64_t batch = d.batch; - int output_idx = 1; // with static shapes buffers[1] is the first output - if (d.batch == -1) { - // the batch is passed as a second operand - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - (void*)&batch, reinterpret_cast(buffers[1]), - sizeof(batch), gpuMemcpyDeviceToHost, stream))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream))); - output_idx = 2; - } - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[output_idx], buffers[0], - SizeOfSolverType(d.type) * batch * static_cast(d.n) * - static_cast(d.n), - gpuMemcpyDeviceToDevice, stream))); - gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR; - int* info = static_cast(buffers[output_idx + 2]); - void* work = buffers[output_idx + 3]; - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[output_idx]); - float* w = static_cast(buffers[output_idx + 1]); - for (int i = 0; i < batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnSsyevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info))); - a += d.n * d.n; - w += d.n; - ++info; - } - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[output_idx]); - double* w = static_cast(buffers[output_idx + 1]); - for (int i = 0; i < batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnDsyevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info))); - a += d.n * d.n; - w += d.n; - ++info; - } - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[output_idx]); - float* w = static_cast(buffers[output_idx + 1]); - for (int i = 0; i < batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnCheevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info))); - a += d.n * d.n; - w += d.n; - ++info; - } - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[output_idx]); - double* w = static_cast(buffers[output_idx + 1]); - for (int i = 0; i < batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZheevd( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info))); - a += d.n * d.n; - w += d.n; - ++info; - } - break; - } - } - return absl::OkStatus(); -} - -void Syevd(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Syevd_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -// Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj -// Supports batches of matrices up to size 32. - -absl::Status Syevj_(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const SyevjDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[1] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], - SizeOfSolverType(d.type) * static_cast(d.batch) * - static_cast(d.n) * static_cast(d.n), - gpuMemcpyDeviceToDevice, stream))); - } - gpuSyevjInfo_t params; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCreateSyevjInfo(¶ms))); - std::unique_ptr params_cleanup( - params, [](gpuSyevjInfo_t p) { gpusolverDnDestroySyevjInfo(p); }); - - gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR; - int* info = static_cast(buffers[3]); - void* work = buffers[4]; - if (d.batch == 1) { - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[1]); - float* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsyevj( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params))); - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[1]); - double* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsyevj( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params))); - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[1]); - float* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCheevj( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params))); - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[1]); - double* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZheevj( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params))); - break; - } - } - } else { - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[1]); - float* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsyevjBatched( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params, d.batch))); - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[1]); - double* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsyevjBatched( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params, d.batch))); - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[1]); - float* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCheevjBatched( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params, d.batch))); - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[1]); - double* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnZheevjBatched(handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), - d.lwork, info, params, d.batch))); - break; - } - } - } - return absl::OkStatus(); -} - -void Syevj(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Syevj_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -// Singular value decomposition using QR algorithm: gesvd - -static absl::Status Gesvd_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const GesvdDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[1] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], - SizeOfSolverType(d.type) * static_cast(d.batch) * - static_cast(d.m) * static_cast(d.n), - gpuMemcpyDeviceToDevice, stream))); - } - int* info = static_cast(buffers[5]); - void* work = buffers[6]; - int64_t k = d.jobu == 'A' ? d.m : d.n; - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[1]); - float* s = static_cast(buffers[2]); - float* u = static_cast(buffers[3]); - float* vt = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSgesvd( - handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n, - static_cast(work), d.lwork, - /*rwork=*/nullptr, info))); - a += d.m * d.n; - s += std::min(d.m, d.n); - u += d.m * k; - vt += d.n * d.n; - ++info; - } - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[1]); - double* s = static_cast(buffers[2]); - double* u = static_cast(buffers[3]); - double* vt = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnDgesvd( - handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n, - static_cast(work), d.lwork, - /*rwork=*/nullptr, info))); - a += d.m * d.n; - s += std::min(d.m, d.n); - u += d.m * k; - vt += d.n * d.n; - ++info; - } - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[1]); - float* s = static_cast(buffers[2]); - gpuComplex* u = static_cast(buffers[3]); - gpuComplex* vt = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCgesvd( - handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n, - static_cast(work), d.lwork, /*rwork=*/nullptr, info))); - a += d.m * d.n; - s += std::min(d.m, d.n); - u += d.m * k; - vt += d.n * d.n; - ++info; - } - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[1]); - double* s = static_cast(buffers[2]); - gpuDoubleComplex* u = static_cast(buffers[3]); - gpuDoubleComplex* vt = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZgesvd( - handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n, - static_cast(work), d.lwork, - /*rwork=*/nullptr, info))); - a += d.m * d.n; - s += std::min(d.m, d.n); - u += d.m * k; - vt += d.n * d.n; - ++info; - } - break; - } - } - return absl::OkStatus(); -} - -void Gesvd(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Gesvd_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -#ifdef JAX_GPU_CUDA - -// Singular value decomposition using Jacobi algorithm: gesvdj - -static absl::Status Gesvdj_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const GesvdjDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[1] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], - SizeOfSolverType(d.type) * static_cast(d.batch) * - static_cast(d.m) * static_cast(d.n), - gpuMemcpyDeviceToDevice, stream))); - } - int* info = static_cast(buffers[5]); - void* work = buffers[6]; - gesvdjInfo_t params; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCreateGesvdjInfo(¶ms))); - std::unique_ptr params_cleanup( - params, [](gesvdjInfo* p) { cusolverDnDestroyGesvdjInfo(p); }); - if (d.batch <= 1 || d.m > 32 || d.n > 32 || d.econ) { - int k = std::min(d.m, d.n); - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[1]); - float* s = static_cast(buffers[2]); - float* u = static_cast(buffers[3]); - float* v = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdj( - handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n, - static_cast(work), d.lwork, info, params))); - a += d.m * d.n; - s += k; - u += d.m * (d.econ ? k : d.m); - v += (d.econ ? k : d.n) * d.n; - ++info; - } - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[1]); - double* s = static_cast(buffers[2]); - double* u = static_cast(buffers[3]); - double* v = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvdj( - handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n, - static_cast(work), d.lwork, info, params))); - a += d.m * d.n; - s += k; - u += d.m * (d.econ ? k : d.m); - v += (d.econ ? k : d.n) * d.n; - ++info; - } - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[1]); - float* s = static_cast(buffers[2]); - gpuComplex* u = static_cast(buffers[3]); - gpuComplex* v = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdj( - handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n, - static_cast(work), d.lwork, info, params))); - a += d.m * d.n; - s += k; - u += d.m * (d.econ ? k : d.m); - v += (d.econ ? k : d.n) * d.n; - ++info; - } - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[1]); - double* s = static_cast(buffers[2]); - gpuDoubleComplex* u = static_cast(buffers[3]); - gpuDoubleComplex* v = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdj( - handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n, - static_cast(work), d.lwork, info, params))); - a += d.m * d.n; - s += k; - u += d.m * (d.econ ? k : d.m); - v += (d.econ ? k : d.n) * d.n; - ++info; - } - break; - } - } - } else { - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[1]); - float* s = static_cast(buffers[2]); - float* u = static_cast(buffers[3]); - float* v = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdjBatched( - handle.get(), d.jobz, d.m, d.n, a, d.m, s, u, d.m, v, d.n, - static_cast(work), d.lwork, info, params, d.batch))); - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[1]); - double* s = static_cast(buffers[2]); - double* u = static_cast(buffers[3]); - double* v = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvdjBatched( - handle.get(), d.jobz, d.m, d.n, a, d.m, s, u, d.m, v, d.n, - static_cast(work), d.lwork, info, params, d.batch))); - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[1]); - float* s = static_cast(buffers[2]); - gpuComplex* u = static_cast(buffers[3]); - gpuComplex* v = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdjBatched( - handle.get(), d.jobz, d.m, d.n, a, d.m, s, u, d.m, v, d.n, - static_cast(work), d.lwork, info, params, d.batch))); - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[1]); - double* s = static_cast(buffers[2]); - gpuDoubleComplex* u = static_cast(buffers[3]); - gpuDoubleComplex* v = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdjBatched( - handle.get(), d.jobz, d.m, d.n, a, d.m, s, u, d.m, v, d.n, - static_cast(work), d.lwork, info, params, - d.batch))); - break; - } - } - } - return absl::OkStatus(); -} - -void Gesvdj(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Gesvdj_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -#endif // JAX_GPU_CUDA - -// sytrd/hetrd: symmetric (Hermitian) tridiagonal reduction - -static absl::Status Sytrd_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const SytrdDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[1] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], - SizeOfSolverType(d.type) * static_cast(d.batch) * - static_cast(d.n) * static_cast(d.lda), - gpuMemcpyDeviceToDevice, stream))); - } - - int* info = static_cast(buffers[5]); - void* workspace = buffers[6]; - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[1]); - float* d_out = static_cast(buffers[2]); - float* e_out = static_cast(buffers[3]); - float* tau = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsytrd( - handle.get(), d.uplo, d.n, a, d.lda, d_out, e_out, tau, - static_cast(workspace), d.lwork, info))); - a += d.lda * d.n; - d_out += d.n; - e_out += d.n - 1; - tau += d.n - 1; - ++info; - } - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[1]); - double* d_out = static_cast(buffers[2]); - double* e_out = static_cast(buffers[3]); - double* tau = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsytrd( - handle.get(), d.uplo, d.n, a, d.lda, d_out, e_out, tau, - static_cast(workspace), d.lwork, info))); - a += d.lda * d.n; - d_out += d.n; - e_out += d.n - 1; - tau += d.n - 1; - ++info; - } - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[1]); - float* d_out = static_cast(buffers[2]); - float* e_out = static_cast(buffers[3]); - gpuComplex* tau = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnChetrd( - handle.get(), d.uplo, d.n, a, d.lda, d_out, e_out, tau, - static_cast(workspace), d.lwork, info))); - a += d.lda * d.n; - d_out += d.n; - e_out += d.n - 1; - tau += d.n - 1; - ++info; - } - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[1]); - double* d_out = static_cast(buffers[2]); - double* e_out = static_cast(buffers[3]); - gpuDoubleComplex* tau = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZhetrd( - handle.get(), d.uplo, d.n, a, d.lda, d_out, e_out, tau, - static_cast(workspace), d.lwork, info))); - a += d.lda * d.n; - d_out += d.n; - e_out += d.n - 1; - tau += d.n - 1; - ++info; - } - break; - } - } - return absl::OkStatus(); -} - -void Sytrd(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Sytrd_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -} // namespace JAX_GPU_NAMESPACE -} // namespace jax diff --git a/jaxlib/gpu/solver_kernels.h b/jaxlib/gpu/solver_kernels.h deleted file mode 100644 index 51082f2fe812..000000000000 --- a/jaxlib/gpu/solver_kernels.h +++ /dev/null @@ -1,148 +0,0 @@ -/* Copyright 2019 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef JAXLIB_CUSOLVER_KERNELS_H_ -#define JAXLIB_CUSOLVER_KERNELS_H_ - -#include - -#include "jaxlib/gpu/vendor.h" -#include "xla/service/custom_call_status.h" - -namespace jax { - -namespace JAX_GPU_NAMESPACE { - -// Set of types known to Cusolver. -enum class SolverType { - F32, - F64, - C64, - C128, -}; - -// getrf: LU decomposition - -struct GetrfDescriptor { - SolverType type; - int batch, m, n, lwork; -}; - -void Getrf(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// geqrf: QR decomposition - -struct GeqrfDescriptor { - SolverType type; - int batch, m, n, lwork; -}; - -void Geqrf(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -#ifdef JAX_GPU_CUDA - -// csrlsvpr: Linear system solve via Sparse QR - -struct CsrlsvqrDescriptor { - SolverType type; - int n, nnz, reorder; - double tol; -}; - -void Csrlsvqr(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -#endif // JAX_GPU_CUDA - -// orgqr/ungqr: apply elementary Householder transformations - -struct OrgqrDescriptor { - SolverType type; - int batch, m, n, k, lwork; -}; - -void Orgqr(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd - -struct SyevdDescriptor { - SolverType type; - gpusolverFillMode_t uplo; - int batch, n; // batch may be -1 in which case it is passed as operand. - int lwork; -}; - -void Syevd(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj -// Supports batches of matrices up to size 32. - -struct SyevjDescriptor { - SolverType type; - gpusolverFillMode_t uplo; - int batch, n; - int lwork; -}; - -void Syevj(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// Singular value decomposition using QR algorithm: gesvd - -struct GesvdDescriptor { - SolverType type; - int batch, m, n; - int lwork; - signed char jobu, jobvt; -}; - -void Gesvd(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -#ifdef JAX_GPU_CUDA - -// Singular value decomposition using Jacobi algorithm: gesvdj - -struct GesvdjDescriptor { - SolverType type; - int batch, m, n; - int lwork; - gpusolverEigMode_t jobz; - int econ; -}; - -void Gesvdj(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); -#endif // JAX_GPU_CUDA - -// sytrd/hetrd: Reduction of a symmetric (Hermitian) matrix to tridiagonal form. -struct SytrdDescriptor { - SolverType type; - gpusolverFillMode_t uplo; - int batch, n, lda, lwork; -}; - -void Sytrd(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - - -} // namespace JAX_GPU_NAMESPACE -} // namespace jax - -#endif // JAXLIB_CUSOLVER_KERNELS_H_ diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc index 79b5dff6ceff..e568a49d58f0 100644 --- a/jaxlib/gpu/solver_kernels_ffi.cc +++ b/jaxlib/gpu/solver_kernels_ffi.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include #include -#include #include #if JAX_GPU_HAVE_64_BIT @@ -51,8 +50,6 @@ namespace JAX_GPU_NAMESPACE { namespace ffi = ::xla::ffi; -#if JAX_GPU_HAVE_64_BIT - // Map an FFI buffer element type to the appropriate GPU solver type. inline absl::StatusOr SolverDataType(ffi::DataType dataType, std::string_view func) { @@ -71,8 +68,6 @@ inline absl::StatusOr SolverDataType(ffi::DataType dataType, } } -#endif - #define SOLVER_DISPATCH_IMPL(impl, ...) \ switch (dataType) { \ case ffi::F32: \ @@ -394,6 +389,112 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(OrgqrFfi, OrgqrDispatch, .Ret() // out ); +// Cholesky decomposition: potrf + +template +ffi::Error PotrfImpl(int64_t batch, int64_t size, gpuStream_t stream, + ffi::ScratchAllocator& scratch, bool lower, + ffi::AnyBuffer a, ffi::Result out, + ffi::Result> info) { + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(size)); + FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); + + gpusolverFillMode_t uplo = + lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER; + + FFI_ASSIGN_OR_RETURN(int lwork, + solver::PotrfBufferSize(handle.get(), uplo, n)); + FFI_ASSIGN_OR_RETURN(auto workspace, + AllocateWorkspace(scratch, lwork, "potrf")); + + auto a_data = static_cast(a.untyped_data()); + auto out_data = static_cast(out->untyped_data()); + auto info_data = info->typed_data(); + if (a_data != out_data) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); + } + + int out_step = n * n; + for (auto i = 0; i < batch; ++i) { + FFI_RETURN_IF_ERROR_STATUS(solver::Potrf(handle.get(), uplo, n, out_data, + workspace, lwork, info_data)); + out_data += out_step; + ++info_data; + } + return ffi::Error::Success(); +} + +template +ffi::Error PotrfBatchedImpl(int64_t batch, int64_t size, gpuStream_t stream, + ffi::ScratchAllocator& scratch, bool lower, + ffi::AnyBuffer a, ffi::Result out, + ffi::Result> info) { + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(size)); + FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); + FFI_ASSIGN_OR_RETURN(auto batch_ptrs, + AllocateWorkspace(scratch, batch, "batched potrf")); + + gpusolverFillMode_t uplo = + lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER; + + auto a_data = a.untyped_data(); + auto out_data = out->untyped_data(); + auto info_data = info->typed_data(); + if (a_data != out_data) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); + } + + MakeBatchPointersAsync(stream, out_data, batch_ptrs, batch, + sizeof(T) * n * n); + JAX_FFI_RETURN_IF_GPU_ERROR(gpuGetLastError()); + + FFI_RETURN_IF_ERROR_STATUS(solver::PotrfBatched( + handle.get(), uplo, n, batch_ptrs, n, info_data, batch)); + + return ffi::Error::Success(); +} + +ffi::Error PotrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, + bool lower, ffi::AnyBuffer a, + ffi::Result out, + ffi::Result> info) { + auto dataType = a.element_type(); + if (dataType != out->element_type()) { + return ffi::Error::InvalidArgument( + "The input and output to potrf must have the same element type"); + } + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(a.dimensions())); + if (rows != cols) { + return ffi::Error::InvalidArgument( + "The input matrix to potrf must be square"); + } + FFI_RETURN_IF_ERROR( + CheckShape(out->dimensions(), {batch, rows, cols}, "out", "potrf")); + FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "potrf")); + if (batch > 1) { + SOLVER_DISPATCH_IMPL(PotrfBatchedImpl, batch, rows, stream, scratch, lower, + a, out, info); + } else { + SOLVER_DISPATCH_IMPL(PotrfImpl, batch, rows, stream, scratch, lower, a, out, + info); + } + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in potrf", absl::FormatStreamed(dataType))); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(PotrfFfi, PotrfDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Ctx() + .Attr("lower") + .Arg() // a + .Ret() // out + .Ret>() // info +); + // Symmetric (Hermitian) eigendecomposition: // * Jacobi algorithm: syevj/heevj (batches of matrices up to 32) // * QR algorithm: syevd/heevd @@ -403,6 +504,16 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(OrgqrFfi, OrgqrDispatch, #if JAX_GPU_HAVE_64_BIT +absl::StatusOr IsSyevBatchedSupported() { + int version; + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverGetVersion(&version))); + // According to + // https://docs.nvidia.com/cuda/archive/12.6.2/cuda-toolkit-release-notes/index.html + // syevBatched is supported since CUDA 12.6.2, where CUSOLVER + // version is 11.7.1. + return version >= 11701; +} + ffi::Error Syevd64Impl(int64_t batch, int64_t n, gpuStream_t stream, ffi::ScratchAllocator& scratch, bool lower, ffi::AnyBuffer a, ffi::Result out, @@ -423,16 +534,39 @@ ffi::Error Syevd64Impl(int64_t batch, int64_t n, gpuStream_t stream, params_cleanup( params, [](gpusolverDnParams_t p) { gpusolverDnDestroyParams(p); }); + int64_t batch_step = 1; + FFI_ASSIGN_OR_RETURN(bool is_batched_syev_supported, + IsSyevBatchedSupported()); + if (is_batched_syev_supported && n > 0) { + int64_t matrix_size = n * n * ffi::ByteWidth(dataType); + batch_step = + std::max(int64_t(1), std::numeric_limits::max() / matrix_size); + if (batch_step >= 32 * 1024) { + batch_step = 32 * 1024; + } + } size_t workspaceInBytesOnDevice, workspaceInBytesOnHost; - JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXsyevd_bufferSize( - handle.get(), params, jobz, uplo, n, aType, /*a=*/nullptr, n, wType, - /*w=*/nullptr, aType, &workspaceInBytesOnDevice, - &workspaceInBytesOnHost)); + if (is_batched_syev_supported) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXsyevBatched_bufferSize( + handle.get(), params, jobz, uplo, n, aType, /*a=*/nullptr, n, wType, + /*w=*/nullptr, aType, &workspaceInBytesOnDevice, + &workspaceInBytesOnHost, std::min(batch, batch_step))); + } else { + if (batch_step != 1) { + return ffi::Error( + ffi::ErrorCode::kInternal, + "Syevd64Impl: batch_step != 1 but batched syev is not supported"); + } + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXsyevd_bufferSize( + handle.get(), params, jobz, uplo, n, aType, /*a=*/nullptr, n, wType, + /*w=*/nullptr, aType, &workspaceInBytesOnDevice, + &workspaceInBytesOnHost)); + } auto maybe_workspace = scratch.Allocate(workspaceInBytesOnDevice); if (!maybe_workspace.has_value()) { return ffi::Error(ffi::ErrorCode::kResourceExhausted, - "Unable to allocate device workspace for syevd"); + "Unable to allocate device workspace for syevBatched"); } auto workspaceOnDevice = maybe_workspace.value(); auto workspaceOnHost = @@ -447,17 +581,31 @@ ffi::Error Syevd64Impl(int64_t batch, int64_t n, gpuStream_t stream, out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); } - size_t out_step = n * n * ffi::ByteWidth(dataType); - size_t w_step = n * ffi::ByteWidth(ffi::ToReal(dataType)); - - for (auto i = 0; i < batch; ++i) { - JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXsyevd( - handle.get(), params, jobz, uplo, n, aType, out_data, n, wType, w_data, - aType, workspaceOnDevice, workspaceInBytesOnDevice, - workspaceOnHost.get(), workspaceInBytesOnHost, info_data)); + size_t out_step = n * n * ffi::ByteWidth(dataType) * batch_step; + size_t w_step = n * ffi::ByteWidth(ffi::ToReal(dataType)) * batch_step; + + for (int64_t i = 0; i < batch; i += batch_step) { + size_t batch_size = static_cast(std::min(batch_step, batch - i)); + if (is_batched_syev_supported && batch_step > 1) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXsyevBatched( + handle.get(), params, jobz, uplo, n, aType, out_data, n, wType, + w_data, aType, workspaceOnDevice, workspaceInBytesOnDevice, + workspaceOnHost.get(), workspaceInBytesOnHost, info_data, + batch_size)); + } else { + if (batch_step != 1) { + return ffi::Error( + ffi::ErrorCode::kInternal, + "Syevd64Impl: batch_step != 1 but batched syev is not supported"); + } + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXsyevd( + handle.get(), params, jobz, uplo, n, aType, out_data, n, wType, + w_data, aType, workspaceOnDevice, workspaceInBytesOnDevice, + workspaceOnHost.get(), workspaceInBytesOnHost, info_data)); + } out_data += out_step; w_data += w_step; - ++info_data; + info_data += batch_step; } return ffi::Error::Success(); @@ -876,8 +1024,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GesvdFfi, GesvdDispatch, .Ret>() // info ); -#ifdef JAX_GPU_CUDA - template ffi::Error GesvdjImpl(int64_t batch, int64_t rows, int64_t cols, gpuStream_t stream, ffi::ScratchAllocator& scratch, @@ -1000,6 +1146,141 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GesvdjFfi, GesvdjDispatch, .Ret>() // info ); +// Singular Value Decomposition: gesvdp (Polar decomposition) + +#ifdef JAX_GPU_CUDA + +ffi::Error GesvdpImpl(int64_t batch, int64_t m, int64_t n, gpuStream_t stream, + ffi::ScratchAllocator& scratch, bool compute_uv, + bool econ, ffi::AnyBuffer a, + ffi::Result out, + ffi::Result s, + ffi::Result u, + ffi::Result v, + ffi::Result> info) { + FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); + gpusolverEigMode_t job = + compute_uv ? GPUSOLVER_EIG_MODE_VECTOR : GPUSOLVER_EIG_MODE_NOVECTOR; + auto dataType = a.element_type(); + FFI_ASSIGN_OR_RETURN(auto aType, SolverDataType(dataType, "gesvdp")); + FFI_ASSIGN_OR_RETURN(auto sType, SolverDataType(s->element_type(), "gesvdp")); + + gpusolverDnParams_t params; + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateParams(¶ms)); + std::unique_ptr + params_cleanup( + params, [](gpusolverDnParams_t p) { gpusolverDnDestroyParams(p); }); + + size_t workspaceInBytesOnDevice, workspaceInBytesOnHost; + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXgesvdp_bufferSize( + handle.get(), params, job, econ ? 1 : 0, m, n, aType, + /*a=*/nullptr, m, sType, /*s=*/nullptr, aType, /*u=*/nullptr, m, aType, + /*v=*/nullptr, n, aType, &workspaceInBytesOnDevice, + &workspaceInBytesOnHost)); + + auto maybe_workspace = scratch.Allocate(workspaceInBytesOnDevice); + if (!maybe_workspace.has_value()) { + return ffi::Error(ffi::ErrorCode::kResourceExhausted, + "Unable to allocate device workspace for gesvd"); + } + auto workspaceOnDevice = maybe_workspace.value(); + auto workspaceOnHost = + std::unique_ptr(new char[workspaceInBytesOnHost]); + + const char* a_data = static_cast(a.untyped_data()); + char* out_data = static_cast(out->untyped_data()); + char* s_data = static_cast(s->untyped_data()); + char* u_data = static_cast(u->untyped_data()); + char* v_data = static_cast(v->untyped_data()); + int* info_data = info->typed_data(); + if (a_data != out_data) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); + } + + size_t out_step = m * n * ffi::ByteWidth(dataType); + size_t s_step = std::min(m, n) * ffi::ByteWidth(ffi::ToReal(dataType)); + size_t u_step = 0; + size_t v_step = 0; + if (compute_uv) { + u_step = m * (econ ? std::min(m, n) : m) * ffi::ByteWidth(dataType); + v_step = n * (econ ? std::min(m, n) : n) * ffi::ByteWidth(dataType); + } + + // TODO(phawkins): figure out a useful way to plumb out h_err. + double h_err; + for (auto i = 0; i < batch; ++i) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXgesvdp( + handle.get(), params, job, econ ? 1 : 0, m, n, aType, out_data, m, + sType, s_data, aType, u_data, m, aType, v_data, n, aType, + workspaceOnDevice, workspaceInBytesOnDevice, workspaceOnHost.get(), + workspaceInBytesOnHost, info_data, &h_err)); + out_data += out_step; + s_data += s_step; + u_data += u_step; + v_data += v_step; + ++info_data; + } + + return ffi::Error::Success(); +} + +ffi::Error GesvdpDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, + bool full_matrices, bool compute_uv, ffi::AnyBuffer a, + ffi::Result out, + ffi::Result s, + ffi::Result u, + ffi::Result v, + ffi::Result> info) { + auto dataType = a.element_type(); + if (out->element_type() != dataType || + s->element_type() != ffi::ToReal(dataType) || + u->element_type() != dataType || v->element_type() != dataType) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to gesvdp must have the same element type"); + } + FFI_ASSIGN_OR_RETURN((auto [batch, m, n]), SplitBatch2D(a.dimensions())); + FFI_RETURN_IF_ERROR( + CheckShape(out->dimensions(), {batch, m, n}, "out", "gesvdp")); + int64_t k = std::min(m, n); + FFI_RETURN_IF_ERROR(CheckShape(s->dimensions(), {batch, k}, "s", "gesvdp")); + if (compute_uv) { + if (full_matrices) { + FFI_RETURN_IF_ERROR( + CheckShape(u->dimensions(), {batch, m, m}, "u", "gesvdp")); + FFI_RETURN_IF_ERROR( + CheckShape(v->dimensions(), {batch, n, n}, "v", "gesvdp")); + } else { + FFI_RETURN_IF_ERROR( + CheckShape(u->dimensions(), {batch, m, k}, "u", "gesvdp")); + FFI_RETURN_IF_ERROR( + CheckShape(v->dimensions(), {batch, n, k}, "v", "gesvdp")); + } + } + FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "gesvdp")); + + return GesvdpImpl(batch, m, n, stream, scratch, compute_uv, !full_matrices, a, + out, s, u, v, info); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GesvdpFfi, GesvdpDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Ctx() + .Attr("full_matrices") + .Attr("compute_uv") + .Arg() // a + .Ret() // out + .Ret() // s + .Ret() // u + .Ret() // v + .Ret>() // info +); + +#endif // JAX_GPU_CUDA + +#ifdef JAX_GPU_CUDA + // csrlsvqr: Linear system solve via Sparse QR template @@ -1173,6 +1454,132 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(SytrdFfi, SytrdDispatch, .Ret>() // info ); +// General eigenvalue decomposition: geev + +#if JAX_GPU_HAVE_SOLVER_GEEV + +ffi::Error GeevImpl(gpuStream_t stream, ffi::ScratchAllocator scratch, + bool left, bool right, ffi::AnyBuffer a, + ffi::Result out, + ffi::Result w, + ffi::Result vl, + ffi::Result vr, + ffi::Result> info) { + auto dataType = a.element_type(); + if (dataType != vr->element_type()) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to geev must have the same element type"); + } + if (dataType != w->element_type() && + !(dataType == ffi::F32 && w->element_type() == ffi::C64) && + !(dataType == ffi::F64 && w->element_type() == ffi::C128)) { + return ffi::Error::InvalidArgument( + "The eigenvector output type of geev must match the input type or " + "be its complex counterpart."); + } + + FFI_ASSIGN_OR_RETURN((auto [batch, m, n]), SplitBatch2D(a.dimensions())); + if (m != n) { + return ffi::Error::InvalidArgument( + "The input matrix to geev must be square"); + } + int w_len; + if (w->element_type() == ffi::F32 || w->element_type() == ffi::F64) { + w_len = 2 * n; + FFI_RETURN_IF_ERROR( + CheckShape(w->dimensions(), {batch, 2 * n}, "w", "geev")); + } else { + FFI_RETURN_IF_ERROR(CheckShape(w->dimensions(), {batch, n}, "w", "geev")); + w_len = n; + } + if (left) { + FFI_RETURN_IF_ERROR( + CheckShape(vl->dimensions(), {batch, n, n}, "vl", "geev")); + } + if (right) { + FFI_RETURN_IF_ERROR( + CheckShape(vr->dimensions(), {batch, n, n}, "vr", "geev")); + } + FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "geev")); + + FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); + FFI_ASSIGN_OR_RETURN(auto aType, SolverDataType(dataType, "geev")); + FFI_ASSIGN_OR_RETURN(auto wType, SolverDataType(w->element_type(), "geev")); + + // At the time of writing, cusolver only supports computing right + // eigenvectors, but has the option for left eigenvectors in its API. Let us + // assume that they intend to add support for left eigenvectors in the future. + gpusolverEigMode_t jobvl = + left ? GPUSOLVER_EIG_MODE_VECTOR : GPUSOLVER_EIG_MODE_NOVECTOR; + gpusolverEigMode_t jobvr = + right ? GPUSOLVER_EIG_MODE_VECTOR : GPUSOLVER_EIG_MODE_NOVECTOR; + + gpusolverDnParams_t params; + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateParams(¶ms)); + std::unique_ptr + params_cleanup( + params, [](gpusolverDnParams_t p) { gpusolverDnDestroyParams(p); }); + + size_t workspaceInBytesOnDevice, workspaceInBytesOnHost; + JAX_FFI_RETURN_IF_GPU_ERROR(cusolverDnXgeev_bufferSize( + handle.get(), params, jobvl, jobvr, n, aType, /*a=*/nullptr, n, wType, + /*w=*/nullptr, aType, /*vl=*/nullptr, n, aType, /*vr=*/nullptr, n, aType, + &workspaceInBytesOnDevice, &workspaceInBytesOnHost)); + + auto maybe_workspace = scratch.Allocate(workspaceInBytesOnDevice); + if (!maybe_workspace.has_value()) { + return ffi::Error(ffi::ErrorCode::kResourceExhausted, + "Unable to allocate device workspace for syevd"); + } + auto workspaceOnDevice = maybe_workspace.value(); + auto workspaceOnHost = + std::unique_ptr(new char[workspaceInBytesOnHost]); + + const char* a_data = static_cast(a.untyped_data()); + char* out_data = static_cast(out->untyped_data()); + char* w_data = static_cast(w->untyped_data()); + char* vl_data = static_cast(vl->untyped_data()); + char* vr_data = static_cast(vr->untyped_data()); + int* info_data = info->typed_data(); + if (a_data != out_data) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); + } + + size_t out_step = n * n * ffi::ByteWidth(dataType); + size_t w_step = w_len * ffi::ByteWidth(w->element_type()); + + for (auto i = 0; i < batch; ++i) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXgeev( + handle.get(), params, jobvl, jobvr, n, aType, out_data, n, wType, + w_data, aType, vl_data, n, aType, vr_data, n, aType, workspaceOnDevice, + workspaceInBytesOnDevice, workspaceOnHost.get(), workspaceInBytesOnHost, + info_data)); + out_data += out_step; + w_data += w_step; + vr_data += out_step; + ++info_data; + } + + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GeevFfi, GeevImpl, + ffi::Ffi::Bind() + .Ctx>() + .Ctx() + .Attr("left") + .Attr("right") + .Arg() // a + .Ret() // out + .Ret() // w + .Ret() // vl + .Ret() // vr + .Ret>() // info +); + +#endif // JAX_GPU_HAVE_SOLVER_GEEV + #undef SOLVER_DISPATCH_IMPL #undef SOLVER_BLAS_DISPATCH_IMPL diff --git a/jaxlib/gpu/solver_kernels_ffi.h b/jaxlib/gpu/solver_kernels_ffi.h index 8e90a310e170..176ab9932886 100644 --- a/jaxlib/gpu/solver_kernels_ffi.h +++ b/jaxlib/gpu/solver_kernels_ffi.h @@ -26,23 +26,29 @@ namespace JAX_GPU_NAMESPACE { enum class SyevdAlgorithm : uint8_t { kDefault = 0, - kDivideAndConquer, - kJacobi, + kDivideAndConquer = 1, + kJacobi = 2, }; XLA_FFI_DECLARE_HANDLER_SYMBOL(GetrfFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(GeqrfFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(OrgqrFfi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(PotrfFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(SyevdFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(SyrkFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(GesvdFfi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(GesvdjFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(SytrdFfi); #ifdef JAX_GPU_CUDA -XLA_FFI_DECLARE_HANDLER_SYMBOL(GesvdjFfi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(GesvdpFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(CsrlsvqrFfi); #endif // JAX_GPU_CUDA +#if JAX_GPU_HAVE_SOLVER_GEEV +XLA_FFI_DECLARE_HANDLER_SYMBOL(GeevFfi); +#endif // JAX_GPU_HAVE_SOLVER_GEEV + } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/sparse.cc b/jaxlib/gpu/sparse.cc index 429c8018dc7a..051bfcddad57 100644 --- a/jaxlib/gpu/sparse.cc +++ b/jaxlib/gpu/sparse.cc @@ -13,19 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include +#include #include #include -#include -#include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" -#include "absl/base/casts.h" -#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/str_format.h" -#include "absl/synchronization/mutex.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/pair.h" // IWYU pragma: keep #include "jaxlib/absl_status_casters.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/sparse_kernels.h" @@ -69,9 +64,7 @@ gpuDataType DtypeToCudaDataType(const dtype& np_type) { {{'u', 1}, CUDA_R_8U}, {{'i', 4}, CUDA_R_32I}, {{'u', 4}, CUDA_R_32U}, -#if JAX_GPU_HAVE_SPARSE {{'V', 2}, CUDA_R_16BF}, -#endif // JAX_GPU_HAVE_SPARSE #endif // JAX_GPU_CUDA }); auto it = types->find({np_type.kind(), np_type.itemsize()}); @@ -107,7 +100,6 @@ DenseVecDescriptor BuildDenseVecDescriptor(const dtype& data_dtype, int size) { return DenseVecDescriptor{value_type, size}; } -#if JAX_GPU_HAVE_SPARSE // CsrToDense: Convert CSR matrix to dense matrix // Returns the descriptor for a Sparse matrix. @@ -147,45 +139,6 @@ std::pair BuildCsrToDenseDescriptor(const dtype& data_dtype, return {buffer_size, PackDescriptor(d)}; } -absl::Status CsrToDense_(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const SparseMatDescriptor& d = **s; - auto h = SparseHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - - gpusparseSpMatDescr_t mat_a = 0; - gpusparseDnMatDescr_t mat_b = 0; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz, - /*csrRowOffsets=*/buffers[2], - /*csrColInd=*/buffers[1], - /*csrValues=*/buffers[0], d.index_type, d.index_type, - GPUSPARSE_INDEX_BASE_ZERO, d.value_type))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat( - &mat_b, d.rows, d.cols, - /*ld=*/d.cols, buffers[3], d.value_type, GPUSPARSE_ORDER_ROW))); - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusparseSparseToDense(handle.get(), mat_a, mat_b, - GPUSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4]))); - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b))); - return absl::OkStatus(); -} - -void CsrToDense(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CsrToDense_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // CsrFromDense: Convert dense matrix to CSR matrix // Returns the descriptor for a CsrFromDense operation. @@ -222,46 +175,6 @@ std::pair BuildCsrFromDenseDescriptor( return {buffer_size, PackDescriptor(d)}; } -absl::Status CsrFromDense_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const SparseMatDescriptor& d = **s; - auto h = SparseHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - - gpusparseDnMatDescr_t mat_a = 0; - gpusparseSpMatDescr_t mat_b = 0; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat( - &mat_a, d.rows, d.cols, - /*ld=*/d.cols, buffers[0], d.value_type, GPUSPARSE_ORDER_ROW))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz, - /*csrRowOffsets=*/buffers[3], - /*csrColInd=*/buffers[2], - /*csrValues=*/buffers[1], d.index_type, d.index_type, - GPUSPARSE_INDEX_BASE_ZERO, d.value_type))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_analysis( - handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT, - buffers[4]))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_convert( - handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT, - buffers[4]))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_a))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_b))); - return absl::OkStatus(); -} - -void CsrFromDense(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CsrFromDense_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // CsrMatvec: Product of CSR matrix and dense vector. // Returns the descriptor for a CsrMatvec operation. @@ -552,46 +465,8 @@ std::pair BuildCooMatmatDescriptor( return {buffer_size, PackDescriptor(CooMatmatDescriptor{A, B, C, op_A})}; } -#endif // if JAX_GPU_HAVE_SPARSE - -nb::bytes BuildGtsv2Descriptor(int b, int m, int n, int ldb) { - return PackDescriptor(Gtsv2Descriptor{b, m, n, ldb}); -} - -template -size_t Gtsv2BufferSize(F f, int m, int n, int ldb) { - auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - size_t size; - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(f(handle.get(), m, n, /*dl=*/nullptr, /*d=*/nullptr, - /*du=*/nullptr, /*B=*/nullptr, ldb, &size))); - return size; -} - -size_t Gtsv2BufferSizeF32(int m, int n, int ldb) { - return Gtsv2BufferSize(gpusparseSgtsv2_bufferSizeExt, m, n, ldb); -} - -size_t Gtsv2BufferSizeF64(int m, int n, int ldb) { - return Gtsv2BufferSize(gpusparseDgtsv2_bufferSizeExt, m, n, ldb); -} - nb::dict Registrations() { nb::dict dict; -#if JAX_GPU_HAVE_SPARSE - dict[JAX_GPU_PREFIX "sparse_csr_todense"] = EncapsulateFunction(CsrToDense); - dict[JAX_GPU_PREFIX "sparse_csr_fromdense"] = - EncapsulateFunction(CsrFromDense); - dict[JAX_GPU_PREFIX "sparse_csr_matvec"] = EncapsulateFunction(CsrMatvec); - dict[JAX_GPU_PREFIX "sparse_csr_matmat"] = EncapsulateFunction(CsrMatmat); - dict[JAX_GPU_PREFIX "sparse_coo_todense"] = EncapsulateFunction(CooToDense); - dict[JAX_GPU_PREFIX "sparse_coo_fromdense"] = - EncapsulateFunction(CooFromDense); - dict[JAX_GPU_PREFIX "sparse_coo_matvec"] = EncapsulateFunction(CooMatvec); - dict[JAX_GPU_PREFIX "sparse_coo_matmat"] = EncapsulateFunction(CooMatmat); - dict[JAX_GPU_PREFIX "sparse_csr_todense_ffi"] = EncapsulateFfiHandler(CsrToDenseFfi); dict[JAX_GPU_PREFIX "sparse_csr_fromdense_ffi"] = @@ -608,22 +483,15 @@ nb::dict Registrations() { EncapsulateFfiHandler(CooMatvecFfi); dict[JAX_GPU_PREFIX "sparse_coo_matmat_ffi"] = EncapsulateFfiHandler(CooMatmatFfi); -#endif - dict[JAX_GPU_PREFIX "sparse_gtsv2_f32"] = EncapsulateFunction(gtsv2_f32); - dict[JAX_GPU_PREFIX "sparse_gtsv2_f64"] = EncapsulateFunction(gtsv2_f64); - dict[JAX_GPU_PREFIX "sparse_gtsv2_f32_ffi"] = - EncapsulateFfiHandler(gtsv2_f32_ffi); - dict[JAX_GPU_PREFIX "sparse_gtsv2_f64_ffi"] = - EncapsulateFfiHandler(gtsv2_f64_ffi); + dict[JAX_GPU_PREFIX "sparse_gtsv2_ffi"] = EncapsulateFfiHandler(kGtsv2); + // TODO(tomhennigan): Add support for gtsv2 complex 32/64. return dict; } NB_MODULE(_sparse, m) { tsl::ImportNumpy(); - m.attr("sparse_supported") = nb::cast(JAX_GPU_HAVE_SPARSE); m.def("registrations", &Registrations); -#if JAX_GPU_HAVE_SPARSE m.def("build_csr_todense_descriptor", &BuildCsrToDenseDescriptor); m.def("build_csr_fromdense_descriptor", &BuildCsrFromDenseDescriptor); m.def("build_csr_matvec_descriptor", &BuildCsrMatvecDescriptor); @@ -632,10 +500,6 @@ NB_MODULE(_sparse, m) { m.def("build_coo_fromdense_descriptor", &BuildCooFromDenseDescriptor); m.def("build_coo_matvec_descriptor", &BuildCooMatvecDescriptor); m.def("build_coo_matmat_descriptor", &BuildCooMatmatDescriptor); -#endif - m.def("gtsv2_f32_buffer_size", &Gtsv2BufferSizeF32); - m.def("gtsv2_f64_buffer_size", &Gtsv2BufferSizeF64); - m.def("build_gtsv2_descriptor", &BuildGtsv2Descriptor); } } // namespace diff --git a/jaxlib/gpu/sparse_kernels.cc b/jaxlib/gpu/sparse_kernels.cc index 5b620a05236d..681c7fc234c3 100644 --- a/jaxlib/gpu/sparse_kernels.cc +++ b/jaxlib/gpu/sparse_kernels.cc @@ -15,22 +15,28 @@ limitations under the License. #include "jaxlib/gpu/sparse_kernels.h" -#include +#include #include -#include -#include -#include +#include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" +#include "jaxlib/ffi_helpers.h" #include "jaxlib/gpu/ffi_wrapper.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/handle_pool.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/handle_pool.h" #include "jaxlib/kernel_helpers.h" -#include "xla/service/custom_call_status.h" +#include "xla/ffi/api/ffi.h" + +#define JAX_FFI_RETURN_IF_GPU_ERROR(...) \ + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(__VA_ARGS__)) + +namespace ffi = ::xla::ffi; namespace jax { @@ -38,7 +44,7 @@ template <> /*static*/ absl::StatusOr SparseHandlePool::Borrow( gpuStream_t stream) { SparseHandlePool* pool = Instance(); - absl::MutexLock lock(&pool->mu_); + absl::MutexLock lock(pool->mu_); gpusparseHandle_t handle; if (pool->handles_[stream].empty()) { JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreate(&handle))); @@ -65,24 +71,19 @@ absl::StatusOr ConstOne(gpuDataType type) { std::memset(&c, 0, sizeof(c)); switch (type) { #ifdef JAX_GPU_CUDA -#if JAX_GPU_HAVE_SPARSE // TODO(jakevdp): 4I/4U here might break on big endian platforms. case CUDA_R_4I: case CUDA_C_4I: -#endif case CUDA_R_8I: case CUDA_C_8I: c.i8[0] = 1; break; -#if JAX_GPU_HAVE_SPARSE case CUDA_R_4U: case CUDA_C_4U: -#endif case CUDA_R_8U: case CUDA_C_8U: c.u8[0] = 1; break; -#if JAX_GPU_HAVE_SPARSE case CUDA_R_16I: case CUDA_C_16I: c.i16[0] = 1; @@ -91,7 +92,6 @@ absl::StatusOr ConstOne(gpuDataType type) { case CUDA_C_16U: c.u16[0] = 1; break; -#endif case CUDA_R_32I: case CUDA_C_32I: c.i32[0] = 1; @@ -100,7 +100,6 @@ absl::StatusOr ConstOne(gpuDataType type) { case CUDA_C_32U: c.u32[0] = 1; break; -#if JAX_GPU_HAVE_SPARSE case CUDA_R_64I: case CUDA_C_64I: c.i64[0] = 1; @@ -109,7 +108,6 @@ absl::StatusOr ConstOne(gpuDataType type) { case CUDA_C_64U: c.u64[0] = 1; break; -#endif #if JAX_GPU_HAVE_FP8 case CUDA_R_8F_E4M3: c.u8[0] = __nv_cvt_float_to_fp8(1.0f, __NV_NOSAT, __NV_E4M3); @@ -118,12 +116,10 @@ absl::StatusOr ConstOne(gpuDataType type) { c.u8[0] = __nv_cvt_float_to_fp8(1.0f, __NV_NOSAT, __NV_E5M2); break; #endif -#if JAX_GPU_HAVE_SPARSE case CUDA_R_16BF: case CUDA_C_16BF: c.u16[0] = 0b11111110000000; // 1.0 in little-endian bfloat16 break; -#endif #endif // JAX_GPU_CUDA // TODO(rocm): add more data types if new rocm supports them. @@ -147,7 +143,6 @@ absl::StatusOr ConstOne(gpuDataType type) { return c; } -#if JAX_GPU_HAVE_SPARSE // CsrToDense: Convert CSR matrix to dense matrix static absl::Status CsrToDense_(gpuStream_t stream, void** buffers, @@ -182,15 +177,6 @@ static absl::Status CsrToDense_(gpuStream_t stream, void** buffers, JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(CsrToDenseFfi, CsrToDense_); -void CsrToDense(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CsrToDense_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // CsrFromDense: Convert dense matrix to CSR matrix static absl::Status CsrFromDense_(gpuStream_t stream, void** buffers, @@ -226,15 +212,6 @@ static absl::Status CsrFromDense_(gpuStream_t stream, void** buffers, JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(CsrFromDenseFfi, CsrFromDense_); -void CsrFromDense(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CsrFromDense_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // CsrMatvec: Product of CSR matrix and dense vector. static absl::Status CsrMatvec_(gpuStream_t stream, void** buffers, @@ -285,15 +262,6 @@ static absl::Status CsrMatvec_(gpuStream_t stream, void** buffers, JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(CsrMatvecFfi, CsrMatvec_); -void CsrMatvec(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CsrMatvec_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // CsrMatmat: Product of CSR matrix and dense matrix. static absl::Status CsrMatmat_(gpuStream_t stream, void** buffers, @@ -345,15 +313,6 @@ static absl::Status CsrMatmat_(gpuStream_t stream, void** buffers, JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(CsrMatmatFfi, CsrMatmat_); -void CsrMatmat(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CsrMatmat_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // CooToDense: Convert COO matrix to dense matrix static absl::Status CooToDense_(gpuStream_t stream, void** buffers, @@ -388,15 +347,6 @@ static absl::Status CooToDense_(gpuStream_t stream, void** buffers, JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(CooToDenseFfi, CooToDense_); -void CooToDense(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CooToDense_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // CooFromDense: Convert dense matrix to COO matrix static absl::Status CooFromDense_(gpuStream_t stream, void** buffers, @@ -432,15 +382,6 @@ static absl::Status CooFromDense_(gpuStream_t stream, void** buffers, JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(CooFromDenseFfi, CooFromDense_); -void CooFromDense(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CooFromDense_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // CooMatvec: Product of COO matrix and dense vector. static absl::Status CooMatvec_(gpuStream_t stream, void** buffers, @@ -490,15 +431,6 @@ static absl::Status CooMatvec_(gpuStream_t stream, void** buffers, JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(CooMatvecFfi, CooMatvec_); -void CooMatvec(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CooMatvec_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // CooMatmat: Product of COO matrix and dense matrix. static absl::Status CooMatmat_(gpuStream_t stream, void** buffers, @@ -558,90 +490,162 @@ static absl::Status CooMatmat_(gpuStream_t stream, void** buffers, JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(CooMatmatFfi, CooMatmat_); -void CooMatmat(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CooMatmat_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); +template +ffi::Error Gtsv2Impl(BufferSizeF getBufferSize, KernelF kernel, int64_t batch, + int64_t rows, int64_t cols, gpuStream_t stream, + ffi::ScratchAllocator& scratch, ffi::AnyBuffer dl, + ffi::AnyBuffer d, ffi::AnyBuffer du, ffi::AnyBuffer b, + ffi::Result out) { + FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow(rows)); + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); + + FFI_ASSIGN_OR_RETURN(auto handle, SparseHandlePool::Borrow(stream)); + size_t buffer_size_in_bytes; + JAX_FFI_RETURN_IF_GPU_ERROR(getBufferSize(handle.get(), m, n, nullptr, + nullptr, nullptr, nullptr, m, + &buffer_size_in_bytes)); + auto maybe_workspace = scratch.Allocate(buffer_size_in_bytes); + if (!maybe_workspace.has_value()) { + return ffi::Error::Internal("Unable to allocate workspace for gtsv2"); + } + void* workspace = maybe_workspace.value(); + + auto dl_data = static_cast(dl.untyped_data()); + auto d_data = static_cast(d.untyped_data()); + auto du_data = static_cast(du.untyped_data()); + auto b_data = static_cast(b.untyped_data()); + auto out_data = static_cast(out->untyped_data()); + if (b_data != out_data) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, b_data, b.size_bytes(), gpuMemcpyDeviceToDevice, stream)); } -} -#endif // if JAX_GPU_HAVE_SPARSE -template -static absl::Status gtsv2(F computeGtsv2, gpuStream_t stream, void** buffers, - const char* opaque, std::size_t opaque_len) { - auto h = SparseHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; + for (int64_t i = 0; i < batch; ++i) { + JAX_FFI_RETURN_IF_GPU_ERROR(kernel(handle.get(), m, n, dl_data, d_data, + du_data, out_data, m, workspace)); + dl_data += m; + d_data += m; + du_data += m; + out_data += m * n; + } + return ffi::Error::Success(); +} - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const Gtsv2Descriptor& descriptor = **s; - int batch = descriptor.batch; - int m = descriptor.m; - int n = descriptor.n; - int ldb = descriptor.ldb; - - T* dl = static_cast(buffers[0]); - T* d = static_cast(buffers[1]); - T* du = static_cast(buffers[2]); - T* B = static_cast(buffers[3]); - T* X = static_cast(buffers[4]); - void* buffer = static_cast(buffers[5]); - - // The solution X is written in place to B. We need to therefore copy the - // contents of B into the output buffer X and pass that into the kernel as B. - // Once copy insertion is supported for custom call aliasing, we could alias B - // with X and avoid the copy, the code below is written defensively assuming B - // and X might alias, but today we know they will not. - // TODO(b/182906199): Update the comment here once copy insertion is WAI. - if (X != B) { - size_t B_bytes = ldb * n * sizeof(T) * batch; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpuMemcpyAsync(X, B, B_bytes, gpuMemcpyDeviceToDevice, stream))); +template +ffi::Error Gtsv2BatchedImpl(BufferSizeF getBufferSize, KernelF kernel, + int64_t batch, int64_t rows, gpuStream_t stream, + ffi::ScratchAllocator& scratch, ffi::AnyBuffer dl, + ffi::AnyBuffer d, ffi::AnyBuffer du, + ffi::AnyBuffer b, ffi::Result out) { + FFI_ASSIGN_OR_RETURN(auto batch_count, MaybeCastNoOverflow(batch)); + FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow(rows)); + + FFI_ASSIGN_OR_RETURN(auto handle, SparseHandlePool::Borrow(stream)); + size_t buffer_size_in_bytes; + JAX_FFI_RETURN_IF_GPU_ERROR(getBufferSize(handle.get(), m, nullptr, nullptr, + nullptr, nullptr, batch_count, m, + &buffer_size_in_bytes)); + auto maybe_workspace = scratch.Allocate(buffer_size_in_bytes); + if (!maybe_workspace.has_value()) { + return ffi::Error::Internal("Unable to allocate workspace for gtsv2"); } - for (int i = 0; i < batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - computeGtsv2(handle.get(), m, n, dl, d, du, X, ldb, buffer))); - dl += m; - d += m; - du += m; - X += m * n; + void* workspace = maybe_workspace.value(); + + auto dl_data = static_cast(dl.untyped_data()); + auto d_data = static_cast(d.untyped_data()); + auto du_data = static_cast(du.untyped_data()); + auto b_data = static_cast(b.untyped_data()); + auto out_data = static_cast(out->untyped_data()); + if (b_data != out_data) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, b_data, b.size_bytes(), gpuMemcpyDeviceToDevice, stream)); } - return absl::OkStatus(); + + JAX_FFI_RETURN_IF_GPU_ERROR(kernel(handle.get(), m, dl_data, d_data, du_data, + out_data, batch_count, m, workspace)); + return ffi::Error::Success(); } -JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL( - gtsv2_f32_ffi, [](gpuStream_t stream, void** buffers, const char* opaque, - std::size_t opaque_len) { - return gtsv2(gpusparseSgtsv2, stream, buffers, opaque, opaque_len); - }); - -JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL( - gtsv2_f64_ffi, [](gpuStream_t stream, void** buffers, const char* opaque, - std::size_t opaque_len) { - return gtsv2(gpusparseDgtsv2, stream, buffers, opaque, - opaque_len); - }); - -void gtsv2_f32(gpuStream_t stream, void** buffers, const char* opaque, - std::size_t opaque_len, XlaCustomCallStatus* status) { - auto s = gtsv2(gpusparseSgtsv2, stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); +ffi::Error Gtsv2(gpuStream_t stream, ffi::ScratchAllocator scratch, + ffi::AnyBuffer dl, ffi::AnyBuffer d, ffi::AnyBuffer du, + ffi::AnyBuffer b, ffi::Result out) { + auto dataType = dl.element_type(); + if (dataType != d.element_type() || dataType != du.element_type() || + dataType != b.element_type() || dataType != out->element_type()) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to gtsv2 must have the same element type"); } -} + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(b.dimensions())); + FFI_RETURN_IF_ERROR( + CheckShape(out->dimensions(), {batch, rows, cols}, "out", "gtsv2")); + FFI_RETURN_IF_ERROR( + CheckShape(dl.dimensions(), {batch, rows}, "dl", "gtsv2")); + FFI_RETURN_IF_ERROR(CheckShape(d.dimensions(), {batch, rows}, "d", "gtsv2")); + FFI_RETURN_IF_ERROR( + CheckShape(du.dimensions(), {batch, rows}, "du", "gtsv2")); + if (batch > 1 && cols == 1) { + switch (dataType) { + case ffi::F32: + return Gtsv2BatchedImpl( + gpusparseSgtsv2StridedBatch_bufferSizeExt, + gpusparseSgtsv2StridedBatch, batch, rows, stream, scratch, dl, d, + du, b, out); + case ffi::F64: + return Gtsv2BatchedImpl( + gpusparseDgtsv2StridedBatch_bufferSizeExt, + gpusparseDgtsv2StridedBatch, batch, rows, stream, scratch, dl, d, + du, b, out); + case ffi::C64: + return Gtsv2BatchedImpl( + gpusparseCgtsv2StridedBatch_bufferSizeExt, + gpusparseCgtsv2StridedBatch, batch, rows, stream, scratch, dl, d, + du, b, out); + case ffi::C128: + return Gtsv2BatchedImpl( + gpusparseZgtsv2StridedBatch_bufferSizeExt, + gpusparseZgtsv2StridedBatch, batch, rows, stream, scratch, dl, d, + du, b, out); + default: + break; + } -void gtsv2_f64(gpuStream_t stream, void** buffers, const char* opaque, - std::size_t opaque_len, XlaCustomCallStatus* status) { - auto s = gtsv2(gpusparseDgtsv2, stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); + } else { + switch (dataType) { + case ffi::F32: + return Gtsv2Impl(gpusparseSgtsv2_bufferSizeExt, gpusparseSgtsv2, + batch, rows, cols, stream, scratch, dl, d, du, + b, out); + case ffi::F64: + return Gtsv2Impl(gpusparseDgtsv2_bufferSizeExt, gpusparseDgtsv2, + batch, rows, cols, stream, scratch, dl, d, du, + b, out); + case ffi::C64: + return Gtsv2Impl(gpusparseCgtsv2_bufferSizeExt, + gpusparseCgtsv2, batch, rows, cols, stream, + scratch, dl, d, du, b, out); + case ffi::C128: + return Gtsv2Impl(gpusparseZgtsv2_bufferSizeExt, + gpusparseZgtsv2, batch, rows, cols, + stream, scratch, dl, d, du, b, out); + default: + break; + } } + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in gtsv2", absl::FormatStreamed(dataType))); } +XLA_FFI_DEFINE_HANDLER_SYMBOL(kGtsv2, Gtsv2, + ffi::Ffi::Bind() + .Ctx>() + .Ctx() + .Arg() // dl + .Arg() // d + .Arg() // du + .Arg() // b + .Ret() // out +); + } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/sparse_kernels.h b/jaxlib/gpu/sparse_kernels.h index 323431812758..bed1a753d097 100644 --- a/jaxlib/gpu/sparse_kernels.h +++ b/jaxlib/gpu/sparse_kernels.h @@ -16,17 +16,12 @@ limitations under the License. #ifndef JAXLIB_GPU_SPARSE_KERNELS_H_ #define JAXLIB_GPU_SPARSE_KERNELS_H_ -#include #include -#include -#include -#include #include "absl/status/statusor.h" +#include "jaxlib/gpu/handle_pool.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/handle_pool.h" #include "xla/ffi/api/ffi.h" -#include "xla/service/custom_call_status.h" namespace jax { @@ -74,82 +69,30 @@ struct DenseVecDescriptor { int size; }; -#if JAX_GPU_HAVE_SPARSE -// CsrToDense: Convert CSR matrix to dense matrix - -void CsrToDense(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// CsrFromDense: Convert dense matrix to CSR matrix - -void CsrFromDense(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// CsrMatvec: Product of CSR matrix and dense vector. - struct CsrMatvecDescriptor { SparseMatDescriptor A; DenseVecDescriptor x, y; gpusparseOperation_t op; }; -void CsrMatvec(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// CsrMatmat: Product of CSR matrix and dense matrix. - struct CsrMatmatDescriptor { SparseMatDescriptor A; DenseMatDescriptor B, C; gpusparseOperation_t op_A; }; -void CsrMatmat(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// CooToDense: Convert COO matrix to dense matrix - -void CooToDense(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// CooFromDense: Convert dense matrix to COO matrix - -void CooFromDense(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// CooMatvec: Product of COO matrix and dense vector. - struct CooMatvecDescriptor { SparseMatDescriptor A; DenseVecDescriptor x, y; gpusparseOperation_t op; }; -void CooMatvec(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// CooMatmat: Product of COO matrix and dense matrix. - struct CooMatmatDescriptor { SparseMatDescriptor A; DenseMatDescriptor B, C; gpusparseOperation_t op_A; }; -void CooMatmat(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); -#endif // JAX_GPU_HAVE_SPARSE - -struct Gtsv2Descriptor { - int batch, m, n, ldb; -}; - -void gtsv2_f32(gpuStream_t stream, void** buffers, const char* opaque, - std::size_t opaque_len, XlaCustomCallStatus* status); - -void gtsv2_f64(gpuStream_t stream, void** buffers, const char* opaque, - std::size_t opaque_len, XlaCustomCallStatus* status); - XLA_FFI_DECLARE_HANDLER_SYMBOL(CsrToDenseFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(CsrFromDenseFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(CsrMatvecFfi); @@ -158,8 +101,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CooToDenseFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(CooFromDenseFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(CooMatvecFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(CooMatmatFfi); -XLA_FFI_DECLARE_HANDLER_SYMBOL(gtsv2_f32_ffi); -XLA_FFI_DECLARE_HANDLER_SYMBOL(gtsv2_f64_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(kGtsv2); } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/triton.cc b/jaxlib/gpu/triton.cc index 500034af3ebb..a1bb10ed510f 100644 --- a/jaxlib/gpu/triton.cc +++ b/jaxlib/gpu/triton.cc @@ -1,17 +1,34 @@ +/* Copyright 2022 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include #include -#include #include #include #include +#include #include -#include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" -#include "nanobind/stl/string.h" -#include "nanobind/stl/string_view.h" -#include "nanobind/stl/tuple.h" -#include "nanobind/stl/vector.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/tuple.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/absl_status_casters.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/triton.pb.h" @@ -28,8 +45,8 @@ namespace jax::JAX_GPU_NAMESPACE { NB_MODULE(_triton, m) { nb::class_(m, "TritonKernel") - .def(nb::init()); + .def(nb::init()); nb::class_(m, "TritonParameter"); @@ -132,27 +149,25 @@ NB_MODULE(_triton, m) { return major * 10 + minor; })); - m.def( - "get_arch_details", - ValueOrThrowWrapper([](int device) -> absl::StatusOr { + m.def("get_arch_details", + ValueOrThrowWrapper([](int device) -> absl::StatusOr { #ifdef JAX_GPU_HIP - hipDeviceProp_t prop; - hipGetDeviceProperties(&prop, 0); - return prop.gcnArchName; + hipDeviceProp_t prop; + hipGetDeviceProperties(&prop, 0); + return prop.gcnArchName; #else - return absl::UnimplementedError("Not a HIP GPU"); + return absl::UnimplementedError("Not a HIP GPU"); #endif - })); + })); m.def("get_serialized_metadata", - ValueOrThrowWrapper( - [](nb::bytes opaque) -> absl::StatusOr { - JAX_ASSIGN_OR_RETURN( - std::string metadata, - GetTritonKernelCallSerializedMetadata( - absl::string_view(opaque.c_str(), opaque.size()))); - return nb::bytes(metadata.c_str(), metadata.size()); - })); + ValueOrThrowWrapper([](nb::bytes opaque) -> absl::StatusOr { + JAX_ASSIGN_OR_RETURN( + std::string metadata, + GetTritonKernelCallSerializedMetadata( + std::string_view(opaque.c_str(), opaque.size()))); + return nb::bytes(metadata.c_str(), metadata.size()); + })); } } // namespace jax::JAX_GPU_NAMESPACE diff --git a/jaxlib/gpu/triton.proto b/jaxlib/gpu/triton.proto index 786b07afbdbe..2ec1f12e3301 100644 --- a/jaxlib/gpu/triton.proto +++ b/jaxlib/gpu/triton.proto @@ -5,13 +5,13 @@ package jax_triton; message TritonKernel { string kernel_name = 1; // Kernel function name within module. uint32 num_warps = 2; + optional uint32 num_ctas = 10; uint32 shared_mem_bytes = 3; string ptx = 4; string ttir = 5; uint32 compute_capability = 6; - uint32 cluster_dim_0 = 7; - uint32 cluster_dim_1 = 8; - uint32 cluster_dim_2 = 9; + + reserved 7, 8, 9; } message TritonKernelCall { diff --git a/jaxlib/gpu/triton_kernels.cc b/jaxlib/gpu/triton_kernels.cc index 22397ff908bc..0ad86f522d9d 100644 --- a/jaxlib/gpu/triton_kernels.cc +++ b/jaxlib/gpu/triton_kernels.cc @@ -1,3 +1,18 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "jaxlib/gpu/triton_kernels.h" #include @@ -34,10 +49,12 @@ #ifdef JAX_GPU_CUDA #include "xla/stream_executor/cuda/cuda_asm_compiler.h" +#include "xla/stream_executor/cuda/cuda_compute_capability.h" #endif // JAX_GPU_CUDA #ifdef JAX_GPU_HIP -#include "tsl/platform/env.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" #endif // JAX_GPU_HIP #define GPU_RETURN_IF_ERROR(expr) JAX_RETURN_IF_ERROR(JAX_AS_STATUS(expr)) @@ -94,7 +111,7 @@ absl::StatusOr GetModuleImage(std::string kernel_name, *new absl::flat_hash_map> ABSL_GUARDED_BY(mutex); - absl::MutexLock lock(&mutex); + absl::MutexLock lock(mutex); auto it = module_images.find(key); if (it != module_images.end()) return it->second.get(); @@ -108,9 +125,17 @@ absl::StatusOr GetModuleImage(std::string kernel_name, // TODO(cjfj): Support `TRITON_PTXAS_PATH` environment variable? int cc_major = compute_capability / 10; int cc_minor = compute_capability % 10; + + bool has_accelerated_features = cc_major >= 9; + using FeatureExtension = + stream_executor::CudaComputeCapability::FeatureExtension; + const stream_executor::CudaComputeCapability cc( + cc_major, cc_minor, + has_accelerated_features ? FeatureExtension::kAcceleratedFeatures + : FeatureExtension::kNone); JAX_ASSIGN_OR_RETURN( std::vector module_image, - stream_executor::CompileGpuAsm(cc_major, cc_minor, ptx.data(), + stream_executor::CompileGpuAsm(cc, std::string(ptx), stream_executor::GpuAsmOpts{})); #endif @@ -141,7 +166,7 @@ absl::StatusOr Benchmark(gpuStream_t stream, KernelCall& kernel_call, return elapsed_ms; } -absl::StatusOr GetKernelCall(absl::string_view opaque, +absl::StatusOr GetKernelCall(std::string_view opaque, gpuStream_t stream, void** buffers) { static absl::Mutex mutex; static auto& kernel_calls = @@ -151,7 +176,7 @@ absl::StatusOr GetKernelCall(absl::string_view opaque, { // Fast path uses reader lock (as hash map look-up is relatively slow). - absl::ReaderMutexLock lock(&mutex); + absl::ReaderMutexLock lock(mutex); auto it = kernel_calls.find(opaque); if (ABSL_PREDICT_TRUE(it != kernel_calls.end())) { JAX_RETURN_IF_ERROR(it->second.status()); @@ -163,7 +188,7 @@ absl::StatusOr GetKernelCall(absl::string_view opaque, return absl::InvalidArgumentError("Opaque data is empty."); } - absl::MutexLock lock(&mutex); + absl::MutexLock lock(mutex); auto get_kernel_call = [&]() -> absl::StatusOr> { // The opaque data is a zlib compressed protobuf. @@ -212,7 +237,7 @@ class ModuleImage { shared_mem_bytes_(shared_mem_bytes) {} absl::StatusOr GetFunctionForContext(gpuContext_t context) { - absl::MutexLock lock(&mutex_); + absl::MutexLock lock(mutex_); auto it = functions_.find(context); if (ABSL_PREDICT_TRUE(it != functions_.end())) { return it->second; @@ -290,17 +315,16 @@ class ModuleImage { ABSL_GUARDED_BY(mutex_); }; -Kernel::Kernel(std::string kernel_name, uint32_t num_warps, +Kernel::Kernel(std::string kernel_name, uint32_t num_warps, uint32_t num_ctas, uint32_t shared_mem_bytes, std::string ptx, std::string ttir, - int compute_capability, uint32_t cluster_dim_0, - uint32_t cluster_dim_1, uint32_t cluster_dim_2) + int compute_capability) : kernel_name_(std::move(kernel_name)), block_dim_x_(num_warps * kNumThreadsPerWarp), + num_ctas_(num_ctas), shared_mem_bytes_(shared_mem_bytes), ptx_(std::move(ptx)), ttir_(std::move(ttir)), - compute_capability_(compute_capability), - cluster_dims_{cluster_dim_0, cluster_dim_1, cluster_dim_2} {} + compute_capability_(compute_capability) {} absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3], void** params) { @@ -337,9 +361,7 @@ absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3], JAX_ASSIGN_OR_RETURN(gpuFunction_t kernel, module_image_->GetFunctionForContext(context)); - const uint32_t cluster_size = - cluster_dims_[0] * cluster_dims_[1] * cluster_dims_[2]; - if (cluster_size <= 1) { + if (num_ctas_ == 1) { return JAX_AS_STATUS(gpuLaunchKernel( kernel, grid[0], grid[1], grid[2], block_dim_x_, /*blockDimY=*/1, /*blockDimZ=*/1, shared_mem_bytes_, stream, params, @@ -347,16 +369,16 @@ absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3], } CUlaunchAttribute launch_attrs[2]; launch_attrs[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; - launch_attrs[0].value.clusterDim.x = cluster_dims_[0]; - launch_attrs[0].value.clusterDim.y = cluster_dims_[1]; - launch_attrs[0].value.clusterDim.z = cluster_dims_[2]; + launch_attrs[0].value.clusterDim.x = num_ctas_; + launch_attrs[0].value.clusterDim.y = 1; + launch_attrs[0].value.clusterDim.z = 1; launch_attrs[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; launch_attrs[1].value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD; CUlaunchConfig launch_config = { - /*gridDimX=*/grid[0] * cluster_dims_[0], - /*gridDimY=*/grid[1] * cluster_dims_[1], - /*gridDimZ=*/grid[2] * cluster_dims_[2], + /*gridDimX=*/grid[0] * num_ctas_, + /*gridDimY=*/grid[1], + /*gridDimZ=*/grid[2], /*blockDimX=*/block_dim_x_, /*blockDimY=*/1, /*blockDimZ=*/1, @@ -371,23 +393,23 @@ absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3], } /*static*/ Kernel Kernel::FromProto(const jax_triton::TritonKernel& proto) { - return Kernel(proto.kernel_name(), proto.num_warps(), + // Use 1 as default value if not specified in already serialized kernels. + int num_ctas = proto.has_num_ctas() ? proto.num_ctas() : 1; + + return Kernel(proto.kernel_name(), proto.num_warps(), num_ctas, proto.shared_mem_bytes(), proto.ptx(), proto.ttir(), - proto.compute_capability(), proto.cluster_dim_0(), - proto.cluster_dim_1(), proto.cluster_dim_2()); + proto.compute_capability()); } jax_triton::TritonKernel Kernel::ToProto() const { jax_triton::TritonKernel proto; proto.set_kernel_name(kernel_name_); proto.set_num_warps(block_dim_x_ / kNumThreadsPerWarp); + proto.set_num_ctas(num_ctas_); proto.set_shared_mem_bytes(shared_mem_bytes_); proto.set_ptx(ptx_); proto.set_ttir(ttir_); proto.set_compute_capability(compute_capability_); - proto.set_cluster_dim_0(cluster_dims_[0]); - proto.set_cluster_dim_1(cluster_dims_[1]); - proto.set_cluster_dim_2(cluster_dims_[2]); return proto; } @@ -499,8 +521,10 @@ absl::Status KernelCall::Launch(gpuStream_t stream, void** buffers) { // pointer. // TODO: b/381242007 - Allocate a proper buffer if we want to use // device-side TMA APIs. - void* scratch_ptr = nullptr; // Alive until kernel_.Launch returns. - params.push_back(&scratch_ptr); + void* tma_descriptor_buffer = nullptr; // Alive until kernel_.Launch returns. + params.push_back(&tma_descriptor_buffer); + void* profiling_buffer = nullptr; // Alive until kernel_.Launch returns. + params.push_back(&profiling_buffer); return kernel_.Launch(stream, grid_, params.data()); } @@ -688,11 +712,11 @@ void TritonKernelCall(gpuStream_t stream, void** buffers, const char* opaque, absl::Status result = [=] { JAX_ASSIGN_OR_RETURN( KernelCall * kernel_call, - GetKernelCall(absl::string_view(opaque, opaque_len), stream, buffers)); + GetKernelCall(std::string_view(opaque, opaque_len), stream, buffers)); return kernel_call->Launch(stream, buffers); }(); if (!result.ok()) { - absl::string_view msg = result.message(); + std::string_view msg = result.message(); XlaCustomCallStatusSetFailure(status, msg.data(), msg.length()); } } diff --git a/jaxlib/gpu/triton_kernels.h b/jaxlib/gpu/triton_kernels.h index c3457093c4f8..08320a104183 100644 --- a/jaxlib/gpu/triton_kernels.h +++ b/jaxlib/gpu/triton_kernels.h @@ -1,8 +1,23 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef JAXLIB_GPU_TRITON_H_ #define JAXLIB_GPU_TRITON_H_ +#include #include -#include #include #include #include @@ -10,7 +25,6 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/triton.pb.h" #include "jaxlib/gpu/vendor.h" #include "xla/service/custom_call_status.h" @@ -24,10 +38,9 @@ class ModuleImage; class Kernel { public: - Kernel(std::string kernel_name, uint32_t num_warps, uint32_t shared_mem_bytes, - std::string ptx, std::string ttir, int compute_capability, - uint32_t cluster_dim_0, uint32_t cluster_dim_1, - uint32_t cluster_dim_2); + Kernel(std::string kernel_name, uint32_t num_warps, uint32_t num_ctas, + uint32_t shared_mem_bytes, std::string ptx, std::string ttir, + int compute_capability); absl::Status Launch(gpuStream_t stream, uint32_t grid[3], void** params); @@ -40,11 +53,11 @@ class Kernel { private: std::string kernel_name_; uint32_t block_dim_x_; + uint32_t num_ctas_; uint32_t shared_mem_bytes_; std::string ptx_; std::string ttir_; int compute_capability_; - uint32_t cluster_dims_[3]; ModuleImage* module_image_ = nullptr; }; @@ -93,8 +106,7 @@ class AutotunedKernelCall { AutotunedKernelCall( std::string name, std::vector configs, - std::vector> input_output_aliases); + std::vector> input_output_aliases); static absl::StatusOr Autotune(AutotunedKernelCall kernel_call, gpuStream_t stream, diff --git a/jaxlib/gpu/triton_utils.cc b/jaxlib/gpu/triton_utils.cc index b3a0779118de..eb8eb5bcb621 100644 --- a/jaxlib/gpu/triton_utils.cc +++ b/jaxlib/gpu/triton_utils.cc @@ -1,18 +1,34 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "jaxlib/gpu/triton_utils.h" #include #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/string_view.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/triton.pb.h" +#include "jaxlib/gpu/vendor.h" namespace jax::JAX_GPU_NAMESPACE { -absl::StatusOr ZlibUncompress(absl::string_view compressed) { +absl::StatusOr ZlibUncompress(std::string_view compressed) { std::string data; uLongf dest_len = 5 * compressed.size(); while (true) { @@ -33,7 +49,7 @@ absl::StatusOr ZlibUncompress(absl::string_view compressed) { return data; } -absl::StatusOr GetTritonKernelCallName(absl::string_view opaque) { +absl::StatusOr GetTritonKernelCallName(std::string_view opaque) { JAX_ASSIGN_OR_RETURN(std::string serialized, ZlibUncompress(opaque)); jax_triton::TritonAnyKernelCall proto; if (!proto.ParseFromString(serialized)) { @@ -43,7 +59,7 @@ absl::StatusOr GetTritonKernelCallName(absl::string_view opaque) { } absl::StatusOr GetTritonKernelCallSerializedMetadata( - absl::string_view opaque) { + std::string_view opaque) { JAX_ASSIGN_OR_RETURN(std::string serialized, ZlibUncompress(opaque)); jax_triton::TritonAnyKernelCall proto; if (!proto.ParseFromString(serialized)) { diff --git a/jaxlib/gpu/triton_utils.h b/jaxlib/gpu/triton_utils.h index 0c286391e296..9645f7822b61 100644 --- a/jaxlib/gpu/triton_utils.h +++ b/jaxlib/gpu/triton_utils.h @@ -1,19 +1,33 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef JAXLIB_GPU_TRITON_UTILS_H_ #define JAXLIB_GPU_TRITON_UTILS_H_ #include +#include -#include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/string_view.h" #include "jaxlib/gpu/vendor.h" namespace jax::JAX_GPU_NAMESPACE { -absl::StatusOr ZlibUncompress(absl::string_view compressed); -absl::StatusOr GetTritonKernelCallName(absl::string_view opaque); +absl::StatusOr ZlibUncompress(std::string_view compressed); +absl::StatusOr GetTritonKernelCallName(std::string_view opaque); absl::StatusOr GetTritonKernelCallSerializedMetadata( - absl::string_view opaque); + std::string_view opaque); } // namespace jax::JAX_GPU_NAMESPACE diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index 7334d4690b59..becfb049de88 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -20,6 +20,7 @@ limitations under the License. #ifndef JAXLIB_GPU_VENDOR_H_ #define JAXLIB_GPU_VENDOR_H_ +#include #if defined(JAX_GPU_CUDA) // IWYU pragma: begin_exports @@ -29,7 +30,7 @@ limitations under the License. #include "third_party/gpus/cuda/include/cublas_v2.h" #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_fp8.h" -#include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "cuda_runtime_api.h" #include "third_party/gpus/cuda/include/cufft.h" #include "third_party/gpus/cuda/include/cusolverDn.h" #include "third_party/gpus/cuda/include/cusolver_common.h" @@ -37,17 +38,16 @@ limitations under the License. #include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: end_exports -#if CUDA_VERSION < 11080 -#error "JAX requires CUDA 11.8 or newer." -#endif // CUDA_VERSION < 11080 - -#define JAX_GPU_HAVE_SPARSE 1 +#if CUDA_VERSION < 12000 +#error "JAX requires CUDA 12.0 or newer." +#endif // CUDA_VERSION < 12000 // CUDA-11.8 introduces FP8 E4M3/E5M2 types. #define JAX_GPU_HAVE_FP8 1 #define JAX_GPU_NAMESPACE cuda #define JAX_GPU_PREFIX "cu" +#define JAX_GPU_PLUGIN_NAME "cuda" typedef cuComplex gpuComplex; typedef cuDoubleComplex gpuDoubleComplex; @@ -150,7 +150,8 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; #define GPUDNN_STATUS_SUCCESS CUDNN_STATUS_SUCCESS #define GPUDNN_WGRAD_MODE_ADD CUDNN_WGRAD_MODE_ADD #define GPUDNN_RNN_ALGO_STANDARD CUDNN_RNN_ALGO_STANDARD -#define GPUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED +#define GPUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED \ + CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED #define GPUDNN_RNN_PADDED_IO_ENABLED CUDNN_RNN_PADDED_IO_ENABLED #define GPUDNN_DEFAULT_MATH CUDNN_DEFAULT_MATH #define GPUDNN_FMA_MATH CUDNN_FMA_MATH @@ -196,6 +197,18 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpusolverDnDorgqr_bufferSize cusolverDnDorgqr_bufferSize #define gpusolverDnCungqr_bufferSize cusolverDnCungqr_bufferSize #define gpusolverDnZungqr_bufferSize cusolverDnZungqr_bufferSize +#define gpusolverDnSpotrf cusolverDnSpotrf +#define gpusolverDnDpotrf cusolverDnDpotrf +#define gpusolverDnCpotrf cusolverDnCpotrf +#define gpusolverDnZpotrf cusolverDnZpotrf +#define gpusolverDnSpotrf_bufferSize cusolverDnSpotrf_bufferSize +#define gpusolverDnDpotrf_bufferSize cusolverDnDpotrf_bufferSize +#define gpusolverDnCpotrf_bufferSize cusolverDnCpotrf_bufferSize +#define gpusolverDnZpotrf_bufferSize cusolverDnZpotrf_bufferSize +#define gpusolverDnSpotrfBatched cusolverDnSpotrfBatched +#define gpusolverDnDpotrfBatched cusolverDnDpotrfBatched +#define gpusolverDnCpotrfBatched cusolverDnCpotrfBatched +#define gpusolverDnZpotrfBatched cusolverDnZpotrfBatched #define gpusolverDnSsyevd cusolverDnSsyevd #define gpusolverDnDsyevd cusolverDnDsyevd #define gpusolverDnCheevd cusolverDnCheevd @@ -287,10 +300,28 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpusparseSpMM_bufferSize cusparseSpMM_bufferSize #define gpusparseSpMV cusparseSpMV #define gpusparseSpMV_bufferSize cusparseSpMV_bufferSize + #define gpusparseSgtsv2 cusparseSgtsv2 #define gpusparseDgtsv2 cusparseDgtsv2 +#define gpusparseCgtsv2 cusparseCgtsv2 +#define gpusparseZgtsv2 cusparseZgtsv2 #define gpusparseSgtsv2_bufferSizeExt cusparseSgtsv2_bufferSizeExt #define gpusparseDgtsv2_bufferSizeExt cusparseDgtsv2_bufferSizeExt +#define gpusparseCgtsv2_bufferSizeExt cusparseCgtsv2_bufferSizeExt +#define gpusparseZgtsv2_bufferSizeExt cusparseZgtsv2_bufferSizeExt + +#define gpusparseSgtsv2StridedBatch_bufferSizeExt \ + cusparseSgtsv2StridedBatch_bufferSizeExt +#define gpusparseDgtsv2StridedBatch_bufferSizeExt \ + cusparseDgtsv2StridedBatch_bufferSizeExt +#define gpusparseCgtsv2StridedBatch_bufferSizeExt \ + cusparseCgtsv2StridedBatch_bufferSizeExt +#define gpusparseZgtsv2StridedBatch_bufferSizeExt \ + cusparseZgtsv2StridedBatch_bufferSizeExt +#define gpusparseSgtsv2StridedBatch cusparseSgtsv2StridedBatch +#define gpusparseDgtsv2StridedBatch cusparseDgtsv2StridedBatch +#define gpusparseCgtsv2StridedBatch cusparseCgtsv2StridedBatch +#define gpusparseZgtsv2StridedBatch cusparseZgtsv2StridedBatch #define GPUSPARSE_INDEX_16U CUSPARSE_INDEX_16U #define GPUSPARSE_INDEX_32I CUSPARSE_INDEX_32I @@ -389,11 +420,26 @@ typedef cusolverDnParams_t gpusolverDnParams_t; #define gpusolverDnCreateParams cusolverDnCreateParams #define gpusolverDnDestroyParams cusolverDnDestroyParams +#define gpusolverGetVersion cusolverGetVersion + #define gpusolverDnXsyevd_bufferSize cusolverDnXsyevd_bufferSize #define gpusolverDnXsyevd cusolverDnXsyevd +#define gpusolverDnXsyevBatched_bufferSize cusolverDnXsyevBatched_bufferSize +#define gpusolverDnXsyevBatched cusolverDnXsyevBatched #define gpusolverDnXgesvd_bufferSize cusolverDnXgesvd_bufferSize #define gpusolverDnXgesvd cusolverDnXgesvd +#define gpusolverDnXgesvdp_bufferSize cusolverDnXgesvdp_bufferSize +#define gpusolverDnXgesvdp cusolverDnXgesvdp + +#if CUDA_VERSION >= 12060 +#define JAX_GPU_HAVE_SOLVER_GEEV 1 +#define gpusolverDnXgeev_bufferSize cusolverDnXgeev_bufferSize +#define gpusolverDnXgeev cusolverDnXgeev +#else +#define JAX_GPU_HAVE_SOLVER_GEEV 0 +#endif // CUDA_VERSION >= 12060 + namespace jax::JAX_GPU_NAMESPACE { namespace { constexpr uint32_t kNumThreadsPerWarp = 32; @@ -402,6 +448,8 @@ constexpr uint32_t kNumThreadsPerWarp = 32; #elif defined(JAX_GPU_HIP) +#define HIPBLAS_V2 1 + // IWYU pragma: begin_exports #include "rocm/include/hip/hip_cooperative_groups.h" #include "rocm/include/hip/hip_runtime_api.h" @@ -409,27 +457,28 @@ constexpr uint32_t kNumThreadsPerWarp = 32; #include "rocm/include/hipsolver/hipsolver.h" #include "rocm/include/hipsparse/hipsparse.h" #include "rocm/include/miopen/miopen.h" +#include "rocm/rocm_config.h" // IWYU pragma: end_exports #define JAX_GPU_NAMESPACE hip #define JAX_GPU_PREFIX "hip" +#define JAX_GPU_PLUGIN_NAME "rocm" -#define JAX_GPU_HAVE_SPARSE 1 #define JAX_GPU_HAVE_64_BIT 0 #define JAX_GPU_HAVE_FP8 0 // TODO(Ruturaj4): Currently equivalent API does exist in // MIOpen lib. Remove when MIOpen support is complete. #define MIOPEN_STATUS_SUCCESS 0 -typedef hipFloatComplex gpuComplex; +typedef hipComplex gpuComplex; typedef hipDoubleComplex gpuDoubleComplex; -typedef hipblasComplex gpublasComplex; -typedef hipblasDoubleComplex gpublasDoubleComplex; -typedef hipsolverHandle_t gpusolverDnHandle_t; +typedef hipComplex gpublasComplex; +typedef hipDoubleComplex gpublasDoubleComplex; +typedef struct hipsolverHandle_* gpusolverDnHandle_t; typedef hipblasFillMode_t gpublasFillMode_t; typedef hipsolverFillMode_t gpusolverFillMode_t; -typedef hipblasHandle_t gpublasHandle_t; +typedef struct hipblasHandle_* gpublasHandle_t; typedef hipblasOperation_t gpublasOperation_t; typedef hipblasStatus_t gpublasStatus_t; typedef hipCtx_t gpuContext_t; @@ -458,10 +507,12 @@ typedef miopenRNNFWDMode_t gpudnnForwardMode_t; typedef hipModule_t gpuModule_t; typedef void gpuSyevjInfo; typedef hipsolverSyevjInfo_t gpuSyevjInfo_t; +typedef void gpuGesvdjInfo; +typedef hipsolverGesvdjInfo_t gpuGesvdjInfo_t; typedef hipsolverEigMode_t gpusolverEigMode_t; typedef hipsolverStatus_t gpusolverStatus_t; typedef hipsparseIndexType_t gpusparseIndexType_t; -typedef hipsparseHandle_t gpusparseHandle_t; +typedef struct hipsparseHandle_* gpusparseHandle_t; typedef hipsparseOperation_t gpusparseOperation_t; typedef hipsparseStatus_t gpusparseStatus_t; typedef hipsparseSpMatDescr_t gpusparseSpMatDescr_t; @@ -475,7 +526,12 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t; #define GPU_C_64F HIP_C_64F #define GPU_R_64F HIP_R_64F -#define gpublasCreate hipblasCreate +namespace{ +inline hipblasStatus_t gpublasCreate(gpublasHandle_t* handle) { + return hipblasCreate(reinterpret_cast(handle)); +} +} + #define gpublasSetStream hipblasSetStream #define gpublasSgeqrfBatched hipblasSgeqrfBatched #define gpublasDgeqrfBatched hipblasDgeqrfBatched @@ -526,10 +582,18 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t; #define GPUDNN_LSTM miopenLSTM #define GPUDNN_BIDIRECTIONAL miopenRNNbidirection -#define gpusolverDnCreate hipsolverCreate +// Wrapper functions for SOLVER handles to ensure unique types +namespace{ +inline hipsolverStatus_t gpusolverDnCreate(gpusolverDnHandle_t* handle) { + return hipsolverCreate(reinterpret_cast(handle)); +} +} + #define gpusolverDnSetStream hipsolverSetStream #define gpusolverDnCreateSyevjInfo hipsolverCreateSyevjInfo #define gpusolverDnDestroySyevjInfo hipsolverDestroySyevjInfo +#define gpusolverDnCreateGesvdjInfo hipsolverCreateGesvdjInfo +#define gpusolverDnDestroyGesvdjInfo hipsolverDestroyGesvdjInfo #define gpusolverDnSgeqrf hipsolverSgeqrf #define gpusolverDnDgeqrf hipsolverDgeqrf #define gpusolverDnCgeqrf hipsolverCgeqrf @@ -558,6 +622,18 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpusolverDnDorgqr_bufferSize hipsolverDorgqr_bufferSize #define gpusolverDnCungqr_bufferSize hipsolverCungqr_bufferSize #define gpusolverDnZungqr_bufferSize hipsolverZungqr_bufferSize +#define gpusolverDnSpotrf hipsolverSpotrf +#define gpusolverDnDpotrf hipsolverDpotrf +#define gpusolverDnCpotrf hipsolverCpotrf +#define gpusolverDnZpotrf hipsolverZpotrf +#define gpusolverDnSpotrf_bufferSize hipsolverSpotrf_bufferSize +#define gpusolverDnDpotrf_bufferSize hipsolverDpotrf_bufferSize +#define gpusolverDnCpotrf_bufferSize hipsolverCpotrf_bufferSize +#define gpusolverDnZpotrf_bufferSize hipsolverZpotrf_bufferSize +#define gpusolverDnSpotrfBatched hipsolverDnSpotrfBatched +#define gpusolverDnDpotrfBatched hipsolverDnDpotrfBatched +#define gpusolverDnCpotrfBatched hipsolverDnCpotrfBatched +#define gpusolverDnZpotrfBatched hipsolverDnZpotrfBatched #define gpusolverDnSsyevd hipsolverSsyevd #define gpusolverDnDsyevd hipsolverDsyevd #define gpusolverDnCheevd hipsolverCheevd @@ -594,6 +670,22 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t; hipsolverCgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) #define gpusolverDnZgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \ hipsolverZgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) +#define gpusolverDnSgesvdj hipsolverSgesvdj +#define gpusolverDnDgesvdj hipsolverDgesvdj +#define gpusolverDnCgesvdj hipsolverCgesvdj +#define gpusolverDnZgesvdj hipsolverZgesvdj +#define gpusolverDnSgesvdj_bufferSize hipsolverSgesvdj_bufferSize +#define gpusolverDnDgesvdj_bufferSize hipsolverDgesvdj_bufferSize +#define gpusolverDnCgesvdj_bufferSize hipsolverCgesvdj_bufferSize +#define gpusolverDnZgesvdj_bufferSize hipsolverZgesvdj_bufferSize +#define gpusolverDnSgesvdjBatched hipsolverSgesvdjBatched +#define gpusolverDnDgesvdjBatched hipsolverDgesvdjBatched +#define gpusolverDnCgesvdjBatched hipsolverCgesvdjBatched +#define gpusolverDnZgesvdjBatched hipsolverZgesvdjBatched +#define gpusolverDnSgesvdjBatched_bufferSize hipsolverSgesvdjBatched_bufferSize +#define gpusolverDnDgesvdjBatched_bufferSize hipsolverDgesvdjBatched_bufferSize +#define gpusolverDnCgesvdjBatched_bufferSize hipsolverCgesvdjBatched_bufferSize +#define gpusolverDnZgesvdjBatched_bufferSize hipsolverZgesvdjBatched_bufferSize #define gpusolverDnSsytrd_bufferSize hipsolverDnSsytrd_bufferSize #define gpusolverDnDsytrd_bufferSize hipsolverDnDsytrd_bufferSize #define gpusolverDnChetrd_bufferSize hipsolverDnChetrd_bufferSize @@ -614,7 +706,13 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t; #define GPUBLAS_OP_C HIPBLAS_OP_C #define gpusparseCooSetStridedBatch hipsparseCooSetStridedBatch -#define gpusparseCreate hipsparseCreate + +namespace{ +inline hipsparseStatus_t gpusparseCreate(gpusparseHandle_t* handle) { + return hipsparseCreate(reinterpret_cast(handle)); +} +} + #define gpusparseSetStream hipsparseSetStream #define gpusparseCreateCoo hipsparseCreateCoo #define gpusparseCreateCsr hipsparseCreateCsr @@ -633,19 +731,37 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpusparseSpMM_bufferSize hipsparseSpMM_bufferSize #define gpusparseSpMV hipsparseSpMV #define gpusparseSpMV_bufferSize hipsparseSpMV_bufferSize + #define gpusparseSgtsv2 hipsparseSgtsv2 #define gpusparseDgtsv2 hipsparseDgtsv2 +#define gpusparseCgtsv2 hipsparseCgtsv2 +#define gpusparseZgtsv2 hipsparseZgtsv2 #define gpusparseSgtsv2_bufferSizeExt hipsparseSgtsv2_bufferSizeExt #define gpusparseDgtsv2_bufferSizeExt hipsparseDgtsv2_bufferSizeExt +#define gpusparseCgtsv2_bufferSizeExt hipsparseCgtsv2_bufferSizeExt +#define gpusparseZgtsv2_bufferSizeExt hipsparseZgtsv2_bufferSizeExt + +#define gpusparseSgtsv2StridedBatch_bufferSizeExt \ + hipsparseSgtsv2StridedBatch_bufferSizeExt +#define gpusparseDgtsv2StridedBatch_bufferSizeExt \ + hipsparseDgtsv2StridedBatch_bufferSizeExt +#define gpusparseCgtsv2StridedBatch_bufferSizeExt \ + hipsparseCgtsv2StridedBatch_bufferSizeExt +#define gpusparseZgtsv2StridedBatch_bufferSizeExt \ + hipsparseZgtsv2StridedBatch_bufferSizeExt +#define gpusparseSgtsv2StridedBatch hipsparseSgtsv2StridedBatch +#define gpusparseDgtsv2StridedBatch hipsparseDgtsv2StridedBatch +#define gpusparseCgtsv2StridedBatch hipsparseCgtsv2StridedBatch +#define gpusparseZgtsv2StridedBatch hipsparseZgtsv2StridedBatch #define GPUSPARSE_INDEX_16U HIPSPARSE_INDEX_16U #define GPUSPARSE_INDEX_32I HIPSPARSE_INDEX_32I #define GPUSPARSE_INDEX_64I HIPSPARSE_INDEX_64I #define GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT -#define GPUSPARSE_SPMV_COO_ALG HIPSPARSE_MV_ALG_DEFAULT -#define GPUSPARSE_SPMV_CSR_ALG HIPSPARSE_MV_ALG_DEFAULT -#define GPUSPARSE_SPMM_COO_ALG HIPSPARSE_SPMM_ALG_DEFAULT -#define GPUSPARSE_SPMM_CSR_ALG HIPSPARSE_SPMM_ALG_DEFAULT +#define GPUSPARSE_SPMV_COO_ALG HIPSPARSE_COOMV_ALG +#define GPUSPARSE_SPMV_CSR_ALG HIPSPARSE_CSRMV_ALG1 +#define GPUSPARSE_SPMM_COO_ALG HIPSPARSE_SPMM_COO_ALG1 +#define GPUSPARSE_SPMM_CSR_ALG HIPSPARSE_SPMM_CSR_ALG1 #define GPUSPARSE_INDEX_BASE_ZERO HIPSPARSE_INDEX_BASE_ZERO #define GPUSPARSE_OPERATION_NON_TRANSPOSE HIPSPARSE_OPERATION_NON_TRANSPOSE #define GPUSPARSE_OPERATION_TRANSPOSE HIPSPARSE_OPERATION_TRANSPOSE @@ -658,7 +774,8 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t; #define GPU_STREAM_NON_BLOCKING hipStreamNonBlocking #define gpuMalloc hipMalloc -#define gpuGetLastError hipGetLastError +#define gpuFree hipFree +#define gpuGetLastError hipExtGetLastError #define gpuGetErrorString hipGetErrorString #define gpuMemcpyAsync hipMemcpyAsync #define gpuMemcpyDeviceToDevice hipMemcpyDeviceToDevice @@ -713,6 +830,8 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpuGetDeviceProperties hipGetDeviceProperties #define gpuLaunchCooperativeKernel hipLaunchCooperativeKernel +#define JAX_GPU_HAVE_SOLVER_GEEV 0 + namespace jax::JAX_GPU_NAMESPACE { namespace { constexpr uint32_t kNumThreadsPerWarp = 64; diff --git a/jaxlib/gpu_linalg.py b/jaxlib/gpu_linalg.py index c747c0abbe8b..967dacdbacff 100644 --- a/jaxlib/gpu_linalg.py +++ b/jaxlib/gpu_linalg.py @@ -19,12 +19,17 @@ _cuda_linalg = import_from_plugin("cuda", "_linalg") _hip_linalg = import_from_plugin("rocm", "_linalg") + def registrations() -> dict[str, list[tuple[str, Any, int]]]: - registrations = {"CUDA": [], "ROCM": []} + registrations: dict[str, list[tuple[str, Any, int]]] = { + "CUDA": [], + "ROCM": [], + } for platform, module in [("CUDA", _cuda_linalg), ("ROCM", _hip_linalg)]: if module: registrations[platform].extend( - (*i, 1) for i in module.registrations().items()) + (*i, 1) for i in module.registrations().items() + ) return registrations # pytype: disable=bad-return-type diff --git a/jaxlib/gpu_prng.py b/jaxlib/gpu_prng.py index 6f74d5813ce4..17da46de699f 100644 --- a/jaxlib/gpu_prng.py +++ b/jaxlib/gpu_prng.py @@ -12,79 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations -from functools import partial -import itertools +from typing import Any -import jaxlib.mlir.ir as ir - -from jaxlib import xla_client - -from .hlo_helpers import custom_call from .plugin_support import import_from_plugin _cuda_prng = import_from_plugin("cuda", "_prng") _hip_prng = import_from_plugin("rocm", "_prng") -if _cuda_prng: - for _name, _value in _cuda_prng.registrations().items(): - # TODO(danfm): remove after JAX 0.5.1 release - api_version = 1 if "_ffi" in _name else 0 - xla_client.register_custom_call_target(_name, _value, platform="CUDA", - api_version=api_version) - -if _hip_prng: - for _name, _value in _hip_prng.registrations().items(): - # TODO(danfm): remove after JAX 0.5.1 release - api_version = 1 if "_ffi" in _name else 0 - xla_client.register_custom_call_target(_name, _value, platform="ROCM", - api_version=api_version) - - -def _threefry2x32_lowering(prng, platform: str, keys, data, - length: int | ir.Value | None = None, - output_shape: ir.Value | None = None, - forward_compatibility_mode: bool = False): - """ThreeFry2x32 kernel for GPU. - - In presence of dynamic shapes, `length` is an `ir.Value` and `output_shape` - is a 1D tensor describing the shape of the two outputs. - """ - del forward_compatibility_mode - assert len(keys) == 2, keys - assert len(data) == 2, data - assert (ir.RankedTensorType(keys[0].type).element_type == - ir.IntegerType.get_unsigned(32)), keys[0].type - - typ = keys[0].type - dims = ir.RankedTensorType(typ).shape - - for x in itertools.chain(keys, data): - assert x.type == typ, (x.type, typ) - ndims = len(dims) - layout = tuple(range(ndims - 1, -1, -1)) - operand_layouts = [layout] * 4 - operands = [keys[0], keys[1], data[0], data[1]] - - opaque = {} # Use if not forward_compatibility_mode to trigger the FFI (v4). - if isinstance(length, int): - result_shapes = None - else: - assert output_shape is not None - # We also need to pass separately the shapes of the outputs. - result_shapes = [output_shape, output_shape] - - custom_call_target = f"{platform}_threefry2x32_ffi" - return custom_call( - custom_call_target, - api_version=4, - result_types=[typ, typ], - operands=operands, - backend_config=opaque, - operand_layouts=operand_layouts, - result_layouts=[layout] * 2, - result_shapes=result_shapes).results - -cuda_threefry2x32 = partial(_threefry2x32_lowering, _cuda_prng, "cu") -rocm_threefry2x32 = partial(_threefry2x32_lowering, _hip_prng, "hip") +def registrations() -> dict[str, list[tuple[str, Any, int]]]: + registrations: dict[str, list[tuple[str, Any, int]]] = { + "CUDA": [], + "ROCM": [], + } + for platform, module in [("CUDA", _cuda_prng), ("ROCM", _hip_prng)]: + if module: + registrations[platform].extend( + (name, value, int(name.endswith("_ffi"))) + for name, value in module.registrations().items() + ) + return registrations diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index efb58f9a4164..cdcd2b6199f9 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -16,21 +16,18 @@ from .plugin_support import import_from_plugin -_cublas = import_from_plugin("cuda", "_blas") _cusolver = import_from_plugin("cuda", "_solver") _cuhybrid = import_from_plugin("cuda", "_hybrid") -_hipblas = import_from_plugin("rocm", "_blas") _hipsolver = import_from_plugin("rocm", "_solver") _hiphybrid = import_from_plugin("rocm", "_hybrid") def registrations() -> dict[str, list[tuple[str, Any, int]]]: - registrations = {"CUDA": [], "ROCM": []} - for platform, module in [("CUDA", _cublas), ("ROCM", _hipblas)]: - if module: - registrations[platform].extend( - (*i, 0) for i in module.registrations().items()) + registrations: dict[str, list[tuple[str, Any, int]]] = { + "CUDA": [], + "ROCM": [], + } for platform, module in [("CUDA", _cusolver), ("ROCM", _hipsolver)]: if module: registrations[platform].extend( @@ -40,17 +37,17 @@ def registrations() -> dict[str, list[tuple[str, Any, int]]]: for platform, module in [("CUDA", _cuhybrid), ("ROCM", _hiphybrid)]: if module: registrations[platform].extend( - (*i, 1) for i in module.registrations().items()) + (*i, 1) for i in module.registrations().items() + ) return registrations # pytype: disable=bad-return-type def batch_partitionable_targets() -> list[str]: - targets = [] + targets: list[str] = [] for module in [_cusolver, _hipsolver]: if module: targets.extend( - name for name in module.registrations() - if name.endswith("_ffi") + name for name in module.registrations() if name.endswith("_ffi") ) for module in [_cuhybrid, _hiphybrid]: if module: diff --git a/jaxlib/gpu_sparse.py b/jaxlib/gpu_sparse.py index d8645041c946..08231794a2ee 100644 --- a/jaxlib/gpu_sparse.py +++ b/jaxlib/gpu_sparse.py @@ -11,373 +11,32 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -cusparse wrappers for performing sparse matrix computations in JAX -""" -import math -from functools import partial from typing import Any -import jaxlib.mlir.ir as ir - -import numpy as np - -from .hlo_helpers import custom_call, mk_result_types_and_shapes - from .plugin_support import import_from_plugin _cusparse = import_from_plugin("cuda", "_sparse") _hipsparse = import_from_plugin("rocm", "_sparse") def registrations() -> dict[str, list[tuple[str, Any, int]]]: - registrations = {"CUDA": [], "ROCM": []} + registrations: dict[str, list[tuple[str, Any, int]]] = { + "CUDA": [], + "ROCM": [], + } for platform, module in [("CUDA", _cusparse), ("ROCM", _hipsparse)]: if module: registrations[platform].extend( (name, value, int(name.endswith("_ffi"))) - for name, value in module.registrations().items()) + for name, value in module.registrations().items() + ) return registrations # pytype: disable=bad-return-type - -cuda_is_supported = bool(_cusparse and _cusparse.sparse_supported) -rocm_is_supported = bool(_hipsparse and _hipsparse.sparse_supported) - - -def _validate_csr_hlo(data, indices, indptr, shape): - data_type = ir.RankedTensorType(data.type) - indices_type = ir.RankedTensorType(indices.type) - indptr_type = ir.RankedTensorType(indptr.type) - - nnz, = data_type.shape - assert indices_type.shape == [nnz] - assert indptr_type.element_type == indices_type.element_type - assert indptr_type.shape == [shape[0] + 1] - return data_type.element_type, indices_type.element_type, nnz - -def _validate_coo_hlo(data, row, col): - data_type = ir.RankedTensorType(data.type) - row_type = ir.RankedTensorType(row.type) - col_type = ir.RankedTensorType(col.type) - - nnz, = data_type.shape - assert row_type.shape == [nnz] - assert col_type.element_type == row_type.element_type - assert col_type.shape == [nnz] - return data_type.element_type, row_type.element_type, nnz - - -def _csr_todense_hlo(platform, gpu_sparse, data, indices, indptr, *, shape, - data_dtype, index_dtype): - """CSR to dense matrix.""" - data_type, index_type, nnz = _validate_csr_hlo(data, indices, indptr, shape) - rows, cols = shape - - buffer_size, opaque = gpu_sparse.build_csr_todense_descriptor( - data_dtype, index_dtype, rows, cols, nnz) - - out = custom_call( - f"{platform}sparse_csr_todense_ffi", - result_types=[ - ir.RankedTensorType.get(shape, data_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[data, indices, indptr], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[0]] * 3, - result_layouts=[[1, 0], [0]]).results - return out[0] - -cuda_csr_todense = partial(_csr_todense_hlo, "cu", _cusparse) -rocm_csr_todense = partial(_csr_todense_hlo, "hip", _hipsparse) - - -def _csr_fromdense_hlo(platform, gpu_sparse, mat, *, nnz, index_dtype, - data_dtype, index_type): - """CSR from dense matrix.""" - mat_type = ir.RankedTensorType(mat.type) - rows, cols = mat_type.shape - - buffer_size, opaque = gpu_sparse.build_csr_fromdense_descriptor( - data_dtype, index_dtype, rows, cols, nnz) - - out = custom_call( - f"{platform}sparse_csr_fromdense_ffi", - result_types=[ - ir.RankedTensorType.get([nnz], mat_type.element_type), - ir.RankedTensorType.get([nnz], index_type), - ir.RankedTensorType.get([rows + 1], index_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[mat], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[1, 0]], - result_layouts=[[0]] * 4).results - return out[:3] - -cuda_csr_fromdense = partial(_csr_fromdense_hlo, "cu", _cusparse) -rocm_csr_fromdense = partial(_csr_fromdense_hlo, "hip", _hipsparse) - - -def _csr_matvec_hlo(platform, gpu_sparse, data, indices, indptr, x, *, shape, - transpose=False, compute_dtype=None, compute_type=None, - data_dtype, index_dtype, x_dtype): - """CSR matrix/vector multiply.""" - data_type, index_type, nnz = _validate_csr_hlo(data, indices, indptr, shape) - rows, cols = shape - - if compute_dtype is None: - compute_dtype = data_dtype - compute_type = data_type - - buffer_size, opaque = gpu_sparse.build_csr_matvec_descriptor( - data_dtype, x_dtype, compute_dtype, index_dtype, - rows, cols, nnz, transpose) - out_size = cols if transpose else rows - - out = custom_call( - f"{platform}sparse_csr_matvec_ffi", - result_types=[ - ir.RankedTensorType.get([out_size], compute_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[data, indices, indptr, x], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[0]] * 4, - result_layouts=[[0]] * 2).results - return out[0] - -cuda_csr_matvec = partial(_csr_matvec_hlo, "cu", _cusparse) -rocm_csr_matvec = partial(_csr_matvec_hlo, "hip", _hipsparse) - - -def _csr_matmat_hlo(platform, gpu_sparse, data, indices, indptr, B, *, shape, - transpose=False, compute_dtype=None, compute_type=None, - index_dtype, data_dtype, B_dtype): - """CSR from dense matrix.""" - data_type, index_type, nnz = _validate_csr_hlo(data, indices, indptr, shape) - rows, cols = shape - B_shape = ir.RankedTensorType(B.type).shape - _, Ccols = B_shape - - if compute_dtype is None: - compute_dtype = data_dtype - compute_type = data_type - - buffer_size, opaque = gpu_sparse.build_csr_matmat_descriptor( - data_dtype, B_dtype, compute_dtype, index_dtype, - rows, cols, Ccols, nnz, transpose) - out_size = cols if transpose else rows - - out = custom_call( - f"{platform}sparse_csr_matmat_ffi", - result_types=[ - ir.RankedTensorType.get([out_size, Ccols], compute_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[data, indices, indptr, B], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[0], [0], [0], [1, 0]], - result_layouts=[[1, 0], [0]]).results - return out[0] - -cuda_csr_matmat = partial(_csr_matmat_hlo, "cu", _cusparse) -rocm_csr_matmat = partial(_csr_matmat_hlo, "hip", _hipsparse) - - -def _coo_todense_hlo(platform, gpu_sparse, data, row, col, *, shape, - data_dtype, index_dtype): - """COO to dense matrix.""" - data_type, _, nnz = _validate_coo_hlo(data, row, col) - rows, cols = shape - - buffer_size, opaque = gpu_sparse.build_coo_todense_descriptor( - data_dtype, index_dtype, rows, cols, nnz) - - out = custom_call( - f"{platform}sparse_coo_todense_ffi", - result_types=[ - ir.RankedTensorType.get(shape, data_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[data, row, col], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[0]] * 3, - result_layouts=[[1, 0], [0]]).results - return out[0] - -cuda_coo_todense = partial(_coo_todense_hlo, "cu", _cusparse) -rocm_coo_todense = partial(_coo_todense_hlo, "hip", _hipsparse) - - -def _coo_fromdense_hlo(platform, gpu_sparse, mat, *, nnz, data_dtype, - index_dtype, index_type): - """COO from dense matrix.""" - mat_type = ir.RankedTensorType(mat.type) - rows, cols = mat_type.shape - - buffer_size, opaque = gpu_sparse.build_coo_fromdense_descriptor( - data_dtype, index_dtype, rows, cols, nnz) - - out = custom_call( - f"{platform}sparse_coo_fromdense_ffi", - result_types=[ - ir.RankedTensorType.get([nnz], mat_type.element_type), - ir.RankedTensorType.get([nnz], index_type), - ir.RankedTensorType.get([nnz], index_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[mat], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[1, 0]], - result_layouts=[[0]] * 4).results - return out[:3] - -cuda_coo_fromdense = partial(_coo_fromdense_hlo, "cu", _cusparse) -rocm_coo_fromdense = partial(_coo_fromdense_hlo, "hip", _hipsparse) - - -def _coo_matvec_hlo(platform, gpu_sparse, data, row, col, x, *, shape, - transpose=False, compute_dtype=None, compute_type=None, - index_dtype, data_dtype, x_dtype): - """COO matrix/vector multiply.""" - data_type, _, nnz = _validate_coo_hlo(data, row, col) - rows, cols = shape - - if compute_dtype is None: - compute_dtype = data_dtype - compute_type = data_type - - buffer_size, opaque = gpu_sparse.build_coo_matvec_descriptor( - data_dtype, x_dtype, compute_dtype, index_dtype, - rows, cols, nnz, transpose) - out_size = cols if transpose else rows - - out = custom_call( - f"{platform}sparse_coo_matvec_ffi", - result_types=[ - ir.RankedTensorType.get([out_size], compute_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[data, row, col, x], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[0]] * 4, - result_layouts=[[0]] * 2).results - return out[0] - -cuda_coo_matvec = partial(_coo_matvec_hlo, "cu", _cusparse) -rocm_coo_matvec = partial(_coo_matvec_hlo, "hip", _hipsparse) - - -def _coo_matmat_hlo(platform, gpu_sparse, data, row, col, B, *, shape, - transpose=False, compute_dtype=None, compute_type=None, - x_dtype, data_dtype, index_dtype): - """COO from dense matrix.""" - data_type, _, nnz = _validate_coo_hlo(data, row, col) - is_batched_matmat = False - batch_count = 1 - if len(shape) == 2: - rows, cols = shape - elif len(shape) == 3: - is_batched_matmat = True - batch_count, rows, cols = shape - # Redefine nnz as nnz per batch. - nnz = nnz // batch_count - - B_shape = ir.RankedTensorType(B.type).shape - _, Ccols = B_shape - - if compute_dtype is None: - compute_dtype = data_dtype - compute_type = data_type - - # TODO(tianjianlu): use batch stride to trigger different mode of batch - # computation. Currently batch_stride = 0 is not allowed because of the issue - # in cusparse https://github.com/NVIDIA/CUDALibrarySamples/issues/81#issuecomment-1205562643 - # Set batch stride to be the matrix size for now. - lhs_batch_stride = nnz - B_rows = rows if transpose else cols - rhs_batch_stride = B_rows * Ccols - - buffer_size, opaque = gpu_sparse.build_coo_matmat_descriptor( - data_dtype, x_dtype, compute_dtype, index_dtype, - rows, cols, Ccols, nnz, transpose, batch_count, lhs_batch_stride, - rhs_batch_stride) - out_size = cols if transpose else rows - - if is_batched_matmat: - out_shape = [batch_count, out_size, Ccols] - out_layout = [2, 1, 0] - else: - out_shape = [out_size, Ccols] - out_layout = [1, 0] - - out = custom_call( - f"{platform}sparse_coo_matmat_ffi", - result_types=[ - ir.RankedTensorType.get(out_shape, compute_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[data, row, col, B], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[0], [0], [0], [1, 0]], - result_layouts=[out_layout, [0]]).results - return out[0] - -cuda_coo_matmat = partial(_coo_matmat_hlo, "cu", _cusparse) -rocm_coo_matmat = partial(_coo_matmat_hlo, "hip", _hipsparse) - - -def _gtsv2_hlo( - platform, gpu_sparse, dl, d, du, B, *, m, n, ldb, t, b_shape_vals=None): - """Calls `cusparsegtsv2(dl, d, du, B, m, n, ldb)`.""" - assert len(b_shape_vals) >= 2 - batch_dim_vals = b_shape_vals[:-2] - batch_size = math.prod(batch_dim_vals) - num_bd = len(b_shape_vals) - 2 - f32 = (t == np.float32) - if f32: - buffer_size = gpu_sparse.gtsv2_f32_buffer_size(m, n, ldb) - else: - buffer_size = gpu_sparse.gtsv2_f64_buffer_size(m, n, ldb) - - b_layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - d_layout = (num_bd,) + tuple(range(num_bd - 1, -1, -1)) - b_type = ir.RankedTensorType(B.type) - - shape_type_pairs = [ - (batch_dim_vals + (ldb, n), b_type.element_type), - ((buffer_size,), ir.IntegerType.get_signless(8)) - ] - result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) - opaque = gpu_sparse.build_gtsv2_descriptor(batch_size, m, n, ldb) - out = custom_call( - f"{platform}sparse_gtsv2_" + ("f32" if f32 else "f64") + "_ffi", - result_types=result_types, - operands=[dl, d, du, B], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[d_layout] * 3 + [b_layout], - result_layouts=[b_layout, [0]], - operand_output_aliases={3: 0}, - result_shapes=result_shapes).results - return out[0] - -cuda_gtsv2 = partial(_gtsv2_hlo, "cu", _cusparse) -rocm_gtsv2 = partial(_gtsv2_hlo, "hip", _hipsparse) +def batch_partitionable_targets() -> list[str]: + targets: list[str] = [] + for module in [_cusparse, _hipsparse]: + if module: + targets.extend( + name for name in module.registrations() if name.endswith("gtsv2_ffi") + ) + return targets diff --git a/jaxlib/guard_lib.cc b/jaxlib/guard_lib.cc new file mode 100644 index 000000000000..15f7ec308f01 --- /dev/null +++ b/jaxlib/guard_lib.cc @@ -0,0 +1,270 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This files implements the configuration management for different types of +// guards. +// C++ backends are responsible for enforcing transfer guard levels. + +#include "jaxlib/guard_lib.h" + +#include +#include +#include +#include // NOLINT + +#include "absl/base/attributes.h" +#include "absl/base/const_init.h" +#include "absl/functional/function_ref.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/util.h" + +namespace jax { + +namespace nb = ::nanobind; + +namespace { + +// Protected by the GIL. +GuardState& global_state = *new GuardState(); + +ABSL_CONST_INIT thread_local GuardState thread_local_state; + +// The default transfer guard level. +constexpr TransferGuardLevel kDefaultGuardLevel = TransferGuardLevel::kAllow; + +// The default garbage collection guard level. +constexpr GarbageCollectionGuardLevel kDefaultGarbageCollectionGuardLevel = + GarbageCollectionGuardLevel::kAllow; + +// Returns the transfer guard action for a transfer. +TransferGuardAction GetTransferGuardAction(TransferGuardLevel guard_level, + bool explicit_transfer) { + switch (guard_level) { + case TransferGuardLevel::kAllow: + return TransferGuardAction::kAllow; + case TransferGuardLevel::kLog: + if (explicit_transfer) { + return TransferGuardAction::kAllow; + } else { + return TransferGuardAction::kLog; + } + case TransferGuardLevel::kDisallow: + if (explicit_transfer) { + return TransferGuardAction::kAllow; + } else { + return TransferGuardAction::kDisallow; + } + case TransferGuardLevel::kLogExplicit: + return TransferGuardAction::kLog; + case TransferGuardLevel::kDisallowExplicit: + return TransferGuardAction::kDisallow; + default: + // Unreachable; gracefully handle the unexpected guard level and prevent a + // compiler warning. + return TransferGuardAction::kDisallow; + } +} + +// Returns the transfer guard action for a host-to-device transfer. +// REQUIRES: Python GIL. +TransferGuardAction GetTransferGuardActionForHostToDevice() { + return GetTransferGuardAction( + thread_local_state.host_to_device.value_or( + global_state.host_to_device.value_or(kDefaultGuardLevel)), + thread_local_state.explicit_device_put); +} + +// Returns the transfer guard action for a device-to-device transfer. +// REQUIRES: Python GIL. +TransferGuardAction GetTransferGuardActionForDeviceToDevice() { + return GetTransferGuardAction( + thread_local_state.device_to_device.value_or( + global_state.device_to_device.value_or(kDefaultGuardLevel)), + thread_local_state.explicit_device_put); +} + +// Returns the transfer guard action for a device-to-host transfer. +// REQUIRES: Python GIL. +TransferGuardAction GetTransferGuardActionForDeviceToHost() { + return GetTransferGuardAction( + thread_local_state.device_to_host.value_or( + global_state.device_to_host.value_or(kDefaultGuardLevel)), + thread_local_state.explicit_device_get); +} + +// Guards the global state's thread ID. +ABSL_CONST_INIT absl::Mutex thread_id_mu(absl::kConstInit); + +} // namespace + +absl::Status ApplyTransferGuardToHostToDevice( + absl::FunctionRef formatter) { + switch (GetTransferGuardActionForHostToDevice()) { + case TransferGuardAction::kAllow: + break; + case TransferGuardAction::kLog: + LOG(WARNING) << "host-to-device transfer: " << formatter(); + break; + case TransferGuardAction::kDisallow: + return xla::InvalidArgument("Disallowed host-to-device transfer: %s", + formatter()); + } + return absl::OkStatus(); +} + +absl::Status ApplyTransferGuardToDeviceToDevice( + absl::FunctionRef formatter) { + switch (GetTransferGuardActionForDeviceToDevice()) { + case TransferGuardAction::kAllow: + break; + case TransferGuardAction::kLog: + LOG(WARNING) << "device-to-device transfer: " << formatter(); + break; + case TransferGuardAction::kDisallow: + return xla::InvalidArgument("Disallowed device-to-device transfer: %s", + formatter()); + } + return absl::OkStatus(); +} + +absl::Status ApplyTransferGuardToDeviceToHost( + absl::FunctionRef formatter) { + switch (GetTransferGuardActionForDeviceToHost()) { + case TransferGuardAction::kAllow: + break; + case TransferGuardAction::kLog: + LOG(WARNING) << "device-to-host transfer: " << formatter(); + break; + case TransferGuardAction::kDisallow: + return xla::InvalidArgument("Disallowed device-to-host transfer: %s", + formatter()); + } + return absl::OkStatus(); +} + +GarbageCollectionGuardLevel GetGarbageCollectArrayGuard() { + return thread_local_state.garbage_collect_array.value_or( + global_state.garbage_collect_array.value_or( + kDefaultGarbageCollectionGuardLevel)); +} + +absl::Status CheckThreadGuard(xla::ifrt::DeviceListRef devices) { + absl::MutexLock lock(thread_id_mu); + // If the thread id is not set, then the thread guard is not enabled. + if (!global_state.thread_id.has_value()) { + return absl::OkStatus(); + } + + // Detect if the devices span multiple processes; the thread guard applies + // only to multi-process operations. + // TODO(emilyaf): Allow disjoint subsets of devices in different threads. + bool is_multiprocess = false; + CHECK(!devices->devices().empty()); + int first_process_index = devices->devices()[0]->ProcessIndex(); + for (const auto& device : devices->devices()) { + if (device->ProcessIndex() != first_process_index) { + is_multiprocess = true; + break; + } + } + if (!is_multiprocess) { + return absl::OkStatus(); + } + + // The thread guard is active, so check that the current thread is the owner. + std::thread::id current_thread_id = std::this_thread::get_id(); + if (current_thread_id != *global_state.thread_id) { + std::stringstream ss_current, ss_owner; + ss_current << current_thread_id; + ss_owner << *global_state.thread_id; + return xla::FailedPrecondition( + "A multi-process JAX operation was called from thread %s. This is not " + "allowed because the thread guard was set in thread %s.", + ss_current.str(), ss_owner.str()); + } + return absl::OkStatus(); +} + +absl::Status UpdateThreadGuardGlobalState(bool set_thread_id) { + absl::MutexLock lock(thread_id_mu); + // If set_thread_id is true, then the thread guard context was entered and the + // thread id should be set. If the thread ID is already set, then a thread + // guard is nested, which is allowed only in the same thread. + // If set_thread_id is false, the thread guard context was exited and the + // thread id should be cleared. + if (set_thread_id) { + if (global_state.thread_id.has_value()) { + if (global_state.thread_id.value() != std::this_thread::get_id()) { + return xla::FailedPrecondition( + "The thread guard's global thread ID is already set. Nested thread " + "guards in different threads are not supported."); + } + } else { + global_state.thread_id = std::this_thread::get_id(); + } + } else { + global_state.thread_id = std::nullopt; + } + return absl::OkStatus(); +} + +void BuildGuardSubmodule(nb::module_& m) { + nb::module_ glib = + m.def_submodule("guard_lib", "Jax support library for guards"); + + nb::enum_ tglevel(glib, "TransferGuardLevel"); + tglevel.value("ALLOW", TransferGuardLevel::kAllow); + tglevel.value("LOG", TransferGuardLevel::kLog); + tglevel.value("DISALLOW", TransferGuardLevel::kDisallow); + tglevel.value("LOG_EXPLICIT", TransferGuardLevel::kLogExplicit); + tglevel.value("DISALLOW_EXPLICIT", TransferGuardLevel::kDisallowExplicit); + + nb::enum_ gcglevel( + glib, "GarbageCollectionGuardLevel"); + gcglevel.value("ALLOW", GarbageCollectionGuardLevel::kAllow); + gcglevel.value("LOG", GarbageCollectionGuardLevel::kLog); + gcglevel.value("FATAL", GarbageCollectionGuardLevel::kFatal); + + nb::class_ tgstate(glib, "GuardState"); + tgstate.def_rw("host_to_device", &GuardState::host_to_device, + nb::arg().none()); + tgstate.def_rw("device_to_device", &GuardState::device_to_device, + nb::arg().none()); + tgstate.def_rw("device_to_host", &GuardState::device_to_host, + nb::arg().none()); + tgstate.def_rw("explicit_device_put", &GuardState::explicit_device_put); + tgstate.def_rw("explicit_device_get", &GuardState::explicit_device_get); + tgstate.def_rw("garbage_collect_array", &GuardState::garbage_collect_array, + nb::arg().none()); + + glib.def( + "global_state", [&]() { return &global_state; }, + nb::rv_policy::reference); + glib.def( + "thread_local_state", [&]() { return &thread_local_state; }, + nb::rv_policy::reference); + glib.def("update_thread_guard_global_state", + xla::ThrowIfErrorWrapper(UpdateThreadGuardGlobalState), + nb::arg()); +} + +} // namespace jax diff --git a/jaxlib/guard_lib.h b/jaxlib/guard_lib.h new file mode 100644 index 000000000000..2adbf0a50c2c --- /dev/null +++ b/jaxlib/guard_lib.h @@ -0,0 +1,130 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_GUARD_LIB_H_ +#define JAXLIB_GUARD_LIB_H_ + +#include +#include +#include // NOLINT + +// placeholder for index annotation headers +#include "absl/functional/function_ref.h" +#include "absl/status/status.h" +#include "nanobind/nanobind.h" +#include "xla/python/ifrt/device_list.h" + +namespace jax { + +// Transfer guard level chosen by the user code. +enum class TransferGuardLevel { + // Explicit transfers: allow + // Implicit transfers: allow + kAllow, + // Explicit transfers: allow + // Implicit transfers: log + kLog, + // Explicit transfers: allow + // Implicit transfers: disallow + kDisallow, + // Explicit transfers: log + // Implicit transfers: log + kLogExplicit, + // Explicit transfers: disallow + // Implicit transfers: disallow + kDisallowExplicit, +}; + +// Garbage collection guard level chose by the user code. +enum class GarbageCollectionGuardLevel { + // Silently allow the object to be garbage collected. + kAllow, + // Log and allow the object to be garbage collected. + kLog, + // Fatal crash on object garbage collection. + kFatal, +}; + +// Flags for guard levels are controlled by: +// - a global flag value, +// e.g., associated to --jax_transfer_guard_device_to_host +// which defaults to TransferGuardLevel::kAllow. +// - possibly a thread-local value, which initially is std::nullopt and +// overrides the global value if set. The thread-local state is used to +// implement context managers that locally override the global state. +// +// Explicit device_put/device_get contexts are tracked by context managers. +struct GuardState { + std::optional host_to_device; + std::optional device_to_device; + std::optional device_to_host; + bool explicit_device_put = false; + bool explicit_device_get = false; + + std::optional garbage_collect_array; + std::optional thread_id; +}; + +// Resulting action for a transfer given the transfer guard level and the +// transfer type. +enum class TransferGuardAction { + // Silently allow the transfer. + kAllow, + // Log and allow the transfer. + kLog, + // Disallow the transfer. + kDisallow, +}; + +// Guards a host-to-device transfer. formatter is called to describe the +// transfer in a log message or error status. +// REQUIRES: Python GIL. +absl::Status ApplyTransferGuardToHostToDevice( + absl::FunctionRef formatter); + +// Guards a device-to-device transfer. formatter is called to describe the +// transfer in a log message or error status. +// REQUIRES: Python GIL. +absl::Status ApplyTransferGuardToDeviceToDevice( + absl::FunctionRef formatter); + +// Guards a device-to-host transfer. formatter is called to describe the +// transfer in a log message or error status. +// REQUIRES: Python GIL. +absl::Status ApplyTransferGuardToDeviceToHost( + absl::FunctionRef formatter); + +// Returns the garbage collection guard level for "jax.Array" objects. +// REQUIRES: Python GIL. +GarbageCollectionGuardLevel GetGarbageCollectArrayGuard(); + +// Updates the global thread guard state. If `set_thread_id` is true, the global +// thread guard state is set with the current thread ID. If `set_thread_id` is +// false, the global thread guard state is cleared. A global mutex ensures the +// update is atomic. +// A failed status is returned if `set_thread_id` is true and the global thread +// ID is already set and not equal to the current thread ID. +absl::Status UpdateThreadGuardGlobalState(bool set_thread_id); + +// Checks if the thread guard should prevent execution. A global mutex ensures +// the thread ID is read atomically. +absl::Status CheckThreadGuard(xla::ifrt::DeviceListRef devices); + +// The function to call in `xla.cc` to add the bindings for this module. +void BuildGuardSubmodule(nanobind::module_& m); + +} // namespace jax + +#endif // JAXLIB_GUARD_LIB_H_ diff --git a/jaxlib/hlo_helpers.py b/jaxlib/hlo_helpers.py deleted file mode 100644 index 0d57a04f1aa7..000000000000 --- a/jaxlib/hlo_helpers.py +++ /dev/null @@ -1,248 +0,0 @@ -# Copyright 2022 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""A small library of helpers for use in jaxlib to build MLIR operations.""" - -from __future__ import annotations - -from collections.abc import Callable, Sequence -from functools import partial -from typing import Union - -import jaxlib.mlir.ir as ir -import jaxlib.mlir.dialects.stablehlo as hlo -import numpy as np - - -_dtype_to_ir_type_factory : dict[np.dtype, Callable[[], ir.Type]] = { - np.dtype(np.bool_): partial(ir.IntegerType.get_signless, 1), - np.dtype(np.int8): partial(ir.IntegerType.get_signless, 8), - np.dtype(np.int16): partial(ir.IntegerType.get_signless, 16), - np.dtype(np.int32): partial(ir.IntegerType.get_signless, 32), - np.dtype(np.int64): partial(ir.IntegerType.get_signless, 64), - np.dtype(np.uint8): partial(ir.IntegerType.get_unsigned, 8), - np.dtype(np.uint16): partial(ir.IntegerType.get_unsigned, 16), - np.dtype(np.uint32): partial(ir.IntegerType.get_unsigned, 32), - np.dtype(np.uint64): partial(ir.IntegerType.get_unsigned, 64), - np.dtype(np.float16): ir.F16Type.get, - np.dtype(np.float32): ir.F32Type.get, - np.dtype(np.float64): ir.F64Type.get, - np.dtype(np.complex64): lambda: ir.ComplexType.get(ir.F32Type.get()), - np.dtype(np.complex128): lambda: ir.ComplexType.get(ir.F64Type.get()), -} -def dtype_to_ir_type(dtype) -> ir.Type: - return _dtype_to_ir_type_factory[np.dtype(dtype)]() - - -def shape_dtype_to_ir_type(shape: Sequence[int], dtype) -> ir.Type: - return ir.RankedTensorType.get(shape, dtype_to_ir_type(dtype)) - - -# When we generate custom calls with dynamic shapes we have to pass -# both the result_types, with ir.ShapedType.get_dynamic_size in place of -# the dynamic dimensions, and also result_shapes, which are ir.Value -# representing 1D int32 tensors. If all the shapes are static we can use -# result_shapes=None. We first construct for each result a pair with the shape -# and element type, the shape containing either integer or ir.Value. -DimensionSize = Union[int, ir.Value] # an ir.Value if not static dimension -ShapeTypePair = tuple[Sequence[DimensionSize], ir.Type] - -def mk_result_types_and_shapes( - shape_type_pairs: Sequence[ShapeTypePair] -) -> tuple[list[ir.Type], list[ir.Value] | None]: - result_types: list[ir.Type] = [] - result_shapes: list[ir.Value] = [] - has_dynamic_shapes = any( - any(not isinstance(d, int) for d in rshape) - for rshape, _ in shape_type_pairs) - for (rshape, rtype) in shape_type_pairs: - if has_dynamic_shapes: - result_shapes.append(shape_tensor(rshape)) - result_types.append( - ir.RankedTensorType.get( - [d if isinstance(d, int) else ir.ShapedType.get_dynamic_size() - for d in rshape], - rtype)) - return (result_types, - result_shapes if has_dynamic_shapes else None) - -# TODO(necula): share this with mlir.shape_tensor -def shape_tensor(sizes: Sequence[int | ir.Value]) -> ir.Value: - int1d = shape_dtype_to_ir_type((1,), np.int32) - i32_type = shape_dtype_to_ir_type((), np.int32) - def dim_to_i32x1(d): - if type(d) is int: - return hlo_const(np.array([d], dtype=np.int32)) - else: - if d.type != i32_type: - d = hlo.convert(i32_type, d) - return hlo.reshape(int1d, d) - ds = [dim_to_i32x1(sz) for sz in sizes] - if not ds: - return hlo_const(np.array([], np.int32)) - elif len(ds) == 1: - return ds[0] - else: - return hlo.concatenate( - ds, ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 0)) - -def hlo_const(x: np.ndarray) -> ir.Value: - assert isinstance(x, np.ndarray) - return hlo.constant( - ir.DenseElementsAttr.get(x, type=dtype_to_ir_type(x.dtype))) - -def hlo_u8(x: int): - return hlo_const(np.array(x, dtype=np.uint8)) -def hlo_s32(x: int): - return hlo_const(np.array(x, dtype=np.int32)) - -def ensure_hlo_s32(x: DimensionSize): - return hlo_s32(x) if isinstance(x, int) else x - -def dense_int_array(xs) -> ir.DenseI64ArrayAttr: - return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64)) - -def hlo_min(x: DimensionSize, y: DimensionSize) -> DimensionSize: - if type(x) is int: - if type(y) is int: - return min(x, y) - x = hlo_s32(x) - if type(y) is int: - y = hlo_s32(y) - return hlo.minimum(x, y) - - -def hlo_add(x: DimensionSize, y: DimensionSize) -> DimensionSize: - if type(x) is int: - if type(y) is int: - return x + y - x = hlo_s32(x) - if type(y) is int: - y = hlo_s32(y) - return hlo.add(x, y) - - -# TODO(necula): this is identical with mlir.custom_call, but meant for use -# in jaxlib. Find a way to share these implementations. -def custom_call( - call_target_name: str, - *, - result_types: Sequence[ir.Type], - operands: Sequence[ir.Value], - backend_config: str | bytes | dict[str, ir.Attribute] = "", - has_side_effect: bool = False, - result_shapes: Sequence[ir.Value] | None = None, - called_computations: Sequence[str] = (), - api_version: int = 2, - operand_output_aliases: dict[int, int] | None = None, - operand_layouts: Sequence[Sequence[int]] | None = None, - result_layouts: Sequence[Sequence[int]] | None = None, - extra_attributes: dict[str, ir.Attribute] | None = None, -) -> ir.Operation: - """Helper function for building an hlo.CustomCall. - - Args: - call_target_name: the name of the custom call target - result_types: the MLIR types of the results of the custom call - operands: the MLIR IR values that are arguments to the custom call - backend_config: an opaque string passed to the custom call kernel - has_side_effect: if True, marks the custom call as effectful - result_shapes: tensors that represent the result shapes, to be used when - the results have dynamic shapes. If not-None, its length must match the - number of the results. - called_computations: the list of function names called by the custom call. - api_version: the ABI contract version of the custom call - operand_output_aliases: a dict mapping operand numbers to outputs they alias - operand_layouts: a sequence of layouts (dimension orders) for each operand - result_layouts: a sequence of layouts (dimension orders) for each result - extra_attributes: additional IR attributes to apply to the custom_call. - """ - operands = list(operands) - - if backend_config is None: - backend_config_attr = ir.StringAttr.get("") - elif isinstance(backend_config, (str, bytes)): - backend_config_attr = ir.StringAttr.get(backend_config) - elif isinstance(backend_config, dict): - # TODO(necula): it seems that the CustomCallOp constructor requires that - # backend_config_attr be a string attribute, even though in some cases we - # need it to be a DictAttr, e.g., for ApproxTopK on TPU. - # "Verification failed: 'stablehlo.custom_call' op attribute 'backend_config' failed to satisfy constraint: string attribute" - # To workaround this limitation we first set it to the empty string and we - # use an unregistered attribute mhlo.backend_config to hold the DictAttr. - # We must also use api_version=1 to ensure that mhlo.backend_config is - # handled properly. - backend_config_attr = ir.StringAttr.get("") - api_version = 1 - else: - raise ValueError("custom_call backend_config unexpected type: " + str(backend_config)) - attributes = dict( - call_target_name=ir.StringAttr.get(call_target_name), - has_side_effect=ir.BoolAttr.get(has_side_effect), - backend_config=backend_config_attr, - api_version=ir.IntegerAttr.get( - ir.IntegerType.get_signless(32), api_version), - called_computations=ir.ArrayAttr.get( - [ir.FlatSymbolRefAttr.get(name) for name in called_computations] - ), - ) - if operand_output_aliases is not None: - attributes["output_operand_aliases"] = ir.ArrayAttr.get([ - hlo.OutputOperandAlias.get( - # if len(result_types) == 1 then the aliasing refers implicitly to - # the only output. - output_tuple_indices=[output_idx] if len(result_types) > 1 else [], - operand_index=input_idx, - operand_tuple_indices=[], - ) - for input_idx, output_idx in (operand_output_aliases.items() or ()) - ]) - - if extra_attributes is not None: - attributes.update(extra_attributes) - - if result_shapes is not None: - # We add the result_shapes at the end of the operands, and must pass - # the indices_of_output_operands attribute. This attribute is not yet - # accepted by the CustomCall constructor, so we use build_generic - attributes["indices_of_shape_operands"] = ir.DenseIntElementsAttr.get( - np.asarray(list(range(len(operands), len(operands) + len(result_shapes))), - dtype=np.int64)) - if operand_layouts is not None: - assert len(operand_layouts) == len(operands), (operand_layouts, operands) - operand_layouts = list(operand_layouts) + [(0,)] * len(result_shapes) - operands = list(operands) + list(result_shapes) - - if operand_layouts is not None: - attributes["operand_layouts"] = ir.ArrayAttr.get([ - ir.DenseIntElementsAttr.get( - np.atleast_1d(np.asarray(l, dtype=np.int64)), - type=ir.IndexType.get()) for l in operand_layouts - ]) - if result_layouts is not None: - assert result_layouts is not None - assert len(result_layouts) == len(result_types), ( - result_layouts, result_types) - attributes["result_layouts"] = ir.ArrayAttr.get([ - ir.DenseIntElementsAttr.get( - np.atleast_1d(np.asarray(l, dtype=np.int64)), - type=ir.IndexType.get()) for l in result_layouts - ]) - - op = hlo.CustomCallOp.build_generic(results=result_types, operands=operands, - attributes=attributes) - if isinstance(backend_config, dict): - backend_config_attr = ir.DictAttr.get(backend_config) - op.operation.attributes["mhlo.backend_config"] = backend_config_attr - return op diff --git a/jaxlib/ifrt_proxy.cc b/jaxlib/ifrt_proxy.cc new file mode 100644 index 000000000000..b15536ad2a72 --- /dev/null +++ b/jaxlib/ifrt_proxy.cc @@ -0,0 +1,158 @@ +// Copyright 2023 The JAX Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/log/log_entry.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/time/time.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/function.h" // IWYU pragma: keep +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/unordered_map.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt_proxy/client/registry.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/statusor.h" + +namespace nb = ::nanobind; + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +struct PyClientConnectionOptions { + std::optional> on_disconnect; + std::optional> on_connection_update; + std::optional connection_timeout_in_seconds; + std::optional< + std::unordered_map>> + initialization_data; +}; + +absl::StatusOr> GetClient( + std::string proxy_server_address, + const PyClientConnectionOptions& py_options) { + DCHECK(PyGILState_Check()); + std::unique_ptr client; + + ClientConnectionOptions options; + if (py_options.on_disconnect) { + // While it is possible to pass around `py_options.on_disconnect` without + // wrapping it via a shared_ptr, copying the `py_options.on_disconnect` + // object can internally attempt to acquire the GIL [1], and can thus block + // or even deadlock. A unique_ptr or `absl::AnyInvocable` is not sufficient + // because downstream code can make copies. Reference: + // https://pybind11.readthedocs.io/en/stable/advanced/misc.html#common-sources-of-global-interpreter-lock-errors + auto py_on_disconnect = std::make_shared>( + std::move(*py_options.on_disconnect)); + + options.on_disconnect = + [on_disconnect = std::move(py_on_disconnect)](absl::Status s) mutable { + LOG(WARNING) << "Connection to server failed, calling supplied " + << "`on_disconnect` function: " << s; + tsl::Env::Default()->SchedClosure([s, on_disconnect]() mutable { + nb::gil_scoped_acquire gil_acquire; + (*on_disconnect)(s.ToString()); + on_disconnect = nullptr; + }); + }; + } + + if (py_options.on_connection_update) { + auto fn = std::make_shared>( + std::move(*py_options.on_connection_update)); + options.on_connection_update = [fn](std::string_view log_line) -> void { + tsl::Env::Default()->SchedClosure([fn, str = std::string(log_line)] { + nb::gil_scoped_acquire gil_acquire; + (*fn)(std::string(str)); + }); + }; + } + + if (py_options.connection_timeout_in_seconds.has_value()) { + options.connection_timeout = + absl::Seconds(*py_options.connection_timeout_in_seconds); + } + + if (py_options.initialization_data.has_value()) { + AttributeMap::Map attribute_map; + for (const auto& [key, py_value] : *py_options.initialization_data) { + if (std::holds_alternative(py_value)) { + nb::bytes value = std::get(py_value); + attribute_map.insert({key, AttributeMap::StringValue(std::string( + value.c_str(), value.size()))}); + } else if (std::holds_alternative(py_value)) { + attribute_map.insert( + {key, AttributeMap::BoolValue(std::get(py_value))}); + } else { + CHECK(std::holds_alternative(py_value)); + attribute_map.insert( + {key, AttributeMap::Int64Value(std::get(py_value))}); + } + } + options.initialization_data = AttributeMap(std::move(attribute_map)); + } + + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(client, CreateClient(proxy_server_address, options)); + } + + // Constructing `jax::PyClient` requires GIL as it may dec-ref Python objects. + return jax::PyClient::Make(std::move(client)); +} + +} // namespace + +NB_MODULE(_ifrt_proxy, m) { + nb::class_(m, "ClientConnectionOptions") + .def(nb::init<>()) + .def_rw("on_disconnect", &PyClientConnectionOptions::on_disconnect, + nb::arg().none()) + .def_rw("on_connection_update", + &PyClientConnectionOptions::on_connection_update, + nb::arg().none()) + .def_rw("connection_timeout_in_seconds", + &PyClientConnectionOptions::connection_timeout_in_seconds, + nb::arg().none()) + .def_rw("initialization_data", + &PyClientConnectionOptions::initialization_data, + nb::arg().none()); + + m.def("get_client", xla::ValueOrThrowWrapper(GetClient), + nb::arg("proxy_server_address"), nb::arg("options")); +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 89f1545995d5..b888785d98a6 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -20,10 +20,12 @@ load("@jax_wheel//:wheel.bzl", "WHEEL_VERSION") load("@jax_wheel_version_suffix//:wheel_version_suffix.bzl", "WHEEL_VERSION_SUFFIX") load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured") load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library") -load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION") +load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION", "HERMETIC_PYTHON_VERSION_KIND") load("@rules_cc//cc:defs.bzl", _cc_proto_library = "cc_proto_library") -load("@rules_python//python:defs.bzl", "py_test") -load("@xla//xla/tsl:tsl.bzl", _if_windows = "if_windows", _pybind_extension = "tsl_pybind_extension_opensource") +load("@rules_python//python:defs.bzl", "py_library", "py_test") +load("@test_shard_count//:test_shard_count.bzl", "USE_MINIMAL_SHARD_COUNT") +load("@xla//third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps") +load("@xla//xla/tsl:tsl.bzl", "transitive_hdrs", _if_windows = "if_windows", _pybind_extension = "tsl_pybind_extension_opensource") load("@xla//xla/tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_cuda_tests_tags", _tf_exec_properties = "tf_exec_properties") # Explicitly re-exports names to avoid "unused variable" warnings from .bzl @@ -31,7 +33,7 @@ load("@xla//xla/tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_c cc_proto_library = _cc_proto_library cuda_library = _cuda_library rocm_library = _rocm_library -pytype_test = native.py_test +proto_library = native.proto_library nanobind_extension = _pybind_extension if_cuda_is_configured = _if_cuda_is_configured if_rocm_is_configured = _if_rocm_is_configured @@ -42,13 +44,14 @@ tf_cuda_tests_tags = _tf_cuda_tests_tags jax_internal_packages = [] jax_extend_internal_users = [] +experimental_transfer_users = [] mosaic_gpu_internal_users = [] mosaic_internal_users = [] pallas_gpu_internal_users = [] -pallas_tpu_internal_users = [] +pallas_sc_internal_users = [] pallas_fuser_users = [] -mosaic_extension_deps = [] serialize_executable_internal_users = [] +buffer_callback_internal_users = [] jax_internal_export_back_compat_test_util_visibility = [] jax_internal_test_harnesses_visibility = [] @@ -56,42 +59,50 @@ jax_test_util_visibility = [] loops_visibility = [] PLATFORM_TAGS_DICT = { - ("Linux", "x86_64"): ("manylinux2014", "x86_64"), - ("Linux", "aarch64"): ("manylinux2014", "aarch64"), + ("Linux", "x86_64"): ("manylinux_2_27", "x86_64"), + ("Linux", "aarch64"): ("manylinux_2_27", "aarch64"), ("Linux", "ppc64le"): ("manylinux2014", "ppc64le"), ("Darwin", "x86_64"): ("macosx_11_0", "x86_64"), ("Darwin", "arm64"): ("macosx_11_0", "arm64"), ("Windows", "AMD64"): ("win", "amd64"), } -# TODO(vam): remove this once zstandard builds against Python 3.13 -def get_zstandard(): - if HERMETIC_PYTHON_VERSION == "3.13" or HERMETIC_PYTHON_VERSION == "3.13-ft": +def get_optional_dep(package, excluded_py_versions = ["3.14", "3.14-ft"]): + py_ver = HERMETIC_PYTHON_VERSION + if HERMETIC_PYTHON_VERSION_KIND == "ft": + py_ver += "-ft" + if py_ver in excluded_py_versions: return [] - return ["@pypi_zstandard//:pkg"] + return [package] _py_deps = { - "absl/logging": ["@pypi_absl_py//:pkg"], - "absl/testing": ["@pypi_absl_py//:pkg"], - "absl/flags": ["@pypi_absl_py//:pkg"], - "cloudpickle": ["@pypi_cloudpickle//:pkg"], - "colorama": ["@pypi_colorama//:pkg"], - "epath": ["@pypi_etils//:pkg"], # etils.epath - "filelock": ["@pypi_filelock//:pkg"], - "flatbuffers": ["@pypi_flatbuffers//:pkg"], - "hypothesis": ["@pypi_hypothesis//:pkg"], + "absl-all": ["@pypi//absl_py"], + "absl/logging": ["@pypi//absl_py"], + "absl/testing": ["@pypi//absl_py"], + "absl/testing:flagsaver": ["@pypi//absl_py"], + "absl/flags": ["@pypi//absl_py"], + "cloudpickle": get_optional_dep("@pypi//cloudpickle"), + "disable_pmap_shmap_merge": [], + "epath": get_optional_dep("@pypi//etils"), # etils.epath + "filelock": get_optional_dep("@pypi//filelock"), + "flatbuffers": ["@pypi//flatbuffers"], + "hypothesis": ["@pypi//hypothesis"], "magma": [], - "matplotlib": ["@pypi_matplotlib//:pkg"], + "matplotlib": get_optional_dep("@pypi//matplotlib"), "mpmath": [], - "opt_einsum": ["@pypi_opt_einsum//:pkg"], - "pil": ["@pypi_pillow//:pkg"], - "portpicker": ["@pypi_portpicker//:pkg"], - "ml_dtypes": ["@pypi_ml_dtypes//:pkg"], - "numpy": ["@pypi_numpy//:pkg"], - "scipy": ["@pypi_scipy//:pkg"], + "opt_einsum": ["@pypi//opt_einsum"], + "pil": get_optional_dep("@pypi//pillow"), + "portpicker": ["@pypi//portpicker"], + "ml_dtypes": ["@pypi//ml_dtypes"], + "numpy": ["@pypi//numpy"], + "scipy": ["@pypi//scipy"], "tensorflow_core": [], + "tensorstore": get_optional_dep("@pypi//tensorstore"), "torch": [], - "zstandard": get_zstandard(), + "tensorflow": get_optional_dep("@pypi//tensorflow", ["3.13-ft", "3.14", "3.14-ft"]), + "tpu_ops": [], + # TODO(vam): remove this once zstandard builds against Python >3.13 + "zstandard": get_optional_dep("@pypi//zstandard", ["3.13", "3.13-ft", "3.14", "3.14-ft"]), } def all_py_deps(excluded = []): @@ -115,9 +126,10 @@ def py_deps(_package): def jax_visibility(_target): """Returns the additional Bazel visibilities for `target`.""" - - # This is only useful as part of a larger Bazel repository. - return [] + return [ + "//jax:__subpackages__", + "//jaxlib:__subpackages__", + ] jax_extra_deps = [] jax_gpu_support_deps = [] @@ -125,134 +137,106 @@ jax2tf_deps = [] def pytype_library(name, pytype_srcs = None, **kwargs): _ = pytype_srcs # @unused - native.py_library(name = name, **kwargs) + kwargs.pop("lazy_imports", None) + py_library(name = name, **kwargs) def pytype_strict_library(name, pytype_srcs = [], **kwargs): data = pytype_srcs + (kwargs["data"] if "data" in kwargs else []) new_kwargs = {k: v for k, v in kwargs.items() if k != "data"} - native.py_library(name = name, data = data, **new_kwargs) + new_kwargs.pop("lazy_imports", None) + py_library(name = name, data = data, **new_kwargs) + +py_strict_library = py_library +py_strict_test = py_test -def py_library_providing_imports_info(*, name, lib_rule = native.py_library, pytype_srcs = [], **kwargs): +def py_library_providing_imports_info(*, name, lib_rule = py_library, pytype_srcs = [], **kwargs): data = pytype_srcs + (kwargs["data"] if "data" in kwargs else []) new_kwargs = {k: v for k, v in kwargs.items() if k != "data"} + new_kwargs.pop("lazy_imports", None) lib_rule(name = name, data = data, **new_kwargs) def py_extension(name, srcs, copts, deps, linkopts = []): nanobind_extension(name, srcs = srcs, copts = copts, linkopts = linkopts, deps = deps, module_name = name) -def windows_cc_shared_mlir_library(name, out, deps = [], srcs = [], exported_symbol_prefixes = []): - """Workaround DLL building issue. - - 1. cc_binary with linkshared enabled cannot produce DLL with symbol - correctly exported. - 2. Even if the DLL is correctly built, the resulting target cannot be - correctly consumed by other targets. - - Args: - name: the name of the output target - out: the name of the output DLL filename - deps: deps - srcs: srcs - """ - - # create a dummy library to get the *.def file - dummy_library_name = name + ".dummy.dll" - native.cc_binary( - name = dummy_library_name, - linkshared = 1, - linkstatic = 1, - deps = deps, - target_compatible_with = ["@platforms//os:windows"], - ) - - # .def file with all symbols, not usable - full_def_name = name + ".full.def" - native.filegroup( - name = full_def_name, - srcs = [dummy_library_name], - output_group = "def_file", - target_compatible_with = ["@platforms//os:windows"], - ) - - # say filtered_symbol_prefixes == ["mlir", "chlo"], then construct the regex - # pattern as "^\\s*(mlir|clho)" to use grep - pattern = "^\\s*(" + "|".join(exported_symbol_prefixes) + ")" - - # filtered def_file, only the needed symbols are included - filtered_def_name = name + ".filtered.def" - filtered_def_file = out + ".def" - native.genrule( - name = filtered_def_name, - srcs = [full_def_name], - outs = [filtered_def_file], - cmd = """echo 'LIBRARY {}\nEXPORTS ' > $@ && grep -E '{}' $(location :{}) >> $@""".format(out, pattern, full_def_name), - target_compatible_with = ["@platforms//os:windows"], - ) - - # create the desired library - native.cc_binary( - name = out, # this name must be correct, it will be the filename - linkshared = 1, - deps = deps, - win_def_file = filtered_def_file, - target_compatible_with = ["@platforms//os:windows"], - ) - - # however, the created cc_library (a shared library) cannot be correctly - # consumed by other cc_*... - interface_library_file = out + ".if.lib" - native.filegroup( - name = interface_library_file, - srcs = [out], - output_group = "interface_library", - target_compatible_with = ["@platforms//os:windows"], - ) - - # but this one can be correctly consumed, this is our final product - native.cc_import( - name = name, - interface_library = interface_library_file, - shared_library = out, - target_compatible_with = ["@platforms//os:windows"], - ) - ALL_BACKENDS = ["cpu", "gpu", "tpu"] +TEST_SUITE_SUFFIX = "_tests" +BACKEND_INDEPENDENT_TESTS = "backend_independent_tests" def if_building_jaxlib( if_building, if_not_building = [ - "@pypi_jaxlib//:pkg", - "@pypi_jax_cuda12_plugin//:pkg", - "@pypi_jax_cuda12_pjrt//:pkg", - ], - if_not_building_for_cpu = ["@pypi_jaxlib//:pkg"], - if_py_import = [ - "//jaxlib/tools:jaxlib_py_import", - "//jaxlib/tools:jax_cuda_plugin_py_import", - "//jaxlib/tools:jax_cuda_pjrt_py_import", - ], - if_py_import_for_cpu = [ - "//jaxlib/tools:jaxlib_py_import", + "@pypi//jaxlib", ]): - """Adds jaxlib and jaxlib cuda plugin wheels as dependencies instead of depending on sources. + """Adds jaxlib wheels as dependencies instead of depending on sources. This allows us to test prebuilt versions of jaxlib wheels against the rest of the JAX codebase. Args: if_building: the source code targets to depend on in case we don't depend on the jaxlib wheels - if_not_building: the jaxlib wheels to depend on including gpu-specific plugins in case of - gpu-enabled builds - if_not_building_for_cpu: the jaxlib wheels to depend on in case of cpu-only builds - if_py_import: the py_import targets to depend on in case of gpu-enabled builds - if_py_import_for_cpu: the py_import targets to depend on in case of cpu-only builds + if_not_building: the wheels to depend on if we are not depending directly on //jaxlib. """ + return select({ + "//jax:config_build_jaxlib_true": if_building, + "//jax:config_build_jaxlib_false": if_not_building, + "//jax:config_build_jaxlib_wheel": [], + }) +def _cpu_test_deps(): + """Returns the test dependencies needed for a CPU-only JAX test.""" return select({ - "//jax:enable_jaxlib_build": if_building, - "//jax_plugins/cuda:disable_jaxlib_for_cpu_build": if_not_building_for_cpu, - "//jax_plugins/cuda:disable_jaxlib_for_cuda12_build": if_not_building, - "//jax_plugins/cuda:enable_py_import_for_cpu_build": if_py_import_for_cpu, - "//jax_plugins/cuda:enable_py_import_for_cuda12_build": if_py_import, + "//jax:config_build_jaxlib_true": [], + "//jax:config_build_jaxlib_false": ["@pypi//jaxlib"], + "//jax:config_build_jaxlib_wheel": ["//jaxlib/tools:jaxlib_py_import"], + }) + +def _gpu_test_deps(): + """Returns the additional dependencies needed for a GPU test.""" + return select({ + "//jax:config_build_jaxlib_true": [ + "//jaxlib/cuda:gpu_only_test_deps", + "//jaxlib/rocm:gpu_only_test_deps", + "//jax_plugins:gpu_plugin_only_test_deps", + ], + "//jax:config_build_jaxlib_false": [ + "//jaxlib/tools:pypi_jax_cuda_plugin_with_cuda_deps", + "//jaxlib/tools:pypi_jax_cuda_pjrt_with_cuda_deps", + ], + "//jax:config_build_jaxlib_wheel": [ + "//jaxlib/tools:jax_cuda_plugin_py_import", + "//jaxlib/tools:jax_cuda_pjrt_py_import", + ], + }) + +def _get_jax_test_deps(deps): + """Returns the jax build deps, pypi jax wheel dep, or jax py_import dep for the given backend. + + Args: + deps: the full list of test dependencies + + Returns: + A list of jax test deps. + + If --//jax:build_jax=true, returns jax build deps. + If --//jax:build_jax=false, returns jax pypi wheel dep and transitive pypi test deps. + If --//jax:build_jax=wheel, returns jax py_import dep and transitive pypi test deps. + """ + non_pypi_deps = [d for d in deps if not d.startswith("@pypi//")] + + # A lot of tests don't have explicit dependencies on scipy, ml_dtypes, etc. But the tests + # transitively depends on them via //jax. So we need to make sure that these dependencies are + # included in the test when JAX is built from source. + pypi_deps = depset([d for d in deps if d.startswith("@pypi//")]) + pypi_deps = depset(py_deps([ + "ml_dtypes", + "scipy", + "opt_einsum", + "flatbuffers", + ]), transitive = [pypi_deps]).to_list() + + return pypi_deps + select({ + "//jax:config_build_jax_false": ["//:jax_wheel_with_internal_test_util"], + "//jax:config_build_jax_wheel": ["//:jax_py_import"], + "//jax:config_build_jax_true": non_pypi_deps, }) # buildifier: disable=function-docstring @@ -262,16 +246,18 @@ def jax_multiplatform_test( args = [], env = {}, shard_count = None, + minimal_shard_count = None, deps = [], data = [], enable_backends = None, - backend_variant_args = {}, # buildifier: disable=unused-variable + backend_variant_args = {}, backend_tags = {}, # buildifier: disable=unused-variable disable_configs = None, # buildifier: disable=unused-variable enable_configs = [], config_tags_overrides = None, # buildifier: disable=unused-variable tags = [], main = None, + size = None, # buildifier: disable=unused-variable pjrt_c_api_bypass = False): # buildifier: disable=unused-variable # enable_configs and disable_configs do not do anything in OSS, only in Google's CI. # The order in which `enable_backends`, `enable_configs`, and `disable_configs` are applied is @@ -286,40 +272,65 @@ def jax_multiplatform_test( else: fail("Must set a main file to test multiple source files.") + env = dict(env) + env.setdefault("PYTHONWARNINGS", "error") + for backend in ALL_BACKENDS: - if shard_count == None or type(shard_count) == type(0): - test_shards = shard_count + test_shard_count = minimal_shard_count if USE_MINIMAL_SHARD_COUNT else shard_count + if test_shard_count == None or type(test_shard_count) == type(0): + test_shards = test_shard_count else: - test_shards = shard_count.get(backend, 1) + test_shards = test_shard_count.get(backend, 1) test_args = list(args) + [ "--jax_test_dut=" + backend, "--jax_platform_name=" + backend, ] + test_args += backend_variant_args.get(backend, []) test_tags = list(tags) + ["jax_test_%s" % backend] + backend_tags.get(backend, []) if enable_backends != None and backend not in enable_backends and not any([config.startswith(backend) for config in enable_configs]): test_tags.append("manual") + test_deps = _cpu_test_deps() + _get_jax_test_deps([ + "//jax", + "//jax/_src:test_util", + ] + deps) if backend == "gpu": + test_deps += _gpu_test_deps() test_tags += tf_cuda_tests_tags() - native.py_test( + elif backend == "tpu": + test_deps += ["@pypi//libtpu"] + py_test( name = name + "_" + backend, srcs = srcs, args = test_args, env = env, - deps = [ - "//jax", - "//jax:test_util", - ] + deps + if_building_jaxlib([ - "//jaxlib/cuda:gpu_only_test_deps", - "//jaxlib/rocm:gpu_only_test_deps", - "//jax_plugins:gpu_plugin_only_test_deps", - ]), + deps = test_deps, data = data, shard_count = test_shards, tags = test_tags, main = main, exec_properties = tf_exec_properties({"tags": test_tags}), + visibility = jax_visibility(name), ) +def get_test_suite_list(paths, backends = []): + """Returns a list of test suite targets for the given paths and backends. + + Args: + paths: the paths to the test suites. + backends: the set of backends for which rules should be generated. Defaults to all backends. + + Returns: + A list of backend specific test suite targets. + """ + test_suite_list = [] + if not backends: + backends = ALL_BACKENDS + for path in paths: + for backend in backends: + test_suite_list.append("//{}:{}{}".format(path, backend, TEST_SUITE_SUFFIX)) + test_suite_list.append("//{}:{}".format(path, BACKEND_INDEPENDENT_TESTS)) + return test_suite_list + def jax_generate_backend_suites(backends = []): """Generates test suite targets named cpu_tests, gpu_tests, etc. @@ -330,12 +341,14 @@ def jax_generate_backend_suites(backends = []): backends = ALL_BACKENDS for backend in backends: native.test_suite( - name = "%s_tests" % backend, + name = backend + TEST_SUITE_SUFFIX, tags = ["jax_test_%s" % backend, "-manual"], + visibility = jax_visibility(backend + TEST_SUITE_SUFFIX), ) native.test_suite( - name = "backend_independent_tests", + name = BACKEND_INDEPENDENT_TESTS, tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"], + visibility = jax_visibility(BACKEND_INDEPENDENT_TESTS), ) def _get_full_wheel_name( @@ -362,7 +375,7 @@ def _get_full_wheel_name( free_threaded_suffix = "t" if py_freethreaded.lower() == "yes" else "", ) -def _get_source_distribution_name(package_name, wheel_version): +def _get_source_package_name(package_name, wheel_version): return "{package_name}-{wheel_version}.tar.gz".format( package_name = package_name, wheel_version = wheel_version, @@ -394,37 +407,47 @@ def _jax_wheel_impl(ctx): no_abi = ctx.attr.no_abi platform_independent = ctx.attr.platform_independent build_wheel_only = ctx.attr.build_wheel_only + build_source_package_only = ctx.attr.build_source_package_only editable = ctx.attr.editable platform_name = ctx.attr.platform_name + + output_dir_path = "" + outputs = [] if editable: output_dir = ctx.actions.declare_directory(output_path + "/" + ctx.attr.wheel_name) - wheel_dir = output_dir.path + output_dir_path = output_dir.path outputs = [output_dir] args.add("--editable") else: - wheel_name = _get_full_wheel_name( - package_name = ctx.attr.wheel_name, - no_abi = no_abi, - platform_independent = platform_independent, - platform_name = platform_name, - cpu_name = cpu, - wheel_version = full_wheel_version, - py_freethreaded = py_freethreaded, - ) - wheel_file = ctx.actions.declare_file(output_path + - "/" + wheel_name) - wheel_dir = wheel_file.path[:wheel_file.path.rfind("/")] - outputs = [wheel_file] - if not build_wheel_only: - source_distribution_name = _get_source_distribution_name( + if build_wheel_only: + wheel_name = _get_full_wheel_name( package_name = ctx.attr.wheel_name, + no_abi = no_abi, + platform_independent = platform_independent, + platform_name = platform_name, + cpu_name = cpu, wheel_version = full_wheel_version, + py_freethreaded = py_freethreaded, ) - source_distribution_file = ctx.actions.declare_file(output_path + - "/" + source_distribution_name) - outputs.append(source_distribution_file) - - args.add("--output_path", wheel_dir) # required argument + wheel_file = ctx.actions.declare_file(output_path + + "/" + wheel_name) + output_dir_path = wheel_file.path[:wheel_file.path.rfind("/")] + outputs = [wheel_file] + if ctx.attr.wheel_name == "jax": + args.add("--build-wheel-only", "True") + if build_source_package_only: + source_package_name = _get_source_package_name( + package_name = ctx.attr.wheel_name, + wheel_version = full_wheel_version, + ) + source_package_file = ctx.actions.declare_file(output_path + + "/" + source_package_name) + output_dir_path = source_package_file.path[:source_package_file.path.rfind("/")] + outputs = [source_package_file] + if ctx.attr.wheel_name == "jax": + args.add("--build-source-package-only", "True") + + args.add("--output_path", output_dir_path) # required argument if not platform_independent: args.add("--cpu", cpu) args.add("--jaxlib_git_hash", git_hash) # required argument @@ -464,24 +487,24 @@ def _jax_wheel_impl(ctx): _jax_wheel = rule( attrs = { "wheel_binary": attr.label( - default = Label("//jaxlib/tools:build_wheel"), + default = Label("//jaxlib/tools:build_wheel_tool"), executable = True, - # b/365588895 Investigate cfg = "exec" for multi platform builds - cfg = "target", + cfg = "exec", ), "wheel_name": attr.string(mandatory = True), "no_abi": attr.bool(default = False), "platform_independent": attr.bool(default = False), - "build_wheel_only": attr.bool(default = True), + "build_wheel_only": attr.bool(mandatory = True, default = True), + "build_source_package_only": attr.bool(mandatory = True, default = False), "editable": attr.bool(default = False), - "cpu": attr.string(mandatory = True), - "platform_name": attr.string(mandatory = True), + "cpu": attr.string(), + "platform_name": attr.string(), "git_hash": attr.label(default = Label("//jaxlib/tools:jaxlib_git_hash")), "source_files": attr.label_list(allow_files = True), "output_path": attr.label(default = Label("//jaxlib/tools:output_path")), "enable_cuda": attr.bool(default = False), # A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string. - "platform_version": attr.string(mandatory = True, default = ""), + "platform_version": attr.string(), "skip_gpu_kernels": attr.bool(default = False), "enable_rocm": attr.bool(default = False), "include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:include_cuda_libs")), @@ -498,7 +521,6 @@ def jax_wheel( wheel_name, no_abi = False, platform_independent = False, - build_wheel_only = True, editable = False, enable_cuda = False, enable_rocm = False, @@ -509,11 +531,10 @@ def jax_wheel( Common artifact attributes are grouped within a single macro. Args: - name: the name of the wheel + name: the target name wheel_binary: the binary to use to build the wheel wheel_name: the name of the wheel no_abi: whether to build a wheel without ABI - build_wheel_only: whether to build a wheel without source distribution editable: whether to build an editable wheel platform_independent: whether to build a wheel without platform tag enable_cuda: whether to build a cuda wheel @@ -522,7 +543,7 @@ def jax_wheel( source_files: the source files to include in the wheel Returns: - A directory containing the wheel + A wheel file or a wheel directory. """ _jax_wheel( name = name, @@ -530,7 +551,8 @@ def jax_wheel( wheel_name = wheel_name, no_abi = no_abi, platform_independent = platform_independent, - build_wheel_only = build_wheel_only, + build_wheel_only = True, + build_source_package_only = False, editable = editable, enable_cuda = enable_cuda, enable_rocm = enable_rocm, @@ -554,6 +576,34 @@ def jax_wheel( source_files = source_files, ) +def jax_source_package( + name, + source_package_binary, + source_package_name, + source_files = []): + """Create jax source package. + + Common artifact attributes are grouped within a single macro. + + Args: + name: the target name + source_package_binary: the binary to use to build the package + source_package_name: the name of the source package + source_files: the source files to include in the package + + Returns: + A jax source package file. + """ + _jax_wheel( + name = name, + wheel_binary = source_package_binary, + wheel_name = source_package_name, + build_source_package_only = True, + build_wheel_only = False, + platform_independent = True, + source_files = source_files, + ) + jax_test_file_visibility = [] jax_export_file_visibility = [] @@ -566,6 +616,304 @@ def jax_py_test( env = {}, **kwargs): env = dict(env) - if "PYTHONWARNINGS" not in env: - env["PYTHONWARNINGS"] = "error" + env.setdefault("PYTHONWARNINGS", "error") + deps = kwargs.get("deps", []) + test_deps = _cpu_test_deps() + _get_jax_test_deps(deps) + kwargs["deps"] = test_deps py_test(name = name, env = env, **kwargs) + +def pytype_test(name, **kwargs): + deps = kwargs.get("deps", []) + test_deps = _cpu_test_deps() + _get_jax_test_deps(deps) + kwargs["deps"] = test_deps + py_test(name = name, **kwargs) + +def if_oss(oss_value, google_value = []): + """Returns one of the arguments based on the non-configurable build env. + + Specifically, it does not return a `select`, and can be used to e.g. + compute elements of list attributes. + """ + _ = (google_value, oss_value) # buildifier: disable=unused-variable + return oss_value + +def wheel_sources( + name, + py_srcs = [], + data_srcs = [], + symlink_data_srcs = [], + hdr_srcs = [], + static_srcs = []): + """Create a filegroup containing the list of source files for a wheel. + + The sources are collected from the static files and from the transitive dependencies of the + given srcs. + + Args: + name: the target name + py_srcs: targets which transitive python dependencies should be included in the wheel + data_srcs: targets which platform-dependent data dependencies should be included in the wheel + symlink_data_srcs: targets which symlinked data dependencies should be included in the wheel + hdr_srcs: targets which transitive header dependencies should be included in the wheel + static_srcs: the platform-independent file dependencies of the wheel + """ + transitive_py_deps(name = "{}_py".format(name), deps = py_srcs) + collect_data_files( + name = "{}_data".format(name), + deps = data_srcs, + symlink_deps = symlink_data_srcs, + ) + transitive_hdrs(name = "{}_hdrs".format(name), deps = hdr_srcs) + native.filegroup( + name = name, + srcs = [ + ":{}_py".format(name), + ":{}_data".format(name), + ":{}_hdrs".format(name), + ] + static_srcs, + visibility = jax_visibility(name), + ) + +def if_pypi_cuda_wheel_deps(if_true, if_false = []): + """ select() on whether we're adding pypi CUDA wheel deps. """ + return select({ + "//jaxlib/tools:pypi_cuda_wheel_deps": if_true, + "//conditions:default": if_false, + }) + +def jax_multiprocess_test( + name, + srcs, + args = [], + env = {}, + shard_count = None, + minimal_shard_count = None, + deps = [], + data = [], + enable_backends = None, + backend_variant_args = {}, + backend_tags = {}, + disable_configs = None, + enable_configs = [], + config_tags_overrides = None, + tags = [], + main = None): + # TODO(emilyaf): Avoid hard-coding the number of processes and chips/gpus per process. + multiprocess_backend_args = { + "cpu": backend_variant_args.get("cpu", []) + [ + "--num_processes=4", + ], + "gpu": backend_variant_args.get("gpu", []) + [ + "--num_processes=4", + "--gpus_per_process=2", + ], + "tpu": backend_variant_args.get("tpu", []) + [ + "--num_processes=4", + "--tpu_chips_per_process=1", + ], + } + tags = tags + ["multiaccelerator"] + deps = deps + py_deps(["absl-all", "portpicker"]) + return jax_multiplatform_test( + name = name, + srcs = srcs, + args = args, + env = env, + shard_count = shard_count, + minimal_shard_count = minimal_shard_count, + deps = deps, + data = data, + enable_backends = enable_backends, + backend_variant_args = multiprocess_backend_args, + backend_tags = backend_tags, + disable_configs = disable_configs, + enable_configs = enable_configs, + config_tags_overrides = config_tags_overrides, + tags = tags, + main = main, + ) + +def jax_multiprocess_generate_backend_suites(name = None, backends = []): + return jax_generate_backend_suites(backends = backends) + +TransitiveSrcsInfo = provider( + "Provider to collect transitive source files and dependencies", + fields = {"files": "depset of files"}, +) + +def _collect_transitive_srcs_aspect_impl(_, ctx): + files = [] + attrs_to_traverse = ["srcs", "deps"] + if ctx.rule.kind == "test_suite": + attrs_to_traverse.append("_implicit_tests") + is_test_rule = ctx.rule.kind in ["py_test", "test_suite"] + + for attr in attrs_to_traverse: + if not hasattr(ctx.rule.attr, attr): + continue + attr_val = getattr(ctx.rule.attr, attr) + if type(attr_val) != "list": + continue + + for dep in attr_val: + transitive_sources = {} + if not PyInfo in dep: + continue + if is_test_rule: + source_files = [f for f in dep[DefaultInfo].files.to_list() if f.is_source] + else: + source_files = [] + for ts in dep[PyInfo].transitive_sources.to_list(): + if (not ts.owner.package or ts in source_files): + # Skip test source files and files in @pypi wheels + continue + else: + transitive_sources[ts] = True + files.append(depset(transitive_sources.keys())) + return [TransitiveSrcsInfo(files = depset(transitive = files))] + +collect_srcs_aspect = aspect( + implementation = _collect_transitive_srcs_aspect_impl, + attr_aspects = ["srcs", "deps"], +) + +collect_test_deps_aspect = aspect( + implementation = _collect_transitive_srcs_aspect_impl, + attr_aspects = ["tests"], +) + +def _compare_srcs_and_test_deps_test_impl(ctx): + build_jaxlib = ctx.attr.build_jaxlib[BuildSettingInfo].value + build_jax = ctx.attr.build_jax[BuildSettingInfo].value + message = "PASSED: All test dependencies are present in the wheel." + test_result = 0 + doc_link = "https://github.com/jax-ml/jax/blob/main/docs/contributing.md#wheel-sources-update" + + if build_jax == "true" and build_jaxlib == "true": + srcs_list = [] + for src in ctx.attr.srcs: + if TransitiveSrcsInfo in src: + srcs_list.append(src[TransitiveSrcsInfo].files) + + srcs_depset = depset(transitive = srcs_list) + srcs_map = { + f.short_path: True + for f in srcs_depset.to_list() + ctx.files.srcs + } + + test_dependencies_list = [] + for test in ctx.attr.tests: + if TransitiveSrcsInfo in test: + test_dependencies_list.append(test[TransitiveSrcsInfo].files) + + test_dependencies_depset = depset(transitive = test_dependencies_list) + test_dependencies_map = {} + + # We need to add __init__.py files for all python modules to make them available via API. + for f in test_dependencies_depset.to_list(): + test_dependencies_map[f.short_path] = True + init_py_path = f.short_path.replace(f.basename, "__init__.py") + for root_folder_name in ctx.attr.root_package_names: + if (f.short_path.startswith(root_folder_name) and + init_py_path not in ctx.attr.ignored_init_py_files): + test_dependencies_map[init_py_path] = True + break + + test_dependencies_paths = [k for k in test_dependencies_map.keys()] + srcs_paths = srcs_map.keys() + + if srcs_paths != test_dependencies_paths: + missing_in_srcs = sorted([ + p + for p in test_dependencies_paths + if p not in srcs_map + ]) + + if missing_in_srcs: + message = ("FAILED: Files in test dependencies not found in sources: %s" % + missing_in_srcs + "\n" + + "See instructions in %s" % doc_link) + test_result = 1 + + else: + message = "SKIPPED: The test will be executed only with //jax:build_jax=true and //jax:build_jaxlib=true." + test_result = 0 + + if ctx.attr.platform_name == "Windows": + script_content = """@ECHO OFF +ECHO {message} +EXIT /B {test_result}""".format(message = message, test_result = test_result) + script_filename = ctx.label.name + "_runner.bat" + else: + script_content = """#!/bin/bash +echo "{message}" +exit {test_result}""".format(message = message, test_result = test_result) + script_filename = ctx.label.name + "_runner.sh" + + test_runner_script = ctx.actions.declare_file(script_filename) + ctx.actions.write( + output = test_runner_script, + content = script_content, + is_executable = True, + ) + + runfiles = ctx.runfiles(files = []) + + return [ + DefaultInfo( + executable = test_runner_script, + runfiles = runfiles, + ), + ] + +_compare_srcs_and_test_deps_test = rule( + implementation = _compare_srcs_and_test_deps_test_impl, + attrs = { + "srcs": attr.label_list( + allow_files = True, + mandatory = True, + aspects = [collect_srcs_aspect], + ), + "tests": attr.label_list( + allow_empty = False, + mandatory = True, + aspects = [collect_test_deps_aspect], + ), + "ignored_init_py_files": attr.string_list( + mandatory = True, + ), + "build_jaxlib": attr.label(default = Label("//jax:build_jaxlib")), + "build_jax": attr.label(default = Label("//jax:build_jax")), + "root_package_names": attr.string_list(mandatory = True), + "platform_name": attr.string(mandatory = True), + }, + test = True, +) + +def compare_srcs_and_test_deps_test(name, srcs, tests, ignored_init_py_files, root_package_names, tags = []): + """Compares the source files against the test dependencies. + + Args: + srcs: The source files to compare. + tests: The test dependencies to compare. + ignored_init_py_files: The init python files to ignore. + build_jaxlib: The build setting for jaxlib. + build_jax: The build setting for jax. + root_package_names: The root folder names to compare. + tags: The tags to apply to the test. + """ + _compare_srcs_and_test_deps_test( + name = name, + srcs = srcs, + tests = tests, + ignored_init_py_files = ignored_init_py_files, + root_package_names = root_package_names, + platform_name = select({ + "@platforms//os:osx": "Darwin", + "@platforms//os:macos": "Darwin", + "@platforms//os:windows": "Windows", + "@platforms//os:linux": "Linux", + }), + tags = tags, + testonly = True, + ) diff --git a/jaxlib/jax.cc b/jaxlib/jax.cc new file mode 100644 index 000000000000..2062d98d11e7 --- /dev/null +++ b/jaxlib/jax.cc @@ -0,0 +1,971 @@ +/* Copyright 2019 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "nanobind/nb_defs.h" +#include "nanobind/stl/function.h" // IWYU pragma: keep +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/set.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/unordered_map.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/ffi.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_program.h" +#include "jaxlib/py_values.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" +#include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/cpu/cpu_client.h" +#include "xla/pjrt/distributed/client.h" +#include "xla/pjrt/distributed/distributed.h" +#include "xla/pjrt/distributed/protocol.pb.h" +#include "xla/pjrt/distributed/service.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/plugin/xla_cpu/cpu_client_options.h" +#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/topology.h" +#include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" +#include "xla/python/pjrt_ifrt/transfer_server_interface.h" +#include "xla/python/version.h" +#include "xla/tsl/python/lib/core/numpy.h" // NOLINT + +#if defined(__linux__) +#include "gloo/transport/tcp/attr.h" +#include "gloo/transport/tcp/device.h" +#include "jaxlib/py_socket_transfer.h" +#include "xla/backends/cpu/collectives/gloo_collectives.h" +#include "xla/backends/cpu/collectives/gloo_kv_store.h" +#elif defined(__APPLE__) +#include "gloo/transport/uv/device.h" +#include "xla/backends/cpu/collectives/gloo_collectives.h" // NOLINT +#include "xla/backends/cpu/collectives/gloo_kv_store.h" // NOLINT +#endif // defined(__linux__) + +#if !defined(_WIN32) && !defined(PLATFORM_GOOGLE) +#include "xla/backends/cpu/collectives/mpi_collectives.h" +#endif // !_WIN32 && !PLATFORM_GOOGLE + +#include "jaxlib/call_location.h" +#include "jaxlib/config.h" +#include "jaxlib/custom_call_sharding.h" +#include "jaxlib/dlpack.h" +#include "jaxlib/guard_lib.h" +#include "jaxlib/jax_jit.h" +#include "jaxlib/mlir.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/partition_spec.h" +#include "jaxlib/pjit.h" +#include "jaxlib/pmap_lib.h" +#include "jaxlib/pprof_profile_builder.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_compile_only_client.h" +#include "jaxlib/py_device.h" +#include "jaxlib/py_device_list.h" +#include "jaxlib/py_executable.h" +#include "jaxlib/py_memory_space.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/pytree.h" +#include "jaxlib/sharding.h" +#include "jaxlib/traceback.h" +#include "jaxlib/xla_compiler.h" +#include "xla/hlo/builder/lib/approx_topk_shape.h" +#include "xla/pjrt/c_api_client/pjrt_c_api_client.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" +#include "xla/pjrt/distributed/preemption/preemption_sync_manager.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_api.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/logging.h" // IWYU pragma: keep +#include "xla/python/nb_absl_flat_hash_map.h" // IWYU pragma: keep +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/pjrt_topology.h" +#include "xla/tsl/distributed_runtime/preemption/preemption_sync_manager.h" +#include "xla/tsl/platform/status.h" +#include "tsl/platform/platform.h" + +// TODO(phawkins): remove host_id properties after JAX is update to avoid them. + +namespace nb = nanobind; + +namespace jax { +namespace { + +bool IsOptimizedBuild() { +#if NDEBUG + return true; +#else + return false; +#endif // NDEBUG +} + +// Is*san reports whether the build is under that particular sanitizer. +bool IsAsan() { +#if defined(__SANITIZE_ADDRESS__) + return true; // GCC and newer MSVC +#elif defined(__has_feature) && __has_feature(address_sanitizer) + return true; // Clang +#else + return false; +#endif +} + +bool IsMsan() { +#if defined(__has_feature) && __has_feature(memory_sanitizer) + return true; // Clang (MSan is typically Clang-only) +#elif defined(__SANITIZE_MEMORY__) + return true; // GCC (rare, but future-proof) +#else + return false; +#endif +} + +bool IsTsan() { +#if defined(__SANITIZE_THREAD__) + return true; // GCC +#elif defined(__has_feature) && __has_feature(thread_sanitizer) + return true; // Clang +#else + return false; +#endif +} + +// IsSanitized reports whether the build is under any sanitizer. +bool IsSanitized() { return IsAsan() || IsMsan() || IsTsan(); } + +} // namespace + +NB_MODULE(_jax, m) { + // Initialize ABSL logging because code within XLA uses it. +#ifndef PLATFORM_GOOGLE + xla::InitializeAbslLogging(); +#endif // PLATFORM_GOOGLE + + // We seem to get a fair number of leak warnings from nanobind. It's unclear + // whether these are false positives or not. + nb::set_leak_warnings(false); + + tsl::ImportNumpy(); + + // Exceptions + nb::exception xla_runtime_error(m, "JaxRuntimeError", + PyExc_RuntimeError); + xla_runtime_error.attr("__doc__") = nb::str( + "Runtime errors thrown by the JAX runtime. While the JAX runtime may " + "raise other exceptions as well, most exceptions thrown by the runtime " + "are instances of this class."); + + // Must be before PyClient.compile. + xla::BuildXlaCompilerSubmodule(m); + + PyDevice::Register(m); + PyMemorySpace::Register(m); + PyClient::Register(m); + + nb::enum_(m, "ArrayCopySemantics", + nb::is_arithmetic()) + .value("ALWAYS_COPY", xla::ifrt::ArrayCopySemantics::kAlwaysCopy) + .value("REUSE_INPUT", xla::ifrt::ArrayCopySemantics::kReuseInput) + .value("DONATE_INPUT", xla::ifrt::ArrayCopySemantics::kDonateInput); + + nb::class_(m, "PjRtLayout") + .def("__str__", &xla::PjRtLayout::ToString) + .def("__eq__", + [](const xla::PjRtLayout& layout, nb::object other) { + return nb::isinstance(other) && + layout == nb::cast(other); + }) + .def("__hash__", + [](const xla::PjRtLayout& layout) { return absl::HashOf(layout); }) + .def("_xla_layout", &xla::PjRtLayout::xla_layout) + .def("__getstate__", + [](const xla::PjRtLayout& layout) -> nb::tuple { + absl::StatusOr serialized = layout.Serialize(); + xla::ThrowIfError(serialized.status()); + return nb::make_tuple( + nb::bytes(serialized->data(), serialized->size())); + }) + .def("__setstate__", [](xla::PjRtLayout* self, nb::tuple t) { + nb::bytes serialized = nb::cast(t[0]); + absl::StatusOr> layout = + xla::PjRtLayout::Deserialize( + std::string_view(serialized.c_str(), serialized.size())); + xla::ThrowIfError(layout.status()); + new (self) xla::PjRtLayout((*layout)->xla_layout()); + }); + + nb::class_ cpu_collectives(m, "CpuCollectives"); + cpu_collectives + .def("Init", + [](xla::cpu::CpuCollectives*) { + throw std::runtime_error("Init is not implemented"); + }) + .def("Finalize", [](xla::cpu::CpuCollectives*) { + throw std::runtime_error("Finalize is not implemented"); + }); + + m.def( + "make_gloo_tcp_collectives", + [](std::shared_ptr distributed_client, + + std::optional hostname, + std::optional interface) + -> std::shared_ptr { +#if defined(__linux__) + std::shared_ptr kv_store = nullptr; + if (distributed_client != nullptr) { + kv_store = GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"cpu:"); + } + auto gloo_kv_store = + std::make_unique(kv_store); + auto tcp_attrs = gloo::transport::tcp::attr(); + if (hostname) { + tcp_attrs.hostname = *hostname; + } + if (interface) { + tcp_attrs.iface = *interface; + } + auto tcp_device = gloo::transport::tcp::CreateDevice(tcp_attrs); + return std::make_shared( + std::move(gloo_kv_store), std::move(tcp_device)); +#elif defined(__APPLE__) + std::shared_ptr kv_store = nullptr; + if (distributed_client != nullptr) { + kv_store = GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"cpu:"); + } + auto gloo_kv_store = + std::make_unique(kv_store); + auto uv_attrs = gloo::transport::uv::attr(); + if (hostname) { + uv_attrs.hostname = *hostname; + } + if (interface) { + uv_attrs.iface = *interface; + } + auto uv_device = gloo::transport::uv::CreateDevice(uv_attrs); + return std::make_shared( + std::move(gloo_kv_store), std::move(uv_device)); +#else // defined(__linux__) + throw xla::XlaRuntimeError( + "make_gloo_tcp_collectives only implemented for linux and macos"); +#endif // defined(__linux__) + }, + nb::arg("distributed_client"), nb::arg("hostname").none() = std::nullopt, + nb::arg("interface").none() = std::nullopt); + +#if !defined(_WIN32) && !defined(PLATFORM_GOOGLE) + nb::class_ mpi_collectives(m, "MpiCollectives", + cpu_collectives); + mpi_collectives.def("Init", &xla::cpu::MpiCollectives::Init); + mpi_collectives.def("Finalize", &xla::cpu::MpiCollectives::Finalize); + m.def("make_mpi_collectives", + []() -> std::shared_ptr { + return std::make_shared(); + }); +#else // !_WIN32 && !PLATFORM_GOOGLE + m.def("make_mpi_collectives", + []() -> std::shared_ptr { + throw xla::XlaRuntimeError( + "make_mpi_collectives is not implemented for Windows"); + }); +#endif // !_WIN32 && !PLATFORM_GOOGLE + + m.def( + "get_tfrt_cpu_client", + [](bool asynchronous, + std::shared_ptr distributed_client, + int node_id, int num_nodes, + std::shared_ptr collectives, + std::optional num_devices, + std::optional get_local_topology_timeout_minutes, + std::optional get_global_topology_timeout_minutes, + std::optional + transfer_server_factory) -> nb_class_ptr { + std::unique_ptr ifrt_client; + { + nb::gil_scoped_release gil_release; + xla::CpuClientOptions options; + + options.asynchronous = asynchronous; + options.collectives = std::move(collectives); + options.process_id = node_id; + options.cpu_device_count = num_devices; + std::unique_ptr client = + xla::ValueOrThrow(xla::GetPjRtCpuClient(std::move(options))); + xla::ifrt::PjRtClient::CreateOptions ifrt_options; + ifrt_options.pjrt_client = + std::shared_ptr(std::move(client)); + if (distributed_client != nullptr) { + ifrt_options.kv_store = + GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"cpu:"); + ifrt_options.process_id = node_id; + ifrt_options.num_processes = num_nodes; + } + if (get_local_topology_timeout_minutes.has_value()) { + ifrt_options.get_local_topology_timeout = + absl::Minutes(*get_local_topology_timeout_minutes); + } + if (get_global_topology_timeout_minutes.has_value()) { + ifrt_options.get_global_topology_timeout = + absl::Minutes(*get_global_topology_timeout_minutes); + } + if (transfer_server_factory.has_value()) { + ifrt_options.transfer_server_factory = + std::move(transfer_server_factory->factory_fn); + } + ifrt_client = xla::ValueOrThrow( + xla::ifrt::PjRtClient::Create(std::move(ifrt_options))); + } + return PyClient::Make(std::move(ifrt_client)); + }, + nb::arg("asynchronous") = true, nb::arg("distributed_client") = nullptr, + nb::arg("node_id") = 0, nb::arg("num_nodes") = 1, + nb::arg("collectives").none() = + std::shared_ptr(), + nb::arg("num_devices").none() = std::nullopt, + nb::arg("get_local_topology_timeout_minutes").none() = std::nullopt, + nb::arg("get_global_topology_timeout_minutes").none() = std::nullopt, + nb::arg("transfer_server_factory").none() = std::nullopt); + m.def("pjrt_plugin_loaded", [](std::string platform_name) -> bool { + absl::StatusOr pjrt_api = pjrt::PjrtApi(platform_name); + return pjrt_api.ok(); + }); + m.def( + "load_pjrt_plugin", + [](std::string platform_name, std::optional library_path, + std::optional c_api) -> nb::capsule { + if (library_path.has_value()) { + const PJRT_Api* api = xla::ValueOrThrow( + pjrt::LoadPjrtPlugin(platform_name, *library_path)); + return nb::capsule(absl::bit_cast(api), "pjrt_c_api"); + } + if (std::string_view(c_api->name()) != "pjrt_c_api") { + throw nb::value_error( + "c_api argument to load_pjrt_plugin is not a pjrt_c_api " + "capsule."); + } + xla::ThrowIfError(pjrt::SetPjrtApi( + platform_name, static_cast(c_api->data()))); + return *c_api; + }, + nb::arg("platform_name"), nb::arg("library_path").none() = std::nullopt, + nb::arg("c_api").none() = std::nullopt); + m.def("pjrt_plugin_initialized", [](std::string platform_name) -> bool { + return xla::ValueOrThrow(pjrt::IsPjrtPluginInitialized(platform_name)); + }); + m.def("initialize_pjrt_plugin", [](std::string platform_name) { + return xla::ThrowIfError(pjrt::InitializePjrtPlugin(platform_name)); + }); + + m.def( + "get_c_api_client", + [](std::string platform_name, + const absl::flat_hash_map& options, + std::shared_ptr distributed_client, + std::optional + transfer_server_factory, + bool force_dcn_cross_host_transfers) -> nb_class_ptr { + std::unique_ptr ifrt_client; + { + nb::gil_scoped_release gil_release; + std::shared_ptr kv_store = nullptr; + if (distributed_client != nullptr) { + kv_store = GetDistributedKeyValueStore( + distributed_client, + /*key_prefix=*/absl::StrCat(platform_name, ":")); + } + std::unique_ptr c_api_client = xla::ValueOrThrow( + xla::GetCApiClient(platform_name, options, kv_store)); + xla::ifrt::PjRtClient::CreateOptions ifrt_options; + ifrt_options.pjrt_client = + std::shared_ptr(std::move(c_api_client)); + ifrt_options.kv_store = kv_store; + ifrt_options.use_kv_store_for_topology_exchange = false; + ifrt_options.distributed_client = distributed_client; + if (transfer_server_factory.has_value()) { + ifrt_options.transfer_server_factory = + std::move(transfer_server_factory->factory_fn); + } + ifrt_options.force_dcn_cross_host_transfers = + force_dcn_cross_host_transfers; + ifrt_client = xla::ValueOrThrow( + xla::ifrt::PjRtClient::Create(std::move(ifrt_options))); + } + return PyClient::Make(std::move(ifrt_client)); + }, + nb::arg("platform_name"), + nb::arg("options") = + absl::flat_hash_map(), + nb::arg("distributed_client").none() = nullptr, + nb::arg("transfer_server_factory").none() = std::nullopt, + nb::arg("force_dcn_cross_host_transfers") = false); + // TODO(b/322357665): Delete this method after TPU plugin changes to use the + // standard registration. + m.def("get_default_c_api_topology", + [](std::string platform_name, std::string topology_name, + const absl::flat_hash_map& options) + -> std::shared_ptr { + return std::make_shared(xla::ValueOrThrow( + xla::GetCApiTopology(platform_name, topology_name, options))); + }); + m.def("get_c_api_topology", + [](nb::capsule c_api, std::string topology_name, + const absl::flat_hash_map& options) + -> std::shared_ptr { + if (std::string_view(c_api.name()) != "pjrt_c_api") { + throw nb::value_error( + "Argument to get_c_api_topology was not a pjrt_c_api capsule."); + } + return std::make_shared(xla::ValueOrThrow( + xla::GetCApiTopology(static_cast(c_api.data()), + topology_name, options))); + }); + m.def("get_topology_for_devices", + [](const std::vector>& py_devices) { + if (py_devices.empty()) { + throw nb::value_error( + "get_topology_for_devices requires >= 1 devices."); + } + auto client = py_devices[0]->client(); + absl::InlinedVector ifrt_devices; + ifrt_devices.reserve(py_devices.size()); + for (const auto& py_device : py_devices) { + if (py_device->client().get() != client.get()) { + throw nb::value_error( + "devices passed to get_topology_for_devices come from " + "different clients."); + } + ifrt_devices.push_back(py_device->device()); + } + xla::ifrt::DeviceListRef device_list = xla::ValueOrThrow( + client->ifrt_client()->MakeDeviceList(ifrt_devices)); + return xla::ValueOrThrow( + client->ifrt_client()->GetTopologyForDevices(device_list)); + }); + + TF_CHECK_OK(PyArray::Register(m)); + PyDeviceList::Register(m); + RegisterSharding(m); + + nb::class_(m, "CompiledMemoryStats") + .def_rw("generated_code_size_in_bytes", + &xla::CompiledMemoryStats::generated_code_size_in_bytes) + .def_rw("argument_size_in_bytes", + &xla::CompiledMemoryStats::argument_size_in_bytes) + .def_rw("output_size_in_bytes", + &xla::CompiledMemoryStats::output_size_in_bytes) + .def_rw("alias_size_in_bytes", + &xla::CompiledMemoryStats::alias_size_in_bytes) + .def_rw("temp_size_in_bytes", + &xla::CompiledMemoryStats::temp_size_in_bytes) + .def_rw("host_generated_code_size_in_bytes", + &xla::CompiledMemoryStats::host_generated_code_size_in_bytes) + .def_rw("host_argument_size_in_bytes", + &xla::CompiledMemoryStats::host_argument_size_in_bytes) + .def_rw("host_output_size_in_bytes", + &xla::CompiledMemoryStats::host_output_size_in_bytes) + .def_rw("host_alias_size_in_bytes", + &xla::CompiledMemoryStats::host_alias_size_in_bytes) + .def_rw("host_temp_size_in_bytes", + &xla::CompiledMemoryStats::host_temp_size_in_bytes) + .def_prop_ro("serialized_buffer_assignment_proto", + [](const xla::CompiledMemoryStats& cms) -> nb::bytes { + const std::string& s = cms.serialized_buffer_assignment; + return nb::bytes(s.data(), s.size()); + }) + .def_rw("peak_memory_in_bytes", + &xla::CompiledMemoryStats::peak_memory_in_bytes) + .def("__str__", &xla::CompiledMemoryStats::DebugString); + + m.def("get_execution_stream_id", []() { return GetExecutionStreamId(); }); + m.def("set_execution_stream_id", + [](int64_t id) { GetExecutionStreamId() = id; }); + + PyLoadedExecutable::Register(m); + PyExecuteResults::Register(m); + PyToken::Register(m); + PyShardedToken::Register(m); + PyExecutable::Register(m); + + m.def("buffer_to_dlpack_managed_tensor", + xla::ValueOrThrowWrapper(BufferToDLPackManagedTensor), + nb::arg("buffer"), nb::arg("stream").none() = nb::none()); + m.def( + "dlpack_managed_tensor_to_buffer", + [](const nb::capsule& tensor, nb_class_ptr device, + std::optional stream, std::optional copy) { + return xla::ValueOrThrow(DLPackManagedTensorToBuffer( + tensor, device->device(), device->client(), stream, copy)); + }, + nb::arg("dlpack"), nb::arg("device"), nb::arg("stream").none(), + nb::arg("copy").none() = nb::none(), + nb::sig( + // clang-format off + "def dlpack_managed_tensor_to_buffer(" + "dlpack: typing_extensions.CapsuleType, " + "device: Device, " + "stream: int | None, " + "copy: bool | None = ..." + ") -> ArrayImpl" + // clang-format on + )); + m.def("cuda_array_interface_to_buffer", + xla::ValueOrThrowWrapper(CudaArrayInterfaceToBuffer), nb::arg("cai"), + nb::arg("gpu_backend").none() = nb::none(), + nb::arg("device_id").none() = nb::none()); + + nb::enum_(m, "RuntimeTracebackMode") + .value("OFF", jax::RuntimeTracebackMode::kOff) + .value("ON", jax::RuntimeTracebackMode::kOn) + .value("FULL", jax::RuntimeTracebackMode::kFull); + m.def("add_exclude_path", &jax::AddExcludePath, + "Adds a path to exclude from tracebacks."); + m.def("set_send_traceback_to_runtime_global", + &jax::SetSendTracebackToRuntimeGlobal); + m.def("set_send_traceback_to_runtime_thread_local", + &jax::SetSendTracebackToRuntimeThreadLocal, nb::arg("mode").none()); + + BuildConfigSubmodule(m); + BuildIfrtProgramsSubmodule(m); + BuildPytreeSubmodule(m); + BuildGuardSubmodule(m); + BuildJaxjitSubmodule(m); + BuildPmapSubmodule(m); + BuildPjitSubmodule(m); + Traceback::Register(m); + BuildMlirSubmodule(m); + BuildCustomCallShardingPybindAPI(m); + RegisterFfiApis(m); +#if defined(__linux__) + aux::RegisterTransferServerTypes(m); +#endif // defined(__linux__) + + nb::class_ preemption_sync_manager( + m, "PreemptionSyncManager"); + preemption_sync_manager + .def( + "initialize", + [](xla::PreemptionSyncManager& manager, + xla::DistributedRuntimeClient* client) { + xla::CoordinationServiceAgent* agent = + xla::ValueOrThrow(client->GetCoordinationServiceAgent()); + xla::ThrowIfError(manager.Initialize(agent)); + }, + nb::arg("distributed_client")) + .def("reached_sync_point", + [](xla::PreemptionSyncManager& manager, int step_counter) { + return manager.ReachedSyncPoint(step_counter); + }) + .def("shutdown", [](xla::PreemptionSyncManager& manager) { + nb::gil_scoped_release gil_release; + manager.Shutdown(); + }); + m.def("create_preemption_sync_manager", + []() { return xla::CreatePreemptionSyncManager(); }); + + nb::class_ distributed_runtime_service( + m, "DistributedRuntimeService"); + distributed_runtime_service.def("shutdown", + &xla::DistributedRuntimeService::Shutdown, + nb::call_guard()); + nb::class_ distributed_runtime_client( + m, "DistributedRuntimeClient"); + distributed_runtime_client + .def("connect", + [](xla::DistributedRuntimeClient& self) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(self.Connect()); + }) + .def("shutdown", + [](xla::DistributedRuntimeClient& self) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(self.Shutdown()); + }) + // This method assumes that the value is a Python string. Use + // `blocking_key_value_get_bytes()` if key_value_set() was called with a + // Python bytes object as its value. + .def( + "blocking_key_value_get", + [](xla::DistributedRuntimeClient& client, std::string key, + int64_t timeout_in_ms) { + nb::gil_scoped_release gil_release; + return xla::ValueOrThrow(client.BlockingKeyValueGet( + key, absl::Milliseconds(timeout_in_ms))); + }, + nb::arg("key"), nb::arg("timeout_in_ms")) + // Same as `blocking_key_value_get()`, but retrieves the raw Python byte + // values explicitly. + .def( + "blocking_key_value_get_bytes", + [](xla::DistributedRuntimeClient& client, std::string key, + int64_t timeout_in_ms) -> nb::bytes { + std::string result; + { + nb::gil_scoped_release gil_release; + result = xla::ValueOrThrow(client.BlockingKeyValueGet( + key, absl::Milliseconds(timeout_in_ms))); + } + return nb::bytes(result.data(), result.size()); + }, + nb::arg("key"), nb::arg("timeout_in_ms")) + .def( + "key_value_try_get", + [](xla::DistributedRuntimeClient& client, std::string key) { + nb::gil_scoped_release gil_release; + return xla::ValueOrThrow(client.KeyValueTryGet(key)); + }, + nb::arg("key")) + .def( + "key_value_try_get_bytes", + [](xla::DistributedRuntimeClient& client, + std::string key) -> nb::bytes { + std::string result; + { + nb::gil_scoped_release gil_release; + result = xla::ValueOrThrow(client.KeyValueTryGet(key)); + } + return nb::bytes(result.data(), result.size()); + }, + nb::arg("key")) + .def( + "key_value_increment", + [](xla::DistributedRuntimeClient& client, std::string key, + int64_t increment) { + nb::gil_scoped_release gil_release; + return xla::ValueOrThrow(client.KeyValueIncrement(key, increment)); + }, + nb::arg("key"), nb::arg("increment")) + .def( + "wait_at_barrier", + [](xla::DistributedRuntimeClient& client, std::string barrier_id, + int64_t timeout_in_ms, + std::optional> process_ids) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(client.WaitAtBarrier( + barrier_id, absl::Milliseconds(timeout_in_ms), process_ids)); + }, + nb::arg("barrier_id"), nb::arg("timeout_in_ms"), + nb::arg("process_ids") = std::nullopt) + .def( + "get_live_nodes", + [](xla::DistributedRuntimeClient& client, + std::vector process_ids) { + nb::gil_scoped_release gil_release; + // Python doesn't understand the IncarnationId type, so we convert + // to regular integers before returning. + absl::flat_hash_map nodes = + xla::ValueOrThrow( + client.GetLiveNodesWithIncarnations(process_ids)); + absl::flat_hash_map py_nodes; + for (const auto& [task_id, incarnation_id] : nodes) { + py_nodes[task_id] = incarnation_id.value(); + } + return py_nodes; + }, + nb::arg("process_ids")) + // The key must be a string, but the value can either be a Python string + // or bytes object. + // With Python string values, use `key_value_set()` and + // `blocking_key_value_get()`. + // With Python byte object values, use `key_value_set()` and + // `blocking_key_value_get_bytes()`. + .def( + "key_value_set", + [](xla::DistributedRuntimeClient& client, std::string_view key, + std::string_view value, bool allow_overwrite) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(client.KeyValueSet(key, value, allow_overwrite)); + }, + nb::arg("key"), nb::arg("value"), nb::arg("allow_overwrite") = false) + // The key must be a string, but the value must a + // Python bytes object. + // Use `key_value_set_bytes()` and `blocking_key_value_get_bytes()`. + .def( + "key_value_set_bytes", + [](xla::DistributedRuntimeClient& client, std::string_view key, + nb::bytes value, bool allow_overwrite) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(client.KeyValueSet( + key, std::string_view(value.c_str(), value.size()), + allow_overwrite)); + }, + nb::arg("key"), nb::arg("value"), nb::arg("allow_overwrite") = false) + // Assumes that all values in the directory are Python strings. + .def( + "key_value_dir_get", + [](xla::DistributedRuntimeClient& client, std::string_view key) { + nb::gil_scoped_release gil_release; + return xla::ValueOrThrow(client.KeyValueDirGet(key)); + }, + nb::arg("key")) + // Assumes that all values in the directory are Python byte objects. + // Same as `key_value_dir_get()`, but retrieves Python byte values + // explicitly. + .def( + "key_value_dir_get_bytes", + [](xla::DistributedRuntimeClient& client, std::string_view key) + -> std::vector> { + std::vector> result; + { + nb::gil_scoped_release gil_release; + result = xla::ValueOrThrow(client.KeyValueDirGet(key)); + } + // Convert std::string values to nb::bytes. + std::vector> kvs; + kvs.reserve(result.size()); + for (auto& kv : result) { + kvs.push_back( + std::pair(std::move(kv.first), + nb::bytes(kv.second.data(), kv.second.size()))); + } + return kvs; + }, + nb::arg("key")) + .def( + "key_value_delete", + [](xla::DistributedRuntimeClient& client, std::string_view key) { + nb::gil_scoped_release gil_release; + return xla::ThrowIfError(client.KeyValueDelete(key)); + }, + nb::arg("key")); + + m.def( + "get_distributed_runtime_service", + [](std::string address, int num_nodes, + std::optional heartbeat_timeout, + std::optional cluster_register_timeout, + std::optional shutdown_timeout) + -> std::unique_ptr { + xla::CoordinationServiceImpl::Options options; + options.num_nodes = num_nodes; + if (heartbeat_timeout.has_value()) { + options.heartbeat_timeout = absl::Seconds(*heartbeat_timeout); + } + if (cluster_register_timeout.has_value()) { + options.cluster_register_timeout = + absl::Seconds(*cluster_register_timeout); + } + if (shutdown_timeout.has_value()) { + options.shutdown_timeout = absl::Seconds(*shutdown_timeout); + } + std::unique_ptr service = + xla::ValueOrThrow(GetDistributedRuntimeService(address, options)); + return service; + }, + nb::arg("address"), nb::arg("num_nodes"), + nb::arg("heartbeat_timeout").none() = std::nullopt, + nb::arg("cluster_register_timeout").none() = std::nullopt, + nb::arg("shutdown_timeout").none() = std::nullopt); + + m.def( + "get_distributed_runtime_client", + [](std::string address, int node_id, std::optional rpc_timeout, + std::optional init_timeout, std::optional shutdown_timeout, + std::optional heartbeat_timeout, + std::optional missed_heartbeat_callback, + std::optional shutdown_on_destruction, + std::optional use_compression, std::optional recoverable) + -> std::shared_ptr { + bool compression = use_compression.value_or(false); + xla::DistributedRuntimeClient::Options options; + options.node_id = node_id; + if (rpc_timeout.has_value()) { + options.rpc_timeout = absl::Seconds(*rpc_timeout); + } + if (init_timeout.has_value()) { + options.init_timeout = absl::Seconds(*init_timeout); + } + if (shutdown_timeout.has_value()) { + options.shutdown_timeout = absl::Seconds(*shutdown_timeout); + } + if (heartbeat_timeout.has_value()) { + options.heartbeat_timeout = absl::Seconds(*heartbeat_timeout); + } + if (missed_heartbeat_callback.has_value()) { + options.missed_heartbeat_callback = + nb::cast>( + *missed_heartbeat_callback); + } + if (shutdown_on_destruction.has_value()) { + options.shutdown_on_destruction = *shutdown_on_destruction; + } + if (recoverable.has_value()) { + options.recoverable = *recoverable; + } + return GetDistributedRuntimeClient(address, options, compression); + }, + nb::arg("address"), nb::arg("node_id"), + nb::arg("rpc_timeout").none() = std::nullopt, + nb::arg("init_timeout").none() = std::nullopt, + nb::arg("shutdown_timeout").none() = std::nullopt, + nb::arg("heartbeat_timeout").none() = std::nullopt, + nb::arg("missed_heartbeat_callback").none() = std::nullopt, + nb::arg("shutdown_on_destruction").none() = std::nullopt, + nb::arg("use_compression").none() = std::nullopt, + nb::arg("recoverable").none() = std::nullopt); + + m.def("collect_garbage", []() { GlobalPyRefManager()->CollectGarbage(); }); + + m.def("is_optimized_build", &IsOptimizedBuild); + + m.def("json_to_pprof_profile", + xla::ValueOrThrowWrapper(xla::JsonToPprofProfile), + "Encodes the JSON representation of a pprof Profile into its binary " + "protocol buffer encoding."); + m.def("pprof_profile_to_json", + xla::ValueOrThrowWrapper(xla::PprofProfileToJson), + "Decodes an uncompressed pprof Profile protocol buffer into a JSON " + "representation"); + + CompileOnlyPyClient::Register(m); + nb::class_(m, "DeviceTopology") + .def("_make_compile_only_devices", + [](std::shared_ptr topology) { + if (!llvm::isa(*topology)) { + throw xla::XlaRuntimeError("Only PjRtTopologies are supported."); + } + return CompileOnlyPyClient::Make( + std::dynamic_pointer_cast( + topology)) + ->Devices(); + }) + .def_prop_ro("platform", + [](xla::ifrt::Topology& topology) { + return topology.platform_name(); + }) + .def_prop_ro("platform_version", + [](xla::ifrt::Topology& topology) { + return topology.platform_version(); + }) + .def("serialize", + [](xla::ifrt::Topology& topology) -> nb::bytes { + std::string serialized = xla::ValueOrThrow(topology.Serialize()); + return nb::bytes(serialized.data(), serialized.size()); + }) + .def("__getattr__", + [](xla::ifrt::Topology& topology, + std::string_view name) -> nb::object { + auto value = + topology.Attributes().Get( + std::string(name)); + if (value.ok()) { + return std::visit([](auto&& v) { return nb::cast(v.value); }, + *value); + } + throw nb::attribute_error( + absl::StrCat("Unknown attribute ", name).c_str()); + }); + + nb::class_( + m, "TransferServerInterfaceFactory"); + + m.def("is_asan", IsAsan); + m.def("is_msan", IsMsan); + m.def("is_tsan", IsTsan); + m.def("is_sanitized", IsSanitized); + + m.def( + "batched_device_put", + [](nb::object aval, nb::object sharding, std::vector xs, + std::vector dst_devices, bool committed, + bool force_copy, + xla::PjRtClient::HostBufferSemantics host_buffer_semantics, + std::optional enable_x64) -> nb::object { + return xla::ValueOrThrow(PyArray::BatchedDevicePut( + aval, sharding, std::move(xs), std::move(dst_devices), committed, + force_copy, host_buffer_semantics, + enable_x64.has_value() ? *enable_x64 : GetEnableX64())); + }, + nb::arg("aval"), nb::arg("sharding"), nb::arg("xs"), nb::arg("devices"), + nb::arg("committed") = true, nb::arg("force_copy") = false, + nb::arg("host_buffer_semantics") = + xla::PjRtClient::HostBufferSemantics::kImmutableZeroCopy, + nb::arg("enable_x64").none() = std::nullopt); + m.def( + "reorder_shards", + [](PyArray x, nb::object dst_sharding, + xla::ifrt::ArrayCopySemantics array_copy_semantics) { + return xla::ValueOrThrow(PyArray::ReorderShards( + std::move(x), std::move(dst_sharding), array_copy_semantics)); + }, + nb::arg("x"), nb::arg("dst_sharding"), nb::arg("array_copy_semantics")); + + m.def("batched_block_until_ready", [](std::vector xs) { + xla::ThrowIfError(PyArray::BatchedBlockUntilReady(std::move(xs))); + }); + + m.def("check_and_canonicalize_memory_kind", &CheckAndCanonicalizeMemoryKind, + nb::arg("memory_kind").none(), nb::arg("device_list")); + + m.attr("ifrt_version_number") = JAX_IFRT_VERSION_NUMBER; + + m.def("approx_top_k_reduction_output_size", + xla::ValueOrThrowWrapper(xla::ApproxTopKReductionOutputSize), + nb::arg("input_size"), nb::arg("rank"), nb::arg("top_k"), + nb::arg("recall_target"), nb::arg("aggregate_to_topk") = true, + nb::arg("input_size_override") = -1); + + m.def("get_internal_device_put_info", + []() { return DevicePutInfo::GetInfo(); }); + + PartitionSpec::Register(m); + + m.def("set_typed_int_type", &SetTypedIntType); + m.def("set_typed_float_type", &SetTypedFloatType); + m.def("set_typed_complex_type", &SetTypedComplexType); + m.def("set_typed_ndarray_type", &SetTypedNdArrayType); +} // NOLINT(readability/fn_size) + +} // namespace jax diff --git a/jaxlib/jax_common.json b/jaxlib/jax_common.json new file mode 100644 index 000000000000..61a2c9313897 --- /dev/null +++ b/jaxlib/jax_common.json @@ -0,0 +1,8 @@ +{ + "global": [ + "Wrapped_PyInit_*" + ], + "local": [ + "*" + ] +} diff --git a/jaxlib/jax_jit.cc b/jaxlib/jax_jit.cc new file mode 100644 index 000000000000..92f246cc823e --- /dev/null +++ b/jaxlib/jax_jit.cc @@ -0,0 +1,546 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This files implements the `jax.jit` dispatch and just-in-time feature. +// +// In a nutshell, `Jit(f)` returns a callable that will dispatch (i.e. forward +// based on passed arguments dtypes/shapes/identity) the execution to a +// just-in-time compiled XLA Executable. All of that is done in C++ for +// performance reasons. +// +// This file contains the utilities to: +// (a) inspect arguments and describe their structure, dtype/shapes, etc. +// (b) keep a mapping from function signatures to compiled XLA Executables. + +#include "jaxlib/jax_jit.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/config.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_values.h" +#include "jaxlib/pytree.h" +#include "jaxlib/sharding.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/nb_absl_inlined_vector.h" // IWYU pragma: keep +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/types.h" +#include "xla/tsl/platform/logging.h" +#include "tsl/profiler/lib/traceme.h" + +namespace jax { + +namespace nb = nanobind; + +// TODO(phawkins): Add support for Tracers. +// TODO(jblespiau): Use absl absl::Status. + +namespace { + +nb_class_ptr& disable_jit_state = *new nb_class_ptr(); +nb_class_ptr& enable_x64_state = *new nb_class_ptr(); +nb_class_ptr& post_hook_state = *new nb_class_ptr(); + +// Callback called the first time the C++ jit accesses thread-local state. +nb::object& initialize_local_state = *new nb::object(); + +} // namespace + +void InitializeThreadLocalState() { + thread_local bool initialized = false; + if (!initialized) { + initialized = true; + // Set the flag first to avoid reentrant calls to the initialization + // function. + initialize_local_state(); + } +} + +bool GetDisableJit() { + if (!disable_jit_state.ptr()) { + throw std::runtime_error("disable_jit_state is not set"); + } + return nb::cast(disable_jit_state->Get()); +} + +bool GetEnableX64() { + if (!enable_x64_state.ptr()) { + throw std::runtime_error("enable_x64_state is not set"); + } + bool out = nb::cast(enable_x64_state->Get()); + return out; +} + +std::optional GetPostHook() { + if (!post_hook_state.ptr()) { + throw std::runtime_error("post_hook_state is not set"); + } + return nb::cast>(post_hook_state->Get()); +} + +std::string ArgumentSignature::DebugString() const { + auto py_object_formatter = [](std::string* out, const nb::object& o) { + out->append(nb::cast(nb::str(o))); + }; + auto treedef_formatter = [](std::string* out, const PyTreeDef& d) { + out->append(d.ToString()); + }; + return absl::StrFormat( + "static args (positional + keyword): [%s], " + "static arg keyword names: [%s], " + "dynamic arg signatures (positional + keyword): [%s], " + "dynamic arg shardings: [%s]", + absl::StrJoin(static_args, ",", py_object_formatter), + absl::StrJoin(static_arg_names, ",", py_object_formatter), + absl::StrJoin(dynamic_arg_names, ",", py_object_formatter), + absl::StrJoin(dynamic_arg_treedefs, "| ", treedef_formatter)); +} + +bool ArgumentSignature::operator==(const ArgumentSignature& other) const { + if (dynamic_arg_treedefs != other.dynamic_arg_treedefs) { + return false; + } + auto object_ptr_equality = [](nb::handle a, nb::handle b) { + return a.ptr() == b.ptr(); + }; + if (!absl::c_equal(dynamic_arg_names, other.dynamic_arg_names, + object_ptr_equality)) { + return false; + } + if (!absl::c_equal(static_arg_names, other.static_arg_names, + object_ptr_equality)) { + return false; + } + return absl::c_equal( + static_args, other.static_args, + [](const nb::object& a, const nb::object& b) { + try { + return a.type().ptr() == b.type().ptr() && a.equal(b); + } catch (const nb::python_error& e) { + throw std::invalid_argument(absl::StrCat( + "static arguments should be comparable using __eq__." + "The following error was raised when comparing two objects of " + "types ", + nb::cast(nb::str(a.type())), " and ", + nb::cast(nb::str(b.type())), + ". The error was:\n", e.what())); + } + }); +} + +std::string CallSignature::DebugString() const { + auto py_object_formatter = [](std::string* out, const nb::object& o) { + out->append(nb::cast(nb::str(o))); + }; + auto signature_formatter = [](std::string* out, const PyArgSignature& s) { + out->append(s.DebugString()); + }; + auto layout_formatter = [](std::string* out, + const std::shared_ptr& l) { + if (l != nullptr) { + out->append(l->ToString()); + } else { + out->append("None"); + } + }; + auto bool_formatter = [](std::string* out, bool o) { + out->append(o ? "true" : "false"); + }; + std::vector config_names = JitConfigNames(); + std::vector config_strs; + config_strs.reserve(configs.size()); + for (int i = 0; i < configs.size(); ++i) { + config_strs.push_back(absl::StrFormat( + "%s: %s", i < config_names.size() ? config_names[i] : "unknown", + nb::cast(nb::str(configs[i])))); + } + return absl::StrFormat( + "arg signature: %s\n" + "dynamic arg signatures (positional + keyword): %s\n" + "dynamic arg shardings: %s\n" + "dynamic arg layouts: %s\n" + "committed args: %s\n" + "device: %s\n" + "configs: %s\n", + arg_signature.DebugString(), + absl::StrJoin(dynamic_arg_signatures, ", ", signature_formatter), + absl::StrJoin(dynamic_arg_shardings, ", ", py_object_formatter), + absl::StrJoin(dynamic_arg_layouts, ", ", layout_formatter), + absl::StrJoin(committed_args, ",", bool_formatter), + device != nullptr ? device->DebugString() : "nullptr", + absl::StrJoin(config_strs, ", ")); +} + +size_t HashShardingForJit(nb::handle sharding) { + auto type = sharding.type(); + + if (type.is(NamedSharding::type())) { + const auto* named_sharding = nb::inst_ptr(sharding); + return absl::Hash()(named_sharding->mesh().ptr()); + } + + if (type.is(GSPMDSharding::type())) { + auto* gspmd_sharding = nb::inst_ptr(sharding); + return gspmd_sharding->Hash(); + } + + if (type.is(SingleDeviceSharding::type())) { + auto* single_device_sharding = nb::inst_ptr(sharding); + return absl::Hash()(single_device_sharding->device().ptr()); + } + + try { + return nb::hash(sharding); + } catch (const nb::python_error& e) { + // Gracefully handle non-hashable sharding. We cannot let a C++ exception + // escape because this hash function may have been called from a code that + // disables C++ exception support. + return 0; + } +} + +bool EqualShardingsForJit(nb::handle a, nb::handle b) { + if (a.ptr() == b.ptr()) { + return true; + } + + auto a_type = a.type(); + auto b_type = b.type(); + + if (!a_type.is(b_type)) { + return false; + } + + if (a_type.is(NamedSharding::type())) { + auto* a_named_sharding = nb::inst_ptr(a); + auto* b_named_sharding = nb::inst_ptr(b); + return a_named_sharding->mesh().ptr() == b_named_sharding->mesh().ptr() && + *a_named_sharding->spec() == *b_named_sharding->spec() && + a_named_sharding->memory_kind().equal( + b_named_sharding->memory_kind()) && + a_named_sharding->logical_device_ids().equal( + b_named_sharding->logical_device_ids()); + } + + if (a_type.is(GSPMDSharding::type())) { + auto* a_gspmd_sharding = nb::inst_ptr(a); + auto* b_gspmd_sharding = nb::inst_ptr(b); + return *a_gspmd_sharding == *b_gspmd_sharding; + } + + if (a_type.is(SingleDeviceSharding::type())) { + auto* a_single_device_sharding = + nb::inst_ptr(a); + auto* b_single_device_sharding = + nb::inst_ptr(b); + return a_single_device_sharding->device().ptr() == + b_single_device_sharding->device().ptr() && + a_single_device_sharding->memory_kind().equal( + b_single_device_sharding->memory_kind()); + } + + return a.equal(b); +} + +bool CallSignature::operator==(const CallSignature& other) const { + if (arg_signature != other.arg_signature) { + return false; + } + if (dynamic_arg_signatures != other.dynamic_arg_signatures) { + return false; + } + if (device != other.device) { + return false; + } + if (committed_args != other.committed_args) { + return false; + } + return + // `==` on py:objects is the Python `is`. We need equal. + absl::c_equal(dynamic_arg_shardings, other.dynamic_arg_shardings, + EqualShardingsForJit) && + absl::c_equal(dynamic_arg_layouts, other.dynamic_arg_layouts, + [](const std::shared_ptr& a, + const std::shared_ptr& b) { + return (a && b) ? *a == *b : a == b; + }) && + configs.size() == other.configs.size() && + absl::c_equal( + configs, other.configs, + [](const nb::object& a, const nb::object& b) { return a.equal(b); }); +} + +// Filter out static arguments, flatten and concatenate other arguments (i.e. +// dynamic positional and keyword arguments), filling `arguments` in place. +absl::Status ParseArguments( + absl::Span positional_args, + absl::Span keyword_args, nb::handle kwnames, + absl::Span static_argnums, + absl::Span static_argnames, PyTreeRegistry* pytree_registry, + ArgumentSignature& signature, + absl::InlinedVector& flat_dynamic_args) { + tsl::profiler::TraceMe traceme("ParseArguments"); + + DCHECK(absl::c_all_of(static_argnames, [](const nb::str& name) { + return PyUnicode_CHECK_INTERNED(name.ptr()); + })); + + flat_dynamic_args.reserve(positional_args.size() + keyword_args.size()); + if (static_argnums.empty()) { + signature.dynamic_arg_treedefs.reserve(positional_args.size()); + + // Positional arguments. + for (int i = 0; i < positional_args.size(); ++i) { + signature.dynamic_arg_treedefs.emplace_back(pytree_registry); + PyTreeDef& pytree_def = signature.dynamic_arg_treedefs.back(); + pytree_def.Flatten(nb::handle(positional_args[i]), flat_dynamic_args); + } + } else { + signature.dynamic_arg_treedefs.reserve(positional_args.size()); + + // Positional arguments. + int num_positional_args = positional_args.size(); + for (int i = 0; i < positional_args.size(); ++i) { + if (std::find_if(static_argnums.begin(), static_argnums.end(), + [i, num_positional_args](int t) { + return t >= 0 ? i == t : i == t + num_positional_args; + }) == static_argnums.end()) { + signature.dynamic_arg_treedefs.emplace_back(pytree_registry); + PyTreeDef& pytree_def = signature.dynamic_arg_treedefs.back(); + pytree_def.Flatten(positional_args[i], flat_dynamic_args); + } else { + signature.static_args.emplace_back( + nb::borrow(positional_args[i])); + } + } + } + + // Keyword arguments. + if (!keyword_args.empty()) { + std::vector> kwargs(keyword_args.size()); + // We first intern the keys, then sort them (by name, as in the Python path) + // (see also PyTreeDef::Flatten) and then create the signatures. + // TODO(jblespiau): We should be able to sort the keys by interned-key + // pointers, but this requires the Python compilation to do the same. + for (int i = 0; i < keyword_args.size(); ++i) { + // Intern the key if not already interned. + PyObject* key = PyTuple_GET_ITEM(kwnames.ptr(), i); + Py_INCREF(key); + if (!PyUnicode_CHECK_INTERNED(key)) { + PyUnicode_InternInPlace(&key); + } + kwargs[i].first = key; + kwargs[i].second = keyword_args[i]; + } + + std::sort(kwargs.begin(), kwargs.end(), + [](const std::pair& a, + const std::pair& b) { + return a.first < b.first; + }); + auto kwarg_is_static = [&](nb::handle name) { + for (const auto& kw : static_argnames) { + if (kw.ptr() == name.ptr()) return true; + } + return false; + }; + + signature.dynamic_arg_names.reserve(keyword_args.size()); + for (int i = 0; i < keyword_args.size(); ++i) { + if (kwarg_is_static(kwargs[i].first)) { + signature.static_arg_names.push_back( + nb::steal(kwargs[i].first)); + signature.static_args.push_back( + nb::borrow(kwargs[i].second)); + } else { + signature.dynamic_arg_names.push_back( + nb::steal(kwargs[i].first)); + signature.dynamic_arg_treedefs.emplace_back(pytree_registry); + PyTreeDef& pytree_def = signature.dynamic_arg_treedefs.back(); + pytree_def.Flatten(nb::handle(kwargs[i].second.ptr()), + flat_dynamic_args); + } + } + } + return absl::OkStatus(); +} + +void BuildJaxjitSubmodule(nb::module_& m) { + nb::module_ jitlib = m.def_submodule("jax_jit", "Jax C++ jit library"); + + jitlib.attr("_Config") = m.attr("config").attr("Config"); + jitlib.attr("_PyTreeDef") = m.attr("pytree").attr("PyTreeDef"); + jitlib.attr("_PyTreeRegistry") = m.attr("pytree").attr("PyTreeRegistry"); + + jitlib.def( + "set_disable_jit_state", + [](nb_class_ptr config) { disable_jit_state = config; }, + nb::sig("def set_disable_jit_state(config: _Config) -> None")); + jitlib.def( + "set_enable_x64_state", + [](nb_class_ptr config) { enable_x64_state = config; }, + nb::sig("def set_enable_x64_state(config: _Config) -> None")); + jitlib.def( + "set_post_hook_state", + [](nb_class_ptr config) { post_hook_state = config; }, + nb::sig("def set_post_hook_state(config: _Config) -> None")); + + jitlib.def( + "set_thread_local_state_initialization_callback", + [](nb::object f) { initialize_local_state = f; }, + nb::sig("def set_thread_local_state_initialization_callback(" + "f: typing.Callable[[], None]) -> None")); + + nb::class_ arg_signature(jitlib, "PyArgSignature"); + arg_signature + .def_prop_ro( + "dtype", + [](const PyArgSignature& sig) { + return xla::ValueOrThrow(xla::PrimitiveTypeToNbDtype(sig.dtype)); + }) + .def_prop_ro("shape", + [](const PyArgSignature& sig) { + return xla::SpanToNbTuple(absl::MakeConstSpan(sig.shape)); + }) + .def_ro("weak_type", &PyArgSignature::weak_type); + jitlib.def("_ArgSignatureOfValue", + xla::ValueOrThrowWrapper(PyArgSignatureOfValue)); + + nb::class_ argument_signature(jitlib, "ArgumentSignature"); + argument_signature.def_ro("static_args", &ArgumentSignature::static_args) + .def_ro("static_arg_names", &ArgumentSignature::static_arg_names) + .def_ro("dynamic_arg_names", &ArgumentSignature::dynamic_arg_names) + .def_ro( + "dynamic_arg_treedefs", &ArgumentSignature::dynamic_arg_treedefs, + nb::sig( + "def dynamic_arg_treedefs(self) -> typing.Sequence[_PyTreeDef]")) + .def("__repr__", &ArgumentSignature::DebugString) + .def("__str__", &ArgumentSignature::DebugString) + .def("__hash__", + [](const ArgumentSignature& s) { return absl::HashOf(s); }) + .def( + "__eq__", + [](const ArgumentSignature& a, nb::object b) { + return nb::isinstance(b) && + a == nb::cast(b); + }, + nb::is_operator()) + .def( + "__ne__", + [](const ArgumentSignature& a, nb::object b) { + return !nb::isinstance(b) || + a != nb::cast(b); + }, + nb::is_operator()); + + jitlib.def( + "parse_arguments", + [](nb::sequence positional_args, nb::sequence keyword_args, + nb::typed kwnames, + absl::Span static_argnums, + absl::Span static_argnames, + PyTreeRegistry* pytree_registry) { + ArgumentSignature signature; + absl::InlinedVector flat_dynamic_args; + nb::object positional_args_seq = nb::steal(PySequence_Fast( + positional_args.ptr(), "positional_args must be a list or tuple")); + if (!positional_args_seq.ptr()) { + throw nb::python_error(); + } + nb::object keyword_args_seq = nb::steal(PySequence_Fast( + keyword_args.ptr(), "keyword_args must be a list or tuple")); + if (!keyword_args_seq.ptr()) { + throw nb::python_error(); + } + absl::Span positional_args_span = + absl::MakeSpan(PySequence_Fast_ITEMS(positional_args_seq.ptr()), + PySequence_Fast_GET_SIZE(positional_args_seq.ptr())); + absl::Span keyword_args_span = + absl::MakeSpan(PySequence_Fast_ITEMS(keyword_args_seq.ptr()), + PySequence_Fast_GET_SIZE(keyword_args_seq.ptr())); + + // Intern the static argument names. + std::vector static_argnames_interned; + static_argnames_interned.reserve(static_argnames.size()); + for (const nb::str& name : static_argnames) { + PyObject* s = name.inc_ref().ptr(); + PyUnicode_InternInPlace(&s); + static_argnames_interned.push_back(nb::steal(s)); + } + + xla::ThrowIfError( + ParseArguments(positional_args_span, keyword_args_span, kwnames, + static_argnums, static_argnames_interned, + pytree_registry, signature, flat_dynamic_args)); + return std::make_pair(std::move(signature), + std::move(flat_dynamic_args)); + }, + nb::arg("positional_args"), nb::arg("keyword_args"), nb::arg("kwnames"), + nb::arg("static_argnums"), nb::arg("static_argnames"), + nb::arg("pytree_registry"), + nb::sig( + // clang-format off + "def parse_arguments(" + "positional_args: Sequence[object], " + "keyword_args: Sequence[object], " + "kwnames: tuple[str, ...], " + "static_argnums: Sequence[int], " + "static_argnames: Sequence[str], " + "pytree_registry: _PyTreeRegistry" + ") -> tuple[ArgumentSignature, list[object]]" + // clang-format on + ), + R"doc(Parses the arguments to a function as jax.jit would. + +Returns a ArgumentSignature and the flattened dynamic arguments. + +Args: + positional_args: The positional arguments. + keyword_args: The keyword arguments. + kwnames: The keyword names. + static_argnums: The static argument numbers. + static_argnames: The static argument names. + pytree_registry: The pytree registry. +)doc"); +} + +} // namespace jax diff --git a/jaxlib/jax_jit.h b/jaxlib/jax_jit.h new file mode 100644 index 000000000000..8f5c2dff8500 --- /dev/null +++ b/jaxlib/jax_jit.h @@ -0,0 +1,232 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_JAX_JIT_H_ +#define JAXLIB_JAX_JIT_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/py_values.h" +#include "jaxlib/pytree.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/tsl/platform/logging.h" + +namespace jax { + +void InitializeThreadLocalState(); +bool GetDisableJit(); +bool GetEnableX64(); +std::optional GetPostHook(); + +// An ArgumentSignature describes the static arguments to a function call, and +// how the dynamic arguments are related to the arguments. Together with the +// values of the dynamic arguments, this fully describes the arguments. +struct ArgumentSignature { + // A PyTreeDef for each dynamic argument, positional arguments first + // followed by keyword arguments. Keyword arguments are in the order given + // by dynamic_arg_names. + absl::InlinedVector dynamic_arg_treedefs; + + // Dynamic keyword argument names. Interned, and sorted by the keyword + // name. Interned values are safe to compare by pointer. + std::vector dynamic_arg_names; + + // Static arguments. Contains the positional arguments sorted in argument + // order, followed by static keyword arguments in the order given by + // `static_arg_names`. + std::vector static_args; + + // Static keyword argument names. Interned, and sorted by keyword name. + std::vector static_arg_names; + + bool operator==(const ArgumentSignature& other) const; + bool operator!=(const ArgumentSignature& other) const { + return !(*this == other); + } + + std::string DebugString() const; +}; + +template +H AbslHashValue(H h, const ArgumentSignature& s) { + h = H::combine(std::move(h), s.dynamic_arg_treedefs, + s.dynamic_arg_names.size(), s.static_args.size(), + s.static_arg_names.size()); + + for (const auto& name : s.dynamic_arg_names) { + h = H::combine(std::move(h), name.ptr()); + } + for (size_t i = 0; i < s.static_args.size(); ++i) { + const auto& static_arg = s.static_args[i]; + Py_hash_t hash; + try { + hash = nanobind::hash(static_arg); + } catch (const nanobind::python_error& e) { + if (!e.matches(PyExc_TypeError)) throw; + throw std::invalid_argument(absl::StrCat( + "Non-hashable static arguments are not supported. An error occurred " + "while trying to hash an object of type ", + nanobind::cast(nanobind::str(static_arg.type())), + ", ", nanobind::cast(nanobind::str(static_arg)), + ". The error was:\n", e.what(), "\n")); + } + h = H::combine(std::move(h), hash); + } + for (const auto& name : s.static_arg_names) { + h = H::combine(std::move(h), name.ptr()); + } + return h; +} + +// Filter out static arguments, flatten and concatenate other arguments (i.e. +// dynamic positional and keyword arguments), filling `arguments` in place. +// Args: +// positional_args: positional arguments +// keyword_args: the values of the keyword arguments +// kwnames: either None or a tuple containing the keyword argument names +// static_argnums: the indices of the static arguments in the positional +// arguments +// static_argnames: the names of the static arguments, which must be interned. +// pytree_registry: the registry to use to convert the arguments to pytrees +// signature: output; describes the static arguments and the identities of the +// dynamic arguments. +// flat_dynamic_args: output; the concatenation of the dynamic positional +// arguments and sorted keyword arguments. +absl::Status ParseArguments( + absl::Span positional_args, + absl::Span keyword_args, nanobind::handle kwnames, + absl::Span static_argnums, + absl::Span static_argnames, + PyTreeRegistry* pytree_registry, ArgumentSignature& signature, + absl::InlinedVector& flat_dynamic_args); + +// The signature of Python jitted function call, partitioned into: +// - dynamic positional arguments (i.e. positional args which are not static) +// - static positional arguments (i.e. the args associated to static_argnums) +// - keyword arguments +// The CallSignature should unambiguously identify a function call, thus, +// equality is based on: +// (a) Same PyTree for all dynamic positional arguments and keyword arguments +// (a) equality of the arguments and keyword arguments ArgSignature +// (a) equality (delegated to Python) of the static arguments. +struct CallSignature { + // Not part of the signature, but we need it for error messages. + std::string_view function_name; + + ArgumentSignature arg_signature; + + // Shape and dtype for both the dynamic positional arguments and the keyword + // arguments (sorted by keyword name). + absl::InlinedVector dynamic_arg_signatures; + + // The sharding of the jax.Array arguments. + std::vector dynamic_arg_shardings; + + // The layout of the jax.Array arguments. + std::vector> dynamic_arg_layouts; + + absl::InlinedVector committed_args; + + // For JIT, we need this in the key because computation follows the data, so + // we may have multiple executables depending on the devices the data is on. + // This is not the case for PMAP, and is set to `nullptr`. + xla::PjRtDevice* device = nullptr; + + std::vector configs; + + // Cached hash of the signature. Must be filled in using `absl::HashOf` as + // part of CallSignature construction. The hash computation happens in a C++ + // exception-safe context, which simplifies using `CallSignature` as a key in + // a non-exception-safe container because `Hash()` would never throw when used + // inside the container implementation. + size_t cached_hash; + + bool operator==(const CallSignature& other) const; + bool operator!=(const CallSignature& other) const { + return !(*this == other); + } + + std::string DebugString() const; + + struct Hash { + size_t operator()(const CallSignature& s) const noexcept { + return s.cached_hash; + } + }; +}; + +// A hash and equality for shardings that may sometimes return different hashes +// for equal values, and may sometimes return "not equal" for equal values. +// These are not correct implementations of `__hash__` and `__eq__` in python, +// but they are fine for jit/pjit dispatch since they only causes spurious cache +// misses. +size_t HashShardingForJit(nanobind::handle sharding); +bool EqualShardingsForJit(nanobind::handle a, nanobind::handle b); + +template +H AbslHashValue(H h, const CallSignature& s) { + h = H::combine(std::move(h), s.arg_signature, s.dynamic_arg_signatures); + + DCHECK(s.dynamic_arg_shardings.empty() || + s.dynamic_arg_shardings.size() == s.dynamic_arg_signatures.size()); + + DCHECK(s.dynamic_arg_layouts.empty() || + s.dynamic_arg_layouts.size() == s.dynamic_arg_signatures.size()); + + // TODO(chky): For now, we are only hashing the pointer of shardings to avoid + // slow python hashing function. Consider implementing hashing function and + // equality checks in C++ in Sharding and use those here. + for (const auto& sharding : s.dynamic_arg_shardings) { + h = H::combine(std::move(h), HashShardingForJit(sharding)); + } + + for (const auto& layout : s.dynamic_arg_layouts) { + if (layout != nullptr) { + h = H::combine(std::move(h), *layout); + } + } + + h = H::combine(std::move(h), s.committed_args, s.device); + + // We do not hash the extra_jit_context fields since calling Python hash + // functions is expensive (~300ns) and we don't expect a large number of + // different contexts. + return h; +} + +// The function to call in `xla.cc` to add the bindings for this module. +void BuildJaxjitSubmodule(nanobind::module_& m); + +} // namespace jax + +#endif // JAXLIB_JAX_JIT_H_ diff --git a/jaxlib/kernel_helpers.h b/jaxlib/kernel_helpers.h index dac0355fbde6..5a053f833ce4 100644 --- a/jaxlib/kernel_helpers.h +++ b/jaxlib/kernel_helpers.h @@ -17,10 +17,10 @@ limitations under the License. #define JAXLIB_KERNEL_HELPERS_H_ #include -#include #include #include "absl/base/casts.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" namespace jax { diff --git a/jaxlib/kernel_nanobind_helpers.h b/jaxlib/kernel_nanobind_helpers.h index fde37e695349..127d89f702c8 100644 --- a/jaxlib/kernel_nanobind_helpers.h +++ b/jaxlib/kernel_nanobind_helpers.h @@ -19,8 +19,8 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" #include "absl/base/casts.h" +#include "nanobind/nanobind.h" #include "jaxlib/kernel_helpers.h" #include "xla/ffi/api/c_api.h" #include "xla/tsl/python/lib/core/numpy.h" // NOLINT diff --git a/jaxlib/libjax_common.lds b/jaxlib/libjax_common.lds new file mode 100644 index 000000000000..6130415a8d26 --- /dev/null +++ b/jaxlib/libjax_common.lds @@ -0,0 +1,7 @@ +{ + global: + Wrapped_PyInit_*; + + local: + *; +}; diff --git a/jaxlib/libjax_common_darwin.lds b/jaxlib/libjax_common_darwin.lds new file mode 100644 index 000000000000..aed9a1d7512a --- /dev/null +++ b/jaxlib/libjax_common_darwin.lds @@ -0,0 +1 @@ +*Wrapped_PyInit_* diff --git a/jaxlib/mlir.cc b/jaxlib/mlir.cc new file mode 100644 index 000000000000..61eda347d215 --- /dev/null +++ b/jaxlib/mlir.cc @@ -0,0 +1,277 @@ +/* Copyright 2021 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/mlir.h" + +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "shardy/dialect/mpmd/ir/dialect.h" +#include "shardy/dialect/sdy/ir/dialect.h" +#include "stablehlo/dialect/Serialization.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/translate/stablehlo.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/refine_polymorphic_shapes.h" +#include "xla/python/version.h" +#include "xla/service/hlo.pb.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" + +namespace nb = nanobind; + +namespace jax { +namespace { + +std::string PrintModule(mlir::ModuleOp module) { + std::string s; + llvm::raw_string_ostream os(s); + mlir::OpPrintingFlags flags; + flags.enableDebugInfo(); + module->print(os, flags); + return s; +} + +absl::StatusOr SerializeUsingBytecode(mlir::ModuleOp module) { + std::string bytecode; + llvm::raw_string_ostream os(bytecode); + mlir::BytecodeWriterConfig config; + if (mlir::failed(mlir::writeBytecodeToFile(module, os, config))) { + return absl::InvalidArgumentError("mlir::writeBytecodeToFile failed"); + } + return bytecode; +} + +void EnablePrintBeforeAndAfter(mlir::PassManager& pm) { + auto print_before = [](mlir::Pass*, mlir::Operation*) { return true; }; + auto print_after = [](mlir::Pass*, mlir::Operation*) { return true; }; + pm.enableIRPrinting(print_before, print_after); +} + +absl::StatusOr HloToStableHlo(const nb::bytes& hlo_module_proto) { + mlir::MLIRContext context; + if (VLOG_IS_ON(3)) context.disableMultithreading(); + xla::HloModuleProto proto; + proto.ParseFromArray(hlo_module_proto.c_str(), hlo_module_proto.size()); + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ConvertHloToStablehlo(context, &proto)); + TF_ASSIGN_OR_RETURN(std::string bytecode, SerializeUsingBytecode(*module)); + return nb::bytes(bytecode.data(), bytecode.size()); +} + +// Converts an XlaComputation to a StableHLO mlir::Module string. +// Exists for backwards compatibility. +// TODO(phawkins): port remaining users of XlaComputations to use mlir::Modules +// instead and delete this function. +absl::StatusOr PyXlaComputationToMlirModule( + const xla::XlaComputation& computation) { + mlir::MLIRContext context; + if (VLOG_IS_ON(3)) context.disableMultithreading(); + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ConvertHloToStablehlo(context, &computation.proto())); + return PrintModule(*module); +} + +absl::StatusOr PyMlirModuleToXlaComputation( + std::string_view mlir_module, bool use_tuple_args, bool return_tuple) { + mlir::MLIRContext context; + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + xla::ParseMlirModuleString(mlir_module, context)); + xla::XlaComputation computation; + TF_RETURN_IF_ERROR(MlirToXlaComputation(*module, computation, use_tuple_args, + return_tuple, + /*exec_build_options=*/nullptr)); + return computation; +} + +absl::StatusOr PyMhloToStablehlo(std::string_view mlir_module) { + mlir::MLIRContext context; + if (VLOG_IS_ON(3)) context.disableMultithreading(); + // JAX can be customized in a way that involves operations from custom + // dialects showing up in JAX IR. + // `xla::ParseMlirModuleString` won't know about these dialects, but that's + // fine since we just want to convert MHLO ops to StableHLO ops here and leave + // everything else unchanged. + // In order to achieve that, we're allowing unregistered dialects here. + context.allowUnregisteredDialects(true); + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + xla::ParseMlirModuleString(mlir_module, context)); + mlir::PassManager pm(&context); + if (VLOG_IS_ON(3)) EnablePrintBeforeAndAfter(pm); + pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); + if (!mlir::succeeded(pm.run(*module))) { + return tsl::errors::InvalidArgument("MHLO => StableHLO failed"); + } + // Use bytecode, passing unregistered dialects with properties causes issues + // when using textual assembly. + TF_ASSIGN_OR_RETURN(std::string bytecode, SerializeUsingBytecode(*module)); + return nb::bytes(bytecode.data(), bytecode.size()); +} + +absl::StatusOr PySerializePortableArtifact( + std::string_view mlir_module, std::string_view target, + bool use_mixed_serialization) { + mlir::MLIRContext context; + context.loadDialect(); + if (VLOG_IS_ON(3)) context.disableMultithreading(); + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + xla::ParseMlirModuleString(mlir_module, context)); + + // Serialize portable artifact + TF_ASSIGN_OR_RETURN( + std::string bytecode, + xla::SerializeUsingVersionedStablehlo(*module, target, /*inplace=*/true, + /*allow_mixed_serialization*/ + use_mixed_serialization)); + return nb::bytes(bytecode.data(), bytecode.size()); +} + +absl::StatusOr PyDeserializePortableArtifact( + const nb::bytes& bytecode_str) { + mlir::MLIRContext context; + context.loadDialect(); + mlir::OwningOpRef module = + mlir::stablehlo::deserializePortableArtifact( + std::string_view(bytecode_str.c_str(), bytecode_str.size()), + &context); + if (!module) + return tsl::errors::InvalidArgument("Failed to deserialize StableHLO"); + return PrintModule(*module); +} + +} // namespace + +void BuildMlirSubmodule(nb::module_& m) { + nb::module_ mlir_module = m.def_submodule("mlir", "MLIR/XLA integration"); + + mlir_module.attr("_XlaComputation") = m.attr("XlaComputation"); + + mlir_module.def("hlo_to_stablehlo", xla::ValueOrThrowWrapper(HloToStableHlo), + nb::arg("computation")); + + mlir_module.def("xla_computation_to_mlir_module", + xla::ValueOrThrowWrapper(PyXlaComputationToMlirModule), + nb::arg("computation"), + nb::sig( + // clang-format off + "def xla_computation_to_mlir_module(computation: _XlaComputation) -> str" + // clang-format on + )); + mlir_module.def( + "mlir_module_to_xla_computation", + [](const nb::bytes& bytecode, bool use_tuple_args, bool return_tuple) { + return xla::ValueOrThrow(PyMlirModuleToXlaComputation( + std::string_view(bytecode.c_str(), bytecode.size()), use_tuple_args, + return_tuple)); + }, + nb::arg("mlir_module"), nb::arg("use_tuple_args") = false, + nb::arg("return_tuple") = false, + nb::sig( + // clang-format off + "def mlir_module_to_xla_computation(" + "mlir_module: bytes, " + "use_tuple_args: bool = ..., " + "return_tuple: bool = ..." + ") -> _XlaComputation" + // clang-format on + )); + mlir_module.def("mlir_module_to_xla_computation", + xla::ValueOrThrowWrapper(PyMlirModuleToXlaComputation), + nb::arg("mlir_module"), nb::arg("use_tuple_args") = false, + nb::arg("return_tuple") = false, + nb::sig( + // clang-format off + "def mlir_module_to_xla_computation(" + "mlir_module: str, " + "use_tuple_args: bool = ..., " + "return_tuple: bool = ..." + ") -> _XlaComputation" + // clang-format on + )); + mlir_module.def( + "mhlo_to_stablehlo", + [](const nb::bytes& bytecode) { + return xla::ValueOrThrow(PyMhloToStablehlo( + std::string_view(bytecode.c_str(), bytecode.size()))); + }, + nb::arg("mlir_module")); + mlir_module.def("mhlo_to_stablehlo", + xla::ValueOrThrowWrapper(PyMhloToStablehlo), + nb::arg("mlir_module")); + mlir_module.def( + "serialize_portable_artifact", + [](const nb::bytes& bytecode, std::string_view target, + bool use_mixed_serialization) { + return xla::ValueOrThrow(PySerializePortableArtifact( + std::string_view(bytecode.c_str(), bytecode.size()), target, + use_mixed_serialization)); + }, + nb::arg("mlir_module"), nb::arg("target"), + nb::arg("use_mixed_serialization") = false); + mlir_module.def( + "serialize_portable_artifact", + [](std::string_view mlir_module, std::string_view target, + bool use_mixed_serialization) { + return xla::ValueOrThrow(PySerializePortableArtifact( + mlir_module, target, use_mixed_serialization)); + }, + nb::arg("mlir_module"), nb::arg("target"), + nb::arg("use_mixed_serialization") = false); + mlir_module.def("deserialize_portable_artifact", + xla::ValueOrThrowWrapper(PyDeserializePortableArtifact), + nb::arg("mlir_module")); + mlir_module.def( + "refine_polymorphic_shapes", + [](nb::bytes bytecode, bool enable_shape_assertions, + bool validate_static_shapes, bool enable_shardy) -> nb::bytes { + std::string buffer; + llvm::raw_string_ostream os(buffer); + xla::ThrowIfError(xla::RefinePolymorphicShapes( + std::string_view(bytecode.c_str(), bytecode.size()), os, + enable_shape_assertions, validate_static_shapes, enable_shardy)); + return nb::bytes(buffer.data(), buffer.size()); + }, + nb::arg("mlir_module"), nb::arg("enable_shape_assertions") = true, + nb::arg("validate_static_shapes") = true, + nb::arg("enable_shardy") = false, + R"(Refines the dynamic shapes for a module. + The "main" function must have static shapes and all the + intermediate dynamic shapes depend only on the input static + shapes. Optionally, also validates that the resulting module has + only static shapes. + )"); +} + +} // namespace jax diff --git a/jaxlib/mlir.h b/jaxlib/mlir.h new file mode 100644 index 000000000000..4d48bc3f4c3a --- /dev/null +++ b/jaxlib/mlir.h @@ -0,0 +1,28 @@ +/* Copyright 2021 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_MLIR_H_ +#define JAXLIB_MLIR_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace jax { + +void BuildMlirSubmodule(nanobind::module_& m); + +} // namespace jax + +#endif // JAXLIB_MLIR_H_ diff --git a/jaxlib/mlir/BUILD.bazel b/jaxlib/mlir/BUILD.bazel index de7b017355fc..56bdd9de4d0e 100644 --- a/jaxlib/mlir/BUILD.bazel +++ b/jaxlib/mlir/BUILD.bazel @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@rules_python//python:py_library.bzl", "py_library") load("//jaxlib:symlink_files.bzl", "symlink_files", "symlink_inputs") package( @@ -46,7 +47,6 @@ symlink_inputs( symlinked_inputs = {"srcs": { ".": [ "@llvm-project//mlir/python:IRPyFiles", - "@llvm-project//mlir/python:IRPyIFiles", ], }}, deps = [ @@ -65,7 +65,7 @@ symlink_inputs( name = "func_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@llvm-project//mlir/python:FuncPyFiles"], + "dialects": ["@llvm-project//mlir/python:FuncPyFiles"], }}, deps = [ ":core", @@ -78,7 +78,7 @@ symlink_inputs( name = "vector_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@llvm-project//mlir/python:VectorOpsPyFiles"], + "dialects": ["@llvm-project//mlir/python:VectorOpsPyFiles"], }}, deps = [ ":core", @@ -91,7 +91,7 @@ symlink_inputs( name = "math_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@llvm-project//mlir/python:MathOpsPyFiles"], + "dialects": ["@llvm-project//mlir/python:MathOpsPyFiles"], }}, deps = [ ":core", @@ -104,7 +104,7 @@ symlink_inputs( name = "arithmetic_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@llvm-project//mlir/python:ArithOpsPyFiles"], + "dialects": ["@llvm-project//mlir/python:ArithOpsPyFiles"], }}, deps = [ ":core", @@ -117,7 +117,20 @@ symlink_inputs( name = "memref_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@llvm-project//mlir/python:MemRefOpsPyFiles"], + "dialects": ["@llvm-project//mlir/python:MemRefOpsPyFiles"], + }}, + deps = [ + ":core", + ":ir", + ":mlir", + ], +) + +symlink_inputs( + name = "control_flow_dialect", + rule = py_library, + symlinked_inputs = {"srcs": { + "dialects": ["@llvm-project//mlir/python:ControlFlowOpsPyFiles"], }}, deps = [ ":core", @@ -130,7 +143,7 @@ symlink_inputs( name = "scf_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@llvm-project//mlir/python:SCFPyFiles"], + "dialects": ["@llvm-project//mlir/python:SCFPyFiles"], }}, deps = [ ":core", @@ -143,7 +156,7 @@ symlink_inputs( name = "builtin_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@llvm-project//mlir/python:BuiltinOpsPyFiles"], + "dialects": ["@llvm-project//mlir/python:BuiltinOpsPyFiles"], }}, deps = [ ":core", @@ -157,7 +170,7 @@ symlink_inputs( name = "chlo_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@stablehlo//:chlo_ops_py_files"], + "dialects": ["@stablehlo//:chlo_ops_py_files"], }}, deps = [ ":core", @@ -171,7 +184,7 @@ symlink_inputs( name = "sparse_tensor_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@llvm-project//mlir/python:SparseTensorOpsPyFiles"], + "dialects": ["@llvm-project//mlir/python:SparseTensorOpsPyFiles"], }}, deps = [ ":core", @@ -186,7 +199,7 @@ symlink_inputs( name = "mhlo_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@xla//xla/mlir_hlo:MhloOpsPyFiles"], + "dialects": ["@xla//xla/mlir_hlo:MhloOpsPyFiles"], }}, deps = [ ":core", @@ -202,7 +215,6 @@ symlink_inputs( symlinked_inputs = {"srcs": { ".": [ "@llvm-project//mlir/python:PassManagerPyFiles", - "@llvm-project//mlir/python:PassManagerPyIFiles", ], }}, deps = [ @@ -224,11 +236,25 @@ symlink_inputs( ], ) +symlink_inputs( + name = "mpmd_dialect", + rule = py_library, + symlinked_inputs = {"srcs": { + "dialects": ["@shardy//shardy/integrations/python/ir/mpmd:mpmd_ops_py_files"], + }}, + deps = [ + ":core", + ":ir", + ":mlir", + "//jaxlib/mlir/_mlir_libs:_sdyMpmd", + ], +) + symlink_inputs( name = "stablehlo_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@stablehlo//:stablehlo_ops_py_files"], + "dialects": ["@stablehlo//:stablehlo_ops_py_files"], }}, deps = [ ":core", diff --git a/jaxlib/mlir/_mlir_libs/BUILD.bazel b/jaxlib/mlir/_mlir_libs/BUILD.bazel index fb94837cff37..60220f6338de 100644 --- a/jaxlib/mlir/_mlir_libs/BUILD.bazel +++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@rules_python//python:py_library.bzl", "py_library") load( "//jaxlib:jax.bzl", "if_windows", - "nanobind_extension", - "py_extension", - "windows_cc_shared_mlir_library", ) +load("//jaxlib:pywrap.bzl", "nanobind_pywrap_extension") load("//jaxlib:symlink_files.bzl", "symlink_inputs") package( @@ -33,134 +32,125 @@ COPTS = [ "-frtti", ] -LINKOPTS = select({ - "@xla//xla/tsl:macos": [ - "-Wl,-rpath,@loader_path/", - "-Wl,-rename_section,__TEXT,text_env,__TEXT,__text", - ], - "@xla//xla/tsl:windows": [], - "//conditions:default": [ - "-Wl,-rpath,$$ORIGIN/", - ], -}) - -py_extension( +nanobind_pywrap_extension( name = "_mlir", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/MainModule.cpp", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:MLIRBindingsPythonCoreNoCAPI", + "@llvm-project//mlir:CAPIDebug", + "@llvm-project//mlir:CAPITransforms", + "@llvm-project//mlir:MLIRBindingsPythonCore", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mlirDialectsGPU", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/DialectGPU.cpp", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIGPUHeaders", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIDebug", + "@llvm-project//mlir:CAPIGPU", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:CAPITransforms", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mlirGPUPasses", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/GPUPasses.cpp", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIGPUHeaders", + "@llvm-project//mlir:CAPIDebug", + "@llvm-project//mlir:CAPIGPU", + "@llvm-project//mlir:CAPITransforms", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mlirDialectsNVGPU", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/DialectNVGPU.cpp", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIIRHeaders", - "@llvm-project//mlir:CAPINVGPUHeaders", + "@llvm-project//mlir:CAPIDebug", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:CAPINVGPU", + "@llvm-project//mlir:CAPITransforms", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mlirDialectsLLVM", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/DialectLLVM.cpp", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIIRHeaders", - "@llvm-project//mlir:CAPILLVMHeaders", + "@llvm-project//mlir:CAPIDebug", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:CAPILLVM", + "@llvm-project//mlir:CAPITarget", + "@llvm-project//mlir:CAPITransforms", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mlirDialectsSparseTensor", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/DialectSparseTensor.cpp", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPISparseTensorHeaders", + "@llvm-project//mlir:CAPIDebug", + "@llvm-project//mlir:CAPISparseTensor", + "@llvm-project//mlir:CAPITransforms", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mlirSparseTensorPasses", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/SparseTensorPasses.cpp", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPISparseTensorHeaders", + "@llvm-project//mlir:CAPIDebug", + "@llvm-project//mlir:CAPISparseTensor", + "@llvm-project//mlir:CAPITransforms", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mosaic_gpu_ext", srcs = ["mosaic_gpu_ext.cc"], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi_headers", - "@llvm-project//mlir:CAPIIRHeaders", + "//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi", + "//jaxlib/mosaic/gpu:tiled_layout", + "@llvm-project//mlir:CAPIDebug", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:CAPITransforms", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps", "@nanobind", ], @@ -171,17 +161,17 @@ py_extension( # :jaxlib_mlir_capi_shared_library). This ensures that the RPATH works correctly # across platforms. It's not clear if Windows supports RPATH-like functionality # across different directories at all. -py_extension( +nanobind_pywrap_extension( name = "_tpu_ext", srcs = ["tpu_ext.cc"], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "//jaxlib/mosaic:tpu_dialect_capi_headers", + "//jaxlib/mosaic:tpu_dialect_capi", "@com_google_absl//absl/log:check", "@llvm-project//llvm:Support", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIDebug", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:CAPITransforms", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps", "@nanobind", "@xla//xla/python:nb_numpy", @@ -190,7 +180,7 @@ py_extension( ) # This target contains the extension and it's Python dependencies, which are not -# supported by the `py_extension`/`nanobind_extension` macros. +# supported by the `nanobind_pywrap_extension`/`nanobind_extension` macros. py_library( name = "_tpu_ext_lib", deps = [ @@ -200,19 +190,23 @@ py_library( ], ) -nanobind_extension( +nanobind_pywrap_extension( name = "_triton_ext", srcs = ["triton_ext.cc"], copts = COPTS, - linkopts = LINKOPTS, pytype_srcs = ["_triton_ext.pyi"], deps = [ - ":jaxlib_mlir_capi_shared_library", - "//jaxlib/triton:triton_dialect_capi_headers", - "@llvm-project//mlir:CAPIIRHeaders", - "@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps", "@nanobind", - ], + ] + if_windows( + [], + [ + "//jaxlib/triton:triton_dialect_capi", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps", + "@llvm-project//mlir:CAPIDebug", + "@llvm-project//mlir:CAPITransforms", + ], + ), ) symlink_inputs( @@ -224,60 +218,73 @@ symlink_inputs( ], }}, deps = [ + ":_jax_mlir_ext", ":_mlir", - ":register_jax_dialects", ], ) -cc_library( - name = "jaxlib_mlir_capi_shims", - srcs = ["jaxlib_mlir_capi_shims.cc"], - hdrs = ["jaxlib_mlir_capi_shims.h"], - deps = [ - "@llvm-project//mlir:BuiltinToLLVMIRTranslation", - "@llvm-project//mlir:CAPIIRHeaders", - "@llvm-project//mlir:GPUPipelines", - "@llvm-project//mlir:GPUToLLVMIRTranslation", - "@llvm-project//mlir:LLVMToLLVMIRTranslation", - "@llvm-project//mlir:MemRefTransforms", - "@llvm-project//mlir:NVVMTarget", - "@llvm-project//mlir:NVVMToLLVMIRTranslation", - ], - alwayslink = 1, -) +# JAX-specific registrations. cc_library( - name = "jaxlib_mlir_capi_shims_hdrs", - hdrs = ["jaxlib_mlir_capi_shims.h"], + name = "traceback_to_location", + srcs = ["traceback_to_location.cc"], + hdrs = ["traceback_to_location.h"], + copts = ["-fexceptions"], deps = [ - "@llvm-project//mlir:CAPIIRHeaders", + "//jaxlib:traceback", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/types:span", + "@llvm-project//mlir:CAPIDebug", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:CAPITransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", + "@nanobind", ], ) -# JAX-specific registrations. -py_extension( - name = "register_jax_dialects", - srcs = ["register_jax_dialects.cc"], +nanobind_pywrap_extension( + name = "_jax_mlir_ext", + srcs = ["jax_mlir_ext.cc"], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "//jaxlib/mosaic/gpu:mlir_capi_headers", - "@llvm-project//mlir:CAPIArithHeaders", - "@llvm-project//mlir:CAPIGPUHeaders", - "@llvm-project//mlir:CAPIIRHeaders", - "@llvm-project//mlir:CAPILLVMHeaders", - "@llvm-project//mlir:CAPIMathHeaders", - "@llvm-project//mlir:CAPIMemRefHeaders", - "@llvm-project//mlir:CAPINVGPUHeaders", - "@llvm-project//mlir:CAPINVVMHeaders", - "@llvm-project//mlir:CAPISCFHeaders", - "@llvm-project//mlir:CAPITransformsHeaders", - "@llvm-project//mlir:CAPIVectorHeaders", + ":traceback_to_location", + "//jaxlib/mosaic/gpu:mlir_capi", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:CAPIArith", + "@llvm-project//mlir:CAPICF", + "@llvm-project//mlir:CAPIDebug", + "@llvm-project//mlir:CAPIGPU", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:CAPILLVM", + "@llvm-project//mlir:CAPIMath", + "@llvm-project//mlir:CAPIMemRef", + "@llvm-project//mlir:CAPINVGPU", + "@llvm-project//mlir:CAPINVVM", + "@llvm-project//mlir:CAPISCF", + "@llvm-project//mlir:CAPITransforms", + "@llvm-project//mlir:CAPIVector", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", - "@local_config_python//:headers", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", "@nanobind", - "@shardy//shardy/integrations/c:sdy_capi_headers", + "@rules_python//python/cc:current_py_cc_headers", + "@shardy//shardy/dialect/sdy/ir:dialect", + "@shardy//shardy/integrations/c:sdy_capi", + "@stablehlo//:vhlo_ops", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_absl_span", + "@xla//xla/service/spmd/shardy/integrations/c:xla_sdy_capi", ], ) @@ -285,20 +292,20 @@ py_extension( # MHLO Extensions ##---------------------------------------------------------------------------## -py_extension( +nanobind_pywrap_extension( name = "_mlirHlo", srcs = [ "@xla//xla/mlir_hlo:bindings/python/MlirHloModule.cc", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIDebug", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:CAPITransforms", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", - "@local_config_python//:headers", "@nanobind", - "@xla//xla/mlir_hlo:CAPIHeaders", + "@rules_python//python/cc:current_py_cc_headers", + "@xla//xla/mlir_hlo:CAPI", ], ) @@ -306,21 +313,39 @@ py_extension( # Shardy Extensions ##---------------------------------------------------------------------------## -py_extension( +nanobind_pywrap_extension( name = "_sdy", srcs = [ "@shardy//shardy/integrations/python/ir:sdy_module.cc", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIDebug", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:CAPITransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", + "@nanobind", + "@rules_python//python/cc:current_py_cc_headers", + "@shardy//shardy/integrations/c:sdy_capi", + ], +) + +nanobind_pywrap_extension( + name = "_sdyMpmd", + srcs = [ + "@shardy//shardy/integrations/python/ir/mpmd:mpmd_module.cc", + ], + copts = COPTS, + deps = [ + "@llvm-project//mlir:CAPIDebug", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:CAPITransforms", "@llvm-project//mlir:IR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", - "@local_config_python//:headers", "@nanobind", - "@shardy//shardy/integrations/c:sdy_capi_headers", + "@rules_python//python/cc:current_py_cc_headers", + "@shardy//shardy/integrations/c/mpmd:mpmd_capi", ], ) @@ -328,115 +353,37 @@ py_extension( # Stablehlo Extensions ##---------------------------------------------------------------------------## -py_extension( +nanobind_pywrap_extension( name = "_chlo", srcs = [ "@stablehlo//:chlo_py_api_files", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIDebug", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:CAPITransforms", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", - "@local_config_python//:headers", "@nanobind", - "@stablehlo//:chlo_capi_headers", + "@rules_python//python/cc:current_py_cc_headers", + "@stablehlo//:chlo_capi", ], ) -py_extension( +nanobind_pywrap_extension( name = "_stablehlo", srcs = [ "@stablehlo//:stablehlo_py_api_files", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", "@llvm-project//llvm:Support", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIDebug", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:CAPITransforms", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", - "@local_config_python//:headers", "@nanobind", - "@stablehlo//:stablehlo_capi_headers", - ], -) - -# Shared C++ extension library - -cc_library( - name = "jaxlib_mlir_capi_shared_library", - srcs = select({ - "@xla//xla/tsl:windows": [":jaxlib_mlir_capi.dll"], - "@xla//xla/tsl:macos": [":libjaxlib_mlir_capi.dylib"], - "//conditions:default": [":libjaxlib_mlir_capi.so"], - }), - deps = select({ - "@xla//xla/tsl:windows": [":jaxlib_mlir_capi_dll"], - "//conditions:default": [], - }), -) - -cc_library( - name = "jaxlib_mlir_capi_objects", - deps = [ - "//jaxlib/mosaic:tpu_dialect_capi_objects", - "//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi_objects", - "//jaxlib/mosaic/gpu:mlir_capi_objects", - "@llvm-project//mlir:CAPIArithObjects", - "@llvm-project//mlir:CAPIGPUObjects", - "@llvm-project//mlir:CAPIIRObjects", - "@llvm-project//mlir:CAPILLVMObjects", - "@llvm-project//mlir:CAPIMathObjects", - "@llvm-project//mlir:CAPIMemRefObjects", - "@llvm-project//mlir:CAPINVGPUObjects", - "@llvm-project//mlir:CAPINVVMObjects", - "@llvm-project//mlir:CAPISCFObjects", - "@llvm-project//mlir:CAPISparseTensorObjects", - "@llvm-project//mlir:CAPITransformsObjects", - "@llvm-project//mlir:CAPIVectorObjects", - "@llvm-project//mlir:MLIRBindingsPythonCAPIObjects", - "@shardy//shardy/integrations/c:sdy_capi_objects", - "@stablehlo//:chlo_capi_objects", - "@stablehlo//:stablehlo_capi_objects", - "@xla//xla/mlir_hlo:CAPIObjects", - ] + if_windows( - [], - [ - "//jaxlib/triton:triton_dialect_capi_objects", - ], - ), -) - -cc_binary( - name = "libjaxlib_mlir_capi.so", - linkopts = [ - "-Wl,-soname=libjaxlib_mlir_capi.so", - "-Wl,-rpath='$$ORIGIN'", - ], - linkshared = 1, - deps = [":jaxlib_mlir_capi_objects"], -) - -cc_binary( - name = "libjaxlib_mlir_capi.dylib", - linkopts = [ - "-Wl,-rpath,@loader_path/", - "-Wl,-install_name,@loader_path/libjaxlib_mlir_capi.dylib", - ], - linkshared = 1, - deps = [":jaxlib_mlir_capi_objects"], -) - -windows_cc_shared_mlir_library( - name = "jaxlib_mlir_capi_dll", - out = "jaxlib_mlir_capi.dll", - exported_symbol_prefixes = [ - "mlir", - "chlo", - "sdy", - "stablehlo", + "@rules_python//python/cc:current_py_cc_headers", + "@stablehlo//:stablehlo_capi", ], - deps = [":jaxlib_mlir_capi_objects"], ) diff --git a/jaxlib/mlir/_mlir_libs/_triton_ext.pyi b/jaxlib/mlir/_mlir_libs/_triton_ext.pyi index 1e1a67405113..93a82010043c 100644 --- a/jaxlib/mlir/_mlir_libs/_triton_ext.pyi +++ b/jaxlib/mlir/_mlir_libs/_triton_ext.pyi @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mlir import ir +from jaxlib.mlir import ir def register_dialect(context: ir.Context, load: bool = ...) -> None: ... diff --git a/jaxlib/mlir/_mlir_libs/jax_mlir_ext.cc b/jaxlib/mlir/_mlir_libs/jax_mlir_ext.cc new file mode 100644 index 000000000000..65ee1673ca7c --- /dev/null +++ b/jaxlib/mlir/_mlir_libs/jax_mlir_ext.cc @@ -0,0 +1,250 @@ +/* Copyright 2022 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Registers MLIR dialects used by JAX. +// This module is called by mlir/__init__.py during initialization. + +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "mlir-c/Dialect/Arith.h" // IWYU pragma: keep +#include "mlir-c/Dialect/ControlFlow.h" +#include "mlir-c/Dialect/Func.h" // IWYU pragma: keep +#include "mlir-c/Dialect/GPU.h" // IWYU pragma: keep +#include "mlir-c/Dialect/LLVM.h" // IWYU pragma: keep +#include "mlir-c/Dialect/Math.h" // IWYU pragma: keep +#include "mlir-c/Dialect/MemRef.h" // IWYU pragma: keep +#include "mlir-c/Dialect/NVGPU.h" // IWYU pragma: keep +#include "mlir-c/Dialect/NVVM.h" // IWYU pragma: keep +#include "mlir-c/Dialect/SCF.h" // IWYU pragma: keep +#include "mlir-c/Dialect/Vector.h" // IWYU pragma: keep +#include "mlir-c/IR.h" +#include "mlir-c/Transforms.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep +#include "mlir/CAPI/IR.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Support/LLVM.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "shardy/dialect/sdy/ir/dialect.h" +#include "shardy/integrations/c/passes.h" +#include "jaxlib/mlir/_mlir_libs/traceback_to_location.h" +#include "jaxlib/mosaic/gpu/integrations/c/passes.h" +#include "stablehlo/dialect/VhloOps.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/service/spmd/shardy/integrations/c/passes.h" + +namespace nb = ::nanobind; + +namespace jax { + +namespace { + +// Returns true if a location is a NameLoc with a FileLineColLoc child. We +// assume the NameLoc names a function name in a frame in this case. +bool IsFrameNameLocation(mlir::Location location) { + return mlir::isa(location) && + mlir::isa( + mlir::cast(location).getChildLoc()); +} + +// Split a location into an operation type and an operation name, and a tail +// location. +void ParseLocation(mlir::Location& location, llvm::StringRef& op_type, + llvm::StringRef& op_name) { + while (auto name_loc = mlir::dyn_cast(location)) { + if (IsFrameNameLocation(name_loc)) { + break; + } + llvm::StringRef name = name_loc.getName().strref(); + if (name.ends_with(":")) { + op_type = name; + } else { + op_name = name; + } + location = mlir::cast(location).getChildLoc(); + } +} + +} // namespace + +absl::StatusOr> InlinedCall( + MlirOperation c_callee, absl::Span c_args, MlirBlock block, + MlirLocation loc) { + mlir::Operation* callee = unwrap(c_callee); + mlir::func::FuncOp func = llvm::cast(callee); + mlir::Region& body = func.getBody(); + if (body.getBlocks().size() != 1) { + return absl::InvalidArgumentError( + absl::StrFormat("expected function to have exactly one block, got %d", + body.getBlocks().size())); + } + mlir::Block& body_block = body.getBlocks().front(); + + mlir::OpBuilder op_builder = mlir::OpBuilder::atBlockEnd(unwrap(block)); + mlir::IRMapping mapping; + if (body_block.getNumArguments() != c_args.size()) { + return absl::InvalidArgumentError( + absl::StrFormat("expected callee to have %d arguments, got %d", + c_args.size(), body_block.getNumArguments())); + } + for (auto [arg_value, arg] : llvm::zip(body_block.getArguments(), c_args)) { + mapping.map(arg_value, unwrap(arg)); + } + + mlir::Location parent_base_loc = unwrap(loc); + llvm::StringRef parent_op_type, parent_op_name; + ParseLocation(parent_base_loc, parent_op_type, parent_op_name); + + std::optional> results; + for (mlir::Operation& op : body_block.getOperations()) { + if (llvm::isa(op)) { + if (results.has_value()) { + return absl::InternalError( + "expected function to have exactly one return op"); + } + results.emplace(); + for (mlir::Value result : op.getOperands()) { + results->push_back(wrap(mapping.lookup(result))); + } + } else { + mlir::Operation* cloned_op = op_builder.clone(op, mapping); + cloned_op->walk([&](mlir::Operation* op) { + // Compute a new location for the cloned op. + // * The name should be "parent_op_name/child_op_name" (assuming both + // are present). + // * We use the op_type of the parent. + // * We use the traceback of the parent. We want the location of the + // equation, not the location of the lowering rule. + mlir::Location child_loc = op->getLoc(); + llvm::StringRef child_op_type, child_op_name; + ParseLocation(child_loc, child_op_type, child_op_name); + + child_loc = parent_base_loc; + if (child_op_name.empty()) { + child_loc = mlir::NameLoc::get( + op_builder.getStringAttr(parent_op_name), child_loc); + } else if (parent_op_name.empty()) { + child_loc = mlir::NameLoc::get( + op_builder.getStringAttr(child_op_name), child_loc); + } else { + std::string name = + absl::StrCat(static_cast(parent_op_name), "/", + static_cast(child_op_name)); + child_loc = + mlir::NameLoc::get(op_builder.getStringAttr(name), child_loc); + } + if (!parent_op_type.empty()) { + child_loc = mlir::NameLoc::get( + op_builder.getStringAttr(parent_op_type), child_loc); + } + op->setLoc(child_loc); + if (mlir::isa(op)) { + // Skip `ManualComputationOp`s and their nested operations, they will + // be handled separately. + return mlir::WalkResult::skip(); + } + return mlir::WalkResult::advance(); + }); + } + } + if (!results.has_value()) { + return absl::InternalError( + "expected function to have exactly one return op"); + } + return *results; +} + +NB_MODULE(_jax_mlir_ext, m) { + m.doc() = "Registers upstream MLIR dialects used by JAX."; + + m.def("register_dialects", [](MlirDialectRegistry registry) { +#define REGISTER_DIALECT(name) \ + MlirDialectHandle name##_dialect = mlirGetDialectHandle__##name##__(); \ + mlirDialectHandleInsertDialect(name##_dialect, registry) + REGISTER_DIALECT(arith); + REGISTER_DIALECT(func); + REGISTER_DIALECT(math); + REGISTER_DIALECT(memref); + REGISTER_DIALECT(scf); + REGISTER_DIALECT(vector); + // TODO(jpienaar): these don't seem to have C API targets known to Bazel + unwrap(registry)->insert(); + unwrap(registry)->insert(); + unwrap(registry)->insert(); + // For Mosaic GPU + REGISTER_DIALECT(cf); + REGISTER_DIALECT(gpu); + REGISTER_DIALECT(nvgpu); + REGISTER_DIALECT(nvvm); + REGISTER_DIALECT(llvm); +#undef REGISTER_DIALECT + + mlirMosaicGpuRegisterSerdePass(); + mlirRegisterTransformsPasses(); + // For Shardy + mlirRegisterAllSdyPassesAndPipelines(); + mlirRegisterAllXlaSdyPassesAndPipelines(); + // Transforms used by JAX. + mlirRegisterTransformsStripDebugInfo(); + }); + + m.def("enter_multi_threaded_execution", [](MlirContext context) { + unwrap(context)->enterMultiThreadedExecution(); + }); + m.def("exit_multi_threaded_execution", [](MlirContext context) { + unwrap(context)->exitMultiThreadedExecution(); + }); + + m.def("inlined_func_call", xla::ValueOrThrowWrapper(InlinedCall), + nb::arg("callee"), nb::arg("args"), nb::arg("block"), + nb::arg("loc").none() = nb::none(), + "Makes an inlined call to a function containing a single block with a " + "single return op."); + + nb::class_(m, "TracebackToLocationCache") + .def( + "__init__", + [](TracebackToLocationCache* self, nb::callable code_to_filename, + int frame_limit, MlirContext context) { + new (self) TracebackToLocationCache(code_to_filename, frame_limit, + unwrap(context)); + }, + nb::arg("code_to_filename"), nb::arg("frame_limit"), + nb::arg("context").none() = nb::none()) + .def("get", &TracebackToLocationCache::Get); +} + +} // namespace jax diff --git a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc index c73084abc99d..6ac66b6ad5f3 100644 --- a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc +++ b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc @@ -14,15 +14,23 @@ limitations under the License. ==============================================================================*/ #include +#include #include +#include "absl/hash/hash.h" #include "mlir-c/IR.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep #include "nanobind/nanobind.h" +#include "nanobind/operators.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/tuple.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h" #include "jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h" +#include "jaxlib/mosaic/gpu/tiled_layout.h" namespace nb = nanobind; +namespace mgpu = jax::mosaic::gpu; NB_MODULE(_mosaic_gpu_ext, m) { m.def( @@ -35,8 +43,13 @@ NB_MODULE(_mosaic_gpu_ext, m) { } }, nb::arg("context"), nb::arg("load") = true); - m.def("private_operation_remove_from_parent", mlirOperationRemoveFromParent); - m.def("private_block_append_owned_operation", mlirBlockAppendOwnedOperation); + + m.def("register_inliner_extensions", [](MlirContext context) { + MlirDialectRegistry registry = mlirDialectRegistryCreate(); + mlirDialectRegistryInsertMosaicGpuInlinerExtensions(registry); + mlirContextAppendDialectRegistry(context, registry); + mlirDialectRegistryDestroy(registry); + }); mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "TileTransformAttr", mlirMosaicGpuIsATileTransformAttr) @@ -139,24 +152,236 @@ NB_MODULE(_mosaic_gpu_ext, m) { return mlirMosaicGpuSwizzleTransformAttrGetSwizzle(self); }); - mlir::python::nanobind_adaptors::mlir_attribute_subclass( - m, "LayoutAttr", mlirMosaicGpuIsALayoutAttr) - .def_classmethod( - "get", - [](nb::object cls, int32_t num_dimensions, - std::vector& transforms, MlirContext ctx) { - return cls(mlirMosaicGpuLayoutAttrGet( - ctx, num_dimensions, transforms.data(), transforms.size())); + nb::class_(m, "Tiling") + .def( + "__init__", + [](mgpu::Tiling* self, nb::iterable in_tiles) { + std::vector> tiles; + for (const auto& tile : in_tiles) { + tiles.push_back(nb::cast>(tile)); + } + auto result = mgpu::Tiling::Create(tiles); + if (!result.ok()) { + throw nb::value_error(result.status().message().data()); + } + new (self) mgpu::Tiling(*result); }, - nb::arg("cls"), nb::arg("num_dimensions"), nb::arg("transforms"), - nb::arg("context").none() = nb::none(), - "Creates a LayoutAttr with the given transforms.") - .def_property_readonly("transforms", [](MlirAttribute self) { - std::vector result; - for (int i = 0; i < mlirMosaicGpuLayoutAttrGetTransformsSize(self); - ++i) { - result.push_back(mlirMosaicGpuLayoutAttrGetTransform(self, i)); + nb::arg("tiles")) + .def( + "tile_shape", + [](const mgpu::Tiling& self, const std::vector& shape) { + auto result = self.TileShape(shape); + if (!result.ok()) { + throw nb::value_error(result.status().message().data()); + } + return nb::tuple(nb::cast(*result)); + }, + nb::arg("shape")) + .def( + "untile_shape", + [](const mgpu::Tiling& self, const std::vector& shape) { + auto result = self.UntileShape(shape); + if (!result.ok()) { + throw nb::value_error(result.status().message().data()); + } + return nb::tuple(nb::cast(*result)); + }, + nb::arg("shape")) + .def( + "tile_strides", + [](const mgpu::Tiling& self, const std::vector& strides) { + return nb::tuple(nb::cast(self.TileStrides(strides))); + }, + nb::arg("strides")) + .def( + "tile_indices", + [](const mgpu::Tiling& self, const std::vector& indices) { + return nb::tuple(nb::cast(self.TileIndices(indices))); + }, + nb::arg("indices")) + .def( + "untile_indices", + [](const mgpu::Tiling& self, const std::vector& indices) { + return nb::tuple(nb::cast(self.UntileIndices(indices))); + }, + nb::arg("indices")) + .def( + "tile_nested_shape_strides", + [](const mgpu::Tiling& self, + const std::vector>& shape, + const std::vector>& strides) { + auto result = self.TileNestedShapeStrides(shape, strides); + if (!result.ok()) { + throw nb::value_error(result.status().message().data()); + } + auto [tiled_shape, tiled_strides] = *result; + nb::list shape_list; + for (const auto& s : tiled_shape) { + shape_list.append(nb::tuple(nb::cast(s))); + } + nb::list strides_list; + for (const auto& s : tiled_strides) { + strides_list.append(nb::tuple(nb::cast(s))); + } + return nb::make_tuple(nb::tuple(shape_list), + nb::tuple(strides_list)); + }, + nb::arg("shape"), nb::arg("strides")) + .def( + "tile_dimension", + [](const mgpu::Tiling& self, int64_t dim) { + auto result = self.TileDimension(dim); + if (!result.ok()) { + throw nb::value_error(result.status().message().data()); + } + return nb::tuple(nb::cast(*result)); + }, + nb::arg("dim")) + .def( + "remove_dimension", + [](const mgpu::Tiling& self, int64_t dim) { + auto result = self.RemoveDimension(dim); + if (!result.ok()) { + throw nb::value_error(result.status().message().data()); + } + return *result; + }, + nb::arg("dim")) + .def("canonicalize", &mgpu::Tiling::Canonicalize) + .def_prop_ro("tiles", + [](const mgpu::Tiling& self) { + nb::list tiles_list; + for (const mgpu::Tiling::Tile& tile : self.tiles()) { + tiles_list.append(nb::tuple(nb::cast(tile))); + } + return nb::tuple(tiles_list); + }) + .def("__str__", &mgpu::Tiling::ToString) + .def("__repr__", &mgpu::Tiling::ToString) + .def(nb::self == nb::self) + .def("__hash__", [](const mgpu::Tiling& self) { + return absl::Hash{}(self); + }); + + nb::class_(m, "Replicated") + .def(nb::init(), nb::arg("times")) + .def_prop_rw( + "times", [](const mgpu::Replicated& self) { return self.times; }, + [](mgpu::Replicated& self, int64_t times) { self.times = times; }) + .def("__repr__", &mgpu::Replicated::ToString) + .def("__hash__", + [](const mgpu::Replicated& self) { + return absl::Hash{}(self); + }) + .def("__eq__", [](const mgpu::Replicated& self, nb::object other) { + if (!nb::isinstance(other)) { + return false; } - return result; + return self == nb::cast(other); }); + + nb::class_(m, "TiledLayout") + .def( + "__init__", + [](mgpu::TiledLayout* self, mgpu::Tiling tiling, + nb::iterable in_warp_dims, nb::iterable in_lane_dims, + int64_t vector_dim, bool check_canonical) { + std::vector warp_dims; + for (const auto& dim : in_warp_dims) { + if (nb::isinstance(dim)) { + warp_dims.emplace_back(nb::cast(dim)); + } else { + warp_dims.emplace_back(nb::cast(dim)); + } + } + std::vector lane_dims; + for (const auto& dim : in_lane_dims) { + if (nb::isinstance(dim)) { + lane_dims.emplace_back(nb::cast(dim)); + } else { + lane_dims.emplace_back(nb::cast(dim)); + } + } + auto result = mgpu::TiledLayout::Create( + tiling, warp_dims, lane_dims, vector_dim, check_canonical); + if (!result.ok()) { + throw nb::value_error(result.status().message().data()); + } + new (self) mgpu::TiledLayout(*result); + }, + nb::arg("tiling"), nb::arg("warp_dims"), nb::arg("lane_dims"), + nb::arg("vector_dim"), nb::arg("_check_canonical") = true) + .def_prop_ro("warp_dims", + [](const mgpu::TiledLayout& self) { + nb::list l; + for (const auto& d : self.warp_dims()) { + if (std::holds_alternative(d)) { + l.append(nb::cast(std::get(d))); + } else { + l.append(nb::cast(std::get(d))); + } + } + return nb::tuple(l); + }) + .def_prop_ro("lane_dims", + [](const mgpu::TiledLayout& self) { + nb::list l; + for (const auto& d : self.lane_dims()) { + if (std::holds_alternative(d)) { + l.append(nb::cast(std::get(d))); + } else { + l.append(nb::cast(std::get(d))); + } + } + return nb::tuple(l); + }) + .def_prop_ro("partitioned_warp_dims", + [](const mgpu::TiledLayout& self) { + return nb::tuple(nb::cast(self.PartitionedWarpDims())); + }) + .def_prop_ro("partitioned_lane_dims", + [](const mgpu::TiledLayout& self) { + return nb::tuple(nb::cast(self.PartitionedLaneDims())); + }) + .def_prop_ro("vector_length", + [](const mgpu::TiledLayout& self) { + auto result = self.VectorLength(); + if (!result.ok()) { + throw nb::value_error(result.status().message().data()); + } + return nb::cast(*result); + }) + .def_prop_ro("vector_dim", &mgpu::TiledLayout::vector_dim) + .def_prop_ro("tiling", &mgpu::TiledLayout::tiling) + .def_prop_ro("tiled_tiling_shape", + [](const mgpu::TiledLayout& self) { + auto result = self.TiledTilingShape(); + if (!result.ok()) { + throw nb::value_error(result.status().message().data()); + } + return nb::tuple(nb::cast(*self.TiledTilingShape())); + }) + .def("canonicalize", + [](const mgpu::TiledLayout& self) { + auto result = self.Canonicalize(); + if (!result.ok()) { + throw nb::value_error(result.status().message().data()); + } + return *result; + }) + .def("__str__", &mgpu::TiledLayout::ToString) + .def("__repr__", &mgpu::TiledLayout::ToString) + .def("__hash__", + [](const mgpu::TiledLayout& self) { + return absl::Hash{}(self); + }) + .def( + "__eq__", + [](const mgpu::TiledLayout& self, nb::object other) -> bool { + if (!nb::isinstance(other)) { + return false; + } + return self == nb::cast(other); + }, + nb::arg("other").none()); } diff --git a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc deleted file mode 100644 index 9da841acc7de..000000000000 --- a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc +++ /dev/null @@ -1,49 +0,0 @@ -// Registers MLIR dialects used by JAX. -// This module is called by mlir/__init__.py during initialization. -#include - -#include "mlir-c/Dialect/Arith.h" -#include "mlir-c/Dialect/Func.h" -#include "mlir-c/Dialect/GPU.h" -#include "mlir-c/Dialect/LLVM.h" -#include "mlir-c/Dialect/Math.h" -#include "mlir-c/Dialect/MemRef.h" -#include "mlir-c/Dialect/NVGPU.h" -#include "mlir-c/Dialect/NVVM.h" -#include "mlir-c/Dialect/SCF.h" -#include "mlir-c/Dialect/Vector.h" -#include "mlir-c/Transforms.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" -#include "shardy/integrations/c/passes.h" -#include "jaxlib/mosaic/gpu/integrations/c/passes.h" - - -namespace nb = nanobind; - -#define REGISTER_DIALECT(name) \ - MlirDialectHandle name##_dialect = mlirGetDialectHandle__##name##__(); \ - mlirDialectHandleInsertDialect(name##_dialect, registry) - -NB_MODULE(register_jax_dialects, m) { - m.doc() = "Registers upstream MLIR dialects used by JAX."; - - m.def("register_dialects", [](MlirDialectRegistry registry) { - REGISTER_DIALECT(arith); - REGISTER_DIALECT(func); - REGISTER_DIALECT(math); - REGISTER_DIALECT(memref); - REGISTER_DIALECT(scf); - REGISTER_DIALECT(vector); - // For Mosaic GPU - REGISTER_DIALECT(gpu); - REGISTER_DIALECT(nvgpu); - REGISTER_DIALECT(nvvm); - REGISTER_DIALECT(llvm); - mlirMosaicGpuRegisterPasses(); - mlirRegisterTransformsPasses(); - // For Shardy - mlirRegisterAllSdyPassesAndPipelines(); - // Transforms used by JAX. - mlirRegisterTransformsStripDebugInfo(); - }); -} diff --git a/jaxlib/mlir/_mlir_libs/tpu_ext.cc b/jaxlib/mlir/_mlir_libs/tpu_ext.cc index 2b5ec898ad3e..c8620d0fe7a6 100644 --- a/jaxlib/mlir/_mlir_libs/tpu_ext.cc +++ b/jaxlib/mlir/_mlir_libs/tpu_ext.cc @@ -13,760 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - -#include -#include -#include -#include -#include -#include -#include #include #include -#include -#include -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/SmallVectorExtras.h" -#include "llvm/ADT/StringExtras.h" -#include "llvm/ADT/StringRef.h" -#include "mlir-c/AffineMap.h" -#include "mlir-c/BuiltinAttributes.h" -#include "mlir-c/BuiltinTypes.h" -#include "mlir-c/Diagnostics.h" #include "mlir-c/Dialect/Func.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep -// clang-format off -#include "mlir-c/Bindings/Python/Interop.h" -// clang-format on #include "nanobind/nanobind.h" -#include "nanobind/stl/optional.h" // IWYU pragma: keep -#include "nanobind/stl/pair.h" // IWYU pragma: keep -#include "nanobind/stl/string.h" // IWYU pragma: keep -#include "nanobind/stl/variant.h" // IWYU pragma: keep -#include "nanobind/stl/vector.h" // IWYU pragma: keep -#include "absl/log/check.h" #include "jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h" -#include "xla/python/nb_numpy.h" -#include "xla/tsl/python/lib/core/numpy.h" - -// TODO(tlongeri): Can I add my own return type annotations to functions? -// TODO(tlongeri): I don't understand why MLIR uses the C API to implement -// Python bindings. Do we have a reason to do that? namespace nb = nanobind; -namespace { -constexpr const char LAYOUT_DEFS_MODULE[] = - "jax.jaxlib.mosaic.python.layout_defs"; -constexpr const char IR_MODULE[] = "jaxlib.mlir.ir"; -constexpr MlirTpuI64TargetTuple DEFAULT_TARGET_SHAPE{8, 128}; - -// TODO(tlongeri): Add type annotations via nanobind once there is -// a release for it (and maybe add a custom Sequence one as well). - -// TODO(tlongeri): For our use-case, we don't really need C++ exceptions - just -// setting the exception object and returning NULL to Python should suffice, but -// not sure if this is possible with nanobind. -class NotImplementedException : public std::runtime_error { - using runtime_error::runtime_error; -}; - -} // namespace - -template <> -struct nb::detail::type_caster { - NB_TYPE_CASTER(MlirTpuImplicitDim, const_name("ImplicitDim | None")); - - bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept { - if (src.is_none()) { - value = MlirTpuImplicitDimNone; - return true; - } - auto implicit_dim_cls = - nb::module_::import_(LAYOUT_DEFS_MODULE).attr("ImplicitDim"); - if (!nb::isinstance(src, implicit_dim_cls)) { - return false; - } - if (src.is(implicit_dim_cls.attr("MINOR"))) { - value = MlirTpuImplicitDimMinor; - } else if (src.is(implicit_dim_cls.attr("SECOND_MINOR"))) { - value = MlirTpuImplicitDimSecondMinor; - } else { - return false; - } - return true; - } - - static handle from_cpp(MlirTpuImplicitDim implicit_dim, rv_policy policy, - cleanup_list* cleanup) noexcept { - auto implicit_dim_cls = - nb::module_::import_(LAYOUT_DEFS_MODULE).attr("ImplicitDim"); - switch (implicit_dim) { - case MlirTpuImplicitDimNone: - return nb::none().release(); - case MlirTpuImplicitDimMinor: - return static_cast(implicit_dim_cls.attr("MINOR")) - .release(); - case MlirTpuImplicitDimSecondMinor: - return static_cast(implicit_dim_cls.attr("SECOND_MINOR")) - .release(); - } - } -}; - -template <> -struct nb::detail::type_caster { - NB_TYPE_CASTER(MlirTpuDirection, const_name("Direction")); - - bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept { - auto direction_cls = - nb::module_::import_(LAYOUT_DEFS_MODULE).attr("Direction"); - if (!nb::isinstance(src, direction_cls)) { - return false; - } - if (src.is(direction_cls.attr("LANES"))) { - value = MlirTpuDirectionLanes; - } else if (src.is(direction_cls.attr("SUBLANES"))) { - value = MlirTpuDirectionSublanes; - } else if (src.is(direction_cls.attr("SUBELEMENTS"))) { - value = MlirTpuDirectionSubelements; - } else { - return false; - } - return true; - } - - static handle from_cpp(MlirTpuDirection direction, rv_policy /* policy */, - cleanup_list* /* cleanup */) noexcept { - auto direction_cls = - nb::module_::import_(LAYOUT_DEFS_MODULE).attr("ImplicitDim"); - switch (direction) { - case MlirTpuDirectionLanes: - return static_cast(direction_cls.attr("LANES")).release(); - case MlirTpuDirectionSublanes: - return static_cast(direction_cls.attr("SUBLANES")) - .release(); - case MlirTpuDirectionSubelements: - return static_cast(direction_cls.attr("SUBELEMENTS")) - .release(); - default: - PyErr_Format(PyExc_ValueError, "Invalid MlirTpuDirection: %d", - static_cast(direction)); - return nb::handle(); - } - } -}; - -template <> -struct nb::detail::type_caster { - NB_TYPE_CASTER(MlirTpuI64TargetTuple, const_name("TargetTuple")); - - bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept { - auto target_tuple_cls = - nb::module_::import_(LAYOUT_DEFS_MODULE).attr("TargetTuple"); - if (!nb::isinstance(src, target_tuple_cls)) { - return false; - } - value = {nb::cast(src.attr("sublanes")), - nb::cast(src.attr("lanes"))}; - return true; - } - - static handle from_cpp(MlirTpuI64TargetTuple target_tuple, rv_policy policy, - cleanup_list* cleanup) noexcept { - nb::object target_tuple_cls = - nb::module_::import_(LAYOUT_DEFS_MODULE).attr("TargetTuple"); - return target_tuple_cls(target_tuple.sublane, target_tuple.lane).release(); - } -}; - -namespace { -// Handler for use with MLIR C API print functions. The 2nd parameter is an -// opaque pointer to "user data" that should always be a string. -void printToString(MlirStringRef c_mlir_str, void* opaque_string) { - std::string* str = static_cast(opaque_string); - CHECK(str != nullptr); - str->append(c_mlir_str.data, c_mlir_str.length); -} - -class DiagnosticCapture { - public: - DiagnosticCapture(MlirContext ctx) - : ctx_(ctx), - id_(mlirContextAttachDiagnosticHandler(ctx, handleDiagnostic, this, - nullptr)) {} - - ~DiagnosticCapture() { mlirContextDetachDiagnosticHandler(ctx_, id_); } - - void throwIfError() { - if (error_messages_.size() == 1) { - // Throw NotImplementedException if we got a single diagnostic that - // contains "Not implemented". - llvm::StringRef ref = error_messages_.front(); - constexpr llvm::StringRef not_implemented = "Not implemented"; - if (const size_t pos = ref.find(not_implemented); - pos != llvm::StringRef::npos) { - // We strip "Not implemented" only if it is a prefix. Sometimes it may - // come after another prefix (e.g. op prefix), in which case we leave it - if (pos == 0) { - ref = ref.drop_front(not_implemented.size()); - ref.consume_front(": "); - } - throw NotImplementedException(ref.str()); - } - } - if (!error_messages_.empty()) { - // Note that it is unusual/unexpected to get multiple diagnostics, so we - // just forward all the error messages. - throw std::runtime_error(llvm::join(error_messages_, "\n")); - } - } - - private: - static MlirLogicalResult handleDiagnostic(MlirDiagnostic diag, - void* opaque_detector) { - DiagnosticCapture* detector = - static_cast(opaque_detector); - if (mlirDiagnosticGetSeverity(diag) == MlirDiagnosticError) { - std::string& message = detector->error_messages_.emplace_back(); - mlirDiagnosticPrint(diag, printToString, &message); - } - return mlirLogicalResultFailure(); // Propagate to other handlers - } - llvm::SmallVector error_messages_; - const MlirContext ctx_; - const MlirDiagnosticHandlerID id_; -}; - -} // namespace - -namespace { -nb::object toPyLayoutOffset(int64_t offset) { - CHECK_GE(offset, -1); - if (offset == -1) { - return nb::module_::import_(LAYOUT_DEFS_MODULE).attr("REPLICATED"); - } else { - return nb::int_(offset); - } -} - -// TODO(tlongeri): Would `type_caster`s let me avoid defining all of these -// to/from functions? -int64_t offsetFromPyOffset(nb::object py_offset) { - if (nb::isinstance(py_offset)) { - int64_t offset = nb::cast(py_offset); - if (offset < 0) { - throw nb::value_error("Invalid py layout offset"); - } - return offset; - } else if (py_offset.equal( - nb::module_::import_(LAYOUT_DEFS_MODULE).attr("REPLICATED"))) { - return -1; - } else { - throw nb::type_error("Invalid layout offset type"); - } -} - -template -llvm::SmallVector sequenceToSmallVector(nb::sequence seq) { - llvm::SmallVector out; - out.reserve(nb::len(seq)); - for (nb::handle elem : seq) { - out.push_back(nb::cast(elem)); - } - return out; -} - -nb::tuple toPyTuple(const int64_t* data, size_t count) { - nb::tuple tuple = nb::steal(PyTuple_New(count)); - for (size_t i = 0; i < count; ++i) { - PyTuple_SET_ITEM(tuple.ptr(), i, nb::int_(data[i]).release().ptr()); - } - return tuple; -} - -nb::tuple toPyTuple(MlirTpuI64TargetTuple tuple) { - return nb::make_tuple(tuple.sublane, tuple.lane); -} - -// Unwraps the current default insertion point -// ValueError is raised if default insertion point is not set -MlirTpuInsertionPoint getDefaultInsertionPoint() { - nb::object insertion_point = - nb::module_::import_(IR_MODULE).attr("InsertionPoint").attr("current"); - nb::object ref_operation = insertion_point.attr("ref_operation"); - return {nb::cast(insertion_point.attr("block")), - ref_operation.is_none() - ? MlirOperation{nullptr} - : nb::cast(insertion_point.attr("ref_operation"))}; -} - -// Unwraps the current default location -// ValueError is raised if default location is not set -MlirLocation getDefaultLocation() { - return nb::cast( - nb::module_::import_(IR_MODULE).attr("Location").attr("current")); -} - -// Unwraps the current default MLIR context -// ValueError is raised if default context is not set -MlirContext getDefaultContext() { - return nb::cast( - nb::module_::import_(IR_MODULE).attr("Context").attr("current")); -} - -struct PyTpuVectorLayout { - PyTpuVectorLayout(MlirTpuVectorLayout layout) : layout(layout) {} - ~PyTpuVectorLayout() { mlirTpuVectorLayoutDestroy(layout); } - PyTpuVectorLayout(const PyTpuVectorLayout&) = delete; - PyTpuVectorLayout& operator=(const PyTpuVectorLayout&) = delete; - - MlirTpuVectorLayout layout; -}; - -} // namespace - NB_MODULE(_tpu_ext, m) { - tsl::ImportNumpy(); - mlirRegisterTPUPasses(); // Register all passes on load. mlirTpuRegisterMosaicSerdePass(); - nb::class_(m, "ApplyVectorLayoutCtx") - .def( - "__init__", - [](MlirTpuApplyVectorLayoutContext* self, int hardware_generation, - nb::tuple target_shape, nb::tuple mxu_shape, - int max_sublanes_in_scratch) { - if (target_shape.size() != 2) { - throw nb::value_error("target_shape should be of length 2"); - } - if (mxu_shape.size() != 2) { - throw nb::value_error("mxu_shape should be of length 2"); - } - new (self) MlirTpuApplyVectorLayoutContext{ - .hardware_generation = hardware_generation, - .target_shape = {nb::cast(target_shape[0]), - nb::cast(target_shape[1])}, - .mxu_shape = {nb::cast(mxu_shape[0]), - nb::cast(mxu_shape[1])}, - .max_sublanes_in_scratch = max_sublanes_in_scratch}; - }, - nb::arg("hardware_generation") = -1, - nb::arg("target_shape") = toPyTuple(DEFAULT_TARGET_SHAPE), - nb::arg("mxu_shape") = nb::make_tuple(128, 128), - nb::arg("max_sublanes_in_scratch") = 0); - - nb::class_(m, "VRegDataBounds") - .def("mask_varies_along", - [](MlirTpuVregDataBounds self, MlirTpuDirection direction, - MlirTpuI64TargetTuple target_shape) { - return mlirTpuVregDataBoundsMaskVariesAlong(self, direction, - target_shape); - }) - .def("complete", - [](MlirTpuVregDataBounds self, MlirTpuI64TargetTuple target_shape) { - return mlirTpuVregDataBoundsIsComplete(self, target_shape); - }) - .def("get_vector_mask", - [](MlirTpuVregDataBounds self, int generation, - MlirTpuI64TargetTuple target_shape) { - // TODO: Does this work? Test in Python - MlirValue mask = mlirTpuVregDataBoundsGetVectorMask( - self, getDefaultInsertionPoint(), getDefaultLocation(), - generation, target_shape); - if (mask.ptr == nullptr) { - throw std::runtime_error("getVectorMask failed"); - } - return mask; - }) - .def("get_sublane_mask", - [](MlirTpuVregDataBounds self, MlirTpuI64TargetTuple target_shape) { - return mlirTpuVregDataBoundsGetSublaneMask( - self, getDefaultContext(), target_shape); - }); - - // TODO(tlongeri): More precise argument type annotations. There currently - // seems to be no way to define your own? - nb::class_(m, "VectorLayout") - .def( - "__init__", - [](PyTpuVectorLayout* self, int bitwidth, nb::tuple offsets, - nb::tuple tiling, MlirTpuImplicitDim implicit_dim) { - if (offsets.size() != 2) { - throw nb::value_error("Offsets should be of length 2"); - } - if (tiling.size() != 2) { - throw nb::value_error("Tiling should be of length 2"); - } - MlirTpuVectorLayout layout = mlirTpuVectorLayoutCreate( - bitwidth, - {offsetFromPyOffset(offsets[0]), - offsetFromPyOffset(offsets[1])}, - {nb::cast(tiling[0]), nb::cast(tiling[1])}, - implicit_dim); - new (self) PyTpuVectorLayout(layout); - }, - nb::arg("bitwidth"), nb::arg("offsets"), nb::arg("tiling"), - nb::arg("implicit_dim").none()) - .def_prop_ro( - "bitwidth", - [](const PyTpuVectorLayout& self) { - return mlirTpuVectorLayoutGetBitwidth(self.layout); - }, - "The bitwidth of the stored values.") - .def_prop_ro( - "offsets", - [](const PyTpuVectorLayout& self) { - MlirTpuLayoutOffsets offsets = - mlirTpuVectorLayoutGetOffsets(self.layout); - return nb::make_tuple(toPyLayoutOffset(offsets.sublane), - toPyLayoutOffset(offsets.lane)); - }, - "The coordinates of the first valid element. If an offset is " - "REPLICATED, then any offset is valid as the value does not vary " - "across sublanes or lanes respectively.") - .def_prop_ro( - "tiling", - [](const PyTpuVectorLayout& self) { - return toPyTuple(mlirTpuVectorLayoutGetTiling(self.layout)); - }, - "The tiling used to lay out values (see the XLA docs). For values of " - "bitwidth < 32, an implicit (32 // bitwidth, 1) tiling is appended " - "to the one specified as an attribute.") - .def_prop_ro( - "implicit_dim", - [](const PyTpuVectorLayout& self) { - return mlirTpuVectorLayoutGetImplicitDim(self.layout); - }, - "If specified, the value has an implicit dim inserted in either " - "minormost or second minormost position.") - .def_prop_ro( - "packing", - [](const PyTpuVectorLayout& self) { - return mlirTpuVectorLayoutGetPacking(self.layout); - }, - "Returns the number of values stored in a vreg entry.") - .def_prop_ro( - "layout_rank", - [](const PyTpuVectorLayout& self) { - return mlirTpuVectorLayoutGetLayoutRank(self.layout); - }, - "The number of minormost dimensions tiled by this layout.") - .def( - "has_natural_topology", - [](const PyTpuVectorLayout& self, - MlirTpuI64TargetTuple target_shape) { - return mlirTpuVectorLayoutHasNaturalTopology(self.layout, - target_shape); - }, - nb::arg("target_shape"), - "True, if every vector register has a layout without jumps.\n" - "\n" - "By without jumps we mean that traversing vregs over (sub)lanes " - "always leads to a contiguous traversal of the (second) minormost " - "dimension of data. This is only true for 32-bit types, since " - "narrower types use two level tiling.") - .def( - "has_native_tiling", - [](const PyTpuVectorLayout& self, - MlirTpuI64TargetTuple target_shape) { - return mlirTpuVectorLayoutHasNativeTiling(self.layout, - target_shape); - }, - nb::arg("target_shape"), - "True, if every vector register has a natural \"packed\" topology.\n" - "\n" - "This is equivalent to has_natural_topology for 32-bit types, but " - "generalizes it to narrower values with packed layouts too.") - .def( - "tiles_per_vreg", - [](const PyTpuVectorLayout& self, - MlirTpuI64TargetTuple target_shape) { - return mlirTpuVectorLayoutTilesPerVreg(self.layout, target_shape); - }, - nb::arg("target_shape"), - "How many tiles fit in each vector register.") - .def( - "sublanes_per_tile", - [](const PyTpuVectorLayout& self, - MlirTpuI64TargetTuple target_shape) { - return mlirTpuVectorLayoutSublanesPerTile(self.layout, - target_shape); - }, - nb::arg("target_shape"), - "The number of sublanes necessary to store each tile.") - .def( - "vreg_slice", - [](const PyTpuVectorLayout& self, - MlirTpuI64TargetTuple target_shape) { - MlirTpuI64TargetTuple vreg_slice = - mlirTpuVectorLayoutVregSlice(self.layout, target_shape); - return nb::module_::import_(LAYOUT_DEFS_MODULE) - .attr("TargetTuple")(vreg_slice.sublane, vreg_slice.lane); - }, - nb::arg("target_shape"), - "Returns the size of a window contained in a single vreg.\n" - "\n" - "We never reuse the same vector register to store data of multiple " - "rows, so only the minormost dimension can increase.") - .def( - "implicit_shape", - [](const PyTpuVectorLayout& self, nb::sequence shape) { - llvm::SmallVector implicit_shape_vec = - sequenceToSmallVector(shape); - MlirTpuI64ArrayRef implicit_shape = - mlirTpuVectorLayoutImplicitShape( - self.layout, - {implicit_shape_vec.data(), implicit_shape_vec.size()}); - nb::tuple ret = toPyTuple(implicit_shape.ptr, implicit_shape.size); - free(implicit_shape.ptr); - return ret; - }, - nb::arg("shape")) - .def( - "tile_array_shape", - [](const PyTpuVectorLayout& self, nb::sequence shape, - MlirTpuI64TargetTuple target_shape) { - llvm::SmallVector tile_array_shape_vec = - sequenceToSmallVector(shape); - MlirTpuI64ArrayRef tile_array_shape = - mlirTpuVectorLayoutTileArrayShape( - self.layout, - {tile_array_shape_vec.data(), tile_array_shape_vec.size()}, - target_shape); - nb::tuple ret = - toPyTuple(tile_array_shape.ptr, tile_array_shape.size); - free(tile_array_shape.ptr); - return ret; - }, - nb::arg("shape"), nb::arg("target_shape"), - "Returns the shape of an ndarray of vregs needed to represent a " - "value.\n" - "\n" - "All but the last two dimensions are unrolled over vregs. In the " - "last two dims we need as many vregs as indicated by dividing the " - "point at which the value ends (given by the start offset plus the " - "dim size) divided by the respective vreg capacity in that dim (and " - "a ceiling if non-integral). If a value is replicated, then any " - "offset is valid and we pick 0 to minimize the number of vregs.\n" - "\n" - "Args:\n" - " shape: The shape of the ndarray to tile.") - .def( - "tile_data_bounds", - [](const PyTpuVectorLayout& self, nb::sequence shape, - nb::sequence ixs, MlirTpuI64TargetTuple target_shape, - std::variant allow_replicated) { - llvm::SmallVector shape_vec = - sequenceToSmallVector(shape); - llvm::SmallVector ixs_vec = - sequenceToSmallVector(ixs); - if (shape_vec.size() != ixs_vec.size()) { - throw nb::value_error( - "Expected shape and ixs to have the same size"); - } - return std::visit( - [&](auto ar) { - if constexpr (std::is_same_v) { - return mlirTpuVectorLayoutTileDataBounds( - self.layout, getDefaultContext(), shape_vec.data(), - ixs_vec.data(), shape_vec.size(), target_shape, - {ar, ar}); - } else { - return mlirTpuVectorLayoutTileDataBounds( - self.layout, getDefaultContext(), shape_vec.data(), - ixs_vec.data(), shape_vec.size(), target_shape, - {nb::cast(ar[0]), nb::cast(ar[1])}); - } - }, - allow_replicated); - }, - nb::arg("shape"), nb::arg("ixs"), nb::arg("target_shape"), - nb::arg("allow_replicated") = false, - "Returns the bounds of the given tile that hold useful data.\n" - "\n" - "Arguments:\n" - " full_shape: The shape of the full vector this layout applies to.\n" - " ixs: The indices into an array of tiles representing the full " - "vector (see tile_array_shape for bounds) selecting the tile for " - "which the bounds are queried.\n" - " allow_replicated: If False, no offset is allowed to be " - "REPLICATED. If True, offsets are allowed to be REPLICATED, but the " - "bounds will span the full dimension of the tile (i.e. potentially " - "multiple repeats of the actual data).\n" - " target_shape: The target shape of the TPU.\n" - "\n" - "Returns:\n" - " A TargetTuple of slices, indicating the span of useful data " - "within the tile selected by idx.") - .def( - "generalizes", - [](const PyTpuVectorLayout& self, const PyTpuVectorLayout& other, - std::optional shape, - MlirTpuI64TargetTuple target_shape) { - if (shape) { - llvm::SmallVector shape_vec = - sequenceToSmallVector(*shape); - return mlirTpuVectorLayoutGeneralizes( - self.layout, other.layout, - {shape_vec.data(), shape_vec.size()}, target_shape); - } - return mlirTpuVectorLayoutGeneralizes(self.layout, other.layout, - {nullptr, 0}, target_shape); - }, - nb::arg("other"), nb::kw_only(), - nb::arg("shape").none() = std::nullopt, nb::arg("target_shape"), - "Returns True if the other layout is a special case of this one.\n" - "\n" - "In here, other is considered \"a special case\" when the set of " - "vector register entries that represent a value in that layout is " - "also the set of entries in which self stores the value. This is of " - "course true for layouts that are equivalent, but it does not need " - "to hold both ways. For example, a layout that implies the value " - "does not change along an axis of the vector register is more " - "general than the layout that picks a fixed starting point for the " - "value and does not encode that assumption.\n" - "\n" - "The generalization relation is a non-strict partial order. You can " - "think of it as a partial <= on vector layouts, but we don't " - "overload Python operators since there's no clear way to decide " - "where the bottom and top should be.\n" - "\n" - "Args:\n" - " other: The layout compared against self.\n" - " shape: An optional shape of the vector to which both layouts " - "apply.\n" - " The generalization relation is larger than usual for some " - "shapes. That is, if self.generalizes(other) then also " - "self.generalizes(other, shape) for any shape, but that implication " - "does not hold the other way around for some shapes.\n" - " target_shape: The target shape of the TPU.") - .def( - "equivalent_to", - [](const PyTpuVectorLayout& self, const PyTpuVectorLayout& other, - std::optional shape, - MlirTpuI64TargetTuple target_shape) { - if (shape) { - llvm::SmallVector shape_vec = - sequenceToSmallVector(*shape); - return mlirTpuVectorLayoutEquivalentTo( - self.layout, other.layout, - {shape_vec.data(), shape_vec.size()}, target_shape); - } - return mlirTpuVectorLayoutEquivalentTo(self.layout, other.layout, - {nullptr, 0}, target_shape); - }, - nb::arg("other"), nb::kw_only(), - nb::arg("shape").none() = std::nullopt, nb::arg("target_shape"), - "Returns True if the two layouts are equivalent.\n" - "\n" - "That is, when all potential vector entries where the value can be " - "stored (there might be multiple choices for some layouts!) are " - "equal in both self and other.\n" - "\n" - "Args:\n" - " other: The layout compared against self.\n" - " shape: An optional shape of the vector to which both layouts " - "apply. More layouts are considered equivalent when the shape is " - "specified. Also see the docstring of the generalizes method.\n" - " target_shape: The target shape of the TPU.") - .def("__eq__", - [](const PyTpuVectorLayout& self, const PyTpuVectorLayout& other) { - return mlirTpuVectorLayoutEquals(self.layout, other.layout); - }) - .def("__repr__", [](const PyTpuVectorLayout& self) { - std::string str; - mlirTpuVectorLayoutPrint(self.layout, printToString, &str); - return str; - }); - - // TODO(tlongeri): Can we make the first parameter a VectorType? - m.def("assemble", - [](const MlirType ty, const PyTpuVectorLayout& layout, - nb::object np_arr_obj, - MlirTpuI64TargetTuple target_shape) -> MlirOperation { - // TODO(tlongeri): Remove nb::array::c_style, I only added it because - // I couldn't find a simple way to iterate over array data, but it - // causes yet another unnecessary copy. - xla::nb_numpy_ndarray np_arr = xla::nb_numpy_ndarray::ensure( - np_arr_obj, NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_ALIGNED); - if (!mlirTypeIsAVector(ty)) { - throw nb::type_error("Expected vector type"); - } - llvm::SmallVector vals(np_arr.size()); - for (int64_t i = 0; i < np_arr.size(); ++i) { - vals.data()[i] = nb::cast(nb::handle( - reinterpret_cast(np_arr.data())[i])); - } - llvm::SmallVector shape(np_arr.ndim()); - for (int64_t i = 0; i < np_arr.ndim(); ++i) { - shape.data()[i] = np_arr.shape()[i]; - } - return mlirTpuAssemble( - getDefaultInsertionPoint(), ty, layout.layout, - MlirTpuValueArray{MlirTpuI64ArrayRef{shape.data(), shape.size()}, - vals.data()}, - target_shape); - }); - m.def("disassemble", [](const PyTpuVectorLayout& layout, MlirValue val, - MlirTpuI64TargetTuple target_shape) { - DiagnosticCapture diag_capture(getDefaultContext()); - MlirTpuValueArray val_arr = mlirTpuDisassemble( - getDefaultInsertionPoint(), layout.layout, val, target_shape); - if (val_arr.vals == nullptr) { - diag_capture.throwIfError(); - throw nb::value_error("Failed to disassemble"); - } - xla::nb_numpy_ndarray np_vals( - /*dtype=*/xla::nb_dtype("O"), - /*shape=*/ - absl::Span(val_arr.shape.ptr, val_arr.shape.size), - /*strides=*/std::nullopt); - for (ssize_t i = 0; i < np_vals.size(); ++i) { - reinterpret_cast(np_vals.mutable_data())[i] = - nb::cast(val_arr.vals[i]).release().ptr(); - } - free(val_arr.shape.ptr); - free(val_arr.vals); - return np_vals; - }); - - m.def("apply_layout_op", - [](MlirTpuApplyVectorLayoutContext ctx, const MlirOperation c_op) { - DiagnosticCapture diag_capture(getDefaultContext()); - MlirLogicalResult res = mlirTpuApplyLayoutOp(ctx, c_op); - if (mlirLogicalResultIsFailure(res)) { - diag_capture.throwIfError(); - throw std::runtime_error("applyLayoutOp failed"); - } - }); - m.def("relayout", [](MlirValue v, const PyTpuVectorLayout& src, - const PyTpuVectorLayout& dst, - MlirTpuApplyVectorLayoutContext apply_layout_ctx) { - DiagnosticCapture diag_capture(getDefaultContext()); - MlirValue new_v = mlirTpuRelayout(getDefaultInsertionPoint(), v, src.layout, - dst.layout, apply_layout_ctx); - if (new_v.ptr == nullptr) { - diag_capture.throwIfError(); - throw nb::value_error("Failed to relayout"); - } - return new_v; - }); - nb::register_exception_translator( - [](const std::exception_ptr& p, void*) { - try { - if (p) std::rethrow_exception(p); - } catch (const NotImplementedException& e) { - PyErr_SetString(PyExc_NotImplementedError, e.what()); - } - }, - nullptr); - m.def( "register_dialect", [](MlirContext context, bool load) { @@ -778,26 +39,6 @@ NB_MODULE(_tpu_ext, m) { }, nb::arg("context"), nb::arg("load") = true); - m.def("private_is_tiled_layout", [](MlirAttribute attr) { - return mlirTPUAttributeIsATiledLayoutAttr(attr); - }); - m.def("private_get_tiles", [](MlirAttribute attr) -> nb::object { - MlirAttribute encoded_tiles = mlirTPUTiledLayoutAttrGetTiles(attr); - nb::tuple py_tiles = nb::steal( - PyTuple_New(mlirArrayAttrGetNumElements(encoded_tiles))); - for (intptr_t i = 0; i < mlirArrayAttrGetNumElements(encoded_tiles); ++i) { - MlirAttribute tile = mlirArrayAttrGetElement(encoded_tiles, i); - nb::tuple py_tile = - nb::steal(PyTuple_New(mlirDenseArrayGetNumElements(tile))); - for (intptr_t j = 0; j < mlirDenseArrayGetNumElements(tile); ++j) { - PyTuple_SET_ITEM( - py_tile.ptr(), j, - nb::cast(mlirDenseI64ArrayGetElement(tile, j)).release().ptr()); - } - PyTuple_SET_ITEM(py_tiles.ptr(), i, py_tile.release().ptr()); - } - return py_tiles; - }); m.def("private_has_communication", [](MlirOperation op) { bool has_communication; bool has_custom_barrier; @@ -807,63 +48,22 @@ NB_MODULE(_tpu_ext, m) { }); // TODO(apaszke): All of those should be upstreamed to MLIR Python bindings. - m.def("private_replace_all_uses_with", [](MlirOperation op, - std::vector vals) { - if (vals.size() != mlirOperationGetNumResults(op)) { - throw nb::value_error("length mismatch in replace_all_uses_with"); - } - for (int i = 0; i < vals.size(); ++i) { - mlirValueReplaceAllUsesOfWith(mlirOperationGetResult(op, i), vals[i]); - } - }); - m.def("private_replace_all_uses_except", - [](MlirValue old, MlirValue new_val, MlirOperation except) { - for (intptr_t i = 0; i < mlirOperationGetNumOperands(except); ++i) { - if (mlirValueEqual(mlirOperationGetOperand(except, i), new_val)) { - throw nb::value_error("new val already used in except"); - } - } - mlirValueReplaceAllUsesOfWith(old, new_val); - // Undo the replacement in the except op. - for (intptr_t i = 0; i < mlirOperationGetNumOperands(except); ++i) { - if (mlirValueEqual(mlirOperationGetOperand(except, i), new_val)) { - mlirOperationSetOperand(except, i, old); - } - } - }); - m.def("private_set_operand", - [](MlirOperation op, int idx, MlirValue new_operand) { - mlirOperationSetOperand(op, idx, new_operand); - }); - m.def("private_set_operands", [](MlirOperation op, - std::vector new_operands) { - mlirOperationSetOperands(op, new_operands.size(), new_operands.data()); - }); - m.def("private_has_no_memory_space", [](MlirType ty) { - return mlirAttributeIsNull(mlirMemRefTypeGetMemorySpace(ty)); - }); - m.def("private_is_identity", [](MlirAttribute attr) { - return mlirAffineMapIsIdentity(mlirAffineMapAttrGetValue(attr)); - }); - m.def("private_insert_argument", - [](int index, MlirBlock block, MlirType type) -> MlirValue { - return mlirBlockInsertArgument( - block, index, type, - mlirLocationUnknownGet(mlirTypeGetContext(type))); - }); m.def("private_set_arg_attr", [](MlirOperation op, unsigned i, std::string name, MlirAttribute attr) { mlirFuncSetArgAttr( op, i, mlirStringRefCreateFromCString(name.c_str()), attr); }); - m.def("private_move_all_regions", [](MlirOperation src, MlirOperation dst) { - if (mlirOperationGetNumRegions(src) != mlirOperationGetNumRegions(dst)) { - throw nb::value_error( - "Region counts do not match in src operation and dst operations"); - } - for (intptr_t i = 0; i < mlirOperationGetNumRegions(src); ++i) { - mlirRegionTakeBody(mlirOperationGetRegion(dst, i), - mlirOperationGetRegion(src, i)); - } - }); + + mlir::python::nanobind_adaptors::mlir_type_subclass(m, "Float8EXMYType", + mlirTpuIsAFloat8EXMYType) + .def_classmethod( + "get", + [](nb::object cls, MlirType exmy_type, MlirContext ctx) { + return cls(mlirTpuFloat8EXMYTypeGet(ctx, exmy_type)); + }, + nb::arg("self"), nb::arg("exmy_type") = nullptr, + nb::arg("ctx") = nullptr) + .def_property_readonly("underlying_type", [](MlirType self) { + return mlirTpuFloat8EXMYTypeGetUnderlyingType(self); + }); } diff --git a/jaxlib/mlir/_mlir_libs/traceback_to_location.cc b/jaxlib/mlir/_mlir_libs/traceback_to_location.cc new file mode 100644 index 000000000000..50b96d039ce9 --- /dev/null +++ b/jaxlib/mlir/_mlir_libs/traceback_to_location.cc @@ -0,0 +1,109 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/mlir/_mlir_libs/traceback_to_location.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/types/span.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep +#include "mlir/CAPI/IR.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/traceback.h" + +namespace nb = ::nanobind; + +namespace jax { + +TracebackToLocationCache::TracebackToLocationCache( + nanobind::callable code_to_filename, int frame_limit, + mlir::MLIRContext* context) + : code_to_filename_(std::move(code_to_filename)), + frame_limit_(frame_limit), + context_(context) {} + +nb::object TracebackToLocationCache::Get(const Traceback& traceback) { + auto& traceback_cache_entry = traceback_to_location_cache_[traceback]; + if (!traceback_cache_entry.ptr()) { + absl::Span frames = traceback.RawFrames(); + std::vector frame_locs_vector; + frame_locs_vector.reserve(frames.size()); + for (const TracebackEntry& frame : frames) { + auto& frame_cache_entry = frame_cache_[frame]; + if (!frame_cache_entry.has_value()) { + // Canonicalize the filename, and skip it if it's not to be shown. + auto [filename_cache_it, inserted] = + code_to_filename_cache_.insert({frame.code, std::nullopt}); + auto& filename_cache_entry = filename_cache_it->second; + if (inserted) { + nb::object out = code_to_filename_( + nb::borrow(reinterpret_cast(frame.code))); + if (out.is_none()) { + filename_cache_entry = std::nullopt; + } else { + filename_cache_entry = nb::cast(out); + } + } + if (!filename_cache_entry.has_value()) { + continue; + } + const std::string& filename = *filename_cache_entry; + + int start_line, start_column, end_line, end_column; + if (!PyCode_Addr2Location(frame.code, frame.lasti, &start_line, + &start_column, &end_line, &end_column)) { + throw nb::python_error(); + } + std::string_view function_name = nb::cast( + nb::borrow(frame.code->co_qualname)); + frame_cache_entry = mlir::NameLoc::get( + mlir::StringAttr::get(context_, function_name), + mlir::FileLineColRange::get( + mlir::StringAttr::get(context_, filename), start_line, + start_column, end_line, end_column)); + } + frame_locs_vector.push_back(*frame_cache_entry); + } + absl::Span frame_locs_span = frame_locs_vector; + frame_locs_span = frame_locs_span.first( + std::min(frame_locs_span.size(), frame_limit_)); + std::optional loc; + for (auto it = frame_locs_span.rbegin(); it != frame_locs_span.rend(); + ++it) { + if (loc.has_value()) { + loc = mlir::CallSiteLoc::get(*it, *loc); + } else { + loc = *it; + } + } + traceback_cache_entry = nb::cast( + wrap(loc.has_value() ? *loc : mlir::UnknownLoc::get(context_))); + } + return traceback_cache_entry; +} + +} // namespace jax diff --git a/jaxlib/mlir/_mlir_libs/traceback_to_location.h b/jaxlib/mlir/_mlir_libs/traceback_to_location.h new file mode 100644 index 000000000000..e2cd6d0e8b74 --- /dev/null +++ b/jaxlib/mlir/_mlir_libs/traceback_to_location.h @@ -0,0 +1,77 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_MLIR_MLIR_LIBS_TRACEBACK_TO_LOCATION_H_ +#define JAXLIB_MLIR_MLIR_LIBS_TRACEBACK_TO_LOCATION_H_ + +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/container/flat_hash_map.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "nanobind/nanobind.h" +#include "jaxlib/traceback.h" + +namespace jax { + +class TracebackToLocationCache { + public: + // code_to_filename is a user provided callable that maps a code object to + // its canonicalized filename that should appeared in the MLIR location. + // Returns None if the filename should be omitted in tracebacks. + TracebackToLocationCache(nanobind::callable code_to_filename, int frame_limit, + mlir::MLIRContext* context); + + // Returns an MLIR location for the given traceback. + // If the traceback is empty, returns an unknown location. + nanobind::object Get(const Traceback& traceback); + + private: + nanobind::callable code_to_filename_; + int frame_limit_; + mlir::MLIRContext* context_; + + // Cached results of code_to_filename_. + absl::flat_hash_map> + code_to_filename_cache_; + + // Cached mapping from individual frames to MLIR locations. + absl::flat_hash_map> + frame_cache_; + + // Cached mapping from tracebacks to MLIR locations. + struct TracebackHash { + size_t operator()(const Traceback& traceback) const noexcept { + // We know the hash of a traceback will not throw. + return absl::bit_cast(nanobind::hash(traceback)); + } + }; + struct TracebackEqual { + bool operator()(const Traceback& a, const Traceback& b) const noexcept { + // We know equality of tracebacks will not throw. + return a.equal(b); + } + }; + absl::flat_hash_map + traceback_to_location_cache_; +}; + +} // namespace jax + +#endif // JAXLIB_MLIR_MLIR_LIBS_TRACEBACK_TO_LOCATION_H_ diff --git a/jaxlib/mlir/_mlir_libs/triton_ext.cc b/jaxlib/mlir/_mlir_libs/triton_ext.cc index 2a13c40d963f..edce47321442 100644 --- a/jaxlib/mlir/_mlir_libs/triton_ext.cc +++ b/jaxlib/mlir/_mlir_libs/triton_ext.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#ifndef _WIN32 + +#include #include #include "mlir-c/IR.h" @@ -42,8 +45,8 @@ NB_MODULE(_triton_ext, m) { // Types. // - mlir::python::nanobind_adaptors::mlir_type_subclass(m, "PointerType", - mlirTritonIsAPointer) + mlir::python::nanobind_adaptors::mlir_type_subclass( + m, "PointerType", mlirTritonIsAPointer, mlirTritonPointerTypeGetTypeID) .def_classmethod( "get", [](nb::object cls, MlirType pointee_type, int64_t address_space) { @@ -51,9 +54,10 @@ NB_MODULE(_triton_ext, m) { }, nb::arg("cls"), nb::arg("pointee_type"), nb::arg("address_space"), "Creates a PointerType type.") - .def_property_readonly("pointee_type", [](MlirType self) { - return mlirTritonPointerTypeGetPointeeType(self); - }) + .def_property_readonly("pointee_type", + [](MlirType self) { + return mlirTritonPointerTypeGetPointeeType(self); + }) .def_property_readonly("address_space", [](MlirType self) { return mlirTritonPointerTypeGetAddressSpace(self); }); @@ -73,3 +77,11 @@ NB_MODULE(_triton_ext, m) { return encoding; }); } + +#else // _WIN32 + +#include "nanobind/nanobind.h" + +NB_MODULE(_triton_ext, m) {} + +#endif // _WIN32 diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD index 4cc2530dd7ca..074f723e8d06 100644 --- a/jaxlib/mosaic/BUILD +++ b/jaxlib/mosaic/BUILD @@ -1,5 +1,4 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("@rules_python//python:defs.bzl", "py_library") # Copyright 2023 The JAX Authors. # @@ -14,14 +13,16 @@ load("@rules_python//python:defs.bzl", "py_library") # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -load("//jaxlib:jax.bzl", "mosaic_extension_deps") +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") +load("@rules_python//python:defs.bzl", "py_library") licenses(["notice"]) package( default_applicable_licenses = [], default_visibility = [ - "//jax:mosaic_users", + "//jax/experimental:mosaic_users", ], ) @@ -45,35 +46,82 @@ cc_library( "dialect/tpu/tpu_ops.cc", "dialect/tpu/util.cc", "dialect/tpu/vreg_util.cc", - ":extension_srcs", - ] + glob([ - "dialect/tpu/transforms/*.cc", - ]), + ], hdrs = [ "dialect/tpu/array_util.h", "dialect/tpu/layout.h", "dialect/tpu/tpu_dialect.h", "dialect/tpu/util.h", "dialect/tpu/vreg_util.h", - ] + glob([ - "dialect/tpu/transforms/*.h", - ]), + ], # compatible with libtpu deps = [ ":tpu_inc_gen", - "//jaxlib:pass_boilerplate", - "//jaxlib/mosaic:serde", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/hash", - "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:CommonFolders", + "@llvm-project//mlir:DialectUtils", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:VectorDialect", + "@xla//xla:array", + "@xla//xla:shape_util", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_library( + name = "tpu_serde_pass", + srcs = ["dialect/tpu/transforms/serde.cc"], + hdrs = ["dialect/tpu/transforms/serde.h"], + # compatible with libtpu + deps = [ + ":pass_boilerplate", + ":serde", + ":tpu_dialect", + ":tpu_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:DataLayoutInterfaces", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:VectorDialect", + ], +) + +cc_library( + name = "tpu_linalg_vectorization_pass", + srcs = ["dialect/tpu/transforms/linalg_vectorization.cc"], + hdrs = ["dialect/tpu/transforms/linalg_vectorization.h"], + # compatible with libtpu + deps = [ + ":pass_boilerplate", + ":serde", + ":tpu_dialect", + ":tpu_inc_gen", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:CommonFolders", "@llvm-project//mlir:ControlFlowDialect", "@llvm-project//mlir:DataLayoutInterfaces", "@llvm-project//mlir:Dialect", @@ -90,83 +138,73 @@ cc_library( "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:VectorDialect", "@llvm-project//mlir:VectorTransforms", - "@tsl//tsl/platform:statusor", - "@xla//xla:array", - "@xla//xla:shape_util", - "@xla//xla:util", - "@xla//xla/tsl/platform:errors", - ] + mosaic_extension_deps, + ], ) gentbl_cc_library( name = "tpu_inc_gen", # compatible with libtpu - tbl_outs = [ - ( - ["-gen-op-decls"], - "dialect/tpu/tpu_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "dialect/tpu/tpu_ops.cc.inc", - ), - ( - ["-gen-dialect-decls"], - "dialect/tpu/tpu_dialect.h.inc", - ), - ( - ["-gen-dialect-defs"], - "dialect/tpu/tpu_dialect.cc.inc", - ), - ( - ["-gen-enum-decls"], - "dialect/tpu/tpu_enums.h.inc", - ), - ( - ["-gen-enum-defs"], - "dialect/tpu/tpu_enums.cc.inc", - ), - ( - ["-gen-attrdef-decls"], - "dialect/tpu/tpu_attr_defs.h.inc", - ), - ( - ["-gen-attrdef-defs"], - "dialect/tpu/tpu_attr_defs.cc.inc", - ), - ( - ["-gen-typedef-decls"], - "dialect/tpu/tpu_type_defs.h.inc", - ), - ( - ["-gen-typedef-defs"], - "dialect/tpu/tpu_type_defs.cc.inc", - ), - ( - [ - "-gen-pass-decls", - "-name=TPU", - ], - "dialect/tpu/tpu_passes.h.inc", - ), - ( - [ - "-gen-pass-capi-header", - "--prefix=TPU", - ], - "dialect/tpu/integrations/c/tpu_passes.capi.h.inc", - ), - ( - [ - "-gen-pass-capi-impl", - "--prefix=TPU", - ], - "dialect/tpu/integrations/c/tpu_passes.capi.cc.inc", - ), - ], + tbl_outs = { + "dialect/tpu/tpu_ops.h.inc": [ + "-gen-op-decls", + "-dialect=tpu", + ], + "dialect/tpu/tpu_ops.cc.inc": [ + "-gen-op-defs", + "-dialect=tpu", + ], + "dialect/tpu/tpu_dialect.h.inc": [ + "-gen-dialect-decls", + "-dialect=tpu", + ], + "dialect/tpu/tpu_dialect.cc.inc": [ + "-gen-dialect-defs", + "-dialect=tpu", + ], + "dialect/tpu/tpu_enums.h.inc": [ + "-gen-enum-decls", + "-dialect=tpu", + ], + "dialect/tpu/tpu_enums.cc.inc": [ + "-gen-enum-defs", + "-dialect=tpu", + ], + "dialect/tpu/tpu_attr_defs.h.inc": [ + "-gen-attrdef-decls", + "-dialect=tpu", + "--attrdefs-dialect=tpu", + ], + "dialect/tpu/tpu_attr_defs.cc.inc": [ + "-gen-attrdef-defs", + "-dialect=tpu", + "--attrdefs-dialect=tpu", + ], + "dialect/tpu/tpu_type_defs.h.inc": [ + "-gen-typedef-decls", + "-dialect=tpu", + "--typedefs-dialect=tpu", + ], + "dialect/tpu/tpu_type_defs.cc.inc": [ + "-gen-typedef-defs", + "-dialect=tpu", + "--typedefs-dialect=tpu", + ], + "dialect/tpu/tpu_passes.h.inc": [ + "-gen-pass-decls", + "-name=TPU", + ], + "dialect/tpu/integrations/c/tpu_passes.capi.h.inc": [ + "-gen-pass-capi-header", + "--prefix=TPU", + ], + "dialect/tpu/integrations/c/tpu_passes.capi.cc.inc": [ + "-gen-pass-capi-impl", + "--prefix=TPU", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "dialect/tpu/tpu.td", - deps = [":tpu_td_files"], + td_file = "dialect/tpu/tpu_ops.td", + deps = [":tpu_ops_td_files"], ) td_library( @@ -177,6 +215,18 @@ td_library( # compatible with libtpu deps = [ "@llvm-project//mlir:BuiltinDialectTdFiles", + ], +) + +td_library( + name = "tpu_ops_td_files", + srcs = [ + "dialect/tpu/tpu_ops.td", + ], + # compatible with libtpu + deps = [ + ":tpu_td_files", + "@llvm-project//mlir:BuiltinDialectTdFiles", "@llvm-project//mlir:ControlFlowInterfacesTdFiles", "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", "@llvm-project//mlir:OpBaseTdFiles", @@ -204,14 +254,12 @@ cc_library( deps = [ ":tpu_dialect", ":tpu_inc_gen", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", + ":tpu_serde_pass", "@llvm-project//llvm:Support", "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", - "@xla//xla:array", ], ) @@ -233,14 +281,12 @@ cc_library( deps = [ ":tpu_dialect", ":tpu_inc_gen", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", + ":tpu_serde_pass", "@llvm-project//llvm:Support", - "@llvm-project//mlir:CAPIIRObjects", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", - "@xla//xla:array", ], alwayslink = True, ) @@ -270,13 +316,31 @@ cc_test( ], ) -filegroup( - name = "extension_srcs", - srcs = [ - "dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc", - "dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc", +cc_test( + name = "tpu_ops_verification_test", + srcs = ["dialect/tpu/tpu_ops_verification_test.cc"], + deps = [ + ":tpu_dialect", + "//testing/base/public:gunit_main", + "@com_google_absl//absl/status", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Support", + "@xla//xla/mlir/utils:error_util", ], +) + +cc_library( + name = "pass_boilerplate", + hdrs = ["pass_boilerplate.h"], # compatible with libtpu + deps = [ + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], ) cc_library( diff --git a/jaxlib/mosaic/dialect/gpu/BUILD b/jaxlib/mosaic/dialect/gpu/BUILD index e21c8756a4e2..6f47649ee6c7 100644 --- a/jaxlib/mosaic/dialect/gpu/BUILD +++ b/jaxlib/mosaic/dialect/gpu/BUILD @@ -18,10 +18,12 @@ load( "gentbl_filegroup", "td_library", ) +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") package( default_applicable_licenses = [], - default_visibility = ["//jax:mosaic_gpu_users"], + default_visibility = ["//jax/experimental:mosaic_gpu_users"], ) td_library( @@ -31,6 +33,7 @@ td_library( deps = [ "@llvm-project//mlir:BasicPtxBuilderIntTdFiles", "@llvm-project//mlir:BuiltinDialectTdFiles", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", "@llvm-project//mlir:LLVMOpsTdFiles", "@llvm-project//mlir:OpBaseTdFiles", @@ -39,66 +42,36 @@ td_library( gentbl_cc_library( name = "mosaic_gpu_inc_gen", - tbl_outs = [ - ( - [ - "-gen-dialect-decls", - "-dialect=mosaic_gpu", - ], - "mosaic_gpu_dialect.h.inc", - ), - ( - [ - "-gen-dialect-defs", - "-dialect=mosaic_gpu", - ], - "mosaic_gpu_dialect.cc.inc", - ), - ( - ["-gen-op-decls"], - "mosaic_gpu_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "mosaic_gpu_ops.cc.inc", - ), - ( - [ - "-gen-typedef-decls", - "--typedefs-dialect=mosaic_gpu", - ], - "mosaic_gpu_types.h.inc", - ), - ( - [ - "-gen-typedef-defs", - "--typedefs-dialect=mosaic_gpu", - ], - "mosaic_gpu_types.cc.inc", - ), - ( - ["-gen-enum-decls"], - "mosaic_gpu_enums.h.inc", - ), - ( - ["-gen-enum-defs"], - "mosaic_gpu_enums.cc.inc", - ), - ( - [ - "-gen-attrdef-decls", - "--attrdefs-dialect=mosaic_gpu", - ], - "mosaic_gpu_attrdefs.h.inc", - ), - ( - [ - "-gen-attrdef-defs", - "--attrdefs-dialect=mosaic_gpu", - ], - "mosaic_gpu_attrdefs.cc.inc", - ), - ], + tbl_outs = { + "mosaic_gpu_dialect.h.inc": [ + "-gen-dialect-decls", + "-dialect=mosaic_gpu", + ], + "mosaic_gpu_dialect.cc.inc": [ + "-gen-dialect-defs", + "-dialect=mosaic_gpu", + ], + "mosaic_gpu_ops.h.inc": ["-gen-op-decls"], + "mosaic_gpu_ops.cc.inc": ["-gen-op-defs"], + "mosaic_gpu_types.h.inc": [ + "-gen-typedef-decls", + "--typedefs-dialect=mosaic_gpu", + ], + "mosaic_gpu_types.cc.inc": [ + "-gen-typedef-defs", + "--typedefs-dialect=mosaic_gpu", + ], + "mosaic_gpu_enums.h.inc": ["-gen-enum-decls"], + "mosaic_gpu_enums.cc.inc": ["-gen-enum-defs"], + "mosaic_gpu_attrdefs.h.inc": [ + "-gen-attrdef-decls", + "--attrdefs-dialect=mosaic_gpu", + ], + "mosaic_gpu_attrdefs.cc.inc": [ + "-gen-attrdef-defs", + "--attrdefs-dialect=mosaic_gpu", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "mosaic_gpu.td", deps = [ @@ -118,7 +91,9 @@ cc_library( "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowInterfaces", "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:LLVMCommonConversion", @@ -127,7 +102,7 @@ cc_library( "@llvm-project//mlir:MemRefUtils", "@llvm-project//mlir:SCFUtils", "@llvm-project//mlir:Support", - "@tsl//tsl/platform:statusor", + "@xla//xla/tsl/platform:statusor", ], ) @@ -140,7 +115,6 @@ cc_test( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", "@llvm-project//llvm:Support", "@llvm-project//mlir:BufferizationInterfaces", "@llvm-project//mlir:DataLayoutInterfaces", @@ -151,7 +125,7 @@ cc_test( "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:SCFUtils", "@llvm-project//mlir:Support", - "@tsl//tsl/platform:errors", + "@xla//xla/tsl/platform:errors", ], ) @@ -216,7 +190,9 @@ cc_library( ":mosaic_gpu_inc_gen", "@llvm-project//llvm:Support", "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMIRTransforms", "@llvm-project//mlir:Support", ], ) @@ -241,7 +217,9 @@ cc_library( ":mosaic_gpu_inc_gen", "@llvm-project//llvm:Support", "@llvm-project//mlir:CAPIIRObjects", + "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMIRTransforms", "@llvm-project//mlir:Support", ], alwayslink = True, diff --git a/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc b/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc index eac1d104f07f..523b14e425c9 100644 --- a/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc +++ b/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc @@ -1,7 +1,21 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h" #include -#include #include "mlir-c/IR.h" #include "mlir/CAPI/IR.h" @@ -82,36 +96,3 @@ int32_t mlirMosaicGpuSwizzleTransformAttrGetSwizzle(MlirAttribute attr) { .getSwizzle() .getValue()); } - -//===----------------------------------------------------------------------===// -// LayoutAttr -//===----------------------------------------------------------------------===// - -bool mlirMosaicGpuIsALayoutAttr(MlirAttribute attr) { - return mlir::isa(unwrap(attr)); -} - -MlirAttribute mlirMosaicGpuLayoutAttrGet(MlirContext ctx, - int32_t num_dimensions, - MlirAttribute* transforms, - int32_t transforms_size) { - std::vector unwrapped_transforms; - unwrapped_transforms.reserve(transforms_size); - for (int i = 0; i < transforms_size; ++i) { - unwrapped_transforms.push_back(unwrap(transforms[i])); - } - return wrap(mosaic_gpu::LayoutAttr::get(unwrap(ctx), num_dimensions, - unwrapped_transforms)); -} - -int32_t mlirMosaicGpuLayoutAttrGetTransformsSize(MlirAttribute attr) { - return mlir::cast(unwrap(attr)) - .getTransforms() - .size(); -} - -MlirAttribute mlirMosaicGpuLayoutAttrGetTransform(MlirAttribute attr, - int32_t index) { - return wrap( - mlir::cast(unwrap(attr)).getTransforms()[index]); -} \ No newline at end of file diff --git a/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h b/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h index 3b8425b6b142..3221b9220e5d 100644 --- a/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h +++ b/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h @@ -69,22 +69,6 @@ mlirMosaicGpuSwizzleTransformAttrGet(MlirContext ctx, int32_t swizzle); MLIR_CAPI_EXPORTED int32_t mlirMosaicGpuSwizzleTransformAttrGetSwizzle(MlirAttribute attr); -//===----------------------------------------------------------------------===// -// LayoutAttr -//===----------------------------------------------------------------------===// - -MLIR_CAPI_EXPORTED bool mlirMosaicGpuIsALayoutAttr(MlirAttribute attr); - -MLIR_CAPI_EXPORTED MlirAttribute -mlirMosaicGpuLayoutAttrGet(MlirContext ctx, int32_t num_dimensions, - MlirAttribute* transforms, int32_t transforms_size); - -MLIR_CAPI_EXPORTED int32_t -mlirMosaicGpuLayoutAttrGetTransformsSize(MlirAttribute attr); - -MLIR_CAPI_EXPORTED MlirAttribute -mlirMosaicGpuLayoutAttrGetTransform(MlirAttribute attr, int32_t index); - #ifdef __cplusplus } #endif diff --git a/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.cc b/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.cc index 1a854f395044..86d2cc270513 100644 --- a/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.cc +++ b/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.cc @@ -16,10 +16,19 @@ limitations under the License. #include "jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h" #include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/Func/Extensions/InlinerExtension.h" +#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu.h" extern "C" { MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(MosaicGPU, mosaic_gpu, mosaic_gpu::MosaicGPUDialect); + +void mlirDialectRegistryInsertMosaicGpuInlinerExtensions( + MlirDialectRegistry registry) { + mlir::LLVM::registerInlinerInterface(*unwrap(registry)); + mlir::func::registerInlinerExtension(*unwrap(registry)); } + +} // extern "C" diff --git a/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h b/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h index bb6cf6e3af4a..9ecc44c0978f 100644 --- a/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h +++ b/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "mlir/CAPI/Registration.h" +#include "mlir-c/IR.h" #ifdef __cplusplus extern "C" { @@ -26,6 +26,9 @@ extern "C" { MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(MosaicGPU, mosaic_gpu); +MLIR_CAPI_EXPORTED void +mlirDialectRegistryInsertMosaicGpuInlinerExtensions(MlirDialectRegistry registry); + #ifdef __cplusplus } #endif diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc index a1e7b571d20e..f1133679520d 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc @@ -16,9 +16,16 @@ limitations under the License. #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu.h" #include +#include +#include #include +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" @@ -26,32 +33,31 @@ limitations under the License. #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/Region.h" #include "mlir/IR/TypeRange.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" -#include "absl/algorithm/container.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Diagnostics.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/statusor.h" // Generated definitions. #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_dialect.cc.inc" @@ -80,13 +86,13 @@ using Integer = ::mlir::TypedValue<::mlir::IntegerType>; Integer ToI64(ImplicitLocOpBuilder& b, Index index) { return llvm::cast( - b.create(b.getI64Type(), index).getResult()); + mlir::arith::IndexCastOp::create(b, b.getI64Type(), index).getResult()); } template Value Constant(ImplicitLocOpBuilder& b, T scalar, IntegerType type) { - return b.create( - type, mlir::IntegerAttr::get(type, scalar)); + return mlir::arith::ConstantOp::create(b, type, + mlir::IntegerAttr::get(type, scalar)); } template @@ -109,8 +115,9 @@ absl::StatusOr ToLLVMArray(ImplicitLocOpBuilder& b, MLIRContext* ctx = b.getContext(); mlir::LLVM::LLVMPointerType pointer_type = mlir::LLVM::LLVMPointerType::get(ctx); - Pointer array_pointer = b.create( - pointer_type, element_type, Constant(b, values.size(), b.getI64Type())); + Pointer array_pointer = + mlir::LLVM::AllocaOp::create(b, pointer_type, element_type, + Constant(b, values.size(), b.getI64Type())); for (auto [i, value] : llvm::enumerate(values)) { if (value.getType() != element_type) { @@ -120,11 +127,11 @@ absl::StatusOr ToLLVMArray(ImplicitLocOpBuilder& b, } auto element_pointer = llvm::cast( - b.create( - pointer_type, element_type, array_pointer, - mlir::ArrayRef(mlir::LLVM::GEPArg(i))) + mlir::LLVM::GEPOp::create( + b, pointer_type, element_type, array_pointer, + mlir::ArrayRef(mlir::LLVM::GEPArg(i))) .getResult()); - b.create(value, element_pointer); + mlir::LLVM::StoreOp::create(b, value, element_pointer); } return array_pointer; @@ -133,21 +140,21 @@ absl::StatusOr ToLLVMArray(ImplicitLocOpBuilder& b, // Extracts a pointer to the start of the parameter memref. Pointer FromMemref(ImplicitLocOpBuilder& b, Memref memref) { Index aligned_pointer_as_index = - b.create(memref); + mlir::memref::ExtractAlignedPointerAsIndexOp::create(b, memref); mlir::LLVM::LLVMPointerType pointer_type = mlir::LLVM::LLVMPointerType::get(b.getContext()); - Value alloc_pointer = b.create( - pointer_type, ToI64(b, aligned_pointer_as_index)); + Value alloc_pointer = mlir::LLVM::IntToPtrOp::create( + b, pointer_type, ToI64(b, aligned_pointer_as_index)); Type tensor_element_type = memref.getType().getElementType(); return mlir::cast( - b.create( - pointer_type, tensor_element_type, alloc_pointer, - mlir::ArrayRef( - mlir::LLVM::GEPArg(ToI64(b, aligned_pointer_as_index)))) + mlir::LLVM::GEPOp::create( + b, pointer_type, tensor_element_type, alloc_pointer, + mlir::ArrayRef( + mlir::LLVM::GEPArg(ToI64(b, aligned_pointer_as_index)))) .getResult()); } @@ -162,7 +169,7 @@ absl::Status InitTmaDescriptor(mlir::OpBuilder& builder, mlir::NameLoc::get(builder.getStringAttr("InitTmaDescriptor")), builder); mlir::memref::ExtractStridedMetadataOp extract_strided_metadata_op = - b.create(gmem_ref); + mlir::memref::ExtractStridedMetadataOp::create(b, gmem_ref); Type tensor_element_type = gmem_ref.getType().getElementType(); @@ -205,8 +212,8 @@ absl::Status InitTmaDescriptor(mlir::OpBuilder& builder, } // TODO(bchetioui): connect this to runtime. - b.create( - kRuntimeTmaDescriptorInitializerName, TypeRange{}, + mlir::func::CallOp::create( + b, kRuntimeTmaDescriptorInitializerName, TypeRange{}, ValueRange{/*tma_desc=*/host_pointer_to_descriptor, /*base_addr=*/tensor_base_pointer, /*elem_bytewidth=*/Constant(b, elem_bitwidth / 8, i64), @@ -225,32 +232,23 @@ void DeclareRuntimeFunctions(mlir::OpBuilder& builder) { mlir::LLVM::LLVMPointerType ptr = mlir::LLVM::LLVMPointerType::get(ctx); IntegerType i64 = builder.getI64Type(); - builder - .create( - builder.getUnknownLoc(), kRuntimeTmaDescriptorInitializerName, - builder.getFunctionType( - TypeRange{ptr, ptr, i64, i64, ptr, ptr, i64, ptr}, TypeRange{})) + mlir::func::FuncOp::create( + builder, builder.getUnknownLoc(), kRuntimeTmaDescriptorInitializerName, + builder.getFunctionType(TypeRange{ptr, ptr, i64, i64, ptr, ptr, i64, ptr}, + TypeRange{})) .setVisibility(mlir::func::FuncOp::Visibility::Private); } -bool IsContiguous(mlir::MemRefType type) { - return type.getLayout().isIdentity() || - (type.hasStaticShape() && type.getNumElements() > 0 && - mlir::memref::isStaticShapeAndContiguousRowMajor(type)); -} - namespace { llvm::LogicalResult VerifyCommonLoadStoreOp( - mlir::Location loc, mlir::MemRefType gmem_type, absl::string_view gmem_name, - mlir::MemRefType smem_type, absl::string_view smem_name, - mlir::ArrayRef slice_lengths, int num_indices) { - auto error = [loc](auto... params) { - return emitError(loc, llvm::formatv(params...)); + mlir::Operation* op, mlir::MemRefType gmem_type, std::string_view gmem_name, + mlir::MemRefType smem_type, std::string_view smem_name, + mlir::ArrayRef slice_lengths, + mlir::Operation::operand_range indices) { + auto error = [op](auto... params) { + return op->emitError(llvm::formatv(params...)); }; - if (!IsContiguous(smem_type)) { - return error("The `{0}` memref must be contiguous.", smem_name); - } if (gmem_type.getElementType() != smem_type.getElementType()) { return error( "The `source` and `destination` memrefs must have the same element " @@ -268,7 +266,7 @@ llvm::LogicalResult VerifyCommonLoadStoreOp( "by -1 values in `slice_lengths`.", gmem_name, smem_name); } - if (num_indices != gmem_type.getRank()) { + if (indices.size() != gmem_type.getRank()) { return error("The size of `indices` must be equal to the rank of `{0}`.", gmem_name); } @@ -277,14 +275,33 @@ llvm::LogicalResult VerifyCommonLoadStoreOp( "The size of `slice_lengths` must be equal to the rank of `{0}`.", gmem_name); } + int first_vector_index_dim = -1; // -1 means no vector index. + for (int i = 0; i < indices.size(); ++i) { + if (auto vec_type = + mlir::dyn_cast(indices[i].getType())) { + if (first_vector_index_dim >= 0) { + return error( + "Only one index may be a vector but got multiple vector indices " + "for dimensions {0} and {1}.", first_vector_index_dim, i); + } + first_vector_index_dim = i; + if (vec_type.getShape()[0] != slice_lengths[i]) { + return error( + "The size of the vector index must be equal to the slice length " + "but got {0} != {1}.", + vec_type.getShape()[0], slice_lengths[i]); + } + } + } return llvm::success(); } } // namespace llvm::LogicalResult AsyncLoadOp::verify() { - auto r = VerifyCommonLoadStoreOp(getLoc(), getSource().getType(), "source", - getDestination().getType(), "destination", - getSliceLengths(), getIndices().size()); + auto r = + VerifyCommonLoadStoreOp(getOperation(), getSource().getType(), "source", + getDestination().getType(), "destination", + getSliceLengths(), getIndices()); if (failed(r)) { return r; } @@ -301,82 +318,485 @@ llvm::LogicalResult AsyncLoadOp::verify() { return llvm::success(); } +llvm::LogicalResult AsyncPrefetchOp::verify() { + if (absl::c_any_of(getSliceLengths(), [](int64_t s) { return s < -1; })) { + return emitOpError( + "The `slice_lengths` attribute must not contain values less than -1."); + } + if (getIndices().size() != getSource().getType().getRank()) { + return emitOpError( + "The size of `indices` must be equal to the rank of `source`."); + } + + for (int i = 0; i < getCollective().size(); ++i) { + for (int k = i + 1; k < getCollective().size(); ++k) + if (getCollective()[i] == getCollective()[k]) { + return emitError( + "The `collective` attribute must not contain duplicate " + "dimensions."); + } + } + + return llvm::success(); +} + llvm::LogicalResult AsyncStoreOp::verify() { - return VerifyCommonLoadStoreOp(getLoc(), getDestination().getType(), + return VerifyCommonLoadStoreOp(getOperation(), getDestination().getType(), "destination", getSource().getType(), "source", - getSliceLengths(), getIndices().size()); + getSliceLengths(), getIndices()); } -namespace { -// This is the size of the M dimension in all wgmma instructions. It is fixed, -// unlike the K and N dimensions. -constexpr int kWgmmaSizeM = 64; -} // namespace +llvm::LogicalResult WGMMAOp::inferReturnTypes( + mlir::MLIRContext*, std::optional location, + mlir::ValueRange operands, mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, mlir::RegionRange regions, + llvm::SmallVectorImpl& inferredReturnTypes) { + if (operands.empty()) { + return mlir::emitOptionalError(location, "expected non-empty operands"); + } + inferredReturnTypes.assign({operands[0].getType()}); + return mlir::success(); +} llvm::LogicalResult WGMMAOp::verify() { auto error = [this](auto... params) { - return emitOpError(llvm::formatv(params...)); + return getOperation()->emitOpError(llvm::formatv(params...)); }; - auto a_shaped_type = mlir::cast(getA().getType()); - mlir::Type element_type = a_shaped_type.getElementType(); - if (element_type != getB().getType().getElementType()) { + auto a_type = mlir::cast(getA().getType()); + auto b_type = getB().getType(); + auto acc_type = getAccumulator().getType(); + + if (a_type.getElementType() != b_type.getElementType()) { return error("The `a` and `b` inputs must have the same element type."); } - auto a_shape = a_shaped_type.getShape(); - if (a_shape.size() != 2) { - return error("The `a` input must have rank 2."); - } + auto a_shape = a_type.getShape(); + auto b_shape = b_type.getShape(); + auto acc_shape = acc_type.getShape(); - auto b_shape = getB().getType().getShape(); - if (b_shape.size() != 2) { - return error("The `b` input must have rank 2."); + int M = acc_shape[0]; + if (M != a_shape[0]) { + return error( + "The accumulator's first dimension {0} must be equal to the first " + "dimensions of `a`: {1}.", + M, a_shape[0]); } - - auto accShape = getAccumulator().getType().getShape(); - if (accShape.size() != 2) { - return error("The accumulator must have rank 2."); + int K = a_shape[1]; // groups_k * k + if (K != b_shape[0]) { + return error( + "`a`'s contracting dimension {0} must be equal to the first dimension " + "of `b`: {1}.", + K, b_shape[0]); + } + int N = b_shape[1]; // groups_n * k + if (N != acc_shape[1]) { + return error( + "`b`'s non-contracting dimension {0} must be equal to the " + "accumulator's second dimension {1}.", + N, acc_shape[1]); } - if (accShape[0] % kWgmmaSizeM) { + // This is the size of the M dimension in all wgmma instructions. It is fixed, + // unlike the K and N dimensions. + constexpr int kWgmmaSizeM = 64; + if (M % kWgmmaSizeM != 0) { return error( "The accumulator's first dimension must be a multiple of {0}, but got " "{1}.", - kWgmmaSizeM, accShape[0]); + kWgmmaSizeM, M); } - int M = accShape[0]; // groups_m * 64 - if (M != a_shape[0] && M != a_shape[1]) { + return llvm::success(); +} + +llvm::LogicalResult TcGen05MMAOp::verify() { + auto error = [this](auto... params) { + return getOperation()->emitOpError(llvm::formatv(params...)); + }; + + auto a_type = getA().getType(); + auto b_type = getB().getType(); + auto acc_type = getAccumulator().getType(); + + if (a_type.getElementType() != b_type.getElementType()) { + return error("The `a` and `b` inputs must have the same element type."); + } + + auto a_shape = a_type.getShape(); + auto b_shape = b_type.getShape(); + auto acc_shape = acc_type.getShape(); + + int M = acc_shape[0]; + if (M != a_shape[0]) { return error( - "The accumulator's first dimension {0} must be equal to one " - "of the dimensions of `a` - ({1}, {2}).", - M, a_shape[0], a_shape[1]); + "The accumulator's first dimension {0} must be equal to the first " + "dimensions of `a`: {1}.", + M, a_shape[0]); } - int K = (a_shape[0] == M ? a_shape[1] : a_shape[0]); // groups_k * k - if (K != b_shape[0] && K != b_shape[1]) { + int K = a_shape[1]; // groups_k * k + if (K != b_shape[0]) { return error( - "`a`'s contracting dimension {0} must be equal to one " - "of the dimensions of `b` - ({1}, {2}).", - K, b_shape[0], b_shape[1]); + "`a`'s contracting dimension {0} must be equal to the first dimension " + "of `b`: {1}.", + K, b_shape[0]); } - int N = (b_shape[0] == K ? b_shape[1] : b_shape[0]); // groups_n * k - if (N != accShape[1]) { + int N = b_shape[1]; // groups_n * k + if (N != acc_shape[1] && !getCollective()) { return error( "`b`'s non-contracting dimension {0} must be equal to the " "accumulator's second dimension {1}.", - N, accShape[1]); + N, acc_shape[1]); + } + if (N * 2 != acc_shape[1] && getCollective()) { + return error( + "`b`'s non-contracting dimension {0} must be half the accumulator's " + "second dimension {1} for collective MMA.", + N, acc_shape[1]); + } + + // This is the size of the M dimension in all `tcgen05.mma` instructions. It + // is fixed, unlike the K and N dimensions. + constexpr int kTcGen05MmaMinSizeM = 32; + if (M % kTcGen05MmaMinSizeM != 0) { + return error( + "The accumulator's first dimension must be a multiple of {0} but got " + "{1}.", + kTcGen05MmaMinSizeM, M); + } + + mlir::Attribute tmem = TmemAttr::get(getContext()); + mlir::Attribute smem = mlir::gpu::AddressSpaceAttr::get( + getContext(), mlir::gpu::AddressSpace::Workgroup); + + mlir::Attribute acc_mem_space = getAccumulator().getType().getMemorySpace(); + if (acc_mem_space != tmem) { + return error("The accumulator must be in TMEM, but got {0}.", + acc_mem_space); + } + mlir::Attribute a_mem_space = getA().getType().getMemorySpace(); + if (a_mem_space != tmem && a_mem_space != smem) { + return error("The `a` input must be in TMEM or SMEM, but got {0}.", + a_mem_space); + } + mlir::Attribute b_mem_space = getB().getType().getMemorySpace(); + if (b_mem_space != smem) { + return error("The `b` input must be in SMEM, but got {0}.", b_mem_space); + } + + mlir::TypedValue a_scale = getAScale(); + mlir::TypedValue b_scale = getBScale(); + if (static_cast(a_scale) != static_cast(b_scale)) { + return error("Either none or both scales should be provided."); + } + + if (a_scale) { + mlir::Attribute a_scale_mem_space = a_scale.getType().getMemorySpace(); + if (a_scale_mem_space != tmem) { + return error("The `a_scale` input must be in TMEM, but got {0}.", + a_scale_mem_space); + } + mlir::Attribute b_scale_mem_space = b_scale.getType().getMemorySpace(); + if (b_scale_mem_space != tmem) { + return error("The `b_scale` input must be in TMEM, but got {0}.", + b_scale_mem_space); + } + } + + return llvm::success(); +} + +llvm::LogicalResult CustomPrimitiveOp::verify() { + int num_vector_operands = 0; + int num_smem_ref_operands = 0; + mlir::Attribute smem = mlir::gpu::AddressSpaceAttr::get( + getContext(), mlir::gpu::AddressSpace::Workgroup); + for (auto operand : getOperands()) { + if (mlir::isa(operand.getType())) { + ++num_vector_operands; + } + + if (auto ref_ty = mlir::dyn_cast(operand.getType())) { + if (ref_ty.getMemorySpace() == smem) { + ++num_smem_ref_operands; + } + } + } + + if (num_vector_operands != getInLayouts().size()) { + return emitOpError( + "Custom primitive must have a layout for each vector operand."); + } + + if (num_smem_ref_operands != getInTransforms().size()) { + return emitOpError( + "Custom primitive must have transforms for each memref operand in " + "smem."); + } + + int num_vector_results = 0; + for (auto result : getResults()) { + if (mlir::isa(result.getType())) { + ++num_vector_results; + } else if (mlir::isa(result.getType())) { + return emitOpError( + "Custom primitive can only return scalars or vectors."); + } + } + + if (num_vector_results != getOutLayouts().size()) { + return emitOpError( + "Custom primitive must have a layout for each vector result."); } return llvm::success(); } -mlir::AffineMap LayoutAttr::getAffineMap() const { - // This always returns an identity map. It's technically not correct, but we - // don't actually use it anywhere. It's only called during verification of the - // layout attribute and needs to be semi-valid. - return mlir::AffineMap::getMultiDimIdentityMap(getNumDimensions(), - getContext()); +llvm::LogicalResult BroadcastInDimOp::verify() { + auto error = [this](auto... params) { + return emitOpError(llvm::formatv(params...)); + }; + + mlir::VectorType operand_type = getOperand().getType(); + mlir::VectorType result_type = getResult().getType(); + + if (operand_type.getRank() == 0) { + return error("The input vector must have rank > 0."); + } + + if (operand_type.getRank() > result_type.getRank()) { + return error( + "The rank of the input vector must be smaller or equal to the rank " + "of the result vector."); + } + + if (operand_type.getRank() != getBroadcastDimensions().size()) { + return error( + "The size of the `broadcast_dimensions` attribute must be equal to " + "the rank of the input vector."); + } + auto dims = llvm::to_vector(getBroadcastDimensions()); + for (int i = 0; i < dims.size(); ++i) { + if (dims[i] < 0 || dims[i] >= result_type.getRank()) { + return error( + "The values in the `broadcast_dimensions` attribute must be in the " + "range [0, result.shape.rank={0}).", + result_type.getRank()); + } + if (i > 0 && dims[i] <= dims[i - 1]) { + return error( + "The values in the `broadcast_dimensions` attribute must be strictly " + "increasing."); + } + } + + return llvm::success(); +} + +llvm::LogicalResult ReturnOp::verify() { + // The operand number and types must match the custom primitive signature. + const auto& results = getParentOp()->getResultTypes(); + if (getNumOperands() != results.size()) + return emitOpError("has ") + << getNumOperands() << " operands, but enclosing custom_primitive (@" + << getParentOp()->getName() << ") returns " << results.size(); + + for (unsigned i = 0, e = results.size(); i != e; ++i) + if (getOperand(i).getType() != results[i]) + return emitError() << "type of return operand " << i << " (" + << getOperand(i).getType() + << ") doesn't match the result type (" << results[i] + << ")" + << " in custom_primitive @" + << getParentOp()->getName(); + + return llvm::success(); +} + +namespace { +int kTmemMaxColumns = 512; +int kTmemCellBitwidth = 32; + +llvm::LogicalResult VerifyTmemRefType(mlir::Operation* op, + mlir::MemRefType tmem_ref_type) { + mlir::Attribute tmem = TmemAttr::get(op->getContext()); + if (tmem_ref_type.getMemorySpace() != tmem) { + return op->emitError() << "The tmem memref must have a " + "mosaic_gpu.tmem memory space but got: " + << tmem_ref_type.getMemorySpace(); + } + + return llvm::success(); +} +} // namespace + +llvm::LogicalResult TmemAllocOp::verify() { + mlir::Attribute smem = mlir::gpu::AddressSpaceAttr::get( + getContext(), mlir::gpu::AddressSpace::Workgroup); + mlir::MemRefType smem_ref_type = getSmemPtr().getType(); + if (smem_ref_type.getMemorySpace() != smem) { + return emitError() + << "The `smem_ptr` memref must have the Workgroup address " + "space but got: " + << smem_ref_type.getMemorySpace(); + } + + mlir::MemRefType tmem_ref_type = getResult().getType(); + llvm::LogicalResult result = VerifyTmemRefType(getOperation(), tmem_ref_type); + if (result.failed()) { + return result; + } + + int num_unpacked_columns = tmem_ref_type.getShape()[1]; + int packing = getPacking(); + if (packing != 1) { + if (packing * tmem_ref_type.getElementTypeBitWidth() != kTmemCellBitwidth) { + return emitError() << "Only unpacked, or fully packed allocations " + "are supported. Expected packing to be either " + "1 or 32 / element_bitwidth, but got: " + "packing = " + << packing << ", element_bitwidth = " + << tmem_ref_type.getElementTypeBitWidth(); + } + if (num_unpacked_columns % packing != 0) { + return emitError() << "The number of unpacked columns must be " + "divisible by the packing factor, but got: " + << num_unpacked_columns << " / " << packing; + } + } + + int num_allocated_columns = num_unpacked_columns / packing; + if (num_allocated_columns > kTmemMaxColumns) { + return emitError() + << "The number of allocated columns must be less than or equal to " + << kTmemMaxColumns << " but got: " << num_allocated_columns; + } + + return llvm::success(); +} + +llvm::LogicalResult TmemDeallocOp::verify() { + return VerifyTmemRefType(getOperation(), getTmemRef().getType()); +} + +llvm::LogicalResult AsyncLoadTmemOp::inferReturnTypes( + mlir::MLIRContext*, std::optional location, + mlir::ValueRange operands, mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, mlir::RegionRange regions, + llvm::SmallVectorImpl& inferredReturnTypes) { + mlir::MemRefType memref_type = + mlir::cast(operands[0].getType()); + auto vector_type = mlir::VectorType::get(memref_type.getShape(), + memref_type.getElementType()); + inferredReturnTypes.assign({vector_type}); + return mlir::success(); +} + +llvm::LogicalResult AsyncLoadTmemOp::verify() { + if (getSource().getType().getElementType() != + getResult().getType().getElementType()) { + return emitError() << "The `source` and `result` must have " + "the same element type."; + } + if (getSource().getType().getShape() != getResult().getType().getShape()) { + return emitError() << "The `source` and `result` must have the same shape."; + } + return VerifyTmemRefType(getOperation(), getSource().getType()); +} + +llvm::LogicalResult AsyncStoreTmemOp::verify() { + if (getSource().getType().getElementType() != + getDestination().getType().getElementType()) { + return emitError() << "The `source` and `destination` must have " + "the same element type."; + } + if (getSource().getType().getShape() != + getDestination().getType().getShape()) { + return emitError() + << "The `source` and `destination` must have the same shape."; + } + return VerifyTmemRefType(getOperation(), getDestination().getType()); +} + +llvm::LogicalResult TmemLayoutCastOp::verify() { + return VerifyTmemRefType(getOperation(), getRef().getType()); +} + +llvm::LogicalResult SliceTmemOp::verify() { + if (VerifyTmemRefType(getOperation(), getSource().getType()).failed() || + VerifyTmemRefType(getOperation(), getResult().getType()).failed()) { + return llvm::failure(); + } + if (getOffset() % 4 != 0) { + return emitError() << "The offset must be a multiple of 4 but got: " + << getOffset(); + } + // TODO(allanrenucci): We can't precisely compute the number of columns in + // source/result because we need to know packing. We can however assume + // packing is either 1 (unpacked) or 32 / element_bitwidth (fully packed) and + // reject some invalid slices. + return llvm::success(); +} + +llvm::LogicalResult VectorLoadOp::inferReturnTypes( + mlir::MLIRContext*, std::optional, + mlir::ValueRange operands, mlir::DictionaryAttr, mlir::OpaqueProperties, + mlir::RegionRange, llvm::SmallVectorImpl& inferredReturnTypes) { + mlir::MemRefType memref_type = + mlir::cast(operands[0].getType()); + auto vector_type = mlir::VectorType::get(memref_type.getShape(), + memref_type.getElementType()); + inferredReturnTypes.assign({vector_type}); + return mlir::success(); +} + +llvm::LogicalResult VectorStoreOp::verify() { + mlir::VectorType src_type = getValueToStore().getType(); + mlir::MemRefType dst_type = getDestination().getType(); + if (src_type.getShape() != dst_type.getShape()) { + return emitError() + << "The source and destination must have the same shape but got " + << src_type.getShape() << " and " << dst_type.getShape(); + } + if (src_type.getElementType() != dst_type.getElementType()) { + return emitError() + << "The source and destination must have the same element type but " + "got " + << src_type.getElementType() << " and " << dst_type.getElementType(); + } + return llvm::success(); +} + +llvm::LogicalResult BroadcastedIotaOp::verify() { + mlir::VectorType result_type = getResult().getType(); + if (getDimension() >= result_type.getRank()) { + return emitError(llvm::formatv( + "dimension={0} must be smaller than the rank={1} of the result.", + getDimension(), result_type.getRank())); + } + return llvm::success(); +} + +llvm::LogicalResult PrintLayoutOp::verify() { + if (auto ref_ty = mlir::dyn_cast(getValue().getType())) { + if (VerifyTmemRefType(getOperation(), ref_ty).failed()) { + return llvm::failure(); + } + } + return llvm::success(); +} + +llvm::LogicalResult OptimizationBarrierOp::inferReturnTypes( + mlir::MLIRContext*, std::optional location, + mlir::ValueRange operands, mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, mlir::RegionRange regions, + llvm::SmallVectorImpl& inferredReturnTypes) { + if (operands.empty()) { + return mlir::emitOptionalError(location, "expected non-empty operands"); + } + mlir::TypeRange operand_types = operands.getTypes(); + inferredReturnTypes.assign(operand_types.begin(), operand_types.end()); + return mlir::success(); } void MosaicGPUDialect::initialize() { diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h index b4f13c50bd8c..641ee71ab604 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h @@ -18,18 +18,17 @@ limitations under the License. #include #include +#include -#include "llvm/ADT/StringRef.h" +#include "absl/status/status.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Value.h" -#include "mlir/Interfaces/InferTypeOpInterface.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" // IWYU pragma: keep +#include "mlir/Interfaces/InferTypeOpInterface.h" // IWYU pragma: keep #include "mlir/Support/LLVM.h" -#include "absl/status/status.h" -#include "absl/strings/string_view.h" // Generated definitions. #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_dialect.h.inc" // IWYU pragma: keep @@ -46,7 +45,7 @@ namespace mosaic_gpu { using Memref = ::mlir::TypedValue<::mlir::MemRefType>; using Pointer = ::mlir::TypedValue<::mlir::LLVM::LLVMPointerType>; -constexpr absl::string_view kRuntimeTmaDescriptorInitializerName = +constexpr std::string_view kRuntimeTmaDescriptorInitializerName = "mosaic_gpu_init_tma_desc"; template diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index 0882986fcf5e..c1a9d011e58d 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -18,6 +18,7 @@ limitations under the License. include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/BuiltinAttributeInterfaces.td" @@ -54,31 +55,30 @@ def LLVM_PointerShared : LLVM_PointerInAddressSpace<3>; def MosaicGPU_InitializeBarrierOp : Op { - let summary = "Initializes a memref of barriers"; + let summary = "Initializes barrier objects at specified location."; let description = [{ - Initializes a memref of barriers each meant to synchronize exactly + Initializes `num_barriers` barriers each meant to synchronize exactly `arrival_count` threads. - The base pointer of the result memref corresponds to `base_pointer`, which - must be a pointer to a shared memory location. + `base_pointer` must be a pointer to a shared memory location. }]; let arguments = (ins LLVM_PointerShared:$base_pointer, - ConfinedAttr:$arrival_count); - let results = (outs MemRefOf<[MosaicGPU_Barrier]>:$barriers_ref); + ConfinedAttr:$arrival_count, + ConfinedAttr:$num_barriers); +} - let assemblyFormat = [{ - $base_pointer $arrival_count attr-dict `:` type($barriers_ref) - }]; +def MosaicGPU_ArriveOp : Op { + let summary = "Executes an arrive operation on the given barrier."; + + let arguments = (ins + MemRefRankOf<[MosaicGPU_Barrier], [0]>:$barrier, + BoolAttr:$orders_tensor_core); } def MosaicGPU_ArriveExpectTxOp : Op { let summary = "Executes an arrive.expect_tx operation on the given barrier."; - let description = [{ - A single thread in the warpgroup will execute an `arrive.expect_tx` - operation on the provided barrier with the provided `expect_tx`. - }]; let arguments = (ins MemRefRankOf<[MosaicGPU_Barrier], [0]>:$barrier, @@ -142,16 +142,15 @@ def MosaicGPU_WGSplatFragLayout : AttrDef { - let summary = "1D array that is a row that can be tiled by supported WGMMA shapes."; +def MosaicGPU_Replicated : AttrDef { + let summary = "Indicates a replicated dimension in a tiled layout."; let description = [{ - This layout is used to handle rows that are fragmented across all threads - in a warpgroup that is executing a WGMMA operation. The length of the array - must be divisible by 64. + See mosaic/gpu/fragmented_array.py -> Replicated for more details. }]; - let mnemonic = "WGMMARowFragLayout"; - let assemblyFormat = ""; + let parameters = (ins "int":$times); + let mnemonic = "Replicated"; + let assemblyFormat = "`<` `times` `=` $times `>`"; } def MosaicGPU_TiledLayout : AttrDef { @@ -162,12 +161,12 @@ def MosaicGPU_TiledLayout : AttrDef { let parameters = (ins "::mlir::ArrayAttr":$tiling, - "int":$warp_dim, + "::mlir::ArrayAttr":$warp_dims, "::mlir::ArrayAttr":$lane_dims, "int":$vector_dim ); let mnemonic = "TiledLayout"; - let assemblyFormat = "`<` $tiling `,` `warp_dim` `=` $warp_dim `,` " + let assemblyFormat = "`<` $tiling `,` `warp_dims` `=` $warp_dims `,` " "`lane_dims` `=` $lane_dims `,` `vector_dim` `=` $vector_dim `>`"; } @@ -196,6 +195,25 @@ def MosaicGPU_SwizzlingMode : I32EnumAttr<"SwizzlingMode", let cppNamespace = "::mosaic_gpu"; } +def MosaicGPU_TMAReduction : I32EnumAttr<"TMAReduction", + "Reduction operation for TMA.", + [ + I32EnumAttrCase<"Add", 0, "add">, + I32EnumAttrCase<"Min", 1, "min">, + I32EnumAttrCase<"Max", 2, "max">, + I32EnumAttrCase<"Inc", 3, "inc">, + I32EnumAttrCase<"Dec", 4, "dec">, + I32EnumAttrCase<"And", 5, "and">, + I32EnumAttrCase<"Or", 6, "or">, + I32EnumAttrCase<"Xor", 7, "xor">, + I32EnumAttrCase<"Umin", 8, "umin">, + I32EnumAttrCase<"Umax", 9, "umax">, + I32EnumAttrCase<"Smin", 10, "smin">, + I32EnumAttrCase<"Smax", 11, "smax"> + ]>{ + let cppNamespace = "::mosaic_gpu"; +} + def TileTransformAttr : MosaicGPU_Attr<"TileTransform", "tile"> { let parameters = (ins ArrayRefParameter<"int32_t", "tiling">:$tiling); let summary = "Specifies a transform that tiles suffix dimensions of a memref in SMEM."; @@ -225,27 +243,6 @@ def SwizzleTransformAttr : MosaicGPU_Attr<"SwizzleTransform", "swizzle"> { let assemblyFormat = "`<` $swizzle `>`"; } -def LayoutAttr : MosaicGPU_Attr<"Layout", "layout", - [DeclareAttrInterfaceMethods]> { - let parameters = (ins - TypeParameter<"int32_t", "number of dimensions">:$num_dimensions, - ArrayRefParameter<"mlir::Attribute", "transforms">:$transforms - ); - - let summary = "Specifies a layout of a memref in SMEM."; - let description = [{ - This layout attribute is used to specify the layout of a memref in SMEM. - It is composed of a number of transforms, which are applied in the order - they are provided. The transforms can be any combination of: - - TileTransformAttr - - TransposeTransformAttr - - SwizzleTransformAttr - - The num_dimensions parameter must match the rank of the memref shape. - }]; - let assemblyFormat = "`<` $num_dimensions `,` $transforms `>`"; -} - def MosaicGPU_AsyncLoadOp : Op { let summary = "Schedules an async load of a MemRef from GMEM to SMEM"; @@ -264,36 +261,32 @@ def MosaicGPU_AsyncLoadOp : Op:$source, - MemRefOf<[AnyType]>:$destination, + AnyMemRef:$source, + AnyMemRef:$destination, MemRefRankOf<[MosaicGPU_Barrier], [0]>:$barrier, - Variadic:$indices, + Variadic]>>:$indices, PtxPredicate:$predicate, // Attributes @@ -301,15 +294,42 @@ def MosaicGPU_AsyncLoadOp : Op:$collective ); - let assemblyFormat = [{ - `source` `(` $source `:` type($source) `)` - `destination` `(` $destination `:` type($destination) `)` - `barrier` `(` $barrier `:` type($barrier) `)` - `indices` `(` $indices `)` - `predicate` `(` $predicate `)` - attr-dict + let hasVerifier = 1; +} + +def MosaicGPU_AsyncPrefetchOp : Op { + let summary = "Schedules an async prefetch of a MemRef from GMEM to L2"; + let description = [{ + Schedules an async prefetch of the contents of the `source` MemRef in GMEM + to the L2 cache, making subsequent loads of the same data from GMEM faster. + + The `indices` and `slice_lengths` inputs define what slice of the GMEM + `source` is going to be prefetched. Both `indices` and `slice_lengths` must + have a length equal to the rank of the `source`. The values in `indices` are + the starting indices of each dimension and the values in `slice_lengths` are + the lengths. Providing -1 in `slice_lengths` indicates that the slice length + is 1. If an index is a vector of ints, its elements serve as GMEM indices + from which the data should be gathered from GMEM. In this case, the + `slice_lengths` must have a value equal to the size of the vector. Only one + index may be a vector. + + The `collective` attribute can be provided to partition the prefetch over + multiple blocks in a cluster. + + The `predicate` allows scheduling the prefetch conditionally. }]; + let arguments = (ins + AnyMemRef:$source, + Variadic]>>:$indices, + PtxPredicate:$predicate, + + // Attributes + DenseI64ArrayAttr:$slice_lengths, + TypedArrayAttrBase:$collective + ); + let hasVerifier = 1; } @@ -327,65 +347,156 @@ def MosaicGPU_AsyncStoreOp : Op:$source, - MemRefOf<[AnyType]>:$destination, - Variadic:$indices, + AnyMemRef:$source, + AnyMemRef:$destination, + Variadic]>>:$indices, PtxPredicate:$predicate, // Attributes DenseI64ArrayAttr:$slice_lengths, - DefaultValuedOptionalAttr:$commit_group + DefaultValuedOptionalAttr:$commit_group, + OptionalAttr:$reduction_op ); - let assemblyFormat = [{ - `source` `(` $source `:` type($source) `)` - `destination` `(` $destination `:` type($destination) `)` - `indices` `(` $indices `)` - `predicate` `(` $predicate `)` - attr-dict + let hasVerifier = 1; +} + +def MosaicGPU_VectorLoadOp : Op, MemoryEffects<[MemRead]>]> { + let summary = "Reads an n-D slice of memory into an n-D vector."; + let description = [{ + Similar to `vector.load` (vector dialect) but supports loading from + non-contiguous memory. + + If `optimized` is true, raises an error if we cannot generate an optimised + transfer. If unset, fall back to a non-optimized transfer if unable to + generate an optimized transfer. + }]; + + let arguments = (ins + AnyNon0RankedMemRef:$source, + OptionalAttr:$optimized + ); + let results = (outs AnyFixedVectorOfNonZeroRank); +} + +def MosaicGPU_VectorStoreOp : Op]> { + let summary = "Writes an n-D vector to an n-D slice of memory."; + let description = [{ + Similar to `vector.store` (vector dialect) but supports storing to + non-contiguous memory. + + If `optimized` is true, raises an error if we cannot generate an optimised + transfer. If unset, fall back to a non-optimized transfer if unable to + generate an optimized transfer. }]; + let arguments = (ins + AnyFixedVectorOfNonZeroRank:$valueToStore, + AnyNon0RankedMemRef:$destination, + OptionalAttr:$optimized + ); + let hasVerifier = 1; } def MosaicGPU_WGMMASupportedType : AnyTypeOf<[F16, BF16, F32], "A type supported by the WGMMA operation">; -def MosaicGPU_WGMMALayout : - I32EnumAttr<"WGMMALayout", "The layout of the tiles of a WGMMA operation", [ - I32EnumAttrCase<"RowMajor", 0>, - I32EnumAttrCase<"ColumnMajor", 1> - ]> { - let cppNamespace = "::mosaic_gpu"; +def MosaicGPU_LayoutCastOp : Op { + let summary = "Casts a vector to a new layout."; + let description = [{Casts a vector value to a new strided or tiled layout.}]; + let arguments = (ins + AnyVectorOfAnyRank:$x, + + // Attributes + AnyAttrOf<[ + MosaicGPU_WGStridedFragLayout, + MosaicGPU_TiledLayout + ]>:$new_layout + ); + + let results = (outs AnyVectorOfAnyRank); + + let assemblyFormat = "`x` `(` $x `:` type($x) `)` attr-dict"; +} + +def MosaicGPU_TmemLayoutCastOp : Op { + let summary = "Casts a TMEM ref to a new TMEM layout."; + let arguments = (ins + MemRefRankOf<[AnyType], [2]>:$ref, + + // Attributes + MosaicGPU_TiledLayout:$new_layout + ); + + let results = (outs MemRefRankOf<[AnyType], [2]>); + + let hasVerifier = 1; } -def MosaicGPU_SliceSMEMOp : Op { +def MosaicGPU_BroadcastInDimOp : Op { + let summary = "Broadcasts a vector to a new shape."; + let description = [{ + `broadcast_dimensions` must have the same size as the rank of the input + vector and for each input dimension, specifies which output dimension it + corresponds to. + }]; + + let arguments = (ins + AnyVectorOfAnyRank:$operand, + + // Attributes + DenseI64ArrayAttr:$broadcast_dimensions + ); + + let results = (outs AnyVectorOfAnyRank); + let assemblyFormat = [{ + `(` $operand `:` type($operand) `)` attr-dict `->` type(results) + }]; + let hasVerifier = 1; +} + + +def MosaicGPU_SliceSMEMOp : Op { let summary = "Constructs an SMEM MemRef with the requested type that begins at the specified SMEM offset address."; let arguments = (ins I32:$offset); - let results = (outs MemRefOf<[AnyType]>); + let results = (outs AnyMemRef); } -def MosaicGPU_WGMMAOp : Op { - let summary = "Multiply two matrices asyncronously using warpgroup level matrix multiply operations."; +def MosaicGPU_WGMMASupportedABType : AnyTypeOf<[F16, BF16, TF32, F32, F8E4M3FN, F8E5M2, I8], + "A type supported by the `a` and `b` operands of the `wgmma.mma_async` instruction">; + +def MosaicGPU_WGMMASupportedAccumulatorType : AnyTypeOf<[F16, F32, I32], + "A type supported by the accumulator `wgmma.mma_async` instruction">; + + +def MosaicGPU_WGMMAOp : Op]> { + let summary = "Multiply two matrices asynchronously using warpgroup level matrix multiply operations."; let description = [{ Schedules WGMMA operations that perform the following matrix multiply and accumulate: @@ -394,19 +505,14 @@ def MosaicGPU_WGMMAOp : Op { This operation supports larger inputs than the PTX-level WGMMA operation and will schedule as many PTX-level WGMMA operations as needed to - accomplish the calculation. The `b` matrix, and optionally `a`, needs to be - provided as a 2-dimensional memref. All memrefs may have transforms that - define swizzling, tiling, and transposition. + accomplish the calculation. The `b` matrix, and optionally `a`, need to be + provided as a 2-dimensional memref. The inputs should have the following shapes: - a: [groups_m * 64, groups_k * s] - b: [groups_k * s, groups_n * s] - accumulator: [groups_m * 64, groups_n * s] - Where: - - `s == swizzle/element_bytediwth` (for `kNoSwizzle`, `swizzle` is 16.) - and the tilings are [64, s] for `a` and [s, s] for `b`. - - `a` and/or `b` may be transposed if the corresponding attribute is set - to `true`. + where `s == swizzle / element_bytewidth`. The output has an identical shape and type as the input accumulator. @@ -419,22 +525,19 @@ def MosaicGPU_WGMMAOp : Op { registers need to be synchronized with a memory fence. Usually `a` is read from shared memory if it is used directly in the WGMMA - operation. If `a` needs to be transfromed before it is used in the WGMMA + operation. If `a` needs to be transformed before it is used in the WGMMA operation, it may be more convenient to read it directly form registers. This avoids the need to store the data and wait for a fence. }]; let arguments = (ins - VectorOfAnyRankOf<[MosaicGPU_WGMMASupportedType]>:$accumulator, + VectorOfRankAndType<[2], [MosaicGPU_WGMMASupportedAccumulatorType]>:$accumulator, AnyTypeOf<[ - MemRefOf<[MosaicGPU_WGMMASupportedType]>, - VectorOfAnyRankOf<[MosaicGPU_WGMMASupportedType]>]>:$a, - MemRefOf<[MosaicGPU_WGMMASupportedType]>:$b, - - DefaultValuedOptionalAttr:$transpose_a, - DefaultValuedOptionalAttr:$transpose_b + MemRefRankOf<[MosaicGPU_WGMMASupportedABType], [2]>, + VectorOfRankAndType<[2], [MosaicGPU_WGMMASupportedABType]>]>:$a, + MemRefRankOf<[MosaicGPU_WGMMASupportedABType], [2]>:$b ); - let results = (outs VectorOfAnyRankOf<[MosaicGPU_WGMMASupportedType]>); + let results = (outs VectorOfRankAndType<[2], [MosaicGPU_WGMMASupportedAccumulatorType]>); let assemblyFormat = [{ `accumulator` `(` $accumulator `:` type($accumulator) `)` @@ -444,25 +547,286 @@ def MosaicGPU_WGMMAOp : Op { `->` type(results) }]; - let extraClassDeclaration = [{ - static llvm::LogicalResult inferReturnTypes( - mlir::MLIRContext *, - std::optional location, - mlir::ValueRange operands, - mlir::DictionaryAttr attributes, - mlir::OpaqueProperties properties, - mlir::RegionRange regions, - llvm::SmallVectorImpl &inferredReturnTypes) { - if (operands.empty()) { - return ::mlir::emitOptionalError( - location, "expected non-empty operands"); - } - inferredReturnTypes.assign({operands[0].getType()}); - return ::mlir::success(); - } + let hasVerifier = 1; +} + +def MosaicGPU_TmemAttr : MosaicGPU_Attr<"Tmem", "tmem"> { + let summary = "Tensor memory space."; +} + +def MosaicGPU_TcGen05MMASupportedABType : AnyTypeOf<[F16, F32, BF16, TF32, F8E4M3FN, F8E5M2, F6E2M3FN, F6E3M2FN, F4E2M1FN, I8], + "A type supported by the `a` and `b` operands of the `tcgen05.mma` instruction">; + +def MosaicGPU_TcGen05MMAOp : Op { + let summary = "Perform a matrix multiply-accumulate operation using the `tcgen05.mma` instruction."; + let description = [{ + Schedules `tcgen05.mma` instructions that perform the following matrix + multiply and accumulate: + + accumulator += a * b + + This operation supports larger inputs than the PTX-level MMA instruction + and will schedule as many PTX-level MMA instructions as needed to + accomplish the calculation. + + The inputs should have the following shapes: + - a: [groups_m * m, groups_k * s] + - b: [groups_k * s, groups_n * s] + - accumulator: [groups_m * m, groups_n * s] + where `s == swizzle / element_bytewidth` and `m` is specified according to + https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-shape. + + The `accumulator`, `a` and `b` matrices need to be provided as 2-dimensional + memrefs. The `accumulator` is always in TMEM and `b` is always in SMEM. + `a` can be in TMEM or SMEM. `a` and `b` must have the same element + type and when `a` is in TMEM only F16 or BF16 are supported. + + `a_scale` and `b_scale` are optional scaling matrices that reside in TMEM. + When set the operation is defined as: + + accumulator += (a * a_scale) * (b * b_scale) + + `accumulate` is a boolean that indicates whether to perform the accumulate + step. + }]; + + let arguments = (ins + MemRefRankOf<[F16, F32, I32], [2]>:$accumulator, + MemRefRankOf<[MosaicGPU_TcGen05MMASupportedABType], [2]>:$a, + MemRefRankOf<[MosaicGPU_TcGen05MMASupportedABType], [2]>:$b, + I1:$accumulate, + // TODO: allanrenucci - F8E4M3FNU (ue4m3) does not exist in MLIR. We use F8E4M3FN (e4m3) instead. + Optional>:$a_scale, + Optional>:$b_scale, + DefaultValuedAttr:$collective + ); + + let hasVerifier = 1; +} + +def MosaicGPU_OptimizationBarrierOp : Op]> { + let summary = "Prevents MLIR from moving operations across the barrier."; + + let arguments = (ins + Variadic:$operands + ); + let results = (outs Variadic); +} + +def MosaicGPU_ReturnOp : Op]> +{ + let summary = "Terminator for the region in a `CustomPrimitiveOp`"; + let description = [{ + The `return` op is a terminator that indicates the end of execution + within a `CustomPrimitiveOp`'s region. It can optionally return some values, + which become the results of the parent `CustomPrimitiveOp`. + + The declared results of the parent `CustomPrimitiveOp` must match the + operand types of this op. + }]; + + // The operand's type must match the parent CustomPrimitiveOp's result type. + let arguments = (ins Variadic:$operands); + let assemblyFormat = [{ attr-dict ($operands^ `:` type($operands))? }]; + let hasVerifier = 1; +} + +def MosaicGPU_CustomPrimitiveOp : Op]> +{ + let summary = "Allows defining a custom Mosaic GPU primitive."; + let description = [{ + Allows defining a custom Mosaic GPU primitive. + + Custom primitives should carry input and output layouts for each of their + vector operands and outputs, and input transforms for each of their memref + operands that live in SMEM. + + Custom primitives can only return vectors. + }]; + + let arguments = ( + ins Variadic:$operands, + // Attributes + ArrayAttr:$in_layouts, + ArrayAttr:$in_transforms, + ArrayAttr:$out_layouts + ); + + let results = (outs Variadic); + let regions = (region SizedRegion<1>:$body); + + let hasVerifier = 1; +} + +def MosaicGPU_WithTransformsOp : Op { + let summary = "A noop that allows manually setting transforms on a memref."; + let description = [{ + This op enforces the provided transforms on the parameter memref. + }]; + + let arguments = ( + ins AnyMemRef:$ref, + // Attributes + ArrayAttr:$transforms + ); + + let results = (outs AnyMemRef); +} + +def MosaicGPU_TmemAllocOp : Op { + let summary = "Allocates a chunk of TMEM."; + let description = [{ + This op allocates a chunk of TMEM and stores the pointer to the memory + in the provided SMEM memref. + + The `smem_ptr` is a pointer in SMEM where a pointer to the allocated + TMEM will be stored. The op returns a memref to the allocated TMEM. The + result must have a shape with dimensions [rows, logical_columns]. If + `packing` is 1, then the number of logical (unpacked) columns is equal to + the number of allocated columns in TMEM. Otherwise, these constraints + must hold: + + packing = 32 / bitwidth(element type of result) + unpacked_columns = allocated_columns * packing + + The number of allocated columns in TMEM can be any power of two in the + range [32, 512]. If the calculated number of allocated columns is less than + 32 or not a power of two, then it will be rounded up to the nearest power of + two larger or equal to 32. + + If `collective` is `true` 2 CTAs will perform the allocation collectively, + otherwise, only one CTA will perform the allocation. + }]; + + let arguments = (ins + MemRefRankOf<[I32], [0]>:$smem_ptr, + + // Attributes + DefaultValuedAttr:$collective, + DefaultValuedAttr, "1">:$packing + ); + + let results = (outs MemRefRankOf<[AnyType], [2]>); + + let assemblyFormat = [{ + `smem_ptr` `(` $smem_ptr `:` type($smem_ptr) `)` + attr-dict `->` type(results) }]; let hasVerifier = 1; } +def MosaicGPU_TmemRelinquishAllocPermitOp: Op { + let summary = "Relinquishing the right to allocate TMEM."; + let description = [{ + The instruction specifies that the CTA of the executing thread is + relinquishing the right to allocate Tensor Memory. So, it is illegal for a + CTA to perform `tmem_alloc` after any of its constituent threads execute + `tmem_relinquish_alloc_permit`. + + If `collective` is `true`, applies to collective TMEM allocations. + }]; + + let arguments = (ins + DefaultValuedAttr:$collective + ); +} + +def MosaicGPU_TmemDeallocOp : Op { + let summary = "Deallocates a chunk of TMEM."; + + let arguments = (ins MemRefRankOf<[AnyType], [2]>:$tmem_ref); + + let assemblyFormat = [{ + `tmem_ref` `(` $tmem_ref `:` type($tmem_ref) `)` + attr-dict + }]; + + let hasVerifier = 1; +} + +def MosaicGPU_AsyncLoadTmemOp : Op]> { + let summary = "Copies TMEM to registers asynchronously."; + + let arguments = (ins MemRefRankOf<[AnyType], [2]>:$source); + let results = (outs VectorOfRank<[2]>); + + let hasVerifier = 1; +} + +def MosaicGPU_AsyncStoreTmemOp : Op { + let summary = "Copies registers to TMEM asynchronously."; + + let arguments = (ins + VectorOfRank<[2]>: $source, + MemRefRankOf<[AnyType], [2]>:$destination + ); + + let hasVerifier = 1; +} + +def MosaicGPU_SliceTmemOp : Op { + let summary = "Constructs a TMEM MemRef with the requested type that begins at the specified TMEM column offset by the argument"; + let description = [{ + The principal use case for this op is to do a single TMEM allocation and + slice it into multiple smaller TMEM references. `source` is the large TMEM + allocation and `offset` is the number of columns to start slicing from. + }]; + + let arguments = (ins + MemRefRankOf<[AnyType], [2]>:$source, + ConfinedAttr:$offset + ); + + let results = (outs MemRefRankOf<[AnyType], [2]>); + + let hasVerifier = 1; +} + +def MosaicGPU_DebugPrintOp : Op { + let summary = "Prints value from inside a MGPU kernel"; + + let arguments = (ins + StrAttr:$format, + AnyVectorOfAnyRank:$value + ); +} + +def MosaicGPU_PrintLayoutOp : Op { + let summary = "Prints the layout of the given array or TMEM reference."; + + let arguments = (ins + StrAttr:$format, + AnyTypeOf<[AnyMemRef, AnyVectorOfAnyRank]>:$value + ); + + let hasVerifier = 1; +} + +def MosaicGPU_BroadcastedIotaOp : Op { + let summary = "Create a broadcasted iota vector."; + let description = [{ + Creates an array that has the specified shape and holds values starting at + zero and incrementing by one along the specified dimension. + }]; + + let arguments = (ins + ConfinedAttr:$dimension + ); + + let results = (outs AnyFixedVectorOfNonZeroRank); + + let hasVerifier = 1; +} + +// TODO (b/415721295): Remove when the minimum JAXLIB version is 0.8.3. +def TMAGatherSupportedOp : Op { + let summary = "Dummy op that will be removed. DO NOT USE."; +} + #endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_ diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc index 527aa7c7ce25..5e26aa722b32 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc @@ -18,33 +18,33 @@ limitations under the License. #include #include #include +#include +#include #include #include #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "llvm/include/llvm/ADT/ArrayRef.h" -#include "llvm/include/llvm/ADT/SmallVector.h" -#include "mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h" -#include "mlir/include/mlir/Conversion/LLVMCommon/StructBuilder.h" -#include "mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/include/mlir/Dialect/SCF/Utils/Utils.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/Diagnostics.h" -#include "mlir/include/mlir/IR/MLIRContext.h" -#include "mlir/include/mlir/IR/OwningOpRef.h" -#include "mlir/include/mlir/IR/Types.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/IR/Verifier.h" -#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "tsl/platform/errors.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" +#include "mlir/Conversion/LLVMCommon/StructBuilder.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "xla/tsl/platform/errors.h" namespace mosaic_gpu { namespace { @@ -63,9 +63,8 @@ absl::StatusOr FromCppFunc( mlir::OpBuilder b(context); b.setInsertionPointToEnd(module.getBody()); - auto fn = b.create( - b.getUnknownLoc(), "function_wrapper", - b.getFunctionType({type1, type2}, std::nullopt)); + auto fn = mlir::func::FuncOp::create(b, b.getUnknownLoc(), "function_wrapper", + b.getFunctionType({type1, type2}, {})); fn.addEntryBlock(); b.setInsertionPointToStart(&fn.front()); @@ -73,7 +72,7 @@ absl::StatusOr FromCppFunc( mlir::cast>(fn.getArgument(1)), varargs...)); - b.create(b.getUnknownLoc()); + mlir::func::ReturnOp::create(b, b.getUnknownLoc()); if (mlir::failed(mlir::verify(module))) { return absl::InternalError("Failed to verify generated module"); @@ -95,7 +94,7 @@ class MosaicGpuTest : public ::testing::Test { mosaic_gpu::DeclareRuntimeFunctions(builder_); } - void ExpectLastErrorContains(absl::string_view substring) { + void ExpectLastErrorContains(std::string_view substring) { EXPECT_THAT(last_error_message_, HasSubstr(substring)); } diff --git a/jaxlib/mosaic/dialect/tpu/array_util.cc b/jaxlib/mosaic/dialect/tpu/array_util.cc index 4c1e79667c0f..c5c60b416ca3 100644 --- a/jaxlib/mosaic/dialect/tpu/array_util.cc +++ b/jaxlib/mosaic/dialect/tpu/array_util.cc @@ -19,10 +19,11 @@ limitations under the License. #include "absl/log/check.h" #include "absl/types/span.h" -#include "llvm/include/llvm/ADT/STLExtras.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/Support/LLVM.h" -namespace mlir::tpu::internal { +namespace mlir::tpu { +namespace internal { bool sliceIsEmpty(const absl::Span starts, const absl::Span limits) { @@ -51,4 +52,20 @@ bool incrementSliceIndex(const MutableArrayRef idx, return false; } -} // namespace mlir::tpu::internal +} // namespace internal + +bool incrementIndex(const MutableArrayRef idx, + const absl::Span limits) { + const int64_t nd = idx.size(); + CHECK_EQ(nd, limits.size()); + for (int64_t i = nd - 1; i >= 0; --i) { + ++idx[i]; + if (idx[i] < limits[i]) { + return true; + } + idx[i] = 0; + } + return false; +} + +} // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/array_util.h b/jaxlib/mosaic/dialect/tpu/array_util.h index 1b755dbf8495..8b80e7e356ec 100644 --- a/jaxlib/mosaic/dialect/tpu/array_util.h +++ b/jaxlib/mosaic/dialect/tpu/array_util.h @@ -17,10 +17,11 @@ limitations under the License. #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_ARRAY_UTIL_H_ #include +#include #include "absl/log/check.h" #include "absl/types/span.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/util.h" #include "xla/array.h" @@ -42,6 +43,9 @@ bool incrementSliceIndex(MutableArrayRef idx, } // namespace internal +bool incrementIndex(MutableArrayRef idx, + absl::Span limits); + template ArrayRef XlaArrayToFlatArrayRef(const xla::Array &arr) { return ArrayRef(arr.data(), arr.num_elements()); @@ -56,6 +60,33 @@ xla::Array XlaArrayFromShapeAndValues(ArrayRef sizes, Range vals) { return arr; } +// An alternative to xla::Array::Each that returns a LogicalResult +template +std::enable_if_t, T*>, + LogicalResult> +Each(xla::Array& arr, Func&& func) { + SmallVector idx(arr.num_dimensions()); + auto it = arr.begin(); + do { + RETURN_IF_FAILED(func(ArrayRef(idx), &*it)); + ++it; + } while (incrementIndex(idx, arr.dimensions())); + DCHECK(it == arr.end()); + return success(); +} +template +std::enable_if_t, T>, LogicalResult> +Each(const xla::Array& arr, Func&& func) { + SmallVector idx(arr.num_dimensions()); + auto it = arr.begin(); + do { + RETURN_IF_FAILED(func(ArrayRef(idx), *it)); + ++it; + } while (incrementIndex(idx, arr.dimensions())); + DCHECK(it == arr.end()); + return success(); +} + // An alternative to `xla::Array::UpdateSlice` that takes a single value. template void updateSlice(xla::Array &arr, const T &value, diff --git a/jaxlib/mosaic/dialect/tpu/array_util_test.cc b/jaxlib/mosaic/dialect/tpu/array_util_test.cc index 18c2f94fa8b6..bcbf417a967b 100644 --- a/jaxlib/mosaic/dialect/tpu/array_util_test.cc +++ b/jaxlib/mosaic/dialect/tpu/array_util_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/Support/LLVM.h" #include "xla/array.h" namespace mlir::tpu { diff --git a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc index 772e87beff71..f8c04ea933fc 100644 --- a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc @@ -15,178 +15,23 @@ limitations under the License. #include "jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h" -#include -#include -#include -#include -#include #include +#include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/MemAlloc.h" -#include "llvm/Support/raw_ostream.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Registration.h" -#include "mlir/CAPI/Utils.h" -#include "mlir/CAPI/Wrap.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Value.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" -#include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h" #include "jaxlib/mosaic/dialect/tpu/transforms/serde.h" -#include "xla/array.h" - -// TODO(tlongeri): null pointer checks? - -namespace { -DEFINE_C_API_PTR_METHODS(MlirTpuVectorLayout, mlir::tpu::VectorLayout); -DEFINE_C_API_PTR_METHODS(MlirTpuVregDataBounds, mlir::tpu::VRegDataBounds); - -MlirTpuImplicitDim wrap(mlir::tpu::VectorLayout::ImplicitDim implicit_dim) { - switch (implicit_dim) { - case mlir::tpu::VectorLayout::ImplicitDim::kNone: - return MlirTpuImplicitDimNone; - case mlir::tpu::VectorLayout::ImplicitDim::kMinor: - return MlirTpuImplicitDimMinor; - case mlir::tpu::VectorLayout::ImplicitDim::kSecondMinor: - return MlirTpuImplicitDimSecondMinor; - } - LOG(FATAL) << "Invalid implicit dim (C++)"; -} -mlir::tpu::VectorLayout::ImplicitDim unwrap(MlirTpuImplicitDim implicit_dim) { - switch (implicit_dim) { - case MlirTpuImplicitDimNone: - return mlir::tpu::VectorLayout::ImplicitDim::kNone; - case MlirTpuImplicitDimMinor: - return mlir::tpu::VectorLayout::ImplicitDim::kMinor; - case MlirTpuImplicitDimSecondMinor: - return mlir::tpu::VectorLayout::ImplicitDim::kSecondMinor; - } - LOG(FATAL) << "Invalid implicit dim (C)"; -} -mlir::tpu::Direction unwrap(MlirTpuDirection direction) { - switch (direction) { - case MlirTpuDirectionSublanes: - return mlir::tpu::Direction::kSublanes; - case MlirTpuImplicitDimMinor: - return mlir::tpu::Direction::kLanes; - case MlirTpuImplicitDimSecondMinor: - return mlir::tpu::Direction::kSubelements; - } - LOG(FATAL) << "Invalid direction (C)"; -} -MlirTpuLayoutOffsets wrap(mlir::tpu::LayoutOffsets offsets) { - return {offsets[0].value_or(-1), offsets[1].value_or(-1)}; -} -mlir::tpu::LayoutOffsets unwrap(MlirTpuLayoutOffsets offsets) { - auto translateOffset = [](int64_t offset) { - CHECK_GE(offset, -1); - return offset == -1 ? std::nullopt : mlir::tpu::LayoutOffset{offset}; - }; - return {translateOffset(offsets.sublane), translateOffset(offsets.lane)}; -} -std::array unwrap(MlirTpuBoolTargetTuple arr) { - return {arr.sublane, arr.lane}; -} -std::array unwrap(MlirTpuI64TargetTuple arr) { - return {arr.sublane, arr.lane}; -} -MlirTpuI64TargetTuple wrap(std::array arr) { - return {arr[0], arr[1]}; -} -mlir::tpu::ApplyVectorLayoutContext unwrap( - MlirTpuApplyVectorLayoutContext ctx) { - return mlir::tpu::ApplyVectorLayoutContext{ - .hardware_generation = ctx.hardware_generation, - .target_shape = unwrap(ctx.target_shape), - .mxu_shape = {ctx.mxu_shape.contracting_size, - ctx.mxu_shape.non_contracting_size}, - .max_sublanes_in_scratch = ctx.max_sublanes_in_scratch}; -} - -mlir::OpBuilder mlirTpuInsertionPointToOpBuilder( - MlirTpuInsertionPoint insertion_point) { - mlir::Operation *ref_operation = unwrap(insertion_point.ref_operation); - return ref_operation == nullptr - ? mlir::OpBuilder::atBlockEnd(unwrap(insertion_point.block)) - : mlir::OpBuilder(ref_operation); -} - -// We do not use the names wrap/unwrap for MlirTpuI64ArrayRef because whether -// they should refer to SmallVector or ArrayRef is ambiguous -MlirTpuI64ArrayRef mlirTpuI64ArrayRefFromLlvmSmallVector( - const mlir::SmallVector &vec) { - // TODO(tlongeri): It would be good to steal the buffer from implicit_shape, - // but there are no public member functions for this. - int64_t *ptr = - static_cast(llvm::safe_malloc(vec.size() * sizeof(int64_t))); - memcpy(ptr, vec.data(), vec.size() * sizeof(int64_t)); - return {ptr, vec.size()}; -} -llvm::ArrayRef mlirTpuI64ArrayRefToLlvmArrayRef( - MlirTpuI64ArrayRef tpu_array_ref) { - return {tpu_array_ref.ptr, tpu_array_ref.size}; -} - -// We do not use the names wrap/unwrap for MlirTpuValueArray because it -// allocates memory (i.e. they have side effects) -xla::Array MlirTpuValueArrayToXlaArray( - MlirTpuValueArray arr) { - llvm::ArrayRef shape = mlirTpuI64ArrayRefToLlvmArrayRef(arr.shape); - xla::Array res(shape); - int64_t n = res.num_elements(); - for (int64_t i = 0; i < n; ++i) { - res.data()[i] = unwrap(arr.vals[i]); - } - return res; -} -MlirTpuValueArray MlirTpuValueArrayFromXlaArray( - const xla::Array &vals) { - int64_t nd = vals.num_dimensions(); - int64_t *shape = - static_cast(llvm::safe_malloc(nd * sizeof(int64_t))); - memcpy(shape, vals.dimensions().data(), nd * sizeof(int64_t)); - int64_t n = vals.num_elements(); - MlirValue *elements = - static_cast(llvm::safe_malloc(n * sizeof(MlirValue))); - memcpy(elements, vals.data(), n * sizeof(MlirValue)); - return {{shape, static_cast(nd)}, elements}; -} - -} // namespace extern "C" { MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(TPU, tpu, mlir::tpu::TPUDialect); -bool mlirTPUAttributeIsATiledLayoutAttr(MlirAttribute attr) { - return llvm::isa(unwrap(attr)); -} - -MlirAttribute mlirTPUTiledLayoutAttrGetTiles(MlirAttribute attr) { - auto layout_attr = llvm::cast(unwrap(attr)); - std::vector tile_attrs; - tile_attrs.reserve(layout_attr.getTiles().size()); - mlir::MLIRContext *ctx = layout_attr.getContext(); - for (auto &tile : layout_attr.getTiles()) { - auto d = tile.dimensions(); - tile_attrs.push_back(mlir::DenseI64ArrayAttr::get( - ctx, llvm::ArrayRef(d.begin(), d.end()))); - } - return wrap(mlir::ArrayAttr::get(ctx, tile_attrs)); -} - void mlirTPUAnalyzePotentialCommunication(MlirOperation op, bool *has_communication, bool *has_custom_barrier) { @@ -195,226 +40,23 @@ void mlirTPUAnalyzePotentialCommunication(MlirOperation op, *has_custom_barrier = result.second; } -MlirTpuVectorLayout mlirTpuVectorLayoutCreate(int bitwidth, - MlirTpuLayoutOffsets offsets, - MlirTpuI64TargetTuple tiling, - MlirTpuImplicitDim implicit_dim) { - return wrap(new mlir::tpu::VectorLayout( - bitwidth, unwrap(offsets), unwrap(tiling), unwrap(implicit_dim))); -} - -void mlirTpuVectorLayoutDestroy(MlirTpuVectorLayout layout) { - delete unwrap(layout); -} - -int mlirTpuVectorLayoutGetBitwidth(MlirTpuVectorLayout layout) { - return unwrap(layout)->bitwidth(); -} - -MlirTpuLayoutOffsets mlirTpuVectorLayoutGetOffsets(MlirTpuVectorLayout layout) { - return wrap(unwrap(layout)->offsets()); -} - -MlirTpuI64TargetTuple mlirTpuVectorLayoutGetTiling(MlirTpuVectorLayout layout) { - return wrap(unwrap(layout)->tiling()); -} - -MlirTpuImplicitDim mlirTpuVectorLayoutGetImplicitDim( - MlirTpuVectorLayout layout) { - return wrap(unwrap(layout)->implicit_dim()); -} - -int mlirTpuVectorLayoutGetPacking(MlirTpuVectorLayout layout) { - return unwrap(layout)->packing(); -} - -int mlirTpuVectorLayoutGetLayoutRank(MlirTpuVectorLayout layout) { - return unwrap(layout)->layout_rank(); -} - -bool mlirTpuVectorLayoutEquals(MlirTpuVectorLayout lhs, - MlirTpuVectorLayout rhs) { - return *unwrap(lhs) == *unwrap(rhs); -} - -int64_t mlirTpuVectorLayoutTilesPerVreg(MlirTpuVectorLayout layout, - MlirTpuI64TargetTuple target_shape) { - return unwrap(layout)->tilesPerVreg(unwrap(target_shape)); -} - -int64_t mlirTpuVectorLayoutSublanesPerTile(MlirTpuVectorLayout layout, - MlirTpuI64TargetTuple target_shape) { - return unwrap(layout)->sublanesPerTile(unwrap(target_shape)); -} - -MlirTpuI64TargetTuple mlirTpuVectorLayoutVregSlice( - MlirTpuVectorLayout layout, MlirTpuI64TargetTuple target_shape) { - return wrap(unwrap(layout)->vregSlice(unwrap(target_shape))); -} - -MlirTpuI64ArrayRef mlirTpuVectorLayoutImplicitShape(MlirTpuVectorLayout layout, - MlirTpuI64ArrayRef shape) { - mlir::SmallVector implicit_shape = - unwrap(layout)->implicitShape(mlirTpuI64ArrayRefToLlvmArrayRef(shape)); - return mlirTpuI64ArrayRefFromLlvmSmallVector(implicit_shape); -} - -MlirTpuI64ArrayRef mlirTpuVectorLayoutTileArrayShape( - MlirTpuVectorLayout layout, MlirTpuI64ArrayRef shape, - MlirTpuI64TargetTuple target_shape) { - mlir::SmallVector tile_array_shape = unwrap(layout)->tileArrayShape( - mlirTpuI64ArrayRefToLlvmArrayRef(shape), unwrap(target_shape)); - return mlirTpuI64ArrayRefFromLlvmSmallVector(tile_array_shape); -} - -MlirTpuVregDataBounds mlirTpuVectorLayoutTileDataBounds( - MlirTpuVectorLayout layout, MlirContext ctx, int64_t *full_shape, - int64_t *idxs, size_t size, MlirTpuI64TargetTuple target_shape, - MlirTpuBoolTargetTuple allow_replicated) { - std::unique_ptr ptr = - unwrap(layout)->tileDataBounds( - unwrap(ctx), llvm::ArrayRef{full_shape, size}, - llvm::ArrayRef{idxs, size}, unwrap(target_shape), - unwrap(allow_replicated)); - return wrap(ptr.release()); -} - -bool mlirTpuVectorLayoutHasNaturalTopology(MlirTpuVectorLayout layout, - MlirTpuI64TargetTuple target_shape) { - return unwrap(layout)->hasNaturalTopology(unwrap(target_shape)); -} - -bool mlirTpuVectorLayoutHasNativeTiling(MlirTpuVectorLayout layout, - MlirTpuI64TargetTuple target_shape) { - return unwrap(layout)->hasNativeTiling(unwrap(target_shape)); -} - -bool mlirTpuVectorLayoutGeneralizes(MlirTpuVectorLayout layout, - MlirTpuVectorLayout other, - MlirTpuI64ArrayRef shape, - MlirTpuI64TargetTuple target_shape) { - return unwrap(layout)->generalizes(*unwrap(other), - mlirTpuI64ArrayRefToLlvmArrayRef(shape), - unwrap(target_shape)); -} - -bool mlirTpuVectorLayoutEquivalentTo(MlirTpuVectorLayout layout, - MlirTpuVectorLayout other, - MlirTpuI64ArrayRef shape, - MlirTpuI64TargetTuple target_shape) { - return unwrap(layout)->equivalentTo(*unwrap(other), - mlirTpuI64ArrayRefToLlvmArrayRef(shape), - unwrap(target_shape)); -} - -void mlirTpuVectorLayoutPrint( - MlirTpuVectorLayout layout, MlirStringCallback callback, void *userData) { - mlir::detail::CallbackOstream stream(callback, userData); - unwrap(layout)->print(stream); -} - -bool mlirTpuVectorLayoutIsValid(MlirTpuVectorLayout layout, - MlirTpuI64TargetTuple target_shape) { - return unwrap(layout)->isValid(unwrap(target_shape)); -} - -void mlirTpuVregDataBoundsDestroy(MlirTpuVregDataBounds data_bounds) { - delete unwrap(data_bounds); -} - -bool mlirTpuVregDataBoundsMaskVariesAlong(MlirTpuVregDataBounds data_bounds, - MlirTpuDirection direction, - MlirTpuI64TargetTuple target_shape) { - return unwrap(data_bounds) - ->maskVariesAlong(unwrap(direction), unwrap(target_shape)); -} - -bool mlirTpuVregDataBoundsIsComplete(MlirTpuVregDataBounds data_bounds, - MlirTpuI64TargetTuple target_shape) { - return unwrap(data_bounds)->isComplete(unwrap(target_shape)); -} - -MlirValue mlirTpuVregDataBoundsGetVectorMask( - MlirTpuVregDataBounds data_bounds, MlirTpuInsertionPoint insertion_point, - MlirLocation location, int generation, MlirTpuI64TargetTuple target_shape) { - mlir::OpBuilder builder = mlirTpuInsertionPointToOpBuilder(insertion_point); - auto failure_or_mask = unwrap(data_bounds) - ->getVectorMask(builder, unwrap(location), - generation, unwrap(target_shape)); - if (failed(failure_or_mask)) { - return wrap(mlir::Value()); - } else { - return wrap(failure_or_mask.value()); - } -} - -MlirAttribute mlirTpuVregDataBoundsGetSublaneMask( - MlirTpuVregDataBounds data_bounds, MlirContext ctx, - MlirTpuI64TargetTuple target_shape) { - return wrap( - unwrap(data_bounds)->getSublaneMask(unwrap(ctx), unwrap(target_shape))); -} - -MlirOperation mlirTpuAssemble(MlirTpuInsertionPoint insertion_point, - MlirType vector_type, MlirTpuVectorLayout layout, - MlirTpuValueArray vals, - MlirTpuI64TargetTuple target_shape) { - mlir::OpBuilder builder = mlirTpuInsertionPointToOpBuilder(insertion_point); - // This cast will fail and assert if the caller passed a non-vector type - auto vty = mlir::cast(unwrap(vector_type)); - return wrap(mlir::tpu::assemble(builder, vty, *unwrap(layout), - MlirTpuValueArrayToXlaArray(vals), - unwrap(target_shape)) - .getOperation()); +MLIR_CAPI_EXPORTED void mlirTpuRegisterMosaicSerdePass() { + mlir::tpu::registerMosaicSerdePass(); } -MlirTpuValueArray mlirTpuDisassemble(MlirTpuInsertionPoint insertion_point, - MlirTpuVectorLayout layout, MlirValue val, - MlirTpuI64TargetTuple target_shape) { - mlir::OpBuilder builder = mlirTpuInsertionPointToOpBuilder(insertion_point); - // This cast will fail and assert if the caller passed a non-vector - auto vector_val = mlir::cast>(unwrap(val)); - mlir::FailureOr> failure_or_vals = - mlir::tpu::disassemble(builder, *unwrap(layout), vector_val, - unwrap(target_shape)); - if (failed(failure_or_vals)) { - return {{nullptr, 0}, nullptr}; - } - return MlirTpuValueArrayFromXlaArray(std::move(failure_or_vals).value()); -} +} // extern "C" -MlirLogicalResult mlirTpuApplyLayoutOp(MlirTpuApplyVectorLayoutContext ctx, - MlirOperation op) { - mlir::tpu::ApplyVectorLayoutContext unwrapped_ctx = unwrap(ctx); - return wrap(mlir::tpu::applyLayoutOp(unwrapped_ctx, *unwrap(op))); +// Type API for Float8EXMYType +MlirType mlirTpuFloat8EXMYTypeGetUnderlyingType(MlirType exmy_type) { + return wrap(llvm::cast(unwrap(exmy_type)) + .getUnderlyingType()); } -MlirValue mlirTpuRelayout(MlirTpuInsertionPoint insertion_point, MlirValue val, - MlirTpuVectorLayout src, MlirTpuVectorLayout dst, - MlirTpuApplyVectorLayoutContext ctx) { - mlir::OpBuilder builder = mlirTpuInsertionPointToOpBuilder(insertion_point); - // This cast will fail and assert if the caller passed a non-vector - auto vector_val = mlir::cast>(unwrap(val)); - auto apply_layout_ctx = unwrap(ctx); - mlir::FailureOr> failure_or_new_val = - mlir::tpu::relayout(apply_layout_ctx, builder, vector_val, *unwrap(src), - *unwrap(dst)); - if (failed(failure_or_new_val)) { - return {nullptr}; - } - return wrap(std::move(failure_or_new_val).value()); +bool mlirTpuIsAFloat8EXMYType(MlirType type) { + return llvm::isa(unwrap(type)); } -} - -MLIR_CAPI_EXPORTED void mlirTpuRegisterMosaicSerdePass() { - mlir::tpu::registerMosaicSerdePass(); -} - -#include "mlir/CAPI/Pass.h" // IWYU pragma: keep -#include "mlir/CAPI/Support.h" // IWYU pragma: keep - -extern "C" { -using namespace mlir::tpu; -#include "jaxlib/mosaic/dialect/tpu/integrations/c/tpu_passes.capi.cc.inc" +MlirType mlirTpuFloat8EXMYTypeGet(MlirContext ctx, MlirType exmy_type) { + auto float_type = llvm::cast(unwrap(exmy_type)); + return wrap(mlir::tpu::Float8EXMYType::get(unwrap(ctx), float_type)); } diff --git a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h index 6f883325403b..da085a7f84bf 100644 --- a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h @@ -13,9 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Refer to the corresponding C++ declarations in layout.h and -// apply_vector_layout.h for documentation on the functions in this file - #ifndef JAXLIB_MOSAIC_DIALECT_TPU_INTEGRATIONS_C_TPU_DIALECT_H_ #define JAXLIB_MOSAIC_DIALECT_TPU_INTEGRATIONS_C_TPU_DIALECT_H_ @@ -37,209 +34,21 @@ extern "C" { MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(TPU, tpu); -MLIR_CAPI_EXPORTED bool mlirTPUAttributeIsATiledLayoutAttr(MlirAttribute attr); - -/// Encodes the tiles as an ArrayAttr of DenseI64ArrayAttrs. -MLIR_CAPI_EXPORTED MlirAttribute -mlirTPUTiledLayoutAttrGetTiles(MlirAttribute attr); - MLIR_CAPI_EXPORTED void mlirTPUAnalyzePotentialCommunication( MlirOperation op, bool* has_communication, bool* has_custom_barrier); -typedef enum MlirTpuImplicitDim { - MlirTpuImplicitDimNone = 0, - MlirTpuImplicitDimMinor = 1, - MlirTpuImplicitDimSecondMinor = 2, -} MlirTpuImplicitDim; - -typedef enum MlirTpuDirection { - MlirTpuDirectionSublanes, - MlirTpuDirectionLanes, - MlirTpuDirectionSubelements -} MlirTpuDirection; - -// Opaque reference to an owned layout -typedef struct MlirTpuVectorLayout { - void* ptr; -} MlirTpuVectorLayout; - -// Opaque reference to owned data bounds -typedef struct MlirTpuVregDataBounds { - void* ptr; -} MlirTpuVregDataBounds; - -// mlir::ArrayRef equivalent -// Unlike mlir::ArrayRef, the data may or may not be owned (this should be -// defined by the producer of the struct). -typedef struct MlirTpuI64ArrayRef { - int64_t* ptr; - size_t size; -} MlirTpuI64ArrayRef; - -// Shaped array of values -typedef struct MlirTpuValueArray { - MlirTpuI64ArrayRef shape; // May or may not be owned - MlirValue* vals; // Size given by the shape -} MlirTpuValueArray; - -typedef struct MlirTpuLayoutOffsets { - // Use -1 for replicated - int64_t sublane; - int64_t lane; -} MlirTpuLayoutOffsets; - -typedef struct MlirTpuI64TargetTuple { - int64_t sublane; - int64_t lane; -} MlirTpuI64TargetTuple; - -typedef struct MlirTpuMxuShape { - int64_t contracting_size; - int64_t non_contracting_size; -} MlirTpuMxuShape; - -typedef struct MlirTpuBoolTargetTuple { - bool sublane; - bool lane; -} MlirTpuBoolTargetTuple; - -// An insertion point within a block. -// The MLIR C API does not already have a similar struct, unfortunately. -typedef struct MlirTpuInsertionPoint { - MlirBlock block; // Only used when ref_operation is unspecified (null) - MlirOperation ref_operation; -} MlirTpuInsertionPoint; - -typedef struct MlirTpuApplyVectorLayoutContext { - int hardware_generation = -1; - MlirTpuI64TargetTuple target_shape = {8, 128}; - MlirTpuMxuShape mxu_shape = {128, 128}; - int64_t max_sublanes_in_scratch = 0; -} MlirTpuApplyVectorLayoutContext; - -// Caller owns the returned object and is responsible for calling -// mlirTpuVectorLayoutDestroy -MLIR_CAPI_EXPORTED MlirTpuVectorLayout mlirTpuVectorLayoutCreate( - int bitwidth, MlirTpuLayoutOffsets offsets, MlirTpuI64TargetTuple tiling, - MlirTpuImplicitDim implicit_dim); - -MLIR_CAPI_EXPORTED void mlirTpuVectorLayoutDestroy(MlirTpuVectorLayout); - -MLIR_CAPI_EXPORTED int mlirTpuVectorLayoutGetBitwidth( - MlirTpuVectorLayout layout); - -MLIR_CAPI_EXPORTED MlirTpuLayoutOffsets -mlirTpuVectorLayoutGetOffsets(MlirTpuVectorLayout layout); - -MLIR_CAPI_EXPORTED MlirTpuI64TargetTuple -mlirTpuVectorLayoutGetTiling(MlirTpuVectorLayout layout); - -MLIR_CAPI_EXPORTED MlirTpuImplicitDim -mlirTpuVectorLayoutGetImplicitDim(MlirTpuVectorLayout layout); - -MLIR_CAPI_EXPORTED int mlirTpuVectorLayoutGetPacking( - MlirTpuVectorLayout layout); - -MLIR_CAPI_EXPORTED int mlirTpuVectorLayoutGetLayoutRank( - MlirTpuVectorLayout layout); - -MLIR_CAPI_EXPORTED bool mlirTpuVectorLayoutEquals(MlirTpuVectorLayout lhs, - MlirTpuVectorLayout rhs); - -MLIR_CAPI_EXPORTED int64_t mlirTpuVectorLayoutTilesPerVreg( - MlirTpuVectorLayout layout, MlirTpuI64TargetTuple target_shape); - -MLIR_CAPI_EXPORTED int64_t mlirTpuVectorLayoutSublanesPerTile( - MlirTpuVectorLayout layout, MlirTpuI64TargetTuple target_shape); - -MLIR_CAPI_EXPORTED MlirTpuI64TargetTuple mlirTpuVectorLayoutVregSlice( - MlirTpuVectorLayout layout, MlirTpuI64TargetTuple target_shape); - -// Caller is responsible for calling free on the returned pointer -MLIR_CAPI_EXPORTED MlirTpuI64ArrayRef mlirTpuVectorLayoutImplicitShape( - MlirTpuVectorLayout layout, MlirTpuI64ArrayRef shape); - -// Caller is responsible for calling free on the returned pointer. -MLIR_CAPI_EXPORTED MlirTpuI64ArrayRef mlirTpuVectorLayoutTileArrayShape( - MlirTpuVectorLayout layout, MlirTpuI64ArrayRef shape, - MlirTpuI64TargetTuple target_shape); - -// Caller owns the returned object and is responsible for calling -// mlirTpuVectorLayoutVregDataBoundsDestroy -MLIR_CAPI_EXPORTED MlirTpuVregDataBounds mlirTpuVectorLayoutTileDataBounds( - MlirTpuVectorLayout layout, MlirContext ctx, int64_t* full_shape, - int64_t* idxs, size_t size, MlirTpuI64TargetTuple target_shape, - MlirTpuBoolTargetTuple allow_replicated); - -MLIR_CAPI_EXPORTED bool mlirTpuVectorLayoutHasNaturalTopology( - MlirTpuVectorLayout layout, MlirTpuI64TargetTuple target_shape); - -MLIR_CAPI_EXPORTED bool mlirTpuVectorLayoutHasNativeTiling( - MlirTpuVectorLayout layout, MlirTpuI64TargetTuple target_shape); - -// `shape` is optional, pass a shape with a null `ptr` to return true iff the -// "generalizes" relationship applies to all shapes. -MLIR_CAPI_EXPORTED bool mlirTpuVectorLayoutGeneralizes( - MlirTpuVectorLayout layout, MlirTpuVectorLayout other, - MlirTpuI64ArrayRef shape, MlirTpuI64TargetTuple target_shape); - -// `shape` is optional, pass a shape with a null `ptr` to return true iff the -// "equivalent to" relationship applies to all shapes. -MLIR_CAPI_EXPORTED bool mlirTpuVectorLayoutEquivalentTo( - MlirTpuVectorLayout layout, MlirTpuVectorLayout other, - MlirTpuI64ArrayRef shape, MlirTpuI64TargetTuple target_shape); - -MLIR_CAPI_EXPORTED void mlirTpuVectorLayoutPrint( - MlirTpuVectorLayout layout, MlirStringCallback callback, void* user_data); - -MLIR_CAPI_EXPORTED bool mlirTpuVectorLayoutIsValid( - MlirTpuVectorLayout layout, MlirTpuI64TargetTuple target_shape); - -MLIR_CAPI_EXPORTED void mlirTpuVregDataBoundsDestroy( - MlirTpuVregDataBounds data_bounds); - -MLIR_CAPI_EXPORTED bool mlirTpuVregDataBoundsMaskVariesAlong( - MlirTpuVregDataBounds data_bounds, MlirTpuDirection direction, - MlirTpuI64TargetTuple target_shape); - -MLIR_CAPI_EXPORTED bool mlirTpuVregDataBoundsIsComplete( - MlirTpuVregDataBounds data_bounds, MlirTpuI64TargetTuple target_shape); -// Returns null on failure -MLIR_CAPI_EXPORTED MlirValue mlirTpuVregDataBoundsGetVectorMask( - MlirTpuVregDataBounds data_bounds, MlirTpuInsertionPoint insertion_point, - MlirLocation location, int generation, MlirTpuI64TargetTuple target_shape); - -MLIR_CAPI_EXPORTED MlirAttribute mlirTpuVregDataBoundsGetSublaneMask( - MlirTpuVregDataBounds data_bounds, MlirContext ctx, - MlirTpuI64TargetTuple target_shape); - -// vals are copied, ownership is not stolen. -MLIR_CAPI_EXPORTED MlirOperation -mlirTpuAssemble(MlirTpuInsertionPoint insertion_point, MlirType vector_type, - MlirTpuVectorLayout layout, MlirTpuValueArray vals, - MlirTpuI64TargetTuple target_shape); - -// Returns null on failure -// Caller owns the returned object and is responsible for calling free on shape -// and vals -MLIR_CAPI_EXPORTED MlirTpuValueArray mlirTpuDisassemble( - MlirTpuInsertionPoint insertion_point, MlirTpuVectorLayout layout, - MlirValue val, MlirTpuI64TargetTuple target_shape); - -MLIR_CAPI_EXPORTED MlirLogicalResult -mlirTpuApplyLayoutOp(MlirTpuApplyVectorLayoutContext ctx, MlirOperation op); +MLIR_CAPI_EXPORTED void mlirTpuRegisterMosaicSerdePass(); -// Returns null on failure -MLIR_CAPI_EXPORTED MlirValue -mlirTpuRelayout(MlirTpuInsertionPoint insertion_point, MlirValue val, - MlirTpuVectorLayout src, MlirTpuVectorLayout dst, - MlirTpuApplyVectorLayoutContext ctx); +MLIR_CAPI_EXPORTED MlirType mlirTpuFloat8EXMYTypeGetUnderlyingType( + MlirType exmy_type); +MLIR_CAPI_EXPORTED bool mlirTpuIsAFloat8EXMYType(MlirType type); -MLIR_CAPI_EXPORTED void mlirTpuRegisterMosaicSerdePass(); +MLIR_CAPI_EXPORTED MlirType mlirTpuFloat8EXMYTypeGet( + MlirContext ctx, MlirType exmy_type); #ifdef __cplusplus -} +} // extern "C" #endif -#endif // JAXLIB_MOSAIC_DIALECT_TPU_INTEGRATIONS_C_TPU_DIALECT_H_ +#endif // JAXLIB_MOSAIC_DIALECT_TPU_INTEGRATIONS_C_TPU_DIALECT_H_ diff --git a/jaxlib/mosaic/dialect/tpu/layout.cc b/jaxlib/mosaic/dialect/tpu/layout.cc index 172f2e91b41f..127e69654b94 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.cc +++ b/jaxlib/mosaic/dialect/tpu/layout.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -27,9 +26,11 @@ limitations under the License. #include #include +#include "absl/log/check.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" @@ -41,49 +42,10 @@ limitations under the License. #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" namespace mlir::tpu { - -bool RectangularVregBounds::maskVariesAlong( - const Direction direction, - const std::array target_shape) const { - switch (direction) { - case Direction::kSublanes: - return starts_[0] != 0 || ends_[0] != target_shape[0]; - case Direction::kLanes: - return starts_[1] != 0 || ends_[1] != target_shape[1]; - case Direction::kSubelements: - return false; - } -} - -FailureOr> RectangularVregBounds::getVectorMask( - OpBuilder& builder, const Location loc, const int /*generation*/, - const std::array target_shape) const { - auto boundIdxConst = std::bind(IdxConst, std::placeholders::_1, builder, loc); - return cast>( - builder - .create( - loc, VectorType::get(target_shape, builder.getI1Type()), - /*low=*/ - ValueRange{boundIdxConst(starts_[0]), boundIdxConst(starts_[1])}, - /*high=*/ - ValueRange{boundIdxConst(ends_[0]), boundIdxConst(ends_[1])}) - .getResult()); -} - -DenseBoolArrayAttr RectangularVregBounds::getSublaneMask( - MLIRContext* mlir_ctx, const std::array target_shape) const { - SmallVector sublane_mask(target_shape[0], false); - for (int64_t i = starts_[0]; i < ends_[0]; ++i) { - sublane_mask[i] = true; - } - return DenseBoolArrayAttr::get(mlir_ctx, sublane_mask); -} - namespace { // Represents a subset of a (packed) 1D vector register. @@ -145,8 +107,8 @@ class SingleRowVRegBounds : public VRegDataBounds { } const auto i32_vreg = VectorType::get(target_shape, builder.getI32Type()); const auto getI32VregConstant = [&](const int32_t v) { - return builder.create( - loc, i32_vreg, DenseElementsAttr::get(i32_vreg, v)); + return arith::ConstantOp::create(builder, loc, i32_vreg, + DenseElementsAttr::get(i32_vreg, v)); }; if (layout_.bitwidth() != 32 && (start_offset_ % (target_shape[1] * layout_.packing()) != 0 || @@ -155,15 +117,15 @@ class SingleRowVRegBounds : public VRegDataBounds { } const Value start = getI32VregConstant(start_offset_ / layout_.packing()); const Value end = getI32VregConstant(stop_offset_ / layout_.packing()); - const Value iota = builder.create(loc, i32_vreg, nullptr); + const Value iota = + tpu::IotaOp::create(builder, loc, i32_vreg, ArrayRef{0, 1}); return cast>( - builder - .create( - loc, - builder.create(loc, arith::CmpIPredicate::sge, - iota, start), - builder.create(loc, arith::CmpIPredicate::slt, - iota, end)) + arith::AndIOp::create( + builder, loc, + arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::sge, iota, + start), + arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::slt, iota, + end)) .getResult()); } @@ -189,208 +151,6 @@ class SingleRowVRegBounds : public VRegDataBounds { int64_t stop_offset_; }; -// Represents the data bounds within a vector register with tiled and -// potentially packed data. -// -// Note that the (packed) sublane offset from start_offset and (packed) sublane -// bound from end_offsets apply to all tiles within a vreg. On the other hand, -// the lane offset from start_offset only applies to the first tile, while -// lane bound from end_offset only applies to the last used tile. -// -// Attributes: -// layout: The layout of the value, mainly used for its bitwidth and tiling. -// Note that the layout offsets SHOULD NOT be used. -// num_tiles: The number of tiles at the beginning of the vreg that contain -// actual data. -// start_offsets: The lane and (packed) sublane offset within the first tile. -// end_offsets: The lane and (packed) sublane offset within the last used -// tile. -class TiledRectangularVregBounds : public VRegDataBounds { - public: - TiledRectangularVregBounds(const VectorLayout& layout, - const int64_t num_tiles, - const std::array start_offsets, - const std::array end_offsets, - const std::array target_shape) - : layout_(layout), - num_tiles_(num_tiles), - start_offsets_(start_offsets), - end_offsets_(end_offsets) { - CHECK(layout_.tiling()[1] == target_shape[1]); - CHECK(0 < num_tiles_ && num_tiles_ <= layout.tilesPerVreg(target_shape)); - for (auto [o, t] : llvm::zip(start_offsets_, layout_.tiling())) { - CHECK(0 <= o && o < t); - } - for (auto [o, t] : llvm::zip(end_offsets_, layout_.tiling())) { - CHECK(0 <= o && o <= t); - } - } - - bool usesAllTiles(const std::array target_shape) const { - return num_tiles_ == layout_.tilesPerVreg(target_shape); - } - - // See base class. - bool maskVariesAlong( - const Direction direction, - const std::array target_shape) const override { - switch (direction) { - case Direction::kSublanes: - return !usesAllTiles(target_shape) || start_offsets_[0] != 0 || - end_offsets_[0] != layout_.tiling()[0]; - case Direction::kLanes: - return start_offsets_[1] != 0 || end_offsets_[1] != layout_.tiling()[1]; - case Direction::kSubelements: - return start_offsets_[0] % layout_.packing() != 0 || - end_offsets_[0] % layout_.packing() != 0; - } - } - - // See base class. - FailureOr> getVectorMask( - OpBuilder& builder, const Location loc, const int generation, - const std::array target_shape) const override { - const IntegerType i1 = builder.getI1Type(); - FAILUREOR_ASSIGN_OR_RETURN( - const VectorType mask_vreg_ty, [&]() -> FailureOr { - // I'm pretty sure this works for all bitwidths, but it's untested. - if (maskVariesAlong(Direction::kSubelements, target_shape)) { - if (layout_.packing() != 2) { - // TODO(b/300082350): Generalize this - return emitError(loc, "Not implemented: packing != 2"); - } - // For older TPUs, we virtualize masking - if (generation < 4) { - return VectorType::get(target_shape, i1); - } else { - return VectorType::get( - {target_shape[0], target_shape[1], layout_.packing()}, i1); - } - } - return VectorType::get(target_shape, i1); - }()); - if (isComplete(target_shape)) { - return cast>( - builder - .create( - loc, mask_vreg_ty, - DenseElementsAttr::get(mask_vreg_ty, - builder.getBoolAttr(true))) - .getResult()); - } - Value mask = nullptr; - CHECK_GE(num_tiles_, 0); - const int packing = layout_.packing(); - const int64_t start_sub = start_offsets_[0] / packing; - const int64_t end_sub = llvm::divideCeil(end_offsets_[0], packing); - CHECK_LE(0, start_sub); - CHECK_LT(start_sub, end_sub); - CHECK_LE(end_sub, target_shape[0]); - const int64_t sublanes_per_tile = layout_.sublanesPerTile(target_shape); - for (int64_t tile = 0; tile < num_tiles_; ++tile) { - const int64_t sublane_offset = sublanes_per_tile * tile; - const int64_t row_offset = sublane_offset * layout_.packing(); - const int64_t start_lane = tile == 0 ? start_offsets_[1] : 0; - const int64_t end_lane = - tile == num_tiles_ - 1 ? end_offsets_[1] : target_shape[1]; - CHECK_LE(0, start_lane); - CHECK_LT(start_lane, end_lane); - CHECK_LE(end_lane, target_shape[1]); - auto boundIdxConst = - std::bind(IdxConst, std::placeholders::_1, builder, loc); - // TODO(apaszke): For loads/stores whole sublanes are covered by the - // sublane mask, so we can focus only on lanes and partial sublanes. - Value tile_mask = builder.create( - loc, mask_vreg_ty, - ValueRange{boundIdxConst(sublane_offset + start_sub), - boundIdxConst(start_lane)}, - ValueRange{boundIdxConst(sublane_offset + end_sub), - boundIdxConst(end_lane)}); - if (maskVariesAlong(Direction::kSubelements, target_shape)) { - int64_t start_row = start_offsets_[0] + row_offset; - int64_t end_row = end_offsets_[0] + row_offset; - if (generation >= 4) { - // Only use non-trivial start/end if they don't fall on sublane - // boundary. Otherwise CreateMaskOp already does the right thing. This - // lets us use cheaper instruction sequences on TPUv4. - if (start_offsets_[0] % layout_.packing() == 0) { - start_row = 0; - } - if (end_offsets_[0] % layout_.packing() == 0) { - end_row = target_shape[0] * layout_.packing(); - } - auto submask = builder.create( - loc, mask_vreg_ty, start_row, end_row); - tile_mask = builder.create(loc, tile_mask, submask); - } else { // generation < 4 - const auto getMaskCst = [&](const uint64_t v) { - const auto int_mask_ty = - VectorType::get(target_shape, builder.getI32Type()); - return builder.create( - loc, int_mask_ty, - DenseElementsAttr::get( - int_mask_ty, builder.getIntegerAttr(builder.getI32Type(), - APInt(32, v)))); - }; - tile_mask = builder.create( - loc, tile_mask, getMaskCst(0xFFFFFFFF), getMaskCst(0)); - if (start_row % 2 != 0) { - auto row_mask = builder.create( - loc, mask_vreg_ty, - ValueRange{boundIdxConst(start_row / 2), boundIdxConst(0)}, - ValueRange{boundIdxConst(start_row / 2 + 1), - boundIdxConst(target_shape[1])}); - auto row_bitmask = builder.create( - loc, row_mask, getMaskCst(0xFFFF0000), getMaskCst(0xFFFFFFFF)); - tile_mask = - builder.create(loc, tile_mask, row_bitmask); - } - if (end_row % 2 != 0) { - auto row_mask = builder.create( - loc, mask_vreg_ty, - ValueRange{boundIdxConst(end_row / 2), boundIdxConst(0)}, - ValueRange{boundIdxConst(end_row / 2 + 1), - boundIdxConst(target_shape[1])}); - auto row_bitmask = builder.create( - loc, row_mask, getMaskCst(0xFFFF), getMaskCst(0xFFFFFFFF)); - tile_mask = - builder.create(loc, tile_mask, row_bitmask); - } - } - } - mask = mask == nullptr - ? tile_mask - : builder.create(loc, tile_mask, mask); - } - CHECK(mask != nullptr); - return cast>(mask); - } - - // See base class - DenseBoolArrayAttr getSublaneMask( - MLIRContext* mlir_ctx, - const std::array target_shape) const override { - SmallVector mask(target_shape[0], false); - const int64_t start = start_offsets_[0] / layout_.packing(); - const int64_t end = llvm::divideCeil(end_offsets_[0], layout_.packing()); - const int64_t sublanes_per_tile = layout_.sublanesPerTile(target_shape); - const int64_t sublane_bound = num_tiles_ * sublanes_per_tile; - for (int64_t sub = 0; sub < sublane_bound; sub += sublanes_per_tile) { - for (int64_t i = sub + start; i < sub + end; ++i) { - CHECK(!mask[i]); - mask[i] = true; - } - } - return DenseBoolArrayAttr::get(mlir_ctx, mask); - } - - private: - VectorLayout layout_; - int64_t num_tiles_; - std::array start_offsets_; - std::array end_offsets_; -}; - mlir::ParseResult parseOffset(StringRef* data, std::optional* result) { int64_t int_result; if (data->consume_front("*")) { @@ -409,9 +169,175 @@ std::array nativeTiling(const int8_t bitwidth, const int packing = 32 / bitwidth; return {target_shape[0] * packing, target_shape[1]}; } - } // namespace +bool TiledRectangularVregBounds::usesAllTiles( + const std::array target_shape) const { + return start_offsets_[1] / layout_.tiling()[1] == 0 && + llvm::divideCeil(end_offsets_[1], layout_.tiling()[1]) == + layout_.tilesPerVreg(target_shape); +} + +// See base class. +bool TiledRectangularVregBounds::maskVariesAlong( + const Direction direction, const std::array target_shape) const +/*override*/ { + switch (direction) { + case Direction::kSublanes: + return !usesAllTiles(target_shape) || start_offsets_[0] != 0 || + end_offsets_[0] != layout_.tiling()[0]; + case Direction::kLanes: + return start_offsets_[1] % layout_.tiling()[1] != 0 || + end_offsets_[1] % layout_.tiling()[1] != 0; + case Direction::kSubelements: + return start_offsets_[0] % layout_.packing() != 0 || + end_offsets_[0] % layout_.packing() != 0; + } +} + +// See base class. +FailureOr> TiledRectangularVregBounds::getVectorMask( + OpBuilder& builder, const Location loc, const int generation, + const std::array target_shape) const /*override*/ { + const int8_t bitwidth = layout_.bitwidth(); + const int packing = layout_.packing(); + const int max_subelems = generation < 4 ? 1 : generation < 5 ? 2 : 4; + const IntegerType i1 = builder.getI1Type(); + const VectorType mask_vreg_ty = [&]() { + if (maskVariesAlong(Direction::kSubelements, target_shape)) { + // When CreateSubelementMask isn't supported, we virtualize masking. + if (packing > max_subelems) { + return VectorType::get(target_shape, i1); + } else { + return VectorType::get({target_shape[0], target_shape[1], packing}, i1); + } + } + return VectorType::get(target_shape, i1); + }(); + if (isComplete(target_shape)) { + return cast>( + arith::ConstantOp::create( + builder, loc, mask_vreg_ty, + DenseElementsAttr::get(mask_vreg_ty, builder.getBoolAttr(true))) + .getResult()); + } + Value mask = nullptr; + const int64_t start_sub = start_offsets_[0] / packing; + const int64_t end_sub = llvm::divideCeil(end_offsets_[0], packing); + CHECK_LE(0, start_sub); + CHECK_LT(start_sub, end_sub); + CHECK_LE(end_sub, target_shape[0]); + const int64_t sublanes_per_tile = layout_.sublanesPerTile(target_shape); + const int64_t start_tile = start_offsets_[1] / layout_.tiling()[1]; + const int64_t end_tile = + llvm::divideCeil(end_offsets_[1], layout_.tiling()[1]); + for (int64_t tile = start_tile; tile < end_tile; ++tile) { + const int64_t sublane_offset = sublanes_per_tile * tile; + const int64_t row_offset = sublane_offset * layout_.packing(); + const int64_t start_lane = + tile == start_tile ? start_offsets_[1] % layout_.tiling()[1] : 0; + const int64_t end_lane = + tile == end_tile - 1 ? positiveMod(end_offsets_[1], layout_.tiling()[1]) + : target_shape[1]; + CHECK_LE(0, start_lane); + CHECK_LT(start_lane, end_lane); + CHECK_LE(end_lane, target_shape[1]); + auto boundIdxConst = + std::bind(IdxConst, std::placeholders::_1, builder, loc); + // TODO(apaszke): For loads/stores whole sublanes are covered by the sublane + // mask, so we can focus only on lanes and partial sublanes. + Value tile_mask = CreateMaskOp::create( + builder, loc, mask_vreg_ty, + ValueRange{boundIdxConst(sublane_offset + start_sub), + boundIdxConst(start_lane)}, + ValueRange{boundIdxConst(sublane_offset + end_sub), + boundIdxConst(end_lane)}); + if (maskVariesAlong(Direction::kSubelements, target_shape)) { + int64_t start_row = start_offsets_[0] + row_offset; + int64_t end_row = end_offsets_[0] + row_offset; + if (packing <= max_subelems) { + // Only use non-trivial start/end if they don't fall on sublane + // boundary. Otherwise CreateMaskOp already does the right thing. This + // lets us use cheaper instruction sequences on TPUv4. + if (start_offsets_[0] % packing == 0) { + start_row = 0; + } + if (end_offsets_[0] % packing == 0) { + end_row = target_shape[0] * packing; + } + auto submask = tpu::CreateSubelementMaskOp::create( + builder, loc, mask_vreg_ty, start_row, end_row); + tile_mask = arith::AndIOp::create(builder, loc, tile_mask, submask); + } else { // packing > max_subelems + const auto getMaskCst = [&](const uint64_t v) { + const auto int_mask_ty = + VectorType::get(target_shape, builder.getI32Type()); + return arith::ConstantOp::create( + builder, loc, int_mask_ty, + DenseElementsAttr::get( + int_mask_ty, + builder.getIntegerAttr(builder.getI32Type(), APInt(32, v)))); + }; + tile_mask = arith::SelectOp::create( + builder, loc, tile_mask, getMaskCst(0xFFFFFFFF), getMaskCst(0)); + if (const int64_t row_in_sublane = start_row % packing; + row_in_sublane != 0) { + auto row_mask = tpu::CreateMaskOp::create( + builder, loc, mask_vreg_ty, + ValueRange{boundIdxConst(start_row / packing), boundIdxConst(0)}, + ValueRange{boundIdxConst(start_row / packing + 1), + boundIdxConst(target_shape[1])}); + auto row_bitmask = arith::SelectOp::create( + builder, loc, row_mask, + getMaskCst(0xFFFFFFFF << row_in_sublane * bitwidth), + getMaskCst(0xFFFFFFFF)); + tile_mask = + arith::AndIOp::create(builder, loc, tile_mask, row_bitmask); + } + if (const int64_t row_in_sublane = end_row % packing; + row_in_sublane != 0) { + auto row_mask = tpu::CreateMaskOp::create( + builder, loc, mask_vreg_ty, + ValueRange{boundIdxConst(end_row / packing), boundIdxConst(0)}, + ValueRange{boundIdxConst(end_row / packing + 1), + boundIdxConst(target_shape[1])}); + auto row_bitmask = arith::SelectOp::create( + builder, loc, row_mask, + getMaskCst(0xFFFFFFFFu >> (packing - row_in_sublane) * bitwidth), + getMaskCst(0xFFFFFFFF)); + tile_mask = + arith::AndIOp::create(builder, loc, tile_mask, row_bitmask); + } + } + } + mask = mask == nullptr + ? tile_mask + : arith::OrIOp::create(builder, loc, tile_mask, mask); + } + CHECK(mask != nullptr); + return cast>(mask); +} + +// See base class +DenseBoolArrayAttr TiledRectangularVregBounds::getSublaneMask( + MLIRContext* mlir_ctx, const std::array target_shape) const +/*override*/ { + SmallVector mask(target_shape[0], false); + const int64_t start = start_offsets_[0] / layout_.packing(); + const int64_t end = llvm::divideCeil(end_offsets_[0], layout_.packing()); + const int64_t sublanes_per_tile = layout_.sublanesPerTile(target_shape); + const int64_t start_tile = start_offsets_[1] / layout_.tiling()[1]; + const int64_t end_tile = + llvm::divideCeil(end_offsets_[1], layout_.tiling()[1]); + for (int64_t tile = start_tile; tile < end_tile; ++tile) { + for (int64_t i = start; i < end; ++i) { + CHECK(!mask[tile * sublanes_per_tile + i]); + mask[tile * sublanes_per_tile + i] = true; + } + } + return DenseBoolArrayAttr::get(mlir_ctx, mask); +} + std::tuple, std::optional, int64_t, int64_t, int8_t, VectorLayout::ImplicitDim> VectorLayout::as_tuple() const { @@ -448,8 +374,12 @@ SmallVector VectorLayout::tileArrayShape( int64_t& second_minor = *(src_shape.end() - 2); int64_t& minor = *(src_shape.end() - 1); second_minor = - llvm::divideCeil(offsets_[0].value_or(0) + second_minor, vreg_slice[0]); - minor = llvm::divideCeil(offsets_[1].value_or(0) + minor, vreg_slice[1]); + offsets_[0].has_value() + ? llvm::divideCeil(*offsets_[0] + second_minor, vreg_slice[0]) + : 1; + minor = offsets_[1].has_value() + ? llvm::divideCeil(*offsets_[1] + minor, vreg_slice[1]) + : 1; if (!res_is_implicit) { CHECK_GE(src_shape.size(), 2); eraseImplicit(src_shape); @@ -464,6 +394,13 @@ std::unique_ptr VectorLayout::tileDataBounds( // TODO(apaszke): allow_replicated could have been generalized to specify // what action should be taken when a REPLICATED offset is encountered. // Right now it either disallows replication, or selects the whole dimension. + for (const int i : {0, 1}) { + if (!allow_replicated[i] && !offsets_[i].has_value()) { + emitError(UnknownLoc::get(mlir_ctx), "Unexpected replicated offset"); + return nullptr; + } + } + const std::array vreg_slice = vregSlice(target_shape); const std::array tiled_idxs = getImplicitTiledDims(idxs, 0); const int64_t s = tiled_idxs[0]; const int64_t l = tiled_idxs[1]; @@ -471,81 +408,31 @@ std::unique_ptr VectorLayout::tileDataBounds( tileArrayImplicitShape(full_shape, target_shape); const int64_t ns = *(tiles_implicit_shape.end() - 2); const int64_t nl = *(tiles_implicit_shape.end() - 1); - const std::array shape_tiled_dims = + const std::array tiled_ishape = getImplicitTiledDims(full_shape, 1); - const int64_t is = shape_tiled_dims[0]; - const int64_t il = shape_tiled_dims[1]; - - if (!hasNaturalTopology(target_shape)) { - if (!offsets_[0].has_value() || !offsets_[1].has_value()) { - emitError(UnknownLoc::get(mlir_ctx), - "Not implemented: non-natural topology with replication"); - return nullptr; - } - const int64_t so = *offsets_[0]; - const int64_t lo = *offsets_[1]; - if (tiling_[0] == 1 && tiling_[1] % target_shape[1] == 0 && - implicit_dim_ == ImplicitDim::kSecondMinor) { - const int64_t values_per_vreg = - target_shape[0] * target_shape[1] * packing(); - const int64_t start_offset = l == 0 ? lo : 0; - const int64_t end_offset = - l == nl - 1 ? lo + il - l * values_per_vreg : values_per_vreg; - return std::make_unique(*this, start_offset, - end_offset, target_shape); - } - if (tiling_[1] != target_shape[1]) { - emitError(UnknownLoc::get(mlir_ctx), - "Not implemented: Unaligned tiling on minormost dimension"); - return nullptr; - } - const int64_t start_sublanes = s == 0 ? so : 0; - const int64_t start_lanes = l == 0 ? lo : 0; - const int64_t end_sublanes = - s == ns - 1 ? (so + is - 1) % tiling_[0] + 1 : tiling_[0]; - const int64_t end_lanes = - l == nl - 1 ? (lo + il - 1) % tiling_[1] + 1 : tiling_[1]; - const int64_t tiles_per_vreg = tilesPerVreg(target_shape); - const int64_t minormost_tiles = llvm::divideCeil(lo + il, tiling_[1]); - const int64_t num_tiles = - l == nl - 1 && minormost_tiles % tiles_per_vreg != 0 - ? minormost_tiles % tiles_per_vreg - : tiles_per_vreg; - return std::make_unique( - *this, num_tiles, std::array{start_sublanes, start_lanes}, - std::array{end_sublanes, end_lanes}, target_shape); - } - // TODO(apaszke): Remove this path in favor of TiledVRegBounds - const std::array shift = {offsets_[0].value_or(0), - offsets_[1].value_or(0)}; - const int64_t sb = s == 0 ? shift[0] : 0; - const int64_t lb = l == 0 ? shift[1] : 0; - int64_t se = target_shape[0]; - int64_t le = target_shape[1]; - // First, deal with sublanes. - if (!offsets_[0].has_value()) { - if (!allow_replicated[0]) { - emitError(UnknownLoc::get(mlir_ctx), "Unexpected replicated offset"); - return nullptr; - } - // Otherwise, do nothing. We take the full slice. - } else if (s == ns - 1) { - se = shift[0] + is - s * target_shape[0]; - } - // Now, we deal with lanes. - if (!offsets_[1].has_value()) { - if (!allow_replicated[1]) { - emitError(UnknownLoc::get(mlir_ctx), "Unexpected replicated offset"); - return nullptr; - } - // Otherwise, do nothing. We take the full slice. - } else if (l == nl - 1) { - le = shift[1] + il - l * target_shape[1]; - } - CHECK_LT(sb, se); - CHECK_LT(lb, le); - return std::make_unique( - std::array{sb, lb}, std::array{se, le}); + // The starts and ends of the data within the vreg slice: + const std::array starts = { + offsets_[0] && s == 0 ? *offsets_[0] : 0, + offsets_[1] && l == 0 ? *offsets_[1] : 0}; + const std::array ends = { + offsets_[0] && s == ns - 1 + ? positiveMod(*offsets_[0] + tiled_ishape[0], vreg_slice[0]) + : vreg_slice[0], + offsets_[1] && l == nl - 1 + ? positiveMod(*offsets_[1] + tiled_ishape[1], vreg_slice[1]) + : vreg_slice[1]}; + + if (tiling_[0] == 1 && tiling_[1] % target_shape[1] == 0) { + return std::make_unique(*this, starts[1], ends[1], + target_shape); + } + if (tiling_[1] != target_shape[1]) { + emitError(UnknownLoc::get(mlir_ctx), + "Not implemented: Unaligned tiling on minormost dimension"); + return nullptr; + } + return std::make_unique(*this, starts, ends, + target_shape); } bool VectorLayout::generalizes( @@ -605,10 +492,46 @@ bool VectorLayout::generalizes( } template -void VectorLayout::print(Stream& os) const { - os << static_cast(bitwidth_) << ",{"; +Stream& printImplicitDim(Stream& os, VectorLayout::ImplicitDim dim) { + switch (dim) { + case VectorLayout::ImplicitDim::kNone: + os << "none"; + break; + case VectorLayout::ImplicitDim::kMinor: + os << "-1"; + break; + case VectorLayout::ImplicitDim::kSecondMinor: + os << "-2"; + break; + case VectorLayout::ImplicitDim::kMinorAndSecondMinor: + os << "-2,-1"; + break; + } + return os; +} + +std::ostream& operator<<(std::ostream& os, VectorLayout::ImplicitDim dim) { + return printImplicitDim(os, dim); +} + +llvm::raw_ostream& operator<<(llvm::raw_ostream& os, + VectorLayout::ImplicitDim dim) { + return printImplicitDim(os, dim); +} + +mlir::Diagnostic& operator<<(mlir::Diagnostic& diag, + VectorLayout::ImplicitDim dim) { + return printImplicitDim(diag, dim); +} + +template +static void printVectorLayout(Stream& os, const int32_t bitwidth, + const VectorLayout::ImplicitDim implicit_dim, + const LayoutOffsets offsets, + const std::array& tiling) { + os << static_cast(bitwidth) << ",{"; bool first = true; - for (auto o : offsets_) { + for (auto o : offsets) { if (first) { first = false; } else { @@ -620,17 +543,37 @@ void VectorLayout::print(Stream& os) const { os << *o; } } - os << "},(" << tiling_[0] << ',' << tiling_[1] << ")"; - if (implicit_dim_ == ImplicitDim::kMinor) { - os << ",-1"; - } else if (implicit_dim_ == ImplicitDim::kSecondMinor) { - os << ",-2"; + os << "},(" << tiling[0] << ',' << tiling[1] << ")"; + if (implicit_dim != VectorLayout::ImplicitDim::kNone) { + os << "," << implicit_dim; } } +void VectorLayout::print(llvm::raw_ostream& os) const { + printVectorLayout(os, bitwidth_, implicit_dim_, offsets_, tiling_); +} + +void VectorLayout::print(std::ostream& os) const { + printVectorLayout(os, bitwidth_, implicit_dim_, offsets_, tiling_); +} + +void VectorLayout::print(mlir::Diagnostic& diag) const { + printVectorLayout(diag, bitwidth_, implicit_dim_, offsets_, tiling_); +} + std::optional VectorLayout::join(const VectorLayout& l, const VectorLayout& r, ArrayRef shape) { + auto is_fully_replicated = [&](const VectorLayout& layout) { + const LayoutOffsets& offsets = layout.getCanonicalOffsets(shape); + return !offsets[0] && !offsets[1]; + }; + if (is_fully_replicated(l) && l.layout_rank() >= r.layout_rank()) { + return r; + } + if (is_fully_replicated(r) && r.layout_rank() >= l.layout_rank()) { + return l; + } if (l.bitwidth_ != r.bitwidth_ || l.tiling_ != r.tiling_) { return std::nullopt; } @@ -662,7 +605,9 @@ std::optional VectorLayout::parse(StringRef* data) { local.consumeInteger(10, tiling[1]) || !local.consume_front(")")) { return std::nullopt; } - if (local.consume_front(",-1")) { + if (local.consume_front(",-2,-1")) { + implicit_dim = ImplicitDim::kMinorAndSecondMinor; + } else if (local.consume_front(",-1")) { implicit_dim = ImplicitDim::kMinor; } else if (local.consume_front(",-2")) { implicit_dim = ImplicitDim::kSecondMinor; @@ -705,31 +650,6 @@ llvm::hash_code hash_value(const VectorLayout& layout) { return llvm::hash_value(layout.as_tuple()); } -template -Stream& printImplicitDim(Stream& os, VectorLayout::ImplicitDim dim) { - switch (dim) { - case VectorLayout::ImplicitDim::kNone: - os << "none"; - break; - case VectorLayout::ImplicitDim::kMinor: - os << "-1"; - break; - case VectorLayout::ImplicitDim::kSecondMinor: - os << "-2"; - break; - } - return os; -} - -std::ostream& operator<<(std::ostream& os, VectorLayout::ImplicitDim dim) { - return printImplicitDim(os, dim); -} - -mlir::Diagnostic& operator<<(mlir::Diagnostic& diag, - VectorLayout::ImplicitDim dim) { - return printImplicitDim(diag, dim); -} - std::optional parseLayout(mlir::AsmParser& parser) { std::string layout_str; if (failed(parser.parseString(&layout_str))) { diff --git a/jaxlib/mosaic/dialect/tpu/layout.h b/jaxlib/mosaic/dialect/tpu/layout.h index 2c45be62fa7d..9002db62adf5 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.h +++ b/jaxlib/mosaic/dialect/tpu/layout.h @@ -18,12 +18,14 @@ limitations under the License. #include #include +#include #include -#include #include #include #include +#include "absl/log/check.h" +#include "llvm/ADT/Hashing.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/bit.h" #include "llvm/Support/ErrorHandling.h" @@ -38,7 +40,6 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" namespace mlir::tpu { @@ -84,41 +85,6 @@ struct VRegDataBounds { MLIRContext *ctxt, std::array target_shape) const = 0; }; -// Represents a rectangular region of data within a vector register. -// -// This class is very limited in its power and should only be used for 32-bit -// values with native tiling. -// -// Attributes: -// bounds: A TargetTuple of slices encoding the bounds of the rectangular -// data region. -// TODO(tlongeri): Can this be removed in favor of the more general -// TiledRectangularVregBounds? -class RectangularVregBounds : public VRegDataBounds { - public: - RectangularVregBounds(const std::array starts, - const std::array ends) - : starts_(starts), ends_(ends) {} - - // See base class. - bool maskVariesAlong(Direction direction, - std::array target_shape) const override; - - // See base class. - FailureOr> getVectorMask( - OpBuilder &builder, Location loc, int generation, - std::array target_shape) const override; - - // See base class. - DenseBoolArrayAttr getSublaneMask( - MLIRContext *mlir_ctxt, - std::array target_shape) const override; - - private: - std::array starts_; - std::array ends_; -}; - // VectorLayout describes a mapping of an arbitrarily sized values into vregs. // // First, let us consider the simplest case, when implicit_dim is None, bitwidth @@ -205,8 +171,8 @@ class RectangularVregBounds : public VRegDataBounds { // tiling: The tiling used to lay out values (see the XLA docs). For values of // bitwidth < 32, an implicit (32 / bitwidth, 1) tiling is appended to the // one specified as an attribute. -// implicit_dim: If specified, the value has an implicit dim inserted in -// either minormost or second minormost position. +// implicit_dim: If specified, the value has implicit dims inserted in the +// minormost and/or second minormost position. // // Note: There is a special case when VectorLayout is used for an mlir::Value // of i1 type. In this case, we use it to represent a vmask, which has a smaller @@ -218,11 +184,15 @@ class RectangularVregBounds : public VRegDataBounds { // but we might want to split out a separate class if it gets used more widely. class VectorLayout { public: - enum class ImplicitDim { - kNone = 0, // To make if (implicit_dim) work. - // Also want to do dims[dims.size() - xla::to_underlying(implicit_dim)] + enum class ImplicitDim : unsigned { + // Each bit indicates whether the corresponding dimension is implicit. + // When bit 0 is set, the minor dimension is implicit. + // When bit 1 is set, the second minor dimension is implicit. + // WARNING: This should not be relied on outside of VectorLayout. + kNone = 0, kMinor = 1, kSecondMinor = 2, + kMinorAndSecondMinor = 3, }; VectorLayout(const int8_t bitwidth, const LayoutOffsets offsets, const std::array tiling, @@ -233,16 +203,15 @@ class VectorLayout { implicit_dim_(implicit_dim) { // TODO(b/275751535): Allow more bitwidths. CHECK(llvm::has_single_bit(bitwidth_) && bitwidth_ <= 32); + CHECK_GT(tiling_[0], 0); + CHECK_GT(tiling_[1], 0); + CHECK_GE(offsets_[0].value_or(0), 0); + CHECK_GE(offsets_[1].value_or(0), 0); + CHECK_LT(offsets_[0].value_or(0), tiling_[0]); } static int num_implicit_dims(const ImplicitDim implicit_dim) { - switch (implicit_dim) { - case ImplicitDim::kNone: - return 0; - case ImplicitDim::kMinor: - case ImplicitDim::kSecondMinor: - return 1; - } + return llvm::popcount(static_cast(implicit_dim)); } // The number of non-implicit dimensions that are tiled. @@ -252,9 +221,7 @@ class VectorLayout { int8_t bitwidth() const { return bitwidth_; } const LayoutOffsets &offsets() const { return offsets_; } - const LayoutOffsets getCanonicalOffsets( - const ArrayRef shape, - const std::array target_shape) const { + LayoutOffsets getCanonicalOffsets(const ArrayRef shape) const { // For (1, n) tiling with a single row, 2nd minor replication does not // change anything about the layout - it is equivalent to an offset of 0. // We choose a replicated offset as "canonical". @@ -318,9 +285,13 @@ class VectorLayout { case ImplicitDim::kNone: break; case ImplicitDim::kMinor: + vec.push_back(value); + break; case ImplicitDim::kSecondMinor: - vec.insert(vec.end() - (static_cast(implicit_dim_) - 1), - value); + vec.insert(vec.end() - 1, value); + break; + case ImplicitDim::kMinorAndSecondMinor: + vec.append(2, value); break; } } @@ -332,8 +303,13 @@ class VectorLayout { case ImplicitDim::kNone: break; case ImplicitDim::kMinor: + vec.pop_back(); + break; case ImplicitDim::kSecondMinor: - vec.erase(vec.end() - static_cast(implicit_dim_)); + vec.erase(vec.end() - 2); + break; + case ImplicitDim::kMinorAndSecondMinor: + vec.pop_back_n(2); break; } } @@ -349,6 +325,29 @@ class VectorLayout { return {*(arr.end() - 1), implicit_value}; case ImplicitDim::kSecondMinor: return {implicit_value, *(arr.end() - 1)}; + case ImplicitDim::kMinorAndSecondMinor: + return {implicit_value, implicit_value}; + } + } + // Returns the dimension of the implicit shape that corresponds to the given + // dimension of a non-implicit shape with the given `rank`. + static int64_t toImplicitDimension(const ImplicitDim implicit_dim, + const int64_t rank, int64_t dimension) { + CHECK_GE(rank, layout_rank(implicit_dim)); + if (rank - layout_rank(implicit_dim) > dimension) { + return dimension; + } + switch (implicit_dim) { + case ImplicitDim::kNone: + return dimension; + case ImplicitDim::kMinor: + CHECK_EQ(dimension, rank - 1); + return rank - 1; + case ImplicitDim::kSecondMinor: + CHECK_EQ(dimension, rank - 1); + return rank; + case ImplicitDim::kMinorAndSecondMinor: + llvm_unreachable("Invalid dimension"); } } @@ -488,8 +487,9 @@ class VectorLayout { other.generalizes(*this, shape, target_shape); } - template - void print(Stream &os) const; + void print(llvm::raw_ostream& os) const; + void print(std::ostream& os) const; + void print(mlir::Diagnostic& diag) const; static std::optional join(const VectorLayout &l, const VectorLayout &r, @@ -529,17 +529,80 @@ class VectorLayout { ImplicitDim implicit_dim_; }; +// Represents the data bounds within a vector register with tiled and +// potentially packed data. +// +// Start and end offsets for data are specified within the vreg slice. +// +// Attributes: +// layout: The layout of the value, mainly used for its bitwidth and tiling. +// Note that the layout offsets SHOULD NOT be used. +// start_offsets: The lane and (packed) sublane offset within the vreg slice. +// end_offsets: The lane and (packed) sublane offset within the vreg slice. +class TiledRectangularVregBounds : public VRegDataBounds { + public: + TiledRectangularVregBounds(const VectorLayout& layout, + const std::array start_offsets, + const std::array end_offsets, + const std::array target_shape) + : layout_(layout), + start_offsets_(start_offsets), + end_offsets_(end_offsets) { + CHECK(layout_.tiling()[1] == target_shape[1]); + const std::array vreg_slice = layout_.vregSlice(target_shape); + for (auto [start, end, vs] : + llvm::zip_equal(start_offsets_, end_offsets_, vreg_slice)) { + CHECK(0 <= start && start < end && end <= vs); + } + } + TiledRectangularVregBounds(const int8_t bitwidth, + const std::array tiling, + const std::array start_offsets, + const std::array end_offsets, + const std::array target_shape) + // The offsets and implicit dim of the layout are ignored. + : TiledRectangularVregBounds( + VectorLayout(bitwidth, LayoutOffsets(), tiling, + VectorLayout::ImplicitDim::kNone), + start_offsets, end_offsets, target_shape) {} + + bool usesAllTiles(std::array target_shape) const; + // See base class. + bool maskVariesAlong(Direction direction, + std::array target_shape) const override; + + // See base class. + FailureOr> getVectorMask( + OpBuilder& builder, Location loc, int generation, + std::array target_shape) const override; + + // See base class + DenseBoolArrayAttr getSublaneMask( + MLIRContext* mlir_ctx, + std::array target_shape) const override; + + private: + // TODO(tlongeri): Replace layout_ with bitwidth_ and tiling_ + VectorLayout layout_; + std::array start_offsets_; + std::array end_offsets_; +}; + using Layout = std::optional; extern const Layout kNoLayout; -std::ostream &operator<<(std::ostream &os, const Layout &v); -llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Layout &v); -llvm::hash_code hash_value(const VectorLayout &layout); -mlir::Diagnostic &operator<<(mlir::Diagnostic &diag, const Layout &v); std::ostream &operator<<(std::ostream &os, VectorLayout::ImplicitDim dim); +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + VectorLayout::ImplicitDim dim); mlir::Diagnostic &operator<<(mlir::Diagnostic &diag, VectorLayout::ImplicitDim dim); +std::ostream &operator<<(std::ostream &os, const Layout &v); +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Layout &v); +mlir::Diagnostic &operator<<(mlir::Diagnostic &diag, const Layout &v); + +llvm::hash_code hash_value(const VectorLayout &layout); + std::optional parseLayout(mlir::AsmParser &parser); } // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 4b5ed34934d7..22a3ffede006 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -1,4 +1,4 @@ -/* Copyright 2023 The JAX Authors. +/* Copyright 2025 The JAX Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,18 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TPU_ATTRS -#define TPU_ATTRS +#ifndef TPU_BASE +#define TPU_BASE -include "mlir/IR/OpBase.td" -include "mlir/IR/AttrTypeBase.td" -include "mlir/IR/BuiltinAttributeInterfaces.td" +include "mlir/IR/BuiltinAttributes.td" include "mlir/IR/BuiltinTypeInterfaces.td" -include "mlir/IR/EnumAttr.td" -include "mlir/Pass/PassBase.td" -include "mlir/Interfaces/ControlFlowInterfaces.td" -include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/Interfaces/InferTypeOpInterface.td" def TPU_Dialect : Dialect { let name = "tpu"; @@ -36,6 +29,8 @@ def TPU_Dialect : Dialect { static std::optional GetCoreTypeAttr(Operation *op); }]; + let hasConstantMaterializer = 1; + let hasCanonicalizer = 1; } class TPU_Attr traits = []> @@ -43,433 +38,13 @@ class TPU_Attr traits = []> let mnemonic = mnemonic_; } -// TODO(b/369418606): Find out the way to verify vreg size. -def TPU_Vreg : Type; - -class TPU_Type traits = []> - : TypeDef { - let mnemonic = mnemonic_; -} - -def TPU_CoreType : I32EnumAttr<"CoreType", "Core type", [ - I32EnumAttrCase<"kTc", 0, "tc">, - I32EnumAttrCase<"kScScalarSubcore", 1, "sc_scalar_subcore">, - I32EnumAttrCase<"kScVectorSubcore", 2, "sc_vector_subcore"> -]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::tpu"; -} - -def TPU_CoreTypeEnum : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - -def TPU_PipelineMode : I32EnumAttr<"PipelineMode", "Pipeline mode", [ - I32EnumAttrCase<"kSynchronous", 1, "synchronous">, - I32EnumAttrCase<"kDoubleBuffered", 2, "double_buffered"> - ]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::tpu"; -} - -def TPU_PipelineModeEnum : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - -def TPU_SemaphoreType : TPU_Type<"Semaphore", "semaphore", [MemRefElementTypeInterface]>; -def TPU_DMASemaphoreType : TPU_Type<"DMASemaphore", "dma_semaphore", [MemRefElementTypeInterface]>; -def TPU_SomeSemaphoreType : AnyTypeOf<[TPU_SemaphoreType, TPU_DMASemaphoreType]>; - -def TPU_DimensionSemantics : I32EnumAttr<"DimensionSemantics", "Dimension semantics", [ - I32EnumAttrCase<"parallel", 0>, - I32EnumAttrCase<"arbitrary", 1> -]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::tpu"; -} - -def TPU_DimensionSemanticsEnum - : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - -// All indices/sizes are in element-space. -// Note that the implementation will require statically provable tile alignment. -def TPU_ElementWindowAttr : TPU_Attr<"ElementWindow", "element_window"> { - // Including low padding, to avoid backwards-incompatible changes once we add it. - let parameters = (ins - ArrayRefParameter<"int64_t", "">:$pad_low, - ArrayRefParameter<"int64_t", "">:$pad_high - ); - let assemblyFormat = "`<` `[` $pad_low `]` `,` `[` $pad_high `]` `>`"; -} - -def TPU_ContractPrecision : I32EnumAttr<"ContractPrecision", "Contraction precision", [ - I32EnumAttrCase<"kBF16", 0, "bf16">, - I32EnumAttrCase<"kFP32", 1, "fp32"> -]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::tpu"; -} - -def TPU_ContractPrecisionEnum - : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - -def TPU_PackFormat : I32EnumAttr<"PackFormat", "Pack format", [ - I32EnumAttrCase<"kCompressed", 0, "compressed">, - I32EnumAttrCase<"kInterleaved", 1, "interleaved"> -]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::tpu"; -} - -def TPU_PackFormatEnum : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - -def TPU_TiledCase : I32EnumAttrCase<"tiled", 0>; -def TPU_LaneCase : I32EnumAttrCase<"lanes", 1>; -def TPU_SublaneCase : I32EnumAttrCase<"sublanes", 2>; -def TPU_VectorLayoutDim : I32EnumAttr< - "VectorLayoutDim", "", [TPU_TiledCase, TPU_LaneCase, TPU_SublaneCase]>; - -def TPU_VectorLayoutAttr : TPU_Attr<"VectorLayout", "vpad"> { - let description = [{TODO}]; - - let parameters = (ins "Layout":$layout); - let hasCustomAssemblyFormat = 1; -} - -def TPU_TiledLayoutAttr - : TPU_Attr<"TiledLayout", "tiled", - [DeclareAttrInterfaceMethods]> { - let description = [{TODO}]; - let parameters = (ins - ArrayRefParameter<"::xla::Tile", "">:$tiles, - ArrayRefParameter<"int64_t", "">:$tile_strides - ); - - let hasCustomAssemblyFormat = 1; -} - -def TPU_MemorySpace : I32EnumAttr<"MemorySpace", "Memory space", [ - I32EnumAttrCase<"kAny", 4294967295, "any">, - // TODO(apaszke): Rename to kXYZ in C++ - I32EnumAttrCase<"vmem", 0, "vmem">, - I32EnumAttrCase<"smem", 1, "smem">, - I32EnumAttrCase<"kHbm", 2, "hbm">, - I32EnumAttrCase<"kCmem", 3, "cmem">, - I32EnumAttrCase<"kSemaphoreMem", 4, "semaphore_mem"> -]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::tpu"; -} - -def TPU_MemorySpaceEnum - : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - -class TPU_Op traits = []> : - Op { -} - -def TPU_ReductionKind : I32EnumAttr<"ReductionKind", "Reduction kind", [ - I32EnumAttrCase<"SUM", 0, "sum">, - I32EnumAttrCase<"MAX", 1, "max">, - I32EnumAttrCase<"MIN", 2, "min"> -]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::tpu"; -} - -def TPU_ReductionKindAttr - : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - -def TPU_AllReduceOp : TPU_Op<"all_reduce", [Pure, SameOperandsAndResultType]> { - let arguments = (ins AnyVectorOfNonZeroRank:$input, I64Attr:$dim, TPU_ReductionKindAttr:$kind); - let results = (outs AnyVectorOfNonZeroRank:$output); - let assemblyFormat = [{ - $input attr-dict `:` type($input) - }]; -} - -def TPU_StoreOp : TPU_Op<"store", [AttrSizedOperandSegments]> { - let arguments = (ins - TPU_Vreg:$valueToStore, - AnyType:$base, - Variadic:$indices, - DenseBoolArrayAttr:$sublane_mask, - Optional:$mask, - OptionalAttr:$sublane_stride // In sublane-sized units - ); - let results = (outs); - let assemblyFormat = [{ - $base `[` $indices `]` `,` $valueToStore (`masked` $mask^)? `sublanes` $sublane_mask (`sublane_stride` $sublane_stride^)? attr-dict `:` type($base) `,` type($valueToStore) `,` type($mask) - }]; -} - -def TPU_LoadOp : TPU_Op<"load"> { - let arguments = (ins - AnyType:$base, - Variadic:$indices, - DenseBoolArrayAttr:$sublane_mask, - OptionalAttr:$sublane_stride // In sublane-sized units - ); - let results = (outs TPU_Vreg:$result); - let assemblyFormat = [{ - $base `[` $indices `]` `sublanes` $sublane_mask (`sublane_stride` $sublane_stride^)? attr-dict `:` type($base) `,` type($result) - }]; -} - -// TODO(jevinjiang): migrate tpu.strided_store to general vector store op. -def TPU_VectorStoreOp :TPU_Op<"vector_store", [AttrSizedOperandSegments]> { - let arguments = (ins - AnyVectorOfNonZeroRank:$valueToStore, - AnyMemRef:$base, - Variadic:$indices, - DenseI32ArrayAttr:$strides, - Optional:$mask // Elementwise mask. - ); - let results = (outs); - let assemblyFormat = [{ - $base `[` $indices `]` `,` $valueToStore (`masked` $mask^)? attr-dict `:` type($base) `,` type($valueToStore) `,` type($mask) - }]; - let hasVerifier = 1; -} - -def TPU_StridedLoadOp : TPU_Op<"strided_load"> { - let arguments = (ins - AnyMemRef:$base, - Variadic:$indices, - DenseI32ArrayAttr:$strides - ); - let results = (outs AnyVectorOfNonZeroRank:$result); - let assemblyFormat = [{ - $base `[` $indices `]` attr-dict `:` type($base) `,` type($result) - }]; - let hasVerifier = 1; -} - -def TPU_StridedStoreOp : TPU_Op<"strided_store"> { - let arguments = (ins - AnyVectorOfNonZeroRank:$valueToStore, - AnyMemRef:$base, - Variadic:$indices, - DenseI32ArrayAttr:$strides - ); - let results = (outs); - let assemblyFormat = [{ - $base `[` $indices `]` `,` $valueToStore attr-dict `:` type($base) `,` type($valueToStore) - }]; - let hasVerifier = 1; -} - -def TPU_ShuffledLoadOp : TPU_Op<"shuffled_load"> { - let arguments = (ins - AnyMemRef:$base, - Variadic:$indices, - DenseBoolArrayAttr:$sublane_mask, - DenseI32ArrayAttr:$sublane_offsets - ); - let results = (outs TPU_Vreg:$result); - let assemblyFormat = [{ - $base `[` $indices `]` attr-dict `:` type($base) `,` type($result) - }]; - let hasVerifier = 1; - let hasCanonicalizeMethod = 1; -} - -def TPU_ShuffledStoreOp : TPU_Op<"shuffled_store"> { - let arguments = (ins - TPU_Vreg:$valueToStore, - AnyMemRef:$base, - Variadic:$indices, - DenseBoolArrayAttr:$sublane_mask, - DenseI32ArrayAttr:$sublane_offsets - ); - let results = (outs); - let assemblyFormat = [{ - $base `[` $indices `]` `,` $valueToStore attr-dict `:` type($base) `,` type($valueToStore) - }]; - let hasVerifier = 1; - let hasCanonicalizeMethod = 1; -} - -// TODO(jevinjiang): deprecate to use dynamic_rotate. -def TPU_RotateOp : TPU_Op<"rotate", [Pure, SameOperandsAndResultType]> { - let arguments = (ins - AnyVectorOfNonZeroRank:$value, - SI32Attr:$amount, - SI32Attr:$dimension, - // When the stride is specified, the rotation amount for each index on the - // stride dimension will be (amount + stride * index). - OptionalAttr:$stride, - OptionalAttr:$stride_dimension - ); - let results = (outs AnyVectorOfNonZeroRank:$result); - let assemblyFormat = [{ - $value `by` $amount `dim` $dimension (`stride` $stride `stride_dim` $stride_dimension^)? attr-dict `:` type($value) - }]; - let hasVerifier = 1; -} - -def TPU_DynamicRotateOp : TPU_Op<"dynamic_rotate", [Pure]> { - let arguments = (ins - AnyVectorOfNonZeroRank:$value, - I32:$amount, - SI32Attr:$dimension, - // When the stride is specified, the rotation amount for each index on the - // stride dimension will be (amount + stride * index). - OptionalAttr:$stride, - OptionalAttr:$stride_dimension - ); - let results = (outs AnyVectorOfNonZeroRank:$result); - let assemblyFormat = [{ - $value `by` $amount `dim` $dimension attr-dict `:` type($value) `,` type($amount) `->` type($result) - }]; - let hasVerifier = 1; -} - -def TPU_IotaOp : TPU_Op<"iota", [Pure]> { - let arguments = (ins OptionalAttr:$dimension); - let results = (outs AnyVectorOfNonZeroRank:$output); - let assemblyFormat = [{ attr-dict `:` type($output) }]; -} - -// TODO(mvoz): deprecated - use concat. Canonicalization will do so automatically. -// b/376295711 -def TPU_RepeatOp : TPU_Op<"repeat", [Pure]> { - let arguments = (ins - AnyVectorOfNonZeroRank:$source, - I32Attr:$dimension, - I32Attr:$times - ); - let results = (outs AnyVectorOfNonZeroRank:$output); - let assemblyFormat = [{ $source `,` $dimension `x` $times attr-dict `:` type($source) `->` type($output) }]; -} - -def TPU_BroadcastInSublanesOp : TPU_Op<"broadcast_in_sublanes", [Pure]> { - let description = [{ - For each sublane `i`, broadcasts the value in lane `lane + i` along the entire - sublane. If `lane + i` is not in [0, lane_count), then the value in sublane `i` - is not defined (can be anything). - }]; - let arguments = (ins - TPU_Vreg:$source, // All sublanes should be equal. - I32Attr:$lane // Coordinates of the first element to take. - ); - // Output shape should be the same, except for position dim which contains - // the newly inserted dimension. - let results = (outs AnyVectorOfNonZeroRank:$output); - let assemblyFormat = [{ - $source `,` $lane attr-dict `:` type($source) `->` type($output) - }]; -} - -// Integer unpacks are always signed at the moment. -def TPU_UnpackSubelementsOp : TPU_Op<"unpack_subelements", [Pure]> { - let arguments = (ins - AnyVectorOfNonZeroRank:$source, - I32Attr:$index, - TPU_PackFormatEnum:$pack_format - ); - let results = (outs AnyVectorOfNonZeroRank:$output); - let assemblyFormat = [{ $source `,` $index attr-dict `:` type($source) `->` type($output) }]; -} - -// Integer packs are always signed at the moment. -// Float to integer packing rounds to nearest even. -def TPU_PackSubelementsOp : TPU_Op<"pack_subelements", [Pure, SameTypeOperands]> { - let arguments = (ins - Variadic:$sources, - DenseI32ArrayAttr:$positions, - TPU_PackFormatEnum:$pack_format - ); - let results = (outs TPU_Vreg:$output); - let assemblyFormat = [{ $sources attr-dict `:` type($sources) `->` type($output) }]; - let builders = [ - OpBuilder<(ins "::mlir::VectorType":$output_type, "::mlir::ArrayRef<::mlir::Value>":$padded_sources, "::mlir::tpu::PackFormat":$pack_format)>, - ]; - let extraClassDeclaration = [{ - static ::mlir::SmallVector<::mlir::Value> getPaddedSources(::mlir::ValueRange sources, ::mlir::ArrayRef positions, int packing_factor); - }]; - let hasVerifier = 1; -} - -def TPU_RelayoutOp : TPU_Op<"relayout", [SameOperandsAndResultType]> { - let arguments = (ins AnyType:$input); - let results = (outs AnyType:$output); - let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; -} - -def TPU_PackMaskOp : TPU_Op<"pack_vmsk", [Pure, SameTypeOperands]> { - let arguments = (ins - VectorOfNonZeroRankOf<[I1]>: $low, - VectorOfNonZeroRankOf<[I1]>: $high - ); - let results = (outs VectorOfNonZeroRankOf<[I1]>:$output); - let assemblyFormat = [{ $low `,` $high `,` attr-dict `:` type($low) `,` type($high) `->` type($output) }]; -} - -def TPU_GatherOp : TPU_Op<"gather", [Pure]> { - let arguments = (ins - AnyVectorOfNonZeroRank:$source, - DenseI32ArrayAttr:$indices, - I32Attr:$dimension - ); - let results = (outs AnyVectorOfNonZeroRank:$output); - let assemblyFormat = [{ - $source `[` $indices `]` `in` $dimension attr-dict - `:` type($source) `->` type($output) - }]; -} - -def TPU_DynamicGatherOp : TPU_Op<"dynamic_gather", [Pure]> { - let arguments = (ins - AnyVectorOfNonZeroRank:$source, - AnyVectorOfNonZeroRank:$indices, - I32Attr:$dimension - ); - let results = (outs AnyVectorOfNonZeroRank:$output); - let assemblyFormat = [{ - $source `[` $indices `]` `in` $dimension attr-dict - `:` type($source) `,` type($indices) `->` type($output) - }]; - let hasVerifier = 1; -} - -def TPU_RoundingMode : I32EnumAttr<"RoundingMode", "Rounding mode", [ - I32EnumAttrCase<"kTowardsZero", 0, "towards_zero">, - I32EnumAttrCase<"kToNearestEven", 1, "to_nearest_even">, -]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::tpu"; -} - -def TPU_RoundingModeEnum : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - -// Internal operation. All arith.fptosi operations that change the bitwidth -// must be canonicalized to this operation. -def TPU_FPToSIOp : TPU_Op<"fptosi", [Pure, ElementwiseMappable]> { - let arguments = (ins AnyVectorOfAnyRank:$input, TPU_RoundingModeEnum:$rounding_mode); - let results = (outs AnyVectorOfAnyRank:$output); - let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; - let hasCanonicalizeMethod = 1; -} - def TPU_DotDimensionNumbersAttr : TPU_Attr<"DotDimensionNumbers", "dot_dimension_numbers"> { let parameters = (ins ArrayRefParameter<"int64_t", "">:$lhs_contracting_dims, ArrayRefParameter<"int64_t", "">:$rhs_contracting_dims, ArrayRefParameter<"int64_t", "">:$lhs_non_contracting_dims, - ArrayRefParameter<"int64_t", "">:$rhs_non_contracting_dims, + // Empty when rhs is a 1-D vector. + OptionalArrayRefParameter<"int64_t", "">:$rhs_non_contracting_dims, // The contract is a flattened structure, wherein, each element is half of a // pair of indices. The first element is always 0 (lhs) or 1 (rhs) and the // second index is the index from the lhs or rhs. @@ -478,512 +53,35 @@ def TPU_DotDimensionNumbersAttr : TPU_Attr<"DotDimensionNumbers", "dot_dimension OptionalArrayRefParameter<"int64_t", "">:$rhs_batch_dims ); let assemblyFormat = "`<` `[` $lhs_contracting_dims `]` `,` `[` $rhs_contracting_dims `]` `,` " - "`[` $lhs_non_contracting_dims `]` `,` `[` $rhs_non_contracting_dims `]` `,` " + "`[` $lhs_non_contracting_dims `]` `,` `[` (`]`):($rhs_non_contracting_dims^ `]`)? `,` " "`[` $output_dim_order `]` `,` " "`[` (`]`):($lhs_batch_dims^ `]`)? `,` " "`[` (`]`):($rhs_batch_dims^ `]`)? `>`"; + let constBuilderCall = "::mlir::tpu::DotDimensionNumbersAttr::get($_builder.getContext(), $0)"; } -// TODO(apaszke): Think hard about precision -def TPU_MatmulOp : TPU_Op<"matmul", [Pure]> { - let arguments = (ins - AnyVectorOfNonZeroRank:$lhs, - AnyVectorOfNonZeroRank:$rhs, - AnyVectorOfNonZeroRank:$acc, - // These flags are deprecated - if dimension_numbers are defined, - // these flags are ignored. They will always be false after canonicalize. - DefaultValuedAttr:$transpose_lhs, - DefaultValuedAttr:$transpose_rhs, - OptionalAttr:$precision, - // NOTE: User-level optional, once canonicalized, always present. - OptionalAttr:$dimension_numbers - ); - let results = (outs AnyVectorOfNonZeroRank:$result); - let assemblyFormat = [{ - $lhs `,` $rhs `,` $acc attr-dict `:` type($lhs) `,` type($rhs) `,` type($acc) `->` type($result) - }]; - let hasCanonicalizer = 1; - let hasVerifier = 1; -} - -def TPU_ConcatenateOp : TPU_Op<"concatenate", [Pure]> { - let arguments = (ins - Variadic:$sources, - I32Attr:$dimension - ); - let results = (outs AnyVectorOfNonZeroRank:$output); - let assemblyFormat = [{ - $sources `in` $dimension attr-dict `:` type($sources) `->` type($output) - }]; - let hasVerifier = 1; -} - -def TPU_BitcastOp : TPU_Op<"bitcast", [Pure]> { - let arguments = (ins AnyVectorOfNonZeroRank:$input); - let results = (outs AnyVectorOfNonZeroRank:$output); - let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; - let hasVerifier = 1; -} - -def TPU_BitcastVregOp : TPU_Op<"bitcast_vreg", [Pure]> { - let arguments = (ins TPU_Vreg:$input); - let results = (outs TPU_Vreg:$output); - let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; -} - -def TPU_WeirdOp : TPU_Op<"weird", [Pure, ElementwiseMappable]> { - let arguments = (ins AnyType:$input); // F32 vector or scalar - let results = (outs AnyType:$output); // I1 vector or scalar - let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; - let hasVerifier = 1; -} - -def TPU_ReciprocalOp : TPU_Op<"reciprocal", [Pure, SameOperandsAndResultType, ElementwiseMappable]> { - let arguments = (ins - AnyVectorOfNonZeroRank:$input, - DefaultValuedAttr:$approx - ); - let results = (outs AnyVectorOfNonZeroRank:$output); - let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; - let hasVerifier = 1; -} - -def TPU_RollVectorsOp : TPU_Op<"roll_vectors", [Pure]> { - let arguments = (ins Variadic:$input); - let results = (outs AnyVectorOfNonZeroRank:$output); - let assemblyFormat = [{ - $input attr-dict `:` type($input) `->` type($output) - }]; -} - -def TPU_UnrollVectorsOp : TPU_Op<"unroll_vectors", [Pure]> { - let arguments = (ins AnyVectorOfNonZeroRank:$input); - let results = (outs Variadic:$output); - let hasCanonicalizeMethod = 1; - let assemblyFormat = [{ - $input attr-dict `:` type($input) `->` type($output) - }]; -} - -def TPU_CreateMaskOp : TPU_Op<"create_mask", [Pure, SameVariadicOperandSize]> { - // high is exclusive - let arguments = (ins Variadic:$low, Variadic:$high); - let results = (outs AnyType:$output); - let assemblyFormat = [{ - `[` $low `]``[` $high `]` attr-dict `:` type($output) - }]; -} - -def TPU_CreateSubelementMaskOp : TPU_Op<"create_subelement_mask", [Pure]> { - let summary = "Create a mask masking contiguous rows of subelements."; - let description = [{ - The "half-sublanes", "quarter-sublanes", etc. (unit is determined by - the type of `output`) of the mask are masked in the range specified by - `from` and `to`. - - - If `from <= to`, the range `[from, to)` is set and the rest is unset. - - If `to <= from`, the range `[to, from)` is unset and the rest is set. - - All lanes are set identically. - - Example: - - ```mlir - %msk = tpu.create_subelement_mask 3, 9 : vector<8x128x2xi1> - ``` - - This creates a mask `%msk` where, for all `lane`s, `%msk[*][lane][*]` is: - - ``` - [[0, 0], [0, 1], [1, 1], [1, 1], [1, 0], [0, 0], [0, 0], [0, 0]] - ``` - - It is currently only supported: - - In TPU v4, for `num_subelems` of 1 and 2. - - In TPU v5, for `num_subelems` of 1, 2, and 4. - }]; - let arguments = (ins - I32Attr:$from, // inclusive - I32Attr:$to // exclusive - ); - let results = (outs AnyType:$output); // Verify this is a vmsk with num_subelems - let assemblyFormat = [{ - $from `,` $to attr-dict `:` type($output) - }]; -} - -def TPU_AssumeMultipleOp : TPU_Op<"assume_multiple", [Pure, SameOperandsAndResultType]> { - let arguments = (ins - AnyTypeOf<[Index, AnyInteger]>:$value, - I32Attr:$multiple - ); - let results = (outs AnyTypeOf<[Index, AnyInteger]>:$result); - let hasVerifier = 1; -} - -def TPU_MemRefSliceOp : TPU_Op<"memref_slice", [Pure, AttrSizedOperandSegments]> { - let arguments = (ins - AnyMemRef:$mem_ref, - Variadic:$base_idx, - Variadic:$dynamic_sizes - ); - let results = (outs AnyMemRef:$result); - let assemblyFormat = [{ - $mem_ref `[` $base_idx `]` (`<` $dynamic_sizes^ `>`)? - attr-dict `:` type($mem_ref) `->` type($result) - }]; - let hasVerifier = 1; - let hasCanonicalizeMethod = 1; -} - -def TPU_MemRefSqueezeOp : TPU_Op<"memref_squeeze", [Pure]> { - let arguments = (ins AnyMemRef:$input); - let results = (outs AnyMemRef:$result); - let assemblyFormat = [{ - $input attr-dict `:` type($input) `->` type($result) - }]; - let hasVerifier = 1; - let hasCanonicalizeMethod = 1; -} - -def TPU_MemRefReshapeOp : TPU_Op<"memref_reshape", [Pure]> { - let arguments = (ins AnyMemRef:$input); - let results = (outs AnyMemRef:$result); - let assemblyFormat = [{ - $input attr-dict `:` type($input) `->` type($result) - }]; - let hasVerifier = 1; - let hasCanonicalizeMethod = 1; -} - -def TPU_MemRefBitcastOp : TPU_Op<"memref_bitcast", [Pure]> { - let arguments = (ins AnyMemRef:$input); - let results = (outs AnyMemRef:$result); - let assemblyFormat = [{ - $input attr-dict `:` type($input) `->` type($result) - }]; - let hasVerifier = 1; - let hasCanonicalizeMethod = 1; -} - -def TPU_ReinterpretCastOp : TPU_Op<"reinterpret_cast", [Pure]> { - let arguments = (ins AnyMemRef:$input); - let results = (outs AnyMemRef:$result); - let assemblyFormat = [{ - $input attr-dict `:` type($input) `->` type($result) - }]; - let hasVerifier = 1; -} - -def TPU_AssumeLayoutOp : TPU_Op<"assume_layout", [Pure]> { - let arguments = (ins AnyType:$input); - let results = (outs AnyType:$result); - let assemblyFormat = [{ - $input attr-dict `:` type($input) `->` type($result) - }]; -} - -def TPU_EraseLayoutOp : TPU_Op<"erase_memref_layout", [Pure]> { - let arguments = (ins AnyMemRef:$operand); - let results = (outs AnyMemRef:$result); - let assemblyFormat = [{ - $operand attr-dict `:` type($operand) `->` type($result) - }]; -} - -def TPU_DeviceIdOp : TPU_Op<"device_id", [Pure]> { - let arguments = (ins); - let results = (outs I32:$result); - let assemblyFormat = [{ attr-dict `:` type($result) }]; -} - -def TPU_SemaphoreReadOp : TPU_Op<"sem_read"> { - let arguments = (ins MemRefOf<[TPU_SemaphoreType, TPU_DMASemaphoreType]>:$semaphore); - let results = (outs I32:$result); - let assemblyFormat = [{ $semaphore attr-dict `:` type($semaphore) `->` type($result)}]; -} - -def TPU_SemaphoreWaitOp : TPU_Op<"sem_wait"> { - let arguments = (ins - MemRefOf<[TPU_SemaphoreType]>:$semaphore, - I32:$amount - ); - let results = (outs); - let assemblyFormat = [{ $semaphore `,` $amount attr-dict `:` type($semaphore)}]; - let hasVerifier = 1; -} - -def TPU_AllocaSemaphoreOp : TPU_Op<"sem_alloc"> { - let arguments = (ins); - let results = (outs MemRefOf<[TPU_SomeSemaphoreType]>:$result); - let assemblyFormat = [{ attr-dict `:` type($result) }]; -} - -def TPU_GetBarrierSemaphoreOp : TPU_Op<"sem_barrier"> { - let arguments = (ins); - let results = (outs MemRefOf<[TPU_SemaphoreType]>:$semaphore); - let assemblyFormat = [{ attr-dict `:` type($semaphore) }]; - let hasVerifier = 1; -} - -def TPU_SemaphoreSignalOp : TPU_Op<"sem_signal", [AttrSizedOperandSegments]> { - let arguments = (ins - MemRefOf<[TPU_SemaphoreType]>:$semaphore, - I32:$amount, - Optional:$device_id, // For remote DMAs - Optional:$core_id, // For megacore - OptionalAttr:$core_type - ); -let assemblyFormat = [{ - $semaphore `,` $amount (`device_id` $device_id^)? (`core_id` $core_id^)? (`core_type` $core_type^)? attr-dict `:` type($semaphore) - }]; - let hasVerifier = 1; - let builders = [ - // A backward-compatible builder that sets `core_type` to nullptr. - OpBuilder<(ins "Value":$semaphore, "Value":$amount, - "Value":$device_id, "Value":$core_id)>, - ]; -} - -def TPU_EnqueueDMAOp : TPU_Op<"enqueue_dma", [AttrSizedOperandSegments]> { - let arguments = (ins - AnyMemRef:$source, - Optional>:$source_semaphore, // For remote DMAs - AnyMemRef:$target, - MemRefOf<[TPU_DMASemaphoreType]>:$target_semaphore, - Optional:$device_id, // For remote DMAs - Optional:$core_id // For megacore - ); - let hasVerifier = 1; -} - -def TPU_WaitDMA2Op : TPU_Op<"wait_dma2"> { - let arguments = (ins - MemRefOf<[TPU_DMASemaphoreType]>:$semaphore, - AnyMemRef:$src, - AnyMemRef:$dst - ); - let hasVerifier = 1; -} - -// TODO(mvoz): Remove once a month has passed. b/395630795 -def TPU_WaitDMAOp : TPU_Op<"wait_dma"> { - let arguments = (ins - MemRefOf<[TPU_DMASemaphoreType]>:$semaphore, - AnyMemRef:$ref - ); - let hasVerifier = 1; -} - -def TPU_RegionOp : TPU_Op<"region", [RecursiveMemoryEffects, SingleBlockImplicitTerminator<"tpu::YieldOp">]> { - let arguments = (ins); - let results = (outs Variadic:$results); - let regions = (region AnyRegion:$region); - let hasVerifier = 1; -} - -def TPU_TraceOp : TPU_Op<"trace", [RecursiveMemoryEffects, SingleBlockImplicitTerminator<"tpu::YieldOp">]> { - let arguments = (ins StrAttr:$message, I32Attr:$level); - let results = (outs Variadic:$results); - let regions = (region AnyRegion:$region); -} - -def TPU_TraceStartOp : TPU_Op<"trace_start", []> { - let arguments = (ins StrAttr:$message, I32Attr:$level); - let results = (outs); -} - -def TPU_TraceStopOp : TPU_Op<"trace_stop", []> { - let arguments = (ins); - let results = (outs); -} - -def TPU_YieldOp : TPU_Op<"yield", [Pure, ReturnLike, Terminator]> { - let arguments = (ins Variadic:$results); - let assemblyFormat = [{ attr-dict ($results^ `:` type($results))? }]; -} - -def TPU_DelayOp : TPU_Op<"delay"> { - let arguments = (ins I32:$nanos); - let results = (outs); -} - -// Expands the granularity of mask to subelements. -def TPU_MaskCastOp : TPU_Op<"mask_cast", [Pure]> { - let arguments = (ins AnyVectorOfNonZeroRank:$input); - let results = (outs AnyVectorOfNonZeroRank:$result); - let assemblyFormat = [{ - $input attr-dict `:` type($input) `->` type($result) - }]; - let hasVerifier = 1; -} - -def TPU_GetIterationBoundOp : TPU_Op<"iteration_bound"> { - let arguments = (ins I32Attr:$dim); - let results = (outs I32:$result); - let assemblyFormat = [{ $dim attr-dict `:` type($result) }]; -} - -def TPU_GetInternalScratchOp : TPU_Op<"internal_scratch"> { - let arguments = (ins); - let results = (outs AnyMemRef:$result); - let assemblyFormat = [{ attr-dict `:` type($result) }]; -} - -def TPU_PRNGSeed32Op : TPU_Op<"prng_set_seed_32"> { - let arguments = (ins Variadic:$seeds); - let results = (outs); -} - -def TPU_PRNGRandomBitsOp : TPU_Op<"prng_random_bits"> { - let arguments = (ins); - let results = (outs AnyVectorOfNonZeroRank:$output); -} - -def TPU_LogOp : TPU_Op<"log"> { - let arguments = (ins - Variadic:$inputs, - StrAttr:$tag, - DefaultValuedAttr:$formatted - ); - let results = (outs); - let assemblyFormat = [{ $tag attr-dict (`:` `[` $inputs^ `]` `:` type($inputs))? }]; - let hasVerifier = 1; -} - -def TPU_LogBufferOp : TPU_Op<"log_buffer"> { - let arguments = (ins - AnyMemRef:$input, - DenseI64ArrayAttr:$shape, - StrAttr:$tag - ); - let results = (outs); - let assemblyFormat = [{ $tag attr-dict `:` $input `:` type($input) }]; - let hasVerifier = 1; -} - -def DebugAssertInsertionPass : Pass<"debug-assert-insertion", "::mlir::func::FuncOp"> { - let dependentDialects = [ - "::mlir::func::FuncDialect", - "::mlir::arith::ArithDialect", - "::mlir::cf::ControlFlowDialect", - "::mlir::vector::VectorDialect", - "::mlir::tpu::TPUDialect", - ]; - let constructor = "::mlir::tpu::createDebugAssertInsertionPass()"; -} - -def LogicalToPhysicalDeviceIdPass : Pass<"logical-to-physical-device-id", "::mlir::func::FuncOp"> { - let dependentDialects = [ - "::mlir::func::FuncDialect", - "::mlir::memref::MemRefDialect", - "::mlir::tpu::TPUDialect", - ]; - let constructor = "::mlir::tpu::createLogicalToPhysicalDeviceIdPass(-1)"; - let options = [Option<"total_devices", "total-devices", "int", "", "">]; -} - -def InferMemRefLayoutPass : Pass<"tpu-infer-memref-layout", "::mlir::func::FuncOp"> { - let dependentDialects = [ - "::mlir::func::FuncDialect", - "::mlir::memref::MemRefDialect", - ]; - let constructor = "::mlir::tpu::createInferMemRefLayoutPass()"; - let options = [ - // If hardware_generation is not set, the default value of -1 will crash on - // runOnOperation. - Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">, - Option<"lane_count", "lane-count", "int", /*default=*/"128", "">, - Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">, - Option<"tpu_tiling_flags", "tpu-tiling-flags", "::mlir::tpu::TpuTilingFlags", /*default=*/"::mlir::tpu::TpuTilingFlags{}", "">, - ]; -} - -def CanonicalizeMosaicPass : Pass<"tpu-canonicalize-mosaic", "::mlir::func::FuncOp"> { - let dependentDialects = [ - "::mlir::arith::ArithDialect", - "::mlir::func::FuncDialect", - "::mlir::memref::MemRefDialect", - "::mlir::scf::SCFDialect", - "::mlir::vector::VectorDialect", - "::mlir::tpu::TPUDialect", - ]; - let constructor = "::mlir::tpu::createCanonicalizeMosaicPass()"; - let options = [ - Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">, - Option<"compatibility_mode", "compatibility-mode", "bool", /*default=*/"1", "">, - ]; -} - -def InferVectorLayoutPass : Pass<"tpu-infer-vector-layout", "::mlir::func::FuncOp"> { - let dependentDialects = [ - "::mlir::arith::ArithDialect", - "::mlir::func::FuncDialect", - "::mlir::memref::MemRefDialect", - "::mlir::scf::SCFDialect", - "::mlir::vector::VectorDialect", - "::mlir::tpu::TPUDialect", - ]; - let constructor = "::mlir::tpu::createInferVectorLayoutPass()"; - let options = [ - Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">, - Option<"lane_count", "lane-count", "int", /*default=*/"128", "">, - Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">, - ]; +class TPU_Type traits = [], + string baseCppType = "::mlir::Type"> + : TypeDef { + let mnemonic = mnemonic_; } -def RelayoutInsertionPass : Pass<"tpu-relayout-insertion", "::mlir::func::FuncOp"> { - let dependentDialects = [ - "::mlir::arith::ArithDialect", - "::mlir::func::FuncDialect", - "::mlir::tpu::TPUDialect", - ]; - let constructor = "::mlir::tpu::createRelayoutInsertionPass()"; - let options = [ - // If hardware_generation is not set, the default value of -1 will crash on - // runOnOperation. - Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">, - Option<"lane_count", "lane-count", "int", /*default=*/"128", "">, - Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">, - ]; -} +def TPU_Float8EXMYType : TPU_Type<"Float8EXMY", "float8_exmy", + [DeclareTypeInterfaceMethods]> { + let summary = "EXMY type in a 8 bit container"; + let description = [{ + EXMY type in a 8 bit container. Meaningful bits are aligned to LSB, and + bits higher than the underlying exmy type in the container are considered + as ignored. See https://arxiv.org/abs/2405.13938 for more details. + }]; -def ApplyVectorLayoutPass : Pass<"tpu-apply-vector-layout", "::mlir::func::FuncOp"> { - let dependentDialects = [ - "::mlir::arith::ArithDialect", - "::mlir::func::FuncDialect", - "::mlir::vector::VectorDialect", - "::mlir::tpu::TPUDialect", - ]; - let constructor = "::mlir::tpu::createApplyVectorLayoutPass()"; - let options = [ - // If hardware_generation is not set, the default value of -1 will crash on - // runOnOperation. - Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">, - Option<"lane_count", "lane-count", "int", /*default=*/"128", "">, - Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">, - Option<"mxu_contracting_size", "mxu-contracting-size", "int", /*default=*/"128", "">, - Option<"mxu_noncontracting_size", "mxu-noncontracting-size", "int", /*default=*/"128", "">, - Option<"max_sublanes_in_scratch", "max-sublanes-in-scratch", "int", /*default=*/"0", "">, - Option<"vmem_banks", "vmem-banks", "int", /*default=*/"-1", "">, - Option<"max_shuffle_sublane_offset", "max-shuffle-sublane-offset", "int", /*default=*/"-1", "Max sublane offset per shuffled load/store">, - ]; -} + let parameters = (ins + TypeParameter<"::mlir::FloatType", "Underlying EXMY type">:$underlying_type + ); -def LinalgVectorizationPass : Pass<"linalg-vectorization", "::mlir::func::FuncOp"> { - let dependentDialects = [ - "::mlir::func::FuncDialect", - "::mlir::memref::MemRefDialect", - "::mlir::linalg::LinalgDialect", - "::mlir::tensor::TensorDialect", - "::mlir::vector::VectorDialect", - "::mlir::tpu::TPUDialect", - ]; - let constructor = "::mlir::tpu::createLinalgVectorizationPass(false)"; - let options = [ - Option<"supports_bf16_alu_instructions", "supports-bf16-alu-instructions", "bool", "", "">, - Option<"supports_bf16_matmul", "supports-bf16-matmul", "bool", "", "">, - ]; + let assemblyFormat = [{ + `<` $underlying_type `>` + }]; } -#endif // TPU_ATTRS +#endif // TPU_BASE diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc index 59ca5d7a3437..9dee418093ca 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc @@ -16,28 +16,38 @@ limitations under the License. #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include -#include +#include #include -#include #include #include -#include +#include "absl/hash/hash.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep. +#include "llvm/Support/MathExtras.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep. +#include "mlir/IR/Location.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/hash/hash.h" -#include "absl/log/log.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" +#include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.cc.inc" #include "jaxlib/mosaic/dialect/tpu/tpu_enums.cc.inc" +#include "jaxlib/mosaic/dialect/tpu/util.h" #include "xla/layout.h" // This is a bit unclean, but we need to squat the xla namespace to make sure @@ -71,6 +81,11 @@ void TPUDialect::initialize() { >(); } +Operation *TPUDialect::materializeConstant(OpBuilder &builder, Attribute value, + Type type, Location loc) { + return arith::ConstantOp::materialize(builder, value, type, loc); +} + /* static */ std::optional TPUDialect::GetCoreTypeAttr( Operation *op) { Attribute attr = op->getAttr(GetCoreTypeKey()); @@ -83,13 +98,125 @@ void TPUDialect::initialize() { return mlir::cast(attr).getValue(); } -FailureOr> GetCoreTypeOfParentFunc(Operation &op) { +struct MemRefCastEraseLayout : public OpRewritePattern { + // Set the benefit to 0 to ensure that other patterns that fold in the cast + // are tried first. + MemRefCastEraseLayout(MLIRContext* context) + : OpRewritePattern(context, /*benefit=*/0) {} + LogicalResult matchAndRewrite(memref::CastOp cast_op, + PatternRewriter& rewriter) const final { + // Push tpu.erase_memref_layout through memref.cast + auto erase_layout_op = cast_op.getOperand().getDefiningOp(); + if (!erase_layout_op) { + return failure(); + } + TypedValue orig_value = erase_layout_op.getOperand(); + const MemRefType orig_type = orig_value.getType(); + const ArrayRef cast_shape = cast_op.getType().getShape(); + MemRefType new_cast_type = + MemRefType::Builder(orig_type).setShape(cast_shape); + auto new_cast_op = memref::CastOp::create(rewriter, cast_op.getLoc(), + new_cast_type, orig_value); + auto new_erase_layout_op = + EraseLayoutOp::create(rewriter, erase_layout_op.getLoc(), new_cast_op); + rewriter.replaceOp(cast_op, new_erase_layout_op); + return success(); + } +}; + +// Rewrites +// +// memref.dim(tpu.memref_slice(..., dynamic_sizes), i) +// +// to +// +// dynamic_sizes[dynamicDimIndex(i)] +// +// if i is a constant and refers to a dynamic dimension. +struct MemRefDimOfSlice : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::DimOp dim_op, + PatternRewriter& rewriter) const override { + auto slice_op = dim_op.getSource().getDefiningOp(); + if (!slice_op) { + return failure(); + } + const std::optional maybe_dim = + getConstantIntValue(dim_op.getDimension()); + if (!maybe_dim) { + return failure(); + } + const int64_t dim = *maybe_dim; + MemRefType result_type = slice_op.getType(); + if (dim < 0 || result_type.getRank() <= dim) { + return dim_op.emitWarning("Dimension index is out of bounds"); + } + if (result_type.getDimSize(dim) != ShapedType::kDynamic) { + return failure(); + } + const unsigned dynamic_dim_idx = result_type.getDynamicDimIndex(dim); + ValueRange dynamic_sizes = slice_op.getDynamicSizes(); + rewriter.replaceOpWithNewOp( + dim_op, dim_op.getType(), dynamic_sizes[dynamic_dim_idx]); + return success(); + } +}; + +// Rewrites memref.dim(tpu.memref_squeeze(x)) to memref.dim(x) with the +// dimension index adjusted to account for squeezed dimensions. +struct MemRefDimOfSqueeze : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::DimOp dim_op, + PatternRewriter& rewriter) const override { + auto squeeze_op = dim_op.getSource().getDefiningOp(); + if (!squeeze_op) { + return failure(); + } + const std::optional maybe_dim = + getConstantIntValue(dim_op.getDimension()); + if (!maybe_dim) { + return failure(); + } + const int64_t dim = *maybe_dim; + MemRefType result_type = squeeze_op.getType(); + if (dim < 0 || result_type.getRank() <= dim) { + return dim_op.emitWarning("Dimension index is out of bounds"); + } + if (result_type.getDimSize(dim) != ShapedType::kDynamic) { + return failure(); + } + MemRefType source_type = getMemRefType(squeeze_op.getInput()); + FAILUREOR_ASSIGN_OR_RETURN( + SmallVector squeezed, + computeSqueezedDimsChecked(squeeze_op, source_type.getShape(), + result_type.getShape())); + int64_t source_dim = dim; + for (int squeezed_dim : squeezed) { + if (squeezed_dim <= source_dim) { + ++source_dim; + } + } + rewriter.replaceOpWithNewOp(dim_op, squeeze_op.getInput(), + source_dim); + return success(); + } +}; + +void TPUDialect::getCanonicalizationPatterns(RewritePatternSet& results) const +/*override*/ { + results.add( + getContext()); +} + +FailureOr GetCoreTypeOfParentFunc(Operation &op) { mlir::Operation *func_op = op.getParentOfType(); if (func_op == nullptr) { return op.emitError() << "Operation " << op.getName() << " is not inside a func.func"; } - return TPUDialect::GetCoreTypeAttr(func_op); + return TPUDialect::GetCoreTypeAttr(func_op).value_or(CoreType::kTc); } void VectorLayoutAttr::print(AsmPrinter &printer) const { @@ -175,32 +302,210 @@ Attribute TiledLayoutAttr::parse(AsmParser &parser, Type type) { } AffineMap TiledLayoutAttr::getAffineMap() const { - AffineMap map = - AffineMap::getMultiDimIdentityMap(getTileStrides().size(), getContext()); SmallVector exprs; - for (const xla::Tile &tile : getTiles()) { - exprs.clear(); + for (int64_t i = 0; i < getRank(); ++i) { + exprs.push_back(getAffineDimExpr(i, getContext())); + } + for (const xla::Tile& tile : getTiles()) { + SmallVector new_exprs; auto dimensions = tile.dimensions(); - int64_t untiled_dims = map.getNumResults() - dimensions.size(); - if (untiled_dims < 0) { - LOG(FATAL) << "Invalid TiledLayoutAttr: Number of dims must be larger " - "or equal to the rank of the tile"; + int64_t untiled_rank = exprs.size() - dimensions.size(); + assert(untiled_rank >= 0); + for (int64_t i = 0; i < untiled_rank; ++i) { + new_exprs.push_back(exprs[i]); } - for (int64_t i = 0; i < untiled_dims; ++i) { - exprs.push_back(getAffineDimExpr(i, getContext())); + for (int64_t i = 0; i < dimensions.size(); ++i) { + new_exprs.push_back(exprs[untiled_rank + i].floorDiv(dimensions[i])); } - for (int i = 0; i < dimensions.size(); ++i) { - exprs.push_back(getAffineDimExpr(untiled_dims + i, getContext()) - .floorDiv(dimensions[i])); + for (int64_t i = 0; i < dimensions.size(); ++i) { + new_exprs.push_back(exprs[untiled_rank + i] % dimensions[i]); } - for (int i = 0; i < dimensions.size(); ++i) { - exprs.push_back(getAffineDimExpr(untiled_dims + i, getContext()) % - dimensions[i]); + exprs = std::move(new_exprs); + } + int64_t num_symbols = 0; + AffineExpr result = getAffineConstantExpr(0, getContext()); + SmallVector strides = getExpandedStrides(); + assert(strides.size() == exprs.size()); + for (int64_t i = 0; i < exprs.size(); ++i) { + AffineExpr stride_expr = + ShapedType::isDynamic(strides[i]) + ? getAffineSymbolExpr(num_symbols++, getContext()) + : getAffineConstantExpr(strides[i], getContext()); + result = result + exprs[i] * stride_expr; + } + return AffineMap::get(getRank(), num_symbols, result); +} + +namespace { +int64_t getUntiledRank(ArrayRef tiles, const int64_t rank) { + // Note: This implementation does not assume there is no nested tiling across + // the first level of tiling, though this is enforced by the verifier. + int64_t untiled_rank = rank; + int64_t tiled_rank = rank; + for (const xla::Tile& tile : tiles) { + const int64_t tile_ndims = tile.dimensions().size(); + untiled_rank = std::min(untiled_rank, tiled_rank - tile_ndims); + tiled_rank += tile_ndims; + } + return untiled_rank; +} +} // namespace + +int64_t TiledLayoutAttr::getUntiledRank() const { + return mlir::tpu::getUntiledRank(getTiles(), getRank()); +} + +namespace { +FailureOr> getExpandedShape( + const ArrayRef untiled_shape, const ArrayRef tiles, + const bool require_alignment) { + SmallVector shape(untiled_shape); + for (const xla::Tile& tile : tiles) { + const int64_t tile_ndims = tile.dimensions().size(); + const llvm::ArrayRef tiled_shape = + llvm::ArrayRef(shape).take_back(tile_ndims); + llvm::SmallVector new_tiled_shape(2 * tile_ndims); + for (int64_t i = 0; i < tile_ndims; ++i) { + if (require_alignment && (ShapedType::isDynamic(tiled_shape[i]) || + tiled_shape[i] % tile.dimension(i) != 0)) { + return failure(); + } + if (ShapedType::isDynamic(tiled_shape[i])) { + new_tiled_shape[i] = ShapedType::kDynamic; + } else { + new_tiled_shape[i] = + llvm::divideCeil(tiled_shape[i], tile.dimension(i)); + } + new_tiled_shape[tile_ndims + i] = tile.dimension(i); } - auto tile_map = AffineMap::get(map.getNumResults(), 0, exprs, getContext()); - map = tile_map.compose(map); + shape.pop_back_n(tile_ndims); + shape.append(new_tiled_shape); } - return map; + return shape; +} +} // namespace + +SmallVector TiledLayoutAttr::getDefaultTileStrides( + const ArrayRef tiles, const ArrayRef shape) { + SmallVector strides(shape.size()); + int64_t stride = 1; + const xla::Tile* const first_tile = tiles.empty() ? nullptr : &tiles.front(); + const int64_t first_tile_rank = + first_tile == nullptr ? 0 : first_tile->dimensions().size(); + for (int64_t d = shape.size() - 1; d >= 0; --d) { + assert(!ShapedType::isDynamic(shape[d])); + strides[d] = stride; + if (d >= shape.size() - first_tile_rank) { + assert(first_tile != nullptr); + const int64_t tile_d = d - (shape.size() - first_tile_rank); + stride *= llvm::divideCeil(shape[d], first_tile->dimension(tile_d)); + } else { + stride *= shape[d]; + } + } + return strides; +} + +bool TiledLayoutAttr::tilesAreKnownContiguous( + const ArrayRef shape) const { + const ArrayRef tiles = getTiles(); + const ArrayRef tile_strides = getTileStrides(); + int64_t stride = 1; + const xla::Tile* const first_tile = tiles.empty() ? nullptr : &tiles.front(); + const int64_t first_tile_rank = + first_tile == nullptr ? 0 : first_tile->dimensions().size(); + for (int64_t d = shape.size() - 1; d >= 0; --d) { + int64_t size_tiles; + if (d >= shape.size() - first_tile_rank && + shape[d] != ShapedType::kDynamic) { + assert(first_tile != nullptr); + const int64_t tile_d = d - (shape.size() - first_tile_rank); + size_tiles = llvm::divideCeil(shape[d], first_tile->dimension(tile_d)); + } else { + size_tiles = shape[d]; + } + // Dimensions with only one element/tile can have any stride. + if (stride != tile_strides[d] && size_tiles != 1) { + return false; + } + if (d == 0) { + break; + } + // When any dimension other than the leading one has a dynamic size, we + // cannot guarantee that there are no gaps. + if (size_tiles == ShapedType::kDynamic) { + return false; + } + stride *= size_tiles; + } + return true; +} + +SmallVector TiledLayoutAttr::getExpandedShape( + ArrayRef untiled_shape) const { + // getExpandedShape should never fail without require_alignment + return *mlir::tpu::getExpandedShape(untiled_shape, getTiles(), + /*require_alignment=*/false); +} + +SmallVector TiledLayoutAttr::getExpandedStrides() const { + if (getTiles().empty()) { + return SmallVector(getTileStrides()); + } + SmallVector strides(getTileStrides()); + // Expand front tile + const xla::Tile& first_tile = getTiles().front(); + const FailureOr> failure_or_expanded_tile = + mlir::tpu::getExpandedShape(first_tile.dimensions(), + getTiles().drop_front(), + /*require_alignment=*/true); + // Verification should ensure this: + assert(succeeded(failure_or_expanded_tile)); + const SmallVector& expanded_tile = *failure_or_expanded_tile; + strides.resize_for_overwrite(getRank() + expanded_tile.size()); + int64_t first_tile_size = llvm::product_of(first_tile.dimensions()); + int64_t tile_size = 1; + for (int64_t d = strides.size() - 1; d >= 0; --d) { + if (d >= getRank()) { + const int64_t new_stride = tile_size; + tile_size *= expanded_tile[d - getRank()]; + strides[d] = new_stride; + } else { + strides[d] *= first_tile_size; + } + } + return strides; +} + +LogicalResult TiledLayoutAttr::verify( + function_ref emitError, + const llvm::ArrayRef tiles, + const llvm::ArrayRef tile_strides) { + if (llvm::any_of(tile_strides, ShapedType::isDynamic)) { + return emitError() << "Not implemented: Dynamic tile strides"; + } + if (tiles.empty()) { + return success(); + } + const int64_t rank = tile_strides.size(); + const xla::Tile& first_tile = tiles.front(); + const int64_t first_tile_rank = first_tile.dimensions().size(); + // The interpretation of tile strides is unclear if there is nested tiling + // across first tiles (e.g. T(8, 128)(2, 4, 64)), and this has no applications + // anyway. + if (mlir::tpu::getUntiledRank(tiles, rank) != rank - first_tile_rank) { + return emitError() << "Not implemented: Nested tiling across first tiles"; + } + // Check that nested tiles evenly divide previous tiles (so they don't add any + // padding or change the tile size) + if (failed(mlir::tpu::getExpandedShape(first_tile.dimensions(), + tiles.drop_front(), + /*require_alignment=*/true))) { + return emitError() << "Not implemented: Nested tiles must evenly divide " + << "the first tile " << first_tile.ToString() + << " but they do not (would add padding)"; + } + return success(); } MemRefType getMemRefType(Value value) { @@ -210,6 +515,15 @@ MemRefType getMemRefType(Value value) { return cast(value.getType()); } +template +bool checkBothOperandsDivisible(Value value, int64_t divisor, int64_t fuel) { + if (auto op = value.getDefiningOp()) { + return isGuaranteedDivisible(op.getLhs(), divisor, fuel / 2) && + isGuaranteedDivisible(op.getRhs(), divisor, (fuel + 1) / 2); + } + return false; +} + bool isGuaranteedDivisible(Value value, int64_t divisor, int64_t fuel) { if (fuel <= 0) { return false; @@ -232,6 +546,17 @@ bool isGuaranteedDivisible(Value value, int64_t divisor, int64_t fuel) { if (auto cast_op = value.getDefiningOp()) { return isGuaranteedDivisible(cast_op.getOperand(), divisor, fuel - 1); } + if (checkBothOperandsDivisible(value, divisor, fuel) || + checkBothOperandsDivisible(value, divisor, fuel) || + checkBothOperandsDivisible(value, divisor, fuel) || + checkBothOperandsDivisible(value, divisor, fuel)) { + return true; + } + if (auto select_op = value.getDefiningOp()) { + return isGuaranteedDivisible(select_op.getTrueValue(), divisor, fuel / 2) && + isGuaranteedDivisible(select_op.getFalseValue(), divisor, + (fuel + 1) / 2); + } return false; } @@ -249,4 +574,51 @@ DotDimensionNumbersAttr defaultDimensionNumbers(Builder &builder, /*rhs_batch_dims=*/{}); } +const ::llvm::fltSemantics& Float8EXMYType::getFloatSemantics() const { + if (mlir::isa(getUnderlyingType())) { + return llvm::APFloat::Float6E3M2FN(); + } else if (mlir::isa(getUnderlyingType())) { + return llvm::APFloat::Float6E2M3FN(); + } + return cast(getUnderlyingType()).getFloatSemantics(); +} + +namespace { + +struct CommsAnalysisState { + bool has_communication = false; + bool has_custom_barrier = false; + + explicit operator bool() { return has_communication && has_custom_barrier; } +}; + +void analyzeCrossChipCommunication(mlir::Operation *op, + CommsAnalysisState *state) { + if (auto dma = dyn_cast(op)) { + state->has_communication |= dma.getDeviceId() != nullptr; + } else if (auto signal = dyn_cast(op)) { + state->has_communication |= signal.getDeviceId() != nullptr; + } else if (auto barrier = dyn_cast(op)) { + state->has_custom_barrier = true; + } + for (Region ®ion : op->getRegions()) { + for (Block &block : region.getBlocks()) { + for (Operation &op : block.getOperations()) { + analyzeCrossChipCommunication(&op, state); + if (*state) { + return; + } + } + } + } +} + +} // namespace + +std::pair mightCommunicateBetweenChips(mlir::Operation *op) { + CommsAnalysisState state; + analyzeCrossChipCommunication(op, &state); + return std::make_pair(state.has_communication, state.has_custom_barrier); +} + } // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index 0800a9e75087..2bfccb171030 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -16,23 +16,21 @@ limitations under the License. #ifndef JAXLIB_MOSAIC_DIALECT_TPU_DIALECT_H_ #define JAXLIB_MOSAIC_DIALECT_TPU_DIALECT_H_ -#include #include #include -#include +#include #include +#include "absl/types/span.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Support/LogicalResult.h" -#include "jaxlib/mosaic/dialect/tpu/layout.h" +#include "mlir/Support/LogicalResult.h" +#include "jaxlib/mosaic/dialect/tpu/layout.h" // IWYU pragma: keep #include "jaxlib/mosaic/dialect/tpu/tpu_enums.h.inc" -#include "jaxlib/mosaic/dialect/tpu/transforms/serde.h" -#include "xla/layout.h" +#include "xla/layout.h" // IWYU pragma: keep namespace mlir::tpu { class TPUDialect; @@ -57,54 +55,24 @@ struct TpuTilingFlags { bool use_x4_large_second_minor = false; }; -struct ApplyVectorLayoutContext { - // TODO(tlongeri): target_shape should be determined from hardware_generation - int hardware_generation = -1; - std::array target_shape = {8, 128}; - // mxu_shape = {contracting_size, non_contracting_size} - std::array mxu_shape = {128, 128}; - int64_t max_sublanes_in_scratch = 0; - int64_t vmem_banks = -1; // -1 means "unspecified". - int32_t max_shuffle_sublane_offset = -1; // -1 means "unspecified". -}; - -std::pair mightCommunicateBetweenChips(Operation* op); +std::pair mightCommunicateBetweenChips(Operation *op); +// Creates a pass that infers the layout of memrefs in the given function. +// +// The `target_shape` can either be +// * 1D -- (lane count) SparseCore tiling; or +// * 2D -- (sublane count, lane count) TensorCore tiling. std::unique_ptr> createInferMemRefLayoutPass( - int hardware_generation = -1, - std::array target_shape = {8, 128}, - const TpuTilingFlags &tpu_tiling_flags = {}); - -std::unique_ptr> createCanonicalizeMosaicPass( - int hardware_generation = -1, bool compatibility_mode = true); - -std::unique_ptr> createInferVectorLayoutPass( - int hardware_generation = -1, - std::array target_shape = {8, 128}, - const TpuTilingFlags &tpu_tiling_flags = {}); - -std::unique_ptr> createRelayoutInsertionPass( - int hardware_generation = -1, - std::array target_shape = {8, 128}); - -std::unique_ptr> createApplyVectorLayoutPass( - const ApplyVectorLayoutContext &ctx = ApplyVectorLayoutContext{}); - -std::unique_ptr> -createLogicalToPhysicalDeviceIdPass(int64_t total_devices); - -std::unique_ptr> createLinalgVectorizationPass( - bool supports_bf16_alu_instructions = false, - bool supports_bf16_matmul = false); - -std::unique_ptr> createDebugAssertInsertionPass(); + int hardware_generation, absl::Span target_shape, + const TpuTilingFlags& tpu_tiling_flags, bool align = true); #define GEN_PASS_DECL_MOSAICSERDEPASS #include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc" // Determine the core type of the given op based on the `tpu.core_type` -// annotation of its parent function. -FailureOr> GetCoreTypeOfParentFunc(Operation &op); +// annotation of its parent function. If no such annotation is found, returns +// kTc. +FailureOr GetCoreTypeOfParentFunc(Operation &op); // Changes the memory space of the value and propagates it through the program. LogicalResult specializeMemorySpace(TypedValue value, @@ -114,7 +82,7 @@ LogicalResult specializeMemorySpace(TypedValue value, // vector ops. This functions inverts the layout erasure applied to the value. MemRefType getMemRefType(Value value); -bool isGuaranteedDivisible(Value value, int64_t divisor, int64_t fuel = 8); +bool isGuaranteedDivisible(Value value, int64_t divisor, int64_t fuel = 128); DotDimensionNumbersAttr defaultDimensionNumbers(Builder &builder, bool transpose_lhs, @@ -123,6 +91,8 @@ DotDimensionNumbersAttr defaultDimensionNumbers(Builder &builder, #define GEN_PASS_REGISTRATION #include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc" +constexpr std::string_view kLeadingTileRows = "leading_tile_rows"; + } // namespace tpu } // namespace mlir diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index c73accb09b26..3ac3d5e3556d 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -16,47 +16,103 @@ limitations under the License. #include #include #include +#include #include #include +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "llvm/ADT/FloatingPointMode.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/Support/FormatVariadic.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" -#include "absl/strings/str_format.h" -#include "mlir/include/mlir/Dialect/Math/IR/Math.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Diagnostics.h" -#include "mlir/include/mlir/IR/IRMapping.h" -#include "mlir/include/mlir/IR/OperationSupport.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" +#include "xla/layout.h" namespace mlir { namespace tpu { +namespace { + +// This should only be used to canonicalize away EraseLayoutOps that feed ops +// that only consume memrefs and don't return them. +LogicalResult propagateTiledLayoutToConsumer(Operation* op, + PatternRewriter& rewriter) { + bool modified = false; + for (unsigned int i = 0; i < op->getNumOperands(); ++i) { + if (auto erase_layout_op = + op->getOperand(i).getDefiningOp()) { + modified = true; + rewriter.modifyOpInPlace( + op, [&]() { op->setOperand(i, erase_layout_op.getOperand()); }); + } + } + return success(modified); +} + +llvm::RoundingMode convertTpuRoundingModeToLLVMIR(tpu::RoundingMode mode) { + switch (mode) { + case tpu::RoundingMode::kToNearestEven: + return llvm::RoundingMode::NearestTiesToEven; + case tpu::RoundingMode::kTowardsZero: + return llvm::RoundingMode::TowardZero; + } +} + +// Attempts to convert `sourceValue` to an APFloat value with +// `targetSemantics` and `roundingMode`, without any information loss. +static FailureOr convertFloatValue( + APFloat sourceValue, const llvm::fltSemantics &targetSemantics, + llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) { + bool losesInfo = false; + auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo); + if (losesInfo || status != APFloat::opOK) { + return failure(); + } + + return sourceValue; +} + +} // namespace + LogicalResult UnrollVectorsOp::canonicalize(UnrollVectorsOp op, PatternRewriter &rewriter) { - RollVectorsOp roll_op = - dyn_cast_or_null(op.getOperand().getDefiningOp()); + RollVectorsOp roll_op = op.getOperand().getDefiningOp(); if (!roll_op) { - return failure(); + return failure(); } if (roll_op.getNumOperands() != op.getNumResults()) { - return failure(); + return failure(); } for (auto [v1, v2] : llvm::zip(roll_op.getOperandTypes(), op.getResultTypes())) { if (v1 != v2) { - return failure(); + return failure(); } } rewriter.replaceOp(op, roll_op.getOperands()); @@ -66,8 +122,8 @@ LogicalResult UnrollVectorsOp::canonicalize(UnrollVectorsOp op, LogicalResult BitcastOp::verify() { auto in_ty = getInput().getType(); auto out_ty = getOutput().getType(); - auto in_bitwidth = in_ty.getElementTypeBitWidth(); - auto out_bitwidth = out_ty.getElementTypeBitWidth(); + auto in_bitwidth = getElementTypeBitwidth(in_ty); + auto out_bitwidth = getElementTypeBitwidth(out_ty); if (in_bitwidth != out_bitwidth) { if (in_ty.getRank() < 2 || out_ty.getRank() < 2) { return emitError( @@ -91,9 +147,26 @@ LogicalResult BitcastOp::verify() { return success(); } +OpFoldResult BitcastVregOp::fold(FoldAdaptor adaptor) { + // Bitcast from X -> X is a no-op. + if (getType() == getInput().getType()) { + return getInput(); + } + // Bitcast from X -> Y -> ... -> Z -> X is a no-op. + Value input = getInput(); + while (auto op = dyn_cast(input.getDefiningOp())) { + input = op.getInput(); + if (getType() == input.getType()) { + return input; + } + } + return nullptr; +} + LogicalResult MemRefSliceOp::verify() { - auto source_type = getMemRefType(getMemRef()); + auto source_type = getMemRef().getType(); auto target_type = getType(); + auto source_layout = source_type.getLayout(); auto target_layout = target_type.getLayout(); auto target_memory_space = target_type.getMemorySpace(); auto indices = getBaseIdx(); @@ -102,6 +175,11 @@ LogicalResult MemRefSliceOp::verify() { return emitOpError( "Only slicing of memrefs with static shapes is supported."); } + if (getDynamicSizes().size() != target_type.getNumDynamicDims()) { + return emitOpError( + "Number of provided dynamic dimensions sizes must match the number of " + "dynamic dimensions in the target type."); + } auto source_shape = source_type.getShape(); bool is_semaphore = HasMemorySpace(source_type, tpu::MemorySpace::kSemaphoreMem); @@ -117,98 +195,164 @@ LogicalResult MemRefSliceOp::verify() { } // TODO(apaszke): Check that the result has a smaller shape. // TODO(apaszke): Check that strides are equivalent. - // Source and target attributes may be different before propagation is done by - // the canonicalizer, so we allow this when attributes are "unset" in the - // target type. Note that MemRefType does not allow a null layout so we treat - // the default identity affine map as an "unset" value instead. + // Source and target memory spaces may be different before propagation is done + // by memory space specialization. bool is_target_memory_space_provided = target_memory_space != nullptr; if (is_target_memory_space_provided && target_memory_space != source_type.getMemorySpace()) { return emitOpError( "Memory spaces must match if the target memory space is provided."); } - bool is_target_layout_identity_map = - isa(target_layout) && target_layout.isIdentity(); - if (!is_target_layout_identity_map && - target_type.getLayout() != source_type.getLayout()) { - return emitOpError( - "Layouts must match if the target layout is not an identity map."); - } - if (getDynamicSizes().size() != target_type.getNumDynamicDims()) { - return emitOpError( - "Number of provided dynamic dimensions sizes must match the number of " - "dynamic dimensions in the target type."); + if (isa(source_layout) != + isa(target_layout)) { + return emitOpError("Source and target layouts must match."); } return success(); } -LogicalResult MemRefSliceOp::canonicalize(MemRefSliceOp op, - PatternRewriter &rewriter) { - auto erase_layout = op.getMemRef().getDefiningOp(); - if (!erase_layout) { - return failure(); +struct MemRefSliceFoldConstantDynamicDim + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MemRefSliceOp op, + PatternRewriter& rewriter) const override { + if (llvm::none_of(op.getDynamicSizes(), [](Value dynamic_size) { + APInt constant_value; // Would be nice if we could pass nullptr below + return matchPattern(dynamic_size, m_ConstantInt(&constant_value)); + })) { + return failure(); + } + SmallVector new_shape(op.getType().getShape()); + SmallVector new_dynamic_sizes; + int64_t dynamic_dim_index = 0; + for (Value dynamic_size : op.getDynamicSizes()) { + // Find the index of the corresponding dynamic dimension in the shape + while (new_shape[dynamic_dim_index] != ShapedType::kDynamic) { + ++dynamic_dim_index; + CHECK(dynamic_dim_index < new_shape.size()); + } + APInt constant_value; + if (matchPattern(dynamic_size, m_ConstantInt(&constant_value))) { + if (constant_value.getSExtValue() <= 0) { + return op.emitWarning() << "Non-positive constant for dynamic size"; + } + new_shape[dynamic_dim_index] = constant_value.getSExtValue(); + } else { + new_dynamic_sizes.push_back(dynamic_size); + } + ++dynamic_dim_index; + } + // Update the memref_slice op and create a cast op to convert to the old + // type. + MemRefType old_type = op.getType(); + MemRefType new_type = MemRefType::Builder(old_type).setShape(new_shape); + rewriter.modifyOpInPlace(op, [&]() { + op.getResult().setType(new_type); + op.getDynamicSizesMutable().assign(new_dynamic_sizes); + }); + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(op); + auto cast_op = memref::CastOp::create(rewriter, op.getLoc(), old_type, op); + rewriter.replaceAllUsesExcept(op, cast_op, cast_op); + return success(); } - // Push layout erasure through slicing. It is important we see the layout - // for lowering and don't make it hard for other ops to query it. - auto layout_ref = erase_layout.getOperand(); - MemRefType layout_ty = layout_ref.getType(); - auto new_result_type = MemRefType::get( - op.getResult().getType().getShape(), layout_ty.getElementType(), - layout_ty.getLayout(), layout_ty.getMemorySpace()); - auto slice = - rewriter.create(op.getLoc(), new_result_type, layout_ref, - op.getBaseIdx(), op.getDynamicSizes()); - rewriter.replaceOpWithNewOp(op, op.getType(), slice); - return success(); +}; + +struct MemRefSliceEraseLayout : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MemRefSliceOp op, + PatternRewriter& rewriter) const override { + auto erase_layout = op.getMemRef().getDefiningOp(); + if (!erase_layout) { + return failure(); + } + // Push layout erasure through slicing. It is important we see the layout + // for lowering and don't make it hard for other ops to query it. + auto layout_ref = erase_layout.getOperand(); + MemRefType layout_ty = layout_ref.getType(); + auto new_result_type = MemRefType::get( + op.getResult().getType().getShape(), layout_ty.getElementType(), + layout_ty.getLayout(), layout_ty.getMemorySpace()); + auto slice = MemRefSliceOp::create(rewriter, op.getLoc(), new_result_type, + layout_ref, op.getBaseIdx(), + op.getDynamicSizes()); + rewriter.replaceOpWithNewOp(op, slice); + return success(); + } +}; + +void MemRefSliceOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { + results.add( + context); } LogicalResult MemRefSqueezeOp::verify() { auto source_type = getMemRefType(getInput()); auto target_type = getType(); - // Source and target attributes may be different before propagation is done by - // the canonicalizer, so we allow this when attributes are "unset" in the - // target type. + if (target_type.getMemorySpace() != nullptr && target_type.getMemorySpace() != source_type.getMemorySpace()) { - emitOpError("Memory spaces do not match."); - return failure(); + return emitOpError("Memory spaces do not match."); } + if (target_type.getElementType() != source_type.getElementType()) { - this->emitOpError("Element types don't match."); - return failure(); - } - if (!HasMemorySpace(source_type, tpu::MemorySpace::kSemaphoreMem) && - source_type.getRank() > 1 && target_type.getRank() == 1) { - return emitError("Not implemented: squeeze memref to 1d."); + return emitOpError("Element types don't match."); } + auto source_shape = source_type.getShape(); auto target_shape = target_type.getShape(); - int source_index = source_shape.size() - 1; - int target_index = target_shape.size() - 1; - auto error_msg = llvm::formatv( - "Target shape is not valid. " - "Source type: {0}. Target type: {1}.", - source_type, target_type); - while (source_index >= 0 || target_index >= 0) { - int target_dim = target_index < 0 ? -1 : target_shape[target_index]; - if (source_index < 0) { - // We have run out of source shape but target shape still remains. - emitOpError(error_msg); - return failure(); - } - int source_dim = source_shape[source_index]; - if (source_dim == target_dim) { - source_index--; - target_index--; - } else { - // Only the source dim can be 1 here. - if (source_dim != 1) { - this->emitOpError(error_msg); - return failure(); - } - source_index--; + FAILUREOR_ASSIGN_OR_RETURN( + auto squeezed, + computeSqueezedDimsChecked(*this, source_shape, target_shape)); + if (squeezed.empty() && source_shape != target_shape) { + return emitOpError( + "Source and target shapes must be the same if no dimensions are " + "squeezed."); + } + + auto source_layout = source_type.getLayout(); + auto target_layout = target_type.getLayout(); + if (!isa(source_layout) && + !isa(target_layout)) { + return success(); + } + + auto tiles = cast(source_layout).getTiles(); + switch (tiles.size()) { + case 0: + break; + case 1: { + auto tile = tiles.front(); + auto tile_dims = tile.dimensions(); + int first_tiled = source_shape.size() - tile_dims.size(); + for (int dim : squeezed) { + if (dim >= first_tiled) { + int tile_idx = dim - first_tiled; + if (tile_idx < 0 || tile_idx >= static_cast(tile_dims.size())) { + return emitOpError() << "Internal error: tile index out of bounds."; + } + if (tile_dims[tile_idx] != 1) { + return emitOpError() + << "All tiled squeezed dimensions must be of size 1."; + } + } + } + break; + } + default: { + auto first_tile = tiles.front(); + for (int dim : squeezed) { + int first_tiled = source_shape.size() - first_tile.dimensions().size(); + if (dim >= first_tiled) { + return emitOpError() << "When multiple tiles are present, no tiled " + "dimensions can be squeezed."; + } + } } } + return success(); } @@ -220,42 +364,91 @@ LogicalResult MemRefSqueezeOp::canonicalize(MemRefSqueezeOp op, if (!erase_layout) { return failure(); } - // Push layout erasure through squeezing. It is important we see the layout - // for lowering and don't make it hard for other ops to query it. + auto layout_ref = erase_layout.getOperand(); - MemRefType layout_ty = layout_ref.getType(); + MemRefType layout_ty = getMemRefType(layout_ref); + auto layout_attr = dyn_cast(layout_ty.getLayout()); + if (!layout_attr) { + return failure(); + } + auto source_shape = source_type.getShape(); auto target_shape = target_type.getShape(); - int source_index = source_shape.size() - 1; - int target_index = target_shape.size() - 1; - auto old_layout = dyn_cast(layout_ty.getLayout()); - auto target_strides = old_layout.getTileStrides(); - SmallVector tile_strides(target_strides.begin(), - target_strides.end()); - // We want to remove all strides that correspond to squeezed dimensions and - // update the corresponding output layout. - while (source_index >= 0 || target_index >= 0) { - int target_dim = target_index < 0 ? -1 : target_shape[target_index]; - int source_dim = source_shape[source_index]; - if (source_dim == target_dim) { - source_index--; - target_index--; - } else { - // Source index must be 1 here (otherwise verification will have failed). - // We are safe to mutate the strides vector here because we are looping - // backwards. - tile_strides.erase(tile_strides.begin() + source_index); - source_index--; + auto squeezed_or = computeSqueezedDimsChecked(op, source_shape, target_shape); + if (failed(squeezed_or)) { + return failure(); + } + auto &squeezed = squeezed_or.value(); + if (squeezed.empty() && source_shape != target_shape) { + return failure(); + } + + SmallVector tile_strides = + llvm::to_vector(layout_attr.getTileStrides()); + for (int i = squeezed.size() - 1; i >= 0; --i) { + tile_strides.erase(tile_strides.begin() + squeezed[i]); + } + + tpu::TiledLayoutAttr new_layout; + bool target_is_1d = target_shape.size() == 1; + auto tiles = layout_attr.getTiles(); + if (target_is_1d && tiles.size() == 1) { + auto tile_dims = llvm::to_vector(tiles.front().dimensions()); + int first_tiled = source_shape.size() - tile_dims.size(); + for (int i = squeezed.size() - 1; i >= 0; --i) { + int dim = squeezed[i]; + if (dim >= first_tiled) { + int tile_idx = dim - first_tiled; + if (tile_idx < 0 || tile_idx >= static_cast(tile_dims.size())) { + return op.emitError() << "Internal error: tile index out of bounds."; + } + tile_dims.erase(tile_dims.begin() + tile_idx); + } } + new_layout = tpu::TiledLayoutAttr::get( + op.getContext(), {xla::Tile(tile_dims)}, tile_strides); + } else { + new_layout = tpu::TiledLayoutAttr::get( + op.getContext(), layout_attr.getTiles(), tile_strides); + } + + auto new_ty = MemRefType::get(target_shape, layout_ty.getElementType(), + new_layout, layout_ty.getMemorySpace()); + + auto new_squeeze = + MemRefSqueezeOp::create(rewriter, op.getLoc(), new_ty, layout_ref); + rewriter.replaceOpWithNewOp(op, new_squeeze); + return success(); +} + +LogicalResult RelayoutOp::verify() { + auto in_layout_array_attr = + getOperation()->getAttrOfType("in_layout"); + if (!in_layout_array_attr) { + return emitOpError("missing 'in_layout' attribute"); + } + if (in_layout_array_attr.size() != 1) { + return emitOpError( + "'in_layout' attribute must be an array containing a single " + "VectorLayoutAttr"); + } + if (!isa(in_layout_array_attr[0])) { + return emitOpError("'in_layout' attribute is not a VectorLayoutAttr"); + } + + auto out_layout_array_attr = + getOperation()->getAttrOfType("out_layout"); + if (!out_layout_array_attr) { + return emitOpError("missing 'out_layout' attribute"); + } + if (out_layout_array_attr.size() != 1) { + return emitOpError( + "'out_layout' attribute must be an array containing a single " + "VectorLayoutAttr"); + } + if (!isa(out_layout_array_attr[0])) { + return emitOpError("'out_layout' attribute is not a VectorLayoutAttr"); } - auto new_layout = tpu::TiledLayoutAttr::get( - source_type.getContext(), old_layout.getTiles(), tile_strides); - auto new_result_type = MemRefType::get(op.getResult().getType().getShape(), - layout_ty.getElementType(), new_layout, - layout_ty.getMemorySpace()); - auto squeeze = rewriter.create(op.getLoc(), new_result_type, - layout_ref); - rewriter.replaceOpWithNewOp(op, op.getType(), squeeze); return success(); } @@ -322,6 +515,41 @@ LogicalResult MemRefReshapeOp::verify() { return success(); } +LogicalResult TransposeOp::verify() { + auto source_type = getSourceVectorType(); + auto permutation = getPermutation(); + auto output_type = getResultVectorType(); + auto input_shape = source_type.getShape(); + auto output_shape = output_type.getShape(); + if (source_type.getElementType() != output_type.getElementType()) { + return emitOpError("Expected input and output element types to match"); + } + if (permutation.size() != source_type.getRank()) { + return emitOpError("Expected permutation rank to match input rank"); + } + if (permutation.size() != output_type.getRank()) { + return emitOpError("Expected permutation rank to match output rank"); + } + std::vector seen_dims(source_type.getRank(), false); + for (int64_t dim : permutation) { + if (dim < 0 || dim >= source_type.getRank()) { + return emitOpError("Permutation element out of bounds: ") << dim; + } + if (seen_dims[dim]) { + return emitOpError("Permutation element repeated: ") << dim; + } + seen_dims[dim] = true; + } + for (int i = 0; i < source_type.getRank(); ++i) { + if (input_shape[permutation[i]] != output_shape[i]) { + return emitOpError( + "Expected input shape permuted by the given permutation to match the " + "output shape"); + } + } + return success(); +} + LogicalResult MemRefReshapeOp::canonicalize(MemRefReshapeOp op, PatternRewriter &rewriter) { auto src_ty = op.getInput().getType(); @@ -332,8 +560,7 @@ LogicalResult MemRefReshapeOp::canonicalize(MemRefReshapeOp op, } auto layout_ref = erase_layout_op.getOperand(); auto layout_ty = layout_ref.getType(); - auto layout = - dyn_cast(layout_ty.getLayout()); + auto layout = cast(layout_ty.getLayout()); CHECK(!layout.getTiles().empty()); auto tile = layout.getTiles().front().dimensions(); auto new_tile_strides = ComputeTileStrides(dst_ty, tile); @@ -343,8 +570,8 @@ LogicalResult MemRefReshapeOp::canonicalize(MemRefReshapeOp op, MemRefType::get(dst_ty.getShape(), dst_ty.getElementType(), new_layout, layout_ty.getMemorySpace()); auto reshape = - rewriter.create(op.getLoc(), new_result_ty, layout_ref); - rewriter.replaceOpWithNewOp(op, op.getType(), reshape); + MemRefReshapeOp::create(rewriter, op.getLoc(), new_result_ty, layout_ref); + rewriter.replaceOpWithNewOp(op, reshape); return success(); } @@ -361,8 +588,8 @@ LogicalResult MemRefBitcastOp::verify() { if (src_ty.getRank() <= 1) { return emitOpError("Not implemented: 1d memref bitcast."); } - auto src_bitwidth = src_ty.getElementTypeBitWidth(); - auto tgt_bitwidth = tgt_ty.getElementTypeBitWidth(); + auto src_bitwidth = getElementTypeBitwidth(src_ty); + auto tgt_bitwidth = getElementTypeBitwidth(tgt_ty); for (int i = 0; i < src_ty.getRank(); ++i) { auto src_dim_size = src_ty.getDimSize(i); auto tgt_dim_size = tgt_ty.getDimSize(i); @@ -417,8 +644,8 @@ LogicalResult MemRefBitcastOp::canonicalize(MemRefBitcastOp op, if (!erase_layout_op) { return failure(); } - auto src_bitwidth = src_ty.getElementTypeBitWidth(); - auto tgt_bitwidth = dst_ty.getElementTypeBitWidth(); + auto src_bitwidth = getElementTypeBitwidth(src_ty); + auto tgt_bitwidth = getElementTypeBitwidth(dst_ty); auto layout_ref = erase_layout_op.getOperand(); auto layout_ty = layout_ref.getType(); auto layout = cast(layout_ty.getLayout()); @@ -427,8 +654,8 @@ LogicalResult MemRefBitcastOp::canonicalize(MemRefBitcastOp op, if (tile[0] * src_bitwidth % tgt_bitwidth != 0) { return failure(); } - SmallVector new_tiles = - {xla::Tile({tile[0] * src_bitwidth / tgt_bitwidth, 128})}; + SmallVector new_tiles = { + xla::Tile({tile[0] * src_bitwidth / tgt_bitwidth, 128})}; if (tgt_bitwidth < 32) { new_tiles.push_back(xla::Tile({32 / tgt_bitwidth, 1})); } @@ -438,14 +665,14 @@ LogicalResult MemRefBitcastOp::canonicalize(MemRefBitcastOp op, MemRefType::get(dst_ty.getShape(), dst_ty.getElementType(), new_layout, layout_ty.getMemorySpace()); auto bitcast = - rewriter.create(op.getLoc(), new_result_ty, layout_ref); - rewriter.replaceOpWithNewOp(op, op.getType(), bitcast); + MemRefBitcastOp::create(rewriter, op.getLoc(), new_result_ty, layout_ref); + rewriter.replaceOpWithNewOp(op, bitcast); return success(); } template LogicalResult verifyStridedOp(Op op, MemRefType memref_ty, - VectorType vector_ty) { + VectorType vector_ty, int64_t min_stride) { auto indices = op.getIndices(); auto strides = op.getStrides(); if (memref_ty.getRank() != indices.size()) { @@ -464,8 +691,9 @@ LogicalResult verifyStridedOp(Op op, MemRefType memref_ty, return failure(); } for (int64_t i = 0; i < memref_ty.getRank(); ++i) { - if (strides[i] < 1) { - op.emitError("Strides[") << i << "]=" << strides[i] << " must be >= 1"; + if (strides[i] < min_stride) { + op.emitError("Strides[") << i << "]=" << strides[i] << " must be >= " + << min_stride; return failure(); } } @@ -474,45 +702,190 @@ LogicalResult verifyStridedOp(Op op, MemRefType memref_ty, LogicalResult StridedLoadOp::verify() { return verifyStridedOp(*this, getMemRefType(getBase()), - getType()); + getType(), /*min_stride=*/0); } LogicalResult StridedStoreOp::verify() { return verifyStridedOp(*this, getMemRefType(getBase()), - getValueToStore().getType()); + getValueToStore().getType(), + /*min_stride=*/1); +} + +template +LogicalResult verifyStoreOp(Op op) { + MemRefType ref_ty = op.getBase().getType(); + if (!HasMemorySpace(ref_ty, MemorySpace::kVmem)) { + return op.emitOpError("Expected base memref to be in VMEM."); + } + VectorType value_ty = op.getValueToStore().getType(); + if (value_ty.getElementType() != ref_ty.getElementType()) { + return op.emitOpError( + "Expected base and valueToStore element type to match"); + } + if (op.getMask()) { + if (getElementTypeBitwidth(value_ty) != 32) { + return op.emitError( + "Not implemented: masked store with non-32-bit element type"); + } + if (value_ty.getShape() != op.getMask().getType().getShape()) + return op.emitOpError("Expected mask shape to match result shape: (") + << value_ty.getShape() << "). Got: (" + << op.getMask().getType().getShape() << ")."; + } + return success(); } LogicalResult VectorStoreOp::verify() { if (!getStrides().empty()) { return emitError("Not implemented: general vector store with strides."); } - VectorType value_ty = getValueToStore().getType(); MemRefType ref_ty = getBase().getType(); + if (llvm::size(getIndices()) != ref_ty.getRank()) { + return emitOpError("Expected ") << ref_ty.getRank() << " indices."; + } + return verifyStoreOp(*this); +} + +LogicalResult VectorStoreOp::canonicalize(VectorStoreOp op, + PatternRewriter& rewriter) { + return propagateTiledLayoutToConsumer(op, rewriter); +} +template +LogicalResult verifyLoadOp(Op op) { + MemRefType ref_ty = op.getBase().getType(); + if (!HasMemorySpace(ref_ty, MemorySpace::kVmem)) { + return op.emitOpError("Expected base memref to be in VMEM."); + } + VectorType value_ty = op.getResult().getType(); if (value_ty.getElementType() != ref_ty.getElementType()) { - return emitOpError( - "Expected base and valueToStore element type should match"); + return op.emitOpError("Expected base and result element type to match."); } + if (op.getMask()) { + if (getElementTypeBitwidth(value_ty) != 32) { + return op.emitError( + "Not implemented: masked load with non-32-bit element type"); + } + if (vector::isBroadcastableTo(op.getMask().getType(), value_ty) != + vector::BroadcastableToResult::Success) { + return op.emitOpError( + "Expected mask shape to be broadcastable to result shape."); + } + } + return success(); +} + +LogicalResult VectorLoadOp::verify() { + const MemRefType ref_ty = getBase().getType(); if (llvm::size(getIndices()) != ref_ty.getRank()) { - return emitOpError("Expected ") << ref_ty.getRank() << " indices"; + return emitOpError("Expected ") << ref_ty.getRank() << " indices."; } - if (getMask()) { - if (value_ty.getElementTypeBitWidth() != 32) { - return emitError( - "Not implemented: masked store with non-32-bit element type"); + if (!getStrides().empty()) { + if (llvm::size(getStrides()) != ref_ty.getRank()) { + return emitOpError("Expected ") << ref_ty.getRank() << " strides."; } - if (value_ty.getShape() != getMask().getType().getShape()) - return emitOpError("Expected valueToStore shape to match mask shape"); + return emitError("Not implemented: general vector load with strides."); } - return success(); + return verifyLoadOp(*this); +} + +LogicalResult VectorLoadOp::canonicalize(VectorLoadOp op, + PatternRewriter& rewriter) { + return propagateTiledLayoutToConsumer(op, rewriter); +} + +LogicalResult VectorLoadIdxOp::verify() { + VectorType value_ty = getResult().getType(); + MemRefType ref_ty = getBase().getType(); + if (llvm::size(getIndices()) != ref_ty.getRank()) { + return emitOpError( + "Expected one index vector for each dimension of the base " + "memref with dimension: ") + << ref_ty.getRank() << ". Got: " << llvm::size(getIndices()) << "."; + } + for (const auto [i, index] : llvm::enumerate(getIndices())) { + VectorType index_ty = llvm::cast(index.getType()); + if (index_ty.getShape() != value_ty.getShape()) { + return emitOpError("Expected ") + << value_ty.getShape() << " elements in indices. Got " + << index_ty.getShape() << " in index #" << i << "."; + } + } + return verifyLoadOp(*this); +} + +LogicalResult VectorLoadIdxOp::canonicalize(VectorLoadIdxOp op, + PatternRewriter& rewriter) { + return propagateTiledLayoutToConsumer(op, rewriter); +} + +LogicalResult VectorStoreIdxOp::verify() { + VectorType value_ty = getValueToStore().getType(); + MemRefType ref_ty = getBase().getType(); + if (llvm::size(getIndices()) != ref_ty.getRank()) { + return emitOpError( + "Expected one index vector for each dimension of the base " + "memref with dimension: ") + << ref_ty.getRank() << ". Got: " << llvm::size(getIndices()) << "."; + } + if (value_ty.getRank() != 1) { + return emitOpError("Expected value to have rank 1. Got: ") + << value_ty.getRank() << "."; + } + for (const auto [i, index] : llvm::enumerate(getIndices())) { + VectorType index_ty = llvm::cast(index.getType()); + if (index_ty.getShape() != value_ty.getShape()) { + return emitOpError("Expected ") + << value_ty.getShape() << " elements in indices. Got " + << index_ty.getShape() << " in index #" << i << "."; + } + } + return verifyStoreOp(*this); +} + +LogicalResult VectorStoreIdxOp::canonicalize(VectorStoreIdxOp op, + PatternRewriter& rewriter) { + return propagateTiledLayoutToConsumer(op, rewriter); } LogicalResult ReinterpretCastOp::verify() { auto source_type = getMemRefType(getInput()); auto target_type = getType(); - return success( - source_type.getMemorySpace() && // Require memory space annotations. - source_type.getMemorySpace() == target_type.getMemorySpace()); + if (source_type.getMemorySpace() != target_type.getMemorySpace()) { + return emitOpError("Source and target memory spaces must match, but got ") + << source_type.getMemorySpace() << " and " + << target_type.getMemorySpace(); + } + return success(); +} + +LogicalResult ReinterpretCastOp::canonicalize(ReinterpretCastOp op, + PatternRewriter& rewriter) { + if (auto erase_layout_op = op.getInput().getDefiningOp()) { + rewriter.modifyOpInPlace(op, [&]() { + op.getInputMutable().assign(erase_layout_op.getOperand()); + }); + return success(); + } + return failure(); +} + +LogicalResult EraseLayoutOp::inferReturnTypes( + MLIRContext* context, std::optional location, + EraseLayoutOp::Adaptor adaptor, + ::llvm::SmallVectorImpl& inferredReturnTypes) { + inferredReturnTypes.push_back( + MemRefType::Builder(cast(adaptor.getOperand().getType())) + .setLayout(nullptr)); + return success(); +} + +OpFoldResult EraseLayoutOp::fold(FoldAdaptor op) { + // If the operand has no interesting layout then there's no need to erase it. + if (getOperand().getType().getLayout().isIdentity()) { + return op.getOperand(); + } + return OpFoldResult(); } template @@ -549,6 +922,32 @@ LogicalResult DynamicRotateOp::verify() { return verifyRotateOp(*this); } +LogicalResult ScanCountOp::inferReturnTypes( + MLIRContext *context, std::optional location, + ScanCountOp::Adaptor adaptor, + ::llvm::SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(adaptor.getInMask().getType()); + inferredReturnTypes.push_back(VectorType::get( + cast(adaptor.getValues().getType()).getShape(), + IntegerType::get(context, 32))); + return success(); +} + +LogicalResult IotaOp::verify() { + const int64_t rank = getType().getRank(); + SmallVector seen(rank, false); + for (const int32_t dim : getDimensions()) { + if (dim < 0 || dim >= getType().getRank()) { + return emitOpError("Invalid dimension: ") << dim; + } + if (seen[dim]) { + return emitOpError("Dimensions must be unique"); + } + seen[dim] = true; + } + return success(); +} + // a + matmul(l, r, 0) == matmul(l, r, a) template class CanonicalizeAddOfMatmul : public OpRewritePattern { @@ -589,7 +988,7 @@ LogicalResult MatmulOp::verify() { return emitOpError( "Not implemented: matmul acc and result have different types"); } - if (acc_ty.getElementTypeBitWidth() != 32) { + if (getElementTypeBitwidth(acc_ty) != 32) { return emitOpError("Expected matmul acc to be 32-bit"); } @@ -624,6 +1023,15 @@ LogicalResult MatmulOp::verify() { auto rhs_non_contracting_dims = dimension_numbers.getRhsNonContractingDims(); + if (!llvm::is_sorted(lhs_non_contracting_dims)) { + emitOpError("Not implemented: lhs non contracting dims must be sorted"); + return failure(); + } + if (!llvm::is_sorted(rhs_non_contracting_dims)) { + emitOpError("Not implemented: rhs non contracting dims must be sorted"); + return failure(); + } + if (lhs_contracting_dims.size() + lhs_non_contracting_dims.size() + lhs_batch_dims.size() != lhs_ty.getShape().size()) { @@ -716,15 +1124,8 @@ LogicalResult MatmulOp::verify() { const std::optional batch_dim_rhs = rhs_batch_dims.empty() ? std::nullopt : std::optional(rhs_batch_dims[0]); - if (batch_dim_lhs != batch_dim_rhs) { - emitOpError("Not Implemented: batch dims must be equal"); - return failure(); - } - if (batch_dim_lhs.has_value() && (batch_dim_lhs.value() != 0)) { - emitOpError("Not Implemented: batch dims pos must be 0"); - return failure(); - } - // Invariant above enforces only 1 batch dim atm, and that both are eq + + // Invariant above enforces only 1 batch dim atm. std::optional batch_size = std::nullopt; if (batch_dim_lhs.has_value()) { batch_size = lhs_ty.getShape()[batch_dim_lhs.value()]; @@ -744,117 +1145,35 @@ LogicalResult MatmulOp::verify() { "Illegal: output dim order must have an even number of elements."); return failure(); } - if (batch_size.has_value()) { - if (output_dim_order[0] != 0 || output_dim_order[1] != 0) { - emitOpError( - "Not implemented: Output with batch size must be the lhs 0 idx for " - "now."); - return failure(); - } - } - // Invariants above enforce a single batch idx for now, and that it is in - // position 0. Future extensions to this will be to: - // 1. Support multiple batch dims - // 2. Support batch dims in any position in the output dim order - if (lhs_non_contracting_dims.size() != 1) { - emitOpError( - "Not implemented: lhs non contracting dims must be of size 1"); - return failure(); - } - if (rhs_non_contracting_dims.size() != 1) { - emitOpError( - "Not implemented: rhs non contracting dims must be of size 1"); - return failure(); + // Invariants above enforce a single batch idx for now. Future extension to + // this will be to support multiple batch dims. + + // Verify that the output dim order is always in the form of [0, + // lhs_batch_dims, 0, lhs_non_contracting_dims, 1, + // rhs_non_contracting_dims]. + llvm::SmallVector expected_output_dim_order; + expected_output_dim_order.reserve(2 * (lhs_batch_dims.size() + + lhs_non_contracting_dims.size() + + rhs_non_contracting_dims.size())); + for (int64_t dim : lhs_batch_dims) { + expected_output_dim_order.push_back(0); + expected_output_dim_order.push_back(dim); } - - // A bit long winded, but the invariants we enforce below are: - // 1. The output order idx is 0 (lhs) or 1 (rhs) - // 2. The output dim order is in valid bounds - // 3. We saw the rhs and lhs non contracting dims in the output dim order - // 4. We never see the contracting dims in the output dim order - // 5. We only see each of the non contracting dim once - std::vector lhs_dims_seen_in_output(lhs_rank, false); - std::vector rhs_dims_seen_in_output(rhs_rank, false); - - // Iterate over the output dimension order - for (int dim_pos = 0; dim_pos < output_dim_order.size(); dim_pos += 2) { - auto idx = output_dim_order[dim_pos]; - auto dim = output_dim_order[dim_pos + 1]; - - if (idx != 0 && idx != 1) { - emitOpError("Illegal: output dim order index must be 0 or 1"); - return failure(); - } - auto is_lhs = (idx == 0); - - if (is_lhs) { - if (dim < 0 || dim >= lhs_rank) { - emitOpError("Illegal: lhs dimension index out of bounds"); - return failure(); - } - if (lhs_dims_seen_in_output[dim]) { - emitOpError("Illegal: lhs dimension ") - << dim << " appears more than once in output dim order"; - return failure(); - } - if (dim == lhs_contracting_dim) { - emitOpError("Illegal: contracting dimension ") - << dim << " appears in lhs output dim order"; - return failure(); - } - // batch_dim_lhs is either 0 or nullopt - if (dim == batch_dim_lhs) { - // Upstream invariants enforce that batch dim is in position 0 - // of the output dim order. - rhs_dims_seen_in_output[dim] = true; - } - lhs_dims_seen_in_output[dim] = true; - } else { - if (dim < 0 || dim >= rhs_rank) { - emitOpError("Illegal: rhs dimension index out of bounds"); - return failure(); - } - if (rhs_dims_seen_in_output[dim]) { - emitOpError("Illegal: rhs dimension ") - << dim << " appears more than once in output dim order"; - return failure(); - } - if (dim == rhs_contracting_dim) { - emitOpError("Illegal: contracting dimension ") - << dim << " appears in rhs output dim order"; - return failure(); - } - if (dim == batch_dim_rhs) { - // Upstream invariants enforce that batch dim is in position 0 - // of the output dim order. - lhs_dims_seen_in_output[dim] = true; - } - rhs_dims_seen_in_output[dim] = true; - } + for (int64_t dim : lhs_non_contracting_dims) { + expected_output_dim_order.push_back(0); + expected_output_dim_order.push_back(dim); } - - // Check that all dims have been seen (except contracting dims) - for (int i = 0; i < lhs_rank; ++i) { - if (i == lhs_contracting_dim) { - continue; - } - if (!lhs_dims_seen_in_output[i]) { - emitOpError("Illegal: lhs non-contracting dimension ") - << i << " is not seen in output dim order"; - return failure(); - } + for (int64_t dim : rhs_non_contracting_dims) { + expected_output_dim_order.push_back(1); + expected_output_dim_order.push_back(dim); } - - for (int i = 0; i < rhs_rank; ++i) { - if (i == rhs_contracting_dim) { - continue; - } - if (!rhs_dims_seen_in_output[i]) { - emitOpError("Illegal: rhs non-contracting dimension ") - << i << " is not seen in output dim order"; - return failure(); - } + if (!absl::c_equal(output_dim_order, expected_output_dim_order)) { + emitOpError( + "Illegal: output dim order must be in the form of [0, " + "lhs_batch_dims, 0, lhs_non_contracting_dims, 1, " + "rhs_non_contracting_dims]"); + return failure(); } } return success(); @@ -869,13 +1188,98 @@ void MatmulOp::getCanonicalizationPatterns(RewritePatternSet &patterns, LogicalResult MaskCastOp::verify() { auto input_ty = getInput().getType(); auto output_ty = getResult().getType(); - return success(input_ty.getElementType() == output_ty.getElementType() && - output_ty.getRank() == 3 && - (input_ty.getRank() == 2 || - (input_ty.getRank() == 3 && - input_ty.getDimSize(2) < output_ty.getDimSize(2))) && - input_ty.getShape().take_front(2) == - output_ty.getShape().take_front(2)); + return success(input_ty.getShape().take_front(2) == + output_ty.getShape().take_front(2)); +} + +LogicalResult ScanOp::verify() { + FailureOr issuing_core = GetCoreTypeOfParentFunc(**this); + if (failed(issuing_core)) { + return issuing_core; + } + if (issuing_core != CoreType::kScVectorSubcore) { + return emitOpError("Scan is supported only on the SC vector subcore"); + } + + VectorType input_ty = getInput().getType(); + VectorType output_ty = getOutput().getType(); + + if (input_ty.getElementType().isInteger(1)) { + if (!output_ty.getElementType().isInteger(32)) { + return emitOpError( + "Output element type must be i32 vector for i1 vector inputs."); + } + } else { + if (input_ty.getElementType() != output_ty.getElementType()) { + return emitOpError("Input and output element type mismatch."); + } + } + + if (input_ty.getShape() != output_ty.getShape()) { + return emitOpError("Input and output shape mismatch. Input shape: (") + << input_ty.getShape() << "). Output shape: (" + << output_ty.getShape() << ")."; + } + + if (input_ty.getRank() > 2) { + return emitOpError("Input must be a rank 1 or 2 vector."); + } + + if (input_ty.getElementType().isInteger(1) && + getKind() != ReductionKind::kSum) { + return emitOpError("Only sum reduction is supported for i1 vector inputs."); + } else if (getKind() != ReductionKind::kSum && + getKind() != ReductionKind::kMax && + getKind() != ReductionKind::kMin) { + return emitOpError("Only sum, max and min reductions are supported."); + } + + if (getMask() == nullptr) { + return success(); + } else if (input_ty.getElementType().isInteger(1)) { + return emitOpError("Mask is not supported for i1 vector inputs."); + } + + VectorType mask_ty = getMask().getType(); + if (mask_ty.getRank() != 1) { + return emitOpError("Mask must be a rank 1 vector."); + } + if (mask_ty.getShape()[0] != input_ty.getShape()[input_ty.getRank() - 1]) { + return emitOpError("Mask and input mismatch. Expected mask of length: ") + << input_ty.getShape()[input_ty.getRank() - 1] << ", but got " + << mask_ty.getShape()[0] << "."; + } + + return success(); +} + +LogicalResult SortOp::verify() { + VectorType keys_ty = getKeys().getType(); + VectorType values_ty = getValues().getType(); + if (keys_ty.getShape() != values_ty.getShape()) { + return emitOpError("Key and value shapes must match: ") + << keys_ty.getShape() << " vs " << values_ty.getShape(); + } + if (getMask()) { + VectorType mask_ty = getMask().getType(); + if (keys_ty.getShape() != mask_ty.getShape()) { + return emitOpError("Key and input mask shapes must match: ") + << keys_ty.getShape() << " vs " << mask_ty.getShape(); + } + } + VectorType output_mask_ty = getOutputMask().getType(); + if (keys_ty.getShape() != output_mask_ty.getShape()) { + return emitOpError("Key and output mask shapes must match: ") + << keys_ty.getShape() << " vs " << output_mask_ty.getShape(); + } + if (keys_ty != getSortedKeys().getType()) { + return emitOpError("Key and sorted_key types must match: ") + << keys_ty << " vs " << getSortedKeys().getType(); + } + if (values_ty != getSortedValues().getType()) { + return emitOpError("Value and sorted_value types must match: ") + << values_ty << " vs " << getSortedValues().getType(); + } return success(); } @@ -947,42 +1351,295 @@ LogicalResult EnqueueDMAOp::verify() { if (target_sem_type.getRank() != 0) { return emitOpError("DMA target semaphore must be rank 0"); } + auto source_ty = getMemRefType(getSource()); + auto target_ty = getMemRefType(getTarget()); + if (source_ty.getElementType() != target_ty.getElementType()) { + return emitOpError("DMA source and target element type mismatch"); + } + if (source_ty.getShape() != target_ty.getShape()) { + return emitOpError("DMA source and target shape mismatch."); + } + if (getDeviceId() || getCoreId()) { if (!getSourceSemaphore()) { return emitOpError( - "DMA source semaphore must be specified when " - "device_id or core_id is specified"); + "DMA source semaphore must be specified when device_id or core_id is " + "specified"); } } + bool is_remote = getDeviceId() || getCoreId(); if (getSourceSemaphore()) { - if (!getDeviceId() && !getCoreId()) { + if (!is_remote) { return emitOpError( "DMA destination device_id or core_id must be specified when source " "semaphore is specified"); } } + int priority = getPriority(); + if (priority < 0 || priority > 1) { + return emitOpError( + "Not implemented: only support priority 0 or 1, but got ") + << priority; + } + if (priority != 0 && is_remote) { + return emitOpError( + "Not implemented: non-zero priority is not supported for remote DMA"); + } + FailureOr issuing_core = GetCoreTypeOfParentFunc(**this); + if (failed(issuing_core)) { + return issuing_core; + } + if (getStrictOrdering() && *issuing_core != CoreType::kScScalarSubcore && + *issuing_core != CoreType::kScVectorSubcore) { + return emitOpError( + "Strict ordering is only supported on the SC scalar and vector " + "subcores"); + } return success(); } -// TODO(mvoz): Remove once a month has passed. b/395630795 +LogicalResult EnqueueDMAOp::canonicalize(EnqueueDMAOp op, + PatternRewriter& rewriter) { + return propagateTiledLayoutToConsumer(op, rewriter); +} + +LogicalResult EnqueueIndirectDMAOp::verifyGather( + MemRefType operand_ty, ArrayRef offsets_shape, + MemRefType result_ty) { + // We've already thrown an error if the target is not VMEM. so this is just a + // sanity check. + CHECK(HasMemorySpace(result_ty, MemorySpace::kVmem)); + uint64_t offsets_rank = offsets_shape.size(); + // Slice [o0, .., on] out of [o0, .., on, s0, .., sm]. + ArrayRef result_offset_dims = + result_ty.getShape().take_front(offsets_rank); + // Slice [s0, .., sm] out of [o0, .., on, s0, .., sm]. + ArrayRef result_slice_dims = + result_ty.getShape().drop_front(offsets_rank); + // Slice [s0, .., sm] out of [z0, .., zn, s0, .., sm]. + ArrayRef operand_slice_dims = + operand_ty.getShape().drop_front(offsets_rank); + uint64_t slice_rank = operand_slice_dims.size(); + + const std::string result_shape_str = + absl::StrJoin(result_ty.getShape(), ", "); + + // Make sure that the output shape is such that there is one output slice per + // offset. + // offsets shape : [o0, .., on] + // result shape : [o'0, .., o'n, s0, .., sm] + // [o0, .., on] == [o'0, .., o'n] + if (!absl::c_equal(offsets_shape, result_offset_dims)) { + return emitOpError("Offsets shape (") + << absl::StrJoin(offsets_shape, ", ") + << ") must match the majormost dimensions of the target (gather " + "result) shape (" + << result_shape_str << ")"; + } + + // At each offset, we are copying an ND slice of data. Make sure that the + // slice shape is the same in the operand and the output for the gather, and + // in the updates and the operand for the scatter. + // Operand shape : [z0, .., zn, s0, .., sm] + // Result shape : [o0, .., on, s'0, .., s'm] + // [s0, .., sm] == [s'0, .., s'm] + if (!absl::c_equal(operand_slice_dims, result_slice_dims)) { + const std::string plural = slice_rank == 1 ? "" : "s"; + return emitOpError(absl::StrFormat( + "%d minormost dimension%s of the source (gather operand) shape (%s) " + "must match the minormost dimension%s of the target (gather result) " + "shape (%s)", + slice_rank, plural, absl::StrJoin(operand_ty.getShape(), ", "), plural, + result_shape_str)); + } + return success(); +} + +LogicalResult EnqueueIndirectDMAOp::verifyScatter( + MemRefType updates_ty, ArrayRef offsets_shape, + MemRefType operand_ty) { + // We've already thrown an error if the source is not VMEM. so this is just a + // sanity check. + CHECK(HasMemorySpace(updates_ty, MemorySpace::kVmem)); + uint64_t offsets_rank = offsets_shape.size(); + // Slice [o0, .., on] out of [o0, .., on, s0, .., sm]. + ArrayRef updates_offset_dims = + updates_ty.getShape().take_front(offsets_rank); + // Slice [s0, .., sm] out of [o0, .., on, s0, .., sm]. + ArrayRef updates_slice_dims = + updates_ty.getShape().drop_front(offsets_rank); + // Slice [s0, .., sm] out of [z0, .., zn, s0, .., sm]. + ArrayRef operand_slice_dims = + operand_ty.getShape().drop_front(offsets_rank); + uint64_t slice_rank = operand_slice_dims.size(); + + const std::string updates_shape_str = + absl::StrJoin(updates_ty.getShape(), ", "); + + // Make sure that there is one slice of updates per offset + // offsets shape : [o0, .., on] + // updates shape : [o'0, .., o'n, s0, .., sm] + // [o0, .., on] == [o'0, .., o'n] + if (!absl::c_equal(offsets_shape, updates_offset_dims)) { + return emitOpError("Offsets shape (") + << absl::StrJoin(offsets_shape, ", ") + << ") must match the majormost dimensions of the source " + "(scatter updates) shape (" + << updates_shape_str << ")"; + } + + // At each offset, we are copying an ND slice of data. Make sure that the + // slice shape is the same in the operand and the output for the gather, and + // in the updates and the operand for the scatter. + // Updates shape : [o0, .., on, s0, .., sm] + // Operand shape : [z0, .., zn, s'0, .., s'm] + // [s0, .., sm] == [s'0, .., s'm] + if (!absl::c_equal(operand_slice_dims, updates_slice_dims)) { + const std::string plural = slice_rank == 1 ? "" : "s"; + return emitOpError(absl::StrFormat( + "%d minormost dimension%s of the source (scatter updates) shape (%s) " + "must match the minormost dimension%s of the target (scatter operand) " + "shape (%s)", + slice_rank, plural, updates_shape_str, plural, + absl::StrJoin(operand_ty.getShape(), ", "))); + } + return success(); +} + +namespace { +bool hasHbmOrVmemSharedMemorySpace(MemRefType ty) { + return HasMemorySpace(ty, MemorySpace::kHbm) || + HasMemorySpace(ty, MemorySpace::kVmemShared); +} + +FailureOr isGather(Operation &op, Value source, Value target) { + const MemRefType source_ty = getMemRefType(source); + const MemRefType target_ty = getMemRefType(target); + if (hasHbmOrVmemSharedMemorySpace(source_ty) && + HasMemorySpace(target_ty, MemorySpace::kVmem)) { + return true; + } + if (HasMemorySpace(source_ty, MemorySpace::kVmem) && + hasHbmOrVmemSharedMemorySpace(target_ty)) { + return false; + } + return op.emitOpError( + "The transfer must be between HBM and VMEM, or between VMEM_SHARED and " + "VMEM"); +} +} // namespace + +FailureOr EnqueueIndirectDMAOp::isGather() { + return mlir::tpu::isGather(*getOperation(), getSource(), getTarget()); +} + +LogicalResult EnqueueIndirectDMAOp::verify() { + FailureOr issuing_core = GetCoreTypeOfParentFunc(**this); + if (failed(issuing_core)) { + return issuing_core; + } + if (issuing_core != CoreType::kScVectorSubcore) { + return emitOpError( + "Enqueue indirect DMA is supported only on the SC vector subcore"); + } + + const MemRefType source_ty = getMemRefType(getSource()); + const MemRefType target_ty = getMemRefType(getTarget()); + + if (source_ty.getElementType() != target_ty.getElementType()) { + return emitOpError("Source and target element type mismatch"); + } + + FAILUREOR_ASSIGN_OR_RETURN(bool is_gather, isGather()); + + const Value offsets = getOffsets(); + ArrayRef offsets_shape; + if (auto offsets_ty = dyn_cast(offsets.getType()); + offsets_ty != nullptr) { + if (!HasMemorySpace(offsets_ty, MemorySpace::kVmem)) { + return emitOpError("Offsets memref must be in VMEM"); + } + offsets_shape = offsets_ty.getShape(); + } else if (auto offsets_ty = dyn_cast(offsets.getType()); + offsets_ty != nullptr) { + offsets_shape = offsets_ty.getShape(); + } else { + return emitOpError("Offsets must be a memref or vector type"); + } + + if (MemRefType sem_ty = getMemRefType(getSemaphore()); + sem_ty.getRank() != 0) { + return emitOpError("Semaphore must be rank 0"); + } + + if (is_gather) { + return verifyGather(/*operand_ty=*/source_ty, + /*offsets_shape=*/offsets_shape, + /*result_ty=*/target_ty); + } + return verifyScatter(/*updates_ty=*/source_ty, + /*offsets_shape=*/offsets_shape, + /*operand_ty=*/target_ty); +} + +LogicalResult EnqueueIndirectDMAOp::canonicalize(EnqueueIndirectDMAOp op, + PatternRewriter& rewriter) { + return propagateTiledLayoutToConsumer(op, rewriter); +} + +// TODO(b/395630795): Remove after 2025-08-10. LogicalResult WaitDMAOp::verify() { auto sem_type = getMemRefType(getSemaphore()); if (sem_type.getRank() != 0) { - emitOpError("DMA wait semaphore must be rank 0"); - return failure(); + return emitOpError("DMA wait semaphore must be rank 0"); } return success(); } +void WaitDMA2Op::build(OpBuilder &builder, OperationState &state, + Value semaphore, Value src, Value dst) { + build(builder, state, semaphore, src, dst, /*device_id=*/nullptr, + /*core_id=*/nullptr); +} + LogicalResult WaitDMA2Op::verify() { auto sem_type = getMemRefType(getSemaphore()); if (sem_type.getRank() != 0) { - emitOpError("DMA wait semaphore must be rank 0"); - return failure(); + return emitOpError("DMA wait semaphore must be rank 0"); } return success(); } +LogicalResult WaitDMA2Op::canonicalize(WaitDMA2Op op, + PatternRewriter& rewriter) { + return propagateTiledLayoutToConsumer(op, rewriter); +} + +FailureOr WaitIndirectDMAOp::isGather() { + return mlir::tpu::isGather(*getOperation(), getSrc(), getDst()); +} + +LogicalResult WaitIndirectDMAOp::verify() { + FailureOr issuing_core = GetCoreTypeOfParentFunc(**this); + if (failed(issuing_core)) { + return issuing_core; + } + if (*issuing_core != CoreType::kScVectorSubcore) { + return emitOpError( + "Wait indirect DMA is supported only on the SC vector subcore"); + } + MemRefType sem_type = getMemRefType(getSemaphore()); + if (sem_type.getRank() != 0) { + return emitOpError("Indirect DMA wait semaphore must be rank 0"); + } + return isGather(); +} + +LogicalResult WaitIndirectDMAOp::canonicalize(WaitIndirectDMAOp op, + PatternRewriter& rewriter) { + return propagateTiledLayoutToConsumer(op, rewriter); +} + LogicalResult RegionOp::verify() { for (auto result_type : getResultTypes()) { if (!isa(result_type)) { @@ -1084,7 +1741,7 @@ LogicalResult ConcatenateOp::verify() { if (getOperands().size() < 2) { return emitOpError("Expected at least 2 operands for concatenate op."); } - auto first_type = getOperand(0).getType().cast(); + auto first_type = cast(getOperand(0).getType()); auto first_shape = first_type.getShape(); auto first_dtype = first_type.getElementType(); for (auto operand : getOperands()) { @@ -1110,28 +1767,53 @@ LogicalResult ConcatenateOp::verify() { return success(); } -LogicalResult LogOp::verify() { - FailureOr> logging_core_type_maybe = - GetCoreTypeOfParentFunc(**this); - if (failed(logging_core_type_maybe)) { - return failure(); +/*static*/ LogicalResult ConcatenateOp::inferReturnTypes( + MLIRContext* context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl& inferredReturnTypes) { + ConcatenateOpAdaptor adaptor(operands, attributes, properties, regions); + auto dimension = adaptor.getDimension(); + for (auto [i, operand] : llvm::enumerate(operands)) { + if (auto operand_ty = dyn_cast(operand.getType()); + !operand_ty || operand_ty.getRank() <= dimension) { + return failure(); + } } - CoreType logging_core_type = logging_core_type_maybe->value_or(CoreType::kTc); - if ((logging_core_type == CoreType::kScScalarSubcore || - logging_core_type == CoreType::kScVectorSubcore) && - getFormattedAttr() != nullptr && getFormattedAttr().getValue()) { + auto first_type = cast(operands[0].getType()); + llvm::SmallVector result_shape = + llvm::to_vector(first_type.getShape()); + Type result_dtype = first_type.getElementType(); + for (int i = 1; i < operands.size(); ++i) { + result_shape[dimension] += + cast(operands[i].getType()).getDimSize(dimension); + } + inferredReturnTypes.push_back(VectorType::get(result_shape, result_dtype)); + return success(); +} + +LogicalResult LogOp::verify() { + FailureOr logging_core = GetCoreTypeOfParentFunc(**this); + if (failed(logging_core)) { + return logging_core; + } + bool is_sc_core = *logging_core == CoreType::kScScalarSubcore || + *logging_core == CoreType::kScVectorSubcore; + if (is_sc_core && getFormattedAttr() != nullptr && + getFormattedAttr().getValue()) { return emitOpError("Formatted logging is not supported on SC"); } - switch (logging_core_type) { - case CoreType::kTc: - case CoreType::kScScalarSubcore: - return success(); - case CoreType::kScVectorSubcore: - return emitOpError("Log op is not supported on the SC vector subcore"); + if (is_sc_core && getInputs().size() > 1) { + return emitOpError("SC logging only supports 0 or 1 inputs"); } - return emitOpError( - absl::StrFormat("Unexpected core type: %s", - stringifyCoreType(logging_core_type_maybe->value()))); + if (*logging_core == CoreType::kScScalarSubcore) { + for (mlir::Value input : getInputs()) { + if (llvm::isa(input.getType())) { + return emitOpError( + "SC scalar subcore does not support logging vectors"); + } + } + } + return success(); } LogicalResult WeirdOp::verify() { @@ -1176,6 +1858,56 @@ LogicalResult ReciprocalOp::verify() { return success(); } +LogicalResult UnpackSubelementsOp::verify() { + const int packing_factor = getElementTypeBitwidth(getType()) / + getElementTypeBitwidth(getSource().getType()); + if (auto index = getIndex(); index >= packing_factor) { + return emitOpError("Index must be between 0 and the packing factor (") + << packing_factor << "), got " << index; + } + return success(); +} + +LogicalResult UnpackSubelementsOp::canonicalize(UnpackSubelementsOp op, + PatternRewriter& rewriter) { + auto src_elem_ty = op.getSource().getType().getElementType(); + auto dst_elem_ty = op.getType().getElementType(); + if (!src_elem_ty.isSignlessInteger() || !dst_elem_ty.isSignlessInteger()) { + return failure(); + } + if (!op.getSignExtended()) { + // Unpack of pack with the same format is reversible if not sign extended. + if (auto pack = dyn_cast(op.getSource().getDefiningOp()); + pack && pack.getPackFormat() == op.getPackFormat() && + pack.getSources().front().getType() == op.getType()) { + Value source = pack.getPaddedSources( + pack.getSources(), pack.getPositions(), + getElementTypeBitwidth(op.getType()) / + getElementTypeBitwidth(pack.getType()))[op.getIndex()]; + if (source) { + rewriter.replaceAllOpUsesWith(op, source); + return success(); + } + } + return failure(); + } + // Set `sign_extended` to false if it's used by pack that reduces the source + // bitwidth. + for (auto user : op->getUsers()) { + auto pack = dyn_cast(user); + if (!pack) { + return failure(); + } + auto packed_elem_ty = pack.getType().getElementType(); + if (!packed_elem_ty.isSignlessInteger() || + getTypeBitwidth(packed_elem_ty) > getTypeBitwidth(src_elem_ty)) { + return failure(); + } + } + rewriter.modifyOpInPlace(op, [&]() { op.setSignExtended(false); }); + return success(); +} + void PackSubelementsOp::build(OpBuilder &builder, OperationState &state, const VectorType output_type, const ArrayRef padded_sources, @@ -1208,13 +1940,14 @@ LogicalResult PackSubelementsOp::verify() { if (getPositions().size() != getSources().size()) { return emitOpError("Size of sources and positions must match"); } - const int packing_factor = cast(getSources().front().getType()) - .getElementTypeBitWidth() / - getType().getElementTypeBitWidth(); + const int packing_factor = + getElementTypeBitwidth(cast(getSources().front().getType())) / + getElementTypeBitwidth(getType()); SmallVector seen_positions(packing_factor, false); for (const int32_t position : getPositions()) { if (position < 0 || packing_factor <= position) { - return emitOpError("Positions must be between 0 and the packing factor"); + return emitOpError("Positions must be between 0 and the packing factor (") + << packing_factor << "), got " << position; } if (seen_positions[position]) { return emitOpError("Positions must be unique"); @@ -1224,40 +1957,340 @@ LogicalResult PackSubelementsOp::verify() { return success(); } -LogicalResult DynamicGatherOp::verify() { - if (getSource().getType() != getType()) { - return emitOpError("Expected source and result types must match"); +namespace { +LogicalResult verifyElementwisePacking(Operation *op, Type unpacked_ty, + Type packed_ty) { + if (unpacked_ty.isF32() && !packed_ty.isBF16()) { + return op->emitOpError( + "Only packing/unpacking between f32 and bf16 is supported for floats"); + } + if (unpacked_ty.isSignlessInteger(32) && + !packed_ty.isSignlessInteger(16) && + !packed_ty.isSignlessInteger(8) && + !packed_ty.isSignlessInteger(4)) { + return op->emitOpError( + "Only packing/unpacking between i32 and i16/i8/i4 is supported for " + "integers"); } - if (getIndices().getType().getShape() != getIndices().getType().getShape()) { - return emitOpError("Expected indices and result shapes must match"); + return success(); +} +} // namespace + +LogicalResult PackElementwiseOp::verify() { + if (getSources().empty()) { + return emitOpError("At least one source is required"); } - if (!getIndices().getType().getElementType().isInteger(32)) { - return emitOpError("Not implemented: Only i32 indices supported"); + const auto src_vty = cast(getSources().front().getType()); + if (getElementTypeBitwidth(src_vty) != getElementTypeBitwidth(getType())) { + return emitOpError("All sources must have the same bitwidth as the result"); + } + if (!getType().getElementType().isSignlessInteger()) { + return emitOpError("Output type must be a signless integer type"); + } + + auto src_elem_ty = src_vty.getElementType(); + auto tgt_elem_ty = getTargetType(); + if (!(src_elem_ty.isF32() && tgt_elem_ty.isBF16()) && + !(src_elem_ty.isSignlessInteger() && tgt_elem_ty.isSignlessInteger())) { + return emitOpError( + "Only packing f32 -> bf16 and integer -> integer is supported"); + } + const int packing_factor = + getElementTypeBitwidth(src_vty) / getTypeBitwidth(getTargetType()); + if (packing_factor != getSources().size()) { + return emitOpError("The number of sources must match the packing factor (") + << packing_factor << "), got " << getSources().size(); } return success(); } -LogicalResult AssumeMultipleOp::verify() { - auto operand_value = getValue(); - auto divisor = getMultiple(); - if (auto cst_op = operand_value.getDefiningOp()) { - auto int_attr = dyn_cast(cst_op.getValue()); - // Illegal usage of AssumeMultipleOp. - if (!int_attr) { +LogicalResult UnpackElementwiseOp::verify() { + if (failed(verifyElementwisePacking(*this, getType(), getSourceType()))) { + return failure(); + } + const int packing_factor = + getElementTypeBitwidth(getType()) / getTypeBitwidth(getSourceType()); + if (auto index = getIndex(); index >= packing_factor) { + return emitOpError("Index must be between 0 and the packing factor (") + << packing_factor << "), got " << index; + } + return success(); +} + +LogicalResult DynamicGatherOp::verify() { + const int64_t rank = getSource().getType().getRank(); + SmallVector seen(rank, false); + for (int32_t d : getDimensions()) { + if (d < 0 || d >= rank) { + return emitOpError("Dimensions must be in [0, rank), but got ") << d; + } + if (seen[d]) { + return emitOpError("Dimensions must be unique"); + } + seen[d] = true; + } + const ArrayRef source_shape = getSource().getType().getShape(); + const ArrayRef result_shape = getType().getShape(); + if (source_shape.size() != result_shape.size()) { + return emitOpError("Source and result shapes must have the same rank"); + } + for (int32_t i = 0; i < source_shape.size(); ++i) { + if (!seen[i] && source_shape[i] != result_shape[i]) { return emitOpError( - "Illegal user annotation, expected an integer, but got ") - << cst_op.getValue(); + "Source and result shapes must match on non-gather dimensions"); + } + } + return success(); +} + +/*static*/ LogicalResult DynamicGatherOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + VectorType source_vty = cast(operands[0].getType()); + VectorType indices_vty = cast(operands[1].getType()); + inferredReturnTypes.push_back( + VectorType::get(indices_vty.getShape(), source_vty.getElementType())); + return success(); +} + +LogicalResult AllReduceOp::verify() { + auto in_ty = getInput().getType(); + auto in_bitwidth = getElementTypeBitwidth(in_ty); + auto out_ty = getOutput().getType(); + auto out_bitwidth = getElementTypeBitwidth(out_ty); + auto kind = getKind(); + + if (in_bitwidth == 1) { + // For mask vectors, the single (semantically scalar) result is broadcast + // into a vector of 32-bit ints of whatever shape the target supports (not + // necessarily the same as the input). + if (out_bitwidth != 32) { + return emitOpError("Vector mask all-reduce must have i32 output"); + } + switch (kind) { + case ReductionKind::kSum: + case ReductionKind::kFindFirstSet: + break; + default: + return emitOpError( + "Mask all-reduce only supports sum and find_first_set kinds"); + } + return success(); + } + + switch (kind) { + case ReductionKind::kSum: + case ReductionKind::kMax: + case ReductionKind::kMin: + if (in_ty != out_ty) { + return emitOpError( + "Sum, max, and min reductions must have the same " + "input and output type"); + } + break; + case ReductionKind::kArgMax: + case ReductionKind::kArgMin: + if (in_ty.getShape() != out_ty.getShape()) { + return emitOpError("Arg_max and arg_min " + "must have the same input and output shape"); + } + if (!in_ty.getElementType().isF32()) { + return emitOpError("Not Implemented: Only f32 input is supported for " + "arg_max and arg_min"); + } + if (!out_ty.getElementType().isSignlessInteger(in_bitwidth)) { + return emitOpError(absl::StrFormat( + "Arg_max and arg_min must have i%d output", in_bitwidth)); + } + break; + case ReductionKind::kFindFirstSet: + return emitOpError("Only i1 input is supported for find_first_set"); + break; + } + return success(); +} + +LogicalResult ReduceIndexOp::verify() { + auto in_ty = getInput().getType(); + auto out_ty = getOutput().getType(); + auto bitwidth = getElementTypeBitwidth(in_ty); + auto axis = getAxis(); + auto kind = getKind(); + if (kind != ReductionKind::kArgMax && + kind != ReductionKind::kArgMin) { + return emitOpError("Reduction kind must be arg_max or arg_min"); + } + if (!in_ty.getElementType().isF32()) { + return emitOpError("Not Implemented: Only f32 input is supported for " + "arg_max and arg_min"); + } + if (!out_ty.getElementType().isSignlessInteger(bitwidth)) { + return emitOpError(absl::StrFormat( + "Arg_max and arg_min must have i%d output", bitwidth)); + } + + auto in_shape = in_ty.getShape(); + auto out_shape = out_ty.getShape(); + if (axis < 0 || axis >= in_shape.size()) { + return emitOpError("Axis must be in [0, ") + << in_shape.size() << "), but got " << axis; + } + + if (in_shape.size() < 2) { + return emitOpError("Not Implemented: Only input rank > 1 is supported."); + } + if (out_shape.size() != in_shape.size() - 1) { + return emitOpError("Output rank must be one less than input rank"); + } + int out_dim = 0; + for (int i = 0; i < in_shape.size(); ++i) { + if (i == axis) { + continue; } - if (int_attr.getInt() % divisor != 0) { + if (in_shape[i] != out_shape[out_dim]) { return emitOpError( - "Illegal user annotation, expected an integer that is " - "divisible by the multiple, but got ") - << int_attr.getInt() << " % " << divisor; + "Output shape must match input shape on non-reduction dimensions. ") + << "Output shape (" << out_shape << ") does not match input shape (" + << in_shape << ") at input dimension " << i; + } + out_dim++; + } + return success(); +} + +LogicalResult AssumeMultipleOp::verify() { + if (getMultiple() < 1) { + return emitError("Multiple must be >= 1, got ") << getMultiple(); + } + if (auto value = mlir::getConstantIntValue(getValue()); + value.has_value() && (*value % getMultiple() != 0)) { + return emitError("Operand is a constant ") + << *value << " that is not a multiple of " << getMultiple(); + } + return success(); +} + +LogicalResult SublaneShuffleOp::verify() { + auto lhs = getLhs(); + auto rhs = getRhs(); + auto result = getResult(); + auto lhs_ty = dyn_cast(lhs.getType()); + auto rhs_ty = dyn_cast(rhs.getType()); + auto result_ty = dyn_cast(result.getType()); + + if (!lhs_ty || !rhs_ty || !result_ty) { + return emitOpError("Expected operands and result to be vector types"); + } + + if (lhs_ty.getShape() != rhs_ty.getShape() || + lhs_ty.getShape() != result_ty.getShape()) { + return emitOpError("Expected lhs, rhs, and result shapes to match"); + } + if (lhs_ty.getElementType() != rhs_ty.getElementType() || + lhs_ty.getElementType() != result_ty.getElementType()) { + return emitOpError("Expected lhs, rhs, and result element types to match"); + } + + auto pattern = getPattern(); + auto shape = result_ty.getShape(); + if (shape.size() < 2 || shape.size() > 3) { + return emitOpError("Vreg rank should be 2 or 3"); + } + auto sublane_count = shape[0]; + + if (pattern.size() != sublane_count) { + return emitOpError("Expected pattern size (") + << pattern.size() << ") to match result/operand sublanes (" + << sublane_count << ")"; + } + + int64_t total_input_sublanes = sublane_count * 2; + for (int32_t idx : pattern) { + if (idx < 0 || idx >= total_input_sublanes) { + return emitOpError("Pattern index ") << idx << " out of bounds [0, " + << (total_input_sublanes - 1) << "]"; } } return success(); } +OpFoldResult TruncFOp::fold(FoldAdaptor adaptor) { + auto resElemType = cast(getElementTypeOrSelf(getType())); + const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics(); + return constFoldCastOp( + adaptor.getOperands(), getType(), + [this, &targetSemantics](const APFloat &a, bool &castStatus) { + llvm::RoundingMode llvmRoundingMode = + convertTpuRoundingModeToLLVMIR(getRoundingMode()); + FailureOr result = + convertFloatValue(a, targetSemantics, llvmRoundingMode); + if (failed(result)) { + castStatus = false; + return a; + } + return *result; + }); +} + +OpFoldResult ExtFOp::fold(FoldAdaptor adaptor) { + auto resElemType = cast(getElementTypeOrSelf(getType())); + const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics(); + return constFoldCastOp( + adaptor.getOperands(), getType(), + [&targetSemantics](const APFloat &a, bool &castStatus) { + FailureOr result = convertFloatValue(a, targetSemantics); + if (failed(result)) { + castStatus = false; + return a; + } + return *result; + }); +} + +LogicalResult ReshapeOp::verify() { + auto src_ty = getSource().getType(); + auto dst_ty = getResult().getType(); + if (src_ty.getElementType() != dst_ty.getElementType()) { + return emitOpError("element type must match"); + } + if (src_ty.getNumElements() != dst_ty.getNumElements()) { + return emitOpError() << "element count must match"; + } + return success(); +} + +OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { + // No-op reshape. + if (getSource().getType() == getType()) { + return getSource(); + } + // Reshape of a reshape is a reshape. + if (auto source_reshape = getSource().getDefiningOp()) { + setOperand(source_reshape.getSource()); + return getResult(); + } + // Reshape of a constant is a constant. + if (auto cst = dyn_cast_if_present(adaptor.getSource())) { + return cst.reshape(getType()); + } + return nullptr; +} + +LogicalResult StochasticConvertElementwiseOp::verify() { + auto dst_ty = getDstType(); + if (!dst_ty.isBF16() && + !llvm::isa(dst_ty)) { + return emitOpError( + "Only bf16, f8e5m2, f8e4m3fn, and f8e4m3b11fnuz are supported as " + "destination types."); + } + return success(); +} + } // namespace tpu } // namespace mlir diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.td b/jaxlib/mosaic/dialect/tpu/tpu_ops.td new file mode 100644 index 000000000000..ebb1cabbc30b --- /dev/null +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.td @@ -0,0 +1,1519 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TPU_OPS +#define TPU_OPS + +include "mlir/IR/OpBase.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinAttributeInterfaces.td" +include "mlir/IR/BuiltinTypeInterfaces.td" +include "mlir/IR/EnumAttr.td" +include "mlir/Pass/PassBase.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "jaxlib/mosaic/dialect/tpu/tpu.td" + +// TODO(b/369418606): Find out the way to verify vreg size. +def TPU_Vreg : Type; + +def TPU_CoreType : I32EnumAttr<"CoreType", "Core type", [ + I32EnumAttrCase<"kTc", 0, "tc">, + I32EnumAttrCase<"kScScalarSubcore", 1, "sc_scalar_subcore">, + I32EnumAttrCase<"kScVectorSubcore", 2, "sc_vector_subcore"> +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tpu"; +} + +def TPU_CoreTypeEnum : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def TPU_PipelineMode : I32EnumAttr<"PipelineMode", "Pipeline mode", [ + I32EnumAttrCase<"kSynchronous", 1, "synchronous">, + I32EnumAttrCase<"kDoubleBuffered", 2, "double_buffered"> + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tpu"; +} + +def TPU_PipelineModeEnum : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def TPU_SemaphoreType : TPU_Type<"Semaphore", "semaphore", [MemRefElementTypeInterface]>; +def TPU_DMASemaphoreType : TPU_Type<"DMASemaphore", "dma_semaphore", [MemRefElementTypeInterface]>; +def TPU_SomeSemaphoreType : AnyTypeOf<[TPU_SemaphoreType, TPU_DMASemaphoreType]>; + +def TPU_DimensionSemantics : I32EnumAttr<"DimensionSemantics", "Dimension semantics", [ + I32EnumAttrCase<"parallel", 0>, + I32EnumAttrCase<"arbitrary", 1>, + I32EnumAttrCase<"core_parallel", 2>, + I32EnumAttrCase<"subcore_parallel", 3> +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tpu"; +} + +def TPU_DimensionSemanticsEnum + : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +// All indices/sizes are in element-space. +// Note that the implementation will require statically provable tile alignment. +def TPU_ElementWindowAttr : TPU_Attr<"ElementWindow", "element_window"> { + // Including low padding, to avoid backwards-incompatible changes once we add it. + let parameters = (ins + ArrayRefParameter<"int64_t", "">:$pad_low, + ArrayRefParameter<"int64_t", "">:$pad_high + ); + let assemblyFormat = "`<` `[` $pad_low `]` `,` `[` $pad_high `]` `>`"; +} + +def TPU_ContractPrecision : I32EnumAttr<"ContractPrecision", "Contraction precision", [ + I32EnumAttrCase<"kBF16", 0, "bf16">, + I32EnumAttrCase<"kFP32", 1, "fp32"> +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tpu"; +} + +def TPU_ContractPrecisionEnum + : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def TPU_PackFormat : I32EnumAttr<"PackFormat", "Pack format", [ + I32EnumAttrCase<"kCompressed", 0, "compressed">, + I32EnumAttrCase<"kInterleaved", 1, "interleaved"> +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tpu"; +} + +def TPU_PackFormatEnum : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def TPU_TiledCase : I32EnumAttrCase<"tiled", 0>; +def TPU_LaneCase : I32EnumAttrCase<"lanes", 1>; +def TPU_SublaneCase : I32EnumAttrCase<"sublanes", 2>; +def TPU_VectorLayoutDim : I32EnumAttr< + "VectorLayoutDim", "", [TPU_TiledCase, TPU_LaneCase, TPU_SublaneCase]>; + +def TPU_VectorLayoutAttr : TPU_Attr<"VectorLayout", "vpad"> { + let description = [{TODO}]; + + let parameters = (ins "Layout":$layout); + let hasCustomAssemblyFormat = 1; +} + +def TPU_TiledLayoutAttr + : TPU_Attr<"TiledLayout", "tiled", + [DeclareAttrInterfaceMethods]> { + let description = [{ + This attribute represents tiled layouts in memrefs. + + Multiple levels of tiling are supported with the following restriction: + - Additional levels of tiling may not add any padding. + - Additional levels of tiling may not tile previously untiled dimensions, + that is, they cannot tile across first-level tiles. + + Tile strides encode the stride when moving along a given dimension. They + must have the same rank as the shape and must be decreasing with increasing + dimension number. For tiled dimensions, the stride applies only when moving + across first-level tiles. The strides are in units of the size of the first + tile, or 1 if there are no tiles. + }]; + let parameters = (ins + ArrayRefParameter<"::xla::Tile", "">:$tiles, + ArrayRefParameter<"int64_t", "">:$tile_strides + ); + let extraClassDeclaration = [{ + static ::llvm::SmallVector getDefaultTileStrides(::llvm::ArrayRef<::xla::Tile> tiles, ::llvm::ArrayRef shape); + bool tilesAreKnownContiguous(::llvm::ArrayRef shape) const; + + int64_t getRank() const { + return getTileStrides().size(); + } + int64_t getUntiledRank() const; + + ::llvm::SmallVector getExpandedShape(::llvm::ArrayRef shape) const; + ::llvm::SmallVector getExpandedStrides() const; + }]; + + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; +} + +def TPU_MemorySpace : I32EnumAttr<"MemorySpace", "Memory space", [ + I32EnumAttrCase<"kAny", 4294967295, "any">, + I32EnumAttrCase<"kVmem", 0, "vmem">, + I32EnumAttrCase<"kSmem", 1, "smem">, + I32EnumAttrCase<"kHbm", 2, "hbm">, + I32EnumAttrCase<"kCmem", 3, "cmem">, + I32EnumAttrCase<"kSemaphoreMem", 4, "semaphore_mem">, + I32EnumAttrCase<"kVmemShared", 5, "vmem_shared">, + I32EnumAttrCase<"kHost", 6, "host"> +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tpu"; +} + +def TPU_MemorySpaceEnum + : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +class TPU_Op traits = []> : + Op { +} + +def DefaultMemWrite : MemoryEffects<[MemWrite]>; +def DefaultMemRead : MemoryEffects<[MemRead]>; + +def TPU_ReductionKind : I32EnumAttr<"ReductionKind", "Reduction kind", [ + I32EnumAttrCase<"kSum", 0, "sum">, + I32EnumAttrCase<"kMax", 1, "max">, + I32EnumAttrCase<"kMin", 2, "min">, + I32EnumAttrCase<"kArgMax", 3, "arg_max">, + I32EnumAttrCase<"kArgMin", 4, "arg_min">, + I32EnumAttrCase<"kFindFirstSet", 5, "find_first_set"> +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tpu"; +} + +def TPU_ReductionKindAttr + : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def TPU_AllReduceOp : TPU_Op<"all_reduce", [Pure]> { + let arguments = (ins AnyVectorOfNonZeroRank:$input, I64Attr:$dim, TPU_ReductionKindAttr:$kind); + let results = (outs AnyVectorOfNonZeroRank:$output); + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type($output) + }]; + let hasVerifier = 1; +} + +def TPU_ReduceIndexOp : TPU_Op<"reduce_index", [Pure]> { + let arguments = (ins + AnyVectorOfNonZeroRank:$input, + I32Attr:$axis, + TPU_ReductionKindAttr:$kind + ); + let results = (outs VectorOfNonZeroRankOf<[I32]>:$output); + let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; + let hasVerifier = 1; +} + +// tpu.scan performs a scan across a vector. +// +// If a mask is provided, all output elements before the first unmasked input +// element is undefined. Subsequent masked elements will hold the result +// of the last unmasked element. +// +// For example, a "kSum" reduction over a input vector [1, 2, 3, 4] +// with mask [0, 1, 0, 1] will produce the output vector [X, 2, 2, 6]. +// where X is some undefined value. +// +// output : Result vector. Must have the same shape as source. +// input : Vector to scan. +// kind : Reduction operator. Must be one of "kSum", "kMax", or "kMin". +// Must be "kSum" if input is an I1 vector. +// mask : Elementwise vector mask. The scan operation starts from the +// lowest-indexed non-masked vector element (all previous elements +// have undefined values). Not taken for I1 input vectors. +def TPU_ScanOp : TPU_Op<"scan"> { + let arguments = (ins + VectorOfNonZeroRankOf<[I1, I16, I32, BF16, F32]>:$input, + TPU_ReductionKindAttr:$kind, + Optional>:$mask + ); + let results = (outs VectorOfNonZeroRankOf<[I16, I32, BF16, F32]>:$output); + let assemblyFormat = [{ + $kind `,` $input (`masked` $mask^)? attr-dict `:` type($input) `,` type($mask) `->` type($output) + }]; + let hasVerifier = 1; +} + +def TPU_SortOp : TPU_Op<"sort", [Pure]> { + let summary = "Sorts key/value pairs based on keys."; + let description = [{ + tpu.sort performs a stable sort of key/value pairs in ascending or + descending order based on keys. Masked-out keys and values are placed at the + end of the output vectors. An output mask indicates which outputs + correspond to the valid inputs. + }]; + let arguments = (ins + VectorOfNonZeroRankOf<[I32,F32]>:$keys, + VectorOfNonZeroRankOf<[I32,F32]>:$values, + Optional>:$mask, + DefaultValuedAttr:$descending + ); + let results = (outs + VectorOfNonZeroRankOf<[I1]>:$output_mask, + VectorOfNonZeroRankOf<[I32,F32]>:$sorted_keys, + VectorOfNonZeroRankOf<[I32,F32]>:$sorted_values + ); + let assemblyFormat = [{ + $keys `,` $values (`masked` $mask^)? attr-dict `:` functional-type(operands, results) + }]; + let hasVerifier = 1; +} + +def TPU_StoreOp : TPU_Op<"store", [DefaultMemWrite, AttrSizedOperandSegments]> { + let arguments = (ins + TPU_Vreg:$valueToStore, + AnyType:$base, + Variadic:$indices, + DenseBoolArrayAttr:$sublane_mask, + Optional:$mask, + OptionalAttr:$sublane_stride // In sublane-sized units + ); + let results = (outs); + let assemblyFormat = [{ + $base `[` $indices `]` `,` $valueToStore (`masked` $mask^)? `sublanes` $sublane_mask (`sublane_stride` $sublane_stride^)? attr-dict `:` type($base) `,` type($valueToStore) `,` type($mask) + }]; +} + +def TPU_LoadOp : TPU_Op<"load", [DefaultMemRead]> { + let arguments = (ins + AnyType:$base, + Variadic:$indices, + DenseBoolArrayAttr:$sublane_mask, + OptionalAttr:$sublane_stride // In sublane-sized units + ); + let results = (outs TPU_Vreg:$result); + let assemblyFormat = [{ + $base `[` $indices `]` `sublanes` $sublane_mask (`sublane_stride` $sublane_stride^)? attr-dict `:` type($base) `,` type($result) + }]; + let description = [{ + Similar to `vector::LoadOp` but with `sublane_mask` and `sublane_stride`. + When `indices` are negative, it means loading from negative offset + of `base` address. + }]; +} + +// TODO(jevinjiang): migrate tpu.strided_store to general vector store op. +def TPU_VectorStoreOp :TPU_Op<"vector_store", [DefaultMemWrite, AttrSizedOperandSegments]> { + let arguments = (ins + AnyVectorOfNonZeroRank:$valueToStore, + AnyMemRef:$base, + Variadic:$indices, + DenseI32ArrayAttr:$strides, + Optional:$mask, // Elementwise mask. + DefaultValuedAttr:$add + ); + let results = (outs); + let assemblyFormat = [{ + $base `[` $indices `]` `,` $valueToStore (`masked` $mask^)? attr-dict `:` type($base) `,` type($valueToStore) `,` type($mask) + }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + +// tpu.vector_load loads a vector from memory into a register. +// +// base : Memref to load from. +// indices: Scalar indices into base. indices must be of the same rank as the +// base memref shape. +// strides: The stride to use for calculating the address of subsequent +// elements. If left unspecified, the stride is implicitly 1 along +// each dimension. Otherwise the stride must match the rank of the +// memref shape. +// mask : Elementwise vector mask. Must be broadcastable to the shape of the +// result vector. Depending on the core type, this may be a dynamic +// (lane) mask consumed from a register or a static (sublane) mask +// that must be the result of arith.constant. +def TPU_VectorLoadOp :TPU_Op<"vector_load", [DefaultMemRead, AttrSizedOperandSegments]> { + let arguments = (ins + AnyMemRef:$base, + Variadic:$indices, + DenseI32ArrayAttr:$strides, + Optional:$mask // Elementwise mask. + ); + let results = (outs AnyVectorOfNonZeroRank:$result); + let assemblyFormat = [{ + $base `[` $indices `]` (`masked` $mask^)? attr-dict `:` type($base) `,` type($result) `,` type($mask) + }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + +def TPU_StridedLoadOp : TPU_Op<"strided_load", [DefaultMemRead]> { + let arguments = (ins + AnyMemRef:$base, + Variadic:$indices, + DenseI32ArrayAttr:$strides + ); + let results = (outs AnyVectorOfNonZeroRank:$result); + let assemblyFormat = [{ + $base `[` $indices `]` attr-dict `:` type($base) `,` type($result) + }]; + let hasVerifier = 1; +} + +def TPU_StridedStoreOp : TPU_Op<"strided_store", [DefaultMemWrite]> { + let arguments = (ins + AnyVectorOfNonZeroRank:$valueToStore, + AnyMemRef:$base, + Variadic:$indices, + DenseI32ArrayAttr:$strides + ); + let results = (outs); + let assemblyFormat = [{ + $base `[` $indices `]` `,` $valueToStore attr-dict `:` type($base) `,` type($valueToStore) + }]; + let hasVerifier = 1; +} + +// TODO: b/435258666 - Merge with tpu.vector_load_idx. +def TPU_ShuffledLoadOp : TPU_Op<"shuffled_load", [DefaultMemRead]> { + let arguments = (ins + AnyMemRef:$base, + Variadic:$indices, + DenseBoolArrayAttr:$sublane_mask, + DenseI32ArrayAttr:$sublane_offsets + ); + let results = (outs TPU_Vreg:$result); + let assemblyFormat = [{ + $base `[` $indices `]` attr-dict `:` type($base) `,` type($result) + }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + +// TODO: b/435258666 - Merge with tpu.vector_store_idx. +def TPU_ShuffledStoreOp : TPU_Op<"shuffled_store", [DefaultMemWrite]> { + let arguments = (ins + TPU_Vreg:$valueToStore, + AnyMemRef:$base, + Variadic:$indices, + DenseBoolArrayAttr:$sublane_mask, + DenseI32ArrayAttr:$sublane_offsets + ); + let results = (outs); + let assemblyFormat = [{ + $base `[` $indices `]` `,` $valueToStore attr-dict `:` type($base) `,` type($valueToStore) + }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + +// tpu.vector_load_idx loads values from arbitrary locations in memory. +// +// Each element in the output vector is loaded from an index in the base memref +// specified by the corresponding elements in the 'indices' vectors. The shape +// of each index vector must match the shape of the output vector. The number +// of index vectors must equal the rank of the base memref. +// +// For example, for a vector of length n with rank 2, the indices will look like: +// indices = [[idx0, idx1, ...], [idxn, idxn+1, ...]] +// where [idx0, idxn] is the offset of the first vector element. +// +// base : Memref specifying the base address. +// indices : Vectors of indices for each dimension of the base memref. +// mask : Optional elementwise vector mask. +def TPU_VectorLoadIdxOp :TPU_Op<"vector_load_idx", [DefaultMemRead, AttrSizedOperandSegments]> { + let arguments = (ins + MemRefOf<[I32, F32]>:$base, + Variadic>:$indices, + Optional>:$mask + ); + let results = (outs VectorOfNonZeroRankOf<[I32, F32]>:$value); + let assemblyFormat = [{ + $base `[` $indices `]` (`masked` $mask^)? attr-dict `:` type($base) `[` type($indices) `]` `,` type($value) `,` type($mask) + }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + +// tpu.vector_store_idx stores values to arbitrary locations in memory. +// +// Each element in the input vector is stored to an index in the base memref +// specified by the corresponding elements in the 'indices' vectors. The shape +// of each index vector must match the shape of the input vector. The number +// of index vectors must equal the rank of the base memref. +// +// For example, for a vector of length n with rank 2, the indices will look like: +// indices = [[idx0, idx1, ...], [idxn, idxn+1, ...]] +// where [idx0, idxn] is the offset of the first vector element. +// +// When multiple vector elements have the same index to store to, the data from +// the highest lane will be the one stored. If add is true, then the data will +// be added from the lowest lane to the highest lane. +// +// valueToStore: Vector to be stored. +// base : Memref specifying the base address. +// indices : Vectors of indices for each dimension of the base memref. +// mask : Optional elementwise vector mask. +// add : If true, add source values to target values. Otherwise, overwrite. +def TPU_VectorStoreIdxOp :TPU_Op<"vector_store_idx", [DefaultMemWrite, AttrSizedOperandSegments]> { + let arguments = (ins + VectorOfNonZeroRankOf<[I32, F32]>:$valueToStore, + MemRefOf<[I32, F32]>:$base, + Variadic>:$indices, + Optional>:$mask, + DefaultValuedAttr:$add + ); + let results = (outs); + let assemblyFormat = [{ + $base `[` $indices `]` `,` $valueToStore (`masked` $mask^)? attr-dict `:` type($base) `[` type($indices) `]` `,` type($valueToStore) `,` type($mask) + }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + +// TODO(jevinjiang): deprecate to use dynamic_rotate. +def TPU_RotateOp : TPU_Op<"rotate", [Pure, SameOperandsAndResultType]> { + let description = [{ + Rotates the given vector by the given amount in the given dimension, i.e., + for a 2D vector of shape (m, n), rotating dim 0 by `amount` will shift a row + at index `i` to index `(i + amount) % m` + }]; + let arguments = (ins + AnyVectorOfNonZeroRank:$value, + SI32Attr:$amount, + SI32Attr:$dimension, + // When the stride is specified, the rotation amount for each index on the + // stride dimension will be (amount + stride * index). + OptionalAttr:$stride, + OptionalAttr:$stride_dimension + ); + let results = (outs AnyVectorOfNonZeroRank:$result); + let assemblyFormat = [{ + $value `by` $amount `dim` $dimension (`stride` $stride `stride_dim` $stride_dimension^)? attr-dict `:` type($value) + }]; + let hasVerifier = 1; +} + +def TPU_DynamicRotateOp : TPU_Op<"dynamic_rotate", [Pure]> { + let arguments = (ins + AnyVectorOfNonZeroRank:$value, + I32:$amount, + SI32Attr:$dimension, + // When the stride is specified, the rotation amount for each index on the + // stride dimension will be (amount + stride * index). + OptionalAttr:$stride, + OptionalAttr:$stride_dimension + ); + let results = (outs AnyVectorOfNonZeroRank:$result); + let assemblyFormat = [{ + $value `by` $amount `dim` $dimension attr-dict `:` type($value) `,` type($amount) `->` type($result) + }]; + let hasVerifier = 1; +} + +def TPU_ScanCountOp : TPU_Op<"scan_count", [Pure, InferTypeOpAdaptor, SameOperandsAndResultShape]> { +let summary = [{ + ScanCountOp calculates the running duplicate occurrence count of the elements + in the input vector. Elements eligible for counting are specified by the + input mask vector. The output mask vector indicates one unique occurrence + per duplicate that was counted. + }]; + + let description = [{ + ScanCountOp calculates the running duplicate occurrence count of the elements + in the input vector, %values. The output vector, %counts, contains the running + duplicate occurrence count for the corresponding element in + the input vector, where the count is performed in ascending order of element + indices. For example, if the elements of %values at indices 0, 5, and 7 had + duplicate values, then the elements of %counts at indices 0, 5, and 7 would + be 1, 2, and 3, respectively. + + A mask vector, %in_mask, specifies which of the elements in the input vector + are eligible for counting. An element in %values that has its mask set to 0 + will always have a count of 1 in %counts, regardless of the position in the + vector, or whether there were duplicates or not. + }]; + + let arguments = (ins + VectorOfNonZeroRankOf<[I1]>:$in_mask, + AnyVectorOfNonZeroRank:$values + ); + let results = (outs + VectorOfNonZeroRankOf<[I1]>:$out_mask, + VectorOfNonZeroRankOf<[I32]>:$counts + ); + + let assemblyFormat = [{ + `mask` `(` $in_mask `:` type($in_mask) `)` + `value` `(` $values `:` type($values) `)` + attr-dict `:` type(results) + }]; + +} + +def TPU_IotaOp : TPU_Op<"iota", [Pure]> { + let description = [{ + Creates a vector that with values that start at 0 and increase along a + dimension resulting from collapsing the given `dimensions` together in + row-major order. + + Example: + ``` + tpu.iota {dimensions = array} : vector<4x3x2xi16> + ``` + This produces a vector with the following values: + ``` + [[[0, 4], [0, 4], [0, 4]] + [[1, 5], [1, 5], [1, 5]] + [[2, 6], [2, 6], [2, 6]] + [[3, 7], [3, 7], [3, 7]]] + ``` + }]; + let arguments = (ins DenseI32ArrayAttr:$dimensions); + let results = (outs VectorOfNonZeroRankOf<[AnyInteger, Index]>:$output); + let assemblyFormat = [{ attr-dict `:` type($output) }]; + let hasVerifier = 1; +} + +def TPU_ReshapeOp : TPU_Op<"reshape", [Pure]> { + let arguments = (ins AnyVectorOfNonZeroRank:$source); + let results = (outs AnyVectorOfNonZeroRank:$result); + let assemblyFormat = [{ $source attr-dict `:` type($source) `->` type($result) }]; + let hasVerifier = 1; + let hasFolder = 1; +} + +// TODO(mvoz): deprecated - use concat. Canonicalization will do so automatically. +// b/376295711 +def TPU_RepeatOp : TPU_Op<"repeat", [Pure]> { + let arguments = (ins + AnyVectorOfNonZeroRank:$source, + I32Attr:$dimension, + I32Attr:$times + ); + let results = (outs AnyVectorOfNonZeroRank:$output); + let assemblyFormat = [{ $source `,` $dimension `x` $times attr-dict `:` type($source) `->` type($output) }]; +} + +def TPU_BroadcastInSublanesOp : TPU_Op<"broadcast_in_sublanes", [Pure]> { + let description = [{ + For each sublane `i`, broadcasts the value in lane `lane + i` along the + entire sublane. For packed type, imagine the data is compressed unpacked + along sublane dimension, and the sublane count is multiplied by the packing + factor. + For example, for i16 with sublane count 8, `i` above is in [0, 8 * 2). + If `lane + i` is not in [0, lane_count), then the value in sublane `i` is + not defined (can be anything). + }]; + let arguments = (ins + TPU_Vreg:$source, // All sublanes should be equal. + I32Attr:$lane // Coordinates of the first element to take. + ); + // Output shape should be the same, except for position dim which contains + // the newly inserted dimension. + let results = (outs AnyVectorOfNonZeroRank:$output); + let assemblyFormat = [{ + $source `,` $lane attr-dict `:` type($source) `->` type($output) + }]; +} + +// Integer unpacks are always signed at the moment. +// +// When unpacking integers to integers, setting `sign_extended` to false will +// leave bits higher than source bitwidth as undefined. +// +// Take int4 to int16 interleaved unpacking and `index = 1` as an example: +// +// Source: +// +// Bits 28 24 20 16 12 8 4 0 +// --------abcd------------efgh---- +// +// where "a" and "e" are the sign bits of the values to be unpacked, and "-" are +// bits to be ignored. +// +// Unpacked, sign_extend = true: +// +// Bits 28 24 20 16 12 8 4 0 +// aaaaaaaaaaaaabcdeeeeeeeeeeeeefgh +// +// Unpacked, sign_extend = false: +// +// Bits 28 24 20 16 12 8 4 0 +// ------------abcd------------efgh +def TPU_UnpackSubelementsOp : TPU_Op<"unpack_subelements", [Pure]> { + let arguments = (ins + AnyVectorOfNonZeroRank:$source, + I32Attr:$index, + TPU_PackFormatEnum:$pack_format, + DefaultValuedAttr:$sign_extended + ); + let results = (outs AnyVectorOfNonZeroRank:$output); + let assemblyFormat = [{ $source `,` $index attr-dict `:` type($source) `->` type($output) }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + +// Integer packs are always signed at the moment. +// Float to integer packing rounds to nearest even. +// WARNING: pack(pack(a, b), pack(c, d)) == pack(a, b, c, d) only holds for +// compressed packing! +// Below, we use [ ... ] to denote the bounds of the vreg and use regular parens +// ( ... ) to denote packing of multiple subelements into a single 32-bit word. +// +// Interleaved packing +// +// Interleaved packing downcasts to a narrower dtype, and packs multiple elements +// into the same word coordinate from which they originated. If a and b are packed +// values, then interleaved packing first iterates over the operand list and only +// then over the subelements within each word. +// Take 16-bit vregs A, B, C and D: +/// +// [ (A000 A001) (A010 A011) ... ] +// [ (A100 A101) (A110 A111) ... ] +// ... +// +// An interleaved pack(a, b) from 16-bit values produces: +// +// [ (A000 B000 A001 B001) (A010 B010 A011 B011) ...] +// ... +// +// While an interleaved pack(a, b, c, d) produces the following subelements in +// each vreg word: +// +// [ (A000 B000 C000 D000 A001 B001 C001 D001) ... ] +// ... +// +// Compressed packing +// +// Compressed packing downcasts each value and then packs multiple rows together. +// A compressed pack(a, b) from 16-bit values produces: +// +// [ (A000 A001 A100 A101) (A010 A011 A110 A111) ... ] +// [ (A200 A201 A300 A301) (A210 A211 A310 A311) ... ] +// ... # 2 more sublanes +// [ (B000 B001 B100 B101) (B010 B011 B110 B111) ... ] +// [ (B200 B201 B300 B301) (B210 B211 B310 B311) ... ] +// ... +def TPU_PackSubelementsOp : TPU_Op<"pack_subelements", [Pure, SameTypeOperands]> { + let arguments = (ins + Variadic:$sources, + DenseI32ArrayAttr:$positions, + TPU_PackFormatEnum:$pack_format + ); + let results = (outs TPU_Vreg:$output); + let assemblyFormat = [{ $sources attr-dict `:` type($sources) `->` type($output) }]; + let builders = [ + OpBuilder<(ins "::mlir::VectorType":$output_type, "::mlir::ArrayRef<::mlir::Value>":$padded_sources, "::mlir::tpu::PackFormat":$pack_format)>, + ]; + let extraClassDeclaration = [{ + static ::mlir::SmallVector<::mlir::Value> getPaddedSources(::mlir::ValueRange sources, ::mlir::ArrayRef positions, int packing_factor); + }]; + let hasVerifier = 1; +} + +def TPU_PackElementwiseOp : TPU_Op<"pack_elementwise", [Pure, SameTypeOperands, ElementwiseMappable]> { + let description = [{ + Packs multiple `sources` elementwise into a single vector of a narrower `target_type`. + + The number of `sources` must equal the packing factor, which is the ratio of + the element bitwidth of the `sources` to the element bitwidth of the + `target_type`. Elements from the `sources` are interleaved and packed into + the `output`, ordered from lowest to highest bits, corresponding to their + order in the `sources`. The `output` is then bitcasted to the signless + integer type of the same bitwidth as the `sources`. + + Note that for integer packing, the bits in `sources` that exceed the + bitwidth of the `target_type` are just truncated. + For example, given the `sources` are int8 xxxx'1001 and yyyy'0011, + `target_type` is int4, the output will be 0011'1001. + }]; + let arguments = (ins + Variadic>:$sources, + TypeAttr:$target_type + ); + let results = (outs VectorOfNonZeroRankOf<[AnyInteger]>:$output); + let assemblyFormat = [{ $sources attr-dict `:` type($sources) `->` type($output) }]; + let hasVerifier = 1; +} + +def TPU_UnpackElementwiseOp : TPU_Op<"unpack_elementwise", [Pure, ElementwiseMappable]> { + let description = [{ + Unpacks a single vector from `source`, which contains multiple `source_type` + vectors packed elementwise. + + The `index` selects which packed value to extract from each word of `source`. + An `index` of 0 corresponds to the lowest bits. The extracted values are + cast to the output element type. + }]; + let arguments = (ins + VectorOfNonZeroRankOf<[I32]>:$source, + TypeAttr:$source_type, + I32Attr:$index + ); + let results = (outs VectorOfNonZeroRankOf<[F32, I32]>:$output); + let assemblyFormat = [{ $source `,` $index attr-dict `:` type($source) `->` type($output) }]; + let hasVerifier = 1; +} + +def TPU_RelayoutOp : TPU_Op<"relayout", [Pure, SameOperandsAndResultType]> { + let arguments = (ins AnyVectorOfAnyRank:$input); + let results = (outs AnyVectorOfAnyRank:$output); + let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; + let hasVerifier = 1; +} + +def TPU_PackMaskOp : TPU_Op<"pack_vmsk", [Pure, SameTypeOperands]> { + let arguments = (ins + VectorOfNonZeroRankOf<[I1]>: $low, + VectorOfNonZeroRankOf<[I1]>: $high + ); + let results = (outs VectorOfNonZeroRankOf<[I1]>:$output); + let assemblyFormat = [{ $low `,` $high `,` attr-dict `:` type($low) `,` type($high) `->` type($output) }]; +} + +def TPU_GatherOp : TPU_Op<"gather", [Pure]> { + let arguments = (ins + AnyVectorOfNonZeroRank:$source, + DenseI32ArrayAttr:$indices, + I32Attr:$dimension + ); + let results = (outs AnyVectorOfNonZeroRank:$output); + let assemblyFormat = [{ + $source `[` $indices `]` `in` $dimension attr-dict + `:` type($source) `->` type($output) + }]; +} + +def TPU_DynamicGatherOp : TPU_Op<"dynamic_gather", [Pure, DeclareOpInterfaceMethods, AllShapesMatch<["indices", "output"]>, AllElementTypesMatch<["source", "output"]>]> { + let description = [{ + Gathers elements from `source` using `indices`. + + The specified `dimensions` of `source` are collapsed together and indexed by + `indices`. + + Given a shape `N0 x N1 x ...`, the `output[i0, i1, ...]` is given by + `collapsed_source[j0, j1, ..., indices[i0, i1, ...] mod M]` where + - `collapsed_source` is the result of collapsing `dimensions` of `source` + into a new trailing dimension of size `M`. + - `jk` is the subsequence of `in` for `n` not in `dimensions`. + + When a single dimension is specified, this is similar to + `np.take_along_axis`. + }]; + let arguments = (ins + AnyVectorOfNonZeroRank:$source, + VectorOfNonZeroRankOf<[AnyInteger]>:$indices, + DenseI32ArrayAttr:$dimensions + ); + let results = (outs AnyVectorOfNonZeroRank:$output); + let assemblyFormat = [{ + $source `[` $indices `]` `in` $dimensions attr-dict + `:` type($source) `,` type($indices) `->` type($output) + }]; + let hasVerifier = 1; +} + +def TPU_RoundingMode : I32EnumAttr<"RoundingMode", "Rounding mode", [ + I32EnumAttrCase<"kTowardsZero", 0, "towards_zero">, + I32EnumAttrCase<"kToNearestEven", 1, "to_nearest_even">, +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tpu"; +} + +def TPU_RoundingModeEnum : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +// Internal operation. All arith.fptosi operations that change the bitwidth +// must be canonicalized to this operation. +def TPU_FPToSIOp : TPU_Op<"fptosi", [Pure, ElementwiseMappable]> { + let arguments = (ins AnyVectorOfAnyRank:$input, TPU_RoundingModeEnum:$rounding_mode); + let results = (outs AnyVectorOfAnyRank:$output); + let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; + let hasCanonicalizeMethod = 1; +} + +// Internal operation. All arith.sitofp operations that change the bitwidth +// must be canonicalized to this operation. +def TPU_SIToFPOp : TPU_Op<"sitofp", [Pure, ElementwiseMappable]> { + let arguments = (ins AnyType:$in, TPU_RoundingModeEnum:$rounding_mode); + let results = (outs AnyType:$output); + let assemblyFormat = [{ $in attr-dict `:` type($in) `->` type($output) }]; +} + +// Internal operation. +def TPU_ExtFOp : TPU_Op<"extf", [Pure, ElementwiseMappable]> { + let arguments = (ins AnyType:$in); + let results = (outs AnyType:$out); + let assemblyFormat = [{ $in attr-dict `:` type($in) `->` type($out) }]; + let hasFolder = 1; +} + +// Internal operation. +def TPU_TruncFOp : TPU_Op<"truncf", [Pure, ElementwiseMappable]> { + let arguments = ( + ins AnyType:$in, + TPU_RoundingModeEnum:$rounding_mode + ); + let results = (outs AnyType:$out); + let assemblyFormat = [{ $in attr-dict `:` type($in) `->` type($out) }]; + let hasFolder = 1; +} + +// TODO(apaszke): Think hard about precision +def TPU_MatmulOp : TPU_Op<"matmul", [Pure]> { + let arguments = (ins + AnyVectorOfNonZeroRank:$lhs, + AnyVectorOfNonZeroRank:$rhs, + AnyVectorOfNonZeroRank:$acc, + // These flags are deprecated - if dimension_numbers are defined, + // these flags are ignored. They will always be false after canonicalize. + DefaultValuedAttr:$transpose_lhs, + DefaultValuedAttr:$transpose_rhs, + OptionalAttr:$precision, + // NOTE: User-level optional, once canonicalized, always present. + OptionalAttr:$dimension_numbers + ); + let results = (outs AnyVectorOfNonZeroRank:$result); + let assemblyFormat = [{ + $lhs `,` $rhs `,` $acc attr-dict `:` type($lhs) `,` type($rhs) `,` type($acc) `->` type($result) + }]; + let hasCanonicalizer = 1; + let hasVerifier = 1; +} + +def TPU_ConcatenateOp : TPU_Op<"concatenate", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins + Variadic:$sources, + I32Attr:$dimension + ); + let results = (outs AnyVectorOfNonZeroRank:$output); + let assemblyFormat = [{ + $sources `in` $dimension attr-dict `:` type($sources) `->` type($output) + }]; + let hasVerifier = 1; +} + +def TPU_BitcastOp : TPU_Op<"bitcast", [Pure]> { + let arguments = (ins AnyVectorOfNonZeroRank:$input); + let results = (outs AnyVectorOfNonZeroRank:$output); + let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; + let hasVerifier = 1; +} + +def TPU_BitcastVregOp : TPU_Op<"bitcast_vreg", [Pure]> { + let arguments = (ins TPU_Vreg:$input); + let results = (outs TPU_Vreg:$output); + let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; + let hasFolder = 1; +} + +def TPU_WeirdOp : TPU_Op<"weird", [Pure, ElementwiseMappable]> { + let arguments = (ins AnyType:$input); // F32 vector or scalar + let results = (outs AnyType:$output); // I1 vector or scalar + let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; + let hasVerifier = 1; +} + +def TPU_ReciprocalOp : TPU_Op<"reciprocal", [Pure, SameOperandsAndResultType, ElementwiseMappable]> { + let arguments = (ins + AnyVectorOfNonZeroRank:$input, + DefaultValuedAttr:$approx + ); + let results = (outs AnyVectorOfNonZeroRank:$output); + let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; + let hasVerifier = 1; +} + +def TPU_StochasticConvertOp : TPU_Op<"stochastic_convert", [Pure, SameOperandsAndResultShape]> { + let arguments = (ins + VectorOfNonZeroRankOf<[F32]>:$input, + VectorOfNonZeroRankOf<[I32]>:$random + ); + let results = (outs AnyVectorOfNonZeroRank:$output); + let assemblyFormat = [{ $input `,` $random attr-dict `:` type($input) `,` type($random) `->` type($output) }]; +} + +def TPU_StochasticConvertElementwiseOp : TPU_Op<"stochastic_convert_elementwise", [Pure, ElementwiseMappable]> { + // Stochastically converts the input to the target dtype based on the mode. + // When the target dtype is less than 32 bits, the result occupies the lowest {bitwidth} bits in the I32 output. + let arguments = (ins + VectorOfNonZeroRankOf<[F32]>:$input, + VectorOfNonZeroRankOf<[I32]>:$random, + TypeAttr:$dst_type + ); + let results = (outs VectorOfNonZeroRankOf<[I32]>:$output); + let assemblyFormat = [{ $input `,` $random attr-dict `:` type($input) `,` type($random) `->` type($output) }]; + let hasVerifier = 1; +} + +def TPU_RollVectorsOp : TPU_Op<"roll_vectors", [Pure]> { + let arguments = (ins Variadic:$input); + let results = (outs AnyVectorOfAnyRank:$output); + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type($output) + }]; +} + +def TPU_UnrollVectorsOp : TPU_Op<"unroll_vectors", [Pure]> { + let arguments = (ins AnyVectorOfAnyRank:$input); + let results = (outs Variadic:$output); + let hasCanonicalizeMethod = 1; + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type($output) + }]; +} + +def TPU_CreateMaskOp : TPU_Op<"create_mask", [Pure, SameVariadicOperandSize]> { + // high is exclusive + let arguments = (ins Variadic:$low, Variadic:$high); + let results = (outs AnyType:$output); + let assemblyFormat = [{ + `[` $low `]``[` $high `]` attr-dict `:` type($output) + }]; +} + +def TPU_CreateSubelementMaskOp : TPU_Op<"create_subelement_mask", [Pure]> { + let summary = "Create a mask masking contiguous rows of subelements."; + let description = [{ + The "half-sublanes", "quarter-sublanes", etc. (unit is determined by + the type of `output`) of the mask are masked in the range specified by + `from` and `to`. + + - If `from <= to`, the range `[from, to)` is set and the rest is unset. + - If `to <= from`, the range `[to, from)` is unset and the rest is set. + + All lanes are set identically. + + Example: + + ```mlir + %msk = tpu.create_subelement_mask 3, 9 : vector<8x128x2xi1> + ``` + + This creates a mask `%msk` where, for all `lane`s, `%msk[*][lane][*]` is: + + ``` + [[0, 0], [0, 1], [1, 1], [1, 1], [1, 0], [0, 0], [0, 0], [0, 0]] + ``` + + It is currently only supported: + - In TPU v4, for `num_subelems` of 1 and 2. + - In TPU v5, for `num_subelems` of 1, 2, and 4. + }]; + let arguments = (ins + I32Attr:$from, // inclusive + I32Attr:$to // exclusive + ); + let results = (outs AnyType:$output); // Verify this is a vmsk with num_subelems + let assemblyFormat = [{ + $from `,` $to attr-dict `:` type($output) + }]; +} + +def TPU_AssumeMultipleOp : TPU_Op<"assume_multiple", [Pure, SameOperandsAndResultType]> { + let summary = "Assumes that a value is a multiple of a given integer."; + let description = [{ + This operation is a hint to the compiler that the input `value` is guaranteed + to be a multiple of `multiple`. This can be used to satisfy divisibility checks + in some compiler passes. + + The result is the same as the input `value`. + + Example: + + ```mlir + %val = tpu.assume_multiple %arg0, 16 : index + ``` + }]; + let arguments = (ins + AnyTypeOf<[Index, AnyInteger]>:$value, + I32Attr:$multiple + ); + let results = (outs AnyTypeOf<[Index, AnyInteger]>:$result); + let assemblyFormat = [{$value `,` $multiple attr-dict `:` type($result)}]; + let hasVerifier = 1; +} + +def TPU_MemRefSliceOp : TPU_Op<"memref_slice", [Pure, AttrSizedOperandSegments]> { + let arguments = (ins + AnyMemRef:$mem_ref, + Variadic:$base_idx, + Variadic:$dynamic_sizes + ); + let results = (outs AnyMemRef:$result); + let assemblyFormat = [{ + $mem_ref `[` $base_idx `]` (`<` $dynamic_sizes^ `>`)? + attr-dict `:` type($mem_ref) `->` type($result) + }]; + let hasVerifier = 1; + let hasCanonicalizer = 1; +} + +def TPU_MemRefSqueezeOp : TPU_Op<"memref_squeeze", [Pure]> { + let arguments = (ins AnyMemRef:$input); + let results = (outs AnyMemRef:$result); + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type($result) + }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + +def TPU_MemRefReshapeOp : TPU_Op<"memref_reshape", [Pure]> { + let arguments = (ins AnyMemRef:$input); + let results = (outs AnyMemRef:$result); + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type($result) + }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + +def TPU_MemRefBitcastOp : TPU_Op<"memref_bitcast", [Pure]> { + let arguments = (ins AnyMemRef:$input); + let results = (outs AnyMemRef:$result); + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type($result) + }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + +def TPU_ReinterpretCastOp : TPU_Op<"reinterpret_cast", [Pure]> { + let arguments = (ins AnyMemRef:$input); + let results = (outs AnyMemRef:$result); + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type($result) + }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + +def TPU_AssumeLayoutOp : TPU_Op<"assume_layout", [Pure]> { + let arguments = (ins AnyType:$input); + let results = (outs AnyType:$result); + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type($result) + }]; +} + +// Erases the layout attribute from the memref. +// +// The resulting memref is identical to the input, except that it has an +// identity layout. +def TPU_EraseLayoutOp : TPU_Op<"erase_memref_layout", [Pure, InferTypeOpAdaptor]> { + let arguments = (ins AnyMemRef:$operand); + let results = (outs AnyMemRef:$result); + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) `->` type($result) + }]; + let hasFolder = 1; +} + +// Returns the ID of the current device. +// +// On the input to the compiler the return value is a logical ID in the XLA +// device assignment. It changes to a physical ID after the +// logical-to-physical-device-id pass. +def TPU_DeviceIdOp : TPU_Op<"device_id", [Pure]> { + let arguments = (ins); + let results = (outs I32:$result); + let assemblyFormat = [{ attr-dict `:` type($result) }]; +} + +def TPU_SemaphoreReadOp : TPU_Op<"sem_read"> { + let arguments = (ins MemRefOf<[TPU_SemaphoreType, TPU_DMASemaphoreType]>:$semaphore); + let results = (outs I32:$result); + let assemblyFormat = [{ $semaphore attr-dict `:` type($semaphore) `->` type($result)}]; +} + +def TPU_SemaphoreWaitOp : TPU_Op<"sem_wait"> { + let arguments = (ins + MemRefOf<[TPU_SemaphoreType]>:$semaphore, + I32:$amount + ); + let results = (outs); + let assemblyFormat = [{ $semaphore `,` $amount attr-dict `:` type($semaphore)}]; + let hasVerifier = 1; +} + +def TPU_AllocaSemaphoreOp : TPU_Op<"sem_alloc"> { + let arguments = (ins); + let results = (outs MemRefOf<[TPU_SomeSemaphoreType]>:$result); + let assemblyFormat = [{ attr-dict `:` type($result) }]; +} + +def TPU_GetBarrierSemaphoreOp : TPU_Op<"sem_barrier"> { + let arguments = (ins); + let results = (outs MemRefOf<[TPU_SemaphoreType]>:$semaphore); + let assemblyFormat = [{ attr-dict `:` type($semaphore) }]; + let hasVerifier = 1; +} + +def TPU_SemaphoreSignalOp : TPU_Op<"sem_signal", [AttrSizedOperandSegments]> { + let arguments = (ins + MemRefOf<[TPU_SemaphoreType]>:$semaphore, + I32:$amount, + Optional:$device_id, // For remote DMAs + Optional:$core_id, // For megacore + OptionalAttr:$core_type + ); +let assemblyFormat = [{ + $semaphore `,` $amount (`device_id` $device_id^)? (`core_id` $core_id^)? (`core_type` $core_type^)? attr-dict `:` type($semaphore) + }]; + let hasVerifier = 1; + let builders = [ + // A backward-compatible builder that sets `core_type` to nullptr. + OpBuilder<(ins "Value":$semaphore, "Value":$amount, + "Value":$device_id, "Value":$core_id)>, + ]; +} + +def TPU_BarrierOp : TPU_Op<"barrier"> { + let summary = [{Barrier synchronization across SC vector subcores.}]; + let description = [{ + Performs barrier synchronization across all SC vector subcores at the + specified barrier id. + }]; + let arguments = (ins Index:$barrier_id); + let results = (outs); + let assemblyFormat = [{ `barrier_id` `(` $barrier_id `)` attr-dict }]; +} + +// tpu.enqueue_dma enqueues a DMA operation. +// +// source : Memref to copy from. +// source_semaphore : Semaphore to signal after the DMA completes. +// target : Memref to copy to. +// target_semaphore : Semaphore to wait on before the DMA completes. +// device_id : The id of the device to copy to for remote DMAs. +// core_id : The id of the core to copy to for remote and cross-core +// DMAs. +// priority : The priority of the DMA. +// strict_ordering : True if the DMA requires strict ordering. If false, the +// ordering is either strict or relaxed depending on the +// source and destination. +def TPU_EnqueueDMAOp : TPU_Op<"enqueue_dma", [AttrSizedOperandSegments]> { + let arguments = (ins + AnyMemRef:$source, + Optional>:$source_semaphore, // For remote DMAs + AnyMemRef:$target, + MemRefOf<[TPU_DMASemaphoreType]>:$target_semaphore, + Optional:$device_id, // For remote DMAs + Optional:$core_id, // For megacore + // Smaller number means higher priority. 0 is the highest and the default. + DefaultValuedAttr:$priority, + DefaultValuedAttr:$strict_ordering + ); + let assemblyFormat = [{ + `source` `(` $source `:` type($source) `)` + `target` `(` $target `:` type($target) `)` + (`source_semaphore` `(` $source_semaphore^ `:` type($source_semaphore) `)`)? + `target_semaphore` `(` $target_semaphore `:` type($target_semaphore) `)` + (`device_id` `(` $device_id^ `)`)? + (`core_id` `(` $core_id^ `)`)? + attr-dict + }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + +// A base class for all ops that need to differentiate between gather and +// scatter. +class IndirectDMAOp { + code extraBaseClassDeclaration = [{ + // Return true if this op performs a gather. Returns false if it performs a + // scatter. + FailureOr isGather(); + }]; +} + +// tpu.enqueue_indirect_dma copies data between HBM and VMEM, or between +// VMEM_SHARED and VMEM using indirect HBM offsets. +// +// If the source is in HBM or VMEM_SHARED and the target is in VMEM, performs a +// gather from the source (operand) at the offsets to the target (gather +// result). +// If the source is in VMEM and the target is in HBM or VMEM_SHARED, performs a +// scatter of the source (updates) to the target (operand) at the offsets. +// +// source : Memref to copy from. +// target : Memref to copy to. +// offsets : Gather or scatter offsets. +// semaphore : Semaphore to wait on; receive semaphore for scatter, send semaphore for gather. +// add : If true, add source values to target values. Otherwise, overwrite. +// offset_filter : If set, don't write values at offsets whose value is equal to +// the filter value. +def TPU_EnqueueIndirectDMAOp : TPU_Op<"enqueue_indirect_dma">, IndirectDMAOp { + let arguments = (ins + AnyMemRef:$source, + AnyMemRef:$target, + AnyTypeOf<[MemRefOf<[I32]>, VectorOfRankAndType<[1], [I32]>]>:$offsets, + MemRefOf<[TPU_DMASemaphoreType]>:$semaphore, + Optional:$offset_filter, + DefaultValuedAttr:$add + ); + let assemblyFormat = [{ + `source` `(` $source `:` type($source) `)` + `target` `(` $target `:` type($target) `)` + `offsets` `(` $offsets `:` type($offsets) `)` + (`offset_filter` `(` $offset_filter^ `)`)? + `semaphore` `(` $semaphore `:` type($semaphore) `)` + attr-dict + }]; + let hasVerifier = 1; + let extraClassDeclaration = extraBaseClassDeclaration # [{ + LogicalResult verifyGather(MemRefType operand_ty, + ArrayRef offsets_shape, + MemRefType result_ty); + LogicalResult verifyScatter(MemRefType updates_ty, + ArrayRef offsets_shape, + MemRefType operand_ty); + }]; + let hasCanonicalizeMethod = 1; +} + +// tpu.wait_dma2 waits for a DMA to complete. +// +// The number of bytes to wait for is determined based on the size of the +// destination memref. +def TPU_WaitDMA2Op : TPU_Op<"wait_dma2", [AttrSizedOperandSegments]> { + let arguments = (ins + MemRefOf<[TPU_DMASemaphoreType]>:$semaphore, + AnyMemRef:$src, + AnyMemRef:$dst, + Optional:$device_id, // For remote DMAs + Optional:$core_id, // For megacore + DefaultValuedAttr:$strict_ordering + ); + let assemblyFormat = [{ + `semaphore` `(` $semaphore `:` type($semaphore) `)` + `src` `(` $src `:` type($src) `)` + `dst` `(` $dst `:` type($dst) `)` + (`device_id` `(` $device_id^ `)`)? + (`core_id` `(` $core_id^ `)`)? + attr-dict + }]; + let hasVerifier = 1; + // A backward-compatible builder that sets `device_id` and `core_id` to nullptr. + let builders = [ + OpBuilder<(ins "Value":$semaphore, "Value":$src, "Value":$dst)> + ]; + let hasCanonicalizeMethod = 1; +} + +// TODO(b/395630795): Remove after 2025-08-10. +def TPU_WaitDMAOp : TPU_Op<"wait_dma"> { + let arguments = (ins + MemRefOf<[TPU_DMASemaphoreType]>:$semaphore, + AnyMemRef:$ref + ); + let hasVerifier = 1; +} + +// Like tpu.wait_dma2, but for indirect DMAs. +// +// The number of bytes to wait for is determined based on the size of the +// destination memref in a gather, and the size of the source memref in a +// scatter. The op differentiates between gather and scatter based on the memory +// spaces of the source and destination memrefs. +def TPU_WaitIndirectDMAOp : TPU_Op<"wait_indirect_dma">, IndirectDMAOp { + let arguments = (ins + MemRefOf<[TPU_DMASemaphoreType]>:$semaphore, + AnyMemRef:$src, + AnyMemRef:$dst + ); + let assemblyFormat = [{ + `semaphore` `(` $semaphore `:` type($semaphore) `)` + `src` `(` $src `:` type($src) `)` + `dst` `(` $dst `:` type($dst) `)` + attr-dict + }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; + let extraClassDeclaration = extraBaseClassDeclaration; +} + +def TPU_RegionOp : TPU_Op<"region", [RecursiveMemoryEffects, SingleBlockImplicitTerminator<"tpu::YieldOp">]> { + let arguments = (ins); + let results = (outs Variadic:$results); + let regions = (region AnyRegion:$region); + let hasVerifier = 1; +} + +def TPU_TraceOp : TPU_Op<"trace", [RecursiveMemoryEffects, SingleBlockImplicitTerminator<"tpu::YieldOp">]> { + let arguments = (ins StrAttr:$message, I32Attr:$level); + let results = (outs Variadic:$results); + let regions = (region AnyRegion:$region); +} + +def TPU_TraceStartOp : TPU_Op<"trace_start", []> { + let arguments = (ins StrAttr:$message, I32Attr:$level); + let results = (outs); +} + +def TPU_TraceStopOp : TPU_Op<"trace_stop", []> { + let arguments = (ins); + let results = (outs); +} + +def TPU_TraceValueOp : TPU_Op<"trace_value", []> { + let summary = "Emit a scalar value as a trace event."; + let description = [{ + Emits a dynamic scalar value as a trace event. + When lowered to LLO, creates its own trace scope (start/value/stop). + Supported types: i32, f32. + + Example: + ```mlir + %count = arith.constant 42 : i32 + tpu.trace_value %count, "my_value" : i32 + ``` + }]; + let arguments = (ins + AnyTypeOf<[I32, F32]>:$value, + StrAttr:$label + ); + let results = (outs); + let assemblyFormat = [{ $value `,` $label attr-dict `:` type($value) }]; +} + +def TPU_YieldOp : TPU_Op<"yield", [Pure, ReturnLike, Terminator]> { + let arguments = (ins Variadic:$results); + let assemblyFormat = [{ attr-dict ($results^ `:` type($results))? }]; +} + +def TPU_DelayOp : TPU_Op<"delay"> { + let arguments = (ins I32:$nanos); + let results = (outs); +} + +// Expands the granularity of mask to subelements. +def TPU_MaskCastOp : TPU_Op<"mask_cast", [Pure]> { + let description = [{ + Cast a mask register into a different packing. + + If casting to a type with smaller packing, then values being packed together + must be identical. For example, for 8x128x4xi1 -> 8x128x2xi1, + input[i, j, 0] == input[i, j, 1] and input[i, j, 2] == input[i, j, 3] must + hold for all i, j. Otherwise, the result is undefined. + }]; + let arguments = (ins VectorOfNonZeroRankOf<[I1]>:$input); + let results = (outs VectorOfNonZeroRankOf<[I1]>:$result); + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type($result) + }]; + let hasVerifier = 1; +} + +def TPU_GetIterationBoundOp : TPU_Op<"iteration_bound"> { + let arguments = (ins I32Attr:$dim); + let results = (outs I32:$result); + let assemblyFormat = [{ $dim attr-dict `:` type($result) }]; +} + +def TPU_GetInternalScratchOp : TPU_Op<"internal_scratch"> { + let arguments = (ins); + let results = (outs AnyMemRef:$result); + let assemblyFormat = [{ attr-dict `:` type($result) }]; +} + +def TPU_PRNGSeed32Op : TPU_Op<"prng_set_seed_32"> { + let arguments = (ins Variadic:$seeds); + let results = (outs); +} + +def TPU_PRNGRandomBitsOp : TPU_Op<"prng_random_bits"> { + let arguments = (ins); + let results = (outs AnyVectorOfNonZeroRank:$output); +} + +def TPU_SublaneShuffleOp : TPU_Op<"sublane_shuffle", [SameOperandsAndResultType]> { + // This op takes 2 physical vregs and a pattern, applies the pattern, + // and returns the result as 1 vreg. + // + // The pattern is a list of integers, where the integer value is the + // index of the sublane in the *combined input* [lhs, rhs], and the + // position of the integer in the list is the index of the sublane + // in the *output* vreg. + // + // The pattern size must match the operand/result sublane count. + // + // Example: + // %0 = tpu.single_output_sublane_shuffle %a, %b, + // [0, 1, 2, 3, 4, 5, 6, 7] // Result is %a + // %1 = tpu.single_output_sublane_shuffle %a, %b, + // [8, 9, 10, 11, 12, 13, 14, 15] // Result is %b + // %2 = tpu.single_output_sublane_shuffle %a, %b, + // [7, 6, 5, 4, 11, 10, 9, 8] // Result uses high half of a + // // and low half of b, reversed. + let arguments = (ins + TPU_Vreg:$lhs, + TPU_Vreg:$rhs, + DenseI32ArrayAttr:$pattern + ); + let results = (outs TPU_Vreg:$result); + let assemblyFormat = [{ + $lhs `,` $rhs `,` $pattern attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; + + let hasVerifier = 1; +} + +def TPU_TransposeOp : TPU_Op<"transpose", [Pure]> { + let summary = "tpu transpose operation"; + let arguments = (ins AnyVectorOfAnyRank:$vector, + DenseI64ArrayAttr:$permutation); + let results = (outs AnyVectorOfAnyRank:$result); + + let assemblyFormat = [{ + $vector `,` $permutation attr-dict `:` type($vector) `->` type($result) + }]; + let extraClassDeclaration = [{ + VectorType getSourceVectorType() { + return ::llvm::cast(getVector().getType()); + } + VectorType getResultVectorType() { + return ::llvm::cast(getResult().getType()); + } + }]; + let hasVerifier = 1; +} + +def TPU_LogOp : TPU_Op<"log"> { + let arguments = (ins + Variadic:$inputs, + StrAttr:$tag, + DefaultValuedAttr:$formatted + ); + let results = (outs); + let assemblyFormat = [{ $tag attr-dict (`:` `[` $inputs^ `]` `:` type($inputs))? }]; + let hasVerifier = 1; +} + +def TPU_LogBufferOp : TPU_Op<"log_buffer"> { + let arguments = (ins + AnyMemRef:$input, + DenseI64ArrayAttr:$shape, + StrAttr:$tag + ); + let results = (outs); + let assemblyFormat = [{ $tag attr-dict `:` $input `:` type($input) }]; + let hasVerifier = 1; +} + +#endif // TPU_OPS diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops_verification_test.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops_verification_test.cc new file mode 100644 index 000000000000..f54cb40ac548 --- /dev/null +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops_verification_test.cc @@ -0,0 +1,1310 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" +#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" +#include "xla/mlir/utils/error_util.h" + +namespace mlir::tpu { +namespace { + +using ::testing::_; +using ::testing::HasSubstr; +using ::testing::status::StatusIs; + +class TpuOpsVerificationTest : public ::testing::Test { + protected: + TpuOpsVerificationTest() + : context_([]() { + DialectRegistry registry; + registry.insert(); + return registry; + }()), + builder_(UnknownLoc::get(&context_), &context_) { + context_.loadAllAvailableDialects(); + context_.printOpOnDiagnostic(true); + } + ~TpuOpsVerificationTest() { + for (int i = ops_.size() - 1; i >= 0; --i) { + ops_[i]->erase(); + } + } + + template + OpTy Create(Args&&... args) { + OpTy op = OpTy::create(builder_, std::forward(args)...); + ops_.push_back(op.getOperation()); + return op; + } + + template + absl::Status VerifyOp(OpTy op) { + BaseScopedDiagnosticHandler diag(&context_); + if (op.verify().succeeded()) { + return absl::OkStatus(); + } + return diag.ConsumeStatus(); + } + + Type i32() { return builder_.getI32Type(); } + + MemRefType GetMemRefType( + ArrayRef shape, Type element_type, + std::optional memory_space = std::nullopt) { + return MemRefType::get( + shape, element_type, nullptr, + memory_space.has_value() + ? MemorySpaceAttr::get(builder_.getContext(), *memory_space) + : Attribute()); + } + + Value AllocaI32(ArrayRef shape, + std::optional memory_space = std::nullopt) { + return Create(GetMemRefType(shape, i32(), memory_space)) + .getMemref(); + } + + Value AllocaSemaphore() { + return Create( + GetMemRefType({}, SemaphoreType::get(builder_.getContext()), + MemorySpace::kSemaphoreMem)) + .getResult(); + } + + Value ConstantI1Vector(ArrayRef shape, ArrayRef values) { + return Create( + /*result=*/VectorType::get(shape, builder().getI1Type()), + /*value=*/builder().getBoolVectorAttr(values)) + .getResult(); + } + + Value ConstantI8Vector(ArrayRef shape, ArrayRef values) { + return Create( + /*result=*/VectorType::get(shape, builder().getI8Type()), + /*value=*/dyn_cast( + builder().getDenseI8ArrayAttr(values))) + .getResult(); + } + + Value ConstantIndexVector(ArrayRef shape, ArrayRef values) { + return Create( + /*result=*/VectorType::get(shape, builder().getIndexType()), + /*value=*/builder().getIndexVectorAttr(values)) + .getResult(); + } + + Value ConstantI32Vector(ArrayRef shape, ArrayRef values) { + return Create( + /*result=*/VectorType::get(shape, i32()), + /*value=*/builder().getI32VectorAttr(values)) + .getResult(); + } + + Value ConstantBF16Vector(ArrayRef shape, float value) { + VectorType bf16_vector_type = + VectorType::get(shape, builder().getBF16Type()); + return Create( + /*result=*/bf16_vector_type, + /*value=*/SplatElementsAttr::get( + bf16_vector_type, + builder().getFloatAttr(builder().getBF16Type(), value))) + .getResult(); + } + + Value ConstantF32Vector(ArrayRef shape, ArrayRef values) { + auto ty = VectorType::get(shape, builder().getF32Type()); + return Create( + /*result=*/ty, + /*value=*/DenseElementsAttr::get(ty, values)) + .getResult(); + } + + ImplicitLocOpBuilder& builder() { return builder_; } + + private: + MLIRContext context_; + ImplicitLocOpBuilder builder_; + std::vector ops_; +}; + +class TpuOpsVectorSubcoreVerificationTest : public TpuOpsVerificationTest { + protected: + TpuOpsVectorSubcoreVerificationTest() { + auto func_op = Create("vector_kernel", + builder().getFunctionType({}, {})); + func_op->setAttr( + TPUDialect::GetCoreTypeKey(), + CoreTypeAttr::get(builder().getContext(), CoreType::kScVectorSubcore)); + builder().setInsertionPointToStart(func_op.addEntryBlock()); + } +}; + +TEST_F(TpuOpsVerificationTest, VectorLoadVerificationWorks) { + auto c0 = Create(0); + Value memref = AllocaI32({8}, MemorySpace::kVmem); + auto vl = Create( + /*result=*/VectorType::get({8}, i32()), + /*base=*/memref, + /*indices=*/ValueRange{c0}, + /*strides=*/builder().getDenseI32ArrayAttr({}), + /*mask=*/nullptr); + + ASSERT_OK(VerifyOp(vl)); +} + +TEST_F(TpuOpsVerificationTest, + VectorLoadRankOfStridesDoesNotMatchBaseMemrefRank) { + auto c0 = Create(0); + Value memref = AllocaI32({8}, MemorySpace::kVmem); + auto vl = Create( + /*result=*/VectorType::get({8}, i32()), + /*base=*/memref, + /*indices=*/ValueRange{c0}, + /*strides=*/builder().getDenseI32ArrayAttr({1, 1, 1, 1}), + /*mask=*/nullptr); + ASSERT_THAT(VerifyOp(vl), StatusIs(_, HasSubstr("Expected 1 strides."))); +} + +TEST_F(TpuOpsVerificationTest, VectorLoadStridesFeatureNotImplemented) { + auto c0 = Create(0); + Value memref = AllocaI32({8}, MemorySpace::kVmem); + auto vl = Create( + /*result=*/VectorType::get({8}, i32()), + /*base=*/memref, + /*indices=*/ValueRange{c0}, + /*strides=*/builder().getDenseI32ArrayAttr({1}), + /*mask=*/nullptr); + ASSERT_THAT( + VerifyOp(vl), + StatusIs( + _, HasSubstr("Not implemented: general vector load with strides."))); +} + +TEST_F(TpuOpsVerificationTest, VectorLoadBaseAndResultTypesDoNotMatch) { + auto c0 = Create(0); + Value memref = AllocaI32({8}, MemorySpace::kVmem); + auto vl = Create( + /*result=*/VectorType::get({8}, builder().getF32Type()), + /*base=*/memref, + /*indices=*/ValueRange{c0}, + /*strides=*/builder().getDenseI32ArrayAttr({}), + /*mask=*/nullptr); + + ASSERT_THAT( + VerifyOp(vl), + StatusIs(_, + HasSubstr("Expected base and result element type to match."))); +} + +TEST_F(TpuOpsVerificationTest, + VectorLoadRankOfIndicesDoesNotMatchBaseMemrefRank) { + auto c0 = Create(0); + Value memref = AllocaI32({8}, MemorySpace::kVmem); + auto vl = Create( + /*result=*/VectorType::get({8}, i32()), + /*base=*/memref, + /*indices=*/ValueRange{c0, c0, c0}, + /*strides=*/builder().getDenseI32ArrayAttr({}), + /*mask=*/nullptr); + + ASSERT_THAT(VerifyOp(vl), StatusIs(_, HasSubstr("Expected 1 indices."))); +} + +TEST_F(TpuOpsVerificationTest, VectorLoadValidMaskSucceeds) { + auto c0 = Create(0); + Value memref = AllocaI32({8, 128}, MemorySpace::kVmem); + Value mask = ConstantI32Vector(/*shape=*/{8, 1}, + /*values=*/{1, 1, 1, 1, 1, 1, 1, 1}); + auto vl = Create( + /*result=*/VectorType::get({8, 128}, i32()), + /*base=*/memref, + /*indices=*/ValueRange{c0, c0}, + /*strides=*/builder().getDenseI32ArrayAttr({}), + /*mask=*/mask); + + ASSERT_OK(VerifyOp(vl)); +} + +TEST_F(TpuOpsVerificationTest, VectorLoadMaskInvalidResultBitWidth) { + auto c0 = Create(0); + auto memref = Create( + GetMemRefType({8, 128}, builder().getI64Type(), MemorySpace::kVmem)); + Value mask = ConstantI32Vector(/*shape=*/{8, 1}, + /*values=*/{1, 1, 1, 1, 1, 1, 1, 1}); + auto vl = Create( + /*result=*/VectorType::get({8, 128}, builder().getI64Type()), + /*base=*/memref.getMemref(), + /*indices=*/ValueRange{c0, c0}, + /*strides=*/builder().getDenseI32ArrayAttr({}), + /*mask=*/mask); + + ASSERT_THAT( + VerifyOp(vl), + StatusIs( + _, HasSubstr( + "Not implemented: masked load with non-32-bit element type"))); +} + +TEST_F(TpuOpsVerificationTest, + VectorLoadMaskNotBroadcastableToResultShapeInvalidMinor) { + auto c0 = Create(0); + Value memref = AllocaI32({8, 128}, MemorySpace::kVmem); + Value mask = ConstantI32Vector(/*shape=*/{8, 2}, + /*values=*/{1}); + auto vl = Create( + /*result=*/VectorType::get({8, 128}, i32()), + /*base=*/memref, + /*indices=*/ValueRange{c0, c0}, + /*strides=*/builder().getDenseI32ArrayAttr({}), + /*mask=*/mask); + + ASSERT_THAT( + VerifyOp(vl), + StatusIs( + _, HasSubstr( + "Expected mask shape to be broadcastable to result shape."))); +} + +TEST_F(TpuOpsVerificationTest, + VectorLoadMaskNotBroadcastableToResultShapeInvalidMajor) { + auto c0 = Create(0); + Value memref = AllocaI32({8, 128}, MemorySpace::kVmem); + Value mask = ConstantI32Vector(/*shape=*/{5, 1}, + /*values=*/{1}); + auto vl = Create( + /*result=*/VectorType::get({8, 128}, i32()), + /*base=*/memref, + /*indices=*/ValueRange{c0, c0}, + /*strides=*/builder().getDenseI32ArrayAttr({}), + /*mask=*/mask); + + ASSERT_THAT( + VerifyOp(vl), + StatusIs( + _, HasSubstr( + "Expected mask shape to be broadcastable to result shape."))); +} + +TEST_F(TpuOpsVerificationTest, VectorLoadInvalidMemorySpace) { + auto c0 = Create(0); + Value memref = AllocaI32({8}, MemorySpace::kHbm); + auto vl = Create( + /*result=*/VectorType::get({8}, i32()), + /*base=*/memref, + /*indices=*/ValueRange{c0}, + /*strides=*/builder().getDenseI32ArrayAttr({}), + /*mask=*/nullptr); + + ASSERT_THAT(VerifyOp(vl), + StatusIs(_, HasSubstr("Expected base memref to be in VMEM."))); +} + +TEST_F(TpuOpsVerificationTest, VectorStoreInvalidMemorySpace) { + auto c0 = Create(0); + Value memref = AllocaI32({8}, MemorySpace::kHbm); + Value vector_to_store = ConstantI32Vector(/*shape=*/{8}, /*values=*/{1}); + auto vs = Create( + /*valueToStore=*/vector_to_store, + /*base=*/memref, + /*indices=*/ValueRange{c0}, + /*strides=*/builder().getDenseI32ArrayAttr({}), + /*mask=*/nullptr); + + ASSERT_THAT(VerifyOp(vs), + StatusIs(_, HasSubstr("Expected base memref to be in VMEM."))); +} + +TEST_F(TpuOpsVerificationTest, UnpackSubelementsValidIndex) { + Value source = ConstantI8Vector(/*shape=*/{4, 8}, /*values=*/{1}); + auto unpack = Create( + /*output=*/VectorType::get({16}, builder().getI16Type()), source, + /*index=*/builder().getI32IntegerAttr(1), + /*pack_format=*/ + PackFormatAttr::get(builder().getContext(), PackFormat::kInterleaved)); + ASSERT_OK(VerifyOp(unpack)); +} + +TEST_F(TpuOpsVerificationTest, UnpackSubelementsInvalidIndex) { + Value source = ConstantI8Vector(/*shape=*/{4, 8}, /*values=*/{1}); + auto unpack = Create( + /*output=*/VectorType::get({16}, builder().getI16Type()), source, + /*index=*/builder().getI32IntegerAttr(4), + /*pack_format=*/ + PackFormatAttr::get(builder().getContext(), PackFormat::kInterleaved)); + ASSERT_THAT( + VerifyOp(unpack), + StatusIs( + _, HasSubstr("Index must be between 0 and the packing factor (2)"))); +} + +TEST_F(TpuOpsVerificationTest, VectorLoadIdxVerificationWorks) { + Value memref = AllocaI32({8}, MemorySpace::kVmem); + Value indices = ConstantIndexVector(/*shape=*/{8}, + /*values=*/{0, 1, 2, 3, 4, 5, 6, 7}); + auto vl = Create( + /*result=*/VectorType::get({8}, i32()), + /*base=*/memref, + /*indices=*/ValueRange{indices}, + /*mask=*/nullptr); + + ASSERT_OK(VerifyOp(vl)); +} + +TEST_F(TpuOpsVerificationTest, VectorLoadIdxInvalidMemorySpace) { + Value memref = AllocaI32({8}, MemorySpace::kHbm); + Value indices = ConstantIndexVector(/*shape=*/{8}, + /*values=*/{0, 1, 2, 3, 4, 5, 6, 7}); + auto vl = Create( + /*result=*/VectorType::get({8}, i32()), + /*base=*/memref, + /*indices=*/ValueRange{indices}, + /*mask=*/nullptr); + + ASSERT_THAT(VerifyOp(vl), + StatusIs(_, HasSubstr("Expected base memref to be in VMEM."))); +} + +TEST_F(TpuOpsVerificationTest, VectorLoadIdxInvalidElementType) { + Value memref = + Create( + GetMemRefType({8}, builder().getF32Type(), MemorySpace::kVmem)) + .getMemref(); + Value indices = ConstantIndexVector(/*shape=*/{8}, + /*values=*/{0}); + auto vl = Create( + /*result=*/VectorType::get({8}, i32()), + /*base=*/memref, + /*indices=*/ValueRange{indices}, + /*mask=*/nullptr); + + ASSERT_THAT( + VerifyOp(vl), + StatusIs(_, + HasSubstr("Expected base and result element type to match."))); +} + +TEST_F(TpuOpsVerificationTest, VectorLoadIdxInvalidIndicesDimension) { + Value memref = AllocaI32({8}, MemorySpace::kVmem); + Value indices = ConstantIndexVector(/*shape=*/{4, 1}, + /*values=*/{0}); + auto vl = Create( + /*result=*/VectorType::get({8}, i32()), + /*base=*/memref, + /*indices=*/ValueRange{indices, indices}, + /*mask=*/nullptr); + + ASSERT_THAT( + VerifyOp(vl), + StatusIs(_, HasSubstr("Expected one index vector for each dimension of " + "the base memref with dimension: 1. Got: 2."))); +} + +TEST_F(TpuOpsVerificationTest, VectorLoadIdxValidMask) { + Value memref = AllocaI32({8}, MemorySpace::kVmem); + Value indices = ConstantIndexVector(/*shape=*/{8}, + /*values=*/{0}); + Value mask = ConstantI32Vector(/*shape=*/{8}, + /*values=*/{1}); + auto vl = Create( + /*result=*/VectorType::get({8}, i32()), + /*base=*/memref, + /*indices=*/ValueRange{indices}, + /*mask=*/mask); + + ASSERT_OK(VerifyOp(vl)); +} + +TEST_F(TpuOpsVerificationTest, VectorLoadIdxInvalidMaskShape) { + Value memref = AllocaI32({8}, MemorySpace::kVmem); + Value indices = ConstantIndexVector(/*shape=*/{8}, + /*values=*/{0}); + Value mask = ConstantI32Vector(/*shape=*/{4, 2}, + /*values=*/{1}); + auto vl = Create( + /*result=*/VectorType::get({8}, i32()), + /*base=*/memref, + /*indices=*/ValueRange{indices}, + /*mask=*/mask); + + ASSERT_THAT( + VerifyOp(vl), + StatusIs( + _, HasSubstr( + "Expected mask shape to be broadcastable to result shape."))); +} + +TEST_F(TpuOpsVerificationTest, VectorStoreIdxVerificationWorks) { + Value memref = AllocaI32({8}, MemorySpace::kVmem); + Value vector_to_store = + ConstantI32Vector(/*shape=*/{8}, + /*values=*/{1, 1, 1, 1, 1, 1, 1, 1}); + Value indices = ConstantIndexVector(/*shape=*/{8}, + /*values=*/{0, 1, 2, 3, 4, 5, 6, 7}); + auto vl = Create( + /*vectorToStore=*/vector_to_store, + /*base=*/memref, + /*indices=*/ValueRange{indices}, + /*mask=*/nullptr, + /*add=*/builder().getBoolAttr(true)); + + ASSERT_OK(VerifyOp(vl)); +} + +TEST_F(TpuOpsVerificationTest, VectorStoreIdxInvalidMemorySpace) { + Value memref = AllocaI32({8}, MemorySpace::kHbm); + Value vector_to_store = + ConstantI32Vector(/*shape=*/{8}, + /*values=*/{1, 1, 1, 1, 1, 1, 1, 1}); + Value indices = ConstantIndexVector(/*shape=*/{8}, + /*values=*/{0, 1, 2, 3, 4, 5, 6, 7}); + auto vl = Create( + /*vectorToStore=*/vector_to_store, + /*base=*/memref, + /*indices=*/ValueRange{indices}, + /*mask=*/nullptr, + /*add=*/nullptr); + + ASSERT_THAT(VerifyOp(vl), + StatusIs(_, HasSubstr("Expected base memref to be in VMEM."))); +} + +TEST_F(TpuOpsVerificationTest, VectorStoreIdxInvalidElementType) { + Value memref = + Create( + GetMemRefType({8}, builder().getF32Type(), MemorySpace::kVmem)) + .getMemref(); + Value vector_to_store = ConstantI32Vector(/*shape=*/{8}, + /*values=*/{1}); + Value indices = ConstantIndexVector(/*shape=*/{8}, + /*values=*/{0}); + auto vl = Create( + /*vectorToStore=*/vector_to_store, + /*base=*/memref, + /*indices=*/ValueRange{indices}, + /*mask=*/nullptr, + /*add=*/nullptr); + + ASSERT_THAT( + VerifyOp(vl), + StatusIs(_, HasSubstr( + "Expected base and valueToStore element type to match"))); +} + +TEST_F(TpuOpsVerificationTest, VectorStoreIdxInvalidIndicesDimension) { + Value memref = AllocaI32({8}, MemorySpace::kVmem); + Value vector_to_store = ConstantI32Vector(/*shape=*/{8}, + /*values=*/{1}); + Value indices = ConstantIndexVector(/*shape=*/{4, 1}, + /*values=*/{0}); + auto vl = Create( + /*vectorToStore=*/vector_to_store, + /*base=*/memref, + /*indices=*/ValueRange{indices, indices}, + /*mask=*/nullptr, + /*add=*/nullptr); + + ASSERT_THAT( + VerifyOp(vl), + StatusIs(_, HasSubstr("Expected one index vector for each dimension of " + "the base memref with dimension: 1. Got: 2."))); +} + +TEST_F(TpuOpsVerificationTest, VectorStoreIdxInvalidValueToStoreDimension) { + Value memref = AllocaI32({8}, MemorySpace::kVmem); + Value vector_to_store = ConstantI32Vector(/*shape=*/{4, 2}, + /*values=*/{1}); + Value indices = ConstantIndexVector(/*shape=*/{8}, + /*values=*/{0}); + auto vl = Create( + /*vectorToStore=*/vector_to_store, + /*base=*/memref, + /*indices=*/ValueRange{indices}, + /*mask=*/nullptr, + /*add=*/nullptr); + + ASSERT_THAT(VerifyOp(vl), + StatusIs(_, HasSubstr("Expected value to have rank 1. Got: 2."))); +} + +TEST_F(TpuOpsVerificationTest, VectorStoreIdxValidMask) { + Value memref = AllocaI32({8}, MemorySpace::kVmem); + Value vector_to_store = ConstantI32Vector(/*shape=*/{8}, + /*values=*/{1}); + Value indices = ConstantIndexVector(/*shape=*/{8}, + /*values=*/{0}); + Value mask = ConstantI32Vector(/*shape=*/{8}, + /*values=*/{1}); + auto vl = Create( + /*vectorToStore=*/vector_to_store, + /*base=*/memref, + /*indices=*/ValueRange{indices}, + /*mask=*/mask, + /*add=*/nullptr); + + ASSERT_OK(VerifyOp(vl)); +} + +TEST_F(TpuOpsVerificationTest, VectorStoreIdxInvalidMaskShape) { + Value memref = AllocaI32({8}, MemorySpace::kVmem); + Value vector_to_store = ConstantI32Vector(/*shape=*/{8}, + /*values=*/{1}); + Value indices = ConstantIndexVector(/*shape=*/{8}, + /*values=*/{0}); + Value mask = ConstantI32Vector(/*shape=*/{4, 2}, + /*values=*/{1}); + auto vl = Create( + /*vectorToStore=*/vector_to_store, + /*base=*/memref, + /*indices=*/ValueRange{indices}, + /*mask=*/mask, + /*add=*/nullptr); + + ASSERT_THAT( + VerifyOp(vl), + StatusIs( + _, + HasSubstr( + "Expected mask shape to match result shape: (8). Got: (4, 2)."))); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, ScanVerificationWorksI32) { + Value src = ConstantI32Vector(/*shape=*/{8}, /*values=*/{1}); + Type dst = VectorType::get(/*shape=*/{8}, /*type=*/builder().getI32Type()); + Value mask = ConstantI1Vector(/*shape=*/{8}, /*values=*/{1}); + + ASSERT_OK(VerifyOp(Create(dst, src, tpu::ReductionKind::kSum, mask))); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, ScanVerificationWorksBF16) { + Value src = ConstantBF16Vector(/*shape=*/{2, 8}, /*value=*/1); + Type dst = + VectorType::get(/*shape=*/{2, 8}, /*type=*/builder().getBF16Type()); + Value mask = ConstantI1Vector(/*shape=*/{8}, /*values=*/{1}); + + ASSERT_OK(VerifyOp(Create(dst, src, tpu::ReductionKind::kSum, mask))); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, ScanVerificationWorksI1) { + Value src = ConstantI1Vector(/*shape=*/{8}, /*values=*/{1}); + Type dst = VectorType::get(/*shape=*/{8}, /*type=*/builder().getI32Type()); + + ASSERT_OK(VerifyOp( + Create(dst, src, tpu::ReductionKind::kSum, /*mask=*/nullptr))); +} + +TEST_F(TpuOpsVerificationTest, ScanOnUnsupportedCore) { + auto func_op = + Create("scalar_kernel", builder().getFunctionType({}, {})); + func_op->setAttr(TPUDialect::GetCoreTypeKey(), + CoreTypeAttr::get(builder().getContext(), CoreType::kTc)); + builder().setInsertionPointToStart(func_op.addEntryBlock()); + Value src = ConstantI32Vector(/*shape=*/{8}, /*values=*/{1}); + Type dst = VectorType::get(/*shape=*/{8}, /*type=*/builder().getI32Type()); + Value mask = ConstantI1Vector(/*shape=*/{8}, /*values=*/{1}); + + ASSERT_THAT( + VerifyOp(Create(dst, src, tpu::ReductionKind::kSum, mask)), + StatusIs(_, + HasSubstr("Scan is supported only on the SC vector subcore"))); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, + ScanVerificationInvalidOutputTypeWithI1Input) { + Value src = ConstantI1Vector(/*shape=*/{8}, /*values=*/{1}); + Type dst = VectorType::get(/*shape=*/{8}, /*type=*/builder().getI1Type()); + Value mask = ConstantI1Vector(/*shape=*/{8}, /*values=*/{1}); + + ASSERT_THAT( + VerifyOp(Create(dst, src, tpu::ReductionKind::kMin, mask)), + StatusIs( + _, + HasSubstr( + "Output element type must be i32 vector for i1 vector inputs."))); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, + ScanVerificationMismatchElementType) { + Value src = ConstantI32Vector(/*shape=*/{8}, /*values=*/{1}); + Type dst = VectorType::get(/*shape=*/{8}, /*type=*/builder().getF32Type()); + Value mask = ConstantI1Vector(/*shape=*/{8}, /*values=*/{1}); + + ASSERT_THAT( + VerifyOp(Create(dst, src, tpu::ReductionKind::kSum, mask)), + StatusIs(_, HasSubstr("Input and output element type mismatch."))); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, ScanVerificationMismatchShape) { + Value src = ConstantI32Vector(/*shape=*/{16}, /*values=*/{1}); + Type dst = VectorType::get(/*shape=*/{8}, /*type=*/builder().getI32Type()); + Value mask = ConstantI1Vector(/*shape=*/{16}, /*values=*/{1}); + + ASSERT_THAT( + VerifyOp(Create(dst, src, tpu::ReductionKind::kSum, mask)), + StatusIs(_, HasSubstr("Input and output shape mismatch. Input " + "shape: (16). Output shape: (8)."))); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, ScanVerificationInvalidInputRank) { + Value src = ConstantI32Vector(/*shape=*/{8, 1, 1}, /*values=*/{1}); + Type dst = + VectorType::get(/*shape=*/{8, 1, 1}, /*type=*/builder().getI32Type()); + Value mask = ConstantI1Vector(/*shape=*/{8}, /*values=*/{1}); + + ASSERT_THAT( + VerifyOp(Create(dst, src, tpu::ReductionKind::kSum, mask)), + StatusIs(_, HasSubstr("Input must be a rank 1 or 2 vector."))); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, + ScanVerificationInvalidReductionKind) { + Value src = ConstantI32Vector(/*shape=*/{8}, /*values=*/{1}); + Type dst = VectorType::get(/*shape=*/{8}, /*type=*/builder().getI32Type()); + Value mask = ConstantI1Vector(/*shape=*/{8}, /*values=*/{1}); + + ASSERT_THAT( + VerifyOp(Create(dst, src, tpu::ReductionKind::kArgMax, mask)), + StatusIs(_, + HasSubstr("Only sum, max and min reductions are supported."))); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, + ScanVerificationInvalidReductionKindWithI1Input) { + Value src = ConstantI1Vector(/*shape=*/{8}, /*values=*/{1}); + Type dst = VectorType::get(/*shape=*/{8}, /*type=*/builder().getI32Type()); + Value mask = ConstantI1Vector(/*shape=*/{8}, /*values=*/{1}); + + ASSERT_THAT( + VerifyOp(Create(dst, src, tpu::ReductionKind::kMin, mask)), + StatusIs( + _, + HasSubstr("Only sum reduction is supported for i1 vector inputs."))); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, + ScanVerificationInvalidMaskWithI1Input) { + Value src = ConstantI1Vector(/*shape=*/{8}, /*values=*/{1}); + Type dst = VectorType::get(/*shape=*/{8}, /*type=*/builder().getI32Type()); + Value mask = ConstantI1Vector(/*shape=*/{8}, /*values=*/{1}); + + ASSERT_THAT( + VerifyOp(Create(dst, src, tpu::ReductionKind::kSum, mask)), + StatusIs(_, HasSubstr("Mask is not supported for i1 vector inputs."))); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, ScanVerificationInvalidMaskRank) { + Value src = ConstantI32Vector(/*shape=*/{1, 8}, /*values=*/{1}); + Type dst = VectorType::get(/*shape=*/{1, 8}, /*type=*/builder().getI32Type()); + Value mask = ConstantI1Vector(/*shape=*/{1, 8}, /*values=*/{1}); + + ASSERT_THAT( + VerifyOp(Create(dst, src, tpu::ReductionKind::kMax, mask)), + StatusIs(_, HasSubstr("Mask must be a rank 1 vector."))); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, ScanVerificationInvalidMaskShape) { + Value src = ConstantI32Vector(/*shape=*/{1, 8}, /*values=*/{1}); + Type dst = VectorType::get(/*shape=*/{1, 8}, /*type=*/builder().getI32Type()); + Value mask = ConstantI1Vector(/*shape=*/{16}, /*values=*/{1}); + + ASSERT_THAT( + VerifyOp(Create(dst, src, tpu::ReductionKind::kMax, mask)), + StatusIs(_, HasSubstr("Mask and input mismatch. Expected mask of " + "length: 8, but got 16."))); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, DmaElementTypeMismatch) { + auto dma = Create( + /*source=*/AllocaI32({1024, 256, 128}, MemorySpace::kHbm), + /*source_semaphore=*/AllocaSemaphore(), + /*target=*/ + Create(GetMemRefType({1024, 256, 128}, + builder().getI64Type(), + MemorySpace::kHbm)) + .getMemref(), + /*target_semaphore=*/AllocaSemaphore(), + /*device_id=*/nullptr, + /*core_id=*/nullptr); + + ASSERT_THAT( + VerifyOp(dma), + StatusIs(_, HasSubstr("DMA source and target element type mismatch"))); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, DmaDynamicRankMismatch) { + auto dma = Create( + /*source=*/AllocaI32({ShapedType::kDynamic, 256, 128}, MemorySpace::kHbm), + /*source_semaphore=*/AllocaSemaphore(), + /*target=*/ + AllocaI32({ShapedType::kDynamic, ShapedType::kDynamic, 128}, + MemorySpace::kHbm), + /*target_semaphore=*/AllocaSemaphore(), + /*device_id=*/nullptr, + /*core_id=*/nullptr); + + ASSERT_THAT(VerifyOp(dma), + StatusIs(_, HasSubstr("DMA source and target shape mismatch."))); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, DmaStrictOrderingSupported) { + auto dma = Create( + /*source=*/AllocaI32({1024, 256, 128}, MemorySpace::kHbm), + /*source_semaphore=*/nullptr, + /*target=*/AllocaI32({1024, 256, 128}, MemorySpace::kVmem), + /*target_semaphore=*/AllocaSemaphore(), + /*device_id=*/nullptr, + /*core_id=*/nullptr, + /*priority=*/0, + /*strict_ordering=*/true); + + ASSERT_OK(VerifyOp(dma)); +} + +TEST_F(TpuOpsVerificationTest, DmaStrictOrderingNotSupportedOnTc) { + auto func_op = + Create("tc_kernel", builder().getFunctionType({}, {})); + func_op->setAttr(TPUDialect::GetCoreTypeKey(), + CoreTypeAttr::get(builder().getContext(), CoreType::kTc)); + builder().setInsertionPointToStart(func_op.addEntryBlock()); + + auto dma = Create( + /*source=*/AllocaI32({1024, 256, 128}, MemorySpace::kHbm), + /*source_semaphore=*/nullptr, + /*target=*/AllocaI32({1024, 256, 128}, MemorySpace::kVmem), + /*target_semaphore=*/AllocaSemaphore(), + /*device_id=*/nullptr, + /*core_id=*/nullptr, + /*priority=*/0, + /*strict_ordering=*/true); + + ASSERT_THAT(VerifyOp(dma), + StatusIs(_, HasSubstr("Strict ordering is only supported on the " + "SC scalar and vector subcores"))); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, + IndirectDmaHbmChunkGatherVerificationWorks) { + auto dma = Create( + /*source=*/AllocaI32({1024, 256, 128}, MemorySpace::kHbm), + /*target=*/AllocaI32({64, 256, 128}, MemorySpace::kVmem), + /*offsets=*/AllocaI32({64}, MemorySpace::kVmem), + /*semaphore=*/AllocaSemaphore(), + /*offset_filter=*/nullptr, + /*add=*/false); + + ASSERT_OK(VerifyOp(dma)); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, + IndirectDmaVmemSharedChunkGatherVerificationWorks) { + auto dma = Create( + /*source=*/AllocaI32({1024, 256, 128}, MemorySpace::kVmemShared), + /*target=*/AllocaI32({64, 256, 128}, MemorySpace::kVmem), + /*offsets=*/AllocaI32({64}, MemorySpace::kVmem), + /*semaphore=*/AllocaSemaphore(), + /*offset_filter=*/nullptr, + /*add=*/false); + + ASSERT_OK(VerifyOp(dma)); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, + IndirectDmaSublaneGatherVerificationWorks) { + auto dma = Create( + /*source=*/AllocaI32({1024, 256, 128}, MemorySpace::kHbm), + /*target=*/AllocaI32({64, 32, 128}, MemorySpace::kVmem), + /*offsets=*/AllocaI32({64, 32}, MemorySpace::kVmem), + /*semaphore=*/AllocaSemaphore(), + /*offset_filter=*/nullptr, + /*add=*/false); + + ASSERT_OK(VerifyOp(dma)); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, + IndirectDmaElementGatherVerificationWorks) { + auto dma = Create( + /*source=*/AllocaI32({1024, 256, 128}, MemorySpace::kHbm), + /*target=*/AllocaI32({64, 32, 128}, MemorySpace::kVmem), + /*offsets=*/AllocaI32({64, 32, 128}, MemorySpace::kVmem), + /*semaphore=*/AllocaSemaphore(), + /*offset_filter=*/nullptr, + /*add=*/false); + + ASSERT_OK(VerifyOp(dma)); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, + IndirectDmaFilteredGatherVerificationWorks) { + Value offset_filter = + Create(builder().getIntegerAttr(i32(), -1)) + .getResult(); + auto dma = Create( + /*source=*/AllocaI32({1024, 256, 128}, MemorySpace::kHbm), + /*target=*/AllocaI32({64, 32, 128}, MemorySpace::kVmem), + /*offsets=*/AllocaI32({64, 32}, MemorySpace::kVmem), + /*semaphore=*/AllocaSemaphore(), + /*offset_filter=*/offset_filter, + /*add=*/false); + + ASSERT_OK(VerifyOp(dma)); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, + IndirectDmaVectorGatherVerificationWorks) { + Value vector_of_offsets = + ConstantI32Vector(/*shape=*/{8}, + /*values=*/{0, 1, 2, 3, 4, 5, 6, 7}); + auto dma = Create( + /*source=*/AllocaI32({1024, 32, 128}, MemorySpace::kHbm), + /*target=*/AllocaI32({8, 32, 128}, MemorySpace::kVmem), + /*offsets=*/vector_of_offsets, + /*semaphore=*/AllocaSemaphore(), + /*offset_filter=*/nullptr, + /*add=*/false); + + ASSERT_OK(VerifyOp(dma)); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, + IndirectDmaHbmScatterVerificationWorks) { + auto dma = Create( + /*source=*/AllocaI32({64, 32, 128}, MemorySpace::kVmem), + /*target=*/AllocaI32({1024, 256, 128}, MemorySpace::kHbm), + /*offsets=*/AllocaI32({64, 32}, MemorySpace::kVmem), + /*semaphore=*/AllocaSemaphore(), + /*offset_filter=*/nullptr, + /*add=*/false); + + ASSERT_OK(VerifyOp(dma)); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, + IndirectDmaVmemSharedScatterVerificationWorks) { + auto dma = Create( + /*source=*/AllocaI32({64, 32, 128}, MemorySpace::kVmem), + /*target=*/AllocaI32({1024, 256, 128}, MemorySpace::kVmemShared), + /*offsets=*/AllocaI32({64, 32}, MemorySpace::kVmem), + /*semaphore=*/AllocaSemaphore(), + /*offset_filter=*/nullptr, + /*add=*/false); + + ASSERT_OK(VerifyOp(dma)); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, + IndirectDmaVectorScatterVerificationWorks) { + Value vector_of_offsets = ConstantI32Vector( + /*shape=*/{16}, + /*values=*/{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + auto dma = Create( + /*source=*/AllocaI32({16, 32, 128}, MemorySpace::kVmem), + /*target=*/AllocaI32({1024, 32, 128}, MemorySpace::kHbm), + /*offsets=*/vector_of_offsets, + /*semaphore=*/AllocaSemaphore(), + /*offset_filter=*/nullptr, + /*add=*/false); + + ASSERT_OK(VerifyOp(dma)); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, + IndirectDmaScatterAddVerificationWorks) { + auto dma = Create( + /*source=*/AllocaI32({64, 32, 128}, MemorySpace::kVmem), + /*target=*/AllocaI32({1024, 256, 128}, MemorySpace::kHbm), + /*offsets=*/AllocaI32({64, 32}, MemorySpace::kVmem), + /*semaphore=*/AllocaSemaphore(), + /*offset_filter=*/nullptr, + /*add=*/true); + + ASSERT_OK(VerifyOp(dma)); +} + +TEST_F(TpuOpsVerificationTest, IndirectDmaOnUnsupportedCore) { + std::vector unsupported_cores = {CoreType::kScScalarSubcore, + CoreType::kTc}; + for (CoreType unsupported_core : unsupported_cores) { + auto func_op = Create("scalar_kernel", + builder().getFunctionType({}, {})); + func_op->setAttr( + TPUDialect::GetCoreTypeKey(), + CoreTypeAttr::get(builder().getContext(), unsupported_core)); + builder().setInsertionPointToStart(func_op.addEntryBlock()); + auto dma = Create( + /*source=*/AllocaI32({1024, 256, 128}, MemorySpace::kHbm), + /*target=*/AllocaI32({64, 32, 128}, MemorySpace::kVmem), + /*offsets=*/AllocaI32({64, 32}, MemorySpace::kVmem), + /*semaphore=*/AllocaSemaphore(), + /*offset_filter=*/nullptr, + /*add=*/false); + + ASSERT_THAT( + VerifyOp(dma), + StatusIs(_, HasSubstr("Enqueue indirect DMA is supported only on " + "the SC vector subcore"))); + } +} + +TEST_F(TpuOpsVerificationTest, IndirectDmaOnUnsupportedTc) { + auto func_op = + Create("tc_kernel", builder().getFunctionType({}, {})); + func_op->setAttr(TPUDialect::GetCoreTypeKey(), + CoreTypeAttr::get(builder().getContext(), CoreType::kTc)); + builder().setInsertionPointToStart(func_op.addEntryBlock()); + auto dma = Create( + /*source=*/AllocaI32({1024, 256, 128}, MemorySpace::kHbm), + /*target=*/AllocaI32({64, 32, 128}, MemorySpace::kVmem), + /*offsets=*/AllocaI32({64, 32}, MemorySpace::kVmem), + /*semaphore=*/AllocaSemaphore(), + /*offset_filter=*/nullptr, + /*add=*/false); + + ASSERT_THAT(VerifyOp(dma), + StatusIs(_, HasSubstr("Enqueue indirect DMA is supported only on " + "the SC vector subcore"))); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, + IndirectDmaGatherSourceAndTargetTypeMismatch) { + auto dma = Create( + /*source=*/AllocaI32({1024, 256, 128}, MemorySpace::kHbm), + /*target=*/ + Create(GetMemRefType({64, 32, 128}, + builder().getI64Type(), + MemorySpace::kVmem)) + .getMemref(), + /*offsets=*/AllocaI32({64, 32}, MemorySpace::kVmem), + /*semaphore=*/AllocaSemaphore(), + /*offset_filter=*/nullptr, + /*add=*/false); + + ASSERT_THAT( + VerifyOp(dma), + StatusIs(_, HasSubstr("Source and target element type mismatch"))); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, IndirectDmaWithoutLocalMem) { + auto dma = Create( + /*source=*/AllocaI32({1024, 256, 128}, MemorySpace::kHbm), + /*target=*/AllocaI32({64, 32, 128}, MemorySpace::kHbm), + /*offsets=*/AllocaI32({64, 32}, MemorySpace::kVmem), + /*semaphore=*/AllocaSemaphore(), + /*offset_filter=*/nullptr, + /*add=*/false); + + ASSERT_THAT(VerifyOp(dma), + StatusIs(_, HasSubstr("The transfer must be between HBM and " + "VMEM, or between VMEM_SHARED and VMEM"))); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, IndirectDmaOffsetsNotInVmem) { + auto dma = Create( + /*source=*/AllocaI32({1024, 256, 128}, MemorySpace::kHbm), + /*target=*/AllocaI32({64, 32, 128}, MemorySpace::kVmem), + /*offsets=*/AllocaI32({64, 32}, MemorySpace::kHbm), + /*semaphore=*/AllocaSemaphore(), + /*offset_filter=*/nullptr, + /*add=*/false); + + ASSERT_THAT(VerifyOp(dma), + StatusIs(_, HasSubstr("Offsets memref must be in VMEM"))); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, IndirectDma1DSemaphore) { + auto dma = Create( + /*source=*/AllocaI32({1024, 256, 128}, MemorySpace::kHbm), + /*target=*/AllocaI32({64, 32, 128}, MemorySpace::kVmem), + /*offsets=*/AllocaI32({64, 32}, MemorySpace::kVmem), + /*semaphore=*/ + Create( + GetMemRefType({1}, SemaphoreType::get(builder().getContext()), + MemorySpace::kSemaphoreMem)) + .getResult(), + /*offset_filter=*/nullptr, + /*add=*/false); + + ASSERT_THAT(VerifyOp(dma), + StatusIs(_, HasSubstr("Semaphore must be rank 0"))); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, + IndirectDmaGatherTargetShapeInvalid) { + auto dma = Create( + /*source=*/AllocaI32({1024, 256, 128}, MemorySpace::kHbm), + /*target=*/AllocaI32({512, 32, 128}, MemorySpace::kVmem), + /*offsets=*/AllocaI32({64, 32}, MemorySpace::kVmem), + /*semaphore=*/AllocaSemaphore(), + /*offset_filter=*/nullptr, + /*add=*/false); + + ASSERT_THAT( + VerifyOp(dma), + StatusIs(_, + HasSubstr( + "Offsets shape (64, 32) must match the majormost dimensions " + "of the target (gather result) shape (512, 32, 128)"))); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, + IndirectDmaVectorGatherTargetShapeInvalid) { + Value vector_of_offsets = + ConstantI32Vector(/*shape=*/{8}, + /*values=*/{0, 1, 2, 3, 4, 5, 6, 7}); + auto dma = Create( + /*source=*/AllocaI32({1024, 32, 128}, MemorySpace::kHbm), + /*target=*/AllocaI32({512, 32, 128}, MemorySpace::kVmem), + /*offsets=*/vector_of_offsets, + /*semaphore=*/AllocaSemaphore(), + /*offset_filter=*/nullptr, + /*add=*/false); + + ASSERT_THAT( + VerifyOp(dma), + StatusIs( + _, HasSubstr("Offsets shape (8) must match the majormost dimensions " + "of the target (gather result) shape (512, 32, 128)"))); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, + IndirectDmaGatherOperandShapeInvalid) { + auto dma = Create( + /*source=*/AllocaI32({1024, 256, 512}, MemorySpace::kHbm), + /*target=*/AllocaI32({64, 32, 128}, MemorySpace::kVmem), + /*offsets=*/AllocaI32({64, 32}, MemorySpace::kVmem), + /*semaphore=*/AllocaSemaphore(), + /*offset_filter=*/nullptr, + /*add=*/false); + + ASSERT_THAT( + VerifyOp(dma), + StatusIs(_, + HasSubstr( + "1 minormost dimension of the source (gather operand) shape " + "(1024, 256, 512) must match the minormost dimension of " + "the target (gather result) shape (64, 32, 128)"))); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, + IndirectDmaScatterUpdatesShapeInvalid) { + auto dma = Create( + /*source=*/AllocaI32({512, 32, 128}, MemorySpace::kVmem), + /*target=*/AllocaI32({1024, 256, 128}, MemorySpace::kHbm), + /*offsets=*/AllocaI32({64, 32}, MemorySpace::kVmem), + /*semaphore=*/AllocaSemaphore(), + /*offset_filter=*/nullptr, + /*add=*/false); + + ASSERT_THAT( + VerifyOp(dma), + StatusIs(_, + HasSubstr( + "Offsets shape (64, 32) must match the majormost dimensions " + "of the source (scatter updates) shape (512, 32, 128)"))); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, + IndirectDmaScatterOperandShapeInvalid) { + auto dma = Create( + /*source=*/AllocaI32({64, 32, 128}, MemorySpace::kVmem), + /*target=*/AllocaI32({1024, 256, 512}, MemorySpace::kHbm), + /*offsets=*/AllocaI32({64, 32}, MemorySpace::kVmem), + /*semaphore=*/AllocaSemaphore(), + /*offset_filter=*/nullptr, + /*add=*/false); + + ASSERT_THAT( + VerifyOp(dma), + StatusIs( + _, HasSubstr( + "1 minormost dimension of the source (scatter updates) shape " + "(64, 32, 128) must match the minormost dimension of the " + "target (scatter operand) shape (1024, 256, 512)"))); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, + IndirectDmaVectorScatterOperandShapeInvalid) { + Value vector_of_offsets = + ConstantI32Vector(/*shape=*/{8}, + /*values=*/{0, 1, 2, 3, 4, 5, 6, 7}); + auto dma = Create( + /*source=*/AllocaI32({8, 32, 128}, MemorySpace::kVmem), + /*target=*/AllocaI32({1024, 96, 512}, MemorySpace::kHbm), + /*offsets=*/vector_of_offsets, + /*semaphore=*/AllocaSemaphore(), + /*offset_filter=*/nullptr, + /*add=*/false); + + ASSERT_THAT( + VerifyOp(dma), + StatusIs( + _, HasSubstr( + "2 minormost dimensions of the source (scatter updates) shape " + "(8, 32, 128) must match the minormost dimensions of the " + "target (scatter operand) shape (1024, 96, 512)"))); +} + +TEST_F(TpuOpsVerificationTest, IndirectDmaWaitOnUnsupportedCoreInvalid) { + static constexpr std::array unsupported_cores = { + CoreType::kScScalarSubcore, CoreType::kTc}; + for (CoreType unsupported_core : unsupported_cores) { + SCOPED_TRACE(testing::Message() + << "Testing unsupported core type: " + << stringifyCoreType(unsupported_core).str()); + auto func_op = Create("scalar_kernel", + builder().getFunctionType({}, {})); + func_op->setAttr( + TPUDialect::GetCoreTypeKey(), + CoreTypeAttr::get(builder().getContext(), unsupported_core)); + builder().setInsertionPointToStart(func_op.addEntryBlock()); + auto wait = Create( + /*semaphore=*/AllocaSemaphore(), + /*src=*/AllocaI32({1024, 256, 128}, MemorySpace::kHbm), + /*dst=*/AllocaI32({64, 256, 128}, MemorySpace::kVmem)); + + ASSERT_THAT(VerifyOp(wait), + StatusIs(_, HasSubstr("Wait indirect DMA is supported only on " + "the SC vector subcore"))); + } +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, + IndirectDmaWaitGatherVerificationWorks) { + auto wait = Create( + /*semaphore=*/AllocaSemaphore(), + /*src=*/AllocaI32({1024, 256, 128}, MemorySpace::kHbm), + /*dst=*/AllocaI32({64, 256, 128}, MemorySpace::kVmem)); + + ASSERT_OK(VerifyOp(wait)); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, SortVerificationWorks) { + Value keys = ConstantI32Vector(/*shape=*/{8}, /*values=*/{1}); + Value values = ConstantI32Vector(/*shape=*/{8}, /*values=*/{2}); + Type mask_ty = VectorType::get({8}, builder().getI1Type()); + Type keys_ty = keys.getType(); + Type values_ty = values.getType(); + auto sort = + Create(/*result_types=*/TypeRange{mask_ty, keys_ty, values_ty}, + /*keys=*/keys, /*values=*/values, + /*mask=*/nullptr, + /*descending=*/builder().getBoolAttr(false)); + ASSERT_OK(VerifyOp(sort)); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, SortF32KeysVerificationWorks) { + Value keys = ConstantF32Vector(/*shape=*/{8}, /*values=*/{1.0f}); + Value values = ConstantI32Vector(/*shape=*/{8}, /*values=*/{2}); + Type mask_ty = VectorType::get({8}, builder().getI1Type()); + Type keys_ty = keys.getType(); + Type values_ty = values.getType(); + auto sort = + Create(/*result_types=*/TypeRange{mask_ty, keys_ty, values_ty}, + /*keys=*/keys, /*values=*/values, + /*mask=*/nullptr, + /*descending=*/builder().getBoolAttr(false)); + ASSERT_OK(VerifyOp(sort)); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, SortKeyValShapeMismatch) { + Value keys = ConstantI32Vector(/*shape=*/{8}, /*values=*/{1}); + Value values = ConstantI32Vector(/*shape=*/{16}, /*values=*/{2}); + Type mask_ty = VectorType::get({8}, builder().getI1Type()); + Type keys_ty = keys.getType(); + Type values_ty = values.getType(); + auto sort = + Create(/*result_types=*/TypeRange{mask_ty, keys_ty, values_ty}, + /*keys=*/keys, /*values=*/values, + /*mask=*/nullptr, + /*descending=*/builder().getBoolAttr(false)); + ASSERT_THAT(VerifyOp(sort), + StatusIs(_, HasSubstr("Key and value shapes must match"))); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, SortResultTypeMismatch) { + Value keys = ConstantI32Vector(/*shape=*/{8}, /*values=*/{1}); + Value values = ConstantI32Vector(/*shape=*/{8}, /*values=*/{2}); + Type mask_ty = VectorType::get({8}, builder().getI1Type()); + Type keys_ty = keys.getType(); + Type f32_ty = VectorType::get({8}, builder().getF32Type()); + auto sort = + Create(/*result_types=*/TypeRange{mask_ty, keys_ty, f32_ty}, + /*keys=*/keys, /*values=*/values, + /*mask=*/nullptr, + /*descending=*/builder().getBoolAttr(false)); + ASSERT_THAT( + VerifyOp(sort), + StatusIs(_, HasSubstr("Value and sorted_value types must match"))); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, + IndirectDmaWaitScatterVerificationWorks) { + auto wait = Create( + /*semaphore=*/AllocaSemaphore(), + /*source=*/AllocaI32({64, 32, 128}, MemorySpace::kVmem), + /*target=*/AllocaI32({1024, 256, 128}, MemorySpace::kVmemShared)); + + ASSERT_OK(VerifyOp(wait)); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, + IndirectDmaWaitWithoutLocalMemInvalid) { + auto wait = Create( + /*semaphore=*/AllocaSemaphore(), + /*src=*/AllocaI32({1024, 256, 128}, MemorySpace::kHbm), + /*dst=*/AllocaI32({64, 256, 128}, MemorySpace::kHbm)); + + ASSERT_THAT(VerifyOp(wait), + StatusIs(_, HasSubstr("The transfer must be between HBM and " + "VMEM, or between VMEM_SHARED and VMEM"))); +} + +TEST_F(TpuOpsVectorSubcoreVerificationTest, + IndirectDmaWaitInvalidSemaphoreRank) { + auto wait = Create( + /*semaphore=*/Create( + GetMemRefType({8}, SemaphoreType::get(builder().getContext()), + MemorySpace::kSemaphoreMem)) + .getResult(), + /*source=*/AllocaI32({64, 32, 128}, MemorySpace::kVmem), + /*target=*/AllocaI32({1024, 256, 128}, MemorySpace::kVmemShared)); + + ASSERT_THAT( + VerifyOp(wait), + StatusIs(_, HasSubstr("Indirect DMA wait semaphore must be rank 0"))); +} +} // namespace +} // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc deleted file mode 100644 index 1997ffe34535..000000000000 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ /dev/null @@ -1,7041 +0,0 @@ -#include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVectorExtras.h" -#include "llvm/ADT/StringMap.h" -#include "llvm/ADT/iterator_range.h" -#include "llvm/Support/Compiler.h" -#include "llvm/Support/MathExtras.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Traits.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Block.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/Region.h" -#include "mlir/IR/TypeRange.h" -#include "mlir/IR/Types.h" -#include "mlir/IR/Value.h" -#include "mlir/IR/ValueRange.h" -#include "mlir/IR/Visitors.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "absl/algorithm/container.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/types/span.h" -#include "llvm/include/llvm/ADT/APInt.h" -#include "llvm/include/llvm/Support/LogicalResult.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/include/mlir/Dialect/Math/IR/Math.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/OperationSupport.h" -#include "jaxlib/mosaic/dialect/tpu/array_util.h" -#include "jaxlib/mosaic/dialect/tpu/layout.h" -#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" -#include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h" -#include "jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h" -#include "jaxlib/mosaic/dialect/tpu/util.h" -#include "jaxlib/mosaic/dialect/tpu/vreg_util.h" -#include "xla/array.h" -#include "xla/layout.h" -#include "xla/tsl/platform/errors.h" -#include "xla/util.h" - -// TODO(tlongeri): Prefer returning failure over CHECKs. In particular, be more -// consistent about this for layout null checks in rules. - -namespace mlir::tpu { -// TODO(tlongeri): Maybe just roll our own multi-dimensional array instead of -// using XLA's? There's too much glue for going from/to ArrayRef. - -#define GEN_PASS_DECL_APPLYVECTORLAYOUTPASS -#define GEN_PASS_DEF_APPLYVECTORLAYOUTPASS -#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc" - - -// The minimum bound required to rotate with scratch space. The bound refers to -// the number of VREGs on rotation dim. This number was concluded from some cost -// analysis for comparing different dynamic rotation implementations. If -// actual bound is greater than this, dynamic rotation with internal scratch -// space is more efficient. -// TODO(jevinjiang): need to update it based on the generation. -static constexpr int kMinBoundToRotateWithScratch = 27; - -using RewriteContext = ApplyVectorLayoutContext; - -LogicalResult applyLayoutBlock(RewriteContext &ctx, Block &block); -namespace { - -void moveAllRegions(Operation &src, Operation &dst) { - for (auto [src_region, dst_region] : - llvm::zip_equal(src.getRegions(), dst.getRegions())) { - dst_region.takeBody(src_region); - } -} - -// Get the address of pre-allocated internal scratch space with requested shape. -// -// Arguments: -// shape: The shape of the requested scratch space. -// elem_ty: The type of the elements in the requested scratch space. -// -// Returns: -// A memref of the requested shape and type. -FailureOr> getInternalScratch( - RewriteContext &ctx, OpBuilder &builder, Location loc, - ArrayRef shape, Type elem_ty, int64_t sublane_tiling = 0) { - if (shape.empty()) { - return failure(); - } - if (shape.back() % ctx.target_shape[1] != 0) { - return emitError(loc, "Unaligned scratch shape on minormost dimension"); - } - int packing = 32 / elem_ty.getIntOrFloatBitWidth(); - int sublane_count = llvm::divideCeil( - std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()) / - ctx.target_shape[1], - packing); - - if (sublane_count > ctx.max_sublanes_in_scratch) { - return emitError( - loc, - "scratch is too small. Try to increase `internal_scratch_in_bytes`."); - } - // We can omit tpu_tiling_flags here because, for internal scratch, the - // tiling does not matter (its shape is (N, 128)). - FAILUREOR_ASSIGN_OR_RETURN( - MemRefType scratch_ref_ty, - inferMemref(MemRefType::get(shape, elem_ty), ctx.hardware_generation, - ctx.target_shape, /*tpu_tiling_flags=*/{}, sublane_tiling)); - return builder.create(loc, scratch_ref_ty) - .getResult(); -} - -// Models Numpy's np.concatenate -xla::Array concatenate(const ArrayRef> arrays, - const int64_t axis) { - CHECK(!arrays.empty()); - SmallVector dims(toArrayRef(arrays[0].dimensions())); - CHECK(0 <= axis && axis < dims.size()); - for (size_t i = 1; i < arrays.size(); ++i) { - CHECK_EQ(arrays[i].num_dimensions(), arrays[0].num_dimensions()); - for (size_t j = 0; j < arrays[i].num_dimensions(); ++j) { - if (j != axis) { - CHECK_EQ(arrays[i].dim(j), arrays[0].dim(j)); - } - } - dims[axis] += arrays[i].dim(axis); - } - xla::Array res(dims); - int64_t offset = 0; - for (xla::Array const& arr : arrays) { - arr.Each([&](const absl::Span idx, const Value v) { - SmallVector res_idx(toArrayRef(idx)); - res_idx[axis] += offset; - res(res_idx) = v; - }); - offset += arr.dim(axis); - } - return res; -} - -SmallVector> split(const xla::Array &vregs, int axis) { - CHECK(axis >= 0 && axis < vregs.num_dimensions()); - SmallVector> chunks; - chunks.reserve(vregs.dim(axis)); - SmallVector starts(vregs.num_dimensions(), 0); - SmallVector limits(vregs.dimensions().begin(), - vregs.dimensions().end()); - for (int64_t i = 0; i < vregs.dim(axis); ++i) { - starts[axis] = i; - limits[axis] = i + 1; - chunks.push_back(vregs.Slice(starts, limits)); - } - return chunks; -}; - -bool incrementIndex(const MutableArrayRef idx, - const absl::Span limits) { - const int64_t nd = idx.size(); - CHECK_EQ(nd, limits.size()); - for (int64_t i = nd - 1; i >= 0; --i) { - ++idx[i]; - if (idx[i] < limits[i]) { - return true; - } - idx[i] = 0; - } - return false; -} - -FailureOr getIntConst(Value v, bool silent = false) { - if (auto constant_op = v.getDefiningOp()) { - if (auto integer_attr = dyn_cast(constant_op.getValue())) { - return integer_attr.getValue().getSExtValue(); - } - } - if (silent) { - return failure(); - } - return emitError(v.getLoc(), "Expected an integer constant"); -} - -FailureOr> getIntConstsFromOperandRange( - ValueRange vals, bool silent = false) { - SmallVector res(vals.size()); - for (int i = 0; i < vals.size(); ++i) { - FAILUREOR_ASSIGN_OR_RETURN(res[i], getIntConst(vals[i], silent)); - } - return res; -} - -Value broadcastSublane(OpBuilder &builder, Value vreg, int sublane_idx, - const std::array target_shape) { - return builder.create( - vreg.getLoc(), vreg.getType(), vreg, - SmallVector(target_shape[0], sublane_idx), - /*dimension=*/0); -} - -FailureOr>> sliceRef( - ImplicitLocOpBuilder &builder, TypedValue base_ref, - ArrayRef slice_shape, ValueRange indices, - ArrayRef tiling) { - IntegerType i32 = builder.getI32Type(); - MemRefType ref_ty = base_ref.getType(); - - // MemRefSliceOp only allows tile-aligned slices. We pad the shape up - // accordingly with the padding. We don't include the static tiled indices - // in the slice when they can be arbitrary. But we do include dynamic tiled - // indices under the condition that they are divisible by the tile size. - SmallVector pad_slice_shape(slice_shape); - TPU_ASSERT_LE_LOC(builder.getLoc(), tiling.size(), slice_shape.size()); - for (int i = 1; i <= tiling.size(); ++i) { - auto &dim = *(pad_slice_shape.end() - i); - dim = xla::RoundUpTo(dim, *(tiling.end() - i)); - } - - SmallVector slice_base_indices; - slice_base_indices.reserve(ref_ty.getRank()); - for (auto idx : indices.drop_back(tiling.size())) { - slice_base_indices.push_back(builder.create(i32, idx)); - } - - Value c0 = nullptr; - SmallVector indices_within_slice(indices.size() - tiling.size(), 0); - for (auto tiled_idx : indices.take_back(tiling.size())) { - if (auto cst = getIntConst(tiled_idx, /*silent=*/true); succeeded(cst)) { - indices_within_slice.push_back(*cst); - if (!c0) { - c0 = builder.create(i32, - builder.getI32IntegerAttr(0)); - } - slice_base_indices.push_back(c0); - } else { - indices_within_slice.push_back(0); - // TODO: Check divisibility! - slice_base_indices.push_back( - builder.create(i32, tiled_idx)); - } - } - - // TODO(apaszke): Allow tile-aligned dynamic slicing on tiled dimensions. - Value sliced_ref = builder.create( - MemRefType::get(pad_slice_shape, ref_ty.getElementType(), - ref_ty.getLayout(), ref_ty.getMemorySpace()), - base_ref, slice_base_indices, /*dynamic_sizes=*/ValueRange()); - - return std::make_pair(sliced_ref, indices_within_slice); -} - -// Returns the first-level tiling of a (packed and tiled) memref value. -FailureOr> getMemRefTiling( - TypedValue value, const std::array target_shape) { - if (auto erase_layout_op = - dyn_cast_if_present(value.getDefiningOp())) { - value = erase_layout_op.getOperand(); - } - const MemRefType memref_ty = value.getType(); - const auto mem_layout = dyn_cast(memref_ty.getLayout()); - if (mem_layout == nullptr) { - return emitError(value.getLoc(), "Expected a tiled memref"); - } - FAILUREOR_ASSIGN_OR_RETURN(int8_t bitwidth, - getTypeBitwidth(memref_ty.getElementType())); - const int packing = 32 / bitwidth; - const ArrayRef tiles = mem_layout.getTiles(); - const xla::Tile &first_tile = tiles.front(); - if (first_tile.dimensions().size() == 1) { - const int64_t tile_size = first_tile.dimension(0); - if (tile_size % (target_shape[1] * packing) != 0) { - return emitError(value.getLoc(), "Not implemented"); - } - if (bitwidth == 32) { - if (tiles.size() > 1) { - return emitError(value.getLoc(), "Not implemented"); - } - } else if (bitwidth < 32) { - if (tiles.drop_front() != - ArrayRef{xla::Tile({target_shape[1]}), - xla::Tile({packing, 1})}) { - return emitError(value.getLoc(), "Not implemented"); - } - } - return std::array{1, tile_size}; - } - if (first_tile.dimensions().size() == 2) { - if (bitwidth == 32) { - if (tiles.size() > 1) { - return emitError(value.getLoc(), "Not implemented"); - } - return std::array{first_tile.dimension(0), - first_tile.dimension(1)}; - } - if (bitwidth < 32) { - if (tiles.size() != 2 || tiles[1] != xla::Tile({packing, 1})) { - return emitError(value.getLoc(), "Not implemented"); - } - return std::array{first_tile.dimension(0), - first_tile.dimension(1)}; - } - } - return emitError(value.getLoc(), "Not implemented"); -} - -// Hoist a vector constant as an additional argument of the function. -FailureOr appendConstant(RewriteContext &ctx, func::FuncOp func, - DenseElementsAttr value) { - MLIRContext *mlir_ctx = func.getContext(); - Block &entry_block = func.getBody().front(); - auto value_ty = cast(value.getType()); - if (value_ty.getElementType().getIntOrFloatBitWidth() != 32) { - return func.emitOpError("Not implemented: Only 32-bit constants supported"); - } - if (func->getAttr("scratch_operands")) { - return func.emitOpError("Not implemented: function has scratch_operands"); - } - // We can omit tpu_tiling_flags here since we invoke inferMemref only for - // constant operands which are kernel parameters that will have their layouts - // overridden before the pass pipeline runs anyway. - FAILUREOR_ASSIGN_OR_RETURN( - MemRefType arg_type, - inferMemref( - MemRefType::get(value_ty.getShape(), value_ty.getElementType()), - ctx.hardware_generation, ctx.target_shape, /*tpu_tiling_flags=*/{}, - /*is_kernel_argument=*/true)); - const BlockArgument argument = entry_block.insertArgument( - entry_block.getNumArguments() - 1, arg_type, UnknownLoc::get(mlir_ctx)); - const FunctionType func_ty = func.getFunctionType(); - // Adjust the function type. - SmallVector new_arg_tys(func_ty.getInputs()); - new_arg_tys.insert(new_arg_tys.begin() + (new_arg_tys.size() - 1), arg_type); - const auto new_func_ty = - FunctionType::get(mlir_ctx, new_arg_tys, func_ty.getResults()); - func.setFunctionType(new_func_ty); - // Adjust the constants attribute. - if (auto prev_cst = func->getAttrOfType("vector_constants")) { - SmallVector vector_constants(prev_cst.getValue()); - vector_constants.push_back(value); - func->setAttr("vector_constants", - ArrayAttr::get(func.getContext(), vector_constants)); - } else { - func->setAttr("vector_constants", ArrayAttr::get(func.getContext(), value)); - } - // Adjust window params for the extra operand. - if (auto window_params = func->getAttrOfType("window_params")) { - const auto iteration_bounds = - func->getAttrOfType("iteration_bounds"); - TPU_ASSERT_LOC(UnknownLoc::get(mlir_ctx), iteration_bounds); - const int64_t iteration_rank = iteration_bounds.getSize(); - const SmallVector zeros( - iteration_rank, getAffineConstantExpr(0, func.getContext())); - const auto transform_indices = - AffineMap::get(iteration_rank, 0, zeros, func.getContext()); - const auto new_param = DictionaryAttr::get( - func.getContext(), - NamedAttribute(StringAttr::get(func.getContext(), "transform_indices"), - AffineMapAttr::get(transform_indices))); - SmallVector window_params_values(window_params.getValue()); - window_params_values.insert(window_params_values.end() - 1, new_param); - func->setAttr("window_params", - ArrayAttr::get(func.getContext(), window_params_values)); - } - return argument; -} - -// Masks all values outside of bounds. -// -// Arguments: -// value: A rank 2 MLIR vector to be masked. -// bounds: A TargetTuple of slices specifying a rectangular subregion of value -// that should be preserved during masking. -// neutral: A scalar attribute specifying the value that will be inserted -// for all values outside of specified bounds. -// -// Returns: -// An MLIR value of the same type as the value argument, with all entries -// outside of bounds replaced by neutral. -FailureOr maskOOB(RewriteContext &ctx, ImplicitLocOpBuilder &builder, - TypedValue value, - const VRegDataBounds &bounds, - const Attribute neutral) { - auto native_vreg_ty = - getNativeVregType(value.getType().getElementType(), ctx.target_shape); - TPU_ASSERT_LOC(value.getLoc(), llvm::equal(value.getType().getShape(), - native_vreg_ty.getShape())); - if (bounds.isComplete(ctx.target_shape)) { - return value; - } - FAILUREOR_ASSIGN_OR_RETURN( - TypedValue mask, - bounds.getVectorMask(builder, value.getLoc(), ctx.hardware_generation, - ctx.target_shape)); - if (cast(mask.getType().getElementType()).getWidth() != 1) { - return emitError(value.getLoc(), - "Not implemented: Unsupported mask bitwidth"); - } - if (mask.getType().getShape() != native_vreg_ty.getShape()) { - mask = builder.create( - value.getLoc(), - VectorType::get(native_vreg_ty.getShape(), builder.getI1Type()), mask); - } - Value neutral_vec = getFullVector(builder, native_vreg_ty, neutral); - return builder - .create(value.getLoc(), mask, value, neutral_vec) - .getResult(); -} - -// Insert a minor dimension to the implicit shape. The original minor dimension -// becomes the new second minor dimension, laid out across sublanes. -// -// The returned vreg array uses the original tiling and the offsets specified in -// new_offsets to hold the value with the new implicit shape. -// -// Args: -// vregs: The vreg array with *implicit* array shape. -// ishape: The implicit shape of the represented value. -// layout: The layout used for the represented value. The implicit -// dimension is ignored, since this function operates directly at -// the level of the implicit shape. -// new_offsets: The offsets to use for the layout of the returned vreg array. -FailureOr> insertImplicitMinorDimension( - RewriteContext &ctx, OpBuilder &builder, const Location loc, - const xla::Array &vregs, const ArrayRef ishape, - const VectorLayout &layout, const LayoutOffsets new_offsets) { - if (layout.bitwidth() != 32 || !layout.hasNativeTiling(ctx.target_shape)) { - return emitError(loc, "Not implemented: Unsupported bitwidth or tiling"); - } - if (layout.offsets()[1].has_value()) { - if (!new_offsets[0]) { - // TODO(tlongeri): This can only be valid if the dim size is 1. - return emitError(loc, "Not implemented: Replication mismatch"); - } - if (*new_offsets[0] != *layout.offsets()[1] % ctx.target_shape[0] && - *layout.offsets()[1] + *(ishape.end() - 1) > ctx.target_shape[1]) { - // This requires blending data from different vregs. - return emitError(loc, - "Not implemented: Misaligned offsets and shape does not " - "fit in one vreg"); - } - } - // new_layout is only to get the new vreg array shape, the implicit dim is - // irrelevant (since we already have the implicit shape): - const VectorLayout new_layout(layout.bitwidth(), new_offsets, layout.tiling(), - VectorLayout::ImplicitDim::kNone); - SmallVector new_ishape(ishape); - new_ishape.push_back(1); - xla::Array new_vregs(new_layout.tileArrayShape( - /*src_is_implicit=*/true, /*res_is_implicit=*/true, std::move(new_ishape), - ctx.target_shape)); - // Preallocate an indices vector to avoid repeated allocations: - SmallVector idxs; - new_vregs.Each([&](const absl::Span dst_idx, - Value *const dst_vreg) { - // Indices of the new vreg in the new vreg array: - const int64_t new_2nd_minor_idx = *(dst_idx.end() - 2); - const int64_t new_3rd_minor_idx = *(dst_idx.end() - 3); - idxs.assign(dst_idx.begin(), dst_idx.end()); - if (!layout.offsets()[0].has_value() && new_3rd_minor_idx != 0) { - // All vregs along that dimension are the same - *(idxs.end() - 3) = 0; - *dst_vreg = new_vregs(idxs); - } else if (!layout.offsets()[1].has_value() && new_2nd_minor_idx != 0) { - // All vregs along that dimension are the same - *(idxs.end() - 2) = 0; - *dst_vreg = new_vregs(idxs); - } else { - // dst_vreg will hold slice [row_idx, col_idx:(col_idx + target_shape[0])] - // of the after-offsets source shape - const int64_t row_idx = - layout.offsets()[0] ? new_3rd_minor_idx + *layout.offsets()[0] : 0; - const int64_t col_idx = layout.offsets()[1] - ? new_2nd_minor_idx * ctx.target_shape[0] + - *layout.offsets()[1] - *new_offsets[0] - : 0; - - idxs.pop_back(); - *(idxs.end() - 2) = row_idx / ctx.target_shape[0]; - *(idxs.end() - 1) = col_idx / ctx.target_shape[1]; - Value src_vreg = vregs(idxs); - // TODO(tlongeri): We can sometimes skip operations when dst_vreg will - // hold a single non-padding element (first or last) and we don't need - // replication in the output. - if (layout.offsets()[0].has_value()) { - // [ . . . . . . . . ] [ . . . . a b c d ] - // [ . . . . a b c d ] => [ . . . . a b c d ] - // [ . . . . . . . . ] [ . . . . a b c d ] - // [ . . . . . . . . ] [ . . . . a b c d ] - src_vreg = broadcastSublane( - builder, src_vreg, - /*sublane_idx=*/row_idx % ctx.target_shape[0], ctx.target_shape); - } - if (layout.offsets()[1].has_value()) { - // [ . . . . a b c d ] [ a a a a a a a a ] - // [ . . . . a b c d ] => [ b b b b b b b b ] - // [ . . . . a b c d ] [ c c c c c c c c ] - // [ . . . . a b c d ] [ d d d d d d d d ] - src_vreg = builder.create( - loc, src_vreg.getType(), src_vreg, - /*lane=*/col_idx % ctx.target_shape[1]); - } - *dst_vreg = src_vreg; - } - }); - return new_vregs; -} - -LogicalResult elementwise_op_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_OP(OpTrait::hasElementwiseMappableTraits(&op)); - if (op.getNumResults() != 1) { - return op.emitError("Not implemented: Only ops with one result supported"); - } - TPU_ASSERT_EQ_OP(layouts_in.size(), op.getNumOperands()); - TPU_ASSERT_GT_OP(layouts_in.size(), 0); - TPU_ASSERT_EQ_OP(layouts_out.size(), 1); - OpBuilder builder(&op); - if (!(layouts_out.front().has_value() && - llvm::all_of(layouts_in, - [&](const Layout &l) { return l.has_value(); }))) { - return op.emitOpError( - "Not implemented: Null layout / non-vector operand in elementwise " - "operation"); - } - const auto out_ty = cast(op.getResult(0).getType()); - const VectorLayout &layout_out = *layouts_out.front(); - if (!llvm::all_of(layouts_in, [&](const Layout &l) { - return l->generalizes(layout_out, out_ty.getShape(), ctx.target_shape); - })) { - return op.emitOpError( - "Not implemented: Incompatible layouts in elementwise operation"); - } - const unsigned num_operands = op.getNumOperands(); - SmallVector> in_vreg_arrays; - in_vreg_arrays.reserve(num_operands); - for (unsigned i = 0; i < num_operands; ++i) { - FAILUREOR_ASSIGN_OR_RETURN( - xla::Array tile_array, - disassemble(builder, *layouts_in[i], - cast>(op.getOperand(i)), - ctx.target_shape)); - in_vreg_arrays.emplace_back(std::move(tile_array)); - } - - const VectorType out_vreg_ty = getNativeVregOrVmaskType( - out_ty.getElementType(), layout_out.bitwidth(), ctx.target_shape); - - NamedAttrList attributes(op.getAttrDictionary()); - attributes.erase("in_layout"); - attributes.erase("out_layout"); - - // Note that we have to broadcast to handle replicate dimensions. - SmallVector broadcasted_shape( - toArrayRef(in_vreg_arrays[0].dimensions())); - for (size_t i = 1; i < num_operands; ++i) { - SmallVector new_broadcasted_shape; - TPU_ASSERT_OP(OpTrait::util::getBroadcastedShape( - broadcasted_shape, toArrayRef(in_vreg_arrays[i].dimensions()), - new_broadcasted_shape)); - broadcasted_shape = std::move(new_broadcasted_shape); - } - TPU_ASSERT_OP(broadcasted_shape == - layout_out.tileArrayShape(out_ty.getShape(), ctx.target_shape)); - - // TODO(tlongeri): Can we avoid initializing the array before filling values? - xla::Array out_vreg_array(broadcasted_shape); - out_vreg_array.Each([&](absl::Span idx, Value *out_vreg) { - SmallVector operands(num_operands); - - for (unsigned i = 0; i < num_operands; ++i) { - // Handle indices for broadcasted dimensions - SmallVector operand_idx(toArrayRef(idx)); - for (unsigned j = 0; j < idx.size(); ++j) { - if (in_vreg_arrays[i].dim(j) == 1) { - operand_idx[j] = 0; - } - } - operands[i] = in_vreg_arrays[i](operand_idx); - } - Operation *vreg_op = - builder.create(op.getLoc(), op.getName().getIdentifier(), operands, - out_vreg_ty, attributes.getAttrs()); - CHECK(vreg_op); - CHECK_EQ(vreg_op->getNumResults(), 1); - *out_vreg = vreg_op->getResult(0); - }); - op.replaceAllUsesWith(assemble(builder, out_ty, layout_out, - std::move(out_vreg_array), ctx.target_shape)); - op.erase(); - return success(); -} - -using rule_type = std::function, ArrayRef)>; - -template -FailureOr> ext_op_rule_impl(RewriteContext &ctx, - OpBuilder &builder, OpTy op, - const VectorLayout &layout_in, - const VectorLayout &layout_out) { - const auto result_ty = cast(op.getResult().getType()); - auto source = cast>(op.getIn()); - auto output_vregs_shape = - layout_out.tileArrayImplicitShape(result_ty.getShape(), ctx.target_shape); - FAILUREOR_ASSIGN_OR_RETURN( - xla::Array input_vregs, - disassemble(builder, layout_in, source, ctx.target_shape, - /*use_implicit_shape=*/true)); - xla::Array output_vregs(output_vregs_shape); - const VectorType res_vreg_ty = - getNativeVregType(result_ty.getElementType(), ctx.target_shape); - if (layout_in.implicit_dim() != layout_out.implicit_dim()) { - return op.emitOpError( - "Not implemented: Change of implicit dim during the cast"); - } - if (layout_in.offsets() != layout_out.offsets()) { - return op.emitOpError("Not implemented: Change of offsets during the cast"); - } - const int packing = layout_out.bitwidth() / layout_in.bitwidth(); - if (layout_in.hasNativeTiling(ctx.target_shape) && - layout_out.hasNativeTiling(ctx.target_shape)) { - output_vregs.Each([&](absl::Span idxs, Value *v) { - SmallVector input_vreg_idxs(toArrayRef(idxs)); - int64_t vreg_part = *(input_vreg_idxs.end() - 2) % packing; - *(input_vreg_idxs.end() - 2) /= packing; - *v = builder.create( - op.getLoc(), res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part, - tpu::PackFormat::kCompressed); - }); - } else { - if (layout_in.tiling() != layout_out.tiling()) { - return op.emitOpError("Not implemented: Changing tiling during the cast"); - } - auto tiling = layout_in.tiling(); - if (ctx.target_shape[0] % tiling[0] != 0 || - ctx.target_shape[1] != tiling[1]) { - return op.emitOpError("Not implemented: tiling not supported"); - } - output_vregs.Each([&](absl::Span idxs, Value *v) { - SmallVector input_vreg_idxs(toArrayRef(idxs)); - input_vreg_idxs.back() /= packing; - const int64_t vreg_part = idxs.back() % packing; - *v = builder.create( - op.getLoc(), res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part, - tpu::PackFormat::kCompressed); - }); - } - return output_vregs; -} - -LogicalResult arith_extf_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(layouts_in.size(), 1); - TPU_ASSERT_OP(layouts_in.front().has_value()); - TPU_ASSERT_OP(layouts_out.front().has_value()); - auto extf_op = cast(op); - if (layouts_out.front()->bitwidth() != 32) { - return op.emitOpError("Not implemented: Only support conversion to 32-bit"); - } - ImplicitLocOpBuilder builder(op.getLoc(), &op); - FAILUREOR_ASSIGN_OR_RETURN( - xla::Array output_vregs, - ext_op_rule_impl(ctx, builder, extf_op, *layouts_in.front(), - *layouts_out.front())); - const auto result_ty = cast(extf_op.getResult().getType()); - extf_op.replaceAllUsesWith(assemble(builder, result_ty, *layouts_out.front(), - std::move(output_vregs), ctx.target_shape, - /*use_implicit_shape=*/true) - .getResult()); - extf_op.erase(); - return success(); -} - -LogicalResult arith_extsi_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(layouts_in.size(), 1); - TPU_ASSERT_OP(layouts_in.front().has_value()); - TPU_ASSERT_EQ_OP(layouts_out.size(), 1); - TPU_ASSERT_OP(layouts_out.front().has_value()); - auto extsi_op = cast(op); - ImplicitLocOpBuilder builder(op.getLoc(), &op); - FAILUREOR_ASSIGN_OR_RETURN( - xla::Array output_vregs, - ext_op_rule_impl(ctx, builder, extsi_op, *layouts_in.front(), - *layouts_out.front())); - const auto result_ty = cast(extsi_op.getResult().getType()); - extsi_op.replaceAllUsesWith(assemble(builder, result_ty, *layouts_out.front(), - std::move(output_vregs), - ctx.target_shape, - /*use_implicit_shape=*/true) - .getResult()); - extsi_op.erase(); - return success(); -} - -LogicalResult arith_extui_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(layouts_in.size(), 1); - TPU_ASSERT_OP(layouts_in.front().has_value()); - TPU_ASSERT_EQ_OP(layouts_out.size(), 1); - TPU_ASSERT_OP(layouts_out.front().has_value()); - auto extui_op = cast(op); - const auto in_ty = cast(extui_op.getIn().getType()); - const auto out_ty = cast(extui_op.getType()); - const unsigned in_bitwidth = in_ty.getElementTypeBitWidth(); - if (in_bitwidth == 1) { - return elementwise_op_rule(ctx, op, layouts_in, layouts_out); - } - ImplicitLocOpBuilder builder(op.getLoc(), &op); - FAILUREOR_ASSIGN_OR_RETURN( - xla::Array output_vregs, - ext_op_rule_impl(ctx, builder, extui_op, *layouts_in.front(), - *layouts_out.front())); - unsigned out_bitwidth = out_ty.getElementTypeBitWidth(); - // Generate a mask to mask out the sign extension. e.g., for u8 -> u16, - // the mask is 0x00ff00ff. - unsigned mask = (1 << in_bitwidth) - 1; - while (out_bitwidth < 32) { - mask = (mask << out_bitwidth) | mask; - out_bitwidth *= 2; - } - const VectorType i32_vreg_ty = - getNativeVregType(builder.getI32Type(), ctx.target_shape); - auto mask_const = builder.create( - op.getLoc(), i32_vreg_ty, DenseIntElementsAttr::get(i32_vreg_ty, {mask})); - const VectorType out_vreg_ty = - getNativeVregType(out_ty.getElementType(), ctx.target_shape); - output_vregs.Each([&](absl::Span _, Value *v) { - Value unpacked = - builder.create(op.getLoc(), i32_vreg_ty, *v); - unpacked = builder.create(op.getLoc(), i32_vreg_ty, unpacked, - mask_const); - *v = builder.create(op.getLoc(), out_vreg_ty, unpacked); - }); - extui_op.replaceAllUsesWith(assemble(builder, out_ty, *layouts_out.front(), - std::move(output_vregs), - ctx.target_shape, - /*use_implicit_shape=*/true) - .getResult()); - extui_op.erase(); - return success(); -} - -template -LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op, - const VectorLayout &layout_in, - const VectorLayout &layout_out) { - ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation()); - auto source = cast>(op.getOperand()); - auto result_ty = cast(op.getResult().getType()); - auto output_vregs_shape = - layout_out.tileArrayImplicitShape(result_ty.getShape(), ctx.target_shape); - FAILUREOR_ASSIGN_OR_RETURN( - xla::Array input_vregs, - disassemble(builder, layout_in, source, ctx.target_shape, - /*use_implicit_shape=*/true)); - xla::Array output_vregs(output_vregs_shape); - const LayoutOffsets input_offsets = layout_in.offsets(); - const LayoutOffsets output_offsets = layout_out.offsets(); - const std::array input_vreg_slice = - layout_in.vregSlice(ctx.target_shape); - const std::array output_vreg_slice = - layout_out.vregSlice(ctx.target_shape); - const int input_sublanes_per_tile = - layout_in.sublanesPerTile(ctx.target_shape); - - if (layout_in.implicit_dim() != layout_out.implicit_dim()) { - return op.emitOpError( - "Not implemented: Truncation changes implicit dimension"); - } - for (const auto &[input_offset, output_offset, input_slice_size] : - llvm::zip_equal(input_offsets, output_offsets, input_vreg_slice)) { - if (!input_offset.has_value() && !output_offset.has_value()) { - // Replicated to replicated is okay - } else if (!input_offset.has_value() && output_offset.has_value()) { - // Replicated to non-replicated could be handled, but we don't leverage - // replication, so we don't expect a replicated input offset to be - // assigned. The materialization of replicated vregs in the vreg - // array should be handled by relayout. - return op.emitOpError( - "Not implemented: Replicated to non-replicated offset"); - } else if (input_offset.has_value() && !output_offset.has_value()) { - return op.emitOpError( - "Not implemented: Truncation introduces replication"); - } else { - DCHECK(input_offset.has_value() && output_offset.has_value()); - if (*input_offset != *output_offset % input_slice_size) { - return op.emitOpError("Not implemented: Misaligned offsets"); - } - } - } - if (output_vreg_slice[0] % input_vreg_slice[0] != 0 || - output_vreg_slice[1] % input_vreg_slice[1] != 0) { - // The output vreg slice should be a union of whole input vreg slices - return op.emitOpError("Not implemented: Unsupported tiling change"); - } - // How many rows and columns of input vregs we are packing into one output - // vreg: - const int64_t vreg_rows = output_vreg_slice[0] / input_vreg_slice[0]; - const int64_t vreg_cols = output_vreg_slice[1] / input_vreg_slice[1]; - - // Currently, we always pack across rows first, and then across columns. - // Note: Even though we combine it into a single tpu.pack_subelements op, the - // order of the operands is such that it is equivalent to packing across - // rows and then across columns. - // TODO(b/384274392): For some cases we want to pack across columns first, but - // we also need mixed compressed/interleaved packing. - - // The format for packing *across* multiple rows in the vreg array (different - // 2nd minor index): - PackFormat row_pack_format = PackFormat::kCompressed; - if (vreg_rows != 1) { - // When going from (a, b) to (a * n, b) tiling, each output tile is the - // union of n input tiles from different vregs. The ith tile of the output - // vreg is formed by packing the ith tiles of the input vregs together. - // This can only be done when tiles are one sublane (by packing interleaved) - // or when they occupy the full vreg (by packing compressed). - // Note: Currently, we always pack across rows before packing across - // columns, so we just check the source tiling. - if (input_sublanes_per_tile == 1) { - row_pack_format = PackFormat::kInterleaved; - } else if (input_sublanes_per_tile == ctx.target_shape[0]) { - row_pack_format = PackFormat::kCompressed; - } else { - return op.emitOpError( - "Not implemented: Tiling change requires interleaving tiles that are " - "not one sublane or one full vreg"); - } - } - // The tiling after packing across rows: - const std::array intermediate_tiling = { - layout_in.tiling()[0] * vreg_rows, layout_in.tiling()[1]}; - DCHECK_EQ(intermediate_tiling[0], layout_out.tiling()[0]); - - // We only support compressed packing across vreg columns, which doesn't - // change the tiling. Logically, it just stacks tiles horizontally. - if (intermediate_tiling[1] != layout_out.tiling()[1] && - // For (1, x) tiling all minor dimension tilings are equivalent, although - // some are illegal in VectorLayout. So, even though compressed packing in - // general does not change the tiling, for (1, x) we can still change to - // other minor dimension tilings (they are equivalent). - intermediate_tiling[0] != 1) { - // This could be handled, in some cases, by using interleaved packing across - // vreg columns, but we never use tilings like this. An example where we - // could use interleaved packing is (8, 128) f32 -> (8, 256) bf16. - return op.emitOpError( - "Not implemented: Truncating to increasing minor tile size"); - } - // The format for packing *across* multiple columns in the vreg array - // (different minor index): - constexpr PackFormat col_pack_format = PackFormat::kCompressed; - - if (vreg_rows != 1 && vreg_cols != 1 && row_pack_format != col_pack_format) { - // TODO(b/384274392): We can alternate interleaved and compressed packing - // but how should we expose it in tpu.pack_subelements? - return op.emitOpError( - "Not implemented: Tiling change requires mixed compressed and " - "interleaved packing"); - } - const PackFormat pack_format = - vreg_rows != 1 ? row_pack_format : col_pack_format; - - const VectorType res_vreg_ty = - getNativeVregType(result_ty.getElementType(), ctx.target_shape); - - SmallVector input_idx; - output_vregs.Each([&](absl::Span output_idx, Value *v) { - SmallVector parts; - input_idx.assign(output_idx.begin(), output_idx.end()); - auto push_col = [&]() { - if (!output_offsets[0].has_value()) { - *(input_idx.end() - 2) = 0; - // Make sure we set all rows of the column to make it replicated - parts.append(vreg_rows, input_vregs(input_idx)); - } else { - const int64_t row_offset = *output_offsets[0] / input_vreg_slice[0]; - const int64_t base_src_row = - *(output_idx.end() - 2) * vreg_rows - row_offset; - for (int64_t row = base_src_row; row < base_src_row + vreg_rows; - ++row) { - if (0 <= row && row < *(input_vregs.dimensions().end() - 2)) { - *(input_idx.end() - 2) = row; - parts.push_back(input_vregs(input_idx)); - } else { - parts.push_back(nullptr); - } - } - } - }; - if (!output_offsets[1].has_value()) { - *(input_idx.end() - 1) = 0; - // Make sure we set all column parts of the vreg to make it replicated - push_col(); - for (int64_t col = 1; col < vreg_cols; ++col) { - for (int64_t row = 0; row < vreg_rows; ++row) { - parts.push_back(parts[row]); - } - } - } else { - const int64_t col_offset = *output_offsets[1] / input_vreg_slice[1]; - const int64_t base_src_col = - *(output_idx.end() - 1) * vreg_cols - col_offset; - for (int64_t col = base_src_col; col < base_src_col + vreg_cols; ++col) { - if (0 <= col && col < *(input_vregs.dimensions().end() - 1)) { - *(input_idx.end() - 1) = col; - push_col(); - } else { - parts.append(vreg_rows, nullptr); - } - } - } - *v = builder.create(res_vreg_ty, parts, pack_format); - }); - op.replaceAllUsesWith(assemble(builder, result_ty, layout_out, - std::move(output_vregs), ctx.target_shape, - /*use_implicit_shape=*/true) - .getResult()); - op.erase(); - return success(); -} - -LogicalResult arith_truncf_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(layouts_in.size(), 1); - TPU_ASSERT_OP(layouts_in.front().has_value()); - TPU_ASSERT_EQ_OP(layouts_out.size(), 1); - TPU_ASSERT_OP(layouts_out.front().has_value()); - auto truncf_op = cast(op); - if (layouts_in.front()->bitwidth() != 32 || - (layouts_out.front()->bitwidth() != 16 && - layouts_out.front()->bitwidth() != 8)) { - return op.emitOpError( - "Not implemented: Only 32-bit to 16-or-8-bit conversion supported"); - } - return trunc_op_rule_impl(ctx, truncf_op, *layouts_in.front(), - *layouts_out.front()); -} - -LogicalResult arith_trunci_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(layouts_in.size(), 1); - TPU_ASSERT_OP(layouts_in.front().has_value()); - TPU_ASSERT_EQ_OP(layouts_out.size(), 1); - TPU_ASSERT_OP(layouts_out.front().has_value()); - auto trunci_op = cast(op); - return trunc_op_rule_impl(ctx, trunci_op, *layouts_in.front(), - *layouts_out.front()); -} - -LogicalResult tpu_fptosi_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(layouts_in.size(), 1); - TPU_ASSERT_OP(layouts_in.front().has_value()); - TPU_ASSERT_EQ_OP(layouts_out.size(), 1); - TPU_ASSERT_OP(layouts_out.front().has_value()); - auto &layout_in = *layouts_in.front(); - auto &layout_out = *layouts_out.front(); - if (layout_in.bitwidth() == layout_out.bitwidth()) { - return elementwise_op_rule(ctx, op, layouts_in, layouts_out); - } else if (layout_in.bitwidth() > layout_out.bitwidth()) { - // FPToSI semantics require rounding towards zero, but packing instructions - // use rounding towards nearest even. We need to insert explicit rounding, - // unless the input is already rounded to nearest even. - auto fptosi_op = cast(op); - switch (fptosi_op.getRoundingMode()) { - case tpu::RoundingMode::kToNearestEven: - break; // That is the mode used by tpu.pack_subelements. - case tpu::RoundingMode::kTowardsZero: { - auto input = cast>(fptosi_op.getInput()); - ImplicitLocOpBuilder builder(op.getLoc(), fptosi_op); - FAILUREOR_ASSIGN_OR_RETURN( - xla::Array vregs, - disassemble(builder, layout_in, input, ctx.target_shape)); - vregs.Each([&](absl::Span idxs, Value *v) { - *v = builder.create(op.getLoc(), v->getType(), - *v); - }); - fptosi_op->replaceUsesOfWith( - input, assemble(builder, input.getType(), layout_in, vregs, - ctx.target_shape)); - } break; - } - return trunc_op_rule_impl(ctx, fptosi_op, layout_in, layout_out); - } - return op.emitOpError("Unsupported FPToSI conversion"); -} - -LogicalResult func_return_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_OP(layouts_out.empty()); - for (const Layout &layout_in : layouts_in) { - if (layout_in.has_value()) { - return op.emitOpError("Vector-typed return values are not supported"); - } - } - return success(); -} - -LogicalResult scf_for_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - scf::ForOp for_op = cast(op); - TPU_ASSERT_EQ_OP(layouts_in.size(), for_op->getNumOperands()); - TPU_ASSERT_EQ_OP(layouts_out.size(), for_op->getNumResults()); - FAILUREOR_ASSIGN_OR_RETURN( - const SmallVector yield_in_layouts, - getInLayouts(*for_op.getBody()->getTerminator(), ctx.target_shape)); - int out_idx = 0; - for (auto [in_layout, yield_layout, out_layout, result] : - llvm::zip_equal(layouts_in.drop_front(3), yield_in_layouts, layouts_out, - op.getResults())) { - if (auto vty = dyn_cast(result.getType())) { - TPU_ASSERT_OP(in_layout.has_value()); - TPU_ASSERT_OP(yield_layout.has_value()); - TPU_ASSERT_OP(out_layout.has_value()); - if (in_layout.value() != yield_layout.value()) { - return op.emitOpError( - "Not implemented: for loop input layout does not match with " - "yield layout ") - << out_idx; - } - if (in_layout.value() != out_layout.value()) { - return op.emitOpError( - "Not implemented: for loop input layout does not match with " - "out layout ") - << out_idx; - } - } else { - TPU_ASSERT_EQ_OP(in_layout, kNoLayout); - TPU_ASSERT_EQ_OP(yield_layout, kNoLayout); - TPU_ASSERT_EQ_OP(out_layout, kNoLayout); - } - ++out_idx; - } - - if (failed(applyLayoutBlock(ctx, *for_op.getBody()))) { - return failure(); - } - - if (op.getNumResults() == 0) { - return success(); - } - - OpBuilder builder(&op); - SmallVector unrolled_args; - for (int i = 0; i < layouts_in.size(); ++i) { - auto layout = layouts_in[i]; - auto operand = for_op.getOperand(i); - if (i < 3) { - if (layout.has_value()) { - return op.emitOpError("Expected no layout for bounds and step"); - } - continue; - } - if (auto vector_operand = dyn_cast>(operand)) { - if (!layout.has_value()) { - return op.emitOpError("Expected layout for vector operand"); - } - FAILUREOR_ASSIGN_OR_RETURN( - const xla::Array tiles, - disassemble(builder, *layout, vector_operand, ctx.target_shape)); - unrolled_args.append(tiles.begin(), tiles.end()); - } else { - if (layout.has_value()) { - return op.emitOpError("Expected no layout for scalar operand"); - } - unrolled_args.push_back(operand); - } - } - - // Create a new scf::ForOp with unrolled args. - auto new_op = builder.create( - for_op->getLoc(), for_op.getLowerBound(), for_op.getUpperBound(), - for_op.getStep(), unrolled_args); - - int num_old_args = for_op.getBody()->getNumArguments(); - SmallVector locs(new_op.getBody()->getNumArguments(), - for_op.getLoc()); - for_op.getBody()->addArguments(TypeRange(new_op.getBody()->getArguments()), - locs); - builder.setInsertionPointToStart(for_op.getBody()); - auto arg_idx = num_old_args; - // Block also has an induction variable that should have no layout, - // which conveniently matches the in layouts. - for (auto [old_arg, layout] : llvm::zip_equal( - for_op.getBody()->getArguments().take_front(num_old_args), - layouts_in.drop_front(2))) { - if (const auto vty = dyn_cast(old_arg.getType())) { - TPU_ASSERT_OP(layout.has_value()); - const SmallVector tiles_shape = - layout->tileArrayShape(vty.getShape(), ctx.target_shape); - const int64_t num_vectors = ShapedType::getNumElements(tiles_shape); - xla::Array tiles(tiles_shape); - TPU_ASSERT_LE_OP(arg_idx + num_vectors, - for_op.getBody()->getNumArguments()); - tiles.SetValues(llvm::make_range( - for_op.getBody()->getArguments().begin() + arg_idx, - for_op.getBody()->getArguments().begin() + arg_idx + num_vectors)); - arg_idx += num_vectors; - RollVectorsOp rolled_op = - assemble(builder, vty, *layout, tiles, ctx.target_shape); - old_arg.replaceUsesWithIf(rolled_op, [&](OpOperand &operand) { - return operand.getOwner() != rolled_op; - }); - } else { - TPU_ASSERT_OP(!layout.has_value()); - old_arg.replaceAllUsesWith(for_op.getBody()->getArgument(arg_idx)); - ++arg_idx; - } - } - for_op.getBody()->eraseArguments(0, num_old_args); - new_op.getRegion().takeBody(for_op.getRegion()); - - // Roll the results back to the original shapes. - builder.setInsertionPointAfter(new_op); - int64_t res_idx = 0; - SmallVector rolled_results; - for (auto [result, layout] : - llvm::zip_equal(for_op.getResults(), layouts_out)) { - if (const auto vty = dyn_cast(result.getType())) { - TPU_ASSERT_OP(layout.has_value()); - const SmallVector tiles_shape = - layout->tileArrayShape(vty.getShape(), ctx.target_shape); - const int64_t num_vectors = ShapedType::getNumElements(tiles_shape); - xla::Array tiles(tiles_shape); - TPU_ASSERT_LE_OP(res_idx + num_vectors, new_op.getResults().size()); - tiles.SetValues(llvm::make_range( - new_op.getResults().begin() + res_idx, - new_op.getResults().begin() + res_idx + num_vectors)); - res_idx += num_vectors; - RollVectorsOp rolled_op = - assemble(builder, vty, *layout, tiles, ctx.target_shape); - rolled_results.push_back(rolled_op); - } else { - TPU_ASSERT_OP(!layout.has_value()); - rolled_results.push_back(new_op.getResult(res_idx)); - ++res_idx; - } - } - - for_op.replaceAllUsesWith(rolled_results); - for_op.erase(); - return success(); -} - -LogicalResult scf_while_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - scf::WhileOp while_op = cast(op); - TPU_ASSERT_EQ_OP(layouts_in.size(), while_op->getNumOperands()); - TPU_ASSERT_EQ_OP(layouts_out.size(), while_op->getNumResults()); - TPU_ASSERT_EQ_OP(layouts_in.size(), layouts_out.size()); - - // The terminator for the before region is the condition op. - // It takes multiple arguments -- the first being the decision to execute the - // after region or branch to the exit. - FAILUREOR_ASSIGN_OR_RETURN( - const SmallVector cond_in_layouts, - getInLayouts(*while_op.getBeforeBody()->getTerminator(), - ctx.target_shape)); - - FAILUREOR_ASSIGN_OR_RETURN( - const SmallVector yield_in_layouts, - getInLayouts(*while_op.getYieldOp(), ctx.target_shape)); - int out_idx = 0; - for (auto [in_layout, cond_layout, yield_layout, out_layout, result] : - llvm::zip_equal(layouts_in, - ArrayRef(cond_in_layouts).drop_front(1), - yield_in_layouts, layouts_out, op.getResults())) { - if (auto vty = dyn_cast(result.getType())) { - TPU_ASSERT_OP(in_layout.has_value()); - TPU_ASSERT_OP(yield_layout.has_value()); - TPU_ASSERT_OP(out_layout.has_value()); - if (in_layout.value() != cond_layout.value()) { - return op.emitOpError( - "Not implemented: while loop input layout does not match " - "with condition layout ") - << out_idx; - } - if (in_layout.value() != yield_layout.value()) { - return op.emitOpError( - "Not implemented: while loop input layout does not match " - "with yield layout ") - << out_idx; - } - if (in_layout.value() != out_layout.value()) { - return op.emitOpError( - "Not implemented: while loop input layout does not match " - "with output layout ") - << out_idx; - } - } else { - TPU_ASSERT_EQ_OP(in_layout, kNoLayout); - TPU_ASSERT_EQ_OP(cond_layout, kNoLayout); - TPU_ASSERT_EQ_OP(yield_layout, kNoLayout); - TPU_ASSERT_EQ_OP(out_layout, kNoLayout); - } - ++out_idx; - } - - if (failed(applyLayoutBlock(ctx, *while_op.getBeforeBody()))) { - return failure(); - } - - if (failed(applyLayoutBlock(ctx, *while_op.getAfterBody()))) { - return failure(); - } - - if (op.getNumResults() == 0) { - return success(); - } - - OpBuilder builder(&op); - SmallVector unrolled_args; - for (int i = 0; i < layouts_in.size(); ++i) { - auto layout = layouts_in[i]; - auto operand = while_op.getOperand(i); - if (auto vector_operand = dyn_cast>(operand)) { - if (!layout.has_value()) { - return op.emitOpError("Expected layout for vector operand"); - } - FAILUREOR_ASSIGN_OR_RETURN( - const xla::Array tiles, - disassemble(builder, *layout, vector_operand, ctx.target_shape)); - unrolled_args.append(tiles.begin(), tiles.end()); - } else { - if (layout.has_value()) { - return op.emitOpError("Expected no layout for scalar operand"); - } - unrolled_args.push_back(operand); - } - } - - // Create a new scf::WhileOp with unrolled args. - auto new_op = builder.create( - while_op->getLoc(), - TypeRange(while_op.getConditionOp().getOperands().drop_front(1)), - unrolled_args, nullptr, nullptr); - - const auto tile_body_args = [&](::mlir::Block *old_body, - ::mlir::Block *new_body, - const ArrayRef layouts) { - TPU_ASSERT_OP(old_body != nullptr); - TPU_ASSERT_OP(new_body != nullptr); - int num_old_args = old_body->getNumArguments(); - SmallVector locs(new_body->getNumArguments(), while_op.getLoc()); - old_body->addArguments(TypeRange(new_body->getArguments()), locs); - builder.setInsertionPointToStart(old_body); - auto arg_idx = num_old_args; - for (auto [old_arg, layout] : llvm::zip_equal( - old_body->getArguments().take_front(num_old_args), layouts)) { - if (const auto vty = dyn_cast(old_arg.getType())) { - TPU_ASSERT_OP(layout.has_value()); - const SmallVector tiles_shape = - layout->tileArrayShape(vty.getShape(), ctx.target_shape); - const int64_t num_vectors = ShapedType::getNumElements(tiles_shape); - xla::Array tiles(tiles_shape); - TPU_ASSERT_LE_OP(arg_idx + num_vectors, old_body->getNumArguments()); - tiles.SetValues(llvm::make_range( - old_body->getArguments().begin() + arg_idx, - old_body->getArguments().begin() + arg_idx + num_vectors)); - arg_idx += num_vectors; - RollVectorsOp rolled_op = - assemble(builder, vty, *layout, tiles, ctx.target_shape); - old_arg.replaceUsesWithIf(rolled_op, [&](OpOperand &operand) { - return operand.getOwner() != rolled_op; - }); - } else { - TPU_ASSERT_OP(!layout.has_value()); - old_arg.replaceAllUsesWith(old_body->getArgument(arg_idx)); - ++arg_idx; - } - } - old_body->eraseArguments(0, num_old_args); - return success(); - }; - - const auto before_status = tile_body_args(while_op.getBeforeBody(), - new_op.getBeforeBody(), layouts_in); - if (before_status.failed()) return before_status; - new_op.getBefore().takeBody(while_op.getBefore()); - - const auto after_status = tile_body_args(while_op.getAfterBody(), - new_op.getAfterBody(), layouts_out); - if (after_status.failed()) return after_status; - new_op.getAfter().takeBody(while_op.getAfter()); - - builder.setInsertionPointAfter(new_op); - int64_t res_idx = 0; - SmallVector rolled_results; - for (auto [result, layout] : - llvm::zip_equal(while_op.getResults(), layouts_out)) { - if (const auto vty = dyn_cast(result.getType())) { - TPU_ASSERT_OP(layout.has_value()); - const SmallVector tiles_shape = - layout->tileArrayShape(vty.getShape(), ctx.target_shape); - const int64_t num_vectors = ShapedType::getNumElements(tiles_shape); - xla::Array tiles(tiles_shape); - TPU_ASSERT_LE_OP(res_idx + num_vectors, new_op.getResults().size()); - tiles.SetValues(llvm::make_range( - new_op.getResults().begin() + res_idx, - new_op.getResults().begin() + res_idx + num_vectors)); - res_idx += num_vectors; - RollVectorsOp rolled_op = - assemble(builder, vty, *layout, tiles, ctx.target_shape); - rolled_results.push_back(rolled_op); - } else { - TPU_ASSERT_OP(!layout.has_value()); - rolled_results.push_back(new_op.getResult(res_idx)); - ++res_idx; - } - } - - while_op.replaceAllUsesWith(rolled_results); - while_op.erase(); - return success(); -} - -LogicalResult scf_condition_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - OpBuilder builder(&op); - auto condition_op = cast(op); - TPU_ASSERT_EQ_OP(layouts_in.size(), condition_op.getNumOperands()); - TPU_ASSERT_EQ_OP(layouts_out.size(), 0); - SmallVector unrolled; - - for (auto [operand, layout] : - llvm::zip_equal(condition_op.getOperands(), layouts_in)) { - if (auto vector_operand = dyn_cast>(operand)) { - // When the operand has vector type, disassemble the operand. - TPU_ASSERT_OP(layout.has_value()); - FAILUREOR_ASSIGN_OR_RETURN( - const xla::Array tiles, - disassemble(builder, *layout, vector_operand, ctx.target_shape)); - unrolled.append(tiles.begin(), tiles.end()); - } else { - TPU_ASSERT_OP(!layout.has_value()); - unrolled.push_back(operand); - } - } - - // Replace the old operands with unrolled operands. - condition_op->setOperands(unrolled); - return success(); -} - -LogicalResult scf_if_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(layouts_in.size(), 1); - TPU_ASSERT_OP(!layouts_in.front().has_value()); - ImplicitLocOpBuilder builder(op.getLoc(), &op); - scf::IfOp if_op = cast(op); - SmallVector then_yield_in_layouts; - SmallVector else_yield_in_layouts; - FAILUREOR_ASSIGN_OR_RETURN( - then_yield_in_layouts, - getInLayouts(*if_op.thenYield(), ctx.target_shape)); - if (!if_op.getElseRegion().empty()) { - FAILUREOR_ASSIGN_OR_RETURN( - else_yield_in_layouts, - getInLayouts(*if_op.elseYield(), ctx.target_shape)); - } - int out_idx = 0; - for (auto [then_layout, else_layout, result_layout, result] : - llvm::zip_equal(then_yield_in_layouts, else_yield_in_layouts, - layouts_out, op.getResults())) { - if (auto vty = dyn_cast(result.getType())) { - TPU_ASSERT_OP(then_layout.has_value()); - TPU_ASSERT_OP(else_layout.has_value()); - TPU_ASSERT_OP(result_layout.has_value()); - if (result_layout.value() != then_layout.value()) { - return op.emitOpError( - "Not implemented: yield layout from then branch does not " - "match with output layout ") - << out_idx; - } - if (result_layout.value() != else_layout.value()) { - return op.emitOpError( - "Not implemented: yield layout from else branch does not " - "match with output layout ") - << out_idx; - } - } else { - TPU_ASSERT_EQ_OP(then_layout, kNoLayout); - TPU_ASSERT_EQ_OP(else_layout, kNoLayout); - TPU_ASSERT_EQ_OP(result_layout, kNoLayout); - } - ++out_idx; - } - if (failed(applyLayoutBlock(ctx, *if_op.thenBlock()))) { - return failure(); - } - if (if_op.getElseRegion().empty()) { - TPU_ASSERT_EQ_OP(if_op->getNumResults(), 0); - TPU_ASSERT_EQ_OP(layouts_out.size(), 0); - return success(); - } - if (failed(applyLayoutBlock(ctx, *if_op.elseBlock()))) { - return failure(); - } - - // Apply layout to results after applying layout in the true and false - // regions. - if (if_op.getNumResults() == 0) { - TPU_ASSERT_EQ_OP(layouts_out.size(), 0); - return success(); - } - TPU_ASSERT_EQ_OP(if_op.getNumResults(), layouts_out.size()); - // If scf.if has results, it should have both non-empty true and false - // regions. - TPU_ASSERT_OP(!if_op.getThenRegion().empty() && - !if_op.getElseRegion().empty()); - - // Move true and false regions to the new if op whose result has same type and - // layout as yield operand's. - auto new_op = builder.create( - TypeRange(if_op.thenYield().getResults()), if_op.getCondition(), - /*withElseRegion =*/true); - moveAllRegions(*if_op, *new_op); - - int64_t index = 0; - SmallVector rolled_results; - for (auto [result, layout] : - llvm::zip_equal(if_op.getResults(), layouts_out)) { - if (const auto vty = dyn_cast(result.getType())) { - // When the result has a vector type, assemble the result. - TPU_ASSERT_OP(layout.has_value()); - const SmallVector tiles_shape = - layout->tileArrayShape(vty.getShape(), ctx.target_shape); - const int64_t num_vectors = ShapedType::getNumElements(tiles_shape); - xla::Array tiles(tiles_shape); - TPU_ASSERT_LE_OP(index + num_vectors, new_op.getResults().size()); - tiles.SetValues( - llvm::make_range(new_op.getResults().begin() + index, - new_op.getResults().begin() + index + num_vectors)); - index += num_vectors; - RollVectorsOp rolled_op = - assemble(builder, vty, *layout, tiles, ctx.target_shape); - rolled_results.push_back(rolled_op); - } else { - TPU_ASSERT_OP(!layout.has_value()); - rolled_results.push_back(new_op.getResult(index)); - ++index; - } - } - if_op.replaceAllUsesWith(rolled_results); - if_op.erase(); - return success(); -} - -LogicalResult yield_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - OpBuilder builder(&op); - TPU_ASSERT_EQ_OP(layouts_in.size(), op.getNumOperands()); - TPU_ASSERT_EQ_OP(layouts_out.size(), 0); - if (op.getNumOperands() == 0) { - return success(); - } - SmallVector unrolled; - for (auto [operand, layout] : - llvm::zip_equal(op.getOperands(), layouts_in)) { - if (auto vector_operand = dyn_cast>(operand)) { - // When the operand has vector type, disassemble the operand. - TPU_ASSERT_OP(layout.has_value()); - FAILUREOR_ASSIGN_OR_RETURN( - const xla::Array tiles, - disassemble(builder, *layout, vector_operand, ctx.target_shape)); - unrolled.append(tiles.begin(), tiles.end()); - } else { - TPU_ASSERT_OP(!layout.has_value()); - unrolled.push_back(operand); - } - } - - // Replace the old operands with unrolled operands. - op.setOperands(unrolled); - return success(); -} - -LogicalResult tpu_load_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(layouts_out.size(), 1); - TPU_ASSERT_OP(llvm::none_of(layouts_in, - [&](const Layout &l) { return l.has_value(); })); - TPU_ASSERT_OP(layouts_out.front().has_value()); - const VectorLayout &layout_out = *layouts_out.front(); - // We expect the result is already a native-sized vreg. - // TODO(b/300493694): Support other bitwidths - if (layout_out.bitwidth() != 32) { - return op.emitOpError("Not implemented: Only 32-bit loads supported"); - } - tpu::LoadOp load_op = cast(op); - if (layout_out != VectorLayout(32, {0, 0}, ctx.target_shape, - VectorLayout::ImplicitDim::kNone)) { - return op.emitOpError("Invalid output layout for ") << load_op->getName(); - } - FAILUREOR_ASSIGN_OR_RETURN( - const SmallVector indices, - getIntConstsFromOperandRange(load_op.getIndices())); - TPU_ASSERT_EQ_OP(indices.size(), 2); - if (indices[1] % ctx.target_shape[1] != 0) { - return op.emitOpError("Not implemented: Lane index is not a multiple of ") - << ctx.target_shape[1]; - } - - OpBuilder builder(op.getContext()); - builder.setInsertionPointAfter(&op); - const RollVectorsOp roll_vectors_op = - assemble(builder, load_op.getResult().getType(), layout_out, - {{load_op.getResult()}}, ctx.target_shape); - load_op->replaceUsesWithIf(roll_vectors_op, [&](OpOperand &operand) { - return operand.getOwner() != roll_vectors_op; - }); - return success(); -} - -LogicalResult strided_op_rule_impl(RewriteContext &ctx, Operation &op, - Value base_ref, ValueRange indices, - const VectorType &vty, - const VectorLayout &layout, - const ArrayRef &strides) { - if (!isa(op)) { - return op.emitOpError("Not implemented: Unsupported strided op") - << op.getName(); - } - if (layout != VectorLayout(32, {0, 0}, ctx.target_shape, - VectorLayout::ImplicitDim::kNone)) { - return op.emitOpError("Not implemented: Unsupported vector layout in ") - << op.getName(); - } - const auto base_ty = getMemRefType(base_ref); - auto rank = base_ty.getRank(); - CHECK_EQ(rank, indices.size()); - CHECK_EQ(rank, strides.size()); - CHECK_EQ(rank, vty.getShape().size()); - if (rank < 2) { - return op.emitOpError("Not implemented: Stride on 1D vector"); - } - auto mem_layout = dyn_cast(base_ty.getLayout()); - if (!mem_layout) { - return op.emitOpError("Expected a tiled memref"); - } - auto tile_strides = mem_layout.getTileStrides(); - - // Currently we hold constraints that the last dim size of memref needs to be - // exactly same as the lane size of native vreg and the memref has never - // been sliced before on the last dim. In other words, the original base - // memref's shape needs to be (..., target_shape[1]). - if (base_ty.getShape()[rank - 1] != ctx.target_shape[1] || - tile_strides.take_back(2) != ArrayRef{1, 1}) { - return op.emitOpError("Not Implemented: The last dim size is not ") - << ctx.target_shape[1] << " in original base memref"; - } - if (strides[rank - 1] != 1) { - return op.emitOpError("Not Implemented: Stride on last dim is not 1"); - } - auto last_idx = getIntConst(indices[rank - 1], /*silent=*/true); - if (failed(last_idx)) { - return op.emitOpError("Not Implemented: Dynamic index on last dim"); - } else if (last_idx.value() != 0) { - return op.emitOpError("Not Implemented: Index on last dim is not 0"); - } - ImplicitLocOpBuilder builder(op.getLoc(), &op); - - VectorType vreg_ty = - getNativeVregType(vty.getElementType(), ctx.target_shape); - - bool is_load_op = true; - xla::Array tiles( - layout.tileArrayShape(vty.getShape(), ctx.target_shape)); - if (auto store_op = dyn_cast(op)) { - is_load_op = false; - FAILUREOR_ASSIGN_OR_RETURN( - tiles, disassemble(builder, layout, store_op.getValueToStore(), - ctx.target_shape)); - } - - tiles.Each([&](absl::Span tile_idxs, Value *v) { - CHECK_EQ(tile_idxs.size(), rank); - SmallVector idxs(rank); - for (int64_t i = 0; i < rank; ++i) { - int64_t stride = (i < rank - 2) - ? strides[i] - : (strides[i] * ctx.target_shape[i - rank + 2]); - idxs[i] = builder.create( - indices[i], IdxConst(tile_idxs[i] * stride, builder, op.getLoc())); - } - SmallVector sublane_mask(ctx.target_shape[0], true); - int64_t sublane_rem = vty.getDimSize(rank - 2) % ctx.target_shape[0]; - if (sublane_rem > 0 && tile_idxs[rank - 2] == tiles.dim(rank - 2) - 1) { - for (int64_t i = sublane_rem; i < ctx.target_shape[0]; ++i) { - sublane_mask[i] = false; - } - } - const auto sublane_mask_attr = - DenseBoolArrayAttr::get(op.getContext(), sublane_mask); - if (is_load_op) { - *v = builder.create( - vreg_ty, base_ref, idxs, sublane_mask_attr, - builder.getI32IntegerAttr(strides[rank - 2])); - } else { - builder.create( - *v, base_ref, idxs, sublane_mask_attr, - /*mask=*/nullptr, builder.getI32IntegerAttr(strides[rank - 2])); - } - }); - if (is_load_op) { - op.replaceAllUsesWith( - assemble(builder, vty, layout, std::move(tiles), ctx.target_shape)); - } - op.erase(); - return success(); -} - -// TODO(jevinjiang): maybe unify with vector load? -LogicalResult tpu_strided_load_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_OP(llvm::none_of(layouts_in, - [&](const Layout &l) { return l.has_value(); })); - TPU_ASSERT_EQ_OP(layouts_out.size(), 1); - TPU_ASSERT_OP(layouts_out.front().has_value()); - const VectorLayout &layout_out = *layouts_out.front(); - auto load_op = cast(op); - const auto vty = cast(load_op.getResult().getType()); - return strided_op_rule_impl(ctx, op, load_op.getBase(), load_op.getIndices(), - vty, layout_out, load_op.getStrides()); -} - -// TODO(jevinjiang): maybe unify with vector store? -LogicalResult tpu_strided_store_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_OP(layouts_in.front().has_value()); - TPU_ASSERT_OP(llvm::none_of(layouts_in.drop_front(), - [&](const Layout &l) { return l.has_value(); })); - TPU_ASSERT_EQ_OP(layouts_out.size(), 0); - - const VectorLayout &to_store_layout = *layouts_in.front(); - auto store_op = cast(op); - const auto vty = store_op.getValueToStore().getType(); - return strided_op_rule_impl(ctx, op, store_op.getBase(), - store_op.getIndices(), vty, to_store_layout, - store_op.getStrides()); -} - -LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(layouts_in.size(), 3); - TPU_ASSERT_EQ_OP(layouts_out.size(), 1); - TPU_ASSERT_OP( - llvm::all_of(layouts_in, [&](const Layout &l) { return l.has_value(); })); - TPU_ASSERT_OP(layouts_out.front().has_value()); - auto matmul_op = cast(op); - if (matmul_op.getTransposeRhs()) { - return op.emitOpError( - "Transposition must have been erased into dimension numbers during " - "canonicalization"); - } - - auto dimension_numbers = matmul_op.getDimensionNumbers(); - if (!dimension_numbers.has_value()) { - return op.emitOpError( - "Dimension numbers must be provided, ensure canonicalization has been " - "run."); - } - auto transposed_mkn = isTransposedMatmul(dimension_numbers.value()); - if (!transposed_mkn.has_value()) { - return op.emitOpError( - "Dimension numbers must be MKN, ensure canonicalization has been " - "run."); - } - auto [transpose_lhs, transpose_rhs] = transposed_mkn.value(); - if (transpose_lhs) { - return op.emitOpError( - "Transposition of LHS is not supported in apply_vector_layout, ensure " - "canonicalization has been run."); - } - - auto &layout_lhs = *layouts_in[0]; - auto &layout_rhs = *layouts_in[1]; - auto &layout_acc = *layouts_in[2]; - auto &layout_out = *layouts_out[0]; - - const std::array, 4> all_layouts = - {layout_lhs, layout_rhs, layout_acc, layout_out}; - for (const VectorLayout &layout : all_layouts) { - for (const LayoutOffset offset : layout.offsets()) { - if (offset.value_or(0) != 0) { - return op.emitOpError("Not implemented: Unaligned layout in matmul"); - } - } - } - ImplicitLocOpBuilder builder(op.getLoc(), &op); - TypedValue lhs, rhs, acc, res; - if (auto tpu_matmul_op = dyn_cast(op)) { - lhs = tpu_matmul_op.getLhs(); - rhs = tpu_matmul_op.getRhs(); - acc = tpu_matmul_op.getAcc(); - res = tpu_matmul_op.getResult(); - } else { - return op.emitOpError("Expected a tpu::MatmulOp"); - } - - for (const std::optional &layout_opt : layouts_in) { - auto layout = *layout_opt; - if (layout.implicit_dim() != VectorLayout::ImplicitDim::kNone) { - return op.emitOpError( - "Not implemented: Unsupported matmul operand layout"); - } - if (!layout.hasNativeTiling(ctx.target_shape)) { - return op.emitOpError( - "Not implemented: Unsupported matmul operand tiling"); - } - } - if (acc.getType().getElementType().getIntOrFloatBitWidth() != 32) { - return op.emitOpError("Not implemented: Non-32-bit matmul acc"); - } - const ArrayRef lhs_shape = lhs.getType().getShape(); - const ArrayRef rhs_shape = rhs.getType().getShape(); - // TODO(tlongeri): This should be part of the tpu::MatmulOp verifier - TPU_ASSERT_EQ_OP(lhs_shape.size(), 2); - TPU_ASSERT_EQ_OP(rhs_shape.size(), 2); - - const int64_t padded_lhs_rows = - llvm::alignTo(lhs_shape[0], layout_lhs.tiling()[0]); - const int64_t padded_lhs_cols = - llvm::alignTo(lhs_shape[1], layout_lhs.tiling()[1]); - const int64_t padded_rhs_rows = - llvm::alignTo(rhs_shape[0], layout_rhs.tiling()[0]); - const int64_t padded_rhs_cols = - llvm::alignTo(rhs_shape[1], layout_rhs.tiling()[1]); - - FAILUREOR_ASSIGN_OR_RETURN( - xla::Array lhs_vregs, - disassemble(builder, layout_lhs, lhs, ctx.target_shape)); - FAILUREOR_ASSIGN_OR_RETURN( - xla::Array acc_vregs, - disassemble(builder, layout_acc, acc, ctx.target_shape)); - FAILUREOR_ASSIGN_OR_RETURN( - xla::Array rhs_vregs, - disassemble(builder, layout_rhs, rhs, ctx.target_shape)); - TPU_ASSERT_EQ_OP(padded_lhs_rows, lhs_vregs.dim(0) * layout_lhs.tiling()[0]); - TPU_ASSERT_EQ_OP(padded_rhs_rows, rhs_vregs.dim(0) * layout_rhs.tiling()[0]); - - auto lhs_zeros_vreg = - getZerosVector(builder, cast(lhs_vregs.begin()->getType())); - auto rhs_zeros_vreg = - getZerosVector(builder, cast(rhs_vregs.begin()->getType())); - auto acc_zeros_vreg = - getZerosVector(builder, cast(acc_vregs.begin()->getType())); - - // Only mask out the paddings on contracting dim of LHS and RHS. - RETURN_IF_FAILED( - maskNativeTilingVregs(builder, lhs_vregs, ctx.target_shape, - /*padding_bottom=*/0, - /*padding_right=*/padded_lhs_cols - lhs_shape[1])); - if (transpose_rhs) { - RETURN_IF_FAILED(maskNativeTilingVregs( - builder, rhs_vregs, ctx.target_shape, - /*padding_bottom=*/0, - /*padding_right=*/padded_rhs_cols - rhs_shape[1])); - } else { - RETURN_IF_FAILED( - maskNativeTilingVregs(builder, rhs_vregs, ctx.target_shape, - /*padding_bottom=*/padded_rhs_rows - rhs_shape[0], - /*padding_right=*/0)); - } - - // At this point, all paddings on vregs are masked out. For now, we - // append zero vregs to make LHS's second dim, both RHS's dims and ACC's - // second dim to be a multiple of mxu_size. - auto mxu_contracting_size = ctx.mxu_shape[0]; - auto mxu_noncontracting_size = ctx.mxu_shape[1]; - if (lhs.getType().getElementType().isSignlessInteger(4) && - rhs.getType().getElementType().isSignlessInteger(4)) { - mxu_contracting_size *= 2; - } - auto rhs_row_size = mxu_contracting_size; - auto rhs_col_size = mxu_noncontracting_size; - if (transpose_rhs) { - rhs_row_size = mxu_noncontracting_size; - rhs_col_size = mxu_contracting_size; - } - CHECK_EQ(rhs_row_size % ctx.target_shape[1], 0); - CHECK_EQ(rhs_col_size % ctx.target_shape[1], 0); - - // Here, a single group corresponds to a single matmul invocation in unrolled - // code. The RHS group matches the MXU shape. - auto lhs_col_vregs_per_group = mxu_contracting_size / ctx.target_shape[1]; - auto rhs_row_vregs_per_group = - rhs_row_size / (ctx.target_shape[0] * layout_rhs.packing()); - auto rhs_col_vregs_per_group = rhs_col_size / ctx.target_shape[1]; - auto acc_col_vregs_per_group = mxu_noncontracting_size / ctx.target_shape[1]; - int64_t target_lhs_col_vregs = - llvm::alignTo(lhs_vregs.dim(1), lhs_col_vregs_per_group); - int64_t target_rhs_row_vregs = - llvm::alignTo(rhs_vregs.dim(0), rhs_row_vregs_per_group); - int64_t target_rhs_col_vregs = - llvm::alignTo(rhs_vregs.dim(1), rhs_col_vregs_per_group); - int64_t target_acc_col_vregs = - llvm::alignTo(acc_vregs.dim(1), acc_col_vregs_per_group); - - xla::Array target_lhs_vregs({lhs_vregs.dim(0), target_lhs_col_vregs}, - lhs_zeros_vreg); - xla::Array target_rhs_vregs( - {target_rhs_row_vregs, target_rhs_col_vregs}, rhs_zeros_vreg); - xla::Array target_acc_vregs( - {lhs_vregs.dim(0) * layout_lhs.packing(), target_acc_col_vregs}, - acc_zeros_vreg); - target_lhs_vregs.UpdateSlice(lhs_vregs, {0, 0}); - target_rhs_vregs.UpdateSlice(rhs_vregs, {0, 0}); - target_acc_vregs.UpdateSlice(acc_vregs, {0, 0}); - - // Now we can regroup vregs from target vregs. - const auto lhs_col_ty = VectorType::get( - {padded_lhs_rows, mxu_contracting_size}, lhs.getType().getElementType()); - const auto acc_col_ty = - VectorType::get({padded_lhs_rows, mxu_noncontracting_size}, - acc.getType().getElementType()); - const ArrayAttr lhs_layout_attr = - builder.getArrayAttr({builder.getAttr(layout_lhs)}); - const ArrayAttr rhs_layout_attr = - builder.getArrayAttr({builder.getAttr(layout_rhs)}); - const ArrayAttr acc_layout_attr = - builder.getArrayAttr({builder.getAttr(layout_acc)}); - - int64_t nk = llvm::divideCeil(lhs_shape[1], mxu_contracting_size); - CHECK_EQ(nk, target_lhs_vregs.dim(1) / lhs_col_vregs_per_group); - SmallVector lhs_cols(nk); - for (int64_t i = 0; i < nk; ++i) { - const xla::Array col_vregs = target_lhs_vregs.Slice( - {0, i * lhs_col_vregs_per_group}, - {target_lhs_vregs.dim(0), (i + 1) * lhs_col_vregs_per_group}); - lhs_cols[i] = builder.create( - op.getLoc(), lhs_col_ty, XlaArrayToFlatArrayRef(col_vregs)); - lhs_cols[i]->setAttr("out_layout", lhs_layout_attr); - } - const auto rhs_group_ty = VectorType::get({rhs_row_size, rhs_col_size}, - rhs.getType().getElementType()); - const int64_t rhs_vregs_per_group = - rhs_row_vregs_per_group * rhs_col_vregs_per_group; - - int64_t nj; - if (transpose_rhs) { - nj = llvm::divideCeil(rhs_shape[0], rhs_row_size); - CHECK_EQ(nk, llvm::divideCeil(rhs_shape[1], rhs_col_size)); - CHECK_EQ(nk, target_rhs_vregs.dim(1) / rhs_col_vregs_per_group); - target_rhs_vregs.Reshape({nj, rhs_vregs_per_group / rhs_col_vregs_per_group, - nk, rhs_col_vregs_per_group}); - target_rhs_vregs.TransposeDimensions({2, 0, 1, 3}); - target_rhs_vregs.Reshape({nk, nj, rhs_vregs_per_group}); - } else { - nj = llvm::divideCeil(rhs_shape[1], rhs_col_size); - CHECK_EQ(nk, llvm::divideCeil(rhs_shape[0], rhs_row_size)); - CHECK_EQ(nk, target_rhs_vregs.dim(0) / rhs_row_vregs_per_group); - target_rhs_vregs.Reshape({nk, rhs_vregs_per_group / rhs_col_vregs_per_group, - nj, rhs_col_vregs_per_group}); - target_rhs_vregs.TransposeDimensions({0, 2, 1, 3}); - target_rhs_vregs.Reshape({nk, nj, rhs_vregs_per_group}); - } - - const tpu::ContractPrecisionAttr precision_attr = // May be null - op.getAttrOfType("precision"); - const tpu::DotDimensionNumbersAttr dot_dimension_numbers_attr = - defaultDimensionNumbers(builder, false, transpose_rhs); - for (int64_t j = 0; j < nj; ++j) { - for (int64_t k = 0; k < nk; ++k) { - // TODO(tlongeri): there should be a way to slice without copying - xla::Array rhs_group = target_rhs_vregs.Slice( - {k, j, 0}, {k + 1, j + 1, rhs_vregs_per_group}); - auto rhs_rolled_group = builder.create( - op.getLoc(), rhs_group_ty, XlaArrayToFlatArrayRef(rhs_group)); - rhs_rolled_group->setAttr("out_layout", rhs_layout_attr); - const xla::Array acc_col_vregs = target_acc_vregs.Slice( - {0, j * acc_col_vregs_per_group}, - {target_acc_vregs.dim(0), (j + 1) * acc_col_vregs_per_group}); - auto acc_col = builder.create( - op.getLoc(), acc_col_ty, XlaArrayToFlatArrayRef(acc_col_vregs)); - acc_col->setAttr("out_layout", acc_layout_attr); - auto new_acc_col = builder.create( - op.getLoc(), acc_col_ty, lhs_cols[k], rhs_rolled_group, acc_col, - /*transpose_lhs=*/false, /*transpose_rhs=*/false, precision_attr, - dot_dimension_numbers_attr); - auto new_acc_vregs = builder.create( - op.getLoc(), - TypeRange(ValueRange(XlaArrayToFlatArrayRef(acc_col_vregs))), - new_acc_col); - new_acc_vregs->setAttr("in_layout", acc_layout_attr); - updateSliceFromRange( - target_acc_vregs, new_acc_vregs->getResults(), - {0, j * acc_col_vregs_per_group}, - {target_acc_vregs.dim(0), (j + 1) * acc_col_vregs_per_group}); - } - } - op.replaceAllUsesWith( - assemble( - builder, res.getType(), layout_out, - target_acc_vregs.Slice({0, 0}, {acc_vregs.dim(0), acc_vregs.dim(1)}), - ctx.target_shape) - .getOperation()); - op.erase(); - return success(); -} - -LogicalResult tpu_store_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(layouts_out.size(), 0); - TPU_ASSERT_OP(layouts_in.front().has_value()); // value to store layout - TPU_ASSERT_OP(llvm::none_of(layouts_in.drop_front(), - [&](const Layout &l) { return l.has_value(); })); - OpBuilder builder(&op); - const VectorLayout &to_store_layout = *layouts_in.front(); - // We expect the value to store is already a native-sized vreg. - if (to_store_layout.bitwidth() != 32) { - return op.emitOpError("Not implemented: Only 32-bit loads supported"); - } - TPU_ASSERT_OP(to_store_layout == - VectorLayout(32, {0, 0}, ctx.target_shape, - VectorLayout::ImplicitDim::kNone)); - tpu::StoreOp store_op = cast(op); - FAILUREOR_ASSIGN_OR_RETURN( - const SmallVector indices, - getIntConstsFromOperandRange(store_op.getIndices())); - TPU_ASSERT_EQ_OP(indices.size(), 2); - if (indices[1] % ctx.target_shape[1] != 0) { - return op.emitOpError("Not implemented: Lane index is not a multiple of ") - << ctx.target_shape[1]; - } - FAILUREOR_ASSIGN_OR_RETURN( - xla::Array tiles, - disassemble(builder, to_store_layout, store_op.getValueToStore(), - ctx.target_shape)); - TPU_ASSERT_OP((tiles.dimensions() == xla::DimensionVector{1, 1})); - store_op.getValueToStoreMutable().assign(tiles({0, 0})); - return success(); -} - -LogicalResult tpu_bitcast_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(layouts_in.size(), 1); - TPU_ASSERT_EQ_OP(layouts_out.size(), 1); - TPU_ASSERT_OP(layouts_in.front().has_value()); - TPU_ASSERT_OP(layouts_out.front().has_value()); - const VectorLayout &in_layout = *layouts_in.front(); - const VectorLayout &out_layout = *layouts_out.front(); - auto in_bitwidth = in_layout.bitwidth(); - auto out_bitwidth = out_layout.bitwidth(); - auto in_tiling = in_layout.tiling(); - auto out_tiling = out_layout.tiling(); - in_tiling[0] *= in_bitwidth; - out_tiling[0] *= out_bitwidth; - if (in_tiling != out_tiling) { - return op.emitOpError( - "Expected tilings are the same after multiplying the " - "second-minor dimension by the ratio of bitwidths."); - } - auto in_offsets = in_layout.offsets(); - auto out_offsets = out_layout.offsets(); - if (!out_offsets[0].has_value() && in_bitwidth > out_bitwidth) { - return op.emitOpError( - "Expected no replicated offset on 2nd minor dimension of output when " - "bitwidth is decreased."); - } - if (in_offsets[0].has_value() != out_offsets[0].has_value() || - in_offsets[0].value_or(0) * in_bitwidth != - out_offsets[0].value_or(0) * out_bitwidth || - in_offsets[1] != out_offsets[1]) { - return op.emitOpError( - "Expected offsets are the same after multiplying the " - "second-minor dimension by the ratio of bitwidths."); - } - if (in_layout.implicit_dim() != out_layout.implicit_dim()) { - return op.emitOpError( - "Expected same implicit dim for input and output layout"); - } - auto bitcast_op = cast(op); - const auto out_ty = bitcast_op.getResult().getType(); - if (in_bitwidth != out_bitwidth) { - if (in_layout.implicit_dim() != VectorLayout::ImplicitDim::kNone) { - return op.emitOpError("Expected no implicit dim when bitwidth changes"); - } - } - ImplicitLocOpBuilder builder(op.getLoc(), &op); - const auto native_vreg_ty = - getNativeVregType(out_ty.getElementType(), ctx.target_shape); - FAILUREOR_ASSIGN_OR_RETURN( - const xla::Array in_tiles, - disassemble(builder, in_layout, bitcast_op.getInput(), ctx.target_shape)); - xla::Array out_tiles(in_tiles.dimensions()); - out_tiles.Each([&](absl::Span idxs, Value *v) { - const Value in_tile = in_tiles(idxs); - *v = builder.create(native_vreg_ty, in_tile); - }); - bitcast_op.replaceAllUsesWith( - assemble(builder, out_ty, out_layout, out_tiles, ctx.target_shape) - .getOperation()); - bitcast_op.erase(); - return success(); -} - -LogicalResult tpu_trace_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - if (op.getNumOperands() != 0 || op.getNumResults() != 0) { - return op.emitOpError( - "Not implemented: tpu.traced_block with inputs or outputs"); - } - TPU_ASSERT_EQ_OP(layouts_in.size(), 0); - TPU_ASSERT_EQ_OP(layouts_out.size(), 0); - // We don't modify the op, but we do rewrite the branch bodies. - TPU_ASSERT_EQ_OP(op.getNumRegions(), 1); - Region ®ion = op.getRegion(0); - TPU_ASSERT_OP(region.hasOneBlock()); - Block &block = region.front(); - return applyLayoutBlock(ctx, block); -} - -LogicalResult tpu_assume_layout_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(op.getNumOperands(), 1); - TPU_ASSERT_EQ_OP(op.getNumResults(), 1); - TPU_ASSERT_EQ_OP(layouts_in.size(), 1); - TPU_ASSERT_EQ_OP(layouts_out.size(), 1); - if (layouts_in[0] != layouts_out[0]) { - return op.emitOpError("Expected same input and output layout"); - } - OpBuilder builder(&op); - auto val = op.getOperand(0); - auto layout = layouts_in[0]; - const auto vty = cast(val.getType()); - SmallVector layout_shape = - layout->tileArrayShape(vty.getShape(), ctx.target_shape); - const int64_t num_vectors = ShapedType::getNumElements(layout_shape); - VectorType vreg_ty = - getNativeVregType(vty.getElementType(), ctx.target_shape); - // We can not use disassemble here because the val is block argument. - auto unrolled_op = builder.create( - val.getLoc(), SmallVector(num_vectors, vreg_ty), val); - - op.replaceAllUsesWith(assemble(builder, vty, *layout, - XlaArrayFromShapeAndValues( - layout_shape, unrolled_op->getResults()), - ctx.target_shape)); - op.erase(); - return success(); -} - -LogicalResult tpu_relayout_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(op.getNumOperands(), 1); - TPU_ASSERT_EQ_OP(op.getNumResults(), 1); - TPU_ASSERT_EQ_OP(layouts_in.size(), 1); - TPU_ASSERT_EQ_OP(layouts_out.size(), 1); - TPU_ASSERT_OP(layouts_in[0].has_value()); - TPU_ASSERT_OP(layouts_out[0].has_value()); - const auto& in_layout = *layouts_in[0]; - const auto& out_layout = *layouts_out[0]; - auto realyout_op = cast(op); - auto in_bitwidth = in_layout.bitwidth(); - auto out_bitwidth = out_layout.bitwidth(); - auto vty = cast(realyout_op.getType()); - ImplicitLocOpBuilder builder(op.getLoc(), &op); - if (in_layout == out_layout) { - realyout_op.replaceAllUsesWith(realyout_op.getInput()); - realyout_op.erase(); - return success(); - } - FAILUREOR_ASSIGN_OR_RETURN( - xla::Array vals, - disassemble(builder, in_layout, - cast>(realyout_op.getInput()), - ctx.target_shape, - /*use_implicit_shape=*/true)); - // Packing vector masks from 32-bit to 16-bit. - if (vty.getElementType() == builder.getI1Type() && in_bitwidth == 32 && - out_bitwidth == 16 && - in_layout.tiling()[0] == in_layout.packing() * ctx.target_shape[0] && - in_layout.tiling()[1] == ctx.target_shape[1] && - in_layout.tiling() == out_layout.tiling() && - in_layout.offsets() == out_layout.offsets() && - in_layout.implicit_dim() == out_layout.implicit_dim()) { - std::vector vmsks_shape(vals.dimensions().begin(), - vals.dimensions().end()); - *(vmsks_shape.end() - 1) = llvm::divideCeil(vmsks_shape.back(), 2); - xla::Array out_vmsks(vmsks_shape, nullptr); - SmallVector val_idx; - Value default_val = - getFullLikeVector(builder, cast>(*vals.begin()), - IntegerAttr::get(builder.getI1Type(), 0)); - out_vmsks.Each([&](absl::Span idx, Value *v) { - val_idx.assign(idx.begin(), idx.end()); - // TODO(jevinjiang): can be simplified when offset is replicated. - *(val_idx.end() - 1) *= 2; - Value low_part = *(val_idx.end() - 1) < *(vals.dimensions().end() - 1) - ? vals(val_idx) - : default_val; - *(val_idx.end() - 1) += 1; - Value high_part = *(val_idx.end() - 1) < *(vals.dimensions().end() - 1) - ? vals(val_idx) - : default_val; - const VectorType mask_ty = getNativeVregOrVmaskType( - builder.getI1Type(), in_bitwidth / 2, ctx.target_shape); - *v = builder.create(mask_ty, low_part, high_part); - }); - const RollVectorsOp rolled_op = - assemble(builder, vty, out_layout, out_vmsks, ctx.target_shape, - /*use_implicit_shape=*/true); - op.replaceAllUsesWith(rolled_op); - op.erase(); - return success(); - } - return op.emitOpError("Not implemented: unsupported layout change"); -} - -// TODO(b/347016737): Deprecate tpu.rotate and only use tpu.dynamic_rotate. So -// we do not need template for the op type and to explicitly force amount -// argument to dynamic. -template -LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount, - const VectorLayout &layout_in, - const VectorLayout &layout_out) { - auto layout = VectorLayout(32, {0, 0}, ctx.target_shape, - VectorLayout::ImplicitDim::kNone); - if (layout_in != layout) { - return op.emitOpError("Not implemented: unsupported layout for input"); - } - if (layout_out != layout) { - return op.emitOpError("Not implemented: unsupported layout for output"); - } - auto vty = op.getResult().getType(); - if (vty.getRank() < 2) { - return op.emitOpError("Not implemented: unsupported 1D shape"); - } - if (*(vty.getShape().end() - 2) % *(layout.tiling().end() - 2) != 0 || - *(vty.getShape().end() - 1) % *(layout.tiling().end() - 1) != 0) { - return op.emitOpError("Not implemented: unsupported unaliged shape"); - } - - ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation()); - - VectorType res_vreg_ty = - getNativeVregType(vty.getElementType(), ctx.target_shape); - FAILUREOR_ASSIGN_OR_RETURN( - const xla::Array in_tiles, - disassemble(builder, layout_in, op.getValue(), ctx.target_shape)); - - const VectorType i32_vreg = - getNativeVregType(builder.getI32Type(), ctx.target_shape); - - // Some helper functions for math ops. - auto mlirI32Const = [&](int d) { - return builder.create( - builder.getIntegerAttr(builder.getI32Type(), d)); - }; - auto mlirIndexConst = [&](int d) { - return builder.create( - builder.getIntegerAttr(builder.getIndexType(), d)); - }; - auto modI = [&](const Value &v, unsigned d) -> Value { - if (auto cst = getIntConst(v, /*silent=*/true); succeeded(cst)) { - return mlirI32Const(cst.value() % d); - } - return builder.create(v, mlirI32Const(d)); - }; - auto divI = [&](const Value &v, unsigned d) -> Value { - if (auto cst = getIntConst(v, /*silent=*/true); succeeded(cst)) { - return mlirI32Const(cst.value() / d); - } - return builder.create(v, mlirI32Const(d)); - }; - auto addI = [&](const Value &v, unsigned d) -> Value { - if (auto cst = getIntConst(v, /*silent=*/true); succeeded(cst)) { - return mlirI32Const(cst.value() + d); - } - return builder.create(v, mlirI32Const(d)); - }; - - // A helper function that creates a VMASK with false flags to bottom (dim = 0) - // or right (dim = 1) where the flag count corresponds to the (dim_size - - // padding). If stride is provided, the padding value is sequentially - // increased by the stride value along the dim. - // - // For example, assume VMASK shape is (4, 8) - // - // getVmaskByPaddingEnd(padding=3, dim=1) creates: - // [T, T, T, T, T, F, F, F] - // [T, T, T, T, T, F, F, F] - // [T, T, T, T, T, F, F, F] - // [T, T, T, T, T, F, F, F] - // - // getVmaskByPaddingEnd(padding=3, dim=1, stride=1) creates: - // [T, T, T, T, T, F, F, F] - // [T, T, T, T, T, T, F, F] - // [T, T, T, T, T, T, T, F] - // [T, T, T, T, T, T, T, T] - auto getVmaskByPaddingEnd = [&](Value padding, int dim, int stride = 0) { - CHECK(dim == 0 || dim == 1); - Value padding_vreg; - if (auto padding_cst = getIntConst(padding, /*silent=*/true); - succeeded(padding_cst)) { - CHECK_GE(padding_cst.value(), 0); - CHECK_LE(padding_cst.value(), ctx.target_shape[dim]); - padding_vreg = builder.create(DenseElementsAttr::get( - i32_vreg, builder.getI32IntegerAttr(padding_cst.value()))); - } else { - padding_vreg = builder.create(i32_vreg, padding); - } - - if (stride > 0) { - auto offset = builder.create( - i32_vreg, - builder.create( - i32_vreg, builder.getI32IntegerAttr(dim == 0 ? 1 : 0)), - builder.create(DenseElementsAttr::get( - i32_vreg, builder.getI32IntegerAttr(stride)))); - padding_vreg = - builder.create(i32_vreg, padding_vreg, offset); - } - return builder.create( - arith::CmpIPredicate::slt, - builder.create(i32_vreg, builder.getI32IntegerAttr(dim)), - padding_vreg); - }; - - // Apply rotation on each vreg with the assumption that shift <= VREG dim size - // and blend the data from contiguous vregs to emulate circular rotation. - auto rotateOnTilingDim = [&](const xla::Array &vregs, - const Value &shift, int axis, int stride = 0) { - if (auto shift_cst = getIntConst(shift, /*silent=*/true); - succeeded(shift_cst)) { - if (shift_cst.value() == 0 && stride == 0) { - return vregs; - } - } - int tiling_dim = axis - (vregs.num_dimensions() - 2); - CHECK((tiling_dim == 0 && stride == 0) || (tiling_dim == 1 && stride >= 0)); - xla::Array result(vregs.dimensions()); - auto chunks = split(vregs, axis); - for (int64_t i = 0; i < chunks.size(); ++i) { - chunks[i].Each([&](absl::Span idxs, Value *v) { - auto stride_attr = - stride > 0 ? builder.getSI32IntegerAttr(stride) : nullptr; - auto stride_dimension_attr = - stride > 0 ? builder.getSI32IntegerAttr(0) : nullptr; - *v = builder.create(res_vreg_ty, *v, shift, - tiling_dim, stride_attr, - stride_dimension_attr); - }); - } - auto mask = getVmaskByPaddingEnd(shift, tiling_dim, stride); - xla::Array last_chunk_copy(chunks[chunks.size() - 1]); - for (int64_t i = chunks.size() - 1; i > 0; --i) { - chunks[i].Each([&](absl::Span idxs, Value *v) { - *v = builder.create(mask, chunks[i - 1](idxs), *v); - }); - } - chunks[0].Each([&](absl::Span idxs, Value *v) { - *v = builder.create(mask, last_chunk_copy(idxs), *v); - }); - return concatenate(chunks, axis); - }; - - std::function(const xla::Array &, Value, int, int)> - rotate; - rotate = [&](const xla::Array &vregs, Value shift, int axis, - int stride) { - xla::Array result(vregs.dimensions()); - CHECK(axis >= 0 && axis < vregs.num_dimensions()); - int tiling_dim = axis - (vregs.num_dimensions() - 2); - CHECK((tiling_dim != 1 && stride == 0) || (tiling_dim == 1 && stride >= 0)); - SmallVector, 4> chunks; - // Handle rotation with static shift. - if (auto shift_cst = getIntConst(shift, /*silent=*/true); - succeeded(shift_cst)) { - int64_t static_shift = shift_cst.value(); - if (tiling_dim >= 0) { - shift = mlirI32Const(static_shift % ctx.target_shape[tiling_dim]); - static_shift /= ctx.target_shape[tiling_dim]; - chunks = split(rotateOnTilingDim(vregs, shift, axis, stride), axis); - } else { - chunks = split(vregs, axis); - } - // Now we only need to shuffle vregs. - for (int64_t i = 0; i < chunks.size(); ++i) { - SmallVector starts(result.num_dimensions(), 0); - starts[axis] = (i + static_shift) % result.dim(axis); - result.UpdateSlice(chunks[i], starts); - } - return result; - } - // Handle rotation with dynamic shift. - // TODO(jevinjiang): consider optimize with assume_multiple op. - Value in_vreg_shift = tiling_dim >= 0 - ? modI(shift, ctx.target_shape[tiling_dim]) - : mlirI32Const(0); - Value vreg_shift = - tiling_dim >= 0 ? divI(shift, ctx.target_shape[tiling_dim]) : shift; - result = tiling_dim >= 0 - ? rotateOnTilingDim(vregs, in_vreg_shift, axis, stride) - : vregs; - int bound = vregs.dim(axis); - if (bound <= ctx.max_sublanes_in_scratch / ctx.target_shape[0] && - bound >= kMinBoundToRotateWithScratch) { - // Use static store + dynamic load to implement dynamic shift. - if (auto scratch_ref = getInternalScratch( - ctx, builder, op.getLoc(), - {ctx.max_sublanes_in_scratch / ctx.target_shape[0], - ctx.target_shape[0], ctx.target_shape[1]}, - vty.getElementType()); - succeeded(scratch_ref)) { - auto cst_0 = mlirIndexConst(0); - SmallVector scratch_indices(3, cst_0); - SmallVector sublane_mask(ctx.target_shape[0], true); - const auto sublane_mask_attr = - DenseBoolArrayAttr::get(op.getContext(), sublane_mask); - chunks = split(result, axis); - chunks[0].Each([&](absl::Span idxs, Value *v) { - // Static store vregs. - for (int i = 0; i < bound; ++i) { - scratch_indices[0] = mlirIndexConst(i); - builder.create(chunks[i](idxs), scratch_ref.value(), - scratch_indices, sublane_mask_attr, - /*mask=*/nullptr, - /*sublane_stride=*/nullptr); - } - // Dynamic load vregs back from a circular buffer. - for (int i = 0; i < bound; ++i) { - scratch_indices[0] = builder.create( - builder.getIndexType(), - modI(builder.create(mlirI32Const(bound + i), - vreg_shift), - bound)); - chunks[i](idxs) = - builder.create(v->getType(), scratch_ref.value(), - scratch_indices, sublane_mask_attr, - /*sublane_stride=*/nullptr); - } - }); - return concatenate(chunks, axis); - } - } - // Convert dynamic shift to log(bound) static ops. - int roll_by = 1; - while (roll_by < bound) { - auto new_result = rotate( - result, - mlirI32Const(tiling_dim >= 0 ? roll_by * ctx.target_shape[tiling_dim] - : roll_by), - axis, /*stride=*/0); - auto mask = builder.create( - arith::CmpIPredicate::ne, - builder.create( - i32_vreg, - builder.create(vreg_shift, mlirI32Const(roll_by))), - builder.create( - DenseElementsAttr::get(i32_vreg, builder.getI32IntegerAttr(0)))); - result.Each([&](absl::Span idxs, Value *v) { - *v = builder.create(mask, new_result(idxs), *v); - }); - roll_by *= 2; - } - return result; - }; - - xla::Array out_tiles(in_tiles.dimensions()); - const auto dim = op.getDimension(); - amount = modI(amount, vty.getDimSize(dim)); - - if (op.getStride().has_value() && op.getStrideDimension().has_value()) { - auto stride_dim = op.getStrideDimension().value(); - auto stride = op.getStride().value() % vty.getDimSize(stride_dim); - if (stride_dim == dim) { - return op.emitOpError( - "Expected rotation dimension and stride dimension are not equal"); - } - if (stride_dim == vty.getRank() - 1) { - return op.emitOpError( - "Not implemented: stride dimension is the minor most"); - } else if (stride_dim == vty.getRank() - 2) { - if (dim != vty.getRank() - 1 || ctx.hardware_generation < 5) { - return op.emitOpError( - "Not implemented: only supported in TPU v5+ and rotation dimension " - "is the minor most when stride dimension is the second minor most"); - } - CHECK_GE(stride, 0); - auto chunks = split(in_tiles, stride_dim); - for (int64_t i = 0; i < chunks.size(); ++i) { - Value base_amount = modI(addI(amount, ctx.target_shape[0] * i * stride), - vty.getDimSize(dim)); - // After applying stride, we expect all shifts in a vreg are less or - // equal to the vreg's lane count for now. - if (auto base_amount_cst = getIntConst(base_amount, /*silent=*/true); - succeeded(base_amount_cst)) { - int64_t static_base_amount = base_amount_cst.value(); - auto max_shift_in_vreg = static_base_amount % ctx.target_shape[1] + - (ctx.target_shape[0] - 1) * stride; - if (max_shift_in_vreg > ctx.target_shape[1]) { - return op.emitOpError("Not implemented: the max shift in a vreg ") - << max_shift_in_vreg << " is larger than the vreg's width " - << ctx.target_shape[1]; - } - } - SmallVector starts(out_tiles.num_dimensions(), 0); - starts[stride_dim] = i; - out_tiles.UpdateSlice(rotate(chunks[i], base_amount, dim, stride), - starts); - } - } else { - // Split vregs along the stride dimension. - auto chunks = split(in_tiles, stride_dim); - for (int64_t i = 0; i < chunks.size(); ++i) { - SmallVector starts(out_tiles.num_dimensions(), 0); - starts[stride_dim] = i; - out_tiles.UpdateSlice( - rotate(chunks[i], addI(amount, i * stride), dim, /*stride=*/0), - starts); - } - } - } else { // No stride. - out_tiles = rotate(in_tiles, amount, dim, /*stride=*/0); - } - - const RollVectorsOp rolled_op = - assemble(builder, op.getResult().getType(), layout_out, out_tiles, - ctx.target_shape); - op.replaceAllUsesWith(rolled_op); - op.erase(); - return success(); -} - -// TODO(b/347016737): deprecate the static rotate. -LogicalResult tpu_rotate_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - CHECK_EQ(layouts_in.size(), 1); - CHECK_EQ(layouts_out.size(), 1); - if (!layouts_in.front().has_value()) { - return op.emitOpError("Expected non-null input layout"); - } - if (!layouts_out.front().has_value()) { - return op.emitOpError("Expected non-null output layout"); - } - auto rotate_op = cast(op); - if (rotate_op.getAmount() < 0) { - return op.emitOpError("Not implemented: shifting by negative amount"); - } - ImplicitLocOpBuilder builder(op.getLoc(), &op); - Value shift = builder.create( - builder.getIntegerAttr(builder.getI32Type(), rotate_op.getAmount())); - const VectorLayout &layout_in = *layouts_in.front(); - const VectorLayout &layout_out = *layouts_out.front(); - return rotate_rule_impl(ctx, rotate_op, shift, layout_in, layout_out); -} - -LogicalResult tpu_dynamic_rotate_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - CHECK_EQ(layouts_in.size(), 2); - CHECK_EQ(layouts_out.size(), 1); - if (!layouts_in.front().has_value()) { - return op.emitOpError("Expected non-null layout for the value to rotate"); - } - if (layouts_in[1].has_value()) { - return op.emitOpError("Expected null layout for the shift"); - } - if (!layouts_out.front().has_value()) { - return op.emitOpError("Expected non-null output layout"); - } - auto rotate_op = cast(op); - const VectorLayout &layout_in = *layouts_in.front(); - const VectorLayout &layout_out = *layouts_out.front(); - return rotate_rule_impl(ctx, rotate_op, rotate_op.getAmount(), layout_in, - layout_out); -} - -LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(layouts_in.size(), op.getNumOperands()); - TPU_ASSERT_EQ_OP(layouts_out.size(), 1); - TPU_ASSERT_OP( - llvm::all_of(layouts_in, [](const Layout &l) { return l.has_value(); })); - TPU_ASSERT_OP(layouts_out.front().has_value()); - OpBuilder builder(&op); - auto concatenate_op = cast(op); - const VectorType res_ty = concatenate_op.getResult().getType(); - uint32_t dimension = concatenate_op.getDimension(); - SmallVector> operand_vregs; - operand_vregs.reserve(op.getNumOperands()); - - std::optional tiling_dim; - auto res_layout = layouts_out.front(); - - TPU_ASSERT_OP(res_layout.has_value()); - auto num_untiled_dims = res_ty.getRank() - res_layout->layout_rank(); - - if (res_ty.getRank() == 1 && - res_layout->implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor) { - tiling_dim = 1; - } else if (dimension >= num_untiled_dims) { - tiling_dim = dimension - num_untiled_dims; - } - - // Op level invariants on layouts, other op level invariants are checked in - // the verifier. - auto res_tiling = res_layout->tiling(); - for (int i = 0; i < op.getNumOperands(); ++i) { - auto operand = op.getOperand(i); - if (!layouts_in[i].has_value()) { - return op.emitOpError("Not implemented: Expected input layout"); - } - auto const &layout = *layouts_in[i]; - - if (layout.tiling() != res_tiling) { - return op.emitOpError("Not implemented: result/input Tiling mismatch"); - } - - if (layout.implicit_dim() != res_layout->implicit_dim()) { - return op.emitOpError("Not implemented: result/input offsets mismatch."); - } - - if (layout.implicit_dim() != res_layout->implicit_dim()) { - return op.emitOpError( - "Not implemented: result/input implicit dim mismatch."); - } - - if (i > 1) { - auto curr_offsets = layout.offsets(); - auto last_operand_offsets = layouts_in[i - 1]->offsets(); - if (tiling_dim.has_value()) { - // Zero out the offset in the tiling dimension for verification. - curr_offsets[tiling_dim.value()] = 0; - last_operand_offsets[tiling_dim.value()] = 0; - } - if (curr_offsets != last_operand_offsets) { - op.emitOpError("Not implemented: non-concat dim offset mismatch."); - } - } - - FAILUREOR_ASSIGN_OR_RETURN( - xla::Array vreg_array, - disassemble(builder, layout, cast>(operand), - ctx.target_shape)); - operand_vregs.push_back(std::move(vreg_array)); - } - - CHECK_EQ(operand_vregs.size(), op.getNumOperands()); - SmallVector vreg_array_shape = - res_layout->tileArrayShape(res_ty.getShape(), ctx.target_shape); - - // Fill out out_vregs with nulls, to avoid a problem with where we have to - // blend with a vreg that has not been written to yet. - xla::Array out_vregs(vreg_array_shape, nullptr); - - auto boundIdxConst = - std::bind(IdxConst, std::placeholders::_1, builder, op.getLoc()); - - // Handle the untiled concatenation case. - if (!tiling_dim.has_value()) { - out_vregs = concatenate(operand_vregs, dimension); - } else { - bool is_rank1_with_no_implicit_dim = res_ty.getRank() == 1 && - res_layout->implicit_dim() == - VectorLayout::ImplicitDim::kNone; - if (res_layout->implicit_dim() == VectorLayout::ImplicitDim::kMinor || - is_rank1_with_no_implicit_dim) { - return op.emitOpError("Not implemented: implicit dim"); - } - if (res_layout->implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor && - res_layout->bitwidth() != 32) { - return op.emitOpError( - "Not implemented: only 32-bit bitwidth supported for SecondMinor " - "implicit dim"); - } - if (res_layout->offsets()[tiling_dim.value()] != 0) { - return op.emitOpError("Not implemented: result non-zero offset."); - } - if (!res_layout->hasNativeTiling(ctx.target_shape) && - res_ty.getRank() != 1) { - return op.emitOpError("Not implemented: Non native tiling in concat."); - } - - int64_t offset_at_dim = 0; - { - for (int i = 0; i < op.getNumOperands(); ++i) { - Value operand = op.getOperand(i); - const Layout &layout = *layouts_in[i]; - xla::Array vreg_array = operand_vregs[i]; - std::array vreg_slice = layout->vregSlice(ctx.target_shape); - std::array tiling = layout->tiling(); - - VectorType vty = cast(operand.getType()); - ArrayRef shape = vty.getShape(); - - int64_t starting_point = offset_at_dim; - int64_t offset_amount = - starting_point % vreg_slice[tiling_dim.value()]; - if (offset_amount >= tiling[tiling_dim.value()]) { - return op.emitError( - "Not implemented: Input offsets outside of the first tile"); - } - if (offset_amount != layout->offsets()[tiling_dim.value()]) { - return op.emitOpError( - "Not implemented: Relayout not called, unaligned dims " - "concatenated without proper offsets. Ensure that " - "infer_vector_layout pass was called."); - } - offset_at_dim += shape[dimension]; - } - } - - // Tiled concatenation logic. - int64_t offset = 0; - for (size_t i = 0; i < operand_vregs.size(); ++i) { - auto &vreg = operand_vregs[i]; - const auto &layout = layouts_in[i]; - const int packing = res_layout->packing(); - - if (layout->tiling()[0] % packing != 0) { - return op.emitOpError( - "Illegal tiling: Non-native tiling in concat - this should " - "have been caught earlier!"); - } - - const int64_t operand_offset = *layout->offsets()[tiling_dim.value()]; - if (operand_offset != 0) { - // We are offset, so we must blend with the previous vreg. - // Or, to frame it in an another way, the prior vreg - // stored its entire dim size in the offset, but only wrote the - // last dime partially. - offset -= 1; - } - - const auto bitwidth = res_ty.getElementTypeBitWidth(); - SmallVector out_idx; - vreg.Each([&](absl::Span idx, Value *v) { - out_idx.assign(idx.begin(), idx.end()); - out_idx[dimension] += offset; - if (idx[dimension] == 0 && operand_offset != 0) { - Value mask; - const VectorType vmask_ty = getNativeVregOrVmaskType( - builder.getI1Type(), bitwidth, ctx.target_shape); - if (tiling_dim.value() == 0) { // sublane - if (operand_offset % packing != 0) { - // Packed case, degenerate where we have a half or quarter - // sublane. - // TODO(mvoz): We can probably always use the - // CreateSubelementMaskOp if (1) optimize it on TPUv4 and (2) Add - // support for unpacked types in some of the invariants in - // lower_to_llo. - mask = builder.create( - op.getLoc(), vmask_ty, 0, operand_offset); - } else { - auto sublane_offset = operand_offset / packing; - mask = builder.create( - op.getLoc(), vmask_ty, - ArrayRef{boundIdxConst(0), boundIdxConst(0)}, - ArrayRef{boundIdxConst(sublane_offset), - boundIdxConst(layout->tiling()[1])}); - } - } else { // lane - mask = builder.create( - op.getLoc(), vmask_ty, - ArrayRef{boundIdxConst(0), boundIdxConst(0)}, - ArrayRef{boundIdxConst(layout->tiling()[0] / packing), - boundIdxConst(operand_offset)}); - } - // Blend the current value with the existing value in the output. - *v = builder.create(op.getLoc(), mask, - out_vregs(out_idx), *v); - } - out_vregs(out_idx) = *v; - }); - offset += vreg.dim(dimension); - } - } - auto assembled = - assemble(builder, res_ty, *res_layout, out_vregs, ctx.target_shape); - op.replaceAllUsesWith(assembled); - op.erase(); - return success(); -} - -LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(layouts_in.size(), 0); - TPU_ASSERT_EQ_OP(layouts_out.size(), 1); - TPU_ASSERT_OP(layouts_out.front().has_value()); - const VectorLayout &layout_out = *layouts_out.front(); - ImplicitLocOpBuilder builder(op.getLoc(), &op); - tpu::IotaOp iota_op = cast(op); - VectorType vty = iota_op.getResult().getType(); - if (const auto int_ty = dyn_cast(vty.getElementType()); - int_ty == nullptr || int_ty.getWidth() != 32) { - return iota_op.emitOpError("Not implemented: Only 32-bit Iota supported"); - } - if (!layout_out.hasNativeTiling(ctx.target_shape)) { - return iota_op.emitOpError("Not implemented: Only native tiling supported"); - } - - const auto native_vreg_ty = - getNativeVregType(vty.getElementType(), ctx.target_shape); - if (layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone) { - return op.emitOpError("Not implemented: Only 2D layouts supported"); - } - const SmallVector tile_array_shape = - layout_out.tileArrayShape(vty.getShape(), ctx.target_shape); - const std::optional dimension = iota_op.getDimension(); - if (!dimension.has_value()) { - return op.emitOpError("Not implemented: null dimension"); - } - if (*dimension == vty.getRank() - 1) { - if (layout_out.offsets()[1] != 0) { - return op.emitOpError("Not implemented: Unsupported offset"); - } - const int64_t num_tiles = tile_array_shape[tile_array_shape.size() - 1]; - SmallVector tiles(num_tiles); - auto vreg_iota = builder.create( - native_vreg_ty, - /*dimension =*/builder.getI32IntegerAttr(1)); - for (int64_t i = 0; i < num_tiles; ++i) { - Value offset = getFullVector( - builder, native_vreg_ty, - IntegerAttr::get(vty.getElementType(), - i * *(native_vreg_ty.getShape().end() - 1))); - tiles[i] = builder.create(vreg_iota, offset); - } - xla::Array broadcasted_tiles(tile_array_shape); - broadcasted_tiles.Each([&](absl::Span idxs, Value *v) { - *v = tiles[*(idxs.end() - 1)]; - }); - op.replaceAllUsesWith(assemble(builder, vty, layout_out, broadcasted_tiles, - ctx.target_shape)); - op.erase(); - return success(); - } - if (*dimension == vty.getRank() - 2) { - if (layout_out.offsets()[0] != 0) { - return op.emitOpError("Not implemented: Unsupported offset"); - } - const int64_t num_tiles = tile_array_shape[tile_array_shape.size() - 2]; - SmallVector tiles(num_tiles); - auto vreg_iota = builder.create( - native_vreg_ty, - /*dimension =*/builder.getI32IntegerAttr(0)); - for (int64_t i = 0; i < num_tiles; ++i) { - Value offset = getFullVector( - builder, native_vreg_ty, - IntegerAttr::get(vty.getElementType(), - i * *(native_vreg_ty.getShape().end() - 2))); - tiles[i] = builder.create(vreg_iota, offset); - } - xla::Array broadcasted_tiles(tile_array_shape); - broadcasted_tiles.Each([&](absl::Span idxs, Value *v) { - *v = tiles[*(idxs.end() - 2)]; - }); - op.replaceAllUsesWith(assemble(builder, vty, layout_out, broadcasted_tiles, - ctx.target_shape)); - op.erase(); - return success(); - } - // We take the iota over an untiled dimension. - CHECK_LT(*dimension, vty.getRank()); - SmallVector tiles; - tiles.reserve(vty.getDimSize(*dimension)); - for (int64_t i = 0; i < vty.getDimSize(*dimension); ++i) { - tiles.push_back(getFullVector(builder, native_vreg_ty, - IntegerAttr::get(vty.getElementType(), i))); - } - xla::Array out_tiles(tile_array_shape); - out_tiles.Each([&](absl::Span idxs, Value *v) { - *v = tiles[idxs[*dimension]]; - }); - op.replaceAllUsesWith( - assemble(builder, vty, layout_out, out_tiles, ctx.target_shape)); - op.erase(); - return success(); -} - -LogicalResult tpu_gather_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(layouts_in.size(), 1); - TPU_ASSERT_EQ_OP(layouts_out.size(), 1); - TPU_ASSERT_OP(layouts_in.front().has_value()); - TPU_ASSERT_OP(layouts_out.front().has_value()); - const VectorLayout &layout_in = *layouts_in.front(); - const VectorLayout &layout_out = *layouts_out.front(); - if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone || - layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone || - layout_in.offsets() != layout_out.offsets() || - llvm::any_of(layout_in.offsets(), [&](const LayoutOffset o) { - return o.has_value() && o != 0; - })) { - return op.emitOpError("Not implemented: Only 2D layouts supported"); - } - ImplicitLocOpBuilder builder(op.getLoc(), &op); - auto gather_op = cast(op); - const VectorType vty = gather_op.getResult().getType(); - const uint32_t dimension = gather_op.getDimension(); - if (dimension + 2 < vty.getRank()) { - return op.emitOpError("Not implemented: Unsupported dimension"); - } - FAILUREOR_ASSIGN_OR_RETURN( - const xla::Array in_tiles, - disassemble(builder, layout_in, gather_op.getSource(), ctx.target_shape)); - const int64_t width = ctx.target_shape[2 - (vty.getRank() - dimension)]; - const ArrayRef indices(gather_op.getIndices()); - auto [num_sections, rem] = std::div(indices.size(), width); - SmallVector segment_indices; - if (rem == 0) { - for (int64_t i = 0; i < width; ++i) { - const int64_t offset = i - i % width; - if (!(offset <= indices[i] && indices[i] < offset + width)) { - return op.emitOpError("Not implemented: Cross-segment gather"); - } - } - for (int64_t i = width; i < indices.size(); ++i) { - const int64_t offset = i - i % width; - if (indices[i] != indices[i % width] + offset) { - return op.emitOpError( - "Not implemented: Indices varying between segments"); - } - } - segment_indices.assign(indices.begin(), indices.begin() + width); - } else if (num_sections == 0) { // Only one vreg. - segment_indices.assign(indices.begin(), indices.end()); - segment_indices.append(width - indices.size(), 0); - } else { - return op.emitOpError("Not implemented: Not a multiple of target length"); - } - xla::Array out_tiles(in_tiles.dimensions()); - if (dimension == vty.getRank() - 1) { - // TODO(b/265133497): Remove the broadcast once 2nd minor works. - const auto dyn_ix_ty = - VectorType::get(ctx.target_shape, builder.getI32Type()); - // Broadcast indices to target_shape - SmallVector dyn_ix_val; - for (int64_t i = 0; i < ctx.target_shape[0]; ++i) { // Broadcast - dyn_ix_val.append(segment_indices); - } - auto func_op = op.getParentOfType(); - if (!func_op) { - return op.emitOpError("Expected a function op"); - } - FAILUREOR_ASSIGN_OR_RETURN( - const BlockArgument dyn_ix_ref, - appendConstant(ctx, func_op, - DenseIntElementsAttr::get(dyn_ix_ty, dyn_ix_val))); - auto all_sublanes = builder.getAttr( - SmallVector(ctx.target_shape[1], true)); - auto dyn_ix = builder.create( - dyn_ix_ty, dyn_ix_ref, - SmallVector(2, IdxConst(0, builder, op.getLoc())), - /*sublane_mask=*/all_sublanes, /*sublane_stride=*/nullptr); - out_tiles.Each([&](absl::Span idxs, Value *v) { - const Value in_tile = in_tiles(idxs); - *v = builder.create(in_tile.getType(), in_tile, - dyn_ix, 1); - }); - } else { - TPU_ASSERT_EQ_OP(dimension, vty.getRank() - 2); - const auto segment_indices_attr = - builder.getAttr(segment_indices); - out_tiles.Each([&](absl::Span idxs, Value *v) { - const Value in_tile = in_tiles(idxs); - *v = builder.create(in_tile.getType(), in_tile, - segment_indices_attr, 0); - }); - } - gather_op.replaceAllUsesWith( - assemble(builder, vty, layout_out, out_tiles, ctx.target_shape) - .getOperation()); - gather_op.erase(); - return success(); -} - -LogicalResult tpu_dynamic_gather_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(layouts_in.size(), 2); - TPU_ASSERT_EQ_OP(layouts_out.size(), 1); - TPU_ASSERT_OP(layouts_in[0].has_value()); - TPU_ASSERT_OP(layouts_in[1].has_value()); - TPU_ASSERT_OP(layouts_out[0].has_value()); - const VectorLayout &src_layout = *(layouts_in[0]); - const VectorLayout &idx_layout = *(layouts_in[1]); - const VectorLayout &out_layout = *(layouts_out[0]); - - OpBuilder builder(&op); - auto dy_gather_op = cast(op); - - // TODO(jevinjiang): we need to think harder for general vector shape. - if (dy_gather_op.getType().getShape() != - ArrayRef(ctx.target_shape)) { - return op.emitOpError( - "Not implemented: DynamicGatherOp only supports 32-bit VREG shape"); - } - - if (src_layout != out_layout || idx_layout != out_layout) { - return op.emitOpError( - "Not implemented: only support same layout for source, indices and " - "result"); - } - - if (!out_layout.hasNaturalTopology(ctx.target_shape)) { - return op.emitOpError( - "Not implemented: unsupported layout for DynamicGatherOp"); - } - - FAILUREOR_ASSIGN_OR_RETURN( - xla::Array src_vregs, - disassemble(builder, src_layout, dy_gather_op.getSource(), - ctx.target_shape)); - - FAILUREOR_ASSIGN_OR_RETURN( - xla::Array idx_vregs, - disassemble(builder, idx_layout, dy_gather_op.getIndices(), - ctx.target_shape)); - - TPU_ASSERT_EQ_OP(src_vregs.dimensions(), idx_vregs.dimensions()); - TPU_ASSERT_EQ_OP(src_vregs.num_elements(), 1); - - xla::Array out_vregs(src_vregs.dimensions()); - out_vregs.Each([&](absl::Span idxs, Value *v) { - *v = builder.create( - op.getLoc(), src_vregs(idxs).getType(), src_vregs(idxs), - idx_vregs(idxs), dy_gather_op.getDimension()); - }); - - dy_gather_op.replaceAllUsesWith( - assemble(builder, dy_gather_op.getResult().getType(), out_layout, - out_vregs, ctx.target_shape) - .getOperation()); - dy_gather_op.erase(); - return success(); -} - -LogicalResult tpu_region_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - if (op.getNumOperands() != 0) { - return op.emitOpError( - "Not implemented: tpu.region_block with inputs"); - } - TPU_ASSERT_EQ_OP(layouts_in.size(), 0); - ImplicitLocOpBuilder builder(op.getLoc(), &op); - auto region_op = cast(op); - // We don't modify the op, but we do rewrite the branch bodies. - if (failed( - applyLayoutBlock(ctx, region_op.getRegion().getBlocks().front()))) { - return op.emitOpError("Failed to apply layout to TPU region."); - } - auto yield_op = cast( - *region_op.getRegion().getBlocks().front().getTerminator()); - auto new_op = builder.create(yield_op->getOperandTypes()); - moveAllRegions(*region_op, *new_op); - - int64_t index = 0; - SmallVector rolled_results; - for (auto [result, layout] : - llvm::zip_equal(region_op.getResults(), layouts_out)) { - if (const auto vty = dyn_cast(result.getType())) { - // When the result has a vector type, assemble the result. - TPU_ASSERT_OP(layout.has_value()); - const SmallVector tiles_shape = - layout->tileArrayShape(vty.getShape(), ctx.target_shape); - const int64_t num_vectors = ShapedType::getNumElements(tiles_shape); - xla::Array tiles(tiles_shape); - TPU_ASSERT_LE_OP(index + num_vectors, new_op.getResults().size()); - tiles.SetValues( - llvm::make_range(new_op.getResults().begin() + index, - new_op.getResults().begin() + index + num_vectors)); - index += num_vectors; - RollVectorsOp rolled_op = - assemble(builder, vty, *layout, tiles, ctx.target_shape); - rolled_results.push_back(rolled_op); - } else { - TPU_ASSERT_OP(!layout.has_value()); - rolled_results.push_back(new_op.getResult(index)); - ++index; - } - } - region_op.replaceAllUsesWith(rolled_results); - region_op.erase(); - return success(); -} - -LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(layouts_out.size(), 1); - MLIRContext *const mlir_ctx = op.getContext(); - TPU_ASSERT_OP(llvm::none_of(layouts_in, - [&](const Layout &l) { return l.has_value(); })); - TPU_ASSERT_OP(layouts_out.front().has_value()); - const VectorLayout &layout_out = *layouts_out.front(); - ImplicitLocOpBuilder builder(op.getLoc(), &op); - auto load_op = cast(op); - const auto memref_ty = getMemRefType(load_op.getBase()); - const auto vty = cast(load_op.getResult().getType()); - VectorType target_ty = - getNativeVregType(vty.getElementType(), ctx.target_shape); - if (vty.getRank() == 0) { - op.emitOpError("Not implemented: scalar loads from vmem"); - } - const bool is_1d = vty.getRank() == 1; - VectorLayout::ImplicitDim expected_dim = - is_1d ? VectorLayout::ImplicitDim::kSecondMinor - : VectorLayout::ImplicitDim::kNone; - if (layout_out.implicit_dim() != expected_dim) { - return op.emitOpError("Not implemented: unsupported layout"); - } - using Tiling = std::array; // To avoid comma in macro - FAILUREOR_ASSIGN_OR_RETURN( - Tiling memref_tiling, - getMemRefTiling(load_op.getBase(), ctx.target_shape)); - if (memref_tiling != layout_out.tiling()) { - if (memref_tiling[0] == 1 && layout_out.tiling()[0] == 1 && - memref_tiling[1] % layout_out.tiling()[1] == 0) { - // In this case, it is valid to use output tiling (1, 128 * packing) when - // loading from a 1D memref. - } else if (layout_out.bitwidth() == 32 && - layout_out.tiling() == - std::array{1, ctx.target_shape[1]}) { - // In this case, it is valid to use output tiling (1, TARGET_SHAPE.lanes) - // because we strided-load one row from each tile of the memref. This can - // save us a bunch of loads! - // TODO(b/295393167): need to support strided load for bitwidth < 32. - } else if (layout_out.bitwidth() == 32 && - canReinterpretToUntiledMemref( - load_op.getBase(), ctx.target_shape, - /*allow_minormost_padding=*/true)) { - // In this case, if the memref can be reinterpreted to untiled, it is - // valid to use any tiling for output. But using native tiling can save us - // a bunch of loads! - } else { - return op.emitOpError( - "Not implemented: dismatch in memref tiling and vector tiling in " - "load"); - } - } - // TODO(apaszke): Check that loads are from vmem! - - bool can_support_unaligned_dynamic_index = false; - bool must_support_unaligned_dynamic_index = false; - if (load_op.getIndices().size() > 1) { - auto second_minor_idx = load_op.getIndices().take_back(2)[0]; - if (failed(getIntConst(second_minor_idx, /*silent=*/true)) && - !isGuaranteedDivisible(second_minor_idx, memref_tiling[0])) { - must_support_unaligned_dynamic_index = true; - } - } - const SmallVector implicit_shape = - layout_out.implicitShape(vty.getShape()); - const int64_t ss = implicit_shape[implicit_shape.size() - 2]; - int64_t sublane_stride = 1; - // Handle special patterns that allow us to support more flexible loads. - if (layout_out.bitwidth() == 32 && - layout_out.tiling() == std::array{1, ctx.target_shape[1]} && - ss == 1) { - // Loading a single row on the 2nd minor dim into the (1, 128) layout. We - // can use sublane striding to perform the relayout as part of the load. - sublane_stride = memref_tiling[0]; - can_support_unaligned_dynamic_index = true; - } else { - // Otherwise, if the memref has a short last dimension and is contiguous - // all the tiled layouts become equivalent, so we can handle unaligned - // dynamic indices without any special case. - auto mem_layout = dyn_cast(memref_ty.getLayout()); - if (!mem_layout) { - return op.emitOpError("Expected a tiled memref"); - } - auto tile_strides = mem_layout.getTileStrides(); - if (memref_ty.getShape().back() == ctx.target_shape[1] && - tile_strides.take_back(2) == ArrayRef{1, 1}) { - can_support_unaligned_dynamic_index = true; - } - } - - auto add_idx = [&](const Value &v, int64_t d) -> Value { - if (auto cst = getIntConst(v, /*silent=*/true); succeeded(cst)) { - return IdxConst(cst.value() + d, builder, op.getLoc()); - } - return builder.create(v, IdxConst(d, builder, op.getLoc())); - }; - - int tiled_dims = is_1d ? 1 : 2; - Value base_addr = load_op.getBase(); - SmallVector base_indices = load_op.getIndices(); - - if (must_support_unaligned_dynamic_index) { - if (!can_support_unaligned_dynamic_index) { - return op.emitOpError( - "Not implemented: dynamic load with unaligned indices"); - } - } else { - // Convert dynamic load to dynamic slice + static load. This saves us a - // bunch of scalar core work. - auto slice_result = - sliceRef(builder, load_op.getBase(), load_op.getVectorType().getShape(), - load_op.getIndices(), - ArrayRef(memref_tiling).take_back(tiled_dims)); - if (failed(slice_result)) { - return failure(); - } - base_addr = slice_result->first; - CHECK_EQ(slice_result->second.size(), base_indices.size()); - for (int i = 0; i < base_indices.size(); ++i) { - base_indices[i] = IdxConst(slice_result->second[i], builder, op.getLoc()); - } - } - - // TODO(jevinjiang): ideally we should update the base addr and use static - // indices even for the cases that can skip alignment check. This can save us - // a bunch of scalar core work. - auto tile_base_idxs = ArrayRef(base_indices).take_back(tiled_dims); - auto batch_base_idxs = ArrayRef(base_indices).drop_back(tiled_dims); - const LayoutOffsets offsets = layout_out.offsets(); - AffineMap load_map; - if (offsets[1] == std::nullopt) { - return op.emitOpError( - "Not implemented: Load replicated along lanes is unsupported"); - } - if (offsets[0] == std::nullopt) { - if (ss != 1) { - return op.emitOpError( - "Not implemented: Sublane-replicated load with size > 1 is " - "unsupported"); - } - if (!layout_out.hasNativeTiling(ctx.target_shape)) { - return op.emitOpError("Not implemented"); - } - // affine_map<(..., j) -> (0, j) - load_map = - AffineMap::get(memref_ty.getRank(), 0, - {getAffineConstantExpr(0, mlir_ctx), - getAffineDimExpr(memref_ty.getRank() - 1, mlir_ctx)}, - mlir_ctx); - } - - xla::Array tiles( - layout_out.tileArrayShape(vty.getShape(), ctx.target_shape)); - const std::array vreg_slice = - layout_out.vregSlice(ctx.target_shape); - const int64_t num_dims = vty.getRank(); - const int64_t num_batch_dims = num_dims - (is_1d ? 1 : 2); - const absl::Status status = - tiles.EachStatus([&](absl::Span tile_idxs, Value * /*v*/) { - CHECK_EQ(num_dims, tile_idxs.size()); - SmallVector idxs(tile_idxs.size()); - for (int64_t i = 0; i < num_batch_dims; ++i) { - idxs[i] = add_idx(batch_base_idxs[i], tile_idxs[i]); - } - const auto base_l = tile_base_idxs.back(); - const int64_t lidx = tile_idxs[num_dims - 1]; - idxs[num_dims - 1] = - add_idx(base_l, lidx * vreg_slice[1] - offsets[1].value_or(0)); - if (!is_1d) { - CHECK_EQ(tile_base_idxs.size(), 2); - const auto base_s = tile_base_idxs.front(); - const int64_t sidx = tile_idxs[num_dims - 2]; - idxs[num_dims - 2] = - add_idx(base_s, sidx * vreg_slice[0] - offsets[0].value_or(0)); - } - TPU_ASSERT_OP(tile_idxs[num_dims - 1] + ctx.target_shape[1] <= - memref_ty.getShape()[num_dims - 1]); - std::unique_ptr bounds = layout_out.tileDataBounds( - mlir_ctx, vty.getShape(), toArrayRef(tile_idxs), ctx.target_shape, - /*allow_replicated =*/{true, false}); - Operation *tile; - if (bounds->maskVariesAlong(Direction::kSublanes, ctx.target_shape)) { - CHECK(offsets[0].has_value()); - tile = builder.create( - target_ty, base_addr, idxs, - bounds->getSublaneMask(mlir_ctx, ctx.target_shape), - builder.getI32IntegerAttr(sublane_stride)); - } else { - if (load_map) { - if (layout_out.bitwidth() != 32) { - load_op.emitOpError("Not implemented"); - return absl::UnimplementedError(""); - } - tile = builder.create( - target_ty, base_addr, idxs, load_map, - // TODO(tlongeri): Not sure whether we are obeying the semantics - // of in_bounds, but our lowering ignores it and this path will - // removed soon anyway. - SmallVector(2, true)); - } else { - const SmallVector sublane_mask(ctx.target_shape[0], true); - const auto sublane_mask_attr = - DenseBoolArrayAttr::get(mlir_ctx, sublane_mask); - tile = builder.create( - target_ty, base_addr, idxs, sublane_mask_attr, - builder.getI32IntegerAttr(sublane_stride)); - } - } - tiles(tile_idxs) = tile->getResult(0); - return absl::OkStatus(); - }); - if (!status.ok()) { - return failure(); - } - load_op->replaceAllUsesWith( - assemble(builder, vty, layout_out, std::move(tiles), ctx.target_shape)); - load_op->erase(); - return success(); -} - -LogicalResult arith_constant_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(layouts_in.size(), 0); - TPU_ASSERT_EQ_OP(layouts_out.size(), 1); - ImplicitLocOpBuilder builder(op.getLoc(), &op); - auto constant_op = cast(op); - auto vty = dyn_cast(op.getResult(0).getType()); - if (vty) { - if (!layouts_out.front().has_value()) { - return op.emitOpError( - "Expected non-null output layout for vector constant"); - } - const VectorLayout &layout_out = *layouts_out.front(); - DenseElementsAttr value = cast(constant_op.getValue()); - const VectorType target_vty = getNativeVregOrVmaskType( - vty.getElementType(), layout_out.bitwidth(), ctx.target_shape); - if (value.isSplat()) { - if (layout_out.offsets() != LayoutOffsets{std::nullopt, std::nullopt}) { - return op.emitOpError( - "Not implemented: Non-replicated splat constants"); - } - auto new_value = - DenseElementsAttr::get(target_vty, value.getSplatValue()); - const auto tile = - builder.create(target_vty, new_value); - const xla::Array tiles( - layout_out.tileArrayShape(vty.getShape(), ctx.target_shape), - tile->getResult(0)); - op.replaceAllUsesWith(assemble(builder, vty, layout_out, std::move(tiles), - ctx.target_shape)); - op.erase(); - return success(); - } - // !value.isSplat() - if (getTypeBitwidth(vty.getElementType()) != 32) { - return op.emitOpError( - "Not implemented: Only 32-bit non-splat constants are supported"); - } - auto func_op = op.getParentOfType(); - if (!func_op) { - return op.emitOpError("Expected a function op"); - } - FAILUREOR_ASSIGN_OR_RETURN(const BlockArgument ref, - appendConstant(ctx, func_op, value)); - auto load_op = builder.create( - vty, ref, - SmallVector(vty.getRank(), IdxConst(0, builder, op.getLoc()))); - op.replaceAllUsesWith(ArrayRef{load_op.getResult()}); - op.erase(); - const SmallVector vector_load_in_layouts(vty.getRank() + 1); - return vector_load_rule(ctx, *load_op, vector_load_in_layouts, - {VectorLayout(/*bitwidth=*/32, /*offsets=*/{0, 0}, - /*tiling=*/ctx.target_shape)}); - } - return op.emitOpError("Not implemented: Unsupported arith.const type: ") - << op.getResult(0).getType(); -} - -LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(layouts_in.size(), 1); - TPU_ASSERT_EQ_OP(layouts_out.size(), 1); - TPU_ASSERT_OP(layouts_out.front().has_value()); - const Layout &maybe_layout_in = layouts_in.front(); - const VectorLayout &layout_out = *layouts_out.front(); - ImplicitLocOpBuilder builder(op.getLoc(), &op); - vector::BroadcastOp broadcast_op = cast(op); - const VectorType dst_ty = broadcast_op.getResult().getType(); - const ArrayRef dst_shape = dst_ty.getShape(); - const SmallVector dst_tiles_shape = - layout_out.tileArrayShape(dst_shape, ctx.target_shape); - const SmallVector dst_tiles_implicit_shape = - layout_out.tileArrayImplicitShape(dst_shape, ctx.target_shape); - if (auto src = dyn_cast>(broadcast_op.getSource())) { - VectorType src_ty = src.getType(); - TPU_ASSERT_OP(maybe_layout_in.has_value()); - const VectorLayout &layout_in = *maybe_layout_in; - if (layout_in.implicit_dim() != layout_out.implicit_dim()) { - return op.emitOpError( - "Not implemented: Changing implicit dims mid-broadcast"); - } - const LayoutOffsets offsets_in = layout_in.offsets(); - const LayoutOffsets offsets_out = layout_out.offsets(); - if (layout_in.tiling() != layout_out.tiling()) { - return op.emitOpError("Not implemented: Changing tiling mid-broadcast"); - } - auto tiling = layout_in.tiling(); - - const int64_t expand_rank = dst_ty.getRank() - src_ty.getRank(); - const ArrayRef src_shape = src_ty.getShape(); - - SmallVector src_implicit_shape_padded; - // `is_logical_broadcast` stores whether each dimension of the implicit - // shape of the result is a broadcast. E.g. if the implicit shape goes from - // (2, 1, 3) to (4, 2, 5, 3) it's (true, false, true, false). - SmallVector is_logical_broadcast; - src_implicit_shape_padded.reserve(dst_shape.size() + - layout_in.num_implicit_dims()); - is_logical_broadcast.reserve(dst_shape.size() + - layout_in.num_implicit_dims()); - src_implicit_shape_padded.append(expand_rank, 1); - src_implicit_shape_padded.append(src_shape.begin(), src_shape.end()); - for (auto [i, o] : llvm::zip(src_implicit_shape_padded, dst_shape)) { - TPU_ASSERT_OP(i == o || i == 1); // Verifier should guarantee this. - is_logical_broadcast.push_back(i != o); - } - layout_in.insertImplicit(src_implicit_shape_padded, 1); - layout_in.insertImplicit(is_logical_broadcast, false); - - // Verify that the offsets are valid. - for (auto [is_logical_broadcast_on_dim, in_off, out_off] : - llvm::zip_equal(ArrayRef(is_logical_broadcast).take_back(2), - offsets_in, offsets_out)) { - if (is_logical_broadcast_on_dim) { - if (out_off.has_value()) { - // There's no reason to ever assign a non-replicated offset to a - // broadcasted dimension in the output. - return op.emitOpError( - // TODO(tlongeri): This should never be implemented but the fuzzed - // tests expect a NotImplementedError, which - // is raised with a "Not implemented" (see - // NotImplementedDetector in tpu_ext.cc). Fix. - "Not implemented: Broadcast output expected to have replicated " - "offsets."); - } - } else { // !is_logical_broadcast_on_dim - if (in_off != out_off) { - return op.emitOpError( - "Not implemented: Changing offsets mid-broadcast"); - } - } - } - - // `needs_physical_broadcast` specifies whether we need to broadcast vregs - // vregs in the sublane and lane dimensions. We only need to do this if the - // corresponding dimension of the implicit shape is logically broadcast and - // if the input vregs are not already replicated along this dimension. - const std::array needs_physical_broadcast{ - *(is_logical_broadcast.end() - 2) && offsets_in[0].has_value(), - *(is_logical_broadcast.end() - 1) && offsets_in[1].has_value()}; - FAILUREOR_ASSIGN_OR_RETURN( - xla::Array src_tiles, - disassemble(builder, layout_in, src, ctx.target_shape, - /*use_implicit_shape=*/true)); - xla::Array dst_tiles(dst_tiles_implicit_shape); - if (needs_physical_broadcast == std::array{false, false}) { // No-op - SmallVector reshape_dims(expand_rank, 1); - const absl::Span src_tiles_dims = src_tiles.dimensions(); - reshape_dims.append(src_tiles_dims.begin(), src_tiles_dims.end()); - src_tiles.Reshape(reshape_dims); - dst_tiles.Each([&](const absl::Span dst_idx, Value *tile) { - const SmallVector src_idx = llvm::map_to_vector( - llvm::zip_equal(dst_idx, is_logical_broadcast), [](auto tup) { - auto [i, is_logical_broadcast_on_dim] = tup; - return is_logical_broadcast_on_dim ? 0 : i; - }); - *tile = src_tiles(src_idx); - }); - } else { - if (tiling[1] != ctx.target_shape[1]) { - return op.emitOpError("Not implemented: unsupported tiling"); - } - int64_t num_tiles = layout_in.tilesPerVreg(ctx.target_shape); - if (needs_physical_broadcast == - std::array{true, false}) { // Sublane broadcast - const int packing = layout_in.packing(); - if (num_tiles != 1) { - return op.emitOpError( - "Not implemented: Only native tiling supported"); - } - TPU_ASSERT_EQ_OP(*(src_tiles.dimensions().end() - 2), 1); - TPU_ASSERT_OP(offsets_in[0].has_value()); - const int64_t sublane_offset = *offsets_in[0] / packing; - const int64_t subelement_offset = *offsets_in[0] % packing; - const DenseI32ArrayAttr indices = builder.getDenseI32ArrayAttr( - SmallVector(ctx.target_shape[0], sublane_offset)); - const absl::Status status = - src_tiles.EachStatus([&](const absl::Span src_idx, - Value *const src_vreg) { - Value dst_vreg = *src_vreg; - // Replicate the value within each sublane. - if (packing != 1) { - if (auto new_dst_vreg = broadcastSubelements( - builder, cast>(dst_vreg), - subelement_offset, ctx.target_shape, - ctx.hardware_generation); - succeeded(new_dst_vreg)) { - dst_vreg = *new_dst_vreg; - } else { - return absl::InternalError(""); - } - } - dst_vreg = builder.create(dst_vreg.getType(), - dst_vreg, indices, 0); - SmallVector dst_starts(dst_tiles_implicit_shape.size()); - SmallVector dst_limits(dst_tiles_implicit_shape.size()); - for (int64_t i = 0; i < dst_tiles.num_dimensions(); ++i) { - if (i < expand_rank || is_logical_broadcast[i]) { - dst_starts[i] = 0; - dst_limits[i] = dst_tiles_implicit_shape[i]; - } else { - dst_starts[i] = src_idx[i - expand_rank]; - dst_limits[i] = dst_starts[i] + 1; - } - } - updateSlice(dst_tiles, dst_vreg, dst_starts, dst_limits); - return absl::OkStatus(); - }); - if (!status.ok()) { - return failure(); - } - } else if (needs_physical_broadcast == - std::array{false, true}) { // Lane broadcast - TPU_ASSERT_EQ_OP(*(src_tiles.dimensions().end() - 1), 1); - TPU_ASSERT_OP(offsets_in[1].has_value()); - const int64_t sublanes_per_tile = - layout_in.sublanesPerTile(ctx.target_shape); - const int64_t offset = *offsets_in[1]; - const int64_t lane_offset = offset % ctx.target_shape[1]; - const int64_t tile_offset = offset / ctx.target_shape[1]; - Value lane_offset_cst = getFullVector( - builder, getNativeVregType(builder.getI32Type(), ctx.target_shape), - builder.getI32IntegerAttr(lane_offset)); - DenseI32ArrayAttr sublane_pattern; - if (num_tiles != 1) { - SmallVector pattern; - pattern.reserve(ctx.target_shape[0]); - for (int32_t t = 0; t < num_tiles; ++t) { - for (int32_t i = 0; i < sublanes_per_tile; ++i) { - pattern.push_back(sublanes_per_tile * tile_offset + i); - } - } - sublane_pattern = builder.getDenseI32ArrayAttr(pattern); - } - src_tiles.Each([&](const absl::Span src_idx, - Value *const src_tile) { - SmallVector dst_starts(dst_tiles_implicit_shape.size()); - SmallVector dst_limits(dst_tiles_implicit_shape.size()); - for (int64_t i = 0; i < dst_tiles.num_dimensions(); ++i) { - if (i < expand_rank || is_logical_broadcast[i]) { - dst_starts[i] = 0; - dst_limits[i] = dst_tiles_implicit_shape[i]; - } else { - dst_starts[i] = src_idx[i - expand_rank]; - dst_limits[i] = dst_starts[i] + 1; - } - } - Value res_vreg = builder.create( - broadcast_op.getLoc(), src_tile->getType(), *src_tile, - lane_offset_cst, - /*dimension=*/1); - if (num_tiles != 1) { - res_vreg = builder.create( - broadcast_op.getLoc(), res_vreg.getType(), res_vreg, - sublane_pattern, 0); - } - updateSlice(dst_tiles, res_vreg, dst_starts, dst_limits); - }); - } else { - TPU_ASSERT_OP((needs_physical_broadcast == std::array{true, true})); - return op.emitOpError( - "Not implemented: Broadcast in both sublanes and lanes"); - } - } - broadcast_op.replaceAllUsesWith(assemble(builder, dst_ty, layout_out, - dst_tiles, ctx.target_shape, - /*use_implicit_shape=*/true) - .getOperation()); - broadcast_op.erase(); - return success(); - } else if (layout_out.bitwidth() == 32 && - broadcast_op.getSourceType().getIntOrFloatBitWidth() == 1) { - // Broadcasting the i1 scalar involves first converting i1 to i32, followed - // by broadcasting i32 to the target shape. Finally, the comparison with 0s - // yields the vmask. - auto src_i32 = builder.create( - broadcast_op.getLoc(), builder.getI32Type(), broadcast_op.getSource()); - - const VectorType native_vreg_ty = - getNativeVregType(src_i32.getType(), ctx.target_shape); - auto tile_i32 = - builder.create(native_vreg_ty, src_i32); - Value zeros = getZerosVector(builder, tile_i32.getType()); - auto tile = - builder.create(arith::CmpIPredicate::ne, tile_i32, zeros) - .getResult(); - const xla::Array dst_tiles(dst_tiles_shape, tile); - broadcast_op.replaceAllUsesWith( - assemble(builder, dst_ty, layout_out, dst_tiles, ctx.target_shape) - .getOperation()); - broadcast_op.erase(); - return success(); - } else if (layout_out.bitwidth() < 32) { - CHECK_EQ(layout_out.bitwidth(), - broadcast_op.getSourceType().getIntOrFloatBitWidth()); - // Broadcasting the scalar with narrower type involves first packing (32 / - // bitwidth) copies to i32, followed by broadcasting i32 to the target - // shape. Finally, bitcast i32 vector back to the original narrower type - // vector. - auto loc = broadcast_op.getLoc(); - auto src_ty = broadcast_op.getSourceType(); - auto bitwidth = src_ty.getIntOrFloatBitWidth(); - auto unpacked_src = broadcast_op.getSource(); - if (!src_ty.isSignlessInteger(bitwidth)) { - unpacked_src = builder.create( - loc, builder.getIntegerType(bitwidth), unpacked_src); - } - auto src_i32 = - builder.create(loc, builder.getI32Type(), unpacked_src) - .getResult(); - for (int i = 1; i < (32 / bitwidth); ++i) { - auto shift_width = builder.create( - loc, builder.getIntegerAttr(builder.getI32Type(), i * bitwidth)); - src_i32 = builder.create( - loc, src_i32, - builder.create(loc, src_i32, shift_width)); - } - - const VectorType i32_vreg_ty = - getNativeVregType(src_i32.getType(), ctx.target_shape); - auto tile_i32 = builder.create(i32_vreg_ty, src_i32); - - const VectorType native_vreg_ty = - getNativeVregType(src_ty, ctx.target_shape); - auto tile = builder.create(native_vreg_ty, tile_i32); - - const xla::Array dst_tiles(dst_tiles_shape, tile); - broadcast_op.replaceAllUsesWith( - assemble(builder, dst_ty, layout_out, dst_tiles, ctx.target_shape) - .getOperation()); - broadcast_op.erase(); - return success(); - } else { - const VectorType native_vreg_ty = - getNativeVregType(broadcast_op.getSourceType(), ctx.target_shape); - auto tile = builder.create(native_vreg_ty, - broadcast_op.getSource()); - const xla::Array dst_tiles(dst_tiles_shape, tile); - broadcast_op.replaceAllUsesWith( - assemble(builder, dst_ty, layout_out, dst_tiles, ctx.target_shape) - .getOperation()); - broadcast_op.erase(); - return success(); - } -} - -// Returns slice of vregs containing a given slice of elements, obtained from -// the result of a vector.extract or vector.extract_strided_slice op. -// -// Takes offsets and sizes describing the slice of elements. If their size is -// less than the rank of the input vector, they describe a prefix i.e. they -// apply to the first (majormost) dimensions and the remaining dimensions are -// not sliced. -// -// Args: -// - ctx: Rewrite context (for disassembling, which may create an op). -// - op: Source vector.extract or vector.extract_strided_slice op. -// - offsets: Prefix of offsets of slice of elements. Must have the same size -// as sizes. -// - sizes: Prefix of sizes of slice of elements. Must have the same size -// as offsets. -// - layout_in: Layout of src_vector. -// - layout_out: Layout that will be used to reassemble the slice (by caller). -// Used only to check that the reassembling is valid. -FailureOr> vector_extract_slice_impl( - RewriteContext &ctx, Operation &op, const ArrayRef sizes, - const ArrayRef offsets, const VectorLayout &layout_in, - const VectorLayout &layout_out) { - if (layout_in.tiling() != layout_out.tiling() || - layout_in.bitwidth() != layout_out.bitwidth()) { - return op.emitOpError( - "Not implemented: Expected layout_in and layout_out tiling and packing " - "to match"); - } - - // Both extract_strided_slice and extract have their input vector at index 0 - // and a single result. - CHECK((isa(op))); - auto src_vector = cast>(op.getOperand(0)); - auto result = cast>(op.getResult(0)); - - const VectorType dst_ty = result.getType(); - if (layout_in.implicit_dim() != layout_out.implicit_dim() && - !(layout_in.implicit_dim() == VectorLayout::ImplicitDim::kNone && - layout_out.implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor && - dst_ty.getRank() == 1)) { - return op.emitOpError( - "Not implemented: Unexpected change in implicit dimension that may not " - "be a no-op"); - } - - const ArrayRef src_vector_shape = src_vector.getType().getShape(); - const int64_t src_vector_rank = src_vector_shape.size(); - const int64_t num_indices = offsets.size(); - TPU_ASSERT_EQ_OP(num_indices, sizes.size()); - - SmallVector full_sizes; - full_sizes.reserve(src_vector_rank + layout_in.num_implicit_dims()); - full_sizes.append(sizes.begin(), sizes.end()); - full_sizes.append(src_vector_shape.begin() + num_indices, - src_vector_shape.end()); - layout_in.insertImplicit(full_sizes, 1); - - SmallVector full_offsets; - full_offsets.reserve(src_vector_rank + layout_in.num_implicit_dims()); - full_offsets.append(offsets.begin(), offsets.end()); - full_offsets.append(src_vector_rank - num_indices, 0); - layout_in.insertImplicit(full_offsets, 0); - - // We currently only support no-op cases - that is, those where we effectively - // just extract a slice of vregs without doing any operations (e.g. shifts) on - // them. - for (auto [index_offset, in_offset, vreg_slice, out_offset] : llvm::zip_equal( - ArrayRef(full_offsets).take_back(2), layout_in.offsets(), - layout_in.vregSlice(ctx.target_shape), layout_out.offsets())) { - if (in_offset.has_value() != out_offset.has_value()) { - return op.emitOpError( - "Unexpected mismatch in replication between input and output " - "layouts"); - } - if (in_offset.has_value() && - (index_offset + *in_offset) % vreg_slice != *out_offset) { - return op.emitOpError("Not implemented: Only no-op tiles"); - } - } - - const std::array vreg_slice = - layout_in.vregSlice(ctx.target_shape); - SmallVector slice_tiled_starts(full_offsets); - *(slice_tiled_starts.end() - 2) = - (layout_in.offsets()[0].value_or(0) + *(full_offsets.end() - 2)) / - vreg_slice[0]; - *(slice_tiled_starts.end() - 1) = - (layout_in.offsets()[1].value_or(0) + *(full_offsets.end() - 1)) / - vreg_slice[1]; - layout_in.eraseImplicit(slice_tiled_starts); - SmallVector slice_tiled_limits(full_offsets); - for (int64_t i = 0; i < full_offsets.size() - layout_in.layout_rank(); ++i) { - slice_tiled_limits[i] += full_sizes[i]; - } - *(slice_tiled_limits.end() - 2) = - llvm::divideCeil(layout_in.offsets()[0].value_or(0) + - *(full_offsets.end() - 2) + *(full_sizes.end() - 2), - vreg_slice[0]); - *(slice_tiled_limits.end() - 1) = - llvm::divideCeil(layout_in.offsets()[1].value_or(0) + - *(full_offsets.end() - 1) + *(full_sizes.end() - 1), - vreg_slice[1]); - layout_in.eraseImplicit(slice_tiled_limits); - - OpBuilder builder(&op); - FAILUREOR_ASSIGN_OR_RETURN( - const xla::Array input_tiles, - disassemble(builder, layout_in, src_vector, ctx.target_shape)); - return input_tiles.Slice(slice_tiled_starts, slice_tiled_limits); -} - -LogicalResult vector_extract_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - ImplicitLocOpBuilder builder(op.getLoc(), &op); - vector::ExtractOp extract_op = cast(op); - if (extract_op.hasDynamicPosition()) { - return op.emitOpError("Not implemented: dynamic indices"); - } - TPU_ASSERT_EQ_OP(layouts_in.size(), 1); - TPU_ASSERT_EQ_OP(layouts_out.size(), 1); - TPU_ASSERT_OP(layouts_in.front().has_value()); - const VectorLayout &layout_in = *layouts_in.front(); - if (layout_in.bitwidth() != 32) { - return op.emitOpError( - "Not implemented: Only 32-bit vector.extract supported"); - } - const VectorType res_vty = - dyn_cast(extract_op.getResult().getType()); - if (res_vty != nullptr) { - TPU_ASSERT_OP(layouts_out.front().has_value()); - const VectorLayout &layout_out = *layouts_out.front(); - const int64_t num_indices = extract_op.getStaticPosition().size(); - const SmallVector sizes(num_indices, 1); - FAILUREOR_ASSIGN_OR_RETURN( - xla::Array dst_vregs, - vector_extract_slice_impl(ctx, *extract_op, sizes, - extract_op.getStaticPosition(), layout_in, - *layouts_out.front())); - // Squeeze leading singleton dimensions. - TPU_ASSERT_EQ_OP(res_vty.getRank(), - extract_op.getSourceVectorType().getRank() - num_indices); - TPU_ASSERT_OP( - llvm::all_of(toArrayRef(dst_vregs.dimensions()).take_front(num_indices), - [](const int64_t d) { return d == 1; })); - // Copy dims to temporary before passing to xla::Array::Reshape - it cannot - // take a pointer to its own data. - dst_vregs.Reshape(SmallVector( - toArrayRef(dst_vregs.dimensions()).drop_front(num_indices))); - op.replaceAllUsesWith( - assemble(builder, res_vty, layout_out, dst_vregs, ctx.target_shape) - .getOperation()); - op.erase(); - return success(); - } else { - // TODO(b/367459476): Support non-zero offsets. - if (layout_in.offsets() != LayoutOffsets{0, 0}) { - return op.emitOpError("Not implemented: Unsupported layout"); - } - auto [sub_tile, lane_tile] = layout_in.tiling(); - FAILUREOR_ASSIGN_OR_RETURN( - const xla::Array vregs, - disassemble(builder, layout_in, extract_op.getVector(), - ctx.target_shape)); - TPU_ASSERT_GT_OP(vregs.num_elements(), 0); - - SmallVector indices(extract_op.getStaticPosition()); - auto vreg_slice = layout_in.vregSlice(ctx.target_shape); - std::array position = {0, 0}; - SmallVector vreg_index(indices); - // TODO(b/367459476): Support non-VREG-aligned tiling. - CHECK_EQ(lane_tile, ctx.target_shape[1]); - layout_in.insertImplicit(indices, static_cast(0)); - layout_in.insertImplicit(vreg_index, static_cast(0)); - int i = *(indices.end()-2); - int j = *(indices.end()-1); - *(vreg_index.end() -2) = i / vreg_slice[0]; - *(vreg_index.end() -1) = j / vreg_slice[1]; - layout_in.eraseImplicit(vreg_index); - position[0] = ((j % vreg_slice[1]) / lane_tile * sub_tile - ) + i % sub_tile; - position[1] = j % lane_tile; - - TPU_ASSERT_LT_OP(vreg_index, vregs.dimensions()); - Value extracted_vreg = vregs(vreg_index); - - // Invert the offsets to get the rotation amount. - position[0] = (ctx.target_shape[0] - position[0]) % ctx.target_shape[0]; - position[1] = (ctx.target_shape[1] - position[1]) % ctx.target_shape[1]; - auto res_vreg_ty = extracted_vreg.getType(); - Value shift = builder.create( - builder.getIntegerAttr(builder.getI32Type(), position[0])); - Value rotated_vreg = builder.create( - res_vreg_ty, extracted_vreg, shift, 0, /*stride*/nullptr, nullptr); - shift = builder.create( - builder.getIntegerAttr(builder.getI32Type(), position[1])); - rotated_vreg = builder.create( - res_vreg_ty, rotated_vreg, shift, 1, /*stride*/nullptr, nullptr); - extract_op.replaceAllUsesWith( - builder.create( - op.getLoc(), rotated_vreg, - ArrayRef{0, 0}) - .getResult()); - } - extract_op.erase(); - return success(); -} - -LogicalResult vector_extract_strided_slice_rule( - RewriteContext &ctx, Operation &op, const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(layouts_in.size(), 1); - TPU_ASSERT_EQ_OP(layouts_out.size(), 1); - TPU_ASSERT_OP(layouts_in.front().has_value()); - TPU_ASSERT_OP(layouts_out.front().has_value()); - const VectorLayout &layout_in = *layouts_in.front(); - const VectorLayout &layout_out = *layouts_out.front(); - ImplicitLocOpBuilder builder(op.getLoc(), &op); - auto extract_strided_slice_op = cast(op); - - auto I64ArrayToSmallVector = [&](const ArrayAttr array_attr) { - return llvm::map_to_vector(array_attr, [](Attribute attr) { - return cast(attr).getValue().getSExtValue(); - }); - }; - - FAILUREOR_ASSIGN_OR_RETURN( - const xla::Array dst_vregs, - vector_extract_slice_impl( - ctx, *extract_strided_slice_op, - I64ArrayToSmallVector(extract_strided_slice_op.getSizes()), - I64ArrayToSmallVector(extract_strided_slice_op.getOffsets()), - layout_in, layout_out)); - op.replaceAllUsesWith(assemble(builder, - extract_strided_slice_op.getResult().getType(), - layout_out, dst_vregs, ctx.target_shape) - .getOperation()); - op.erase(); - return success(); -} - -LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(layouts_in.size(), 2); - TPU_ASSERT_EQ_OP(layouts_out.size(), 1); - TPU_ASSERT_OP( - llvm::all_of(layouts_in, [&](const Layout &l) { return l.has_value(); })); - const Location loc = op.getLoc(); - const VectorLayout &src_layout = *layouts_in[0]; - const VectorLayout &acc_layout = *layouts_in[1]; - const VectorLayout &dst_layout = *layouts_out[0]; - ImplicitLocOpBuilder builder(op.getLoc(), &op); - auto multi_reduction_op = cast(op); - const VectorType src_ty = multi_reduction_op.getSourceVectorType(); - auto element_type = src_ty.getElementType(); - int64_t src_rank = src_ty.getRank(); - const auto res_ty = dyn_cast(multi_reduction_op.getDestType()); - if (res_ty == nullptr) { - return multi_reduction_op.emitOpError( - "Not implemented: Can only reduce into vectors"); - } - // Op definition enforces that accumulator type must match result type - auto acc = cast>(multi_reduction_op.getAcc()); - TPU_ASSERT_OP(layouts_out.front().has_value()); - - SmallVector dims(multi_reduction_op.getReductionDims()); - std::sort(dims.begin(), dims.end()); - - // Make sure that the accumulator is a splat of the neutral value - if (acc_layout.offsets() != LayoutOffsets{std::nullopt, std::nullopt}) { - return multi_reduction_op.emitOpError( - "Not implemented: Only replicated accumulator supported"); - } - FAILUREOR_ASSIGN_OR_RETURN( - const xla::Array acc_vregs, - disassemble(builder, acc_layout, acc, ctx.target_shape)); - auto acc_def = dyn_cast_if_present( - acc_vregs.begin()->getDefiningOp()); - if (acc_def == nullptr) { - return multi_reduction_op.emitOpError( - "Not implemented: Only constant accumulator supported"); - } - if (!element_type.isF32() && !element_type.isBF16() && - !element_type.isSignlessInteger((32))) { - return multi_reduction_op.emitOpError( - "Not implemented: unsupported element type"); - } - bool is_int = element_type.isSignlessInteger(32); - const auto acc_def_value = dyn_cast(acc_def.getValue()); - if (acc_def_value == nullptr || !acc_def_value.isSplat()) { - return multi_reduction_op.emitOpError("Expected a splat constant"); - } - TPU_ASSERT_OP(acc_def_value.getElementType() == element_type); - Attribute neutral; - switch (multi_reduction_op.getKind()) { - case vector::CombiningKind::ADD: - neutral = builder.getZeroAttr(element_type); - break; - case vector::CombiningKind::MAXIMUMF: { - // TODO(b/322836633): The semantics of maximumf don't match the lowering - // for older TPU versions because older TPU versions don't respect the - // -0.0 vs +0.0 ordering. - neutral = builder.getFloatAttr( - element_type, - APFloat::getInf(cast(element_type).getFloatSemantics(), - /*Negative=*/true)); - } break; - case vector::CombiningKind::MINIMUMF: { - neutral = builder.getFloatAttr( - element_type, - APFloat::getInf(cast(element_type).getFloatSemantics(), - /*Negative=*/false)); - } break; - case vector::CombiningKind::MAXSI: { - neutral = builder.getIntegerAttr( - element_type, - APInt::getSignedMinValue(element_type.getIntOrFloatBitWidth())); - } break; - case vector::CombiningKind::MINSI: { - neutral = builder.getIntegerAttr( - element_type, - APInt::getSignedMaxValue(element_type.getIntOrFloatBitWidth())); - } break; - default: - return multi_reduction_op.emitOpError( - "Not implemented: unsupported kind"); - } - if (auto val = acc_def_value.getSplatValue(); val != neutral) { - return multi_reduction_op.emitOpError( - "Not implemented: Only neutral accumulator supported for " - "float reduction. Expected ") - << neutral << ", but got " << val; - } - - std::array reduces; - switch (src_layout.implicit_dim()) { - case VectorLayout::ImplicitDim::kNone: - reduces = { - std::find(dims.begin(), dims.end(), src_rank - 2) != dims.end(), - std::find(dims.begin(), dims.end(), src_rank - 1) != dims.end()}; - break; - case VectorLayout::ImplicitDim::kSecondMinor: - reduces = {false, std::find(dims.begin(), dims.end(), src_rank - 1) != - dims.end()}; - break; - case VectorLayout::ImplicitDim::kMinor: - reduces = { - std::find(dims.begin(), dims.end(), src_rank - 1) != dims.end(), - false}; - break; - } - - if ((reduces[0] || reduces[1]) && - !src_layout.hasNativeTiling(ctx.target_shape)) { - return multi_reduction_op.emitOpError( - "Not implemented: Unsupported input layout: ") - << src_layout; - } - if (src_layout.tiling() != dst_layout.tiling()) { - return multi_reduction_op.emitOpError("Not implemented: Tiling change"); - } - for (int i = 0; i < 2; ++i) { - if (reduces[i] && src_layout.offsets()[i] == std::nullopt && - element_type.getIntOrFloatBitWidth() != 32) { - return multi_reduction_op.emitOpError( - "Not implemented: Non-32-bit reductions over replicated axes"); - } - // Offsets have to be equal, unless we're reducing over that dimension. - if (src_layout.offsets()[i] != dst_layout.offsets()[i] && !reduces[i]) { - return multi_reduction_op.emitOpError("Not implemented: Offset change"); - } - } - VectorLayout::ImplicitDim dst_implicit_dim; - if ((reduces[0] && reduces[1]) || - (src_layout.implicit_dim() != VectorLayout::ImplicitDim::kNone && - (reduces[0] || reduces[1]))) { - // This is difficult, because we'd like to make both tiling dims implicit, - // but there is no way to do that in VectorLayout right now. - // We use an equivalence between VectorLayouts when trailing dims are 1 - // to enable some special cases, but we should generalize this. - if (*(res_ty.getShape().end() - 1) != 1) { - return multi_reduction_op.emitOpError( - "Not implemented: reductions over both trailing dimensions are only " - "supported when the resulting value has a trailing axis of size 1"); - } - dst_implicit_dim = - VectorLayout::ImplicitDim::kSecondMinor; // Anything works. - } else if (reduces[0]) { - TPU_ASSERT_OP(src_layout.implicit_dim() == - VectorLayout::ImplicitDim::kNone); - dst_implicit_dim = VectorLayout::ImplicitDim::kSecondMinor; - } else if (reduces[1]) { - TPU_ASSERT_OP(src_layout.implicit_dim() == - VectorLayout::ImplicitDim::kNone); - dst_implicit_dim = VectorLayout::ImplicitDim::kMinor; - } else { - dst_implicit_dim = src_layout.implicit_dim(); - } - if (dst_layout.implicit_dim() != dst_implicit_dim) { - return multi_reduction_op.emitOpError( - "Not implemented: Unsupported output implicit dimension"); - } - - FAILUREOR_ASSIGN_OR_RETURN( - xla::Array src_vregs, - disassemble(builder, src_layout, multi_reduction_op.getSource(), - ctx.target_shape)); - xla::Array dst_vregs( - dst_layout.tileArrayShape(res_ty.getShape(), ctx.target_shape)); - tpu::ReductionKind tpu_kind; - switch (multi_reduction_op.getKind()) { - case vector::CombiningKind::ADD: - tpu_kind = tpu::ReductionKind::SUM; - break; - case vector::CombiningKind::MAXIMUMF: - case vector::CombiningKind::MAXSI: - tpu_kind = tpu::ReductionKind::MAX; - break; - case vector::CombiningKind::MINIMUMF: - case vector::CombiningKind::MINSI: - tpu_kind = tpu::ReductionKind::MIN; - break; - default: - return multi_reduction_op.emitOpError( - "Not implemented: unsupported reduction kind"); - } - const ArrayRef src_shape = src_ty.getShape(); - auto all_results_ok = dst_vregs.EachStatus( - [&](const absl::Span idx, Value *const dst_vreg) { - // Extract a subset of source vregs that reduce into this result vreg. - SmallVector src_slice_start; - src_slice_start.reserve(src_rank); - SmallVector src_slice_end; - src_slice_end.reserve(src_rank); - for (int64_t i : idx) { - src_slice_start.push_back(i); - src_slice_end.push_back(i + 1); - } - for (int64_t d : dims) { - int64_t d_size = src_vregs.dim(d); - src_slice_start.insert(src_slice_start.begin() + d, 0); - if (!src_layout.offsets()[0].has_value() && d == src_rank - 2) { - d_size = 1; - } - if (!src_layout.offsets()[1].has_value() && d == src_rank - 1) { - d_size = 1; - } - src_slice_end.insert(src_slice_end.begin() + d, d_size); - } - xla::Array reduced_vregs = - src_vregs.Slice(src_slice_start, src_slice_end); - std::optional acc_vreg; - auto reduce_elementwise = [&](Value lhs, Value rhs) -> Value { - Value result; - switch (tpu_kind) { - case tpu::ReductionKind::SUM: - result = - is_int - ? builder.create(loc, lhs, rhs).getResult() - : builder.create(loc, lhs, rhs) - .getResult(); - break; - case tpu::ReductionKind::MAX: - result = is_int ? builder.create(loc, lhs, rhs) - .getResult() - : builder.create(loc, lhs, rhs) - .getResult(); - break; - case tpu::ReductionKind::MIN: - result = is_int ? builder.create(loc, lhs, rhs) - .getResult() - : builder.create(loc, lhs, rhs) - .getResult(); - break; - } - return result; - }; - auto reduction_status = reduced_vregs.EachStatus( - [&](const absl::Span red_idx, Value *const src_vreg) { - SmallVector src_idx(red_idx.begin(), red_idx.end()); - for (int i = 0; i < src_idx.size(); ++i) { - src_idx[i] += src_slice_start[i]; - } - const std::unique_ptr data_bounds = - src_layout.tileDataBounds(builder.getContext(), src_shape, - src_idx, ctx.target_shape, - {true, true}); - if (data_bounds == nullptr) { - // Op error has already been emitted inside tileDataBounds(). - return absl::UnknownError("Unable to obtain data bounds"); - } - Value vreg = *src_vreg; - // If replicated, we don't need to mask. - if (src_layout.offsets()[0].has_value() || - src_layout.offsets()[1].has_value()) { - // TODO(tlongeri): Maybe assemble/disassemble should take - // TypedValue and we could save casts here and - // elsewhere - FailureOr failure_or_vreg = - maskOOB(ctx, builder, cast>(*src_vreg), - *data_bounds, neutral); - if (failed(failure_or_vreg)) { - op.emitOpError("Failed to mask vreg"); - return absl::UnknownError(""); - } - vreg = failure_or_vreg.value(); - } - if (!acc_vreg.has_value()) { - acc_vreg = vreg; - } else { - acc_vreg = reduce_elementwise(*acc_vreg, vreg); - } - return absl::OkStatus(); - }); - TF_RETURN_IF_ERROR(reduction_status); - TPU_ASSERT_OP(acc_vreg.has_value()); - const bool is_double_replicated_double_reduced = - reduces[0] && reduces[1] && !src_layout.offsets()[0].has_value() && - !src_layout.offsets()[1].has_value(); - if (reduces[1]) { - if (src_layout.offsets()[1].has_value()) { - acc_vreg = builder.create( - multi_reduction_op->getLoc(), *acc_vreg, /* dim= */ 1, tpu_kind); - } else { - int64_t size_dim1 = src_layout.getImplicitTiledDims(src_shape, 1)[1]; - if (is_double_replicated_double_reduced) { - size_dim1 *= src_layout.getImplicitTiledDims(src_shape, 1)[0]; - } - switch (tpu_kind) { - case tpu::ReductionKind::SUM: - if (is_int) { - IntegerAttr size_attr = builder.getI32IntegerAttr(size_dim1); - TypedValue source_value = getFullVector( - builder, - getNativeVregType(builder.getI32Type(), ctx.target_shape), - size_attr); - acc_vreg = - builder.create(loc, *acc_vreg, source_value); - } else { - FloatAttr size_attr = builder.getF32FloatAttr(size_dim1); - TypedValue source_value = getFullVector( - builder, - getNativeVregType(builder.getF32Type(), ctx.target_shape), - size_attr); - acc_vreg = - builder.create(loc, *acc_vreg, source_value); - } - break; - // We don't need to do anything for other reduction kinds. - case tpu::ReductionKind::MAX: - case tpu::ReductionKind::MIN: - break; - } - } - } - if (reduces[0]) { - // Packed types are compressed along rows, so we need to reduce them - // within each 32-bit word. There's no performance penalty for doing - // this in 32-bit precision, so we take advantage of it. - Type acc_vreg_ty = acc_vreg->getType(); - if (acc_layout.packing() > 1) { - Type vreg_ty_32 = nullptr; - if (acc.getType().getElementType().isBF16()) { - vreg_ty_32 = - getNativeVregType(builder.getF32Type(), ctx.target_shape); - } else { - multi_reduction_op.emitOpError( - "Not implemented: Unsupported reduction dtype"); - return absl::UnknownError(""); - } - Value acc_vreg_32 = builder.create( - loc, vreg_ty_32, *acc_vreg, 0, tpu::PackFormat::kInterleaved); - for (int i = 1; i < acc_layout.packing(); ++i) { - Value acc_vreg_part_32 = builder.create( - loc, vreg_ty_32, *acc_vreg, i, tpu::PackFormat::kInterleaved); - acc_vreg_32 = reduce_elementwise(acc_vreg_32, acc_vreg_part_32); - } - acc_vreg = acc_vreg_32; - } - // At this point acc_vreg is always 32-bit. - if (src_layout.offsets()[0].has_value()) { - acc_vreg = builder.create( - multi_reduction_op->getLoc(), *acc_vreg, 0, tpu_kind); - } else if (!is_double_replicated_double_reduced) { - int64_t size_dim0 = src_layout.getImplicitTiledDims(src_shape, 1)[0]; - switch (tpu_kind) { - case tpu::ReductionKind::SUM: - if (is_int) { - IntegerAttr size_attr = builder.getI32IntegerAttr(size_dim0); - TypedValue source_value = getFullVector( - builder, - getNativeVregType(builder.getI32Type(), ctx.target_shape), - size_attr); - acc_vreg = - builder.create(loc, *acc_vreg, source_value); - } else { - FloatAttr size_attr = builder.getF32FloatAttr(size_dim0); - TypedValue source_value = getFullVector( - builder, - getNativeVregType(builder.getF32Type(), ctx.target_shape), - size_attr); - acc_vreg = - builder.create(loc, *acc_vreg, source_value); - } - break; - case tpu::ReductionKind::MAX: - case tpu::ReductionKind::MIN: - break; - } - } - // We pack the final result back into the original type. - if (acc_layout.packing() > 1) { - SmallVector positions(acc_layout.packing()); - std::iota(positions.begin(), positions.end(), - static_cast(0)); - SmallVector parts(acc_layout.packing(), *acc_vreg); - acc_vreg = builder.create( - loc, acc_vreg_ty, parts, - builder.getDenseI32ArrayAttr(positions), - tpu::PackFormat::kInterleaved); - } - } - *dst_vreg = *acc_vreg; - return absl::OkStatus(); - }); - if (!all_results_ok.ok()) { - return failure(); - } - multi_reduction_op->replaceAllUsesWith( - assemble(builder, res_ty, dst_layout, dst_vregs, ctx.target_shape)); - multi_reduction_op->erase(); - return success(); -} - -LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(layouts_in.size(), 1); - TPU_ASSERT_EQ_OP(layouts_out.size(), 1); - TPU_ASSERT_OP(layouts_in.front().has_value()); - TPU_ASSERT_OP(layouts_out.front().has_value()); - using Tiling = std::array; - const VectorLayout &layout_in = *layouts_in.front(); - const VectorLayout &layout_out = *layouts_out.front(); - TPU_ASSERT_EQ_OP( - layout_in.bitwidth(), - layout_out.bitwidth()); // This should be guaranteed through MLIR - // verifier plus our layoutIsValidForValue check - ImplicitLocOpBuilder builder(op.getLoc(), &op); - auto shape_cast_op = cast(op); - const VectorType src_ty = shape_cast_op.getSourceVectorType(); - const ArrayRef src_shape = src_ty.getShape(); - const VectorType dst_ty = shape_cast_op.getResultVectorType(); - const ArrayRef dst_shape = dst_ty.getShape(); - bool no_op = false; - const std::array src_tiled_dims = - layout_in.getImplicitTiledDims(src_shape, 1); - const std::array dst_tiled_dims = - layout_out.getImplicitTiledDims(dst_shape, 1); - const std::array src_vreg_slice = - layout_in.vregSlice(ctx.target_shape); - const std::array dst_vreg_slice = - layout_out.vregSlice(ctx.target_shape); - if (layout_in.tiling() == layout_out.tiling() && - layout_in.offsets() == layout_out.offsets() && - src_tiled_dims == dst_tiled_dims) { - no_op = true; - } else if ( // Fold or unfold sublane dim, but keeping a whole number of - // vregs. - layout_in.offsets()[0] == 0 && - layout_in.offsets() == layout_out.offsets() && - layout_in.tiling() == layout_out.tiling() && - dst_tiled_dims[1] == src_tiled_dims[1] && - dst_tiled_dims[0] % dst_vreg_slice[0] == 0 && - src_tiled_dims[0] % src_vreg_slice[0] == 0) { - no_op = true; - } else if (layout_in.offsets() == layout_out.offsets() && - layout_in.offsets() == LayoutOffsets{0, 0} && - layout_in.tiling()[0] == 1 && - layout_out.hasNativeTiling(ctx.target_shape) && - dst_tiled_dims[1] == dst_vreg_slice[1] && - dst_tiled_dims[0] % dst_vreg_slice[0] == 0 && - src_tiled_dims[1] % src_vreg_slice[1] == 0) { - // Shapecast (..., m * 128 * packing) -> (..., 128). - no_op = true; - } else if (layout_in.offsets() == LayoutOffsets{0, 0} && - layout_out.offsets() == LayoutOffsets{0, 0} && - layout_in.hasNativeTiling(ctx.target_shape) && - layout_out.tiling()[0] == 1 && - src_tiled_dims[1] == src_vreg_slice[1] && - src_tiled_dims[0] % src_vreg_slice[0] == 0 && - dst_tiled_dims[1] % dst_vreg_slice[1] == 0) { - // Shapecast (..., 128) -> (..., m * 128 * packing). - no_op = true; - } else if (layout_in.offsets() == LayoutOffsets{0, 0} && - layout_out.offsets() == LayoutOffsets{0, 0} && - layout_in.tiling()[0] == 1 && layout_out.tiling()[0] == 1 && - src_vreg_slice[1] == dst_vreg_slice[1] && - src_tiled_dims[1] % src_vreg_slice[1] == 0 && - dst_tiled_dims[1] % dst_vreg_slice[1] == 0) { - no_op = true; - } - FAILUREOR_ASSIGN_OR_RETURN( - xla::Array src_vregs, - disassemble(builder, layout_in, shape_cast_op.getSource(), - ctx.target_shape, /*use_implicit_shape=*/true)); - auto getDstVregs = [&]() -> FailureOr> { - if (no_op) { - xla::Array dst_vregs_local = src_vregs; - dst_vregs_local.Reshape( - layout_out.tileArrayImplicitShape(dst_shape, ctx.target_shape)); - return dst_vregs_local; - } else if (dst_tiled_dims == std::array{src_tiled_dims[1], 1} && - layout_in.bitwidth() == 32 && - layout_in.hasNativeTiling(ctx.target_shape) && - layout_in.tiling() == layout_out.tiling() && - (!layout_in.offsets()[1].has_value() || - *layout_in.offsets()[1] % ctx.target_shape[0] == - layout_out.offsets()[0] || - *layout_in.offsets()[1] + src_tiled_dims[1] <= - ctx.target_shape[1])) { - FAILUREOR_ASSIGN_OR_RETURN( - xla::Array dst_vregs_local, - insertImplicitMinorDimension(ctx, builder, op.getLoc(), src_vregs, - layout_in.implicitShape(src_shape), - layout_in, layout_out.offsets())); - // Now, reshape the major axes of the vreg array. - dst_vregs_local.Reshape( - layout_out.tileArrayImplicitShape(dst_shape, ctx.target_shape)); - return dst_vregs_local; - } else { - return shape_cast_op.emitOpError( - "Not implemented: Unsupported vector.shape_cast: ") - << *shape_cast_op; - } - }; - FAILUREOR_ASSIGN_OR_RETURN(const xla::Array dst_vregs, getDstVregs()); - shape_cast_op->replaceAllUsesWith(assemble(builder, dst_ty, layout_out, - dst_vregs, ctx.target_shape, - /*use_implicit_shape=*/true)); - shape_cast_op->erase(); - return success(); -} - -template -LogicalResult vector_store_impl(RewriteContext &ctx, Op store_op, - const VectorLayout &to_store_layout, - TypedValue store_mask = nullptr) { - Operation &op = *(store_op.getOperation()); - MLIRContext *const mlir_ctx = store_op.getContext(); - ImplicitLocOpBuilder builder(op.getLoc(), &op); - const VectorType ty = store_op.getValueToStore().getType(); - const auto memref_ty = getMemRefType(store_op.getBase()); - if (!ty.getRank()) { - return op.emitOpError("Not implemented: scalar stores to vmem"); - } - const bool is_1d = ty.getRank() == 1; - VectorLayout::ImplicitDim expected_dim = - is_1d ? VectorLayout::ImplicitDim::kSecondMinor - : VectorLayout::ImplicitDim::kNone; - if (to_store_layout.implicit_dim() != expected_dim) { - return op.emitOpError("Not implemented: unsupported layout"); - } - using Tiling = std::array; - FAILUREOR_ASSIGN_OR_RETURN( - const Tiling memref_tiling, - getMemRefTiling(store_op.getBase(), ctx.target_shape)); - if (memref_tiling != to_store_layout.tiling()) { - if (memref_tiling[0] == 1 && to_store_layout.tiling()[0] == 1 && - memref_tiling[1] % to_store_layout.tiling()[1] == 0) { - // In this case, it is valid to have to_store tiling (1, 128 * packing) - // when storing to a 1D memref. - } else if (to_store_layout.bitwidth() == 32 && - to_store_layout.tiling() == - std::array{1, ctx.target_shape[1]}) { - // In this case, it is valid to have to_store tiling (1, - // TARGET_SHAPE.lanes) because we strided-store one row to each tile of - // the memref. This can save us a bunch of stores! - // TODO(b/295393167): need to support strided store for bitwidth < 32. - } else if (to_store_layout.bitwidth() == 32 && - // We accept padding in the minormost dim, because - // apply_vector_layout will properly mask stores。 - canReinterpretToUntiledMemref( - store_op.getBase(), ctx.target_shape, - /*allow_minormost_padding=*/true)) { - // In this case, if the memref can be reinterpreted to untiled, it is - // valid to use any tiling for to_store. But using native tiling can save - // us a bunch of stores! - } else { - return op.emitOpError( - "Not implemented: dismatch in memref tiling and vector tiling in " - "store"); - } - } - - bool can_support_unaligned_dynamic_index = false; - bool must_support_unaligned_dynamic_index = false; - if (store_op.getIndices().size() > 1) { - auto second_minor_idx = store_op.getIndices().take_back(2)[0]; - if (failed(getIntConst(second_minor_idx, /*silent=*/true)) && - !isGuaranteedDivisible(second_minor_idx, memref_tiling[0])) { - must_support_unaligned_dynamic_index = true; - } - } - int64_t sublane_stride = 1; - // Handle special patterns that allow us to support more flexible loads. - if (to_store_layout.bitwidth() == 32 && - to_store_layout.tiling() == Tiling{1, ctx.target_shape[1]}) { - // Storing a single row on the 2nd minor dim from the (1, 128) layout. We - // can use sublane striding to perform the relayout as part of the store. - // The stride of store should be the number of sublanes in memref tile when - // store a single sublane. - sublane_stride = memref_tiling[0]; - can_support_unaligned_dynamic_index = true; - } else { - // Otherwise, if the memref has a short last dimension and is contiguous - // all the tiled layouts become equivalent, so we can handle unaligned - // dynamic indices without any special case. - auto mem_layout = dyn_cast(memref_ty.getLayout()); - if (!mem_layout) { - return op.emitOpError("Expected a tiled memref"); - } - auto tile_strides = mem_layout.getTileStrides(); - if (memref_ty.getShape().back() == ctx.target_shape[1] && - tile_strides.take_back(2) == ArrayRef{1, 1}) { - can_support_unaligned_dynamic_index = true; - } - } - - auto add_idx = [&](const Value &v, int64_t d) -> Value { - if (auto cst = getIntConst(v, /*silent=*/true); succeeded(cst)) { - return IdxConst(cst.value() + d, builder, op.getLoc()); - } - return builder.create(v, IdxConst(d, builder, op.getLoc())); - }; - - int tiled_dims = is_1d ? 1 : 2; - Value base_addr = store_op.getBase(); - SmallVector base_indices = store_op.getIndices(); - - if (must_support_unaligned_dynamic_index) { - if (!can_support_unaligned_dynamic_index) { - return op.emitOpError( - "Not implemented: dynamic store with unaligned indices"); - } - } else { - // Convert dynamic store to dynamic slice + static store. This saves us a - // bunch of scalar core work. - auto slice_result = sliceRef( - builder, store_op.getBase(), ty.getShape(), store_op.getIndices(), - ArrayRef(memref_tiling).take_back(tiled_dims)); - if (failed(slice_result)) { - return failure(); - } - base_addr = slice_result->first; - CHECK_EQ(slice_result->second.size(), base_indices.size()); - for (int i = 0; i < base_indices.size(); ++i) { - base_indices[i] = IdxConst(slice_result->second[i], builder, op.getLoc()); - } - } - - // TODO(jevinjiang): ideally we should update the base addr and use static - // indices even for the cases that can skip alignment check. This can save - // us a bunch of scalar core work. - auto tile_base_idxs = ArrayRef(base_indices).take_back(tiled_dims); - auto batch_base_idxs = ArrayRef(base_indices).drop_back(tiled_dims); - - FAILUREOR_ASSIGN_OR_RETURN( - xla::Array tiles, - disassemble(builder, to_store_layout, store_op.getValueToStore(), - ctx.target_shape, /*use_implicit_shape=*/true)); - std::optional> tile_masks; - if (store_mask) { - FAILUREOR_ASSIGN_OR_RETURN( - tile_masks, disassemble(builder, to_store_layout, store_mask, - ctx.target_shape, /*use_implicit_shape=*/true)); - TPU_ASSERT_EQ_OP(tile_masks->dimensions(), tiles.dimensions()); - } - const int64_t ndims = ty.getRank(); - const auto base_s = is_1d ? nullptr : tile_base_idxs.front(); - const auto base_l = tile_base_idxs.back(); - const LayoutOffset sublane_offset = to_store_layout.offsets()[0]; - const LayoutOffset lane_offset = to_store_layout.offsets()[1]; - if (!sublane_offset.has_value() || !lane_offset.has_value()) { - return store_op.emitOpError( - "Not implemented: Replicated layout disallowed in vector store"); - } - const SmallVector stored_shape = - to_store_layout.implicitShape(ty.getShape()); - const std::array vreg_slice = - to_store_layout.vregSlice(ctx.target_shape); - const absl::Status status = - tiles.EachStatus([&](const absl::Span idx, - const Value tile) -> absl::Status { - const auto tile_mask = store_mask ? (*tile_masks)(idx) : nullptr; - const std::unique_ptr bounds = - to_store_layout.tileDataBounds(mlir_ctx, stored_shape, - toArrayRef(idx), ctx.target_shape); - const int64_t sidx = *(idx.end() - 2); - const int64_t lidx = *(idx.end() - 1); - SmallVector indices(ndims); - for (int64_t i = 0; i < batch_base_idxs.size(); ++i) { - indices[i] = add_idx(batch_base_idxs[i], idx[i]); - } - if (!is_1d) { - *(indices.end() - 2) = - add_idx(base_s, sidx * vreg_slice[0] - *sublane_offset); - } - *(indices.end() - 1) = - add_idx(base_l, lidx * vreg_slice[1] - *lane_offset); - const DenseBoolArrayAttr sublane_mask = - bounds->getSublaneMask(store_op->getContext(), ctx.target_shape); - const bool masks_subelements = - bounds->maskVariesAlong(Direction::kSubelements, ctx.target_shape); - if (bounds->maskVariesAlong(Direction::kLanes, ctx.target_shape) || - masks_subelements) { - auto failure_or_mask = - bounds->getVectorMask(builder, store_op.getLoc(), - ctx.hardware_generation, ctx.target_shape); - if (failed(failure_or_mask)) { - return absl::UnimplementedError("Failed to get vector mask"); - } - TypedValue mask = failure_or_mask.value(); - // Vmem stores don't support masking below 32-bit granularity, so we - // need to load and blend explicitly if needed. - if (masks_subelements) { - auto data = builder.create(tile.getType(), base_addr, - indices, sublane_mask, - /*sublane_stride=*/nullptr); - const bool mask_is_a_bitmask = - cast(mask.getType().getElementType()).getWidth() == - 32; - Value updated; - if (mask_is_a_bitmask) { - auto ones = builder.create( - mask.getType(), - DenseElementsAttr::get( - mask.getType(), - builder.getIntegerAttr(builder.getI32Type(), - APInt(32, 0xFFFFFFFF)))); - auto masked_tile = builder.create( - store_op.getLoc(), mask, - builder.create(mask.getType(), tile)); - auto mask_neg = builder.create(ones, mask); - auto masked_data = builder.create( - mask_neg, - builder.create(mask.getType(), data)); - updated = builder.create( - tile.getType(), - builder.create(masked_data, masked_tile)); - } else { - updated = builder.create(mask, tile, data); - } - builder.create( - updated, base_addr, indices, sublane_mask, tile_mask, - /*sublane_stride=*/builder.getI32IntegerAttr(sublane_stride)); - } else { - builder.create( - tile, base_addr, indices, sublane_mask, - tile_mask - ? builder.create(mask, tile_mask).getResult() - : mask, - /*sublane_stride=*/builder.getI32IntegerAttr(sublane_stride)); - } - } else { - builder.create( - tile, base_addr, indices, sublane_mask, tile_mask, - /*sublane_stride=*/builder.getI32IntegerAttr(sublane_stride)); - } - return absl::OkStatus(); - }); - if (!status.ok()) { - return failure(); - } - store_op->erase(); - return success(); -} - -LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - auto store_op = cast(op); - TPU_ASSERT_EQ_OP(layouts_out.size(), 0); - TPU_ASSERT_OP(layouts_in.front().has_value()); - TPU_ASSERT_OP(llvm::none_of(layouts_in.drop_front(), - [&](const Layout &l) { return l.has_value(); })); - return vector_store_impl(ctx, store_op, *layouts_in.front()); -} - -LogicalResult tpu_vector_store_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - auto store_op = cast(op); - TPU_ASSERT_EQ_OP(layouts_out.size(), 0); - TPU_ASSERT_OP(layouts_in.front().has_value()); - auto other_layouts_in = layouts_in.drop_front(); - if (store_op.getMask()) { - TPU_ASSERT_EQ_OP(layouts_in.front(), layouts_in.back()); - other_layouts_in = other_layouts_in.drop_back(); - } - TPU_ASSERT_OP(llvm::none_of(other_layouts_in, - [&](const Layout &l) { return l.has_value(); })); - return vector_store_impl(ctx, store_op, *layouts_in.front(), - store_op.getMask()); -} - -LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(layouts_in.size(), 1); - TPU_ASSERT_EQ_OP(layouts_out.size(), 1); - TPU_ASSERT_OP(layouts_in.front().has_value()); - TPU_ASSERT_OP(layouts_out.front().has_value()); - const VectorLayout &layout_in = *layouts_in.front(); - const VectorLayout &layout_out = *layouts_out.front(); - if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone || - layout_in != layout_out) { - return op.emitOpError("Not implemented: Unsupported 2D layouts"); - } - ImplicitLocOpBuilder builder(op.getLoc(), &op); - auto transpose_op = cast(op); - VectorType src_ty = transpose_op.getSourceVectorType(); - VectorType dst_ty = transpose_op.getResultVectorType(); - const int64_t rank = src_ty.getRank(); - FAILUREOR_ASSIGN_OR_RETURN( - xla::Array src_vregs, - disassemble(builder, layout_in, transpose_op.getVector(), - ctx.target_shape)); - ArrayRef permutation = transpose_op.getPermutation(); - const auto tile_perm = permutation.take_back(2); - if (tile_perm != ArrayRef{rank - 2, rank - 1} && - tile_perm != ArrayRef{rank - 1, rank - 2}) { - return transpose_op->emitOpError( - "Not implemented: Unsupported permutation"); - } - { - SmallVector p(permutation); - p[rank - 2] = rank - 2; - p[rank - 1] = rank - 1; - src_vregs.TransposeDimensions(p); - } - if (tile_perm == ArrayRef{rank - 2, rank - 1}) { - transpose_op->replaceAllUsesWith( - assemble(builder, dst_ty, layout_out, src_vregs, ctx.target_shape)); - transpose_op.erase(); - return success(); - } - if (layout_in.offsets() != LayoutOffsets{0, 0} || - !layout_in.hasNativeTiling(ctx.target_shape)) { - return transpose_op->emitOpError( - "Not implemented: Non-native or offset layout unsupported"); - } - const int64_t transpose_unit_size = ctx.target_shape[1]; - if (ctx.hardware_generation < 4 && layout_in.bitwidth() != 32) { - return transpose_op->emitOpError( - "Not implemented: TPUs before v4 only support 32-bit transposes"); - } - xla::Array dst_vregs( - layout_out.tileArrayShape(dst_ty.getShape(), ctx.target_shape)); - const int packing = layout_in.packing(); - // Note that we checked for native tiling above. - const int64_t vregs_per_tile = transpose_unit_size / layout_in.tiling()[0]; - const SmallVector minor_perm{1, 0}; - const auto tile_ty = VectorType::get( - {transpose_unit_size, transpose_unit_size}, src_ty.getElementType()); - const auto batch_tile_ty_in = - VectorType::get({transpose_unit_size, transpose_unit_size * packing}, - src_ty.getElementType()); - const auto batch_tile_ty_out = - VectorType::get({transpose_unit_size * packing, transpose_unit_size}, - src_ty.getElementType()); - // For packed types, we can increase the XLU throughput by batching together - // multiple tiles. At the moment we always batch along columns, with the - // reasoning being that if all the tiles are fed into the MXU, then it's - // better if we end up with results that contribute to the same contraction. - const bool can_batch = - layout_in.bitwidth() == 16 && ctx.hardware_generation < 6; - auto doTranspose = [&](const ArrayRef batch_idx, - const int64_t src_row, const int64_t src_col, - const int64_t src_col_end, const VectorType tile_ty_in, - const VectorType tile_ty_out) { - SmallVector src_slice_starts; - src_slice_starts.reserve(rank); - src_slice_starts.append(batch_idx.begin(), batch_idx.end()); - src_slice_starts.append({src_row * vregs_per_tile, src_col}); - SmallVector src_slice_ends; - src_slice_ends.reserve(rank); - auto incremented_batch_idx = - map_range(batch_idx, [](int64_t i) { return i + 1; }); - src_slice_ends.append(incremented_batch_idx.begin(), - incremented_batch_idx.end()); - src_slice_ends.append({(src_row + 1) * vregs_per_tile, src_col_end}); - xla::Array src_tile_vregs = src_vregs.Slice( - src_slice_starts, src_slice_ends, - builder.create( - op.getLoc(), builder.getZeroAttr(src_vregs.begin()->getType()))); - // Drop leading singleton (batch) dimensions to have a shape that conforms - // with the vreg array shape specified by layout_in, as expected by assemble - src_tile_vregs.Reshape( - ArrayRef{vregs_per_tile, src_col_end - src_col}); - const Value src_tile = assemble(builder, tile_ty_in, layout_in, - src_tile_vregs, ctx.target_shape); - auto new_transpose_op = - builder.create(tile_ty_out, src_tile, minor_perm); - new_transpose_op->setAttr("out_layout", - builder.getAttr(layout_out)); - auto unroll_vectors_op = builder.create( - llvm::map_to_vector(src_tile_vregs, - [](Value v) { return v.getType(); }), - new_transpose_op); - SmallVector dst_slice_starts; - dst_slice_starts.reserve(rank); - dst_slice_starts.append(batch_idx.begin(), batch_idx.end()); - dst_slice_starts.append({src_col * vregs_per_tile, src_row}); - SmallVector dst_slice_ends; - dst_slice_ends.reserve(rank); - dst_slice_ends.append(incremented_batch_idx.begin(), - incremented_batch_idx.end()); - dst_slice_ends.append({src_col_end * vregs_per_tile, src_row + 1}); - updateSliceFromRange(dst_vregs, unroll_vectors_op.getResults(), - dst_slice_starts, dst_slice_ends); - }; - const int num_batch_dims = rank - 2; - const ArrayRef batch_sizes = - dst_ty.getShape().take_front(num_batch_dims); - SmallVector batch_idx(num_batch_dims); - const int64_t tile_rows = - xla::CeilOfRatio(*(src_ty.getShape().end() - 2), transpose_unit_size); - const int64_t num_col_tiles = - xla::CeilOfRatio(*(src_ty.getShape().end() - 1), transpose_unit_size); - do { - for (int64_t src_row = 0; src_row < tile_rows; ++src_row) { - if (can_batch) { - const int64_t num_batch_tiles = num_col_tiles / 2; - for (int64_t src_col = 0; src_col < num_batch_tiles; ++src_col) { - doTranspose(batch_idx, src_row, src_col * 2, (src_col + 1) * 2, - batch_tile_ty_in, batch_tile_ty_out); - } - if (num_col_tiles % 2 == 1) { - doTranspose(batch_idx, src_row, num_col_tiles - 1, num_col_tiles, - tile_ty, tile_ty); - } - } else { - for (int64_t src_col = 0; src_col < num_col_tiles; ++src_col) { - doTranspose(batch_idx, src_row, src_col, src_col + 1, tile_ty, - tile_ty); - } - } - } - } while (incrementIndex(batch_idx, batch_sizes)); - for (const Value v : dst_vregs) { - TPU_ASSERT_OP(v != nullptr); - } - transpose_op->replaceAllUsesWith( - assemble(builder, dst_ty, layout_out, dst_vregs, ctx.target_shape)); - transpose_op->erase(); - return success(); -} - -LogicalResult tpu_prng_random_bits_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(layouts_in.size(), 0); - TPU_ASSERT_EQ_OP(layouts_out.size(), 1); - TPU_ASSERT_OP(layouts_out.front().has_value()); - - const VectorLayout &layout_out = *layouts_out.front(); - tpu::PRNGRandomBitsOp rng_op = cast(op); - if (layout_out != VectorLayout(32, {0, 0}, ctx.target_shape, - VectorLayout::ImplicitDim::kNone)) { - return op.emitOpError( - "Unsupported output layout for ") << rng_op->getName(); - } - OpBuilder builder(op.getContext()); - builder.setInsertionPointAfter(&op); - - VectorType vty = rng_op.getResult().getType(); - TPU_ASSERT_OP(vty.getElementType().isInteger()); - // Only 32-bit output supported currently. - TPU_ASSERT_OP(vty.getElementType().getIntOrFloatBitWidth() == 32); - xla::Array tiles( - layout_out.tileArrayShape(vty.getShape(), ctx.target_shape)); - VectorType tile_ty = VectorType::get(ctx.target_shape, vty.getElementType()); - tiles.Each([&](absl::Span tile_idxs, Value * v) { - *v = builder.create(op.getLoc(), tile_ty); - }); - const RollVectorsOp roll_vectors_op = - assemble(builder, vty, layout_out, tiles, ctx.target_shape); - rng_op->replaceUsesWithIf(roll_vectors_op, [&](OpOperand &operand) { - return operand.getOwner() != roll_vectors_op; - }); - rng_op->erase(); - return success(); -} - -const llvm::StringMap &rules() { - static const llvm::StringMap *rules = [] { - static auto rules = new llvm::StringMap{ - {arith::ConstantOp::getOperationName(), arith_constant_rule}, - {arith::ExtFOp::getOperationName(), arith_extf_rule}, - {arith::ExtSIOp::getOperationName(), arith_extsi_rule}, - {arith::ExtUIOp::getOperationName(), arith_extui_rule}, - {arith::TruncFOp::getOperationName(), arith_truncf_rule}, - {arith::TruncIOp::getOperationName(), arith_trunci_rule}, - {func::ReturnOp::getOperationName(), func_return_rule}, - {scf::ForOp::getOperationName(), scf_for_rule}, - {scf::WhileOp::getOperationName(), scf_while_rule}, - {scf::ConditionOp::getOperationName(), scf_condition_rule}, - {scf::IfOp::getOperationName(), scf_if_rule}, - {scf::YieldOp::getOperationName(), yield_rule}, - {tpu::YieldOp::getOperationName(), yield_rule}, - {tpu::RotateOp::getOperationName(), tpu_rotate_rule}, - {tpu::DynamicRotateOp::getOperationName(), tpu_dynamic_rotate_rule}, - {tpu::ConcatenateOp::getOperationName(), tpu_concatenate_rule}, - {tpu::IotaOp::getOperationName(), tpu_iota_rule}, - {tpu::GatherOp::getOperationName(), tpu_gather_rule}, - {tpu::DynamicGatherOp::getOperationName(), tpu_dynamic_gather_rule}, - {tpu::LoadOp::getOperationName(), tpu_load_rule}, - {tpu::StoreOp::getOperationName(), tpu_store_rule}, - {tpu::StridedLoadOp::getOperationName(), tpu_strided_load_rule}, - {tpu::StridedStoreOp::getOperationName(), tpu_strided_store_rule}, - {tpu::VectorStoreOp::getOperationName(), tpu_vector_store_rule}, - {tpu::MatmulOp::getOperationName(), tpu_matmul_rule}, - {tpu::RegionOp::getOperationName(), tpu_region_rule}, - {tpu::BitcastOp::getOperationName(), tpu_bitcast_rule}, - {tpu::TraceOp::getOperationName(), tpu_trace_rule}, - {tpu::AssumeLayoutOp::getOperationName(), tpu_assume_layout_rule}, - {tpu::PRNGRandomBitsOp::getOperationName(), tpu_prng_random_bits_rule}, - {tpu::RelayoutOp::getOperationName(), tpu_relayout_rule}, - {tpu::FPToSIOp::getOperationName(), tpu_fptosi_rule}, - {vector::BroadcastOp::getOperationName(), vector_broadcast_rule}, - {vector::ExtractOp::getOperationName(), vector_extract_rule}, - {vector::LoadOp::getOperationName(), vector_load_rule}, - {vector::MultiDimReductionOp::getOperationName(), - vector_multi_reduction_rule}, - {vector::ExtractStridedSliceOp::getOperationName(), - vector_extract_strided_slice_rule}, - {vector::ShapeCastOp::getOperationName(), vector_shape_cast_rule}, - {vector::StoreOp::getOperationName(), vector_store_rule}, - {vector::TransposeOp::getOperationName(), vector_transpose_rule}}; - - for (const auto &[name, rule] : mlir::tpu::extensions::rules()) { - rules->insert({name, rule}); - } - return rules; - }(); - return *rules; -} - -// Determines whether we should handle bank conflict for the given stride and -// max_sublane_offset. -// -// See `handleBankConflict` for how this is done. -bool shouldHandleBankConflict(const ApplyVectorLayoutContext &ctx, - int32_t stride, int max_sublane_offset) { - return ctx.hardware_generation >= 4 && ctx.vmem_banks > 0 && - ctx.vmem_banks < stride * ctx.target_shape[0] && - ctx.max_shuffle_sublane_offset > 0 && - ctx.max_shuffle_sublane_offset >= max_sublane_offset; -} - -// Handles load/store bank conflict by adding one extra sublane to stride and -// adjusting sublane offsets accordingly. -// -// For example, when store stride is 4 and load sublane offsets are -// [0, 1, 2, 3, 4, 5, 6, 7], the store bank conflict can be avoided by changing -// stride to 5 and sublane offsets to [0, 1, 2, 3, 5, 6, 7, 8]. -void handleBankConflict(int32_t &stride, absl::Span sublane_offsets) { - // Add one extra sublane to stride to avoid bank conflict. - for (int i = 0; i < sublane_offsets.size(); ++i) { - // Adjust sublane offsets to match the stride. - sublane_offsets[i] += i / stride; - } - ++stride; -} - -} // namespace - -RollVectorsOp assemble(OpBuilder &builder, VectorType vty, - const VectorLayout &layout, - const xla::Array &vals, - const std::array target_shape, - const bool use_implicit_shape) { - // TODO(tlongeri): Maybe just add a parameter to tileArrayShape instead of - // having `tileArrayShape` and `tileArrayImplicitShape`. - SmallVector vreg_array_shape = - layout.tileArrayImplicitShape(vty.getShape(), target_shape); - if (!use_implicit_shape) { - layout.eraseImplicit(vreg_array_shape); - } - CHECK(vals.dimensions() == vreg_array_shape); - CHECK_GT(vals.num_elements(), 0); - Location loc = vals.begin()->getLoc(); - auto op = - builder.create(loc, vty, XlaArrayToFlatArrayRef(vals)); - op->setAttr("out_layout", builder.getAttr(ArrayRef{ - builder.getAttr(layout)})); - return op; -} - -// Disassemble an MLIR vector into an ndarray of native vectors. -// -// Args: -// layout: The layout of val. Used to determine the unrolling into -// native-shaped vectors. -// val: Value to disassemble. Must be of type VectorType. -// -// Returns: -// An ndarray of MLIR values representing the tiling of val given by layout. -FailureOr> disassemble( - OpBuilder &builder, const VectorLayout &layout, - const TypedValue val, const std::array target_shape, - const bool use_implicit_shape) { // TODO(tlongeri): Remove default - const auto vty = val.getType(); - const auto op_result = dyn_cast(val); - if (op_result == nullptr) { - return failure(); - } - Operation *const op = op_result.getOwner(); - const unsigned res_idx = op_result.getResultNumber(); - FAILUREOR_ASSIGN_OR_RETURN(const SmallVector def_layouts, - getOutLayouts(*op, target_shape)); - const Layout def_layout = def_layouts[res_idx]; - TPU_ASSERT_LOC(val.getLoc(), def_layout.has_value()); - TPU_ASSERT_LOC(val.getLoc(), - def_layout->generalizes(layout, vty.getShape(), target_shape)); - auto layout_product = - xla::Product(layout.tileArrayShape(vty.getShape(), target_shape)); - auto def_layout_product = - xla::Product(def_layout->tileArrayShape(vty.getShape(), target_shape)); - TPU_ASSERT_LOC(val.getLoc(), layout_product == def_layout_product); - // TODO(tlongeri): Maybe just add a parameter to tileArrayShape instead of - // having `tileArrayShape` and `tileArrayImplicitShape`. - SmallVector layout_shape = - layout.tileArrayImplicitShape(vty.getShape(), target_shape); - if (!use_implicit_shape) { - layout.eraseImplicit(layout_shape); - } - if (auto roll_vectors_op = dyn_cast(op)) { - return XlaArrayFromShapeAndValues(layout_shape, - roll_vectors_op->getOperands()); - } - return op->emitOpError("Not implemented: ") << val; -} - -// Assembles a destination tile using partial data from rotated vregs using a -// divide-and-conquer strategy. -// -// Arguments: -// rotated_row_vregs: A row of rotated vregs, from which destination tile(s) -// is/are to be selected to assemble a new vreg. -// src_layout: The source layout. -// start_src_col: The first rotated vreg in the row of rotated vregs to -// process. -// end_src_col: The last rotated vreg in the row of rotated vreg to process. -// first_dst_tile_sublane_offset: Sublane offset where the first dst tile to -// be -// selected starts. -// dst_layout: Destination layout, based on which retiling is being performed. -// hw_generation: The generation of a target hardware. -// -// Returns: -// A new vreg assembled from dst tiles stored in given rotated vregs. -Value selectTilesFromRotatedRowVregs( - OpBuilder &builder, const ArrayRef &rotated_row_vregs, - const int64_t start_src_col, const int64_t end_src_col, - const int64_t first_dst_tile_sublane_offset, const VectorLayout &dst_layout, - const std::array target_shape) { - CHECK_LE(start_src_col, end_src_col); - CHECK_LE(start_src_col, end_src_col); - if (start_src_col == end_src_col) { - return rotated_row_vregs[start_src_col]; - } - const int64_t mid_src_col = start_src_col + (end_src_col - start_src_col) / 2; - - Value left_partial_vreg = selectTilesFromRotatedRowVregs( - builder, rotated_row_vregs, start_src_col, mid_src_col, - first_dst_tile_sublane_offset, dst_layout, target_shape); - Location loc = left_partial_vreg.getLoc(); - - const int64_t left_tiles_count = mid_src_col - start_src_col + 1; - const int64_t right_first_dst_tile_sublane_offset = - (first_dst_tile_sublane_offset + - left_tiles_count * dst_layout.sublanesPerTile(target_shape)) % - target_shape[0]; - - Value right_partial_vreg = selectTilesFromRotatedRowVregs( - builder, rotated_row_vregs, mid_src_col + 1, end_src_col, - right_first_dst_tile_sublane_offset, dst_layout, target_shape); - - const IntegerType i1 = builder.getI1Type(); - // We never need to select partial sublanes, even for packed data. - const auto mask_vreg_ty = VectorType::get(target_shape, i1); - auto i32_vreg = VectorType::get(target_shape, builder.getI32Type()); - auto select_32bit = [&](Value sublane_mask, Value left, Value right) { - // Always do the selects on 32-bit granularity for maximum HW compatibility. - Type vreg_ty = left.getType(); - if (dst_layout.packing() != 1) { - left = builder.create(loc, i32_vreg, left); - right = builder.create(loc, i32_vreg, right); - } - Value result = - builder.create(loc, sublane_mask, left, right); - if (dst_layout.packing() != 1) { - result = builder.create(loc, vreg_ty, result); - } - return result; - }; - - auto boundIdxConst = std::bind(IdxConst, std::placeholders::_1, builder, - left_partial_vreg.getLoc()); - if (first_dst_tile_sublane_offset < right_first_dst_tile_sublane_offset) { - // The useful data sublanes in left vregs do not wrap around in vreg. - // For e.g. consider (2,128) destination tiling and we are trying to merge - // two vregs as follows: - // - // vreg 0: vreg 1: - // x x x x x dst_tile_2 - // x x x x x dst_tile_3 - // dst_tile_4 x x x x x - // dst_tile_5 x x x x x - // dst_tile_6 x x x x x - // dst_tile_7 x x x x x - // x x x x x dst_tile_0 - // x x x x x dst_tile_1 - // - // In the above case, the data we want to select from vreg 1 wraps around, - // whereas vreg 0 useful data is contiguous. It is easier to create '1' mask - // for vreg 0. - auto sublanes_mask = builder.create( - left_partial_vreg.getLoc(), mask_vreg_ty, - ArrayRef{boundIdxConst(first_dst_tile_sublane_offset), - boundIdxConst(0)}, - ArrayRef{boundIdxConst(right_first_dst_tile_sublane_offset), - boundIdxConst(target_shape[1])}); - return select_32bit(sublanes_mask, left_partial_vreg, right_partial_vreg); - } - - auto sublanes_mask = builder.create( - left_partial_vreg.getLoc(), mask_vreg_ty, - ArrayRef{boundIdxConst(right_first_dst_tile_sublane_offset), - boundIdxConst(0)}, - ArrayRef{boundIdxConst(first_dst_tile_sublane_offset), - boundIdxConst(target_shape[1])}); - return select_32bit(sublanes_mask, right_partial_vreg, left_partial_vreg); -} - -// Retiles across vregs to match the destination layout when the sublane tiling -// dimension is reduced. -// -// Arguments: -// value_shape: The shape of the value which needs to be retiled in vregs. -// src: The source layout. -// src_vreg_array: An array of vregs storing source tiles (with implicit -// shape). -// dst_layout: The destination layout, with reduced sublane dimension, based -// on -// which the retiling will be performed. -// hw_generation: The generation of a target hardware. -// -// Returns: -// A new array of vregs that store tiles based on the destination layout. -xla::Array retileToReducedSublanes( - OpBuilder &builder, const ArrayRef value_shape, - const VectorLayout &src_layout, const xla::Array &src_vreg_array, - const VectorLayout &dst_layout, const std::array target_shape) { - const int64_t dst_tiling_sublane = dst_layout.tiling()[0]; - CHECK_LT(0, dst_tiling_sublane); - CHECK_LT(dst_tiling_sublane, src_layout.tiling()[0]); - CHECK(llvm::isPowerOf2_64(dst_tiling_sublane)); - - xla::Array dst_vreg_array( - dst_layout.tileArrayImplicitShape(value_shape, target_shape)); - - // We need to rotate each src tile in each src vreg once so that they can - // be merged to form new vregs. If a src vreg contains more than one src tile, - // it will be rotated once per src tile. Consider (8,512) tensor stored with - // layout (8,128) in a vreg array of shape (1, 4). Each src vreg - // contains one src tile in this case. Given, the destination layout is - // (2,128), each src tile is divided into 4 destination tiles as shown below: - // - // src_vreg_0_0: src_vreg_0_1: src_vreg_0_2: src_vreg_0_3: - // dst_tile_0_0_0 dst_tile_0_0_1 dst_tile_0_0_2 dst_tile_0_0_3 - // dst_tile_1_0_0 dst_tile_1_0_1 dst_tile_1_0_2 dst_tile_1_0_3 - // dst_tile_2_0_0 dst_tile_2_0_1 dst_tile_2_0_2 dst_tile_2_0_3 - // dst_tile_3_0_0 dst_tile_3_0_1 dst_tile_3_0_2 dst_tile_3_0_3 - // - // In this example, each src tile in the src vreg is rotated by - // col * sublanes_per_tile to produce the following rotated src vregs: - // - // rot_src_vreg_0_0: rot_src_vreg_0_1: rot_src_vreg_0_2: rot_src_vreg_0_3: - // dst_tile_0_0_0 dst_tile_3_0_1 dst_tile_2_0_2 dst_tile_1_0_3 - // dst_tile_1_0_0 dst_tile_0_0_1 dst_tile_3_0_2 dst_tile_2_0_3 - // dst_tile_2_0_0 dst_tile_1_0_1 dst_tile_0_0_2 dst_tile_3_0_3 - // dst_tile_3_0_0 dst_tile_2_0_1 dst_tile_1_0_2 dst_tile_0_0_3 - - // If there were 2 src tiles in the src vreg, we would have rotated each src - // vreg twice, producing 2 rotated src vreg per src vreg. The rotation amount - // is calculated from the src and the dest tiling. - - const int64_t src_tiles_per_vreg = src_layout.tilesPerVreg(target_shape); - const int64_t dst_tiles_per_vreg = dst_layout.tilesPerVreg(target_shape); - const int64_t src_sublanes_per_tile = - src_layout.sublanesPerTile(target_shape); - const int64_t dst_sublanes_per_tile = - dst_layout.sublanesPerTile(target_shape); - // Each vreg may store more than one src tile. We may have to rotate a vreg, - // once for every src tile in the vreg. - SmallVector rotated_src_vreg_array_shape( - toArrayRef(src_vreg_array.dimensions())); - rotated_src_vreg_array_shape.back() *= src_tiles_per_vreg; - xla::Array rotated_src_vreg_array(rotated_src_vreg_array_shape); - - rotated_src_vreg_array.Each([&](const absl::Span rotated_idx, - Value *const rotated_src_vreg) { - const int64_t idx = rotated_idx.back(); - const int64_t tile_idx = idx % dst_tiles_per_vreg; - const int64_t dst_sublane = tile_idx * dst_sublanes_per_tile; - auto [src_col, src_tile_offset] = std::div(idx, src_tiles_per_vreg); - SmallVector src_vreg_idx(toArrayRef(rotated_idx)); - src_vreg_idx.back() = src_col; - Value src_vreg = src_vreg_array(src_vreg_idx); - const int64_t src_sublane = src_tile_offset * src_sublanes_per_tile; - int64_t rotate_amt = dst_sublane - src_sublane; - if (rotate_amt == 0) { - *rotated_src_vreg = src_vreg; - return; - } - if (rotate_amt < 0) { - rotate_amt += target_shape[0]; - } - *rotated_src_vreg = builder.create( - src_vreg.getLoc(), src_vreg, rotate_amt, - /*dimension=*/0, /*stride=*/nullptr, /*stride_dimension=*/nullptr); - }); - // Assemble output vregs using tiles from rotated vregs using select. - // Given, above example, destination vregs are then assembled as follows: - // dst_vreg_0_0: - // dst_tile_0_0_0 - // dst_tile_0_0_1 - // dst_tile_0_0_2 - // dst_tile_0_0_3 - - // dst_vreg_1_0: (Notice dst tiles are not in correct offset!) - // dst_tile_1_0_3 - // dst_tile_1_0_0 - // dst_tile_1_0_1 - // dst_tile_1_0_2 - - // dst_vreg_2_0: (Notice dst tiles are not in correct offset!) - // dst_tile_2_0_2 - // dst_tile_2_0_3 - // dst_tile_2_0_0 - // dst_tile_2_0_1 - - // dst_vreg_3_0: (Notice dst tiles are not in correct offset!) - // dst_tile_3_0_1 - // dst_tile_3_0_2 - // dst_tile_3_0_3 - // dst_tile_3_0_0 - - // Each destination vreg is assembled from destination tiles in multiple - // rotated src vregs. In the above example, if we wanted each destination tile - // to be in correct sublane offset in a rotated vreg, say rot_src_vreg_0_1, - // before assembling the destination tiles, we would have had to rotate - // src_vreg_0_1 four times, creating 4 rotated vregs (instead of 1) for each - // src vreg. In the above example, we instead rotated a src vreg src_vreg_0_1 - // only once to obtain rot_src_vreg_0_1 where the dst_tile_0_0_1 is in correct - // final sublane offset, i.e. 2. But notice the sublane offset of - // dst_tile_1_0_1 in the same rotated vreg. Its correct final destination - // sublane offset is 2, but in rot_src_vreg_0_1, its offset is 4. Its sublane - // offset is off by 2. We need to correct these sublane offsets in the final - // assembled dst vregs. A single rotation of each assembled dst vreg is needed - // to correct such sublane offsets. This strategy reduces the number of - // sublane rotations required. See comments below. - const int64_t tile_sublane_change_factor = - src_layout.tiling()[0] / dst_layout.tiling()[0]; - - dst_vreg_array.Each([&](absl::Span idx, - Value *const dst_vreg) { - const int64_t row = *(idx.end() - 2); - const int64_t col = *(idx.end() - 1); - auto [rotated_vreg_row, first_dst_tile_offset] = - std::div(row, tile_sublane_change_factor); - const int64_t first_dst_tile_sublane_offset = - first_dst_tile_offset * dst_sublanes_per_tile; - const int64_t src_vreg_array_col_start = col * dst_tiles_per_vreg; - const int64_t src_vreg_array_col_end = - std::min((col + 1) * dst_tiles_per_vreg, - rotated_src_vreg_array.dimensions().back()) - - 1; - - // TODO(tlongeri): Find a better way to slice that doesn't involve so - // copying so many index vectors and hopefully is more concise. Probably - // by expanding xla::Array (maybe could just expose calculate_index?). - SmallVector rotated_row_starts(toArrayRef(idx)); - *(rotated_row_starts.end() - 2) = rotated_vreg_row; - *(rotated_row_starts.end() - 1) = 0; - SmallVector rotated_row_ends(idx.size()); - for (size_t i = 0; i + 1 < rotated_row_ends.size(); ++i) { - rotated_row_ends[i] = rotated_row_starts[i] + 1; - } - *(rotated_row_ends.end() - 1) = rotated_src_vreg_array.dimensions().back(); - const xla::Array rotated_row_slice = - rotated_src_vreg_array.Slice(rotated_row_starts, rotated_row_ends); - const Value dst_tile = selectTilesFromRotatedRowVregs( - builder, /*rotated_row_vregs=*/ - ArrayRef(rotated_row_slice.begin(), rotated_row_slice.end()), - src_vreg_array_col_start, src_vreg_array_col_end, - first_dst_tile_sublane_offset, dst_layout, target_shape); - if (first_dst_tile_sublane_offset == 0) { - // No need to rotate. First dst tile is already at offset 0, which means - // rest of the dst tiles are also at correct sublane offset. - *dst_vreg = dst_tile; - } else { - // Fix the destination tile sublane offset by rotating assembled dest vreg - // once (See comments above). The dst vregs are fixed as follows: - // No rotation needed. - // dst_tile_0_0_0 - // dst_tile_0_0_1 - // dst_tile_0_0_2 - // dst_tile_0_0_3 - - // Rotated by -1 * (sublanes_per_tile=2) * (row=1): - // dst_tile_1_0_0 - // dst_tile_1_0_1 - // dst_tile_1_0_2 - // dst_tile_1_0_3 - - // Rotated by -1 * (sublanes_per_tile=2) * (row=2): - // dst_tile_2_0_0 - // dst_tile_2_0_1 - // dst_tile_2_0_2 - // dst_tile_2_0_3 - - // Rotated by -1 * (sublanes_per_tile=2) * (row=3): - // dst_tile_3_0_0 - // dst_tile_3_0_1 - // dst_tile_3_0_2 - // dst_tile_3_0_3 - *dst_vreg = builder.create( - dst_tile.getLoc(), dst_tile, - target_shape[0] - first_dst_tile_sublane_offset, /*dimension=*/0, - /*stride=*/nullptr, /*stride_dimension=*/nullptr); - } - }); - return dst_vreg_array; -} - - -// Copy one sublane from a vreg to another vreg. -// -// Arguments: -// src_vreg: The source vreg to copy a sublane from. -// src_sl_idx: The sublane index in src_vreg to copy from. -// dst_vreg: The base vreg to copy the sublane into. May be null. -// dst_sl_idx: The sublane index in the result. -// -// Returns: -// A new dst_vreg with the copied sublane. -Value copy_one_sublane(OpBuilder &builder, Value src_vreg, int src_sl_idx, - Value dst_vreg, int dst_sl_idx, - const std::array target_shape) { - src_vreg = builder.create( - src_vreg.getLoc(), src_vreg, - /*amount=*/(dst_sl_idx - src_sl_idx + target_shape[0]) % target_shape[0], - /*dimension=*/0, /*stride=*/nullptr, /*stride_dimension=*/nullptr); - if (dst_vreg) { - auto boundIdxConst = - std::bind(IdxConst, std::placeholders::_1, builder, src_vreg.getLoc()); - const int bitwidth = - cast(src_vreg.getType()).getElementTypeBitWidth(); - CHECK_EQ(bitwidth, - cast(dst_vreg.getType()).getElementTypeBitWidth()); - const VectorType vmask_ty = - getNativeVregOrVmaskType(builder.getI1Type(), bitwidth, target_shape); - auto sublanes_mask = builder.create( - src_vreg.getLoc(), vmask_ty, - ValueRange{boundIdxConst(dst_sl_idx), boundIdxConst(0)}, - ValueRange{boundIdxConst(dst_sl_idx + 1), - boundIdxConst(target_shape[1])}); - src_vreg = builder.create(src_vreg.getLoc(), sublanes_mask, - src_vreg, dst_vreg); - } - return src_vreg; -} - -// This function is based on tpu_rotate_rule. It applies a shift of amount to -// a given dim. A major difference is that it "overflows", i.e. if the shift -// amount is such that it pushes us into a new vreg, we create a new vreg and -// fill it in with the remaining rows. -// -// The shift is the difference between layout_in and layout_out, on the -// given dim. -FailureOr> tpu_rotate_with_overflow( - OpBuilder &builder, const std::array target_shape, - const Location loc, const VectorType vty, xla::Array in_tiles, - int64_t dim, const VectorLayout &layout_in, - const LayoutOffsets offsets_out) { - if (!layout_in.hasNativeTiling(target_shape)) { - return emitError(loc, "Not implemented: non-native tiling for layout"); - } - if (layout_in.bitwidth() != 32) { - return emitError(loc, - "Not implemented: multi-row shift with " - "bitwidth != 32"); - } - // TODO(apaszke,mvoz): Just use offsets_out instead of this. - VectorLayout layout_out(layout_in.bitwidth(), offsets_out, layout_in.tiling(), - layout_in.implicit_dim()); - - int64_t tiling_dim = dim - (in_tiles.num_dimensions() - 2); - if (tiling_dim != 0) { - return emitError(loc, - "Rotate with overflow untested for " - "dim != 0"); - } - auto amount = - *layout_out.offsets()[tiling_dim] - *layout_in.offsets()[tiling_dim]; - - SmallVector dst_tiles_shape = - layout_out.tileArrayImplicitShape(vty.getShape(), target_shape); - - const VectorType res_vreg_ty = - getNativeVregType(vty.getElementType(), target_shape); - - xla::Array out_tiles(dst_tiles_shape); - - // We update the result vregs in the following way: - // - If the offset is positive, write the first tile as is, if the offset - // is negative, blend it with the next tile. - // - Blend the rest of the tiles with the prior (positive offset) or next - // (negative offset) tile. - // - (In positive cases, we can get an extra vreg (overflow)) we write the - // remaining tiles. - // This only happens if the original input vreg size is smaller than the - // result vreg size (an offset) can "push" us into a new vreg. - // - // Ex: (30, 128), starting offset 0, shift by 6, native tiling (8, 128) - // The input is (4, 1), where the first 3 vregs are full (0-24) - // and the last vreg is filled in rows 0-6. When we offset it by 6, we - // need a 4th vreg, as now vreg 0 is filled in 6-8 (2 total), vreg 1, 2, 3 - // are filled in fully (8-16, 16-24, 24-32) (2 + 24 total), and vreg 4 is - // filled in 0-4. (2 + 24 + 4 = 30). - - // Negative offset amount means we: - // - // Ex 1: (30, 128), input offset 6, shift by -2, native tiling (8, 128) - // (The result of the last example, for simplicity). In this case, we have - // (5, 1) vregs as decribed above. Because the shift does not cause us to - // shift back from the 5th vreg, we still need it. In such a case, the result - // vreg is still (5, 1). - // - // - Write the first vreg as is. - // - The next vregs are blended with the prior one (except the last), - // where we blend by the shift amount. Ex: Vreg 1 goes from 6-8 to 4-8, - // pulling 2 rows from the next vreg. - // - The last tile is masked to only write the remaining rows. - // Ex: Vreg 4 goes from 0-4 to 0-2. - // - // Ex 2: (30, 128), starting offset 6, shift by -6, native tiling (8, 128) - // In this case, we have (5, 1) vregs as described above. Because the shift - // causes us to shift back from the 5th vreg, we don't need it anymore. - // In such a case, the result vreg is (4, 1). - // - // - All vregs are blended with the next one (except the last), - // where we blend by the shift amount. Ex: Vreg 1 goes from 6-8 to 0-8, - // pulling 6 rows from the next vreg. - // - The last tile is discarded - it was fully subsumed by the prior blends. - // - // Ex 3: (30, 128), starting offset 0, shift by -6, native tiling (8, 128) - // In this case, we have (4, 1) vregs as described above. - // In such a case, the result vreg is (4, 1), where the first vreg is filled - // in rows 2-8 (6), and vregs 1 and 2 are filled in fully (8-16, 16-24), and - // vreg 3 is filled in rows 0-6. - // - // NOTE - in such cases, where the abs(shift) in a negative shift > starting - // offset, we can actually implement this as a positive shift of the delta - // from the native tile size. - // in the example above, the delta is 8 - 6 + 0 = 2. The resulting vregs are - // the same as if we had shifted by 2, starting at offset 0. - // - // Another example to demonstrate the point: - // Ex 4: (30, 128), starting offset 2, shift by -4, native tiling (8, 128) - // In this case, we start with (4, 1) vregs as described above. - // (2-8)(8-16)(16-24)(0-4). Shifting by -4 is the same as 8 - 4 + 2 = 6. - // So we can just shift by 6, starting at offset 0. - // Vreg 0 is filled in 6-8 (2 total), vreg 1, 2 and 3 are filled in fully - // (8-16, 16-24, 24-32) (2 + 24 total = 26) vreg 4 is filled with the - // remainder, 0-4 (30 total). - // - // This means that no matter what the shift is, we should always - // rotate and compute the shift amount in such a way that the first input - // vreg is the first output vreg. - - // Compute the mask for the blend. - // Positive blends blend "forward" and negative blends blend "backward". - auto mask_val = amount; - auto vreg_rot_amount = amount; - if (amount < 0) { - mask_val = layout_in.tiling()[tiling_dim] - std::abs(amount); - vreg_rot_amount += target_shape[tiling_dim]; - } - auto boundIdxConst = std::bind(IdxConst, std::placeholders::_1, builder, loc); - auto mask = builder.create( - loc, VectorType::get(target_shape, builder.getI1Type()), - ValueRange{boundIdxConst(0), boundIdxConst(0)}, - ValueRange{boundIdxConst(mask_val), boundIdxConst(target_shape[1])}); - - // Actually do the rotation. - in_tiles.Each([&](absl::Span idxs, Value *v) { - if (dim >= in_tiles.num_dimensions() - 2) { - *v = builder.create(loc, res_vreg_ty, in_tiles(idxs), - vreg_rot_amount, tiling_dim, nullptr, - nullptr); - } - }); - - // Walk the result tiles. - // TODO(mvoz): There is a micro-optimization here where we can avoid - // allocating blend indices per vreg. - out_tiles.Each([&](absl::Span idxs, Value *v) { - if (idxs[dim] == 0) { - // A negative shift amount means we need to blend the first tile with the - // next one, but only if we're not at the end of the input. - if (amount < 0 && (idxs[dim] + 1 < in_tiles.dim(dim))) { - SmallVector next_idx = {idxs.begin(), idxs.end()}; - next_idx[dim] = idxs[dim] + 1; - *v = builder.create(loc, mask, in_tiles(idxs), - in_tiles(next_idx)); - } else { - // Positive shift, or negative shift at the end of the input. - *v = in_tiles(idxs); - } - } else if (idxs[dim] < in_tiles.dim(dim)) { - // write the rest as blended up to the end of the input - if (amount < 0) { - if (idxs[dim] + 1 < in_tiles.dim(dim)) { - SmallVector next_idx = {idxs.begin(), idxs.end()}; - next_idx[dim] = idxs[dim] + 1; - *v = builder.create(loc, mask, in_tiles(idxs), - in_tiles(next_idx)); - } else { - // Nothing to blend with, just write the last tile. - *v = in_tiles(idxs); - } - } else { - SmallVector prior_idx = {idxs.begin(), idxs.end()}; - prior_idx[dim] = idxs[dim] - 1; - *v = builder.create(loc, mask, in_tiles(prior_idx), - in_tiles(idxs)); - } - } else { - // write trailing if it's there (positive shift, increasing vreg count) - // Use the last prior - SmallVector prior_idx = {idxs.begin(), idxs.end()}; - prior_idx[dim] = idxs[dim] - 1; - *v = in_tiles(prior_idx); - } - }); - - return out_tiles; -} - -void rotateVregs(OpBuilder &builder, xla::Array &vregs, - const int64_t amount, const int dimension) { - if (amount != 0) { - vregs.Each([&](absl::Span idx, Value *vreg) { - CHECK(vreg); - *vreg = builder - .create(vreg->getLoc(), *vreg, - /*amount=*/amount, - /*dimension=*/dimension, - /*stride=*/nullptr, - /*stride_dimension=*/nullptr) - .getResult(); - }); - } -}; - -void rotateSublanes(OpBuilder &builder, xla::Array &vregs, - const int64_t amount) { - rotateVregs(builder, vregs, amount, 0); -} - -void rotateLanes(OpBuilder &builder, xla::Array &vregs, - const int64_t amount) { - rotateVregs(builder, vregs, amount, 1); -} - -// Relayout src_vregs from layout src to layout dst, where dst is the same as -// src except that the column offset is dst_col_offset. -FailureOr> doColumnShiftRelayout( - OpBuilder &builder, const ArrayRef shape, - xla::Array src_vregs, const VectorLayout &src, - const int64_t dst_col_offset, const std::array target_shape) { - CHECK(src.offsets()[1]); - const std::array tiled_ishape = - src.getImplicitTiledDims(shape, 1); - const Location loc = src_vregs.begin()->getLoc(); - const std::array tiling = src.tiling(); - const std::array vreg_slice = src.vregSlice(target_shape); - const int bitwidth = src.bitwidth(); - const int packing = src.packing(); - const int64_t col_diff = dst_col_offset - *src.offsets()[1]; - if (tiling[0] % packing != 0 || tiling[1] != target_shape[1]) { - return emitError(loc, - "Not implemented: Unsupported tiling for column shift"); - } - // When shifting columns with multiple tiles per vreg, the overflowing - // columns of a tile move to the next tile, and they have to be shifted - // down. For example, for a 32-bit layout with (2, 128 tiling), when shifting - // a vreg right by 138 (128 + 10): - // - // +---------------+---------+ +---------+---------------+ - // | 0:118 | 118:128 | |-138:-128| -128:-10 | - // +---------------+---------+ +---------+---------------+ - // | 128:246 | 246:256 | | -10:0 | 0:118 | - // +---------------+---------+ -> +---------+---------------+ - // | 256:382 | 382:392 | | 118:128 | 128:246 | - // +---------------+---------+ +---------+---------------+ - // | 392:502 | 502:512 | | 246:256 | 256:382 | - // +---------------+---------+ +---------+---------------+ - // - // The negative numbers above are used for column intervals coming from the - // previous vreg (if there is one). - // - // We can break the result vreg down into four parts: - // - // +---------+---------------+ - // | UL | UR | - // + +---------------+ - // | | LR | - // +---------+ + - // | LL | | - // + + + - // | | | - // +---------+---------------+ - // - // Our example shifts right, which causes the upper parts to come from the - // previous (along the minor dim) vreg of the array (if it exists) and the - // lower parts to come from the original "current" vreg. - // - // - LR (Lower Right) comes from the current vreg lane-rotated by 10, and - // sublane-rotated down by 2 (1 tile). - // - LL (Lower Left) comes from the current vreg lane-rotated by 10, and - // sublane-rotated down by 4 (2 tiles). - // - UR (Upper Right) comes from the previous vreg lane-shifted by 10, and - // sublane-rotated down by 2 (1 tile). - // - UL (Upper Left) comes from the previous vreg lane-shifted by 10, and - // sublane-rotated down by 4 (2 tiles). - // - // This partitioning also works similarly for left shifts, except that the - // upper parts come from the current vreg, and the lower parts come from the - // next vreg. - // - // In general, for any tiling and shift amount, we will partition the result - // vreg into four like we did here. However, for some tilings and shift - // amounts, some of the partitions may be empty. There are some notable cases: - // - // - Tile-aligned shifts result in empty left parts. - // - Native tiling (a single tile per vreg) results in empty upper right and - // lower left parts. - // - Shifts right by less than 1 tile result in empty upper right parts, and - // shifts left by less than 1 tile result in empty lower left parts. - - const int64_t sublanes_per_tile = src.sublanesPerTile(target_shape); - const int64_t tiles_per_vreg = src.tilesPerVreg(target_shape); - - int64_t split_offset = col_diff; - int64_t upper_idx_delta = -1; - int64_t lower_idx_delta = 0; - if (col_diff < 0) { - split_offset += vreg_slice[1]; - ++upper_idx_delta; - ++lower_idx_delta; - } - const int64_t left_tile_split = llvm::divideCeil(split_offset, tiling[1]); - const int64_t right_tile_split = split_offset / tiling[1]; - const int64_t left_right_split = split_offset % tiling[1]; - - rotateLanes(builder, src_vregs, left_right_split); - // TODO(tlongeri): Clean up. Some of these rotations may end up unused: - // - The left part of the first vreg and the right part of the last vreg - // may be entirely padding. - // - The entire left part may be unused if the shift is tile-aligned. - // They will be removed as dead code anyway, but it would be nicer to not - // generate them in the first place. - // Also, sometimes the rotation amount is 0, so we don't need to allocate - // another array (and we should steal the allocation for src_tiles, too). - xla::Array left_part = src_vregs; - xla::Array right_part = src_vregs; - rotateSublanes(builder, left_part, - left_tile_split * sublanes_per_tile % target_shape[0]); - rotateSublanes(builder, right_part, - right_tile_split * sublanes_per_tile % target_shape[0]); - // We assemble left and right, and then put them together. - // TODO(tlongeri): Lower and upper first is probably better, it can be - // reused for consecutive vregs. We can assemble lower_left+lower_right - // for one vreg and upper_left+upper_right for the next one in the same - // vselect. But the mask for assembling upper+lower is not as simple, so - // it might be a bit more expensive to generate. Worth it for large vreg - // arrays, I'm not sure about small ones (especially in older TPU gens). - const auto mask_vreg_ty = VectorType::get( - packing == 1 - ? target_shape - : ArrayRef{target_shape[0], target_shape[1], packing}, - builder.getI1Type()); - Value left_mask = nullptr; - Value right_mask = nullptr; - Value left_right_mask = nullptr; - auto get_left_mask = [&]() { - if (left_mask == nullptr) { - left_mask = builder.create( - loc, mask_vreg_ty, - ArrayRef{IdxConst(0, builder, loc), IdxConst(0, builder, loc)}, - ArrayRef{ - IdxConst(left_tile_split * sublanes_per_tile, builder, loc), - IdxConst(target_shape[1], builder, loc)}); - } - return left_mask; - }; - auto get_right_mask = [&]() { - if (right_mask == nullptr) { - right_mask = builder.create( - loc, mask_vreg_ty, - ArrayRef{IdxConst(0, builder, loc), IdxConst(0, builder, loc)}, - ArrayRef{ - IdxConst(right_tile_split * sublanes_per_tile, builder, loc), - IdxConst(target_shape[1], builder, loc)}); - } - return right_mask; - }; - auto get_left_right_mask = [&]() { - if (left_right_mask == nullptr) { - left_right_mask = builder.create( - loc, mask_vreg_ty, - ArrayRef{IdxConst(0, builder, loc), IdxConst(0, builder, loc)}, - ArrayRef{IdxConst(target_shape[0], builder, loc), - IdxConst(left_right_split, builder, loc)}); - } - return left_right_mask; - }; - xla::Array dst_vregs(VectorLayout(bitwidth, - {src.offsets()[0], dst_col_offset}, - tiling, src.implicit_dim()) - .tileArrayImplicitShape(shape, target_shape)); - dst_vregs.Each([&](absl::Span dst_idx, Value *dst_vreg) { - SmallVector dst_idx_local(toArrayRef(dst_idx)); - Value lower_left = nullptr; - Value lower_right = nullptr; - Value upper_left = nullptr; - Value upper_right = nullptr; - // Set parts if their size is non-empty and the source vreg exists. - *(dst_idx_local.end() - 1) += lower_idx_delta; - if (*(dst_idx_local.end() - 1) < *(src_vregs.dimensions().end() - 1)) { - if (left_tile_split < tiles_per_vreg && 0 < left_right_split) { - lower_left = left_part(dst_idx_local); - } - if (right_tile_split < tiles_per_vreg) { - lower_right = right_part(dst_idx_local); - } - } - *(dst_idx_local.end() - 1) -= lower_idx_delta; - *(dst_idx_local.end() - 1) += upper_idx_delta; - if (*(dst_idx_local.end() - 1) >= 0) { - if (0 < left_tile_split && 0 < left_right_split) { - upper_left = left_part(dst_idx_local); - } - if (0 < right_tile_split) { - upper_right = right_part(dst_idx_local); - } - } - *(dst_idx_local.end() - 1) -= upper_idx_delta; - - // For the first and last vregs, some parts may be all padding, so - // unset them if this is the case. Note that the first and last vreg - // are the same when there is only one. - if (*(dst_idx_local.end() - 1) == 0) { - // We check the final offset (note that this is different from the rotate - // amount) against the thresholds of the last columns of vreg parts. - if (right_tile_split * tiling[1] <= dst_col_offset) { - // Note: When shifting right, UR is always all-padding. - upper_right = nullptr; - } - if (split_offset <= dst_col_offset) { - // Note: When shifting right, UL is always all-padding. When shifting - // left, UL is never all-padding (unless this is also the last vreg, - // possibly). - upper_left = nullptr; - } - if (vreg_slice[1] - tiling[1] + left_right_split <= dst_col_offset) { - // Note: When shifting right, LL is only all-padding if the source - // offset is in the last tile. When shifting left, LL is never - // all-padding (unless this is also the last vreg, possibly). - lower_left = nullptr; - } - } - if (*(dst_idx_local.end() - 1) == *(dst_vregs.dimensions().end() - 1) - 1) { - // We check the final end offset against the thresholds of the first - // columns of vreg parts. - const uint64_t end_offset = - (dst_col_offset + tiled_ishape[1] - 1) % vreg_slice[1] + 1; - if (end_offset <= left_tile_split * tiling[1]) { - // Note: When shifting left, LL is always all-padding. - lower_left = nullptr; - } - if (end_offset <= split_offset) { - // Note: When shifting left, LR is always all-padding. When shifting - // right, LR is never all-padding (unless this is also the first vreg, - // possibly). - lower_right = nullptr; - } - if (end_offset <= left_right_split) { - // Note: When shifting left, UR is only all-padding if the original - // end offset is in the first tile. When shifting right, UR is never - // all-padding (unless this is also the last vreg, possibly). - upper_right = nullptr; - } - } - // Combine parts into the final vreg (see comment in mask definitions). - auto combine_parts = [&builder](Value part1, Value part2, - auto get_mask_fn) -> Value { - if (part1 && part2) { - return builder.create(part1.getLoc(), get_mask_fn(), - part1, part2); - } else if (part1) { - return part1; - } else { - return part2; - } - }; - Value left = combine_parts(upper_left, lower_left, get_left_mask); - Value right = combine_parts(upper_right, lower_right, get_right_mask); - *dst_vreg = combine_parts(left, right, get_left_right_mask); - CHECK(*dst_vreg); - }); - return dst_vregs; -} - -FailureOr>> changeOffsets( - RewriteContext &ctx, OpBuilder &builder, const Location loc, - const VectorType vty, const VectorLayout src, xla::Array vregs, - const LayoutOffsets dst_offsets) { - const auto &target_shape = ctx.target_shape; - const VectorLayout dst(src.bitwidth(), dst_offsets, src.tiling(), - src.implicit_dim()); - const int packing = src.packing(); - const int8_t bitwidth = src.bitwidth(); - - int row_diff; - if (!src.offsets()[0].has_value()) { - row_diff = 0; - } else if (!dst_offsets[0].has_value()) { - return emitError(loc, "Not implemented: Sublane broadcast"); - } else { - row_diff = *dst_offsets[0] - *src.offsets()[0]; - } - - int64_t col_diff; - if (!src.offsets()[1].has_value()) { - col_diff = 0; - } else if (!dst_offsets[1].has_value()) { - return emitError(loc, "Not implemented: Lane broadcast"); - } else { - col_diff = *dst_offsets[1] - *src.offsets()[1]; - } - - if (row_diff != 0) { - if (col_diff != 0) { - return emitError(loc, "Not implemented: Row and column offset changes"); - } - const SmallVector implicit_shape = - src.implicitShape(vty.getShape()); - if (implicit_shape[implicit_shape.size() - 2] != 1) { - // Multi row shift - // TODO(mvoz): This should take the vregs array, not the value. - FAILUREOR_ASSIGN_OR_RETURN( - vregs, tpu_rotate_with_overflow( - builder, target_shape, loc, vty, std::move(vregs), - /*dim*/ implicit_shape.size() - 2, src, dst_offsets)); - } else { - // Single row case - // TODO(mvoz): The single row case has a broader set of supported - // operations: non-native tiling, packed types, implicit dim. We should - // support these cases in tpu_rotate_with_overflow and remove this - // branch. - const int64_t src_sublane = *src.offsets()[0] / packing; - const int64_t dst_sublane = *dst_offsets[0] / packing; - if (int64_t sublane_diff = dst_sublane - src_sublane) { - if (sublane_diff < 0) { - sublane_diff += target_shape[0]; - } - rotateSublanes(builder, vregs, sublane_diff); - } - const int src_subelem = *src.offsets()[0] % packing; - const int dst_subelem = *dst.offsets()[0] % packing; - if (src_subelem != dst_subelem) { - const int subelem_diff = dst_subelem - src_subelem; - const int shift_bits = bitwidth * std::abs(subelem_diff); - VectorType bits_vreg_ty = - VectorType::get(target_shape, builder.getI32Type()); - auto shift_vreg = builder.create( - loc, bits_vreg_ty, - DenseElementsAttr::get(bits_vreg_ty, shift_bits)); - vregs.Each([&](absl::Span /*idx*/, Value *tile) { - auto bit_tile = - builder.create(loc, bits_vreg_ty, *tile); - Operation *shift_tile; - if (subelem_diff > 0) { - shift_tile = - builder.create(loc, bit_tile, shift_vreg); - } else { // subelem_diff < 0 - CHECK_LT(subelem_diff, 0); - shift_tile = - builder.create(loc, bit_tile, shift_vreg); - } - *tile = builder - .create(loc, tile->getType(), - shift_tile->getResult(0)) - .getResult(); - }); - } - } - } - - // Rows are now correctly aligned. Time to offset columns. - // TODO(apaszke, mvoz): Changing an offset might add or remove one vreg. - // Note - this is handled for row shifts via tpu_rotate_with_overflow - SmallVector dst_tiles_shape = - dst.tileArrayImplicitShape(vty.getShape(), target_shape); - CHECK_EQ(*(dst_tiles_shape.end() - 2), *(vregs.dimensions().end() - 2)); - - // TODO(tlongeri): Clean up col_diff and pass the dst offset directly. - if (col_diff != 0) { - FAILUREOR_ASSIGN_OR_RETURN( - vregs, doColumnShiftRelayout(builder, vty.getShape(), std::move(vregs), - src, *dst.offsets()[1], target_shape)); - } - return std::make_pair(dst, std::move(vregs)); -} - -LogicalResult retileToLargeTileWithScratch( - RewriteContext &ctx, OpBuilder &builder, const Location loc, - xla::Array &dst_tiles, const std::array &dst_tile, - const xla::Array &src_tiles, const std::array &src_tile, - TypedValue scratch_ref, const int64_t store_vreg_delay, - const int64_t load_vreg_skips) { - if (dst_tile[0] % src_tile[0] != 0) { - return emitError(loc, "dst_tile[0] must be a multiple of src_tile_size[0]"); - } - // Number of src vregs needed to assemble one dst vreg. - int vregs_per_group = dst_tile[0] / src_tile[0]; - // Number of sublanes needed per src vreg to assemble one dst vreg. - int sl_per_vreg = ctx.target_shape[0] / vregs_per_group; - int stride = vregs_per_group; - - xla::Array sublane_offsets( - {ctx.target_shape[0] / dst_tile[0], src_tile[0], vregs_per_group}, 0); - absl::c_iota(sublane_offsets, 0); - // The older hardware has limited support for shuffles so even if we have bank - // conflicts, we just accept them and will have the lowering unroll the - // loads/stores. - int64_t num_offsets = sublane_offsets.num_elements(); - // The max sublane offset before handling bank conflicts is always - // (num_offsets - 1). To avoid bank conflicts, we need to add one extra - // sublane to stride so (num_offsets - 1) / stride is the extra offset needed - // to pad sublanes. - // - // For example, if store stride = 4, sublane_count = 8, and - // load offsets = [0, 1, 2, 3, 4, 5, 6, 7], then the sublane offsets after - // handling bank conflicts will be [0, 1, 2, 3, 5, 6, 7, 8] and the max - // sublane offset will be 7 + (8 - 1) / 4 = 8. - // - // Before - // <-------- sublanes ---------> - // 0 1 ... 32 - // store: x---x---x---x---x---x---x---x - // load: xxxxxxxxx-------------------- - // - // After - // <-------- sublanes ---------> - // 0 5 ... 40 - // store: x----x----x----x----x----x----x----x - // load: xxxx-xxxx--------------------------- - // - // where "x" indicates a sublane that needs to be accessed and "-"" indicates - // a sublane that does not need to be accessed. - int max_sublane_offset = (num_offsets - 1) + (num_offsets - 1) / stride; - bool should_handle_bank_confict = - shouldHandleBankConflict(ctx, stride, max_sublane_offset); - if (should_handle_bank_confict) { - handleBankConflict(stride, absl::MakeSpan(sublane_offsets.data(), - sublane_offsets.num_elements())); - } - sublane_offsets.TransposeDimensions({0, 2, 1}); - - auto mlirIndexConst = [&](int d) { - return builder.create( - src_tiles.begin()->getLoc(), - builder.getIntegerAttr(builder.getIndexType(), d)); - }; - auto cst_0 = mlirIndexConst(0); - // Each group has exact number of src vregs needed to assemble one dst vreg. - // We can not use circular buffer here because we need to have enough space to - // strided load/store. - int64_t sublanes_per_group = stride * sl_per_vreg * vregs_per_group; - int64_t max_groups_in_scratch = - ctx.max_sublanes_in_scratch / sublanes_per_group; - if (max_groups_in_scratch < 1) { - return emitError(loc, - "scratch space is not enough for retiling to large tile"); - } - int64_t stored_group_cnt = 0; - auto dst_vreg_ty = src_tiles.begin()->getType(); - // Create a new vreg type that can be stored in scratch memref. - auto temp_vreg_ty = - VectorType::get(ctx.target_shape, scratch_ref.getType().getElementType()); - SmallVector sublane_mask(ctx.target_shape[0], true); - // (dst_vreg, load_offset) - std::vector> delayed_loads; - delayed_loads.reserve(max_groups_in_scratch * vregs_per_group); - // We only emit the loads when we run out of scratch space or we are at the - // last vreg of the batch to help bundle scheduling. - auto emit_all_delayed_loads = [&]() { - for (auto [dst_vreg, load_offset] : delayed_loads) { - Value load_op = builder.create( - loc, temp_vreg_ty, scratch_ref, ArrayRef({load_offset, cst_0}), - ArrayRef(sublane_mask), - ArrayRef(sublane_offsets.begin(), sublane_offsets.end())); - *dst_vreg = builder.create(loc, dst_vreg_ty, load_op); - } - delayed_loads.clear(); - }; - - int rank = src_tiles.dimensions().size(); - if (rank != dst_tiles.dimensions().size()) { - return emitError(loc, "src and dst tiles have different ranks"); - } - for (int i = 0; i < rank - 2; ++i) { - if (src_tiles.dim(i) != dst_tiles.dim(i)) { - return emitError(loc, - "Expected src and dst tiles have same dimension " - "sizes on dim") - << i << ", but got " << src_tiles.dim(i) << " vs " - << dst_tiles.dim(i); - } - } - SmallVector src_idx(rank); - dst_tiles.Each([&](absl::Span dst_idx, Value *dst_vreg) { - int64_t dst_row_idx = *(dst_idx.end() - 2); - int64_t dst_col_idx_with_skips = *(dst_idx.end() - 1) + load_vreg_skips; - int64_t vreg_idx_in_group = dst_col_idx_with_skips % vregs_per_group; - int64_t load_offset = sublanes_per_group * stored_group_cnt + - vreg_idx_in_group * sl_per_vreg * stride; - delayed_loads.push_back( - std::make_pair(dst_vreg, mlirIndexConst(load_offset))); - // When dst vreg is at the last vreg of the group or the current dst - // vregs' row, this indicates we have scheduled delayed loads for all - // the vregs from current group and now we need to store corresponding - // group of src vregs before actually emitting the loads. - if (vreg_idx_in_group == vregs_per_group - 1 || - dst_idx.back() == dst_tiles.dimensions().back() - 1) { - auto base_src_row_idx = dst_row_idx * vregs_per_group - store_vreg_delay; - auto src_col_idx = dst_col_idx_with_skips / vregs_per_group; - std::copy(dst_idx.begin(), dst_idx.end(), src_idx.begin()); - for (int vi = 0; vi < vregs_per_group; ++vi) { - const int64_t src_row_idx = base_src_row_idx + vi; - if (src_row_idx < 0) { - continue; - } - if (src_row_idx >= src_tiles.dim(rank - 2) || - src_col_idx >= src_tiles.dim(rank - 1)) { - break; - } - *(src_idx.end() - 2) = src_row_idx; - *(src_idx.end() - 1) = src_col_idx; - Value src_vreg = src_tiles(src_idx); - src_vreg = - builder.create(loc, temp_vreg_ty, src_vreg); - Value store_offset = - mlirIndexConst(sublanes_per_group * stored_group_cnt + vi); - builder.create( - loc, src_vreg, scratch_ref, ArrayRef({store_offset, cst_0}), - ArrayRef(sublane_mask), - /*mask=*/nullptr, builder.getI32IntegerAttr(stride)); - } - stored_group_cnt = (stored_group_cnt + 1) % max_groups_in_scratch; - // We emit loads when we run out of scratch space or we are at the - // last vreg of the batch. - if (stored_group_cnt == 0 || - (*(dst_idx.end() - 2) == dst_tiles.dim(rank - 2) - 1 && - *(dst_idx.end() - 1) == dst_tiles.dim(rank - 1) - 1)) { - emit_all_delayed_loads(); - } - } - }); - return success(); -} - -LogicalResult retileToSmallTileWithScratch( - RewriteContext &ctx, OpBuilder &builder, const Location loc, - xla::Array &dst_tiles, const std::array &dst_tile, - const xla::Array &src_tiles, const std::array &src_tile, - TypedValue scratch_ref, const int64_t store_vreg_delay, - const int64_t load_vreg_skips) { - if (src_tile[0] % dst_tile[0] != 0) { - return emitError(loc, "src tile size must be a multiple of dst tile size"); - } - // Number of src vregs needed to assemble one dst vreg. - int vregs_per_group = src_tile[0] / dst_tile[0]; - // Number of sublanes needed per src vreg to assemble one dst vreg. - int sl_per_vreg = ctx.target_shape[0] / vregs_per_group; - int stride = vregs_per_group; - - xla::Array sublane_offsets( - {ctx.target_shape[0] / src_tile[0], dst_tile[0], vregs_per_group}, 0); - absl::c_iota(sublane_offsets, 0); - // The older hardware has limited support for shuffles so even if we have - // bank conflicts, we just accept them and will have the lowering unroll the - // loads/stores. - int64_t num_offsets = sublane_offsets.num_elements(); - // The max sublane offset before handling bank conflicts is always - // (num_offsets - 1). To avoid bank conflicts, we need to add one extra - // sublane to stride so (num_offsets - 1) / stride is the extra offset needed - // to pad sublanes. - // - // For example, if store stride = 4, sublane_count = 8, and - // load offsets = [0, 1, 2, 3, 4, 5, 6, 7], then the sublane offsets after - // handling bank conflicts will be [0, 1, 2, 3, 5, 6, 7, 8] and the max - // sublane offset will be 7 + (8 - 1) / 4 = 8. - // - // Before - // <-------- sublanes ---------> - // 0 4 ... - // store: x---x---x---x---x---x---x---x - // load: xxxxxxxxx------------------- - // - // After - // <-------- sublanes ---------> - // 0 5 ... - // store: x----x----x----x----x----x----x----x - // load: xxxx-xxxx--------------------------- - // - // where "x" indicates a sublane that needs to be accessed and "-"" indicates - // a sublane that does not need to be accessed. - int max_sublane_offset = (num_offsets - 1) + (num_offsets - 1) / stride; - bool should_handle_bank_confict = - shouldHandleBankConflict(ctx, stride, max_sublane_offset); - bool use_shuffled_load = false; - if (ctx.hardware_generation <= 4) { - if (src_tile[0] == 8) { - // The older hardware does not support shuffled store. However, if the src - // tile is (8, 128), we can convert (shuffled store + strided load) to - // (strided store + shuffled load). - use_shuffled_load = true; - } else if (src_tile[0] == 4) { - // In this case, the trick of replacing a shuffled store with a shuffled - // load does not work. Handling bank conflicts will cause the sublane - // offsets to increase which might make emulation harder, so we avoid - // doing so. - should_handle_bank_confict = false; - } - } - - // Add one extra sublane to stride to avoid bank conflict. - if (should_handle_bank_confict) { - handleBankConflict(stride, absl::MakeSpan(sublane_offsets.data(), - sublane_offsets.num_elements())); - } - sublane_offsets.TransposeDimensions({0, 2, 1}); - auto mlirIndexConst = [&](int d) { - return builder.create( - src_tiles.begin()->getLoc(), - builder.getIntegerAttr(builder.getIndexType(), d)); - }; - auto cst_0 = mlirIndexConst(0); - // Each group has exact number of src vregs needed to assemble one dst vreg. - // We can not use circular buffer here because we need to have enough space - // to strided load/store. - int64_t sublanes_per_group = stride * sl_per_vreg * vregs_per_group; - int64_t max_groups_in_scratch = - ctx.max_sublanes_in_scratch / sublanes_per_group; - if (max_groups_in_scratch < 1) { - return emitError(loc, - "scratch space is not enough for retiling to small tile"); - } - int64_t stored_group_cnt = 0; - auto dst_vreg_ty = src_tiles.begin()->getType(); - // Create a new vreg type that can be stored in scratch memref. - auto temp_vreg_ty = - VectorType::get(ctx.target_shape, scratch_ref.getType().getElementType()); - SmallVector sublane_mask(ctx.target_shape[0], true); - // (dst_vreg, load_offset) - std::vector> delayed_loads; - delayed_loads.reserve(max_groups_in_scratch * vregs_per_group); - // We only emit the loads when we run out of scratch space or we are at the - // last vreg of the batch to help bundle scheduling. - auto emit_all_delayed_loads = [&]() { - for (auto [dst_vreg, load_offset] : delayed_loads) { - Value load_op; - if (use_shuffled_load) { - load_op = builder.create( - loc, temp_vreg_ty, scratch_ref, - ArrayRef({load_offset, cst_0}), ArrayRef(sublane_mask), - ArrayRef(sublane_offsets.begin(), sublane_offsets.end())); - } else { - load_op = builder.create( - loc, temp_vreg_ty, scratch_ref, - ArrayRef({load_offset, cst_0}), ArrayRef(sublane_mask), - builder.getI32IntegerAttr(stride)); - } - *dst_vreg = builder.create(loc, dst_vreg_ty, load_op); - } - delayed_loads.clear(); - }; - int rank = src_tiles.dimensions().size(); - if (rank != dst_tiles.dimensions().size()) { - return emitError(loc, "src and dst tiles have different ranks"); - } - for (int i = 0; i < rank - 2; ++i) { - if (src_tiles.dim(i) != dst_tiles.dim(i)) { - return emitError(loc, - "Expected src and dst tiles have same dimension " - "sizes on dim") - << i << ", but got " << src_tiles.dim(i) << " vs " - << dst_tiles.dim(i); - } - } - SmallVector dst_idx(rank); - src_tiles.Each([&](absl::Span src_idx, Value src_vreg) { - int64_t src_row_idx = *(src_idx.end() - 2); - int64_t src_col_idx_with_delays = *(src_idx.end() - 1) + store_vreg_delay; - int64_t vreg_idx_in_group = src_col_idx_with_delays % vregs_per_group; - src_vreg = builder.create(loc, temp_vreg_ty, src_vreg); - if (use_shuffled_load) { - Value store_offset = mlirIndexConst( - sublanes_per_group * stored_group_cnt + vreg_idx_in_group); - builder.create( - loc, src_vreg, scratch_ref, ArrayRef({store_offset, cst_0}), - ArrayRef(sublane_mask), - /*mask=*/nullptr, builder.getI32IntegerAttr(stride)); - } else { - Value store_offset = - mlirIndexConst(sublanes_per_group * stored_group_cnt + - vreg_idx_in_group * sl_per_vreg * stride); - builder.create( - loc, src_vreg, scratch_ref, ArrayRef({store_offset, cst_0}), - ArrayRef(sublane_mask), - ArrayRef(sublane_offsets.begin(), sublane_offsets.end())); - } - // When src vreg is at the last vreg of the group or the current src - // vregs' row, this indicates we have stored all the vregs needed to - // assemble a new group of dst vreg. - if (vreg_idx_in_group == vregs_per_group - 1 || - src_idx.back() == src_tiles.dimensions().back() - 1) { - auto base_dst_row_idx = src_row_idx * vregs_per_group - load_vreg_skips; - auto dst_col_idx = src_col_idx_with_delays / vregs_per_group; - std::copy(src_idx.begin(), src_idx.end(), dst_idx.begin()); - for (int vi = 0; vi < vregs_per_group; ++vi) { - const int64_t dst_row_idx = base_dst_row_idx + vi; - if (dst_row_idx < 0) { - continue; - } - if (dst_row_idx >= dst_tiles.dim(rank - 2) || - dst_col_idx >= dst_tiles.dim(rank - 1)) { - break; - } - *(dst_idx.end() - 2) = dst_row_idx; - *(dst_idx.end() - 1) = dst_col_idx; - Value *dst_vreg = &dst_tiles(dst_idx); - int64_t load_offset = - use_shuffled_load ? (sublanes_per_group * stored_group_cnt + - vi * sl_per_vreg * stride) - : (sublanes_per_group * stored_group_cnt + vi); - delayed_loads.push_back( - std::make_pair(dst_vreg, mlirIndexConst(load_offset))); - } - stored_group_cnt = (stored_group_cnt + 1) % max_groups_in_scratch; - // We emit loads when we run out of scratch space or we are at the - // last vreg of the batch. - if (stored_group_cnt == 0 || - (*(src_idx.end() - 2) == src_tiles.dim(rank - 2) - 1 && - *(src_idx.end() - 1) == src_tiles.dim(rank - 1) - 1)) { - emit_all_delayed_loads(); - } - } - }); - return success(); -} - -// go/mosaic-retiling-in-scratch is the full internal documentation that -// includes more details about the TPU generations. -// Arguments: -// - shape: The non-implicit shape of the operand -// - dst_tiling: The desired result tiling -// - dst_offsets_hint: Hints for the result offsets. They may be used or -// ignored. See comments in the body of the function for -// more details. -// - src_vregs: The source vregs to retile. -// - src: The source layout -// Returns a pair holding the result layout (potentially using the hints) and -// the retiled vregs. -// TODO(tlongeri): Clean up the function parameters/signatures. We are passing -// in more information than strictly needed. -FailureOr>> retileWithScratch( - RewriteContext &ctx, OpBuilder &builder, const Location loc, - const ArrayRef shape, const std::array dst_tiling, - const LayoutOffsets dst_offsets_hint, const xla::Array &src_vregs, - const VectorLayout &src) { - const int bitwidth = src.bitwidth(); - const int packing = src.packing(); - const std::array src_tiling = src.tiling(); - if (!(src_tiling[1] == ctx.target_shape[1] && - dst_tiling[1] == ctx.target_shape[1] && src_tiling[0] % packing == 0 && - dst_tiling[0] % packing == 0)) { - return emitError(loc, "Unsupported retiling with scratch"); - } - const std::array src_vreg_slice = - VectorLayout::vregSlice(ctx.target_shape, bitwidth, src_tiling); - const std::array dst_vreg_slice = - VectorLayout::vregSlice(ctx.target_shape, bitwidth, dst_tiling); - - // TODO(b/368088671): When sublane tiling changes, we should be able to - // preserve some replications from the source layout. But we need to - // make sure they are implemented efficiently and well-tested. For now, we - // just simply use 0 for the replicated offset after retiling. - const LayoutOffsets src_offsets = {src.offsets()[0].value_or(0), - src.offsets()[1].value_or(0)}; - // The provided offset hints are used only if they align with the source - // offsets, else we default to the smallest possible aligned offsets. - LayoutOffsets dst_offsets = {*src_offsets[0] % dst_vreg_slice[0], - *src_offsets[1] % dst_vreg_slice[1]}; - // On a given dimension, either the source vreg slice size divides the dest - // vreg slice size, or vice versa (depending on the dimension and whether it's - // small-to-large or large-to-small retiling). Offset changes are supported - // as long as they are aligned modulo the smaller of the two sizes. - const std::array alignment = { - std::min(src_vreg_slice[0], dst_vreg_slice[0]), - std::min(src_vreg_slice[1], dst_vreg_slice[1])}; - if (dst_offsets_hint[0].has_value() && - (*dst_offsets_hint[0] - *src_offsets[0]) % alignment[0] == 0) { - CHECK_LT(*dst_offsets_hint[0], dst_vreg_slice[0]); - dst_offsets[0] = *dst_offsets_hint[0]; - } - if (dst_offsets_hint[1].has_value() && - (*dst_offsets_hint[1] - *src_offsets[1]) % alignment[1] == 0) { - CHECK_LT(*dst_offsets_hint[1], dst_vreg_slice[1]); - dst_offsets[1] = *dst_offsets_hint[1]; - } - // The offsets of the source in units of the destination vreg slice: - const std::array src_offsets_in_dst_vreg_slices = { - *src_offsets[0] / dst_vreg_slice[0], *src_offsets[1] / dst_vreg_slice[1]}; - // The offsets of the destination in units of the source vreg slice: - const std::array dst_offsets_in_src_vreg_slices = { - *dst_offsets[0] / src_vreg_slice[0], *dst_offsets[1] / src_vreg_slice[1]}; - - // Try to get i32 vector scratch space. Because we will bitcast vregs to - // i32 vregs before using scratch for retiling. Through this way we can - // handle packed types as well. - auto vi32_scratch_ref = getInternalScratch( - ctx, builder, loc, {ctx.max_sublanes_in_scratch, ctx.target_shape[1]}, - builder.getI32Type(), /*sublane_tiling=*/1); - if (failed(vi32_scratch_ref)) { - return emitError(loc, "Failed to get scratch ref for retiling"); - } - auto ref = vi32_scratch_ref.value(); - std::array vi32_dst_tiling = {dst_tiling[0] / packing, - dst_tiling[1]}; - std::array vi32_src_tiling = {src_tiling[0] / packing, - src_tiling[1]}; - - const VectorLayout dst(bitwidth, dst_offsets, dst_tiling, src.implicit_dim()); - TPU_ASSERT_LOC(loc, dst.isValid(ctx.target_shape)); - xla::Array dst_vregs( - dst.tileArrayImplicitShape(shape, ctx.target_shape)); - // When differences in offsets exist, the source vregs may stored at an offset - // position in their group. For example, the 1st vreg in a row/column may be - // stored as if it was the 3rd, so that the parts corresponding to the 1st and - // 2nd in the destination are filled with padding. Likewise, loads to - // destination vregs may be skipped, when they would load only padding. - // store_vreg_delay is the position offset for stores, and load_vreg_skips is - // the position offset for loads. - // - // For example, suppose we are going from 32-bit {0, 128}(2, 128) to - // {4, 0}(8, 128). We form groups of 4 vregs that represent an (8, 512) slice - // of the padded implicit shape. For the given offsets, for the first group, - // the data is in (4:8, 128:512). But the first and second sources (stored - // vregs) of the group form the slices of data (0:2, 0:512) and (2:4, 0:512), - // which should be all padding. Likewise, the first dest vreg slice (which we - // load from) holds the data from slice (0:8, 0:128), which is all padding. - // We never load or store to slices that should contain only padding. - if (src_tiling[0] > dst_tiling[0]) { - DCHECK_EQ(src_offsets_in_dst_vreg_slices[1], 0); - DCHECK_EQ(dst_offsets_in_src_vreg_slices[0], 0); - const int64_t store_vreg_delay = dst_offsets_in_src_vreg_slices[1]; - const int64_t load_vreg_skips = src_offsets_in_dst_vreg_slices[0]; - if (failed(retileToSmallTileWithScratch( - ctx, builder, loc, dst_vregs, vi32_dst_tiling, src_vregs, - vi32_src_tiling, ref, store_vreg_delay, load_vreg_skips))) { - return failure(); - } - } - if (src_tiling[0] < dst_tiling[0]) { - DCHECK_EQ(src_offsets_in_dst_vreg_slices[0], 0); - DCHECK_EQ(dst_offsets_in_src_vreg_slices[1], 0); - const int64_t store_vreg_delay = dst_offsets_in_src_vreg_slices[0]; - const int64_t load_vreg_skips = src_offsets_in_dst_vreg_slices[1]; - if (failed(retileToLargeTileWithScratch( - ctx, builder, loc, dst_vregs, vi32_dst_tiling, src_vregs, - vi32_src_tiling, ref, store_vreg_delay, load_vreg_skips))) { - return failure(); - } - } - return std::make_pair(dst, dst_vregs); -} - -FailureOr>> changeTiling( - RewriteContext &ctx, OpBuilder &builder, const Location loc, VectorType vty, - const VectorLayout src, xla::Array vregs, - const std::array dst_tiling, - const LayoutOffsets dst_offsets_hint) { - bool has_enough_scratch = ctx.max_sublanes_in_scratch >= - ctx.target_shape[0] * (ctx.target_shape[0] + 1); - const auto &target_shape = ctx.target_shape; - const std::array src_tiling = src.tiling(); - if (src_tiling == dst_tiling) { - return std::pair(src, std::move(vregs)); - } - const LayoutOffsets src_offsets = - src.getCanonicalOffsets(vty.getShape(), ctx.target_shape); - const std::array tiled_ishape = - src.getImplicitTiledDims(vty.getShape(), 1); - const int packing = src.packing(); - const int8_t bitwidth = src.bitwidth(); - const std::array dst_vreg_slice = - VectorLayout::vregSlice(ctx.target_shape, bitwidth, dst_tiling); - // TODO(tlongeri): Using canonical vs non-canonical offsets can change the - // value of try_replicate rows, and it breaks some tests. It doesn't make - // sense that we have different behavior for equivalent layouts, though. We - // need better logic for picking the relayout strategy. - const bool try_replicate_rows = - src.offsets()[0].has_value() && !dst_offsets_hint[0].has_value(); - - // Fully replicated offsets are handled efficiently elsewhere (in relayout) - CHECK(src.offsets()[0].has_value() || src.offsets()[1].has_value()); - - // Handle replicating small-to-large retiling for (a) replicated 2nd minor or - // (b) 32-bit single-row. - // This retiling is one-to-many vregs. - // TODO(tlongeri): Large-to-small retiling with replicated minor is analogous - // to this. - if (src_tiling[1] == ctx.target_shape[1] && - dst_tiling[1] == ctx.target_shape[1] && - dst_tiling[0] % src_tiling[0] == 0 && - (!src_offsets[0].has_value() || (packing == 1 && tiled_ishape[0] == 1)) && - // This relayout relies on gathers, which are cheap on newer generations, - // so we always use it for them. - // TODO(tlongeri): Once we have it, probably also prefer the - // small-to-large rotate+blend relayout if we don't need replication. It's - // slightly cheaper for some dst vregs you rotate by 0. - // TODO(tlongeri): Using store + multiple replicated loads is good on - // older gens. I wonder if we can integrate this logic to scratch retiling - (try_replicate_rows || ctx.hardware_generation >= 5)) { - const LayoutOffset dst_minor_offset = - src.offsets()[1].has_value() ? *src.offsets()[1] % dst_vreg_slice[1] - : LayoutOffset(); - const VectorLayout dst(bitwidth, {std::nullopt, dst_minor_offset}, - dst_tiling, src.implicit_dim()); - const SmallVector dst_vreg_array_shape = - dst.tileArrayImplicitShape(vty.getShape(), target_shape); - const int64_t src_tiles_per_vreg = src.tilesPerVreg(ctx.target_shape); - const int64_t dst_tiles_per_vreg = dst.tilesPerVreg(ctx.target_shape); - const int64_t src_sublanes_per_tile = src.sublanesPerTile(ctx.target_shape); - const int64_t dst_sublanes_per_tile = dst.sublanesPerTile(ctx.target_shape); - xla::Array retiled(dst_vreg_array_shape); - SmallVector idxs; - retiled.Each([&](absl::Span dst_idx, Value *vreg) { - const int64_t dst_col_idx = *(dst_idx.end() - 1); - const int64_t base_dst_tile_idx = dst_col_idx * dst_tiles_per_vreg; - const int64_t base_src_tile_idx = - src_offsets[1].has_value() - ? base_dst_tile_idx + - (*src_offsets[1] - *dst_minor_offset) / src_tiling[1] - : 0; - // The following should be true from our choice of minor offset: - DCHECK_EQ(base_src_tile_idx % dst_tiles_per_vreg, 0); - const int64_t src_col_idx = base_src_tile_idx / src_tiles_per_vreg; - SmallVector gather_pattern; - // Iterate over the sublanes in the dst vreg: - for (int32_t sublane = 0; sublane < ctx.target_shape[0]; ++sublane) { - const int64_t dst_tile_idx_in_vreg = sublane / dst_sublanes_per_tile; - const int64_t src_tile_idx_in_vreg = - base_src_tile_idx % src_tiles_per_vreg + dst_tile_idx_in_vreg; - // Although replication may give us several sublanes to choose from, - // we always gather from the first sublane in the source tile. This - // degenerates to a broadcast when dst_tiling is native, which can - // be cheaper than an arbitrary gather (for some hardware gens). - const int64_t src_sublane_in_tile = - src_offsets[0].value_or(0) / packing; - const int64_t src_sublane = - src_tile_idx_in_vreg * src_sublanes_per_tile + src_sublane_in_tile; - gather_pattern.push_back(src_sublane); - } - idxs.assign(dst_idx.begin(), dst_idx.end()); - *(idxs.end() - 2) = 0; - *(idxs.end() - 1) = src_col_idx; - Value src_vreg = vregs(idxs); - *vreg = builder.create(loc, src_vreg.getType(), src_vreg, - gather_pattern, - /*dimension=*/0); - }); - return std::pair(dst, std::move(retiled)); - } - // (8,128) <-> (8 * packing,128) tiling change for packed type. - if (ctx.hardware_generation >= 4 && - src_offsets[0].value_or(0) < dst_vreg_slice[0] && - src_offsets[1].value_or(0) < dst_vreg_slice[1] && bitwidth < 32 && - 32 % bitwidth == 0 && - ((src_tiling == ctx.target_shape && - dst_tiling == std::array{ctx.target_shape[0] * packing, - ctx.target_shape[1]}) || - (dst_tiling == ctx.target_shape && - src_tiling == std::array{ctx.target_shape[0] * packing, - ctx.target_shape[1]}))) { - // TODO(tlongeri): This relayout is just ext + trunc. Refactor. - // Note: for int4, retiling with scratch is always faster. - if (bitwidth != 4 || !has_enough_scratch) { - // Note: The code below does not work when src is replicated and dst is - // not, since it relies on the src vreg array shape to know how many tiles - // to pack in dst, and vreg array shapes with materialized offsets are - // unfortunately not equal to vreg array shapes with replicated offsets. - VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling, - src.implicit_dim()); - xla::Array retiled( - dst.tileArrayImplicitShape(vty.getShape(), target_shape)); - VectorType vreg_x32 = - vty.getElementType().isSignlessInteger() - ? VectorType::get(target_shape, builder.getI32Type()) - : VectorType::get(target_shape, builder.getF32Type()); - // For each output vreg we collect `packing` registers from the moving dim - // (sublanes or lanes), while using the other vreg dim to determine which - // part of each register to use (the parts dim). - const int parts_dim = src_tiling[0] < dst_tiling[0] ? 1 : 2; - const int moving_dim = src_tiling[0] < dst_tiling[0] ? 2 : 1; - retiled.Each([&](absl::Span idx, Value *tile) { - const int vreg_part = *(idx.end() - parts_dim) % packing; - SmallVector parts; - parts.reserve(packing); - SmallVector src_idx(idx.begin(), idx.end()); - *(src_idx.end() - parts_dim) /= packing; - if (!dst.offsets()[2 - moving_dim].has_value()) { - *(src_idx.end() - moving_dim) = 0; - // Make sure we set all parts of the output vreg to make it replicated - parts.append(packing, builder.create( - loc, vreg_x32, vregs(src_idx), vreg_part, - tpu::PackFormat::kCompressed)); - } else { - *(src_idx.end() - moving_dim) *= packing; - for (int i = 0; i < packing; ++i) { - if (*(src_idx.end() - moving_dim) < - *(vregs.dimensions().end() - moving_dim)) { - parts.push_back(builder.create( - loc, vreg_x32, vregs(src_idx), vreg_part, - tpu::PackFormat::kCompressed)); - ++*(src_idx.end() - moving_dim); - } else { - parts.push_back(nullptr); - } - } - } - *tile = builder.create( - loc, cast(vregs.begin()->getType()), parts, - tpu::PackFormat::kCompressed); - }); - return std::pair(dst, std::move(retiled)); - } - } - // Handle retiling from (1, 128 * packing) to (packing, 128) for - // packed data. - // We do compressed unpacking followed by interleaved packing. - // TODO(tlongeri): This can be used as a first step before using - // a generalized retiling where we only move sublanes around - // (without packing/unpacking). - // TODO(tlongeri): Interleaved unpacking followed by interleaved - // packing (but with different pairings) might also be - // interesting if the next step is a retile, since we can also - // match corresponding elements without shifting. It's just that - // the tiles are not adjacent (no contiguous vreg slice). - if (src_offsets[0].value_or(0) < dst_vreg_slice[0] && - src_offsets[1].value_or(0) < dst_vreg_slice[1] && bitwidth < 32 && - 32 % bitwidth == 0 && - src_tiling == std::array{1, ctx.target_shape[1] * packing} && - dst_tiling == std::array{packing, ctx.target_shape[1]}) { - // TODO(tlongeri): This relayout is just ext + trunc. Refactor. - // To illustrate, consider a 2 x 16 16-bit shape laid out in vregs of - // 4 sublanes and 2 lanes (this is convenient for to keep the example small - // yet non-trivial) with (1, 4) tiling. We will relayout to (2, 2) tiling. - // - // The vreg slice is 1 x 16, that is, the vreg contains the data for a - // 1 x 16 window of the logical shape. - // - // [a b c d e f g h i j k l m n o p] -> vreg 1 - // [A B C D E F G H I J K L M N O P] -> vreg 2 - // - // Note: we support multiple vregs per row of the logical shape, but we use - // one here just to keep the example small. - // - // When we do a compressed unpack, the resulting vregs effectively have a - // tiling of (1, 2) and cover a vreg slice of 1 x 8 logical elements. - // - // [a b c d e f g h] -> vreg 1, part 1 [i j k l m n o p] -> vreg 1, part 2 - // [A B C D E F G H] -> vreg 2, part 1 [I J K L M N O P] -> vreg 2, part 2 - // - // It is clear that if combine vreg 1, part 1 and vreg 2, part 1 we get data - // that covers a 2 x 8 vreg slice. Note, however, that we will have to mind - // the internal ordering of the vreg. - // - // [a b c d e f g h [i j k l m n o p - // A B C D E F G H] -> new vreg 1 I J K L M N O P] -> new vreg 2 - // - // To see if we can get the right internal ordering that we need for (2, 2) - // tiling, let's break new vreg 1 into (1, 2) rows, which correspond to - // sublanes when unpacked and half-sublanes when packed. - // - // [(a b) (c d) (e f) (g h) - // (A B) (C D) (E F) (G H)] - // - // The sublane order for the vreg parts is [(a b) (c d) ...] for vreg 1, - // part 1 and [(A B) (C D) ...] for vreg 2, part 1. - // - // The desired half-sublane order, for packed (2, 2) tiling, is - // [(a b) (A B) (c d) (C D) ...]. That is, traverse down each column before - // moving to the next one. This is exactly an interleaving of the sublanes - // of the vreg parts. - - // Note: The code below does not work when src is replicated and dst is - // not, since it relies on the src vreg array shape to know how many tiles - // to pack in dst, and vreg array shapes with materialized offsets are - // unfortunately not equal to vreg array shapes with replicated offsets. - VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling, - src.implicit_dim()); - xla::Array retiled( - dst.tileArrayImplicitShape(vty.getShape(), target_shape)); - const VectorType vreg_x32 = - vty.getElementType().isSignlessInteger() - ? VectorType::get(target_shape, builder.getI32Type()) - : VectorType::get(target_shape, builder.getF32Type()); - retiled.Each([&](absl::Span idx, Value *tile) { - SmallVector parts; - parts.reserve(packing); - SmallVector src_idx(toArrayRef(idx)); - const int64_t vreg_part = *(src_idx.end() - 1) % packing; - *(src_idx.end() - 1) /= packing; - if (!dst.offsets()[0].has_value()) { - *(src_idx.end() - 2) = 0; - // Make sure we set all parts of the output vreg to make it replicated - parts.append(packing, builder.create( - loc, vreg_x32, vregs(src_idx), vreg_part, - tpu::PackFormat::kCompressed)); - } else { - *(src_idx.end() - 2) *= packing; - for (int i = 0; i < packing; ++i) { - if (*(src_idx.end() - 2) < *(vregs.dimensions().end() - 2)) { - parts.push_back(builder.create( - loc, vreg_x32, vregs(src_idx), vreg_part, - tpu::PackFormat::kCompressed)); - ++*(src_idx.end() - 2); - } else { - parts.push_back(nullptr); - } - } - } - *tile = builder.create( - loc, cast(vregs.begin()->getType()), parts, - tpu::PackFormat::kInterleaved); - }); - return std::pair(dst, std::move(retiled)); - } - if (src_tiling[1] == target_shape[1] && dst_tiling[1] == target_shape[1]) { - // All clauses in the and expression are based on performance benchmarking. - bool use_alu = !has_enough_scratch || - (ctx.hardware_generation >= 5 && src_tiling[0] != packing && - dst_tiling[0] != packing); - - if (use_alu) { - if (src_tiling[0] > dst_tiling[0] && - // retileToReducedSublanes does not support offset changes - src.offsets()[0].value_or(0) < dst_vreg_slice[0] && - src.offsets()[1].value_or(0) < dst_vreg_slice[1]) { - VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling, - src.implicit_dim()); - return std::pair(dst, retileToReducedSublanes( - builder, vty.getShape(), src, vregs, - VectorLayout(bitwidth, - {src.offsets()[0].value_or(0), - src.offsets()[1].value_or(0)}, - dst_tiling, dst.implicit_dim()), - target_shape)); - } else if (!has_enough_scratch) { - // TODO(b/357538782): Implement retileToIncreasedSublanes with ALU ops. - return emitError( - loc, - "Not implemented: retiling to increase sublane tiling with ALU"); - } - } - return retileWithScratch(ctx, builder, loc, vty.getShape(), dst_tiling, - dst_offsets_hint, vregs, src); - } - return emitError(loc, "Not implemented: Unsupported tiling change for ") - << vty << ": from " << src << " to (" << dst_tiling[0] << ", " - << dst_tiling[1] << ") tiling"; -} - -FailureOr>> changeImplicitDim( - RewriteContext &ctx, OpBuilder &builder, const Location loc, VectorType vty, - const VectorLayout src, xla::Array vregs, - const VectorLayout::ImplicitDim dst_implicit_dim, - const LayoutOffsets dst_offset_hints) { - const auto &target_shape = ctx.target_shape; - if (src.implicit_dim() == dst_implicit_dim) { - return std::make_pair(src, std::move(vregs)); - } - // It's possible that the implicit dim change is a no-op. - VectorLayout src_candidate(src.bitwidth(), src.offsets(), src.tiling(), - dst_implicit_dim); - if (src_candidate.equivalentTo(src, vty.getShape(), target_shape)) { - vregs.Reshape( - src_candidate.tileArrayImplicitShape(vty.getShape(), target_shape)); - return std::make_pair(src_candidate, vregs); - } - // Remove second minor implicit dim, for values that have (m, 128) tiling (for - // m that is a power of 2). - if (src.implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor && - dst_implicit_dim == VectorLayout::ImplicitDim::kNone && - src.bitwidth() == 32 && src.tiling()[1] == target_shape[1] && - llvm::isPowerOf2_32(src.tiling()[0])) { - // We should never see a replicated offset here. We're removing the implicit - // dim so the only case when this can happen is when its size is 1 (or else - // we can't prove replication in the logical value). But in that case, the - // equivalentTo case above triggers and we never reach this branch. - CHECK(dst_offset_hints[0].has_value()); - int64_t dst_sublane_offset = *dst_offset_hints[0]; - VectorLayout dst(src.bitwidth(), {dst_sublane_offset, src.offsets()[1]}, - src.tiling(), dst_implicit_dim); - xla::Array new_vregs( - dst.tileArrayImplicitShape(vty.getShape(), target_shape)); - new_vregs.Each([&](const absl::Span idx, Value *tile) { - const int64_t dst_2nd_minor_idx = idx.size() - 2; - SmallVector src_idx(idx.begin(), idx.end()); - src.insertImplicit(src_idx, 0); - const int dst_sl_start = - idx[dst_2nd_minor_idx] == 0 ? dst_sublane_offset : 0; - // This could be optimized further to take offsets[1] into account. - // For example, extended offsets allow us to skip copies of low sublanes - // in tiles with idx.back() == 0. - const int tiles_per_vreg = src.tilesPerVreg(target_shape); - const int sublanes_per_tile = src.sublanesPerTile(target_shape); - src_idx[dst_2nd_minor_idx] = src.tiling()[0] * idx[dst_2nd_minor_idx] + - dst_sl_start - dst_sublane_offset; - for (int dst_sl_idx = dst_sl_start; - dst_sl_idx < src.tiling()[0] && - src_idx[dst_2nd_minor_idx] < vregs.dim(dst_2nd_minor_idx); - ++dst_sl_idx, ++src_idx[dst_2nd_minor_idx]) { - // This could be optimized further by copying multiple sublanes at once. - for (int tile_idx = 0; tile_idx < tiles_per_vreg; ++tile_idx) { - int tile_off = tile_idx * sublanes_per_tile; - *tile = - copy_one_sublane(builder, vregs(src_idx), - tile_off + src.offsets()[0].value_or(dst_sl_idx), - *tile, tile_off + dst_sl_idx, target_shape); - } - } - }); - return std::make_pair(dst, new_vregs); - } - if (src.implicit_dim() == VectorLayout::ImplicitDim::kNone && - dst_implicit_dim == VectorLayout::ImplicitDim::kMinor && - src.bitwidth() == 32 && src.hasNativeTiling(ctx.target_shape)) { - // TODO(tlongeri): Make insertImplicitMinorDimension more flexible about - // offsets, then we can pass dst_offset_hints directly. - const LayoutOffset dst_2nd_minor_offset = - !src.offsets()[1] || *src.offsets()[1] + *(vty.getShape().end() - 1) <= - ctx.target_shape[1] - ? dst_offset_hints[0] - : LayoutOffset(*src.offsets()[1] % ctx.target_shape[0]); - VectorLayout dst(src.bitwidth(), - {dst_2nd_minor_offset, dst_offset_hints[1]}, src.tiling(), - VectorLayout::ImplicitDim::kMinor); - FAILUREOR_ASSIGN_OR_RETURN( - xla::Array dst_vregs, - insertImplicitMinorDimension(ctx, builder, loc, vregs, - src.implicitShape(vty.getShape()), src, - dst.offsets())); - return std::make_pair(dst, std::move(dst_vregs)); - } - return emitError(loc, - "Not implemented: Unsupported implicit dim change: from ") - << src << " to " << dst_implicit_dim; -} - -// TODO(apaszke): Test this function properly -FailureOr> relayout(RewriteContext &ctx, - OpBuilder &builder, - TypedValue v, - VectorLayout src, - VectorLayout dst) { - const auto target_shape = ctx.target_shape; - const int8_t bitwidth = src.bitwidth(); - if (bitwidth != dst.bitwidth()) { - return emitError(v.getLoc(), "Can't change bitwidth during a relayout"); - } - VectorType vty = v.getType(); - const bool is_mask = vty.getElementTypeBitWidth() == 1; - { - // Replication imposes a replication constraint on the *logical* value of - // the vector: When moving along a replicated axis, all elements must be - // equal. Note that when the axis is a singleton, there is effectively no - // added *logical* constraint. - // For example, a vector<2x2xf32> v with no implicit dims and layout offsets - // {*, 0} is expected to satisfy v[0, 0] == v[1, 0] and v[0, 1] == v[1, 1]. - // Relayout does not change the logical value of the vector. Any replication - // constraints in the result must be guaranteed by the source layout. - SmallVector src_offsets(ArrayRef(src.offsets())); - SmallVector dst_offsets(ArrayRef(dst.offsets())); - // Remove implicit dims to get offsets for trailing logical dims. - src.eraseImplicit(src_offsets); - dst.eraseImplicit(dst_offsets); - for (int i = dst_offsets.size(); i > 0; --i) { - const int64_t dim_size = *(vty.getShape().end() - i); - const bool dim_replicated_in_dst = !*(dst_offsets.end() - i); - // If the dim is untiled in the src layout, then there is no guarantee of - // replication, because we don't track replication for untiled dims. - const bool dim_replicated_in_src = - i <= src_offsets.size() && !*(src_offsets.end() - i); - if (dim_replicated_in_dst && !dim_replicated_in_src && dim_size != 1) { - return emitError(v.getLoc(), - "Invalid relayout: Non-singleton logical dimension is " - "replicated in destination but not in source for ") - << vty << ": " << src << " -> " << dst; - } - } - } - - FAILUREOR_ASSIGN_OR_RETURN( - xla::Array src_tiles, - disassemble(builder, src, v, target_shape, /*use_implicit_shape=*/true)); - if (is_mask) { - auto new_tile_ty = getNativeVregOrVmaskType( - builder.getIntegerType(bitwidth), bitwidth, target_shape); - src_tiles.Each([&](const absl::Span idx, Value *tile) { - *tile = - builder.create(tile->getLoc(), new_tile_ty, *tile); - }); - vty = VectorType::get(vty.getShape(), builder.getIntegerType(bitwidth)); - } - auto assemble_with_mask_check = [&](xla::Array &tiles, - bool use_implicit_shape = false) { - if (is_mask) { - auto zeros_tile = builder.create( - tiles.begin()->getLoc(), - DenseElementsAttr::get( - cast(tiles.begin()->getType()), - builder.getIntegerAttr(builder.getIntegerType(bitwidth), 0))); - tiles.Each([&](const absl::Span idx, Value *tile) { - *tile = builder.create( - tile->getLoc(), arith::CmpIPredicate::ne, *tile, zeros_tile); - }); - vty = VectorType::get(vty.getShape(), builder.getI1Type()); - } - return assemble(builder, vty, dst, tiles, target_shape, use_implicit_shape) - .getResult(); - }; - // Two easy cases: source is more general, or is replicated. - if (src.generalizes(dst, vty.getShape(), target_shape)) { - // A value with a replicated offset might use fewer vregs than a value with - // a non-zero offset. - auto src_product = - xla::Product(src.tileArrayShape(vty.getShape(), target_shape)); - auto dst_product = - xla::Product(dst.tileArrayShape(vty.getShape(), target_shape)); - if (src_product != dst_product) { - TPU_ASSERT_LOC(v.getLoc(), dst_product > src_product); - auto src_offsets = src.offsets(); - - TPU_ASSERT_LOC(v.getLoc(), src_offsets != dst.offsets()); - TPU_ASSERT_LOC(v.getLoc(), src.bitwidth() == dst.bitwidth()); - - if (src.implicit_dim() != dst.implicit_dim()) { - return emitError(v.getLoc(), - "Not implemented: Source layout is more general, but " - "vreg count changes and implicit dims are mismatched"); - } - - if (src.tiling() != dst.tiling()) { - return emitError(v.getLoc(), - "Not implemented: Source layout is more general, but " - "vreg count changes and tiling are mismatched"); - } - - // This case is moving from a replicated to a non replicated layout. - // As such, we need to make a new destination shape that is the - // materialization of the src shape with replication. - FAILUREOR_ASSIGN_OR_RETURN(auto src_vregs, - disassemble(builder, src, v, target_shape, - /*use_implicit_shape=*/true)); - auto dst_vregs_shape = dst.tileArrayShape(vty.getShape(), target_shape); - xla::Array dst_vregs(dst_vregs_shape); - dst_vregs.Each([&](const absl::Span idx, Value *vreg) { - SmallVector local_idx(idx.begin(), idx.end()); - if (!src_offsets[0].has_value()) { - local_idx[local_idx.size() - 2] = 0; - } - if (!src_offsets[1].has_value()) { - local_idx[local_idx.size() - 1] = 0; - } - *vreg = src_vregs(local_idx); - }); - return assemble_with_mask_check(dst_vregs, /*use_implicit_shape=*/true); - } - src_tiles.Reshape(dst.tileArrayImplicitShape(vty.getShape(), target_shape)); - return assemble_with_mask_check(src_tiles, - /*use_implicit_shape=*/true); - } - - if (const LayoutOffsets src_offsets = - src.getCanonicalOffsets(vty.getShape(), ctx.target_shape); - src.layout_rank() >= dst.layout_rank() && !src_offsets[0].has_value() && - !src_offsets[1].has_value()) { - // A fully replicated value is always easy to relayout - xla::Array dst_tiles( - dst.tileArrayImplicitShape(vty.getShape(), target_shape)); - SmallVector idxs; - dst_tiles.Each([&](const absl::Span src_idx, Value *vreg) { - idxs.assign(src_idx.begin(), src_idx.end()); - dst.eraseImplicit(idxs); - src.insertImplicit(idxs, 0); - *(idxs.end() - 2) = 0; - *(idxs.end() - 1) = 0; - *vreg = src_tiles(idxs); - }); - return assemble_with_mask_check(dst_tiles, /*use_implicit_shape=*/true); - } - - // Consider (1,128),-2 -> (8,128). In this case we can change the implicit - // dim for free before we change the tiling, but not after. - // TODO(apaszke): In general the number of vregs necessary to represent a - // value for different implicit dims satisfies kNone < kSecondMinor < kMinor. - // We should use this property to decide if we should change the implicit dim - // before or after changing the tiling and offsets. - if (src.implicit_dim() != dst.implicit_dim()) { - VectorLayout src_candidate(src.bitwidth(), src.offsets(), src.tiling(), - dst.implicit_dim()); - if (src_candidate.equivalentTo(src, vty.getShape(), target_shape)) { - src = src_candidate; - src_tiles.Reshape( - src.tileArrayImplicitShape(vty.getShape(), target_shape)); - } - } - - FAILUREOR_ASSIGN_OR_RETURN( - std::tie(src, src_tiles), - changeTiling(ctx, builder, v.getLoc(), vty, src, std::move(src_tiles), - dst.tiling(), dst.offsets())); - - FAILUREOR_ASSIGN_OR_RETURN( - std::tie(src, src_tiles), - changeImplicitDim(ctx, builder, v.getLoc(), vty, src, - std::move(src_tiles), dst.implicit_dim(), - dst.offsets())); - - FAILUREOR_ASSIGN_OR_RETURN( - std::tie(src, src_tiles), - changeOffsets(ctx, builder, v.getLoc(), vty, src, std::move(src_tiles), - dst.offsets())); - - CHECK_EQ(src, dst); // At this point we've should be done. - return assemble_with_mask_check(src_tiles, - /*use_implicit_shape=*/true); -} - -// TODO(apaszke): Implement a debug mode that inserts additional assertions. -// For example, we should verify that ops that were supposed to generate -// replicated outputs satisfy that requirement. -LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) { - // When an operation does not have any operands, the layout_in tuple is empty. - // If one of the operands is not of vector type, the corresponding entry in - // the layout_in tuple will be None. The same applies to the results of the - // operation and the layout_out tuple. - FAILUREOR_ASSIGN_OR_RETURN(const SmallVector layouts_out, - getOutLayouts(op, ctx.target_shape)); - FAILUREOR_ASSIGN_OR_RETURN(const SmallVector layouts_in, - getInLayouts(op, ctx.target_shape)); - if (!layouts_in.empty() && !isa(op)) { - // Relayout the operands, if their requested input layouts don't match the - // layouts in which they were produced. - for (auto [idx, tup] : - llvm::enumerate(llvm::zip(op.getOperands(), layouts_in))) { - auto [operand, li] = tup; - auto vector_operand = dyn_cast>(operand); - TPU_ASSERT_EQ_OP(vector_operand != nullptr, li.has_value()); - if (vector_operand == nullptr) { - continue; - } - // The operand should always be an Operation (and not a BlockArgument) - // since we expect the FuncOp to have only memrefs and semaphores as - // arguments. - auto op_result = dyn_cast(vector_operand); - if (op_result == nullptr) { - return op.emitError( - "Expected vector operand to be an operation result"); - } - Operation *const def_op = op_result.getOwner(); - DCHECK(def_op); - const unsigned res_idx = op_result.getResultNumber(); - FAILUREOR_ASSIGN_OR_RETURN(const SmallVector def_layouts, - getOutLayouts(*def_op, ctx.target_shape)); - const Layout lo = def_layouts[res_idx]; - TPU_ASSERT_OP(lo.has_value()); - if (*lo == *li) { - continue; - } - OpBuilder builder(&op); - FAILUREOR_ASSIGN_OR_RETURN( - Value new_v, relayout(ctx, builder, vector_operand, /*src=*/*lo, - /*dst=*/*li)); - op.setOperand(idx, new_v); - } - } - - // TODO: b/342235360 - This check is temporary while we increase and test - // support for offsets outside of the first tile. When support is more broad, - // any op without support should check it within their own rule. - if (!isa(op)) { - for (const Layout &layout : layouts_in) { - if (layout && layout->offsets()[1].has_value() && - layout->offsets()[1].value() >= layout->tiling()[1]) { - return op.emitError( - "Not implemented: Input offsets outside of the first tile"); - } - } - } - const bool no_vector_args = - llvm::none_of(layouts_out, - [](Layout layout) { return layout.has_value(); }) && - llvm::none_of(layouts_in, - [](Layout layout) { return layout.has_value(); }); - if (no_vector_args && op.getRegions().empty()) { - // We don't need to do anything for scalar operations. - if (!op.getOperands().empty()) { - op.removeAttr("in_layout"); - } - if (!op.getResults().empty()) { - op.removeAttr("out_layout"); - } - return success(); - } - if (auto rule_it = rules().find(op.getName().getStringRef()); - rule_it != rules().end()) { - const rule_type &rule = rule_it->getValue(); - return rule(ctx, op, layouts_in, layouts_out); - } - if (OpTrait::hasElementwiseMappableTraits(&op)) { - return elementwise_op_rule(ctx, op, layouts_in, layouts_out); - } - return op.emitError("Not implemented: Unsupported operation: ") - << op.getName() << " in apply-vector-layout pass"; -} - -LogicalResult applyLayoutBlock(RewriteContext &ctx, Block &block) { - // We'll be modifying the block, so use early increment. - for (Operation &op : make_early_inc_range(block)) { - if (failed(applyLayoutOp(ctx, op))) { - return failure(); - } - } - return success(); -} - -// Rewrites the function according to layout annotations of its operations. -// -// Args: -// ctx: The context used for rewriting. -// f: An MLIR function to be rewritten. -LogicalResult applyLayoutFunc(RewriteContext &ctx, func::FuncOp f) { - if (f->getNumRegions() != 1) { - return f.emitError("Expected FuncOp to have a single region"); - } - if (!f.getBody().hasOneBlock()) { - return f.emitError("Expected FuncOp to have a single block"); - } - return applyLayoutBlock(ctx, f.getBody().front()); -} - -struct ApplyVectorLayoutPass - : public impl::ApplyVectorLayoutPassBase { - ApplyVectorLayoutPass(const RewriteContext &ctx) { - hardware_generation = ctx.hardware_generation; - sublane_count = ctx.target_shape[0]; - lane_count = ctx.target_shape[1]; - mxu_contracting_size = ctx.mxu_shape[0]; - mxu_noncontracting_size = ctx.mxu_shape[1]; - max_sublanes_in_scratch = ctx.max_sublanes_in_scratch; - vmem_banks = ctx.vmem_banks; - max_shuffle_sublane_offset = ctx.max_shuffle_sublane_offset; - } - void runOnOperation() override { - // Fail if hardware_generation has not been set from the default value. - if (hardware_generation < 0) { - signalPassFailure(); - return; - } - RewriteContext ctx{ - .hardware_generation = hardware_generation, - .target_shape = {sublane_count, lane_count}, - .mxu_shape = {mxu_contracting_size, mxu_noncontracting_size}, - .max_sublanes_in_scratch = max_sublanes_in_scratch, - .vmem_banks = vmem_banks, - .max_shuffle_sublane_offset = max_shuffle_sublane_offset, - }; - if (failed(applyLayoutFunc(ctx, getOperation()))) { - signalPassFailure(); - return; - } - } -}; - -std::unique_ptr> createApplyVectorLayoutPass( - const RewriteContext &ctx) { - return std::make_unique(ctx); -} -} // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h deleted file mode 100644 index ed72a21028eb..000000000000 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h +++ /dev/null @@ -1,61 +0,0 @@ -#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_APPLY_VECTOR_LAYOUT_H_ -#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_APPLY_VECTOR_LAYOUT_H_ - -#include -#include - -#include "mlir/IR/Builders.h" -#include "mlir/IR/Value.h" -#include "mlir/Support/LogicalResult.h" -#include "jaxlib/mosaic/dialect/tpu/layout.h" -#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" -#include "xla/array.h" - -namespace mlir::tpu { - -// TODO(tlongeri): Remove default values for use_implicit_shape. -RollVectorsOp assemble(OpBuilder &builder, VectorType vty, - const VectorLayout &layout, - const xla::Array &vals, - std::array target_shape, - bool use_implicit_shape = false); -FailureOr> disassemble(OpBuilder &builder, - const VectorLayout &layout, - TypedValue val, - std::array target_shape, - bool use_implicit_shape = false); - -// Rewrites the operation according to its layout annotations. -// -// Args: -// ctx: The context used for rewriting. -// op: An MLIR operation to be rewritten. -// -// A valid op is expected to have a layout_in attribute unless it has no -// operands. The layout_in attribute must fulfill the following: -// - All vector operands originate from an operation (not a BlockArgument) -// and -// have a valid layout (Layout1D or Layout2D) -// - All non-vector operands must have NoLayout. -LogicalResult applyLayoutOp(ApplyVectorLayoutContext &ctx, Operation &op); - -// Changes the layout of a vector value. -// -// Arguments: -// ctx: The context used for rewriting. -// builder: The builder used for rewriting. -// v: The value to relayout. Must be of type VectorType. -// src: The current layout of v. -// dst: The target layout of v. -// -// Returns: -// A new MLIR vector value, laid out as requested by dst. -FailureOr> relayout(ApplyVectorLayoutContext &ctx, - OpBuilder &builder, - TypedValue v, - VectorLayout src, - VectorLayout dst); - -} // namespace mlir::tpu - -#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_APPLY_VECTOR_LAYOUT_H_ diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h deleted file mode 100644 index 33c9e7421004..000000000000 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANFORMS_APPLY_VECTOR_LAYOUT_EXTENSIONS_H_ -#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANFORMS_APPLY_VECTOR_LAYOUT_EXTENSIONS_H_ - -#include - -#include "llvm/include/llvm/ADT/StringMap.h" -#include "mlir/include/mlir/IR/Operation.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "jaxlib/mosaic/dialect/tpu/layout.h" -#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" - -namespace mlir::tpu::extensions { - -const llvm::StringMap< - std::function, ArrayRef)>> & -rules(); - -} // namespace mlir::tpu::extensions - -#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANFORMS_APPLY_VECTOR_LAYOUT_EXTENSIONS_H_ diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc deleted file mode 100644 index 5efbdb9cb437..000000000000 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ /dev/null @@ -1,771 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include - -#include "llvm/ADT/STLExtras.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -// It requires these headers, but does not include them. -// NOLINTNEXTLINE(misc-include-cleaner) -#include "mlir/Dialect/MemRef/IR/MemRef.h" -// NOLINTNEXTLINE(misc-include-cleaner) -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringMap.h" -#include "llvm/ADT/StringSet.h" -#include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Math/IR/Math.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h" -#include "mlir/include/mlir/IR/AffineExpr.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/Block.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/OpDefinition.h" -#include "mlir/include/mlir/IR/Operation.h" -#include "mlir/include/mlir/IR/PatternMatch.h" -#include "mlir/include/mlir/IR/Region.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" -#include "jaxlib/mosaic/dialect/tpu/vreg_util.h" - -namespace mlir::tpu { - -#define GEN_PASS_DECL_CANONICALIZEMOSAICPASS -#define GEN_PASS_DEF_CANONICALIZEMOSAICPASS -#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc" - -namespace { - -struct CanonicalizeContext { - // see Note: Compatibility mode - bool compatibility_mode; - - int hardware_generation; -}; - -LogicalResult tpu_matmul_rule(const CanonicalizeContext &ctx, - tpu::MatmulOp op) { - ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation()); - - auto transpose_lhs = op.getTransposeLhs(); - auto transpose_rhs = op.getTransposeRhs(); - - auto lhs = op.getLhs(); - auto rhs = op.getRhs(); - auto acc = op.getAcc(); - - const VectorType lhs_ty = lhs.getType(); - const VectorType rhs_ty = rhs.getType(); - const VectorType acc_ty = acc.getType(); - - auto lhs_element_type = lhs_ty.getElementType(); - auto rhs_element_type = rhs_ty.getElementType(); - auto acc_element_type = acc_ty.getElementType(); - - // there are a few primary paths for dimension_numbers in matmul - // 1) No dimension numbers provided -> set to default - // 2) defined and not default -> verify and apply - // 3) defined and matching defaultDimensionNumbers -> no-op for - // canonicalization of dims - std::optional batch_size = std::nullopt; - - // MKN matmul - no dims or transpositions set - if (!op.getDimensionNumbers().has_value()) { - // Legacy API - convert it to dimension numbers - op.setDimensionNumbersAttr( - defaultDimensionNumbers(builder, transpose_lhs, transpose_rhs)); - } else if ( - // Dot dim API - dimensions are provided and are not default - (op.getDimensionNumbers().value() != - defaultDimensionNumbers(builder, false, false))) { - auto dimension_numbers = op.getDimensionNumbers(); - auto lhs_contracting_dims = dimension_numbers->getLhsContractingDims(); - auto rhs_contracting_dims = dimension_numbers->getRhsContractingDims(); - - auto lhs_batch_dims = dimension_numbers->getLhsBatchDims(); - auto rhs_batch_dims = dimension_numbers->getRhsBatchDims(); - - // Invariant in matmul verifier: <= 1 batch dim atm, and that lhs and rhs - // are the same - // Invariant in matmul verifier: Exactly one contracting and non contracting - // dim in each of lhs and rhs for now. - batch_size = - lhs_batch_dims.empty() - ? std::nullopt - : std::optional(lhs_ty.getShape()[lhs_batch_dims[0]]); - // Lower each dim in contracting dims by size(batch_dims) - auto batch_adjusted_lhs_contracting_dim = - lhs_contracting_dims[0] - lhs_batch_dims.size(); - auto batch_adjusted_rhs_contracting_dim = - rhs_contracting_dims[0] - rhs_batch_dims.size(); - - if (batch_adjusted_lhs_contracting_dim != 1) { - transpose_lhs = true; - } - if (batch_adjusted_rhs_contracting_dim != 0) { - transpose_rhs = true; - } - } - - auto extsi_sitofp = [&builder, &op](TypedValue element) { - const VectorType ty = element.getType(); - auto shape = ty.getShape(); - CHECK(ty.getElementType().isInteger()); - TypedValue ext_ele; - if (ty.getElementType().getIntOrFloatBitWidth() == 32) { - ext_ele = element; - } else { - ext_ele = cast>( - builder - .create( - VectorType::get(shape, builder.getI32Type()), element) - .getResult()); - } - // TODO(mvoz): Go to bf16 when hardware supported, requires adding support - // for 16 bitwidth in extsiop in infer/apply. - auto ele_as_fp = builder.create( - op.getLoc(), VectorType::get(shape, builder.getF32Type()), ext_ele); - return ele_as_fp; - }; - - if (lhs_element_type != rhs_element_type) { - if (!ctx.compatibility_mode) { - return op->emitOpError( - "Mosaic matmul invoked with mixed element types, but compatibility " - "mode is disabled."); - } - if (lhs_element_type.isInteger() && rhs_element_type.isInteger()) { - // TODO(mvoz): Add support for mixed int/int matmul. - op->emitOpError("Mix int/int - NYI"); - return failure(); - } - if (acc_element_type.isInteger()) { - // TODO(mvoz): Add support for mixed int/float matmul with int acc. - // Should be pretty straightforward. - op->emitOpError("acc is int in mixed matmul. Expected float."); - return failure(); - } - if (lhs_element_type.isInteger()) { - auto float_lhs = extsi_sitofp(lhs); - op->setOperand(0, float_lhs); - lhs = cast>(float_lhs.getResult()); - } - if (rhs_element_type.isInteger()) { - auto float_rhs = extsi_sitofp(rhs); - op->setOperand(1, float_rhs); - rhs = cast>(float_rhs.getResult()); - } - } - // TODO(mvoz): Add more invariants. - if (acc_element_type.isInteger()) { - if (!op.getLhs().getType().getElementType().isInteger()) { - op->emitOpError("int acc with float lhs. Expected int lhs."); - return failure(); - } - if (!op.getRhs().getType().getElementType().isInteger()) { - op->emitOpError("int acc with float rhs. Expected int rhs."); - return failure(); - } - } else { - if (op.getLhs().getType().getElementType().isInteger()) { - op->emitOpError("float acc with int lhs. Expected float lhs."); - return failure(); - } - if (op.getRhs().getType().getElementType().isInteger()) { - op->emitOpError("float acc with int rhs. Expected float rhs."); - return failure(); - } - } - - auto dot_dim_matmul = [&](auto lhs, auto rhs, auto acc) { - auto precision_attr = op.getPrecisionAttr(); - - // If we are transposing the lhs, we need to transpose the lhs before - // matmul here, as we don't have lhs fusion implemented in apply. - if (transpose_lhs) { - auto lhs_ty = cast(lhs.getType()); - auto rank = lhs_ty.getShape().size(); - - // This transposition must run on vectors with rank >= 2 - CHECK_GE(rank, 2); - - std::vector perm(rank); - std::iota(perm.begin(), perm.end(), 0); - std::swap(perm[rank - 2], perm[rank - 1]); - - std::vector shape(lhs_ty.getShape()); - std::swap(shape[rank - 2], shape[rank - 1]); - - auto lhs_ty_transposed = VectorType::get(shape, lhs_ty.getElementType()); - - const SmallVector perm_vec = - SmallVector(perm.begin(), perm.end()); - lhs = builder.create( - lhs_ty_transposed, lhs, - DenseI64ArrayAttr::get(builder.getContext(), perm_vec)); - } - auto ddn = defaultDimensionNumbers(builder, /*transpose_lhs=*/false, - transpose_rhs); - // transpose flags are always false here, because ddn takes precedence - // after this pass. - auto matmul_res = builder.create( - op.getLoc(), acc.getType(), lhs, rhs, acc, - /*transpose_lhs=*/false, - /*transpose_rhs=*/false, precision_attr, ddn); - return matmul_res; - }; - - // If we have a batch_size, we want to slice rhs and lhs [:batch_size], - // and then do O[i] = A[i] @ B[i] - // Produce an output shape of [batch_size, m, n] - if (batch_size.has_value()) { - std::vector outputs; - - for (int64_t i = 0; i < batch_size; ++i) { - auto sliced_lhs = builder.create(op.getLoc(), lhs, - ArrayRef{i}); - auto sliced_rhs = builder.create(op.getLoc(), rhs, - ArrayRef{i}); - - auto sliced_acc = builder.create(op.getLoc(), acc, - ArrayRef{i}); - - auto matmul_res = - dot_dim_matmul(sliced_lhs.getResult(), sliced_rhs.getResult(), - sliced_acc.getResult()); - auto res_ty = matmul_res.getType().cast(); - auto res_shape = res_ty.getShape(); - // reshape to 1x[prior_shape] - auto reshape_shape = llvm::to_vector(res_shape); - reshape_shape.insert(reshape_shape.begin(), 1); - auto shape_cast = builder.create( - op.getLoc(), VectorType::get(reshape_shape, res_ty.getElementType()), - matmul_res); - outputs.push_back(shape_cast); - } - // Technically almost identical to the case where batch_size is 1, but - // we want to avoid the spurious concat here. - if (batch_size == 1) { - op.replaceAllUsesWith(outputs[0]); - op.erase(); - return success(); - } - auto output = builder - .create(op.getLoc(), acc_ty, outputs, - /*dimension=*/0) - .getResult(); - op.replaceAllUsesWith(output); - op.erase(); - } else { - auto matmul_res = dot_dim_matmul(lhs, rhs, acc).getResult(); - op.replaceAllUsesWith(matmul_res); - op.erase(); - } - return success(); -}; - -LogicalResult canonicalize_elementwise(const CanonicalizeContext &ctx, - Operation &op) { - OpBuilder builder(&op); - auto operands = op.getOperands(); - auto res_ty = dyn_cast(op.getResult(0).getType()); - if (op.getNumResults() != 1) { - op.emitOpError("Invariant violated: Unexpected number of results"); - return failure(); - } - if (!res_ty) { - // scalar - // TODO(mvoz): Add canonicalization and invariants for scalar elementwise - // ops. - return success(); - } - auto shape = res_ty.getShape(); - std::vector new_operands; - new_operands.reserve(operands.size()); - - bool should_rewrite_op = false; - auto target_f32_ty = VectorType::get(shape, builder.getF32Type()); - for (int i = 0; i < operands.size(); ++i) { - auto operand = operands[i]; - auto ty = dyn_cast(operand.getType()); - if (ty) { - if (ty.getShape() != shape) { - // Should already be checked my MLIR verification, but let's be safe. - op.emitOpError("Mismatched shapes in elementwise op."); - return failure(); - } - auto element_type = ty.getElementType(); - // There's an annoying hodgepodge of elementwise ops that need to be - // rewritten to f32 on later hardware. - // TODO(mvoz): Look into (1) what it would take to support these ops - // natively on later hardware, and (2) how to better organize this list. - bool needs_cast = ctx.hardware_generation <= 5 || isa(op) || - isa(op) || isa(op) || - isa(op); - if (needs_cast && element_type.isBF16()) { - if (ctx.compatibility_mode) { - auto target_f32 = - builder.create(op.getLoc(), target_f32_ty, operand) - .getResult(); - should_rewrite_op = true; - new_operands.push_back(target_f32); - } else { - op.emitOpError( - "Compatibility mode disabled. Unsupported element type in " - "elementwise op on hardware generation: ") - << ctx.hardware_generation - << ". Use hardware generation after 5 or cast to f32."; - return failure(); - } - } else { - new_operands.push_back(operand); - } - } else { - // Should already be checked my MLIR verification, but let's be safe. - op.emitOpError("MLIR unsupported - mix scalar and vec elementwise ops"); - return failure(); - } - } - if (should_rewrite_op) { - auto result_ty = dyn_cast(op.getResult(0).getType()); - if (!result_ty) { - op.emitOpError("Not implemented: Unexpected result type"); - return failure(); - } - auto result_element_type = result_ty.getElementType(); - if (!result_element_type.isF32() && !result_element_type.isBF16()) { - op.emitOpError("Not implemented: Unexpected result element type"); - return failure(); - } - // Do the new op in f32, then truncate to the original element type. - auto new_op = builder.create(op.getLoc(), op.getName().getIdentifier(), - new_operands, target_f32_ty); - new_op = builder.create(op.getLoc(), res_ty, - new_op->getResult(0)); - op.replaceAllUsesWith(new_op); - op.erase(); - } - return success(); -} - -LogicalResult canonicalize_multi_dim_reduction(const CanonicalizeContext &ctx, - Operation &operation) { - ImplicitLocOpBuilder builder(operation.getLoc(), &operation); - auto op = cast(operation); - auto source_ty = op.getSourceVectorType(); - auto result_ty = dyn_cast(op.getDestType()); - if (!result_ty) { - return op->emitOpError() << "Only vector reductions supported"; - } - - auto element_type = source_ty.getElementType(); - if (element_type.isF32()) { - return success(); - } else if (element_type.isBF16()) { - bool reduces_sublanes = false; - for (int64_t dim : op.getReductionDims()) { - if (dim == source_ty.getRank() - 2) { - reduces_sublanes = true; - } - } - if (ctx.hardware_generation <= 5) { - auto new_source = builder.create( - VectorType::get(source_ty.getShape(), builder.getF32Type()), - op.getSource()); - - auto result_ty_f32 = - VectorType::get(result_ty.getShape(), builder.getF32Type()); - auto acc_ext = builder.create(result_ty_f32, op.getAcc()); - Value new_acc = acc_ext.getResult(); - // Try to constant fold. - if (auto const_acc = op.getAcc().getDefiningOp()) { - auto result = - acc_ext.fold(arith::ExtFOp::FoldAdaptor(const_acc.getValue())); - if (!result.isNull() && result.is()) { - acc_ext->erase(); - new_acc = builder.create( - op.getLoc(), result_ty_f32, - cast(result.get())); - } - } - auto new_op = builder.create( - op.getLoc(), new_acc.getType(), op.getKindAttr(), new_source, new_acc, - DenseI64ArrayAttr::get(builder.getContext(), op.getReductionDims())); - auto new_result = builder.create(op.getLoc(), result_ty, - new_op.getResult()); - op.replaceAllUsesWith(new_result.getResult()); - op.erase(); - } - return success(); - } else if (element_type.isSignlessInteger(32) && - // TODO(b/384774084): Add support for u32 reductions. - (op.getKind() == vector::CombiningKind::ADD || - op.getKind() == vector::CombiningKind::MAXSI || - op.getKind() == vector::CombiningKind::MINSI)) { - return success(); - } - op.emitOpError("Unsupported element type for the selected reduction"); - return failure(); -} - -LogicalResult canonicalize_matmul(const CanonicalizeContext &ctx, - Operation &op) { - auto matmul_op = dyn_cast(op); - if (!matmul_op) { - op.emitOpError("Invariant violated: Not a matmul"); - return failure(); - } - return tpu_matmul_rule(ctx, matmul_op); -}; - -LogicalResult canonicalize_contraction(const CanonicalizeContext &ctx, - Operation &op) { - auto contraction_op = dyn_cast(op); - if (!contraction_op) { - op.emitOpError("Invariant violated: Not a contraction"); - return failure(); - } - // Rewrite the contraction as a matmul - auto lhs = contraction_op.getLhs(); - auto rhs = contraction_op.getRhs(); - auto acc = contraction_op.getAcc(); - VectorType acc_ty; - if (!(acc_ty = dyn_cast(acc.getType()))) { - contraction_op->emitOpError("Not implemented: acc must be a vector"); - return failure(); - } - - if (contraction_op.getKind() != vector::CombiningKind::ADD) { - contraction_op->emitOpError("Only ADD supported"); - return failure(); - } - - ImplicitLocOpBuilder builder(contraction_op->getLoc(), - contraction_op.getOperation()); - - MLIRContext *const mlir_ctx = contraction_op->getContext(); - - auto getMapAttr = [&](const unsigned first, const unsigned second) { - return AffineMapAttr::get(AffineMap::get( - 3, 0, - {getAffineDimExpr(first, mlir_ctx), getAffineDimExpr(second, mlir_ctx)}, - mlir_ctx)); - }; - - const ArrayAttr matmul_indexing_maps = builder.getArrayAttr( - {getMapAttr(0, 2), getMapAttr(2, 1), getMapAttr(0, 1)}); - const ArrayAttr matmul_indexing_maps_transposed = builder.getArrayAttr( - {getMapAttr(0, 2), getMapAttr(1, 2), getMapAttr(0, 1)}); - const auto indexing_maps = contraction_op.getIndexingMaps(); - if (indexing_maps != matmul_indexing_maps && - indexing_maps != matmul_indexing_maps_transposed) { - return contraction_op->emitOpError( - "Not implemented: Non-matmul or unsupported indexing_maps"); - } - const bool transpose_rhs = indexing_maps == matmul_indexing_maps_transposed; - - const ArrayAttr matmul_iterator_types = - builder.getArrayAttr({builder.getAttr( - vector::IteratorType::parallel), - builder.getAttr( - vector::IteratorType::parallel), - builder.getAttr( - vector::IteratorType::reduction)}); - if (contraction_op->getAttr("iterator_types") != matmul_iterator_types) { - return contraction_op->emitOpError( - "Not implemented: Non-matmul iterator_types"); - } - const tpu::ContractPrecisionAttr precision_attr = // May be null - contraction_op->getAttrOfType("precision"); - - const auto dot_dimension_numbers_attr = - defaultDimensionNumbers(builder, false, transpose_rhs); - - auto matmul_op = builder.create( - contraction_op->getLoc(), acc_ty, lhs, rhs, acc, - /*transpose_lhs=*/false, - /*transpose_rhs=*/false, precision_attr, dot_dimension_numbers_attr); - contraction_op.replaceAllUsesWith(matmul_op.getResult()); - contraction_op.erase(); - auto result = tpu_matmul_rule(ctx, matmul_op); - return result; -} - -LogicalResult canonicalize_extract(const CanonicalizeContext &ctx, - Operation &raw_op) { - auto op = dyn_cast(raw_op); - Type result_ty = op.getResult().getType(); - if (!isa(result_ty)) { - bool is_supported = result_ty.isSignlessIntOrFloat() && - result_ty.getIntOrFloatBitWidth() == 32; - if (!is_supported) { - return op.emitOpError( - "Only 32-bit scalar vector.extracts supported. Cast your input to a " - "32-bit type first."); - } - } - return success(); -} - -LogicalResult canonicalize_select(const CanonicalizeContext &ctx, - Operation &raw_op) { - auto op = dyn_cast(raw_op); - if (!isa(op.getType()) || - isa(op.getCondition().getType())) { - return success(); - } - // Canonicalize `i1 ? v1 : v2` -> `broadcast(i1) ? v1 : v2`. - ImplicitLocOpBuilder builder(op->getLoc(), op.getOperation()); - auto cond_ty = VectorType::get(cast(op.getType()).getShape(), - op.getCondition().getType()); - auto cond = builder.create(cond_ty, op.getCondition()); - auto new_op = builder.create( - op.getLoc(), cond, op.getTrueValue(), op.getFalseValue()); - op.replaceAllUsesWith(new_op.getResult()); - op.erase(); - return success(); -} - -// All conversions that change bitwidth must be canonicalized to tpu.fptosi. -LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx, - Operation &raw_op) { - auto op = cast(raw_op); - ImplicitLocOpBuilder builder(op->getLoc(), op.getOperation()); - auto src_vty = dyn_cast(op.getIn().getType()); - auto dst_vty = dyn_cast(op.getType()); - if (static_cast(src_vty) != static_cast(dst_vty)) { - return op.emitOpError("Vector/scalar mismatch between input and output"); - } - bool is_vector = static_cast(src_vty); - unsigned src_bitwidth, dst_bitwidth; - if (is_vector) { - src_bitwidth = src_vty.getElementTypeBitWidth(); - dst_bitwidth = dst_vty.getElementTypeBitWidth(); - } else { - src_bitwidth = op.getIn().getType().getIntOrFloatBitWidth(); - dst_bitwidth = op.getType().getIntOrFloatBitWidth(); - } - if (dst_bitwidth > 32) { - return op.emitOpError("Target bitwidth too large"); - } - // We have low-level optimized code for bf16->s8 and bf16->s4 casts on v6. - if (ctx.hardware_generation >= 6 && is_vector && - src_vty.getElementType().isBF16() && - (dst_vty.getElementType().isSignlessInteger(8) || - dst_vty.getElementType().isSignlessInteger(4))) { - auto new_op = builder.create( - op.getType(), op.getIn(), tpu::RoundingMode::kTowardsZero); - op.replaceAllUsesWith(new_op.getResult()); - op.erase(); - // We briefly trigger canonicalization here to potentially fuse the rounding - // ops into the newly created tpu.fptosi. - { - PatternRewriter rewriter(new_op.getContext()); - rewriter.setInsertionPoint(new_op); - // We don't care if the canonicalization pattern matched or not. - (void)tpu::FPToSIOp::canonicalize(new_op, rewriter); - new_op = nullptr; // Canonicalization may have erased the op! - } - return success(); - } - Value x = op.getIn(); - // Upcast the input to f32. - if (src_bitwidth < 32) { - if (is_vector) { - x = builder.create( - VectorType::get(src_vty.getShape(), builder.getF32Type()), x); - } else { - x = builder.create(builder.getF32Type(), x); - } - } - if (dst_bitwidth < 32) { - if (!ctx.compatibility_mode) { - return op.emitOpError( - "On this target only float-to-integer conversions can only happen on " - "32-bit values. Enable compatibility mode or upcast to float32."); - } - // Need to clip values to match XLA - auto clip = [&](Value x, Value low, Value high) { - x = builder.create(x, low); - x = builder.create(x, high); - return x; - }; - auto minval = builder.getF32FloatAttr( - APInt::getSignedMinValue(dst_bitwidth).getSExtValue()); - auto maxval = builder.getF32FloatAttr( - APInt::getSignedMaxValue(dst_bitwidth).getSExtValue()); - if (is_vector) { - auto x_vty = cast(x.getType()); - x = clip(x, getFullVector(builder, x_vty, minval), - getFullVector(builder, x_vty, maxval)); - } else { - auto f32 = builder.getF32Type(); - x = clip(x, builder.create(f32, minval), - builder.create(f32, maxval)); - } - } - if (is_vector) { - x = builder.create( - VectorType::get(src_vty.getShape(), builder.getI32Type()), x); - } else { - x = builder.create(builder.getI32Type(), x); - } - if (dst_bitwidth < 32) { - if (!ctx.compatibility_mode) { - return op.emitOpError( - "On this target only float-to-integer conversions can only happen on " - "32-bit values. Enable compatibility mode or cast to int32 and " - "truncate later."); - } - x = builder.create(op.getType(), x); - } - op.replaceAllUsesWith(x); - op.erase(); - return success(); -} - -LogicalResult canonicalize_repeat(const CanonicalizeContext &ctx, - Operation &raw_op) { - auto op = dyn_cast(raw_op); - if (!isa(op.getType())) { - return op.emitOpError("Only vector types supported"); - } - auto operand = op.getSource(); - auto times = op.getTimes(); - if (times == 1) { - // A true no op - kind of an odd edge case, but this does come up in - // flash_attention_backward tests. - op.replaceAllUsesWith(operand); - op.erase(); - return success(); - } - auto operands = std::vector(times, operand); - ImplicitLocOpBuilder builder(op->getLoc(), op.getOperation()); - auto concat = builder.create(op.getLoc(), op.getType(), - operands, op.getDimension()); - op.replaceAllUsesWith(concat.getResult()); - op.erase(); - return success(); -} - -using canonicalize_rule_type = - std::function; - -const llvm::StringMap &rules() { - static auto rules = new llvm::StringMap{ - {tpu::MatmulOp::getOperationName(), canonicalize_matmul}, - {vector::ContractionOp::getOperationName(), canonicalize_contraction}, - {vector::ExtractOp::getOperationName(), canonicalize_extract}, - {vector::MultiDimReductionOp::getOperationName(), - canonicalize_multi_dim_reduction}, - {arith::SelectOp::getOperationName(), canonicalize_select}, - {arith::FPToSIOp::getOperationName(), canonicalize_fptosi}, - {tpu::RepeatOp::getOperationName(), canonicalize_repeat}}; - return *rules; -} - -bool need_elementwise_canonicalization(CanonicalizeContext ctx, Operation &op) { - if (isa(op)) { - auto vec_ty = dyn_cast(op.getOperand(0).getType()); - if (vec_ty && vec_ty.getElementType().isBF16() && - ctx.hardware_generation >= 4) { - return false; - } - return true; - } - return isa(op); -} - -class MosaicCanonicalizer { - public: - MosaicCanonicalizer(int hardware_generation, bool compatibility_mode) - : hardware_generation_(hardware_generation), - compatibility_mode_(compatibility_mode) {} - - int hardware_generation_; - bool compatibility_mode_; - - LogicalResult canonicalize(func::FuncOp op) { - if (!op.getBody().hasOneBlock()) { - op.emitOpError("Only one block functions supported"); - return failure(); - } - return canonicalizeBlock(op.getBody().front()); - } - - LogicalResult canonicalizeBlock(Block &block) { - // make_early_inc_range is utilized due to op mutation. - for (Operation &any_op : make_early_inc_range(block)) { - if (canonicalizeOp(any_op).failed()) { - return failure(); - } - } - return success(); - } - - LogicalResult canonicalizeOp(Operation &any_op) { - CanonicalizeContext ctx({compatibility_mode_, hardware_generation_}); - // We must iterate over the op first, because canonicalization can cause - // us to .erase() an op, and accessing getRegions on it after is not sound. - // Invariant - top level ops with regions may never be invalidated. - for (Region ®ion : any_op.getRegions()) { - for (Block &block : region) { - if (canonicalizeBlock(block).failed()) { - return failure(); - } - } - } - if (need_elementwise_canonicalization(ctx, any_op)) { - return canonicalize_elementwise(ctx, any_op); - } - if (auto rule_it = rules().find(any_op.getName().getStringRef()); - rule_it != rules().end()) { - const canonicalize_rule_type &rule = rule_it->getValue(); - return rule(ctx, any_op); - } - return success(); - } -}; - -struct CanonicalizeMosaicPass - : public impl::CanonicalizeMosaicPassBase { - CanonicalizeMosaicPass(int hardware_generation_p, bool compatibility_mode_p) - : compatibility_mode_(compatibility_mode_p) { - this->hardware_generation = hardware_generation_p; - } - - void runOnOperation() override { - func::FuncOp func = getOperation(); - MosaicCanonicalizer vlc(hardware_generation, compatibility_mode_); - if (vlc.canonicalize(func).failed()) { - signalPassFailure(); - } - }; - - bool compatibility_mode_; -}; - -} // namespace - -std::unique_ptr> createCanonicalizeMosaicPass( - int hardware_generation, bool compatibility_mode) { - return std::make_unique(hardware_generation, - compatibility_mode); -} - -} // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/transforms/communication.cc b/jaxlib/mosaic/dialect/tpu/transforms/communication.cc deleted file mode 100644 index 89e3a8bb9f70..000000000000 --- a/jaxlib/mosaic/dialect/tpu/transforms/communication.cc +++ /dev/null @@ -1,148 +0,0 @@ -/* Copyright 2023 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/Visitors.h" -#include "mlir/Support/LogicalResult.h" -#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" -#include "xla/layout.h" - -namespace mlir::tpu { - -namespace { - -struct CommsAnalysisState { - bool has_communication = false; - bool has_custom_barrier = false; - - explicit operator bool() { return has_communication && has_custom_barrier; } -}; - -void analyzeCrossChipCommunication(mlir::Operation *op, - CommsAnalysisState *state) { - if (auto dma = dyn_cast(op)) { - state->has_communication |= dma.getDeviceId() != nullptr; - } else if (auto signal = dyn_cast(op)) { - state->has_communication |= signal.getDeviceId() != nullptr; - } else if (auto barrier = dyn_cast(op)) { - state->has_custom_barrier = true; - } - for (Region ®ion : op->getRegions()) { - for (Block &block : region.getBlocks()) { - for (Operation &op : block.getOperations()) { - analyzeCrossChipCommunication(&op, state); - if (*state) { - return; - } - } - } - } -} - -} // namespace - -std::pair mightCommunicateBetweenChips(mlir::Operation *op) { - CommsAnalysisState state; - analyzeCrossChipCommunication(op, &state); - return std::make_pair(state.has_communication, state.has_custom_barrier); -} - -#define GEN_PASS_DECL_LOGICALTOPHYSICALDEVICEIDPASS -#define GEN_PASS_DEF_LOGICALTOPHYSICALDEVICEIDPASS -#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc" - -namespace { - -template -void logicalToPhysicalDeviceIds(Op op, Value device_assignment) { - auto device_id = op.getDeviceIdMutable(); - if (device_id.empty()) { - return; - } - CHECK_EQ(device_id.size(), 1); - mlir::OpBuilder builder(op); - auto logical_id = builder.create( - op.getLoc(), builder.getIndexType(), op.getDeviceId()); - auto physical_id = builder.create( - op.getLoc(), device_assignment, ValueRange{logical_id}); - device_id.assign(physical_id); -} - -} // namespace - -struct LogicalToPhysicalDeviceIdPass - : public impl::LogicalToPhysicalDeviceIdPassBase< - LogicalToPhysicalDeviceIdPass> { - explicit LogicalToPhysicalDeviceIdPass(int64_t total_devices_) { - total_devices = total_devices_; - } - - void runOnOperation() override { - if (total_devices <= 0) { - signalPassFailure(); - return; - } - func::FuncOp func = getOperation(); - if (func.getName() == "main") { - auto device_assignment_type = MemRefType::get( - {total_devices}, IntegerType::get(func.getContext(), 32), - TiledLayoutAttr::get(func.getContext(), {xla::Tile({128})}, {1}), - MemorySpaceAttr::get(func.getContext(), MemorySpace::smem)); - func.insertArgument(func.getNumArguments(), device_assignment_type, - nullptr, UnknownLoc::get(func.getContext())); - auto device_assignment_arg = func.getArgument(func.getNumArguments() - 1); - func.walk([device_assignment_arg](Operation *some_op) { - if (auto op = dyn_cast(some_op)) { - logicalToPhysicalDeviceIds(op, device_assignment_arg); - } else if (auto op = dyn_cast(some_op)) { - logicalToPhysicalDeviceIds(op, device_assignment_arg); - } - }); - } else { - auto result = func.walk([](Operation *some_op) { - auto fail = [some_op]() { - some_op->emitOpError( - "Communication ops are only allowed in the main function."); - return WalkResult::interrupt(); - }; - if (auto op = dyn_cast(some_op)) { - return fail(); - } - if (auto op = dyn_cast(some_op)) { - return fail(); - } - return WalkResult::advance(); - }); - if (result.wasInterrupted()) { - signalPassFailure(); - } - } - } -}; - -std::unique_ptr> -createLogicalToPhysicalDeviceIdPass(int64_t total_devices) { - return std::make_unique(total_devices); -} - -} // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/transforms/debug_assert_insertion.cc b/jaxlib/mosaic/dialect/tpu/transforms/debug_assert_insertion.cc deleted file mode 100644 index 846e3bbb341f..000000000000 --- a/jaxlib/mosaic/dialect/tpu/transforms/debug_assert_insertion.cc +++ /dev/null @@ -1,171 +0,0 @@ -/* Copyright 2023 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include - -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/StringMap.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/IR/Types.h" -#include "mlir/IR/Value.h" -#include "mlir/IR/ValueRange.h" -#include "mlir/IR/Visitors.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" - -namespace mlir::tpu { - -#define GEN_PASS_DECL_DEBUGASSERTINSERTIONPASS -#define GEN_PASS_DEF_DEBUGASSERTINSERTIONPASS -#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc" - -namespace { - -using rule_type = std::function; - -template -rule_type as_generic_rule(void (*rule)(Op)) { - return [rule](const Operation *op) { return rule(cast(op)); }; -} - -void assertIsValidSubwindow(Operation *op, mlir::ValueRange base_indices, - ArrayRef window_shape, - ArrayRef full_shape, - ArrayRef strides = {}) { - if (base_indices.size() != window_shape.size() || - base_indices.size() != full_shape.size() || - (!strides.empty() && base_indices.size() != strides.size())) { - return; // Malformed op. - } - if (base_indices.empty()) { - return; - } - Type idx_type = base_indices.front().getType(); - ImplicitLocOpBuilder builder(op->getLoc(), op); - for (auto [dim, access] : - llvm::enumerate(llvm::zip(base_indices, window_shape, full_shape))) { - auto [idx, size, bound] = access; - int64_t stride = strides.empty() ? 1 : strides[dim]; - Value positive = builder.create( - arith::CmpIPredicate::sge, idx, - builder.create(builder.getIntegerAttr(idx_type, 0))); - Value in_bounds = builder.create( - arith::CmpIPredicate::slt, - builder.create( - idx, builder.create( - builder.getIntegerAttr(idx_type, (size - 1) * stride))), - builder.create( - builder.getIntegerAttr(idx_type, bound))); - std::string msg; - llvm::raw_string_ostream msg_builder(msg); - msg_builder << "Operation " << op->getName().getStringRef().str() - << " references out-of-bounds elements in dimension " - << std::to_string(dim) << " (source location: " << op->getLoc() - << ")"; - builder.create( - builder.create(positive, in_bounds), msg); - } -} - -void vector_load_rule(vector::LoadOp op) { - assertIsValidSubwindow(op, op.getIndices(), - /*window_shape=*/op.getVectorType().getShape(), - /*full_shape=*/op.getBase().getType().getShape()); -} - -void vector_store_rule(vector::StoreOp op) { - assertIsValidSubwindow(op, op.getIndices(), - /*window_shape=*/op.getVectorType().getShape(), - /*full_shape=*/op.getBase().getType().getShape()); -} - -void tpu_memref_slice_rule(tpu::MemRefSliceOp op) { - assertIsValidSubwindow(op, op.getBaseIdx(), - /*window_shape=*/op.getResult().getType().getShape(), - /*full_shape=*/op.getMemRef().getType().getShape()); -} - -void tpu_strided_load_rule(tpu::StridedLoadOp op) { - assertIsValidSubwindow(op, op.getIndices(), - /*window_shape=*/op.getResult().getType().getShape(), - /*full_shape=*/op.getBase().getType().getShape(), - /*strides=*/op.getStrides()); -} - -void tpu_strided_store_rule(tpu::StridedStoreOp op) { - assertIsValidSubwindow( - op, op.getIndices(), - /*window_shape=*/op.getValueToStore().getType().getShape(), - /*full_shape=*/op.getBase().getType().getShape(), - /*strides=*/op.getStrides()); -} - -void tpu_vector_store_rule(tpu::VectorStoreOp op) { - // TODO(b/379925823): Take strides into account. - assertIsValidSubwindow( - op, op.getIndices(), - /*window_shape=*/op.getValueToStore().getType().getShape(), - /*full_shape=*/op.getBase().getType().getShape()); -} - -const llvm::StringMap &rules() { - static auto rules = new llvm::StringMap{ - // TODO: tpu::LoadOp, tpu::StoreOp - {vector::LoadOp::getOperationName(), as_generic_rule(vector_load_rule)}, - {vector::StoreOp::getOperationName(), as_generic_rule(vector_store_rule)}, - {tpu::MemRefSliceOp::getOperationName(), - as_generic_rule(tpu_memref_slice_rule)}, - {tpu::StridedLoadOp::getOperationName(), - as_generic_rule(tpu_strided_load_rule)}, - {tpu::StridedStoreOp::getOperationName(), - as_generic_rule(tpu_strided_store_rule)}, - {tpu::VectorStoreOp::getOperationName(), - as_generic_rule(tpu_vector_store_rule)}, - }; - return *rules; -} - -struct DebugAssertInsertionPass - : public impl::DebugAssertInsertionPassBase { - void runOnOperation() override { - func::FuncOp func = getOperation(); - func.walk([](Operation *op) { - if (auto rule_it = rules().find(op->getName().getStringRef()); - rule_it != rules().end()) { - const rule_type &rule = rule_it->getValue(); - rule(op); - } - return WalkResult::advance(); - }); - } -}; - -} // namespace - -std::unique_ptr> createDebugAssertInsertionPass() { - return std::make_unique(); -} - -} // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc b/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc deleted file mode 100644 index e7528533938f..000000000000 --- a/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc +++ /dev/null @@ -1,19 +0,0 @@ -#include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h" - -#include "llvm/include/llvm/ADT/StringMap.h" -#include "mlir/include/mlir/IR/Operation.h" - -namespace mlir::tpu::extensions { - -using RewriteContext = ApplyVectorLayoutContext; - -using rule_type = std::function, ArrayRef)>; - -const llvm::StringMap &rules() { - static const llvm::StringMap *rules = - new llvm::StringMap{}; - return *rules; -} - -} // namespace mlir::tpu::extensions \ No newline at end of file diff --git a/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc b/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc deleted file mode 100644 index c9c4a97e6222..000000000000 --- a/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc +++ /dev/null @@ -1,19 +0,0 @@ -#include "jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h" - -#include -#include - -#include "mlir/include/mlir/IR/Operation.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Support/LogicalResult.h" - -namespace mlir::tpu::extensions { - -bool canInferVectorLayout(const Operation &op) { return false; } - -LogicalResult inferVectorLayout(const Operation &op, - std::array target_shape) { - return failure(); -} - -} // namespace mlir::tpu::extensions diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc deleted file mode 100644 index 0926f8a3c7b5..000000000000 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc +++ /dev/null @@ -1,438 +0,0 @@ -#include "jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h" - -#include -#include -#include -#include -#include - -#include "llvm/ADT/bit.h" -#include "llvm/Support/MathExtras.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Block.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/Region.h" -#include "mlir/IR/Value.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" -#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" -#include "jaxlib/mosaic/dialect/tpu/util.h" -#include "xla/layout.h" - -namespace mlir::tpu { - -#define GEN_PASS_DECL_INFERMEMREFLAYOUTPASS -#define GEN_PASS_DEF_INFERMEMREFLAYOUTPASS -#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc" - -// Returns the number of lanes (usually 128) groups in a tile. -// -// Arguments: -// src_sublane: A number of lanes in the full operand. -// hardware_generation: An integer indicating the target TPU generation. -// target_sublane_count: The number of sublane in the target shape. -// tpu_tiling_flags: A struct of flags indicating which large tiling modes are -// enabled by XLA for memrefs. -// bitwidth: The bitwidth of the element type of the operand. -// is_kernel_argument: Whether the operand is a kernel argument. -int getTilingFactor(const int src_sublane, const int hardware_generation, - const int64_t target_sublane_count, - const TpuTilingFlags &tpu_tiling_flags, - const int8_t bitwidth, const bool is_kernel_argument) { - CHECK(llvm::isPowerOf2_32(bitwidth)); - CHECK_LE(2, bitwidth); - CHECK_LE(bitwidth, 32); - const int packing = 32 / bitwidth; - const int min_tiling = (1 + (hardware_generation < 4)) * packing; - // When packing is larger than the sublane count, we want its tiling to be at - // least as large as the packing to make sure we can fully pack values. For - // example, for int2 on the target with 8 sublanes, we want the tiling to be - // at least 16. - const int64_t tiling_sublane = - std::max(target_sublane_count, static_cast(packing)); - const int max_normal_tiling = tiling_sublane; - - int large_tiling = [&] { - if (bitwidth == 4 && tpu_tiling_flags.use_x4_large_second_minor) { - return tiling_sublane * 8; - } - if (bitwidth == 8 && tpu_tiling_flags.use_x8_large_second_minor) { - return tiling_sublane * 4; - } - // 16-bit values are generally always possible to relayout on the fly in v6, - // so we allow large 2nd minor tiling whenever possible. We can't do this - // for kernel arguments, because the layout of those is controlled by XLA. - if (bitwidth == 16 && (tpu_tiling_flags.use_x16_large_second_minor || - (!is_kernel_argument && hardware_generation >= 6))) { - return tiling_sublane * 2; - } - return tiling_sublane; - }(); - - bool is_divisible = src_sublane % large_tiling == 0; - large_tiling = is_divisible ? large_tiling : tiling_sublane; - - // Use large tiling if our operand is tall enough to fit at least one full - // tile. - if (large_tiling <= src_sublane) { - return large_tiling; - } - - int tiling = min_tiling; - while (tiling < std::min(src_sublane, max_normal_tiling)) { - tiling *= 2; - } - return tiling; -} - -FailureOr inferLayout(MemRefType memref_ty, - const int hardware_generation, - std::array target_shape, - const TpuTilingFlags &tpu_tiling_flags, - bool is_kernel_argument, - int64_t leading_tile_rows = 0) { - if (auto tiled_layout_attr = - dyn_cast(memref_ty.getLayout())) { - if (leading_tile_rows > 0 && !tiled_layout_attr.getTiles().empty() && - tiled_layout_attr.getTiles().front().dimensions().size() == 2 && - tiled_layout_attr.getTiles().front().dimensions()[0] != - leading_tile_rows) { - return emitError(UnknownLoc::get(memref_ty.getContext()), - "Trying to infer memref layout with sublane tiling ") - << leading_tile_rows - << ", but the memref already has sublane tiling " - << tiled_layout_attr.getTiles().front().dimensions()[0]; - } - return tiled_layout_attr; - } - if (auto affine_map_attr = dyn_cast(memref_ty.getLayout())) { - if (memref_ty.getRank() == 0) { - return emitError(UnknownLoc::get(memref_ty.getContext()), - "0-rank memref not supported"); - } - if (!affine_map_attr.isIdentity()) { - return emitError(UnknownLoc::get(memref_ty.getContext()), - "Non-identity affine layout"); - } - if (!memref_ty.getElementType().isIntOrFloat()) { - return emitError(UnknownLoc::get(memref_ty.getContext()), - "Invalid element type for memref"); - } - const int8_t bitwidth = memref_ty.getElementTypeBitWidth(); - const auto [sublane_count, lane_count] = target_shape; - // Infer the layout - if (memref_ty.getRank() == 1) { - auto src_sublane = - llvm::divideCeil(memref_ty.getShape().back(), lane_count); - const int64_t leading_tile = - getTilingFactor(src_sublane, hardware_generation, - sublane_count, tpu_tiling_flags, bitwidth, - is_kernel_argument) * - lane_count; - SmallVector tiles{xla::Tile({leading_tile})}; - if (bitwidth != 32) { - if (!llvm::has_single_bit(bitwidth) || bitwidth > 32) { - return emitError(UnknownLoc::get(memref_ty.getContext()), - "Unsupported bitwidth: ") - << bitwidth; - } - tiles.append({xla::Tile({lane_count}), xla::Tile({32 / bitwidth, 1})}); - } - return TiledLayoutAttr::get(memref_ty.getContext(), tiles, {1}); - } - - // memref.getRank() > 1 - const ArrayRef shape = memref_ty.getShape(); - - const int64_t src_sublane = shape[shape.size() - 2]; - if (leading_tile_rows == 0) { - leading_tile_rows = getTilingFactor( - src_sublane, hardware_generation, sublane_count, - tpu_tiling_flags, bitwidth, is_kernel_argument); - } - SmallVector tiles{xla::Tile({leading_tile_rows, lane_count})}; - if (bitwidth != 32) { - if (!llvm::has_single_bit(bitwidth) || bitwidth > 32) { - return emitError(UnknownLoc::get(memref_ty.getContext()), - "Unsupported bitwidth: ") - << bitwidth; - } - tiles.push_back(xla::Tile({32 / bitwidth, 1})); - } - auto tile_strides = - ComputeTileStrides(memref_ty, {leading_tile_rows, lane_count}); - return TiledLayoutAttr::get(memref_ty.getContext(), tiles, tile_strides); - } - return emitError(UnknownLoc::get(memref_ty.getContext()), - "Unrecognized layout annotation"); -} - -// Make sure only the first tile might introduce padding. -LogicalResult checkTiles(MLIRContext *mlir_ctx, - const ArrayRef &tiles) { - SmallVector tiled_dims(tiles.front().dimensions().begin(), - tiles.front().dimensions().end()); - for (const xla::Tile &t : tiles.drop_front()) { - const int64_t offset = tiled_dims.size() - t.dimensions().size(); - if (offset < 0) { - return emitError(UnknownLoc::get(mlir_ctx), - "Not implemented: layout too complicated"); - } - for (int i = 0; i < t.dimensions().size(); ++i) { - auto [d, m] = std::div(tiled_dims[offset + i], t.dimension(i)); - if (m != 0) { - return emitError(UnknownLoc::get(mlir_ctx), - "Not implemented: layout too complicated"); - } - tiled_dims[offset + i] = d; - } - tiled_dims.append(t.dimensions().begin(), t.dimensions().end()); - } - return success(); -} - -FailureOr inferMemref(MemRefType memref, - const int hardware_generation, - std::array target_shape, - const TpuTilingFlags &tpu_tiling_flags, - bool is_kernel_argument, - int64_t leading_tile_rows) { - if (isa(memref.getElementType())) { - const Attribute semaphore_mem = tpu::MemorySpaceAttr::get( - memref.getContext(), MemorySpace::kSemaphoreMem); - SmallVector tile_strides; - tile_strides.reserve(memref.getRank()); - int64_t stride = 1; - for (int i = memref.getRank() - 1; i >= 0; --i) { - tile_strides.push_back(stride); - stride *= memref.getDimSize(i); - } - std::reverse(tile_strides.begin(), tile_strides.end()); - auto layout = TiledLayoutAttr::get(memref.getContext(), {}, tile_strides); - return MemRefType::get(memref.getShape(), memref.getElementType(), layout, - semaphore_mem); - } - const Attribute vmem = - tpu::MemorySpaceAttr::get(memref.getContext(), MemorySpace::vmem); - const Attribute memory_space = - memref.getMemorySpace() == nullptr ? vmem : memref.getMemorySpace(); - FAILUREOR_ASSIGN_OR_RETURN( - const TiledLayoutAttr layout, - inferLayout(memref, hardware_generation, target_shape, tpu_tiling_flags, - is_kernel_argument, leading_tile_rows)); - - const ArrayRef tiles = layout.getTiles(); - if (failed(checkTiles(memref.getContext(), tiles))) { - return failure(); - } - const xla::Tile &first_tile = tiles.front(); - const int64_t untiled_dims = - memref.getShape().size() - first_tile.dimensions().size(); - if (untiled_dims < 0) { - return emitError(UnknownLoc::get(memref.getContext()), "Invalid tiling"); - } - SmallVector new_shape(memref.getShape()); - for (int i = 0; i < first_tile.dimensions().size(); ++i) { - new_shape[untiled_dims + i] = - llvm::alignTo(new_shape[untiled_dims + i], first_tile.dimension(i)); - } - return MemRefType::get(new_shape, memref.getElementType(), layout, - memory_space); -} - -LogicalResult inferOp(Operation &op, const int hardware_generation, - std::array target_shape, - const TpuTilingFlags &tpu_tiling_flags) { - if (auto alloca_op = dyn_cast(op)) { - TypedValue arg = alloca_op.getResult(); - const MemRefType memref_ty = alloca_op.getResult().getType(); - // If the memref can be reinterpreted to untiled, force to use tiling - // {1, target.lane_count} for 32 bit. - int64_t leading_tile_rows = 0; - // TODO(b/375038685): generalize untiled memref with packed type which - // needs to update load/store rules. - if (memref_ty.getElementTypeBitWidth() == 32 && memref_ty.getRank() > 1 && - *(memref_ty.getShape().end() - 1) <= target_shape[1]) { - leading_tile_rows = 1; - } - FAILUREOR_ASSIGN_OR_RETURN( - const MemRefType new_memref_ty, - inferMemref(memref_ty, hardware_generation, target_shape, - tpu_tiling_flags, /*is_kernel_argument=*/false, - leading_tile_rows)); - alloca_op.getResult().setType(new_memref_ty); - if (memref_ty != new_memref_ty) { - OpBuilder builder(alloca_op->getContext()); - builder.setInsertionPointAfter(alloca_op); - // TODO(b/376130272): add a canonicalizer for EraseLayoutOp so that if we - // have erase(erase(x)) then we rewrite it to erase(x). - auto erase_op = builder.create( - arg.getLoc(), - MemRefType::get(new_memref_ty.getShape(), memref_ty.getElementType(), - /*layout=*/nullptr, new_memref_ty.getMemorySpace()), - arg); - arg.replaceAllUsesExcept(erase_op.getResult(), erase_op); - } - } else if (auto alloca_op = dyn_cast(op)) { - TypedValue arg = alloca_op.getResult(); - const MemRefType memref_ty = alloca_op.getResult().getType(); - FAILUREOR_ASSIGN_OR_RETURN( - const MemRefType new_memref_ty, - inferMemref(memref_ty, hardware_generation, target_shape, - tpu_tiling_flags, /*is_kernel_argument=*/false)); - alloca_op.getResult().setType(new_memref_ty); - if (memref_ty != new_memref_ty) { - OpBuilder builder(alloca_op->getContext()); - builder.setInsertionPointAfter(alloca_op); - auto erase_op = builder.create( - arg.getLoc(), - MemRefType::get(new_memref_ty.getShape(), memref_ty.getElementType(), - /*layout=*/nullptr, new_memref_ty.getMemorySpace()), - arg); - arg.replaceAllUsesExcept(erase_op.getResult(), erase_op); - } - } - for (Region ®ion : op.getRegions()) { - for (Block &block : region) { - for (Operation& op : block) { - if (failed(inferOp(op, hardware_generation, target_shape, - tpu_tiling_flags))) { - return failure(); - } - } - } - } - return success(); -} - -LogicalResult inferFunc(func::FuncOp f, const int hardware_generation, - std::array target_shape, - const TpuTilingFlags &tpu_tiling_flags) { - if (!f.getBody().hasOneBlock()) { - return f.emitOpError("Functions should only have a single block"); - } - Block &entry = f.getBody().front(); - SmallVector new_arg_types; - auto builder = OpBuilder::atBlockBegin(&entry); - for (int i = 0; i < entry.getNumArguments(); ++i) { - BlockArgument arg = entry.getArgument(i); - const auto memref_ty = dyn_cast(arg.getType()); - if (memref_ty == nullptr) { - new_arg_types.push_back(arg.getType()); - continue; - } - int64_t leading_tile_rows = 0; - auto leading_tile_rows_attr = - f.getArgAttrOfType(i, kLeadingTileRows); - if (leading_tile_rows_attr != nullptr) { - leading_tile_rows = leading_tile_rows_attr.getInt(); - f.removeArgAttr(i, kLeadingTileRows); - } - - FAILUREOR_ASSIGN_OR_RETURN( - MemRefType new_memref_ty, - inferMemref(memref_ty, hardware_generation, target_shape, - tpu_tiling_flags, /*is_kernel_argument=*/true, - leading_tile_rows)); - arg.setType(new_memref_ty); - new_arg_types.push_back(arg.getType()); - if (memref_ty != new_memref_ty) { - Value val = arg; - Operation * arg_use_op = nullptr; - // If the arg memref can be reinterpreted to untiled, we can insert - // ReinterpretCastOp to use tiling {packing, target.lane_count} before - // EraseLayoutOp for only the arg memrefs and expect the rest memref - // layout inference is based on the casted layout automatically. This - // would help lift many restrictions in alignment check when consuming - // this memref. - if (canReinterpretToUntiledMemref(cast>(val), - target_shape, - /*allow_minormost_padding=*/true) && - // TODO(b/375038685): generalize untiled memref with packed type which - // needs to update load/store rules. - new_memref_ty.getElementTypeBitWidth() == 32) { - auto tiled_layout = - cast(new_memref_ty.getLayout()); - SmallVector tiles(tiled_layout.getTiles()); - SmallVector new_tile_strides(tiled_layout.getTileStrides()); - for (int i = 0; i < new_tile_strides.size() - 2; ++i) { - new_tile_strides[i] *= tiles[0].dimension(0); - } - tiles[0] = ::xla::Tile({1, target_shape[1]}); - new_memref_ty = MemRefType::get( - new_memref_ty.getShape(), new_memref_ty.getElementType(), - TiledLayoutAttr::get(new_memref_ty.getContext(), tiles, - new_tile_strides), - new_memref_ty.getMemorySpace()); - arg_use_op = builder.create(val.getLoc(), - new_memref_ty, val); - val = arg_use_op->getResult(0); - } - // Some standard MLIR ops have static checks that seems unreasonable, - // and we know they hold in the way they are used in Mosaic. Still, - // verification with layouts likes to fail, because it can't statically - // prove the properties. - auto erase_op = builder.create( - val.getLoc(), - MemRefType::get(new_memref_ty.getShape(), memref_ty.getElementType(), - /*layout=*/nullptr, new_memref_ty.getMemorySpace()), - val); - if (!arg_use_op) { - arg_use_op = erase_op; - } - arg.replaceAllUsesExcept(erase_op.getResult(), arg_use_op); - } - } - f.setFunctionType( - builder.getAttr(new_arg_types, f.getResultTypes())); - for (Operation &op : entry.getOperations()) { - if (failed( - inferOp(op, hardware_generation, target_shape, tpu_tiling_flags))) { - return failure(); - } - } - return success(); -} - -struct InferMemRefLayoutPass - : public impl::InferMemRefLayoutPassBase { - InferMemRefLayoutPass(int hardware_generation_, - std::array target_shape_, - const TpuTilingFlags &tpu_tiling_flags_) { - hardware_generation = hardware_generation_; - sublane_count = target_shape_[0]; - lane_count = target_shape_[1]; - tpu_tiling_flags = tpu_tiling_flags_; - } - void runOnOperation() override { - // Fail if hardware_generation has not been set from the default value. - if (hardware_generation < 0) { - signalPassFailure(); - return; - } - func::FuncOp func = getOperation(); - if (failed(inferFunc(func, hardware_generation, {sublane_count, lane_count}, - tpu_tiling_flags))) { - signalPassFailure(); - return; - } - } -}; - -std::unique_ptr> createInferMemRefLayoutPass( - int hardware_generation, std::array target_shape, - const TpuTilingFlags &tpu_tiling_flags_) { - return std::make_unique( - hardware_generation, target_shape, tpu_tiling_flags_); -} - -} // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h deleted file mode 100644 index f2ab7c624eb1..000000000000 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h +++ /dev/null @@ -1,24 +0,0 @@ -#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_MEMREF_LAYOUT_H_ -#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_MEMREF_LAYOUT_H_ - -#include -#include -#include - -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/Support/LogicalResult.h" -#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" - -namespace mlir::tpu { - -FailureOr inferMemref(MemRefType memref, int hardware_generation, - std::array target_shape, - const TpuTilingFlags& tpu_tiling_flags, - bool is_kernel_argument, - int64_t leading_tile_rows = 0); - -const std::string_view kLeadingTileRows = "leading_tile_rows"; - -} // namespace mlir::tpu - -#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_MEMREF_LAYOUT_H_ diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc deleted file mode 100644 index 0081feba985b..000000000000 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ /dev/null @@ -1,2144 +0,0 @@ -/* Copyright 2023 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVectorExtras.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Value.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/types/span.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/OpDefinition.h" -#include "mlir/include/mlir/IR/Visitors.h" -#include "mlir/include/mlir/Pass/Pass.h" -#include "jaxlib/mosaic/dialect/tpu/layout.h" -#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" -#include "jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h" -#include "jaxlib/mosaic/dialect/tpu/util.h" -#include "xla/layout.h" - -namespace mlir::tpu { - -#define GEN_PASS_DECL_INFERVECTORLAYOUTPASS -#define GEN_PASS_DEF_INFERVECTORLAYOUTPASS -#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc" - -namespace { - -using ImplicitDim = VectorLayout::ImplicitDim; - -static constexpr int kLayoutLog = 10; - - -bool is_fully_replicated(const Layout &layout) { - static LayoutOffsets replicated_offsets = {std::nullopt, std::nullopt}; - return layout.has_value() && layout->offsets() == replicated_offsets; -} - -TiledLayoutAttr getMemRefLayout(Value ref) { - if (auto erase_op = ref.getDefiningOp()) { - ref = erase_op.getOperand(); - } - return cast(cast(ref.getType()).getLayout()); -} - -LogicalResult verifyDivisibleIndex(Value tiled_index, int64_t tiling, int dim, - Operation *op) { - if (!isGuaranteedDivisible(tiled_index, tiling)) { - return op->emitOpError("cannot statically prove that index in dimension ") - << dim << " is a multiple of " << tiling; - } - return success(); -} - -// TODO(apaszke): Test that this pass fills in NoLayout for all operations that -// have corresponding native instructions. -class VectorLayoutInferer { - public: - explicit VectorLayoutInferer(int hardware_generation, - std::array target_shape, - const TpuTilingFlags &tpu_tiling_flags) - : hardware_generation_(hardware_generation), - target_shape_({target_shape[0], target_shape[1]}), - default_tiling_(target_shape), - tpu_tiling_flags_(tpu_tiling_flags) {} - -#define TPU_CHECK_OP(cond, msg) \ - if (!(cond)) { \ - op->emitOpError(msg); \ - return failure(); \ - } - -#define NYI(msg) \ - op->emitOpError("not implemented: " msg); \ - return failure(); - - LogicalResult inferBlock( - Block &block, - const std::function &match_terminator) { - for (Operation &any_op : block.without_terminator()) { - VLOG(kLayoutLog) << Print(&any_op); - if (any_op.hasAttr("in_layout") || any_op.hasAttr("out_layout")) { - if (auto op = dyn_cast(any_op)) { - TPU_CHECK_OP( - any_op.hasAttr("in_layout") && any_op.hasAttr("out_layout"), - "expect layout attributes in tpu::AssumeLayoutOp"); - continue; - } else { - any_op.emitOpError("layout attributes already attached"); - return failure(); - } - } - - // TODO: b/342235360 - This check is temporary while we increase and test - // support for offsets outside of the first tile. When support is more - // broad, any op without support should check it within their own rule. - if (!isa(any_op)) { - const SmallVector layouts_in = getLayoutFromOperands(&any_op); - for (const Layout &layout : layouts_in) { - if (layout && - layout->offsets()[1].value_or(0) >= layout->tiling()[1]) { - force_first_tile_offsets_ = true; - } - } - } - - bool has_vector_io = false; - for (auto op : any_op.getOperands()) { - has_vector_io |= op.getType().isa(); - } - for (auto r : any_op.getResults()) { - has_vector_io |= r.getType().isa(); - } - if (!has_vector_io && any_op.getRegions().empty()) { - SmallVector in_layout(any_op.getNumOperands(), kNoLayout); - if (any_op.getNumResults() == 0) { - setInLayout(&any_op, in_layout); - } else if (any_op.getNumResults() == 1) { - setLayout(&any_op, in_layout, kNoLayout); - } else { - any_op.emitOpError("Multi-result ops not supported"); - return failure(); - } - } else if (isa(any_op)) { - if (inferExt(&any_op).failed()) { - return failure(); - } - } else if (isa(any_op)) { - if (inferTrunc(&any_op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op); - op && - cast(op.getOperand().getType()) - .getElementTypeBitWidth() > - cast(op.getType()).getElementTypeBitWidth()) { - if (inferTrunc(&any_op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - auto true_ty = dyn_cast(op.getTrueValue().getType()); - auto false_ty = dyn_cast(op.getFalseValue().getType()); - TPU_CHECK_OP(static_cast(true_ty) == static_cast(false_ty), - "Only one side of arith is a vector?"); - if (inferElementwise(&any_op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - auto in_ty = dyn_cast(op.getIn().getType()); - auto out_ty = dyn_cast(op.getType()); - TPU_CHECK_OP(static_cast(in_ty) == static_cast(out_ty), - "Input and output are not both vectors?"); - auto in_bitwidth = in_ty ? in_ty.getElementTypeBitWidth() - : op.getIn().getType().getIntOrFloatBitWidth(); - if (in_bitwidth == 1) { - if (inferElementwise(&any_op).failed()) { - return failure(); - } - } else { - if (inferExt(&any_op).failed()) { - return failure(); - } - } - } else if (isa(any_op) || isa(any_op)) { - Operation *op = &any_op; // For TPU_CHECK_OP macros, which use the `op` - // variable in scope - auto lhs_ty = dyn_cast(any_op.getOperand(0).getType()); - auto rhs_ty = dyn_cast(any_op.getOperand(1).getType()); - TPU_CHECK_OP(static_cast(lhs_ty) == static_cast(rhs_ty), - "Only one side of cmp is a vector?"); - // TODO(tlongeri): Check that TPU generation supports comparison. - if (inferElementwise(&any_op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (inferStore(op, - /*has_mask=*/op.getMask() != nullptr) - .failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (inferStore(op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } - } else if (OpTrait::hasElementwiseMappableTraits(&any_op)) { - // We put elementwise rule to the end in case the overriding rule. - if (inferElementwise(&any_op).failed()) { - return failure(); - } - } else if (mlir::tpu::extensions::canInferVectorLayout(any_op)) { - if (mlir::tpu::extensions::inferVectorLayout(any_op, target_shape_) - .failed()) { - return failure(); - } - } else { - return any_op.emitError("Not implemented: Unsupported operation: ") - << any_op.getName() << " in infer-vector-layout pass"; - } - CHECK(any_op.getNumResults() == 0 || any_op.hasAttr("out_layout")); - CHECK(any_op.getNumOperands() == 0 || any_op.hasAttr("in_layout")); - force_first_tile_offsets_ = false; - } - return match_terminator(block.getTerminator()); - } - - LogicalResult infer(arith::ConstantOp op) { - if (op.getType().isSignlessIntOrIndexOrFloat()) { - setOutLayout(op, kNoLayout); - return success(); - } - if (auto ty = dyn_cast(op.getType())) { - auto elems = dyn_cast(op.getValue()); - TPU_CHECK_OP(ty.getElementType().isSignlessIntOrIndexOrFloat(), - "expected scalar element type in vector"); - TPU_CHECK_OP(ty.getRank() > 0, "rank 0 vectors unsupported"); - TPU_CHECK_OP(elems, "expected vector constants to use DenseElementsAttr"); - auto bitwidth = ty.getElementTypeBitWidth(); - if (bitwidth == 1) { - // i1 is a special case where the layout bitwidth can be different from - // the element bitwidth, see comment in VectorLayout class - bitwidth = kNativeBitwidth; - } - if (elems.isSplat()) { - if (ty.getRank() == 1) { - // Here, we choose to lay out along lanes arbitrarily. It would be - // equally valid to go with sublanes. Still, this value is so easy - // to relayout that it shouldn't really make a difference. - setOutLayout(op, VectorLayout(bitwidth, {std::nullopt, std::nullopt}, - nativeTiling(bitwidth), - ImplicitDim::kSecondMinor)); - } else { // ty.getRank() >= 2 - setOutLayout( - op, VectorLayout(bitwidth, {std::nullopt, std::nullopt}, - nativeTiling(bitwidth), ImplicitDim::kNone)); - } - } else { - TPU_CHECK_OP(ty.getElementTypeBitWidth() == kNativeBitwidth, - "Only 32-bit non-splat constants supported"); - if (ty.getRank() == 1) { - if (ty.getDimSize(0) <= target_shape_[0]) { - // Use 2D layout with replication. - NYI("small 1D constants"); - } else { // NOLINT(readability-else-after-return) - NYI("large 1D constants"); - } - } else { // ty.getRank() >= 2 - setOutLayout(op, VectorLayout(kNativeBitwidth, {0, 0}, - default_tiling_, ImplicitDim::kNone)); - } - } - return success(); - } - op.emitOpError("unsupported constant type"); - return failure(); - } - - LogicalResult infer(cf::AssertOp op) { - setInLayout(op, {kNoLayout}); - return success(); - } - - LogicalResult infer(func::FuncOp op) { - if (!op.getBody().hasOneBlock()) { - op.emitOpError("Only one block functions supported"); - return failure(); - } - return inferBlock( - op.getBody().front(), [this](Operation *op) -> LogicalResult { - TPU_CHECK_OP(isa(op), - "Expected func.return terminator"); - for (Value o : op->getOperands()) { - TPU_CHECK_OP(!isa(o.getType()), - "vector returns unsupported"); - } - SmallVector in_layout(op->getNumOperands(), {kNoLayout}); - setInLayout(op, in_layout); - return success(); - }); - } - - LogicalResult infer(memref::LoadOp op) { - TPU_CHECK_OP(op.getType().isSignlessIntOrIndexOrFloat(), - "memref.load with non-scalar result"); - SmallVector in_layout(op.getNumOperands(), {kNoLayout}); - setLayout(op, in_layout, kNoLayout); - return success(); - } - - LogicalResult infer(scf::IfOp op) { - static LogicalResult (*match_yield)(Operation *) = [](Operation *op) { - TPU_CHECK_OP(isa(op), "expected yield terminator"); - return success(); - }; - TPU_CHECK_OP(op->getNumOperands() == 1, "expected one operand"); - setInLayout(op, {kNoLayout}); - if (inferBlock(*op.thenBlock(), match_yield).failed()) { - op.emitOpError("failed to infer layout for then branch"); - return failure(); - } - auto then_yield = op.thenBlock()->getTerminator(); - TPU_CHECK_OP(then_yield->getOperandTypes() == op->getResultTypes(), - "scf if results and then branch yield operands do not match"); - auto then_yield_in_layouts = getLayoutFromOperands(then_yield); - if (auto else_block = op.elseBlock()) { - if (inferBlock(*else_block, match_yield).failed()) { - op.emitOpError("failed to infer layout for else branch"); - return failure(); - } - } - if (op->getNumResults() == 0) { - return success(); - } - // If the if op has results, it should have both then and else regions with - // yield op. - auto else_yield = op.elseBlock()->getTerminator(); - TPU_CHECK_OP(else_yield->getOperandTypes() == op->getResultTypes(), - "scf if results and else branch yield operands do not match"); - auto else_yield_in_layouts = getLayoutFromOperands(else_yield); - // Find a compatible layout from then and else branches for each reuslt. For - // example, if we yield offset (*, *) in then branch and offset (*, 0) in - // else branch, the result offset should be (*, 0). - SmallVector out_layouts; - out_layouts.reserve(op->getNumResults()); - int out_idx = 0; - for (auto [then_layout, else_layout, result] : llvm::zip_equal( - then_yield_in_layouts, else_yield_in_layouts, op.getResults())) { - if (auto vty = dyn_cast(result.getType())) { - if (!then_layout.has_value()) { - return op.emitOpError( - "expected a vector layout for then yield input ") - << out_idx; - } - if (!else_layout.has_value()) { - return op.emitOpError( - "expected a vector layout for else yield input ") - << out_idx; - } - auto compatible_layout = VectorLayout::join( - then_layout.value(), else_layout.value(), vty.getShape()); - // If no compatible layout is found in layouts for then and else - // branches, the output layout falls back to a normalized layout which - // has offsets 0 and the native tiling. - if (!compatible_layout.has_value()) { - compatible_layout = VectorLayout( - then_layout->bitwidth(), {0, 0}, - nativeTiling(then_layout->bitwidth()), ImplicitDim::kNone); - } - out_layouts.push_back(compatible_layout); - } else { - if (then_layout.has_value()) { - return op.emitOpError("expected no layout for then yield input ") - << out_idx; - } - if (else_layout.has_value()) { - return op.emitOpError("expected no layout for else yield input ") - << out_idx; - } - out_layouts.push_back(kNoLayout); - } - ++out_idx; - } - setInLayout(then_yield, out_layouts); - setInLayout(else_yield, out_layouts); - setOutLayout(op, out_layouts); - return success(); - } - - LogicalResult infer(scf::ForOp op) { - static LogicalResult (*match_yield)(Operation *) = [](Operation *op) { - TPU_CHECK_OP(isa(op), "expected yield terminator"); - return success(); - }; - TPU_CHECK_OP(op.getRegion().hasOneBlock(), - "expected one block for scf.for"); - TPU_CHECK_OP( - op.getNumRegionIterArgs() == op.getNumResults(), - "expected num_region_iter_args is equal to num_results in scf.for"); - TPU_CHECK_OP( - op->getNumOperands() == 3 + op.getNumResults(), - "expected num_operands is equal to 3 + num_results in scf.for"); - - auto in_layouts = getLayoutFromOperands(op); - // Drop the input layouts for lower bound, upper bound. But keep the layout - // for step because it matches with induction variable in arguments. - auto arg_layouts = ArrayRef(in_layouts).drop_front(2); - if (assumeLayoutsForBlockArgs(*op.getBody(), arg_layouts).failed() || - inferBlock(*op.getBody(), match_yield).failed()) { - return op.emitOpError( - "failed to infer layout with initial layouts for body in " - "scf.for op"); - } - auto yield_op = op.getBody()->getTerminator(); - auto yield_in_layouts = getLayoutFromOperands(yield_op); - - SmallVector out_layouts; - out_layouts.reserve(op->getNumResults()); - int out_idx = 0; - bool require_reinfer = false; - for (auto [in_layout, yield_layout, result] : - llvm::zip_equal(arg_layouts.drop_front( - 1), // Drop the layout for induction variable. - yield_in_layouts, op.getResults())) { - if (auto vty = dyn_cast(result.getType())) { - if (!in_layout.has_value()) { - return op.emitOpError("expected a vector layout for input ") - << out_idx; - } - if (!yield_layout.has_value()) { - return op.emitOpError("expected a vector layout for yield input ") - << out_idx; - } - auto compatible_layout = VectorLayout::join( - in_layout.value(), yield_layout.value(), vty.getShape()); - // If no compatible layout is found in layouts for input and - // yield, the output layout falls back to a normalized layout which - // has offsets 0 and the native tiling. - if (!compatible_layout.has_value()) { - compatible_layout = VectorLayout(in_layout->bitwidth(), {0, 0}, - nativeTiling(in_layout->bitwidth()), - ImplicitDim::kNone); - } - if (!require_reinfer && - (compatible_layout.value() != in_layout.value() || - compatible_layout.value() != yield_layout.value())) { - require_reinfer = true; - } - out_layouts.push_back(compatible_layout); - } else { - if (in_layout.has_value()) { - return op.emitOpError("expected no layout for input ") << out_idx; - } - if (yield_layout.has_value()) { - return op.emitOpError("expected no layout for yield input ") - << out_idx; - } - out_layouts.push_back(kNoLayout); - } - ++out_idx; - } - if (require_reinfer) { - // Force same layouts in input layout but skip the first 3 layouts for - // lower bound, upper bound and step. - std::copy(out_layouts.begin(), out_layouts.end(), in_layouts.begin() + 3); - - // Terminator in the loop will carry layouts to the next loop but - // the loop's block args' layouts are determined by the initial inputs. We - // need to force the same layouts for all in order to make layouts be - // consistent across all branches. To ensure that, we need to reprocess - // layout inference for the entire body with the final consolidated - // layout. - clearBlockLayouts(*op.getBody()); - if (assumeLayoutsForBlockArgs(*op.getBody(), - ArrayRef(in_layouts).drop_front(2)) - .failed() || - inferBlock(*op.getBody(), match_yield).failed()) { - return op.emitOpError( - "failed to infer layout with compatible layouts for body in " - "scf.for op"); - } - } - setInLayout(yield_op, out_layouts); - setLayout(op, in_layouts, out_layouts); - return success(); - } - - LogicalResult infer(scf::WhileOp op) { - static LogicalResult (*match_condition)(Operation *) = [](Operation *op) { - TPU_CHECK_OP(isa(op), "expected condition terminator"); - return success(); - }; - static LogicalResult (*match_yield)(Operation *) = [](Operation *op) { - TPU_CHECK_OP(isa(op), "expected yield terminator"); - return success(); - }; - TPU_CHECK_OP(op.getNumRegions() == 2, "expected two blocks for scf.while"); - - SmallVector in_layouts = getLayoutFromOperands(op); - - if (assumeLayoutsForBlockArgs(*op.getBeforeBody(), in_layouts).failed() || - inferBlock(*op.getBeforeBody(), match_condition).failed()) { - return op.emitOpError( - "failed to infer layout with initial layouts for before body in " - "scf.while op"); - } - - if (assumeLayoutsForBlockArgs(*op.getAfterBody(), in_layouts).failed() || - inferBlock(*op.getAfterBody(), match_yield).failed()) { - return op.emitOpError( - "failed to infer layout with initial layouts for after body in " - "scf.while op"); - } - - auto *cond_op = op.getBeforeBody()->getTerminator(); - auto cond_in_layouts = getLayoutFromOperands(cond_op); - auto *yield_op = op.getAfterBody()->getTerminator(); - auto yield_in_layouts = getLayoutFromOperands(yield_op); - - // Find a compatible layout from condition body and loop body for each - // reuslt. For example, if we yield offset (*, *) in condition body and - // offset (*, 0) in loop body, the result offset should be (*, 0). - SmallVector out_layouts; - out_layouts.reserve(op->getNumResults()); - int out_idx = 0; - bool require_reinfer = false; - for (auto [in_layout, cond_layout, yield_layout, result] : llvm::zip_equal( - in_layouts, ArrayRef(cond_in_layouts).drop_front(1), - yield_in_layouts, op.getResults())) { - if (auto vty = dyn_cast(result.getType())) { - if (!in_layout.has_value()) { - return op.emitOpError("expected a vector layout for whileOp input ") - << out_idx; - } - if (!cond_layout.has_value()) { - return op.emitOpError("expected a vector layout for condition input ") - << out_idx + 1; // ConditionOp's first input is 1 bit bool. - } - if (!yield_layout.has_value()) { - return op.emitOpError("expected a vector layout for yield input ") - << out_idx; - } - auto compatible_layout = VectorLayout::join( - cond_layout.value(), yield_layout.value(), vty.getShape()); - if (compatible_layout.has_value()) { - compatible_layout = VectorLayout::join( - in_layout.value(), compatible_layout.value(), vty.getShape()); - } - // If no compatible layout is found in layouts for input, condition and - // yield, the output layout falls back to a normalized layout which - // has offsets 0 and the native tiling. - if (!compatible_layout.has_value()) { - compatible_layout = VectorLayout(in_layout->bitwidth(), {0, 0}, - nativeTiling(in_layout->bitwidth()), - ImplicitDim::kNone); - } - if (!require_reinfer && - (compatible_layout.value() != in_layout.value() || - compatible_layout.value() != cond_layout.value() || - compatible_layout.value() != yield_layout.value())) { - require_reinfer = true; - } - out_layouts.push_back(compatible_layout); - } else { - if (in_layout.has_value()) { - return op.emitOpError("expected no layout for whileOp input ") - << out_idx; - } - if (cond_layout.has_value()) { - return op.emitOpError("expected no layout for condition input ") - << out_idx + 1; // ConditionOp's first input is 1 bit bool. - } - if (yield_layout.has_value()) { - return op.emitOpError("expected no layout for yield input ") - << out_idx; - } - out_layouts.push_back(kNoLayout); - } - ++out_idx; - } - if (require_reinfer) { - clearBlockLayouts(*op.getBeforeBody()); - clearBlockLayouts(*op.getAfterBody()); - // Terminator in the loop will carry layouts to the next loop but - // the loop's block args' layouts are determined by the initial inputs. We - // need to force the same layouts for all in order to make layouts be - // consistent across all branches. To ensure that, we need to reprocess - // layout inference for the entire body with the final consolidated - // layout. - if (assumeLayoutsForBlockArgs(*op.getBeforeBody(), out_layouts) - .failed() || - inferBlock(*op.getBeforeBody(), match_condition).failed()) { - return op.emitOpError( - "failed to infer layout with compatible layouts for before body in " - "scf.while op"); - } - if (assumeLayoutsForBlockArgs(*op.getAfterBody(), out_layouts).failed() || - inferBlock(*op.getAfterBody(), match_yield).failed()) { - return op.emitOpError( - "failed to infer layout with compatible layouts for after body in " - "scf.while op"); - } - } - std::copy(out_layouts.begin(), out_layouts.end(), - cond_in_layouts.begin() + 1); // Skip the first 1 bit bool. - setInLayout(cond_op, cond_in_layouts); - setInLayout(yield_op, out_layouts); - setLayout(op, out_layouts, out_layouts); - return success(); - } - - // TODO(b/347016737): deprecate the static rotate. - LogicalResult infer(tpu::RotateOp op) { - auto bitwidth = op.getType().getElementTypeBitWidth(); - if (bitwidth != 32) { - NYI("Rotate with non-32-bit data"); - } - if (op.getType().getRank() < 2) { - NYI("Unsupported 1D shape"); - } - auto layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth), - ImplicitDim::kNone); - setLayout(op, layout, layout); - return success(); - } - - LogicalResult infer(tpu::DynamicRotateOp op) { - auto bitwidth = op.getType().getElementTypeBitWidth(); - // TODO(b/347067057): Support dynamic rotate with packed dtype. - if (bitwidth != 32) { - NYI("Rotate with non-32-bit data"); - } - if (op.getType().getRank() < 2) { - NYI("Unsupported 1D shape"); - } - auto layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth), - ImplicitDim::kNone); - setLayout(op, {layout, kNoLayout}, layout); - return success(); - } - - LogicalResult infer(tpu::ConcatenateOp op) { - TPU_CHECK_OP(!op.getSources().empty(), - "Need at least one vector to concatenate"); - int64_t res_rank = op.getType().getRank(); - uint32_t dimension = op.getDimension(); - TPU_CHECK_OP(0 <= dimension && dimension < res_rank, - "Expect a valid concatenate dimension"); - VectorType res_ty = op.getResult().getType(); - - std::optional tiling_dim; - if (dimension == res_ty.getRank() - 1) { - tiling_dim = 1; - } else if (dimension == res_ty.getRank() - 2) { - tiling_dim = 0; - } - - if (tiling_dim.has_value()) { - int64_t starting_point = 0; - - Layout first_layout = getLayout(op.getSources().front()); - SmallVector op_layouts = getLayoutFromOperands(op); - SmallVector in_layouts; - in_layouts.reserve(op.getSources().size()); - int8_t bitwidth = first_layout->bitwidth(); - - // Set implicit dim to treat 1D as (1, N) and tile it as (1, 128) - std::array tiling = - res_rank == 1 ? std::array{1L, target_shape_[1]} - : nativeTiling(bitwidth); - ImplicitDim implicit_dim = - res_rank == 1 ? ImplicitDim::kSecondMinor : ImplicitDim::kNone; - std::array vreg_slice = - VectorLayout::vregSlice(target_shape_, bitwidth, tiling); - for (int i = 0; i < op.getSources().size(); ++i) { - // Compute the offset per source. - // Ex: for a cat of (10, 128), (10, 128) on dim 0, where the - // vreg_slice for that dim is 8, the first source starts at - // offset 0, and overflows the vreg - // by 2, so the offset for the second input is 2. - ArrayRef op_shape = - cast(op.getSources()[i].getType()).getShape(); - Layout op_layout = op_layouts[i]; - int64_t offset_amount = starting_point % vreg_slice[tiling_dim.value()]; - if (offset_amount >= tiling[tiling_dim.value()]) { - return op.emitError( - "Not implemented: Input offsets outside of the first tile"); - } - SmallVector in_idx{op_layout->offsets()[0].value_or(0), - op_layout->offsets()[1].value_or(0)}; - in_idx[tiling_dim.value()] = offset_amount; - starting_point += op_shape[dimension]; - in_layouts.push_back(VectorLayout(bitwidth, {in_idx[0], in_idx[1]}, - tiling, implicit_dim)); - } - SmallVector res_layout_offsets( - {first_layout->offsets()[0].value_or(0), - first_layout->offsets()[1].value_or(0)}); - res_layout_offsets[tiling_dim.value()] = 0; - // TODO(mvoz): A tiny optimization we could do here later is to - // no-op setting tiling when sublane dim size is aligned to sublane - // tiling. - VectorLayout res_layout = - VectorLayout(bitwidth, {res_layout_offsets[0], res_layout_offsets[1]}, - tiling, implicit_dim); - setLayout(op, in_layouts, res_layout); - return success(); - } else { - Layout layout = getLayout(op.getSources().front()); - // When concatenating vectors with replicated offsets, we want to reset - // the replicated offset to zero. Because we are not sure if the - // replicated value from each vector are same. - layout = VectorLayout( - layout->bitwidth(), - {layout->offsets()[0].value_or(0), layout->offsets()[1].value_or(0)}, - layout->tiling(), layout->implicit_dim()); - SmallVector in_layouts(op->getNumOperands(), layout); - setLayout(op, in_layouts, layout); - return success(); - } - } - - LogicalResult infer(tpu::LoadOp op) { - auto res_ty = op.getResult().getType(); - int8_t bitwidth = res_ty.getElementTypeBitWidth(); - - // We expect the result is already a native-sized vreg. - TPU_CHECK_OP(bitwidth == 32 && res_ty.getShape()[0] == target_shape_[0] && - res_ty.getShape()[1] == target_shape_[1], - "Only 32-bit loads supported"); - SmallVector in_layout(op->getNumOperands(), kNoLayout); - auto out_layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth), - ImplicitDim::kNone); - setLayout(op, in_layout, out_layout); - return success(); - } - - LogicalResult infer(tpu::StridedLoadOp op) { - auto vty = op.getResult().getType(); - int8_t bitwidth = vty.getElementTypeBitWidth(); - if (bitwidth != 32) { - NYI("Strided load with non 32-bit data"); - } - if (vty.getRank() < 2) { - NYI("Strided load with 1D vector"); - } - SmallVector in_layout(op->getNumOperands(), kNoLayout); - setLayout(op, in_layout, - VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth), - ImplicitDim::kNone)); - return success(); - } - - LogicalResult infer(tpu::StridedStoreOp op) { - auto vty = op.getValueToStore().getType(); - int8_t bitwidth = vty.getElementTypeBitWidth(); - if (bitwidth != 32) { - NYI("Strided store with non 32-bit data"); - } - if (vty.getRank() < 2) { - NYI("Strided store with 1D vector"); - } - auto store_layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth), - ImplicitDim::kNone); - SmallVector in_layout{op->getNumOperands(), kNoLayout}; - in_layout[0] = store_layout; - setInLayout(op, in_layout); - return success(); - } - - LogicalResult infer(tpu::MatmulOp op) { - auto lhs_bitwidth = op.getLhs().getType().getElementTypeBitWidth(); - auto rhs_bitwidth = op.getRhs().getType().getElementTypeBitWidth(); - auto acc_bitwidth = op.getAcc().getType().getElementTypeBitWidth(); - auto res_bitwidth = op.getResult().getType().getElementTypeBitWidth(); - TPU_CHECK_OP(acc_bitwidth == kNativeBitwidth, - "Expected 32-bit acc in tpu::MatmulOp"); - TPU_CHECK_OP(res_bitwidth == kNativeBitwidth, - "Expected 32-bit result in tpu::MatmulOp"); - auto lhs_layout = VectorLayout( - lhs_bitwidth, {0, 0}, nativeTiling(lhs_bitwidth), ImplicitDim::kNone); - auto rhs_layout = VectorLayout( - rhs_bitwidth, {0, 0}, nativeTiling(rhs_bitwidth), ImplicitDim::kNone); - auto acc_layout = VectorLayout( - acc_bitwidth, {0, 0}, nativeTiling(acc_bitwidth), ImplicitDim::kNone); - setLayout(op, {lhs_layout, rhs_layout, acc_layout}, - VectorLayout(kNativeBitwidth, {0, 0}, default_tiling_, - ImplicitDim::kNone)); - return success(); - } - - LogicalResult infer(tpu::StoreOp op) { - auto store_ty = op.getValueToStore().getType(); - int8_t bitwidth = store_ty.getElementTypeBitWidth(); - - // We expect the value to store is already a native-sized vreg. - TPU_CHECK_OP(bitwidth == 32 && store_ty.getShape()[0] == target_shape_[0] && - store_ty.getShape()[1] == target_shape_[1], - "Only 32-bit stores supported"); - auto store_layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth), - ImplicitDim::kNone); - SmallVector in_layout{store_layout}; - in_layout.insert(in_layout.end(), op.getIndices().size() + 1, kNoLayout); - setInLayout(op, in_layout); - return success(); - } - - LogicalResult infer(tpu::EraseLayoutOp op) { - setLayout(op, kNoLayout, kNoLayout); - return success(); - } - - LogicalResult infer(tpu::GatherOp op) { - auto src_layout = getLayout(op.getSource()); - setLayout(op, src_layout, src_layout); - return success(); - } - - LogicalResult infer(tpu::DynamicGatherOp op) { - if (op.getType().getShape() != ArrayRef(target_shape_) && - op.getType().getElementTypeBitWidth() != 32) { - return op.emitOpError( - "Not implemented: DynamicGatherOp only supports 32-bit VREG shape"); - } - if (op.getDimension() != 0 && op.getDimension() != 1) { - return op.emitOpError( - "Not implemented: Only dimension 0 and 1 are supported"); - } - // TODO(jevinjiang): we could preserve some offsets such as replicated - // offset but since we are forcing all operands and result to be the same - // layout, we can set all offsets to zero for now. Also maybe we should - // consider adding this to elementwise rule. - auto layout = VectorLayout(kNativeBitwidth, {0, 0}, default_tiling_, - ImplicitDim::kNone); - setLayout(op, {layout, layout}, layout); - return success(); - } - - LogicalResult infer(tpu::BitcastOp op) { - // Note we have verified the shapes in verify(). - auto in_ty = cast(op.getInput().getType()); - auto out_ty = cast(op.getOutput().getType()); - auto in_bitwidth = in_ty.getElementTypeBitWidth(); - auto out_bitwidth = out_ty.getElementTypeBitWidth(); - auto src_layout = getLayout(op.getInput()); - LayoutOffsets src_offsets = src_layout->offsets(); - auto implicit_dim = src_layout->implicit_dim(); - if (src_offsets[0].value_or(0) * in_bitwidth % out_bitwidth != 0) { - // Force offset to zero if the input offset on the second minor dimension - // is not a multiple of the ratio of output and input bitwidth. - src_offsets[0] = 0; - } else if (!src_offsets[0].has_value() && in_bitwidth > out_bitwidth) { - // We can't preserve replicated offset for decreasing bitwidth. - src_offsets[0] = 0; - } - // Force implicit dim to None if the bitwidth changes. Because we expect 2nd - // minor dim size ratio matches the bitwidth ratio in input and output. - if (in_bitwidth != out_bitwidth) { - if (in_ty.getRank() < 2 || out_ty.getRank() < 2) { - return op.emitOpError( - "Not implemented: bitcast between different bitwidths on a 1D " - "vector."); - } - implicit_dim = ImplicitDim::kNone; - } - // TODO(b/348485035): Instead of forcing to native tiling, bitcast should - // keep the input tiling and infer bitcastable tiling for output. For - // example, it is valid to bitcast vector<8x128xi32> with tile (1, 128) to - // vector<8x128xbf16> with tile (2, 128). - setLayout( - op, - VectorLayout(in_bitwidth, src_offsets, nativeTiling(in_bitwidth), - implicit_dim), - VectorLayout(out_bitwidth, - {src_offsets[0].has_value() - ? src_offsets[0].value() * in_bitwidth / out_bitwidth - : src_offsets[0], - src_offsets[1]}, - nativeTiling(out_bitwidth), implicit_dim)); - return success(); - } - - LogicalResult infer(tpu::TraceOp op) { - static LogicalResult (*match_yield)(Operation *) = [](Operation *op) { - TPU_CHECK_OP(isa(op), "expected yield terminator"); - return success(); - }; - TPU_CHECK_OP(op->getNumOperands() == 0, "expected no operands"); - TPU_CHECK_OP(op->getNumResults() == 0, "results unsupported"); - return inferBlock(*op.getBody(), match_yield); - } - - LogicalResult infer(tpu::RegionOp op) { - static LogicalResult (*match_region)(Operation *) = [](Operation *op) { - TPU_CHECK_OP(isa(op), "expected yield terminator"); - return success(); - }; - TPU_CHECK_OP(op->getNumOperands() == 0, "expected no operands"); - auto body_result = - inferBlock((*op).getRegion(0).getBlocks().front(), match_region); - if (body_result.failed()) { - return op.emitOpError("failed to infer vector layout in region body"); - } - auto yield_op = op.getBody()->getTerminator(); - auto yield_in_layouts = getLayoutFromOperands(yield_op); - setInLayout(yield_op, yield_in_layouts); - setOutLayout(op, yield_in_layouts); - return success(); - } - - LogicalResult infer(tpu::IotaOp op) { - auto ty = op.getResult().getType(); - TPU_CHECK_OP(ty.getElementType().isSignlessInteger(32), - "Only 32-bit integer iota supported"); - TPU_CHECK_OP(ty.getRank() >= 2, "iota rank below 2D unsupported"); - LayoutOffsets offsets = {0, 0}; - if (op.getDimension() == ty.getRank() - 1) { - offsets[0] = std::nullopt; - } - if (op.getDimension() == ty.getRank() - 2) { - offsets[1] = std::nullopt; - } - setOutLayout(op, VectorLayout(kNativeBitwidth, offsets, default_tiling_, - ImplicitDim::kNone)); - return success(); - } - - LogicalResult infer(vector::BroadcastOp op) { - auto some_src_ty = op.getSourceType(); - auto res_ty = op.getResultVectorType(); - TPU_CHECK_OP(res_ty.getRank() > 0, "rank 0 vectors unsupported"); - if (some_src_ty.isSignlessIntOrIndexOrFloat()) { - auto bitwidth = some_src_ty.getIntOrFloatBitWidth(); - // TODO(b/320725357): We need a better design for mask layout. For now, we - // always set layout bitwidth of Vmask to 32bit. - if (bitwidth == 1) { - bitwidth = kNativeBitwidth; - } - if (res_ty.getRank() == 1) { - // We use a full vreg tile, because only then its layout can be changed - // for free. - setLayout( - op, kNoLayout, - VectorLayout(bitwidth, {std::nullopt, std::nullopt}, - nativeTiling(bitwidth), ImplicitDim::kSecondMinor)); - } else { // rank >= 2 // NOLINT(readability-else-after-return) - setLayout(op, kNoLayout, - VectorLayout(bitwidth, {std::nullopt, std::nullopt}, - nativeTiling(bitwidth), ImplicitDim::kNone)); - } - return success(); - } - if (auto src_ty = dyn_cast(some_src_ty)) { - auto some_layout = getLayout(op.getSource()); - TPU_CHECK_OP(some_layout.has_value(), "missing vector layout"); - auto &layout = *some_layout; - if (layout.implicit_dim() != ImplicitDim::kNone && src_ty.getRank() > 1) { - VectorLayout layout_2d(layout.bitwidth(), layout.offsets(), - layout.tiling(), ImplicitDim::kNone); - if (layout_2d.equivalentTo(layout, src_ty.getShape(), target_shape_)) { - // TODO(b/342237796): Stop preferring 2D layouts (if given the choice) - // and defer the work, if any, to relayout. - layout = layout_2d; - } - } - auto src_tiled_ishape = layout.getImplicitTiledDims(src_ty.getShape(), 1); - auto dst_tiled_ishape = layout.getImplicitTiledDims(res_ty.getShape(), 1); - // Since we can only do sublane broadcasts in the (8, 128) tiling, we - // should always use that when sublane broadcasting is required. - if (src_tiled_ishape[0] != dst_tiled_ishape[0] && - layout.offsets()[0] != std::nullopt) { - LayoutOffsets offsets = layout.offsets(); - // At the moment relayout can only produce replicated sublanes when - // converting to (8, 128) if the input was in (1, 128) tiling - if (layout.tiling()[0] == 1 && layout.bitwidth() == kNativeBitwidth) { - offsets[0] = std::nullopt; - } - layout = VectorLayout(layout.bitwidth(), offsets, - nativeTiling(layout.bitwidth()), - layout.implicit_dim()); - } - LayoutOffsets offsets = layout.offsets(); - for (int i = 0; i < 2; ++i) { - if (src_tiled_ishape[i] != dst_tiled_ishape[i]) { - offsets[i] = std::nullopt; - } - } - setLayout(op, layout, - VectorLayout(layout.bitwidth(), offsets, layout.tiling(), - layout.implicit_dim())); - return success(); - } - op.emitOpError("unsupported broadcast source type"); - return failure(); - } - - LogicalResult infer(vector::ExtractOp op) { - TPU_CHECK_OP(!op.hasDynamicPosition(), "dynamic indices not supported"); - TPU_CHECK_OP( - op.getSourceVectorType().getElementTypeBitWidth() == kNativeBitwidth, - "Only 32-bit types supported"); - auto layout = getLayout(op.getVector()); - TPU_CHECK_OP(layout.has_value(), "missing vector layout"); - if (VectorType res_vty = dyn_cast(op.getResult().getType()); - res_vty != nullptr) { - if (res_vty.getRank() == 1 && - layout->implicit_dim() == ImplicitDim::kNone) { - const int64_t second_minor_idx = op.getStaticPosition().back(); - const LayoutOffset second_minor_offset = layout->offsets()[0]; - const LayoutOffset res_second_minor_offset = - second_minor_offset.has_value() - ? (*second_minor_offset + second_minor_idx) % - layout->vregSlice(target_shape_)[0] - : LayoutOffset(); - // TODO: b/342235360 - We should already support this but it needs - // testing. - TPU_CHECK_OP(!res_second_minor_offset.has_value() || - *res_second_minor_offset < layout->tiling()[0], - "Not implemented: Slice does not start on the first tile " - "of a VReg"); - setLayout(op, layout, - VectorLayout(layout->bitwidth(), - {res_second_minor_offset, layout->offsets()[1]}, - layout->tiling(), ImplicitDim::kSecondMinor)); - } else { - TPU_CHECK_OP(layout->layout_rank() <= res_vty.getRank(), - "Internal error: Layout has too many dimensions for " - "vector type (invalid vector.extract?)") - setLayout(op, layout, layout); - } - } else { - setLayout(op, - VectorLayout(kNativeBitwidth, {0, 0}, layout->tiling(), - layout->implicit_dim()), - kNoLayout); - } - return success(); - } - - LogicalResult infer(vector::LoadOp op) { - auto src_ty = getMemRefType(op.getBase()); - auto res_ty = op.getVectorType(); - TPU_CHECK_OP(src_ty.getRank() == res_ty.getRank(), - "memref and vector rank mismatch"); - int64_t rank = res_ty.getRank(); - int8_t bitwidth = res_ty.getElementTypeBitWidth(); - if (kNativeBitwidth % bitwidth != 0) { - return op.emitOpError("Unsupported bitwidth"); - } - const int packing = kNativeBitwidth / bitwidth; - auto maybe_tiling = - verifyMemoryTiling(op, getMemRefLayout(op.getBase()).getTiles(), - src_ty.getRank(), src_ty.getElementTypeBitWidth()); - if (!maybe_tiling) { - return failure(); - } - auto tiling = *maybe_tiling; - - SmallVector in_layout(op->getNumOperands(), kNoLayout); - CHECK_EQ(op->getNumOperands(), op.getIndices().size() + 1); - // Infer the static offset on a given tiling dimension. - auto infer_offset = [&](int64_t &offset, - int64_t tiling_dim) -> LogicalResult { - int dim = rank - tiling.size() + tiling_dim; - Value tiled_index = op.getIndices()[dim]; - if (auto cst_op = tiled_index.getDefiningOp()) { - offset = - cast(cst_op.getValue()).getInt() % tiling[tiling_dim]; - return success(); - } - if (failed( - verifyDivisibleIndex(tiled_index, tiling[tiling_dim], dim, op))) { - return failure(); - } - offset = 0; - return success(); - }; - - if (rank == 0) { - op.emitOpError("rank 0 vectors unsupported"); - return failure(); - } - if (rank == 1) { - TPU_CHECK_OP(tiling.size() == 1, "Expected 1D tiling in 1D loads"); - const int64_t lane_tiling = packing * target_shape_[1]; - auto tile = tiling.front(); - TPU_CHECK_OP(tile % lane_tiling == 0, "Unsupported tiling for 1D load"); - int64_t offset; - if (failed(infer_offset(offset, 0))) { - return failure(); - } - // TODO(apaszke): We could generate replicated loads for short values. - setLayout(op, in_layout, - VectorLayout(bitwidth, {0, offset % lane_tiling}, - {1, lane_tiling}, ImplicitDim::kSecondMinor)); - } else { // rank >= 2 - TPU_CHECK_OP(tiling.size() == 2, "Expected 2D tiling in 2D+ loads"); - LayoutOffsets offsets = {0, 0}; - const auto tile_src_shape = src_ty.getShape().take_back(2); - const auto tile_res_shape = res_ty.getShape().take_back(2); - const int64_t num_sublanes = tile_res_shape[0]; - // For now, we focus on tilings that span full sublanes. - TPU_CHECK_OP(tiling[1] == target_shape_[1], - "Unsupported tiling for 2d load"); - // We can load starting from any row if the source has few columns, - // because the tiling structure degenerates to regular layout there. - // There is also no extra need for alignment if we load a single sublane. - // TODO(apaszke): Also no need to align if we don't exceed the base chunk! - if (bitwidth == 32 && - (tile_src_shape[1] <= target_shape_[1] || num_sublanes == 1)) { - offsets[0] = 0; - } else if (failed(infer_offset(*offsets[0], 0))) { - return failure(); - } - if (failed(infer_offset(*offsets[1], 1))) { - return failure(); - } - std::array layout_tiling{tiling[0], tiling[1]}; - if (num_sublanes == 1 && bitwidth == 32 && - tiling[1] == target_shape_[1] && - tile_res_shape[1] > target_shape_[1]) { - // We can strided load sublanes if we're loading a single sublane for - // multiple times. Enabling this helps load one entire row from memref - // more efficiently. - setLayout(op, in_layout, - VectorLayout(bitwidth, offsets, {1, layout_tiling[1]}, - ImplicitDim::kNone)); - } else if (num_sublanes == 1 && bitwidth == 32 && - tiling == target_shape_) { - // We can use replicated loads if we're only loading a single sublane. - setLayout(op, in_layout, - VectorLayout(bitwidth, {std::nullopt, offsets[1]}, - layout_tiling, ImplicitDim::kNone)); - } else if (bitwidth == 32 && - canReinterpretToUntiledMemref( - op.getBase(), target_shape_, - /*allow_minormost_padding=*/true) && - *(src_ty.getShape().end() - 2) > 1) { - // Since it is untiled, we can load from any arbitrary address which - // means we can always set the sublane offset to 0. - // Note: if the src_shape[-2] == 1, we can just use the tiling from ref. - setLayout(op, in_layout, - VectorLayout(bitwidth, {0, offsets[1].value_or(0)}, - nativeTiling(bitwidth), ImplicitDim::kNone)); - } else { - setLayout( - op, in_layout, - VectorLayout(bitwidth, offsets, layout_tiling, ImplicitDim::kNone)); - } - } - return success(); - } - - LogicalResult infer(vector::ExtractStridedSliceOp op) { - auto input_layout = getLayout(op.getVector()); - TPU_CHECK_OP(input_layout, "missing vector layout"); - auto offsets_attr = op.getOffsets().getValue(); - auto strides_attr = op.getStrides().getValue(); - auto offsets = llvm::map_to_vector(offsets_attr, [](auto attr) { - return cast(attr).getInt(); - }); - input_layout->insertImplicit(offsets, 0); - auto vreg_slice = input_layout->vregSlice(target_shape_); - LayoutOffsets new_layout_offsets; - if (input_layout->offsets()[0].has_value()) { - new_layout_offsets[0] = - (*(offsets.end() - 2) + *input_layout->offsets()[0]) % vreg_slice[0]; - } - if (input_layout->offsets()[1].has_value()) { - new_layout_offsets[1] = - (*(offsets.end() - 1) + *input_layout->offsets()[1]) % vreg_slice[1]; - } - for (auto stride : strides_attr) { - TPU_CHECK_OP(stride.cast().getInt() == 1, - "Only trivial strides supported."); - } - - setLayout( - op, input_layout, - VectorLayout(input_layout->bitwidth(), new_layout_offsets, - input_layout->tiling(), input_layout->implicit_dim())); - return success(); - } - - LogicalResult infer(vector::MultiDimReductionOp op) { - auto src_ty = op.getSourceVectorType(); - auto dst_ty = dyn_cast(op.getDestType()); - TPU_CHECK_OP(dst_ty, "only reductions with vector results supported"); - llvm::ArrayRef dims = op.getReductionDims(); - int64_t src_rank = src_ty.getRank(); - auto acc_layout = getLayout(op.getAcc()); - TPU_CHECK_OP(is_fully_replicated(acc_layout), - "only constant accumulators supported"); - TPU_CHECK_OP( - src_ty.getElementTypeBitWidth() == 32 || - src_ty.getElementTypeBitWidth() == 16, - "only 32-bit (and 16-bit only on some targets) reductions supported"); - auto some_src_layout = getLayout(op.getSource()); - TPU_CHECK_OP(some_src_layout, "missing vector layout"); - auto &src_layout = *some_src_layout; - std::array reduces; - switch (src_layout.implicit_dim()) { - case VectorLayout::ImplicitDim::kNone: - reduces = { - std::find(dims.begin(), dims.end(), src_rank - 2) != dims.end(), - std::find(dims.begin(), dims.end(), src_rank - 1) != dims.end()}; - break; - case VectorLayout::ImplicitDim::kSecondMinor: - reduces = {false, std::find(dims.begin(), dims.end(), src_rank - 1) != - dims.end()}; - break; - case VectorLayout::ImplicitDim::kMinor: - reduces = { - std::find(dims.begin(), dims.end(), src_rank - 1) != dims.end(), - false}; - break; - } - if ((reduces[0] || reduces[1]) && - !src_layout.hasNativeTiling(target_shape_)) { - src_layout = VectorLayout(src_layout.bitwidth(), src_layout.offsets(), - nativeTiling(src_layout.bitwidth()), - src_layout.implicit_dim()); - } - LayoutOffsets out_offsets = src_layout.offsets(); - for (int i = 0; i < out_offsets.size(); ++i) { - if (reduces[i]) { - out_offsets[i] = std::nullopt; - } - } - ImplicitDim out_implicit_dim = src_layout.implicit_dim(); - if ((reduces[0] && reduces[1]) || - (src_layout.implicit_dim() != ImplicitDim::kNone && - (reduces[0] || reduces[1]))) { - TPU_CHECK_OP( - dst_ty.getRank() > 0 && *(dst_ty.getShape().end() - 1) == 1, - "Not implemented: reductions over both trailing dimensions are only " - "supported when the resulting value has a trailing axis of size 1"); - out_implicit_dim = VectorLayout::ImplicitDim::kSecondMinor; - } else if (reduces[0]) { - out_implicit_dim = VectorLayout::ImplicitDim::kSecondMinor; - } else if (reduces[1]) { - out_implicit_dim = VectorLayout::ImplicitDim::kMinor; - } - setLayout(op, {src_layout, acc_layout}, - VectorLayout(src_layout.bitwidth(), out_offsets, - src_layout.tiling(), out_implicit_dim)); - return success(); - } - - LogicalResult infer(vector::ShapeCastOp op) { - auto src_ty = op.getSourceVectorType(); - auto src_shape = src_ty.getShape(); - auto res_ty = op.getResultVectorType(); - auto res_shape = res_ty.getShape(); - auto some_src_layout = getLayout(op.getSource()); - TPU_CHECK_OP(some_src_layout, "missing vector layout"); - auto layout = *some_src_layout; - const unsigned bitwidth = src_ty.getElementTypeBitWidth(); - const std::array native_tiling = nativeTiling(bitwidth); - const std::array src_tiled_ishape = - layout.getImplicitTiledDims(src_shape, 1); - const std::array vreg_slice = layout.vregSlice(target_shape_); - - // TODO(tlongeri): Be smarter about trying implicit dims. We should probably - // only add them when folding dimensions, and remove them when unfolding. - // The ordering of candidate implicit dims is important! Inserting an - // implicit second minor can make a reshape possible, but also very - // inefficient. We should always prefer to try with None first. - SmallVector candidate_implicit_dims; - if (res_shape.size() >= 2) { - candidate_implicit_dims.push_back(ImplicitDim::kNone); - } - if (!res_shape.empty()) { - candidate_implicit_dims.push_back(ImplicitDim::kSecondMinor); - candidate_implicit_dims.push_back(ImplicitDim::kMinor); - } - // TODO(b/340625465): Add case with both implicit dims once we support it. - - // See if we can get implicit tiled dimensions to match. This is always a - // no-op. - for (const ImplicitDim implicit_dim : candidate_implicit_dims) { - const std::array res_tiled_ishape = - VectorLayout::getImplicitTiledDims(implicit_dim, res_shape, 1); - if (src_tiled_ishape == res_tiled_ishape) { - // Nothing changes in the tiled dimensions - setLayout(op, layout, - VectorLayout(layout.bitwidth(), layout.offsets(), - layout.tiling(), implicit_dim)); - return success(); - } - } - - // See if we can do sublane or lane (un)folding. - for (const ImplicitDim implicit_dim : candidate_implicit_dims) { - const std::array res_tiled_ishape = - VectorLayout::getImplicitTiledDims(implicit_dim, res_shape, 1); - // Sublane (un)folding. We attempt to reduce the sublane tiling, which - // might make this reshape a no-op. We use do-while to handle the packed - // 1D tilings that use 1 in the sublane dimension. - int64_t sublane_tiling = vreg_slice[0]; - do { - auto src_res_tiled_equal = src_tiled_ishape[1] == res_tiled_ishape[1]; - auto vreg_num_elements = - target_shape_[0] * target_shape_[1] * layout.packing(); - auto single_subline_mod_1024 = - (sublane_tiling == 1 && - src_tiled_ishape[1] % vreg_num_elements == 0 && - res_tiled_ishape[1] % vreg_num_elements == 0); - if ((src_res_tiled_equal || single_subline_mod_1024) && - src_tiled_ishape[0] % sublane_tiling == 0 && - res_tiled_ishape[0] % sublane_tiling == 0) { - std::array tiling = {sublane_tiling, target_shape_[1]}; - // TODO(b/343808585): We shouldn't force second minor offset to 0 when - // unfolding, it's still a no-op, but we need to - // add support in apply-vector-layout. - LayoutOffsets offsets = {0, layout.offsets()[1]}; - setLayout( - op, - VectorLayout(layout.bitwidth(), offsets, tiling, - layout.implicit_dim()), - VectorLayout(layout.bitwidth(), offsets, tiling, implicit_dim)); - return success(); - } - sublane_tiling /= 2; - } while (sublane_tiling >= layout.packing()); - // Lane (un)folding. - if (src_tiled_ishape[1] != res_tiled_ishape[1] && - src_tiled_ishape[1] % layout.tiling()[1] == 0 && - res_tiled_ishape[1] % layout.tiling()[1] == 0) { - const int packing = kNativeBitwidth / bitwidth; - const auto elements_per_vreg = native_tiling[0] * native_tiling[1]; - // When we shapecast from input shape - // (..., m * target_shape_[1] * packing) to output shape - // (..., target_shape_[1]), the reshape becomes no-op when input is - // densely packed with tiling (1, target_shape_[1] * packing) and output - // has the native tiling. - if (res_tiled_ishape[1] == target_shape_[1] && - res_tiled_ishape[0] % native_tiling[0] == 0 && - src_tiled_ishape[1] % elements_per_vreg == 0) { - // Inferring in_layout to have tiling (1, 128 * packing) triggers any - // necessary relayout before shapecast. - setLayout(op, - VectorLayout(layout.bitwidth(), {0, 0}, - {1, target_shape_[1] * packing}, - layout.implicit_dim()), - VectorLayout(layout.bitwidth(), {0, 0}, native_tiling, - implicit_dim)); - return success(); - } - - // When we shapecast from input shape (..., target_shape_[1]) to output - // shape (..., m * target_shape_[1] * packing), the reshape becomes - // no-op when input has the native tiling and output is densely packed - // with tiling (1, target_shape_[1] * packing). - if (src_tiled_ishape[1] == target_shape_[1] && - src_tiled_ishape[0] % native_tiling[0] == 0 && - res_tiled_ishape[1] % elements_per_vreg == 0) { - setLayout( - op, - VectorLayout(layout.bitwidth(), {0, 0}, native_tiling, - layout.implicit_dim()), - VectorLayout(layout.bitwidth(), {0, 0}, - {1, target_shape_[1] * packing}, implicit_dim)); - return success(); - } - } - } - - // Try adding a singleton innermost dim to the actual *implicit* shape. - if (res_shape.size() >= 2 && - res_shape.take_back(2) == ArrayRef({src_tiled_ishape[1], 1})) { - TPU_CHECK_OP(bitwidth == kNativeBitwidth, - "Insertion of minor dim that is not a no-op only " - "supported for 32-bit types"); - setLayout(op, - VectorLayout(layout.bitwidth(), layout.offsets(), native_tiling, - layout.implicit_dim()), - VectorLayout(layout.bitwidth(), {0, std::nullopt}, - native_tiling, ImplicitDim::kNone)); - return success(); - } - op.emitOpError("unsupported shape cast"); - return failure(); - } - - template - LogicalResult inferStore(Op op, bool has_mask = false) { - auto ref_ty = getMemRefType(op.getBase()); - auto store_ty = op.getValueToStore().getType(); - TPU_CHECK_OP(ref_ty.getRank() == store_ty.getRank(), - "memref and vector rank mismatch"); - int64_t rank = ref_ty.getRank(); - int8_t bitwidth = store_ty.getElementTypeBitWidth(); - if (kNativeBitwidth % bitwidth != 0) { - return op.emitOpError("Unsupported bitwidth"); - } - const int packing = kNativeBitwidth / bitwidth; - auto maybe_tiling = - verifyMemoryTiling(op, getMemRefLayout(op.getBase()).getTiles(), - ref_ty.getRank(), ref_ty.getElementTypeBitWidth()); - if (!maybe_tiling) { - return failure(); - } - auto tiling = *maybe_tiling; - - // Infer the static offset on a given tiling dimension. - auto infer_offset = [&](int64_t &offset, - int64_t tiling_dim) -> LogicalResult { - int dim = rank - tiling.size() + tiling_dim; - Value tiled_index = op.getIndices()[dim]; - if (auto cst_op = tiled_index.getDefiningOp()) { - offset = - cast(cst_op.getValue()).getInt() % tiling[tiling_dim]; - return success(); - } - if (failed( - verifyDivisibleIndex(tiled_index, tiling[tiling_dim], dim, op))) { - return failure(); - } - offset = 0; - return success(); - }; - - Layout store_layout; - if (rank == 0) { - op.emitOpError("rank 0 vectors unsupported"); - return failure(); - } - if (rank == 1) { - TPU_CHECK_OP(tiling.size() == 1, "Expected 1D tiling in 1D store"); - const int64_t lane_tiling = packing * target_shape_[1]; - auto tile = tiling.front(); - TPU_CHECK_OP(tile % lane_tiling == 0, - "Unsupported 1D tiling for 1D store"); - int64_t offset; - if (failed(infer_offset(offset, 0))) { - return failure(); - } - store_layout = VectorLayout(bitwidth, {0, offset % lane_tiling}, - {1, lane_tiling}, ImplicitDim::kSecondMinor); - } else { // rank >= 2 // NOLINT(readability-else-after-return) - TPU_CHECK_OP(tiling.size() == 2, "Expected 2D tiling in 2D+ store"); - LayoutOffsets offsets = {0, 0}; - const auto tile_ref_shape = ref_ty.getShape().take_back(2); - const auto tile_store_shape = store_ty.getShape().take_back(2); - const int64_t num_sublanes = tile_store_shape[0]; - // For now, we focus on tilings that span full sublanes. - TPU_CHECK_OP(tiling[1] == target_shape_[1], - "Unsupported tiling for 2d store"); - // We can store starting from any row if the source has few columns, - // because the tiling structure degenerates to regular layout there. - // There is also no extra need for alignment if we store a single sublane. - // TODO(apaszke): Also no need to align if we don't exceed the base chunk! - if (bitwidth == 32 && - (tile_ref_shape[1] <= target_shape_[1] || num_sublanes == 1)) { - offsets[0] = 0; - } else if (failed(infer_offset(*offsets[0], 0))) { - return failure(); - } - if (failed(infer_offset(*offsets[1], 1))) { - return failure(); - } - if (num_sublanes == 1 && bitwidth == 32 && - tiling[1] == target_shape_[1] && - tile_store_shape[1] > target_shape_[1]) { - // We can strided store sublanes if we're storing a single sublane for - // multiple times. Enabling this helps store one entire row to memref - // more efficiently. - store_layout = - VectorLayout(bitwidth, offsets, {1, tiling[1]}, ImplicitDim::kNone); - } else if (bitwidth == 32 && - // We accept padding in the minormost dim, because - // apply_vector_layout will properly mask stores. - canReinterpretToUntiledMemref( - op.getBase(), target_shape_, - /*allow_minormost_padding=*/true)) { - // Since it is untiled, we can store to any arbitrary address which - // means the sublane offset can be any value and we can fold it to - // 2nd minor index. - auto prev_store_layout = getLayout(op.getValueToStore()); - TPU_CHECK_OP(prev_store_layout.has_value(), "missing vector layout"); - offsets[0] = prev_store_layout->offsets()[0].value_or(0); - if (offsets[1].value_or(0) >= tiling[1]) { - offsets[1] = 0; - } - store_layout = VectorLayout(bitwidth, offsets, nativeTiling(bitwidth), - ImplicitDim::kNone); - } else { - store_layout = VectorLayout(bitwidth, offsets, {tiling[0], tiling[1]}, - ImplicitDim::kNone); - } - } - SmallVector in_layout{store_layout}; - in_layout.insert(in_layout.end(), op.getIndices().size() + 1, kNoLayout); - if (has_mask) { - // Mask layout should be the same as the layout of value to store. - in_layout.push_back(store_layout); - } - setInLayout(op, in_layout); - return success(); - } - - LogicalResult infer(vector::TransposeOp op) { - auto permutation = op.getPermutation(); - TPU_CHECK_OP(permutation.size() > 1, - "Vector and scalar transpose should be a no-op and removed"); - - auto some_layout = getLayout(op.getVector()); - TPU_CHECK_OP(some_layout.has_value(), "missing vector layout"); - auto &layout = *some_layout; - auto src_ty = op.getSourceVectorType(); - TPU_CHECK_OP(permutation.size() == src_ty.getRank(), - "Transpose permutation has incorrect rank"); - for (auto dim : permutation.drop_back(2)) { - TPU_CHECK_OP(dim < src_ty.getRank() - 2, - "Unsupported transpose permutation - minor dims into major"); - } - for (auto dim : permutation.take_back(2)) { - TPU_CHECK_OP(dim >= src_ty.getRank() - 2, - "Unsupported transpose permutation - major dims into minor"); - } - Layout required_layout = some_layout; - // Require native tiling if we're going to use the XLU. - if (permutation[permutation.size() - 1] == permutation.size() - 2) { - auto native_tiling = nativeTiling(layout.bitwidth()); - required_layout = VectorLayout(layout.bitwidth(), LayoutOffsets{0, 0}, - native_tiling, ImplicitDim::kNone); - } - setLayout(op, required_layout, required_layout); - return success(); - } - - LogicalResult inferExt(Operation *op) { - TPU_CHECK_OP(op->getNumOperands() == 1, "expect 1 operand"); - TPU_CHECK_OP(op->getNumResults() == 1, "expect 1 result"); - auto src_ty = dyn_cast(op->getOperand(0).getType()); - if (!src_ty) { - setLayout(op, kNoLayout, kNoLayout); - return success(); - } - auto dst_ty = cast(op->getResult(0).getType()); - unsigned src_bitwidth = src_ty.getElementTypeBitWidth(); - unsigned dst_bitwidth = dst_ty.getElementTypeBitWidth(); - auto some_layout = getLayout(op->getOperand(0)); - TPU_CHECK_OP(some_layout.has_value(), "missing vector layout"); - if (dyn_cast(op)) { - TPU_CHECK_OP(dst_bitwidth == 32, "Only supported extensions to 32-bit"); - } - auto &layout = *some_layout; - Layout src_layout; - Layout dst_layout; - if (layout.tiling() == nativeTiling(src_bitwidth)) { - // If the source is already in native tiling, we can unpack it directly. - std::array dst_native_tiling = nativeTiling(dst_bitwidth); - LayoutOffsets offsets = {layout.offsets()[0] - ? *layout.offsets()[0] % dst_native_tiling[0] - : LayoutOffset(), - layout.offsets()[1]}; - DCHECK_LT(offsets[1].value_or(0), dst_native_tiling[1]); - src_layout = VectorLayout(src_bitwidth, offsets, layout.tiling(), - layout.implicit_dim()); - dst_layout = - VectorLayout(dst_bitwidth, offsets, dst_native_tiling, - layout.implicit_dim()); - } else if (dst_bitwidth == 32 && - default_tiling_[0] % layout.tiling()[0] == 0 && - default_tiling_[1] == layout.tiling()[1]) { - // All layouts that subdivide the rows of the result native tiling evenly - // can be handled uniformly with the default case, by preserving the - // tiling through the op. - // TODO(jevinjiang): we can relax this for non-32bit as well. - src_layout = layout; - dst_layout = VectorLayout(dst_bitwidth, layout.offsets(), - src_layout->tiling(), layout.implicit_dim()); - } else if (layout.packing() > target_shape_[0]) { - // When the input dtype has packing greater than the sublane count, we - // can't preserve its native tiling in the output (the tile would be too - // big to fit in a vreg). At the same time, we can't use the default - // tiling either, because the tile size in the input dtype is smaller than - // a sublane. - // For example, for int2 on the target with 8 sublanes, subelements are - // unpacked into 16 consecutive sublanes. - // TODO(b/401624977): Perhaps there is a better layout for this case, or - // if it's impossible, such layout should be used everywhere for int2, not - // just ExtOp. - std::array src_native_tiling = nativeTiling(src_bitwidth); - std::array dst_native_tiling = nativeTiling(dst_bitwidth); - LayoutOffsets src_offsets = { - layout.offsets()[0] ? *layout.offsets()[0] % src_native_tiling[0] - : LayoutOffset(), - layout.offsets()[1] ? *layout.offsets()[1] % src_native_tiling[1] - : LayoutOffset()}; - LayoutOffsets dst_offsets = { - layout.offsets()[0] ? *layout.offsets()[0] % dst_native_tiling[0] - : LayoutOffset(), - layout.offsets()[1] ? *layout.offsets()[1] % dst_native_tiling[1] - : LayoutOffset()}; - src_layout = VectorLayout(src_bitwidth, src_offsets, src_native_tiling, - layout.implicit_dim()); - dst_layout = VectorLayout(dst_bitwidth, dst_offsets, dst_native_tiling, - layout.implicit_dim()); - } else { - LayoutOffsets offsets = { - layout.offsets()[0] ? *layout.offsets()[0] % default_tiling_[0] - : LayoutOffset(), - layout.offsets()[1] ? *layout.offsets()[1] % default_tiling_[1] - : LayoutOffset()}; - src_layout = VectorLayout(src_bitwidth, offsets, default_tiling_, - layout.implicit_dim()); - dst_layout = VectorLayout(dst_bitwidth, offsets, default_tiling_, - layout.implicit_dim()); - } - setLayout(op, src_layout, dst_layout); - return success(); - } - - LogicalResult inferTrunc(Operation *op) { - TPU_CHECK_OP(op->getNumOperands() == 1, "expect 1 operand"); - TPU_CHECK_OP(op->getNumResults() == 1, "expect 1 result"); - auto src_ty = dyn_cast(op->getOperand(0).getType()); - if (!src_ty) { - setLayout(op, kNoLayout, kNoLayout); - return success(); - } - auto dst_ty = cast(op->getResult(0).getType()); - auto some_layout = getLayout(op->getOperand(0)); - TPU_CHECK_OP(some_layout.has_value(), "missing vector layout"); - auto &layout = *some_layout; - bool select_native = allUsersRequireNativeTiling(op->getResult(0)); - // We might want to reconsider enabling native this aggressively in cases - // when it would introduce a lot of padding (e.g. when the value only has - // a small second minor size, but large minor size). - if (dst_ty.getElementTypeBitWidth() == 16) { - // TPUv6 has good support for compute in 16-bit and cheap retiling between - // large 2nd minor and the default tiling, so we bias towards large tiles. - select_native |= hardware_generation_ >= 6 || - tpu_tiling_flags_.use_x16_large_second_minor; - } else if (dst_ty.getElementTypeBitWidth() == 8) { - select_native |= tpu_tiling_flags_.use_x8_large_second_minor; - } else if (dst_ty.getElementTypeBitWidth() == 4) { - select_native |= tpu_tiling_flags_.use_x4_large_second_minor; - } else if (dst_ty.getElementTypeBitWidth() == 2) { - // Force it to native tiling. See comments in `inferExt`. - select_native = true; - } else { - return op->emitOpError("Unsupported target bitwidth for truncation"); - } - auto src_layout = - VectorLayout(layout.bitwidth(), layout.offsets(), - nativeTiling(layout.bitwidth()), layout.implicit_dim()); - auto dst_layout = VectorLayout( - dst_ty.getElementTypeBitWidth(), layout.offsets(), - select_native ? nativeTiling(dst_ty.getElementTypeBitWidth()) - : src_layout.tiling(), - layout.implicit_dim()); - setLayout(op, src_layout, dst_layout); - return success(); - } - - LogicalResult inferElementwise(Operation *op) { - TPU_CHECK_OP(op->getNumResults() == 1, "only one result supported"); - TPU_CHECK_OP(op->getNumOperands() > 0, - "elementwise ops with no operands unsupported"); - // Elementwise operators can be parameterized by both scalars and shaped - // types, so make sure we infer layout based on a shaped-typed operand. - std::optional out_layout_candidate; - std::optional out_layout; - SmallVector, 4> in_layouts; - int64_t bitwidth = -1; - // Find the bitwidth of the operands/results. They must all be the same - // except for the case of i1s, which use a "fake" bitwidth for layouts. - // They can be relayouted (in principle) to any other fake bitwidth, so we - // don't commit to their bitwidth. See comments in VectorLayout class. - for (Value val : llvm::concat(op->getOperands(), op->getResults())) { - if (const VectorType vty = dyn_cast(val.getType())) { - const int64_t val_bitwidth = vty.getElementTypeBitWidth(); - if (val_bitwidth != 1) { - if (bitwidth == -1) { - bitwidth = val_bitwidth; - } else if (bitwidth != val_bitwidth) { - return op->emitOpError( - "Mismatched bitwidth in elementwise for non-i1 " - "operands/results"); - } - } - } - } - for (int64_t i = 0; i < op->getNumOperands(); ++i) { - if (auto vty = dyn_cast(op->getOperand(i).getType())) { - auto some_layout = getLayout(op->getOperand(i)); - TPU_CHECK_OP(some_layout.has_value(), "missing vector layout"); - auto &layout = *some_layout; - if (bitwidth == -1) { - // All operands/results are i1s, just commit to the first bitwidth - DCHECK(!out_layout.has_value()); - bitwidth = layout.bitwidth(); - out_layout = layout; - in_layouts.push_back(layout); - } else if (bitwidth != layout.bitwidth()) { - DCHECK_EQ(vty.getElementTypeBitWidth(), 1); - in_layouts.push_back(std::nullopt); - } else if (is_fully_replicated(some_layout)) { - // If the input is fully replicated, don't use it to commit to any - // layout. Replicated values are easy to relayout. - in_layouts.push_back(std::nullopt); - out_layout_candidate = layout; - } else if (!out_layout) { - // TODO(apaszke): There are probably smarter ways to choose layout. - out_layout = layout; - in_layouts.push_back(some_layout); - } else { - if (auto new_out = - VectorLayout::join(layout, *out_layout, vty.getShape())) { - out_layout = *new_out; - in_layouts.push_back(some_layout); - } else { - // When we detect a layout conflict we cannot reconcile, we remove - // any replication bits that might have been present in out_layout, - // since there is no guarantee that the conflicting inputs could - // even become replicated. - DCHECK_EQ(out_layout->bitwidth(), bitwidth); - out_layout = - VectorLayout(bitwidth, - {out_layout->offsets()[0].value_or(0), - out_layout->offsets()[1].value_or(0)}, - out_layout->tiling(), out_layout->implicit_dim()); - in_layouts.push_back(std::nullopt); - } - } - } else { - TPU_CHECK_OP(op->getOperand(i).getType().isSignlessIntOrIndexOrFloat(), - "expected only vector and scalar operands"); - in_layouts.push_back({kNoLayout}); - } - } - Layout final_out_layout = std::nullopt; - if (auto out_vty = dyn_cast(op->getResult(0).getType())) { - if (out_layout) { - final_out_layout = *out_layout; - } else if (out_layout_candidate) { - final_out_layout = *out_layout_candidate; - } else { - op->emitOpError( - "Elementwise op has no vector operands but returns a vector?"); - return failure(); - } - } - CHECK_EQ(in_layouts.size(), op->getNumOperands()) << Print(op); - SmallVector final_in_layouts; - for (int i = 0; i < in_layouts.size(); ++i) { - if (in_layouts[i]) { - final_in_layouts.push_back(*in_layouts[i]); - } else { - final_in_layouts.push_back(final_out_layout); - } - } - setLayout(op, final_in_layouts, final_out_layout); - return success(); - } - - LogicalResult infer(tpu::PRNGRandomBitsOp op) { - auto res_ty = dyn_cast(op->getResult(0).getType()); - TPU_CHECK_OP(res_ty.getElementTypeBitWidth() == kNativeBitwidth, - "only 32-bit random bit generation supported"); - // TODO: b/342054464 - Support implicit dims for PRNGRandomBitsOp. - LayoutOffsets offsets = {0, 0}; - setOutLayout( - op, VectorLayout(kNativeBitwidth, offsets, - nativeTiling(kNativeBitwidth), ImplicitDim::kNone)); - return success(); - } - - bool allUsersRequireNativeTiling(Value x) { - for (OpOperand &operand : x.getUses()) { - if (isa(operand.getOwner())) { - continue; - } - if (auto reduce = - dyn_cast(operand.getOwner())) { - bool reduces_tiled_dims = false; - for (int64_t dim : reduce.getReductionDims()) { - if (dim >= reduce.getSourceVectorType().getRank() - 2) { - reduces_tiled_dims = true; - break; - } - } - if (reduces_tiled_dims) { - continue; - } - } - if (auto transpose = dyn_cast(operand.getOwner())) { - auto perm = transpose.getPermutation(); - auto rank = perm.size(); - // Only permutations that actually swap the last two dims need it. - if (rank >= 2 && perm[rank - 1] == rank - 2 && - perm[rank - 2] == rank - 1) { - continue; - } - // Fall through. - } - if (auto store = dyn_cast(operand.getOwner())) { - auto maybe_tiling = verifyMemoryTiling( - store, getMemRefLayout(store.getBase()).getTiles(), - store.getMemRefType().getRank(), - store.getMemRefType().getElementTypeBitWidth()); - if (maybe_tiling) { - auto tiling = *maybe_tiling; - if (tiling == - nativeTiling(store.getMemRefType().getElementTypeBitWidth())) { - continue; - } - } - // Fall through. - } - return false; - } - return true; - } - - LogicalResult assumeLayoutsForBlockArgs(Block &block, - ArrayRef layouts) { - auto op = block.getParentOp(); - if (layouts.size() != block.getNumArguments()) { - return op->emitOpError( - "Block arguments must have the same number of layouts"); - } - // Use tpu.assume_layout to annotate every block argument with the layout of - // the corresponding operand and replace all uses of the block argument with - // the result of tpu.assume_layout. - ImplicitLocOpBuilder builder = - ImplicitLocOpBuilder::atBlockBegin(op->getLoc(), &block); - for (auto [iter_arg, layout] : - llvm::zip_equal(block.getArguments(), layouts)) { - if (!dyn_cast(iter_arg.getType())) { - continue; - } - if (llvm::any_of(iter_arg.getUsers(), [](Operation *user) { - return isa(user); - })) { - return op->emitOpError("Expected no assume layout for block arguments"); - } - auto assume_layout_op = - builder.create(iter_arg.getType(), iter_arg); - setLayout(assume_layout_op, layout, layout); - iter_arg.replaceUsesWithIf(assume_layout_op, [&](OpOperand &operand) { - return operand.getOwner() != assume_layout_op; - }); - } - return success(); - } - - void clearBlockLayouts(Block &block) { - block.walk([&](Operation *op) { - // We need to remove assume_layout ops in each block. Otherwise, we will - // create extra assume_layout ops for nested blocks. - if (auto assume_op = dyn_cast(op)) { - assume_op.getResult().replaceAllUsesWith(assume_op.getInput()); - assume_op->erase(); - return WalkResult::advance(); - } - op->removeAttr("in_layout"); - op->removeAttr("out_layout"); - return WalkResult::advance(); - }); - } - - Layout getLayout(Value v) { - auto op = v.getDefiningOp(); - CHECK(op); - auto op_result = dyn_cast(v); - CHECK(op_result); - auto result_index = op_result.getResultNumber(); - auto out_attrs = op->getAttrOfType("out_layout").getValue(); - CHECK(out_attrs.size() > result_index); - auto layout = cast(out_attrs[result_index]).getLayout(); - if (force_first_tile_offsets_ && - layout->offsets()[1].value_or(0) >= layout->tiling()[1]) { - // Force the out-of-first-tile offset to be zero. - layout = VectorLayout(layout->bitwidth(), {layout->offsets()[0], 0}, - layout->tiling(), layout->implicit_dim()); - } - return layout; - } - - SmallVector getLayoutFromOperands(Operation *op) { - SmallVector layouts; - layouts.reserve(op->getNumOperands()); - for (const auto &operand : op->getOperands()) { - if (isa(operand.getType())) { - layouts.push_back(getLayout(operand)); - } else { - layouts.push_back(kNoLayout); - } - } - return layouts; - } - - private: - std::optional> verifyMemoryTiling( - Operation *op, ArrayRef mem_tiling, int64_t rank, - int8_t bitwidth) { - const int packing = kNativeBitwidth / bitwidth; - if (bitwidth == 32) { - if (mem_tiling.size() != 1) { - op->emitOpError("Only one-level tiling supported for 32-bit loads"); - return std::nullopt; - } - } else if (bitwidth < 32) { - int64_t rows_per_tile; - if (rank == 1) { - if (mem_tiling.size() != 3) { - op->emitOpError( - "Only three-level tiling supported for 1D memory ops narrower " - "than 32-bit"); - return std::nullopt; - } - auto first = mem_tiling[0].dimensions(); - auto second = mem_tiling[1].dimensions(); - if (first.size() != 1 || first[0] % (packing * target_shape_[1]) != 0) { - op->emitOpError("Invalid first-level tile in 1D memory op"); - return std::nullopt; - } - rows_per_tile = first[0] / target_shape_[1]; - if (second.size() != 1 || second[0] != target_shape_[1]) { - op->emitOpError("Invalid second-level tile in 1D memory op"); - return std::nullopt; - } - } else { - if (mem_tiling.size() != 2) { - op->emitOpError( - "Only two-level tiling supported for 2D+ memory ops narrower " - "than 32-bit"); - return std::nullopt; - } - auto first = mem_tiling[0].dimensions(); - rows_per_tile = first[0]; - } - auto row_compressed = mem_tiling[mem_tiling.size() - 1].dimensions(); - if (row_compressed.size() != 2) { - op->emitOpError("Expected 2D tiling for packed layout"); - return std::nullopt; - } - if (row_compressed[0] != (32 / bitwidth) || row_compressed[1] != 1) { - op->emitOpError("Expected compressed packed layout"); - return std::nullopt; - } - if (row_compressed[0] > rows_per_tile) { - op->emitOpError("Packing cannot introduce padding"); - return std::nullopt; - } - } else { - op->emitOpError("Loads of types wider than 32-bit unsupported"); - return std::nullopt; - } - return mem_tiling[0].dimensions(); - } - - std::array nativeTiling(int8_t bitwidth) { - return {default_tiling_[0] * kNativeBitwidth / bitwidth, - default_tiling_[1]}; - } - - int hardware_generation_; - std::array target_shape_; - std::array default_tiling_; - TpuTilingFlags tpu_tiling_flags_; - - // TODO(b/342235360): Deprecate force_first_tile_offsets_ once we fully - // remove the restriction that offsets must fall within the first tile. - bool force_first_tile_offsets_ = false; - - // TODO(apaszke): This is not really native on newer generations of TPUs. - // Get rid of this temporary stopgap. - static constexpr int8_t kNativeBitwidth = 32; -}; - -struct InferVectorLayoutPass - : public impl::InferVectorLayoutPassBase { - InferVectorLayoutPass(int hardware_generation, - std::array target_shape, - TpuTilingFlags tpu_tiling_flags) { - this->hardware_generation = hardware_generation; - this->sublane_count = target_shape[0]; - this->lane_count = target_shape[1]; - this->tpu_tiling_flags = tpu_tiling_flags; - } - void runOnOperation() override { - // Fail if hardware_generation has not been set from the default value. - if (hardware_generation < 0) { - getOperation().emitError("hardware_generation must be set") << hardware_generation; - signalPassFailure(); - return; - } - func::FuncOp func = getOperation(); - VectorLayoutInferer run(hardware_generation, {sublane_count, lane_count}, - tpu_tiling_flags); - if (run.infer(func).failed()) { - signalPassFailure(); - } - } - - TpuTilingFlags tpu_tiling_flags; -}; - -} // namespace - -std::unique_ptr> createInferVectorLayoutPass( - int hardware_generation, std::array target_shape, - const TpuTilingFlags &tpu_tiling_flags) { - return std::make_unique( - hardware_generation, target_shape, tpu_tiling_flags); -} - -} // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h deleted file mode 100644 index d240f27fd42d..000000000000 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_VECTOR_LAYOUT_EXTENSIONS_H_ -#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_VECTOR_LAYOUT_EXTENSIONS_H_ - -#include -#include - -#include "mlir/include/mlir/IR/Operation.h" -#include "mlir/include/mlir/Support/LLVM.h" - -namespace mlir::tpu::extensions { - -bool canInferVectorLayout(const Operation &op); - -LogicalResult inferVectorLayout(const Operation &op, - std::array target_shape); - -} // namespace mlir::tpu::extensions - -#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_VECTOR_LAYOUT_EXTENSIONS_H_ diff --git a/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc b/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc index 949a26a4f593..6930bf7d7ceb 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc @@ -13,62 +13,67 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include "jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.h" + +#include +#include #include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" -#include "absl/strings/string_view.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h" -#include "mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/include/mlir/Dialect/Math/IR/Math.h" -#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/include/mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h" -#include "mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" -#include "mlir/include/mlir/IR/AffineMap.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/DialectRegistry.h" -#include "mlir/include/mlir/IR/Matchers.h" -#include "mlir/include/mlir/IR/Operation.h" -#include "mlir/include/mlir/IR/OperationSupport.h" -#include "mlir/include/mlir/IR/PatternMatch.h" -#include "mlir/include/mlir/IR/Types.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Pass/Pass.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Support/LogicalResult.h" -#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" namespace mlir::tpu { -#define GEN_PASS_DECL_LINALGVECTORIZATIONPASS -#define GEN_PASS_DEF_LINALGVECTORIZATIONPASS -#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc" - namespace { + struct VectorizationPattern : public OpInterfaceRewritePattern { using OpInterfaceRewritePattern::OpInterfaceRewritePattern; LogicalResult matchAndRewrite(linalg::LinalgOp op, - PatternRewriter &rewriter) const override { - return vectorize(rewriter, op, - /*inputVectorSizes=*/{}, - /*inputScalableVecDims=*/{}, - /*vectorizeNDExtract=*/false); + PatternRewriter& rewriter) const override { + FailureOr vectorResults = + vectorize(rewriter, op, + /*inputVectorSizes=*/{}, + /*inputScalableVecDims=*/{}, + /*vectorizeNDExtract=*/false); + if (failed(vectorResults)) { + return failure(); + } + rewriter.replaceOp(op, vectorResults->replacements); + return success(); } }; // Check preconditions for `vector.transfer_read` rewrite patterns. LogicalResult checkPreconditions(vector::TransferReadOp op, - PatternRewriter &rewriter) { + PatternRewriter& rewriter) { if (op.hasOutOfBoundsDim()) { return rewriter.notifyMatchFailure(op, "out of bounds transfer dim"); } @@ -91,15 +96,18 @@ LogicalResult checkPreconditions(vector::TransferReadOp op, vector::TransferReadOp createTransferReadOp(vector::TransferReadOp op, Value source, RankedTensorType source_ty, - PatternRewriter &rewriter) { + PatternRewriter& rewriter) { // We know from preconditions that there are no out of bound dims. SmallVector in_bounds(source_ty.getRank(), true); - return rewriter.create( - op.getLoc(), + auto padding = mlir::arith::ConstantOp::create( + rewriter, op->getLoc(), source_ty.getElementType(), + rewriter.getZeroAttr(source_ty.getElementType())); + return vector::TransferReadOp::create( + rewriter, op.getLoc(), VectorType::get(source_ty.getShape(), source_ty.getElementType()), source, - SmallVector( - source_ty.getRank(), - rewriter.create(op.getLoc(), 0)), + SmallVector(source_ty.getRank(), arith::ConstantIndexOp::create( + rewriter, op.getLoc(), 0)), + padding, // Use padding with source_ty. AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(source_ty.getRank(), op->getContext())), rewriter.getBoolArrayAttr(in_bounds)); @@ -107,11 +115,11 @@ vector::TransferReadOp createTransferReadOp(vector::TransferReadOp op, template LogicalResult matchAndRewriteTransferOfExpandOrCollapseShape( - vector::TransferReadOp op, PatternRewriter &rewriter) { + vector::TransferReadOp op, PatternRewriter& rewriter) { if (failed(checkPreconditions(op, rewriter))) { return failure(); } - auto expand = op.getSource().template getDefiningOp(); + auto expand = op.getBase().template getDefiningOp(); if (!expand) { return rewriter.notifyMatchFailure( op, "not a tensor.expand_shape/collapse_shape"); @@ -137,7 +145,7 @@ struct TransferReadOfExpandShape using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransferReadOp op, - PatternRewriter &rewriter) const override { + PatternRewriter& rewriter) const override { return matchAndRewriteTransferOfExpandOrCollapseShape< tensor::ExpandShapeOp>(op, rewriter); } @@ -150,7 +158,7 @@ struct TransferReadOfCollapseShape using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransferReadOp op, - PatternRewriter &rewriter) const override { + PatternRewriter& rewriter) const override { return matchAndRewriteTransferOfExpandOrCollapseShape< tensor::CollapseShapeOp>(op, rewriter); } @@ -163,10 +171,10 @@ struct TransferReadOfConstant using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransferReadOp op, - PatternRewriter &rewriter) const override { + PatternRewriter& rewriter) const override { DenseElementsAttr constant_elements; Attribute constant_value; - if (matchPattern(op.getSource(), m_Constant(&constant_elements)) && + if (matchPattern(op.getBase(), m_Constant(&constant_elements)) && constant_elements.isSplat()) { constant_value = constant_elements.getSplatValue(); } else { @@ -185,11 +193,11 @@ struct TransferReadOfSelect : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransferReadOp op, - PatternRewriter &rewriter) const override { + PatternRewriter& rewriter) const override { if (failed(checkPreconditions(op, rewriter))) { return failure(); } - auto select = op.getSource().getDefiningOp(); + auto select = op.getBase().getDefiningOp(); if (!select) { return rewriter.notifyMatchFailure(op, "source not an arith.select"); } @@ -226,11 +234,11 @@ struct TransferReadOfCmpI : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransferReadOp op, - PatternRewriter &rewriter) const override { + PatternRewriter& rewriter) const override { if (failed(checkPreconditions(op, rewriter))) { return failure(); } - auto cmp = op.getSource().getDefiningOp(); + auto cmp = op.getBase().getDefiningOp(); if (!cmp) { return rewriter.notifyMatchFailure(op, "source not an arith.cmpi"); } @@ -257,11 +265,11 @@ struct TransferReadOfSplat : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransferReadOp op, - PatternRewriter &rewriter) const override { + PatternRewriter& rewriter) const override { if (failed(checkPreconditions(op, rewriter))) { return failure(); } - auto splat = op.getSource().getDefiningOp(); + auto splat = op.getBase().getDefiningOp(); if (!splat) { return rewriter.notifyMatchFailure(op, "source not a tensor.splat"); } @@ -275,7 +283,7 @@ struct TransferReadOfSplat : public OpRewritePattern { }; // List of operations that are covered by the supports_bf16_alu_instructions. -const auto kSupportedBf16Ops = absl::flat_hash_set( +const auto kSupportedBf16Ops = absl::flat_hash_set( {arith::AddFOp::getOperationName(), arith::SubFOp::getOperationName(), arith::MulFOp::getOperationName(), arith::MaximumFOp::getOperationName(), arith::MinimumFOp::getOperationName()}); @@ -287,12 +295,12 @@ const auto kSupportedBf16Ops = absl::flat_hash_set( class GenericBitwidthConvert : public RewritePattern { public: explicit GenericBitwidthConvert(llvm::StringRef operation_name, - MLIRContext *ctx, + MLIRContext* ctx, bool supports_bf16_alu_instructions) : RewritePattern(operation_name, 0, ctx), supports_bf16_alu_instructions_(supports_bf16_alu_instructions) {} - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { if (supports_bf16_alu_instructions_ && kSupportedBf16Ops.contains(op->getName().getStringRef())) { return rewriter.notifyMatchFailure(op, "target supports bf16 operands"); @@ -313,8 +321,9 @@ class GenericBitwidthConvert : public RewritePattern { continue; } has_bf16_operand = true; - extended_operands.push_back(rewriter.create( - loc, VectorType::get(operand_type.getShape(), rewriter.getF32Type()), + extended_operands.push_back(arith::ExtFOp::create( + rewriter, loc, + VectorType::get(operand_type.getShape(), rewriter.getF32Type()), operand)); } // If there are no bf16 operands, then we do not need to rewrite the op. @@ -337,7 +346,7 @@ class GenericBitwidthConvert : public RewritePattern { } OperationState state(loc, op->getName().getStringRef(), extended_operands, new_results, op->getAttrs(), op->getSuccessors()); - Operation *new_op = rewriter.create(state); + Operation* new_op = rewriter.create(state); rewriter.replaceOpWithNewOp(op, op->getResultTypes(), new_op->getResults()); return success(); @@ -356,11 +365,11 @@ struct ContractionBitwidthConvert : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - ContractionBitwidthConvert(bool supports_bf16_matmul, MLIRContext *ctx) + ContractionBitwidthConvert(bool supports_bf16_matmul, MLIRContext* ctx) : OpRewritePattern(ctx), supports_bf16_matmul_(supports_bf16_matmul) {} LogicalResult matchAndRewrite(vector::ContractionOp op, - PatternRewriter &rewriter) const override { + PatternRewriter& rewriter) const override { // The ContractionOp contract is that (1) lhs and rhs have same element // type, and (2) the accumulator and result have the same element type. @@ -384,27 +393,27 @@ struct ContractionBitwidthConvert Value lhs = op.getLhs(); Value rhs = op.getRhs(); if (extend_operands) { - lhs = rewriter.create( - op.getLoc(), + lhs = arith::ExtFOp::create( + rewriter, op.getLoc(), VectorType::get(op.getLhsType().getShape(), rewriter.getF32Type()), lhs); - rhs = rewriter.create( - op.getLoc(), + rhs = arith::ExtFOp::create( + rewriter, op.getLoc(), VectorType::get(op.getRhsType().getShape(), rewriter.getF32Type()), rhs); } Value acc = op.getAcc(); if (extend_acc) { - acc = rewriter.create( - op.getLoc(), + acc = arith::ExtFOp::create( + rewriter, op.getLoc(), VectorType::get(acc_ty.getShape(), rewriter.getF32Type()), op.getAcc()); } - vector::ContractionOp contraction = rewriter.create( - op.getLoc(), lhs, rhs, acc, op.getIndexingMaps(), op.getIteratorTypes(), - op.getKind()); + vector::ContractionOp contraction = vector::ContractionOp::create( + rewriter, op.getLoc(), lhs, rhs, acc, op.getIndexingMaps(), + op.getIteratorTypes(), op.getKind()); if (extend_acc) { rewriter.replaceOpWithNewOp( @@ -429,7 +438,7 @@ struct MultiDimReductionBitwidthConvert using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::MultiDimReductionOp op, - PatternRewriter &rewriter) const override { + PatternRewriter& rewriter) const override { // Below we rely on the contract that the source operand, accumulator, and // result have the same element type. auto src_ty = op.getSourceVectorType(); @@ -442,14 +451,14 @@ struct MultiDimReductionBitwidthConvert return rewriter.notifyMatchFailure(op, "not vector reduction"); } - auto reduction = rewriter.create( - op.getLoc(), - rewriter.create( - op.getLoc(), + auto reduction = vector::MultiDimReductionOp::create( + rewriter, op.getLoc(), + arith::ExtFOp::create( + rewriter, op.getLoc(), VectorType::get(src_ty.getShape(), rewriter.getF32Type()), op.getSource()), - rewriter.create( - op.getLoc(), + arith::ExtFOp::create( + rewriter, op.getLoc(), VectorType::get(res_ty.getShape(), rewriter.getF32Type()), op.getAcc()), op.getReductionMask(), op.getKind()); @@ -458,88 +467,76 @@ struct MultiDimReductionBitwidthConvert } }; -struct LinalgVectorizationPass - : public impl::LinalgVectorizationPassBase { - explicit LinalgVectorizationPass( - const LinalgVectorizationPassOptions &options) - : impl::LinalgVectorizationPassBase(options) {} - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); +} // namespace + +void LinalgVectorizationPass::getDependentDialects( + DialectRegistry& registry) const { + registry.insert(); +} + +void LinalgVectorizationPass::runOnOperation() { + auto func = getOperation(); + MLIRContext* ctx = func.getContext(); + + RewritePatternSet patterns(ctx); + patterns.add(ctx); + // Pull in patterns to shuffle broadcast/transpose ops around in order to + // cancel them or embed into contract ops. Embedding in the flexible + // contract ops will help to sustain the structure through various + // transformations. + vector::populateVectorReductionToContractPatterns(patterns); + vector::populateSinkVectorOpsPatterns(patterns); + // Pull in patterns to canonicalize transfer ops. + vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); + vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); + vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); + patterns.add(ctx); + // Pull in patterns to convert bf16 ops to f32 ops. + for (::llvm::StringLiteral unary_op_name : + {arith::NegFOp::getOperationName(), math::TanhOp::getOperationName(), + math::ExpOp::getOperationName(), math::AbsFOp::getOperationName(), + math::SinOp::getOperationName(), math::CosOp::getOperationName(), + math::SqrtOp::getOperationName(), math::RsqrtOp::getOperationName(), + math::LogOp::getOperationName(), math::Log1pOp::getOperationName(), + math::RoundOp::getOperationName(), + math::RoundEvenOp::getOperationName()}) { + patterns.add(unary_op_name, ctx, + supports_bf16_alu_instructions); } - void runOnOperation() override { - auto func = getOperation(); - MLIRContext *ctx = func.getContext(); - - RewritePatternSet patterns(ctx); - patterns.add(ctx); - // Pull in patterns to shuffle broadcast/transpose ops around in order to - // cancel them or embed into contract ops. Embedding in the flexible - // contract ops will help to sustain the structure through various - // transformations. - vector::populateVectorReductionToContractPatterns(patterns); - vector::populateSinkVectorOpsPatterns(patterns); - // Pull in patterns to canonicalize transfer ops. - vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); - vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); - vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); - patterns.add(ctx); - // Pull in patterns to convert bf16 ops to f32 ops. - for (::llvm::StringLiteral unary_op_name : - {arith::NegFOp::getOperationName(), math::TanhOp::getOperationName(), - math::ExpOp::getOperationName(), math::AbsFOp::getOperationName(), - math::SinOp::getOperationName(), math::CosOp::getOperationName(), - math::SqrtOp::getOperationName(), math::RsqrtOp::getOperationName(), - math::LogOp::getOperationName(), math::Log1pOp::getOperationName(), - math::RoundOp::getOperationName(), - math::RoundEvenOp::getOperationName()}) { - patterns.add(unary_op_name, ctx, - supports_bf16_alu_instructions); - } - for (::llvm::StringLiteral binary_op_name : - {arith::MulFOp::getOperationName(), arith::DivFOp::getOperationName(), - arith::AddFOp::getOperationName(), arith::SubFOp::getOperationName(), - arith::MaximumFOp::getOperationName(), - arith::MinimumFOp::getOperationName(), - math::PowFOp::getOperationName()}) { - patterns.add(binary_op_name, ctx, - supports_bf16_alu_instructions); - } - for (::llvm::StringLiteral ternary_op_name : - {arith::SelectOp::getOperationName()}) { - patterns.add(ternary_op_name, ctx, - supports_bf16_alu_instructions); - } - patterns.add(supports_bf16_matmul, ctx); - patterns.add(ctx); - - // We do not want to apply the vector patterns above to the ops that are - // unrelated to the original linalg op. - SmallVector linalgOps; - func.walk([&](Operation *op) { - if (dyn_cast(op) || dyn_cast(op) || - dyn_cast(op) || - dyn_cast(op) || - dyn_cast(op) || - dyn_cast(op)) { - linalgOps.push_back(op); - } - }); - if (failed(applyOpPatternsAndFold(linalgOps, std::move(patterns)))) { - return signalPassFailure(); + for (::llvm::StringLiteral binary_op_name : + {arith::MulFOp::getOperationName(), arith::DivFOp::getOperationName(), + arith::AddFOp::getOperationName(), arith::SubFOp::getOperationName(), + arith::MaximumFOp::getOperationName(), + arith::MinimumFOp::getOperationName(), + math::PowFOp::getOperationName()}) { + patterns.add(binary_op_name, ctx, + supports_bf16_alu_instructions); + } + for (::llvm::StringLiteral ternary_op_name : + {arith::SelectOp::getOperationName()}) { + patterns.add(ternary_op_name, ctx, + supports_bf16_alu_instructions); + } + patterns.add(supports_bf16_matmul, ctx); + patterns.add(ctx); + + // We do not want to apply the vector patterns above to the ops that are + // unrelated to the original linalg op. + SmallVector linalgOps; + func.walk([&](Operation* op) { + if (dyn_cast(op) || dyn_cast(op) || + dyn_cast(op) || + dyn_cast(op) || + dyn_cast(op) || + dyn_cast(op)) { + linalgOps.push_back(op); } + }); + if (failed(applyOpPatternsAndFold(linalgOps, std::move(patterns)))) { + return signalPassFailure(); } -}; - -} // namespace - -std::unique_ptr> createLinalgVectorizationPass( - bool supports_bf16_alu_instructions, bool supports_bf16_matmul) { - LinalgVectorizationPassOptions options; - options.supports_bf16_alu_instructions = supports_bf16_alu_instructions; - options.supports_bf16_matmul = supports_bf16_matmul; - return std::make_unique(options); } } // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.h b/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.h new file mode 100644 index 000000000000..2a392f37e76c --- /dev/null +++ b/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.h @@ -0,0 +1,91 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_LINALG_VECTORIZATION_H_ +#define JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_LINALG_VECTORIZATION_H_ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "jaxlib/mosaic/pass_boilerplate.h" + +namespace mlir::tpu { + +struct LinalgVectorizationPassOptions { + bool supports_bf16_alu_instructions = false; + bool supports_bf16_matmul = false; +}; + +struct LinalgVectorizationPass + : public jaxlib::mlir::Pass { + using jaxlib::mlir::Pass::Pass; + + static constexpr llvm::StringLiteral kArgumentName = "linalg-vectorization"; + static constexpr llvm::StringLiteral kPassName = "LinalgVectorizationPass"; + + LinalgVectorizationPass() = default; + + explicit LinalgVectorizationPass(LinalgVectorizationPassOptions options) { + supports_bf16_alu_instructions = options.supports_bf16_alu_instructions; + supports_bf16_matmul = options.supports_bf16_matmul; + } + + LinalgVectorizationPass(const LinalgVectorizationPass& other) { + supports_bf16_alu_instructions = other.supports_bf16_alu_instructions; + supports_bf16_matmul = other.supports_bf16_matmul; + } + + LinalgVectorizationPass& operator=(const LinalgVectorizationPass& other) { + supports_bf16_alu_instructions = other.supports_bf16_alu_instructions; + supports_bf16_matmul = other.supports_bf16_matmul; + return *this; + } + + void getDependentDialects(DialectRegistry& registry) const override; + void runOnOperation() override; + + protected: + ::mlir::Pass::Option supports_bf16_alu_instructions{ + *this, "supports-bf16-alu-instructions", llvm::cl::desc("")}; + ::mlir::Pass::Option supports_bf16_matmul{*this, "supports-bf16-matmul", + llvm::cl::desc("")}; +}; + +inline std::unique_ptr<::mlir::Pass> createLinalgVectorizationPass( + bool supports_bf16_alu_instructions = false, + bool supports_bf16_matmul = false) { + return std::make_unique( + LinalgVectorizationPassOptions{ + .supports_bf16_alu_instructions = supports_bf16_alu_instructions, + .supports_bf16_matmul = supports_bf16_matmul, + }); +} + +inline std::unique_ptr<::mlir::Pass> createLinalgVectorizationPass( + LinalgVectorizationPassOptions options) { + return std::make_unique(std::move(options)); +} + +inline void registerLinalgVectorizationPass() { + registerPass([]() -> std::unique_ptr<::mlir::Pass> { + return createLinalgVectorizationPass(); + }); +} + +} // namespace mlir::tpu + +#endif // JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_LINALG_VECTORIZATION_H_ diff --git a/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc b/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc deleted file mode 100644 index b73ea0f1250f..000000000000 --- a/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc +++ /dev/null @@ -1,114 +0,0 @@ -/* Copyright 2023 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "absl/log/check.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Support/LogicalResult.h" -#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" - -namespace mlir { -namespace tpu { - -namespace { - -MemRefType updateMemorySpace(MemRefType ty, Attribute memory_space) { - return MemRefType::get(ty.getShape(), ty.getElementType(), ty.getLayout(), - memory_space); -} - -MemRefType updateMemorySpace(MemRefType ty, MemorySpace memory_space) { - return updateMemorySpace(ty, - MemorySpaceAttr::get(ty.getContext(), memory_space)); -} - -} // namespace - -LogicalResult specializeMemorySpace(TypedValue value, - MemorySpace memory_space) { - MemorySpaceAttr attr = - dyn_cast_if_present(value.getType().getMemorySpace()); - if (!attr) { - return failure(); - } - MemorySpace current_memory_space = attr.getValue(); - if (current_memory_space == memory_space) { - return success(); // Nothing to do here. - } else if (current_memory_space != MemorySpace::kAny) { - return failure(); // Memory space mismatch! - } - value.setType(updateMemorySpace(value.getType(), memory_space)); - std::vector to_update(value.getUsers().begin(), - value.getUsers().end()); - auto updateResultFrom = [&](Operation* op, MemRefType ty) { - Attribute source_memory_space = ty.getMemorySpace(); - CHECK_EQ(op->getNumResults(), 1); - Value result = op->getResult(0); - MemRefType result_type = cast(result.getType()); - if (result_type.getMemorySpace() != source_memory_space) { - result.setType(updateMemorySpace(result_type, source_memory_space)); - to_update.insert(to_update.end(), result.getUsers().begin(), - result.getUsers().end()); - } - }; - while (!to_update.empty()) { - Operation* some_op = to_update.back(); - to_update.pop_back(); - // Here we only have to handle the operations allowed on refs with - // unspecified memory space. - if (auto op = dyn_cast(some_op)) { - updateResultFrom(op, op.getInput().getType()); - continue; - } - if (auto op = dyn_cast(some_op)) { - updateResultFrom(op, op.getMemRef().getType()); - continue; - } - if (auto op = dyn_cast(some_op)) { - updateResultFrom(op, op.getInput().getType()); - continue; - } - if (auto op = dyn_cast(some_op)) { - updateResultFrom(op, op.getInput().getType()); - continue; - } - if (auto op = dyn_cast(some_op)) { - updateResultFrom(op, op.getInput().getType()); - continue; - } - if (auto op = dyn_cast(some_op)) { - updateResultFrom(op, op.getOperand().getType()); - continue; - } - if (auto op = dyn_cast(some_op)) { - continue; // Nothing to do. - } - if (auto op = dyn_cast(some_op)) { - continue; // Nothing to do. - } - if (auto op = dyn_cast(some_op)) { - continue; // Nothing to do. - } - some_op->emitOpError( - "Failed to propagate memory space update through this operation"); - return failure(); - } - return success(); -} - -} // namespace tpu -} // namespace mlir diff --git a/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc b/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc deleted file mode 100644 index b88504e35068..000000000000 --- a/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc +++ /dev/null @@ -1,196 +0,0 @@ -#include -#include -#include -#include - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/Types.h" -#include "mlir/IR/Value.h" -#include "mlir/IR/Visitors.h" -#include "mlir/Pass/Pass.h" -#include "absl/log/check.h" -#include "llvm/include/llvm/ADT/STLExtras.h" -#include "llvm/include/llvm/Support/MathExtras.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/Diagnostics.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "jaxlib/mosaic/dialect/tpu/layout.h" -#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" -#include "jaxlib/mosaic/dialect/tpu/util.h" - -namespace mlir::tpu { - -#define GEN_PASS_DECL_RELAYOUTINSERTIONPASS -#define GEN_PASS_DEF_RELAYOUTINSERTIONPASS -#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc" - -namespace { - -FailureOr> relayout( - OpBuilder &builder, TypedValue v, VectorLayout src, - VectorLayout dst, int hardware_generation, - const std::array target_shape) { - // change bitwidth - if (v.getType().getElementType() == builder.getI1Type() && - // TODO(jevinjiang): for other relayout changes (tiling, offsets, implicit - // dim), we currently rely on apply-vector-layout pass to do the relayout. - src.bitwidth() != dst.bitwidth()) { - auto vreg_slice = src.vregSlice(target_shape, dst.bitwidth(), src.tiling()); - auto dst_bitwidth_layout = VectorLayout( - dst.bitwidth(), - { - src.offsets()[0].has_value() ? *src.offsets()[0] % vreg_slice[0] - : LayoutOffset(), - src.offsets()[1].has_value() ? *src.offsets()[1] % vreg_slice[1] - : LayoutOffset(), - }, - src.tiling(), src.implicit_dim()); - if (!dst_bitwidth_layout.isValid(target_shape)) { - return emitError(v.getLoc(), - "Not implemented: failed to infer valid layout during " - "relayout, got ") - << dst_bitwidth_layout; - } - // We might be able to pack mask directly. - // TODO(jevinjiang): Add support for 16bit -> 8bit mask packing. - if (src.bitwidth() == 32 && dst.bitwidth() == 16 && - // TODO(jevinjiang): support mask packing for non-native source tiling. - src.tiling()[0] == src.packing() * target_shape[0] && - src.tiling()[1] == target_shape[1]) { - auto relayout_op = - builder.create(v.getLoc(), v.getType(), v); - setLayout(relayout_op, src, dst_bitwidth_layout); - return cast>(relayout_op.getResult()); - } - CHECK(llvm::isPowerOf2_32(src.bitwidth())); - CHECK(llvm::isPowerOf2_32(dst.bitwidth())); - auto make_vty = [&](int bitwidth) { - return VectorType::get(v.getType().getShape(), - builder.getIntegerType(bitwidth)); - }; - auto make_constant = [&](int val, VectorLayout layout) { - auto vty = make_vty(layout.bitwidth()); - auto constant_op = builder.create( - v.getLoc(), - DenseElementsAttr::get( - vty, builder.getIntegerAttr(vty.getElementType(), val))); - setOutLayout(constant_op, - VectorLayout(layout.bitwidth(), {std::nullopt, std::nullopt}, - layout.tiling(), layout.implicit_dim())); - return constant_op; - }; - auto src_int_vty = make_vty(src.bitwidth()); - auto dst_int_vty = make_vty(dst.bitwidth()); - // TODO(jevinjiang): Since dst_bitwidth_layout will be firstly used in the - // extSI or truncI below, we can reuse the inferExt and inferTrunc from - // infer-vector-layout pass. - auto ext_op = builder.create(v.getLoc(), src_int_vty, v); - setLayout(ext_op, src, src); - - // TODO(jevinjiang): some conversion might not be supported in HW. - Operation *cast_op = - dst.bitwidth() > src.bitwidth() - ? builder.create(v.getLoc(), dst_int_vty, ext_op) - // TODO(jevinjiang): HW may support pack vmask directly. - : builder.create(v.getLoc(), dst_int_vty, ext_op); - setLayout(cast_op, src, dst_bitwidth_layout); - - auto cmp_op = builder.create( - v.getLoc(), v.getType(), arith::CmpIPredicate::ne, - cast_op->getResult(0), make_constant(0, dst_bitwidth_layout)); - setLayout(cmp_op, {dst_bitwidth_layout, dst_bitwidth_layout}, - dst_bitwidth_layout); - return cast>(cmp_op.getResult()); - } - return v; -} - -// TODO(jevinjiang): make relayout to an op so we don't need decide when to -// relayout in apply-vector-layout pass. -LogicalResult insertRelayout(Operation &op, int hardware_generation, - const std::array target_shape) { - FAILUREOR_ASSIGN_OR_RETURN(const SmallVector in_layouts, - getInLayouts(op, target_shape)); - if (in_layouts.size() != op.getNumOperands()) { - return op.emitError("Expected the same number of operands as in_layouts"); - } - if (isa(op)) { - return success(); - } - // Relayout the operands, if their requested input layouts don't match the - // layouts in which they were produced. - for (auto [idx, tup] : - llvm::enumerate(llvm::zip(op.getOperands(), in_layouts))) { - auto [operand, li] = tup; - auto vector_operand = dyn_cast>(operand); - TPU_ASSERT_EQ_OP(vector_operand != nullptr, li.has_value()); - if (vector_operand == nullptr) { - continue; - } - // The operand should always be an Operation (and not a BlockArgument) - // since we expect the FuncOp to have only memrefs and semaphores as - // arguments. - auto op_result = dyn_cast(vector_operand); - if (op_result == nullptr) { - return op.emitError("Expected vector operand to be an operation result"); - } - Operation *const def_op = op_result.getOwner(); - DCHECK(def_op); - const unsigned res_idx = op_result.getResultNumber(); - FAILUREOR_ASSIGN_OR_RETURN(const SmallVector def_layouts, - getOutLayouts(*def_op, target_shape)); - const Layout lo = def_layouts[res_idx]; - TPU_ASSERT_OP(lo.has_value()); - if (*lo == *li) { - continue; - } - OpBuilder builder(&op); - FAILUREOR_ASSIGN_OR_RETURN( - Value new_v, relayout(builder, vector_operand, /*src=*/*lo, - /*dst=*/*li, hardware_generation, target_shape)); - op.setOperand(idx, new_v); - } - return success(); -} - -struct RelayoutInsertionPass - : public impl::RelayoutInsertionPassBase { - RelayoutInsertionPass(int generation, std::array target_shape) { - this->hardware_generation = generation; - this->sublane_count = target_shape[0]; - this->lane_count = target_shape[1]; - } - void runOnOperation() override { - // Fail if hardware_generation has not been set from the default value. - if (hardware_generation < 0) { - getOperation().emitError("hardware_generation must be set"); - signalPassFailure(); - return; - } - func::FuncOp func = getOperation(); - auto result = func.walk([&](Operation *op) { - if (insertRelayout(*op, hardware_generation, {sublane_count, lane_count}) - .failed()) { - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - if (result.wasInterrupted()) { - signalPassFailure(); - return; - } - } -}; - -} // namespace - -std::unique_ptr> createRelayoutInsertionPass( - int hardware_generation, std::array target_shape) { - return std::make_unique(hardware_generation, - target_shape); -} - -} // namespace mlir::tpu \ No newline at end of file diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc index 0981c263d252..2eec29ee1cec 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc @@ -18,19 +18,17 @@ limitations under the License. #include #include +#include "llvm/ADT/StringMap.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Value.h" #include "mlir/IR/Visitors.h" #include "mlir/Support/LLVM.h" -#include "llvm/include/llvm/ADT/StringMap.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/OpDefinition.h" -#include "mlir/include/mlir/IR/OperationSupport.h" -#include "mlir/include/mlir/Support/LogicalResult.h" +#include "mlir/Support/LogicalResult.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/serde.h" @@ -42,11 +40,46 @@ constexpr StringRef kMangledDialect = "stable_mosaic."; constexpr StringRef kVersionAttrName = "stable_mosaic.version"; // When this is bumped, we should file a TODO to update the forward-compatible // version in tpu_custom_call.py in a month! -constexpr int kVersion = 3; +constexpr int kVersion = 9; using SerdeRuleType = jaxlib::mosaic::SerdeRuleType; -LogicalResult enqueue_dma_upgrade(Operation* op, int version) { +LogicalResult dynamic_gather_upgrade(Operation* op, int version, bool&) { + if (version < 5) { + auto dimension_attr = op->getAttrOfType("dimension"); + if (!dimension_attr || dimension_attr.getValue().getBitWidth() != 32) { + return op->emitError("Missing or invalid dimension attribute"); + } + const int32_t dimension = dimension_attr.getInt(); + op->removeAttr("dimension"); + op->setAttr("dimensions", + DenseI32ArrayAttr::get(op->getContext(), {dimension})); + } + return success(); +} + +LogicalResult dynamic_gather_downgrade(Operation* op, int version, bool&) { + if (version < 5) { + auto dimensions_attr = op->getAttrOfType("dimensions"); + if (!dimensions_attr) { + return op->emitError("Missing or invalid dimensions attribute"); + } + const ArrayRef dimensions = dimensions_attr.asArrayRef(); + if (dimensions.size() != 1) { + return op->emitError( + "Can only downgrade below version 5 when a single dimension is " + "specified."); + } + const int32_t dimension = dimensions.front(); + op->removeAttr("dimensions"); + op->setAttr("dimension", + mlir::IntegerAttr::get( + mlir::IntegerType::get(op->getContext(), 32), dimension)); + } + return success(); +} + +LogicalResult enqueue_dma_upgrade(Operation* op, int version, bool&) { // Added AttrSizedOperandSegments and core_id in version 2. if (version < 2) { if (op->getNumOperands() == 3) { // Local DMA. @@ -64,17 +97,106 @@ LogicalResult enqueue_dma_upgrade(Operation* op, int version) { << op->getNumOperands(); } } + if (version < 4) { + op->setAttr("priority", + mlir::IntegerAttr::get( + mlir::IntegerType::get(op->getContext(), 32), 0)); + } return success(); } -LogicalResult enqueue_dma_downgrade(Operation* op, int version) { +LogicalResult enqueue_dma_downgrade(Operation* op, int version, bool&) { + if (version < 8) { + auto ordering_attr = op->getAttrOfType("strict_ordering"); + if (ordering_attr != nullptr) { + if (ordering_attr.getValue()) { + return op->emitError( + "Can only downgrade below version 8 when strict ordering is not " + "set to True"); + } + op->removeAttr("strict_ordering"); + } + } + if (version < 4) { + op->removeAttr("priority"); + } if (version < 2) { return op->emitError("Downgrade to version ") << version << " unsupported"; } return success(); } -LogicalResult semaphore_signal_upgrade(Operation* op, int version) { +LogicalResult iota_upgrade(Operation* op, int version, bool&) { + if (version < 6) { + auto dimension_attr = op->getAttrOfType("dimension"); + if (!dimension_attr || dimension_attr.getValue().getBitWidth() != 32) { + return op->emitError("Missing or invalid dimension attribute"); + } + const int32_t dimension = dimension_attr.getInt(); + op->removeAttr("dimension"); + op->setAttr("dimensions", + DenseI32ArrayAttr::get(op->getContext(), {dimension})); + } + return success(); +} + +LogicalResult iota_downgrade(Operation* op, int version, bool&) { + if (version < 6) { + auto dimensions_attr = op->getAttrOfType("dimensions"); + if (!dimensions_attr) { + return op->emitError("Missing or invalid dimensions attribute"); + } + const ArrayRef dimensions = dimensions_attr.asArrayRef(); + if (dimensions.size() != 1) { + return op->emitError( + "Can only downgrade below version 5 when a single dimension is " + "specified."); + } + const int32_t dimension = dimensions.front(); + op->removeAttr("dimensions"); + op->setAttr("dimension", + mlir::IntegerAttr::get( + mlir::IntegerType::get(op->getContext(), 32), dimension)); + } + return success(); +} + +LogicalResult wait_dma2_upgrade(Operation* op, int version, bool&) { + if (version < 7) { + if (op->getNumOperands() != 3) { + return op->emitError("Unexpected operand count in tpu.wait_dma2: ") + << op->getNumOperands(); + } + op->setAttr( + OpTrait::AttrSizedOperandSegments< + EnqueueDMAOp>::getOperandSegmentSizeAttr(), + mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 1, 0, 0})); + } + return success(); +} + +LogicalResult wait_dma2_downgrade(Operation* op, int version, bool&) { + if (version < 7) { + auto operands = op->getAttrOfType( + OpTrait::AttrSizedOperandSegments< + EnqueueDMAOp>::getOperandSegmentSizeAttr()); + if (!operands || operands.size() != 5) { + return op->emitError("Missing or invalid AttrSizedOperandSegments"); + } + if (operands[3] || operands[4]) { + return op->emitError("Downgrade to version ") + << version << " impossible: device_id and/or core_id is set"; + } + op->removeAttr(OpTrait::AttrSizedOperandSegments< + EnqueueDMAOp>::getOperandSegmentSizeAttr()); + } + if (version < 3) { + return op->emitError("Downgrade to version ") << version << " unsupported"; + } + return success(); +} + +LogicalResult semaphore_signal_upgrade(Operation* op, int version, bool&) { // Added AttrSizedOperandSegments and core_id in version 2. if (version < 2) { if (op->getNumOperands() == 2) { // Local signal. @@ -92,7 +214,7 @@ LogicalResult semaphore_signal_upgrade(Operation* op, int version) { return success(); } -LogicalResult semaphore_signal_downgrade(Operation* op, int version) { +LogicalResult semaphore_signal_downgrade(Operation* op, int version, bool&) { if (version < 2) { auto operands = op->getAttrOfType( OpTrait::AttrSizedOperandSegments< @@ -110,7 +232,8 @@ LogicalResult semaphore_signal_downgrade(Operation* op, int version) { return success(); } -LogicalResult vector_multi_dim_reduce_upgrade(Operation* op, int version) { +LogicalResult vector_multi_dim_reduce_upgrade(Operation* op, int version, + bool&) { // Changed reductions_dims from ArrayAttr of IntegerAttrs to DenseI64ArrayAttr // in version 3. if (version < 3) { @@ -138,7 +261,8 @@ LogicalResult vector_multi_dim_reduce_upgrade(Operation* op, int version) { return success(); } -LogicalResult vector_multi_dim_reduce_downgrade(Operation* op, int version) { +LogicalResult vector_multi_dim_reduce_downgrade(Operation* op, int version, + bool&) { if (version < 3) { return op->emitError("Downgrade to version ") << version << " unsupported"; } @@ -148,15 +272,22 @@ LogicalResult vector_multi_dim_reduce_downgrade(Operation* op, int version) { const llvm::StringMap& upgrade_rules() { static auto rules = new llvm::StringMap{ {EnqueueDMAOp::getOperationName(), enqueue_dma_upgrade}, + {WaitDMA2Op::getOperationName(), wait_dma2_upgrade}, + {DynamicGatherOp::getOperationName(), dynamic_gather_upgrade}, + {IotaOp::getOperationName(), iota_upgrade}, {SemaphoreSignalOp::getOperationName(), semaphore_signal_upgrade}, {vector::MultiDimReductionOp::getOperationName(), - vector_multi_dim_reduce_upgrade}}; + vector_multi_dim_reduce_upgrade}, + }; return *rules; } const llvm::StringMap& downgrade_rules() { static auto rules = new llvm::StringMap{ {EnqueueDMAOp::getOperationName(), enqueue_dma_downgrade}, + {WaitDMA2Op::getOperationName(), wait_dma2_downgrade}, + {DynamicGatherOp::getOperationName(), dynamic_gather_downgrade}, + {IotaOp::getOperationName(), iota_downgrade}, {SemaphoreSignalOp::getOperationName(), semaphore_signal_downgrade}, {vector::MultiDimReductionOp::getOperationName(), vector_multi_dim_reduce_downgrade}}; @@ -180,7 +311,8 @@ void MosaicSerdePass::runOnOperation() { {.dialect_prefix = kMangledDialect, .highest_version = kVersion, .version_attr_name = kVersionAttrName, - .serialize_version = serialize_version}))) { + .serialize_version = serialize_version}, + /*keep_version_attr=*/keep_version_attr))) { signalPassFailure(); } } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.h b/jaxlib/mosaic/dialect/tpu/transforms/serde.h index 8685918d3b39..574443e310b2 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.h @@ -1,15 +1,31 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_SERDE_H_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_SERDE_H_ #include #include -#include "llvm/include/llvm/ADT/StringRef.h" -#include "llvm/include/llvm/Support/CommandLine.h" -#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" -#include "mlir/include/mlir/Pass/Pass.h" -#include "mlir/include/mlir/Pass/PassRegistry.h" -#include "jaxlib/pass_boilerplate.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "jaxlib/mosaic/pass_boilerplate.h" namespace mlir::tpu { @@ -46,6 +62,8 @@ struct MosaicSerdePass : public jaxlib::mlir::Pass { protected: ::mlir::Pass::Option serialize{*this, "serialize", llvm::cl::desc("")}; + ::mlir::Pass::Option keep_version_attr{ + *this, "keep-version-attr", llvm::cl::desc(""), llvm::cl::init(true)}; ::mlir::Pass::Option target_version{*this, "target-version", llvm::cl::desc("")}; }; diff --git a/jaxlib/mosaic/dialect/tpu/util.cc b/jaxlib/mosaic/dialect/tpu/util.cc index 651cef85f740..86f51a8ae170 100644 --- a/jaxlib/mosaic/dialect/tpu/util.cc +++ b/jaxlib/mosaic/dialect/tpu/util.cc @@ -15,23 +15,30 @@ limitations under the License. #include "jaxlib/mosaic/dialect/tpu/util.h" +#include #include +#include #include #include #include #include #include -#include "llvm/Support/MathExtras.h" #include "absl/log/check.h" #include "absl/types/span.h" -#include "llvm/include/llvm/Support/raw_ostream.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/IR/ValueRange.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" @@ -45,23 +52,65 @@ std::ostream &operator<<(std::ostream &os, Print p) { return os; } -SmallVector ComputeTileStrides(MemRefType memref_ty, +SmallVector ComputeTileStrides(absl::Span shape, absl::Span tiling) { - SmallVector tile_strides(memref_ty.getRank()); + CHECK_LE(tiling.size(), shape.size()); + SmallVector tile_strides(shape.size()); int64_t stride = 1; - for (int64_t i = 0; i < memref_ty.getRank(); ++i) { - int64_t idx = memref_ty.getRank() - 1 - i; - int64_t tiling_idx = tiling.size() - 1 - i; + for (size_t i = 0; i < shape.size(); ++i) { + const size_t idx = shape.size() - 1 - i; tile_strides[idx] = stride; - if (tiling_idx >= 0) { - stride *= llvm::divideCeil(memref_ty.getShape()[idx], tiling[tiling_idx]); + if (i < tiling.size()) { + const size_t tiling_idx = tiling.size() - 1 - i; + stride *= llvm::divideCeil(shape[idx], tiling[tiling_idx]); } else { - stride *= memref_ty.getShape()[idx]; + stride *= shape[idx]; } } return tile_strides; } +FailureOr> computeSqueezedDimsChecked( + Operation *op, ArrayRef source_shape, + ArrayRef target_shape) { + SmallVector squeezed; + int source_index = source_shape.size() - 1; + int target_index = target_shape.size() - 1; + + while (source_index >= 0 || target_index >= 0) { + int64_t target_dim = (target_index >= 0) ? target_shape[target_index] : -1; + if (source_index < 0) { + op->emitError() << llvm::formatv( + "Target shape is not valid. Source: {0}, Target: {1}.", + shapeToString(source_shape), shapeToString(target_shape)); + return failure(); + } + int64_t source_dim = source_shape[source_index]; + if (source_dim == target_dim) { + source_index--; + target_index--; + } else { + if (source_dim != 1) { + op->emitError() << llvm::formatv( + "Target shape is not valid. Source: {0}, Target: {1}.", + shapeToString(source_shape), shapeToString(target_shape)); + return failure(); + } + squeezed.push_back(source_index); + source_index--; + } + } + + if (source_index != -1 || target_index != -1) { + op->emitError() << "Shape mismatch after traversal. Source shape: " + << shapeToString(source_shape) + << ", target shape: " << shapeToString(target_shape); + return failure(); + } + std::reverse(squeezed.begin(), squeezed.end()); + return squeezed; +} + std::optional> isTransposedMatmul( DotDimensionNumbersAttr dim_numbers) { auto lhs_contracting_dims = dim_numbers.getLhsContractingDims(); @@ -130,7 +179,7 @@ bool canReinterpretToUntiledMemref(TypedValue tiled_memref, return false; } auto rank = tiled_memref_ty.getRank(); - auto packing = 32 / tiled_memref_ty.getElementTypeBitWidth(); + auto packing = 32 / getElementTypeBitwidth(tiled_memref_ty); if (tiled_memref_ty.isDynamicDim(rank - 1)) { // TODO(jevinjiang): we can still allow the minormost padding if we know the // max bound of the dynamic size is not larger than the target_shape[1]. @@ -158,6 +207,17 @@ bool canReinterpretToUntiledMemref(TypedValue tiled_memref, *(tiled_layout.getTileStrides().end() - 2) == 1; } +bool isContiguousMemref(TypedValue memref) { + auto memref_ty = getMemRefType(memref); + if (auto tiled_layout = + dyn_cast(memref_ty.getLayout())) { + auto contiguous_tile_strides = ComputeTileStrides( + memref_ty, tiled_layout.getTiles().front().dimensions()); + return contiguous_tile_strides == tiled_layout.getTileStrides(); + } + return true; +} + bool HasMemorySpace(MemRefType ty, tpu::MemorySpace space) { auto memory_space = dyn_cast_or_null(ty.getMemorySpace()); @@ -177,7 +237,7 @@ bool layoutIsValidForValue(const Layout &l, const Value v, if (!vty.getElementType().isIntOrFloat()) { return false; } - const int8_t bitwidth = vty.getElementTypeBitWidth(); + const int8_t bitwidth = getElementTypeBitwidth(vty); if (bitwidth != l->bitwidth() && bitwidth != 1) { return false; } @@ -209,7 +269,9 @@ FailureOr> getOutLayouts( FAILUREOR_ASSIGN_OR_RETURN(const SmallVector out_layouts, getLayoutArrayFromAttr(op.getAttr("out_layout"))); if (out_layouts.size() != op.getNumResults()) { - return op.emitOpError("out_layout size does not match number of results"); + return op.emitOpError("out_layout size (") + << out_layouts.size() << ") does not match number of results (" + << op.getNumResults() << ")"; } for (const auto [l, res] : llvm::zip_equal(out_layouts, op.getResults())) { if (!layoutIsValidForValue(l, res, target_shape)) { @@ -224,7 +286,9 @@ FailureOr> getInLayouts( FAILUREOR_ASSIGN_OR_RETURN(const SmallVector in_layouts, getLayoutArrayFromAttr(op.getAttr("in_layout"))); if (in_layouts.size() != op.getNumOperands()) { - return op.emitOpError("in_layout size does not match number of operands"); + return op.emitOpError("in_layout size (") + << in_layouts.size() << ") does not match number of operands (" + << op.getNumOperands() << ")"; } for (const auto [l, operand] : llvm::zip_equal(in_layouts, op.getOperands())) { @@ -274,4 +338,46 @@ void setLayout(Operation *op, ArrayRef in, ArrayRef out) { setInLayout(op, in); setOutLayout(op, out); } + +std::optional getIntConst(Value v) { + if (auto const_op = v.getDefiningOp()) { + if (auto cst_attr = dyn_cast(const_op.getValue())) { + return cst_attr.getValue().getSExtValue(); + } + } + return std::nullopt; +} + +SmallVector getNontrivialTransitiveUsers(Value v) { + auto isUnaryElementwise = [](Operation *op) { + if (!op->hasTrait()) { + return false; + } + return op->getNumOperands() == 1 && op->getNumResults() == 1; + }; + SmallVector users; + SmallVector candidates; + candidates.push_back(v); + while (!candidates.empty()) { + Value candidate = candidates.back(); + candidates.pop_back(); + for (const auto &user : candidate.getUsers()) { + if (isa(user) || isUnaryElementwise(user)) + candidates.push_back(user->getResult(0)); + else + users.push_back(user); + } + } + return users; +} + +bool hasVectorOperandsOrResults(Operation& op) { + for (Value value : llvm::concat(op.getOperands(), op.getResults())) { + if (isa(value.getType())) { + return true; + } + } + return false; +} + } // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/util.h b/jaxlib/mosaic/dialect/tpu/util.h index 2e19cb820b5b..8a09667b237b 100644 --- a/jaxlib/mosaic/dialect/tpu/util.h +++ b/jaxlib/mosaic/dialect/tpu/util.h @@ -1,3 +1,18 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_UTIL_H_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_UTIL_H_ @@ -10,22 +25,22 @@ #include #include +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "llvm/Support/Compiler.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Location.h" +#include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/status/status.h" -#include "absl/types/span.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Diagnostics.h" -#include "mlir/include/mlir/IR/Value.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/statusor.h" // TODO: Instead of CHECK_EQs, can we do something like TF_RET_CHECK but with // MLIR diagnostics? @@ -147,7 +162,7 @@ class Print { std::ostream &operator<<(std::ostream &os, Print p); template -FailureOr getTypeBitwidth(Type ty) { +int8_t getTypeBitwidth(Type ty) { if (auto integer_ty = dyn_cast(ty)) { const unsigned width = integer_ty.getWidth(); if constexpr (adjust_bool) { @@ -157,13 +172,25 @@ FailureOr getTypeBitwidth(Type ty) { return width; } } - if (isa(ty)) { - return ty.getIntOrFloatBitWidth(); + if (isa(ty)) { + return 8; } - return emitError(UnknownLoc::get(ty.getContext()), - "Unsupported type in mosaic dialect: ") - << ty; + return ty.getIntOrFloatBitWidth(); +} + +// Returns the bitwidth of the element type. The function works for both +// scalar and vector types. +template +inline int8_t getElementTypeBitwidth(Type ty) { + if (auto vty = dyn_cast(ty)) { + return getTypeBitwidth(vty.getElementType()); + } + return getTypeBitwidth(ty); +} + +template +inline int8_t getElementTypeBitwidth(MemRefType ty) { + return getElementTypeBitwidth(ty.getElementType()); } template @@ -171,11 +198,6 @@ ArrayRef> toArrayRef(absl::Span span) { return ArrayRef>(span.data(), span.size()); } -inline arith::ConstantOp IdxConst(int64_t idx, OpBuilder &builder, - Location loc) { - return builder.create(loc, builder.getIndexType(), - builder.getIndexAttr(idx)); -} // Debug only util. template @@ -192,8 +214,22 @@ std::string shapeToString(const T &shape) { return os.str(); } -SmallVector ComputeTileStrides(MemRefType memref_ty, +SmallVector ComputeTileStrides(absl::Span shape, absl::Span tiling); + +inline SmallVector ComputeTileStrides( + MemRefType memref_ty, absl::Span tiling) { + absl::Span shape(memref_ty.getShape().data(), + memref_ty.getShape().size()); + return ComputeTileStrides(shape, tiling); +} + +// Computes the dimensions that were squeezed from the source shape to match the +// target shape. Returns the dimensions in increasing order. +FailureOr> computeSqueezedDimsChecked( + Operation *op, ArrayRef source_shape, + ArrayRef target_shape); + // Assuming MKN matmul - This function must only be called after // canonicalization passes. // @@ -211,6 +247,8 @@ bool canReinterpretToUntiledMemref(TypedValue tiled_memref, const std::array &target_shape, bool allow_minormost_padding = false); +bool isContiguousMemref(TypedValue memref); + // Determines whether the given MemRefType has the given memory space. bool HasMemorySpace(MemRefType ty, tpu::MemorySpace space); @@ -233,6 +271,48 @@ void setLayout(Operation *op, Layout in, Layout out); void setLayout(Operation *op, ArrayRef in, Layout out); void setLayout(Operation *op, Layout in, ArrayRef out); void setLayout(Operation *op, ArrayRef in, ArrayRef out); + +// Helper functions to create constants. +inline arith::ConstantOp IdxConst(int64_t idx, OpBuilder &builder, + Location loc) { + return arith::ConstantOp::create(builder, loc, builder.getIndexType(), + builder.getIndexAttr(idx)); +} + +inline arith::ConstantOp I32Const(int32_t value, OpBuilder &builder, + Location loc) { + return arith::ConstantOp::create(builder, loc, builder.getI32Type(), + builder.getI32IntegerAttr(value)); +} + +inline arith::ConstantOp I32Const(int32_t value, ArrayRef shape, + OpBuilder &builder, Location loc) { + return arith::ConstantOp::create( + builder, loc, + DenseElementsAttr::get( + VectorType::get(shape, builder.getI32Type()), + builder.getIntegerAttr(builder.getI32Type(), value))); +} + +std::optional getIntConst(Value v); + +// Recursively finds all non-trivial users of a given value, including those +// accessed via `tpu.bitcast` or unary elementwise operations. However, +// `tpu.bitcast` and unary element-wise operations are excluded from the +// results. +SmallVector getNontrivialTransitiveUsers(Value v); + +bool hasVectorOperandsOrResults(Operation& op); + +// Return a mod b for a, b > 0, but adjusted to return b when a mod b == 0 such +// that the result is strictly positive. +template +auto positiveMod(U a, V b) { + DCHECK_GT(a, 0); + DCHECK_GT(b, 0); + return (a - 1) % b + 1; +} + } // namespace mlir::tpu #endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_UTIL_H_ diff --git a/jaxlib/mosaic/dialect/tpu/vreg_util.cc b/jaxlib/mosaic/dialect/tpu/vreg_util.cc index 1f59ee13a311..073c1a1f02f4 100644 --- a/jaxlib/mosaic/dialect/tpu/vreg_util.cc +++ b/jaxlib/mosaic/dialect/tpu/vreg_util.cc @@ -19,16 +19,16 @@ limitations under the License. #include #include "absl/log/check.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Diagnostics.h" -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/Types.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/IR/ValueRange.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" #include "xla/array.h" @@ -51,7 +51,7 @@ VectorType getNativeVregOrVmaskTypeImpl( VectorType getNativeVregOrVmaskType(Type elem_ty, const int8_t layout_bitwidth, const std::array target_shape) { - int8_t bitwidth = elem_ty.getIntOrFloatBitWidth(); + int8_t bitwidth = getTypeBitwidth(elem_ty); if (bitwidth == 1) { bitwidth = layout_bitwidth; } else { @@ -62,14 +62,14 @@ VectorType getNativeVregOrVmaskType(Type elem_ty, const int8_t layout_bitwidth, VectorType getNativeVregType(Type elem_ty, const std::array target_shape) { - return getNativeVregOrVmaskTypeImpl(elem_ty, elem_ty.getIntOrFloatBitWidth(), + return getNativeVregOrVmaskTypeImpl(elem_ty, getTypeBitwidth(elem_ty), target_shape); } TypedValue getFullVector(ImplicitLocOpBuilder &builder, VectorType vty, Attribute value) { return cast>( - builder.create(DenseElementsAttr::get(vty, value)) + arith::ConstantOp::create(builder, DenseElementsAttr::get(vty, value)) .getResult()); } @@ -79,6 +79,20 @@ TypedValue getFullLikeVector(ImplicitLocOpBuilder &builder, return getFullVector(builder, vec.getType(), value); } +TypedValue getFullVector(OpBuilder &builder, Location loc, + VectorType vty, Attribute value) { + return cast>( + arith::ConstantOp::create(builder, loc, + DenseElementsAttr::get(vty, value)) + .getResult()); +} + +TypedValue getFullLikeVector(OpBuilder &builder, Location loc, + TypedValue vec, + Attribute value) { + return getFullVector(builder, loc, vec.getType(), value); +} + TypedValue getZerosVector(ImplicitLocOpBuilder &builder, VectorType vty) { return getFullVector(builder, vty, builder.getZeroAttr(vty.getElementType())); @@ -111,13 +125,13 @@ FailureOr> getX32VmaskByPaddingEnd( const VectorType vmask_ty = getNativeVregOrVmaskType( builder.getI1Type(), /*layout_bitwidth=*/32, target_shape); if (dim == 0) { - mask_op = builder.create( - vmask_ty, ValueRange{idx_const(0), idx_const(0)}, + mask_op = tpu::CreateMaskOp::create( + builder, vmask_ty, ValueRange{idx_const(0), idx_const(0)}, ValueRange{idx_const(target_shape[0] - padding), idx_const(target_shape[1])}); } else { - mask_op = builder.create( - vmask_ty, ValueRange{idx_const(0), idx_const(0)}, + mask_op = tpu::CreateMaskOp::create( + builder, vmask_ty, ValueRange{idx_const(0), idx_const(0)}, ValueRange{idx_const(target_shape[0]), idx_const(target_shape[1] - padding)}); } @@ -174,23 +188,23 @@ LogicalResult maskNativeTilingVregs(ImplicitLocOpBuilder &builder, Value partial_sublane_mask = getFullVector( builder, i32_vreg_ty, builder.getI32IntegerAttr( - 0xffffffff >> (sub_padding * vreg_ty.getElementTypeBitWidth()))); + 0xffffffff >> (sub_padding * getElementTypeBitwidth(vreg_ty)))); // Insert 0xffffffff above the blended sublane. - Value sublane_mask = builder.create(mask_top, i32_max_vreg, - partial_sublane_mask); + Value sublane_mask = arith::SelectOp::create( + builder, mask_top, i32_max_vreg, partial_sublane_mask); // Insert 0 below the blended sublane. - sublane_mask = builder.create(mask_bottom, sublane_mask, - i32_zeros_vreg); + sublane_mask = arith::SelectOp::create(builder, mask_bottom, sublane_mask, + i32_zeros_vreg); for (int64_t i = 0; i < vregs.dim(1); ++i) { Value &vreg = vregs({vregs.dim(0) - 1, i}); - Value i32_vreg = builder.create(i32_vreg_ty, vreg); + Value i32_vreg = tpu::BitcastVregOp::create(builder, i32_vreg_ty, vreg); if (sub_padding > 0) { - i32_vreg = builder.create(i32_vreg, sublane_mask); + i32_vreg = arith::AndIOp::create(builder, i32_vreg, sublane_mask); } else { - i32_vreg = builder.create(mask_bottom, i32_vreg, - i32_zeros_vreg); + i32_vreg = arith::SelectOp::create(builder, mask_bottom, i32_vreg, + i32_zeros_vreg); } - vreg = builder.create(vreg_ty, i32_vreg); + vreg = tpu::BitcastVregOp::create(builder, vreg_ty, i32_vreg); } } // Mask out the right. @@ -200,10 +214,10 @@ LogicalResult maskNativeTilingVregs(ImplicitLocOpBuilder &builder, target_shape, /*dim=*/1)); for (int64_t i = 0; i < vregs.dim(0); ++i) { Value &vreg = vregs({i, vregs.dim(1) - 1}); - Value i32_vreg = builder.create(i32_vreg_ty, vreg); - i32_vreg = - builder.create(mask_right, i32_vreg, i32_zeros_vreg); - vreg = builder.create(vreg_ty, i32_vreg); + Value i32_vreg = tpu::BitcastVregOp::create(builder, i32_vreg_ty, vreg); + i32_vreg = arith::SelectOp::create(builder, mask_right, i32_vreg, + i32_zeros_vreg); + vreg = tpu::BitcastVregOp::create(builder, vreg_ty, i32_vreg); } } return success(); @@ -211,9 +225,8 @@ LogicalResult maskNativeTilingVregs(ImplicitLocOpBuilder &builder, FailureOr> broadcastSubelements( ImplicitLocOpBuilder &builder, TypedValue vec, - int subelement_idx, std::array target_shape, - int hardware_generation) { - int bitwidth = vec.getType().getElementTypeBitWidth(); + int subelement_idx, std::array target_shape) { + int bitwidth = getElementTypeBitwidth(vec.getType()); int packing = 32 / bitwidth; if (subelement_idx < 0 || subelement_idx >= packing) { return builder.emitError() @@ -229,24 +242,16 @@ FailureOr> broadcastSubelements( getNativeVregType(builder.getIntegerType(bitwidth), target_shape); // The chosen subelements must be in the low bits. High bits are unspecified. Value src_vreg_int = - builder.create(vreg_native_int_ty, vec); - Value vreg_subelement_low = builder.create( - src_vreg_int, + tpu::BitcastVregOp::create(builder, vreg_native_int_ty, vec); + Value vreg_subelement_low = arith::ShRUIOp::create( + builder, src_vreg_int, getFullVector(builder, vreg_native_int_ty, builder.getI32IntegerAttr(subelement_idx * bitwidth))); - Value vreg_result_int; - if (hardware_generation >= 5) { - SmallVector packed_vregs(packing, vreg_subelement_low); - vreg_result_int = builder.create( - vreg_packed_int_ty, packed_vregs, tpu::PackFormat::kInterleaved); - } else { - // This can be virtualized as a tree of shifts and ORs. - return builder.emitError() - << "broadcastSubelements not implemented for hardware generation " - << hardware_generation; - } + SmallVector packed_vregs(packing, vreg_subelement_low); + Value vreg_result_int = tpu::PackSubelementsOp::create( + builder, vreg_packed_int_ty, packed_vregs, tpu::PackFormat::kInterleaved); return cast>( - builder.create(vec.getType(), vreg_result_int) + tpu::BitcastVregOp::create(builder, vec.getType(), vreg_result_int) .getResult()); } diff --git a/jaxlib/mosaic/dialect/tpu/vreg_util.h b/jaxlib/mosaic/dialect/tpu/vreg_util.h index 86955e128f59..90e802fcb8fc 100644 --- a/jaxlib/mosaic/dialect/tpu/vreg_util.h +++ b/jaxlib/mosaic/dialect/tpu/vreg_util.h @@ -19,12 +19,12 @@ limitations under the License. #include #include -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/Types.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" #include "xla/array.h" namespace mlir::tpu { @@ -50,6 +50,15 @@ TypedValue getFullLikeVector(ImplicitLocOpBuilder &builder, TypedValue vec, Attribute value); +// Same as above, but takes a `loc` as input, in case of an OpBuilder. +TypedValue getFullVector(OpBuilder &builder, Location loc, + VectorType vty, Attribute value); + +// Same as above, but takes a `vec` as input. +TypedValue getFullLikeVector(OpBuilder &builder, Location loc, + TypedValue vec, + Attribute value); + // Creates a vmask with false flags to bottom (dim = 0) // or right (dim = 1) where the flag count corresponds to the (dim_size - // padding). @@ -81,8 +90,7 @@ LogicalResult maskNativeTilingVregs(ImplicitLocOpBuilder &builder, // subelement_idx must be between 0 and packing. FailureOr> broadcastSubelements( ImplicitLocOpBuilder &builder, TypedValue vec, - int subelement_idx, std::array target_shape, - int hardware_generation); + int subelement_idx, std::array target_shape); } // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/vreg_util_test.cc b/jaxlib/mosaic/dialect/tpu/vreg_util_test.cc index ea3063361e1a..9c1ba87aca0a 100644 --- a/jaxlib/mosaic/dialect/tpu/vreg_util_test.cc +++ b/jaxlib/mosaic/dialect/tpu/vreg_util_test.cc @@ -21,20 +21,20 @@ limitations under the License. #include #include -#include "llvm/include/llvm/ADT/TypeSwitch.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/MLIRContext.h" -#include "mlir/include/mlir/IR/OwningOpRef.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Support/DebugStringHelper.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/DebugStringHelper.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" namespace mlir::tpu { @@ -121,7 +121,7 @@ class VregUtilTest : public ::testing::Test { tpu::TPUDialect>(); mlir::Location loc = mlir::UnknownLoc::get(&context_); mlir::OpBuilder b(&context_); - module_ = b.create(loc); + module_ = ModuleOp::create(b, loc); builder_ = std::make_unique( module_->getLoc(), module_->getBodyRegion()); } @@ -175,7 +175,7 @@ TEST_F(VregUtilTest, GetFullVector) { TEST_F(VregUtilTest, GetFullLikeVector) { VectorType vty = VectorType::get({2, 4}, Builder().getF32Type()); - TypedValue in_vec = Builder().create( + TypedValue in_vec = Builder().create( vty, Builder().create( vty.getElementType(), Builder().getF32FloatAttr(1.0f))); TypedValue vec = @@ -193,7 +193,7 @@ TEST_F(VregUtilTest, GetZerosVector) { TEST_F(VregUtilTest, GetZerosLikeVector) { VectorType vty = VectorType::get({2, 4}, Builder().getF32Type()); - TypedValue in_vec = Builder().create( + TypedValue in_vec = Builder().create( vty, Builder().create( vty.getElementType(), Builder().getF32FloatAttr(1.0f))); TypedValue vec = getZerosLikeVector(Builder(), in_vec); diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index 9249ae256901..3bc22b8c1091 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@rules_cc//cc:cc_binary.bzl", "cc_binary") +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") load("@rules_python//python:defs.bzl", "py_library") -load("//jaxlib:jax.bzl", "nanobind_extension") +load("//jaxlib:jax.bzl", "if_oss", "nanobind_extension") package( default_applicable_licenses = [], - default_visibility = ["//jax:mosaic_gpu_users"], + default_visibility = ["//jax/experimental:mosaic_gpu_users"], ) py_library( @@ -26,6 +29,14 @@ py_library( deps = [":_mosaic_gpu_ext"], ) +cc_library( + name = "mosaic_gpu_support", + deps = [ + ":custom_call", + ":runtime", + ], +) + cc_library( name = "target", srcs = ["target.cc"], @@ -36,25 +47,122 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@llvm-project//llvm:MC", + "@llvm-project//llvm:TargetParser", + ], +) + +cc_library( + name = "dump", + srcs = ["dump.cc"], + hdrs = ["dump.h"], + deps = [ + ":library_paths", # buildcleaner: keep + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@tsl//tsl/platform:path", + ], +) + +cc_test( + name = "dump_test", + srcs = ["dump_test.cc"], + deps = [ + ":dump", + "//testing/base/public:gunit_main", + "@llvm-project//mlir:IR", + ], +) + +cc_test( + name = "gpu_module_to_assembly_test", + srcs = ["gpu_module_to_assembly_test.cc"], + deps = [ + ":passes", + "//jaxlib/mosaic/dialect/gpu:mosaic_gpu", + "//testing/base/public:gunit_main", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:BufferizationInterfaces", + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:ComplexToLLVM", + "@llvm-project//mlir:ControlFlowToLLVM", + "@llvm-project//mlir:DataLayoutInterfaces", + "@llvm-project//mlir:FuncToLLVM", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToLLVMIRTranslation", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexToLLVM", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:MathToLLVM", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:MemRefToLLVM", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:NVVMTarget", + "@llvm-project//mlir:NVVMToLLVM", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ROCDLDialect", + "@llvm-project//mlir:SCFUtils", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:UBToLLVM", + "@llvm-project//mlir:VectorDialect", + "@xla//xla/service/gpu/llvm_gpu_backend:nvptx_libdevice_path", + ], +) + +cc_library( + name = "serde", + srcs = ["serde.cc"], + hdrs = ["serde.h"], + deps = [ + "//jaxlib/mosaic:pass_boilerplate", + "//jaxlib/mosaic:serde", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:DataLayoutInterfaces", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:VectorDialect", ], ) cc_library( name = "passes", srcs = [ + "assembly_to_binary.cc", + "gpu_module_to_assembly.cc", "launch_lowering.cc", "passes.cc", - "serde.cc", ], hdrs = [ + "assembly_to_binary.h", + "gpu_module_to_assembly.h", "launch_lowering.h", "passes.h", - "serde.h", ], deps = [ - "//jaxlib:pass_boilerplate", - "//jaxlib/mosaic:serde", + ":dump", + "//jaxlib/mosaic:pass_boilerplate", + "@com_google_absl//absl/base", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@llvm-project//llvm:IRReader", "@llvm-project//llvm:Support", + "@llvm-project//llvm:ir_headers", "@llvm-project//mlir:DataLayoutInterfaces", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:GPUDialect", @@ -62,10 +170,16 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMCommonConversion", "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:NVVMDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TargetLLVM", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:VectorDialect", + "@xla//xla/service/gpu/llvm_gpu_backend:load_ir_module", + "@xla//xla/stream_executor/cuda:compilation_options", + "@xla//xla/stream_executor/cuda:compilation_provider", + "@xla//xla/stream_executor/cuda:cuda_compute_capability", ], ) @@ -77,12 +191,14 @@ CAPI_HEADERS = [ "integrations/c/passes.h", ] +# `mlir_capi` should be used only to wrap hardware-agnostic passes that need to be called directly +# from Python. cc_library( name = "mlir_capi", srcs = CAPI_SOURCES, hdrs = CAPI_HEADERS, deps = [ - ":passes", + ":serde", "@llvm-project//mlir:CAPIIRHeaders", ], ) @@ -111,22 +227,58 @@ cc_library( cc_library( name = "runtime", srcs = ["runtime.cc"], + # Linker may prune these symbols if they are not explicitly exported. + linkopts = [ + "-Wl,--export-dynamic-symbol='mosaic_gpu_*'", + "-Wl,--export-dynamic-symbol='nvshmem_my_pe'", + "-Wl,--export-dynamic-symbol='nvshmem_ptr'", + "-Wl,--export-dynamic-symbol='nvshmemx_barrier_all_on_stream'", + "-Wl,--export-dynamic-symbol='nvshmemx_cumodule_finalize'", + "-Wl,--export-dynamic-symbol='nvshmemx_cumodule_init'", + "-Wl,--export-dynamic-symbol='nvshmemx_init_status'", + "-Wl,--export-dynamic-symbol='nvshmemx_mc_ptr'", + ], + deps = [ + ":nvshmem", + "@local_config_cuda//cuda:cuda_headers", + ], + alwayslink = True, +) + +cc_library( + name = "nvshmem", + hdrs = ["nvshmem.h"], deps = [ "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/tsl/cuda:cudart", ], ) cc_library( name = "custom_call", srcs = ["custom_call.cc"], + linkopts = if_oss( + [], # We use a version script to ensure symbol visibility in the OSS build. + [ + "-Wl,--export-dynamic-symbol='MosaicGpuCompile'", + "-Wl,--export-dynamic-symbol='MosaicGpuUnload'", + ], + ), deps = [ + ":dump", + ":library_paths", # buildcleaner: keep + ":nvshmem", ":passes", + ":serde", ":target", "//jaxlib/cuda:cuda_vendor", "//jaxlib/mosaic/dialect/gpu:mosaic_gpu", + "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -139,6 +291,7 @@ cc_library( "@llvm-project//mlir:ArithTransforms", "@llvm-project//mlir:BuiltinToLLVMIRTranslation", "@llvm-project//mlir:ComplexToLLVM", + "@llvm-project//mlir:ControlFlowDialect", "@llvm-project//mlir:ControlFlowToLLVM", "@llvm-project//mlir:ConversionPasses", "@llvm-project//mlir:ExecutionEngine", @@ -151,6 +304,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:IndexToLLVM", "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMIRTransforms", "@llvm-project//mlir:LLVMToLLVMIRTranslation", "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MathToLLVM", @@ -170,12 +324,54 @@ cc_library( "@llvm-project//mlir:UBToLLVM", "@llvm-project//mlir:VectorDialect", "@llvm-project//mlir:VectorToLLVM", + "@local_config_cuda//cuda:cuda_headers", + "@tsl//tsl/platform:path", + "@tsl//tsl/profiler/lib:traceme", + "@xla//xla:executable_run_options", + "@xla//xla/backends/gpu:ffi", + "@xla//xla/ffi", + "@xla//xla/ffi:ffi_api", "@xla//xla/service:custom_call_status", "@xla//xla/service:custom_call_target_registry", + "@xla//xla/service/gpu/llvm_gpu_backend:nvptx_libdevice_path", + "@xla//xla/service/llvm_ir:llvm_command_line_options", + "@xla//xla/stream_executor/cuda:assemble_compilation_provider", + "@xla//xla/stream_executor/cuda:compilation_provider", + "@xla//xla/stream_executor/cuda:compilation_provider_options", + "@xla//xla/stream_executor/cuda:cuda_compute_capability", + "@xla//xla/stream_executor/cuda:ptx_compiler_support", + "@xla//xla/tsl/platform:statusor", ], alwayslink = True, ) +cc_test( + name = "custom_call_test", + srcs = ["custom_call_test.cc"], + tags = ["requires-gpu-sm90"], + deps = [ + ":mosaic_gpu_support", + "//testing/base/public:gunit_main", + "@com_google_absl//absl/base:log_severity", + "@com_google_absl//absl/log:globals", + "@com_google_absl//absl/log:scoped_mock_log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@xla//xla/hlo/builder:xla_computation", + "@xla//xla/hlo/parser:hlo_parser", + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:pjrt_executable", + "@xla//xla/pjrt/plugin/xla_gpu:xla_gpu_pjrt_client", + "@xla//xla/service:gpu_plugin", + "@xla//xla/stream_executor/cuda:cuda_platform", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:status", + "@xla//xla/tsl/platform:statusor", + ], +) + nanobind_extension( name = "_mosaic_gpu_ext", srcs = ["mosaic_gpu_ext.cc"], @@ -189,8 +385,6 @@ nanobind_extension( "@com_google_absl//absl/cleanup", "@com_google_absl//absl/strings", "@nanobind", - "@xla//xla/ffi/api:c_api", - "@xla//xla/ffi/api:ffi", "@xla//xla/tsl/cuda:cudart", ], ) @@ -205,7 +399,36 @@ cc_binary( "notap", ], deps = [ + ":nvshmem", "@local_config_cuda//cuda:cuda_headers", "@xla//xla/tsl/cuda:cudart", ], ) + +cc_library( + name = "library_paths", + hdrs = ["library_paths.h"], +) + +cc_library( + name = "tiled_layout", + srcs = ["tiled_layout.cc"], + hdrs = ["tiled_layout.h"], + deps = [ + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_test( + name = "tiled_layout_test", + srcs = ["tiled_layout_test.cc"], + deps = [ + ":tiled_layout", + "//testing/base/public:gunit_main", + "@com_google_absl//absl/status", + ], +) diff --git a/jaxlib/mosaic/gpu/assembly_to_binary.cc b/jaxlib/mosaic/gpu/assembly_to_binary.cc new file mode 100644 index 000000000000..11381899217d --- /dev/null +++ b/jaxlib/mosaic/gpu/assembly_to_binary.cc @@ -0,0 +1,155 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This pass lowers existing PTX into a `gpu.binary` op using stream executor +// compilation providers. The stock MLIR pipeline uses `ptxas` in a subprocess +// to compile PTX by default. This does not work reliably in all environments, +// and stream executor's compilation providers are meant to remedy this problem. + +#include "jaxlib/mosaic/gpu/assembly_to_binary.h" + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "jaxlib/mosaic/gpu/dump.h" +#include "jaxlib/mosaic/pass_boilerplate.h" +#include "xla/stream_executor/cuda/compilation_options.h" +#include "xla/stream_executor/cuda/compilation_provider.h" +#include "xla/stream_executor/cuda/cuda_compute_capability.h" + +namespace mosaic { +namespace gpu { + +namespace { + +namespace se = ::stream_executor; + +class AssemblyToBinaryPass + : public jaxlib::mlir::Pass { + public: + using jaxlib::mlir::Pass::Pass; + + AssemblyToBinaryPass( + const se::cuda::CompilationProvider* compilation_provider, + se::CudaComputeCapability cc) + : compilation_provider_(std::move(compilation_provider)), + cc_(std::move(cc)) {} + + static constexpr llvm::StringLiteral kArgumentName = + "mosaic-gpu-assembly-to-binary"; + static constexpr llvm::StringLiteral kPassName = "GpuAssemblyToBinaryPass"; + + void runOnOperation() override { + mlir::ModuleOp module = getOperation(); + mlir::OpBuilder b(module); + mlir::MLIRContext* ctx = module.getContext(); + DumpOptions dump_opts = GetOrSetDumpOptionsForModule(module); + + se::cuda::CompilationOptions compilation_options; + compilation_options.dump_compilation_log = dump_opts.ptxas; + compilation_options.generate_line_info = true; + + mlir::WalkResult result = module.walk([&](mlir::gpu::BinaryOp binary) { + if (binary.getObjects().size() != 1) { + binary.emitOpError("Expected exactly one object in the binary."); + return mlir::WalkResult::interrupt(); + } + + mlir::gpu::ObjectAttr object = + mlir::cast(*binary.getObjects().begin()); + if (object.getFormat() != mlir::gpu::CompilationTarget::Assembly) { + binary.emitOpError("Expected an assembly object."); + return mlir::WalkResult::interrupt(); + } + + llvm::StringRef ptx_str = object.getObject().getValue(); + if (dump_opts.ptx) { + DumpToFileOrStdout(ptx_str, dump_opts.module_basename + ".ptx", + dump_opts.dump_path); + } + absl::StatusOr sass_or = + compilation_provider_->Compile(cc_, ptx_str, compilation_options); + if (!sass_or.ok()) { + binary.emitOpError(sass_or.status().message()); + return mlir::WalkResult::interrupt(); + } + + if (dump_opts.ptxas) { + if (!sass_or->compilation_log.has_value()) { + binary.emitOpError("Expected a compilation log to be available."); + return mlir::WalkResult::interrupt(); + } + DumpToFileOrStdout(*sass_or->compilation_log, + dump_opts.module_basename + ".ptxas", + dump_opts.dump_path); + } + + mlir::StringAttr sass = mlir::StringAttr::get( + ctx, std::string(sass_or->cubin.begin(), sass_or->cubin.end())); + b.setInsertionPointAfter(binary); + + mlir::gpu::ObjectAttr new_object = mlir::gpu::ObjectAttr::get( + ctx, object.getTarget(), mlir::gpu::CompilationTarget::Binary, sass, + object.getProperties(), object.getKernels()); + binary.setObjectsAttr(mlir::ArrayAttr::get(ctx, {new_object})); + + if (dump_opts.sass || dump_opts.sass_ctrl) { + DumpSass(binary, dump_opts.dump_path, dump_opts.module_basename, + dump_opts.sass_ctrl); + } + + return mlir::WalkResult::advance(); + }); + + if (result.wasInterrupted()) { + signalPassFailure(); + } + } + + private: + const se::cuda::CompilationProvider* compilation_provider_; + se::CudaComputeCapability cc_; +}; + +} // namespace + +void registerAssemblyToBinaryPass( + const se::cuda::CompilationProvider* compilation_provider, + const se::CudaComputeCapability& cc) { + ::mlir::registerPass( + [compilation_provider, cc]() -> std::unique_ptr<::mlir::Pass> { + return std::make_unique(compilation_provider, + std::move(cc)); + }); +} + +} // namespace gpu +} // namespace mosaic diff --git a/jaxlib/mosaic/gpu/assembly_to_binary.h b/jaxlib/mosaic/gpu/assembly_to_binary.h new file mode 100644 index 000000000000..53e582b0c3f3 --- /dev/null +++ b/jaxlib/mosaic/gpu/assembly_to_binary.h @@ -0,0 +1,34 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_MOSAIC_GPU_ASSEMBLY_TO_BINARY_H_ +#define JAXLIB_MOSAIC_GPU_ASSEMBLY_TO_BINARY_H_ + +#include "xla/stream_executor/cuda/compilation_provider.h" +#include "xla/stream_executor/cuda/cuda_compute_capability.h" + +namespace mosaic { +namespace gpu { + +// Registers a pass that converts `gpu.binary` ops wrapping PTX assembly into +// `gpu.binary` ops wrapping a CUBIN binary. +void registerAssemblyToBinaryPass( + const stream_executor::cuda::CompilationProvider* compilation_provider, + const stream_executor::CudaComputeCapability& cc); + +} // namespace gpu +} // namespace mosaic + +#endif // JAXLIB_MOSAIC_GPU_ASSEMBLY_TO_BINARY_H_ diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 402e099c8d6b..b95c0254051c 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -18,199 +18,274 @@ limitations under the License. #include #include +#include #include #include #include #include -#include -#include +#include +#include #include +#include #include #include +#include // NOLINT #include #include #include +#include "jaxlib/mosaic/gpu/library_paths.h" +#include "absl/base/call_once.h" +#include "absl/base/no_destructor.h" #include "absl/base/optimization.h" -#include "absl/cleanup/cleanup.h" +#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" #include "absl/synchronization/mutex.h" -#include "llvm/include/llvm/ADT/SmallVector.h" -#include "llvm/include/llvm/Support/CodeGen.h" -#include "llvm/include/llvm/Support/TargetSelect.h" -#include "mlir/include/mlir/Conversion/ArithToLLVM/ArithToLLVM.h" -#include "mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" -#include "mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" -#include "mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" -#include "mlir/include/mlir/Conversion/IndexToLLVM/IndexToLLVM.h" -#include "mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h" -#include "mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" -#include "mlir/include/mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" -#include "mlir/include/mlir/Conversion/Passes.h" -#include "mlir/include/mlir/Conversion/UBToLLVM/UBToLLVM.h" -#include "mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Arith/Transforms/Passes.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/include/mlir/Dialect/GPU/Transforms/Passes.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h" -#include "mlir/include/mlir/Dialect/Math/IR/Math.h" -#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h" -#include "mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h" -#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/ExecutionEngine/ExecutionEngine.h" -#include "mlir/include/mlir/ExecutionEngine/OptUtils.h" -#include "mlir/include/mlir/IR/AsmState.h" -#include "mlir/include/mlir/IR/DialectRegistry.h" -#include "mlir/include/mlir/IR/MLIRContext.h" -#include "mlir/include/mlir/Parser/Parser.h" -#include "mlir/include/mlir/Pass/PassManager.h" -#include "mlir/include/mlir/Pass/PassRegistry.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Support/LogicalResult.h" -#include "mlir/include/mlir/Target/LLVM/NVVM/Target.h" -#include "mlir/include/mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" -#include "mlir/include/mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" -#include "mlir/include/mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" -#include "mlir/include/mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" -#include "mlir/include/mlir/Transforms/Passes.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/CodeGen.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" +#include "mlir/Conversion/Passes.h" +#include "mlir/Conversion/UBToLLVM/UBToLLVM.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/GPU/Transforms/Passes.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/LLVMIR/Transforms/Passes.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Target/LLVM/NVVM/Target.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" +#include "mlir/Transforms/Passes.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu.h" +#include "jaxlib/mosaic/gpu/assembly_to_binary.h" +#include "jaxlib/mosaic/gpu/dump.h" +#include "jaxlib/mosaic/gpu/gpu_module_to_assembly.h" #include "jaxlib/mosaic/gpu/launch_lowering.h" +#include "jaxlib/mosaic/gpu/nvshmem.h" #include "jaxlib/mosaic/gpu/passes.h" #include "jaxlib/mosaic/gpu/serde.h" #include "jaxlib/mosaic/gpu/target.h" +#include "xla/backends/gpu/ffi.h" +#include "xla/executable_run_options.h" +#include "xla/ffi/ffi.h" +#include "xla/ffi/ffi_api.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_target_registry.h" +#include "xla/service/gpu/llvm_gpu_backend/nvptx_libdevice_path.h" +#include "xla/service/llvm_ir/llvm_command_line_options.h" +#include "xla/stream_executor/cuda/assemble_compilation_provider.h" +#include "xla/stream_executor/cuda/compilation_provider.h" +#include "xla/stream_executor/cuda/compilation_provider_options.h" +#include "xla/stream_executor/cuda/cuda_compute_capability.h" +#include "xla/stream_executor/cuda/ptx_compiler_support.h" +#include "xla/tsl/platform/statusor.h" +#include "tsl/platform/path.h" +#include "tsl/profiler/lib/traceme.h" namespace { +namespace ffi = xla::ffi; +namespace se = stream_executor; + using MosaicInitFunc = void(void****); using MosaicHostFunc = void(void**); -absl::StatusOr> GetSmAndPtxIsaVersion() { - // Assumes driver has been initialized and a context exists. XLA already has - // some utilities to query this, but we try to stay runtime-agnostic, so we - // build our own here. - CUdevice device; - if (cuCtxGetDevice(&device) != CUDA_SUCCESS) { - return absl::InternalError("Failed to get device for current context"); - } - int major = 0; - if (cuDeviceGetAttribute(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, - device) != CUDA_SUCCESS) { - return absl::InternalError("Failed to get major compute capability"); - } - int minor = 0; - if (cuDeviceGetAttribute(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, - device) != CUDA_SUCCESS) { - return absl::InternalError("Failed to get minor compute capability"); - } - return mosaic::gpu::GetSmAndPtxIsaVersion(major, minor); +// Mirrors `--xla_gpu_cuda_data_dir`'s default value. +constexpr std::string_view kDefaultCudaDataDir = "./cuda_sdk_lib"; + +absl::StatusOr GetPtxIsaVersion( + const se::cuda::CompilationProvider& compilation_provider) { + TF_ASSIGN_OR_RETURN(int ptxas_latest_version, + compilation_provider.GetLatestPtxIsaVersion()); + // We'd like to target the latest PTX ISA version supported by + // ptxas. However, it doesn't make sense to ask LLVM to target a PTX + // ISA that it isn't aware of yet. Find the latest version supported + // by LLVM and return the minimum of the two versions, one from + // ptxas and the other from LLVM. + TF_ASSIGN_OR_RETURN(int llvm_latest_version, + mosaic::gpu::GetLatestLlvmPtxIsaVersion()); + int final_version = std::min(ptxas_latest_version, llvm_latest_version); + return absl::StrFormat("ptx%d", final_version); } mlir::FailureOr GetPassPipeline( - mlir::MLIRContext* ctx, mlir::gpu::CompilationTarget target, - const std::string& sm, const std::string& ptx_isa) { - static bool register_once = []() { - llvm::InitializeNativeTarget(); - llvm::InitializeNativeTarget(); - llvm::InitializeNativeTargetAsmPrinter(); - mlir::registerCanonicalizer(); - mlir::registerCSE(); - mlir::registerStripDebugInfo(); - mlir::registerConvertNVGPUToNVVMPass(); - mlir::registerConvertVectorToSCF(); - mlir::registerSCFToControlFlowPass(); - mlir::registerConvertNVVMToLLVMPass(); - mlir::registerArithToLLVMConversionPass(); - mlir::registerConvertIndexToLLVMPass(); - mlir::registerConvertGpuOpsToNVVMOps(); - mlir::registerConvertMathToLLVMPass(); - mlir::registerConvertFuncToLLVMPass(); - mlir::registerLowerAffinePass(); - mlir::registerReconcileUnrealizedCastsPass(); - // TODO(apaszke): Only register the passes we actually use. - mlir::memref::registerMemRefPasses(); - mlir::registerConvertToLLVMPass(); - mlir::registerGPUPasses(); - mlir::registerGpuLaunchSinkIndexComputationsPass(); - mosaic::gpu::registerGpuLaunchLoweringPass(); - mosaic::gpu::registerConvertGpuToLLVMPass(); - mosaic::gpu::registerByvalInsertionPass(); - mlir::arith::registerArithExpandOpsPass(); - return true; - }(); - (void)register_once; - return mlir::parsePassPipeline(absl::StrCat( - R"( - builtin.module( - arith-expand, - canonicalize, - gpu-launch-sink-index-computations, - convert-nvgpu-to-nvvm, - gpu-kernel-outlining{data-layout-str=}, - convert-vector-to-scf{full-unroll=false lower-tensors=false target-rank=1}, - convert-scf-to-cf, - convert-nvvm-to-llvm, - expand-strided-metadata, - nvvm-attach-target{O=3 chip=)", - sm, R"( fast=false features=+)", ptx_isa, - R"( ftz=false module= triple=nvptx64-nvidia-cuda}, - lower-affine, - convert-arith-to-llvm{index-bitwidth=0}, - convert-index-to-llvm{index-bitwidth=64}, - canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, - cse, - gpu.module(strip-debuginfo), - gpu.module(convert-gpu-to-nvvm{has-redux=false index-bitwidth=64 use-bare-ptr-memref-call-conv=false}), - gpu.module(canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}), - gpu.module(cse), - gpu.module(mosaic-byval-insertion), - gpu.module(reconcile-unrealized-casts), - mosaic-convert-gpu-to-llvm, - gpu-module-to-binary{format=)", - mlir::gpu::stringifyCompilationTarget(target).str(), R"(}, - convert-math-to-llvm{approximate-log1p=true}, - canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, - cse, - )", - (target != mlir::gpu::CompilationTarget::Assembly ? "gpu-launch-lowering," - : ""), - R"( - convert-to-llvm, - reconcile-unrealized-casts - ) - )")); + mlir::MLIRContext* ctx, + const se::cuda::CompilationProvider* compilation_provider, + const se::CudaComputeCapability& cc, const std::string& sm, + const std::string& ptx_isa, const std::string& nvshmem_path) { + static absl::once_flag register_passes_flag; + absl::call_once( + register_passes_flag, [&compilation_provider, &cc]() { + mosaic::gpu::EnsureLLVMNVPTXTargetIsRegistered(); + + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + mlir::registerCanonicalizer(); + mlir::registerCSE(); + mlir::registerStripDebugInfo(); + mlir::registerConvertNVGPUToNVVMPass(); + mlir::registerConvertVectorToSCF(); + mlir::registerSCFToControlFlowPass(); + mlir::registerConvertNVVMToLLVMPass(); + mlir::registerArithToLLVMConversionPass(); + mlir::registerConvertIndexToLLVMPass(); + mlir::registerConvertGpuOpsToNVVMOps(); + mlir::registerConvertMathToLLVMPass(); + mlir::registerConvertFuncToLLVMPass(); + mlir::registerLowerAffinePass(); + mlir::registerReconcileUnrealizedCastsPass(); + // TODO(apaszke): Only register the passes we actually use. + mlir::memref::registerMemRefPasses(); + mlir::registerConvertToLLVMPass(); + mlir::registerGPUPasses(); + mlir::registerGpuLaunchSinkIndexComputationsPass(); + mosaic::gpu::registerGpuModuleToAssemblyPass(); + mosaic::gpu::registerAssemblyToBinaryPass(compilation_provider, cc); + mosaic::gpu::registerGpuLaunchLoweringPass(); + mosaic::gpu::registerConvertGpuToLLVMPass(); + mosaic::gpu::registerByvalInsertionPass(); + mosaic::gpu::registerLLVMAttrInsertionPass(); + mosaic::gpu::registerResolveTrivialLocationsPass(); + mlir::arith::registerArithExpandOpsPass(); + mlir::LLVM::registerDIScopeForLLVMFuncOpPass(); + return true; + }); + const char* cuda_root = mosaic::gpu::GetCUDARoot(); + if (!cuda_root) { + return mlir::failure(); + } + std::vector libraries_to_link{ + ::xla::gpu::nvptx::LibDevicePath(kDefaultCudaDataDir)}; + if (!nvshmem_path.empty()) { + libraries_to_link.push_back(nvshmem_path); + } + return mlir::parsePassPipeline( + absl::StrFormat(R"( + builtin.module( + mosaic-gpu-resolve-trivial-locations, + arith-expand, + canonicalize, + gpu-launch-sink-index-computations, + convert-nvgpu-to-nvvm, + gpu-kernel-outlining{data-layout-str=}, + convert-vector-to-scf{full-unroll=false lower-tensors=false target-rank=1}, + convert-scf-to-cf, + convert-nvvm-to-llvm, + expand-strided-metadata, + nvvm-attach-target{O=3 chip=%1$s fast=false features=+%2$s ftz=false module= triple=nvptx64-nvidia-cuda}, + lower-affine, + convert-arith-to-llvm{index-bitwidth=0}, + convert-index-to-llvm{index-bitwidth=64}, + canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, + cse, + gpu.module(convert-gpu-to-nvvm{has-redux=false index-bitwidth=64 use-bare-ptr-memref-call-conv=false}), + gpu.module(canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}), + gpu.module(cse), + gpu.module(mosaic-byval-insertion), + gpu.module(mosaic-llvm-attr-insertion), + gpu.module(reconcile-unrealized-casts), + mosaic-convert-gpu-to-llvm, + ensure-debug-info-scope-on-llvm-func{emission-kind=DebugDirectivesOnly}, + mosaic-gpu-module-to-assembly{libraries-to-link=%3$s}, + convert-math-to-llvm{approximate-log1p=true}, + canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, + cse, + mosaic-gpu-assembly-to-binary, + gpu-launch-lowering, + convert-to-llvm, + reconcile-unrealized-casts + ) + )", + sm, ptx_isa, absl::StrJoin(libraries_to_link, ","))); } mlir::LogicalResult RunPasses(mlir::OpPassManager&& passes, - mlir::ModuleOp module) { + mlir::ModuleOp module, + const mosaic::gpu::DumpOptions& dump_opts) { mlir::PassManager pm(module.getContext()); *static_cast(&pm) = std::move(passes); - if (getenv("MOSAIC_GPU_DUMP_MLIR_PASSES") != nullptr) { - pm.enableIRPrinting(); + std::optional dump_stream; + if (dump_opts.mlir_passes) { + if (!dump_opts.dump_path.empty()) { + std::string path = tsl::io::JoinPath( + dump_opts.dump_path, dump_opts.module_basename + ".mlir-passes.log"); + std::error_code error; + dump_stream.emplace(path, error, llvm::sys::fs::OF_None); + if (error) { + dump_stream.reset(); + LOG(ERROR) << error.message(); + LOG(ERROR) << "Output will be written to stdout instead."; + dump_stream = std::nullopt; + } + } + pm.getContext()->disableMultithreading(); + auto print_always = [](mlir::Pass*, mlir::Operation*) { return true; }; + pm.enableIRPrinting(/*shouldPrintBeforePass=*/print_always, + /*shouldPrintAfterPass=*/print_always, + /*printModuleScope=*/true, + /*printAfterOnlyOnChange=*/false, + /*printAfterOnlyOnFailure=*/true, + dump_stream.has_value() ? *dump_stream : llvm::outs(), + mlir::OpPrintingFlags().enableDebugInfo()); } return pm.run(module); } void InitContext(mlir::MLIRContext* context) { mlir::DialectRegistry registry; - registry.insert(); + registry.insert(); mlir::registerConvertNVVMToLLVMInterface(registry); mlir::registerConvertComplexToLLVMInterface(registry); mlir::registerConvertMemRefToLLVMInterface(registry); @@ -232,195 +307,223 @@ void InitContext(mlir::MLIRContext* context) { context->loadAllAvailableDialects(); } -absl::Status RunCUDATool(const char* tool, - const std::vector& args, - bool stderr_to_stdout = false) { - CHECK(!args.empty() && args.back() == nullptr); - const char * cuda_path_ptr = getenv("CUDA_ROOT"); - if (!cuda_path_ptr) return absl::InternalError("Failed to get CUDA_ROOT"); - std::string tool_path(cuda_path_ptr); - tool_path += "/bin/"; - tool_path += tool; - pid_t child_pid; - posix_spawn_file_actions_t file_actions; - if (posix_spawn_file_actions_init(&file_actions)) { - return absl::InternalError("Failed to initialize spawn file actions"); - } - if (posix_spawn_file_actions_adddup2(&file_actions, STDOUT_FILENO, - STDERR_FILENO)) { - return absl::InternalError("Failed to set up spawn file actions"); - } - // execv is guaranteed by POSIX to not modify the args (other than - // replacing the whole process image), so the const_cast is valid. - if (posix_spawn(&child_pid, tool_path.c_str(), &file_actions, nullptr, - const_cast(args.data()), environ)) { - return absl::InternalError("Process spawn failed"); - } - int status; - if (waitpid(child_pid, &status, 0) == -1) { - return absl::InternalError("Failed to wait for CUDA tool invocation"); - } - if (status != 0) return absl::InternalError("CUDA tool failed"); - if (posix_spawn_file_actions_destroy(&file_actions) != 0) { - return absl::InternalError("Failed to clean up after posix_spawn"); +bool is_nvshmem_used(mlir::ModuleOp module) { + constexpr std::string_view prefix1 = "nvshmem_"; + constexpr std::string_view prefix2 = "nvshmemx_"; + for (mlir::LLVM::LLVMFuncOp llvm_func : + module.getOps()) { + const auto& func_name = llvm_func.getName(); + if (!func_name.starts_with(prefix1) && !func_name.starts_with(prefix2)) { + continue; + } + auto uses = + mlir::SymbolTable::getSymbolUses(llvm_func, module.getOperation()); + if (uses && !uses->empty()) { + return true; + } } - return absl::OkStatus(); + return false; } -class TemporaryDirectory { - private: - TemporaryDirectory(std::string path) : path(std::move(path)) {} - // TODO(apaszke): Unlink in destructor. +absl::StatusOr get_nvshmem_llvm_lib_path() { + const char* nvshmem_path_ptr = getenv("MOSAIC_GPU_NVSHMEM_BC_PATH"); + if (!nvshmem_path_ptr) + return absl::InternalError("Failed to get MOSAIC_GPU_NVSHMEM_BC_PATH"); + return nvshmem_path_ptr; +} - public: - static absl::StatusOr Create() { - std::string pattern = "/tmp/mosaic-gpu-XXXXXX"; - if (mkdtemp(pattern.data()) == NULL) { - return absl::InternalError("Failed to create temporary directory"); - } - return TemporaryDirectory(std::move(pattern)); - } +absl::StatusOr +GetAssemblyToBinaryCompilationProvider() { + auto create_provider = []() { + // Defaults mirror those used in `xla/debug_options_flags.cc`. + constexpr se::cuda::CompilationProviderOptions::NvJitLinkMode + nvjitlink_mode = + se::cuda::CompilationProviderOptions::NvJitLinkMode::kAuto; + constexpr bool enable_llvm_module_compilation_parallelism = false; + constexpr bool enable_driver_compilation = false; + bool enable_libnvptxcompiler = se::IsLibNvPtxCompilerSupported(); + + se::cuda::CompilationProviderOptions opts( + nvjitlink_mode, enable_libnvptxcompiler, + enable_llvm_module_compilation_parallelism, enable_driver_compilation, + std::string(kDefaultCudaDataDir)); + + return absl::NoDestructor(se::cuda::AssembleCompilationProvider(opts)); + }; + static absl::NoDestructor< + absl::StatusOr>> + compilation_provider = create_provider(); + + if (!compilation_provider->ok()) { + return compilation_provider->status(); + } + return (*compilation_provider)->get(); +} - std::string_view GetPath() { return path; } +std::string CUDAErrorString(CUresult result) { + const char* error; + cuGetErrorString(result, &error); + return error; +} +// Returns if the CUDA expression returns an error. +#define CUDA_RETURN_IF_ERROR(stmt) \ + do { \ + if (CUresult result = stmt; result != CUDA_SUCCESS) { \ + return absl::InternalError(CUDAErrorString(result)); \ + } \ + } while (0) + +absl::StatusOr GetCudaComputeCapability() { + // Assumes driver has been initialized and a context exists. XLA already has + // some utilities to query this, but we try to stay runtime-agnostic, so we + // build our own here. + CUdevice device; + CUDA_RETURN_IF_ERROR(cuCtxGetDevice(&device)); + int major = 0; + CUDA_RETURN_IF_ERROR(cuDeviceGetAttribute( + &major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device)); + int minor = 0; + CUDA_RETURN_IF_ERROR(cuDeviceGetAttribute( + &minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device)); - private: - std::string path; -}; + TF_ASSIGN_OR_RETURN(std::string sm, mosaic::gpu::GetSmVersion(major, minor)); + bool has_accelerated_features = absl::EndsWith(sm, "a"); -void DumpCompilationOutput(mlir::ModuleOp module, const std::string& sm, - const std::string& ptx_isa) { - bool dump_ptx = getenv("MOSAIC_GPU_DUMP_PTX") != nullptr; - bool dump_ptxas = getenv("MOSAIC_GPU_DUMP_PTXAS") != nullptr; - bool dump_sass = getenv("MOSAIC_GPU_DUMP_SASS") != nullptr; - if (!dump_ptx && !dump_ptxas && !dump_sass) { - return; - } + using FeatureExtension = se::CudaComputeCapability::FeatureExtension; + return se::CudaComputeCapability(major, minor, + has_accelerated_features + ? FeatureExtension::kAcceleratedFeatures + : FeatureExtension::kNone); +} - module = module.clone(); // Prevent accidental modification. - absl::Cleanup module_destroyer = [module] { module->erase(); }; - auto passes = GetPassPipeline( - module.getContext(), mlir::gpu::CompilationTarget::Assembly, sm, ptx_isa); - if (mlir::failed(passes) || - mlir::failed(RunPasses(std::move(*passes), module))) { - return; - } - for (mlir::Operation& op : module.getBody()->getOperations()) { - auto binary = mlir::dyn_cast(&op); - if (!binary) { continue; } - auto objects = binary.getObjects(); - if (objects.size() != 1) { - std::cerr << "Multiple objects per gpu.binary unsupported" << std::endl; - continue; - } - auto object = mlir::cast(*objects.begin()); - std::string ptx = object.getObject().getValue().str(); - if (dump_ptx) { - std::cout << ptx << std::endl; - } - if (!dump_ptxas && !dump_sass) { continue; } // We're done. - auto tmpdir = TemporaryDirectory::Create(); - if (!tmpdir.ok()) { - std::cerr << "Failed to create a temporary directory" << std::endl; - continue; - } - std::string ptx_path = std::string(tmpdir->GetPath()) + "/kernel.ptx"; - std::string elf_path = std::string(tmpdir->GetPath()) + "/kernel.o"; - // Dump PTX into a file. - std::ofstream ptx_out(ptx_path.c_str()); - if (!ptx_out) { - std::cerr << "Failed to write PTX to a file" << std::endl; - continue; - } - ptx_out << ptx << std::endl; - // Run ptxas to generate SASS. - std::vector ptxas_args = { - "ptxas", "--opt-level", "3", - "--gpu-name", sm.c_str(), "--output-file", - elf_path.c_str(), ptx_path.c_str()}; - if (dump_ptxas) { - ptxas_args.push_back("-v"); - } - ptxas_args.push_back(nullptr); - if (auto status = RunCUDATool("ptxas", ptxas_args); !status.ok()) { - std::cerr << "ptxas invocation failed: " << status.message() << std::endl; - continue; +absl::StatusOr, bool>> Compile( + mlir::ModuleOp module) { + tsl::profiler::TraceMe trace("Compile"); + mosaic::gpu::EnsureLLVMNVPTXTargetIsRegistered(); + TF_ASSIGN_OR_RETURN(se::cuda::CompilationProvider * compilation_provider, + GetAssemblyToBinaryCompilationProvider()); + TF_ASSIGN_OR_RETURN(se::CudaComputeCapability cc, GetCudaComputeCapability()); + TF_ASSIGN_OR_RETURN(std::string sm, + mosaic::gpu::GetSmVersion(cc.major, cc.minor)); + TF_ASSIGN_OR_RETURN(std::string ptx_isa, + GetPtxIsaVersion(*compilation_provider)); + bool is_comm_used = is_nvshmem_used(module); + std::string nvshmem_path = ""; + if (is_comm_used) { + TF_ASSIGN_OR_RETURN(nvshmem_path, get_nvshmem_llvm_lib_path()); + if (!mosaic::gpu::NvshmemApi::Default(/*assert_ok=*/false).is_loaded()) { + return absl::InternalError( + "Failed to load the NVSHMEM library. Make sure it is installed (e.g. " + "`pip install nvidia-nvshmem-cu12`)."); } - if (!dump_sass) { continue; } // We're done. - // Call nvdisasm to pretty-print SASS. - if (auto status = RunCUDATool( - "nvdisasm", {"nvdisasm", "-ndf", "-c", elf_path.c_str(), nullptr}); - !status.ok()) { - std::cerr << "nvdisasm invocation failed: " << status.message() - << std::endl; - continue; + } + const char* dump_llvm = getenv("MOSAIC_GPU_DUMP_LLVM"); + const char* llvm_debug_only = getenv("MOSAIC_GPU_LLVM_DEBUG_ONLY"); +#ifndef NDEBUG + bool old_debug_state = false; + std::vector debug_only_types; + if (llvm_debug_only) { + debug_only_types = absl::StrSplit(llvm_debug_only, ','); + } + if (dump_llvm) { + debug_only_types.push_back("serialize-to-llvm"); + } + if (!debug_only_types.empty()) { + old_debug_state = llvm::DebugFlag; + std::vector debug_only_types_ptrs; + debug_only_types_ptrs.reserve(debug_only_types.size()); + for (std::string_view debug_only_type : debug_only_types) { + debug_only_types_ptrs.push_back(debug_only_type.data()); } + llvm::setCurrentDebugTypes(debug_only_types_ptrs.data(), + debug_only_types_ptrs.size()); + llvm::DebugFlag = true; + } +#else + if (llvm_debug_only || dump_llvm) { + fprintf( + stderr, + "MOSAIC_GPU_LLVM_DEBUG_ONLY or MOSAIC_GPU_DUMP_LLVM is set but LLVM " + "was built with NDEBUG\n"); + abort(); } -} - -absl::StatusOr> Compile( - mlir::ModuleOp module) { - auto sm_and_ptx_isa = GetSmAndPtxIsaVersion(); - if (!sm_and_ptx_isa.ok()) { - return sm_and_ptx_isa.status(); - } - const std::string sm = sm_and_ptx_isa.value().first; - const std::string ptx_isa = sm_and_ptx_isa.value().second; - DumpCompilationOutput(module, sm, ptx_isa); - auto passes = GetPassPipeline( - module.getContext(), mlir::gpu::CompilationTarget::Binary, sm, ptx_isa); +#endif + // Use `div.full` for float32 division---this generates better SASS. + const std::vector llvm_cl_options{"-nvptx-prec-divf32=1"}; + // Acquire a lock over the LLVM command line options here. XLA uses this + // lock to override the default LLVM command line options on a per-client + // basis. This means that failing to acquire this lock and explicitly + // setting our own command line options makes compilation dependent on + // outside state/non-deterministic. + xla::llvm_ir::LLVMCommandLineOptionsLock llvm_lock(llvm_cl_options); + auto passes = GetPassPipeline(module.getContext(), compilation_provider, cc, + sm, ptx_isa, nvshmem_path); if (mlir::failed(passes)) { return absl::InternalError("Failed to construct pass pipeline"); } - if (mlir::failed(RunPasses(std::move(*passes), module))) { + mosaic::gpu::DumpOptions dump_opts = + mosaic::gpu::GetOrSetDumpOptionsForModule(module); + if (mlir::failed(RunPasses(std::move(*passes), module, dump_opts))) { return absl::InternalError("Pass pipeline failed"); } - - llvm::SmallVector runtime_lib; - if (const char* lib_path = getenv("MOSAIC_GPU_RUNTIME_LIB_PATH")) { - runtime_lib.emplace_back(lib_path); + llvm::SmallVector runtime_libs; + if (const char* runtime_lib_path = getenv("MOSAIC_GPU_RUNTIME_LIB_PATH")) { + runtime_libs.emplace_back(runtime_lib_path); + } + if (const char* nvshmem_path = getenv("MOSAIC_GPU_NVSHMEM_SO_PATH")) { + runtime_libs.emplace_back(nvshmem_path); } // Create a transformer to run all LLVM optimization passes at the // specified optimization level. - auto transformer = mlir::makeOptimizingTransformer( - /*optLevel=*/3, /*sizeLevel=*/0, /*targetMachine=*/nullptr); + std::function transformer = + [dump_opts](llvm::Module* module) { + if (getenv("MOSAIC_GPU_DUMP_HOST_LLVM")) { + std::string ll_str; + llvm::raw_string_ostream os(ll_str); + module->print(os, nullptr); + os.flush(); + mosaic::gpu::DumpToFileOrStdout( + ll_str, dump_opts.module_basename + ".ll", dump_opts.dump_path); + } + return mlir::makeOptimizingTransformer( + /*optLevel=*/3, /*sizeLevel=*/0, /*targetMachine=*/nullptr)(module); + }; mlir::ExecutionEngineOptions options; options.transformer = transformer; options.jitCodeGenOptLevel = llvm::CodeGenOptLevel::Aggressive; - options.sharedLibPaths = runtime_lib; + options.sharedLibPaths = runtime_libs; auto maybe_execution_engine = mlir::ExecutionEngine::create(module, options); +#ifndef NDEBUG + if (llvm_debug_only || dump_llvm) { + llvm::DebugFlag = old_debug_state; + } +#endif if (!maybe_execution_engine) { return absl::InternalError("Failed to compile kernel"); } - return std::move(*maybe_execution_engine); + return std::make_pair(std::move(*maybe_execution_engine), is_comm_used); } class CompiledKernel { public: CompiledKernel(std::unique_ptr engine, void* ctx, - MosaicHostFunc* host_launch) - : engine_(std::move(engine)), ctx_(ctx), host_launch_(host_launch) {} + MosaicHostFunc* host_launch, bool is_comm_used) + : engine_(std::move(engine)), + ctx_(ctx), + host_launch_(host_launch), + is_comm_used_(is_comm_used) {} - std::tuple GetHostLaunch() { - return std::make_tuple(ctx_, host_launch_); + std::tuple GetHostLaunch() { + return std::make_tuple(ctx_, host_launch_, is_comm_used_); } private: std::unique_ptr engine_; void* ctx_; // TODO(apaszke): Destroy this properly MosaicHostFunc* host_launch_; + bool is_comm_used_; }; -using KernelHash = std::array; -using CacheKey = std::pair; - -std::pair*, absl::Mutex*> -GetKernelCache() { - static absl::Mutex mutex; - static auto& context_cache = - *new absl::flat_hash_map; - return std::make_pair(&context_cache, &mutex); -} - absl::StatusOr> GetHostAndInitFuncNames( mlir::ModuleOp module_op) { // We look for two top level C-interface functions: @@ -456,7 +559,7 @@ absl::StatusOr> GetHostAndInitFuncNames( return std::make_pair(host_func_name, init_func_name); } -absl::StatusOr CompileAndInit(const char* module) { +absl::StatusOr CompileAndInit(llvm::StringRef module) { mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED); context.allowUnregisteredDialects(true); InitContext(&context); @@ -476,7 +579,8 @@ absl::StatusOr CompileAndInit(const char* module) { if (!maybe_engine.ok()) { return maybe_engine.status(); } - mlir::ExecutionEngine* execution_engine = maybe_engine->get(); + mlir::ExecutionEngine* execution_engine = maybe_engine.value().first.get(); + bool is_comm_used = maybe_engine.value().second; auto host_and_init_func_names = GetHostAndInitFuncNames(*module_op); if (!host_and_init_func_names.ok()) { @@ -495,77 +599,153 @@ absl::StatusOr CompileAndInit(const char* module) { void** kernel_ptr_ptr = &kernel_ptr; void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr}; reinterpret_cast(*init)(init_args); - return CompiledKernel(std::move(*maybe_engine), kernel_ptr, - reinterpret_cast(*host)); + VLOG(5) << "Successfully compiled and initialized Mosaic GPU kernel"; + return CompiledKernel(std::move(maybe_engine.value().first), kernel_ptr, + reinterpret_cast(*host), is_comm_used); } +using KernelHash = std::array; + // Each compiled kernel has a unique init func, and each kernel is used from // a single HLO module. So it should be safe to not include the CUDA context // in the key. -absl::StatusOr> CachedCompileAndInit( - CacheKey key, const char* module) { - auto cache_and_mutex = GetKernelCache(); - auto* cache = cache_and_mutex.first; - auto* mutex = cache_and_mutex.second; +absl::StatusOr CachedCompileAndInit( + const KernelHash& kernel_hash, llvm::StringRef module) { + using CacheKey = std::pair; + struct Cache { + absl::Mutex mutex; + absl::flat_hash_map kernels + ABSL_GUARDED_BY(mutex); + }; + static absl::NoDestructor cache; + + CUcontext ctx; + CUDA_RETURN_IF_ERROR(cuCtxGetCurrent(&ctx)); + CacheKey key(kernel_hash, reinterpret_cast(ctx)); { // Fast path uses reader lock (as hash map look-up is relatively slow). - absl::ReaderMutexLock lock(mutex); - auto it = cache->find(key); - if (ABSL_PREDICT_TRUE(it != cache->end())) - return it->second.GetHostLaunch(); + absl::ReaderMutexLock lock(cache->mutex); + auto it = cache->kernels.find(key); + if (ABSL_PREDICT_TRUE(it != cache->kernels.end())) return &it->second; } - absl::MutexLock lock(mutex); + absl::MutexLock lock(cache->mutex); // We released the reader lock, another thread might have initialized it. - if (cache->find(key) == cache->end()) { - auto compiled = CompileAndInit(module); - if (!compiled.ok()) { - return compiled.status(); - } - cache->insert_or_assign(key, std::move(*compiled)); + if (cache->kernels.find(key) == cache->kernels.end()) { + tsl::profiler::TraceMe trace("Compilation cache miss"); + TF_ASSIGN_OR_RETURN(auto compiled, CompileAndInit(module)); + cache->kernels.insert_or_assign(key, std::move(compiled)); } - return cache->at(key).GetHostLaunch(); + return &cache->kernels.at(key); } +// TODO(b/464203195): Backward-compatible version using the legacy FFI +// API. Remove once backward compatibility window has passed. void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { - if (reinterpret_cast(opaque) % alignof(KernelHash)) { - fprintf(stderr, "Misaligned opaque pointer\n"); - abort(); - } - auto hash = *reinterpret_cast(opaque); - CUcontext ctx; - if (cuCtxGetCurrent(&ctx) != CUDA_SUCCESS) { - fprintf(stderr, "Failed to get current CUDA context\n"); - abort(); - } - CacheKey key(hash, reinterpret_cast(ctx)); - auto ctx_and_kernel = CachedCompileAndInit(key, opaque + sizeof(KernelHash)); - if (!ctx_and_kernel.ok()) { + KernelHash hash; + std::memcpy(hash.data(), opaque, sizeof(KernelHash)); + auto compiled_kernel = + CachedCompileAndInit(hash, opaque + sizeof(KernelHash)); + if (!compiled_kernel.ok()) { XlaCustomCallStatusSetFailure(status, - ctx_and_kernel.status().message().data(), - ctx_and_kernel.status().message().size()); + compiled_kernel.status().message().data(), + compiled_kernel.status().message().size()); return; } - void* args[4] = {&std::get<0>(*ctx_and_kernel), &stream, &buffers}; - std::get<1>(*ctx_and_kernel)(args); + auto ctx_kernel_comm = (*compiled_kernel)->GetHostLaunch(); + bool is_comm_used = std::get<2>(ctx_kernel_comm); + void* args[4] = {&std::get<0>(ctx_kernel_comm), &stream, &buffers}; + if (is_comm_used) { + mosaic::gpu::NvshmemApi::Default().barrier_all_on_stream( + reinterpret_cast(stream)); + } + std::get<1>(ctx_kernel_comm)(args); } XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("mosaic_gpu", &MosaicGPUCustomCall, "CUDA"); +absl::Status MosaicGpuExecute(cudaStream_t stream, ffi::RemainingArgs inputs, + ffi::RemainingRets results, + std::string_view kernel_hash, + std::string_view module, + bool use_custom_barrier) { + if (use_custom_barrier) { + return absl::UnimplementedError("Custom barrier is not supported on GPUs."); + } + if (kernel_hash.size() != sizeof(KernelHash)) { + return absl::InvalidArgumentError( + absl::StrFormat("Kernel hash size is %d bytes, expected %d bytes", + kernel_hash.size(), sizeof(KernelHash))); + } + KernelHash hash; + std::memcpy(hash.data(), kernel_hash.data(), sizeof(KernelHash)); + TF_ASSIGN_OR_RETURN(auto compiled_kernel, CachedCompileAndInit(hash, module)); + auto ctx_kernel_comm = compiled_kernel->GetHostLaunch(); + bool is_comm_used = std::get<2>(ctx_kernel_comm); + + std::vector buffers; + buffers.reserve(inputs.size() + results.size()); + for (int i = 0; i < inputs.size(); ++i) { + buffers.push_back(inputs.get(i)->untyped_data()); + if (reinterpret_cast(buffers.back()) % + mosaic::gpu::kExpectedHbmAlignment) { + return absl::InvalidArgumentError( + absl::StrFormat("Input buffer %d is not %d-byte aligned", i, + mosaic::gpu::kExpectedHbmAlignment)); + } + } + for (int i = 0; i < results.size(); ++i) { + buffers.push_back((*results.get(i))->untyped_data()); + if (reinterpret_cast(buffers.back()) % + mosaic::gpu::kExpectedHbmAlignment) { + return absl::InvalidArgumentError( + absl::StrFormat("Output buffer %d is not %d-byte aligned", i, + mosaic::gpu::kExpectedHbmAlignment)); + } + } + void** buffers_ptr = buffers.data(); + void* args[4] = {&std::get<0>(ctx_kernel_comm), &stream, &buffers_ptr}; + + if (is_comm_used) { + mosaic::gpu::NvshmemApi::Default().barrier_all_on_stream(stream); + } + std::get<1>(ctx_kernel_comm)(args); + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kMosaicGpuExecute, MosaicGpuExecute, + ffi::Ffi::Bind() + .Ctx>() + .RemainingArgs() + .RemainingRets() + .Attr("kernel_hash") + .Attr("module") + .Attr("use_custom_barrier"), + {ffi::Traits::kCmdBufferCompatible}); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "mosaic_gpu_v2", "CUDA", + { + /*instantiate=*/nullptr, + /*prepare=*/nullptr, + /*initialize=*/nullptr, + /*execute=*/kMosaicGpuExecute, + }); + } // namespace extern "C" { -__attribute__((visibility("default"))) -void** MosaicGpuCompile(const char* module) { - auto compiled = CompileAndInit(module); +__attribute__((visibility("default"))) void** MosaicGpuCompile( + const char* module, int num_module_bytes) { + std::string module_str(module, num_module_bytes); + auto compiled = CompileAndInit(module_str); if (!compiled.ok()) { return nullptr; } - auto [ctx, launch] = compiled->GetHostLaunch(); + auto [ctx, launch, is_comm_used] = compiled->GetHostLaunch(); auto tuple_ptr = std::unique_ptr(new void*[3]); if (!tuple_ptr) { return nullptr; @@ -579,8 +759,7 @@ void** MosaicGpuCompile(const char* module) { return tuple_ptr.release(); } -__attribute__((visibility("default"))) -void MosaicGpuUnload(void** tuple_ptr) { +__attribute__((visibility("default"))) void MosaicGpuUnload(void** tuple_ptr) { delete reinterpret_cast(tuple_ptr[2]); delete[] tuple_ptr; } diff --git a/jaxlib/mosaic/gpu/custom_call_test.cc b/jaxlib/mosaic/gpu/custom_call_test.cc new file mode 100644 index 000000000000..15efa7460dbf --- /dev/null +++ b/jaxlib/mosaic/gpu/custom_call_test.cc @@ -0,0 +1,196 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include +#include +#include "absl/base/log_severity.h" +#include "absl/log/globals.h" +#include "absl/log/scoped_mock_log.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/plugin/xla_gpu/xla_gpu_pjrt_client.h" +#include "xla/stream_executor/cuda/cuda_platform.h" // IWYU pragma: keep +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/statusor.h" + +namespace { + +using ::absl_testing::IsOk; +using ::testing::_; + +absl::Status ExecuteSync(xla::PjRtLoadedExecutable* executable) { + std::vector no_buffers; + TF_ASSIGN_OR_RETURN(auto result, + executable->Execute({no_buffers}, /*options=*/{})); + return result[0][0]->GetReadyFuture().Await(); +} + +TEST(CustomCallTest, MosaicGpuUsesCommandBuffers) { + constexpr absl::string_view kHloModule = R"( +HloModule mosaic_gpu_uses_command_buffers + +ENTRY main { + c0 = f32[] constant(0.0) + // Use several custom calls to make sure that XLA decides to wrap them inside + // a command buffer thunk. At the time of writing, the minimum number of + // thunks necessary to trigger the behavior is 5. + cc0 = f32[] custom-call(c0), + custom_call_target="mosaic_gpu_v2", api_version=API_VERSION_TYPED_FFI + cc1 = f32[] custom-call(cc0), + custom_call_target="mosaic_gpu_v2", api_version=API_VERSION_TYPED_FFI + cc2 = f32[] custom-call(cc1), + custom_call_target="mosaic_gpu_v2", api_version=API_VERSION_TYPED_FFI + cc3 = f32[] custom-call(cc2), + custom_call_target="mosaic_gpu_v2", api_version=API_VERSION_TYPED_FFI + ROOT cc4 = f32[] custom-call(cc3), + custom_call_target="mosaic_gpu_v2", api_version=API_VERSION_TYPED_FFI +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + xla::ParseAndReturnUnverifiedModule(kHloModule)); + + std::string tmp_path = testing::TempDir(); + tsl::setenv("XLA_FLAGS", absl::StrCat("--xla_dump_to=", tmp_path).c_str(), + /*overwrite=*/true); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, + xla::GetXlaPjrtGpuClient(/*options=*/{})); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr executable, + client->CompileAndLoad(xla::XlaComputation(module->ToProto()), + /*options=*/{})); + + // Ignore return value. Execution will fail because the custom calls don't + // wrap any valid Mosaic code, but we only care that the chosen execution + // plan uses a command buffer thunk. + ExecuteSync(executable.get()).IgnoreError(); + + // Matching the name exactly is vulnerable to renaming changes, and is not + // ideal. With that said, this seems like the most reasonable thing to do, and + // the naming scheme is relatively stable, so this is unlikely to produce + // churn. + constexpr absl::string_view kBeforeThunkPassesFilename = + "module_0001.mosaic_gpu_uses_command_buffers.thunk_sequence.txt"; + constexpr absl::string_view kAfterThunkPassesFilename = + "module_0001.mosaic_gpu_uses_command_buffers.thunk_sequence_after_thunk_" + "passes.txt"; + + // Ensure that before the thunk passes have run, the first thunk is a custom + // call thunk as expected. + std::string before_contents; + TF_CHECK_OK(tsl::ReadFileToString( + ::tsl::Env::Default(), + absl::StrCat(tmp_path, "/", kBeforeThunkPassesFilename), + &before_contents)); + EXPECT_THAT(before_contents, testing::StartsWith("001: kCustomCall")); + + // Ensure that after the thunk passes have run, the first thunk is a command + // buffer thunk (which therefore wraps the custom call thunk identified in + // the previous step). + std::string after_contents; + TF_CHECK_OK(tsl::ReadFileToString( + ::tsl::Env::Default(), + absl::StrCat(tmp_path, "/", kAfterThunkPassesFilename), &after_contents)); + + // There should be only command buffer thunks. + EXPECT_THAT(after_contents, testing::StartsWith("000: kCommandBuffer")); +} + +TEST(CustomCallTest, LegacyCustomCall) { + absl::string_view hlo_string = R"hlo( + HloModule test + + ENTRY main { + ROOT result = s32[] custom-call(), custom_call_target="mosaic_gpu", api_version=API_VERSION_STATUS_RETURNING, backend_config="\220\307\037$\222=c\235\344\250\025\261Y\233.\002\264\260\013\026\305Ol\324\355\315dA-\311\3277\"builtin.module\"() <{sym_name = \"kernel\"}> ({\n \"stable_mosaic_gpu.func.func\"() ({\n }) {function_type = (!llvm.ptr, !llvm.ptr, i64, i64, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> (), sym_name = \"mosaic_gpu_init_tma_desc\", sym_visibility = \"private\"} : () -> ()\n \"stable_mosaic_gpu.llvm.mlir.global\"() ({\n }) {addr_space = 4 : i32, global_type = !llvm.array<0 x i8>, linkage = #llvm.linkage, sym_name = \"global_scratch\", unnamed_addr = 0 : i64, visibility_ = 0 : i64} : () -> ()\n \"stable_mosaic_gpu.func.func\"() ({\n ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):\n %0 = \"stable_mosaic_gpu.arith.constant\"() {value = 42 : i32} : () -> i32\n %1 = \"stable_mosaic_gpu.arith.constant\"() {value = 0 : i32} : () -> i32\n %2 = \"stable_mosaic_gpu.arith.constant\"() {value = 128 : index} : () -> index\n %3 = \"stable_mosaic_gpu.arith.constant\"() {value = 1 : index} : () -> index\n %4 = \"stable_mosaic_gpu.llvm.mlir.constant\"() {value = 0 : i64} : () -> i64\n %5 = \"stable_mosaic_gpu.llvm.mlir.undef\"() : () -> !llvm.struct<(ptr, ptr, i64)>\n %6 = \"stable_mosaic_gpu.builtin.unrealized_conversion_cast\"(%arg0) : (!llvm.ptr) -> !gpu.async.token\n %7 = \"stable_mosaic_gpu.llvm.load\"(%arg1) {ordering = 0 : i64} : (!llvm.ptr) -> !llvm.ptr\n %8 = \"stable_mosaic_gpu.llvm.insertvalue\"(%5, %7) {position = array} : (!llvm.struct<(ptr, ptr, i64)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64)>\n %9 = \"stable_mosaic_gpu.llvm.insertvalue\"(%8, %7) {position = array} : (!llvm.struct<(ptr, ptr, i64)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64)>\n %10 = \"stable_mosaic_gpu.llvm.insertvalue\"(%9, %4) {position = array} : (!llvm.struct<(ptr, ptr, i64)>, i64) -> !llvm.struct<(ptr, ptr, i64)>\n %11 = \"stable_mosaic_gpu.builtin.unrealized_conversion_cast\"(%10) : (!llvm.struct<(ptr, ptr, i64)>) -> memref\n %12 = \"stable_mosaic_gpu.gpu.launch\"(%6, %3, %3, %3, %2, %3, %3, %1) ({\n ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index, %arg8: index, %arg9: index, %arg10: index, %arg11: index, %arg12: index, %arg13: index):\n %13 = \"stable_mosaic_gpu.nvvm.elect.sync\"() : () -> i1\n %14 = \"stable_mosaic_gpu.gpu.thread_id\"() {dimension = #gpu} : () -> index\n %15 = \"stable_mosaic_gpu.arith.index_cast\"(%14) : (index) -> i32\n %16 = \"stable_mosaic_gpu.gpu.block_dim\"() {dimension = #gpu} : () -> index\n %17 = \"stable_mosaic_gpu.arith.index_cast\"(%16) : (index) -> i32\n %18 = \"stable_mosaic_gpu.gpu.thread_id\"() {dimension = #gpu} : () -> index\n %19 = \"stable_mosaic_gpu.arith.index_cast\"(%18) : (index) -> i32\n %20 = \"stable_mosaic_gpu.arith.muli\"(%19, %17) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %21 = \"stable_mosaic_gpu.arith.addi\"(%15, %20) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %22 = \"stable_mosaic_gpu.gpu.block_dim\"() {dimension = #gpu} : () -> index\n %23 = \"stable_mosaic_gpu.arith.index_cast\"(%22) : (index) -> i32\n %24 = \"stable_mosaic_gpu.arith.muli\"(%17, %23) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %25 = \"stable_mosaic_gpu.gpu.thread_id\"() {dimension = #gpu} : () -> index\n %26 = \"stable_mosaic_gpu.arith.index_cast\"(%25) : (index) -> i32\n %27 = \"stable_mosaic_gpu.arith.muli\"(%26, %24) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %28 = \"stable_mosaic_gpu.arith.addi\"(%21, %27) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %29 = \"stable_mosaic_gpu.gpu.block_dim\"() {dimension = #gpu} : () -> index\n %30 = \"stable_mosaic_gpu.arith.index_cast\"(%29) : (index) -> i32\n %31 = \"stable_mosaic_gpu.arith.muli\"(%24, %30) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %32 = \"stable_mosaic_gpu.arith.constant\"() {value = 5 : i32} : () -> i32\n %33 = \"stable_mosaic_gpu.arith.shrui\"(%28, %32) : (i32, i32) -> i32\n %34 = \"stable_mosaic_gpu.arith.constant\"() {value = -1 : i32} : () -> i32\n %35 = \"stable_mosaic_gpu.arith.constant\"() {value = 0 : i32} : () -> i32\n %36 = \"stable_mosaic_gpu.arith.constant\"() {value = 31 : i32} : () -> i32\n %37 = \"stable_mosaic_gpu.nvvm.shfl.sync\"(%34, %33, %35, %36) {kind = #nvvm} : (i32, i32, i32, i32) -> i32\n %38 = \"stable_mosaic_gpu.arith.constant\"() {value = 0 : i32} : () -> i32\n %39 = \"stable_mosaic_gpu.arith.cmpi\"(%37, %38) {predicate = 0 : i64} : (i32, i32) -> i1\n %40 = \"stable_mosaic_gpu.arith.andi\"(%39, %13) : (i1, i1) -> i1\n %41 = \"stable_mosaic_gpu.nvvm.elect.sync\"() : () -> i1\n %42 = \"stable_mosaic_gpu.gpu.thread_id\"() {dimension = #gpu} : () -> index\n %43 = \"stable_mosaic_gpu.arith.index_cast\"(%42) : (index) -> i32\n %44 = \"stable_mosaic_gpu.gpu.block_dim\"() {dimension = #gpu} : () -> index\n %45 = \"stable_mosaic_gpu.arith.index_cast\"(%44) : (index) -> i32\n %46 = \"stable_mosaic_gpu.gpu.thread_id\"() {dimension = #gpu} : () -> index\n %47 = \"stable_mosaic_gpu.arith.index_cast\"(%46) : (index) -> i32\n %48 = \"stable_mosaic_gpu.arith.muli\"(%47, %45) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %49 = \"stable_mosaic_gpu.arith.addi\"(%43, %48) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %50 = \"stable_mosaic_gpu.gpu.block_dim\"() {dimension = #gpu} : () -> index\n %51 = \"stable_mosaic_gpu.arith.index_cast\"(%50) : (index) -> i32\n %52 = \"stable_mosaic_gpu.arith.muli\"(%45, %51) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %53 = \"stable_mosaic_gpu.gpu.thread_id\"() {dimension = #gpu} : () -> index\n %54 = \"stable_mosaic_gpu.arith.index_cast\"(%53) : (index) -> i32\n %55 = \"stable_mosaic_gpu.arith.muli\"(%54, %52) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %56 = \"stable_mosaic_gpu.arith.addi\"(%49, %55) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %57 = \"stable_mosaic_gpu.gpu.block_dim\"() {dimension = #gpu} : () -> index\n %58 = \"stable_mosaic_gpu.arith.index_cast\"(%57) : (index) -> i32\n %59 = \"stable_mosaic_gpu.arith.muli\"(%52, %58) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %60 = \"stable_mosaic_gpu.arith.constant\"() {value = 5 : i32} : () -> i32\n %61 = \"stable_mosaic_gpu.arith.shrui\"(%56, %60) : (i32, i32) -> i32\n %62 = \"stable_mosaic_gpu.arith.constant\"() {value = -1 : i32} : () -> i32\n %63 = \"stable_mosaic_gpu.arith.constant\"() {value = 0 : i32} : () -> i32\n %64 = \"stable_mosaic_gpu.arith.constant\"() {value = 31 : i32} : () -> i32\n %65 = \"stable_mosaic_gpu.nvvm.shfl.sync\"(%62, %61, %63, %64) {kind = #nvvm} : (i32, i32, i32, i32) -> i32\n %66 = \"stable_mosaic_gpu.arith.constant\"() {value = 4 : i32} : () -> i32\n %67 = \"stable_mosaic_gpu.arith.remui\"(%65, %66) : (i32, i32) -> i32\n %68 = \"stable_mosaic_gpu.arith.constant\"() {value = 0 : i32} : () -> i32\n %69 = \"stable_mosaic_gpu.arith.cmpi\"(%67, %68) {predicate = 0 : i64} : (i32, i32) -> i1\n %70 = \"stable_mosaic_gpu.arith.andi\"(%69, %41) : (i1, i1) -> i1\n %71 = \"stable_mosaic_gpu.gpu.thread_id\"() {dimension = #gpu} : () -> index\n %72 = \"stable_mosaic_gpu.arith.index_cast\"(%71) : (index) -> i32\n %73 = \"stable_mosaic_gpu.gpu.block_dim\"() {dimension = #gpu} : () -> index\n %74 = \"stable_mosaic_gpu.arith.index_cast\"(%73) : (index) -> i32\n %75 = \"stable_mosaic_gpu.gpu.thread_id\"() {dimension = #gpu} : () -> index\n %76 = \"stable_mosaic_gpu.arith.index_cast\"(%75) : (index) -> i32\n %77 = \"stable_mosaic_gpu.arith.muli\"(%76, %74) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %78 = \"stable_mosaic_gpu.arith.addi\"(%72, %77) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %79 = \"stable_mosaic_gpu.gpu.block_dim\"() {dimension = #gpu} : () -> index\n %80 = \"stable_mosaic_gpu.arith.index_cast\"(%79) : (index) -> i32\n %81 = \"stable_mosaic_gpu.arith.muli\"(%74, %80) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %82 = \"stable_mosaic_gpu.gpu.thread_id\"() {dimension = #gpu} : () -> index\n %83 = \"stable_mosaic_gpu.arith.index_cast\"(%82) : (index) -> i32\n %84 = \"stable_mosaic_gpu.arith.muli\"(%83, %81) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %85 = \"stable_mosaic_gpu.arith.addi\"(%78, %84) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %86 = \"stable_mosaic_gpu.gpu.block_dim\"() {dimension = #gpu} : () -> index\n %87 = \"stable_mosaic_gpu.arith.index_cast\"(%86) : (index) -> i32\n %88 = \"stable_mosaic_gpu.arith.muli\"(%81, %87) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %89 = \"stable_mosaic_gpu.arith.constant\"() {value = 5 : i32} : () -> i32\n %90 = \"stable_mosaic_gpu.arith.shrui\"(%85, %89) : (i32, i32) -> i32\n %91 = \"stable_mosaic_gpu.arith.constant\"() {value = 0 : i32} : () -> i32\n %92 = \"stable_mosaic_gpu.arith.cmpi\"(%90, %91) {predicate = 0 : i64} : (i32, i32) -> i1\n %93 = \"stable_mosaic_gpu.gpu.dynamic_shared_memory\"() : () -> memref>\n %94 = \"stable_mosaic_gpu.arith.index_cast\"(%1) : (i32) -> index\n %95 = \"stable_mosaic_gpu.memref.view\"(%93, %94) : (memref>, index) -> memref<0xi8, #gpu.address_space>\n %96 = \"stable_mosaic_gpu.builtin.unrealized_conversion_cast\"(%95) {transforms = []} : (memref<0xi8, #gpu.address_space>) -> memref<0xi8, #gpu.address_space>\n \"stable_mosaic_gpu.nvvm.fence.mbarrier.init\"() : () -> ()\n \"stable_mosaic_gpu.gpu.barrier\"() : () -> ()\n \"stable_mosaic_gpu.memref.store\"(%0, %11) : (i32, memref) -> ()\n \"stable_mosaic_gpu.gpu.terminator\"() : () -> ()\n }) {operandSegmentSizes = array, workgroup_attributions = 0 : i64} : (!gpu.async.token, index, index, index, index, index, index, i32) -> !gpu.async.token\n \"stable_mosaic_gpu.func.return\"() : () -> ()\n }) {function_type = (!llvm.ptr, !llvm.ptr) -> (), llvm.emit_c_interface, sym_name = \"kernel_mosaic_gpu\"} : () -> ()\n}) {stable_mosaic_gpu.version = 6 : i64} : () -> ()\n" + } + )hlo"; + ASSERT_OK_AND_ASSIGN(auto module, + xla::ParseAndReturnUnverifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(std::unique_ptr client, + xla::GetXlaPjrtGpuClient(/*options=*/{})); + ASSERT_OK_AND_ASSIGN( + std::unique_ptr executable, + client->CompileAndLoad(xla::XlaComputation(module->ToProto()), + /*options=*/{})); + EXPECT_THAT(ExecuteSync(executable.get()), IsOk()); +} + +absl::string_view TestMGPUHloModule() { + // Dumped from the following JAX program: + // + // ``` + // @functools.partial( + // plgpu.pallas_call, + // out_shape=jax.ShapeDtypeStruct((), jnp.int32), + // out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + // ) + // def kernel(o_ref): + // o_ref[...] = jnp.array(42) + // ``` + return R"hlo( + HloModule test + + ENTRY main { + ROOT result = s32[] custom-call(), custom_call_target="mosaic_gpu_v2", api_version=API_VERSION_TYPED_FFI, backend_config={kernel_hash = "6f8a2b1d5e9c0f4a3b7d8e2c1a6b0f9e", module = "ML\EFR\01MLIR\00\01O\0D\01\03\05\07\09\0B\01\03\0D\037\0F\11\13\15\17\19\1B\1D\1F!#%')+-/13579;=?AC\03\12\02\C9\1D\01\BB\0F\13\0B\0B\0F\13\13\13\13\0B\07\0B\0B\13\13\0B\0F\13\13\13e\1B\0B\0F\0B\0B#\0B\0B\0B\0B;\0B\0B\0B\0B\0B\0B\0B#\0B\0B\07\0B\13\0F\0F\13\13\13\0F\13\13\0B\133\133\133U\1B\0B\C3\0B\13\13\13\13\13\13\13\13\13\17\17\17\0B\0F\1F\0F\0B\0B\13\13\0B\0B\0F\0B\0F\0B\17\0B\05\03a\07\09y111\09\03Y\0B\03U\01\15\0F\07\0F\0B\0B\1B/\17\13;\05\07)yQ\07\03E\02\AE\0A\1D3\15\03\03\9B\C5\05E\05G\11\05\01\03\03\07]\03\03\19\BF\03\03\19\C1\03\03\19\C3\05I\1F\05K\05M\03\03\07\9D\03\03\A5\09\05O\11\01\11\03\03\07\9F\03\03\07\A1\03\03\A3\C7affine_map<(d0) -> (d0)>\00\03\05-/\131\05Q\11\05\19\05S\05U\03\07\1F7\139;=\0D\0D\05W\05Y\05[\03\0DA!CEG\BB\13IK\09M\09\05]\05_\0D\19\05a\05c\05e\05g\03\07\1FQSU\13W\0D\0F\05i\0F\05k\03\03\07[\11\01\A9\11\01\01\03\03\07a\11\03\02\04\03\03\07e\11\03\05\03\03\07\09\03\03k\09\05m\03\03\17o#\05\03\11\00\00\00\00\00\00\00\00\03\03\17s#\05\03\11\01\00\00\00\00\00\00\00\03\03\17w#\05\03\11\02\00\00\00\00\00\00\00affine_map<() -> ()>\00\03\05}\7F\81\09\05o#\01\17Y\01\00\00\00\01\00\00\00\01\00\00\00\01\00\00\00\01\00\00\00\01\00\00\00\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\05q\17\05%O\17\05%]\17\05%k\17\05%\E1\17\05%\EF\17\05%\FD\17\05%\81\17\05%\9B\17\05%\B5\17\05%&\02\17\05%f\02\17\05%\9E\02\05s\11\01\15\11\01\D0\FF\FF\FF?\11\01}\05u\05w\03\03\07!\03\03\AB\AD\05y\01\01\1D\B1\B3\05{\1D\B5\B7\05}\17\B9\06\03\0D\05\7F#llvm.linkage\00#gpu.address_space\00#gpu\00#gpu\00#gpu\00#arith.overflow\00#nvvm\00\01\02\02\03\01\02\04\01\09\01A\17\BD\03\01\09)\05\11\15\15\05\05\15\15\05\15\01\05\05\15\15\01\15\01\01y\17\BD\03\00\FF\FF\FF\FF\FF\FF\FF\FF\09)!llvm.ptr\00!llvm.struct<(ptr, ptr, i64)>\00!llvm.array<0 x i8>\00!gpu.async.token\00\04Z\0C\05\01\11\01+\07\03\01\0D\17\11\015\07\01\1F\11\01?\07\01\17\11\01O\07\03\1F;\05\15\01\15\01\05\03\15Y\03\01\05\03\15\0B\03\01\05\03\01_\03\03\05\03\15c\03\03!\03\01g\03\05#\02\01\03\17\0F\06\01\03\1B\03\01%\07\01i\03\15\03\03\11\07\01m\03\17\05\0F\13\11\07\01q\03\17\05\15\13\11\07\01u\03\17\05\17\0D\0F\06\01\03\11\03\19'\17\01{\03\1B\11\11\0B\0B\0B\09\0B\0B\07\05\03\C1\C6\02\19\03\83\03\85\03\87\03\89\03\8B\03\8D\03\8F\03\91\03\93\03\95\03\97\03\99\19\02\01\03\07\09\03\01\0D\03\03\03\06\01\03\01\039\0B\03\01\0D\03\03\03\06\01\03\01\03=\09\03\01\0F\03\03\03\06\01\03\01\03A\07\07\01\03\03\01\05C?\0D\07\01\03\03\01\05;E\0B\03\01\0F\03\03\03\06\01\03\01\03I\07\07\01\03\03\01\05?K\09\03\01\11\03\03\03\06\01\03\01\03O\07\07\01\03\03\01\05QM\0D\07\01\03\03\01\05GS\0B\03\01\11\03\03\03\06\01\03\01\03W\07\07\01\03\03\01\05MY\05\03\01\1B\03\01\13\06\01\03\01\05U]\05\03\01#\03\01\05\03\01\0B\03\01\05\03\01%\03\01\1B\07\01'\03\01\09a_ce\05\03\01\0B\03\01\15\07\01\1D\03\07\05gi\1D\06\01\03\07\05k7\19\02\01\03\07\09\03\01\0D\03\03\03\06\01\03\01\03q\0B\03\01\0D\03\03\03\06\01\03\01\03u\09\03\01\0F\03\03\03\06\01\03\01\03y\07\07\01\03\03\01\05{w\0D\07\01\03\03\01\05s}\0B\03\01\0F\03\03\03\06\01\03\01\03\81\07\07\01\03\03\01\05w\83\09\03\01\11\03\03\03\06\01\03\01\03\87\07\07\01\03\03\01\05\89\85\0D\07\01\03\03\01\05\7F\8B\0B\03\01\11\03\03\03\06\01\03\01\03\8F\07\07\01\03\03\01\05\85\91\05\03\01\1B\03\01\13\06\01\03\01\05\8D\95\05\03\01#\03\01\05\03\01\0B\03\01\05\03\01%\03\01\1B\07\01'\03\01\09\99\97\9B\9D\05\03\01\A7\03\01+\06\01\03\01\05\9F\A1\05\03\01\0B\03\01\15\07\01\1D\03\07\05\A3\A5\1D\06\01\03\07\05\A7o\09\03\01\0D\03\03\03\06\01\03\01\03\AB\0B\03\01\0D\03\03\03\06\01\03\01\03\AF\09\03\01\0F\03\03\03\06\01\03\01\03\B3\07\07\01\03\03\01\05\B5\B1\0D\07\01\03\03\01\05\AD\B7\0B\03\01\0F\03\03\03\06\01\03\01\03\BB\07\07\01\03\03\01\05\B1\BD\09\03\01\11\03\03\03\06\01\03\01\03\C1\07\07\01\03\03\01\05\C3\BF\0D\07\01\03\03\01\05\B9\C5\0B\03\01\11\03\03\03\06\01\03\01\03\C9\07\07\01\03\03\01\05\BF\CB\05\03\01\1B\03\01\13\06\01\03\01\05\C7\CF\05\03\01\0B\03\01\15\07\01\1D\03\07\05\D1\D3-\02\01\03\13\03\06\01\03\03\03\07/\06\01\03\0B\05\D7\D9\0F\07\01\A9\03\0B\03\DB1\00\013\00\015\04\AF\05\05\1B7\00\01)\00\01\06\03\01\05\01\00\9E\0E\81g\0B\0D\17\15\0B\1D/)\13%-\19\1B\1F\11\19\17\11\1F3\19\0F5\1D\15\13\13\0D\05\1F\1B\193\195\19\19\17\15!'#\17\1F!\15\17\19#G\17\1D\1D\17\1F#\0F\0B\0D\09\0B%\11builtin\00stable_mosaic_gpu\00llvm\00gpu\00arith\00nvvm\00module\00arith.index_cast\00arith.constant\00arith.muli\00gpu.thread_id\00gpu.block_dim\00arith.addi\00builtin.unrealized_conversion_cast\00llvm.insertvalue\00arith.shrui\00arith.cmpi\00func.func\00nvvm.elect.sync\00nvvm.shfl.sync\00arith.andi\00llvm.mlir.global\00llvm.mlir.constant\00llvm.mlir.undef\00llvm.load\00gpu.launch\00func.return\00arith.remui\00gpu.dynamic_shared_memory\00memref.view\00nvvm.fence.mbarrier.init\00gpu.barrier\00memref.store\00gpu.terminator\00-\00value\00sym_name\00position\00dimension\00function_type\00stable_mosaic_gpu.version\00kernel\00pallas_call\00mosaic_gpu_init_tma_desc\00sym_visibility\00private\00addr_space\00global_type\00linkage\00global_scratch\00unnamed_addr\00visibility_\00llvm.emit_c_interface\00kernel_mosaic_gpu\00ordering\00operandSegmentSizes\00workgroup_attributions\00overflowFlags\00kind\00predicate\00transforms\00swap:\00swap\00third_party/py/jax/tests/pallas/mosaic_gpu_test.py\00", use_custom_barrier = false} + } + )hlo"; +} + +TEST(CustomCallTest, KernelCompilationIsCached) { + ASSERT_OK_AND_ASSIGN( + auto module, xla::ParseAndReturnUnverifiedModule(TestMGPUHloModule())); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr client, + xla::GetXlaPjrtGpuClient(/*options=*/{})); + ASSERT_OK_AND_ASSIGN( + std::unique_ptr executable, + client->CompileAndLoad(xla::XlaComputation(module->ToProto()), + /*options=*/{})); + + absl::SetVLogLevel("custom_call", 5); + { + absl::ScopedMockLog log; + EXPECT_CALL(log, + Log(absl::LogSeverity::kInfo, _, + "Successfully compiled and initialized Mosaic GPU kernel")) + .Times(1); + log.StartCapturingLogs(); + EXPECT_THAT(ExecuteSync(executable.get()), IsOk()); + } + + { + // The second execution the compilation should be cached. + absl::ScopedMockLog log; + EXPECT_CALL(log, + Log(absl::LogSeverity::kInfo, _, + "Successfully compiled and initialized Mosaic GPU kernel")) + .Times(0); + log.StartCapturingLogs(); + EXPECT_THAT(ExecuteSync(executable.get()), IsOk()); + } +} + +} // namespace diff --git a/jaxlib/mosaic/gpu/dump.cc b/jaxlib/mosaic/gpu/dump.cc new file mode 100644 index 000000000000..4474595ec607 --- /dev/null +++ b/jaxlib/mosaic/gpu/dump.cc @@ -0,0 +1,366 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/mosaic/gpu/dump.h" + +#if defined(__APPLE__) +// This is the fix recommended by +// https://www.gnu.org/software/gnulib/manual/html_node/environ.html to make +// sure accessing `environ` works on Apple platforms. +#include +#define environ (*_NSGetEnviron()) +#endif +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // NOLINT +#include +#include + +#include "jaxlib/mosaic/gpu/library_paths.h" +#include "absl/cleanup/cleanup.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/numeric/bits.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_split.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Support/LLVM.h" +#include "tsl/platform/path.h" + +namespace mosaic { +namespace gpu { + +namespace { + +class TemporaryDirectory { + private: + TemporaryDirectory(std::string path) : path(std::move(path)) {} + // TODO(apaszke): Unlink in destructor. + + public: + static absl::StatusOr Create() { + std::string pattern = "/tmp/mosaic-gpu-XXXXXX"; + if (mkdtemp(pattern.data()) == NULL) { + return absl::InternalError("Failed to create temporary directory"); + } + return TemporaryDirectory(std::move(pattern)); + } + + std::string_view GetPath() { return path; } + + private: + std::string path; +}; + +absl::StatusOr RunCUDATool(const char* tool, + const std::vector& args, + bool stderr_to_stdout = true) { + CHECK(!args.empty() && args.back() == nullptr); + const char* cuda_path_ptr = mosaic::gpu::GetCUDARoot(); + if (!cuda_path_ptr) + return absl::InternalError("Failed to get the CUDA toolkit path"); + std::string tool_path(cuda_path_ptr); + tool_path += "/bin/"; + tool_path += tool; + int stdout_pipe[2] = {-1, -1}; + pid_t child_pid; + posix_spawn_file_actions_t file_actions; + if (posix_spawn_file_actions_init(&file_actions)) { + return absl::InternalError("Failed to initialize spawn file actions"); + } + absl::Cleanup file_actions_destroyer = [&file_actions] { + posix_spawn_file_actions_destroy(&file_actions); + }; + if (pipe(stdout_pipe) == -1) { + return absl::InternalError("Failed to set up pipe"); + } + absl::Cleanup pipe_closer = [&stdout_pipe] { + if (stdout_pipe[0] != -1) close(stdout_pipe[0]); + if (stdout_pipe[1] != -1) close(stdout_pipe[1]); + }; + // close read end in child + if (posix_spawn_file_actions_addclose(&file_actions, stdout_pipe[0])) { + return absl::InternalError("Failed to close read end of the pipe in child"); + } + if (posix_spawn_file_actions_adddup2(&file_actions, stdout_pipe[1], + STDOUT_FILENO)) { + return absl::InternalError("Failed to redirect stdout to pipe"); + } + if (stderr_to_stdout && posix_spawn_file_actions_adddup2( + &file_actions, STDOUT_FILENO, STDERR_FILENO)) { + return absl::InternalError("Failed to redirect stderr to stdout"); + } + // execv is guaranteed by POSIX to not modify the args (other than + // replacing the whole process image), so the const_cast is valid. + if (int status = + posix_spawn(&child_pid, tool_path.c_str(), &file_actions, nullptr, + const_cast(args.data()), environ)) { + return absl::InternalError( + absl::StrCat("Process spawn failed: ", strerror(status))); + } + // Proactively close write end in parent. If we don't do this, read + // will block since the pipe will have an open write end in the + // parent process. + if (close(stdout_pipe[1]) == -1) { + return absl::InternalError( + absl::StrCat("Failed to close write end of pipe in parent process: ", + strerror(errno))); + } + // Mark the write end as successfully closed, so it doesn't get + // closed a second time by the deferred pipe_closer. + stdout_pipe[1] = -1; + std::string stdout; + char buf[1024]; + while (int bytes_read = read(stdout_pipe[0], buf, sizeof buf)) { + if (bytes_read == -1) { + return absl::InternalError( + absl::StrCat("Failed to read from pipe: ", strerror(errno))); + } + stdout.append(buf, bytes_read); + } + int status; + if (waitpid(child_pid, &status, 0) == -1) { + return absl::InternalError("Failed to wait for CUDA tool invocation"); + } + if (status != 0) { + std::string error_message = "CUDA tool failed"; + if (!stdout.empty()) { + error_message += ": "; + error_message += stdout; + } + return absl::InternalError(error_message); + } + return stdout; +} + +// Parse the SASS and reformat control codes following NervanaSystems/maxas. +std::string FormatSassCtrl(const std::string& sass) { + std::string result; + result.reserve(sass.size()); + std::vector lines = absl::StrSplit(sass, '\n'); + for (int i = 0; i < lines.size(); ++i) { + std::string_view line = lines[i]; + if (i + 1 < lines.size()) { + const std::string& next_line = lines[i + 1]; + size_t first_hex_start = line.rfind("/* 0x"); + size_t first_instr_end = line.rfind(';'); + size_t second_hex_start = next_line.rfind("/* 0x"); + bool second_line_empty = true; + if (second_hex_start != std::string::npos) { + for (size_t i = 0; i < second_hex_start; ++i) { + second_line_empty &= next_line[i] == ' '; + } + } + if (first_hex_start != std::string::npos && + first_instr_end != std::string::npos && + second_hex_start != std::string::npos && second_line_empty) { + line = line.substr(0, first_instr_end); + std::string hex_str = next_line.substr(second_hex_start + 5, 16); + uint64_t ctrl; + if (absl::SimpleHexAtoi(hex_str, &ctrl)) { + uint64_t stall = (ctrl >> 41) & 0xf; + uint64_t yield = (ctrl >> 45) & 0x1; + uint64_t write_barrier = (ctrl >> 46) & 0x7; + uint64_t read_barrier = (ctrl >> 49) & 0x7; + uint64_t wait_barrier = (ctrl >> 52) & 0x3f; + std::string wait_barrier_str; + if (wait_barrier == 0) { + result += " -"; + } else if (absl::has_single_bit(wait_barrier)) { + absl::StrAppendFormat(&result, " %d", + absl::countr_zero(wait_barrier)); + } else { + int first_set = absl::countr_zero(wait_barrier); + uint64_t without_first_set = wait_barrier ^ (1 << first_set); + if (absl::has_single_bit(without_first_set)) { + absl::StrAppendFormat(&result, " %d&%d", + absl::countr_zero(without_first_set), + first_set); + } else { + absl::StrAppendFormat(&result, "0x%02x", wait_barrier); + } + } + absl::StrAppendFormat( + &result, ":%c:%c:%c:%02llu", + read_barrier == 7 ? '-' : ('0' + read_barrier), + write_barrier == 7 ? '-' : ('0' + write_barrier), + yield ? 'Y' : '-', stall); + } + i++; // Skip the hex line. + } + } + result += line; + result.append("\n"); + } + return result; +} + +// The name of the attribute wrapping the module basename for dumping. See +// `GetDumpOptionsForModule` for more details. +constexpr std::string_view kDumpBasenameAttr = "mosaic_gpu.dump_basename"; + +} // namespace + +DumpOptions GetOrSetDumpOptionsForModule(mlir::ModuleOp module) { + // Use a static variable in order to ensure that subsequent compilations of + // modules that share the same name will result in distinct dumps. + static std::atomic dumped_module_count = 0; + DumpOptions opts; + // In order to make sure that we use a consistent module basename for the same + // module even if we end up calling this function multiple times, we set an + // attribute on the module that records its basename whenever we first + // generate it. Subsequent calls will just return the value from the + // attribute. + if (auto attr = module->getAttrOfType(kDumpBasenameAttr)) { + opts.module_basename = attr.getValue().str(); + } else { + int current_count = dumped_module_count.fetch_add(1); + if (std::optional name = module.getName(); + name.has_value()) { + opts.module_basename = absl::StrCat(name->str(), "_", current_count); + } else { + opts.module_basename = absl::StrCat("mosaic_gpu_module_", current_count); + } + module->setAttr( + kDumpBasenameAttr, + mlir::StringAttr::get(module.getContext(), opts.module_basename)); + } + + if (char* dump_to = getenv("MOSAIC_GPU_DUMP_TO"); dump_to != nullptr) { + // "sponge" is a special value, which if set, will result in the files being + // dumped to a directory path specified in the `TEST_UNDECLARED_OUTPUTS_DIR` + // environment variable. + if (std::string_view(dump_to) == "sponge") { + if (char* dump_dir = getenv("TEST_UNDECLARED_OUTPUTS_DIR"); + dump_dir != nullptr) { + opts.dump_path = dump_dir; + } else { + LOG(WARNING) << "\"sponge\" specified as dump directory but " + "TEST_UNDECLARED_OUTPUTS_DIR is not set! " + "Will dump to stdout instead."; + } + } else if (std::string_view(dump_to) == "-") { + // Dump to stdout. + opts.dump_path = ""; + } else { + opts.dump_path = dump_to; + } + + opts.mlir_passes = true; + opts.ptx = true; + opts.ptxas = true; + opts.sass = true; + opts.sass_ctrl = true; + return opts; + } + + opts.mlir_passes = getenv("MOSAIC_GPU_DUMP_MLIR_PASSES") != nullptr; + opts.ptx = getenv("MOSAIC_GPU_DUMP_PTX") != nullptr; + opts.ptxas = getenv("MOSAIC_GPU_DUMP_PTXAS") != nullptr; + opts.sass_ctrl = getenv("MOSAIC_GPU_DUMP_SASS_CTRL") != nullptr; + opts.sass = getenv("MOSAIC_GPU_DUMP_SASS") != nullptr || opts.sass_ctrl; + return opts; +} + +void DumpToFileOrStdout(std::string_view content, std::string_view name, + std::string_view path) { + if (path.empty()) { + std::cout << content << std::endl; + return; + } + std::error_code error; + llvm::raw_fd_ostream out_file(tsl::io::JoinPath(path, name), error, + llvm::sys::fs::OF_None); + if (error) { + LOG(ERROR) << error.message(); + LOG(ERROR) << "Output will be written to stdout instead."; + std::cout << content << std::endl; + return; + } + out_file << content << "\n"; +} + +void DumpSass(mlir::gpu::BinaryOp binary, std::string_view path, + std::string_view basename, bool include_sass_ctrl) { + auto objects = binary.getObjects(); + if (objects.size() != 1) { + std::cerr << "Multiple objects per gpu.binary unsupported" << std::endl; + return; + } + auto object = mlir::cast(*objects.begin()); + if (object.getFormat() != mlir::gpu::CompilationTarget::Binary) { + std::cerr << "gpu.binary object is not in binary format" << std::endl; + return; + } + std::string elf = object.getObject().getValue().str(); + auto tmpdir = TemporaryDirectory::Create(); + if (!tmpdir.ok()) { + std::cerr << "Failed to create a temporary directory" << std::endl; + return; + } + std::string elf_path = std::string(tmpdir->GetPath()) + "/kernel.bin"; + // Dump ELF into a file. + std::ofstream elf_out(elf_path.c_str()); + if (!elf_out) { + std::cerr << "Failed to write binary to a file" << std::endl; + return; + } + elf_out << elf << std::endl; + // Call nvdisasm to pretty-print SASS. + std::vector nvdisasm_args = {"nvdisasm", "-ndf", "-c", + elf_path.c_str()}; + if (include_sass_ctrl) { + nvdisasm_args.push_back("-hex"); + } + nvdisasm_args.push_back(nullptr); + auto result = RunCUDATool("nvdisasm", nvdisasm_args); + if (!result.ok()) { + std::cerr << "nvdisasm invocation failed: " << result.status() << std::endl; + return; + } + + if (include_sass_ctrl) { + mosaic::gpu::DumpToFileOrStdout(FormatSassCtrl(*result), + absl::StrCat(basename, ".sass_ctrl"), path); + } else { + // Dump SASS. + mosaic::gpu::DumpToFileOrStdout(*result, absl::StrCat(basename, ".sass"), + path); + } +} + +} // namespace gpu +} // namespace mosaic diff --git a/jaxlib/mosaic/gpu/dump.h b/jaxlib/mosaic/gpu/dump.h new file mode 100644 index 000000000000..d1db4a0333fe --- /dev/null +++ b/jaxlib/mosaic/gpu/dump.h @@ -0,0 +1,77 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_MOSAIC_GPU_DUMP_H_ +#define JAXLIB_MOSAIC_GPU_DUMP_H_ + +#include +#include + +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/IR/BuiltinOps.h" + +namespace mosaic { +namespace gpu { + +struct DumpOptions { + // Whether to dump the MLIR module before and after each pass. + bool mlir_passes = false; + // Whether to dump the PTX resulting from the compilation. + bool ptx = false; + // Whether to run ptxas in verbose mode. + bool ptxas = false; + // Whether to dump the SASS resulting from the compilation. If both `sass` + // and `sass_ctrl` are true, a single dump containing both will be + // generated. + bool sass = false; + // Whether to dump the SASS control codes following NervanaSystems/maxas. If + // both `sass` and `sass_ctrl` are true, a single dump containing both will be + // generated. + bool sass_ctrl = false; + // Where to dump the output files. If empty, dump to stdout. + std::string dump_path = ""; + // The basename to use when dumping files. + std::string module_basename; +}; + +// Extracts the dump options for the given module from environment variables. +// +// This function takes in a module in order to ensure that subsequent +// compilations of modules that share the same name will result in distinct +// dumps. The module is annotated with an attribute that records the basename +// used for dumps, to ensure that we use a consistent module basename for the +// same module even if we end up calling this function multiple times. +DumpOptions GetOrSetDumpOptionsForModule(mlir::ModuleOp module); + +// Dumps `content` to `path`/`name` if `path` is non-empty, otherwise to +// stdout. +void DumpToFileOrStdout(std::string_view content, std::string_view name, + std::string_view path); + +// Dumps the SASS for the given binary op. +// +// The dump will be written to `path`/`basename`.sass if `include_sass_ctrl` is +// false, or `path`/`basename`.sass_ctrl if it is true. In this latter case, +// SASS control codes will be included in the dump, following +// NervanaSystems/maxas. +// +// If `path` is empty, the dump will be written to stdout instead. +void DumpSass(mlir::gpu::BinaryOp binary, std::string_view path, + std::string_view basename, bool include_sass_ctrl); + +} // namespace gpu +} // namespace mosaic + +#endif // JAXLIB_MOSAIC_GPU_DUMP_H_ diff --git a/jaxlib/mosaic/gpu/dump_test.cc b/jaxlib/mosaic/gpu/dump_test.cc new file mode 100644 index 000000000000..627572354fd1 --- /dev/null +++ b/jaxlib/mosaic/gpu/dump_test.cc @@ -0,0 +1,55 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/mosaic/gpu/dump.h" + +#include +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" + +namespace { + +TEST(DumpTest, GetOrSetDumpOptionsForModuleReturnsConsistentBasenameForModule) { + mlir::MLIRContext ctx; + mlir::OwningOpRef module = mlir::ModuleOp::create( + mlir::UnknownLoc::get(&ctx), /*name=*/"test_module"); + mosaic::gpu::DumpOptions opts1 = + mosaic::gpu::GetOrSetDumpOptionsForModule(*module); + mosaic::gpu::DumpOptions opts2 = + mosaic::gpu::GetOrSetDumpOptionsForModule(*module); + // The module basename should be consistent across calls for the same module. + EXPECT_EQ(opts1.module_basename, opts2.module_basename); +} + +TEST( + DumpTest, + GetOrSetDumpOptionsForModuleReturnsConsistentBasenameForDifferentModulesWithTheSameName) { // NOLINT(whitespace/line_length) + mlir::MLIRContext ctx; + mlir::OwningOpRef module1 = mlir::ModuleOp::create( + mlir::UnknownLoc::get(&ctx), /*name=*/"test_module"); + mlir::OwningOpRef module2 = mlir::ModuleOp::create( + mlir::UnknownLoc::get(&ctx), /*name=*/"test_module"); + mosaic::gpu::DumpOptions opts1 = + mosaic::gpu::GetOrSetDumpOptionsForModule(*module1); + mosaic::gpu::DumpOptions opts2 = + mosaic::gpu::GetOrSetDumpOptionsForModule(*module2); + // The module basename should be different for different modules, even though + // they have the same name. + EXPECT_NE(opts1.module_basename, opts2.module_basename); +} + +} // namespace diff --git a/jaxlib/mosaic/gpu/gpu_module_to_assembly.cc b/jaxlib/mosaic/gpu/gpu_module_to_assembly.cc new file mode 100644 index 000000000000..321b974e0bdf --- /dev/null +++ b/jaxlib/mosaic/gpu/gpu_module_to_assembly.cc @@ -0,0 +1,239 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/mosaic/gpu/gpu_module_to_assembly.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/base/call_once.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/IR/Module.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/TargetSelect.h" +#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Target/LLVM/ModuleToObject.h" +#include "jaxlib/mosaic/pass_boilerplate.h" +#include "xla/service/gpu/llvm_gpu_backend/load_ir_module.h" + +namespace mosaic { +namespace gpu { + +namespace { + +using ::llvm::failure; +using ::llvm::FailureOr; +using ::llvm::LogicalResult; +using ::llvm::SmallVector; +using ::mlir::Attribute; +using ::mlir::gpu::GPUModuleOp; + +// A replacement class for the upstream `NVPTXSerializer` and +// `SerializeGpuModuleBase` classes. +class ModuleToAssembly : public mlir::LLVM::ModuleToObject { + public: + ModuleToAssembly(gpu::GPUModuleOp gpu_module, + ::mlir::NVVM::NVVMTargetAttr target, + std::vector libraries_to_link) + : ModuleToObject(*gpu_module, target.getTriple(), target.getChip(), + target.getFeatures(), target.getO()), + libraries_to_link_(std::move(libraries_to_link)) {}; + + // Serializes the LLVM module to PTX. + FailureOr> moduleToObject( + llvm::Module& llvm_module) override; + + // Loads the bitcode files in `libraries_to_link_`. + std::optional>> loadBitcodeFiles( + llvm::Module& module) override; + + private: + std::vector libraries_to_link_; +}; + +FailureOr> ModuleToAssembly::moduleToObject( + llvm::Module& llvm_module) { + // Use a debug type compatible with upstream. +#define DEBUG_TYPE "serialize-to-llvm" + LLVM_DEBUG({ llvm::dbgs() << llvm_module; }); +#undef DEBUG_TYPE + std::optional machine = getOrCreateTargetMachine(); + if (!machine) { + return getOperation().emitError() + << "Target Machine unavailable for " + "triple " + << triple << ", can't optimize with LLVM\n"; + } + llvm::FailureOr ptx = translateModuleToISA( + llvm_module, **machine, [&]() { return getOperation().emitError(); }); + if (failed(ptx)) { + return getOperation().emitError() << "Failed translating the module" + "to PTX."; + } + + return SmallVector(ptx->begin(), ptx->end()); +} + +std::optional>> +ModuleToAssembly::loadBitcodeFiles(llvm::Module& llvm_module) { + llvm::LLVMContext& ctx = llvm_module.getContext(); + llvm::SMDiagnostic err; + SmallVector> loaded_modules; + loaded_modules.reserve(libraries_to_link_.size()); + for (const std::string& library_path : libraries_to_link_) { + std::unique_ptr library_module = + xla::gpu::LoadIRModule(library_path, &ctx); + if (!library_module) { + getOperation().emitError() << "Failed loading file from " << library_path + << ", error: " << err.getMessage(); + return std::nullopt; + } + loaded_modules.push_back(std::move(library_module)); + } + return loaded_modules; +} + +} // namespace + +namespace internal { + +// A simplified version of the logic implemented by `NVVMTargetAttrImpl`'s +// `serializeToObject` and `createObject` to deal better with different +// environments. +LogicalResult LowerGpuModuleToAssembly( + GPUModuleOp gpu_module, const std::vector& libraries_to_link) { + EnsureLLVMNVPTXTargetIsRegistered(); + mlir::gpu::OffloadingLLVMTranslationAttrInterface handler(nullptr); + mlir::OpBuilder builder(gpu_module->getContext()); + SmallVector objects; + // Fail if there are no target attributes + if (gpu_module.getTargetsAttr().size() != 1) { + return gpu_module.emitError( + "Expected exactly one target attribute, but got ") + << gpu_module.getTargetsAttr().size(); + } + + auto target_attr = llvm::dyn_cast( + gpu_module.getTargetsAttr()[0]); + if (!target_attr) { + return gpu_module.emitError( + "Target attribute is not of type NVVMTargetAttr"); + } + + ModuleToAssembly serializer(gpu_module, target_attr, libraries_to_link); + std::optional> assembly = serializer.run(); + if (!assembly) { + gpu_module.emitError("An error happened while serializing the module."); + return mlir::failure(); + } + + SmallVector properties{ + builder.getNamedAttr("O", builder.getI32IntegerAttr(target_attr.getO()))}; + + Attribute object = builder.getAttr( + target_attr, mlir::gpu::CompilationTarget::Assembly, + builder.getStringAttr( + llvm::StringRef(assembly->data(), assembly->size())), + builder.getDictionaryAttr(properties), /*kernels=*/nullptr); + + if (!object) { + gpu_module.emitError("An error happened while creating the object."); + return mlir::failure(); + } + + builder.setInsertionPointAfter(gpu_module); + mlir::gpu::BinaryOp::create( + builder, gpu_module.getLoc(), gpu_module.getName(), /*handler=*/nullptr, + builder.getArrayAttr(SmallVector{object})); + gpu_module->erase(); + return mlir::success(); +} + +} // namespace internal + +namespace { + +class GpuModuleToAssemblyPass + : public jaxlib::mlir::Pass { + public: + using jaxlib::mlir::Pass::Pass; + + GpuModuleToAssemblyPass() = default; + GpuModuleToAssemblyPass(const GpuModuleToAssemblyPass&) {}; + + static constexpr llvm::StringLiteral kArgumentName = + "mosaic-gpu-module-to-assembly"; + static constexpr llvm::StringLiteral kPassName = "GpuModuleToAssemblyPass"; + + void runOnOperation() override { + mlir::ModuleOp module = getOperation(); + module.walk([&](mlir::gpu::GPUModuleOp gpu_module) { + if (mlir::failed(internal::LowerGpuModuleToAssembly( + gpu_module, libraries_to_link_))) { + gpu_module.emitError("Failed to lower GPU module to assembly."); + return mlir::WalkResult::interrupt(); + } + return mlir::WalkResult::advance(); + }); + } + + private: + ListOption libraries_to_link_{ + *this, "libraries-to-link", + llvm::cl::desc("A comma-separated list of bitcode files to link into the " + "resulting assembly.")}; +}; + +} // namespace + +void registerGpuModuleToAssemblyPass() { + ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { + return std::make_unique(); + }); +} + +void EnsureLLVMNVPTXTargetIsRegistered() { + static absl::once_flag register_nvptx_target_flag; + absl::call_once(register_nvptx_target_flag, []() { + LLVMInitializeNVPTXTarget(); + LLVMInitializeNVPTXTargetInfo(); + LLVMInitializeNVPTXTargetMC(); + LLVMInitializeNVPTXAsmPrinter(); + }); +} + +} // namespace gpu +} // namespace mosaic diff --git a/jaxlib/mosaic/gpu/gpu_module_to_assembly.h b/jaxlib/mosaic/gpu/gpu_module_to_assembly.h new file mode 100644 index 000000000000..d89ed5a9fc55 --- /dev/null +++ b/jaxlib/mosaic/gpu/gpu_module_to_assembly.h @@ -0,0 +1,46 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_MOSAIC_GPU_MODULE_TO_ASSEMBLY_H_ +#define JAXLIB_MOSAIC_GPU_MODULE_TO_ASSEMBLY_H_ + +#include +#include + +#include "llvm/Support/LogicalResult.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" + +namespace mosaic { +namespace gpu { + +// Registers a pass that converts `gpu.module` ops into `gpu.binary` ops +// wrapping a PTX assembly. +void registerGpuModuleToAssemblyPass(); + +// Initializes the NVPTX target on first call. +void EnsureLLVMNVPTXTargetIsRegistered(); + +namespace internal { + +// Implements the main logic of the pass. This is exposed for testing only. +llvm::LogicalResult LowerGpuModuleToAssembly( + mlir::gpu::GPUModuleOp gpu_module, + const std::vector& libraries_to_link); +} // namespace internal + +} // namespace gpu +} // namespace mosaic + +#endif // JAXLIB_MOSAIC_GPU_MODULE_TO_ASSEMBLY_H_ diff --git a/jaxlib/mosaic/gpu/gpu_module_to_assembly_test.cc b/jaxlib/mosaic/gpu/gpu_module_to_assembly_test.cc new file mode 100644 index 000000000000..7ac23ae6b07c --- /dev/null +++ b/jaxlib/mosaic/gpu/gpu_module_to_assembly_test.cc @@ -0,0 +1,236 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/mosaic/gpu/gpu_module_to_assembly.h" + +#include +#include +#include +#include + +#include +#include +#include "absl/strings/str_cat.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/Conversion/LLVMCommon/StructBuilder.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" +#include "mlir/Conversion/UBToLLVM/UBToLLVM.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/WalkResult.h" +#include "mlir/Target/LLVM/NVVM/Target.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" +#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu.h" +#include "xla/service/gpu/llvm_gpu_backend/nvptx_libdevice_path.h" + +namespace { + +using ::testing::HasSubstr; + +class GpuModuleToAssemblyTest : public ::testing::Test { + public: + GpuModuleToAssemblyTest() + : builder_(&context_), + module_(mlir::OwningOpRef( + mlir::ModuleOp::create(builder_.getUnknownLoc(), "module"))) { + RegisterErrorRecordingHandler(); + mlir::DialectRegistry registry; + registry.insert(); + mlir::registerGPUDialectTranslation(registry); + mlir::registerLLVMDialectTranslation(registry); + mlir::registerNVVMDialectTranslation(registry); + builder_.setInsertionPointToEnd(module_->getBody()); + context_.appendDialectRegistry(registry); + context_.loadAllAvailableDialects(); + + mosaic::gpu::registerGpuModuleToAssemblyPass(); + } + + void ExpectLastErrorContains(std::string_view substring) { + EXPECT_THAT(last_error_message_, HasSubstr(substring)); + } + + protected: + mlir::MLIRContext context_; + mlir::OpBuilder builder_; + mlir::OwningOpRef module_; + std::string last_error_message_; + + private: + void RegisterErrorRecordingHandler() { + // Make sure to make the context single-threaded to avoid race conditions + // when recording the last error message. + context_.disableMultithreading(); + mlir::DiagnosticEngine& diagnostic_engine = context_.getDiagEngine(); + diagnostic_engine.registerHandler([&](mlir::Diagnostic& diagnostic) { + last_error_message_ = diagnostic.str(); + }); + } +}; + +mlir::gpu::GPUModuleOp CreateGpuModuleWithEmptyFunc(mlir::OpBuilder& b, + mlir::ArrayAttr targets) { + mlir::gpu::GPUModuleOp gpu_module = mlir::gpu::GPUModuleOp::create( + b, b.getUnknownLoc(), "gpu_module", targets); + b.setInsertionPointToEnd(gpu_module.getBody()); + + mlir::LLVM::LLVMFunctionType func_ty = mlir::LLVM::LLVMFunctionType::get( + mlir::LLVM::LLVMVoidType::get(b.getContext()), {}); + mlir::LLVM::LLVMFuncOp func = + mlir::LLVM::LLVMFuncOp::create(b, b.getUnknownLoc(), "gpu_func", func_ty); + b.setInsertionPointToEnd(func.addEntryBlock(b)); + + mlir::LLVM::ReturnOp::create(b, b.getUnknownLoc(), mlir::ValueRange()); + + b.setInsertionPointAfter(gpu_module); + return gpu_module; +} + +template +mlir::SmallVector GetOpsOfType(mlir::ModuleOp module) { + mlir::SmallVector ops; + module.walk([&](T binary) { ops.push_back(binary); }); + return ops; +} + +mlir::NVVM::NVVMTargetAttr GetNVVMTargetAttr(mlir::MLIRContext* ctx) { + return mlir::NVVM::NVVMTargetAttr::get( + ctx, /*optLevel=*/3, /*triple=*/"nvptx64-nvidia-cuda", + /*chip=*/"sm_90a", /*features=*/"+ptx87"); +} + +TEST_F(GpuModuleToAssemblyTest, ConvertGpuModuleWithNVVMAttributeToAssembly) { + mlir::gpu::GPUModuleOp gpu_module = CreateGpuModuleWithEmptyFunc( + builder_, + mlir::ArrayAttr::get(&context_, {GetNVVMTargetAttr(&context_)})); + EXPECT_TRUE(mlir::succeeded(mosaic::gpu::internal::LowerGpuModuleToAssembly( + gpu_module, /*libraries_to_link=*/{}))); + + EXPECT_EQ(GetOpsOfType(*module_).size(), 0); + mlir::SmallVector binary_ops = + GetOpsOfType(*module_); + ASSERT_EQ(binary_ops.size(), 1); + + auto binary = binary_ops.front(); + ASSERT_EQ(binary.getObjects().size(), 1); + mlir::gpu::ObjectAttr object = + mlir::cast(*binary.getObjects().begin()); + EXPECT_EQ(object.getFormat(), mlir::gpu::CompilationTarget::Assembly); +} + +TEST_F(GpuModuleToAssemblyTest, + EncounteringAGpuModuleWithoutNVVMTargetIsAnError) { + mlir::gpu::GPUModuleOp gpu_module = CreateGpuModuleWithEmptyFunc( + builder_, mlir::ArrayAttr::get( + &context_, {mlir::ROCDL::ROCDLTargetAttr::get(&context_)})); + ASSERT_TRUE(mlir::succeeded(mlir::verify(gpu_module))); + EXPECT_TRUE(mlir::failed(mosaic::gpu::internal::LowerGpuModuleToAssembly( + gpu_module, /*libraries_to_link=*/{}))); + ExpectLastErrorContains("Target attribute is not of type NVVMTargetAttr"); +} + +TEST_F(GpuModuleToAssemblyTest, + EncounteringAGpuModuleWithMultipleTargetsIsAnError) { + mlir::Attribute target_attr = GetNVVMTargetAttr(&context_); + mlir::gpu::GPUModuleOp gpu_module = CreateGpuModuleWithEmptyFunc( + builder_, mlir::ArrayAttr::get(&context_, {target_attr, target_attr})); + ASSERT_TRUE(mlir::succeeded(mlir::verify(gpu_module))); + EXPECT_TRUE(mlir::failed(mosaic::gpu::internal::LowerGpuModuleToAssembly( + gpu_module, /*libraries_to_link=*/{}))); + ExpectLastErrorContains("Expected exactly one target attribute"); +} + +TEST_F(GpuModuleToAssemblyTest, + LoweringGpuModuleToAssemblyLinksToLibrariesCorrectly) { + mlir::Location loc = builder_.getUnknownLoc(); + mlir::Type f32 = builder_.getF32Type(); + mlir::gpu::GPUModuleOp gpu_module = CreateGpuModuleWithEmptyFunc( + builder_, + mlir::ArrayAttr::get(&context_, {GetNVVMTargetAttr(&context_)})); + auto gpu_func = GetOpsOfType(*module_).front(); + + // Insert a declaration for `__nv_exp2f` defined in libdevice here. + builder_.setInsertionPointToStart(gpu_module.getBody()); + auto exp2f = mlir::LLVM::LLVMFuncOp::create( + builder_, loc, "__nv_exp2f", + mlir::LLVM::LLVMFunctionType::get({f32}, f32)); + // Call the function in the entry block of `gpu_func`. + builder_.setInsertionPointToStart(&gpu_func.getBlocks().front()); + auto constant = mlir::LLVM::ConstantOp::create(builder_, loc, f32, + builder_.getF32FloatAttr(1.0)); + mlir::LLVM::CallOp::create(builder_, loc, exp2f, mlir::ValueRange{constant}); + + // Clone the module so that we can check that linking braries behaves + // differently than not linking them. + mlir::OwningOpRef module2 = module_->clone(); + + // Without linking to the libraries, we should not be able to resolve the + // function, and are left with an `extern .func` declaration. Here, we call + // the pass itself in order to make sure that the pass options are propagated + // as expected. + mlir::PassManager pm(module_->getContext()); + auto pass_without_libdevice = + mlir::parsePassPipeline("builtin.module(mosaic-gpu-module-to-assembly)"); + ASSERT_TRUE(mlir::succeeded(pass_without_libdevice)); + *static_cast(&pm) = std::move(*pass_without_libdevice); + EXPECT_TRUE(mlir::succeeded(pm.run(*module_))); + EXPECT_THAT(mosaic_gpu::MlirToString(*module_), HasSubstr("extern .func")); + + // When linking the libraries, the `extern .func` declaration should + // disappear. + std::string libdevice_path = + ::xla::gpu::nvptx::LibDevicePath("./cuda_sdk_lib"); + auto pass_with_libdevice = mlir::parsePassPipeline(absl::StrCat( + "builtin.module(mosaic-gpu-module-to-assembly{libraries-to-link=", + libdevice_path, "})")); + ASSERT_TRUE(mlir::succeeded(pass_with_libdevice)); + *static_cast(&pm) = std::move(*pass_with_libdevice); + EXPECT_TRUE(mlir::succeeded(pm.run(*module2))); + EXPECT_THAT(mosaic_gpu::MlirToString(*module2), + Not(HasSubstr("extern .func"))); +} + +} // anonymous namespace diff --git a/jaxlib/mosaic/gpu/integrations/c/passes.cc b/jaxlib/mosaic/gpu/integrations/c/passes.cc index 524b4443cd8e..ed4b0be94835 100644 --- a/jaxlib/mosaic/gpu/integrations/c/passes.cc +++ b/jaxlib/mosaic/gpu/integrations/c/passes.cc @@ -15,13 +15,11 @@ limitations under the License. #include "jaxlib/mosaic/gpu/integrations/c/passes.h" -#include "jaxlib/mosaic/gpu/launch_lowering.h" #include "jaxlib/mosaic/gpu/serde.h" extern "C" { -void mlirMosaicGpuRegisterPasses() { - mosaic::gpu::registerGpuLaunchLoweringPass(); +void mlirMosaicGpuRegisterSerdePass() { mosaic::gpu::registerSerdePass(); } diff --git a/jaxlib/mosaic/gpu/integrations/c/passes.h b/jaxlib/mosaic/gpu/integrations/c/passes.h index 901c39d68b77..7f6135c6657c 100644 --- a/jaxlib/mosaic/gpu/integrations/c/passes.h +++ b/jaxlib/mosaic/gpu/integrations/c/passes.h @@ -22,7 +22,7 @@ limitations under the License. extern "C" { #endif -MLIR_CAPI_EXPORTED void mlirMosaicGpuRegisterPasses(); +MLIR_CAPI_EXPORTED void mlirMosaicGpuRegisterSerdePass(); #ifdef __cplusplus } diff --git a/jaxlib/mosaic/gpu/launch_lowering.cc b/jaxlib/mosaic/gpu/launch_lowering.cc index 0331d800ec50..b4ded4c6eade 100644 --- a/jaxlib/mosaic/gpu/launch_lowering.cc +++ b/jaxlib/mosaic/gpu/launch_lowering.cc @@ -31,29 +31,30 @@ limitations under the License. #include #include -#include "llvm/include/llvm/ADT/STLExtras.h" -#include "llvm/include/llvm/ADT/StringRef.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/DialectRegistry.h" -#include "mlir/include/mlir/IR/Location.h" -#include "mlir/include/mlir/IR/SymbolTable.h" -#include "mlir/include/mlir/IR/TypeRange.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/IR/ValueRange.h" -#include "mlir/include/mlir/IR/Visitors.h" -#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" -#include "mlir/include/mlir/Pass/Pass.h" -#include "mlir/include/mlir/Pass/PassRegistry.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Support/LogicalResult.h" -#include "mlir/include/mlir/Support/TypeID.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Support/TypeID.h" +#include "jaxlib/mosaic/pass_boilerplate.h" namespace mosaic { namespace gpu { @@ -70,28 +71,30 @@ mlir::Value packKernelArgs(mlir::OpBuilder &builder, auto kernel_args_struct_ty = mlir::LLVM::LLVMStructType::getLiteral( builder.getContext(), kernel_operand_types); auto ptr_ty = mlir::LLVM::LLVMPointerType::get(builder.getContext()); - mlir::Value c1 = builder.create( - launch.getLoc(), builder.getI32Type(), builder.getI32IntegerAttr(1)); - mlir::Value kernel_args_struct = builder.create( - launch.getLoc(), ptr_ty, kernel_args_struct_ty, c1); - mlir::Value kernel_args_array = builder.create( - launch.getLoc(), ptr_ty, + mlir::Value c1 = mlir::LLVM::ConstantOp::create(builder, launch.getLoc(), + builder.getI32Type(), + builder.getI32IntegerAttr(1)); + mlir::Value kernel_args_struct = mlir::LLVM::AllocaOp::create( + builder, launch.getLoc(), ptr_ty, kernel_args_struct_ty, c1); + mlir::Value kernel_args_array = mlir::LLVM::AllocaOp::create( + builder, launch.getLoc(), ptr_ty, mlir::LLVM::LLVMArrayType::get(builder.getI64Type(), launch.getNumKernelOperands()), c1); for (auto [i, operand] : llvm::enumerate(launch.getKernelOperands())) { - mlir::Value storage_ptr = builder.create( - launch.getLoc(), ptr_ty, kernel_args_struct_ty, kernel_args_struct, + mlir::Value storage_ptr = mlir::LLVM::GEPOp::create( + builder, launch.getLoc(), ptr_ty, kernel_args_struct_ty, + kernel_args_struct, mlir::ArrayRef{mlir::LLVM::GEPArg(0), mlir::LLVM::GEPArg(i)}); - builder.create(launch.getLoc(), operand, storage_ptr); + mlir::LLVM::StoreOp::create(builder, launch.getLoc(), operand, storage_ptr); mlir::LLVM::GEPArg arr_gep_arg(i); - mlir::Value array_slot_ptr = builder.create( - launch.getLoc(), ptr_ty, builder.getI64Type(), kernel_args_array, - mlir::LLVM::GEPArg(i)); - builder.create(launch.getLoc(), storage_ptr, - array_slot_ptr); + mlir::Value array_slot_ptr = mlir::LLVM::GEPOp::create( + builder, launch.getLoc(), ptr_ty, builder.getI64Type(), + kernel_args_array, mlir::LLVM::GEPArg(i)); + mlir::LLVM::StoreOp::create(builder, launch.getLoc(), storage_ptr, + array_slot_ptr); } return kernel_args_array; } @@ -100,21 +103,24 @@ void emitRuntimeDecls(mlir::ModuleOp module) { auto ptr_ty = mlir::LLVM::LLVMPointerType::get(module.getContext()); auto i32 = mlir::IntegerType::get(module.getContext(), 32); auto decl_builder = mlir::OpBuilder::atBlockBegin(module.getBody()); - decl_builder.create( - module.getLoc(), decl_builder.getStringAttr("mosaic_gpu_launch_kernel"), + mlir::func::FuncOp::create( + decl_builder, module.getLoc(), + decl_builder.getStringAttr("mosaic_gpu_launch_kernel"), mlir::FunctionType::get(module.getContext(), {ptr_ty, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, ptr_ty, ptr_ty}, {}), decl_builder.getStringAttr("private"), /*arg_attr=*/nullptr, /*res_attrs=*/nullptr); - decl_builder.create( - module.getLoc(), decl_builder.getStringAttr("mosaic_gpu_module_load"), + mlir::func::FuncOp::create( + decl_builder, module.getLoc(), + decl_builder.getStringAttr("mosaic_gpu_module_load"), mlir::FunctionType::get(module.getContext(), {ptr_ty}, {ptr_ty}), decl_builder.getStringAttr("private"), /*arg_attr=*/nullptr, /*res_attrs=*/nullptr); - decl_builder.create( - module.getLoc(), decl_builder.getStringAttr("mosaic_gpu_get_function"), + mlir::func::FuncOp::create( + decl_builder, module.getLoc(), + decl_builder.getStringAttr("mosaic_gpu_get_function"), mlir::FunctionType::get(module.getContext(), {ptr_ty, ptr_ty, i32, i32}, {ptr_ty}), decl_builder.getStringAttr("private"), /*arg_attr=*/nullptr, @@ -131,8 +137,8 @@ void buildInitFunction(mlir::OpBuilder &module_builder, auto ptr_ty = mlir::LLVM::LLVMPointerType::get(init_func.getContext()); mlir::Location loc = init_func.getLoc(); auto builder = mlir::OpBuilder::atBlockBegin(init_func.addEntryBlock()); - auto binary_global_decl = module_builder.create( - loc, + auto binary_global_decl = mlir::LLVM::GlobalOp::create( + module_builder, loc, mlir::LLVM::LLVMArrayType::get(builder.getI8Type(), object.getObject().size()), /*is_constant=*/true, @@ -140,20 +146,19 @@ void buildInitFunction(mlir::OpBuilder &module_builder, /*name=*/ builder.getStringAttr(kernel_name.str() + "_kernel_binary"), /*value=*/object.getObject()); - mlir::Value binary_addr = builder.create( - init_func.getLoc(), binary_global_decl); + mlir::Value binary_addr = mlir::LLVM::AddressOfOp::create( + builder, init_func.getLoc(), binary_global_decl); mlir::Value module_handle = - builder - .create(loc, "mosaic_gpu_module_load", ptr_ty, - binary_addr) + mlir::func::CallOp::create(builder, loc, "mosaic_gpu_module_load", ptr_ty, + binary_addr) .getResult(0); // TODO(apaszke): This will create duplicate globals if the kernel // is called from multiple functions! mlir::StringAttr kernel_name_global_name = builder.getStringAttr(kernel_name.str() + "_name"); - auto kernel_name_global = module_builder.create( - loc, + auto kernel_name_global = mlir::LLVM::GlobalOp::create( + module_builder, loc, mlir::LLVM::LLVMArrayType::get(builder.getI8Type(), kernel_name.size() + 1), /*is_constant=*/true, @@ -163,14 +168,14 @@ void buildInitFunction(mlir::OpBuilder &module_builder, builder.getStringAttr( llvm::Twine(kernel_name).concat(llvm::Twine('\0')))); mlir::Value kernel_name_ptr = - builder.create(loc, kernel_name_global); - mlir::Value used_smem = builder.create( - loc, i32, builder.getI32IntegerAttr(0)); + mlir::LLVM::AddressOfOp::create(builder, loc, kernel_name_global); + mlir::Value used_smem = mlir::LLVM::ConstantOp::create( + builder, loc, i32, builder.getI32IntegerAttr(0)); if (dynamic_smem_size) { if (auto const_smem = dynamic_smem_size.getDefiningOp()) { - used_smem = builder.create( - loc, i32, + used_smem = mlir::LLVM::ConstantOp::create( + builder, loc, i32, builder.getI32IntegerAttr( mlir::cast(const_smem.getValue()).getInt())); } @@ -182,33 +187,32 @@ void buildInitFunction(mlir::OpBuilder &module_builder, auto const_y = cluster_shape.y.getDefiningOp(); auto const_z = cluster_shape.z.getDefiningOp(); if (const_x && const_y && const_z) { - cluster_size = builder.create( - loc, i32, + cluster_size = mlir::LLVM::ConstantOp::create( + builder, loc, i32, builder.getI32IntegerAttr( mlir::cast(const_x.getValue()).getInt() * mlir::cast(const_y.getValue()).getInt() * mlir::cast(const_z.getValue()).getInt())); } else { - cluster_size = builder.create( - loc, i32, builder.getI32IntegerAttr(-1)); + cluster_size = mlir::LLVM::ConstantOp::create( + builder, loc, i32, builder.getI32IntegerAttr(-1)); } } else { assert(!cluster_shape.y && !cluster_shape.z); - cluster_size = builder.create( - loc, i32, builder.getI32IntegerAttr(1)); + cluster_size = mlir::LLVM::ConstantOp::create(builder, loc, i32, + builder.getI32IntegerAttr(1)); } mlir::Value kernel_handle = - builder - .create( - loc, "mosaic_gpu_get_function", ptr_ty, - mlir::ValueRange{module_handle, kernel_name_ptr, used_smem, - cluster_size}) + mlir::func::CallOp::create( + builder, loc, "mosaic_gpu_get_function", ptr_ty, + mlir::ValueRange{module_handle, kernel_name_ptr, used_smem, + cluster_size}) .getResult(0); - builder.create(loc, module_handle, - init_func.getArgument(0)); - builder.create(loc, kernel_handle, - init_func.getArgument(1)); - builder.create(loc); + mlir::LLVM::StoreOp::create(builder, loc, module_handle, + init_func.getArgument(0)); + mlir::LLVM::StoreOp::create(builder, loc, kernel_handle, + init_func.getArgument(1)); + mlir::func::ReturnOp::create(builder, loc); } mlir::LogicalResult launchPreloadedKernel(mlir::func::FuncOp func, @@ -218,17 +222,18 @@ mlir::LogicalResult launchPreloadedKernel(mlir::func::FuncOp func, mlir::OpBuilder builder(launch); mlir::Value dynamic_smem = launch.getDynamicSharedMemorySize(); if (!dynamic_smem) { - dynamic_smem = builder.create( - launch.getLoc(), builder.getI32Type(), builder.getI32IntegerAttr(0)); + dynamic_smem = mlir::LLVM::ConstantOp::create(builder, launch.getLoc(), + builder.getI32Type(), + builder.getI32IntegerAttr(0)); } mlir::Value arg_ptr_array = packKernelArgs(builder, launch); auto as_32bit = [&](mlir::gpu::KernelDim3 dim) { - dim.x = builder.create(launch.getLoc(), - builder.getI32Type(), dim.x); - dim.y = builder.create(launch.getLoc(), - builder.getI32Type(), dim.y); - dim.z = builder.create(launch.getLoc(), - builder.getI32Type(), dim.z); + dim.x = mlir::LLVM::TruncOp::create(builder, launch.getLoc(), + builder.getI32Type(), dim.x); + dim.y = mlir::LLVM::TruncOp::create(builder, launch.getLoc(), + builder.getI32Type(), dim.y); + dim.z = mlir::LLVM::TruncOp::create(builder, launch.getLoc(), + builder.getI32Type(), dim.z); return dim; }; mlir::gpu::KernelDim3 grid = as_32bit(launch.getGridSizeOperandValues()); @@ -237,49 +242,26 @@ mlir::LogicalResult launchPreloadedKernel(mlir::func::FuncOp func, if (launch.hasClusterSize()) { cluster = as_32bit(launch.getClusterSizeOperandValues()); } else { - cluster.x = cluster.y = cluster.z = builder.create( - launch.getLoc(), builder.getI32Type(), builder.getI32IntegerAttr(0)); + cluster.x = cluster.y = cluster.z = mlir::LLVM::ConstantOp::create( + builder, launch.getLoc(), builder.getI32Type(), + builder.getI32IntegerAttr(0)); } mlir::Value stream = launch.getAsyncObject(); - builder.create( - launch.getLoc(), "mosaic_gpu_launch_kernel", mlir::TypeRange{}, + mlir::func::CallOp::create( + builder, launch.getLoc(), "mosaic_gpu_launch_kernel", mlir::TypeRange{}, mlir::ValueRange{kernel_handle, grid.x, grid.y, grid.z, cluster.x, cluster.y, cluster.z, block.x, block.y, block.z, dynamic_smem, stream, arg_ptr_array}); return mlir::success(); } -class GpuLaunchLoweringPass : public ::mlir::OperationPass { +class GpuLaunchLoweringPass + : public jaxlib::mlir::Pass { public: - GpuLaunchLoweringPass() - : ::mlir::OperationPass( - ::mlir::TypeID::get()) {} - GpuLaunchLoweringPass(const GpuLaunchLoweringPass &other) - : ::mlir::OperationPass(other) {} - GpuLaunchLoweringPass &operator=(const GpuLaunchLoweringPass &) = delete; - GpuLaunchLoweringPass(GpuLaunchLoweringPass &&) = delete; - GpuLaunchLoweringPass &operator=(GpuLaunchLoweringPass &&) = delete; - ~GpuLaunchLoweringPass() = default; + using jaxlib::mlir::Pass::Pass; - // Pass boilerplate... - static constexpr ::llvm::StringLiteral getArgumentName() { - return ::llvm::StringLiteral("gpu-launch-lowering"); - } - ::llvm::StringRef getArgument() const override { return getArgumentName(); } - ::llvm::StringRef getDescription() const override { return ""; } - static constexpr ::llvm::StringLiteral getPassName() { - return ::llvm::StringLiteral("GpuLaunchLoweringPass"); - } - ::llvm::StringRef getName() const override { return getPassName(); } - static bool classof(const ::mlir::Pass *pass) { - return pass->getTypeID() == ::mlir::TypeID::get(); - } - std::unique_ptr<::mlir::Pass> clonePass() const override { - return std::make_unique( - *static_cast(this)); - } - void getDependentDialects(::mlir::DialectRegistry ®istry) const override {} - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GpuLaunchLoweringPass) + static constexpr ::llvm::StringLiteral kArgumentName = "gpu-launch-lowering"; + static constexpr ::llvm::StringLiteral kPassName = "GpuLaunchLoweringPass"; void runOnOperation() override { mlir::ModuleOp module = getOperation(); @@ -293,12 +275,13 @@ class GpuLaunchLoweringPass : public ::mlir::OperationPass { continue; } auto module_builder = mlir::OpBuilder::atBlockBegin(module.getBody()); - auto init_func = module_builder.create( - op.getLoc(), func.getName().str() + "_init", + auto init_func = mlir::func::FuncOp::create( + module_builder, op.getLoc(), func.getName().str() + "_init", mlir::FunctionType::get(func->getContext(), {ptr_ty, ptr_ty}, {})); init_func->setAttr(mlir::LLVM::LLVMDialect::getEmitCWrapperAttrName(), mlir::UnitAttr::get(func->getContext())); bool had_launch = false; + mlir::Operation *gpu_binary = nullptr; auto result = getOperation()->walk([&](mlir::gpu::LaunchFuncOp launch) -> mlir::WalkResult { if (had_launch) { @@ -314,6 +297,7 @@ class GpuLaunchLoweringPass : public ::mlir::OperationPass { << launch.getKernelModuleName(); return mlir::WalkResult::interrupt(); } + gpu_binary = binary.getOperation(); if (binary.getObjects().size() != 1) { binary.emitOpError("Expected exactly one object in the binary."); return mlir::WalkResult::interrupt(); @@ -335,15 +319,16 @@ class GpuLaunchLoweringPass : public ::mlir::OperationPass { launch.getDynamicSharedMemorySize(), cluster_shape); // Add a new function argument for the kernel handle. - func.insertArgument(0, ptr_ty, - mlir::DictionaryAttr::get(func.getContext()), - mlir::UnknownLoc::get(func.getContext())); + if (failed(func.insertArgument( + 0, ptr_ty, mlir::DictionaryAttr::get(func.getContext()), + mlir::UnknownLoc::get(func.getContext())))) { + return mlir::WalkResult::interrupt(); + } mlir::Value kernel_handle = func.getArgument(0); if (launchPreloadedKernel(func, launch, kernel_handle).failed()) { return mlir::WalkResult::interrupt(); } launch.erase(); - // TODO(apaszke): Generate a destructor function. // builder.CreateCall(getModuleUnloadFn(), {moduleObject}); @@ -352,6 +337,13 @@ class GpuLaunchLoweringPass : public ::mlir::OperationPass { if (!had_launch) { init_func.erase(); } + if (gpu_binary) { + // This deletion is load-bearing: the conversion of `gpu.binary` to + // LLVM is side-effecting, as it creates module constructors and + // destructors which create an assumption that symbols from the MLIR + // runtime are available. + gpu_binary->erase(); + } if (result == mlir::WalkResult::interrupt()) { signalPassFailure(); } diff --git a/jaxlib/mosaic/gpu/library_paths.h b/jaxlib/mosaic/gpu/library_paths.h new file mode 100644 index 000000000000..83d523ac3ccc --- /dev/null +++ b/jaxlib/mosaic/gpu/library_paths.h @@ -0,0 +1,31 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_MOSAIC_GPU_LIBRARY_PATHS_H_ +#define JAXLIB_MOSAIC_GPU_LIBRARY_PATHS_H_ + +#include + +namespace mosaic { +namespace gpu { + +inline const char *GetCUDARoot() { + return getenv("CUDA_ROOT"); +} + +} // namespace gpu +} // namespace mosaic + +#endif // JAXLIB_MOSAIC_GPU_LIBRARY_PATHS_H_ diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc index 4f804c9e2116..2b1b044d3478 100644 --- a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc +++ b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc @@ -15,116 +15,23 @@ limitations under the License. #include #include -#include #include #include -#include #include #include -#include "nanobind/nanobind.h" -#include "nanobind/stl/tuple.h" -#include "nanobind/stl/vector.h" #include "absl/cleanup/cleanup.h" #include "absl/strings/str_cat.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/tuple.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/gpu/vendor.h" -#include "jaxlib/kernel_nanobind_helpers.h" -#include "xla/ffi/api/c_api.h" -#include "xla/ffi/api/ffi.h" namespace jax::cuda { namespace { -namespace ffi = xla::ffi; namespace nb = nanobind; -static std::string ToString(CUresult result) { - const char* error_name; - if (cuGetErrorName(result, &error_name)) { - return absl::StrCat("UNKNOWN ERROR (", static_cast(result), ")"); - } - const char* error_string; - if (cuGetErrorString(result, &error_string)) { - return error_name; - } - return absl::StrCat(error_name, ": ", error_string); -} - -// Ensure it is safe to store gpuEvent_t in a uint64_t buffer. -static_assert(sizeof(gpuEvent_t) <= sizeof(uint64_t)); - -static const auto* kEventRecord = - ffi::Ffi::Bind() - .Ctx>() - .Attr("copy_before") - .RemainingArgs() - .Ret>() // event - .RemainingRets() - .To([](gpuStream_t stream, bool copy_before, - auto remaining_args, auto ret, auto remaining_rets) { - static auto* event = new gpuEvent_t; - if (auto res = gpuEventCreate(event, GPU_EVENT_DEFAULT); - res) { - return ffi::Error::Internal( - absl::StrCat("Failed to create event: ", ToString(res))); - } - auto do_copy = [&]() { - gpuMemcpyAsync(ret->untyped_data(), event, - sizeof(gpuEvent_t), gpuMemcpyHostToDevice, stream); - }; - if (copy_before) { - do_copy(); - } - if (auto res = gpuEventRecord(*event, stream); res) { - return ffi::Error::Internal( - absl::StrCat("Failed to record event: ", ToString(res))); - } - if (!copy_before) { - do_copy(); - } - return ffi::Error::Success(); - }) - .release(); - -XLA_FFI_Error* EventRecord(XLA_FFI_CallFrame* call_frame) { - return kEventRecord->Call(call_frame); -} - -static const auto* kEventElapsed = - ffi::Ffi::Bind() - .Ctx>() - .Arg>() // start_event - .Arg>() // end_event - .Ret>() // elapsed_ms - .To([](gpuStream_t stream, auto start, auto end, auto out) { - gpuStreamSynchronize(stream); - auto start_event = std::make_unique(); - auto end_event = std::make_unique(); - absl::MakeCleanup([&]() { - gpuEventDestroy(*start_event); - gpuEventDestroy(*end_event); - }); - gpuMemcpy(start_event.get(), start.untyped_data(), sizeof(gpuEvent_t), - gpuMemcpyDeviceToHost); - gpuMemcpy(end_event.get(), end.untyped_data(), sizeof(gpuEvent_t), - gpuMemcpyDeviceToHost); - float elapsed; - if (auto res = - gpuEventElapsedTime(&elapsed, *start_event, *end_event); - res) { - return ffi::Error::Internal(absl::StrCat( - "Failed to get elapsed time between events: ", ToString(res))); - } - gpuMemcpy(out->untyped_data(), &elapsed, sizeof(float), - gpuMemcpyHostToDevice); - return ffi::Error::Success(); - }) - .release(); - -XLA_FFI_Error* EventElapsed(XLA_FFI_CallFrame* call_frame) { - return kEventElapsed->Call(call_frame); -} - #define THROW(...) \ do { \ throw std::runtime_error( \ @@ -193,15 +100,15 @@ void callback_complete(CUcontext context, uint32_t streamId, THROW_IF_CUPTI_ERROR(status); } } + + size_t num_dropped; + THROW_IF_CUPTI_ERROR( + cuptiActivityGetNumDroppedRecords(context, streamId, &num_dropped), + "failed to get number of dropped activity records"); + THROW_IF(num_dropped > 0, "activity records were dropped"); } NB_MODULE(_mosaic_gpu_ext, m) { - m.def("registrations", []() { - return nb::make_tuple( - nb::make_tuple("mgpu_event_record", EncapsulateFunction(EventRecord)), - nb::make_tuple("mgpu_event_elapsed", EncapsulateFunction(EventElapsed)) - ); - }); m.def("_sync_all_devices", []() { int devices = 0; if (cudaGetDeviceCount(&devices) != gpuSuccess) { @@ -237,15 +144,23 @@ NB_MODULE(_mosaic_gpu_ext, m) { cuptiActivityEnable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL), "failed to enable tracking of kernel activity by CUPTI"); }); - m.def("_cupti_get_timings", []() { - THROW_IF_CUPTI_ERROR( - cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_FLUSH_FORCED), - "failed to flush CUPTI activity buffers"); - THROW_IF_CUPTI_ERROR(cuptiFinalize(), "failed to detach CUPTI"); - THROW_IF_CUPTI_ERROR(cuptiUnsubscribe(profiler_state.subscriber), - "failed to unsubscribe from CUPTI"); - return profiler_state.timings; - }); + m.def( + "_cupti_get_timings", + [](bool finalize) { + THROW_IF_CUPTI_ERROR( + cuptiActivityDisable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL), + "failed to disable tracking of kernel activity by CUPTI"); + THROW_IF_CUPTI_ERROR( + cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_FLUSH_FORCED), + "failed to flush CUPTI activity buffers"); + if (finalize) { + THROW_IF_CUPTI_ERROR(cuptiFinalize(), "failed to detach CUPTI"); + } + THROW_IF_CUPTI_ERROR(cuptiUnsubscribe(profiler_state.subscriber), + "failed to unsubscribe from CUPTI"); + return profiler_state.timings; + }, + nb::arg("finalize") = true); } } // namespace diff --git a/jaxlib/mosaic/gpu/nvshmem.h b/jaxlib/mosaic/gpu/nvshmem.h new file mode 100644 index 000000000000..7869b55b7a31 --- /dev/null +++ b/jaxlib/mosaic/gpu/nvshmem.h @@ -0,0 +1,101 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_MOSAIC_GPU_COMM_H_ +#define JAXLIB_MOSAIC_GPU_COMM_H_ + +#include + +#include +#include +#include + +#include "third_party/gpus/cuda/include/cuda.h" +#include "cuda_runtime_api.h" + +#define NVSHMEM_SUCCESS 0 + +namespace mosaic { +namespace gpu { + +#define NVSHMEM_SET_FN(FnName) \ + FnName = reinterpret_cast(dlsym(library, #FnName)); \ + if (!FnName) { \ + fprintf(stderr, #FnName " not available in this library."); \ + } + +class NvshmemApi { + public: + // Returns a default NvshmemApi for a current process. + // NvshmemApi follows the Singleton design pattern + static NvshmemApi& Default(bool assert_ok = true) { + static NvshmemApi instance; + if (assert_ok && !instance.is_loaded()) { + fprintf(stderr, "Failed to load the NVSHMEM library.\n"); + abort(); + } + return instance; + } + + int cumodule_init(CUmodule module) { + std::lock_guard lock(mutex_); + return nvshmemx_cumodule_init(module); + } + + int cumodule_finalize(CUmodule module) { + std::lock_guard lock(mutex_); + return nvshmemx_cumodule_finalize(module); + } + + void barrier_all_on_stream(cudaStream_t stream) { + nvshmemx_barrier_all_on_stream(stream); + } + + bool is_loaded() { + return nvshmemx_init_status != nullptr && nvshmemx_init_status() == 2; + } + + NvshmemApi(NvshmemApi const&) = delete; + void operator=(NvshmemApi const&) = delete; + + private: + NvshmemApi() { + const char* env_value = getenv("MOSAIC_GPU_NVSHMEM_SO_PATH"); + const char* libnvshmem_path = + env_value && *env_value != 0 ? env_value : nullptr; + void* library = dlopen(libnvshmem_path, RTLD_LAZY); + if (library == nullptr) { + fprintf(stderr, "Failed to open library (from %s): %s", + libnvshmem_path ? libnvshmem_path : "", dlerror()); + } + + NVSHMEM_SET_FN(nvshmemx_barrier_all_on_stream) + NVSHMEM_SET_FN(nvshmemx_cumodule_finalize) + NVSHMEM_SET_FN(nvshmemx_cumodule_init) + NVSHMEM_SET_FN(nvshmemx_init_status) + } + + int (*nvshmemx_barrier_all_on_stream)(cudaStream_t); + int (*nvshmemx_cumodule_finalize)(CUmodule); + int (*nvshmemx_cumodule_init)(CUmodule); + int (*nvshmemx_init_status)(); + + std::mutex mutex_; +}; + +} // namespace gpu +} // namespace mosaic + +#endif // JAXLIB_MOSAIC_GPU_COMM_H_ diff --git a/jaxlib/mosaic/gpu/passes.cc b/jaxlib/mosaic/gpu/passes.cc index b8c3fbb74c81..0d025b80c34b 100644 --- a/jaxlib/mosaic/gpu/passes.cc +++ b/jaxlib/mosaic/gpu/passes.cc @@ -14,24 +14,31 @@ limitations under the License. ==============================================================================*/ #include "jaxlib/mosaic/gpu/passes.h" + #include #include +#include #include #include -#include "llvm/include/llvm/ADT/StringRef.h" -#include "mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h" -#include "mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/SymbolTable.h" -#include "mlir/include/mlir/Pass/PassRegistry.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Transforms/DialectConversion.h" -#include "jaxlib/pass_boilerplate.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" +#include "jaxlib/mosaic/pass_boilerplate.h" namespace mosaic { namespace gpu { @@ -50,31 +57,32 @@ struct ConvertExtractStridedSlicePattern final return rewriter.notifyMatchFailure(op, "only 1-D vectors are supported"); } int64_t size = - (*op.getSizes().getAsRange().begin()).getSInt(); + (*op.getSizes().getAsRange().begin()).getInt(); if (size < 0) { return rewriter.notifyMatchFailure(op, "size is negative"); } int64_t start = - (*op.getOffsets().getAsRange().begin()).getSInt(); + (*op.getOffsets().getAsRange().begin()).getInt(); int64_t stride = - (*op.getStrides().getAsRange().begin()).getSInt(); + (*op.getStrides().getAsRange().begin()).getInt(); if (stride != 1) { return rewriter.notifyMatchFailure(op, "only stride 1 is supported"); } if (start < 0 || start + size > vty.getShape()[0]) { return rewriter.notifyMatchFailure(op, "slice is out of bounds"); } - mlir::Value result = rewriter.create( - op.getLoc(), op.getResult().getType()); + mlir::Value result = mlir::LLVM::UndefOp::create(rewriter, op.getLoc(), + op.getResult().getType()); for (int64_t i = 0; i < size; ++i) { - result = rewriter.create( - op.getLoc(), result, - rewriter.create( - op.getLoc(), subst.getVector(), - rewriter.create( - op.getLoc(), rewriter.getI32IntegerAttr(i + start))), - rewriter.create( - op.getLoc(), rewriter.getI32IntegerAttr(i))); + result = mlir::LLVM::InsertElementOp::create( + rewriter, op.getLoc(), result, + mlir::LLVM::ExtractElementOp::create( + rewriter, op.getLoc(), subst.getSource(), + mlir::LLVM::ConstantOp::create( + rewriter, op.getLoc(), + rewriter.getI32IntegerAttr(i + start))), + mlir::LLVM::ConstantOp::create(rewriter, op.getLoc(), + rewriter.getI32IntegerAttr(i))); } rewriter.replaceOp(op, result); return mlir::success(); @@ -163,6 +171,91 @@ class ByvalInsertionPass } }; +// Insert the nvvm.minctasm attribute, which is sometimes required for ptxas +// to recognize setmaxnreg instructions. +class LLVMAttrInsertionPass + : public jaxlib::mlir::Pass { + public: + using jaxlib::mlir::Pass::Pass; + static constexpr llvm::StringLiteral kArgumentName = "mosaic-llvm-attr-insertion"; + static constexpr llvm::StringLiteral kPassName = "LLVMAttrInsertionPass"; + + void runOnOperation() override { + auto result = getOperation().walk([](mlir::LLVM::LLVMFuncOp op) { + // TODO(apaszke): op.isDeclaration() always returns false... + if (op.getFunctionBody().empty()) { // Skip over declarations. + return mlir::WalkResult::advance(); + } + op.getOperation()->setAttr( + "nvvm.minctasm", mlir::IntegerAttr::get( + mlir::IntegerType::get(op.getContext(), 32), 1)); + for (unsigned i = 0; i < op.getNumArguments(); ++i) { + mlir::BlockArgument arg = op.getArgument(i); + if (!mlir::isa(arg.getType())) { + continue; + } + if (!op.getArgAttr(i, "llvm.align")) { + op.setArgAttr(i, "llvm.align", + mlir::IntegerAttr::get( + mlir::IntegerType::get(op.getContext(), 32), + kExpectedHbmAlignment)); + } + } + return mlir::WalkResult::advance(); + }); + if (result.wasInterrupted()) { + signalPassFailure(); + } + } +}; + +// Replaces all "pallas_call" locations within a FuncOp with the location +// of the first operation in the function that has a different location. +// This provides more specific source information for debugging. +class ResolveTrivialLocationsPass + : public jaxlib::mlir::Pass { + public: + using jaxlib::mlir::Pass::Pass; + static constexpr llvm::StringLiteral kArgumentName = + "mosaic-gpu-resolve-trivial-locations"; + static constexpr llvm::StringLiteral kPassName = + "ResolveTrivialLocationsPass"; + + void runOnOperation() override { + const auto trivial_loc = + mlir::NameLoc::get(mlir::StringAttr::get(&getContext(), "pallas_call")); + getOperation()->walk([&](mlir::func::FuncOp func_op) { + if (func_op->getLoc() != trivial_loc) { + return mlir::WalkResult::advance(); + } + std::optional replacement_loc; + func_op.getBody().walk([&](mlir::Operation* op) { + if (op->getLoc() == trivial_loc) { + return mlir::WalkResult::advance(); + } + auto candidate_loc = op->getLoc(); + while (mlir::isa(candidate_loc)) { + candidate_loc = + mlir::cast(candidate_loc).getChildLoc(); + } + replacement_loc = candidate_loc; + return mlir::WalkResult::interrupt(); + }); + if (!replacement_loc) { + return mlir::WalkResult::advance(); + } + func_op.walk([&](mlir::Operation* op) { + // We use the same replacement for all ops with the trivial location, + // because that what the lowering of pallas_call would have done. + if (op->getLoc() == trivial_loc) { + op->setLoc(*replacement_loc); + } + }); + return mlir::WalkResult::advance(); + }); + } +}; + } // namespace void registerConvertGpuToLLVMPass() { @@ -177,5 +270,17 @@ void registerByvalInsertionPass() { }); } +void registerLLVMAttrInsertionPass() { + ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { + return std::make_unique(); + }); +} + +void registerResolveTrivialLocationsPass() { + ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { + return std::make_unique(); + }); +} + } // namespace gpu } // namespace mosaic diff --git a/jaxlib/mosaic/gpu/passes.h b/jaxlib/mosaic/gpu/passes.h index 21c142bc7692..eeec87e49887 100644 --- a/jaxlib/mosaic/gpu/passes.h +++ b/jaxlib/mosaic/gpu/passes.h @@ -21,6 +21,12 @@ namespace gpu { void registerByvalInsertionPass(); void registerConvertGpuToLLVMPass(); +void registerLLVMAttrInsertionPass(); +void registerResolveTrivialLocationsPass(); + +// This is the default of cudaMalloc and is also upheld by the XLA:GPU runtime. +// We annotate all GMEM pointers with this alignment in LLVMAttrInsertionPass. +inline constexpr int kExpectedHbmAlignment = 256; } // namespace gpu } // namespace mosaic diff --git a/jaxlib/mosaic/gpu/runtime.cc b/jaxlib/mosaic/gpu/runtime.cc index ad3cd0e19644..4d94120aa8c0 100644 --- a/jaxlib/mosaic/gpu/runtime.cc +++ b/jaxlib/mosaic/gpu/runtime.cc @@ -18,11 +18,24 @@ limitations under the License. #include #include "third_party/gpus/cuda/include/cuda.h" +#include "jaxlib/mosaic/gpu/nvshmem.h" + +namespace { +template +void abort_on_error(CUresult result, const char* fmt, Args&&... args) { + if (result != CUDA_SUCCESS) { + const char *ptr = nullptr; + cuGetErrorString(result, &ptr); + fprintf(stderr, fmt, std::forward(args)..., ptr); + abort(); + } +} +} extern "C" { void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, - int64_t elem_bitwidth, int64_t rank, + int64_t elem_type, int64_t rank, int64_t *sizes, int64_t *strides, int64_t swizzle_bytes, int64_t *window_shape) { if (((uintptr_t)tma_desc) % 64 != 0) { @@ -32,7 +45,50 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, abort(); } - // Pack 4 bit types in 8 bit pairs. + CUtensorMapDataType data_type; + int64_t elem_bitwidth; + // types are defined in: launch_context._tma_dma_type() + if (elem_type == 8){ + // this is for int2s + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + elem_bitwidth = 2; + } else if (elem_type == 0){ + // this is for int4s + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + elem_bitwidth = 4; + } else if (elem_type == 1){ + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + elem_bitwidth = 8; + } else if (elem_type == 2){ + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT16; + elem_bitwidth = 16; + } else if (elem_type == 3){ + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT32; + elem_bitwidth = 32; + } else if (elem_type == 4){ + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT64; + elem_bitwidth = 64; + } else if (elem_type == 5){ + data_type = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + elem_bitwidth = 16; + } else if (elem_type == 6){ + data_type = CU_TENSOR_MAP_DATA_TYPE_FLOAT32; + elem_bitwidth = 32; + } else if (elem_type == 7){ + data_type = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + elem_bitwidth = 16; + } else if (elem_type == 9){ + data_type = CU_TENSOR_MAP_DATA_TYPE_INT32; + elem_bitwidth = 32; + } else if (elem_type == 10){ + data_type = CU_TENSOR_MAP_DATA_TYPE_INT64; + elem_bitwidth = 64; + } else{ + fprintf(stderr, "Unsupported element type: %ld \n", elem_type); + abort(); + } + + // Pack sub byte types in 8 bit pairs. int64_t elem_bytewidth; if (elem_bitwidth < 8) { // Check that it's a power of 2. @@ -54,19 +110,6 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, elem_bytewidth = elem_bitwidth / 8; } - CUtensorMapDataType data_type; - if (elem_bytewidth == 1) { - data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8; - } else if (elem_bytewidth == 2) { - data_type = CU_TENSOR_MAP_DATA_TYPE_UINT16; - } else if (elem_bytewidth == 4) { - data_type = CU_TENSOR_MAP_DATA_TYPE_UINT32; - } else if (elem_bytewidth == 8) { - data_type = CU_TENSOR_MAP_DATA_TYPE_UINT64; - } else { - fprintf(stderr, "Unsupported element size: %ld\n", elem_bytewidth); - abort(); - } if (rank < 1 || rank > 5) { fprintf(stderr, "Rank must be in [1, 5], but got %ld\n", rank); abort(); @@ -94,7 +137,7 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, if (tma_stride_i % 16 != 0 || tma_stride_i >= static_cast(1) << 40) { fprintf(stderr, - "Byte strides must be divisble by 16 and less than 2**40, but " + "Byte strides must be divisible by 16 and less than 2**40, but " "got %ld (item stride = %ld, item size = %ld) at index %ld\n", tma_stride_i, strides[rank - 1], elem_bytewidth, rank - i - 2); abort(); @@ -134,26 +177,31 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, fprintf(stderr, "Unsupported swizzle: %ld\n", swizzle_bytes); abort(); } - CUresult result = cuTensorMapEncodeTiled( + abort_on_error( + cuTensorMapEncodeTiled( tma_desc, data_type, rank, base_addr, tma_sizes, tma_strides, tma_window_shape, element_strides, CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle, - CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); - if (result != CUDA_SUCCESS) { - const char *ptr = nullptr; - cuGetErrorString(result, &ptr); - fprintf(stderr, "cuTensorMapEncodeTiled failed: %s\n", ptr); - abort(); - } + CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE), + "cuTensorMapEncodeTiled failed: %s\n"); } void* mosaic_gpu_module_load(void *data) { CUmodule module = nullptr; - if (auto result = cuModuleLoadData(&module, data); result != CUDA_SUCCESS) { - const char *ptr = nullptr; - cuGetErrorString(result, &ptr); - fprintf(stderr, "cuModuleLoadData failed: %s\n", ptr); - abort(); + abort_on_error(cuModuleLoadData(&module, data), + "cuModuleLoadData failed: %s\n"); + { // Set the NVSHMEM state if it's used by the module. + CUdeviceptr ptr = 0; + size_t size = 0; + if (cuModuleGetGlobal(&ptr, &size, module, + "nvshmemi_device_lib_version_d") == CUDA_SUCCESS) { + if (mosaic::gpu::NvshmemApi::Default().cumodule_init(module) != + NVSHMEM_SUCCESS) { + fprintf(stderr, "nvshmemx_cumodule_init failed.\n"); + abort(); + } + } } + return module; } @@ -161,32 +209,23 @@ void* mosaic_gpu_module_load(void *data) { void *mosaic_gpu_get_function(CUmodule module, const char *name, int32_t smem_bytes, int32_t cluster_size) { CUfunction function = nullptr; - CUresult result = cuModuleGetFunction(&function, module, name); - if (result != CUDA_SUCCESS) { - const char *ptr = nullptr; - cuGetErrorString(result, &ptr); - fprintf(stderr, "cuModuleGetFunction failed: %s\n", ptr); - abort(); - } + abort_on_error( + cuModuleGetFunction(&function, module, name), + "Failed to retrieve function pointer to kernel \"%s\", " + "cuModuleGetFunction failed: %s\n", name); if (smem_bytes) { - result = cuFuncSetAttribute( - function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_bytes); - if (result != CUDA_SUCCESS) { - const char *ptr = nullptr; - cuGetErrorString(result, &ptr); - fprintf(stderr, "cuFuncSetAttribute failed: %s\n", ptr); - abort(); - } + abort_on_error( + cuFuncSetAttribute( + function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_bytes), + "Failed to set maximum dynamic shared memory size for kernel \"%s\" " + "to %d bytes, cuFuncSetAttribute failed: %s\n", name, smem_bytes); } if (cluster_size > 8) { - result = cuFuncSetAttribute( - function, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1); - if (result != CUDA_SUCCESS) { - const char *ptr = nullptr; - cuGetErrorString(result, &ptr); - fprintf(stderr, "cuFuncSetAttribute failed: %s\n", ptr); - abort(); - } + abort_on_error( + cuFuncSetAttribute( + function, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1), + "Failed to set allowed cluster size for kernel \"%s\" to %d, " + "cuFuncSetAttribute failed: %s\n", name, cluster_size); } return function; } @@ -222,11 +261,18 @@ void mosaic_gpu_launch_kernel(CUfunction function, uint32_t grid_x, config.numAttrs = 1; } CUresult result = cuLaunchKernelEx(&config, function, params, nullptr); - if (result != CUDA_SUCCESS) { - const char *ptr = nullptr; - cuGetErrorString(result, &ptr); - fprintf(stderr, "cuLaunchKernel failed: %s\n", ptr); + if (result == CUDA_ERROR_INVALID_CLUSTER_SIZE) { + int max_cluster_size; + abort_on_error(cuOccupancyMaxPotentialClusterSize(&max_cluster_size, + function, &config), + "cuOccupancyMaxPotentialClusterSize failed: %s\n"); + fprintf(stderr, + "cuLaunchKernel failed with invalid cluster size (%d, %d, %d)" + ": maximum is %d\n", cluster_x, cluster_y, cluster_z, + max_cluster_size); abort(); + } else { + abort_on_error(result, "cuLaunchKernelEx: %s\n"); } } } diff --git a/jaxlib/mosaic/gpu/serde.cc b/jaxlib/mosaic/gpu/serde.cc index f4cf846acc11..3330199a986f 100644 --- a/jaxlib/mosaic/gpu/serde.cc +++ b/jaxlib/mosaic/gpu/serde.cc @@ -15,31 +15,225 @@ limitations under the License. #include "jaxlib/mosaic/gpu/serde.h" -#include "llvm/include/llvm/ADT/StringMap.h" -#include "llvm/include/llvm/ADT/StringRef.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/LogicalResult.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/serde.h" namespace mosaic::gpu { namespace { +using ::llvm::ArrayRef; +using ::llvm::LogicalResult; +using ::llvm::success; +using ::mlir::Operation; +using ::mlir::Value; + constexpr llvm::StringRef kMangledDialect = "stable_mosaic_gpu."; constexpr llvm::StringRef kVersionAttrName = "stable_mosaic_gpu.version"; // When this is bumped, we should file a TODO to update the forward-compatible // version in Mosaic GPU lowering in a month! -constexpr int kVersion = 1; +// +// TODO(apaszke): Update the forward-compatible version to 3 in Mosaic GPU +// lowering after 2025-10-08. +// TODO(apaszke): Update the forward-compatible version to 4 in Mosaic GPU +// lowering after 2025-11-13. +// TODO(apaszke): Update the forward-compatible version to 5 in Mosaic GPU +// lowering after 2025-12-07. +// TODO(apaszke): Update the forward-compatible version to 6 in Mosaic GPU +// lowering after 2025-12-18. +constexpr int kVersion = 6; using SerdeRuleType = jaxlib::mosaic::SerdeRuleType; +LogicalResult vector_extractelement_upgrade(Operation* op, int version, + bool& erased) { + if (version < 2) { + // vector.extractelement was removed in + // https://github.com/llvm/llvm-project/commit/33465bb2bb75f26b7ad42ab87ccb2464c0245476. + // We replace it with a vector.extract. + mlir::OpBuilder b(op->getParentRegion()); + b.setInsertionPointAfter(op); + Value vec = op->getOperand(0); + Value position = op->getOperand(1); + Value extracted_value = mlir::vector::ExtractOp::create( + b, op->getLoc(), vec, ArrayRef{position}); + + op->replaceAllUsesWith(llvm::SmallVector{extracted_value}); + op->erase(); + erased = true; + } + return success(); +} + +LogicalResult vector_insertelement_upgrade(Operation* op, int version, + bool& erased) { + if (version < 2) { + // vector.insertelement was removed in + // https://github.com/llvm/llvm-project/commit/33465bb2bb75f26b7ad42ab87ccb2464c0245476. + // We replace it with a vector.insert. + mlir::OpBuilder b(op->getParentRegion()); + b.setInsertionPointAfter(op); + Value source = op->getOperand(0); + Value destination = op->getOperand(1); + Value position = op->getOperand(2); + + Value inserted_value = + mlir::vector::InsertOp::create(b, op->getLoc(), source, destination, + ArrayRef{position}); + op->replaceAllUsesWith(llvm::SmallVector{inserted_value}); + op->erase(); + erased = true; + } + return success(); +} + +LogicalResult nvvm_cp_async_bulk_tensor_global_shared_cta_upgrade( + Operation* op, int version, bool& erased) { + // A new operand was added in + // https://github.com/llvm/llvm-project/pull/155435/commits/216550ca2169677dd6fc33bc47c3e1ba6d93fc20 + if (version < 3) { + auto sizes_attr = + op->getAttrOfType("operandSegmentSizes"); + if (!sizes_attr) { + return op->emitOpError( + "Missing or invalid operandSegmentSizes attribute"); + } + if (sizes_attr.getSize() != 4) { + return op->emitOpError("operandSegmentSizes attribute has wrong size"); + } + auto new_sizes = sizes_attr.asArrayRef().vec(); + new_sizes.insert(new_sizes.end() - 1, 0); + op->setAttr("operandSegmentSizes", + mlir::DenseI32ArrayAttr::get(op->getContext(), new_sizes)); + } + return success(); +} + +LogicalResult nvvm_cp_async_bulk_tensor_global_shared_cta_downgrade( + Operation* op, int version, bool& erased) { + // A new operand was added in + // https://github.com/llvm/llvm-project/pull/155435/commits/216550ca2169677dd6fc33bc47c3e1ba6d93fc20 + if (version < 3) { + auto sizes_attr = + op->getAttrOfType("operandSegmentSizes"); + if (!sizes_attr) { + return op->emitOpError( + "Missing or invalid operandSegmentSizes attribute"); + } + if (sizes_attr.getSize() != 5) { + return op->emitOpError("operandSegmentSizes attribute has wrong size"); + } + auto new_sizes = sizes_attr.asArrayRef().vec(); + if (*(new_sizes.end() - 2) != 0) { + return op->emitOpError("Can't downgrade: l2 hint operand is present"); + } + new_sizes.erase(new_sizes.end() - 2); + op->setAttr("operandSegmentSizes", + mlir::DenseI32ArrayAttr::get(op->getContext(), new_sizes)); + } + return success(); +} + +LogicalResult vector_splat_upgrade(Operation* op, int version, bool& erased) { + if (version < 4) { + // vector.splat was removed in + // https://github.com/llvm/llvm-project/commit/ea291d0e8c93d47d7953eff5ca1048891a5fcc55. + // We replace it with a vector.broadcast. + mlir::OpBuilder b(op->getParentRegion()); + b.setInsertionPointAfter(op); + Value inserted_value = mlir::vector::BroadcastOp::create( + b, op->getLoc(), op->getResult(0).getType(), op->getOperand(0)); + op->replaceAllUsesWith(llvm::SmallVector{inserted_value}); + op->erase(); + erased = true; + } + return success(); +} + +LogicalResult nvvm_mbarrier_init_shared_upgrade(Operation* op, int version, + bool& erased) { + // https://github.com/llvm/llvm-project/commit/523706f2cd6a06bd9557bf0dca9986d867eddd79 + if (version < 5) { + mlir::OpBuilder b(op->getParentRegion()); + b.setInsertionPointAfter(op); + mlir::NVVM::MBarrierInitOp::create( + b, op->getLoc(), op->getOperand(0), op->getOperand(1), + op->getNumOperands() < 3 ? Value{} : op->getOperand(2)); + op->erase(); + erased = true; + } + return success(); +} + +LogicalResult nvvm_mbarrier_try_wait_parity_shared_upgrade(Operation* op, + int version, + bool& erased) { + // https://github.com/llvm/llvm-project/commit/7eeae8e41d7827d84de12df7b5ecfab3058900cb + if (version < 6) { + mlir::OpBuilder b(op->getParentRegion()); + b.setInsertionPointAfter(op); + mlir::NVVM::MBarrierTryWaitParityOp::create( + b, op->getLoc(), op->getOperand(0), op->getOperand(1), + op->getOperand(2)); + op->erase(); + erased = true; + } + return success(); +} + +LogicalResult nvvm_mbarrier_arrive_expect_tx_shared_upgrade(Operation* op, + int version, + bool& erased) { + // https://github.com/llvm/llvm-project/commit/fddf7b0510e5df7a08c512a177ea9c1ec4307718 + if (version < 6) { + mlir::ImplicitLocOpBuilder b(op->getLoc(), op->getParentRegion()); + b.setInsertionPointAfter(op); + auto new_op = mlir::NVVM::MBarrierArriveExpectTxOp::create( + b, op->getResultTypes(), op->getOperand(0), op->getOperand(1), + mlir::NVVM::MemScopeKind::CTA, + /*relaxed=*/false, + op->getNumOperands() < 3 ? mlir::Value{} : op->getOperand(2)); + op->replaceAllUsesWith(new_op); + op->erase(); + erased = true; + } + return success(); +} + const llvm::StringMap& upgrade_rules() { - static auto rules = new llvm::StringMap{}; + static auto rules = new llvm::StringMap{ + {::llvm::StringLiteral("vector.extractelement"), + vector_extractelement_upgrade}, + {::llvm::StringLiteral("vector.insertelement"), + vector_insertelement_upgrade}, + {::llvm::StringLiteral("nvvm.cp.async.bulk.tensor.global.shared.cta"), + nvvm_cp_async_bulk_tensor_global_shared_cta_upgrade}, + {::llvm::StringLiteral("vector.splat"), vector_splat_upgrade}, + {::llvm::StringLiteral("nvvm.mbarrier.init.shared"), + nvvm_mbarrier_init_shared_upgrade}, + {::llvm::StringLiteral("nvvm.mbarrier.try_wait.parity.shared"), + nvvm_mbarrier_try_wait_parity_shared_upgrade}, + {::llvm::StringLiteral("nvvm.mbarrier.arrive.expect_tx.shared"), + nvvm_mbarrier_arrive_expect_tx_shared_upgrade}, + }; return *rules; } const llvm::StringMap& downgrade_rules() { - static auto rules = new llvm::StringMap{}; + static auto rules = new llvm::StringMap{ + {::llvm::StringLiteral("nvvm.cp.async.bulk.tensor.global.shared.cta"), + nvvm_cp_async_bulk_tensor_global_shared_cta_downgrade}}; return *rules; } @@ -53,7 +247,7 @@ void SerdePass::runOnOperation() { } int serialize_version = -1; if (serialize) { - serialize_version = target_version.hasValue() ? target_version : kVersion; + serialize_version = target_version.hasValue() ? target_version : kVersion; } if (mlir::failed(jaxlib::mosaic::RunSerde( module, upgrade_rules(), downgrade_rules(), serialize, diff --git a/jaxlib/mosaic/gpu/serde.h b/jaxlib/mosaic/gpu/serde.h index 6187d72b4cd5..29dda33d0c5a 100644 --- a/jaxlib/mosaic/gpu/serde.h +++ b/jaxlib/mosaic/gpu/serde.h @@ -19,13 +19,13 @@ limitations under the License. #include #include -#include "llvm/include/llvm/ADT/StringRef.h" -#include "llvm/include/llvm/Support/CommandLine.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" -#include "mlir/include/mlir/Pass/Pass.h" -#include "mlir/include/mlir/Pass/PassRegistry.h" -#include "jaxlib/pass_boilerplate.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "jaxlib/mosaic/pass_boilerplate.h" namespace mosaic::gpu { diff --git a/jaxlib/mosaic/gpu/target.cc b/jaxlib/mosaic/gpu/target.cc index a1a66a709cbe..8e08dc125a49 100644 --- a/jaxlib/mosaic/gpu/target.cc +++ b/jaxlib/mosaic/gpu/target.cc @@ -16,29 +16,32 @@ limitations under the License. #include #include -#include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" +#include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "llvm/include/llvm/MC/MCSubtargetInfo.h" -#include "llvm/include/llvm/MC/TargetRegistry.h" +#include "absl/strings/strip.h" +#include "llvm/MC/MCSubtargetInfo.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/TargetParser/Triple.h" namespace mosaic::gpu { -absl::StatusOr> GetSmAndPtxIsaVersion( - int major, int minor) { +absl::StatusOr GetSmVersion(int major, int minor) { // "base" compute capability as reported by the driver. // For example for a Hopper H200 GPU this would return sm_90, and never // sm_90a. std::string sm_base = absl::StrCat("sm_", major, minor); const std::string triple = "nvptx64-nvidia-cuda"; + const llvm::Triple target_triple(triple); std::string error; const llvm::Target* target = - llvm::TargetRegistry::lookupTarget(triple, error); + llvm::TargetRegistry::lookupTarget(target_triple, error); if (target == nullptr) { return absl::InternalError(absl::StrFormat( "Failed to lookup LLVM target based on triple %s: %s", triple, error)); @@ -50,7 +53,7 @@ absl::StatusOr> GetSmAndPtxIsaVersion( { // generic subtarget std::unique_ptr subtarget_info{ - target->createMCSubtargetInfo(triple, "", "")}; + target->createMCSubtargetInfo(target_triple, "", "")}; if (subtarget_info == nullptr) { return absl::InternalError(absl::StrFormat( "Failed to get generic LLVM subtarget info for triple %s", triple)); @@ -64,25 +67,42 @@ absl::StatusOr> GetSmAndPtxIsaVersion( } } } + return sm_arch_specific ? sm_arch_specific : sm_base; +} - const std::string sm = sm_arch_specific ? sm_arch_specific : sm_base; - +absl::StatusOr GetLatestLlvmPtxIsaVersion() { + const std::string triple = "nvptx64-nvidia-cuda"; + const llvm::Triple target_triple(triple); + std::string error; + const llvm::Target* target = + llvm::TargetRegistry::lookupTarget(target_triple, error); + if (target == nullptr) { + return absl::InternalError(absl::StrFormat( + "Failed to lookup LLVM target based on triple %s: %s", triple, error)); + } + // generic subtarget std::unique_ptr subtarget_info{ - target->createMCSubtargetInfo(triple, sm, "")}; + target->createMCSubtargetInfo(target_triple, "", "")}; if (subtarget_info == nullptr) { - return absl::InternalError( - absl::StrFormat("Failed to get LLVM subtarget info for sm %s", sm)); + return absl::InternalError(absl::StrFormat( + "Failed to get generic LLVM subtarget info for triple %s", triple)); } - + int llvm_latest_version = 0; for (const llvm::SubtargetFeatureKV& feature : - subtarget_info->getEnabledProcessorFeatures()) { - if (absl::StartsWith(feature.Key, "ptx")) { - std::string ptx_isa = feature.Key; - return std::make_pair(sm, ptx_isa); + subtarget_info->getAllProcessorFeatures()) { + std::string_view version_string = feature.Key; + if (absl::ConsumePrefix(&version_string, "ptx")) { + int version; + if (!absl::SimpleAtoi(version_string, &version)) { + return absl::InternalError( + absl::StrFormat("Failed to convert PTX ISA version to integer: %s", + version_string)); + } + llvm_latest_version = + version > llvm_latest_version ? version : llvm_latest_version; } } - return absl::InternalError(absl::StrFormat( - "Failed to find a PTX ISA LLVM subtarget feature for %s", sm)); + return llvm_latest_version; } } // namespace mosaic::gpu diff --git a/jaxlib/mosaic/gpu/target.h b/jaxlib/mosaic/gpu/target.h index 070ecedebd01..5a2a240d8db1 100644 --- a/jaxlib/mosaic/gpu/target.h +++ b/jaxlib/mosaic/gpu/target.h @@ -22,8 +22,8 @@ limitations under the License. namespace mosaic::gpu { -absl::StatusOr> GetSmAndPtxIsaVersion( - int major, int minor); +absl::StatusOr GetSmVersion(int major, int minor); +absl::StatusOr GetLatestLlvmPtxIsaVersion(); } // namespace mosaic::gpu diff --git a/jaxlib/mosaic/gpu/tiled_layout.cc b/jaxlib/mosaic/gpu/tiled_layout.cc new file mode 100644 index 000000000000..9a7912e8a106 --- /dev/null +++ b/jaxlib/mosaic/gpu/tiled_layout.cc @@ -0,0 +1,641 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/mosaic/gpu/tiled_layout.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "xla/tsl/platform/statusor.h" + +namespace jax::mosaic::gpu { +namespace { + +constexpr int64_t WARP_SIZE = 32; +constexpr int64_t WARPGROUP_SIZE = 128; +constexpr int64_t WARPS_IN_WARPGROUP = WARPGROUP_SIZE / WARP_SIZE; + +int64_t DimSize(const TiledLayout::Dim& d, + const std::vector& tiled_shape) { + if (std::holds_alternative(d)) { + return std::get(d).times; + } + int64_t idx = std::get(d); + CHECK(idx < 0) << "Dimension index must be negative"; + idx += tiled_shape.size(); + CHECK(idx >= 0 && idx < tiled_shape.size()) << "Dimension index out of range"; + return tiled_shape[idx]; +} + +std::vector PartitionedDims( + const std::vector& dims) { + std::vector result; + for (const auto& d : dims) { + if (std::holds_alternative(d)) { + result.push_back(std::get(d)); + } + } + return result; +} + +} // namespace + +Tiling::Tiling(std::vector tiles) : tiles_(std::move(tiles)) {} + +absl::StatusOr Tiling::Create(std::vector tiles) { + size_t last_tile_rank = std::numeric_limits::max(); + for (const Tile& tile : tiles) { + if (tile.size() > last_tile_rank) { + return absl::InvalidArgumentError("Tiles must have a decreasing rank"); + } + if (tile.empty()) { + return absl::InvalidArgumentError("Tiles must not be empty"); + } + if (absl::c_any_of(tile, [](int64_t d) { return d <= 0; })) { + return absl::InvalidArgumentError( + "Tile shape must only have positive sizes"); + } + last_tile_rank = tile.size(); + } + + return Tiling(std::move(tiles)); +} + +absl::StatusOr> Tiling::TileShape( + const std::vector& shape) const { + std::vector current_shape = shape; + for (const Tile& tile : tiles_) { + if (tile.size() > current_shape.size()) { + return absl::InvalidArgumentError("Tiling does not apply to shape"); + } + size_t untiled_rank = current_shape.size() - tile.size(); + std::vector next_shape; + next_shape.reserve(untiled_rank + 2 * tile.size()); + for (size_t i = 0; i < untiled_rank; ++i) { + next_shape.push_back(current_shape[i]); + } + for (size_t i = 0; i < tile.size(); ++i) { + int64_t dim = current_shape[untiled_rank + i]; + int64_t t = tile[i]; + if (dim % t != 0) { + return absl::InvalidArgumentError( + "Dimension not divisible by tile size"); + } + next_shape.push_back(dim / t); + } + for (int64_t t : tile) { + next_shape.push_back(t); + } + current_shape = std::move(next_shape); + } + return current_shape; +} + +absl::StatusOr> Tiling::UntileShape( + const std::vector& shape) const { + std::vector current_shape = shape; + for (auto it = tiles_.rbegin(); it != tiles_.rend(); ++it) { + const Tile& tile = *it; + if (tile.size() * 2 > current_shape.size()) { + return absl::InvalidArgumentError("Invalid tiled shape"); + } + size_t untiled_rank = current_shape.size() - 2 * tile.size(); + std::vector next_shape; + next_shape.reserve(untiled_rank + tile.size()); + for (size_t i = 0; i < untiled_rank; ++i) { + next_shape.push_back(current_shape[i]); + } + for (size_t i = 0; i < tile.size(); ++i) { + int64_t outer = current_shape[untiled_rank + i]; + int64_t inner = current_shape[untiled_rank + tile.size() + i]; + if (inner != tile[i]) { + return absl::InvalidArgumentError("Tiling dimension mismatch"); + } + next_shape.push_back(outer * inner); + } + current_shape = std::move(next_shape); + } + return current_shape; +} + +std::vector Tiling::TileStrides( + const std::vector& strides) const { + std::vector current_strides = strides; + for (const Tile& tile : tiles_) { + size_t untiled_rank = current_strides.size() - tile.size(); + std::vector next_strides; + next_strides.reserve(untiled_rank + 2 * tile.size()); + for (size_t i = 0; i < untiled_rank; ++i) { + next_strides.push_back(current_strides[i]); + } + for (size_t i = 0; i < tile.size(); ++i) { + next_strides.push_back(current_strides[untiled_rank + i] * tile[i]); + } + for (size_t i = 0; i < tile.size(); ++i) { + next_strides.push_back(current_strides[untiled_rank + i]); + } + current_strides = std::move(next_strides); + } + return current_strides; +} + +std::vector Tiling::TileIndices( + const std::vector& indices) const { + std::vector current_indices = indices; + for (const Tile& tile : tiles_) { + size_t untiled_rank = current_indices.size() - tile.size(); + std::vector next_indices; + next_indices.reserve(untiled_rank + 2 * tile.size()); + for (size_t i = 0; i < untiled_rank; ++i) { + next_indices.push_back(current_indices[i]); + } + for (size_t i = 0; i < tile.size(); ++i) { + next_indices.push_back(current_indices[untiled_rank + i] / tile[i]); + } + for (size_t i = 0; i < tile.size(); ++i) { + next_indices.push_back(current_indices[untiled_rank + i] % tile[i]); + } + current_indices = std::move(next_indices); + } + return current_indices; +} + +std::vector Tiling::UntileIndices( + const std::vector& indices) const { + std::vector current_indices = indices; + for (auto it = tiles_.rbegin(); it != tiles_.rend(); ++it) { + const Tile& tile = *it; + size_t untiled_rank = current_indices.size() - 2 * tile.size(); + std::vector next_indices; + next_indices.reserve(untiled_rank + tile.size()); + for (size_t i = 0; i < untiled_rank; ++i) { + next_indices.push_back(current_indices[i]); + } + for (size_t i = 0; i < tile.size(); ++i) { + int64_t outer = current_indices[untiled_rank + i]; + int64_t inner = current_indices[untiled_rank + tile.size() + i]; + next_indices.push_back(outer * tile[i] + inner); + } + current_indices = std::move(next_indices); + } + return current_indices; +} + +absl::StatusOr>, + std::vector>>> +Tiling::TileNestedShapeStrides( + const std::vector>& shape, + const std::vector>& strides) const { + if (shape.size() != strides.size()) { + return absl::InvalidArgumentError( + "Shape and strides must have the same length"); + } + std::vector> current_shape = shape; + std::vector> current_strides = strides; + + for (const Tile& tile : tiles_) { + if (tile.size() > current_shape.size()) { + return absl::InvalidArgumentError("Tiling does not apply to shape"); + } + size_t untiled_rank = current_shape.size() - tile.size(); + std::vector> next_shape; + std::vector> next_strides; + next_shape.reserve(untiled_rank + 2 * tile.size()); + next_strides.reserve(untiled_rank + 2 * tile.size()); + + for (size_t i = 0; i < untiled_rank; ++i) { + next_shape.push_back(current_shape[i]); + next_strides.push_back(current_strides[i]); + } + + std::vector> major_dim_shapes; + std::vector> minor_dim_shapes; + std::vector> major_dim_strides; + std::vector> minor_dim_strides; + + for (size_t i = 0; i < tile.size(); ++i) { + int64_t t = tile[i]; + const std::vector& dim_shape = current_shape[untiled_rank + i]; + const std::vector& dim_strides = + current_strides[untiled_rank + i]; + + std::vector major_dim_shape_rev, major_dim_stride_rev; + std::vector minor_dim_shape_rev, minor_dim_stride_rev; + + for (size_t j = 0; j < dim_shape.size(); ++j) { + size_t idx = dim_shape.size() - 1 - j; + int64_t d = dim_shape[idx]; + int64_t s = dim_strides[idx]; + + if (d < t) { + if (t % d != 0) { + return absl::InvalidArgumentError( + "Dimension not divisible by tile size"); + } + t /= d; + minor_dim_shape_rev.push_back(d); + minor_dim_stride_rev.push_back(s); + } else if (t != 1) { + if (d % t != 0) { + return absl::InvalidArgumentError( + "Dimension not divisible by tile size"); + } + minor_dim_shape_rev.push_back(t); + minor_dim_stride_rev.push_back(s); + if (d != t) { + major_dim_shape_rev.push_back(d / t); + major_dim_stride_rev.push_back(s * t); + } + t = 1; + } else { + major_dim_shape_rev.push_back(d); + major_dim_stride_rev.push_back(s); + } + } + if (t != 1) { + return absl::InvalidArgumentError("Tile size too large for dimension"); + } + + major_dim_shapes.push_back(std::vector( + major_dim_shape_rev.rbegin(), major_dim_shape_rev.rend())); + major_dim_strides.push_back(std::vector( + major_dim_stride_rev.rbegin(), major_dim_stride_rev.rend())); + minor_dim_shapes.push_back(std::vector( + minor_dim_shape_rev.rbegin(), minor_dim_shape_rev.rend())); + minor_dim_strides.push_back(std::vector( + minor_dim_stride_rev.rbegin(), minor_dim_stride_rev.rend())); + } + next_shape.insert(next_shape.end(), major_dim_shapes.begin(), + major_dim_shapes.end()); + next_shape.insert(next_shape.end(), minor_dim_shapes.begin(), + minor_dim_shapes.end()); + next_strides.insert(next_strides.end(), major_dim_strides.begin(), + major_dim_strides.end()); + next_strides.insert(next_strides.end(), minor_dim_strides.begin(), + minor_dim_strides.end()); + current_shape = std::move(next_shape); + current_strides = std::move(next_strides); + } + + auto normalize = [](std::vector>& v) { + for (std::vector& d : v) { + if (d.empty()) { + d.push_back(1); + } + } + }; + normalize(current_shape); + normalize(current_strides); + + return std::make_pair(std::move(current_shape), std::move(current_strides)); +} + +absl::StatusOr> Tiling::TileDimension(int dim) const { + size_t tiling_rank = tiles_[0].size(); + if (dim < 0 || dim >= tiling_rank) { + return absl::InvalidArgumentError("Invalid dimension"); + } + std::vector strides(tiling_rank, 1); + strides[dim] = 0; + std::vector tiled_strides = TileStrides(strides); + std::vector result; + result.reserve(tiled_strides.size()); + for (int64_t s : tiled_strides) { + result.push_back(s == 0); + } + return result; +} + +absl::StatusOr Tiling::RemoveDimension(int dim) const { + size_t tiling_rank = tiles_[0].size(); + if (dim < 0 || dim >= tiling_rank) { + return absl::InvalidArgumentError("Invalid dimension"); + } + int dim_in_tile = dim; + std::vector new_tiles; + size_t last_tile_rank = tiling_rank; + for (Tile t : tiles_) { + if (last_tile_rank < t.size()) { + return absl::InvalidArgumentError("Rank invariant violated"); + } + dim_in_tile -= (last_tile_rank - t.size()); + last_tile_rank = t.size(); + if (dim_in_tile >= 0) { + t.erase(t.begin() + dim_in_tile); + } + if (t.empty()) break; + new_tiles.push_back(std::move(t)); + } + return Tiling(std::move(new_tiles)); +} + +Tiling Tiling::Canonicalize() const { + if (tiles_.size() <= 1) return *this; + std::vector new_tiles; + new_tiles.push_back(tiles_[0]); + Tile shape = tiles_[0]; + for (size_t i = 1; i < tiles_.size(); ++i) { + const Tile& tile = tiles_[i]; + Tile canonical_tile; + bool found_non_one = false; + for (size_t j = 0; j < tile.size(); ++j) { + if (tile[j] != 1) { + canonical_tile.assign(tile.begin() + j, tile.end()); + found_non_one = true; + break; + } + } + if (!found_non_one) { + canonical_tile = {1}; + } + + bool redundant = true; + if (shape.size() < canonical_tile.size()) { + redundant = false; + } else { + for (size_t k = 0; k < canonical_tile.size(); ++k) { + if (shape[shape.size() - canonical_tile.size() + k] != + canonical_tile[k]) { + redundant = false; + break; + } + } + } + + if (redundant) continue; + shape = canonical_tile; + new_tiles.push_back(std::move(canonical_tile)); + } + return Tiling(std::move(new_tiles)); +} + +std::string Tiling::ToString() const { + std::stringstream ss; + ss << "Tiling("; + for (const Tile& tile : tiles_) { + ss << "("; + for (size_t i = 0; i < tile.size(); ++i) { + if (i > 0) ss << ", "; + ss << tile[i]; + } + if (tile.size() == 1) ss << ","; + ss << ")"; + } + ss << ")"; + return ss.str(); +} + +bool Tiling::operator==(const Tiling& other) const { + return tiles_ == other.tiles_; +} + +std::ostream& operator<<(std::ostream& os, const Tiling& tiling) { + return os << tiling.ToString(); +} + +std::string Replicated::ToString() const { + std::stringstream ss; + ss << "Replicated(" << times << ")"; + return ss.str(); +} + +absl::StatusOr TiledLayout::Create(Tiling tiling, + std::vector warp_dims, + std::vector lane_dims, + int64_t vector_dim, + bool check_canonical) { + if (tiling.tiles().empty()) { + return absl::InvalidArgumentError("Tiling must have at least one tile"); + } + const Tiling::Tile& min_shape = tiling.tiles()[0]; + TF_ASSIGN_OR_RETURN(std::vector min_tiled_shape, + tiling.TileShape(min_shape)); + + std::vector partitioned_warp_dims = PartitionedDims(warp_dims); + std::vector partitioned_lane_dims = PartitionedDims(lane_dims); + + // Keeping the dimensions in a std::vector as the size is small and the extra + // overhead of a set would be larger. + std::vector dims_set; + dims_set.insert(dims_set.end(), partitioned_warp_dims.begin(), + partitioned_warp_dims.end()); + dims_set.insert(dims_set.end(), partitioned_lane_dims.begin(), + partitioned_lane_dims.end()); + dims_set.push_back(vector_dim); + + for (int64_t d : dims_set) { + if (d >= 0) { + return absl::InvalidArgumentError("All dimensions must be negative"); + } + + if (d < -static_cast(min_tiled_shape.size() - min_shape.size())) { + return absl::InvalidArgumentError("Dimension out of range"); + } + } + + std::sort(dims_set.begin(), dims_set.end()); + for (size_t i = 1; i < dims_set.size(); ++i) { + if (dims_set[i] == dims_set[i - 1]) { + return absl::InvalidArgumentError("Duplicate partitioning dimensions"); + } + } + + int64_t warp_dims_prod = 1; + for (const Dim& d : warp_dims) { + warp_dims_prod *= DimSize(d, min_tiled_shape); + } + if (warp_dims_prod != WARPS_IN_WARPGROUP) { + return absl::InvalidArgumentError( + "The product of warp dims does not equal the number of warps in a " + "warpgroup"); + } + + int64_t lane_dims_prod = 1; + for (const auto& d : lane_dims) { + lane_dims_prod *= DimSize(d, min_tiled_shape); + } + if (lane_dims_prod != WARP_SIZE) { + return absl::InvalidArgumentError( + "The product of lane dims does not equal the warp size"); + } + + TiledLayout layout(std::move(tiling), std::move(warp_dims), + std::move(lane_dims), vector_dim); + if (check_canonical) { + TF_ASSIGN_OR_RETURN(TiledLayout canonical, layout.Canonicalize()); + if (canonical != layout) { + return absl::InvalidArgumentError("TiledLayout is not canonical"); + } + } + return layout; +} + +absl::StatusOr> TiledLayout::TiledTilingShape() const { + const Tiling::Tile& min_shape = tiling_.tiles()[0]; + TF_ASSIGN_OR_RETURN(std::vector min_tiled_shape, + tiling_.TileShape(min_shape)); + return std::vector(min_tiled_shape.begin() + min_shape.size(), + min_tiled_shape.end()); +} + +absl::StatusOr TiledLayout::Canonicalize() const { + Tiling canonical_tiling = tiling_.Canonicalize(); + const std::vector& s = tiling_.tiles()[0]; + TF_ASSIGN_OR_RETURN(std::vector tiled_tiling_shape, + TiledTilingShape()); + + TF_ASSIGN_OR_RETURN(std::vector canonical_tiled_tiling_shape, + canonical_tiling.TileShape(s)); + canonical_tiled_tiling_shape.erase( + canonical_tiled_tiling_shape.begin(), + canonical_tiled_tiling_shape.begin() + s.size()); + + int64_t offset = + static_cast(canonical_tiled_tiling_shape.size()) - 1; + std::vector rev_removed_dims; + // Iterate starting from the end in order to eliminate leading dimensions, + // whenever possible. For instance, say we have + // + // shape=(4, 32, 1, 1, 1, 1, 1) + // warp_dims=(-7,), + // lane_dims=(-6,) + // vector_dim=-1 + // + // and we want to canonicalize this to + // + // shape=(4, 32, 1) + // warp_dims=(-3,), + // lane_dims=(-2,) + // vector_dim=-1. + // + // After the loop below, we end up with + // + // rev_removed_dims=[False, True, True, True, True, False, False] + // + // which will yield offsets `4` for `warp_dims[0]`, `4` for `lane_dims[0]`, + // and `0` for `vector_dim`. + for (auto it = tiled_tiling_shape.rbegin(); it != tiled_tiling_shape.rend(); + ++it) { + if (offset >= 0 && *it == canonical_tiled_tiling_shape[offset]) { + rev_removed_dims.push_back(false); + offset--; + } else { + rev_removed_dims.push_back(true); + } + } + CHECK_EQ(offset, -1); + + std::vector dim_offsets(rev_removed_dims.size()); + int64_t current_sum = 0; + for (size_t i = 0; i < rev_removed_dims.size(); ++i) { + if (rev_removed_dims[i]) { + current_sum++; + } + dim_offsets[i] = current_sum; + } + std::reverse(dim_offsets.begin(), dim_offsets.end()); + + auto replace_tiled_dim = [&](Dim d) -> Dim { + if (std::holds_alternative(d)) { + return d; + } + int64_t idx = std::get(d); + CHECK(idx < 0) << "Expected negative index"; + return idx + dim_offsets[idx + tiled_tiling_shape.size()]; + }; + + auto is_nontrivial = [&](Dim d) -> bool { + if (std::holds_alternative(d)) { + return true; + } + int64_t idx = std::get(d); + CHECK(idx < 0) << "Expected negative index"; + return tiled_tiling_shape[idx + tiled_tiling_shape.size()] != 1; + }; + + std::vector new_warp_dims; + for (const auto& d : warp_dims_) { + if (is_nontrivial(d)) { + new_warp_dims.push_back(replace_tiled_dim(d)); + } + } + std::vector new_lane_dims; + for (const auto& d : lane_dims_) { + if (is_nontrivial(d)) { + new_lane_dims.push_back(replace_tiled_dim(d)); + } + } + Dim new_vector_dim_val = replace_tiled_dim(vector_dim_); + int64_t new_vector_dim = std::get(new_vector_dim_val); + + return TiledLayout(canonical_tiling, new_warp_dims, new_lane_dims, + new_vector_dim); +} + +std::vector TiledLayout::PartitionedWarpDims() const { + return PartitionedDims(warp_dims_); +} + +std::vector TiledLayout::PartitionedLaneDims() const { + return PartitionedDims(lane_dims_); +} + +absl::StatusOr TiledLayout::VectorLength() const { + TF_ASSIGN_OR_RETURN(std::vector tiled_tiling_shape, + TiledTilingShape()); + return DimSize(vector_dim_, tiled_tiling_shape); +} + +bool TiledLayout::operator==(const TiledLayout& other) const { + return tiling_ == other.tiling_ && warp_dims_ == other.warp_dims_ && + lane_dims_ == other.lane_dims_ && vector_dim_ == other.vector_dim_; +} + +std::string TiledLayout::ToString() const { + std::stringstream ss; + ss << "TiledLayout(tiling=" << tiling_.ToString() << ", warp_dims=("; + for (size_t i = 0; i < warp_dims_.size(); ++i) { + if (i > 0) ss << ", "; + if (std::holds_alternative(warp_dims_[i])) { + ss << "Replicated(" << std::get(warp_dims_[i]).times << ")"; + } else { + ss << std::get(warp_dims_[i]); + } + } + ss << "), lane_dims=("; + for (size_t i = 0; i < lane_dims_.size(); ++i) { + if (i > 0) ss << ", "; + if (std::holds_alternative(lane_dims_[i])) { + ss << "Replicated(" << std::get(lane_dims_[i]).times << ")"; + } else { + ss << std::get(lane_dims_[i]); + } + } + ss << "), vector_dim=" << vector_dim_ << ")"; + return ss.str(); +} + +} // namespace jax::mosaic::gpu diff --git a/jaxlib/mosaic/gpu/tiled_layout.h b/jaxlib/mosaic/gpu/tiled_layout.h new file mode 100644 index 000000000000..4bdd55c3cebd --- /dev/null +++ b/jaxlib/mosaic/gpu/tiled_layout.h @@ -0,0 +1,237 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_PY_JAX_EXPERIMENTAL_MOSAIC_GPU_CC_TILED_LAYOUT_H_ +#define THIRD_PARTY_PY_JAX_EXPERIMENTAL_MOSAIC_GPU_CC_TILED_LAYOUT_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" + +namespace jax::mosaic::gpu { + +// A tiling expression describing a permutation of elements of an nd-array. +// +// To apply one level of tiling to an array, each of the trailing dimensions (up +// to the rank of the tile) is unfolded into two dimensions: first equal to the +// ratio of the dimension size and the tile size, and second equal to the tile +// size. Then, all newly unfolded minor dimensions are transposed to appear at +// the end. +// +// This expression describes multi-level tiling, by applying each element of +// `tiles` in sequence to the array. +// +// See https://openxla.org/xla/tiled_layout for a more detailed explanation. +class Tiling { + public: + using Tile = std::vector; + + static absl::StatusOr Create(std::vector tiles); + + bool operator==(const Tiling& other) const; + bool operator!=(const Tiling& other) const { return !(*this == other); } + + const std::vector& tiles() const { return tiles_; } + + // Compute the shape of an array after tiling. + absl::StatusOr> TileShape( + const std::vector& shape) const; + + // Compute the shape of an array before tiling from its tiled shape. + absl::StatusOr> UntileShape( + const std::vector& shape) const; + + // Compute the strides of an array after tiling. + std::vector TileStrides(const std::vector& strides) const; + + // Compute the indices of an array after tiling. + std::vector TileIndices(const std::vector& indices) const; + + // Compute the indices of an array before tiling from its tiled indices. + std::vector UntileIndices(const std::vector& indices) const; + + // A fused version of `TileShape` and `TileStrides` for nested shapes. + // + // By nested shape we mean that each logical dimension (i.e. each element of + // shape/strides) is actually composed out of multiple physical dimensions. + // For example, a row-major array of logical shape (128, 128) that is tiled + // into (64, 64) tiles would have a nested shape ((2, 64), (2, 64)) (i.e. each + // dim is split into two sub-dims) and nested strides of + // ((2 * 64 * 64, 64), (64 * 64, 1)). + absl::StatusOr, std::vector>> + TileNestedShapeStrides( + const std::vector>& shape, + const std::vector>& strides) const; + + // Returns true if the tiled dim originated from the given input dim. + absl::StatusOr> TileDimension(int dim) const; + + // Returns a tiling with the given dimension removed. + absl::StatusOr RemoveDimension(int dim) const; + + // We define a tiling to be canonical if, at each step (except the first one, + // which defines the base tile shape): + + // 1. The tiling partitions at least one dimension in more than 1 tile. For + // example, the tiling `(8, 8)(8, 8)` is not canonical, as applying it + // yields a shape `(1, 1, 8, 8)`. We canonicalize it to `(8, 8)`, which + // allows getting rid of the unnecessary `1` dimensions. + // 2. The leading dimensions of each tile are not `1`. If canonicalizing a + // tile in this way leads to an empty tile, then the tile is given shape + // `(1,)`---which is still a meaningful (final) tile. For example, the + // tiling `(8, 8)(1, 4)` is not canonical, as applying it yields a shape + // `(8, 2, 1, 4)`. We canonicalize it to `(8, 8)(4,)`, which allows + // getting rid of the unnecessary `1` dimension, and yields a shape + // `(8, 2, 4)`. + Tiling Canonicalize() const; + + std::string ToString() const; + + template + friend H AbslHashValue(H h, const Tiling& tiling) { + return H::combine(std::move(h), tiling.tiles_); + } + + private: + explicit Tiling(std::vector tiles); + + std::vector tiles_; +}; + +// Type wrapper for the number of times a dimension is replicated. +struct Replicated { + int64_t times; + + std::string ToString() const; + + bool operator==(const Replicated& other) const { + return times == other.times; + } + + template + friend H AbslHashValue(H h, const Replicated& rep) { + return H::combine(std::move(h), rep.times); + } +}; + +// A FragmentedArray layout derived from a tiling expression. + +// A logical array is transformed according to the tiling expression, and then +// split across warps (within a warpgroup), lanes, and vectorized according to +// the dimension indices. All dimension indices must be negative and should +// refer to the dimensions after tiling is applied. +// +// To better understand this layout, consider the example of WGMMA-related +// tiling from +// https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-d as +// applied to a 128x128 array. The corresponding TiledLayout has a tiling of: +// +// (64, 8)(16, 8)(8, 8)(1, 2) +// +// and warp_dims=(-8,), lane_dims=(-4, -3), vector_dim=-1. +// +// We begin by applying the tiling (note that it always applies to a suffix): +// +// Tiled shape Remaining tiling actions +// =========================================================================== +// 128 128 (64, 8)(16, 8)(8, 8)(1, 2) +// 2 16 64 8 (16, 8)(8, 8)(1, 2) +// 2 16 4 1 16 8 (8, 8)(1, 2) +// 2 16 4 1 2 1 8 8 (1, 2) +// 2 16 4 1 2 1 8 4 1 2 +// +// The last expression is our final shape. At this stage, we're ready to +// partition the dimensions: warp_dims=(-8,) means that the 8-th dimension from +// the end is partitioned over 4 warps in a warpgroup (and so it must be of +// size 4). lane_dims=(-4, -3) indicate that those two dimensions are +// partitioned over the lanes within a warp (their product must be equal to 32, +// i.e. warp size). Finally, vector_dim=-1 indicates that each (logical) +// register is a vector containing 2 elements (there are no shape restrictions +// here). +// +// Given the above, the shape of the (logical) register array used to represent +// the array in each thread is: (2, 16, 1, 1, 2, 1, 1, 1, 1, 1). We have set +// all the dimensions above to 1, since each thread is a member of a single +// warp, a single lane, and the elements along the vectorized dimension are +// represented by a single (logical) register. +// +class TiledLayout { + public: + using Dim = std::variant; + + static absl::StatusOr Create(Tiling tiling, + std::vector warp_dims, + std::vector lane_dims, + int64_t vector_dim, + bool check_canonical = true); + virtual ~TiledLayout() = default; + + bool operator==(const TiledLayout& other) const; + bool operator!=(const TiledLayout& other) const { return !(*this == other); } + + const Tiling& tiling() const { return tiling_; } + const std::vector& warp_dims() const { return warp_dims_; } + const std::vector& lane_dims() const { return lane_dims_; } + int64_t vector_dim() const { return vector_dim_; } + + // Returns the shape of the tiled tiling (without the base tile shape part). + absl::StatusOr> TiledTilingShape() const; + + // Canonicalizes the layout. E.g. If the tiling suffix is + // (4, 32, 1, 1, 1), vector_dim = -1, warp_dims = {-5}, lane_dims = {-4} + // then the canonicalized layout is + // (4, 32, 1), vector_dim = -1, warp_dims = {-3}, lane_dims = {-2} + absl::StatusOr Canonicalize() const; + + // Returns the partitioned warp dimensions verbatim. + std::vector PartitionedWarpDims() const; + + // Returns the partitioned lane dimensions verbatim. + std::vector PartitionedLaneDims() const; + + // Returns the size of the vector dimension. E.g. if the tiling suffix is + // (..., 4), and vector_dims = {-1}, then the vector length is 4. + absl::StatusOr VectorLength() const; + + template + friend H AbslHashValue(H h, const TiledLayout& layout) { + return H::combine(std::move(h), layout.tiling_, layout.warp_dims_, + layout.lane_dims_, layout.vector_dim_); + } + + std::string ToString() const; + + private: + TiledLayout(Tiling tiling, std::vector warp_dims, + std::vector lane_dims, int64_t vector_dim) + : tiling_(std::move(tiling)), + warp_dims_(std::move(warp_dims)), + lane_dims_(std::move(lane_dims)), + vector_dim_(vector_dim) {}; + + Tiling tiling_; + std::vector warp_dims_; + std::vector lane_dims_; + int64_t vector_dim_; +}; + +} // namespace jax::mosaic::gpu + +#endif // THIRD_PARTY_PY_JAX_EXPERIMENTAL_MOSAIC_GPU_CC_TILED_LAYOUT_H_ diff --git a/jaxlib/mosaic/gpu/tiled_layout_test.cc b/jaxlib/mosaic/gpu/tiled_layout_test.cc new file mode 100644 index 000000000000..817799bb47c4 --- /dev/null +++ b/jaxlib/mosaic/gpu/tiled_layout_test.cc @@ -0,0 +1,241 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/mosaic/gpu/tiled_layout.h" + +#include +#include +#include + +#include +#include +#include "absl/status/status.h" + +namespace jax::mosaic::gpu { +namespace { + +using ::testing::ElementsAre; +using ::testing::status::IsOkAndHolds; +using ::testing::status::StatusIs; + +TEST(TilingTest, TileNestedShapeStrides) { + ASSERT_OK_AND_ASSIGN(Tiling tiling, Tiling::Create({{64, 64}})); + std::vector> shape = {{128}, {128}}; + std::vector> strides = {{128}, {1}}; + + ASSERT_OK_AND_ASSIGN((auto [tiled_shape, tiled_strides]), + tiling.TileNestedShapeStrides(shape, strides)); + + std::vector> expected_shape = {{2}, {2}, {64}, {64}}; + std::vector> expected_strides = { + {64 * 128}, {64}, {128}, {1}}; + EXPECT_EQ(tiled_shape, expected_shape); + EXPECT_EQ(tiled_strides, expected_strides); +} + +TEST(TilingTest, TileNestedShapeStridesAlreadySplit) { + ASSERT_OK_AND_ASSIGN(Tiling tiling, Tiling::Create({{64, 64}})); + std::vector> shape = {{2, 64}, {2, 64}}; + std::vector> strides = {{64 * 128, 128}, {64, 1}}; + + ASSERT_OK_AND_ASSIGN((auto [tiled_shape, tiled_strides]), + tiling.TileNestedShapeStrides(shape, strides)); + + std::vector> expected_shape = {{2}, {2}, {64}, {64}}; + std::vector> expected_strides = { + {64 * 128}, {64}, {128}, {1}}; + EXPECT_EQ(tiled_shape, expected_shape); + EXPECT_EQ(tiled_strides, expected_strides); +} + +TEST(TilingTest, TileNestedShapeStridesMultiLevel) { + ASSERT_OK_AND_ASSIGN(Tiling tiling, Tiling::Create({{64, 64}, {8}})); + std::vector> shape = {{128}, {128}}; + std::vector> strides = {{128}, {1}}; + + ASSERT_OK_AND_ASSIGN((auto [tiled_shape, tiled_strides]), + tiling.TileNestedShapeStrides(shape, strides)); + + std::vector> expected_shape = {{2}, {2}, {64}, {8}, {8}}; + std::vector> expected_strides = { + {8192}, {64}, {128}, {8}, {1}}; + EXPECT_EQ(tiled_shape, expected_shape); + EXPECT_EQ(tiled_strides, expected_strides); +} + +TEST(TilingTest, TileIndices) { + ASSERT_OK_AND_ASSIGN(Tiling tiling, Tiling::Create({{64, 64}})); + std::vector indices = {70, 80}; + + std::vector tiled_indices = tiling.TileIndices(indices); + + EXPECT_THAT(tiled_indices, ElementsAre(1, 1, 6, 16)); +} + +TEST(TilingTest, UntileIndices) { + ASSERT_OK_AND_ASSIGN(Tiling tiling, Tiling::Create({{64, 64}})); + std::vector indices = {1, 1, 6, 16}; + + std::vector untiled_indices = tiling.UntileIndices(indices); + + EXPECT_THAT(untiled_indices, ElementsAre(70, 80)); +} + +TEST(TilingTest, TileIndicesMultiLevel) { + ASSERT_OK_AND_ASSIGN(Tiling tiling, Tiling::Create({{64, 64}, {8}})); + std::vector indices = {70, 80}; + + std::vector tiled_indices = tiling.TileIndices(indices); + + EXPECT_THAT(tiled_indices, ElementsAre(1, 1, 6, 2, 0)); +} + +TEST(TilingTest, UntileIndicesMultiLevel) { + ASSERT_OK_AND_ASSIGN(Tiling tiling, Tiling::Create({{64, 64}, {8}})); + std::vector indices = {1, 1, 6, 2, 0}; + + auto untiled_indices = tiling.UntileIndices(indices); + + EXPECT_THAT(untiled_indices, ElementsAre(70, 80)); +} + +TEST(TiledLayoutTest, Create) { + ASSERT_OK_AND_ASSIGN(Tiling tiling, + Tiling::Create({{64, 8}, {16, 8}, {8, 8}, {1, 2}})); + + ASSERT_OK_AND_ASSIGN( + TiledLayout layout, + TiledLayout::Create(std::move(tiling), + /*warp_dims=*/{-8}, + /*lane_dims=*/{-4, -3}, + /*vector_dim=*/-1, /*check_canonical=*/false)); + + EXPECT_THAT(layout.warp_dims(), ElementsAre(-8)); + EXPECT_THAT(layout.lane_dims(), ElementsAre(-4, -3)); + EXPECT_EQ(layout.vector_dim(), -1); +} + +TEST(TiledLayoutTest, CreateFailsWithDuplicateDims) { + ASSERT_OK_AND_ASSIGN(Tiling tiling, Tiling::Create({{10, 2}})); + + EXPECT_THAT(TiledLayout::Create(std::move(tiling), + /*warp_dims=*/{-1}, + /*lane_dims=*/{-1}, + /*vector_dim=*/-2), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(TiledLayoutTest, CreateFailsWithEmptyTiling) { + ASSERT_OK_AND_ASSIGN(Tiling tiling, Tiling::Create({})); + EXPECT_THAT(TiledLayout::Create(std::move(tiling), + /*warp_dims=*/{}, + /*lane_dims=*/{}, + /*vector_dim=*/-1), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(TiledLayoutTest, CreateFailsWithPositiveDim) { + ASSERT_OK_AND_ASSIGN(Tiling tiling, Tiling::Create({{32, 4}})); + EXPECT_THAT(TiledLayout::Create(std::move(tiling), + /*warp_dims=*/{1}, + /*lane_dims=*/{-2}, + /*vector_dim=*/-1), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(TiledLayoutTest, CreateFailsWithOutOfRangeDim) { + ASSERT_OK_AND_ASSIGN(Tiling tiling, Tiling::Create({{32, 4}})); + EXPECT_THAT(TiledLayout::Create(std::move(tiling), + /*warp_dims=*/{-3}, + /*lane_dims=*/{-2}, + /*vector_dim=*/-1), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(TiledLayoutTest, CreateFailsWithInvalidWarpDimsProduct) { + ASSERT_OK_AND_ASSIGN(Tiling tiling, Tiling::Create({{32, 4}})); + EXPECT_THAT(TiledLayout::Create(std::move(tiling), + /*warp_dims=*/{-2}, + /*lane_dims=*/{Replicated(32)}, + /*vector_dim=*/-1), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(TiledLayoutTest, CreateFailsWithInvalidLaneDimsProduct) { + ASSERT_OK_AND_ASSIGN(Tiling tiling, Tiling::Create({{8, 4, 32}})); + EXPECT_THAT(TiledLayout::Create(std::move(tiling), + /*warp_dims=*/{-2}, + /*lane_dims=*/{-3}, + /*vector_dim=*/-1), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(TiledLayoutTest, CreateFailsWithNonCanonicalLayout) { + ASSERT_OK_AND_ASSIGN(Tiling tiling, Tiling::Create({{4, 32}, {1, 1}})); + EXPECT_THAT(TiledLayout::Create(std::move(tiling), + /*warp_dims=*/{-4}, + /*lane_dims=*/{-3}, + /*vector_dim=*/-1, + /*check_canonical=*/true), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(TiledLayoutTest, Canonicalize) { + ASSERT_OK_AND_ASSIGN(Tiling tiling, + Tiling::Create({{4, 32, 1, 1}, {1, 1, 1, 1}})); + + ASSERT_OK_AND_ASSIGN( + TiledLayout layout, + TiledLayout::Create(std::move(tiling), + /*warp_dims=*/{-8}, + /*lane_dims=*/{-7}, + /*vector_dim=*/-1, /*check_canonical=*/false)); + + ASSERT_OK_AND_ASSIGN(TiledLayout canonical, layout.Canonicalize()); + + EXPECT_THAT(canonical.warp_dims(), ElementsAre(-4)); + EXPECT_THAT(canonical.lane_dims(), ElementsAre(-3)); + EXPECT_EQ(canonical.vector_dim(), -1); +} + +TEST(TiledLayoutTest, PartitionedDimsReturnAllPartitionedDims) { + ASSERT_OK_AND_ASSIGN(Tiling tiling, + Tiling::Create({{64, 8}, {32, 8}, {8, 8}, {1, 4}})); + ASSERT_OK_AND_ASSIGN( + TiledLayout layout, + TiledLayout::Create(std::move(tiling), + /*warp_dims=*/{-8, Replicated(2)}, + /*lane_dims=*/{-4, -3, Replicated(2)}, + /*vector_dim=*/-1, /*check_canonical=*/false)); + + EXPECT_THAT(layout.PartitionedWarpDims(), ElementsAre(-8)); + EXPECT_THAT(layout.PartitionedLaneDims(), ElementsAre(-4, -3)); +} + +TEST(TiledLayoutTest, VectorLengthReturnsTheSizeOfTheVectorDim) { + ASSERT_OK_AND_ASSIGN(Tiling tiling, + Tiling::Create({{64, 8}, {16, 8}, {8, 8}, {1, 2}})); + ASSERT_OK_AND_ASSIGN( + TiledLayout layout, + TiledLayout::Create(std::move(tiling), + /*warp_dims=*/{-8}, + /*lane_dims=*/{-4, -3}, + /*vector_dim=*/-1, /*check_canonical=*/false)); + + EXPECT_THAT(layout.VectorLength(), IsOkAndHolds(2)); +} + +} // namespace +} // namespace jax::mosaic::gpu diff --git a/jaxlib/mosaic/gpu/wheel/BUILD.bazel b/jaxlib/mosaic/gpu/wheel/BUILD.bazel new file mode 100644 index 000000000000..dc41a99de84e --- /dev/null +++ b/jaxlib/mosaic/gpu/wheel/BUILD.bazel @@ -0,0 +1,29 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +exports_files(["setup.py", "__init__.py"]) + +cc_binary( + name = "mosaic_gpu.so", + linkopts = [ + "-Wl,--version-script,$(location :mosaic_symbols.lds)", + "-Wl,--no-undefined", + ], + linkshared = True, + deps = [ + ":mosaic_symbols.lds", + "//jaxlib/mosaic/gpu:custom_call", + ], + visibility = ["//visibility:public"], +) diff --git a/jaxlib/mosaic/gpu/wheel/__init__.py b/jaxlib/mosaic/gpu/wheel/__init__.py new file mode 100644 index 000000000000..1559e73015fe --- /dev/null +++ b/jaxlib/mosaic/gpu/wheel/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This is currently just here so that users can access the version file. + +from .version import __version__ as __version__ diff --git a/jaxlib/mosaic/gpu/wheel/mosaic_symbols.lds b/jaxlib/mosaic/gpu/wheel/mosaic_symbols.lds new file mode 100644 index 000000000000..fbff5dc87d7d --- /dev/null +++ b/jaxlib/mosaic/gpu/wheel/mosaic_symbols.lds @@ -0,0 +1,10 @@ +VERS_1.0 { + global: + extern "C" { + MosaicGpuCompile; + MosaicGpuUnload; + }; + + local: + *; +}; diff --git a/jaxlib/mosaic/gpu/wheel/setup.py b/jaxlib/mosaic/gpu/wheel/setup.py new file mode 100644 index 000000000000..366acda1564d --- /dev/null +++ b/jaxlib/mosaic/gpu/wheel/setup.py @@ -0,0 +1,99 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import os +from setuptools import setup, find_namespace_packages + +__version__ = None +cuda_version = 0 # placeholder +project_name = f"mosaic_gpu-cuda{cuda_version}" +package_name = f"mosaic_gpu.mosaic_gpu_cuda{cuda_version}" + +cuda_wheel_suffix = '' # placeholder + +nvidia_cublas_version = '' # placeholder +nvidia_cuda_cupti_version = '' # placeholder +nvidia_cuda_nvcc_version = '' # placeholder +nvidia_cuda_runtime_version = '' # placeholder +nvidia_cudnn_version = '' # placeholder +nvidia_cufft_version = '' # placeholder +nvidia_cusolver_version = '' # placeholder +nvidia_cusparse_version = '' # placeholder +nvidia_nccl_version = '' # placeholder +nvidia_nvjitlink_version = '' # placeholder +nvidia_cuda_nvrtc_version = '' # placeholder +nvidia_nvshmem_version = '' # placeholder + +def load_version_module(pkg_path): + spec = importlib.util.spec_from_file_location( + 'version', os.path.join(pkg_path, 'version.py')) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + +_version_module = load_version_module(f"mosaic_gpu/mosaic_gpu_cuda{cuda_version}") +__version__ = _version_module._get_version_for_build() + +packages = find_namespace_packages( + include=[ + package_name, + f"{package_name}.*", + ] +) + +setup( + name=project_name, + version=__version__, + description="Mosaic GPU Support Plugin", + long_description="", + long_description_content_type="text/markdown", + author="JAX team", + author_email="jax-dev@google.com", + packages=packages, + install_requires=[], + extras_require={ + 'with-cuda': [ + # Using the same deps as JAX for now - can likely be trimmed down. + f"nvidia-cublas{cuda_wheel_suffix}{nvidia_cublas_version}", + f"nvidia-cuda-cupti{cuda_wheel_suffix}{nvidia_cuda_cupti_version}", + f"nvidia-cuda-nvcc{cuda_wheel_suffix}{nvidia_cuda_nvcc_version}", + f"nvidia-cuda-runtime{cuda_wheel_suffix}{nvidia_cuda_runtime_version}", + f"nvidia-cudnn-cu{cuda_version}{nvidia_cudnn_version}", + f"nvidia-cufft{cuda_wheel_suffix}{nvidia_cufft_version}", + f"nvidia-cusolver{cuda_wheel_suffix}{nvidia_cusolver_version}", + f"nvidia-cusparse{cuda_wheel_suffix}{nvidia_cusparse_version}", + f"nvidia-nccl-cu{cuda_version}{nvidia_nccl_version}", + f"nvidia-nvjitlink{cuda_wheel_suffix}{nvidia_nvjitlink_version}", + f"nvidia-cuda-nvrtc{cuda_wheel_suffix}{nvidia_cuda_nvrtc_version}", + f"nvidia-nvshmem-cu{cuda_version}{nvidia_nvshmem_version}", + ] + (["nvidia-nvvm"] if cuda_version == 13 else []), + }, + url="https://github.com/jax-ml/jax", + license="Apache-2.0", + classifiers=[ + "Development Status :: 5 - Production/Stable", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: Free Threading :: 3 - Stable", + ], + package_data={ + package_name: ["*.so"], + }, + zip_safe=False, + entry_points={ + "mosaic_gpu": [ + f"mosaic_gpu_cuda{cuda_version} = {package_name}", + ], + }, +) diff --git a/jaxlib/pass_boilerplate.h b/jaxlib/mosaic/pass_boilerplate.h similarity index 87% rename from jaxlib/pass_boilerplate.h rename to jaxlib/mosaic/pass_boilerplate.h index b9754a8738ee..96d9e85a1d2d 100644 --- a/jaxlib/pass_boilerplate.h +++ b/jaxlib/mosaic/pass_boilerplate.h @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_PASS_BOILERPLATE_H_ -#define JAXLIB_PASS_BOILERPLATE_H_ +#ifndef JAXLIB_MOSAIC_PASS_BOILERPLATE_H_ +#define JAXLIB_MOSAIC_PASS_BOILERPLATE_H_ #include -#include "mlir/include/mlir/IR/DialectRegistry.h" -#include "mlir/include/mlir/Pass/Pass.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Support/TypeID.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" namespace jaxlib { namespace mlir { @@ -64,4 +64,4 @@ class Pass : public ::mlir::OperationPass { } // namespace mlir } // namespace jaxlib -#endif // JAXLIB_PASS_BOILERPLATE_H_ +#endif // JAXLIB_MOSAIC_PASS_BOILERPLATE_H_ diff --git a/jaxlib/mosaic/python/BUILD b/jaxlib/mosaic/python/BUILD index 6e575fb3092a..d13b5efe5fbe 100644 --- a/jaxlib/mosaic/python/BUILD +++ b/jaxlib/mosaic/python/BUILD @@ -45,7 +45,7 @@ gentbl_filegroup( tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = ":tpu_python.td", deps = [ - "//jaxlib/mosaic:tpu_td_files", + "//jaxlib/mosaic:tpu_ops_td_files", "@llvm-project//mlir:OpBaseTdFiles", ], ) diff --git a/jaxlib/mosaic/python/layout_defs.py b/jaxlib/mosaic/python/layout_defs.py index 813c01d0ae5c..2a252beb83ee 100644 --- a/jaxlib/mosaic/python/layout_defs.py +++ b/jaxlib/mosaic/python/layout_defs.py @@ -52,6 +52,7 @@ def __bool__(self): class ImplicitDim(enum.IntEnum): MINOR = -1 SECOND_MINOR = -2 + MINOR_AND_SECOND_MINOR = -3 def __repr__(self) -> str: return str(int(self)) diff --git a/jaxlib/mosaic/python/tpu.py b/jaxlib/mosaic/python/tpu.py index a1c7f79ba769..8083b9759f1b 100644 --- a/jaxlib/mosaic/python/tpu.py +++ b/jaxlib/mosaic/python/tpu.py @@ -19,6 +19,7 @@ # pylint: disable=g-bad-import-order +from . import _tpu_gen from ._tpu_gen import * # pylint: disable=wildcard-import from ._tpu_gen import _Dialect from jaxlib.mlir._mlir_libs._tpu_ext import * # pylint: disable=wildcard-import @@ -32,7 +33,7 @@ @_cext.register_operation(_Dialect, replace=True) -class TraceOp(TraceOp): # noqa: F405 +class TraceOp(_tpu_gen.TraceOp): # noqa: F405 """An extension to the automatically generated TraceOp bindings.""" def __init__(self, results, message, level, *, loc=None, ip=None): @@ -45,7 +46,7 @@ def body(self): @_cext.register_operation(_Dialect, replace=True) -class RegionOp(RegionOp): # noqa: F405 +class RegionOp(_tpu_gen.RegionOp): # noqa: F405 """An extension to the automatically generated RegionOp bindings.""" def __init__(self, results, *, loc=None, ip=None): diff --git a/jaxlib/mosaic/python/tpu_python.td b/jaxlib/mosaic/python/tpu_python.td index 56abaadd7f36..a6abf92116b3 100644 --- a/jaxlib/mosaic/python/tpu_python.td +++ b/jaxlib/mosaic/python/tpu_python.td @@ -13,4 +13,4 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -include "jaxlib/mosaic/dialect/tpu/tpu.td" +include "jaxlib/mosaic/dialect/tpu/tpu_ops.td" diff --git a/jaxlib/mosaic/serde.cc b/jaxlib/mosaic/serde.cc index 88bca44bf181..263ff71c3ce0 100644 --- a/jaxlib/mosaic/serde.cc +++ b/jaxlib/mosaic/serde.cc @@ -18,15 +18,16 @@ limitations under the License. #include #include -#include "llvm/include/llvm/ADT/StringMap.h" -#include "llvm/include/llvm/ADT/StringRef.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/Operation.h" -#include "mlir/include/mlir/IR/OperationSupport.h" -#include "mlir/include/mlir/IR/Visitors.h" -#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/WalkResult.h" namespace jaxlib::mosaic { @@ -54,7 +55,7 @@ std::optional demangle(llvm::StringRef name, mlir::LogicalResult RunSerde( mlir::ModuleOp module, const llvm::StringMap& upgrade_rules, const llvm::StringMap& downgrade_rules, bool serialize, - SerdeOptions options) { + SerdeOptions options, bool keep_version_attr) { int version = options.highest_version; int serialize_version = options.serialize_version; if (!serialize && serialize_version != -1) { @@ -90,14 +91,20 @@ mlir::LogicalResult RunSerde( return mlir::failure(); } version = version_attr.getInt(); - module->removeAttr(options.version_attr_name); + if (!keep_version_attr) { + module->removeAttr(options.version_attr_name); + } } std::string storage; - auto result = module.walk([&](mlir::Operation* op) { + // Explicitly use a post-order walk to allow for deleting operations on the + // fly. + auto result = module.walk([&](mlir::Operation* + op) { if (mlir::isa(op)) { // Don't mangle the ModuleOp itself. return mlir::WalkResult::advance(); } std::optional new_name; + bool was_erased = false; if (serialize) { auto new_name_str = mangle(op->getName().getStringRef(), options.dialect_prefix, &storage); @@ -119,28 +126,39 @@ mlir::LogicalResult RunSerde( // Upgrade the op to the current version, if needed. if (const auto rule = upgrade_rules.find(new_name->getStringRef()); rule != upgrade_rules.end()) { - if (rule->second(op, version).failed()) { + if (rule->second(op, version, was_erased).failed()) { return mlir::WalkResult::interrupt(); } } } + + // In this case, the op is no longer accessible, and can't be processed + // further. + if (was_erased) { + return mlir::WalkResult::advance(); + } + auto new_op = mlir::Operation::create( op->getLoc(), *new_name, op->getResultTypes(), op->getOperands(), op->getAttrs(), nullptr, op->getSuccessors(), op->getRegions()); + op->getBlock()->getOperations().insertAfter(mlir::Block::iterator(op), + new_op); // Downgrade the op to the target version, if needed. + bool downgrade_failed = false; if (serialize && version != serialize_version) { if (const auto rule = downgrade_rules.find(op->getName().getStringRef()); rule != downgrade_rules.end()) { - if (rule->second(new_op, serialize_version).failed()) { - return mlir::WalkResult::interrupt(); - } + downgrade_failed = + rule->second(new_op, serialize_version, was_erased).failed(); } } - op->getBlock()->getOperations().insertAfter(mlir::Block::iterator(op), - new_op); + if (was_erased) { + return mlir::WalkResult::advance(); + } op->replaceAllUsesWith(new_op->getResults()); op->erase(); - return mlir::WalkResult::advance(); + return downgrade_failed ? mlir::WalkResult::interrupt() + : mlir::WalkResult::advance(); }); return result.wasInterrupted() ? mlir::failure() : mlir::success(); } diff --git a/jaxlib/mosaic/serde.h b/jaxlib/mosaic/serde.h index 762d9e5dad73..af7f936f9055 100644 --- a/jaxlib/mosaic/serde.h +++ b/jaxlib/mosaic/serde.h @@ -18,11 +18,11 @@ limitations under the License. #include -#include "llvm/include/llvm/ADT/StringMap.h" -#include "llvm/include/llvm/ADT/StringRef.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Support/LLVM.h" namespace jaxlib::mosaic { @@ -38,18 +38,20 @@ struct SerdeOptions { // // The first argument is the operation to upgrade/downgrade. // The second argument is the target version. +// The third argument is a boolean that the serde rule will set to true if it +// happens to erase the operation. // // The function should return success if the upgrade/downgrade was successful, // or an error otherwise. using SerdeRuleType = - std::function<::mlir::LogicalResult(::mlir::Operation *, int)>; + std::function<::mlir::LogicalResult(::mlir::Operation *, int, bool &)>; // Run serialization or deserialization on the given module. ::mlir::LogicalResult RunSerde( ::mlir::ModuleOp module, - const llvm::StringMap &upgrade_rules, - const llvm::StringMap &downgrade_rules, bool serialize, - SerdeOptions options); + const llvm::StringMap& upgrade_rules, + const llvm::StringMap& downgrade_rules, bool serialize, + SerdeOptions options, bool keep_version_attr = false); } // namespace jaxlib::mosaic diff --git a/jaxlib/nb_class_ptr.h b/jaxlib/nb_class_ptr.h new file mode 100644 index 000000000000..93e3cd9c0258 --- /dev/null +++ b/jaxlib/nb_class_ptr.h @@ -0,0 +1,68 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_NB_CLASS_PTR_H_ +#define JAXLIB_NB_CLASS_PTR_H_ + +#include + +#include "nanobind/nanobind.h" + +namespace jax { + +// A reference-counting smart pointer to a nanobind-wrapped class on the Python +// heap. Type T must be a class known to nanobind via a nanobind::class_ +// declaration. nb_class_ptr is useful for managing C++ classes that may be +// allocated inline in Python objects on the Python heap. +template +class nb_class_ptr : public nanobind::object { + public: + static constexpr auto Name = nanobind::detail::make_caster::Name; + + inline nb_class_ptr() : nanobind::object() {} + inline nb_class_ptr(nanobind::handle h, ::nanobind::detail::borrow_t) + : nanobind::object(h, ::nanobind::detail::borrow_t{}) {} + inline nb_class_ptr(nanobind::handle h, ::nanobind::detail::steal_t) + : nanobind::object(h, ::nanobind::detail::steal_t{}) {} + inline static bool check_(nanobind::handle h) { + nanobind::handle type = nanobind::type(); + return nanobind::isinstance(h, type); + }; + + template >> + inline nb_class_ptr(nb_class_ptr&& other) + : nanobind::object(other.release(), ::nanobind::detail::steal_t{}) {} + + T* operator->() const { return nanobind::inst_ptr(ptr()); } + T& operator*() const { return *nanobind::inst_ptr(ptr()); } + T* get() const { return ptr() ? nanobind::inst_ptr(ptr()) : nullptr; } +}; + +// This function is analogous to std::make_unique(...), but instead it +// allocates the object on the Python heap +template +nb_class_ptr make_nb_class(Args&&... args) { + nanobind::handle type = nanobind::type(); + nanobind::object instance = nanobind::inst_alloc(type); + T* ptr = nanobind::inst_ptr(instance); + new (ptr) T(std::forward(args)...); + nanobind::inst_mark_ready(instance); + return nb_class_ptr(instance.release(), ::nanobind::detail::steal_t{}); +} + +} // namespace jax + +#endif // JAXLIB_NB_CLASS_PTR_H_ diff --git a/jaxlib/partition_spec.cc b/jaxlib/partition_spec.cc new file mode 100644 index 000000000000..4a687dac35bd --- /dev/null +++ b/jaxlib/partition_spec.cc @@ -0,0 +1,239 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/partition_spec.h" + +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/hash/hash.h" +#include "absl/strings/str_format.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep + +namespace nb = nanobind; + +namespace jax { + +namespace { + +bool IsTrue(nb::handle x) { + int ret = PyObject_IsTrue(x.ptr()); + if (ret == -1) { + throw nb::python_error(); + } + return static_cast(ret); +} + +nb::object CanonicalizePartition(nb::object unconstrained_singleton, + nb::object partition) { + if (!IsTrue(partition)) { + return nb::none(); + } + if (partition.is(unconstrained_singleton)) { + return unconstrained_singleton; + } + bool is_tuple = nb::isinstance(partition); + if (is_tuple || nb::isinstance(partition)) { + for (nb::handle p : partition) { + if (nb::isinstance(p) || nb::isinstance(p)) { + throw nb::value_error( + absl::StrFormat( + "A tuple inside PartitionSpec cannot contain a " + "nested tuple. Got partition: %s and the nested tuple: %s", + nb::cast(nb::str(partition)), + nb::cast(nb::str(p))) + .c_str()); + } + } + if (nb::len(partition) == 1) { + return partition[0]; + } + if (!is_tuple) { + return nb::tuple(partition); + } + return partition; + } + return partition; +} + +void CheckPartitionSpec(nb::tuple partitions, nb::frozenset unreduced, + nb::frozenset reduced) { + if (unreduced.contains(nb::none())) { + throw nb::value_error( + "unreduced cannot contain None. All elements in unreduced should " + "refer to the mesh axes."); + } + if (reduced.contains(nb::none())) { + throw nb::value_error( + "reduced cannot contain None. All elements in reduced should " + "refer to the mesh axes."); + } + auto check_overlap = [&](nb::handle partition) { + if (unreduced.contains(partition)) { + throw nb::value_error( + absl::StrFormat( + "partitions cannot overlap with unreduced axes passed to " + "PartitionSpec. Got partitions: %s and unreduced axes: %s", + nb::cast(nb::str(partitions)), + nb::cast(nb::str(unreduced))) + .c_str()); + } + if (reduced.contains(partition)) { + throw nb::value_error( + absl::StrFormat( + "partitions cannot overlap with reduced axes passed to " + "PartitionSpec. Got partitions: %s and reduced axes: %s", + nb::cast(nb::str(partitions)), + nb::cast(nb::str(reduced))) + .c_str()); + } + }; + for (nb::handle partition : partitions) { + if (nb::isinstance(partition)) { + for (nb::handle p : partition) { + check_overlap(p); + } + } else { + check_overlap(partition); + } + } + if (nb::len((unreduced & reduced)) != 0) { + throw nb::value_error( + absl::StrFormat("`unreduced` and `reduced` argument to PartitionSpec " + "cannot overlap. " + "Got unreduced: %s and reduced: %s", + nb::cast(nb::str(unreduced)), + nb::cast(nb::str(reduced))) + .c_str()); + } +} + +} // namespace + +PartitionSpec::PartitionSpec(nb::tuple partitions, nb::frozenset unreduced, + nb::frozenset reduced) + : partitions_(std::move(partitions)), + unreduced_(std::move(unreduced)), + reduced_(std::move(reduced)) {} + +Py_hash_t PartitionSpec::Hash() const { + size_t h = absl::HashOf(nb::hash(partitions_), nb::hash(unreduced_), + nb::hash(reduced_)); + Py_hash_t s = absl::bit_cast(h); // Python hashes are signed. + return s == -1 ? -2 : s; // -1 must not be used as a Python hash value. +} + +bool PartitionSpec::operator==(const PartitionSpec& other) const { + return partitions().equal(other.partitions()) && + unreduced().equal(other.unreduced()) && + reduced().equal(other.reduced()); +} + +bool PartitionSpec::Eq(const nb::object& other) const { + if (!other.ptr() || other.is_none()) { + return false; + } + PartitionSpec* other_spec; + if (nb::try_cast(other, other_spec)) { + return *this == *other_spec; + } + nb::tuple other_tuple; + if (nb::try_cast(other, other_tuple)) { + if (unreduced().size() > 0 || reduced().size() > 0 || + partitions().size() != other_tuple.size()) { + return false; + } + for (size_t i = 0; i < partitions().size(); ++i) { + if (!partitions()[i].equal(CanonicalizePartition( + *unconstrained_singleton_, other_tuple[i]))) { + return false; + } + } + return true; + } + return false; +} + +nb::object* PartitionSpec::unconstrained_singleton_ = nullptr; + +void PartitionSpec::Register(nb::module_& m) { + nb::class_(m, "UnconstrainedSingleton") + .def("__repr__", [](nb::handle self) { return nb::str("UNCONSTRAINED"); }) + .def("__reduce__", + [](nb::handle self) { return nb::str("UNCONSTRAINED_PARTITION"); }); + + unconstrained_singleton_ = new nb::object(nb::cast(UnconstrainedSingleton())); + m.attr("UNCONSTRAINED_PARTITION") = *unconstrained_singleton_; + + m.def("canonicalize_partition", [](nb::object partition) { + return CanonicalizePartition(*unconstrained_singleton_, partition); + }); + + auto partition_spec = + nb::class_(m, "PartitionSpec", + nb::sig("class PartitionSpec(typing.Any)")) + .def( + "__init__", + [](PartitionSpec* self, nb::args partition_args, + nb::object unreduced_arg, nb::object reduced_arg) { + nb::tuple partitions = + nb::steal(PyTuple_New(partition_args.size())); + for (size_t i = 0; i < partition_args.size(); ++i) { + PyTuple_SET_ITEM(partitions.ptr(), i, + CanonicalizePartition( + *PartitionSpec::unconstrained_singleton_, + partition_args[i]) + .release() + .ptr()); + } + nb::frozenset unreduced; + nb::frozenset reduced; + if (!PyAnySet_Check(unreduced_arg.ptr())) { + throw nb::type_error( + absl::StrFormat( + "unreduced argument of PartitionSpec should " + "of type `frozenset` or `set`. Got type %s", + nb::cast(nb::repr(unreduced_arg.type()))) + .c_str()); + } + if (!PyAnySet_Check(reduced_arg.ptr())) { + throw nb::type_error( + absl::StrFormat( + "reduced argument of PartitionSpec should " + "of type `frozenset` or `set`. Got type %s", + nb::cast(nb::repr(reduced_arg.type()))) + .c_str()); + } + unreduced = nb::frozenset(unreduced_arg); + reduced = nb::frozenset(reduced_arg); + CheckPartitionSpec(partitions, unreduced, reduced); + new (self) + PartitionSpec(std::move(partitions), std::move(unreduced), + std::move(reduced)); + }, + nb::arg("partitions"), nb::arg("unreduced") = nb::frozenset(), + nb::arg("reduced") = nb::frozenset()) + .def_prop_ro("_partitions", &PartitionSpec::partitions) + .def_prop_ro("unreduced", &PartitionSpec::unreduced) + .def_prop_ro("reduced", &PartitionSpec::reduced) + .def("__eq__", &PartitionSpec::Eq, nb::arg(), nb::is_operator()) + .def("__hash__", &PartitionSpec::Hash); +} + +} // namespace jax diff --git a/jaxlib/partition_spec.h b/jaxlib/partition_spec.h new file mode 100644 index 000000000000..0ca00fff814b --- /dev/null +++ b/jaxlib/partition_spec.h @@ -0,0 +1,53 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAX_JAXLIB_PARTITION_SPEC_H_ +#define JAX_JAXLIB_PARTITION_SPEC_H_ + +#include + +#include "nanobind/nanobind.h" + +namespace jax { + +struct UnconstrainedSingleton {}; + +class PartitionSpec { + public: + PartitionSpec(nanobind::tuple partitions, nanobind::frozenset unreduced, + nanobind::frozenset reduced); + + nanobind::tuple partitions() const { return partitions_; } + nanobind::frozenset unreduced() const { return unreduced_; } + nanobind::frozenset reduced() const { return reduced_; } + + bool operator==(const PartitionSpec& other) const; + + bool Eq(const nanobind::object& other) const; // Python __eq__ + Py_hash_t Hash() const; // Python __hash__ + + static void Register(nanobind::module_& m); + + private: + nanobind::tuple partitions_; + nanobind::frozenset unreduced_; + nanobind::frozenset reduced_; + + static nanobind::object* unconstrained_singleton_; +}; + +} // namespace jax + +#endif // JAX_JAXLIB_PARTITION_SPEC_H_ diff --git a/jaxlib/pathways.cc b/jaxlib/pathways.cc new file mode 100644 index 000000000000..5db3b6cce0db --- /dev/null +++ b/jaxlib/pathways.cc @@ -0,0 +1,377 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_user_context.h" +#include "jaxlib/to_ifrt_sharding.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/array_spec.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/remap_plan.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt/user_context.h" +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/types.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace jax { + +namespace { + +// Returns strides for the given `axis_sizes`. +absl::StatusOr> GetStrides(absl::Span axis_sizes) { + if (axis_sizes.empty()) { + return absl::InvalidArgumentError("`axis_sizes` must not be empty"); + } + std::vector strides; + strides.reserve(axis_sizes.size()); + strides.push_back(1); + for (int i = axis_sizes.size() - 1; i > 0; --i) { + strides.push_back(axis_sizes[i] * strides.back()); + } + std::reverse(strides.begin(), strides.end()); + return strides; +} + +// Populates `offsets` with the offsets to use to create continuous intervals +// for `RemapPlan`. +// `axis_sizes` represents the mesh axis sizes up to the mesh axis on which +// the arrays are split/concatenated. +// `current_entry` iterates over the mesh axis sizes to generate the offsets. +// `strides` are the strides of the mesh axes. +absl::Status PopulateSubmeshOffsets(absl::Span axis_sizes, + absl::Span current_entry, + absl::Span strides, + std::vector& offsets) { + int offset = 0; + for (int idx = 0; idx < axis_sizes.size(); ++idx) { + offset += strides[idx] * current_entry[idx]; + } + offsets.push_back(offset); + current_entry[current_entry.size() - 1] += 1; + for (int idx = current_entry.size() - 1; + idx > 0 && current_entry[idx] >= axis_sizes[idx]; --idx) { + current_entry[idx] = 0; + current_entry[idx - 1] += 1; + } + if (current_entry[0] < axis_sizes[0]) { + return PopulateSubmeshOffsets(axis_sizes, current_entry, strides, offsets); + } else { + return absl::OkStatus(); + } +} + +// If `backend` is nullptr, sets it to `array.py_client()`; otherwise checks +// that `backend` equals `array.py_client()`. +absl::Status PyClientFromPyArray(const PyArray& array, + nb_class_ptr& backend) { + if (array.py_client().get() == nullptr) { + return absl::InternalError("Unexpected array with py_client as nullptr."); + } + if (backend.get() == nullptr) { + backend = array.py_client(); + } else if (backend.get() != array.py_client().get()) { + std::string old_description = + absl::StrFormat("%p/%s/%s/%s/%s", backend.get(), + backend->platform_name(), backend->platform_version(), + backend->runtime_type(), backend->raw_platform_name()); + std::string new_description = + absl::StrFormat("%p/%s/%s/%s/%s", array.py_client().get(), + array.py_client()->platform_name(), + array.py_client()->platform_version(), + array.py_client()->runtime_type(), + array.py_client()->raw_platform_name()); + return absl::InvalidArgumentError(absl::StrCat( + "py_client mismatch: ", old_description, " vs ", new_description)); + } + return absl::OkStatus(); +} + +namespace nb = ::nanobind; + +// Runs `xla::ifrt::Client::ReshardArrays`. +absl::StatusOr ExperimentalReshardArrays(nb::sequence py_arrays, + nb::sequence out_shardings, + bool donate_input) { + const int num_arrays = nb::len(py_arrays); + + if (nb::len(out_shardings) != num_arrays) { + return absl::InvalidArgumentError( + absl::StrCat("Number of out_shardings must match number of arrays: ", + nb::len(out_shardings), num_arrays)); + } + + if (num_arrays == 0) { + return nb::list(); + } + + PyUserContextScope user_context_scope; + nb_class_ptr backend; + std::vector ifrt_arrays; + std::vector ifrt_specs; + ifrt_arrays.reserve(num_arrays); + ifrt_specs.reserve(num_arrays); + + for (int i = 0; i < num_arrays; ++i) { + PyArray array = nb::cast(py_arrays[i]); + if (array.ifrt_array() == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("Input array ", i, " has been donated or deleted.")); + } + TF_RETURN_IF_ERROR(PyClientFromPyArray(array, backend)); + ifrt_arrays.push_back(tsl::FormRef(array.ifrt_array())); + + TF_ASSIGN_OR_RETURN(xla::ifrt::DType ifrt_dtype, + xla::DtypeToIfRtDType(array.dtype())); + xla::ifrt::Shape ifrt_shape(array.shape()); + TF_ASSIGN_OR_RETURN(xla::ifrt::ShardingRef ifrt_sharding, + GetIfrtHloSharding(out_shardings[i], ifrt_shape)); + ifrt_specs.push_back(xla::ifrt::ArraySpec{ + /*dtype=*/std::move(ifrt_dtype), + /*shape=*/std::move(ifrt_shape), + /*sharding=*/std::move(ifrt_sharding), + }); + } + + const xla::ifrt::ArrayCopySemantics copy_semantics = + donate_input ? xla::ifrt::ArrayCopySemantics::kDonateInput + : xla::ifrt::ArrayCopySemantics::kAlwaysCopy; + + std::vector outputs; + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN( + outputs, backend->ifrt_client()->ReshardArrays( + absl::MakeSpan(ifrt_arrays), ifrt_specs, copy_semantics)); + } + + nb::list result; + for (int i = 0; i < num_arrays; ++i) { + PyArray new_py_array = PyArray::MakeFromIfrtArrayAndSharding( + backend, std::move(outputs[i]), out_shardings[i], + /*weak_type=*/false, + /*committed=*/true, + /*skip_checks=*/true); + result.append(std::move(new_py_array)); + } + + return result; +} + +absl::StatusOr>> +ExperimentalSplitByMeshAxis( + nb::object py_arrays_py, absl::Span sharded_dim_idxs, + absl::Span mesh_axis_sizes, int mesh_axis_idx, + absl::Span mesh_axis_sections, + absl::Span> submesh_shardings, bool donate) { + // Using `nb_class_ptr` requires GIL. + DCHECK(PyGILState_Check()); + + auto py_arrays = nb::cast>(py_arrays_py); + if (py_arrays.empty()) { + return std::vector>(); + } + + if (sharded_dim_idxs.size() != py_arrays.size()) { + return absl::InvalidArgumentError( + absl::StrCat("Number of sharded_dim_idxs must match number of arrays: ", + sharded_dim_idxs.size(), " vs ", py_arrays.size())); + } + if (submesh_shardings.size() != py_arrays.size()) { + return absl::InvalidArgumentError(absl::StrCat( + "Number of submesh_shardings must match number of arrays: ", + submesh_shardings.size(), " vs ", py_arrays.size())); + } + + int num_submeshes = submesh_shardings[0].size(); + if (mesh_axis_sections.size() != num_submeshes) { + return absl::InvalidArgumentError(absl::StrCat( + "Number of mesh_axis_sections must match number of submeshes: ", + mesh_axis_sections.size(), " vs ", num_submeshes)); + } + + PyUserContextScope user_context_scope; + // All input arrays are expected to use the same mesh. + TF_ASSIGN_OR_RETURN(xla::ifrt::DeviceListRef device_list, + GetIfrtDeviceList(py_arrays[0].sharding())); + int num_devices = device_list->size(); + // The last entry in `mexh_axis_sections` contains the mesh axis size. + int mesh_axis_size = mesh_axis_sections.back(); + if (num_devices % mesh_axis_size != 0) { + return absl::InvalidArgumentError(absl::StrCat( + "Number of devices must be divisible by the mesh axis size: ", + num_devices, " vs ", mesh_axis_size)); + } + + xla::ifrt::RemapPlan remap_plan; + remap_plan.mappings = + std::make_shared>(); + auto& mappings = *remap_plan.mappings; + + TF_ASSIGN_OR_RETURN(std::vector strides, GetStrides(mesh_axis_sizes)); + std::vector submesh_offsets; + if (mesh_axis_idx == 0) { + submesh_offsets.push_back(0); + } else { + std::vector current_entry(mesh_axis_idx, 0); + TF_RETURN_IF_ERROR(PopulateSubmeshOffsets( + mesh_axis_sizes.subspan(0, mesh_axis_idx), + absl::MakeSpan(current_entry), strides, submesh_offsets)); + } + + nb_class_ptr backend; + std::vector input_ifrt_arrays; + input_ifrt_arrays.reserve(py_arrays.size()); + for (int array_idx = 0; array_idx < py_arrays.size(); ++array_idx) { + TF_RETURN_IF_ERROR(PyClientFromPyArray(py_arrays[array_idx], backend)); + xla::ifrt::Array* array = py_arrays[array_idx].ifrt_array(); + if (array == nullptr) { + return xla::InvalidArgument("Input array #%d has been donated or deleted", + array_idx); + } + + remap_plan.input_specs.push_back( + xla::ifrt::ArraySpec{/*dtype=*/array->dtype(), + /*shape=*/array->shape(), + /*sharding=*/array->shared_ptr_sharding()}); + + for (int submesh_idx = 0; submesh_idx < num_submeshes; ++submesh_idx) { + auto& mapping = mappings.emplace_back(); + mapping.in_array = array_idx; + mapping.out_array = remap_plan.output_specs.size(); + int submesh_axis_size = mesh_axis_sections[submesh_idx]; + int submesh_axis_start = 0; + if (submesh_idx > 0) { + submesh_axis_size -= mesh_axis_sections[submesh_idx - 1]; + submesh_axis_start = mesh_axis_sections[submesh_idx - 1]; + } + int offset_to_array = 0; + for (const auto& submesh_offset : submesh_offsets) { + int num_contiguous_shards = submesh_axis_size * strides[mesh_axis_idx]; + int offset_from_array = + submesh_offset + submesh_axis_start * strides[mesh_axis_idx]; + mapping.from.push_back(xla::ifrt::RemapPlan::Interval{ + offset_from_array, offset_from_array + num_contiguous_shards, 1}); + mapping.to.push_back(xla::ifrt::RemapPlan::Interval{ + offset_to_array, offset_to_array + num_contiguous_shards, 1}); + offset_to_array += num_contiguous_shards; + } + if (sharded_dim_idxs[array_idx] >= 0) { + std::vector dims(array->shape().dims().begin(), + array->shape().dims().end()); + dims[sharded_dim_idxs[array_idx]] = dims[sharded_dim_idxs[array_idx]] / + mesh_axis_size * submesh_axis_size; + xla::ifrt::Shape subshape = xla::ifrt::Shape(dims); + TF_ASSIGN_OR_RETURN( + auto ifrt_submesh_sharding, + GetIfrtHloSharding(submesh_shardings[array_idx][submesh_idx], + subshape)); + remap_plan.output_specs.push_back(xla::ifrt::ArraySpec{ + /*dtype=*/array->dtype(), + /*shape=*/std::move(subshape), + /*sharding=*/std::move(ifrt_submesh_sharding)}); + } else { + // The arrays is replicated, so its shape does not change. + TF_ASSIGN_OR_RETURN( + auto ifrt_submesh_sharding, + GetIfrtHloSharding(submesh_shardings[array_idx][submesh_idx], + array->shape())); + remap_plan.output_specs.push_back(xla::ifrt::ArraySpec{ + /*dtype=*/array->dtype(), + /*shape=*/array->shape(), + /*sharding=*/std::move(ifrt_submesh_sharding)}); + } + } + + input_ifrt_arrays.push_back(FormRef(array)); + } + + DCHECK_OK(remap_plan.Validate()); + + std::vector result_ifrt_arrays; + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN( + result_ifrt_arrays, + backend->ifrt_client()->RemapArrays( + remap_plan, absl::MakeSpan(input_ifrt_arrays), + donate ? xla::ifrt::ArrayCopySemantics::kDonateInput + : xla::ifrt::ArrayCopySemantics::kReuseInput)); + } + + DCHECK_EQ(result_ifrt_arrays.size(), py_arrays.size() * num_submeshes); + + // Wrap IFRT arrays as JAX arrays. + std::vector> py_results; + int offset_in_results = 0; + for (int array_idx = 0; array_idx < py_arrays.size(); ++array_idx) { + auto& py_submesh_results = py_results.emplace_back(); + for (int submesh_idx = 0; submesh_idx < num_submeshes; ++submesh_idx) { + PyArray new_py_array = PyArray::MakeFromIfrtArrayAndSharding( + backend, + std::move(result_ifrt_arrays[offset_in_results + submesh_idx]), + submesh_shardings[array_idx][submesh_idx], + py_arrays[array_idx].weak_type(), + /*committed=*/true, + /*skip_checks=*/true); + py_submesh_results.push_back(new_py_array); + } + offset_in_results += num_submeshes; + } + + return py_results; +} + +} // namespace + +NB_MODULE(_pathways, m) { + m.def("_transfer_to_shardings", + xla::ValueOrThrowWrapper(ExperimentalReshardArrays), nb::arg("arrays"), + nb::arg("out_shardings"), nb::arg("donate") = false); + m.def("_split_by_mesh_axis", + xla::ValueOrThrowWrapper(ExperimentalSplitByMeshAxis), + nb::arg("arrays"), nb::arg("sharded_dim_idxs"), + nb::arg("mesh_axis_sizes"), nb::arg("mesh_axis_idx"), + nb::arg("mesh_axis_sections"), nb::arg("submesh_shardings"), + nb::arg("donate")); +} + +} // namespace jax diff --git a/jaxlib/pjit.cc b/jaxlib/pjit.cc new file mode 100644 index 000000000000..899a1e51f084 --- /dev/null +++ b/jaxlib/pjit.cc @@ -0,0 +1,1465 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/pjit.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/cleanup/cleanup.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/call_location.h" +#include "jaxlib/config.h" +#include "jaxlib/guard_lib.h" +#include "jaxlib/jax_jit.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_executable.h" +#include "jaxlib/py_user_context.h" +#include "jaxlib/py_values.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/pytree.h" +#include "jaxlib/sharding.h" +#include "jaxlib/to_ifrt_sharding.h" +#include "xla/layout.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/lru_cache.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt/user_context.h" +#include "xla/python/nb_helpers.h" +#include "xla/python/nb_numpy.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "tsl/profiler/lib/traceme.h" + +namespace jax { +namespace { + +namespace nb = nanobind; + +struct PjitCacheEntry { + explicit PjitCacheEntry(PyTreeRegistry* registry) + : out_pytree_def(registry) {} + std::shared_ptr executable; + std::vector in_shardings; + std::vector out_avals; + std::vector out_dtypes; + std::vector> out_shapes; + std::vector out_weak_types; + std::vector out_shardings; + std::vector out_committed; + PyTreeDef out_pytree_def; + // Bitvector of kept arguments from Jaxpr DCE pass. Used to drop some `args` + // in PjitFunction::Call before calling into compiled computation. + std::vector kept_var_bitvec; + std::vector in_device_local_layouts; + std::vector const_args; + + // Ensures a single thread performs the compilation for a given executable. + // + // The first thread (holding the GIL) will create the CacheEntry associated to + // a signature and if the object has been inserted already, other threads + // will wait for the notification. + absl::Notification compilation_complete; + + std::thread::id thread_id = std::this_thread::get_id(); + + bool fall_back_to_python = false; +}; + +// A PjitFunctionCache represents a cache of compiled functions that can be +// shared between one or more PjitFunction objects. It serves two goals: +// - reduce the number of lru caches (hash map) across multiple JITs. +// - make the cache global to increase cache hits (e.g. calling jit(f)(3) twice) +// keeping entries alive as long as the underlying function f is alive. +// Assume the cache is protected by the GIL. +class PjitFunctionCache { + public: + static constexpr int kDefaultCapacity = 4096; + explicit PjitFunctionCache(int capacity); + + // Cache entries are shared_ptr<>s because it's possible the cache entry + // might be evicted before we finish tracing/compiling. + typedef xla::LRUCache, + CallSignature::Hash> + Cache; + + // We include as part of the cache key `global_cache_key` (and any other + // fields that aren't subsumed by the CallSignature we compute for each call). + static std::shared_ptr Lookup(nb_class_ptr self, + nb::handle function, + nb::object global_cache_key); + std::shared_ptr DefaultCache(); + + // These methods require the GIL or the object's lock in no-GIL mode. + int Size() const { return lru_list_.Size(); } + int Capacity() const { return lru_list_.Capacity(); } + void Clear() { + lru_list_.Clear(); + functions_.clear(); + } + + private: + struct Key { + nb::handle function; // Does not hold a reference. + + // Other fields that are part of the arguments to `jit`, but are not + // otherwise part of CallSignature. + nb::object global_cache_key; + + size_t cached_hash; + + bool operator==(const Key& other) const { + bool global_cache_eq; + try { + global_cache_eq = global_cache_key.equal(other.global_cache_key); + } catch (const nanobind::python_error& e) { + throw std::invalid_argument( + absl::StrCat("Equality of global cache key lead to an exception. " + "The error was:\n", + e.what(), "\n")); + } + return function.ptr() == other.function.ptr() && global_cache_eq; + } + + struct Hash { + size_t operator()(const Key& key) const { return key.cached_hash; } + }; + }; + + template + friend H AbslHashValue(H h, const Key& key) { + h = H::combine(std::move(h), key.function.ptr()); + Py_hash_t hash; + try { + hash = nb::hash(key.global_cache_key); + } catch (const nanobind::python_error& e) { + if (!e.matches(PyExc_TypeError)) throw; + throw std::invalid_argument(absl::StrCat( + "Hashing global cache key lead to an exception. The error was:\n", + e.what(), "\n")); + } + h = H::combine(std::move(h), hash); + return h; + } + + struct Value { + explicit Value(std::shared_ptr cache) : cache(std::move(cache)) {} + std::shared_ptr cache; + + // A weak reference to the key function. We use the weak reference to + // register a callback that is triggered when the key function is destroyed. + // We use a weak pointer because we want to allow caching across multiple + // calls to `pjit(f)` if `f` remains alive, but we do not want the cache + // to keep `f` alive if all other references are dropped. + std::optional weakref; + }; + + // lru_list_ and functions_ are protected by the GIL in GIL mode, and by the + // self object lock in freethreading mode. + Cache::LRUList lru_list_; + // We use std::unordered_map because ABSL containers are not exception safe: + std::unordered_map, Key::Hash> functions_; + // mu_ prevents concurrent insertions into functions_ if the gil or critical + // section lock is released during insertion. + absl::Mutex mu_; +}; + +PjitFunctionCache::PjitFunctionCache(int capacity) : lru_list_(capacity) {} + +std::shared_ptr PjitFunctionCache::DefaultCache() { + return std::make_shared(&lru_list_); +} + +/*static*/ std::shared_ptr PjitFunctionCache::Lookup( + nb_class_ptr self, nb::handle function, + nb::object global_cache_key) ABSL_NO_THREAD_SAFETY_ANALYSIS { + // In no-GIL mode, a critical section on self plays the same role that + // the GIL plays in GIL mode. + nb::ft_object_guard lock(self); + { + // Because the gil (or the critical section lock) can be released during + // cache insertion, this forces the lock order to be mu_ then gil so we + // must release the gil first. + nb::gil_scoped_release release; + // Acquire a mutex to avoid problems where the gil is released during + // cache insertion and then a second thread invalidates the cache order. + self->mu_.lock(); + } + absl::Cleanup unlock = [&self]() ABSL_UNLOCK_FUNCTION(self->mu_) { + self->mu_.unlock(); + }; + Key key; + key.function = function; + key.global_cache_key = global_cache_key; + key.cached_hash = absl::HashOf(key); + auto insert = self->functions_.emplace(key, nullptr); + if (!insert.second) { + return insert.first->second->cache; + } + std::shared_ptr cache = std::make_shared(&self->lru_list_); + auto callback = + nb::cpp_function([self, key{std::move(key)}](nb::handle weakref) { + nb::ft_object_guard lock(self); + auto it = self->functions_.find(key); + if (it == self->functions_.end()) { + return; + } + // Remove the value from the map before destroying it. Destroying + // the value may release `lock` since it may call arbitrary Python + // code. + std::unique_ptr value = std::move(it->second); + self->functions_.erase(it); + value.reset(); + }); + PyObject* weakref = PyWeakref_NewRef(function.ptr(), callback.ptr()); + if (weakref) { + std::unique_ptr& entry = insert.first->second; + entry = std::make_unique(cache); + entry->weakref = nb::steal(weakref); + } else { + PyErr_Clear(); + // `function` is not weak-referenceable. Don't bother adding it to the + // shared cache in that case; the `jit` object will hold the only shared + // reference to the cache entry. + self->functions_.erase(insert.first); + } + return cache; +} + +class PjitFunction { + public: + PjitFunction(std::string function_name, std::optional fun, + nb::callable cache_miss, std::vector static_argnums, + std::vector static_argnames, + nb::object global_cache_key, + nb_class_ptr pytree_registry, + nb::callable shard_arg_fallback, + nb_class_ptr cache); + ~PjitFunction(); + + PjitFunction(const PjitFunction&) = delete; + PjitFunction& operator=(const PjitFunction&) = delete; + PjitFunction(PjitFunction&&) = default; + PjitFunction& operator=(PjitFunction&&) = default; + + // nb::object typed subclass for PjitFunction objects. + class pyobject : public nb::object { + public: + NB_OBJECT(pyobject, nb::object, "PjitFunction", + PjitFunction::IsPjitFunction); + pyobject() = default; + PjitFunction* func() const { + return PjitFunction::AsPjitFunctionUnchecked(*this); + } + }; + // Alias as ::object; outside the scope above we won't confuse nanobind's + // macros. + using object = pyobject; + + // Returns true if `h` is a PjitFunction. + static bool IsPjitFunction(nb::handle handle); + // Converts `handle` to a PjitFunction*. Does not do any checking. + static PjitFunction* AsPjitFunctionUnchecked(nb::handle handle); + + absl::StatusOr Call(nb::handle callable, PyObject* const* args, + size_t nargs, PyObject* kwnames); + + void InitExecutables(); + + void ClearPythonReferences(); + + const std::string& function_name() const { return function_name_; } + const std::optional& fun() const { return fun_; } + const nb::callable& cache_miss() const { return cache_miss_; } + const nb_class_ptr& pytree_registry() const { + return pytree_registry_; + } + const nb::callable& shard_arg_fallback() const { return shard_arg_fallback_; } + + const std::vector& static_argnums() const { return static_argnums_; } + const std::vector& static_argnames() const { + return static_argnames_; + } + const nb::object& global_cache_key() const { return global_cache_key_; } + const nb_class_ptr& cache() const { return cache_; } + + int cache_capacity() const { + nb::ft_object_guard lock(cache_); + return executables_->Size(); + } + + void ClearCache() { + nb::ft_object_guard lock(cache_); + executables_->Clear(); + } + + std::shared_ptr executables() { + nb::ft_object_guard lock(cache_); + return executables_; + } + + nb::object PythonSignature() { + if (!fun_.has_value()) { + throw nb::value_error( + absl::StrFormat( + "Calling __signature__ on PjitFunction(%s) not supported.", + function_name_) + .c_str()); + } + static const auto* inspect = + new nb::module_(nb::module_::import_("inspect")); + return inspect->attr("signature")(*fun_); + } + + private: + absl::Status ComputeCallSignature( + absl::Span flat_dynamic_args, bool enable_x64, + CallSignature& call_signature); + + void PopulateCacheEntry(PjitCacheEntry& cache_entry, + const nb::tuple& out_and_fastpath_data); + + std::string function_name_; + std::optional fun_; + nb::callable cache_miss_; + std::vector static_argnums_; + std::vector static_argnames_; + nb::object global_cache_key_; + + nb_class_ptr pytree_registry_; + nb::callable shard_arg_fallback_; + nb_class_ptr cache_; + + // In no-GIL mode executables_ is protected by the object lock on cache_, + // because it shared an LRU list with cache_. + std::shared_ptr executables_; +}; + +PjitFunction::PjitFunction( + std::string function_name, std::optional fun, + nb::callable cache_miss, std::vector static_argnums, + std::vector static_argnames, nb::object global_cache_key, + nb_class_ptr pytree_registry, + nb::callable shard_arg_fallback, nb_class_ptr cache) + : function_name_(std::move(function_name)), + fun_(std::move(fun)), + cache_miss_(std::move(cache_miss)), + static_argnums_(std::move(static_argnums)), + global_cache_key_(std::move(global_cache_key)), + pytree_registry_(std::move(pytree_registry)), + shard_arg_fallback_(std::move(shard_arg_fallback)), + cache_(std::move(cache)) { + std::sort(static_argnums_.begin(), static_argnums_.end()); + static_argnames_.reserve(static_argnames.size()); + for (nb::str& name : static_argnames) { + PyObject* s = name.inc_ref().ptr(); + PyUnicode_InternInPlace(&s); + static_argnames_.push_back(nb::steal(s)); + } +} + +void PjitFunction::InitExecutables() { + // Construction of the object hasn't completed yet, so we don't need to hold + // the cache lock to mutate executables_. + if (!fun_.has_value()) { + executables_ = cache_->DefaultCache(); + } else { + executables_ = cache_->Lookup(cache_, fun_.value(), global_cache_key_); + } +} + +PjitFunction::~PjitFunction() { + nb::ft_object_guard lock(cache_); + executables_ = nullptr; +} + +void CallShardArgFallback(nb::handle arg, nb::handle sharding, + nb::handle layout, const nb::callable& fallback, + std::vector& num_args_arrays, + std::vector& keep_alive_objects) { + tsl::profiler::TraceMe traceme("cpp_pjit_shard_arg_fallback"); + auto py_array_or_bufs = fallback(arg, sharding, layout); + auto py_array = nb::cast(py_array_or_bufs); + num_args_arrays.push_back(tsl::FormRef(py_array.ifrt_array())); + keep_alive_objects.push_back(std::move(py_array_or_bufs)); +} + +// Prepares the input PjRtBuffers from the python arguments. This is equivalent +// to shard_args() in pxla.py but for only a few supported cases. +absl::StatusOr> PrepareIfrtInputs( + const PyLoadedExecutable& executable, + absl::Span flat_dynamic_args, + absl::Span flat_dynamic_arg_signatures, + bool enable_x64, const std::vector& kept_args, + const std::vector& in_shardings, + const std::vector& in_device_local_layouts, + const nb::callable& shard_arg_fallback, + std::vector& keep_alive_objects) { + const auto& addressable_devices = + executable.ifrt_loaded_executable()->addressable_devices(); + const auto& num_global_devices = + executable.ifrt_loaded_executable()->num_devices(); + int num_args = flat_dynamic_args.size(); + + std::vector num_args_arrays; + num_args_arrays.reserve(num_args); + + struct CopyGroup { + std::vector indices; + std::vector arrays; + }; + absl::flat_hash_map, + CopyGroup> + copy_groups; + + DevicePutOptions options; + options.squash_64bit_types = !enable_x64; + options.allow_zero_copy = true; + xla::ifrt::Device* data_device = nullptr; + if (executable.ifrt_loaded_executable()->num_devices() == 1 && + !addressable_devices.empty()) { + data_device = addressable_devices[0]; + } + int dce_i = 0; + for (int i = 0; i < num_args; ++i) { + if (!kept_args[i]) { + continue; + } + int dce_index = dce_i; + ++dce_i; + + const nb::object& arg = flat_dynamic_args[i]; + const nb::object& in_device_local_layout = + in_device_local_layouts[dce_index]; + + auto transfer_guard_formatter = [] { return std::string(""); }; + + if (arg.type().ptr() != PyArray::type().ptr()) { + if (data_device != nullptr && in_device_local_layout.is_none()) { + TF_RETURN_IF_ERROR( + ApplyTransferGuardToHostToDevice(transfer_guard_formatter)); + TF_ASSIGN_OR_RETURN( + auto device_put_result, + DevicePutWithDevice(arg, + executable.ifrt_loaded_executable()->client(), + data_device, xla::ifrt::MemoryKind(), options)); + num_args_arrays.push_back(std::move(device_put_result.ifrt_array)); + continue; + } else { + CallShardArgFallback(arg, in_shardings[dce_index], + in_device_local_layout, shard_arg_fallback, + num_args_arrays, keep_alive_objects); + continue; + } + } + + PyArray py_array = nb::borrow(arg); + const auto& sharding = py_array.sharding(); + int sharding_num_devices = + nb::cast(sharding)->num_devices(); + + // Currently only committed PyArray inputs or uncommitted PyArray on a + // single device inputs are allowed. This is checked previously in the entry + // point of PjitFunction::Call(). + DCHECK(py_array.committed() || + (!py_array.committed() && sharding_num_devices == 1)); + + if (!in_device_local_layout.is_none()) { + xla::ifrt::Array* ifrt_array = py_array.ifrt_array(); + TF_ASSIGN_OR_RETURN(auto arr_layout, ifrt_array->pjrt_layout()); + if (arr_layout == nullptr) { + TF_ASSIGN_OR_RETURN( + xla::ifrt::Shape shard_shape, + ifrt_array->sharding().GetShardShape(ifrt_array->shape())); + TF_ASSIGN_OR_RETURN( + arr_layout, + executable.ifrt_loaded_executable()->client()->GetDefaultPjRtLayout( + ifrt_array->dtype(), shard_shape.dims(), + ifrt_array->sharding().devices()->devices().front(), + ifrt_array->sharding().memory_kind())); + } + xla::Layout in_xc_layout = nb::cast( + in_device_local_layout.attr("_to_xla_layout")(py_array.dtype())); + if (in_xc_layout != arr_layout->xla_layout()) { + CallShardArgFallback(arg, in_shardings[dce_index], + in_device_local_layout, shard_arg_fallback, + num_args_arrays, keep_alive_objects); + continue; + } + } + + if (sharding.type().ptr() == PmapSharding::type().ptr()) { + CallShardArgFallback(arg, in_shardings[dce_index], in_device_local_layout, + shard_arg_fallback, num_args_arrays, + keep_alive_objects); + continue; + } + + if (sharding_num_devices != num_global_devices) { + CallShardArgFallback(arg, in_shardings[dce_index], in_device_local_layout, + shard_arg_fallback, num_args_arrays, + keep_alive_objects); + continue; + } + + xla::ifrt::Array* ifrt_array = py_array.ifrt_array(); + // PyArray inputs should have already been checked in + // `PyArgSignatureOfValue()` called by + // `PjitFunction::ComputeCallSignature()`. + DCHECK(ifrt_array != nullptr) << "PyArray has been unexpectedly deleted."; + + const auto& ifrt_sharding = ifrt_array->sharding(); + if (sharding_num_devices == 1 && !addressable_devices.empty() && + ifrt_sharding.devices()->devices().front() != addressable_devices[0]) { + auto& copy_group = copy_groups[std::make_tuple( + ifrt_sharding.devices()->devices().front(), + ifrt_sharding.memory_kind(), GetMemoryKind(in_shardings[dce_index]))]; + copy_group.indices.push_back(num_args_arrays.size()); + copy_group.arrays.push_back(tsl::FormRef(ifrt_array)); + num_args_arrays.push_back({}); + } else { + num_args_arrays.push_back(tsl::FormRef(ifrt_array)); + } + + keep_alive_objects.push_back(arg); + } + + if (!copy_groups.empty() && !addressable_devices.empty()) { + xla::ifrt::Client* const ifrt_client = + executable.ifrt_loaded_executable()->client(); + TF_ASSIGN_OR_RETURN(xla::ifrt::DeviceListRef ifrt_devices, + ifrt_client->MakeDeviceList({addressable_devices[0]})); + for (auto& [key, group] : copy_groups) { + TF_ASSIGN_OR_RETURN( + auto copied_ifrt_arrays, + ifrt_client->CopyArrays(absl::MakeSpan(group.arrays), ifrt_devices, + std::get<2>(key), + xla::ifrt::ArrayCopySemantics::kReuseInput)); + for (int i = 0; i < copied_ifrt_arrays.size(); ++i) { + num_args_arrays[group.indices[i]] = std::move(copied_ifrt_arrays[i]); + } + } + } + + return num_args_arrays; +} + +absl::StatusOr PjitFunction::Call(nb::handle callable, + PyObject* const* args, + size_t nargs, PyObject* kwnames) { + tsl::profiler::TraceMe traceme( + [&] { return absl::StrCat("PjitFunction(", function_name_, ")"); }); + + // Make sure we trigger a garbage collection on JIT function calls. Otherwise + // code like + // f = jit(...) + // while True: + // f(x) + // may never free temporary buffers for copies of arguments. + GlobalPyRefManager()->MaybeCollectGarbage(); + InitializeThreadLocalState(); + + if (GetDisableJit()) { + if (!fun_.has_value()) { + throw nb::value_error( + absl::StrFormat("Disable jit is not supported in the AOT path since " + "the function is not available for (%s)", + function_name_) + .c_str()); + } + return nb::steal( + PyObject_Vectorcall(fun_.value().ptr(), args, nargs, kwnames)); + } + + // Calls the cache_miss_ function. This just calls the Python function; it may + // return nullptr value if a Python exception is thrown. + auto cache_miss = [&]() -> nb::tuple { + return nb::steal( + PyObject_Vectorcall(cache_miss_.ptr(), args, nargs, kwnames)); + }; + + // Call the cache_miss() function, extracting the output data and ignoring + // the fastpath data. If the cache miss returns a Python error, returns + // nullptr and leaves the Python error set. + auto fallback_to_cache_miss = [&]() { + nb::tuple cache_miss_output = cache_miss(); + if (!cache_miss_output.ptr()) { + return nb::object(); + } + return nb::object(cache_miss_output[0]); + }; + + size_t num_positional_args = PyVectorcall_NARGS(nargs); + size_t num_keyword_args = kwnames ? PyTuple_GET_SIZE(kwnames) : 0; + absl::Span positional_args(args, num_positional_args); + absl::Span keyword_args(args + num_positional_args, + num_keyword_args); + + CallSignature call_signature; + std::vector keep_alive_objects; + absl::InlinedVector flat_dynamic_args; + auto status = ParseArguments( + positional_args, keyword_args, kwnames, static_argnums_, static_argnames_, + pytree_registry_.get(), call_signature.arg_signature, flat_dynamic_args); + if (!status.ok()) { + VLOG(2) << "ParseArguments failed: " << status; + return fallback_to_cache_miss(); + } + + // Perform a few checks for the arguments. Currently we are only allowing + // committed PyArray inputs. For other cases, e.g. Tracers or ShapedArray, it + // will fallback to python. For jit, numpy arrays and scalars are also + // allowed, which we will check later. + for (const auto& arg : flat_dynamic_args) { + if (arg.type().ptr() != PyArray::type().ptr()) { + continue; + } + + PyArray py_array = nb::borrow(arg); + + // Only allow committed PyArray in cpp pjit for now as the logic on handling + // sharding for uncommitted PyArray is complicated and still under + // development. + // + // TODO(chky): Consider support uncommitted PyArray in cpp when the python + // side stabilizes. + int sharding_num_devices = + nb::cast(py_array.sharding())->num_devices(); + if (!py_array.committed() && sharding_num_devices > 1) { + VLOG(2) << "PyArray argument is not committed and number of global " + "devices is more than 1; fallback to python."; + return fallback_to_cache_miss(); + } + } + + bool enable_x64 = GetEnableX64(); + status = ComputeCallSignature(flat_dynamic_args, enable_x64, call_signature); + if (!status.ok()) { + VLOG(2) << "ComputeCallSignature failed: " << status; + return fallback_to_cache_miss(); + } + + VLOG(2) << "CallSignature:\n" << call_signature.DebugString(); + bool inserted = false; + std::shared_ptr cache_entry; + { + nb::ft_object_guard lock(cache_); + cache_entry = executables_->GetOrCreateIfAbsent( + call_signature, [this, &inserted](const CallSignature& unused) { + inserted = true; + return std::make_shared(pytree_registry_.get()); + }); + } + + if (!cache_entry->compilation_complete.HasBeenNotified()) { + // In case of several threads attempting to compile the executable, only + // the one that inserted the item will perform the compilation. + if (inserted) { + nb::object out_and_fastpath_data; + nb::tuple out_tuple; + VLOG(2) << "Cache miss for " << call_signature.DebugString(); + bool remove_cache = false; + try { + // Calls Python and may release the GIL. May also throw if + // compilation/tracing fails. + out_and_fastpath_data = cache_miss(); + if (!out_and_fastpath_data.ptr()) { + throw nb::python_error(); + } + out_tuple = nb::cast(out_and_fastpath_data); + + PopulateCacheEntry(*cache_entry, out_tuple); + + if (out_tuple.size() > 2 && out_tuple[2].is_valid()) { + remove_cache = nb::cast(out_tuple[2]); + } + } catch (const std::exception& e) { + VLOG(2) << "cache miss fail: " << e.what(); + cache_entry->fall_back_to_python = true; + cache_entry->compilation_complete.Notify(); + throw; + } + cache_entry->compilation_complete.Notify(); + + if (remove_cache) { + nb::ft_object_guard lock(cache_); + executables_->Remove(call_signature); + } + + // We have already computed the result in the miss path so we can return + // it. We are even *required* to do so if there are donated arguments, + // because any donated buffers will now be invalid. + return nb::object(out_tuple[0]); + } else { + if (cache_entry->thread_id == std::this_thread::get_id()) { + auto error_string = absl::StrCat("Recursively calling jit: ", + call_signature.DebugString()); + PyErr_SetString(PyExc_RecursionError, error_string.c_str()); + throw nb::python_error(); + } + // Release the GIL while we wait, making sure the compile thread can + // lock it. + nb::gil_scoped_release release; + cache_entry->compilation_complete.WaitForNotification(); + } + } + + if (cache_entry->fall_back_to_python) { + VLOG(2) << "cpp pjit fallback to python."; + return fallback_to_cache_miss(); + } + + absl::InlinedVector dynamic_arg_signatures; + dynamic_arg_signatures.reserve(cache_entry->const_args.size() + + flat_dynamic_args.size()); + if (!cache_entry->const_args.empty()) { + flat_dynamic_args.reserve(cache_entry->const_args.size() + + flat_dynamic_args.size()); + flat_dynamic_args.insert(flat_dynamic_args.begin(), + cache_entry->const_args.begin(), + cache_entry->const_args.end()); + + for (nb::handle const_arg : cache_entry->const_args) { + TF_ASSIGN_OR_RETURN(auto const_arg_signature, + PyArgSignatureOfValue(const_arg, enable_x64)); + dynamic_arg_signatures.push_back(std::move(const_arg_signature)); + } + } + for (const auto& arg : call_signature.dynamic_arg_signatures) { + dynamic_arg_signatures.push_back(std::move(arg)); + } + + PyUserContextScope user_context_scope; + // A vector of [num_inputs]. + auto num_args_arrays = PrepareIfrtInputs( + *cache_entry->executable, flat_dynamic_args, dynamic_arg_signatures, + enable_x64, cache_entry->kept_var_bitvec, cache_entry->in_shardings, + cache_entry->in_device_local_layouts, shard_arg_fallback_, + keep_alive_objects); + + if (!num_args_arrays.ok()) { + VLOG(2) << "Failed to prepare IFRT inputs: " << num_args_arrays.status(); + return fallback_to_cache_miss(); + } + + xla::ifrt::ExecuteOptions execute_options = + cache_entry->executable->options(); + execute_options.launch_id = cache_entry->executable->GetNextLaunchId(); + execute_options.execution_stream_id = GetExecutionStreamId(); + if (execute_options.execution_stream_id == 0) { + execute_options.execution_stream_id = + tsl::Env::Default()->GetCurrentThreadId(); + } + PopulateCallLocation(execute_options, + xla::ifrt::UserContextScope::current().get()); + + // Check if the thread guard is active and should prevent execution. + // Skipped for portable executables. + if (cache_entry->executable->ifrt_executable()->devices().has_value()) { + TF_RETURN_IF_ERROR(CheckThreadGuard( + *cache_entry->executable->ifrt_executable()->devices())); + } + + // A vector of [num_outputs]. + std::vector output_arrays; + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(auto result, + cache_entry->executable->ifrt_executable()->Execute( + absl::MakeSpan(*num_args_arrays), execute_options, + /*devices=*/std::nullopt)); + output_arrays = std::move(result.outputs); + } + + // Convert the ifrt::Array objects to PyArray. + int num_outputs = output_arrays.size(); + absl::InlinedVector outputs; + outputs.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + // Creating the PyArray result. In addition to the IFRT arrays, the metadata + // like `aval` and `sharding` are retrieved from the cache for this + // function, which are produced by the python path in `cache_miss`. + PyArray py_array( + cache_entry->out_avals[i], cache_entry->out_weak_types[i], + cache_entry->out_dtypes[i], cache_entry->out_shapes[i], + cache_entry->out_shardings[i], cache_entry->executable->client(), + std::move(output_arrays[i]), + /*committed=*/cache_entry->out_committed.at(i), /*skip_checks=*/true); + + outputs.push_back(std::move(py_array)); + } + + nb::object out = nb::steal( + cache_entry->out_pytree_def.Unflatten(outputs).release().ptr()); + + // If there is a post-hook function, call it with the inputs and the outputs. + std::optional post_hook = GetPostHook(); + if (post_hook) { + nb::tuple args_tuple = + nb::steal(PyTuple_New(num_positional_args)); + for (size_t i = 0; i < num_positional_args; ++i) { + Py_INCREF(args[i]); + PyTuple_SET_ITEM(args_tuple.ptr(), i, args[i]); + } + nb::dict kwargs; + if (kwnames) { + for (size_t i = 0; i < num_keyword_args; ++i) { + kwargs[nb::handle(PyTuple_GET_ITEM(kwnames, i))] = + nb::borrow(args[num_positional_args + i]); + } + } + (*post_hook)(nb::handle(callable.ptr()), args_tuple, kwargs, + nb::handle(out.ptr())); + } + + return out; +} + +absl::Status PjitFunction::ComputeCallSignature( + absl::Span flat_dynamic_args, bool enable_x64, + CallSignature& signature) { + signature.function_name = function_name_; + + // Get dynamic argument signatures. + auto& dynamic_arg_signatures = signature.dynamic_arg_signatures; + dynamic_arg_signatures.reserve(flat_dynamic_args.size()); + auto& dynamic_arg_shardings = signature.dynamic_arg_shardings; + dynamic_arg_shardings.reserve(flat_dynamic_args.size()); + auto& dynamic_arg_layouts = signature.dynamic_arg_layouts; + dynamic_arg_layouts.reserve(flat_dynamic_args.size()); + + for (nb::handle arg : flat_dynamic_args) { + TF_ASSIGN_OR_RETURN(auto arg_signature, + PyArgSignatureOfValue(arg, enable_x64)); + signature.dynamic_arg_signatures.push_back(std::move(arg_signature)); + + // It should be already checked previously in the entry point of + // PjitFunction::Call(). + if (arg.type().ptr() == PyArray::type().ptr()) { + auto py_array = nb::borrow(arg); + signature.dynamic_arg_shardings.push_back(py_array.sharding()); + auto layout = py_array.layout(); + if (absl::IsUnimplemented(layout.status())) { + signature.dynamic_arg_layouts.push_back(nullptr); + } else { + signature.dynamic_arg_layouts.push_back(*std::move(layout)); + } + signature.committed_args.push_back(py_array.committed()); + } else { + signature.dynamic_arg_shardings.push_back(nb::none()); + signature.dynamic_arg_layouts.push_back(nullptr); + signature.committed_args.push_back(false); + } + } + + signature.configs = JitConfigs(); + signature.cached_hash = absl::HashOf(signature); + + return absl::OkStatus(); +} + +void PjitFunction::PopulateCacheEntry(PjitCacheEntry& cache_entry, + const nb::tuple& out_and_fastpath_data) { + DCHECK_GE(out_and_fastpath_data.size(), 2); + + if (out_and_fastpath_data[1].is_none()) { + VLOG(2) << "fastpath_data is none"; + cache_entry.fall_back_to_python = true; + return; + } + + nb::tuple fastpath_data = nb::cast(out_and_fastpath_data[1]); + + cache_entry.executable = nb::cast>( + fastpath_data.attr("xla_executable")); + + nb::sequence in_shardings = fastpath_data.attr("in_shardings"); + cache_entry.in_shardings.reserve(nb::len(in_shardings)); + for (nb::handle sharding : in_shardings) { + cache_entry.in_shardings.push_back(nb::borrow(sharding)); + } + + nb::sequence out_shardings = fastpath_data.attr("out_shardings"); + cache_entry.out_shardings.reserve(nb::len(out_shardings)); + for (nb::handle sharding : out_shardings) { + cache_entry.out_shardings.push_back(nb::borrow(sharding)); + } + + nb::sequence out_committed = fastpath_data.attr("out_committed"); + cache_entry.out_committed.reserve(nb::len(out_committed)); + for (nb::handle c : out_committed) { + cache_entry.out_committed.push_back(nb::cast(c)); + } + + nb::sequence out_avals = fastpath_data.attr("out_avals"); + cache_entry.out_avals.reserve(nb::len(out_avals)); + cache_entry.out_dtypes.reserve(nb::len(out_avals)); + cache_entry.out_shapes.reserve(nb::len(out_avals)); + cache_entry.out_weak_types.reserve(nb::len(out_avals)); + for (nb::handle aval : out_avals) { + cache_entry.out_avals.push_back(nb::borrow(aval)); + cache_entry.out_dtypes.push_back(aval.attr("dtype")); + cache_entry.out_shapes.push_back( + nb::cast>(aval.attr("shape"))); + cache_entry.out_weak_types.push_back( + nb::cast(aval.attr("weak_type"))); + } + + cache_entry.out_pytree_def = nb::cast( + nb::handle(fastpath_data.attr("out_pytree_def").ptr())); + + nb::sequence kept_var_bitvec = fastpath_data.attr("kept_var_bitvec"); + cache_entry.kept_var_bitvec.reserve(nb::len(kept_var_bitvec)); + for (nb::handle k : kept_var_bitvec) { + cache_entry.kept_var_bitvec.push_back(nb::cast(k)); + } + + nb::sequence in_device_local_layouts = + fastpath_data.attr("in_device_local_layouts"); + cache_entry.in_device_local_layouts.reserve(nb::len(in_device_local_layouts)); + for (nb::handle dll : in_device_local_layouts) { + cache_entry.in_device_local_layouts.push_back(nb::borrow(dll)); + } + + nb::sequence const_args = fastpath_data.attr("const_args"); + cache_entry.const_args.reserve(nb::len(const_args)); + for (nb::handle ca : const_args) { + cache_entry.const_args.push_back(nb::borrow(ca)); + } +} + +// Helper function used by the tp_clear GC method. +void PjitFunction::ClearPythonReferences() { + // TODO(mattjj): phawkins@ observed that the PyTreeRegistry + // pytree_registry_ attribute of PjitFunction could in principle also have + // python references to clear + nb::callable cache_miss; + std::optional fun; + nb::callable shard_arg_fallback; + // Swap values for nulls before they are destroyed. See the Python + // Py_CLEAR() documentation for a discussion of this topic. + std::swap(cache_miss_, cache_miss); + std::swap(fun_, fun); + std::swap(shard_arg_fallback_, shard_arg_fallback); +} + +struct PjitFunctionObject { + PyObject_HEAD; +#if PY_VERSION_HEX < 0x030C0000 + PyObject* dict; // Dictionary for __dict__ + PyObject* weakrefs; // Weak references; for use by the Python interpreter. +#endif // PY_VERSION_HEX < 0x030C0000 + vectorcallfunc vectorcall; + PjitFunction fun; + + // Doubly-linked list of PjitFunctionObjects, protected by + // PjitFunctionStore::mu_ or the GIL in GIL mode. + PjitFunctionObject* next; + PjitFunctionObject* prev; +}; + +// Contains a list of all PjitFunctionObjects. +// Thread-safe. +class PjitFunctionStore { + public: + void Insert(PjitFunctionObject* o) { + nb::ft_lock_guard lock(mu_); + o->next = compiled_functions_; + o->prev = nullptr; + if (o->next) { + o->next->prev = o; + } + compiled_functions_ = o; + } + + void Remove(PjitFunctionObject* o) { + nb::ft_lock_guard lock(mu_); + if (o->next) { + o->next->prev = o->prev; + } + if (o->prev) { + o->prev->next = o->next; + } else { + compiled_functions_ = o->next; + } + } + + void ClearCaches() { + std::vector< + std::pair>> + caches; + { + nb::ft_lock_guard lock(mu_); + for (PjitFunctionObject* fn = compiled_functions_; fn != nullptr; + fn = fn->next) { + caches.emplace_back(fn->fun.cache(), fn->fun.executables()); + } + } + for (auto& [cache, executables] : caches) { + nb::ft_object_guard lock(cache); + executables->Clear(); + } + }; + + private: + // Protected by the GIL in GIL mode, and by mu_ in freethreading mode. + nb::ft_mutex mu_; + PjitFunctionObject* compiled_functions_; +}; + +PjitFunctionStore pjit_function_store; + +PyObject* PjitFunction_Type = nullptr; + +bool PjitFunction::IsPjitFunction(nb::handle handle) { + return handle.type().ptr() == PjitFunction_Type; +} + +PjitFunction* PjitFunction::AsPjitFunctionUnchecked(nb::handle handle) { + return &(reinterpret_cast(handle.ptr())->fun); +} + +PjitFunction* AsPjitFunction(nb::handle handle) { + if (!PjitFunction::IsPjitFunction(handle)) { + throw xla::XlaRuntimeError(xla::InvalidArgument("Expected a PjitFunction")); + } + return PjitFunction::AsPjitFunctionUnchecked(handle); +} + +extern "C" { + +PyObject* PjitFunction_tp_vectorcall(PyObject* callable, PyObject* const* args, + size_t nargs, PyObject* kwnames) { + PjitFunctionObject* o = reinterpret_cast(callable); + tsl::profiler::TraceMe traceme([&] { + return absl::StrCat("PjitFunction(", o->fun.function_name(), ")"); + }); + try { + absl::StatusOr out = + o->fun.Call(callable, args, nargs, kwnames); + if (!out.ok()) { + PyErr_SetString(PyExc_ValueError, out.status().ToString().c_str()); + return nullptr; + } + return out.value().release().ptr(); + } catch (nb::python_error& e) { + e.restore(); + return nullptr; + } catch (nb::cast_error& e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return nullptr; + } catch (std::invalid_argument& e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return nullptr; + } catch (std::runtime_error& e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return nullptr; + } +} + +PyObject* PjitFunction_tp_new(PyTypeObject* subtype, PyObject* args, + PyObject* kwds) { + PjitFunctionObject* self = + reinterpret_cast(subtype->tp_alloc(subtype, 0)); + if (!self) return nullptr; +#if PY_VERSION_HEX < 0x030C0000 + self->dict = nullptr; + self->weakrefs = nullptr; +#endif // PY_VERSION_HEX < 0x030C0000 + self->vectorcall = PjitFunction_tp_vectorcall; + return reinterpret_cast(self); +} + +void PjitFunction_tp_dealloc(PyObject* self) { + PyObject_GC_UnTrack(self); + PyTypeObject* tp = Py_TYPE(self); + PjitFunctionObject* o = reinterpret_cast(self); + pjit_function_store.Remove(o); + PyObject_ClearWeakRefs(self); +#if PY_VERSION_HEX < 0x030C0000 + Py_CLEAR(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + o->fun.~PjitFunction(); + tp->tp_free(self); + Py_DECREF(tp); +} + +int PjitFunction_tp_traverse(PyObject* self, visitproc visit, void* arg) { + // TODO(mattjj): phawkins@ observed that the PyTreeRegistry + // pytree_registry_ attribute of PjitFunction could in principle also have + // python references to visit + PjitFunctionObject* o = reinterpret_cast(self); + // https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_traverse + Py_VISIT(Py_TYPE(self)); +#if PY_VERSION_HEX < 0x030C0000 + Py_VISIT(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_VisitManagedDict(self, visit, arg); +#else + PyObject_VisitManagedDict(self, visit, arg); +#endif // PY_VERSION_HEX < 0x030C0000 + Py_VISIT(o->fun.cache_miss().ptr()); + Py_VISIT(o->fun.shard_arg_fallback().ptr()); + if (o->fun.fun()) { + Py_VISIT(o->fun.fun()->ptr()); + } + return 0; +} + +int PjitFunction_tp_clear(PyObject* self) { + PjitFunctionObject* o = reinterpret_cast(self); +#if PY_VERSION_HEX < 0x030C0000 + Py_CLEAR(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + o->fun.ClearPythonReferences(); + return 0; +} + +// Implements the Python descriptor protocol so JIT-compiled functions can be +// used as bound methods. See: +// https://docs.python.org/3/howto/descriptor.html#functions-and-methods +PyObject* PjitFunction_tp_descr_get(PyObject* self, PyObject* obj, + PyObject* type) { + if (obj == nullptr || obj == Py_None) { + Py_INCREF(self); + return self; + } + return PyMethod_New(self, obj); +} + +static PyGetSetDef PjitFunction_tp_getset[] = { + // Having a __dict__ seems necessary to allow !functool.wraps to override + // __doc__. + {const_cast("__dict__"), PyObject_GenericGetDict, + PyObject_GenericSetDict, nullptr, nullptr}, + {nullptr, nullptr, nullptr, nullptr, nullptr}}; + +PyObject* PjitFunction_tp_repr(PyObject* self) { + try { + const std::string& repr = absl::StrFormat( + "", + nb::cast(nb::repr(nb::getattr(self, "__wrapped__")))); + return PyUnicode_FromString(repr.c_str()); + } catch (...) { + // Ignore all errors when accessing a repr. + return PyUnicode_FromString(""); + } +} + +} // extern "C" + +void InitializePjitFunction( + PjitFunctionObject* fn_obj, std::string function_name, + std::optional fun, nb::callable cache_miss, + std::vector static_argnums, std::vector static_argnames, + nb::object global_cache_key, nb_class_ptr pytree_registry, + nb::callable shard_arg_fallback, nb_class_ptr cache) { + fn_obj->next = fn_obj->prev = nullptr; + if (nb::isinstance(global_cache_key)) { + global_cache_key = nb::tuple(global_cache_key); + } + new (&fn_obj->fun) PjitFunction( + std::move(function_name), std::move(fun), std::move(cache_miss), + std::move(static_argnums), std::move(static_argnames), + std::move(global_cache_key), std::move(pytree_registry), + std::move(shard_arg_fallback), std::move(cache)); + // Handled separately because it is not exception safe to call this + // in the constructor because it leaves the object improperly constructed. + fn_obj->fun.InitExecutables(); + + // Only add the executable to the store after executables_ has been + // initialized. We want only fully constructed executables in the store. + pjit_function_store.Insert(fn_obj); +} + +nb::object MakePjitFunction( + std::string function_name, std::optional fun, + nb::callable cache_miss, std::vector static_argnums, + std::vector static_argnames, nb::object global_cache_key, + nb_class_ptr pytree_registry, + nb::callable shard_arg_fallback, + std::optional> cache) { + nb::object obj = nb::steal(PjitFunction_tp_new( + reinterpret_cast(PjitFunction_Type), nullptr, nullptr)); + PjitFunctionObject* fn_obj = reinterpret_cast(obj.ptr()); + if (!cache) { + cache = + make_nb_class(PjitFunctionCache::kDefaultCapacity); + } + InitializePjitFunction( + fn_obj, std::move(function_name), std::move(fun), std::move(cache_miss), + std::move(static_argnums), std::move(static_argnames), + std::move(global_cache_key), std::move(pytree_registry), + std::move(shard_arg_fallback), std::move(*cache)); + return obj; +} + +// Version numbers for the pickled representations of +// PjitFunction. Increment these if changing them. +const int kPjitFunctionPickleVersion = 1; + +PyMemberDef PjitFunction_members[] = { + {"__vectorcalloffset__", T_PYSSIZET, + static_cast(offsetof(PjitFunctionObject, vectorcall)), + READONLY, nullptr}, +#if PY_VERSION_HEX < 0x030C0000 + {"__dictoffset__", T_PYSSIZET, + static_cast(offsetof(PjitFunctionObject, dict)), READONLY, + nullptr}, + {"__weaklistoffset__", T_PYSSIZET, + static_cast(offsetof(PjitFunctionObject, weakrefs)), READONLY, + nullptr}, +#endif // PY_VERSION_HEX < 0x030C0000 + {nullptr, 0, 0, 0, nullptr}, +}; + +PyType_Slot PjitFunction_slots[] = { + {Py_tp_new, reinterpret_cast(PjitFunction_tp_new)}, + {Py_tp_dealloc, reinterpret_cast(PjitFunction_tp_dealloc)}, + {Py_tp_traverse, reinterpret_cast(PjitFunction_tp_traverse)}, + {Py_tp_clear, reinterpret_cast(PjitFunction_tp_clear)}, + {Py_tp_getset, reinterpret_cast(PjitFunction_tp_getset)}, + {Py_tp_descr_get, reinterpret_cast(PjitFunction_tp_descr_get)}, + {Py_tp_call, reinterpret_cast(PyVectorcall_Call)}, + {Py_tp_repr, reinterpret_cast(PjitFunction_tp_repr)}, + {Py_tp_members, reinterpret_cast(PjitFunction_members)}, + {0, nullptr}, +}; + +} // namespace + +void BuildPjitSubmodule(nb::module_& m) { + m.attr("_PyTreeRegistry") = m.attr("pytree").attr("PyTreeRegistry"); + + nb::class_ cache(m, "PjitFunctionCache"); + cache.def(nb::init(), + nb::arg("capacity") = PjitFunctionCache::kDefaultCapacity); + cache.def("size", &PjitFunctionCache::Size, nb::lock_self()); + cache.def("capacity", &PjitFunctionCache::Capacity, nb::lock_self()); + cache.def("clear", &PjitFunctionCache::Clear, nb::lock_self()); + cache.def_static("clear_all", []() { pjit_function_store.ClearCaches(); }); + cache.def( + "__getstate__", + // Pickles as an empty cache; the client can repopulate as needed. + [](const PjitFunctionCache& cache) { + nb::dict pickle; + pickle["version"] = kPjitFunctionPickleVersion; + pickle["capacity"] = cache.Capacity(); + return pickle; + }, + nb::lock_self()); + cache.def("__setstate__", + [](PjitFunctionCache* cache, const nb::dict& pickle) { + int version = nb::cast(pickle["version"]); + if (version != kPjitFunctionPickleVersion) { + throw std::invalid_argument(absl::StrFormat( + "Invalid PjitFunction pickle version, got %d, expected %d", + version, kPjitFunctionPickleVersion)); + } + int capacity = nb::cast(pickle["capacity"]); + new (cache) PjitFunctionCache(capacity); + }); + + // We need to use heap-allocated type objects because we want to add + // additional methods dynamically. + std::string name = + absl::StrCat(nb::cast(m.attr("__name__")), ".PjitFunction"); + PyType_Spec PjitFunction_spec = { + /*.name=*/name.c_str(), + /*.basicsize=*/static_cast(sizeof(PjitFunctionObject)), + /*.itemsize=*/0, +#if PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_HAVE_VECTORCALL, +#else // PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_HAVE_VECTORCALL | Py_TPFLAGS_MANAGED_DICT | + Py_TPFLAGS_MANAGED_WEAKREF, +#endif // PY_VERSION_HEX < 0x030C0000 + /*.slots=*/PjitFunction_slots, + }; + PjitFunction_Type = PyType_FromSpec(&PjitFunction_spec); + if (!PjitFunction_Type) { + throw nb::python_error(); + } + nb::object cfun = nb::borrow(PjitFunction_Type); + + // Add PjitFunction to the _jax module so it can be pickled. + m.attr("PjitFunction") = cfun; + cfun.attr("__getstate__") = nb::cpp_function( + [](const PjitFunction::object& self) { + PjitFunction* fn = self.func(); + nb::dict pickle; + pickle["version"] = kPjitFunctionPickleVersion; + pickle["function_name"] = fn->function_name(); + if (fn->fun().has_value()) { + pickle["fun"] = *fn->fun(); + } + pickle["cache_miss"] = fn->cache_miss(); + pickle["static_argnums"] = fn->static_argnums(); + pickle["static_argnames"] = nb::cast(fn->static_argnames()); + pickle["global_cache_key"] = fn->global_cache_key(); + pickle["pytree_registry"] = nb::cast(fn->pytree_registry()); + pickle["shard_arg_fallback"] = fn->shard_arg_fallback(); + pickle["cache"] = fn->cache(); + return pickle; + }, + nb::is_method()); + cfun.attr("__setstate__") = nb::cpp_function( + [](nb::object& self, const nb::dict& pickle) { + int version = nb::cast(pickle["version"]); + if (version != kPjitFunctionPickleVersion) { + throw std::invalid_argument(absl::StrFormat( + "Invalid PjitFunction pickle version, got %d, expected %d. " + "Pickling/Unpickling jitted functions using different JAX " + "versions is not supported.", + version, kPjitFunctionPickleVersion)); + } + std::string function_name = + nb::cast(pickle["function_name"]); + std::optional fun; + if (pickle.contains("fun")) { + fun = nb::cast(pickle["fun"]); + } + nb::callable cache_miss = nb::cast(pickle["cache_miss"]); + std::vector static_argnums = + nb::cast>(pickle["static_argnums"]); + std::vector static_argnames = + nb::cast>(pickle["static_argnames"]); + nb::object global_cache_key = pickle["global_cache_key"]; + nb_class_ptr pytree_registry = + nb::cast>( + nb::handle(pickle["pytree_registry"].ptr())); + nb::callable shard_arg_fallback = + nb::cast(pickle["shard_arg_fallback"]); + nb_class_ptr cache = + nb::cast>(pickle["cache"]); + InitializePjitFunction( + reinterpret_cast(self.ptr()), + std::move(function_name), std::move(fun), std::move(cache_miss), + std::move(static_argnums), std::move(static_argnames), + std::move(global_cache_key), std::move(pytree_registry), + std::move(shard_arg_fallback), std::move(cache)); + }, + nb::is_method()); + cfun.attr("__signature__") = xla::nb_property_readonly( + [](nb::handle self) { + return AsPjitFunction(self)->PythonSignature(); + }, + nb::sig("def __signature__(self) -> inspect.Signature")); + cfun.attr("_cache_miss") = + xla::nb_property_readonly([](nb::handle self) { + return AsPjitFunction(self)->cache_miss(); + }); + // All private members are only for testing/debugging purposes + cfun.attr("_cache_size") = nb::cpp_function( + [](nb::handle self) -> int { + return AsPjitFunction(self)->cache_capacity(); + }, + nb::is_method()); + cfun.attr("_clear_cache") = nb::cpp_function( + [](nb::handle self) { AsPjitFunction(self)->ClearCache(); }, + nb::is_method()); + + m.def( + "pjit", + [](std::string function_name, std::optional fun, + nb::callable cache_miss, std::vector static_argnums, + std::vector static_argnames, nb::object global_cache_key, + nb::object pytree_registry, nb::callable shard_arg_fallback, + std::optional> cache) { + nb_class_ptr registry = + nb::cast>( + nb::handle(pytree_registry.ptr())); + return MakePjitFunction( + std::move(function_name), std::move(fun), std::move(cache_miss), + std::move(static_argnums), std::move(static_argnames), + std::move(global_cache_key), std::move(registry), + std::move(shard_arg_fallback), std::move(cache)); + }, + nb::arg("function_name"), nb::arg("fun").none(), nb::arg("cache_miss"), + nb::arg("static_argnums"), nb::arg("static_argnames"), + nb::arg("global_cache_key"), nb::arg("pytree_registry"), + nb::arg("shard_arg_fallback"), nb::arg("cache").none() = nb::none(), + nb::sig( + // clang-format off + "def pjit(" + "function_name: str, " + "fun: Callable[..., Any] | None, " + "cache_miss: Callable[..., Any], " + "static_argnums: Sequence[int], " + "static_argnames: Sequence[str], " + "global_cache_key: Any, " + "pytree_registry: _PyTreeRegistry, " + "shard_arg_fallback: Callable[..., Any], " + "cache: PjitFunctionCache | None = ..." + ") -> PjitFunction" + // clang-format on + )); +} + +} // namespace jax diff --git a/jaxlib/pjit.h b/jaxlib/pjit.h new file mode 100644 index 000000000000..d86fa6bddc3c --- /dev/null +++ b/jaxlib/pjit.h @@ -0,0 +1,27 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PJIT_H_ +#define JAXLIB_PJIT_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace jax { + +void BuildPjitSubmodule(nanobind::module_& m); +} + +#endif // JAXLIB_PJIT_H_ diff --git a/jaxlib/plugin_support.py b/jaxlib/plugin_support.py index ea24dc181be0..3bae30d7903d 100644 --- a/jaxlib/plugin_support.py +++ b/jaxlib/plugin_support.py @@ -21,9 +21,9 @@ from .version import __version__ as jaxlib_version -_PLUGIN_MODULE_NAME = { - "cuda": "jax_cuda12_plugin", - "rocm": "jax_rocm60_plugin", +_PLUGIN_MODULE_NAMES = { + "cuda": ["jax_cuda13_plugin", "jax_cuda12_plugin"], + "rocm": ["jax_rocm7_plugin", "jax_rocm60_plugin"], } @@ -44,10 +44,10 @@ def import_from_plugin( The imported submodule, or None if the plugin is not installed or if the versions are incompatible. """ - if plugin_name not in _PLUGIN_MODULE_NAME: + if plugin_name not in _PLUGIN_MODULE_NAMES: raise ValueError(f"Unknown plugin: {plugin_name}") return maybe_import_plugin_submodule( - [f".{plugin_name}", _PLUGIN_MODULE_NAME[plugin_name]], + [f".{plugin_name}"] + _PLUGIN_MODULE_NAMES[plugin_name], submodule_name, check_version=check_version, ) diff --git a/jaxlib/pmap_lib.cc b/jaxlib/pmap_lib.cc new file mode 100644 index 000000000000..97b9de74e5ca --- /dev/null +++ b/jaxlib/pmap_lib.cc @@ -0,0 +1,1160 @@ +/* Copyright 2021 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/pmap_lib.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/synchronization/notification.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/call_location.h" +#include "jaxlib/config.h" +#include "jaxlib/jax_jit.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_device.h" +#include "jaxlib/py_executable.h" +#include "jaxlib/py_user_context.h" +#include "jaxlib/py_values.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/pytree.h" +#include "jaxlib/sharded_device_array.h" +#include "jaxlib/sharding.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt/user_context.h" +#include "xla/python/nb_helpers.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/safe_static_init.h" +#include "xla/python/types.h" +#include "xla/status_macros.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/python/lib/core/numpy.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/profiler/lib/traceme.h" + +namespace jax { + +namespace nb = nanobind; + +namespace { + +// Specifies how to shard the inputs. Even though everything could be computed +// from `sharding_specs` and the argument shape, we cache derived computations +// for performance. +struct InputSpec { + InputSpec(nb::object indices, nb::object array_sharding) + : indices(std::move(indices)), + array_sharding(std::move(array_sharding)) {} + nb::object indices; + nb::object array_sharding; +}; + +// An object containing the arguments to create Array from the +// output buffers. +struct ResultSpec { + public: + explicit ResultSpec(nb::object aval) + : out_aval(std::move(aval)), + weak_type(nb::cast(out_aval.attr("weak_type"))) {} + nb::object out_aval; + bool weak_type; +}; + +// The result of `ShardArg`. +struct ShardArgResult { + // Points to the on-device array. + // ifrt_array->sharding().num_shards() == `num_devices`. + xla::ifrt::ArrayRef ifrt_array; + // The Python argument will be always be copied to `owning_sda`. + nb::object owning_sda; +}; + +// Shards a single argument over devices. +// +// We currently only support fully in C++, C++ Array. For all +// other usages, we call a Python function returning C++ Array +// that will be casted back to the C++ objects. +// +// This function is not usable for JAX extensions that do not comply with the +// PjRt interfaces. +// +// Arguments: +// `arg`: The object to shard across `devices`. If a `Array`, +// a fast-path will be executed if it's already correctly sharded. +// +// Returns a failure absl::Status when an unrecoverable error occurred, so we +// don't need to fallback to Python. +// +// Both `devices` and `sharding_spec` has the same length. +absl::StatusOr ShardArg( + nb::handle arg, absl::Span devices, + const InputSpec& input_spec, nb::handle py_devices, + const nb::callable& python_fallback) { + if (arg.type().ptr() == PyArray::type().ptr()) { + auto py_array = nb::borrow(arg); + if (py_array.sharding().type().ptr() == + input_spec.array_sharding.type().ptr()) { + auto* pmap_sharding = nb::cast(py_array.sharding()); + auto* cached_pmap_sharding = + nb::cast(input_spec.array_sharding); + + if (pmap_sharding->sharding_spec() == + cached_pmap_sharding->sharding_spec()) { + ShardArgResult result; + result.owning_sda = nb::borrow(arg); + result.ifrt_array = tsl::FormRef(py_array.ifrt_array()); + if (result.ifrt_array == nullptr) { + return xla::InvalidArgument("Array has been deleted."); + } + if (result.ifrt_array->sharding().devices()->devices() != devices) { + absl::InlinedVector ifrt_devices; + ifrt_devices.reserve(devices.size()); + ifrt_devices.insert(ifrt_devices.end(), devices.begin(), + devices.end()); + // pmap does not support memory_kind for now. + auto* ifrt_client = result.ifrt_array->client(); + TF_ASSIGN_OR_RETURN(xla::ifrt::DeviceListRef device_list, + ifrt_client->MakeDeviceList(ifrt_devices)); + TF_ASSIGN_OR_RETURN( + auto copied_ifrt_arrays, + ifrt_client->CopyArrays( + absl::MakeSpan(&result.ifrt_array, 1), std::move(device_list), + xla::ifrt::MemoryKind(), + xla::ifrt::ArrayCopySemantics::kReuseInput)); + result.ifrt_array = std::move(copied_ifrt_arrays.front()); + } + return result; + } + } + } + + auto ndarray = xla::nb_numpy_ndarray::ensure(arg); + if (ndarray && PyArray_CheckExact(arg.ptr()) && + xla::DtypeToPrimitiveType(ndarray.dtype()).status().ok()) { + tsl::profiler::TraceMe traceme("ndarray pmap ShardArg"); + nb::list indices = nb::list(input_spec.indices); + nb::list py_devices_list = nb::cast(py_devices); + auto n_devices = py_devices_list.size(); + if (indices.size() != n_devices) { + return xla::InvalidArgument("indices vs devices mismatch: %d vs %d", + indices.size(), n_devices); + } + + ShardArgResult result; + const bool jax_enable_x64 = GetEnableX64(); + + std::vector owning_args; + std::vector args; + owning_args.reserve(n_devices); + args.reserve(n_devices); + DevicePutOptions options; + options.squash_64bit_types = !jax_enable_x64; + options.allow_zero_copy = true; + xla::ifrt::Client* ifrt_client = nullptr; + for (size_t i = 0; i < n_devices; ++i) { + auto to_device = nb::cast(py_devices_list[i]); + if (to_device->client().get() == nullptr) { + return xla::InvalidArgument("Cannot copy to unattached devices."); + } + if (i == 0) { + ifrt_client = to_device->client()->ifrt_client(); + } + owning_args.push_back(arg[indices[i]]); + args.push_back(owning_args.back()); + } + CHECK(ifrt_client != nullptr); + TF_ASSIGN_OR_RETURN( + DevicePutResult device_put_result, + DevicePutWithSharding( + args, ifrt_client, ndarray.dtype(), + nb::cast>(ndarray.attr("shape")), + input_spec.array_sharding, options)); + result.ifrt_array = std::move(device_put_result.ifrt_array); + return result; + } + tsl::profiler::TraceMe traceme("pmap_lib_shard_arg_python_fallback"); + auto py_array_or_bufs = python_fallback(arg, input_spec.array_sharding); + + auto py_array = nb::cast(py_array_or_bufs); + ShardArgResult result; + result.owning_sda = nb::borrow(py_array_or_bufs); + result.ifrt_array = tsl::FormRef(py_array.ifrt_array()); + return result; +} + +struct PmapCacheEntry { + explicit PmapCacheEntry(PyTreeRegistry* registry) + : out_pytree_def(registry) {} + std::shared_ptr executable; + // The value `backend.local_devices()`. + nb::object py_devices; // To pass back to Python. + std::vector devices; + std::vector input_specs; + PyTreeDef out_pytree_def; + // Objects necessary to build the out Array objects. + std::vector out_result_specs; + + std::vector out_array_shardings; + std::vector out_dtypes; + std::vector> out_shapes; + std::vector out_committed; + + // Ensures a single thread performs the compilation for a given executable. + // + // The first thread (holding the GIL) will create the CacheEntry associated to + // a signature and if the object has been inserted already, other threads + // will wait for the notification. + absl::Notification compilation_complete; + + bool fall_back_to_python = false; +}; + +} // namespace + +// A `PmapFunction` is associated to a `jax.pmap(f)` and takes care of the +// bookkeeping of the different signatures used and the dispatch of calls to +// the correct underlying `PyLoadedExecutable`. This class is thread-safe. +class PmapFunction { + public: + PmapFunction(nb::callable fun, nb::callable cache_miss, + std::vector static_argnums, + nb::callable python_shard_arg_fallback, + nb_class_ptr pytree_registry) + : fun_(std::move(fun)), + cache_miss_(std::move(cache_miss)), + static_argnums_(std::move(static_argnums)), + pytree_registry_(std::move(pytree_registry)), + python_shard_arg_fallback_(std::move(python_shard_arg_fallback)) { + std::sort(static_argnums_.begin(), static_argnums_.end()); + + function_name_ = + nb::cast(nb::str(nb::getattr(fun_, "__name__", fun_))); + } + PmapFunction(const PmapFunction&) = delete; + PmapFunction& operator=(const PmapFunction& other) = delete; + PmapFunction(PmapFunction&&) = default; + PmapFunction& operator=(PmapFunction&&) = default; + + // This function will: + // (a) flatten the inputs using pytree + // (b) get buffer objects from the arguments + // (c) call the executable + // (d) construct `Array` objects from the outputs + // (e) reconstruct the `PyTree`. + absl::StatusOr Call(nb::handle callable, PyObject* const* args, + size_t nargs, PyObject* kwnames); + + nb::object PythonSignature() { + const nb::module_& inspect = xla::SafeStaticInit([]() { + return std::make_unique(nb::module_::import_("inspect")); + }); + return inspect.attr("signature")(fun_); + } + + int cache_size() { + nb::ft_lock_guard lock(mu_); + return executables_.size(); + } + void cache_clear() { + nb::ft_lock_guard lock(mu_); + return executables_.clear(); + } + const nb::callable& fun() const { return fun_; } + const nb::callable& cache_miss() const { return cache_miss_; } + const std::string& function_name() const { return function_name_; } + const nb_class_ptr& pytree_registry() const { + return pytree_registry_; + } + const nb::callable& python_shard_arg_fallback() const { + return python_shard_arg_fallback_; + } + const std::vector& static_argnums() const { return static_argnums_; } + + // nb::object typed subclass for PmapFunction objects. + class pyobject : public nb::object { + public: + NB_OBJECT(pyobject, nb::object, "PmapFunction", + PmapFunction::IsPmapFunction); + pyobject() = default; + PmapFunction* func() const { + return PmapFunction::AsPmapFunctionUnchecked(*this); + } + }; + // Alias as ::object; outside the scope above we won't confuse nanobind's + // macros. + using object = pyobject; + + // Returns true if `h` is a PmapFunction. + static bool IsPmapFunction(nb::handle handle); + // Converts `handle` to a PmapFunction*. Does not do any checking. + static PmapFunction* AsPmapFunctionUnchecked(nb::handle handle); + + // Helper function used by the tp_clear GC method. + void ClearPythonReferences() { + nb::callable fun, cache_miss, python_shard_arg_fallback; + // Swap values for nulls before they are destroyed. See the Python + // Py_CLEAR() documentation for a discussion of this topic. + std::swap(fun_, fun); + std::swap(cache_miss_, cache_miss); + std::swap(python_shard_arg_fallback_, python_shard_arg_fallback); + } + + // Updates the signature of arguments for a pmapped function. + // + // It deals with the arguments signatures and also of the global and + // thread-local jit context. + absl::Status ComputeCallSignature( + absl::Span flat_dynamic_args, + CallSignature& signature) { + signature.function_name = function_name_; + + // Get dynamic argument signatures. + const bool jax_enable_x64 = GetEnableX64(); + for (nb::handle arg : flat_dynamic_args) { + auto signature_or_error = PyArgSignatureOfValue(arg, jax_enable_x64); + if (!signature_or_error.ok()) { + VLOG(2) << "PyArgSignatureOfValue failed: " + << signature_or_error.status(); + return signature_or_error.status(); + } + signature.dynamic_arg_signatures.push_back( + std::move(signature_or_error).value()); + } + signature.configs = JitConfigs(); + signature.cached_hash = absl::HashOf(signature); + return absl::Status(); + } + + // Returns, for debugging purposes (e.g. finding why some call misses the + // cache and recompiles), the list of the string representations of the keys. + // + // The format can change at any time. + std::string DebugCacheKeys() { + nb::ft_lock_guard lock(mu_); + std::vector key_strings = { + absl::StrCat("The cache contains ", executables_.size(), " elements:")}; + // We will be able to use auto& [key, _] when TF uses C++ 17. + for (auto& pair : executables_) { + key_strings.push_back(pair.first.DebugString()); + } + return absl::StrJoin(key_strings, "\n\n"); + } + + private: + // Mutates `cache_entry` in place. + void PopulateCacheEntry(PmapCacheEntry& cache_entry, + const nb::tuple& out_and_fastpath_data); + + bool always_fallback_to_python_ = false; + + nb::callable fun_; // The Python function to pmap. + std::string function_name_; + // See JAX _cpp_pmap in api.py for documentation. + nb::callable cache_miss_; + + // We need to know the static arguments to remove them from the arguments + // passed to the underlying PyLoadedExecutable. In sorted order. + std::vector static_argnums_; + nb_class_ptr pytree_registry_; + // We need a `shared_ptr` here to ensure value pointer stability, and to + // ensure that the cache entry remains alive in the presence of concurrent + // removals. + absl::flat_hash_map, + CallSignature::Hash> + executables_; + + // The fallback function to use with `ShardArgs`. + // TODO(jblespiau): Add support for more types from C++. + nb::callable python_shard_arg_fallback_; + + // Protect methods in FT: + nb::ft_mutex mu_; +}; + +void PmapFunction::PopulateCacheEntry(PmapCacheEntry& cache_entry, + const nb::tuple& out_and_fastpath_data) { + CHECK_EQ(out_and_fastpath_data.size(), 2); + if (out_and_fastpath_data[1].is_none()) { + cache_entry.fall_back_to_python = true; + return; + } + + nb::tuple pmap_data = nb::cast(out_and_fastpath_data[1]); + if (nb::cast(pmap_data.attr("version")) != 1) { + throw xla::XlaRuntimeError(absl::StrCat( + "The versions of jaxlib and Jax are incompatible (pmap cpp version 1 " + "expected, but got ", + nb::cast(pmap_data.attr("version")), + "Upgrade jaxlib and jax. Provided data was:", + nb::cast(nb::str(nb::repr(pmap_data))))); + } + // See api.nb::_PmapFastpathData in the JAX code base for the expected + // namedtuple. + std::shared_ptr executable; + try { + executable = nb::cast>( + pmap_data.attr("xla_executable")); + } catch (const nb::cast_error& e) { + // Backends that don't implement the C++ PjRt APIs + cache_entry.fall_back_to_python = true; + always_fallback_to_python_ = true; + return; + } + cache_entry.executable = std::move(executable); + const std::vector>& devices = + cache_entry.executable->AddressableDevices(); + cache_entry.devices.reserve(devices.size()); + for (auto& device : devices) { + cache_entry.devices.push_back(device->device()); + } + + // Inputs shard args details. + nb::list input_indices = pmap_data.attr("input_indices"); + + cache_entry.py_devices = pmap_data.attr("input_devices"); + auto input_devices = nb::cast>>( + pmap_data.attr("input_devices")); + + nb::list input_array_shardings = pmap_data.attr("input_array_shardings"); + + cache_entry.input_specs.reserve(input_array_shardings.size()); + + for (int i = 0; i < input_array_shardings.size(); ++i) { + cache_entry.input_specs.emplace_back(input_indices[i], + input_array_shardings[i]); + } + + // Outputs specs. + auto out_tree = nb::cast(pmap_data.attr("out_pytree_def")); + cache_entry.out_pytree_def = std::move(out_tree); + nb::list out_avals = pmap_data.attr("out_avals"); + + cache_entry.out_result_specs.reserve(out_avals.size()); + cache_entry.out_dtypes.reserve(out_avals.size()); + cache_entry.out_shapes.reserve(out_avals.size()); + + for (int i = 0; i < out_avals.size(); ++i) { + cache_entry.out_dtypes.push_back(out_avals[i].attr("dtype")); + cache_entry.out_shapes.push_back( + nb::cast>(out_avals[i].attr("shape"))); + cache_entry.out_result_specs.emplace_back(out_avals[i]); + } + + nb::list out_array_shardings = pmap_data.attr("out_array_shardings"); + + DCHECK(out_array_shardings.size() == 0 || + out_avals.size() == out_array_shardings.size()); + + cache_entry.out_array_shardings.reserve(out_array_shardings.size()); + for (nb::handle out_array_sharding : out_array_shardings) { + cache_entry.out_array_shardings.push_back( + nb::borrow(out_array_sharding)); + } + + nb::list out_committed = pmap_data.attr("out_committed"); + + DCHECK(out_committed.size() == 0 || out_avals.size() == out_committed.size()); + + cache_entry.out_committed.reserve(out_committed.size()); + for (nb::handle c : out_committed) { + cache_entry.out_committed.push_back(nb::cast(c)); + } +} + +absl::StatusOr PmapFunction::Call(nb::handle callable, + PyObject* const* args, + size_t nargs, PyObject* kwnames) { + GlobalPyRefManager()->MaybeCollectGarbage(); + InitializeThreadLocalState(); + + // Calls the cache_miss_ function. This just calls the Python function; it may + // return nullptr value if a Python exception is thrown. + auto cache_miss = [&]() -> nb::tuple { + return nb::steal( + PyObject_Vectorcall(cache_miss_.ptr(), args, nargs, kwnames)); + }; + + // Call the cache_miss() function, extracting the output data and ignoring + // the fastpath data. If the cache miss returns a Python error, returns + // nullptr and leaves the Python error set. + auto fallback_to_cache_miss = [&]() { + nb::tuple cache_miss_output = cache_miss(); + if (!cache_miss_output.ptr()) { + return nb::object(); + } + return nb::object(cache_miss_output[0]); + }; + + if (always_fallback_to_python_) { + return fallback_to_cache_miss(); + } + + size_t num_positional_args = PyVectorcall_NARGS(nargs); + size_t num_keyword_args = kwnames ? PyTuple_GET_SIZE(kwnames) : 0; + absl::Span positional_args(args, num_positional_args); + absl::Span keyword_args(args + num_positional_args, + num_keyword_args); + CallSignature call_signature; + absl::InlinedVector flat_dynamic_args; + std::vector keep_alive_objects; + absl::Status status = + ParseArguments(positional_args, keyword_args, kwnames, static_argnums_, + /*static_argnames=*/{}, pytree_registry_.get(), + call_signature.arg_signature, flat_dynamic_args); + if (!status.ok()) { + VLOG(2) << "ParseArguments failed: " << status; + return fallback_to_cache_miss(); + } + + status = ComputeCallSignature(flat_dynamic_args, call_signature); + if (!status.ok()) { + return fallback_to_cache_miss(); + } + + // Retrieve/Maybe add the executable to the cache. + bool inserted = false; + std::shared_ptr cache_entry_ptr; + { + nb::ft_lock_guard lock(mu_); + std::shared_ptr& entry_ref = executables_[call_signature]; + if (!entry_ref) { + inserted = true; + entry_ref = std::make_shared(pytree_registry_.get()); + } + cache_entry_ptr = entry_ref; + } + PmapCacheEntry& cache_entry = *cache_entry_ptr; + + if (!cache_entry.compilation_complete.HasBeenNotified()) { + // In case of several threads attempting to compile the executable, only + // the one that inserted the item will perform the compilation. + if (inserted) { + nb::object out_and_fastpath_data; + nb::tuple out_tuple; + VLOG(2) << "Cache miss for " << call_signature.DebugString(); + try { + // Calls Python and may release the GIL. May also throw if + // compilation/tracing fails. + out_and_fastpath_data = cache_miss(); + if (!out_and_fastpath_data.ptr()) { + throw nb::python_error(); + } + out_tuple = nb::cast(out_and_fastpath_data); + + PopulateCacheEntry(cache_entry, out_tuple); + } catch (const std::exception& e) { + cache_entry.fall_back_to_python = true; + cache_entry.compilation_complete.Notify(); + throw; + } + cache_entry.compilation_complete.Notify(); + + // We have already computed the result in the miss path so we can return + // it. We are even *required* to do so if there are donated arguments, + // because any donated buffers will now be invalid. + return nb::object(out_tuple[0]); + } else { + // Release the GIL while we wait, making sure the compile thread can + // lock it. + nb::gil_scoped_release release; + cache_entry.compilation_complete.WaitForNotification(); + } + } + if (cache_entry.fall_back_to_python) { + return fallback_to_cache_miss(); + } + + PyUserContextScope user_context_scope; + + // 1. Parse arguments. + std::vector& input_devices = cache_entry.devices; + std::vector& input_specs = cache_entry.input_specs; + const int num_args = flat_dynamic_args.size(); + + // We need [num_args] for the `Execute` call below. + std::vector num_args_arrays(num_args); + for (int i = 0; i < num_args; ++i) { + TF_ASSIGN_OR_RETURN( + ShardArgResult sharded_arg, + ShardArg(flat_dynamic_args[i], input_devices, input_specs[i], + cache_entry.py_devices, python_shard_arg_fallback_)); + + num_args_arrays[i] = std::move(sharded_arg.ifrt_array); + if (sharded_arg.owning_sda) { + keep_alive_objects.push_back(std::move(sharded_arg.owning_sda)); + } + } + + xla::ifrt::ExecuteOptions execute_options = cache_entry.executable->options(); + execute_options.launch_id = cache_entry.executable->GetNextLaunchId(); + execute_options.execution_stream_id = GetExecutionStreamId(); + if (execute_options.execution_stream_id == 0) { + execute_options.execution_stream_id = + tsl::Env::Default()->GetCurrentThreadId(); + } + PopulateCallLocation(execute_options, + xla::ifrt::UserContextScope::current().get()); + + // A vector of [num_outputs]. + std::vector output_arrays; + { + nb::gil_scoped_release gil_release; + auto ifrt_executable = cache_entry.executable->ifrt_executable(); + TF_ASSIGN_OR_RETURN( + auto result, ifrt_executable->Execute(absl::MakeSpan(num_args_arrays), + execute_options, + /*devices=*/std::nullopt)); + output_arrays = std::move(result.outputs); + } + + // TODO(jblespiau): We don't need to create the PyBuffer objects. + // Having a C++ `Array`, keeping internally the PjRtBuffer + // objects is sufficient, and we can lazily create the `PyBuffer` only if + // we access them from Python. + // TODO(jblespiau): Change the `client` function to return a reference. + nb_class_ptr client = cache_entry.executable->client(); + + // Convert the PjRtBuffer objects to PyBuffer, and invert the order from + // [num_devices, num_args] to [num_args, num_devices]. + const int num_outputs = output_arrays.size(); + std::vector flat_sharded_device_arrays; + flat_sharded_device_arrays.reserve(num_outputs); + + const auto& output_specs = cache_entry.out_result_specs; + + TF_RET_CHECK(cache_entry.out_array_shardings.size() == num_outputs); + for (int i = 0; i < num_outputs; ++i) { + const ResultSpec& result_spec = output_specs[i]; + PyArray py_array(result_spec.out_aval, result_spec.weak_type, + cache_entry.out_dtypes[i], cache_entry.out_shapes[i], + cache_entry.out_array_shardings[i], client, + std::move(output_arrays[i]), cache_entry.out_committed[i], + /*skip_checks=*/true); + + flat_sharded_device_arrays.push_back(std::move(py_array)); + } + + nb::object out = + cache_entry.out_pytree_def.Unflatten(flat_sharded_device_arrays); + + // If there is a post-hook function, call it with the inputs and the outputs. + std::optional post_hook = GetPostHook(); + if (post_hook) { + nb::tuple args_tuple = + nb::steal(PyTuple_New(num_positional_args)); + for (size_t i = 0; i < num_positional_args; ++i) { + Py_INCREF(args[i]); + PyTuple_SET_ITEM(args_tuple.ptr(), i, args[i]); + } + nb::dict kwargs; + if (kwnames) { + for (size_t i = 0; i < num_keyword_args; ++i) { + kwargs[nb::handle(PyTuple_GET_ITEM(kwnames, i))] = + nb::borrow(args[num_positional_args + i]); + } + } + + (*post_hook)(callable, args_tuple, kwargs, out); + } + + return out; +} + +struct JaxPmapFunctionObject { + PyObject_HEAD; +#if PY_VERSION_HEX < 0x030C0000 + PyObject* dict; // Dictionary for __dict__ + PyObject* weakrefs; // Weak references; for use by the Python interpreter. +#endif // PY_VERSION_HEX < 0x030C0000 + vectorcallfunc vectorcall; + PmapFunction fun; +}; + +PyObject* JaxPmapFunction_Type = nullptr; + +bool PmapFunction::IsPmapFunction(nb::handle handle) { + return handle.type().ptr() == JaxPmapFunction_Type; +} + +PmapFunction* PmapFunction::AsPmapFunctionUnchecked(nb::handle handle) { + return &(reinterpret_cast(handle.ptr())->fun); +} + +absl::StatusOr AsPmapFunction(nb::handle handle) { + if (!PmapFunction::IsPmapFunction(handle)) { + return xla::InvalidArgument("Expected a PmapFunction"); + } + return PmapFunction::AsPmapFunctionUnchecked(handle); +} + +namespace { + +extern "C" { + +PyObject* JaxPmapFunction_tp_vectorcall(PyObject* callable, + PyObject* const* args, size_t nargs, + PyObject* kwnames) { + JaxPmapFunctionObject* o = reinterpret_cast(callable); + tsl::profiler::TraceMe traceme([&] { + return absl::StrCat("JaxPmapFunction(", o->fun.function_name(), ")"); + }); + try { + absl::StatusOr out = + o->fun.Call(callable, args, nargs, kwnames); + if (!out.ok()) { + PyErr_SetString(PyExc_ValueError, out.status().ToString().c_str()); + return nullptr; + } + return out.value().release().ptr(); + } catch (nb::python_error& e) { + e.restore(); + return nullptr; + } catch (nb::cast_error& e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return nullptr; + } catch (std::invalid_argument& e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return nullptr; + } +} + +PyObject* JaxPmapFunction_tp_new(PyTypeObject* subtype, PyObject* args, + PyObject* kwds) { + JaxPmapFunctionObject* self = + reinterpret_cast(subtype->tp_alloc(subtype, 0)); + if (!self) return nullptr; +#if PY_VERSION_HEX < 0x030C0000 + self->dict = nullptr; + self->weakrefs = nullptr; +#endif // PY_VERSION_HEX < 0x030C0000 + self->vectorcall = JaxPmapFunction_tp_vectorcall; + return reinterpret_cast(self); +} + +void JaxPmapFunction_tp_dealloc(PyObject* self) { + PyObject_GC_UnTrack(self); + PyTypeObject* tp = Py_TYPE(self); + JaxPmapFunctionObject* o = reinterpret_cast(self); + PyObject_ClearWeakRefs(self); +#if PY_VERSION_HEX < 0x030C0000 + Py_CLEAR(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + o->fun.~PmapFunction(); + tp->tp_free(self); + Py_DECREF(tp); +} + +int JaxPmapFunction_tp_traverse(PyObject* self, visitproc visit, void* arg) { + JaxPmapFunctionObject* o = reinterpret_cast(self); + // https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_traverse + Py_VISIT(Py_TYPE(self)); +#if PY_VERSION_HEX < 0x030C0000 + Py_VISIT(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_VisitManagedDict(self, visit, arg); +#else + PyObject_VisitManagedDict(self, visit, arg); +#endif // PY_VERSION_HEX < 0x030C0000 + Py_VISIT(o->fun.fun().ptr()); + Py_VISIT(o->fun.cache_miss().ptr()); + return 0; +} + +int JaxPmapFunction_tp_clear(PyObject* self) { + JaxPmapFunctionObject* o = reinterpret_cast(self); +#if PY_VERSION_HEX < 0x030C0000 + Py_CLEAR(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + o->fun.ClearPythonReferences(); + return 0; +} + +// Implements the Python descriptor protocol so PMAP-compiled functions can be +// used as bound methods. See: +// https://docs.python.org/3/howto/descriptor.html#functions-and-methods +PyObject* JaxPmapFunction_tp_descr_get(PyObject* self, PyObject* obj, + PyObject* type) { + if (obj == nullptr || obj == Py_None) { + Py_INCREF(self); + return self; + } + return PyMethod_New(self, obj); +} + +static PyGetSetDef JaxPmapFunction_tp_getset[] = { + // Having a __dict__ seems necessary to allow !functool.wraps to override + // __doc__. + {const_cast("__dict__"), PyObject_GenericGetDict, + PyObject_GenericSetDict, nullptr, nullptr}, + {nullptr, nullptr, nullptr, nullptr, nullptr}}; + +PyMemberDef JaxPmapFunction_members[] = { + {"__vectorcalloffset__", T_PYSSIZET, + static_cast(offsetof(JaxPmapFunctionObject, vectorcall)), + READONLY, nullptr}, +#if PY_VERSION_HEX < 0x030C0000 + {"__dictoffset__", T_PYSSIZET, + static_cast(offsetof(JaxPmapFunctionObject, dict)), READONLY, + nullptr}, + {"__weaklistoffset__", T_PYSSIZET, + static_cast(offsetof(JaxPmapFunctionObject, weakrefs)), + READONLY, nullptr}, +#endif // PY_VERSION_HEX < 0x030C0000 + {nullptr, 0, 0, 0, nullptr}, +}; + +PyType_Slot JaxPmapFunction_slots[] = { + {Py_tp_new, reinterpret_cast(JaxPmapFunction_tp_new)}, + {Py_tp_dealloc, reinterpret_cast(JaxPmapFunction_tp_dealloc)}, + {Py_tp_traverse, reinterpret_cast(JaxPmapFunction_tp_traverse)}, + {Py_tp_clear, reinterpret_cast(JaxPmapFunction_tp_clear)}, + {Py_tp_getset, reinterpret_cast(JaxPmapFunction_tp_getset)}, + {Py_tp_descr_get, reinterpret_cast(JaxPmapFunction_tp_descr_get)}, + {Py_tp_call, reinterpret_cast(PyVectorcall_Call)}, + {Py_tp_members, reinterpret_cast(JaxPmapFunction_members)}, + {0, nullptr}, +}; + +} // extern "C" + +nb::object MakePmapFunction(nb::callable fun, nb::callable cache_miss, + std::vector static_argnums, + nb::callable python_shard_arg_fallback, + nb_class_ptr pytree_registry) { + nb::object obj = nb::steal(JaxPmapFunction_tp_new( + reinterpret_cast(JaxPmapFunction_Type), nullptr, nullptr)); + JaxPmapFunctionObject* buf = + reinterpret_cast(obj.ptr()); + new (&buf->fun) PmapFunction( + std::move(fun), std::move(cache_miss), std::move(static_argnums), + std::move(python_shard_arg_fallback), std::move(pytree_registry)); + return obj; +} + +// Version numbers for the pickled representations. +// Increment these if changing them. +const int kPmapFunctionPickleVersion = 1; + +struct Descriptor {}; + +} // namespace + +void BuildPmapSubmodule(nb::module_& m) { + nb::module_ pmap_lib = m.def_submodule("pmap_lib", "Jax C++ pmap library"); + + nb::class_ no_sharding(pmap_lib, "NoSharding"); + no_sharding.def(nb::init<>()) + .def("__getstate__", + [](const NoSharding& self) { return nb::make_tuple(); }) + .def("__setstate__", + [](NoSharding& self, nb::tuple t) { new (&self) NoSharding(); }) + .def("__repr__", [](const NoSharding& self) { return "NoSharding()"; }) + .def("__eq__", + [](const NoSharding& self, nb::object obj) { + return nb::isinstance(obj); + }) + .def("__hash__", [](const NoSharding& self) { + const size_t hash = absl::HashOf(self); + return nb::int_(hash); + }); + + nb::class_ chunked(pmap_lib, "Chunked"); + chunked.def(nb::init>()) + .def("__getstate__", + [](const Chunked& self) { return nb::make_tuple(self.chunks); }) + .def("__setstate__", + [](Chunked& self, nb::tuple t) { + new (&self) Chunked{nb::cast>(t[0])}; + }) + .def_ro("chunks", &Chunked::chunks) + .def("__repr__", + [](const Chunked& self) { + return absl::StrCat("Chunked(", absl::StrJoin(self.chunks, ","), + ")"); + }) + .def("__eq__", [](const Chunked& self, nb::object other) { + return nb::isinstance(other) && + self == nb::cast(other); + }); + + nb::class_ unstacked(pmap_lib, "Unstacked"); + unstacked.def(nb::init()) + .def("__getstate__", + [](const Unstacked& self) { return nb::make_tuple(self.size); }) + .def("__setstate__", + [](Unstacked& self, nb::tuple t) { + new (&self) Unstacked{nb::cast(t[0])}; + }) + .def_ro("size", &Unstacked::size) + .def("__repr__", + [](const Unstacked& x) { + return absl::StrCat("Unstacked(", x.size, ")"); + }) + .def("__eq__", [](const Unstacked& self, nb::object other) { + return nb::isinstance(other) && + self == nb::cast(other); + }); + + nb::class_ sharded_axis(pmap_lib, "ShardedAxis"); + sharded_axis.def(nb::init()) + .def("__getstate__", + [](const ShardedAxis& self) { return nb::make_tuple(self.axis); }) + .def("__setstate__", + [](ShardedAxis& self, nb::tuple t) { + new (&self) ShardedAxis{nb::cast(t[0])}; + }) + .def_ro("axis", &ShardedAxis::axis) + .def("__repr__", + [](const ShardedAxis& x) { + return absl::StrCat("ShardedAxis(axis=", x.axis, ")"); + }) + .def("__eq__", [](const ShardedAxis& self, nb::object other) { + return nb::isinstance(other) && + self == nb::cast(other); + }); + + nb::class_ replicated(pmap_lib, "Replicated"); + replicated.def(nb::init()) + .def("__getstate__", + [](const Replicated& self) { return nb::make_tuple(self.replicas); }) + .def("__setstate__", + [](Replicated& self, nb::tuple t) { + new (&self) Replicated{nb::cast(t[0])}; + }) + .def_ro("replicas", &Replicated::replicas) + .def("__repr__", + [](const Replicated& x) { + return absl::StrCat("Replicated(replicas=", x.replicas, ")"); + }) + .def("__eq__", [](const Replicated& self, nb::object other) { + return nb::isinstance(other) && + self == nb::cast(other); + }); + + nb::class_ sharding_spec( + pmap_lib, "ShardingSpec", nb::sig("class ShardingSpec(typing.Any)")); + sharding_spec + .def(nb::init(), nb::arg("sharding"), + nb::arg("mesh_mapping")) + .def("__getstate__", + [](const ShardingSpec& self) { + auto sharding = + xla::SpanToNbTuple(absl::MakeConstSpan(self.GetSharding())); + auto mesh_mapping = + xla::SpanToNbTuple(absl::MakeConstSpan(self.GetMeshMapping())); + return nb::make_tuple(sharding, mesh_mapping); + }) + .def("__setstate__", + [](ShardingSpec& self, nb::tuple t) { + new (&self) + ShardingSpec{nb::cast>(t[0]), + nb::cast>(t[1])}; + }) + .def_prop_ro( + "sharding", + [](const ShardingSpec& self) { + return xla::SpanToNbTuple(absl::MakeConstSpan(self.GetSharding())); + }) + .def_prop_ro("mesh_mapping", + [](const ShardingSpec& self) { + return xla::SpanToNbTuple( + absl::MakeConstSpan(self.GetMeshMapping())); + }) + .def("__eq__", + [](const ShardingSpec& self, nb::object other) { + return nb::isinstance(other) && + self == nb::cast(other); + }) + .def("__hash__", [](const ShardingSpec& self) { + const size_t hash = absl::HashOf(self); + return nb::int_(hash); + }); + + // We need to use heap-allocated type objects because we want to add + // additional methods dynamically. + + std::string name = + absl::StrCat(nb::cast(m.attr("__name__")), ".PmapFunction"); + PyType_Spec pmap_function_spec = { + /*.name=*/name.c_str(), + /*.basicsize=*/static_cast(sizeof(JaxPmapFunctionObject)), + /*.itemsize=*/0, +#if PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_HAVE_VECTORCALL, +#else // PY_VERSION_HEX >= 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_HAVE_VECTORCALL | Py_TPFLAGS_MANAGED_DICT | + Py_TPFLAGS_MANAGED_WEAKREF, +#endif // PY_VERSION_HEX >= 0x030C0000 + /*.slots=*/JaxPmapFunction_slots, + }; + + JaxPmapFunction_Type = PyType_FromSpec(&pmap_function_spec); + if (!JaxPmapFunction_Type) { + throw nb::python_error(); + } + nb::object cfun = nb::borrow(JaxPmapFunction_Type); + cfun.attr("__module__") = pmap_lib.attr("__name__"); + pmap_lib.attr("PmapFunction") = cfun; + + // Add PmapFunction to the _jax module so it can be pickled. + m.attr("PmapFunction") = cfun; + + cfun.attr("__signature__") = xla::nb_property_readonly( + [](nb::handle self) -> nb::object { + PmapFunction* fun = xla::ValueOrThrow(AsPmapFunction(self)); + return fun->PythonSignature(); + }, + nb::sig("def __signature__(self) -> inspect.Signature")); + // Required by `post_hook`. + cfun.attr("_cache_miss") = xla::nb_property_readonly([](nb::handle self) { + PmapFunction* fun = xla::ValueOrThrow(AsPmapFunction(self)); + return fun->cache_miss(); + }); + cfun.attr("__getstate__") = nb::cpp_function( + [](const PmapFunction::object& self) { + PmapFunction* fn = self.func(); + nb::dict pickle; + pickle["version"] = kPmapFunctionPickleVersion; + pickle["fun"] = fn->fun(); + pickle["cache_miss"] = fn->cache_miss(); + pickle["static_argnums"] = fn->static_argnums(); + pickle["python_shard_arg_fallback"] = fn->python_shard_arg_fallback(); + pickle["pytree_registry"] = nb::cast(fn->pytree_registry()); + return pickle; + }, + nb::is_method()); + cfun.attr("__setstate__") = nb::cpp_function( + [](PmapFunction::object& self, const nb::dict& pickle) { + int version = nb::cast(pickle["version"]); + if (version != kPmapFunctionPickleVersion) { + throw std::invalid_argument(absl::StrFormat( + "Invalid PmapFunction pickle version, got %d, expected %d. " + "Pickling/Unpickling jitted functions using different JAX " + "versions is not supported.", + version, kPmapFunctionPickleVersion)); + } + nb::callable fun = nb::cast(pickle["fun"]); + nb::callable cache_miss = nb::cast(pickle["cache_miss"]); + std::vector static_argnums = + nb::cast>(pickle["static_argnums"]); + nb::callable python_shard_arg_fallback = + nb::cast(pickle["python_shard_arg_fallback"]); + nb_class_ptr pytree_registry = + nb::cast>(pickle["pytree_registry"]); + new (&(reinterpret_cast(self.ptr())->fun)) + PmapFunction(std::move(fun), std::move(cache_miss), + std::move(static_argnums), + std::move(python_shard_arg_fallback), + std::move(pytree_registry)); + }, + nb::is_method()); + + // This is only for testing/debugging purposes. + cfun.attr("_cache_size") = xla::nb_property_readonly( + [](nb::handle self) { + PmapFunction* fun = xla::ValueOrThrow(AsPmapFunction(self)); + return nb::cast(fun->cache_size()); + }, + nb::sig("def _cache_size(self) -> int")); + + cfun.attr("_cache_clear") = nb::cpp_function( + [](nb::handle self) { + PmapFunction* fun = xla::ValueOrThrow(AsPmapFunction(self)); + fun->cache_clear(); + }, + nb::is_method()); + + cfun.attr("_debug_cache_keys") = nb::cpp_function( + [](nb::handle self) -> std::string { + PmapFunction* fun = xla::ValueOrThrow(AsPmapFunction(self)); + return fun->DebugCacheKeys(); + }, + nb::is_method()); + + pmap_lib.attr("_PyTreeRegistry") = m.attr("pytree").attr("PyTreeRegistry"); + pmap_lib.def( + "pmap", + [](nb::callable fun, nb::callable cache_miss, + std::vector static_argnums, nb::callable shard_arg_fallback, + nb::object pytree_registry) -> nb::object { + nb_class_ptr registry = + nb::cast>(pytree_registry); + return MakePmapFunction( + std::move(fun), std::move(cache_miss), std::move(static_argnums), + std::move(shard_arg_fallback), std::move(registry)); + }, + nb::arg("fun"), nb::arg("cache_miss"), nb::arg("static_argnums"), + nb::arg("shard_arg_fallback"), nb::arg("pytree_registry"), + nb::sig( + // clang-format off + "def pmap(" + "fun: typing.Callable[..., typing.Any], " + "cache_miss: Callable[..., Any], " + "static_argnums: typing.Sequence[int], " + "shard_arg_fallback: Callable[..., Any], " + "pytree_registry: _PyTreeRegistry" + ") -> PmapFunction" + // clang-format on + )); +} + +} // namespace jax diff --git a/jaxlib/pmap_lib.h b/jaxlib/pmap_lib.h new file mode 100644 index 000000000000..b7cc2cc13f36 --- /dev/null +++ b/jaxlib/pmap_lib.h @@ -0,0 +1,34 @@ +/* Copyright 2021 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PMAP_LIB_H_ +#define JAXLIB_PMAP_LIB_H_ + + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +// TODO(jblespiau): The current implementation moves the Python logic to C++, +// as a preliminary step to executing the `pmap` execution path from C++. +// It implements the current Python behavior (thus, it may not be optimal, and +// we will be able to modify it later). + +namespace jax { + +void BuildPmapSubmodule(nanobind::module_& m); + +} // namespace jax + +#endif // JAXLIB_PMAP_LIB_H_ diff --git a/jaxlib/pprof_profile_builder.cc b/jaxlib/pprof_profile_builder.cc new file mode 100644 index 000000000000..dd1e52102752 --- /dev/null +++ b/jaxlib/pprof_profile_builder.cc @@ -0,0 +1,105 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/pprof_profile_builder.h" + +#include // IWYU pragma: keep + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "xla/tsl/platform/logging.h" +#include "xla/util.h" +#include "tsl/platform/protobuf.h" +#include "tsl/profiler/protobuf/profile.pb.h" + +namespace xla { + +namespace nb = nanobind; + +PprofProfileBuilder::PprofProfileBuilder() { CHECK_EQ(0, StringId("")); } + +int PprofProfileBuilder::StringId(std::string_view s) { + auto ret = strings_.emplace(s, profile_.string_table_size()); + if (ret.second) { + profile_.add_string_table(s.data(), s.size()); + } + return ret.first->second; +} + +int PprofProfileBuilder::FunctionId(PyCodeObject* code) { + // +1 because id 0 is reserved. + auto ret = functions_.emplace(code, profile_.function_size() + 1); + if (ret.second) { + auto* function = profile_.add_function(); + function->set_id(ret.first->second); + int name = StringId(nb::cast(nb::str(code->co_name))); + function->set_name(name); + function->set_system_name(name); + function->set_filename( + StringId(nb::cast(nb::str(code->co_filename)))); + function->set_start_line(code->co_firstlineno); + } + return ret.first->second; +} + +int PprofProfileBuilder::LocationId(PyCodeObject* code, int instruction) { + // +1 because id 0 is reserved. + auto ret = locations_.emplace(std::make_pair(code, instruction), + profile_.location_size() + 1); + if (ret.second) { + auto* location = profile_.add_location(); + location->set_id(ret.first->second); + auto* line = location->add_line(); + line->set_function_id(FunctionId(code)); + line->set_line(PyCode_Addr2Line(code, instruction)); + } + return ret.first->second; +} + +absl::StatusOr JsonToPprofProfile(std::string json) { + tensorflow::tfprof::pprof::Profile profile; + auto status = tsl::protobuf::util::JsonStringToMessage(json, &profile); + if (!status.ok()) { + // TODO(phawkins): the explicit `std::string` cast here is to work around + // https://github.com/google/jax/issues/9534 which appears to be an ABSL and + // protobuf version compatibility problem. + return InvalidArgument("JSON parsing failed: %s", + std::string{status.message()}); + } + std::string s = profile.SerializeAsString(); + return nb::bytes(s.data(), s.size()); +} + +absl::StatusOr PprofProfileToJson(nb::bytes binary_proto) { + tensorflow::tfprof::pprof::Profile profile; + profile.ParseFromArray(binary_proto.c_str(), binary_proto.size()); + std::string output; + auto status = tsl::protobuf::util::MessageToJsonString(profile, &output); + if (!status.ok()) { + // TODO(phawkins): the explicit `std::string` cast here is to work around + // https://github.com/google/jax/issues/9534 which appears to be an ABSL and + // protobuf version compatibility problem. + return InvalidArgument("JSON printing failed: %s", + std::string{status.message()}); + } + return output; +} + +} // namespace xla diff --git a/jaxlib/pprof_profile_builder.h b/jaxlib/pprof_profile_builder.h new file mode 100644 index 000000000000..79caebe64cf5 --- /dev/null +++ b/jaxlib/pprof_profile_builder.h @@ -0,0 +1,69 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PPROF_PROFILE_BUILDER_H_ +#define JAXLIB_PPROF_PROFILE_BUILDER_H_ + +#include + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "tsl/profiler/protobuf/profile.pb.h" + +namespace xla { + +// Helper class for building pprof::Profile profiles. +class PprofProfileBuilder { + public: + PprofProfileBuilder(); + tensorflow::tfprof::pprof::Profile& profile() { return profile_; } + + // Adds or returns the ID of `s` in the table. + int StringId(std::string_view s); + + // Adds or returns the ID of a function. + int FunctionId(PyCodeObject* code); + + // Adds or returns the ID of a code location. + int LocationId(PyCodeObject* code, int instruction); + + private: + tensorflow::tfprof::pprof::Profile profile_; + + absl::flat_hash_map strings_; + absl::flat_hash_map functions_; + absl::flat_hash_map, int> locations_; +}; + +// Converts the JSON representation of a pprof profile protocol buffer into +// a serialized protocol buffer. We want to allow Python code to construct pprof +// protocol buffers, but we don't want to export the generated protocol buffer +// bindings for Python because they cause conflicts between multiple Python +// extensions that contain the same protocol buffer message. Instead, we accept +// a JSON representation from Python and use this function to serialize it to +// a uncompressed binary protocol buffer. +absl::StatusOr JsonToPprofProfile(std::string json); + +// The reverse, useful for testing. +absl::StatusOr PprofProfileToJson(nanobind::bytes binary_proto); + +} // namespace xla + +#endif // JAXLIB_PPROF_PROFILE_BUILDER_H_ diff --git a/jaxlib/py_array.cc b/jaxlib/py_array.cc new file mode 100644 index 000000000000..bccfe7189fc6 --- /dev/null +++ b/jaxlib/py_array.cc @@ -0,0 +1,2402 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_array.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/casts.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/guard_lib.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_device.h" +#include "jaxlib/py_device_list.h" +#include "jaxlib/py_user_context.h" +#include "jaxlib/py_values.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/sharding.h" +#include "jaxlib/to_ifrt_sharding.h" +#include "jaxlib/traceback.h" +#include "jaxlib/util.h" +#include "xla/future.h" +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/lru_cache.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/pjrt/status_casters.h" +#include "xla/primitive_util.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/array_spec.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/remap_plan.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt/user_context.h" +#include "xla/python/ifrt/user_context_status_util.h" +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/nb_helpers.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/safe_static_init.h" +#include "xla/python/types.h" +#include "xla/python/version.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/tsl/concurrency/future.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/python/lib/core/numpy.h" // IWYU pragma: keep +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/profiler/lib/traceme.h" + +namespace ifrt = ::xla::ifrt; +namespace nb = nanobind; + +namespace jax { +namespace { + +nb::object& tracer_class = *new nb::object(); + +xla::PjRtBuffer* GetPjrtBuffer(ifrt::Array* ifrt_array) { + auto* arr = llvm::dyn_cast_or_null(ifrt_array); + if (arr == nullptr) { + throw xla::XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return arr->pjrt_buffers().front().get(); +} + +absl::StatusOr XlaDynamicShape( + ifrt::Array* ifrt_array, std::optional& scratch) { + auto* pjrt_buffer = GetPjrtBuffer(ifrt_array); + + if (!scratch) { + absl::Span dims; + std::optional> logical_dims_storage; + if (pjrt_buffer->has_dynamic_dimensions()) { + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(std::vector logical_dims, + pjrt_buffer->logical_dimensions()); + logical_dims_storage.emplace(std::move(logical_dims)); + } + dims = *logical_dims_storage; + } else { + dims = pjrt_buffer->dimensions(); + } + xla::Shape shape = + xla::ShapeUtil::MakeShape(pjrt_buffer->element_type(), dims); + // TODO(b/327524065): fix this + *shape.mutable_layout() = pjrt_buffer->layout()->xla_layout(); + scratch = std::move(shape); + } + return &scratch.value(); +} + +ifrt::ArrayRef CreateIfRtArrayFromSingleDeviceShardedPyArrays( + xla::nb_dtype dtype, absl::Span shape, + absl::Span py_arrays, const nb::object& sharding) { + const ifrt::MemoryKind dst_memory_kind = GetMemoryKind(sharding); + + std::vector ifrt_arrays; + ifrt_arrays.reserve(py_arrays.size()); + absl::InlinedVector devices; + devices.reserve(py_arrays.size()); + absl::flat_hash_set device_set; + device_set.reserve(py_arrays.size()); + std::vector shapes; + shapes.reserve(py_arrays.size()); + + auto sharding_device_list = GetIfrtDeviceList(sharding); + if (!sharding_device_list.ok()) { + // TODO(hyeontaek): Return a absl::Status. + throw nb::value_error(sharding_device_list.status().ToString().c_str()); + } + ifrt::Device* device = sharding_device_list.value()->devices().front(); + + // TODO(hyeontaek): Canonicalize every `ifrt::MemoryKind` at creation time to + // skip canonicalization here once JAX begins to do it for JAX shardings. + const ifrt::MemoryKind canonical_dst_memory_kind = + ifrt::CanonicalizeMemoryKind(dst_memory_kind, device); + for (const auto& py_array : py_arrays) { + if (py_array.num_shards() != 1) { + throw nb::value_error( + absl::StrFormat( + "When making an array from single-device arrays the input arrays " + "must have one shard each. An argument array had %d shard(s).", + py_array.num_shards()) + .c_str()); + } + ifrt_arrays.push_back(tsl::FormRef(py_array.ifrt_array())); + ifrt::Device* const device = + ifrt_arrays.back()->sharding().devices()->devices().front(); + devices.push_back(device); + device_set.insert(device); + shapes.push_back(ifrt_arrays.back()->shape()); + if (canonical_dst_memory_kind != + ifrt::CanonicalizeMemoryKind( + ifrt_arrays.back()->sharding().memory_kind(), device)) { + throw nb::value_error( + absl::StrFormat( + "Memory kind mismatch with xla::PjRtBuffers. Got sharding with " + "memory kind '%v' and a buffer with memory_kind '%v'", + dst_memory_kind, ifrt_arrays.back()->sharding().memory_kind()) + .c_str()); + } + } + ifrt::DeviceListRef device_list = + xla::ValueOrThrow(device->client()->MakeDeviceList(devices)); + if (device_set.size() != device_list->size()) { + throw nb::value_error( + absl::StrFormat( + "When making an array from single-device arrays, the input arrays " + "must be from distinct devices, but got %v", + *device_list) + .c_str()); + } + + auto ifrt_dtype = DtypeToIfRtDType(dtype); + if (!ifrt_dtype.ok()) { + // TODO(hyeontaek): Return a absl::Status. + throw nb::value_error(ifrt_dtype.status().ToString().c_str()); + } + + absl::StatusOr ifrt_sharding = + sharding.type().is(PmapSharding::type()) + ? GetIfrtConcreteSharding(sharding, ifrt::Shape(shape), + std::move(shapes)) + : GetIfrtHloSharding(sharding, ifrt::Shape(shape)); + if (!ifrt_sharding.ok()) { + // TODO(hyeontaek): Return a absl::Status. + throw nb::value_error(ifrt_sharding.status().ToString().c_str()); + } + // TODO(emilyaf): Always use `ifrt_dtype` once tokens are handled correctly. + ifrt::DType array_dtype = + ifrt_arrays.empty() ? ifrt_dtype.value() : ifrt_arrays[0]->dtype(); + absl::StatusOr ifrt_array = + device->client()->AssembleArrayFromSingleDeviceArrays( + array_dtype, ifrt::Shape(shape), *std::move(ifrt_sharding), + absl::MakeSpan(ifrt_arrays), ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); + if (!ifrt_array.ok()) { + // TODO(hyeontaek): Return a absl::Status. + throw nb::value_error(ifrt_array.status().ToString().c_str()); + } + return *std::move(ifrt_array); +} + +struct PyBaseArrayObject { + PyObject_HEAD; +#if PY_VERSION_HEX < 0x030C0000 + PyObject* weakrefs; +#endif // PY_VERSION_HEX < 0x030C0000 +}; + +extern "C" void PyBaseArray_tp_dealloc(PyObject* self) { + PyObject_GC_UnTrack(self); + PyTypeObject* tp = Py_TYPE(self); + tp->tp_free(self); + Py_DECREF(tp); +} + +extern "C" int PyBaseArray_tp_traverse(PyObject* self, visitproc visit, + void* arg) { + Py_VISIT(Py_TYPE(self)); + return 0; +} + +struct PyArrayObject { + PyObject_HEAD; +#if PY_VERSION_HEX < 0x030C0000 + PyObject* weakrefs; + PyObject* dict; +#endif // PY_VERSION_HEX < 0x030C0000 + bool initialized; + alignas(PyArray::Storage) char array_storage[sizeof(PyArray::Storage)]; +}; +static_assert(std::is_standard_layout::value); + +PyArray::Storage* GetPyArrayStorageFromObject(PyArrayObject* py_array_object) { + return std::launder( + reinterpret_cast(py_array_object->array_storage)); +} + +extern "C" PyObject* PyArray_tp_new(PyTypeObject* type, PyObject*, PyObject*) { + PyObject* self = type->tp_alloc(type, 0); + auto* obj = reinterpret_cast(self); + obj->initialized = false; + return self; +} + +extern "C" void PyArray_tp_dealloc(PyObject* self) { + PyObject_GC_UnTrack(self); + PyTypeObject* tp = Py_TYPE(self); + auto* obj = reinterpret_cast(self); + + if (obj->initialized) { + GetPyArrayStorageFromObject(obj)->~PyArray_Storage(); + } + + PyObject_ClearWeakRefs(self); +#if PY_VERSION_HEX < 0x030C0000 + PyObject*& dict = *_PyObject_GetDictPtr(self); + Py_CLEAR(dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + + tp->tp_free(self); + Py_DECREF(tp); +} + +// dynamic_attr: Allow the garbage collector to traverse the internal instance +// `__dict__`. +extern "C" int PyArray_tp_traverse(PyObject* self, visitproc visit, void* arg) { +#if PY_VERSION_HEX < 0x030C0000 + PyObject*& dict = *_PyObject_GetDictPtr(self); + Py_VISIT(dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_VisitManagedDict(self, visit, arg); +#else + PyObject_VisitManagedDict(self, visit, arg); +#endif // PY_VERSION_HEX < 0x030C0000 + // https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_traverse + Py_VISIT(Py_TYPE(self)); + return 0; +} + +// dynamic_attr: Allow the GC to clear the dictionary. +extern "C" int PyArray_tp_clear(PyObject* self) { + switch (auto guard_level = GetGarbageCollectArrayGuard(); guard_level) { + case GarbageCollectionGuardLevel::kAllow: + break; + case GarbageCollectionGuardLevel::kLog: + case GarbageCollectionGuardLevel::kFatal: { + auto* obj = reinterpret_cast(self); + std::string traceback_str; + if (obj->initialized) { + xla::ifrt::Array* ifrt_array_ptr = + GetPyArrayStorageFromObject(obj)->ifrt_array.get(); + if (ifrt_array_ptr != nullptr) { + std::optional traceback = + GetTraceback(ifrt_array_ptr->user_context().get()); + if (traceback.has_value()) { + traceback_str = traceback->ToString(); + } + } + } + auto error_msg = absl::StrCat( + "`jax.Array` was deleted by the Python garbage collector " + "instead of reference counting. Break the reference cycle " + "that delays the deletion of this `jax.Array` to avoid hogging " + "memory. Traceback: \n", + traceback_str.empty() ? "not available" : traceback_str); + if (guard_level == GarbageCollectionGuardLevel::kFatal) { + Py_FatalError(error_msg.c_str()); + } else { + PyErr_SetString(PyExc_RuntimeError, error_msg.c_str()); + PyErr_Print(); + PyErr_Clear(); + } + break; + } + } +#if PY_VERSION_HEX < 0x030C0000 + PyObject*& dict = *_PyObject_GetDictPtr(self); + Py_CLEAR(dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + return 0; +} + +template +PyArray::Storage* Construct(PyArrayObject* self, Args&&... args) { + PyArray::Storage* out = + new (self->array_storage) PyArray::Storage(std::forward(args)...); + self->initialized = true; + return out; +} + +struct ShapedArrayCacheKey { + std::vector dims; + ifrt::DType dtype{ifrt::DType::kInvalid}; + bool weak_type; + + template + friend H AbslHashValue(H h, const ShapedArrayCacheKey& value) { + return H::combine(std::move(h), value.dims, value.dtype, value.weak_type); + } + bool operator==(const ShapedArrayCacheKey& other) const { + return dims == other.dims && dtype == other.dtype && + weak_type == other.weak_type; + } +}; + +// Constructing ShapedArrays has gotten slow. Cache it. +nb::object MakeShapedArrayCached(const ShapedArrayCacheKey& key) { + using CacheT = xla::LRUCache>>; + static nb::ft_mutex mu; + static auto* lru_list = new CacheT::LRUList(4096); + static auto* cache = new CacheT(lru_list); + + const nb::object& shaped_array = xla::SafeStaticInit([]() { + nb::object jax_core; + try { + jax_core = nb::module_::import_("jax.core"); + } catch (nb::python_error& e) { + return std::make_unique(); + } + return std::make_unique(jax_core.attr("ShapedArray")); + }); + if (!shaped_array.ptr()) { + return nb::none(); + } + + nb::ft_lock_guard lock(mu); + auto value = + cache->GetOrCreateIfAbsent(key, [](const ShapedArrayCacheKey& key) { + return std::make_shared>(); + }); + + if (!value->has_value()) { + xla::nb_dtype dtype = + xla::IfrtDtypeToDtypeWithTokenCanonicalization(key.dtype).value(); + nb::object aval = shaped_array( + xla::SpanToNbTuple(absl::Span( + key.dtype.kind() == ifrt::DType::kToken ? std::vector{0} + : key.dims)), + dtype, key.weak_type); + *value = aval; + return aval; + } + return **value; +} + +// Grouping key used by BatchedCopyToDeviceWithSharding. +// Defined outside of the function as required by templatized function +// `AbslHashValue`. +struct BatchedCopyToDeviceWithShardingKey { + ifrt::DeviceListRef src_devices; + ifrt::MemoryKind src_memory_kind; + ifrt::DeviceListRef dst_devices; + ifrt::MemoryKind dst_memory_kind; + ifrt::ArrayCopySemantics array_copy_semantics; + + bool operator==(const BatchedCopyToDeviceWithShardingKey& other) const { + return *src_devices == *other.src_devices && + src_memory_kind == other.src_memory_kind && + *dst_devices == *other.dst_devices && + dst_memory_kind == other.dst_memory_kind && + array_copy_semantics == other.array_copy_semantics; + } + + template + friend H AbslHashValue(H h, const BatchedCopyToDeviceWithShardingKey& key) { + return H::combine(std::move(h), key.src_devices, key.src_memory_kind, + key.dst_devices, key.dst_memory_kind, + key.array_copy_semantics); + } +}; + +} // namespace + +PyArray_Storage::PyArray_Storage(nb::object aval, bool weak_type, + xla::nb_dtype dtype, + std::vector shape, + nb::object sharding, bool committed, + nb_class_ptr py_client, + ifrt::ArrayRef ifrt_array, + xla::Future<> result_status) + : aval(std::move(aval)), + weak_type(weak_type), + dtype(std::move(dtype)), + shape(std::move(shape)), + sharding(std::move(sharding)), + committed(committed), + py_client(std::move(py_client)), + ifrt_array(std::move(ifrt_array)), + result_status(std::move(result_status)) { + static_assert(PyClient::kNumArraysShards < + std::numeric_limits::max()); + thread_id_bucket = std::hash()(std::this_thread::get_id()) % + PyClient::kNumArraysShards; + + PyClient::ArraysShard& shard = this->py_client->arrays_[thread_id_bucket]; + nanobind::ft_lock_guard lock(shard.mutex); + next = shard.arrays; + shard.arrays = this; + if (next) { + next->prev = this; + } + prev = nullptr; +} + +void PyInit_helper(PyArray self, nb::object aval, nb::object sharding, + absl::Span py_arrays, bool committed) { + auto dtype = nb::cast(aval.attr("dtype")); + auto shape = nb::cast>(aval.attr("shape")); + auto py_device_list = + nb::cast(sharding.attr("_internal_device_list")); + nb_class_ptr py_client = py_device_list->py_client(); + auto ifrt_array = CreateIfRtArrayFromSingleDeviceShardedPyArrays( + dtype, shape, py_arrays, sharding); + Construct(reinterpret_cast(self.ptr()), aval, + nb::cast(aval.attr("weak_type")), std::move(dtype), + std::move(shape), std::move(sharding), committed, py_client, + std::move(ifrt_array), xla::Future<>()); +} + +void PyArray::PyInit(PyArray self, nb::object aval, nb::object sharding, + absl::Span py_arrays, bool committed, + bool skip_checks) { + PyUserContextScope user_context_scope; + if (skip_checks) { + PyInit_helper(self, aval, sharding, py_arrays, committed); + } else { + nb::object rearranged_arrays = + self.CheckAndRearrange(py_arrays, sharding, aval); + auto rearranged_py_arrays = + nb::cast>(rearranged_arrays); + PyInit_helper(self, aval, sharding, rearranged_py_arrays, committed); + } +} + +PyArray PyArray::MakeFromSingleDeviceArray(nb_class_ptr py_client, + ifrt::ArrayRef ifrt_array, + bool weak_type, bool committed, + xla::Future<> result_status) { + if (!llvm::isa(ifrt_array->sharding())) { + throw xla::XlaRuntimeError(xla::InvalidArgument( + "Constructing single device jax.Array from non-single " + "device ifrt array.")); + } + auto shape_span = ifrt_array->shape().dims(); + ShapedArrayCacheKey key; + key.dtype = ifrt_array->dtype(); + key.dims = key.dtype.kind() == ifrt::DType::kToken + ? std::vector{0} + : std::vector(shape_span.begin(), shape_span.end()); + key.weak_type = weak_type; + auto aval = MakeShapedArrayCached(key); + auto dtype = + xla::IfrtDtypeToDtypeWithTokenCanonicalization(key.dtype).value(); + const ifrt::MemoryKind memory_kind = ifrt_array->sharding().memory_kind(); + nb::object py_memory_kind = + (memory_kind.memory_kind().has_value()) + ? nb::object(nb::str(memory_kind.memory_kind()->data(), + memory_kind.memory_kind()->size())) + : nb::none(); + nb::object sharding = make_nb_class( + py_client, ifrt_array->sharding().devices(), std::move(py_memory_kind)); + return PyArray(std::move(aval), weak_type, dtype, std::move(key.dims), + std::move(sharding), std::move(py_client), + std::move(ifrt_array), committed, + /*skip_checks=*/true, std::move(result_status)); +} + +PyArray PyArray::MakeFromIfrtArrayAndSharding(nb_class_ptr py_client, + ifrt::ArrayRef ifrt_array, + nb::object sharding, + bool weak_type, bool committed, + bool skip_checks) { + auto shape_span = ifrt_array->shape().dims(); + ShapedArrayCacheKey key; + key.dtype = ifrt_array->dtype(); + key.dims = key.dtype.kind() == ifrt::DType::kToken + ? std::vector{0} + : std::vector(shape_span.begin(), shape_span.end()); + key.weak_type = weak_type; + auto aval = MakeShapedArrayCached(key); + auto dtype = + xla::IfrtDtypeToDtypeWithTokenCanonicalization(key.dtype).value(); + return PyArray(std::move(aval), weak_type, dtype, std::move(key.dims), + std::move(sharding), std::move(py_client), + std::move(ifrt_array), committed, skip_checks); +} + +PyArrayResultHandler::PyArrayResultHandler( + nb::object aval, nb::object sharding, bool committed, bool skip_checks, + std::vector wrappers) + : aval_(std::move(aval)), + sharding_(std::move(sharding)), + committed_(committed), + skip_checks_(skip_checks), + wrappers_(std::move(wrappers)) { + weak_type_ = nb::cast(aval_.attr("weak_type")); + dtype_ = nb::cast(aval_.attr("dtype")); + shape_ = nb::cast>(aval_.attr("shape")); +} + +nanobind::object PyArrayResultHandler::Call( + absl::Span py_arrays) const { + auto py_device_list = GetPyDeviceList(sharding_); + if (!py_device_list.ok()) { + throw nb::value_error( + absl::StrCat("Failed to get py device list from sharding: ", + py_device_list.status().ToString()) + .c_str()); + } + PyUserContextScope user_context_scope; + return Call(py_device_list.value()->py_client(), + CreateIfRtArrayFromSingleDeviceShardedPyArrays( + dtype_, shape_, py_arrays, sharding_), + xla::Future<>()); +} + +nanobind::object PyArrayResultHandler::Call(nb_class_ptr py_client, + ifrt::ArrayRef ifrt_array, + xla::Future<> result_status) const { + nanobind::object result = + PyArray(aval_, weak_type_, dtype_, shape_, sharding_, + std::move(py_client), std::move(ifrt_array), committed_, + skip_checks_, std::move(result_status)); + for (auto& cb : wrappers_) { + result = cb(std::move(result)); + } + return result; +} + +nanobind::object PyArrayResultHandler::Call(PyArray py_array) const { + return Call(py_array.py_client(), tsl::FormRef(py_array.ifrt_array()), + xla::Future<>()); +} + +PyArray::PyArray(nb::object aval, bool weak_type, xla::nb_dtype dtype, + std::vector shape, nb::object sharding, + nb_class_ptr py_client, ifrt::ArrayRef ifrt_array, + bool committed, bool skip_checks, + xla::Future<> result_status) { + if (ifrt_array->user_context() == nullptr && Traceback::IsEnabled()) { + throw nb::value_error( + "Expecting an IFRT `Array` to have a user context, but got a null " + "user context. Use `jax::PyUserContextScope` to set a user context for " + "operations producing IFRT `Array`s."); + } + auto* self = + PyArray_tp_new(reinterpret_cast(type_), nullptr, nullptr); + m_ptr = self; + Construct(reinterpret_cast(self), std::move(aval), weak_type, + std::move(dtype), std::move(shape), std::move(sharding), committed, + std::move(py_client), std::move(ifrt_array), + std::move(result_status)); + + if (!skip_checks) { + this->attr("_arrays") = this->attr("_check_and_rearrange")( + this->attr("_arrays"), this->attr("_sharding"), this->attr("aval")); + } +} + +PyArray::Storage& PyArray::GetStorage() { + return *GetPyArrayStorageFromObject(reinterpret_cast(ptr())); +} + +const PyArray::Storage& PyArray::GetStorage() const { + return *GetPyArrayStorageFromObject(reinterpret_cast(ptr())); +} + +nb::object PyArray::CheckAndRearrange(const absl::Span py_arrays, + const nb::object sharding, + const nb::object aval) { + return this->attr("_check_and_rearrange")(py_arrays, sharding, aval); +} + +void PyArray::SetIfrtArray(ifrt::ArrayRef ifrt_array) { + if (ifrt_array != nullptr && ifrt_array->user_context() == nullptr && + Traceback::IsEnabled()) { + throw nb::value_error( + "Expecting an IFRT `Array` to have a user context, but got a null " + "user context. Use `jax::PyUserContextScope` to set a user context for " + "operations producing IFRT `Array`s."); + } + GetStorage().ifrt_array = std::move(ifrt_array); +} + +const std::vector& PyArray::py_arrays_cached() { + auto& py_arrays = this->py_arrays(); + + if (py_arrays.empty()) { + // Use the user context of this array. + xla::ifrt::UserContextScope user_context_scope( + ifrt_array()->user_context()); + auto ifrt_arrays = ifrt_array()->DisassembleIntoSingleDeviceArrays( + ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); + if (!ifrt_arrays.ok()) { + throw nb::value_error( + absl::StrCat("Failed to disassemble into single-device arrays: ", + ifrt_arrays.status().ToString()) + .c_str()); + } + py_arrays.reserve(ifrt_arrays->size()); + for (auto& ifrt_array : *ifrt_arrays) { + py_arrays.push_back(PyArray::MakeFromSingleDeviceArray( + py_client(), std::move(ifrt_array), weak_type(), committed(), + result_status())); + } + } + + return py_arrays; +} + +nb::object PyArray::arrays() { + // For performance, we only keep pjrt buffers by default. But on python side + // "_arrays" returns PyArrays instead, and subsequent calls to "_arrays" + // should return the same PyArrays (to avoid duplicate device to host + // transfers). So we create PyArrays the first time it is called and reuse + // them later. + if (ifrt_array() == nullptr || ifrt_array()->IsDeleted()) return nb::none(); + + if (llvm::isa(&ifrt_array()->sharding())) { + std::vector py_arrays; + py_arrays.push_back(*this); + return nb::cast(py_arrays); + } + + return nb::cast(py_arrays_cached()); +} + +absl::Status PyArray::set_arrays(nb::object obj) { + if (obj.is_none()) { + SetIfrtArray(ifrt::ArrayRef()); + py_arrays().clear(); + return absl::OkStatus(); + } + + if (!nb::isinstance(obj)) { + return xla::InvalidArgument( + "Unsupported arg when setting Array._arrays: %s", + nb::cast(nb::str(obj.type()))); + } + + nb::list list(obj); + + if (list.size() == 0) return absl::OkStatus(); + + SetIfrtArray(ifrt::ArrayRef()); + py_arrays().clear(); + std::vector ifrt_arrays; + ifrt_arrays.reserve(list.size()); + absl::InlinedVector devices; + devices.reserve(list.size()); + std::vector shapes; + shapes.reserve(list.size()); + for (nb::handle obj : list) { + if (obj.type().is(PyArray::type())) { + auto py_array = nb::borrow(obj); + if (py_array.py_client().get() != py_client().get()) { + return xla::InvalidArgument( + "Client mismatch when assigning to _arrays."); + } + if (py_array.num_shards() != 1) { + return xla::InvalidArgument("Wrong number of shards: %d", + py_array.num_shards()); + } + ifrt_arrays.push_back(tsl::FormRef(py_array.ifrt_array())); + devices.push_back( + ifrt_arrays.back()->sharding().devices()->devices().front()); + shapes.push_back(ifrt_arrays.back()->shape()); + } else { + return xla::InvalidArgument( + "Unsupported arg when setting Array._arrays: %s", + nb::cast(nb::str(obj.type()))); + } + } + const ifrt::MemoryKind first_memory_kind = + ifrt_arrays.front()->sharding().memory_kind(); + // TODO(hyeontaek): Canonicalize every `ifrt::MemoryKind` at creation time to + // skip canonicalization here once JAX begins to do it for JAX shardings. + const ifrt::MemoryKind canonical_first_memory_kind = + ifrt::CanonicalizeMemoryKind( + first_memory_kind, + ifrt_arrays.front()->sharding().devices()->devices().front()); + for (const auto& ifrt_array : ifrt_arrays) { + if (canonical_first_memory_kind != + ifrt::CanonicalizeMemoryKind( + ifrt_array->sharding().memory_kind(), + ifrt_array->sharding().devices()->devices().front())) { + throw nb::value_error( + absl::StrFormat( + "Memory kind mismatch between single-device arrays. Got one " + "array with memory kind '%v' and another with memory_kind '%v'", + first_memory_kind, ifrt_array->sharding().memory_kind()) + .c_str()); + } + } + + TF_ASSIGN_OR_RETURN( + auto ifrt_sharding, + sharding().type().is(PmapSharding::type()) + ? GetIfrtConcreteSharding(sharding(), ifrt::Shape(shape()), + std::move(shapes)) + : GetIfrtHloSharding(sharding(), ifrt::Shape(shape()))); + TF_ASSIGN_OR_RETURN( + auto array, + py_client()->ifrt_client()->AssembleArrayFromSingleDeviceArrays( + ifrt_arrays[0]->dtype(), ifrt::Shape(shape()), + std::move(ifrt_sharding), absl::MakeSpan(ifrt_arrays), + ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards)); + SetIfrtArray(std::move(array)); + return absl::OkStatus(); +} + +absl::StatusOr PyArray::FullyReplicatedShard() { + auto& cached = GetStorage().fully_replicated_array; + if (!cached.is_none()) { + return nb::cast(cached); + } + + if (ifrt_array() == nullptr) { + return xla::InvalidArgument( + "FullyReplicatedShard() called on deleted or donated buffer"); + } + + // Use the user context of this array. + xla::ifrt::UserContextScope user_context_scope(ifrt_array()->user_context()); + TF_ASSIGN_OR_RETURN(auto fully_replicated_ifrt_shard, + ifrt_array()->FullyReplicatedShard( + ifrt::ArrayCopySemantics::kReuseInput)); + auto array = MakeFromSingleDeviceArray( + py_client(), std::move(fully_replicated_ifrt_shard), weak_type(), + committed(), result_status()); + cached = array; + return nb::cast(cached); +} + +absl::Status PyArray::BlockUntilReady() const { + PyUserContextScope user_context_scope; + absl::Status status; + { + nb::gil_scoped_release gil_release; + if (ifrt_array() == nullptr) { + return xla::InvalidArgument( + "BlockHostUntilReady() called on deleted or donated buffer"); + } + ifrt::Array* ifrt_array = this->ifrt_array(); + status = AwaitBuffersReady(absl::MakeConstSpan(&ifrt_array, 1)); + } + // The array ready future can reference an asynchronously propagated + // `ifrt::UserContext` representing the context of an error. We expand this + // future result right before returning it to Python (outside of + // `nb::gil_scoped_release`) so that any attached user context is appended to + // the status message. + return xla::ifrt::ExpandUserContexts(std::move(status)); +} + +absl::StatusOr PyArray::GetOnDeviceSizeInBytes() { + const ifrt::Array* ifrt_array = this->ifrt_array(); + if (ifrt_array == nullptr) { + return xla::InvalidArgument( + "GetOnDeviceSizeInBytes() called on deleted or donated buffer"); + } + // TODO(emilyaf): Support this method for non-addressable arrays by calling + // py_client()->pjrt_client()->GetOnDeviceBytesCount once all clients + // implement it. + if (ifrt_array->sharding().devices()->AddressableDeviceList()->empty()) { + return xla::Unimplemented( + "GetOnDeviceSizeInBytes() is not yet supported for arrays with no " + "addressable devices"); + } + TF_ASSIGN_OR_RETURN(std::shared_ptr pjrt_layout, + layout()); + TF_ASSIGN_OR_RETURN(xla::PrimitiveType element_type, + ifrt::ToPrimitiveType(ifrt_array->dtype())); + if (sharding().type().is(SingleDeviceSharding::type())) { + // An array with `SingleDeviceSharding` takes a fast path. We do not need to + // compute a shard shape separately, and the array aval is often `None`. + xla::Shape shard_shape = xla::ShapeUtil::MakeShape(element_type, shape()); + *shard_shape.mutable_layout() = pjrt_layout->xla_layout(); + return xla::ShapeUtil::ArraySize(shard_shape); + } + auto shard_shape_dims = nb::cast>( + sharding().attr("shard_shape")(aval().attr("shape"))); + xla::Shape shard_shape = + xla::ShapeUtil::MakeShape(element_type, shard_shape_dims); + *shard_shape.mutable_layout() = pjrt_layout->xla_layout(); + return xla::ShapeUtil::ArraySize(shard_shape) * + nb::len(nb::object(sharding().attr("device_set"))); +} + +absl::Status PyArray::BlockUntilResultStatusIsReady() { + auto& result_status = GetStorage().result_status; + // If the result_status future is not valid, this result did not come directly + // from a computation that returns tokens, so we don't wait for the status. + if (!result_status.IsValid()) { + return absl::OkStatus(); + } + absl::Status status; + if (!result_status.IsReady()) { + // Only release the gil if we need to Await(). + nb::gil_scoped_release release_gil; + BlockUntilReadyWithCancel(result_status); + status = result_status.Await(); + } else { + status = result_status.Await(); + } + // `result_status` originates from `ifrt::ExecuteResult::status`, which can + // reference an asynchronously propagated `ifrt::UserContext` representing the + // context of an error. We expand this future result right before returning it + // to Python (outside of `nb::gil_scoped_release`) so that any attached user + // context is appended to the status message. + return xla::ifrt::ExpandUserContexts(std::move(status)); +} + +absl::StatusOr> +PyArray::SingleDeviceArrayToNumpyArrayDidCopy() { + TF_ASSIGN_OR_RETURN(auto arr, FullyReplicatedShard()); + auto result = arr.GetStorage().host_value.AsNumPyArray( + arr.GetStorage().dynamic_shape, arr.ifrt_array()); + TF_RETURN_IF_ERROR(arr.BlockUntilResultStatusIsReady()); + return result; +} + +absl::StatusOr PyArray::SingleDeviceArrayToNumpyArray() { + TF_ASSIGN_OR_RETURN(auto result, SingleDeviceArrayToNumpyArrayDidCopy()); + return result.first; +} + +absl::Status PyArray::CopySingleDeviceArrayToHostAsync() { + TF_ASSIGN_OR_RETURN(auto arr, FullyReplicatedShard()); + return arr.GetStorage().host_value.CopyToHostAsync( + arr.GetStorage().dynamic_shape, arr.ifrt_array()); +} + +absl::StatusOr PyArray::AssertUnsharded(std::string_view api) { + if (ifrt_array() == nullptr) { + return xla::InvalidArgument("%s( called on deleted or donated buffer", api); + } + + if (llvm::isa(&ifrt_array()->sharding())) { + return *this; + } + + auto& py_arrays = py_arrays_cached(); + if (py_arrays.size() != 1) { + return xla::InvalidArgument("%s() is supported only for unsharded arrays.", + api); + } + return py_arrays[0]; +} + +absl::StatusOr PyArray::UnsafeBufferPointer() { + TF_ASSIGN_OR_RETURN(auto arr, AssertUnsharded("UnsafeBufferPointer")); + + return py_client()->pjrt_client()->UnsafeBufferPointer( + GetPjrtBuffer(arr.ifrt_array())); +} + +nb::dict PyArray::CudaArrayInterface() { + auto arr_or_error = AssertUnsharded("UnsafeBufferPointer"); + if (!arr_or_error.ok()) { + throw nb::attribute_error( + "__cuda_array_interface__ is only supported for unsharded arrays."); + } + auto arr = *arr_or_error; + + ifrt::Array* ifrt_array = arr.ifrt_array(); + std::optional& scratch = arr.GetStorage().dynamic_shape; + auto* pjrt_buffer = GetPjrtBuffer(ifrt_array); + if (pjrt_buffer->client()->platform_id() != xla::CudaId() && + pjrt_buffer->client()->platform_id() != xla::RocmId()) { + throw nb::attribute_error( + "__cuda_array_interface__ is only defined for GPU buffers."); + } + if (pjrt_buffer->IsTuple()) { + throw nb::attribute_error( + "__cuda_array_interface__ is only defined for array buffers."); + } + + switch (pjrt_buffer->element_type()) { + case xla::PrimitiveType::PRED: + case xla::PrimitiveType::S8: + case xla::PrimitiveType::S16: + case xla::PrimitiveType::S32: + case xla::PrimitiveType::S64: + case xla::PrimitiveType::U8: + case xla::PrimitiveType::U16: + case xla::PrimitiveType::U32: + case xla::PrimitiveType::U64: + case xla::PrimitiveType::F16: + case xla::PrimitiveType::F32: + case xla::PrimitiveType::F64: + case xla::PrimitiveType::C64: + case xla::PrimitiveType::C128: + break; + + default: + throw nb::attribute_error( + absl::StrFormat( + "__cuda_array_interface__ is not supported for %s buffers.", + PrimitiveType_Name(pjrt_buffer->element_type())) + .c_str()); + } + + nb::str typestr = xla::ValueOrThrow( + TypeDescriptorForPrimitiveType(pjrt_buffer->element_type())); + + // TODO(b/327524065): use xla::PjRtLayout directly instead of xla::Layout + xla::Layout xla_layout = pjrt_buffer->layout()->xla_layout(); + if (!xla::LayoutUtil::IsMonotonicWithDim0Major(xla_layout)) { + throw nb::attribute_error( + "__cuda_array_interface__ is only currently supported for " + "buffers in row-major order."); + } + + nb::dict result; + const auto* dynamic_shape = + xla::ValueOrThrow(XlaDynamicShape(ifrt_array, scratch)); + result["shape"] = xla::SpanToNbTuple(dynamic_shape->dimensions()); + result["typestr"] = std::move(typestr); + std::unique_ptr external_reference_hold = + xla::ValueOrThrow(pjrt_buffer->AcquireExternalReference()); + const void* root_ptr = + external_reference_hold->OpaqueDeviceMemoryDataPointer(); + nb::tuple data = + nb::make_tuple(nb::int_(absl::bit_cast(root_ptr)), + nb::bool_(true) /* read-only */ + ); + result["data"] = std::move(data); + result["version"] = nb::int_(2); + return result; +} + +absl::StatusOr CudaArrayInterfaceToBuffer( + const nb::dict& cai, nb_class_ptr client, + std::optional device_id) { + if (!cai.contains("data")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `data`"); + } + if (!cai.contains("shape")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `shape`"); + } + if (!cai.contains("typestr")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `typestr`"); + } + if (!cai.contains("version")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `version`"); + } + auto version = nb::cast(cai["version"]); + if (version < 2 || version > 3) { + LOG(WARNING) << "CUDA Array Interface version " << version + << " support is undefined"; + } + auto data = nb::cast(cai["data"]); + auto data_value = nb::cast(data[0]); + void* data_ptr = reinterpret_cast(data_value); + auto dimensions = nb::cast>(cai["shape"]); + if (data_value == 0 && absl::c_find(dimensions, 0) == dimensions.end()) { + return absl::InvalidArgumentError( + "CUDA Array Interface `data`(=NULL) and `shape`(no zero-valued " + "dimensions) are inconsistent"); + } + auto ndim = dimensions.size(); + TF_ASSIGN_OR_RETURN( + xla::PrimitiveType element_type, + DtypeToPrimitiveType(xla::nb_dtype::from_args(cai["typestr"]))); + + if (!device_id.has_value()) { + throw xla::XlaRuntimeError( + "This operation requires CUDA support from jaxlib or jax cuda plugin."); + } + TF_ASSIGN_OR_RETURN(auto device, + client->DeviceFromLocalHardwareId(*device_id)); + bool is_default_stream = + data_value == 0 || version == 2 || + (version == 3 && (!cai.contains("stream") || cai["stream"].is_none())); + TF_ASSIGN_OR_RETURN( + std::intptr_t stream, + ([is_default_stream, cai, device]() -> absl::StatusOr { + if (is_default_stream) { + return device->GetStreamForExternalReadyEvents(); + } else { + auto stream_ = nb::cast(cai["stream"]); + if (stream_ == 0) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not allow zero stream value"); + } + return stream_; + } + }())); + + bool has_custom_layout; + std::vector minor_to_major(ndim); + if (cai.contains("strides") && !cai["strides"].is_none() && data_value != 0) { + std::iota(minor_to_major.begin(), minor_to_major.end(), 0); + auto strides = nb::cast>(cai["strides"]); + if (strides.size() != ndim) { + return absl::InvalidArgumentError( + "CUDA Array Interface `shape` and `strides` dimensionalities are " + "inconsistent"); + } + has_custom_layout = true; + absl::c_sort(minor_to_major, [&](int a, int b) { + // If two dimensions have the same stride, prefer the major-to-minor + // interpretation of the ordering, since that's what JAX wants. + return (strides[a] == strides[b] ? b < a : strides[a] < strides[b]); + }); + int64_t stride = xla::ShapeUtil::ByteSizeOfPrimitiveType(element_type); + for (int64_t d : minor_to_major) { + if (dimensions[d] > 1 && strides[d] != stride) { + return absl::UnimplementedError(absl::StrCat( + "Only arrays with trivial (compact) striding are supported; " + "i.e., arrays whose striding represents a transposition of the " + "underlying buffer but not broadcasting. Dimensions were: [%s], " + "strides were [%s].", + absl::StrJoin(dimensions, ","), absl::StrJoin(strides, ","))); + } + stride *= dimensions[d]; + } + } else { + has_custom_layout = false; + std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); + } + xla::Shape shape = xla::ShapeUtil::MakeShapeWithDenseLayout( + element_type, dimensions, minor_to_major); + std::function on_delete_callback = []() {}; + auto* pjrt_device = + llvm::dyn_cast_or_null(device->device()); + if (pjrt_device == nullptr) { + return xla::InvalidArgument( + "This operation is implemented for a PjRt-compatible backend only."); + } + TF_RET_CHECK(pjrt_device->IsAddressable()); + TF_ASSIGN_OR_RETURN( + auto pjrt_buffer, + device->client()->pjrt_client()->CreateViewOfDeviceBuffer( + static_cast(data_ptr), shape, + *pjrt_device->pjrt_device()->default_memory_space(), + on_delete_callback, + stream <= 2 ? std::nullopt : std::make_optional(stream))); + auto* ifrt_client = + llvm::dyn_cast_or_null(client->ifrt_client()); + if (ifrt_client == nullptr) { + throw xla::XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + PyUserContextScope user_context_scope; + TF_ASSIGN_OR_RETURN( + auto ifrt_array, + ifrt_client->CreatePjRtArray(std::move(pjrt_buffer), has_custom_layout)); + return PyArray::MakeFromSingleDeviceArray(std::move(client), + std::move(ifrt_array), false, true); +} + +absl::Status PyArray::Delete() { + for (auto& arr : py_arrays()) { + TF_RETURN_IF_ERROR(arr.Delete()); + } + py_arrays().clear(); + if (ifrt_array() != nullptr) { + // We do not wait for the deletion to complete here. + // + // (1) Skipping blocking does not affect the correctness of deletion as long + // as the runtime preserves dispatch ordering of deletion w.r.t. other + // operations. + // + // (2) Synchronously waiting for the deletion to complete is very expensive + // when the deletion can return a status only after the underlying physical + // buffer has been deleted or a request must be processed via RPC, + // especially as this deletion is done per array. + ifrt_array()->Delete(); + SetIfrtArray(ifrt::ArrayRef()); + } + return absl::OkStatus(); +} + +bool PyArray::IsDeleted() const { + if (ifrt_array() == nullptr) { + return true; + } + + return ifrt_array()->IsDeleted(); +} + +PyArray PyArray::Clone() const { + auto array = tsl::FormRef(ifrt_array()); + auto* ifrt_client = py_client()->ifrt_client(); + // Use the user context of this array. + xla::ifrt::UserContextScope user_context_scope(array->user_context()); + ifrt::ArrayRef out = + ifrt_client + ->CopyArrays(absl::MakeSpan(&array, 1), /*devices=*/std::nullopt, + /*memory_kind=*/std::nullopt, + ifrt::ArrayCopySemantics::kReuseInput) + .value() + .front(); + return PyArray(aval(), weak_type(), dtype(), + std::vector(shape().begin(), shape().end()), + sharding(), py_client(), std::move(out), committed(), + /*skip_checks=*/true, result_status()); +} + +nb::handle PyArray::Storage::AsHandle() { + return reinterpret_cast(reinterpret_cast(this) - + offsetof(PyArrayObject, array_storage)); +} + +PyArray::Storage::~PyArray_Storage() { + CHECK(PyGILState_Check()); + if (py_client) { + PyClient::ArraysShard& shard = py_client->arrays_[thread_id_bucket]; + nanobind::ft_lock_guard lock(shard.mutex); + if (shard.arrays == this) { + shard.arrays = next; + } + if (prev) { + prev->next = next; + } + if (next) { + next->prev = prev; + } + } + // Release GIL and then explicitly destroy `ifrt_array` to prevent deadlock on + // CPU backend caused by interactions between argument donations and host + // callbacks. + nb::gil_scoped_release gil_release; + ifrt_array.reset(); +} + +absl::StatusOr> PyArray::BatchedCopyToDeviceWithSharding( + absl::Span py_arrays, + absl::Span dst_device_lists, + absl::Span dst_shardings, + absl::Span array_copy_semantics) { + if (py_arrays.empty()) { + return std::vector(); + } + + TF_RET_CHECK(py_arrays.size() == dst_device_lists.size()); + TF_RET_CHECK(py_arrays.size() == dst_shardings.size()); + + ifrt::Client* const client = py_arrays.front().ifrt_array()->client(); + std::vector results(py_arrays.size()); + + // Arrays to be copied, grouped by source/destination devices and memory + // kinds. The grouping is enforced by `ifrt::Client::CopyArrays()`. + struct Batch { + std::vector indexes; + std::vector ifrt_arrays; + }; + absl::flat_hash_map batches; + + PyUserContextScope user_context_scope; + { + tsl::profiler::TraceMe results_traceme( + "BatchedCopyToDeviceWithSharding create batch"); + for (int i = 0; i < py_arrays.size(); ++i) { + const auto& py_array = py_arrays[i]; + const auto& dst_sharding = dst_shardings[i]; + const auto& array_cs = array_copy_semantics[i]; + + auto* ifrt_array_ptr = py_array.ifrt_array(); + const ifrt::DeviceListRef& src_devices = + ifrt_array_ptr->sharding().devices(); + const ifrt::DeviceListRef& dst_devices = dst_device_lists[i]; + + ifrt::MemoryKind src_memory_kind = + ifrt::CanonicalizeMemoryKind(ifrt_array_ptr->sharding().memory_kind(), + src_devices->devices().front()); + ifrt::MemoryKind dst_memory_kind = ifrt::CanonicalizeMemoryKind( + GetMemoryKind(dst_sharding), dst_devices->devices().front()); + + if (*src_devices == *dst_devices && src_memory_kind == dst_memory_kind && + array_cs == ifrt::ArrayCopySemantics::kReuseInput) { + if (py_array.sharding().equal(dst_sharding)) { + results[i] = py_arrays[i]; + } else { + absl::Span shape_span = py_array.shape(); + // We can reuse the input array despite the sharding being different. + // This is because this code expects no resharding is necessary, which + // has been verified by the code invoking this method. + results[i] = PyArray( + py_array.aval(), py_array.weak_type(), py_array.dtype(), + std::vector(shape_span.begin(), shape_span.end()), + dst_sharding, py_array.py_client(), tsl::FormRef(ifrt_array_ptr), + py_array.committed(), + /*skip_checks=*/true, py_array.result_status()); + } + continue; + } + + auto transfer_guard_formatter = [&py_array, &dst_sharding] { + return absl::StrCat( + "aval=", nb::cast(nb::repr(py_array.aval())), + ", sharding=", + nb::cast(nb::repr(py_array.sharding())), + ", dst_sharding=", + nb::cast(nb::repr(dst_sharding))); + }; + TF_RETURN_IF_ERROR( + ApplyTransferGuardToDeviceToDevice(transfer_guard_formatter)); + + Batch& batch = batches[BatchedCopyToDeviceWithShardingKey{ + src_devices, src_memory_kind, dst_devices, dst_memory_kind, + array_cs}]; + batch.indexes.push_back(i); + batch.ifrt_arrays.push_back(tsl::FormRef(ifrt_array_ptr)); + } + } + + std::vector> ifrt_arrays; + { + GlobalPyRefManager()->CollectGarbage(); + nb::gil_scoped_release gil_release; + + tsl::profiler::TraceMe copy_traceme( + "BatchedCopyToDeviceWithSharding: dispatch"); + for (auto& [key, batch] : batches) { + TF_ASSIGN_OR_RETURN( + auto copied, + client->CopyArrays( + absl::MakeSpan(batch.ifrt_arrays), + // All arrays in `batch` have the same `key.dst_devices` and + // `key.dst_memory_kind` due to the grouping above. + key.dst_devices, key.dst_memory_kind, key.array_copy_semantics)); + for (int i = 0; i < batch.indexes.size(); ++i) { + ifrt_arrays.push_back( + std::make_pair(batch.indexes[i], std::move(copied[i]))); + } + } + } + + tsl::profiler::TraceMe results_traceme( + "BatchedCopyToDeviceWithSharding create results"); + for (auto& [i, ifrt_array] : ifrt_arrays) { + TF_ASSIGN_OR_RETURN(nb_class_ptr dst_device_list, + GetPyDeviceList(dst_shardings[i])); + nb_class_ptr py_client = dst_device_list->py_client(); + const auto& py_array = py_arrays[i]; + absl::Span shape_span = py_array.shape(); + results[i] = + PyArray(py_array.aval(), py_array.weak_type(), py_array.dtype(), + std::vector(shape_span.begin(), shape_span.end()), + dst_shardings[i], py_client, std::move(ifrt_array), + py_array.committed(), + /*skip_checks=*/true, py_array.result_status()); + } + return results; +} + +absl::StatusOr PyArray::BatchedDevicePut( + nb::object aval, nb::object sharding, std::vector xs, + absl::Span dst_devices, bool committed, + bool force_copy, xla::PjRtClient::HostBufferSemantics host_buffer_semantics, + bool jax_enable_x64) { + if (dst_devices.size() != xs.size()) { + throw nb::value_error( + absl::StrCat("Argument sizes (xs and devices) must match %zu vs %zu", + dst_devices.size(), xs.size()) + .c_str()); + } + for (const PyDevice* device : dst_devices) { + if (device->client().get() == nullptr) { + return xla::InvalidArgument("Cannot copy to unattached devices."); + } + } + auto transfer_guard_formatter = [&aval, &sharding] { + return absl::StrCat( + "aval=", nb::cast(nb::repr(aval)), + ", dst_sharding=", nb::cast(nb::repr(sharding))); + }; + + GlobalPyRefManager()->CollectGarbage(); + + PyUserContextScope user_context_scope; + auto n_devices = dst_devices.size(); + + DevicePutOptions options; + options.squash_64bit_types = !jax_enable_x64; + options.allow_zero_copy = + (!force_copy && (host_buffer_semantics == + ifrt::Client::HostBufferSemantics::kImmutableZeroCopy)); + + std::vector ifrt_arrays; + + absl::InlinedVector devices; + devices.reserve(n_devices); + std::vector shapes; + shapes.reserve(n_devices); + + std::vector args; + args.reserve(xs.size()); + for (const nb::object& x : xs) { + if (PyArray::IsPyArray(x)) { + TF_RETURN_IF_ERROR( + ApplyTransferGuardToDeviceToDevice(transfer_guard_formatter)); + } else { + TF_RETURN_IF_ERROR( + ApplyTransferGuardToHostToDevice(transfer_guard_formatter)); + } + args.push_back(x); + } + auto weak_type = nb::cast(aval.attr("weak_type")); + auto dtype = aval.attr("dtype"); + auto shape = nb::cast>(aval.attr("shape")); + TF_ASSIGN_OR_RETURN(nb_class_ptr py_device_list, + GetPyDeviceList(sharding)); + + TF_ASSIGN_OR_RETURN( + DevicePutResult device_put_result, + DevicePutWithSharding(args, py_device_list->py_client()->ifrt_client(), + dtype, shape, sharding, options)); + + return PyArray(aval, weak_type, dtype, std::move(shape), std::move(sharding), + py_device_list->py_client(), + std::move(device_put_result.ifrt_array), committed, + /*skip_checks=*/true); +} + +absl::StatusOr PyArray::ReorderShards( + PyArray x, nanobind::object dst_sharding, + ifrt::ArrayCopySemantics array_copy_semantics) { + xla::ifrt::Array* ifrt_array_ptr = x.ifrt_array(); + if (ifrt_array_ptr == nullptr) { + return absl::InvalidArgumentError( + "Reorder() called on deleted or donated buffer"); + } + + ifrt::Client* const client = ifrt_array_ptr->client(); + + const auto& device_list = ifrt_array_ptr->sharding().devices(); + TF_ASSIGN_OR_RETURN(auto dst_device_list, GetIfrtDeviceList(dst_sharding)); + if (device_list->AddressableDeviceList()->size() != + dst_device_list->AddressableDeviceList()->size()) { + return absl::InvalidArgumentError(absl::StrCat( + "Array is expected to have ", + dst_device_list->AddressableDeviceList()->size(), + " addressable shards, but has ", + device_list->AddressableDeviceList()->size(), " addressable shards")); + } + + TF_ASSIGN_OR_RETURN( + xla::ifrt::ShardingRef dst_ifrt_sharding, + GetIfrtConcreteEvenSharding(dst_sharding, ifrt_array_ptr->dtype(), + ifrt_array_ptr->shape())); + + // Use the user context of this array. + xla::ifrt::UserContextScope user_context_scope( + ifrt_array_ptr->user_context()); + + xla::ifrt::ArrayRef new_ifrt_array; + { + nb::gil_scoped_release gil_release; + + const absl::Span addressable_devices = + device_list->AddressableDeviceList()->devices(); + const absl::Span dst_addressable_devices = + dst_device_list->AddressableDeviceList()->devices(); + + absl::flat_hash_map device_id_to_array_shard_index; + device_id_to_array_shard_index.reserve(dst_addressable_devices.size()); + for (int i = 0; i < dst_addressable_devices.size(); ++i) { + const int device_id = dst_addressable_devices[i]->Id().value(); + const bool inserted = + device_id_to_array_shard_index.insert({device_id, i}).second; + if (!inserted) { + return absl::InvalidArgumentError( + absl::StrCat("Sharding contains duplicate device id=", device_id)); + } + } + + std::vector from_shard_indices; + from_shard_indices.reserve(addressable_devices.size()); + std::vector to_shard_indices; + to_shard_indices.reserve(dst_addressable_devices.size()); + for (int i = 0; i < dst_addressable_devices.size(); ++i) { + from_shard_indices.push_back(i); + const int shard_device_id = addressable_devices[i]->Id().value(); + const auto it = device_id_to_array_shard_index.find(shard_device_id); + if (it == device_id_to_array_shard_index.end()) { + return absl::InvalidArgumentError(absl::StrCat( + "Array shard ", i, " is on device id=", shard_device_id, + ", but sharding does not have a shard on that device.")); + } + to_shard_indices.push_back(it->second); + } + + auto mappings = + std::make_shared>(); + { + auto& mapping = mappings->emplace_back(); + mapping.in_array = 0; + mapping.out_array = 0; + mapping.from.reserve(dst_addressable_devices.size()); + mapping.to.reserve(dst_addressable_devices.size()); + for (int64_t i = 0; i < dst_addressable_devices.size(); ++i) { + mapping.from.push_back(xla::ifrt::RemapPlan::Interval{ + from_shard_indices[i], from_shard_indices[i] + 1, 1}); + mapping.to.push_back(xla::ifrt::RemapPlan::Interval{ + to_shard_indices[i], to_shard_indices[i] + 1, 1}); + } + } + + xla::ifrt::RemapPlan plan = { + /*input_specs=*/{xla::ifrt::ArraySpec{ + /*dtype=*/ifrt_array_ptr->dtype(), + /*shape=*/ifrt_array_ptr->shape(), + /*sharding=*/ifrt_array_ptr->shared_ptr_sharding()}}, + /*output_specs=*/ + {xla::ifrt::ArraySpec{/*dtype=*/ifrt_array_ptr->dtype(), + /*shape=*/ifrt_array_ptr->shape(), + /*sharding=*/std::move(dst_ifrt_sharding)}}, + /*mappings=*/std::move(mappings), + }; + DCHECK_OK(plan.Validate()); + std::vector input; + input.push_back(tsl::FormRef(ifrt_array_ptr)); + TF_ASSIGN_OR_RETURN( + auto remapped, + client->RemapArrays(plan, absl::MakeSpan(input), array_copy_semantics)); + + TF_RET_CHECK(remapped.size() == 1); + new_ifrt_array = std::move(remapped.front()); + } + + return PyArray(nb::borrow(x.aval().ptr()), x.weak_type(), + nb::borrow(x.dtype().ptr()), + std::vector(x.shape().begin(), x.shape().end()), + std::move(dst_sharding), x.py_client(), + std::move(new_ifrt_array), + /*committed=*/true, + /*skip_checks=*/true); +} + +absl::Status PyArray::BatchedBlockUntilReady(std::vector objs) { + // Create ready futures for all arrays before blocking on their readiness. + // This helps reduce the latency in some backend implementations where + // querying readiness of an array is not free. + + std::vector ifrt_arrays; + ifrt_arrays.reserve(objs.size()); + for (nb::handle obj : objs) { + if (obj.type().is(PyArray::type())) { + auto py_array = nb::borrow(obj); + ifrt::Array* const ifrt_array = py_array.ifrt_array(); + if (ifrt_array == nullptr) { + return absl::InvalidArgumentError( + "BlockHostUntilReady() called on deleted or donated buffer"); + } + ifrt_arrays.push_back(ifrt_array); + } else { + return absl::InvalidArgumentError( + "PyArray::BatchedBlockUntilReady can take PyArray only"); + } + } + + GlobalPyRefManager()->CollectGarbage(); + PyUserContextScope user_context_scope; + absl::Status status; + { + nb::gil_scoped_release gil_release; + status = AwaitBuffersReady(absl::MakeConstSpan(ifrt_arrays)); + } + // `status` can reference an asynchronously propagated `ifrt::UserContext` + // representing the context of an error. We expand this future result right + // before returning it to Python (outside of `nb::gil_scoped_release`) so that + // any attached user context is appended to the status message. + return xla::ifrt::ExpandUserContexts(std::move(status)); +} + +absl::Status PyArray::ReplaceWithAlias(PyArray o) { + auto& storage = GetStorage(); + auto& o_storage = o.GetStorage(); + if (storage.py_client.get() != o_storage.py_client.get()) { + return absl::InvalidArgumentError( + "Unable to replace a PyArray with a PyArray from a different client."); + } + storage.aval = o_storage.aval; + storage.weak_type = o_storage.weak_type; + storage.dtype = o_storage.dtype; + storage.shape = o_storage.shape; + storage.sharding = o_storage.sharding; + storage.npy_value = o_storage.npy_value; + storage.committed = o_storage.committed; + storage.ifrt_array = o_storage.ifrt_array; + storage.fully_replicated_array = o_storage.fully_replicated_array; + storage.py_arrays = o_storage.py_arrays; + storage.host_value.Clear(); + storage.dynamic_shape = o_storage.dynamic_shape; + storage.result_status = o_storage.result_status; + + return absl::OkStatus(); +} + +std::vector PyClient::LiveArrays() const { + std::vector result; + for (auto& shard : arrays_) { + nb::ft_lock_guard lock(shard.mutex); + for (PyArray::Storage* array = shard.arrays; array; array = array->next) { + bool all_deleted = + (array->ifrt_array == nullptr || array->ifrt_array->IsDeleted()); + if (!all_deleted) { + result.push_back(nb::borrow(array->AsHandle())); + } + } + } + return result; +} + +// PEP 3118 buffer protocol implementation. + +namespace { + +// Extra data to be kept alive by the consumer of the buffer protocol. +struct ExtraBufferInfo { + explicit ExtraBufferInfo(std::shared_ptr buffer, + std::unique_ptr + external_reference_hold) + : buffer(std::move(buffer)), + external_reference_hold(std::move(external_reference_hold)) {} + + std::vector strides; + // We keep an external reference hold to the xla::PjRtBuffer. This prevents a + // use-after-free in the event that Delete() is called on a buffer with an + // live buffer protocol view. It does however mean that Delete() sometimes + // won't actually delete immediately. + std::shared_ptr buffer; + std::unique_ptr external_reference_hold; +}; + +// The default layout of a non-tuple array should have major-to-minor layout +// and no tiles. +bool HasDefaultLayout(const xla::Layout& layout) { + return xla::LayoutUtil::IsMonotonicWithDim0Major(layout) && + layout.tiles().empty(); +} + +int PyArray_bf_getbuffer(PyObject* exporter, Py_buffer* view, int flags) { + absl::Status status = [&]() -> absl::Status { + PyArray py_array = nb::borrow(exporter); + if (py_array.ifrt_array() == nullptr) { + // TODO(phawkins): why is this happening? + return xla::InvalidArgument("Array is null"); + } + if (!llvm::isa(py_array.ifrt_array())) { + return xla::InvalidArgument("Only local arrays are supported, got %s", + py_array.ifrt_array()->DebugString()); + } + auto* array = + static_cast(py_array.ifrt_array()); + absl::Span> buffers = + array->pjrt_buffers(); + + if (buffers.empty()) { + return xla::InvalidArgument("Array has no buffers."); + } + xla::PjRtBuffer& buffer = *buffers.front(); + if (!buffer.IsOnCpu()) { + return xla::InvalidArgument( + "Python buffer protocol is only defined for CPU buffers."); + } + + if (buffers.size() != 1) { + return xla::InvalidArgument( + "Python buffer protocol is only defined for buffers with a single " + "shard."); + } + if (!py_array.sharding().type().is(SingleDeviceSharding::type())) { + return xla::InvalidArgument( + "Python buffer protocol is only defined for single-device sharded " + "buffers."); + } + + const char* format = + PEP3118FormatDescriptorForPrimitiveType(buffer.element_type()); + // It isn't an option for us to export unknown types as, say, bytes. When + // converting an object to an ndarray, NumPy tries the buffer protocol + // first. We very much want NumPy to fail and fall back to using + // __array__, which allows us to handle custom dtypes correctly. + if (!format) { + return xla::InvalidArgument( + "Buffers of type %s are not supported by the Python buffer protocol.", + PrimitiveType_Name(buffer.element_type())); + } + + std::unique_ptr external_reference_hold; + { + // We call BlockHostUntilReady() below, which may block. + nb::gil_scoped_release gil_release; + + if (buffer.IsTuple()) { + return xla::InvalidArgument( + "Python buffer protocol is only defined for array buffers."); + } + if ((flags & PyBUF_WRITEABLE) == PyBUF_WRITEABLE) { + return xla::InvalidArgument("XLA buffers are read-only."); + } + TF_ASSIGN_OR_RETURN(external_reference_hold, + buffer.AcquireExternalReference()); + if (buffer.IsDeleted()) { + return xla::InvalidArgument("Deleted buffer used in buffer protocol."); + } + + // TODO(b/327524065): use xla::PjRtLayout directly instead of xla::Layout + xla::Layout xla_layout = buffer.layout()->xla_layout(); + + if (((flags & PyBUF_C_CONTIGUOUS) == PyBUF_C_CONTIGUOUS || + (flags & PyBUF_STRIDES) == PyBUF_ND) && + !xla::LayoutUtil::IsMonotonicWithDim0Major(xla_layout)) { + return xla::InvalidArgument("Buffer is not in C-contiguous layout."); + } else if ((flags & PyBUF_F_CONTIGUOUS) == PyBUF_F_CONTIGUOUS && + !xla::LayoutUtil::IsMonotonicWithDim0Minor(xla_layout)) { + return xla::InvalidArgument("Buffer is not in F-contiguous layout."); + } else if ((flags & PyBUF_ANY_CONTIGUOUS) == PyBUF_ANY_CONTIGUOUS && + !xla::LayoutUtil::IsMonotonicWithDim0Major(xla_layout) && + !xla::LayoutUtil::IsMonotonicWithDim0Minor(xla_layout)) { + return xla::InvalidArgument("Buffer is not in contiguous layout."); + } else if (!HasDefaultLayout(xla_layout)) { + // Fail and fall back to using __array__ if the CPU buffer has a device + // specific layout. For instance, this happens for host buffers in + // pinned memories of the TPU device. + return xla::InvalidArgument( + "Buffer is potentially a device buffer with non default layout."); + } + TF_RETURN_IF_ERROR(buffer.GetReadyFuture().Await()); + } + + // We must hold the GIL (or at least prevent Python GC) while writing to the + // view object, see https://github.com/python/cpython/issues/130409. + std::memset(view, 0, sizeof(Py_buffer)); + const void* root_ptr = + external_reference_hold->OpaqueDeviceMemoryDataPointer(); + view->buf = const_cast(root_ptr); + auto extra = std::make_unique( + buffers.front(), std::move(external_reference_hold)); + view->itemsize = + xla::ShapeUtil::ByteSizeOfPrimitiveType(buffer.element_type()); + TF_ASSIGN_OR_RETURN(view->len, buffer.GetOnDeviceSizeInBytes()); + view->readonly = 1; + if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) { + view->format = const_cast(format); + } + if ((flags & PyBUF_ND) == PyBUF_ND) { + view->ndim = buffer.dimensions().size(); + static_assert(sizeof(int64_t) == sizeof(Py_ssize_t), + "Py_ssize_t must be 64 bits"); + if (view->ndim != 0) { + view->shape = reinterpret_cast( + const_cast(buffer.dimensions().data())); + if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) { + extra->strides = + ByteStridesForShape(buffer.element_type(), buffer.dimensions(), + buffer.layout()->xla_layout()); + view->strides = reinterpret_cast( + const_cast(extra->strides.data())); + } + } + } + view->internal = extra.release(); + return absl::OkStatus(); + }(); + if (!status.ok()) { + // numpy.asarray(...) eats the PyExc_BufferError. Adding a log here helps + // debugging when the error really occurs. + VLOG(1) << "Buffer Protocol Error: " << status; + PyErr_SetString(PyExc_BufferError, status.ToString().c_str()); + return -1; + } + view->obj = exporter; + Py_INCREF(view->obj); + return 0; +} + +void PyArray_bf_releasebuffer(PyObject*, Py_buffer* buffer) { + auto extra = static_cast(buffer->internal); + delete extra; +} + +// Returns if shape has a major-to-minor layout. +bool HasMajorToMinorLayout(const xla::Shape& shape) { + if (shape.has_layout()) { + for (int i = 0; i < shape.layout().minor_to_major().size(); ++i) { + if (shape.layout().minor_to_major(i) != + shape.layout().minor_to_major().size() - 1 - i) { + return false; + } + } + } + return true; +} + +// Returns byte_strides if shape has a non-major-to-minor layout. +std::optional> ByteStridesOrDefaultForShapeInt64( + const xla::Shape& shape) { + if (!shape.has_layout() || HasMajorToMinorLayout(shape)) { + return std::nullopt; + } + return ByteStridesForShape(shape); +} + +bool IsZeroCopyableCpuBuffer(const xla::PjRtBuffer* buf) { + // For CPU buffers with device-specific layouts, we must delinearize + // to unpack the array. This could happen for the host buffer + // pre-mapped to the TPU device, a.k.a., pinned host buffers for the + // device. + bool has_default_layout = + buf->layout() == nullptr || HasDefaultLayout(buf->layout()->xla_layout()); + // On CPU for values >= 8 bits, we can return the value in a zero-copy way. + // For sub-byte values, we must copy in order to unpack the array. + return buf->IsOnCpu() && + !xla::primitive_util::IsSubByteNonPredType(buf->element_type()) && + has_default_layout; +} +} // namespace + +PyHostValue::PyHostValue() = default; +PyHostValue::~PyHostValue() = default; + +absl::StatusOr> PyHostValue::AsNumPyArray( + std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array) { + if (ifrt_array->IsDeleted()) { + return xla::InvalidArgument("DeviceArray has been deleted."); + } + // The only `jax.Array` with token-shape buffer is the one wrapped by + // `jax.core.Token`. Since it is an internal implementation detail, we + // don't support converting it to a numpy array. + if (ifrt_array->dtype().kind() == ifrt::DType::kToken) { + return xla::InvalidArgument( + "Cannot convert a token-shape buffer to a numpy array."); + } + auto* arr = llvm::dyn_cast_or_null(ifrt_array); + if (arr != nullptr) { + auto* pjrt_buffer = arr->pjrt_buffers().front().get(); + TF_RET_CHECK(!pjrt_buffer->IsTuple()); + // On CPU for values >= 8 bits, we can return the value in a zero-copy way. + // For sub-byte values, we must copy in order to unpack the array. + if (IsZeroCopyableCpuBuffer(pjrt_buffer)) { + TF_ASSIGN_OR_RETURN(const auto* shape, + XlaDynamicShape(ifrt_array, dynamic_shape_holder)); + TF_ASSIGN_OR_RETURN(xla::nb_dtype dtype, + PrimitiveTypeToNbDtype(shape->element_type())); + // Objects that must be kept alive while the array is alive. + struct Hold { + ifrt::ArrayRef buffer; + std::unique_ptr + external_reference_hold; + }; + auto hold = std::make_unique(); + hold->buffer = tsl::FormRef(ifrt_array); + auto* hold_ptr = hold.release(); + nb::capsule hold_capsule( + hold_ptr, [](void* h) noexcept { delete static_cast(h); }); + { + // Release the GIL as `AcquireExternalReference` may block. + nb::gil_scoped_release gil; + TF_ASSIGN_OR_RETURN(hold_ptr->external_reference_hold, + pjrt_buffer->AcquireExternalReference()); + auto fut = ifrt_array->GetReadyFuture(); + BlockUntilReadyWithCancel(fut); + TF_RETURN_IF_ERROR(fut.Await()); + } + void* data = + hold_ptr->external_reference_hold->OpaqueDeviceMemoryDataPointer(); + xla::nb_numpy_ndarray array(dtype, shape->dimensions(), + ByteStridesForShape(*shape), data, + hold_capsule); + array.attr("flags").attr("writeable") = nb::bool_(false); + return std::make_pair(array, false); + } + } + + PyUserContextScope user_context_scope; + TF_RETURN_IF_ERROR(CopyToHostAsync(dynamic_shape_holder, ifrt_array)); + absl::Status status; + if (!ready_.IsReady()) { + nb::gil_scoped_release gil; + BlockUntilReadyWithCancel(ready_); + status = ready_.Await(); + } else { + status = ready_.Await(); + } + if (!status.ok()) { + // `ready_` is the returned future of `ifrt::Array::CopyToHostBuffer`, which + // can reference an asynchronously propagated `ifrt::UserContext` + // representing the context of an error. We expand this future result right + // before returning it to Python (outside of `nb::gil_scoped_release`) so + // that any attached user context is appended to the status message. + return xla::ifrt::ExpandUserContexts(std::move(status)); + } + if (string_array_contents_ != nullptr) { + TF_RETURN_IF_ERROR(ConvertStringArrayContentsToNumpyArray(ifrt_array)); + } + return std::make_pair(value_, true); +} + +absl::Status PyHostValue::ConvertStringArrayContentsToNumpyArray( + ifrt::Array* ifrt_array) { +#ifdef NPY_2_0_API_VERSION + if (PyArray_RUNTIME_VERSION < NPY_2_0_API_VERSION) { + return absl::FailedPreconditionError( + absl::StrCat("String arrays are not supported in NumPy version: ", + PyArray_RUNTIME_VERSION)); + } + auto numpy_dtype = nb::steal( + reinterpret_cast(PyArray_DescrFromType(NPY_VSTRING))); + value_ = xla::nb_numpy_ndarray(numpy_dtype, ifrt_array->shape().dims(), + /*strides=*/std::nullopt); + + auto dst_py_array_obj = reinterpret_cast<::PyArrayObject*>(value_.ptr()); + auto iter = + nb::steal(PyArray_IterNew(reinterpret_cast(dst_py_array_obj))); + for (auto& cord : *string_array_contents_) { + std::string_view input_str_view = cord.Flatten(); + auto py_unicode = nb::steal(PyUnicode_FromStringAndSize( + input_str_view.data(), input_str_view.size())); + if (py_unicode.ptr() == nullptr) { + return absl::InternalError("PyUnicode_FromStringAndSize failed"); + } + if (PyArray_SETITEM(dst_py_array_obj, + static_cast(PyArray_ITER_DATA(iter.ptr())), + py_unicode.ptr()) != 0) { + return absl::InternalError("PyArray_SETITEM failed"); + } + PyArray_ITER_NEXT(iter.ptr()); + } + + value_.attr("flags").attr("writeable") = nb::bool_(false); + + string_array_contents_.reset(); + + return absl::OkStatus(); +#else + return absl::FailedPreconditionError( + "String arrays are not supported in this NumPy version."); +#endif +} + +absl::Status PyHostValue::CopyStringArrayToHostAsync( + std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array) { + auto transfer_guard_formatter = [ifrt_array] { + return absl::StrCat( + "shape=(", absl::StrJoin(ifrt_array->shape().dims(), ","), + "), dtype=", ifrt_array->dtype().DebugString(), ", device=", + ifrt_array->sharding().devices()->devices().front()->DebugString()); + }; + TF_RETURN_IF_ERROR( + ApplyTransferGuardToDeviceToHost(transfer_guard_formatter)); + + TF_ASSIGN_OR_RETURN(xla::nb_dtype dtype, + xla::IfrtDtypeToNbDtype(ifrt_array->dtype())); + auto shape = ifrt_array->shape(); + + // Allocate a vector of cords to hold the contents of the array until + // they are until they are ultimately converted to a numpy array as part + // of the `AsNumPyArray` call. + string_array_contents_ = + std::make_shared>(shape.num_elements()); + PyUserContextScope user_context_scope; + ready_ = ifrt_array->CopyToHostBuffer(string_array_contents_->data(), + /*byte_strides=*/std::nullopt, + ifrt::ArrayCopySemantics::kAlwaysCopy); + + ready_.OnReady( + [string_array_contents = string_array_contents_](absl::Status) { + }); // Keeps the cords alive until the copy is done. + + return absl::OkStatus(); +} + +absl::Status PyHostValue::CopyToHostAsync( + std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array) { + if (ready_.IsValid()) { + // The array value has been populated, so CopyToHostAsync has been called. + return absl::OkStatus(); + } + + // Copying in Arrays of type kString requires some special handling + if (ifrt_array->dtype().kind() == ifrt::DType::kString) { + return CopyStringArrayToHostAsync(dynamic_shape_holder, ifrt_array); + } + + auto* arr = llvm::dyn_cast_or_null(ifrt_array); + if (arr != nullptr && !arr->pjrt_buffers().front()->IsTuple() && + IsZeroCopyableCpuBuffer(arr->pjrt_buffers().front().get())) { + return absl::OkStatus(); + } + auto transfer_guard_formatter = [ifrt_array] { + return absl::StrCat( + "shape=(", absl::StrJoin(ifrt_array->shape().dims(), ","), + "), dtype=", ifrt_array->dtype().DebugString(), ", device=", + ifrt_array->sharding().devices()->devices().front()->DebugString()); + }; + TF_RETURN_IF_ERROR( + ApplyTransferGuardToDeviceToHost(transfer_guard_formatter)); + + // TODO(b/182461453): This is a blocking call. If we further implemented + // populating dynamic shape metadata while fetching the literal, we wouldn't + // need this static approach. + const xla::Shape* dynamic_shape; + std::optional shape_holder; + if (llvm::isa(ifrt_array)) { + TF_ASSIGN_OR_RETURN(dynamic_shape, + XlaDynamicShape(ifrt_array, dynamic_shape_holder)); + } else { + // Skip querying the dynamic shape for a non-PjRt Array. + TF_ASSIGN_OR_RETURN(xla::PrimitiveType type, + ifrt::ToPrimitiveType(ifrt_array->dtype())); + shape_holder = xla::ShapeUtil::MakeShapeWithDescendingLayout( + type, ifrt_array->shape().dims()); + dynamic_shape = &*shape_holder; + } + + xla::Shape host_shape = + xla::ShapeUtil::DeviceShapeToHostShape(*dynamic_shape); + + auto strides = ByteStridesOrDefaultForShapeInt64(host_shape); + TF_ASSIGN_OR_RETURN(xla::nb_dtype dtype, + PrimitiveTypeToNbDtype(host_shape.element_type())); + value_ = xla::nb_numpy_ndarray(dtype, host_shape.dimensions(), strides); + // TODO(hyeontaek): Several PjRt runtimes assume that the host buffer uses + // the same transposition as the device buffer. This is different from + // xla::PjRtBuffer::ToLiteral()'s semantics that the runtime respects the + // layout of the host buffer literal. On the other hand, the runtime often + // knows better about an efficient layout for the host buffer. It will be + // useful to revisit the semantics of xla::PjRtBuffer::ToLiteral() to see if + // it is desirable for the runtime to choose the layout. + PyUserContextScope user_context_scope; + ready_ = ifrt_array->CopyToHostBuffer(value_.mutable_data(), strides, + ifrt::ArrayCopySemantics::kAlwaysCopy); + // Make sure the destination of the copy remains alive until the copy is done. + value_.inc_ref(); + ready_.OnReady([array{value_.ptr()}](absl::Status status) { + GlobalPyRefManager()->AddGarbage(nb::steal(array)); + }); + value_.attr("flags").attr("writeable") = nb::bool_(false); + return absl::OkStatus(); +} + +void PyHostValue::Clear() { + ready_ = {}; + value_ = {}; + string_array_contents_ = {}; +} + +namespace { + +PyType_Slot array_meta_slots[] = { + {Py_tp_base, &PyType_Type}, + {0, nullptr}, +}; + +PyType_Slot array_slots[] = { + {Py_tp_dealloc, reinterpret_cast(PyBaseArray_tp_dealloc)}, + {Py_tp_traverse, reinterpret_cast(PyBaseArray_tp_traverse)}, + {Py_tp_hash, reinterpret_cast(PyObject_HashNotImplemented)}, + {0, nullptr}, +}; + +PyGetSetDef array_impl_tp_getset[] = { + {"__dict__", PyObject_GenericGetDict, PyObject_GenericSetDict, nullptr, + nullptr}, + {nullptr, nullptr, nullptr, nullptr, nullptr}, +}; + +PyMemberDef array_impl_members[] = { +#if PY_VERSION_HEX < 0x030C0000 + {"__weaklistoffset__", T_PYSSIZET, + static_cast(offsetof(PyArrayObject, weakrefs)), READONLY, + nullptr}, + {"__dictoffset__", T_PYSSIZET, + static_cast(offsetof(PyArrayObject, dict)), READONLY, nullptr}, +#endif // PY_VERSION_HEX < 0x030C0000 + {nullptr, 0, 0, 0, nullptr}, +}; // namespace jax + +PyType_Slot array_impl_slots[] = { + {Py_tp_new, reinterpret_cast(PyArray_tp_new)}, + {Py_tp_dealloc, reinterpret_cast(PyArray_tp_dealloc)}, + {Py_tp_members, reinterpret_cast(array_impl_members)}, + {Py_tp_traverse, reinterpret_cast(PyArray_tp_traverse)}, + {Py_tp_clear, reinterpret_cast(PyArray_tp_clear)}, + {Py_tp_getset, reinterpret_cast(array_impl_tp_getset)}, + {Py_bf_getbuffer, reinterpret_cast(PyArray_bf_getbuffer)}, + {Py_bf_releasebuffer, reinterpret_cast(PyArray_bf_releasebuffer)}, + {0, nullptr}, +}; + +// TODO(phawkins): remove this code when we drop support for Python < 3.12 +PyObject* MakeArrayTypeFromMetaclass(PyTypeObject* meta, PyObject* module, + PyType_Spec* spec) { +#if PY_VERSION_HEX >= 0x030C0000 + return PyType_FromMetaclass(meta, module, spec, nullptr); +#else + nb::str name = nb::steal(PyUnicode_InternFromString(spec->name)); + const char* name_cstr = PyUnicode_AsUTF8AndSize(name.ptr(), nullptr); + if (!name_cstr) { + return nullptr; + } + + PyHeapTypeObject* ht = + reinterpret_cast(PyType_GenericAlloc(meta, 0)); + if (!ht) { + return nullptr; + } + ht->ht_name = name.inc_ref().ptr(); + ht->ht_qualname = name.inc_ref().ptr(); + Py_INCREF(module); + ht->ht_module = module; + + PyTypeObject* tp = &ht->ht_type; + tp->tp_name = name_cstr; + tp->tp_basicsize = spec->basicsize; + tp->tp_itemsize = spec->itemsize; + tp->tp_flags = spec->flags | Py_TPFLAGS_HEAPTYPE; + tp->tp_as_async = &ht->as_async; + tp->tp_as_number = &ht->as_number; + tp->tp_as_sequence = &ht->as_sequence; + tp->tp_as_mapping = &ht->as_mapping; + tp->tp_as_buffer = &ht->as_buffer; + + for (PyType_Slot* slot = spec->slots; slot->slot; slot++) { + switch (slot->slot) { + case Py_tp_dealloc: + tp->tp_dealloc = reinterpret_cast(slot->pfunc); + break; + case Py_tp_traverse: + tp->tp_traverse = reinterpret_cast(slot->pfunc); + break; + case Py_tp_hash: + tp->tp_hash = reinterpret_cast(slot->pfunc); + break; + default: + // TODO(phawkins): support other slots as needed. + LOG(FATAL) << "Unsupported slot: " << slot->slot; + } + } + + if (PyType_Ready(tp) != 0) { + Py_DECREF(tp); + return nullptr; + } + + return reinterpret_cast(tp); +#endif +} + +} // namespace + +absl::Status PyArray::Register(nb::module_& m) { + std::string metaclass_name = + absl::StrCat(nb::cast(m.attr("__name__")), ".ArrayMeta"); + PyType_Spec array_meta_spec = { + /*.name=*/metaclass_name.c_str(), + /*.basicsize=*/0, + /*.itemsize=*/0, + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, + /*.slots=*/array_meta_slots}; + nb::object array_meta_type = + nb::steal(PyType_FromSpec(&array_meta_spec)); + if (!array_meta_type) { + throw nb::python_error(); + } + m.attr("ArrayMeta") = array_meta_type; + + // We are not using nanobind to avoid having a non-standard metaclass, which + // would make Array incompatible with abc.ABCMeta. + std::string base_name = + absl::StrCat(nb::cast(m.attr("__name__")), ".Array"); + PyType_Spec array_spec = { + /*.name=*/base_name.c_str(), + /*.basicsize=*/static_cast(sizeof(PyBaseArrayObject)), + /*.itemsize=*/0, + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, + /*.slots=*/array_slots}; + nb::object base_type = nb::steal(MakeArrayTypeFromMetaclass( + reinterpret_cast(array_meta_type.ptr()), m.ptr(), + &array_spec)); + if (!base_type) { + throw nb::python_error(); + } + m.attr("Array") = base_type; + + m.def("set_tracer_class", [](nb::object f) { tracer_class = f; }); + + nb::object type_instancecheck = + nb::borrow(reinterpret_cast(&PyType_Type)) + .attr("__instancecheck__"); + array_meta_type.attr("__instancecheck__") = nb::cpp_function( + [base_type, type_instancecheck](nb::object self, nb::object x) { + // We are calling type's instancecheck method rather than + // PyObject_TypeCheck to avoid breaking users who use wrapt.ObjectProxy, + // such as TFP's NumpyVariable. + if (nb::cast(type_instancecheck(self, x))) { + return true; + } + // Instances of Tracer that have array avals are considered instances of + // Array. + if (tracer_class.ptr() && self.ptr() == base_type.ptr() && + PyObject_TypeCheck(x.ptr(), reinterpret_cast( + tracer_class.ptr())) != 0) { + // TODO(phawkins): we would like to change this to use the logic below + // but it is a somewhat breaking change. Let us defer it to a future + // PR. + return true; + // auto is_traced_array_fn = + // nb::getattr(x, "_is_traced_array", nb::none()); + // if (!is_traced_array_fn.is_none()) { + // try { + // return nb::cast(is_traced_array_fn()); + // } catch (...) { + // } + // } + } + return false; + }, + nb::is_method(), nb::arg("x").none()); + + std::string name = + absl::StrCat(nb::cast(m.attr("__name__")), ".ArrayImpl"); + + PyType_Spec array_impl_spec = { + /*.name=*/name.c_str(), + /*.basicsize=*/static_cast(sizeof(PyArrayObject)), + /*.itemsize=*/0, +#if PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, +#else // PY_VERSION_HEX >= 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_MANAGED_DICT | Py_TPFLAGS_MANAGED_WEAKREF, +#endif // PY_VERSION_HEX >= 0x030C0000 + /*.slots=*/array_impl_slots, + }; + + type_ = PyType_FromSpecWithBases(&array_impl_spec, base_type.ptr()); + if (!type_) { + throw nb::python_error(); + } + auto type = nb::borrow(type_); + m.attr("ArrayImpl") = type; + + type.attr("__init__") = nb::cpp_function( + [](PyArray self, nb::object aval, nb::object sharding, nb::list arrays, + bool committed, bool skip_checks) { + if (!(arrays.size() == 0 || arrays[0].type().is(PyArray::type()))) { + throw nb::type_error( + absl::StrCat( + "Unsupported type for elements in `arrays`: ", + nb::cast(nb::str(arrays[0].type()))) + .c_str()); + } + auto py_arrays = nb::cast>(arrays); + PyArray::PyInit(self, std::move(aval), std::move(sharding), py_arrays, + committed, skip_checks); + }, + nb::is_method(), nb::arg("aval"), nb::arg("sharding"), nb::arg("arrays"), + nb::arg("committed"), nb::arg("_skip_checks") = false); + type.attr("delete") = nb::cpp_function( + [](PyArray& self) { xla::ThrowIfError(self.Delete()); }, nb::is_method()); + type.attr("_sharding") = xla::nb_property_readonly(&PyArray::sharding); + type.attr("aval") = xla::nb_property(&PyArray::aval, &PyArray::set_aval); + type.attr("_arrays") = + xla::nb_property(&PyArray::arrays, [](PyArray& self, nb::object obj) { + xla::ThrowIfError(self.set_arrays(obj)); + }); + type.attr("_fully_replicated_shard") = nb::cpp_function( + [](PyArray self) { + return xla::ValueOrThrow(self.FullyReplicatedShard()); + }, + nb::is_method()); + type.attr("_npy_value") = + xla::nb_property(&PyArray::npy_value, &PyArray::set_npy_value); + type.attr("_committed") = xla::nb_property_readonly(&PyArray::committed); + type.attr("unsafe_buffer_pointer") = nb::cpp_function( + [](PyArray self) { + return xla::ValueOrThrow(self.UnsafeBufferPointer()); + }, + nb::is_method()); + type.attr("__cuda_array_interface__") = xla::nb_property_readonly( + [](PyArray self) { return self.CudaArrayInterface(); }); + type.attr("_pjrt_layout") = + xla::nb_property_readonly(xla::ValueOrThrowWrapper(&PyArray::layout)); + type.attr("on_device_size_in_bytes") = nb::cpp_function( + xla::ValueOrThrowWrapper(&PyArray::GetOnDeviceSizeInBytes), + nb::is_method()); + type.attr("_single_device_array_to_np_array_did_copy") = nb::cpp_function( + xla::ValueOrThrowWrapper(&PyArray::SingleDeviceArrayToNumpyArrayDidCopy), + nb::is_method()); + type.attr("_copy_single_device_array_to_host_async") = nb::cpp_function( + [](PyArray& self) { + xla::ThrowIfError(self.CopySingleDeviceArrayToHostAsync()); + }, + nb::is_method()); + type.attr("_replace_with") = nb::cpp_function( + [](PyArray& self, PyArray& o) { + xla::ThrowIfError(self.ReplaceWithAlias(o)); + }, + nb::is_method()); + type.attr("block_until_ready") = nb::cpp_function( + [](PyArray self) -> nb::object { + xla::ThrowIfError(self.BlockUntilReady()); + return self; + }, + nb::is_method()); + type.attr("platform") = nb::cpp_function( + [](PyArray self) { + const xla::ifrt::DeviceListRef& devices = + self.ifrt_array()->sharding().devices(); + absl::string_view platform_name = + devices->devices().front()->PlatformName(); + if (platform_name == "cuda" || platform_name == "rocm") { + return std::string_view("gpu"); + } else { + return platform_name; + } + }, + nb::is_method()); + type.attr("is_ready") = nb::cpp_function( + [](PyArray self) { return xla::ValueOrThrow(self.IsReady()); }, + nb::is_method()); + type.attr("is_deleted") = + nb::cpp_function(&PyArray::IsDeleted, nb::is_method()); + type.attr("traceback") = xla::nb_property_readonly(&PyArray::traceback); + type.attr("clone") = nb::cpp_function(&PyArray::Clone, nb::is_method()); + type.attr("__module__") = m.attr("__name__"); + + m.attr("batched_copy_array_to_devices_with_sharding") = nb::cpp_function( + [](absl::Span arrays, + absl::Span> dst_device_lists, + absl::Span shardings, + absl::Span array_copy_semantics) { + if (arrays.empty()) { + return std::vector(); + } + tsl::profiler::TraceMe traceme( + "batched_copy_array_to_devices_with_sharding"); + std::vector device_lists; + { + tsl::profiler::TraceMe device_list_traceme( + "batched_copy_array_to_devices_with_sharding: assemble device " + "lists"); + device_lists.reserve(dst_device_lists.size()); + for (const auto& dst_devices : dst_device_lists) { + device_lists.push_back( + xla::ValueOrThrow(dst_devices->ifrt_device_list())); + } + } + return xla::ValueOrThrow(PyArray::BatchedCopyToDeviceWithSharding( + arrays, device_lists, shardings, array_copy_semantics)); + }); + m.attr("array_result_handler") = nb::cpp_function( + [](nb::object aval, nb::object sharding, bool committed, + bool skip_checks) -> nb_class_ptr { + return make_nb_class( + std::move(aval), std::move(sharding), committed, skip_checks); + }, + nb::arg("aval"), nb::arg("sharding"), nb::arg("committed"), + nb::arg("_skip_checks") = false); + + nb::class_(m, "ResultHandler") + .def( + "__call__", + [](const PyArrayResultHandler& self, nb::object arg) { + if (PyArray py_array; nb::try_cast(arg, py_array)) { + return self.Call(py_array); + } + if (std::vector py_arrays; + nb::try_cast>(arg, py_arrays)) { + return self.Call(py_arrays); + } + throw nb::type_error( + absl::StrCat( + "Expected a single PyArray or a sequence of PyArrays, got ", + nb::cast(nb::str(arg.type()))) + .c_str()); + }, + nb::sig( + "def __call__(self, arg: Array | Sequence[Array], /) -> Array")) + .def("wrap", + [](const PyArrayResultHandler& self, nb::callable wrapper) { + auto wrappers = self.wrappers(); + wrappers.push_back(std::move(wrapper)); + return make_nb_class( + self.aval(), self.sharding(), self.committed(), + self.skip_checks(), std::move(wrappers)); + }) + .def("pre_wrap", + [](const PyArrayResultHandler& self, nb::callable wrapper) { + auto wrappers = self.wrappers(); + wrappers.insert(wrappers.begin(), std::move(wrapper)); + return make_nb_class( + self.aval(), self.sharding(), self.committed(), + self.skip_checks(), std::move(wrappers)); + }); + + return absl::OkStatus(); +} + +} // namespace jax diff --git a/jaxlib/py_array.h b/jaxlib/py_array.h new file mode 100644 index 000000000000..f4d1d59b99e9 --- /dev/null +++ b/jaxlib/py_array.h @@ -0,0 +1,397 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_ARRAY_H_ +#define JAXLIB_PY_ARRAY_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_user_context.h" +#include "jaxlib/traceback.h" +#include "xla/future.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/user_context.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/shape.h" +#include "xla/tsl/concurrency/future.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace jax { + +// Private to PyArray, but you cannot forward declare member classes. +// Not thread safe; assumes the GIL is held. +class PyHostValue { + public: + PyHostValue(); + ~PyHostValue(); + + PyHostValue(const PyHostValue&) = delete; + PyHostValue(PyHostValue&&) = delete; + PyHostValue& operator=(const PyHostValue&) = delete; + PyHostValue& operator=(PyHostValue&&) = delete; + + absl::Status CopyToHostAsync(std::optional& dynamic_shape_holder, + xla::ifrt::Array* ifrt_array); + + absl::StatusOr> AsNumPyArray( + std::optional& dynamic_shape_holder, + xla::ifrt::Array* ifrt_array); + + void Clear(); + + private: + absl::Status CopyStringArrayToHostAsync( + std::optional& dynamic_shape_holder, + xla::ifrt::Array* ifrt_array); + + absl::Status ConvertStringArrayContentsToNumpyArray( + xla::ifrt::Array* ifrt_array); + + tsl::Future<> ready_; + xla::nb_numpy_ndarray value_; + + // Optional field, only used for arrays of type kString. This vector of cords + // serves as input buffer for the CopyToHostBuffer call. It holds these + // contents until it is lazily converted it to a numpy array when the user + // calls `AsNumPyArray`. + std::shared_ptr> string_array_contents_; +}; + +// Private to PyArray, but you cannot forward declare member classes. +struct PyArray_Storage { + PyArray_Storage(nanobind::object aval, bool weak_type, xla::nb_dtype dtype, + std::vector shape, nanobind::object sharding, + bool committed, nb_class_ptr py_client, + xla::ifrt::ArrayRef ifrt_array, xla::Future<> result_status); + + ~PyArray_Storage(); + nanobind::handle AsHandle(); + + nanobind::object aval; + bool weak_type = false; + xla::nb_dtype dtype; + std::vector shape; + + nanobind::object sharding; + nanobind::object npy_value = nanobind::none(); + bool committed = false; + + nb_class_ptr py_client; + xla::ifrt::ArrayRef ifrt_array; + nanobind::object fully_replicated_array = nanobind::none(); + + // optional field, used only in python + std::vector py_arrays; + PyHostValue host_value; // Protected by the GIL. + std::optional dynamic_shape = std::nullopt; + // Only set if this Array was generated by a computation that has effects. + // This is the result status of the XLA computation that generated this + // array. + xla::Future<> result_status; + + // Doubly-linked list of all PyArrays known to the client. Protected by the + // GIL. Since multiple PyArrays may share the same PjRtBuffer, there may be + // duplicate PjRtBuffers in this list. + PyArray_Storage* next; + PyArray_Storage* prev; + + uint8_t thread_id_bucket; +}; + +// The C++ implementation of jax.Array. A few key methods and data members are +// implemented in C++ for performance, while most of the functionalities are +// still implemented in python. +class PyArray : public nanobind::object { + public: + NB_OBJECT(PyArray, nanobind::object, "Array", PyArray::IsPyArray); + PyArray() = default; + + // "__init__" methods. Only used in python + static void PyInit(PyArray self, nanobind::object aval, + nanobind::object sharding, + absl::Span py_arrays, bool committed, + bool skip_checks); + + // Only used in C++. `skip_checks` should only be set for Arrays created by + // jax that cannot possibly have consistency issues (e.g. `sharding` devices + // different than `ifrt_array` devices). Arrays created by users should be + // checked. + PyArray(nanobind::object aval, bool weak_type, xla::nb_dtype dtype, + std::vector shape, nanobind::object sharding, + nb_class_ptr py_client, xla::ifrt::ArrayRef ifrt_array, + bool committed, bool skip_checks, + xla::Future<> result_status = xla::Future<>()); + + static PyArray MakeFromSingleDeviceArray( + nb_class_ptr py_client, xla::ifrt::ArrayRef ifrt_array, + bool weak_type, bool committed, + xla::Future<> result_status = xla::Future<>()); + + static PyArray MakeFromIfrtArrayAndSharding(nb_class_ptr py_client, + xla::ifrt::ArrayRef ifrt_array, + nanobind::object sharding, + bool weak_type, bool committed, + bool skip_checks); + + // Registers Array and related types in module m. + static absl::Status Register(nanobind::module_& m); + + static PyArray borrow(PyObject* ptr) { + return nanobind::borrow(ptr); + } + + using Storage = PyArray_Storage; + + const nanobind::object& aval() const { return GetStorage().aval; } + void set_aval(nanobind::object aval) { GetStorage().aval = std::move(aval); } + + bool weak_type() const { return GetStorage().weak_type; } + + const xla::nb_dtype& dtype() const { return GetStorage().dtype; } + absl::Span shape() const { return GetStorage().shape; } + + const nanobind::object& sharding() const { return GetStorage().sharding; } + + absl::StatusOr> layout() { + xla::ifrt::Array* ifrt_array_ptr = ifrt_array(); + TF_ASSIGN_OR_RETURN(std::shared_ptr layout, + ifrt_array_ptr->pjrt_layout()); + if (layout == nullptr) { + TF_ASSIGN_OR_RETURN( + xla::ifrt::Shape shard_shape, + ifrt_array_ptr->sharding().GetShardShape(ifrt_array_ptr->shape())); + TF_ASSIGN_OR_RETURN( + layout, ifrt_array_ptr->client()->GetDefaultPjRtLayout( + ifrt_array_ptr->dtype(), shard_shape.dims(), + ifrt_array_ptr->sharding().devices()->devices().front(), + ifrt_array_ptr->sharding().memory_kind())); + } + return layout; + } + + bool committed() const { return GetStorage().committed; } + + const nanobind::object& npy_value() const { return GetStorage().npy_value; } + void set_npy_value(nanobind::object v) { + GetStorage().npy_value = std::move(v); + } + + const nb_class_ptr& py_client() const { + return GetStorage().py_client; + } + + std::optional traceback() const { + xla::ifrt::Array* ifrt_array_ptr = ifrt_array(); + if (ifrt_array_ptr == nullptr) { + return std::nullopt; + } + return GetTraceback(ifrt_array_ptr->user_context().get()); + } + + // Returns xla::InvalidArgument if the buffer has been deleted. + // See `Future` for the semantics of `IsReady` and `IsKnownReady`. + absl::StatusOr IsReady() { + xla::ifrt::Array* ifrt_array_ptr = ifrt_array(); + if (ifrt_array_ptr->IsDeleted()) { + return xla::InvalidArgument("Array has been deleted."); + } + xla::ifrt::UserContextScope user_context_scope( + jax::PyUserContext::Create()); + return ifrt_array_ptr->GetReadyFuture().IsReady(); + } + + const xla::Future<>& result_status() const { + return GetStorage().result_status; + } + + xla::ifrt::Array* ifrt_array() const { return GetStorage().ifrt_array.get(); } + + // Short-term escape hatch to get PjRtBuffers from PyArray. + // TODO(hyeontaek): Migrate all users of this method to be agnostic of PjRt. + absl::Span> pjrt_buffers() const { + xla::ifrt::Array* ifrt_array_ptr = ifrt_array(); + if (ifrt_array_ptr == nullptr) { + return {}; + } + auto* arr = + llvm::dyn_cast_or_null(ifrt_array_ptr); + if (arr == nullptr) { + throw xla::XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return arr->pjrt_buffers(); + } + + int num_addressable_shards() const { + xla::ifrt::Array* ifrt_array_ptr = ifrt_array(); + if (ifrt_array_ptr == nullptr) { + return 0; + } + auto* arr = + llvm::dyn_cast_or_null(ifrt_array_ptr); + if (arr == nullptr) { + // TODO(hyeontaek): Add num_addressable_shards to ifrt. + return num_shards(); + } + return arr->pjrt_buffers().size(); + } + + std::vector& py_arrays() { return GetStorage().py_arrays; } + const std::vector& py_arrays() const { + return GetStorage().py_arrays; + } + const std::vector& py_arrays_cached(); + + nanobind::object arrays(); + absl::Status set_arrays(nanobind::object obj); + absl::StatusOr FullyReplicatedShard(); + + int num_shards() const { + xla::ifrt::Array* ifrt_array_ptr = ifrt_array(); + if (ifrt_array_ptr == nullptr) { + return 0; + } + return ifrt_array_ptr->sharding().devices()->size(); + } + + static nanobind::handle type() { + DCHECK(type_); + return nanobind::handle(type_); + } + + static bool IsPyArray(nanobind::handle arg) { + return arg.type().is(PyArray::type()); + } + + absl::Status BlockUntilReady() const; + + absl::Status BlockUntilResultStatusIsReady(); + + absl::StatusOr GetOnDeviceSizeInBytes(); + absl::StatusOr> + SingleDeviceArrayToNumpyArrayDidCopy(); + absl::StatusOr SingleDeviceArrayToNumpyArray(); + absl::Status CopySingleDeviceArrayToHostAsync(); + nanobind::dict CudaArrayInterface(); + absl::StatusOr UnsafeBufferPointer(); + + absl::Status Delete(); + + bool IsDeleted() const; + + PyArray Clone() const; + + static absl::StatusOr> BatchedCopyToDeviceWithSharding( + absl::Span py_arrays, + absl::Span dst_device_lists, + absl::Span dst_shardings, + absl::Span array_copy_semantics); + + static absl::StatusOr BatchedDevicePut( + nanobind::object aval, nanobind::object sharding, + std::vector xs, + absl::Span dst_devices, bool committed, + bool force_copy, + xla::PjRtClient::HostBufferSemantics host_buffer_semantics, + bool jax_enable_x64); + + static absl::StatusOr ReorderShards( + PyArray x, nanobind::object dst_sharding, + xla::ifrt::ArrayCopySemantics array_copy_semantics); + + static absl::Status BatchedBlockUntilReady( + std::vector objs); + + absl::Status ReplaceWithAlias(PyArray o); + + private: + absl::StatusOr AssertUnsharded(std::string_view api); + + nanobind::object CheckAndRearrange(absl::Span py_arrays, + nanobind::object sharding, + nanobind::object aval); + + void SetIfrtArray(xla::ifrt::ArrayRef ifrt_array); + + Storage& GetStorage(); + const Storage& GetStorage() const; + + inline static PyObject* type_ = nullptr; +}; + +class PyArrayResultHandler { + public: + PyArrayResultHandler(nanobind::object aval, nanobind::object sharding, + bool committed, bool skip_checks, + std::vector wrappers = {}); + + nanobind::object Call(absl::Span py_arrays) const; + nanobind::object Call(PyArray py_array) const; + + nanobind::object Call(nb_class_ptr py_client, + xla::ifrt::ArrayRef ifrt_array, + xla::Future<> result_status = xla::Future<>()) const; + + const std::vector& wrappers() const { return wrappers_; } + + nanobind::object aval() const { return aval_; } + nanobind::object sharding() const { return sharding_; } + bool committed() const { return committed_; } + bool skip_checks() const { return skip_checks_; } + + private: + nanobind::object aval_; + nanobind::object sharding_; + bool weak_type_; + bool committed_; + bool skip_checks_; + + xla::nb_dtype dtype_; + std::vector shape_; + std::vector wrappers_; +}; + +absl::StatusOr CudaArrayInterfaceToBuffer( + const nanobind::dict& cai, nb_class_ptr cuda_client, + std::optional device_id); + +} // namespace jax + +#endif // JAXLIB_PY_ARRAY_H_ diff --git a/jaxlib/py_client.cc b/jaxlib/py_client.cc new file mode 100644 index 000000000000..439a945d06ae --- /dev/null +++ b/jaxlib/py_client.cc @@ -0,0 +1,1040 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_client.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "mlir-c/IR.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep +#include "mlir/CAPI/IR.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/Pass/PassManager.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/guard_lib.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/pprof_profile_builder.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_device.h" +#include "jaxlib/py_device_list.h" +#include "jaxlib/py_executable.h" +#include "jaxlib/py_host_callback.h" +#include "jaxlib/py_memory_space.h" +#include "jaxlib/py_user_context.h" +#include "jaxlib/py_values.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/sharding.h" +#include "jaxlib/traceback.h" +#include "xla/literal.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/hlo/hlo_program.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/program.h" +#include "xla/python/ifrt/user_context.h" +#include "xla/python/ifrt/user_context_status_util.h" +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/pjrt_executable.h" +#include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "xla/python/types.h" +#include "xla/python/version.h" +#include "xla/service/platform_util.h" // IWYU pragma: keep +#include "xla/service/spmd/shardy/utils.h" // IWYU pragma: keep +#include "xla/shape.h" +#include "xla/status_macros.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace ifrt = xla::ifrt; +namespace nb = nanobind; + +namespace jax { + +/*static*/ nb_class_ptr PyClient::Make( + std::shared_ptr ifrt_client) { + auto client = make_nb_class(std::move(ifrt_client)); + Initialize(client); + return client; +} + +PyClient::PyClient(std::shared_ptr ifrt_client) + : ifrt_client_(std::move(ifrt_client)), + client_attributes_(ifrt_client_->Attributes()) { + CHECK(ifrt_client_); +} + +/* static */ void PyClient::Initialize(nb_class_ptr client) { + for (ifrt::Device* device : client->ifrt_client()->devices()) { + client->devices_[device] = make_nb_class(client, device); + + for (ifrt::Memory* memory : device->Memories()) { + auto& py_memory = client->memory_spaces_[memory]; + if (py_memory.get() == nullptr) { + py_memory = make_nb_class(client, memory); + } + } + } +} + +PyClient::~PyClient() { + nb::gil_scoped_release gil; + ifrt_client_ = nullptr; +} + +nb_class_ptr PyClient::GetPyDevice(ifrt::Device* device) { + auto& py_device = devices_[device]; + if (py_device.get() == nullptr) { + py_device = make_nb_class( + nb::borrow>(nb::find(this)), device); + } + return py_device; +} + +nb_class_ptr PyClient::GetPyMemorySpace( + ifrt::Memory* memory_space) { + auto& py_memory = memory_spaces_[memory_space]; + if (py_memory.get() == nullptr) { + py_memory = make_nb_class( + nb::borrow>(nb::find(this)), memory_space); + } + return py_memory; +} + +std::vector> PyClient::Devices() { + std::vector> devices; + auto span = ifrt_client_->devices(); + devices.reserve(span.size()); + for (ifrt::Device* device : span) { + devices.push_back(GetPyDevice(device)); + } + return devices; +} + +std::vector> PyClient::LocalDevices() { + std::vector> devices; + devices.reserve(ifrt_client_->addressable_devices().size()); + for (ifrt::Device* device : ifrt_client_->addressable_devices()) { + devices.push_back(GetPyDevice(device)); + } + return devices; +} + +std::vector> PyClient::GetAllDevices() { + std::vector> devices; + devices.reserve(ifrt_client_->GetAllDevices().size()); + for (ifrt::Device* device : ifrt_client_->GetAllDevices()) { + devices.push_back(GetPyDevice(device)); + } + return devices; +} + +absl::StatusOr> PyClient::DeviceFromLocalHardwareId( + int local_hardware_id) { + TF_ASSIGN_OR_RETURN(ifrt::Device * device, + ifrt_client_->LookupAddressableDevice(local_hardware_id)); + return GetPyDevice(device); +} + +nb::typed PyClient::LiveExecutables() { + CHECK(PyGILState_Check()); + nb::ft_lock_guard lock(executables_mutex_); + nb::list executables; + for (PyLoadedExecutable* exec = executables_; exec; exec = exec->next_) { + executables.append(nb::find(exec)); + } + return executables; +} + +absl::Status PyClient::Defragment() { + CHECK(PyGILState_Check()); + if (!llvm::isa(ifrt_client_.get())) { + return absl::UnimplementedError( + "Defragmentation is not supported on this runtime."); + } + ifrt::PlatformId platform_id = ifrt_client_->platform_id(); + bool is_gpu_client = platform_id == xla::CudaId() || + platform_id == xla::RocmId() || + platform_id == xla::SyclId(); + + if (!is_gpu_client) { + return absl::UnimplementedError( + "Defragmentation is not supported on this runtime."); + } + + // TODO(b/399879011): This is a GPU-specific implementation of `Defragment`. + // Ideally, this would be replaced with some kind of auto-defrag-on-OOM, or at + // least would not live in this file. + + struct TmpBuffer { + // Non-empty for buffers found in a PyArray_Storage. Multiple Arrays + // can reference the same xla::PjRtBuffer. + std::vector*> pjrt_buffer_ptrs; + // TODO(skyewm): maybe use py_buffer's HostValue + std::shared_ptr host_copy; + }; + + // Synchronously copy all buffers to host + absl::flat_hash_map pjrt_buf_to_tmp_buffer; + + std::vector arrays = LiveArrays(); + for (const PyArray& array : arrays) { + // TODO(hyeontaek): Support non-PjRt Arrays. + // TODO(hyeontaek): Re-construct ifrt::Array with new xla::PjRtBuffer so + // that std::shared_ptr does not need to be updated + // in-place. + if (array.ifrt_array() == nullptr) { + continue; + } + auto* arr = + llvm::dyn_cast_or_null(array.ifrt_array()); + if (arr == nullptr) { + throw xla::XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend " + "only."); + } + TF_ASSIGN_OR_RETURN( + absl::Span> pjrt_buffers, + arr->mutable_pjrt_buffers()); + for (int i = 0; i < pjrt_buffers.size(); ++i) { + std::shared_ptr& pjrt_buf_ptr = pjrt_buffers[i]; + if (pjrt_buf_ptr->IsDeleted()) { + continue; + } + auto [iter, inserted] = + pjrt_buf_to_tmp_buffer.insert({pjrt_buf_ptr.get(), TmpBuffer()}); + if (inserted) { + TF_ASSIGN_OR_RETURN(iter->second.host_copy, + pjrt_buf_ptr->ToLiteralSync()); + } + iter->second.pjrt_buffer_ptrs.push_back(&pjrt_buf_ptr); + } + } + + // All buffers successfully copied to host, delete on-device copies. + // + // Use blocking delete operation to ensure all memory is actually cleared + // before we start rewriting buffers. + // + // Die instead of returning a bad status because program presumably can't + // continue if we fail to reconstitute device buffers. + for (const auto& it : pjrt_buf_to_tmp_buffer) { + xla::PjRtBuffer* pjrt_buf = it.first; + TF_CHECK_OK(pjrt_buf + ->ReleaseDeviceMemoryOwnership( + /*wait_for_operations_to_complete=*/true) + .status()); + } + + // Copy host copies back to device and update PyArrays in-place. + for (auto& it : pjrt_buf_to_tmp_buffer) { + xla::PjRtBuffer* pjrt_buf = it.first; + TmpBuffer& tmp_buffer = it.second; + std::unique_ptr new_copy = + pjrt_client() + ->BufferFromHostLiteral(*tmp_buffer.host_copy, + pjrt_buf->memory_space()) + .value(); + TF_CHECK_OK(new_copy->GetReadyFuture().Await()); + + std::shared_ptr new_pjrt_buf_ptr(new_copy.release()); + for (std::shared_ptr* pjrt_buffer_ptr : + tmp_buffer.pjrt_buffer_ptrs) { + *pjrt_buffer_ptr = new_pjrt_buf_ptr; + } + } + + // TODO(skyewm): delete executables? + return absl::OkStatus(); +} + +/* static */ absl::StatusOr PyClient::BufferFromPyval( + nb_class_ptr client, nb::handle argument, ifrt::Device* device, + bool force_copy, ifrt::Client::HostBufferSemantics host_buffer_semantics) { + if (device == nullptr) { + TF_RET_CHECK(!client->ifrt_client_->addressable_devices().empty()); + device = client->ifrt_client_->addressable_devices().front(); + } + CHECK(device != nullptr); + + auto transfer_guard_formatter = [&argument, dst_device = device] { + auto type = nb::cast(nb::str(argument.type())); + // Catch exceptions because shape and dtype properties convertible to str + // are not guaranteed to present in an arbitrary argument. + std::string shape; + std::string dtype; + try { + shape = + nb::cast(nb::str(nb::object(argument.attr("shape")))); + } catch (const std::exception& e) { + shape = ""; + } + try { + dtype = + nb::cast(nb::str(nb::object(argument.attr("dtype")))); + } catch (const std::exception& e) { + dtype = ""; + } + return absl::StrCat("type=", type, ", shape=", shape, ", dtype=", dtype, + ", dst_device=", dst_device->DebugString()); + }; + TF_RETURN_IF_ERROR( + ApplyTransferGuardToHostToDevice(transfer_guard_formatter)); + + TF_ASSIGN_OR_RETURN(ifrt::Device * found_device, + client->ifrt_client_->LookupDevice(device->Id())); + if (found_device != device) { + return xla::InvalidArgument( + "Cannot copy value to device '%s' with '%s' backend", + device->DebugString(), client->ifrt_client_->platform_name()); + } + GlobalPyRefManager()->CollectGarbage(); + + PyUserContextScope user_context_scope; + DevicePutOptions options; + options.squash_64bit_types = false; + options.allow_zero_copy = + (!force_copy && (host_buffer_semantics == + ifrt::Client::HostBufferSemantics::kImmutableZeroCopy)); + TF_ASSIGN_OR_RETURN(DevicePutResult device_put_result, + DevicePutWithDevice(argument, client->ifrt_client_.get(), + device, ifrt::MemoryKind(), options)); + TF_ASSIGN_OR_RETURN(ifrt::DeviceListRef device_list, + client->ifrt_client()->MakeDeviceList({device})); + auto sharding = + make_nb_class(client, std::move(device_list), + /*memory_kind=*/nb::none()); + return PyArray::MakeFromIfrtArrayAndSharding( + std::move(client), std::move(device_put_result.ifrt_array), + std::move(sharding), + /*weak_type=*/false, /*committed=*/false, + /*skip_checks=*/true); +} + +namespace { + +// Makes IFRT `CompileOptions` from XLA `CompileOptions` and optional host +// callbacks. +std::unique_ptr MakeIfrtCompileOptions( + xla::CompileOptions options, ifrt::DeviceListRef executable_devices, + std::vector host_callbacks) { + std::vector> + ifrt_loaded_host_callbacks; + ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); + // Extract `ifrt::LoadedHostCallback`s from host callback capsules that + // were created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()`. + for (auto& host_callback : host_callbacks) { + ifrt_loaded_host_callbacks.push_back(tsl::FormRef( + static_cast(host_callback.data()))); + } + return std::make_unique( + std::move(options), std::move(executable_devices), + std::move(ifrt_loaded_host_callbacks)); +} + +// Makes IFRT `DeserializeExecutableOptions` from `xla::CompileOptions` and +// optional host callbacks. +std::unique_ptr +MakeIfrtDeserializeExecutableOptions(std::optional options, + ifrt::DeviceListRef executable_devices, + std::vector host_callbacks) { + std::vector> + ifrt_loaded_host_callbacks; + ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); + // Extract `ifrt::LoadedHostCallback`s from host callback capsules that + // were created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()`. + for (auto& host_callback : host_callbacks) { + ifrt_loaded_host_callbacks.push_back(tsl::FormRef( + static_cast(host_callback.data()))); + } + return std::make_unique( + std::move(options), std::move(executable_devices), + std::move(ifrt_loaded_host_callbacks)); +} + +std::unique_ptr +MakeIfrtDeserializeExecutableOptions(std::optional options, + ifrt::DeviceListRef executable_devices, + std::vector host_callbacks, + ifrt::Client* ifrt_client) { + std::vector> + ifrt_loaded_host_callbacks; + ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); + for (auto& host_callback : host_callbacks) { + auto callback = tsl::MakeRef( + ifrt_client, std::move(host_callback)); + ifrt_loaded_host_callbacks.push_back(callback); + } + return std::make_unique( + std::move(options), std::move(executable_devices), + std::move(ifrt_loaded_host_callbacks)); +} + +} // namespace + +/* static */ absl::StatusOr> +PyClient::CompileAndLoadIfrtProgram( + nb_class_ptr client, std::unique_ptr ifrt_program, + std::unique_ptr ifrt_options) { + auto* pjrt_compatible_client = + llvm::dyn_cast_or_null( + client->ifrt_client_.get()); + auto* ifrt_xla_options = + llvm::dyn_cast_or_null(ifrt_options.get()); + // For XLA programs, pass allocated device memory size to compile options for + // pjrt compatible backends. + if (pjrt_compatible_client != nullptr && ifrt_xla_options != nullptr) { + xla::CompileOptions& options = ifrt_xla_options->compile_options; + auto addressable_devices = + pjrt_compatible_client->pjrt_client()->addressable_devices(); + if (!addressable_devices.empty()) { + int device_ordinal = options.executable_build_options.device_ordinal(); + if (device_ordinal < 0) { + device_ordinal = 0; + } + CHECK_LT(device_ordinal, addressable_devices.size()); + auto stats = addressable_devices[device_ordinal]->GetAllocatorStats(); + if (stats.ok() && stats->bytes_limit) { + options.executable_build_options.set_device_memory_size( + *stats->bytes_limit); + } + } + + if (pjrt_compatible_client->pjrt_client()->key_value_store().has_value()) { + options.executable_build_options.set_key_value_store( + *pjrt_compatible_client->pjrt_client()->key_value_store()); + } + } + + PyUserContextScope user_context_scope; + ifrt::LoadedExecutableRef ifrt_loaded_executable; + std::optional fingerprint; + absl::Status compile_status; + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN( + ifrt_loaded_executable, + client->ifrt_client_->GetDefaultCompiler()->CompileAndLoad( + std::move(ifrt_program), std::move(ifrt_options))); + compile_status = ifrt_loaded_executable->GetReadyFuture().Await(); + if (compile_status.ok()) { + TF_ASSIGN_OR_RETURN(fingerprint, ifrt_loaded_executable->Fingerprint()); + } + } + if (!compile_status.ok()) { + // `compile_status.status()` can reference an asynchronously propagated + // `ifrt::UserContext` representing the context of an error. We expand this + // future result right before returning it to Python (outside of + // `nb::gil_scoped_release`) so that any attached user context is appended + // to the status message. + return xla::ifrt::ExpandUserContexts(std::move(compile_status)); + } + return make_nb_class(std::move(client), + std::move(ifrt_loaded_executable), + std::move(fingerprint)); +} + +/* static */ absl::StatusOr> PyClient::Compile( + nb_class_ptr client, mlir::ModuleOp module, + ifrt::DeviceListRef executable_devices, xla::CompileOptions options) { + mlir::OwningOpRef clone(module.clone()); + module = *clone; + ifrt::ExecutableRef ifrt_executable; + { + TF_ASSIGN_OR_RETURN( + auto topology, + client->ifrt_client()->GetTopologyForDevices(executable_devices)); + auto xla_options = std::make_unique( + options, std::move(executable_devices)); + TF_ASSIGN_OR_RETURN( + ifrt_executable, + client->ifrt_client()->GetDefaultCompiler()->Compile( + std::make_unique(std::move(module)), + *topology, std::move(xla_options))); + } + return make_nb_class(ifrt_executable); +} + +/* static */ absl::StatusOr> +PyClient::CompileAndLoad(nb_class_ptr client, mlir::ModuleOp module, + ifrt::DeviceListRef executable_devices, + xla::CompileOptions options, + std::vector host_callbacks) { + mlir::OwningOpRef clone(module.clone()); + module = *clone; + // TODO(b/420837831): Remove this once we don't need to fall back to GSPMD. + if (options.executable_build_options.use_shardy_partitioner() && + xla::sdy::hasGspmdAttrsOrOps(module)) { + LOG(WARNING) + << "Module has GSPMD attrs or ops, but Shardy is enabled. Disabling " + "Shardy and falling back to using GSPMD propagation."; + options.executable_build_options.set_use_shardy_partitioner(false); + if (xla::sdy::hasShardyMesh(module)) { + // Shardy is not enabled, but the module has shardy ops. Likely due to + // export loading a GSPMD checkpoint. Fall back to GSPMD. + TF_RETURN_IF_ERROR(xla::ExportShardyForGSPMD(module)); + } + } + options.allow_in_place_mlir_modification = true; // We just cloned the module + return CompileAndLoadIfrtProgram( + client, std::make_unique(std::move(module)), + MakeIfrtCompileOptions(std::move(options), std::move(executable_devices), + std::move(host_callbacks))); +} + +/* static */ absl::StatusOr> +PyClient::CompileAndLoad(nb_class_ptr client, mlir::ModuleOp module, + ifrt::DeviceListRef executable_devices, + xla::CompileOptions options, + std::vector host_callbacks) { + mlir::OwningOpRef clone(module.clone()); + module = *clone; + std::vector> + ifrt_loaded_host_callbacks; + ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); + // Extract `ifrt::LoadedHostCallback`s from host callback capsules that + // were created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()`. + for (auto& host_callback : host_callbacks) { + auto callback = tsl::MakeRef( + client->ifrt_client(), std::move(host_callback)); + ifrt_loaded_host_callbacks.push_back(callback); + } + auto compile_options = std::make_unique( + std::move(options), std::move(executable_devices), + std::move(ifrt_loaded_host_callbacks)); + return CompileAndLoadIfrtProgram( + client, std::make_unique(module), + std::move(compile_options)); +} + +absl::StatusOr PyClient::SerializeExecutable( + const PyLoadedExecutable& executable) const { + TF_ASSIGN_OR_RETURN(auto serialized, + executable.ifrt_loaded_executable()->Serialize()); + return nb::bytes(serialized.data(), serialized.size()); +} + +/* static */ absl::StatusOr> +PyClient::DeserializeExecutable(nb_class_ptr client, + nb::bytes serialized, + ifrt::DeviceListRef executable_devices, + std::optional options, + std::vector host_callbacks) { + ifrt::LoadedExecutableRef ifrt_loaded_executable; + std::optional fingerprint; + auto ifrt_deserialize_options = MakeIfrtDeserializeExecutableOptions( + std::move(options), std::move(executable_devices), + std::move(host_callbacks)); + PyUserContextScope user_context_scope; + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN( + ifrt_loaded_executable, + client->ifrt_client_->GetDefaultCompiler()->DeserializeLoadedExecutable( + std::string_view(serialized.c_str(), serialized.size()), + std::move(ifrt_deserialize_options))); + } + TF_ASSIGN_OR_RETURN(fingerprint, ifrt_loaded_executable->Fingerprint()); + return make_nb_class(std::move(client), + std::move(ifrt_loaded_executable), + std::move(fingerprint)); +} + +/* static */ absl::StatusOr> +PyClient::DeserializeExecutable(nb_class_ptr client, + nb::bytes serialized, + ifrt::DeviceListRef executable_devices, + std::optional options, + std::vector host_callbacks) { + ifrt::LoadedExecutableRef ifrt_loaded_executable; + std::optional fingerprint; + auto ifrt_deserialize_options = MakeIfrtDeserializeExecutableOptions( + std::move(options), std::move(executable_devices), + std::move(host_callbacks), client->ifrt_client()); + PyUserContextScope user_context_scope; + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN( + ifrt_loaded_executable, + client->ifrt_client_->GetDefaultCompiler()->DeserializeLoadedExecutable( + std::string_view(serialized.c_str(), serialized.size()), + std::move(ifrt_deserialize_options))); + } + TF_ASSIGN_OR_RETURN(fingerprint, ifrt_loaded_executable->Fingerprint()); + return make_nb_class(std::move(client), + std::move(ifrt_loaded_executable), + std::move(fingerprint)); +} + +namespace { + +struct HeapProfileKey { + std::optional traceback; + int64_t size; + xla::PjRtDevice* device; + bool operator==(const HeapProfileKey& other) const; +}; + +bool HeapProfileKey::operator==(const HeapProfileKey& other) const { + if (size != other.size || device != other.device) { + return false; + } + if ((traceback.has_value()) != (other.traceback.has_value())) { + return false; + } + if (traceback.has_value() && traceback->not_equal(*other.traceback)) { + return false; + } + return true; +} + +template +H AbslHashValue(H h, const HeapProfileKey& key) { + if (key.traceback) { + h = H::combine(std::move(h), nb::hash(*key.traceback)); + } + h = H::combine(std::move(h), key.size, key.device); + return h; +} + +} // namespace + +absl::StatusOr PyClient::HeapProfile() { + CHECK(PyGILState_Check()); + absl::flat_hash_set buffer_set; + absl::flat_hash_map entries; + + auto add_buffer_to_profile = [&](xla::PjRtBuffer* buffer, + std::optional traceback) { + // We only wish to count each xla::PjRtBuffer once, even though they may be + // shared by multiple PyArrays. + if (!buffer->IsDeleted() && buffer_set.insert(buffer).second) { + TF_ASSIGN_OR_RETURN(size_t size, buffer->GetOnDeviceSizeInBytes()); + HeapProfileKey key{traceback, static_cast(size), + buffer->device()}; + ++entries[key]; + } + return absl::OkStatus(); + }; + + std::vector arrays = LiveArrays(); + for (const PyArray& array : arrays) { + if (array.ifrt_array() == nullptr) { + continue; + } + auto* arr = + llvm::dyn_cast_or_null(array.ifrt_array()); + // TODO(hyeontaek): Support non-PjRt Arrays. + if (arr == nullptr) { + throw xla::XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend " + "only."); + } + for (const auto& buffer : arr->pjrt_buffers()) { + TF_RETURN_IF_ERROR( + add_buffer_to_profile(buffer.get(), array.traceback())); + } + } + + for (PyLoadedExecutable* executable = executables_; executable; + executable = executable->next_) { + HeapProfileKey key{executable->traceback(), + executable->SizeOfGeneratedCodeInBytes(), nullptr}; + ++entries[key]; + } + + xla::PprofProfileBuilder builder; + auto* allocations = builder.profile().add_sample_type(); + allocations->set_type(builder.StringId("allocations")); + allocations->set_unit(builder.StringId("count")); + auto* space = builder.profile().add_sample_type(); + space->set_type(builder.StringId("space")); + space->set_unit(builder.StringId("bytes")); + + const int kind_string_id = builder.StringId("kind"); + const int buffer_string_id = builder.StringId("buffer"); + const int executable_string_id = builder.StringId("executable"); + const int device_string_id = builder.StringId("device"); + for (const auto& entry : entries) { + auto* sample = builder.profile().add_sample(); + if (entry.first.traceback) { + for (const auto& frame : entry.first.traceback->RawFrames()) { + sample->add_location_id(builder.LocationId(frame.code, frame.lasti)); + } + } + sample->add_value(entry.second); + sample->add_value(entry.first.size * entry.second); + + auto* kind_label = sample->add_label(); + kind_label->set_key(kind_string_id); + if (entry.first.device) { + kind_label->set_str(buffer_string_id); + auto* device_label = sample->add_label(); + device_label->set_key(device_string_id); + std::string device_label_str(entry.first.device->DebugString()); + device_label->set_str(builder.StringId(device_label_str)); + } else { + kind_label->set_str(executable_string_id); + } + } + std::string serialized = builder.profile().SerializeAsString(); + return nb::bytes(serialized.data(), serialized.size()); +} + +absl::StatusOr PyClient::MakePythonCallbackUsingHostSendAndRecv( + nb::callable callable, absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, nb::callable serializer) { + TF_ASSIGN_OR_RETURN( + auto loaded_host_callback, + PyHostSendAndRecvLoadedHostCallback::Create( + ifrt_client(), std::move(callable), operand_shapes, result_shapes, + send_channel_ids, recv_channel_ids, std::move(serializer))); + nb::capsule callback_capsule( + loaded_host_callback.release(), [](void* ptr) noexcept { + static_cast(ptr)->DropRef(); + }); + return callback_capsule; +} + +/* static */ int PyClient::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + Py_VISIT(Py_TYPE(self)); + if (!nb::inst_ready(self)) { + return 0; + } + PyClient* c = nb::inst_ptr(self); + for (const auto& [ifrt_device, py_device] : c->devices_) { + Py_VISIT(py_device.ptr()); + } + for (const auto& [ifrt_memory, py_memory] : c->memory_spaces_) { + Py_VISIT(py_memory.ptr()); + } + return 0; +} + +/* static */ int PyClient::tp_clear(PyObject* self) { + PyClient* c = nb::inst_ptr(self); + absl::flat_hash_map> devices; + std::swap(devices, c->devices_); + absl::flat_hash_map> memory_spaces; + std::swap(memory_spaces, c->memory_spaces_); + return 0; +} + +PyType_Slot PyClient::slots_[] = { + {Py_tp_traverse, (void*)PyClient::tp_traverse}, + {Py_tp_clear, (void*)PyClient::tp_clear}, + {0, nullptr}, +}; + +/* static */ void PyClient::Register(nb::module_& m) { + nb::enum_(m, "HostBufferSemantics") + .value("IMMUTABLE_ONLY_DURING_CALL", + xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall) + .value("IMMUTABLE_UNTIL_TRANSFER_COMPLETES", + xla::PjRtClient::HostBufferSemantics:: + kImmutableUntilTransferCompletes) + .value("ZERO_COPY", + xla::PjRtClient::HostBufferSemantics::kImmutableZeroCopy); + + nb::class_ py_local_client(m, "Client", nb::is_weak_referenceable(), + nb::type_slots(PyClient::slots_)); + py_local_client.def_prop_ro("platform", &PyClient::platform_name) + .def_prop_ro("_raw_platform", &PyClient::raw_platform_name) + .def_prop_ro("platform_version", &PyClient::platform_version) + .def_prop_ro("runtime_type", &PyClient::runtime_type) + .def("device_count", &PyClient::device_count) + .def("local_device_count", &PyClient::addressable_device_count) + .def("devices", &PyClient::Devices) + .def("local_devices", &PyClient::LocalDevices) + // TODO(hyeontaek): Remove this method once we have a unified API for + // enumerating devices with different criteria. + .def("_get_all_devices", &PyClient::GetAllDevices) + .def("device_from_local_hardware_id", + xla::ValueOrThrowWrapper(&PyClient::DeviceFromLocalHardwareId)) + .def("live_executables", &PyClient::LiveExecutables) + .def("live_arrays", &PyClient::LiveArrays) + .def("live_buffers", &PyClient::LiveArrays) + .def("process_index", &PyClient::process_index) + .def("host_id", &PyClient::process_index) + .def("task_id", &PyClient::process_index) + .def( + "buffer_from_pyval", + [](nb_class_ptr client, nb::handle argument, + PyDevice* device, bool force_copy, + xla::PjRtClient::HostBufferSemantics host_buffer_semantics) { + return xla::ValueOrThrow( + PyClient::BufferFromPyval(std::move(client), argument, + device ? device->device() : nullptr, + force_copy, host_buffer_semantics)); + }, + nb::arg("argument"), nb::arg("device").none() = nullptr, + nb::arg("force_copy") = false, + nb::arg("host_buffer_semantics") = + xla::PjRtClient::HostBufferSemantics::kImmutableZeroCopy) + .def( + "compile", + [](nb_class_ptr client, MlirModule mlir_module, + PyDeviceList& py_executable_devices, xla::CompileOptions options) { + ifrt::DeviceListRef executable_devices = + xla::ValueOrThrow(py_executable_devices.ifrt_device_list()); + return xla::ValueOrThrow(PyClient::Compile( + std::move(client), unwrap(mlir_module), + std::move(executable_devices), std::move(options))); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = xla::CompileOptions(), + nb::sig( + // clang-format off + "def compile(" + "self, " + "computation: object, " + "executable_devices: DeviceList, " + "compile_options: CompileOptions = ..." + ") -> Executable" + // clang-format on + )) + .def( + "compile_and_load", + [](nb_class_ptr client, MlirModule mlir_module, + PyDeviceList& py_executable_devices, xla::CompileOptions options, + std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices = + xla::ValueOrThrow(py_executable_devices.ifrt_device_list()); + return xla::ValueOrThrow(PyClient::CompileAndLoad( + std::move(client), unwrap(mlir_module), + std::move(executable_devices), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = xla::CompileOptions(), + nb::arg("host_callbacks") = std::vector(), + nb::sig( + // clang-format off + "def compile_and_load(" + "self, " + "computation: object, " + "executable_devices: DeviceList, " + "compile_options: CompileOptions = ..., " + "host_callbacks: Sequence[typing_extensions.CapsuleType] = ..." + ") -> LoadedExecutable" + // clang-format on + )) + .def( + "compile_and_load", + [](nb_class_ptr client, MlirModule mlir_module, + PyDeviceList& py_executable_devices, xla::CompileOptions options, + std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices = + xla::ValueOrThrow(py_executable_devices.ifrt_device_list()); + return xla::ValueOrThrow(PyClient::CompileAndLoad( + std::move(client), unwrap(mlir_module), + std::move(executable_devices), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = xla::CompileOptions(), + nb::arg("host_callbacks") = std::vector(), + nb::sig( + // clang-format off + "def compile_and_load(" + "self, " + "computation: object, " + "executable_devices: DeviceList, " + "compile_options: CompileOptions = ..., " + "host_callbacks: Sequence[Callable[..., typing.Any]] = ..." + ") -> LoadedExecutable" + // clang-format on + )) + // The following two overloads are for users of deprecated APIs who call + // `backend.compile` but do not have visibility to `DeviceList`. + .def( + "compile_and_load", + [](nb_class_ptr client, nb::bytes module_str, + nb::sequence& py_executable_devices, xla::CompileOptions options) { + mlir::MLIRContext context; + mlir::OwningOpRef module = + xla::ValueOrThrow(xla::ParseMlirModuleString( + std::string_view(module_str.c_str(), module_str.size()), + context)); + ifrt::DeviceListRef executable_devices = + xla::ValueOrThrow(PyDeviceList(nb::tuple(py_executable_devices)) + .ifrt_device_list()); + return xla::ValueOrThrow(PyClient::CompileAndLoad( + std::move(client), *module, std::move(executable_devices), + std::move(options), std::vector())); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = xla::CompileOptions()) + .def( + "compile_and_load", + [](nb_class_ptr client, std::string module_str, + nb::sequence& py_executable_devices, xla::CompileOptions options) { + mlir::MLIRContext context; + mlir::OwningOpRef module = xla::ValueOrThrow( + xla::ParseMlirModuleString(module_str, context)); + + ifrt::DeviceListRef executable_devices = + xla::ValueOrThrow(PyDeviceList(nb::tuple(py_executable_devices)) + .ifrt_device_list()); + return xla::ValueOrThrow(PyClient::CompileAndLoad( + std::move(client), *module, std::move(executable_devices), + std::move(options), std::vector())); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = xla::CompileOptions()) + .def("compile_ifrt_program", + xla::ValueOrThrowWrapper(PyClient::CompileAndLoadIfrtProgram)) + .def("compile_and_load_ifrt_program", + xla::ValueOrThrowWrapper(PyClient::CompileAndLoadIfrtProgram)) + .def("serialize_executable", + xla::ValueOrThrowWrapper(&PyClient::SerializeExecutable)) + .def( + "deserialize_executable", + [](nb_class_ptr client, nb::bytes serialized, + PyDeviceList& py_executable_devices, + std::optional options, + std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices = + xla::ValueOrThrow(py_executable_devices.ifrt_device_list()); + return xla::ValueOrThrow(PyClient::DeserializeExecutable( + std::move(client), std::move(serialized), + std::move(executable_devices), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("serialized"), nb::arg("executable_devices"), + nb::arg("compile_options").none() = nb::none(), + nb::arg("host_callbacks") = std::vector()) + .def( + "deserialize_executable", + [](nb_class_ptr client, nb::bytes serialized, + jax::PyDeviceList& py_executable_devices, + std::optional options, + std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices = + xla::ValueOrThrow(py_executable_devices.ifrt_device_list()); + return xla::ValueOrThrow(PyClient::DeserializeExecutable( + std::move(client), std::move(serialized), + std::move(executable_devices), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("serialized"), nb::arg("executable_devices"), + nb::arg("compile_options").none() = nb::none(), + nb::arg("host_callbacks") = std::vector()) + // The following overload is for users of deprecated APIs who call + // `deserialize_executable` but do not have visibility to `DeviceList`. + .def( + "deserialize_executable", + [](nb_class_ptr client, nb::bytes serialized, + nb::sequence& py_executable_devices, + std::optional options) { + ifrt::DeviceListRef executable_devices = + xla::ValueOrThrow(PyDeviceList(nb::tuple(py_executable_devices)) + .ifrt_device_list()); + return xla::ValueOrThrow(PyClient::DeserializeExecutable( + std::move(client), std::move(serialized), + std::move(executable_devices), std::move(options), + std::vector())); + }, + nb::arg("serialized"), nb::arg("executable_devices"), + nb::arg("compile_options").none() = nb::none()) + .def("heap_profile", xla::ValueOrThrowWrapper(&PyClient::HeapProfile)) + // TODO(zhangqiaorjc): Experimental. + .def("defragment", + [](PyClient& self) { xla::ThrowIfError(self.Defragment()); }) + .def("make_python_callback_from_host_send_and_recv", + xla::ValueOrThrowWrapper( + &PyClient::MakePythonCallbackUsingHostSendAndRecv), + nb::arg("callable"), nb::arg("operand_shapes"), + nb::arg("result_shapes"), nb::arg("send_channel_ids"), + nb::arg("recv_channel_ids"), + nb::arg("serializer").none() = nb::none()) + .def( + "get_default_layout", + [](PyClient& self, xla::nb_dtype dtype, nb::sequence shard_shape, + nb_class_ptr device) + -> std::shared_ptr { + ifrt::DType ifrt_type = xla::ValueOrThrow(DtypeToIfRtDType(dtype)); + std::vector dims = + xla::SequenceToVector(shard_shape); + return xla::ValueOrThrow(self.ifrt_client()->GetDefaultPjRtLayout( + ifrt_type, dims, device->device(), xla::ifrt::MemoryKind())); + }, + nb::arg("dtype"), nb::arg("shard_shape"), nb::arg("device")) + .def("__getattr__", + [](PyClient& client, std::string_view name) -> nb::object { + auto value = + client.Attributes().Get( + std::string(name)); + if (value.ok()) { + return std::visit([](auto&& v) { return nb::cast(v.value); }, + *value); + } + throw nb::attribute_error( + absl::StrCat("Unknown attribute ", name).c_str()); + }); +} + +} // namespace jax diff --git a/jaxlib/py_client.h b/jaxlib/py_client.h new file mode 100644 index 000000000000..b28c6077262a --- /dev/null +++ b/jaxlib/py_client.h @@ -0,0 +1,273 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_CLIENT_H_ +#define JAXLIB_PY_CLIENT_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/BuiltinOps.h" +#include "nanobind/nanobind.h" +#include "jaxlib/nb_class_ptr.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/program.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/shape.h" + +namespace jax { + +class PyClient; +class PyLoadedExecutable; +class PyExecutable; +class PyArray; +class PyDevice; +class PyMemorySpace; +struct PyArray_Storage; + +// Python wrapper around xla::PjRtClient. +// We use a wrapper class to add Python-specific functionality. +class PyClient { + public: + static nb_class_ptr Make( + std::shared_ptr ifrt_client); + + // Do not call the constructor directly. Use `PyClient::Make` instead. + explicit PyClient(std::shared_ptr ifrt_client); + virtual ~PyClient(); + + xla::ifrt::Client* ifrt_client() const { return ifrt_client_.get(); } + const std::shared_ptr& shared_ptr_ifrt_client() const { + return ifrt_client_; + } + + // Short-term escape hatch to get xla::PjRtClient from PyClient. + // TODO(hyeontaek): Migrate all users of this method to be agnostic of PjRt. + xla::PjRtClient* pjrt_client() const { + auto* pjrt_client = llvm::dyn_cast_or_null( + ifrt_client_.get()); + if (pjrt_client == nullptr) { + throw xla::XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return pjrt_client->pjrt_client(); + } + std::shared_ptr shared_ptr_pjrt_client() { + auto* pjrt_client = llvm::dyn_cast_or_null( + ifrt_client_.get()); + if (pjrt_client == nullptr) { + throw xla::XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return pjrt_client->shared_ptr_pjrt_client(); + } + + // Legacy aliases. + std::shared_ptr shared_pjrt_client() { + return shared_ptr_pjrt_client(); + } + + std::string_view platform_name() const { + // TODO(phawkins): this is a temporary backwards compatibility shim. We + // changed the name PJRT reports for GPU platforms to "cuda" or "rocm", but + // we haven't yet updated JAX clients that expect "gpu". Migrate users and + // remove this code. + if (ifrt_client_->platform_name() == "cuda" || + ifrt_client_->platform_name() == "rocm") { + return "gpu"; + } else { + return ifrt_client_->platform_name(); + } + } + std::string_view raw_platform_name() const { + // TODO(parkers): Once platform_name() is the same, remove this. + return ifrt_client_->platform_name(); + } + std::string_view platform_version() const { + return ifrt_client_->platform_version(); + } + std::string_view runtime_type() const { return ifrt_client_->runtime_type(); } + + // Returns implementation-specific attributes about this client, e.g. the PJRT + // C API version if applicable. + const xla::ifrt::AttributeMap& Attributes() const { + return client_attributes_; + } + + int addressable_device_count() const { + return ifrt_client_->addressable_device_count(); + } + int device_count() const { return ifrt_client_->device_count(); } + int process_index() const { return ifrt_client_->process_index(); } + + std::vector> Devices(); + std::vector> LocalDevices(); + // Returns all devices in the client. Private API; only use this method for + // implementing backend._get_all_devices(). + // TODO(hyeontaek): Remove this method once we have a unified API for + // enumerating devices with different criteria. + std::vector> GetAllDevices(); + absl::StatusOr> DeviceFromLocalHardwareId( + int local_hardware_id); + + // Returns the PyDevice associated with the given xla::ifrt::Device. + nb_class_ptr GetPyDevice(xla::ifrt::Device* device); + + // Returns the PyMemorySpace associated with the given xla::ifrt::Memory. + nb_class_ptr GetPyMemorySpace(xla::ifrt::Memory* memory_space); + + // Returns a vector of live PyArray objects. PyArray objects may share + // PjRtBuffers, so there may be duplicates of the same underlying device + // buffer. + std::vector LiveBuffersOnDevice(xla::ifrt::Device* device); + + nanobind::typed LiveExecutables(); + + // TODO(zhangqiaorjc): Remove when we have transparent defragmentation. + absl::Status Defragment(); + + static absl::StatusOr BufferFromPyval( + nb_class_ptr client, nanobind::handle argument, + xla::ifrt::Device* device, bool force_copy, + xla::ifrt::Client::HostBufferSemantics host_buffer_semantics); + + static absl::StatusOr> + CompileAndLoadIfrtProgram( + nb_class_ptr client, + std::unique_ptr ifrt_program, + std::unique_ptr ifrt_options); + + static absl::StatusOr> Compile( + nb_class_ptr client, mlir::ModuleOp mlir_module, + xla::ifrt::DeviceListRef executable_devices, xla::CompileOptions options); + + static absl::StatusOr> CompileAndLoad( + nb_class_ptr client, mlir::ModuleOp mlir_module, + xla::ifrt::DeviceListRef executable_devices, xla::CompileOptions options, + std::vector host_callbacks); + + static absl::StatusOr> CompileAndLoad( + nb_class_ptr client, mlir::ModuleOp mlir_module, + xla::ifrt::DeviceListRef executable_devices, xla::CompileOptions options, + std::vector host_callbacks); + + absl::StatusOr SerializeExecutable( + const PyLoadedExecutable& executable) const; + static absl::StatusOr> DeserializeExecutable( + nb_class_ptr client, nanobind::bytes serialized, + xla::ifrt::DeviceListRef executable_devices, + std::optional options, + std::vector host_callbacks); + static absl::StatusOr> DeserializeExecutable( + nb_class_ptr client, nanobind::bytes serialized, + xla::ifrt::DeviceListRef executable_devices, + std::optional options, + std::vector host_callbacks); + + absl::StatusOr HeapProfile(); + + // `MakePythonCallbackUsingHostSendAndRecv` takes in an input Python callable + // that takes in arguments of shapes `operand_shapes` and returns results of + // shapes `result_shapes`. The arguments correspond to Send ops in the HLO + // program through `send_channel_ids` and the results correspond to Recv ops + // through `recv_channel_ids`. It returns the host callback as an opaque + // object whose reference will keep the Python callback alive. The host + // callback can be passed to `PyClient::CompileAndLoad` or + // `PyClient::DeserializeExecutable`. The corresponding Send/Recv ops in the + // XLA computation can trigger the execution of this host callback. + // `serializer` is a function that takes `callable` as an argument and returns + // a serialized callable as a string. + // + // The callable receives as arguments NumPy arrays for arguments with array + // types, and None for Token argument. The callable must return a tuple of + // either arrays or None values. + absl::StatusOr MakePythonCallbackUsingHostSendAndRecv( + nanobind::callable callable, absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, + nanobind::callable serializer); + + std::vector LiveArrays() const; + + static void Register(nanobind::module_& m); + + protected: + static void Initialize(nb_class_ptr client); + + private: + friend class PyLoadedExecutable; + friend class PyArray; + friend struct PyArray_Storage; + + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); + static PyType_Slot slots_[]; + + std::shared_ptr ifrt_client_; + xla::ifrt::AttributeMap client_attributes_; + // Pointers to intrusive doubly-linked lists of arrays and executables, used + // to iterate over all known objects when heap profiling. The list structure + // is protected by the GIL. + + nanobind::ft_mutex executables_mutex_; + // List guarded by executables_mutex_. + PyLoadedExecutable* executables_ = nullptr; + +#ifdef NB_FREE_THREADING + static constexpr size_t kNumArraysShards = 16; +#else + static constexpr size_t kNumArraysShards = 1; +#endif + struct ArraysShard { + mutable nanobind::ft_mutex mutex; + PyArray_Storage* arrays; + }; + std::array arrays_; + + absl::flat_hash_map> devices_; + absl::flat_hash_map> + memory_spaces_; +}; + +// Returns the execution stream id set for the current thread. +inline int64_t& GetExecutionStreamId() { + thread_local int64_t execution_stream_id = 0; + return execution_stream_id; +} + +} // namespace jax + +#endif // JAXLIB_PY_CLIENT_H_ diff --git a/jaxlib/py_client_cpu.cc b/jaxlib/py_client_cpu.cc new file mode 100644 index 000000000000..aaf7fd2faab0 --- /dev/null +++ b/jaxlib/py_client_cpu.cc @@ -0,0 +1,247 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_client_cpu.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "include/dlpack/dlpack.h" +#include "nanobind/nanobind.h" +#include "jaxlib/ffi.h" +#include "xla/ffi/api/ffi.h" +#include "xla/ffi/ffi_api.h" +#include "xla/pjrt/host_callback.h" +#include "xla/pjrt/transpose.h" +#include "xla/primitive_util.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/types.h" +#include "xla/shape_util.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace jax { + +struct CpuTransposePlanCache { + static ffi::TypeId id; + static ffi::TypeInfo info; + + explicit CpuTransposePlanCache(int capacity) : cache(capacity) {} + xla::TransposePlanCache cache; +}; + +ffi::TypeId CpuTransposePlanCache::id = {}; +ffi::TypeInfo CpuTransposePlanCache::info = + ffi::MakeTypeInfo(); + +XLA_FFI_REGISTER_TYPE(ffi::GetXlaFfiApi(), "CpuTransposePlanCache", + &CpuTransposePlanCache::id, &CpuTransposePlanCache::info); + +static ffi::ErrorOr> +CpuTransposePlanCacheInstantiate(uint64_t index) { + return std::make_unique(/*capacity=*/16); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + kCpuTransposePlanCacheInstantiate, CpuTransposePlanCacheInstantiate, + ffi::Ffi::BindInstantiate().Attr("index")); + +ffi::Error XlaFfiPythonCpuCallback(xla::FfiLoadedHostCallbacks* callbacks, + CpuTransposePlanCache* transpose_cache, + uint64_t index, ffi::RemainingArgs args, + ffi::RemainingRets rets) { + nb::gil_scoped_acquire gil; + auto callback = nb::borrow( + static_cast(callbacks->callbacks[index])); + auto nb_args = nb::steal(PyTuple_New(args.size())); + for (size_t i = 0; i < args.size(); ++i) { + auto arg = args.get(i); + auto ptype = static_cast(arg->element_type()); + // TODO(b/395428868): Remove this check once we support subbyte types. + if (ptype == xla::S1 || ptype == xla::U1) { + return ffi::Error(ffi::ErrorCode::kUnimplemented, + absl::StrFormat("Unsupported primitive type: %s", + xla::PrimitiveType_Name(ptype))); + } + if (ptype == xla::TOKEN) { + PyTuple_SET_ITEM(nb_args.ptr(), i, nb::none().release().ptr()); + continue; + } + auto maybe_dtype = xla::PrimitiveTypeToNbDtype(ptype); + if (!maybe_dtype.ok()) { + return ffi::Error::Internal(maybe_dtype.status().ToString()); + } + auto dtype = maybe_dtype.value(); + auto dims = absl::Span(arg->dimensions().begin(), + arg->dimensions().size()); + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + std::unique_ptr buffer; + const void* data = arg->untyped_data(); + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): FFI arguments and return buffers are sized assuming + // minimum 1-byte element sizes, even if the data itself is packed. We + // assume that 2-bit and 4-bit types are packed. + size_t size_bytes = arg->element_count() * bits_per_element / 8; + buffer = xla::UnpackIntN(bits_per_element, static_cast(data), + size_bytes); + data = buffer.get(); + } + // We pass in data using default numpy layout i.e., std::nullopt. + auto array = xla::nb_numpy_ndarray(dtype, dims, std::nullopt, data); + array.attr("flags").attr("writeable") = nb::bool_(false); + PyTuple_SET_ITEM(nb_args.ptr(), i, array.release().ptr()); + } + + // TODO(dsuo): Change this to use the Python vectorcall protocol, which allows + // you to avoid constructing a tuple for the arguments. + nb::tuple result_tuple; + { + xla::HostCallbackScope scope; + try { + auto result_object = callback(*nb::borrow(nb_args)); + result_tuple = nb::cast(result_object); + } catch (nb::python_error& e) { + return ffi::Error::Internal( + absl::StrFormat("CpuCallback error calling callback: %s", e.what())); + } + } + + for (size_t i = 0; i < rets.size(); ++i) { + auto ret = rets.get(i).value(); + auto ptype = static_cast(ret->element_type()); + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + if (ptype == xla::S1 || ptype == xla::U1) { + return ffi::Error(ffi::ErrorCode::kUnimplemented, + absl::StrFormat("Unsupported primitive type: %s", + xla::PrimitiveType_Name(ptype))); + } + if (ptype == xla::TOKEN) continue; + nb::object output = + nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); + xla::nb_numpy_ndarray array = + xla::nb_numpy_ndarray::ensure(std::move(output)); + absl::Span strides( + reinterpret_cast(array.strides()), array.ndim()); + // We expect the output to be in default numpy layout. + auto dims = absl::Span(ret->dimensions().begin(), + ret->dimensions().size()); + auto maybe_expected_shape = xla::ShapeUtil::MakeValidatedShape(ptype, dims); + if (!maybe_expected_shape.ok()) { + return ffi::Error::Internal(maybe_expected_shape.status().ToString()); + } + auto expected_shape = maybe_expected_shape.value(); + auto expected_strides = ByteStridesForShape(expected_shape); + + const void* data = array.data(); + std::unique_ptr buffer; + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + size_t size_bytes = array.size() * array.itemsize(); + if (strides != expected_strides) { + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + options.dims = absl::Span( + reinterpret_cast(array.shape()), array.ndim()); + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.dimensions().size()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_striding = xla::TransposePlan::Striding{strides}; + auto maybe_plan = transpose_cache->cache.GetOrCreate(options); + if (!maybe_plan.ok()) { + return ffi::Error::Internal(maybe_plan.status().ToString()); + } + auto plan = maybe_plan.value(); + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): If the data needs to be unpacked, don't use return buffer + // supplied by FFI directly. + buffer = std::make_unique(size_bytes); + plan->Execute(data, buffer.get()); + data = buffer.get(); + } else { + plan->Execute(data, ret->untyped_data()); + data = ret->untyped_data(); + } + } + + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + // of packing/unpacking to/from numpy arrays. + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): FFI arguments and return buffers are sized assuming + // minimum 1-byte element sizes, even if the data itself is packed. We + // assume that 2-bit and 4-bit types are packed. + buffer = xla::PackIntN(bits_per_element, static_cast(data), + size_bytes); + data = buffer.get(); + size_bytes = (size_bytes * bits_per_element) / 8; + } + + // Copy data to output buffer if haven't already or modified the data to + // write back. + if (data != ret->untyped_data()) { + std::memcpy(ret->untyped_data(), data, size_bytes); + } + } + + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + kXlaFfiPythonCpuCallback, XlaFfiPythonCpuCallback, + ffi::Ffi::Bind() + .Ctx>() + .Ctx>() + .Attr("index") + .RemainingArgs() + .RemainingRets()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "xla_ffi_python_cpu_callback", + "HOST", + {kCpuTransposePlanCacheInstantiate, nullptr, nullptr, + kXlaFfiPythonCpuCallback}); +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), + "xla_ffi_partitioned_python_cpu_callback", "HOST", + {kCpuTransposePlanCacheInstantiate, nullptr, nullptr, + kXlaFfiPythonCpuCallback}); + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + kXlaBufferPythonCpuCallback, (XlaBufferCallback), + ffi::Ffi::Bind() + .Ctx() + .Ctx() + .Ctx>() + .Attr("index") + .RemainingArgs() + .RemainingRets()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "xla_buffer_python_cpu_callback", + "HOST", kXlaBufferPythonCpuCallback); + +} // namespace jax diff --git a/jaxlib/py_client_cpu.h b/jaxlib/py_client_cpu.h new file mode 100644 index 000000000000..93a2770b35e6 --- /dev/null +++ b/jaxlib/py_client_cpu.h @@ -0,0 +1,28 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_CLIENT_CPU_H_ +#define JAXLIB_PY_CLIENT_CPU_H_ + +#include "xla/ffi/api/ffi.h" + +namespace jax { + +XLA_FFI_DECLARE_HANDLER_SYMBOL(kCpuTransposePlanCacheInstantiate); +XLA_FFI_DECLARE_HANDLER_SYMBOL(kXlaFfiPythonCpuCallback); + +} // namespace jax + +#endif // JAXLIB_PY_CLIENT_CPU_H_ diff --git a/jaxlib/py_compile_only_client.cc b/jaxlib/py_compile_only_client.cc new file mode 100644 index 000000000000..8c6bb192ed56 --- /dev/null +++ b/jaxlib/py_compile_only_client.cc @@ -0,0 +1,125 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_compile_only_client.h" + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "llvm/Support/Casting.h" +#include "mlir-c/IR.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep +#include "mlir/CAPI/IR.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OwningOpRef.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_device_list.h" +#include "jaxlib/py_executable.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/compile_only_ifrt/client.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/hlo/hlo_program.h" +#include "xla/python/pjrt_ifrt/pjrt_executable.h" +#include "xla/python/pjrt_ifrt/pjrt_topology.h" +#include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "xla/python/version.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/python/lib/core/numpy.h" +#include "xla/xla_data.pb.h" + +namespace ifrt = xla::ifrt; +namespace nb = nanobind; + +namespace jax { + +nb_class_ptr CompileOnlyPyClient::Make( + std::shared_ptr topology) { + auto client = + nb::borrow>(make_nb_class( + std::make_unique(std::move(topology)))); + CompileOnlyPyClient::Initialize(client); + return client; +} + +absl::StatusOr> CompileOnlyPyClient::CompileUnloaded( + MlirModule mlir_module, xla::ifrt::DeviceListRef executable_devices, + xla::CompileOptions options) { + mlir::ModuleOp module = unwrap(mlir_module); + mlir::OwningOpRef clone(module.clone()); + module = *clone; + ifrt::ExecutableRef ifrt_executable; + { + nb::gil_scoped_release gil_release; + auto* ifrt_client = + llvm::dyn_cast_or_null(this->ifrt_client()); + CHECK(ifrt_client) << "CompileOnlyPyClient requires ifrt_client be a " + "xla::CompileOnlyIfRtClient"; + + auto xla_options = std::make_unique( + options, std::move(executable_devices)); + TF_ASSIGN_OR_RETURN( + ifrt_executable, + ifrt_client->GetDefaultCompiler()->Compile( + std::make_unique(std::move(module)), + ifrt_client->topology(), std::move(xla_options))); + } + return make_nb_class(ifrt_executable); +} + +void CompileOnlyPyClient::Initialize(nb_class_ptr client) { + PyClient::Initialize(client); +} + +void CompileOnlyPyClient::Register(nb::module_& m) { + nb::class_(m, "CompileOnlyPyClient") + .def( + "compile", + [](CompileOnlyPyClient& self, MlirModule mlir_module, + PyDeviceList& py_executable_devices, xla::CompileOptions options, + std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices = + xla::ValueOrThrow(py_executable_devices.ifrt_device_list()); + return xla::ValueOrThrow( + self.CompileUnloaded(mlir_module, std::move(executable_devices), + std::move(options))); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = xla::CompileOptions(), + nb::arg("host_callbacks") = std::vector(), + nb::sig( + // clang-format off + "def compile(" + "self, " + "computation: object, " + "executable_devices: DeviceList, " + "compile_options: CompileOptions = ..., " + "host_callbacks: Sequence[typing_extensions.CapsuleType] = ..." + ") -> Executable" + // clang-format on + )); +} + +} // namespace jax diff --git a/jaxlib/py_compile_only_client.h b/jaxlib/py_compile_only_client.h new file mode 100644 index 000000000000..297830348e25 --- /dev/null +++ b/jaxlib/py_compile_only_client.h @@ -0,0 +1,63 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_COMPILE_ONLY_CLIENT_H_ +#define JAXLIB_PY_COMPILE_ONLY_CLIENT_H_ + +#include + +// placeholder for index annotation headers +#include "absl/status/statusor.h" +#include "mlir-c/IR.h" +#include "nanobind/nanobind.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/pjrt_ifrt/pjrt_topology.h" +#include "xla/xla_data.pb.h" + +namespace jax { +class PyExecutable; + +// This is a workaround for AOT compilation until topologies and device +// descriptions are better integrated into jax's Python code. It returns a +// PyClient that will return errors for all non-AOT methods. It also exposes a +// different compile method that returns an unloaded executable (vs. PyClient +// usually returns a loaded executable). RegisterCompileOnlyClient() overloads +// the Python "compile" method to return the unloaded executable, and we rely on +// Python duck typing to treat the unloaded executable like a loaded executable +// (except it will raise errors if you try to run it, which is what we want for +// AOT environments). +class CompileOnlyPyClient : public PyClient { + public: + using PyClient::PyClient; + + static nb_class_ptr Make( + std::shared_ptr topology); + + absl::StatusOr> CompileUnloaded( + MlirModule mlir_module, xla::ifrt::DeviceListRef executable_devices, + xla::CompileOptions options); + + static void Register(nanobind::module_& m); + + private: + static void Initialize(nb_class_ptr client); +}; + +} // namespace jax + +#endif // JAXLIB_PY_COMPILE_ONLY_CLIENT_H_ diff --git a/jaxlib/py_device.cc b/jaxlib/py_device.cc new file mode 100644 index 000000000000..c1641fce57e4 --- /dev/null +++ b/jaxlib/py_device.cc @@ -0,0 +1,311 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_device.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_memory_space.h" +#include "jaxlib/python_ref_manager.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/nb_helpers.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/version.h" +#include "xla/tsl/framework/allocator.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace ifrt = ::xla::ifrt; +namespace nb = ::nanobind; + +namespace jax { + +PyDevice::PyDevice(nb_class_ptr client, ifrt::Device* device) + : client_(std::move(client)), device_(device) {} + +int PyDevice::id() const { return device_->Id().value(); } + +int PyDevice::process_index() const { return device_->ProcessIndex(); } + +std::string_view PyDevice::platform() const { + // TODO(phawkins): this is a temporary backwards + // compatibility shim. We changed the name PJRT + // reports for GPU platforms to "cuda" or "rocm", + // but we haven't yet updated JAX clients that + // expect "gpu". Migrate users and remove this + // code. + absl::string_view platform_name = device_->PlatformName(); + if (platform_name == "cuda" || platform_name == "rocm") { + return std::string_view("gpu"); + } else { + return platform_name; + } +} + +std::string_view PyDevice::device_kind() const { return device_->Kind(); } + +std::optional PyDevice::local_hardware_id() const { + // TODO(phawkins): consider supporting this for non-PJRT devices. + ifrt::PjRtDevice* device = llvm::dyn_cast(device_); + if (device == nullptr || !device->IsAddressable()) { + return std::nullopt; + } + int local_hardware_id = device->pjrt_device()->local_hardware_id().value(); + if (local_hardware_id == -1) { + return std::nullopt; + } + return local_hardware_id; +} + +std::string_view PyDevice::Str() const { return device_->DebugString(); } + +std::string_view PyDevice::Repr() const { return device_->ToString(); } + +absl::StatusOr> PyDevice::Memory( + std::string_view kind) const { + ifrt::Memory* result_memory_space = nullptr; + for (auto* memory_space : device_->Memories()) { + if (memory_space->Kind().memory_kind() == kind) { + if (result_memory_space != nullptr) { + std::string memories = absl::StrJoin( + device_->Memories(), ", ", + [](std::string* out, const auto& memory_space) { + absl::StrAppend(out, *memory_space->Kind().memory_kind()); + }); + auto device_kind = device_->Kind(); + return xla::InvalidArgument( + "Found more than one addressable memory for " + "kind %s which is not allowed. There can only " + "be one memory for each " + "kind. Device %s can address the following " + "memory kinds: %s", + kind, device_kind, memories); + } + result_memory_space = memory_space; + } + } + if (result_memory_space == nullptr) { + std::string memories = absl::StrJoin( + device_->Memories(), ", ", + [](std::string* out, const auto& memory_space) { + absl::StrAppend(out, *memory_space->Kind().memory_kind()); + }); + auto device_kind = device_->Kind(); + return xla::InvalidArgument( + "Could not find memory addressable by device %s. Device %s " + "can address the following memory kinds: %s. " + "Got memory kind: %s", + device_kind, device_kind, memories, kind); + } + return client_->GetPyMemorySpace(result_memory_space); +} + +absl::StatusOr> PyDevice::DefaultMemory() const { + TF_ASSIGN_OR_RETURN(auto* memory_space, device_->DefaultMemory()); + return client_->GetPyMemorySpace(memory_space); +} + +nb::typed PyDevice::AddressableMemories() const { + nb::list memory_spaces; + for (auto* memory_space : device_->Memories()) { + memory_spaces.append(client_->GetPyMemorySpace(memory_space)); + } + return memory_spaces; +} + +absl::StatusOr>> +PyDevice::MemoryStats() const { + GlobalPyRefManager()->CollectGarbage(); + ifrt::PjRtDevice* device = llvm::dyn_cast(device_); + if (device == nullptr || !device->IsAddressable()) { + return xla::InvalidArgument( + "MemoryStats is only supported for addressable PjRt devices."); + } + absl::StatusOr maybe_stats = + device->pjrt_device()->GetAllocatorStats(); + if (absl::IsUnimplemented(maybe_stats.status())) { + return std::nullopt; + } + // Raise error if any status other than Unimplemented is returned. + xla::ThrowIfError(maybe_stats.status()); + + nb::dict result; + result["num_allocs"] = maybe_stats->num_allocs; + result["bytes_in_use"] = maybe_stats->bytes_in_use; + result["peak_bytes_in_use"] = maybe_stats->peak_bytes_in_use; + result["largest_alloc_size"] = maybe_stats->largest_alloc_size; + if (maybe_stats->bytes_limit) { + result["bytes_limit"] = *maybe_stats->bytes_limit; + } + result["bytes_reserved"] = maybe_stats->bytes_reserved; + result["peak_bytes_reserved"] = maybe_stats->peak_bytes_reserved; + if (maybe_stats->bytes_reservable_limit) { + result["bytes_reservable_limit"] = *maybe_stats->bytes_reservable_limit; + } + result["largest_free_block_bytes"] = maybe_stats->largest_free_block_bytes; + if (maybe_stats->pool_bytes) { + result["pool_bytes"] = *maybe_stats->pool_bytes; + } + if (maybe_stats->peak_pool_bytes) { + result["peak_pool_bytes"] = *maybe_stats->peak_pool_bytes; + } + return result; +} + +absl::StatusOr PyDevice::GetStreamForExternalReadyEvents() + const { + ifrt::PjRtDevice* device = llvm::dyn_cast(device_); + if (device == nullptr || !device->IsAddressable()) { + return xla::InvalidArgument( + "GetStreamForExternalReadyEvents is only supported for addressable " + "PjRt devices."); + } + return device->pjrt_device()->GetStreamForExternalReadyEvents(); +} + +/* static */ int PyDevice::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + PyDevice* d = nb::inst_ptr(self); + Py_VISIT(d->client().ptr()); + return 0; +} + +/* static */ int PyDevice::tp_clear(PyObject* self) { + PyDevice* d = nb::inst_ptr(self); + nb_class_ptr client; + std::swap(client, d->client_); + return 0; +} + +PyType_Slot PyDevice::slots_[] = { + {Py_tp_traverse, (void*)PyDevice::tp_traverse}, + {Py_tp_clear, (void*)PyDevice::tp_clear}, + {0, nullptr}, +}; + +/* static */ void PyDevice::Register(nb::module_& m) { + nb::class_ device( + m, "Device", nb::type_slots(PyDevice::slots_), + "A descriptor of an available device.\n\nSubclasses are used to " + "represent specific types of devices, e.g. CPUs, GPUs. Subclasses may " + "have additional properties specific to that device type."); + device + .def_prop_ro( + "id", &PyDevice::id, + "Integer ID of this device.\n\nUnique across all available devices " + "of this type, including remote devices on multi-host platforms.") + .def_prop_ro("process_index", &PyDevice::process_index, + "Integer index of this device's process.\n\n" + "This is always 0 except on multi-process platforms.") + .def_prop_ro("host_id", &PyDevice::process_index, + "Deprecated; please use process_index") + .def_prop_ro("task_id", &PyDevice::process_index, + "Deprecated; please use process_index") + .def_prop_ro("platform", &PyDevice::platform) + .def_prop_ro("device_kind", &PyDevice::device_kind) + .def_prop_ro("client", &PyDevice::client) + .def_prop_ro( + "local_hardware_id", &PyDevice::local_hardware_id, + "Opaque hardware ID, e.g., the CUDA device number. In general, not " + "guaranteed to be dense, and not guaranteed to be defined on all " + "platforms.") + .def("__str__", &PyDevice::Str) + .def("__repr__", &PyDevice::Repr) + .def("memory", xla::ValueOrThrowWrapper(&PyDevice::Memory), + nb::arg("kind")) + .def("default_memory", xla::ValueOrThrowWrapper(&PyDevice::DefaultMemory), + "Returns the default memory of a device.") + .def("addressable_memories", &PyDevice::AddressableMemories, + "Returns all the memories that a device can address.") + + .def("live_buffers", + [](nb::handle device) { + xla::PythonDeprecationWarning( + /*stacklevel=*/1, + "Per device live_buffers() is deprecated. Please " + "use the jax.live_arrays() for jax.Arrays instead."); + return nb::list(); + }) + .def( + "memory_stats", xla::ValueOrThrowWrapper(&PyDevice::MemoryStats), + "Returns memory statistics for this device keyed by name. May not " + "be implemented on all platforms, and different platforms may return " + "different stats, or -1 for unavailable stats. 'bytes_in_use' is " + "usually available. Intended for diagnostic use.") + .def( + "get_stream_for_external_ready_events", + xla::ValueOrThrowWrapper(&PyDevice::GetStreamForExternalReadyEvents)); + static PyMethodDef get_attr_method = { + "__getattr__", + +[](PyObject* self, PyObject* args) -> PyObject* { + PyObject* key; + if (!PyArg_ParseTuple(args, "O", &key)) { + PyErr_SetString(PyExc_TypeError, "__getattr__ must take 1 argument."); + return nullptr; + } + try { + auto device = nb::cast(nb::handle(self)); + auto name = nb::cast(nb::handle(key)); + auto value = + device->device_->Attributes().Get( + name); + if (value.ok()) { + auto result = + std::visit([](auto&& v) { return nb::cast(v.value); }, *value); + return result.release().ptr(); + } + PyErr_SetNone(PyExc_AttributeError); + return nullptr; + } catch (std::exception& e) { + PyErr_Format(PyExc_SystemError, "Unhandled nanobind exception: %s", + e.what()); + return nullptr; + } catch (...) { + PyErr_SetString(PyExc_SystemError, "Unhandled nanobind exception."); + return nullptr; + } + }, + METH_VARARGS, + nullptr, + }; + device.attr("__getattr__") = nb::steal(PyDescr_NewMethod( + reinterpret_cast(device.ptr()), &get_attr_method)); +} + +} // namespace jax diff --git a/jaxlib/py_device.h b/jaxlib/py_device.h new file mode 100644 index 000000000000..70d16cd64e15 --- /dev/null +++ b/jaxlib/py_device.h @@ -0,0 +1,82 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_DEVICE_H_ +#define JAXLIB_PY_DEVICE_H_ + +#include + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "xla/literal.h" +#include "xla/python/ifrt/device.h" +#include "xla/shape.h" + +namespace jax { + +class PyDevice { + public: + PyDevice(nb_class_ptr client, xla::ifrt::Device* device); + + // Devices are compared using Python object identity, so we don't allow them + // to be copied or moved. + PyDevice(const PyDevice&) = delete; + PyDevice(PyDevice&&) = delete; + PyDevice& operator=(const PyDevice&) = delete; + PyDevice& operator=(PyDevice&&) = delete; + + const nb_class_ptr& client() const { return client_; } + xla::ifrt::Device* device() const { return device_; } + + int id() const; + int process_index() const; + std::string_view platform() const; + std::string_view device_kind() const; + std::optional local_hardware_id() const; + + std::string_view Str() const; + std::string_view Repr() const; + + absl::StatusOr> Memory( + std::string_view kind) const; + absl::StatusOr> DefaultMemory() const; + nanobind::typed AddressableMemories() const; + absl::StatusOr< + std::optional>> + MemoryStats() const; + + absl::StatusOr GetStreamForExternalReadyEvents() const; + + static void Register(nanobind::module_& m); + + private: + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); + static PyType_Slot slots_[]; + + nb_class_ptr client_; + xla::ifrt::Device* device_; +}; + +} // namespace jax + +#endif // JAXLIB_PY_DEVICE_H_ diff --git a/jaxlib/py_device_list.cc b/jaxlib/py_device_list.cc new file mode 100644 index 000000000000..9ef786f72bc8 --- /dev/null +++ b/jaxlib/py_device_list.cc @@ -0,0 +1,506 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_device_list.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/log/check.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "nanobind/make_iterator.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/set.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_device.h" +#include "jaxlib/python_ref_manager.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/types.h" +#include "xla/util.h" + +namespace jax { + +namespace nb = ::nanobind; + +PyDeviceList::PyDeviceList(nb_class_ptr py_client, + xla::ifrt::DeviceListRef device_list) + : py_client_(std::move(py_client)), device_list_(std::move(device_list)) {} + +PyDeviceList::PyDeviceList(nb::tuple py_device_assignment) + : device_list_(py_device_assignment) { + // Attempt to convert to Python devices into `ifrt::DeviceList`. + if (py_device_assignment.size() == 0) { + return; + } + absl::InlinedVector devices; + devices.reserve(py_device_assignment.size()); + for (nb::handle obj : py_device_assignment) { + if (!nb::isinstance(obj.ptr())) { + // Non-`PyDevice` is used on an alternative JAX backend with device + // duck typing. Use Python device objects already set in `device_list_`. + return; + } + auto py_device = nb::cast(obj); + if (py_client_.get() == nullptr) { + py_client_ = py_device->client(); + } else if (py_device->client().get() != py_client_.get()) { + // If the list contains multiple clients, fall back to device duck typing. + return; + } + devices.push_back(py_device->device()); + } + device_list_ = + xla::ValueOrThrow(py_client_->ifrt_client()->MakeDeviceList(devices)); +} + +PyDeviceList::~PyDeviceList() { + if (device_list_.index() == 1) { + GlobalPyRefManager()->AddGarbage( + std::move(std::get<1>(std::move(device_list_)))); + } +} + +absl::StatusOr PyDeviceList::ifrt_device_list() + const { + switch (device_list_.index()) { + case 0: + return std::get<0>(device_list_); + case 1: + return xla::InvalidArgument("DeviceList contains non-IFRT devices"); + default: + return xla::InvalidArgument("Unrecognized DeviceList type"); + } +} + +int64_t PyDeviceList::Hash() { + if (!hash_.has_value()) { + switch (device_list_.index()) { + case 0: + hash_ = absl::HashOf(std::get<0>(device_list_)); + break; + case 1: + hash_ = nb::hash(std::get<1>(device_list_)); + break; + default: + throw nb::value_error("Unrecognized DeviceList type"); + } + } + return *hash_; +} + +/*static*/ bool PyDeviceList::Equal(nb_class_ptr self, + nb::handle other) { + if (!nb::isinstance(other)) { + return false; + } + auto o = nb::cast(other); + // Fast-path using a pointer equality check. + if (self.get() == o) { + return true; + } + int64_t h1, h2; + { + nb::ft_object_guard lock(self); + h1 = self->Hash(); + } + { + nb::ft_object_guard lock(other); + h2 = o->Hash(); + } + if (h1 != h2) { + return false; + } + if (self->device_list_.index() == 0 && o->device_list_.index() == 0) { + nb::gil_scoped_release gil_release; + return *std::get<0>(self->device_list_) == *std::get<0>(o->device_list_); + } else { + return self->AsTuple().equal(o->AsTuple()); + } +} + +/*static*/ bool PyDeviceList::NotEqual(nb_class_ptr self, + nb::handle other) { + return !Equal(std::move(self), other); +} + +int PyDeviceList::Len() const { + switch (device_list_.index()) { + case 0: + return std::get<0>(device_list_)->size(); + case 1: + return nb::len(std::get<1>(device_list_)); + default: + throw nb::value_error("Unrecognized DeviceList type"); + } +} + +nb::object PyDeviceList::GetItem(int index) { + switch (device_list_.index()) { + case 0: { + const xla::ifrt::DeviceListRef& device_list = std::get<0>(device_list_); + if (index < -device_list->size() || index >= device_list->size()) { + throw nb::index_error(); + } else if (index < 0) { + index += device_list->size(); + } + return py_client_->GetPyDevice(device_list->devices()[index]); + } + case 1: + return std::get<1>(device_list_).attr("__getitem__")(index); + default: + throw nb::value_error("Unrecognized DeviceList type"); + } +} + +nb::object PyDeviceList::GetSlice(nb::slice slice) { + switch (device_list_.index()) { + case 0: { + const xla::ifrt::DeviceListRef& device_list = std::get<0>(device_list_); + const absl::Span devices = + device_list->devices(); + Py_ssize_t start, stop, step, slicelength; + if (PySlice_GetIndicesEx(slice.ptr(), devices.size(), &start, &stop, + &step, &slicelength) != 0) { + throw nb::python_error(); + } + nb::tuple out = nb::steal(PyTuple_New(slicelength)); + for (size_t i = 0; i < slicelength; ++i) { + nb::object d = py_client_->GetPyDevice(devices[start]); + PyTuple_SET_ITEM(out.ptr(), i, d.release().ptr()); + start += step; + } + return std::move(out); + } + case 1: + return std::get<1>(device_list_).attr("__getitem__")(slice); + default: + throw nb::value_error("Unrecognized DeviceList type"); + } +} + +nb::tuple PyDeviceList::AsTuple() const { + switch (device_list_.index()) { + case 0: { + const xla::ifrt::DeviceListRef& device_list = std::get<0>(device_list_); + nb::tuple out = nb::steal(PyTuple_New(device_list->size())); + int i = 0; + for (xla::ifrt::Device* device : device_list->devices()) { + nb::object d = py_client_->GetPyDevice(device); + PyTuple_SET_ITEM(out.ptr(), i, d.release().ptr()); + ++i; + } + return out; + } + case 1: + return std::get<1>(device_list_); + default: + throw nb::value_error("Unrecognized DeviceList type"); + } +} + +nb::iterator PyDeviceList::Iter() { + switch (device_list_.index()) { + case 0: { + // Iterator whose deference converts `xla::ifrt::Device*` into JAX + // `PjRtDevice`. + struct Iterator { + void operator++() { ++it; } + bool operator==(const Iterator& other) const { return it == other.it; } + nb_class_ptr operator*() const { + return py_client->GetPyDevice(*it); + } + nb_class_ptr py_client; + absl::Span::const_iterator it; + }; + return nb::make_iterator( + nb::type(), "ifrt_device_iterator", + Iterator{py_client_, std::get<0>(device_list_)->devices().cbegin()}, + Iterator{py_client_, std::get<0>(device_list_)->devices().cend()}); + } + case 1: + return nb::make_iterator( + nb::type(), "python_device_iterator", + std::get<1>(device_list_).begin(), std::get<1>(device_list_).end()); + default: + throw nb::value_error("Unrecognized DeviceList type"); + } +} + +std::string PyDeviceList::Str() { + return nb::cast(nb::str(AsTuple())); +} + +nb::tuple PyDeviceList::Dump() const { return AsTuple(); } + +bool PyDeviceList::IsFullyAddressable() { + if (!is_fully_addressable_.has_value()) { + ProcessIndices(); + CHECK(process_indices_.has_value()); + if (process_indices_->size() > 1) { + is_fully_addressable_ = false; + } else { + CHECK_EQ(process_indices_->size(), 1); + int process_index; + switch (device_list_.index()) { + case 0: { + process_index = py_client_ ? py_client_->process_index() : 0; + break; + } + case 1: { + process_index = + nb::cast(std::get<1>(device_list_)[0].attr("client").attr( + "process_index")()); + break; + } + default: + throw nb::value_error("Unrecognized DeviceList type"); + } + is_fully_addressable_ = *process_indices_->begin() == process_index; + } + } + return *is_fully_addressable_; +} + +/*static*/ nb_class_ptr PyDeviceList::AddressableDeviceList( + nb_class_ptr self) { + nb::ft_object_guard lock(self); + if (self->IsFullyAddressable()) { + // Do not cache this result in `addressable_device_list_`. Otherwise, it + // will create a cycle that prevents deletion of this object. + return self; + } + if (!self->addressable_device_list_.has_value()) { + switch (self->device_list_.index()) { + case 0: { + absl::InlinedVector addressable_devices; + const int process_index = + self->py_client_ ? self->py_client_->process_index() : 0; + for (xla::ifrt::Device* device : + std::get<0>(self->device_list_)->devices()) { + if (device->ProcessIndex() == process_index) { + addressable_devices.push_back(device); + } + } + self->addressable_device_list_ = make_nb_class( + self->py_client_, + xla::ValueOrThrow(self->py_client_->ifrt_client()->MakeDeviceList( + addressable_devices))); + break; + } + case 1: { + auto device_list = std::get<1>(self->device_list_); + std::vector addressable_devices; + for (size_t i = 0; i < device_list.size(); ++i) { + nb::object device = device_list[i]; + if (nb::cast(device.attr("process_index")) == + nb::cast(device.attr("client").attr("process_index")())) { + addressable_devices.push_back(std::move(device)); + } + } + self->addressable_device_list_ = make_nb_class( + xla::MutableSpanToNbTuple(absl::MakeSpan(addressable_devices))); + break; + } + default: + throw nb::value_error("Unrecognized DeviceList type"); + } + } + return nb::cast>(*self->addressable_device_list_); +} + +const std::set& PyDeviceList::ProcessIndices() { + if (!process_indices_.has_value()) { + process_indices_ = std::set{}; + switch (device_list_.index()) { + case 0: { + for (const xla::ifrt::Device* device : + std::get<0>(device_list_)->devices()) { + process_indices_->insert(device->ProcessIndex()); + } + break; + } + case 1: { + for (nb::handle device : std::get<1>(device_list_)) { + process_indices_->insert(nb::cast(device.attr("process_index"))); + } + break; + } + default: + throw nb::value_error("Unrecognized DeviceList type"); + } + } + return *process_indices_; +} + +const std::string& PyDeviceList::DeviceKind() { + if (!device_kind_.has_value()) { + auto device_list = ifrt_device_list(); + if (!device_list.ok()) { + throw nb::value_error(device_list.status().ToString().c_str()); + } + if (Len() == 0) { + throw nb::value_error("DeviceList is empty"); + } + device_kind_ = (*device_list)->devices()[0]->Kind(); + } + return *device_kind_; +} + +void PyDeviceList::PopulateMemoryKindInfo() { + if (device_list_.index() == 1) { + // Handle Python duck-type devices in a separate function for readability. + PopulateMemoryKindInfoForDuckTypedDevices(); + return; + } + if (device_list_.index() != 0) { + throw nb::value_error("Unrecognized DeviceList type"); + } + MemoryKindInfo info; + if (std::get<0>(device_list_)->size() == 0) { + info.default_memory_kind = nb::none(); + memory_kind_info_ = std::move(info); + return; + } + xla::ifrt::Device* device = std::get<0>(device_list_)->devices()[0]; + + auto default_memory = device->DefaultMemory(); + if (!default_memory.ok()) { + // Cache the error. + memory_kind_info_ = default_memory.status(); + return; + } + info.default_memory_kind = nb::cast(*(*default_memory)->Kind().memory_kind()); + nb::tuple memory_kinds = + nb::steal(PyTuple_New(device->Memories().size())); + for (size_t i = 0; i < device->Memories().size(); ++i) { + auto* memory = device->Memories()[i]; + nb::str s = nb::str(memory->Kind().memory_kind()->data(), + memory->Kind().memory_kind()->size()); + PyTuple_SET_ITEM(memory_kinds.ptr(), i, s.release().ptr()); + } + info.memory_kinds = std::move(memory_kinds); + memory_kind_info_ = std::move(info); +} + +void PyDeviceList::PopulateMemoryKindInfoForDuckTypedDevices() { + MemoryKindInfo info; + try { + if (std::get<1>(device_list_).size() == 0) { + info.default_memory_kind = nb::none(); + // info.memory_kinds is default-initialized to an empty tuple. + memory_kind_info_ = std::move(info); + return; + } + nb::handle device = std::get<1>(device_list_)[0]; + auto default_memory = device.attr("default_memory")(); + info.default_memory_kind = default_memory.attr("kind"); + info.memory_kinds = + nb::tuple(nb::object(device.attr("addressable_memories")())); + memory_kind_info_ = std::move(info); + } catch (nb::python_error& e) { + // Cache the error. + memory_kind_info_ = xla::InvalidArgument("%s", e.what()); + } +} + +/*static*/ absl::StatusOr> +PyDeviceList::MemoryKinds(nb_class_ptr self) { + nb::ft_object_guard lock(self); + if (!self->memory_kind_info_.has_value()) { + self->PopulateMemoryKindInfo(); + } + if (!self->memory_kind_info_->ok()) { + return self->memory_kind_info_->status(); + } + return (*self->memory_kind_info_)->memory_kinds; +} + +/*static*/ absl::StatusOr PyDeviceList::DefaultMemoryKind( + nb_class_ptr self) { + nb::ft_object_guard lock(self); + if (!self->memory_kind_info_.has_value()) { + self->PopulateMemoryKindInfo(); + } + if (!self->memory_kind_info_->ok()) { + return self->memory_kind_info_->status(); + } + return (*self->memory_kind_info_)->default_memory_kind; +} + +/*static*/ void PyDeviceList::Register(nb::module_& m) { + nb::class_(m, "DeviceList") + .def(nb::init>()) + .def("__hash__", &PyDeviceList::Hash, nb::lock_self()) + .def("__eq__", &PyDeviceList::Equal) + .def("__ne__", &PyDeviceList::NotEqual) + .def("__len__", &PyDeviceList::Len) + .def("__getitem__", &PyDeviceList::GetItem, + nb::sig("def __getitem__(self, index: int, /) -> Device")) + .def( + "__getitem__", &PyDeviceList::GetSlice, + nb::sig("def __getitem__(self, slice: slice, /) -> Sequence[Device]")) + .def("__iter__", &PyDeviceList::Iter, nb::keep_alive<0, 1>()) + .def("__str__", &PyDeviceList::Str) + .def("__repr__", &PyDeviceList::Str) + .def("__getstate__", [](const PyDeviceList& l) { return l.Dump(); }) + .def("__setstate__", + [](PyDeviceList& self, nb::tuple t) { + new (&self) PyDeviceList(std::move(t)); + }) + .def_prop_ro("is_fully_addressable", &PyDeviceList::IsFullyAddressable, + nb::lock_self()) + .def_prop_ro("addressable_device_list", + &PyDeviceList::AddressableDeviceList) + .def_prop_ro("process_indices", &PyDeviceList::ProcessIndices, + nb::lock_self()) + // `xla::ValueOrThrowWrapper` does not work with + // `def_prop_ro()`. Manually convert an error into an exception. + .def_prop_ro( + "default_memory_kind", + [](nb_class_ptr l) { + auto kind = DefaultMemoryKind(l); + if (!kind.ok()) { + throw nb::value_error(kind.status().ToString().c_str()); + } + return *kind; + }, + nb::sig("def default_memory_kind(self) -> str | None")) + .def_prop_ro("memory_kinds", + [](nb_class_ptr l) { + auto kinds = MemoryKinds(l); + if (!kinds.ok()) { + throw nb::value_error(kinds.status().ToString().c_str()); + } + return *kinds; + }) + .def_prop_ro("device_kind", &PyDeviceList::DeviceKind, nb::lock_self()); +} + +} // namespace jax diff --git a/jaxlib/py_device_list.h b/jaxlib/py_device_list.h new file mode 100644 index 000000000000..c732c6418387 --- /dev/null +++ b/jaxlib/py_device_list.h @@ -0,0 +1,146 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_DEVICE_LIST_H_ +#define JAXLIB_PY_DEVICE_LIST_H_ + +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "xla/python/ifrt/device_list.h" + +namespace jax { + +// Device list with various caching and direct access to IFRT DeviceList. +class PyDeviceList { + public: + PyDeviceList(nb_class_ptr py_client, + xla::ifrt::DeviceListRef device_list); + explicit PyDeviceList(nanobind::tuple py_device_assignment); + ~PyDeviceList(); + + PyDeviceList(const PyDeviceList&) = delete; + PyDeviceList(PyDeviceList&&) = delete; + PyDeviceList& operator=(const PyDeviceList&) = delete; + PyDeviceList& operator=(PyDeviceList&&) = delete; + + static nanobind::handle type() { + static auto type = nanobind::type(); + return type; + } + + // These two methods are safe to call from C++ without GIL. + nb_class_ptr py_client() const { return py_client_; } + absl::StatusOr ifrt_device_list() const; + + int Len() const; // Requires the GIL in GIL mode. + nanobind::object GetItem(int index); // Requires the GIL in GIL mode. + + // Requires the GIL in GIL mode. Acquires the self lock in non-GIL mode. + static nb_class_ptr AddressableDeviceList( + nb_class_ptr self); + + // Requires the GIL in GIL mode. Acquires the self lock in non-GIL mode. + static absl::StatusOr DefaultMemoryKind( + nb_class_ptr self); + + // Requires the GIL in GIL mode. Acquires the self lock in non-GIL mode. + static absl::StatusOr< + nanobind::typed> + MemoryKinds(nb_class_ptr self); + + // go/pywald-pybind-annotation BEGIN + // refs { + // module_path: "third_party/py/jax/jaxlib/py_device_list.cc" + // module_arg {} + // } + // go/pywald-pybind-annotation END + static void Register(nanobind::module_& m); + + private: + nanobind::tuple AsTuple() const; + + // Methods below require GIL. + nanobind::object GetSlice(nanobind::slice slice); + nanobind::iterator Iter(); + + std::string Str(); + + nanobind::tuple Dump() const; + + int64_t Hash(); // Mutates hash_, needs self lock. + + static bool Equal(nb_class_ptr self, nanobind::handle other); + static bool NotEqual(nb_class_ptr self, nanobind::handle other); + + // Finds the memory kind info from an addressable device. Requires the GIL + // or self lock. + void PopulateMemoryKindInfo(); + // Same as `PopulateMemoryKindInfo()`, but uses `py_device_assignment_` + // instead of `ifrt_device_list_` to support duck-typed device objects. + // Requires the GIL or self lock. + void PopulateMemoryKindInfoForDuckTypedDevices(); + + // Requires the self lock or GIL is held. + bool IsFullyAddressable(); + + // Requires the self lock or GIL. + const std::set& ProcessIndices(); + + // Requires the self lock or GIL. + const std::string& DeviceKind(); + + // Valid only if `device_list_` contains `xla::ifrt::DeviceList` and + // non-empty. + nb_class_ptr py_client_; + + // Either C++ `ifrt::DeviceList` or Python duck-type devices. + // TODO(hyeontaek): Remove support for Python duck-type devices once all + // JAX backends and tests are migrated to use an `xla::ifrt::Device` type + // for JAX devices. + // Immutable after constructor; no locking needed. + std::variant device_list_; + + // Populated on demand. Guarded by the object's self lock. + std::optional hash_; + // TODO(hyeontaek): Make the following property cached within + // `xla::ifrt::DeviceList`. + // Populated on demand. Guarded by the object's self lock. + std::optional is_fully_addressable_; + // Populated on demand. Guarded by the object's self lock. + std::optional addressable_device_list_; + // Populated on demand. Guarded by the object's self lock. + std::optional> process_indices_; + // Populated on demand. Guarded by the object's self lock. + std::optional device_kind_; + + struct MemoryKindInfo { + nanobind::object default_memory_kind; + nanobind::tuple memory_kinds; + }; + // Populated on demand. Guarded by the object's self lock. + std::optional> memory_kind_info_; +}; + +} // namespace jax + +#endif // JAXLIB_PY_DEVICE_LIST_H_ diff --git a/jaxlib/py_executable.cc b/jaxlib/py_executable.cc new file mode 100644 index 000000000000..e458183b55ab --- /dev/null +++ b/jaxlib/py_executable.cc @@ -0,0 +1,610 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_executable.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/casts.h" +#include "absl/base/const_init.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/call_location.h" +#include "jaxlib/guard_lib.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_device.h" +#include "jaxlib/py_user_context.h" +#include "jaxlib/traceback.h" +#include "xla/future.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt/user_context.h" +#include "xla/python/ifrt/user_context_status_util.h" +#include "xla/python/nb_absl_flat_hash_map.h" // IWYU pragma: keep +#include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" +#include "xla/python/version.h" +#include "xla/tsl/concurrency/future.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/fingerprint.h" +#include "tsl/profiler/lib/traceme.h" + +namespace ifrt = xla::ifrt; + +namespace { + +uint64_t GetBaseLaunchId(std::optional fingerprint, + ifrt::LoadedExecutableRef executable) { + uint64_t ret = 0; + if (fingerprint.has_value()) { + ret = tsl::Fingerprint64(*fingerprint); + } + // Don't use the device fingerprint for executables running on single process. + // Pmap and replicated executables for example will only populate the local + // device to the loaded executable and all devices will have different devices + // fingerprints. + if (std::optional device_list = executable->devices(); + device_list.has_value() && !(*device_list)->IsFullyAddressable()) { + ret += (*device_list)->fingerprint(); + } + VLOG(1) << "Get base launch id: " << ret << " from fingerprint: " + << (fingerprint.has_value() + ? absl::StrCat(tsl::Fingerprint64(*fingerprint)) + : ""); + return ret; +} + +} // namespace + +namespace nb = nanobind; + +namespace jax { + +// PyToken + +absl::Status PyToken::Await() { + CHECK(future_.IsValid()); + absl::Status status; + { + nb::gil_scoped_release gil_release; + status = future_.Await(); + } + // `status` originates from `ifrt::ExecuteResult::status`, which can reference + // an asynchronously propagated `ifrt::UserContext` representing the context + // of an error. We expand this future result right before returning it to + // Python (outside of `nb::gil_scoped_release`) so that any attached user + // context is appended to the status message. + return xla::ifrt::ExpandUserContexts(std::move(status)); +} + +void PyToken::Register(nb::module_& m) { + nb::class_ token(m, "Token"); + token.def("block_until_ready", + [](PyToken& self) { xla::ThrowIfError(self.Await()); }); +} + +// PyShardedToken + +absl::Status PyShardedToken::Await() { + absl::Status status = absl::OkStatus(); + { + nb::gil_scoped_release gil_release; + for (auto& future : futures_) { + auto s = future.Await(); + if (!s.ok()) status = std::move(s); + } + } + // `status` combines the statuses originating from + // `ifrt::ExecuteResult::status`, which can reference an asynchronously + // propagated `ifrt::UserContext` representing the context of an error. We + // expand this future result right before returning it to Python (outside of + // `nb::gil_scoped_release`) so that any attached user context is appended to + // the status message. + return xla::ifrt::ExpandUserContexts(std::move(status)); +} + +void PyShardedToken::Register(nb::module_& m) { + nb::class_ sharded_token(m, "ShardedToken"); + sharded_token.def("block_until_ready", [](PyShardedToken& self) { + xla::ThrowIfError(self.Await()); + }); + sharded_token.def("get_token", &PyShardedToken::GetPyToken); +} + +// PyExecuteResults + +namespace { + +void PopulateExecuteShardedResults(const nb_class_ptr& client, + std::vector ifrt_arrays, + const xla::Future<>& result_status, + int num_computations, + std::vector>& outputs) { + DCHECK_GT(num_computations, 0); + int num_output_buffers = ifrt_arrays.size(); + outputs.resize(num_output_buffers); + for (int buffer_id = 0; buffer_id < num_output_buffers; ++buffer_id) { + xla::ifrt::UserContextScope user_context_scope( + ifrt_arrays[buffer_id]->user_context()); + outputs[buffer_id].reserve(num_computations); + auto exploded_arrays = + ifrt_arrays[buffer_id]->DisassembleIntoSingleDeviceArrays( + ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); + TF_CHECK_OK(exploded_arrays.status()); + for (auto& exploded_array : *exploded_arrays) { + outputs[buffer_id].push_back(PyArray::MakeFromSingleDeviceArray( + client, std::move(exploded_array), false, true, result_status)); + } + } +} + +} // namespace + +PyExecuteResults::PyExecuteResults(const nb_class_ptr& client, + std::vector ifrt_arrays, + int num_computations, PyShardedToken token, + xla::Future<> result_status) + : client_(client), + ifrt_arrays_(std::move(ifrt_arrays)), + num_computations_(num_computations), + token_(std::move(token)), + result_status_(std::move(result_status)) {} + +void PyExecuteResults::CheckNotDisassembled() const { + if (is_exploded_) { + throw nb::value_error("ExecuteResults already exploded."); + } +} + +std::vector PyExecuteResults::Consume() { + CheckNotDisassembled(); + is_exploded_ = true; + return std::move(ifrt_arrays_); +} + +PyShardedToken PyExecuteResults::ConsumeToken() { + if (token_consumed_) { + throw nb::value_error("ExecuteResults token already consumed."); + } + token_consumed_ = true; + return std::move(token_); +} + +std::vector> +PyExecuteResults::DisassembleIntoSingleDeviceArrays() { + std::vector> outputs; + PopulateExecuteShardedResults( + client_, Consume(), + result_status_.IsValid() ? result_status_ : xla::Future<>(), + num_computations_, outputs); + return outputs; +} + +std::vector> +PyExecuteResults::DisassemblePrefixIntoSingleDeviceArrays(size_t n) { + CheckNotDisassembled(); + if (n > ifrt_arrays_.size()) { + throw nb::value_error( + absl::StrCat("In DisassemblePrefixIntoSingleDeviceArrays: ", n, " > ", + ifrt_arrays_.size()) + .c_str()); + } + std::vector ifrt_arrays; + ifrt_arrays.reserve(ifrt_arrays_.size() - n); + for (size_t i = n; i < ifrt_arrays_.size(); ++i) { + ifrt_arrays.push_back(std::move(ifrt_arrays_[i])); + } + ifrt_arrays_.erase(ifrt_arrays_.begin() + n, ifrt_arrays_.end()); + std::swap(ifrt_arrays_, ifrt_arrays); + std::vector> outputs; + PopulateExecuteShardedResults( + client_, std::move(ifrt_arrays), + result_status_.IsValid() ? result_status_ : xla::Future<>(), + num_computations_, outputs); + return outputs; +} + +std::vector PyExecuteResults::ConsumeWithHandlers( + std::vector> + out_handlers, + bool strict) { + std::vector outputs; + int num_output_buffers = out_handlers.size(); + std::vector ifrt_arrays; + if (strict) { + if (out_handlers.size() != ifrt_arrays_.size()) { + throw nb::value_error( + absl::StrCat("Mismatch between out_handlers and num_results: ", + out_handlers.size(), " vs ", ifrt_arrays_.size()) + .c_str()); + } + ifrt_arrays = Consume(); + } else { + if (out_handlers.size() > ifrt_arrays_.size()) { + throw nb::value_error( + absl::StrCat("Mismatch between out_handlers and num_results: ", + out_handlers.size(), " > ", ifrt_arrays_.size()) + .c_str()); + } + CheckNotDisassembled(); + ifrt_arrays.reserve(ifrt_arrays_.size() - num_output_buffers); + for (size_t i = num_output_buffers; i < ifrt_arrays_.size(); ++i) { + ifrt_arrays.push_back(std::move(ifrt_arrays_[i])); + } + ifrt_arrays_.erase(ifrt_arrays_.begin() + ifrt_arrays_.size(), + ifrt_arrays_.end()); + std::swap(ifrt_arrays_, ifrt_arrays); + } + for (int buffer_id = 0; buffer_id < num_output_buffers; ++buffer_id) { + auto& handler = out_handlers[buffer_id]; + xla::ifrt::UserContextScope user_context_scope( + ifrt_arrays[buffer_id]->user_context()); + if (std::holds_alternative(handler)) { + outputs.push_back(std::get(handler)->Call( + client_, std::move(ifrt_arrays[buffer_id]), + result_status_.IsValid() ? result_status_ : xla::Future<>())); + } else { + tsl::profiler::TraceMe traceme("ConsumeWithHandlers fallback."); + auto disassembled_arrays = + ifrt_arrays[buffer_id]->DisassembleIntoSingleDeviceArrays( + ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); + TF_CHECK_OK(disassembled_arrays.status()); + nb::list bufs = + nb::steal(PyList_New(disassembled_arrays->size())); + int i = 0; + for (auto& disassembled_array : *disassembled_arrays) { + nb::object array = PyArray::MakeFromSingleDeviceArray( + client_, std::move(disassembled_array), false, true, + result_status_.IsValid() ? result_status_ : xla::Future<>()); + PyList_SET_ITEM(bufs.ptr(), i, array.release().ptr()); + ++i; + } + outputs.push_back(std::get(handler)(std::move(bufs))); + } + } + return outputs; +} + +void PyExecuteResults::Register(nb::module_& m) { + nb::class_(m, "ExecuteResults") + .def("__len__", [](PyExecuteResults& results) { return results.Size(); }) + .def("disassemble_into_single_device_arrays", + &PyExecuteResults::DisassembleIntoSingleDeviceArrays) + .def("disassemble_prefix_into_single_device_arrays", + &PyExecuteResults::DisassemblePrefixIntoSingleDeviceArrays) + .def("consume_with_handlers", &PyExecuteResults::ConsumeWithHandlers, + nb::arg("out_handlers"), nb::arg("strict") = true) + .def("consume_token", &PyExecuteResults::ConsumeToken); +} + +// PyExecutable + +void PyExecutable::Register(nb::module_& m) { + nb::class_(m, "Executable") + .def("hlo_modules", + xla::ValueOrThrowWrapper(&PyExecutable::GetHloModules)) + .def("get_output_memory_kinds", + xla::ValueOrThrowWrapper(&PyExecutable::GetOutputMemoryKinds)) + .def("get_output_shardings", &PyExecutable::GetOutputShardings) + .def("get_parameter_layouts", + xla::ValueOrThrowWrapper(&PyExecutable::GetParameterLayouts)) + .def("get_output_layouts", + xla::ValueOrThrowWrapper(&PyExecutable::GetOutputLayouts)) + .def("get_parameter_shardings", &PyExecutable::GetParameterShardings) + .def("get_compiled_memory_stats", + xla::ValueOrThrowWrapper(&PyExecutable::GetCompiledMemoryStats)) + .def("serialize", + [](const PyExecutable& exec) -> nb::bytes { + std::string serialized = xla::ValueOrThrow(exec.Serialize()); + return nb::bytes(serialized.data(), serialized.size()); + }) + .def("cost_analysis", [](const PyExecutable& exec) { + auto attrs = xla::ValueOrThrow(exec.GetCostAnalysis()); + return xla::ifrt::ToPjRtAttributeMap(std::move(attrs)); + }); +} + +// PyLoadedExecutable + +PyLoadedExecutable::PyLoadedExecutable( + nb_class_ptr client, + ifrt::LoadedExecutableRef ifrt_loaded_executable, + std::optional fingerprint) + : client_(std::move(client)), + ifrt_loaded_executable_(std::move(ifrt_loaded_executable)), + fingerprint_(std::move(fingerprint)), + launch_id_key_(GetBaseLaunchId(fingerprint_, ifrt_loaded_executable_)) { + CHECK(PyGILState_Check()); + if (ifrt_loaded_executable_->user_context() == nullptr && + Traceback::IsEnabled()) { + throw nb::value_error( + "Expecting an IFRT `LoadedExecutable` to have a user context, but got " + "a null user context. Use `jax::PyUserContextScope` to set a user " + "context for operations producing IFRT `LoadedExecutable`s."); + } + if (fingerprint_) { + VLOG(1) << "Fingerprint for executable " << ifrt_loaded_executable_->name() + << ": " << *fingerprint_; + } + nb::ft_lock_guard lock(client_->executables_mutex_); + next_ = client_->executables_; + client_->executables_ = this; + prev_ = nullptr; + if (next_) { + next_->prev_ = this; + } +} + +PyLoadedExecutable::~PyLoadedExecutable() { + CHECK(PyGILState_Check()); + nb::ft_lock_guard lock(client_->executables_mutex_); + if (client_->executables_ == this) { + client_->executables_ = next_; + } + if (prev_) { + prev_->next_ = next_; + } + if (next_) { + next_->prev_ = prev_; + } +} + +std::vector> PyLoadedExecutable::AddressableDevices() + const { + std::vector> devices; + devices.reserve(ifrt_loaded_executable_->addressable_devices().size()); + for (ifrt::Device* device : ifrt_loaded_executable_->addressable_devices()) { + devices.push_back(client_->GetPyDevice(device)); + } + return devices; +} + +namespace { + +absl::StatusOr ExecuteShardedOnLocalDevicesInternal( + const ifrt::ExecuteOptions& options, const nb_class_ptr& client, + ifrt::LoadedExecutable* ifrt_loaded_executable, + absl::Span args, + std::optional>>& returned_futures) { + std::vector output_arrays; + std::unique_ptr> returned_future; + int num_computations = ifrt_loaded_executable->addressable_devices().size(); + xla::Future<> result_status; + { + nb::gil_scoped_release gil_release; + for (const auto& arg : args) { + if (arg.num_addressable_shards() != num_computations) { + return xla::InvalidArgument( + "Expected args to execute_sharded_on_local_devices to have %d " + "shards, got: [%s]", + num_computations, + absl::StrJoin(args, ", ", [](std::string* out, const PyArray& arg) { + out->append(std::to_string(arg.num_addressable_shards())); + })); + } + } + std::vector arg_arrays(args.size()); + absl::c_transform(args, arg_arrays.begin(), + [&](const PyArray& arg) mutable { + return tsl::FormRef(arg.ifrt_array()); + }); + TF_ASSIGN_OR_RETURN(auto result, ifrt_loaded_executable->Execute( + absl::MakeSpan(arg_arrays), options, + /*devices=*/std::nullopt)); + output_arrays = std::move(result.outputs); + // options.fill_status is only supposed to be true when the computation has + // tokens. + if (options.fill_status) { + result_status = result.status; + if (returned_futures.has_value()) { + returned_futures->resize(num_computations, std::move(result.status)); + } + } + } + + // TODO(b/240696624): Although the PjRt interface require `returned_futures` + // to be resized correctly if it is not nullopt, some implementation does not + // implement this. So we have to check whether returned_futures is empty. + // Remove this check once the implementation is fixed. + auto py_sharded_token = returned_futures.has_value() + ? PyShardedToken(std::move(*returned_futures)) + : PyShardedToken(); + + return PyExecuteResults(client, std::move(output_arrays), num_computations, + std::move(py_sharded_token), result_status); +} + +} // namespace + +absl::Mutex PyLoadedExecutable::next_launch_id_mutex_(absl::kConstInit); +absl::flat_hash_map* PyLoadedExecutable::next_launch_id_ = + new absl::flat_hash_map(); + +absl::StatusOr PyLoadedExecutable::ExecuteSharded( + std::vector args, bool with_tokens) { + // Check if the thread guard is active and should prevent execution. + // Skipped for portable executables. + if (ifrt_loaded_executable_->devices().has_value()) { + TF_RETURN_IF_ERROR(CheckThreadGuard(*ifrt_loaded_executable_->devices())); + } + + xla::ifrt::ExecuteOptions options = options_; + options.launch_id = GetNextLaunchId(); + options.fill_status = with_tokens; + options.execution_stream_id = GetExecutionStreamId(); + if (options.execution_stream_id == 0) { + options.execution_stream_id = tsl::Env::Default()->GetCurrentThreadId(); + } + PyUserContextScope user_context_scope; + PopulateCallLocation(options, xla::ifrt::UserContextScope::current().get()); + std::optional>> returned_futures; + if (with_tokens) { + returned_futures.emplace(); + } + absl::Span span_args = args; + return ExecuteShardedOnLocalDevicesInternal(options, client_, + ifrt_loaded_executable_.get(), + span_args, returned_futures); +} + +absl::StatusOr>> +PyLoadedExecutable::HloModules() const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetHloModules(); +} + +absl::StatusOr>> +PyLoadedExecutable::GetOutputMemoryKinds() const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetOutputMemoryKinds(); +} + +absl::StatusOr>> +PyLoadedExecutable::GetParameterLayouts() const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetParameterLayouts(); +} + +absl::StatusOr>> +PyLoadedExecutable::GetOutputLayouts() const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetOutputLayouts(); +} + +std::optional> +PyLoadedExecutable::GetParameterShardings() const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetParameterShardings(); +} + +std::optional> +PyLoadedExecutable::GetOutputShardings() const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetOutputShardings(); +} + +int32_t PyLoadedExecutable::GetNextLaunchId() { + int32_t launch_id; + { + absl::MutexLock lock(next_launch_id_mutex_); + auto it = next_launch_id_->find(launch_id_key_); + if (it == next_launch_id_->end()) { + uint32_t initial_value = static_cast(launch_id_key_); + it = next_launch_id_->emplace(launch_id_key_, initial_value).first; + } + launch_id = absl::bit_cast(it->second++); + } + VLOG(1) << "Launching executable " << ifrt_loaded_executable_->name() + << " with launch ID: " << launch_id << " key: " << launch_id_key_; + VLOG(2) << "Executable devices for launch ID " << launch_id << ": " + << (ifrt_loaded_executable_->devices().has_value() + ? (*ifrt_loaded_executable_->devices())->DebugString() + : ""); + return launch_id; +} + +void PyLoadedExecutable::KeepAlive(nb::object obj) { + keepalives_.push_back(std::move(obj)); +} + +void PyLoadedExecutable::Register(nb::module_& m) { + nb::class_(m, "LoadedExecutable") + .def_prop_ro("client", &PyLoadedExecutable::client) + .def("local_devices", &PyLoadedExecutable::AddressableDevices) + .def("get_hlo_text", + xla::ValueOrThrowWrapper( + &PyLoadedExecutable::GetHumanReadableProgramText)) + .def("serialize", + [](const PyLoadedExecutable& exec) -> nb::bytes { + std::string serialized = + xla::ValueOrThrow(exec.ifrt_loaded_executable()->Serialize()); + return nb::bytes(serialized.data(), serialized.size()); + }) + .def("size_of_generated_code_in_bytes", + &PyLoadedExecutable::SizeOfGeneratedCodeInBytes) + .def( + "get_compiled_memory_stats", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetCompiledMemoryStats)) + .def("execute_sharded", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::ExecuteSharded), + nb::arg("arguments"), nb::arg("with_tokens") = false) + .def("hlo_modules", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::HloModules)) + .def("get_output_memory_kinds", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetOutputMemoryKinds)) + .def("get_output_shardings", &PyLoadedExecutable::GetOutputShardings) + .def("get_parameter_layouts", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetParameterLayouts)) + .def("get_output_layouts", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetOutputLayouts)) + .def("get_parameter_shardings", + &PyLoadedExecutable::GetParameterShardings) + .def("keep_alive", &PyLoadedExecutable::KeepAlive) + .def("cost_analysis", + [](const PyLoadedExecutable& self) { + auto map = xla::ValueOrThrow(self.GetCostAnalysis()); + return xla::ifrt::ToPjRtAttributeMap(std::move(map)); + }) + .def_prop_ro("traceback", &PyLoadedExecutable::traceback) + .def_prop_ro("fingerprint", [](PyLoadedExecutable* exec) -> nb::object { + if (exec->fingerprint().has_value()) { + return nb::bytes(exec->fingerprint()->data(), + exec->fingerprint()->size()); + } else { + return nb::none(); + } + }); +} + +} // namespace jax diff --git a/jaxlib/py_executable.h b/jaxlib/py_executable.h new file mode 100644 index 000000000000..2e3f33a1222c --- /dev/null +++ b/jaxlib/py_executable.h @@ -0,0 +1,310 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_EXECUTABLE_H_ +#define JAXLIB_PY_EXECUTABLE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_user_context.h" +#include "jaxlib/traceback.h" +#include "xla/future.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/pjrt_ifrt/pjrt_executable.h" +#include "xla/xla_data.pb.h" + +namespace jax { + +class PyToken { + public: + PyToken() = default; + explicit PyToken(xla::Future<> future) : future_(std::move(future)) {} + + static PyToken ReadyPyToken() { + return PyToken(xla::Future<>(absl::OkStatus())); + } + + absl::Status Await(); + + static void Register(nanobind::module_& m); + + private: + xla::Future<> future_; +}; + +// PyShardedToken contains a PyToken for each device's execution. +class PyShardedToken { + public: + // Default construction creates a always-ready token. + PyShardedToken() = default; + explicit PyShardedToken(std::vector> futures) + : futures_(std::move(futures)) {} + + PyToken GetPyToken(int device_id) const { + if (futures_.empty()) return PyToken::ReadyPyToken(); + return PyToken(futures_.at(device_id)); + } + + absl::Status Await(); + + static void Register(nanobind::module_& m); + + private: + std::vector> futures_; +}; + +class PyExecuteResults { + public: + PyExecuteResults(const nb_class_ptr& client, + std::vector ifrt_arrays, + int num_computations, PyShardedToken token, + xla::Future<> result_status = xla::Future<>()); + + std::vector> DisassembleIntoSingleDeviceArrays(); + + std::vector> DisassemblePrefixIntoSingleDeviceArrays( + size_t n); + + std::vector ConsumeWithHandlers( + std::vector> + out_handlers, + bool strict); + + std::vector Consume(); + + PyShardedToken ConsumeToken(); + + size_t Size() const { + CheckNotDisassembled(); + return ifrt_arrays_.size(); + } + + void CheckNotDisassembled() const; + + static void Register(nanobind::module_& m); + + private: + bool is_exploded_ = false; + bool token_consumed_ = false; + nb_class_ptr client_; + std::vector ifrt_arrays_; + int num_computations_; + PyShardedToken token_; + // Only set if the computation has tokens. + xla::Future<> result_status_; +}; + +// Thin Python wrapper around xla::ifrt::ExecutableRef. We use a wrapper class: +// a) Standardize around xla::ifrt::ExecutableRef, which is +// std::shared_ptr. +// b) Concrete subclasses of xla::ifrt::Executable have protected constructors. +class PyExecutable { + public: + PyExecutable(xla::ifrt::ExecutableRef ifrt_executable) + : ifrt_executable_(std::move(ifrt_executable)) {}; + ~PyExecutable() = default; + + // NOTE(dsuo): For now, we only expose the xla::ifrt::Executable members + // required by the Python bindings. + absl::StatusOr>> GetHloModules() + const { + return ifrt_executable_->GetHloModules(); + } + absl::StatusOr>> + GetOutputMemoryKinds() const { + return ifrt_executable_->GetOutputMemoryKinds(); + } + std::optional> GetOutputShardings() const { + return ifrt_executable_->GetOutputShardings(); + } + absl::StatusOr>> + GetParameterLayouts() const { + return ifrt_executable_->GetParameterLayouts(); + } + absl::StatusOr>> + GetOutputLayouts() const { + return ifrt_executable_->GetOutputLayouts(); + } + std::optional> GetParameterShardings() const { + return ifrt_executable_->GetParameterShardings(); + } + absl::StatusOr GetCompiledMemoryStats() const { + return ifrt_executable_->GetCompiledMemoryStats(); + } + absl::StatusOr Serialize() const { + return ifrt_executable_->Serialize(); + } + absl::StatusOr GetCostAnalysis() const { + return ifrt_executable_->GetCostAnalysis(); + } + + static void Register(nanobind::module_& m); + + private: + xla::ifrt::ExecutableRef ifrt_executable_; +}; + +// Python wrapper around xla::ifrt::LoadedExecutableRef. We use a wrapper class: +// a) to keep the PyClient alive via a std::shared_ptr<> +// b) to add Python-specific functionality. +class PyLoadedExecutable { + public: + PyLoadedExecutable(nb_class_ptr client, + xla::ifrt::LoadedExecutableRef ifrt_loaded_executable, + std::optional fingerprint); + ~PyLoadedExecutable(); + + nb_class_ptr client() const { return client_; } + xla::ifrt::LoadedExecutable* ifrt_loaded_executable() const { + return ifrt_loaded_executable_.get(); + } + + xla::ifrt::LoadedExecutableRef shared_ifrt_loaded_executable() { + return ifrt_loaded_executable_; + } + + std::vector> AddressableDevices() const; + + absl::StatusOr GetHumanReadableProgramText() const { + return ifrt_loaded_executable_->GetHumanReadableProgramText(); + } + + int64_t SizeOfGeneratedCodeInBytes() const { + return ifrt_loaded_executable_->SizeOfGeneratedCodeInBytes(); + } + + absl::StatusOr GetCompiledMemoryStats() const { + nanobind::gil_scoped_release scope; + return ifrt_loaded_executable_->GetCompiledMemoryStats(); + } + + absl::StatusOr GetCostAnalysis() const { + return ifrt_loaded_executable_->GetCostAnalysis(); + } + + // Takes args indexed by argid then deviceid, transposes them, and passes to + // xla::ifrt::LoadedExecutable::Execute. The result is similarly transposed + // back into the argid,deviceid format. args is [num_args x num_devices]. + absl::StatusOr ExecuteSharded(std::vector args, + bool with_tokens); + + absl::StatusOr>> HloModules() + const; + + absl::StatusOr>> + GetOutputMemoryKinds() const; + + absl::StatusOr>> + GetParameterLayouts() const; + + absl::StatusOr>> + GetOutputLayouts() const; + + std::optional> GetParameterShardings() const; + + std::optional> GetOutputShardings() const; + + std::optional traceback() { + return GetTraceback(ifrt_loaded_executable_->user_context().get()); + } + + xla::ifrt::LoadedExecutable* ifrt_executable() const { + return ifrt_loaded_executable_.get(); + } + + // Short-term escape hatch to get xla::PjRtLoadedExecutable from PyExecutable. + // TODO(hyeontaek): Migrate all users of this method to be agnostic of PjRt. + std::shared_ptr shared_ptr_pjrt_executable() { + auto* exec = + llvm::dyn_cast_or_null( + ifrt_loaded_executable_.get()); + if (exec == nullptr) { + throw xla::XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return exec->shared_ptr_pjrt_loaded_executable(); + } + + // Returns a template of execute options to pass to + // `ifrt_executable()->Execute()`. Note that the caller may need to override + // some options such as `launch_id` that change at each execution. + const xla::ifrt::ExecuteOptions& options() const { return options_; } + + // Returns a unique launch ID to use for the next execution. + int32_t GetNextLaunchId(); + + const std::optional& fingerprint() const { return fingerprint_; } + + // Keep `obj` alive as long as PyLoadedExecutable. + void KeepAlive(nanobind::object obj); + + static void Register(nanobind::module_& m); + + private: + friend class PyClient; + + nb_class_ptr client_; + xla::ifrt::LoadedExecutableRef ifrt_loaded_executable_; + + // Identical executables (i.e. representing the same program) will have the + // same fingerprint. nullopt on platforms or executables where fingerprints + // aren't implemented. + std::optional fingerprint_; + + // Launch ID to use for the next execution. + const uint64_t launch_id_key_; + + static absl::Mutex next_launch_id_mutex_; + static absl::flat_hash_map* next_launch_id_; + + // The options to pass to `executable_.Execute`. + xla::ifrt::ExecuteOptions options_; + + // Python objects to keep alive as requested by user. + std::vector keepalives_; + + // Doubly-linked list of all executables known to the client. Protected by the + // GIL. + PyLoadedExecutable* next_; + PyLoadedExecutable* prev_; +}; + +} // namespace jax + +#endif // JAXLIB_PY_EXECUTABLE_H_ diff --git a/jaxlib/py_host_callback.cc b/jaxlib/py_host_callback.cc new file mode 100644 index 000000000000..ef511b69b202 --- /dev/null +++ b/jaxlib/py_host_callback.cc @@ -0,0 +1,260 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_host_callback.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "nanobind/nanobind.h" +#include "jaxlib/callback.h" +#include "jaxlib/py_host_callback.pb.h" +#include "jaxlib/python_ref_manager.h" +#include "xla/layout_util.h" +#include "xla/pjrt/host_callback.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/pjrt_ifrt/pjrt_host_callback.h" +#include "xla/python/pjrt_ifrt/xla_host_callback.pb.h" +#include "xla/python/types.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace ifrt = ::xla::ifrt; +namespace nb = nanobind; + +namespace jax { + +char PyFfiLoadedHostCallback::ID = 0; +char PyHostSendAndRecvLoadedHostCallback::ID = 0; + +namespace { + +absl::StatusOr> CreateCallbackArgs( + absl::Span operand_shapes) { + std::vector callback_args(operand_shapes.size()); + for (int i = 0; i < operand_shapes.size(); ++i) { + xla::Shape shape = operand_shapes[i]; + + if (shape.IsArray()) { + xla::Shape layout = + (shape.has_layout() ? shape + : xla::LayoutUtil::GetWithDefaultLayout(shape)); + callback_args[i].dims.resize(shape.dimensions().size()); + absl::c_copy(shape.dimensions(), callback_args[i].dims.begin()); + callback_args[i].strides = ByteStridesForShape(layout); + callback_args[i].type = shape.element_type(); + callback_args[i].size_in_bytes = xla::ShapeUtil::ByteSizeOf(layout); + TF_ASSIGN_OR_RETURN(callback_args[i].dtype, + PrimitiveTypeToNbDtype(shape.element_type())); + } else if (shape.IsToken()) { + callback_args[i].type = xla::TOKEN; + } else { + return xla::InvalidArgument( + "Only array and token arguments to Python callbacks are supported, " + "got %s", + shape.ToString()); + } + } + return callback_args; +} + +absl::StatusOr> CreateCallbackResults( + absl::Span result_shapes) { + std::vector callback_results(result_shapes.size()); + for (int i = 0; i < result_shapes.size(); ++i) { + if (result_shapes[i].IsArray()) { + const xla::Shape& shape = + result_shapes[i].has_layout() + ? result_shapes[i] + : xla::LayoutUtil::GetWithDefaultLayout(result_shapes[i]); + callback_results[i].expected_dims.resize(shape.dimensions().size()); + absl::c_copy(shape.dimensions(), + callback_results[i].expected_dims.begin()); + callback_results[i].expected_strides = ByteStridesForShape(shape); + callback_results[i].type = shape.element_type(); + callback_results[i].size_in_bytes = xla::ShapeUtil::ByteSizeOf(shape); + callback_results[i].reversed_layout.resize(shape.dimensions().size()); + absl::c_reverse_copy(shape.layout().minor_to_major(), + callback_results[i].reversed_layout.begin()); + } else if (result_shapes[i].IsToken()) { + callback_results[i].type = xla::TOKEN; + } else { + return xla::InvalidArgument( + "Only array and token return values from Python callbacks are " + "supported, got %s", + result_shapes[i].ToString()); + } + } + return callback_results; +} + +} // namespace + +PyFfiLoadedHostCallback::~PyFfiLoadedHostCallback() { + // The destructor may be called without GIL held. In that case, we defer it + // to GlobalPyRefManager. + std::vector objects; + objects.push_back(std::move(callable_)); + GlobalPyRefManager()->AddGarbage(absl::MakeSpan(objects)); +} + +absl::StatusOr> +PyHostSendAndRecvLoadedHostCallback::Create( + ifrt::Client* ifrt_client, nb::callable callable, + absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, nb::callable serializer) { + TF_ASSIGN_OR_RETURN(auto callback_args, CreateCallbackArgs(operand_shapes)); + TF_ASSIGN_OR_RETURN(auto callback_results, + CreateCallbackResults(result_shapes)); + + // `callable` will be destroyed safely with `PythonRefManager` when + // `CpuCallback` is destroyed. + auto cpu_callback = + std::make_shared(callable, callback_args, callback_results); + + auto host_callback = std::make_unique(); + + auto assign_arg_info = [](absl::Span shapes, + absl::Span channel_ids, + std::vector& arg_infos) { + DCHECK_EQ(shapes.size(), channel_ids.size()); + arg_infos.reserve(shapes.size()); + for (int i = 0; i < shapes.size(); ++i) { + xla::HostCallbackArgInfo host_callback_arg_info; + host_callback_arg_info.channel_id = channel_ids[i]; + const auto& shape = shapes[i]; + xla::Shape layout = + (shape.has_layout() ? shape + : xla::LayoutUtil::GetWithDefaultLayout(shape)); + host_callback_arg_info.shape = layout; + arg_infos.push_back(std::move(host_callback_arg_info)); + } + }; + + assign_arg_info(operand_shapes, send_channel_ids, host_callback->operands); + assign_arg_info(result_shapes, recv_channel_ids, host_callback->results); + + host_callback->callback = [cpu_callback = std::move(cpu_callback)]( + void** outputs, void** inputs) { + return cpu_callback->PrepareAndCall(outputs, inputs); + }; + return tsl::RCReference( + tsl::MakeRef( + ifrt_client, std::move(host_callback), callable, operand_shapes, + result_shapes, send_channel_ids, recv_channel_ids, + std::move(serializer))); +} + +PyHostSendAndRecvLoadedHostCallback::PyHostSendAndRecvLoadedHostCallback( + ifrt::Client* ifrt_client, + std::unique_ptr xla_host_callback, nb::callable callable, + absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, nb::callable serializer) + : llvm::RTTIExtends( + ifrt_client, std::move(xla_host_callback)), + callable_(std::move(callable)), + operand_shapes_(operand_shapes.begin(), operand_shapes.end()), + result_shapes_(result_shapes.begin(), result_shapes.end()), + send_channel_ids_(send_channel_ids.begin(), send_channel_ids.end()), + recv_channel_ids_(recv_channel_ids.begin(), recv_channel_ids.end()), + serializer_(serializer) {} + +PyHostSendAndRecvLoadedHostCallback::~PyHostSendAndRecvLoadedHostCallback() { + GlobalPyRefManager()->AddGarbage( + absl::MakeSpan(static_cast(&callable_), 1)); + GlobalPyRefManager()->AddGarbage( + absl::MakeSpan(static_cast(&serializer_), 1)); +} + +absl::StatusOr PyHostSendAndRecvLoadedHostCallback::Serialize() + const { + if (serializer_.is_none()) { + return xla::InvalidArgument( + "Host callback cannot be serialized because serializer was not " + "provided by JAX"); + } + ifrt::XlaHostCallbackProto xla_host_callback_proto; + + TF_RET_CHECK(operand_shapes_.size() == send_channel_ids_.size()); + for (int i = 0; i < operand_shapes_.size(); ++i) { + ifrt::XlaHostCallbackProto::ArgInfo* const operand = + xla_host_callback_proto.add_operands(); + operand->set_channel_id(send_channel_ids_[i]); + *operand->mutable_shape() = operand_shapes_[i].ToProto(); + } + + TF_RET_CHECK(result_shapes_.size() == recv_channel_ids_.size()); + for (int i = 0; i < result_shapes_.size(); ++i) { + ifrt::XlaHostCallbackProto::ArgInfo* const result = + xla_host_callback_proto.add_results(); + result->set_channel_id(recv_channel_ids_[i]); + *result->mutable_shape() = result_shapes_[i].ToProto(); + } + + std::string callable; + { + nb::gil_scoped_acquire gil_acquire; + try { + nb::bytes bytes = nb::cast(serializer_(callable_)); + callable = std::string(bytes.c_str(), bytes.size()); + } catch (const nb::python_error& e) { + return absl::InternalError(absl::StrCat( + "Unable to pickle the host_callback callable: ", e.what())); + } catch (const std::exception& e) { + std::exception_ptr p = std::current_exception(); + return absl::InternalError(absl::StrCat( + "Exception while pickling the host_callback callable: ", e.what())); + } catch (...) { + // Ensure to avoid leaking any exception because this method could have + // been called outside of a Python context where C++ exceptions are not + // necessarily enabled. + return absl::InternalError( + "Unknown exception while pickling the host_callback callable."); + } + } + PyHostCallbackProto py_host_callback_proto; + py_host_callback_proto.set_callable(std::move(callable)); + if (!xla_host_callback_proto.mutable_serialized_callback()->PackFrom( + py_host_callback_proto)) { + return absl::InternalError("Could not serialize a Python host callback"); + } + xla_host_callback_proto.set_use_major_to_minor_data_layout_for_callbacks( + true); + return xla_host_callback_proto.SerializeAsString(); +} + +} // namespace jax diff --git a/jaxlib/py_host_callback.h b/jaxlib/py_host_callback.h new file mode 100644 index 000000000000..4313e17472dc --- /dev/null +++ b/jaxlib/py_host_callback.h @@ -0,0 +1,120 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_HOST_CALLBACK_H_ +#define JAXLIB_PY_HOST_CALLBACK_H_ + +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "nanobind/nanobind.h" +#include "xla/pjrt/host_callback.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/pjrt_ifrt/pjrt_host_callback.h" +#include "xla/shape.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/util.h" + +namespace jax { + +using PyLoadedHostCallback = ::xla::ifrt::LoadedHostCallback; + +class PyFfiLoadedHostCallback final + : public llvm::RTTIExtends { + public: + PyFfiLoadedHostCallback(xla::ifrt::Client* ifrt_client, + nanobind::callable callable) + : llvm::RTTIExtends(ifrt_client, + callable.ptr()), + callable_(std::move(callable)) {} + ~PyFfiLoadedHostCallback() override; + + xla::ifrt::Client* client() const override { return ifrt_client_; } + absl::StatusOr Serialize() const override { + return xla::Unimplemented( + "PyFfiLoadedHostCallback::Serialize() is not supported"); + }; + + static char ID; // NOLINT + + private: + xla::ifrt::Client* ifrt_client_; + nanobind::callable callable_; +}; + +// `PyHostSendAndRecvLoadedHostCallback` implements a Python host callback that +// uses XLA host send and recv. This object should be passed to the compiler +// when creating `xla::ifrt::LoadedExecutable`. +// +// Serialization is supported if the Python host callback using the +// `cloudpickle` third-party library. +// +// TODO(hyeontaek): Update the comment ("compiler" to "client") after splitting +// compilation and loading. +class PyHostSendAndRecvLoadedHostCallback final + : public llvm::RTTIExtends< + PyHostSendAndRecvLoadedHostCallback, + xla::ifrt::PjRtHostSendAndRecvLoadedHostCallback> { + public: + static absl::StatusOr> + Create(xla::ifrt::Client* ifrt_client, nanobind::callable callable, + absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, + nanobind::callable serializer); + + // PjRtLoadedHostCallback implementation. + + ~PyHostSendAndRecvLoadedHostCallback() override; + + absl::StatusOr Serialize() const override; + + static char ID; // NOLINT + + private: + PyHostSendAndRecvLoadedHostCallback( + xla::ifrt::Client* ifrt_client, + std::unique_ptr xla_host_callback, + nanobind::callable callable, absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, + nanobind::callable serializer); + + template + friend tsl::RCReference tsl::MakeRef(Args&&... args); + + // Retained arguments for host callback serialization. + nanobind::callable callable_; + std::vector operand_shapes_; + std::vector result_shapes_; + std::vector send_channel_ids_; + std::vector recv_channel_ids_; + nanobind::callable serializer_; +}; + +} // namespace jax + +#endif // JAXLIB_PY_HOST_CALLBACK_H_ diff --git a/jaxlib/py_host_callback.proto b/jaxlib/py_host_callback.proto new file mode 100644 index 000000000000..beda3d341e90 --- /dev/null +++ b/jaxlib/py_host_callback.proto @@ -0,0 +1,25 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package jax; + +// Represents a JAX host callback that is serialized using the 'cloudpickle' +// Python library. Typically used for +// `xla.ifrt.XlaHostCallbackProto.serialized_callback`. +message PyHostCallbackProto { + bytes callable = 1; +} diff --git a/jaxlib/py_memory_space.cc b/jaxlib/py_memory_space.cc new file mode 100644 index 000000000000..a4958d6036b6 --- /dev/null +++ b/jaxlib/py_memory_space.cc @@ -0,0 +1,104 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_memory_space.h" + +#include + +#include +#include + +#include "nanobind/nanobind.h" +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_device.h" // IWYU pragma: keep +#include "xla/python/ifrt/device.h" + +namespace ifrt = ::xla::ifrt; +namespace nb = ::nanobind; + +namespace jax { + +PyMemorySpace::PyMemorySpace(nb_class_ptr client, + ifrt::Memory* memory) + : client_(std::move(client)), memory_(memory) {} + +int PyMemorySpace::process_index() const { return client_->process_index(); } + +std::string_view PyMemorySpace::platform() const { + // TODO(phawkins): this is a temporary backwards + // compatibility shim. We changed the name PJRT + // reports for GPU platforms to "cuda" or "rocm", + // but we haven't yet updated JAX clients that + // expect "gpu". Migrate users and remove this + // code. + if (client_->platform_name() == "cuda" || + client_->platform_name() == "rocm") { + return std::string_view("gpu"); + } else { + return client_->platform_name(); + } +} + +std::string_view PyMemorySpace::kind() const { + return *memory_->Kind().memory_kind(); +} + +std::string_view PyMemorySpace::Str() const { return memory_->DebugString(); } + +std::string_view PyMemorySpace::Repr() const { return memory_->ToString(); } + +nb::typed PyMemorySpace::AddressableByDevices() const { + nb::list devices; + for (ifrt::Device* device : memory_->Devices()) { + devices.append(client_->GetPyDevice(device)); + } + return devices; +} + +/* static */ int PyMemorySpace::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + PyMemorySpace* d = nb::inst_ptr(self); + Py_VISIT(d->client().ptr()); + return 0; +} + +/* static */ int PyMemorySpace::tp_clear(PyObject* self) { + PyMemorySpace* d = nb::inst_ptr(self); + nb_class_ptr client; + std::swap(client, d->client_); + return 0; +} + +PyType_Slot PyMemorySpace::slots_[] = { + {Py_tp_traverse, (void*)PyMemorySpace::tp_traverse}, + {Py_tp_clear, (void*)PyMemorySpace::tp_clear}, + {0, nullptr}, +}; + +/* static */ void PyMemorySpace::Register(nb::module_& m) { + nb::class_ device(m, "Memory", + nb::type_slots(PyMemorySpace::slots_)); + device.def_prop_ro("process_index", &PyMemorySpace::process_index) + .def_prop_ro("platform", &PyMemorySpace::platform) + .def_prop_ro("kind", &PyMemorySpace::kind) + .def("__str__", &PyMemorySpace::Str) + .def("__repr__", &PyMemorySpace::Repr) + .def("addressable_by_devices", &PyMemorySpace::AddressableByDevices, + "Returns devices that can address this memory."); +} + +} // namespace jax diff --git a/jaxlib/py_memory_space.h b/jaxlib/py_memory_space.h new file mode 100644 index 000000000000..cccdbea71a19 --- /dev/null +++ b/jaxlib/py_memory_space.h @@ -0,0 +1,66 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_MEMORY_SPACE_H_ +#define JAXLIB_PY_MEMORY_SPACE_H_ + +#include + +#include + +#include "nanobind/nanobind.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "xla/python/ifrt/memory.h" + +namespace jax { + +class PyMemorySpace { + public: + PyMemorySpace(nb_class_ptr client, xla::ifrt::Memory* memory_space); + + // Memory spaces are compared using Python object identity, so we don't allow + // them to be copied or moved. + PyMemorySpace(const PyMemorySpace&) = delete; + PyMemorySpace(PyMemorySpace&&) = delete; + PyMemorySpace& operator=(const PyMemorySpace&) = delete; + PyMemorySpace& operator=(PyMemorySpace&&) = delete; + + const nb_class_ptr& client() const { return client_; } + xla::ifrt::Memory* memory_space() const { return memory_; } + + int process_index() const; + std::string_view platform() const; + std::string_view kind() const; + + std::string_view Str() const; + std::string_view Repr() const; + + nanobind::typed AddressableByDevices() const; + + static void Register(nanobind::module_& m); + + private: + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); + static PyType_Slot slots_[]; + + nb_class_ptr client_; + xla::ifrt::Memory* memory_; +}; + +} // namespace jax + +#endif // JAXLIB_PY_MEMORY_SPACE_H_ diff --git a/jaxlib/py_program.cc b/jaxlib/py_program.cc new file mode 100644 index 000000000000..b4cfd41cf873 --- /dev/null +++ b/jaxlib/py_program.cc @@ -0,0 +1,324 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_program.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_device.h" +#include "jaxlib/py_device_list.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/sharding.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/array_spec.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/custom_call_program.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/hlo/hlo_program.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/plugin_program.h" +#include "xla/python/ifrt/program.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "xla/python/pjrt_ifrt/xla_sharding.h" +#include "xla/python/types.h" +#include "xla/python/version.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/statusor.h" + +namespace ifrt = ::xla::ifrt; +namespace nb = ::nanobind; + +namespace jax { + +namespace { + +// Gets `ifrt::DeviceList` from a sequence of JAX devices. +absl::StatusOr GetDeviceList(nb::sequence devices) { + ifrt::DeviceListRef ifrt_device_list; + if (devices.type().is(PyDeviceList::type())) { + return nb::cast(devices)->ifrt_device_list(); + } else { + auto py_devices = nb::cast>>(devices); + if (py_devices.empty()) { + return absl::InvalidArgumentError( + "Colocated Python program requires at least one device"); + } + absl::InlinedVector ifrt_devices; + ifrt_devices.reserve(py_devices.size()); + for (const nb_class_ptr& py_device : py_devices) { + ifrt_devices.push_back(py_device->device()); + } + return py_devices.front()->client()->ifrt_client()->MakeDeviceList( + ifrt_devices); + } +} + +// Gets `xla::HloSharding` from a JAX Sharding. +xla::HloSharding GetXlaHloSharding(nb::handle sharding, + int64_t num_dimensions) { + if (sharding.type().is(GSPMDSharding::type())) { + return nb::cast(sharding)->hlo_sharding(); + } else { + return nb::cast( + sharding.attr("_to_xla_hlo_sharding")(num_dimensions)); + } +} + +// Gets `ifrt::DeviceList` from a JAX Sharding. +absl::StatusOr GetIfrtDeviceList(nb::handle sharding) { + if (sharding.type().is(NamedSharding::type())) { + TF_ASSIGN_OR_RETURN( + auto ns_device_list, + nb::cast(sharding)->internal_device_list()); + return ns_device_list->ifrt_device_list(); + } else if (sharding.type().is(SingleDeviceSharding::type())) { + return nb::cast(sharding) + ->internal_device_list() + ->ifrt_device_list(); + } else if (sharding.type().is(PmapSharding::type())) { + return nb::cast(sharding) + ->internal_device_list() + ->ifrt_device_list(); + } else if (sharding.type().is(GSPMDSharding::type())) { + return nb::cast(sharding) + ->internal_device_list() + ->ifrt_device_list(); + } else { + return nb::cast(sharding.attr("_internal_device_list")) + ->ifrt_device_list(); + } +} + +// Gets `ifrt::MemoryKind` from a JAX Sharding. +ifrt::MemoryKind GetIfrtMemoryKind(nb::handle sharding) { + auto memory_kind = sharding.attr("memory_kind"); + if (memory_kind.is_none()) { + return ifrt::MemoryKind(); + } else { + return ifrt::MemoryKind(nb::cast(memory_kind)); + } +} + +// Makes `ifrt::Sharding` from a JAX Sharding. It requires the number of shape +// dimensions, which may become necessary when building an HLO sharding. +absl::StatusOr GetIfrtSharding(nb::handle sharding, + int64_t num_dimensions) { + auto ifrt_memory_kind = GetIfrtMemoryKind(sharding); + ifrt::ShardingRef ifrt_sharding; + if (sharding.type().is(SingleDeviceSharding::type())) { + TF_ASSIGN_OR_RETURN(auto ifrt_device_list, + nb::cast(sharding) + ->internal_device_list() + ->ifrt_device_list()); + return ifrt::SingleDeviceSharding::Create( + ifrt_device_list->devices().front(), ifrt_memory_kind); + } else { + TF_ASSIGN_OR_RETURN(auto ifrt_device_list, GetIfrtDeviceList(sharding)); + auto xla_hlo_sharding = GetXlaHloSharding(sharding, num_dimensions); + return ifrt::HloSharding::Create(std::move(ifrt_device_list), + ifrt_memory_kind, + std::move(xla_hlo_sharding)); + } +} + +// Gets `ifrt::ArraySpec`s from a sequence of JAX avals (e.g., +// `jax.ShapeDtypeStruct`). +absl::StatusOr> GetIfrtArraySpecs( + nb::sequence avals) { + std::vector ifrt_array_specs; + ifrt_array_specs.reserve(nb::len(avals)); + for (nb::handle aval : avals) { + ifrt::Shape ifrt_shape(nb::cast>(aval.attr("shape"))); + TF_ASSIGN_OR_RETURN( + auto ifrt_dtype, + DtypeToIfRtDType(nb::cast(aval.attr("dtype")))); + TF_ASSIGN_OR_RETURN( + auto ifrt_sharding, + GetIfrtSharding(aval.attr("sharding"), ifrt_shape.dims().size())); + ifrt_array_specs.push_back(ifrt::ArraySpec{ + ifrt_dtype, std::move(ifrt_shape), std::move(ifrt_sharding)}); + } + return ifrt_array_specs; +} + +absl::StatusOr> MakePluginProgramFromString( + std::string data) { + auto plugin_program = std::make_unique(); + plugin_program->data = std::move(data); + return plugin_program; +} + +absl::StatusOr> MakePluginProgramFromBytes( + nb::bytes data) { + auto plugin_program = std::make_unique(); + plugin_program->data = std::string(data.c_str(), data.size()); + return plugin_program; +} + +absl::StatusOr> +MakeColocatedPythonCompileOptions() { + return std::make_unique(); +} + +absl::StatusOr> +MakePluginCompileOptions() { + return std::make_unique(); +} + +absl::StatusOr> MakeHloProgram( + std::string_view mlir_module) { + auto context = std::make_unique(); + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + xla::ParseMlirModuleString(mlir_module, *context)); + return std::make_unique(std::move(context), + std::move(module)); +} + +absl::StatusOr> MakeHloProgramFromString( + std::string mlir_module) { + return MakeHloProgram(mlir_module); +} + +absl::StatusOr> MakeHloProgramFromBytes( + nb::bytes mlir_module) { + return MakeHloProgram( + std::string_view(mlir_module.c_str(), mlir_module.size())); +} + +absl::StatusOr> MakeXlaCompileOptions( + xla::CompileOptions options, PyDeviceList& py_executable_devices, + std::vector host_callbacks) { + std::vector> + ifrt_loaded_host_callbacks; + ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); + // Extract `ifrt::LoadedHostCallback`s from host callback capsules that were + // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()` or + // `PyClient::GetEmitPythonCallbackDescriptor()`. + for (auto& host_callback : host_callbacks) { + ifrt_loaded_host_callbacks.push_back(tsl::FormRef( + static_cast(host_callback.data()))); + } + TF_ASSIGN_OR_RETURN(ifrt::DeviceListRef executable_devices, + py_executable_devices.ifrt_device_list()); + return std::make_unique( + std::move(options), std::move(executable_devices), + std::move(ifrt_loaded_host_callbacks)); +} + +constexpr std::string_view kColocatedPythonProgramType = + "jax_colocated_python_v0.0.1"; + +absl::StatusOr> MakeColocatedPythonProgram( + std::string name, nb::bytes picked_function, nb::sequence devices, + nb::sequence input_avals, nb::sequence output_avals) { + auto ifrt_serialized_program_text = absl::MakeCordFromExternal( + std::string_view(reinterpret_cast(picked_function.data()), + picked_function.size()), + /*releaser=*/[picked_function](std::string_view) mutable { + GlobalPyRefManager()->AddGarbage(std::move(picked_function)); + }); + TF_ASSIGN_OR_RETURN(auto ifrt_device_list, GetDeviceList(devices)); + TF_ASSIGN_OR_RETURN(auto ifrt_input_specs, GetIfrtArraySpecs(input_avals)); + TF_ASSIGN_OR_RETURN(auto ifrt_output_specs, GetIfrtArraySpecs(output_avals)); + return std::make_unique( + std::string(kColocatedPythonProgramType), std::move(name), + std::move(ifrt_serialized_program_text), std::move(ifrt_device_list), + std::move(ifrt_input_specs), std::move(ifrt_output_specs)); +} + +} // namespace + +void BuildIfrtProgramsSubmodule(nanobind::module_& m) { + auto sub_module = m.def_submodule("ifrt_programs"); + sub_module.attr("_CompileOptions") = m.attr("CompileOptions"); + sub_module.attr("_Device") = m.attr("Device"); + sub_module.attr("_DeviceList") = m.attr("DeviceList"); + + nb::class_ ifrt_program_base_class(sub_module, "Program"); + nb::class_ ifrt_compile_options_base_class( + sub_module, "CompileOptions"); + sub_module + .def("make_hlo_program", + xla::ValueOrThrowWrapper(MakeHloProgramFromString), + nb::arg("mlir_module")) + .def("make_hlo_program", + xla::ValueOrThrowWrapper(MakeHloProgramFromBytes), + nb::arg("mlir_module")) + .def("make_colocated_python_program", + xla::ValueOrThrowWrapper(MakeColocatedPythonProgram), + nb::arg("name"), nb::arg("pickled_function"), nb::arg("devices"), + nb::arg("input_avals"), nb::arg("output_avals"), + nb::sig( + // clang-format off + "def make_colocated_python_program(" + "name: str, " + "picked_function: bytes, " + "devices: typing.Sequence[_Device] | _DeviceList, " + "input_avals: Sequence[typing.Any], " + "output_avals: Sequence[Any]" + ") -> Program" + // clang-format on + )) + .def("make_plugin_program", + xla::ValueOrThrowWrapper(MakePluginProgramFromString), + nb::arg("data")) + .def("make_plugin_program", + xla::ValueOrThrowWrapper(MakePluginProgramFromBytes), + nb::arg("data")) + .def("make_xla_compile_options", + xla::ValueOrThrowWrapper(MakeXlaCompileOptions), nb::arg("options"), + nb::arg("executable_devices"), nb::arg("host_callbacks"), + nb::sig( + // clang-format off + "def make_xla_compile_options(" + "options: _CompileOptions, " + "executable_devices: Sequence[_Device], " + "host_callbacks: Sequence[typing_extensions.CapsuleType]" + ") -> CompileOptions" + // clang-format on + )) + .def("make_colocated_python_compile_options", + xla::ValueOrThrowWrapper(MakeColocatedPythonCompileOptions)) + .def("make_plugin_compile_options", + xla::ValueOrThrowWrapper(MakePluginCompileOptions)); +} + +} // namespace jax diff --git a/jaxlib/py_program.h b/jaxlib/py_program.h new file mode 100644 index 000000000000..fccd21aa16f9 --- /dev/null +++ b/jaxlib/py_program.h @@ -0,0 +1,27 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_PROGRAM_H_ +#define JAXLIB_PY_PROGRAM_H_ + +#include "nanobind/nanobind.h" + +namespace jax { + +void BuildIfrtProgramsSubmodule(nanobind::module_& m); + +} // namespace jax + +#endif // JAXLIB_PY_PROGRAM_H_ diff --git a/jaxlib/py_socket_transfer.cc b/jaxlib/py_socket_transfer.cc new file mode 100644 index 000000000000..7349db2a0b6a --- /dev/null +++ b/jaxlib/py_socket_transfer.cc @@ -0,0 +1,566 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "jaxlib/py_socket_transfer.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/time/time.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/array.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_executable.h" +#include "jaxlib/py_user_context.h" +#include "jaxlib/to_ifrt_sharding.h" +#include "xla/future.h" +#include "xla/pjrt/distributed/client.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/array_spec.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/pjrt_ifrt/pjrt_memory.h" +#include "xla/python/pjrt_ifrt/transfer_server_interface.h" +#include "xla/python/transfer/event_loop.h" +#include "xla/python/transfer/pjrt_transfer_server.h" +#include "xla/python/transfer/socket-server.h" +#include "xla/python/transfer/socket_bulk_transport.h" +#include "xla/python/transfer/streaming.h" +#include "xla/python/transfer/streaming_ifrt.h" +#include "xla/python/transfer/transfer_socket.pb.h" +#include "xla/python/types.h" +#include "xla/python/version.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "tsl/platform/casts.h" + +namespace aux { + +namespace nb = nanobind; + +absl::StatusOr MemorySpaceFromSharding( + const xla::ifrt::Sharding& sharding) { + if (sharding.devices()->devices().size() != 1) { + return xla::InvalidArgument( + "Can only convert SingleDeviceSharding to MemorySpace not %s", + sharding.DebugString()); + } + auto* device = sharding.devices()->devices()[0]; + if (sharding.memory_kind().memory_kind().has_value()) { + // Find `PjRtMemorySpace` that is associated with the sharding's device + // and matches the sharding's memory_kind. + xla::ifrt::Memory* memory = nullptr; + for (xla::ifrt::Memory* ms : device->Memories()) { + if (ms->Kind() == sharding.memory_kind()) { + memory = ms; + break; + } + } + if (memory == nullptr) { + return xla::InvalidArgument( + "Invalid memory kind: %s; available memory kinds: %s", + *sharding.memory_kind().memory_kind(), + absl::StrJoin(sharding.devices()->devices().front()->Memories(), ", ", + [](std::string* out, xla::ifrt::Memory* ms) { + absl::StrAppend(out, *ms->Kind().memory_kind()); + })); + } + return tensorflow::down_cast(memory)->pjrt_memory(); + } else { + if (!device->IsAddressable()) { + return xla::InvalidArgument( + "Cannot copy array to non-addressable device %s", + device->DebugString()); + } + return tensorflow::down_cast(device) + ->pjrt_device() + ->default_memory_space(); + } +} + +absl::StatusOr> CreatePullEntry( + const std::vector& arrs, + std::shared_ptr state, size_t xfer_size, + bool use_raw_buffers) { + if (use_raw_buffers) { + std::vector refs; + for (auto& arr : arrs) { + auto* pjrt_arr = llvm::dyn_cast_or_null(arr.get()); + if (pjrt_arr == nullptr) { + return absl::InvalidArgumentError( + "Cannot remote transfer non-pjrt arrays."); + } + for (auto& pjrt_buf : pjrt_arr->pjrt_buffers()) { + TF_ASSIGN_OR_RETURN(size_t buf_size, + pjrt_buf->GetOnDeviceSizeInBytes()); + TF_ASSIGN_OR_RETURN( + auto raw_buffer, + xla::PjRtRawBuffer::CreateRawAliasOfBuffer(pjrt_buf.get())); + refs.push_back( + {pjrt_buf->GetReadyFuture(), std::move(raw_buffer), buf_size}); + } + } + return tsl::MakeRef(std::move(refs), state, xfer_size); + } + + std::vector refs; + for (auto& arr : arrs) { + auto* pjrt_arr = llvm::dyn_cast_or_null(arr.get()); + if (pjrt_arr == nullptr) { + return absl::InvalidArgumentError( + "Cannot remote transfer non-pjrt arrays."); + } + for (auto& pjrt_buf : pjrt_arr->pjrt_buffers()) { + TF_ASSIGN_OR_RETURN(size_t buf_size, pjrt_buf->GetOnDeviceSizeInBytes()); + refs.push_back({pjrt_buf, buf_size}); + } + } + return tsl::MakeRef(std::move(refs), state, xfer_size); +} + +class PyTransferServerConnection { + public: + explicit PyTransferServerConnection( + tsl::RCReference conn) + : conn_(std::move(conn)) {} + + void Pull(uint64_t uuid, std::vector buffer_ids, + std::vector> pull_dests) { + if (buffer_ids.size() < 64) { + conn_->Pull(uuid, buffer_ids, std::move(pull_dests)); + } else { + for (size_t i = 0; i < buffer_ids.size(); ++i) { + conn_->Pull(uuid, buffer_ids[i], std::move(pull_dests[i])); + } + } + } + + SocketServer::Connection& conn() { return *conn_; } + + private: + tsl::RCReference conn_; +}; + +class PyTransferServer { + public: + PyTransferServer() = default; + absl::Status Start(xla::ifrt::Client* client, size_t max_num_parallel_copies, + size_t xfer_size, const SocketAddress& addr, + const std::vector& transport_addresses, + bool supports_pinned_allocator, bool use_raw_buffers) { + use_raw_buffers_ = use_raw_buffers; + std::shared_ptr factory; + std::shared_ptr pjrt_client = + tensorflow::down_cast(client) + ->shared_ptr_pjrt_client(); + if (transport_addresses.empty()) { + factory = BulkTransportFactory::CreateLocal(); + } else { + auto tmp = xla::ValueOrThrow( + AllocateAlignedMemory(xfer_size * max_num_parallel_copies)); + SlabAllocator uallocator(xla::ValueOrThrow(MapPjrtMemory( + pjrt_client, tmp->data(), tmp->size(), tmp)), + xfer_size); + std::optional pinned_allocator; + if (supports_pinned_allocator) { + auto tmp = xla::ValueOrThrow( + AllocateNetworkPinnedMemory(xfer_size * max_num_parallel_copies)); + pinned_allocator.emplace( + xla::ValueOrThrow( + MapPjrtMemory(pjrt_client, tmp->data(), tmp->size(), tmp)), + xfer_size); + } + factory = xla::ValueOrThrow(CreateSocketBulkTransportFactory( + transport_addresses, pinned_allocator, uallocator)); + } + + server_ = std::make_shared(); + + TF_ASSIGN_OR_RETURN( + auto mem, AllocateAndMapPjrtMemory( + pjrt_client, max_num_parallel_copies * xfer_size * 2)); + premapped_copier_ = std::make_shared( + mem, max_num_parallel_copies, xfer_size); + xfer_size_ = xfer_size; + return server_->Start(addr, factory); + } + std::string address() { return server_->addr().ToString(); } + + PyTransferServerConnection Connect(const std::string& saddr) { + return PyTransferServerConnection( + server_->Connect(xla::ValueOrThrow(SocketAddress::Parse(saddr)))); + } + + void AwaitPull(uint64_t uuid, const std::vector& arrs) { + server_->AwaitPull( + uuid, xla::ValueOrThrow(CreatePullEntry(arrs, premapped_copier_, + xfer_size_, use_raw_buffers_))); + } + + void Reset() { server_->Reset(); } + + size_t xfer_size() { return xfer_size_; } + + std::shared_ptr premapped_copier() { + return premapped_copier_; + } + + private: + std::shared_ptr server_; + std::shared_ptr premapped_copier_; + size_t xfer_size_; + bool use_raw_buffers_ = false; +}; + +absl::StatusOr ArraySpecFromShapeDtypeStruct( + nb::handle aval) { + TF_ASSIGN_OR_RETURN(xla::ifrt::DType dtype, + xla::DtypeToIfRtDType( + nb::borrow(aval.attr("dtype").ptr()))); + auto shape_dims = nb::cast>(aval.attr("shape")); + auto shape = xla::ifrt::Shape( + xla::ifrt::Shape::Dimensions(shape_dims.begin(), shape_dims.end())); + TF_ASSIGN_OR_RETURN(auto sharding, + jax::GetIfrtHloSharding(aval.attr("sharding"), shape)); + return xla::ifrt::ArraySpec{dtype, std::move(shape), std::move(sharding)}; +} + +struct BufferSource { + xla::ifrt::ArrayRef arr; + xla::PjRtBuffer* buffer; +}; + +struct CopyDests { + std::vector shape_specs; + xla::PjRtMemorySpace* memory_space; +}; + +void RegisterTransferServerTypes(nanobind::module_& m) { + nb::class_(m, "TransferConnection") + .def( + "_testonly_inject_failure", + [](PyTransferServerConnection& self) { self.conn().InjectFailure(); }) + .def("_poison_connection", + [](PyTransferServerConnection& self) { + self.conn().InjectFailure(aux::SocketServer::Connection::kPoison); + }) + .def("_pull_flat", + [](PyTransferServerConnection& self, nb::int_ uuid, + jax::nb_class_ptr py_client, + std::vector py_avals) { + auto* ifrt_client = llvm::dyn_cast_or_null( + py_client->ifrt_client()); + if (ifrt_client == nullptr) { + xla::ThrowIfError(absl::InvalidArgumentError( + "_pull_flat only supported on pjrt-ifrt clients.")); + } + + jax::PyUserContextScope user_context_scope; + std::vector avals; + std::vector shardings; + shardings.reserve(py_avals.size()); + avals.reserve(py_avals.size()); + for (const auto& py_aval : py_avals) { + avals.push_back( + xla::ValueOrThrow(ArraySpecFromShapeDtypeStruct(py_aval))); + shardings.push_back(py_aval.attr("sharding")); + } + + std::vector dests; + std::vector> fetch_idxs; + absl::flat_hash_map mapping; + std::vector>> buffer_list; + + for (auto& aval : avals) { + std::vector> buf_list; + auto prim_type = + xla::ValueOrThrow(xla::ifrt::ToPrimitiveType(aval.dtype)); + auto shards = xla::ValueOrThrow(aval.sharding->Disassemble( + aval.shape, + xla::ifrt::SingleDeviceShardSemantics::kAddressableShards)); + buf_list.reserve(shards.size()); + for (auto& shard : shards) { + auto* mem_space = + xla::ValueOrThrow(MemorySpaceFromSharding(*shard.second)); + int dest_idx = + mapping.emplace(mem_space, static_cast(dests.size())) + .first->second; + if (dest_idx == dests.size()) { + dests.emplace_back(); + dests.back().memory_space = mem_space; + } + fetch_idxs.push_back( + {dest_idx, + static_cast(dests[dest_idx].shape_specs.size())}); + buf_list.push_back(fetch_idxs.back()); + dests[dest_idx].shape_specs.push_back( + {prim_type, + xla::DimensionVector(shard.first.dims().begin(), + shard.first.dims().end())}); + } + buffer_list.push_back(std::move(buf_list)); + } + + std::vector> + atms; + atms.reserve(dests.size()); + + for (auto& dest : dests) { + atms.push_back(xla::ValueOrThrow( + py_client->pjrt_client()->CreateBuffersForAsyncHostToDevice( + dest.shape_specs, std::nullopt, dest.memory_space))); + } + + std::vector> pull_dests; + std::vector buffer_ids; + pull_dests.reserve(fetch_idxs.size()); + buffer_ids.reserve(fetch_idxs.size()); + for (auto& fetch_idx : fetch_idxs) { + auto& atm = atms[fetch_idx.first]; + pull_dests.push_back(MakeDmaDestination( + atm, fetch_idx.second, atm->buffer_size(fetch_idx.second))); + buffer_ids.push_back(static_cast(buffer_ids.size())); + } + + uint64_t uuid_cpp; + try { + uuid_cpp = static_cast(uuid); + } catch (std::out_of_range& e) { + throw nb::value_error( + "_pull_flat requires uuid to fit in a uint64_t"); + } + self.Pull(uuid_cpp, buffer_ids, std::move(pull_dests)); + + std::vector out; + for (size_t i = 0; i < buffer_list.size(); ++i) { + xla::ifrt::PjRtArray::PjRtBuffers buffers; + buffers.reserve(buffer_list[i].size()); + for (auto& v : buffer_list[i]) { + buffers.push_back(atms[v.first]->RetrieveBuffer(v.second)); + } + auto arr = xla::ValueOrThrow(xla::ifrt::PjRtArray::Create( + ifrt_client, avals[i].dtype, avals[i].shape, + avals[i].sharding, std::move(buffers), avals[i].layout)); + out.push_back(jax::PyArray::MakeFromIfrtArrayAndSharding( + py_client, std::move(arr), shardings[i], false, true, + /*skip_checks=*/false)); + } + + return out; + }) + .def("_pull_into_flat", [](PyTransferServerConnection& self, + nb::int_ uuid, std::vector dests, + std::vector slices_per_array) { + if (dests.size() != slices_per_array.size()) { + throw nb::value_error( + absl::StrFormat("Expected dests and slices to have the same " + "size, got: %d vs %d", + dests.size(), slices_per_array.size()) + .c_str()); + } + std::vector> arrs; + arrs.reserve(dests.size()); + for (const jax::PyArray& dest : dests) { + arrs.push_back(tsl::FormRef( + tensorflow::down_cast( + dest.ifrt_array()))); + } + uint64_t uuid_cpp; + try { + uuid_cpp = static_cast(uuid); + } catch (std::out_of_range& e) { + throw nb::value_error( + "_await_pull_flat requires uuid to fit in a uint64_t"); + } + size_t i = 0; + std::vector futures; + std::vector> pull_dests; + std::vector buffer_ids; + for (auto& slice : slices_per_array) { + auto device_size = xla::ValueOrThrow( + arrs[i]->pjrt_buffers()[0]->GetOnDeviceSizeInBytes()); + auto [start, limit, step, total_size] = slice.compute(device_size); + if (step != 1 || start + total_size != limit || limit > device_size) { + throw nb::value_error( + absl::StrFormat("Invalid slice (strides are not supported): %s " + "for buffer of size: %d", + nb::repr(slice).c_str(), device_size) + .c_str()); + } + std::vector> futures_per_array; + for (auto& buffer : arrs[i]->pjrt_buffers()) { + auto raw_buffer = xla::ValueOrThrow( + xla::PjRtRawBuffer::CreateRawAliasOfBuffer(buffer.get())); + tsl::RCReference dest; + xla::Future<> future; + std::tie(dest, future) = xla::ValueOrThrow( + CreateSlicedRawBufferDest(raw_buffer, start, total_size)); + futures_per_array.push_back(std::move(future)); + pull_dests.push_back(std::move(dest)); + buffer_ids.push_back(static_cast(buffer_ids.size())); + } + futures.emplace_back(xla::JoinFutures(futures_per_array)); + ++i; + } + + self.Pull(uuid_cpp, buffer_ids, std::move(pull_dests)); + + return futures; + }); + + nb::class_(m, "TransferServer") + .def("address", [](PyTransferServer& self) { return self.address(); }) + .def("_await_pull_flat", + [](PyTransferServer& self, nb::int_ uuid, + std::vector inputs) { + std::vector arrs; + arrs.reserve(inputs.size()); + for (const jax::PyArray& input : inputs) { + arrs.push_back(tsl::FormRef(input.ifrt_array())); + } + uint64_t uuid_cpp; + try { + uuid_cpp = static_cast(uuid); + } catch (std::out_of_range& e) { + throw nb::value_error( + "_await_pull_flat requires uuid to fit in a uint64_t"); + } + self.AwaitPull(uuid_cpp, arrs); + }) + .def("_reset_rendevous_table", + [](PyTransferServer& self) { self.Reset(); }) + .def("connect", [](PyTransferServer& self, const std::string& address) { + return self.Connect(address); + }); + m.def("_make_error_array", [](jax::nb_class_ptr py_client, + nb::object py_aval, std::string message) { + auto* ifrt_client = + llvm::dyn_cast_or_null(py_client->ifrt_client()); + if (ifrt_client == nullptr) { + xla::ThrowIfError(absl::InvalidArgumentError( + "_pull_flat only supported on pjrt-ifrt clients.")); + } + jax::PyUserContextScope user_context_scope; + auto aval = xla::ValueOrThrow(ArraySpecFromShapeDtypeStruct(py_aval)); + xla::ifrt::PjRtArray::PjRtBuffers buffers; + auto prim_type = xla::ValueOrThrow(xla::ifrt::ToPrimitiveType(aval.dtype)); + auto shards = xla::ValueOrThrow(aval.sharding->Disassemble( + aval.shape, xla::ifrt::SingleDeviceShardSemantics::kAddressableShards)); + buffers.reserve(shards.size()); + for (auto& shard : shards) { + auto* mem_space = + xla::ValueOrThrow(MemorySpaceFromSharding(*shard.second)); + xla::PjRtClient::ShapeSpec shape_spec = { + prim_type, xla::DimensionVector(shard.first.dims().begin(), + shard.first.dims().end())}; + auto atm = xla::ValueOrThrow( + py_client->pjrt_client()->CreateBuffersForAsyncHostToDevice( + {shape_spec}, std::nullopt, mem_space)); + + atm->SetBufferError(0, absl::InternalError(message)); + buffers.push_back(atm->RetrieveBuffer(0)); + } + auto arr = xla::ValueOrThrow(xla::ifrt::PjRtArray::Create( + ifrt_client, aval.dtype, aval.shape, aval.sharding, std::move(buffers), + aval.layout)); + return jax::PyArray::MakeFromIfrtArrayAndSharding( + py_client, std::move(arr), py_aval.attr("sharding"), false, true, + /*skip_checks=*/false); + }); + + m.def( + "start_transfer_server", + [](jax::nb_class_ptr py_client, std::string address, + std::vector transport_addresses_str, + size_t max_num_parallel_copies, size_t transfer_size, + bool supports_pinned_allocator, + bool use_raw_buffers) -> PyTransferServer { + PyTransferServer result; + std::vector transport_addresses; + transport_addresses.reserve(transport_addresses_str.size()); + for (const std::string& addr : transport_addresses_str) { + transport_addresses.push_back( + xla::ValueOrThrow(SocketAddress::Parse(addr))); + } + xla::ThrowIfError(result.Start( + py_client->ifrt_client(), max_num_parallel_copies, transfer_size, + xla::ValueOrThrow(SocketAddress::Parse(address)), + transport_addresses, supports_pinned_allocator, use_raw_buffers)); + return result; + }, + nb::arg("client"), nb::arg("address") = SocketAddress().ToString(), + nb::arg("transport_addresses") = std::vector(), + nb::arg("max_num_parallel_copies") = 8, + nb::arg("transfer_size") = 256 * 1024 * 1024, + // Dual pinning not confirmed to be supported. + nb::arg("supports_pinned_allocator") = false, + // Technically unsafe (because a future donation won't wait for the + // transfer to complete). + nb::arg("use_raw_buffers") = false); + m.def( + "make_transfer_server_interface_factory", + [](size_t transfer_size, int cross_host_transfer_timeout_seconds, + std::shared_ptr distributed_client, + const std::string& socket_address, + const std::vector& transport_addresses) + -> xla::ifrt::TransferServerInterfaceFactory { + std::shared_ptr kv_store = + xla::GetDistributedKeyValueStore(distributed_client, + "transfer_server:"); + auto factory_fn = xla::ValueOrThrow( + xla::ifrt::PjRtTransferServer::MakePjRtTransferServerFactory( + transfer_size, + absl::Seconds(cross_host_transfer_timeout_seconds), kv_store, + socket_address, transport_addresses)); + return xla::ifrt::TransferServerInterfaceFactory{std::move(factory_fn)}; + }, + nb::arg("transfer_size") = 256 * 1024 * 1024, + nb::arg("cross_host_transfer_timeout_seconds") = 60, + nb::arg("distributed_client").none() = nullptr, + nb::arg("socket_address") = SocketAddress().ToString(), + nb::arg("transport_addresses") = std::vector()); +} + +} // namespace aux diff --git a/jaxlib/py_socket_transfer.h b/jaxlib/py_socket_transfer.h new file mode 100644 index 000000000000..1b0236b56889 --- /dev/null +++ b/jaxlib/py_socket_transfer.h @@ -0,0 +1,26 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef JAXLIB_TRANSFER_PY_SOCKET_TRANSFER_H_ +#define JAXLIB_TRANSFER_PY_SOCKET_TRANSFER_H_ + +#include "nanobind/nanobind.h" + +namespace aux { + +void RegisterTransferServerTypes(nanobind::module_& m); + +} // namespace aux + +#endif // JAXLIB_TRANSFER_PY_SOCKET_TRANSFER_H_ diff --git a/jaxlib/py_user_context.cc b/jaxlib/py_user_context.cc new file mode 100644 index 000000000000..b959d551a848 --- /dev/null +++ b/jaxlib/py_user_context.cc @@ -0,0 +1,107 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_user_context.h" + +#include + +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/traceback.h" +#include "xla/python/ifrt/user_context.h" +#include "xla/service/slow_operation_alarm.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "tsl/platform/random.h" + +namespace jax { + +namespace nb = ::nanobind; + +// For LLVM RTTI. +char PyUserContext::ID = 0; + +xla::ifrt::UserContextRef PyUserContext::Create( + std::optional traceback) { + if (traceback.has_value()) { + return tsl::TakeRef( + new PyUserContext(*std::move(traceback))); + } + return {}; +} + +xla::ifrt::UserContextRef PyUserContext::Create() { + return Create(Traceback::Get()); +} + +PyUserContext::PyUserContext(Traceback traceback) + : id_(tsl::random::ThreadLocalNew64()), traceback_(std::move(traceback)) {} + +PyUserContext::~PyUserContext() { + // The traceback must be destroyed under the GIL. + GlobalPyRefManager()->AddGarbage(std::move(traceback_)); +} + +Traceback PyUserContext::traceback() const { + CHECK(PyGILState_Check()); + return traceback_; +} + +xla::ifrt::UserContextId PyUserContext::Id() const { return id_; } + +std::string PyUserContext::DebugString() const { + absl::MutexLock lock(mu_); + + if (debug_str_.has_value()) { + return *debug_str_; + } + + xla::SlowOperationAlarm slow_gil_alarm( + absl::Seconds(20), + "Acquiring the GIL in PyUserContext::DebugString took longer than 20s. " + "This can occur when an operation blocks while holding the GIL."); + nb::gil_scoped_acquire gil_acquire; + slow_gil_alarm.cancel(); + + try { + debug_str_ = traceback_.ToString(); + } catch (std::exception& e) { + debug_str_ = absl::StrFormat( + "(traceback could not be converted to a string: %s)", e.what()); + } + return *debug_str_; +} + +std::optional GetTraceback( + const xla::ifrt::UserContext* user_context) { + if (const auto* py_user_context = + llvm::dyn_cast_or_null(user_context)) { + return py_user_context->traceback(); + } + return std::nullopt; +} + +} // namespace jax diff --git a/jaxlib/py_user_context.h b/jaxlib/py_user_context.h new file mode 100644 index 000000000000..3c3da2c5a577 --- /dev/null +++ b/jaxlib/py_user_context.h @@ -0,0 +1,101 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_USER_CONTEXT_H_ +#define JAXLIB_PY_USER_CONTEXT_H_ + +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/synchronization/mutex.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "jaxlib/traceback.h" +#include "xla/python/ifrt/user_context.h" + +namespace jax { + +// IFRT `UserContext` implementation for JAX that captures a Python traceback. +// Can be associated with an IFRT runtime objects such as `xla::ifrt::Array` and +// `xla::ifrt::LoadedExecutable` to track their creation. +// +// All methods are thread-safe. +class PyUserContext + : public llvm::RTTIExtends { + public: + // Creates a `PyUserContext` from a given Python traceback. If `traceback` is + // `nullopt`, returns `nullptr`. + static xla::ifrt::UserContextRef Create(std::optional traceback); + + // Creates a `PyUserContext` with a new `Traceback`. If JAX `Traceback` is not + // enabled, returns `nullptr`. + static xla::ifrt::UserContextRef Create(); + + PyUserContext(const PyUserContext&) = delete; + PyUserContext& operator=(const PyUserContext&) = delete; + + // Destructor. Does not require GIL. + ~PyUserContext() override; + + // Returns the traceback captured by this `PyUserContext`. + // Requires GIL. + Traceback traceback() const; + + // UserContext implementation. + + xla::ifrt::UserContextId Id() const override; + + // Returns a string representation of the traceback captured by this + // `PyUserContext`. + // + // While GIL is not required to call this method, calling `DebugString()` when + // the caller already holds GIL is strongly recommended to reduce the overhead + // of (re)acquiring GIL. + std::string DebugString() const override; + + // For LLVM RTTI. + static char ID; // NOLINT + + private: + explicit PyUserContext(Traceback traceback); + + xla::ifrt::UserContextId id_; + Traceback traceback_; + + // Debug string generation can be expensive. Maintain a cache for them. + mutable absl::Mutex mu_; + mutable std::optional debug_str_ ABSL_GUARDED_BY(mu_); +}; + +// Retrieves a `Traceback` object from an IFRT `UserContext`. Returns `nullopt` +// if no `Traceback` was captured for a `PyUserContext` or `user_context` is not +// a `PyUserContext`. +// +// Requires GIL. +std::optional GetTraceback( + const xla::ifrt::UserContext* user_context); + +// Shorthand for `xla::ifrt::UserContextScope(PyUserContext::Create())`. +class PyUserContextScope { + public: + PyUserContextScope() : user_context_scope_(PyUserContext::Create()) {} + + private: + xla::ifrt::UserContextScope user_context_scope_; +}; + +} // namespace jax + +#endif // JAXLIB_PY_USER_CONTEXT_H_ diff --git a/jaxlib/py_values.cc b/jaxlib/py_values.cc new file mode 100644 index 000000000000..58900145455a --- /dev/null +++ b/jaxlib/py_values.cc @@ -0,0 +1,1233 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_values.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/functional/any_invocable.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/complex.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/py_array.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/sharding.h" +#include "jaxlib/to_ifrt_sharding.h" +#include "xla/primitive_util.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/array_spec.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/safe_static_init.h" +#include "xla/python/types.h" +#include "xla/shape.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/python/lib/core/numpy.h" +#include "xla/types.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/ml_dtypes.h" +#include "tsl/profiler/lib/traceme.h" + +namespace nb = nanobind; +namespace ifrt = xla::ifrt; + +namespace jax { + +namespace { + +// The TypedInt, TypedFloat, TypedComplex, and TypedNdArray types. +nb::object& typed_int_type = *new nb::object(); +nb::object& typed_float_type = *new nb::object(); +nb::object& typed_complex_type = *new nb::object(); +nb::object& typed_ndarray_type = *new nb::object(); + +// For xla::S64/U64/F64/C128 types, returns the largest 32-bit equivalent. +xla::PrimitiveType Squash64BitType(xla::PrimitiveType type) { + switch (type) { + case xla::S64: + return xla::S32; + case xla::U64: + return xla::U32; + case xla::F64: + return xla::F32; + case xla::C128: + return xla::C64; + default: + return type; + } +} + +// Gets the thread-local instance. +static DevicePutInfo& GetDevicePutInfo() { + thread_local DevicePutInfo device_put_info; + return device_put_info; +} + +// Prepared data for creating a single shard of an array. Holds a single-device +// IFRT array or a host buffer. +struct Shard { + explicit Shard(ifrt::ArrayRef ifrt_array, bool weak_type) + : ifrt_array_or_host_buffer(std::move(ifrt_array)), + weak_type(weak_type), + // host_buffer_semantics is not meaningful when + // `ifrt_array_or_host_buffer` is an IFRT Array. + host_buffer_semantics( + ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall) {} + + Shard(ifrt::Client::HostBuffer ifrt_host_buffer, bool weak_type, + ifrt::Client::HostBufferSemantics host_buffer_semantics) + : ifrt_array_or_host_buffer(std::move(ifrt_host_buffer)), + weak_type(weak_type), + host_buffer_semantics(host_buffer_semantics) {} + + Shard(const Shard&) = delete; + Shard& operator=(const Shard&) = delete; + Shard(Shard&&) noexcept = default; + Shard& operator=(Shard&&) noexcept = default; + + bool is_ifrt_array() const { + return std::holds_alternative(ifrt_array_or_host_buffer); + } + ifrt::DType ifrt_dtype() const; + const ifrt::Shape& ifrt_shape() const; + + // Points to the on-device array or on-host buffer. + std::variant + ifrt_array_or_host_buffer; + bool weak_type; + ifrt::Client::HostBufferSemantics host_buffer_semantics; +}; + +// A function that creates a `Shard` from a Python object when called. +using ShardFn = absl::AnyInvocable() &&>; + +absl::StatusOr> StringDTypeArrayToCords( + PyArrayObject* py_array_obj) { + if (PyArray_SIZE(py_array_obj) == 0) { + return absl::InvalidArgumentError("empty numpy array"); + } + + std::vector cords; + cords.reserve(PyArray_SIZE(py_array_obj)); + + auto iter = + nb::steal(PyArray_IterNew(reinterpret_cast(py_array_obj))); + while (PyArray_ITER_NOTDONE(iter.ptr())) { + auto* iter_data = PyArray_ITER_DATA(iter.ptr()); + auto* item = PyArray_GETITEM(py_array_obj, static_cast(iter_data)); + if (!item) { + return absl::InternalError( + "Failed to get elements out of the ndarray iter."); + } + Py_ssize_t len; + auto str = PyUnicode_AsUTF8AndSize(item, &len); + cords.push_back(absl::Cord(std::string_view(str, len))); + PyArray_ITER_NEXT(iter.ptr()); + } + return cords; +} + +// Handler that creates a `Shard` from a Python object. +using DevicePutHandler = std::function( + nb::handle obj, ifrt::Client* client, ifrt::Device* to_device, + ifrt::MemoryKind to_memory_kind, const DevicePutOptions& options)>; + +// Shared logic that makes an IFRT array (either single-device or multi-device) +// from a fully-replicated `shard` that is created from a host buffer (not from +// an existing IFRT array). `shard` will be consumed. +// +// Expected to be called without holding GIL. +absl::StatusOr> +MakeIfrtArrayFromFullyReplicatedShard(ifrt::Client* ifrt_client, + ifrt::ShardingRef ifrt_sharding, + Shard& shard) { + auto host_buffer_shard = std::get( + std::move(shard.ifrt_array_or_host_buffer)); + return ifrt_client->MakeArrayFromHostBuffer( + host_buffer_shard.data, host_buffer_shard.dtype, + std::move(host_buffer_shard.shape), + std::move(host_buffer_shard.byte_strides), std::move(ifrt_sharding), + shard.host_buffer_semantics, std::move(host_buffer_shard.on_done)); +} + +// Shared logic that makes a single-device IFRT array from a `shard`. `shard` +// will be consumed. +// +// Expected to be called without holding GIL. +absl::StatusOr MakeSingleDeviceIfrtArrayFromShard( + xla::ifrt::Client* ifrt_client, xla::ifrt::Device* ifrt_device, + xla::ifrt::MemoryKind ifrt_memory_kind, Shard& shard) { + if (auto* ifrt_array = + std::get_if(&shard.ifrt_array_or_host_buffer)) { + return std::move(*ifrt_array); + } + ifrt::ShardingRef ifrt_sharding = + ifrt::SingleDeviceSharding::Create(ifrt_device, ifrt_memory_kind); + return MakeIfrtArrayFromFullyReplicatedShard(ifrt_client, + std::move(ifrt_sharding), shard); +} + +using HostBuffer = ifrt::Client::HostBuffer; + +struct HostBufferHash { + // We hash on everything except for on_done. + size_t operator()(const HostBuffer* h) const { + return absl::Hash< + std::tuple&>>()( + {h->data, h->dtype, h->shape, h->byte_strides}); + } +}; + +struct HostBufferEq { + bool operator()(const HostBuffer* lhs, const HostBuffer* rhs) const { + return lhs->data == rhs->data && lhs->dtype == rhs->dtype && + lhs->shape == rhs->shape && lhs->byte_strides == rhs->byte_strides; + } +}; + +// Makes an IFRT Array from `shards` using a batched array creation API (fast +// path). `shards` will be consumed. +// +// Expected to be called without holding GIL. +absl::StatusOr MakeIfrtArrayFromShardsInBatch( + ifrt::Client* ifrt_client, ifrt::DType ifrt_dtype, ifrt::Shape ifrt_shape, + ifrt::ShardingRef ifrt_sharding, absl::Span shards) { + absl::InlinedVector< + std::pair, ifrt::Client::HostBuffer>, 1> + host_buffers; + // Note: Dedup map relies on this reserve to give pointer stability. + host_buffers.reserve(shards.size()); + ifrt::Client::HostBufferSemantics safe_host_semantics = + ifrt::Client::HostBufferSemantics::kImmutableZeroCopy; + + // TODO(hyeontaek): Consider performing this deduplication earlier to avoid + // even constructing the HostBuffers. + absl::flat_hash_map + host_buffer_dedup_map; + host_buffer_dedup_map.reserve(shards.size()); + + for (int64_t i = 0; i < shards.size(); ++i) { + auto insert_idx = host_buffers.size(); + host_buffers.push_back( + {{}, + std::move(std::get( + std::move(shards[i].ifrt_array_or_host_buffer)))}); + auto [it, insert_happened] = host_buffer_dedup_map.emplace( + std::make_pair(&host_buffers.back().second, insert_idx)); + if (!insert_happened) { + std::move(host_buffers.back().second.on_done)(); + host_buffers.pop_back(); + } + host_buffers[it->second].first.push_back(i); + // The minimum host buffer semantics is a safe semantics that can be used + // for all shards when they are created in a single batch. + safe_host_semantics = + std::min(safe_host_semantics, shards[i].host_buffer_semantics); + } + + std::vector specs; + specs.push_back(ifrt::Client::MakeArraysFromHostBufferShardsSpec{ + std::move(host_buffers), + ifrt::ArraySpec{/*dtype=*/ifrt_dtype, + /*shape=*/std::move(ifrt_shape), + /*sharding=*/std::move(ifrt_sharding), + /*layout=*/nullptr}}); + TF_ASSIGN_OR_RETURN(auto arrays, + ifrt_client->MakeArraysFromHostBufferShards( + absl::MakeSpan(specs), safe_host_semantics)); + return std::move(arrays.front()); +} + +// Makes an IFRT Array from `shards` using an array assembly API (slow path). +// `shards` will be consumed. +// +// Expected to be called without holding GIL. +absl::StatusOr MakeIfrtArrayFromShardsWithAssembly( + ifrt::Client* ifrt_client, ifrt::DType ifrt_dtype, ifrt::Shape ifrt_shape, + ifrt::ShardingRef ifrt_sharding, + ifrt::DeviceList* ifrt_addressable_device_list, + ifrt::MemoryKind ifrt_memory_kind, absl::Span shards) { + absl::Span ifrt_addressable_devices = + ifrt_addressable_device_list->devices(); + std::vector ifrt_array_shards; + ifrt_array_shards.reserve(shards.size()); + for (int64_t i = 0; i < shards.size(); ++i) { + TF_ASSIGN_OR_RETURN(ifrt::ArrayRef ifrt_array_shard, + MakeSingleDeviceIfrtArrayFromShard( + ifrt_client, ifrt_addressable_devices[i], + ifrt_memory_kind, shards[i])); + ifrt_array_shards.push_back(std::move(ifrt_array_shard)); + } + return ifrt_client->AssembleArrayFromSingleDeviceArrays( + ifrt_dtype, std::move(ifrt_shape), std::move(ifrt_sharding), + absl::MakeSpan(ifrt_array_shards), ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); +} + +template +absl::StatusOr HandlePythonScalar(nb::handle obj, ifrt::Client* client, + ifrt::Device* to_device, + ifrt::MemoryKind to_memory_kind, + const DevicePutOptions& options) { + T value; + try { + value = nb::cast(obj); + } catch (const std::exception& e) { + return xla::InvalidArgument( + "Unable to convert Python scalar to %s. This most likely means the " + "value (%s) overflows the range of the type.", + xla::PrimitiveType_Name( + xla::primitive_util::NativeToPrimitiveType()), + nb::cast(nb::repr(obj))); + } + + std::variant data; + xla::Shape shape; + xla::PrimitiveType type; + if (std::is_same() || !options.squash_64bit_types) { + data.template emplace<0>(value); + type = xla::primitive_util::NativeToPrimitiveType(); + } else { + // TODO(phawkins): we should check for overflow here, e.g., because of bugs + // like https://github.com/google/jax/issues/2006 + data.template emplace<1>(static_cast(value)); + type = xla::primitive_util::NativeToPrimitiveType(); + } + TF_ASSIGN_OR_RETURN(ifrt::DType ifrt_dtype, ifrt::ToDType(type)); + + return [data, ifrt_dtype]() -> absl::StatusOr { + const void* ptr = std::visit( + [](const auto& v) { return static_cast(&v); }, data); + ifrt::Client::HostBuffer ifrt_host_buffer{ + ptr, ifrt_dtype, ifrt::Shape({}), + /*byte_strides=*/std::nullopt, + /*on_done_with_host_buffer=*/nullptr}; + return Shard(std::move(ifrt_host_buffer), /*weak_type=*/true, + ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall); + }; +} + +absl::StatusOr HandlePythonInt(nb::handle obj, ifrt::Client* client, + ifrt::Device* to_device, + ifrt::MemoryKind to_memory_kind, + const DevicePutOptions& options) { + xla::PrimitiveType type; + std::variant data; + + if (options.squash_64bit_types) { + try { + data.emplace<1>(nb::cast(obj)); + } catch (const std::exception& e) { + return xla::InvalidArgument( + "Unable to convert Python scalar to %s. This most likely means the " + "value (%s) overflows the range of the type.", + PrimitiveType_Name( + xla::primitive_util::NativeToPrimitiveType()), + nb::cast(nb::repr(obj))); + } + type = xla::S32; + } else { + try { + data.emplace<0>(nb::cast(obj)); + } catch (const std::exception& e) { + return xla::InvalidArgument( + "Unable to convert Python scalar to %s. This most likely means the " + "value (%s) overflows the range of the type.", + xla::PrimitiveType_Name( + xla::primitive_util::NativeToPrimitiveType()), + nb::cast(nb::repr(obj))); + } + type = xla::S64; + } + TF_ASSIGN_OR_RETURN(ifrt::DType ifrt_dtype, ifrt::ToDType(type)); + return [data, ifrt_dtype]() -> absl::StatusOr { + const void* ptr = std::visit( + [](const auto& v) { return static_cast(&v); }, data); + ifrt::Client::HostBuffer ifrt_host_buffer{ + ptr, ifrt_dtype, ifrt::Shape({}), + /*byte_strides=*/std::nullopt, + /*on_done_with_host_buffer=*/nullptr}; + return Shard(std::move(ifrt_host_buffer), /*weak_type=*/true, + ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall); + }; +} + +template +absl::StatusOr HandleNumpyScalar(nb::handle h, ifrt::Client* client, + ifrt::Device* to_device, + ifrt::MemoryKind to_memory_kind, + const DevicePutOptions& options) { + std::variant data; + xla::PrimitiveType type; + // For extension types, ScalarAsCtype returns a pointer to the data. + if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = xla::S2; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = xla::S4; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = xla::U2; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = xla::U4; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = xla::BF16; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = xla::F4E2M1FN; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = xla::F8E3M4; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = xla::F8E4M3; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = xla::F8E4M3FN; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = xla::F8E4M3B11FNUZ; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = xla::F8E5M2; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = xla::F8E4M3FNUZ; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = xla::F8E5M2FNUZ; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = xla::F8E8M0FNU; + } else if (std::is_same() || !options.squash_64bit_types) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<0>()); + type = xla::primitive_util::NativeToPrimitiveType(); + } else { + T value; + PyArray_ScalarAsCtype(h.ptr(), &value); + data.template emplace<1>(static_cast(value)); + type = xla::primitive_util::NativeToPrimitiveType(); + } + std::shared_ptr py_buffer_ref; + if (data.index() == 2) { + py_buffer_ref = + GlobalPyRefManager()->ManageReference(nb::cast(h)); + } + TF_ASSIGN_OR_RETURN(ifrt::DType ifrt_dtype, ifrt::ToDType(type)); + return [data, py_buffer_ref = std::move(py_buffer_ref), + ifrt_dtype]() mutable -> absl::StatusOr { + const void* ptr = std::visit( + [](const auto& v) -> const void* { + if constexpr (std::is_same_v, void*>) { + return v; + } else { + return static_cast(&v); + } + }, + data); + ifrt::Client::HostBuffer ifrt_host_buffer{ + ptr, ifrt_dtype, ifrt::Shape({}), + /*byte_strides=*/std::nullopt, + /*on_done_with_host_buffer=*/ + [py_buffer_ref = + std::move(py_buffer_ref)]() { /* keeps py_buffer_ref alive */ }}; + return Shard(std::move(ifrt_host_buffer), /*weak_type=*/false, + ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall); + }; +} + +absl::StatusOr HandleStringNumpyArray( + nb::handle h, ifrt::Client* client, ifrt::Device* to_device, + ifrt::MemoryKind to_memory_kind, const DevicePutOptions& options) { + xla::nb_numpy_ndarray array = nb::cast(h); + auto py_array_obj = reinterpret_cast(array.ptr()); + TF_ASSIGN_OR_RETURN(auto cords, StringDTypeArrayToCords(py_array_obj)); + + // Assemble all the parameters of MakeArrayFromHostBuffer + const void* data = cords.data(); + + // Make an explicit copy of the shape elements so we won't run into complex + // endianness and precision issues that might arise if we reinterpret-casted + // from npy_intp, that can be just 32 bits-wide in some environments + // such as macos_arm64 to const int64_t* that must be 64 bits-wide. + ifrt::Shape::Dimensions dims; + dims.reserve(array.ndim()); + for (int i = 0; i < array.ndim(); ++i) { + dims.push_back(array.shape(i)); + } + ifrt::Shape shape(std::move(dims)); + + auto on_done_with_host_buffer = [cords = std::move(cords)] {}; + + return [data, shape = std::move(shape), + on_done_with_host_buffer = std::move( + on_done_with_host_buffer)]() mutable -> absl::StatusOr { + ifrt::Client::HostBuffer ifrt_host_buffer{ + data, ifrt::DType(ifrt::DType::kString), std::move(shape), + /*byte_strides=*/std::nullopt, std::move(on_done_with_host_buffer)}; + return Shard( + std::move(ifrt_host_buffer), /*weak_type=*/false, + ifrt::Client::HostBufferSemantics::kImmutableUntilTransferCompletes); + }; +} + +absl::StatusOr HandleNumpyArray(nb::handle h, ifrt::Client* client, + ifrt::Device* to_device, + ifrt::MemoryKind to_memory_kind, + const DevicePutOptions& options) { + xla::nb_numpy_ndarray array = nb::cast(h); + + // String numpy arrays require substantially different processing. + if (array.dtype().char_() == (int)'T' || array.dtype().kind() == 'T') { + return HandleStringNumpyArray(h, client, to_device, to_memory_kind, + options); + } + + TF_ASSIGN_OR_RETURN(xla::PrimitiveType type, + DtypeToPrimitiveType(array.dtype())); + + xla::PrimitiveType squashed_type; + if (options.squash_64bit_types) { + squashed_type = Squash64BitType(type); + if (squashed_type != type) { + TF_ASSIGN_OR_RETURN(xla::nb_dtype squashed_dtype, + PrimitiveTypeToNbDtype(squashed_type)); + array = nb::steal(PyArray_CastToType( + reinterpret_cast(array.ptr()), + reinterpret_cast(squashed_dtype.release().ptr()), + /*fortran=*/0)); + } + } else { + squashed_type = type; + } + + absl::InlinedVector dims(array.ndim()); + ifrt::Client::HostBuffer::ByteStrides byte_strides(array.ndim()); + for (int i = 0; i < array.ndim(); ++i) { + dims[i] = array.shape(i); + byte_strides[i] = array.strides(i); + } + const void* data = array.data(); + std::shared_ptr py_buffer_ref = + GlobalPyRefManager()->ManageReference(std::move(array)); + TF_ASSIGN_OR_RETURN(ifrt::DType ifrt_dtype, ifrt::ToDType(squashed_type)); + return [data, ifrt_dtype, dims = std::move(dims), + byte_strides = std::move(byte_strides), + py_buffer_ref = std::move(py_buffer_ref), + allow_zero_copy = + options.allow_zero_copy]() mutable -> absl::StatusOr { + ifrt::Client::HostBufferSemantics host_buffer_semantics = + ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall; + std::function on_done_with_host_buffer; + if (allow_zero_copy) { + on_done_with_host_buffer = + [py_buffer_ref{ + std::move(py_buffer_ref)}]() { /* keeps py_buffer_ref alive */ }; + host_buffer_semantics = + ifrt::Client::HostBufferSemantics::kImmutableZeroCopy; + } + + ifrt::Client::HostBuffer ifrt_host_buffer{ + data, ifrt_dtype, ifrt::Shape(dims), std::move(byte_strides), + std::move(on_done_with_host_buffer)}; + return Shard(std::move(ifrt_host_buffer), /*weak_type=*/false, + host_buffer_semantics); + }; +} + +absl::StatusOr HandleTypedInt(nb::handle h, ifrt::Client* client, + ifrt::Device* to_device, + ifrt::MemoryKind to_memory_kind, + const DevicePutOptions& options) { + xla::nb_dtype dtype = nb::cast(h.attr("dtype")); + TF_ASSIGN_OR_RETURN(xla::PrimitiveType type, + xla::DtypeToPrimitiveType(dtype)); + switch (type) { + case xla::S64: + return HandlePythonScalar(h, client, to_device, + to_memory_kind, options); + case xla::S32: + return HandlePythonScalar(h, client, to_device, + to_memory_kind, options); + default: + return xla::InvalidArgument("Unsupported type: %s", + xla::PrimitiveType_Name(type)); + } +} + +absl::StatusOr HandleTypedFloat(nb::handle h, ifrt::Client* client, + ifrt::Device* to_device, + ifrt::MemoryKind to_memory_kind, + const DevicePutOptions& options) { + xla::nb_dtype dtype = nb::cast(h.attr("dtype")); + TF_ASSIGN_OR_RETURN(xla::PrimitiveType type, + xla::DtypeToPrimitiveType(dtype)); + switch (type) { + case xla::F64: + return HandlePythonScalar(h, client, to_device, + to_memory_kind, options); + case xla::F32: + return HandlePythonScalar(h, client, to_device, + to_memory_kind, options); + default: + return xla::InvalidArgument("Unsupported type: %s", + xla::PrimitiveType_Name(type)); + } +} + +absl::StatusOr HandleTypedComplex(nb::handle h, ifrt::Client* client, + ifrt::Device* to_device, + ifrt::MemoryKind to_memory_kind, + const DevicePutOptions& options) { + xla::nb_dtype dtype = nb::cast(h.attr("dtype")); + TF_ASSIGN_OR_RETURN(xla::PrimitiveType type, + xla::DtypeToPrimitiveType(dtype)); + switch (type) { + case xla::C128: + return HandlePythonScalar( + h, client, to_device, to_memory_kind, options); + case xla::C64: + return HandlePythonScalar( + h, client, to_device, to_memory_kind, options); + default: + return xla::InvalidArgument("Unsupported type: %s", + xla::PrimitiveType_Name(type)); + } +} + +absl::StatusOr HandleTypedNdArray(nb::handle h, ifrt::Client* client, + ifrt::Device* to_device, + ifrt::MemoryKind to_memory_kind, + const DevicePutOptions& options) { + DevicePutOptions o = options; + o.squash_64bit_types = false; + return HandleNumpyArray(h.attr("val"), client, to_device, to_memory_kind, o); +} + +absl::StatusOr HandlePyArray(nb::handle obj, ifrt::Client* client, + ifrt::Device* to_device, + ifrt::MemoryKind to_memory_kind, + const DevicePutOptions& options) { + auto py_array = nb::borrow(obj); + + // We only allow single device case for PyArray in device put. + if (py_array.num_shards() != 1) { + return xla::InvalidArgument( + "device_put expects an array with exactly one shard, got an array with " + "with %d shards.", + py_array.num_shards()); + } + + ifrt::Array* ifrt_array = py_array.ifrt_array(); + if (ifrt_array == nullptr) { + return xla::InvalidArgument("Array has been deleted."); + } + + // Fallback to python for non-matching clients or pmap sharding. + if (py_array.sharding().type().ptr() == PmapSharding::type().ptr() || + ifrt_array->sharding().devices()->devices().front()->client() != + to_device->client()) { + return HandleNumpyArray(obj.attr("_value"), client, to_device, + to_memory_kind, options); + } + + if (ifrt_array->sharding().devices()->devices().front() == to_device && + options.allow_zero_copy && + (!to_memory_kind.memory_kind().has_value() || + !ifrt_array->sharding().memory_kind().memory_kind().has_value() || + ifrt_array->sharding().memory_kind() == to_memory_kind)) { + Shard result(tsl::FormRef(ifrt_array), py_array.weak_type()); + return [result = std::move(result)]() mutable { return std::move(result); }; + } else { + return [ifrt_array = tsl::FormRef(ifrt_array), to_device, to_memory_kind, + weak_type = py_array.weak_type(), + allow_zero_copy = + options.allow_zero_copy]() mutable -> absl::StatusOr { + auto* ifrt_client = ifrt_array->client(); + TF_ASSIGN_OR_RETURN(ifrt::DeviceListRef device_list, + ifrt_client->MakeDeviceList({to_device})); + TF_ASSIGN_OR_RETURN( + auto copied_ifrt_arrays, + ifrt_client->CopyArrays(absl::MakeSpan(&ifrt_array, 1), + std::move(device_list), to_memory_kind, + allow_zero_copy + ? ifrt::ArrayCopySemantics::kReuseInput + : ifrt::ArrayCopySemantics::kAlwaysCopy)); + return Shard(std::move(copied_ifrt_arrays.front()), weak_type); + }; + } +} + +ifrt::DType Shard::ifrt_dtype() const { + if (is_ifrt_array()) { + return std::get(ifrt_array_or_host_buffer)->dtype(); + } else { + return std::get(ifrt_array_or_host_buffer).dtype; + } +} + +const ifrt::Shape& Shard::ifrt_shape() const { + if (is_ifrt_array()) { + return std::get(ifrt_array_or_host_buffer)->shape(); + } else { + return std::get(ifrt_array_or_host_buffer).shape; + } +} + +// Creates a `ShardFn` that copies `arg` to `to_device` and `to_memory_kind`. +// +// Requires GIL. The returned `ShardFn` should be called without GIL held. +absl::StatusOr MakeShardFn(nb::handle arg, ifrt::Client* client, + ifrt::Device* to_device, + ifrt::MemoryKind to_memory_kind, + const DevicePutOptions& options) { + using PyObjectDeviceHandlerMap = + absl::flat_hash_map; + + auto init_fn = []() { + std::unique_ptr p = + std::make_unique(); + + const xla::NumpyScalarTypes& dtypes = xla::GetNumpyScalarTypes(); + // Python scalar types. + static_assert(sizeof(bool) == 1, "Conversion code assumes bool is 1 byte"); + (*p)[reinterpret_cast(&PyBool_Type)] = + HandlePythonScalar; + (*p)[reinterpret_cast(&PyLong_Type)] = HandlePythonInt; + (*p)[reinterpret_cast(&PyFloat_Type)] = + HandlePythonScalar; + (*p)[reinterpret_cast(&PyComplex_Type)] = + HandlePythonScalar; + + if (typed_int_type.ptr() != nullptr) { + (*p)[typed_int_type.ptr()] = HandleTypedInt; + } + if (typed_float_type.ptr() != nullptr) { + (*p)[typed_float_type.ptr()] = HandleTypedFloat; + } + if (typed_complex_type.ptr() != nullptr) { + (*p)[typed_complex_type.ptr()] = HandleTypedComplex; + } + (*p)[reinterpret_cast(&PyArray_Type)] = HandleNumpyArray; + + if (typed_ndarray_type.ptr() != nullptr) { + (*p)[typed_ndarray_type.ptr()] = HandleTypedNdArray; + } + // Numpy scalar types. For some of them, we share the handler with + // Python types (np_int64, np_float64, np_complex128). + (*p)[dtypes.np_bool.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_int4.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_int2.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_int8.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_int16.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_int32.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_int64.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_uint2.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_uint4.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_uint8.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_uint16.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_uint32.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_uint64.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float4_e2m1fn.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float8_e3m4.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float8_e4m3.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float8_e4m3fn.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = + HandleNumpyScalar; + (*p)[dtypes.np_float8_e5m2.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float8_e4m3fnuz.ptr()] = + HandleNumpyScalar; + (*p)[dtypes.np_float8_e5m2fnuz.ptr()] = + HandleNumpyScalar; + (*p)[dtypes.np_float8_e8m0fnu.ptr()] = + HandleNumpyScalar; + (*p)[dtypes.np_bfloat16.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float16.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float32.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float64.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_complex64.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_complex128.ptr()] = + HandleNumpyScalar; + static_assert(sizeof(long long) == sizeof(int64_t), // NOLINT + "long long must be the same size as int64_t"); + (*p)[dtypes.np_longlong.ptr()] = HandleNumpyScalar; + static_assert(sizeof(int) == sizeof(int32_t), + "int must be the same size as int32_t"); + (*p)[dtypes.np_intc.ptr()] = HandleNumpyScalar; + return p; + }; + const PyObjectDeviceHandlerMap& handlers = + xla::SafeStaticInit(init_fn); + + if (arg.type().ptr() == PyArray::type().ptr()) { + auto array = nb::borrow(arg); + return HandlePyArray(arg, client, to_device, to_memory_kind, options); + } + + auto res = handlers.find(arg.type().ptr()); + if (res == handlers.end()) { + for (auto base_class : arg.type().attr("__mro__")) { + res = handlers.find(base_class.ptr()); + if (res != handlers.end()) { + return res->second(arg, client, to_device, to_memory_kind, options); + } + } + return xla::InvalidArgument( + "%s", absl::StrCat( + "Not supported: The C++ jax jit execution path, only accepts " + "DeviceArray, Numpy arrays scalars of supported types " + "(see implementation), or Python scalars. Got type ", + nb::cast(nb::str(arg.type())))); + } + return res->second(arg, client, to_device, to_memory_kind, options); +} + +} // namespace + +void SetTypedIntType(nb::object t) { typed_int_type = t; } +void SetTypedFloatType(nb::object t) { typed_float_type = t; } +void SetTypedComplexType(nb::object t) { typed_complex_type = t; } +void SetTypedNdArrayType(nb::object t) { typed_ndarray_type = t; } + +std::string PyArgSignature::DebugString() const { + std::string result = ""; + if (weak_type) { + absl::StrAppend(&result, "weak_"); + } + absl::StrAppend(&result, xla::PrimitiveType_Name(dtype)); + absl::StrAppend(&result, "[", absl::StrJoin(shape, ","), "]"); + return result; +} + +using ToPyArgSignatureHandler = + std::function(nb::handle, bool)>; + +absl::StatusOr PyArgSignatureOfValue(nb::handle arg, + bool jax_enable_x64) { + const absl::flat_hash_map& handlers = + xla::SafeStaticInit< + absl::flat_hash_map>([] { + auto p = std::make_unique< + absl::flat_hash_map>(); + + const xla::NumpyScalarTypes& dtypes = xla::GetNumpyScalarTypes(); + + // The 4 Python native types. + ToPyArgSignatureHandler bool_handler = + [](nb::handle, bool) -> absl::StatusOr { + return PyArgSignature(xla::PrimitiveType::PRED, {}, true); + }; + ToPyArgSignatureHandler int_handler = + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + // TODO(phawkins): we should consider checking for integer + // overflow. + if (jax_enable_x64) { + return PyArgSignature(xla::PrimitiveType::S64, {}, true); + } else { + return PyArgSignature(xla::PrimitiveType::S32, {}, true); + } + }; + ToPyArgSignatureHandler float_handler = + [&dtypes](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + // Only Python native types has a True weak_type. + bool weak_type = !nb::isinstance(h, dtypes.np_float64); + if (jax_enable_x64) { + return PyArgSignature(xla::PrimitiveType::F64, {}, weak_type); + } else { + return PyArgSignature(xla::PrimitiveType::F32, {}, weak_type); + } + }; + ToPyArgSignatureHandler complex_handler = + [&dtypes](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + // Note that this branch is also taken for np.complex128: + // isinstance(np.complex128(3), complex) returns True + // isinstance(np.complex64(3), complex) returns False + bool weak_type = !nb::isinstance(h, dtypes.np_complex128); + if (jax_enable_x64) { + return PyArgSignature(xla::PrimitiveType::C128, {}, weak_type); + } else { + return PyArgSignature(xla::PrimitiveType::C64, {}, weak_type); + } + }; + + (*p)[reinterpret_cast(&PyBool_Type)] = bool_handler; + (*p)[reinterpret_cast(&PyLong_Type)] = int_handler; + (*p)[reinterpret_cast(&PyFloat_Type)] = float_handler; + (*p)[reinterpret_cast(&PyComplex_Type)] = complex_handler; + + ToPyArgSignatureHandler typed_scalar_handler = + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + TF_ASSIGN_OR_RETURN( + xla::PrimitiveType dtype, + DtypeToPrimitiveType(nb::cast(h.attr("dtype")))); + return PyArgSignature(dtype, {}, true); + }; + if (typed_int_type.ptr() != nullptr) { + (*p)[typed_int_type.ptr()] = typed_scalar_handler; + } + if (typed_float_type.ptr() != nullptr) { + (*p)[typed_float_type.ptr()] = typed_scalar_handler; + } + if (typed_complex_type.ptr() != nullptr) { + (*p)[typed_complex_type.ptr()] = typed_scalar_handler; + } + + ToPyArgSignatureHandler numpy_handler = + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + xla::nb_numpy_ndarray numpy_array = + nb::cast(h); + TF_ASSIGN_OR_RETURN(xla::PrimitiveType dtype, + DtypeToPrimitiveType(numpy_array.dtype())); + if (!jax_enable_x64) { + dtype = Squash64BitType(dtype); + } + // We use reinterpret_cast<> to defend against environments where + // ssize_t may not be precisely the same type as int64_t, even if + // it is the same size (long vs long long). + static_assert(sizeof(int64_t) == sizeof(ssize_t), + "Code assumes ssize_t is the same as int64_t"); + return PyArgSignature( + dtype, + absl::MakeConstSpan( + reinterpret_cast(numpy_array.shape()), + numpy_array.ndim()), + /*weak_type=*/false); + }; + (*p)[reinterpret_cast(&PyArray_Type)] = numpy_handler; + + ToPyArgSignatureHandler typed_ndarray_handler = + [numpy_handler]( + nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + return numpy_handler(h.attr("val"), /*jax_enable_x64=*/true); + }; + if (typed_ndarray_type.ptr() != nullptr) { + (*p)[typed_ndarray_type.ptr()] = typed_ndarray_handler; + } + + ToPyArgSignatureHandler np_uint64_handler = + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + if (jax_enable_x64) { + return PyArgSignature(xla::PrimitiveType::U64, {}, + /*weak_type=*/false); + } else { + return PyArgSignature(xla::PrimitiveType::U32, {}, + /*weak_type=*/false); + } + }; + ToPyArgSignatureHandler np_int_handler = + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + if (jax_enable_x64) { + return PyArgSignature(xla::PrimitiveType::S64, {}, + /*weak_type=*/false); + } else { + return PyArgSignature(xla::PrimitiveType::S32, {}, + /*weak_type=*/false); + } + }; + ToPyArgSignatureHandler numpy_array_handler = + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + // This block deals with all numpy scalar types, except for + // int64_dt, float64_dt and complex128_dt which are taken care of + // in previous if blocks. + TF_ASSIGN_OR_RETURN(auto dtype, + xla::DtypeToPrimitiveType(h.attr("dtype"))); + return PyArgSignature(dtype, {}, /*weak_type=*/false); + }; + + // This block deals with all numpy scalar types, except for + // int64_dt, float64_dt and complex128_dt which are taken care of in + // previous if blocks. + (*p)[dtypes.np_bool.ptr()] = numpy_array_handler; + (*p)[dtypes.np_int4.ptr()] = numpy_array_handler; + (*p)[dtypes.np_int8.ptr()] = numpy_array_handler; + (*p)[dtypes.np_int16.ptr()] = numpy_array_handler; + (*p)[dtypes.np_int32.ptr()] = numpy_array_handler; + (*p)[dtypes.np_int64.ptr()] = np_int_handler; + (*p)[dtypes.np_uint4.ptr()] = numpy_array_handler; + (*p)[dtypes.np_uint8.ptr()] = numpy_array_handler; + (*p)[dtypes.np_uint16.ptr()] = numpy_array_handler; + (*p)[dtypes.np_uint32.ptr()] = numpy_array_handler; + (*p)[dtypes.np_uint64.ptr()] = np_uint64_handler; + (*p)[dtypes.np_float4_e2m1fn.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e3m4.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e4m3.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e4m3fn.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e4m3fnuz.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e5m2.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e5m2fnuz.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e8m0fnu.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float16.ptr()] = numpy_array_handler; + (*p)[dtypes.np_bfloat16.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float32.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float64.ptr()] = float_handler; + (*p)[dtypes.np_complex64.ptr()] = numpy_array_handler; + (*p)[dtypes.np_complex128.ptr()] = complex_handler; + (*p)[dtypes.np_longlong.ptr()] = np_int_handler; + (*p)[dtypes.np_intc.ptr()] = numpy_array_handler; + + return p; + }); + + if (arg.type().ptr() == PyArray::type().ptr()) { + auto array = nb::borrow(arg); + ifrt::Array* ifrt_array = array.ifrt_array(); + if (ifrt_array == nullptr) { + return xla::InvalidArgument("Array has been deleted."); + } + TF_ASSIGN_OR_RETURN(auto primitive_type, + ifrt::ToPrimitiveType(ifrt_array->dtype())); + return PyArgSignature(primitive_type, array.shape(), array.weak_type()); + } + + auto res = handlers.find(arg.type().ptr()); + if (res == handlers.end()) { + // We attempt to look at the MRO classes + for (auto base_class : arg.type().attr("__mro__")) { + res = handlers.find(base_class.ptr()); + if (res != handlers.end()) { + return res->second(arg, jax_enable_x64); + } + } + return xla::InvalidArgument( + "%s", + absl::StrCat("Not supported: The C++ ToPyArgSignature only accepts " + "JAX Arrays, Numpy arrays and scalars of supported types " + "(see implementation), or Python scalars. Got type ", + nb::cast(nb::str(arg.type())))); + } + return res->second(arg, jax_enable_x64); +} + +absl::StatusOr DevicePutWithDevice( + nanobind::handle addressable_shard, ifrt::Client* ifrt_client, + ifrt::Device* ifrt_device, ifrt::MemoryKind ifrt_memory_kind, + const DevicePutOptions& options) { + tsl::profiler::TraceMe traceme("DevicePut"); + ++GetDevicePutInfo().device_put_with_device; + + if (!ifrt_device->IsAddressable()) { + return xla::InvalidArgument( + "Cannot copy array to non-addressable device: %s", + ifrt_device->DebugString()); + } + + TF_ASSIGN_OR_RETURN(ShardFn shard_fn, + MakeShardFn(addressable_shard, ifrt_client, ifrt_device, + ifrt_memory_kind, options)); + + nb::gil_scoped_release gil_release; + + TF_ASSIGN_OR_RETURN(Shard shard, std::move(shard_fn)()); + TF_ASSIGN_OR_RETURN(ifrt::ArrayRef ifrt_array, + MakeSingleDeviceIfrtArrayFromShard( + ifrt_client, ifrt_device, ifrt_memory_kind, shard)); + return DevicePutResult(std::move(ifrt_array), shard.weak_type); +} + +absl::StatusOr DevicePutWithSharding( + absl::Span addressable_shards, + ifrt::Client* ifrt_client, const xla::nb_dtype& dtype, + absl::Span shape, nanobind::handle sharding, + const DevicePutOptions& options) { + tsl::profiler::TraceMe traceme("DevicePutWithSharding"); + ++GetDevicePutInfo().device_put_with_sharding; + + TF_ASSIGN_OR_RETURN(ifrt::DeviceListRef ifrt_device_list, + GetIfrtDeviceList(sharding)); + ifrt::DeviceList* ifrt_addressable_device_list = + ifrt_device_list->AddressableDeviceList(); + absl::Span ifrt_addressable_devices = + ifrt_addressable_device_list->devices(); + // Pmap sharding requires special handling because it needs a shard shape + // upfront. + const bool is_pmap_sharding = sharding.type().is(PmapSharding::type()); + + if (addressable_shards.size() != ifrt_addressable_devices.size()) { + // Try to generate a friendly error message if the user attempted to copy to + // a non-addressable device. + if (addressable_shards.size() > ifrt_addressable_devices.size()) { + for (ifrt::Device* device : ifrt_device_list->devices()) { + if (!device->IsAddressable()) { + return xla::InvalidArgument( + "Cannot copy array to non-addressable device: %s", + device->DebugString()); + } + } + } + // Otherwise, generate a generic error message. + return xla::InvalidArgument( + "Number of addressable shard data does not match the number " + "of addressable devices in the sharding: %d vs. %d", + addressable_shards.size(), ifrt_addressable_devices.size()); + } + if (is_pmap_sharding && addressable_shards.empty()) { + return xla::InvalidArgument( + "Pmap sharding requires at least one addressable shard."); + } + + TF_ASSIGN_OR_RETURN(ifrt::DType ifrt_dtype, DtypeToIfRtDType(dtype)); + ifrt::Shape ifrt_shape(shape); + ifrt::MemoryKind ifrt_memory_kind = GetMemoryKind(sharding); + + std::vector shard_fns; + shard_fns.reserve(addressable_shards.size()); + for (int i = 0; i < addressable_shards.size(); ++i) { + TF_ASSIGN_OR_RETURN( + ShardFn shard, + MakeShardFn(addressable_shards[i], ifrt_client, + ifrt_addressable_devices[i], ifrt_memory_kind, options)); + shard_fns.push_back(std::move(shard)); + } + + ifrt::ShardingRef ifrt_sharding; + bool is_fully_replicated; + if (is_pmap_sharding) { + CHECK(!shard_fns.empty()); + // IFRT Sharding will be determined once we discover the shard shape. + is_fully_replicated = false; + } else { + TF_ASSIGN_OR_RETURN(ifrt_sharding, + GetIfrtHloSharding(sharding, ifrt_shape)); + // Fully-replicated shardings enable additional optimizations of using a + // single host buffer. + // TODO(hyeontaek): Enable a similar optimization for partially replicated + // cases to reduce the number of host buffers to obtain. + is_fully_replicated = ifrt_sharding->IsFullyReplicated(); + } + + nb::gil_scoped_release gil_release; + + // Whether to build an IFRT array from host buffers as a single batch. We do + // not batch any shard is already an IFRT array. + bool should_batch = true; + + std::vector shards; + shards.reserve(shard_fns.size()); + for (int64_t i = 0; i < shard_fns.size(); ++i) { + TF_ASSIGN_OR_RETURN(Shard shard, std::move(shard_fns[i])()); + if (shard.is_ifrt_array()) { + // If any shard is an IFRT array, we should assemble shards. + should_batch = false; + } + shards.push_back(std::move(shard)); + if (should_batch && is_fully_replicated) { + // We need only one host buffer for a fully-replicated array. + break; + } + } + // While we have finished calling `shard_fns`, we cannot destroy them until we + // make a call to IFRT array creation. Destroying `shard_fns` would release + // host buffers prematurely and can cause the array creation API to see + // garbage data. + + // TODO(emilyaf): Remove the following and just use ifrt_dtype when tokens are + // supported. + if (!shards.empty()) { + ifrt_dtype = shards.front().ifrt_dtype(); + } + if (is_pmap_sharding) { + ifrt_sharding = ifrt::ConcreteEvenSharding::Create( + ifrt::DeviceListRef(tsl::FormRef(ifrt_addressable_device_list)), + ifrt_memory_kind, ifrt_shape, + /*shard_shape=*/shards.front().ifrt_shape(), + /*is_fully_replicated=*/false); + } + + ifrt::ArrayRef ifrt_array; + if (should_batch) { + if (is_fully_replicated && shards.size() == 1) { + ++GetDevicePutInfo().device_put_fully_replicated; + TF_ASSIGN_OR_RETURN(ifrt_array, MakeIfrtArrayFromFullyReplicatedShard( + ifrt_client, std::move(ifrt_sharding), + shards.front())); + } else { + ++GetDevicePutInfo().device_put_batched; + TF_ASSIGN_OR_RETURN( + ifrt_array, MakeIfrtArrayFromShardsInBatch( + ifrt_client, ifrt_dtype, std::move(ifrt_shape), + std::move(ifrt_sharding), absl::MakeSpan(shards))); + } + } else { + ++GetDevicePutInfo().device_put_assembled; + TF_ASSIGN_OR_RETURN( + ifrt_array, MakeIfrtArrayFromShardsWithAssembly( + ifrt_client, ifrt_dtype, std::move(ifrt_shape), + std::move(ifrt_sharding), ifrt_addressable_device_list, + ifrt_memory_kind, absl::MakeSpan(shards))); + } + const bool weak_type = shards.empty() ? false : shards.front().weak_type; + return DevicePutResult(std::move(ifrt_array), weak_type); +} + +std::unordered_map DevicePutInfo::GetInfo() { + const DevicePutInfo& info = GetDevicePutInfo(); + return std::unordered_map({ + {"device_put_with_device", info.device_put_with_device}, + {"device_put_with_sharding", info.device_put_with_sharding}, + {"device_put_fully_replicated", info.device_put_fully_replicated}, + {"device_put_batched", info.device_put_batched}, + {"device_put_assembled", info.device_put_assembled}, + }); +} + +} // namespace jax diff --git a/jaxlib/py_values.h b/jaxlib/py_values.h new file mode 100644 index 000000000000..f5ef4535d60c --- /dev/null +++ b/jaxlib/py_values.h @@ -0,0 +1,165 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Helpers for converting Python values into buffers. + +#ifndef JAXLIB_PY_VALUES_H_ +#define JAXLIB_PY_VALUES_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/nb_numpy.h" +#include "xla/xla_data.pb.h" + +namespace jax { + +struct DevicePutResult { + DevicePutResult(xla::ifrt::ArrayRef ifrt_array, bool weak_type) + : ifrt_array(std::move(ifrt_array)), weak_type(weak_type) {} + + // Disallow copy. `DevicePutResult` is expected to be consumed by one user. + DevicePutResult(const DevicePutResult&) = delete; + DevicePutResult& operator=(const DevicePutResult&) = delete; + DevicePutResult(DevicePutResult&&) noexcept = default; + DevicePutResult& operator=(DevicePutResult&&) noexcept = default; + + // Points to the on-device array. + xla::ifrt::ArrayRef ifrt_array; + bool weak_type; +}; + +// Options for `DevicePut`. +struct DevicePutOptions { + bool squash_64bit_types = false; + bool allow_zero_copy = true; +}; + +// Copies a buffer-like object to be on device. This version is designed for +// creating a single-device array. +// +// If `addressable_shard` is not convertible to a `PjRtBuffer` from C++, an +// error will be returned; float0s are not supported yet. +// +// If the value is known to be a PyBuffer object, py_buffer can be passed as an +// optimization to avoid a Python->C++ cast. +// +// Requires GIL. This function performs Python work inline, and runs expensive +// C++ work with GIL temporarily released. +// +// May throw exceptions from nanobind in addition to failing via an error +// absl::Status. (We could catch these if needed, but there seems little point.) +absl::StatusOr DevicePutWithDevice( + nanobind::handle addressable_shard, xla::ifrt::Client* ifrt_client, + xla::ifrt::Device* ifrt_device, xla::ifrt::MemoryKind ifrt_memory_kind, + const DevicePutOptions& options); + +// Copies a buffer-like object to be on device. This version is optimized for +// creating a multi-device array. +// +// `addressable_shards` is a list of buffer-like objects to be copied to +// addressable devices specified in `sharding`. +// +// `shape` and `sharding` determine the shape and sharding of the returned IFRT +// Array. +// +// The size of `addressable_shards` must match the number of addressable devices +// in `sharding`. For a Pmap sharding, there must be at least one addressable +// device. +// +// Requires GIL. This function performs Python work inline, and runs expensive +// C++ work with GIL temporarily released. +// +// See the above `DevicePutWithDevice` for other details. +absl::StatusOr DevicePutWithSharding( + absl::Span addressable_shards, + xla::ifrt::Client* ifrt_client, const xla::nb_dtype& dtype, + absl::Span shape, nanobind::handle sharding, + const DevicePutOptions& options); + +// Describes the abstract shape and dtype of an argument. +struct PyArgSignature { + PyArgSignature(xla::PrimitiveType dtype, absl::Span shape, + bool weak_type) + : dtype(dtype), shape(shape.begin(), shape.end()), weak_type(weak_type) {} + // This is the XLA dtype of the object. + const xla::PrimitiveType dtype; + const absl::InlinedVector shape; + // JAX arguments can be of weak type, if and only if they are Python scalars + // or `DeviceArray` values such that `aval.weak_type` is true. + const bool weak_type; + bool operator==(const PyArgSignature& other) const { + return std::tie(dtype, weak_type, shape) == + std::tie(other.dtype, other.weak_type, other.shape); + } + bool operator!=(const PyArgSignature& other) const { + return !(*this == other); + } + std::string DebugString() const; +}; + +// Returns the PyArgSignature associated with an argument. Returns an error if +// the argument is not supported. +absl::StatusOr PyArgSignatureOfValue(nanobind::handle arg, + bool jax_enable_x64); + +template +H AbslHashValue(H h, const PyArgSignature& s) { + h = H::combine(std::move(h), s.dtype); + h = H::combine_contiguous(std::move(h), s.shape.data(), s.shape.size()); + return h; +} + +// Tracks the number of DevicePut calls and subcases. For testing. +struct DevicePutInfo { + // DevicePutWithDevice call count. + int device_put_with_device = 0; + + // DevicePutWithSharding call count. + int device_put_with_sharding = 0; + + // DevicePutWithSharding with a fully replicated sharding. + int device_put_fully_replicated = 0; + // DevicePutWithSharding that made a batched array creation call. + int device_put_batched = 0; + // DevicePutWithSharding that made per-shard creation calls followed by an + // assembly call. + int device_put_assembled = 0; + + // Returns a map of the counters for the current thread. + static std::unordered_map GetInfo(); +}; + +// Tells the C++ code about the Python types TypedInt, TypedFloat, +// TypedComplex, and TypedNdArray. +void SetTypedIntType(nanobind::object t); +void SetTypedFloatType(nanobind::object t); +void SetTypedComplexType(nanobind::object t); +void SetTypedNdArrayType(nanobind::object t); + +} // namespace jax + +#endif // JAXLIB_PY_VALUES_H_ diff --git a/jaxlib/pyinit_stub.c b/jaxlib/pyinit_stub.c new file mode 100644 index 000000000000..7fc873d9ae0e --- /dev/null +++ b/jaxlib/pyinit_stub.c @@ -0,0 +1,28 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Stub that reexports Wrapped_PyInit_module as PyInit_module. + +extern void* Wrapped_PyInit_@MODULE_NAME@(); + +#if defined(WIN32) || defined(_WIN32) +#define EXPORT_SYMBOL __declspec(dllexport) +#else +#define EXPORT_SYMBOL __attribute__ ((visibility("default"))) +#endif + +EXPORT_SYMBOL void* PyInit_@MODULE_NAME@() { + return Wrapped_PyInit_@MODULE_NAME@(); +} diff --git a/jaxlib/python_ref_manager.cc b/jaxlib/python_ref_manager.cc new file mode 100644 index 000000000000..4589b595be3f --- /dev/null +++ b/jaxlib/python_ref_manager.cc @@ -0,0 +1,108 @@ +/* Copyright 2019 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/python_ref_manager.h" + +#include + +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "tsl/profiler/lib/traceme.h" + +namespace jax { + +namespace nb = nanobind; + +PythonRefManager::ManagedPyObjects::ManagedPyObjects( + PythonRefManager* manager, absl::Span objects) + : manager_(manager) { + objects_.reserve(objects.size()); + for (nb::object& object : objects) { + objects_.push_back(std::move(object)); + } +} + +PythonRefManager::ManagedPyObjects::~ManagedPyObjects() { + if (manager_ && !objects_.empty()) { + manager_->AddGarbage(absl::MakeSpan(objects_)); + } +} + +std::shared_ptr +PythonRefManager::ManageReference(nb::object object) { + return std::make_shared(this, + absl::Span(&object, 1)); +} + +std::shared_ptr +PythonRefManager::ManageReferences(absl::Span objects) { + return std::make_shared(this, objects); +} + +void PythonRefManager::AddGarbage(nb::object garbage) { + absl::MutexLock lock(mu_); + // We want to collect arbitrary python garbage (e.g., buffers) aggressively. + garbage_count_.fetch_add(100, std::memory_order_relaxed); + python_garbage_.push_back(std::move(garbage)); +} + +void PythonRefManager::AddGarbage(absl::Span garbage) { + absl::MutexLock lock(mu_); + // We want to collect arbitrary python garbage (e.g., buffers) aggressively. + garbage_count_.fetch_add(100, std::memory_order_relaxed); + for (nb::object& o : garbage) { + python_garbage_.push_back(std::move(o)); + } +} + +void PythonRefManager::AddGarbage( + absl::Span const> garbage) { + absl::MutexLock lock(mu_); + // We don't care about collecting stack frame objects often. We grab a lot of + // tracebacks and the code objects are most likely live for the entire + // process. + garbage_count_.fetch_add(1, std::memory_order_relaxed); + for (const auto& o : garbage) { + python_garbage_.push_back(nb::steal(reinterpret_cast(o.first))); + } +} + +void PythonRefManager::CollectGarbage() { + // TODO(phawkins): we should CHECK(PyGILState_Check()); + tsl::profiler::TraceMe traceme("PythonRefManager::CollectGarbage"); + std::deque garbage; + { + absl::MutexLock lock(mu_); + garbage_count_ = 0; + garbage.swap(python_garbage_); + } + // We defer deleting garbage until the lock is released. It's possible that + // deleting garbage will lead to more Python garbage being added; if we held + // the lock we would deadlock because absl::Mutex is not reentrant. +} + +PythonRefManager* GlobalPyRefManager() { + static PythonRefManager* static_ref_manager = new PythonRefManager(); + return static_ref_manager; +} + +} // namespace jax diff --git a/jaxlib/python_ref_manager.h b/jaxlib/python_ref_manager.h new file mode 100644 index 000000000000..ea0c6ad7bade --- /dev/null +++ b/jaxlib/python_ref_manager.h @@ -0,0 +1,108 @@ +/* Copyright 2019 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PYTHON_REF_MANAGER_H_ +#define JAXLIB_PYTHON_REF_MANAGER_H_ + +#include + +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/inlined_vector.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" + +namespace jax { + +// Class that manages destruction of Python objects. +// +// We must not destroy Python objects without holding the GIL. However, we +// frequently want to hold references to Python objects for the duration of +// an asynchronous transfer on a Stream, and release our reference when the +// transfer completes. +// +// This class holds references to Python objects outside a GIL scope, that can +// be collected later when the GIL is held by calling CollectGarbage(). +class PythonRefManager { + public: + PythonRefManager() = default; + + // Holds references to a set of nanobind::objects, adding the references to + // the PythonRefManager on destruction. + class ManagedPyObjects { + public: + ManagedPyObjects() = default; + ManagedPyObjects(PythonRefManager* manager, + absl::Span objects); + + ~ManagedPyObjects(); + + ManagedPyObjects(const ManagedPyObjects& other) = delete; + ManagedPyObjects(ManagedPyObjects&& other) = default; + ManagedPyObjects& operator=(const ManagedPyObjects& other) = delete; + ManagedPyObjects& operator=(ManagedPyObjects&& other) noexcept = default; + + private: + PythonRefManager* manager_ = nullptr; + absl::InlinedVector objects_; + }; + + // Creates a managed std::shared_ptr to an object. When the shared_ptr is + // destroyed, the reference to 'object' will be added to python_garbage_, + // and collected next time CollectGarbage() is called. + std::shared_ptr ManageReference(nanobind::object object); + std::shared_ptr ManageReferences( + absl::Span objects); + + // Adds garbage objects to the manager. + void AddGarbage(nanobind::object garbage); + void AddGarbage(absl::Span garbage); + void AddGarbage(absl::Span const> garbage); + + // Releases the contents of python_garbage_. Requires that the GIL is held. + // The client calls this method during API entry points where the GIL is held + // to free any garbage that has accumulated. + void CollectGarbage(); + + // Cheaper version of CollectGarbage() with relaxed consistency and frequency. + // The purpose of this function is to amortize lock acquisition costs over + // a larger number of API calls. + void MaybeCollectGarbage() { + if (garbage_count_.load(std::memory_order_relaxed) >= 100) { + CollectGarbage(); + } + } + + private: + absl::Mutex mu_; + std::deque python_garbage_ ABSL_GUARDED_BY(mu_); + + // Writes to garbage_count_ are protected by mu_, reads are not protected. + std::atomic garbage_count_{0}; +}; + +// A global PythonRefManager. Unless `CollectGarbage()` is called before +// shutdown, this container will hold on to Python objects and thus cause a +// leak. This behavior is similar to `tensorflow::ClearDecRefCache()`. +PythonRefManager* GlobalPyRefManager(); + +} // namespace jax + +#endif // JAXLIB_PYTHON_REF_MANAGER_H_ diff --git a/jaxlib/pytree.cc b/jaxlib/pytree.cc new file mode 100644 index 000000000000..b0a2d7788cb4 --- /dev/null +++ b/jaxlib/pytree.cc @@ -0,0 +1,1959 @@ +/* Copyright 2019 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Caution: this code uses exceptions. The exception use is local to the +// binding code and the idiomatic way to emit Python exceptions. + +#include "jaxlib/pytree.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/tuple.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "nanobind/typing.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/pytree.pb.h" +#include "xla/pjrt/exceptions.h" +#include "xla/tsl/platform/logging.h" + +namespace jax { + +namespace nb = nanobind; + +constexpr int kSequenceKeyHashSalt = 1; +constexpr int kFlattenedIndexKeyHashSalt = 42; + +PyTreeRegistry::PyTreeRegistry(bool enable_none, bool enable_tuple, + bool enable_namedtuple, bool enable_list, + bool enable_dict) { + auto add_builtin_type = [&](PyTypeObject* type_obj, PyTreeKind kind) { + nb::object type = + nb::borrow(reinterpret_cast(type_obj)); + auto registration = std::make_unique(); + registration->kind = kind; + registration->type = type; + CHECK(registrations_.emplace(type, std::move(registration)).second); + }; + if (enable_none) { + add_builtin_type(Py_TYPE(Py_None), PyTreeKind::kNone); + } + if (enable_tuple) { + add_builtin_type(&PyTuple_Type, PyTreeKind::kTuple); + } + enable_namedtuple_ = enable_namedtuple; + if (enable_list) { + add_builtin_type(&PyList_Type, PyTreeKind::kList); + } + if (enable_dict) { + add_builtin_type(&PyDict_Type, PyTreeKind::kDict); + } +} + +void PyTreeRegistry::Register( + nb::object type, nb::callable to_iterable, nb::callable from_iterable, + std::optional to_iterable_with_keys) { + auto registration = std::make_unique(); + registration->kind = PyTreeKind::kCustom; + registration->type = type; + registration->to_iterable = std::move(to_iterable); + registration->from_iterable = std::move(from_iterable); + registration->to_iterable_with_keys = std::move(to_iterable_with_keys); + nb::ft_lock_guard lock(mu_); + auto it = registrations_.emplace(type, std::move(registration)); + if (!it.second) { + throw std::invalid_argument( + absl::StrFormat("Duplicate custom PyTreeDef type registration for %s.", + nb::cast(nb::repr(type)))); + } +} + +void PyTreeRegistry::RegisterDataclass(nb::object type, + std::vector data_fields, + std::vector meta_fields) { + auto registration = std::make_unique(); + registration->kind = PyTreeKind::kDataclass; + registration->type = type; + registration->data_fields = std::move(data_fields); + registration->meta_fields = std::move(meta_fields); + nb::ft_lock_guard lock(mu_); + auto it = registrations_.emplace(type, std::move(registration)); + if (!it.second) { + throw std::invalid_argument(absl::StrFormat( + "Duplicate custom dataclass PyTreeDef type registration for %s.", + nb::cast(nb::repr(std::move(type))))); + } +} + +std::pair +PyTreeRegistry::Registration::ToIterable(nanobind::handle o) const { + nb::object out = to_iterable(o); + nb::tuple leaves_and_aux_data; + if (!nb::try_cast(out, leaves_and_aux_data) || + leaves_and_aux_data.size() != 2) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable function for a custom PyTree node should return " + "a (children, aux_data) tuple, got ", + nb::cast(nb::repr(out)))); + } + nb::iterable leaves; + if (!nb::try_cast(leaves_and_aux_data[0], leaves)) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable function for a custom PyTree node should return " + "a (children, aux_data) tuple where 'children' is iterable, " + "got ", + nb::cast(nb::repr(out)))); + } + return std::make_pair(std::move(leaves), nb::object(leaves_and_aux_data[1])); +} + +std::pair>, nb::object> +PyTreeRegistry::Registration::ToIterableWithKeys(nb::handle o) const { + // Backwards compatibility case: return dummy FlattenedIndexKey for each leaf. + std::vector> result; + if (!to_iterable_with_keys.has_value()) { + auto [leaves, aux_data] = ToIterable(o); + for (nb::handle leaf : leaves) { + result.push_back(std::make_pair( + make_nb_class(result.size()), nb::borrow(leaf))); + } + return std::make_pair(std::move(result), std::move(aux_data)); + } + nb::object out = to_iterable_with_keys.value()(o); + nb::tuple leaves_and_aux_data; + if (!nb::try_cast(out, leaves_and_aux_data) || + leaves_and_aux_data.size() != 2) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable_with_keys function for a custom PyTree " + "node should return a (key_leaf_pairs, aux_data) tuple, got ", + nb::cast(nb::repr(out)))); + } + nb::iterable key_leaf_pairs; + if (!nb::try_cast(leaves_and_aux_data[0], key_leaf_pairs)) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable_with_keys function for a custom PyTree node should " + "return a (key_leaf_pairs, aux_data) tuple where 'key_leaf_pairs' is " + "iterable, got ", + nb::cast(nb::repr(leaves_and_aux_data)))); + } + for (nb::handle key_leaf_pair : key_leaf_pairs) { + nb::tuple key_leaf_pair_tuple; + if (!nb::try_cast(key_leaf_pair, key_leaf_pair_tuple) || + key_leaf_pair_tuple.size() != 2) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable_with_keys function for a custom PyTree node should " + "return a (key_leaf_pairs, aux_data) tuple where 'child", + nb::cast(nb::repr(key_leaf_pair)))); + } + result.push_back(std::make_pair(nb::borrow(key_leaf_pair_tuple[0]), + nb::borrow(key_leaf_pair_tuple[1]))); + } + return std::make_pair(std::move(result), nb::object(leaves_and_aux_data[1])); +} + +int PyTreeRegistry::Registration::tp_traverse(visitproc visit, void* arg) { + Py_VISIT(type.ptr()); + Py_VISIT(to_iterable.ptr()); + Py_VISIT(from_iterable.ptr()); + for (const auto& field : data_fields) { + Py_VISIT(field.ptr()); + } + for (const auto& field : meta_fields) { + Py_VISIT(field.ptr()); + } + return 0; +} + +// Computes the node kind of a given Python object. +PyTreeKind PyTreeRegistry::KindOfObject( + nb::handle obj, PyTreeRegistry::Registration const** custom) const { + const PyTreeRegistry::Registration* registration = Lookup(obj.type()); + if (registration) { + if (registration->kind == PyTreeKind::kCustom || + registration->kind == PyTreeKind::kDataclass) { + *custom = registration; + } else { + *custom = nullptr; + } + return registration->kind; + } else if (nb::isinstance(obj) && nb::hasattr(obj, "_fields")) { + // We can only identify namedtuples heuristically, here by the presence of + // a _fields attribute. + return PyTreeKind::kNamedTuple; + } else { + return PyTreeKind::kLeaf; + } +} + +/*static*/ const PyTreeRegistry::Registration* PyTreeRegistry::Lookup( + nb::handle type) const { + nb::ft_lock_guard lock(mu_); + auto it = registrations_.find(type); + return it == registrations_.end() ? nullptr : it->second.get(); +} + +/*static*/ std::vector GetSortedPyDictKeys(PyObject* py_dict) { + std::vector keys; + keys.reserve(PyDict_Size(py_dict)); + PyObject* key; + Py_ssize_t pos = 0; + while (PyDict_Next(py_dict, &pos, &key, /*value=*/nullptr)) { + keys.push_back(nb::borrow(key)); + } + + try { + std::stable_sort( + keys.begin(), keys.end(), [](const nb::object& a, const nb::object& b) { + int cmp = PyObject_RichCompareBool(a.ptr(), b.ptr(), Py_LT); + if (cmp == -1) { + throw nb::python_error(); + } + return cmp; + }); + } catch (nb::python_error& e) { + nb::raise_from(e, PyExc_ValueError, + "Comparator raised exception while sorting pytree " + "dictionary keys."); + } + return keys; +} + +/*static*/ bool IsSortedPyDictKeysEqual(absl::Span lhs, + absl::Span rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + for (int i = 0; i < lhs.size(); ++i) { + if (lhs[i].not_equal(rhs[i])) { + return false; + } + } + return true; +} + +bool PyTreeDef::operator==(const PyTreeDef& other) const { + if (traversal_.size() != other.traversal_.size()) { + return false; + } + for (size_t i = 0; i < traversal_.size(); ++i) { + const Node& a = traversal_[i]; + const Node& b = other.traversal_[i]; + if (a.kind != b.kind || a.arity != b.arity || + (a.node_data.ptr() == nullptr) != (b.node_data.ptr() == nullptr) || + (a.sorted_dict_keys.size() != b.sorted_dict_keys.size()) || + a.custom != b.custom) { + return false; + } + try { + if (a.node_data && a.node_data.not_equal(b.node_data)) { + return false; + } + } catch (nb::python_error& e) { + nb::raise_from(e, PyExc_ValueError, + "Exception raised while checking equality of metadata " + "fields of pytree. Make sure that metadata fields are " + "hashable and have simple equality semantics. (Note: " + "arrays cannot be passed as metadata fields!)"); + } + if (!IsSortedPyDictKeysEqual(a.sorted_dict_keys, b.sorted_dict_keys)) { + return false; + } + // We don't need to test equality of num_leaves and num_nodes since they + // are derivable from the other node data. + } + return true; +} + +nb::object PyTreeRegistry::FlattenOneLevel(nb::handle x) const { + return FlattenOneLevelImpl(x, /*with_keys=*/false); +} + +nb::object PyTreeRegistry::FlattenOneLevelWithKeys(nb::handle x) const { + return FlattenOneLevelImpl(x, /*with_keys=*/true); +} + +nb::object PyTreeRegistry::FlattenOneLevelImpl(nb::handle x, + bool with_keys) const { + PyTreeRegistry::Registration const* custom; + PyTreeKind kind = KindOfObject(x, &custom); + switch (kind) { + case PyTreeKind::kNone: + return nb::make_tuple(nb::make_tuple(), nb::none()); + case PyTreeKind::kTuple: { + if (with_keys) { + auto size = PyTuple_GET_SIZE(x.ptr()); + nb::object key_leaves = nb::steal(PyTuple_New(size)); + for (int i = 0; i < size; ++i) { + nb::object key = make_nb_class(i); + nb::object value = + nb::borrow(PyTuple_GET_ITEM(x.ptr(), i)); + PyTuple_SET_ITEM(key_leaves.ptr(), i, + nb::make_tuple(key, value).release().ptr()); + } + return nb::make_tuple(std::move(key_leaves), nb::none()); + } + return nb::make_tuple(nb::borrow(x), nb::none()); + } + case PyTreeKind::kList: { + if (with_keys) { + auto size = PyList_GET_SIZE(x.ptr()); + nb::object key_leaves = nb::steal(PyTuple_New(size)); + for (int i = 0; i < size; ++i) { + nb::object key = make_nb_class(i); + nb::object value = + nb::borrow(PyList_GET_ITEM(x.ptr(), i)); + PyTuple_SET_ITEM(key_leaves.ptr(), i, + nb::make_tuple(key, value).release().ptr()); + } + return nb::make_tuple(std::move(key_leaves), nb::none()); + } + return nb::make_tuple(nb::borrow(x), nb::none()); + } + case PyTreeKind::kDict: { + nb::dict dict = nb::borrow(x); + std::vector sorted_keys = GetSortedPyDictKeys(dict.ptr()); + nb::tuple keys = nb::steal(PyTuple_New(sorted_keys.size())); + nb::tuple values = nb::steal(PyTuple_New(sorted_keys.size())); + for (size_t i = 0; i < sorted_keys.size(); ++i) { + nb::object& key = sorted_keys[i]; + nb::object value = nb::object(dict[key]); + if (with_keys) { + value = nb::make_tuple(make_nb_class(key), value); + } + PyTuple_SET_ITEM(values.ptr(), i, value.release().ptr()); + PyTuple_SET_ITEM(keys.ptr(), i, sorted_keys[i].release().ptr()); + } + return nb::make_tuple(std::move(values), std::move(keys)); + } + case PyTreeKind::kNamedTuple: { + nb::tuple in = nb::borrow(x); + nb::list out; + if (with_keys) { + // Get key names from NamedTuple fields. + nb::tuple fields; + if (!nb::try_cast(nb::getattr(in, "_fields"), fields) || + in.size() != fields.size()) { + throw std::invalid_argument( + "A namedtuple's _fields attribute should have the same size as " + "the tuple."); + } + auto field_iter = fields.begin(); + for (nb::handle entry : in) { + out.append(nb::make_tuple( + make_nb_class(nb::str(*field_iter)), entry)); + } + return nb::make_tuple(std::move(out), x.type()); + } + for (size_t i = 0; i < in.size(); ++i) { + out.append(in[i]); + } + return nb::make_tuple(std::move(out), x.type()); + } + case PyTreeKind::kCustom: { + if (with_keys) { + auto [leaves, aux_data] = custom->ToIterableWithKeys(x); + return nb::make_tuple(std::move(leaves), std::move(aux_data)); + } + auto [leaves, aux_data] = custom->ToIterable(x); + return nb::make_tuple(std::move(leaves), std::move(aux_data)); + } + case PyTreeKind::kDataclass: { + auto data_size = custom->data_fields.size(); + nb::list leaves = nb::steal(PyList_New(data_size)); + for (int leaf = 0; leaf < data_size; ++leaf) { + nb::object value = nb::getattr(x, custom->data_fields[leaf]); + if (with_keys) { + value = nb::make_tuple( + make_nb_class(custom->data_fields[leaf]), value); + } + PyList_SET_ITEM(leaves.ptr(), leaf, value.release().ptr()); + } + auto meta_size = custom->meta_fields.size(); + nb::object aux_data = nb::steal(PyTuple_New(meta_size)); + for (int meta_leaf = 0; meta_leaf < meta_size; ++meta_leaf) { + PyTuple_SET_ITEM( + aux_data.ptr(), meta_leaf, + nb::getattr(x, custom->meta_fields[meta_leaf]).release().ptr()); + } + return nb::make_tuple(std::move(leaves), std::move(aux_data)); + } + default: + DCHECK(kind == PyTreeKind::kLeaf); + return nb::none(); + } +} + +/* static */ PyType_Slot PyTreeRegistry::slots_[] = { + {Py_tp_traverse, (void*)PyTreeRegistry::tp_traverse}, + {Py_tp_clear, (void*)PyTreeRegistry::tp_clear}, + {0, nullptr}, +}; + +/* static */ int PyTreeRegistry::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + Py_VISIT(Py_TYPE(self)); + if (!nb::inst_ready(self)) { + return 0; + } + PyTreeRegistry* registry = nb::inst_ptr(self); + nb::ft_lock_guard lock(registry->mu_); + for (const auto& [key, value] : registry->registrations_) { + Py_VISIT(key.ptr()); + int rval = value->tp_traverse(visit, arg); + if (rval != 0) { + return rval; + } + } + return 0; +} + +/* static */ int PyTreeRegistry::tp_clear(PyObject* self) { + PyTreeRegistry* registry = nb::inst_ptr(self); + nb::ft_lock_guard lock(registry->mu_); + registry->registrations_.clear(); + return 0; +} + +/* static */ PyType_Slot DictKey::slots_[] = { + {Py_tp_traverse, (void*)DictKey::tp_traverse}, + {Py_tp_clear, (void*)DictKey::tp_clear}, + {0, nullptr}, +}; + +/* static */ int DictKey::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + DictKey* key = nb::inst_ptr(self); + Py_VISIT(key->key_.ptr()); + return 0; +} + +/* static */ int DictKey::tp_clear(PyObject* self) { + DictKey* dictkey = nb::inst_ptr(self); + nb::object tmp; + std::swap(tmp, dictkey->key_); + return 0; +} + +std::string SequenceKey::ToString() const { + return absl::StrFormat("[%d]", idx_); +} + +std::string SequenceKey::ToReprString() const { + return absl::StrFormat("SequenceKey(idx=%d)", idx_); +} + +std::string DictKey::ToString() const { + return absl::StrFormat("[%s]", nb::cast(nb::repr(key_))); +} + +std::string DictKey::ToReprString() const { + return absl::StrFormat("DictKey(key=%s)", + nb::cast(nb::repr(key_))); +} + +std::string GetAttrKey::ToString() const { + return absl::StrFormat(".%s", nb::cast(name_)); +} + +std::string GetAttrKey::ToReprString() const { + return absl::StrFormat("GetAttrKey(name='%s')", + nb::cast(name_)); +} + +std::string FlattenedIndexKey::ToString() const { + return absl::StrFormat("[]", key_); +} + +std::string FlattenedIndexKey::ToReprString() const { + return absl::StrFormat("FlattenedIndexKey(key=%d)", key_); +} + +bool SequenceKey::Equals(const nb::object& other) { + SequenceKey other_key(0); + if (!nb::try_cast(other, other_key)) return false; + return idx_ == other_key.idx(); +} + +bool DictKey::Equals(const nb::object& other) { + DictKey other_key(nb::none()); + if (!nb::try_cast(other, other_key)) return false; + return key_.equal(other_key.key()); +} + +bool GetAttrKey::Equals(const nb::object& other) { + GetAttrKey other_key(nb::str("")); + if (!nb::try_cast(other, other_key)) return false; + return name_.equal(other_key.name()); +} + +bool FlattenedIndexKey::Equals(const nb::object& other) { + FlattenedIndexKey other_key(0); + if (!nb::try_cast(other, other_key)) return false; + return key_ == other_key.key(); +} + +nanobind::tuple SequenceKey::MatchArgs(nanobind::handle unused) { + return nanobind::make_tuple("idx"); +}; + +nanobind::tuple DictKey::MatchArgs(nanobind::handle unused) { + return nanobind::make_tuple("key"); +}; + +nanobind::tuple GetAttrKey::MatchArgs(nanobind::handle unused) { + return nanobind::make_tuple("name"); +}; + +nanobind::tuple FlattenedIndexKey::MatchArgs(nanobind::handle unused) { + return nanobind::make_tuple("key"); +}; + +/* static */ nb::object MakeKeyPathTuple(std::vector& keypath) { + const std::vector& frozen_keypath = keypath; + nb::object kp_tuple = nb::steal(PyTuple_New(frozen_keypath.size())); + for (int i = 0; i < frozen_keypath.size(); ++i) { + PyTuple_SET_ITEM(kp_tuple.ptr(), i, + nb::object(frozen_keypath[i]).release().ptr()); + } + return kp_tuple; +} + +template +void PyTreeDef::FlattenImpl(nb::handle handle, T& leaves, + std::optional>& keypath, + const std::optional& leaf_predicate) { + Node node; + const int start_num_nodes = traversal_.size(); + const int start_num_leaves = leaves.size(); + bool is_known_leaf = false; + if (leaf_predicate) { + nb::object o; + if (keypath.has_value()) { + auto kp_tuple = MakeKeyPathTuple(keypath.value()); + o = (*leaf_predicate)(kp_tuple, handle); + } else { + o = (*leaf_predicate)(handle); + } + // Historically we accepted "truthy" values from leaf predicates. Accept + // None here to keep existing clients happy. + if (o.is_none()) { + is_known_leaf = false; + } else if (!nb::try_cast(o, is_known_leaf)) { + throw std::invalid_argument(absl::StrCat( + "is_leaf predicate returned a non-boolean value ", + nb::cast(nb::repr(o)), "; expected a boolean")); + } + } + if (is_known_leaf) { + nb::object value = nb::borrow(handle); + if (keypath.has_value()) { + auto kp_tuple = MakeKeyPathTuple(keypath.value()); + value = nb::make_tuple(std::move(kp_tuple), std::move(value)); + } + if constexpr (std::is_same_v) { + leaves.append(std::move(value)); + } else { + leaves.push_back(std::move(value)); + } + } else { + node.kind = registry_->KindOfObject(handle, &node.custom); + auto recurse = [this, &leaf_predicate, &leaves]( + nb::handle child, + std::optional>& keypath) { + if (Py_EnterRecursiveCall( + " in flatten; PyTree may have cyclical node references.")) { + return; + } + FlattenImpl(child, leaves, keypath, leaf_predicate); + Py_LeaveRecursiveCall(); + }; + switch (node.kind) { + case PyTreeKind::kNone: + // Nothing to do. + break; + case PyTreeKind::kTuple: { + node.arity = PyTuple_GET_SIZE(handle.ptr()); + for (int i = 0; i < node.arity; ++i) { + if (keypath.has_value()) { + keypath->push_back(make_nb_class(i)); + } + recurse(PyTuple_GET_ITEM(handle.ptr(), i), keypath); + if (keypath.has_value()) { + keypath->pop_back(); + } + } + break; + } + case PyTreeKind::kList: { + node.arity = PyList_GET_SIZE(handle.ptr()); + for (int i = 0; i < node.arity; ++i) { + if (keypath.has_value()) { + keypath->push_back(make_nb_class(i)); + } + recurse(PyList_GET_ITEM(handle.ptr(), i), keypath); + if (keypath.has_value()) { + keypath->pop_back(); + } + } + break; + } + case PyTreeKind::kDict: { + nb::dict dict = nb::borrow(handle); + + std::vector keys = GetSortedPyDictKeys(dict.ptr()); + for (nb::object& key : keys) { + if (keypath.has_value()) { + keypath->push_back(make_nb_class(key)); + } + recurse(dict[key], keypath); + if (keypath.has_value()) { + keypath->pop_back(); + } + } + node.arity = dict.size(); + node.sorted_dict_keys = std::move(keys); + break; + } + case PyTreeKind::kCustom: { + if (keypath.has_value()) { + auto [leaves, aux_data] = node.custom->ToIterableWithKeys(handle); + node.node_data = std::move(aux_data); + node.arity = 0; + for (auto& [key, leaf] : leaves) { + keypath->push_back(key); + ++node.arity; + recurse(leaf, keypath); + keypath->pop_back(); + } + } else { + auto [leaves, aux_data] = node.custom->ToIterable(handle); + node.node_data = std::move(aux_data); + node.arity = 0; + for (nb::handle entry : leaves) { + ++node.arity; + recurse(entry, keypath); + } + } + break; + } + case PyTreeKind::kDataclass: { + auto meta_size = node.custom->meta_fields.size(); + nb::object aux_data = nb::steal(PyTuple_New(meta_size)); + for (int meta_leaf = 0; meta_leaf < meta_size; ++meta_leaf) { + PyTuple_SET_ITEM( + aux_data.ptr(), meta_leaf, + nb::getattr(handle, node.custom->meta_fields[meta_leaf]) + .release() + .ptr()); + } + node.node_data = std::move(aux_data); + auto data_size = node.custom->data_fields.size(); + node.arity = data_size; + for (int leaf = 0; leaf < data_size; ++leaf) { + if (keypath.has_value()) { + keypath->push_back( + make_nb_class(node.custom->data_fields[leaf])); + } + recurse(nb::getattr(handle, node.custom->data_fields[leaf]), keypath); + if (keypath.has_value()) { + keypath->pop_back(); + } + } + break; + } + case PyTreeKind::kNamedTuple: { + nb::tuple tuple = nb::borrow(handle); + node.arity = tuple.size(); + node.node_data = nb::borrow(tuple.type()); + if (keypath.has_value()) { + // Get key names from NamedTuple fields. + nb::tuple fields; + if (!nb::try_cast(nb::getattr(tuple, "_fields"), fields) || + tuple.size() != fields.size()) { + throw std::invalid_argument( + "A namedtuple's _fields attribute should have the same size as " + "the tuple."); + } + auto field_iter = fields.begin(); + for (nb::handle entry : tuple) { + keypath->push_back(make_nb_class(nb::str(*field_iter))); + field_iter++; + recurse(entry, keypath); + keypath->pop_back(); + } + } else { + for (nb::handle entry : tuple) { + recurse(entry, keypath); + } + } + break; + } + default: + DCHECK(node.kind == PyTreeKind::kLeaf); + auto value = nb::borrow(handle); + if (keypath.has_value()) { + auto kp_tuple = MakeKeyPathTuple(keypath.value()); + value = nb::make_tuple(std::move(kp_tuple), std::move(value)); + } + if constexpr (std::is_same_v) { + leaves.append(std::move(value)); + } else { + leaves.push_back(std::move(value)); + } + } + } + node.num_nodes = traversal_.size() - start_num_nodes + 1; + node.num_leaves = leaves.size() - start_num_leaves; + traversal_.push_back(std::move(node)); +} + +void PyTreeDef::Flatten(nb::handle handle, + absl::InlinedVector& leaves, + std::optional leaf_predicate) { + std::optional> keypath = std::nullopt; + FlattenImpl(handle, leaves, keypath, leaf_predicate); +} + +void PyTreeDef::Flatten(nb::handle handle, std::vector& leaves, + std::optional leaf_predicate) { + std::optional> keypath = std::nullopt; + FlattenImpl(handle, leaves, keypath, leaf_predicate); +} + +void PyTreeDef::Flatten(nb::handle handle, nb::list& leaves, + std::optional leaf_predicate) { + std::optional> keypath = std::nullopt; + FlattenImpl(handle, leaves, keypath, leaf_predicate); +} + +/*static*/ std::pair, nb_class_ptr> +PyTreeDef::Flatten(nb::handle x, nb_class_ptr registry, + std::optional leaf_predicate) { + auto def = make_nb_class(registry); + std::vector leaves; + def->Flatten(x, leaves, leaf_predicate); + return std::make_pair(std::move(leaves), std::move(def)); +} + +void PyTreeDef::FlattenWithPath(nb::handle handle, nanobind::list& leaves, + std::optional leaf_predicate) { + std::optional> keypath = std::vector(); + FlattenImpl(handle, leaves, keypath, leaf_predicate); +} + +/*static*/ bool PyTreeDef::AllLeaves(PyTreeRegistry* registry, + const nb::iterable& x) { + const PyTreeRegistry::Registration* custom; + for (const nb::handle& h : x) { + if (registry->KindOfObject(h, &custom) != PyTreeKind::kLeaf) return false; + } + return true; +} + +template +nb::object PyTreeDef::UnflattenImpl(T leaves) const { + absl::InlinedVector agenda; + auto it = leaves.begin(); + int leaf_count = 0; + for (const Node& node : traversal_) { + if (agenda.size() < node.arity) { + throw std::logic_error("Too few elements for TreeDef node."); + } + switch (node.kind) { + case PyTreeKind::kLeaf: + if (it == leaves.end()) { + throw std::invalid_argument(absl::StrFormat( + "Too few leaves for PyTreeDef; expected %d, got %d", num_leaves(), + leaf_count)); + } + agenda.push_back(nb::borrow(*it)); + ++it; + ++leaf_count; + break; + + case PyTreeKind::kNone: + case PyTreeKind::kTuple: + case PyTreeKind::kNamedTuple: + case PyTreeKind::kList: + case PyTreeKind::kDict: + case PyTreeKind::kCustom: + case PyTreeKind::kDataclass: { + const int size = agenda.size(); + absl::Span span; + if (node.arity > 0) { + span = absl::Span(&agenda[size - node.arity], node.arity); + } + nb::object o = MakeNode(node, span); + agenda.resize(size - node.arity); + agenda.push_back(o); + break; + } + } + } + if (it != leaves.end()) { + throw std::invalid_argument(absl::StrFormat( + "Too many leaves for PyTreeDef; expected %d.", num_leaves())); + } + if (agenda.size() != 1) { + throw std::logic_error("PyTreeDef traversal did not yield a singleton."); + } + return std::move(agenda.back()); +} + +nb::object PyTreeDef::Unflatten(nb::iterable leaves) const { + return UnflattenImpl(leaves); +} + +nb::object PyTreeDef::Unflatten(absl::Span leaves) const { + return UnflattenImpl(leaves); +} + +/*static*/ nb::object PyTreeDef::MakeNode(const PyTreeDef::Node& node, + absl::Span children) { + if (children.size() != node.arity) { + throw std::logic_error("Node arity mismatch."); + } + switch (node.kind) { + case PyTreeKind::kLeaf: + throw std::logic_error("MakeNode not implemented for leaves."); + + case PyTreeKind::kNone: + return nb::none(); + + case PyTreeKind::kTuple: + case PyTreeKind::kNamedTuple: { + nb::object tuple = nb::steal(PyTuple_New(node.arity)); + for (int i = 0; i < node.arity; ++i) { + PyTuple_SET_ITEM(tuple.ptr(), i, children[i].release().ptr()); + } + if (node.kind == PyTreeKind::kNamedTuple) { + return node.node_data(*tuple); + } else { + return tuple; + } + } + + case PyTreeKind::kList: { + nb::object list = nb::steal(PyList_New(node.arity)); + for (int i = 0; i < node.arity; ++i) { + PyList_SET_ITEM(list.ptr(), i, children[i].release().ptr()); + } + return list; + } + + case PyTreeKind::kDict: { + nb::dict dict; + for (int i = 0; i < node.arity; ++i) { + dict[node.sorted_dict_keys[i]] = std::move(children[i]); + } + return std::move(dict); + break; + } + case PyTreeKind::kCustom: { + nb::object tuple = nb::steal(PyTuple_New(node.arity)); + for (int i = 0; i < node.arity; ++i) { + PyTuple_SET_ITEM(tuple.ptr(), i, children[i].release().ptr()); + } + return node.custom->from_iterable(node.node_data, tuple); + } + + case PyTreeKind::kDataclass: { + nb::kwargs kwargs; + auto meta_size = node.custom->meta_fields.size(); + for (int i = 0; i < meta_size; ++i) { + kwargs[node.custom->meta_fields[i]] = + nb::borrow(nb::tuple(node.node_data)[i]); + } + auto data_size = node.custom->data_fields.size(); + for (int i = 0; i < data_size; ++i) { + kwargs[node.custom->data_fields[i]] = std::move(children[i]); + } + return node.custom->type(**kwargs); + } + } + throw std::logic_error("Unreachable code."); +} + +nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { + nb::list leaves = nb::steal(PyList_New(num_leaves())); + std::vector agenda; + agenda.push_back(nb::borrow(xs)); + auto it = traversal_.rbegin(); + int leaf = num_leaves() - 1; + while (!agenda.empty()) { + if (it == traversal_.rend()) { + throw std::invalid_argument(absl::StrFormat( + "Tree structures did not match: %s vs %s", + nb::cast(nb::repr(xs)), ToString())); + } + const Node& node = *it; + nb::object object = agenda.back(); + agenda.pop_back(); + ++it; + + switch (node.kind) { + case PyTreeKind::kLeaf: + if (leaf < 0) { + throw std::logic_error("Leaf count mismatch."); + } + PyList_SET_ITEM(leaves.ptr(), leaf, object.release().ptr()); + --leaf; + break; + + case PyTreeKind::kNone: + if (!object.is_none()) { + throw std::invalid_argument(absl::StrFormat( + "Expected None, got %s.\n\n" + "In previous releases of JAX, flatten-up-to used to " + "consider None to be a tree-prefix of non-None values. To obtain " + "the previous behavior, you can usually write:\n" + " jax.tree.map(lambda x, y: None if x is None else f(x, y), a, " + "b, is_leaf=lambda x: x is None)", + nb::cast(nb::repr(object)))); + } + break; + + case PyTreeKind::kTuple: { + if (!PyTuple_CheckExact(object.ptr())) { + throw std::invalid_argument( + absl::StrFormat("Expected tuple, got %s.", + nb::cast(nb::repr(object)))); + } + nb::tuple tuple = nb::borrow(object); + if (tuple.size() != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "Tuple arity mismatch: %d != %d; tuple: %s.", tuple.size(), + node.arity, nb::cast(nb::repr(object)))); + } + for (nb::handle entry : tuple) { + agenda.push_back(nb::borrow(entry)); + } + break; + } + + case PyTreeKind::kList: { + if (!PyList_CheckExact(object.ptr())) { + throw std::invalid_argument( + absl::StrFormat("Expected list, got %s.", + nb::cast(nb::repr(object)))); + } + nb::list list = nb::borrow(object); + if (list.size() != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "List arity mismatch: %d != %d; list: %s.", list.size(), + node.arity, nb::cast(nb::repr(object)))); + } + for (nb::handle entry : list) { + agenda.push_back(nb::borrow(entry)); + } + break; + } + + case PyTreeKind::kDict: { + if (!PyDict_CheckExact(object.ptr())) { + throw std::invalid_argument( + absl::StrFormat("Expected dict, got %s.", + nb::cast(nb::repr(object)))); + } + nb::dict dict = nb::borrow(object); + std::vector keys = GetSortedPyDictKeys(dict.ptr()); + if (!IsSortedPyDictKeysEqual(keys, node.sorted_dict_keys)) { + // Convert to a nb::list for nb::repr to avoid having to stringify a + // vector. This is error path so it is fine to pay conversion cost. + throw std::invalid_argument(absl::StrFormat( + "Dict key mismatch; expected keys: %s; present keys: %s.", + nb::cast( + nb::repr(nb::cast(node.sorted_dict_keys))), + nb::cast(nb::repr(nb::cast(keys))))); + } + for (nb::handle key : keys) { + agenda.push_back(dict[key]); + } + break; + } + + case PyTreeKind::kNamedTuple: { + if (!nb::isinstance(object) || + !nb::hasattr(object, "_fields")) { + throw std::invalid_argument( + absl::StrFormat("Expected named tuple, got %s.", + nb::cast(nb::repr(object)))); + } + nb::tuple tuple = nb::borrow(object); + if (tuple.size() != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "Named tuple arity mismatch: %d != %d; tuple: %s.", tuple.size(), + node.arity, nb::cast(nb::repr(object)))); + } + if (tuple.type().not_equal(node.node_data)) { + throw std::invalid_argument(absl::StrFormat( + "Named tuple type mismatch: expected type: %s, tuple: %s.", + nb::cast(nb::repr(node.node_data)), + nb::cast(nb::repr(object)))); + } + for (nb::handle entry : tuple) { + agenda.push_back(nb::borrow(entry)); + } + break; + } + + case PyTreeKind::kCustom: { + auto* registration = registry_->Lookup(object.type()); + if (registration != node.custom) { + throw std::invalid_argument(absl::StrFormat( + "Custom node type mismatch: expected type: %s, value: %s.", + nb::cast(nb::repr(node.custom->type)), + nb::cast(nb::repr(object)))); + } + auto [leaves, aux_data] = node.custom->ToIterable(object); + if (node.node_data.not_equal(aux_data)) { + throw std::invalid_argument(absl::StrFormat( + "Mismatch custom node data: %s != %s; value: %s.", + nb::cast(nb::repr(node.node_data)), + nb::cast(nb::repr(aux_data)), + nb::cast(nb::repr(object)))); + } + int arity = 0; + for (nb::handle entry : leaves) { + ++arity; + agenda.push_back(nb::borrow(entry)); + } + if (arity != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "Custom type arity mismatch: %d != %d; value: %s.", arity, + node.arity, nb::cast(nb::repr(object)))); + } + break; + } + + case PyTreeKind::kDataclass: { + auto* registration = registry_->Lookup(object.type()); + if (registration != node.custom) { + throw std::invalid_argument(absl::StrFormat( + "Custom dataclass node type mismatch: expected type: %s, value: " + "%s.", + nb::cast(nb::repr(node.custom->type)), + nb::cast(nb::repr(std::move(object))))); + } + auto meta_size = node.custom->meta_fields.size(); + nb::object aux_data = nb::steal(PyTuple_New(meta_size)); + for (int meta_leaf = 0; meta_leaf < meta_size; ++meta_leaf) { + PyTuple_SET_ITEM( + aux_data.ptr(), meta_leaf, + nb::getattr(object, node.custom->meta_fields[meta_leaf]) + .release() + .ptr()); + } + if (node.node_data.not_equal(aux_data)) { + throw std::invalid_argument(absl::StrFormat( + "Mismatch custom dataclass node data: %s != %s; value: %s.", + nb::cast(nb::repr(node.node_data)), + nb::cast(nb::repr(aux_data)), + nb::cast(nb::repr(object)))); + } + auto data_size = node.custom->data_fields.size(); + if (data_size != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "Custom type arity mismatch: %d != %d; value: %s.", data_size, + node.arity, nb::cast(nb::repr(object)))); + } + for (int leaf = 0; leaf < data_size; ++leaf) { + agenda.push_back(nb::borrow( + nb::getattr(object, node.custom->data_fields[leaf]))); + } + break; + } + } + } + if (it != traversal_.rend() || leaf != -1) { + throw std::invalid_argument( + absl::StrFormat("Tree structures did not match: %s vs %s", + nb::cast(nb::repr(xs)), ToString())); + } + return leaves; +} + +nb::object PyTreeDef::Walk(const nb::callable& f_node, nb::handle f_leaf, + nb::iterable leaves) const { + std::vector agenda; + auto it = leaves.begin(); + for (const Node& node : traversal_) { + switch (node.kind) { + case PyTreeKind::kLeaf: { + if (it == leaves.end()) { + throw std::invalid_argument("Too few leaves for PyTreeDef"); + } + + nb::object leaf = nb::borrow(*it); + agenda.push_back(f_leaf.is_none() ? std::move(leaf) + : f_leaf(std::move(leaf))); + ++it; + break; + } + + case PyTreeKind::kNone: + case PyTreeKind::kTuple: + case PyTreeKind::kNamedTuple: + case PyTreeKind::kList: + case PyTreeKind::kDict: + case PyTreeKind::kCustom: + case PyTreeKind::kDataclass: { + if (agenda.size() < node.arity) { + throw std::logic_error("Too few elements for custom type."); + } + nb::object tuple = nb::steal(PyTuple_New(node.arity)); + for (int i = node.arity - 1; i >= 0; --i) { + PyTuple_SET_ITEM(tuple.ptr(), i, agenda.back().release().ptr()); + agenda.pop_back(); + } + nb::object node_data = node.node_data; + if (node.kind == PyTreeKind::kDict) { + // Convert to a nb::list for f_node invocation. + node_data = nb::cast(node.sorted_dict_keys); + } + agenda.push_back(f_node(tuple, node_data ? node_data : nb::none())); + } + } + } + if (it != leaves.end()) { + throw std::invalid_argument("Too many leaves for PyTreeDef"); + } + if (agenda.size() != 1) { + throw std::logic_error("PyTreeDef traversal did not yield a singleton."); + } + return std::move(agenda.back()); +} + +nb::object PyTreeDef::FromIterableTreeHelper( + nb::handle xs, + absl::InlinedVector::const_reverse_iterator* it) const { + if (*it == traversal_.rend()) { + throw std::invalid_argument("Tree structures did not match."); + } + const Node& node = **it; + ++*it; + if (node.kind == PyTreeKind::kLeaf) { + return nb::borrow(xs); + } + nb::iterable iterable = nb::borrow(xs); + std::vector ys; + ys.reserve(node.arity); + for (nb::handle x : iterable) { + ys.push_back(nb::borrow(x)); + } + if (ys.size() != node.arity) { + throw std::invalid_argument("Arity mismatch between trees"); + } + for (int j = node.arity - 1; j >= 0; --j) { + ys[j] = FromIterableTreeHelper(ys[j], it); + } + + return MakeNode(node, absl::MakeSpan(ys)); +} + +nb::object PyTreeDef::FromIterableTree(nb::handle xs) const { + auto it = traversal_.rbegin(); + nb::object out = FromIterableTreeHelper(xs, &it); + if (it != traversal_.rend()) { + throw std::invalid_argument("Tree structures did not match."); + } + return out; +} + +nb_class_ptr PyTreeDef::Compose(const PyTreeDef& inner) const { + if (inner.registry_ != registry_) { + throw std::invalid_argument( + "PyTree registries of PyTreeDefs passed to Compose() must match."); + } + auto out = make_nb_class(registry_ref_); + out->traversal_.reserve(static_cast(num_leaves()) * + inner.num_nodes() + + num_nodes() - num_leaves()); + for (const Node& n : traversal_) { + if (n.kind == PyTreeKind::kLeaf) { + absl::c_copy(inner.traversal_, std::back_inserter(out->traversal_)); + } else { + out->traversal_.push_back(n); + } + } + out->SetNumLeavesAndNumNodes(); + return out; +} + +/*static*/ nb_class_ptr PyTreeDef::Tuple( + nb_class_ptr registry, nb::list defs) { + auto out = make_nb_class(std::move(registry)); + int num_leaves = 0; + for (nb::handle def_handle : defs) { + const PyTreeDef* def = nb::cast(def_handle); + if (def->registry() != out->registry()) { + throw std::invalid_argument( + "PyTree registries of PyTreeDefs passed to Tuple() must match."); + } + absl::c_copy(def->traversal_, std::back_inserter(out->traversal_)); + num_leaves += def->num_leaves(); + } + Node node; + node.kind = PyTreeKind::kTuple; + node.arity = defs.size(); + node.num_leaves = num_leaves; + node.num_nodes = out->traversal_.size() + 1; + out->traversal_.push_back(node); + return out; +} + +std::vector> PyTreeDef::Children() const { + std::vector> children; + if (traversal_.empty()) { + return children; + } + Node const& root = traversal_.back(); + children.resize(root.arity); + int pos = traversal_.size() - 1; + for (int i = root.arity - 1; i >= 0; --i) { + children[i] = make_nb_class(registry_ref_); + const Node& node = traversal_.at(pos - 1); + if (pos < node.num_nodes) { + throw std::logic_error("children() walked off start of array"); + } + std::copy(traversal_.begin() + pos - node.num_nodes, + traversal_.begin() + pos, + std::back_inserter(children[i]->traversal_)); + pos -= node.num_nodes; + } + if (pos != 0) { + throw std::logic_error("pos != 0 at end of PyTreeDef::Children"); + } + return children; +} + +std::string PyTreeDef::ToString() const { + std::vector agenda; + for (const Node& node : traversal_) { + if (agenda.size() < node.arity) { + throw std::logic_error("Too few elements for container."); + } + + std::string children = + absl::StrJoin(agenda.end() - node.arity, agenda.end(), ", "); + std::string representation; + switch (node.kind) { + case PyTreeKind::kLeaf: + agenda.push_back("*"); + continue; + case PyTreeKind::kNone: + representation = "None"; + break; + case PyTreeKind::kTuple: + // Tuples with only one element must have a trailing comma. + if (node.arity == 1) children += ","; + representation = absl::StrCat("(", children, ")"); + break; + case PyTreeKind::kList: + representation = absl::StrCat("[", children, "]"); + break; + case PyTreeKind::kDict: { + if (node.sorted_dict_keys.size() != node.arity) { + throw std::logic_error("Number of keys and entries does not match."); + } + representation = "{"; + std::string separator; + auto child_iter = agenda.end() - node.arity; + for (const nb::handle& key : node.sorted_dict_keys) { + absl::StrAppendFormat(&representation, "%s%s: %s", separator, + nb::cast(nb::repr(key)), + *child_iter); + child_iter++; + separator = ", "; + } + representation += "}"; + break; + } + + case PyTreeKind::kNamedTuple: + case PyTreeKind::kCustom: + case PyTreeKind::kDataclass: { + std::string kind; + std::string data; + if (node.kind == PyTreeKind::kNamedTuple) { + kind = "namedtuple"; + if (node.node_data) { + // Node data for named tuples is the type. + data = absl::StrFormat( + "[%s]", nb::cast( + nb::str(nb::getattr(node.node_data, "__name__")))); + } + } else { + kind = nb::cast( + nb::str(nb::getattr(node.custom->type, "__name__"))); + if (node.node_data) { + data = absl::StrFormat( + "[%s]", nb::cast(nb::str(node.node_data))); + } + } + + representation = + absl::StrFormat("CustomNode(%s%s, [%s])", kind, data, children); + break; + } + } + agenda.erase(agenda.end() - node.arity, agenda.end()); + agenda.push_back(std::move(representation)); + } + if (agenda.size() != 1) { + throw std::logic_error("PyTreeDef traversal did not yield a singleton."); + } + return absl::StrCat("PyTreeDef(", agenda.back(), ")"); +} + +nb::object PyTreeDef::ToPickle() const { + nb::list traversal; + for (const auto& node : traversal_) { + nb::object node_data = node.node_data; + if (node.kind == PyTreeKind::kDict) { + // Convert to a nb::list for pickling to avoid having to pickle a vector. + // Pickle should be a rare operation so this conversion cost is hopefully + // on non-critical path. + node_data = nb::cast(node.sorted_dict_keys); + } + traversal.append( + nb::make_tuple(static_cast(node.kind), node.arity, + node_data ? node_data : nb::none(), + node.custom != nullptr ? node.custom->type : nb::none(), + node.num_leaves, node.num_nodes)); + } + return nb::make_tuple(nb::cast(registry_ref_), traversal); +} + +void PyTreeDef::FromPickle(nb::object pickle) { + for (const auto& item : nb::cast(pickle)) { + auto t = nb::cast(item); + if (t.size() != 6) { + throw xla::XlaRuntimeError("Malformed pickled PyTreeDef"); + } + Node& node = traversal_.emplace_back(); + node.kind = static_cast(nb::cast(t[0])); + node.arity = nb::cast(t[1]); + switch (node.kind) { + case PyTreeKind::kNamedTuple: + node.node_data = t[2]; + break; + case PyTreeKind::kDict: + node.sorted_dict_keys = nb::cast>(t[2]); + break; + case PyTreeKind::kCustom: + case PyTreeKind::kDataclass: + node.node_data = t[2]; + break; + default: + if (!t[2].is_none()) { + throw xla::XlaRuntimeError("Malformed pickled PyTreeDef"); + } + break; + } + if (node.kind == PyTreeKind::kCustom || + node.kind == PyTreeKind::kDataclass) { + node.custom = t[3].is_none() ? nullptr : registry()->Lookup(t[3]); + if (node.custom == nullptr) { + throw xla::XlaRuntimeError( + absl::StrCat("Unknown custom type in pickled PyTreeDef: ", + nb::cast(nb::repr(t[3])))); + } + } else { + if (!t[3].is_none()) { + throw xla::XlaRuntimeError("Malformed pickled PyTreeDef"); + } + } + node.num_leaves = nb::cast(t[4]); + node.num_nodes = nb::cast(t[5]); + } +} + +void PyTreeDef::SetNumLeavesAndNumNodes() { + // num_leaves and num_nodes are fully determined by arity. + std::vector> starts; + int num_leaves = 0; + for (int i = 0; i < traversal_.size(); ++i) { + std::pair start = {num_leaves, i}; + if (traversal_[i].kind == PyTreeKind::kLeaf) { + num_leaves += 1; + } + if (traversal_[i].arity == 0) { + starts.push_back(start); + } else { + starts.resize(starts.size() - (traversal_[i].arity - 1)); + } + traversal_[i].num_leaves = num_leaves - starts.back().first; + traversal_[i].num_nodes = i + 1 - starts.back().second; + } +} + +void PyTreeDef::SerializeTo(PyTreeDefProto& result) const { + absl::flat_hash_map interned_strings; + auto intern_str = [&](const std::string& key) { + auto [it, added] = + interned_strings.emplace(key, result.interned_strings_size()); + if (added) { + result.add_interned_strings(key); + } + return it->second; + }; + for (const auto& node : traversal_) { + auto* node_data = result.add_nodes(); + node_data->set_arity(node.arity); + switch (node.kind) { + case PyTreeKind::kLeaf: + node_data->set_type(PyTreeNodeType::PY_TREE_KIND_LEAF); + break; + case PyTreeKind::kList: + node_data->set_type(PyTreeNodeType::PY_TREE_KIND_LIST); + break; + case PyTreeKind::kNone: + node_data->set_type(PyTreeNodeType::PY_TREE_KIND_NONE); + break; + case PyTreeKind::kTuple: + node_data->set_type(PyTreeNodeType::PY_TREE_KIND_TUPLE); + break; + case PyTreeKind::kDict: + node_data->set_type(PyTreeNodeType::PY_TREE_KIND_DICT); + for (auto& key : node.sorted_dict_keys) { + if (!nb::isinstance(key)) { + throw std::invalid_argument( + "Only string keys are supported in proto pytree " + "serialization."); + } + node_data->mutable_dict_keys()->add_str_id( + intern_str(nb::cast(key))); + } + break; + default: + throw std::invalid_argument( + "User-defined nodes are not supported when serializing pytrees as " + "protocol buffers. You should either convert the user-defined " + "nodes to another type or use pickle instead."); + break; + } + } +} + +nb_class_ptr PyTreeDef::DeserializeFrom( + nb_class_ptr registry, const PyTreeDefProto& input) { + std::vector interned_strings; + interned_strings.reserve(input.interned_strings().size()); + for (auto& s : input.interned_strings()) { + interned_strings.push_back(nb::cast(s)); + } + nb_class_ptr result = + make_nb_class(std::move(registry)); + for (auto& node_proto : input.nodes()) { + result->traversal_.emplace_back(); + auto& node = result->traversal_.back(); + node.arity = node_proto.arity(); + node.custom = nullptr; + switch (node_proto.type()) { + case PyTreeNodeType::PY_TREE_KIND_LEAF: + node.kind = PyTreeKind::kLeaf; + break; + case PyTreeNodeType::PY_TREE_KIND_LIST: + node.kind = PyTreeKind::kList; + break; + case PyTreeNodeType::PY_TREE_KIND_NONE: + node.kind = PyTreeKind::kNone; + break; + case PyTreeNodeType::PY_TREE_KIND_TUPLE: + node.kind = PyTreeKind::kTuple; + break; + case PyTreeNodeType::PY_TREE_KIND_DICT: + node.kind = PyTreeKind::kDict; + for (uint32_t str_id : node_proto.dict_keys().str_id()) { + if (str_id >= interned_strings.size()) { + throw std::invalid_argument( + "Malformed pytree proto (dict_key out of range)."); + } + node.sorted_dict_keys.push_back(interned_strings.at(str_id)); + } + break; + default: + throw std::invalid_argument( + "Malformed pytree proto (invalid node type)"); + break; + } + } + result->SetNumLeavesAndNumNodes(); + return result; +} + +std::optional> PyTreeDef::GetNodeData() + const { + if (traversal_.empty()) { + throw std::logic_error("empty PyTreeDef traversal."); + } + auto builtin_type = [](PyTypeObject* type_obj) { + return nb::borrow(reinterpret_cast(type_obj)); + }; + const auto& node = traversal_.back(); + switch (node.kind) { + case PyTreeKind::kLeaf: + return std::nullopt; + case PyTreeKind::kNone: + return std::make_pair(builtin_type(Py_TYPE(Py_None)), nb::none()); + case PyTreeKind::kTuple: + return std::make_pair(builtin_type(&PyTuple_Type), nb::none()); + case PyTreeKind::kList: + return std::make_pair(builtin_type(&PyList_Type), nb::none()); + case PyTreeKind::kDict: + return std::make_pair(builtin_type(&PyDict_Type), + nb::cast(node.sorted_dict_keys)); + case PyTreeKind::kNamedTuple: + return std::make_pair(node.node_data, nb::none()); + case PyTreeKind::kCustom: + case PyTreeKind::kDataclass: + return std::make_pair(node.custom->type, node.node_data); + } +} + +nb_class_ptr PyTreeDef::FromNodeDataAndChildren( + nb_class_ptr registry, + std::optional> node_data, + nb::iterable children) { + nb_class_ptr result = + make_nb_class(std::move(registry)); + int num_leaves = 0; + int arity = 0; + for (nb::handle pchild : children) { + const PyTreeDef& child = nb::cast(pchild); + absl::c_copy(child.traversal_, std::back_inserter(result->traversal_)); + num_leaves += child.num_leaves(); + ++arity; + } + result->traversal_.emplace_back(); + auto& node = result->traversal_.back(); + node.arity = arity; + node.custom = nullptr; + node.num_leaves = num_leaves; + node.num_nodes = result->traversal_.size(); + if (node_data == std::nullopt) { + node.kind = PyTreeKind::kLeaf; + ++node.num_leaves; + return result; + } + int is_nt = PyObject_IsSubclass(node_data->first.ptr(), + reinterpret_cast(&PyTuple_Type)); + if (is_nt == -1) { + throw nb::python_error(); + } + if (is_nt != 0 && nb::hasattr(node_data->first, "_fields")) { + node.kind = PyTreeKind::kNamedTuple; + node.node_data = node_data->first; + return result; + } + auto* registration = result->registry()->Lookup(node_data->first); + if (registration == nullptr) { + throw std::logic_error(absl::StrFormat( + "Could not find type: %s.", + nb::cast(nb::repr(node_data->first)))); + } + node.kind = registration->kind; + if (node.kind == PyTreeKind::kCustom || node.kind == PyTreeKind::kDataclass) { + node.custom = registration; + node.node_data = node_data->second; + } else if (node.kind == PyTreeKind::kNamedTuple) { + node.node_data = node_data->first; + } else if (node.kind == PyTreeKind::kDict) { + node.sorted_dict_keys = + nb::cast>(node_data->second); + } + return result; +} + +int PyTreeDef::Node::tp_traverse(visitproc visit, void* arg) const { + Py_VISIT(node_data.ptr()); + for (const auto& key : sorted_dict_keys) { + Py_VISIT(key.ptr()); + } + return 0; +} + +/* static */ int PyTreeDef::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + Py_VISIT(Py_TYPE(self)); + if (!nb::inst_ready(self)) { + return 0; + } + PyTreeDef* treedef = nb::inst_ptr(self); + Py_VISIT(treedef->registry_ref_.ptr()); + for (const auto& node : treedef->traversal_) { + node.tp_traverse(visit, arg); + } + return 0; +} + +/* static */ int PyTreeDef::tp_clear(PyObject* self) { + PyTreeDef* treedef = nb::inst_ptr(self); + treedef->registry_ref_.reset(); + treedef->traversal_.clear(); + return 0; +} + +/* static */ PyType_Slot PyTreeDef::slots_[] = { + {Py_tp_traverse, (void*)PyTreeDef::tp_traverse}, + {Py_tp_clear, (void*)PyTreeDef::tp_clear}, + {0, nullptr}, +}; + +void BuildPytreeSubmodule(nb::module_& m) { + auto iterable_type = nb::typing().attr("Iterable"); + auto tuple_type = nb::typing().attr("Tuple"); + nb::module_ pytree = m.def_submodule("pytree", "Python tree library"); + pytree.attr("version") = nb::int_(3); + pytree.attr("_T") = nb::type_var("_T"); + + pytree.attr("_Children") = nb::type_var( + "_Children", nb::arg("bound") = iterable_type[nb::any_type()]); + nb::object key_leaf_pair_type = + tuple_type[nb::make_tuple(nb::any_type(), nb::any_type())]; + pytree.attr("_KeyLeafPair") = + nb::type_var("_KeyLeafPair", nb::arg("bound") = key_leaf_pair_type); + nb::object key_leaf_pairs_type = + iterable_type[tuple_type[nb::make_tuple(nb::any_type(), nb::any_type())]]; + pytree.attr("_KeyLeafPairs") = + nb::type_var("_KeyLeafPairs", nb::arg("bound") = key_leaf_pairs_type); + nb::object key_path_type = + tuple_type[nb::make_tuple(nb::any_type(), nb::ellipsis())]; + pytree.attr("_KeyPath") = + nb::type_var("_KeyPath", nb::arg("bound") = key_path_type); + nb::object aux_data_type = nb::typing().attr("Hashable"); + pytree.attr("_AuxData") = + nb::type_var("_AuxData", nb::arg("bound") = aux_data_type); + + nb::class_ registry(pytree, "PyTreeRegistry", + nb::dynamic_attr(), + nb::type_slots(PyTreeRegistry::slots_)); + + registry.def(nb::init(), + nb::arg("enable_none") = true, nb::arg("enable_tuple") = true, + nb::arg("enable_namedtuple") = true, + nb::arg("enable_list") = true, nb::arg("enable_dict") = true); + registry.def( + "flatten", + [](nb_class_ptr registry, nb::object x, + std::optional leaf_predicate) { + nb::list leaves; + nb_class_ptr def = + make_nb_class(std::move(registry)); + def->Flatten(x, leaves, leaf_predicate); + return nb::make_tuple(std::move(leaves), std::move(def)); + }, + nb::arg("tree").none(), nb::arg("leaf_predicate").none() = std::nullopt, + nb::sig( + // clang-format off + "def flatten(" + "self, " + "tree: object | None, " + "leaf_predicate: Callable[[Any], bool] | None = None" + ") -> tuple[list[Any], PyTreeDef]" + // clang-format on + )); + registry.def("flatten_one_level", &PyTreeRegistry::FlattenOneLevel, + nb::arg("tree").none(), + // clang-format off + nb::sig("def flatten_one_level(" + "self, " + "tree: object | None" + ") -> tuple[Iterable[Any], Any] | None") + // clang-format on + ); + registry.def("flatten_one_level_with_keys", + &PyTreeRegistry::FlattenOneLevelWithKeys, nb::arg("tree").none(), + nb::sig( + // clang-format off + "def flatten_one_level_with_keys(" + "self, " + "tree: object | None" + ") -> tuple[Iterable[_KeyLeafPair], Any] | None" + // clang-format on + )); + registry.def( + "flatten_with_path", + [](nb_class_ptr registry, nb::object x, + std::optional leaf_predicate) { + nb::list leaves; + nb_class_ptr def = + make_nb_class(std::move(registry)); + def->FlattenWithPath(x, leaves, leaf_predicate); + return nb::make_tuple(std::move(leaves), std::move(def)); + }, + nb::arg("tree").none(), nb::arg("leaf_predicate").none() = std::nullopt, + nb::sig( + // clang-format off + "def flatten_with_path(" + "self, " + "tree: object | None, " + "leaf_predicate: typing.Callable[[Any, Any], bool] | None = None" + ") -> tuple[list[tuple[_KeyPath, Any]], PyTreeDef]" + // clang-format on + )); + registry.def("register_node", &PyTreeRegistry::Register, + nb::arg("type").none(), nb::arg("to_iterable").none(), + nb::arg("from_iterable").none(), + nb::arg("to_iterable_with_keys").none() = std::nullopt, + nb::sig( + // clang-format off + "def register_node(" + "self, " + "type: type[_T], " + "to_iterable: Callable[[_T], tuple[_Children, _AuxData]], " + "from_iterable: Callable[[_AuxData, _Children], _T], " + "to_iterable_with_keys: Callable[[_T], tuple[_KeyLeafPairs, _AuxData]] | None = None" + ") -> Any" + // clang-format on + )); + registry.def("register_dataclass_node", &PyTreeRegistry::RegisterDataclass, + nb::arg("type").none(), nb::arg("data_fields").none(), + nb::arg("meta_fields").none(), + nb::sig( + // clang-format off + "def register_dataclass_node(" + "self, " + "type: type, " + "data_fields: typing.Sequence[str], " + "meta_fields: Sequence[str], /" + ") -> Any" + // clang-format on + )); + registry.def("__reduce__", [](nb::object self) { + return nb::cast(self.attr("__name__")); + }); + + pytree.attr("_default_registry") = make_nb_class( + /*enable_none=*/true, /*enable_tuple=*/true, /*enable_namedtuple=*/true, + /*enable_list=*/true, /*enable_dict*/ true); + pytree.def("default_registry", + [registry = nb::cast>( + pytree.attr("_default_registry"))]() { return registry; }); + + pytree.def("treedef_tuple", &PyTreeDef::Tuple, + nb::sig( + // clang-format off + "def treedef_tuple(" + "registry: PyTreeRegistry, " + "arg0: Sequence[PyTreeDef], /" + ") -> PyTreeDef" + // clang-format on + )); + pytree.def("all_leaves", &PyTreeDef::AllLeaves); + + nb::class_ treedef(pytree, "PyTreeDef", + nb::type_slots(PyTreeDef::slots_)); + m.attr("PyTreeDef") = treedef; // For backwards compatibility. + treedef.def("unflatten", + static_cast( + &PyTreeDef::Unflatten), + nb::sig("def unflatten(self, arg: Iterable[Any], /) -> Any")); + treedef.def("flatten_up_to", &PyTreeDef::FlattenUpTo, nb::arg("tree").none()); + treedef.def("compose", &PyTreeDef::Compose); + treedef.def( + "walk", &PyTreeDef::Walk, + "Walk pytree, calling f_node(node, node_data) at nodes, and f_leaf " + "at leaves", + nb::arg("f_node"), nb::arg("f_leaf"), nb::arg("leaves"), + nb::sig( + // clang-format off + "def walk(" + "self, " + "__f_node: Callable[[Any, Any], Any], " + "__f_leaf: Callable[[_T], Any] | None, " + "leaves: Iterable[Any], /" + ") -> Any" + // clang-format on + )); + treedef.def("from_iterable_tree", &PyTreeDef::FromIterableTree); + treedef.def("children", &PyTreeDef::Children); + treedef.def_prop_ro("num_leaves", &PyTreeDef::num_leaves); + treedef.def_prop_ro("num_nodes", &PyTreeDef::num_nodes); + treedef.def("__repr__", &PyTreeDef::ToString); + treedef.def("__eq__", [](const PyTreeDef& a, nb::object b) { + return nb::isinstance(b) && a == nb::cast(b); + }); + treedef.def("__ne__", [](const PyTreeDef& a, nb::object b) { + return nb::isinstance(b) && a != nb::cast(b); + }); + treedef.def("__hash__", [](const PyTreeDef& t) { return absl::HashOf(t); }); + treedef.def("serialize_using_proto", [](const PyTreeDef& a) { + PyTreeDefProto result; + a.SerializeTo(result); + std::string serialized = result.SerializeAsString(); + return nb::bytes(serialized.data(), serialized.size()); + }); + treedef.def_static( + "deserialize_using_proto", + [](nb_class_ptr registry, nb::bytes data) { + PyTreeDefProto input; + std::string_view serialized(data.c_str(), data.size()); + if (serialized.size() > std::numeric_limits::max()) { + throw xla::XlaRuntimeError( + "Pytree serialization too large to deserialize."); + } + if (!input.ParseFromArray(serialized.data(), serialized.size())) { + throw xla::XlaRuntimeError("Could not deserialize PyTreeDefProto."); + } + return PyTreeDef::DeserializeFrom(std::move(registry), input); + }, + nb::arg("registry"), nb::arg("data")); + treedef.def("node_data", &PyTreeDef::GetNodeData, + "Returns None if a leaf-pytree, else (type, node_data)", + nb::sig("def node_data(self) -> tuple[type, Any] | None")); + treedef.def_static( + "from_node_data_and_children", + &PyTreeDef::FromNodeDataAndChildren, nb::arg("registry"), + nb::arg("node_data").none(), nb::arg("children"), + "Reconstructs a pytree from `node_data()` and `children()`.", + nb::sig( + // clang-format off + "def from_node_data_and_children(" + "self, " + "registry: PyTreeRegistry, " + "node_data: tuple[type, Any] | None, " + "children: typing.Iterable[PyTreeDef]" + ") -> PyTreeDef" + // clang-format on + )); + treedef.def("__getstate__", &PyTreeDef::ToPickle); + treedef.def("__setstate__", [](PyTreeDef& t, nb::object o) { + nb::tuple pickle = nb::cast(o); + if (pickle.size() != 2) { + throw xla::XlaRuntimeError( + "Malformed pickled PyTreeDef, expected 2-tuple"); + } + auto registry = nb::cast>(pickle[0]); + new (&t) PyTreeDef(registry); + t.FromPickle(pickle[1]); + }); + + nb::class_ sequence_key(pytree, "SequenceKey", + nb::sig("class SequenceKey(Hashable)")); + sequence_key.def(nb::init(), nb::arg("idx")); + sequence_key.def("__str__", &SequenceKey::ToString); + sequence_key.def("__repr__", &SequenceKey::ToReprString); + sequence_key.def("__eq__", &SequenceKey::Equals); + sequence_key.def("__hash__", [](const SequenceKey& key) { + return key.idx() + kSequenceKeyHashSalt; + }); + sequence_key.def_prop_ro("idx", &SequenceKey::idx); + sequence_key.def_prop_ro_static("__match_args__", &SequenceKey::MatchArgs); + sequence_key.def("__getstate__", + [](SequenceKey& key) { return nb::make_tuple(key.idx()); }); + sequence_key.def("__setstate__", + [](SequenceKey& key, const nb::tuple& state) { + if (state.size() != 1) { + throw xla::XlaRuntimeError( + "Malformed pickled SequenceKey, expected 1-tuple"); + } + new (&key) SequenceKey(nb::cast(state[0])); + }); + + nb::class_ dict_key(pytree, "DictKey", + nb::type_slots(DictKey::slots_), + nb::sig("class DictKey(Hashable)")); + dict_key.def(nb::init(), nb::arg("key")); + dict_key.def("__str__", &DictKey::ToString); + dict_key.def("__repr__", &DictKey::ToReprString); + dict_key.def("__eq__", &DictKey::Equals); + dict_key.def("__hash__", + [](const DictKey& key) { return nanobind::hash(key.key()); }); + dict_key.def_prop_ro("key", &DictKey::key); + dict_key.def_prop_ro_static("__match_args__", &DictKey::MatchArgs); + dict_key.def("__getstate__", + [](DictKey& key) { return nb::make_tuple(key.key()); }); + dict_key.def("__setstate__", [](DictKey& key, const nb::tuple& state) { + if (state.size() != 1) { + throw xla::XlaRuntimeError("Malformed pickled DictKey, expected 1-tuple"); + } + new (&key) DictKey(nb::cast(state[0])); + }); + + nb::class_ get_attr_key(pytree, "GetAttrKey", + nb::sig("class GetAttrKey(Hashable)")); + get_attr_key.def(nb::init(), nb::arg("name")); + get_attr_key.def("__str__", &GetAttrKey::ToString); + get_attr_key.def("__repr__", &GetAttrKey::ToReprString); + get_attr_key.def("__eq__", &GetAttrKey::Equals); + get_attr_key.def("__hash__", + [](const GetAttrKey& key) { return nb::hash(key.name()); }); + get_attr_key.def_prop_ro("name", &GetAttrKey::name); + get_attr_key.def_prop_ro_static("__match_args__", &GetAttrKey::MatchArgs); + get_attr_key.def("__getstate__", + [](GetAttrKey& key) { return nb::make_tuple(key.name()); }); + get_attr_key.def("__setstate__", [](GetAttrKey& key, const nb::tuple& state) { + if (state.size() != 1) { + throw xla::XlaRuntimeError( + "Malformed pickled GetAttrKey, expected 1-tuple"); + } + new (&key) GetAttrKey(nb::str(state[0])); + }); + + nb::class_ flattened_index_key( + pytree, "FlattenedIndexKey", + nb::sig("class FlattenedIndexKey(Hashable)")); + flattened_index_key.def(nb::init(), nb::arg("key")); + flattened_index_key.def("__str__", &FlattenedIndexKey::ToString); + flattened_index_key.def("__repr__", &FlattenedIndexKey::ToReprString); + flattened_index_key.def("__eq__", &FlattenedIndexKey::Equals); + flattened_index_key.def("__hash__", [](const FlattenedIndexKey& key) { + return key.key() + kFlattenedIndexKeyHashSalt; + }); + flattened_index_key.def_prop_ro("key", &FlattenedIndexKey::key); + flattened_index_key.def_prop_ro_static("__match_args__", + &FlattenedIndexKey::MatchArgs); + flattened_index_key.def("__getstate__", [](FlattenedIndexKey& key) { + return nb::make_tuple(key.key()); + }); + flattened_index_key.def( + "__setstate__", [](FlattenedIndexKey& key, const nb::tuple& state) { + if (state.size() != 1) { + throw xla::XlaRuntimeError( + "Malformed pickled FlattenedIndexKey, expected 1-tuple"); + } + new (&key) FlattenedIndexKey(nb::cast(state[0])); + }); +} + +} // namespace jax diff --git a/jaxlib/pytree.h b/jaxlib/pytree.h new file mode 100644 index 000000000000..7787a26acf7b --- /dev/null +++ b/jaxlib/pytree.h @@ -0,0 +1,408 @@ +/* Copyright 2019 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PYTREE_H_ +#define JAXLIB_PYTREE_H_ + +// See https://docs.jax.dev/en/latest/pytrees.html for the documentation +// about pytree. + +#include + +#include +#include +#include +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/pytree.pb.h" + +namespace jax { + +enum class PyTreeKind { + kLeaf, // An opaque leaf node + kNone, // None. + kTuple, // A tuple + kNamedTuple, // A collections.namedtuple + kList, // A list + kDict, // A dict + kCustom, // A custom type. + kDataclass, // A dataclass. +}; + +// Registry of custom node types. +class PyTreeRegistry { + public: + PyTreeRegistry(bool enable_none, bool enable_tuple, bool enable_namedtuple, + bool enable_list, bool enable_dict); + + PyTreeRegistry(const PyTreeRegistry&) = delete; + PyTreeRegistry(PyTreeRegistry&&) = delete; + PyTreeRegistry& operator=(const PyTreeRegistry&) = delete; + PyTreeRegistry& operator=(PyTreeRegistry&&) = delete; + + struct Registration { + PyTreeKind kind; + + // The following values are populated for custom types. + // The Python type object, used to identify the type. + nanobind::object type; + // A function with signature: object -> (iterable, aux_data) + nanobind::callable to_iterable; + // A function with signature: (aux_data, iterable) -> object + nanobind::callable from_iterable; + // A function with signature: (aux_data, iterable(keypath, leaf)) -> object + std::optional to_iterable_with_keys; + + // Helper that calls to_iterable and validates that it returns a pair + // of an iterable and an aux_data object + std::pair ToIterable( + nanobind::handle o) const; + // Helper that calls to_iterable_with_keys and validates that it returns a + // pair of an iterable of key-leaf pairs and an aux_data object. If + // to_iterable_with_keys is not available, return a dummy key for each leaf, + // similar to the current jax.tree_util.FlattenedIndexKey. + std::pair>, + nanobind::object> + ToIterableWithKeys(nanobind::handle o) const; + + // For dataclasses. + std::vector data_fields; + std::vector meta_fields; + + int tp_traverse(visitproc visit, void* arg); + }; + + // Registers a new custom type. Objects of `type` will be treated as container + // node types in PyTrees. + void Register( + nanobind::object type, nanobind::callable to_iterable, + nanobind::callable from_iterable, + std::optional to_iterable_with_keys = std::nullopt); + // Same, but for dataclasses. + void RegisterDataclass(nanobind::object type, + std::vector data_fields, + std::vector meta_fields); + + // Finds the custom type registration for `type`. Returns nullptr if none + // exists. + const Registration* Lookup(nanobind::handle type) const; + + PyTreeKind KindOfObject(nanobind::handle obj, + PyTreeRegistry::Registration const** custom) const; + + // Flattens a pytree one level, returning either a tuple of the leaves and + // the node data, or None, if the entry is a leaf. + nanobind::object FlattenOneLevel(nanobind::handle x) const; + // Similar to above but returns a key-leaf pair for each leaf. + nanobind::object FlattenOneLevelWithKeys(nanobind::handle x) const; + // Underlying implementation of FlattenOneLevel and FlattenOneLevelWithKeys. + nanobind::object FlattenOneLevelImpl(nanobind::handle x, + bool with_keys) const; + + static PyType_Slot slots_[]; + + private: + struct TypeHash { + using is_transparent = void; + size_t operator()(const nanobind::object& t) const { + return absl::HashOf(t.ptr()); + } + size_t operator()(const nanobind::handle& t) const { + return absl::HashOf(t.ptr()); + } + }; + struct TypeEq { + using is_transparent = void; + bool operator()(const nanobind::object& a, + const nanobind::object& b) const { + return a.ptr() == b.ptr(); + } + bool operator()(const nanobind::object& a, + const nanobind::handle& b) const { + return a.ptr() == b.ptr(); + } + }; + mutable nanobind::ft_mutex mu_; + absl::flat_hash_map, TypeHash, + TypeEq> + registrations_; // Guarded by mu_ + bool enable_namedtuple_; + + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); +}; + +class SequenceKey { + public: + explicit SequenceKey(int idx) : idx_(idx) {}; + std::string ToReprString() const; + std::string ToString() const; + bool Equals(const nanobind::object& other); + int idx() const { return idx_; } + static nanobind::tuple MatchArgs(nanobind::handle unused); + + private: + int idx_; +}; + +class DictKey { + public: + explicit DictKey(nanobind::object key) : key_(key) {}; + std::string ToReprString() const; + std::string ToString() const; + bool Equals(const nanobind::object& other); + nanobind::object key() const { return key_; } + static nanobind::tuple MatchArgs(nanobind::handle unused); + static PyType_Slot slots_[]; + + private: + nanobind::object key_; + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); +}; + +class GetAttrKey { + public: + explicit GetAttrKey(nanobind::str name) : name_(name) {}; + std::string ToReprString() const; + std::string ToString() const; + bool Equals(const nanobind::object& other); + nanobind::str name() const { return name_; } + static nanobind::tuple MatchArgs(nanobind::handle unused); + + private: + nanobind::str name_; +}; + +class FlattenedIndexKey { + public: + explicit FlattenedIndexKey(int key) : key_(key) {}; + std::string ToReprString() const; + std::string ToString() const; + bool Equals(const nanobind::object& other); + int key() const { return key_; } + static nanobind::tuple MatchArgs(nanobind::handle unused); + + private: + int key_; +}; + +// A PyTreeDef describes the tree structure of a PyTree. A PyTree is a tree of +// Python values, where the interior nodes are tuples, lists, dictionaries, or +// user-defined containers, and the leaves are other objects. +class PyTreeDef { + public: + // Unowned registry: the registry must remain live at least as long as the + // PyTreeDef. It is the caller's responsibility to enforce this. + explicit PyTreeDef(PyTreeRegistry* registry) : registry_(registry) {} + + explicit PyTreeDef(nb_class_ptr registry) + : registry_(registry.get()), registry_ref_(std::move(registry)) {} + + // Flattens a Pytree into a list of leaves and a PyTreeDef. + // Returns references to the flattened objects, which might be temporary + // objects in the case of custom pytype handlers. + static std::pair, nb_class_ptr> + Flatten(nanobind::handle x, nb_class_ptr registry, + std::optional leaf_predicate = std::nullopt); + + // Flattens a Pytree into a list of `leaves` and a PyTreeDef (this). + // `leaves` owns references to the flattened objects, which might be + // temporary objects in the case of custom pytype handlers. + void Flatten(nanobind::handle handle, std::vector& leaves, + std::optional leaf_predicate = std::nullopt); + void Flatten(nanobind::handle handle, + absl::InlinedVector& leaves, + std::optional leaf_predicate = std::nullopt); + void Flatten(nanobind::handle handle, nanobind::list& leaves, + std::optional leaf_predicate = std::nullopt); + + void FlattenWithPath( + nanobind::handle handle, nanobind::list& leaves, + std::optional leaf_predicate = std::nullopt); + + // Tests whether the given list is a flat list of leaves. + static bool AllLeaves(PyTreeRegistry* registry, const nanobind::iterable& x); + + // Flattens a Pytree up to this PyTreeDef. 'this' must be a tree prefix of + // the tree-structure of 'x'. For example, if we flatten a value + // [(1, (2, 3)), {"foo": 4}] with a treedef [(*, *), *], the result is the + // list of leaves [1, (2, 3), {"foo": 4}]. + nanobind::list FlattenUpTo(nanobind::handle x) const; + + // Returns an unflattened PyTree given an iterable of leaves and a PyTreeDef. + nanobind::object Unflatten(nanobind::iterable leaves) const; + nanobind::object Unflatten(absl::Span leaves) const; + + // Composes two PyTreeDefs, replacing the leaves of this tree with copies of + // `inner`. The returned PyTreeDef holds a reference to its registry. + nb_class_ptr Compose(const PyTreeDef& inner) const; + + // Makes a Tuple PyTreeDef out of a vector of PyTreeDefs. + static nb_class_ptr Tuple(nb_class_ptr registry, + nanobind::list defs); + + // The returned PyTreeDefs hold a reference to the registry. + std::vector> Children() const; + + // Maps a function over a PyTree structure, applying f_leaf to each leaf, and + // f_node(node, node_data) to each container node. + nanobind::object Walk(const nanobind::callable& f_node, + nanobind::handle f_leaf, + nanobind::iterable leaves) const; + + // Given a tree of iterables with the same node/leaf structure as this PyTree, + // build the corresponding PyTree. + // TODO(phawkins): use flattening everywhere instead and delete this method. + nanobind::object FromIterableTree(nanobind::handle xs) const; + + int num_leaves() const { + if (traversal_.empty()) { + return 0; + } + return traversal_.back().num_leaves; + } + + int num_nodes() const { return traversal_.size(); } + + PyTreeRegistry* registry() const { return registry_; } + + size_t Hash() const; + + bool operator==(const PyTreeDef& other) const; + bool operator!=(const PyTreeDef& other) const { return !(*this == other); } + + std::string ToString() const; + + // Transforms the PyTreeDef into a pickleable object. Used to implement + // `PyTreeDef.__getstate__`. + nanobind::object ToPickle() const; + + // Transforms the object returned by `ToPickleable()` back to PyTreeDef. Used + // to implement `PyTreeDef.__setstate__`. + void FromPickle(nanobind::object pickleable); + + void SerializeTo(PyTreeDefProto& result) const; + + static nb_class_ptr DeserializeFrom( + nb_class_ptr registry, const PyTreeDefProto& input); + + std::optional> GetNodeData() + const; + + static nb_class_ptr FromNodeDataAndChildren( + nb_class_ptr registry, + std::optional> node_data, + nanobind::iterable children); + + static PyType_Slot slots_[]; + + private: + void SetNumLeavesAndNumNodes(); + + struct Node { + PyTreeKind kind = PyTreeKind::kLeaf; + + // Arity for non-kLeaf types. + int arity = 0; + + // Kind-specific auxiliary data. For a kNamedTuple, contains the tuple type + // object. For a kDict, use `sorted_dict_keys` field below. For a kCustom + // type, contains the auxiliary data returned by the `to_iterable` function. + nanobind::object node_data; + + // Kind-specific auxiliary data specialized for kDict. Use a c++ vector + // to hold the sorted dict keys instead of a py::list to avoid creating + // a new python list object when flattening kDict. For deeply nested dict, + // using c++ vector instead of py::list avoids creating too many python + // objects that make python gc sweep slow. + std::vector sorted_dict_keys; + + // Custom type registration. Must be null for non-custom types. + const PyTreeRegistry::Registration* custom = nullptr; + + // Number of leaf nodes in the subtree rooted at this node. + int num_leaves = 0; + + // Number of leaf and interior nodes in the subtree rooted at this node. + int num_nodes = 0; + + int tp_traverse(visitproc visit, void* arg) const; + }; + template + friend H AbslHashValue(H h, const Node& n); + + template + friend H AbslHashValue(H h, const PyTreeDef& t); + + // Helper that manufactures an instance of a node given its children. + static nanobind::object MakeNode(const Node& node, + absl::Span children); + + // Recursive helper used to implement FromIterableTree() + nanobind::object FromIterableTreeHelper( + nanobind::handle xs, + absl::InlinedVector::const_reverse_iterator* it) + const; + + template + void FlattenImpl(nanobind::handle handle, T& leaves, + std::optional>& keypath, + const std::optional& leaf_predicate); + + template + nanobind::object UnflattenImpl(T leaves) const; + + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); + + // Pytree registry. Not owned. + PyTreeRegistry* registry_; + // If this class holds a reference to `registry`, it is held by + // `registry_ref_`. + nb_class_ptr registry_ref_; + + // Nodes, in a post-order traversal. We use an ordered traversal to minimize + // allocations, and post-order corresponds to the order we need to rebuild the + // tree structure. + absl::InlinedVector traversal_; +}; + +template +H AbslHashValue(H h, const PyTreeDef::Node& n) { + h = H::combine(std::move(h), n.kind, n.arity, n.custom); + return h; +} + +template +H AbslHashValue(H h, const PyTreeDef& t) { + h = H::combine(std::move(h), t.traversal_); + return h; +} + +void BuildPytreeSubmodule(nanobind::module_& m); + +} // namespace jax + +#endif // JAXLIB_PYTREE_H_ diff --git a/jaxlib/pytree.proto b/jaxlib/pytree.proto new file mode 100644 index 000000000000..73c087ef55ab --- /dev/null +++ b/jaxlib/pytree.proto @@ -0,0 +1,32 @@ +syntax = "proto3"; + +package jax; + +enum PyTreeNodeType { + PY_TREE_KIND_INVALID = 0; + PY_TREE_KIND_LEAF = 1; + PY_TREE_KIND_LIST = 2; + PY_TREE_KIND_NONE = 3; + PY_TREE_KIND_TUPLE = 4; + PY_TREE_KIND_DICT = 5; +} + +message DictKeysProto { + repeated uint32 str_id = 1; +} + +message PyTreeNodeDefProto { + // Recovers the tree structure. + uint32 arity = 1; + // Node type. + PyTreeNodeType type = 2; + // Only set when type == DICT. + DictKeysProto dict_keys = 3; +} + +// A Pytree. +message PyTreeDefProto { + repeated PyTreeNodeDefProto nodes = 1; + // Extra strings. + repeated string interned_strings = 2; +} diff --git a/jaxlib/pytree_test.py b/jaxlib/pytree_test.py new file mode 100644 index 000000000000..0c2eeca130e9 --- /dev/null +++ b/jaxlib/pytree_test.py @@ -0,0 +1,144 @@ +# Copyright 2023 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import collections +import dataclasses +import gc + +from absl.testing import absltest + +from jax.jaxlib import xla_client + +pytree = xla_client._xla.pytree + + +ExampleType = collections.namedtuple("ExampleType", "field0 field1") + +registry = pytree.PyTreeRegistry() + + +class ExampleType2: + + def __init__(self, field0, field1): + self.field0 = field0 + self.field1 = field1 + + def to_iterable(self): + return [self.field0, self.field1], (None,) + + +def from_iterable(state, values): + del state + return ExampleType2(field0=values[0], field1=values[1]) + + +registry.register_node(ExampleType2, ExampleType2.to_iterable, from_iterable) + + +@dataclasses.dataclass +class Custom: + a: int + b: str + + +registry.register_dataclass_node(Custom, ["a"], ["b"]) + + +class PyTreeTest(absltest.TestCase): + + def roundtrip_proto(self, example): + original = registry.flatten(example)[1] + self.assertEqual( + pytree.PyTreeDef.deserialize_using_proto( + registry, original.serialize_using_proto() + ), + original, + ) + + def testSerializeDeserializeNoPickle(self): + o = object() + self.roundtrip_proto(({"a": o, "b": o}, [o, (o, o), None])) + + def testSerializeWithFallback(self): + o = object() + with self.assertRaises(ValueError): + self.roundtrip_proto({"a": ExampleType(field0=o, field1=o)}) + + def testRegisteredType(self): + o = object() + with self.assertRaises(ValueError): + self.roundtrip_proto({"a": ExampleType2(field0=o, field1=o)}) + + def roundtrip_node_data(self, example): + original = registry.flatten(example)[1] + restored = pytree.PyTreeDef.from_node_data_and_children( + registry, original.node_data(), original.children() + ) + self.assertEqual(restored, original) + + def testRoundtripNodeData(self): + o = object() + self.roundtrip_node_data([o, o, o]) + self.roundtrip_node_data((o, o, o)) + self.roundtrip_node_data({"a": o, "b": o}) + self.roundtrip_node_data({22: o, 88: o}) + self.roundtrip_node_data(None) + self.roundtrip_node_data(o) + self.roundtrip_node_data(ExampleType(field0=o, field1=o)) + self.roundtrip_node_data(ExampleType2(field0=o, field1=o)) + + def testCompose(self): + x = registry.flatten(0)[1] + y = registry.flatten((0, 0))[1] + self.assertEqual((x.compose(y)).num_leaves, 2) + + def testDataclassMakeFromNodeData(self): + c = Custom(1, "a") + c_leafs, c_tree = registry.flatten(c) + c_tree2 = pytree.PyTreeDef.from_node_data_and_children( + registry, c_tree.node_data(), c_tree.children() + ) + self.assertEqual(c_tree2.unflatten(c_leafs), c) + self.assertEqual(str(c_tree2), str(c_tree)) + + def testTpTraverse(self): + self.assertContainsSubset( + [ + pytree.PyTreeRegistry, + ExampleType2, + ExampleType2.to_iterable, + from_iterable, + ], + gc.get_referents(registry), + ) + k1 = "k1" + k2 = "k2" + + t = ExampleType("a", "b") + _, treedef = registry.flatten([1, {k1: 2, k2: t}, 5, t]) + + self.assertContainsSubset( + [ + pytree.PyTreeDef, + registry, + k1, + k2, + ExampleType, + ], + gc.get_referents(treedef), + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxlib/pywrap.bzl b/jaxlib/pywrap.bzl new file mode 100644 index 000000000000..002f51230995 --- /dev/null +++ b/jaxlib/pywrap.bzl @@ -0,0 +1,85 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wrappers around pywrap rules for JAX.""" + +load( + "@xla//third_party/py/rules_pywrap:pywrap.impl.bzl", + "pybind_extension", + _pywrap_binaries = "pywrap_binaries", + _pywrap_library = "pywrap_library", +) +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@bazel_skylib//rules:expand_template.bzl", "expand_template") + +pywrap_library = _pywrap_library +pywrap_binaries = _pywrap_binaries + +def nanobind_pywrap_extension( + name, + srcs = [], + deps = [], + pytype_srcs = [], + pytype_deps = [], + copts = [], + linkopts = [], + visibility = None, + **kwargs): # @unused + # buildifier: disable=function-docstring-args + "Python extension rule using nanobind and the pywrap rules." + module_name = name + lib_name = name + "_pywrap_library" + src_cc_name = name + "_pywrap_stub.c" + + # We put the entire contents of the extension in a single cc_library, which will become part of + # the common pywrap library. All the contents of all extensions will end up in the common + # library. + cc_library( + name = lib_name, + srcs = srcs, + copts = copts, + deps = deps, + local_defines = [ + "PyInit_{}=Wrapped_PyInit_{}".format(module_name, module_name), + ], + visibility = ["//visibility:private"], + ) + + # We build a small stub library as the extension that forwards to the PyInit_... symbol from the + # common pywrap library. + expand_template( + name = name + "_pywrap_stub", + testonly = True, + out = src_cc_name, + substitutions = { + "@MODULE_NAME@": module_name, + }, + template = "//jaxlib:pyinit_stub.c", + visibility = ["//visibility:private"], + ) + + # Despite its name "pybind_extension" has nothing to do with pybind. It is the Python extension + # rule from the pywrap rules. + pybind_extension( + name = name, + srcs = [src_cc_name], + deps = [":" + lib_name], + data = pytype_srcs, + linkopts = linkopts, + visibility = visibility, + default_deps = [], + common_lib_packages = [ + "jaxlib", + ], + ) diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 9a25a795fd14..1b8e8dd1e64b 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -14,6 +14,7 @@ # AMD HIP kernels +load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_python//python:defs.bzl", "py_library") load( "//jaxlib:jax.bzl", @@ -26,7 +27,7 @@ licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//:__subpackages__"], + default_visibility = ["//visibility:public"], ) cc_library( @@ -79,7 +80,7 @@ cc_library( deps = [ ":hip_gpu_kernel_helpers", ":hip_vendor", - "//jaxlib:handle_pool", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@local_config_rocm//rocm:hipblas", @@ -87,54 +88,6 @@ cc_library( ], ) -cc_library( - name = "hipblas_kernels", - srcs = ["//jaxlib/gpu:blas_kernels.cc"], - hdrs = ["//jaxlib/gpu:blas_kernels.h"], - deps = [ - ":hip_blas_handle_pool", - ":hip_gpu_kernel_helpers", - ":hip_make_batch_pointers", - ":hip_vendor", - "//jaxlib:kernel_helpers", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@local_config_rocm//rocm:hipblas", - "@local_config_rocm//rocm:rocm_headers", - "@xla//xla/service:custom_call_status", - ], -) - -nanobind_extension( - name = "_blas", - srcs = ["//jaxlib/gpu:blas.cc"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - module_name = "_blas", - deps = [ - ":hip_vendor", - ":hipblas_kernels", - "//jaxlib:kernel_nanobind_helpers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings:str_format", - "@local_config_rocm//rocm:hipblas", - "@local_config_rocm//rocm:rocm_headers", - "@nanobind", - "@xla//xla/tsl/python/lib/core:numpy", - ], -) - cc_library( name = "miopen_rnn_kernels", srcs = ["//jaxlib/gpu:rnn_kernels.cc"], @@ -143,15 +96,15 @@ cc_library( ":ffi_wrapper", ":hip_gpu_kernel_helpers", ":hip_vendor", - "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@local_config_rocm//rocm:miopen", "@local_config_rocm//rocm:rocm_headers", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", ], ) @@ -182,29 +135,11 @@ cc_library( deps = [ ":hip_gpu_kernel_helpers", ":hip_vendor", - "//jaxlib:handle_pool", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/synchronization", - "@local_config_rocm//rocm:hipsolver", - "@local_config_rocm//rocm:rocm_headers", - ], -) - -cc_library( - name = "hipsolver_kernels", - srcs = ["//jaxlib/gpu:solver_kernels.cc"], - hdrs = ["//jaxlib/gpu:solver_kernels.h"], - deps = [ - ":hip_gpu_kernel_helpers", - ":hip_solver_handle_pool", - ":hip_vendor", - "//jaxlib:kernel_helpers", - "@com_google_absl//absl/status", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@local_config_rocm//rocm:hipsolver", "@local_config_rocm//rocm:rocm_headers", - "@xla//xla/service:custom_call_status", ], ) @@ -242,7 +177,6 @@ cc_library( "@local_config_rocm//rocm:hipsolver", "@local_config_rocm//rocm:rocm_headers", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", ], ) @@ -256,20 +190,13 @@ nanobind_extension( features = ["-use_header_modules"], module_name = "_solver", deps = [ - ":hip_gpu_kernel_helpers", - ":hip_solver_handle_pool", ":hip_vendor", - ":hipsolver_kernels", ":hipsolver_kernels_ffi", "//jaxlib:kernel_nanobind_helpers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", "@local_config_rocm//rocm:hipblas", "@local_config_rocm//rocm:hipsolver", "@local_config_rocm//rocm:rocm_headers", "@nanobind", - "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -291,16 +218,17 @@ cc_library( ":ffi_wrapper", ":hip_gpu_kernel_helpers", ":hip_vendor", - "//jaxlib:handle_pool", + "//jaxlib:ffi_helpers", "//jaxlib:kernel_helpers", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@local_config_rocm//rocm:hipsparse", "@local_config_rocm//rocm:rocm_headers", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", ], ) @@ -398,7 +326,6 @@ cc_library( "@local_config_rocm//rocm:rocm_headers", "@xla//xla/ffi/api:c_api", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", ], ) @@ -412,7 +339,6 @@ rocm_library( "//jaxlib:kernel_helpers", "@local_config_rocm//rocm:rocm_headers", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", ], ) @@ -430,6 +356,7 @@ nanobind_extension( ":hip_prng_kernels", ":hip_vendor", "//jaxlib:kernel_nanobind_helpers", + "@local_config_rocm//rocm:hip_runtime", "@local_config_rocm//rocm:rocm_headers", "@nanobind", ], @@ -472,10 +399,15 @@ nanobind_extension( "//jaxlib:kernel_nanobind_helpers", "//jaxlib/cpu:lapack_kernels", "@com_google_absl//absl/base", + "@local_config_rocm//rocm:hip_runtime", "@local_config_rocm//rocm:rocm_headers", "@nanobind", "@xla//xla/ffi/api:ffi", ], + linkopts = [ + "-L/opt/rocm/lib", + "-lamdhip64", + ], ) cc_library( @@ -497,8 +429,9 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", - "@tsl//tsl/platform:env", "@xla//xla/service:custom_call_status", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:errors", "@xla//xla/tsl/util:env_var", ], ) @@ -536,7 +469,9 @@ nanobind_extension( "//jaxlib:absl_status_casters", "//jaxlib:kernel_nanobind_helpers", "//jaxlib/gpu:triton_cc_proto", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@local_config_rocm//rocm:hip_runtime", "@nanobind", ], ) @@ -544,7 +479,6 @@ nanobind_extension( py_library( name = "rocm_gpu_support", deps = [ - ":_blas", ":_hybrid", ":_linalg", ":_prng", @@ -555,15 +489,57 @@ py_library( ], ) +cc_library( + name = "py_client_gpu", + srcs = ["//jaxlib/gpu:py_client_gpu.cc"], + hdrs = ["//jaxlib/gpu:py_client_gpu.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":hip_vendor", + "//jaxlib:ffi", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@dlpack", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:comparison_util", + "@xla//xla:shape_util", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/ffi:ffi_api", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", + "@xla//xla/pjrt:host_callback", + "@xla//xla/pjrt:transpose", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:types", + "@xla//xla/service:platform_util", + ], +) + nanobind_extension( name = "rocm_plugin_extension", srcs = ["rocm_plugin_extension.cc"], module_name = "rocm_plugin_extension", deps = [ + ":hip_gpu_kernel_helpers", + ":py_client_gpu", + "//jaxlib:kernel_nanobind_helpers", "//jaxlib/gpu:gpu_plugin_extension", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", - "@local_config_rocm//rocm:hip", + "@local_config_rocm//rocm:hip_runtime", "@local_config_rocm//rocm:rocm_headers", "@nanobind", ], diff --git a/jaxlib/rocm/rocm_plugin_extension.cc b/jaxlib/rocm/rocm_plugin_extension.cc index 1dd1f1943fc8..e28e4927b81f 100644 --- a/jaxlib/rocm/rocm_plugin_extension.cc +++ b/jaxlib/rocm/rocm_plugin_extension.cc @@ -16,16 +16,20 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" #include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "rocm/include/hip/hip_runtime.h" +#include "nanobind/nanobind.h" #include "jaxlib/gpu/gpu_plugin_extension.h" +#include "jaxlib/gpu/py_client_gpu.h" +#include "jaxlib/kernel_nanobind_helpers.h" +#include "jaxlib/gpu/gpu_kernel_helpers.h" namespace nb = nanobind; -namespace xla { +namespace jax { namespace { + std::string ToString(hipError_t result) { #define OSTREAM_ROCM_ERROR(__name) \ case hipError##__name: \ @@ -62,10 +66,51 @@ std::string ToString(hipError_t result) { return absl::StrCat("hipError_t(", static_cast(result), ")"); } } + +static nb::dict GpuTransposePlanCacheType() { + auto [type_id, type_info] = hip::GpuTransposePlanCacheTypeInfo(); + nb::dict d; + d["type_id"] = nb::capsule(type_id); + d["type_info"] = nb::capsule(type_info); + return d; +} + +nb::dict FfiTypes() { + nb::dict dict; + dict["GpuTransposePlanCache"] = GpuTransposePlanCacheType(); + return dict; +} + +nb::dict FfiHandlers() { + nb::dict dict; + nb::dict gpu_callback_dict; + gpu_callback_dict["instantiate"] = + EncapsulateFfiHandler(hip::kGpuTransposePlanCacheInstantiate); + gpu_callback_dict["execute"] = + EncapsulateFfiHandler(hip::kXlaFfiPythonGpuCallback); + dict["xla_ffi_python_gpu_callback"] = gpu_callback_dict; + dict["xla_ffi_partitioned_python_gpu_callback"] = gpu_callback_dict; + dict["xla_buffer_python_gpu_callback"] = + EncapsulateFfiHandler(hip::kXlaBufferPythonGpuCallback); + dict["xla_buffer_python_gpu_callback_cmd_buffer"] = + EncapsulateFfiHandler(hip::kXlaBufferPythonGpuCallbackCmdBuffer); + return dict; +} + +int ROCmDeviceCount() { + int device_count = -1; + JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipInit(0))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipGetDeviceCount(&device_count))); + return device_count; +} + } // namespace NB_MODULE(rocm_plugin_extension, m) { BuildGpuPluginExtension(m); + m.def("ffi_types", &FfiTypes); + m.def("ffi_handlers", &FfiHandlers); + m.def( "get_device_ordinal", [](std::intptr_t data_value) { @@ -85,5 +130,6 @@ NB_MODULE(rocm_plugin_extension, m) { return device_ordinal; }, nb::arg("data_value")); + m.def("get_device_count", &ROCmDeviceCount); } -} // namespace xla +} // namespace jax diff --git a/jaxlib/sdy_mpmd.cc b/jaxlib/sdy_mpmd.cc new file mode 100644 index 000000000000..d21729dbe350 --- /dev/null +++ b/jaxlib/sdy_mpmd.cc @@ -0,0 +1,258 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "mlir-c/IR.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep; Needed to allow MlirModule -> ModuleOp. +#include "mlir/CAPI/IR.h" // IWYU pragma: keep; Needed to allow MlirModule -> ModuleOp. +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OperationSupport.h" +#include "nanobind/nanobind.h" +// IWYU pragma: begin_keep; Nanobind conversions for std types. +#include "nanobind/stl/map.h" +#include "nanobind/stl/optional.h" +#include "nanobind/stl/pair.h" +#include "nanobind/stl/string.h" +#include "nanobind/stl/tuple.h" +#include "nanobind/stl/variant.h" +#include "nanobind/stl/vector.h" +// IWYU pragma: end_keep +#include "shardy/dialect/mpmd/ir/fragment_execution_rules.h" +#include "shardy/dialect/mpmd/ir/utils.h" +#include "shardy/integrations/python/jax/mpmd/jaxlib/mpmd_program.h" +#include "xla/pjrt/status_casters.h" // IWYU pragma: keep; Needed for ValueOrThrow +#include "xla/python/ifrt/ir/conversions/mpmd/lower_to_ifrt.h" +#include "xla/python/nb_absl_flat_hash_map.h" // IWYU pragma: keep + +namespace nb = nanobind; + +namespace jax::mpmd { +namespace { + +using ::mlir::Builder; +using ::mlir::ModuleOp; +using ::mlir::mpmd::FlatMesh; +using ::mlir::mpmd::FragmentInfo; +using ::mlir::mpmd::FragmentMergeRule; +using ::mlir::mpmd::FragmentMergeRules; +using ::mlir::mpmd::FragmentOrigin; +using ::mlir::mpmd::FragmentScheduleRule; +using ::mlir::mpmd::FunctionIOShardingSpecsAndMeshes; +using ::mlir::mpmd::MpmdProgram; +using ::mlir::mpmd::NamedSpmdShardingSpec; +using ::mlir::mpmd::PartitioningOptions; +using ::mlir::mpmd::PartitioningPhase; +using ::mlir::mpmd::PartitioningResult; +using ::mlir::mpmd::SplitFragmentType; +using ::mlir::mpmd::SpmdTensorPartitionSpec; +using ::mlir::mpmd::UserAssignmentMap; +using ::xla::ifrt::mpmd::EnvOptionsOverride; +using ::xla::ifrt::mpmd::GetCompileOptions; +using ::xla::ifrt::mpmd::LowerToIfrt; + +// Wrapper of PartitioningResult, which stores MlirModules instead of ModuleOps. +struct PartitioningResultWrapper { + MlirModule mpmd_module; + mlir::mpmd::FunctionIOShardingSpecsAndMeshes + module_io_sharding_specs_and_meshes; +}; + +// name -> [mesh | (mesh, stage)] +using PyUserAssignmentMap = + std::map>>; + +UserAssignmentMap GetCppUserAssignmentMap(const PyUserAssignmentMap& py_map) { + UserAssignmentMap cpp_map; + for (const auto& [name, py_value] : py_map) { + if (const auto* mesh = std::get_if(&py_value)) { + cpp_map[name] = std::make_pair(*mesh, std::nullopt); + } else if (const auto* mesh_stage = + std::get_if>(&py_value)) { + cpp_map[name] = *mesh_stage; + } + } + return cpp_map; +} + +NB_MODULE(_sdy_mpmd, m) { + nb::enum_(m, "PartitioningPhase", nb::is_flag()) + .value("NONE", PartitioningPhase::kNone) + .value("IMPORT", PartitioningPhase::kImport) + .value("OPTIMIZE", PartitioningPhase::kOptimize) + .value("PARTITION", PartitioningPhase::kPartition) + .value("ALL", PartitioningPhase::kAll) + .export_values(); + + nb::enum_(m, "SplitFragmentType") + .value("KEEP_TRANSFERRED", SplitFragmentType::kKeepTransferred) + .value("DROP_TRANSFERRED", SplitFragmentType::kDropTransferred) + .export_values(); + + nb::class_(m, "FragmentOrigin") + .def(nb::init(), nb::arg("computation_name"), + nb::arg("transpose_count")) + .def_ro("computation_name", &FragmentOrigin::computation_name) + .def_ro("transpose_count", &FragmentOrigin::transpose_count); + + nb::class_(m, "FragmentInfo") + .def(nb::init&, std::optional, + std::optional, + std::optional, + const std::string&>(), + nb::arg("origins"), nb::arg("stage_id").none() = std::nullopt, + nb::arg("call_counter").none() = std::nullopt, + nb::arg("split_type").none() = std::nullopt, nb::arg("mesh_name")) + .def_ro("origins", &FragmentInfo::origins) + .def_ro("stage_id", &FragmentInfo::stage_id) + .def_ro("call_counter", &FragmentInfo::call_counter) + .def_ro("split_type", &FragmentInfo::split_type) + .def_ro("mesh_name", &FragmentInfo::mesh_name); + + nb::class_(m, "FragmentScheduleRule") + .def(nb::init&>(), + nb::arg("ordered_fragments")) + .def_ro("ordered_fragments", &FragmentScheduleRule::ordered_fragments); + + nb::class_(m, "FragmentMergeRule") + .def(nb::init&, FragmentInfo>(), + nb::arg("sources"), nb::arg("target")) + .def_ro("sources", &FragmentMergeRule::sources) + .def_ro("target", &FragmentMergeRule::target); + + nb::class_(m, "PartitioningResult") + .def_ro("mpmd_module", &PartitioningResultWrapper::mpmd_module) + .def_ro("module_io_sharding_specs_and_meshes", + &PartitioningResultWrapper::module_io_sharding_specs_and_meshes); + + m.def( + "apply_mpmd_partitioning", + [](MlirModule c_module, std::string func_name, + const std::vector>& named_meshes, + const mpmd::PyUserAssignmentMap& assignment, + const std::vector>& input_meshes, + const std::vector>& output_meshes, + const std::vector& donate_argnums, + const std::optional< + std::map>>& + partitioning_options, + const FragmentMergeRules& fragment_merge_rules, + PartitioningPhase phases) -> PartitioningResultWrapper { + PartitioningOptions options; + if (partitioning_options) { + options = mlir::mpmd::ParsePartitioningOptions(*partitioning_options); + } + MpmdProgram program{.module = unwrap(c_module), + .func_name = func_name, + .options = std::move(options), + .named_meshes = named_meshes, + .assignment = GetCppUserAssignmentMap(assignment), + .input_meshes = input_meshes, + .output_meshes = output_meshes, + .donate_argnums = donate_argnums, + .fragment_merge_rules = fragment_merge_rules}; + + PartitioningResult partitioning_result = + program.ApplyPartitioning(phases); + + return PartitioningResultWrapper{ + wrap(partitioning_result.mpmd_module), + std::move(partitioning_result.module_io_sharding_specs_and_meshes), + }; + }, + nb::arg("module"), nb::arg("func_name"), nb::arg("named_meshes"), + nb::arg("assignment"), nb::arg("input_meshes"), nb::arg("output_meshes"), + nb::arg("donate_argnums"), + nb::arg("partitioning_options").none() = std::nullopt, + nb::arg("fragment_merge_rules"), nb::arg("phases")); + + m.def("get_fragment_info", + [](MlirModule c_module) -> std::vector { + std::vector fragment_info; + auto module = unwrap(c_module); + // Walk module and get info for each fragment + module.walk([&fragment_info](mlir::mpmd::FragmentOp fragment) { + fragment_info.push_back(mlir::mpmd::GetFragmentInfo(fragment)); + }); + return fragment_info; + }); + + nb::class_(m, "NamedSpmdShardingSpec") + .def(nb::init>(), + nb::arg("mesh_name"), nb::arg("tensor_spec"), + nb::arg("memory_kind").none() = std::nullopt) + .def_ro("mesh_name", &NamedSpmdShardingSpec::mesh_name) + .def_ro("tensor_spec", &NamedSpmdShardingSpec::tensor_spec) + .def_ro("memory_kind", &NamedSpmdShardingSpec::memory_kind); + nb::class_( + m, "FunctionIOShardingSpecsAndMeshes") + .def(nb::init, + std::vector>()) + .def_ro("input_specs", &FunctionIOShardingSpecsAndMeshes::input_specs) + .def_ro("output_specs", &FunctionIOShardingSpecsAndMeshes::output_specs); + + m.def( + "clone_mlir_module", + [](MlirModule c_module, const std::vector& unit_attributes) { + MlirOperation op = mlirModuleGetOperation(c_module); + MlirModule module = mlirModuleFromOperation(mlirOperationClone(op)); + if (unit_attributes.empty()) { + return module; + } + + ModuleOp module_op = unwrap(module); + for (const std::string& attr_name : unit_attributes) { + module_op->setAttr(attr_name, Builder(module_op).getUnitAttr()); + } + return wrap(module_op); + }, + nb::arg("c_module"), + nb::arg("unit_attributes") = std::vector()); + + m.def( + "lower_to_ifrt", + [](MlirModule module) -> void { + return xla::ThrowIfError(LowerToIfrt(unwrap(module))); + }, + nb::arg("module")); + + m.def("get_compile_options", + [](MlirModule c_module, + const absl::flat_hash_map& + compile_options_overrides) -> absl::StatusOr { + auto module = unwrap(c_module); + auto compile_options_map = ValueOrThrow( + GetCompileOptions(module, compile_options_overrides)); + nb::dict out; + for (const auto& [name, options] : compile_options_map) { + out[nb::cast(name)] = + nb::steal(nanobind::cast(options).release().ptr()); + } + return out; + }); +} + +} // namespace +} // namespace jax::mpmd diff --git a/jaxlib/setup.py b/jaxlib/setup.py index b3a37a25f1b2..ebe433c62e42 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -58,24 +58,28 @@ def has_ext_modules(self): long_description_content_type='text/markdown', author='JAX team', author_email='jax-dev@google.com', - packages=['jaxlib', 'jaxlib.xla_extension'], - python_requires='>=3.10', + packages=['jaxlib'], + python_requires='>=3.11', install_requires=[ - 'scipy>=1.11.1', - 'numpy>=1.25', - 'ml_dtypes>=0.2.0', + 'scipy>=1.13', + 'numpy>=2.0', + 'ml_dtypes>=0.5.0', ], url='https://github.com/jax-ml/jax', license='Apache-2.0', classifiers=[ - "Programming Language :: Python :: 3.10", + "Development Status :: 5 - Production/Stable", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "Programming Language :: Python :: Free Threading :: 3 - Stable", ], package_data={ 'jaxlib': [ '*.so', + '*.dylib', + '*.dll', '*.pyd*', 'py.typed', 'cpu/*', @@ -105,7 +109,6 @@ def has_ext_modules(self): 'triton/*.so', 'include/xla/ffi/api/*.h', ], - 'jaxlib.xla_extension': ['*.pyi'], }, zip_safe=False, distclass=BinaryDistribution, diff --git a/jaxlib/sharded_device_array.h b/jaxlib/sharded_device_array.h new file mode 100644 index 000000000000..97fb8702cae5 --- /dev/null +++ b/jaxlib/sharded_device_array.h @@ -0,0 +1,216 @@ +/* Copyright 2021 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_SHARDED_DEVICE_ARRAY_H_ +#define JAXLIB_SHARDED_DEVICE_ARRAY_H_ + +#include +#include +#include + +#include "nanobind/nanobind.h" +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "xla/python/types.h" + +// TODO(jblespiau): The current implementation moves the Python logic to C++, +// as a preliminary step to executing the `pmap` execution path from C++. +// It implements the current Python behavior (thus, it may not be optimal, and +// we will be able to modify it later). + +namespace jax { + +// High level introduction. +// +// pmap and other parallel computation functions distribute some computation on +// several devices. On December 2020, the devices mesh (i.e. N-dimensional array +// of devices on which we map the computation) is defined by the user. +// +// We describe how to shard the inputs, and how to map it to the mesh of devices +// using `ShardingSpec`. It's mainly based on 2 components: +// - `sharding`, which specifies how to shard the inputs. +// - `mesh_mapping`, which specifies how to map shards to devices. +// +// The 3 following structs define how to shard one dimension of an ndarry. +// +// `NoSharding` (`None` in Python) means no sharding. +struct NoSharding { + bool operator==(const NoSharding& other) const { return true; } + bool operator!=(const NoSharding& other) const { return false; } +}; + +template +H AbslHashValue(H h, const NoSharding& key) { + return h; +} + +// `Chunked` means that the dimension is split into np.prod(chunks) chunks +// and the split dimension itself is preserved inside the map. +// Those chunks are distributed over `len(chunks)` ShardedAxes axes +// (major-to-minor). +// For example, for a tensor `t` of shape [N] sharded using [Chunked([p])] (with +// p dividing N, let S = N // p) the tensor will be split into p chunks of +// shape [S], such sharded_t[k] = t[k * S: (k+1)*S] (left included, right +// excluded) for k in {0, ... p-1}. +struct Chunked { + public: + explicit Chunked(std::vector chunks_) : chunks(std::move(chunks_)) {} + // The number of chunks per axis. + std::vector chunks; + + bool operator==(const Chunked& other) const { return chunks == other.chunks; } + bool operator!=(const Chunked& other) const { return chunks != other.chunks; } +}; + +template +H AbslHashValue(H h, const Chunked& key) { + h = H::combine(std::move(h), key.chunks); + return h; +} + +// `Unstacked` means that the dimension is split into chunks of size 1, and +// doesn't appear inside the map. `size` is always the dimension size. +// For example, a Tensor t of shape [N] will be sharded into N tensors of shape +// [], when using `Unstacked(N)`. +struct Unstacked { + public: + explicit Unstacked(int sz) : size(sz) {} + int size; + + bool operator==(const Unstacked& other) const { return size == other.size; } + bool operator!=(const Unstacked& other) const { return size != other.size; } +}; + +template +H AbslHashValue(H h, const Unstacked& key) { + h = H::combine(std::move(h), key.size); + return h; +} + +using AvalDimSharding = std::variant; + +// Assigns sharded axes to mesh dimensions. +// +// The devices will be for each dimension which has a sharded `AvalDimSharding` +// When no axis is assigned, the data is replicated. +// As indices are 0-indexed, `ShardedAxis(1)` refers to the second actually +// sharded axis (i.e. counting as if the None dimensions of sharding were +// filtered out). +// For example, given the sharding `[Unstacked(n), None, Chunked(m)]`, an entry +// of `ShardedAxis(1)` refers to the `Chunked(m)` axis, not the `None`. + +struct ShardedAxis { + int axis; + bool operator==(const ShardedAxis& other) const { return axis == other.axis; } + bool operator!=(const ShardedAxis& other) const { return axis != other.axis; } +}; + +template +H AbslHashValue(H h, const ShardedAxis& key) { + h = H::combine(std::move(h), key.axis); + return h; +} + +struct Replicated { + int replicas; + bool operator==(const Replicated& other) const { + return replicas == other.replicas; + } + bool operator!=(const Replicated& other) const { + return replicas != other.replicas; + } +}; + +template +H AbslHashValue(H h, const Replicated& key) { + h = H::combine(std::move(h), key.replicas); + return h; +} + +using MeshDimAssignment = std::variant; + +// Describes how each axis is sharded (if it is), and how it's mapped to the +// devices mesh. See Jax pxla.py for the documentation. +// +// ShardingSpec is shared across pmap, pjit and xpmap. For pmap, an input +// `sharding` is composed of `NoSharding` and at most one `Unstacked`. +// If `axis_size=None`, at least one the inputs has a dimension associated to +// `Unstacked`. +// +// Examples: +// +// 1. For pmap, with a tensor of shape [8, 2, 2], to unstack along the first +// dimension into [8] devices: +// +// sharding = [Unstacked(8), NoSharding, NoSharding] +// mesh_mapping = [ShardedAxis(0)] +// +// 2. With an input array of shape [6], that we want to chunk into [2, 3] +// Assuming a device mesh [3, 4, 2] of devices, we will have: +// +// sharding = [Chunked([2, 3])] +// mesh_mapping = [ShardedAxis(1), Replicated, ShardedAxis(0)] +// +// In particular, in the above example, the ShardedAxis refers to indices +// of the sharded shape [2, 3]. (only the `Chunked` sharding can produce more +// than one dimension). +class ShardingSpec { + public: + ShardingSpec(std::vector sharding, + std::vector mesh_mapping) + : sharding_(std::move(sharding)), + mesh_mapping_(std::move(mesh_mapping)) {} + ShardingSpec(nanobind::iterable py_sharding, + nanobind::iterable py_mesh_mapping) + : sharding_(xla::IterableToVector(py_sharding)), + mesh_mapping_( + xla::IterableToVector(py_mesh_mapping)) {} + + const std::vector& GetSharding() const { return sharding_; } + const std::vector& GetMeshMapping() const { + return mesh_mapping_; + } + + bool operator==(const ShardingSpec& other) const { + return sharding_ == other.sharding_ && mesh_mapping_ == other.mesh_mapping_; + } + + bool operator!=(const ShardingSpec& other) const { return !(*this == other); } + + template + friend H AbslHashValue(H h, const ShardingSpec& key); + + private: + // `sharding` specifies how the array is supposed to get partitioned into + // chunks. Its length matches the rank of the array. See the docstring + // of `AvalDimSharding` for the supported partitioning schemes. + std::vector sharding_; + // `mesh_mapping` describes an assignments of the array chunks created by + // `sharding` to a logical device mesh. The length of the tuple is equal to + // the rank of the mesh. Each mesh dimension can either get partitions of + // data varying along one of the sharded dimensions, or the data can be + // replicated. + std::vector mesh_mapping_; +}; + +template +H AbslHashValue(H h, const ShardingSpec& key) { + h = H::combine(std::move(h), key.sharding_); + h = H::combine(std::move(h), key.mesh_mapping_); + return h; +} + +} // namespace jax + +#endif // JAXLIB_SHARDED_DEVICE_ARRAY_H_ diff --git a/jaxlib/sharding.cc b/jaxlib/sharding.cc new file mode 100644 index 000000000000..a599c94e0c2d --- /dev/null +++ b/jaxlib/sharding.cc @@ -0,0 +1,362 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/sharding.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/partition_spec.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_device.h" // IWYU pragma: keep +#include "jaxlib/py_device_list.h" +#include "jaxlib/sharded_device_array.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/safe_static_init.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/xla_data.pb.h" + +namespace jax { + +namespace nb = nanobind; + +// Gets `PyDeviceList` from a JAX Sharding. +absl::StatusOr> GetPyDeviceList( + nb::handle sharding) { + if (sharding.type().is(NamedSharding::type())) { + TF_ASSIGN_OR_RETURN( + auto ns_device_list, + nb::cast(sharding)->internal_device_list()); + return ns_device_list; + } else if (sharding.type().is(SingleDeviceSharding::type())) { + return nb::cast(sharding) + ->internal_device_list(); + } else if (sharding.type().is(PmapSharding::type())) { + return nb::cast(sharding)->internal_device_list(); + } else if (sharding.type().is(GSPMDSharding::type())) { + return nb::cast(sharding)->internal_device_list(); + } else { + return nb::cast>( + sharding.attr("_internal_device_list")); + } +} + +nb::object CheckAndCanonicalizeMemoryKind( + nb::object memory_kind, const nb_class_ptr& device_list) { + if (!memory_kind.is_none()) { + // If memory kind is not None, check if it's supported by the devices + // mentioned in the Sharding. + auto supported_memory_kinds = PyDeviceList::MemoryKinds(device_list); + if (absl::IsUnimplemented(supported_memory_kinds.status())) { + // TODO(b/473586037): Implement + // PjRtDeviceDescription::default_memory_space() for all backends so this + // fallback isn't necessary. + return nb::none(); + } + if (!supported_memory_kinds.ok()) { + supported_memory_kinds = nb::tuple(); + } + for (nb::handle supported_memory_kind : *supported_memory_kinds) { + if (supported_memory_kind.equal(memory_kind)) { + return memory_kind; + } + } + auto addressable_device_list = + PyDeviceList::AddressableDeviceList(device_list); + if (addressable_device_list->Len() == 0) { + // If the device list is not addressable, we can't check if the memory + // kind is supported, so we assume it is. + return memory_kind; + } + nb::object device_kind = + addressable_device_list->GetItem(0).attr("device_kind"); + std::string_view device_kind_str = nb::cast(device_kind); + auto py_str_formatter = [](std::string* out, nb::handle h) { + *out += nb::cast(nb::str(h)); + }; + throw nb::value_error( + absl::StrCat( + "Could not find memory addressable by device ", device_kind_str, + ". Device ", device_kind_str, + " can address the following memory kinds: ", + absl::StrJoin(*supported_memory_kinds, ", ", py_str_formatter), + ". Got memory kind: ", nb::cast(memory_kind)) + .c_str()); + } + // If memory kind is None, canonicalize to default memory. + absl::StatusOr default_memory_kind = + PyDeviceList::DefaultMemoryKind(device_list); + if (!default_memory_kind.ok()) { + return nb::none(); + } + return *std::move(default_memory_kind); +} + +// This list is to check for valid memory kinds when an AbstractMesh is passed +// to NamedSharding. +static const std::array valid_memory_kinds = { + "device", + "pinned_host", + "unpinned_host", +}; + +NamedSharding::NamedSharding(nb::object mesh, nb_class_ptr spec, + nb::object memory_kind, + nb::object logical_device_ids) + : Sharding(/*num_devices=*/[&mesh]() { + return nb::cast(mesh.attr("size")); + }()), + mesh_(std::move(mesh)), + spec_(std::move(spec)), + memory_kind_(std::move(memory_kind)), + logical_device_ids_(std::move(logical_device_ids)) { + nb::object idl = nb::object(mesh_.attr("_internal_device_list")); + if (idl.is_none()) { + internal_device_list_ = std::nullopt; + } else { + internal_device_list_ = nb::cast>(idl); + } + if (internal_device_list_) { + memory_kind_ = + CheckAndCanonicalizeMemoryKind(memory_kind_, *internal_device_list_); + } else { + if (!memory_kind_.is_none() && + (std::find(valid_memory_kinds.begin(), valid_memory_kinds.end(), + nb::cast(memory_kind_)) == + valid_memory_kinds.end())) { + throw nb::value_error( + absl::StrCat("Got invalid memory kind: ", + nb::cast(memory_kind_), + ". Valid memory kinds are: ", + absl::StrJoin(valid_memory_kinds, ", ")) + .c_str()); + } + } + + // TODO(phawkins): this leaks a reference to the check_pspec function. + // A better way to fix this would be to move PartitionSpec and this check into + // C++. + auto init_fn = []() { + nb::module_ si = nb::module_::import_("jax._src.named_sharding"); + return std::make_unique(si.attr("check_pspec")); + }; + nb::object& check_pspec = xla::SafeStaticInit(init_fn); + check_pspec(mesh_, spec_); +} + +/*static*/ PyObject* NamedSharding::type_ = nullptr; + +/*static*/ void NamedSharding::InitializeType() { + // Intentionally leaks a reference. + type_ = nanobind::type().inc_ref().ptr(); +} + +bool NamedSharding::operator==(const NamedSharding& other) const { + // Caution: you may need to update EqualShardingsForJit in jax_jit.cc as well. + return mesh().equal(other.mesh()) && *spec() == *other.spec() && + memory_kind().equal(other.memory_kind()) && + logical_device_ids().equal(other.logical_device_ids()); +} + +bool NamedSharding::Eq(const nanobind::object& other) const { + if (!other.ptr() || other.is_none()) { + return false; + } + const NamedSharding* other_sharding; + if (!nb::try_cast(other, other_sharding)) { + return false; + } + return this == other_sharding || *this == *other_sharding; +} + +nb::int_ NamedSharding::Hash() const { + // Caution: you may need to update HashShardingForJit in jax_jit.cc as well. + return nb::cast(hash_.Get([&]() { + size_t h = + absl::HashOf(nb::hash(mesh_), spec_->Hash(), nb::hash(memory_kind_), + nb::hash(logical_device_ids_)); + Py_hash_t s = absl::bit_cast(h); // Python hashes are signed. + return nb::cast( + s == -1 ? -2 : s); // -1 must not be used as a Python hash value. + })); +} + +SingleDeviceSharding::SingleDeviceSharding(nb::object device, + nb::object memory_kind) + : Sharding(/*num_devices=*/1), + device_(device), + memory_kind_(std::move(memory_kind)), + internal_device_list_( + make_nb_class(nb::make_tuple(std::move(device)))) { + memory_kind_ = + CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); +} + +/*static*/ PyObject* SingleDeviceSharding::type_ = nullptr; + +/*static*/ void SingleDeviceSharding::InitializeType() { + // Intentionally leaks a reference. + type_ = nanobind::type().inc_ref().ptr(); +} + +SingleDeviceSharding::SingleDeviceSharding(nb_class_ptr client, + xla::ifrt::DeviceListRef device_list, + nb::object memory_kind) + : Sharding(/*num_devices=*/1), + device_(client->GetPyDevice(device_list->devices().front())), + memory_kind_(std::move(memory_kind)), + internal_device_list_(make_nb_class( + std::move(client), std::move(device_list))) { + memory_kind_ = + CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); +} + +PmapSharding::PmapSharding(xla::nb_numpy_ndarray devices, + ShardingSpec sharding_spec) + : Sharding(/*num_devices=*/devices.size()), + devices_(std::move(devices)), + sharding_spec_(std::move(sharding_spec)) { + nb::object flat_devices = devices_.attr("flat"); + internal_device_list_ = make_nb_class(nb::tuple(flat_devices)); +} + +/*static*/ PyObject* PmapSharding::type_ = nullptr; + +// /*static*/ nanobind::handle PmapSharding::type() { return type_; } + +/*static*/ void PmapSharding::InitializeType() { + // Intentionally leaks a reference. + type_ = nanobind::type().inc_ref().ptr(); +} + +GSPMDSharding::GSPMDSharding(nb_class_ptr devices, + xla::HloSharding op_sharding, + nb::object memory_kind) + : Sharding(/*num_devices=*/nb::len(devices.ptr())), + devices_(std::move(devices)), + hlo_sharding_(std::move(op_sharding)), + memory_kind_(std::move(memory_kind)) { + internal_device_list_ = devices_; + // This checks in python if the memory kind is correct for the given + // devices. Currently in python this check is optimized but we want to + // move that check to C++ after which we can remove this call. + CHECK(devices_->Len() != 0) + << "Devices given to GSPMDSharding must not be empty"; + memory_kind_ = + CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); +} + +/*static*/ PyObject* GSPMDSharding::type_ = nullptr; + +/*static*/ void GSPMDSharding::InitializeType() { + // Intentionally leaks a reference. + type_ = nanobind::type().inc_ref().ptr(); +} + +void RegisterSharding(nb::module_& m) { + nb::class_(m, "Sharding").def(nb::init<>()); + + nb::class_(m, "NamedSharding", nb::dynamic_attr()) + .def(nb::init, nb::object, + nb::object>(), + nb::arg("mesh"), nb::arg("spec"), + nb::arg("memory_kind").none() = nb::none(), + nb::arg("_logical_device_ids").none() = nb::none()) + .def_prop_ro("mesh", &NamedSharding::mesh) + .def_prop_ro("spec", &NamedSharding::spec) + .def_prop_ro("_memory_kind", &NamedSharding::memory_kind) + .def_prop_ro("_logical_device_ids", &NamedSharding::logical_device_ids) + .def_prop_ro("_internal_device_list", + [](const NamedSharding& s) { + return xla::ValueOrThrow(s.internal_device_list()); + }) + .def("__eq__", &NamedSharding::Eq, nb::arg(), nb::is_operator()) + .def("__hash__", &NamedSharding::Hash); + NamedSharding::InitializeType(); + + nb::class_(m, "SingleDeviceSharding", + nb::dynamic_attr()) + .def(nb::init(), nb::arg("device"), + nb::arg("memory_kind").none() = nb::none()) + .def_prop_ro("_device", &SingleDeviceSharding::device) + .def_prop_ro("_memory_kind", &SingleDeviceSharding::memory_kind) + .def_prop_ro("_internal_device_list", + &SingleDeviceSharding::internal_device_list); + SingleDeviceSharding::InitializeType(); + + nb::class_(m, "PmapSharding", nb::dynamic_attr()) + .def( + "__init__", + [](PmapSharding* self, nb::object devices, + ShardingSpec sharding_spec) { + new (self) PmapSharding(xla::nb_numpy_ndarray::ensure(devices), + std::move(sharding_spec)); + }, + nb::arg("devices"), nb::arg("sharding_spec")) + .def_prop_ro("devices", &PmapSharding::devices) + .def_prop_ro("sharding_spec", &PmapSharding::sharding_spec) + .def_prop_ro("_internal_device_list", + &PmapSharding::internal_device_list); + PmapSharding::InitializeType(); + + nb::class_(m, "GSPMDSharding", nb::dynamic_attr()) + // NOTE: We explicitly list the two PyDeviceList ctors first since they + // are the fast path and PyDeviceList conforms to `nb::sequence` so we + // can silently fall back to the slow sequence ctor(s). + .def(nb::init, xla::OpSharding, nb::object>(), + nb::arg("devices"), nb::arg("op_sharding"), + nb::arg("memory_kind").none() = nb::none()) + .def(nb::init, xla::HloSharding, nb::object>(), + nb::arg("devices"), nb::arg("op_sharding"), + nb::arg("memory_kind").none() = nb::none()) + .def(nb::init, xla::OpSharding, + nb::object>(), + nb::arg("devices"), nb::arg("op_sharding"), + nb::arg("memory_kind").none() = nb::none()) + .def(nb::init, xla::HloSharding, + nb::object>(), + nb::arg("devices"), nb::arg("op_sharding"), + nb::arg("memory_kind").none() = nb::none()) + .def_prop_ro("_devices", &GSPMDSharding::devices) + .def_prop_ro("_hlo_sharding", &GSPMDSharding::hlo_sharding) + .def_prop_ro("_memory_kind", &GSPMDSharding::memory_kind) + .def_prop_ro("_internal_device_list", + &GSPMDSharding::internal_device_list); + GSPMDSharding::InitializeType(); +} + +} // namespace jax diff --git a/jaxlib/sharding.h b/jaxlib/sharding.h new file mode 100644 index 000000000000..ffc5c90b0acf --- /dev/null +++ b/jaxlib/sharding.h @@ -0,0 +1,254 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_SHARDING_H_ +#define JAXLIB_SHARDING_H_ + +#include + +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/hash/hash.h" +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "jaxlib/cached_py_object.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/partition_spec.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_device_list.h" +#include "jaxlib/sharded_device_array.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/nb_numpy.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace jax { + +class Sharding { + public: + Sharding() = default; + + // This constructor is used in the fast path to retrieve the number of devices + // without falling back to python. This is only used in the cpp path. + explicit Sharding(int num_devices) : num_devices_(num_devices) {} + + virtual ~Sharding() = default; + + int num_devices() const { return num_devices_; } + + private: + int num_devices_; +}; + +// Gets `PyDeviceList` from a JAX Sharding. +absl::StatusOr> GetPyDeviceList( + nanobind::handle sharding); + +// Checks if the memory kind is valid, and canonicalizes the +// memory kind to default memory on backends that support memories. +nanobind::object CheckAndCanonicalizeMemoryKind( + nanobind::object memory_kind, + const nb_class_ptr& device_list); + +class NamedSharding : public Sharding { + public: + NamedSharding(nanobind::object mesh, nb_class_ptr spec, + nanobind::object memory_kind, + nanobind::object logical_device_ids); + + const nanobind::object& mesh() const { return mesh_; } + const nb_class_ptr& spec() const { return spec_; } + const nanobind::object& memory_kind() const { return memory_kind_; } + const nanobind::object& logical_device_ids() const { + return logical_device_ids_; + } + + static nanobind::handle type() { return type_; } + static void InitializeType(); + + absl::StatusOr> internal_device_list() const { + if (internal_device_list_) { + return *internal_device_list_; + } + return xla::InvalidArgument( + "internal_device_list is not implemented for " + "`jax.sharding.AbstractMesh`"); + } + + bool operator==(const NamedSharding& other) const; + + bool Eq(const nanobind::object& other) const; // Python __eq__ + nanobind::int_ Hash() const; // Python __hash__ + + private: + nanobind::object mesh_; + nb_class_ptr spec_; + nanobind::object memory_kind_; + nanobind::object logical_device_ids_; + std::optional> internal_device_list_; + mutable CachedPyObject hash_; + static PyObject* type_; +}; + +class SingleDeviceSharding : public Sharding { + public: + explicit SingleDeviceSharding( + nanobind::object device, nanobind::object memory_kind = nanobind::none()); + + // Used only in C++ to accelerate `PyArray::MakeFromSingleDeviceArray()`. + SingleDeviceSharding(nb_class_ptr client, + xla::ifrt::DeviceListRef device_list, + nanobind::object memory_kind); + + const nanobind::object& device() const { return device_; } + const nanobind::object& memory_kind() const { return memory_kind_; } + + static nanobind::handle type() { return type_; } + static void InitializeType(); + + nb_class_ptr internal_device_list() const { + return internal_device_list_; + } + + private: + nanobind::object device_; + nanobind::object memory_kind_; + nb_class_ptr internal_device_list_; + + static PyObject* type_; +}; + +// The C++ implementation of jax.PmapSharding in python. It contains a few key +// data members and methods that are performance-critical. +class PmapSharding : public Sharding { + public: + PmapSharding(xla::nb_numpy_ndarray devices, ShardingSpec sharding_spec); + + ~PmapSharding() override = default; + + xla::nb_numpy_ndarray devices() const { return devices_; } + + const ShardingSpec& sharding_spec() const { return sharding_spec_; } + + static nanobind::handle type() { return type_; } + static void InitializeType(); + + nb_class_ptr internal_device_list() const { + return internal_device_list_; + } + + private: + xla::nb_numpy_ndarray devices_; + ShardingSpec sharding_spec_; + nb_class_ptr internal_device_list_; + static PyObject* type_; +}; + +class GSPMDSharding : public Sharding { + public: + GSPMDSharding(nanobind::sequence devices, xla::OpSharding op_sharding, + nanobind::object memory_kind) + : GSPMDSharding( + make_nb_class(nanobind::tuple(devices)), + xla::ValueOrThrow(xla::HloSharding::FromProto(op_sharding)), + std::move(memory_kind)) {} + + GSPMDSharding(nanobind::sequence devices, xla::HloSharding op_sharding, + nanobind::object memory_kind) + : GSPMDSharding(make_nb_class(nanobind::tuple(devices)), + std::move(op_sharding), std::move(memory_kind)) {} + + GSPMDSharding(nb_class_ptr devices, xla::OpSharding op_sharding, + nanobind::object memory_kind) + : GSPMDSharding( + std::move(devices), + xla::ValueOrThrow(xla::HloSharding::FromProto(op_sharding)), + std::move(memory_kind)) {} + + GSPMDSharding(nb_class_ptr devices, + xla::HloSharding op_sharding, nanobind::object memory_kind); + + nb_class_ptr devices() const { return devices_; } + const nanobind::object& memory_kind() const { return memory_kind_; } + + size_t Hash() { + if (!hash_.has_value()) { + hash_ = CalculateHash(); + } + return *hash_; + } + + static nanobind::handle type() { return type_; } + static void InitializeType(); + + const xla::HloSharding& hlo_sharding() const { return hlo_sharding_; } + + bool operator==(const GSPMDSharding& other) const { + return AreOpShardingsEqual(*this, other) && + this->devices().equal(other.devices()) && + this->memory_kind().equal(other.memory_kind()); + } + + nb_class_ptr internal_device_list() const { + return internal_device_list_; + } + + private: + size_t CalculateHash() const { + // We only hash `hlo_sharding_` here for performance. + return absl::Hash()(hlo_sharding_); + } + + static bool AreOpShardingsEqual(const GSPMDSharding& a, + const GSPMDSharding& b) { + // If the OpSharding object is the same, return true + if (&a.hlo_sharding() == &b.hlo_sharding()) { + return true; + } + // If both OpShardings are replicated, return true + if (a.IsOpShardingReplicated() && b.IsOpShardingReplicated()) { + return true; + } + return a.hlo_sharding() == b.hlo_sharding(); + } + + bool IsOpShardingReplicated() const { + // For JAX, shardings with 1 device are considered as replicated in its + // semantics so that downstream things continue to work. + if (hlo_sharding_.tile_assignment().num_elements() == 1) { + return true; + } + return hlo_sharding().IsReplicated(); + } + + nb_class_ptr devices_; + xla::HloSharding hlo_sharding_; + nanobind::object memory_kind_; + std::optional hash_; + nb_class_ptr internal_device_list_; + + static PyObject* type_; +}; + +void RegisterSharding(nanobind::module_& m); + +} // namespace jax + +#endif // JAXLIB_SHARDING_H_ diff --git a/jaxlib/symlink_files.bzl b/jaxlib/symlink_files.bzl index 203b66022926..65f910181f7f 100644 --- a/jaxlib/symlink_files.bzl +++ b/jaxlib/symlink_files.bzl @@ -84,6 +84,8 @@ symlink_inputs( ) """ +visibility(["//jaxlib/..."]) + def _symlink_files_impl(ctx): flatten = ctx.attr.flatten strip_prefix = ctx.attr.strip_prefix diff --git a/jaxlib/to_ifrt_sharding.cc b/jaxlib/to_ifrt_sharding.cc new file mode 100644 index 000000000000..9739937755e1 --- /dev/null +++ b/jaxlib/to_ifrt_sharding.cc @@ -0,0 +1,137 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/to_ifrt_sharding.h" + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_device_list.h" +#include "jaxlib/sharding.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/pjrt_ifrt/xla_sharding.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/xla_data.pb.h" + +namespace jax { + +namespace nb = ::nanobind; + +// Gets `xla::HloSharding` from a JAX Sharding. +xla::HloSharding GetXlaHloSharding(nb::handle sharding, + int64_t num_dimensions) { + if (sharding.type().is(nb::handle(GSPMDSharding::type().ptr()))) { + return nb::cast(nb::handle(sharding.ptr()))->hlo_sharding(); + } else { + return nb::cast( + sharding.attr("_to_xla_hlo_sharding")(num_dimensions)); + } +} + +// Gets `xla::ifrt::DeviceList` from a JAX Sharding. +absl::StatusOr GetIfrtDeviceList( + nb::handle sharding_py) { + TF_ASSIGN_OR_RETURN(auto py_device_list, GetPyDeviceList(sharding_py)); + return py_device_list->ifrt_device_list(); +} + +// Gets `xla::ifrt::MemoryKind` from a JAX Sharding. +xla::ifrt::MemoryKind GetMemoryKind(nb::handle sharding) { + nb::object py_memory_kind = nb::none(); + + // sharding.attr("memory_kind") can crash if sharding was originally created + // from C++ and casted into a Python Sharding object. Thus, we cast sharding + // to a C++ type and use C++ `memory_kind()` method, which bypasses any Python + // attribute access. + nb::handle type = sharding.type(); + if (type.is(NamedSharding::type())) { + py_memory_kind = nb::cast(sharding)->memory_kind(); + } else if (type.is(SingleDeviceSharding::type())) { + py_memory_kind = + nb::cast(sharding)->memory_kind(); + } else if (type.is(GSPMDSharding::type())) { + py_memory_kind = nb::cast(sharding)->memory_kind(); + } else { + py_memory_kind = sharding.attr("memory_kind"); + } + + if (py_memory_kind.is_none()) { + return xla::ifrt::MemoryKind(); + } + return xla::ifrt::MemoryKind(nb::cast(py_memory_kind)); +} + +// Converts a JAX Sharding into `xla::ifrt::HloSharding`. +absl::StatusOr GetIfrtHloSharding( + nb::handle sharding, const xla::ifrt::Shape& shape) { + TF_ASSIGN_OR_RETURN(xla::ifrt::DeviceListRef device_list, + GetIfrtDeviceList(sharding)); + xla::ifrt::MemoryKind memory_kind = GetMemoryKind(sharding.ptr()); + xla::HloSharding hlo_sharding = + GetXlaHloSharding(sharding, shape.dims().size()); + return xla::ifrt::HloSharding::Create( + std::move(device_list), std::move(memory_kind), std::move(hlo_sharding)); +} + +// Converts a JAX Sharding into `xla::ifrt::ConcreteEvenSharding`. +absl::StatusOr GetIfrtConcreteEvenSharding( + nb::handle sharding, xla::ifrt::DType dtype, + const xla::ifrt::Shape& shape) { + TF_ASSIGN_OR_RETURN(xla::ifrt::DeviceListRef device_list, + GetIfrtDeviceList(sharding)); + xla::ifrt::MemoryKind memory_kind = GetMemoryKind(sharding.ptr()); + TF_ASSIGN_OR_RETURN(xla::PrimitiveType xla_primitive_type, + xla::ifrt::ToPrimitiveType(dtype)); + // The XLA shape's layout is irrelevant because we only need to know the + // tile shape, which is independent from the layout. + xla::Shape xla_shape = xla::ShapeUtil::MakeShapeWithDescendingLayout( + xla_primitive_type, shape.dims()); + xla::HloSharding hlo_sharding = + GetXlaHloSharding(sharding, shape.dims().size()); + xla::Shape tile_shape = hlo_sharding.TileShape(xla_shape); + xla::ifrt::Shape shard_shape(xla::ifrt::Shape::Dimensions( + tile_shape.dimensions().begin(), tile_shape.dimensions().end())); + return xla::ifrt::ConcreteEvenSharding::Create( + std::move(device_list), std::move(memory_kind), shape, + /*shard_shape=*/std::move(shard_shape)); +} + +// Converts a JAX Sharding into `xla::ifrt::ConcreteSharding`. +absl::StatusOr GetIfrtConcreteSharding( + nb::handle sharding, const xla::ifrt::Shape& shape, + std::vector shard_shapes) { + TF_ASSIGN_OR_RETURN(xla::ifrt::DeviceListRef device_list, + GetIfrtDeviceList(sharding)); + xla::ifrt::MemoryKind memory_kind = GetMemoryKind(sharding.ptr()); + return xla::ifrt::ConcreteSharding::Create( + std::move(device_list), std::move(memory_kind), shape, + /*shard_shapes=*/std::move(shard_shapes)); +} + +} // namespace jax diff --git a/jaxlib/to_ifrt_sharding.h b/jaxlib/to_ifrt_sharding.h new file mode 100644 index 000000000000..31470748998d --- /dev/null +++ b/jaxlib/to_ifrt_sharding.h @@ -0,0 +1,61 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_TO_IFRT_SHARDING_H_ +#define JAXLIB_TO_IFRT_SHARDING_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" + +namespace jax { + +// Gets `xla::HloSharding` from a JAX Sharding. +xla::HloSharding GetXlaHloSharding(nanobind::handle sharding, + int64_t num_dimensions); + +// Gets `xla::ifrt::DeviceList` from a JAX Sharding. +absl::StatusOr GetIfrtDeviceList( + nanobind::handle sharding_py); + +// Gets `xla::ifrt::MemoryKind` from a JAX Sharding. +xla::ifrt::MemoryKind GetMemoryKind(nanobind::handle sharding); + +// Converts a JAX Sharding into `xla::ifrt::HloSharding`. +absl::StatusOr GetIfrtHloSharding( + nanobind::handle sharding, const xla::ifrt::Shape& shape); + +// Converts a JAX Sharding into `xla::ifrt::ConcreteEvenSharding`. +absl::StatusOr GetIfrtConcreteEvenSharding( + nanobind::handle sharding, xla::ifrt::DType dtype, + const xla::ifrt::Shape& shape); + +// Converts a JAX Sharding into `xla::ifrt::ConcreteSharding`. +absl::StatusOr GetIfrtConcreteSharding( + nanobind::handle sharding, const xla::ifrt::Shape& shape, + std::vector shard_shapes); + +} // namespace jax + +#endif // JAXLIB_TO_IFRT_SHARDING_H_ diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index afa5866e286d..7bacc30ccd7e 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -15,9 +15,11 @@ # JAX is Autograd and XLA load("@bazel_skylib//lib:selects.bzl", "selects") -load("@bazel_skylib//rules:common_settings.bzl", "string_flag") +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "string_flag") +load("@cuda_cudart//:version.bzl", cuda_major_version = "VERSION") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") +load("@rules_python//python:py_binary.bzl", "py_binary") load( "@xla//third_party/py:py_import.bzl", "py_import", @@ -29,16 +31,22 @@ load( load( "//jaxlib:jax.bzl", "PLATFORM_TAGS_DICT", - "if_windows", + "compare_srcs_and_test_deps_test", + "get_test_suite_list", + "if_pypi_cuda_wheel_deps", "jax_py_test", "jax_wheel", "pytype_strict_library", + "pytype_test", + "wheel_sources", ) licenses(["notice"]) # Apache 2 package(default_visibility = ["//visibility:public"]) +exports_files(["wheel_size_test.py"]) + genrule( name = "platform_tags_py", srcs = [], @@ -52,119 +60,10 @@ pytype_strict_library( "build_utils.py", ":platform_tags_py", ], + deps = ["@xla//third_party/py:setup_py_nvidia_dependencies_util"], ) -py_binary( - name = "build_wheel", - srcs = ["build_wheel.py"], - data = [ - "LICENSE.txt", - "//jaxlib", - "//jaxlib:README.md", - "//jaxlib:setup.py", - "@xla//xla/ffi/api:api.h", - "@xla//xla/ffi/api:c_api.h", - "@xla//xla/ffi/api:ffi.h", - "@xla//xla/python:xla_client.py", - "@xla//xla/python:xla_extension", - ] + if_windows([ - "//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll", - ]), - deps = [ - ":build_utils", - "@bazel_tools//tools/python/runfiles", - "@pypi_build//:pkg", - "@pypi_setuptools//:pkg", - "@pypi_wheel//:pkg", - ], -) - -jax_py_test( - name = "build_wheel_test", - srcs = ["build_wheel_test.py"], - data = [":build_wheel"], - deps = [ - "@bazel_tools//tools/python/runfiles", - ], -) - -cc_binary( - name = "pjrt_c_api_gpu_plugin.so", - linkopts = [ - "-Wl,--version-script,$(location :gpu_version_script.lds)", - "-Wl,--no-undefined", - ], - linkshared = True, - deps = [ - ":gpu_version_script.lds", - "@xla//xla/pjrt/c:pjrt_c_api_gpu", - "@xla//xla/pjrt/c:pjrt_c_api_gpu_version_script.lds", - "@xla//xla/service:gpu_plugin", - ] + if_cuda([ - "//jaxlib/mosaic/gpu:custom_call", - "@xla//xla/stream_executor:cuda_platform", - ]) + if_rocm([ - "@xla//xla/stream_executor:rocm_platform", - ]), -) - -py_binary( - name = "build_gpu_plugin_wheel", - srcs = ["build_gpu_plugin_wheel.py"], - data = [ - "LICENSE.txt", - ":pjrt_c_api_gpu_plugin.so", - ] + if_cuda([ - "//jaxlib:version", - "//jaxlib/cuda:cuda_gpu_support", - "//jax_plugins/cuda:pyproject.toml", - "//jax_plugins/cuda:setup.py", - "//jax_plugins/cuda:__init__.py", - "@local_config_cuda//cuda:cuda-nvvm", - ]) + if_rocm([ - "//jaxlib:version", - "//jaxlib/rocm:rocm_gpu_support", - "//jax_plugins/rocm:pyproject.toml", - "//jax_plugins/rocm:setup.py", - "//jax_plugins/rocm:__init__.py", - ]), - deps = [ - ":build_utils", - "@bazel_tools//tools/python/runfiles", - "@pypi_build//:pkg", - "@pypi_setuptools//:pkg", - "@pypi_wheel//:pkg", - ], -) - -py_binary( - name = "build_gpu_kernels_wheel", - srcs = ["build_gpu_kernels_wheel.py"], - data = [ - "LICENSE.txt", - ] + if_cuda([ - "//jaxlib:version", - "//jaxlib/mosaic/gpu:mosaic_gpu", - "//jaxlib/cuda:cuda_plugin_extension", - "//jaxlib/cuda:cuda_gpu_support", - "//jax_plugins/cuda:plugin_pyproject.toml", - "//jax_plugins/cuda:plugin_setup.py", - "@local_config_cuda//cuda:cuda-nvvm", - ]) + if_rocm([ - "//jaxlib:version", - "//jaxlib/rocm:rocm_plugin_extension", - "//jaxlib/rocm:rocm_gpu_support", - "//jax_plugins/rocm:plugin_pyproject.toml", - "//jax_plugins/rocm:plugin_setup.py", - ]), - deps = [ - ":build_utils", - "@bazel_tools//tools/python/runfiles", - "@pypi_build//:pkg", - "@pypi_setuptools//:pkg", - "@pypi_wheel//:pkg", - ], -) +# Platform configurations. selects.config_setting_group( name = "macos", @@ -222,6 +121,8 @@ selects.config_setting_group( ], ) +# Flags for the new wheel build rules. + string_flag( name = "jaxlib_git_hash", build_setting_default = "", @@ -232,69 +133,161 @@ string_flag( build_setting_default = "dist", ) -NVIDIA_WHEELS_DEPS = [ - "@pypi_nvidia_cublas_cu12//:whl", - "@pypi_nvidia_cuda_cupti_cu12//:whl", - "@pypi_nvidia_cuda_runtime_cu12//:whl", - "@pypi_nvidia_cudnn_cu12//:whl", - "@pypi_nvidia_cufft_cu12//:whl", - "@pypi_nvidia_cusolver_cu12//:whl", - "@pypi_nvidia_cusparse_cu12//:whl", - "@pypi_nvidia_nccl_cu12//:whl", - "@pypi_nvidia_nvjitlink_cu12//:whl", -] +# Wheel targets. + +# Jaxlib wheel targets. +py_binary( + name = "build_wheel_tool", + srcs = ["build_wheel.py"], + main = "build_wheel.py", + deps = [ + ":build_utils", + "@bazel_tools//tools/python/runfiles", + "@pypi//build", + "@pypi//setuptools", + "@pypi//wheel", + ], +) + +wheel_sources( + name = "jaxlib_sources", + data_srcs = [ + "//jaxlib", + "//jaxlib:jaxlib_binaries", + "//jaxlib:_jax", + ], + hdr_srcs = [ + "@xla//xla/ffi/api:ffi", + ], + py_srcs = [ + "//jaxlib", + ], + static_srcs = [ + "//jaxlib:README.md", + "LICENSE.txt", + "//jaxlib:setup.py", + "//jaxlib:xla_client.py", + ], + symlink_data_srcs = [ + "//jaxlib", + ], +) jax_wheel( name = "jaxlib_wheel", no_abi = False, - wheel_binary = ":build_wheel", + source_files = [":jaxlib_sources"], + wheel_binary = ":build_wheel_tool", wheel_name = "jaxlib", ) -py_import( - name = "jaxlib_py_import", - wheel = ":jaxlib_wheel", -) - jax_wheel( name = "jaxlib_wheel_editable", editable = True, - wheel_binary = ":build_wheel", + source_files = [":jaxlib_sources"], + wheel_binary = ":build_wheel_tool", wheel_name = "jaxlib", ) +# JAX plugin wheel targets. +pytype_strict_library( + name = "version", + srcs = ["//jaxlib:version"], +) + +py_binary( + name = "build_gpu_kernels_wheel_tool", + srcs = ["build_gpu_kernels_wheel.py"], + main = "build_gpu_kernels_wheel.py", + deps = [ + ":build_utils", + "@bazel_tools//tools/python/runfiles", + "@pypi//build", + "@pypi//setuptools", + "@pypi//wheel", + ], +) + +wheel_sources( + name = "jax_plugin_sources", + data_srcs = [ + ] + if_cuda([ + "//jaxlib/cuda:cuda_gpu_support", + "@local_config_cuda//cuda:cuda-nvvm", + "//jaxlib/cuda:cuda_plugin_extension", + "//jaxlib/mosaic/gpu:mosaic_gpu", + ]) + if_rocm([ + "//jaxlib/rocm:rocm_gpu_support", + "//jaxlib/rocm:rocm_plugin_extension", + ]), + py_srcs = [":version"] + if_cuda([ + "//jaxlib/cuda:cuda_gpu_support", + "//jaxlib/mosaic/gpu:mosaic_gpu", + ]) + if_rocm([ + "//jaxlib/rocm:rocm_gpu_support", + ]), + static_srcs = [ + "LICENSE.txt", + ] + if_cuda([ + "//jax_plugins/cuda:plugin_pyproject.toml", + "//jax_plugins/cuda:plugin_setup.py", + ]) + if_rocm([ + "//jax_plugins/rocm:plugin_pyproject.toml", + "//jax_plugins/rocm:plugin_setup.py", + ]), +) + jax_wheel( - name = "jax_cuda_plugin_wheel", + name = "jax_cuda12_plugin_wheel", enable_cuda = True, no_abi = False, # TODO(b/371217563) May use hermetic cuda version here. platform_version = "12", - wheel_binary = ":build_gpu_kernels_wheel", + source_files = [":jax_plugin_sources"], + wheel_binary = ":build_gpu_kernels_wheel_tool", wheel_name = "jax_cuda12_plugin", ) -py_import( - name = "jax_cuda_plugin_py_import", - wheel = ":jax_cuda_plugin_wheel", - wheel_deps = if_cuda(NVIDIA_WHEELS_DEPS), +jax_wheel( + name = "jax_cuda13_plugin_wheel", + enable_cuda = True, + no_abi = False, + # TODO(b/371217563) May use hermetic cuda version here. + platform_version = "13", + source_files = [":jax_plugin_sources"], + wheel_binary = ":build_gpu_kernels_wheel_tool", + wheel_name = "jax_cuda13_plugin", ) jax_wheel( - name = "jax_cuda_plugin_wheel_editable", + name = "jax_cuda12_plugin_wheel_editable", editable = True, enable_cuda = True, # TODO(b/371217563) May use hermetic cuda version here. platform_version = "12", - wheel_binary = ":build_gpu_kernels_wheel", + source_files = [":jax_plugin_sources"], + wheel_binary = ":build_gpu_kernels_wheel_tool", wheel_name = "jax_cuda12_plugin", ) +jax_wheel( + name = "jax_cuda13_plugin_wheel_editable", + editable = True, + enable_cuda = True, + # TODO(b/371217563) May use hermetic cuda version here. + platform_version = "13", + source_files = [":jax_plugin_sources"], + wheel_binary = ":build_gpu_kernels_wheel_tool", + wheel_name = "jax_cuda13_plugin", +) + jax_wheel( name = "jax_rocm_plugin_wheel", enable_rocm = True, no_abi = False, platform_version = "60", - wheel_binary = ":build_gpu_kernels_wheel", + source_files = [":jax_plugin_sources"], + wheel_binary = ":build_gpu_kernels_wheel_tool", wheel_name = "jax_rocm60_plugin", ) @@ -303,42 +296,107 @@ jax_wheel( editable = True, enable_rocm = True, platform_version = "60", - wheel_binary = ":build_gpu_kernels_wheel", + source_files = [":jax_plugin_sources"], + wheel_binary = ":build_gpu_kernels_wheel_tool", wheel_name = "jax_rocm60_plugin", ) +# JAX PJRT wheel targets. + +py_binary( + name = "build_gpu_plugin_wheel_tool", + srcs = ["build_gpu_plugin_wheel.py"], + main = "build_gpu_plugin_wheel.py", + deps = [ + ":build_utils", + "@bazel_tools//tools/python/runfiles", + "@pypi//build", + "@pypi//setuptools", + "@pypi//wheel", + ], +) + +wheel_sources( + name = "jax_pjrt_sources", + data_srcs = if_cuda([ + "//jax_plugins/cuda:cuda_plugin", + "//jaxlib/cuda:cuda_gpu_support", + "@local_config_cuda//cuda:cuda-nvvm", + ]) + if_rocm([ + "//jax_plugins/rocm:rocm_plugin", + "//jaxlib/rocm:rocm_gpu_support", + ]), + py_srcs = [ + ":version", + ] + if_cuda([ + "//jaxlib/cuda:cuda_gpu_support", + ]) + if_rocm([ + "//jaxlib/rocm:rocm_gpu_support", + ]), + static_srcs = [ + "LICENSE.txt", + ] + if_cuda([ + "//jax_plugins/cuda:pyproject.toml", + "//jax_plugins/cuda:setup.py", + "//jax_plugins/cuda:__init__.py", + ]) + if_rocm([ + "//jax_plugins/rocm:pyproject.toml", + "//jax_plugins/rocm:setup.py", + "//jax_plugins/rocm:__init__.py", + ]), +) + jax_wheel( - name = "jax_cuda_pjrt_wheel", + name = "jax_cuda12_pjrt_wheel", enable_cuda = True, no_abi = True, # TODO(b/371217563) May use hermetic cuda version here. platform_version = "12", - wheel_binary = ":build_gpu_plugin_wheel", + source_files = [":jax_pjrt_sources"], + wheel_binary = ":build_gpu_plugin_wheel_tool", wheel_name = "jax_cuda12_pjrt", ) -py_import( - name = "jax_cuda_pjrt_py_import", - wheel = ":jax_cuda_pjrt_wheel", - wheel_deps = if_cuda(NVIDIA_WHEELS_DEPS), +jax_wheel( + name = "jax_cuda13_pjrt_wheel", + enable_cuda = True, + no_abi = True, + # TODO(b/371217563) May use hermetic cuda version here. + platform_version = "13", + source_files = [":jax_pjrt_sources"], + wheel_binary = ":build_gpu_plugin_wheel_tool", + wheel_name = "jax_cuda13_pjrt", ) jax_wheel( - name = "jax_cuda_pjrt_wheel_editable", + name = "jax_cuda12_pjrt_wheel_editable", editable = True, enable_cuda = True, # TODO(b/371217563) May use hermetic cuda version here. platform_version = "12", - wheel_binary = ":build_gpu_plugin_wheel", + source_files = [":jax_pjrt_sources"], + wheel_binary = ":build_gpu_plugin_wheel_tool", wheel_name = "jax_cuda12_pjrt", ) +jax_wheel( + name = "jax_cuda13_pjrt_wheel_editable", + editable = True, + enable_cuda = True, + # TODO(b/371217563) May use hermetic cuda version here. + platform_version = "13", + source_files = [":jax_pjrt_sources"], + wheel_binary = ":build_gpu_plugin_wheel_tool", + wheel_name = "jax_cuda13_pjrt", +) + jax_wheel( name = "jax_rocm_pjrt_wheel", enable_rocm = True, no_abi = True, platform_version = "60", - wheel_binary = ":build_gpu_plugin_wheel", + source_files = [":jax_pjrt_sources"], + wheel_binary = ":build_gpu_plugin_wheel_tool", wheel_name = "jax_rocm60_pjrt", ) @@ -347,10 +405,127 @@ jax_wheel( editable = True, enable_rocm = True, platform_version = "60", - wheel_binary = ":build_gpu_plugin_wheel", + source_files = [":jax_pjrt_sources"], + wheel_binary = ":build_gpu_plugin_wheel_tool", wheel_name = "jax_rocm60_pjrt", ) +# Py_import targets. +cuda_suffix = "_cu12" if cuda_major_version == "12" else "" + +filegroup( + name = "nvidia_wheel_deps", + srcs = [ + "@pypi_nvidia_cublas{cuda}//:pkg".format(cuda = cuda_suffix), + "@pypi_nvidia_cuda_cupti{cuda}//:pkg".format(cuda = cuda_suffix), + "@pypi_nvidia_cuda_nvcc{cuda}//:pkg".format(cuda = cuda_suffix), + "@pypi_nvidia_cuda_nvrtc{cuda}//:pkg".format(cuda = cuda_suffix), + "@pypi_nvidia_cuda_runtime{cuda}//:pkg".format(cuda = cuda_suffix), + "@pypi_nvidia_cudnn_cu{cuda}//:pkg".format(cuda = cuda_major_version), + "@pypi_nvidia_cufft{cuda}//:pkg".format(cuda = cuda_suffix), + "@pypi_nvidia_cusolver{cuda}//:pkg".format(cuda = cuda_suffix), + "@pypi_nvidia_cusparse{cuda}//:pkg".format(cuda = cuda_suffix), + "@pypi_nvidia_nccl_cu{cuda}//:pkg".format(cuda = cuda_major_version), + "@pypi_nvidia_nvjitlink{cuda}//:pkg".format(cuda = cuda_suffix), + "@pypi_nvidia_nvshmem_cu{cuda}//:pkg".format(cuda = cuda_major_version), + ] + ([ + "@pypi_nvidia_nvvm{cuda}//:pkg".format(cuda = cuda_suffix), + "@pypi_nvidia_cuda_crt{cuda}//:pkg".format(cuda = cuda_suffix), + ] if cuda_major_version == "13" else []), +) + +# The flag configures whether to add the pypi NVIDIA CUDA deps to py_import. +bool_flag( + name = "add_pypi_cuda_wheel_deps", + build_setting_default = True, +) + +config_setting( + name = "pypi_cuda_wheel_deps", + flag_values = { + ":add_pypi_cuda_wheel_deps": "True", + "@local_config_cuda//:enable_cuda": "True", + }, +) + +py_import( + name = "jaxlib_py_import", + wheel = ":jaxlib_wheel", +) + +py_import( + name = "jax_cuda_plugin_py_import", + wheel = ":jax_cuda{cuda}_plugin_wheel".format(cuda = cuda_major_version), + wheel_deps = if_pypi_cuda_wheel_deps([":nvidia_wheel_deps"]), +) + +py_import( + name = "jax_cuda_pjrt_py_import", + wheel = ":jax_cuda{cuda}_pjrt_wheel".format(cuda = cuda_major_version), + wheel_deps = if_pypi_cuda_wheel_deps([":nvidia_wheel_deps"]), +) + +# The targets below are used for GPU tests with `--//jax:build_jaxlib=false`. +py_import( + name = "pypi_jax_cuda_plugin_with_cuda_deps", + wheel = "@pypi_jax_cuda{cuda}_plugin//:whl".format(cuda = cuda_major_version), + wheel_deps = if_pypi_cuda_wheel_deps([":nvidia_wheel_deps"]), +) + +py_import( + name = "pypi_jax_cuda_pjrt_with_cuda_deps", + wheel = "@pypi_jax_cuda{cuda}_pjrt//:whl".format(cuda = cuda_major_version), + wheel_deps = if_pypi_cuda_wheel_deps([":nvidia_wheel_deps"]), +) + +# Mosaic GPU + +py_binary( + name = "build_mosaic_wheel_tool", + srcs = ["build_mosaic_wheel.py"], + main = "build_mosaic_wheel.py", + deps = [ + ":build_utils", + "@bazel_tools//tools/python/runfiles", + "@pypi//build", + "@pypi//setuptools", + "@pypi//wheel", + ], +) + +wheel_sources( + name = "mosaic_sources", + static_srcs = [ + "LICENSE.txt", + "//jaxlib/mosaic/gpu/wheel:mosaic_gpu.so", + "//jaxlib/mosaic/gpu/wheel:setup.py", + "//jaxlib/mosaic/gpu/wheel:__init__.py", + "//jaxlib:version", + ], +) + +jax_wheel( + name = "mosaic_gpu_wheel_cuda12", + enable_cuda = True, + no_abi = True, + platform_version = "12", + source_files = [":mosaic_sources"], + wheel_binary = ":build_mosaic_wheel_tool", + wheel_name = "mosaic_gpu_cuda12", +) + +jax_wheel( + name = "mosaic_gpu_wheel_cuda13", + enable_cuda = True, + no_abi = True, + platform_version = "13", + source_files = [":mosaic_sources"], + wheel_binary = ":build_mosaic_wheel_tool", + wheel_name = "mosaic_gpu_cuda13", +) + +# Wheel tests. + AARCH64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "aarch64")]) PPC64LE_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "ppc64le")]) @@ -375,7 +550,7 @@ verify_manylinux_compliance_test( test_tags = [ "manual", ], - wheel = ":jax_cuda_plugin_wheel", + wheel = ":jax_cuda{cuda}_plugin_wheel".format(cuda = cuda_major_version), x86_64_compliance_tag = X86_64_MANYLINUX_TAG, ) @@ -386,6 +561,175 @@ verify_manylinux_compliance_test( test_tags = [ "manual", ], - wheel = ":jax_cuda_pjrt_wheel", + wheel = ":jax_cuda{cuda}_pjrt_wheel".format(cuda = cuda_major_version), + x86_64_compliance_tag = X86_64_MANYLINUX_TAG, +) + +verify_manylinux_compliance_test( + name = "mosaic_gpu_manylinux_compliance_test", + aarch64_compliance_tag = AARCH64_MANYLINUX_TAG, + ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG, + test_tags = [ + "manual", + ], + wheel = ":mosaic_gpu_wheel_cuda{cuda}".format(cuda = cuda_major_version), x86_64_compliance_tag = X86_64_MANYLINUX_TAG, ) + +pytype_test( + name = "jaxlib_wheel_size_test", + srcs = [":wheel_size_test.py"], + args = [ + "--wheel-path=$(location :jaxlib_wheel)", + "--max-size-mib=110", + ], + data = [":jaxlib_wheel"], + main = "wheel_size_test.py", + tags = [ + "manual", + "notap", + ], +) + +pytype_test( + name = "jax_cuda_plugin_wheel_size_test", + srcs = [":wheel_size_test.py"], + args = [ + "--wheel-path=$(location :jax_cuda{cuda}_plugin_wheel)".format(cuda = cuda_major_version), + "--max-size-mib=20", + ], + data = [":jax_cuda{cuda}_plugin_wheel".format(cuda = cuda_major_version)], + main = "wheel_size_test.py", + tags = [ + "manual", + "notap", + ], +) + +pytype_test( + name = "jax_cuda_pjrt_wheel_size_test", + srcs = [":wheel_size_test.py"], + args = [ + "--wheel-path=$(location :jax_cuda{cuda}_pjrt_wheel)".format(cuda = cuda_major_version), + "--max-size-mib=150", + ], + data = [":jax_cuda{cuda}_pjrt_wheel".format(cuda = cuda_major_version)], + main = "wheel_size_test.py", + tags = [ + "manual", + "notap", + ], +) + +pytype_test( + name = "mosaic_gpu_wheel_size_test", + srcs = [":wheel_size_test.py"], + args = [ + "--wheel-path=$(location :mosaic_gpu_wheel_cuda{cuda})".format(cuda = cuda_major_version), + "--max-size-mib=40", + ], + data = [":mosaic_gpu_wheel_cuda{cuda}".format(cuda = cuda_major_version)], + main = "wheel_size_test.py", + tags = [ + "manual", + "notap", + ], +) + +_IGNORED_INIT_PY_FILES = [ + # This file is copied as `__init__.py` in the root wheel folder. + "jaxlib/__init__.py", + # These files do not exist in LLVM/MLIR sources. + "jaxlib/mlir/__init__.py", + "jaxlib/mlir/dialects/__init__.py", + "jaxlib/mlir/extras/__init__.py", + # These files do not exist in JAX sources. + "jaxlib/mosaic/dialect/gpu/__init__.py", + "jaxlib/mosaic/python/__init__.py", +] + +compare_srcs_and_test_deps_test( + name = "check_cpu_wheel_sources_test", + srcs = [ + ":jaxlib_sources", + "//:jax_sources", + "//:wheel_additives", + ], + ignored_init_py_files = _IGNORED_INIT_PY_FILES, + root_package_names = [ + "jax", + "jaxlib", + ], + tags = [ + "manual", + "notap", + ], + tests = [ + "//jax/experimental/jax2tf/tests:jax2tf_test_cpu", + ] + get_test_suite_list( + backends = ["cpu"], + paths = [ + "jax/experimental/jax2tf/tests/multiprocess", + "tests", + "tests/mosaic", + "tests/multiprocess", + "tests/pallas", + ], + ), +) + +compare_srcs_and_test_deps_test( + name = "check_gpu_wheel_sources_test", + srcs = [ + ":jax_pjrt_sources", + ":jax_plugin_sources", + ":jaxlib_sources", + "//:jax_sources", + "//:wheel_additives", + ], + ignored_init_py_files = _IGNORED_INIT_PY_FILES, + root_package_names = [ + "jax", + "jaxlib", + ], + tags = [ + "manual", + "notap", + ], + tests = get_test_suite_list( + backends = ["gpu"], + paths = [ + "jax/experimental/jax2tf/tests/multiprocess", + "tests", + "tests/mosaic", + "tests/multiprocess", + "tests/pallas", + ], + ), +) + +compare_srcs_and_test_deps_test( + name = "check_tpu_wheel_sources_test", + srcs = [ + ":jaxlib_sources", + "//:jax_sources", + "//:wheel_additives", + ], + ignored_init_py_files = _IGNORED_INIT_PY_FILES, + root_package_names = [ + "jax", + "jaxlib", + ], + tags = [ + "manual", + "notap", + ], + tests = get_test_suite_list( + backends = ["tpu"], + paths = [ + "tests", + "tests/multiprocess", + "tests/pallas", + ], + ), +) diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index 2f81eacbdde4..20babb2f38fd 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -26,7 +26,7 @@ from bazel_tools.tools.python.runfiles import runfiles from jaxlib.tools import build_utils -parser = argparse.ArgumentParser() +parser = argparse.ArgumentParser(fromfile_prefix_chars="@") parser.add_argument( "--output_path", default=None, @@ -61,6 +61,15 @@ "--enable-rocm", default=False, help="Should we build with ROCM enabled?") +parser.add_argument( + "--srcs", help="source files for the wheel", action="append" +) +parser.add_argument( + "--nvidia_wheel_versions_data", + default=None, + required=False, + help="NVIDIA wheel versions data", +) args = parser.parse_args() r = runfiles.Create() @@ -79,80 +88,113 @@ def write_setup_cfg(sources_path, cpu): def prepare_wheel_cuda( - sources_path: pathlib.Path, *, cpu, cuda_version + wheel_sources_path: pathlib.Path, + *, + cpu, + cuda_version, + wheel_sources, + nvidia_wheel_versions_data, ): - """Assembles a source tree for the cuda kernel wheel in `sources_path`.""" - copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r) + """Assembles a source tree for the cuda kernel wheel in `wheel_sources_path`.""" + source_file_prefix = build_utils.get_source_file_prefix(wheel_sources) + wheel_sources_map = build_utils.create_wheel_sources_map( + wheel_sources, + root_packages=[ + "jax_plugins", + f"jax_cuda{cuda_version}_plugin", + "jaxlib", + ], + ) + copy_files = functools.partial( + build_utils.copy_file, + runfiles=r, + wheel_sources_map=wheel_sources_map, + ) - copy_runfiles( - "__main__/jax_plugins/cuda/plugin_pyproject.toml", - dst_dir=sources_path, + copy_files( + f"{source_file_prefix}jax_plugins/cuda/plugin_pyproject.toml", + dst_dir=wheel_sources_path, dst_filename="pyproject.toml", ) - copy_runfiles( - "__main__/jax_plugins/cuda/plugin_setup.py", - dst_dir=sources_path, + copy_files( + f"{source_file_prefix}jax_plugins/cuda/plugin_setup.py", + dst_dir=wheel_sources_path, dst_filename="setup.py", ) - build_utils.update_setup_with_cuda_version(sources_path, cuda_version) - write_setup_cfg(sources_path, cpu) + build_utils.update_setup_with_cuda_and_nvidia_wheel_versions( + wheel_sources_path, cuda_version, nvidia_wheel_versions_data + ) + write_setup_cfg(wheel_sources_path, cpu) - plugin_dir = sources_path / f"jax_cuda{cuda_version}_plugin" - copy_runfiles( + plugin_dir = wheel_sources_path / f"jax_cuda{cuda_version}_plugin" + copy_files( dst_dir=plugin_dir, src_files=[ - f"__main__/jaxlib/cuda/_solver.{pyext}", - f"__main__/jaxlib/cuda/_blas.{pyext}", - f"__main__/jaxlib/cuda/_linalg.{pyext}", - f"__main__/jaxlib/cuda/_prng.{pyext}", - f"__main__/jaxlib/cuda/_rnn.{pyext}", - f"__main__/jaxlib/cuda/_sparse.{pyext}", - f"__main__/jaxlib/cuda/_triton.{pyext}", - f"__main__/jaxlib/cuda/_hybrid.{pyext}", - f"__main__/jaxlib/cuda/_versions.{pyext}", - f"__main__/jaxlib/cuda/cuda_plugin_extension.{pyext}", - f"__main__/jaxlib/mosaic/gpu/_mosaic_gpu_ext.{pyext}", - "__main__/jaxlib/mosaic/gpu/libmosaic_gpu_runtime.so", - "__main__/jaxlib/version.py", + f"{source_file_prefix}jaxlib/cuda/_solver.{pyext}", + f"{source_file_prefix}jaxlib/cuda/_linalg.{pyext}", + f"{source_file_prefix}jaxlib/cuda/_prng.{pyext}", + f"{source_file_prefix}jaxlib/cuda/_rnn.{pyext}", + f"{source_file_prefix}jaxlib/cuda/_sparse.{pyext}", + f"{source_file_prefix}jaxlib/cuda/_triton.{pyext}", + f"{source_file_prefix}jaxlib/cuda/_hybrid.{pyext}", + f"{source_file_prefix}jaxlib/cuda/_versions.{pyext}", + f"{source_file_prefix}jaxlib/cuda/cuda_plugin_extension.{pyext}", + f"{source_file_prefix}jaxlib/mosaic/gpu/_mosaic_gpu_ext.{pyext}", + f"{source_file_prefix}jaxlib/mosaic/gpu/libmosaic_gpu_runtime.so", + f"{source_file_prefix}jaxlib/version.py", ], ) + def prepare_wheel_rocm( - sources_path: pathlib.Path, *, cpu, rocm_version + wheel_sources_path: pathlib.Path, *, cpu, rocm_version, wheel_sources ): - """Assembles a source tree for the rocm kernel wheel in `sources_path`.""" - copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r) + """Assembles a source tree for the rocm kernel wheel in `wheel_sources_path`.""" + source_file_prefix = build_utils.get_source_file_prefix(wheel_sources) + wheel_sources_map = build_utils.create_wheel_sources_map( + wheel_sources, + root_packages=[ + "jax_plugins", + f"jax_rocm{rocm_version}_plugin", + "jaxlib", + ], + ) + copy_files = functools.partial( + build_utils.copy_file, + runfiles=r, + wheel_sources_map=wheel_sources_map, + ) - copy_runfiles( - "__main__/jax_plugins/rocm/plugin_pyproject.toml", - dst_dir=sources_path, + copy_files( + f"{source_file_prefix}jax_plugins/rocm/plugin_pyproject.toml", + dst_dir=wheel_sources_path, dst_filename="pyproject.toml", ) - copy_runfiles( - "__main__/jax_plugins/rocm/plugin_setup.py", - dst_dir=sources_path, + copy_files( + f"{source_file_prefix}jax_plugins/rocm/plugin_setup.py", + dst_dir=wheel_sources_path, dst_filename="setup.py", ) - build_utils.update_setup_with_rocm_version(sources_path, rocm_version) - write_setup_cfg(sources_path, cpu) + build_utils.update_setup_with_rocm_version(wheel_sources_path, rocm_version) + write_setup_cfg(wheel_sources_path, cpu) - plugin_dir = sources_path / f"jax_rocm{rocm_version}_plugin" - copy_runfiles( + plugin_dir = wheel_sources_path / f"jax_rocm{rocm_version}_plugin" + copy_files( dst_dir=plugin_dir, src_files=[ - f"__main__/jaxlib/rocm/_blas.{pyext}", - f"__main__/jaxlib/rocm/_linalg.{pyext}", - f"__main__/jaxlib/rocm/_prng.{pyext}", - f"__main__/jaxlib/rocm/_solver.{pyext}", - f"__main__/jaxlib/rocm/_sparse.{pyext}", - f"__main__/jaxlib/rocm/_hybrid.{pyext}", - f"__main__/jaxlib/rocm/_rnn.{pyext}", - f"__main__/jaxlib/rocm/_triton.{pyext}", - f"__main__/jaxlib/rocm/rocm_plugin_extension.{pyext}", - "__main__/jaxlib/version.py", + f"{source_file_prefix}jaxlib/rocm/_linalg.{pyext}", + f"{source_file_prefix}jaxlib/rocm/_prng.{pyext}", + f"{source_file_prefix}jaxlib/rocm/_solver.{pyext}", + f"{source_file_prefix}jaxlib/rocm/_sparse.{pyext}", + f"{source_file_prefix}jaxlib/rocm/_hybrid.{pyext}", + f"{source_file_prefix}jaxlib/rocm/_rnn.{pyext}", + f"{source_file_prefix}jaxlib/rocm/_triton.{pyext}", + f"{source_file_prefix}jaxlib/rocm/rocm_plugin_extension.{pyext}", + f"{source_file_prefix}jaxlib/version.py", ], ) + # Build wheel for cuda kernels if args.enable_rocm: tmpdir = tempfile.TemporaryDirectory(prefix="jax_rocm_plugin") @@ -163,12 +205,19 @@ def prepare_wheel_rocm( os.makedirs(args.output_path, exist_ok=True) if args.enable_cuda: prepare_wheel_cuda( - pathlib.Path(sources_path), cpu=args.cpu, cuda_version=args.platform_version + pathlib.Path(sources_path), + cpu=args.cpu, + cuda_version=args.platform_version, + wheel_sources=args.srcs, + nvidia_wheel_versions_data=args.nvidia_wheel_versions_data, ) package_name = f"jax cuda{args.platform_version} plugin" elif args.enable_rocm: prepare_wheel_rocm( - pathlib.Path(sources_path), cpu=args.cpu, rocm_version=args.platform_version + pathlib.Path(sources_path), + cpu=args.cpu, + rocm_version=args.platform_version, + wheel_sources=args.srcs, ) package_name = f"jax rocm{args.platform_version} plugin" if args.editable: diff --git a/jaxlib/tools/build_gpu_plugin_wheel.py b/jaxlib/tools/build_gpu_plugin_wheel.py index 667807b51197..8256a89a91b3 100644 --- a/jaxlib/tools/build_gpu_plugin_wheel.py +++ b/jaxlib/tools/build_gpu_plugin_wheel.py @@ -26,7 +26,7 @@ from bazel_tools.tools.python.runfiles import runfiles from jaxlib.tools import build_utils -parser = argparse.ArgumentParser() +parser = argparse.ArgumentParser(fromfile_prefix_chars="@") parser.add_argument( "--sources_path", default=None, @@ -67,6 +67,15 @@ "--enable-rocm", default=False, help="Should we build with ROCM enabled?") +parser.add_argument( + "--srcs", help="source files for the wheel", action="append" +) +parser.add_argument( + "--nvidia_wheel_versions_data", + default=None, + required=False, + help="NVIDIA wheel versions data", +) args = parser.parse_args() r = runfiles.Create() @@ -81,62 +90,89 @@ def write_setup_cfg(sources_path, cpu): [bdist_wheel] plat_name={tag} -python-tag=py3 +python_tag=py3 """ ) -def prepare_cuda_plugin_wheel(sources_path: pathlib.Path, *, cpu, cuda_version): - """Assembles a source tree for the wheel in `sources_path`.""" - copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r) +def prepare_cuda_plugin_wheel( + wheel_sources_path: pathlib.Path, + *, + cpu, + cuda_version, + wheel_sources, + nvidia_wheel_versions_data, +): + """Assembles a source tree for the wheel in `wheel_sources_path`""" + source_file_prefix = build_utils.get_source_file_prefix(wheel_sources) + wheel_sources_map = build_utils.create_wheel_sources_map( + wheel_sources, root_packages=["jax_plugins", "jaxlib"] + ) + copy_files = functools.partial( + build_utils.copy_file, + runfiles=r, + wheel_sources_map=wheel_sources_map, + ) - plugin_dir = sources_path / "jax_plugins" / f"xla_cuda{cuda_version}" - copy_runfiles( - dst_dir=sources_path, + plugin_dir = wheel_sources_path / "jax_plugins" / f"xla_cuda{cuda_version}" + copy_files( + dst_dir=wheel_sources_path, src_files=[ - "__main__/jax_plugins/cuda/pyproject.toml", - "__main__/jax_plugins/cuda/setup.py", + f"{source_file_prefix}jax_plugins/cuda/pyproject.toml", + f"{source_file_prefix}jax_plugins/cuda/setup.py", ], ) - build_utils.update_setup_with_cuda_version(sources_path, cuda_version) - write_setup_cfg(sources_path, cpu) - copy_runfiles( + build_utils.update_setup_with_cuda_and_nvidia_wheel_versions( + wheel_sources_path, cuda_version, nvidia_wheel_versions_data + ) + write_setup_cfg(wheel_sources_path, cpu) + copy_files( dst_dir=plugin_dir, src_files=[ - "__main__/jax_plugins/cuda/__init__.py", - "__main__/jaxlib/version.py", + f"{source_file_prefix}jax_plugins/cuda/__init__.py", + f"{source_file_prefix}jaxlib/version.py", ], ) - copy_runfiles( - "__main__/jaxlib/tools/pjrt_c_api_gpu_plugin.so", + copy_files( + f"{source_file_prefix}jax_plugins/cuda/pjrt_c_api_gpu_plugin.so", dst_dir=plugin_dir, dst_filename="xla_cuda_plugin.so", ) -def prepare_rocm_plugin_wheel(sources_path: pathlib.Path, *, cpu, rocm_version): - """Assembles a source tree for the ROCm wheel in `sources_path`.""" - copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r) +def prepare_rocm_plugin_wheel( + wheel_sources_path: pathlib.Path, *, cpu, rocm_version, wheel_sources +): + """Assembles a source tree for the ROCm wheel in `wheel_sources_path`.""" + source_file_prefix = build_utils.get_source_file_prefix(wheel_sources) + wheel_sources_map = build_utils.create_wheel_sources_map( + wheel_sources, root_packages=["jax_plugins", "jaxlib"] + ) + copy_files = functools.partial( + build_utils.copy_file, + runfiles=r, + wheel_sources_map=wheel_sources_map, + ) - plugin_dir = sources_path / "jax_plugins" / f"xla_rocm{rocm_version}" - copy_runfiles( - dst_dir=sources_path, - src_files=[ - "__main__/jax_plugins/rocm/pyproject.toml", - "__main__/jax_plugins/rocm/setup.py", + plugin_dir = wheel_sources_path / "jax_plugins" / f"xla_rocm{rocm_version}" + copy_files( + dst_dir=wheel_sources_path, + src_files=[ + f"{source_file_prefix}jax_plugins/rocm/pyproject.toml", + f"{source_file_prefix}jax_plugins/rocm/setup.py", ], ) - build_utils.update_setup_with_rocm_version(sources_path, rocm_version) - write_setup_cfg(sources_path, cpu) - copy_runfiles( + build_utils.update_setup_with_rocm_version(wheel_sources_path, rocm_version) + write_setup_cfg(wheel_sources_path, cpu) + copy_files( dst_dir=plugin_dir, src_files=[ - "__main__/jax_plugins/rocm/__init__.py", - "__main__/jaxlib/version.py", + f"{source_file_prefix}jax_plugins/rocm/__init__.py", + f"{source_file_prefix}jaxlib/version.py", ], ) - copy_runfiles( - "__main__/jaxlib/tools/pjrt_c_api_gpu_plugin.so", + copy_files( + f"{source_file_prefix}jax_plugins/rocm/pjrt_c_api_gpu_plugin.so", dst_dir=plugin_dir, dst_filename="xla_rocm_plugin.so", ) @@ -153,12 +189,19 @@ def prepare_rocm_plugin_wheel(sources_path: pathlib.Path, *, cpu, rocm_version): if args.enable_cuda: prepare_cuda_plugin_wheel( - pathlib.Path(sources_path), cpu=args.cpu, cuda_version=args.platform_version + pathlib.Path(sources_path), + cpu=args.cpu, + cuda_version=args.platform_version, + wheel_sources=args.srcs, + nvidia_wheel_versions_data=args.nvidia_wheel_versions_data, ) package_name = "jax cuda plugin" elif args.enable_rocm: prepare_rocm_plugin_wheel( - pathlib.Path(sources_path), cpu=args.cpu, rocm_version=args.platform_version + pathlib.Path(sources_path), + cpu=args.cpu, + rocm_version=args.platform_version, + wheel_sources=args.srcs, ) package_name = "jax rocm plugin" else: diff --git a/jaxlib/tools/build_mosaic_wheel.py b/jaxlib/tools/build_mosaic_wheel.py new file mode 100644 index 000000000000..b5b1e0831117 --- /dev/null +++ b/jaxlib/tools/build_mosaic_wheel.py @@ -0,0 +1,159 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Script that builds a jax cuda/rocm plugin wheel, intended to be run via bazel run +# as part of the jax cuda/rocm plugin build process. + +# Most users should not run this script directly; use build.py instead. + +import argparse +import functools +import os +import pathlib +import tempfile + +from bazel_tools.tools.python.runfiles import runfiles +from jaxlib.tools import build_utils + +parser = argparse.ArgumentParser(fromfile_prefix_chars="@") +parser.add_argument( + "--sources_path", + default=None, + help="Path in which the wheel's sources should be prepared. Optional. If " + "omitted, a temporary directory will be used.", +) +parser.add_argument( + "--output_path", + required=True, + help="Path to which the output wheel should be written. Required.", +) +parser.add_argument( + "--jaxlib_git_hash", + required=True, + help="Git hash. Required.", +) +parser.add_argument( + "--cpu", required=True, help="Target CPU architecture. Required." +) +parser.add_argument( + "--platform_version", + required=True, + help="Target CUDA version. Required.", +) +parser.add_argument( + "--editable", + action="store_true", + help="Create an 'editable' mosaic build instead of a wheel.", +) +parser.add_argument( + "--srcs", help="source files for the wheel", action="append" +) +parser.add_argument( + "--nvidia_wheel_versions_data", + default=None, + required=True, + help="NVIDIA wheel versions data", +) + +# The jax_wheel target passes in some extra params, which we ignore +args, _ = parser.parse_known_args() + +r = runfiles.Create() + + +def assemble_sources( + wheel_sources_path: pathlib.Path, + *, + cpu, + cuda_version, + wheel_sources, + nvidia_wheel_versions_data, +): + """Assembles a source tree for the wheel in `wheel_sources_path`""" + source_file_prefix = build_utils.get_source_file_prefix(wheel_sources) + wheel_sources_map = build_utils.create_wheel_sources_map( + wheel_sources, root_packages=["jaxlib"] + ) + mgpudir = wheel_sources_path / "mosaic_gpu" + + copy_files = functools.partial( + build_utils.copy_file, + runfiles=r, + wheel_sources_map=wheel_sources_map, + ) + + copy_files( + dst_dir=wheel_sources_path, + src_files=[ + f"{source_file_prefix}jaxlib/tools/LICENSE.txt", + f"{source_file_prefix}jaxlib/mosaic/gpu/wheel/setup.py", + ], + ) + + copy_files( + dst_dir=mgpudir / f"mosaic_gpu_cuda{cuda_version}", + src_files=[ + f"{source_file_prefix}jaxlib/mosaic/gpu/wheel/mosaic_gpu.so", + f"{source_file_prefix}jaxlib/mosaic/gpu/wheel/__init__.py", + f"{source_file_prefix}jaxlib/version.py", + ], + ) + + # This sets the cuda version in setup.py + build_utils.update_setup_with_cuda_and_nvidia_wheel_versions( + wheel_sources_path, cuda_version, nvidia_wheel_versions_data + ) + + tag = build_utils.platform_tag(cpu) + with open(wheel_sources_path / "setup.cfg", "w") as f: + f.write( + f"""[metadata] +license_files = LICENSE.txt + +[bdist_wheel] +plat_name={tag} +python_tag=py3 +""" + ) + + +tmpdir = None +sources_path = args.sources_path +if sources_path is None: + tmpdir = tempfile.TemporaryDirectory(prefix="mosaic_gpu") + sources_path = tmpdir.name + +try: + os.makedirs(args.output_path, exist_ok=True) + package_name = "mosaic_gpu" + + assemble_sources( + pathlib.Path(sources_path), + cpu=args.cpu, + cuda_version=args.platform_version, + wheel_sources=args.srcs, + nvidia_wheel_versions_data=args.nvidia_wheel_versions_data, + ) + if args.editable: + build_utils.build_editable(sources_path, args.output_path, package_name) + else: + build_utils.build_wheel( + sources_path, + args.output_path, + package_name, + git_hash=args.jaxlib_git_hash, + ) +finally: + if tmpdir: + tmpdir.cleanup() diff --git a/jaxlib/tools/build_utils.py b/jaxlib/tools/build_utils.py index 4c50cff16743..a4cc19ee8758 100644 --- a/jaxlib/tools/build_utils.py +++ b/jaxlib/tools/build_utils.py @@ -24,32 +24,85 @@ import subprocess import glob from collections.abc import Sequence + from jaxlib.tools import platform_tags +import third_party.py.setup_py_nvidia_dependencies_util as util + +MAIN_RUNFILES_DIR = "__main__/" def is_windows() -> bool: return sys.platform.startswith("win32") +def create_wheel_sources_map(wheel_sources, root_packages): + """Returns a map of paths relative to the root package to the full paths.""" + wheel_sources_map = {} + if not wheel_sources: + return wheel_sources_map + for source in wheel_sources: + for package in root_packages: + # Dealing with source files from the main repo + if source.startswith("{}/".format(package)): + wheel_sources_map[source] = source + continue + # Dealing with source files from external repos + # e.g. external/xla/xla/ffi/api/c_api.h + # which should map to xla/ffi/api/c_api.h + if source.startswith("external/"): + parts = source.split("/", 2) + if len(parts) == 3: + wheel_sources_map[parts[2]] = source + continue + else: + raise RuntimeError( + "Unexpected external source format: {}".format(source) + ) + # Dealing with generated source + # e.g. bazel-out/k8-opt/bin/jaxlib/mlir/_mlir_libs/_jax_mlir_ext.py + # which should map to jaxlib/mlir/_mlir_libs/_jax_mlir_ext.py + root_package_ind = source.find("/{}/".format(package)) + if root_package_ind >= 0: + wheel_sources_map[source[root_package_ind + 1:]] = source + return wheel_sources_map + + +# TODO(ybaturina): remove the method when we switch to the new wheel build rules +# and the runfiles are not needed. +def get_source_file_prefix(wheel_sources): + return "" if wheel_sources else MAIN_RUNFILES_DIR + + def copy_file( src_files: str | Sequence[str], dst_dir: pathlib.Path, - dst_filename = None, - runfiles = None, + dst_filename=None, + runfiles=None, + wheel_sources_map=None, ) -> None: dst_dir.mkdir(parents=True, exist_ok=True) if isinstance(src_files, str): src_files = [src_files] for src_file in src_files: - src_file_rloc = runfiles.Rlocation(src_file) - if src_file_rloc is None: + if wheel_sources_map: + src_file_loc = wheel_sources_map.get(src_file, None) + # TODO(ybaturina): remove the runfiles part when we switch to the new wheel + # build rules and the runfiles are not needed. + elif runfiles: + src_file_loc = runfiles.Rlocation(src_file) + else: + raise RuntimeError( + "Either runfiles or wheel_sources_map should be provided!" + ) + if src_file_loc is None: raise ValueError(f"Unable to find wheel source file {src_file}") - src_filename = os.path.basename(src_file_rloc) + + src_filename = os.path.basename(src_file_loc) dst_file = os.path.join(dst_dir, dst_filename or src_filename) if is_windows(): - shutil.copyfile(src_file_rloc, dst_file) + shutil.copyfile(src_file_loc, dst_file) else: - shutil.copy(src_file_rloc, dst_file) + shutil.copy(src_file_loc, dst_file) def platform_tag(cpu: str) -> str: @@ -65,6 +118,7 @@ def build_wheel( package_name: str, git_hash: str = "", build_wheel_only: bool = True, + build_source_package_only: bool = False, ) -> None: """Builds a wheel in `output_path` using the source tree in `sources_path`.""" env = dict(os.environ) @@ -78,7 +132,8 @@ def build_wheel( env["USERPROFILE"] = env.get("SYSTEMDRIVE", "C:") subprocess.run( [sys.executable, "-m", "build", "-n"] - + (["-w"] if build_wheel_only else []), + + (["-w"] if build_wheel_only else []) + + (["-s"] if build_source_package_only else []), check=True, cwd=sources_path, env=env, @@ -97,10 +152,10 @@ def build_wheel( sys.stderr.write(" bazel run //build:requirements.update" + f" --repo_env=HERMETIC_PYTHON_VERSION={py_version}\n\n") shutil.copy(wheel, output_path) - if not build_wheel_only: + if build_source_package_only: for dist in glob.glob(os.path.join(sources_path, "dist", "*.tar.gz")): output_file = os.path.join(output_path, os.path.basename(dist)) - sys.stderr.write(f"Output source distribution: {output_file}\n\n") + sys.stderr.write(f"Output source package: {output_file}\n\n") shutil.copy(dist, output_path) @@ -115,13 +170,16 @@ def build_editable( shutil.copytree(sources_path, output_path) -def update_setup_with_cuda_version(file_dir: pathlib.Path, cuda_version: str): +def update_setup_with_cuda_and_nvidia_wheel_versions( + file_dir: pathlib.Path, cuda_version: str, nvidia_wheel_versions_data: str +): src_file = file_dir / "setup.py" with open(src_file) as f: content = f.read() - content = content.replace( - "cuda_version = 0 # placeholder", f"cuda_version = {cuda_version}" + content = util.get_setup_py_content_with_nvidia_wheel_versions( + content, cuda_version, nvidia_wheel_versions_data ) + with open(src_file, "w") as f: f.write(content) diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 8632468acb97..593dfcdaa708 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -29,7 +29,7 @@ from bazel_tools.tools.python.runfiles import runfiles from jaxlib.tools import build_utils -parser = argparse.ArgumentParser() +parser = argparse.ArgumentParser(fromfile_prefix_chars="@") parser.add_argument( "--sources_path", default=None, @@ -56,32 +56,43 @@ action="store_true", help="Create an 'editable' jaxlib build instead of a wheel.", ) +parser.add_argument( + "--srcs", help="source files for the wheel", action="append" +) args = parser.parse_args() r = runfiles.Create() - def _is_mac(): return platform.system() == "Darwin" +soext = "dll" if build_utils.is_windows() else ("dylib" if _is_mac() else "so") pyext = "pyd" if build_utils.is_windows() else "so" -def exists(src_file): - path = r.Rlocation(src_file) - if path is None: - return False - return os.path.exists(path) +def _get_file_path(src_file, runfiles=None, wheel_sources_map=None): + if wheel_sources_map: + return wheel_sources_map.get( + src_file.replace(build_utils.MAIN_RUNFILES_DIR, ""), None + ) + # TODO(ybaturina): remove the runfiles part when we switch to the new wheel + # build rules and the runfiles are not needed. + elif runfiles: + return runfiles.Rlocation(src_file) + else: + raise RuntimeError("Either runfiles or wheel_sources should be provided!") -def patch_copy_mlir_import(src_file, dst_dir): - src_file = r.Rlocation(src_file) +def patch_copy_mlir_import( + src_file, dst_dir, runfiles=None, wheel_sources_map=None +): + src_file = _get_file_path(src_file, runfiles, wheel_sources_map) src_filename = os.path.basename(src_file) - with open(src_file) as f: + with open(src_file, encoding="utf-8") as f: src = f.read() - with open(dst_dir / src_filename, "w") as f: + with open(dst_dir / src_filename, "w", encoding="utf-8") as f: replaced = re.sub( r"^from mlir(\..*)? import (.*)", r"from jaxlib.mlir\1 import \2", @@ -91,40 +102,10 @@ def patch_copy_mlir_import(src_file, dst_dir): f.write(replaced) -_XLA_EXTENSION_STUBS = [ - "__init__.pyi", - "guard_lib.pyi", - "ifrt_programs.pyi", - "ifrt_proxy.pyi", - "jax_jit.pyi", - "ops.pyi", - "pmap_lib.pyi", - "profiler.pyi", - "pytree.pyi", - "transfer_guard_lib.pyi", -] -_OPTIONAL_XLA_EXTENSION_STUBS = [] - - -def patch_copy_xla_extension_stubs(dst_dir): - xla_extension_dir = os.path.join(dst_dir, "xla_extension") - os.makedirs(xla_extension_dir) - for stub_name in _XLA_EXTENSION_STUBS: - stub_path = r.Rlocation("xla/xla/python/xla_extension/" + stub_name) - stub_path = str(stub_path) # Make pytype accept os.path.exists(stub_path). - if stub_name in _OPTIONAL_XLA_EXTENSION_STUBS and not os.path.exists(stub_path): - continue - with open(stub_path) as f: - src = f.read() - src = src.replace( - "from xla.python import xla_extension", "from .. import xla_extension" - ) - with open(os.path.join(xla_extension_dir, stub_name), "w") as f: - f.write(src) - - -def verify_mac_libraries_dont_reference_chkstack(): - """Verifies that xla_extension.so doesn't depend on ____chkstk_darwin. +def verify_mac_libraries_dont_reference_chkstack( + runfiles=None, wheel_sources_map=None +): + """Verifies that _jax.so doesn't depend on ____chkstk_darwin. We don't entirely know why this happens, but in some build environments we seem to target the wrong Mac OS version. @@ -134,8 +115,11 @@ def verify_mac_libraries_dont_reference_chkstack(): """ if not _is_mac(): return + file_path = _get_file_path( + f"__main__/jaxlib/_jax.{pyext}", runfiles, wheel_sources_map + ) nm = subprocess.run( - ["nm", "-g", r.Rlocation("xla/xla/python/xla_extension.so")], + ["nm", "-g", file_path], capture_output=True, text=True, check=False, @@ -162,214 +146,256 @@ def write_setup_cfg(sources_path, cpu): ) -def prepare_wheel(sources_path: pathlib.Path, *, cpu): - """Assembles a source tree for the wheel in `sources_path`.""" - copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r) +def prepare_wheel(wheel_sources_path: pathlib.Path, *, cpu, wheel_sources): + """Assembles a source tree for the wheel in `wheel_sources_path`.""" + source_file_prefix = build_utils.get_source_file_prefix(wheel_sources) + # The wheel sources provided by the transitive rules might have different path + # prefixes, so we need to create a map of paths relative to the root package + # to the full paths. + # E.g. if we have the wheel sources paths like + # bazel-out/k8-opt/bin/jaxlib/mlir/_mlir_libs/_jax_mlir_ext.py and + # external/xla/xla/ffi/api/c_api.h, the resulting map will be + # {'jaxlib/mlir/_mlir_libs/_jax_mlir_ext.py': + # 'bazel-out/k8-opt/bin/jaxlib/mlir/_mlir_libs/_jax_mlir_ext.py', + # 'xla/ffi/api/c_api.h': 'external/xla/xla/ffi/api/c_api.h'} + wheel_sources_map = build_utils.create_wheel_sources_map( + wheel_sources, root_packages=["jaxlib", "xla"] + ) + copy_files = functools.partial( + build_utils.copy_file, + runfiles=r, + wheel_sources_map=wheel_sources_map, + ) - verify_mac_libraries_dont_reference_chkstack() - copy_runfiles( - dst_dir=sources_path, + verify_mac_libraries_dont_reference_chkstack( + runfiles=r, wheel_sources_map=wheel_sources_map + ) + copy_files( + dst_dir=wheel_sources_path, src_files=[ - "__main__/jaxlib/tools/LICENSE.txt", - "__main__/jaxlib/README.md", - "__main__/jaxlib/setup.py", + f"{source_file_prefix}jaxlib/tools/LICENSE.txt", + f"{source_file_prefix}jaxlib/README.md", + f"{source_file_prefix}jaxlib/setup.py", ], ) - write_setup_cfg(sources_path, cpu) + write_setup_cfg(wheel_sources_path, cpu) - jaxlib_dir = sources_path / "jaxlib" - copy_runfiles( - "__main__/jaxlib/init.py", dst_dir=jaxlib_dir, dst_filename="__init__.py" + jaxlib_dir = wheel_sources_path / "jaxlib" + copy_files( + f"{source_file_prefix}jaxlib/init.py", + dst_dir=jaxlib_dir, + dst_filename="__init__.py", ) - copy_runfiles( + copy_files( dst_dir=jaxlib_dir, src_files=[ - f"__main__/jaxlib/cpu_feature_guard.{pyext}", - f"__main__/jaxlib/utils.{pyext}", - "__main__/jaxlib/lapack.py", - "__main__/jaxlib/hlo_helpers.py", - "__main__/jaxlib/gpu_prng.py", - "__main__/jaxlib/gpu_linalg.py", - "__main__/jaxlib/gpu_rnn.py", - "__main__/jaxlib/gpu_triton.py", - "__main__/jaxlib/gpu_common_utils.py", - "__main__/jaxlib/gpu_solver.py", - "__main__/jaxlib/gpu_sparse.py", - "__main__/jaxlib/plugin_support.py", - "__main__/jaxlib/version.py", - "__main__/jaxlib/xla_client.py", - f"xla/xla/python/xla_extension.{pyext}", + f"{source_file_prefix}jaxlib/cpu_feature_guard.{pyext}", + f"{source_file_prefix}jaxlib/cpu_sparse.py", + f"{source_file_prefix}jaxlib/utils.{pyext}", + f"{source_file_prefix}jaxlib/jax_common.dll" + if build_utils.is_windows() + else f"{source_file_prefix}jaxlib/libjax_common.{soext}", + f"{source_file_prefix}jaxlib/lapack.py", + f"{source_file_prefix}jaxlib/gpu_prng.py", + f"{source_file_prefix}jaxlib/gpu_linalg.py", + f"{source_file_prefix}jaxlib/gpu_rnn.py", + f"{source_file_prefix}jaxlib/gpu_triton.py", + f"{source_file_prefix}jaxlib/gpu_common_utils.py", + f"{source_file_prefix}jaxlib/gpu_solver.py", + f"{source_file_prefix}jaxlib/gpu_sparse.py", + f"{source_file_prefix}jaxlib/plugin_support.py", + f"{source_file_prefix}jaxlib/_pretty_printer.{pyext}", + f"{source_file_prefix}jaxlib/version.py", + f"{source_file_prefix}jaxlib/xla_client.py", + f"{source_file_prefix}jaxlib/weakref_lru_cache.{pyext}", + f"{source_file_prefix}jaxlib/weakref_lru_cache.pyi", + f"{source_file_prefix}jaxlib/_ifrt_proxy.{pyext}", + f"{source_file_prefix}jaxlib/_jax.{pyext}", + f"{source_file_prefix}jaxlib/_sdy_mpmd.{pyext}", + f"{source_file_prefix}jaxlib/_pathways.{pyext}", + f"{source_file_prefix}jaxlib/_profiler.{pyext}", + f"{source_file_prefix}jaxlib/_profile_data.{pyext}", ], ) # This file is required by PEP-561. It marks jaxlib as package containing # type stubs. with open(jaxlib_dir / "py.typed", "w"): pass - patch_copy_xla_extension_stubs(jaxlib_dir) - copy_runfiles( + copy_files( dst_dir=jaxlib_dir / "cpu", src_files=[ - f"__main__/jaxlib/cpu/_lapack.{pyext}", + f"{source_file_prefix}jaxlib/cpu/_lapack.{pyext}", + f"{source_file_prefix}jaxlib/cpu/_sparse.{pyext}", ], ) mosaic_python_dir = jaxlib_dir / "mosaic" / "python" - copy_runfiles( + copy_files( dst_dir=mosaic_python_dir, src_files=[ - "__main__/jaxlib/mosaic/python/layout_defs.py", - "__main__/jaxlib/mosaic/python/mosaic_gpu.py", - "__main__/jaxlib/mosaic/python/tpu.py", + f"{source_file_prefix}jaxlib/mosaic/python/layout_defs.py", + f"{source_file_prefix}jaxlib/mosaic/python/mosaic_gpu.py", + f"{source_file_prefix}jaxlib/mosaic/python/tpu.py", ], ) # TODO (sharadmv,skyewm): can we avoid patching this file? patch_copy_mlir_import( - "__main__/jaxlib/mosaic/python/_tpu_gen.py", dst_dir=mosaic_python_dir + f"{source_file_prefix}jaxlib/mosaic/python/_tpu_gen.py", + dst_dir=mosaic_python_dir, + runfiles=r, + wheel_sources_map=wheel_sources_map, ) mosaic_gpu_dir = jaxlib_dir / "mosaic" / "dialect" / "gpu" os.makedirs(mosaic_gpu_dir) patch_copy_mlir_import( - "__main__/jaxlib/mosaic/dialect/gpu/_mosaic_gpu_gen_ops.py", + f"{source_file_prefix}jaxlib/mosaic/dialect/gpu/_mosaic_gpu_gen_ops.py", dst_dir=mosaic_gpu_dir, + runfiles=r, + wheel_sources_map=wheel_sources_map, ) patch_copy_mlir_import( - "__main__/jaxlib/mosaic/dialect/gpu/_mosaic_gpu_gen_enums.py", + f"{source_file_prefix}jaxlib/mosaic/dialect/gpu/_mosaic_gpu_gen_enums.py", dst_dir=mosaic_gpu_dir, + runfiles=r, + wheel_sources_map=wheel_sources_map, ) - copy_runfiles( + copy_files( dst_dir=jaxlib_dir / "mlir", src_files=[ - "__main__/jaxlib/mlir/ir.py", - "__main__/jaxlib/mlir/ir.pyi", - "__main__/jaxlib/mlir/passmanager.py", - "__main__/jaxlib/mlir/passmanager.pyi", + f"{source_file_prefix}jaxlib/mlir/ir.py", + f"{source_file_prefix}jaxlib/mlir/passmanager.py", ], ) - copy_runfiles( + copy_files( dst_dir=jaxlib_dir / "mlir" / "dialects", src_files=[ - "__main__/jaxlib/mlir/dialects/_arith_enum_gen.py", - "__main__/jaxlib/mlir/dialects/_arith_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_builtin_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_chlo_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_func_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_math_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_memref_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_mhlo_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_ods_common.py", - "__main__/jaxlib/mlir/dialects/_scf_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_sdy_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_sparse_tensor_enum_gen.py", - "__main__/jaxlib/mlir/dialects/_sparse_tensor_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_stablehlo_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_vector_enum_gen.py", - "__main__/jaxlib/mlir/dialects/_vector_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_gpu_enum_gen.py", - "__main__/jaxlib/mlir/dialects/_gpu_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_nvgpu_enum_gen.py", - "__main__/jaxlib/mlir/dialects/_nvgpu_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_nvvm_enum_gen.py", - "__main__/jaxlib/mlir/dialects/_nvvm_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_llvm_enum_gen.py", - "__main__/jaxlib/mlir/dialects/_llvm_ops_gen.py", - "__main__/jaxlib/mlir/dialects/arith.py", - "__main__/jaxlib/mlir/dialects/builtin.py", - "__main__/jaxlib/mlir/dialects/chlo.py", - "__main__/jaxlib/mlir/dialects/func.py", - "__main__/jaxlib/mlir/dialects/math.py", - "__main__/jaxlib/mlir/dialects/memref.py", - "__main__/jaxlib/mlir/dialects/mhlo.py", - "__main__/jaxlib/mlir/dialects/scf.py", - "__main__/jaxlib/mlir/dialects/sdy.py", - "__main__/jaxlib/mlir/dialects/sparse_tensor.py", - "__main__/jaxlib/mlir/dialects/stablehlo.py", - "__main__/jaxlib/mlir/dialects/vector.py", - "__main__/jaxlib/mlir/dialects/nvgpu.py", - "__main__/jaxlib/mlir/dialects/nvvm.py", - "__main__/jaxlib/mlir/dialects/llvm.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_arith_enum_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_arith_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_builtin_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_cf_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_chlo_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_func_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_math_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_memref_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_mhlo_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_mpmd_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_ods_common.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_scf_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_sdy_enums_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_sdy_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_sparse_tensor_enum_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_sparse_tensor_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_stablehlo_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_vector_enum_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_vector_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_gpu_enum_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_gpu_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_nvgpu_enum_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_nvgpu_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_nvvm_enum_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_nvvm_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_llvm_enum_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_llvm_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/arith.py", + f"{source_file_prefix}jaxlib/mlir/dialects/builtin.py", + f"{source_file_prefix}jaxlib/mlir/dialects/cf.py", + f"{source_file_prefix}jaxlib/mlir/dialects/chlo.py", + f"{source_file_prefix}jaxlib/mlir/dialects/func.py", + f"{source_file_prefix}jaxlib/mlir/dialects/math.py", + f"{source_file_prefix}jaxlib/mlir/dialects/memref.py", + f"{source_file_prefix}jaxlib/mlir/dialects/mhlo.py", + f"{source_file_prefix}jaxlib/mlir/dialects/mpmd.py", + f"{source_file_prefix}jaxlib/mlir/dialects/scf.py", + f"{source_file_prefix}jaxlib/mlir/dialects/sdy.py", + f"{source_file_prefix}jaxlib/mlir/dialects/sparse_tensor.py", + f"{source_file_prefix}jaxlib/mlir/dialects/stablehlo.py", + f"{source_file_prefix}jaxlib/mlir/dialects/vector.py", + f"{source_file_prefix}jaxlib/mlir/dialects/nvgpu.py", + f"{source_file_prefix}jaxlib/mlir/dialects/nvvm.py", + f"{source_file_prefix}jaxlib/mlir/dialects/llvm.py", ], ) - copy_runfiles( + copy_files( dst_dir=jaxlib_dir / "mlir" / "extras", src_files=[ - "__main__/jaxlib/mlir/extras/meta.py", + f"{source_file_prefix}jaxlib/mlir/extras/meta.py", + f"{source_file_prefix}jaxlib/mlir/extras/types.py", ], ) - copy_runfiles( + copy_files( dst_dir=jaxlib_dir / "mlir" / "dialects" / "gpu", src_files=[ - "__main__/jaxlib/mlir/dialects/gpu/__init__.py", + f"{source_file_prefix}jaxlib/mlir/dialects/gpu/__init__.py", ], ) - copy_runfiles( + copy_files( dst_dir=jaxlib_dir / "mlir" / "dialects" / "gpu" / "passes", src_files=[ - "__main__/jaxlib/mlir/dialects/gpu/passes/__init__.py", + f"{source_file_prefix}jaxlib/mlir/dialects/gpu/passes/__init__.py", ], ) - - if build_utils.is_windows(): - capi_so = "__main__/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi.dll" - else: - so_ext = "dylib" if _is_mac() else "so" - capi_so = f"__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.{so_ext}" - mlir_libs_dir = jaxlib_dir / "mlir" / "_mlir_libs" - copy_runfiles( + copy_files( dst_dir=mlir_libs_dir, src_files=[ - capi_so, - "__main__/jaxlib/mlir/_mlir_libs/__init__.py", - f"__main__/jaxlib/mlir/_mlir_libs/_mlir.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_chlo.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mosaic_gpu_ext.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_sdy.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/register_jax_dialects.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsGPU.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsLLVM.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsNVGPU.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirGPUPasses.{pyext}", + f"{source_file_prefix}jaxlib/mlir/_mlir_libs/__init__.py", + f"{source_file_prefix}jaxlib/_mlir.{pyext}", + f"{source_file_prefix}jaxlib/_chlo.{pyext}", + f"{source_file_prefix}jaxlib/_mlirHlo.{pyext}", + f"{source_file_prefix}jaxlib/_mlirDialectsSparseTensor.{pyext}", + f"{source_file_prefix}jaxlib/_mlirSparseTensorPasses.{pyext}", + f"{source_file_prefix}jaxlib/_mosaic_gpu_ext.{pyext}", + f"{source_file_prefix}jaxlib/_tpu_ext.{pyext}", + f"{source_file_prefix}jaxlib/_sdy.{pyext}", + f"{source_file_prefix}jaxlib/_sdyMpmd.{pyext}", + f"{source_file_prefix}jaxlib/_stablehlo.{pyext}", + f"{source_file_prefix}jaxlib/_jax_mlir_ext.{pyext}", + f"{source_file_prefix}jaxlib/_mlirDialectsGPU.{pyext}", + f"{source_file_prefix}jaxlib/_mlirDialectsLLVM.{pyext}", + f"{source_file_prefix}jaxlib/_mlirDialectsNVGPU.{pyext}", + f"{source_file_prefix}jaxlib/_mlirGPUPasses.{pyext}", ] + ( [] if build_utils.is_windows() else [ - f"__main__/jaxlib/mlir/_mlir_libs/_triton_ext.{pyext}", - "__main__/jaxlib/mlir/_mlir_libs/_triton_ext.pyi", + f"{source_file_prefix}jaxlib/_triton_ext.{pyext}", + f"{source_file_prefix}jaxlib/mlir/_mlir_libs/_triton_ext.pyi", ] ), ) triton_dir = jaxlib_dir / "triton" - copy_runfiles( + copy_files( dst_dir=triton_dir, src_files=[ - "__main__/jaxlib/triton/__init__.py", - "__main__/jaxlib/triton/dialect.py", + f"{source_file_prefix}jaxlib/triton/__init__.py", + f"{source_file_prefix}jaxlib/triton/dialect.py", ], ) patch_copy_mlir_import( - "__main__/jaxlib/triton/_triton_enum_gen.py", dst_dir=triton_dir + f"{source_file_prefix}jaxlib/triton/_triton_enum_gen.py", + dst_dir=triton_dir, + runfiles=r, + wheel_sources_map=wheel_sources_map, ) patch_copy_mlir_import( - "__main__/jaxlib/triton/_triton_ops_gen.py", dst_dir=triton_dir + f"{source_file_prefix}jaxlib/triton/_triton_ops_gen.py", + dst_dir=triton_dir, + runfiles=r, + wheel_sources_map=wheel_sources_map, ) - copy_runfiles( - dst_dir=jaxlib_dir / "include" / "xla" / "ffi" / "api", - src_files=[ - "xla/xla/ffi/api/c_api.h", - "xla/xla/ffi/api/api.h", - "xla/xla/ffi/api/ffi.h", - ], + copy_files( + dst_dir=jaxlib_dir / "include" / "xla" / "ffi" / "api", + src_files=[ + "xla/ffi/api/c_api.h", + "xla/ffi/api/api.h", + "xla/ffi/api/ffi.h", + ], ) tmpdir = None @@ -383,6 +409,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu): prepare_wheel( pathlib.Path(sources_path), cpu=args.cpu, + wheel_sources=args.srcs, ) package_name = "jaxlib" if args.editable: diff --git a/jaxlib/tools/wheel_size_test.py b/jaxlib/tools/wheel_size_test.py new file mode 100644 index 000000000000..26b1ee2f89ed --- /dev/null +++ b/jaxlib/tools/wheel_size_test.py @@ -0,0 +1,56 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os + + +def parse_args(): + """Arguments parser.""" + parser = argparse.ArgumentParser( + description="Helper for the wheel size verification", + fromfile_prefix_chars="@", + ) + parser.add_argument( + "--wheel-path", required=True, help="Path of the wheel, mandatory" + ) + parser.add_argument( + "--max-size-mib", + required=True, + help="Maximum size of the wheel in MiB", + ) + return parser.parse_args() + + +def verify_wheel_size(args): + wheel_size_mib = os.path.getsize(args.wheel_path) >> 20 + wheel_name = os.path.basename(args.wheel_path) + if wheel_size_mib > int(args.max_size_mib): + raise RuntimeError( + "The {name} size is {size} MiB, which is larger than the maximum size" + " {max_size} MiB".format( + name=wheel_name, + size=wheel_size_mib, + max_size=args.max_size_mib, + ) + ) + else: + logging.info( + "The %s size is %s MiB, which is less than the maximum size" + " %s MB", wheel_name, wheel_size_mib, args.max_size_mib) + + +if __name__ == "__main__": + verify_wheel_size(parse_args()) diff --git a/jaxlib/traceback.cc b/jaxlib/traceback.cc new file mode 100644 index 000000000000..8ab90a3b4cc3 --- /dev/null +++ b/jaxlib/traceback.cc @@ -0,0 +1,413 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/traceback.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/hash/hash.h" +#include "absl/log/check.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/pjrt/exceptions.h" +#include "xla/python/nb_helpers.h" +#include "tsl/platform/platform.h" + +#ifdef PLATFORM_GOOGLE +#define Py_BUILD_CORE +#include "internal/pycore_frame.h" +#undef Py_BUILD_CORE +#endif // PLATFORM_GOOGLE + +namespace nb = nanobind; + +namespace jax { + +namespace { + +std::atomic traceback_enabled_ = true; + +static constexpr int kMaxFrames = 512; + +PyTypeObject* traceback_type_ = nullptr; + +static_assert(std::is_trivial_v == true); + +struct TracebackObject { + PyObject_VAR_HEAD; + TracebackEntry frames[]; +}; + +template +H AbslHashValue(H h, const TracebackObject& tb) { + h = H::combine_contiguous(std::move(h), &tb.frames[0], Py_SIZE(&tb)); + return h; +} + +static_assert(sizeof(TracebackObject) % alignof(PyObject) == 0); +static_assert(sizeof(TracebackEntry) % alignof(void*) == 0); + +bool traceback_check(nb::handle o) { + return Py_TYPE(o.ptr()) == traceback_type_; +} + +Py_hash_t traceback_tp_hash(PyObject* o) { + TracebackObject* tb = reinterpret_cast(o); + size_t h = absl::HashOf(*tb); + Py_hash_t s = absl::bit_cast(h); // Python hashes are signed. + return s == -1 ? -2 : s; // -1 must not be used as a Python hash value. +} + +PyObject* traceback_tp_richcompare(PyObject* self, PyObject* other, int op) { + if (op != Py_EQ && op != Py_NE) { + return Py_NewRef(Py_NotImplemented); + } + + if (!traceback_check(other)) { + return Py_NewRef(Py_False); + } + TracebackObject* tb_self = reinterpret_cast(self); + TracebackObject* tb_other = reinterpret_cast(other); + if (Py_SIZE(tb_self) != Py_SIZE(tb_other)) { + return Py_NewRef(op == Py_EQ ? Py_False : Py_True); + } + for (Py_ssize_t i = 0; i < Py_SIZE(tb_self); ++i) { + if ((tb_self->frames[i] != tb_other->frames[i])) { + return Py_NewRef(op == Py_EQ ? Py_False : Py_True); + } + } + return Py_NewRef(op == Py_EQ ? Py_True : Py_False); +} + +static void traceback_tp_dealloc(PyObject* self) { + TracebackObject* tb = reinterpret_cast(self); + for (Py_ssize_t i = 0; i < Py_SIZE(tb); ++i) { + Py_XDECREF(tb->frames[i].code); + } + PyTypeObject* tp = Py_TYPE(self); + tp->tp_free((PyObject*)self); + Py_DECREF(tp); +} + +Traceback::Frame DecodeFrame(const TracebackEntry& frame) { + return Traceback::Frame{ + .file_name = nb::borrow(frame.code->co_filename), + .function_name = nb::borrow(frame.code->co_qualname), + .function_start_line = frame.code->co_firstlineno, + .line_num = PyCode_Addr2Line(frame.code, frame.lasti), + }; +} + +std::string traceback_to_string(const TracebackObject* tb) { + std::vector frame_strs; + frame_strs.reserve(Py_SIZE(tb)); + for (Py_ssize_t i = 0; i < Py_SIZE(tb); ++i) { + frame_strs.push_back(DecodeFrame(tb->frames[i]).ToString()); + } + return absl::StrJoin(frame_strs, "\n"); +} + +PyObject* traceback_tp_str(PyObject* self) { + TracebackObject* tb = reinterpret_cast(self); + return nb::cast(traceback_to_string(tb)).release().ptr(); +} + +// It turns out to be slightly faster to define a tp_hash slot rather than +// defining __hash__ and __eq__ on the class. +PyType_Slot traceback_slots_[] = { + {Py_tp_hash, reinterpret_cast(traceback_tp_hash)}, + {Py_tp_richcompare, reinterpret_cast(traceback_tp_richcompare)}, + {Py_tp_dealloc, reinterpret_cast(traceback_tp_dealloc)}, + {Py_tp_str, reinterpret_cast(traceback_tp_str)}, + {0, nullptr}, +}; + +nb::object AsPythonTraceback(const Traceback& tb) { + nb::object traceback = nb::none(); + nb::dict globals; + nb::handle traceback_type(reinterpret_cast(&PyTraceBack_Type)); + TracebackObject* tb_obj = reinterpret_cast(tb.ptr()); + for (Py_ssize_t i = 0; i < Py_SIZE(tb_obj); ++i) { + const TracebackEntry& frame = tb_obj->frames[i]; + int lineno = PyCode_Addr2Line(frame.code, frame.lasti); + // Under Python 3.11 we observed crashes when using a fake PyFrameObject + // with a real PyCodeObject (https://github.com/google/jax/issues/16027). + // because the frame does not have fields necessary to compute the locals, + // notably the closure object, leading to crashes in CPython in + // _PyFrame_FastToLocalsWithError + // https://github.com/python/cpython/blob/deaf509e8fc6e0363bd6f26d52ad42f976ec42f2/Objects/frameobject.c#LL1116C2-L1116C2 + // We therefore always build a fake code object to go along with our fake + // frame. + PyCodeObject* py_code = + PyCode_NewEmpty(PyUnicode_AsUTF8(frame.code->co_filename), + PyUnicode_AsUTF8(frame.code->co_name), lineno); + PyFrameObject* py_frame = PyFrame_New(PyThreadState_Get(), py_code, + globals.ptr(), /*locals=*/nullptr); + Py_DECREF(py_code); + + traceback = traceback_type( + /*tb_next=*/std::move(traceback), + /*tb_frame=*/ + nb::steal(reinterpret_cast(py_frame)), + /*tb_lasti=*/0, + /*tb_lineno=*/lineno); + } + return traceback; +} + +} // namespace + +std::vector Traceback::Frames() const { + // We require the GIL because we manipulate Python strings. + CHECK(PyGILState_Check()); + std::vector frames; + TracebackObject* tb = reinterpret_cast(ptr()); + frames.reserve(Py_SIZE(tb)); + for (Py_ssize_t i = 0; i < Py_SIZE(tb); ++i) { + const TracebackEntry& frame = tb->frames[i]; + frames.push_back(DecodeFrame(frame)); + } + return frames; +} + +std::string Traceback::Frame::ToString() const { + return absl::StrFormat("%s:%d (%s)", nb::cast(file_name), + line_num, nb::cast(function_name)); +} + +std::string Traceback::ToString() const { + return traceback_to_string(reinterpret_cast(ptr())); +} + +absl::Span Traceback::RawFrames() const { + const TracebackObject* tb = reinterpret_cast(ptr()); + return absl::MakeConstSpan(tb->frames, Py_SIZE(tb)); +} + +/*static*/ bool Traceback::Check(PyObject* o) { return traceback_check(o); } + +/*static*/ std::optional Traceback::Get() { + // We use a thread_local here mostly to avoid requiring a large amount of + // space. + thread_local std::array frames; + int count = 0; + + DCHECK(PyGILState_Check()); + + if (!traceback_enabled_.load()) { + return std::nullopt; + } + + PyThreadState* thread_state = PyThreadState_GET(); + +#if defined(PLATFORM_GOOGLE) && PY_VERSION_HEX < 0x030e0000 +// This code is equivalent to the version using public APIs, but it saves us +// an allocation of one object per stack frame. However, this is definitely +// violating the API contract of CPython, so we only use this where we can be +// confident we know exactly which CPython we are using (internal to Google). +// Feel free to turn this on if you like, but it might break at any time! +#if PY_VERSION_HEX < 0x030d0000 + for (_PyInterpreterFrame* f = thread_state->cframe->current_frame; + f != nullptr && count < kMaxFrames; f = f->previous) { + if (_PyFrame_IsIncomplete(f)) continue; + Py_INCREF(f->f_code); + frames[count] = {f->f_code, static_cast(_PyInterpreterFrame_LASTI(f) * + sizeof(_Py_CODEUNIT))}; + ++count; + } +#else // PY_VERSION_HEX < 0x030d0000 + for (_PyInterpreterFrame* f = thread_state->current_frame; + f != nullptr && count < kMaxFrames; f = f->previous) { + if (_PyFrame_IsIncomplete(f)) continue; + Py_INCREF(f->f_executable); + frames[count] = { + reinterpret_cast(f->f_executable), + static_cast(_PyInterpreterFrame_LASTI(f) * sizeof(_Py_CODEUNIT))}; + ++count; + } +#endif // PY_VERSION_HEX < 0x030d0000 + +#else // PLATFORM_GOOGLE + PyFrameObject* py_frame = PyThreadState_GetFrame(thread_state); + while (py_frame != nullptr && count < kMaxFrames) { + frames[count] = {PyFrame_GetCode(py_frame), PyFrame_GetLasti(py_frame)}; + ++count; + PyFrameObject* next = PyFrame_GetBack(py_frame); + Py_DECREF(py_frame); + py_frame = next; + } + Py_XDECREF(py_frame); +#endif // PLATFORM_GOOGLE + + Traceback traceback = + nb::steal(PyObject_NewVar(PyObject, traceback_type_, count)); + TracebackObject* tb = reinterpret_cast(traceback.ptr()); + std::memcpy(tb->frames, frames.data(), sizeof(TracebackEntry) * count); + return traceback; +} + +bool Traceback::IsEnabled() { return traceback_enabled_.load(); } + +void Traceback::Register(nb::module_& m) { + nb::class_(m, "Frame") + .def(nb::init()) + .def_ro("file_name", &Traceback::Frame::file_name) + .def_ro("function_name", &Traceback::Frame::function_name) + .def_ro("function_start_line", &Traceback::Frame::function_start_line) + .def_ro("line_num", &Traceback::Frame::line_num) + .def("__repr__", [](const Traceback::Frame& frame) { + return absl::StrFormat( + "%s;%s:%d", nb::cast(frame.function_name), + nb::cast(frame.file_name), frame.line_num); + }); + + std::string name = + absl::StrCat(nb::cast(m.attr("__name__")), ".Traceback"); + + PyType_Spec traceback_spec = { + /*.name=*/name.c_str(), + /*.basicsize=*/static_cast(sizeof(TracebackObject)), + /*.itemsize=*/static_cast(sizeof(TracebackEntry)), + /*.flags=*/Py_TPFLAGS_DEFAULT, + /*.slots=*/traceback_slots_, + }; + + traceback_type_ = + reinterpret_cast(PyType_FromSpec(&traceback_spec)); + if (!traceback_type_) { + throw nb::python_error(); + } + + auto type = nb::borrow(traceback_type_); + m.attr("Traceback") = type; + + m.def("tracebacks_enabled", []() { return Traceback::IsEnabled(); }); + m.def("set_tracebacks_enabled", + [](bool value) { traceback_enabled_.store(value); }); + + type.attr("get_traceback") = nb::cpp_function(Traceback::Get, + R"doc( + Returns a :class:`Traceback` for the current thread. + + If ``Traceback.enabled`` is ``True``, returns a :class:`Traceback` + object that describes the Python stack of the calling thread. Stack + trace collection has a small overhead, so it is disabled by default. If + traceback collection is disabled, returns ``None``. )doc"); + type.attr("frames") = xla::nb_property_readonly(&Traceback::Frames); + type.attr("raw_frames") = nb::cpp_function( + [](const Traceback& tb) -> nb::tuple { + // We return a tuple of lists, rather than a list of tuples, because it + // is cheaper to allocate only three Python objects for everything + // rather than one per frame. + absl::Span frames = tb.RawFrames(); + nb::list out_code = nb::steal(PyList_New(frames.size())); + nb::list out_lasti = nb::steal(PyList_New(frames.size())); + for (size_t i = 0; i < frames.size(); ++i) { + const auto& frame = frames[i]; + PyObject* code = reinterpret_cast(frame.code); + Py_INCREF(code); + PyList_SET_ITEM(out_code.ptr(), i, code); + PyList_SET_ITEM(out_lasti.ptr(), i, + nb::int_(frame.lasti).release().ptr()); + } + return nb::make_tuple(out_code, out_lasti); + }, + nb::is_method(), + nb::sig( + "def raw_frames(self) -> tuple[list[types.CodeType], list[int]]")); + type.attr("as_python_traceback") = nb::cpp_function( + AsPythonTraceback, nb::is_method(), + nb::sig("def as_python_traceback(self) -> traceback.TracebackType")); + + type.attr("traceback_from_frames") = nb::cpp_function( + [](std::vector frames) { + nb::object traceback = nb::none(); + nb::dict globals; + nb::handle traceback_type( + reinterpret_cast(&PyTraceBack_Type)); + for (const Traceback::Frame& frame : frames) { + PyCodeObject* py_code = + PyCode_NewEmpty(frame.file_name.c_str(), + frame.function_name.c_str(), frame.line_num); + PyFrameObject* py_frame = PyFrame_New(PyThreadState_Get(), py_code, + globals.ptr(), /*locals=*/ + nullptr); + Py_DECREF(py_code); + traceback = traceback_type( + /*tb_next=*/std::move(traceback), + /*tb_frame=*/ + nb::steal(reinterpret_cast(py_frame)), + /*tb_lasti=*/0, + /*tb_lineno=*/ + frame.line_num); + } + return traceback; + }, + "Creates a traceback from a list of frames.", + nb::sig( + // clang-format off + "def traceback_from_frames(frames: list[Frame]) -> traceback.TracebackType" + // clang-format on + )); + + type.attr("code_addr2line") = nb::cpp_function( + [](nb::handle code, int lasti) { + if (!PyCode_Check(code.ptr())) { + throw xla::XlaRuntimeError("code argument must be a code object"); + } + return PyCode_Addr2Line(reinterpret_cast(code.ptr()), + lasti); + }, + "Python wrapper around the Python C API function PyCode_Addr2Line", + nb::sig("def code_addr2line(code: types.CodeType, lasti: int) -> int")); + + type.attr("code_addr2location") = nb::cpp_function( + [](nb::handle code, int lasti) { + if (!PyCode_Check(code.ptr())) { + throw xla::XlaRuntimeError("code argument must be a code object"); + } + int start_line, start_column, end_line, end_column; + if (!PyCode_Addr2Location(reinterpret_cast(code.ptr()), + lasti, &start_line, &start_column, &end_line, + &end_column)) { + throw nb::python_error(); + } + return nb::make_tuple(start_line, start_column, end_line, end_column); + }, + "Python wrapper around the Python C API function PyCode_Addr2Location", + nb::sig("def code_addr2location(code: types.CodeType, lasti: int) -> " + "tuple[int, int, int, int]")); +} + +} // namespace jax diff --git a/jaxlib/traceback.h b/jaxlib/traceback.h new file mode 100644 index 000000000000..97a43a400a34 --- /dev/null +++ b/jaxlib/traceback.h @@ -0,0 +1,90 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_TRACEBACK_H_ +#define JAXLIB_TRACEBACK_H_ + +#include + +#include +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/types/span.h" +#include "nanobind/nanobind.h" + +namespace jax { + +// Entry in a traceback. Must be POD. +struct TracebackEntry { + TracebackEntry() = default; + TracebackEntry(PyCodeObject* code, int lasti) : code(code), lasti(lasti) {} + PyCodeObject* code; + int lasti; + + bool operator==(const TracebackEntry& other) const { + return code == other.code && lasti == other.lasti; + } + bool operator!=(const TracebackEntry& other) const { + return !operator==(other); + } +}; +static_assert(std::is_trivial_v == true); + +template +H AbslHashValue(H h, const TracebackEntry& entry) { + h = H::combine(std::move(h), entry.code, entry.lasti); + return h; +} + +class Traceback : public nanobind::object { + public: + NB_OBJECT(Traceback, nanobind::object, "Traceback", Traceback::Check); + + // Returns a traceback if it is enabled, otherwise returns nullopt. + static std::optional Get(); + + // Returns true if traceback collection is enabled. + static bool IsEnabled(); + + // Returns a string representation of the traceback. + std::string ToString() const; + + // Returns a list of (code, lasti) pairs for each frame in the traceback. + // Frames are from innermost to outermost. + absl::Span RawFrames() const; + + struct Frame { + nanobind::str file_name; + nanobind::str function_name; + int function_start_line; + int line_num; + + std::string ToString() const; + }; + // Returns a list of Frames for the traceback. + std::vector Frames() const; + + static void Register(nanobind::module_& m); + + private: + static bool Check(PyObject* o); +}; + +} // namespace jax + +#endif // JAXLIB_TRACEBACK_H_ diff --git a/jaxlib/triton/BUILD b/jaxlib/triton/BUILD index 99cddd9e6381..aea1f03a2555 100644 --- a/jaxlib/triton/BUILD +++ b/jaxlib/triton/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("@llvm-project//mlir:tblgen.bzl", "gentbl_filegroup") +load("@rules_cc//cc:cc_library.bzl", "cc_library") load("//jaxlib:jax.bzl", "if_windows", "pytype_strict_library") licenses(["notice"]) @@ -35,7 +36,9 @@ pytype_strict_library( "//jaxlib/mlir:ir", ] + if_windows( [], - ["//jaxlib/mlir/_mlir_libs:_triton_ext"], + [ + "//jaxlib/mlir/_mlir_libs:_triton_ext", + ], ), ) diff --git a/jaxlib/triton/triton_dialect_capi.cc b/jaxlib/triton/triton_dialect_capi.cc index 6a46d2914f57..8359a802d922 100644 --- a/jaxlib/triton/triton_dialect_capi.cc +++ b/jaxlib/triton/triton_dialect_capi.cc @@ -15,12 +15,16 @@ limitations under the License. #include "jaxlib/triton/triton_dialect_capi.h" -#include "llvm/include/llvm/Support/Casting.h" -#include "mlir/include/mlir-c/IR.h" -#include "mlir/include/mlir/CAPI/IR.h" -#include "mlir/include/mlir/CAPI/Registration.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/Dialect.h" +#include + +#include "llvm/Support/Casting.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/CAPI/Support.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" @@ -56,8 +60,12 @@ MlirAttribute mlirTritonInferReduceOpEncoding(MlirAttribute operandEncoding, llvm::dyn_cast(&dialect); mlir::Attribute retEncoding; (void)inferLayoutInterface->inferReduceOpEncoding(opEncoding, axis, - retEncoding); + retEncoding, std::nullopt); return wrap(retEncoding); } +MlirTypeID mlirTritonPointerTypeGetTypeID(void) { + return wrap(mlir::triton::PointerType::getTypeID()); +} + } // extern "C" diff --git a/jaxlib/triton/triton_dialect_capi.h b/jaxlib/triton/triton_dialect_capi.h index 8c27b5b82500..556197988d71 100644 --- a/jaxlib/triton/triton_dialect_capi.h +++ b/jaxlib/triton/triton_dialect_capi.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef JAXLIB_TRITON_TRITON_DIALECT_CAPI_H_ #define JAXLIB_TRITON_TRITON_DIALECT_CAPI_H_ -#include "mlir/include/mlir-c/IR.h" -#include "mlir/include/mlir-c/Support.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" #ifdef __cplusplus extern "C" { @@ -25,6 +25,7 @@ extern "C" { MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Triton, triton); +MLIR_CAPI_EXPORTED MlirTypeID mlirTritonPointerTypeGetTypeID(void); MLIR_CAPI_EXPORTED MlirType mlirTritonPointerTypeGet(MlirType pointeeType, int addressSpace); MLIR_CAPI_EXPORTED bool mlirTritonIsAPointer(MlirType type); diff --git a/jaxlib/util.cc b/jaxlib/util.cc new file mode 100644 index 000000000000..a17e5b640264 --- /dev/null +++ b/jaxlib/util.cc @@ -0,0 +1,85 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/util.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/synchronization/notification.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "xla/future.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/value.h" +#include "xla/python/version.h" +#include "xla/tsl/concurrency/async_value.h" +#include "xla/tsl/concurrency/future.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/util.h" + +namespace ifrt = xla::ifrt; + +namespace jax { + +void BlockUntilReadyWithCancel(xla::Future<>& future) { + future.BlockUntilReady([](tsl::AsyncValue* value) { + auto state = std::make_shared(); + value->AndThen([state]() { state->Notify(); }); + while (true) { + if (state->WaitForNotificationWithTimeout(absl::Milliseconds(200))) { + break; + } + nanobind::gil_scoped_acquire gil_acquire; + if (PyErr_CheckSignals() != 0) { + throw nanobind::python_error(); + } + } + }); +} + +absl::Status AwaitBuffersReady(absl::Span ifrt_arrays) { + if (ifrt_arrays.empty()) { + return absl::OkStatus(); + } + + tsl::Future<> future; + if (ifrt_arrays.size() == 1) { + future = ifrt_arrays[0]->GetReadyFuture(); + } else { + std::vector values; + values.reserve(ifrt_arrays.size()); + for (ifrt::Array* const ifrt_array : ifrt_arrays) { + values.push_back(tsl::FormRef(ifrt_array)); + } + ifrt::Client* const client = ifrt_arrays.front()->client(); + future = client->GetReadyFuture(values); + } + BlockUntilReadyWithCancel(future); + absl::Status s = future.Await(); + if (!s.ok()) { + // Fix up error string because some clients rely on it. + if (s.message() == "GetReadyFuture() called on deleted or donated buffer") { + s = xla::InvalidArgument( + "BlockHostUntilReady() called on deleted or donated buffer"); + } + } + return s; +} + +} // namespace jax diff --git a/jaxlib/util.h b/jaxlib/util.h new file mode 100644 index 000000000000..dde600868fdf --- /dev/null +++ b/jaxlib/util.h @@ -0,0 +1,35 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_UTIL_H_ +#define JAXLIB_UTIL_H_ + +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "xla/future.h" +#include "xla/python/ifrt/array.h" + +namespace jax { + +// Waits until future is ready but will cancel if ctrl-c is pressed. +void BlockUntilReadyWithCancel(xla::Future<>& future); + +// Requests if given buffers are ready, awaits for results and returns OK if +// all of the buffers are ready or the last non-ok status. +absl::Status AwaitBuffersReady(absl::Span ifrt_arrays); + +} // namespace jax + +#endif // JAXLIB_UTIL_H_ diff --git a/jaxlib/utils.cc b/jaxlib/utils.cc index bf50b3a5254d..3aaac01bc023 100644 --- a/jaxlib/utils.cc +++ b/jaxlib/utils.cc @@ -19,12 +19,17 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" +#include "absl/base/log_severity.h" #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" +#include "absl/debugging/failure_signal_handler.h" +#include "absl/log/globals.h" #include "absl/synchronization/mutex.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "tsl/platform/platform.h" namespace nb = nanobind; @@ -360,6 +365,22 @@ nb::list TopologicalSort(nb::str parents_attr, return sorted_nodes; } +void InstallFailureSignalHandler(bool call_previous_handler) { +#ifndef PLATFORM_GOOGLE + absl::FailureSignalHandlerOptions options; + options.call_previous_handler = call_previous_handler; + absl::InstallFailureSignalHandler(options); +#endif // PLATFORM_GOOGLE +} + +void SetMinLogLevel(int severity) { + absl::SetMinLogLevel(static_cast(severity)); +} + +void SetStderrThreshold(int severity) { + absl::SetStderrThreshold(static_cast(severity)); +} + } // namespace NB_MODULE(utils, m) { @@ -378,6 +399,12 @@ NB_MODULE(utils, m) { "parent objects. end_nodes is an iterable of objects from which we " "should start a backwards search."); + // Abseil C++ logging functions. + m.def("absl_set_min_log_level", &SetMinLogLevel); + m.def("absl_set_vlog_level", &absl::SetVLogLevel); + m.def("absl_set_global_vlog_level", &absl::SetGlobalVLogLevel); + m.def("absl_set_stderr_threshold", &SetStderrThreshold); + // Python has no reader-writer lock in its standard library, so we expose // bindings around absl::Mutex. nb::class_(m, "Mutex") @@ -392,4 +419,7 @@ NB_MODULE(utils, m) { .def("writer_lock", &absl::Mutex::WriterLock, nb::call_guard()) .def("writer_unlock", &absl::Mutex::WriterUnlock); + + m.def("install_failure_signal_handler", &InstallFailureSignalHandler, + nb::arg("call_previous_handler") = true); } \ No newline at end of file diff --git a/jaxlib/weakref_lru_cache.cc b/jaxlib/weakref_lru_cache.cc new file mode 100644 index 000000000000..6773aa97980f --- /dev/null +++ b/jaxlib/weakref_lru_cache.cc @@ -0,0 +1,494 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/cleanup/cleanup.h" +#include "absl/hash/hash.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/pjrt/lru_cache.h" +#include "xla/tsl/platform/logging.h" + +namespace nb = nanobind; + +namespace jax { +namespace { + +// Minimal wrapper to expose a nb::dict_iterator's value as something +// hashable with Abseil. +class HashablePyDictEntry { + public: + explicit HashablePyDictEntry(std::pair entry) + : entry_(entry) {} + + template + friend H AbslHashValue(H h, const HashablePyDictEntry& v) { + return H::combine(std::move(h), nb::hash(v.entry_.first), + nb::hash(v.entry_.second)); + } + + std::pair entry_; +}; + +// Similarly, a minimalist adaptor around the nb::detail::dict_iterator +// itself. Note that the iterator "is" also a Value. Does not meet the full +// standard iterator requirements, only enough to support H::combine_unordered. +class HashablePyDictIter { + public: + using iterator_category = std::input_iterator_tag; + + explicit HashablePyDictIter(nb::detail::dict_iterator& iter) : iter_(iter) {} + + // Minimal set of iterator operations. + HashablePyDictEntry operator*() const { return HashablePyDictEntry(*iter_); } + bool operator!=(const HashablePyDictIter& rhs) const { + return iter_ != rhs.iter_; + } + void operator++() { ++iter_; } + + private: + nb::detail::dict_iterator& iter_; +}; + +struct HashableKey { + nb::object context; + nb::args args; + nb::kwargs kwargs; + + template + friend H AbslHashValue(H h, const HashableKey& key) { + // Note: Despite the fact this is an ABSL hash function, it's safe to call + // functions that may throw exceptions such as nb::hash(), because it is + // used by an LRUCache, which uses a std::unordered_map, which is + // exception-safe. + h = H::combine(std::move(h), nb::hash(key.context), nb::hash(key.args)); + nb::detail::dict_iterator begin = key.kwargs.begin(); + nb::detail::dict_iterator end = key.kwargs.end(); + h = H::combine_unordered(std::move(h), HashablePyDictIter(begin), + HashablePyDictIter(end)); + h = H::combine(std::move(h), key.kwargs.size()); + return h; + } +}; + +} // namespace + +class WeakrefLRUCache : public std::enable_shared_from_this { + public: + WeakrefLRUCache(nb::callable cache_context_fn, nb::callable fn, + int64_t maxsize, std::optional explain) + : cache_context_fn_(cache_context_fn), + fn_(fn), + lru_list_(std::make_shared(maxsize)), + explain_(explain) {} + + nb::object Call(nb::object weakref_key, nb::args args, nb::kwargs kwargs); + + void EvictWeakref(nb::object weakref_key); + + std::vector GetKeys(); + std::vector GetKeysLocked(); + + struct CacheInfo { + int64_t hits; + int64_t misses; + int64_t maxsize; + int64_t currsize; + }; + CacheInfo GetCacheInfo() const; + + void Clear(); + + static PyType_Slot slots_[]; + + private: + class Key { + public: + Key(nb::object context, nb::args args, nb::kwargs kwargs) + : context_(std::move(context)), + args_(std::move(args)), + kwargs_(std::move(kwargs)), + cached_hash_(absl::HashOf(HashableKey{context_, args_, kwargs_})) {} + + bool operator==(const Key& other) const { + return context_.equal(other.context_) && args_.equal(other.args_) && + kwargs_.equal(other.kwargs_); + } + + template + friend H AbslHashValue(H h, const Key& key) { + return H::combine(std::move(h), key.cached_hash_); + } + + nb::object context() const { return context_; } + nb::args args() const { return args_; } + nb::kwargs kwargs() const { return kwargs_; } + + int tp_traverse(visitproc visit, void* arg) const { + Py_VISIT(context_.ptr()); + Py_VISIT(args_.ptr()); + Py_VISIT(kwargs_.ptr()); + return 0; + } + + private: + nb::object context_; + nb::args args_; + nb::kwargs kwargs_; + size_t cached_hash_; + }; + + struct CacheEntry { + bool has_result = false; + nb::object result; + absl::Notification completed; + std::thread::id thread_id = std::this_thread::get_id(); + + int tp_traverse(visitproc visit, void* arg) const { + Py_VISIT(result.ptr()); + return 0; + } + }; + + struct WeakrefCacheKey { + nb::weakref ref; + size_t cached_hash; + }; + + using Cache = xla::LRUCache>; + + struct WeakrefCacheValue { + std::shared_ptr lru_list; + std::shared_ptr cache; + }; + + struct WeakrefKeyHash { + size_t operator()(const WeakrefCacheKey& v) const { return v.cached_hash; } + }; + + struct WeakrefKeyEq { + bool operator()(const WeakrefCacheKey& lhs, + const WeakrefCacheKey& rhs) const { + return lhs.ref.equal(rhs.ref); + } + }; + + std::shared_ptr GetCache(WeakrefCacheKey key) { + WeakrefCacheValue& value = entries_[key]; + if (!value.cache) { + value.lru_list = lru_list_; + value.cache = std::make_shared(lru_list_.get()); + } + return value.cache; + } + + WeakrefCacheKey MakeWeakrefKey(const nb::object& weakref_key); + + nb::callable cache_context_fn_; + nb::callable fn_; + std::shared_ptr lru_list_; + std::optional explain_; + std::unordered_map + entries_; + int64_t misses_ = 0; + int64_t total_queries_ = 0; + absl::Mutex mu_; + + // The thread ID of the thread that currently holds mu_. This is used to + // detect reentrant calls. + std::atomic mu_holder_thread_id_; + + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); +}; + +WeakrefLRUCache::WeakrefCacheKey WeakrefLRUCache::MakeWeakrefKey( + const nb::object& weakref_key) { + size_t wrcache_hash = static_cast(nb::hash(weakref_key)); + + // No hash computations after this point. + + auto weakref_gc_callback = nb::cpp_function( + [this_weak = weak_from_this(), wrcache_hash](nb::handle weakref) { + auto cache = this_weak.lock(); + if (cache == nullptr) { + return; + } + // Set up PyCriticalSection for cache python associated object; + auto py_cache = nb::find(cache); + // This should never happen as python cache should always be found + CHECK(py_cache.ptr() != nullptr); + nb::ft_object_guard lock(py_cache); + + // The object the reference referred to is now in the process of being + // destroyed, so we cannot refer to its contents. Python weakref + // objects compare based on identity if the object they refer to is + // gone, so the hash lookup will work fine. + auto it = cache->entries_.find( + WeakrefCacheKey{nb::borrow(weakref), wrcache_hash}); + if (it == cache->entries_.end()) { + return; + } + // Create temp-var to avoid re-entrant erase. + auto tmp = std::move(it->second); + cache->entries_.erase(it); + }); + nb::weakref weakref = nb::weakref(weakref_key, weakref_gc_callback); + return WeakrefCacheKey{std::move(weakref), wrcache_hash}; +} + +void WeakrefLRUCache::EvictWeakref(nb::object weakref_key) { + auto it = entries_.find(MakeWeakrefKey(weakref_key)); + if (it == entries_.end()) { + return; + } + // Create temp-var to avoid re-entrant erase. + auto tmp = std::move(it->second); + entries_.erase(it); +} + +nb::object WeakrefLRUCache::Call(nb::object weakref_key, nb::args args, + nb::kwargs kwargs) + ABSL_NO_THREAD_SAFETY_ANALYSIS { + nb::object context = cache_context_fn_(); + // We precompute all of the hash values needed by the various maps rather + // than computing them during the std::unordered_map insertions. At the very + // least, MSVC's std::unordered_map has undefined behavior if the hash + // function throws an exception + // (https://learn.microsoft.com/en-us/cpp/standard-library/unordered-map-class?view=msvc-170#emplace). + Key key(context, args, kwargs); + auto wrcache_key = MakeWeakrefKey(weakref_key); + + std::shared_ptr cache_ptr = GetCache(wrcache_key); + Cache& cache = *cache_ptr; + ++total_queries_; + + bool inserted = false; + std::shared_ptr entry; + if (mu_holder_thread_id_.load() == std::this_thread::get_id()) { + auto error_string = + absl::StrCat("Reentrant call to weakref_lru_cache. Key: ", + nb::cast(nb::repr(weakref_key)), + nb::cast(nb::repr(args))); + PyErr_SetString(PyExc_RecursionError, error_string.c_str()); + throw nb::python_error(); + } + { + // Because the gil can be released during cache insertion, this forces + // the lock order to be mu_ then gil so we must release the gil first. + nb::gil_scoped_release release; + + // Acquire a mutex to avoid problems where the gil is released during + // cache insertion and then a second thread invalidates the cache order. + mu_.lock(); + mu_holder_thread_id_.store(std::this_thread::get_id()); + } + std::vector miss_keys; + nb::object explainer; + { + // GetOrCreateIfAbsent calls into Python hash and equality functions, + // which may throw exceptions. The use of absl::Cleanup ensures mu_ is + // released if that happens. + absl::Cleanup unlock = [this]() ABSL_UNLOCK_FUNCTION(mu_) { + mu_holder_thread_id_.store(std::thread::id()); + mu_.unlock(); + }; + entry = cache.GetOrCreateIfAbsent( + key, [this, &miss_keys, &inserted, &explainer](const Key& key) { + inserted = true; + if (explain_.has_value()) { + explainer = (*explain_)(); + if (!explainer.is_none()) { + miss_keys = GetKeysLocked(); + } else { + explainer = nb::object(); + } + } + return std::make_shared(); + }); + } + if (!entry->completed.HasBeenNotified()) { + if (inserted) { + ++misses_; + absl::Cleanup notify = [&] { entry->completed.Notify(); }; + if (explainer) { + explainer(miss_keys, weakref_key, *args, **kwargs); + } + entry->result = fn_(weakref_key, *args, **kwargs); + entry->has_result = true; + } else { + if (entry->thread_id == std::this_thread::get_id()) { + auto error_string = + absl::StrCat("Recursively calling ", + nb::cast(nb::repr(weakref_key)), + nb::cast(nb::repr(args))); + PyErr_SetString(PyExc_RecursionError, error_string.c_str()); + throw nb::python_error(); + } + nb::gil_scoped_release release; + entry->completed.WaitForNotification(); + } + } + + if (entry->has_result) { + return entry->result; + } else { + ++misses_; + return fn_(weakref_key, *args, **kwargs); + } +} + +std::vector WeakrefLRUCache::GetKeys() { + absl::MutexLock l(mu_); + return GetKeysLocked(); +} +std::vector WeakrefLRUCache::GetKeysLocked() { + std::vector results; + for (const auto& [wr_key, wr_value] : entries_) { + wr_value.cache->ForEach([&results, &wr_key]( + const Key& key, + const std::shared_ptr& value) { + if (!value->completed.HasBeenNotified()) { return; } + nb::tuple result = + nb::make_tuple(*wr_key.ref, key.context(), key.args(), key.kwargs()); + results.push_back(std::move(result)); + }); + } + return results; +} + +WeakrefLRUCache::CacheInfo WeakrefLRUCache::GetCacheInfo() const { + CacheInfo result; + result.hits = total_queries_ - misses_; + result.misses = misses_; + result.maxsize = lru_list_->Capacity(); + result.currsize = lru_list_->Size(); + return result; +} + +void WeakrefLRUCache::Clear() { + total_queries_ = misses_ = 0; + std::vector> deferred_deletes; + deferred_deletes.reserve(entries_.size()); + for (auto& entry : entries_) { + deferred_deletes.emplace_back(entry.first, std::move(entry.second)); + } + entries_.clear(); + deferred_deletes.clear(); +} + +/*static*/ int WeakrefLRUCache::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + Py_VISIT(Py_TYPE(self)); + if (!nb::inst_ready(self)) { + return 0; + } + WeakrefLRUCache* cache = nb::inst_ptr(self); + Py_VISIT(cache->cache_context_fn_.ptr()); + Py_VISIT(cache->fn_.ptr()); + if (cache->explain_) { Py_VISIT(cache->explain_->ptr()); } + for (const auto& [wr_key, wr_value] : cache->entries_) { + Py_VISIT(wr_key.ref.ptr()); + int rval = 0; + wr_value.cache->ForEach( + [&visit, &arg, &rval](const Key& key, + const std::shared_ptr& value) { + if (rval != 0) { + return; + } + rval = key.tp_traverse(visit, arg); + if (rval != 0) { + return; + } + value->tp_traverse(visit, arg); + }); + if (rval != 0) { + return rval; + } + } + return 0; +} + +/*static*/ int WeakrefLRUCache::tp_clear(PyObject* self) { + WeakrefLRUCache* cache = nb::inst_ptr(self); + cache->Clear(); + cache->cache_context_fn_.reset(); + cache->fn_.reset(); + cache->explain_ = std::nullopt; + return 0; +} + +/* static */ PyType_Slot WeakrefLRUCache::slots_[] = { + {Py_tp_traverse, (void*)WeakrefLRUCache::tp_traverse}, + {Py_tp_clear, (void*)WeakrefLRUCache::tp_clear}, + {0, nullptr}, +}; + +NB_MODULE(weakref_lru_cache, m) { + auto weakref_lru_cache = + nb::class_(m, "WeakrefLRUCache", + nb::is_weak_referenceable(), + nb::type_slots(WeakrefLRUCache::slots_)) + .def("__call__", &WeakrefLRUCache::Call, nb::lock_self()) + .def("evict_weakref", &WeakrefLRUCache::EvictWeakref, nb::lock_self()) + .def("cache_keys", &WeakrefLRUCache::GetKeys, nb::lock_self()) + .def("cache_info", &WeakrefLRUCache::GetCacheInfo, nb::lock_self()) + .def("cache_clear", &WeakrefLRUCache::Clear, nb::lock_self()); + nb::class_(weakref_lru_cache, + "WeakrefLRUCacheInfo") + .def_ro("hits", &WeakrefLRUCache::CacheInfo::hits) + .def_ro("misses", &WeakrefLRUCache::CacheInfo::misses) + .def_ro("maxsize", &WeakrefLRUCache::CacheInfo::maxsize) + .def_ro("currsize", &WeakrefLRUCache::CacheInfo::currsize) + .def("__repr__", [](WeakrefLRUCache::CacheInfo& info) { + return absl::StrCat( + "WeakrefLRUCache(hits=", info.hits, ", misses=", info.misses, + ", maxsize=", info.maxsize, ", currsize=", info.currsize, ")"); + }); + m.def( + "weakref_lru_cache", + [](nb::callable cache_context_fn, nb::callable fn, + std::optional maxsize, std::optional explain) { + return std::make_shared( + cache_context_fn, fn, + maxsize.value_or(std::numeric_limits::max()), explain); + }, + nb::arg("cache_context_fn"), nb::arg("fn"), + nb::arg("maxsize").none() = 2048, + nb::arg("explain") = std::optional()); +} + +} // namespace jax diff --git a/jaxlib/weakref_lru_cache.pyi b/jaxlib/weakref_lru_cache.pyi new file mode 100644 index 000000000000..fcd30a8d0cdc --- /dev/null +++ b/jaxlib/weakref_lru_cache.pyi @@ -0,0 +1,40 @@ +# Copyright 2025 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from collections.abc import Callable +from typing import Any + +class WeakrefLRUCache: + def __call__(self, arg0: Any, /, *args, **kwargs) -> Any: ... + def evict_weakref(self, arg0: Any) -> None: ... + def cache_keys(self) -> list[Any]: ... + def cache_info(self) -> WeakrefLRUCache.WeakrefLRUCacheInfo: ... + def cache_clear(self) -> None: ... + + class WeakrefLRUCacheInfo: + @property + def hits(self) -> int: ... + @property + def misses(self) -> int: ... + @property + def maxsize(self) -> int: ... + @property + def currsize(self) -> int: ... + def __repr__(self) -> str: ... + +def weakref_lru_cache( + cache_context_fn: Callable, fn: Callable, maxsize: int | None = 2048, + explain: Callable | None = None +) -> WeakrefLRUCache: ... diff --git a/jaxlib/weakref_lru_cache_test.py b/jaxlib/weakref_lru_cache_test.py new file mode 100644 index 000000000000..f333e1a66a12 --- /dev/null +++ b/jaxlib/weakref_lru_cache_test.py @@ -0,0 +1,336 @@ +# Copyright 2023 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import gc +import random +import threading +import time +import weakref + +from absl.testing import absltest +from jax.jaxlib import weakref_lru_cache + + +class WeakrefLRUCacheTest(absltest.TestCase): + + def testMultiThreaded(self): + insert_evs = [threading.Event() for _ in range(2)] + insert_evs_i = 0 + + class WRKey: + pass + + class ClashingKey: + + def __eq__(self, other): + return False + + def __hash__(self): + return 333 # induce maximal caching problems. + + class GilReleasingCacheKey: + + def __eq__(self, other): + nonlocal insert_evs_i + if isinstance(other, GilReleasingCacheKey) and insert_evs_i < len( + insert_evs + ): + insert_evs[insert_evs_i].set() + insert_evs_i += 1 + time.sleep(0.01) + return False + + def __hash__(self): + return 333 # induce maximal caching problems. + + def CacheFn(obj, gil_releasing_cache_key): + del obj + del gil_releasing_cache_key + return None + + cache = weakref_lru_cache.weakref_lru_cache(lambda: None, CacheFn, 2048) + + wrkey = WRKey() + + def Body(): + for insert_ev in insert_evs: + insert_ev.wait() + for _ in range(20): + cache(wrkey, ClashingKey()) + + t = threading.Thread(target=Body) + t.start() + for _ in range(3): + cache(wrkey, GilReleasingCacheKey()) + t.join() + + def testAnotherMultiThreaded(self): + num_workers = 5 + barrier = threading.Barrier(num_workers) + cache = weakref_lru_cache.weakref_lru_cache( + lambda: None, lambda x, y: y, 2048 + ) + + class WRKey: + pass + + def WorkerAddToCache(): + barrier.wait() + wrkey = WRKey() + for i in range(10): + cache(wrkey, i) + + def WorkerCleanCache(): + barrier.wait() + for _ in range(10): + cache.cache_clear() + + workers = [ + threading.Thread(target=WorkerAddToCache) + for _ in range(num_workers - 1) + ] + [threading.Thread(target=WorkerCleanCache)] + + for t in workers: + t.start() + + for t in workers: + t.join() + + def testKwargsDictOrder(self): + miss_id = 0 + + class WRKey: + pass + + def CacheFn(obj, kwkey1, kwkey2): + del obj, kwkey1, kwkey2 + nonlocal miss_id + miss_id += 1 + return miss_id + + cache = weakref_lru_cache.weakref_lru_cache(lambda: None, CacheFn, 4) + + wrkey = WRKey() + + self.assertEqual(cache(wrkey, kwkey1="a", kwkey2="b"), 1) + self.assertEqual(cache(wrkey, kwkey1="b", kwkey2="a"), 2) + self.assertEqual(cache(wrkey, kwkey2="b", kwkey1="a"), 1) + + def testGetKeys(self): + def CacheFn(obj, arg): + del obj + return arg + "extra" + + cache = weakref_lru_cache.weakref_lru_cache(lambda: None, CacheFn, 4) + + class WRKey: + pass + + wrkey = WRKey() + + self.assertEmpty(cache.cache_keys()) + cache(wrkey, "arg1") + cache(wrkey, "arg2") + self.assertLen(cache.cache_keys(), 2) + + def testNonWeakreferenceableKey(self): + class NonWRKey: + __slots__ = () + + non_wr_key = NonWRKey() + with self.assertRaises(TypeError): + weakref.ref(non_wr_key) + + cache = weakref_lru_cache.weakref_lru_cache(lambda: None, lambda x: 2048) + for _ in range(100): + with self.assertRaises(TypeError): + cache(non_wr_key) + + def testCrashingKey(self): + class WRKey: + pass + + class CrashingKey: + # A key that raises exceptions if eq or hash is called. + + def __eq__(self, other): + raise ValueError("eq") + + def __hash__(self): + raise ValueError("hash") + + cache = weakref_lru_cache.weakref_lru_cache( + lambda: None, lambda x, y: y, 2048 + ) + wrkey = WRKey() + with self.assertRaises(ValueError): + for _ in range(100): + cache(wrkey, CrashingKey()) + + def testPrintingStats(self): + class WRKey: + pass + + cache = weakref_lru_cache.weakref_lru_cache( + lambda: None, lambda x, y: y, 2048 + ) + wrkey = WRKey() + for i in range(10): + cache(wrkey, i) + for i in range(5): + cache(wrkey, i) + + self.assertEqual( + repr(cache.cache_info()), + "WeakrefLRUCache(hits=5, misses=10, maxsize=2048, currsize=10)", + ) + + def testGCKeys(self): + class WRKey: + + def __init__(self, x): + self.x = x + + def __eq__(self, other): + return self.x == other.x + + def __hash__(self): + return hash(self.x) + + cache = weakref_lru_cache.weakref_lru_cache( + lambda: None, lambda x, y: y, 2048 + ) + keys = [WRKey(i) for i in range(10)] + for i in range(10): + cache(keys[i], i) + + # Delete some keys, to exercise the weakref callback behavior. + del keys[::2] + + for key in keys: + cache(key, 7) + + def testTpTraverse(self): + class WRKey: + pass + + def CacheContextFn(): + return None + + def CallFn(x, y, *args, **kwargs): + del x, args, kwargs + return y + + cache = weakref_lru_cache.weakref_lru_cache(CacheContextFn, CallFn, 2048) + + keys = [WRKey() for _ in range(10)] + values = [str(i) for i in range(10)] + args = [str(i) for i in range(10)] + kwargs = {"a": "b"} + + for key, value in zip(keys, values): + cache(key, value, *args, **kwargs) + + expected_refs = ( + [ + CacheContextFn, + CallFn, + weakref_lru_cache.WeakrefLRUCache, + kwargs, + ] + + [weakref.getweakrefs(key)[0] for key in keys] + + values + + args + ) + + # Can't use assertContainsSubset because it doesn't support kwargs since + # dicts aren't hashable. + for ref in expected_refs: + self.assertIn(ref, gc.get_referents(cache)) + + def testReentrantKey(self): + cache = weakref_lru_cache.weakref_lru_cache( + lambda: None, lambda x, y: y, 2048 + ) + + class WRKey: + pass + + class ReentrantKey: + def __eq__(self, other): + cache(WRKey(), None) + return False + + def __hash__(self): + return 42 + + wrkey = WRKey() + with self.assertRaisesRegex(RecursionError, "Reentrant call"): + for _ in range(100): + cache(wrkey, ReentrantKey()) + + def testEvictWeakref(self): + dtor_list = [] + + class NoisyDestructor: + + def __init__(self, v): + self.v = v + + def __del__(self): + dtor_list.append(self.v) + + cache = weakref_lru_cache.weakref_lru_cache( + lambda: None, lambda x, y: NoisyDestructor(y) + ) + + class WRKey: + pass + + N = 100 + expected_deletes = [] + plan = list(range(N)) * 2 + random.shuffle(plan) + keys = [None] * N + for i in plan: + if keys[i] is None: + keys[i] = WRKey() + cache(keys[i], i) + else: + cache.evict_weakref(keys[i]) + expected_deletes.append(i) + self.assertEqual(dtor_list, expected_deletes) + + def testExplain(self): + + def explain(keys, x): + self.assertLen(keys, num_keys_should_be) + + cache = weakref_lru_cache.weakref_lru_cache( + lambda: None, lambda x: None, explain=lambda: explain) + + class A: ... + a = A() + + num_keys_should_be = 0 + cache(a) + + num_keys_should_be = 1 + b = A() + cache(b) + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py new file mode 100644 index 000000000000..780693137c39 --- /dev/null +++ b/jaxlib/xla_client.py @@ -0,0 +1,556 @@ +# Copyright 2017 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""An XLA client in Python.""" + +from __future__ import annotations + +import atexit +from collections.abc import Mapping +import contextlib +import enum +import logging +import os +import threading +from typing import Any, Protocol, Union + +from jaxlib import _jax as _xla + +# Note this module does *not* depend on any Python protocol buffers. The XLA +# Python bindings are currently packaged both as part of jaxlib and as part +# of TensorFlow. If we use protocol buffers here, then importing both jaxlib +# and TensorFlow may fail with duplicate protocol buffer message definitions. + +# Most functions are snake_case for consistency with other modules, some +# method names are CamelCase for consistency with XLA. +# pylint: disable=invalid-name + +# Pylint has false positives for type annotations. +# pylint: disable=invalid-sequence-index + +ifrt_programs = _xla.ifrt_programs + +# Just an internal arbitrary increasing number to help with backward-compatible +# changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. +# +# Please suffix the version number with a brief description of your change +# in a comment. The goal here is to force a merge conflict if two changes +# attempt to grab the same version number. +_version = 397 # Re-enable DCN cross-host transfers on accelerators. + +# An internal increasing version number for protecting jaxlib code against +# ifrt changes. +# lives in xla/python/version.h. +# In JAX, reference this via jax._src.lib.ifrt_version. +_ifrt_version = _xla.ifrt_version_number + +xla_platform_names = { + 'cpu': 'Host', + 'gpu': 'CUDA', +} + +logger = logging.getLogger(__name__) + +_NameValueMapping = Mapping[str, Union[str, int, list[int], float, bool]] + + +def make_cpu_client( + asynchronous=True, + distributed_client=None, + node_id=0, + num_nodes=1, + collectives=None, + num_devices=None, + get_local_topology_timeout_minutes=None, + get_global_topology_timeout_minutes=None, + transfer_server_factory=None, +) -> Client: + register_custom_call_handler('cpu', _xla.register_custom_call_target) + register_custom_type_handler('cpu', _xla.register_custom_type) + return _xla.get_tfrt_cpu_client( + asynchronous=asynchronous, + distributed_client=distributed_client, + node_id=node_id, + num_nodes=num_nodes, + collectives=collectives, + num_devices=num_devices, + get_local_topology_timeout_minutes=get_local_topology_timeout_minutes, + get_global_topology_timeout_minutes=get_global_topology_timeout_minutes, + transfer_server_factory=transfer_server_factory, + ) + + +DeviceTopology = _xla.DeviceTopology +get_topology_for_devices = _xla.get_topology_for_devices + + +def make_tfrt_tpu_c_api_device_topology( + topology_name: str = '', **kwargs +) -> DeviceTopology: + """Creates a PJRT C API TopologyDescription.""" + return _xla.get_default_c_api_topology('tpu', topology_name, dict(**kwargs)) + + +def make_c_api_device_topology( + c_api: Any, topology_name: str = '', **kwargs +) -> DeviceTopology: + """Creates a PJRT C API TopologyDescription.""" + return _xla.get_c_api_topology(c_api, topology_name, dict(**kwargs)) + + +def pjrt_plugin_loaded(plugin_name: str) -> bool: + return _xla.pjrt_plugin_loaded(plugin_name) + + +def load_pjrt_plugin_dynamically(plugin_name: str, library_path: str) -> Any: + return _xla.load_pjrt_plugin(plugin_name, library_path, c_api=None) + + +def load_pjrt_plugin_with_c_api(plugin_name: str, c_api: Any) -> None: + _xla.load_pjrt_plugin(plugin_name, None, c_api) + + +def pjrt_plugin_initialized(plugin_name: str) -> bool: + return _xla.pjrt_plugin_initialized(plugin_name) + + +def initialize_pjrt_plugin(plugin_name: str) -> None: + """Initializes a PJRT plugin. + + The plugin needs to be loaded first (through load_pjrt_plugin_dynamically or + static linking) before this method is called. + Args: + plugin_name: the name of the PJRT plugin. + """ + _xla.initialize_pjrt_plugin(plugin_name) + + +def make_c_api_client( + plugin_name: str, + options: _NameValueMapping | None = None, + distributed_client: _xla.DistributedRuntimeClient | None = None, + transfer_server_factory: _xla.TransferServerInterfaceFactory | None = None, + force_dcn_cross_host_transfers: bool = False, +): + """Creates a PJRT C API client for a PJRT plugin. + + It is required that load_pjrt_plugin_dynamically is called once with the same + plugin_name before this method is called. + + Args: + plugin_name: the name of the PJRT plugin. + options: extra platform-specific options. + distributed_client: distributed client. + + Returns: + A PJRT C API client for plugin_name. + """ + if options is None: + options = {} + return _xla.get_c_api_client( + plugin_name, + options, + distributed_client, + transfer_server_factory, + force_dcn_cross_host_transfers, + ) + + +def generate_pjrt_gpu_plugin_options() -> _NameValueMapping: + """Generates the PjRt GPU plugin options. + + Returns: + A dictionary of plugin options. + """ + + options: dict[str, Any] = {} + options['platform_name'] = 'cuda' + allocator = os.getenv('XLA_PYTHON_CLIENT_ALLOCATOR', 'default').lower() + memory_fraction = os.getenv('XLA_CLIENT_MEM_FRACTION', '') + deprecated_memory_fraction = os.getenv('XLA_PYTHON_CLIENT_MEM_FRACTION', '') + if deprecated_memory_fraction: + if memory_fraction: + raise ValueError( + 'XLA_CLIENT_MEM_FRACTION is specified together ' + 'with XLA_PYTHON_CLIENT_MEM_FRACTION. ' + 'Remove the latter one, it is deprecated.' + ) + else: + memory_fraction = deprecated_memory_fraction + preallocate = os.getenv('XLA_PYTHON_CLIENT_PREALLOCATE', '') + collective_memory_size = os.getenv( + 'XLA_PYTHON_CLIENT_COLLECTIVE_MEM_SIZE_MB', '' + ) + if allocator not in ('default', 'platform', 'bfc', 'cuda_async'): + raise ValueError( + 'XLA_PYTHON_CLIENT_ALLOCATOR env var must be "default", "platform", ' + '"bfc", or "cuda_async", got "%s"' % allocator + ) + options['allocator'] = allocator + if memory_fraction: + options['memory_fraction'] = float(memory_fraction) + if preallocate: + options['preallocate'] = preallocate not in ('false', 'False', '0') + if collective_memory_size: + options['collective_memory_size'] = int(collective_memory_size) * (1 << 20) + abort = os.getenv('XLA_PYTHON_CLIENT_ABORT_COLLECTIVES_ON_FAILURE', '0') + options['abort_collectives_on_failure'] = bool(int(abort)) + use_trft_gpu_client = os.getenv('XLA_PYTHON_CLIENT_USE_TFRT_GPU_CLIENT', '0') + options['use_tfrt_gpu_client'] = bool(int(use_trft_gpu_client)) + return options + + +PrimitiveType = _xla.PrimitiveType + +Shape = _xla.Shape +Shape.__doc__ = """ +A Shape is an object defined in C++ that duck types like the following class: + +class Shape: + '''Represents an XLA shape. + + A shape is either an array shape, having rank-many integer + dimensions and an element type (represented by a Numpy dtype), or it + is a tuple shape, having a shape for every tuple component: + + type shape = + TupleShape of shape list + | ArrayShape of { dimensions: int list; element_type: dtype } + ''' + + @staticmethod + def tuple_shape(tuple_shapes) -> Shape: + "Construct a tuple shape." + + @staticmethod + def array_shape(element_type, dimensions, minor_to_major=None) -> Shape: + + @staticmethod + def from_pyval(pyval) -> Shape: + "Returns a Shape that describes a tuple-tree of Numpy arrays." + + def __init__(self, str) -> Shape: + "Parses a shape string." + def __eq__(self, other: Shape) -> bool: + def __ne__(self, other: Shape) -> bool: + def __hash__(self): + def __repr__(self): + def is_tuple(self) -> bool: + def is_array(self) -> bool: + def tuple_shapes(self) -> [Shape]: + def numpy_dtype(self) -> np.dtype: + "Like element_type(), but returns dtype('O') for a tuple shape." + def xla_element_type(self) -> PrimitiveType: + def element_type(self) -> np.dtype: + def dimensions(self) -> (int, int, ...): + def rank(self) -> int: + def with_major_to_minor_layout_if_absent(self) -> Shape: + "Returns a copy with missing layouts set to major-to-minor." + + def to_serialized_proto(self) -> bytes: + "Returns 'shape' as a serialized proto." +""" + +ProgramShape = _xla.ProgramShape +ProgramShape.__doc__ = """ +A ProgramShape is a C++ object that duck types like the following class. + +class ProgramShape: + def __init__(self, parameter_shapes, result_shape): + def parameter_shapes(self) -> [Shape]: + def result_shape(self) -> Shape: + def __repr__(self): +""" + +DeviceAssignment = _xla.DeviceAssignment +DeviceAssignment.__doc__ = """ +A DeviceAssignment is a C++ object with the following signature. + +def create(assignment): + '''Builds a device assignment. + + Args: + assignment: a 2D numpy array of device ordinal integers, indexed by + [replica][computation_in_replica]. + Returns: + A device assignment. + ''' + +def replica_count(): + '''Returns the number of replicas.''' +def computation_count(): + '''Returns the number of computations per replica.''' +""" + +Device = _xla.Device +CompileOptions = _xla.CompileOptions + +HostBufferSemantics = _xla.HostBufferSemantics + +# An Executable is a C++ class that duck types with the following API: +# class Executable: +# def local_devices(self) -> [Device]: +# def execute(self, arguments : [Buffer]) -> Buffer: +# """Execute on one replica with Buffer arguments and return value.""" +# +# def size_of_generated_code_in_bytes(self) -> int: +# """Return generated binary size, or -1 if not known.""" +# +# def execute_sharded_on_local_devices(self, arguments: [[Buffer]]) +# -> [Buffer]: +# """Execute on many replicas with Buffer arguments and return value. +# +# Args: +# arguments: A sequence of sequences of Buffers. The i'th element of each +# sequence comprises the arguments for execution on the i'th local +# device. +# +# Returns: +# A list of the computation's outputs as a list of Buffers for each +# device. +# """ +# +# There are different implementations of Executable for different backends. + + +XlaComputation = _xla.XlaComputation +Client = _xla.Client +Memory = _xla.Memory +Array = _xla.Array +ArrayImpl = _xla.ArrayImpl +LoadedExecutable = _xla.LoadedExecutable +Executable = _xla.Executable +DeviceList = _xla.DeviceList +OpSharding = _xla.OpSharding +HloSharding = _xla.HloSharding +Sharding = _xla.Sharding +NamedSharding = _xla.NamedSharding +SingleDeviceSharding = _xla.SingleDeviceSharding +PmapSharding = _xla.PmapSharding +GSPMDSharding = _xla.GSPMDSharding +PjRtLayout = _xla.PjRtLayout +AutotuneCacheMode = _xla.AutotuneCacheMode + + +def LoadedExecutable_execute(self, arguments, device=None): + del device + results = self.execute_sharded(arguments) + return [x[0] for x in results.disassemble_into_single_device_arrays()] + + +def LoadedExecutable_execute_with_token(self, arguments, device=None): + del device + results = self.execute_sharded(arguments, with_tokens=True) + return ( + [x[0] for x in results.disassemble_into_single_device_arrays()], + results.consume_token().get_token(0), + ) + + +LoadedExecutable.execute = LoadedExecutable_execute # type: ignore[method-assign] +LoadedExecutable.execute_with_token = LoadedExecutable_execute_with_token # type: ignore[method-assign] + + +class CustomCallTargetTraits(enum.IntFlag): + DEFAULT = 0 + # Calls to custom call are safe to trace into the command buffer. It means + # that calls to custom call always launch exactly the same device operations + # (can depend on attribute values) that can be captured and then replayed. + # + # Supported only for custom calls implemented with XLA FFI. + COMMAND_BUFFER_COMPATIBLE = 1 + + +class CustomCallHandler(Protocol): + + def __call__( + self, + name: str, + fn: Any, + platform: str, + /, + api_version: int = ..., + traits: CustomCallTargetTraits = ..., + ) -> None: + ... + + +_custom_callback_handler: dict[str, CustomCallHandler] = {} +# Key is xla_platform_name, value is (function_name, function, api_version) +_custom_callback: dict[ + str, list[tuple[str, Any, int, CustomCallTargetTraits]] +] = {} +_custom_callback_lock = threading.Lock() + + +def register_custom_call_target( + name: str, + fn: Any, + platform: str = 'cpu', + api_version: int = 0, + traits: CustomCallTargetTraits = CustomCallTargetTraits.DEFAULT, +) -> None: + """Registers a custom call target. + + Args: + name: bytes containing the name of the function. + fn: a PyCapsule object containing the function pointer. + platform: the target platform. + api_version: the XLA FFI version to use. Supported versions are: 0 for the + untyped FFI and 1 for the typed FFI. + traits: custom call traits corresponding to XLA FFI handler traits. + """ + # To support AMD GPUs, we need to have xla_platform_names["gpu"] == "ROCM" + # Since that is hardcoded to CUDA, we are using the following as workaround. + xla_platform_name = xla_platform_names.get(platform, platform) + with _custom_callback_lock: + if xla_platform_name in _custom_callback_handler: + _custom_callback_handler[xla_platform_name]( + name, fn, xla_platform_name, api_version, traits + ) + else: + _custom_callback.setdefault(xla_platform_name, []).append( + (name, fn, api_version, traits) + ) + + +def register_custom_call_handler( + platform: str, handler: CustomCallHandler +) -> None: + """Registers a custom handler and use it to register existing custom calls. + + If a custom call handler for the platform already exist, calling this method + is a no-op and it will not register a new handler. + + Args: + platform: the target platform. + handler: the function to register a custom call. + """ + xla_platform_name = xla_platform_names.get(platform, platform) + with _custom_callback_lock: + if xla_platform_name in _custom_callback_handler: + logger.debug( + 'Custom call handler for %s is already register. Will not register a' + ' new one', + xla_platform_name, + ) + return + _custom_callback_handler[xla_platform_name] = handler + if xla_platform_name in _custom_callback: + for name, fn, api_version, traits in _custom_callback[xla_platform_name]: + handler(name, fn, xla_platform_name, api_version, traits) + del _custom_callback[xla_platform_name] + + +class CustomTypeIdHandler(Protocol): + + def __call__(self, type_name: str, type_id: Any, /) -> None: + ... + + +_custom_type_id_handler: dict[str, CustomTypeIdHandler] = {} +_custom_type_id: dict[str, Any] = {} +_custom_type_id_lock = threading.Lock() + + +def register_custom_type( + type_name: str, + type_id: Any, + platform: str = 'cpu', +) -> None: + """Register a custom type id for use with the FFI. + + Args: + type_name: a unique name for the type. + type_id: a PyCapsule object containing a pointer to the ``ffi::TypeId``. + platform: the target platform. + """ + xla_platform_name = xla_platform_names.get(platform, platform) + with _custom_type_id_lock: + if xla_platform_name in _custom_type_id_handler: + _custom_type_id_handler[xla_platform_name](type_name, type_id) + else: + _custom_type_id.setdefault(xla_platform_name, []).append( + (type_name, type_id) + ) + + +def register_custom_type_handler( + platform: str, handler: CustomTypeIdHandler +) -> None: + """Register a custom type id handler and use it to register existing type ids. + + If a custom type id handler for the platform already exist, calling this + method is a no-op and it will not register a new handler. + + Args: + platform: the target platform. + handler: the function to register a custom type id. + """ + xla_platform_name = xla_platform_names.get(platform, platform) + with _custom_callback_lock: + if xla_platform_name in _custom_type_id_handler: + logger.debug( + 'Custom type id handler for %s is already register. Will not ' + 'register a new one', + xla_platform_name, + ) + return + _custom_type_id_handler[xla_platform_name] = handler + if xla_platform_name in _custom_type_id: + for name, capsule in _custom_type_id[xla_platform_name]: + handler(name, capsule) + del _custom_type_id[xla_platform_name] + + +register_custom_call_partitioner = _xla.register_custom_call_partitioner +encode_inspect_sharding_callback = _xla.encode_inspect_sharding_callback +hlo_sharding_util = _xla.hlo_sharding_util +register_custom_call_as_batch_partitionable = ( + _xla.register_custom_call_as_batch_partitionable +) + + +Traceback = _xla.Traceback +Frame = _xla.Frame + + +@contextlib.contextmanager +def execution_stream_id(new_id: int): + """Context manager that overwrites and restores the current thread's execution_stream_id.""" + saved = _xla.get_execution_stream_id() + _xla.set_execution_stream_id(new_id) + try: + yield + finally: + _xla.set_execution_stream_id(saved) + + +XlaRuntimeError = _xla.JaxRuntimeError + +# Perform one last garbage collection of deferred Python references. This is +# mostly to keep ASAN happy. +atexit.register(_xla.collect_garbage) + +array_result_handler = _xla.array_result_handler +batched_copy_array_to_devices_with_sharding = ( + _xla.batched_copy_array_to_devices_with_sharding +) +batched_device_put = _xla.batched_device_put +reorder_shards = _xla.reorder_shards +batched_block_until_ready = _xla.batched_block_until_ready +check_and_canonicalize_memory_kind = _xla.check_and_canonicalize_memory_kind +Layout = _xla.Layout +custom_call_targets = _xla.custom_call_targets +ArrayCopySemantics = _xla.ArrayCopySemantics diff --git a/jaxlib/xla_compiler.cc b/jaxlib/xla_compiler.cc new file mode 100644 index 000000000000..eaed15c07487 --- /dev/null +++ b/jaxlib/xla_compiler.cc @@ -0,0 +1,1380 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla_compiler.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/types/span.h" +#include "mlir/Support/LLVM.h" +#include "nanobind/nanobind.h" +#include "nanobind/ndarray.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/dlpack.h" +#include "jaxlib/py_client.h" +#include "xla/array.h" +#include "xla/client/executable_build_options.h" +#include "xla/debug_options_flags.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_print_options.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/literal.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/proto/compile_options.pb.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/nb_numpy.h" +#include "xla/python/types.h" +#include "xla/service/computation_placer.h" +#include "xla/service/hlo.pb.h" +#include "xla/service/hlo_graph_dumper.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/spmd/shardy/stablehlo_round_trip/stablehlo_import.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/lib/strings/proto_serialization.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace { + +namespace nb = nanobind; + +// Converts a computation to a serialized HloModuleProto. +absl::StatusOr GetComputationSerializedProto( + const XlaComputation& computation) { + std::string result; + if (!tsl::SerializeToStringDeterministic(computation.proto(), &result)) { + return Unknown("Failed to serialize the HloModuleProto."); + } + return nb::bytes(result.data(), result.size()); +} + +// Converts a hlo module to a serialized HloModuleProto. +absl::StatusOr GetHloModuleSerializedProto(const HloModule& module) { + std::string result; + if (!tsl::SerializeToStringDeterministic(module.ToProto(), &result)) { + return Unknown("Failed to serialize the HloModuleProto."); + } + return nb::bytes(result.data(), result.size()); +} + +// Converts a serialized HloModuleProto into a HloModule. +absl::StatusOr> HloModuleFromSerializedProto( + const nb::bytes& bytes) { + HloModuleProto proto; + proto.ParseFromArray(bytes.c_str(), bytes.size()); + TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, + HloModule::CreateModuleConfigFromProto( + proto, GetDebugOptionsFromFlags())); + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + HloModule::CreateFromProto(proto, module_config)); + return std::shared_ptr(std::move(module)); +} + +absl::StatusOr> GetHloModule( + const XlaComputation& computation) { + TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, + HloModule::CreateModuleConfigFromProto( + computation.proto(), GetDebugOptionsFromFlags())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module, + HloModule::CreateFromProto(computation.proto(), module_config)); + return std::shared_ptr(std::move(module)); +} + +// Converts a computation to textual HLO form. +absl::StatusOr GetComputationHloText( + const XlaComputation& computation, bool print_large_constants = false) { + TF_ASSIGN_OR_RETURN(std::shared_ptr hlo_module, + GetHloModule(computation)); + HloPrintOptions options; + options = HloPrintOptions::ShortParsable(); + options.set_print_large_constants(print_large_constants); + return hlo_module->ToString(options); +} + +// Converts a computation to HLO dot graph form. +absl::StatusOr GetComputationHloDotGraph( + const XlaComputation& computation) { + TF_ASSIGN_OR_RETURN(std::shared_ptr hlo_module, + GetHloModule(computation)); + return RenderGraph(*hlo_module->entry_computation(), /*label=*/"", + hlo_module->config().debug_options(), + RenderedGraphFormat::kDot); +} + +// Hashes the HLO module. +absl::StatusOr HashComputation(const XlaComputation& computation) { + TF_ASSIGN_OR_RETURN(std::shared_ptr hlo_module, + GetHloModule(computation)); + return absl::HashOf(*hlo_module); +} +// Safe version of ShapeUtil::MakeShapeWithDenseLayout that fails gracefully on +// invalid input. +absl::StatusOr MakeShapeWithDenseLayout( + PrimitiveType element_type, absl::Span dims, + std::optional> minor_to_major, + std::optional> dynamic_dimensions) { + Shape shape; + if (dynamic_dimensions) { + TF_ASSIGN_OR_RETURN( + shape, ShapeUtil::MakeValidatedShape(element_type, dims, + dynamic_dimensions.value())); + } else { + TF_ASSIGN_OR_RETURN(shape, + ShapeUtil::MakeValidatedShape(element_type, dims)); + } + if (minor_to_major) { + *shape.mutable_layout() = LayoutUtil::MakeLayout(*minor_to_major); + TF_RETURN_IF_ERROR( + LayoutUtil::ValidateLayoutForShape(shape.layout(), shape)); + } + + return shape; +} + +// Pybind function for HloSharding.iota_tile, which is a non-crashing factory +// that produces a HloSharding instance backed by tile assignment of a +// transposed and reshaped iota array of device ids. More specifically the tile +// assignment array is as if it is produced by the following numpy code: +// numpy.arange(math.prod(dims)).reshape(reshape_dims) +// .transpose(transpose_perm).reshape(math.prod(dims)) +// where: +// `dims`: is the dimensions of the tile assignment array, which corresponds to +// OpSharding.tile_assignment_dimensions. +// `reshape_dims`: is the dimensions the 1D iota array is reshaped to. +// `transpose_perm`: is the dimension permutation to transpose `reshape_dims`. +// `subgroup_types`: indicates the subgroups of the last `subgroup_types.size()` +// dimensions in `dims`. +// +// In practice, `reshape_dims` often maps to the axes of user defined device +// mesh, and `transpose_perm` often maps to the user specification of how a +// tensor is partitioned based on the axes defined in the mesh, e.g. for a mesh +// of size 4x2x2 as AxBxC: +// PartitionSpec('A', 'B', 'C') corresponds to reshape_dims=[4,2,2], +// transpose_perm=[0,1,2] (no transpose) +// PartitionSpec('B', 'A', 'C') corresponds to reshape_dims=[4,2,2], +// transpose_perm=[1,0,2] (swap A and B) +absl::StatusOr IotaTileHelper( + absl::Span dims, absl::Span reshape_dims, + absl::Span transpose_perm, + absl::Span subgroup_types) { + if (dims.empty()) { + return InvalidArgument("`dims` should not be empty."); + } + if (reshape_dims.size() != transpose_perm.size()) { + return InvalidArgument( + "`reshape_dims` and `transpose_perm` should have the same size, saw " + "[%s] v.s. [%s]", + absl::StrJoin(reshape_dims, ","), absl::StrJoin(transpose_perm, ",")); + } + if (!reshape_dims.empty() && Product(dims) != Product(reshape_dims)) { + return InvalidArgument( + "Cannot reshape from `dims` [%s] to `reshape_dims` [%s].", + absl::StrJoin(dims, ","), absl::StrJoin(reshape_dims, ",")); + } + if (subgroup_types.size() > dims.size()) { + return InvalidArgument( + "`subgroup_types`(%lld) should not have more dimensions than " + "`dims`(%lld).", + subgroup_types.size(), dims.size()); + } + if (reshape_dims.empty()) { + return subgroup_types.empty() + ? HloSharding::IotaTile(dims) + : HloSharding::Subgroup(TileAssignment(dims), subgroup_types); + } + return subgroup_types.empty() + ? HloSharding::IotaTile(dims, reshape_dims, transpose_perm) + : HloSharding::Subgroup( + TileAssignment(dims, reshape_dims, transpose_perm), + subgroup_types); +} + +template +void DefRepeatedProperty(nb::class_& cls, const char* name, + Container* (T::*getter)()) { + cls.def_prop_rw( + name, + [getter](T& obj) { + Container* elems = (obj.*getter)(); + std::vector result; + result.reserve(elems->size()); + std::copy(elems->begin(), elems->end(), std::back_inserter(result)); + return result; + }, + [getter](T& obj, std::vector new_elems) { + Container* elems = (obj.*getter)(); + elems->Clear(); + elems->Reserve(new_elems.size()); + for (typename Container::value_type& e : new_elems) { + elems->Add(std::move(e)); + } + }); +} + +template +void DefRepeatedEnumProperty(nb::class_& cls, const char* name, + Container* (T::*getter)()) { + cls.def_prop_rw( + name, + [getter](T& obj) { + Container* elems = (obj.*getter)(); + std::vector result; + result.reserve(elems->size()); + std::copy(elems->begin(), elems->end(), std::back_inserter(result)); + return result; + }, + [getter]( + T& obj, + nb::typed new_elems) { + Container* elems = (obj.*getter)(); + elems->Clear(); + for (nb::handle e : new_elems) { + elems->Add(nb::cast(e.attr("value"))); + } + }); +} + +template +Array NDArrayToArray(nb::ndarray ndarray) { + std::vector shapes; + shapes.reserve(ndarray.ndim()); + for (int i = 0; i < ndarray.ndim(); ++i) { + shapes.push_back(ndarray.shape(i)); + } + xla::Array array(shapes); + array.Each([&](absl::Span indices, int64_t* val) { + int64_t offset = indices.back(); + int64_t multiplier = 1; + for (int i = ndarray.ndim() - 1; i > 0; --i) { + multiplier *= ndarray.shape(i); + offset += indices[i - 1] * multiplier; + } + *val = *(ndarray.data() + offset); + }); + return array; +} + +absl::StatusOr SubgroupWithTileAssignmentHelper( + nb::ndarray tile_assignment, + absl::Span subgroup_types) { + return HloSharding::Subgroup(NDArrayToArray(tile_assignment), subgroup_types); +} + +nb::ndarray<> LiteralToNdarray(Literal& obj) { + const Shape& shape = obj.shape(); + + if (!shape.has_layout()) { + throw XlaRuntimeError( + "Creating an array is only supported for Literals with a layout."); + } + + const Layout& layout = shape.layout(); + + if (!layout.tiles().empty()) { + throw XlaRuntimeError( + "Creating an array from a tiled Literal is not supported."); + } + + if (!shape.IsArray()) { + throw XlaRuntimeError( + "Creating an array is only supported for dense Literals."); + } + + xla::PrimitiveType primitive_type = shape.element_type(); + nb::dlpack::dtype dtype = + ValueOrThrow(jax::PrimitiveTypeToNbDLDataType(primitive_type)); + + absl::Span dimensions = shape.dimensions(); + std::vector unsigned_dimensions(dimensions.begin(), dimensions.end()); + auto strides = StridesForShape(primitive_type, dimensions, layout); + + return nb::ndarray<>(obj.untyped_data(), unsigned_dimensions.size(), + unsigned_dimensions.data(), {}, strides.data(), dtype, + nb::device::cpu::value, 0); +} + +struct Descriptor {}; + +} // namespace + +void BuildXlaCompilerSubmodule(nb::module_& m) { + // Types + nb::enum_(m, "PrimitiveType", nb::is_arithmetic()) + .value("PRIMITIVE_TYPE_INVALID", PRIMITIVE_TYPE_INVALID) + .value("PRED", PRED) + .value("S4", S4) + .value("S8", S8) + .value("S16", S16) + .value("S32", S32) + .value("S64", S64) + .value("U4", U4) + .value("U8", U8) + .value("U16", U16) + .value("U32", U32) + .value("U64", U64) + .value("F16", F16) + .value("F4E2M1FN", F4E2M1FN) + .value("F8E3M4", F8E3M4) + .value("F8E4M3", F8E4M3) + .value("F8E4M3FN", F8E4M3FN) + .value("F8E4M3B11FNUZ", F8E4M3B11FNUZ) + .value("F8E4M3FNUZ", F8E4M3FNUZ) + .value("F8E5M2", F8E5M2) + .value("F8E5M2FNUZ", F8E5M2FNUZ) + .value("F8E8M0FNU", F8E8M0FNU) + .value("BF16", BF16) + .value("F32", F32) + .value("F64", F64) + .value("C64", C64) + .value("C128", C128) + .value("TUPLE", TUPLE) + .value("OPAQUE_TYPE", OPAQUE_TYPE) + .value("TOKEN", TOKEN); + + // Shapes + nb::class_ layout_class(m, "Layout"); + layout_class.def(nb::init>()) + .def("__init__", + [](Layout* self, nb::typed minor_to_major, + nb::typed> + tiling, + int64_t element_size_in_bits) { + std::vector xla_tiles; + xla_tiles.reserve(nb::len(tiling.ptr())); + for (auto tile : tiling) { + xla_tiles.push_back(Tile( + SequenceToVector(nb::cast(tile)))); + } + std::vector xla_minor_to_major = + SequenceToVector(minor_to_major); + new (self) + Layout(xla_minor_to_major, xla_tiles, element_size_in_bits); + }) + .def("minor_to_major", + [](Layout layout) { return SpanToNbTuple(layout.minor_to_major()); }) + .def("element_size_in_bits", &Layout::element_size_in_bits) + .def("tiling", + [](Layout layout) { + std::vector> result; + result.reserve(layout.tiles().size()); + for (auto& t : layout.tiles()) { + result.push_back(SpanToNbTuple(t.dimensions())); + } + return result; + }) + .def( + "__eq__", + [](const Layout& layout, const Layout& other) { + return layout == other; + }, + nb::is_operator(), + nb::sig("def __eq__(self, other: object, /) -> bool")) + .def( + "__ne__", + [](const Layout& layout, const Layout& other) { + return layout != other; + }, + nb::is_operator(), + nb::sig("def __ne__(self, other: object, /) -> bool")) + .def("__str__", &Layout::ToString) + .def("__hash__", + [](const Layout& layout) { return absl::HashOf(layout); }) + .def("to_string", &Layout::ToString) + .def("__getstate__", + [](const Layout& self) -> nb::tuple { + auto proto = self.ToProto(); + std::string result; + if (!tsl::SerializeToStringDeterministic(proto, &result)) { + // throw converted by PyBind to a Python RuntimeError. + throw XlaRuntimeError( + absl::StrCat("Layout.py_pickle: ", + "SerializeToStringDeterministic failed")); + } + return nb::make_tuple(nb::bytes(result.data(), result.size())); + }) + .def("__setstate__", [](Layout* self, nb::tuple t) { + LayoutProto result; + nb::bytes serialized = nb::cast(t[0]); + result.ParseFromArray(serialized.c_str(), serialized.size()); + new (self) Layout(ValueOrThrow(Layout::FromProto(result))); + }); + + nb::class_ shape_class(m, "Shape"); + shape_class + .def("__init__", + [](Shape* self, const std::string& s) { + new (self) Shape(ValueOrThrow(ParseShape(s))); + }) + .def_static( + "tuple_shape", + [](std::vector shapes) -> Shape { + return ShapeUtil::MakeTupleShape(shapes); + }, + "Constructs a tuple shape.") + .def_static( + "array_shape", + xla::ValueOrThrowWrapper( + [](PrimitiveType type, nb::typed dims_seq, + std::optional> layout_seq, + std::optional> dynamic_dimensions) + -> absl::StatusOr { + std::vector dims = SequenceToVector(dims_seq); + if (layout_seq) { + std::vector layout = + SequenceToVector(*layout_seq); + return MakeShapeWithDenseLayout(type, dims, layout, + dynamic_dimensions); + } else { + return MakeShapeWithDenseLayout(type, dims, std::nullopt, + dynamic_dimensions); + } + }), + "Constructs an array shape.", nb::arg("type"), nb::arg("dims"), + nb::arg("layout").none() = std::nullopt, + nb::arg("dynamic_dimensions").none() = std::nullopt) + .def_static( + "array_shape", + xla::ValueOrThrowWrapper( + [](nb_dtype dtype, nb::typed dims_seq, + std::optional> layout_seq, + std::optional> dynamic_dimensions) + -> absl::StatusOr { + PrimitiveType type = ValueOrThrow(DtypeToPrimitiveType(dtype)); + std::vector dims = SequenceToVector(dims_seq); + if (layout_seq) { + std::vector layout = + SequenceToVector(*layout_seq); + return MakeShapeWithDenseLayout(type, dims, layout, + dynamic_dimensions); + } else { + return MakeShapeWithDenseLayout(type, dims, std::nullopt, + dynamic_dimensions); + } + }), + "Constructs an array shape.", nb::arg("type"), nb::arg("dims"), + nb::arg("layout").none() = std::nullopt, + nb::arg("dynamic_dimensions").none() = std::nullopt) + .def_static("token_shape", []() { return ShapeUtil::MakeTokenShape(); }) + .def_static( + "scalar_shape", + [](PrimitiveType type) -> Shape { + return ShapeUtil::MakeScalarShape(type); + }, + "Constructs a scalar shape.", nb::arg("type")) + .def_static( + "scalar_shape", + [](nb_dtype dtype) -> Shape { + PrimitiveType type = xla::ValueOrThrow(DtypeToPrimitiveType(dtype)); + return ShapeUtil::MakeScalarShape(type); + }, + "Constructs a scalar shape.", nb::arg("type")) + .def("dimensions", + [](const Shape& shape) { return SpanToNbTuple(shape.dimensions()); }) + .def("layout", + [](const Shape& shape) -> Layout { return shape.layout(); }) + .def("xla_element_type", &Shape::element_type) + .def("element_type", + [](const Shape& shape) { + return xla::ValueOrThrow( + PrimitiveTypeToNbDtype(shape.element_type())); + }) + .def("numpy_dtype", + [](const Shape& shape) { + if (shape.IsTuple()) { + return nb_dtype("O"); + } + return xla::ValueOrThrow( + PrimitiveTypeToNbDtype(shape.element_type())); + }) + .def("is_tuple", &Shape::IsTuple) + .def("is_array", &Shape::IsArray) + .def("is_token", &Shape::IsToken) + .def("is_static", &Shape::is_static) + .def("is_dynamic", &Shape::is_dynamic) + .def("is_dynamic_dimension", &Shape::is_dynamic_dimension, + nb::arg("dimension")) + .def("set_dynamic_dimension", &Shape::set_dynamic_dimension, + nb::arg("dimension"), nb::arg("is_dynamic")) + .def("rank", &Shape::dimensions_size) + .def("to_serialized_proto", + [](const Shape& shape) { + ShapeProto proto = shape.ToProto(); + std::string s = proto.SerializeAsString(); + return nb::bytes(s.data(), s.size()); + }) + .def("tuple_shapes", + [](const Shape& shape) { + return std::vector(shape.tuple_shapes()); + }) + .def("leaf_count", + [](const Shape& shape) { return ShapeUtil::GetLeafCount(shape); }) + .def( + "with_major_to_minor_layout_if_absent", + [](const Shape& shape) { + Shape out = shape; + ShapeUtil::ForEachMutableSubshape( + &out, [](Shape* subshape, const ShapeIndex&) { + if (!subshape->has_layout()) { + LayoutUtil::SetToDefaultLayout(subshape); + } + }); + return out; + }, + "Returns a copy of a shape with missing layouts set to " + "major-to-minor.") + .def( + "__eq__", + [](const Shape& shape, const Shape& other) { return shape == other; }, + nb::is_operator(), + nb::sig("def __eq__(self, other: object, /) -> bool")) + .def( + "__ne__", + [](const Shape& shape, const Shape& other) { return shape != other; }, + nb::is_operator(), + nb::sig("def __ne__(self, other: object, /) -> bool")) + .def("__hash__", [](const Shape& shape) { return absl::HashOf(shape); }) + .def("__repr__", [](const Shape& shape) { + return shape.ToString(/*print_layout=*/true); + }); + + nb::class_(m, "ProgramShape") + .def( + "__init__", + [](ProgramShape* self, absl::Span params, Shape result) { + new (self) ProgramShape(); + for (const Shape& param : params) { + self->AddParameter(param, ""); + } + *self->mutable_result() = result; + }) + .def("parameter_shapes", + static_cast& (ProgramShape::*)() const>( + &ProgramShape::parameters)) + .def("result_shape", &ProgramShape::result) + .def("__repr__", &ProgramShape::ToString); + + // Literals + nb::class_(m, "Literal") + .def(nb::init()) + .def("__repr__", &Literal::ToString) + .def( + "__array__", + [](std::shared_ptr obj, std::optional dtype, + std::optional copy) { + // Provides the interface required by numpy to create a np.ndarray. + // Currently don't support the __dl_pack__ interface but can be + // added with very little effort it if needed. + + nb::ndarray np_array(LiteralToNdarray(*obj)); + + if (dtype.has_value()) { + throw XlaRuntimeError( + "Passing of dtype to __array__ not currently supported."); + } + + if (copy.has_value() && *copy) { + // when a copy is requested we _must_ return a copy: + // https://numpy.org/doc/2.1/reference/generated/numpy.ndarray.__array__.html + return np_array.cast(nb::rv_policy::copy); + } + + return np_array.cast(nb::rv_policy::reference_internal, + nb::cast(obj)); + }, + nb::arg("dtype").none() = nb::none(), + nb::arg("copy").none() = nb::none()) + .def("shape", &Literal::shape); + + nb::class_(m, "XlaComputation") + .def("__init__", + [](XlaComputation* self, + const nb::bytes& serialized_hlo_module_proto) { + HloModuleProto proto; + proto.ParseFromArray(serialized_hlo_module_proto.c_str(), + serialized_hlo_module_proto.size()); + new (self) XlaComputation(proto); + }) + .def("get_hlo_module", xla::ValueOrThrowWrapper(GetHloModule)) + .def("program_shape", + xla::ValueOrThrowWrapper(&XlaComputation::GetProgramShape)) + .def("name", &XlaComputation::name) + .def("as_serialized_hlo_module_proto", + xla::ValueOrThrowWrapper(GetComputationSerializedProto)) + .def("as_hlo_text", xla::ValueOrThrowWrapper(GetComputationHloText), + nb::arg("print_large_constants") = false) + .def("as_hlo_dot_graph", + xla::ValueOrThrowWrapper(GetComputationHloDotGraph)) + .def("hash", xla::ValueOrThrowWrapper(HashComputation)) + .def("as_hlo_module", xla::ValueOrThrowWrapper(GetHloModule)); + + nb::class_ hlo_print_options_class(m, "HloPrintOptions"); + hlo_print_options_class.def(nb::init<>()) + .def_static("short_parsable", &HloPrintOptions::ShortParsable) + .def_static("canonical", &HloPrintOptions::Canonical) + .def_static("fingerprint", &HloPrintOptions::Fingerprint) + .def_prop_rw("print_large_constants", + &HloPrintOptions::print_large_constants, + &HloPrintOptions::set_print_large_constants) + .def_prop_rw("print_metadata", &HloPrintOptions::print_metadata, + &HloPrintOptions::set_print_metadata) + .def_prop_rw("print_backend_config", + &HloPrintOptions::print_backend_config, + &HloPrintOptions::set_print_backend_config) + .def_prop_rw("print_result_shape", &HloPrintOptions::print_result_shape, + &HloPrintOptions::set_print_result_shape) + .def_prop_rw("print_operand_shape", &HloPrintOptions::print_operand_shape, + &HloPrintOptions::set_print_operand_shape) + .def_prop_rw("print_operand_names", &HloPrintOptions::print_operand_names, + &HloPrintOptions::set_print_operand_names) + .def_prop_rw("print_ids", &HloPrintOptions::print_ids, + &HloPrintOptions::set_print_ids) + .def_prop_rw("print_extra_attributes", + &HloPrintOptions::print_extra_attributes, + &HloPrintOptions::set_print_extra_attributes) + .def_prop_rw("print_program_shape", &HloPrintOptions::print_program_shape, + &HloPrintOptions::set_print_program_shape) + .def_prop_rw("print_percent", &HloPrintOptions::print_percent, + &HloPrintOptions::set_print_percent) + .def_prop_rw("print_control_dependencies", + &HloPrintOptions::print_control_dependencies, + &HloPrintOptions::set_print_control_dependencies) + .def_prop_rw("compact_operands", &HloPrintOptions::compact_operands, + &HloPrintOptions::set_compact_operands) + .def_prop_rw("include_layout_in_shapes", + &HloPrintOptions::include_layout_in_shapes, + &HloPrintOptions::set_include_layout_in_shapes) + .def_prop_rw("canonicalize_instruction_names", + &HloPrintOptions::canonicalize_instruction_names, + &HloPrintOptions::set_canonicalize_instruction_names) + .def_prop_rw("canonicalize_computations", + &HloPrintOptions::canonicalize_computations, + &HloPrintOptions::set_canonicalize_computations) + .def_prop_rw("indent_amount", &HloPrintOptions::indent_amount, + &HloPrintOptions::set_indent_amount) + .def_prop_rw("is_in_nested_computation", + &HloPrintOptions::is_in_nested_computation, + &HloPrintOptions::set_is_in_nested_computation); + + // HloModule.computations() returns raw pointers. + // pybind seems to prefer smart pointers. + // We give pybind a smart pointer to a wrapper around a raw pointer to satisfy + // pybind and avoid double frees. + class ComputationWrapper { + public: + ComputationWrapper(const HloComputation* comp, + const std::shared_ptr module) + : comp_(comp), module_(module) {} + std::string_view name() const { return comp_->name(); } + void render_html(const std::string& filename) { + std::string html = xla::ValueOrThrow(RenderGraph( + *comp_, /*label=*/"", comp_->parent()->config().debug_options(), + RenderedGraphFormat::kHtml, HloRenderOptions())); + xla::ThrowIfError(tsl::WriteStringToFile( + tsl::Env::Default(), absl::StrCat(filename, ".html"), html)); + } + + private: + const HloComputation* comp_; + // The module owns the computations: if its destructor is called, the + // computations are freed. To prevent that from happening in cases where the + // module Python object goes out of scope and gets garbage collected before + // the computations, we keep a shared_ptr to the module that originated the + // computation. + const std::shared_ptr module_; + }; + + nb::class_ hlo_computation_class(m, "HloComputation"); + + hlo_computation_class.def_prop_ro("name", &ComputationWrapper::name) + .def("render_html", &ComputationWrapper::render_html); + + nb::class_ hlo_module_class(m, "HloModule"); + hlo_module_class.def_prop_ro("name", &HloModule::name) + .def( + "to_string", + static_cast( + &HloModule::ToString), + nb::arg("options") = HloPrintOptions()) + .def("as_serialized_hlo_module_proto", + xla::ValueOrThrowWrapper(GetHloModuleSerializedProto)) + .def("from_serialized_hlo_module_proto", + xla::ValueOrThrowWrapper(HloModuleFromSerializedProto)) + .def("computations", + [](const std::shared_ptr m) + -> std::vector> { + std::vector> computations; + for (HloComputation* comp : m->computations()) + computations.push_back( + std::make_shared(comp, m)); + return computations; + }) + .def_prop_ro("spmd_output_sharding", + [](const HloModule& m) -> std::optional { + if (!m.has_spmd_output_sharding()) return std::nullopt; + return m.spmd_output_sharding().ToProto(); + }) + .def_prop_ro("spmd_parameters_shardings", + [](const HloModule& m) + -> std::optional> { + if (!m.has_spmd_parameters_shardings()) + return std::nullopt; + std::vector param_shardings; + for (const auto& parameter_sharding : + m.spmd_parameters_shardings()) { + param_shardings.push_back(parameter_sharding.ToProto()); + } + return param_shardings; + }); + + m.def("hlo_module_to_dot_graph", + [](const HloModule& hlo_module) -> std::string { + return xla::ValueOrThrow(RenderGraph( + *hlo_module.entry_computation(), /*label=*/"", + hlo_module.config().debug_options(), RenderedGraphFormat::kDot)); + }); + m.def( + "hlo_module_cost_analysis", + xla::ValueOrThrowWrapper([](jax::PyClient* client, + const HloModule& module) + -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(auto analysis, + client->pjrt_client()->GetHloCostAnalysis()); + TF_RETURN_IF_ERROR(module.entry_computation()->Accept(analysis.get())); + + // Convert from HloCostAnalysis::Properties to a standard map. + nb::dict ret; + analysis->properties().ForEach([&](std::string_view key, float val) { + ret[nb::str(key.data(), key.size())] = nb::cast(val); + }); + return ret; + })); + m.def("hlo_module_from_text", + xla::ValueOrThrowWrapper( + [](const std::string& hlo_module_text) + -> absl::StatusOr> { + auto hlo_module = + xla::ParseAndReturnUnverifiedModule(hlo_module_text); + TF_RETURN_IF_ERROR(hlo_module.status()); + std::shared_ptr result(std::move(*hlo_module)); + return result; + })); + + // Device assignments + nb::class_(m, "DeviceAssignment") + .def_static( + "create", + xla::ValueOrThrowWrapper([](nb::ndarray> array) + -> absl::StatusOr { + if (array.ndim() != 2) { + return InvalidArgument( + "Argument to DeviceAssignment constructor must be a " + "2D array, received an %dD array.", + array.ndim()); + } + DeviceAssignment result(array.shape(0), array.shape(1)); + for (int i = 0; i < array.shape(0); ++i) { + for (int j = 0; j < array.shape(1); ++j) { + result(i, j) = array(i, j); + } + } + return result; + })) + .def("replica_count", &DeviceAssignment::replica_count) + .def("computation_count", &DeviceAssignment::computation_count) + .def("__repr__", &DeviceAssignment::ToString) + .def("serialize", + xla::ValueOrThrowWrapper( + [](const DeviceAssignment& da) -> absl::StatusOr { + DeviceAssignmentProto proto; + da.Serialize(&proto); + std::string result; + if (!tsl::SerializeToStringDeterministic(proto, &result)) { + return Unknown( + "Failed to serialize the DeviceAssignmentProto."); + } + return nb::bytes(result.data(), result.size()); + })); + + nb::class_ compile_options(m, "CompileOptions"); + compile_options + .def("__init__", + [](CompileOptions* self) { + new (self) CompileOptions(); + DebugOptions* debug_options = + self->executable_build_options.mutable_debug_options(); + // Sets fast-math-disabling default options expected by JAX. + debug_options->set_xla_cpu_enable_fast_min_max(false); + debug_options->set_xla_gpu_enable_fast_min_max(false); + }) + .def("__getstate__", + [](const CompileOptions& self) -> nb::tuple { + auto proto = ValueOrThrow(self.ToProto()); + std::string result; + if (!tsl::SerializeToStringDeterministic(proto, &result)) { + // throw converted by PyBind to a Python RuntimeError. + throw XlaRuntimeError( + absl::StrCat("CompileOptions.py_pickle: ", + "SerializeToStringDeterministic failed")); + } + return nb::make_tuple(nb::bytes(result.data(), result.size())); + }) + .def("__setstate__", + [](CompileOptions* self, nb::tuple t) { + CompileOptionsProto result; + nb::bytes serialized = nb::cast(t[0]); + result.ParseFromArray(serialized.c_str(), serialized.size()); + new (self) CompileOptions( + ValueOrThrow(CompileOptions::FromProto(result))); + }) + .def("SerializeAsString", + [](const CompileOptions& self) -> nb::bytes { + auto proto = ValueOrThrow(self.ToProto()); + std::string result; + if (!tsl::SerializeToStringDeterministic(proto, &result)) { + // throw converted by PyBind to a Python RuntimeError. + throw XlaRuntimeError( + absl::StrCat("CompileOptions.SerializeAsString: ", + "SerializeToStringDeterministic failed")); + } + return nb::bytes(result.data(), result.size()); + }) + .def_static("ParseFromString", + [](nb::bytes s) { + CompileOptionsProto result; + result.ParseFromArray(s.c_str(), s.size()); + return ValueOrThrow(CompileOptions::FromProto(result)); + }) + .def_rw("argument_layouts", &CompileOptions::argument_layouts) + .def_rw("parameter_is_tupled_arguments", + &CompileOptions::parameter_is_tupled_arguments) + .def_rw("compile_portable_executable", + &CompileOptions::compile_portable_executable) + .def_ro("executable_build_options", + &CompileOptions::executable_build_options) + .def_rw("env_option_overrides", &CompileOptions::env_option_overrides) + .def_prop_rw( + "num_replicas", + [](const CompileOptions& options) { + return options.executable_build_options.num_replicas(); + }, + [](CompileOptions& options, int num_replicas) { + options.executable_build_options.set_num_replicas(num_replicas); + }) + .def_prop_rw( + "num_partitions", + [](const CompileOptions& options) { + return options.executable_build_options.num_partitions(); + }, + [](CompileOptions& options, int num_partitions) { + options.executable_build_options.set_num_partitions(num_partitions); + }) + .def_prop_rw( + "profile_version", + [](const CompileOptions& options) { return options.profile_version; }, + [](CompileOptions& options, int64_t profile_version) { + options.profile_version = profile_version; + }) + .def_prop_rw( + "device_assignment", + [](const CompileOptions& options) -> std::optional { + return options.executable_build_options.has_device_assignment() + ? std::optional( + options.executable_build_options + .device_assignment()) + : std::nullopt; + }, + [](CompileOptions& options, + const DeviceAssignment& device_assignment) { + options.executable_build_options.set_device_assignment( + device_assignment); + }); + + nb::enum_(m, "AutotuneCacheMode") + .value("UNSPECIFIED", DebugOptions::AUTOTUNE_CACHE_MODE_UNSPECIFIED) + .value("UPDATE", DebugOptions::AUTOTUNE_CACHE_MODE_UPDATE) + .value("READ", DebugOptions::AUTOTUNE_CACHE_MODE_READ); + + nb::class_(m, "DebugOptions") + .def("__repr__", &DebugOptions::DebugString) + .def_prop_rw("xla_backend_optimization_level", + &DebugOptions::xla_backend_optimization_level, + &DebugOptions::set_xla_backend_optimization_level) + .def_prop_rw("xla_cpu_enable_fast_math", + &DebugOptions::xla_cpu_enable_fast_math, + &DebugOptions::set_xla_cpu_enable_fast_math) + .def_prop_rw("xla_cpu_enable_xprof_traceme", + &DebugOptions::xla_cpu_enable_xprof_traceme, + &DebugOptions::set_xla_cpu_enable_xprof_traceme) + .def_prop_rw("xla_cpu_fast_math_honor_infs", + &DebugOptions::xla_cpu_fast_math_honor_infs, + &DebugOptions::set_xla_cpu_fast_math_honor_infs) + .def_prop_rw("xla_cpu_fast_math_honor_nans", + &DebugOptions::xla_cpu_fast_math_honor_nans, + &DebugOptions::set_xla_cpu_fast_math_honor_nans) + .def_prop_rw("xla_cpu_fast_math_honor_division", + &DebugOptions::xla_cpu_fast_math_honor_division, + &DebugOptions::set_xla_cpu_fast_math_honor_division) + .def_prop_rw("xla_cpu_fast_math_honor_functions", + &DebugOptions::xla_cpu_fast_math_honor_functions, + &DebugOptions::set_xla_cpu_fast_math_honor_functions) + .def_prop_rw("xla_detailed_logging", &DebugOptions::xla_detailed_logging, + &DebugOptions::set_xla_detailed_logging) + .def_prop_rw("xla_enable_dumping", &DebugOptions::xla_enable_dumping, + &DebugOptions::set_xla_enable_dumping) + .def_prop_rw("xla_gpu_enable_fast_min_max", + &DebugOptions::xla_gpu_enable_fast_min_max, + &DebugOptions::set_xla_gpu_enable_fast_min_max) + .def_prop_rw("xla_gpu_dump_autotune_results_to", + &DebugOptions::xla_gpu_dump_autotune_results_to, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_dump_autotune_results_to(value); + }) + .def_prop_rw("xla_gpu_load_autotune_results_from", + &DebugOptions::xla_gpu_load_autotune_results_from, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_load_autotune_results_from(value); + }) + .def_prop_rw("xla_gpu_cuda_data_dir", + &DebugOptions::xla_gpu_cuda_data_dir, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_cuda_data_dir(value); + }) + .def_prop_rw("xla_llvm_disable_expensive_passes", + &DebugOptions::xla_llvm_disable_expensive_passes, + &DebugOptions::set_xla_llvm_disable_expensive_passes) + .def_prop_rw( + "xla_disable_hlo_passes", + [](DebugOptions* self) { + return absl::StrJoin(self->xla_disable_hlo_passes(), ","); + }, + [](DebugOptions* self, std::string value) { + self->clear_xla_disable_hlo_passes(); + for (const auto& passname : + std::vector(absl::StrSplit(value, ','))) { + self->add_xla_disable_hlo_passes(passname); + } + }) + .def_prop_rw( + "xla_enable_hlo_passes_only", + [](DebugOptions* self) { + return absl::StrJoin(self->xla_enable_hlo_passes_only(), ","); + }, + [](DebugOptions* self, std::string value) { + self->clear_xla_enable_hlo_passes_only(); + for (const auto& passname : + std::vector(absl::StrSplit(value, ','))) { + self->add_xla_enable_hlo_passes_only(passname); + } + }) + .def_prop_rw("xla_test_all_input_layouts", + &DebugOptions::xla_test_all_input_layouts, + &DebugOptions::set_xla_test_all_input_layouts) + .def_prop_rw("xla_force_host_platform_device_count", + &DebugOptions::xla_force_host_platform_device_count, + &DebugOptions::set_xla_force_host_platform_device_count) + .def_prop_rw("xla_dump_to", &DebugOptions::xla_dump_to, + [](DebugOptions* self, std::string value) { + self->set_xla_dump_to(value); + }) + .def_prop_rw("xla_dump_hlo_module_re", + &DebugOptions::xla_dump_hlo_module_re, + [](DebugOptions* self, std::string value) { + self->set_xla_dump_hlo_module_re(value); + }) + .def_prop_rw("xla_dump_hlo_pass_re", &DebugOptions::xla_dump_hlo_pass_re, + [](DebugOptions* self, std::string value) { + self->set_xla_dump_hlo_pass_re(value); + }) + .def_prop_rw("xla_dump_hlo_as_text", &DebugOptions::xla_dump_hlo_as_text, + &DebugOptions::set_xla_dump_hlo_as_text) + .def_prop_rw("xla_dump_hlo_as_proto", + &DebugOptions::xla_dump_hlo_as_proto, + &DebugOptions::set_xla_dump_hlo_as_proto) + .def_prop_rw("xla_dump_hlo_as_dot", &DebugOptions::xla_dump_hlo_as_dot, + &DebugOptions::set_xla_dump_hlo_as_dot) + .def_prop_rw("xla_dump_hlo_as_url", &DebugOptions::xla_dump_hlo_as_url, + &DebugOptions::set_xla_dump_hlo_as_url) + .def_prop_rw("xla_dump_hlo_as_html", &DebugOptions::xla_dump_hlo_as_html, + &DebugOptions::set_xla_dump_hlo_as_html) + .def_prop_rw("xla_dump_fusion_visualization", + &DebugOptions::xla_dump_fusion_visualization, + &DebugOptions::set_xla_dump_fusion_visualization) + .def_prop_rw("xla_dump_hlo_snapshots", + &DebugOptions::xla_dump_hlo_snapshots, + &DebugOptions::set_xla_dump_hlo_snapshots) + .def_prop_rw("xla_dump_max_hlo_modules", + &DebugOptions::xla_dump_max_hlo_modules, + &DebugOptions::set_xla_dump_max_hlo_modules) + .def_prop_rw("xla_dump_module_metadata", + &DebugOptions::xla_dump_module_metadata, + &DebugOptions::set_xla_dump_module_metadata) + .def_prop_rw("xla_dump_compress_protos", + &DebugOptions::xla_dump_compress_protos, + &DebugOptions::set_xla_dump_compress_protos) + .def_prop_rw("xla_dump_hlo_as_long_text", + &DebugOptions::xla_dump_hlo_as_long_text, + &DebugOptions::set_xla_dump_hlo_as_long_text) + .def_prop_rw("xla_dump_disable_metadata", + &DebugOptions::xla_dump_disable_metadata, + &DebugOptions::set_xla_dump_disable_metadata) + .def_prop_rw("xla_dump_hlo_pipeline_re", + &DebugOptions::xla_dump_hlo_pipeline_re, + [](DebugOptions* self, std::string value) { + self->set_xla_dump_hlo_pipeline_re(value); + }) + .def_prop_rw("xla_gpu_dump_autotune_logs_to", + &DebugOptions::xla_gpu_dump_autotune_logs_to, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_dump_autotune_logs_to(value); + }) + .def_prop_rw("xla_gpu_kernel_cache_file", + &DebugOptions::xla_gpu_kernel_cache_file, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_kernel_cache_file(value); + }) + .def_prop_rw( + "xla_gpu_enable_llvm_module_compilation_parallelism", + &DebugOptions::xla_gpu_enable_llvm_module_compilation_parallelism, + &DebugOptions::set_xla_gpu_enable_llvm_module_compilation_parallelism) + .def_prop_rw("xla_gpu_per_fusion_autotune_cache_dir", + &DebugOptions::xla_gpu_per_fusion_autotune_cache_dir, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_per_fusion_autotune_cache_dir(value); + }) + .def_prop_rw("xla_gpu_experimental_autotune_cache_mode", + &DebugOptions::xla_gpu_experimental_autotune_cache_mode, + &DebugOptions::set_xla_gpu_experimental_autotune_cache_mode); + + nb::class_(m, "ExecutableBuildOptions") + .def(nb::init<>()) + .def("__repr__", &ExecutableBuildOptions::ToString) + .def_prop_rw( + "fdo_profile", + [](const ExecutableBuildOptions& options) { + return nb::bytes(options.fdo_profile().data(), + options.fdo_profile().size()); + }, + [](ExecutableBuildOptions& options, nb::bytes fdo_profile) { + options.set_fdo_profile( + std::string(fdo_profile.c_str(), fdo_profile.size())); + }) + .def_prop_rw( + "result_layout", + [](const ExecutableBuildOptions& options) -> std::optional { + return options.result_layout() + ? std::optional(*options.result_layout()) + : std::nullopt; + }, + &ExecutableBuildOptions::set_result_layout) + .def_prop_rw("num_replicas", &ExecutableBuildOptions::num_replicas, + &ExecutableBuildOptions::set_num_replicas) + .def_prop_rw("num_partitions", &ExecutableBuildOptions::num_partitions, + &ExecutableBuildOptions::set_num_partitions) + .def_prop_ro("debug_options", + &ExecutableBuildOptions::mutable_debug_options, + nb::rv_policy::reference, nb::keep_alive<1, 0>()) + .def_prop_rw( + "device_assignment", + [](const ExecutableBuildOptions& options) + -> std::optional { + return options.has_device_assignment() + ? std::optional( + options.device_assignment()) + : std::nullopt; + }, + &ExecutableBuildOptions::set_device_assignment) + .def("compilation_environments_from_serialized_proto", + [](ExecutableBuildOptions& options, + const nb::bytes& serialized_proto) { + xla::CompilationEnvironmentsProto env_proto; + env_proto.ParseFromArray(serialized_proto.c_str(), + serialized_proto.size()); + auto comp_envs = xla::ValueOrThrow( + xla::CompilationEnvironments::CreateFromProto(env_proto)); + *options.mutable_comp_envs() = std::move(*comp_envs); + }) + .def_prop_rw("exec_time_optimization_effort", + &ExecutableBuildOptions::exec_time_optimization_effort, + &ExecutableBuildOptions::set_exec_time_optimization_effort) + .def_prop_rw("memory_fitting_effort", + &ExecutableBuildOptions::memory_fitting_effort, + &ExecutableBuildOptions::set_memory_fitting_effort) + .def_prop_rw( + "optimization_level", + [](ExecutableBuildOptions& options) { + return static_cast(options.optimization_level()); + }, + [](ExecutableBuildOptions& options, int value) { + options.set_optimization_level( + static_cast(value)); + }) + .def_prop_rw( + "memory_fitting_level", + [](ExecutableBuildOptions& options) { + return static_cast(options.memory_fitting_level()); + }, + [](ExecutableBuildOptions& options, int value) { + options.set_memory_fitting_level( + static_cast(value)); + }) + .def_prop_rw("use_spmd_partitioning", + &ExecutableBuildOptions::use_spmd_partitioning, + &ExecutableBuildOptions::set_use_spmd_partitioning) + .def_prop_rw("use_auto_spmd_partitioning", + &ExecutableBuildOptions::use_auto_spmd_partitioning, + &ExecutableBuildOptions::set_use_auto_spmd_partitioning) + .def_prop_rw( + "auto_spmd_partitioning_mesh_shape", + &ExecutableBuildOptions::auto_spmd_partitioning_mesh_shape, + &ExecutableBuildOptions::set_auto_spmd_partitioning_mesh_shape) + .def_prop_rw("auto_spmd_partitioning_mesh_ids", + &ExecutableBuildOptions::auto_spmd_partitioning_mesh_ids, + &ExecutableBuildOptions::set_auto_spmd_partitioning_mesh_ids) + .def_prop_rw( + "allow_spmd_sharding_propagation_to_parameters", + [](const ExecutableBuildOptions& options) -> std::vector { + return std::vector( + options.allow_spmd_sharding_propagation_to_parameters().begin(), + options.allow_spmd_sharding_propagation_to_parameters().end()); + }, + [](ExecutableBuildOptions& options, std::vector values) { + absl::InlinedVector v(values.begin(), values.end()); + options.set_allow_spmd_sharding_propagation_to_parameters(v); + }) + .def_prop_rw( + "allow_spmd_sharding_propagation_to_output", + [](const ExecutableBuildOptions& options) -> std::vector { + return std::vector( + options.allow_spmd_sharding_propagation_to_output().begin(), + options.allow_spmd_sharding_propagation_to_output().end()); + }, + [](ExecutableBuildOptions& options, std::vector values) { + absl::InlinedVector v(values.begin(), values.end()); + options.set_allow_spmd_sharding_propagation_to_output(v); + }) + .def_prop_rw("use_shardy_partitioner", + &ExecutableBuildOptions::use_shardy_partitioner, + &ExecutableBuildOptions::set_use_shardy_partitioner); + + nb::enum_ op_sharding_type(m, "OpSharding_Type", + nb::is_arithmetic()); + op_sharding_type.value("REPLICATED", OpSharding::REPLICATED) + .value("MAXIMAL", OpSharding::MAXIMAL) + .value("MANUAL", OpSharding::MANUAL) + .value("UNREDUCED", OpSharding::UNREDUCED) + .value("TUPLE", OpSharding::TUPLE) + .value("OTHER", OpSharding::OTHER) + .value("UNKNOWN", OpSharding::UNKNOWN); + + nb::enum_ op_sharding_shard_group_type( + m, "OpSharding_ShardGroupType"); + op_sharding_shard_group_type.value("AS", OpSharding::AS) + .value("LIKE", OpSharding::LIKE); + + nb::class_ op_sharding(m, "OpSharding"); + op_sharding.attr("Type") = op_sharding_type; + op_sharding.attr("ShardGroupType") = op_sharding_shard_group_type; + op_sharding.def(nb::init<>()) + .def("__getstate__", + [](const OpSharding& self) { + std::string serialized = self.SerializeAsString(); + return nb::make_tuple( + nb::bytes(serialized.data(), serialized.size())); + }) + .def("__setstate__", + [](OpSharding* self, nb::tuple t) { + new (self) OpSharding(); + nb::bytes serialized = nb::cast(t[0]); + self->ParseFromArray(serialized.c_str(), serialized.size()); + }) + .def_prop_rw("type", &xla::OpSharding::type, &xla::OpSharding::set_type) + .def_prop_rw("replicate_on_last_tile_dim", + &xla::OpSharding::replicate_on_last_tile_dim, + &xla::OpSharding::set_replicate_on_last_tile_dim) + .def_prop_rw("is_shard_group", &xla::OpSharding::is_shard_group, + &xla::OpSharding::set_is_shard_group) + .def_prop_rw("shard_group_id", &xla::OpSharding::shard_group_id, + &xla::OpSharding::set_shard_group_id) + .def_prop_rw("shard_group_type", &xla::OpSharding::shard_group_type, + &xla::OpSharding::set_shard_group_type) + .def("__repr__", + [](const xla::OpSharding& self) { return self.DebugString(); }) + .def("ParseFromString", + [](OpSharding& sharding, const nb::bytes& s) { + sharding.ParseFromArray(s.c_str(), s.size()); + }) + .def("SerializeToString", + [](const OpSharding& sharding) { + std::string serialized = sharding.SerializeAsString(); + return nb::bytes(serialized.data(), serialized.size()); + }) + .def("clone", + [](const OpSharding& sharding) { return OpSharding(sharding); }); + DefRepeatedProperty(op_sharding, "tile_assignment_dimensions", + &xla::OpSharding::mutable_tile_assignment_dimensions); + DefRepeatedProperty(op_sharding, "tile_assignment_devices", + &xla::OpSharding::mutable_tile_assignment_devices); + DefRepeatedProperty(op_sharding, "iota_reshape_dims", + &xla::OpSharding::mutable_iota_reshape_dims); + DefRepeatedProperty(op_sharding, "iota_transpose_perm", + &xla::OpSharding::mutable_iota_transpose_perm); + DefRepeatedProperty(op_sharding, "tuple_shardings", + &xla::OpSharding::mutable_tuple_shardings); + DefRepeatedEnumProperty(op_sharding, "last_tile_dims", + &xla::OpSharding::mutable_last_tile_dims); + + nb::class_ hlo_sharding(m, "HloSharding"); + hlo_sharding + .def_static("from_proto", + xla::ValueOrThrowWrapper(xla::HloSharding::FromProto)) + .def_static("from_string", xla::ValueOrThrowWrapper(xla::ParseSharding)) + .def_static( + "tuple_sharding", + [](xla::Shape shape, + std::vector shardings) -> xla::HloSharding { + return HloSharding::Tuple(shape, shardings); + }, + "Constructs a tuple sharding.") + .def_static( + "iota_tile", xla::ValueOrThrowWrapper(IotaTileHelper), + nb::arg("dims"), + nb::arg("reshape_dims") = absl::Span(), + nb::arg("transpose_perm") = absl::Span(), + nb::arg("subgroup_types") = absl::Span()) + .def_static("manual", [] { return HloSharding::Manual(); }) + .def_static("replicate", [] { return HloSharding::Replicate(); }) + .def_static("unreduced", [] { return HloSharding::Unreduced(); }) + .def_static("unknown", [] { return HloSharding::Unknown(); }) + .def_static( + "subgroup_with_device_ordering", + xla::ValueOrThrowWrapper(SubgroupWithTileAssignmentHelper), + nb::arg("tile_assignment"), + nb::arg("subgroup_types") = absl::Span()) + .def( + "__eq__", + [](const xla::HloSharding& a, const xla::HloSharding& b) { + return a == b; + }, + nb::is_operator(), + nb::sig("def __eq__(self, other: object, /) -> bool")) + .def( + "__ne__", + [](const xla::HloSharding& a, const xla::HloSharding& b) { + return a != b; + }, + nb::is_operator(), + nb::sig("def __ne__(self, other: object, /) -> bool")) + .def("__hash__", + [](const xla::HloSharding& self) { return absl::HashOf(self); }) + .def("is_replicated", &xla::HloSharding::IsReplicated) + .def("is_manual", &xla::HloSharding::IsManual) + .def("is_unreduced", &xla::HloSharding::IsUnreduced) + .def("is_unknown", &xla::HloSharding::IsUnknown) + .def("is_tiled", &xla::HloSharding::IsTiled) + .def("is_maximal", &xla::HloSharding::IsTileMaximal) + .def("tile", [](const xla::HloSharding& self, + xla::Shape shape) { return self.TileShape(shape); }) + // tile_assignment.array() is computed using an internal cache, + // which is why nb::lock_self() is required. It may be preferable to move + // this locking into the TileAssignment class if we find it to race with + // non-Python users of that class. + .def( + "tuple_elements", + [](const xla::HloSharding& self) { return self.tuple_elements(); }, + nb::lock_self()) + .def( + "num_devices", + [](const xla::HloSharding& self) { + return self.tile_assignment().num_elements(); + }, + nb::lock_self()) + .def( + "num_dimensions", + [](const xla::HloSharding& self) { + return self.tile_assignment().num_dimensions(); + }, + nb::lock_self()) + .def("is_tile_assignment_iota", + [](const xla::HloSharding& self) { + return self.tile_assignment().iota().has_value(); + }) + .def( + "tile_assignment_dimensions", + [](const xla::HloSharding& self) { + absl::Span span = + self.tile_assignment().dimensions(); + CHECK(span.data()); + return span; + }, + nb::lock_self()) + .def( + "tile_assignment_devices", + [](const xla::HloSharding& self) { + auto span = + absl::MakeConstSpan(self.tile_assignment().array().data(), + self.tile_assignment().num_elements()); + CHECK(span.data()); + return span; + }, + nb::lock_self()) + .def("replicate_on_last_tile_dim", + &xla::HloSharding::ReplicateOnLastTileDim) + .def("subgroup_types", &xla::HloSharding::subgroup_types) + .def("__repr__", + [](const xla::HloSharding& self) { return self.ToString(); }) + .def("to_proto", &xla::HloSharding::ToProto) + .def("get_axis_sizes", [](const xla::HloSharding& self) { + // If returning the SmallVector, we encounter the error "unable to + // convert function return value to a Python type!". + mlir::SmallVector mesh_shape = + xla::sdy::getAxisSizes(self.tile_assignment()); + return std::vector(mesh_shape.begin(), mesh_shape.end()); + }); +} // NOLINT(readability/fn_size) +} // namespace xla diff --git a/jaxlib/xla_compiler.h b/jaxlib/xla_compiler.h new file mode 100644 index 000000000000..261f630d1cd3 --- /dev/null +++ b/jaxlib/xla_compiler.h @@ -0,0 +1,28 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_COMPILER_H_ +#define JAXLIB_XLA_COMPILER_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace xla { + +void BuildXlaCompilerSubmodule(nanobind::module_& m); + +} // namespace xla + +#endif // JAXLIB_XLA_COMPILER_H_ diff --git a/pyproject.toml b/pyproject.toml index a1b9e7dd446a..3077529caa21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ allow_redefinition = true module = [ "IPython.*", "absl.*", - "colorama.*", + "compression.*", "etils.*", "filelock.*", "flatbuffers.*", @@ -23,21 +23,32 @@ module = [ "jax.experimental.jax2tf.tests.back_compat_testdata", "jax.experimental.jax2tf.tests.flax_models", "jax_cuda12_plugin.*", - "jaxlib.*", + "jax_cuda13_plugin.*", + "jaxlib.cpu_feature_guard", + "jaxlib.cuda.*", "jaxlib.mlir.*", + "jaxlib.mosaic.dialect.gpu.*", + "jaxlib.mosaic.python._tpu_gen", + "jaxlib.triton.*", + "jaxlib.utils", + "jaxlib.version", + "jaxlib._jax.utils", + "jaxlib._pretty_printer", "jraph.*", "libtpu.*", "matplotlib.*", + "mlir.*", + "ml_dtypes.*", "nvidia.*", "numpy.*", "opt_einsum.*", "optax.*", + "portpicker.*", "pygments.*", "pytest.*", "rich.*", - "scipy.*", "setuptools.*", - "tensorboard_plugin_profile.convert.*", + "xprof.convert.*", "tensorflow.*", "tensorflow.io.*", "tensorflowjs.*", @@ -66,9 +77,15 @@ filterwarnings = [ # https://github.com/protocolbuffers/protobuf/issues/12186#issuecomment-1745679358 "ignore:Type google\\._upb\\._message\\.(Scalar|Message)MapContainer uses PyType_Spec with a metaclass that has custom tp_new\\. This is deprecated and will no longer be allowed in Python 3\\.14\\.:DeprecationWarning", + # Protobuf version in TF releases may lag behind the ones in code. + "ignore:Protobuf gencode version \\d+\\.\\d+\\.\\d is exactly one major version older than the runtime version \\d+\\.\\d+\\.\\d at tensorflow/fore/framework/attr_value\\.proto\\. Please update the gencode to avoid compatibility violations in the next runtime release\\.:UserWarning", + # TODO(b/401588349): Remove this once transparent hugepages are enabled. "ignore:Transparent hugepages", + # TODO(phawkins): Remove this once cloud_tpu_init is removed. + "ignore:jax.cloud_tpu_init was deprecated", + # NOTE: this is probably not where you want to add code to suppress a # warning. Only pytest tests look at this list, whereas Bazel tests also # check for warnings and do not check this list. Most likely, you should @@ -78,7 +95,7 @@ doctest_optionflags = [ "NUMBER", "NORMALIZE_WHITESPACE" ] -addopts = "--doctest-glob='*.rst' --ignore='examples/ffi'" +addopts = "--doctest-glob='*.rst' --ignore='examples/ffi' --import-mode=importlib" [tool.ruff] preview = true @@ -87,9 +104,9 @@ exclude = [ "build", "__pycache__", ] -line-length = 88 +line-length = 80 indent-width = 2 -target-version = "py310" +target-version = "py311" [tool.ruff.lint] ignore = [ @@ -103,6 +120,8 @@ ignore = [ "C901", # Local variable is assigned to but never used "F841", + # Class could be dataclass or namedtuple + "B903", # Raise with from clause inside except block "B904", # Zip without explicit strict parameter diff --git a/setup.py b/setup.py index 80f45285ba61..8e65648466bc 100644 --- a/setup.py +++ b/setup.py @@ -19,11 +19,11 @@ project_name = 'jax' -_current_jaxlib_version = '0.5.1' +_current_jaxlib_version = '0.9.0' # The following should be updated after each new jaxlib release. -_latest_jaxlib_version_on_pypi = '0.5.1' +_latest_jaxlib_version_on_pypi = '0.8.2' -_libtpu_version = '0.0.10.*' +_libtpu_version = '0.0.34.*' def load_version_module(pkg_path): spec = importlib.util.spec_from_file_location( @@ -38,6 +38,13 @@ def load_version_module(pkg_path): _cmdclass = _version_module._get_cmdclass(project_name) _minimum_jaxlib_version = _version_module._minimum_jaxlib_version +# If this is a pre-release ("rc" wheels), append "rc0" to +# _minimum_jaxlib_version and _current_jaxlib_version so that we are able to +# install the rc wheels. +if _version_module._is_prerelease(): + _minimum_jaxlib_version += "rc0" + _current_jaxlib_version += "rc0" + with open('README.md', encoding='utf-8') as f: _long_description = f.read() @@ -50,16 +57,15 @@ def load_version_module(pkg_path): long_description_content_type='text/markdown', author='JAX team', author_email='jax-dev@google.com', - packages=find_packages(exclude=["*examples*", "*internal_test_util*"]), + packages=find_packages(exclude=["examples"]), package_data={'jax': ['py.typed', "*.pyi", "**/*.pyi"]}, - python_requires='>=3.10', + python_requires='>=3.11', install_requires=[ f'jaxlib >={_minimum_jaxlib_version}, <={_jax_version}', - 'ml_dtypes>=0.4.0', - 'numpy>=1.25', - "numpy>=1.26.0; python_version>='3.12'", + 'ml_dtypes>=0.5.0', + 'numpy>=2.0', 'opt_einsum', - 'scipy>=1.11.1', + 'scipy>=1.13', ], extras_require={ # Minimum jaxlib version; used in testing. @@ -81,47 +87,57 @@ def load_version_module(pkg_path): ], 'cuda': [ - f"jaxlib=={_current_jaxlib_version}", - f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}", + f"jaxlib>={_current_jaxlib_version},<={_jax_version}", + f"jax-cuda12-plugin[with-cuda]>={_current_jaxlib_version},<={_jax_version}", ], 'cuda12': [ - f"jaxlib=={_current_jaxlib_version}", - f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}", + f"jaxlib>={_current_jaxlib_version},<={_jax_version}", + f"jax-cuda12-plugin[with-cuda]>={_current_jaxlib_version},<={_jax_version}", ], - # Deprecated alias for cuda12, kept to avoid breaking users who wrote - # cuda12_pip in their CI. - 'cuda12_pip': [ - f"jaxlib=={_current_jaxlib_version}", - f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}", + 'cuda13': [ + f"jaxlib>={_current_jaxlib_version},<={_jax_version}", + f"jax-cuda13-plugin[with-cuda]>={_current_jaxlib_version},<={_jax_version}", ], # Target that does not depend on the CUDA pip wheels, for those who want # to use a preinstalled CUDA. - 'cuda12_local': [ - f"jaxlib=={_current_jaxlib_version}", - f"jax-cuda12-plugin=={_current_jaxlib_version}", + 'cuda12-local': [ + f"jaxlib>={_current_jaxlib_version},<={_jax_version}", + f"jax-cuda12-plugin>={_current_jaxlib_version},<={_jax_version}", + ], + + 'cuda13-local': [ + f"jaxlib>={_current_jaxlib_version},<={_jax_version}", + f"jax-cuda13-plugin>={_current_jaxlib_version},<={_jax_version}", ], - # ROCm support for ROCm 6.0 and above. + # ROCm support for ROCm 7.0 and above. 'rocm': [ - f"jaxlib=={_current_jaxlib_version}", - f"jax-rocm60-plugin>={_current_jaxlib_version},<={_jax_version}", + f"jaxlib>={_current_jaxlib_version},<={_jax_version}", + f"jax-rocm7-plugin>={_current_jaxlib_version},<={_jax_version}", ], # For automatic bootstrapping distributed jobs in Kubernetes 'k8s': [ 'kubernetes', ], + + # For including XProf server + 'xprof': [ + 'xprof', + ], }, url='https://github.com/jax-ml/jax', license='Apache-2.0', classifiers=[ - "Programming Language :: Python :: 3.10", + "Development Status :: 5 - Production/Stable", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "Programming Language :: Python :: Free Threading :: 3 - Stable", ], zip_safe=False, ) diff --git a/test_shard_count.bzl b/test_shard_count.bzl new file mode 100644 index 000000000000..d56344c0787e --- /dev/null +++ b/test_shard_count.bzl @@ -0,0 +1,27 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Repository rule to generate a file with USE_MINIMAL_SHARD_COUNT. """ + +def _test_shard_count_repository_impl(repository_ctx): + USE_MINIMAL_SHARD_COUNT = repository_ctx.getenv("USE_MINIMAL_SHARD_COUNT", "False") + repository_ctx.file( + "test_shard_count.bzl", + "USE_MINIMAL_SHARD_COUNT = %s" % USE_MINIMAL_SHARD_COUNT, + ) + repository_ctx.file("BUILD", "") + +test_shard_count_repository = repository_rule( + implementation = _test_shard_count_repository_impl, +) diff --git a/tests/BUILD b/tests/BUILD index 0ffa68ed8eb3..11e9fbe869e9 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -34,41 +34,69 @@ jax_generate_backend_suites() jax_multiplatform_test( name = "api_test", srcs = ["api_test.py"], - enable_configs = ["tpu_v3_2x2"], - shard_count = 10, + enable_configs = ["tpu_v3_x4"], + shard_count = { + "cpu": select({ + "//tests/config:tsan_freethreading_rbe": 10, + "//conditions:default": 5, + }), + "gpu": 5, + "tpu": 5, + }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "custom_api_test", + srcs = ["custom_api_test.py"], + shard_count = { + "cpu": select({ + "//tests/config:tsan_freethreading_rbe": 2, + "//conditions:default": 1, + }), + }, deps = [ - "//jax:experimental", - ], + "//jax/_src:custom_derivatives", + "//jax/experimental:custom_dce", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "debug_info_test", srcs = ["debug_info_test.py"], - enable_configs = ["tpu_v3_2x2"], + enable_configs = ["tpu_v3_x4"], deps = [ - "//jax:experimental", - "//jax:pallas", - "//jax:pallas_gpu", - "//jax:pallas_gpu_ops", - "//jax:pallas_tpu", - "//jax:pallas_tpu_ops", - ] + py_deps("numpy"), + "//jax/_src:custom_transpose", + "//jax/_src:shard_map", + "//jax/experimental:checkify", + "//jax/experimental:custom_dce", + "//jax/experimental:pallas", + "//jax/experimental:pallas_gpu", + "//jax/experimental:pallas_gpu_ops", + "//jax/experimental:pallas_tpu", + "//jax/experimental:pallas_tpu_ops", + ] + py_deps([ + "numpy", + "absl/testing", + ]), ) jax_multiplatform_test( name = "device_test", srcs = ["device_test.py"], -) - -jax_multiplatform_test( - name = "dynamic_api_test", - srcs = ["dynamic_api_test.py"], - shard_count = 2, + deps = py_deps("absl/testing"), ) jax_multiplatform_test( name = "api_util_test", srcs = ["api_util_test.py"], + deps = py_deps("absl/testing"), ) jax_py_test( @@ -76,33 +104,84 @@ jax_py_test( srcs = ["array_api_test.py"], deps = [ "//jax", - "//jax:test_util", + "//jax/_src:test_util", ] + py_deps("absl/testing"), ) +jax_py_test( + name = "array_extensibility_test", + srcs = ["array_extensibility_test.py"], + shard_count = 3, + tags = [ + "notsan", # Times out + ], + deps = [ + "//jax", + "//jax/_src:test_util", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + jax_multiplatform_test( name = "array_interoperability_test", srcs = ["array_interoperability_test.py"], + backend_tags = { + "gpu": [ + "noasan", # Forge OOM + ], + }, enable_backends = [ "cpu", "gpu", ], enable_configs = [ - "gpu_p100x2", + "gpu_h100x2", ], - env = { - "PYTHONWARNINGS": "default", # TODO(b/394123878): protobuf, via TensorFlow, issues a Python warning under Python 3.12+ sometimes. - }, - tags = ["multiaccelerator"], - deps = py_deps("tensorflow_core"), + tags = [ + "multiaccelerator", + "notsan", # Times out + ], + deps = py_deps([ + "absl/testing", + "numpy", + "tensorflow_core", + ]), ) jax_multiplatform_test( name = "batching_test", srcs = ["batching_test.py"], + backend_tags = { + "tpu": ["noasan"], # Times out. + }, shard_count = { "gpu": 5, + "cpu": select({ + "//tests/config:tsan_freethreading_rbe": 4, + "//conditions:default": 2, + }), }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "buffer_callback_test", + srcs = ["buffer_callback_test.py"], + enable_backends = [ + "cpu", + "gpu", + ], + deps = [ + "//jax/experimental:buffer_callback", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_py_test( @@ -110,27 +189,52 @@ jax_py_test( srcs = ["config_test.py"], deps = [ "//jax", - "//jax:test_util", + "//jax/_src:test_util", ] + py_deps("absl/testing"), ) jax_multiplatform_test( name = "core_test", srcs = ["core_test.py"], - shard_count = { - "cpu": 5, - "gpu": 10, + backend_tags = { + "tpu": ["noasan"], # Times out. }, + shard_count = 2, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "debug_nans_test", srcs = ["debug_nans_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_py_test( + name = "distributed_initialize_test", + srcs = ["distributed_initialize_test.py"], + deps = [ + "//jax", + "//jax/_src:test_util", + ] + py_deps([ + "portpicker", + "absl/testing", + ]), ) jax_multiplatform_test( name = "distributed_test", srcs = ["distributed_test.py"], + enable_backends = ["gpu"], + deps = py_deps([ + "portpicker", + "absl/testing", + ]), ) jax_py_test( @@ -142,13 +246,20 @@ jax_py_test( tags = ["manual"], deps = [ "//jax", - "//jax:test_util", - ] + py_deps("portpicker"), + "//jax/_src:test_util", + ] + py_deps([ + "portpicker", + "absl/testing", + ]), ) jax_multiplatform_test( name = "dtypes_test", srcs = ["dtypes_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -158,22 +269,28 @@ jax_multiplatform_test( enable_configs = [ "cpu", ], + deps = py_deps("absl/testing"), ) jax_multiplatform_test( name = "extend_test", srcs = ["extend_test.py"], - deps = ["//jax:extend"], + deps = ["//jax/extend"] + py_deps("absl/testing"), ) jax_multiplatform_test( name = "ffi_test", srcs = ["ffi_test.py"], enable_configs = [ - "gpu_p100x2", + "gpu_h100x2", ], # TODO(dfm): Remove after removal of jex.ffi imports. - deps = ["//jax:extend"], + deps = [ + "//jax/_src:ffi", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -186,15 +303,23 @@ jax_multiplatform_test( ], # Times out on TPU with asan/tsan. }, shard_count = { - "tpu": 20, - "cpu": 20, - "gpu": 10, + "tpu": 5, + "cpu": 5, + "gpu": 5, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "generated_fun_test", srcs = ["generated_fun_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -205,6 +330,7 @@ jax_multiplatform_test( "XLA_PYTHON_CLIENT_PREALLOCATE": "0", }, main = "gpu_memory_flags_test.py", + deps = py_deps("absl/testing"), ) jax_multiplatform_test( @@ -214,20 +340,24 @@ jax_multiplatform_test( env = { "XLA_PYTHON_CLIENT_PREALLOCATE": "1", }, + deps = py_deps("absl/testing"), ) jax_multiplatform_test( name = "lobpcg_test", srcs = ["lobpcg_test.py"], - env = {"LOBPCG_EMIT_DEBUG_PLOTS": "1"}, - shard_count = { - "cpu": 48, - "gpu": 48, - "tpu": 48, - }, + # Set LOBPCG_EMIT_DEBUG_PLOTS=1 to debug + # checkLobpcgMonotonicity and checkApproxEigs tests + # using matplotlib plots + # env = {"LOBPCG_EMIT_DEBUG_PLOTS": "1"}, deps = [ - "//jax:experimental_sparse", - ] + py_deps("matplotlib"), + "//jax/experimental:sparse", + ] + py_deps([ + "matplotlib", + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( @@ -236,8 +366,13 @@ jax_multiplatform_test( shard_count = { "cpu": 10, "gpu": 10, - "tpu": 40, + "tpu": 15, }, + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_py_test( @@ -245,53 +380,69 @@ jax_py_test( srcs = ["xla_interpreter_test.py"], deps = [ "//jax", - "//jax:test_util", - ], + "//jax/_src:test_util", + ] + py_deps("absl/testing"), ) jax_multiplatform_test( name = "memories_test", srcs = ["memories_test.py"], + tags = ["multiaccelerator"], + deps = [ + "//jax/experimental:compute_on", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "custom_partitioning_test", + srcs = ["custom_partitioning_test.py"], + backend_tags = { + "tpu": ["notsan"], # Times out under tsan. + "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 + }, enable_configs = [ - "cpu", - "gpu_p100x2", - "tpu_v3_2x2", - "tpu_v4_2x2", - "tpu_v5p_2x2", - "tpu_v5e_4x2", - "gpu_p100x2_shardy", - "tpu_v5e_4x2_shardy", + "tpu_v3_x4", + "gpu_h100x2", ], - shard_count = { - "tpu": 5, - }, + tags = ["multiaccelerator"], deps = [ - "//jax:experimental", - ], + "//jax/experimental:custom_partitioning", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "pjit_test", srcs = ["pjit_test.py"], backend_tags = { - "tpu": ["notsan"], # Times out under tsan. + "tpu": [ + "notsan", # Times out. + "noasan", # Times out. + ], "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 }, enable_configs = [ - "gpu_p100x2_shardy", - "tpu_v3_2x2_shardy", - "tpu_v3_2x2", - "gpu_p100x2", + "tpu_v3_x4", + "gpu_h100x2", ], shard_count = { "cpu": 5, - "gpu": 5, + "gpu": 2, "tpu": 5, }, tags = ["multiaccelerator"], deps = [ - "//jax:experimental", - ], + "//jax/experimental", + "//jax/experimental:multihost_utils", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -301,26 +452,27 @@ jax_multiplatform_test( "tpu": ["requires-mem:16g"], # Under tsan on 2x2 this test exceeds the default 12G memory limit. }, enable_configs = [ - "tpu_v3_2x2_shardy", + "tpu_v3_x4", ], tags = ["multiaccelerator"], deps = [ - "//jax:experimental", - ], + "//jax/experimental:layout", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "shard_alike_test", srcs = ["shard_alike_test.py"], - enable_configs = [ - "tpu_v3_2x2", - "tpu_v5e_4x2", - "tpu_v4_2x2", - "tpu_v3_2x2_shardy", - ], + tags = ["multiaccelerator"], deps = [ - "//jax:experimental", - ], + "//jax/experimental:shard_alike", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -329,29 +481,34 @@ jax_multiplatform_test( backend_tags = { "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 }, + disable_configs = [ + "gpu_h100x2_tfrt", # TODO(b/419192167): Doesn't work + ], enable_backends = ["gpu"], tags = [ "config-cuda-only", "multiaccelerator", ], deps = [ - "//jax:experimental", - ], + "//jax/experimental:profiler", + "//jax/experimental:serialize_executable", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "mock_gpu_test", srcs = ["mock_gpu_test.py"], enable_backends = ["gpu"], - enable_configs = [ - "gpu_p100x2_shardy", - ], tags = [ "config-cuda-only", ], - deps = [ - "//jax:experimental", - ], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -364,9 +521,7 @@ jax_multiplatform_test( tags = [ "config-cuda-only", ], - deps = [ - "//jax:experimental", - ], + deps = py_deps("absl/testing"), ) jax_multiplatform_test( @@ -376,13 +531,16 @@ jax_multiplatform_test( "tpu": ["requires-mem:16g"], # Under tsan on 2x2 this test exceeds the default 12G memory limit. }, enable_configs = [ - "tpu_v3_2x2", + "tpu_v3_x4", ], tags = ["multiaccelerator"], deps = [ - "//jax:experimental", - "//jax:internal_test_util", - ], + "//jax/_src:internal_test_util", + "//jax/experimental:multihost_utils", + ] + py_deps([ + "numpy", + "absl/testing", + ]), ) jax_multiplatform_test( @@ -390,8 +548,13 @@ jax_multiplatform_test( srcs = ["aot_test.py"], tags = ["multiaccelerator"], deps = [ - "//jax:experimental", - ] + py_deps("numpy"), + "//jax/experimental:pjit", + "//jax/experimental:serialize_executable", + "//jax/experimental:topologies", + ] + py_deps([ + "numpy", + "absl/testing", + ]), ) jax_multiplatform_test( @@ -399,34 +562,40 @@ jax_multiplatform_test( srcs = ["image_test.py"], shard_count = { "cpu": 10, - "gpu": 20, - "tpu": 10, + "gpu": 10, + "tpu": 8, }, tags = ["noasan"], # Linking TF causes a linker OOM. - deps = py_deps("pil") + py_deps("tensorflow_core"), -) - -jax_multiplatform_test( - name = "infeed_test", - srcs = ["infeed_test.py"], - deps = [ - ], + deps = py_deps([ + "pil", + "tensorflow_core", + "numpy", + "absl/testing", + ]), ) jax_multiplatform_test( name = "jax_jit_test", srcs = ["jax_jit_test.py"], + enable_backends = ["cpu"], main = "jax_jit_test.py", + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_py_test( name = "jax_to_ir_test", srcs = ["jax_to_ir_test.py"], deps = [ - "//jax:test_util", + "//jax/_src:test_util", "//jax/experimental/jax2tf", "//jax/tools:jax_to_ir", - ] + py_deps("tensorflow_core"), + ] + py_deps([ + "tensorflow_core", + "absl/testing", + ]), ) jax_py_test( @@ -434,59 +603,83 @@ jax_py_test( srcs = ["jaxpr_util_test.py"], deps = [ "//jax", - "//jax:jaxpr_util", - "//jax:test_util", - ], + "//jax/_src:jaxpr_util", + "//jax/_src:test_util", + ] + py_deps("absl/testing"), ) jax_multiplatform_test( name = "jet_test", srcs = ["jet_test.py"], shard_count = { - "cpu": 10, - "gpu": 10, + "gpu": 2, + "cpu": select({ + "//tests/config:tsan_freethreading_rbe": 8, + "//conditions:default": 1, + }), }, deps = [ - "//jax:jet", - "//jax:stax", - ], + "//jax/example_libraries:stax", + "//jax/experimental:jet", + "//jax/extend:core", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "lax_control_flow_test", srcs = ["lax_control_flow_test.py"], shard_count = { - "cpu": 30, - "gpu": 40, - "tpu": 30, + "cpu": 15, + "gpu": 15, + "tpu": 10, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "custom_root_test", srcs = ["custom_root_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "custom_linear_solve_test", srcs = ["custom_linear_solve_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "lax_numpy_test", srcs = ["lax_numpy_test.py"], backend_tags = { + "tpu": ["notsan"], # Test times out. "cpu": ["notsan"], # Test times out. }, shard_count = { - "cpu": 40, - "gpu": 40, - "tpu": 50, + "cpu": 45, + "gpu": 35, + "tpu": 45, }, tags = [ "noasan", # Test times out on all backends "test_cpu_thunks", ], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -497,6 +690,10 @@ jax_multiplatform_test( "gpu": 30, "tpu": 40, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -507,52 +704,98 @@ jax_multiplatform_test( "gpu": 20, "tpu": 20, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "lax_numpy_indexing_test", srcs = ["lax_numpy_indexing_test.py"], + backend_tags = { + "tpu": ["noasan"], # Times out. + }, shard_count = { "cpu": 10, "gpu": 10, "tpu": 10, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "lax_numpy_einsum_test", srcs = ["lax_numpy_einsum_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "lax_numpy_setops_test", + srcs = ["lax_numpy_setops_test.py"], + backend_tags = { + "tpu": ["notsan"], # Test times out. + "cpu": ["notsan"], # Test times out. + }, shard_count = { - "cpu": 10, - "gpu": 10, - "tpu": 10, + "cpu": 5, + "gpu": 5, + "tpu": 5, }, + tags = [ + "noasan", # Test times out on all backends + ], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "lax_numpy_ufuncs_test", srcs = ["lax_numpy_ufuncs_test.py"], shard_count = { - "cpu": 10, - "gpu": 10, - "tpu": 10, + "cpu": 5, + "gpu": 4, + "tpu": 2, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "lax_numpy_vectorize_test", srcs = ["lax_numpy_vectorize_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "lax_scipy_test", srcs = ["lax_scipy_test.py"], + backend_tags = { + "tpu": ["noasan"], # Times out. + }, shard_count = { - "cpu": 20, - "gpu": 20, - "tpu": 20, + "cpu": 10, + "gpu": 10, + "tpu": 3, }, - deps = py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"), + deps = py_deps([ + "numpy", + "scipy", + "absl/testing", + ]), ) jax_multiplatform_test( @@ -563,37 +806,60 @@ jax_multiplatform_test( }, shard_count = { "cpu": 10, - "gpu": 10, - "tpu": 10, + "gpu": 5, + "tpu": 5, }, + deps = py_deps([ + "numpy", + "scipy", + "absl/testing", + ]), ) jax_multiplatform_test( name = "lax_scipy_special_functions_test", srcs = ["lax_scipy_special_functions_test.py"], backend_tags = { - "gpu": ["noasan"], # Times out. - "cpu": ["noasan"], # Times out. + "cpu": [ + "nomsan", # Times out. + "notsan", # Times out. + ], + "tpu": [ + "nomsan", # Times out. + "notsan", # Times out. + ], }, shard_count = { - "cpu": 20, - "gpu": 20, + "cpu": 30, + "gpu": 30, "tpu": 20, }, - deps = py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"), + tags = ["noasan"], # Times out under asan. + deps = py_deps([ + "numpy", + "scipy", + "absl/testing", + ]), ) jax_multiplatform_test( name = "lax_scipy_spectral_dac_test", srcs = ["lax_scipy_spectral_dac_test.py"], + backend_tags = { + "tpu": ["noasan"], # Times out. + "gpu": ["noasan"], # Times out. + }, shard_count = { - "cpu": 40, - "gpu": 40, - "tpu": 40, + "cpu": 2, + "gpu": 3, + "tpu": 3, }, deps = [ - "//jax:internal_test_util", - ] + py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"), + "//jax/_src:internal_test_util", + ] + py_deps([ + "numpy", + "absl/testing", + ]), ) jax_multiplatform_test( @@ -604,57 +870,60 @@ jax_multiplatform_test( "tpu": ["noasan"], # Times out. }, shard_count = { - "cpu": 40, - "gpu": 40, - "tpu": 40, + "cpu": 30, + "gpu": 30, + "tpu": 20, }, deps = [ - "//jax:internal_test_util", - "//jax:lax_reference", - ] + py_deps("numpy") + py_deps("mpmath"), -) - -jax_multiplatform_test( - name = "lax_metal_test", - srcs = ["lax_metal_test.py"], - enable_backends = ["metal"], - tags = ["notap"], - deps = [ - "//jax:internal_test_util", - "//jax:lax_reference", - ] + py_deps("numpy"), + "//jax/_src:internal_test_util", + "//jax/_src:lax_reference", + ] + py_deps([ + "numpy", + "absl/testing", + "mpmath", + ]), ) jax_multiplatform_test( name = "lax_autodiff_test", srcs = ["lax_autodiff_test.py"], shard_count = { - "cpu": 40, - "gpu": 40, + "cpu": 20, + "gpu": 20, "tpu": 20, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "lax_vmap_test", srcs = ["lax_vmap_test.py"], shard_count = { - "cpu": 40, - "gpu": 40, - "tpu": 40, + "cpu": 20, + "gpu": 20, + "tpu": 20, }, - deps = ["//jax:internal_test_util"] + py_deps("numpy") + py_deps("absl/testing"), + deps = ["//jax/_src:internal_test_util"] + py_deps([ + "numpy", + "absl/testing", + ]), ) jax_multiplatform_test( name = "lax_vmap_op_test", srcs = ["lax_vmap_op_test.py"], shard_count = { - "cpu": 40, - "gpu": 40, - "tpu": 40, + "cpu": 20, + "gpu": 20, + "tpu": 20, }, - deps = ["//jax:internal_test_util"] + py_deps("numpy") + py_deps("absl/testing"), + deps = ["//jax/_src:internal_test_util"] + py_deps([ + "numpy", + "absl/testing", + ]), ) jax_py_test( @@ -663,9 +932,9 @@ jax_py_test( "lazy_loader_test.py", ], deps = [ - "//jax:internal_test_util", - "//jax:test_util", - ], + "//jax/_src:internal_test_util", + "//jax/_src:test_util", + ] + py_deps("absl/testing"), ) jax_py_test( @@ -674,9 +943,23 @@ jax_py_test( "deprecation_test.py", ], deps = [ - "//jax:internal_test_util", - "//jax:test_util", + "//jax/_src:internal_test_util", + "//jax/_src:test_util", + ] + py_deps("absl/testing"), +) + +jax_py_test( + name = "documentation_coverage_test", + srcs = [ + "documentation_coverage_test.py", ], + deps = [ + "//jax", + "//jax/_src:config", + "//jax/_src:internal_test_util", + "//jax/_src:test_util", + # "//jax/docs", + ] + py_deps("absl/testing"), ) jax_multiplatform_test( @@ -690,12 +973,22 @@ jax_multiplatform_test( "nodebug", # Times out. "notsan", # Times out. ], + "cpu": [ + "noasan", + "nomsan", + "notsan", # Times out. + ], # TODO(phawkins): Latest SciPy leaks memory. }, shard_count = { "cpu": 40, "gpu": 40, "tpu": 40, }, + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( @@ -705,31 +998,42 @@ jax_multiplatform_test( "cpu", ], enable_configs = [ - "gpu_p100x2", - "gpu_p100x2_shardy", - "gpu_p100x2_pjrt_c_api", + "gpu_h100x2", ], + shard_count = 2, tags = [ "multiaccelerator", ], + deps = py_deps([ + "absl/testing", + ]), ) jax_multiplatform_test( name = "magma_linalg_test", srcs = ["magma_linalg_test.py"], enable_backends = ["gpu"], - deps = py_deps("magma"), + deps = py_deps([ + "magma", + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "cholesky_update_test", srcs = ["cholesky_update_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "metadata_test", srcs = ["metadata_test.py"], enable_backends = ["cpu"], + deps = py_deps("absl/testing"), ) jax_py_test( @@ -737,23 +1041,28 @@ jax_py_test( srcs = ["monitoring_test.py"], deps = [ "//jax", - "//jax:test_util", - ], + "//jax/_src:test_util", + ] + py_deps("absl/testing"), ) jax_multiplatform_test( name = "multibackend_test", srcs = ["multibackend_test.py"], enable_configs = [ - "tpu_v3_2x2", - "gpu_p100x2", + "tpu_v3_x4", + "gpu_h100x2", ], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "multi_device_test", srcs = ["multi_device_test.py"], enable_backends = ["cpu"], + deps = py_deps("absl/testing"), ) jax_multiplatform_test( @@ -772,20 +1081,68 @@ jax_multiplatform_test( "tpu": 10, "gpu": 10, }, + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( name = "optimizers_test", srcs = ["optimizers_test.py"], - deps = ["//jax:optimizers"], + deps = ["//jax/example_libraries:optimizers"] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "pickle_test", srcs = ["pickle_test.py"], + enable_backends = ["cpu"], + deps = py_deps([ + "cloudpickle", + "numpy", + "absl/testing", + ]), +) + +jax_multiplatform_test( + name = "pmap_without_shmap_test", + srcs = ["pmap_test.py"], + args = ["--pmap_shmap_merge=false"], + backend_tags = { + "tpu": [ + "noasan", # Times out under asan. + "requires-mem:16g", # Under tsan on 2x2 this test exceeds the default 12G memory limit. + ], + "gpu": [ + "noasan", # Times out under asan. + ], + }, + disable_configs = [ + "tpu_v6e", # Forge OOM. + ], + enable_configs = [ + "gpu_v100", + "tpu_v3_x4", + ], + minimal_shard_count = { + "tpu": 8, + }, + shard_count = { + "cpu": 20, + "gpu": 10, + "tpu": 20, + }, + tags = ["multiaccelerator"], deps = [ - "//jax:experimental", - ] + py_deps("cloudpickle") + py_deps("numpy"), + "//jax/_src:internal_test_util", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -796,20 +1153,48 @@ jax_multiplatform_test( "noasan", # Times out under asan. "requires-mem:16g", # Under tsan on 2x2 this test exceeds the default 12G memory limit. ], + "gpu": [ + "noasan", # Times out under asan. + ], }, enable_configs = [ "gpu_v100", - "tpu_v3_2x2", + "tpu_v3_x4", ], + minimal_shard_count = { + "tpu": 8, + }, shard_count = { - "cpu": 30, - "gpu": 30, - "tpu": 30, + "cpu": 20, + "gpu": 10, + "tpu": 20, }, tags = ["multiaccelerator"], deps = [ - "//jax:internal_test_util", + "//jax/_src:internal_test_util", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "pmap_shmap_merge_test", + srcs = ["pmap_shmap_merge_test.py"], + enable_configs = [ + "gpu_v100", + "tpu_v3_x4", ], + shard_count = { + "cpu": 20, + "gpu": 10, + "tpu": 20, + }, + tags = ["multiaccelerator"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -817,9 +1202,6 @@ jax_multiplatform_test( srcs = ["polynomial_test.py"], # No implementation of nonsymmetric Eigendecomposition. enable_backends = ["cpu"], - shard_count = { - "cpu": 10, - }, # This test ends up calling Fortran code that initializes some memory and # passes it to C code. MSan is not able to detect that the memory was # initialized by Fortran, and it makes the test fail. This can usually be @@ -827,12 +1209,21 @@ jax_multiplatform_test( # in this case there's not a good place to do it, see b/197635968#comment19 # for details. tags = ["nomsan"], + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( name = "heap_profiler_test", srcs = ["heap_profiler_test.py"], enable_backends = ["cpu"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -848,10 +1239,25 @@ jax_multiplatform_test( enable_backends = [ "cpu", "gpu", + "tpu", ], + tags = ["multiaccelerator"], deps = [ - "//jax:profiler", - ], + "//jax/_src:profiler", + ] + py_deps([ + "absl/testing", + "portpicker", + ]), +) + +jax_py_test( + name = "profiler_session_test", + srcs = ["profiler_session_test.py"], + deps = [ + "//jax", + "//jax/_src:profiler", + "//jax/_src:test_util", + ] + py_deps("absl/testing"), ) jax_multiplatform_test( @@ -866,7 +1272,10 @@ jax_multiplatform_test( "nomsan", # TODO(b/355237462): msan false-positives in torch? "not_build:arm", ], - deps = py_deps("torch"), + deps = py_deps([ + "torch", + "absl/testing", + ]), ) jax_multiplatform_test( @@ -879,29 +1288,20 @@ jax_multiplatform_test( "notsan", # Times out ], }, - shard_count = 10, + shard_count = 4, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "random_test", srcs = ["random_test.py"], - backend_tags = { - "cpu": [ - "notsan", # Times out - "nomsan", # Times out - ], - "tpu": [ - "optonly", - "nomsan", # Times out - "notsan", # Times out - ], - }, - shard_count = { - "cpu": 30, - "gpu": 30, - "tpu": 40, - }, - tags = ["noasan"], # Times out + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -922,11 +1322,16 @@ jax_multiplatform_test( "gpu": ["--jax_num_generated_cases=40"], }, shard_count = { - "cpu": 40, - "gpu": 40, + "cpu": 20, + "gpu": 50, "tpu": 40, }, tags = ["noasan"], # Times out + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) # TODO(b/199564969): remove once we always enable_custom_prng @@ -934,25 +1339,8 @@ jax_multiplatform_test( name = "random_test_with_custom_prng", srcs = ["random_test.py"], args = ["--jax_enable_custom_prng=true"], - backend_tags = { - "cpu": [ - "noasan", # Times out under asan/msan/tsan. - "nomsan", - "notsan", - ], - "tpu": [ - "noasan", # Times out under asan/msan/tsan. - "nomsan", - "notsan", - "optonly", - ], - }, main = "random_test.py", - shard_count = { - "cpu": 40, - "gpu": 40, - "tpu": 40, - }, + deps = py_deps("absl/testing"), ) jax_multiplatform_test( @@ -966,21 +1354,41 @@ jax_multiplatform_test( ], # Times out on TPU with asan/tsan/msan. }, shard_count = 12, + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( name = "scipy_interpolate_test", srcs = ["scipy_interpolate_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( name = "scipy_ndimage_test", srcs = ["scipy_ndimage_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( name = "scipy_optimize_test", srcs = ["scipy_optimize_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( @@ -1002,33 +1410,52 @@ jax_multiplatform_test( "gpu_h100", # TODO(phawkins): numerical failure on h100 ], shard_count = { - "cpu": 40, - "gpu": 40, - "tpu": 50, + "cpu": 20, + "gpu": 20, + "tpu": 30, }, + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( name = "scipy_spatial_test", srcs = ["scipy_spatial_test.py"], - deps = py_deps("scipy"), + shard_count = { + "cpu": 2, + "gpu": 2, + }, + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( name = "scipy_stats_test", srcs = ["scipy_stats_test.py"], backend_tags = { + "cpu": ["nomsan"], # Times out "tpu": ["nomsan"], # Times out }, shard_count = { - "cpu": 40, - "gpu": 30, - "tpu": 40, + "cpu": 50, + "gpu": 50, + "tpu": 50, }, tags = [ "noasan", "notsan", ], # Times out + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( @@ -1049,9 +1476,9 @@ jax_multiplatform_test( "gpu": ["--jax_num_generated_cases=40"], }, shard_count = { - "cpu": 50, - "gpu": 50, - "tpu": 50, + "cpu": 15, + "gpu": 20, + "tpu": 15, }, tags = [ "noasan", @@ -1059,9 +1486,13 @@ jax_multiplatform_test( "notsan", ], # Test times out under asan/msan/tsan. deps = [ - "//jax:experimental_sparse", - "//jax:sparse_test_util", - ] + py_deps("scipy"), + "//jax/experimental:sparse", + "//jax/experimental:sparse_test_util", + ] + py_deps([ + "scipy", + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1077,14 +1508,21 @@ jax_multiplatform_test( }, # Use fewer cases to prevent timeouts. backend_variant_args = { - "cpu": ["--jax_num_generated_cases=40"], - "cpu_x32": ["--jax_num_generated_cases=40"], - "gpu": ["--jax_num_generated_cases=40"], - "tpu": ["--jax_num_generated_cases=40"], + "cpu": ["--jax_num_generated_cases=20"], + "cpu_x32": ["--jax_num_generated_cases=30"], + "gpu_a100": ["--jax_num_generated_cases=40"], + "gpu_p100": ["--jax_num_generated_cases=40"], + "gpu_v100": ["--jax_num_generated_cases=40"], + "tpu_v3": ["--jax_num_generated_cases=40"], + "tpu_v5e": ["--jax_num_generated_cases=40"], }, disable_configs = [ - "cpu_shardy", # TODO(b/376475853): array values mismatch, need to fix and re-enable. + "gpu_b200", # Times out + "tpu_7x", # Times out ], + minimal_shard_count = { + "tpu": 8, + }, shard_count = { "cpu": 50, "gpu": 50, @@ -1096,23 +1534,12 @@ jax_multiplatform_test( "notsan", ], # Test times out under asan/msan/tsan. deps = [ - "//jax:experimental_sparse", - "//jax:sparse_test_util", - ] + py_deps("scipy"), -) - -jax_multiplatform_test( - name = "sparse_nm_test", - srcs = ["sparse_nm_test.py"], - enable_backends = [], - enable_configs = [ - "gpu_a100", - "gpu_h100", - ], - deps = [ - "//jax:experimental_sparse", - "//jax:pallas_gpu", - ], + "//jax/experimental:sparse", + "//jax/experimental:sparse_test_util", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1122,10 +1549,11 @@ jax_multiplatform_test( backend_tags = { "cpu": [ "noasan", # Times out under asan - "notsan", # Times out under asan + "notsan", # Times out under tsan ], "tpu": [ - "noasan", # Times out under asan. + "noasan", # Times out under asan + "notsan", # Times out under tsan ], }, shard_count = { @@ -1134,50 +1562,88 @@ jax_multiplatform_test( "tpu": 10, }, deps = [ - "//jax:experimental_sparse", - "//jax:sparse_test_util", - ], + "//jax/experimental:sparse", + "//jax/experimental:sparse_test_util", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "stack_test", srcs = ["stack_test.py"], + deps = py_deps("absl/testing"), ) jax_multiplatform_test( name = "checkify_test", srcs = ["checkify_test.py"], - enable_configs = ["tpu_v3_2x2"], + enable_configs = ["tpu_v3_x4"], shard_count = { "gpu": 2, "tpu": 4, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "error_check_test", srcs = ["error_check_test.py"], + deps = py_deps("absl/testing"), +) + +jax_multiplatform_test( + name = "jax_numpy_error_test", + srcs = ["jax_numpy_error_test.py"], + deps = py_deps("absl/testing"), +) + +jax_multiplatform_test( + name = "scheduling_groups_test", + srcs = ["scheduling_groups_test.py"], + deps = py_deps("absl/testing"), +) + +jax_multiplatform_test( + name = "fused_test", + srcs = ["fused_test.py"], + deps = [ + "//jax/experimental:fused", + ] + py_deps([ + "absl/testing", + ]), ) jax_multiplatform_test( name = "stax_test", srcs = ["stax_test.py"], - shard_count = { - "cpu": 5, - "gpu": 5, - }, - deps = ["//jax:stax"], + deps = ["//jax/example_libraries:stax"] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "linear_search_test", srcs = ["third_party/scipy/line_search_test.py"], main = "third_party/scipy/line_search_test.py", + deps = py_deps([ + "absl/testing", + "scipy", + ]), ) jax_multiplatform_test( name = "blocked_sampler_test", srcs = ["blocked_sampler_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_py_test( @@ -1185,8 +1651,12 @@ jax_py_test( srcs = ["tree_util_test.py"], deps = [ "//jax", - "//jax:test_util", - ], + "//jax/_src:test_util", + ] + py_deps([ + "absl/testing", + "numpy", + "cloudpickle", + ]), ) pytype_test( @@ -1194,8 +1664,12 @@ pytype_test( srcs = ["typing_test.py"], deps = [ "//jax", - "//jax:test_util", - ], + "//jax/_src:test_util", + "//jax/_src:typing", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_py_test( @@ -1203,8 +1677,8 @@ jax_py_test( srcs = ["util_test.py"], deps = [ "//jax", - "//jax:test_util", - ], + "//jax/_src:test_util", + ] + py_deps("absl/testing"), ) jax_py_test( @@ -1212,15 +1686,15 @@ jax_py_test( srcs = ["version_test.py"], deps = [ "//jax", - "//jax:test_util", - ], + "//jax/_src:test_util", + ] + py_deps("absl/testing"), ) jax_py_test( name = "warnings_util_test", srcs = ["warnings_util_test.py"], deps = [ - "//jax:test_util", + "//jax/_src:test_util", ] + py_deps("absl/testing"), ) @@ -1230,9 +1704,11 @@ jax_py_test( data = ["testdata/example_pjrt_plugin_config.json"], deps = [ "//jax", - "//jax:compiler", - "//jax:test_util", - ] + py_deps("absl/logging"), + "//jax/_src:compiler", + "//jax/_src:test_util", + ] + py_deps([ + "absl/logging", + ]), ) jax_py_test( @@ -1240,61 +1716,87 @@ jax_py_test( srcs = ["lru_cache_test.py"], deps = [ "//jax", - "//jax:lru_cache", - "//jax:test_util", - ] + py_deps("filelock"), + "//jax/_src:lru_cache", + "//jax/_src:test_util", + ] + py_deps([ + "filelock", + "absl/logging", + ]), ) jax_multiplatform_test( name = "compilation_cache_test", srcs = ["compilation_cache_test.py"], deps = [ - "//jax:compilation_cache_internal", - "//jax:compiler", - ], + "//jax/_src:compilation_cache_internal", + "//jax/_src:compiler", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "cache_key_test", srcs = ["cache_key_test.py"], deps = [ - "//jax:cache_key", - "//jax:compiler", - ], + "//jax/_src:cache_key", + "//jax/_src:compiler", + "//jax/_src:custom_partitioning", + ] + py_deps("absl/testing"), ) jax_multiplatform_test( name = "ode_test", srcs = ["ode_test.py"], - shard_count = { - "cpu": 10, - }, - deps = ["//jax:ode"], + deps = ["//jax/experimental:ode"] + py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( name = "key_reuse_test", srcs = ["key_reuse_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "roofline_test", srcs = ["roofline_test.py"], enable_backends = ["cpu"], + deps = py_deps("absl/testing"), ) jax_multiplatform_test( name = "x64_context_test", srcs = ["x64_context_test.py"], + enable_backends = ["cpu"], deps = [ - "//jax:experimental", - ], + "//jax/experimental", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "ann_test", srcs = ["ann_test.py"], - shard_count = 10, + backend_tags = { + "tpu": [ + "noasan", # Times out. + "notsan", # Times out. + ], + }, + deps = py_deps([ + "numpy", + "absl/testing", + ]), ) jax_py_test( @@ -1302,54 +1804,57 @@ jax_py_test( srcs = ["mesh_utils_test.py"], deps = [ "//jax", - "//jax:mesh_utils", - "//jax:test_util", - ], + "//jax/_src:test_util", + "//jax/experimental:mesh_utils", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "transfer_guard_test", srcs = ["transfer_guard_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + "cloudpickle", + ]), ) jax_multiplatform_test( name = "garbage_collection_guard_test", srcs = ["garbage_collection_guard_test.py"], + deps = py_deps("absl/testing"), ) -jax_multiplatform_test( +jax_py_test( name = "name_stack_test", srcs = ["name_stack_test.py"], + deps = [ + "//jax", + "//jax/_src:test_util", + ] + py_deps("absl/testing"), ) jax_multiplatform_test( name = "jaxpr_effects_test", srcs = ["jaxpr_effects_test.py"], - backend_tags = { - "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 - }, - enable_configs = [ - "cpu", - "gpu_h100", - "tpu_v2_1x1", - "tpu_v3_2x2", - "tpu_v4_2x2", - ], - tags = ["multiaccelerator"], + enable_backends = ["cpu"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "debugging_primitives_test", srcs = ["debugging_primitives_test.py"], - enable_configs = [ - "cpu", - "gpu_h100", - "tpu_v2_1x1", - "tpu_v3_2x2", - "tpu_v4_2x2", - "gpu_a100_shardy", - "tpu_v3_2x2_shardy", - ], + tags = ["multiaccelerator"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1358,32 +1863,23 @@ jax_multiplatform_test( backend_tags = { "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 }, - enable_configs = [ - "tpu_v2_1x1", - "tpu_v3_2x2", - "tpu_v4_2x2", - "tpu_v3_2x2_shardy", - "gpu_p100x2_shardy", - ], tags = ["multiaccelerator"], deps = [ - "//jax:experimental", - ], + "//jax/experimental", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "debugger_test", srcs = ["debugger_test.py"], - disable_configs = [ - "cpu_shardy", # TODO(b/364547005): enable once pure callbacks are supported. - ], - enable_configs = [ - "cpu", - "gpu_h100", - "tpu_v2_1x1", - "tpu_v3_2x2", - "tpu_v4_2x2", - ], + tags = ["multiaccelerator"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1393,71 +1889,65 @@ jax_multiplatform_test( args = [ "--jax_num_generated_cases=5", ], - backend_variant_args = { - "tpu_pjrt_c_api": ["--jax_num_generated_cases=1"], - }, enable_configs = [ "gpu_h100", "cpu", ], - shard_count = { - "cpu": 2, - "gpu": 2, - "tpu": 2, - }, - deps = py_deps("hypothesis"), + deps = py_deps([ + "absl/testing", + "numpy", + "hypothesis", + ]), ) jax_multiplatform_test( name = "mutable_array_test", srcs = ["mutable_array_test.py"], + shard_count = { + "cpu": 10, + }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( - name = "for_loop_test", - srcs = ["for_loop_test.py"], - shard_count = { - "cpu": 20, - "gpu": 10, - "tpu": 20, - }, + name = "stateful_rng_test", + srcs = ["stateful_rng_test.py"], + deps = [ + "//jax/experimental:random", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "ragged_collective_test", srcs = ["ragged_collective_test.py"], - disable_configs = [ - "tpu_pjrt_c_api", - ], enable_backends = [ "gpu", "tpu", ], - enable_configs = [ - "gpu_p100x2_shardy", - ], - shard_count = { - "gpu": 10, - "tpu": 10, - }, tags = [ "multiaccelerator", ], - deps = [ - "//jax:experimental", - ], + deps = py_deps("absl/testing"), ) jax_multiplatform_test( name = "shard_map_test", srcs = ["shard_map_test.py"], - enable_configs = [ - "gpu_p100x2_shardy", - "tpu_v3_2x2_shardy", + disable_configs = [ + "gpu_h100x2_tfrt", # TODO(b/419192167): Doesn't work ], + minimal_shard_count = { + "tpu": 50, + }, shard_count = { "cpu": 50, - "gpu": 10, + "gpu": 20, "tpu": 50, }, tags = [ @@ -1467,44 +1957,60 @@ jax_multiplatform_test( "notsan", ], # Times out under *SAN. deps = [ - "//jax:experimental", - "//jax:tree_util", - ], + "//jax/_src:tree_util", + "//jax/experimental:custom_partitioning", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "clear_backends_test", srcs = ["clear_backends_test.py"], + deps = py_deps("absl/testing"), ) jax_multiplatform_test( - name = "attrs_test", - srcs = ["attrs_test.py"], + name = "hijax_test", + srcs = ["hijax_test.py"], deps = [ - "//jax:experimental", - ], + "//jax/experimental:hijax", + ] + py_deps([ + "numpy", + "absl/testing", + ]), ) jax_multiplatform_test( name = "colocated_python_test", srcs = ["colocated_python_test.py"], deps = [ - "//jax:experimental_colocated_python", + "//jax/experimental:colocated_python", "//jax/extend:ifrt_programs", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "experimental_rnn_test", srcs = ["experimental_rnn_test.py"], + backend_tags = { + "gpu": ["noasan"], # Times out. + }, disable_configs = [ "gpu_a100", # Numerical precision problems. ], enable_backends = ["gpu"], shard_count = 15, deps = [ - "//jax:rnn", - ], + "//jax/experimental:rnn", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_py_test( @@ -1512,9 +2018,9 @@ jax_py_test( srcs = ["mosaic_test.py"], deps = [ "//jax", - "//jax:mosaic", - "//jax:test_util", - ], + "//jax/_src:test_util", + "//jax/experimental:mosaic", + ] + py_deps("absl/testing"), ) jax_py_test( @@ -1522,8 +2028,8 @@ jax_py_test( srcs = ["source_info_test.py"], deps = [ "//jax", - "//jax:test_util", - ], + "//jax/_src:test_util", + ] + py_deps("absl/testing"), ) jax_py_test( @@ -1531,75 +2037,80 @@ jax_py_test( srcs = ["package_structure_test.py"], deps = [ "//jax", - "//jax:test_util", - ], + "//jax/_src:test_util", + ] + py_deps("absl/testing"), ) jax_multiplatform_test( name = "logging_test", srcs = ["logging_test.py"], + deps = py_deps("absl/testing"), +) + +jax_py_test( + name = "absl_cpp_logging_test", + srcs = ["absl_cpp_logging_test.py"], + deps = [ + "//jax", + "//jax/_src:test_util", + ] + py_deps("absl/testing"), ) jax_multiplatform_test( name = "export_test", srcs = ["export_test.py"], - disable_configs = [ - "cpu_shardy", # TODO(b/355263220): enable once export is supported. - ], - enable_configs = [ - "cpu_shardy", - "gpu_p100x2_shardy", - "tpu_v3_2x2_shardy", - "tpu_v3_2x2", - ], - tags = [], + tags = ["multiaccelerator"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "shape_poly_test", srcs = ["shape_poly_test.py"], - disable_configs = [ - "gpu_a100", # TODO(b/269593297): matmul precision issues - ], enable_configs = [ "cpu", "cpu_x32", ], - shard_count = { - "cpu": 4, - "gpu": 6, - "tpu": 4, - }, + shard_count = 15, tags = [ "noasan", # Times out "nomsan", # Times out "notsan", # Times out ], deps = [ - "//jax:internal_test_harnesses", - ], + "//jax/_src:internal_test_harnesses", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "export_harnesses_multi_platform_test", srcs = ["export_harnesses_multi_platform_test.py"], disable_configs = [ - "gpu_a100", # TODO(b/269593297): matmul precision issues "gpu_h100", # Scarce resources. - "cpu_shardy", # TODO(b/355263220): enable once export is supported. ], + minimal_shard_count = { + "tpu": 8, + }, shard_count = { - "cpu": 40, - "gpu": 20, - "tpu": 20, + "cpu": 20, + "gpu": 30, + "tpu": 15, }, tags = [ "noasan", # Times out "nodebug", # Times out. ], deps = [ - "//jax:internal_test_harnesses", - ], + "//jax/_src:internal_test_harnesses", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1607,25 +2118,63 @@ jax_multiplatform_test( srcs = ["export_back_compat_test.py"], tags = [], deps = [ - "//jax:internal_export_back_compat_test_data", - "//jax:internal_export_back_compat_test_util", + "//jax/_src:internal_export_back_compat_test_data", + "//jax/_src:internal_export_back_compat_test_util", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "export_serialization_back_compat_test", + srcs = ["export_serialization_back_compat_test.py"], + enable_backends = [ + "cpu", + "gpu", + "tpu", ], + enable_configs = [ + "tpu_v3_x4", + "gpu_h100x2", + ], + tags = [], + deps = [ + "//jax/_src:internal_export_back_compat_test_data", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "fused_attention_stablehlo_test", srcs = ["fused_attention_stablehlo_test.py"], enable_backends = ["gpu"], - shard_count = { - "gpu": 4, - }, + shard_count = 8, tags = ["multiaccelerator"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "xla_metadata_test", srcs = ["xla_metadata_test.py"], - deps = ["//jax:experimental"], + deps = ["//jax/experimental:xla_metadata"] + py_deps("absl/testing"), +) + +jax_multiplatform_test( + name = "unary_ops_accuracy_test", + srcs = ["unary_ops_accuracy_test.py"], + enable_backends = [ + "tpu", + ], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_py_test( @@ -1633,8 +2182,8 @@ jax_py_test( srcs = ["pretty_printer_test.py"], deps = [ "//jax", - "//jax:test_util", - ], + "//jax/_src:test_util", + ] + py_deps("absl/testing"), ) jax_py_test( @@ -1642,9 +2191,9 @@ jax_py_test( srcs = ["source_mapper_test.py"], deps = [ "//jax", - "//jax:source_mapper", - "//jax:test_util", - ], + "//jax/_src:test_util", + "//jax/experimental:source_mapper", + ] + py_deps("absl/testing"), ) jax_py_test( @@ -1652,13 +2201,21 @@ jax_py_test( srcs = ["sourcemap_test.py"], deps = [ "//jax", - "//jax:test_util", - ], + "//jax/_src:sourcemap", + "//jax/_src:test_util", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "string_array_test", srcs = ["string_array_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1670,6 +2227,7 @@ jax_multiplatform_test( "gpu_h100", ], tags = ["multiaccelerator"], + deps = py_deps("absl/testing"), ) jax_multiplatform_test( @@ -1679,6 +2237,13 @@ jax_multiplatform_test( shard_count = { "gpu": 4, }, + tags = [ + "multiaccelerator", + ], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_py_test( @@ -1686,14 +2251,27 @@ jax_py_test( srcs = ["custom_partitioning_sharding_rule_test.py"], deps = [ "//jax", - "//jax:experimental", - "//jax:test_util", - ], + "//jax/_src:custom_partitioning_sharding_rule", + "//jax/_src:test_util", + ] + py_deps("absl/testing"), +) + +jax_py_test( + name = "traceback_test", + srcs = ["traceback_test.py"], + deps = [ + "//jax", + "//jax/_src:test_util", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) exports_files( [ "api_test.py", + "custom_api_test.py", "array_test.py", "cache_key_test.py", "colocated_python_test.py", diff --git a/tests/absl_cpp_logging_test.py b/tests/absl_cpp_logging_test.py new file mode 100644 index 000000000000..ea641639ed23 --- /dev/null +++ b/tests/absl_cpp_logging_test.py @@ -0,0 +1,38 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +import jax +from jax._src import test_util as jtu +from jax._src.lib import utils + + +# Note: This test modifies global logging library configuration knobs. +# We isolate it from other tests by running it as a separate test target. +@jtu.skip_under_pytest("Test must run in an isolated process") +class AbslCppLoggingTest(jtu.JaxTestCase): + + def test_vlogging(self): + utils.absl_set_min_log_level(0) # INFO + with jtu.capture_stderr() as stderr: + jax.jit(lambda x: x + 1)(1) + self.assertNotIn("hlo_pass_pipeline.cc", stderr()) + with jtu.capture_stderr() as stderr: + utils.absl_set_vlog_level("hlo_pass_pipeline", 1) + jax.jit(lambda x: x + 2)(1) + self.assertIn("hlo_pass_pipeline.cc", stderr()) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/ann_test.py b/tests/ann_test.py index 1d704c725c61..18bb51bec93b 100644 --- a/tests/ann_test.py +++ b/tests/ann_test.py @@ -179,7 +179,7 @@ def approx_max_k(qy, db): def test_vmap_after(self): - batch = 4 + batch = 8 qy_size = 128 db_size = 1024 feature_dim = 32 diff --git a/tests/aot_test.py b/tests/aot_test.py index daaeb8417d33..2b39e0e67b7e 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -16,12 +16,13 @@ import unittest from absl.testing import absltest import jax +from jax import lax +from jax._src import config from jax._src import core from jax._src import test_util as jtu -import jax._src.lib +from jax._src.lib import jaxlib_extension_version from jax._src.lib import xla_client as xc from jax.experimental import topologies -from jax.experimental.pjit import pjit from jax.experimental.serialize_executable import ( deserialize_and_load, serialize, @@ -42,12 +43,12 @@ class JaxAotTest(jtu.JaxTestCase): @jtu.run_on_devices('tpu', 'gpu') - def test_pickle_pjit_lower(self): + def test_pickle_jit_lower(self): def fun(x): return x * x - with jax.sharding.Mesh(np.array(jax.devices()), ('data',)): - lowered = pjit( + with jax.set_mesh(jax.sharding.Mesh(np.array(jax.devices()), ('data',))): + lowered = jax.jit( fun, in_shardings=P('data'), out_shardings=P(None, 'data') ).lower(core.ShapedArray(shape=(8, 8), dtype=np.float32)) @@ -62,16 +63,21 @@ def verify_serialization(lowered): jax.pmap(lambda x: x * x).lower( np.zeros((len(jax.devices()), 4), dtype=np.float32))) - def test_topology_pjit_serialize(self): + @jtu.skip_on_devices("tpu") # TODO(phawkins): This test is segfaulting on TPU + def test_topology_jit_serialize(self): try: aot_topo = topologies.get_topology_desc( platform=jax.devices()[0].platform ) - except NotImplementedError: + except (ValueError, NotImplementedError) as e: + assert ('topology_name is not specified' in str(e) or + 'topology not implemented' in str(e)) raise unittest.SkipTest('PJRT Topology not supported') if jtu.TEST_WITH_PERSISTENT_COMPILATION_CACHE.value: raise unittest.SkipTest('Compilation caching not yet supported.') + if jtu.is_device_cuda(): + raise unittest.SkipTest('Broken on GPU: b/442353988') @jax.jit def fn(x): @@ -103,7 +109,9 @@ def test_get_topology_from_devices(self): aot_topo = topologies.get_topology_desc( platform=jax.devices()[0].platform ) - except NotImplementedError: + except (ValueError, NotImplementedError) as e: + assert ('topology_name is not specified' in str(e) or + 'topology not implemented' in str(e)) raise unittest.SkipTest('PJRT Topology not supported') topo = xc.get_topology_for_devices(aot_topo.devices) @@ -122,10 +130,173 @@ def my_function(x): self.assertNotRegex(stablehlo, r"sine.* loc") hlo = lowered.as_text("hlo", debug_info=True) - self.assertRegex(hlo, r"sine.*metadata=.*source_file=.*") + self.assertRegex(hlo, r'sine.*metadata=.*[stack_frame_id|source_file]=.*') hlo = lowered.as_text("hlo") - self.assertNotRegex(hlo, r"sine.*metadata=.*source_file=.*") + self.assertNotRegex( + hlo, r'sine.*metadata=.*[stack_frame_id|source_file]=.*' + ) + + def test_constants_in_lowering_in_aot(self): + const_size = 100 + const = jax.random.uniform(jax.random.key(0), (const_size,), + dtype=np.float32) + + def my_function(x): + return jnp.sin(x) + const + + lowered = jax.jit(my_function).lower(np.full_like(const, 42., dtype=const.dtype)) + stablehlo = lowered.as_text("stablehlo") + if config.use_simplified_jaxpr_constants.value: + self.assertNotRegex(stablehlo, rf"stablehlo.constant dense.*tensor<{const_size}x") + self.assertLen(lowered._lowering.const_args, 1) + self.assertIs(lowered._lowering.const_args[0], const) + else: + self.assertRegex(stablehlo, rf"stablehlo.constant dense.*tensor<{const_size}x") + self.assertLen(lowered._lowering.const_args, 0) + + def test_with_constants(self): + const = jnp.arange(16.) + 42. # A distinctive shape and value + + @jax.jit + def f(x): + return const[0:8] + x + + inp = jnp.arange(8.) + compiled = f.lower(inp).compile() + self.assertLen(compiled.args_info[0], 1) # Not including const_args + self.assertLen(compiled.in_avals[0], 1) + if config.use_simplified_jaxpr_constants.value: + self.assertLen(compiled._params.const_args, 1) + self.assertIs(compiled._params.const_args[0], const) + else: + self.assertLen(compiled._params.const_args, 0) + self.assertArraysEqual(compiled(inp), const[0:8] + inp) + self.assertCacheMisses(lambda: compiled(inp), cpp=0, aot_call=0) + + @jtu.parameterized_filterable( + kwargs=[ + dict(use_np=use_np, lower=lower, compile=compile, exec=exec) + for use_np in (False, True) + for lower in (False, True) + for compile in (False, True) + for exec in (False, True) + ]) + def test_with_constants_enable_x64(self, *, use_np, lower, compile, exec): + # Closed-over constant is 64-bit. Each of lowering, compilation, and + # execution can be run in 64-bit or 32-bit mode. + with config.enable_x64(True): + arange = np.arange if use_np else jnp.arange + const = arange(8, dtype=np.int64) + 42 + + @jax.jit + def f(x): + return lax.convert_element_type(const, np.float32) + x + + inp = np.arange(8., dtype=np.float32) + with config.enable_x64(True) if lower else contextlib.nullcontext(): + lowered = f.lower(inp) + with config.enable_x64(True) if compile else contextlib.nullcontext(): + compiled = lowered.compile() + + def run(): + with config.enable_x64(True) if exec else contextlib.nullcontext(): + return compiled(inp) + + self.assertLen(compiled.args_info[0], 1) # Not including const_args + self.assertLen(compiled.in_avals[0], 1) + if config.use_simplified_jaxpr_constants.value: + self.assertLen(compiled._params.const_args, 1) + self.assertLen(compiled._executable.in_avals, 2) + expected_dtype = np.int64 + if not config.enable_x64.value and use_np and not lower: + expected_dtype = np.int32 + self.assertEqual(compiled._executable.in_avals[0].dtype, expected_dtype) + if expected_dtype is np.int64: # Otherwise, we made a copy of the const + if use_np: + self.assertIs(np.asarray(compiled._params.const_args[0]), const) + else: + self.assertIs(compiled._params.const_args[0], const) + else: + self.assertLen(compiled._params.const_args, 0) + self.assertLen(compiled._executable.in_avals, 1) + + # In some cases we expect errors: in 32-bit mode, lowered with 64-bit mode + # and execute in 32-bit mode. + if (config.use_simplified_jaxpr_constants.value and + not config.enable_x64.value and + use_np and lower and not exec): + with self.assertRaisesRegex( + xc.XlaRuntimeError, + "got buffer with incompatible size"): + run() + return + + self.assertArraysEqual(run(), + lax.convert_element_type(const, inp.dtype) + inp) + # Trigger cache hit + self.assertCacheMisses(run, cpp=0, aot_call=0) + + def test_with_ref_constants(self): + x_ref = core.new_ref(0) + + @jax.jit + def f(x): + x_ref[...] += x + + f_lowered = f.lower(1) + with self.assertRaisesRegex(ValueError, 'serialize with a closed-over'): + serialized, in_tree, out_tree = serialize(f_lowered.compile()) + + @jtu.run_on_devices('gpu', 'tpu') + def test_mismatched_backends_raises(self): + @jax.jit + def f(x): + return x * 2 + + x = jnp.arange(1) + f_lowered = f.lower(x) + serialized, in_tree, out_tree = serialize(f_lowered.compile()) + with self.assertRaisesRegex( + ValueError, + 'Execution devices belong to a client other than `backend`'): + deserialize_and_load(serialized, in_tree, out_tree, backend='cpu', + execution_devices=jax.devices()[:1]) + + @jtu.run_on_devices('gpu') + def test_deviceless_aot_compile(self): + if jaxlib_extension_version < 393: + raise unittest.SkipTest('Test requires jaxlib extension version 393 or higher') + target_config = xc.get_topology_for_devices(jax.devices()).target_config + gpu_platform = jax.devices()[0].platform # Capture before switching to cpu + with jtu.global_config_context(jax_platforms="cpu"): + topology = topologies.get_topology_desc( + platform=gpu_platform, + target_config=target_config, + topology="1x1x1", + ) + assert topology.devices[0].client.runtime_type == "compile_only_runtime" + mesh = topologies.make_mesh(topo=topology, mesh_shape=(1,), axis_names=("x",)) + x = jax.ShapeDtypeStruct( + shape=(2, 2), + dtype=jnp.float32, + sharding=jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("x")) + ) + compiled = jax.jit(lambda x: jnp.sum(x * x)).lower(x).compile() + serialized_executable, _, _ = serialize(compiled) + + _, in_tree = jax.tree.flatten(((0,), {})) + _, out_tree = jax.tree.flatten(0) + compiled = deserialize_and_load( + serialized_executable, + in_tree, + out_tree, + backend=gpu_platform, + execution_devices=jax.devices()[:1] + ) + input = jnp.array([[0., 1.], [2., 3.]], dtype=jnp.float32, device=jax.devices()[0]) + result = compiled(input) + self.assertEqual(result, 14.) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/api_test.py b/tests/api_test.py index aece7b19fdfb..15d8b541c945 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -16,7 +16,6 @@ import collections import collections.abc -from collections.abc import Callable import concurrent.futures from contextlib import contextmanager import copy @@ -34,6 +33,7 @@ import re import subprocess import sys +import threading import traceback import types from typing import NamedTuple @@ -43,7 +43,6 @@ from absl import logging from absl.testing import absltest, parameterized import jax -from jax import custom_derivatives as custom_derivatives_public from jax import device_put, float0, grad, hessian, jacfwd, jacrev, jit from jax import lax from jax import tree_util @@ -51,30 +50,24 @@ from jax._src import array from jax._src import config from jax._src import core -from jax._src import custom_derivatives +from jax._src import dispatch from jax._src import linear_util as lu +from jax._src import literals from jax._src import test_util as jtu from jax._src import xla_bridge from jax._src import debugging -from jax._src import pjit as pjit_lib +from jax._src import sharding_impls from jax._src.ad_checkpoint import saved_residuals from jax._src.interpreters import ad as ad_internal from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.compilation_cache import is_persistent_cache_enabled -from jax._src.lib import xla_extension import jax._src.util as jax_util -from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint -import jax.custom_batching -import jax.custom_derivatives -import jax.custom_transpose -import jax.experimental.custom_dce +from jax.ad_checkpoint import checkpoint_name from jax.errors import (UnexpectedTracerError, TracerIntegerConversionError, ConcretizationTypeError, TracerBoolConversionError) -from jax.experimental import pjit from jax.interpreters import ad from jax.interpreters import batching -from jax.interpreters import xla import jax.numpy as jnp from jax.sharding import PartitionSpec as P import numpy as np @@ -99,6 +92,14 @@ def my_function(): jitted = jit(my_function) self.assertEqual(repr(jitted), f"") + def test_jit_decorator_factory(self): + @jit(static_argnames=['flag']) + def func(x, flag): + return x if flag else -x + + self.assertEqual(func(1, True), 1) + self.assertEqual(func(1, False), -1) + def test_fun_name(self): def my_function(): return @@ -272,12 +273,8 @@ def test_jit_device(self): _check_instance(self, x) self.assertEqual(x.devices(), {device}) - @parameterized.named_parameters( - ('jit', jax.jit), - ('pjit', pjit.pjit), - ) @jtu.skip_on_devices("cpu") - def test_jit_default_device(self, module): + def test_jit_default_device(self): if jax.device_count() == 1: raise unittest.SkipTest("Test requires multiple devices") @@ -287,7 +284,7 @@ def test_jit_default_device(self, module): test_device = jax.devices()[-1] self.assertNotEqual(system_default_device, test_device) - f = module(lambda x: x + 1) + f = jax.jit(lambda x: x + 1) self.assertEqual(f(1).devices(), system_default_devices) with jax.default_device(test_device): @@ -302,9 +299,9 @@ def test_jit_default_device(self, module): with jtu.ignore_warning(category=DeprecationWarning, message="backend and device argument"): self.assertEqual( - module(f, device=system_default_device)(1).devices(), + jax.jit(f, device=system_default_device)(1).devices(), system_default_devices) - out = module(f, backend="cpu")(1) + out = jax.jit(f, backend="cpu")(1) self.assertEqual(next(iter(out.devices())).platform, "cpu") # Sticky input device overrides default_device @@ -418,12 +415,20 @@ def f(args_list): # Jit and Donate arguments def test_donate_argnames_signature_fail(self): + class NoSignature: + @property + def __signature__(self): + raise TypeError("no signature") + def __call__(self, *args, **kwargs): + return None + fun = NoSignature() + inp = np.arange(4) with self.assertRaisesRegex( ValueError, "Getting the signature of function.*failed. Pass donate_argnums " "instead of donate_argnames."): - jax.jit(np.dot, donate_argnames='a')(inp, inp) + jax.jit(fun, donate_argnames='a')(inp, inp) @parameterized.named_parameters( ("argnums", "donate_argnums", (0, 1)), @@ -510,6 +515,27 @@ def test_device_put_aliasing(self): may_alias=False, donate=False) self.assertNotEqual(id(arr), id(out)) + def test_device_put_aliasing_with_diff_compatible_sharding(self): + if jax.device_count() < 2: + raise unittest.SkipTest("Test requires >= 2 devices") + + mesh = jax.sharding.Mesh( + np.array(jax.devices()[:2]).reshape((2, 1)), ("x", "y") + ) + x = jax.device_put( + np.arange(16).reshape((4, 4)), + jax.NamedSharding(mesh, P("x", None)), + ) + expanded_mesh = jax.sharding.Mesh( + np.array(jax.devices()[:2]).reshape((1, 2, 1)), ("replicas", "x", "y") + ) + dst_sharding = jax.NamedSharding(expanded_mesh, P("x", None)) + # No transfer should happen because the array is aliased to compatible + # sharding that only has a mesh with an additional dimension of size 1. + with jax.transfer_guard_device_to_device("disallow_explicit"): + res = jax.device_put(x, dst_sharding, may_alias=True) + self.assertEqual(dst_sharding, res.sharding) + @parameterized.named_parameters( ("argnums", "donate_argnums", 0), ("argnames", "donate_argnames", 'x'), @@ -529,7 +555,7 @@ def test_jit_donate_weak_type(self, argnum_type, argnum_val): def test_jnp_array_copy(self, argnum_type, argnum_val): # https://github.com/jax-ml/jax/issues/3412 - @partial(jit, **{argnum_type: argnum_val}) + @jit(**{argnum_type: argnum_val}) def _test(array): return array.at[0].set(77) @@ -543,7 +569,7 @@ def _test(array): @jtu.device_supports_buffer_donation() def test_specify_donate_argnums_and_argnames(self): - @partial(jax.jit, donate_argnums=0, donate_argnames=('inp2', 'inp3')) + @jax.jit(donate_argnums=0, donate_argnames=('inp2', 'inp3')) def f(inp1, inp2, inp3): return inp1 * 2, inp2 * 2, inp3 * 2 @@ -561,7 +587,7 @@ def test_resolve_argnums_signature_fail(self): @jtu.device_supports_buffer_donation() def test_donate_argnames_with_args(self): - @partial(jax.jit, donate_argnames='inp1') + @jax.jit(donate_argnames='inp1') def f(inp1): return inp1 * 2 @@ -571,7 +597,7 @@ def f(inp1): @jtu.device_supports_buffer_donation() def test_donate_argnums_with_kwargs(self): - @partial(jax.jit, donate_argnums=0) + @jax.jit(donate_argnums=0) def f(inp1): return inp1 * 2 @@ -621,7 +647,7 @@ def add(x, y): ('argnums', {'donate_argnums': (0, 1)}) ) def test_dict_donation(self, jit_kwargs): - @partial(jax.jit, **jit_kwargs) + @jax.jit(**jit_kwargs) def f(z, y, x): return z, y, x @@ -639,7 +665,7 @@ def f(z, y, x): ('argnums', {'donate_argnums': (0, 1)}) ) def test_dict_donation_args_kwargs(self, jit_kwargs): - @partial(jax.jit, **jit_kwargs) + @jax.jit(**jit_kwargs) def f(z, y, x): return z, y, x @@ -683,7 +709,28 @@ def f(x, y): f.clear_cache() gc.collect() num_live = len(client.live_executables()) - self.assertEqual(num_live_initial, num_live) + # You would hope that these would be equal, but in practice we sometimes + # observe *fewer* live executables after this code runs in threaded tests. + # I suspect this is an artifact of other caches and garbage collection. + self.assertGreaterEqual(num_live_initial, num_live) + + @jtu.thread_unsafe_test() # close_jaxpr cache is shared across threads + def test_pe_close_jaxpr_cache_leak(self): + @jax.jit + def f(x): + return lax.cond(x, lambda: x, lambda: ~ x) + + jaxpr = f.trace(True).jaxpr + jax_util.clear_all_caches() + + res1 = pe.close_jaxpr(jaxpr.jaxpr) + res2 = pe.close_jaxpr(jaxpr.jaxpr) + self.assertIs(res1, res2) + keys_1 = pe.close_jaxpr.cache_keys() + self.assertGreater(len(keys_1), 0) + del jaxpr, res1, res2, keys_1 + keys_2 = pe.close_jaxpr.cache_keys() + self.assertEmpty(keys_2, 0) def test_jit_shallow_copy(self): def f(x): @@ -913,35 +960,27 @@ def test_cpp_jitted_function_returns_PyBuffer(self): self.assertIsInstance(out.sharding, jax.sharding.SingleDeviceSharding) self.assertIsInstance(out, array.ArrayImpl) - @parameterized.named_parameters( - ('jit', jax.jit), - ('pjit', pjit.pjit) - ) @jtu.skip_on_devices("cpu") - def test_explicit_backend(self, module): + def test_explicit_backend(self): f = lambda x: x + 1 with jtu.ignore_warning(category=DeprecationWarning, message="backend and device argument"): - jitted_f = module(f, backend=jtu.device_under_test()) - jitted_f_cpu = module(f, backend="cpu") + jitted_f = jax.jit(f, backend=jtu.device_under_test()) + jitted_f_cpu = jax.jit(f, backend="cpu") result = jitted_f(1.) result_cpu = jitted_f_cpu(1.) self.assertEqual(list(result.devices())[0].platform, jtu.device_under_test()) self.assertEqual(list(result_cpu.devices())[0].platform, "cpu") - @parameterized.named_parameters( - ('jit', jax.jit), - ('pjit', pjit.pjit) - ) @jtu.skip_on_devices("cpu") - def test_device_to_device_copy_between_backends(self, module): + def test_device_to_device_copy_between_backends(self): # b/186624243 f = lambda x: x + 1 with jtu.ignore_warning(category=DeprecationWarning, message="backend and device argument"): - jitted_f = module(f, backend=jtu.device_under_test()) - jitted_f_cpu = module(f, backend="cpu") + jitted_f = jax.jit(f, backend=jtu.device_under_test()) + jitted_f_cpu = jax.jit(f, backend="cpu") x = np.arange(30).reshape(1, 10, 3) result = jitted_f(x) @@ -955,7 +994,7 @@ def test_device_to_device_copy_between_backends(self, module): @jtu.ignore_warning(category=DeprecationWarning, message="backend and device argument") def test_mismatched_nested_backends(self): - @partial(jax.jit, backend=jtu.device_under_test()) + @jax.jit(backend=jtu.device_under_test()) def f(x): return jax.jit(lambda x: x + 1, backend="cpu")(x) @@ -1361,7 +1400,7 @@ def f(x): "exec_time_optimization_effort": 0.0, })(1.0) # doesn't crash. - with self.assertRaisesRegex(xla_extension.XlaRuntimeError, "No such"): + with self.assertRaisesRegex(jax.errors.JaxRuntimeError, "No such"): f_jit = jit( f, compiler_options={ @@ -1402,12 +1441,12 @@ def f(x): lowered = f_jit.lower(1.) self.assertRaisesRegex( - xla_extension.XlaRuntimeError, "No such compile option: 'invalid_key'", + jax.errors.JaxRuntimeError, "No such compile option: 'invalid_key'", lambda: lowered.compile( compiler_options={"invalid_key": "invalid_value"})) self.assertRaisesRegex( - xla_extension.XlaRuntimeError, "is not a valid bool value.", + jax.errors.JaxRuntimeError, "is not a valid bool value.", lambda: lowered.compile( compiler_options={"xla_embed_ir_in_executable": "invalid_value"})) @@ -1422,7 +1461,7 @@ def f(x): # We should still error on invalid options after some valid compiles with self.assertRaisesRegex( - xla_extension.XlaRuntimeError, "No such compile option: 'invalid_key'"): + jax.errors.JaxRuntimeError, "No such compile option: 'invalid_key'"): jit(f, compiler_options={"invalid_key": "invalid_value"})(1.) def test_lower_compile_with_compiler_options_multiple(self): @@ -1447,7 +1486,7 @@ def f(x): # We should still error on invalid options after some valid compiles self.assertRaisesRegex( - xla_extension.XlaRuntimeError, "No such compile option: 'invalid_key'", + jax.errors.JaxRuntimeError, "No such compile option: 'invalid_key'", lambda: lowered.compile( compiler_options={"invalid_key": "invalid_value"})) @@ -1467,7 +1506,7 @@ def f(d) -> float: def test_jit_static_argnums_requires_type_equality(self): # See: https://github.com/jax-ml/jax/pull/9311 - @partial(jit, static_argnums=(0,)) + @jit(static_argnums=(0,)) def f(k): assert python_should_be_executing return k @@ -1482,7 +1521,7 @@ def f(k): def test_caches_depend_on_axis_env(self): # https://github.com/jax-ml/jax/issues/9187 - f = lambda: lax.psum(1, "i") + f = lambda: lax.axis_size("i") g = jax.jit(f) expected = jax.vmap(f, axis_name="i", axis_size=2, out_axes=None)() ans = jax.vmap(g, axis_name="i", axis_size=2, out_axes=None)() @@ -1534,6 +1573,17 @@ def f(x): with jax.no_tracing(): _ = f(y) # crash! + def test_no_execution(self): + @jax.jit + def f(): + return jnp.ones(3) + + f() # no crash + with self.assertRaisesRegex(RuntimeError, 'no_execution'): + with jax.no_execution(): + f() # crash + f() # no crash + class APITest(jtu.JaxTestCase): @@ -1562,7 +1612,7 @@ def f(x): def test_grad_wrap(self, transform): # Ensures that transforms wrap transformed functions with the correct signature. - @partial(jit, static_argnames=['flag']) + @jit(static_argnames=['flag']) @transform def my_function(x, flag): return x if flag else jnp.zeros_like(x) @@ -1626,6 +1676,27 @@ def f(x): assert g(2.0) == 4.0 assert len(side) == 1 + @jtu.thread_unsafe_test() # Concurrent ache eviction means we may retrace. + def test_fwd_and_bwd(self): + def f(x, W): + return x @ W + + x = W = cot_out = jnp.ones((4,4)) + expected_y, f_vjp = api.vjp(f, x, W) + expected_cot_x, expected_cot_W = f_vjp(cot_out) + + fwd, bwd = api.fwd_and_bwd(f, argnums=(0,1)) + y, residuals = fwd(x, W) + cot_x, cot_W = bwd(residuals, cot_out) + + self.assertArraysAllClose(y, expected_y) + self.assertArraysAllClose(cot_x, expected_cot_x) + self.assertArraysAllClose(cot_W, expected_cot_W) + + with jax.no_tracing(): + y, residuals = fwd(x, W) + cot_x, cot_W = bwd(residuals, cot_out) # no recompilation + @parameterized.named_parameters( {"testcase_name": f"_{transform.__name__}", "transform": transform} for transform in [grad, jacfwd, jacrev]) @@ -1816,19 +1887,19 @@ def test_device_put_and_get(self): x2 = api.device_get(dx) self.assertNotIsInstance(x2, jax.Array) self.assertIsInstance(x2, np.ndarray) - assert np.all(x == x2) + self.assertArraysEqual(x2, x) y = [x, (2 * x, 3 * x)] dy = api.device_put(y) y2 = api.device_get(dy) self.assertIsInstance(y2, list) self.assertIsInstance(y2[0], np.ndarray) - assert np.all(y2[0] == x) + self.assertArraysEqual(y2[0], x) self.assertIsInstance(y2[1], tuple) self.assertIsInstance(y2[1][0], np.ndarray) - assert np.all(y2[1][0] == 2 * x) + self.assertArraysEqual(y2[1][0], 2 * x) self.assertIsInstance(y2[1][1], np.ndarray) - assert np.all(y2[1][1] == 3 * x) + self.assertArraysEqual(y2[1][1], 3 * x) def test_device_put_sharding(self): mesh = jax.sharding.Mesh(jax.devices(), ('x',)) @@ -1937,6 +2008,75 @@ def test_device_put_sharding_mismatched_tree_different_leaf_count(self): ): jax.device_put((x, y, z), device=(s1, s2)) + def test_internal_device_put_with_device(self): + # Hitting the cache for a single-device jitted execution while using a numpy + # array calls internal `DevicePutWithDevice`. + f = jax.jit(lambda x: x + 1) + f(np.arange(8)) + + with jtu.count_internal_device_puts() as counts: + f(np.arange(8)) + self.assertEqual(counts(), {"device_put_with_device": 1}) + + def test_internal_device_put_fully_replicated(self): + if jax.device_count() < 2: + raise unittest.SkipTest("Test requires >= 2 devices") + + # Creating an array from a numpy array with a fully-replicated sharding + # calls internal `DevicePutWithSharding`, taking the fully-replicated sub + # case. + mesh = jax.sharding.Mesh(np.array(jax.devices()[:2]), "x") + sharding = jax.NamedSharding(mesh, P()) + + with jtu.count_internal_device_puts() as counts: + jax.device_put(np.arange(8), sharding) + self.assertEqual( + counts(), + {"device_put_with_sharding": 1, "device_put_fully_replicated": 1}, + ) + + def test_internal_device_put_batched(self): + if jax.device_count() < 2: + raise unittest.SkipTest("Test requires >= 2 devices") + + # Creating an array from a numpy array with a non-fully-replicated sharding + # calls internal `DevicePutWithSharding`, performing batched creation of a + # multi-shard array. + mesh = jax.sharding.Mesh(np.array(jax.devices()[:2]), "x") + sharding = jax.NamedSharding(mesh, P("x")) + + with jtu.count_internal_device_puts() as counts: + jax.device_put(np.arange(8), sharding) + self.assertEqual( + counts(), {"device_put_with_sharding": 1, "device_put_batched": 1} + ) + + def test_internal_device_put_assembled(self): + if jax.device_count() < 2: + raise unittest.SkipTest("Test requires >= 2 devices") + + # Creating an array from per-device JAX arrays calls internal + # `DevicePutWithSharding`, performing per-shard array adoption followed by + # assembly. + mesh = jax.sharding.Mesh(np.array(jax.devices()[:2]), "x") + sharding = jax.NamedSharding(mesh, P("x")) + + arr = np.arange(8) + per_device_arrs = { + # Use uncommitted arrays that are not aligned with the destination + # sharding so that we trigger `BatchedDevicePut`. + sharding_impls.hashed_index(index): jnp.array(arr[index]) + for _, index in sharding.devices_indices_map(arr.shape).items() + } + data_callback = lambda index: per_device_arrs[ + sharding_impls.hashed_index(index) + ] + with jtu.count_internal_device_puts() as counts: + jax.make_array_from_callback(arr.shape, sharding, data_callback) + self.assertEqual( + counts(), {"device_put_with_sharding": 1, "device_put_assembled": 1} + ) + def test_device_put_custom_type_not_accepting_none_leaves(self): class CustomNode(list): @@ -1949,6 +2089,39 @@ def unflatten(unused_aux_data, children): tree_util.register_pytree_node(CustomNode, lambda x: (x, None), unflatten) jax.device_put(CustomNode([0.1])) + def test_device_put_literals(self): + self.assertEqual( + np.dtype(np.int32), + jax.device_put(literals.TypedInt(1, np.dtype(np.int32))).dtype) + self.assertEqual( + np.dtype(np.int64), + jax.device_put(literals.TypedInt(1, np.dtype(np.int64))).dtype) + self.assertEqual( + np.dtype(np.float32), + jax.device_put(literals.TypedFloat(1, np.dtype(np.float32))).dtype) + self.assertEqual( + np.dtype(np.float64), + jax.device_put(literals.TypedFloat(1, np.dtype(np.float64))).dtype) + self.assertEqual( + np.dtype(np.complex64), + jax.device_put(literals.TypedComplex( + 1,np.dtype(np.complex64))).dtype) + if jtu.device_under_test() != "tpu": + # The TPU compiler does not support complex128. + self.assertEqual( + np.dtype(np.complex128), + jax.device_put(literals.TypedComplex( + 1, np.dtype(np.complex128))).dtype) + self.assertEqual( + np.dtype(np.int32), + jax.device_put(literals.TypedNdArray(np.array([1], dtype=np.int32), + weak_type=False)).dtype) + self.assertEqual( + np.dtype(np.int64), + jax.device_put(literals.TypedNdArray(np.array([1], dtype=np.int64), + weak_type=False)).dtype) + + def test_vmap_inconsistent_sizes_constructs_proper_error_message(self): def f(x1, x2, g): return g(x1, x2) @@ -1979,7 +2152,7 @@ def f(x1, x2, a3): def test_vmap_inconsistent_sizes_constructs_proper_error_message_starargs(self): # regression test for https://github.com/jax-ml/jax/issues/26908 def f(x, *args): - return x - functools.reduce(jnp.add, args) + return x - sum(args) with self.assertRaisesRegex( ValueError, @@ -1987,37 +2160,6 @@ def f(x, *args): ): jax.vmap(f)(jnp.ones(4), jnp.ones(2), jnp.ones(2)) - def test_vmap_sentinel(self): - - @jax.tree_util.register_dataclass - @dataclasses.dataclass - class Foo: - x: jax.Array - - def __init__(self, x): - nonlocal saw_sentinel - if x is jax._src.api_util.SENTINEL: - saw_sentinel += 1 - self.x = x - - x = jnp.arange(10) - - # assert that sentinel is seen once for vmap in_axes - saw_sentinel = 0 - jax.vmap(lambda f: f.x)(Foo(x)) - self.assertEqual(saw_sentinel, 1) - - # assert that sentinel is seen once for vmap out_axes - saw_sentinel = 0 - jax.vmap(Foo)(x) - self.assertEqual(saw_sentinel, 1) - - # assert that sentinel is seen twice with vmap in_axes and out_axes - saw_sentinel = 0 - jax.vmap(lambda f: Foo(f.x + 1))(Foo(x)) - self.assertEqual(saw_sentinel, 2) - - def test_device_get_scalar(self): x = np.arange(12.).reshape((3, 4)).astype("float32") x = api.device_put(x) @@ -2031,7 +2173,7 @@ def test_device_get_scalar(self): y2 = api.device_get(y) self.assertIsInstance(y2, list) self.assertIsInstance(y2[0], np.ndarray) - assert np.all(y2[0] == x) + self.assertArraysEqual(y2[0], x) self.assertIsInstance(y2[1], int) self.assertEqual(y2[1], 2) @@ -2067,7 +2209,7 @@ def test_device_put_across_platforms(self): self.assertEqual(x.devices(), {cpu_device}) def test_device_put_on_single_device_donated_buffer_fails(self): - @partial(jax.jit, donate_argnums=0) + @jax.jit(donate_argnums=0) def f(inp1): return inp1 * 2 @@ -2083,7 +2225,7 @@ def f(inp1): result.block_until_ready() def test_device_put_on_multi_device_donated_buffer_fails(self): - @partial(jax.jit, donate_argnums=0) + @jax.jit(donate_argnums=0) def f(inp1): return inp1 * 2 @@ -2112,11 +2254,11 @@ def test_jacobian(self): x = R(3) f = lambda x: jnp.dot(A, x) - assert np.allclose(jacfwd(f)(x), A) - assert np.allclose(jacrev(f)(x), A) + self.assertAllClose(jacfwd(f)(x), A) + self.assertAllClose(jacrev(f)(x), A) f = lambda x: jnp.tanh(jnp.dot(A, x)) - assert np.allclose(jacfwd(f)(x), jacrev(f)(x)) + self.assertAllClose(jacfwd(f)(x), jacrev(f)(x)) @jax.default_matmul_precision("float32") def test_hessian(self): @@ -2125,7 +2267,7 @@ def test_hessian(self): x = R(4) f = lambda x: jnp.dot(x, jnp.dot(A, x)) - assert np.allclose(hessian(f)(x), A + A.T) + self.assertAllClose(hessian(f)(x), A + A.T) @jax.default_matmul_precision("float32") def test_hessian_holomorphic(self): @@ -2134,7 +2276,8 @@ def test_hessian_holomorphic(self): x = R(4).astype('complex64') * (1 + 2j) f = lambda x: jnp.dot(x, jnp.dot(A.astype(x.dtype), x)) - assert np.allclose(hessian(f, holomorphic=True)(x), A + A.T) + self.assertAllClose( + hessian(f, holomorphic=True)(x), (A + A.T).astype(x.dtype)) @jax.default_matmul_precision("float32") def test_hessian_aux(self): @@ -2144,17 +2287,17 @@ def test_hessian_aux(self): f = lambda x: (jnp.dot(x, jnp.dot(A, x)), x) h, aux = hessian(f, has_aux=True)(x) - assert np.allclose(h, A + A.T) - assert np.allclose(aux, x) + self.assertAllClose(h, A + A.T) + self.assertAllClose(aux, x) def test_std_basis(self): basis = api._std_basis(jnp.zeros(3)) assert getattr(basis, "shape", None) == (3, 3) - assert np.allclose(basis, np.eye(3)) + self.assertAllClose(basis, np.eye(3)) basis = api._std_basis(jnp.zeros((3, 3))) assert getattr(basis, "shape", None) == (9, 3, 3) - assert np.allclose(basis, np.eye(9).reshape(9, 3, 3)) + self.assertAllClose(basis, np.eye(9).reshape(9, 3, 3)) basis = api._std_basis([0., (jnp.zeros(3), jnp.zeros((3, 4)))]) assert isinstance(basis, list) and len(basis) == 2 @@ -2798,35 +2941,10 @@ def f(x): self.assertEqual(count(), 1) - @jtu.thread_unsafe_test() # jit cache misses aren't thread safe - def test_jit_infer_params_cache(self): - def f(x): - return x - - f_jit = jax.jit(f) - - def g(x): - x = f_jit(x) # noqa: F821 - x = f_jit(x) # noqa: F821 - return x - - g_jit = jax.jit(g) - - inp = np.arange(8) - with jtu.count_jit_infer_params_cache_miss() as count: - g_jit(inp) - - self.assertDictEqual(count, {f: 1, g: 1}) - cache_size = pjit_lib._infer_params_cached.cache_info().currsize - del count, f, f_jit, g, g_jit - # Cache should only keep a weak reference to f and g. - self.assertLess(pjit_lib._infer_params_cached.cache_info().currsize, - cache_size, msg=pjit_lib._infer_params_cached.cache_keys()) - def test_eval_shape_out_shardings(self): s = jax.sharding.SingleDeviceSharding(jax.devices()[0]) - @partial(jax.jit, out_shardings=s) + @jax.jit(out_shardings=s) def f(x): return x * 2 @@ -2954,6 +3072,14 @@ def cond(pred): self.assertEqual(value, 1.) self.assertEqual(grd, np.zeros(shape=(), dtype=float0)) + def test_grad_of_bool_vjp3(self): + def cond(pred): + return lax.cond(pred, lambda _: 1., lambda _: 2., 1.) + value, f_vjp = api.vjp(cond, True) + grd, = f_vjp(1.) + self.assertEqual(value, 1.) + self.assertEqual(grd, np.zeros(shape=(), dtype=float0)) + def test_grad_of_int_index(self): grad_x, grad_i = api.grad(lambda x, i: x[i], argnums=(0, 1), allow_int=True)(np.ones(2), 1) @@ -2979,10 +3105,11 @@ def test_float0_reshape(self): def test_float0_error(self): # float0 is incompatible with other dtypes float0_array = jax.grad(lambda x: x+0., allow_int=True)(1) + self.assertEqual(float0_array.dtype, dtypes.float0) error_text = "float0s do not support any operations by design" with self.assertRaisesRegex(TypeError, error_text): - # dispatch via Array + # dispatch via Array.__add__ and hence jax.numpy _ = float0_array + jnp.zeros(()) with self.assertRaisesRegex(TypeError, error_text): @@ -3040,6 +3167,75 @@ def e(x): self.assertIn("stablehlo.cosine", stablehlo) self.assertIn("stablehlo.sine", stablehlo) + def test_constants_not_in_lowering_jit(self): + if not config.use_simplified_jaxpr_constants.value: + self.skipTest("Works only with simplified Jaxpr consts") + const_size = 100 + const = jax.random.uniform(jax.random.key(0), (const_size,), + dtype=np.float32) + + @jax.jit + def f(): + return jax.jit(lambda: const + 1.)() + + with jtu.collect_lowered_jaxprs() as collection: + res = f() + res = f() + self.assertAllClose(const + 1., res) + + for j, j_module in collection: + self.assertNotRegex(str(j_module), + f"stablehlo.constant dense.*tensor<{const_size}x") + + def test_basic_vjp3(self): + f = jax.jit(lambda x: jnp.sin(jnp.sin(x))) + _, f_vjp = jax.vjp(f, 1.) + g, = f_vjp(1.0) + self.assertAllClose(g, jnp.cos(jnp.sin(1.)) * jnp.cos(1.), check_dtypes=False) + + def test_constants_not_in_lowering_scan(self): + if not config.use_simplified_jaxpr_constants.value: + self.skipTest("Works only with simplified Jaxpr consts") + const_size = 100 + const = jax.random.uniform(jax.random.key(0), (const_size,), + dtype=np.float32) + def f(): + def scan_body(carry, x): + return const, None # Closed over and return + return lax.scan(jax.jit(scan_body), + jnp.zeros((const_size,), dtype=np.float32), # ignored + jnp.zeros((8, const_size), dtype=np.float32)) + + with jtu.collect_lowered_jaxprs() as collection: + res, _ = f() + res, _ = f() + self.assertAllClose(const, res) + + for j, j_module in collection: + self.assertNotRegex(str(j_module), + f"stablehlo.constant dense.*tensor<{const_size}x") + + def test_constants_not_in_lowering_cond(self): + if not config.use_simplified_jaxpr_constants.value: + self.skipTest("Works only with simplified Jaxpr consts") + const_size = 100 + const = jax.random.uniform(jax.random.key(0), (const_size,), + dtype=np.float32) + + def f(x): + return lax.cond(x >= 0., jax.jit(lambda: const), + lambda: const) + + with jtu.collect_lowered_jaxprs() as collection: + res = f(42.) + f(43.) + self.assertAllClose(const, res) + + for j, j_module in collection: + self.assertNotRegex(str(j_module), + f"stablehlo.constant dense.*tensor<{const_size}x") + + def test_concurrent_device_get_and_put(self): def f(x): for _ in range(100): @@ -3063,47 +3259,66 @@ def test_dtype_from_builtin_types(self): x = jnp.array(0, dtype=dtype) self.assertEqual(x.dtype, dtypes.canonicalize_dtype(dtype)) - def test_dtype_warning(self): + @jtu.sample_product( + explicit_x64_dtypes=[ + config.ExplicitX64Mode.WARN, + config.ExplicitX64Mode.ERROR, + config.ExplicitX64Mode.ALLOW, + ], + enable_x64=[True, False], + ) + def test_dtype_warning(self, explicit_x64_dtypes, enable_x64): # cf. issue #1230 - if config.enable_x64.value: - raise unittest.SkipTest("test only applies when x64 is disabled") + @config.explicit_x64_dtypes(explicit_x64_dtypes) + @config.enable_x64(enable_x64) + def check(warn, nowarn): + if ( + config.enable_x64.value + or config.explicit_x64_dtypes.value == config.ExplicitX64Mode.ALLOW + ): + if config.enable_x64.value: + with self.assertNoWarnings(): + warn() + elif config.explicit_x64_dtypes.value == config.ExplicitX64Mode.WARN: + with self.assertWarnsRegex(UserWarning, "Explicitly requested dtype"): + warn() + else: + with self.assertRaisesRegex(ValueError, "Explicitly requested dtype"): + warn() - def check_warning(warn, nowarn): - with self.assertWarnsRegex(UserWarning, "Explicitly requested dtype"): - warn() with self.assertNoWarnings(): nowarn() - check_warning(lambda: jnp.array([1, 2, 3], dtype="float64"), - lambda: jnp.array([1, 2, 3], dtype="float32")) - check_warning(lambda: jnp.array([1, 2, 3], dtype="float64"), - lambda: jnp.array([1, 2, 3], dtype=float)) - check_warning(lambda: jnp.ones(3, dtype=np.float64), - lambda: jnp.ones(3)) - check_warning(lambda: jnp.ones(3, dtype=np.float64), - lambda: jnp.ones(3, dtype=float)) - check_warning(lambda: jnp.ones_like(3, dtype=np.int64), - lambda: jnp.ones_like(3, dtype=np.int32)) - check_warning(lambda: jnp.zeros(3, dtype="int64"), - lambda: jnp.zeros(3, dtype="int32")) - check_warning(lambda: jnp.zeros_like(3, dtype="float64"), - lambda: jnp.zeros_like(3, dtype="float32")) - check_warning(lambda: jnp.full((2, 3), 1, dtype="int64"), - lambda: jnp.full((2, 3), 1)) - check_warning(lambda: jnp.ones(3).astype("float64"), - lambda: jnp.ones(3).astype("float32")) - check_warning(lambda: jnp.eye(3, dtype=np.float64), - lambda: jnp.eye(3)) - check_warning(lambda: jnp.arange(3, dtype=np.float64), - lambda: jnp.arange(3, dtype=np.float32)) - check_warning(lambda: jnp.linspace(0, 3, dtype=np.float64), - lambda: jnp.linspace(0, 3, dtype=np.float32)) - check_warning(lambda: jnp.tri(2, dtype="float64"), - lambda: jnp.tri(2, dtype="float32")) - check_warning(lambda: jnp.arange(1).astype("float64"), - lambda: jnp.arange(1).astype(float)) - check_warning(lambda: jnp.arange(1.0).astype("int64"), - lambda: jnp.arange(1.0).astype(int)) + check(lambda: jnp.array([1, 2, 3], dtype="float64"), + lambda: jnp.array([1, 2, 3], dtype="float32")) + check(lambda: jnp.array([1, 2, 3], dtype="float64"), + lambda: jnp.array([1, 2, 3], dtype=float)) + check(lambda: jnp.ones(3, dtype=np.float64), + lambda: jnp.ones(3)) + check(lambda: jnp.ones(3, dtype=np.float64), + lambda: jnp.ones(3, dtype=float)) + check(lambda: jnp.ones_like(3, dtype=np.int64), + lambda: jnp.ones_like(3, dtype=np.int32)) + check(lambda: jnp.zeros(3, dtype="int64"), + lambda: jnp.zeros(3, dtype="int32")) + check(lambda: jnp.zeros_like(3, dtype="float64"), + lambda: jnp.zeros_like(3, dtype="float32")) + check(lambda: jnp.full((2, 3), 1, dtype="int64"), + lambda: jnp.full((2, 3), 1)) + check(lambda: jnp.ones(3).astype("float64"), + lambda: jnp.ones(3).astype("float32")) + check(lambda: jnp.eye(3, dtype=np.float64), + lambda: jnp.eye(3)) + check(lambda: jnp.arange(3, dtype=np.float64), + lambda: jnp.arange(3, dtype=np.float32)) + check(lambda: jnp.linspace(0, 3, dtype=np.float64), + lambda: jnp.linspace(0, 3, dtype=np.float32)) + check(lambda: jnp.tri(2, dtype="float64"), + lambda: jnp.tri(2, dtype="float32")) + check(lambda: jnp.arange(1).astype("float64"), + lambda: jnp.arange(1).astype(float)) + check(lambda: jnp.arange(1.0).astype("int64"), + lambda: jnp.arange(1.0).astype(int)) def test_error_for_invalid_dtype(self): err_str = ("Error interpreting argument to .* as an abstract array. The problematic " @@ -3120,7 +3335,6 @@ def test_error_for_invalid_dtype(self): def test_vmap_preserves_docstr(self): def superfun(a): """Does things with stuff.""" - pass self.assertRegex(api.vmap(superfun).__doc__, "\n".join([ "Vectorized version of superfun.*", @@ -3330,11 +3544,12 @@ def test_pmap_empty_arguments(self): r"containing an array, got empty \*args=\(\{\},\) and \*\*kwargs=\{\}"): api.pmap(lambda x: x)({}) + @jtu.thread_unsafe_test() # counting compilations isn't thread-safe def test_pmap_global_cache(self): def f(x, y): return x, y - x = np.ones((1, 1, 1)) + x = np.ones((1, 1, 1), dtype=np.float32) # All defaults with jtu.assert_num_jit_and_pmap_compilations(1): @@ -3407,7 +3622,7 @@ def f(x): logging.set_verbosity(prev_level) self.assertGreaterEqual(len(l.output), 3) # 3 lines self.assertTrue(any('Finished tracing' in line for line in l.output)) - self.assertTrue(any('Compiling f' in line for line in l.output)) + self.assertTrue(any('Compiling jit(' in line for line in l.output)) self.assertTrue(any('Finished XLA compilation' in line for line in l.output)) def test_grad_of_jit_compilation_caching(self): @@ -3507,7 +3722,7 @@ def mlir_jaxpr_subcomp_and_collect(c, jaxpr, *args, **kwargs): outer_jaxpr, inner_jaxpr = jaxprs self.assertLen(outer_jaxpr.eqns, 1) - prim_name = 'pjit' + prim_name = 'jit' jaxpr_param = 'jaxpr' self.assertEqual(outer_jaxpr.eqns[0].primitive.name, f'{prim_name}') subjaxpr_1 = outer_jaxpr.eqns[0].params[f"{jaxpr_param}"] @@ -3702,7 +3917,7 @@ def test_linearize_aval_error(self): def test_grad_of_token_consuming_primitive(self): # https://github.com/jax-ml/jax/issues/5463 tokentest_p = core.Primitive("tokentest") - tokentest_p.def_impl(partial(xla.apply_primitive, tokentest_p)) + tokentest_p.def_impl(partial(dispatch.apply_primitive, tokentest_p)) tokentest_p.def_abstract_eval(lambda x, y: x) mlir.register_lowering(tokentest_p, lambda ctx, x, y: [x]) ad.defjvp(tokentest_p, (lambda g, x, token: x), None) @@ -3746,6 +3961,7 @@ def f(x): with self.assertRaisesRegex(Exception, r"Leaked"): f(np.ones(1)) + @unittest.skip('TODO(dougalm): re-enable once we fix tests that were showing tracer leaks') def test_leak_checker_catches_a_grad_leak(self): with jax.checking_leaks(): lst = [] @@ -3917,6 +4133,9 @@ def test_default_device(self): def test_dunder_jax_array(self): # https://github.com/jax-ml/jax/pull/4725 + @partial(jax.tree_util.register_dataclass, + data_fields=['jax_val'], + meta_fields=[]) class AlexArray: def __init__(self, jax_val): self.jax_val = jax_val @@ -3926,10 +4145,16 @@ def __jax_array__(self): shape = property(lambda self: self.jax_val.shape) x = AlexArray(jnp.array([1., 2., 3.])) + + y = jax.jit(lambda x: x)(x) + self.assertIsInstance(x, AlexArray) + self.assertArraysEqual(jnp.asarray(x), jnp.asarray(y)) + y = jnp.sin(x) self.assertAllClose(y, jnp.sin(jnp.array([1., 2., 3.]))) y = api.grad(api.jit(lambda x: jnp.sin(x).sum()))(x) - self.assertAllClose(y, jnp.cos(jnp.array([1., 2., 3.]))) + self.assertIsInstance(y, AlexArray) + self.assertAllClose(jnp.asarray(y), jnp.cos(jnp.array([1., 2., 3.]))) x = AlexArray(jnp.array([[1., 2., 3.]])) y = api.pmap(jnp.sin)(x) @@ -3937,7 +4162,7 @@ def __jax_array__(self): x = jnp.array(1) a = AlexArray(x) - for f in [jnp.isscalar, jnp.size, jnp.shape, jnp.dtype]: + for f in [jnp.isscalar, jnp.size, jnp.shape, jnp.result_type]: self.assertEqual(f(x), f(a)) x = AlexArray(jnp.array(1)) @@ -3947,6 +4172,22 @@ def __jax_array__(self): a2 = jnp.array(((x, x), [x, x])) self.assertAllClose(np.array(((1, 1), (1, 1))), a2) + def test_dunder_jax_array_warnings(self): + class AlexArray: + def __init__(self, jax_val): + self.jax_val = jax_val + def __jax_array__(self): + return self.jax_val + + f = jax.jit(lambda x: x) + a = AlexArray(jnp.arange(4)) + msg = ( + r"Triggering __jax_array__\(\) during abstractification is no longer" + r" supported." + ) + with self.assertRaisesRegex(ValueError, msg): + f(a) + @jtu.thread_unsafe_test() # count_jit_tracing_cache_miss() isn't thread-safe def test_eval_shape_weak_type(self): # https://github.com/jax-ml/jax/issues/23302 @@ -4135,19 +4376,19 @@ def f_rev(_, g): api.grad(lambda x: f(f(f(x))))(1.) def test_jit_inline(self): - @partial(api.jit, inline=False) + @api.jit(inline=False) def f(x): return x * 2 jaxpr = api.make_jaxpr(f)(3) - self.assertIn('pjit', str(jaxpr)) + self.assertIn('jit', str(jaxpr)) - @partial(api.jit, inline=True) + @api.jit(inline=True) def f(x): return x * 2 jaxpr = api.make_jaxpr(f)(3) - self.assertNotIn('pjit', str(jaxpr)) + self.assertNotIn('jit', str(jaxpr)) # Repro for https://github.com/jax-ml/jax/issues/7229. def test_compute_with_large_transfer(self): @@ -4167,7 +4408,7 @@ def test_vjp_fun_jit(self): # from and passed to jitted functions f = lambda x: 2. * x - @partial(jit, static_argnums=0) + @jit(static_argnums=0) def linearize_vjp(f, x): _, vjp_fun = api.vjp(f, x) return vjp_fun @@ -4182,7 +4423,7 @@ def test_linearize_fun_jit(self): # from and passed to jitted functions f = lambda x: 2. * x - @partial(jit, static_argnums=0) + @jit(static_argnums=0) def linearize(f, x): _, jvp_fun = api.linearize(f, x) return jvp_fun @@ -4197,7 +4438,7 @@ def test_linear_transpose_fun_jit(self): # from and passed to jitted functions f = lambda x: 2. * x - @partial(jit, static_argnums=0) + @jit(static_argnums=0) def transpose(f, x): return api.linear_transpose(f, x) @@ -4206,13 +4447,27 @@ def transpose(f, x): expected = (6.,) self.assertEqual(actual, expected) + def test_lax_real_empty(self): + out = jax.lax.empty((2, 2), dtype=jnp.float32) + self.assertEqual(out.shape, (2, 2)) + self.assertEqual(out.dtype, jnp.float32) + + @jtu.run_on_devices('gpu', 'tpu') + def test_lax_empty_vmap(self): + inp = np.arange(8, dtype=jnp.int32).reshape(4, 2) + + def f(x): + return jax.lax.empty(x.shape, x.dtype) + + f = jax.jit(jax.vmap(f)) + f(inp) # doesn't crash + lowered_text = f.lower(inp).as_text() + self.assertIn('@AllocateBuffer() : () -> tensor<4x2xi32>', lowered_text) + def test_leaked_tracer_issue_7613(self): # from https://github.com/jax-ml/jax/issues/7613 import numpy.random as npr - def sigmoid(x): - return 1. / (1. + jnp.exp(-x)) - x = jnp.ones((1, 50)) A = jnp.array(npr.randn(50, 50), dtype=x.dtype) @@ -4296,6 +4551,11 @@ def f(x): finally: config.update("jax_numpy_rank_promotion", allow_promotion) + def test_frexp_sharded(self): + mesh = jtu.create_mesh((1,), 'x') + x = jax.device_put(np.ones(8), jax.NamedSharding(mesh, jax.P('x'))) + jax.jacrev(lambda x: jnp.frexp(x)[0])(x) # doesn't crash + def test_grad_negative_argnums(self): def f(x, y): return x.sum() * y.sum() @@ -4306,13 +4566,28 @@ def f(x, y): g(x, y) # doesn't crash def test_jit_negative_static_argnums(self): - @partial(jax.jit, static_argnums=-1) + @jax.jit(static_argnums=-1) def g(x, y): assert isinstance(y, int) return x * y for i in range(3): # Loop verifies we exercise both Python and C++ dispatch self.assertEqual(2 * i, g(2, i), msg=i) + def test_make_jaxpr_static_argnums_order(self): + # https://github.com/jax-ml/jax/issues/28065 + def f(a, b, c): + x = a + c + y = b * c + z = x - y + return z + + for static_argnums in [(1, 0), (0, 1)]: + val = jax.jit(f, static_argnums=static_argnums)(1, 2, 3) + self.assertEqual(val, -2) + jaxpr = jax.make_jaxpr(f, static_argnums=static_argnums)(1, 2, 3) + self.assertEqual(jaxpr.eqns[0].invars[0].val, 1) + self.assertEqual(jaxpr.eqns[1].invars[0].val, 2) + def test_fastpath_cache_confusion(self): # https://github.com/jax-ml/jax/issues/12542 @jax.jit @@ -4366,13 +4641,6 @@ def foo(x): with self.assertRaisesRegex(TypeError, "applied to foo"): f_vjp(1.0, 1.0) - def test_shapedtypestruct_sharding_error(self): - with self.assertRaisesRegex( - ValueError, - "sharding should be an instance of `jax.sharding.Sharding`."): - jax.ShapeDtypeStruct((8, 2), np.float32, - sharding=jax.sharding.PartitionSpec('x')) - def test_make_jaxpr_weakref(self): class Foo(NamedTuple): x: int @@ -4412,19 +4680,26 @@ def outer_fn(x): return x state = jnp.arange(5, dtype=jnp.uint32) - inner_fn(state) - outer_fn(state) + outer_fn(state) + outer_fn(state) self.assertEqual(inner_count, 1) self.assertEqual(outer_count, 1) + inner_fn(state) + self.assertEqual(inner_count, 1) # not retraced when top-level + def test_grad_conj_symbolic_zeros(self): # https://github.com/jax-ml/jax/issues/15400 f = lambda x: jax.jit(lambda x, y: (x, y))(x, jax.lax.conj(x))[0] out = jax.grad(f)(3.0) # doesn't crash self.assertAllClose(out, 1., check_dtypes=False) + @jtu.thread_unsafe_test() def test_cache_clear_pmap(self): + if config.pmap_shmap_merge.value: + self.skipTest("Already tested by pjit tests under pmap_shmap_merge=True.") + @jax.pmap def f(i): return i * 2 @@ -4465,66 +4740,229 @@ def add(x): tracing_add_count += 1 self.assertEqual(tracing_add_count, 2) + @unittest.skipIf(lib.jaxlib_extension_version < 396, "jaxlib version") @jtu.thread_unsafe_test() # logging is not thread-safe - def test_cache_miss_explanations(self): - @jax.jit - def f(x, y): - return jnp.sin(x) * y['hi'] + def test_cache_miss_explanations_skip_internals(self): + if is_persistent_cache_enabled(): + self.skipTest('With persistent cache, we see the cache misses') + + with config.explain_cache_misses(True): + with self.assertNoLogs(level='WARNING'): + for i in range(2): + jnp.sin(jnp.arange(i + 1, dtype=np.float32)) + @unittest.skipIf(lib.jaxlib_extension_version < 396, "jaxlib version") + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_first_miss(self): + @jax.jit + def f(x): return x x = jnp.float32(1.) - y = {'hi': jnp.arange(3., dtype='float32')} expected_log_len = 1 if not is_persistent_cache_enabled() else 3 - # print on first miss, not on hit + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + f(x) + f(x) + self.assertLen(cm.output, expected_log_len) + msg = cm.output[0] + self.assertIn("TRACING CACHE MISS", msg) + self.assertIn("never seen function", msg) + self.assertNotIn("explanation unavailable!", msg) + + @unittest.skipIf(lib.jaxlib_extension_version < 396, "jaxlib version") + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_in_tree(self): + @jax.jit + def f(*args, **kwargs): return args[0] + + f(0., 1., y=(2., 2.1)) + + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + # Same number of leaves but different trees + f(0., (1., 1.1), y=2.) + self.assertLen(cm.output, 1) + msg = cm.output[0] + self.assertIn("different input pytree", msg) + self.assertNotIn("explanation unavailable!", msg) + + @unittest.skip('TODO(mattjj): re-enable after updating cache miss explainer') + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_arg_passed_as_kwarg(self): + @jax.jit + def f(x, y): return jnp.sin(x) + y + + f(0., 1.) + + # kwarg change + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + f(0., y=1.) + + self.assertLen(cm.output, 1) + msg = cm.output[0] + self.assertIn("different number of args and kwargs, but same total number", msg) + self.assertIn("now 1 args and kwargs with keys ['y']", msg) + self.assertIn("before 1 args and kwargs with keys []", msg) + self.assertNotIn("explanation unavailable!", msg) + + @unittest.skip('TODO(mattjj): re-enable after updating cache miss explainer') + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_static_argnums(self): + @jax.jit(static_argnums=(0, 2)) + def f(x, y, z): + return y + + f(1., 2., "foo") + + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + f(1., 2., "bar") + self.assertLen(cm.output, 1) + msg = cm.output[0] + self.assertIn("different value of static args", msg) + self.assertIn("now 1.0, 'bar' and before 1.0, 'foo'", msg) + self.assertNotIn("explanation unavailable!", msg) + + @unittest.skip('TODO(mattjj): re-enable after updating cache miss explainer') + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_static_argnames(self): + @jax.jit(static_argnames="foo") + def f(*, foo): + return 1 + + f(foo="foo") + + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + f(foo="bar") + self.assertLen(cm.output, 1) + msg = cm.output[0] + self.assertIn("different value of static kwargs", msg) + self.assertIn("now {foo: 'bar'} and before {foo: 'foo'}", msg) + self.assertNotIn('explanation unavailable!', msg) + + @unittest.skipIf(lib.jaxlib_extension_version < 396, "jaxlib version") + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_dtype(self): + @jax.jit + def f(x, y): return x + f(np.float32(0), np.float32(1)) + with config.explain_cache_misses(True): with self.assertLogs(level='WARNING') as cm: - f(x, y) - f(x, y) + f(np.float32(0), np.int32(1)) + self.assertLen(cm.output, 1) + msg = cm.output[0] + self.assertIn("different input types", msg) + self.assertIn("at y, now i32[] and before f32[]", msg) + self.assertNotIn("explanation unavailable!", msg) + + @unittest.skipIf(lib.jaxlib_extension_version < 396, "jaxlib version") + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_weak_type(self): + @jax.jit + def f(x, y): return jnp.sin(x) + y + + y = jnp.arange(4, dtype="float32") + f(jnp.float32(0.), y) + # weak type change (assuming no x64) + if config.enable_x64.value: + self.skipTest("Work only for 32 bit mode") + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + f(0., y) + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 self.assertLen(cm.output, expected_log_len) msg = cm.output[0] - self.assertIn('TRACING CACHE MISS', msg) - self.assertIn('never seen function', msg) + self.assertIn("different input types", msg) + self.assertIn("at x, now f32[]{weak_type=True} and before f32[]{weak_type=False}", msg) + self.assertIn("https://docs.jax.dev/en/latest/type_promotion.html#weak-types", msg) + self.assertNotIn("explanation unavailable!", msg) + + @unittest.skipIf(lib.jaxlib_extension_version < 396, "jaxlib version") + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_shape(self): + @jax.jit + def f(x, y): return jnp.sin(x) + y + f(np.float32(0), np.arange(1, dtype=np.float32)) - # shape change - y_ = {'hi': jnp.arange(4, dtype='float32')} with config.explain_cache_misses(True): with self.assertLogs(level='WARNING') as cm: - f(x, y_) + f(np.float32(0), np.arange(2, dtype=np.float32)) + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 self.assertLen(cm.output, expected_log_len) msg = cm.output[0] - self.assertIn('never seen input type signature', msg) - self.assertIn('closest seen input type signature has 1 mismatches', msg) - self.assertIn('seen f32[3], but now given f32[4]', msg) + self.assertIn("different input types", msg) + self.assertIn("at y, now f32[2] and before f32[1]", msg) + self.assertNotIn("explanation unavailable!", msg) - # weak type change (assuming no x64) - if not config.enable_x64.value: - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - f(1., y) - self.assertLen(cm.output, expected_log_len) - msg = cm.output[0] - self.assertIn('weak_type=True', msg) - self.assertIn('https://jax.readthedocs.io/en/latest/type_promotion.html#weak-types', msg) + @unittest.skip('TODO(mattjj): re-enable after updating cache miss explainer') + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_shape_explain_closest(self): + @jax.jit + def f(x): return x + f(np.ones((1, 2), dtype=np.float32)) + f(np.ones((10, 20, 30), dtype=np.float32)) + f(np.ones((1, 2, 3), dtype=np.float32)) - # kwarg change with config.explain_cache_misses(True): with self.assertLogs(level='WARNING') as cm: - f(1, y=y) + f(np.ones((10, 2, 30), dtype=np.float32)) + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 self.assertLen(cm.output, expected_log_len) msg = cm.output[0] - self.assertIn('never seen passing 1 positional args and 1 keyword args', msg) + self.assertIn("key with different input types", msg) + self.assertIn("at x, now f32[10,2,30] and before f32[10,20,30]", msg) + self.assertNotIn("explanation unavailable!", msg) + + @unittest.skipIf(lib.jaxlib_extension_version < 396, "jaxlib version") + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_tracing_config(self): + @jax.jit + def f(x, y): return jnp.sin(x) + y + f(0., 1.) # tracing config change with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - with jax.numpy_rank_promotion('warn'): - f(x, y) - # depending on the backend, we may or may not get persistent cache warnings + with self.assertLogs(level="WARNING") as cm: + with jax.numpy_rank_promotion("warn"): + with jax.default_matmul_precision("high"): + f(0., 1.) + + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 self.assertTrue(1 <= len(cm.output) <= expected_log_len) msg = cm.output[0] - self.assertIn("tracing context doesn't match", msg) + self.assertIn("racing context", msg) + # self.assertIn("now warn and before", msg) + # self.assertIn("now high and before", msg) + self.assertNotIn("explanation unavailable!", msg) + + @unittest.skip('TODO(mattjj): re-enable after updating cache miss explainer') + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_multiple_changes(self): + @jax.jit + def f(x): return jnp.sin(x) + + call_1 = f(np.arange(4, dtype=np.float32)) + with jax.numpy_rank_promotion("warn"): + call_2 = f(np.arange(8, dtype=np.float32)) + + with config.explain_cache_misses(True): + with self.assertLogs(level='WARNING') as cm: + # Matches call_2 in shape but not context, and call_1 in context but + # not in shape. + f(np.arange(8, dtype=np.float32)) + + self.assertLen(cm.output, 1) + msg = cm.output[0] + self.assertIn("key with different input types", msg) + self.assertIn("at x, now f32[8] and before f32[4]", msg) + self.assertIn("key with different tracing context", msg) + self.assertNotIn("explanation unavailable!", msg) + @unittest.skipIf(lib.jaxlib_extension_version < 396, "jaxlib version") @jtu.thread_unsafe_test() # logging is not thread-safe def test_cache_miss_explanations_new_function_in_loop(self): @jax.jit @@ -4547,32 +4985,38 @@ def f(x, y): _, msg = cm.output self.assertIn('another function defined on the same line', msg) - @jtu.thread_unsafe_test() # logging is not thread-safe - def test_cache_miss_explanations_unpacks_transforms(self): - # Tests that the explain_tracing_cache_miss() function does not throw an - # error when unpacking `transforms` with a length greater than 3. - @jax.jit - def f(key): - return jax.random.truncated_normal(key, 1, 1, dtype=jax.numpy.float32) - - with config.explain_cache_misses(True): - with self.assertLogs(level="WARNING") as cm: - f(jax.random.key(seed=123)) - - if is_persistent_cache_enabled(): - # 5 warnings from tracing cache, 5-10 from persistent cache depending on - # the backend - self.assertTrue(10 <= len(cm.output) <= 15) - self.assertTrue(any("TRACING CACHE MISS" in msg for msg in cm.output)) - else: - self.assertLen(cm.output, 5) - for msg in cm.output: - self.assertIn("TRACING CACHE MISS", msg) - + @unittest.skipIf(lib.jaxlib_extension_version < 396, "jaxlib version") def test_cache_miss_explanations_no_source_info(self): # ``operator.add`` is a built-in function and does not have source info. with config.explain_cache_misses(True): - jax.jit(operator.add)(42, 24) + jax.jit(operator.add)(42, 24) # doesn't crash + + @unittest.skipIf(lib.jaxlib_extension_version < 396, "jaxlib version") + def test_cache_miss_explanations_are_thread_safe(self): + @jax.jit + def f(i): + return jnp.sum(i) + + saw_exception = False + + def thread(i0): + nonlocal saw_exception + try: + for i in range(i0, 100, 10): + if saw_exception: + break + with config.explain_cache_misses(True): + f(jnp.zeros(i)) + except Exception: + saw_exception = True + raise + + t = [threading.Thread(target=thread, args=(i,)) for i in range(10)] + for i in t: + i.start() + for i in t: + i.join() + self.assertFalse(saw_exception) @parameterized.named_parameters([ {"testcase_name": f"{np.dtype(dtype)}", "dtype": dtype} @@ -4685,33 +5129,6 @@ def f(inputs): jtu.check_grads(f, (list(jnp.arange(float(num_args))),), order=1, modes=['rev'], atol=1e-3, rtol=1e-3) - @jtu.run_on_devices("cpu") - def test_inner_jit_forwarding_happens(self): - jaxpr = jax.make_jaxpr(lambda: jax.jit(lambda x: x)(3))() - self.assertLen(jaxpr.jaxpr.outvars, 1) - self.assertIsInstance(jaxpr.jaxpr.outvars[0], core.Literal) - self.assertEqual(jaxpr.jaxpr.outvars[0].val, 3) - - @parameterized.parameters(range(8)) - @jtu.run_on_devices("cpu") - def test_inner_jit_forwarding_correctness(self, num_input_fwd): - num_args = 8 - rng = np.random.RandomState(0) - - @jax.jit - def f(inputs): - inputs = [inputs[i] for i in rng.permutation(num_args)] - outputs = (inputs[:num_input_fwd] + - [jnp.sin(inputs[i]) for i in range(num_args - num_input_fwd)]) - return [outputs[i] for i in rng.permutation(num_args)] - - f2 = jax.jit(f) - inputs = list(jnp.arange(float(num_args))) - expected = f(inputs) - ans = f2(inputs) - for a, b in zip(ans, expected): - self.assertAllClose(a, b) - @unittest.skip # TODO(dougalm): figure out with Matt what to do with this feature def test_inner_jit_forwarded_consts_stay_const(self): out = jax.jit(lambda: int(jax.jit(lambda x: x)(3)))() # don't crash @@ -4735,7 +5152,7 @@ def f(x): def test_inlined_literals_with_error(self): @jax.jit def f(): - @partial(jax.jit, inline=True) + @jax.jit(inline=True) def g(): return jnp.sin(1.) if g() > 0: @@ -4776,7 +5193,7 @@ def sin_of_sin(x): def test_deferred_primal_with_direct_linearize(self): def my_sin_lin(nzs, x): nz, = nzs - return (my_sin_p.bind(x), nz, x, lambda x, t: lax.mul(t, lax.cos(x))) + return (my_sin_p.bind(x, accuracy=None), nz, x, lambda x, t: lax.mul(t, lax.cos(x))) my_sin_p = core.Primitive("my_sin_p") my_sin_p.def_impl(lax.sin) @@ -4786,15 +5203,166 @@ def my_sin_lin(nzs, x): with config.use_direct_linearize(True): jax.grad(my_sin_p.bind)(1.0) # doesn't crash + def test_ensure_compile_time_eval_no_leaks(self): + # https://github.com/jax-ml/jax/issues/25847 + with jax.ensure_compile_time_eval(): + jnp.linalg.solve(jnp.eye(3), jnp.ones(3)) # doesn't crash -class RematTest(jtu.JaxTestCase): + def test_returned_non_jaxtype(self): - @parameterized.named_parameters( + class TestEnum(enum.Enum): + A = enum.auto() + + @jax.tree_util.register_dataclass + @dataclasses.dataclass + class TestClass3: + test_enum_field: TestEnum = dataclasses.field(metadata=dict(static=True)) + test_data_field: int + + def test_jax_function(test_class: TestClass3) -> TestEnum: + return test_class.test_enum_field + + jitted_test_function = jax.jit(test_jax_function) + with self.assertRaisesRegex(TypeError, "returned a value of type"): + jitted_test_function( + TestClass3( + test_data_field=1, + test_enum_field=TestEnum.A, + ) + ) + + def test_make_jaxpr_deduplicates_consts(self): + # We don't promise this behavior in the public API, but we've had it for a + # long time. This test checks we don't *unintentionally* break it. + + # We are careful to choose a type that would not be canonicalized here, + # otherwise the jnp.array(...) calls will induce constant duplication. + c = np.ones(3).astype(np.float32) + + def find_constants(jaxpr: core.ClosedJaxpr): + for j in it.chain([jaxpr], core.subjaxprs(jaxpr)): + for eq in j.eqns: + for inv in eq.invars: + if isinstance(inv, core.Literal) and np.shape(inv.val): + yield inv.val + + def uniq(lst): + def key(a): + if isinstance(a, literals.TypedNdArray): + return np.asarray(a) + else: + return a + return {id(key(v)): v for v in lst}.values() + + @jax.make_jaxpr + def f(): + return jnp.array(c), jnp.sum(c), c, jnp.array(c), jnp.sum(c), c + + if config.use_simplified_jaxpr_constants.value: + consts = uniq(find_constants(f())) + else: + consts = f().consts + + self.assertLen(consts, 1) + + d = np.zeros(3) + + # TODO(mattjj,phawkins): we broke this on purpose, as it probably isn't + # load-bearing (see above comment). If we wanted to fix it, we might share + # the constid cache across jaxpr traces, or we might hash on const value. + # @jax.make_jaxpr + # def g(): + # return jax.lax.cond(True, + # lambda: (c, jnp.sum(c), c), + # lambda: (c, jnp.sum(d), d)) + # if config.use_simplified_jaxpr_constants.value: + # consts = uniq(find_constants(g())) + # else: + # consts = g().consts + # self.assertLen(consts, 2) + + # TODO(mattjj,dougalm): this test was flakey on CI; figure out how to enable? + # @jtu.run_on_devices('cpu') + # def test_implicit_dce_linearize(self): + # def foo(x): + # const = np.zeros((300,)) + # x * const + # r = weakref.ref(const) + # del const + # assert r() is None, "oops, the constant wasn't DCE'd" + # return x + # with config.use_direct_linearize(True): + # _ = jax.grad(foo)(3.) + + @jtu.run_on_devices('cpu') + def test_implicit_dce_linearize_jaxpr(self): + def foo(x): + const = np.zeros((300,)) + x * const + r = weakref.ref(const) + del const + return x + + with config.use_direct_linearize(True): + _, f_vjp = jax.vjp(foo, 3.) + + self.assertNotIn('mul', str(f_vjp)) + + @jtu.thread_unsafe_test() # make_user_context() is not thread-safe at the moment + def test_user_trace_context_hooks(self): + my_config = jax.make_user_context() + + @jax.jit + def f(x): + return x + + with jtu.count_jit_tracing_cache_miss() as tracing_count: + f(1.) + with my_config(2): + f(1.) + with my_config(3): + f(1.) + with my_config(4): + f(1.) + self.assertEqual(tracing_count(), 4) + + # TODO(mattjj,dougalm): re-enable if we set auto_dce=True by default + # @jtu.run_on_devices('cpu') + # def test_implicit_dce(self): + # @api.jit + # def foo(x): + # const = np.zeros((300,)) + # r = weakref.ref(const) + # jnp.sin(const) + const + # del const + # assert r() is None, "oops, the constant wasn't DCE'd" + # return x + x + # foo(1.0) + + def test_dce_sink_vmap(self): + def f(x): + jax.lax.dce_sink(x) + return x + + jax.vmap(f)(jnp.arange(3.)) # don't crash + + def test_sharding_attr_on_tracer_error(self): + @jax.jit + def f(x): + with self.assertRaisesRegex(AttributeError, 'typeof'): + x.sharding + + f(jnp.arange(2.)) + + +class RematTest(jtu.JaxTestCase): + + @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ ('', jax.remat), ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), - ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), + ('_new', partial(jax.checkpoint, policy=lambda *_, **__: False)), ]) @jtu.thread_unsafe_test() # monkey patches sin_p and cos_p def test_remat_basic(self, remat): @@ -4823,8 +5391,8 @@ def f(x): sin_impl = lax.sin_p.impl cos_impl = lax.cos_p.impl try: - lax.sin_p.def_impl(lambda x: sin_calls.append(1) or sin_impl(x)) - lax.cos_p.def_impl(lambda x: cos_calls.append(1) or cos_impl(x)) + lax.sin_p.def_impl(lambda x, **kwargs: sin_calls.append(1) or sin_impl(x, **kwargs)) + lax.cos_p.def_impl(lambda x, **kwargs: cos_calls.append(1) or cos_impl(x, **kwargs)) f_lin(3.) finally: lax.sin_p.def_impl(sin_impl) @@ -4837,7 +5405,7 @@ def f(x): for suffix, remat in [ ('', jax.remat), ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), - ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), + ('_new', partial(jax.checkpoint, policy=lambda *_, **__: False)), ]) def test_remat_freevars(self, remat): def f1(x): @@ -4968,7 +5536,7 @@ def f(x): for suffix, remat in [ ('', jax.remat), ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), - ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), + ('_new', partial(jax.checkpoint, policy=lambda *_, **__: False)), ]) def test_remat_jit(self, remat): @remat @@ -4996,7 +5564,7 @@ def f_(x): for suffix, remat in [ ('', jax.remat), ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), - ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), + ('_new', partial(jax.checkpoint, policy=lambda *_, **__: False)), ]) def test_remat_vmap(self, remat): @remat @@ -5019,7 +5587,7 @@ def g(x): # Make sure that introducing constants in vmap works. constant_introducing_p = core.Primitive('introduce_constant') - constant_introducing_p.def_abstract_eval(core.raise_to_shaped) + constant_introducing_p.def_abstract_eval(lambda x: x) def _constant_introducing_batcher(xs, ds): (x,), (d,) = xs, ds return (x + np.arange(x.size, dtype=x.dtype).reshape(x.shape)), d @@ -5032,7 +5600,7 @@ def _constant_introducing_batcher(xs, ds): for suffix, remat in [ ('', jax.remat), ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), - ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), + ('_new', partial(jax.checkpoint, policy=lambda *_, **__: False)), ]) def test_remat_vmap_not_leading_dim(self, remat): @remat @@ -5050,7 +5618,7 @@ def g(x): for suffix, remat in [ ('', jax.remat), ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), - ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), + ('_new', partial(jax.checkpoint, policy=lambda *_, **__: False)), ]) def test_remat_higher_order_autodiff(self, remat): def f(x): @@ -5065,7 +5633,7 @@ def f(x): {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ ('', jax.remat), - ('_new', new_checkpoint), + ('_new', jax.checkpoint), ]) def test_remat_scan(self, remat): to_scan = lambda c, x: (jnp.sin(c), None) @@ -5099,7 +5667,7 @@ def f_yesremat(x): for suffix, remat in [ ('', jax.remat), ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), - ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), + ('_new', partial(jax.checkpoint, policy=lambda *_, **__: False)), ]) @jtu.thread_unsafe_test() # monkey patches sin_p def test_remat_no_redundant_flops(self, remat): @@ -5117,7 +5685,7 @@ def f(x, y): called = [] sin_impl = lax.sin_p.impl try: - lax.sin_p.def_impl(lambda x: called.append(1) or sin_impl(x)) + lax.sin_p.def_impl(lambda x, **kwargs: called.append(1) or sin_impl(x, **kwargs)) api.grad(g)(3.) finally: lax.sin_p.def_impl(sin_impl) @@ -5129,7 +5697,7 @@ def f(x, y): for suffix, remat in [ ('', jax.remat), ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), - ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), + ('_new', partial(jax.checkpoint, policy=lambda *_, **__: False)), ]) def test_remat_binomial_checkpointing(self, remat): def binom_checkpoint(funs): @@ -5150,7 +5718,7 @@ def binom_checkpoint(funs): {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ ('', jax.remat), - ('_new', new_checkpoint), + ('_new', jax.checkpoint), ]) def test_remat_symbolic_zeros(self, remat): # code from https://github.com/jax-ml/jax/issues/1907 @@ -5184,7 +5752,7 @@ def move(R,i): for suffix, remat in [ ('', jax.remat), ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), - ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), + ('_new', partial(jax.checkpoint, policy=lambda *_, **__: False)), ]) def test_remat_jit2(self, remat): @api.jit @@ -5203,7 +5771,7 @@ def g(): {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ ('', jax.remat), - ('_new', new_checkpoint), + ('_new', jax.checkpoint), ]) def test_remat_nontrivial_env(self, remat): # simplified from https://github.com/jax-ml/jax/issues/2030 @@ -5215,7 +5783,7 @@ def foo(state, dt=0.5, c=1): u_t = u_t + u_tt * dt return (u, u_t) - @partial(api.jit, static_argnums=(1,)) + @api.jit(static_argnums=(1,)) def _multi_step(state, count, dt, c): f = lambda s, _: (foo(s, dt, c), _) return lax.scan(f, state, None, count) @@ -5237,7 +5805,7 @@ def loss(u0, target, steps, dt=1/jnp.sqrt(2), c=1): for suffix, remat in [ ('', jax.remat), ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), - ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), + ('_new', partial(jax.checkpoint, policy=lambda *_, **__: False)), ]) def test_remat_jit3(self, remat): # https://github.com/jax-ml/jax/issues/2180 @@ -5270,7 +5838,7 @@ def f(w, x): {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ ('', jax.remat), - ('_new', new_checkpoint), + ('_new', jax.checkpoint), ]) def test_remat_scan2(self, remat): # https://github.com/jax-ml/jax/issues/1963 @@ -5291,7 +5859,7 @@ def named_call(f): def named_f(*args): my_f = lambda: (f(*args),) f_ = lu.wrap_init( - my_f, debug_info=api_util.debug_info("test_remat", my_f, args, {})) + my_f, debug_info=api_util.debug_info("test_remat", my_f, (), {})) out, = core.call_p.bind(f_) return out return named_f @@ -5309,7 +5877,7 @@ def f(a_bool, y): for suffix, remat in [ ('', jax.remat), ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), - ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), + ('_new', partial(jax.checkpoint, policy=lambda *_, **__: False)), ]) def test_remat_eval_counter(self, remat): # https://github.com/jax-ml/jax/issues/2737 @@ -5373,7 +5941,7 @@ def call(f, *args): for suffix, remat in [ ('', jax.remat), ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), - ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), + ('_new', partial(jax.checkpoint, policy=lambda *_, **__: False)), ]) def test_escaped_tracer_remat(self, remat): # b/169779185 @@ -5394,7 +5962,7 @@ def g(): for suffix, remat in [ ('', jax.remat), ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), - ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), + ('_new', partial(jax.checkpoint, policy=lambda *_, **__: False)), ]) def test_no_cse_widget_on_primals(self, remat): @remat @@ -5418,7 +5986,7 @@ def f(x): {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ ('', jax.remat), - ('_new', new_checkpoint), + ('_new', jax.checkpoint), ]) def test_no_cse_widget_with_prevent_cse_false(self, remat): @partial(remat, prevent_cse=False) @@ -5442,7 +6010,7 @@ def f(x): "policy": policy, "in_jaxpr2": in_jaxpr2, "not_in_jaxpr2": not_in_jaxpr2} for remat_name, remat in [ ('old_remat', jax.remat), - ('new_remat', new_checkpoint), + ('new_remat', jax.checkpoint), ] for policy_name, policy, in_jaxpr2, not_in_jaxpr2 in [ ('save_anything', lambda *_, **__: True, [], [' sin ', ' cos ']), @@ -5469,7 +6037,7 @@ def test_remat_custom_policy(self, remat, policy, in_jaxpr2, not_in_jaxpr2): {"testcase_name": f"_{remat_name}", "remat": remat} for remat_name, remat in [ ('old_remat', jax.remat), - ('new_remat', new_checkpoint), + ('new_remat', jax.checkpoint), ]) def test_remat_custom_policy_save_cos(self, remat): save_cos = lambda prim, *_, **__: str(prim) == 'cos' @@ -5485,7 +6053,7 @@ def test_remat_custom_policy_save_cos(self, remat): {"testcase_name": f"_{remat_name}", "remat": remat} for remat_name, remat in [ ('old_remat', jax.remat), - ('new_remat', new_checkpoint), + ('new_remat', jax.checkpoint), ]) def test_remat_checkpoint_dots(self, remat): @partial(remat, policy=jax.checkpoint_policies.checkpoint_dots) @@ -5508,7 +6076,7 @@ def f(x): {"testcase_name": f"_{remat_name}", "remat": remat} for remat_name, remat in [ ('old_remat', jax.remat), - ('new_remat', new_checkpoint), + ('new_remat', jax.checkpoint), ]) def test_remat_checkpoint_dots_with_no_batch_dims(self, remat): @partial(remat, policy=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims) @@ -5531,7 +6099,7 @@ def f(x): {"testcase_name": f"_{remat_name}", "remat": remat} for remat_name, remat in [ ('old_remat', jax.remat), - ('new_remat', new_checkpoint), + ('new_remat', jax.checkpoint), ]) def test_remat_checkpoint_dots_with_no_batch_dims2(self, remat): @partial(remat, policy=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims) @@ -5554,7 +6122,7 @@ def f(x): {"testcase_name": f"_{remat_name}", "remat": remat} for remat_name, remat in [ ('old_remat', jax.remat), - ('new_remat', new_checkpoint), + ('new_remat', jax.checkpoint), ]) def test_remat_checkpoint_dots_jit(self, remat): @api.jit @@ -5589,7 +6157,7 @@ def body(x, _): return f(x), None return lax.scan(body, x, None, length=2)[0] _, f_vjp = api.vjp(f, jnp.ones((5, 5))) - jaxpr_text = str(f_vjp.args[0].func.args[1]) + jaxpr_text = str(f_vjp.jaxpr) # Two sine calls in the backward pass because while we don't save sines # within the (rematted) body function, we can save the scan carry, which @@ -5664,7 +6232,7 @@ def g(x): {"testcase_name": f"_{remat_name}", "remat": remat} for remat_name, remat in [ ('old_remat', jax.remat), - ('new_remat', new_checkpoint), + ('new_remat', jax.checkpoint), ]) def test_remat_dropvar_policy(self, remat): def f(x): @@ -5702,21 +6270,21 @@ def test_constants_not_hoisted(self): # implementation avoids that. See https://github.com/jax-ml/jax/pull/8191. # no residuals from constants created inside jnp.einsum - @partial(new_checkpoint, policy=lambda *_, **__: False) + @partial(jax.checkpoint, policy=lambda *_, **__: False) def f(x): return jnp.einsum('ii->i', x) res_avals = saved_residuals(f, jnp.ones((2, 2))) self.assertLen(res_avals, 0) # no residuals from jnp.zeros - @partial(new_checkpoint, policy=lambda *_, **__: False) + @partial(jax.checkpoint, policy=lambda *_, **__: False) def f(x): return jnp.zeros_like(x) * x res_avals = saved_residuals(f, jnp.ones((2, 2))) self.assertLen(res_avals, 0) # no residuals from jnp.zeros, but input must be saved - @partial(new_checkpoint, policy=lambda *_, **__: False) + @partial(jax.checkpoint, policy=lambda *_, **__: False) def f(x): return jnp.zeros_like(x) * jnp.sin(x) res_avals = saved_residuals(f, jnp.ones((2, 2))) @@ -5732,6 +6300,16 @@ def f(x): res = saved_residuals(f, 3.) self.assertStartsWith(res[1][1], "named 'foo'") + def test_name_pytree(self): + @partial(jax.remat, policy=lambda p, *_, **__: 'mul' in str(p)) + def f(x): + x = checkpoint_name({'a': x * x}, 'foo')['a'] + x = x * x + return x + + res = saved_residuals(f, 3.) + self.assertStartsWith(res[1][1], "named 'foo'") + def test_name_denylist(self): def f(x): y = checkpoint_name(jnp.multiply(2., 2.), 'y') @@ -5741,19 +6319,19 @@ def f(x): return (((x * y) * z) * w) * u policy = jax.checkpoint_policies.save_any_names_but_these('y', 'z', 'w') - res = saved_residuals(new_checkpoint(f, policy=policy), 1.) + res = saved_residuals(jax.checkpoint(f, policy=policy), 1.) self.assertLen(res, 0) # can't save anything policy = jax.checkpoint_policies.save_any_names_but_these('z', 'w') - res = saved_residuals(new_checkpoint(f, policy=policy), 1.) + res = saved_residuals(jax.checkpoint(f, policy=policy), 1.) self.assertLen(res, 1) # can save only y policy = jax.checkpoint_policies.save_any_names_but_these('w') - res = saved_residuals(new_checkpoint(f, policy=policy), 1.) + res = saved_residuals(jax.checkpoint(f, policy=policy), 1.) self.assertLen(res, 2) # can save y and z policy = jax.checkpoint_policies.save_any_names_but_these() - res = saved_residuals(new_checkpoint(f, policy=policy), 1.) + res = saved_residuals(jax.checkpoint(f, policy=policy), 1.) self.assertLen(res, 3) # can save y, z, and w def test_name_allowlist(self): @@ -5765,19 +6343,19 @@ def f(x): return (((x * y) * z) * w) * u policy = jax.checkpoint_policies.save_only_these_names('y', 'z', 'w') - res = saved_residuals(new_checkpoint(f, policy=policy), 1.) + res = saved_residuals(jax.checkpoint(f, policy=policy), 1.) self.assertLen(res, 3) # can save y, z, and w policy = jax.checkpoint_policies.save_only_these_names('z', 'w') - res = saved_residuals(new_checkpoint(f, policy=policy), 1.) + res = saved_residuals(jax.checkpoint(f, policy=policy), 1.) self.assertLen(res, 2) # can save z and w policy = jax.checkpoint_policies.save_only_these_names('w') - res = saved_residuals(new_checkpoint(f, policy=policy), 1.) + res = saved_residuals(jax.checkpoint(f, policy=policy), 1.) self.assertLen(res, 1) # can save w policy = jax.checkpoint_policies.save_only_these_names() - res = saved_residuals(new_checkpoint(f, policy=policy), 1.) + res = saved_residuals(jax.checkpoint(f, policy=policy), 1.) self.assertLen(res, 0) # can't save anything! def test_saved_residuals_utility(self): @@ -5787,18 +6365,24 @@ def f(x, y): return z * ((x1 * x2) * y) * np.array([3.]) res = saved_residuals(f, (2., 3.), y=4.) - self.assertLen(res, 6) - self.assertEqual(res[0][0].shape, (1,)) - self.assertEqual(res[0][1], "from a constant") - self.assertEqual(res[1][0].shape, ()) - self.assertEqual(res[1][1], "from the argument x[0]") - self.assertEqual(res[2][0].shape, ()) - self.assertEqual(res[2][1], "from the argument x[1]") - self.assertEqual(res[3][0].shape, ()) - self.assertEqual(res[3][1], "from the argument y") - self.assertEqual(res[4][0].shape, ()) - self.assertStartsWith(res[4][1], "named 'z'") - self.assertEqual(res[5][0].shape, ()) + if config.use_simplified_jaxpr_constants.value: + self.assertLen(res, 5) + start_idx = 0 + else: + self.assertLen(res, 6) + self.assertEqual(res[0][0].shape, (1,)) + self.assertEqual(res[0][1], "from a constant") + start_idx = 1 + + self.assertEqual(res[start_idx][0].shape, ()) + self.assertEqual(res[start_idx][1], "from the argument x[0]") + self.assertEqual(res[start_idx + 1][0].shape, ()) + self.assertEqual(res[start_idx + 1][1], "from the argument x[1]") + self.assertEqual(res[start_idx + 2][0].shape, ()) + self.assertEqual(res[start_idx + 2][1], "from the argument y") + self.assertEqual(res[start_idx + 3][0].shape, ()) + self.assertStartsWith(res[start_idx + 3][1], "named 'z'") + self.assertEqual(res[start_idx + 4][0].shape, ()) def test_saved_residuals_utility_jit(self): @jax.jit @@ -5808,7 +6392,13 @@ def f(x, y): return z * ((x1 * x2) * y) * np.array([3.]) res = saved_residuals(f, (2., 3.), y=4.) - self.assertLen(res, 6) + if config.use_simplified_jaxpr_constants.value: + base_res_idx = 0 + else: + self.assertEqual(res[0][1], "from a constant") + self.assertEqual(res[0][0].shape, (1,)) + res.pop(0) + self.assertLen(res, 5) self.assertEqual(res[0][0].shape, ()) self.assertEqual(res[0][1], "from the argument x[0]") self.assertEqual(res[1][0].shape, ()) @@ -5818,15 +6408,13 @@ def f(x, y): self.assertEqual(res[3][0].shape, ()) self.assertStartsWith(res[3][1], "output of jitted function 'f'") self.assertEqual(res[4][0].shape, ()) - self.assertEqual(res[5][0].shape, (1,)) - self.assertStartsWith(res[5][1], "output of jitted function 'f'") @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ ('', jax.remat), ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), - ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), + ('_new', partial(jax.checkpoint, policy=lambda *_, **__: False)), ]) def test_checkpoint_dropvars(self, remat): @remat @@ -5837,7 +6425,7 @@ def f(x): _ = api.grad(f)(3.) # doesn't crash def test_dce_keeps_eqns_with_used_outputs_but_no_used_inputs(self): - @new_checkpoint + @jax.checkpoint def f(x): c = jax.jit(lambda: 3.)() return c * x @@ -5860,7 +6448,7 @@ def test_vjp_caching(self): with jtu.count_pjit_cpp_cache_miss() as count: # noqa: F841 for _ in range(20): f_vjp(1.)[0].block_until_ready() - self.assertEqual(count(), 2) # fwd execute_trivial, backward_pass on bwd + self.assertLessEqual(count(), 2) def test_vjp_caching_static_argnums(self): identity = jax.remat(lambda x, y: jax.jit(lambda x: 2 * x if y else x)(x), @@ -5869,7 +6457,7 @@ def test_vjp_caching_static_argnums(self): with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 for _ in range(20): f_vjp(1.)[0].block_until_ready() - self.assertEqual(count(), 2) # fwd execute_trivial, backward_pass on bwd + self.assertLessEqual(count(), 2) def test_fwd_caching(self): # see above test also @@ -5893,7 +6481,7 @@ def test_fwd_caching_static_argnums(self): {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ ('', jax.remat), - ('_new', new_checkpoint), + ('_new', jax.checkpoint), ]) def test_remat_of_scan(self, remat): to_scan = lambda c, _: (jnp.sin(c), jnp.sin(c)) @@ -5901,6 +6489,7 @@ def test_remat_of_scan(self, remat): jtu.check_grads(remat(f), (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(remat(f), 4.)[1])(1.) + print("debug jaxpr: ", str(jaxpr)) self.assertIn(' sin ', str(jaxpr)) self.assertIn(' cos ', str(jaxpr)) @@ -5908,7 +6497,7 @@ def test_remat_of_scan(self, remat): {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ ('', jax.remat), - ('_new', new_checkpoint), + ('_new', jax.checkpoint), ]) def test_const_in_jvp_scan(self, remat): @jax.custom_jvp @@ -5932,7 +6521,7 @@ def test_remat_checkpoint_dots_outside_scan(self): # see also above test test_remat_checkpoint_dots_inside_scan x = jnp.ones((5,)) - @partial(new_checkpoint, policy=jax.checkpoint_policies.checkpoint_dots) + @partial(jax.checkpoint, policy=jax.checkpoint_policies.checkpoint_dots) def f(W): def f(x): x = jnp.sin(jnp.dot(x, W, precision=lax.Precision.HIGHEST)) @@ -5944,7 +6533,7 @@ def body(x, _): return f(x), None return lax.scan(body, x, None, length=2)[0] _, f_vjp = api.vjp(f, jnp.ones((5, 5))) - jaxpr = f_vjp.args[0].func.args[1] + jaxpr = f_vjp.jaxpr jaxpr_text = str(jaxpr) self.assertEqual(jaxpr_text.count(' sin '), 3) @@ -5959,7 +6548,7 @@ def body(x, _): return f(x), None def test_remat_of_scan_policy(self): save_cos = lambda prim, *_, **__: str(prim) == 'cos' to_scan = lambda c, _: (jnp.sin(c), jnp.sin(c)) - f = new_checkpoint(lambda x: lax.scan(to_scan, x, None, length=3), + f = jax.checkpoint(lambda x: lax.scan(to_scan, x, None, length=3), policy=save_cos) jtu.check_grads(f, (3.,), order=2, modes=['rev']) @@ -5985,7 +6574,7 @@ def sin_jvp(primals, tangents): sin.defjvp(sin_jvp) save_cos = lambda prim, *_, **__: str(prim) == 'cos' - f = new_checkpoint(partial(scan_apply, sin), policy=save_cos) + f = jax.checkpoint(partial(scan_apply, sin), policy=save_cos) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) @@ -5993,14 +6582,14 @@ def sin_jvp(primals, tangents): self.assertEqual(jaxpr_text.count(' cos '), 0) save_sin = lambda prim, *_, **__: str(prim) == 'sin' - f = new_checkpoint(partial(scan_apply, sin), policy=save_sin) + f = jax.checkpoint(partial(scan_apply, sin), policy=save_sin) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) self.assertEqual(jaxpr_text.count(' sin '), 0) self.assertEqual(jaxpr_text.count(' cos '), 1) - f = new_checkpoint(partial(scan_apply, sin), + f = jax.checkpoint(partial(scan_apply, sin), policy=jax.checkpoint_policies.everything_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) @@ -6008,7 +6597,7 @@ def sin_jvp(primals, tangents): self.assertEqual(jaxpr_text.count(' sin '), 0) self.assertEqual(jaxpr_text.count(' cos '), 0) - f = new_checkpoint(partial(scan_apply, sin), + f = jax.checkpoint(partial(scan_apply, sin), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) @@ -6016,7 +6605,7 @@ def sin_jvp(primals, tangents): self.assertEqual(jaxpr_text.count(' sin '), 1) # +1 b/c dce fixed point self.assertEqual(jaxpr_text.count(' cos '), 1) - f = new_checkpoint(lambda x: scan_apply(sin, scan_apply(sin, x)), + f = jax.checkpoint(lambda x: scan_apply(sin, scan_apply(sin, x)), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) @@ -6043,7 +6632,7 @@ def sin_jvp(primals, tangents): sin.defjvp(sin_jvp) save_cos = lambda prim, *_, **__: str(prim) == 'cos' - f = new_checkpoint(partial(scan_apply, sin), policy=save_cos) + f = jax.checkpoint(partial(scan_apply, sin), policy=save_cos) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) @@ -6051,14 +6640,14 @@ def sin_jvp(primals, tangents): self.assertEqual(jaxpr_text.count(' cos '), 0) save_sin = lambda prim, *_, **__: str(prim) == 'sin' - f = new_checkpoint(partial(scan_apply, sin), policy=save_sin) + f = jax.checkpoint(partial(scan_apply, sin), policy=save_sin) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) self.assertEqual(jaxpr_text.count(' sin '), 0) self.assertEqual(jaxpr_text.count(' cos '), 1) - f = new_checkpoint(partial(scan_apply, sin), + f = jax.checkpoint(partial(scan_apply, sin), policy=jax.checkpoint_policies.everything_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) @@ -6066,7 +6655,7 @@ def sin_jvp(primals, tangents): self.assertEqual(jaxpr_text.count(' sin '), 0) self.assertEqual(jaxpr_text.count(' cos '), 0) - f = new_checkpoint(partial(scan_apply, sin), + f = jax.checkpoint(partial(scan_apply, sin), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) @@ -6074,7 +6663,7 @@ def sin_jvp(primals, tangents): self.assertEqual(jaxpr_text.count(' sin '), 1) # +1 b/c dce fixed point self.assertEqual(jaxpr_text.count(' cos '), 1) - f = new_checkpoint(lambda x: scan_apply(sin, scan_apply(sin, x)), + f = jax.checkpoint(lambda x: scan_apply(sin, scan_apply(sin, x)), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) @@ -6086,7 +6675,7 @@ def sin_jvp(primals, tangents): {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ ('', jax.remat), - ('_new', new_checkpoint), + ('_new', jax.checkpoint), ]) def test_remat_of_cond(self, remat): true_fn = lambda c: (jnp.sin(c), jnp.sin(c)) @@ -6111,7 +6700,7 @@ def test_remat_of_cond(self, remat): {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ ('', jax.remat), - ('_new', new_checkpoint), + ('_new', jax.checkpoint), ]) def test_const_in_jvp_cond(self, remat): @jax.custom_jvp @@ -6143,8 +6732,7 @@ def f(x): return lax.cond(x.sum() > 0, f, lambda x: x, x) _, f_vjp = api.vjp(f, jnp.ones((5, 5))) - jaxpr_text = str(f_vjp.args[0].func.args[1]) - + jaxpr_text = str(f_vjp.jaxpr) self.assertEqual(jaxpr_text.count(' sin '), 2) self.assertEqual(jaxpr_text.count(' cos '), 3) # Five calls to dot_general in the backward pass because we have two for @@ -6162,7 +6750,7 @@ def test_remat_checkpoint_dots_outside_cond(self): # placement (because of the carry). x = jnp.ones((5,)) - @partial(new_checkpoint, policy=jax.checkpoint_policies.checkpoint_dots) + @partial(jax.checkpoint, policy=jax.checkpoint_policies.checkpoint_dots) def f(W): def f(x): x = jnp.sin(jnp.dot(x, W, precision=lax.Precision.HIGHEST)) @@ -6173,7 +6761,7 @@ def f(x): return lax.cond(x.sum() > 0, f, lambda x: x, x) _, f_vjp = api.vjp(f, jnp.ones((5, 5))) - jaxpr = f_vjp.args[0].func.args[1] + jaxpr = f_vjp.jaxpr jaxpr_text = str(jaxpr) self.assertEqual(jaxpr_text.count(' sin '), 2) @@ -6185,7 +6773,7 @@ def f(x): def test_remat_of_cond_policy(self): save_cos = lambda prim, *_, **__: str(prim) == 'cos' - f = new_checkpoint(lambda x: lax.cond(x > 0, jnp.sin, lambda x: x, x), + f = jax.checkpoint(lambda x: lax.cond(x > 0, jnp.sin, lambda x: x, x), policy=save_cos) jtu.check_grads(f, (3.,), order=2, modes=['rev']) @@ -6210,7 +6798,7 @@ def sin_jvp(primals, tangents): sin.defjvp(sin_jvp) save_cos = lambda prim, *_, **__: str(prim) == 'cos' - f = new_checkpoint(partial(cond_apply, sin), policy=save_cos) + f = jax.checkpoint(partial(cond_apply, sin), policy=save_cos) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) @@ -6218,14 +6806,14 @@ def sin_jvp(primals, tangents): self.assertEqual(jaxpr_text.count(' cos '), 0) save_sin = lambda prim, *_, **__: str(prim) == 'sin' - f = new_checkpoint(partial(cond_apply, sin), policy=save_sin) + f = jax.checkpoint(partial(cond_apply, sin), policy=save_sin) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) self.assertEqual(jaxpr_text.count(' sin '), 0) self.assertEqual(jaxpr_text.count(' cos '), 1) - f = new_checkpoint(partial(cond_apply, sin), + f = jax.checkpoint(partial(cond_apply, sin), policy=jax.checkpoint_policies.everything_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) @@ -6233,7 +6821,7 @@ def sin_jvp(primals, tangents): self.assertEqual(jaxpr_text.count(' sin '), 0) self.assertEqual(jaxpr_text.count(' cos '), 0) - f = new_checkpoint(partial(cond_apply, sin), + f = jax.checkpoint(partial(cond_apply, sin), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) @@ -6241,7 +6829,7 @@ def sin_jvp(primals, tangents): self.assertEqual(jaxpr_text.count(' sin '), 0) self.assertEqual(jaxpr_text.count(' cos '), 1) - f = new_checkpoint(lambda x: cond_apply(sin, cond_apply(sin, x)), + f = jax.checkpoint(lambda x: cond_apply(sin, cond_apply(sin, x)), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) @@ -6267,7 +6855,7 @@ def sin_jvp(primals, tangents): sin.defjvp(sin_jvp) save_cos = lambda prim, *_, **__: str(prim) == 'cos' - f = new_checkpoint(partial(cond_apply, sin), policy=save_cos) + f = jax.checkpoint(partial(cond_apply, sin), policy=save_cos) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) @@ -6275,14 +6863,14 @@ def sin_jvp(primals, tangents): self.assertEqual(jaxpr_text.count(' cos '), 0) save_sin = lambda prim, *_, **__: str(prim) == 'sin' - f = new_checkpoint(partial(cond_apply, sin), policy=save_sin) + f = jax.checkpoint(partial(cond_apply, sin), policy=save_sin) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) self.assertEqual(jaxpr_text.count(' sin '), 0) self.assertEqual(jaxpr_text.count(' cos '), 1) - f = new_checkpoint(partial(cond_apply, sin), + f = jax.checkpoint(partial(cond_apply, sin), policy=jax.checkpoint_policies.everything_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) @@ -6290,7 +6878,7 @@ def sin_jvp(primals, tangents): self.assertEqual(jaxpr_text.count(' sin '), 0) self.assertEqual(jaxpr_text.count(' cos '), 0) - f = new_checkpoint(partial(cond_apply, sin), + f = jax.checkpoint(partial(cond_apply, sin), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) @@ -6298,7 +6886,7 @@ def sin_jvp(primals, tangents): self.assertEqual(jaxpr_text.count(' sin '), 0) self.assertEqual(jaxpr_text.count(' cos '), 1) - f = new_checkpoint(lambda x: cond_apply(sin, cond_apply(sin, x)), + f = jax.checkpoint(lambda x: cond_apply(sin, cond_apply(sin, x)), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) @@ -6310,7 +6898,7 @@ def sin_jvp(primals, tangents): {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ ('', jax.remat), - ('_new', new_checkpoint), + ('_new', jax.checkpoint), ]) def test_remat_of_while_loop(self, remat): def cond_fn(carry): @@ -6345,7 +6933,7 @@ def f(x): # even with a policy, we can't save residuals (w/o dynamic shapes)! save_cos = lambda prim, *_, **__: str(prim) == 'cos' - g = new_checkpoint(f, policy=save_cos) + g = jax.checkpoint(f, policy=save_cos) jaxpr = api.make_jaxpr(jax.linearize(g, 4.)[1])(1.) self.assertIn(' sin ', str(jaxpr)) self.assertIn(' cos ', str(jaxpr)) @@ -6464,6 +7052,99 @@ def f(x, _): else: assert False + def test_name_stack_annotation(self): + def g(x, y): + with jax.named_scope("g"): + return x @ y + + def g1(x, y): + with jax.named_scope("g1"): + t = checkpoint_name(x @ y, "save_me") + return t + + def f2(x, ws): + for i, w in enumerate(ws): + if i % 8 == 1: + x = g1(x, w) + else: + x = g(x, w) + x = jnp.tanh(x) + return jnp.sum(x) + + @jax.remat + def f(x, ws): + return jax.named_call(jax.jit(f2), name='run_per_expert_shard')(x, ws) + + def make_weight(i): + if i % 2 == 0: + return jnp.ones([64, 128]) + else: + return jnp.ones([128, 64]) + + x = jnp.ones([64, 64]) + ws = [make_weight(i) for i in range(2)] + + out = f(x, ws) + vjp = jax.make_jaxpr(jax.vjp(functools.partial(f, x), ws)[1])(out) + + s = vjp.jaxpr.pretty_print(name_stack=True) + self.assertEqual(s.count('rematted_computation'), 1) + + def test_remat_partial_cse_prevention(self): + @partial(jax.remat, prevent_cse=(False, True)) + def layer(W, x): + res = x @ W + res += jnp.array([1.0, 2.0, 3.0]) # ensure the jaxpr also contains a const + return res + + def net(Ws, x): + for W in Ws: + x = layer(W, x) + return x + + def loss(Ws, x): + return jnp.sum(net(Ws, x)**2) + + Ws = [jnp.ones((3, 3)) for _ in range(2)] + x = jnp.ones(3) + txt = jax.jit(jax.grad(loss, (0, 1))).lower(Ws, x).as_text() + self.assertRegex(txt, r'optimization_barrier %[a-z0-9]+, %[a-z0-9]+ :') + + def test_remat_partial_cse_prevention_all_false(self): + @partial(jax.remat, prevent_cse=(False, False)) + def layer(W, x): + return x @ W + + def net(Ws, x): + for W in Ws: + x = layer(W, x) + return x + + def loss(Ws, x): + return jnp.sum(net(Ws, x)**2) + + Ws = [jnp.ones((3, 3)) for _ in range(2)] + x = jnp.ones(3) + txt = jax.jit(jax.grad(loss, (0, 1))).lower(Ws, x).as_text() # don't crash + + def test_remat_partial_cse_prevention_pytree(self): + @partial(jax.remat, prevent_cse=({'W': False, 'x': True},)) + def layer(dct): + return dct['x'] @ dct['W'] + + def net(Ws, x): + for W in Ws: + x = layer(dict(W=W, x=x)) + return x + + def loss(Ws, x): + return jnp.sum(net(Ws, x)**2) + + Ws = [jnp.ones((3, 3)) for _ in range(2)] + x = jnp.ones(3) + txt = jax.jit(jax.grad(loss, (0, 1))).lower(Ws, x).as_text() + self.assertRegex(txt, r'optimization_barrier %[a-z0-9]+, %[a-z0-9]+ :') + @jtu.with_config(jax_pprint_use_color=False) class JaxprTest(jtu.JaxTestCase): @@ -6482,21 +7163,34 @@ def test_const(self): def fun(x): return (x, 1., np.zeros(1, dtype=jnp.float32)) - expected = "{ lambda a:f32[1]; b:f32[]. let in (b, 1.0, a) }" + dtype = "f64" if config.enable_x64.value else "f32" + if config.use_simplified_jaxpr_constants.value: + expected = f"{{ lambda ; a:f32[]. let in (a, 1.0:{dtype}[], [...]:f32[1]) }}" + else: + expected = f"{{ lambda a:f32[1]; b:f32[]. let in (b, 1.0:{dtype}[], a) }}" + jaxpr = api.make_jaxpr(fun)(jnp.float32(0.)) + self.assertMultiLineStrippedEqual(expected, str(jaxpr)) + + @config.use_simplified_jaxpr_constants(True) + def test_non_scalar_const(self): + def fun(x): + return (x, np.zeros(3, dtype=jnp.float32)) + + expected = "{ lambda ; a:f32[]. let in (a, [...]:f32[3]) }" jaxpr = api.make_jaxpr(fun)(jnp.float32(0.)) self.assertMultiLineStrippedEqual(expected, str(jaxpr)) def test_cond(self): def f(x): return lax.cond(x >= 0., + lambda xt, _: xt + x, + lambda _, xf: xf - x, x + 1., - lambda xt: xt + x, - x + 2., - lambda xf: xf - x) + x + 2.) expected = """{ lambda ; a:f32[]. let - b:bool[] = ge a 0.0 - c:f32[] = add a 1.0 - d:f32[] = add a 2.0 + b:bool[] = ge a 0.0:f32[] + c:f32[] = add a 1.0:f32[] + d:f32[] = add a 2.0:f32[] e:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b f:f32[] = cond[ branches=( @@ -6567,8 +7261,8 @@ def f(x): jax.debug.print("{}", x) return x jaxpr = jax.make_jaxpr(f)(np.int32(0)) - self.assertEqual(jaxpr.eqns[0].primitive, debugging.debug_callback_p) - self.assertStartsWith(str(jaxpr.eqns[0]), "debug_callback[", ) + self.assertEqual(jaxpr.eqns[0].primitive, debugging.debug_print_p) + self.assertStartsWith(str(jaxpr.eqns[0]), "debug_print[") class DCETest(jtu.JaxTestCase): @@ -6678,13 +7372,13 @@ def body(c, _): self.assert_dce_result( jaxpr, used_outputs=used_outputs, expected_used_inputs=expected_used_inputs, - expected_num_eqns=1) # 1 b/c scan doesn't have fwding rule + expected_num_eqns=0) used_outputs[7] = expected_used_inputs[7] = True used_outputs[6] = expected_used_inputs[6] = True self.assert_dce_result( jaxpr, used_outputs=used_outputs, expected_used_inputs=expected_used_inputs, - expected_num_eqns=1) + expected_num_eqns=0) # If we use the value at index 3 only, some of the hidden sequence must be # kept but the rest pruned. @@ -6864,4577 +7558,131 @@ def f(x1, x2): self.assert_dce_result(jaxpr, [True, False], [True, True], 5) -class CustomJVPTest(jtu.JaxTestCase): - - def test_basic(self): - @jax.custom_jvp - def f(x): - return jnp.sin(x) - def f_jvp(primals, tangents): - x, = primals - g, = tangents - return f(x), 2 * jnp.cos(x) * g - f.defjvp(f_jvp) - - x = 3. - self.assertAllClose(f(x), jnp.sin(x)) - self.assertAllClose(api.jvp(f, (x,), (1.,)), - (jnp.sin(x), 2 * jnp.cos(x))) - self.assertAllClose(api.grad(f)(x), 2 * jnp.cos(x)) - - def test_invariance(self): - @jax.custom_jvp - def f(x): - return jnp.cos(2 * x) / 2. - def f_jvp(primals, tangents): - x, = primals - g, = tangents - return (f(x), 3 * g) - f.defjvp(f_jvp) - def f2(x): - y, _ = api.jvp(f, (x,), (x,)) - return y - def f3(x): - y, _ = api.jvp(f2, (x,), (x,)) - return y - x = 1. - self.assertAllClose(api.jvp(f, (x,), (x,)), - api.jvp(f2, (x,), (x,)), - check_dtypes=False) - self.assertAllClose(api.jvp(f, (x,), (x,)), - api.jvp(f3, (x,), (x,)), - check_dtypes=False) - - def test_python_control_flow(self): - @jax.custom_jvp - def f(x): - if x > 0: - return jnp.sin(x) - else: - return jnp.cos(x) - def f_jvp(primals, tangents): - x, = primals - g, = tangents - if x > 0: - return f(x), 2 * g - else: - return f(x), 3 * g - f.defjvp(f_jvp) - x = 2. - self.assertAllClose(f(x), jnp.sin(x)) - self.assertAllClose(f(-x), jnp.cos(-x)) - self.assertAllClose(api.jvp(f, (x,), (1.,)), - (jnp.sin(x), 2.), - check_dtypes=False) - self.assertAllClose(api.jvp(f, (-x,), (1.,)), - (jnp.cos(-x), 3.), - check_dtypes=False) - self.assertAllClose(api.grad(f)(x), 2., check_dtypes=False) - self.assertAllClose(api.grad(f)(-x), 3., check_dtypes=False) - - def test_vmap(self): - @jax.custom_jvp - def f(x): - assert jnp.ndim(x) == 0 - return jnp.sin(x) - def f_jvp(primals, tangents): - x, = primals - g, = tangents - assert jnp.ndim(x) == jnp.ndim(g) == 0 - return f(x), 2 * jnp.cos(x) * g - f.defjvp(f_jvp) +class BufferDonationTest(jtu.BufferDonationTestCase): - x = jnp.arange(3.) - xx = jnp.arange(6.).reshape(2, 3) + @jtu.device_supports_buffer_donation() + def test_pmap_donate_argnums_invalidates_input(self): + move = api.pmap(lambda x: x + x - x, donate_argnums=0) + n = jax.local_device_count() + x = api.pmap(lambda x: x)(jnp.ones([n])) + y = move(x) + self.assertDeleted(x) + np.testing.assert_allclose(y, [1.] * n) - # vmap of f - self.assertAllClose(api.vmap(f)(x), jnp.sin(x)) - self.assertAllClose(api.vmap(api.vmap(f))(xx), jnp.sin(xx)) + @jtu.device_supports_buffer_donation() + def test_pmap_nested_donate_ignored(self): + pmap_fun = jit(lambda x: api.pmap(lambda y: y ** 2, donate_argnums=0)(x)) + a = api.pmap(lambda x: x)(jnp.array([1])) - # vmap of jvp of f - self.assertAllClose(api.vmap(lambda x: api.jvp(f, (x,), (x,)))(x), - (jnp.sin(x), 2 * jnp.cos(x) * x)) - self.assertAllClose(api.vmap(api.vmap(lambda x: api.jvp(f, (x,), (x,))))(xx), - (jnp.sin(xx), 2 * jnp.cos(xx) * xx)) + # NOTE(mattjj): stopped raising error here and instead just ignored + # with self.assertRaisesRegex(ValueError, "nested.*not supported"): + # pmap_fun(a) - # jvp of vmap of f - self.assertAllClose(api.jvp(api.vmap(f), (x,), (x,)), - (jnp.sin(x), 2 * jnp.cos(x) * x)) - self.assertAllClose(api.jvp(api.vmap(api.vmap(f)), (xx,), (xx,)), - (jnp.sin(xx), 2 * jnp.cos(xx) * xx)) + pmap_fun(a) # doesn't crash - # vmap of jvp of vmap of f - self.assertAllClose(api.vmap(lambda x: api.jvp(api.vmap(f), (x,), (x,)))(xx), - (jnp.sin(xx), 2 * jnp.cos(xx) * xx)) - def test_jit(self): - @jax.custom_jvp - def f(x): - return jnp.sin(x) - def f_jvp(primals, tangents): - x, = primals - g, = tangents - return f(x), 2 * jnp.cos(x) * g - f.defjvp(f_jvp) +class NamedCallTest(jtu.JaxTestCase): - x = 3. + def test_non_jaxtype_arg(self): + # For the test to fail without the invalid JaxType filter we need to pass + # in a valid JaxType that forces the invalid Jaxtype to be raised to an + # abstract value. + def f(not_a_jaxtype, a_jaxtype): + # then Jax needs to try and evaluate the abstractified non-JaxType + if not_a_jaxtype: + return a_jaxtype + return 0 - # jit - self.assertAllClose(api.jit(f)(x), jnp.sin(x)) - self.assertAllClose(api.jit(api.jit(f))(x), jnp.sin(x)) + f = api.named_call(f, name="test") + out = jax.jit(f, static_argnums=(0,))("not a Jaxtype", 1) + self.assertEqual(out, 1) - # jit of jvp - self.assertAllClose(api.jit(lambda x: api.jvp(f, (x,), (x,)))(x), - (jnp.sin(x), 2 * jnp.cos(x) * x), - check_dtypes=False) + @parameterized.parameters(jax.jit, jax.grad, jax.vmap, jax.remat) + def test_jax_transforms(self, transform): + f = jnp.sum + x = jnp.array([1.]) - # jvp of jit - self.assertAllClose(api.jvp(api.jit(f), (x,), (x,)), - (jnp.sin(x), 2 * jnp.cos(x) * x), - check_dtypes=False) + unnamed_out = transform(f)(x) + named_out = transform(api.named_call(f, name="test"))(x) - def test_pytrees(self): - @jax.custom_jvp - def f(x): - return {'b': jnp.sin(x['a'])} - def f_jvp(primals, tangents): - x, = primals - g, = tangents - return f(x), {'b': 2 * jnp.cos(x['a']) * g['a']} - f.defjvp(f_jvp) - x = {'a': 3.} - self.assertAllClose(f(x)['b'], jnp.sin(x['a'])) - self.assertAllClose(api.jvp(f, (x,), (x,)), - ({'b': jnp.sin(x['a'])}, - {'b': 2 * jnp.cos(x['a']) * x['a']}), - check_dtypes=False) - - def test_kwargs(self): - # from https://github.com/jax-ml/jax/issues/1938 - @jax.custom_jvp - def my_fun(x, y, c=1.): - return c * (x + y) - def my_jvp(primals, tangents): - x, y, c = primals - t_x, t_y, t_c = tangents - return my_fun(x, y, c), t_c - my_fun.defjvp(my_jvp) - f = lambda x, y: jnp.square(my_fun(x, y, c=2.)).sum() - f(10., 5.) # doesn't crash - api.jvp(f, (10., 5.), (1., 1.)) # doesn't crash - - def test_initial_style(self): - @jax.custom_jvp - def f(x): - return 3 * x - def f_jvp(primals, tangents): - x, = primals - g, = tangents - return f(x), 2 * g - f.defjvp(f_jvp) + self.assertEqual(unnamed_out, named_out) - def foo(x): - out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) - return out + def test_static_argnums(self): + f = api.named_call(lambda x, y: y if x else None, name="test") + f = jax.jit(f, static_argnums=(0,)) + out = f(True, 5) + self.assertEqual(out, 5) - ans = api.grad(foo)(3.) - expected = 2. - self.assertAllClose(ans, expected, check_dtypes=False) + def test_partial_eval(self): + f = api.named_call(lambda x, y: y if x else None, name="test") + f = jax.jit(functools.partial(f, True)) + out = f(5) + self.assertEqual(out, 5) - ans = api.grad(api.jit(foo))(3.) - expected = 2. - self.assertAllClose(ans, expected, check_dtypes=False) + @parameterized.parameters( + [dict(func=func, jit=jit) + for func in ['identity_trivial', 'identity', 'closure_trivial', 'closure', + 'asarray', 'device_put'] + for jit in jtu.JIT_IMPLEMENTATION + if not (jit._name == "noop" and func in ('identity', 'identity_trivial')) + ], + ) + def test_integer_overflow(self, jit, func): + funcdict = { + 'identity_trivial': lambda x: x, # may hit trivial dispatch path + 'identity': lambda x: x + 0, + 'closure_trivial': lambda x: jax.jit(lambda: x)(), + 'closure': lambda x: jax.jit(lambda: x + 0)(), + 'asarray': lambda x: jnp.asarray(x), # add lambdas so no cross-test cache + 'device_put': lambda x: api.device_put(x), + } - ans = api.jit(api.grad(foo))(3.) - expected = 2. - self.assertAllClose(ans, expected, check_dtypes=False) + f = jit(funcdict[func]) - ans = api.grad(api.grad(foo))(3.) - expected = 0. - self.assertAllClose(ans, expected, check_dtypes=False) + int_dtype = dtypes.default_int_dtype() + int_max = np.iinfo(int_dtype).max + int_min = np.iinfo(int_dtype).min - ans = api.grad(api.grad(api.jit(foo)))(3.) - expected = 0. - self.assertAllClose(ans, expected, check_dtypes=False) + # check before any jit cache entries + self.assertRaises(OverflowError, f, int_max + 1) + self.assertRaises(OverflowError, f, int_min - 1) - ans = api.grad(api.jit(api.grad(foo)))(3.) - expected = 0. - self.assertAllClose(ans, expected, check_dtypes=False) + self.assertEqual(f(int_max).dtype, int_dtype) + self.assertEqual(f(int_min).dtype, int_dtype) + self.assertAllClose(f(int_max), int_max) + self.assertAllClose(f(int_min), int_min) - ans = api.jit(api.grad(api.grad(foo)))(3.) - expected = 0. - self.assertAllClose(ans, expected, check_dtypes=False) + # check after any cache entries + self.assertRaises(OverflowError, f, int_max + 1) + self.assertRaises(OverflowError, f, int_min - 1) + if func in ('trivial', 'identity'): + self.assertRaisesRegex( + OverflowError, 'An overflow.*whose argument path is x.', f, + int_max + 1) - def test_initial_style_vmap(self): - @jax.custom_jvp - def f(x): - assert jnp.ndim(x) == 0 - return 3 * x - def f_jvp(primals, tangents): - x, = primals - g, = tangents - return f(x), 2 * g - f.defjvp(f_jvp) - def foo(x): - out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) - return out +class BackendsTest(jtu.JaxTestCase): - ans = api.vmap(foo)(jnp.ones(3)) - expected = 3. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) + @unittest.skipIf(not sys.executable, "test requires sys.executable") + @jtu.run_on_devices("cpu") + def test_no_backend_warning_on_cpu_if_platform_specified(self): + warning_not_expected = ( + "import jax; " + "jax.config.update('jax_platform_name', 'cpu'); " + "jax.numpy.arange(10)") - ans = api.vmap(api.jit(foo))(jnp.ones(3)) - expected = 3. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) + result = subprocess.run([sys.executable, '-c', warning_not_expected], + check=True, capture_output=True) + assert "may be present" not in result.stderr.decode() - ans = api.jit(api.vmap(foo))(jnp.ones(3)) - expected = 3. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.ones(3)) - expected = 2. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(lambda x: api.vmap(api.jit(foo))(x).sum())(jnp.ones(3)) - expected = 2. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(lambda x: api.jit(api.vmap(foo))(x).sum())(jnp.ones(3)) - expected = 2. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(api.jit(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3)) - expected = 2. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.jit(api.grad(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3)) - expected = 2. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_initial_style_vmap_with_collective(self): - - @jax.custom_jvp - def f(x): - return lax.psum(x, 'foo') - - @f.defjvp - def f_jvp(xs, ts): - x, = xs - t, = ts - return lax.psum(x, 'foo'), t - - def g(x): - jaxpr = api.make_jaxpr(f)(x) - return core.eval_jaxpr(jaxpr.jaxpr, [], x)[0] - - v = api.vmap(lambda _, x: g(x), axis_name='foo', in_axes=(0, None), - out_axes=None)(jnp.arange(4.), 2.) - self.assertAllClose(v, 8.) - - def test_closed_over_tracers_error_message(self): - def f(x): - @jax.custom_jvp - def g(y): - return x + y - def g_jvp(primals, tangents): - return g(x), 2 * primals[0] - g.defjvp(g_jvp) - return g(1.) - - self.assertRaises(UnexpectedTracerError, lambda: api.jvp(f, (3.,), (1.,))) - self.assertRaises(UnexpectedTracerError, lambda: api.grad(f)(3.)) - - def test_nondiff_arg(self): - @partial(jax.custom_jvp, nondiff_argnums=(0,)) - def app(f, x): - return f(x) - def app_jvp(f, primals, tangents): - (x,), (t,) = primals, tangents - return app(f, x), 3 * t - app.defjvp(app_jvp) - - ans = app(lambda x: 2 * x, 1) - expected = 2 - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.jvp(lambda x: app(lambda y: 2 * y, x), (1.,), (1.,)) - expected = (2., 3.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_nondiff_arg_jit_tracer(self): - # This test would pass with "final-style" JIT tracing, but that was - # misleading: it doesn't work with "initial-style" staging, i.e. control - # flow primitives like jax.lax.scan or even pjit. The behavior isn't very - # useful either: instead of using nondiff_argnums here, a user can just pass - # such inputs as ordinary arguments, and ignore the corresponding tangents. - # Then nondiff_argnums can be reserved for (1) non jaxtype data (like a - # string- or callable-valued argument which parameterizes the function or - # rule) or (2) static data (e.g. integers which parameterize shapes). - raise unittest.SkipTest("behavior no longer supported") - - @partial(jax.custom_jvp, nondiff_argnums=(0,)) - def f(x, y): - return x * y - def f_jvp(x, primals, tangents): - (y,), (t_y,) = primals, tangents - return f(x, y), 5 * t_y - f.defjvp(f_jvp) - - @jit - def g(x, y): - return f(x, y) - - ans = api.jvp(lambda y: g(2., y), (3.,), (1.,)) - expected = (6., 5.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_nondiff_arg_vmap_tracer(self): - @partial(jax.custom_jvp, nondiff_argnums=(0,)) - def f(x, y): - return x * y - def f_jvp(x, primals, tangents): - (y,), (t_y,) = primals, tangents - return f(x, y), 5 * t_y - f.defjvp(f_jvp) - - g = jax.vmap(f) - - ans = api.jvp(lambda y: g(jnp.array([2.]), y), - (jnp.array([3.]),), (jnp.array([1.]),)) - expected = (jnp.array([6.]), jnp.array([5.])) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_nondiff_arg_hiding_jvp_tracer(self): - def f(x): - @partial(jax.custom_jvp, nondiff_argnums=(0,)) - def g(h, x): - return h(x) - @g.defjvp - def g_jvp(h, primals, tangents): - x, = primals - t, = tangents - return g(h, x), 2. * t - h = lambda y: x + y # capture x - return g(h, x) - - with self.assertRaises(UnexpectedTracerError): - api.jvp(f, (2.,), (1.,)) - - def test_vmap_axes(self): - raise unittest.SkipTest("TODO") # TODO(mattjj): write test - - def test_pmap(self): - raise unittest.SkipTest("TODO") # TODO(mattjj): write test - - def test_missing_jvp_rule_error_message(self): - @jax.custom_jvp - def foo(x): - return x ** 2 - - self.assertRaisesRegex( - AttributeError, - r"No JVP defined for custom_jvp function foo using defjvp.", - lambda: foo(2)) - self.assertRaisesRegex( - AttributeError, - r"No JVP defined for custom_jvp function foo using defjvp.", - lambda: api.jvp(foo, (2.,), (1.,))) - self.assertRaisesRegex( - AttributeError, - r"No JVP defined for custom_jvp function foo using defjvp.", - lambda: api.grad(foo)(2.)) - - def test_jvp_rule_inconsistent_pytree_structures_error_message(self): - @jax.custom_jvp - def f(x): - return (x**2,) - - @f.defjvp - def foo_jvp(primals, tangents): - x, = primals - t, = tangents - return f(x), [2 * x * t, x] - - f(2.) # doesn't crash - self.assertRaisesRegex( - TypeError, - re.escape( - "Custom JVP rule foo_jvp for function f " - "must produce primal and tangent outputs " - "with equal container (pytree) structures, but got " - "{} and {} respectively.".format( - jax.tree.structure((1,)), - jax.tree.structure([1, 2])) - ), - lambda: api.jvp(f, (2.,), (1.,))) - - def test_primal_tangent_aval_disagreement_error_message(self): - @jax.custom_jvp - def f(x): - return x ** 2 - - @f.defjvp - def foo_jvp(primals, tangents): - x, = primals - t, = tangents - return f(x), jnp.reshape(t, (1,)) - - f(2.) # doesn't crash - self.assertRaisesRegex( - TypeError, - re.escape( - "Custom JVP rule must produce primal and tangent outputs " - "with corresponding shapes and dtypes. " - "Expected float32[] (tangent type of float32[]) but got float32[1]."), - lambda: api.jvp(f, (jnp.float32(2.),), (jnp.float32(1.),))) - - - def test_jvp_rule_doesnt_return_pair_error_message(self): - # https://github.com/jax-ml/jax/issues/2516 - - @jax.custom_jvp - def f(x): - return x ** 2 - - @f.defjvp - def foo_jvp(primals, tangents): - x, = primals - t, = tangents - return t - - f(2.) # doesn't crash - self.assertRaisesRegex( - TypeError, - re.escape( - "Custom JVP rule foo_jvp for function f " - "must produce a pair (list or tuple of length two) " - "representing primal and tangent outputs, but got 1.0"), - lambda: api.jvp(f, (2.,), (1.,))) - - def test_jvp_rule_primal_out_type_doesnt_match_primal_error_message(self): - # https://github.com/lucidrains/flash-attention-jax/issues/7 - - def scan_apply(f, x): - y, _ = jax.lax.scan(lambda x, _: (f(x), None), x, None, length=1) - return y - - @jax.custom_jvp - def f(x): - return x - - @f.defjvp - def f_jvp(primals, tangents): - (x,), (xdot,) = primals, tangents - return (x, x), (xdot, xdot) - - x = jnp.float32(1.) - self.assertRaisesRegex( - TypeError, - re.escape( - "Custom JVP rule f_jvp for function f must produce a pair " - "(list or tuple of length two) where the first element represents " - "the primal output (equal in value to the output of the " - "custom_jvp-decorated function f, and in particular of the " - "same container/pytree structure), but instead the JVP rule " - "output's first element had container/pytree structure:\n" - " (float32[], float32[])\n" - "while the custom_jvp-decorated function f had output " - "container/pytree structure:\n" - " float32[]." - ), - lambda: jax.jvp(lambda x: scan_apply(f, x), (x,), (x,))) - - @f.defjvp - def f_jvp2(primals, tangents): - (x,), (xdot,) = primals, tangents - return jnp.zeros((3, *x.shape), x.dtype), xdot - - self.assertRaisesRegex( - TypeError, - re.escape( - "Custom JVP rule f_jvp2 for function f must produce a pair " - "(list or tuple of length two) where the first element represents " - "the primal output (equal in value to the output of the " - "custom_jvp-decorated function f, and in particular " - "with leaves of the same shape/dtype), but instead the JVP rule " - "output's first element had shapes/dtypes of:\n" - " float32[3]\n" - "while the custom_jvp-decorated function f had output shapes/dtypes" - " of:\n" - " float32[]" - ), - lambda: jax.jvp(lambda x: scan_apply(f, x), (x,), (x,))) - - def test_multiple_rule_invocations(self): - @jax.custom_jvp - def expit(x): - return 1 / (1 + lax.exp(-x)) - - @expit.defjvp - def _expit_jvp(primals, tangents): - (x,), (t,) = primals, tangents - ans = expit(x) - t_out = t * ans * (1 - ans) - return ans, t_out - - def scanned_fun(c, _): - return [expit(c[0])] + [c[i-1] + c[i] for i in range(1, len(c))], None - - def foo(x): - zero = jnp.zeros_like(x) - c, _ = lax.scan(scanned_fun, [x, zero, zero, zero, zero], None, length=10) - return c[-1] - - # just make sure these don't crash - foo(3.) - grad(foo)(3.) - grad(lambda x: jax.vmap(foo)(x).sum())(jnp.arange(3.)) - - def test_hard_stuff(self): - arr = jnp.ones((5, 2, 2)) - api.jit(jax.vmap(jnp.linalg.det))(arr) # doesn't crash - - def test_hard_stuff2(self): - @jax.custom_jvp - def f(x): - return np.zeros(x.shape, x.dtype) - - @f.defjvp - def f_jvp(primals, tangents): - x, = primals - t, = tangents - return f(x), t - - # don't crash - jax.jit(jax.vmap(f))(jnp.arange(3.)) - jax.jit(jax.vmap(jax.grad(f)))(jnp.arange(3.)) - jax.jit(jax.grad(lambda x: jax.vmap(f)(x).sum()))(jnp.arange(3.)) - jax.grad(lambda x: jax.vmap(f)(x).sum())(jnp.arange(3.)) - jax.jvp(jax.vmap(f), (jnp.arange(3.),), (jnp.ones(3),)) - - def test_hard_stuff3(self): - @jax.custom_jvp - def relu(x): - return jnp.maximum(x, 0) - - @relu.defjvp - def _relu_jvp(primals, tangents): - x, = primals - t, = tangents - return relu(x), lax.select(x > 0, t, lax.full_like(t, 0)) - - def scanned_fun(c, _): - return [relu(c[0])] + [c[i-1] + c[i] for i in range(1, len(c))], None - - def f(x): - zero = jnp.zeros_like(x) - c, _ = lax.scan(scanned_fun, [x, zero, zero, zero, zero], None, length=10) - return c[-1] - - # don't crash - jax.jit(jax.vmap(f))(jnp.arange(3.)) - jax.jit(jax.vmap(jax.grad(f)))(jnp.arange(3.)) - jax.jit(jax.grad(lambda x: jax.vmap(f)(x).sum()))(jnp.arange(3.)) - jax.grad(lambda x: jax.vmap(f)(x).sum())(jnp.arange(3.)) - jax.jvp(jax.jit(jax.vmap(f)), (jnp.arange(3.),), (jnp.ones(3),)) - - def test_eval_shape(self): - @jax.custom_jvp - def expit(x): - return 1 / (1 + lax.exp(-x)) - - @expit.defjvp - def _expit_jvp(primals, tangents): - (x,), (t,) = primals, tangents - ans = expit(x) - t_out = t * ans * (1 - ans) - return ans, t_out - - # don't crash - api.eval_shape(expit, jnp.ones((2, 3))) - api.eval_shape(api.grad(lambda x: expit(x).sum()), jnp.ones((2, 3))) - - def test_jaxpr_zeros(self): - # from https://github.com/jax-ml/jax/issues/2657 - @jax.custom_jvp - def f(A, b): - return A @ b - - def f_jvp(primals, tangents): - A, b = primals - dA, db = tangents - z = f(A, b) - dz = A @ db + dA @ b - return z, dz - - f.defjvp(f_jvp) - - def experiment(theta): - def step(q, _): - z = f(jnp.eye(3), jnp.ones(3) * theta) - q += z[0] - return q, q - - q = 0. - q, _ = lax.scan(step, q, None, 4) - return q - - grad(experiment)(1.) # doesn't crash - - def test_linear_in_scan(self): - @jax.custom_jvp - def f(x): - return -x - - @f.defjvp - def f_jvp(primals, tangents): - x, = primals - x_dot, = tangents - return f(x), f(x_dot) - - def foo(x): - out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) - return out - - ans = api.grad(foo)(3.) - expected = -1. - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_custom_jvps_first_rule_is_none(self): - # https://github.com/jax-ml/jax/issues/3389 - @jax.custom_jvp - def f(x, y): - return x ** 2 * y - - f.defjvps(None, lambda x_dot, primal_out, x, y: 2 * x * y * x_dot) - ans = grad(f, 1)(2., 3.) # doesn't crash - expected = 12. - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_concurrent_initial_style(self): - # https://github.com/jax-ml/jax/issues/3843 - def unroll(param, sequence): - def scan_f(prev_state, inputs): - return prev_state, jax.nn.sigmoid(param * inputs) - return jnp.sum(jax.lax.scan(scan_f, None, sequence)[1]) - - def run(): - return jax.grad(unroll)(jnp.array(1.0), jnp.array([1.0])) - - expected = run() - - # we just don't want this to crash - n_workers = 2 - with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as e: - futures = [] - for _ in range(n_workers): - futures.append(e.submit(run)) - results = [f.result() for f in futures] - for ans in results: - self.assertAllClose(ans, expected) - - def test_nondiff_argnums_vmap_tracer(self): - # https://github.com/jax-ml/jax/issues/3964 - @partial(jax.custom_jvp, nondiff_argnums=(0, 2)) - def sample(shape, param, seed): - return jax.random.uniform(key=seed, shape=shape, minval=param) - - @sample.defjvp - def sample_jvp(shape, seed, primals, tangents): - param, = primals - dparam, = tangents - dparam = jnp.broadcast_to(dparam, shape) - samples = sample(shape, param, seed) - return samples, samples * dparam # dummy jvp for proof of concept - - # check these don't crash - jax.vmap(lambda seed: sample((2,3), 1., seed))( - jax.random.split(jax.random.key(1), 10)) - jax.jvp(lambda x: sample((2, 3), x, jax.random.key(1)), - (1.,), (1.,)) - - def test_fun_with_nested_calls_2(self): - def call(f, *args): - f = jax.custom_jvp(f) - f.defjvp(lambda primals, tangents: (f(*primals), sum(tangents))) - return f(*args) - - def fun_with_nested_calls_2(x): - def bar(y): - def baz(w): - q = call(lambda x: y, x) - q = q + call(lambda: y) - q = q + call(lambda y: w + y, y) - q = call(lambda w: call(jnp.sin, x) * y, 1.0) + q - return q - return api.jit(baz)(x) - return call(bar, x) - - # test these don't crash - self.assertAllClose(api.jit(fun_with_nested_calls_2)(3.), - fun_with_nested_calls_2(3.)) - api.vmap(fun_with_nested_calls_2)(jnp.arange(3.)) - - def test_closure_with_vmap(self): - # https://github.com/jax-ml/jax/issues/3822 - alpha = np.float32(2.) - - def sample(seed): - @jax.custom_jvp - def f(alpha): - return jax.random.gamma(seed, alpha, shape=[]) - - @f.defjvp - def f_jvp(primal, tangent): - alpha = primal - dalpha = tangent - sample = f(alpha) - partial_alpha = lax.random_gamma_grad(alpha, sample) - return sample, partial_alpha * dalpha - return f(alpha) - - api.vmap(sample)(jax.random.split(jax.random.key(1), 3)) # don't crash - - def test_closure_with_vmap2(self): - # https://github.com/jax-ml/jax/issues/8783 - def h(z): - def f(x): - @jax.custom_jvp - def g(y): - return x * y - - # NOTE: rule closes over vmap tracer - @g.defjvp - def g_jvp(primals, tangents): - (y,), (ydot,) = primals, tangents - return x * y, x * ydot - - return g(z) # NOTE: no vmapped arg - - return jax.vmap(f)(jnp.arange(3., dtype='float32')) - - primals, tangents = jax.jvp(h, (jnp.float32(1.),), (jnp.float32(2.),)) - self.assertAllClose(primals , jnp.arange(3., dtype='float32')) - self.assertAllClose(tangents, 2 * jnp.arange(3., dtype='float32')) - - def test_float0(self): - scalar_float0 = jnp.zeros((), dtype=float0) - @jax.custom_jvp - def f(x, y): - return x, y - def f_jvp(primals, _): - x, y = primals - return (x, y), (2., custom_derivatives_public.zero_from_primal(y)) - f.defjvp(f_jvp) - - primals = (2., 3) - tangents = (np.ones(()), scalar_float0) - expected_tangents = (2., scalar_float0) - self.assertAllClose(api.jvp(f, primals, tangents), - (primals, expected_tangents)) - - def test_float0_initial_style(self): - scalar_float0 = jnp.zeros((), dtype=float0) - @jax.custom_jvp - def f(x, y): - return x, y - def f_jvp(primals, _): - x, y = primals - return (x, y), (2., custom_derivatives_public.zero_from_primal(y)) - f.defjvp(f_jvp) - - def foo(x, y): - out, _ = lax.scan(lambda c, _: (f(*c), None), (x, y), None, length=1) - return out - - primals = (2., 3) - tangents = (np.ones(()), scalar_float0) - expected_tangents = (2., scalar_float0) - - self.assertAllClose(api.jvp(foo, primals, tangents), - (primals, expected_tangents)) - - def test_remat(self): - @jax.custom_jvp - def f(x): - return jnp.sin(x) - def f_jvp(primals, tangents): - x, = primals - g, = tangents - return f(x), 2 * jnp.cos(x) * g - f.defjvp(f_jvp) - - @jax.remat - def g(x): - return f(f(x)) - - ans = g(2.) - expected = np.sin(np.sin(2.)) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(g)(2.) - expected = 4. * api.grad(lambda x: jnp.sin(jnp.sin(x)))(2.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_remat_higher_order(self): - @jax.custom_jvp - def f(x): - return jnp.sin(x) - def f_jvp(primals, tangents): - x, = primals - g, = tangents - return f(x), 2 * jnp.cos(x) * g - f.defjvp(f_jvp) - - def g(x): - return f(f(x)) - - ans = api.grad(api.grad(new_checkpoint(g)))(2.) - expected = api.grad(api.grad(g))(2.) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(new_checkpoint(api.grad(g)))(2.) - expected = api.grad(api.grad(g))(2.) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(api.grad(api.grad(new_checkpoint(g))))(2.) - expected = api.grad(api.grad(api.grad(g)))(2.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_initial_style_vmap_2(self): - # This is like test_initial_style_vmap except the primal function closes - # over an array constant. - y = jnp.arange(1., 4.) - - @jax.custom_jvp - def f(x): - assert jnp.ndim(x) == 0 - return 3 * x * jnp.sum(y) - def f_jvp(primals, tangents): - x, = primals - g, = tangents - return f(x), 2 * g - f.defjvp(f_jvp) - - def foo(x): - out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) - return out - - ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.ones(3)) - expected = 2. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(lambda x: api.vmap(api.jit(foo))(x).sum())(jnp.ones(3)) - expected = 2. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(lambda x: api.jit(api.vmap(foo))(x).sum())(jnp.ones(3)) - expected = 2. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(api.jit(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3)) - expected = 2. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.jit(api.grad(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3)) - expected = 2. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_custom_jvp_vmap_broadcasting_interaction(self): - # https://github.com/jax-ml/jax/issues/6452 - def f2(y, z): - v1 = z - v2 = jnp.sum(y) + z - return jnp.logaddexp(v1, v2) - - def f1(y, z): - v = api.vmap(lambda _y: f2(_y, z))(y) - return jnp.sum(v) - - y = jnp.ones((3, 2)) - f = lambda z: f1(y, z) - z = 0.1 - val, g = api.value_and_grad(f)(z) - self.assertEqual(val.shape, ()) - self.assertEqual(g.shape, ()) - - def test_custom_jvp_vmap_broadcasting_interaction_2(self): - # https://github.com/jax-ml/jax/issues/5849 - @jax.custom_jvp - def transform(box, R): - if jnp.isscalar(box) or box.size == 1: - return R * box - elif box.ndim == 2: - return jnp.einsum('ij,j->i', box, R) - raise ValueError() - - @transform.defjvp - def transform_jvp(primals, tangents): - box, R = primals - dbox, dR = tangents - return (transform(box, R), dR + transform(dbox, R)) - - def periodic_general(box): - def displacement_fn(Ra, Rb, **kwargs): - _box = kwargs.get('box', box) - return transform(_box, Ra - Rb) - - return displacement_fn - - N = 250 - - scalar_box = 1.0 - displacement = periodic_general(scalar_box) - - key = jax.random.key(0) - R = jax.random.uniform(key, (N, 2)) - - def energy_fn(box): - d = partial(displacement, box=box) - d = api.vmap(api.vmap(d, (None, 0)), (0, None)) - return jnp.sum(d(R, R) ** 2) - - self.assertEqual(grad(energy_fn)(scalar_box).shape, ()) - - def test_custom_jvp_implicit_broadcasting(self): - # https://github.com/jax-ml/jax/issues/6357 - if config.enable_x64.value: - raise unittest.SkipTest("test only applies when x64 is disabled") - - @jax.custom_jvp - def projection_unit_simplex(x: jax.Array) -> jax.Array: - """Projection onto the unit simplex.""" - s = 1.0 - n_features = x.shape[0] - u = jnp.sort(x)[::-1] - cssv = jnp.cumsum(u) - s - ind = jnp.arange(n_features, dtype=x.dtype) + 1 - cond = u - cssv / ind > 0 - idx = jnp.count_nonzero(cond) - threshold = cssv[idx - 1] / idx.astype(x.dtype) - return jax.nn.relu(x - threshold) - - - @projection_unit_simplex.defjvp - def projection_unit_simplex_jvp(primals, tangents): - x, = primals - x_dot, = tangents - primal_out = projection_unit_simplex(x) - supp = (primal_out > 0).astype(x_dot.dtype) - card = jnp.count_nonzero(supp).astype(x_dot.dtype) - tangent_out = supp * x_dot - (jnp.dot(supp, x_dot) / card) * supp - return primal_out, tangent_out - - rng = self.rng() - x = rng.rand(5).astype(np.float32) - - J_rev = jax.jacrev(projection_unit_simplex)(x) - J_fwd = jax.jacfwd(projection_unit_simplex)(x) - - p = projection_unit_simplex(x) - support = (p > 0).astype(jnp.float32) - cardinality = jnp.count_nonzero(support).astype(support.dtype) - J_true = jnp.diag(support) - jnp.outer(support, support) / cardinality - self.assertAllClose(J_true, J_fwd) - self.assertAllClose(J_true, J_rev) - - proj = jax.vmap(projection_unit_simplex) - - def fun(X): - return jnp.sum(proj(X) ** 2) - - rng = self.rng() - X = rng.rand(4, 5).astype(np.float32) - U = rng.rand(4, 5) - U /= np.sqrt(np.sum(U ** 2)) - U = U.astype(np.float32) - - eps = 1e-3 - dir_deriv_num = (fun(X + eps * U) - fun(X - eps * U)) / (2 * eps) - dir_deriv = jnp.vdot(jax.grad(fun)(X), U) - self.assertAllClose(dir_deriv, dir_deriv_num, atol=1e-3) - - def test_vmap_inside_defjvp(self): - # https://github.com/jax-ml/jax/issues/3201 - seed = 47 - key = jax.random.key(seed) - mat = jax.random.normal(key, (2, 3)) - - @jax.custom_jvp - def f(mat, aux): - num_rows, num_cols = mat.shape - return jnp.ones((num_rows, 1)) / num_cols - - @f.defjvp - def f_jvp(primals, tangents): - mat, aux = primals - vec, _ = tangents - output = f(*primals) - num_rows, num_cols = mat.shape - size = num_rows * num_cols - # ----- - bd_mat = mat.reshape(1, 1, num_rows, num_cols) - bd_mat = jnp.tile(bd_mat, reps=(num_rows, num_cols)) - bd_mat = bd_mat.reshape(size, num_rows, num_cols) - # ----- - rowsum = jnp.sum(mat, axis=1, keepdims=True) - colsum = jnp.sum(mat, axis=0, keepdims=True) - bd_rowsum = jnp.tile(rowsum, reps=(1, num_rows)) - bd_colsum = jnp.tile(colsum, reps=(num_cols, 1)) - # ----- - bd_vec = vec.reshape(size, 1) - # ----- - def operate(mx, val): - buf = 0 - for i in range(2): - buf = buf + jnp.matmul(mx, bd_colsum) / jnp.power(aux, i) - buf = jnp.matmul(bd_rowsum, buf) - return buf * val[None, :] - # ----- - # Vertorizing will raise shape error - bd_buf = jax.vmap(operate, in_axes=(0, 0), out_axes=0)(bd_mat, bd_vec) - # ----- - bd_buf = bd_buf / aux - jvp = jnp.sum(bd_buf, axis=0) - jvp = jnp.mean(jvp, axis=1, keepdims=True) - # ----- - # JVP ends successfully, but still raise an error - return (output, jvp) - - jax.grad(lambda mat, aux: jnp.sum(f(mat, aux)))(mat, 0.5) # doesn't crash - - def test_custom_jvp_unbroadcasting(self): - # https://github.com/jax-ml/jax/issues/3056 - a = jnp.array([1., 1.]) - - @jax.custom_jvp - def f(x): - return a * x - - @f.defjvp - def f_jvp(primals, tangents): - x, = primals - dx, = tangents - return a * x, a * dx - - shape = grad(lambda x: jnp.sum(f(x)))(jnp.array(1.)).shape - self.assertEqual(shape, ()) - - def test_maybe_perturbed_internal_helper_function(self): - # This is a unit test for an internal API. We include it so as not to - # regress https://github.com/jax-ml/jax/issues/9567. For an explanation of - # this helper function, see https://github.com/jax-ml/jax/issues/6415. - def f(x): - def g(y, _): - z = y * x - self.assertTrue(custom_derivatives._maybe_perturbed(z)) - return y, None - g(1, None) - return lax.scan(g, 1, xs=None, length=1)[0] - - jax.jvp(f, (1.0,), (1.0,)) # assertions inside f - - def test_maybe_perturbed_int_regression(self): - # see https://github.com/jax-ml/jax/discussions/9951 - - @jax.jit - def f(): - x = jnp.array(1) - _, aux_args = custom_derivatives.closure_convert(lambda: x) - self.assertEmpty(aux_args) - f() - - def test_sinc_constant_function_batching(self): - # https://github.com/jax-ml/jax/pull/10756 - batch_data = jnp.arange(15.).reshape(5, 3) - - @jax.vmap - def f(x): - return jax.lax.map(jnp.sinc, x) - g = lambda param: f(param * batch_data).sum() - - @jax.vmap - def f_ref(x): - return jnp.stack([jnp.sinc(x_) for x_ in x]) - g_ref = lambda param: f_ref(param * batch_data).sum() - - grad = jax.grad(g )(0.1) # doesn't crash - grad_ref = jax.grad(g_ref)(0.1) - self.assertAllClose(grad, grad_ref, check_dtypes=False) - - @parameterized.named_parameters( - ('jit_vmap', True, True), - ('jit', True, False), - ('vmap', False, True), - ('', False, False), - ) - def test_symbolic_zero_custom_jvp(self, maybe_jit, maybe_vmap): - def f(static_scalar, static_array, dyn_scalar, dyn_array): - out1 = static_scalar + dyn_scalar - out2 = static_array + dyn_array - return out1, out2 - - def _pack(x): - return lax.broadcast(x, (1,)) - - def _unpack(x): - (x,) = x - return x - - def _vmap(fun): - def _fun(*args): - args = jax.tree.map(_pack, args) - out = jax.vmap(fun)(*args) - out = jax.tree.map(_unpack, out) - return out - return _fun - - f = jax.custom_jvp(f) - - @partial(f.defjvp, symbolic_zeros=True) - def f_jvp(primals, tangents): - static_scalar, *_ = primals - t_static, t_static_arr, t_dyn_scalar, t_dyn_array = tangents - self.assertIs(type(t_static) , custom_derivatives_public.SymbolicZero) - self.assertIs(type(t_static_arr), custom_derivatives_public.SymbolicZero) - self.assertEqual(t_static.shape, ()) - self.assertEqual(t_static_arr.shape, (2,)) - return f(*primals), (static_scalar + 90, t_dyn_array + 91) - - def g(dyn_scalar, dyn_array): - if maybe_vmap: - f_ = _vmap(f) - else: - f_ = f - return f_(1., jnp.array([2., 3.]), dyn_scalar, dyn_array) - - def run(primal_ins, tangent_ins): - return jax.jvp(g, primal_ins, tangent_ins) - - if maybe_jit: - run = jax.jit(run) - - primal_ins = (4., jnp.array([5., 6.])) - tangent_ins = (7., jnp.array([8., 9.])) - primal_outs, tangent_outs = run(primal_ins, tangent_ins) - primal_out1, primal_out2 = primal_outs - tangent_out1, tangent_out2 = tangent_outs - scalar_type = jax.Array if maybe_jit or maybe_vmap else float - self.assertIsInstance(primal_out1, scalar_type) - self.assertAllClose(primal_out1, 5.) - self.assertIsInstance(tangent_out1, scalar_type) - self.assertAllClose(tangent_out1, 91.) - self.assertIsInstance(primal_out2, jax.Array) - self.assertArraysAllClose(primal_out2, jnp.array([7., 9.])) - self.assertIsInstance(tangent_out2, jax.Array) - self.assertArraysAllClose(tangent_out2, jnp.array([99., 100.])) - - def test_symbolic_zero_custom_jvp_vmap_output(self): - @jax.custom_jvp - def f(x, y): - return x * y - - @partial(f.defjvp, symbolic_zeros=True) - def f_jvp(primals, tangents): - x, y = primals - x_dot, y_dot = tangents - self.assertIs(type(y_dot), custom_derivatives_public.SymbolicZero) - return f(x, y), y_dot - - jax.grad(lambda x, y: jax.vmap(f)(x, y).sum())(jnp.ones(3), jnp.ones(3)) - - def test_symbolic_zeros_memoization_caching(self): - # Tests multiple zero patterns for partial_eval._memoize, and also tests - # that we're okay with stores being occupied with equal values. - - @jax.custom_jvp - def f(x, y): - return x * y - - @partial(f.defjvp, symbolic_zeros=True) - def f_jvp(primals, tangents): - x, y = primals - x_dot, y_dot = tangents - return f(x, y), y_dot - - f_ = core.jaxpr_as_fun(jax.make_jaxpr(f)(2., 3.)) - _ = jax.linearize(f_, 2., 3.) - _ = jax.linearize(lambda x: f_(x, 3.), 2.) # don't crash! - - def test_symbolic_zeros_under_jit(self): - # https://github.com/jax-ml/jax/issues/14833 - Zero = jax.custom_derivatives.SymbolicZero - - @jax.custom_jvp - def f(x, y): - return x * y - - @partial(f.defjvp, symbolic_zeros=True) - def fjvp(primals, tangents): - x, y = primals - tx, ty = tangents - assert type(tx) is not Zero or type(ty) is not Zero - return f(x, y), ( - ty if type(tx) is Zero else - tx if type(ty) is Zero else - tx + ty) - - jax.jacfwd(jax.jit(f))(0.1, 0.2) # don't crash - - def test_custom_jvp_functools_partial(self): - def fun(x, y, a): - return x + y * a - - fun_wrapped = functools.partial(fun, a = 0.1) - - def jvp_fn(primals, tangents): - return jax.jvp(fun_wrapped, primals, tangents) - - fn = jax.custom_jvp(fun_wrapped) - fn.defjvp(jvp_fn) - - self.assertEqual((1.0, 0.1), jax.grad(lambda args: fn(*args))((1.0, 2.0))) - - def test_run_rules_more_than_once(self): - # https://github.com/jax-ml/jax/issues/16614 - - @jax.custom_jvp - def f(x, y): - return x - - @partial(f.defjvp, symbolic_zeros=True) - def f_jvp(primals, tangents): - x, _ = primals - x_dot, _ = tangents - return x, x_dot - - def body(x_y, _): - x, y = x_y - return (f(x, y), x), None - - @jax.grad - def g(x): - (out, _), _ = lax.scan(body, (x, 1.), xs=None, length=2) - return out - - g(1.) # doesn't crash - - def test_dce(self): - @jax.custom_jvp - def f(x, y): - return jnp.sin(x), x + jnp.cos(y) - - @f.defjvp - def f_jvp(primals, tangents): - x, y = primals - dx, dy = tangents - return f(x, y), (2.0 * jnp.cos(x) * dx, 1.5 * dx - 0.5 * jnp.sin(y) * dy) - - def check_jaxpr(jaxpr, used_outs, includes, excludes): - dce_jaxpr, _ = pe.dce_jaxpr(jaxpr, used_outs) - if not dce_jaxpr.eqns: - assert not includes - return - call_jaxpr = dce_jaxpr.eqns[0].params["call_jaxpr"] - for prim in includes: - assert any(eqn.primitive == prim for eqn in call_jaxpr.eqns) - for prim in excludes: - assert all(eqn.primitive != prim for eqn in call_jaxpr.eqns) - - x, y = 0.1, -1.3 - jaxpr = jax.make_jaxpr(f)(x, y).jaxpr - check_jaxpr(jaxpr, [True, True], [lax.sin_p, lax.cos_p], []) - check_jaxpr(jaxpr, [True, False], [lax.sin_p], [lax.cos_p]) - check_jaxpr(jaxpr, [False, True], [lax.cos_p], [lax.sin_p]) - check_jaxpr(jaxpr, [False, False], [], [lax.sin_p, lax.cos_p]) - - def dce_jaxpr_as_fun(jaxpr, used_outs): - jaxpr_, _ = pe.dce_jaxpr(jaxpr, used_outs) - fun = core.jaxpr_as_fun(pe.close_jaxpr(jaxpr_)) - return lambda *args: fun(*args)[0] - - f0 = dce_jaxpr_as_fun(jaxpr, [True, False]) - f1 = dce_jaxpr_as_fun(jaxpr, [False, True]) - self.assertAllClose( - api.jvp(f0, (x, y), (1.0, 0.0)), (f0(x, y), 2.0 * jnp.cos(x))) - self.assertAllClose( - api.jvp(f0, (x, y), (0.0, 1.0)), (f0(x, y), 0.0)) - self.assertAllClose( - api.jvp(f1, (x, y), (1.0, 0.0)), (f1(x, y), 1.5)) - self.assertAllClose( - api.jvp(f1, (x, y), (0.0, 1.0)), (f1(x, y), -0.5 * jnp.sin(y))) - - def test_resolve_kwargs_error_message(self): - @jax.custom_jvp - def f(x, y, *, z=None): - return jnp.sin(x), x + jnp.cos(y) - - @f.defjvp - def f_jvp(primals, tangents): - self.fail("should not be executed") - - with self.assertRaisesRegex( - TypeError, - r"The input arguments to the custom_jvp-decorated function f(.*)\n" - r"missing a required argument: 'y'" - ): - f(0.5) - - with self.assertRaisesRegex( - TypeError, - r"The input arguments to the custom_jvp-decorated function f(.*)\n" - "The following keyword arguments could not be resolved to positions: z" - ): - f(0.5, 0.1, z=1.0) - - -class CustomVJPTest(jtu.JaxTestCase): - - def test_basic(self): - @jax.custom_vjp - def f(x): - return jnp.sin(x) - def f_fwd(x): - return f(x), jnp.cos(x) - def f_rev(cos_x, g): - return (2 * cos_x * g,) - f.defvjp(f_fwd, f_rev) - - x = 3. - self.assertAllClose(f(x), jnp.sin(x)) - self.assertAllClose(api.grad(f)(x), 2 * jnp.cos(x)) - self.assertAllClose(api.value_and_grad(f)(x), - (jnp.sin(x), 2 * jnp.cos(x))) - - def test_invariance(self): - @jax.custom_vjp - def f(x): - return jnp.cos(2 * x) / 2. - def f_fwd(x): - return (f(x), x) - def f_rev(x, g): - return (g * 3,) - f.defvjp(f_fwd, f_rev) - def f2(x): - y, _ = api.value_and_grad(f)(x) - return y - def f3(x): - y, _ = api.value_and_grad(f2)(x) - return y - x = 1. - self.assertAllClose(f(x), f2(x), check_dtypes=False) - self.assertAllClose(f(x), f3(x), check_dtypes=False) - self.assertAllClose(api.grad(f)(x), api.grad(f2)(x), - check_dtypes=False) - self.assertAllClose(api.grad(f)(x), api.grad(f3)(x), - check_dtypes=False) - - def test_python_control_flow(self): - @jax.custom_vjp - def f(x): - if x > 0: - return jnp.sin(x) - else: - return jnp.cos(x) - def f_fwd(x): - if x > 0: - return f(x), x - else: - return f(x), x - def f_rev(x, g): - if x > 0: - return (2 * g,) - else: - return (3 * g,) - f.defvjp(f_fwd, f_rev) - x = 2. - self.assertAllClose(f(x), jnp.sin(x)) - self.assertAllClose(f(-x), jnp.cos(-x)) - self.assertAllClose(api.value_and_grad(f)(x), (jnp.sin(x), 2.), - check_dtypes=False) - self.assertAllClose(api.value_and_grad(f)(-x), (jnp.cos(-x), 3.), - check_dtypes=False) - - def test_vmap(self): - @jax.custom_vjp - def f(x): - assert jnp.ndim(x) == 0 - return jnp.sin(x) - def f_fwd(x): - assert jnp.ndim(x) == 0 - return f(x), jnp.cos(x) - def f_rev(cos_x, g): - return (2 * cos_x * g,) - f.defvjp(f_fwd, f_rev) - - x = jnp.arange(3.) - xx = jnp.arange(6.).reshape(2, 3) - - # vmap of f - self.assertAllClose(api.vmap(f)(x), jnp.sin(x)) - self.assertAllClose(api.vmap(api.vmap(f))(xx), jnp.sin(xx)) - - # vmap of grad of f - self.assertAllClose(api.vmap(api.grad(f))(x), 2 * jnp.cos(x)) - self.assertAllClose(api.vmap(api.value_and_grad(f))(x), - (jnp.sin(x), 2 * jnp.cos(x))) - self.assertAllClose(api.vmap(api.vmap(api.grad(f)))(xx), 2 * jnp.cos(xx)) - self.assertAllClose(api.vmap(api.vmap(api.value_and_grad(f)))(xx), - (jnp.sin(xx), 2 * jnp.cos(xx))) - - # grad of vmap of f - self.assertAllClose(api.grad(lambda x: api.vmap(f)(x).sum())(x), - 2 * jnp.cos(x)) - self.assertAllClose(api.grad(lambda x: api.vmap(api.vmap(f))(x).sum())(xx), - 2 * jnp.cos(xx)) - - # vmap of grad of vmap of f - self.assertAllClose(api.vmap(api.grad(lambda x: api.vmap(f)(x).sum()))(xx), - 2 * jnp.cos(xx)) - - def test_jit(self): - @jax.custom_vjp - def f(x): - return jnp.sin(x) - def f_fwd(x): - return f(x), jnp.cos(x) - def f_rev(cos_x, g): - return (2 * cos_x * g,) - f.defvjp(f_fwd, f_rev) - - x = 3. - - # jit - self.assertAllClose(api.jit(f)(x), jnp.sin(x)) - self.assertAllClose(api.jit(api.jit(f))(x), jnp.sin(x)) - - # jit of grad - self.assertAllClose(api.jit(api.grad(f))(x), 2 * jnp.cos(x), - check_dtypes=False) - - # grad of jit - self.assertAllClose(api.grad(api.jit(f))(x), 2 * jnp.cos(x), - check_dtypes=False) - - def test_pytrees(self): - @jax.custom_vjp - def f(x): - return {'b': jnp.sin(x['a'])} - def f_fwd(x): - return f(x), {'r': jnp.cos(x['a'])} - def f_bwd(res, g): - cos_x = res['r'] - return ({'a': 2 * cos_x * g['b']},) - f.defvjp(f_fwd, f_bwd) - x = {'a': 3.} - self.assertAllClose(f(x)['b'], jnp.sin(x['a'])) - self.assertAllClose(api.grad(lambda x: f(x)['b'])(x), - {'a': 2 * jnp.cos(x['a'])}) - - def test_jvp_error(self): - @jax.custom_vjp - def f(x): - return jnp.sin(x) - def f_fwd(x): - return f(x), jnp.cos(x) - def f_rev(cos_x, g): - return (2 * cos_x * g,) - f.defvjp(f_fwd, f_rev) - - self.assertRaisesRegex( - TypeError, - r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.", - lambda: api.jvp(f, (3.,), (1.,))) - self.assertRaisesRegex( - TypeError, - r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.", - lambda: api.jvp(api.vmap(f), (jnp.arange(3.),), (jnp.ones(3),))) - self.assertRaisesRegex( - TypeError, - r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.", - lambda: api.jvp(jit(f), (3.,), (1.,))) - - def test_kwargs(self): - # from https://github.com/jax-ml/jax/issues/1938 - @jax.custom_vjp - def my_fun(x, y, c=1.): - return c * (x + y) - my_fun.defvjp(lambda x, y, c=1.: (my_fun(c, y, c), None), - lambda _, g: (g, g, g)) - f = lambda x, y: jnp.square(my_fun(x, y, c=2.)).sum() - f(10., 5.) # doesn't crash - api.grad(f)(10., 5.) # doesn't crash - - def test_initial_style(self): - @jax.custom_vjp - def f(x): - return jnp.sin(x) - def f_fwd(x): - return f(x), jnp.cos(x) - def f_rev(cos_x, g): - return (2 * cos_x * g,) - f.defvjp(f_fwd, f_rev) - - def foo(x): - out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) - return out - - ans = api.grad(foo)(3.) - expected = 2. * jnp.cos(3.) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(api.grad(foo))(3.) - expected = -2. * jnp.sin(3.) - self.assertAllClose(ans, expected) - - def test_initial_style_vmap(self): - @jax.custom_vjp - def f(x): - assert jnp.ndim(x) == 0 - return 3 * x - def f_fwd(x): - return f(x), jnp.cos(x) - def f_rev(cos_x, g): - return (2 * cos_x * g,) - f.defvjp(f_fwd, f_rev) - - def foo(x): - out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) - return out - - ans = api.vmap(foo)(jnp.arange(3.)) - expected = 3. * jnp.arange(3.) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.arange(3.)) - expected = 2. * jnp.cos(jnp.arange(3.)) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_nondiff_arg(self): - @partial(jax.custom_vjp, nondiff_argnums=(0,)) - def app(f, x): - return f(x) - def app_fwd(f, x): - return app(f, x), jnp.cos(x) - def app_rev(f, cos_x, g): - return (cos_x * g,) - app.defvjp(app_fwd, app_rev) - - ans = app(lambda x: 2 * x, 1) - expected = 2 - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.value_and_grad(lambda x: app(lambda y: 2 * y, x))(1.) - expected = (2., jnp.cos(1.)) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_closed_over_jit_tracer(self): - # See the comment in CustomJVPTest.test_nondiff_arg_jit_tracer. - raise unittest.SkipTest("behavior no longer supported") - - # This test is similar to test_nondiff_arg_tracer except it uses lexical - # closure rather than the nondiff_argnums mechanism. We decided to disallow - # tracers in nondiff_argnums to greatly simplify bookkeeping while still - # supporting the cases for which it is necessary. - def outer(x): - @jax.custom_vjp - def f(y): - return x * y - def f_fwd(y): - return f(y), jnp.cos(y) - def f_rev(cos_y, g): - return (cos_y * g,) - f.defvjp(f_fwd, f_rev) - return f - - @jit - def g(x, y): - return outer(x)(y) - - ans = g(2, 3.) - expected = 6. - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(g, 1)(2., 3.) - expected = jnp.cos(3.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_closed_over_vmap_tracer(self): - def outer(x): - @jax.custom_vjp - def f(y): - return x * y - def f_fwd(y): - return f(y), jnp.cos(y) - def f_rev(cos_y, g): - return (cos_y * g,) - f.defvjp(f_fwd, f_rev) - return f - - @api.vmap - def g(x): - return outer(x)(3.) - - ans = g(np.arange(3.)) - expected = np.arange(3.) * 3 - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_closed_over_tracer3(self): - def outer(x): - @jax.custom_vjp - def f(y): - return x * y - def f_fwd(y): - return f(y), (x, jnp.cos(y)) - def f_rev(res, g): - x, cos_y = res - return (cos_y * g * x,) - f.defvjp(f_fwd, f_rev) - return api.grad(f) - - @api.vmap - def g(x): - return outer(x)(3.) - - ans = g(np.arange(3.)) - expected = np.cos(3.) * np.arange(3.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_nondiff_arg_tracer_error(self): - # This is similar to the old (now skipped) test_nondiff_arg_tracer, except - # we're testing for the error message that usage pattern now raises. - - @partial(jax.custom_vjp, nondiff_argnums=(0,)) - def f(x, y): - return x * y - def f_fwd(x, y): - return f(x, y), jnp.cos(y) - def f_rev(x, cos_y, g): - return (cos_y * g,) - f.defvjp(f_fwd, f_rev) - - @jit - def g(x, y): - return f(x, y) - - with self.assertRaisesRegex(UnexpectedTracerError, "custom_vjp"): - _ = g(2, 3.) - with self.assertRaisesRegex(UnexpectedTracerError, "custom_vjp"): - _ = api.grad(g, 1)(2., 3.) - - def test_vmap_axes(self): - raise unittest.SkipTest("TODO") # TODO(mattjj): write test - - def test_pmap(self): - raise unittest.SkipTest("TODO") # TODO(mattjj): write test - - def test_missing_vjp_rule_error(self): - @jax.custom_vjp - def foo(x): - return x ** 2 - - self.assertRaisesRegex( - AttributeError, - r"No VJP defined for custom_vjp function foo using defvjp.", - lambda: foo(2)) - self.assertRaisesRegex( - AttributeError, - r"No VJP defined for custom_vjp function foo using defvjp.", - lambda: api.grad(foo)(2.)) - - def test_vjp_rule_inconsistent_pytree_structures_error(self): - @jax.custom_vjp - def f(x): - return x - - def foo_fwd(x): - return x, None - - def foo_bwd(_, g): - return (g, g) - - f.defvjp(foo_fwd, foo_bwd) - - f(2) # doesn't crash - self.assertRaisesRegex( - TypeError, - re.escape( - "Custom VJP bwd rule must produce an output with the same container " - "(pytree) structure as the args tuple of the primal function, " - "and in particular must produce a tuple of length equal to the " - "number of arguments to the primal function, but got bwd output " - "structure {} for primal input structure {}.".format( - jax.tree.structure((1, 1)), - jax.tree.structure((1,))) - ), - lambda: api.grad(f)(2.)) - - def test_vjp_bwd_returns_non_tuple_error(self): - @jax.custom_vjp - def f(x): - return x - - def foo_fwd(x): - return x, None - - def foo_bwd(_, g): - return 2. * g # Should be a tuple - - f.defvjp(foo_fwd, foo_bwd) - with self.assertRaisesRegex(TypeError, "Custom VJP bwd rule .* must produce a tuple"): - api.grad(f)(3.) - - def test_fwd_rule_primal_out_type_doesnt_match_primal_error_message(self): - # https://github.com/lucidrains/flash-attention-jax/issues/7 - - def scan_apply(f, x): - y, _ = jax.lax.scan(lambda x, _: (f(x), None), x, None, length=1) - return y - - @jax.custom_vjp - def f(x): - return x - - def f_fwd(x): - return (x, x), None - - def f_bwd(_, y_bar): - return (y_bar,) - - f.defvjp(f_fwd, f_bwd) - - self.assertRaisesRegex( - TypeError, - re.escape( - "Custom VJP fwd rule f_fwd for function f must produce a pair " - "(list or tuple of length two) where the first element represents " - "the primal output (equal to the output of the " - "custom_vjp-decorated function f) and the second element " - "represents residuals (i.e. values stored from the forward " - "pass for use on the backward pass), but instead the fwd rule " - "output's first element had container/pytree structure:\n" - " (float32[], float32[])\n" - "while the custom_vjp-decorated function f had output " - "container/pytree structure:\n" - " float32[]." - ), - lambda: jax.grad(lambda x: scan_apply(f, x))(jnp.float32(1.))) - - def f_fwd2(x): - return jnp.zeros((3, *x.shape), x.dtype), None - - def f_bwd2(_, y_bar): - return (y_bar,) - - f.defvjp(f_fwd2, f_bwd2) - - self.assertRaisesRegex( - TypeError, - re.escape( - "Custom VJP fwd rule f_fwd2 for function f must produce a pair " - "(list or tuple of length two) where the first element represents " - "the primal output (equal to the output of the " - "custom_vjp-decorated function f) and the second element " - "represents residuals (i.e. values stored from the forward " - "pass for use on the backward pass), but instead the fwd rule " - "output's first element had shapes/dtypes of:\n" - " float32[3]\n" - "while the custom_vjp-decorated function f had output " - "shapes/dtypes of:\n" - " float32[]" - ), - lambda: jax.grad(lambda x: scan_apply(f, x))(jnp.float32(1.))) - - def test_issue2511(self): - arr = jnp.ones((5, 2, 2)) - foo = lambda x: api.vmap(jnp.linalg.det, (0,))(x) - api.jit(foo)(arr) # doesn't crash - - def test_lowering_out_of_traces(self): - # https://github.com/jax-ml/jax/issues/2578 - - class F(collections.namedtuple("F", ["a"])): - def __call__(self, x): - return jax.nn.relu(self.a) * x - - @jax.jit - def g(f, x): - return f(x) - - jax.grad(g, argnums=(1,))(F(2.0), 0.) # doesn't crash - - def test_clip_gradient(self): - # https://github.com/jax-ml/jax/issues/2784 - @jax.custom_vjp - def _clip_gradient(lo, hi, x): - return x # identity function when not differentiating - - def clip_gradient_fwd(lo, hi, x): - return x, (lo, hi,) - - def clip_gradient_bwd(res, g): - lo, hi = res - return (None, None, jnp.clip(g, lo, hi),) - - _clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd) - - def clip_gradient(x): - lo = -0.1 - hi = x + 0.1 - return _clip_gradient(lo, hi, x) - - g = jax.grad(clip_gradient)(0.1) # doesn't crash - self.assertAllClose(g, jnp.array(0.2)) - - def test_nestable_vjp(self): - # Verify that https://github.com/jax-ml/jax/issues/3667 is resolved. - def f(x): - return x ** 2 - - @jax.custom_vjp - def g(x): - return f(x) - - def g_fwd(x): - y, f_vjp = api.vjp(f, x) - return y, f_vjp - - def g_bwd(f_vjp, y_bar): - return f_vjp(y_bar) - - g.defvjp(g_fwd, g_bwd) - - # Check that VJP can be nested in simple situations. For this to pass, - # vjp has to return a PyTree. - _, g_vjp = api.vjp(g, 1.0) - y, = g_vjp(1.0) - self.assertAllClose(y, jnp.array(2.0)) - - # Check that VJP can be nested in complex situations. For this to pass, - # vjp can't treat the closed-over tracer x as a static argument. - @jit - def z(x): - _, g_vjp = api.vjp(g, x) - return g_vjp - y, = z(1.0)(3.0) - self.assertAllClose(y, jnp.array(6.0)) - - def test_initial_style_vmap_2(self): - # https://github.com/jax-ml/jax/issues/4173 - x = jnp.ones((10, 3)) - - # Create the custom function - @jax.custom_vjp - def custom_fun(x): - return x.sum() - - def forward(x): - return x.sum(), (jnp.ones_like(x),) - - def backward(res, g): - return g * res[0], - - custom_fun.defvjp(forward, backward) - - def train_fun(x): - - def summed_fun(x): - return api.vmap(custom_fun)(x).sum() - - return api.grad(summed_fun)(x) - - def scan_body(carry, inputs): - x = carry - return carry, train_fun(x) - - scan_range = jnp.arange(4) - lax.scan(scan_body, x, scan_range) # don't crash - - def test_initial_style_vmap_3(self): - # This is like test_initial_style_vmap except the primal function closes - # over an array constant. - y = jnp.arange(1., 4.) - - @jax.custom_vjp - def f(x): - assert jnp.ndim(x) == 0 - return 3 * x * jnp.sum(y) - def f_fwd(x): - return f(x), jnp.cos(x) - def f_rev(cos_x, g): - return (2 * cos_x * g,) - f.defvjp(f_fwd, f_rev) - - def foo(x): - out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) - return out - - ans = api.vmap(foo)(jnp.arange(3.)) - expected = 3. * jnp.arange(3.) * 6 - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.arange(3.)) - expected = 2. * jnp.cos(jnp.arange(3.)) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_initial_style_vmap_with_collective(self): - - @jax.custom_vjp - def f(x): - return lax.psum(x, 'foo') - - def f_fwd(x): - return lax.psum(x, 'foo'), None - - def f_bwd(res, dx): - return dx - f.defvjp(f_fwd, f_bwd) - - def g(x): - jaxpr = api.make_jaxpr(f)(x) - return core.eval_jaxpr(jaxpr.jaxpr, [], x)[0] - - out = api.vmap(lambda _, x: g(x), axis_name='foo', in_axes=(0, None), - out_axes=None)(jnp.arange(4.), 2.) - self.assertAllClose(out, 8.) - - def test_bwd_closes_over_tracer(self): - def f(y): - @jax.custom_vjp - def f(x): - return 2. * jnp.sin(x) - - def fwd(x): - return f(x), () - - def bwd(_, g): - return (2. * jnp.cos(y) * g,) # capture! - - f.defvjp(fwd, bwd) - - return jax.grad(f)(1.) - - ans = jax.jit(f)(2.) - self.assertAllClose(ans, 2. * jnp.cos(2.)) - - ans = jax.vmap(f)(jnp.arange(3.)) - self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) - - ans = jax.jit(jax.vmap(f))(jnp.arange(3.)) - self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) - - ans = jax.vmap(jax.jit(f))(jnp.arange(3.)) - self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) - - ans = jax.grad(f)(4.) - self.assertAllClose(ans, -2. * jnp.sin(4.)) - - def test_fwd_closes_over_tracer(self): - def f(y): - @jax.custom_vjp - def f(x): - return 2. * jnp.sin(x) - - def fwd(x): - return f(x), y - - def bwd(y, g): - return (2. * jnp.cos(y) * g,) # capture! - - f.defvjp(fwd, bwd) - - return jax.grad(f)(1.) - - ans = jax.jit(f)(2.) - self.assertAllClose(ans, 2. * jnp.cos(2.)) - - ans = jax.vmap(f)(jnp.arange(3.)) - self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) - - ans = jax.jit(jax.vmap(f))(jnp.arange(3.)) - self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) - - ans = jax.vmap(jax.jit(f))(jnp.arange(3.)) - self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) - - ans = jax.grad(f)(4.) - self.assertAllClose(ans, -2. * jnp.sin(4.)) - - def test_float0(self): - @jax.custom_vjp - def f(x, _): - return x - def f_fwd(x, _): - # we need a defined (non-float0) tangent to trigger the rule - return x, (2., 1) - def f_rev(*_): - return (2., 1) - f.defvjp(f_fwd, f_rev) - - x = 2. - y = 3 - self.assertEqual(api.grad(f, allow_int=True, argnums=(0, 1))(x, y), - (2., np.zeros(shape=(), dtype=float0))) - - def test_float0_initial_style(self): - @jax.custom_vjp - def f(x): - return x - def f_fwd(x): - return x, (2., x) - def f_rev(*_): - return ((2., jnp.zeros(shape=(), dtype=float0)),) - f.defvjp(f_fwd, f_rev) - - def foo(x, y): - out, _ = lax.scan(lambda c, _: (f(c), None), (x, y), None, length=1) - return out[0] - - x = 2. - y = 3 - self.assertEqual(api.grad(foo, allow_int=True, argnums=(0, 1))(x, y), - (2., np.zeros(shape=(), dtype=float0))) - - def test_remat(self): - @jax.custom_vjp - def f(x): - return jnp.sin(x) - def f_fwd(x): - return f(x), jnp.cos(x) - def f_rev(cos_x, g): - return (2 * cos_x * g,) - f.defvjp(f_fwd, f_rev) - - @jax.remat - def g(x): - return f(f(x)) - - ans = g(2.) - expected = np.sin(np.sin(2.)) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(g)(2.) - expected = 4. * api.grad(lambda x: jnp.sin(jnp.sin(x)))(2.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_remat_higher_order(self): - @jax.custom_vjp - def f(x): - return jnp.sin(x) - def f_fwd(x): - return f(x), jnp.cos(x) - def f_rev(cos_x, g): - return (2 * cos_x * g,) - f.defvjp(f_fwd, f_rev) - - def g(x): - return f(f(x)) - - ans = api.grad(api.grad(jax.remat(g)))(2.) - expected = api.grad(api.grad(g))(2.) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(jax.remat(api.grad(g)))(2.) - expected = api.grad(api.grad(g))(2.) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(api.grad(api.grad(jax.remat(g))))(2.) - expected = api.grad(api.grad(api.grad(g)))(2.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_bwd_nones(self): - @jax.custom_vjp - def f(x, y): - return x * jnp.sin(y) - def f_fwd(x, y): - return f(x, y), jnp.cos(y) - def f_rev(cos, g): - return (None, 2 * cos * g) - f.defvjp(f_fwd, f_rev) - - ans = api.grad(lambda x: f(x, x))(3.) - expected = 2 * jnp.cos(3.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_bwd_nones_vmap(self): - @jax.custom_vjp - def f(x, y): - return x * jnp.sin(y) - def f_fwd(x, y): - return f(x, y), jnp.cos(y) - def f_rev(cos, g): - return (None, 2 * cos * g) - f.defvjp(f_fwd, f_rev) - - ans = api.grad(lambda x: api.vmap(f)(x, x).sum())(jnp.arange(3.)) - expected = 2 * jnp.cos(jnp.arange(3.)) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_bwd_nones_pytree(self): - @jax.custom_vjp - def f(xs, y): - x1, x2 = xs - return x1 * x2 * jnp.sin(y) - def f_fwd(xs, y): - return f(xs, y), jnp.cos(y) - def f_rev(cos, g): - return (None, 2 * cos * g) - f.defvjp(f_fwd, f_rev) - - ans = api.grad(lambda x: f((x, x), x))(3.) - expected = 2 * jnp.cos(3.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_custom_vjp_closure_4521(self): - # https://github.com/jax-ml/jax/issues/4521 - @jax.custom_vjp - def g(x, y): - return None - def g_fwd(x, y): - return None, y - def g_bwd(residuals, z_bar): - assert False - - g.defvjp(g_fwd, g_bwd) - - def f(xs, y): - v_g = api.vmap(g, in_axes=(0, None), out_axes=None) - v_g(xs, y) - - def scan_body(xs, _): - y = jnp.zeros(1) - _, vjp_f = api.vjp(f, xs, y) - vjp_f(None) - return xs, None - - lax.scan(scan_body, jnp.ones(5), None, 100) # doesn't crash - - def test_float0_bwd_none(self): - @jax.custom_vjp - def f(i, x): - return jnp.sin(x) - def f_fwd(i, x): - return f(i, x), jnp.cos(x) - def f_rev(cos_x, g): - return (None, 2 * cos_x * g) - f.defvjp(f_fwd, f_rev) - - ans = api.grad(f, 1)(jnp.array([1, 2]), 3.) # doesn't crash - expected = 2 * jnp.cos(3.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_custom_gradient(self): - @jax.custom_gradient - def f(x): - return x ** 2, lambda g: (g * x,) - - self.assertAllClose(f(3.), 9., check_dtypes=False) - self.assertAllClose(api.grad(f)(3.), 3., check_dtypes=False) - self.assertAllClose(api.grad(api.grad(f))(3.), 1., check_dtypes=False) - - def test_custom_gradient_2(self): - @jax.custom_gradient - def f(x, y): - return x * y, lambda g: (y, x) - - self.assertAllClose(f(3., 4.), 12., check_dtypes=False) - self.assertAllClose(api.grad(f, argnums=(0, 1))(3., 4.), (4., 3.), - check_dtypes=False) - - def test_custom_gradient_3(self): - @jax.custom_gradient - def f(x): - vjp = lambda g: (jnp.cos(x) * jnp.arange(3., 6.),) - return jnp.sum(jnp.sin(x)), vjp - - self.assertAllClose(f(jnp.arange(3)), jnp.sum(jnp.sin(jnp.arange(3.))), - check_dtypes=False) - self.assertAllClose( - api.grad(f)(jnp.arange(3.)), - api.grad(lambda x: jnp.sum(jnp.sin(x)))(jnp.arange(3.)) * jnp.arange(3., 6.), - check_dtypes=False) - - def test_custom_gradient_can_return_singleton_value_in_vjp(self): - @jax.custom_gradient - def f(x): - return x ** 2, lambda g: g * x - - self.assertAllClose(f(3.), 9., check_dtypes=False) - self.assertAllClose(api.grad(f)(3.), 3., check_dtypes=False) - self.assertAllClose(api.grad(api.grad(f))(3.), 1., check_dtypes=False) - - def test_closure_convert(self): - def cos_after(fn, x): - converted_fn, aux_args = jax.closure_convert(fn, x) - self.assertLessEqual(len(aux_args), 1) - return _cos_after(converted_fn, x, *aux_args) - - @partial(jax.custom_vjp, nondiff_argnums=(0,)) - def _cos_after(fn, x, *args): - return jnp.cos(fn(x, *args)) - - def fwd(fn, x, *args): - y = _cos_after(fn, x, *args) - return y, (x, args) - - def rev(fn, res, g): - x, args = res - x_bar = 17. * x - args_bars = [42. * a for a in args] - return (x_bar, *args_bars) - - _cos_after.defvjp(fwd, rev) - - def dist(c, x): - return jnp.sum((x - c) ** 2.) - - def solve(c, x): - def closure(x): - return dist(c, x) - return cos_after(closure, x) - - c, x = 2. * jnp.ones(2), jnp.ones(2) - expected = jnp.cos(dist(c, x)) - self.assertAllClose(solve(c, x), expected, check_dtypes=False) - g_c, g_x = api.grad(solve, argnums=(0, 1))(c, x) - self.assertAllClose(g_c, 42. * c, check_dtypes=False) - self.assertAllClose(g_x, 17. * x, check_dtypes=False) - - def test_closure_convert_mixed_consts(self): - # Like test_closure_convert, but close over values that - # participate in AD as well as values that do not. - # See https://github.com/jax-ml/jax/issues/6415 - - def cos_after(fn, x): - converted_fn, aux_args = jax.closure_convert(fn, x) - self.assertLessEqual(len(aux_args), 1) - return _cos_after(converted_fn, x, *aux_args) - - @partial(jax.custom_vjp, nondiff_argnums=(0,)) - def _cos_after(fn, x, *args): - return jnp.cos(fn(x, *args)) - - def fwd(fn, x, *args): - y = _cos_after(fn, x, *args) - return y, (x, args) - - def rev(fn, res, g): - x, args = res - x_bar = 17. * x - args_bars = [42. * a for a in args] - return (x_bar, *args_bars) - - _cos_after.defvjp(fwd, rev) - - def dist(c, s, x): - return jnp.sum(s * (x - c) ** 2.) - - def solve(c, s, x): - def closure(x): - return dist(c, s, x) - return cos_after(closure, x) - - c, s, x = 2. * jnp.ones(2), 3. * jnp.ones(2), jnp.ones(2) - expected = jnp.cos(dist(c, s, x)) - self.assertAllClose(solve(c, s, x), expected, check_dtypes=False) - g_c, g_x = api.grad(solve, argnums=(0, 2))(c, s, x) - self.assertAllClose(g_c, 42. * c, check_dtypes=False) - self.assertAllClose(g_x, 17. * x, check_dtypes=False) - - def test_closure_convert_pytree_mismatch(self): - # See https://github.com/jax-ml/jax/issues/23588 - def f(x, z): - return z * x - - x, z = 2.0, 3.0 - _, vjp = api.vjp(f, x, z) - vjp_pure, vjp_aux_args = jax.closure_convert(vjp, x) - vjp_pure(x, *vjp_aux_args) - with self.assertRaisesRegex( - TypeError, "The inputs to the closure produced by closure_convert"): - vjp_pure(x, vjp_aux_args) - - def test_float0_cotangents_automatically_handled(self): - @jax.custom_vjp - def f(x, y): - return x - - def f_fwd(x, y): - return x, None - - def f_bwd(_, zbar): - return (0., 1) - - f.defvjp(f_fwd, f_bwd) - - jax.jit(lambda x: jax.vjp(f, 0., x)[1](1.))(1) # doesn't crash - - def test_custom_vjp_scan_batching_edge_case(self): - # https://github.com/jax-ml/jax/issues/5832 - @jax.custom_vjp - def mul(x, coeff): return x * coeff - def mul_fwd(x, coeff): return mul(x, coeff), (x, coeff) - def mul_bwd(res, g): - x, coeff = res - g_x = g * coeff - g_coeff = (x * g).sum() - return g_x, g_coeff - mul.defvjp(mul_fwd, mul_bwd) - - def scan_over_mul(x, coeff): - def f_(x, t): - return mul(x, coeff), None - y, _ = jax.lax.scan(f_, x, jnp.arange(3)) - return y - - key = jax.random.key(0) - key1, key2 = jax.random.split(key, 2) - x_batch = jax.random.normal(key1, (3, 2)) - covector_batch = jax.random.normal(key2, (3, 2)) - coeff = jnp.array(1., dtype=x_batch.dtype) - - batched_scan_over_mul = jax.vmap(scan_over_mul, in_axes=(0, None), out_axes=0) - res, vjp_fun = jax.vjp(batched_scan_over_mul, x_batch, coeff) - vjp_fun(covector_batch) # doesn't crash - - jtu.check_grads(batched_scan_over_mul, (x_batch, coeff), order=2, - modes=['rev']) - - def test_closure_with_vmap2(self): - # https://github.com/jax-ml/jax/issues/8783 - def h(z): - def f(x): - @jax.custom_vjp - def g(y): - return x * y - - def g_fwd(y): - return x * y, (x, x * y, y) - def g_rev(res, w_bar): - x, *_ = res - return (x * w_bar,) - g.defvjp(g_fwd, g_rev) - - return g(z) - - return jax.vmap(f)(jnp.arange(3., dtype='float32')).sum() - - jtu.check_grads(h, (jnp.float32(3.14),), order=1, modes=['rev']) - - def test_pytrees_not_required_to_contain_nones(self): - class A(list): - pass - - def unflatten(_, children): - assert children[0] is not None - return A(children) - - tree_util.register_pytree_node(A, lambda x: (x, None), unflatten) - - @jax.custom_vjp - def f(x): - return x[0] - def f_fwd(x): - return x[0], None - def f_bwd(_, g): - return A([g]), - f.defvjp(f_fwd, f_bwd) - - jax.grad(f)(A([1.])) # doesn't crash - - def test_vmap_vjp_called_twice(self): - # https://github.com/jax-ml/jax/pull/14728 - @jax.custom_vjp - def f(x): - return x - f.defvjp(lambda x: (x, None), lambda _, y_bar: (y_bar,)) - - _, f_vjp = jax.vjp(jax.vmap(f), jnp.array([3.])) - f_vjp(jnp.array([3.])) - f_vjp(jnp.array([3.])) # doesn't crash - - def test_symbolic_zero_custom_vjp_basic(self): - ZERO = custom_derivatives_public.SymbolicZero - - @jax.custom_vjp - def f(x, y, z): - return x, x - - def fwd(x, y, z): - self.assertIsInstance(x, jax.custom_derivatives.CustomVJPPrimal) - self.assertIsInstance(y, jax.custom_derivatives.CustomVJPPrimal) - self.assertIsInstance(z, jax.custom_derivatives.CustomVJPPrimal) - self.assertTrue(x.perturbed) - self.assertFalse(y.perturbed) - self.assertFalse(z.perturbed) - return (x.value, x.value), None - - def fwd_all(x, y, z): - self.assertIsInstance(x, jax.custom_derivatives.CustomVJPPrimal) - self.assertIsInstance(y, jax.custom_derivatives.CustomVJPPrimal) - self.assertIsInstance(z, jax.custom_derivatives.CustomVJPPrimal) - self.assertTrue(x.perturbed) - self.assertTrue(y.perturbed) - self.assertTrue(z.perturbed) - return (x.value, x.value), None - - def bwd_all(_, g): - x1, x2 = g - self.assertFalse(type(x1) is ZERO) - self.assertFalse(type(x2) is ZERO) - return x1, x1, x2 - - def bwd_fst(_, g): - x1, x2 = g - self.assertFalse(type(x1) is ZERO) - self.assertIs(type(x2), ZERO) - return x1, x1, x2 - - def bwd_snd(_, g): - x1, x2 = g - self.assertIs(type(x1), ZERO) - self.assertFalse(type(x2) is ZERO) - return x1, x1, x2 - - x, y, z = 4., 5., 6. - i = np.array(7, np.int32) - zero = np.array(0.) - - f.defvjp(fwd, bwd_all, symbolic_zeros=True) - h = jax.jit(f) - jax.jacrev(h)(x, y, z) - jax.jacrev(lambda x: h(x, y, z))(x) - jax.jacrev(h, argnums=(0, 1, 2), allow_int=True)(x, i, i) - - f.defvjp(fwd_all, bwd_fst, symbolic_zeros=True) - fst_f = lambda *xs: f(*xs)[0] - _, vjp = jax.vjp(fst_f, x, y, z) - _, _, gz = vjp(x) - self.assertArraysAllClose(gz, zero) - - f.defvjp(fwd_all, bwd_snd, symbolic_zeros=True) - snd_f = lambda *xs: f(*xs)[1] - _, vjp = jax.vjp(snd_f, x, y, z) - gx, gy, _ = vjp(x) - self.assertArraysAllClose(gx, zero) - self.assertArraysAllClose(gy, zero) - - f.defvjp(fwd, bwd_snd, symbolic_zeros=True) - _, vjp = jax.vjp(lambda x: snd_f(x, y, z), x) - gx, = vjp(x) - self.assertArraysAllClose(gx, zero) - - def test_symbolic_zero_custom_vjp_bwd_shape_error(self): - @jax.custom_vjp - def f(x, y, z): - return x, y, z - - def fwd(x, y, z): - return f(x.value, y.value, z.value), None - - def bwd(_, gs): - x_bar, y_bar, z_bar = gs - return y_bar, x_bar, z_bar # swapped! - - f.defvjp(fwd, bwd, symbolic_zeros=True) - - with self.assertRaisesRegex( - ValueError, - r'Consider just returning a None here'): - jax.grad(lambda x, y, z: f(x, y, z)[2].sum())( - jnp.ones(1), jnp.ones(2), jnp.ones(3)) - - @parameterized.named_parameters( - ('jit_vmap', True, True), - ('jit', True, False), - ('vmap', False, True), - ('', False, False), - ) - def test_symbolic_zero_custom_vjp(self, maybe_jit, maybe_vmap): - # below: - # * static_scalar will be static in and out - # * static_array will be static in, but dynamic out - # * dyn_scalar and dyn_array will be dynamic in and out - - ZERO = custom_derivatives_public.SymbolicZero - - def f(static_scalar, static_array, dyn_scalar, dyn_array): - out1 = static_scalar + dyn_scalar - out2 = static_array + dyn_array - return static_scalar, static_array, out1, out2 - - def _pack(x): - return lax.broadcast(x, (1,)) - - def _unpack(x): - (x,) = x - return x - - def _vmap(fun): - def _fun(*args): - args = jax.tree.map(_pack, args) - out = jax.vmap(fun)(*args) - out = jax.tree.map(_unpack, out) - return out - return _fun - - f = jax.custom_vjp(f) - - def fwd(*args): - xs, pert = [x.value for x in args], [x.perturbed for x in args] - self.assertFalse(pert[0]) - self.assertFalse(pert[1]) - self.assertTrue(pert[2]) - self.assertTrue(pert[3]) - return f(*xs), xs - - def bwd(res, g): - static_scalar, *_ = res - t_static, t_static_arr, t_dyn_scalar, t_dyn_array = g - self.assertIs(type(t_static), ZERO) - self.assertFalse(type(t_static_arr) is ZERO) - self.assertFalse(type(t_dyn_scalar) is ZERO) - self.assertFalse(type(t_dyn_array) is ZERO) - self.assertEqual(t_static.shape, ()) - self.assertEqual(t_static_arr.shape, (2,)) - return (static_scalar + 90, - t_static_arr + 91, - t_dyn_scalar + 92, - t_dyn_array + 93) - - f.defvjp(fwd, bwd, symbolic_zeros=True) - - def g(dyn_scalar, dyn_array): - if maybe_vmap: - f_ = _vmap(f) - else: - f_ = f - outs = f_(1., jnp.array([2., 3.]), dyn_scalar, dyn_array) - return outs[1:] - - def run(primal_ins, cotangent_outs): - primal_outs, vjp = jax.vjp(g, *primal_ins) - cotangent_ins = vjp(cotangent_outs) - return primal_outs, cotangent_ins - - if maybe_jit: - run = jax.jit(run) - - scalar_type = jax.Array if maybe_jit or maybe_vmap else float - primal_ins = (4., jnp.array([5., 6.])) - cotangent_outs = (jnp.array([10., 11.]), 7., jnp.array([8., 9.])) - primal_outs, cotangent_ins = run(primal_ins, cotangent_outs) - - primal_out1, primal_out2, primal_out3 = primal_outs - self.assertIsInstance(primal_out1, jax.Array) - self.assertAllClose(primal_out1, jnp.array([2., 3.])) - self.assertIsInstance(primal_out2, scalar_type) - self.assertAllClose(primal_out2, 5.) - self.assertIsInstance(primal_out3, jax.Array) - self.assertAllClose(primal_out3, jnp.array([7., 9.])) - - ct_in1, ct_in2 = cotangent_ins - self.assertIsInstance(ct_in1, scalar_type) - self.assertAllClose(ct_in1, 99.) - self.assertIsInstance(ct_in2, jax.Array) - self.assertArraysAllClose(ct_in2, jnp.array([101., 102.])) - - def test_symbolic_zero_custom_vjp_vmap_output(self): - @jax.custom_vjp - def f(x, y): - return x, y - - def fwd(x, y): - self.assertTrue(x.perturbed) - self.assertFalse(y.perturbed) - return f(x.value, y.value), None - - def bwd(_, g): - _, ct_y = g - self.assertIs(type(ct_y), custom_derivatives_public.SymbolicZero) - return g - - f.defvjp(fwd, bwd, symbolic_zeros=True) - jax.grad(lambda x, y: jax.vmap(f)(x, y)[0].sum())(jnp.ones(3), jnp.ones(3)) - - def test_symbolic_zero_custom_vjp_custom_pytree(self): - tree_values = custom_derivatives_public.custom_vjp_primal_tree_values - - @tree_util.register_pytree_node_class - class Box: - def __init__(self_, strict, val): - if strict: - # make sure we aren't getting special arguments that should only - # come up when symbolic_zeros is True - self.assertFalse(hasattr(val, 'perturbed')) - self_.strict = strict - self_.x = val - - def tree_flatten(self_): - return [self_.x], self_.strict - - @classmethod - def tree_unflatten(cls, strict, xs): - x, = xs - return cls(strict, x) - - x, y = Box(False, jnp.array(72.)), jnp.array(73.) - - @jax.custom_vjp - def f(box, y): - return box.x * y - - def fwd0(box, y): - self.assertTrue(box.x.perturbed) - self.assertFalse(y.perturbed) - box, y = map(tree_values, [box, y]) - return f(box, y), (box, y) - - def bwd0(res, g): - box, y = res - return y * g, box.x * g - - def fwd1(box, y): - self.assertFalse(box.x.perturbed) - self.assertTrue(y.perturbed) - box, y = map(tree_values, [box, y]) - return f(box, y), (box, y) - - def bwd1(res, g): - box, y = res - return y * g, box.x * g - - f.defvjp(fwd0, bwd0, symbolic_zeros=True) - jax.grad(f, argnums=0)(x, y) - f.defvjp(fwd1, bwd1, symbolic_zeros=True) - jax.grad(f, argnums=1)(x, y) - - def fwd_strict(box, y): - return f(box, y), (box, y) - - def bwd_strict(res, g): - box, y = res - return y * g, box.x * g - - f.defvjp(fwd_strict, bwd_strict) - jax.grad(f)(x, y) - - def test_symbolic_zeros_memoization_caching(self): - # Tests multiple zero patterns for partial_eval._memoize, and also tests - # that we're okay with stores being occupied with equal values. - @jax.custom_vjp - def f(x, y): - return x * y - - def f_fwd(x, y): - return x.value, None - - def f_bwd(_, z_bar): - return z_bar, None - - f.defvjp(f_fwd, f_bwd, symbolic_zeros=True) - - f_ = core.jaxpr_as_fun(jax.make_jaxpr(f)(2., 3.)) - _ = jax.linearize(f_, 2., 3.) - _ = jax.linearize(lambda x: f_(x, 3.), 2.) # don't crash! - - def test_run_rules_more_than_once(self): - # https://github.com/jax-ml/jax/issues/16614 - - @jax.custom_vjp - def f(x, y): - return x + y - - def f_fwd(x, y): - if y.perturbed: - res = None - else: - res = [] - return x.value + y.value, res - - def f_bwd(res, ct): - return ct, ct - - f.defvjp(f_fwd, f_bwd, symbolic_zeros=True) - - def body(x_y, _): - x, y = x_y - return (f(x, y), x), None - - @jax.grad - def g(x): - (out, _), _ = lax.scan(body, (x, 1.), xs=None, length=2) - return out - - g(1.) # doesn't crash - - def test_nones_representing_zeros_in_subtrees_returned_by_bwd(self): - # https://github.com/jax-ml/jax/issues/8356 - @jax.custom_vjp - def f(x): - return x[0] - - def f_fwd(x): - return f(x), None - - def f_bwd(_, z_bar): - return (z_bar, (None, None)), - - f.defvjp(f_fwd, f_bwd) - - jax.grad(f)((1.0, (2.0, 3.0))) # don't crash - - def test_pytree_nones_returned_by_bwd(self): - @jax.custom_vjp - def f(x): - return x[0] - - def f_fwd(x): - return f(x), None - - def f_bwd(_, z_bar): - return (z_bar, (None, None)), - - f.defvjp(f_fwd, f_bwd) - - jax.grad(f)((1.0, (2.0, None))) # don't crash - - def test_bwd_rule_shape_mismatch(self): - @jax.custom_vjp - def foo(x, y): - return x - - def foo_fwd(x, y): - return x, None - - def foo_bwd(_, g): - return jnp.zeros(3), jnp.zeros(3) - - foo.defvjp(foo_fwd, foo_bwd) - - with self.assertRaisesRegex( - ValueError, - r'output\[1\] the bwd rule produced an output of shape/dtype float..\[3\]'): - jax.grad(lambda x, y: foo(x, y * y).sum(), 1)(jnp.ones(3), jnp.ones(4)) - - def test_bwd_rule_shape_mismatch_disable(self): - # TODO(mattjj): remove this test when the config option is removed - @jax.custom_vjp - def foo(x, y): - return x - - def foo_fwd(x, y): - return x, None - - def foo_bwd(_, g): - return jnp.zeros(3), jnp.zeros(3) - - foo.defvjp(foo_fwd, foo_bwd) - - with config.custom_vjp_disable_shape_check(True): - jax.grad(lambda x, y: foo(x, y).sum(), 1)(jnp.ones(3), jnp.ones(4)) - - def test_bwd_rule_can_produce_list_or_tuple(self): - @jax.custom_vjp - def f(x, y): - return x * y - - def f_fwd(x, y): - return f(x, y), (x, y) - - def f_bwd(xy, g): - x, y = xy - return [g * y, x * g] # list, not tuple - - f.defvjp(f_fwd, f_bwd) - - jax.grad(f)(1., 2.) # don't crash - - def test_optimize_remat(self): - def fun(x): - # This array is included to make sure that we handle consts appropriately - return np.array([1.0])*x - - def fwd(x): - return np.array([2.0])*x*x/np.array([1.0]), (x,) - - x = jnp.linspace(0, 5.0, 10) - fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( - fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}), - fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {})) - - self.assertAllClose(jax.jit(fwd)(x)[0], 2*x*x) # Shouldn't hit custom DCE - self.assertAllClose(jax.jit(lambda x: fwd(x)[0])(x), x) # Should be DCEed - - def test_optimize_remat_vmap(self): - def fun(x): - return (np.array([1.0])*x)[0] - def fwd(x): - return (np.array([2.0])*x*x/np.array([1.0]))[0], (x,) - x = jnp.linspace(0, 5.0, 10) - fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( - fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}), - fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {})) - self.assertAllClose(jax.jit(jax.vmap(fwd))(x)[0], 2*x*x) - self.assertAllClose(jax.jit(lambda x: jax.vmap(fwd)(x)[0])(x), x) - - def test_optimize_remat_cond(self): - def fun(x): - return x - def fwd(x): - return x*x, (x,) - - x = jnp.linspace(0, 5.0, 10) - fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( - fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}), - fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {})) - - def g(x): - return jax.lax.cond(True, fwd, lambda x: (2.0 * x, (x,)), x) - - self.assertAllClose(jax.jit(g)(x)[0], x*x) - self.assertAllClose(jax.jit(lambda x: g(x)[0])(x), x) - - def test_optimize_remat_jvp(self): - def fun(x): - return x**2 - def fwd_(x): - return x*x, (x,) - - fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( - fun, api_util.debug_info("custom_vjp fun", fun, (3.2,), {}), - fwd_, api_util.debug_info("custom_vjp fwd", fwd_, (3.2,), {})) - calc = jax.jvp(fwd, (3.2,), (1.0,)) - expected = jax.jvp(fwd_, (3.2,), (1.0,)) - self.assertAllClose(calc, expected) - - @jax.jit - def g(x, t): - (y, r), (y_dot, r_dot) = jax.jvp(fwd, (x,), (t,)) - return y, y_dot - calc = g(3.2, 1.0) - expected = jax.jvp(fun, (3.2,), (1.0,)) - self.assertAllClose(calc, expected) - - def test_optimize_remat_gh21303(self): - @jax.custom_vjp - def f(x): - return jnp.tan(x) - - def f_fwd(x): - return jnp.sin(x), (x,) - - def f_bwd(res, g): - x, = res - cos_x = jnp.cos(x) - return (cos_x * g,) - - f.defvjp(f_fwd, f_bwd, optimize_remat=True) - - def temp(x): - out = jax.remat(f)(x) - out = out ** 2 - return out - - v, g = jax.value_and_grad(temp)(3.2) - self.assertAllClose(v, jnp.tan(3.2)**2) - - def test_optimize_remat_multiple_args(self): - def f_(x, y): - return jnp.sin(x) * y - - @jax.custom_vjp - def f(x, y): - return f_(x, y) - - def f_fwd(x, y): - return f(x, y), (jnp.cos(x), jnp.sin(x), y) - - def f_bwd(res, g): - cos_x, sin_x, y = res - return (cos_x * g * y, sin_x * g) - - f.defvjp(f_fwd, f_bwd, optimize_remat=True) - x, y = 3.2, 1.0 - self.assertAllClose(jax.grad(f)(x, y), jax.grad(f_)(x, y)) - - def test_optimize_remat_kwargs(self): - @jax.custom_vjp - def f(x, y): - return jnp.sin(x) * y - - def f_fwd(x, y, *, keyword=False): - del keyword - return f(x, y), (jnp.cos(x), jnp.sin(x), y) - - def f_bwd(res, g): - cos_x, sin_x, y = res - return (cos_x * g * y, sin_x * g) - - f.defvjp(f_fwd, f_bwd, optimize_remat=True) - x, y = 3.2, 1.0 - jax.grad(f)(x, y) # Doesn't error - - def test_optimize_remat_custom_vmap(self): - # See https://github.com/jax-ml/jax/pull/23000 - @jax.custom_vjp - def f(x, y): - return jnp.sin(x) * y - - @jax.custom_batching.custom_vmap - def f_fwd(x, y): - return f(x, y), (jnp.cos(x), jnp.sin(x), y) - - @f_fwd.def_vmap - def f_fwd_vmap(_, in_batched, x, y): - # Insert a new const here to test the optimize_remat batching rule. - out = np.array([2.0])*f(x, y) - out_batched = (True, (True, True, True)) - return (out, (jnp.cos(x), jnp.sin(x), y)), out_batched - - def f_bwd(res, g): - cos_x, sin_x, y = res - return (cos_x * g * y, sin_x * g) - - f.defvjp(f_fwd, f_bwd, optimize_remat=True) - x, y = jnp.linspace(0.0, 1.0, 5), jnp.linspace(2.0, 5.0, 5) - jax.jit(jax.vmap(jax.grad(f)))(x, y) # Doesn't error - - def test_dce(self): - @jax.custom_vjp - def f(x, y): - return jnp.sin(x), x + jnp.cos(y) - - def f_fwd(x, y): - return f(x, y), (jnp.cos(x), jnp.sin(y)) - - def f_bwd(res, cts): - cos_x, sin_y = res - ct_a, ct_b = cts - return 2.0 * cos_x * ct_a + 1.5 * ct_b, -0.5 * sin_y * ct_b - - f.defvjp(f_fwd, f_bwd) - - def check_jaxpr(jaxpr, used_outs, includes, excludes): - dce_jaxpr, _ = pe.dce_jaxpr(jaxpr, used_outs) - if not dce_jaxpr.eqns: - assert not includes - return - call_jaxpr = dce_jaxpr.eqns[0].params["fun_jaxpr"] - for prim in includes: - assert any(eqn.primitive == prim for eqn in call_jaxpr.eqns) - for prim in excludes: - assert all(eqn.primitive != prim for eqn in call_jaxpr.eqns) - - x, y = 0.1, -1.3 - jaxpr = jax.make_jaxpr(f)(x, y).jaxpr - check_jaxpr(jaxpr, [True, True], [lax.sin_p, lax.cos_p], []) - check_jaxpr(jaxpr, [True, False], [lax.sin_p], [lax.cos_p]) - check_jaxpr(jaxpr, [False, True], [lax.cos_p], [lax.sin_p]) - check_jaxpr(jaxpr, [False, False], [], [lax.sin_p, lax.cos_p]) - - def dce_jaxpr_as_fun(jaxpr, used_outs): - jaxpr_, _ = pe.dce_jaxpr(jaxpr, used_outs) - fun = core.jaxpr_as_fun(pe.close_jaxpr(jaxpr_)) - return lambda *args: fun(*args)[0] - - f0 = dce_jaxpr_as_fun(jaxpr, [True, False]) - f1 = dce_jaxpr_as_fun(jaxpr, [False, True]) - self.assertAllClose( - api.grad(f0, argnums=(0, 1))(x, y), (2.0 * jnp.cos(x), 0.0)) - self.assertAllClose( - api.grad(f1, argnums=(0, 1))(x, y), (1.5, -0.5 * jnp.sin(y))) - - def test_resolve_kwargs_error_message(self): - @jax.custom_vjp - def f(x, y, *, z=None): - return jnp.sin(x), x + jnp.cos(y) - - def f_fwd(x, y): - self.fail("should not be executed") - - def f_bwd(res, cts): - self.fail("should not be executed") - - f.defvjp(f_fwd, f_bwd) - - with self.assertRaisesRegex( - TypeError, - r"The input arguments to the custom_vjp-decorated function f(.*)\n" - r"missing a required argument: 'y'" - ): - f(0.5) - - with self.assertRaisesRegex( - TypeError, - r"The input arguments to the custom_vjp-decorated function f(.*)\n" - "The following keyword arguments could not be resolved to positions: z" - ): - f(0.5, 0.1, z=1.0) - - -def transpose_unary(f, x_example): - def transposed(y): - x, = api.linear_transpose(f, x_example)(y) - return x - return transposed - - -# This class wraps jax.custom_transpose.custom_transpose in order to pass in a -# particular tree of output type on each call. Otherwise it forwards -# all attribute access. -class _custom_transpose: - def __init__(self, out_types, fun): - self.out_types = out_types - self.fun = jax.custom_transpose.custom_transpose(fun) - - def __getattr__(self, name): - return getattr(self.fun, name) - - def __call__(self, *args): - return self.fun(self.out_types, *args) - - -# This function is meant to be used as a decorator that delegates to -# custom_transpose but makes it easy to specify output argument types -# by example. If used directly a decorator (i.e. not invoked with -# example arguments), assumes a scalar-valued function. -# -# TODO(frostig): remove this (and its uses) once custom_transpose offers -# an option of inferring output types. -def custom_transpose(example_out): - if isinstance(example_out, Callable): - out_type = core.get_aval(0.).to_tangent_aval() - return _custom_transpose(out_type, example_out) - return partial( - _custom_transpose, - jax.tree.map( - lambda x: core.get_aval(x).to_tangent_aval(), example_out)) - - -class CustomTransposeTest(jtu.JaxTestCase): - - def test_linear_call(self): - def f(x, y): - def fn(r, x): return x / r - def tp(r, t): return t / r - return x + jax.custom_derivatives.linear_call(fn, tp, y, x) - - def f_ref(x, y): - return x + x / y - - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - self.assertAllClose(f(x, y), f_ref(x, y)) - - f1 = lambda x: f(x, y) - f1_ref = lambda x: f_ref(x, y) - self.assertAllClose(transpose_unary(f1, x)(x), - transpose_unary(f1_ref, x)(x)) - - def test_linear_call_incorrect_transpose(self): - def f(x, y): - def fn(r, x): return x / r - def tp(r, t): return t / (2. * r) # nb: not the true transpose - return x + jax.custom_derivatives.linear_call(fn, tp, y, x) - - def f_ref(x, y): - return x + x / y - - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - self.assertAllClose(f(x, y), f_ref(x, y)) - - f1 = lambda x: f(x, y) - f1_ref = lambda x: f_ref(x, 2. * y) # nb: double the reference divisor - self.assertAllClose(transpose_unary(f1, x)(x), - transpose_unary(f1_ref, x)(x)) - - def test_linear_call_transpose_transpose_transpose(self): - def fn(r, x): return x / r - def tp(r, t): return t / (2. * r) # nb: untrue transpose - def f_(x, y): - return x + jax.custom_derivatives.linear_call(fn, tp, y, x) - - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - f = lambda x: f_(x, y) - ft = transpose_unary(f, x) - ftt = transpose_unary(ft, x) - fttt = transpose_unary(ftt, x) - self.assertAllClose(ft(x), x + tp(y, x)) - self.assertAllClose(f(x), ftt(x)) - self.assertAllClose(ft(x), fttt(x)) - - def test_linear_call_scalar_to_vector(self): - def f(c, x): - def fn(_, x): - return [x, x] - - def tp(_, t): - t1, t2 = t - return t1 + t2 - - return jax.custom_derivatives.linear_call(fn, tp, (), c * x) - - def f_ref(c, x): - return [c * x, c * x] - - c, x = 2., 3. - t = [4., 5.] - self.assertAllClose(f(c, x), f_ref(c, x)) - self.assertAllClose(transpose_unary(partial(f, c), x)(t), - transpose_unary(partial(f_ref, c), x)(t)) - - def test_linear_call_nested(self): - # identity function with an untrue transpose of 0 - def id_(x): - def f(_, x): return x - def t(_, t): return 0. - return jax.custom_derivatives.linear_call(f, t, (), x) - - # identity function with an untrue transpose of 7, and where both - # forward and transpose have custom transpositions that should - # never end up invoked. - def f(x): - def f_(_, x): return id_(x) - def t_(_, t): return id_(7.) - return jax.custom_derivatives.linear_call(f_, t_, (), x) - - x = 5. - id_t = transpose_unary(id_, x) - id_tt = transpose_unary(id_t, x) - ft = transpose_unary(f, x) - ftt = transpose_unary(ft, x) - fttt = transpose_unary(ftt, x) - - self.assertAllClose(id_(x), x) - self.assertAllClose(id_t(x), 0.) - self.assertAllClose(id_tt(x), x) - - self.assertAllClose(f(x), x) - self.assertAllClose(ft(x), 7.) - self.assertAllClose(ftt(x), x) - self.assertAllClose(fttt(x), 7.) - - def test_linear_call_jit(self): - def f(x, y): - def fn(r, x): return x / r - def tp(r, t): return t / r - return x + jax.custom_derivatives.linear_call(fn, tp, y, x) - - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - self.assertAllClose(f(x, y), jax.jit(f)(x, y)) - - f1 = lambda x: f(x, y) - self.assertAllClose(transpose_unary(f1, x)(x), - jax.jit(transpose_unary(f1, x))(x)) - - def test_basic(self): - def f(x, y): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @fn.def_transpose - def tp(r, t): return t / r - - return x + fn(y, x) - - def f_ref(x, y): - return x + x / y - - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - self.assertAllClose(f(x, y), f_ref(x, y)) - - f1 = lambda x: f(x, y) - f1_ref = lambda x: f_ref(x, y) - self.assertAllClose(transpose_unary(f1, x)(x), - transpose_unary(f1_ref, x)(x)) - - def test_incorrect_transpose(self): - def f(x, y): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @fn.def_transpose - def tp(r, t): return t / (2. * r) # nb: not the true transpose - - return x + fn(y, x) - - def f_ref(x, y): - return x + x / y - - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - self.assertAllClose(f(x, y), f_ref(x, y)) - - f1 = lambda x: f(x, y) - f1_ref = lambda x: f_ref(x, 2. * y) # nb: double the reference divisor - self.assertAllClose(transpose_unary(f1, x)(x), - transpose_unary(f1_ref, x)(x)) - - def test_transpose_transpose_transpose(self): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @custom_transpose(jnp.ones(2)) - def tp(r, t): return t / (2. * r) # nb: untrue transpose - - fn.def_transpose(tp) - tp.def_transpose(fn) - - def f_(x, y): - return x + fn(y, x) - - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - f = lambda x: f_(x, y) - ft = transpose_unary(f, x) - ftt = transpose_unary(ft, x) - fttt = transpose_unary(ftt, x) - self.assertAllClose(ft(x), x + tp(y, x)) - self.assertAllClose(f(x), ftt(x)) - self.assertAllClose(ft(x), fttt(x)) - - def test_scalar_to_vector(self): - def f(c, x): - @custom_transpose([0., 0.]) - def fn(_, x): - return [x, x] - - @fn.def_transpose - def tp(_, t): - t1, t2 = t - return t1 + t2 - - return fn((), c * x) - - def f_ref(c, x): - return [c * x, c * x] - - c, x = 2., 3. - t = [4., 5.] - self.assertAllClose(f(c, x), f_ref(c, x)) - self.assertAllClose(transpose_unary(partial(f, c), x)(t), - transpose_unary(partial(f_ref, c), x)(t)) - - def test_nested(self): - # identity function with an untrue transpose of 0 - def id_(x): - f = custom_transpose(lambda _, x: x) - t = custom_transpose(lambda _, t: 0.) - f.def_transpose(t) - t.def_transpose(f) - return f((), x) - - # identity function with an untrue transpose of 7, and where both - # forward and transpose have custom transpositions that should - # never end up invoked. - def f(x): - f_ = custom_transpose(lambda _, x: id_(x)) - t_ = custom_transpose(lambda _, t: id_(7.)) - f_.def_transpose(t_) - t_.def_transpose(f_) - return f_((), x) - - x = 5. - id_t = transpose_unary(id_, x) - id_tt = transpose_unary(id_t, x) - ft = transpose_unary(f, x) - ftt = transpose_unary(ft, x) - fttt = transpose_unary(ftt, x) - - self.assertAllClose(id_(x), x) - self.assertAllClose(id_t(x), 0.) - self.assertAllClose(id_tt(x), x) - - self.assertAllClose(f(x), x) - self.assertAllClose(ft(x), 7.) - self.assertAllClose(ftt(x), x) - self.assertAllClose(fttt(x), 7.) - - def test_one_degree(self): - T = lambda f: transpose_unary(f, 0.) - - @custom_transpose - def f(_, z): return 2. * z - @f.def_transpose - def ft(_, z): return 3. * z - - f = partial(f, ()) - self.assertAllClose(2., f(1.)) - self.assertAllClose(3., T(f)(1.)) - self.assertAllClose(3., T(T(f))(1.)) - self.assertAllClose(3., T(T(T(f)))(1.)) - self.assertAllClose(3., T(T(T(T(f))))(1.)) # ... - - def test_two_degrees(self): - T = lambda f: transpose_unary(f, 0.) - - @custom_transpose - def f(_, z): return 2. * z - - @f.def_transpose - @custom_transpose - def ft(_, z): return 3. * z - - @ft.def_transpose - def ftt(_, z): return 7. * z - - f = partial(f, ()) - self.assertAllClose(2., f(1.)) - self.assertAllClose(3., T(f)(1.)) - self.assertAllClose(7., T(T(f))(1.)) - self.assertAllClose(7., T(T(T(f)))(1.)) - self.assertAllClose(7., T(T(T(T(f))))(1.)) # ... - - def test_symmetric(self): - T = lambda f: transpose_unary(f, 0.) - - @custom_transpose - def f(_, z): return 2. * z - @custom_transpose - def g(_, z): return 3. * z - - f.def_transpose(g) - g.def_transpose(f) - - f = partial(f, ()) - self.assertAllClose(2., f(1.)) - self.assertAllClose(3., T(f)(1.)) - self.assertAllClose(2., T(T(f))(1.)) - self.assertAllClose(3., T(T(T(f)))(1.)) - self.assertAllClose(2., T(T(T(T(f))))(1.)) # ... - - def test_recursive(self): - T = lambda f: transpose_unary(f, 0.) - - @custom_transpose - def f(c, z): return c * z - - @f.def_transpose - def ft(c, z): return f(c + 1., z) - - g = partial(f, 1.) - self.assertAllClose(1., g(1.)) - self.assertAllClose(2., T(g)(1.)) - self.assertAllClose(3., T(T(g))(1.)) - self.assertAllClose(4., T(T(T(g)))(1.)) - self.assertAllClose(5., T(T(T(T(g))))(1.)) # ... - - def test_jvp_lin(self): - def f(x, y): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @fn.def_transpose - def tp(r, t): return t / r - return x + fn(y, x) - - def f_ref(x, y): return x + x / y - - x, y, tx = 6., 3., 1. - g = lambda x: f(x, y) - g_ref = lambda x: f_ref(x, y) - self.assertAllClose(api.jvp(g, [x], [tx]), api.jvp(g_ref, [x], [tx])) - - def test_jvp_res(self): - raise unittest.SkipTest('unimplemented') # TODO(frostig) - - def f(x, y): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @fn.def_transpose - def tp(r, t): return t / r - return x + fn(y, x) - - def f_ref(x, y): return x + x / y - - x, y, ty = 6., 3., 1. - g = lambda y: f(x, y) - g_ref = lambda y: f_ref(x, y) - self.assertAllClose(api.jvp(g, [y], [ty]), api.jvp(g_ref, [y], [ty])) - - def test_jvp_both(self): - raise unittest.SkipTest('unimplemented') # TODO(frostig) - - def f(x, y): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @fn.def_transpose - def tp(r, t): return t / r - return x + fn(y, x) - - def f_ref(x, y): return x + x / y - - x, y, tx, ty = 6., 3., 1., 1. - self.assertAllClose(api.jvp(f, [x, y], [tx, ty]), - api.jvp(f_ref, [x, y], [tx, ty])) - - def test_make_jaxpr(self): - def f(x, y): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @fn.def_transpose - def tp(r, t): return 2 * t / r - - return x + fn(y, x) - - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - f_ = lambda x: f(x, y) - f_t = transpose_unary(f_, x) - - jaxpr = api.make_jaxpr(f_)(x) - self.assertIn('custom_transpose_call', str(jaxpr)) - - jaxpr_t = api.make_jaxpr(f_t)(x) - self.assertNotIn('custom_transpose_call', str(jaxpr_t)) - - def test_jit(self): - def f(x, y): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @fn.def_transpose - def tp(r, t): return 2 * t / r - - return x + fn(y, x) - - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - self.assertAllClose(f(x, y), jax.jit(f)(x, y)) - - f_ = lambda x: f(x, y) - f_t = transpose_unary(f_, x) - g_ = jax.jit(f_) - g_t = transpose_unary(g_, x) - self.assertAllClose(f_(x), jax.jit(f_)(x)) - self.assertAllClose(f_t(x), jax.jit(f_t)(x)) - self.assertAllClose(f_(x), g_(x)) - self.assertAllClose(f_t(x), g_t(x)) - - def test_jit_recursive(self): - def f(x, y): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @fn.def_transpose - def tp(r, t): return 2 * fn(r, t) - - return x + fn(y, x) - - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - self.assertAllClose(f(x, y), jax.jit(f)(x, y)) - - f_ = lambda x: f(x, y) - f_t = transpose_unary(f_, x) - g_ = jax.jit(f_) - g_t = transpose_unary(g_, x) - self.assertAllClose(f_(x), jax.jit(f_)(x)) - self.assertAllClose(f_t(x), jax.jit(f_t)(x)) - self.assertAllClose(f_(x), g_(x)) - self.assertAllClose(f_t(x), g_t(x)) - - def test_cond(self): - def f(x, y): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @fn.def_transpose - def tp(r, t): return 2 * t / r - - return x + fn(y, x) - - def cond_wrap(f): - return lambda i, x: lax.cond(i > 0, f, lambda x: x, x) - - i = 7. - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - - f_ = lambda x: f(x, y) - f_t = transpose_unary(f_, x) - g_ = partial(cond_wrap(f_), i) - g_t = transpose_unary(g_, x) - - self.assertAllClose(f_(x), g_(x)) - self.assertAllClose(f_t(x), g_t(x)) - - def test_cond_recursive(self): - def f(x, y): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @fn.def_transpose - def tp(r, t): return 2 * fn(r, t) - - return x + fn(y, x) - - def cond_wrap(f): - return lambda i, x: lax.cond(i > 0, f, lambda x: x, x) - - i = 7. - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - - f_ = lambda x: f(x, y) - f_t = transpose_unary(f_, x) - g_ = partial(cond_wrap(f_), i) - g_t = transpose_unary(g_, x) - - self.assertAllClose(f_(x), g_(x)) - self.assertAllClose(f_t(x), g_t(x)) - - -class CustomDceTest(jtu.JaxTestCase): - - def test_basic(self): - @jax.experimental.custom_dce.custom_dce - def f(x): - return jnp.sin(x), jnp.cos(x) - - @f.def_dce - def rule(used_outs, x): - return ( - jnp.exp(x) if used_outs[0] else None, - jnp.sqrt(x) if used_outs[1] else None, - ) - - x = jnp.array(1.1234) - self.assertAllClose(jax.jit(lambda x: f(x)[0])(x), jnp.exp(x)) - self.assertAllClose(jax.jit(lambda x: f(x)[1])(x), jnp.sqrt(x)) - - def test_recursive(self): - @jax.experimental.custom_dce.custom_dce - def f(x): - return jnp.exp(x), 10 * jnp.sqrt(x) - - @f.def_dce - def f_dce(used_outs, x): - return [2 * v if used else None for used, v in zip(used_outs, f(x))] - - x = 1.1234 - expected = f(x) - self.assertAllClose(jax.jit(lambda x: f(x)[0])(x), 2 * expected[0]) - self.assertAllClose(jax.jit(lambda x: f(x)[1])(x), 2 * expected[1]) - - def test_multiple_rounds(self): - @jax.experimental.custom_dce.custom_dce - def f(x, y, z): - return jnp.sin(x), jnp.sin(y), jnp.sin(z) - - @f.def_dce - def rule(used_outs, x, y, z): - patterns.append(used_outs) - outs = [ - jnp.cos(v) if used else None for used, v in zip(used_outs, (x, y, z)) - ] - return outs - - patterns = [] - x, y, z = jnp.array(1.), jnp.array(2.), jnp.array(3.) - jaxpr = jax.make_jaxpr(f)(x, y, z).jaxpr - new_jaxpr, used_ins = pe.dce_jaxpr(jaxpr, [True, False, True]) - assert used_ins == [True, False, True] - new_jaxpr, used_ins = pe.dce_jaxpr(new_jaxpr, [True, False]) - assert used_ins == [True, False] - assert patterns == [(True, False, True), (True, False, False)], patterns - - def test_batching(self): - @jax.experimental.custom_dce.custom_dce - def f(x, y): - return jnp.sin(x), jnp.sin(y) - - @f.def_dce - def rule(used_outs, x, y): - return ( - jnp.cos(x) if used_outs[0] else None, - jnp.cos(y) if used_outs[1] else None, - ) - - x = jnp.linspace(-0.1, 0.2, 5) - y = jnp.linspace(3.0, 4.0, 5) - self.assertAllClose(jax.vmap(f)(x, y), f(x, y)) - self.assertAllClose( - jax.jit(lambda *args: jax.vmap(f)(*args)[0])(x, y), jnp.cos(x) - ) - self.assertAllClose( - jax.vmap(jax.jit(lambda *args: f(*args)[0]))(x, y), jnp.cos(x) - ) - self.assertAllClose( - jax.jit(lambda *args: jax.vmap(f)(*args)[1])(x, y), jnp.cos(y) - ) - self.assertAllClose( - jax.vmap(jax.jit(lambda *args: f(*args)[1]))(x, y), jnp.cos(y) - ) - - def test_composes_with_custom_vjp(self): - # custom_dce must be the "outer" decorator (for now!) because custom_vjp - # doesn't pass through DCE. - @jax.experimental.custom_dce.custom_dce - @jax.custom_vjp - def f(x, y): - return jnp.sin(x) * y, x * jnp.sin(y) - - @f.def_dce - def f_dce_rule(used_outs, x, y): - return ( - jnp.cos(x) * y if used_outs[0] else None, - x * jnp.cos(y) if used_outs[1] else None, - ) - - def f_fwd(x, y): - return f(x, y), (x, jnp.cos(x), jnp.sin(x), y, jnp.cos(y), jnp.sin(y)) - - def f_bwd(res, g): - ga, gb = g - x, cos_x, sin_x, y, cos_y, sin_y = res - return (cos_x * ga * y + sin_y * gb, sin_x * ga + x * cos_y * gb) - - f.defvjp(f_fwd, f_bwd) - - x, y = jnp.array(1.), jnp.array(2.) - self.assertAllClose(jax.jit(lambda *args: f(*args)[0])(x, y), - jnp.cos(x) * y) - jax.grad(lambda *args: f(*args)[0])(x, y) # Doesn't crash. - - def test_can_optimize_remat(self): - @jax.custom_vjp - def f(x): - return jnp.tan(x) - - @jax.experimental.custom_dce.custom_dce - def f_fwd(x): - return jnp.sin(x), (x,) - - @f_fwd.def_dce - def f_dce_rule(used_outs, x): - used_prim, used_res = used_outs - used_res, = used_res - if not used_res: - return f(x), None - prim, res = f_fwd(x) - return prim if used_prim else None, res - - def f_bwd(res, g): - x, = res - cos_x = jnp.cos(x) - return (cos_x * g,) - - f.defvjp(f_fwd, f_bwd) - - def temp(x): - out = jax.remat(f)(x) - out = out ** 2 - return out - - v, g = jax.value_and_grad(temp)(3.2) - self.assertAllClose(v, jnp.tan(3.2)**2) - - def test_static_argnums(self): - @partial(jax.experimental.custom_dce.custom_dce, static_argnums=(0,)) - def g(f, x): - return f(x), 10 * f(x) - - @g.def_dce - def g_dce(f, used_outs, x): # note: static_argnums are always passes first - self.assertTrue(callable(f)) - return [2 * v if used else None for used, v in zip(used_outs, g(f, x))] - - x = 1.1234 - f = lambda x: jnp.exp(x) - expected = g(f, x) - self.assertAllClose(jax.jit(lambda x: g(f, x)[0])(x), 2 * expected[0]) - self.assertAllClose(jax.jit(lambda x: g(f, x)[1])(x), 2 * expected[1]) - - def test_shape_mismatch_error(self): - @jax.experimental.custom_dce.custom_dce - def f(x): - return jnp.stack((x, x)), jnp.cos(x) - - @f.def_dce - def rule(used_outs, x): - return ( - jnp.exp(x) if used_outs[0] else None, - x.astype(jnp.int32) if used_outs[1] else None, - ) - - x = jnp.array(1.1234) - with self.assertRaisesRegex( - ValueError, - r'Custom DCE rule .* same shapes/dtypes .* output\[0\]', - ): - jax.jit(lambda x: f(x)[0])(x) - with self.assertRaisesRegex( - ValueError, - r'Custom DCE rule .* same shapes/dtypes .* output\[1\]', - ): - jax.jit(lambda x: f(x)[1])(x) - - def test_missing_output_error(self): - @jax.experimental.custom_dce.custom_dce - def f(x): - return jnp.sin(x), jnp.cos(x) - - @f.def_dce - def rule(used_outs, x): - return None, None - - x = jnp.array(1.1234) - with self.assertRaisesRegex( - ValueError, - r'Custom DCE rule .* produce values for all .* output\[0\]', - ): - jax.jit(lambda x: f(x)[0])(x) - - def test_consts(self): - @jax.experimental.custom_dce.custom_dce - def f(x): - return np.eye(1) * jnp.sin(x), jnp.cos(x) - - @f.def_dce - def rule(used_outs, x): - return ( - np.full((1, 1), 2.0) * jnp.exp(x) if used_outs[0] else None, - jnp.sqrt(x) if used_outs[1] else None, - ) - - x = jnp.array(1.1234) - expected = rule([True, True], x) - self.assertAllClose(jax.jit(lambda x: f(x)[0])(x), expected[0]) - self.assertAllClose(jax.jit(lambda x: f(x)[1])(x), expected[1]) - - def test_resolve_kwargs_error_message(self): - @jax.experimental.custom_dce.custom_dce - def f(x, y, *, z=None): - return jnp.sin(x) * y, x * jnp.sin(y) - - @f.def_dce - def f_dce_rule(used_outs, x, y): - self.fail("should not be executed") - - with self.assertRaisesRegex( - TypeError, - r"The input arguments to the custom_dce-decorated function f(.*)\n" - r"missing a required argument: 'y'" - ): - f(0.5) - - with self.assertRaisesRegex( - TypeError, - r"The input arguments to the custom_dce-decorated function f(.*)\n" - "The following keyword arguments could not be resolved to positions: z" - ): - f(0.5, 0.1, z=1.0) - - -class CustomVmapTest(jtu.JaxTestCase): - - def test_basic(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - xs_batched, = in_batched - self.assertEqual(xs_batched, True) - self.assertEqual(axis_size, xs.shape[0]) - return jnp.cos(xs), xs_batched - - x, xs = jnp.array(1.), jnp.arange(3) - y = f(x) - self.assertAllClose(y, jnp.sin(x)) - ys = api.vmap(f)(xs) - self.assertAllClose(ys, jnp.cos(xs)) - - @jax.numpy_dtype_promotion('standard') - def test_closure(self): - z = jnp.array([2., 1., 3.]) - - @jax.custom_batching.custom_vmap - def f(x): return z + jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, *args): - self.assertEqual(len(in_batched), 1) - self.assertEqual(len(args), 1) - xs, = args - xs_batched, = in_batched - self.assertEqual(xs_batched, True) - self.assertEqual(axis_size, xs.shape[0]) - return z + jnp.cos(xs), xs_batched - - x, xs = jnp.array(1.), jnp.arange(3) - y = f(x) - self.assertAllClose(y, z + jnp.sin(x)) - ys = api.vmap(f)(xs) - self.assertAllClose(ys, z + jnp.cos(xs)) - - def test_rule_multi_output(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x), jnp.cos(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - return (jnp.cos(xs), jnp.sin(xs)), tuple(in_batched * 2) - - x, xs = jnp.array(1.), jnp.arange(3) - y1, y2 = f(x) - self.assertAllClose(y1, jnp.sin(x)) - self.assertAllClose(y2, jnp.cos(x)) - ys1, ys2 = api.vmap(f)(xs) - self.assertAllClose(ys1, jnp.cos(xs)) - self.assertAllClose(ys2, jnp.sin(xs)) - - def test_nary(self): - @jax.custom_batching.custom_vmap - def f(x, y): return jnp.sin(x) + y ** 2. - - @f.def_vmap - def rule(axis_size, in_batched, xs, ys): - self.assertEqual(in_batched, [True, True]) - self.assertEqual(axis_size, 3) - self.assertEqual(axis_size, xs.shape[0]) - self.assertEqual(axis_size, ys.shape[0]) - return jnp.cos(xs) + ys ** 2., True - - xs, ys = jnp.arange(3.0), jnp.arange(3.0) - zs = api.vmap(f)(xs, ys) - self.assertAllClose(zs, jnp.cos(xs) + ys ** 2.) - - def test_nary_mixed_batching(self): - @jax.custom_batching.custom_vmap - def vector_dot(u, v): - self.assertEqual(u.ndim, 1) - self.assertEqual(v.ndim, 1) - return u @ v - - size = 4 - vlen = 3 - in_batched_log = [] - - @vector_dot.def_vmap - def vector_dot_vmap_rule(axis_size, in_batched, u, v): - in_batched_log.append(in_batched) - self.assertEqual(axis_size, size) - u_batched, v_batched = in_batched - if u_batched: - self.assertEqual(u.ndim, 2) - self.assertEqual(u.shape[0], size) - else: - self.assertEqual(u.ndim, 1) - self.assertEqual(u.shape[0], vlen) - if v_batched: - self.assertEqual(v.ndim, 2) - self.assertEqual(v.shape[0], size) - else: - self.assertEqual(v.ndim, 1) - self.assertEqual(v.shape[0], vlen) - if u_batched and v_batched: - out = jnp.sum(u * v, axis=1) - else: - out = u @ v if u_batched else v @ u - return out, u_batched or v_batched - - f = vector_dot - v = lambda *shape: jnp.ones(shape) - - y = api.vmap(f, in_axes=(0, None))(v(4, 3), v(3)) - self.assertAllClose(y, v(4, 3) @ v(3)) - y = api.vmap(f, in_axes=(1, None))(v(3, 4), v(3)) - self.assertAllClose(y, v(3, 4).T @ v(3)) - y = api.vmap(f, in_axes=(None, 0))(v(3), v(4, 3)) - self.assertAllClose(y, v(3) @ v(4, 3).T) - y = api.vmap(f, in_axes=(0, 0))(v(4, 3), v(4, 3)) - self.assertAllClose(y, jnp.sum(v(4, 3) * v(4, 3), axis=1)) - self.assertEqual(in_batched_log[0], [True, False]) - self.assertEqual(in_batched_log[1], [True, False]) - self.assertEqual(in_batched_log[2], [False, True]) - self.assertEqual(in_batched_log[3], [True, True]) - - def test_rule_input_signature(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - rule_args = [] - - @f.def_vmap - def rule(axis_size, in_batched, xs): - rule_args.append((axis_size, in_batched)) - return jnp.cos(xs), in_batched[0] - - xs = jnp.arange(3) - _ = api.vmap(f)(xs) - (axis_size, in_batched), = rule_args - self.assertIs(type(axis_size), int) - self.assertIs(type(in_batched), list) - self.assertEqual(len(in_batched), 1) - - def test_rule_output_vs_batching_output_mismatch(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def test_rule_abc(axis_size, in_batched, xs): - return [jnp.sin(xs), jnp.cos(xs)], in_batched - - xs = jnp.arange(3) - self.assertRaisesRegex( - ValueError, - 'structure of output value and output batching specification ' - r'returned by custom vmap rule \(test_rule_abc\) do not match.*', - lambda: api.vmap(f)(xs)) - - def test_rule_vs_call_output_mismatch(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def test_rule_abc2(axis_size, in_batched, xs): - return [jnp.sin(xs)], in_batched - - xs = jnp.arange(3) - self.assertRaisesRegex( - ValueError, - r'structure of output returned by custom vmap rule \(test_rule_abc2\) ' - r'does not match that of original custom-vmapped function.*', - lambda: api.vmap(f)(xs)) - - def test_jvp_basic(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - self.assertEqual(axis_size, 3) - self.assertEqual(in_batched, [True]) - return jnp.cos(xs), in_batched[0] - - f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) - - x, tx = jnp.array(1.), jnp.array(2.) - xs, txs = jnp.arange(3.), jnp.arange(3.) * 2. - - y, ty = f_jvp(x, tx) - self.assertAllClose(y, jnp.sin(x)) - self.assertAllClose(ty, jnp.cos(x) * tx) - - ys, tys = api.vmap(f_jvp)(xs, txs) - self.assertAllClose(ys, jnp.cos(xs)) - self.assertAllClose(tys, -jnp.sin(xs) * txs) - - ys, tys = api.jvp(api.vmap(f), [xs], [txs]) - self.assertAllClose(ys, jnp.cos(xs)) - self.assertAllClose(tys, -jnp.sin(xs) * txs) - - @jax.numpy_dtype_promotion('standard') - def test_jvp_closure(self): - z = jnp.array([2., 1., 3.]) - def bcast(x): return z + x - z - - @jax.custom_batching.custom_vmap - def f(x): return z + jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - self.assertEqual(axis_size, 3) - self.assertEqual(in_batched, [True]) - return z + jnp.cos(xs), in_batched[0] - - f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) - - x, tx = jnp.array(1.), jnp.array(2.) - xs, txs = jnp.arange(3.), jnp.arange(3.) * 2. - - y, ty = f_jvp(x, tx) - self.assertAllClose(y, z + jnp.sin(x)) - self.assertAllClose(ty, bcast(jnp.cos(x)) * tx) - - ys, tys = api.vmap(f_jvp)(xs, txs) - self.assertAllClose(ys, z + jnp.cos(xs)) - self.assertAllClose(tys, bcast(-jnp.sin(xs)) * txs) - - ys, tys = api.jvp(api.vmap(f), [xs], [txs]) - self.assertAllClose(ys, z + jnp.cos(xs)) - self.assertAllClose(tys, bcast(-jnp.sin(xs)) * txs) - - def test_jvp_nary(self): - @jax.custom_batching.custom_vmap - def f(x, y): return jnp.sin(x) + y - - @f.def_vmap - def rule(axis_size, in_batched, xs, ys): - self.assertEqual(axis_size, 3) - self.assertEqual(in_batched, [True, True]) - return jnp.cos(xs) + ys, True - - f_jvp = lambda x, y, tx, ty: api.jvp(f, [x, y], [tx, ty]) - - x, y, tx, ty = jnp.arange(4.) - xs, ys, txs, tys = 4. + jnp.arange(3. * 4).reshape((4, 3)) - - zs, tzs = api.vmap(f_jvp)(xs, ys, txs, tys) - self.assertAllClose(zs, jnp.cos(xs) + ys) - self.assertAllClose(tzs, -jnp.sin(xs) * txs + tys) - - zs, tzs = api.jvp(api.vmap(f), [xs, ys], [txs, tys]) - self.assertAllClose(zs, jnp.cos(xs) + ys) - self.assertAllClose(tzs, -jnp.sin(xs) * txs + tys) - - def test_jvp_extra_batched_tangents(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - self.assertEqual(axis_size, 3) - self.assertEqual(in_batched, [False]) - return jnp.cos(xs), in_batched[0] - - f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) - - txs = 2. + jnp.arange(3.) - x = jnp.array(1, dtype=txs.dtype) - y, tys = api.vmap(f_jvp, in_axes=(None, 0), out_axes=(None, 0))(x, txs) - self.assertAllClose(y, jnp.cos(x)) - self.assertAllClose(tys, -jnp.sin(x) * txs) - - def test_jacfwd(self): - # jacfwd is another way to exercise extra-batched tangents - - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - self.assertEqual(axis_size, 3) - self.assertEqual(in_batched, [False]) - return jnp.cos(xs), in_batched[0] - - x = jnp.arange(3.) + .72 - j = api.jacfwd(f)(x) - self.assertAllClose(j, -jnp.diag(jnp.sin(x))) - - def test_jvp_extra_batched_primals(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - self.assertEqual(axis_size, 3) - self.assertEqual(in_batched, [False]) - return jnp.cos(xs), in_batched[0] - - f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) - - xs = jnp.arange(3.) - tx = jnp.array(4, dtype=xs.dtype) - ys, tys = api.vmap(f_jvp, in_axes=(0, None))(xs, tx) - self.assertAllClose(ys, jnp.cos(xs)) - self.assertAllClose(tys, -jnp.sin(xs) * tx) - - def test_jvp_extra_batched_primals_with_linear_vmap_rule(self): - # When a function is linear, its Jacobian is constant. JAX's JVP - # of linear functions takes advantage of this: when mapping over a - # batch of primals relative to a fixed (i.e. symbolically - # replicated) tangent, output tangents remain replicated as well - # (i.e. JAX will not broadcast them). This is true in general, and - # this test checks that vmapped JVPs continue to behave this way - # when custom_vmap is involved and the custom vmap rule is linear. - - @jax.custom_batching.custom_vmap - def f_linear(x): return 7. * x - - @f_linear.def_vmap - def linear_rule(axis_size, in_batched, xs): - return 11. * xs, in_batched[0] - - @jax.custom_batching.custom_vmap - def f_nonlinear(x): return jnp.sin(x) - - @f_nonlinear.def_vmap - def nonlinear_rule(axis_size, in_batched, xs): - return jnp.cos(xs), in_batched[0] - - f_lin_jvp = lambda x, tx: api.jvp(f_linear, [x], [tx]) - f_non_jvp = lambda x, tx: api.jvp(f_nonlinear, [x], [tx]) - xs = jnp.arange(3.) - tx = jnp.array(4., dtype=xs.dtype) - - # doesn't err - _ = api.vmap(f_lin_jvp, in_axes=(0, None), out_axes=(0, None))(xs, tx) - - # does err - self.assertRaisesRegex( - ValueError, "at vmap out_axes", - lambda: api.vmap( - f_non_jvp, in_axes=(0, None), out_axes=(0, None))(xs, tx)) - - def test_jvp_dataflow_violation(self): - # The jvp-of-custom-vmap machinery should not assume the standard - # dataflow constraint on the JVP of the custom vmap rule (primal - # outputs independent of tangent inputs). Both jvp and vmap are - # "forward" transformations under which, at present, we don't - # enforce the JVP dependence diagram. Because output primals can - # depend on input tangents, extra-batched input tangents can - # create batched output primals, as this test checks. - - @jax.custom_jvp - def cos_with_invalid_dataflow_jvp(x): return jnp.cos(x) - - @cos_with_invalid_dataflow_jvp.defjvp - def invalid_dataflow_jvp(x, tx): - [x], [tx] = x, tx - return jnp.cos(x * tx), tx - - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - return cos_with_invalid_dataflow_jvp(xs), in_batched[0] - - f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) - txs = 2. + jnp.arange(3.) - x = jnp.array(1, dtype=txs.dtype) - - # doesn't err - ys, tys = api.vmap(f_jvp, in_axes=(None, 0))(x, txs) - self.assertAllClose(ys, jnp.cos(x * txs)) - self.assertAllClose(tys, txs) - - # does err - self.assertRaisesRegex( - ValueError, "at vmap out_axes", - lambda: api.vmap( - f_jvp, in_axes=(None, 0), out_axes=(None, 0))(x, txs)) - - def test_tree(self): - tree_sin = partial(jax.tree.map, jnp.sin) - tree_cos = partial(jax.tree.map, jnp.cos) - - x, xs = jnp.array(1.), jnp.arange(3) - x = (x, [x + 1, x + 2], [x + 3], x + 4) - xs = (xs, [xs + 1, xs + 2], [xs + 3], xs + 4) - in_batched_ref = jax.tree.map(lambda _: True, x) - - @jax.custom_batching.custom_vmap - def f(xs): return tree_sin(xs) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - self.assertEqual(in_batched, [in_batched_ref]) - sz, = {z.shape[0] for z in jax.tree.leaves(xs)} - self.assertEqual(axis_size, sz) - return tree_cos(xs), in_batched[0] - - y = f(x) - self.assertAllClose(y, tree_sin(x)) - ys = api.vmap(f)(xs) - self.assertAllClose(ys, tree_cos(xs)) - - def test_tree_with_nones(self): - tree_sin = partial(jax.tree.map, jnp.sin) - tree_cos = partial(jax.tree.map, jnp.cos) - - x, xs = jnp.array(1.), jnp.arange(3) - x = (x, [x + 1, None], [x + 3], None) - xs = (xs, [xs + 1, None], [xs + 3], None) - in_batched_ref = jax.tree.map(lambda _: True, x) - - @jax.custom_batching.custom_vmap - def f(xs): return tree_sin(xs) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - self.assertEqual(in_batched, [in_batched_ref]) - sz, = {z.shape[0] for z in jax.tree.leaves(xs)} - self.assertEqual(axis_size, sz) - return tree_cos(xs), in_batched[0] - - y = f(x) - self.assertAllClose(y, tree_sin(x)) - ys = api.vmap(f)(xs) - self.assertAllClose(ys, tree_cos(xs)) - - def test_jit(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - self.assertEqual(in_batched, [True]) - self.assertEqual(axis_size, xs.shape[0]) - return jnp.cos(xs), in_batched[0] - - x, xs = jnp.array(1.), jnp.arange(3) - self.assertAllClose(f(x), jit(f)(x)) - self.assertAllClose(jit(api.vmap(f))(xs), api.vmap(f)(xs)) - self.assertAllClose(api.vmap(jit(f))(xs), api.vmap(f)(xs)) - - def test_sequential_vmap_basic(self): - @jax.custom_batching.sequential_vmap - def f(x): - return x + 1. - - def vmap_ref(xs): - return lax.map(f, xs) - - xs = jnp.arange(3.) - jaxpr = api.make_jaxpr(api.vmap(f))(xs) - jaxpr_ref = api.make_jaxpr(vmap_ref)(xs) - - self.assertEqual(str(jaxpr), str(jaxpr_ref)) - - def test_sequential_vmap_nary_same_batching(self): - @jax.custom_batching.sequential_vmap - def f(x, y): - return x + y - - def vmap_ref(xs, ys): - return lax.map(lambda args: f(*args), (xs, ys)) - - xs, ys = jnp.arange(3.), 4. + jnp.arange(3.) - jaxpr = api.make_jaxpr(api.vmap(f))(xs, ys) - jaxpr_ref = api.make_jaxpr(vmap_ref)(xs, ys) - - self.assertEqual(str(jaxpr), str(jaxpr_ref)) - - def test_sequential_vmap_nary_mixed_batching(self): - @jax.custom_batching.sequential_vmap - def f(x, y): - return x + y - - def vmap_ref(xs, y): - return lax.map(lambda x: f(x, y), xs) - - xs, y = jnp.arange(3.), 4. - jaxpr = api.make_jaxpr(api.vmap(f, in_axes=(0, None)))(xs, y) - jaxpr_ref = api.make_jaxpr(vmap_ref)(xs, y) - - self.assertEqual(str(jaxpr), str(jaxpr_ref)) - - @parameterized.named_parameters( - ("1", 1), - ("8", 4), - ("12", 8), - ("16", 16), - ) - def test_batch_map_basic(self, batch_size: int): - def f(x): - self.assertEqual(x.shape, ()) - return x**2 - - x = np.arange(16) - y = jax.lax.map(f, x, batch_size=batch_size) - - np.testing.assert_array_equal(y, x**2) - - @parameterized.named_parameters( - ("1", 1), - ("8", 4), - ("12", 8), - ("16", 16), - ) - def test_batch_map_pytrees(self, batch_size: int): - f = lambda x: {'b': x['a'] ** 2} - inputs = {'a': np.arange(16)} - expected = np.arange(16) ** 2 - - outputs = jax.lax.map(f, inputs, batch_size=batch_size) - self.assertAllClose(outputs['b'], expected) - - outputs = jax.lax.map( - f, inputs, batch_size=batch_size - ) - self.assertAllClose(outputs['b'], expected) - - def test_batch_divides_axis(self): - def f(t): - x, a = t - self.assertEqual(x.shape, (4,)) - return (x + a)**2 - - x = jax.random.randint(jax.random.key(0), (16, 4), -10, 10) - a = jax.random.randint(jax.random.key(1), (16, 4), -10, 10) - - @jax.jit - def g(x, a): - return jax.lax.map(f, (x, a), batch_size=8) - - y = g(x, a) - - self.assertAllClose(y, (x + a)**2) - - def test_undefined_rule(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - with self.assertRaisesRegex( - AttributeError, "No batching rule defined for custom_vmap function f"): - f(0.5) - - def test_kwargs(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - xs_batched, = in_batched - self.assertEqual(xs_batched, True) - self.assertEqual(axis_size, xs.shape[0]) - return jnp.cos(xs), xs_batched - - x, xs = jnp.array(1.), jnp.arange(3) - y = f(x=x) - self.assertAllClose(y, jnp.sin(x)) - ys = api.vmap(f)(x=xs) - self.assertAllClose(ys, jnp.cos(xs)) - - def test_partial_eval_raises(self): - @jax.custom_batching.custom_vmap - def f(x): - return jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - del axis_size # unused - return jnp.cos(xs), in_batched[0] - - with self.assertRaisesRegex( - ValueError, - "Linearization failed to produce known values for all output primals", - ): - jax.grad(f)(0.5) - - def test_compose_custom_vjp(self): - @jax.custom_vjp - @jax.custom_batching.custom_vmap - def f(x, y): - return jnp.sin(x) * y - - @f.def_vmap - def f_vmap_rule(axis_size, in_batched, xs, ys): - return jnp.cos(xs) * ys, True - - def f_fwd(x, y): - return f(x, y), (jnp.cos(x), jnp.sin(x), y) - - def f_bwd(res, g): - cos_x, sin_x, y = res - return (cos_x * g * y, sin_x * g) - - f.defvjp(f_fwd, f_bwd) - - xs = jnp.linspace(0, 1, 5) - ys = jnp.linspace(-0.1, 0.1, 5) - self.assertAllClose(jax.vmap(f)(xs, ys), jnp.cos(xs) * ys) - jax.grad(f)(xs[0], ys[0]) # Doesn't crash. - - def test_compose_custom_vjp_bwd_rule(self): - # This tests the case where both the forward and backward rules are wrapped - # in custom_vmap. - @jax.custom_batching.sequential_vmap - def fun_fwd(x, y): - return jnp.sin(x) * y, (x, y) - - @jax.custom_batching.sequential_vmap - def fun_bwd(res, ct): - x, y = res - return x * ct, y * ct - - fun = jax.custom_vjp(lambda *args: fun_fwd(*args)[0]) - fun.defvjp(fun_fwd, fun_bwd) - - xs = jnp.linspace(0, 1, 5) - y = jnp.array(0.5, dtype=xs.dtype) - f = jax.vmap(jax.jit(fun), in_axes=(0, None)) - out, f_vjp = jax.vjp(f, xs, y) - f_vjp(out) # Doesn't crash. - - def test_resolve_kwargs_error_message(self): - @jax.custom_batching.custom_vmap - def f(x, y, *, z=None): - return jnp.sin(x) * y - - @f.def_vmap - def f_vmap_rule(axis_size, in_batched, xs, ys): - self.fail("should not be executed") - - with self.assertRaisesRegex( - TypeError, - r"The input arguments to the custom_vmap-decorated function f(.*)\n" - r"missing a required argument: 'y'" - ): - f(0.5) - - with self.assertRaisesRegex( - TypeError, - r"The input arguments to the custom_vmap-decorated function f(.*)\n" - "The following keyword arguments could not be resolved to positions: z" - ): - f(0.5, 0.1, z=1.0) - - -class CustomApiTest(jtu.JaxTestCase): - """Test interactions among the custom_{vmap,jvp,vjp,transpose,*} APIs""" - - def test_method_forwarding(self): - @jax.custom_batching.custom_vmap - @jax.custom_jvp - @jax.custom_transpose.custom_transpose - def f(x): return 2. * x - - # none of these err: - @f.def_vmap - def f_batch(sz, b, xs): return 2. * xs - @f.defjvp - def f_jvp(x, tx): return 2. * x, 2. * tx - @f.def_transpose - def f_transpose(x): return 2. * x - - def test_def_method_forwarding_all_permutations(self): - for wraps in it.permutations([ - jax.custom_jvp, jax.custom_transpose.custom_transpose, jax.custom_batching.custom_vmap]): - f = lambda x: x + 1. - for wrap in wraps: - f = wrap(f) - for methods in it.permutations(['defjvp', 'def_vmap', 'def_transpose']): - for method in methods: - self.assertIsInstance(getattr(f, method), Callable) - - for decorators in it.permutations([ - jax.custom_vjp, jax.custom_transpose.custom_transpose, jax.custom_batching.custom_vmap]): - f = lambda x: x + 1. - for decorator in decorators: - f = decorator(f) - for methods in it.permutations(['defvjp', 'def_vmap', 'def_transpose']): - for method in methods: - self.assertIsInstance(getattr(f, method), Callable) - - -class BufferDonationTest(jtu.BufferDonationTestCase): - - @jtu.device_supports_buffer_donation() - def test_pmap_donate_argnums_invalidates_input(self): - move = api.pmap(lambda x: x + x - x, donate_argnums=0) - n = jax.local_device_count() - x = api.pmap(lambda x: x)(jnp.ones([n])) - y = move(x) - self.assertDeleted(x) - np.testing.assert_allclose(y, [1.] * n) - - @jtu.device_supports_buffer_donation() - def test_pmap_nested_donate_ignored(self): - pmap_fun = jit(lambda x: api.pmap(lambda y: y ** 2, donate_argnums=0)(x)) - a = api.pmap(lambda x: x)(jnp.array([1])) - - # NOTE(mattjj): stopped raising error here and instead just ignored - # with self.assertRaisesRegex(ValueError, "nested.*not supported"): - # pmap_fun(a) - - pmap_fun(a) # doesn't crash - - -class NamedCallTest(jtu.JaxTestCase): - - def test_non_jaxtype_arg(self): - # For the test to fail without the invalid JaxType filter we need to pass - # in a valid JaxType that forces the invalid Jaxtype to be raised to an - # abstract value. - def f(not_a_jaxtype, a_jaxtype): - # then Jax needs to try and evaluate the abstractified non-JaxType - if not_a_jaxtype: - return a_jaxtype - return 0 - - f = api.named_call(f, name="test") - out = jax.jit(f, static_argnums=(0,))("not a Jaxtype", 1) - self.assertEqual(out, 1) - - @parameterized.parameters(jax.jit, jax.grad, jax.vmap, jax.remat) - def test_jax_transforms(self, transform): - f = jnp.sum - x = jnp.array([1.]) - - unnamed_out = transform(f)(x) - named_out = transform(api.named_call(f, name="test"))(x) - - self.assertEqual(unnamed_out, named_out) - - def test_static_argnums(self): - f = api.named_call(lambda x, y: y if x else None, name="test") - f = jax.jit(f, static_argnums=(0,)) - out = f(True, 5) - self.assertEqual(out, 5) - - def test_partial_eval(self): - f = api.named_call(lambda x, y: y if x else None, name="test") - f = jax.jit(functools.partial(f, True)) - out = f(5) - self.assertEqual(out, 5) - - @parameterized.parameters( - [dict(func=func, jit=jit) - for func in ['identity_trivial', 'identity', 'closure_trivial', 'closure', - 'asarray', 'device_put'] - for jit in jtu.JIT_IMPLEMENTATION - if not (jit._name == "noop" and func in ('identity', 'identity_trivial')) - ], - ) - def test_integer_overflow(self, jit, func): - funcdict = { - 'identity_trivial': lambda x: x, # may hit trivial dispatch path - 'identity': lambda x: x + 0, - 'closure_trivial': lambda x: jax.jit(lambda: x)(), - 'closure': lambda x: jax.jit(lambda: x + 0)(), - 'asarray': lambda x: jnp.asarray(x), # add lambdas so no cross-test cache - 'device_put': lambda x: api.device_put(x), - } - - f = jit(funcdict[func]) - - int_dtype = dtypes.canonicalize_dtype(jnp.int64) - int_max = np.iinfo(int_dtype).max - int_min = np.iinfo(int_dtype).min - - # check before any jit cache entries - self.assertRaises(OverflowError, f, int_max + 1) - self.assertRaises(OverflowError, f, int_min - 1) - - self.assertEqual(f(int_max).dtype, int_dtype) - self.assertEqual(f(int_min).dtype, int_dtype) - self.assertAllClose(f(int_max), int_max) - self.assertAllClose(f(int_min), int_min) - - # check after any cache entries - self.assertRaises(OverflowError, f, int_max + 1) - self.assertRaises(OverflowError, f, int_min - 1) - if func in ('trivial', 'identity'): - self.assertRaisesRegex( - OverflowError, 'An overflow.*whose argument path is x.', f, - int_max + 1) - - -class BackendsTest(jtu.JaxTestCase): - - @unittest.skipIf(not sys.executable, "test requires sys.executable") - @jtu.run_on_devices("cpu") - def test_no_backend_warning_on_cpu_if_platform_specified(self): - warning_not_expected = ( - "import jax; " - "jax.config.update('jax_platform_name', 'cpu'); " - "jax.numpy.arange(10)") - - result = subprocess.run([sys.executable, '-c', warning_not_expected], - check=True, capture_output=True) - assert "may be present" not in result.stderr.decode() - - -class CleanupTest(jtu.JaxTestCase): - def test_call_wrapped_second_phase_cleanup(self): - try: - jax.vmap(lambda x: x, out_axes=None)(jnp.arange(3)) - except: - assert core.trace_state_clean() # this is the hard one - assert core.trace_state_clean() +class CleanupTest(jtu.JaxTestCase): + def test_call_wrapped_second_phase_cleanup(self): + try: + jax.vmap(lambda x: x, out_axes=None)(jnp.arange(3)) + except: + assert core.trace_state_clean() # this is the hard one + assert core.trace_state_clean() class EnvironmentInfoTest(jtu.JaxTestCase): @@ -11504,5 +7752,216 @@ def wsc_as_noop(ctx, operand, *args, **kwargs): self.assertNotIn("stablehlo.custom_call @Sharding", lowered_ir) +class InputSavedVJPTest(jtu.JaxTestCase): + + def test_basic(self): + def f(x, y): + return x * y + + primals = [2., 3.] + y, f_vjp = jax.vjp(f, *primals) + f_vjp.args_res = [None, None] + y_grad = 1. + f_vjp.args_res = primals + arg_cts = f_vjp(1.) + self.assertAllClose(y, 6.) + self.assertAllClose(arg_cts, (3., 2.)) + + def test_basic_pass_through_jit(self): + def f(x, y): + return x * y + + @jax.jit + def g(): + primals = 2., 3. + y, f_vjp = jax.vjp(f, *primals) + f_vjp.args_res = [None, None] + return y, f_vjp + + @jax.jit + def h(f_vjp): + f_vjp.args_res = [2., 3.] + return f_vjp(1.) + + y, f_vjp = g() + arg_cts = h(f_vjp) + self.assertAllClose(y, 6.) + self.assertAllClose(arg_cts, (3., 2.)) + + def test_basic_unused_vjp3(self): + f = jnp.sin + primals = 3., + y, f_vjp = api.vjp(f, *primals) + x_ct, = f_vjp(1.) + self.assertAllClose(y, jnp.sin(3.)) + self.assertAllClose(x_ct, jnp.cos(3.)) + self.assertIsInstance(f_vjp.args_res[0], api.NotNeeded) # can check if unused + + def test_basic_opaque_vjp3(self): + f = jnp.sin + primals = 3., + _, f_vjp = api.vjp(f, *primals) + self.assertTrue(f_vjp.opaque_residuals) # can detect if opaque res are used + + def test_basic_pytree_error(self): + def f(x): + return [x['hi'] * x['bye']] + + y, f_vjp = jax.vjp(f, {'hi': 2., 'bye': 3.}) + f_vjp.args_res = [None] + y_grad = [1.] + f_vjp.args_res = [{'hi': 2., 'bye': 3.}] + arg_ct, = f_vjp(y_grad) + self.assertAllClose(y, [6.]) + self.assertAllClose(arg_ct, {'hi': 3., 'bye': 2.}) + + # TODO(mattjj): Raise an error message. + # with self.assertRaisesRegex(ValueError, "but the structures differ"): + # f_vjp.args_res = [{'hi': 2.}] + # f_vjp([1.]) + + def test_fsdp_error(self): + # see https://github.com/jax-ml/jax/pull/27017 for why this is called "fsdp" + def f2(x, w): + x = 1. * x + x = x @ w + x = 2. * x + return x + + x = jnp.ones((3, 4)) + w = jnp.ones((4, 4)) + y, f2_vjp = jax.vjp(f2, x, w) + f2_vjp.args_res[1] = None + y_grad = jnp.ones((2, 4)) + f2_vjp.args_res[1] = w + with self.assertRaisesRegex(ValueError, "unexpected JAX type"): + f2_vjp(y_grad) + + def test_fsdp_vjp3(self): + # see https://github.com/jax-ml/jax/pull/27017 for why this is called "fsdp" + def f2(x, w): + x = 1. * x + x = x @ w + x = 2. * x + return x + + x = jnp.ones((3, 4)) + w = jnp.ones((4, 4)) + y, f2_vjp = api.vjp(f2, x, w) + f2_vjp.args_res[1] = None + y_grad = jnp.ones_like(y) + f2_vjp.args_res[1] = w + x_grad, w_grad = f2_vjp(y_grad) + self.assertAllClose(x_grad, 2. * y_grad @ w.T) + self.assertAllClose(w_grad, 2. * x.T @ y_grad) + self.assertAllClose(w_grad, 2. * x.T @ y_grad) + + def test_doesnt_leak_symbolic_zeros(self): + _, vjp = jax.vjp(lambda x: 1., 3.14) + ans, = vjp(1.0) + self.assertIsInstance(ans, jax.Array) + + +class TracebackTest(jtu.JaxTestCase): + # These tests are to catch regressions in Python traceback sizes. Our + # second-order APIs can be nested arbitrarily and if each one adds a dozen + # stack frames then we can end up with very deep tracebacks. We expect the + # particular `expected_depth` constants in these tests to change from time to + # time. We just want to know when it happens and what caused it. + + def cur_depth(self): + return len(inspect.stack()) + + def test_traceback_test(self): + expected_depth_foo = 1 + expected_depth_bar = 2 + init_depth = self.cur_depth() + def foo(): + self.assertExpectedDepth(init_depth, expected_depth_foo) + def bar(): + self.assertExpectedDepth(init_depth, expected_depth_bar) + bar() + + foo() + + def assertExpectedDepth(self, init_depth, expected_depth): + # `- 1` is for the `assertExpectedDepth` stack frame itself + self.assertEqual(self.cur_depth() - init_depth - 1, expected_depth) + + def test_scan_traceback(self): + expected_depth = 5 + init_depth = self.cur_depth() + + def f(c, x): + self.assertExpectedDepth(init_depth, expected_depth) + return (c, ()) + + jax.lax.scan(f, 0, jnp.arange(4)) + + def test_cond_traceback(self): + if sys.version_info < (3, 13): + # Fails because 3.11 adds an extra stack frame due to a list comprehension + self.skipTest("Expected failure.") + expected_depth = 4 + init_depth = self.cur_depth() + + def f(): + self.assertExpectedDepth(init_depth, expected_depth) + + lax.cond(True, f, lambda: None) + + def test_jit_traceback(self): + # TODO(dougalm): shoud be able to get this down to 2 or 3 + expected_depth = 5 + init_depth = self.cur_depth() + @jit + def foo(x): + self.assertExpectedDepth(init_depth, expected_depth) + return x + foo(1) + + def test_grad_traceback(self): + # TODO(dougalm): improve this + expected_depth = 11 + init_depth = self.cur_depth() + + def foo(x): + self.assertExpectedDepth(init_depth, expected_depth) + return x + + grad(foo)(1.0) + + def test_vmap_traceback(self): + # TODO(dougalm): improve this + expected_depth = 7 + init_depth = self.cur_depth() + + def foo(x): + self.assertExpectedDepth(init_depth, expected_depth) + return x + + jax.vmap(foo)(np.arange(3)) + + def test_custom_vjp_traceback(self): + # TODO(dougalm): improve this + expected_depth_f = 10 + expected_depth_f_fwd = 19 + expected_depth_f_rev = 12 + init_depth = self.cur_depth() + @jax.custom_vjp + def f(x): + self.assertExpectedDepth(init_depth, expected_depth_f) + return x + def f_fwd(x): + self.assertExpectedDepth(init_depth, expected_depth_f_fwd) + return x, None + def f_rev(_, g): + self.assertExpectedDepth(init_depth, expected_depth_f_rev) + return (g,) + f.defvjp(f_fwd, f_rev) + + f(1.0) + grad(f)(1.0) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/array_api_skips.txt b/tests/array_api_skips.txt index 2f8d4d1c666f..98636ec9d582 100644 --- a/tests/array_api_skips.txt +++ b/tests/array_api_skips.txt @@ -2,6 +2,7 @@ # finfo return type misalignment (https://github.com/data-apis/array-api/issues/405) array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] # Test suite attempts in-place mutation: array_api_tests/test_array_object.py::test_setitem @@ -10,6 +11,26 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays # Returns wrong zero sign array_api_tests/test_special_cases.py::test_unary[sign((x_i is -0 or x_i == +0)) -> 0] +array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is +0 and x2_i < 0) -> -0] +array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is -0 and x2_i > 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -0 and x2_i > 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -0 and x2_i < 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i > 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i < 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i > 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is -0 and x2_i > 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> +0] +array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is +0 and x2_i < 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is +0 and x2_i < 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> +0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -0 and x2_i > 0) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -0 and x2_i < 0) -> +0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> +0] + +# Array API expects default value for axis argument. +array_api_tests/test_indexing_functions.py::test_take_along_axis # Returns int32 when int64 is expected array_api_tests/test_searching_functions.py::test_searchsorted @@ -19,3 +40,44 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_clip # JAX raises a ValueError rather than the expected IndexError for out-of-bound axis array_api_tests/test_manipulation_functions.py::test_expand_dims + +# Doesn't promote to uint64 +array_api_tests/test_statistical_functions.py::test_cumulative_prod + +# TODO(jakevdp): fix the following failures: + +# Returns NaN rather than inf +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i < 0 and x2_i is +0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i > 0 and x2_i is +0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i < 0 and x2_i is -0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i > 0 and x2_i is -0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i < 0 and x2_i is +0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i < 0 and x2_i is -0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i > 0 and x2_i is +0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i > 0 and x2_i is -0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i > 0 and x2_i is +0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i > 0 and x2_i is -0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i < 0 and x2_i is -0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i < 0 and x2_i is +0) -> -infinity] + +# Returns -1.0 rather than 0.0 +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] + +# TODO(b/440163737): tanh(inf) not returning 1.0 +array_api_tests/test_special_cases.py::test_unary[tanh(x_i is -infinity) -> -1] +array_api_tests/test_special_cases.py::test_unary[tanh(x_i is +infinity) -> +1] diff --git a/tests/array_api_test.py b/tests/array_api_test.py index 250eeb810872..1c624f78a6bc 100644 --- a/tests/array_api_test.py +++ b/tests/array_api_test.py @@ -25,7 +25,8 @@ import jax import jax.numpy as jnp from jax._src import config, test_util as jtu -from jax._src.dtypes import _default_types, canonicalize_dtype +from jax._src.dtypes import default_types +from jax._src import xla_bridge as xb ARRAY_API_NAMESPACE = jnp @@ -270,19 +271,23 @@ def setUp(self): def build_dtype_dict(self, dtypes): out = {} for name in dtypes: - out[name] = jnp.dtype(name) + out[name] = jnp.dtype(name) return out def test_capabilities_info(self): capabilities = self.info.capabilities() - assert capabilities["boolean indexing"] + assert not capabilities["boolean indexing"] assert not capabilities["data-dependent shapes"] + assert capabilities["max dimensions"] == 64 def test_default_device_info(self): assert self.info.default_device() is None def test_devices_info(self): - assert self.info.devices() == jax.devices() + devices = set(self.info.devices()) + assert None in devices + for backend in xb.backends(): + assert devices.issuperset(jax.devices(backend)) def test_default_dtypes_info(self): _default_dtypes = { @@ -292,9 +297,8 @@ def test_default_dtypes_info(self): "indexing": "i", } target_dict = { - dtype_name: canonicalize_dtype( - _default_types.get(kind) - ) for dtype_name, kind in _default_dtypes.items() + dtype_name: default_types.get(kind)() + for dtype_name, kind in _default_dtypes.items() } assert self.info.default_dtypes() == target_dict diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py new file mode 100644 index 000000000000..76e221efe64e --- /dev/null +++ b/tests/array_extensibility_test.py @@ -0,0 +1,624 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import Any, NamedTuple +from collections.abc import Callable +import dataclasses + +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np + +import jax +import jax.numpy as jnp +from jax.typing import ArrayLike +from jax._src import config +from jax._src import test_util as jtu + + +config.parse_flags_with_absl() + + +@jax.tree_util.register_dataclass +@dataclasses.dataclass +class JaxArrayWrapper: + """Class that provides a __jax_array__ method.""" + x: ArrayLike + + def __jax_array__(self) -> jax.Array: + return jnp.asarray(self.x) + + +@jax.tree_util.register_dataclass +@dataclasses.dataclass +class NumpyArrayWrapper: + """Pytree that provides an __array__ method.""" + x: ArrayLike + + def __array__(self, dtype=None, copy=None) -> jax.Array: + return np.asarray(self.x, dtype=dtype, copy=copy) + + +@jax.tree_util.register_dataclass +@dataclasses.dataclass +class JaxArrayWrapperWithErroringNumpyArray: + """Pytree that provides an __array__ method which fails.""" + x: ArrayLike + + def __jax_array__(self) -> jax.Array: + return jnp.asarray(self.x) + + def __array__(self, dtype=None, copy=None) -> jax.Array: + raise ValueError("__array__ method should not be called.") + + +class DuckTypedArrayWithErroringJaxArray: + """Duck-typed array that provides a __jax_array__ method which fails.""" + shape = (2, 3) + dtype = np.dtype('float32') + + def __jax_array__(self): + raise ValueError("jax array was called.") + + +class NumPyAPI(NamedTuple): + fun: Callable[..., Any] + args: list[jax.ShapeDtypeStruct] + kwargs: dict[str, Any] + skip_on_devices: list[str] | None + + def name(self): + return self.fun.__name__ + + def make_args(self, rng): + rng = jtu.rand_default(rng) + return jax.tree.map(lambda arg: rng(arg.shape, arg.dtype), self.args) + + def with_skip_on_devices(self, disabled_devices: list[str]) -> 'NumPyAPI': + return self._replace(skip_on_devices=disabled_devices) + + @classmethod + def sig(cls, fun: Callable[..., Any], *args: Any, **kwargs: Any) -> 'NumPyAPI': + return cls(fun, args, kwargs, None) + + +class ShapeDtype: + """Shortcut for specifying ShapeDtypeStruct.""" + def __init__(self, dtype): + self.dtype = jax.dtypes.canonicalize_dtype(dtype) + def __getitem__(self, shape) -> jax.ShapeDtypeStruct: + if isinstance(shape, int): + shape = (shape,) + return jax.ShapeDtypeStruct(shape, self.dtype) + +Bool = ShapeDtype(bool) +Int = ShapeDtype(int) +UInt = ShapeDtype('uint32') +Uint8 = ShapeDtype('uint8') +Float = ShapeDtype(float) +Complex = ShapeDtype(complex) + + +# NumPy namespace objects skipped in the enumeration below, mainly because +# they are not functions or do not take arrays as positional arguments. +SKIPPED_APIS = [ + 'apply_along_axis', + 'apply_over_axes', + 'arange', + 'array_str', + 'array_repr', + 'astype', + 'bartlett', + 'bfloat16', + 'blackman', + 'block', + 'bool', + 'bool_', + 'broadcast_shapes', + 'c_', + 'can_cast', + 'cdouble', + 'character', + 'complex128', + 'complex64', + 'complex_', + 'complexfloating', + 'csingle', + 'diag_indices', + 'double', + 'dtype', + 'e', + 'einsum', + 'einsum_path', + 'euler_gamma', + 'empty', + 'eye', + 'finfo', + 'flexible', + 'float_', + 'float16', + 'float32', + 'float4_e2m1fn', + 'float64', + 'float8_e3m4', + 'float8_e4m3', + 'float8_e4m3b11fnuz', + 'float8_e4m3fn', + 'float8_e4m3fnuz', + 'float8_e5m2', + 'float8_e5m2fnuz', + 'float8_e8m0fnu', + 'floating', + 'from_dlpack', + 'frombuffer', + 'fromfile', + 'fromfunction', + 'fromiter', + 'frompyfunc', + 'fromstring', + 'full', + 'generic', + 'geomspace', + 'get_printoptions', + 'gradient', + 'hamming', + 'hanning', + 'identity', + 'iinfo', + 'index_exp', + 'indices', + 'inexact', + 'inf', + 'int16', + 'int2', + 'int32', + 'int4', + 'int64', + 'int8', + 'int_', + 'integer', + 'isdtype', + 'issubdtype' + 'iterable' + 'kaiser' + 'kron' + 'ix_', + 'linalg', + 'linspace', + 'load', + 'logspace', + 'mask_indices', + 'mgrid', + 'nan', + 'ndarray', + 'newaxis', + 'number', + 'object_', + 'ogrid', + 'ones', + 'pi', + 'printoptions', + 'promote_types' + 'r_', + 'result_type', + 's_', + 'save', + 'savez', + 'set_printoptions', + 'signedinteger', + 'single', + 'tri', + 'tril_indices', + 'triu_indices', + 'ufunc', + 'uint', + 'uint16', + 'uint2', + 'uint32', + 'uint4', + 'uint64', + 'uint8', + 'unsignedinteger', + 'vectorize', + 'zeros', +] + +# TODO(jakevdp): commented APIs are ones which do not yet support +# __jax_array__ on inputs. We should fix these! +NUMPY_APIS = [ + NumPyAPI.sig(jnp.abs, Float[5]), + NumPyAPI.sig(jnp.absolute, Float[5]), + NumPyAPI.sig(jnp.acos, Float[5]), + NumPyAPI.sig(jnp.acosh, Float[5]), + NumPyAPI.sig(jnp.add, Float[5], Float[5]), + NumPyAPI.sig(jnp.all, Bool[5]), + NumPyAPI.sig(jnp.allclose, Float[5], Float[5]), + NumPyAPI.sig(jnp.amax, Float[5]), + NumPyAPI.sig(jnp.amin, Float[5]), + NumPyAPI.sig(jnp.angle, Float[5]), + NumPyAPI.sig(jnp.any, Float[5]), + NumPyAPI.sig(jnp.append, Float[10], Float[()]), + NumPyAPI.sig(jnp.arccos, Float[5]), + NumPyAPI.sig(jnp.arccosh, Float[5]), + NumPyAPI.sig(jnp.arcsin, Float[5]), + NumPyAPI.sig(jnp.arcsinh, Float[5]), + NumPyAPI.sig(jnp.arctan, Float[5]), + NumPyAPI.sig(jnp.arctan2, Float[5], Float[5]), + NumPyAPI.sig(jnp.arctanh, Float[5]), + NumPyAPI.sig(jnp.argmax, Float[10]), + NumPyAPI.sig(jnp.argmin, Float[10]), + NumPyAPI.sig(jnp.argpartition, Float[10], kth=5), + NumPyAPI.sig(jnp.argsort, Float[10]), + NumPyAPI.sig(jnp.argwhere, Float[10]), + NumPyAPI.sig(jnp.around, Float[5]), + NumPyAPI.sig(jnp.array, Float[5]), + NumPyAPI.sig(jnp.array_equal, Float[5], Float[5]), + NumPyAPI.sig(jnp.array_equiv, Float[5], Float[5]), + NumPyAPI.sig(jnp.array_split, Float[9], indices_or_sections=3), + NumPyAPI.sig(jnp.asarray, Float[5]), + NumPyAPI.sig(jnp.asin, Float[5]), + NumPyAPI.sig(jnp.asinh, Float[5]), + NumPyAPI.sig(jnp.atan, Float[5]), + NumPyAPI.sig(jnp.atan2, Float[5], Float[5]), + NumPyAPI.sig(jnp.atanh, Float[5]), + NumPyAPI.sig(jnp.atleast_1d, Float[5]), + NumPyAPI.sig(jnp.atleast_2d, Float[5]), + NumPyAPI.sig(jnp.atleast_3d, Float[5]), + NumPyAPI.sig(jnp.average, Float[10]), + NumPyAPI.sig(jnp.bincount, Int[10]), + NumPyAPI.sig(jnp.bitwise_and, Int[5], Int[5]), + NumPyAPI.sig(jnp.bitwise_count, Int[5]), + NumPyAPI.sig(jnp.bitwise_invert, Int[5]), + NumPyAPI.sig(jnp.bitwise_left_shift, Int[5], Int[5]), + NumPyAPI.sig(jnp.bitwise_not, Int[5]), + NumPyAPI.sig(jnp.bitwise_or, Int[5], Int[5]), + NumPyAPI.sig(jnp.bitwise_right_shift, Int[5], Int[5]), + NumPyAPI.sig(jnp.bitwise_xor, Int[5], Int[5]), + NumPyAPI.sig(jnp.broadcast_arrays, Float[5]), + NumPyAPI.sig(jnp.broadcast_to, Float[()], shape=(10,)), + NumPyAPI.sig(jnp.cbrt, Float[5]), + NumPyAPI.sig(jnp.ceil, Float[5]), + NumPyAPI.sig(jnp.choose, Int[3], [Float[3], Float[3], Float[3]], mode='clip'), + NumPyAPI.sig(jnp.clip, Float[5]), + NumPyAPI.sig(jnp.column_stack, [Float[5], Float[5], Float[5]]), + NumPyAPI.sig(jnp.compress, Float[10], Bool[10]), + NumPyAPI.sig(jnp.concat, [Float[5], Float[5]]), + NumPyAPI.sig(jnp.concatenate, [Float[5], Float[5]]), + NumPyAPI.sig(jnp.conj, Float[5]), + NumPyAPI.sig(jnp.conjugate, Float[5]), + NumPyAPI.sig(jnp.convolve, Float[7], Float[3]), + NumPyAPI.sig(jnp.copy, Float[5]), + NumPyAPI.sig(jnp.copysign, Float[5], Float[5]), + NumPyAPI.sig(jnp.corrcoef, Float[7], Float[7]), + NumPyAPI.sig(jnp.correlate, Float[7], Float[3]), + NumPyAPI.sig(jnp.cos, Float[5]), + NumPyAPI.sig(jnp.cosh, Float[5]), + NumPyAPI.sig(jnp.count_nonzero, Float[10]), + NumPyAPI.sig(jnp.cov, Float[10]), + NumPyAPI.sig(jnp.cross, Float[3], Float[3]), + NumPyAPI.sig(jnp.cumprod, Float[5]), + NumPyAPI.sig(jnp.cumsum, Float[5]), + NumPyAPI.sig(jnp.cumulative_prod, Float[5]), + NumPyAPI.sig(jnp.cumulative_sum, Float[5]), + NumPyAPI.sig(jnp.deg2rad, Float[5]), + NumPyAPI.sig(jnp.degrees, Float[5]), + NumPyAPI.sig(jnp.delete, Float[5], Int[()]), + NumPyAPI.sig(jnp.diag, Float[5]), + NumPyAPI.sig(jnp.diag_indices_from, Float[5, 5]), + NumPyAPI.sig(jnp.diagflat, Float[5]), + NumPyAPI.sig(jnp.diagonal, Float[5, 5]), + NumPyAPI.sig(jnp.diff, Float[5]), + NumPyAPI.sig(jnp.digitize, Float[5], Float[5]), + NumPyAPI.sig(jnp.divide, Float[5], Float[5]), + NumPyAPI.sig(jnp.divmod, Float[5], Float[5]), + NumPyAPI.sig(jnp.dot, Float[5], Float[5]), + NumPyAPI.sig(jnp.dsplit, Float[3, 5, 6], indices_or_sections=2), + NumPyAPI.sig(jnp.dstack, [Float[3, 5, 1], Float[3, 5, 3]]), + NumPyAPI.sig(jnp.ediff1d, Float[5]), + NumPyAPI.sig(jnp.empty_like, Float[5]), + NumPyAPI.sig(jnp.equal, Float[5], Float[5]), + NumPyAPI.sig(jnp.exp, Float[5]), + NumPyAPI.sig(jnp.exp2, Float[5]), + NumPyAPI.sig(jnp.expand_dims, Float[5], axis=0), + NumPyAPI.sig(jnp.expm1, Float[5]), + NumPyAPI.sig(jnp.extract, Bool[5], Float[5]), + NumPyAPI.sig(jnp.fabs, Float[5]), + NumPyAPI.sig(jnp.fft.fft, Float[5]), + NumPyAPI.sig(jnp.fft.fft2, Float[5, 5]), + NumPyAPI.sig(jnp.fft.ifft, Float[5]), + NumPyAPI.sig(jnp.fft.ifft2, Float[5, 5]), + NumPyAPI.sig(jnp.fill_diagonal, Float[5, 5], Float[()], inplace=False), + NumPyAPI.sig(jnp.flatnonzero, Float[5]), + NumPyAPI.sig(jnp.flip, Float[5]), + NumPyAPI.sig(jnp.fliplr, Float[5, 5]), + NumPyAPI.sig(jnp.flipud, Float[5, 5]), + NumPyAPI.sig(jnp.float_power, Float[5], Float[5]), + NumPyAPI.sig(jnp.floor, Float[5]), + NumPyAPI.sig(jnp.floor_divide, Float[5], Float[5]), + NumPyAPI.sig(jnp.fmax, Float[5], Float[5]), + NumPyAPI.sig(jnp.fmin, Float[5], Float[5]), + NumPyAPI.sig(jnp.fmod, Float[5], Float[5]), + NumPyAPI.sig(jnp.frexp, Float[5]), + NumPyAPI.sig(jnp.full_like, Float[5], Float[()]), + NumPyAPI.sig(jnp.gcd, Int[5], Int[5]), + NumPyAPI.sig(jnp.greater, Float[5], Float[5]), + NumPyAPI.sig(jnp.greater_equal, Float[5], Float[5]), + NumPyAPI.sig(jnp.heaviside, Float[5], Float[5]), + NumPyAPI.sig(jnp.histogram, Float[5]), + NumPyAPI.sig(jnp.histogram2d, Float[5], Float[5]), + NumPyAPI.sig(jnp.histogram_bin_edges, Float[5]), + NumPyAPI.sig(jnp.histogramdd, Float[5, 3]), + NumPyAPI.sig(jnp.hsplit, Float[3, 6], indices_or_sections=2), + NumPyAPI.sig(jnp.hstack, (Float[5], Float[5])), + NumPyAPI.sig(jnp.hypot, Float[5], Float[5]), + NumPyAPI.sig(jnp.i0, Float[5]), + NumPyAPI.sig(jnp.imag, Complex[5]), + NumPyAPI.sig(jnp.inner, Float[5], Float[5]), + NumPyAPI.sig(jnp.insert, Float[5], Int[()], Float[2]), + NumPyAPI.sig(jnp.interp, Float[10], Float[5], Float[5]), + NumPyAPI.sig(jnp.intersect1d, Int[5], Int[5]), + NumPyAPI.sig(jnp.invert, Int[5]), + NumPyAPI.sig(jnp.isclose, Float[5], Float[5]), + NumPyAPI.sig(jnp.iscomplex, Float[5]), + NumPyAPI.sig(jnp.iscomplexobj, Complex[5]), + NumPyAPI.sig(jnp.isfinite, Float[5]), + NumPyAPI.sig(jnp.isin, Int[5], Int[10]), + NumPyAPI.sig(jnp.isinf, Float[5]), + NumPyAPI.sig(jnp.isnan, Float[5]), + NumPyAPI.sig(jnp.isneginf, Float[5]), + NumPyAPI.sig(jnp.isposinf, Float[5]), + NumPyAPI.sig(jnp.isreal, Float[5]), + NumPyAPI.sig(jnp.isrealobj, Float[5]), + NumPyAPI.sig(jnp.isscalar, Float[()]), + NumPyAPI.sig(jnp.lcm, Int[5], Int[5]), + NumPyAPI.sig(jnp.ldexp, Float[5], Int[5]), + NumPyAPI.sig(jnp.left_shift, Int[5], Int[5]), + NumPyAPI.sig(jnp.less, Float[5], Float[5]), + NumPyAPI.sig(jnp.less_equal, Float[5], Float[5]), + NumPyAPI.sig(jnp.lexsort, [Float[5], Float[5]]), + NumPyAPI.sig(jnp.log, Float[5]), + NumPyAPI.sig(jnp.log10, Float[5]), + NumPyAPI.sig(jnp.log1p, Float[5]), + NumPyAPI.sig(jnp.log2, Float[5]), + NumPyAPI.sig(jnp.logaddexp, Float[5], Float[5]), + NumPyAPI.sig(jnp.logaddexp2, Float[5], Float[5]), + NumPyAPI.sig(jnp.logical_and, Int[5], Int[5]), + NumPyAPI.sig(jnp.logical_not, Int[5]), + NumPyAPI.sig(jnp.logical_or, Int[5], Int[5]), + NumPyAPI.sig(jnp.logical_xor, Int[5], Int[5]), + NumPyAPI.sig(jnp.matmul, Float[5, 5], Float[5]), + NumPyAPI.sig(jnp.matrix_transpose, Float[5, 6]), + NumPyAPI.sig(jnp.matvec, Float[5, 5], Float[5]), + NumPyAPI.sig(jnp.max, Float[5]), + NumPyAPI.sig(jnp.maximum, Float[5], Float[5]), + NumPyAPI.sig(jnp.mean, Float[5]), + NumPyAPI.sig(jnp.median, Float[5]), + NumPyAPI.sig(jnp.meshgrid, Float[5], Float[5]), + NumPyAPI.sig(jnp.min, Float[5]), + NumPyAPI.sig(jnp.minimum, Float[5], Float[5]), + NumPyAPI.sig(jnp.mod, Float[5], Float[5]), + NumPyAPI.sig(jnp.modf, Float[5]), + NumPyAPI.sig(jnp.moveaxis, Float[5, 3], source=0, destination=1), + NumPyAPI.sig(jnp.multiply, Float[5], Float[5]), + NumPyAPI.sig(jnp.nan_to_num, Float[5]), + NumPyAPI.sig(jnp.nanargmax, Float[5]), + NumPyAPI.sig(jnp.nanargmin, Float[5]), + NumPyAPI.sig(jnp.nancumprod, Float[5]), + NumPyAPI.sig(jnp.nancumsum, Float[5]), + NumPyAPI.sig(jnp.nanmax, Float[5]), + NumPyAPI.sig(jnp.nanmean, Float[5]), + NumPyAPI.sig(jnp.nanmedian, Float[5]), + NumPyAPI.sig(jnp.nanmin, Float[5]), + NumPyAPI.sig(jnp.nanpercentile, Float[5], q=75), + NumPyAPI.sig(jnp.nanprod, Float[5]), + NumPyAPI.sig(jnp.nanquantile, Float[5], q=0.75), + NumPyAPI.sig(jnp.nanstd, Float[5]), + NumPyAPI.sig(jnp.nansum, Float[5]), + NumPyAPI.sig(jnp.nanvar, Float[5]), + NumPyAPI.sig(jnp.ndim, Float[5]), + NumPyAPI.sig(jnp.negative, Float[5]), + NumPyAPI.sig(jnp.nextafter, Float[5], Float[5]), + NumPyAPI.sig(jnp.nonzero, Float[5]), + NumPyAPI.sig(jnp.not_equal, Float[5], Float[5]), + NumPyAPI.sig(jnp.ones_like, Float[5]), + NumPyAPI.sig(jnp.outer, Float[5], Float[5]), + NumPyAPI.sig(jnp.packbits, Int[5]), + NumPyAPI.sig(jnp.pad, Float[5], pad_width=2), + NumPyAPI.sig(jnp.partition, Float[5], kth=3), + NumPyAPI.sig(jnp.percentile, Float[5], q=75), + NumPyAPI.sig(jnp.permute_dims, Float[3, 5], axes=(1, 0)), + NumPyAPI.sig(jnp.piecewise, Float[5], [Bool[5], Bool[5]], funclist=[jnp.sin, jnp.cos]), + NumPyAPI.sig(jnp.place, Float[5], Bool[5], Float[3], inplace=False), + NumPyAPI.sig(jnp.poly, Float[5]), + NumPyAPI.sig(jnp.polyadd, Float[5], Float[5]), + NumPyAPI.sig(jnp.polyder, Float[5]), + NumPyAPI.sig(jnp.polydiv, Float[5], Float[5]), + NumPyAPI.sig(jnp.polyfit, Float[5], Float[5], deg=2), + NumPyAPI.sig(jnp.polyint, Float[5]), + NumPyAPI.sig(jnp.polymul, Float[5], Float[5]), + NumPyAPI.sig(jnp.polysub, Float[5], Float[5]), + NumPyAPI.sig(jnp.polyval, Float[5], Float[10]), + NumPyAPI.sig(jnp.positive, Float[5]), + NumPyAPI.sig(jnp.pow, Float[5], Float[5]), + NumPyAPI.sig(jnp.power, Float[5], Float[5]), + NumPyAPI.sig(jnp.prod, Float[5]), + NumPyAPI.sig(jnp.ptp, Float[5]), + NumPyAPI.sig(jnp.put, Float[5], Int[()], Float[()], inplace=False), + NumPyAPI.sig(jnp.put_along_axis, Float[5], Int[1], Float[1], axis=0, inplace=False), + NumPyAPI.sig(jnp.quantile, Float[5], q=0.75), + NumPyAPI.sig(jnp.rad2deg, Float[5]), + NumPyAPI.sig(jnp.radians, Float[5]), + NumPyAPI.sig(jnp.ravel, Float[5]), + NumPyAPI.sig(jnp.ravel_multi_index, [Uint8[5], Uint8[5]], dims=(8, 9)), + NumPyAPI.sig(jnp.real, Complex[5]), + NumPyAPI.sig(jnp.reciprocal, Float[5]), + NumPyAPI.sig(jnp.remainder, Float[5], Float[5]), + NumPyAPI.sig(jnp.repeat, Float[5], repeats=np.array([2, 3, 1, 5, 4])), + NumPyAPI.sig(jnp.reshape, Float[6], shape=(2, 3)), + NumPyAPI.sig(jnp.resize, Float[6], new_shape=(2, 3)), + NumPyAPI.sig(jnp.right_shift, Int[5], Int[5]), + NumPyAPI.sig(jnp.rint, Float[5]), + NumPyAPI.sig(jnp.roll, Float[5], Int[1]), + NumPyAPI.sig(jnp.rollaxis, Float[5, 4], axis=1), + NumPyAPI.sig(jnp.roots, Float[5]).with_skip_on_devices(['tpu']), + NumPyAPI.sig(jnp.rot90, Float[5, 3]), + NumPyAPI.sig(jnp.round, Float[5]), + NumPyAPI.sig(jnp.searchsorted, Float[5], Float[5]), + NumPyAPI.sig(jnp.select, [Bool[5], Bool[5]], [Float[5], Float[5]], Float[()]), + NumPyAPI.sig(jnp.setdiff1d, Int[5], Int[5]), + NumPyAPI.sig(jnp.setxor1d, Int[5], Int[5]), + NumPyAPI.sig(jnp.shape, Float[5, 3]), + NumPyAPI.sig(jnp.sign, Float[5]), + NumPyAPI.sig(jnp.signbit, Float[5]), + NumPyAPI.sig(jnp.sin, Float[5]), + NumPyAPI.sig(jnp.sinc, Float[5]), + NumPyAPI.sig(jnp.sinh, Float[5]), + NumPyAPI.sig(jnp.size, Float[5]), + NumPyAPI.sig(jnp.sort, Float[5]), + NumPyAPI.sig(jnp.sort_complex, Complex[5]), + NumPyAPI.sig(jnp.spacing, Float[5]), + NumPyAPI.sig(jnp.split, Float[6], indices_or_sections=2), + NumPyAPI.sig(jnp.sqrt, Float[5]), + NumPyAPI.sig(jnp.square, Float[5]), + NumPyAPI.sig(jnp.squeeze, Float[5]), + NumPyAPI.sig(jnp.stack, [Float[2, 3], Float[2, 3]], axis=1), + NumPyAPI.sig(jnp.std, Float[5]), + NumPyAPI.sig(jnp.subtract, Float[5], Float[5]), + NumPyAPI.sig(jnp.sum, Float[5]), + NumPyAPI.sig(jnp.swapaxes, Float[3, 5], axis1=1, axis2=0), + NumPyAPI.sig(jnp.take, Float[5], Int[2]), + NumPyAPI.sig(jnp.take_along_axis, Float[5], Int[2], axis=0), + NumPyAPI.sig(jnp.tan, Float[5]), + NumPyAPI.sig(jnp.tanh, Float[5]), + NumPyAPI.sig(jnp.tensordot, Float[2, 3, 4], Float[3, 4, 5]), + NumPyAPI.sig(jnp.tile, Float[5], reps=(2,)), + NumPyAPI.sig(jnp.trace, Float[5, 5]), + NumPyAPI.sig(jnp.transpose, Float[5, 6]), + NumPyAPI.sig(jnp.trapezoid, Float[5]), + NumPyAPI.sig(jnp.tril, Float[5, 6]), + NumPyAPI.sig(jnp.tril_indices_from, Float[5, 6]), + NumPyAPI.sig(jnp.trim_zeros, Float[5]), + NumPyAPI.sig(jnp.triu, Float[5, 6]), + NumPyAPI.sig(jnp.triu_indices_from, Float[5, 6]), + NumPyAPI.sig(jnp.true_divide, Float[5], Float[5]), + NumPyAPI.sig(jnp.trunc, Float[5]), + NumPyAPI.sig(jnp.union1d, Int[5], Int[5]), + NumPyAPI.sig(jnp.unique, Int[10]), + NumPyAPI.sig(jnp.unique_all, Int[10]), + NumPyAPI.sig(jnp.unique_counts, Int[10]), + NumPyAPI.sig(jnp.unique_inverse, Int[10]), + NumPyAPI.sig(jnp.unique_values, Int[10]), + NumPyAPI.sig(jnp.unpackbits, Uint8[8]), + NumPyAPI.sig(jnp.unravel_index, Int[5], shape=(2, 3)), + NumPyAPI.sig(jnp.unstack, Float[5]), + NumPyAPI.sig(jnp.unwrap, Float[5]), + NumPyAPI.sig(jnp.vander, Float[5]), + NumPyAPI.sig(jnp.var, Float[5]), + NumPyAPI.sig(jnp.vdot, Float[5], Float[5]), + NumPyAPI.sig(jnp.vecdot, Float[5], Float[5]), + NumPyAPI.sig(jnp.vecmat, Float[5], Float[5, 3]), + NumPyAPI.sig(jnp.vsplit, Float[6], indices_or_sections=2), + NumPyAPI.sig(jnp.vstack, [Float[5], Float[2, 5]]), + NumPyAPI.sig(jnp.where, Bool[5], Float[5], Float[5]), + NumPyAPI.sig(jnp.zeros_like, Float[5]), +] + + +class JaxArrayTests(jtu.JaxTestCase): + @parameterized.named_parameters( + {'testcase_name': api.name(), 'api': api} for api in NUMPY_APIS) + def test_numpy_api_supports_jax_array(self, api): + if api.skip_on_devices and jtu.test_device_matches(api.skip_on_devices): + self.skipTest(f'{api.name()} not supported on {api.skip_on_devices}') + fun = api.fun + args = api.make_args(self.rng()) + wrapped_args = jax.tree.map(JaxArrayWrapper, args) + kwargs = api.kwargs + + expected = fun(*args, **kwargs) + wrapped = fun(*wrapped_args, **kwargs) + + self.assertAllClose(wrapped, expected, atol=0, rtol=0) + + @jtu.sample_product( + api=['array', 'asarray'], + test_class=[ + JaxArrayWrapper, + NumpyArrayWrapper, + JaxArrayWrapperWithErroringNumpyArray, + ], + ) + def test_array_creation(self, api, test_class): + """Test pytrees with __jax_array__ and/or __array__ methods.""" + fun = getattr(jnp, api) + x = np.arange(5, dtype='float32') + + expected = fun(x) + actual = fun(test_class(x)) + + self.assertIsInstance(actual, jax.Array) + self.assertAllClose(actual, expected, atol=0, rtol=0) + + @parameterized.named_parameters( + {'testcase_name': func.__name__, 'func': func} + for func in [jnp.zeros_like, jnp.ones_like, jnp.empty_like, jnp.full_like] + ) + def test_array_creation_from_duck_typed_array(self, func): + # Ensure that jnp.*_like prefers shape/dtype over __jax_array__ when + # both methods are available. + if func is jnp.full_like: + func = functools.partial(func, fill_value=2.0) + obj = DuckTypedArrayWithErroringJaxArray() + + # The test relies on this failing + with self.assertRaises(ValueError): + jnp.asarray(obj) + + result = func(obj) + self.assertIsInstance(result, jax.Array) + self.assertEqual(result.shape, obj.shape) + self.assertEqual(result.dtype, obj.dtype) + + @parameterized.named_parameters( + {"testcase_name": "subscript-form", "args": ("jk,k->j", Float[5, 3], Float[3])}, + {"testcase_name": "index-form", "args": (Float[5, 3], (0, 1), Float[3], (1,), (0,))}, + ) + def test_einsum(self, args): + rng = jtu.rand_default(self.rng()) + def make_arg(arg): + if isinstance(arg, jax.ShapeDtypeStruct): + return rng(arg.shape, arg.dtype) + return arg + args = jax.tree.map(make_arg, args) + + def wrap_array(arg): + if isinstance(arg, (jax.Array, np.ndarray)): + return JaxArrayWrapper(arg) + return arg + wrapped_args = jax.tree.map(wrap_array, args) + + expected = jnp.einsum(*args) + actual = jnp.einsum(*wrapped_args) + + self.assertAllClose(actual, expected, atol=0, rtol=0) + + +@jtu.with_config(jax_disable_jit=True) +class JaxArrayTestsNoJit(JaxArrayTests): + pass + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index 80a4d8ef5a25..b14e9171f952 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -13,18 +13,19 @@ # limitations under the License. import unittest +import warnings from absl.testing import absltest +import numpy as np import jax import jax.dlpack import jax.numpy as jnp from jax.sharding import PartitionSpec as P from jax._src import config +from jax._src import dlpack as dlpack_src from jax._src import test_util as jtu -from jax._src.lib import version as jaxlib_version - -import numpy as np +from jax._src.util import cache config.parse_flags_with_absl() @@ -34,42 +35,43 @@ cupy = None try: - import tensorflow as tf + # TODO(b/470156950): Remove this once a proper fix is in place + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", + category=FutureWarning, + message=".*np.object.*") + import tensorflow as tf + tf_version = tuple( int(x) for x in tf.version.VERSION.split("-")[0].split(".")) except ImportError: tf = None -dlpack_dtypes = sorted(jax.dlpack.SUPPORTED_DTYPES, key=lambda x: x.__name__) +dlpack_dtypes = sorted([dt.type for dt in dlpack_src.SUPPORTED_DTYPES_SET], + key=lambda x: x.__name__) # These dtypes are not supported by neither NumPy nor TensorFlow, therefore # we list them separately from ``jax.dlpack.SUPPORTED_DTYPES``. -extra_dlpack_dtypes = [] -if jaxlib_version >= (0, 5, 3): - extra_dlpack_dtypes = [ - jnp.float8_e4m3b11fnuz, - jnp.float8_e4m3fn, - jnp.float8_e4m3fnuz, - jnp.float8_e5m2, - jnp.float8_e5m2fnuz, - ] + [ - dtype - for name in [ - "float4_e2m1fn", - "float8_e3m4", - "float8_e4m3", - "float8_e8m0fnu", - ] - if (dtype := getattr(jnp, name, None)) - ] - -numpy_dtypes = sorted( - [dt for dt in jax.dlpack.SUPPORTED_DTYPES if dt != jnp.bfloat16], - key=lambda x: x.__name__) - +extra_dlpack_dtypes = [ + jnp.float8_e4m3b11fnuz, + jnp.float8_e4m3fn, + jnp.float8_e4m3fnuz, + jnp.float8_e5m2, + jnp.float8_e5m2fnuz, +] + [ + dtype + for name in [ + "float4_e2m1fn", + "float8_e3m4", + "float8_e4m3", + "float8_e8m0fnu", + ] + if (dtype := getattr(jnp, name, None)) +] + +numpy_dtypes = [dt for dt in dlpack_dtypes if dt != jnp.bfloat16] cuda_array_interface_dtypes = [dt for dt in dlpack_dtypes if dt != jnp.bfloat16] - nonempty_nonscalar_array_shapes = [(4,), (3, 4), (2, 3, 4)] empty_array_shapes = [] empty_array_shapes += [(0,), (0, 4), (3, 0),] @@ -78,6 +80,89 @@ nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes all_shapes = nonempty_array_shapes + empty_array_shapes + +def _get_alignment(x: int): + """Return alignment of x. + """ + return x & ((~x) + 1) + + +def _get_alignment_offset(ptr: int, alignment: int): + """Return minimal positive offset such that + _get_alignment(ptr + offset) == alignment + + Note that 0 <= offset < 2 * alignment. + """ + if _get_alignment(ptr) == alignment: + return 0 + offset = alignment - (ptr & (alignment - 1)) + if _get_alignment(ptr + offset) == alignment: + return offset + return offset + alignment + + +def _ensure_alignment(arr, desired_alignment): + """Return a copy of numpy array such that its data pointer has the + desired alignment exactly. The desired alignment must be power of + two. + """ + assert desired_alignment > 1, desired_alignment + buf = np.empty(2 * desired_alignment + arr.nbytes, dtype=np.int8) + ptr = buf.__array_interface__['data'][0] + start = _get_alignment_offset(ptr, desired_alignment) + # if arr.nbytes == 0 and start > 0 then buf[start:start+arr.nbytes] + # incorrectly returns the original buffer, so we must use + # buf[start:][:arr.nbytes]: + new = buf[start:][:arr.nbytes].view(arr.dtype).reshape(arr.shape) + np.copyto(new, arr, casting='unsafe') + new_ptr = new.__array_interface__['data'][0] + assert new_ptr & (desired_alignment - 1) == 0 # sanity check + assert new_ptr & (desired_alignment * 2 - 1) != 0 # sanity check + return new + + +def test_ensure_alignment(): + + def reference(ptr, alignment): + start = 0 + while _get_alignment(ptr + start) != alignment: + start += 1 + return start + + for alignment in [2, 4, 8, 16, 32, 64, 128]: + max_start = 1 + for ptr in range(1000): + start = _get_alignment_offset(ptr, alignment) + expected = reference(ptr, alignment) + max_start = max(max_start, start) + assert start == expected + assert max_start == alignment * 2 - 1 + + +@cache() +def _get_max_align_bits(dtype, device): + max_align_bits = 64 + if device.platform == "cpu": + from jax._src.lib import _jax + + # We determine the max_align_bits value from the error that is + # raised by dlpack_managed_tensor_to_buffer when using a buffer + # with a very small data alignment (=2). + x_np = _ensure_alignment(np.zeros(5, dtype=dtype), desired_alignment=2) + try: + _jax.dlpack_managed_tensor_to_buffer(x_np.__dlpack__(), device, None, False) + raise RuntimeError("unexpected success") + except Exception as e: + msg = str(e) + m = "is not aligned to" + if m in msg: + i = msg.index(m) + len(m) + max_align_bits = int(msg[i:].split(None, 1)[0]) + else: + raise + return max_align_bits + + class DLPackTest(jtu.JaxTestCase): def setUp(self): super().setUp() @@ -91,51 +176,20 @@ def setUp(self): use_stream=[False, True], ) @jtu.run_on_devices("gpu") - @jtu.ignore_warning( - message="Calling from_dlpack with a DLPack tensor", - category=DeprecationWarning, - ) def testJaxRoundTrip(self, shape, dtype, copy, use_stream): rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) - def _check_copy(x: jax.Array, y: jax.Array, expect_copy): - copied = x.unsafe_buffer_pointer() != y.unsafe_buffer_pointer() - assert copied == expect_copy, f"Expected {'a' if expect_copy else 'no'} copy" - # Check if the source device is preserved x = jax.device_put(np, jax.devices("cpu")[0]) device = jax.devices("gpu")[0] y = jax.device_put(x, device) - dl_device = y.__dlpack_device__() - if use_stream: - stream = tuple(y.devices())[0].get_stream_for_external_ready_events() - dlpack = jax.dlpack.to_dlpack(y, copy=copy, stream=stream) - else: - dlpack = jax.dlpack.to_dlpack(y, copy=copy) - z = jax.dlpack.from_dlpack(dlpack) + # TODO(parkers): Remove after setting 'stream' properly below. + jax.block_until_ready(y) + z = jax.dlpack.from_dlpack(y) self.assertEqual(z.devices(), {device}) self.assertAllClose(np.astype(x.dtype), z) - self.assertRaisesRegex(RuntimeError, - "DLPack tensor may be consumed at most once", - lambda: jax.dlpack.from_dlpack(dlpack)) - - if shape in nonempty_array_shapes: - _check_copy(y, z, bool(copy)) - - # Check if the destination device can be specified - make_dlpack = lambda: x.__dlpack__(dl_device=dl_device, copy=copy) - if copy == False: - self.assertRaisesRegex(ValueError, "copy=False", make_dlpack) - return - - z = jax.dlpack.from_dlpack(make_dlpack()) - self.assertEqual(z.devices(), {device}) - self.assertAllClose(x, z) - - if shape in nonempty_array_shapes: - _check_copy(x, z, True) @jtu.sample_product( shape=all_shapes, @@ -149,6 +203,8 @@ def testJaxArrayRoundTrip(self, shape, dtype, gpu): raise unittest.SkipTest("Skipping GPU test case on CPU") device = jax.devices("gpu" if gpu else "cpu")[0] x = jax.device_put(np, device) + # TODO(parkers): Remove after setting 'stream' properly. + jax.block_until_ready(x) y = jax.dlpack.from_dlpack(x) self.assertEqual(y.devices(), {device}) self.assertAllClose(np.astype(x.dtype), y) @@ -157,13 +213,8 @@ def testJaxArrayRoundTrip(self, shape, dtype, gpu): self.assertEqual(z.devices(), {device}) self.assertAllClose(np.astype(x.dtype), z) - @jtu.sample_product( - shape=all_shapes, - dtype=dlpack_dtypes, - ) + @jtu.sample_product(shape=all_shapes, dtype=dlpack_dtypes) @unittest.skipIf(not tf, "Test requires TensorFlow") - @jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor", - category=DeprecationWarning) def testTensorFlowToJax(self, shape, dtype): if (not config.enable_x64.value and dtype in [jnp.int64, jnp.uint64, jnp.float64]): @@ -179,13 +230,12 @@ def testTensorFlowToJax(self, shape, dtype): np = rng(shape, dtype) with tf.device("/GPU:0" if jtu.test_device_matches(["gpu"]) else "/CPU:0"): x = tf.identity(tf.constant(np)) - dlpack = tf.experimental.dlpack.to_dlpack(x) - y = jax.dlpack.from_dlpack(dlpack) + y = jax.dlpack.from_dlpack(x) self.assertAllClose(np, y) @jtu.sample_product( - shape=all_shapes, - dtype=dlpack_dtypes, + shape=all_shapes, + dtype=dlpack_dtypes, ) @unittest.skipIf(not tf, "Test requires TensorFlow") def testJaxToTensorFlow(self, shape, dtype): @@ -198,79 +248,75 @@ def testJaxToTensorFlow(self, shape, dtype): rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) x = jnp.array(np) + # TODO(parkers): Remove after setting 'stream' properly. + jax.block_until_ready(x) # TODO(b/171320191): this line works around a missing context initialization # bug in TensorFlow. _ = tf.add(1, 1) - dlpack = jax.dlpack.to_dlpack(x) - y = tf.experimental.dlpack.from_dlpack(dlpack) + y = tf.experimental.dlpack.from_dlpack(x.__dlpack__()) self.assertAllClose(np, y.numpy()) @unittest.skipIf(not tf, "Test requires TensorFlow") - @jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor", - category=DeprecationWarning) def testTensorFlowToJaxInt64(self): # See https://github.com/jax-ml/jax/issues/11895 - x = jax.dlpack.from_dlpack( - tf.experimental.dlpack.to_dlpack(tf.ones((2, 3), tf.int64))) + x = jax.dlpack.from_dlpack(tf.ones((2, 3), tf.int64)) dtype_expected = jnp.int64 if config.enable_x64.value else jnp.int32 self.assertEqual(x.dtype, dtype_expected) + @unittest.skipIf(not tf, "Test requires TensorFlow") + def testTensorFlowToJaxNondefaultLayout(self): + x = tf.transpose(np.arange(4).reshape(2, 2)) + self.assertAllClose(x.numpy(), jax.dlpack.from_dlpack(x)) + @jtu.sample_product( shape=all_shapes, dtype=numpy_dtypes, - copy=[False, True], + copy=[False, True, None], + aligned=[False, True], ) - def testNumpyToJax(self, shape, dtype, copy): + def testNumpyToJax(self, shape, dtype, copy, aligned): rng = jtu.rand_default(self.rng()) x_np = rng(shape, dtype) device = jax.devices()[0] + + alignment = _get_max_align_bits(dtype, device) if aligned else 2 + x_np = _ensure_alignment(x_np, desired_alignment=alignment) + _from_dlpack = lambda: jnp.from_dlpack(x_np, device=device, copy=copy) - if jax.default_backend() == 'gpu' and not copy: + if copy is not None and not copy and (jax.default_backend() != "cpu" + or not aligned): self.assertRaisesRegex( - ValueError, - r"Specified .* which requires a copy", - _from_dlpack + ValueError, "Specified .* which requires a copy", _from_dlpack ) else: self.assertAllClose(x_np, _from_dlpack()) - @jtu.sample_product( - shape=all_shapes, - dtype=numpy_dtypes, - ) - @jtu.run_on_devices("cpu") # NumPy only accepts cpu DLPacks + def testNumpyToJaxNondefaultLayout(self): + x = np.arange(4).reshape(2, 2).T + self.assertAllClose(x, jax.dlpack.from_dlpack(x)) + + @jtu.sample_product(shape=all_shapes, dtype=numpy_dtypes) + @jtu.run_on_devices("cpu") # NumPy only accepts cpu DLPacks def testJaxToNumpy(self, shape, dtype): rng = jtu.rand_default(self.rng()) x_jax = jnp.array(rng(shape, dtype)) x_np = np.from_dlpack(x_jax) self.assertAllClose(x_np, x_jax) - @jtu.ignore_warning(message="Calling from_dlpack.*", - category=DeprecationWarning) - def testNondefaultLayout(self): - # Generate numpy array with nonstandard layout - a = np.arange(4).reshape(2, 2) - b = a.T - with self.assertRaisesRegex( - RuntimeError, - r"from_dlpack got array with non-default layout with minor-to-major " - r"dimensions \(0,1\), expected \(1,0\)"): - b_jax = jax.dlpack.from_dlpack(b.__dlpack__()) - class CudaArrayInterfaceTest(jtu.JaxTestCase): - @jtu.skip_on_devices("cuda") + @jtu.skip_on_devices("cuda", "rocm") def testCudaArrayInterfaceOnNonCudaFails(self): x = jnp.arange(5) self.assertFalse(hasattr(x, "__cuda_array_interface__")) with self.assertRaisesRegex( AttributeError, - "__cuda_array_interface__ is only defined for NVidia GPU buffers.", + "__cuda_array_interface__ is only defined for GPU buffers.", ): _ = x.__cuda_array_interface__ - @jtu.run_on_devices("cuda") + @jtu.run_on_devices("gpu") def testCudaArrayInterfaceOnShardedArrayFails(self): devices = jax.local_devices() if len(devices) <= 1: @@ -291,7 +337,7 @@ def testCudaArrayInterfaceOnShardedArrayFails(self): shape=all_shapes, dtype=cuda_array_interface_dtypes, ) - @jtu.run_on_devices("cuda") + @jtu.run_on_devices("gpu") def testCudaArrayInterfaceWorks(self, shape, dtype): rng = jtu.rand_default(self.rng()) x = rng(shape, dtype) @@ -301,7 +347,7 @@ def testCudaArrayInterfaceWorks(self, shape, dtype): self.assertEqual(shape, a["shape"]) self.assertEqual(z.__array_interface__["typestr"], a["typestr"]) - @jtu.run_on_devices("cuda") + @jtu.run_on_devices("gpu") def testCudaArrayInterfaceBfloat16Fails(self): rng = jtu.rand_default(self.rng()) x = rng((2, 2), jnp.bfloat16) @@ -319,6 +365,8 @@ def testJaxToCuPy(self, shape, dtype): rng = jtu.rand_default(self.rng()) x = rng(shape, dtype) y = jnp.array(x) + # TODO(parkers): Remove after setting 'stream' properly. + jax.block_until_ready(y) z = cupy.asarray(y) self.assertEqual(y.__cuda_array_interface__["data"][0], z.__cuda_array_interface__["data"][0]) @@ -344,16 +392,20 @@ def testCuPyToJax(self, shape, dtype): shape=all_shapes, dtype=jtu.dtypes.supported(cuda_array_interface_dtypes), ) - @jtu.run_on_devices("cuda") + @jtu.run_on_devices("gpu") def testCaiToJax(self, shape, dtype): + dtype = np.dtype(dtype) + rng = jtu.rand_default(self.rng()) x = rng(shape, dtype) # using device with highest device_id for testing the correctness # of detecting the device id from a pointer value - device = jax.devices('cuda')[-1] + device = jax.devices('gpu')[-1] with jax.default_device(device): y = jnp.array(x, dtype=dtype) + # TODO(parkers): Remove after setting 'stream' properly below. + jax.block_until_ready(y) self.assertEqual(y.dtype, dtype) # Using a jax array CAI provider support to construct an object @@ -377,7 +429,7 @@ class CAIWithoutStrides: class CAIWithStrides: __cuda_array_interface__ = cai.copy() __cuda_array_interface__["version"] = 3 - strides = (dtype.dtype.itemsize,) if shape else () + strides = (dtype.itemsize,) if shape else () for s in reversed(shape[1:]): strides = (strides[0] * s, *strides) __cuda_array_interface__['strides'] = strides diff --git a/tests/array_test.py b/tests/array_test.py index cc8990828ded..216ef68e3901 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -29,15 +29,11 @@ from jax._src import test_util as jtu from jax._src import xla_bridge as xb from jax._src.lib import xla_client as xc -from jax._src.lib.mlir import dialects, ir from jax._src.util import safe_zip -from jax._src.mesh import AxisType +from jax._src.mesh import AxisType, AbstractMesh, Mesh from jax._src.sharding import common_devices_indices_map from jax._src.sharding_impls import ( - _op_sharding_to_pos_sharding, pmap_sharding_devices_indices_map, - NamedSharding, GSPMDSharding, PositionalSharding, SdyDimSharding, - SdyArraySharding) -from jax.experimental.pjit import pjit + pmap_sharding_devices_indices_map, NamedSharding, GSPMDSharding) from jax.experimental import multihost_utils from jax.sharding import PartitionSpec as P from jax._src import array @@ -166,6 +162,34 @@ def test_single_device_array_usage_after_delete(self): with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'): _ = x + 1 + @parameterized.named_parameters( + ('no_global_shape', np.arange(10).reshape(2, 5), None), + ('global_shape_prefix', {'a': np.arange(10).reshape(2, 5)}, (2, 5)), + ('global_shape_full', {'a': np.arange(10).reshape(2, 5)}, {'a': (2, 5)}), + ) + def test_array_from_local_data_single_host(self, data, global_shape): + jnp_data = jax.make_array_from_process_local_data( + jax.devices()[0], data, global_shape + ) + jax.tree.map(self.assertArraysEqual, data, jnp_data) + + @parameterized.named_parameters( + ('global_shape_prefix', {'a': np.arange(10).reshape(2, 5)}, (2, 8)), + ('global_shape_full', {'a': np.arange(10).reshape(2, 5)}, {'a': (2, 6)}), + ( + 'global_shape_extra', + {'a': np.arange(10).reshape(2, 5)}, + {'a': (2, 5), 'b': (3, 5)}, + ), + ) + def test_array_from_local_data_single_host_invalid_global_shape( + self, data, global_shape + ): + with self.assertRaises(ValueError): + jax.make_array_from_process_local_data( + jax.devices()[0], data, global_shape + ) + def test_multi_device_array_usage_after_delete(self): global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) shape = (8, 2) @@ -197,6 +221,15 @@ def test_device_put_array_delete(self): self.assertIsNone(arr._npy_value) self.assertIsNone(arr._arrays) + def test_device_put_to_cpu(self): + mesh = Mesh(jax.devices(), 'x') + mesh_cpu = Mesh(jax.devices('cpu'), 'x') + x = np.zeros(16) + y = jax.device_put(x, NamedSharding(mesh, P('x'))) + z = jax.device_put(y, NamedSharding(mesh_cpu, P('x'))) + for z_s in z.addressable_shards: + self.assertArraysEqual(z_s.data, x[z_s.index]) + def test_array_device_get(self): global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) @@ -368,8 +401,6 @@ def test_different_devices_in_arrays_than_sharding(self): array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True) def test_duplicated_devices_in_arrays(self): - if xc._version <= 274: - self.skipTest('Test requires jaxlib version 275') shape = (8, 2) mesh = jtu.create_mesh((1, 2), ('x', 'y')) # Sharding device ids = {0, 1} @@ -422,6 +453,11 @@ def test_mismatch_dtype(self): def test_array_iter_pmap_sharding(self): if jax.device_count() < 2: self.skipTest('Test requires >= 2 devices.') + if config.pmap_shmap_merge.value: + self.skipTest( + 'Under `pmap_shmap_merge=True`, `y[0]` of sharded `y` will replicate' + ' because of the indexing operation. ' + ) x = jnp.array([[1., 0., 0.], [0., 2., 3.]]) y = jax.pmap(jnp.sin)(x) @@ -465,7 +501,7 @@ def test_array_iter_replicated_multi_device(self): self.assertArraysEqual(i, j) self.assertLen(i.sharding.device_set, 8) self.assertTrue( - op_shardings.are_op_shardings_equal( + op_shardings.are_hlo_shardings_equal( arr.sharding._to_xla_hlo_sharding(arr.ndim), i.sharding._to_xla_hlo_sharding(i.ndim))) @@ -526,7 +562,7 @@ def test_array_getitem_replicated_multi_device(self): self.assertArraysEqual(s, np.array([[4], [6]])) self.assertLen(s.sharding.device_set, 8) self.assertTrue( - op_shardings.are_op_shardings_equal( + op_shardings.are_hlo_shardings_equal( arr.sharding._to_xla_hlo_sharding(arr.ndim), s.sharding._to_xla_hlo_sharding(s.ndim))) @@ -535,7 +571,7 @@ def test_array_getitem_replicated_multi_device(self): self.assertArraysEqual(p, input_data[:2]) self.assertLen(s.sharding.device_set, 8) self.assertTrue( - op_shardings.are_op_shardings_equal( + op_shardings.are_hlo_shardings_equal( arr.sharding._to_xla_hlo_sharding(arr.ndim), s.sharding._to_xla_hlo_sharding(s.ndim))) @@ -626,43 +662,23 @@ def f(x): self.assertEqual(input_shardings[1], {}) self.assertTrue( - op_shardings.are_op_shardings_equal( - input_shardings[0][0]._to_xla_hlo_sharding(x_dummy.ndim), - s._to_xla_hlo_sharding(x_dummy.ndim))) - self.assertTrue( - op_shardings.are_op_shardings_equal( - output_shardings._to_xla_hlo_sharding(x_dummy.ndim), - s._to_xla_hlo_sharding(x_dummy.ndim))) - - def test_shape_dtype_struct_sharding_pjit(self): - mesh = jtu.create_mesh((4, 2), ('x', 'y')) - s = jax.sharding.NamedSharding(mesh, P('x', 'y')) - - def f(x): - return x * 2. - - x_dummy = jax.ShapeDtypeStruct( - shape=(8, 2), - dtype=jnp.dtype('float32'), - sharding=s) - - c = pjit(f).lower(x_dummy).compile() - input_shardings, output_shardings = c.input_shardings, c.output_shardings - self.assertTrue( - op_shardings.are_op_shardings_equal( + op_shardings.are_hlo_shardings_equal( input_shardings[0][0]._to_xla_hlo_sharding(x_dummy.ndim), s._to_xla_hlo_sharding(x_dummy.ndim))) self.assertTrue( - op_shardings.are_op_shardings_equal( + op_shardings.are_hlo_shardings_equal( output_shardings._to_xla_hlo_sharding(x_dummy.ndim), s._to_xla_hlo_sharding(x_dummy.ndim))) - # TODO(skyewm): remove this test when we can remove the workaround manual - # defragment API - @jtu.skip_on_devices('cpu') # defragment not implemented for TFRT CPU + # TODO(b/399879011): GPU is the only platform that has an implementation for + # this, which exists in py_client.cc. Ideally, this would be replaced with + # some kind of auto-defrag-on-OOM. + @jtu.run_on_devices('gpu') def test_defragment(self): + # Since the GPU implementation is in py_client.cc, it cannot be exposed via + # the PjRt C API. if xb.using_pjrt_c_api(): - self.skipTest("Manual defragment not exposed via PJRT C API") + self.skipTest('Manual defragment not exposed via PJRT C API') # Create a few arrays global_mesh = jtu.create_mesh((jax.local_device_count(),), ('x',)) @@ -675,7 +691,7 @@ def test_defragment(self): # Delete one of them arr2.delete() - # Defragment + # Defragment. xb.get_backend().defragment() # Sanity check remaining arrays @@ -710,6 +726,13 @@ def test_process_allgather_single_host(self): self.assertEqual(out.shape, (1, x.shape[0])) self.assertArraysEqual(out, np.expand_dims(x, axis=0)) + def test_broadcast_one_to_all_single_host(self): + x = jnp.arange(8, dtype=jnp.uint8) + out = multihost_utils.broadcast_one_to_all(x) + self.assertEqual(out.shape, x.shape) + self.assertEqual(out.dtype, x.dtype) + self.assertArraysEqual(out, x) + @jtu.sample_product( dtype=jtu.dtypes.all, shape=[(), (10), (2, 3)], @@ -732,6 +755,21 @@ def test_buffer_protocol(self, dtype, shape): y_bytes = memoryview(y).tobytes() self.assertEqual(x_bytes, y_bytes) + @jtu.run_on_devices("cpu") + def test_buffer_protocol_donation(self): + + @jax.jit(donate_argnums=(0,)) + def add_one(x): + return x + 1; + + rng = jtu.rand_default(self.rng()) + x = rng((64, 64), np.float32) + y = jax.device_put(x) + # holds ref. + y_bytes = memoryview(y) + # doesn't crash + self.assertArraysEqual(add_one(y), x + 1) + @jtu.run_on_devices("cpu") def test_buffer_protocol_deletion(self): rng = jtu.rand_default(self.rng()) @@ -749,8 +787,8 @@ def test_buffer_protocol_deletion(self): def test_array_copy_to_host_async(self): global_mesh = jtu.create_mesh((2, 2), ('x', 'y')) - x = pjit(lambda: jnp.arange(8.), - out_shardings=jax.sharding.NamedSharding(global_mesh, P(None)))() + x = jax.jit(lambda: jnp.arange(8.), + out_shardings=jax.NamedSharding(global_mesh, P(None)))() self.assertLen(x.sharding.device_set, 4) x.copy_to_host_async() # doesn't crash self.assertArraysEqual(np.arange(8.), x) @@ -821,6 +859,11 @@ def test_make_array_from_process_data_single_host_data_sharding(self): self.assertArraysEqual(result, data) self.assertEqual(result.sharding, s) + with jax.set_mesh(mesh): + result = jax.make_array_from_process_local_data(P('x'), data) + self.assertArraysEqual(result, data) + self.assertEqual(result.sharding, s) + @parameterized.product(dtype=jtu.dtypes.all + jtu.dtypes.custom_floats) @jtu.run_on_devices("gpu") def test_pinned_host_npy_value_doesnt_cache(self, dtype): @@ -832,7 +875,6 @@ def test_pinned_host_npy_value_doesnt_cache(self, dtype): np.array(h_tensor) self.assertIsNone(h_tensor._npy_value) - @config.enable_empty_arrays(True) def test_make_array_from_single_device_arrays_no_dtype_error(self): mesh = jtu.create_mesh((4, 2), ('x', 'y')) s = jax.sharding.NamedSharding(mesh, P('x', 'y')) @@ -854,6 +896,13 @@ def test_make_array_from_single_device_arrays_bad_dtype_error(self): jax.make_array_from_single_device_arrays( shape, s, [arr], dtype=jnp.float32) + @jtu.with_explicit_mesh((2,), ('x',)) + def test_unreduced_printing(self, mesh): + x = jax.device_put(jnp.arange(8., dtype='float32'), P('x')) + x = jax.lax.reduce_sum(x, [0], out_sharding=P(unreduced={'x'})) + self.assertIn('nreduced', str(x.sharding)) + self.assertIn('Array(shape=(), dtype=float32, sharding=', str(x)) + class ShardingTest(jtu.JaxTestCase): @@ -897,7 +946,7 @@ def test_op_sharding_indices(self, pspec): shape = (8, 4) mesh = jtu.create_mesh((4, 2), ('x', 'y')) mps = jax.sharding.NamedSharding(mesh, pspec) - ops = jax.sharding.GSPMDSharding( + ops = GSPMDSharding( list(mesh.devices.flat), mps._to_xla_hlo_sharding(len(shape))) self.assertDictEqual( ops.devices_indices_map(shape), mps.devices_indices_map(shape)) @@ -927,10 +976,15 @@ def test_uneven_shard_error(self): r"factors: \[4, 2\] should evenly divide the shape\)"): mps.shard_shape((8, 3)) + @jtu.ignore_warning(category=DeprecationWarning) @jtu.thread_unsafe_test() # cache_info isn't thread-safe def test_pmap_sharding_hash_eq(self): if jax.device_count() < 2: self.skipTest('Test needs >= 2 devices.') + if config.pmap_shmap_merge.value: + self.skipTest( + 'There is not an equivalent cache to test when pmap_shmap_merge=True.' + ) shape = (2, 2) num_elements = math.prod(shape) @@ -973,7 +1027,7 @@ def test_gspmd_sharding_repr(self): op.tile_assignment_dimensions = [4, 1, 2] op.tile_assignment_devices = [0, 1, 2, 3, 4, 5, 6, 7] op.replicate_on_last_tile_dim = True - s = jax.sharding.GSPMDSharding(jax.devices(), op) + s = GSPMDSharding(jax.devices(), op) # memory kind also appears in the repr but only for TPU. self.assertIn( 'GSPMDSharding({devices=[4,1,2]0,1,2,3,4,5,6,7 ' @@ -981,93 +1035,10 @@ def test_gspmd_sharding_repr(self): op2 = xc.OpSharding() op2.type = xc.OpSharding.Type.REPLICATED - s2 = jax.sharding.GSPMDSharding(jax.devices(), op2) + s2 = GSPMDSharding(jax.devices(), op2) # memory kind also appears in the repr but only for TPU. self.assertIn('GSPMDSharding({replicated}', repr(s2)) - def test_positional_sharding_fully_replicated(self): - sharding = PositionalSharding(jax.devices()) - jax.device_put(jnp.array(1), sharding.replicate()) # doesn't crash - - @parameterized.named_parameters( - ("mesh_x_y", P("x", "y"), (4, 2), (), False), - ("mesh_x", P("x"), (4, 2), (1,), False), - ("mesh_y", P("y"), (4, 2), (0,), True), - ("mesh_none_y", P(None, "y"), (4, 2), (0,), False), - ("mesh_none_x", P(None, "x"), (4, 2), (1,), True), - ("mesh_xy", P(("x", "y")), (8, 1), (), False), - ("mesh_fully_replicated", P(), (4, 2), None, False), - ) - def test_positional_sharding_op_sharding_lowering( - self, pspec, shape, axes, transpose): - value_shape = (8, 4) - - mesh = jtu.create_mesh((4, 2), ('x', 'y')) - mps = jax.sharding.NamedSharding(mesh, pspec) - devices = jax.local_devices()[:8] # Taking up to 8 devices - - devices_sharding = jax.sharding.PositionalSharding(devices) - devices_sharding = devices_sharding.reshape(shape).replicate(axes) - if transpose: - devices_sharding = devices_sharding.T - - op1 = mps._to_xla_hlo_sharding(len(value_shape)) - op2 = devices_sharding._to_xla_hlo_sharding(len(value_shape)) - - self.assertEqual(mps.shard_shape(value_shape), - devices_sharding.shard_shape(value_shape)) - self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2)) - - def test_positional_sharding_aval_compatible(self): - if jax.device_count() < 2: - self.skipTest('Requires >=2 devices') - sharding = PositionalSharding(jax.devices()).reshape(1, jax.device_count()) - x = jax.random.uniform(jax.random.key(42), (256, 20, 1000)) - with self.assertRaisesRegex( - ValueError, - 'Sharding PositionalSharding.*is only valid for values of rank 2, but' - ' was applied to a value of rank 3'): - jax.lax.with_sharding_constraint(x, sharding) - - @parameterized.named_parameters( - ("2d_mesh_x_y", (4, 2), P("x", "y")), - ("2d_mesh_x", (4, 2), P("x")), - ("2d_mesh_y", (4, 2), P("y")), - ("2d_mesh_none_y", (4, 2), P(None, "y")), - ("2d_mesh_none_x", (4, 2), P(None, "x")), - ("2d_mesh_xy", (4, 2), P(("x", "y"))), - ("2d_mesh_none_xy", (4, 2), P(None, ("x", "y"))), - ("2d_mesh_x_none", (2, 1), P(('x',), None)), - ("2d_mesh_fully_replicated", (4, 2), P()), - ("3d_mesh_none_none_z", (2, 2, 2), P(None, None, 'z')), - ("3d_mesh_none_y_none", (2, 2, 2), P(None, 'y', None)), - ("3d_mesh_x_y_none", (2, 2, 2), P('x', 'y', None)), - ("3d_mesh_none_yz", (2, 2, 2), P(None, ('y', 'z'))), - ("3d_mesh_x_none_yz", (2, 2, 2), P('x', None, ('y', 'z'))), - ("3d_mesh_none_x_yz", (2, 2, 2), P(None, 'x', ('y', 'z'))), - ("3d_mesh_xy_z", (2, 2, 2), P(('x', 'y'), 'z')), - ("3d_mesh_xy_none_z", (2, 2, 2), P(('x', 'y'), None, 'z')), - ("3d_mesh_x_y_z", (2, 2, 2), P('x', 'y', 'z')), - ("3d_mesh_xz_y", (2, 2, 2), P(('x', 'z'), 'y')), - ("3d_mesh_xz_none_y", (2, 2, 2), P(('x', 'z'), None, 'y')), - ("3d_mesh_y_none_xz", (2, 2, 2), P('y', None, ('x', 'z'))), - ("3d_mesh_none_y_xz", (2, 2, 2), P(None, 'y', ('x', 'z'))), - ("3d_mesh2_none_none_z", (1, 2, 4), P(None, None, 'z')), - ("3d_mesh2_x_none_none", (1, 2, 4), P('x', None, None)), - ("3d_mesh_x_none_none", (2, 1, 1), P('x', None, None)), - ) - def test_positional_sharding_from_op_sharding(self, mesh_shape, pspec): - ndim = len(mesh_shape) - mesh = jtu.create_mesh( - mesh_shape, ('x', 'y') if ndim == 2 else ('x', 'y', 'z')) - mps = jax.sharding.NamedSharding(mesh, pspec) - original_op_sharding = mps._to_xla_hlo_sharding(ndim) - ps = _op_sharding_to_pos_sharding(original_op_sharding, - mps._device_assignment) - out_op_sharding = ps._to_xla_hlo_sharding(ndim) - self.assertTrue(op_shardings.are_op_shardings_equal( - original_op_sharding, out_op_sharding)) - @parameterized.named_parameters( ("2d_mesh_x", (1, 1), P("x", "y")), ("2d_mesh_x_y", (4, 2), P("x", "y")), @@ -1094,29 +1065,9 @@ def test_is_fully_replicated_named_sharding(self, mesh_shape, pspec): mps = jax.sharding.NamedSharding(mesh, pspec) shape = (8, 2, 4) mps_op_sharding = mps._to_xla_hlo_sharding(len(shape)) - ops_ifr = op_shardings.is_op_sharding_replicated(mps_op_sharding) + ops_ifr = op_shardings.is_hlo_sharding_replicated(mps_op_sharding) self.assertEqual(mps.is_fully_replicated, ops_ifr) - ps = _op_sharding_to_pos_sharding(mps_op_sharding, mps._device_assignment) - self.assertEqual(ps.is_fully_replicated, - op_shardings.is_op_sharding_replicated( - ps._to_xla_hlo_sharding(len(shape)))) - - def test_devices_sharding_respects_init_mesh_shape(self): - value_shape = (8, 4) - - mesh = jtu.create_mesh((4, 2), ('x', 'y')) - mps = jax.sharding.NamedSharding(mesh, P('x', 'y')) - - devices_sharding = jax.sharding.PositionalSharding(mesh.devices) - - op1 = mps._to_xla_hlo_sharding(len(value_shape)) - op2 = devices_sharding._to_xla_hlo_sharding(len(value_shape)) - - self.assertEqual(mps.shard_shape(value_shape), - devices_sharding.shard_shape(value_shape)) - self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2)) - def test_pmap_sharding_repr(self): if jax.device_count() < 2: self.skipTest('Test needs >= 2 devices.') @@ -1124,13 +1075,6 @@ def test_pmap_sharding_repr(self): str(out.sharding) # doesn't crash repr(out.sharding) # doesn't crash - def test_positional_sharding_repr(self): - if jax.device_count() < 2: - self.skipTest('Test needs >= 2 devices.') - s = jax.sharding.PositionalSharding(jax.devices()).reshape(jax.device_count(), 1) - repr(s) # doesn't crash - str(s) # doesn't crash - def test_pspec_tuple(self): pspec = P('x', 'y', 'z') self.assertEqual(pspec, ('x', 'y', 'z')) @@ -1143,18 +1087,33 @@ def test_pspec_tuple(self): ('sharded_dim_2', (4, 2, 4), 2), ('sharded_dim_1_1', (2, 4), 1) ) + @jtu.ignore_warning(category=DeprecationWarning) def test_default_pmap_sharding(self, shape, sharded_dim): if jax.device_count() < 4: self.skipTest('Test needs >= 4 devices.') - ps = jax.sharding.PmapSharding.default(shape, sharded_dim) inp = jnp.arange(math.prod(shape)).reshape(shape) - compiled = jax.pmap(lambda x: x, in_axes=sharded_dim).lower(inp).compile() - pmap_in_sharding, = compiled._executable.unsafe_call.in_handler.in_shardings - - self.assertEqual(ps._device_assignment, pmap_in_sharding._device_assignment) - self.assertEqual(ps.sharding_spec, pmap_in_sharding.sharding_spec) - + if config.pmap_shmap_merge.value: + out = jax.pmap(lambda x: x, in_axes=sharded_dim, axis_name='x')(inp) + actual_sharding = out.sharding + expected_sharding = jax.sharding.NamedSharding( + jax.sharding.Mesh(jax.devices()[: shape[sharded_dim]], 'x'), + jax.P('x'), + ) + self.assertEqual(actual_sharding.spec, expected_sharding.spec) + self.assertEqual(actual_sharding._device_assignment, expected_sharding._device_assignment) + else: + compiled = jax.pmap(lambda x: x, in_axes=sharded_dim).lower(inp).compile() + # TOOD(dsuo): Investigate why + # `compiled._executable.unsafe_call.in_handler.in_shardings` is of type + # `GSPMDSharding` when `pmap_shmap_merge=True`. It should be + # `NamedSharding`. + actual_sharding, = compiled._executable.unsafe_call.in_handler.in_shardings + expected_sharding = jax.sharding.PmapSharding.default(shape, sharded_dim) + self.assertEqual(actual_sharding.sharding_spec, expected_sharding.sharding_spec) + self.assertEqual(actual_sharding._device_assignment, expected_sharding._device_assignment) + + @jtu.ignore_warning(category=DeprecationWarning) def test_default_pmap_sharding_with_devices(self): if jax.device_count() < 4: self.skipTest('Test needs >= 4 devices.') @@ -1164,18 +1123,26 @@ def test_default_pmap_sharding_with_devices(self): ps = jax.sharding.PmapSharding.default((4, 2), devices=new_order) self.assertEqual(ps._device_assignment, new_order) + @jtu.ignore_warning(category=DeprecationWarning) def test_default_pmap_sharding_replicated(self): x = np.zeros((len(jax.local_devices()), 8), dtype=np.float32) - x = jax.pmap(lambda x: x, in_axes=0, out_axes=None)(x) - ps = jax.sharding.PmapSharding.default( - shape=(8,), sharded_dim=None, - devices=jax.local_devices()) - self.assertEqual(x.sharding, ps) + x = jax.pmap(lambda x: x, in_axes=0, out_axes=None, axis_name='x')(x) + if config.pmap_shmap_merge.value: + expected_sharding = jax.sharding.NamedSharding( + mesh=jax.sharding.Mesh(jax.local_devices(), 'x'), + spec=jax.P(), + ) + self.assertEqual(x.sharding, expected_sharding) + else: + ps = jax.sharding.PmapSharding.default( + shape=(8,), sharded_dim=None, + devices=jax.local_devices()) + self.assertEqual(x.sharding, ps) def test_mesh_repr(self): mesh = jtu.create_mesh((1, 1), ('x', 'y')) mesh_repr = repr(mesh) - self.assertIn('device_ids', mesh_repr) + self.assertIn('axis_sizes', mesh_repr) self.assertIn('axis_names', mesh_repr) def test_are_shardings_equivalent(self): @@ -1198,9 +1165,9 @@ def test_are_shardings_equivalent(self): op1 = xc.OpSharding() op1.type = xc.OpSharding.Type.REPLICATED - s6 = jax.sharding.GSPMDSharding([jax.devices()[0]], op1) + s6 = GSPMDSharding([jax.devices()[0]], op1) - s7 = jax.sharding.GSPMDSharding(jax.devices(), op1) + s7 = GSPMDSharding(jax.devices(), op1) # The OpSharding is replicated but the Sharding itself are on different # devices. @@ -1210,7 +1177,7 @@ def test_are_shardings_equivalent(self): op2.type = xc.OpSharding.Type.OTHER op2.tile_assignment_devices = [0, 1] op2.tile_assignment_dimensions = [2, 1] - s8 = jax.sharding.GSPMDSharding(list(mesh2.devices.flat), op2) + s8 = GSPMDSharding(list(mesh2.devices.flat), op2) self.assertTrue(s1.is_equivalent_to(s6, 2)) self.assertTrue(s5.is_equivalent_to(s8, 2)) @@ -1223,7 +1190,7 @@ def test_are_shardings_equivalent(self): op3.tile_assignment_devices = [0, 1] op3.tile_assignment_dimensions = [1, 1, 2] op3.replicate_on_last_tile_dim = True - s10 = jax.sharding.GSPMDSharding(list(mesh2.devices.flat), op3) + s10 = GSPMDSharding(list(mesh2.devices.flat), op3) self.assertTrue(s9.is_equivalent_to(s10, 2)) @@ -1301,6 +1268,18 @@ def f(x): with self.assertRaisesRegex(TypeError, msg): jax.jit(f)(x) + def test_make_array_from_single_device_arrays_tuple(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + shape = (8, 8) + s = jax.sharding.NamedSharding(mesh, P('x', 'y')) + inp_data = np.arange(math.prod(shape)).reshape(shape) + + arrays = tuple( + jax.device_put(inp_data[index], d) + for d, index in s.addressable_devices_indices_map(shape).items()) + + jax.make_array_from_single_device_arrays(shape, s, arrays) # doesn't crash + def test_make_array_from_single_device_arrays_bad_inputs(self): x = jnp.arange(10) mesh = jtu.create_mesh((2,), ('x',)) @@ -1363,6 +1342,16 @@ def test_mesh_axis_types_mismatch(self): jax.sharding.AbstractMesh((2, 1), ('x', 'y'), axis_types=jax.sharding.AxisType.Auto) + with self.assertRaisesRegex(TypeError, "axis_types.*must be of type"): + AbstractMesh((2,), ('x',), axis_types=("explicit",)) + + with self.assertRaisesRegex(TypeError, "axis_types.*must be of type"): + AbstractMesh((2,), ('x',), axis_types="explicit") + + with self.assertRaisesRegex(TypeError, "axis_types.*must be of type"): + AbstractMesh((2, 2), ('x', 'y'), + axis_types=("explicit", AxisType.Explicit)) + def test_make_mesh_axis_types(self): Auto, Explicit, Manual = AxisType.Auto, AxisType.Explicit, AxisType.Manual @@ -1371,24 +1360,14 @@ def test_make_mesh_axis_types(self): self.assertEqual(mesh1, mesh2) mesh = jax.make_mesh((1, 1), ('x', 'y')) - self.assertDictEqual(mesh._axis_types_dict, {AxisType.Auto: ('x', 'y')}) + self.assertTupleEqual(mesh.axis_types, (AxisType.Explicit,) * 2) mesh = jax.make_mesh((1, 1, 1), ('x', 'y', 'z'), axis_types=(Explicit, Auto, Manual)) - self.assertDictEqual( - mesh._axis_types_dict, {AxisType.Auto: ('y',), AxisType.Explicit: ('x',), - AxisType.Manual: ('z',)}) - - mesh = jax.make_mesh((1, 1, 1), ('x', 'y', 'z'), - axis_types=(Explicit, Explicit, Manual)) - self.assertDictEqual(mesh._axis_types_dict, {AxisType.Explicit: ('x', 'y'), - AxisType.Manual: ('z',)}) - mesh = jax.make_mesh((1, 1), ('x', 'y'), axis_types=(Explicit, Explicit)) - self.assertDictEqual(mesh._axis_types_dict, {AxisType.Explicit: ('x', 'y')}) - - mesh = jax.make_mesh((1,), 'model', axis_types=Manual) - self.assertDictEqual(mesh._axis_types_dict, {AxisType.Manual: ('model',)}) + self.assertEqual(mesh.explicit_axes, ('x',)) + self.assertEqual(mesh.auto_axes, ('y',)) + self.assertEqual(mesh.manual_axes, ('z',)) with self.assertRaisesRegex( ValueError, @@ -1402,53 +1381,216 @@ def test_make_mesh_axis_types(self): self.assertNotEqual(mesh1, mesh2) self.assertNotEqual(hash(mesh1), hash(mesh2)) + def test_memory_kind_with_abstract_mesh(self): + abstract_mesh = AbstractMesh((2,), ('x',)) + ns = NamedSharding(abstract_mesh, P(), memory_kind='pinned_host') + self.assertEqual(ns.memory_kind, 'pinned_host') -@jtu.with_config(jax_use_shardy_partitioner=True) -class ShardyShardingTest(jtu.JaxTestCase): + ns = NamedSharding(abstract_mesh, P()) + self.assertIsNone(ns.memory_kind) - def test_long_axis_names(self): - mesh = jtu.create_mesh((2, 2, 2), ('sequence', 'data', 'model')) - s = jax.sharding.NamedSharding(mesh, P(('sequence', 'data'), 'model')) - sdy_sharding = s._to_sdy_sharding(3) - self.assertEqual( - sdy_sharding, - SdyArraySharding( - mesh.shape_tuple, - [SdyDimSharding( - ('sequence', 'data'), True), - SdyDimSharding(('model',), True), - SdyDimSharding([], True)])) - with ir.Context() as ctx: - dialects.sdy.register_dialect(ctx) - self.assertEqual( - str(sdy_sharding.build()), - '#sdy.sharding,' - ' [{"sequence", "data"}, {"model"}, {}]>', - ) + with self.assertRaisesRegex( + ValueError, 'Got invalid memory kind'): + NamedSharding(abstract_mesh, P(), memory_kind='weird_device') + + def test_pspec_mix_axis_types(self): + mesh = AbstractMesh( + (2, 2, 2, 2), ('a', 'b', 'c', 'd'), + axis_types=(AxisType.Explicit, AxisType.Explicit, AxisType.Auto, + AxisType.Manual)) + aval = jax.core.ShapedArray((16, 8, 4, 2), np.float32) + + out = aval.update(sharding=NamedSharding(mesh, P(('a', 'b', 'c'), 'd'))) + self.assertEqual(out.sharding.spec, P(('a', 'b'), None, None, None)) + + out = aval.update(sharding=NamedSharding(mesh, P(('a', 'c'), 'b', 'd'))) + self.assertEqual(out.sharding.spec, P('a', 'b', None, None)) + + out = aval.update(sharding=NamedSharding(mesh, P(('a', 'b'), 'c', 'd'))) + self.assertEqual(out.sharding.spec, P(('a', 'b'), None, None, None)) + + out = aval.update(sharding=NamedSharding(mesh, P(('a', 'd'), 'b', 'c'))) + self.assertEqual(out.sharding.spec, P('a', 'b', None, None)) - def test_unconstrained(self): - mesh = jtu.create_mesh((8,), ('x',)) - s = jax.sharding.NamedSharding(mesh, P(None, P.UNCONSTRAINED, 'x')) - sdy_sharding = s._to_sdy_sharding(3) + def test_aval_str_short(self): + mesh = AbstractMesh( + (2, 2, 2), ('a', 'b', 'c'), + axis_types=(AxisType.Explicit, AxisType.Explicit, AxisType.Manual)) + + s = NamedSharding(mesh, P(unreduced={'a'}, reduced={'b'})) + aval = jax.core.ShapedArray((1, 1, 1, 1), np.float32, sharding=s, + vma=frozenset('c')) + self.assertEqual(aval.str_short(True), 'f32[1,1,1,1]{V:c, U:a, R:b}') + + s = NamedSharding(mesh, P(unreduced={'a'})) + aval = jax.core.ShapedArray((1, 1, 1, 1), np.float32, sharding=s, + vma=frozenset('c')) + self.assertEqual(aval.str_short(True), 'f32[1,1,1,1]{V:c, U:a}') + + s = NamedSharding(mesh, P(unreduced={'a'})) + aval = jax.core.ShapedArray((1, 1, 1, 1), np.float32, sharding=s) + self.assertEqual(aval.str_short(True), 'f32[1,1,1,1]{U:a}') + + s = NamedSharding(mesh, P()) + aval = jax.core.ShapedArray((1, 1, 1, 1), np.float32, sharding=s, + vma=frozenset('c')) + self.assertEqual(aval.str_short(True), 'f32[1,1,1,1]{V:c}') + + aval = jax.core.ShapedArray((1, 1, 1, 1), np.float32) + self.assertEqual(aval.str_short(True), 'f32[1,1,1,1]') + + def test_modify_spec_auto_unreduced(self): + mesh = AbstractMesh( + (2, 2, 2), ('a', 'b', 'c'), + axis_types=(AxisType.Explicit, AxisType.Explicit, AxisType.Auto)) + spec = P(unreduced={'a', 'b', 'c'}) + out = core.modify_spec_for_auto_manual(spec, mesh) + self.assertEqual(out, P(unreduced={'a', 'b'})) + + spec = P(reduced={'a', 'b', 'c'}) + out = core.modify_spec_for_auto_manual(spec, mesh) + self.assertEqual(out, P(reduced={'a', 'b'})) + + spec = P(unreduced={'a', 'b'}, reduced={'c'}) + out = core.modify_spec_for_auto_manual(spec, mesh) + self.assertEqual(out, P(unreduced={'a', 'b'})) + + spec = P(unreduced={'a', 'c'}, reduced={'b'}) + out = core.modify_spec_for_auto_manual(spec, mesh) + self.assertEqual(out, P(unreduced={'a'}, reduced={'b'})) + + spec = P(unreduced={'c'}, reduced={'a', 'b'}) + out = core.modify_spec_for_auto_manual(spec, mesh) + self.assertEqual(out, P(reduced={'a', 'b'})) + + def test_pspec_unreduced(self): + pspec = P('a', 'b', None, unreduced={'c'}, reduced={'d'}) self.assertEqual( - sdy_sharding, - SdyArraySharding( - mesh.shape_tuple, - [SdyDimSharding([], True), - SdyDimSharding([], False), - SdyDimSharding(('x',), True)])) - with ir.Context() as ctx: - dialects.sdy.register_dialect(ctx) - self.assertEqual( - str(sdy_sharding.build()), - '#sdy.sharding, [{}, {?}, {"x"}]>') + repr(pspec), + "PartitionSpec('a', 'b', None, unreduced={'c'}, reduced={'d'})") + + pspec1 = P('a', 'b', None, unreduced={'c'}) + self.assertEqual(repr(pspec1), + "PartitionSpec('a', 'b', None, unreduced={'c'})") + + pspec2 = P('a', 'b', None, unreduced={'c'}) + self.assertEqual(pspec1, pspec2) + + pspec3 = P('a', 'b', None, unreduced={'d'}) + self.assertNotEqual(pspec1, pspec3) + + out = P('x', unreduced={'z'}) + P('a', unreduced={'b'}) + self.assertEqual(out, P('x', 'a', unreduced={'z', 'b'})) + + pspec4 = P('x', unreduced={'y'}) + self.assertEqual(repr(pspec4), + "PartitionSpec('x', unreduced={'y'})") + + pspec5 = P(None, None, unreduced={'x'}) + self.assertEqual(repr(pspec5), + "PartitionSpec(None, None, unreduced={'x'})") + + pspec6 = P(None, unreduced={'x'}) + self.assertEqual(repr(pspec6), "PartitionSpec(None, unreduced={'x'})") + + pspec7 = P(unreduced={'x'}) + self.assertEqual(repr(pspec7), "PartitionSpec(unreduced={'x'})") + + with self.assertRaisesRegex( + TypeError, 'unreduced in `__add__` of PartitionSpec'): + P('x', unreduced={'z'}) + (None,) * 2 + + with self.assertRaisesRegex( + TypeError, "unreduced in `__radd__` of PartitionSpec"): + (None,) * 2 + P('x', unreduced={'y'}) + + with self.assertRaisesRegex( + ValueError, "partitions cannot overlap with unreduced"): + P('x', 'y', unreduced={'x'}) + + with self.assertRaisesRegex( + ValueError, "partitions cannot overlap with unreduced"): + P('x', None, 'y', unreduced={'z', 'y'}) + + def test_named_sharding_unreduced_error(self): + mesh = jtu.create_mesh((1, 1, 1), ('x', 'y', 'z')) + + with self.assertRaisesRegex( + ValueError, "Unreduced axes.*not found in mesh.*"): + NamedSharding(mesh, P('x', unreduced={'a'})) + + with self.assertRaisesRegex( + ValueError, "Unreduced axes can only refer to mesh axes.*Explicit"): + NamedSharding(mesh, P('x', unreduced={'y', 'z'})) + + with self.assertRaisesRegex( + ValueError, "unreduced cannot contain None.*"): + NamedSharding(mesh, P('x', unreduced={'y', None})) + + def test_hlo_sharding_get_axis_sizes(self): + op = xc.OpSharding() + op.type = xc.OpSharding.Type.OTHER + op.tile_assignment_dimensions = [6, 35] + op.iota_reshape_dims = [7, 10, 3] + op.iota_transpose_perm = [2, 1, 0] + s = GSPMDSharding(jax.devices(), op) + self.assertIn('{devices=[6,35]<=[7,10,3]T(2,1,0)}', repr(s)) + self.assertEqual(s._to_xla_hlo_sharding(2).get_axis_sizes(), [7, 2, 5, 3]) + + @parameterized.named_parameters( + ('2d_mesh_x_y', (4, 2), P('x', 'y')), + ('2d_mesh_x', (4, 2), P('x')), + ('2d_mesh_y', (4, 2), P('y')), + ('2d_mesh_none_y', (4, 2), P(None, 'y')), + ('2d_mesh_none_x', (4, 2), P(None, 'x')), + ('2d_mesh_xy', (4, 2), P(('x', 'y'))), + ('2d_mesh_none_xy', (4, 2), P(None, ('x', 'y'))), + ('2d_mesh_fully_replicated', (4, 2), P()), + ('2d_mesh_x_none', (2, 1), P(('x',), None)), + ('3d_mesh_none_none_z', (2, 2, 2), P(None, None, 'z')), + ('3d_mesh_none_y_none', (2, 2, 2), P(None, 'y', None)), + ('3d_mesh_x_y_none', (2, 2, 2), P('x', 'y', None)), + ('3d_mesh_none_yz', (2, 2, 2), P(None, ('y', 'z'))), + ('3d_mesh_x_none_yz', (2, 2, 2), P('x', None, ('y', 'z'))), + ('3d_mesh_none_x_yz', (2, 2, 2), P(None, 'x', ('y', 'z'))), + ('3d_mesh_xy_z', (2, 2, 2), P(('x', 'y'), 'z')), + ('3d_mesh_xy_none_z', (2, 2, 2), P(('x', 'y'), None, 'z')), + ('3d_mesh_x_y_z', (2, 2, 2), P('x', 'y', 'z')), + ('3d_mesh_xz_y', (2, 2, 2), P(('x', 'z'), 'y')), + ('3d_mesh_xz_none_y', (2, 2, 2), P(('x', 'z'), None, 'y')), + ('3d_mesh_y_none_xz', (2, 2, 2), P('y', None, ('x', 'z'))), + ('3d_mesh_none_y_xz', (2, 2, 2), P(None, 'y', ('x', 'z'))), + ('3d_mesh2_none_none_z', (1, 2, 4), P(None, None, 'z')), + ('3d_mesh2_x_none_none', (1, 2, 4), P('x', None, None)), + ('3d_mesh_x_none_none', (2, 1, 1), P('x', None, None)), + ) + def test_gspmd_sharding_shardy_lowering(self, mesh_shape, pspec): + ndim = len(mesh_shape) + mesh = jtu.create_mesh( + mesh_shape, ('x', 'y') if ndim == 2 else ('x', 'y', 'z') + ) + ns = jax.sharding.NamedSharding(mesh, pspec) + gs = GSPMDSharding(ns._device_assignment, ns._to_xla_hlo_sharding(ndim)) + out_sdy_sharding = gs._to_sdy_sharding(ndim) + self.assertTrue(out_sdy_sharding, ns._to_sdy_sharding(ndim)) + + def test_nested_tuple_pspec_error(self): + with self.assertRaisesRegex( + ValueError, + "A tuple inside PartitionSpec cannot contain a nested tuple"): + jax.P('x', 'y', ('z', ('a',))) + + with self.assertRaisesRegex( + ValueError, + "A tuple inside PartitionSpec cannot contain a nested tuple"): + jax.P((('a', 'b'), 'c')) class RngShardingTest(jtu.JaxTestCase): # tests that the PRNGs are automatically sharded as expected @parameterized.named_parameters(("3", 3), ("4", 4), ("5", 5)) - @jtu.skip_on_devices("gpu") + @jtu.skip_on_devices("cuda") def test_random_bits_is_pure_map_1d(self, num_devices): @jax.jit def f(x): @@ -1482,7 +1624,7 @@ def f(x): "mesh_shape": mesh_shape, "pspec": pspec} for mesh_shape in [(3, 2), (4, 2), (2, 3)] for pspec in [P('x', None), P(None, 'y'), P('x', 'y')]) - @jtu.skip_on_devices("gpu") + @jtu.skip_on_devices("cuda") def test_random_bits_is_pure_map_2d(self, mesh_shape, pspec): @jax.jit def f(x): diff --git a/tests/attrs_test.py b/tests/attrs_test.py deleted file mode 100644 index 2334a7b98f91..000000000000 --- a/tests/attrs_test.py +++ /dev/null @@ -1,670 +0,0 @@ -# Copyright 2024 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from dataclasses import dataclass - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as np - -import jax -import jax.numpy as jnp - -from jax._src import config -from jax._src import test_util as jtu -from jax._src.util import safe_zip, safe_map - -from jax.experimental import attrs -from jax.experimental.attrs import jax_setattr, jax_getattr - -config.parse_flags_with_absl() - -map, unsafe_map = safe_map, map -zip, unsafe_zip = safe_zip, zip - -@dataclass -class Thing: - x: float - __hash__ = object.__hash__ - __eq__ = object.__eq__ - -attrs.register(Thing) # enables passing as arg into jitted function - -class AttrsTest(jtu.JaxTestCase): - - @parameterized.parameters([True, False]) - def test_jit_basic(self, jit: bool): - thing = Thing(1.0) - - def double_it() -> None: - cur_x = jax_getattr(thing, "x") - jax_setattr(thing, "x", cur_x * 2) - - if jit: - double_it = jax.jit(double_it) - - self.assertEqual(thing.x, 1.0) - double_it() - self.assertEqual(thing.x, 2.0) - double_it() - self.assertEqual(thing.x, 4.0) - double_it() - self.assertEqual(thing.x, 8.0) - double_it() - self.assertEqual(thing.x, 16.0) - - @parameterized.parameters([True, False]) - def test_jit_basic_tree(self, jit: bool): - thing = Thing((1.0, 2.0)) - - def double_it() -> None: - (cur_x, cur_y) = jax_getattr(thing, "x") - jax_setattr(thing, "x", (cur_x * 2, cur_y * 2)) - - if jit: - double_it = jax.jit(double_it) - - self.assertEqual(thing.x, (1.0, 2.0)) - double_it() - self.assertEqual(thing.x, (2.0, 4.0)) - double_it() - self.assertEqual(thing.x, (4.0, 8.0)) - double_it() - self.assertEqual(thing.x, (8.0, 16.0)) - double_it() - self.assertEqual(thing.x, (16.0, 32.0)) - - @parameterized.parameters([True, False]) - def test_jit_basic_tree_changes(self, jit: bool): - thing = Thing(None) - count = 0 - - def double_it() -> None: - nonlocal count - count += 1 - maybe_x = jax_getattr(thing, "x") - x = 1.0 if maybe_x is None else maybe_x - jax_setattr(thing, "x", 2 * x) - - if jit: - double_it = jax.jit(double_it) - - self.assertEqual(thing.x, None) - double_it() - self.assertEqual(thing.x, 2.0) - self.assertEqual(count, 1) - double_it() - self.assertEqual(thing.x, 4.0) - self.assertEqual(count, 2) - double_it() - self.assertEqual(thing.x, 8.0) - self.assertEqual(count, 2 + (not jit)) - - def test_jit_basic_tree_changes_multiple(self): - thing1 = Thing(None) - thing2 = Thing(0) - count = 0 - - @jax.jit - def double_it() -> None: - nonlocal count - count += 1 - - x1 = jax_getattr(thing1, "x") - if x1 is None: - jax_setattr(thing1, 'x', (None,)) - elif isinstance(x1, tuple): - # depend on a new value - jax_setattr(thing1, 'x', jax_getattr(thing2, 'x') + 1) - else: - jax_setattr(thing2, 'x', jax_getattr(thing1, 'x')) - jax_setattr(thing1, 'x', None) - - self.assertEqual(thing1.x, None) - self.assertEqual(thing2.x, 0) - double_it() - self.assertEqual(thing1.x, (None,)) - self.assertEqual(thing2.x, 0) - self.assertEqual(count, 1) - double_it() - self.assertEqual(thing1.x, 1) - self.assertEqual(thing2.x, 0) - self.assertEqual(count, 2) - double_it() - self.assertEqual(thing1.x, None) - self.assertEqual(thing2.x, 1) - self.assertEqual(count, 3) - double_it() - self.assertEqual(thing1.x, (None,)) - self.assertEqual(thing2.x, 1) - self.assertEqual(count, 3) - double_it() - self.assertEqual(thing1.x, 2) - self.assertEqual(thing2.x, 1) - self.assertEqual(count, 3) - double_it() - self.assertEqual(thing1.x, None) - self.assertEqual(thing2.x, 2) - self.assertEqual(count, 3) - - def test_jit_nesting_basic(self): - thing = Thing(1.0) - - @jax.jit - @jax.jit - def double_it() -> None: - cur_x = jax_getattr(thing, "x") - jax_setattr(thing, "x", cur_x * 2) - - self.assertEqual(thing.x, 1.0) - double_it() - self.assertEqual(thing.x, 2.0) - double_it() - self.assertEqual(thing.x, 4.0) - double_it() - self.assertEqual(thing.x, 8.0) - double_it() - self.assertEqual(thing.x, 16.0) - - def test_jit_consts_and_args(self): - thing = Thing(1.0) - - @jax.jit - def double_it(y) -> None: - cur_x = jax_getattr(thing, "x") - jax_setattr(thing, "x", cur_x * 2) - return jnp.cos(np.arange(3.) * cur_x * y) - - self.assertEqual(thing.x, 1.0) - double_it(2.) - self.assertEqual(thing.x, 2.0) - double_it(2.) - self.assertEqual(thing.x, 4.0) - double_it(2.) - self.assertEqual(thing.x, 8.0) - double_it(2.) - self.assertEqual(thing.x, 16.0) - - def test_jit_transpose_basic(self): - thing = Thing(jnp.array(2.0)) - - @jax.custom_vjp - def foo(x): - return x - - def foo_fwd(x): - return x, None - - def foo_bwd(x, g): - jax_setattr(thing, 'x', g) - return g, - - foo.defvjp(foo_fwd, foo_bwd) - - foo(3.14) - self.assertEqual(thing.x, 2.0) - - jax.grad(foo)(3.14) - self.assertEqual(thing.x, 1.0) - - thing.x = jnp.array(3.14) - self.assertEqual(thing.x, 3.14) - - jax.jit(jax.grad(foo))(3.14) - self.assertEqual(thing.x, 1.0) - - thing.x = jnp.array(2.718) - self.assertEqual(thing.x, 2.718) - - jax.grad(jax.jit(lambda x: jnp.sin(foo(x))))(3.0) - self.assertAllClose(thing.x, -0.9899925, atol=1e-5, rtol=1e-5, check_dtypes=False) - - thing.x = jnp.array(3.14) - self.assertEqual(thing.x, 3.14) - - def bar(x): - out = jnp.sin(foo(x)) - jax_setattr(thing, 'x', 5.0) - return out - - jax.grad(jax.jit(bar))(3.0) - self.assertAllClose(thing.x, -0.9899925, atol=1e-5, rtol=1e-5, check_dtypes=False) - - @parameterized.parameters([True, False]) - def test_scan_basic(self, jit: bool): - thing = Thing(1.0) - - def double_it_10(): - def body(_, __): - cur_x = jax_getattr(thing ,"x") - jax_setattr(thing, "x", cur_x * 2.0) - return None, None - _, _ = jax.lax.scan(body, None, None, length=10) - - if jit: - double_it_10 = jax.jit(double_it_10) - - double_it_10() - self.assertAllClose(thing.x, 1024., check_dtypes=False) - - def test_scan_basic_consts_and_args(self): - thing = Thing(1.0) - - def double_it_10(y): - def body(i, x): - cur_x = jax_getattr(thing ,"x") - jax_setattr(thing, "x", cur_x * 2.0) - return i + 1, (y, y) - _, _ = jax.lax.scan(body, 0, jnp.arange(10)) - - jax.jit(double_it_10)(jnp.arange(3.)) - self.assertAllClose(thing.x, 1024., check_dtypes=False) - - @parameterized.parameters([True, False]) - def test_scan_transpose_basic(self, jit: bool): - thing = Thing(1.0) - - @jax.custom_vjp - def foo(x): - return x - - def foo_fwd(x): - return x, None - - def foo_bwd(x, g): - jax_setattr(thing, 'x', 2 * jax_getattr(thing, 'x') * g) - return g, - - foo.defvjp(foo_fwd, foo_bwd) - - - def double_it_10(x): - def body(x, __): - return foo(x), None - x, _ = jax.lax.scan(body, x, None, length=10) - return x - - if jit: - double_it_10 = jax.jit(double_it_10) - - double_it_10(1.0) - self.assertAllClose(thing.x, 1., check_dtypes=False) - - jax.grad(double_it_10)(1.0) - self.assertAllClose(thing.x, 1024., check_dtypes=False) - - def test_arg_to_jit(self): - self.skipTest("regressed this experimental feature") # TODO(mattjj) - thing = Thing(1.0) - count = 0 - - @jax.jit - def f(obj, x): - nonlocal count - count += 1 - jax_setattr(obj, 'x', x) - - f(thing, 2.0) # don't crash! - self.assertAllClose(thing.x, 2.0, check_dtypes=False) - f(thing, 3.0) - self.assertAllClose(thing.x, 3.0, check_dtypes=False) - self.assertEqual(count, 1) - - def test_tracer_lifetime_bug(self): - # regression test for https://github.com/jax-ml/jax/issues/20082 - class StatefulRNG: - key: jax.Array - - def __init__(self, key: jax.Array): - self.key = key - - def split(self) -> jax.Array: - key = jax_getattr(self, "key") - new_key, returned_key = jax.random.split(key) - jax_setattr(self, "key", new_key) - return returned_key - - rng = StatefulRNG(jax.random.key(0)) - - def jitted(): - rng.split() - rng.split() - - jax.jit(jitted)() # don't crash - - def test_scan_carry(self): - class A: - ... - - a = A() - - jax_setattr(a, 'x', jnp.zeros(3)) - - def body(i, _): - x = jax_getattr(a, 'x') - x = x.at[i].set(x[i] + 1) - jax_setattr(a, 'x', x) - return i + 1, None - _, _ = jax.lax.scan(body, 0, None, length=3) # don't crash - - -class AttrsJVPTest(jtu.JaxTestCase): - - @parameterized.parameters([True, False]) - def test_jvp_basic(self, jit): - thing = Thing(2.0) - - def f(): - x = jax_getattr(thing, 'x') - x = jnp.sin(x) - jax_setattr(thing, 'x', x) - - if jit: - f = jax.jit(f) - - _, _, attr_tangents = attrs.jvp(f, (), (), [(thing, 'x', 1.0)]) - self.assertAllClose(thing.x, jnp.sin(2.0), check_dtypes=False) - (thing_, attr_, tangent_), = attr_tangents - self.assertIs(thing, thing_) - self.assertEqual(attr_, 'x') - self.assertAllClose(tangent_, jnp.cos(2.0), check_dtypes=False) - - @parameterized.parameters([True, False]) - def test_jvp_clobber(self, jit): - thing = Thing(2.0) - - def f(): - x = jax_getattr(thing, 'x') - x = jnp.sin(2.0) - jax_setattr(thing, 'x', x) - - if jit: - f = jax.jit(f) - - _, _, attr_tangents = attrs.jvp(f, (), (), [(thing, 'x', 1.0)]) - self.assertAllClose(thing.x, jnp.sin(2.0), check_dtypes=False) - self.assertEmpty(attr_tangents) - - @parameterized.parameters([True, False]) - def test_jvp_nowrite(self, jit): - thing = Thing(2.0) - - def f(): - x = jax_getattr(thing, 'x') - - if jit: - f = jax.jit(f) - - _, _, attr_tangents = attrs.jvp(f, (), (), [(thing, 'x', 1.0)]) - self.assertAllClose(thing.x, 2.0, check_dtypes=False) - (thing_, attr_, tangent_), = attr_tangents - self.assertIs(thing, thing_) - self.assertEqual(attr_, 'x') - self.assertAllClose(tangent_, 1.0, check_dtypes=False) - - def test_jit_of_jvp(self): - thing = Thing(2.0) - - def f(): - x = jax_getattr(thing, 'x') - x = jnp.sin(x) - jax_setattr(thing, 'x', x) - - @jax.jit - def g(): - _, _, attr_tangents = attrs.jvp(f, (), (), [(thing, 'x', 1.0)]) - (thing_, attr_, tangent_), = attr_tangents - self.assertIs(thing, thing_) - self.assertEqual(attr_, 'x') - return jax_getattr(thing, 'x'), tangent_ - - x, tangent = g() - self.assertAllClose(x, jnp.sin(2.0), check_dtypes=False) - self.assertAllClose(tangent, jnp.cos(2.0), check_dtypes=False) - - @parameterized.parameters([True, False]) - def test_jvp_higher_order(self, jit): - thing = Thing(2.0) - - def f(y): - x = jax_getattr(thing, 'x') - w = jnp.tan(jnp.sin(y) * jnp.cos(x)) - z = jnp.tan(jnp.cos(y) * jnp.sin(x)) - jax_setattr(thing, 'x', z) - return w - if jit: - f = jax.jit(f) - - def f_ref(x, y): - w = jnp.tan(jnp.sin(y) * jnp.cos(x)) - z = jnp.tan(jnp.cos(y) * jnp.sin(x)) - return w, z - - x = jax.random.normal(jax.random.key(0), (3,)) - x_dot = jax.random.normal(jax.random.key(1), (3,)) - y = jax.random.normal(jax.random.key(2), (3,)) - y_dot = jax.random.normal(jax.random.key(3), (3,)) - - setattr(thing, 'x', x) - w, w_dot, [(_, _, z_dot)] = attrs.jvp(f, (y,), (y_dot,), [(thing, 'x', x_dot)]) - z = getattr(thing, 'x') - - (w_, z_), (w_dot_, z_dot_) = jax.jvp(f_ref, (x, y), (x_dot, y_dot)) - - self.assertAllClose(w, w_, check_dtypes=False) - self.assertAllClose(z, z_, check_dtypes=False) - self.assertAllClose(w_dot, w_dot_, check_dtypes=False) - self.assertAllClose(z_dot, z_dot_, check_dtypes=False) - - def g(x_dot, y, y_dot): - w, w_dot, [(_, _, z_dot)] = attrs.jvp(f, (y,), (y_dot,), [(thing, 'x', x_dot)]) - return w, w_dot, z_dot - - def g_ref(x, x_dot, y, y_dot): - (w, z), (w_dot, z_dot) = jax.jvp(f_ref, (x, y), (x_dot, y_dot)) - return w, w_dot, z, z_dot - - x_dot2 = jax.random.normal(jax.random.key(3), (3,)) - x_ddot = jax.random.normal(jax.random.key(4), (3,)) - y_dot2 = jax.random.normal(jax.random.key(5), (3,)) - y_ddot = jax.random.normal(jax.random.key(6), (3,)) - - setattr(thing, 'x', x) - (w, w_dot, z_dot), (w_dot2, w_ddot, z_ddot), [(_, _, z_dot2)] = \ - attrs.jvp(g, (x_dot, y, y_dot), (x_ddot, y_dot2, y_ddot), - [(thing, 'x', x_dot2)]) - z = getattr(thing, 'x') - - (w_, w_dot_, z_, z_dot_), (w_dot2_, w_ddot_, z_dot2_, z_ddot_) = \ - jax.jvp(g_ref, (x, x_dot, y, y_dot), (x_dot2, x_ddot, y_dot2, y_ddot)) - - self.assertAllClose( w, w_, check_dtypes=False) - self.assertAllClose( z, z_, check_dtypes=False) - self.assertAllClose( w_dot, w_dot_, check_dtypes=False) - self.assertAllClose( z_dot, z_dot_, check_dtypes=False) - self.assertAllClose(w_dot2, w_dot2_, check_dtypes=False) - self.assertAllClose(z_dot2, z_dot2_, check_dtypes=False) - self.assertAllClose(w_ddot, w_ddot_, check_dtypes=False) - self.assertAllClose(z_ddot, z_ddot_, check_dtypes=False) - -class AttrsLinTest(jtu.JaxTestCase): - - @parameterized.parameters([True, False]) - def test_attr_output(self, jit): - thing = Thing(1.0) - - def f(x, _): - y = jnp.sin(x) - jax_setattr(thing, 'x', y) - - if jit: - f = jax.jit(f) - - out, f_lin = attrs.linearize(f, 3.0, 4.0) - self.assertIsNone(out) - self.assertAllClose(thing.x, jnp.sin(3.0), check_dtypes=False) - - out_dot, attr_tangents = f_lin(1.0, 2.0, attr_tangents={}) - self.assertIsNone(out_dot) - self.assertAllClose(thing.x, jnp.sin(3.0)) # didn't change - self.assertLen(attr_tangents, 1) - self.assertAllClose(attr_tangents[(thing, 'x')], jnp.cos(3.0), - check_dtypes=False) - - @parameterized.parameters([True, False]) - def test_attr_input(self, jit): - thing = Thing(1.0) - - def f(): - x = jax_getattr(thing, 'x') - return jnp.sin(x) - - if jit: - f = jax.jit(f) - - out, f_lin = attrs.linearize(f, attrs=[(thing, 'x')]) - self.assertAllClose(out, jnp.sin(1.0), check_dtypes=False) - - out_dot, attr_tangents = f_lin(attr_tangents={(thing, 'x'): 2.0}) - self.assertAllClose(out_dot, 2. * jnp.cos(1.0), check_dtypes=False) - self.assertLen(attr_tangents, 1) - self.assertAllClose(attr_tangents[(thing, 'x')], 2.0, check_dtypes=False) - - @parameterized.parameters([True, False]) - def test_attr_inout(self, jit): - thing1 = Thing(1.0) - thing2 = Thing(2.0) - - def f(x, y): - z = jax_getattr(thing1, 'x') - w = jax_getattr(thing2, 'x') - out = jnp.sin(x * y * z * w) - jax_setattr(thing1, 'x', out) - jax_setattr(thing2, 'x', 2 * out) - return 3 * out, 4 * out - - if jit: - f = jax.jit(f) - - def f_ref(x, y, z, w): - out = jnp.sin(x * y * z * w) - return (3 * out, 4 * out), (out, 2 * out) - - out, f_lin = attrs.linearize(f, 3., 4., attrs=[(thing1, 'x'), (thing2, 'x')]) - expected = (3 * jnp.sin(1. * 2. * 3. * 4.), - 4 * jnp.sin(1. * 2. * 3. * 4.)) - self.assertAllClose(out, expected, check_dtypes=False) - self.assertAllClose(thing1.x, jnp.sin(1. * 2. * 3. * 4.)) - self.assertAllClose(thing2.x, 2 * jnp.sin(1. * 2. * 3. * 4.)) - - (out_ref, state_out_ref), f_lin_ref = jax.linearize(f_ref, 3., 4., 1., 2.) - self.assertAllClose(out, out_ref, check_dtypes=False) - self.assertAllClose((thing1.x, thing2.x), state_out_ref, check_dtypes=False) - - out_dot, attr_tangents = f_lin(1., 2., - attr_tangents={(thing1, 'x'): 5., - (thing2, 'x'): 6.}) - self.assertAllClose(thing1.x, jnp.sin(1. * 2. * 3. * 4.)) - self.assertAllClose(thing2.x, 2 * jnp.sin(1. * 2. * 3. * 4.)) - (out_dot_ref, state_dot_ref) = f_lin_ref(1., 2., 5., 6.) - self.assertAllClose(out_dot, out_dot_ref, check_dtypes=False) - self.assertLen(attr_tangents, 2) - self.assertAllClose(attr_tangents[(thing1, 'x')], state_dot_ref[0], - check_dtypes=False) - self.assertAllClose(attr_tangents[(thing2, 'x')], state_dot_ref[1], - check_dtypes=False) - -class AttrsVJPTest(jtu.JaxTestCase): - - @parameterized.parameters([True, False]) - def test_attr_input(self, jit): - thing = Thing(1.0) - - def f(): - x = jax_getattr(thing, 'x') - return jnp.sin(x) - - if jit: - f = jax.jit(f) - - out, f_vjp = attrs.vjp(f, attrs=[(thing, 'x')]) - self.assertAllClose(out, jnp.sin(1.0), check_dtypes=False) - - arg_cts, attr_cotangents = f_vjp(1.0) - self.assertEqual(arg_cts, ()) - self.assertLen(attr_cotangents, 1) - self.assertAllClose(attr_cotangents[(thing, 'x')], jnp.cos(1.0), - check_dtypes=False) - - @parameterized.parameters([True, False]) - def test_attr_output(self, jit): - thing = Thing(1.0) - - def f(x, _): - y = jnp.sin(x) - jax_setattr(thing, 'x', y) - - if jit: - f = jax.jit(f) - - out, f_vjp = attrs.vjp(f, 3.0, 4.0) - self.assertIsNone(out) - self.assertAllClose(thing.x, jnp.sin(3.0), check_dtypes=False) - - arg_cts, attr_cotangents = f_vjp(None, attr_cotangents={(thing, 'x'): 2.0}) - self.assertAllClose(arg_cts, (2 * jnp.cos(3.0), 0.), check_dtypes=False) - self.assertLen(attr_cotangents, 0) - - @parameterized.parameters([True, False]) - def test_attr_inout(self, jit): - thing1 = Thing(1.0) - thing2 = Thing(2.0) - - def f(x, y): - z = jax_getattr(thing1, 'x') - w = jax_getattr(thing2, 'x') - out = jnp.sin(x * y * z * w) - jax_setattr(thing1, 'x', out) - jax_setattr(thing2, 'x', 2 * out) - return 3 * out, 4 * out - - if jit: - f = jax.jit(f) - - def f_ref(x, y, z, w): - out = jnp.sin(x * y * z * w) - return (3 * out, 4 * out), (out, 2 * out) - - out, f_vjp = attrs.vjp(f, 3., 4., attrs=[(thing1, 'x'), (thing2, 'x')]) - (out_ref, state_out_ref), f_vjp_ref = jax.vjp(f_ref, 3., 4., 1., 2.) - self.assertAllClose(out, out_ref, check_dtypes=False) - self.assertAllClose((thing1.x, thing2.x), state_out_ref, check_dtypes=False) - - in_bar, attr_cotangents = f_vjp((1., 2.), - attr_cotangents={(thing1, 'x'): 5., - (thing2, 'x'): 6.}) - in_bar_ref_ = f_vjp_ref(((1., 2.), (5., 6.))) - in_bar_ref, attr_cotangents_ref = in_bar_ref_[:2], in_bar_ref_[2:] - self.assertAllClose(in_bar, in_bar_ref, check_dtypes=False) - self.assertLen(attr_cotangents, 2) - self.assertAllClose(attr_cotangents[(thing1, 'x')], attr_cotangents_ref[0], - check_dtypes=False) - self.assertAllClose(attr_cotangents[(thing2, 'x')], attr_cotangents_ref[1], - check_dtypes=False) - - -if __name__ == '__main__': - absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/batching_test.py b/tests/batching_test.py index f2a4e8c34fe3..393317bcbe77 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -1328,33 +1328,70 @@ def list_insert(lst: list[a], idx: int, val: a) -> list[a]: @jtu.thread_unsafe_test_class() # temporary registration isn't thread-safe class VmappableTest(jtu.JaxTestCase): - def test_basic(self): + @parameterized.parameters([False, True]) + def test_basic(self, jit): with temporarily_register_named_array_vmappable(): def f(x): return named_mul(x, x) + if jit: + f = jax.jit(f) x = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4)) g = jax.vmap(f, - in_axes=NamedMapSpec('i', 0), - out_axes=NamedMapSpec('i', 1), - axis_size=3) + in_axes=NamedMapSpec('i', 0), + out_axes=NamedMapSpec('i', 1), + axis_size=3) ans = g(x) expected = NamedArray(['j', 'i'], jnp.arange(12.).reshape(3, 4).T ** 2) self.assertEqual(ans.names, expected.names) self.assertAllClose(ans.data, expected.data) - def test_basic_jit(self): - with temporarily_register_named_array_vmappable(): - def f(x): - return named_mul(x, x) - - x = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4)) - ans = jax.jit(f)(x) - expected = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4) ** 2) - - self.assertEqual(ans.names, expected.names) - self.assertAllClose(ans.data, expected.data) + def test_to_elt_that_binds_primitives(self): + class A: + data: Array + def __init__(self, data): + self.data = data + def to_elt(cont, _, val, spec): + return cont(val.data + 1, spec) + def from_elt(cont, size, elt, spec): + assert False + + @jax.jit + def f(): + a = A(jnp.arange(3.)) + return jax.vmap(lambda x: x - 1, axis_size=3)(a) + + try: + batching.register_vmappable(A, int, int, to_elt, from_elt, None) + ans = f() + finally: + batching.unregister_vmappable(A) + + self.assertAllClose(ans, jnp.arange(3.)) + + def test_from_elt_that_binds_primitives(self): + class A: + data: Array + def __init__(self, data): + self.data = data + def to_elt(cont, _, val, spec): + return A(cont(val.data, spec)) + def from_elt(cont, size, elt, spec): + return A(cont(size, elt.data + 1, spec)) + + @jax.jit + def f(): + a = A(jnp.arange(3.)) + return jax.vmap(lambda x: x, axis_size=3)(a).data + + try: + batching.register_vmappable(A, int, int, to_elt, from_elt, None) + ans = f() + finally: + batching.unregister_vmappable(A) + + self.assertAllClose(ans, jnp.arange(3.) + 1) def test_types_with_same_spec(self): # We register NamedArray. diff --git a/tests/buffer_callback_test.py b/tests/buffer_callback_test.py new file mode 100644 index 000000000000..138d170a152e --- /dev/null +++ b/tests/buffer_callback_test.py @@ -0,0 +1,211 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np + +import jax +import jax.numpy as jnp +from jax._src import test_util as jtu +from jax._src.lib import jaxlib_extension_version +from jax.experimental import buffer_callback + +jax.config.parse_flags_with_absl() + + +class BufferCallbackTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not supported on TPU.") + + @parameterized.parameters(jtu.dtypes.all) + @jtu.run_on_devices("cpu") + def test_numpy(self, dtype): + def callback(ctx, out, arg): + with self.assertRaisesRegex( + jax.errors.JaxRuntimeError, "XLA FFI GPU context is not available" + ): + ctx.stream + + self.assertEqual(ctx.stage, buffer_callback.ExecutionStage.EXECUTE) + self.assertEqual(arg.shape, shape) + self.assertEqual(arg.dtype, dtype) + self.assertEqual(out.shape, shape) + self.assertEqual(out.dtype, dtype) + + self.assertFalse(arg.writeable) + self.assertTrue(out.writeable) + + x = np.asarray(arg) + self.assertArraysEqual(x, data) + + y = np.asarray(out) + self.assertEqual(x.dtype, y.dtype) + self.assertEqual(x.shape, y.shape) + y[...] = x + + rng = jtu.rand_default(self.rng()) + shape = (3, 4) + data = rng(shape, dtype) + fun = buffer_callback.buffer_callback( + callback, jax.ShapeDtypeStruct(data.shape, data.dtype) + ) + self.assertArraysEqual(fun(data), data) + + @parameterized.parameters(jtu.dtypes.all) + @jtu.run_on_devices("cpu") + def test_dlpack(self, dtype): + if dtype == jnp.bfloat16: + self.skipTest("Numpy's DLPack implementation does not support bfloat16") + + def callback(ctx, out, arg): + del ctx # unused + + x = np.from_dlpack(arg) + self.assertArraysEqual(x, data) + + y = np.from_dlpack(out) + self.assertEqual(x.dtype, y.dtype) + self.assertEqual(x.shape, y.shape) + + rng = jtu.rand_default(self.rng()) + shape = (3, 4) + data = rng(shape, dtype) + fun = buffer_callback.buffer_callback( + callback, jax.ShapeDtypeStruct(data.shape, data.dtype) + ) + + # We can't actually test the output because numpy doesn't support writable + # DLPack tensors. + jax.block_until_ready(fun(data)) + + @parameterized.product( + dtype=jtu.dtypes.all, command_buffer_compatible=[True, False] + ) + @jtu.run_on_devices("gpu") + def test_cuda_array_interface(self, dtype, command_buffer_compatible): + if command_buffer_compatible and jaxlib_extension_version < 337: + self.skipTest("Requires jaxlib extension version of at least 337.") + + def callback(ctx, out, arg): + ctx.stream # doesn't crash + + self.assertEqual(ctx.stage, buffer_callback.ExecutionStage.EXECUTE) + self.assertEqual(arg.shape, shape) + self.assertEqual(arg.dtype, dtype) + self.assertEqual(out.shape, shape) + self.assertEqual(out.dtype, dtype) + + obj = arg.__cuda_array_interface__ + self.assertEqual(obj["shape"], data.shape) + self.assertEqual(obj["typestr"], data.dtype.str) + + obj = out.__cuda_array_interface__ + self.assertEqual(obj["shape"], data.shape) + self.assertEqual(obj["typestr"], data.dtype.str) + + rng = jtu.rand_default(self.rng()) + shape = (3, 4) + data = rng(shape, dtype) + fun = buffer_callback.buffer_callback( + callback, jax.ShapeDtypeStruct(data.shape, data.dtype), + command_buffer_compatible=command_buffer_compatible, + ) + + # TODO: There's an XLA:GPU/CUDA bug that causes a segfault when + # instantiating an empty CUDA graph. Once that bug is fixed or worked + # around, add a test that checks that the Python callback is only executed + # once. + jax.block_until_ready(fun(data)) + + @parameterized.parameters([ + "sequential", "sequential_unrolled", "expand_dims", "broadcast_all" + ]) + @jtu.run_on_devices("cpu") + def test_batching(self, vmap_method): + def callback(ctx, out, *args): + del ctx # unused + x = np.asarray(args[0]) + y = np.asarray(args[1]) + z = np.asarray(out) + z[...] = x + z[...] += y + + rng = jtu.rand_default(self.rng()) + shape = (3, 4) + x = rng(shape, jnp.float32) + y = rng(shape, jnp.float32) + fun = buffer_callback.buffer_callback( + callback, + jax.ShapeDtypeStruct(x.shape[1:], x.dtype), + vmap_method=vmap_method, + ) + self.assertArraysEqual(jax.vmap(fun)(x, y), x + y) + + @jtu.run_on_devices("cpu") + def test_input_output_aliases(self): + def callback(ctx, out, arg): + del ctx # unused + x = np.asarray(arg) + y = np.asarray(out) + self.assertEqual(x.ctypes.data, y.ctypes.data) + + rng = jtu.rand_default(self.rng()) + shape = (3, 4) + data = rng(shape, jnp.float32) + fun = buffer_callback.buffer_callback( + callback, jax.ShapeDtypeStruct(data.shape, data.dtype), + input_output_aliases={0: 0}, + ) + jax.block_until_ready(fun(data)) + + @jtu.run_on_devices("cpu") + def test_buffer_callback_multi_mesh(self): + def no_op(*args, **kwargs): + pass + + @jax.jit + def f(x, y): + z = x * y + output_shape = jax.ShapeDtypeStruct(x.shape, x.dtype) + buffer_call = buffer_callback.buffer_callback( + no_op, output_shape, command_buffer_compatible=True) + return buffer_call((z,)) + + mesh1 = jtu.create_mesh((1, 1), ('a', 'b')) + mesh2 = jtu.create_mesh((1, 1), ('x', 'y')) + + x = jax.device_put( + jnp.ones((32, 32)), jax.NamedSharding(mesh1, jax.P('a', 'b'))) + y = jax.device_put( + jnp.ones((32, 32)), jax.NamedSharding(mesh2, jax.P('x', 'y'))) + f(x, y) # doesn't crash + + def test_side_effect(self): + def callback(*_): + nonlocal called + called = True + + called = False + fun = buffer_callback.buffer_callback( + callback, jax.ShapeDtypeStruct((), jnp.float32), has_side_effect=True) + jax.block_until_ready(fun()) + self.assertTrue(called) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/cache_key_test.py b/tests/cache_key_test.py index 2faa4dbaf9d4..35ac03011a97 100644 --- a/tests/cache_key_test.py +++ b/tests/cache_key_test.py @@ -83,9 +83,9 @@ def test_hash_accelerator_devices(self): self.assertEqual(dev_hash1, dev_hash2) acc_hash1 = self.get_hashed_value( - cache_key._hash_accelerator_config, devices, xla_bridge.get_backend()) + cache_key._hash_accelerator_config, devices) acc_hash2 = self.get_hashed_value( - cache_key._hash_accelerator_config, devices, xla_bridge.get_backend()) + cache_key._hash_accelerator_config, devices) self.assertEqual(acc_hash1, acc_hash2) def test_hash_platform(self): @@ -163,6 +163,8 @@ def test_different_computations(self): cache_key.get(computation2, devices, compile_options, backend), ) + # TODO(phawkins): this test flakes if test concurrency is enabled. + @jtu.thread_unsafe_test() def test_custom_partitioning_ptr_removal(self): def _partition(mesh, arg_shapes, result_shape): arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) @@ -178,7 +180,8 @@ def _cp_add(x, y): _cp_add.def_partition( infer_sharding_from_operands=_infer_sharding_from_operands, - partition=_partition) + partition=_partition, + sharding_rule='..., ... -> ...') devices = np.asarray(jax.devices()) with Mesh(devices, ('x',)) as m: diff --git a/tests/checkify_test.py b/tests/checkify_test.py index 6a1660b28578..9cc1ab2bc515 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -22,15 +22,13 @@ import jax from jax import lax from jax.experimental import checkify -from jax.experimental import pjit -from jax.experimental import shard_map -from jax.sharding import NamedSharding +from jax._src import shard_map +from jax.sharding import NamedSharding, PartitionSpec as P from jax._src import array from jax._src import config from jax._src import core from jax._src import test_util as jtu from jax._src.checkify import JaxRuntimeError, FailedCheckError, ErrorEffect, OOBError -from jax._src.lib import xla_extension import jax.numpy as jnp config.parse_flags_with_absl() @@ -250,7 +248,7 @@ def f(x): xs = jnp.array([3., 0.]) err, _ = checked_f(xs) self.assertIsNotNone(err.get()) - self.assertStartsWith(err.get(), "nan generated by primitive: sin") + self.assertIn("nan generated by primitive: sin", err.get()) def test_pmap_collectives(self): if len(jax.devices()) < 4: @@ -426,7 +424,7 @@ def f(init_val): def test_while_loop_cond_error(self): def while_cond(val): _ = jnp.sin(1./val) - return val < 2. + return 0. * _ + val < 2. def while_body(val): return val+1. @@ -475,12 +473,25 @@ def f(init_val): self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "division by zero") + def test_checify_donation_no_forwarding(self): + mesh = jtu.create_mesh((2,), ('x',)) + + @checkify.checkify + @jax.jit(donate_argnums=(0,)) + def f(x: jax.Array) -> jax.Array: + checkify.check(jnp.all(x > 0), "a") + return x + + x = jax.device_put(jnp.zeros(64, dtype="int32"), NamedSharding(mesh, P())) + err, y = f(x) + err, z = f(y) # doesn't crash + @jtu.skip_on_devices("tpu") def test_while_loop_body_and_cond_error(self): def while_cond(val): i, cond_val, _ = val - _ = jnp.sin(cond_val) - return i < 2 + j = jnp.sin(cond_val) + return i + (0. * j) < 2 # don't let the sin value be dead code def while_body(val): i, cond_val, body_val = val @@ -512,7 +523,7 @@ def f(cond_val, body_val): # first error which occurs is in cond self.assertStartsWith(err.get(), "nan generated by primitive: sin") - def test_pjit(self): + def test_checkify_jit(self): def f(x): # unary func return x / x @@ -527,11 +538,11 @@ def g(x, y): inp = np.arange(8) x = array.make_array_from_callback(inp.shape, ps, lambda idx: inp[idx]) - f = pjit.pjit(f, in_shardings=ps, out_shardings=ps) + f = jax.jit(f, in_shardings=ps, out_shardings=ps) f = checkify.checkify(f, errors=checkify.float_checks) - g = pjit.pjit(g, in_shardings=ps, out_shardings=ps) + g = jax.jit(g, in_shardings=ps, out_shardings=ps) g = checkify.checkify(g, errors=checkify.float_checks) - with mesh: + with jax.set_mesh(mesh): u_err, _ = f(x) b_err, _ = g(x, x) @@ -541,7 +552,7 @@ def g(x, y): self.assertStartsWith(b_err.get(), "division by zero") @parameterized.parameters(True, False) - def test_shard_map(self, check_rep): + def test_shard_map(self, check_vma): def f(x): # unary func return jax.lax.axis_index("dev") * x / x @@ -558,12 +569,12 @@ def g(x, y): x = array.make_array_from_callback(inp.shape, ps, lambda idx: inp[idx]) f = shard_map.shard_map( - f, mesh, in_specs=pspec, out_specs=pspec, check_rep=check_rep + f, mesh=mesh, in_specs=pspec, out_specs=pspec, check_vma=check_vma ) f = jax.jit(f, in_shardings=ps, out_shardings=ps) f = checkify.checkify(f, errors=checkify.float_checks) g = shard_map.shard_map( - g, mesh, in_specs=(pspec, pspec), out_specs=pspec, check_rep=check_rep + g, mesh=mesh, in_specs=(pspec, pspec), out_specs=pspec, check_vma=check_vma ) g = jax.jit(g, in_shardings=(ps, ps), out_shardings=ps) g = checkify.checkify(g, errors=checkify.float_checks) @@ -764,8 +775,9 @@ def while_body(carry): def test_multiple_payloads(self): def f(x): - _ = x[5] - _ = x[6] + a = x[5] + b = x[6] + return a + b err, _ = checkify.checkify(f, errors=checkify.index_checks)(jnp.ones((2,))) self.assertIsNotNone(err.get()) @@ -1215,7 +1227,7 @@ def while_body(s): with self.assertRaisesRegex(ValueError, "checkify-of-vmap-of-while"): checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([5., 2., 4.])) - # TODO(lenamartens): reenable assertions below. + # TODO(lenamartens): re-enable assertions below. # self.assertIsNotNone(err.get()) # self.assertStartsWith(err.get(), "division by zero") @@ -1244,7 +1256,7 @@ def fun(x): with self.assertRaisesRegex(ValueError, "checkify-of-vmap-of-while"): checked_f(jnp.arange(5)) - # TODO(lenamartens): reenable assertions below. + # TODO(lenamartens): re-enable assertions below. # self.assertIsNone(err.get()) def test_assert_cond_no_data_dependence(self): @@ -1374,9 +1386,9 @@ def f(x): checkify.check(x > 0, "x needs to be positive") return x - with self.assertRaisesRegex(xla_extension.XlaRuntimeError, + with self.assertRaisesRegex(jax.errors.JaxRuntimeError, "x needs to be positive"): - f(-1.) + f(-1.).block_until_ready() if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/colocated_python_test.py b/tests/colocated_python_test.py index 52d494904fe6..c89b9a36958f 100644 --- a/tests/colocated_python_test.py +++ b/tests/colocated_python_test.py @@ -16,9 +16,6 @@ import struct import tempfile import threading -import time -from typing import Sequence -import unittest from absl.testing import absltest from absl.testing import parameterized @@ -36,27 +33,9 @@ try: import cloudpickle # noqa + HAS_CLOUDPICKLE = True except (ModuleNotFoundError, ImportError): - raise unittest.SkipTest("tests depend on cloudpickle library") - - -def _colocated_cpu_devices( - devices: Sequence[jax.Device], -) -> Sequence[jax.Device]: - """Returns CPU devices colocated with the given devices.""" - try: - return colocated_python.colocated_cpu_devices(devices) - except (ValueError, AttributeError): - # PjRt-IFRT prepares CPU devices by its own. - # TODO(hyeontaek): Remove this fallback path once PjRt-IFRT prepares CPU - # devices by its own. - cpu_backend_devices = jax.local_devices(backend="cpu") - device_index_map = {device.id: i for i, device in enumerate(jax.devices())} - - available_devices = devices[: min(len(cpu_backend_devices), len(devices))] - return [ - cpu_backend_devices[device_index_map[d.id]] for d in available_devices - ] + HAS_CLOUDPICKLE = False _count_colocated_python_specialization_cache_miss = jtu.count_events( @@ -68,32 +47,113 @@ class ColocatedPythonTest(jtu.JaxTestCase): def setUp(self): super().setUp() + if not HAS_CLOUDPICKLE: + self.skipTest( + "ColocatedPythonTest depends on cloudpickle library" + ) if np.lib.NumpyVersion(np.__version__) < "2.0.0": self.skipTest( - "Serialization in Colocated Python needs StringDType, and thus" - " requires NumPy 2.0.0 or later" + "Serialization in Colocated Python needs StringDType, and thus" + " requires NumPy 2.0.0 or later" ) - def testMakeColocatedPythonProgram(self): + def test_colocated_cpu_devices(self): + mesh = jax.sharding.Mesh( + np.array(jax.local_devices()[:1]).reshape((1, 1)), ("x", "y") + ) + cpu_mesh1 = colocated_python.colocated_cpu_devices(mesh) + + cpu_devices = colocated_python.colocated_cpu_devices( + jax.local_devices()[:1] + ) + cpu_mesh2 = jax.sharding.Mesh( + np.array(cpu_devices).reshape((1, 1)), ("x", "y") + ) + self.assertEqual(cpu_mesh1, cpu_mesh2) + + def test_serialization_roundtrip(self): + cpu_devices = colocated_python.colocated_cpu_devices( + jax.local_devices()[:1]) + + mesh = jax.sharding.Mesh(np.array(cpu_devices).reshape((1, 1)), ("x", "y")) + self.assertEqual( + serialization._deserialize(serialization._serialize(mesh)), mesh) + + sharding1 = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec("x")) + self.assertEqual( + serialization._deserialize(serialization._serialize([sharding1])), + [sharding1]) + + sharding2 = jax.sharding.SingleDeviceSharding( + cpu_devices[0], memory_kind="pinned_host") + self.assertEqual( + serialization._deserialize(serialization._serialize((sharding2,))), + (sharding2,)) + + def func(x): + return x + 1 + + self.assertEqual( + serialization._deserialize(serialization._serialize(func))(1), func(1)) + + def test_make_colocated_python_program(self): def add_one(x): return x + 1 - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) sharding = jax.sharding.SingleDeviceSharding(cpu_devices[0]) sds = jax.ShapeDtypeStruct((), jnp.int32, sharding=sharding) - pickled_function = serialization._serialize(add_one) + fun_and_specialization = ( + add_one, + None, # dummy in_specs_treedef + None, # dummy in_specs_leaves + None, # dummy out_specs_treedef + None, # dummy out_specs_leaves + None, # dummy devices + ) + pickled_function = serialization._serialize(fun_and_specialization) program = ifrt_programs.make_colocated_python_program( "add_one", pickled_function, [cpu_devices[0]], [sds], [sds] ) del program - def testSimpleFunction(self): + def test_serialize_with_shared_obj(self): + cpu_devices = colocated_python.colocated_cpu_devices( + jax.local_devices()[:1]) + mesh = jax.sharding.Mesh( + np.array(cpu_devices).reshape((1, 1)), + ("long_axis_name_1", "long_axis_name_2")) + sharding1 = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec("long_axis_name_1")) + sharding2 = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec("long_axis_name_2")) + + serialized1 = serialization._serialize([sharding1]) + serialized2 = serialization._serialize([sharding1, sharding2]) + serialized3 = serialization._serialize([sharding1, sharding1]) + + # The total serialized size of two shardings of a shared mesh should be less + # than twice the serialized size of a single sharding. + self.assertLess(len(serialized2), len(serialized1) * 2) + + # The total serialized size of two identical shardings should be less than + # that of two shardings that only share the mesh. + self.assertLess(len(serialized3), len(serialized2)) + + self.assertEqual(serialization._deserialize(serialized1), [sharding1]) + self.assertEqual( + serialization._deserialize(serialized2), [sharding1, sharding2]) + self.assertEqual( + serialization._deserialize(serialized3), [sharding1, sharding1]) + + def test_simple_function(self): @colocated_python.colocated_python def add_one(x): return x + 1 - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) x = np.array(1) x = jax.device_put(x, cpu_devices[0]) @@ -108,12 +168,12 @@ def add_one(x): self.assertEqual(out, np.array(2)) self.assertEqual(count(), 1) - def testSimpleFunctionWithTree(self): + def test_simple_function_with_tree(self): @colocated_python.colocated_python def add_one(x): return jax.tree.map(lambda x: x + 1, x) - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) x = [np.array(1), (np.array(2), {"v": np.array(3)})] x = jax.device_put(x, jax.sharding.SingleDeviceSharding(cpu_devices[0])) @@ -128,7 +188,7 @@ def add_one(x): self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})]) self.assertEqual(count(), 1) - def testEmptyInputFailsWithoutSpecialization(self): + def test_empty_input_fails_without_specialization(self): @colocated_python.colocated_python def make_zero(): return jnp.array(0) @@ -136,16 +196,15 @@ def make_zero(): with self.assertRaisesRegex( ValueError, "No devices found. colocated_python function without input arguments" - " must be first specialized with devices.", - ): + " must be first specialized with devices."): _ = make_zero() - def testEmptyInputWithDevicesSpecialization(self): + def test_empty_input_with_devices_specialization(self): @colocated_python.colocated_python def make_zero(): return jnp.array(0) - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) with _count_colocated_python_specialization_cache_miss() as count: make_zero = make_zero.specialize(devices=cpu_devices[:1]) @@ -159,12 +218,12 @@ def make_zero(): self.assertEqual(out, np.array(0)) self.assertEqual(count(), 1) - def testInputPolymorphismWithoutOutSpecsFn(self): + def test_input_polymorphism_without_out_specs_fn(self): @colocated_python.colocated_python def add_one(x): return jax.tree.map(lambda x: x + 1, x) - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) x = np.array(1) x = jax.device_put(x, cpu_devices[0]) @@ -193,12 +252,12 @@ def add_one(x): self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})]) self.assertEqual(count(), 2) - def testInputPolymorphismAllowedWithOutSpecsFn(self): + def test_input_polymorphism_allowed_with_out_specs_fn(self): @colocated_python.colocated_python def add_one(x): return jax.tree.map(lambda x: x + 1, x) - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) x = np.array(1) x = jax.device_put(x, cpu_devices[0]) @@ -232,82 +291,108 @@ def add_one(x): ("on_main_thread", True), ("on_non_main_thread", False), ) - def testSequentialExecution(self, on_main_thread: bool): - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + # Cannot run concurrently with other tests using `colocated_python._testing_global_state`. + @jtu.thread_unsafe_test() + def test_sequential_execution(self, on_main_thread: bool): + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) x = np.array(1) x = jax.device_put(x, cpu_devices[0]) - # Make sure that this input array is ready for use by the colocated Python - # function and does not disrupt elapsed time measurement. - jax.block_until_ready(x) @colocated_python.colocated_python - def sleep(x: jax.Array) -> jax.Array: - time.sleep(5) + def func0(x: jax.Array) -> jax.Array: + colocated_python._testing_global_state = 100 return x - # Specify out_specs_fn so that all executions are asynchronously dispatched. - sleep = sleep.specialize(out_specs_fn=lambda x: x) + @colocated_python.colocated_python + def func1(x: jax.Array) -> jax.Array: + assert "_testing_global_state" in colocated_python.__dict__ + assert colocated_python._testing_global_state == 100 + colocated_python._testing_global_state += 1 + return x - def sleep_twice_and_wait(x: jax.Array) -> None: - _ = sleep(x) - jax.block_until_ready(sleep(x)) + @colocated_python.colocated_python + def func2(x: jax.Array) -> jax.Array: + assert "_testing_global_state" in colocated_python.__dict__ + assert colocated_python._testing_global_state == 101 + return x - start_time = time.time() + @colocated_python.colocated_python + def cleanup(x: jax.Array) -> jax.Array: + if "_testing_global_state" in colocated_python.__dict__: + del colocated_python._testing_global_state + return x - # Two executions of `sleep` within `sleep_twice_and_wait` should run - # sequentially. - if on_main_thread: - sleep_twice_and_wait(x) - else: - t = threading.Thread(target=sleep_twice_and_wait, args=(x,)) - t.start() - t.join() + # Specify out_specs_fn so that their executions are asynchronously + # dispatched. + func0 = func0.specialize(out_specs_fn=lambda x: x) + func1 = func1.specialize(out_specs_fn=lambda x: x) + func2 = func2.specialize(out_specs_fn=lambda x: x) - elapsed_time = time.time() - start_time + def calls(x: jax.Array) -> None: + # No explicit blocking before making the next call. + func0(x) + func1(x) + jax.block_until_ready(func2(x)) - # If sequential execution did not happen, elapsed time typically will be - # around 5 seconds. - self.assertGreaterEqual(elapsed_time, 10) + try: + # Executions in `calls` should run sequentially. + if on_main_thread: + calls(x) + else: + t = threading.Thread(target=calls, args=(x,)) + t.start() + t.join() + # Executions should succeed without an error. + finally: + jax.block_until_ready(cleanup(x)) - def testConcurrentExecution(self): - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + # Cannot run concurrently with other tests using `colocated_python._testing_global_state`. + @jtu.thread_unsafe_test() + def test_concurrent_execution(self): + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) x = np.array(1) x = jax.device_put(x, cpu_devices[0]) - # Make sure that this input array is ready for use by the colocated Python - # function and does not disrupt elapsed time measurement. - jax.block_until_ready(x) @colocated_python.colocated_python - def sleep(x: jax.Array) -> jax.Array: - time.sleep(5) + def init(x: jax.Array) -> jax.Array: + colocated_python._testing_global_state = threading.Barrier(3) return x - # Specify out_specs_fn so that all executions are asynchronously dispatched. - sleep = sleep.specialize(out_specs_fn=lambda x: x) - - def sleep_and_wait(x: jax.Array) -> None: - jax.block_until_ready(sleep(x)) - - start_time = time.time() + @colocated_python.colocated_python + def func(x: jax.Array) -> jax.Array: + assert "_testing_global_state" in colocated_python.__dict__ + colocated_python._testing_global_state.wait(timeout=5) + return x - # All three executions of `sleep_and_wait` should run concurrently. - t1 = threading.Thread(target=sleep_and_wait, args=(x,)) - t2 = threading.Thread(target=sleep_and_wait, args=(x,)) - t1.start() - t2.start() - sleep_and_wait(x) - t1.join() - t2.join() + @colocated_python.colocated_python + def cleanup(x: jax.Array) -> jax.Array: + if "_testing_global_state" in colocated_python.__dict__: + del colocated_python._testing_global_state + return x - elapsed_time = time.time() - start_time + # Specify out_specs_fn so that their executions are asynchronously + # dispatched. + func = func.specialize(out_specs_fn=lambda x: x) - self.assertGreaterEqual(elapsed_time, 5) - # If concurrent execution did not happen, elapsed time typically will be - # around 15 seconds. - self.assertLess(elapsed_time, 10) + try: + jax.block_until_ready(init(x)) + + # All func calls should run concurrently and enter/exit the barrier. + t1 = threading.Thread(target=func, args=(x,)) + t2 = threading.Thread(target=func, args=(x,)) + t3 = threading.Thread(target=func, args=(x,)) + t1.start() + t2.start() + t3.start() + t1.join() + t2.join() + t3.join() + # Executions should succeed without a deadlock. + finally: + jax.block_until_ready(cleanup(x)) - def testInputsWithDifferentDeviceOrders(self): - cpu_devices = _colocated_cpu_devices(jax.local_devices())[:2] + def test_inputs_with_different_device_orders(self): + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices())[:2] if len(cpu_devices) < 2: self.skipTest("Not enough CPU devices") @@ -348,7 +433,7 @@ def add(x: jax.Array, y: jax.Array) -> jax.Array: out = jax.device_get(out) np.testing.assert_equal(out, np.array([2 + 4, 0 + 8])) - def testModuleVariableAccess(self): + def test_module_variable_access(self): try: # The following pattern of storing and accessing non-serialized state in # the Python module is discouraged for storing user-defined state. @@ -372,7 +457,7 @@ def get_global_state(x: jax.Array) -> jax.Array: del x return colocated_python._testing_global_state - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) x = np.array(1) x = jax.device_put(x, cpu_devices[0]) y = np.array(2) @@ -389,8 +474,8 @@ def get_global_state(x: jax.Array) -> jax.Array: if "_testing_global_state" in colocated_python.__dict__: del colocated_python._testing_global_state - def testStringProcessing(self): - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + def test_string_processing(self): + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) if len(cpu_devices) < 2: self.skipTest(f"Need at least two CPU devices, got: {len(cpu_devices)}") @@ -430,8 +515,8 @@ def f(x): ), ) - def testBinaryDataProcessing(self): - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + def test_binary_data_processing(self): + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) if len(cpu_devices) < 1: self.skipTest("Need at least one CPU devices") @@ -472,8 +557,8 @@ def f(x): self.assertEqual(out_ints[0], 1002) self.assertEqual(out_ints[1], 1003) - def testDetectInvalidMeshDevice(self): - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + def test_detect_invalid_mesh_device(self): + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) if jax.local_devices()[0].id == cpu_devices[0].id: self.skipTest( "This test only works in a setup where accelerator and CPU devices" @@ -493,9 +578,12 @@ def make_zero() -> jax.Array: make_zero = make_zero.specialize(devices=cpu_devices) jax.block_until_ready(make_zero()) - def testObjectLifecycle(self): - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + # Cannot run concurrently with other tests using `colocated_python._testing_global_state`. + @jtu.thread_unsafe_test() + def test_object_lifecycle(self): + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) sharding = jax.sharding.SingleDeviceSharding(cpu_devices[0]) + x = jax.device_put(np.array(0), sharding) @colocated_python.colocated_python_class class Object: @@ -506,8 +594,6 @@ def __init__(self) -> None: def __del__(self) -> None: colocated_python._testing_destroyed = True - # TODO(hyeontaek): Support method calls with no arguments and remove - # `x` parameter. def echo(self, x: jax.Array) -> jax.Array: return x @@ -522,15 +608,15 @@ def check_destroyed() -> jax.Array: return jax.device_put(np.array(destroyed), sharding) @colocated_python.colocated_python - def cleanup(): + def cleanup(x: jax.Array) -> jax.Array: if "_testing_initialized" in colocated_python.__dict__: del colocated_python._testing_initialized if "_testing_destroyed" in colocated_python.__dict__: del colocated_python._testing_destroyed + return x check_initialized = check_initialized.specialize(devices=cpu_devices[:1]) check_destroyed = check_destroyed.specialize(devices=cpu_devices[:1]) - cleanup = cleanup.specialize(devices=cpu_devices[:1]) try: # Object initialization is deferred until the first method call. @@ -544,7 +630,7 @@ def cleanup(): self.assertEqual(jax.device_get(check_initialized()), False) self.assertEqual(jax.device_get(check_destroyed()), False) finally: - cleanup() + jax.block_until_ready(cleanup(x)) try: # Object initialization is deferred until the first method call. @@ -555,7 +641,7 @@ def cleanup(): # The first method call on a process triggers object initialization there. x = np.array(1) x = jax.device_put(x, sharding) - obj.echo(x) + jax.block_until_ready(obj.echo(x)) self.assertEqual(jax.device_get(check_initialized()), True) self.assertEqual(jax.device_get(check_destroyed()), False) @@ -563,10 +649,10 @@ def cleanup(): self.assertEqual(jax.device_get(check_initialized()), True) self.assertEqual(jax.device_get(check_destroyed()), True) finally: - cleanup() + jax.block_until_ready(cleanup(x)) - def testStatefulObject(self): - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + def test_stateful_object(self): + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) @colocated_python.colocated_python_class class Value: @@ -578,9 +664,7 @@ def add(self, x: jax.Array) -> jax.Array: self.value += np.asarray(x) return jax.device_put(self.value, x.sharding) - # TODO(hyeontaek): Support method calls with no arguments and remove - # `x` parameter. - def fetch(self, x: jax.Array) -> jax.Array: + def fetch_like(self, x: jax.Array) -> jax.Array: return jax.device_put(self.value, x.sharding) value = Value(np.array(5)) @@ -594,11 +678,11 @@ def fetch(self, x: jax.Array) -> jax.Array: out = jax.device_get(value.add(x)) self.assertEqual(out, np.array(7)) - out = jax.device_get(value.fetch(x)) + out = jax.device_get(value.fetch_like(x)) self.assertEqual(out, np.array(7)) - def testObjectWithCapturedSharding(self): - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + def test_object_with_captured_sharding(self): + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) if len(cpu_devices) < 2: self.skipTest(f"Need at least two CPU devices, got: {len(cpu_devices)}") @@ -642,6 +726,51 @@ def add_sharding2(self, x: jax.Array) -> jax.Array: out = jax.device_get(out) self.assertArraysEqual(out, np.array([7, 17])) + def test_object_method_specialization(self): + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) + cpu_devices = cpu_devices[:1] + sharding = jax.sharding.SingleDeviceSharding(cpu_devices[0]) + + @colocated_python.colocated_python_class + class Object: + + def __init__(self, sharding: jax.sharding.Sharding) -> None: + self.sharding = sharding + + def fetch_with_devices(self) -> jax.Array: + return jax.device_put(np.array(1, dtype=np.int32), self.sharding) + + def fetch_with_output_spec(self) -> np.ndarray: + return jax.device_put(np.array(1, dtype=np.int32), self.sharding) + + obj = Object(sharding) + + with self.assertRaisesRegex( + ValueError, + "No devices found. colocated_python function without input arguments" + " must be first specialized with devices."): + jax.block_until_ready(obj.fetch_with_devices()) + + with self.assertRaisesRegex( + ValueError, + "No devices found. colocated_python function without input arguments" + " must be first specialized with devices."): + jax.block_until_ready(obj.fetch_with_output_spec()) + + obj.fetch_with_devices = ( + obj.fetch_with_devices.specialize(devices=cpu_devices)) + out = obj.fetch_with_devices() + self.assertArraysEqual(out, np.array(1, dtype=np.int32)) + + # TODO(hyeontaek): Infer `devices` from the output spec computed using the + # output spec function. + obj.fetch_with_output_spec = obj.fetch_with_output_spec.specialize( + devices=cpu_devices, + out_specs_fn=lambda: jax.ShapeDtypeStruct( + shape=(), dtype=np.int32, sharding=sharding)) + out = obj.fetch_with_output_spec() + self.assertArraysEqual(out, np.array(1, dtype=np.int32)) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index 3fcc0ab476bf..fa7ec7d81c14 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -15,7 +15,7 @@ from __future__ import annotations from collections import Counter -from functools import partial +import glob import logging import math import os @@ -43,7 +43,6 @@ from jax._src import xla_bridge from jax._src.compilation_cache_interface import CacheInterface from jax._src.lib import xla_client as xc -from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as P import numpy as np @@ -134,7 +133,7 @@ def test_get_no_executable(self): backend = xla_bridge.get_backend() key = cc.get_cache_key(computation, devices, compile_options, backend) executable, compile_time = cc.get_executable_and_time( - key, compile_options, backend) + key, compile_options, backend, xc.DeviceList(tuple(devices.flat))) self.assertIsNone(executable) self.assertIsNone(compile_time) @@ -145,15 +144,20 @@ def test_diff_executables(self): num_replicas=1, num_partitions=1 ) backend = xla_bridge.get_backend() - executable1 = backend.compile(computation1, compile_options) - executable2 = backend.compile(computation2, compile_options) + executable_devices = xc.DeviceList(tuple(backend.local_devices())) + executable1 = backend.compile_and_load( + computation1, executable_devices, compile_options) + executable2 = backend.compile_and_load( + computation2, executable_devices, compile_options) cc.put_executable_and_time( "key1", "computation1", executable1, backend, FAKE_COMPILE_TIME) cc.put_executable_and_time( "key2", "computation2", executable2, backend, FAKE_COMPILE_TIME) self.assertNotEqual( - cc.get_executable_and_time("key1", compile_options, backend)[0], - cc.get_executable_and_time("key2", compile_options, backend)[0] + cc.get_executable_and_time( + "key1", compile_options, backend, executable_devices)[0], + cc.get_executable_and_time( + "key2", compile_options, backend, executable_devices)[0] ) def test_put_executable(self): @@ -167,12 +171,14 @@ def test_put_executable(self): num_replicas=1, num_partitions=1 ) backend = xla_bridge.get_backend() - executable = backend.compile(str(computation), compile_options) + executable_devices = xc.DeviceList(tuple(devices.flat)) + executable = backend.compile_and_load( + str(computation), executable_devices, compile_options) key = cc.get_cache_key(computation, devices, compile_options, backend) cc.put_executable_and_time( key, "alambda", executable, backend, FAKE_COMPILE_TIME) executable_retrieved, compile_time_retrieved = cc.get_executable_and_time( - key, compile_options, backend) + key, compile_options, backend, executable_devices) inputs_to_executable = ( jnp.array(1, dtype=np.int32), jnp.array(2, dtype=np.int32), @@ -184,7 +190,7 @@ def test_put_executable(self): def test_pmap(self): f = pmap(lambda x: x - lax.psum(x, "i"), axis_name="i") - x = np.arange(jax.device_count(), dtype=np.int64) + x = np.arange(jax.device_count(), dtype=np.int32) f(x) self.assertEqual(count_cache_items(), 1) x = np.arange(jax.device_count(), dtype=np.float32) @@ -192,13 +198,65 @@ def test_pmap(self): self.assertEqual(count_cache_items(), 2) # TODO: create a test for calling pmap with the same input more than once + def test_pmap_with_consts(self): + const = jnp.array([42, 43], dtype=np.int32) + clear_cache() + f = pmap(lambda x: x - lax.psum(x, "i") + const[0], axis_name="i") + x = np.arange(jax.device_count(), dtype=np.int32) + self.assertAllClose(f(x), x - np.sum(x, dtype=np.int32) + np.int32(42)) + self.assertEqual(count_cache_items(), 1) + + const1 = jnp.array([142, 143], dtype=np.int32) # another const + f1 = pmap(lambda x: x - lax.psum(x, "i") + const1[0], axis_name="i") + expected_compilations = 0 if config.use_simplified_jaxpr_constants.value else 1 + self.assertCacheMisses(lambda: f1(x), + lowering=1, + compilation_after_persistent_cache_miss=expected_compilations) + self.assertAllClose(f1(x), x - np.sum(x, dtype=np.int32) + np.int32(142)) + self.assertEqual(count_cache_items(), 1 + expected_compilations) + def test_jit(self): f = jit(lambda x: x * x) - f(1) + self.assertCacheMisses(lambda: f(1), lowering=1, + compilation_after_persistent_cache_miss=1) self.assertEqual(count_cache_items(), 1) + f1 = jit(lambda x: x * x) + self.assertCacheMisses(lambda: f1(2), lowering=1, + compilation_after_persistent_cache_miss=0) f(1.0) self.assertEqual(count_cache_items(), 2) + def test_jit_sharded(self): + mesh = jtu.create_mesh((2,), 'x') + with jax.set_mesh(mesh): + @jax.jit(in_shardings=(P("x"), P("x")), out_shardings=None) + def f(x, y): + return x + y + + shape = (8, 8) + x = np.arange(math.prod(shape), dtype=np.int64).reshape(shape) + f(x, x + 1) + self.assertEqual(count_cache_items(), 1) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) + f(x, x + 1) + self.assertEqual(count_cache_items(), 2) + + def test_jit_with_constants(self): + const = jnp.array([42, 43]) # A distinctive shape + clear_cache() + f = jit(lambda x: x * const[0]) + self.assertAllClose(f(2), 2 * 42) + self.assertEqual(count_cache_items(), 1) + + const1 = jnp.array([142, 143]) # The closed over const can be different + f1 = jit(lambda x: x * const1[0]) + expected_compilations = 0 if config.use_simplified_jaxpr_constants.value else 1 + self.assertCacheMisses( + lambda: f1(3), lowering=1, + compilation_after_persistent_cache_miss=expected_compilations) + self.assertAllClose(f1(3), 3 * 142) + self.assertEqual(count_cache_items(), 1 + expected_compilations) + def test_set_cache_dir_after_backends_init(self): # This a regression test for #25768 with config.compilation_cache_dir(None): @@ -237,7 +295,7 @@ def test_enable_compilation_cache(self): g = jit(lambda x: x * 3) g(2) cache = cc._get_cache(backend) - self.assertIsNotNone(cache) # Cache should be initalized + self.assertIsNotNone(cache) # Cache should be initialized def test_xla_autofdo_profile_version(self): original_profile_version = config.jax_xla_profile_version.value @@ -253,20 +311,6 @@ def test_xla_autofdo_profile_version(self): f(1) self.assertEqual(count_cache_items(), 1) - @jtu.with_mesh([("x", 2)]) - def test_pjit(self): - @partial(pjit, in_shardings=(P("x"), P("x")), out_shardings=None) - def f(x, y): - return x + y - - shape = (8, 8) - x = np.arange(math.prod(shape), dtype=np.int64).reshape(shape) - f(x, x + 1) - self.assertEqual(count_cache_items(), 1) - x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) - f(x, x + 1) - self.assertEqual(count_cache_items(), 2) - def test_cache_write_warning(self): f = jit(lambda x: x * x) @@ -284,7 +328,7 @@ def test_cache_write_warning(self): self.assertIn( ( "Error writing persistent compilation cache entry " - "for 'jit__lambda_': RuntimeError: test error" + "for 'jit__lambda': RuntimeError: test error" ), str(w[0].message), ) @@ -299,7 +343,7 @@ def test_cache_read_warning(self): test_warning_util.record_warnings() as w, ): mock_get.side_effect = RuntimeError("test error") - # Calling assertEqual with the jitted f will generate two PJIT + # Calling assertEqual with the jitted f will generate two JIT # executables: Equal and the lambda function itself. self.assertEqual(f(2).item(), 4) if len(w) != 1: @@ -308,7 +352,7 @@ def test_cache_read_warning(self): self.assertIn( ( "Error reading persistent compilation cache entry " - "for 'jit__lambda_': RuntimeError: test error" + "for 'jit__lambda': RuntimeError: test error" ), str(w[0].message), ) @@ -344,7 +388,8 @@ def test_cache_saving_metric(self): config.persistent_cache_min_entry_size_bytes(0), ): durations = Counter() # Map metric name to time duration. - def append_metric_duration(metric, duration): + def append_metric_duration(metric, duration, **kwargs): + del kwargs durations[metric] += duration with jtu.register_event_duration_listener(append_metric_duration): @@ -562,8 +607,9 @@ def test_backend_serialization_deserialization(self): .runtime_executable() ) serialized_executable = backend.serialize_executable(executable) - deserialized_executable = backend.deserialize_executable( - serialized_executable, None) + deserialized_executable = backend.deserialize_executable( # type: ignore + serialized_executable, + xc.DeviceList(tuple(jax.local_devices(backend=backend))), None) self.assertEqual( executable.fingerprint, deserialized_executable.fingerprint) @@ -603,6 +649,35 @@ def test_persistent_cache_enable_xla_caches(self): self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_per_fusion_autotune_cache_dir, f"jax-cache{s}xla_gpu_per_fusion_autotune_cache_dir") self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_experimental_autotune_cache_mode, xc.AutotuneCacheMode.UPDATE) + @jtu.skip_on_devices("tpu") # TPU backend does not dump on deserialize + def test_dump_on_cache_hit(self): + previous_counts = Counter(_counts) + with ( + config.persistent_cache_min_compile_time_secs(0), + config.persistent_cache_min_entry_size_bytes(0), + tempfile.TemporaryDirectory() as dump_dir1, + tempfile.TemporaryDirectory() as dump_dir2 + ): + jit(lambda x: x + 1, compiler_options={"xla_dump_to": dump_dir1})(1) + self.assertEqual( + _counts["/jax/compilation_cache/cache_hits"], + previous_counts["/jax/compilation_cache/cache_hits"], + ) + jit(lambda x: x + 1, compiler_options={"xla_dump_to": dump_dir2, "xla_dump_hlo_as_proto": True, "xla_dump_hlo_as_text": True})(1) + self.assertEqual( + _counts["/jax/compilation_cache/cache_hits"], + previous_counts["/jax/compilation_cache/cache_hits"] + 1, + 1) + dump1_files = glob.glob(os.path.join(dump_dir1, "*after_optimizations.txt")) + dump2_files = glob.glob(os.path.join(dump_dir2, "*after_optimizations.txt")) + self.assertEqual(len(dump1_files), 1) + self.assertEqual(len(dump2_files), 1) + with (open(dump1_files[0]) as file1, open(dump2_files[0]) as file2): + self.assertEqual(file1.read(), file2.read()) + dump2_pbs = glob.glob(os.path.join(dump_dir2, "*after_optimizations.hlo.pb")) + self.assertEqual(len(dump2_pbs), 1) + + @jtu.with_config( jax_enable_compilation_cache=False, jax_persistent_cache_min_compile_time_secs=0, diff --git a/tests/config/BUILD b/tests/config/BUILD new file mode 100644 index 000000000000..4b3105450c98 --- /dev/null +++ b/tests/config/BUILD @@ -0,0 +1,84 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@bazel_skylib//lib:selects.bzl", "selects") + +package( + default_applicable_licenses = [], + default_visibility = ["//visibility:private"], +) + +# ----- Flags, Config settings ----- + +config_setting( + name = "is_tsan", + values = {"copt": "-fsanitize=thread"}, +) + +# RBE +config_setting( + name = "is_exec_env_rbe", + define_values = {"exec_env": "rbe"}, +) + +config_setting( + name = "is_executor_remote", + define_values = {"EXECUTOR": "remote"}, +) + +selects.config_setting_group( + name = "is_rbe", + match_any = [ + ":is_exec_env_rbe", + ":is_executor_remote", + ], +) + +# A composite setting to use for certain tests +# that use a lot of RAM, which RBE cannot currently afford. +selects.config_setting_group( + name = "tsan_freethreading_rbe", + match_all = [ + ":is_tsan", + "@rules_python//python/config_settings:is_py_freethreaded", + ":is_rbe", + ], +) + +# ----- TSAN suppression files ----- + +filegroup( + name = "tsan_suppressions_txts", + srcs = [ + "tsan-suppressions_3.13.txt", + "tsan-suppressions_3.14.txt", + ], +) + +# ----- TSAN wrapper script ----- + +# Increases the stack size in case the underlying environment's, like Ubuntu 22.04, +# which has it set to 8192, is not enough for JAX tests under TSAN+RBE. +# Not entirely clear if RBE does something with the stack, as the local build ran fine with 8192. +# +# Conveniently, includes the TSAN suppressions' files. +sh_binary( + name = "oss_tsan_wrapper_sh", + srcs = [":oss_tsan_wrapper.sh"], + data = [":tsan_suppressions_txts"], + tags = [ + "manual", + "notap", + ], +) diff --git a/tests/config/oss_tsan_wrapper.sh b/tests/config/oss_tsan_wrapper.sh new file mode 100755 index 000000000000..945380df1048 --- /dev/null +++ b/tests/config/oss_tsan_wrapper.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +# Set stack size to 64MB (KB value) +ulimit -s 65536 + +# Run the actual command passed by Bazel +exec "$@" + diff --git a/.github/workflows/tsan-suppressions.txt b/tests/config/tsan-suppressions_3.13.txt similarity index 50% rename from .github/workflows/tsan-suppressions.txt rename to tests/config/tsan-suppressions_3.13.txt index 7b713b2da194..8b6ddb19bca6 100644 --- a/.github/workflows/tsan-suppressions.txt +++ b/tests/config/tsan-suppressions_3.13.txt @@ -2,34 +2,29 @@ # are racing on a call to __register_frame_info(), but that function appears to be correctly locked internally. race:llvm::RuntimeDyldELF::registerEHFrames -# https://github.com/python/cpython/issues/128050 -race:partial_vectorcall_fallback - # https://github.com/openxla/xla/issues/20686 race:dnnl_sgemm -# https://github.com/python/cpython/issues/128130 -race_top:run_eval_code_obj - -# Likely only happens when the process is crashing. -race:dump_traceback +# https://github.com/python/cpython/issues/128050 +race:partial_vectorcall_fallback # https://github.com/python/cpython/issues/128137 # Fixed in Python 3.14, but not backported to 3.13. race:immortalize_interned race:_PyUnicode_InternMortal +race:_PyUnicode_InternImmortal # https://github.com/python/cpython/issues/128144 # Fixed in Python 3.14, but not backported to 3.13. race_top:PyMember_GetOne -# https://github.com/python/cpython/issues/129547 -race:type_get_annotations - - -# https://github.com/python/cpython/issues/129748 -race:mi_block_set_nextx +# https://github.com/python/cpython/issues/131680 +# Fixed in Python 3.14, but not backported to 3.13. +race_top:new_reference +race:_Py_IsOwnedByCurrentThread +# https://github.com/python/cpython/issues/128130 +race_top:run_eval_code_obj # Races because the LAPACK and BLAS in our scipy isn't TSAN instrumented. race:heevd_ffi @@ -39,29 +34,10 @@ race:scal_k_ race:gemm_beta race:gemm_oncopy +# https://github.com/python/cpython/issues/132214 +# Fixed in Python 3.15, but not backported to 3.13, 3.14. +race:type_update_dict - -# Races below this point are likely fixed. -# TODO(phawkins): remove these if they don't show up in CI again. - -# https://github.com/python/cpython/issues/128100 -# race:ensure_nonmanaged_dict - -# https://github.com/python/cpython/issues/128657 -# race:py_digest_by_name - -# https://github.com/python/cpython/issues/128714 -# race:func_get_annotations - -# https://github.com/python/cpython/issues/129533 -# race:PyGC_Disable -# race:PyGC_Enable - -# https://github.com/python/cpython/issues/128133 -# race:bytes_hash - -# https://github.com/python/cpython/issues/130571 -# race:_PyObject_GetMethod - -# https://github.com/python/cpython/issues/130547 -# race:split_keys_entry_added +# https://github.com/python/cpython/issues/126907 +# Fixed in Python 3.14, but not backported to 3.13 +race:atexit_register diff --git a/tests/config/tsan-suppressions_3.14.txt b/tests/config/tsan-suppressions_3.14.txt new file mode 100644 index 000000000000..d987879cab58 --- /dev/null +++ b/tests/config/tsan-suppressions_3.14.txt @@ -0,0 +1,21 @@ +# false-positive caused because we haven't tsan-instrumented libgcc_s. Multiple threads +# are racing on a call to __register_frame_info(), but that function appears to be correctly locked internally. +race:llvm::RuntimeDyldELF::registerEHFrames + +# https://github.com/openxla/xla/issues/20686 +race:dnnl_sgemm + +# https://github.com/python/cpython/issues/128050 +race:partial_vectorcall_fallback + +# Races because the LAPACK and BLAS in our scipy isn't TSAN instrumented. +race:heevd_ffi +race:gesdd_ffi +race:dscal_k_ +race:scal_k_ +race:gemm_beta +race:gemm_oncopy + +# https://github.com/python/cpython/issues/132214 +# Fixed in Python 3.15, but not backported to 3.13, 3.14. +race:type_update_dict diff --git a/tests/config_test.py b/tests/config_test.py index 5a6da36705a6..a8578ff2e2f5 100644 --- a/tests/config_test.py +++ b/tests/config_test.py @@ -13,6 +13,7 @@ # limitations under the License. from absl.testing import absltest +from absl.testing import parameterized import jax from jax._src import config from jax._src import test_util as jtu @@ -21,6 +22,12 @@ jax.config.parse_flags_with_absl() +jax_test_bool_config = config.bool_state( + name='jax_test_bool_config', + default=True, + help='Configuration only used for tests.', +) + jax_test_enum_config = config.enum_state( name='jax_test_enum_config', enum_values=['default', 'xxx', 'yyy'], @@ -29,31 +36,50 @@ ) +class InvalidBool: + def __bool__(self): + raise ValueError("invalid bool") + + class ConfigTest(jtu.JaxTestCase): - def test_config_setting_via_update(self): - self.assertEqual(jax_test_enum_config.value, 'default') + @parameterized.named_parameters( + {"testcase_name": "_enum", "config_name": "jax_test_enum_config", + "config_obj": jax_test_enum_config, "default": "default", "val1": "xxx", + "val2": "yyy"}, + {"testcase_name": "_bool", "config_name": "jax_test_bool_config", + "config_obj": jax_test_bool_config, "default": True, "val1": False, + "val2": True}, + ) + def test_config_setting_via_update(self, config_name, config_obj, default, val1, val2): + self.assertEqual(config_obj.value, default) - jax.config.update('jax_test_enum_config', 'xxx') - self.assertEqual(jax_test_enum_config.value, 'xxx') + jax.config.update(config_name, val1) + self.assertEqual(config_obj.value, val1) - jax.config.update('jax_test_enum_config', 'yyy') - self.assertEqual(jax_test_enum_config.value, 'yyy') + jax.config.update(config_name, val2) + self.assertEqual(config_obj.value, val2) - jax.config.update('jax_test_enum_config', 'default') - self.assertEqual(jax_test_enum_config.value, 'default') + jax.config.update(config_name, default) + self.assertEqual(config_obj.value, default) - def test_config_setting_via_context(self): - self.assertEqual(jax_test_enum_config.value, 'default') + @parameterized.named_parameters( + {"testcase_name": "_enum", "config_obj": jax_test_enum_config, + "default": "default", "val1": "xxx", "val2": "yyy"}, + {"testcase_name": "_bool", "config_obj": jax_test_bool_config, + "default": True, "val1": False, "val2": True}, + ) + def test_config_setting_via_context(self, config_obj, default, val1, val2): + self.assertEqual(config_obj.value, default) - with jax_test_enum_config('xxx'): - self.assertEqual(jax_test_enum_config.value, 'xxx') + with config_obj(val1): + self.assertEqual(config_obj.value, val1) - with jax_test_enum_config('yyy'): - self.assertEqual(jax_test_enum_config.value, 'yyy') + with config_obj(val2): + self.assertEqual(config_obj.value, val2) - self.assertEqual(jax_test_enum_config.value, 'xxx') + self.assertEqual(config_obj.value, val1) - self.assertEqual(jax_test_enum_config.value, 'default') + self.assertEqual(config_obj.value, default) def test_config_update_validation(self): self.assertEqual(jax_test_enum_config.value, 'default') @@ -69,6 +95,20 @@ def test_config_context_validation(self): pass self.assertEqual(jax_test_enum_config.value, 'default') + def test_bool_config_update_validation(self): + self.assertEqual(jax_test_bool_config.value, True) + with self.assertRaisesRegex(ValueError, "invalid bool"): + jax.config.update('jax_test_bool_config', InvalidBool()) + # Error should raise before changing the value + self.assertEqual(jax_test_bool_config.value, True) + + def test_bool_config_context_validation(self): + self.assertEqual(jax_test_bool_config.value, True) + with self.assertRaisesRegex(ValueError, "invalid bool"): + with jax_test_bool_config(InvalidBool()): + pass + self.assertEqual(jax_test_bool_config.value, True) + def test_cloud_tpu_init(self): if not jtu.is_cloud_tpu(): self.skipTest('Not running on a Cloud TPU VM.') diff --git a/tests/core_test.py b/tests/core_test.py index c46d493bda54..86430593c50a 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest from collections import namedtuple from functools import partial import gc @@ -32,7 +31,7 @@ from jax._src import linear_util as lu from jax._src import util from jax._src import test_util as jtu -from jax._src.core import ShapedArray, DBIdx +from jax._src.core import ShapedArray from jax._src.interpreters import partial_eval as pe from jax._src.lax import control_flow as lax_control_flow @@ -203,6 +202,13 @@ def test_is_valid_jaxtype(self, dtype): else: self.assertFalse(core.valid_jaxtype(arr)) + def test_str_aval(self): + aval = ShapedArray((8, 2), np.int32) + self.assertEqual(str(aval), "int32[8,2]") + + aval = ShapedArray((8, 2), np.int32, weak_type=True) + self.assertEqual(str(aval), "~int32[8,2]") + @parameterized.named_parameters( (str(i), *spec) for i, spec in enumerate(test_specs)) def test_jit(self, f, args): @@ -354,18 +360,27 @@ def g_vmap(x): 'This BatchTracer with object id'): g_vmap(jnp.ones((1, ))) + def test_aval_str_short_mem_space(self): + aval = core.ShapedArray((8,), jnp.float32, + memory_space=jax.memory.Space.Host) + self.assertEqual(aval.str_short(True), "f32[8]") + + aval = core.ShapedArray((8,), jnp.float32, + memory_space=jax.memory.Space.Device) + self.assertEqual(aval.str_short(True), "f32[8]") + def test_dropvar_avals(self): def f(x): def body(c, _): - return c, None + x1, x2 = c + return (2 * x1, 2 * x2), None (x1, x2), _ = jax.lax.scan(body, (x, x), None, length=1) return [x2] aval = core.ShapedArray((), jnp.dtype('int32')) pval = pe.PartialVal.unknown(aval) jaxpr, _, _ = pe.trace_to_jaxpr_nounits( - lu.wrap_init(f, - debug_info=debug_info("test", f, (0,), {})), + lu.wrap_init(f, debug_info=debug_info("test", f, (0,), {})), [pval], False) dropvar, b = jaxpr.eqns[0].outvars self.assertEqual(dropvar.aval, aval) @@ -387,15 +402,58 @@ def f(y): self.assertLen(e1.outvars, 1) # only primal out, no residuals self.assertEqual(e1.outvars[0].aval.shape, (3, 3)) # only primal out shape + def test_tracer_reprs(self): + def f(x): + nonlocal x_repr + x_repr = repr(x) + return x.sum() + x_repr = "" + + jax.jit(f)(jnp.arange(10.0, dtype='float32')) + self.assertEqual(x_repr, "JitTracer(float32[10])") + + jax.vmap(f)(jnp.arange(20, dtype='int32')) + self.assertEqual(x_repr, "VmapTracer(aval=int32[], batched=int32[20])") + + jax.grad(f)(jnp.float16(1.0)) + self.assertEqual(x_repr, "GradTracer(primal=1.0, typeof(tangent)=f16[])") + + jax.jacrev(f)(jnp.arange(4, dtype='float32')) + self.assertEqual(x_repr, "GradTracer(primal=[0. 1. 2. 3.], typeof(tangent)=f32[4])") + + jax.jacfwd(f)(jnp.arange(3, dtype='float32')) + self.assertEqual(x_repr, "JVPTracer(primal=[0. 1. 2.], tangent=VmapTracer(aval=float32[3], batched=float32[3,3]))") + + def test_verbose_tracer_reprs(self): + # Verbose reprs, avaiable via tracer._pretty_print() + def f(x): + nonlocal x_repr + x_repr = x._pretty_print(verbose=True).format() + return x.sum() + x_repr = "" + + jax.jit(f)(jnp.arange(10.0, dtype='float32')) + self.assertRegex(x_repr, r"^Tracedwith") + + jax.vmap(f)(jnp.arange(20, dtype='int32')) + self.assertRegex(x_repr, r"^Tracedwith") + + jax.grad(f)(jnp.float16(1.0)) + self.assertRegex(x_repr, r"^Tracedwith<(JVP)|(Linearize)Trace>") + @jtu.with_config(jax_pprint_use_color=False) class JaxprTypeChecks(jtu.JaxTestCase): def setUp(self): super().setUp() - lax_control_flow._initial_style_open_jaxpr.cache_clear() - lax_control_flow._initial_style_jaxpr.cache_clear() - lax_control_flow.common._pad_jaxpr_constvars.cache_clear() + lax_control_flow.common._dedup_consts.cache_clear() + lax_control_flow.common._pad_constvars.cache_clear() + + def tearDown(self): + super().tearDown() + lax_control_flow.common._dedup_consts.cache_clear() + lax_control_flow.common._pad_constvars.cache_clear() def test_check_jaxpr_correct(self): jaxpr = make_jaxpr(lambda x: jnp.sin(x) + jnp.cos(x))(1.).jaxpr @@ -405,6 +463,7 @@ def test_check_jaxpr_cond_correct(self): jaxpr = make_jaxpr(lambda x: lax.switch(0, [jnp.sin, jnp.cos], x))(1.).jaxpr core.check_jaxpr(jaxpr) + @jtu.thread_unsafe_test() # in-place mutation of possibly-cached jaxpr def test_check_jaxpr_jit_invalid(self): jaxpr = make_jaxpr(jax.jit(lambda x, y: x + 1))(1., 2.).jaxpr pjit_eqn, = jaxpr.eqns @@ -414,6 +473,7 @@ def test_check_jaxpr_jit_invalid(self): '0 operands cannot call jaxpr with 2 inputs', lambda: core.check_jaxpr(jaxpr)) + @jtu.thread_unsafe_test() # in-place mutation of possibly-cached jaxpr def test_check_jaxpr_cond_invalid(self): jaxpr = make_jaxpr(lambda x: lax.switch(0, [jnp.sin, jnp.cos], x))(1.).jaxpr cond = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cond') @@ -433,6 +493,7 @@ def f(c, x): jaxpr = make_jaxpr(partial(lax.scan, f))(c, xs).jaxpr core.check_jaxpr(jaxpr) + @jtu.thread_unsafe_test() # in-place mutation of possibly-cached jaxpr def test_check_jaxpr_invalid_long(self): # jaxprs can be large, and this tests that when large ones are printed for # context in jaxpr typechecking errors, they're not printed entirely @@ -464,6 +525,7 @@ def g(x): self.assertIn('while checking jaxpr:', msg) self.assertLess(msg.count('\n'), 200) + @jtu.thread_unsafe_test() # in-place mutation of possibly-cached jaxpr def test_check_jaxpr_eqn_mismatch(self): def f(x): return jnp.sin(x) + jnp.cos(x) @@ -487,7 +549,7 @@ def new_jaxpr(): self.assertRaisesRegex( core.JaxprTypeError, r"Value for variable 'b' inconsistently typed as f32\[\] " - r"for let-binder of type i32\[\]\n\nin equation:\n\nb:i32\[\] = sin a", + r"for let-binder of type i32\[\]\n\nin equation:\n\nb:i32\[\] = sin\ a", lambda: core.check_jaxpr(jaxpr)) jaxpr = new_jaxpr() @@ -496,7 +558,7 @@ def new_jaxpr(): self.assertRaisesRegex( core.JaxprTypeError, r"Value for variable 'b' inconsistently typed as f32\[\] " - r"for let-binder of type f32\[2,3\]\n\nin equation:\n\nb:f32\[2,3\] = sin a", + r"for let-binder of type f32\[2,3\]\n\nin equation:\n\nb:f32\[2,3\] = sin\ a", lambda: core.check_jaxpr(jaxpr)) def test_jaxpr_dropvar_from_jit_call(self): @@ -534,202 +596,6 @@ def f(x): assert isinstance(jaxpr.eqns[-1].outvars[0], core.DropVar) core.check_jaxpr(jaxpr) - def test_jaxpr_undefined_eqn_invar(self): - jaxpr = make_jaxpr(lambda x: jnp.sin(x) + jnp.cos(x))(1.).jaxpr - cos = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cos') - cos.invars[0] = core.gensym(suffix='_test')(cos.invars[0].aval) - self.assertRaisesRegex( - core.JaxprTypeError, - r"Variable '.+_test' not defined\n\nin equation:", - lambda: core.check_jaxpr(jaxpr)) - - -@jtu.with_config(jax_dynamic_shapes=True) -class DynamicShapesTest(jtu.JaxTestCase): - - def test_staging_basic(self): - n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False) - a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False) - b = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False) - - def f(x, y): - return x, y - - jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f, - debug_info=debug_info("test", f, (1, 2), {})), - [n, a, b], keep_inputs=[False, True, True]) - - self.assertLen(jaxpr.invars, 3) - self.assertEqual((jaxpr.invars[0],), jaxpr.invars[1].aval.shape) - self.assertEqual((jaxpr.invars[0],), jaxpr.invars[2].aval.shape) - - self.assertLen(jaxpr.outvars, 2) - self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[0].aval.shape) - self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[1].aval.shape) - - @unittest.skip('This test does not work with nested pjit and DShapedArray') - def test_staging_nested(self): - n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False) - a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False) - b = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False) - - def f(x, y): - @jax.jit - def g(x, y, z, w): - return (x, w) - return g(x, y, x, y) - - jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f, - debug_info=debug_info("test", f, (0, 1), {})), - [n, a, b], keep_inputs=[False, True, True]) - - self.assertLen(jaxpr.invars, 1 + 2) # one axis size var, two other inputs - self.assertEqual((jaxpr.invars[0],), jaxpr.invars[1].aval.shape) - self.assertEqual((jaxpr.invars[0],), jaxpr.invars[2].aval.shape) - - self.assertLen(jaxpr.outvars, 2) - self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[0].aval.shape) - self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[1].aval.shape) - - self.assertLen(jaxpr.eqns, 1) - eqn = jaxpr.eqns[0] - self.assertIsInstance(eqn.primitive, core.CallPrimitive) - inner_jaxpr = eqn.params['call_jaxpr'] - self.assertIsInstance(inner_jaxpr, core.Jaxpr) - - self.assertLen(inner_jaxpr.invars, 1 + 4) # one axis size var - self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[1].aval.shape) - self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[2].aval.shape) - self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[3].aval.shape) - self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[4].aval.shape) - - @unittest.skip('This test does not work with nested pjit and DShapedArray') - def test_staging_nested_including_shape_arg(self): - n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False) - a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False) - b = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False) - - def f(x, y): - @jax.jit - def g(_, x, y, z, w): - return (x, w) - return g(x.shape[0], x, y, x, y) - - jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f, - debug_info=debug_info("test", f, (1, 2), {})), - [n, a, b], keep_inputs=[False, True, True]) - - # { lambda ; a:i32[] b:f32[a] c:f32[a]. let - # d:f32[a] e:f32[a] = xla_call[ - # call_jaxpr={ lambda ; f:i32[] g:i32[] h:f32[f] i:f32[f] j:f32[f] k:f32[f]. let - # - # in (h, k) } - # name=g - # ] a a b c b c - # in (d, e) } - - self.assertLen(jaxpr.eqns, 1) - eqn = jaxpr.eqns[0] - self.assertIsInstance(eqn.primitive, core.CallPrimitive) - inner_jaxpr = eqn.params['call_jaxpr'] - self.assertIsInstance(inner_jaxpr, core.Jaxpr) - - self.assertLen(inner_jaxpr.invars, 1 + 4) # one axis size var - self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[1].aval.shape) - self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[2].aval.shape) - self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[3].aval.shape) - self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[4].aval.shape) - - def test_staging_primitive_applications(self): - n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False) - a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False) - b = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False) - - def f(x, y): - z = lax.mul(x, y) - w = lax.sin(z) - u = lax.reduce_sum(w, [0]) - return (u,) - - jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f, - debug_info=debug_info("test", f, (1, 2), {})), - [n, a, b], keep_inputs=[False, True, True]) - - self.assertLen(jaxpr.invars, 1 + 2) # one axis size var, two other inputs - self.assertLen(jaxpr.eqns, 3) - self.assertLen(jaxpr.eqns[0].outvars, 1) - self.assertEqual(jaxpr.eqns[0].outvars[0].aval.shape, - jaxpr.invars[1].aval.shape) - - self.assertLen(jaxpr.outvars, 1) - self.assertEqual(jaxpr.outvars[0].aval.shape, ()) - - @unittest.skip('This test does not work with nested pjit and DShapedArray') - def test_typecheck_staging_nested(self): - n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False) - m = core.ShapedArray((), jnp.dtype('int32'), weak_type=False) - a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False) - b = core.DShapedArray((DBIdx(1),), jnp.dtype('float32'), weak_type=False) - - def f(a, b): - @jax.jit - def g(x): return x - return g(a), - - jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f, - debug_info=debug_info("test", f, (1, 2), {})), - [n, m, a, b], keep_inputs=[False, False, True, True]) - # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let - # e:f32[a] = xla_call[ - # call_jaxpr={ lambda ; f:i32[] g:f32[f]. let in (g,) } - # name=g - # ] a c - # in (e,) } - core.check_jaxpr(jaxpr) # no problems here... - - # Let's introduce a type error by applying the called jaxpr to arguments - # with types which aren't consistent with its input binders: - _, _, c, d = jaxpr.invars - jaxpr.eqns[0].invars[1] = d - # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let - # e:f32[a] = xla_call[ - # call_jaxpr={ lambda ; f:i32[] g:f32[f]. let in (g,) } - # name=g - # ] a d !!! type error here !!! - # in (e,) } - with self.assertRaisesRegex(TypeError, "passes operand"): - core.check_jaxpr(jaxpr) - - # Restore the original jaxpr: - jaxpr.eqns[0].invars[1] = c - core.check_jaxpr(jaxpr) # no problems here... - - # Let's introduce another type error by setting the call result let binders - # to have the wrong type: - jaxpr.eqns[0].outvars[0] = core.Var('', d.aval) - # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let - # e:f32[b] = xla_call[ !!! type error here !!! - # call_jaxpr={ lambda ; f:i32[] g:f32[f]. let in (g,) } - # name=g - # ] a c - # in (h,) } - with self.assertRaisesRegex(TypeError, "inconsistently typed as"): - core.check_jaxpr(jaxpr) - - def test_check_jaxpr_key_reuse(self): - with config.debug_key_reuse(True): - def f(seed): - key = jax.random.key(seed) - return jax.random.uniform(key) + jax.random.normal(key) - with jax.enable_checks(True): - with self.assertRaises(jax.errors.KeyReuseError): - jax.jit(f)(0) - if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/cudnn_fusion_test.py b/tests/cudnn_fusion_test.py index 7dc0571bc172..5e04d69889fc 100644 --- a/tests/cudnn_fusion_test.py +++ b/tests/cudnn_fusion_test.py @@ -12,36 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. -from absl.testing import absltest, parameterized from unittest import SkipTest -from jax._src import test_util as jtu -from jax._src.lib import cuda_versions +from absl.testing import absltest, parameterized import jax -import jax.numpy as jnp +from jax._src import test_util as jtu from jax._src.cudnn import cudnn_fusion +import jax.numpy as jnp jax.config.parse_flags_with_absl() class CudnnFusionTest(jtu.JaxTestCase): + def setUp(self): - if (not jtu.test_device_matches(["cuda"]) or - not jtu.is_cuda_compute_capability_at_least("8.0") or - cuda_versions.cudnn_get_version() < 90110): - self.skipTest("Only works on >= sm80 GPUs with cuDNN 9.1.1+") + if not jtu.test_device_matches( + ["cuda"] + ) or not jtu.is_cuda_compute_capability_at_least("8.0"): + self.skipTest("Only works on >= sm80 GPUs") super().setUp() @parameterized.parameters(["", "pmap"]) @jtu.run_on_devices("cuda") def test_cudnn_fusion(self, mode): + if jtu.is_cuda_version_at_least(13, 0): + self.skipTest("cuDNN creates no execution plans on CUDA 13.0.") + batch_size = 2 if mode == "pmap" and jax.device_count() < batch_size: - raise SkipTest("pmap test requires 2 GPUs") + raise SkipTest("pmap test requires 2 GPUs") @cudnn_fusion def comp1(x, y, z): - return jnp.float32(jax.lax.batch_matmul(jnp.bfloat16(x), y)) + z + return jnp.float32(jax.lax.batch_matmul(jnp.bfloat16(x), y)) + z k = jax.random.key(0) s = batch_size, 16, 16 @@ -60,7 +63,14 @@ def comp1(x, y, z): self.assertIn('custom_call_target="__cudnn$fusion"', hlo) self.assertIn("called_computations=", hlo) - compiled = lowered.compile({"xla_gpu_cublas_fallback": False}) + compiled = lowered.compile({ + # Disable Cublas to make sure CuDNN is used. + "xla_gpu_cublas_fallback": False, + # Enable CuDNN fusions. + "xla_gpu_cudnn_gemm_fusion_level": 2, + # Disable autotuning to pick first config to ensure CuDNN is always used. + "xla_gpu_autotune_level": 0, + }) hlo_after_opt = compiled.as_text() self.assertIn("kind=kCustom", hlo_after_opt) @@ -69,5 +79,5 @@ def comp1(x, y, z): self.assertAllClose(compiled(x, y, z), fn(x, y, z)) -if __name__ == '__main__': +if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/custom_api_test.py b/tests/custom_api_test.py new file mode 100644 index 000000000000..bde7c1762cc5 --- /dev/null +++ b/tests/custom_api_test.py @@ -0,0 +1,4879 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +from collections.abc import Callable +import concurrent.futures +import functools +from functools import partial +import itertools as it +import re +import unittest +import textwrap + +from absl.testing import absltest, parameterized +import numpy as np + +import jax +import jax.numpy as jnp +from jax import float0, grad, jit +from jax import lax +from jax import tree_util +import jax.custom_batching +import jax.custom_derivatives +import jax.custom_transpose +import jax.experimental.custom_dce +from jax.errors import UnexpectedTracerError + +from jax._src import api +from jax._src import api_util +from jax._src import config +from jax._src import core +from jax._src import custom_derivatives +from jax._src import hijax +from jax._src import test_util as jtu +from jax._src.interpreters import partial_eval as pe + +config.parse_flags_with_absl() + + +class CustomJVPTest(jtu.JaxTestCase): + + def test_basic(self): + @jax.custom_jvp + def f(x): + return jnp.sin(x) + def f_jvp(primals, tangents): + x, = primals + g, = tangents + return f(x), 2 * jnp.cos(x) * g + f.defjvp(f_jvp) + + x = 3. + self.assertAllClose(f(x), jnp.sin(x)) + self.assertAllClose(api.jvp(f, (x,), (1.,)), + (jnp.sin(x), 2 * jnp.cos(x))) + self.assertAllClose(api.grad(f)(x), 2 * jnp.cos(x)) + + def test_invariance(self): + @jax.custom_jvp + def f(x): + return jnp.cos(2 * x) / 2. + def f_jvp(primals, tangents): + x, = primals + g, = tangents + return (f(x), 3 * g) + f.defjvp(f_jvp) + def f2(x): + y, _ = api.jvp(f, (x,), (x,)) + return y + def f3(x): + y, _ = api.jvp(f2, (x,), (x,)) + return y + x = 1. + self.assertAllClose(api.jvp(f, (x,), (x,)), + api.jvp(f2, (x,), (x,)), + check_dtypes=False) + self.assertAllClose(api.jvp(f, (x,), (x,)), + api.jvp(f3, (x,), (x,)), + check_dtypes=False) + + def test_python_control_flow(self): + @jax.custom_jvp + def f(x): + if x > 0: + return jnp.sin(x) + else: + return jnp.cos(x) + def f_jvp(primals, tangents): + x, = primals + g, = tangents + if x > 0: + return f(x), 2 * g + else: + return f(x), 3 * g + f.defjvp(f_jvp) + x = 2. + self.assertAllClose(f(x), jnp.sin(x)) + self.assertAllClose(f(-x), jnp.cos(-x)) + self.assertAllClose(api.jvp(f, (x,), (1.,)), + (jnp.sin(x), 2.), + check_dtypes=False) + self.assertAllClose(api.jvp(f, (-x,), (1.,)), + (jnp.cos(-x), 3.), + check_dtypes=False) + self.assertAllClose(api.grad(f)(x), 2., check_dtypes=False) + self.assertAllClose(api.grad(f)(-x), 3., check_dtypes=False) + + def test_vmap(self): + @jax.custom_jvp + def f(x): + assert jnp.ndim(x) == 0 + return jnp.sin(x) + def f_jvp(primals, tangents): + x, = primals + g, = tangents + assert jnp.ndim(x) == jnp.ndim(g) == 0 + return f(x), 2 * jnp.cos(x) * g + f.defjvp(f_jvp) + + x = jnp.arange(3.) + xx = jnp.arange(6.).reshape(2, 3) + + # vmap of f + self.assertAllClose(api.vmap(f)(x), jnp.sin(x)) + self.assertAllClose(api.vmap(api.vmap(f))(xx), jnp.sin(xx)) + + # vmap of jvp of f + self.assertAllClose(api.vmap(lambda x: api.jvp(f, (x,), (x,)))(x), + (jnp.sin(x), 2 * jnp.cos(x) * x)) + self.assertAllClose(api.vmap(api.vmap(lambda x: api.jvp(f, (x,), (x,))))(xx), + (jnp.sin(xx), 2 * jnp.cos(xx) * xx)) + + # jvp of vmap of f + self.assertAllClose(api.jvp(api.vmap(f), (x,), (x,)), + (jnp.sin(x), 2 * jnp.cos(x) * x)) + self.assertAllClose(api.jvp(api.vmap(api.vmap(f)), (xx,), (xx,)), + (jnp.sin(xx), 2 * jnp.cos(xx) * xx)) + + # vmap of jvp of vmap of f + self.assertAllClose(api.vmap(lambda x: api.jvp(api.vmap(f), (x,), (x,)))(xx), + (jnp.sin(xx), 2 * jnp.cos(xx) * xx)) + + def test_jit(self): + @jax.custom_jvp + def f(x): + return jnp.sin(x) + def f_jvp(primals, tangents): + x, = primals + g, = tangents + return f(x), 2 * jnp.cos(x) * g + f.defjvp(f_jvp) + + x = 3. + + # jit + self.assertAllClose(api.jit(f)(x), jnp.sin(x)) + self.assertAllClose(api.jit(api.jit(f))(x), jnp.sin(x)) + + # jit of jvp + self.assertAllClose(api.jit(lambda x: api.jvp(f, (x,), (x,)))(x), + (jnp.sin(x), 2 * jnp.cos(x) * x), + check_dtypes=False) + + # jvp of jit + self.assertAllClose(api.jvp(api.jit(f), (x,), (x,)), + (jnp.sin(x), 2 * jnp.cos(x) * x), + check_dtypes=False) + + def test_pytrees(self): + @jax.custom_jvp + def f(x): + return {'b': jnp.sin(x['a'])} + def f_jvp(primals, tangents): + x, = primals + g, = tangents + return f(x), {'b': 2 * jnp.cos(x['a']) * g['a']} + f.defjvp(f_jvp) + x = {'a': 3.} + self.assertAllClose(f(x)['b'], jnp.sin(x['a'])) + self.assertAllClose(api.jvp(f, (x,), (x,)), + ({'b': jnp.sin(x['a'])}, + {'b': 2 * jnp.cos(x['a']) * x['a']}), + check_dtypes=False) + + def test_kwargs(self): + # from https://github.com/jax-ml/jax/issues/1938 + @jax.custom_jvp + def my_fun(x, y, c=1.): + return c * (x + y) + def my_jvp(primals, tangents): + x, y, c = primals + t_x, t_y, t_c = tangents + return my_fun(x, y, c), t_c + my_fun.defjvp(my_jvp) + f = lambda x, y: jnp.square(my_fun(x, y, c=2.)).sum() + f(10., 5.) # doesn't crash + api.jvp(f, (10., 5.), (1., 1.)) # doesn't crash + + def test_initial_style(self): + @jax.custom_jvp + def f(x): + return 3 * x + def f_jvp(primals, tangents): + x, = primals + g, = tangents + return f(x), 2 * g + f.defjvp(f_jvp) + + def foo(x): + out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) + return out + + ans = api.grad(foo)(3.) + expected = 2. + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(api.jit(foo))(3.) + expected = 2. + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.jit(api.grad(foo))(3.) + expected = 2. + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(api.grad(foo))(3.) + expected = 0. + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(api.grad(api.jit(foo)))(3.) + expected = 0. + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(api.jit(api.grad(foo)))(3.) + expected = 0. + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.jit(api.grad(api.grad(foo)))(3.) + expected = 0. + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_initial_style_vmap(self): + @jax.custom_jvp + def f(x): + assert jnp.ndim(x) == 0 + return 3 * x + def f_jvp(primals, tangents): + x, = primals + g, = tangents + return f(x), 2 * g + f.defjvp(f_jvp) + + def foo(x): + out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) + return out + + ans = api.vmap(foo)(jnp.ones(3)) + expected = 3. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.vmap(api.jit(foo))(jnp.ones(3)) + expected = 3. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.jit(api.vmap(foo))(jnp.ones(3)) + expected = 3. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.ones(3)) + expected = 2. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(lambda x: api.vmap(api.jit(foo))(x).sum())(jnp.ones(3)) + expected = 2. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(lambda x: api.jit(api.vmap(foo))(x).sum())(jnp.ones(3)) + expected = 2. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(api.jit(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3)) + expected = 2. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.jit(api.grad(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3)) + expected = 2. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_initial_style_vmap_with_collective(self): + + @jax.custom_jvp + def f(x): + return lax.psum(x, 'foo') + + @f.defjvp + def f_jvp(xs, ts): + x, = xs + t, = ts + return lax.psum(x, 'foo'), t + + def g(x): + jaxpr = api.make_jaxpr(f)(x) + return core.eval_jaxpr(jaxpr.jaxpr, [], x)[0] + + v = api.vmap(lambda _, x: g(x), axis_name='foo', in_axes=(0, None), + out_axes=None)(jnp.arange(4.), 2.) + self.assertAllClose(v, 8.) + + def test_closed_over_tracers_error_message(self): + def f(x): + @jax.custom_jvp + def g(y): + return x + y + def g_jvp(primals, tangents): + return g(x), 2 * primals[0] + g.defjvp(g_jvp) + return g(1.) + + self.assertRaises(UnexpectedTracerError, lambda: api.jvp(f, (3.,), (1.,))) + self.assertRaises(UnexpectedTracerError, lambda: api.grad(f)(3.)) + + def test_nondiff_argnums(self): + @partial(jax.custom_jvp, nondiff_argnums=(0,)) + def app(f, x): + return f(x) + def app_jvp(f, primals, tangents): + (x,), (t,) = primals, tangents + return app(f, x), 3 * t + app.defjvp(app_jvp) + + ans = app(lambda x: 2 * x, 1) + expected = 2 + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.jvp(lambda x: app(lambda y: 2 * y, x), (1.,), (1.,)) + expected = (2., 3.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_nondiff_argnames(self): + @partial(jax.custom_jvp, nondiff_argnames=('f',)) + def app(f, x): + return f(x) + + def app_jvp(f, primals, tangents): + (x,), (t,) = primals, tangents + return app(f, x), 3 * t + + app.defjvp(app_jvp) + + ans = app(lambda x: 2 * x, 1) + expected = 2 + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_nondiff_arg_jit_tracer(self): + # This test would pass with "final-style" JIT tracing, but that was + # misleading: it doesn't work with "initial-style" staging, i.e. control + # flow primitives like jax.lax.scan or even pjit. The behavior isn't very + # useful either: instead of using nondiff_argnums here, a user can just pass + # such inputs as ordinary arguments, and ignore the corresponding tangents. + # Then nondiff_argnums can be reserved for (1) non jaxtype data (like a + # string- or callable-valued argument which parameterizes the function or + # rule) or (2) static data (e.g. integers which parameterize shapes). + raise unittest.SkipTest("behavior no longer supported") + + @partial(jax.custom_jvp, nondiff_argnums=(0,)) + def f(x, y): + return x * y + def f_jvp(x, primals, tangents): + (y,), (t_y,) = primals, tangents + return f(x, y), 5 * t_y + f.defjvp(f_jvp) + + @jit + def g(x, y): + return f(x, y) + + ans = api.jvp(lambda y: g(2., y), (3.,), (1.,)) + expected = (6., 5.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_nondiff_arg_vmap_tracer(self): + @partial(jax.custom_jvp, nondiff_argnums=(0,)) + def f(x, y): + return x * y + def f_jvp(x, primals, tangents): + (y,), (t_y,) = primals, tangents + return f(x, y), 5 * t_y + f.defjvp(f_jvp) + + g = jax.vmap(f) + + ans = api.jvp(lambda y: g(jnp.array([2.]), y), + (jnp.array([3.]),), (jnp.array([1.]),)) + expected = (jnp.array([6.]), jnp.array([5.])) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_nondiff_arg_hiding_jvp_tracer(self): + def f(x): + @partial(jax.custom_jvp, nondiff_argnums=(0,)) + def g(h, x): + return h(x) + @g.defjvp + def g_jvp(h, primals, tangents): + x, = primals + t, = tangents + return g(h, x), 2. * t + h = lambda y: x + y # capture x + return g(h, x) + + with self.assertRaises(UnexpectedTracerError): + api.jvp(f, (2.,), (1.,)) + + def test_vmap_axes(self): + raise unittest.SkipTest("TODO") # TODO(mattjj): write test + + def test_pmap(self): + raise unittest.SkipTest("TODO") # TODO(mattjj): write test + + def test_missing_jvp_rule_error_message(self): + @jax.custom_jvp + def foo(x): + return x ** 2 + + self.assertRaisesRegex( + AttributeError, + r"No JVP defined for custom_jvp function foo using defjvp.", + lambda: foo(2)) + self.assertRaisesRegex( + AttributeError, + r"No JVP defined for custom_jvp function foo using defjvp.", + lambda: api.jvp(foo, (2.,), (1.,))) + self.assertRaisesRegex( + AttributeError, + r"No JVP defined for custom_jvp function foo using defjvp.", + lambda: api.grad(foo)(2.)) + + def test_jvp_rule_inconsistent_pytree_structures_error_message(self): + @jax.custom_jvp + def f(x): + return (x**2,) + + @f.defjvp + def foo_jvp(primals, tangents): + x, = primals + t, = tangents + return f(x), [2 * x * t, x] + + f(2.) # doesn't crash + self.assertRaisesRegex( + TypeError, + re.escape( + "Custom JVP rule foo_jvp for function f " + "must produce primal and tangent outputs " + "with equal container (pytree) structures, but got " + "{} and {} respectively.".format( + jax.tree.structure((1,)), + jax.tree.structure([1, 2])) + ), + lambda: api.jvp(f, (2.,), (1.,))) + + def test_primal_tangent_aval_disagreement_error_message(self): + @jax.custom_jvp + def f(x): + return x ** 2 + + @f.defjvp + def foo_jvp(primals, tangents): + x, = primals + t, = tangents + return f(x), jnp.reshape(t, (1,)) + + f(2.) # doesn't crash + self.assertRaisesRegex( + TypeError, + re.escape( + "Custom JVP rule must produce primal and tangent outputs " + "with corresponding shapes and dtypes. " + "Expected float32[] (tangent type of float32[]) but got float32[1]."), + lambda: api.jvp(f, (jnp.float32(2.),), (jnp.float32(1.),))) + + + def test_jvp_rule_doesnt_return_pair_error_message(self): + # https://github.com/jax-ml/jax/issues/2516 + + @jax.custom_jvp + def f(x): + return x ** 2 + + @f.defjvp + def foo_jvp(primals, tangents): + x, = primals + t, = tangents + return t + + f(2.) # doesn't crash + self.assertRaisesRegex( + TypeError, + re.escape( + "Custom JVP rule foo_jvp for function f " + "must produce a pair (list or tuple of length two) " + "representing primal and tangent outputs, but got 1.0"), + lambda: api.jvp(f, (2.,), (1.,))) + + def test_jvp_rule_primal_out_type_doesnt_match_primal_error_message(self): + # https://github.com/lucidrains/flash-attention-jax/issues/7 + + def scan_apply(f, x): + y, _ = jax.lax.scan(lambda x, _: (f(x), None), x, None, length=1) + return y + + @jax.custom_jvp + def f(x): + return x + + @f.defjvp + def f_jvp(primals, tangents): + (x,), (xdot,) = primals, tangents + return (x, x), (xdot, xdot) + + x = jnp.float32(1.) + self.assertRaisesRegex( + TypeError, + re.escape( + "Custom JVP rule f_jvp for function f must produce a pair " + "(list or tuple of length two) where the first element represents " + "the primal output (equal in value to the output of the " + "custom_jvp-decorated function f, and in particular of the " + "same container/pytree structure), but instead the JVP rule " + "output's first element had container/pytree structure:\n" + " (float32[], float32[])\n" + "while the custom_jvp-decorated function f had output " + "container/pytree structure:\n" + " float32[]." + ), + lambda: jax.jvp(lambda x: scan_apply(f, x), (x,), (x,))) + + @f.defjvp + def f_jvp2(primals, tangents): + (x,), (xdot,) = primals, tangents + return jnp.zeros((3, *x.shape), x.dtype), xdot + + self.assertRaisesRegex( + TypeError, + re.escape( + "Custom JVP rule f_jvp2 for function f must produce a pair " + "(list or tuple of length two) where the first element represents " + "the primal output (equal in value to the output of the " + "custom_jvp-decorated function f, and in particular " + "with leaves of the same shape/dtype), but instead the JVP rule " + "output's first element had shapes/dtypes of:\n" + " float32[3]\n" + "while the custom_jvp-decorated function f had output shapes/dtypes" + " of:\n" + " float32[]" + ), + lambda: jax.jvp(lambda x: scan_apply(f, x), (x,), (x,))) + + def test_multiple_rule_invocations(self): + @jax.custom_jvp + def expit(x): + return 1 / (1 + lax.exp(-x)) + + @expit.defjvp + def _expit_jvp(primals, tangents): + (x,), (t,) = primals, tangents + ans = expit(x) + t_out = t * ans * (1 - ans) + return ans, t_out + + def scanned_fun(c, _): + return [expit(c[0])] + [c[i-1] + c[i] for i in range(1, len(c))], None + + def foo(x): + zero = jnp.zeros_like(x) + c, _ = lax.scan(scanned_fun, [x, zero, zero, zero, zero], None, length=10) + return c[-1] + + # just make sure these don't crash + foo(3.) + grad(foo)(3.) + grad(lambda x: jax.vmap(foo)(x).sum())(jnp.arange(3.)) + + def test_hard_stuff(self): + arr = jnp.ones((5, 2, 2)) + api.jit(jax.vmap(jnp.linalg.det))(arr) # doesn't crash + + def test_hard_stuff2(self): + @jax.custom_jvp + def f(x): + return np.zeros(x.shape, x.dtype) + + @f.defjvp + def f_jvp(primals, tangents): + x, = primals + t, = tangents + return f(x), t + + # don't crash + jax.jit(jax.vmap(f))(jnp.arange(3.)) + jax.jit(jax.vmap(jax.grad(f)))(jnp.arange(3.)) + jax.jit(jax.grad(lambda x: jax.vmap(f)(x).sum()))(jnp.arange(3.)) + jax.grad(lambda x: jax.vmap(f)(x).sum())(jnp.arange(3.)) + jax.jvp(jax.vmap(f), (jnp.arange(3.),), (jnp.ones(3),)) + + def test_hard_stuff3(self): + @jax.custom_jvp + def relu(x): + return jnp.maximum(x, 0) + + @relu.defjvp + def _relu_jvp(primals, tangents): + x, = primals + t, = tangents + return relu(x), lax.select(x > 0, t, lax.full_like(t, 0)) + + def scanned_fun(c, _): + return [relu(c[0])] + [c[i-1] + c[i] for i in range(1, len(c))], None + + def f(x): + zero = jnp.zeros_like(x) + c, _ = lax.scan(scanned_fun, [x, zero, zero, zero, zero], None, length=10) + return c[-1] + + # don't crash + jax.jit(jax.vmap(f))(jnp.arange(3.)) + jax.jit(jax.vmap(jax.grad(f)))(jnp.arange(3.)) + jax.jit(jax.grad(lambda x: jax.vmap(f)(x).sum()))(jnp.arange(3.)) + jax.grad(lambda x: jax.vmap(f)(x).sum())(jnp.arange(3.)) + jax.jvp(jax.jit(jax.vmap(f)), (jnp.arange(3.),), (jnp.ones(3),)) + + def test_eval_shape(self): + @jax.custom_jvp + def expit(x): + return 1 / (1 + lax.exp(-x)) + + @expit.defjvp + def _expit_jvp(primals, tangents): + (x,), (t,) = primals, tangents + ans = expit(x) + t_out = t * ans * (1 - ans) + return ans, t_out + + # don't crash + api.eval_shape(expit, jnp.ones((2, 3))) + api.eval_shape(api.grad(lambda x: expit(x).sum()), jnp.ones((2, 3))) + + def test_jaxpr_zeros(self): + # from https://github.com/jax-ml/jax/issues/2657 + @jax.custom_jvp + def f(A, b): + return A @ b + + def f_jvp(primals, tangents): + A, b = primals + dA, db = tangents + z = f(A, b) + dz = A @ db + dA @ b + return z, dz + + f.defjvp(f_jvp) + + def experiment(theta): + def step(q, _): + z = f(jnp.eye(3), jnp.ones(3) * theta) + q += z[0] + return q, q + + q = 0. + q, _ = lax.scan(step, q, None, 4) + return q + + grad(experiment)(1.) # doesn't crash + + def test_linear_in_scan(self): + @jax.custom_jvp + def f(x): + return -x + + @f.defjvp + def f_jvp(primals, tangents): + x, = primals + x_dot, = tangents + return f(x), f(x_dot) + + def foo(x): + out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) + return out + + ans = api.grad(foo)(3.) + expected = -1. + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_custom_jvps_first_rule_is_none(self): + # https://github.com/jax-ml/jax/issues/3389 + @jax.custom_jvp + def f(x, y): + return x ** 2 * y + + f.defjvps(None, lambda x_dot, primal_out, x, y: 2 * x * y * x_dot) + ans = grad(f, 1)(2., 3.) # doesn't crash + expected = 12. + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_concurrent_initial_style(self): + # https://github.com/jax-ml/jax/issues/3843 + def unroll(param, sequence): + def scan_f(prev_state, inputs): + return prev_state, jax.nn.sigmoid(param * inputs) + return jnp.sum(jax.lax.scan(scan_f, None, sequence)[1]) + + def run(): + return jax.grad(unroll)(jnp.array(1.0), jnp.array([1.0])) + + expected = run() + + # we just don't want this to crash + n_workers = 2 + with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as e: + futures = [] + for _ in range(n_workers): + futures.append(e.submit(run)) + results = [f.result() for f in futures] + for ans in results: + self.assertAllClose(ans, expected) + + def test_nondiff_argnums_vmap_tracer(self): + # https://github.com/jax-ml/jax/issues/3964 + @partial(jax.custom_jvp, nondiff_argnums=(0, 2)) + def sample(shape, param, seed): + return jax.random.uniform(key=seed, shape=shape, minval=param) + + @sample.defjvp + def sample_jvp(shape, seed, primals, tangents): + param, = primals + dparam, = tangents + dparam = jnp.broadcast_to(dparam, shape) + samples = sample(shape, param, seed) + return samples, samples * dparam # dummy jvp for proof of concept + + # check these don't crash + jax.vmap(lambda seed: sample((2,3), 1., seed))( + jax.random.split(jax.random.key(1), 10)) + jax.jvp(lambda x: sample((2, 3), x, jax.random.key(1)), + (1.,), (1.,)) + + def test_fun_with_nested_calls_2(self): + def call(f, *args): + f = jax.custom_jvp(f) + f.defjvp(lambda primals, tangents: (f(*primals), sum(tangents))) + return f(*args) + + def fun_with_nested_calls_2(x): + def bar(y): + def baz(w): + q = call(lambda x: y, x) + q = q + call(lambda: y) + q = q + call(lambda y: w + y, y) + q = call(lambda w: call(jnp.sin, x) * y, 1.0) + q + return q + return api.jit(baz)(x) + return call(bar, x) + + # test these don't crash + self.assertAllClose(api.jit(fun_with_nested_calls_2)(3.), + fun_with_nested_calls_2(3.)) + api.vmap(fun_with_nested_calls_2)(jnp.arange(3.)) + + def test_closure_with_vmap(self): + # https://github.com/jax-ml/jax/issues/3822 + alpha = np.float32(2.) + + def sample(seed): + @jax.custom_jvp + def f(alpha): + return jax.random.gamma(seed, alpha, shape=[]) + + @f.defjvp + def f_jvp(primal, tangent): + alpha = primal + dalpha = tangent + sample = f(alpha) + partial_alpha = lax.random_gamma_grad(alpha, sample) + return sample, partial_alpha * dalpha + return f(alpha) + + api.vmap(sample)(jax.random.split(jax.random.key(1), 3)) # don't crash + + def test_closure_with_vmap2(self): + # https://github.com/jax-ml/jax/issues/8783 + def h(z): + def f(x): + @jax.custom_jvp + def g(y): + return x * y + + # NOTE: rule closes over vmap tracer + @g.defjvp + def g_jvp(primals, tangents): + (y,), (ydot,) = primals, tangents + return x * y, x * ydot + + return g(z) # NOTE: no vmapped arg + + return jax.vmap(f)(jnp.arange(3., dtype='float32')) + + primals, tangents = jax.jvp(h, (jnp.float32(1.),), (jnp.float32(2.),)) + self.assertAllClose(primals , jnp.arange(3., dtype='float32')) + self.assertAllClose(tangents, 2 * jnp.arange(3., dtype='float32')) + + def test_float0(self): + scalar_float0 = jnp.zeros((), dtype=float0) + @jax.custom_jvp + def f(x, y): + return x, y + def f_jvp(primals, _): + x, y = primals + return (x, y), (2., jax.custom_derivatives.zero_from_primal(y)) + f.defjvp(f_jvp) + + primals = (2., 3) + tangents = (np.ones(()), scalar_float0) + expected_tangents = (2., scalar_float0) + self.assertAllClose(api.jvp(f, primals, tangents), + (primals, expected_tangents)) + + def test_float0_initial_style(self): + scalar_float0 = jnp.zeros((), dtype=float0) + @jax.custom_jvp + def f(x, y): + return x, y + def f_jvp(primals, _): + x, y = primals + return (x, y), (2., jax.custom_derivatives.zero_from_primal(y)) + f.defjvp(f_jvp) + + def foo(x, y): + out, _ = lax.scan(lambda c, _: (f(*c), None), (x, y), None, length=1) + return out + + primals = (2., 3) + tangents = (np.ones(()), scalar_float0) + expected_tangents = (2., scalar_float0) + + self.assertAllClose(api.jvp(foo, primals, tangents), + (primals, expected_tangents)) + + def test_remat(self): + @jax.custom_jvp + def f(x): + return jnp.sin(x) + def f_jvp(primals, tangents): + x, = primals + g, = tangents + return f(x), 2 * jnp.cos(x) * g + f.defjvp(f_jvp) + + @jax.remat + def g(x): + return f(f(x)) + + ans = g(2.) + expected = np.sin(np.sin(2.)) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(g)(2.) + expected = 4. * api.grad(lambda x: jnp.sin(jnp.sin(x)))(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_remat_higher_order(self): + @jax.custom_jvp + def f(x): + return jnp.sin(x) + def f_jvp(primals, tangents): + x, = primals + g, = tangents + return f(x), 2 * jnp.cos(x) * g + f.defjvp(f_jvp) + + def g(x): + return f(f(x)) + + ans = api.grad(api.grad(jax.checkpoint(g)))(2.) + expected = api.grad(api.grad(g))(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(jax.checkpoint(api.grad(g)))(2.) + expected = api.grad(api.grad(g))(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(api.grad(api.grad(jax.checkpoint(g))))(2.) + expected = api.grad(api.grad(api.grad(g)))(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_initial_style_vmap_2(self): + # This is like test_initial_style_vmap except the primal function closes + # over an array constant. + y = jnp.arange(1., 4.) + + @jax.custom_jvp + def f(x): + assert jnp.ndim(x) == 0 + return 3 * x * jnp.sum(y) + def f_jvp(primals, tangents): + x, = primals + g, = tangents + return f(x), 2 * g + f.defjvp(f_jvp) + + def foo(x): + out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) + return out + + ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.ones(3)) + expected = 2. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(lambda x: api.vmap(api.jit(foo))(x).sum())(jnp.ones(3)) + expected = 2. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(lambda x: api.jit(api.vmap(foo))(x).sum())(jnp.ones(3)) + expected = 2. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(api.jit(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3)) + expected = 2. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.jit(api.grad(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3)) + expected = 2. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_custom_jvp_vmap_broadcasting_interaction(self): + # https://github.com/jax-ml/jax/issues/6452 + def f2(y, z): + v1 = z + v2 = jnp.sum(y) + z + return jnp.logaddexp(v1, v2) + + def f1(y, z): + v = api.vmap(lambda _y: f2(_y, z))(y) + return jnp.sum(v) + + y = jnp.ones((3, 2)) + f = lambda z: f1(y, z) + z = 0.1 + val, g = api.value_and_grad(f)(z) + self.assertEqual(val.shape, ()) + self.assertEqual(g.shape, ()) + + def test_custom_jvp_vmap_broadcasting_interaction_2(self): + # https://github.com/jax-ml/jax/issues/5849 + @jax.custom_jvp + def transform(box, R): + if jnp.isscalar(box) or box.size == 1: + return R * box + elif box.ndim == 2: + return jnp.einsum('ij,j->i', box, R) + raise ValueError() + + @transform.defjvp + def transform_jvp(primals, tangents): + box, R = primals + dbox, dR = tangents + return (transform(box, R), dR + transform(dbox, R)) + + def periodic_general(box): + def displacement_fn(Ra, Rb, **kwargs): + _box = kwargs.get('box', box) + return transform(_box, Ra - Rb) + + return displacement_fn + + N = 250 + + scalar_box = 1.0 + displacement = periodic_general(scalar_box) + + key = jax.random.key(0) + R = jax.random.uniform(key, (N, 2)) + + def energy_fn(box): + d = partial(displacement, box=box) + d = api.vmap(api.vmap(d, (None, 0)), (0, None)) + return jnp.sum(d(R, R) ** 2) + + self.assertEqual(grad(energy_fn)(scalar_box).shape, ()) + + def test_custom_jvp_implicit_broadcasting(self): + # https://github.com/jax-ml/jax/issues/6357 + if config.enable_x64.value: + raise unittest.SkipTest("test only applies when x64 is disabled") + + @jax.custom_jvp + def projection_unit_simplex(x: jax.Array) -> jax.Array: + """Projection onto the unit simplex.""" + s = 1.0 + n_features = x.shape[0] + u = jnp.sort(x)[::-1] + cssv = jnp.cumsum(u) - s + ind = jnp.arange(n_features, dtype=x.dtype) + 1 + cond = u - cssv / ind > 0 + idx = jnp.count_nonzero(cond) + threshold = cssv[idx - 1] / idx.astype(x.dtype) + return jax.nn.relu(x - threshold) + + + @projection_unit_simplex.defjvp + def projection_unit_simplex_jvp(primals, tangents): + x, = primals + x_dot, = tangents + primal_out = projection_unit_simplex(x) + supp = (primal_out > 0).astype(x_dot.dtype) + card = jnp.count_nonzero(supp).astype(x_dot.dtype) + tangent_out = supp * x_dot - (jnp.dot(supp, x_dot) / card) * supp + return primal_out, tangent_out + + rng = self.rng() + x = rng.rand(5).astype(np.float32) + + J_rev = jax.jacrev(projection_unit_simplex)(x) + J_fwd = jax.jacfwd(projection_unit_simplex)(x) + + p = projection_unit_simplex(x) + support = (p > 0).astype(jnp.float32) + cardinality = jnp.count_nonzero(support).astype(support.dtype) + J_true = jnp.diag(support) - jnp.outer(support, support) / cardinality + self.assertAllClose(J_true, J_fwd) + self.assertAllClose(J_true, J_rev) + + proj = jax.vmap(projection_unit_simplex) + + def fun(X): + return jnp.sum(proj(X) ** 2) + + rng = self.rng() + X = rng.rand(4, 5).astype(np.float32) + U = rng.rand(4, 5) + U /= np.sqrt(np.sum(U ** 2)) + U = U.astype(np.float32) + + eps = 1e-3 + dir_deriv_num = (fun(X + eps * U) - fun(X - eps * U)) / (2 * eps) + dir_deriv = jnp.vdot(jax.grad(fun)(X), U) + self.assertAllClose(dir_deriv, dir_deriv_num, atol=1e-3) + + def test_vmap_inside_defjvp(self): + # https://github.com/jax-ml/jax/issues/3201 + seed = 47 + key = jax.random.key(seed) + mat = jax.random.normal(key, (2, 3)) + + @jax.custom_jvp + def f(mat, aux): + num_rows, num_cols = mat.shape + return jnp.ones((num_rows, 1)) / num_cols + + @f.defjvp + def f_jvp(primals, tangents): + mat, aux = primals + vec, _ = tangents + output = f(*primals) + num_rows, num_cols = mat.shape + size = num_rows * num_cols + # ----- + bd_mat = mat.reshape(1, 1, num_rows, num_cols) + bd_mat = jnp.tile(bd_mat, reps=(num_rows, num_cols)) + bd_mat = bd_mat.reshape(size, num_rows, num_cols) + # ----- + rowsum = jnp.sum(mat, axis=1, keepdims=True) + colsum = jnp.sum(mat, axis=0, keepdims=True) + bd_rowsum = jnp.tile(rowsum, reps=(1, num_rows)) + bd_colsum = jnp.tile(colsum, reps=(num_cols, 1)) + # ----- + bd_vec = vec.reshape(size, 1) + # ----- + def operate(mx, val): + buf = 0 + for i in range(2): + buf = buf + jnp.matmul(mx, bd_colsum) / jnp.power(aux, i) + buf = jnp.matmul(bd_rowsum, buf) + return buf * val[None, :] + # ----- + # Vertorizing will raise shape error + bd_buf = jax.vmap(operate, in_axes=(0, 0), out_axes=0)(bd_mat, bd_vec) + # ----- + bd_buf = bd_buf / aux + jvp = jnp.sum(bd_buf, axis=0) + jvp = jnp.mean(jvp, axis=1, keepdims=True) + # ----- + # JVP ends successfully, but still raise an error + return (output, jvp) + + jax.grad(lambda mat, aux: jnp.sum(f(mat, aux)))(mat, 0.5) # doesn't crash + + def test_custom_jvp_unbroadcasting(self): + # https://github.com/jax-ml/jax/issues/3056 + a = jnp.array([1., 1.]) + + @jax.custom_jvp + def f(x): + return a * x + + @f.defjvp + def f_jvp(primals, tangents): + x, = primals + dx, = tangents + return a * x, a * dx + + shape = grad(lambda x: jnp.sum(f(x)))(jnp.array(1.)).shape + self.assertEqual(shape, ()) + + def test_maybe_perturbed_internal_helper_function(self): + # This is a unit test for an internal API. We include it so as not to + # regress https://github.com/jax-ml/jax/issues/9567. For an explanation of + # this helper function, see https://github.com/jax-ml/jax/issues/6415. + def f(x): + def g(y, _): + z = y * x + self.assertTrue(custom_derivatives._maybe_perturbed(z)) + return y, None + g(1, None) + return lax.scan(g, 1, xs=None, length=1)[0] + + jax.jvp(f, (1.0,), (1.0,)) # assertions inside f + + def test_maybe_perturbed_int_regression(self): + # see https://github.com/jax-ml/jax/discussions/9951 + + @jax.jit + def f(): + x = jnp.array(1) + _, aux_args = custom_derivatives.closure_convert(lambda: x) + self.assertEmpty(aux_args) + f() + + def test_sinc_constant_function_batching(self): + # https://github.com/jax-ml/jax/pull/10756 + batch_data = jnp.arange(15.).reshape(5, 3) + + @jax.vmap + def f(x): + return jax.lax.map(jnp.sinc, x) + g = lambda param: f(param * batch_data).sum() + + @jax.vmap + def f_ref(x): + return jnp.stack([jnp.sinc(x_) for x_ in x]) + g_ref = lambda param: f_ref(param * batch_data).sum() + + grad = jax.grad(g )(0.1) # doesn't crash + grad_ref = jax.grad(g_ref)(0.1) + self.assertAllClose(grad, grad_ref, check_dtypes=False) + + @parameterized.named_parameters( + ('jit_vmap', True, True), + ('jit', True, False), + ('vmap', False, True), + ('', False, False), + ) + def test_symbolic_zero_custom_jvp(self, maybe_jit, maybe_vmap): + def f(static_scalar, static_array, dyn_scalar, dyn_array): + out1 = static_scalar + dyn_scalar + out2 = static_array + dyn_array + return out1, out2 + + def _pack(x): + return lax.broadcast(x, (1,)) + + def _unpack(x): + (x,) = x + return x + + def _vmap(fun): + def _fun(*args): + args = jax.tree.map(_pack, args) + out = jax.vmap(fun)(*args) + out = jax.tree.map(_unpack, out) + return out + return _fun + + f = jax.custom_jvp(f) + + @partial(f.defjvp, symbolic_zeros=True) + def f_jvp(primals, tangents): + static_scalar, *_ = primals + t_static, t_static_arr, t_dyn_scalar, t_dyn_array = tangents + self.assertIs(type(t_static) , jax.custom_derivatives.SymbolicZero) + self.assertIs(type(t_static_arr), jax.custom_derivatives.SymbolicZero) + self.assertEqual(t_static.shape, ()) + self.assertEqual(t_static_arr.shape, (2,)) + return f(*primals), (static_scalar + 90, t_dyn_array + 91) + + def g(dyn_scalar, dyn_array): + if maybe_vmap: + f_ = _vmap(f) + else: + f_ = f + return f_(1., jnp.array([2., 3.]), dyn_scalar, dyn_array) + + def run(primal_ins, tangent_ins): + return jax.jvp(g, primal_ins, tangent_ins) + + if maybe_jit: + run = jax.jit(run) + + primal_ins = (4., jnp.array([5., 6.])) + tangent_ins = (7., jnp.array([8., 9.])) + primal_outs, tangent_outs = run(primal_ins, tangent_ins) + primal_out1, primal_out2 = primal_outs + tangent_out1, tangent_out2 = tangent_outs + scalar_type = jax.Array if maybe_jit or maybe_vmap else float + self.assertIsInstance(primal_out1, scalar_type) + self.assertAllClose(primal_out1, 5.) + self.assertIsInstance(tangent_out1, scalar_type) + self.assertAllClose(tangent_out1, 91.) + self.assertIsInstance(primal_out2, jax.Array) + self.assertArraysAllClose(primal_out2, jnp.array([7., 9.])) + self.assertIsInstance(tangent_out2, jax.Array) + self.assertArraysAllClose(tangent_out2, jnp.array([99., 100.])) + + def test_symbolic_zero_custom_jvp_vmap_output(self): + @jax.custom_jvp + def f(x, y): + return x * y + + @partial(f.defjvp, symbolic_zeros=True) + def f_jvp(primals, tangents): + x, y = primals + x_dot, y_dot = tangents + self.assertIs(type(y_dot), jax.custom_derivatives.SymbolicZero) + return f(x, y), y_dot + + jax.grad(lambda x, y: jax.vmap(f)(x, y).sum())(jnp.ones(3), jnp.ones(3)) + + def test_symbolic_zeros_memoization_caching(self): + # Tests multiple zero patterns for partial_eval._memoize, and also tests + # that we're okay with stores being occupied with equal values. + + @jax.custom_jvp + def f(x, y): + return x * y + + @partial(f.defjvp, symbolic_zeros=True) + def f_jvp(primals, tangents): + x, y = primals + x_dot, y_dot = tangents + return f(x, y), y_dot + + f_ = core.jaxpr_as_fun(jax.make_jaxpr(f)(2., 3.)) + _ = jax.linearize(f_, 2., 3.) + _ = jax.linearize(lambda x: f_(x, 3.), 2.) # don't crash! + + def test_symbolic_zeros_under_jit(self): + # https://github.com/jax-ml/jax/issues/14833 + Zero = jax.custom_derivatives.SymbolicZero + + @jax.custom_jvp + def f(x, y): + return x * y + + @partial(f.defjvp, symbolic_zeros=True) + def fjvp(primals, tangents): + x, y = primals + tx, ty = tangents + assert type(tx) is not Zero or type(ty) is not Zero + return f(x, y), ( + ty if type(tx) is Zero else + tx if type(ty) is Zero else + tx + ty) + + jax.jacfwd(jax.jit(f))(0.1, 0.2) # don't crash + + def test_custom_jvp_functools_partial(self): + def fun(x, y, a): + return x + y * a + + fun_wrapped = functools.partial(fun, a = 0.1) + + def jvp_fn(primals, tangents): + return jax.jvp(fun_wrapped, primals, tangents) + + fn = jax.custom_jvp(fun_wrapped) + fn.defjvp(jvp_fn) + + self.assertEqual((1.0, 0.1), jax.grad(lambda args: fn(*args))((1.0, 2.0))) + + def test_run_rules_more_than_once(self): + # https://github.com/jax-ml/jax/issues/16614 + + @jax.custom_jvp + def f(x, y): + return x + + @partial(f.defjvp, symbolic_zeros=True) + def f_jvp(primals, tangents): + x, _ = primals + x_dot, _ = tangents + return x, x_dot + + def body(x_y, _): + x, y = x_y + return (f(x, y), x), None + + @jax.grad + def g(x): + (out, _), _ = lax.scan(body, (x, 1.), xs=None, length=2) + return out + + g(1.) # doesn't crash + + def test_dce(self): + @jax.custom_jvp + def f(x, y): + return jnp.sin(x), x + jnp.cos(y) + + @f.defjvp + def f_jvp(primals, tangents): + x, y = primals + dx, dy = tangents + return f(x, y), (2.0 * jnp.cos(x) * dx, 1.5 * dx - 0.5 * jnp.sin(y) * dy) + + def check_jaxpr(jaxpr, used_outs, includes, excludes): + dce_jaxpr, _ = pe.dce_jaxpr(jaxpr, used_outs) + if not dce_jaxpr.eqns: + assert not includes + return + call_jaxpr = dce_jaxpr.eqns[0].params["call_jaxpr"] + for prim in includes: + assert any(eqn.primitive == prim for eqn in call_jaxpr.eqns) + for prim in excludes: + assert all(eqn.primitive != prim for eqn in call_jaxpr.eqns) + + x, y = 0.1, -1.3 + jaxpr = jax.make_jaxpr(f)(x, y).jaxpr + check_jaxpr(jaxpr, [True, True], [lax.sin_p, lax.cos_p], []) + check_jaxpr(jaxpr, [True, False], [lax.sin_p], [lax.cos_p]) + check_jaxpr(jaxpr, [False, True], [lax.cos_p], [lax.sin_p]) + check_jaxpr(jaxpr, [False, False], [], [lax.sin_p, lax.cos_p]) + + def dce_jaxpr_as_fun(jaxpr, used_outs): + jaxpr_, _ = pe.dce_jaxpr(jaxpr, used_outs) + fun = core.jaxpr_as_fun(pe.close_jaxpr(jaxpr_)) + return lambda *args: fun(*args)[0] + + f0 = dce_jaxpr_as_fun(jaxpr, [True, False]) + f1 = dce_jaxpr_as_fun(jaxpr, [False, True]) + self.assertAllClose( + api.jvp(f0, (x, y), (1.0, 0.0)), (f0(x, y), 2.0 * jnp.cos(x))) + self.assertAllClose( + api.jvp(f0, (x, y), (0.0, 1.0)), (f0(x, y), 0.0)) + self.assertAllClose( + api.jvp(f1, (x, y), (1.0, 0.0)), (f1(x, y), 1.5)) + self.assertAllClose( + api.jvp(f1, (x, y), (0.0, 1.0)), (f1(x, y), -0.5 * jnp.sin(y))) + + def test_dce_symbolic_zeros(self): + # https://github.com/jax-ml/jax/issues/31448 + @jax.custom_jvp + def f(x): + return x + + @partial(f.defjvp, symbolic_zeros=True) + def f_jvp(primals, tangents): + x, = primals + tx, = tangents + return f(x), tx + + @jax.jacfwd + @jax.jacrev + def f_wrapped(x): + return jax.jit(f)((x, 3.)) + + f_wrapped(jnp.zeros(2)) # doesn't crash + + def test_resolve_kwargs_error_message(self): + @jax.custom_jvp + def f(x, y, *, z=None): + return jnp.sin(x), x + jnp.cos(y) + + @f.defjvp + def f_jvp(primals, tangents): + self.fail("should not be executed") + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_jvp-decorated function f(.*)\n" + r"missing a required argument: 'y'" + ): + f(0.5) + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_jvp-decorated function f(.*)\n" + "The following keyword arguments could not be resolved to positions: z" + ): + f(0.5, 0.1, z=1.0) + + def test_symbolic_zero_custom_jvp_vmap_doesnt_instantiate(self): + @jax.custom_jvp + def f(x, y): + return y + + def f_jvp(primals, tangents): + (x, y), (x_dot, y_dot) = primals, tangents + assert type(y_dot) is jax.custom_derivatives.SymbolicZero + return y, y_dot + + f.defjvp(f_jvp, symbolic_zeros=True) + + def g(x): + return f(x, f(x, 1.)) + + jax.jvp(jax.vmap(g), (jnp.ones(3),), (jnp.ones(3),)) # don't crash + + def test_symbolic_zero_under_vmap_of_jit(self): + # https://github.com/jax-ml/jax/issues/28144 + @jax.custom_jvp + def f(x): + return x + 1 + + @f.defjvp + def f_jvp(x, t): + (x,) = x + (t,) = t + z = jax.custom_derivatives.zero_from_primal(x, symbolic_zeros=True) + return f(x), z + + x = jnp.arange(3.0) + jax.jvp(jax.vmap(jax.jit(f)), (x,), (x,)) # doesn't crash + + def test_pretty_print(self): + @jax.custom_jvp + def f(x): + return x + 1 + + @f.defjvp + def f_jvp(primals, tangents): + return f(*primals), tangents[0] + + x = jnp.array([4.2], dtype=jnp.float32) + jaxpr = jax.make_jaxpr(f)(x) + actual = jaxpr.pretty_print(use_color=False) + expected = textwrap.dedent( + """ + { lambda ; a:f32[1]. let + b:f32[1] = custom_jvp_call[ + name=f + call_jaxpr={ lambda ; c:f32[1]. let d:f32[1] = add c 1.0:f32[] in (d,) } + jvp=f_jvp + symbolic_zeros=False + ] a + in (b,) } + """).strip() + self.assertEqual(actual, expected) + + def test_custom_jvp_transpose_vjp3(self): + @jax.custom_jvp + def div(x, y): + return x / y + @div.defjvp + def sin_jvp(primals, tangents): + (x, y), (x_dot, y_dot) = primals, tangents + del y_dot # ignore lol + return div(x, y), div(x_dot, y) + _, f_vjp = api.vjp(lambda x: div(x, 2.), 1.) + ans, = f_vjp(1.) + self.assertAllClose(ans, 1./2, check_dtypes=False) + + def test_ensure_compile_time_eval(self): + @jax.custom_jvp + def f(x): + assert x == 0. # concrete! + return x + @f.defjvp + def f_jvp(primals, tangents): + (x,), (x_dot,) = primals, tangents + assert x == 0. # concrete! + + @jax.jit + def g(): + with jax.ensure_compile_time_eval(): + return f(0.) + + g() # don't crash + + # TODO(mattjj): do we want to support autodiff here too? + # def h(x): + # @jax.jit + # def hh(): + # with jax.ensure_compile_time_eval(): + # return f(x) + # return hh() + + # jax.grad(h)(0.) # don't crash + + +class CustomVJPTest(jtu.JaxTestCase): + + def test_basic(self): + @jax.custom_vjp + def f(x): + return jnp.sin(x) + def f_fwd(x): + return f(x), jnp.cos(x) + def f_rev(cos_x, g): + return (2 * cos_x * g,) + f.defvjp(f_fwd, f_rev) + + x = 3. + self.assertAllClose(f(x), jnp.sin(x)) + self.assertAllClose(api.grad(f)(x), 2 * jnp.cos(x)) + self.assertAllClose(api.value_and_grad(f)(x), + (jnp.sin(x), 2 * jnp.cos(x))) + + def test_invariance(self): + @jax.custom_vjp + def f(x): + return jnp.cos(2 * x) / 2. + def f_fwd(x): + return (f(x), x) + def f_rev(x, g): + return (g * 3,) + f.defvjp(f_fwd, f_rev) + def f2(x): + y, _ = api.value_and_grad(f)(x) + return y + def f3(x): + y, _ = api.value_and_grad(f2)(x) + return y + x = 1. + self.assertAllClose(f(x), f2(x), check_dtypes=False) + self.assertAllClose(f(x), f3(x), check_dtypes=False) + self.assertAllClose(api.grad(f)(x), api.grad(f2)(x), + check_dtypes=False) + self.assertAllClose(api.grad(f)(x), api.grad(f3)(x), + check_dtypes=False) + + def test_python_control_flow(self): + @jax.custom_vjp + def f(x): + if x > 0: + return jnp.sin(x) + else: + return jnp.cos(x) + def f_fwd(x): + if x > 0: + return f(x), x + else: + return f(x), x + def f_rev(x, g): + if x > 0: + return (2 * g,) + else: + return (3 * g,) + f.defvjp(f_fwd, f_rev) + x = 2. + self.assertAllClose(f(x), jnp.sin(x)) + self.assertAllClose(f(-x), jnp.cos(-x)) + self.assertAllClose(api.value_and_grad(f)(x), (jnp.sin(x), 2.), + check_dtypes=False) + self.assertAllClose(api.value_and_grad(f)(-x), (jnp.cos(-x), 3.), + check_dtypes=False) + + def test_python_control_flow_bwd(self): + @jax.custom_vjp + def f(x): + return jax.lax.cond(x > 0, jnp.sin, jnp.cos, x) # no primal control flow + def f_fwd(x): + if x > 0: + return jnp.sin(x), x + else: + return jnp.cos(x), x + def f_rev(x, g): + if x > 0: + return (2 * g,) + else: + return (3 * g,) + f.defvjp(f_fwd, f_rev) + x = 2. + self.assertAllClose(f(x), jnp.sin(x)) + self.assertAllClose(f(-x), jnp.cos(-x)) + self.assertAllClose(api.value_and_grad(f)(x), (jnp.sin(x), 2.), + check_dtypes=False) + self.assertAllClose(api.value_and_grad(f)(-x), (jnp.cos(-x), 3.), + check_dtypes=False) + + def test_vmap(self): + @jax.custom_vjp + def f(x): + assert jnp.ndim(x) == 0 + return jnp.sin(x) + def f_fwd(x): + assert jnp.ndim(x) == 0 + return f(x), jnp.cos(x) + def f_rev(cos_x, g): + return (2 * cos_x * g,) + f.defvjp(f_fwd, f_rev) + + x = jnp.arange(3.) + xx = jnp.arange(6.).reshape(2, 3) + + # vmap of f + self.assertAllClose(api.vmap(f)(x), jnp.sin(x)) + self.assertAllClose(api.vmap(api.vmap(f))(xx), jnp.sin(xx)) + + # vmap of grad of f + self.assertAllClose(api.vmap(api.grad(f))(x), 2 * jnp.cos(x)) + self.assertAllClose(api.vmap(api.value_and_grad(f))(x), + (jnp.sin(x), 2 * jnp.cos(x))) + self.assertAllClose(api.vmap(api.vmap(api.grad(f)))(xx), 2 * jnp.cos(xx)) + self.assertAllClose(api.vmap(api.vmap(api.value_and_grad(f)))(xx), + (jnp.sin(xx), 2 * jnp.cos(xx))) + + # grad of vmap of f + self.assertAllClose(api.grad(lambda x: api.vmap(f)(x).sum())(x), + 2 * jnp.cos(x)) + self.assertAllClose(api.grad(lambda x: api.vmap(api.vmap(f))(x).sum())(xx), + 2 * jnp.cos(xx)) + + # vmap of grad of vmap of f + self.assertAllClose(api.vmap(api.grad(lambda x: api.vmap(f)(x).sum()))(xx), + 2 * jnp.cos(xx)) + + def test_jit(self): + @jax.custom_vjp + def f(x): + return jnp.sin(x) + def f_fwd(x): + return f(x), jnp.cos(x) + def f_rev(cos_x, g): + return (2 * cos_x * g,) + f.defvjp(f_fwd, f_rev) + + x = 3. + + # jit + self.assertAllClose(api.jit(f)(x), jnp.sin(x)) + self.assertAllClose(api.jit(api.jit(f))(x), jnp.sin(x)) + + # jit of grad + self.assertAllClose(api.jit(api.grad(f))(x), 2 * jnp.cos(x), + check_dtypes=False) + + # grad of jit + self.assertAllClose(api.grad(api.jit(f))(x), 2 * jnp.cos(x), + check_dtypes=False) + + def test_pytrees(self): + @jax.custom_vjp + def f(x): + return {'b': jnp.sin(x['a'])} + def f_fwd(x): + return f(x), {'r': jnp.cos(x['a'])} + def f_bwd(res, g): + cos_x = res['r'] + return ({'a': 2 * cos_x * g['b']},) + f.defvjp(f_fwd, f_bwd) + x = {'a': 3.} + self.assertAllClose(f(x)['b'], jnp.sin(x['a'])) + self.assertAllClose(api.grad(lambda x: f(x)['b'])(x), + {'a': 2 * jnp.cos(x['a'])}) + + def test_jvp_error(self): + @jax.custom_vjp + def f(x): + return jnp.sin(x) + def f_fwd(x): + return f(x), jnp.cos(x) + def f_rev(cos_x, g): + return (2 * cos_x * g,) + f.defvjp(f_fwd, f_rev) + + self.assertRaisesRegex( + TypeError, + r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.", + lambda: api.jvp(f, (3.,), (1.,))) + self.assertRaisesRegex( + TypeError, + r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.", + lambda: api.jvp(api.vmap(f), (jnp.arange(3.),), (jnp.ones(3),))) + self.assertRaisesRegex( + TypeError, + r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.", + lambda: api.jvp(jit(f), (3.,), (1.,))) + + def test_kwargs(self): + # from https://github.com/jax-ml/jax/issues/1938 + @jax.custom_vjp + def my_fun(x, y, c=1.): + return c * (x + y) + my_fun.defvjp(lambda x, y, c=1.: (my_fun(c, y, c), None), + lambda _, g: (g, g, g)) + f = lambda x, y: jnp.square(my_fun(x, y, c=2.)).sum() + f(10., 5.) # doesn't crash + api.grad(f)(10., 5.) # doesn't crash + + def test_initial_style(self): + @jax.custom_vjp + def f(x): + return jnp.sin(x) + def f_fwd(x): + return f(x), jnp.cos(x) + def f_rev(cos_x, g): + return (2 * cos_x * g,) + f.defvjp(f_fwd, f_rev) + + def foo(x): + out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) + return out + + ans = api.grad(foo)(3.) + expected = 2. * jnp.cos(3.) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(api.grad(foo))(3.) + expected = -2. * jnp.sin(3.) + self.assertAllClose(ans, expected) + + def test_initial_style_vmap(self): + @jax.custom_vjp + def f(x): + assert jnp.ndim(x) == 0 + return 3 * x + def f_fwd(x): + return f(x), jnp.cos(x) + def f_rev(cos_x, g): + return (2 * cos_x * g,) + f.defvjp(f_fwd, f_rev) + + def foo(x): + out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) + return out + + ans = api.vmap(foo)(jnp.arange(3.)) + expected = 3. * jnp.arange(3.) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.arange(3.)) + expected = 2. * jnp.cos(jnp.arange(3.)) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_nondiff_argnums(self): + @partial(jax.custom_vjp, nondiff_argnums=(0,)) + def app(f, x): + return f(x) + def app_fwd(f, x): + return app(f, x), jnp.cos(x) + def app_rev(f, cos_x, g): + return (cos_x * g,) + app.defvjp(app_fwd, app_rev) + + ans = app(lambda x: 2 * x, 1) + expected = 2 + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.value_and_grad(lambda x: app(lambda y: 2 * y, x))(1.) + expected = (2., jnp.cos(1.)) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_nondiff_argnames(self): + @partial(jax.custom_vjp, nondiff_argnames=('f',)) + def app(f, x): + return f(x) + def app_fwd(f, x): + return app(f, x), jnp.cos(x) + def app_rev(f, cos_x, g): + return (cos_x * g,) + app.defvjp(app_fwd, app_rev) + + ans = app(lambda x: 2 * x, 1) + expected = 2 + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.value_and_grad(lambda x: app(lambda y: 2 * y, x))(1.) + expected = (2., jnp.cos(1.)) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_nondiff_argnums_argnames(self): + @partial(jax.custom_vjp, nondiff_argnums=(0,), nondiff_argnames=('g',)) + def app(f, g, x): + return f(x) + g(x) + def app_fwd(f, g, x): + return app(f, g, x), jnp.cos(x) + def app_rev(f, g, cos_x, v): + return (cos_x * v,) + app.defvjp(app_fwd, app_rev) + + f = lambda x: 2 * x + g = lambda x: 2 * x + ans = app(f, g, 1) + expected = 4 + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.value_and_grad(lambda x: app(f, g, x))(1.) + expected = (4., jnp.cos(1.)) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_closed_over_jit_tracer(self): + # See the comment in CustomJVPTest.test_nondiff_arg_jit_tracer. + raise unittest.SkipTest("behavior no longer supported") + + # This test is similar to test_nondiff_arg_tracer except it uses lexical + # closure rather than the nondiff_argnums mechanism. We decided to disallow + # tracers in nondiff_argnums to greatly simplify bookkeeping while still + # supporting the cases for which it is necessary. + def outer(x): + @jax.custom_vjp + def f(y): + return x * y + def f_fwd(y): + return f(y), jnp.cos(y) + def f_rev(cos_y, g): + return (cos_y * g,) + f.defvjp(f_fwd, f_rev) + return f + + @jit + def g(x, y): + return outer(x)(y) + + ans = g(2, 3.) + expected = 6. + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(g, 1)(2., 3.) + expected = jnp.cos(3.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_closed_over_vmap_tracer(self): + def outer(x): + @jax.custom_vjp + def f(y): + return x * y + def f_fwd(y): + return f(y), jnp.cos(y) + def f_rev(cos_y, g): + return (cos_y * g,) + f.defvjp(f_fwd, f_rev) + return f + + @api.vmap + def g(x): + return outer(x)(3.) + + ans = g(np.arange(3.)) + expected = np.arange(3.) * 3 + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_closed_over_tracer3(self): + def outer(x): + @jax.custom_vjp + def f(y): + return x * y + def f_fwd(y): + return f(y), (x, jnp.cos(y)) + def f_rev(res, g): + x, cos_y = res + return (cos_y * g * x,) + f.defvjp(f_fwd, f_rev) + return api.grad(f) + + @api.vmap + def g(x): + return outer(x)(3.) + + ans = g(np.arange(3.)) + expected = np.cos(3.) * np.arange(3.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_nondiff_arg_tracer_error(self): + # This is similar to the old (now skipped) test_nondiff_arg_tracer, except + # we're testing for the error message that usage pattern now raises. + + @partial(jax.custom_vjp, nondiff_argnums=(0,)) + def f(x, y): + return x * y + def f_fwd(x, y): + return f(x, y), jnp.cos(y) + def f_rev(x, cos_y, g): + return (cos_y * g,) + f.defvjp(f_fwd, f_rev) + + @jit + def g(x, y): + return f(x, y) + + with self.assertRaisesRegex(UnexpectedTracerError, "custom_vjp"): + _ = g(2, 3.) + with self.assertRaisesRegex(UnexpectedTracerError, "custom_vjp"): + _ = api.grad(g, 1)(2., 3.) + + def test_vmap_axes(self): + raise unittest.SkipTest("TODO") # TODO(mattjj): write test + + def test_pmap(self): + raise unittest.SkipTest("TODO") # TODO(mattjj): write test + + def test_missing_vjp_rule_error(self): + @jax.custom_vjp + def foo(x): + return x ** 2 + + self.assertRaisesRegex( + AttributeError, + r"No VJP defined for custom_vjp function foo using defvjp.", + lambda: foo(2)) + self.assertRaisesRegex( + AttributeError, + r"No VJP defined for custom_vjp function foo using defvjp.", + lambda: api.grad(foo)(2.)) + + def test_vjp_rule_inconsistent_pytree_structures_error(self): + @jax.custom_vjp + def f(x): + return x + + def foo_fwd(x): + return x, None + + def foo_bwd(_, g): + return (g, g) + + f.defvjp(foo_fwd, foo_bwd) + + f(2) # doesn't crash + with self.assertRaisesRegex(Exception, "Custom VJP bwd rule .*must produce"): + api.grad(f)(2.) + + def test_vjp_bwd_returns_non_tuple_error(self): + @jax.custom_vjp + def f(x): + return x + + def foo_fwd(x): + return x, None + + def foo_bwd(_, g): + return 2. * g # Should be a tuple + + f.defvjp(foo_fwd, foo_bwd) + with self.assertRaisesRegex(TypeError, "Custom VJP bwd rule .* must produce a tuple"): + api.grad(f)(3.) + + def test_fwd_rule_primal_out_type_doesnt_match_primal_error_message(self): + # https://github.com/lucidrains/flash-attention-jax/issues/7 + + def scan_apply(f, x): + y, _ = jax.lax.scan(lambda x, _: (f(x), None), x, None, length=1) + return y + + @jax.custom_vjp + def f(x): + return x + + def f_fwd(x): + return (x, x), None + + def f_bwd(_, y_bar): + return (y_bar,) + + f.defvjp(f_fwd, f_bwd) + + self.assertRaisesRegex( + TypeError, + "Custom VJP fwd rule f_fwd for function f must produce a pair ", + lambda: jax.grad(lambda x: scan_apply(f, x))(jnp.float32(1.))) + + def f_fwd2(x): + return jnp.zeros((3, *x.shape), x.dtype), None + + def f_bwd2(_, y_bar): + return (y_bar,) + + f.defvjp(f_fwd2, f_bwd2) + + self.assertRaisesRegex( + TypeError, + re.escape( + "Custom VJP fwd rule f_fwd2 for function f must produce a pair " + "(list or tuple of length two) where the first element represents " + "the primal output (equal to the output of the " + "custom_vjp-decorated function f) and the second element " + "represents residuals (i.e. values stored from the forward " + "pass for use on the backward pass), but instead the fwd rule " + "output's first element had shapes/dtypes of:\n" + " float32[3]\n" + "while the custom_vjp-decorated function f had output " + "shapes/dtypes of:\n" + " float32[]" + ), + lambda: jax.grad(lambda x: scan_apply(f, x))(jnp.float32(1.))) + + def test_issue2511(self): + arr = jnp.ones((5, 2, 2)) + foo = lambda x: api.vmap(jnp.linalg.det, (0,))(x) + api.jit(foo)(arr) # doesn't crash + + def test_lowering_out_of_traces(self): + # https://github.com/jax-ml/jax/issues/2578 + + class F(collections.namedtuple("F", ["a"])): + def __call__(self, x): + return jax.nn.relu(self.a) * x + + @jax.jit + def g(f, x): + return f(x) + + jax.grad(g, argnums=(1,))(F(2.0), 0.) # doesn't crash + + def test_clip_gradient(self): + # https://github.com/jax-ml/jax/issues/2784 + @jax.custom_vjp + def _clip_gradient(lo, hi, x): + return x # identity function when not differentiating + + def clip_gradient_fwd(lo, hi, x): + return x, (lo, hi,) + + def clip_gradient_bwd(res, g): + lo, hi = res + return (None, None, jnp.clip(g, lo, hi),) + + _clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd) + + def clip_gradient(x): + lo = -0.1 + hi = x + 0.1 + return _clip_gradient(lo, hi, x) + + g = jax.grad(clip_gradient)(0.1) # doesn't crash + self.assertAllClose(g, jnp.array(0.2)) + + def test_nestable_vjp(self): + # Verify that https://github.com/jax-ml/jax/issues/3667 is resolved. + def f(x): + return x ** 2 + + @jax.custom_vjp + def g(x): + return f(x) + + def g_fwd(x): + y, f_vjp = api.vjp(f, x) + return y, f_vjp + + def g_bwd(f_vjp, y_bar): + return f_vjp(y_bar) + + g.defvjp(g_fwd, g_bwd) + + # Check that VJP can be nested in simple situations. For this to pass, + # vjp has to return a PyTree. + _, g_vjp = api.vjp(g, 1.0) + y, = g_vjp(1.0) + self.assertAllClose(y, jnp.array(2.0)) + + # Check that VJP can be nested in complex situations. For this to pass, + # vjp can't treat the closed-over tracer x as a static argument. + @jit + def z(x): + _, g_vjp = api.vjp(g, x) + return g_vjp + y, = z(1.0)(3.0) + self.assertAllClose(y, jnp.array(6.0)) + + def test_initial_style_vmap_2(self): + # https://github.com/jax-ml/jax/issues/4173 + x = jnp.ones((10, 3)) + + # Create the custom function + @jax.custom_vjp + def custom_fun(x): + return x.sum() + + def forward(x): + return x.sum(), (jnp.ones_like(x),) + + def backward(res, g): + return g * res[0], + + custom_fun.defvjp(forward, backward) + + def train_fun(x): + + def summed_fun(x): + return api.vmap(custom_fun)(x).sum() + + return api.grad(summed_fun)(x) + + def scan_body(carry, inputs): + x = carry + return carry, train_fun(x) + + scan_range = jnp.arange(4) + lax.scan(scan_body, x, scan_range) # don't crash + + def test_initial_style_vmap_3(self): + # This is like test_initial_style_vmap except the primal function closes + # over an array constant. + y = jnp.arange(1., 4.) + + @jax.custom_vjp + def f(x): + assert jnp.ndim(x) == 0 + return 3 * x * jnp.sum(y) + def f_fwd(x): + return f(x), jnp.cos(x) + def f_rev(cos_x, g): + return (2 * cos_x * g,) + f.defvjp(f_fwd, f_rev) + + def foo(x): + out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) + return out + + ans = api.vmap(foo)(jnp.arange(3.)) + expected = 3. * jnp.arange(3.) * 6 + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.arange(3.)) + expected = 2. * jnp.cos(jnp.arange(3.)) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_initial_style_vmap_with_collective(self): + + @jax.custom_vjp + def f(x): + return lax.psum(x, 'foo') + + def f_fwd(x): + return lax.psum(x, 'foo'), None + + def f_bwd(res, dx): + return dx + f.defvjp(f_fwd, f_bwd) + + def g(x): + jaxpr = api.make_jaxpr(f)(x) + return core.eval_jaxpr(jaxpr.jaxpr, [], x)[0] + + out = api.vmap(lambda _, x: g(x), axis_name='foo', in_axes=(0, None), + out_axes=None)(jnp.arange(4.), 2.) + self.assertAllClose(out, 8.) + + def test_bwd_closes_over_tracer(self): + def f(y): + @jax.custom_vjp + def f(x): + return 2. * jnp.sin(x) + + def fwd(x): + return f(x), () + + def bwd(_, g): + return (2. * jnp.cos(y) * g,) # capture! + + f.defvjp(fwd, bwd) + + return jax.grad(f)(1.) + + ans = jax.jit(f)(2.) + self.assertAllClose(ans, 2. * jnp.cos(2.)) + + ans = jax.vmap(f)(jnp.arange(3.)) + self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) + + ans = jax.jit(jax.vmap(f))(jnp.arange(3.)) + self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) + + ans = jax.vmap(jax.jit(f))(jnp.arange(3.)) + self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) + + ans = jax.grad(f)(4.) + self.assertAllClose(ans, -2. * jnp.sin(4.)) + + def test_fwd_closes_over_tracer(self): + def f(y): + @jax.custom_vjp + def f(x): + return 2. * jnp.sin(x) + + def fwd(x): + return f(x), y + + def bwd(y, g): + return (2. * jnp.cos(y) * g,) # capture! + + f.defvjp(fwd, bwd) + + return jax.grad(f)(1.) + + ans = jax.jit(f)(2.) + self.assertAllClose(ans, 2. * jnp.cos(2.)) + + ans = jax.vmap(f)(jnp.arange(3.)) + self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) + + ans = jax.jit(jax.vmap(f))(jnp.arange(3.)) + self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) + + ans = jax.vmap(jax.jit(f))(jnp.arange(3.)) + self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) + + ans = jax.grad(f)(4.) + self.assertAllClose(ans, -2. * jnp.sin(4.)) + + def test_float0(self): + @jax.custom_vjp + def f(x, _): + return x + def f_fwd(x, _): + # we need a defined (non-float0) tangent to trigger the rule + return x, (2., 1) + def f_rev(*_): + return (2., 1) + f.defvjp(f_fwd, f_rev) + + x = 2. + y = 3 + self.assertEqual(api.grad(f, allow_int=True, argnums=(0, 1))(x, y), + (2., np.zeros(shape=(), dtype=float0))) + + def test_float0_initial_style(self): + @jax.custom_vjp + def f(x): + return x + def f_fwd(x): + return x, (2., x) + def f_rev(*_): + return ((2., jnp.zeros(shape=(), dtype=float0)),) + f.defvjp(f_fwd, f_rev) + + def foo(x, y): + out, _ = lax.scan(lambda c, _: (f(c), None), (x, y), None, length=1) + return out[0] + + x = 2. + y = 3 + self.assertEqual(api.grad(foo, allow_int=True, argnums=(0, 1))(x, y), + (2., np.zeros(shape=(), dtype=float0))) + + def test_remat(self): + @jax.custom_vjp + def f(x): + return jnp.sin(x) + def f_fwd(x): + return f(x), jnp.cos(x) + def f_rev(cos_x, g): + return (2 * cos_x * g,) + f.defvjp(f_fwd, f_rev) + + @jax.remat + def g(x): + return f(f(x)) + + ans = g(2.) + expected = np.sin(np.sin(2.)) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(g)(2.) + expected = 4. * api.grad(lambda x: jnp.sin(jnp.sin(x)))(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_remat_higher_order(self): + @jax.custom_vjp + def f(x): + return jnp.sin(x) + def f_fwd(x): + return f(x), jnp.cos(x) + def f_rev(cos_x, g): + return (2 * cos_x * g,) + f.defvjp(f_fwd, f_rev) + + def g(x): + return f(f(x)) + + ans = api.grad(api.grad(jax.remat(g)))(2.) + expected = api.grad(api.grad(g))(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(jax.remat(api.grad(g)))(2.) + expected = api.grad(api.grad(g))(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(api.grad(api.grad(jax.remat(g))))(2.) + expected = api.grad(api.grad(api.grad(g)))(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_bwd_nones(self): + @jax.custom_vjp + def f(x, y): + return x * jnp.sin(y) + def f_fwd(x, y): + return f(x, y), jnp.cos(y) + def f_rev(cos, g): + return (None, 2 * cos * g) + f.defvjp(f_fwd, f_rev) + + ans = api.grad(lambda x: f(x, x))(3.) + expected = 2 * jnp.cos(3.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_bwd_nones_vmap(self): + @jax.custom_vjp + def f(x, y): + return x * jnp.sin(y) + def f_fwd(x, y): + return f(x, y), jnp.cos(y) + def f_rev(cos, g): + return (None, 2 * cos * g) + f.defvjp(f_fwd, f_rev) + + ans = api.grad(lambda x: api.vmap(f)(x, x).sum())(jnp.arange(3.)) + expected = 2 * jnp.cos(jnp.arange(3.)) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_bwd_nones_pytree(self): + @jax.custom_vjp + def f(xs, y): + x1, x2 = xs + return x1 * x2 * jnp.sin(y) + def f_fwd(xs, y): + return f(xs, y), jnp.cos(y) + def f_rev(cos, g): + return (None, 2 * cos * g) + f.defvjp(f_fwd, f_rev) + + ans = api.grad(lambda x: f((x, x), x))(3.) + expected = 2 * jnp.cos(3.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_custom_vjp_closure_4521(self): + # https://github.com/jax-ml/jax/issues/4521 + @jax.custom_vjp + def g(x, y): + return None + def g_fwd(x, y): + return None, y + def g_bwd(residuals, z_bar): + assert False + + g.defvjp(g_fwd, g_bwd) + + def f(xs, y): + v_g = api.vmap(g, in_axes=(0, None), out_axes=None) + v_g(xs, y) + + def scan_body(xs, _): + y = jnp.zeros(1) + _, vjp_f = api.vjp(f, xs, y) + vjp_f(None) + return xs, None + + lax.scan(scan_body, jnp.ones(5), None, 100) # doesn't crash + + def test_float0_bwd_none(self): + @jax.custom_vjp + def f(i, x): + return jnp.sin(x) + def f_fwd(i, x): + return f(i, x), jnp.cos(x) + def f_rev(cos_x, g): + return (None, 2 * cos_x * g) + f.defvjp(f_fwd, f_rev) + + ans = api.grad(f, 1)(jnp.array([1, 2]), 3.) # doesn't crash + expected = 2 * jnp.cos(3.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_custom_gradient(self): + @jax.custom_gradient + def f(x): + return x ** 2, lambda g: (g * x,) + + self.assertAllClose(f(3.), 9., check_dtypes=False) + self.assertAllClose(api.grad(f)(3.), 3., check_dtypes=False) + self.assertAllClose(api.grad(api.grad(f))(3.), 1., check_dtypes=False) + + def test_custom_gradient_2(self): + @jax.custom_gradient + def f(x, y): + return x * y, lambda g: (y, x) + + self.assertAllClose(f(3., 4.), 12., check_dtypes=False) + self.assertAllClose(api.grad(f, argnums=(0, 1))(3., 4.), (4., 3.), + check_dtypes=False) + + def test_custom_gradient_3(self): + @jax.custom_gradient + def f(x): + vjp = lambda g: (jnp.cos(x) * jnp.arange(3., 6.),) + return jnp.sum(jnp.sin(x)), vjp + + self.assertAllClose(f(jnp.arange(3)), jnp.sum(jnp.sin(jnp.arange(3.))), + check_dtypes=False) + self.assertAllClose( + api.grad(f)(jnp.arange(3.)), + api.grad(lambda x: jnp.sum(jnp.sin(x)))(jnp.arange(3.)) * jnp.arange(3., 6.), + check_dtypes=False) + + def test_custom_gradient_can_return_singleton_value_in_vjp(self): + @jax.custom_gradient + def f(x): + return x ** 2, lambda g: g * x + + self.assertAllClose(f(3.), 9., check_dtypes=False) + self.assertAllClose(api.grad(f)(3.), 3., check_dtypes=False) + self.assertAllClose(api.grad(api.grad(f))(3.), 1., check_dtypes=False) + + def test_closure_convert(self): + def cos_after(fn, x): + converted_fn, aux_args = jax.closure_convert(fn, x) + self.assertLessEqual(len(aux_args), 1) + return _cos_after(converted_fn, x, *aux_args) + + @partial(jax.custom_vjp, nondiff_argnums=(0,)) + def _cos_after(fn, x, *args): + return jnp.cos(fn(x, *args)) + + def fwd(fn, x, *args): + y = _cos_after(fn, x, *args) + return y, (x, args) + + def rev(fn, res, g): + x, args = res + x_bar = 17. * x + args_bars = [42. * a for a in args] + return (x_bar, *args_bars) + + _cos_after.defvjp(fwd, rev) + + def dist(c, x): + return jnp.sum((x - c) ** 2.) + + def solve(c, x): + def closure(x): + return dist(c, x) + return cos_after(closure, x) + + c, x = 2. * jnp.ones(2), jnp.ones(2) + expected = jnp.cos(dist(c, x)) + self.assertAllClose(solve(c, x), expected, check_dtypes=False) + g_c, g_x = api.grad(solve, argnums=(0, 1))(c, x) + self.assertAllClose(g_c, 42. * c, check_dtypes=False) + self.assertAllClose(g_x, 17. * x, check_dtypes=False) + + def test_closure_convert_mixed_consts(self): + # Like test_closure_convert, but close over values that + # participate in AD as well as values that do not. + # See https://github.com/jax-ml/jax/issues/6415 + + def cos_after(fn, x): + converted_fn, aux_args = jax.closure_convert(fn, x) + self.assertLessEqual(len(aux_args), 1) + return _cos_after(converted_fn, x, *aux_args) + + @partial(jax.custom_vjp, nondiff_argnums=(0,)) + def _cos_after(fn, x, *args): + return jnp.cos(fn(x, *args)) + + def fwd(fn, x, *args): + y = _cos_after(fn, x, *args) + return y, (x, args) + + def rev(fn, res, g): + x, args = res + x_bar = 17. * x + args_bars = [42. * a for a in args] + return (x_bar, *args_bars) + + _cos_after.defvjp(fwd, rev) + + def dist(c, s, x): + return jnp.sum(s * (x - c) ** 2.) + + def solve(c, s, x): + def closure(x): + return dist(c, s, x) + return cos_after(closure, x) + + c, s, x = 2. * jnp.ones(2), 3. * jnp.ones(2), jnp.ones(2) + expected = jnp.cos(dist(c, s, x)) + self.assertAllClose(solve(c, s, x), expected, check_dtypes=False) + g_c, g_x = api.grad(solve, argnums=(0, 2))(c, s, x) + self.assertAllClose(g_c, 42. * c, check_dtypes=False) + self.assertAllClose(g_x, 17. * x, check_dtypes=False) + + def test_closure_convert_pytree_mismatch(self): + # See https://github.com/jax-ml/jax/issues/23588 + def f(x, z): + return z * x + + x, z = 2.0, 3.0 + _, vjp = api.vjp(f, x, z) + vjp_pure, vjp_aux_args = jax.closure_convert(vjp, x) + vjp_pure(x, *vjp_aux_args) + with self.assertRaisesRegex( + TypeError, "The inputs to the closure produced by closure_convert"): + vjp_pure(x, vjp_aux_args) + + def test_float0_cotangents_automatically_handled(self): + @jax.custom_vjp + def f(x, y): + return x + + def f_fwd(x, y): + return x, None + + def f_bwd(_, zbar): + return (0., 1) + + f.defvjp(f_fwd, f_bwd) + + jax.jit(lambda x: jax.vjp(f, 0., x)[1](1.))(1) # doesn't crash + + def test_custom_vjp_scan_batching_edge_case(self): + # https://github.com/jax-ml/jax/issues/5832 + @jax.custom_vjp + def mul(x, coeff): return x * coeff + def mul_fwd(x, coeff): return mul(x, coeff), (x, coeff) + def mul_bwd(res, g): + x, coeff = res + g_x = g * coeff + g_coeff = (x * g).sum() + return g_x, g_coeff + mul.defvjp(mul_fwd, mul_bwd) + + def scan_over_mul(x, coeff): + def f_(x, t): + return mul(x, coeff), None + y, _ = jax.lax.scan(f_, x, jnp.arange(3)) + return y + + key = jax.random.key(0) + key1, key2 = jax.random.split(key, 2) + x_batch = jax.random.normal(key1, (3, 2)) + covector_batch = jax.random.normal(key2, (3, 2)) + coeff = jnp.array(1., dtype=x_batch.dtype) + + batched_scan_over_mul = jax.vmap(scan_over_mul, in_axes=(0, None), out_axes=0) + res, vjp_fun = jax.vjp(batched_scan_over_mul, x_batch, coeff) + vjp_fun(covector_batch) # doesn't crash + + jtu.check_grads(batched_scan_over_mul, (x_batch, coeff), order=2, + modes=['rev']) + + def test_closure_with_vmap2(self): + # https://github.com/jax-ml/jax/issues/8783 + def h(z): + def f(x): + @jax.custom_vjp + def g(y): + return x * y + + def g_fwd(y): + return x * y, (x, x * y, y) + def g_rev(res, w_bar): + x, *_ = res + return (x * w_bar,) + g.defvjp(g_fwd, g_rev) + + return g(z) + + return jax.vmap(f)(jnp.arange(3., dtype='float32')).sum() + + jtu.check_grads(h, (jnp.float32(3.14),), order=1, modes=['rev']) + + def test_pytrees_not_required_to_contain_nones(self): + class A(list): + pass + + def unflatten(_, children): + assert children[0] is not None + return A(children) + + tree_util.register_pytree_node(A, lambda x: (x, None), unflatten) + + @jax.custom_vjp + def f(x): + return x[0] + def f_fwd(x): + return x[0], None + def f_bwd(_, g): + return A([g]), + f.defvjp(f_fwd, f_bwd) + + jax.grad(f)(A([1.])) # doesn't crash + + def test_vmap_vjp_called_twice(self): + # https://github.com/jax-ml/jax/pull/14728 + @jax.custom_vjp + def f(x): + return x + f.defvjp(lambda x: (x, None), lambda _, y_bar: (y_bar,)) + + _, f_vjp = jax.vjp(jax.vmap(f), jnp.array([3.])) + f_vjp(jnp.array([3.])) + f_vjp(jnp.array([3.])) # doesn't crash + + def test_symbolic_zero_custom_vjp_basic(self): + ZERO = jax.custom_derivatives.SymbolicZero + + @jax.custom_vjp + def f(x, y, z): + return x, x + + def fwd(x, y, z): + self.assertIsInstance(x, jax.custom_derivatives.CustomVJPPrimal) + self.assertIsInstance(y, jax.custom_derivatives.CustomVJPPrimal) + self.assertIsInstance(z, jax.custom_derivatives.CustomVJPPrimal) + self.assertTrue(x.perturbed) + self.assertFalse(y.perturbed) + self.assertFalse(z.perturbed) + return (x.value, x.value), None + + def fwd_all(x, y, z): + self.assertIsInstance(x, jax.custom_derivatives.CustomVJPPrimal) + self.assertIsInstance(y, jax.custom_derivatives.CustomVJPPrimal) + self.assertIsInstance(z, jax.custom_derivatives.CustomVJPPrimal) + self.assertTrue(x.perturbed) + self.assertTrue(y.perturbed) + self.assertTrue(z.perturbed) + return (x.value, x.value), None + + def bwd_all(_, g): + x1, x2 = g + self.assertFalse(type(x1) is ZERO) + self.assertFalse(type(x2) is ZERO) + return x1, x1, x2 + + def bwd_fst(_, g): + x1, x2 = g + self.assertFalse(type(x1) is ZERO) + self.assertIs(type(x2), ZERO) + return x1, x1, x2 + + def bwd_snd(_, g): + x1, x2 = g + self.assertIs(type(x1), ZERO) + self.assertFalse(type(x2) is ZERO) + return x1, x1, x2 + + x, y, z = 4., 5., 6. + i = np.array(7, np.int32) + zero = np.array(0.) + + f.defvjp(fwd, bwd_all, symbolic_zeros=True) + h = jax.jit(f) + jax.jacrev(h)(x, y, z) + jax.jacrev(lambda x: h(x, y, z))(x) + jax.jacrev(h, argnums=(0, 1, 2), allow_int=True)(x, i, i) + + f.defvjp(fwd_all, bwd_fst, symbolic_zeros=True) + fst_f = lambda *xs: f(*xs)[0] + _, vjp = jax.vjp(fst_f, x, y, z) + _, _, gz = vjp(x) + self.assertArraysAllClose(gz, zero) + + f.defvjp(fwd_all, bwd_snd, symbolic_zeros=True) + snd_f = lambda *xs: f(*xs)[1] + _, vjp = jax.vjp(snd_f, x, y, z) + gx, gy, _ = vjp(x) + self.assertArraysAllClose(gx, zero) + self.assertArraysAllClose(gy, zero) + + f.defvjp(fwd, bwd_snd, symbolic_zeros=True) + _, vjp = jax.vjp(lambda x: snd_f(x, y, z), x) + gx, = vjp(x) + self.assertArraysAllClose(gx, zero) + + def test_symbolic_zero_custom_vjp_bwd_shape_error(self): + @jax.custom_vjp + def f(x, y, z): + return x, y, z + + def fwd(x, y, z): + return f(x.value, y.value, z.value), None + + def bwd(_, gs): + x_bar, y_bar, z_bar = gs + return y_bar, x_bar, z_bar # swapped! + + f.defvjp(fwd, bwd, symbolic_zeros=True) + + with self.assertRaisesRegex( + ValueError, + r'Consider just returning a None here'): + jax.grad(lambda x, y, z: f(x, y, z)[2].sum())( + jnp.ones(1), jnp.ones(2), jnp.ones(3)) + + @parameterized.named_parameters( + ('jit_vmap', True, True), + ('jit', True, False), + ('vmap', False, True), + ('', False, False), + ) + def test_symbolic_zero_custom_vjp(self, maybe_jit, maybe_vmap): + # below: + # * static_scalar will be static in and out + # * static_array will be static in, but dynamic out + # * dyn_scalar and dyn_array will be dynamic in and out + + ZERO = jax.custom_derivatives.SymbolicZero + + def f(static_scalar, static_array, dyn_scalar, dyn_array): + out1 = static_scalar + dyn_scalar + out2 = static_array + dyn_array + return static_scalar, static_array, out1, out2 + + def _pack(x): + return lax.broadcast(x, (1,)) + + def _unpack(x): + (x,) = x + return x + + def _vmap(fun): + def _fun(*args): + args = jax.tree.map(_pack, args) + out = jax.vmap(fun)(*args) + out = jax.tree.map(_unpack, out) + return out + return _fun + + f = jax.custom_vjp(f) + + def fwd(*args): + xs, pert = [x.value for x in args], [x.perturbed for x in args] + self.assertFalse(pert[0]) + self.assertFalse(pert[1]) + self.assertTrue(pert[2]) + self.assertTrue(pert[3]) + return f(*xs), xs + + def bwd(res, g): + static_scalar, *_ = res + t_static, t_static_arr, t_dyn_scalar, t_dyn_array = g + self.assertIs(type(t_static), ZERO) + self.assertFalse(type(t_static_arr) is ZERO) + self.assertFalse(type(t_dyn_scalar) is ZERO) + self.assertFalse(type(t_dyn_array) is ZERO) + self.assertEqual(t_static.shape, ()) + self.assertEqual(t_static_arr.shape, (2,)) + return (static_scalar + 90, + t_static_arr + 91, + t_dyn_scalar + 92, + t_dyn_array + 93) + + f.defvjp(fwd, bwd, symbolic_zeros=True) + + def g(dyn_scalar, dyn_array): + if maybe_vmap: + f_ = _vmap(f) + else: + f_ = f + outs = f_(1., jnp.array([2., 3.]), dyn_scalar, dyn_array) + return outs[1:] + + def run(primal_ins, cotangent_outs): + primal_outs, vjp = jax.vjp(g, *primal_ins) + cotangent_ins = vjp(cotangent_outs) + return primal_outs, cotangent_ins + + if maybe_jit: + run = jax.jit(run) + + scalar_type = jax.Array if maybe_jit or maybe_vmap else float + primal_ins = (4., jnp.array([5., 6.])) + cotangent_outs = (jnp.array([10., 11.]), 7., jnp.array([8., 9.])) + primal_outs, cotangent_ins = run(primal_ins, cotangent_outs) + + primal_out1, primal_out2, primal_out3 = primal_outs + self.assertIsInstance(primal_out1, jax.Array) + self.assertAllClose(primal_out1, jnp.array([2., 3.])) + if self.__class__ is CustomVJPTest: + # TODO(mattjj): we don't yet support this behavior for CustomVJPTraced + self.assertIsInstance(primal_out2, scalar_type) + self.assertAllClose(primal_out2, 5.) + self.assertIsInstance(primal_out3, jax.Array) + self.assertAllClose(primal_out3, jnp.array([7., 9.])) + + ct_in1, ct_in2 = cotangent_ins + self.assertIsInstance(ct_in1, scalar_type) + self.assertAllClose(ct_in1, 99.) + self.assertIsInstance(ct_in2, jax.Array) + self.assertArraysAllClose(ct_in2, jnp.array([101., 102.])) + + def test_symbolic_zero_custom_vjp_vmap_output(self): + @jax.custom_vjp + def f(x, y): + return x, y + + def fwd(x, y): + self.assertTrue(x.perturbed) + self.assertFalse(y.perturbed) + return f(x.value, y.value), None + + def bwd(_, g): + _, ct_y = g + self.assertIs(type(ct_y), jax.custom_derivatives.SymbolicZero) + return g + + f.defvjp(fwd, bwd, symbolic_zeros=True) + jax.grad(lambda x, y: jax.vmap(f)(x, y)[0].sum())(jnp.ones(3), jnp.ones(3)) + + def test_symbolic_zero_custom_vjp_custom_pytree(self): + tree_values = jax.custom_derivatives.custom_vjp_primal_tree_values + + @tree_util.register_pytree_node_class + class Box: + def __init__(self_, strict, val): + if strict: + # make sure we aren't getting special arguments that should only + # come up when symbolic_zeros is True + self.assertFalse(hasattr(val, 'perturbed')) + self_.strict = strict + self_.x = val + + def tree_flatten(self_): + return [self_.x], self_.strict + + @classmethod + def tree_unflatten(cls, strict, xs): + x, = xs + return cls(strict, x) + + x, y = Box(False, jnp.array(72.)), jnp.array(73.) + + @jax.custom_vjp + def f(box, y): + return box.x * y + + def fwd0(box, y): + self.assertTrue(box.x.perturbed) + self.assertFalse(y.perturbed) + box, y = map(tree_values, [box, y]) + return f(box, y), (box, y) + + def bwd0(res, g): + box, y = res + return y * g, box.x * g + + def fwd1(box, y): + self.assertFalse(box.x.perturbed) + self.assertTrue(y.perturbed) + box, y = map(tree_values, [box, y]) + return f(box, y), (box, y) + + def bwd1(res, g): + box, y = res + return y * g, box.x * g + + f.defvjp(fwd0, bwd0, symbolic_zeros=True) + jax.grad(f, argnums=0)(x, y) + f.defvjp(fwd1, bwd1, symbolic_zeros=True) + jax.grad(f, argnums=1)(x, y) + + def fwd_strict(box, y): + return f(box, y), (box, y) + + def bwd_strict(res, g): + box, y = res + return y * g, box.x * g + + f.defvjp(fwd_strict, bwd_strict) + jax.grad(f)(x, y) + + def test_symbolic_zeros_memoization_caching(self): + # Tests multiple zero patterns for partial_eval._memoize, and also tests + # that we're okay with stores being occupied with equal values. + @jax.custom_vjp + def f(x, y): + return x * y + + def f_fwd(x, y): + return x.value, None + + def f_bwd(_, z_bar): + return z_bar, None + + f.defvjp(f_fwd, f_bwd, symbolic_zeros=True) + + f_ = core.jaxpr_as_fun(jax.make_jaxpr(f)(2., 3.)) + _ = jax.linearize(f_, 2., 3.) + _ = jax.linearize(lambda x: f_(x, 3.), 2.) # don't crash! + + def test_run_rules_more_than_once(self): + # https://github.com/jax-ml/jax/issues/16614 + + @jax.custom_vjp + def f(x, y): + return x + y + + def f_fwd(x, y): + if y.perturbed: + res = None + else: + res = [] + return x.value + y.value, res + + def f_bwd(res, ct): + return ct, ct + + f.defvjp(f_fwd, f_bwd, symbolic_zeros=True) + + def body(x_y, _): + x, y = x_y + return (f(x, y), x), None + + @jax.grad + def g(x): + (out, _), _ = lax.scan(body, (x, 1.), xs=None, length=2) + return out + + g(1.) # doesn't crash + + def test_symbolic_zeros_remat(self): + @jax.custom_vjp + def f(x): + return x + def f_fwd(x): + return f(x.value), None + def f_bwd(_, g): + return g, + f.defvjp(f_fwd, f_bwd, symbolic_zeros=True) + + @jax.remat + def foo(x): + return f(f(x)) + + jax.grad(foo)(3.) + + def test_nones_representing_zeros_in_subtrees_returned_by_bwd(self): + # https://github.com/jax-ml/jax/issues/8356 + @jax.custom_vjp + def f(x): + return x[0] + + def f_fwd(x): + return f(x), None + + def f_bwd(_, z_bar): + return (z_bar, (None, None)), + + f.defvjp(f_fwd, f_bwd) + + jax.grad(f)((1.0, (2.0, 3.0))) # don't crash + + def test_pytree_nones_returned_by_bwd(self): + @jax.custom_vjp + def f(x): + return x[0] + + def f_fwd(x): + return f(x), None + + def f_bwd(_, z_bar): + return (z_bar, (None, None)), + + f.defvjp(f_fwd, f_bwd) + + jax.grad(f)((1.0, (2.0, None))) # don't crash + + def test_bwd_rule_shape_mismatch(self): + @jax.custom_vjp + def foo(x, y): + return x + + def foo_fwd(x, y): + return x, None + + def foo_bwd(_, g): + return jnp.zeros(3), jnp.zeros(3) + + foo.defvjp(foo_fwd, foo_bwd) + + with self.assertRaisesRegex( + ValueError, + r'output\[1\] the bwd rule produced an output of type f.*\[3\]'): + jax.grad(lambda x, y: foo(x, y * y).sum(), 1)(jnp.ones(3), jnp.ones(4)) + + def test_bwd_rule_shape_mismatch_disable(self): + # TODO(mattjj): remove this test when the config option is removed + @jax.custom_vjp + def foo(x, y): + return x + + def foo_fwd(x, y): + return x, None + + def foo_bwd(_, g): + return jnp.zeros(3), jnp.zeros(3) + + foo.defvjp(foo_fwd, foo_bwd) + + with config.disable_bwd_checks(True): + jax.grad(lambda x, y: foo(x, y).sum(), 1)(jnp.ones(3), jnp.ones(4)) + + def test_bwd_rule_can_produce_list_or_tuple(self): + @jax.custom_vjp + def f(x, y): + return x * y + + def f_fwd(x, y): + return f(x, y), (x, y) + + def f_bwd(xy, g): + x, y = xy + return [g * y, x * g] # list, not tuple + + f.defvjp(f_fwd, f_bwd) + + jax.grad(f)(1., 2.) # don't crash + + def test_optimize_remat(self): + def fun(x): + # This array is included to make sure that we handle consts appropriately + return np.array([1.0])*x + + def fwd(x): + return np.array([2.0])*x*x/np.array([1.0]), (2 * x,) + + x = jnp.linspace(0, 5.0, 10) + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( + fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}), + fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {})) + + self.assertAllClose(jax.jit(fwd)(x)[0], 2*x*x) # Shouldn't hit custom DCE + self.assertAllClose(jax.jit(lambda x: fwd(x)[0])(x), x) # Should be DCEed + + def test_optimize_remat_vmap(self): + def fun(x): + return (np.array([1.0])*x)[0] + def fwd(x): + return (np.array([2.0])*x*x/np.array([1.0]))[0], (2 * x,) + x = jnp.linspace(0, 5.0, 10) + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( + fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}), + fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {})) + self.assertAllClose(jax.jit(jax.vmap(fwd))(x)[0], 2*x*x) + self.assertAllClose(jax.jit(lambda x: jax.vmap(fwd)(x)[0])(x), x) + + def test_optimize_remat_cond(self): + def fun(x): + return x + def fwd(x): + return x*x, (2 * x,) + + x = jnp.linspace(0, 5.0, 10) + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( + fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}), + fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {})) + + def g(x): + return jax.lax.cond(True, fwd, lambda x: (2.0 * x, (x,)), x) + + self.assertAllClose(jax.jit(g)(x)[0], x*x) + self.assertAllClose(jax.jit(lambda x: g(x)[0])(x), x) + + def test_optimize_remat_jvp(self): + def fun(x): + return x**2 + def fwd_(x): + return x*x, (2 * x,) + + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( + fun, api_util.debug_info("custom_vjp fun", fun, (3.2,), {}), + fwd_, api_util.debug_info("custom_vjp fwd", fwd_, (3.2,), {})) + calc = jax.jvp(fwd, (3.2,), (1.0,)) + expected = jax.jvp(fwd_, (3.2,), (1.0,)) + self.assertAllClose(calc, expected) + + @jax.jit + def g(x, t): + (y, r), (y_dot, r_dot) = jax.jvp(fwd, (x,), (t,)) + return y, y_dot + calc = g(3.2, 1.0) + expected = jax.jvp(fun, (3.2,), (1.0,)) + self.assertAllClose(calc, expected) + + def test_optimize_remat_gh21303(self): + @jax.custom_vjp + def f(x): + return jnp.tan(x) + + def f_fwd(x): + return jnp.sin(x), (x,) + + def f_bwd(res, g): + x, = res + cos_x = jnp.cos(x) + return (cos_x * g,) + + f.defvjp(f_fwd, f_bwd, optimize_remat=True) + + def temp(x): + out = jax.remat(f)(x) + out = out ** 2 + return out + + v, g = jax.value_and_grad(temp)(3.2) + self.assertAllClose(v, jnp.tan(3.2)**2) + + def test_optimize_remat_multiple_args(self): + def f_(x, y): + return jnp.sin(x) * y + + @jax.custom_vjp + def f(x, y): + return f_(x, y) + + def f_fwd(x, y): + return f(x, y), (jnp.cos(x), jnp.sin(x), y) + + def f_bwd(res, g): + cos_x, sin_x, y = res + return (cos_x * g * y, sin_x * g) + + f.defvjp(f_fwd, f_bwd, optimize_remat=True) + x, y = 3.2, 1.0 + self.assertAllClose(jax.grad(f)(x, y), jax.grad(f_)(x, y)) + + def test_optimize_remat_kwargs(self): + @jax.custom_vjp + def f(x, y): + return jnp.sin(x) * y + + def f_fwd(x, y, *, keyword=False): + del keyword + return f(x, y), (jnp.cos(x), jnp.sin(x), y) + + def f_bwd(res, g): + cos_x, sin_x, y = res + return (cos_x * g * y, sin_x * g) + + f.defvjp(f_fwd, f_bwd, optimize_remat=True) + x, y = 3.2, 1.0 + jax.grad(f)(x, y) # Doesn't error + + def test_optimize_remat_custom_vmap(self): + # See https://github.com/jax-ml/jax/pull/23000 + @jax.custom_vjp + def f(x, y): + return jnp.sin(x) * y + + @jax.custom_batching.custom_vmap + def f_fwd(x, y): + return f(x, y), (jnp.cos(x), jnp.sin(x), y) + + @f_fwd.def_vmap + def f_fwd_vmap(_, in_batched, x, y): + # Insert a new const here to test the optimize_remat batching rule. + out = np.array([2.0])*f(x, y) + out_batched = (True, (True, True, True)) + return (out, (jnp.cos(x), jnp.sin(x), y)), out_batched + + def f_bwd(res, g): + cos_x, sin_x, y = res + return (cos_x * g * y, sin_x * g) + + f.defvjp(f_fwd, f_bwd, optimize_remat=True) + x, y = jnp.linspace(0.0, 1.0, 5), jnp.linspace(2.0, 5.0, 5) + jax.jit(jax.vmap(jax.grad(f)))(x, y) # Doesn't error + + def test_dce(self): + @jax.custom_vjp + def f(x, y): + return jnp.sin(x), x + jnp.cos(y) + + def f_fwd(x, y): + return f(x, y), (jnp.cos(x), jnp.sin(y)) + + def f_bwd(res, cts): + cos_x, sin_y = res + ct_a, ct_b = cts + return 2.0 * cos_x * ct_a + 1.5 * ct_b, -0.5 * sin_y * ct_b + + f.defvjp(f_fwd, f_bwd) + + def check_jaxpr(jaxpr, used_outs, includes, excludes): + dce_jaxpr, _ = pe.dce_jaxpr(jaxpr, used_outs) + if not dce_jaxpr.eqns: + assert not includes + return + call_jaxpr = dce_jaxpr.eqns[0].params["call_jaxpr"] + for prim in includes: + assert any(eqn.primitive == prim for eqn in call_jaxpr.eqns) + for prim in excludes: + assert all(eqn.primitive != prim for eqn in call_jaxpr.eqns) + + x, y = 0.1, -1.3 + jaxpr = jax.make_jaxpr(f)(x, y).jaxpr + check_jaxpr(jaxpr, [True, True], [lax.sin_p, lax.cos_p], []) + check_jaxpr(jaxpr, [True, False], [lax.sin_p], [lax.cos_p]) + check_jaxpr(jaxpr, [False, True], [lax.cos_p], [lax.sin_p]) + check_jaxpr(jaxpr, [False, False], [], [lax.sin_p, lax.cos_p]) + + def dce_jaxpr_as_fun(jaxpr, used_outs): + jaxpr_, _ = pe.dce_jaxpr(jaxpr, used_outs) + fun = core.jaxpr_as_fun(pe.close_jaxpr(jaxpr_)) + return lambda *args: fun(*args)[0] + + f0 = dce_jaxpr_as_fun(jaxpr, [True, False]) + f1 = dce_jaxpr_as_fun(jaxpr, [False, True]) + self.assertAllClose( + api.grad(f0, argnums=(0, 1))(x, y), (2.0 * jnp.cos(x), 0.0)) + self.assertAllClose( + api.grad(f1, argnums=(0, 1))(x, y), (1.5, -0.5 * jnp.sin(y))) + + def test_resolve_kwargs_error_message(self): + @jax.custom_vjp + def f(x, y, *, z=None): + return jnp.sin(x), x + jnp.cos(y) + + def f_fwd(x, y): + self.fail("should not be executed") + + def f_bwd(res, cts): + self.fail("should not be executed") + + f.defvjp(f_fwd, f_bwd) + + with self.assertRaisesRegex( + TypeError, + r"missing a required argument: 'y'" + ): + f(0.5) + + with self.assertRaisesRegex( + TypeError, + "The following keyword arguments could not be resolved to positions: z" + ): + f(0.5, 0.1, z=1.0) + + def test_pretty_print(self): + @jax.custom_vjp + def f(x): + return x + 1 + + def f_fwd(x): + return f(x), () + + def f_bwd(_, g): + return g + f.defvjp(f_fwd, f_bwd) + + x = jnp.array([4.2], dtype=jnp.float32) + jaxpr = jax.make_jaxpr(f)(x) + actual = jaxpr.pretty_print(use_color=False) + expected = textwrap.dedent( + """ + { lambda ; a:f32[1]. let + b:f32[1] = custom_vjp_call[ + name=f + bwd=f_bwd + call_jaxpr={ lambda ; c:f32[1]. let d:f32[1] = add c 1.0:f32[] in (d,) } + fwd=f_fwd + symbolic_zeros=False + ] a + in (b,) } + """).strip() + self.assertEqual(actual, expected) + + def test_custom_lin_pretty_print(self): + @jax.custom_vjp + def f(x): + return x + 1 + + def f_fwd(x): + return f(x), () + + def f_bwd(_, g): + return g + f.defvjp(f_fwd, f_bwd) + + x = jnp.array([4.2], dtype=jnp.float32) + jaxpr = jax.make_jaxpr(lambda x: jax.jvp(f, (x,), (x,)))(x) + jaxpr, _ = pe.dce_jaxpr(jaxpr.jaxpr, [False, True]) + actual = jaxpr.pretty_print(use_color=False) + expected = textwrap.dedent( + """ + { lambda ; a:f32[1]. let + b:f32[1] = custom_lin[ + bwd=f_bwd + in_zeros=[False] + num_res=0 + symbolic_zeros=False + ] a + in (b,) } + """).strip() + self.assertEqual(actual, expected) + +@unittest.skip("delete this when running manually, doesn't work in CI") +class CustomVJP3Test(CustomVJPTest): + def setUp(self): + self.prev, jax.custom_vjp = jax.custom_vjp, hijax.custom_vjp3 + + def tearDown(self): + jax.custom_vjp = self.prev + + # closure + def test_closed_over_vmap_tracer(self): pass + def test_bwd_closes_over_tracer(self): pass + def test_closed_over_tracer3(self): pass + def test_closure_with_vmap2(self): pass + def test_fwd_closes_over_tracer(self): pass + + # eager (ie dont always trace, unless under a jit) + def test_python_control_flow(self): pass + + # regress these, hope no one cares + def test_pytrees_not_required_to_contain_nones(self): pass + def test_symbolic_zero_custom_vjp_bwd_shape_error(self): pass + + def test_fwd_rule_primal_out_type_doesnt_match_primal_error_message(self): + def scan_apply(f, x): + y, _ = jax.lax.scan(lambda x, _: (f(x), None), x, None, length=1) + return y + + @jax.custom_vjp + def f(x): + return x + + def f_fwd(x): + return (x, x), None + + def f_bwd(_, y_bar): + return (y_bar,) + + f.defvjp(f_fwd, f_bwd) + + self.assertRaisesRegex( + TypeError, + "Custom VJP fwd rule f_fwd for function f must produce a pair ", + lambda: jax.grad(lambda x: scan_apply(f, x))(jnp.float32(1.))) + + def f_fwd2(x): + return jnp.zeros((3, *x.shape), x.dtype), None + + def f_bwd2(_, y_bar): + return (y_bar,) + + f.defvjp(f_fwd2, f_bwd2) + + self.assertRaisesRegex( + TypeError, + r"got fwd output type float32\[3\] which doesn't match", + lambda: jax.grad(lambda x: scan_apply(f, x))(jnp.float32(1.))) + + # bad tests + def test_dce(self): pass # TODO (test jaxpr) + + # pretty-printing + def test_pretty_print(self): pass + def test_custom_lin_pretty_print(self): pass + + # maybe we don't need to support? + def test_symbolic_zeros_remat(self): pass + +def transpose_unary(f, x_example): + def transposed(y): + x, = api.linear_transpose(f, x_example)(y) + return x + return transposed + + +# This class wraps jax.custom_transpose.custom_transpose in order to pass in a +# particular tree of output type on each call. Otherwise it forwards +# all attribute access. +class _custom_transpose: + def __init__(self, out_types, fun): + self.out_types = out_types + self.fun = jax.custom_transpose.custom_transpose(fun) + + def __getattr__(self, name): + return getattr(self.fun, name) + + def __call__(self, *args): + return self.fun(self.out_types, *args) + + +# This function is meant to be used as a decorator that delegates to +# custom_transpose but makes it easy to specify output argument types +# by example. If used directly a decorator (i.e. not invoked with +# example arguments), assumes a scalar-valued function. +# +# TODO(frostig): remove this (and its uses) once custom_transpose offers +# an option of inferring output types. +def custom_transpose(example_out): + if isinstance(example_out, Callable): + out_type = core.get_aval(0.).to_tangent_aval() + return _custom_transpose(out_type, example_out) + return partial( + _custom_transpose, + jax.tree.map( + lambda x: core.get_aval(x).to_tangent_aval(), example_out)) + + +class CustomTransposeTest(jtu.JaxTestCase): + + def test_linear_call(self): + def f(x, y): + def fn(r, x): return x / r + def tp(r, t): return t / r + return x + jax.custom_derivatives.linear_call(fn, tp, y, x) + + def f_ref(x, y): + return x + x / y + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + self.assertAllClose(f(x, y), f_ref(x, y)) + + f1 = lambda x: f(x, y) + f1_ref = lambda x: f_ref(x, y) + self.assertAllClose(transpose_unary(f1, x)(x), + transpose_unary(f1_ref, x)(x)) + + def test_linear_call_incorrect_transpose(self): + def f(x, y): + def fn(r, x): return x / r + def tp(r, t): return t / (2. * r) # nb: not the true transpose + return x + jax.custom_derivatives.linear_call(fn, tp, y, x) + + def f_ref(x, y): + return x + x / y + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + self.assertAllClose(f(x, y), f_ref(x, y)) + + f1 = lambda x: f(x, y) + f1_ref = lambda x: f_ref(x, 2. * y) # nb: double the reference divisor + self.assertAllClose(transpose_unary(f1, x)(x), + transpose_unary(f1_ref, x)(x)) + + def test_linear_call_transpose_transpose_transpose(self): + def fn(r, x): return x / r + def tp(r, t): return t / (2. * r) # nb: untrue transpose + def f_(x, y): + return x + jax.custom_derivatives.linear_call(fn, tp, y, x) + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + f = lambda x: f_(x, y) + ft = transpose_unary(f, x) + ftt = transpose_unary(ft, x) + fttt = transpose_unary(ftt, x) + self.assertAllClose(ft(x), x + tp(y, x)) + self.assertAllClose(f(x), ftt(x)) + self.assertAllClose(ft(x), fttt(x)) + + def test_linear_call_scalar_to_vector(self): + def f(c, x): + def fn(_, x): + return [x, x] + + def tp(_, t): + t1, t2 = t + return t1 + t2 + + return jax.custom_derivatives.linear_call(fn, tp, (), c * x) + + def f_ref(c, x): + return [c * x, c * x] + + c, x = 2., 3. + t = [4., 5.] + self.assertAllClose(f(c, x), f_ref(c, x)) + self.assertAllClose(transpose_unary(partial(f, c), x)(t), + transpose_unary(partial(f_ref, c), x)(t)) + + def test_linear_call_nested(self): + # identity function with an untrue transpose of 0 + def id_(x): + def f(_, x): return x + def t(_, t): return 0. + return jax.custom_derivatives.linear_call(f, t, (), x) + + # identity function with an untrue transpose of 7, and where both + # forward and transpose have custom transpositions that should + # never end up invoked. + def f(x): + def f_(_, x): return id_(x) + def t_(_, t): return id_(7.) + return jax.custom_derivatives.linear_call(f_, t_, (), x) + + x = 5. + id_t = transpose_unary(id_, x) + id_tt = transpose_unary(id_t, x) + ft = transpose_unary(f, x) + ftt = transpose_unary(ft, x) + fttt = transpose_unary(ftt, x) + + self.assertAllClose(id_(x), x) + self.assertAllClose(id_t(x), 0.) + self.assertAllClose(id_tt(x), x) + + self.assertAllClose(f(x), x) + self.assertAllClose(ft(x), 7.) + self.assertAllClose(ftt(x), x) + self.assertAllClose(fttt(x), 7.) + + def test_linear_call_jit(self): + def f(x, y): + def fn(r, x): return x / r + def tp(r, t): return t / r + return x + jax.custom_derivatives.linear_call(fn, tp, y, x) + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + self.assertAllClose(f(x, y), jax.jit(f)(x, y)) + + f1 = lambda x: f(x, y) + self.assertAllClose(transpose_unary(f1, x)(x), + jax.jit(transpose_unary(f1, x))(x)) + + def test_linear_call_type_mismatch(self): + def f(x, y): + def fn(r, x): return x / r + def tp(r, t): return None + return x + jax.custom_derivatives.linear_call(fn, tp, y, x) + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + f1 = lambda x: f(x, y) + with self.assertRaisesRegex(TypeError, "transpose output pytree"): + transpose_unary(f1, x)(x) + + def test_linear_call_recursion(self): + def f(x): + def fn(_, x): return x + def tp(_, t): return f(t) + return jax.custom_derivatives.linear_call(fn, tp, None, x) + jax.jit(f)(0.1) + + def test_linear_call_grad(self): + def f(x, y): + def fn(r, x): return x / r + def tp(r, t): return t / r + return x + jax.custom_derivatives.linear_call(fn, tp, y, x) + + def f_ref(x, y): + return x + x / y + + x = jnp.array(6.) + y = jnp.array(3.) + self.assertAllClose(jax.grad(f)(x, y), jax.grad(f_ref)(x, y)) + + def test_basic(self): + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return t / r + + return x + fn(y, x) + + def f_ref(x, y): + return x + x / y + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + self.assertAllClose(f(x, y), f_ref(x, y)) + + f1 = lambda x: f(x, y) + f1_ref = lambda x: f_ref(x, y) + self.assertAllClose(transpose_unary(f1, x)(x), + transpose_unary(f1_ref, x)(x)) + + def test_incorrect_transpose(self): + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return t / (2. * r) # nb: not the true transpose + + return x + fn(y, x) + + def f_ref(x, y): + return x + x / y + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + self.assertAllClose(f(x, y), f_ref(x, y)) + + f1 = lambda x: f(x, y) + f1_ref = lambda x: f_ref(x, 2. * y) # nb: double the reference divisor + self.assertAllClose(transpose_unary(f1, x)(x), + transpose_unary(f1_ref, x)(x)) + + def test_transpose_transpose_transpose(self): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @custom_transpose(jnp.ones(2)) + def tp(r, t): return t / (2. * r) # nb: untrue transpose + + fn.def_transpose(tp) + tp.def_transpose(fn) + + def f_(x, y): + return x + fn(y, x) + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + f = lambda x: f_(x, y) + ft = transpose_unary(f, x) + ftt = transpose_unary(ft, x) + fttt = transpose_unary(ftt, x) + self.assertAllClose(ft(x), x + tp(y, x)) + self.assertAllClose(f(x), ftt(x)) + self.assertAllClose(ft(x), fttt(x)) + + def test_scalar_to_vector(self): + def f(c, x): + @custom_transpose([0., 0.]) + def fn(_, x): + return [x, x] + + @fn.def_transpose + def tp(_, t): + t1, t2 = t + return t1 + t2 + + return fn((), c * x) + + def f_ref(c, x): + return [c * x, c * x] + + c, x = 2., 3. + t = [4., 5.] + self.assertAllClose(f(c, x), f_ref(c, x)) + self.assertAllClose(transpose_unary(partial(f, c), x)(t), + transpose_unary(partial(f_ref, c), x)(t)) + + def test_nested(self): + # identity function with an untrue transpose of 0 + def id_(x): + f = custom_transpose(lambda _, x: x) + t = custom_transpose(lambda _, t: 0.) + f.def_transpose(t) + t.def_transpose(f) + return f((), x) + + # identity function with an untrue transpose of 7, and where both + # forward and transpose have custom transpositions that should + # never end up invoked. + def f(x): + f_ = custom_transpose(lambda _, x: id_(x)) + t_ = custom_transpose(lambda _, t: id_(7.)) + f_.def_transpose(t_) + t_.def_transpose(f_) + return f_((), x) + + x = 5. + id_t = transpose_unary(id_, x) + id_tt = transpose_unary(id_t, x) + ft = transpose_unary(f, x) + ftt = transpose_unary(ft, x) + fttt = transpose_unary(ftt, x) + + self.assertAllClose(id_(x), x) + self.assertAllClose(id_t(x), 0.) + self.assertAllClose(id_tt(x), x) + + self.assertAllClose(f(x), x) + self.assertAllClose(ft(x), 7.) + self.assertAllClose(ftt(x), x) + self.assertAllClose(fttt(x), 7.) + + def test_one_degree(self): + T = lambda f: transpose_unary(f, 0.) + + @custom_transpose + def f(_, z): return 2. * z + @f.def_transpose + def ft(_, z): return 3. * z + + f = partial(f, ()) + self.assertAllClose(2., f(1.)) + self.assertAllClose(3., T(f)(1.)) + self.assertAllClose(3., T(T(f))(1.)) + self.assertAllClose(3., T(T(T(f)))(1.)) + self.assertAllClose(3., T(T(T(T(f))))(1.)) # ... + + def test_two_degrees(self): + T = lambda f: transpose_unary(f, 0.) + + @custom_transpose + def f(_, z): return 2. * z + + @f.def_transpose + @custom_transpose + def ft(_, z): return 3. * z + + @ft.def_transpose + def ftt(_, z): return 7. * z + + f = partial(f, ()) + self.assertAllClose(2., f(1.)) + self.assertAllClose(3., T(f)(1.)) + self.assertAllClose(7., T(T(f))(1.)) + self.assertAllClose(7., T(T(T(f)))(1.)) + self.assertAllClose(7., T(T(T(T(f))))(1.)) # ... + + def test_symmetric(self): + T = lambda f: transpose_unary(f, 0.) + + @custom_transpose + def f(_, z): return 2. * z + @custom_transpose + def g(_, z): return 3. * z + + f.def_transpose(g) + g.def_transpose(f) + + f = partial(f, ()) + self.assertAllClose(2., f(1.)) + self.assertAllClose(3., T(f)(1.)) + self.assertAllClose(2., T(T(f))(1.)) + self.assertAllClose(3., T(T(T(f)))(1.)) + self.assertAllClose(2., T(T(T(T(f))))(1.)) # ... + + def test_recursive(self): + T = lambda f: transpose_unary(f, 0.) + + @custom_transpose + def f(c, z): return c * z + + @f.def_transpose + def ft(c, z): return f(c + 1., z) + + g = partial(f, 1.) + self.assertAllClose(1., g(1.)) + self.assertAllClose(2., T(g)(1.)) + self.assertAllClose(3., T(T(g))(1.)) + self.assertAllClose(4., T(T(T(g)))(1.)) + self.assertAllClose(5., T(T(T(T(g))))(1.)) # ... + + def test_jvp_lin(self): + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return t / r + return x + fn(y, x) + + def f_ref(x, y): return x + x / y + + x, y, tx = 6., 3., 1. + g = lambda x: f(x, y) + g_ref = lambda x: f_ref(x, y) + self.assertAllClose(api.jvp(g, [x], [tx]), api.jvp(g_ref, [x], [tx])) + + def test_jvp_res(self): + raise unittest.SkipTest('unimplemented') # TODO(frostig) + + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return t / r + return x + fn(y, x) + + def f_ref(x, y): return x + x / y + + x, y, ty = 6., 3., 1. + g = lambda y: f(x, y) + g_ref = lambda y: f_ref(x, y) + self.assertAllClose(api.jvp(g, [y], [ty]), api.jvp(g_ref, [y], [ty])) + + def test_jvp_both(self): + raise unittest.SkipTest('unimplemented') # TODO(frostig) + + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return t / r + return x + fn(y, x) + + def f_ref(x, y): return x + x / y + + x, y, tx, ty = 6., 3., 1., 1. + self.assertAllClose(api.jvp(f, [x, y], [tx, ty]), + api.jvp(f_ref, [x, y], [tx, ty])) + + def test_make_jaxpr(self): + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return 2 * t / r + + return x + fn(y, x) + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + f_ = lambda x: f(x, y) + f_t = transpose_unary(f_, x) + + jaxpr = api.make_jaxpr(f_)(x) + self.assertIn('custom_transpose_call', str(jaxpr)) + + jaxpr_t = api.make_jaxpr(f_t)(x) + self.assertNotIn('custom_transpose_call', str(jaxpr_t)) + + def test_jit(self): + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return 2 * t / r + + return x + fn(y, x) + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + self.assertAllClose(f(x, y), jax.jit(f)(x, y)) + + f_ = lambda x: f(x, y) + f_t = transpose_unary(f_, x) + g_ = jax.jit(f_) + g_t = transpose_unary(g_, x) + self.assertAllClose(f_(x), jax.jit(f_)(x)) + self.assertAllClose(f_t(x), jax.jit(f_t)(x)) + self.assertAllClose(f_(x), g_(x)) + self.assertAllClose(f_t(x), g_t(x)) + + def test_jit_recursive(self): + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return 2 * fn(r, t) + + return x + fn(y, x) + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + self.assertAllClose(f(x, y), jax.jit(f)(x, y)) + + f_ = lambda x: f(x, y) + f_t = transpose_unary(f_, x) + g_ = jax.jit(f_) + g_t = transpose_unary(g_, x) + self.assertAllClose(f_(x), jax.jit(f_)(x)) + self.assertAllClose(f_t(x), jax.jit(f_t)(x)) + self.assertAllClose(f_(x), g_(x)) + self.assertAllClose(f_t(x), g_t(x)) + + def test_cond(self): + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return 2 * t / r + + return x + fn(y, x) + + def cond_wrap(f): + return lambda i, x: lax.cond(i > 0, f, lambda x: x, x) + + i = 7. + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + + f_ = lambda x: f(x, y) + f_t = transpose_unary(f_, x) + g_ = partial(cond_wrap(f_), i) + g_t = transpose_unary(g_, x) + + self.assertAllClose(f_(x), g_(x)) + self.assertAllClose(f_t(x), g_t(x)) + + def test_cond_recursive(self): + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return 2 * fn(r, t) + + return x + fn(y, x) + + def cond_wrap(f): + return lambda i, x: lax.cond(i > 0, f, lambda x: x, x) + + i = 7. + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + + f_ = lambda x: f(x, y) + f_t = transpose_unary(f_, x) + g_ = partial(cond_wrap(f_), i) + g_t = transpose_unary(g_, x) + + self.assertAllClose(f_(x), g_(x)) + self.assertAllClose(f_t(x), g_t(x)) + + def test_compose_custom_jvp(self): + @jax.custom_jvp + def f(x): + return jnp.sin(x) + + @f.defjvp + def f_jvp(primals, tangents): + x, = primals + dx, = tangents + return f(x), g(x, dx) + + @custom_transpose + def g(x, dx): + return jnp.cos(x) * dx + + @g.def_transpose + def gt(x, t): + return jnp.cos(x) * t + + with config.use_direct_linearize(True): + self.assertAllClose(jax.grad(f)(0.5), jnp.cos(0.5)) + + def test_input_none(self): + # ref: https://github.com/jax-ml/jax/issues/29009 + @jax.custom_jvp + def f(x, y): return y + @f.defjvp + def f_jvp(p, t): return f(*p), g(p, t) + + @custom_transpose(jnp.float32(0)) + def g(r, x): return x[1] + @g.def_transpose + def gt(r, t): return None, jnp.zeros_like(r[1]) + + jax.grad(f, argnums=(1,))(None, jnp.float32(2)) # doesn't crash + + +class CustomDceTest(jtu.JaxTestCase): + + def test_basic(self): + @jax.experimental.custom_dce.custom_dce + def f(x): + return jnp.sin(x), jnp.cos(x) + + @f.def_dce + def rule(used_outs, x): + return ( + jnp.exp(x) if used_outs[0] else None, + jnp.sqrt(x) if used_outs[1] else None, + ) + + x = jnp.array(1.1234) + self.assertAllClose(jax.jit(lambda x: f(x)[0])(x), jnp.exp(x)) + self.assertAllClose(jax.jit(lambda x: f(x)[1])(x), jnp.sqrt(x)) + + def test_recursive(self): + @jax.experimental.custom_dce.custom_dce + def f(x): + return jnp.exp(x), 10 * jnp.sqrt(x) + + @f.def_dce + def f_dce(used_outs, x): + return [2 * v if used else None for used, v in zip(used_outs, f(x))] + + x = 1.1234 + expected = f(x) + self.assertAllClose(jax.jit(lambda x: f(x)[0])(x), 2 * expected[0]) + self.assertAllClose(jax.jit(lambda x: f(x)[1])(x), 2 * expected[1]) + + def test_multiple_rounds(self): + @jax.experimental.custom_dce.custom_dce + def f(x, y, z): + return jnp.sin(x), jnp.sin(y), jnp.sin(z) + + @f.def_dce + def rule(used_outs, x, y, z): + patterns.append(used_outs) + outs = [ + jnp.cos(v) if used else None for used, v in zip(used_outs, (x, y, z)) + ] + return outs + + patterns = [] + x, y, z = jnp.array(1.), jnp.array(2.), jnp.array(3.) + jaxpr = jax.make_jaxpr(f)(x, y, z).jaxpr + new_jaxpr, used_ins = pe.dce_jaxpr(jaxpr, [True, False, True]) + assert used_ins == [True, False, True] + new_jaxpr, used_ins = pe.dce_jaxpr(new_jaxpr, [True, False]) + assert used_ins == [True, False] + assert patterns == [(True, False, True), (True, False, False)], patterns + + def test_batching(self): + @jax.experimental.custom_dce.custom_dce + def f(x, y): + return jnp.sin(x), jnp.sin(y) + + @f.def_dce + def rule(used_outs, x, y): + return ( + jnp.cos(x) if used_outs[0] else None, + jnp.cos(y) if used_outs[1] else None, + ) + + x = jnp.linspace(-0.1, 0.2, 5) + y = jnp.linspace(3.0, 4.0, 5) + self.assertAllClose(jax.vmap(f)(x, y), f(x, y)) + self.assertAllClose( + jax.jit(lambda *args: jax.vmap(f)(*args)[0])(x, y), jnp.cos(x) + ) + self.assertAllClose( + jax.vmap(jax.jit(lambda *args: f(*args)[0]))(x, y), jnp.cos(x) + ) + self.assertAllClose( + jax.jit(lambda *args: jax.vmap(f)(*args)[1])(x, y), jnp.cos(y) + ) + self.assertAllClose( + jax.vmap(jax.jit(lambda *args: f(*args)[1]))(x, y), jnp.cos(y) + ) + + def test_composes_with_custom_vjp(self): + # custom_dce must be the "outer" decorator (for now!) because custom_vjp + # doesn't pass through DCE. + @jax.experimental.custom_dce.custom_dce + @jax.custom_vjp + def f(x, y): + return jnp.sin(x) * y, x * jnp.sin(y) + + @f.def_dce + def f_dce_rule(used_outs, x, y): + return ( + jnp.cos(x) * y if used_outs[0] else None, + x * jnp.cos(y) if used_outs[1] else None, + ) + + def f_fwd(x, y): + return f(x, y), (x, jnp.cos(x), jnp.sin(x), y, jnp.cos(y), jnp.sin(y)) + + def f_bwd(res, g): + ga, gb = g + x, cos_x, sin_x, y, cos_y, sin_y = res + return (cos_x * ga * y + sin_y * gb, sin_x * ga + x * cos_y * gb) + + f.defvjp(f_fwd, f_bwd) + + x, y = jnp.array(1.), jnp.array(2.) + self.assertAllClose(jax.jit(lambda *args: f(*args)[0])(x, y), + jnp.cos(x) * y) + jax.grad(lambda *args: f(*args)[0])(x, y) # Doesn't crash. + + def test_can_optimize_remat(self): + @jax.custom_vjp + def f(x): + return jnp.tan(x) + + @jax.experimental.custom_dce.custom_dce + def f_fwd(x): + return jnp.sin(x), (x,) + + @f_fwd.def_dce + def f_dce_rule(used_outs, x): + used_prim, used_res = used_outs + used_res, = used_res + if not used_res: + return f(x), None + prim, res = f_fwd(x) + return prim if used_prim else None, res + + def f_bwd(res, g): + x, = res + cos_x = jnp.cos(x) + return (cos_x * g,) + + f.defvjp(f_fwd, f_bwd) + + def temp(x): + out = jax.remat(f)(x) + out = out ** 2 + return out + + v, g = jax.value_and_grad(temp)(3.2) + self.assertAllClose(v, jnp.tan(3.2)**2) + + def test_static_argnums(self): + @partial(jax.experimental.custom_dce.custom_dce, static_argnums=(0,)) + def g(f, x): + return f(x), 10 * f(x) + + @g.def_dce + def g_dce(f, used_outs, x): # note: static_argnums are always passes first + self.assertTrue(callable(f)) + return [2 * v if used else None for used, v in zip(used_outs, g(f, x))] + + x = 1.1234 + f = lambda x: jnp.exp(x) + expected = g(f, x) + self.assertAllClose(jax.jit(lambda x: g(f, x)[0])(x), 2 * expected[0]) + self.assertAllClose(jax.jit(lambda x: g(f, x)[1])(x), 2 * expected[1]) + + def test_shape_mismatch_error(self): + @jax.experimental.custom_dce.custom_dce + def f(x): + return jnp.stack((x, x)), jnp.cos(x) + + @f.def_dce + def rule(used_outs, x): + return ( + jnp.exp(x) if used_outs[0] else None, + x.astype(jnp.int32) if used_outs[1] else None, + ) + + x = jnp.array(1.1234) + with self.assertRaisesRegex( + ValueError, + r'Custom DCE rule .* same shapes/dtypes .* output\[0\]', + ): + jax.jit(lambda x: f(x)[0])(x) + with self.assertRaisesRegex( + ValueError, + r'Custom DCE rule .* same shapes/dtypes .* output\[1\]', + ): + jax.jit(lambda x: f(x)[1])(x) + + def test_missing_output_error(self): + @jax.experimental.custom_dce.custom_dce + def f(x): + return jnp.sin(x), jnp.cos(x) + + @f.def_dce + def rule(used_outs, x): + return None, None + + x = jnp.array(1.1234) + with self.assertRaisesRegex( + ValueError, + r'Custom DCE rule .* produce values for all .* output\[0\]', + ): + jax.jit(lambda x: f(x)[0])(x) + + def test_consts(self): + @jax.experimental.custom_dce.custom_dce + def f(x): + return np.eye(1) * jnp.sin(x), jnp.cos(x) + + @f.def_dce + def rule(used_outs, x): + return ( + np.full((1, 1), 2.0) * jnp.exp(x) if used_outs[0] else None, + jnp.sqrt(x) if used_outs[1] else None, + ) + + x = jnp.array(1.1234) + expected = rule([True, True], x) + self.assertAllClose(jax.jit(lambda x: f(x)[0])(x), expected[0]) + self.assertAllClose(jax.jit(lambda x: f(x)[1])(x), expected[1]) + + def test_resolve_kwargs_error_message(self): + @jax.experimental.custom_dce.custom_dce + def f(x, y, *, z=None): + return jnp.sin(x) * y, x * jnp.sin(y) + + @f.def_dce + def f_dce_rule(used_outs, x, y): + self.fail("should not be executed") + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_dce-decorated function f(.*)\n" + r"missing a required argument: 'y'" + ): + f(0.5) + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_dce-decorated function f(.*)\n" + "The following keyword arguments could not be resolved to positions: z" + ): + f(0.5, 0.1, z=1.0) + + +class CustomVmapTest(jtu.JaxTestCase): + + def test_basic(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + xs_batched, = in_batched + self.assertEqual(xs_batched, True) + self.assertEqual(axis_size, xs.shape[0]) + return jnp.cos(xs), xs_batched + + x, xs = jnp.array(1.), jnp.arange(3) + y = f(x) + self.assertAllClose(y, jnp.sin(x)) + ys = api.vmap(f)(xs) + self.assertAllClose(ys, jnp.cos(xs)) + + @jax.numpy_dtype_promotion('standard') + def test_closure(self): + z = jnp.array([2., 1., 3.]) + + @jax.custom_batching.custom_vmap + def f(x): return z + jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, *args): + self.assertEqual(len(in_batched), 1) + self.assertEqual(len(args), 1) + xs, = args + xs_batched, = in_batched + self.assertEqual(xs_batched, True) + self.assertEqual(axis_size, xs.shape[0]) + return z + jnp.cos(xs), xs_batched + + x, xs = jnp.array(1.), jnp.arange(3) + y = f(x) + self.assertAllClose(y, z + jnp.sin(x)) + ys = api.vmap(f)(xs) + self.assertAllClose(ys, z + jnp.cos(xs)) + + def test_rule_multi_output(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x), jnp.cos(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + return (jnp.cos(xs), jnp.sin(xs)), tuple(in_batched * 2) + + x, xs = jnp.array(1.), jnp.arange(3) + y1, y2 = f(x) + self.assertAllClose(y1, jnp.sin(x)) + self.assertAllClose(y2, jnp.cos(x)) + ys1, ys2 = api.vmap(f)(xs) + self.assertAllClose(ys1, jnp.cos(xs)) + self.assertAllClose(ys2, jnp.sin(xs)) + + def test_nary(self): + @jax.custom_batching.custom_vmap + def f(x, y): return jnp.sin(x) + y ** 2. + + @f.def_vmap + def rule(axis_size, in_batched, xs, ys): + self.assertEqual(in_batched, [True, True]) + self.assertEqual(axis_size, 3) + self.assertEqual(axis_size, xs.shape[0]) + self.assertEqual(axis_size, ys.shape[0]) + return jnp.cos(xs) + ys ** 2., True + + xs, ys = jnp.arange(3.0), jnp.arange(3.0) + zs = api.vmap(f)(xs, ys) + self.assertAllClose(zs, jnp.cos(xs) + ys ** 2.) + + def test_nary_mixed_batching(self): + @jax.custom_batching.custom_vmap + def vector_dot(u, v): + self.assertEqual(u.ndim, 1) + self.assertEqual(v.ndim, 1) + return u @ v + + size = 4 + vlen = 3 + in_batched_log = [] + + @vector_dot.def_vmap + def vector_dot_vmap_rule(axis_size, in_batched, u, v): + in_batched_log.append(in_batched) + self.assertEqual(axis_size, size) + u_batched, v_batched = in_batched + if u_batched: + self.assertEqual(u.ndim, 2) + self.assertEqual(u.shape[0], size) + else: + self.assertEqual(u.ndim, 1) + self.assertEqual(u.shape[0], vlen) + if v_batched: + self.assertEqual(v.ndim, 2) + self.assertEqual(v.shape[0], size) + else: + self.assertEqual(v.ndim, 1) + self.assertEqual(v.shape[0], vlen) + if u_batched and v_batched: + out = jnp.sum(u * v, axis=1) + else: + out = u @ v if u_batched else v @ u + return out, u_batched or v_batched + + f = vector_dot + v = lambda *shape: jnp.ones(shape) + + y = api.vmap(f, in_axes=(0, None))(v(4, 3), v(3)) + self.assertAllClose(y, v(4, 3) @ v(3)) + y = api.vmap(f, in_axes=(1, None))(v(3, 4), v(3)) + self.assertAllClose(y, v(3, 4).T @ v(3)) + y = api.vmap(f, in_axes=(None, 0))(v(3), v(4, 3)) + self.assertAllClose(y, v(3) @ v(4, 3).T) + y = api.vmap(f, in_axes=(0, 0))(v(4, 3), v(4, 3)) + self.assertAllClose(y, jnp.sum(v(4, 3) * v(4, 3), axis=1)) + self.assertEqual(in_batched_log[0], [True, False]) + self.assertEqual(in_batched_log[1], [True, False]) + self.assertEqual(in_batched_log[2], [False, True]) + self.assertEqual(in_batched_log[3], [True, True]) + + def test_rule_input_signature(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + rule_args = [] + + @f.def_vmap + def rule(axis_size, in_batched, xs): + rule_args.append((axis_size, in_batched)) + return jnp.cos(xs), in_batched[0] + + xs = jnp.arange(3) + _ = api.vmap(f)(xs) + (axis_size, in_batched), = rule_args + self.assertIs(type(axis_size), int) + self.assertIs(type(in_batched), list) + self.assertEqual(len(in_batched), 1) + + def test_rule_output_vs_batching_output_mismatch(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def test_rule_abc(axis_size, in_batched, xs): + return [jnp.sin(xs), jnp.cos(xs)], in_batched + + xs = jnp.arange(3) + self.assertRaisesRegex( + ValueError, + 'structure of output value and output batching specification ' + r'returned by custom vmap rule \(test_rule_abc\) do not match.*', + lambda: api.vmap(f)(xs)) + + def test_rule_vs_call_output_mismatch(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def test_rule_abc2(axis_size, in_batched, xs): + return [jnp.sin(xs)], in_batched + + xs = jnp.arange(3) + self.assertRaisesRegex( + ValueError, + r'structure of output returned by custom vmap rule \(test_rule_abc2\) ' + r'does not match that of original custom-vmapped function.*', + lambda: api.vmap(f)(xs)) + + def test_jvp_basic(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + self.assertEqual(axis_size, 3) + self.assertEqual(in_batched, [True]) + return jnp.cos(xs), in_batched[0] + + f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) + + x, tx = jnp.array(1.), jnp.array(2.) + xs, txs = jnp.arange(3.), jnp.arange(3.) * 2. + + y, ty = f_jvp(x, tx) + self.assertAllClose(y, jnp.sin(x)) + self.assertAllClose(ty, jnp.cos(x) * tx) + + ys, tys = api.vmap(f_jvp)(xs, txs) + self.assertAllClose(ys, jnp.cos(xs)) + self.assertAllClose(tys, -jnp.sin(xs) * txs) + + ys, tys = api.jvp(api.vmap(f), [xs], [txs]) + self.assertAllClose(ys, jnp.cos(xs)) + self.assertAllClose(tys, -jnp.sin(xs) * txs) + + @jax.numpy_dtype_promotion('standard') + def test_jvp_closure(self): + z = jnp.array([2., 1., 3.]) + def bcast(x): return z + x - z + + @jax.custom_batching.custom_vmap + def f(x): return z + jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + self.assertEqual(axis_size, 3) + self.assertEqual(in_batched, [True]) + return z + jnp.cos(xs), in_batched[0] + + f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) + + x, tx = jnp.array(1.), jnp.array(2.) + xs, txs = jnp.arange(3.), jnp.arange(3.) * 2. + + y, ty = f_jvp(x, tx) + self.assertAllClose(y, z + jnp.sin(x)) + self.assertAllClose(ty, bcast(jnp.cos(x)) * tx) + + ys, tys = api.vmap(f_jvp)(xs, txs) + self.assertAllClose(ys, z + jnp.cos(xs)) + self.assertAllClose(tys, bcast(-jnp.sin(xs)) * txs) + + ys, tys = api.jvp(api.vmap(f), [xs], [txs]) + self.assertAllClose(ys, z + jnp.cos(xs)) + self.assertAllClose(tys, bcast(-jnp.sin(xs)) * txs) + + def test_jvp_nary(self): + @jax.custom_batching.custom_vmap + def f(x, y): return jnp.sin(x) + y + + @f.def_vmap + def rule(axis_size, in_batched, xs, ys): + self.assertEqual(axis_size, 3) + self.assertEqual(in_batched, [True, True]) + return jnp.cos(xs) + ys, True + + f_jvp = lambda x, y, tx, ty: api.jvp(f, [x, y], [tx, ty]) + + x, y, tx, ty = jnp.arange(4.) + xs, ys, txs, tys = 4. + jnp.arange(3. * 4).reshape((4, 3)) + + zs, tzs = api.vmap(f_jvp)(xs, ys, txs, tys) + self.assertAllClose(zs, jnp.cos(xs) + ys) + self.assertAllClose(tzs, -jnp.sin(xs) * txs + tys) + + zs, tzs = api.jvp(api.vmap(f), [xs, ys], [txs, tys]) + self.assertAllClose(zs, jnp.cos(xs) + ys) + self.assertAllClose(tzs, -jnp.sin(xs) * txs + tys) + + def test_jvp_extra_batched_tangents(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + self.assertEqual(axis_size, 3) + self.assertEqual(in_batched, [False]) + return jnp.cos(xs), in_batched[0] + + f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) + + txs = 2. + jnp.arange(3.) + x = jnp.array(1, dtype=txs.dtype) + y, tys = api.vmap(f_jvp, in_axes=(None, 0), out_axes=(None, 0))(x, txs) + self.assertAllClose(y, jnp.cos(x)) + self.assertAllClose(tys, -jnp.sin(x) * txs) + + def test_jacfwd(self): + # jacfwd is another way to exercise extra-batched tangents + + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + self.assertEqual(axis_size, 3) + self.assertEqual(in_batched, [False]) + return jnp.cos(xs), in_batched[0] + + x = jnp.arange(3.) + .72 + j = api.jacfwd(f)(x) + self.assertAllClose(j, -jnp.diag(jnp.sin(x))) + + def test_jvp_extra_batched_primals(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + self.assertEqual(axis_size, 3) + self.assertEqual(in_batched, [False]) + return jnp.cos(xs), in_batched[0] + + f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) + + xs = jnp.arange(3.) + tx = jnp.array(4, dtype=xs.dtype) + ys, tys = api.vmap(f_jvp, in_axes=(0, None))(xs, tx) + self.assertAllClose(ys, jnp.cos(xs)) + self.assertAllClose(tys, -jnp.sin(xs) * tx) + + def test_jvp_extra_batched_primals_with_linear_vmap_rule(self): + # When a function is linear, its Jacobian is constant. JAX's JVP + # of linear functions takes advantage of this: when mapping over a + # batch of primals relative to a fixed (i.e. symbolically + # replicated) tangent, output tangents remain replicated as well + # (i.e. JAX will not broadcast them). This is true in general, and + # this test checks that vmapped JVPs continue to behave this way + # when custom_vmap is involved and the custom vmap rule is linear. + + @jax.custom_batching.custom_vmap + def f_linear(x): return 7. * x + + @f_linear.def_vmap + def linear_rule(axis_size, in_batched, xs): + return 11. * xs, in_batched[0] + + @jax.custom_batching.custom_vmap + def f_nonlinear(x): return jnp.sin(x) + + @f_nonlinear.def_vmap + def nonlinear_rule(axis_size, in_batched, xs): + return jnp.cos(xs), in_batched[0] + + f_lin_jvp = lambda x, tx: api.jvp(f_linear, [x], [tx]) + f_non_jvp = lambda x, tx: api.jvp(f_nonlinear, [x], [tx]) + xs = jnp.arange(3.) + tx = jnp.array(4., dtype=xs.dtype) + + # doesn't err + _ = api.vmap(f_lin_jvp, in_axes=(0, None), out_axes=(0, None))(xs, tx) + + # does err + self.assertRaisesRegex( + ValueError, "at vmap out_axes", + lambda: api.vmap( + f_non_jvp, in_axes=(0, None), out_axes=(0, None))(xs, tx)) + + def test_jvp_dataflow_violation(self): + # The jvp-of-custom-vmap machinery should not assume the standard + # dataflow constraint on the JVP of the custom vmap rule (primal + # outputs independent of tangent inputs). Both jvp and vmap are + # "forward" transformations under which, at present, we don't + # enforce the JVP dependence diagram. Because output primals can + # depend on input tangents, extra-batched input tangents can + # create batched output primals, as this test checks. + + @jax.custom_jvp + def cos_with_invalid_dataflow_jvp(x): return jnp.cos(x) + + @cos_with_invalid_dataflow_jvp.defjvp + def invalid_dataflow_jvp(x, tx): + [x], [tx] = x, tx + return jnp.cos(x * tx), tx + + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + return cos_with_invalid_dataflow_jvp(xs), in_batched[0] + + f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) + txs = 2. + jnp.arange(3.) + x = jnp.array(1, dtype=txs.dtype) + + # doesn't err + ys, tys = api.vmap(f_jvp, in_axes=(None, 0))(x, txs) + self.assertAllClose(ys, jnp.cos(x * txs)) + self.assertAllClose(tys, txs) + + # does err + self.assertRaisesRegex( + ValueError, "at vmap out_axes", + lambda: api.vmap( + f_jvp, in_axes=(None, 0), out_axes=(None, 0))(x, txs)) + + def test_tree(self): + tree_sin = partial(jax.tree.map, jnp.sin) + tree_cos = partial(jax.tree.map, jnp.cos) + + x, xs = jnp.array(1.), jnp.arange(3) + x = (x, [x + 1, x + 2], [x + 3], x + 4) + xs = (xs, [xs + 1, xs + 2], [xs + 3], xs + 4) + in_batched_ref = jax.tree.map(lambda _: True, x) + + @jax.custom_batching.custom_vmap + def f(xs): return tree_sin(xs) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + self.assertEqual(in_batched, [in_batched_ref]) + sz, = {z.shape[0] for z in jax.tree.leaves(xs)} + self.assertEqual(axis_size, sz) + return tree_cos(xs), in_batched[0] + + y = f(x) + self.assertAllClose(y, tree_sin(x)) + ys = api.vmap(f)(xs) + self.assertAllClose(ys, tree_cos(xs)) + + def test_tree_with_nones(self): + tree_sin = partial(jax.tree.map, jnp.sin) + tree_cos = partial(jax.tree.map, jnp.cos) + + x, xs = jnp.array(1.), jnp.arange(3) + x = (x, [x + 1, None], [x + 3], None) + xs = (xs, [xs + 1, None], [xs + 3], None) + in_batched_ref = jax.tree.map(lambda _: True, x) + + @jax.custom_batching.custom_vmap + def f(xs): return tree_sin(xs) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + self.assertEqual(in_batched, [in_batched_ref]) + sz, = {z.shape[0] for z in jax.tree.leaves(xs)} + self.assertEqual(axis_size, sz) + return tree_cos(xs), in_batched[0] + + y = f(x) + self.assertAllClose(y, tree_sin(x)) + ys = api.vmap(f)(xs) + self.assertAllClose(ys, tree_cos(xs)) + + def test_jit(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + self.assertEqual(in_batched, [True]) + self.assertEqual(axis_size, xs.shape[0]) + return jnp.cos(xs), in_batched[0] + + x, xs = jnp.array(1.), jnp.arange(3) + self.assertAllClose(f(x), jit(f)(x)) + self.assertAllClose(jit(api.vmap(f))(xs), api.vmap(f)(xs)) + self.assertAllClose(api.vmap(jit(f))(xs), api.vmap(f)(xs)) + + def test_sequential_vmap_basic(self): + @jax.custom_batching.sequential_vmap + def f(x): + return x + 1. + + def vmap_ref(xs): + return lax.map(f, xs) + + xs = jnp.arange(3.) + jaxpr = api.make_jaxpr(api.vmap(f))(xs) + jaxpr_ref = api.make_jaxpr(vmap_ref)(xs) + + self.assertEqual(str(jaxpr), str(jaxpr_ref)) + + def test_sequential_vmap_nary_same_batching(self): + @jax.custom_batching.sequential_vmap + def f(x, y): + return x + y + + def vmap_ref(xs, ys): + return lax.map(lambda args: f(*args), (xs, ys)) + + xs, ys = jnp.arange(3.), 4. + jnp.arange(3.) + jaxpr = api.make_jaxpr(api.vmap(f))(xs, ys) + jaxpr_ref = api.make_jaxpr(vmap_ref)(xs, ys) + + self.assertEqual(str(jaxpr), str(jaxpr_ref)) + + def test_sequential_vmap_nary_mixed_batching(self): + @jax.custom_batching.sequential_vmap + def f(x, y): + return x + y + + def vmap_ref(xs, y): + return lax.map(lambda x: f(x, y), xs) + + xs, y = jnp.arange(3.), 4. + jaxpr = api.make_jaxpr(api.vmap(f, in_axes=(0, None)))(xs, y) + jaxpr_ref = api.make_jaxpr(vmap_ref)(xs, y) + + self.assertEqual(str(jaxpr), str(jaxpr_ref)) + + @parameterized.named_parameters( + ("0", 0), + ("1", 1), + ("8", 4), + ("12", 8), + ("16", 16), + ) + def test_batch_map_basic(self, batch_size: int): + def f(x): + self.assertEqual(x.shape, ()) + return x**2 + + x = np.arange(16) + y = jax.lax.map(f, x, batch_size=batch_size) + + np.testing.assert_array_equal(y, x**2) + + @parameterized.named_parameters( + ("0", 0), + ("1", 1), + ("8", 4), + ("12", 8), + ("16", 16), + ) + def test_batch_map_pytrees(self, batch_size: int): + f = lambda x: {'b': x['a'] ** 2} + inputs = {'a': np.arange(16)} + expected = np.arange(16) ** 2 + + outputs = jax.lax.map(f, inputs, batch_size=batch_size) + self.assertAllClose(outputs['b'], expected) + + outputs = jax.lax.map( + f, inputs, batch_size=batch_size + ) + self.assertAllClose(outputs['b'], expected) + + def test_batch_divides_axis(self): + def f(t): + x, a = t + self.assertEqual(x.shape, (4,)) + return (x + a)**2 + + x = jax.random.randint(jax.random.key(0), (16, 4), -10, 10) + a = jax.random.randint(jax.random.key(1), (16, 4), -10, 10) + + @jax.jit + def g(x, a): + return jax.lax.map(f, (x, a), batch_size=8) + + y = g(x, a) + + self.assertAllClose(y, (x + a)**2) + + def test_undefined_rule(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + with self.assertRaisesRegex( + AttributeError, "No batching rule defined for custom_vmap function f"): + f(0.5) + + def test_kwargs(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + xs_batched, = in_batched + self.assertEqual(xs_batched, True) + self.assertEqual(axis_size, xs.shape[0]) + return jnp.cos(xs), xs_batched + + x, xs = jnp.array(1.), jnp.arange(3) + y = f(x=x) + self.assertAllClose(y, jnp.sin(x)) + ys = api.vmap(f)(x=xs) + self.assertAllClose(ys, jnp.cos(xs)) + + def test_partial_eval_raises(self): + @jax.custom_batching.custom_vmap + def f(x): + return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + del axis_size # unused + return jnp.cos(xs), in_batched[0] + + with self.assertRaisesRegex( + ValueError, + "Linearization failed to produce known values for all output primals", + ): + jax.grad(f)(0.5) + + def test_compose_custom_vjp(self): + @jax.custom_vjp + @jax.custom_batching.custom_vmap + def f(x, y): + return jnp.sin(x) * y + + @f.def_vmap + def f_vmap_rule(axis_size, in_batched, xs, ys): + return jnp.cos(xs) * ys, True + + def f_fwd(x, y): + return f(x, y), (jnp.cos(x), jnp.sin(x), y) + + def f_bwd(res, g): + cos_x, sin_x, y = res + return (cos_x * g * y, sin_x * g) + + f.defvjp(f_fwd, f_bwd) + + xs = jnp.linspace(0, 1, 5) + ys = jnp.linspace(-0.1, 0.1, 5) + self.assertAllClose(jax.vmap(f)(xs, ys), jnp.cos(xs) * ys) + jax.grad(f)(xs[0], ys[0]) # Doesn't crash. + + def test_compose_custom_vjp_bwd_rule(self): + # This tests the case where both the forward and backward rules are wrapped + # in custom_vmap. + @jax.custom_batching.sequential_vmap + def fun_fwd(x, y): + return jnp.sin(x) * y, (x, y) + + @jax.custom_batching.sequential_vmap + def fun_bwd(res, ct): + x, y = res + return x * ct, y * ct + + fun = jax.custom_vjp(lambda *args: fun_fwd(*args)[0]) + fun.defvjp(fun_fwd, fun_bwd) + + xs = jnp.linspace(0, 1, 5) + y = jnp.array(0.5, dtype=xs.dtype) + f = jax.vmap(jax.jit(fun), in_axes=(0, None)) + out, f_vjp = jax.vjp(f, xs, y) + f_vjp(out) # Doesn't crash. + + def test_resolve_kwargs_error_message(self): + @jax.custom_batching.custom_vmap + def f(x, y, *, z=None): + return jnp.sin(x) * y + + @f.def_vmap + def f_vmap_rule(axis_size, in_batched, xs, ys): + self.fail("should not be executed") + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_vmap-decorated function f(.*)\n" + r"missing a required argument: 'y'" + ): + f(0.5) + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_vmap-decorated function f(.*)\n" + "The following keyword arguments could not be resolved to positions: z" + ): + f(0.5, 0.1, z=1.0) + + +class CustomApiTest(jtu.JaxTestCase): + """Test interactions among the custom_{vmap,jvp,vjp,transpose,*} APIs""" + + def test_method_forwarding(self): + @jax.custom_batching.custom_vmap + @jax.custom_jvp + @jax.custom_transpose.custom_transpose + def f(x): return 2. * x + + # none of these err: + @f.def_vmap + def f_batch(sz, b, xs): return 2. * xs + @f.defjvp + def f_jvp(x, tx): return 2. * x, 2. * tx + @f.def_transpose + def f_transpose(x): return 2. * x + + def test_def_method_forwarding_all_permutations(self): + for wraps in it.permutations([ + jax.custom_jvp, jax.custom_transpose.custom_transpose, jax.custom_batching.custom_vmap]): + f = lambda x: x + 1. + for wrap in wraps: + f = wrap(f) + for methods in it.permutations(['defjvp', 'def_vmap', 'def_transpose']): + for method in methods: + self.assertIsInstance(getattr(f, method), Callable) + + for decorators in it.permutations([ + jax.custom_vjp, jax.custom_transpose.custom_transpose, jax.custom_batching.custom_vmap]): + f = lambda x: x + 1. + for decorator in decorators: + f = decorator(f) + for methods in it.permutations(['defvjp', 'def_vmap', 'def_transpose']): + for method in methods: + self.assertIsInstance(getattr(f, method), Callable) + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/custom_linear_solve_test.py b/tests/custom_linear_solve_test.py index 8dfb338f3be1..00d5396962b7 100644 --- a/tests/custom_linear_solve_test.py +++ b/tests/custom_linear_solve_test.py @@ -23,7 +23,7 @@ import jax from jax import lax -from jax.ad_checkpoint import checkpoint +from jax import checkpoint from jax._src import test_util as jtu import jax.numpy as jnp # scan tests use numpy import jax.scipy as jsp @@ -502,6 +502,21 @@ def solve(b): self.assertEqual(output(), "mat_vec\n") self.assertAllClose(computed, expected) + def test_symbolic_zero_cotangents(self): + # https://github.com/jax-ml/jax/issues/29342 + def g(x): + def p(z): + return jnp.linalg.solve(z.sum()*jnp.eye(3), jnp.array([1., 0., 0.]))[0] + h = lambda y: jax.jvp(jax.vmap(p), (x,), (y,))[1] + return h + + def f(x): + return jax.vjp(g(x), jnp.ones_like(x))[1](x)[0] + + x = jnp.array([200.0]) + f_x = f(x) + grad_f_x = jax.jacrev(f)(x) # don't crash + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/custom_partitioning_sharding_rule_test.py b/tests/custom_partitioning_sharding_rule_test.py index f22721910408..920b68b5bc2e 100644 --- a/tests/custom_partitioning_sharding_rule_test.py +++ b/tests/custom_partitioning_sharding_rule_test.py @@ -148,6 +148,32 @@ def test_sharding_rule_ellipsis_inside_compound_dim(self): ValueError, "Ellipsis can only be used at the beginning of a dimension"): str_to_sdy_sharding_rule("i, (..., j) -> j") + def test_sharding_rule_redcution_factors_is_not_used(self): + with self.assertRaisesRegex( + ValueError, "Factor k in reduction_factors is not used"): + str_to_sdy_sharding_rule("i -> j", reduction_factors=("k",)) + + def test_sharding_rule_need_replication_factors_is_not_used(self): + with self.assertRaisesRegex( + ValueError, "Factor k in need_replication_factors is not used"): + str_to_sdy_sharding_rule("(i, j) -> (j, i)", need_replication_factors=("k",), i=10, j=20) + + def test_sharding_rule_permutation_factors_must_be_a_tuple_of_factors(self): + with self.assertRaisesRegex( + ValueError, "permutation_factors must be a tuple of factors"): + str_to_sdy_sharding_rule("i j -> j", permutation_factors=3) + + def test_sharding_rule_factor_used_in_multiple_special_factors(self): + with self.assertRaisesRegex( + ValueError, "Factor i can only be in one of the reduction, need " + "replication, or permutation factor sets"): + str_to_sdy_sharding_rule("i -> j", reduction_factors=("i",), need_replication_factors=("i",)) + + def test_sharding_rule_duplicated_factors_in_special_factors(self): + with self.assertRaisesRegex( + ValueError, "reduction_factors contains duplicated factors"): + str_to_sdy_sharding_rule("i -> j", reduction_factors=("i", "j", "i")) + def test_sharding_rule_scalar_operand_scalar_result(self): rule = str_to_sdy_sharding_rule("->") self.assertEqual(str(rule), "SdyShardingRule(((),), ((),), {})") @@ -202,6 +228,18 @@ def test_sharding_rule_factor_infer_k(self): "SdyShardingRule((('i_', ('j', 'k')),), (('j', 'foo', ('m', 'bar_24'))" ",), {'k': 10, 'm': 10, 'bar_24': 20})") + def test_sharding_rule_with_special_factors(self): + rule = str_to_sdy_sharding_rule("i_ (j k)-> j foo (m bar_24)", k=10, m=10, bar_24=20, + need_replication_factors=("m",), + permutation_factors=("j",), + reduction_factors=("k", "bar_24")) + self.assertEqual( + str(rule), + "SdyShardingRule((('i_', ('j', 'k')),), (('j', 'foo', ('m', 'bar_24'))" + ",), {'k': 10, 'm': 10, 'bar_24': 20} " + "reduction_factors=('k', 'bar_24') " + "need_replication_factors=('m',) permutation_factors=('j',))") + class SdyShardingRuleConversionTest(jtu.JaxTestCase): @@ -290,21 +328,21 @@ def test_conversion_rule_op_mismatch_in_results_dim(self): [result.operands[0].type, result.operands[1].type], [result.result.type,]) - def test_conversion_factor_has_two_sizes(self): - opnd0 = self.create_tensor_value((16, 32)) + def test_conversion_factor_with_multiple_sizes_use_smallest_size(self): + opnd0 = self.create_tensor_value((16, 16)) opnd1 = self.create_tensor_value((16, 32)) result = ir.Operation.create( "stablehlo.custom_call", - results=[self.get_tensor_type((16, 64))], + results=[self.get_tensor_type((16, 8))], operands=[opnd0, opnd1,], attributes=dict(call_target_name=ir.StringAttr.get("foo"))) rule = str_to_sdy_sharding_rule("i j, i j -> i j") - with self.assertRaisesRegex( - ValueError, - "Factor j corresponds to two sizes: 32 and 64"): - sdy_sharding_rule_to_mlir(rule, + mlir_rule = sdy_sharding_rule_to_mlir(rule, [result.operands[0].type, result.operands[1].type], [result.result.type,]) + self.assertEqual( + str(mlir_rule), + "#sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=16, j=8}, custom>") def test_conversion_batching_dim_has_two_sizes(self): opnd0 = self.create_tensor_value((16, 32)) @@ -383,7 +421,7 @@ def test_conversion_compound_then_individual(self): [result.result.type,]) self.assertEqual( str(mlir_rule), - "#sdy.op_sharding_rule<([ij])->([i, j]) {i=2, j=4}>") + "#sdy.op_sharding_rule<([ij])->([i, j]) {i=2, j=4}, custom>") def test_conversion_elementwise_rule_scalar_instance(self): opnd0 = self.create_tensor_value(()) @@ -399,7 +437,7 @@ def test_conversion_elementwise_rule_scalar_instance(self): [result.result.type,]) self.assertEqual( str(mlir_rule), - "#sdy.op_sharding_rule<([], [])->([])>") + "#sdy.op_sharding_rule<([], [])->([]), custom>") def test_conversion_elementwise_rule_2D_instance(self): opnd0 = self.create_tensor_value((16, 32)) @@ -415,7 +453,7 @@ def test_conversion_elementwise_rule_2D_instance(self): [result.result.type,]) self.assertEqual( str(mlir_rule), - "#sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=16, j=32}>") + "#sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=16, j=32}, custom>") def test_conversion_vector_scalar_add_2D_instance(self): opnd0 = self.create_tensor_value((16, 32)) @@ -431,7 +469,7 @@ def test_conversion_vector_scalar_add_2D_instance(self): [result.result.type,]) self.assertEqual( str(mlir_rule), - "#sdy.op_sharding_rule<([i, j], [])->([i, j]) {i=16, j=32}>") + "#sdy.op_sharding_rule<([i, j], [])->([i, j]) {i=16, j=32}, custom>") def test_conversion_reshape_rule(self): opnd0 = self.create_tensor_value((2, 4)) @@ -446,7 +484,7 @@ def test_conversion_reshape_rule(self): [result.result.type,]) self.assertEqual( str(mlir_rule), - "#sdy.op_sharding_rule<([i, j])->([ij]) {i=2, j=4}>") + "#sdy.op_sharding_rule<([i, j])->([ij]) {i=2, j=4}, custom>") def test_conversion_contracting_dim_matmul(self): opnd0 = self.create_tensor_value((16, 32)) @@ -456,14 +494,14 @@ def test_conversion_contracting_dim_matmul(self): results=[self.get_tensor_type((16, 8))], operands=[opnd0, opnd1,], attributes=dict(call_target_name=ir.StringAttr.get("foo"))) - rule = str_to_sdy_sharding_rule("... contracting_dim, contracting_dim k -> ... k") + rule = str_to_sdy_sharding_rule("... contracting_dim, contracting_dim k -> ... k", + reduction_factors=("contracting_dim",)) mlir_rule = sdy_sharding_rule_to_mlir(rule, [result.operands[0].type, result.operands[1].type], [result.result.type,]) self.assertEqual( str(mlir_rule), - "#sdy.op_sharding_rule<([i, j], [j, k])->([i, k]) {i=16, j=32, k=8}>") - + "#sdy.op_sharding_rule<([i, j], [j, k])->([i, k]) {i=16, j=32, k=8} reduction={j}, custom>") def test_conversion_multiple_batching_groups(self): opnd0 = self.create_tensor_value((4, 5, 16, 32)) @@ -479,8 +517,7 @@ def test_conversion_multiple_batching_groups(self): [result.result.type,]) self.assertEqual( str(mlir_rule), - "#sdy.op_sharding_rule<([i, j, k, l], [m, n, o, l, k])->([i, j, l, k]) {i=4, j=5, k=16, l=32, m=6, n=7, o=8}>") - + "#sdy.op_sharding_rule<([i, j, k, l], [m, n, o, l, k])->([i, j, l, k]) {i=4, j=5, k=16, l=32, m=6, n=7, o=8}, custom>") if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/custom_partitioning_test.py b/tests/custom_partitioning_test.py new file mode 100644 index 000000000000..a03aff84c012 --- /dev/null +++ b/tests/custom_partitioning_test.py @@ -0,0 +1,481 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from functools import partial +from absl.testing import absltest + +import jax +import jax.numpy as jnp +from jax import P +from jax._src import test_util as jtu +from jax._src import config +from jax._src.named_sharding import NamedSharding +from jax.experimental.custom_partitioning import ( + custom_partitioning, SdyShardingRule, BATCHING) + +config.parse_flags_with_absl() +jtu.request_cpu_devices(8) + + +@jtu.pytest_mark_if_available('multiaccelerator') +class CustomPartitionerTest(jtu.JaxTestCase): + + def skip_if_custom_partitioning_not_supported(self): + if jtu.is_cloud_tpu(): + raise unittest.SkipTest("Custom partitioning is not supported on libtpu.") + + @jtu.skip_on_devices('cpu') # Collectives don't seem to work on CPU. + def test_custom_partitioner(self): + self.skip_if_custom_partitioning_not_supported() + + def partition(precision, mesh, arg_shapes, result_shape): + arg_shardings = jax.tree.map(lambda s: s.sharding, arg_shapes) + result_sharding = result_shape[0].sharding + self.assertEqual(arg_shardings[0], result_sharding) + self.assertEqual(P('x', None), result_sharding.spec) + self.assertEqual(P('y', None), arg_shardings[1].spec) + + def lower_fn(x, y): + axis_name = arg_shardings[1].spec[0][0] + i = jax.lax.axis_index(axis_name) + # Use offset i * 0 instead of 0 to ensure that the two offsets have the + # same dtype regardless the value of config.enable_x64. + z = jax.lax.psum( + jax.lax.dynamic_slice(x, (i * 0, i * 8), (8, 8)) @ y, (axis_name) + ) + return z, z * z + + return mesh, lower_fn, (result_sharding, result_sharding), arg_shardings + + def infer_sharding_from_operands(precision, mesh, arg_shapes, result_shape): + arg_shardings = jax.tree.map(lambda s: s.sharding, arg_shapes) + x_shard, y_shard = arg_shardings + x_shape, y_shape = arg_shapes + x_names = tuple(x_shard.spec) + tuple( + None for _ in range(len(x_shape.shape) - len(x_shard.spec))) + y_names = tuple(y_shard.spec) + tuple( + None for _ in range(len(y_shape.shape) - len(y_shard.spec))) + z_shard = NamedSharding(y_shard.mesh, P(*(x_names[:-1] + y_names[1:]))) + return z_shard, z_shard + + @partial(custom_partitioning, static_argnums=(2,)) + def f(x, y, precision=None): + z = jnp.matmul(x, y, precision=precision) + return z, z * z + + f.def_partition( + infer_sharding_from_operands=infer_sharding_from_operands, + partition=partition, + sharding_rule=SdyShardingRule(operand_mappings=(('i', 'j'), ('j', 'k')), result_mappings=(('i', 'k'), ('i', 'k')))) + + with jax.set_mesh(jtu.create_mesh((4, 2), ('x', 'y'))): + jit_f = jax.jit(f, in_shardings=(P('x'), P('y')), out_shardings=P('x')) + x = np.asarray(np.random.randint(0, 20, (32, 16)), dtype=np.float32) + y = np.asarray(np.random.randint(0, 20, (16, 32)), dtype=np.float32) + x_sharded = jax.device_put(x, P('x')) + y_sharded = jax.device_put(y, P('y')) + result1 = jax.jit(f)(x_sharded, y_sharded) + result2 = f(x, y) + result0 = jit_f(x_sharded, y_sharded) + self.assertArraysEqual(result0, result1) + self.assertArraysEqual(result1, result2) + + def test_custom_partitioner_propagate_user_sharding(self): + self.skip_if_custom_partitioning_not_supported() + + def partition(mesh, arg_shapes, result_shape): + def lower_fn(x): + return x + + return ( + mesh, + lower_fn, + arg_shapes[0].sharding, + (arg_shapes[0].sharding,), + ) + + def infer_sharding_from_operands(mesh, arg_shapes, result_shape): + return arg_shapes[0].sharding + + def propagate_user_sharding(mesh, user_shape): + return user_shape.sharding + + @custom_partitioning + def f(x): + return x + + f.def_partition( + infer_sharding_from_operands=infer_sharding_from_operands, + partition=partition, + propagate_user_sharding=propagate_user_sharding, + sharding_rule='i j -> i j', + ) + + def f2(a): + return a + f(a) + + with jax.set_mesh(jtu.create_mesh((4, 2), ('x', 'y'))): + jit_f = jax.jit(f2, in_shardings=(P(None, 'x')), out_shardings=P('x')) + x = np.asarray(np.random.randint(0, 20, (32, 16)), dtype=np.float32) + self.assertArraysEqual(x + x, jit_f(jax.device_put(x, P(None, 'x')))) + + def test_custom_partitioner_sharding_override(self): + self.skip_if_custom_partitioning_not_supported() + + def partition(mesh, arg_shapes, result_shape): + def lower_fn(x): + return x + + y_shard = arg_shapes[0].sharding + return ( + mesh, + lower_fn, + NamedSharding(y_shard.mesh, P(None)), + (NamedSharding(y_shard.mesh, P(None)),), + ) + + def infer_sharding_from_operands(mesh, arg_shapes, result_shape): + y_shard = arg_shapes[0].sharding + return NamedSharding(y_shard.mesh, P('x')) + + @custom_partitioning + def f(x): + return x + + f.def_partition( + infer_sharding_from_operands=infer_sharding_from_operands, + partition=partition, + sharding_rule=SdyShardingRule(operand_mappings=((BATCHING, 'i'),), result_mappings=((BATCHING, 'i'),))) + + with jax.set_mesh(jtu.create_mesh((4, 2), ('x', 'y'))): + jit_f = jax.jit(f, in_shardings=(P(None, 'x')), out_shardings=P('x')) + x = np.asarray(np.random.randint(0, 20, (32, 16)), dtype=np.float32) + self.assertArraysEqual(x, jit_f(jax.device_put(x, P(None, 'x')))) + + def test_custom_partitioner_invalid_sharding(self): + self.skip_if_custom_partitioning_not_supported() + def partition(mesh, arg_shapes, result_shape): + def lower_fn(x): + return x + + y_shard = arg_shapes[0].sharding + return ( + mesh, + lower_fn, + NamedSharding(y_shard.mesh, P(None)), + (NamedSharding(y_shard.mesh, P(None, 'x')),), + ) + + def infer_sharding_from_operands(mesh, arg_shapes, result_shape): + y_shard = arg_shapes[0].sharding + return NamedSharding(y_shard.mesh, P('x')) + + @custom_partitioning + def f(x): + return x + + f.def_partition( + infer_sharding_from_operands=infer_sharding_from_operands, + partition=partition, + sharding_rule='i j -> i j', + ) + + with jax.set_mesh(jtu.create_mesh((4, 2), ('x', 'y'))): + jit_f = jax.jit(f, in_shardings=(P(None, 'x')), out_shardings=P('x')) + x = np.asarray(np.random.randint(0, 20, (32, 16)), dtype=np.float32) + + with self.assertRaisesRegex(Exception, 'Mismatch in result shapes.'): + jit_f(jax.device_put(x, P(None, 'x'))).block_until_ready() + + def test_custom_partitioner_jit_annotated_function(self): + """Test correct lowering of function with a @jax.jit annotated callee. + + Annotating a callee with @jax.jit results in a module with a HLO CallOp. + This test is makes sure that the custom partitioner lowering supports + CallOps. + """ + + self.skip_if_custom_partitioning_not_supported() + + @custom_partitioning + def f(x): + return x + + def partition(mesh, arg_shapes, result_shape): + def lower_fn(x): + @jax.jit + def g(y): + return y + + return g(x) + + x_shard = arg_shapes[0].sharding + return ( + mesh, + lower_fn, + NamedSharding(x_shard.mesh, P('x')), + (NamedSharding(x_shard.mesh, P('x')),), + ) + + def infer_sharding_from_operands(mesh, arg_shapes, result_shape): + x_shard = arg_shapes[0].sharding + return NamedSharding(x_shard.mesh, P('x')) + + f.def_partition( + infer_sharding_from_operands=infer_sharding_from_operands, + partition=partition, + sharding_rule='i -> i', + ) + + with jax.set_mesh(jtu.create_mesh((4,), ('x',))): + jit_f = jax.jit(f) + x = np.asarray(np.random.randint(0, 20, (32,)), dtype=np.float32) + jit_f = jax.jit(jit_f, in_shardings=(P('x')), out_shardings=P('x')) + self.assertArraysEqual(x, jit_f(jax.device_put(x, P('x')))) + + def test_custom_partitioner_with_scan(self): + self.skip_if_custom_partitioning_not_supported() + + # This is a reproducer from https://github.com/jax-ml/jax/issues/20864. + + @custom_partitioning + def f(x): + return jnp.sum(x) + + def partition(mesh, arg_shapes, result_shape): + def lower_fn(xs): + def f(carry, x): + return carry + jax.lax.psum(jnp.sum(x), axis_name='x'), None + + carry, _ = jax.lax.scan(f, 0, xs) + return carry + + result_shardings = jax.tree.map(lambda x: x.sharding, result_shape) + arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) + return mesh, lower_fn, result_shardings, arg_shardings + + f.def_partition( + partition, + infer_sharding_from_operands=lambda mesh, *_: NamedSharding(mesh, P()), + propagate_user_sharding=lambda _, user_shape: user_shape.sharding, + sharding_rule='i j -> ') # Result is a scalar. + + with jax.set_mesh(jtu.create_mesh((4,), ('x',))): + jit_f = jax.jit(f, in_shardings=P(None, 'x')) + xs = jax.device_put(jnp.ones([32, 16]), P(None, 'x')) + self.assertEqual(jit_f(xs), xs.sum()) + + def test_custom_partitioning_no_mesh_context(self): + self.skip_if_custom_partitioning_not_supported() + + @custom_partitioning + def f(x): + return x + + def partition(mesh, arg_shapes, result_shape): + def lower_fn(x): + @jax.jit + def g(y): + return y + + return g(x) + + x_shard = arg_shapes[0].sharding + return ( + mesh, + lower_fn, + NamedSharding(x_shard.mesh, P('x')), + (NamedSharding(x_shard.mesh, P('x')),), + ) + + def infer_sharding_from_operands(mesh, arg_shapes, result_shape): + x_shard = arg_shapes[0].sharding + return NamedSharding(x_shard.mesh, P('x')) + + f.def_partition( + infer_sharding_from_operands=infer_sharding_from_operands, + partition=partition, + sharding_rule='i -> i', + ) + + mesh = jtu.create_mesh((4,), ('x',)) + x = np.asarray(np.random.randint(0, 20, (32,)), dtype=np.float32) + s = NamedSharding(mesh, P('x')) + + jit_f = jax.jit(f, in_shardings=s, out_shardings=s) + self.assertArraysEqual(x, jit_f(x)) + + def test_custom_partitioner_pytree_inputs(self): + self.skip_if_custom_partitioning_not_supported() + + def partition(mesh, arg_shapes, result_shape): + def lower_fn(xs): + x, y, z = xs + return x + y + z + + return ( + mesh, + lower_fn, + arg_shapes[0][0].sharding, + jax.tree.map(lambda x: x.sharding, arg_shapes), + ) + + def infer_sharding_from_operands(mesh, arg_shapes, result_shape): + return arg_shapes[0][0].sharding + + def propagate_user_sharding(mesh, user_shape): + return user_shape.sharding + + @custom_partitioning + def f(xs): + x, y, z = xs + return x + y + z + + f.def_partition( + infer_sharding_from_operands=infer_sharding_from_operands, + partition=partition, + propagate_user_sharding=propagate_user_sharding, + sharding_rule='i j, i j, i j -> i j', + ) + + def f2(a): + return a + f((a, a, a)) + + with jax.set_mesh(jtu.create_mesh((4, 2), ('x', 'y'))): + jit_f = jax.jit(f2, in_shardings=(P(None, 'x')), out_shardings=P('x')) + x = np.asarray(np.random.randint(0, 20, (32, 16)), dtype=np.float32) + self.assertArraysEqual(x * 4, jit_f(jax.device_put(x, P(None, 'x')))) + + @jtu.skip_on_devices('cpu') + def test_custom_partition_with_sharding_rule_callback(self): + self.skip_if_custom_partitioning_not_supported() + + def partition(static_arg0, static_arg1, mesh, arg_shapes, result_shape): + arg_shardings = jax.tree.map(lambda s: s.sharding, arg_shapes) + result_sharding = result_shape.sharding + rank = len(arg_shapes[0].shape) + + self.assertEqual(static_arg0, 1) + self.assertEqual(static_arg1, 2) + def lower_fn(x, y): + axis_name = arg_shardings[1].spec[rank-2][0] + i = jax.lax.axis_index(axis_name) + z = jax.lax.psum(jax.lax.dynamic_slice_in_dim( + jax.lax.dynamic_slice_in_dim(x, i * 0, 8, axis=rank-2), + i * 8, 8, axis=rank-1) @ y, (axis_name)) + return z + + return mesh, lower_fn, (result_sharding), arg_shardings + + def produce_sharding_rule(static_arg0, static_arg1, mesh, arg_shapes, result_shape): + self.assertEqual(static_arg0, 1) + self.assertEqual(static_arg1, 2) + rank = len(arg_shapes[0].shape) + leading_axes = "" + for i in range(rank - 2): + leading_axes += f" b{i}" + return f"{leading_axes} i j, {leading_axes} j k -> {leading_axes} i k" , dict(reduction_factors=("j",)) + + @partial(custom_partitioning, static_argnums=(2,3)) + def f(x, y, static_arg0=1, static_arg1=2): + return jnp.matmul(x, y) + + f.def_partition( + infer_sharding_from_operands=None, + partition=partition, + sharding_rule=produce_sharding_rule) + + mesh = jtu.create_mesh((4, 2), ('x', 'y')) + x = jax.device_put(np.arange(2 * 3 * 32 * 16).reshape(2, 3, 32, 16), + NamedSharding(mesh, P(None, None, 'x'))) + y = jax.device_put(np.arange(2 * 3 * 16 * 32).reshape(2, 3, 16, 32), + NamedSharding(mesh, P(None, None,'y'))) + result = jax.jit(f)(x, y) + expected_result = f(x, y) + self.assertArraysEqual(result, expected_result) + self.assertEqual(result.sharding, NamedSharding(mesh, P(None, None, 'x'))) + + def test_custom_partition_shardy_migration(self): + self.skip_if_custom_partitioning_not_supported() + + def partition(mesh, arg_shapes, result_shape): + def lower_fn(x): + return x + + return ( + mesh, + lower_fn, + arg_shapes[0].sharding, + (arg_shapes[0].sharding,), + ) + + def infer_sharding_from_operands(mesh, arg_shapes, result_shape): + return arg_shapes[0].sharding + + def propagate_user_sharding(mesh, user_shape): + return user_shape.sharding + + @custom_partitioning + def f(x): + return x + + f.def_partition( + infer_sharding_from_operands=infer_sharding_from_operands, + partition=partition, + propagate_user_sharding=propagate_user_sharding, + ) + + mesh = jtu.create_mesh((4, 2), ('x', 'y')) + x = jax.device_put(np.arange(32 * 16).reshape(32, 16), + NamedSharding(mesh, P(None, 'x'))) + with self.assertRaisesRegex( + NotImplementedError, 'provide sharding_rule to migrate to Shardy'): + jax.jit(f)(x) + + def test_custom_partitioner_reshape(self): + self.skip_if_custom_partitioning_not_supported() + + def partition(mesh, arg_shapes, result_shape): + arg_shardings = jax.tree.map(lambda s: s.sharding, arg_shapes) + result_sharding = result_shape.sharding + + def lower_fn(x, y): + return x.reshape((4,)) + y + return mesh, lower_fn, (result_sharding), arg_shardings + + @partial(custom_partitioning) + def f(x, y): + x = x.reshape((8,)) + return x + y + + f.def_partition( + infer_sharding_from_operands=None, + propagate_user_sharding=None, + partition=partition, + sharding_rule='(i k) j, (i k j) -> (i k j)', i=2, k=2, need_replication_factors=('k',)) + + mesh = jtu.create_mesh((2, 4), ('x', 'y')) + x = jax.device_put(np.arange(8).reshape(4, 2), + NamedSharding(mesh, P('x', None))) + y = jax.device_put(np.arange(8), + NamedSharding(mesh, P('x'))) + jitted_result = jax.jit(f)(x, y) + unjitted_result = f(x, y) + self.assertArraysEqual(jitted_result, unjitted_result) + self.assertEqual(jitted_result.sharding, NamedSharding(mesh, P('x'))) + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index a39b53c3ad16..1f3181f67759 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -30,7 +30,7 @@ from jax.experimental import checkify import jax.experimental.custom_dce from jax.experimental import pallas as pl -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map import jax.numpy as jnp import jax.scipy as jsp @@ -44,8 +44,8 @@ from jax._src import custom_transpose from jax._src import test_util as jtu from jax._src.compilation_cache import is_persistent_cache_enabled -from jax._src.lax.control_flow import for_loop from jax._src.interpreters import mlir +from jax._src import util as util import numpy as np @@ -75,12 +75,16 @@ def _debug_info_to_string(dbg: core.DebugInfo) -> list[str]: # Strip the absolute path and the line number but check that it references # this file (to catch errors when the source info points in JAX internals) func_src_info = re.sub(r"^(\S+)( at .*.debug_info_test.py:\d+)?", "\\1", dbg.func_src_info) - arg_names_str = ",".join([str(a) for a in dbg.arg_names]) + if dbg.arg_names is None: + arg_names_str = "None" + else: + arg_names_str = ",".join([str(a) for a in dbg.arg_names]) res = f"traced_for={dbg.traced_for}, fun={func_src_info}, arg_names={arg_names_str}" if isinstance(dbg.result_paths, tuple): res += f", result_paths={','.join(dbg.result_paths)}" elif dbg.result_paths is None: - res += ", result_paths=" + res += ", result_paths=None" + # Do not show the thunk return res @@ -111,6 +115,7 @@ def append(self, t: Any) -> None: @jtu.with_config(jax_mutable_array_checks=True) +@unittest.skip("WIP") class DebugInfoTest(jtu.JaxTestCase): def _check_tracers_and_jaxprs(self, traceable: Any, @@ -130,7 +135,7 @@ def _check_tracers_and_jaxprs(self, traceable: Any, mode. The debug infos in the nested Jaxprs are first converted to strings using `_debug_info_to_string` and then compared against `expected_jaxpr_debug_infos`. During this conversion, - we strip occurences of this test file name and a line number + we strip occurrences of this test file name and a line number (e.g., .*/debug_info_test.py:56) An element of `expected_jaxpr_debug_infos` can be a string, in which case it is compared by equality, or a `re.Pattern` (the result of `re.compile`) @@ -241,8 +246,8 @@ def my_f(x, y, z, w): dbg = api_util.debug_info("jit", my_f, (1, 2), dict(z=3, w=4)) self.assertRegex(dbg.func_src_info, r"^my_f at .*debug_info_test.py:\d+") self.assertEqual(dbg.func_name, "my_f") - self.assertEqual(dbg.arg_names, ("x", "y", "z", "w")) - self.assertIsNone(dbg.result_paths) + self.assertEqual(dbg.arg_names, ("x", "y", "w", "z")) + self.assertIs(dbg.result_paths, core.initial_result_paths) def test_debug_info_arg_passed_as_kwarg(self): def my_f(x, y, z): @@ -261,23 +266,29 @@ def my_f(x_tree, *, y_tree): "y_tree['w']", "y_tree['z']")) def test_debug_info_with_statics(self): - def my_f(x, y, *, z, w): + def my_f(x, z, *, w, y): pass - dbg = api_util.debug_info("jit", my_f, (1, 2), dict(z=3, w=4), + dbg = api_util.debug_info("jit", my_f, (1,), dict(y=2, z=3, w=4), static_argnums=(1,), static_argnames=("w",)) - self.assertEqual(dbg.arg_names, ("x", "z")) + self.assertEqual(dbg.arg_names, ("x", "y", "z")) def test_debug_info_with_pytrees_and_statics(self): - def my_f(x, y, *, z, w): + def my_f(x, y, *, z, w, t): pass dbg = api_util.debug_info("jit", my_f, ((1, 2), (2, 3)), - dict(z=(3, 4), w=(5, 6)), + dict(z=(3, 4), w=(5, 6), t=7), + static_argnums=(1,), + static_argnames=("w",)) + self.assertEqual(dbg.arg_names, ("x[0]", "x[1]", "t", "z[0]", "z[1]")) + + dbg = api_util.debug_info("jit", my_f, ((1, 2),), + dict(z=(3, 4), w=(5, 6), t=7, y=3), static_argnums=(1,), static_argnames=("w",)) - self.assertEqual(dbg.arg_names, ("x[0]", "x[1]", "z[0]", "z[1]")) + self.assertEqual(dbg.arg_names, ("x[0]", "x[1]", "t", "y", "z[0]", "z[1]")) def test_debug_info_too_many_args(self): def my_f(x): @@ -287,35 +298,35 @@ def my_f(x): self.assertEqual(dbg.arg_names, ('args[0]', 'args[1]', 'args[2]', "kwargs['z']")) def test_debug_info_no_source_info_built_in(self): - # built-in function "int" does not have an inspect.Signature + # built-in function "max" does not have an inspect.Signature dbg = api_util.debug_info("jit", max, (1,), {}) self.assertEqual(dbg.func_src_info, "max") + self.assertEqual(dbg.func_name, "max") + self.assertEqual(dbg.func_filename, None) + self.assertEqual(dbg.func_lineno, None) self.assertEqual(dbg.arg_names, ("args[0]",)) def test_debug_info_lambda(self): # built-in function "int" does not have an inspect.Signature dbg = api_util.debug_info("jit", lambda my_arg: False, (1,), {}) self.assertRegex(dbg.func_src_info, r"^ at .*debug_info_test.py:\d+") + self.assertEndsWith(dbg.func_filename, "debug_info_test.py") + self.assertIsNotNone(dbg.func_lineno) self.assertEqual(dbg.arg_names, ("my_arg",)) - def test_debug_info_save_wrapped_fun_source_info(self): + def test_debug_info_save_wrapped_fun_debug_info(self): def wrapper(x, y): return x - dbg = api_util.debug_info("test", wrapper, (1, 2), {}) - self.assertEqual("wrapper", dbg.func_name) - - api_util.save_wrapped_fun_sourceinfo(wrapper, lambda x, y: x) - dbg = api_util.debug_info("test", wrapper, (1, 2), {}) - self.assertEqual("", dbg.func_name) - def other_f(): pass - dbg_other = api_util.debug_info("test other", other_f, (), {}) - api_util.save_wrapped_fun_sourceinfo(wrapper, dbg_other) + dbg = api_util.debug_info("test", wrapper, (1, 2), {}) - self.assertEqual("other_f", dbg.func_name) - self.assertEqual("test", dbg.traced_for) + self.assertEqual("wrapper", dbg.func_name) + + api_util.save_wrapped_fun_debug_info(other_f, dbg) + dbg = api_util.debug_info("other", other_f, (1, 2), {}) + self.assertEqual("wrapper", dbg.func_name) def test_debug_info_no_source_info_not_callable(self): # built-in function "int" does not have an inspect.Signature @@ -380,66 +391,6 @@ def f(x): with self.assertRaisesRegex(TypeError, err_str): jax.jit(f)(jnp.int32) - @jtu.thread_unsafe_test() # logging is not thread-safe - def test_arg_names_cache_miss_explanations(self): - @jax.jit - def f(x, y): - return jnp.sin(x) * y['hi'] - - x = jnp.float32(1.) - y = {'hi': jnp.arange(3., dtype='float32')} - - expected_log_len = 1 if not is_persistent_cache_enabled() else 3 - - # print on first miss, not on hit - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - f(x, y) - f(x, y) - self.assertLen(cm.output, expected_log_len) - msg = cm.output[0] - self.assertIn('TRACING CACHE MISS', msg) - self.assertIn('never seen function', msg) - - # shape change - y_ = {'hi': jnp.arange(4, dtype='float32')} - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - f(x, y_) - self.assertLen(cm.output, expected_log_len) - msg = cm.output[0] - self.assertIn('never seen input type signature', msg) - self.assertIn('closest seen input type signature has 1 mismatches', msg) - self.assertIn('seen f32[3], but now given f32[4]', msg) - - # weak type change (assuming no x64) - if not config.enable_x64.value: - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - f(1., y) - self.assertLen(cm.output, expected_log_len) - msg = cm.output[0] - self.assertIn('weak_type=True', msg) - self.assertIn('https://jax.readthedocs.io/en/latest/type_promotion.html#weak-types', msg) - - # kwarg change - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - f(1, y=y) - self.assertLen(cm.output, expected_log_len) - msg = cm.output[0] - self.assertIn('never seen passing 1 positional args and 1 keyword args', msg) - - # tracing config change - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - with jax.numpy_rank_promotion('warn'): - f(x, y) - # depending on the backend, we may or may not get persistent cache warnings - self.assertTrue(1 <= len(cm.output) <= expected_log_len) - msg = cm.output[0] - self.assertIn("tracing context doesn't match", msg) - @jtu.thread_unsafe_test() # logging is not thread-safe def test_arg_names_cache_miss_explanations_new_function_in_loop(self): @jax.jit @@ -671,7 +622,7 @@ def my_g(b, d=1): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ # TODO(necula): result_paths? - "traced_for=jit, fun=my_f, arg_names=a, result_paths=", + "traced_for=jit, fun=my_f, arg_names=a, result_paths=result", "traced_for=jit, fun=my_g, arg_names=b, result_paths=result", ], expected_tracer_debug_infos=[ @@ -761,6 +712,122 @@ def f(x, y, *args, **kwargs): re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\"\}"), ]) + def test_jit_arg_names_with_out_of_order_kwargs(self): + tracer_spy = TracerSpy() + + # The shapes are different, to differentiate them easily + a1 = (np.float32(0),) # a hashable tuple, can be static + b2 = np.arange(2, dtype=np.float32) # b2 + z3 = np.arange(3, dtype=np.float32) + y4 = (np.float32(0.), np.float32(1.), np.float32(2.), np.float32(3.)) + x5 = np.arange(5, dtype=np.float32) + u6 = np.arange(6, dtype=np.float32) + t7 = np.arange(7, dtype=np.float32) + + def my_f(a1, b2, z3, y4, x5, *, u6, t7): + assert np.shape(a1[0]) == () + assert np.shape(b2) == (2,) + assert np.shape(z3) == (3,) + assert np.shape(y4) == (4,) + assert np.shape(x5) == (5,) + assert np.shape(u6) == (6,) + assert np.shape(t7) == (7,) + tracer_spy.append(b2) + tracer_spy.append(x5) + return a1[0] + b2[0] + z3[0] + y4[0] + x5[0] + u6[0] + t7[0] + + self._check_tracers_and_jaxprs( + jax.jit(my_f, static_argnums=(0,), static_argnames=("y4",)), + # Some positional args passed as keyword + a1, b2, x5=x5, y4=y4, z3=z3, t7=t7, u6=u6, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=b2,t7,u6,x5,z3, result_paths=result", + ], + tracer_spy=tracer_spy, + expected_tracer_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=b2,t7,u6,x5,z3, from b2", + "traced_for=jit, fun=my_f, arg_names=b2,t7,u6,x5,z3, from x5", + ], + expected_lowering_lines=[ + re.compile(r".*func.func public @main\(%arg0: tensor<2xf..> loc\(\"b2\"\)"), + re.compile(r".*func.func public @main\(.*%arg1: tensor<7xf..> loc\(\"t7\"\)"), + re.compile(r".*func.func public @main\(.*%arg2: tensor<6xf..> loc\(\"u6\"\)"), + re.compile(r".*func.func public @main\(.*%arg3: tensor<5xf..> loc\(\"x5\"\)"), + ] + ) + + tracer_spy.tracers = [] + util.clear_all_caches() + self._check_tracers_and_jaxprs( + jax.jit(my_f, static_argnames=("y4",)), + # Positional argument y4 is static and passed by kwarg + a1, b2, z3, x5=x5, y4=y4, t7=t7, u6=u6, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, result_paths=result", + ], + tracer_spy=tracer_spy, + expected_tracer_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, from b2", + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, from x5", + ], + expected_lowering_lines=[ + re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"a1\[0\]\"\)"), + re.compile(r".*func.func public @main\(.*%arg1: tensor<2xf..> loc\(\"b2\"\)"), + re.compile(r".*func.func public @main\(.*%arg2: tensor<3xf..> loc\(\"z3\"\)"), + re.compile(r".*func.func public @main\(.*%arg3: tensor<7xf..> loc\(\"t7\"\)"), + re.compile(r".*func.func public @main\(.*%arg4: tensor<6xf..> loc\(\"u6\"\)"), + re.compile(r".*func.func public @main\(.*%arg5: tensor<5xf..> loc\(\"x5\"\)"), + ] + ) + + tracer_spy.tracers = [] + util.clear_all_caches() + self._check_tracers_and_jaxprs( + jax.jit(my_f, static_argnames=("y4",)), + # Positional argument y4 is static (declared as static_argnames) + a1, b2, z3, y4, x5=x5, t7=t7, u6=u6, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, result_paths=result", + ], + tracer_spy=tracer_spy, + expected_tracer_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, from b2", + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, from x5", + ], + expected_lowering_lines=[ + re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"a1\[0\]\"\)"), + re.compile(r".*func.func public @main\(.*%arg1: tensor<2xf..> loc\(\"b2\"\)"), + re.compile(r".*func.func public @main\(.*%arg2: tensor<3xf..> loc\(\"z3\"\)"), + re.compile(r".*func.func public @main\(.*%arg3: tensor<7xf..> loc\(\"t7\"\)"), + re.compile(r".*func.func public @main\(.*%arg4: tensor<6xf..> loc\(\"u6\"\)"), + re.compile(r".*func.func public @main\(.*%arg5: tensor<5xf..> loc\(\"x5\"\)"), + ] + ) + + tracer_spy.tracers = [] + util.clear_all_caches() + self._check_tracers_and_jaxprs( + jax.jit(my_f, static_argnums=(3,)), + # Positional argument y4 is static (declared as static_argnums) + a1, b2, z3, y4, x5=x5, t7=t7, u6=u6, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, result_paths=result", + ], + tracer_spy=tracer_spy, + expected_tracer_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, from b2", + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, from x5", + ], + expected_lowering_lines=[ + re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"a1\[0\]\"\)"), + re.compile(r".*func.func public @main\(.*%arg1: tensor<2xf..> loc\(\"b2\"\)"), + re.compile(r".*func.func public @main\(.*%arg2: tensor<3xf..> loc\(\"z3\"\)"), + re.compile(r".*func.func public @main\(.*%arg3: tensor<7xf..> loc\(\"t7\"\)"), + re.compile(r".*func.func public @main\(.*%arg4: tensor<6xf..> loc\(\"u6\"\)"), + re.compile(r".*func.func public @main\(.*%arg5: tensor<5xf..> loc\(\"x5\"\)"), + ] + ) + def test_jit_result_info(self): def f(x, y, z): return {'a': x, 'b': [y]} @@ -793,9 +860,9 @@ def my_g(u, v): 2, 3, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result", - "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c']" - ], + "traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result", + "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c'],result['d']", + ], expected_tracer_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x,y, from x", "traced_for=jit, fun=my_g, arg_names=u,v, from u" @@ -803,23 +870,36 @@ def my_g(u, v): def test_nested_jit_with_const_and_unused_args(self): def my_f(x, y): # y is unused - def my_g(u, v): # v is unused + def my_g(u, v): # u is unused return v + np.ones(v.shape, v.dtype) return x + jax.jit(my_g)(y, x) x = y = np.ones((8,), dtype=np.float32) + expected_jaxpr_debug_infos = [ + "traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result", + ] + if config.use_simplified_jaxpr_constants.value: + # TODO(necula): remove the conditional + expected_jaxpr_debug_infos.extend([ + "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result", + ]) + expected_lowering_lines = [ + re.compile(r".*func.func public @main\(%arg0: tensor<8xf..> {jax.const = true} loc\(unknown\)"), + re.compile(r".*func.func public @main\(.*%arg1: tensor<8xf..> loc\(\"x\"\)"), + ] + else: + expected_jaxpr_debug_infos.extend([ + "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result", + ]) + expected_lowering_lines = [ + re.compile(r".*func.func public @main\(%arg0: tensor<8xf..> loc\(\"x\"\)\)"), + ] self._check_tracers_and_jaxprs( jax.jit(my_f), x, y, - expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result", - "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result", - ], - expected_lowering_lines=[ - re.compile(r".*func.func public @main\(%arg0: tensor<8xf..> loc\(\"x\"\)\)"), - re.compile(r".*call @my_g\(%arg.\) : \(tensor<8xf..>\)"), - ] + expected_jaxpr_debug_infos=expected_jaxpr_debug_infos, + expected_lowering_lines=expected_lowering_lines ) def test_jvp_of_jit(self): @@ -831,8 +911,7 @@ def f(x, y, z): lambda x, y, z: jax.jvp(jax.jit(f), (x, y, z), (x, y, z)), jnp.float32(1.), (jnp.float32(2.),), [jnp.float32(3.)], expected_jaxpr_debug_infos=[ - # TODO(necula): arg_names, result_paths - "traced_for=jit, fun=f, arg_names=,,,, result_paths=,,,", + "traced_for=jit, fun=f, arg_names=None, result_paths=None", ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ @@ -876,6 +955,7 @@ def my_f(x, y, z): re.compile(r".*func.func public @main\(.*-> \(tensor {jax.result_info = \"\"}"), ]) + @unittest.skip("testing for incorrect debug info (pjit transpose)") def test_vjp_of_nested_jit(self): tracer_spy = TracerSpy() def my_f(x, y): @@ -886,20 +966,16 @@ def my_g(u, v): return dict(c=u * v, d=v) return jax.jit(my_g)(y, x)["c"] + expected_jaxpr_debug_infos = [ + "traced_for=jit, fun=, arg_names=x,y,res_ct, result_paths=result[0],result[1]", + "traced_for=jit, fun=my_g, arg_names=None, result_paths=None", + ] self._check_tracers_and_jaxprs( jax.jit(lambda x, y, res_ct: jax.vjp(my_f, x, y)[1](res_ct)), 2., 3., 0.3, tracer_spy=tracer_spy, - expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=, arg_names=x,y,res_ct, result_paths=result[0],result[1]", - # TODO(necula): result_paths - "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=", - # TODO(necula): arg_names - "traced_for=jit, fun=my_g, arg_names=u,v,,, result_paths=," - if config.use_direct_linearize.value else - "traced_for=jit, fun=my_g, arg_names=,,u,v, result_paths=result['c'],result['d']", - ], + expected_jaxpr_debug_infos=expected_jaxpr_debug_infos, expected_tracer_debug_infos=[ # TODO(necula): missing debug info "None", @@ -913,6 +989,7 @@ def my_g(u, v): re.compile(r".*func.func public @main\(.*jax.result_info = \"result\[1\]\"}"), ]) + @unittest.skip("test fails despite looking like it matches...") def test_vjp_remat(self): tracer_spy = TracerSpy() def apply_fn(inp): @@ -1029,10 +1106,10 @@ def to_diff(x): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=to_diff, arg_names=x['a'], result_paths=result['a']", - "traced_for=custom_vjp fun, fun=my_f, arg_names=x['a'], result_paths=result['b']", + "traced_for=custom_vjp fun, fun=my_f, arg_names=None, result_paths=None", ], expected_tracer_debug_infos=[ - "traced_for=custom_vjp fun, fun=my_f, arg_names=x['a'], from x['a']", + "traced_for=custom_vjp fun, fun=my_f, arg_names=None, result_paths=None, from unknown", # TODO(necula): from None? "traced_for=jit, fun=to_diff, arg_names=x['a'], from None", "traced_for=jit, fun=to_diff, arg_names=x['a'], from x['a']", @@ -1059,11 +1136,11 @@ def app_rev(f, cos_x0, g): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=, arg_names=xy[0],xy[1], result_paths=result[0],result[1]", - "traced_for=custom_vjp fun, fun=app, arg_names=xy[0],xy[1], result_paths=result", + "traced_for=custom_vjp fun, fun=app, arg_names=None, result_paths=None", ], expected_tracer_debug_infos=[ "traced_for=jit, fun=, arg_names=xy[0],xy[1], from xy[0]", - "traced_for=custom_vjp fun, fun=app, arg_names=xy[0],xy[1], from xy[0]", + "traced_for=custom_vjp fun, fun=app, arg_names=None, result_paths=None, from unknown", # TODO(necula): from None "traced_for=jit, fun=, arg_names=xy[0],xy[1], from None", ]) @@ -1113,8 +1190,8 @@ def fn_tp(r, t): x, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=cond, fun=my_f, arg_names=x['c'], result_paths=result", - "traced_for=cond, fun=, arg_names=x['c'], result_paths=result", + "traced_for=cond, fun=my_f, arg_names=None, result_paths=None", + "traced_for=cond, fun=, arg_names=None, result_paths=None", "traced_for=jit, fun=, arg_names=x, result_paths=result[0][0][0],result[0][0][1]", ], expected_tracer_debug_infos=[ @@ -1145,15 +1222,14 @@ def fn_tp(r, t): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=, arg_names=x, result_paths=result[0]['c']", - "traced_for=linear_call fun, fun=fn, arg_names=r,x['c'], result_paths=result['b']", - "traced_for=linear_call fun_transpose, fun=fn_tp, arg_names=r,t['c'], result_paths=result['c']", + "traced_for=linear_call fun_transpose, fun=fn_tp, arg_names=None, result_paths=None", ], expected_tracer_debug_infos=[ # TODO(necula): from None? "traced_for=jit, fun=, arg_names=x, from None", "traced_for=linear_call fun, fun=fn, arg_names=r,x['c'], from r", "traced_for=linear_call fun, fun=fn, arg_names=r,x['c'], from x['c']", - "traced_for=linear_call fun_transpose, fun=fn_tp, arg_names=r,t['c'], from t['c']", + "traced_for=linear_call fun_transpose, fun=fn_tp, arg_names=None, result_paths=None, from unknown", ]), @@ -1273,12 +1349,9 @@ def my_g(x, y): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result", - # TODO(necula): arg_names? result_paths? - "traced_for=cond, fun=my_true_branch, arg_names=, result_paths=,", - "traced_for=cond, fun=my_false_branch, arg_names=, result_paths=,", - "traced_for=cond, fun=my_true_branch, arg_names=a,b, result_paths=result[0],result[1]", - "traced_for=cond, fun=my_false_branch, arg_names=c,d, result_paths=result[0],result[1]", - "traced_for=checkpoint / remat, fun=my_g, arg_names=,, result_paths=,", + "traced_for=cond, fun=my_true_branch, arg_names=None, result_paths=None", + "traced_for=cond, fun=my_false_branch, arg_names=None, result_paths=None", + "traced_for=checkpoint / remat, fun=my_g, arg_names=None, result_paths=None", ], expected_tracer_debug_infos=[ "traced_for=cond, fun=my_true_branch, arg_names=a,b, from a", @@ -1288,10 +1361,11 @@ def my_g(x, y): ]) def test_grad_scan(self): - # Based on control_flow_test:testScanHigherOrderDifferentiation + # Based on control_flow_test:testScanHigherOrderDifferentiation tracer_spy = TracerSpy() - def f(c, a): + def f(c, a): # c: f32, a: f32[2] tracer_spy.append(c) + tracer_spy.append(a) d = 0.75 b = jnp.sin(c * jnp.sum(jnp.cos(d * a))) c = 0.9 * jnp.cos(d * jnp.sum(jnp.sin(c * a))) @@ -1301,50 +1375,55 @@ def f(c, a): c = jnp.array(1, dtype=as_.dtype) @jax.jit - def my_f(x, as_): + def my_f(x, as_): # x: f32, as_: f32[3, 2] tracer_spy.append(x) def to_remat(a, b): - return for_loop.scan(f, a, b) - return jax.remat(to_remat)(c, as_) + return lax.scan(f, a, b) + return jax.remat(to_remat)(c, as_) # c is closed-over - def the_grad(c, as_): + def the_grad(c, as_): # c: f32[], as_: f32[3, 2], tracer_spy.append(c) _, pullback = jax.vjp(my_f, c, as_) return pullback((c, np.arange(3, dtype=c.dtype))) + expected_jaxpr_debug_infos = [ + "traced_for=jit, fun=the_grad, arg_names=c,as_, result_paths=result[0],result[1]", + "traced_for=checkpoint / remat, fun=to_remat, arg_names=None, result_paths=None", + "traced_for=scan, fun=f, arg_names=None, result_paths=None", + "traced_for=jit, fun=my_f, arg_names=None, result_paths=None", + ] + if config.use_simplified_jaxpr_constants.value: + expected_jaxpr_debug_infos.extend([ + "traced_for=jit, fun=my_f, arg_names=None, result_paths=None", + ]) + + if config.use_simplified_jaxpr_constants.value: + expected_lowering_lines = [ + re.compile(r".*func.func public @main\(%arg0: tensor<3xf..> {jax.const = true} loc\(unknown\)"), + re.compile(r".*func.func public @main\(.*, %arg1: tensor loc\(\"c\"\)"), + re.compile(r".*func.func public @main\(.*, %arg2: tensor<3x2xf..> loc\(\"as_\"\)"), + re.compile(r".*func.func public @main\(.* -> .*tensor {jax.result_info = \"result\[0\]\""), + re.compile(r".*func.func public @main\(.* -> .*tensor<3x2xf..> {jax.result_info = \"result\[1\]\""), + ] + else: + expected_lowering_lines = [ + re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"c\"\)"), + re.compile(r".*func.func public @main\(.*, %arg1: tensor<3x2xf..> loc\(\"as_\"\)"), + re.compile(r".*func.func public @main\(.* -> .*tensor {jax.result_info = \"result\[0\]\""), + re.compile(r".*func.func public @main\(.* -> .*tensor<3x2xf..> {jax.result_info = \"result\[1\]\""), + ] self._check_tracers_and_jaxprs( jax.jit(the_grad), c, as_, tracer_spy=tracer_spy, - expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=the_grad, arg_names=c,as_, result_paths=result[0],result[1]", - # TODO(necula): arg names, bad result paths - "traced_for=jit, fun=my_f, arg_names=x,as_, result_paths=,,", - "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], result_paths=", - "traced_for=for_loop, fun=f, arg_names=,,, result_paths=,", - "traced_for=for_loop, fun=f, arg_names=,,,,,, result_paths=,", - "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,, result_paths=", - "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,,,,,, result_paths=,", - "traced_for=checkpoint / remat, fun=to_remat, arg_names=,,, result_paths=,", - "traced_for=jit, fun=my_f, arg_names=as_,,, result_paths=" - if config.use_direct_linearize.value else - "traced_for=jit, fun=my_f, arg_names=,,x,as_, result_paths=", - ], + expected_jaxpr_debug_infos=expected_jaxpr_debug_infos, expected_tracer_debug_infos=[ "traced_for=jit, fun=the_grad, arg_names=c,as_, from c", "traced_for=scan, fun=f, arg_names=c,a, from c", + "traced_for=scan, fun=f, arg_names=c,a, from a", "traced_for=jit, fun=my_f, arg_names=x,as_, from x", - # TODO(necula): arg_names, and "from x" - "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], from refs[0]", ], - expected_lowering_lines=[ - re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"c\"\)"), - re.compile(r".*func.func public @main\(.*, %arg1: tensor<3x2xf..> loc\(\"as_\"\)"), - re.compile(r".*func.func public @main\(.* -> .*tensor {jax.result_info = \"result\[0\]\""), - re.compile(r".*func.func public @main\(.* -> .*tensor<3x2xf..> {jax.result_info = \"result\[1\]\""), - # TODO(necula): unnamed function? - re.compile(r".*func.func private @None"), - ]) + expected_lowering_lines=expected_lowering_lines) def test_while_loop(self): tracer_spy = TracerSpy() @@ -1386,13 +1465,11 @@ def my_body(_, c): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=, arg_names=x, result_paths=result", - # TODO(necula): bad arg_names, result_paths - "traced_for=scan, fun=my_body, arg_names=loop_carry[0],loop_carry[1], result_paths=result[0][0],result[0][1]", + "traced_for=fori_loop, fun=my_body, arg_names=_,c, result_paths=result[0][0],result[0][1]", ], expected_tracer_debug_infos=[ - # TODO(necula): the arg_names are not right - "traced_for=scan, fun=my_body, arg_names=loop_carry[0],loop_carry[1], from loop_carry[1]", + "traced_for=fori_loop, fun=my_body, arg_names=_,c, result_paths=None, from c", ] ) @@ -1405,13 +1482,12 @@ def my_body(_, c): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=, arg_names=ub,x, result_paths=result", + # The fori_cond fun is entire manufactured internally re.compile(r"traced_for=while_cond, fun=_fori_cond_fun at .*loops.py:.*, arg_names=loop_carry\[0\],loop_carry\[1\],loop_carry\[2\], result_paths="), - # TODO(necula): arg_names and result_paths are not right - "traced_for=while_body, fun=my_body, arg_names=loop_carry[0],loop_carry[1],loop_carry[2], result_paths=result[0],result[1],result[2]", + "traced_for=fori_loop, fun=my_body, arg_names=_,,c, result_paths=result[0],result[1],result[2]", ], expected_tracer_debug_infos=[ - # TODO(necula): the arg_names are not right - "traced_for=while_body, fun=my_body, arg_names=loop_carry[0],loop_carry[1],loop_carry[2], from loop_carry[2]", + "traced_for=fori_loop, fun=my_body, arg_names=_,,c, from c", ]) def test_scan(self): @@ -1467,7 +1543,7 @@ def my_g(u, v): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result", - "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c']", + "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c'],result['d']", ], expected_tracer_debug_infos=[ # TODO(necula): missing debug info @@ -1481,49 +1557,91 @@ def my_f(x): tracer_spy.append(x) return jnp.sin(x) + if config.pmap_shmap_merge.value: + expected_jaxpr_debug_infos = [ + # TODO(necula): We should not include `call_wrapped` in the debug info. + re.compile(r"traced_for=jit, fun=call_wrapped at .*jax._src.linear_util.py:.*, arg_names=args\[0\], result_paths=result\[0\]"), + re.compile(r"traced_for=shard_map, fun=call_wrapped at .*jax._src.linear_util.py:.*, arg_names=args\[0\], result_paths=result\[0\]"), + ] + expected_tracer_debug_infos = [ + re.compile(r"traced_for=shard_map, fun=call_wrapped at .*jax._src.linear_util.py:.*, arg_names=args\[0\], from args\[0\]"), + ] + else: + expected_jaxpr_debug_infos = [ + "traced_for=pmap, fun=my_f, arg_names=None, result_paths=None", + ] + expected_tracer_debug_infos = [ + "traced_for=pmap, fun=my_f, arg_names=None, result_paths=None, from unknown" + ] + self._check_tracers_and_jaxprs( jax.pmap(my_f), np.ones((jax.device_count(),), dtype=np.float32), - expected_jaxpr_debug_infos=[ - "traced_for=pmap, fun=my_f, arg_names=x, result_paths=result" - ], + expected_jaxpr_debug_infos=expected_jaxpr_debug_infos, tracer_spy=tracer_spy, - expected_tracer_debug_infos=[ - "traced_for=pmap, fun=my_f, arg_names=x, from x" - ], + expected_tracer_debug_infos=expected_tracer_debug_infos, ) def test_pmap_with_arg_and_result_names(self): tracer_spy = TracerSpy() - x = np.ones((jax.device_count(),), dtype=np.float32) - def my_f(x, y, *args, a, **kwargs): - # y and kwargs[c] is dead - tracer_spy.append(args[1]) - s = x + a + args[1] + kwargs["d"] - return dict(u=s, v=x) + # Use different shapes arguments to distinguish them in the HLO + def my_f(x0, y1, *args, b4, **kwargs): + assert np.shape(x0) == () + assert np.shape(y1) == (1,) + assert np.shape(args[0]) == (2,) + assert np.shape(args[1]) == (3,) + assert np.shape(b4) == (4,) + assert np.shape(kwargs["a5"]) == (5,) + assert np.shape(kwargs["c6"]) == (6,) + # kwargs[b5] is dead + tracer_spy.append(args[1]) + tracer_spy.append(b4) + tracer_spy.append(kwargs["c6"]) + s0 = x0 + y1[0] + b4[0] + args[1][0] + kwargs["c6"][0] + return dict(v1=jnp.broadcast_to(s0, (1,)), u0=s0) + + if config.pmap_shmap_merge.value: + expected_jaxpr_debug_infos = [ + # TODO(necula): We should not include `call_wrapped` in the debug info. + re.compile(r"traced_for=jit, fun=call_wrapped at .*jax._src.linear_util.py:.*, arg_names=(args\[\d+\],)+ result_paths=result\[0\],result\[1\]"), + re.compile(r"traced_for=shard_map, fun=call_wrapped at .*jax._src.linear_util.py:.*, arg_names=(args\[\d+\],)+ result_paths=result\[0\],result\[1\]"), + ] + expected_tracer_debug_infos = [ + re.compile(r"traced_for=shard_map, fun=call_wrapped at .*jax._src.linear_util.py:.*, arg_names=(args\[\d+\],)+ from args\[\d+\]"), + ] + expected_lowering_lines = [] + else: + expected_jaxpr_debug_infos = [ + "traced_for=pmap, fun=my_f, arg_names=None, result_paths=None", + ] + expected_tracer_debug_infos = [ + "traced_for=pmap, fun=my_f, arg_names=None, result_paths=None, from unknown", + ] + expected_lowering_lines = [ + re.compile(r".*func.func public @main\(.*%arg0: tensor<1x1xf..> loc\(unknown\)"), + re.compile(r".*func.func public @main\(.*%arg1: tensor<1x2xf..> loc\(unknown\)"), + re.compile(r".*func.func public @main\(.*%arg2: tensor<1x3xf..> loc\(unknown\)"), + re.compile(r".*func.func public @main\(.*%arg3: tensor<1x5xf..> loc\(unknown\)"), + re.compile(r".*func.func public @main\(.*%arg4: tensor<1x4xf..> loc\(unknown\)"), + re.compile(r".*func.func public @main\(.*%arg5: tensor<1x6xf..> loc\(unknown\)"), + re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"\"\}"), + re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"\"\}"), + ] self._check_tracers_and_jaxprs( jax.pmap(my_f, static_broadcasted_argnums=(0,)), - 1., x, x, x, # x, y, args[0], args[1] - d=x, a=x, b=x, # kwargs - expected_jaxpr_debug_infos=[ - "traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], result_paths=result['u'],result['v']", - ], + 1., # x0 + np.ones((jax.device_count(), 1), dtype=np.float32), # y1 + np.ones((jax.device_count(), 2), dtype=np.float32), # args[0] + np.ones((jax.device_count(), 3), dtype=np.float32), # args[1] + b4=np.ones((jax.device_count(), 4), dtype=np.float32), + a5=np.ones((jax.device_count(), 5), dtype=np.float32), + c6=np.ones((jax.device_count(), 6), dtype=np.float32), + + expected_jaxpr_debug_infos=expected_jaxpr_debug_infos, tracer_spy=tracer_spy, - expected_tracer_debug_infos=[ - "traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], from args[1]", - ], - expected_lowering_lines=[ - # TODO(necula): we did not DCE y? - re.compile(r".*func.func public @main\(.*%arg0: tensor<1xf..> loc\(\"y\"\)"), - re.compile(r".*func.func public @main\(.*%arg1: tensor<1xf..> loc\(\"args\[0\]\"\)"), - re.compile(r".*func.func public @main\(.*%arg2: tensor<1xf..> loc\(\"args\[1\]\"\)"), - re.compile(r".*func.func public @main\(.*%arg3: tensor<1xf..> loc\(\"a\"\)"), - re.compile(r".*func.func public @main\(.*%arg4: tensor<1xf..> loc\(\"kwargs\['b'\]\"\)"), - re.compile(r".*func.func public @main\(.*%arg5: tensor<1xf..> loc\(\"kwargs\['d'\]\"\)"), - re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"result\['u'\]\"\}"), - re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"result\['v'\]\"\}"), - ] + expected_tracer_debug_infos=expected_tracer_debug_infos, + expected_lowering_lines=expected_lowering_lines, ) def test_pmap_of_grad(self): @@ -1532,20 +1650,31 @@ def my_f(x): tracer_spy.append(x) return jnp.sin(x) + if config.pmap_shmap_merge.value: + expected_jaxpr_debug_infos = [ + # TODO(necula): We should not include `call_wrapped` in the debug info. + re.compile(r"traced_for=jit, fun=call_wrapped at .*jax._src.linear_util.py:.*, arg_names=args\[0\], result_paths=result\[0\]"), + re.compile(r"traced_for=shard_map, fun=call_wrapped at .*jax._src.linear_util.py:.*, arg_names=args\[0\], result_paths=result\[0\]"), + ] + else: + expected_jaxpr_debug_infos = [ + "traced_for=pmap, fun=my_f, arg_names=None, result_paths=None", + ] self._check_tracers_and_jaxprs( jax.pmap(jax.grad(my_f)), np.ones((jax.device_count(),), dtype=np.float32), - expected_jaxpr_debug_infos=[ - "traced_for=pmap, fun=my_f, arg_names=x, result_paths=result", - ], + expected_jaxpr_debug_infos=expected_jaxpr_debug_infos, tracer_spy=tracer_spy, expected_tracer_debug_infos=[ # TODO(necula): missing debug_info - 'None' + "None" ], ) def test_jvp_pmap_eager(self): + if config.pmap_shmap_merge.value: + self.skipTest("TODO(dsuo): How to check for tpu version?") + tracer_spy = TracerSpy() def my_f(x, y, *args): # y is dead, x is static broadcasted @@ -1556,24 +1685,42 @@ def my_f(x, y, *args): x = jnp.ones((jax.device_count(), 1), dtype=np.float32) x_tan = jnp.full_like(x, .1) + if config.pmap_shmap_merge.value: + expected_jaxpr_debug_infos = [ + # TODO(necula): We should not include `call_wrapped` in the debug info. + re.compile(r"traced_for=jit, fun=call_wrapped at .*jax._src.linear_util.py:.*, arg_names=,,,, result_paths=,,,"), + re.compile(r"traced_for=shard_map, fun=call_wrapped at .*jax._src.linear_util.py:.*, arg_names=,,,, result_paths=,,,"), + ] + # TODO(dsuo): Need to add tpu_pjrt_c_api and tpu_v3 to this conditional. + if jtu.test_device_matches(["cpu"]): + expected_jaxpr_debug_infos.extend([ + re.compile(r"traced_for=jit, fun=dynamic_slice at .*jax._src.dispatch.py:.*, arg_names=(args\[\d+\],)+ result_paths=result"), + ]) + expected_tracer_debug_infos = [ + re.compile(r"traced_for=shard_map, fun=call_wrapped at .*jax._src.linear_util.py:.*, arg_names=(args\[\d+\],)+ from args\[\d+\]"), + ] + else: + expected_jaxpr_debug_infos = [ + # TODO(necula): why this? + re.compile(r"traced_for=jit, fun=_multi_slice at .*array_methods.py:.*, arg_names=self, result_paths=.*"), + "traced_for=pmap, fun=my_f, arg_names=None, result_paths=None", + ] + expected_tracer_debug_infos = [ + # TODO(necula): missing debug_info + "None" + ] + self._check_tracers_and_jaxprs( lambda x, x_tan: jax.jvp(jax.pmap(my_f), (x, x, x, x), (x_tan, x_tan, x_tan, x_tan)), x, x_tan, - expected_jaxpr_debug_infos=[ - # TODO(necula): why this? - re.compile(r'traced_for=jit, fun=_multi_slice at .*array_methods.py:.*, arg_names=self, result_paths=.*'), - "traced_for=pmap, fun=my_f, arg_names=x,y,args[0],args[1], result_paths=result['u'],result['v']", - ], + expected_jaxpr_debug_infos=expected_jaxpr_debug_infos, tracer_spy=tracer_spy, - expected_tracer_debug_infos=[ - # TODO(necula): missing debug_info - "None" - ], + expected_tracer_debug_infos=expected_tracer_debug_infos, ) @jtu.ignore_warning(category=UserWarning, - message=".* jitted function .* includes a pmap") + message=".* function .* includes a pmap") def test_jvp_pmap(self): tracer_spy = TracerSpy() def my_f(x, y): @@ -1583,20 +1730,35 @@ def my_f(x, y): x = np.ones((jax.device_count(), 1), dtype=np.float32) x_tan = np.full_like(x, .1) + if config.pmap_shmap_merge.value: + expected_jaxpr_debug_infos = [ + # TODO(necula): We should not include `call_wrapped` in the debug info. + re.compile(r"traced_for=jit, fun=call_wrapped at .*jax._src.linear_util.py:.*, arg_names=None, result_paths=None"), + r"traced_for=jit, fun=, arg_names=x,x_tan, result_paths=result[0],result[1]", + re.compile(r"traced_for=shard_map, fun=call_wrapped at .*jax._src.linear_util.py:.*, arg_names=None, result_paths=None"), + ] + expected_tracer_debug_infos = [ + re.compile(r"traced_for=shard_map, fun=call_wrapped at .*jax._src.linear_util.py:.*, arg_names=args\[0\],args\[1\], from args\[0\]"), + ] + else: + expected_jaxpr_debug_infos = [ + "traced_for=jit, fun=, arg_names=x,x_tan, result_paths=result[0],result[1]", + "traced_for=pmap, fun=my_f, arg_names=None, result_paths=None", + ] + expected_tracer_debug_infos = [ + # TODO(necula): missing debug_info + "None" + ] + self._check_tracers_and_jaxprs( jax.jit(lambda x, x_tan: jax.jvp(jax.pmap(my_f), (x, x), (x_tan, x_tan))), x, x_tan, - expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=, arg_names=x,x_tan, result_paths=result[0],result[1]", - "traced_for=pmap, fun=my_f, arg_names=x,y, result_paths=result", - ], + expected_jaxpr_debug_infos=expected_jaxpr_debug_infos, tracer_spy=tracer_spy, - expected_tracer_debug_infos=[ - # TODO(necula): missing debug_info - "None" - ], + expected_tracer_debug_infos=expected_tracer_debug_infos, ) + @unittest.skip("testing for incorrect debug info (pjit transpose)") def test_hessian(self): tracer_spy = TracerSpy() @@ -1606,24 +1768,22 @@ def my_f(x): x = jax.random.uniform(jax.random.key(0), shape=(8, 4)) + expected_jaxpr_debug_infos = [ + "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", + "traced_for=jit, fun=my_f, arg_names=None, result_paths=None", + ] + self._check_tracers_and_jaxprs( jax.jit(jax.hessian(jax.jit(my_f))), x, - expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", - # TODO(necula): arg_names and result_paths? - "traced_for=jit, fun=my_f, arg_names=x, result_paths=,,,", - "traced_for=jit, fun=my_f, arg_names=x,, result_paths=," - if config.use_direct_linearize.value else - "traced_for=jit, fun=my_f, arg_names=,x, result_paths=,", - ], + expected_jaxpr_debug_infos=expected_jaxpr_debug_infos, tracer_spy=tracer_spy, expected_tracer_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x, from x", ], ) - (x).block_until_ready() + x.block_until_ready() def test_remat(self): tracer_spy = TracerSpy() @@ -1664,8 +1824,7 @@ def my_g(y): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", - # TODO(necula): arg_names? result_paths? - "traced_for=checkpoint / remat, fun=my_g, arg_names=,, result_paths=", + "traced_for=checkpoint / remat, fun=my_g, arg_names=None, result_paths=None", ], expected_tracer_debug_infos=[ "traced_for=checkpoint / remat, fun=my_g, arg_names=y, from y", @@ -1690,14 +1849,12 @@ def my_f(x): jnp.arange(2, dtype=np.float32), tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - # TODO(necula): arg_names, result_paths "traced_for=jit, fun=, arg_names=x, result_paths=result", - "traced_for=checkpoint / remat, fun=my_f, arg_names=,, result_paths=", - "traced_for=shard_map, fun=my_f, arg_names=x, result_paths=result", - "traced_for=shard_map, fun=my_f, arg_names=,, result_paths=", + "traced_for=checkpoint / remat, fun=my_f, arg_names=None, result_paths=None", + "traced_for=shard_map, fun=my_f, arg_names=None, result_paths=None", ], expected_tracer_debug_infos=[ - "None" # TODO(necula): missing + "traced_for=shard_map, fun=my_f, arg_names=x, from x" ]) def test_remat_saved_residuals(self): @@ -1713,6 +1870,7 @@ def my_f(x, y): self.assertEqual(res[0][1], "from the argument x") self.assertRegex(res[1][1], r"named 'foo' from .*debug_info_test.py:.*my_f") + @unittest.skip("Test fails during no-thunks rewrite") def test_checkify_pmap_basic(self): if len(jax.devices()) < 2: self.skipTest("requires at least 2 devices") @@ -1724,19 +1882,33 @@ def my_f(my_x): y2 = jnp.sin(my_x) return (y1 + y2,) + if config.pmap_shmap_merge.value: + expected_jaxpr_debug_infos = [ + # TODO(necula): We should not include `call_wrapped` in the debug info. + re.compile(r"traced_for=jit, fun=call_wrapped at .*jax._src.linear_util.py:.*, arg_names=None, result_paths=None"), + re.compile(r"traced_for=shard_map, fun=call_wrapped at .*jax._src.linear_util.py:.*, arg_names=None, result_paths=None"), + re.compile(r"traced_for=jit, fun=checked_fun at .*jax._src.checkify.py:.*, arg_names=args\[0\], result_paths=(result\[\d+\]\[\d+\]\[ErrorEffect\(error_type=, shape_dtypes=.*\)\])*"), + ] + expected_tracer_debug_infos = [ + re.compile(r"traced_for=shard_map, fun=call_wrapped at .*jax._src.linear_util.py:.*, arg_names=args\[0\], from args\[0\]"), + ] + else: + expected_jaxpr_debug_infos = [ + # TODO(necula): this should not be pointing into the JAX internals + re.compile(r"traced_for=jit, fun=checked_fun at .*jax._src.checkify.py:.*, arg_names=args\[0\]"), + re.compile(r"traced_for=jit, fun=argsort at .*numpy.sorting.py:.*, arg_names=a, result_paths=result"), + "traced_for=pmap, fun=my_f, arg_names=None, result_paths=None", + ] + expected_tracer_debug_infos = [ + "traced_for=pmap, fun=my_f, arg_names=None, result_paths=None, from unknown", + ] + self._check_tracers_and_jaxprs( jax.jit(checkify.checkify(my_f, errors=checkify.nan_checks)), np.arange(len(jax.devices()), dtype=np.float32), tracer_spy=tracer_spy, - expected_jaxpr_debug_infos=[ - # TODO(necula): this should not be pointing into the JAX internals - re.compile(r"traced_for=jit, fun=checked_fun at .*jax._src.checkify.py:.*, arg_names=args\[0\]"), - re.compile(r"traced_for=jit, fun=argsort at .*numpy.sorting.py:.*, arg_names=a, result_paths=result"), - "traced_for=pmap, fun=my_f, arg_names=my_x, result_paths=result[0]", - ], - expected_tracer_debug_infos=[ - "traced_for=pmap, fun=my_f, arg_names=my_x, from my_x", - ], + expected_jaxpr_debug_infos=expected_jaxpr_debug_infos, + expected_tracer_debug_infos=expected_tracer_debug_infos, check_lowering=False, # TODO(necula): warning during lowering ) @@ -1786,15 +1958,22 @@ def my_rule(used_outs, y): jnp.sqrt(y) if used_outs[1] else None, ) + expected_jaxpr_debug_infos = [ + "traced_for=jit, fun=, arg_names=x, result_paths=result", + ] + if config.use_simplified_jaxpr_constants.value: + expected_jaxpr_debug_infos.extend([ + "traced_for=custom_dce, fun=my_f, arg_names=x, result_paths=result[0],result[1]", + ]) + else: + expected_jaxpr_debug_infos.extend([ + "traced_for=custom_dce, fun=my_f, arg_names=,x, result_paths=result[0],result[1]", + ]) self._check_tracers_and_jaxprs( jax.jit(lambda x: my_f(x)[0]), np.array(1.1234), tracer_spy=tracer_spy, - expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=, arg_names=x, result_paths=result", - # TODO(necula): bad arg_names (why None), bad result_paths - 'traced_for=custom_dce, fun=my_f, arg_names=,x, result_paths=result[0],result[1]', - ], + expected_jaxpr_debug_infos=expected_jaxpr_debug_infos, expected_tracer_debug_infos=[ # TODO(necula): no leaked tracer from my_rule? "traced_for=custom_dce, fun=my_f, arg_names=x, from x", @@ -1873,8 +2052,7 @@ def my_transpose_solve(f, x): expected_tracer_debug_infos=[ "traced_for=custom_root, fun=my_f, arg_names=x, from x", "traced_for=custom_root solve, fun=my_solve, arg_names=x, from x", - # TODO(necula): from None - "traced_for=custom_root tangent_solve, fun=my_transpose_solve, arg_names=x, from None", + "traced_for=custom_root tangent_solve, fun=my_transpose_solve, arg_names=f,x, from x", "None", # TODO(necula): there are missing debug info ]) @@ -1955,16 +2133,23 @@ def my_consts(x): tracer_spy.append(x) return x / scale - x = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32) + expected_jaxpr_debug_infos = [ + "traced_for=jit, fun=my_consts, arg_names=x, result_paths=result", + ] + if config.use_simplified_jaxpr_constants.value: + expected_jaxpr_debug_infos += [ + "traced_for=composite, fun=my_consts, arg_names=x, result_paths=result", + ] + else: + expected_jaxpr_debug_infos += [ + "traced_for=composite, fun=my_consts, arg_names=,x, result_paths=result", + ] self._check_tracers_and_jaxprs( jax.jit(my_consts), x, tracer_spy=tracer_spy, - expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_consts, arg_names=x, result_paths=result", - "traced_for=composite, fun=my_consts, arg_names=x, result_paths=result", - ], + expected_jaxpr_debug_infos=expected_jaxpr_debug_infos, expected_tracer_debug_infos=[ "traced_for=composite, fun=my_consts, arg_names=x, from x"]) diff --git a/tests/debug_nans_test.py b/tests/debug_nans_test.py index c80d23c416df..e00446235ad0 100644 --- a/tests/debug_nans_test.py +++ b/tests/debug_nans_test.py @@ -19,10 +19,10 @@ from unittest import SkipTest from jax._src import api +from jax._src import config from jax._src import test_util as jtu from jax import numpy as jnp -from jax.experimental import pjit -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map from jax.sharding import PartitionSpec as P jax.config.parse_flags_with_absl() @@ -89,7 +89,7 @@ def f(x): f(1) def testShardMap(self): - mesh = jax.make_mesh((1,), ('x',)) + mesh = jtu.create_mesh((1,), ('x',)) f = shard_map(lambda x: 0. / x, mesh=mesh, in_specs=(P('x')), out_specs=P('x')) # For the Cpp pmap, the first execution always goes through Python. f(jnp.array([1.])) @@ -136,9 +136,13 @@ def f(x): _, f_vjp = jax.vjp(jax.pmap(f), jnp.zeros([1])) + if config.pmap_shmap_merge.value: + expected_regex = r"Invalid value \(nan\) encountered in sharded computation." + else: + expected_regex = r"invalid value \(nan\) encountered in mul\nWhen differentiating" + with self.assertRaisesRegex( - FloatingPointError, - r"invalid value \(nan\) encountered in mul\nWhen differentiating"): + FloatingPointError, expected_regex): ans, = f_vjp(jnp.ones([1])) ans.block_until_ready() @@ -148,7 +152,7 @@ def f(x): y = x**2 return jnp.log(y) - mesh = jax.make_mesh((1,), ('x',)) + mesh = jtu.create_mesh((1,), ('x',)) shmap_f = shard_map(f, mesh=mesh, in_specs=(P('x')), out_specs=P('x')) _, f_vjp = jax.vjp(shmap_f, jnp.zeros([1])) @@ -162,16 +166,18 @@ def testPmapNoNaN(self): ans.block_until_ready() @jtu.ignore_warning(message=".*is an experimental.*") - def testPjit(self): + def test_jit(self): if jax.device_count() < 2: raise SkipTest("test requires >=2 devices") p = jax.sharding.PartitionSpec('x') - f = pjit.pjit(lambda x: 0. / x, in_shardings=p, out_shardings=p) + f = jax.jit(lambda x: 0. / x, in_shardings=p, out_shardings=p) + inp = jnp.array([0., 1.]) - with jax.sharding.Mesh(np.array(jax.local_devices()[:2]), ('x',)): + with jax.set_mesh( + jax.sharding.Mesh(np.array(jax.local_devices()[:2]), ('x',))): with self.assertRaises(FloatingPointError): - ans = f(jnp.array([0., 1.])) + ans = f(inp) ans.block_until_ready() def testDebugNansJitWithDonation(self): @@ -187,20 +193,18 @@ def testDebugNansPmapWithDonation(self): ans = jax.pmap(lambda x: 0. / x, donate_argnums=(0,))(a) ans.block_until_ready() - @jtu.ignore_warning(message=".*is an experimental.*") - def testDebugNansPjitWithDonation(self): + def testDebugNansJitWithDonationSharded(self): if jax.device_count() < 2: raise SkipTest("test requires >=2 devices") - p = jax.sharding.PartitionSpec('x') - f = pjit.pjit(lambda x: 0. / x, - in_shardings=p, - out_shardings=p, - donate_argnums=(0,)) + inp = jnp.array([0., 1.]) + f = jax.jit(lambda x: 0. / x, in_shardings=jax.P('x'), + out_shardings=jax.P('x'), donate_argnums=(0,)) - with jax.sharding.Mesh(np.array(jax.local_devices()[:2]), ('x',)): + with jax.set_mesh( + jax.sharding.Mesh(np.array(jax.local_devices()[:2]), ('x',))): with self.assertRaises(FloatingPointError): - ans = f(jnp.array([0., 1.])) + ans = f(inp) ans.block_until_ready() def testDebugNansZeroDiv(self): diff --git a/tests/debugger_test.py b/tests/debugger_test.py index 419e7b18dfed..084a3ede116d 100644 --- a/tests/debugger_test.py +++ b/tests/debugger_test.py @@ -21,7 +21,6 @@ from absl.testing import absltest import jax -from jax.experimental import pjit from jax._src import debugger from jax._src import test_util as jtu import jax.numpy as jnp @@ -43,6 +42,10 @@ def _format_multiline(text): foo = 2 +# This test is thread-unsafe because jax.effects_barrier() is global. This means +# that we can create a deadlock if running tests in multiple threads because we +# can introduce false dependencies via the effects barrier. +@jtu.thread_unsafe_test_class() class CliDebuggerTest(jtu.JaxTestCase): def setUp(self): @@ -272,7 +275,7 @@ def g(x): jax.effects_barrier() self.assertRegex(stdout.getvalue(), expected) - def test_debugger_works_with_pjit(self): + def test_debugger_works_with_jit(self): if jax.default_backend() != "tpu": raise unittest.SkipTest("`pjit` doesn't work with CustomCall.") @@ -286,18 +289,19 @@ def f(x): def g(x): y = f(x) return jnp.exp(y) - g = pjit.pjit( + g = jax.jit( g, in_shardings=jax.sharding.PartitionSpec("dev"), out_shardings=jax.sharding.PartitionSpec("dev"), ) - with jax.sharding.Mesh(np.array(jax.devices()), ["dev"]): - arr = (1 + jnp.arange(8)).astype(np.int32) + arr = (1 + jnp.arange(8)).astype(np.int32) + arr2 = jnp.arange(8, dtype=jnp.int32) + with jax.set_mesh(jax.sharding.Mesh(np.array(jax.devices()), ["dev"])): expected = _format_multiline(r""" Entering jdb: \(jdb\) {} \(jdb\) """.format(re.escape(repr(arr)))) - g(jnp.arange(8, dtype=jnp.int32)) + g(arr2) jax.effects_barrier() self.assertRegex(stdout.getvalue(), expected) diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index a8d59bc39e36..5c1ee7f25574 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -13,18 +13,20 @@ # limitations under the License. import collections import functools +import logging import textwrap import unittest -from absl.testing import absltest +from absl.testing import absltest, parameterized import jax from jax import lax -from jax.experimental import pjit -from jax.interpreters import pxla from jax._src import ad_checkpoint +from jax._src import config from jax._src import debugging from jax._src import dispatch from jax._src import test_util as jtu +from jax._src.interpreters import pxla +from jax.sharding import PartitionSpec as P import jax.numpy as jnp import numpy as np @@ -89,9 +91,7 @@ def step(v): def mean(forest): norm = 1.0 / len(forest) - add = lambda a, b: a + b - m = norm * functools.reduce(add, forest) - return m + return norm * sum(forest) post_mean = mean(tuple(run(x) for x in inputs)) jax.block_until_ready(post_mean) # This shouldn't deadlock. @@ -274,6 +274,28 @@ def f(x): jax.effects_barrier() self.assertEqual(output(), "[1.23 2.35 0. ]\n") + @parameterized.parameters([False, True]) + def test_debug_print_in_unrolled_loop(self, use_jit): + def body(i, _): + jax.debug.print("{}", i) + if use_jit: + body = jax.jit(body) + @jax.jit + def f(): + return jax.lax.fori_loop(0, 4, body, None, unroll=2) + with jtu.capture_stdout() as output: + f() + jax.effects_barrier() + actual = tuple(sorted(map(int, output().splitlines()))) + self.assertEqual(actual, tuple(range(4))) + + def test_debug_print_extended_dtype(self): + def f(k): + jax.debug.print("{}", k) + with jtu.capture_stdout(): + f(jax.random.key(0)) # doesn't crash + jax.effects_barrier() + @jtu.thread_unsafe_test_class() # printing isn't thread-safe class DebugPrintTransformationTest(jtu.JaxTestCase): @@ -412,16 +434,15 @@ def f(x): expected = jnp.array(2., jnp.float32) self.assertEqual(output(), f"x: 1.0\nx_grad: {expected}\n") - def test_debug_print_transpose_rule(self): - def f(x): - debug_print('should never be called: {}', x) - return x - with jtu.capture_stdout() as output: - jax.linear_transpose(f, 1.)(1.) - jax.effects_barrier() - # `debug_print` should be dropped by `partial_eval` because of no - # output data-dependence. - self.assertEqual(output(), "") + # mattjj was here + # def test_debug_print_transpose_rule(self): + # def f(x): + # debug_print('should never be called: {}', x) + # return x + # with jtu.capture_stdout() as output: + # jax.linear_transpose(f, 1.)(1.) + # jax.effects_barrier() + # self.assertEqual(output(), "") @jtu.sample_product(ordered=[False, True]) def test_remat_of_debug_print(self, ordered): @@ -432,7 +453,7 @@ def f_(x): return ad_checkpoint.checkpoint_name(jnp.exp(z), "w") # Policy that saves everything so the debug callback will be saved - f = ad_checkpoint.checkpoint(f_, policy=ad_checkpoint.everything_saveable) + f = jax.checkpoint(f_, policy=ad_checkpoint.everything_saveable) with jtu.capture_stdout() as output: jax.grad(f)(2.) @@ -443,7 +464,7 @@ def f_(x): # Policy that saves nothing so everything gets rematerialized, including the # debug callback - f = ad_checkpoint.checkpoint(f_, policy=ad_checkpoint.nothing_saveable) + f = jax.checkpoint(f_, policy=ad_checkpoint.nothing_saveable) with jtu.capture_stdout() as output: jax.grad(f)(2.) @@ -452,7 +473,7 @@ def f_(x): self.assertEqual(output(), "y: 3.0, z: 6.0\n" * 2) # Policy that does not save `z` so we will need to rematerialize the print - f = ad_checkpoint.checkpoint( + f = jax.checkpoint( f_, policy=ad_checkpoint.save_any_names_but_these("z")) with jtu.capture_stdout() as output: @@ -470,7 +491,7 @@ def policy(prim, *_, **params): return policy # Policy that saves everything but `y` - f = ad_checkpoint.checkpoint( + f = jax.checkpoint( f_, policy=save_everything_but_these_names("y")) with jtu.capture_stdout() as output: @@ -481,7 +502,7 @@ def policy(prim, *_, **params): self.assertEqual(output(), "y: 3.0, z: 6.0\n") # Policy that saves everything but `y` and `z` - f = ad_checkpoint.checkpoint( + f = jax.checkpoint( f_, policy=save_everything_but_these_names("y", "z")) with jtu.capture_stdout() as output: @@ -763,6 +784,7 @@ def b3(x): b3: 2 """)) + @jtu.thread_unsafe_test_class() # printing isn't thread-safe class DebugPrintParallelTest(jtu.JaxTestCase): @@ -774,12 +796,20 @@ def _count(lines): self.assertDictEqual(_count(text1.split("\n")), _count(text2.split("\n"))) def test_ordered_print_not_supported_in_pmap(self): - @jax.pmap def f(x): debug_print("{}", x, ordered=True) - with self.assertRaisesRegex( - ValueError, "Ordered effects not supported in `pmap`."): + if config.pmap_shmap_merge.value: + if jax.device_count() == 1: + self.skipTest("This test won't raise with 1 device.") + if jtu.device_under_test() == "gpu": + self.skipTest("Test does not raise under GPU.") + if jtu.device_under_test() == "tpu" and jtu.get_tpu_version() > 3: + self.skipTest("Test does not raise under TPU v4+.") + regex = "The following ordered effects are not supported for more than 1 device:*" + else: + regex = "Ordered effects not supported in `pmap`." + with self.assertRaisesRegex(ValueError, regex): f(jnp.arange(jax.local_device_count())) def test_unordered_print_works_in_pmap(self): @@ -804,15 +834,15 @@ def f2(x): jax.effects_barrier() self._assertLinesEqual(output(), "hello: 0\nhello: 1\nhello: 2\nhello: 3\n") - def test_unordered_print_with_pjit(self): + def test_unordered_print_with_jit(self): def f(x): debug_print("{}", x, ordered=False) return x mesh = jax.sharding.Mesh(np.array(jax.devices()), ['dev']) spec = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('dev')) out_spec = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) - f = pjit.pjit(f, in_shardings=spec, out_shardings=out_spec) - with mesh: + f = jax.jit(f, in_shardings=spec, out_shardings=out_spec) + with jax.set_mesh(mesh): with jtu.capture_stdout() as output: f(np.arange(8, dtype=jnp.int32)) jax.effects_barrier() @@ -822,24 +852,24 @@ def f2(x): y = x.dot(x) debug_print("{}", y, ordered=False) return y - f2 = pjit.pjit(f2, in_shardings=spec, out_shardings=out_spec) - with jax.sharding.Mesh(np.array(jax.devices()), ['dev']): + f2 = jax.jit(f2, in_shardings=spec, out_shardings=out_spec) + with jax.set_mesh(mesh): with jtu.capture_stdout() as output: f2(np.arange(8, dtype=jnp.int32)) jax.effects_barrier() self.assertEqual(output(), "140\n") - def test_nested_pjit_debug_print(self): + def test_nested_jit_debug_print(self): def f(x): debug_print("{}", x) return x with jtu.capture_stdout() as output: - pjit.pjit(pjit.pjit(f))(jnp.arange(8)) + jax.jit(jax.jit(f))(jnp.arange(8)) jax.effects_barrier() self.assertEqual(output(), "[0 1 2 3 4 5 6 7]\n") - def test_unordered_print_of_pjit_of_while(self): + def test_unordered_print_of_jit_of_while(self): def f(x): def cond(carry): i, *_ = carry @@ -853,8 +883,8 @@ def body(carry): mesh = jax.sharding.Mesh(np.array(jax.devices()), ['dev']) spec = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('dev')) - f = pjit.pjit(f, in_shardings=spec, out_shardings=spec) - with mesh: + f = jax.jit(f, in_shardings=spec, out_shardings=spec) + with jax.set_mesh(mesh): with jtu.capture_stdout() as output: f(np.arange(8, dtype=jnp.int32)) jax.effects_barrier() @@ -1077,6 +1107,8 @@ def test_visualize_wide_array(self): """) self.assertEqual(output(), expected) + @jtu.ignore_warning(category=DeprecationWarning, + message='jax.sharding.PmapSharding is deprecated') def test_visualize_pmap_sharding(self): ss = pxla.ShardingSpec( sharding=(pxla.Unstacked(8),), @@ -1120,9 +1152,31 @@ def test_visualize_pmap_sharding(self): """) self.assertEqual(output(), expected) + def test_visualize_sharding_shard_map(self): + mesh = jtu.create_mesh((2,), 'x') + + def f(): + a = jnp.zeros(1000) + debugging.visualize_array_sharding(a) + return a + + with jtu.capture_stdout() as output: + f() # doesn't crash + + with jtu.capture_stdout() as output: + jax.jit(f, out_shardings=jax.NamedSharding(mesh, P('x')))() # doesn't crash + + with jtu.capture_stdout() as output: + jax.shard_map(f, mesh=mesh, in_specs=P(None), out_specs=P("x"))() # doesn't crash + + with jtu.capture_stdout() as output: + jax.shard_map(f, mesh=mesh, in_specs=P(None), out_specs=P("x"), + check_vma=False)() # doesn't crash + + class InspectShardingTest(jtu.JaxTestCase): - def test_inspect_sharding_is_called_in_pjit(self): + def test_inspect_sharding_is_called_in_jit_sharded(self): if jtu.is_cloud_tpu(): raise unittest.SkipTest("Inspect sharding is not supported on libtpu.") @@ -1141,8 +1195,8 @@ def f(x): mesh = jax.sharding.Mesh(np.array(jax.devices()), ['dev']) spec = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('dev')) out_spec = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) - f = pjit.pjit(f, in_shardings=spec, out_shardings=out_spec) - with mesh: + f = jax.jit(f, in_shardings=spec, out_shardings=out_spec) + with jax.set_mesh(mesh): f(np.arange(8, dtype=jnp.int32)) self.assertTrue(is_called) @@ -1185,22 +1239,197 @@ def f_(x): f(arr) - def test_inspect_sharding_3d_pjit(self): - def _cb(sd): - self.assertIsInstance(sd, jax.sharding.NamedSharding) - self.assertLen(sd.device_set, 2) +def _get_output_set(output, num_lines): + """Return a set of strings where each string is num_lines.""" + output = output().strip().split("\n") + return { + "\n".join(output[i : i + num_lines]) + for i in range(0, len(output), num_lines) + } + + +@jtu.thread_unsafe_test_class() # printing isn't thread-safe +class PartitionedDebugCallbackTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if (jtu.device_under_test() not in ("cpu", "gpu")): + raise unittest.SkipTest( + f"Test requires CPU or GPU devices. Got {jtu.device_under_test()}" + ) + if len(jax.devices()) < 2: + raise unittest.SkipTest("Test requires >= 2 devices.") + + def tearDown(self): + super().tearDown() + dispatch.runtime_tokens.clear() + + def test_partitioned_debug_callback(self): def f_(x): - debugging.inspect_array_sharding(x, callback=_cb) - return jnp.square(x) + debug_print("hello: {x}", x=x, partitioned=True) - f = pjit.pjit(f_) - mesh = jtu.create_mesh((2,), ('x')) - s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x')) - arr = jax.device_put(np.arange(8).reshape(2, 2, 2), s) + f = jax.jit(f_) + mesh = jtu.create_mesh((1, 1, 2,), ("x", "y", "z")) + s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("x", "y", "z")) + arr = jax.device_put(np.arange(24).reshape(2, 3, 4), s) - with mesh: - f(arr) + with jtu.capture_stdout() as output: + with jax.set_mesh(mesh): + f(arr) + jax.effects_barrier() + + expected = { + _format_multiline(""" + hello: [[[ 0 1] + [ 4 5] + [ 8 9]] + + [[12 13] + [16 17] + [20 21]]]"""), + _format_multiline(""" + hello: [[[ 2 3] + [ 6 7] + [10 11]] + + [[14 15] + [18 19] + [22 23]]]"""), + } + self.assertEqual(_get_output_set(output, 7), expected) + + def test_partitioned_debug_callback_compute(self): + def f(x): + debug_print("hello: {x}", x=x.sum(), partitioned=True) + + mesh = jtu.create_mesh((2,), ("x",)) + arr = jax.device_put(np.arange(8), jax.NamedSharding(mesh, jax.P("x"))) + + with jtu.capture_stdout() as output: + with jax.set_mesh(mesh): + f(arr) + jax.effects_barrier() + + def test_debug_print_batching(self): + @jax.vmap + def f_(x): + debug_print("hello: {}", x, partitioned=True) + + f = jax.jit(f_) + mesh = jtu.create_mesh((1, 1, 2), ("x", "y", "z")) + s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("x", "y", "z")) + arr = np.arange(24).reshape(2, 3, 4) + arr = jax.device_put(arr, s) + + with jtu.capture_stdout() as output: + with jax.set_mesh(mesh): + f(arr) + jax.effects_barrier() + + expected = { + _format_multiline(""" + hello: [[0 1] + [4 5] + [8 9]]"""), + _format_multiline(""" + hello: [[ 2 3] + [ 6 7] + [10 11]]"""), + _format_multiline(""" + hello: [[14 15] + [18 19] + [22 23]]"""), + _format_multiline(""" + hello: [[12 13] + [16 17] + [20 21]]"""), + } + + self.assertEqual(_get_output_set(output, 3), expected) + + def test_debug_print_batching_with_diff_axes(self): + @functools.partial(jax.vmap, in_axes=(0, 1)) + def f_(x, y): + debug_print("hello: {} {}", x, y, partitioned=True) + + f = jax.jit(f_) + mesh = jtu.create_mesh((2,), ("x")) + s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("x")) + x = np.arange(4).reshape(2, 2) + x = jax.device_put(x, s) + y = np.arange(4).reshape(2, 2) + 6 + y = jax.device_put(y, s) + + with jtu.capture_stdout() as output: + with jax.set_mesh(mesh): + f(x, y) + jax.effects_barrier() + + expected = { + "hello: [2 3] [9]", + "hello: [0 1] [6]", + "hello: [0 1] [8]", + "hello: [2 3] [7]", + } + + self.assertEqual(_get_output_set(output, 1), expected) + + def test_debug_print_with_logging(self): + logger_name = "jax._src.debugging" + jax_logger = logging.getLogger(logger_name) + class RecordHandler(logging.Handler): + def __init__(self): + logging.Handler.__init__(self) + self.records = [] + def emit(self, record): + self.records.append(record) + record_handler = RecordHandler() + jax_logger.handlers.append(record_handler) + + def log_fn(x): + x = x * x + jax.debug.log("x={}", x) + return x * x + + self.assertEqual(jax.jit(log_fn)(2), 16) + jax_logger.removeHandler(record_handler) + self.assertEqual(record_handler.records[0].msg, "x=4") + + + def test_debug_print_with_nested_vmap(self): + @jax.vmap + @jax.vmap + def f_(x): + debug_print("hello: {}", x, partitioned=True) + + f = jax.jit(f_) + mesh = jtu.create_mesh((1, 1, 2), ("x", "y", "z")) + s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("x", "y", "z")) + arr = np.arange(24).reshape(2, 3, 4) + arr = jax.device_put(arr, s) + + with jtu.capture_stdout() as output: + with jax.set_mesh(mesh): + f(arr) + jax.effects_barrier() + + expected = { + "hello: [14 15]", + "hello: [12 13]", + "hello: [18 19]", + "hello: [16 17]", + "hello: [22 23]", + "hello: [20 21]", + "hello: [2 3]", + "hello: [0 1]", + "hello: [6 7]", + "hello: [10 11]", + "hello: [4 5]", + "hello: [8 9]", + } + + self.assertEqual(_get_output_set(output, 1), expected) if not rich: diff --git a/tests/device_test.py b/tests/device_test.py index d8f2ae65bbac..f93f72182af7 100644 --- a/tests/device_test.py +++ b/tests/device_test.py @@ -28,6 +28,9 @@ def test_repr(self): if jtu.is_device_cuda(): self.assertEqual(device.platform, 'gpu') self.assertEqual(repr(device), 'CudaDevice(id=0)') + elif jtu.is_device_rocm(): + self.assertEqual(device.platform, 'gpu') + self.assertEqual(repr(device), 'RocmDevice(id=0)') elif jtu.test_device_matches(['tpu']): self.assertEqual(device.platform, 'tpu') self.assertEqual( @@ -44,6 +47,8 @@ def test_str(self): # TODO(pobudzey): Add a test for rocm devices when available. if jtu.is_device_cuda(): self.assertEqual(str(device), 'cuda:0') + elif jtu.is_device_rocm(): + self.assertEqual(str(device), 'rocm:0') elif jtu.test_device_matches(['tpu']): self.assertEqual(str(device), 'TPU_0(process=0,(0,0,0,0))') elif jtu.test_device_matches(['cpu']): diff --git a/tests/distributed_initialize_test.py b/tests/distributed_initialize_test.py new file mode 100644 index 000000000000..33242a41a68e --- /dev/null +++ b/tests/distributed_initialize_test.py @@ -0,0 +1,44 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from absl.testing import absltest +import jax +from jax._src import test_util as jtu + +try: + import portpicker +except ImportError: + portpicker = None + +jax.config.parse_flags_with_absl() + + +@unittest.skipIf(not portpicker, "Test requires portpicker") +class DistributedInitializeTest(jtu.JaxTestCase): + + @jtu.skip_under_pytest( + """Side effects from jax.distributed.initialize conflict with other tests + in the same process. pytest runs multiple tests in the same process.""" + ) + def test_is_distributed_initialized(self): + port = portpicker.pick_unused_port() # type: ignore + self.assertFalse(jax.distributed.is_initialized()) + jax.distributed.initialize(f"localhost:{port}", 1, 0) + self.assertTrue(jax.distributed.is_initialized()) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/distributed_test.py b/tests/distributed_test.py index 3961932dfad0..ae72143fbe7d 100644 --- a/tests/distributed_test.py +++ b/tests/distributed_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import subprocess -import sys import threading import unittest @@ -43,7 +41,10 @@ def testInitializeAndShutdown(self): # concurrency to simulate multiple tasks. port = portpicker.pick_unused_port() jax.distributed.initialize( - coordinator_address=f"localhost:{port}", num_processes=1, process_id=0 + coordinator_address=f"localhost:{port}", + num_processes=1, + process_id=0, + cluster_detection_method="deactivate", ) jax.distributed.shutdown() @@ -57,7 +58,10 @@ def task(i): # We can't call the public APIs directly because they use global state. state = distributed.State() state.initialize( - coordinator_address=f"localhost:{port}", num_processes=n, process_id=i + coordinator_address=f"localhost:{port}", + num_processes=n, + process_id=i, + cluster_detection_method="deactivate", ) state.shutdown() @@ -67,22 +71,6 @@ def task(i): for thread in threads: thread.join() - def test_is_distributed_initialized(self): - # Run in subprocess to isolate side effects from jax.distributed.initialize which conflict with other - # tests. Unfortunately this can't be avoided by calling jax.distributed.shutdown, as the XLA backend - # will be warmed up, which yields a RuntimeError on subsequent calls to initialize. - port = portpicker.pick_unused_port() # type: ignore - cmd = f"""import jax; - assert not jax.distributed.is_initialized(); - jax.distributed.initialize('localhost:{port}', 1, 0); - assert jax.distributed.is_initialized(); - """.replace("\n", ' ') - - result = subprocess.run([sys.executable, "-c", cmd], capture_output=True) - self.assertEqual( - result.returncode, 0, msg=f"Test failed with:\n{result.stdout}\n{result.stderr}" - ) - if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/documentation_coverage_test.py b/tests/documentation_coverage_test.py new file mode 100644 index 000000000000..6cef08e26f62 --- /dev/null +++ b/tests/documentation_coverage_test.py @@ -0,0 +1,291 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test that public APIs are correctly documented.""" + +import collections +from collections.abc import Iterator, Mapping, Sequence +import importlib +import functools +import os +import pkgutil +import warnings + +from absl.testing import absltest +from absl.testing import parameterized + +import jax +import jax._src.test_util as jtu +from jax._src import config + +config.parse_flags_with_absl() + + +CURRENTMODULE_TAG = '.. currentmodule::' +AUTOMODULE_TAG = '.. automodule::' +AUTOSUMMARY_TAG = '.. autosummary::' +AUTOCLASS_TAG = '.. autoclass::' + + +@functools.lru_cache() +def jax_docs_dir() -> str: + """Return the string or path object pointing to the JAX docs.""" + try: + # In bazel, access docs files via data dependencies of a jax.docs package. + return importlib.resources.files('jax.docs') + except ImportError: + # Outside of bazel, assume code is layed out as in the github repository, where + # the docs and tests subdirectories are both within the same top-level directory. + return os.path.abspath(os.path.join(__file__, os.pardir, os.pardir, "docs")) + + +UNDOCUMENTED_APIS = { + 'jax': ['NamedSharding', 'P', 'Ref', 'Shard', 'reshard', 'ad_checkpoint', 'api_util', 'checkpoint_policies', 'core', 'custom_derivatives', 'custom_transpose', 'debug_key_reuse', 'device_put_replicated', 'device_put_sharded', 'effects_barrier', 'example_libraries', 'explain_cache_misses', 'experimental', 'extend', 'float0', 'freeze', 'fwd_and_bwd', 'host_count', 'host_id', 'host_ids', 'interpreters', 'jax', 'jax2tf_associative_scan_reductions', 'legacy_prng_key', 'lib', 'make_user_context', 'new_ref', 'no_execution', 'numpy_dtype_promotion', 'remat', 'remove_size_one_mesh_axis_from_type', 'softmax_custom_jvp', 'threefry_partitionable', 'thread_guard', 'tools', 'transfer_guard_device_to_device', 'transfer_guard_device_to_host', 'transfer_guard_host_to_device', 'version'], + 'jax.ad_checkpoint': ['checkpoint', 'checkpoint_policies', 'print_saved_residuals', 'remat', 'Offloadable', 'Recompute', 'Saveable'], + 'jax.custom_batching': ['custom_vmap', 'sequential_vmap'], + 'jax.custom_derivatives': ['CustomVJPPrimal', 'SymbolicZero', 'closure_convert', 'custom_gradient', 'custom_jvp', 'custom_jvp_call_p', 'custom_vjp', 'custom_vjp_call_p', 'custom_vjp_primal_tree_values', 'linear_call', 'remat_opt_p', 'zero_from_primal'], + 'jax.custom_transpose': ['custom_transpose'], + 'jax.debug': ['DebugEffect', 'log'], + 'jax.distributed': ['is_initialized'], + 'jax.dtypes': ['extended', 'finfo', 'iinfo'], + 'jax.ffi': ['build_ffi_lowering_function', 'include_dir', 'register_ffi_target_as_batch_partitionable', 'register_ffi_type_id'], + 'jax.lax': ['pcast', 'unreduced_psum', 'dce_sink', 'conv_transpose_shape_tuple', 'reduce_window_shape_tuple', 'conv_general_permutations', 'conv_general_shape_tuple', 'pbroadcast', 'padtype_to_pads', 'conv_shape_tuple', 'unreduced_psum_scatter', 'create_token', 'dtype', 'shape_as_value', 'all_gather_reduced', 'pvary', *(name for name in dir(jax.lax) if name.endswith('_p'))], + 'jax.lax.linalg': [api for api in dir(jax.lax.linalg) if api.endswith('_p')], + 'jax.memory': ['Space'], + 'jax.monitoring': ['clear_event_listeners', 'record_event', 'record_event_duration_secs', 'record_event_time_span', 'record_scalar', 'register_event_duration_secs_listener', 'register_event_listener', 'register_event_time_span_listener', 'register_scalar_listener', 'unregister_event_duration_listener', 'unregister_event_listener', 'unregister_event_time_span_listener', 'unregister_scalar_listener'], + 'jax.numpy': ['bfloat16', 'bool', 'e', 'euler_gamma', 'float4_e2m1fn', 'float8_e3m4', 'float8_e4m3', 'float8_e4m3b11fnuz', 'float8_e4m3fn', 'float8_e4m3fnuz', 'float8_e5m2', 'float8_e5m2fnuz', 'float8_e8m0fnu', 'inf', 'int2', 'int4', 'nan', 'newaxis', 'pi', 'uint2', 'uint4'], + 'jax.profiler': ['ProfileData', 'ProfileEvent', 'ProfileOptions', 'ProfilePlane', 'stop_server'], + 'jax.random': ['key_impl', 'random_gamma_p'], + 'jax.scipy.special': ['bessel_jn', 'sph_harm_y'], + 'jax.sharding': ['AbstractDevice', 'AbstractMesh', 'AxisType', 'auto_axes', 'explicit_axes', 'get_abstract_mesh', 'reshard', 'set_mesh', 'use_abstract_mesh', 'get_mesh'], + 'jax.stages': ['ArgInfo', 'CompilerOptions'], + 'jax.tree_util': ['DictKey', 'FlattenedIndexKey', 'GetAttrKey', 'PyTreeDef', 'SequenceKey', 'default_registry'], +} + +# A list of modules to skip entirely, either because they cannot be imported +# or because they are not expected to be documented. +MODULES_TO_SKIP = [ + "jax.api_util", # internal tools, not documented. + "jax.cloud_tpu_init", # deprecated in JAX v0.8.1 + "jax.collect_profile", # fails when xprof is not available. + "jax.core", # internal tools, not documented. + "jax.example_libraries", # TODO(jakevdp): un-skip these. + "jax.extend.core.primitives", + "jax.extend.ifrt_programs", + "jax.extend.mlir.dialects", + "jax.extend.mlir.ir", + "jax.extend.mlir.passmanager", + "jax.extend.sharding", + "jax.extend.source_info_util", + "jax.experimental", # Many non-public submodules. + "jax.interpreters", # internal tools, not documented. + "jax.jaxlib", # internal tools, not documented. + "jax.lib", # deprecated in JAX v0.8.0 + "jax.tools", # internal tools, not documented. + "jax.version", # no public APIs. +] + + +def extract_apis_from_rst_file(path: str) -> dict[str, list[str]]: + """Extract documented APIs from an RST file.""" + # We could do this more robustly by adding a docutils dependency, but that is + # pretty heavy. Instead we use simple string-based file parsing, recognizing the + # particular patterns used within the JAX documentation. + currentmodule: str = '' + in_autosummary_block = False + apis = collections.defaultdict(list) + with open(path, 'r') as f: + for line in f: + stripped_line = line.strip() + if not stripped_line: + continue + if line.startswith(CURRENTMODULE_TAG): + currentmodule = line.removeprefix(CURRENTMODULE_TAG).strip() + continue + if line.startswith(AUTOMODULE_TAG): + currentmodule = line.removeprefix(AUTOMODULE_TAG).strip() + continue + if line.startswith(AUTOCLASS_TAG): + in_autosummary_block = False + apis[currentmodule].append(line.removeprefix(AUTOCLASS_TAG).strip()) + continue + if line.startswith(AUTOSUMMARY_TAG): + in_autosummary_block = True + continue + if not in_autosummary_block: + continue + if not line.startswith(' '): + in_autosummary_block = False + continue + if stripped_line.startswith(':'): + continue + apis[currentmodule].append(stripped_line) + return dict(apis) + + +@functools.lru_cache() +def get_all_documented_jax_apis() -> Mapping[str, list[str]]: + """Get the list of APIs documented in all files in a directory (recursive).""" + path = jax_docs_dir() + + apis = collections.defaultdict(list) + for root, _, files in os.walk(path): + if (root.startswith(os.path.join(path, 'build')) + or root.startswith(os.path.join(path, '_autosummary'))): + continue + for filename in files: + if filename.endswith('.rst'): + new_apis = extract_apis_from_rst_file(os.path.join(root, filename)) + for key, val in new_apis.items(): + apis[key].extend(val) + return {key: sorted(vals) for key, vals in apis.items()} + + +@functools.lru_cache() +def list_public_jax_modules() -> Sequence[str]: + """Return a list of the public modules defined in jax.""" + # We could use pkgutil.walk_packages, but we want to avoid traversing modules + # like `jax._src`, `jax.example_libraries`, etc. so we implement it manually. + def walk_public_modules(paths: list[str], parent_package: str) -> Iterator[str]: + for info in pkgutil.iter_modules(paths): + pkg_name = f"{parent_package}.{info.name}" + if pkg_name in MODULES_TO_SKIP or info.name == 'tests' or info.name.startswith('_'): + continue + yield pkg_name + if not info.ispkg: + continue + try: + submodule = importlib.import_module(pkg_name) + except ImportError as e: + warnings.warn(f"failed to import {pkg_name}: {e!r}") + else: + if path := getattr(submodule, '__path__', None): + yield from walk_public_modules(path, pkg_name) + return [jax.__name__, *walk_public_modules(jax.__path__, jax.__name__)] + + +@functools.lru_cache() +def list_public_apis(module_name: str) -> Sequence[str]: + """Return a list of public APIs within a specified module. + + This will import the module as a side-effect. + """ + module = importlib.import_module(module_name) + return [api for api in dir(module) + if not api.startswith('_') # skip private members + and not api.startswith('@') # skip injected pytest-related symbols + ] + + +@functools.lru_cache() +def get_all_public_jax_apis() -> Mapping[str, list[str]]: + """Return a dictionary mapping jax submodules to their list of public APIs.""" + apis = {} + for module in list_public_jax_modules(): + try: + apis[module] = list_public_apis(module) + except ImportError as e: + warnings.warn(f"failed to import {module}: {e}") + return apis + + +class DocumentationCoverageTest(jtu.JaxTestCase): + + def setUp(self): + if jtu.runtime_environment() == 'bazel': + self.skipTest("Skipping test in bazel, because rst docs aren't accessible.") + + def test_list_public_jax_modules(self): + """Simple smoke test for list_public_jax_modules()""" + apis = list_public_jax_modules() + + # A few submodules which should be included + self.assertIn("jax", apis) + self.assertIn("jax.numpy", apis) + self.assertIn("jax.numpy.linalg", apis) + + # A few submodules which should not be included + self.assertNotIn("jax._src", apis) + self.assertNotIn("jax._src.numpy", apis) + self.assertNotIn("jax.example_libraries", apis) + self.assertNotIn("jax.experimental.jax2tf", apis) + + def test_list_public_apis(self): + """Simple smoketest for list_public_apis()""" + jnp_apis = list_public_apis('jax.numpy') + self.assertIn("array", jnp_apis) + self.assertIn("zeros", jnp_apis) + self.assertNotIn("jax.numpy.array", jnp_apis) + self.assertNotIn("np", jnp_apis) + self.assertNotIn("jax", jnp_apis) + + def test_get_all_public_jax_apis(self): + """Simple smoketest for get_all_public_jax_apis()""" + apis = get_all_public_jax_apis() + self.assertIn("Array", apis["jax"]) + self.assertIn("array", apis["jax.numpy"]) + self.assertIn("eigh", apis["jax.numpy.linalg"]) + + def test_extract_apis_from_rst_file(self): + """Simple smoketest for extract_apis_from_rst_file()""" + numpy_docs = os.path.join(jax_docs_dir(), "jax.numpy.rst") + apis = extract_apis_from_rst_file(numpy_docs) + + self.assertIn("jax.numpy", apis.keys()) + self.assertIn("jax.numpy.linalg", apis.keys()) + + self.assertIn("array", apis["jax.numpy"]) + self.assertIn("asarray", apis["jax.numpy"]) + self.assertIn("eigh", apis["jax.numpy.linalg"]) + self.assertNotIn("jax", apis["jax.numpy"]) + self.assertNotIn("jax.numpy", apis["jax.numpy"]) + + def test_get_all_documented_jax_apis(self): + """Simple smoketest of get_all_documented_jax_apis()""" + apis = get_all_documented_jax_apis() + self.assertIn("Array", apis["jax"]) + self.assertIn("arange", apis["jax.numpy"]) + self.assertIn("eigh", apis["jax.lax.linalg"]) + + @parameterized.parameters(list_public_jax_modules()) + def test_module_apis_documented(self, module): + """Test that the APIs in each module are appropriately documented.""" + public_apis = get_all_public_jax_apis() + documented_apis = get_all_documented_jax_apis() + + pub_apis = {f"{module}.{api}" for api in public_apis.get(module, ())} + doc_apis = {f"{module}.{api}" for api in documented_apis.get(module, ())} + undoc_apis = {f"{module}.{api}" for api in UNDOCUMENTED_APIS.get(module, ())} + + # Remove submodules from list. + pub_apis -= public_apis.keys() + pub_apis -= set(MODULES_TO_SKIP) + + if (notempty := undoc_apis & doc_apis): + raise ValueError( + f"Found stale values in the UNDOCUMENTED_APIS list: {notempty}." + " If this fails, the fix is typically to remove the offending entries" + " from the UNDOCUMENTED_APIS mapping.") + + if (notempty := pub_apis - doc_apis - undoc_apis): + raise ValueError( + f"Found public APIs that are not listed within docs: {notempty}." + " If this fails, it likely means a new public API has been added to the" + " jax package without an associated entry in docs/*.rst. To fix this," + " either add the missing documentation entries, or add these names to the" + " UNDOCUMENTED_APIS mapping to indicate it is deliberately undocumented.") + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index 87380443f4cb..0c22530e6150 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -31,6 +31,7 @@ from jax._src import earray from jax._src import config from jax._src import dtypes +from jax._src import literals from jax._src import test_util as jtu from jax._src.lax import lax as lax_internal @@ -46,30 +47,19 @@ np.dtype('uint64')] unsigned_dtypes = list(np_unsigned_dtypes) -intn_dtypes = [np.dtype('int4'), np.dtype('uint4')] -signed_dtypes += [np.dtype('int4')] -unsigned_dtypes += [np.dtype('uint4')] -if dtypes.int2 is not None: - assert dtypes.uint2 is not None - intn_dtypes[:0] = [np.dtype('int2'), np.dtype('uint2')] - signed_dtypes[:0] = [np.dtype('int2')] - unsigned_dtypes[:0] = [np.dtype('uint2')] +intn_dtypes = [np.dtype('int2'), np.dtype('uint2'), np.dtype('int4'), np.dtype('uint4')] +signed_dtypes += [np.dtype('int2'), np.dtype('int4')] +unsigned_dtypes += [np.dtype('uint2'), np.dtype('uint4')] -np_float_dtypes = [np.dtype('float16'), np.dtype('float32'), - np.dtype('float64')] +np_float_dtypes = [np.dtype('float16'), np.dtype('float32'), np.dtype('float64')] float_dtypes = [np.dtype(dtypes.bfloat16)] + np_float_dtypes custom_float_dtypes = [np.dtype(dtypes.bfloat16)] fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz), np.dtype(dtypes.float8_e4m3fn), np.dtype(dtypes.float8_e4m3fnuz), np.dtype(dtypes.float8_e5m2), - np.dtype(dtypes.float8_e5m2fnuz)] -if dtypes.float8_e3m4 is not None: - fp8_dtypes += [np.dtype(dtypes.float8_e3m4)] -if dtypes.float8_e4m3 is not None: - fp8_dtypes += [np.dtype(dtypes.float8_e4m3)] -if dtypes.float8_e8m0fnu is not None: - fp8_dtypes += [np.dtype(dtypes.float8_e8m0fnu)] + np.dtype(dtypes.float8_e5m2fnuz), np.dtype(dtypes.float8_e3m4), + np.dtype(dtypes.float8_e4m3), np.dtype(dtypes.float8_e8m0fnu)] float_dtypes += fp8_dtypes custom_float_dtypes += fp8_dtypes @@ -79,6 +69,9 @@ float_dtypes += fp4_dtypes custom_float_dtypes += fp4_dtypes +x64_dtypes = [np.dtype('int64'), np.dtype('uint64'), np.dtype('float64'), + np.dtype('complex128')] + complex_dtypes = [np.dtype('complex64'), np.dtype('complex128')] @@ -123,6 +116,10 @@ def identity(x): """A named identity function for use in tests""" return x +TypedInt = literals.TypedInt +TypedFloat = literals.TypedFloat +TypedComplex = literals.TypedComplex +TypedNdArray = literals.TypedNdArray class DtypesTest(jtu.JaxTestCase): @@ -134,11 +131,41 @@ def test_canonicalize_type(self): for in_dtype, expected_dtype in expected[config.enable_x64.value].items(): self.assertEqual(dtypes.canonicalize_dtype(in_dtype), expected_dtype) + def test_canonicalize_value_preserves_literal_dtypes(self): + self.assertEqual(np.dtype(np.int32), dtypes.canonicalize_value( + TypedInt(6, dtype=np.dtype(np.int32))).dtype) + self.assertEqual(np.dtype(np.int64), dtypes.canonicalize_value( + TypedInt(6, dtype=np.dtype(np.int64))).dtype) + self.assertEqual(np.dtype(np.float32), dtypes.canonicalize_value( + TypedFloat(6, dtype=np.dtype(np.float32))).dtype) + self.assertEqual(np.dtype(np.float64), dtypes.canonicalize_value( + TypedFloat(6, dtype=np.dtype(np.float64))).dtype) + self.assertEqual(np.dtype(np.complex64), dtypes.canonicalize_value( + TypedComplex(6, dtype=np.dtype(np.complex64))).dtype) + self.assertEqual(np.dtype(np.complex128), dtypes.canonicalize_value( + TypedComplex(6, dtype=np.dtype(np.complex128))).dtype) + self.assertEqual( + np.dtype(np.int32), + dtypes.canonicalize_value( + TypedNdArray(np.array([6], dtype=np.dtype(np.int32)), + weak_type=False) + ).dtype, + ) + self.assertEqual( + np.dtype(np.int64), + dtypes.canonicalize_value( + TypedNdArray(np.array([6], dtype=np.dtype(np.int64)), + weak_type=False) + ).dtype, + ) + @parameterized.named_parameters( {"testcase_name": f"_type={type_.__name__}", "type_": type_} for type_ in python_scalar_types) def testDefaultTypes(self, type_): - expected_dtype = dtypes.canonicalize_dtype(dtypes.python_scalar_dtypes[type_]) + expected_dtype = dtypes.canonicalize_dtype( + dtypes.python_scalar_types_to_dtypes[type_] + ) for f in [jnp.array, jax.jit(jnp.array), jax.jit(lambda x: x)]: y = f(type_(0)) self.assertTrue(isinstance(y, jax.Array), msg=(f, y)) @@ -156,9 +183,12 @@ def testUnsupportedType(self): message="Explicitly requested dtype.*") @jax.numpy_dtype_promotion('standard') def testBinaryPromotion(self, swap, jit): + if config.explicit_x64_dtypes.value == config.ExplicitX64Mode.ERROR: + self.skipTest("Test uses explicit x64 dtypes") + dfloat = dtypes.canonicalize_dtype(float) testcases = [ - (jnp.array(1.), 0., jnp.float64), - (jnp.array(1.), jnp.array(0.), jnp.float64), + (jnp.array(1.), 0., dfloat), + (jnp.array(1.), jnp.array(0.), dfloat), (jnp.array(1.), jnp.array(0., dtype=jnp.float16), jnp.float16), (jnp.array(1.), jnp.array(0., dtype=jnp.float32), jnp.float32), (jnp.array(1.), jnp.array(0., dtype=jnp.float64), jnp.float64), @@ -171,10 +201,10 @@ def testBinaryPromotion(self, swap, jit): (jnp.array(1., dtype=jnp.float32), jnp.array(0., dtype=jnp.float32), jnp.float32), (jnp.array(1., dtype=jnp.float32), jnp.array(0., dtype=jnp.float64), jnp.float64), (jnp.array(1., dtype=jnp.float64), jnp.array(0., dtype=jnp.float64), jnp.float64), - (jnp.array([1.]), 0., jnp.float64), - (jnp.array([1.]), jnp.array(0.), jnp.float64), - (jnp.array([1.]), jnp.array(0., dtype=jnp.float16), jnp.float64), - (jnp.array([1.]), jnp.array(0., dtype=jnp.float32), jnp.float64), + (jnp.array([1.]), 0., dfloat), + (jnp.array([1.]), jnp.array(0.), dfloat), + (jnp.array([1.]), jnp.array(0., dtype=jnp.float16), dfloat), + (jnp.array([1.]), jnp.array(0., dtype=jnp.float32), dfloat), (jnp.array([1.]), jnp.array(0., dtype=jnp.float64), jnp.float64), (jnp.array([1.], dtype=jnp.float32), jnp.array(0., dtype=jnp.float16), jnp.float32), (jnp.array([1.], dtype=jnp.float16), jnp.array(0., dtype=jnp.float32), jnp.float32), @@ -185,7 +215,10 @@ def testBinaryPromotion(self, swap, jit): x, y = (y, x) if swap else (x, y) z = op(x, y) self.assertTrue(isinstance(z, jax.Array), msg=(x, y, z)) - self.assertEqual(z.dtype, dtypes.canonicalize_dtype(dtype), msg=(x, y, z)) + if config.explicit_x64_dtypes.value == config.ExplicitX64Mode.ALLOW: + self.assertEqual(z.dtype, dtype, msg=(x, y, z)) + else: + self.assertEqual(z.dtype, dtypes.canonicalize_dtype(dtype), msg=(x, y, z)) @jax.numpy_dtype_promotion('strict') def testPromoteDtypesStrict(self): @@ -211,9 +244,10 @@ def testPromoteDtypesStrict(self): # np.dtype(int) is int32 on Windows and int64 on Linux/Mac. py_result_dtype = (np.dtype(np.int64) if py_result is int else np.dtype(py_result)) - lattice_dtype, lattice_weak_type = dtypes._lattice_result_type(t1, t2) + lattice_dtype, lattice_weak_type = dtypes.lattice_result_type(t1, t2) self.assertTrue(lattice_weak_type) - self.assertEqual(lattice_dtype, py_result_dtype) + self.assertEqual(lattice_dtype, + dtypes.canonicalize_dtype(py_result_dtype)) # Check that weak promotion only works if strong value is not cast: for t1 in bool_dtypes: @@ -235,54 +269,72 @@ def testPromoteDtypesStrict(self): @jax.numpy_dtype_promotion('standard') def testPromoteDtypesStandard(self): + assertTypePromotionError = functools.partial( + self.assertRaisesRegex, + dtypes.TypePromotionError, + 'Input dtypes .* have no available implicit dtype promotion path.', + dtypes.promote_types, + ) + + small_fp_dtypes = set(fp8_dtypes + fp4_dtypes) + implicit_int_dtypes = set(signed_dtypes + unsigned_dtypes) - set(intn_dtypes) + for t1 in all_dtypes: self.assertEqual(t1, dtypes.promote_types(t1, t1)) - self.assertEqual(t1, dtypes.promote_types(t1, np.bool_)) # TODO(zhangqiaorjc): Consider more dtype promotion rules for fp8. - if t1 in fp8_dtypes: - continue - if t1 in intn_dtypes: - continue - if t1 in fp4_dtypes: - continue - self.assertEqual(np.dtype(np.complex128), - dtypes.promote_types(t1, np.complex128)) + if t1 in small_fp_dtypes or t1 in intn_dtypes: + assertTypePromotionError(t1, np.complex128) + else: + self.assertEqual( + np.dtype(np.complex128), dtypes.promote_types(t1, np.complex128) + ) for t2 in all_dtypes: # TODO(zhangqiaorjc): Consider more dtype promotion rules for fp8. - if t2 in fp8_dtypes: - continue - if t2 in intn_dtypes: - continue - if t2 in fp4_dtypes: - continue - # Symmetry - self.assertEqual(dtypes.promote_types(t1, t2), - dtypes.promote_types(t2, t1)) + if ( + (t1 != t2) + and (t1 != np.bool_) + and (t2 != np.bool_) + and ( + t1 in intn_dtypes or + t2 in intn_dtypes or + (t1 in small_fp_dtypes and t2 not in implicit_int_dtypes) or + (t2 in small_fp_dtypes and t1 not in implicit_int_dtypes) + ) + ): + assertTypePromotionError(t1, t2) + assertTypePromotionError(t2, t1) + else: + self.assertEqual( + dtypes.promote_types(t1, t2), dtypes.promote_types(t2, t1) + ) self.assertEqual(np.dtype(np.float32), dtypes.promote_types(np.float16, dtypes.bfloat16)) - # Promotions of non-inexact types against inexact types always prefer - # the inexact types. + # Promotions of exact types against inexact types always prefer the + # inexact types. for t in float_dtypes + complex_dtypes: for i in bool_dtypes + signed_dtypes + unsigned_dtypes: - # TODO(zhangqiaorjc): Consider more dtype promotion rules for fp8. - if t in fp8_dtypes: - continue - if t in fp4_dtypes: - continue - if t in intn_dtypes or i in intn_dtypes: - continue - self.assertEqual(t, dtypes.promote_types(t, i)) + if i in intn_dtypes: + assertTypePromotionError(t, i) + else: + self.assertEqual(t, dtypes.promote_types(t, i)) # Promotions between exact types, or between inexact types, match NumPy. for groups in [bool_dtypes + np_signed_dtypes + np_unsigned_dtypes, np_float_dtypes + complex_dtypes]: for t1, t2 in itertools.combinations(groups, 2): - self.assertEqual(np.promote_types(t1, t2), - dtypes.promote_types(t1, t2)) + expected = np.promote_types(t1, t2) + if ( + not config.enable_x64.value + and np.issubdtype(t1, np.signedinteger) + and t1 != np.int64 + and t2 == np.uint32 + ): + expected = np.dtype(np.int32) + self.assertEqual(expected, dtypes.promote_types(t1, t2)) # Promotion between weak types matches numpy promotion for t1 in [int, float, complex]: @@ -291,9 +343,10 @@ def testPromoteDtypesStandard(self): # np.dtype(int) is int32 on Windows and int64 on Linux/Mac. py_result_dtype = (np.dtype(np.int64) if py_result is int else np.dtype(py_result)) - lattice_dtype, lattice_weak_type = dtypes._lattice_result_type(t1, t2) + lattice_dtype, lattice_weak_type = dtypes.lattice_result_type(t1, t2) self.assertTrue(lattice_weak_type) - self.assertEqual(lattice_dtype, py_result_dtype) + self.assertEqual(lattice_dtype, + dtypes.canonicalize_dtype(py_result_dtype)) @parameterized.parameters([jnp.bool_, jnp.int32, jnp.bfloat16, jnp.float32, jnp.complex64]) def testScalarInstantiation(self, scalar_type): @@ -384,53 +437,113 @@ def testScalarCastInsideJitWorks(self): self.assertEqual(jnp.int32(101), jax.jit(lambda x: jnp.int32(x))(jnp.float32(101.4))) - @parameterized.parameters(python_scalar_types) - def testDtypeFromScalarType(self, typ): - self.assertEqual(dtypes.dtype(typ), dtypes.python_scalar_dtypes[typ]) + def testDtypeFromScalarType(self): + self.assertEqual(dtypes.dtype(bool), np.dtype(np.bool_)) + if config.enable_x64.value: + self.assertEqual(dtypes.dtype(int), np.dtype(np.int64)) + self.assertEqual(dtypes.dtype(float), np.dtype(np.float64)) + self.assertEqual(dtypes.dtype(complex), np.dtype(np.complex128)) + else: + self.assertEqual(dtypes.dtype(int), np.dtype(np.int32)) + self.assertEqual(dtypes.dtype(float), np.dtype(np.float32)) + self.assertEqual(dtypes.dtype(complex), np.dtype(np.complex64)) - @parameterized.parameters(python_scalar_types) - def testDtypeFromScalarValue(self, typ): - self.assertEqual(dtypes.dtype(typ(0)), dtypes.python_scalar_dtypes[typ]) + def testDtypeFromScalarValue(self): + self.assertEqual(dtypes.dtype(bool(0)), np.dtype(np.bool_)) + if config.enable_x64.value: + self.assertEqual(dtypes.dtype(int(0)), np.dtype(np.int64)) + self.assertEqual(dtypes.dtype(float(0)), np.dtype(np.float64)) + self.assertEqual(dtypes.dtype(complex(0)), np.dtype(np.complex128)) + else: + self.assertEqual(dtypes.dtype(int(0)), np.dtype(np.int32)) + self.assertEqual(dtypes.dtype(float(0)), np.dtype(np.float32)) + self.assertEqual(dtypes.dtype(complex(0)), np.dtype(np.complex64)) + + def testDtypeFromLiteralValue(self): + self.assertEqual(dtypes.dtype(TypedInt(0, np.dtype(np.int64))), np.dtype(np.int64)) + self.assertEqual(dtypes.dtype(TypedFloat(0, np.dtype(np.float64))), np.dtype(np.float64)) + self.assertEqual(dtypes.dtype(TypedComplex(0, np.dtype(np.complex128))), np.dtype(np.complex128)) + self.assertEqual(dtypes.dtype(TypedInt(0, np.dtype(np.int32))), np.dtype(np.int32)) + self.assertEqual(dtypes.dtype(TypedFloat(0, np.dtype(np.float32))), np.dtype(np.float32)) + self.assertEqual(dtypes.dtype(TypedComplex(0, np.dtype(np.complex64))), np.dtype(np.complex64)) + self.assertEqual(dtypes.dtype(TypedNdArray(np.array([0], dtype=np.int32), weak_type=False)), np.dtype(np.int32)) + self.assertEqual(dtypes.dtype(TypedNdArray(np.array([0], dtype=np.int64), weak_type=False)), np.dtype(np.int64)) @parameterized.parameters(all_dtypes) def testDtypeFromValue(self, dtype): - self.assertEqual(dtypes.dtype(dtype.type(0)), dtype) + self.assertEqual(dtypes.dtype(dtype.type(0)), + dtypes.canonicalize_dtype(dtype)) - @parameterized.parameters(all_dtypes) - def testDtypeFromDtype(self, dtype): - self.assertEqual(dtypes.dtype(dtype), dtype) + @parameterized.product( + dtype=all_dtypes, + explicit_x64_dtypes=tuple(config.ExplicitX64Mode.__members__.values()), + ) + def testDtypeFromDtype(self, dtype, explicit_x64_dtypes): + with config.explicit_x64_dtypes(explicit_x64_dtypes): + if explicit_x64_dtypes == config.ExplicitX64Mode.ALLOW: + self.assertEqual(dtypes.dtype(dtype), dtype) + elif explicit_x64_dtypes == config.ExplicitX64Mode.WARN: + with jtu.ignore_warning(category=UserWarning, + message="Explicitly requested dtype.*"): + self.assertEqual(dtypes.dtype(dtype), dtypes.canonicalize_dtype(dtype)) + else: + if config.enable_x64.value or dtype not in x64_dtypes: + self.assertEqual(dtypes.dtype(dtype), dtypes.canonicalize_dtype(dtype)) + else: + with self.assertRaisesRegex(ValueError, "Explicitly requested dtype"): + dtypes.dtype(dtype) @parameterized.parameters(all_dtypes) def testDtypeFromString(self, dtype): - self.assertEqual(dtypes.dtype(str(dtype)), dtype) + if config.explicit_x64_dtypes.value != config.ExplicitX64Mode.ERROR and dtype not in x64_dtypes: + self.assertEqual(dtypes.dtype(str(dtype)), dtypes.canonicalize_dtype(dtype)) def testDtypeFromNone(self): with self.assertRaisesRegex(ValueError, "Invalid argument to dtype"): dtypes.dtype(None) def testDefaultDtypes(self): - precision = config.default_dtype_bits.value - assert precision in ['32', '64'] self.assertEqual(dtypes.bool_, np.bool_) - self.assertEqual(dtypes.int_, np.int32 if precision == '32' else np.int64) - self.assertEqual(dtypes.uint, np.uint32 if precision == '32' else np.uint64) - self.assertEqual(dtypes.float_, np.float32 if precision == '32' else np.float64) - self.assertEqual(dtypes.complex_, np.complex64 if precision == '32' else np.complex128) + self.assertEqual(dtypes.int_, np.int64) + self.assertEqual(dtypes.uint, np.uint64) + self.assertEqual(dtypes.float_, np.float64) + self.assertEqual(dtypes.complex_, np.complex128) def test_check_dtype_non_hashable(self): # regression test for issue with checking non-hashable custom dtype class MyDtype: __hash__ = None dtype = np.dtype('float32') - dtypes.check_user_dtype_supported(MyDtype()) + dtypes.check_and_canonicalize_user_dtype(MyDtype()) def test_check_dtype_array(self): x = jnp.arange(4) - msg = "Passing an array as a dtype argument is deprecated" - with self.assertWarnsRegex(DeprecationWarning, msg): - dtypes.check_user_dtype_supported(x) - with self.assertWarnsRegex(DeprecationWarning, msg): - jax.jit(dtypes.check_user_dtype_supported)(x) + msg = "Passing an array as a dtype argument is no longer supported" + with self.assertRaisesRegex(ValueError, msg): + dtypes.check_and_canonicalize_user_dtype(x) + with self.assertRaisesRegex(ValueError, msg): + def f(x): + dtypes.check_and_canonicalize_user_dtype(x) + jax.jit(f)(x) + + @parameterized.parameters( + (jnp.int2, 2), + (jnp.int4, 4), + (jnp.int8, 8), + (jnp.int16, 16), + (jnp.int32, 32), + *[(fp4_dtype, 4) for fp4_dtype in fp4_dtypes], + *[(fp8_dtype, 8) for fp8_dtype in fp8_dtypes], + (jnp.float16, 16), + (jnp.float32, 32), + (jnp.float64, 64), + ) + def test_itemsize_bits(self, dtype, expected_bitwidth): + self.assertEqual(dtypes.itemsize_bits(dtype), expected_bitwidth) + + def test_itemsize_none_raises(self): + with self.assertRaisesRegex(ValueError, 'dtype cannot be None'): + dtypes.itemsize_bits(None) class ExtendedDTypeTest(jtu.JaxTestCase): @@ -757,7 +870,7 @@ def global_sharded_result_handler(aval, out_sharding, committed): phys_aval = core.physical_aval(aval) phys_handler_maker = pxla.global_result_handlers[core.ShapedArray] phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed) - return lambda bufs: earray.EArray(aval, phys_handler(bufs)) + return phys_handler.wrap(lambda arr: earray.EArray(aval, arr)) @dataclasses.dataclass(frozen=True) class FooTy(dtypes.ExtendedDType): @@ -789,12 +902,16 @@ class TestPromotionTables(jtu.JaxTestCase): {"testcase_name": f"_{jaxtype=}", "jaxtype": jaxtype} for jaxtype in dtypes._jax_types + dtypes._weak_types) def testJaxTypeFromType(self, jaxtype): + if isinstance(jaxtype, np.dtype): + jaxtype = dtypes.canonicalize_dtype(jaxtype) self.assertIs(dtypes._jax_type(*dtypes._dtype_and_weaktype(jaxtype)), jaxtype) @parameterized.named_parameters( {"testcase_name": f"_{jaxtype=}", "jaxtype": jaxtype} for jaxtype in dtypes._jax_types + dtypes._weak_types) def testJaxTypeFromVal(self, jaxtype): + if isinstance(jaxtype, np.dtype): + jaxtype = dtypes.canonicalize_dtype(jaxtype) try: val = jaxtype(0) except TypeError: @@ -834,10 +951,10 @@ def testScalarWeakTypes(self, typ): def testResultTypeNone(self): # This matches the behavior of np.result_type(None) => np.float64 - self.assertEqual(dtypes.result_type(None), dtypes.canonicalize_dtype(dtypes.float_)) + self.assertEqual(dtypes.result_type(None), dtypes.default_float_dtype()) def testResultTypeWeakFlag(self): - float_ = dtypes.canonicalize_dtype(dtypes.float_) + float_ = dtypes.default_float_dtype() x_weak = jnp.array(1.) x_strong = x_weak.astype(float_) self.assertEqual(dtypes.result_type(x_weak), float_) @@ -874,6 +991,28 @@ def testObservedPromotionTable(self): ['f*','f*','f*','f*','f*','f*','f*','f*','f*','bf','f2','f4','f8','c4','c8','f*','f*','c*'], ['c*','c*','c*','c*','c*','c*','c*','c*','c*','c4','c4','c4','c8','c4','c8','c*','c*','c*'], ] + elif config.explicit_x64_dtypes.value == config.ExplicitX64Mode.ALLOW: + # This differs from enable_x64=True only because i4xu4 -> i4 instead of s8. + expected = [ + ['b1','u1','u2','u4','u8','i1','i2','i4','i8','bf','f2','f4','f8','c4','c8','i*','f*','c*'], + ['u1','u1','u2','u4','u8','i2','i2','i4','i8','bf','f2','f4','f8','c4','c8','u1','f*','c*'], + ['u2','u2','u2','u4','u8','i4','i4','i4','i8','bf','f2','f4','f8','c4','c8','u2','f*','c*'], + ['u4','u4','u4','u4','u8','i4','i4','i4','i8','bf','f2','f4','f8','c4','c8','u4','f*','c*'], + ['u8','u8','u8','u8','u8','f*','f*','f*','f*','bf','f2','f4','f8','c4','c8','u8','f*','c*'], + ['i1','i2','i4','i4','f*','i1','i2','i4','i8','bf','f2','f4','f8','c4','c8','i1','f*','c*'], + ['i2','i2','i4','i4','f*','i2','i2','i4','i8','bf','f2','f4','f8','c4','c8','i2','f*','c*'], + ['i4','i4','i4','i4','f*','i4','i4','i4','i8','bf','f2','f4','f8','c4','c8','i4','f*','c*'], + ['i8','i8','i8','i8','f*','i8','i8','i8','i8','bf','f2','f4','f8','c4','c8','i8','f*','c*'], + ['bf','bf','bf','bf','bf','bf','bf','bf','bf','bf','f4','f4','f8','c4','c8','bf','bf','c4'], + ['f2','f2','f2','f2','f2','f2','f2','f2','f2','f4','f2','f4','f8','c4','c8','f2','f2','c4'], + ['f4','f4','f4','f4','f4','f4','f4','f4','f4','f4','f4','f4','f8','c4','c8','f4','f4','c4'], + ['f8','f8','f8','f8','f8','f8','f8','f8','f8','f8','f8','f8','f8','c8','c8','f8','f8','c8'], + ['c4','c4','c4','c4','c4','c4','c4','c4','c4','c4','c4','c4','c8','c4','c8','c4','c4','c4'], + ['c8','c8','c8','c8','c8','c8','c8','c8','c8','c8','c8','c8','c8','c8','c8','c8','c8','c8'], + ['i*','u1','u2','u4','u8','i1','i2','i4','i8','bf','f2','f4','f8','c4','c8','i*','f*','c*'], + ['f*','f*','f*','f*','f*','f*','f*','f*','f*','bf','f2','f4','f8','c4','c8','f*','f*','c*'], + ['c*','c*','c*','c*','c*','c*','c*','c*','c*','c4','c4','c4','c8','c4','c8','c*','c*','c*'], + ] else: expected = [ ['b1','u1','u2','u4','u4','i1','i2','i4','i4','bf','f2','f4','f4','c4','c4','i*','f*','c*'], @@ -922,6 +1061,9 @@ def val_to_typecode(val): typecode = typecode[:-1] + '*' return typecode + if config.explicit_x64_dtypes.value == config.ExplicitX64Mode.ERROR: + self.skipTest("Test uses x64 types") + vals = [typecode_to_val(t) for t in typecodes] table = [[val_to_typecode(v1 + v2) for v1 in vals] for v2 in vals] @@ -965,10 +1107,14 @@ def testUnaryPromotion(self, dtype, weak_type): self.skipTest("TPU does not support float8_e8m0fnu.") if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']): self.skipTest("TPU does not support float4_e2m1fn.") + dtype = dtypes.canonicalize_dtype(dtype) x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type) if weak_type: - expected = dtypes.canonicalize_dtype( - dtypes._default_types['f' if x.dtype in ["bfloat16", *fp8_dtypes, *fp4_dtypes] else x.dtype.kind]) + expected = dtypes.default_types[ + 'f' + if x.dtype in ['bfloat16', *fp8_dtypes, *fp4_dtypes] + else x.dtype.kind + ]() else: expected = x.dtype self.assertEqual(dtypes.result_type(x), expected) @@ -1026,6 +1172,7 @@ def testBinaryNonPromotion(self, dtype, weak_type, promotion): if dtype in intn_dtypes: self.skipTest("XLA support for int2 and int4 is incomplete.") # Regression test for https://github.com/jax-ml/jax/issues/6051 + dtype = dtypes.canonicalize_dtype(dtype) x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type) with jax.numpy_dtype_promotion(promotion): y = (x + x) @@ -1057,6 +1204,7 @@ def testArrayRepr(self, dtype, weak_type): self.skipTest('TPU does not support float8_e8m0fnu.') if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']): self.skipTest('TPU does not support float4_e2m1fn.') + dtype = dtypes.canonicalize_dtype(dtype) val = lax_internal._convert_element_type(0, dtype, weak_type=weak_type) rep = repr(val) self.assertStartsWith(rep, 'Array(') diff --git a/tests/dynamic_api_test.py b/tests/dynamic_api_test.py deleted file mode 100644 index df79e6aaf6df..000000000000 --- a/tests/dynamic_api_test.py +++ /dev/null @@ -1,1770 +0,0 @@ -# Copyright 2018 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partial -import re -import unittest -import numpy as np - -from absl.testing import absltest -from absl.testing import parameterized - -import jax -import jax.numpy as jnp -from jax import lax -from jax.interpreters import batching - -import jax._src.lib -import jax._src.util -from jax._src import core -from jax._src import test_util as jtu - -jax.config.parse_flags_with_absl() - - -@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow") -class DynamicShapeStagingTest(jtu.JaxTestCase): - def test_basic_staging(self): - def f(x, _): - return x - - x = jnp.arange(3) - y = jnp.ones((3, 4)) - jaxpr = jax.make_jaxpr(f, abstracted_axes={0: 'n'})(x, y) - - # { lambda ; a:i32[] b:i32[a] c:f32[a,4]. let in (b,) } - self.assertLen(jaxpr.in_avals, 3) - self.assertLen(jaxpr.in_avals[0].shape, 0) - self.assertLen(jaxpr.in_avals[1].shape, 1) - self.assertLen(jaxpr.in_avals[2].shape, 2) - self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.in_avals[1].shape[0]) - self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.in_avals[2].shape[0]) - self.assertLen(jaxpr.out_avals, 1) - self.assertLen(jaxpr.out_avals[0].shape, 1) - self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.out_avals[0].shape[0]) - - def test_basic_staging_repeated(self): - def f(x, _): - return x - - x = jnp.arange(3) - y = jnp.ones((3, 3)) - jaxpr = jax.make_jaxpr(f, abstracted_axes=(('n',), ('n', 'n')))(x, y) - - # { lambda ; a:i32[] b:i32[a] c:f32[a,a]. let in (b,) } - self.assertLen(jaxpr.in_avals, 3) - self.assertLen(jaxpr.in_avals[0].shape, 0) - self.assertLen(jaxpr.in_avals[1].shape, 1) - self.assertLen(jaxpr.in_avals[2].shape, 2) - self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.in_avals[1].shape[0]) - self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.in_avals[2].shape[0]) - self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.in_avals[2].shape[1]) - self.assertLen(jaxpr.out_avals, 1) - self.assertLen(jaxpr.out_avals[0].shape, 1) - self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.out_avals[0].shape[0]) - - def test_basic_staging_multiple_shape_vars(self): - def f(x, _): - return x - - x = jnp.arange(3) - y = jnp.ones((4, 3)) - jaxpr = jax.make_jaxpr(f, abstracted_axes=(('n',), ('m', 'n')))(x, y) - - # { lambda ; a:i32[] b: i32[] c:i32[a] d:f32[b,a]. let in (c,) } - self.assertLen(jaxpr.in_avals, 4) - self.assertLen(jaxpr.in_avals[0].shape, 0) - self.assertLen(jaxpr.in_avals[1].shape, 0) - self.assertLen(jaxpr.in_avals[2].shape, 1) - self.assertLen(jaxpr.in_avals[3].shape, 2) - self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.in_avals[2].shape[0]) - self.assertIs(jaxpr.jaxpr.invars[1], jaxpr.in_avals[3].shape[0]) - self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.in_avals[3].shape[1]) - self.assertLen(jaxpr.out_avals, 1) - self.assertLen(jaxpr.out_avals[0].shape, 1) - self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.out_avals[0].shape[0]) - - def test_basic_add(self): - def f(x, y): - return x + y - - x = jnp.arange(3) - y = jnp.arange(1, 4) - jaxpr = jax.make_jaxpr(f, abstracted_axes={0: 'n'})(x, y) - - # { lambda ; a:i32[] b:i32[a] c:i32[a]. let d:i32[a] = add b c in (d,) } - self.assertLen(jaxpr.eqns, 1) - self.assertLen(jaxpr.out_avals, 1) - self.assertLen(jaxpr.out_avals[0].shape, 1) - self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.out_avals[0].shape[0]) - - def test_basic_jnp(self): - def f(x): - y = x + jnp.sin(x) - return y.sum() - - x = jnp.ones((3, 4)) - jaxpr = jax.make_jaxpr(f, abstracted_axes={0: 'n'})(x) - - # { lambda ; a:i32[] b:f32[a,4]. let - # c:f32[a,4] = sin b - # d:f32[a,4] = add b c - # e:f32[] = reduce_sum[axes=(0, 1)] d - # in (e,) } - self.assertLen(jaxpr.in_avals, 2) - self.assertLen(jaxpr.eqns, 3) # sin, add, and reduce_sum - self.assertLen(jaxpr.out_avals, 1) - self.assertLen(jaxpr.out_avals[0].shape, 0) - - def test_shape_errors_var_and_lit(self): - def f(x, y): - return jnp.sin(x) + y - - x = np.ones(3) - y = np.ones(3) - with self.assertRaisesRegex( - Exception, '[Ii]ncompatible shapes for broadcasting'): - _ = jax.make_jaxpr(f, abstracted_axes=({0: 'n'}, {}))(x, y) - - def test_shape_errors_distinct_vars(self): - def f(x, y): - return jnp.sin(x) + y - - x = np.ones(3) - y = np.ones(3) - with self.assertRaisesRegex( - Exception, '[Ii]ncompatible shapes for broadcasting'): - _ = jax.make_jaxpr(f, abstracted_axes=({0: 'n'}, {0: 'm'}))(x, y) - - def test_basic_dot(self): - A = jnp.ones((3, 4)) - x = jnp.ones(4) - jaxpr = jax.make_jaxpr(jnp.dot, abstracted_axes=(('m', 'n'), ('n',)))(A, x) - - # { lambda ; a:i32[] b:i32[] c:f32[a,b] d:f32[b]. let - # e:f32[a] = dot_general[dimension_numbers=(((1,), (0,)), ((), ()))] c d - # in (e,) } - self.assertLen(jaxpr.in_avals, 4) - self.assertLen(jaxpr.in_avals[0].shape, 0) # two shape vars - self.assertLen(jaxpr.in_avals[1].shape, 0) - self.assertLen(jaxpr.in_avals[2].shape, 2) # one matrix - self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.in_avals[2].shape[0]) - self.assertIs(jaxpr.jaxpr.invars[1], jaxpr.in_avals[2].shape[1]) - self.assertLen(jaxpr.in_avals[3].shape, 1) # one vector - self.assertIs(jaxpr.jaxpr.invars[1], jaxpr.in_avals[3].shape[0]) - self.assertLen(jaxpr.eqns, 1) - self.assertLen(jaxpr.out_avals, 1) - self.assertLen(jaxpr.out_avals[0].shape, 1) # output vector - self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.out_avals[0].shape[0]) - - def test_basic_broadcast(self): - def f(x, n): - return lax.broadcast(x, (n,)) - - jaxpr = jax.make_jaxpr(f)(jnp.ones(4), 3) - - # { lambda ; a:f32[4] b:i32[]. let - # c:f32[b,4] = broadcast_in_dim[bcast_dims=(1,) shape=(None, 4)] a b - # in (c,) } - self.assertLen(jaxpr.in_avals, 2) - self.assertLen(jaxpr.eqns, 1) - self.assertLen(jaxpr.out_avals, 1) - self.assertLen(jaxpr.out_avals[0].shape, 2) - self.assertIs(jaxpr.jaxpr.invars[1], jaxpr.out_avals[0].shape[0]) - self.assertEqual(4, jaxpr.out_avals[0].shape[1]) - - def test_basic_batchpoly_neuralnet(self): - def predict(params, inputs): - for W, b in params: - outputs = jnp.dot(inputs, W) + b - inputs = jnp.tanh(outputs) - return outputs - - def loss(params, batch): - inputs, targets = batch - preds = predict(params, inputs) - return jnp.sum((preds - targets) ** 2) - - sizes = [784, 128, 128, 10] - params = [(jnp.ones((input_dim, output_dim)), jnp.ones(output_dim)) - for input_dim, output_dim in zip(sizes[:-1], sizes[1:])] - batch = (jnp.ones((32, 784)), jnp.ones((32, 10))) - - # Mainly we want to test that make_jaxpr doesn't crash here. - jaxpr = jax.make_jaxpr(loss, abstracted_axes=({}, {0: 'n'}))(params, batch) - self.assertLen(jaxpr.in_avals, 9) - self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.in_avals[-2].shape[0]) - self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.in_avals[-1].shape[0]) - self.assertLen(jaxpr.out_avals, 1) - self.assertLen(jaxpr.out_avals[0].shape, 0) - - def test_closing_over_polymorphic_shape(self): - def f(n): - x = jnp.zeros(n) - return jax.jit(lambda: x)() - - jaxpr = jax.make_jaxpr(f)(3) - - # { lambda ; a:i32[]. let - # b:f32[a] = bcast[dims=() shape=(None,)] 0.0 a - # c:f32[a] = pjit[ - # jaxpr={ lambda ; d:i32[] e:f32[d]. let in (e,) } - # name= - # ] a b - # in (c,) } - a, = jaxpr.jaxpr.invars - c, = jaxpr.jaxpr.outvars - self.assertLen(c.aval.shape, 1) - self.assertIs(a, c.aval.shape[0]) - - def test_closing_over_dynamic_shape(self): - def f(n): - m = 2 * n - x = jnp.zeros(m) - return jax.jit(jnp.sin)(x) - - # { lambda ; a:i32[]. let - # b:i32[] = mul a 2 - # c:f32[b] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 0.0 b - # d:f32[b] = pjit[ - # jaxpr={ lambda ; e:i32[] f:f32[e]. let in (f,) } - # name= - # ] b c - # in (d,) } - jaxpr = jax.make_jaxpr(f)(3) - b, = jaxpr.jaxpr.eqns[0].outvars - c, = jaxpr.jaxpr.eqns[1].outvars - d, = jaxpr.jaxpr.eqns[2].outvars - self.assertLen(c.aval.shape, 1) - self.assertIs(b, c.aval.shape[0]) - self.assertLen(d.aval.shape, 1) - self.assertIs(b, d.aval.shape[0]) - - def test_closing_over_polymorphic_shape_and_adding(self): - def f(n): - x = jnp.zeros(n) - y = jnp.zeros(n) - - @jax.jit - def g(): - return x + y - return g() - - # { lambda ; a:i32[]. let - # b:f32[a] = bcast[broadcast_dimensions=() shape=(None,)] 0.0 a - # c:f32[a] = bcast[broadcast_dimensions=() shape=(None,)] 0.0 a - # d:f32[a] = pjit[ - # jaxpr={ lambda ; e:i32[] f:f32[e] g:f32[e]. let - # h:f32[e] = add f g - # in (h,) } - # name=g - # ] a b c - # in (d,) } - jaxpr = jax.make_jaxpr(f)(3) # doesn't fail on the addition! - a, = jaxpr.jaxpr.invars - b, = jaxpr.jaxpr.eqns[0].outvars - c, = jaxpr.jaxpr.eqns[1].outvars - d, = jaxpr.jaxpr.eqns[2].outvars - self.assertIs(a, b.aval.shape[0]) - self.assertIs(a, c.aval.shape[0]) - self.assertIs(a, d.aval.shape[0]) - - def test_passing_in_equal_polymorphic_shapes_and_adding(self): - def f(n): - x = jnp.zeros(n) - - @jax.jit - def g(x, y): - return x + y - return g(x, x) - - # { lambda ; a:i32[]. let - # b:f32[a] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 0.0 a - # c:f32[a] = pjit[ - # jaxpr={ lambda ; d:i32[] e:f32[d] f:f32[d]. let - # g:f32[d] = add e f - # in (g,) } - # name=g - # ] a b b - # in (c,) } - jaxpr = jax.make_jaxpr(f)(3) - a, = jaxpr.jaxpr.invars - c, = jaxpr.jaxpr.outvars - self.assertLen(c.aval.shape, 1) - self.assertIs(a, c.aval.shape[0]) - - @unittest.skip("doesn't work yet: shape error b/c we don't notice x and y same") - def test_closing_over_and_passing_arg_addition(self): - # TODO(mattjj,dougalm): currently fails to notice equal shapes, fix! - def f(n): - x = jnp.zeros(n) - - @jax.jit - def g(y): - return x + y - return g(x) - - _ = jax.make_jaxpr(f)(3) - - @unittest.skip("doesn't work yet: shape error b/c we don't notice x and jnp.zeros(m) same") - def test_closing_over_and_passing_size_addition(self): - # TODO(mattjj,dougalm): currently fails to notice equal shapes, fix! - def f(n): - x = jnp.zeros(n) - - @jax.jit - def g(m): - return jnp.zeros(m) + x - return g(n) - - _ = jax.make_jaxpr(f)(3) - - def test_closing_over_and_broadcasting_polymorphic_shape(self): - def f(n): - x = jnp.zeros(n) - @jax.jit - def g(): - return jnp.zeros(n) + x - return g() - - # { lambda ; a:i32[]. let - # b:f32[a] = bcast[broadcast_dimensions=() shape=(None,)] 0.0 a - # c:f32[a] = pjit[ - # jaxpr={ lambda ; d:i32[] e:f32[d]. let - # f:f32[d] = bcast[broadcast_dimensions=() shape=(None,)] 0.0 d - # g:f32[d] = add f e - # in (g,) } - # name=g - # ] a b - # in (c,) } - jaxpr = jax.make_jaxpr(f)(3) - - a, = jaxpr.jaxpr.invars - c, = jaxpr.jaxpr.outvars - self.assertLen(c.aval.shape, 1) - self.assertIs(a, c.aval.shape[0]) - - def test_closing_over_repeated_shapes(self): - def zeros(shape): - if not isinstance(shape, (tuple, list)): - shape = shape, - return lax.broadcast(0., shape) - - def f(n): - m = 2 * n - x = zeros((m, m)) - return jax.jit(lambda: x.sum(0))() - - # { lambda ; a:i32[]. let - # b:i32[] = mul a 2 - # c:f32[b,b] = broadcast_in_dim[broadcast_dimensions=() shape=(None, None)] 0.0 - # b b - # d:f32[b] = pjit[ - # jaxpr={ lambda ; e:i32[] f:f32[e,e]. let - # g:f32[e] = reduce_sum[axes=(0,)] f - # in (g,) } - # name= - # ] b c - # in (d,) } - jaxpr = jax.make_jaxpr(f)(3) - a, = jaxpr.jaxpr.invars - b, = jaxpr.jaxpr.eqns[0].outvars - c, = jaxpr.jaxpr.eqns[1].outvars - d, = jaxpr.jaxpr.eqns[2].outvars - b_, c_ = jaxpr.jaxpr.eqns[2].invars - self.assertLen(c.aval.shape, 2) - self.assertIs(c.aval.shape[0], b) - self.assertIs(c.aval.shape[1], b) - self.assertIs(b, b_) - self.assertIs(c, c_) - self.assertLen(d.aval.shape, 1) - self.assertIs(d.aval.shape[0], b) - - def test_staging_repeated_nested(self): - def zeros(shape): - if not isinstance(shape, (tuple, list)): - shape = shape, - return lax.broadcast(jnp.float32(0.), shape) - - def f(n): - m = 2 * n - x = zeros((m, n)) - y = zeros(m) - return jax.jit(lambda x, y: x.sum(1) + y)(x, y) - - # { lambda ; a:i32[]. let - # b:i32[] = mul a 2 - # c:f32[b,a] = broadcast_in_dim[broadcast_dimensions=() shape=(None, None)] 0.0 - # b a - # d:f32[b] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0 b - # e:f32[b] = pjit[ - # jaxpr={ lambda ; f:i32[] g:i32[] h:f32[f,g] i:f32[f]. let - # j:f32[f] = reduce_sum[axes=(1,)] h - # k:f32[f] = add j i - # in (k,) } - # name= - # ] b a c d - # in (e,) } - jaxpr = jax.make_jaxpr(f)(jnp.int32(3)) - a, = jaxpr.jaxpr.invars - b, = jaxpr.jaxpr.eqns[0].outvars - c, = jaxpr.jaxpr.eqns[1].outvars - d, = jaxpr.jaxpr.eqns[2].outvars - e, = jaxpr.jaxpr.eqns[3].outvars - b_, a_, c_, d_ = jaxpr.jaxpr.eqns[3].invars - self.assertLen(c.aval.shape, 2) - self.assertIs(c.aval.shape[0], b) - self.assertIs(c.aval.shape[1], a) - self.assertLen(e.aval.shape, 1) - self.assertIs(e.aval.shape[0], b) - self.assertIs(a, a_) - self.assertIs(b, b_) - self.assertIs(c, c_) - self.assertIs(d, d_) - - def test_jit_abstracted_axes_staging(self): - # We just test make_jaxpr-of-jit because dynamic shape compilation/execution - # may not be supported. - @partial(jax.jit, abstracted_axes=('n',)) - def f(x): - return jnp.sum(x) - jaxpr = jax.make_jaxpr(f)(jnp.ones(3, jnp.dtype('float32'))) - # { lambda ; a:f32[3]. let - # b:f32[] = pjit[ - # jaxpr={ lambda ; c:i32[] d:f32[c]. let - # e:f32[] = reduce_sum[axes=(0,)] d - # in (e,) } - # name=f - # ] 3 a - # in (b,) } - a, = jaxpr.jaxpr.invars - e, = jaxpr.jaxpr.eqns - self.assertLen(e.invars, 2) - self.assertIsInstance(e.invars[0], core.Literal) - self.assertIs(e.invars[1], a) - b, = e.outvars - self.assertLen(b.aval.shape, 0) - - subjaxpr = e.params['jaxpr'] - c, d = subjaxpr.jaxpr.invars - self.assertLen(c.aval.shape, 0) - self.assertLen(d.aval.shape, 1) - self.assertIs(d.aval.shape[0], c) - - def test_jit_abstracted_axes_staging2(self): - @partial(jax.jit, abstracted_axes=('n',)) - def fun(x): - return jnp.sum(x) - jaxpr = jax.make_jaxpr(lambda n: fun(jnp.ones(n + n, jnp.dtype('float32'))) - )(3) - # { lambda ; a:i32[]. let - # b:i32[] = add a a - # c:f32[b] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0 b - # d:f32[] = pjit[ - # jaxpr={ lambda ; e:i32[] f:f32[e]. let - # g:f32[] = reduce_sum[axes=(0,)] f - # in (g,) } - # name=f - # ] b c - # in (d,) } - a, = jaxpr.jaxpr.invars - e1, e2, e3 = jaxpr.jaxpr.eqns - b, = e1.outvars - c, = e2.outvars - b_, c_ = e3.invars - self.assertIs(b, b_) - self.assertIs(c, c_) - - subjaxpr = e3.params['jaxpr'] - e, f = subjaxpr.jaxpr.invars - self.assertLen(e.aval.shape, 0) - self.assertLen(f.aval.shape, 1) - self.assertIs(f.aval.shape[0], e) - - def test_jit_abstracted_axes_staging3(self): - f = jax.jit(jnp.sum, abstracted_axes=('n',)) - jaxpr = jax.make_jaxpr(f, abstracted_axes=('n',))(jnp.arange(3.)) - # { lambda ; a:i32[] b:f32[a]. let - # c:f32[] = pjit[ - # jaxpr={ lambda ; d:i32[] e:f32[d]. let - # f:f32[] = reduce_sum[axes=(0,)] e - # in (f,) } - # name=sum - # ] a b - # in (c,) } - a, b = jaxpr.jaxpr.invars - e, = jaxpr.jaxpr.eqns - self.assertIs(e.invars[0], a) - self.assertIs(e.invars[1], b) - c, = e.outvars - self.assertLen(c.aval.shape, 0) - - subjaxpr = e.params['jaxpr'] - d, e = subjaxpr.jaxpr.invars - self.assertLen(d.aval.shape, 0) - self.assertLen(e.aval.shape, 1) - self.assertIs(e.aval.shape[0], d) - - def test_jit_abstracted_axes_return_polymorphic_shape(self): - f = jax.jit(lambda x: jnp.sin(x), abstracted_axes=('n',)) - jaxpr = jax.make_jaxpr(f)(jnp.arange(3)) # doesn't crash - # { lambda ; a:i32[3]. let - # b:i32[3] = pjit[ - # jaxpr={ lambda ; c:i32[] d:i32[c]. let in (d,) } - # name= - # ] 3 a - # in (b,) } - a, = jaxpr.jaxpr.invars - e, = jaxpr.jaxpr.eqns - three, a_ = e.invars - b, = e.outvars - self.assertIsInstance(three, core.Literal) - self.assertEqual(three.val, 3) - self.assertIs(a_, a) - self.assertLen(b.aval.shape, 1) - self.assertEqual(b.aval.shape[0], 3) - - def test_jit_abstracted_axes_return_polymorphic_shape2(self): - f = jax.jit(lambda n: jnp.ones(n)) - # TODO(mattjj,dougalm): support dynamic shapes in type checker - with jax.enable_checks(False): - jaxpr = jax.make_jaxpr(f)(3) - # { lambda ; a:i32[]. let - # b:f32[a] = pjit[ - # jaxpr={ lambda ; c:i32[]. let - # d:f32[c] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0 - # c - # in (d,) } - # name= - # ] a - # in (b,) } - a, = jaxpr.jaxpr.invars - e, = jaxpr.jaxpr.eqns - a_, = e.invars - self.assertIs(a, a_) - b, = e.outvars - a__, = b.aval.shape - self.assertIs(a, a__) - - with jax.enable_checks(False): - jaxpr = jax.make_jaxpr(lambda: f(3))() - # { lambda ; . let - # a:f32[3] = pjit[ - # jaxpr={ lambda ; b:i32[]. let - # c:f32[b] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0 - # b - # in (c,) } - # name= - # ] 3 - # in (a,) } - () = jaxpr.jaxpr.invars - e, = jaxpr.jaxpr.eqns - three, = e.invars - self.assertIsInstance(three, core.Literal) - self.assertEqual(three.val, 3) - b, = e.outvars - three_, = b.aval.shape - self.assertIsInstance(three_, int) - self.assertEqual(three_, 3) - - def test_zero_size_checking(self): - def f(x): - if core.definitely_equal(x.size, 0): - return x - else: - return -x - - x = jnp.zeros(1) - jaxpr = jax.make_jaxpr(f, abstracted_axes={0: 'n'})(x) # doesn't crash - self.assertGreaterEqual(len(jaxpr.jaxpr.eqns), 1) - - y = jnp.zeros((2, 0)) - jaxpr = jax.make_jaxpr(f, abstracted_axes={0: 'n'})(y) # doesn't crash - self.assertLen(jaxpr.jaxpr.eqns, 0) - - def test_flattening_basic(self): - x = jnp.zeros((2, 3, 4, 5)) - - # don't need to divide or multiply any dynamic axis sizes - jaxpr = jax.make_jaxpr(lambda x: x.reshape(x.shape[0], -1), - abstracted_axes={0: 'n'})(x) - self.assertLen(jaxpr.jaxpr.eqns, 1) - jaxpr = jax.make_jaxpr(lambda x: x.reshape(3, x.shape[0], -1), - abstracted_axes={0: 'n'})(x) - self.assertLen(jaxpr.jaxpr.eqns, 1) - jaxpr = jax.make_jaxpr(lambda x: x.reshape(-1, x.shape[0]), - abstracted_axes={0: 'n'})(x) - self.assertLen(jaxpr.jaxpr.eqns, 1) - - # don't need to divide but do need a dynamic axis size in multiplication - # (so to typecheck we'd need nontrivial reductions) - jaxpr = jax.make_jaxpr(lambda x: x.reshape(-1), - abstracted_axes={0: 'n'})(x) - self.assertLessEqual(len(jaxpr.jaxpr.eqns), 3) # may have mul with 1 - self.assertEqual(str(jaxpr.jaxpr.eqns[-2].primitive), 'mul') - self.assertEqual(str(jaxpr.jaxpr.eqns[-1].primitive), 'reshape') - jaxpr = jax.make_jaxpr(lambda x: x.reshape(2, -1), - abstracted_axes={0: 'n'})(x) - self.assertLessEqual(len(jaxpr.jaxpr.eqns), 3) - jaxpr = jax.make_jaxpr(lambda x: x.reshape(-1, 12), abstracted_axes={0: 'n'})(x) - self.assertLessEqual(len(jaxpr.jaxpr.eqns), 3) - - def test_shape_validation(self): - # Regression test for https://github.com/jax-ml/jax/issues/18937 - msg = r"Shapes must be 1D sequences of integer scalars, got .+" - with self.assertRaisesRegex(TypeError, msg): - jax.make_jaxpr(jnp.ones)(5.0) - with self.assertRaisesRegex(TypeError, msg): - jax.make_jaxpr(jnp.ones)(jnp.ones((2, 2))) - - def test_matmul_two_arg(self): - def f(x, y): - return jnp.matmul(x, y) - - jaxpr = jax.make_jaxpr(f, abstracted_axes=({0: 'a_0', 1: 'a_1'}, {0: 'a_1', 1: 'a_2'}),)(jnp.ones((8, 4)), jnp.ones((4, 8))) - - def test_matmul_two_arg_size_mismatch_name_validation(self): - def f(x, y): - return jnp.matmul(x, y) - - with self.assertRaisesRegex(TypeError, - re.escape("Provided size 4 for a_1 does not match prior associated name for a_1 : 8")): - jaxpr = jax.make_jaxpr(f, abstracted_axes=({0: 'a_0', 1: 'a_1'}, {0: 'a_1', 1: 'a_2'}),)(jnp.ones((8, 4)), jnp.ones((8, 4))) - -@unittest.skip("Test does not work with jax.Array") -@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow") -class DynamicShapeAutodiffTest(jtu.JaxTestCase): - def test_jvp_broadcast(self): - @jax.jit - def fn(n, x): - return lax.broadcast_in_dim(x, (n,), ()) - - outer_jaxpr = jax.make_jaxpr( - lambda x, t: jax.jvp(lambda y: fn(3, y), (x,), (t,)) - )(3., 4.) - # { lambda ; a:f32[] b:f32[]. let - # c:f32[3] d:f32[3] = pjit[ - # jaxpr={ lambda ; e:i32[] f:f32[] g:f32[]. let - # h:f32[e] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] f e - # i:f32[e] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] g e - # in (h, i) } - # name=f - # ] 3 a b - # in (c, d) } - self.assertLen(outer_jaxpr.jaxpr.eqns, 1) - eqn, = outer_jaxpr.jaxpr.eqns - self.assertIn('jaxpr', eqn.params) - jaxpr = eqn.params['jaxpr'].jaxpr - self.assertLen(jaxpr.invars, 3) - e, f, g = jaxpr.invars - self.assertEqual(e.aval.shape, ()) - self.assertEqual(f.aval.shape, ()) - self.assertEqual(g.aval.shape, ()) - self.assertLen(jaxpr.outvars, 2) - h, i = jaxpr.outvars - self.assertEqual(h.aval.shape, (e,)) - self.assertEqual(i.aval.shape, (e,)) - self.assertLen(eqn.outvars, 2) - c, d = eqn.outvars - self.assertEqual(c.aval.shape, (3,)) - self.assertEqual(d.aval.shape, (3,)) - - def test_jvp_basic(self): - @partial(jax.jit, abstracted_axes=('n',)) - def foo(x): - return jnp.sin(x) - - x = t = jnp.arange(3.) - outer_jaxpr = jax.make_jaxpr(lambda x, t: jax.jvp(foo, (x,), (t,)))(x, t) - # { lambda ; a:f32[3] b:f32[3]. let - # c:f32[3] d:f32[3] = pjit[ - # jaxpr={ lambda ; e:i32[] f:f32[e] g:f32[e]. let - # h:f32[e] = sin f - # i:f32[e] = cos f - # j:f32[e] = mul g i - # in (h, j) } - # name=f - # ] 3 a b - # in (c, d) } - self.assertLen(outer_jaxpr.jaxpr.eqns, 1) - eqn, = outer_jaxpr.eqns - self.assertIn('jaxpr', eqn.params) - jaxpr = eqn.params['jaxpr'].jaxpr - self.assertLen(jaxpr.invars, 3) - e, f, g = jaxpr.invars - self.assertEqual(e.aval.shape, ()) - self.assertEqual(f.aval.shape, (e,)) - self.assertEqual(g.aval.shape, (e,)) - self.assertLen(jaxpr.outvars, 2) - self.assertLen(eqn.outvars, 2) - c, d = eqn.outvars - self.assertEqual(c.aval.shape, (3,)) - self.assertEqual(d.aval.shape, (3,)) - - def test_linearize_basic(self): - @partial(jax.jit, abstracted_axes=('n',)) - def foo(x): - return jax.lax.sin(x) - - x = jnp.arange(3.) - - # primal computation - outer_jaxpr = jax.make_jaxpr(lambda x: jax.linearize(foo, x))(x) - # { lambda ; a:f32[3]. let - # b:f32[3] c:f32[3] = pjit[ - # jaxpr={ lambda ; d:i32[] e:f32[d]. let - # f:f32[d] = sin e - # g:f32[d] = cos e - # in (f, g) } - # name=foo - # ] 3 a - # in (b, c) } - self.assertLen(outer_jaxpr.jaxpr.eqns, 1) - eqn, = outer_jaxpr.jaxpr.eqns - self.assertIn('jaxpr', eqn.params) - jaxpr = eqn.params['jaxpr'].jaxpr - self.assertLen(jaxpr.invars, 2) - d, e = jaxpr.invars - self.assertEqual(d.aval.shape, ()) - self.assertEqual(e.aval.shape, (d,)) - self.assertLen(jaxpr.eqns, 2) - self.assertLen(jaxpr.outvars, 2) - f, g = jaxpr.outvars - self.assertEqual(jaxpr.eqns[0].outvars, [f]) - self.assertEqual(jaxpr.eqns[1].outvars, [g]) - self.assertLen(eqn.outvars, 2) - b, c = eqn.outvars - self.assertEqual(b.aval.shape, (3,)) - self.assertEqual(c.aval.shape, (3,)) - - # primal and tangent computation - outer_jaxpr = jax.make_jaxpr( - lambda x, xdot: jax.linearize(foo, x)[1](xdot))(x, x) - # { lambda ; a:f32[3] b:f32[3]. let - # _:f32[3] c:f32[3] = pjit[ - # jaxpr={ lambda ; d:i32[] e:f32[d]. let - # f:f32[d] = sin e - # g:f32[d] = cos e - # in (f, g) } - # name=foo - # ] 3 a - # h:f32[3] = pjit[ - # jaxpr={ lambda ; i:i32[] j:f32[i] k:f32[i]. let - # l:f32[i] = mul k j - # in (l,) } - # name=foo - # ] 3 c b - # in (h,) } - self.assertLen(outer_jaxpr.jaxpr.eqns, 2) - _, eqn = outer_jaxpr.jaxpr.eqns - self.assertIn('jaxpr', eqn.params) - jaxpr = eqn.params['jaxpr'].jaxpr - self.assertLen(jaxpr.invars, 3) - i, j, k = jaxpr.invars - self.assertEqual(i.aval.shape, ()) - self.assertEqual(j.aval.shape, (i,)) - self.assertEqual(k.aval.shape, (i,)) - self.assertLen(eqn.outvars, 1) - h, = eqn.outvars - self.assertEqual(h.aval.shape, (3,)) - - def test_linearize_basic2(self): - @partial(jax.jit, abstracted_axes=('n',)) - def foo(x): - return jax.jit(jax.lax.sin)(x) - - x = jnp.arange(3.) - outer_jaxpr = jax.make_jaxpr(lambda x: jax.linearize(foo, x))(x) - # { lambda ; a:f32[3]. let - # b:f32[3] c:f32[3] = pjit[ - # jaxpr={ lambda ; d:i32[] e:f32[d]. let - # f:f32[d] g:f32[d] = pjit[ - # jaxpr={ lambda ; h:i32[] i:f32[h]. let - # j:f32[h] = sin i - # k:f32[h] = cos i - # in (j, k) } - # name=sin - # ] d e - # in (f, g) } - # name=foo - # ] 3 a - # in (b, c) } - self.assertLen(outer_jaxpr.jaxpr.eqns, 1) - eqn, = outer_jaxpr.jaxpr.eqns - self.assertLen(eqn.outvars, 2) - b, c = eqn.outvars - self.assertEqual(b.aval.shape, (3,)) - self.assertEqual(c.aval.shape, (3,)) - - def test_grad_basic(self): - @partial(jax.jit, abstracted_axes=('n',)) - def foo(x): - y = jax.lax.sin(x) - return y.sum() - - x = jnp.arange(3.) - outer_jaxpr = jax.make_jaxpr(jax.grad(foo))(x) - # { lambda ; a:f32[3]. let - # _:f32[] b:f32[3] = pjit[ - # jaxpr={ lambda ; c:i32[] d:f32[c]. let - # e:f32[c] = sin d - # f:f32[c] = cos d - # g:f32[] = reduce_sum[axes=(0,)] e - # in (g, f) } - # name=foo - # ] 3 a - # h:f32[3] = pjit[ - # jaxpr={ lambda ; i:i32[] j:f32[i] k:f32[]. let - # l:f32[i] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] k i - # m:f32[i] = mul l j - # in (m,) } - # name=foo - # ] 3 b 1.0 - # in (h,) } - self.assertLen(outer_jaxpr.jaxpr.eqns, 2) - fwd_eqn, bwd_eqn = outer_jaxpr.jaxpr.eqns - self.assertIn('jaxpr', fwd_eqn.params) - fwd_jaxpr = fwd_eqn.params['jaxpr'].jaxpr - self.assertLen(fwd_jaxpr.invars, 2) - c, d = fwd_jaxpr.invars - self.assertEqual(c.aval.shape, ()) - self.assertEqual(d.aval.shape, (c,)) - self.assertLen(fwd_jaxpr.outvars, 2) - g, f = fwd_jaxpr.outvars - self.assertEqual(g.aval.shape, ()) - self.assertEqual(f.aval.shape, (c,)) - self.assertLen(fwd_eqn.outvars, 2) - _, b = fwd_eqn.outvars - self.assertEqual(b.aval.shape, (3,)) - self.assertIn('jaxpr', bwd_eqn.params) - bwd_jaxpr = bwd_eqn.params['jaxpr'].jaxpr - self.assertLen(bwd_jaxpr.invars, 3) - i, j, k = bwd_jaxpr.invars - self.assertEqual(i.aval.shape, ()) - self.assertEqual(j.aval.shape, (i,)) - self.assertEqual(k.aval.shape, ()) - self.assertLen(bwd_jaxpr.outvars, 1) - m, = bwd_jaxpr.outvars - self.assertEqual(m.aval.shape, (i,)) - self.assertLen(bwd_eqn.outvars, 1) - h, = bwd_eqn.outvars - self.assertEqual(h.aval.shape, (3,)) - - def test_mlp_autodiff_dynamic_batch_toplevel(self): - def predict(params, inputs): - for W, b in params: - outputs = jnp.dot(inputs, W) + b - inputs = jnp.maximum(0, outputs) - return outputs - - def loss(params, batch): - inputs, targets = batch - predictions = predict(params, inputs) - return jnp.sum((predictions - targets) ** 2) - - batch = (inputs, targets) = (jnp.ones((128, 784)), jnp.ones((128, 10))) - params = [(jnp.ones((784, 256)), jnp.ones(256)), - (jnp.ones((256, 256)), jnp.ones(256)), - (jnp.ones((256, 10)), jnp.ones( 10))] - - # jvp - def loss_jvp(params, batch): - return jax.jvp(loss, (params, batch), (params, batch)) - jaxpr = jax.make_jaxpr(loss_jvp, abstracted_axes=({}, {0: 'n'}))(params, batch) - core.check_jaxpr(jaxpr.jaxpr) - - # linearize - def loss_lin(params, batch): - y, f_lin = jax.linearize(loss, params, batch) - y_dot = f_lin(params, batch) - return y, y_dot - jaxpr = jax.make_jaxpr(loss_lin, abstracted_axes=({}, {0: 'n'}))(params, batch) - core.check_jaxpr(jaxpr.jaxpr) - - # grad - jaxpr = jax.make_jaxpr(jax.grad(loss), abstracted_axes=({}, {0: 'n'}))(params, batch) - core.check_jaxpr(jaxpr.jaxpr) - - def test_mlp_autodiff_dynamic_batch_inner(self): - # This is like the above 'toplevel' test, but instead of introducing - # abstracted axes on the make_jaxpr call, we do it on a jit. - - @partial(jax.jit, abstracted_axes=({}, {0: 'n'})) - def predict(params, inputs): - for W, b in params: - outputs = jnp.dot(inputs, W) + b - inputs = jnp.maximum(0, outputs) - return outputs - - def loss(params, batch): - inputs, targets = batch - predictions = predict(params, inputs) - return jnp.sum((predictions - targets) ** 2) - - batch = (inputs, targets) = (jnp.ones((128, 784)), jnp.ones((128, 10))) - params = [(jnp.ones((784, 256)), jnp.ones(256)), - (jnp.ones((256, 256)), jnp.ones(256)), - (jnp.ones((256, 10)), jnp.ones( 10))] - - # jvp - def loss_jvp(params, batch): - return jax.jvp(loss, (params, batch), (params, batch)) - jaxpr = jax.make_jaxpr(loss_jvp)(params, batch) - core.check_jaxpr(jaxpr.jaxpr) - - # linearize - def loss_lin(params, batch): - y, f_lin = jax.linearize(loss, params, batch) - y_dot = f_lin(params, batch) - return y, y_dot - jaxpr = jax.make_jaxpr(loss_lin)(params, batch) - core.check_jaxpr(jaxpr.jaxpr) - - # grad - jaxpr = jax.make_jaxpr(jax.grad(loss))(params, batch) - core.check_jaxpr(jaxpr.jaxpr) - - def test_bint_broadcast(self): - d = lax.convert_element_type(3, core.bint(5)) - bint = lambda x, b: lax.convert_element_type(x, core.bint(b)) - - x = lax.broadcast_in_dim(0, (d,), ()) # doesn't crash - self.assertIsInstance(x, core.DArray) - self.assertAllClose(x._data, np.zeros(5, dtype='int32'), check_dtypes=False) - self.assertEqual( - x._aval, core.DShapedArray((bint(3, 5),), x._data.dtype, True)) - - def f(n): - return jnp.zeros(n) - x = jax.jit(f)(d) - self.assertIsInstance(x, core.DArray) - self.assertAllClose(x._data, np.zeros(5, dtype='int32'), check_dtypes=False) - self.assertEqual( - x._aval, core.DShapedArray((bint(3, 5),), x._data.dtype, False)) - - jaxpr = jax.make_jaxpr(f)(d).jaxpr - # { lambda ; a:bint{≤5}[]. let - # b:f32[a] = broadcast_in_dim[...] 0.0 a - # in (b,) } - self.assertLen(jaxpr.invars, 1) - a, = jaxpr.invars - self.assertEqual(a.aval, core.DShapedArray((), core.bint(5))) - self.assertLen(jaxpr.eqns, 1) - eqn, = jaxpr.eqns - self.assertLen(eqn.outvars, 1) - b, = eqn.outvars - self.assertEqual(b.aval.shape, (a,)) - - def test_vmap_abstracted_axis(self): - def foo(x, y): - z = jax.vmap(jnp.sin)(x) * y - return jax.vmap(jnp.add)(x, z) - - x = jnp.arange(3.) - jaxpr = jax.make_jaxpr(foo, abstracted_axes=('n',))(x, x).jaxpr - self.assertLen(jaxpr.invars, 3) - a, b, c = jaxpr.invars - self.assertEqual(a.aval.shape, ()) - self.assertEqual(b.aval.shape, (a,)) - self.assertEqual(c.aval.shape, (a,)) - self.assertLen(jaxpr.eqns, 3) - self.assertLen(jaxpr.outvars, 1) - f, = jaxpr.outvars - self.assertEqual(f.aval.shape, (a,)) - - def test_vmap_abstracted_axes_2d(self): - def foo(x, y): - z = jax.vmap(jax.vmap(jnp.sin))(x) * y - return jax.vmap(jax.vmap(jnp.add))(x, z) - - x = jnp.arange(12.).reshape(3, 4) - jaxpr = jax.make_jaxpr(foo, abstracted_axes=('n', 'm'))(x, x).jaxpr - self.assertLen(jaxpr.invars, 4) - a, b, c, d = jaxpr.invars - self.assertEqual(a.aval.shape, ()) - self.assertEqual(b.aval.shape, ()) - self.assertEqual(c.aval.shape, (a, b)) - self.assertEqual(c.aval.shape, (a, b)) - self.assertLen(jaxpr.eqns, 3) - self.assertLen(jaxpr.outvars, 1) - f, = jaxpr.outvars - self.assertEqual(f.aval.shape, (a, b)) - - def test_vmap_of_indexing_basic(self): - x = jnp.arange(3.) - - def f(idxs): - return jax.vmap(lambda i: x[i])(idxs) - - idxs = jnp.arange(3) - jaxpr = jax.make_jaxpr(f, abstracted_axes=('n',))(idxs).jaxpr - # { lambda a:f32[3]; b:i32[] c:i32[b]. let - # d:bool[b] = lt c 0 - # e:i32[b] = add c 3 - # f:i32[b] = select_n d c e - # g:i32[b,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(None, 1)] f b - # h:f32[b,1] = gather[ - # dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)) - # fill_value=None - # indices_are_sorted=False - # mode=GatherScatterMode.PROMISE_IN_BOUNDS - # slice_sizes=(1,) - # unique_indices=False - # ] a g - # i:f32[b] = squeeze[dimensions=(1,)] h - # in (i,) } - b, _ = jaxpr.invars - e, = (e for e in jaxpr.eqns if str(e.primitive) == 'gather') - h, = e.outvars - self.assertEqual(h.aval.shape, (b, 1)) - - def test_einsum_basic(self): - x = jnp.arange(20.).reshape(4, 5) - - def f(x): - return jnp.einsum('ij,kj->ik', x, x) - - jaxpr = jax.make_jaxpr(f, abstracted_axes=('n', 'm'))(x).jaxpr - # { lambda ; a:i32[] b:i32[] c:f32[a,b]. let - # d:f32[a,a] = pjit[ - # jaxpr={ lambda ; e:i32[] f:i32[] g:f32[e,f] h:f32[e,f]. let - # i:f32[e,e] = dot_general[ - # dimension_numbers=(((1,), (1,)), ((), ())) - # precision=None - # preferred_element_type=None - # ] g h - # in (i,) } - # name=_einsum - # ] a b c c - # in (d,) } - self.assertLen(jaxpr.invars, 3) - a, b, c = jaxpr.invars - self.assertEqual(c.aval.shape[0], a) - self.assertLen(jaxpr.eqns, 1) - self.assertLen(jaxpr.eqns[0].outvars, 1) - d, = jaxpr.eqns[0].outvars - self.assertEqual(d.aval.shape, (a, a)) - - def test_inferring_valid_subjaxpr_type_add(self): - def f(x): - return x + x.shape[0] - - jax.make_jaxpr(f, abstracted_axes=('n',))(jnp.arange(3)) # doesn't crash - - def test_slicing_basic_jaxpr(self): - def f(x): - return x[0] - - jaxpr = jax.make_jaxpr(f, abstracted_axes=(None, 'n'))(jnp.zeros((3, 4))) - # { lambda ; a:i32[] b:f32[3,a]. let - # c:f32[1,a] = dynamic_slice[slice_sizes=(1, None)] b 0 0 a - # d:f32[a] = squeeze[dimensions=(0,)] c - # in (d,) } - self.assertLen(jaxpr.jaxpr.invars, 2) - a, _ = jaxpr.jaxpr.invars - self.assertLen(jaxpr.jaxpr.outvars, 1) - d, = jaxpr.jaxpr.outvars - self.assertLen(d.aval.shape, 1) - self.assertEqual(d.aval.shape, (a,)) - - def test_shape_tuple_argument_to_zeros(self): - @partial(jax.jit, abstracted_axes=(('n',), ('n',))) - def f(x, y): - zero = jnp.zeros(jnp.shape(x)) - return zero * y - - x = jnp.arange(3.0) - y = jnp.arange(3.0) + 1 - jax.make_jaxpr(f)(x, y) # doesn't crash - -@unittest.skip("Test does not work with jax.Array") -@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow") -class DynamicShapeExecutionTest(jtu.JaxTestCase): - def test_jit_basic(self): - @jax.jit - def f(i): - return jnp.sum(jnp.ones(i, dtype='float32')) - self.assertAllClose(f(3), jnp.array(3., dtype='float32'), check_dtypes=True) - - def test_jit_basic_2(self): - count = 0 - - @partial(jax.jit, abstracted_axes=('n',)) - def f(x): - nonlocal count - count += 1 - return jnp.sum(x) - - x = f(np.arange(3)) - y = f(np.arange(4)) - self.assertAllClose(x, 3., check_dtypes=False) - self.assertAllClose(y, 6., check_dtypes=False) - self.assertEqual(count, 1) - - def test_jit_polymorphic_output(self): - # like test_jit_basic, but without the jnp.sum! - count = 0 - - @jax.jit - def f(i): - nonlocal count - count += 1 - return jnp.ones(i, dtype='float32') - - self.assertAllClose(f(3), np.ones(3, dtype='float32'), check_dtypes=True) - self.assertAllClose(f(4), np.ones(4, dtype='float32'), check_dtypes=True) - self.assertEqual(count, 1) - - @unittest.skip('TODO: need typechecking rule for concatenate') - def test_concatenate(self): - @partial(jax.jit, abstracted_axes=({0: 'n'},)) - def f(x): # x: f32[n, 4] - return jnp.concatenate([x, x, x], axis=0) - - f(np.ones((5, 4), dtype=np.float32)) - # TODO: add assertions - - def test_reshape(self): - @partial(jax.jit, abstracted_axes=({0: 'n'},)) - def f(x): # x: f32[n, 4] - return jnp.reshape(x, (2, -1)) - - f(np.ones((5, 4), dtype=np.float32)) - # TODO: add assertions - - def test_nested(self): - @jax.jit - def nested_f(x): # f32[h, v] -> f32[h, v] - # A nested call that needs shape variables - return jnp.sin(x) - - @partial(jax.jit, abstracted_axes=({0: 'h', 1: 'v'},)) - def f(x): # f32[h, w] -> f32[h, w] - return jnp.sin(x) + jax.jit(nested_f)(x) - f(np.ones((3, 5), dtype=np.float32)) - # TODO: add assertions - - def test_nested_arange(self): - def nested_f(x): # f32[h, v] -> f32[h, v] - # A nested call that needs to compute with shapes - return jnp.arange(x.shape[0] * x.shape[1], dtype=x.dtype).reshape(x.shape) - - @partial(jax.jit, abstracted_axes=({0: 'h', 1: 'w'},)) - def f(x): # f32[h, w] -> f32[h, w] - return x + jax.jit(nested_f)(x) - f(np.ones((3, 5), dtype=np.float32)) - # TODO: add assertions - - def test_transpose(self): - # see also https://github.com/iree-org/iree-jax/issues/57 - @partial(jax.jit, abstracted_axes=({0: 'h', 1: 'w'},)) - def f(x): # f32[h, w] -> f32[w, h] - return x.T - - f(np.ones((3, 5), dtype=np.float32)) # doesn't crash - # TODO: add assertions - - def test_matmul(self): - @partial(jax.jit, abstracted_axes=({0: 'w', 1: 'w'},)) - def f(x): # f32[w, w] -> f32[w, w] - return jnp.matmul(x, x) - - f(np.ones((5, 5), dtype=np.float32)) - # TODO: add assertions - - def test_matmul_shape_error(self): - @partial(jax.jit, abstracted_axes=({0: 'h', 1: 'w'},)) - def f(x): # f32[h, w] -> error - return jnp.matmul(x, x) - - # TODO(necula): improve error message, print actual shapes - with self.assertRaisesRegex(TypeError, - re.escape("dot_general requires contracting dimensions to have the same shape, got")): - f(np.ones((5, 5), dtype=np.float32)) - - @unittest.skip("TODO: investigate failure") - def test_cond(self): - @partial(jax.jit, abstracted_axes=({0: 'w', 1: 'w'},)) - def f(x): # f32[w, w] -> f32[w, w] - return lax.cond(True, - lambda x: jnp.sin(x), - lambda x: jnp.matmul(x, x), x) - f(np.ones((5, 5), dtype=np.float32)) - # TODO: add assertions - - def test_arange(self): - @partial(jax.jit, abstracted_axes=({0: 'w'},)) - def f(x): # f32[w] -> f32[w] - return jnp.arange(x.shape[0], dtype=x.dtype) + x - f(np.ones((5,), dtype=np.float32)) - # TODO: add assertions - - def test_broadcast(self): - @partial(jax.jit, abstracted_axes=({0: 'w'},)) - def f(x): # f32[w] -> f32[w, w] - return jnp.broadcast_to(x, (x.shape[0], x.shape[0])) - f(np.ones((5,), dtype=np.float32)) - # TODO: add assertions - - def test_zeros(self): - @partial(jax.jit, abstracted_axes=({0: 'w'},)) - def f(x): # f32[w] -> f32[w] - return jnp.zeros(x.shape[0], dtype=x.dtype) + x - f(np.ones((5,), dtype=np.float32)) - # TODO: add assertions - - def test_stack(self): - @partial(jax.jit, abstracted_axes=({0: 'w'},)) - def f(x): - return jnp.stack([jnp.sin(x), jnp.cos(x)]) - - f(np.ones((5,), dtype=np.float32)) - # TODO: add assertions - - def test_jit_dependent_pair_output(self): - # Like the above 'polymorhpic output' test, but now with a `2 * n`! - count = 0 - - @jax.jit - def f(n): - nonlocal count - count += 1 - return jnp.arange(2 * n) - - x = f(3) - y = f(4) - self.assertAllClose(x, jnp.arange(2 * 3), check_dtypes=False) - self.assertAllClose(y, jnp.arange(2 * 4), check_dtypes=False) - self.assertEqual(count, 1) - - @unittest.skip("revising slicing logic") - def test_slicing_basic(self): - f = jax.jit(lambda x, n: jnp.sum(x[:n])) - # TODO(mattjj): revise getslice, add typecheck rule for it, enable checks - with jax.enable_checks(False): - ans = f(jnp.arange(10), 3) - expected = jnp.sum(jnp.arange(10)[:3]) - self.assertAllClose(ans, expected, check_dtypes=True) - - # TODO(mattjj,dougalm,phawkins): debug iree failure, "failed to legalize - # operation 'while' that was explicitly marked illegal" - @unittest.skip("revising slicing logic") - def test_scan_basic(self): - def cumsum(x): - def body(i, _): - return i + 1, jnp.sum(x[:i+1]) - _, ans = lax.scan(body, 0, None, length=len(x)) - return ans - x = jnp.array([3, 1, 4, 1, 5, 9]) - with jax.enable_checks(False): - ans = cumsum(x) - expected = jnp.cumsum(x) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_jit_of_broadcast(self): - x = jax.jit(jnp.ones)(3) - self.assertAllClose(x, jnp.ones(3)) - - def test_jit_of_broadcast2(self): - x = jax.jit(lambda n: jnp.ones(2 * n))(3) - self.assertAllClose(x, jnp.ones(2 * 3)) - - def test_mlp_autodiff_dynamic_batch(self): - count = 0 - - def predict(params, inputs): - for W, b in params: - outputs = jnp.dot(inputs, W) + b - inputs = jnp.maximum(0, outputs) - return outputs - - def loss_ref(params, batch): - nonlocal count - count += 1 # count retraces - inputs, targets = batch - predictions = predict(params, inputs) - return jnp.sum((predictions - targets) ** 2) - - loss = jax.jit(loss_ref, abstracted_axes=({}, {0: 'n'})) - - params = [(jnp.ones((784, 256)), jnp.ones(256)), - (jnp.ones((256, 10)), jnp.ones( 10))] - - # two different size batches - batch1 = (inputs, targets) = (jnp.ones((128, 784)), jnp.ones((128, 10))) - batch2 = (inputs, targets) = (jnp.ones((32, 784)), jnp.ones((32, 10))) - - _ = loss(params, batch1) - _ = loss(params, batch2) - self.assertEqual(count, 1) - - _ = jax.grad(loss)(params, batch1) - _ = jax.grad(loss)(params, batch2) - self.assertEqual(count, 2) - - ans = loss( params, batch1) - expected = loss_ref(params, batch1) - self.assertAllClose(ans, expected) - - ans = jax.grad(loss )(params, batch1) - expected = jax.grad(loss_ref)(params, batch1) - self.assertAllClose(ans, expected) - - @jax.enable_checks(False) # TODO(mattjj): upgrade typecompat to handle bints - def test_mlp_autodiff_dynamic_batch_bint(self): - count = 0 - - def predict(params, inputs): - for W, b in params: - outputs = jnp.dot(inputs, W) + b - inputs = jnp.maximum(0, outputs) - return outputs - - def loss_ref(params, batch): - nonlocal count - count += 1 # count traces - inputs, targets = batch - predictions = predict(params, inputs) - return jnp.sum((predictions - targets) ** 2) - - loss = jax.jit(loss_ref, abstracted_axes=({}, {0: 'n'})) - - params = [(jnp.ones((784, 256)), jnp.ones(256)), - (jnp.ones((256, 10)), jnp.ones( 10))] - - # two different batch sizes *with bints* - bs1 = jax.lax.convert_element_type(128, core.bint(128)) - batch1 = (jnp.ones((bs1, 784)), jnp.ones((bs1, 10))) - - bs2 = jax.lax.convert_element_type(32, core.bint(128)) - batch2 = (jnp.ones((bs2, 784)), jnp.ones((bs2, 10))) - - # count retraces (and don't crash) - self.assertEqual(count, 0) - _ = jax.grad(loss)(params, batch1) - self.assertEqual(count, 1) - g2 = jax.grad(loss)(params, batch2) - self.assertEqual(count, 1) # cache hit! - - # check the numbers make sense - batch = (jnp.ones((32, 784)), jnp.ones((32, 10))) - g2_expected = jax.grad(loss_ref)(params, batch) - self.assertAllClose(g2, g2_expected, check_dtypes=False, - atol=1e-3, rtol=1e-3) - - def test_bint_basic(self): - d = lax.convert_element_type(3, core.bint(5)) - self.assertEqual(str(d), '3{≤5}') - - @jax.jit - def f(d): - jnp.sin(3.) # don't have an empty jaxpr - return d - f(d) # doesn't crash - - def test_bint_iota(self): - def f(d): - return jnp.arange(d, dtype='int32') - - y = f(lax.convert_element_type(3, core.bint(5))) - self.assertIsInstance(y, core.DArray) - self.assertAllClose(y._data, np.arange(5), check_dtypes=False) - - d = lax.convert_element_type(3, core.bint(5)) - y = jax.jit(f)(d) - self.assertIsInstance(y, core.DArray) - self.assertAllClose(y._data, np.arange(5), check_dtypes=False) - - def test_bint_compilation_cache(self): - count = 0 - - @jax.jit - def f(n): - nonlocal count - count += 1 - return jnp.zeros(n) - f(lax.convert_element_type(3, core.bint(5))) - f(lax.convert_element_type(4, core.bint(5))) - self.assertEqual(count, 1) - - def test_bint_compilation_cache2(self): - count = 0 - - @partial(jax.jit, abstracted_axes=('n',)) - def f(x): - nonlocal count - count += 1 - return x.sum() - - d = lax.convert_element_type(3, core.bint(5)) - x = jnp.arange(d) - y = f(x) - self.assertEqual(y, 3) - self.assertEqual(count, 1) - - d = lax.convert_element_type(4, core.bint(5)) - x = jnp.arange(d) - y = f(x) - self.assertEqual(y, 6) - self.assertEqual(count, 1) - - d = lax.convert_element_type(4, core.bint(6)) - x = jnp.arange(d) - y = f(x) - self.assertEqual(y, 6) - self.assertEqual(count, 2) - - @unittest.skip('do we want to support this?') - def test_bint_add(self): - d = lax.convert_element_type(4, core.bint(6)) - x = jnp.arange(d) - - @jax.jit - def f(x): - return x + x - - f(x) # doesn't crash - - def test_lower_abstracted_axes(self): - @partial(jax.jit, abstracted_axes=('n',)) - def f(x): - return x.sum() - - f_lowered = f.lower(np.arange(3, dtype='int32')) - mlir_str = f_lowered.compiler_ir() - self.assertIn('tensor', str(mlir_str)) - - def test_lower_abstracted_axes_shapedtypestruct(self): - @partial(jax.jit, abstracted_axes=('n',)) - def f(x): - return x.sum() - - f_lowered = f.lower(jax.ShapeDtypeStruct((3,), np.int32)) - mlir_str = f_lowered.compiler_ir() - self.assertIn('tensor', str(mlir_str)) - - def test_slicing_basic_lower(self): - @partial(jax.jit, abstracted_axes=(None, 'n')) - def f(x): - return x[0] - f.lower(jnp.zeros((3, 4))).compiler_ir() # doesn't crash - - def test_slicing_basic_execute(self): - @partial(jax.jit, abstracted_axes=(None, 'n')) - def f(x): - return x[0] - - y = f(jnp.arange(3 * 4).reshape(3, 4)) - self.assertAllClose(y, jnp.array([0, 1, 2, 3])) - - def test_gather_basic_bounded(self): - x = jnp.arange(3. * 4.).reshape(3, 4) - - def f(i): - return x[i] - - sz = jax.lax.convert_element_type(2, core.bint(3)) - idx = jnp.arange(sz) - y = jax.jit(jax.vmap(f), abstracted_axes=('n',))(idx) - - self.assertIsInstance(y, core.DArray) - self.assertEqual(y.shape, (sz, 4)) - self.assertAllClose(y._data, x) - -@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow", - jax_traceback_filtering='off') -class JumbleTest(jtu.JaxTestCase): - - def setUp(self): - super().setUp() - if jax.config.x64_enabled: raise unittest.SkipTest() - - @parameterized.parameters((True,), (False,)) - def test_internal_jumble(self, disable_jit): - with jax.disable_jit(disable_jit): - ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - xs = jax.vmap(lambda n: jax.lax.iota('int32', n).sum())(ins) - self.assertAllClose(xs, jnp.array([3, 0, 6]), check_dtypes=False) - - def test_jumble_escapes(self): - ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - xs = jax.vmap(jax.jit(lambda n: jax.lax.iota('int32', n)), - out_axes=batching.jumble_axis)(ins) - self.assertIsInstance(xs, batching.Jumble) - data = jax.lax.broadcasted_iota('int32', (3, 5), 1) - self.assertAllClose(xs.data, data, check_dtypes=False) - - def test_make_jumble_from_dynamic_shape(self): - # We may not want to support returning jumbles from vmapped functions - # (instead preferring to have a separate API which allows jumbles). But for - # now it makes for a convenient way to construct jumbles for the other - # tests! - ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - p = jax.vmap(partial(jnp.arange, dtype='int32'), - out_axes=batching.jumble_axis)(ins) - self.assertIsInstance(p, batching.Jumble) - self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+\]') - data = jax.lax.broadcasted_iota('int32', (3, 5), 1) - self.assertAllClose(p.data, data, check_dtypes=False) - - def test_jumble_map_eltwise(self): - ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - p = jax.vmap(partial(jnp.arange, dtype='int32'), - out_axes=batching.jumble_axis)(ins) - p = jumble_map(jax.jit(lambda x: x * 3))(p) - self.assertIsInstance(p, batching.Jumble) - self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+\]') - data = jax.lax.broadcasted_iota('int32', (3, 5), 1) * 3 - self.assertAllClose(p.data, data, check_dtypes=False) - - def test_jumble_map_vector_dot(self): - ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - p = jax.vmap(partial(jnp.arange, dtype='int32'), - out_axes=batching.jumble_axis)(ins) - y = jumble_map(jnp.dot)(p, p) - self.assertIsInstance(y, batching.Jumble) - self.assertAllClose(y.data, jnp.array([5, 0, 14], dtype='int32')) - - @parameterized.parameters((True,), (False,)) - def test_jumble_map_matrix_dot_ragged_contract(self, disable_jit): - with jax.disable_jit(disable_jit): - sizes = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - p1 = jax.vmap(lambda n: jnp.ones((7, n)), out_axes=batching.jumble_axis - )(sizes) - p2 = jax.vmap(lambda n: jnp.ones((n, 7)), out_axes=batching.jumble_axis - )(sizes) - y = jax.vmap(jnp.dot, in_axes=batching.jumble_axis, out_axes=0, - axis_size=3)(p1, p2) - self.assertAllClose(y, np.tile(np.array([3, 1, 4])[:, None, None], (7, 7)), - check_dtypes=False) - - @parameterized.parameters((True,), (False,)) - def test_jumble_map_matrix_dot_ragged_tensor(self, disable_jit): - with jax.disable_jit(disable_jit): - sizes = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - def func(size): - lhs_one_d = jnp.arange(size, dtype='int32') + 1 - lhs_two_d = jax.lax.broadcast_in_dim(lhs_one_d, (size, 2), (0,)) - rhs = jax.lax.broadcasted_iota('int32', (2, 4), 0) + 1 - return jnp.dot(lhs_two_d, rhs) - p = jax.vmap(func, out_axes=batching.jumble_axis)(sizes) - self.assertIsInstance(p, batching.Jumble) - self.assertEqual(p.data.shape, (3, 5, 4)) - - def test_broadcast_in_dim_while_ragged(self): - ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - def func(size): - one_d = jnp.arange(size, dtype='int32') - two_d = jax.lax.broadcast_in_dim(one_d, (size, 7), (0,)) - return two_d - p = jax.vmap(func, out_axes=batching.jumble_axis)(ins) - self.assertIsInstance(p, batching.Jumble) - data = jax.lax.broadcasted_iota('int32', (3, 5, 7), 1) - self.assertAllClose(p.data, data) - - def test_broadcast_in_dim_to_ragged(self): - ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - def func(size): - one_d = jnp.arange(12, dtype='int32') - two_d = jax.lax.broadcast_in_dim(one_d, (size, 12), (1,)) - return two_d - p = jax.vmap(func, out_axes=batching.jumble_axis)(ins) - self.assertIsInstance(p, batching.Jumble) - data = jax.lax.broadcasted_iota('int32', (3, 5, 12), 2) - self.assertAllClose(p.data, data) - - def test_broadcast_in_dim_ragged_to_static_error(self): - ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - def func(size): - one_d = jnp.arange(size, dtype='int32') - # Broadcast should error even if the target shape is the same as the - # underlying data shape, because the semantic size doesn't match. - two_d = jax.lax.broadcast_in_dim(one_d, (4, 5), (1,)) - return two_d - msg = r"got operand of shape \(\[dynamic\],\), target broadcast shape \(4, 5\)" - with self.assertRaisesRegex(TypeError, msg): - jax.vmap(func, out_axes=batching.jumble_axis)(ins) - - def test_broadcast_in_dim_to_doubly_ragged(self): - ins1 = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - ins2 = lax.convert_element_type(jnp.array([2, 5, 1]), core.bint(6)) - def func(size1, size2): - one_d = jnp.arange(size1, dtype='int32') - two_d = jax.lax.broadcast_in_dim(one_d, (size1, size2), (0,)) - return two_d - p = jax.vmap(func, out_axes=batching.jumble_axis)(ins1, ins2) - self.assertIsInstance(p, batching.Jumble) - data = jax.lax.broadcasted_iota('int32', (3, 5, 6), 1) - self.assertAllClose(p.data, data) - - def test_squeeze_ragged(self): - ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - def func(size): - one_d = jnp.arange(size, dtype='int32') - two_d = jax.lax.broadcast_in_dim(one_d, (size, 1), (0,)) - one_again = jax.lax.squeeze(two_d, dimensions=[1]) - return one_again - p = jax.vmap(func, out_axes=batching.jumble_axis)(ins) - self.assertIsInstance(p, batching.Jumble) - data = jax.lax.broadcasted_iota('int32', (3, 5), 1) - self.assertAllClose(p.data, data) - - def test_broadcast_to_while_ragged(self): - ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - def func(size): - one_d = jnp.arange(size, dtype='int32') - two_d = jnp.broadcast_to(one_d, (4, size)) - return two_d - p = jax.vmap(func, out_axes=batching.jumble_axis)(ins) - self.assertIsInstance(p, batching.Jumble) - data = jax.lax.broadcasted_iota('int32', (3, 4, 5), 2) - self.assertAllClose(p.data, data) - - def test_broadcast_to_doubly_ragged(self): - ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - def func(size): - one_d = jnp.arange(size, dtype='int32') - two_d = jnp.broadcast_to(one_d, (size, size)) - return two_d - p = jax.vmap(func, out_axes=batching.jumble_axis)(ins) - self.assertIsInstance(p, batching.Jumble) - data = jax.lax.broadcasted_iota('int32', (3, 5, 5), 2) - self.assertAllClose(p.data, data) - - def test_transpose_ragged(self): - ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - def func(size): - one_d = jnp.arange(size, dtype='int32') - two_d = jnp.broadcast_to(one_d, (7, size)) - return jnp.transpose(two_d, [1, 0]) - p = jax.vmap(func, out_axes=batching.jumble_axis)(ins) - self.assertIsInstance(p, batching.Jumble) - data = jax.lax.broadcasted_iota('int32', (3, 5, 7), 1) - self.assertAllClose(p.data, data) - - def test_einsum_with_ragged_tensor_dimension(self): - x_sizes = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - def fprop_layer(x_size): - one_d = jnp.arange(x_size, dtype='int32') - x = jax.lax.broadcast_in_dim(one_d, (x_size, 11), [0]) - wqkv = jax.lax.broadcasted_iota('int32', (3, 2, 7, 11), 1) - qkv = jnp.einsum('te,ihqe->ithq', x, wqkv) - return qkv - p = jax.vmap(fprop_layer, out_axes=batching.jumble_axis)(x_sizes) - self.assertIsInstance(p, batching.Jumble) - self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[3,bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+,2,7\]') - self.assertEqual(p.data.shape, (3, 3, 5, 2, 7)) - - @parameterized.parameters((True,), (False,)) - def test_einsum_with_ragged_tensor_and_contract_dimensions(self, disable_jit): - with jax.disable_jit(disable_jit): - ragged_sizes = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - def fprop_layer(ragged_size): - one_d = jnp.arange(ragged_size, dtype='int32') - alpha = jax.lax.broadcast_in_dim(one_d, (ragged_size, ragged_size, 2), [1]) - v = jax.lax.broadcast_in_dim(one_d, (ragged_size, 2, 7), [0]) - inner = jnp.einsum('tsh,shq->thq', alpha, v) - return inner - p = jax.vmap(fprop_layer, out_axes=batching.jumble_axis)(ragged_sizes) - self.assertIsInstance(p, batching.Jumble) - self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+,2,7\]') - self.assertEqual(p.data.shape, (3, 5, 2, 7)) - - def test_split_while_ragged(self): - ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - def func(size): - one_d = jnp.arange(size, dtype='int32') - two_d = jnp.broadcast_to(one_d, (2, size)) - part_1, part_2 = two_d - return part_1 - p = jax.vmap(func, out_axes=batching.jumble_axis)(ins) - self.assertIsInstance(p, batching.Jumble) - self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+\]') - data = jax.lax.broadcasted_iota('int32', (3, 5), 1) - self.assertAllClose(p.data, data) - - @parameterized.parameters((True,), (False,)) - @unittest.skip("test fails at head") - def test_jumble_map_end_to_end_fprop_layer(self, disable_jit): - - def fprop_layer(params, x): - ((xnorm_scale, xnorm_bias), (wqkv, wqkv_bias), (wo, wo_bias), - (ynorm_scale, ynorm_bias), (w_i, w_i_bias), (w_o, w_o_bias)) = params - xnorm = jax.nn.standardize(x) * xnorm_scale + xnorm_bias - qkv = jnp.einsum('te,ihqe->ithq', xnorm, wqkv) + wqkv_bias[:, None] - q, k, v = qkv - outer = jnp.einsum('thq,shq->tsh', q, k) / jnp.asarray( - jnp.sqrt(v.shape[-1]), dtype=x.dtype) - - alpha = jax.nn.softmax(outer, 2) - inner = jnp.einsum('tsh,shq->thq', alpha, v) - y = jnp.einsum('thq,hqe->te', inner, wo) + wo_bias + x - ynorm = jax.nn.standardize(y) * ynorm_scale + ynorm_bias - act = jax.nn.gelu(jnp.einsum('te,ef->tf', ynorm, w_i) + w_i_bias) - z = jnp.einsum('tf,fe->te', act, w_o) + w_o_bias + y - return z - - params = [ - (jnp.ones(128), jnp.zeros(128)), # xnorm_scale, xnorm_bias - (jnp.ones((3, 16, 64, 128)), jnp.zeros((3, 16, 64))), # wqkv, wqkv_bias - (jnp.ones((16, 64, 128)), jnp.zeros(128)), # wo, wo_bias - (jnp.ones(128), jnp.zeros(128)), # ynorm_scale, ynorm_bias - (jnp.ones((128, 4096)), jnp.zeros(4096)), # w_i, w_i_bias - (jnp.ones((4096, 128)), jnp.zeros(128)), # w_o, w_o_bias - ] - - xs = [ - jnp.zeros((512, 128)), - jnp.zeros((386, 128)), - jnp.zeros((420, 128)), - ] - - def jumble_stack(xs: list[jax.Array]) -> batching.Jumble: - max_length = max(len(x) for x in xs) - lengths = jnp.array([len(x) for x in xs]) - lengths = jax.lax.convert_element_type(lengths, core.bint(max_length)) - xs_padded = jnp.stack([jnp.zeros((max_length, 128), dtype=x.dtype - ).at[:x.shape[0]].set(x) for x in xs]) - - # binder = i - binder = core.Var('', core.ShapedArray((), np.dtype('int32'))) - # elt_ty = f32[[3, 1, 4].i, 128] - elt_ty = core.DShapedArray((batching.IndexedAxisSize(binder, lengths), 128), - xs_padded.dtype) - # aval = i:(Fin 3) => f32[[3, 1, 4].i, 128] - aval = batching.JumbleTy(binder, len(xs), elt_ty) - xs_jumble = batching.Jumble(aval, xs_padded) - return xs_jumble - - with jax.disable_jit(disable_jit): - xs_jumble = jumble_stack(xs) - - fprop_batched = jax.vmap(fprop_layer, - in_axes=(None, batching.jumble_axis), - out_axes=batching.jumble_axis, - axis_size=3) - result_jumble = fprop_batched(params, xs_jumble) - self.assertIsInstance(result_jumble, batching.Jumble) - regex = r'Var[0-9]+:3 => (f32|f64)\[bint\{≤512\}\[3\] with value: \[512 386 420\]\.Var[0-9]+,128\]' - self.assertRegex(str(result_jumble.aval), regex) - self.assertAllClose(result_jumble.data.shape, (3, 512, 128)) - -def jumble_map(f): - def mapped(*jumbles): - return jax.vmap(f, in_axes=batching.jumble_axis, out_axes=batching.jumble_axis, - axis_size=jumbles[0].aval.length)(*jumbles) - return mapped - -if __name__ == '__main__': - absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/error_check_test.py b/tests/error_check_test.py index b96c6281411f..e20017a39a9b 100644 --- a/tests/error_check_test.py +++ b/tests/error_check_test.py @@ -13,12 +13,16 @@ # limitations under the License. +import traceback + from absl.testing import absltest from absl.testing import parameterized import jax from jax._src import config from jax._src import error_check +from jax._src import mesh as mesh_lib from jax._src import test_util as jtu +import jax.export import jax.numpy as jnp from jax.sharding import NamedSharding, PartitionSpec as P @@ -30,7 +34,9 @@ jtu.request_cpu_devices(4) -@jtu.with_config(jax_check_tracer_leaks=True) +# TODO: AOT tests fails with the tracer leak checker. +# Re-enable once https://github.com/jax-ml/jax/issues/27315 is fixed. +# @jtu.with_config(jax_check_tracer_leaks=True) class ErrorCheckTests(jtu.JaxTestCase): @parameterized.product(jit=[True, False]) @@ -107,6 +113,32 @@ def g(x): with self.assertRaisesRegex(JaxValueError, "x must be greater than 0 in g"): error_check.raise_if_error() + @parameterized.product(jit=[True, False]) + def test_error_includes_traceback(self, jit): + def function_that_triggers_error_for_traceback_test(x): + error_check.set_error_if( # This line must be included in the traceback. + x <= 0, "x must be greater than 0" + ) + return x + 1 + + if jit: + function_that_triggers_error_for_traceback_test = jax.jit( + function_that_triggers_error_for_traceback_test + ) + + x = jnp.zeros((4,), dtype=jnp.int32) + function_that_triggers_error_for_traceback_test(x) + + tb_string = "" + try: + error_check.raise_if_error() + except JaxValueError as e: + tb_string = traceback.format_tb(e.__traceback__) + tb_string = "".join(tb_string) + + self.assertIn("function_that_triggers_error_for_traceback_test", tb_string) + self.assertIn("This line must be included in the traceback", tb_string) + @parameterized.product(jit=[True, False]) def test_error_check_works_with_cond(self, jit): def f(x): @@ -193,7 +225,7 @@ def f(x): jax.jit(error_check.raise_if_error)() @parameterized.product(jit=[True, False]) - @jtu.with_user_mesh((2, 2), ("x", "y")) + @jtu.with_explicit_mesh((2, 2), ("x", "y")) def test_error_check_explicit_mode(self, mesh, jit): def f(x): error_check.set_error_if(x <= 0, "x must be greater than 0") @@ -202,13 +234,144 @@ def f(x): if jit: f = jax.jit(f) - sharding = NamedSharding(mesh, P("x", "y")) - x = jnp.full((4, 4), -1, dtype=jnp.int32, device=sharding) with error_check.error_checking_context(): + x = jnp.full((4, 4), -1, dtype=jnp.int32) + f(x) + with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"): + error_check.raise_if_error() + + sharding = NamedSharding(mesh, P("x", "y")) + with error_check.error_checking_context(): + y = jnp.full((4, 4), -1, dtype=jnp.int32, device=sharding) + f(y) + with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"): + error_check.raise_if_error() + + # The unsharded version of `f` should still be able to check errors after + # exiting the error checking context. + f(x) + with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"): + error_check.raise_if_error() + + @parameterized.product(jit=[True, False]) + @jtu.with_explicit_mesh( + (2, 2), + ("x", "y"), + axis_types=(mesh_lib.AxisType.Auto, mesh_lib.AxisType.Auto), + ) + @jtu.ignore_warning( + message=( + "When at least one mesh axis of `pred` is in auto mode, calling" + " `set_error_if` will cause implicit communication between devices." + " To avoid this, consider converting the mesh axis in auto mode to" + " explicit mode." + ), + category=RuntimeWarning, + ) + def test_error_check_auto_mode(self, jit, mesh): + def f(x): + error_check.set_error_if(x <= 0, "x must be greater than 0") + return x + 1 + + if jit: + f = jax.jit(f) + + with error_check.error_checking_context(): + sharding = NamedSharding(mesh, P("x", "y")) + x = jnp.full((4, 4), -1, dtype=jnp.int32, device=sharding) f(x) with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"): error_check.raise_if_error() + def test_error_check_aot(self): + def run_export(): + def f(x): + error_check.set_error_if(x <= 0, "x must be greater than 0") + return x + 1 + + f = jax.jit(error_check.wrap_for_export(jax.jit(f))) + x = jax.ShapeDtypeStruct((), jnp.float32) + serialized = jax.export.export(f)(x).serialize() + return serialized + + def run_import(serialized): + f = jax.export.deserialize(serialized).call + f = jax.jit(error_check.unwrap_from_import(jax.jit(f))) + x = jnp.float32(-3.) + _ = f(x) + with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"): + error_check.raise_if_error() + + serialized = run_export() + run_import(serialized) + + def test_error_check_aot_includes_traceback(self): + def run_export(): + def function_that_triggers_error_for_traceback_test(x): + error_check.set_error_if( # This line must be included in the traceback + x <= 0, "x must be greater than 0" + ) + return x + 1 + + f = jax.jit( + error_check.wrap_for_export( + jax.jit(function_that_triggers_error_for_traceback_test) + ) + ) + x = jax.ShapeDtypeStruct((), jnp.float32) + serialized = jax.export.export(f)(x).serialize() + return serialized + + def run_import(serialized): + f = jax.export.deserialize(serialized).call + f = jax.jit(error_check.unwrap_from_import(jax.jit(f))) + x = jnp.float32(-3.0) + _ = f(x) + + msg = "" + try: + error_check.raise_if_error() + except JaxValueError as e: + msg = str(e) + + self.assertIn("function_that_triggers_error_for_traceback_test", msg) + self.assertIn("This line must be included in the traceback", msg) + + serialized = run_export() + run_import(serialized) + + def test_error_check_aot_should_not_override_existing_error(self): + def f1(x): + error_check.set_error_if(x <= 0, "x must be greater than 0 in f1") + return x + 1 + + def run_export(): + def f2(x): + error_check.set_error_if(x <= 0, "x must be greater than 0 in f2") + return x + 1 + + f2 = jax.jit(error_check.wrap_for_export(jax.jit(f2))) + x = jax.ShapeDtypeStruct((), jnp.float32) + serialized = jax.export.export(f2)(x).serialize() + return serialized + + def run_import(serialized): + f2 = jax.export.deserialize(serialized).call + f2 = jax.jit(error_check.unwrap_from_import(jax.jit(f2))) + return f2 + + x = jnp.float32(-3.) + _ = f1(x) # check fails. so it should set error + + serialized = run_export() + f2 = run_import(serialized) + _ = f2(x) # check fails, but should not override the error + + with self.assertRaisesRegex( + JaxValueError, "x must be greater than 0 in f1" + ): + error_check.raise_if_error() + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/errors_test.py b/tests/errors_test.py index 25f29cfee224..4b4ea2cfd487 100644 --- a/tests/errors_test.py +++ b/tests/errors_test.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib +import gc import re -import sys import traceback +import weakref from absl.testing import absltest from absl.testing import parameterized @@ -46,10 +48,7 @@ def check_filtered_stack_trace(test, etype, f, frame_patterns=(), e = get_exception(etype, f) c = e.__cause__ if filter_mode == "quiet_remove_frames": - if sys.version_info >= (3, 11): - assert any("For simplicity" in x for x in e.__notes__) - else: - test.assertIsInstance(c, jax.errors.SimplifiedTraceback) + assert any("For simplicity" in x for x in e.__notes__) elif filter_mode == "remove_frames": test.assertIsInstance(c, traceback_util.UnfilteredStackTrace) else: @@ -393,12 +392,8 @@ def outer(x): ('', 'f = lambda: outer'), ('outer', 'raise TypeError')], filter_mode=filter_mode) e = get_exception(TypeError, f) # Uses the default JAX_TRACEBACK_FILTERING=auto - if sys.version_info >= (3, 11): - assert any("For simplicity" in x for x in e.__notes__) - self.assertIsInstance(e.__cause__, ValueError) - else: - self.assertIsInstance(e.__cause__, jax.errors.SimplifiedTraceback) - self.assertIsInstance(e.__cause__.__cause__, ValueError) + assert any("For simplicity" in x for x in e.__notes__) + self.assertIsInstance(e.__cause__, ValueError) def test_null_traceback(self, filter_mode): class TestA: pass @@ -411,6 +406,39 @@ def err(): check_filtered_stack_trace(self, TypeError, err, [ ('err', 'return jit(f)(a)')], filter_mode=filter_mode) + def test_api_boundary_does_not_add_to_garbage(self, filter_mode): + self.enter_context(config.traceback_filtering(filter_mode)) + self.enter_context(disable_gc()) + + class MyObject: + def __call__(self): + f() + + @traceback_util.api_boundary + def f(): + g() + + @traceback_util.api_boundary + def g(): + raise ValueError('f') + + o = MyObject() + weak_o = weakref.ref(o) + try: + o() + except ValueError: + pass + del o + self.assertIsNone(weak_o()) + +@contextlib.contextmanager +def disable_gc(): + gc.disable() + gc.collect() + try: + yield + finally: + gc.enable() @jtu.with_config(jax_traceback_filtering='auto') # JaxTestCase defaults to off. class UserContextTracebackTest(jtu.JaxTestCase): @@ -424,14 +452,9 @@ def test_grad_norm(self): e = exc self.assertIsNot(e, None) self.assertIn("invalid value", str(e)) - if sys.version_info >= (3, 11): - self.assertIsInstance( - e.__cause__, - source_info_util.JaxStackTraceBeforeTransformation) - else: - self.assertIsInstance( - e.__cause__.__cause__, - source_info_util.JaxStackTraceBeforeTransformation) + self.assertIsInstance( + e.__cause__, + source_info_util.JaxStackTraceBeforeTransformation) class CustomErrorsTest(jtu.JaxTestCase): @@ -455,7 +478,7 @@ class FakeTracer(core.Tracer): ErrorClass = getattr(jax.errors, errorclass) err = ErrorClass(FakeTracer(None)) - self.assertIn(f'https://jax.readthedocs.io/en/latest/errors.html#jax.errors.{errorclass}', str(err)) + self.assertIn(f'https://docs.jax.dev/en/latest/errors.html#jax.errors.{errorclass}', str(err)) if __name__ == '__main__': diff --git a/tests/experimental_rnn_test.py b/tests/experimental_rnn_test.py index 7fa3b93f3c42..a86e3bd78da6 100644 --- a/tests/experimental_rnn_test.py +++ b/tests/experimental_rnn_test.py @@ -179,7 +179,7 @@ def f(weights, x, h_0, c_0): y_padded = y_ref[i, seq_lengths[i]:] np.testing.assert_allclose(y_padded, jnp.zeros_like(y_padded)) - @jtu.run_on_devices("cuda") + @jtu.run_on_devices("gpu") def test_struct_encoding_determinism(self): def f(k1, k2, k3, k4): batch_size = 1 @@ -213,18 +213,18 @@ def f(k1, k2, k3, k4): k = jax.random.split(jax.random.PRNGKey(1), 4) stablehlo = jax.jit(f).lower(*k).as_text("stablehlo") - if jtu.jaxlib_version() <= (0, 5, 2): - self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00@\\01\\00\\00"', - stablehlo) - else: - self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00\\00\\00\\00\\00@\\01\\00\\00\\00\\00\\00\\00"', - stablehlo) - - @jtu.run_on_devices("cuda") - def test_no_workspace_overflow(self): - if jtu.jaxlib_version() <= (0, 5, 2): - self.skipTest("Older versions fail because of integer overflow.") + # Platform-specific binary encodings for RnnDescriptor + cuda_encoding = '"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00\\00\\00\\00\\00@\\01\\00\\00\\00\\00\\00\\00"' + rocm_encoding = '"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\008\\00\\00\\00\\00\\00\\00\\00\\1C\\00\\00\\00\\00\\00\\00\\00"' + + # Check that one of the expected encodings is present + if jtu.test_device_matches(["cuda"]): + self.assertIn(cuda_encoding, stablehlo) + elif jtu.test_device_matches(["rocm"]): + self.assertIn(rocm_encoding, stablehlo) + @jtu.run_on_devices("cuda", "rocm") + def test_no_workspace_overflow(self): # Problem sizes known to cause overflows on older versions. batch_size, max_seq_length, input_size = 256, 500, 512 num_layers, hidden_size = 1, 256 diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 9b457b8f27a5..478e63121bda 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -31,14 +31,15 @@ from jax._src.internal_test_util import export_back_compat_test_util as bctu +from jax._src.internal_test_util.export_back_compat_test_data import annotate_data_placement from jax._src.internal_test_util.export_back_compat_test_data import cpu_cholesky_lapack_potrf +from jax._src.internal_test_util.export_back_compat_test_data import cuda_cholesky_solver_potrf from jax._src.internal_test_util.export_back_compat_test_data import cpu_eig_lapack_geev from jax._src.internal_test_util.export_back_compat_test_data import cuda_eigh_cusolver_syev from jax._src.internal_test_util.export_back_compat_test_data import rocm_eigh_hipsolver_syev from jax._src.internal_test_util.export_back_compat_test_data import cpu_eigh_lapack_syev from jax._src.internal_test_util.export_back_compat_test_data import cpu_lu_lapack_getrf from jax._src.internal_test_util.export_back_compat_test_data import cuda_qr_cusolver_geqrf -from jax._src.internal_test_util.export_back_compat_test_data import rocm_qr_hipsolver_geqrf from jax._src.internal_test_util.export_back_compat_test_data import cpu_qr_lapack_geqrf from jax._src.internal_test_util.export_back_compat_test_data import cpu_schur_lapack_gees from jax._src.internal_test_util.export_back_compat_test_data import cpu_svd_lapack_gesdd @@ -51,6 +52,7 @@ from jax._src.internal_test_util.export_back_compat_test_data import cuda_lu_cusolver_getrf from jax._src.internal_test_util.export_back_compat_test_data import cuda_svd_cusolver_gesvd from jax._src.internal_test_util.export_back_compat_test_data import cuda_tridiagonal_cusolver_sytrd +from jax._src.internal_test_util.export_back_compat_test_data import cuda_tridiagonal_solve from jax._src.internal_test_util.export_back_compat_test_data import tpu_Eigh from jax._src.internal_test_util.export_back_compat_test_data import tpu_Lu from jax._src.internal_test_util.export_back_compat_test_data import tpu_ApproxTopK @@ -62,8 +64,7 @@ from jax._src.internal_test_util.export_back_compat_test_data import stablehlo_dynamic_top_k from jax._src.internal_test_util.export_back_compat_test_data import stablehlo_dynamic_approx_top_k -from jax.experimental import pjit -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map import jax.numpy as jnp from jax.sharding import Mesh @@ -72,17 +73,10 @@ from jax._src import config from jax._src import test_util as jtu -from jax._src.lib import cuda_versions config.parse_flags_with_absl() -def _is_required_cusolver_version_satisfied(required_version): - if cuda_versions is None: - return False - return cuda_versions.cusolver_get_version() >= required_version - - @jtu.with_config(jax_legacy_prng_key="allow", jax_debug_key_reuse=False, jax_include_full_tracebacks_in_locations=False, @@ -120,12 +114,12 @@ def test_custom_call_coverage(self): targets_to_cover = set(_export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE) cpu_ffi_testdatas = [ cpu_cholesky_lapack_potrf.data_2024_05_31, - cpu_qr_lapack_geqrf.data_2024_08_22, + cpu_qr_lapack_geqrf.data_2025_04_02, cpu_eig_lapack_geev.data_2024_08_19, cpu_eigh_lapack_syev.data_2024_08_19, cpu_lu_lapack_getrf.data_2024_05_31, cpu_schur_lapack_gees.data_2024_11_29, - cpu_triangular_solve_blas_trsm.data_2024_12_02, + cpu_triangular_solve_blas_trsm.data_2025_10_20, cpu_svd_lapack_gesdd.data_2024_08_13, cpu_hessenberg_lapack_gehrd.data_2024_08_31, cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_12_01, @@ -134,28 +128,19 @@ def test_custom_call_coverage(self): # stable covering_testdatas = [ *cpu_ffi_testdatas, - cpu_cholesky_lapack_potrf.data_2023_06_19, - cpu_eig_lapack_geev.data_2023_06_19, - cpu_eigh_lapack_syev.data_2023_03_17, - cpu_qr_lapack_geqrf.data_2023_03_17, + cuda_cholesky_solver_potrf.data_2025_10_15, cuda_threefry2x32.data_2024_07_30, - cpu_lu_lapack_getrf.data_2023_06_14, - cuda_lu_pivots_to_permutation.data_2024_08_08, + cuda_lu_pivots_to_permutation.data_2025_04_01, cuda_lu_cusolver_getrf.data_2024_08_19, cuda_qr_cusolver_geqrf.data_2024_09_26, cuda_eigh_cusolver_syev.data_2024_09_30, cuda_svd_cusolver_gesvd.data_2024_10_08, cpu_tridiagonal_solve_lapack_gtsv.data_2025_01_09, cuda_tridiagonal_cusolver_sytrd.data_2025_01_09, - rocm_qr_hipsolver_geqrf.data_2024_08_05, + cuda_tridiagonal_solve.data_2025_06_16, rocm_eigh_hipsolver_syev.data_2024_08_05, - cpu_schur_lapack_gees.data_2023_07_16, - cpu_svd_lapack_gesdd.data_2023_06_19, - cpu_triangular_solve_blas_trsm.data_2023_07_16, - cpu_hessenberg_lapack_gehrd.data_2024_08_30, - cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_09_03, tpu_Eigh.data, tpu_Lu.data_2023_03_21, tpu_Qr.data_2023_03_17, - tpu_Sharding.data_2023_03_16, tpu_ApproxTopK.data_2023_04_17, + tpu_Sharding.data_2025_06_30, tpu_ApproxTopK.data_2023_04_17, tpu_ApproxTopK.data_2023_05_16, tpu_stablehlo_dynamic_reduce_window.data_unary_2023_06_17, tpu_stablehlo_dynamic_reduce_window.data_variadic_2023_06_17, @@ -163,6 +148,8 @@ def test_custom_call_coverage(self): stablehlo_dynamic_top_k.data_2023_07_16, stablehlo_dynamic_top_k.data_2023_08_11, # with shape_assertion stablehlo_dynamic_approx_top_k.data_2024_05_30, + annotate_data_placement.data_2025_04_07_tpu, + annotate_data_placement.data_2025_04_07_cuda, ] # Some of the above are nested structures. covering_testdatas = itertools.chain( @@ -172,14 +159,20 @@ def test_custom_call_coverage(self): self.assertIsInstance(data, bctu.CompatTestData) covered_targets = covered_targets.union(data.custom_call_targets) + # Note: add names of custom calls here only if you are sure that they are + # covered by tests that are somewhere else, or if have a good reason + # to believe that they are not going to be broken by JAX changes. covered_targets = covered_targets.union({ "tf.call_tf_function", # tested in jax2tf/tests/back_compat_tf_test.py "tpu_custom_call", # tested separately + "mosaic_gpu_v2", # tested in pallas/export_back_compat_pallas_test.py + "AllocateBuffer", # tested in pallas/export_back_compat_pallas_test.py "__gpu$xla.gpu.triton", # tested in pallas/export_back_compat_pallas_test.py # The following require ROCm to test "hip_lu_pivots_to_permutation", "hipsolver_getrf_ffi", "hipsolver_geqrf_ffi", "hipsolver_orgqr_ffi", "hipsolver_syevd_ffi", "hipsolver_gesvd_ffi", "hipsolver_gesvdj_ffi", + "hipsolver_potrf_ffi", }) not_covered = targets_to_cover.difference(covered_targets) self.assertEmpty(not_covered, @@ -212,9 +205,26 @@ def test_cpu_cholesky_lapack_potrf(self, dtype_name="f32"): data = self.load_testdata(info) self.run_one_test(func, data, rtol=rtol, atol=atol) - data = self.load_testdata(cpu_cholesky_lapack_potrf.data_2023_06_19[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - expect_current_custom_calls=info["custom_call_targets"]) + @parameterized.named_parameters( + dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) + for dtype_name in ("f32", "f64", "c64", "c128")) + def test_gpu_cholesky_solver_potrf(self, dtype_name="f32"): + if not config.enable_x64.value and dtype_name in ["f64", "c128"]: + self.skipTest("Test disabled for x32 mode") + + dtype = dict(f32=np.float32, f64=np.float64, + c64=np.complex64, c128=np.complex128)[dtype_name] + shape = (4, 4) + input = self.cholesky_input(shape, dtype) + del input # Input is in the testdata, here for readability + func = lax.linalg.cholesky + + rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name] + atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name] + + info = cuda_cholesky_solver_potrf.data_2025_10_15[dtype_name] + data = self.load_testdata(info) + self.run_one_test(func, data, rtol=rtol, atol=atol) @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) @@ -276,10 +286,6 @@ def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array): data = self.load_testdata(info) self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=check_eig_results) - data = self.load_testdata(cpu_eig_lapack_geev.data_2023_06_19[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=check_eig_results, - expect_current_custom_calls=info["custom_call_targets"]) @staticmethod def eigh_input(shape, dtype): @@ -333,44 +339,6 @@ def test_cpu_eigh_lapack_syevd(self, dtype_name="f32"): self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=partial(self.check_eigh_results, operand)) - # Legacy custom call test - data = self.load_testdata(cpu_eigh_lapack_syev.data_2023_03_17[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=partial(self.check_eigh_results, operand), - expect_current_custom_calls=info["custom_call_targets"]) - - @parameterized.named_parameters( - dict(testcase_name=f"_dtype={dtype_name}_{variant}", - dtype_name=dtype_name, variant=variant) - for dtype_name in ("f32", "f64") - # We use different custom calls for sizes <= 32 - for variant in ["syevj", "syevd"]) - def test_gpu_eigh_solver_syev_legacy(self, dtype_name="f32", variant="syevj"): - if not config.enable_x64.value and dtype_name == "f64": - self.skipTest("Test disabled for x32 mode") - if jtu.test_device_matches(["rocm"]): - data = self.load_testdata(rocm_eigh_hipsolver_syev.data_2024_08_05[f"{dtype_name}_{variant}"]) - prefix = "hip" - elif jtu.test_device_matches(["cuda"]): - if _is_required_cusolver_version_satisfied(11600): - # The underlying problem is that this test assumes the workspace size can be - # queried from an older version of cuSOLVER and then be used in a newer one. - self.skipTest("Newer cuSOLVER expects a larger workspace than was serialized") - data = self.load_testdata(cuda_eigh_cusolver_syev.data_2023_03_17[f"{dtype_name}_{variant}"]) - prefix = "cu" - else: - self.skipTest("Unsupported platform") - # For lax.linalg.eigh - dtype = dict(f32=np.float32, f64=np.float64)[dtype_name] - size = dict(syevj=8, syevd=36)[variant] - rtol = dict(f32=1e-3, f64=1e-5)[dtype_name] - atol = dict(f32=1e-2, f64=1e-10)[dtype_name] - operand = CompatTest.eigh_input((size, size), dtype) - func = lambda: CompatTest.eigh_harness((size, size), dtype) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=partial(self.check_eigh_results, operand), - expect_current_custom_calls=[f"{prefix}solver_syevd_ffi"]) - @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) @@ -411,14 +379,14 @@ def lu_pivots_to_permutation_harness(shape): def test_cuda_lu_pivots_to_permutation(self): shape = (2, 3, 4) func = lambda: CompatTest.lu_pivots_to_permutation_harness(shape) - data = self.load_testdata(cuda_lu_pivots_to_permutation.data_2024_08_08) + data = self.load_testdata(cuda_lu_pivots_to_permutation.data_2025_04_01) self.run_one_test(func, data) @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) - def test_cuda_lu_lapack_getrf(self, dtype_name:str): + def test_cuda_lu_cusolver_getrf(self, dtype_name:str): if not config.enable_x64.value and dtype_name in ["f64", "c128"]: self.skipTest("Test disabled for x32 mode") dtype = dict(f32=np.float32, f64=np.float64, @@ -445,38 +413,10 @@ def test_cpu_qr_lapack_geqrf(self, dtype_name="f32"): c64=np.complex64, c128=np.complex128)[dtype_name] func = lambda: CompatTest.qr_harness((3, 3), dtype) - info = cpu_qr_lapack_geqrf.data_2024_08_22[dtype_name] + info = cpu_qr_lapack_geqrf.data_2025_04_02[dtype_name] data = self.load_testdata(info) self.run_one_test(func, data, rtol=rtol) - # TODO(b/369826500): Remove legacy custom call test after mid March 2025. - data = self.load_testdata(cpu_qr_lapack_geqrf.data_2023_03_17[dtype_name]) - self.run_one_test(func, data, rtol=rtol, - expect_current_custom_calls=info["custom_call_targets"]) - - # TODO(b/369826500): Remove legacy custom call test after mid March 2025. - @parameterized.named_parameters( - dict(testcase_name=f"_dtype={dtype_name}_{batched}", - dtype_name=dtype_name, batched=batched) - for dtype_name in ("f32",) - # For batched qr we use cublas_geqrf_batched/hipblas_geqrf_batched. - for batched in ("batched", "unbatched")) - def test_gpu_qr_solver_geqrf_legacy(self, dtype_name, batched): - if jtu.test_device_matches(["rocm"]): - data = self.load_testdata(rocm_qr_hipsolver_geqrf.data_2024_08_05[batched]) - prefix = "hip" - elif jtu.test_device_matches(["cuda"]): - data = self.load_testdata(cuda_qr_cusolver_geqrf.data_2023_03_18[batched]) - prefix = "cu" - else: - self.skipTest("Unsupported platform") - dtype = dict(f32=np.float32)[dtype_name] - rtol = dict(f32=1e-3)[dtype_name] - shape = dict(batched=(2, 3, 3), unbatched=(3, 3))[batched] - func = lambda: CompatTest.qr_harness(shape, dtype) - self.run_one_test(func, data, rtol=rtol, expect_current_custom_calls=[ - f"{prefix}solver_geqrf_ffi", f"{prefix}solver_orgqr_ffi"]) - @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) @@ -551,14 +491,6 @@ def test_cpu_lu_lapack_getrf(self, dtype_name:str): check_results=partial(self.check_lu_results, operand, dtype=dtype)) - # TODO(b/357034884): Remove legacy custom call test after mid March 2025. - legacy_data = self.load_testdata( - cpu_lu_lapack_getrf.data_2023_06_14[dtype_name]) - self.run_one_test(func, legacy_data, rtol=rtol, atol=atol, - check_results=partial(self.check_lu_results, operand, - dtype=dtype), - expect_current_custom_calls=info["custom_call_targets"]) - def check_svd_results(self, input, res_run, res_exp, rtol=None, atol=None): # Following linalg_test.testSVD @@ -652,10 +584,6 @@ def check_schur_results(res_run, res_expected, *, rtol, atol): data = self.load_testdata(info) self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=check_schur_results) - data = self.load_testdata(cpu_schur_lapack_gees.data_2023_07_16[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=check_schur_results, - expect_current_custom_calls=info["custom_call_targets"]) @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) @@ -677,12 +605,6 @@ def func(operand): check_results=partial(self.check_svd_results, *data.inputs)) - data = self.load_testdata(cpu_svd_lapack_gesdd.data_2023_06_19[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=partial(self.check_svd_results, - *data.inputs), - expect_current_custom_calls=info["custom_call_targets"]) - @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}_algorithm={algorithm_name}", dtype_name=dtype_name, algorithm_name=algorithm_name) @@ -740,16 +662,11 @@ def check_triangular_solve_results(res_run, res_expected, *, rtol, atol): y = matmul(a, x) if left_side else matmul(x, a) self.assertArraysAllClose(y, jnp.broadcast_to(b, y.shape), rtol=rtol, atol=atol) - info = cpu_triangular_solve_blas_trsm.data_2024_12_02[dtype_name] + info = cpu_triangular_solve_blas_trsm.data_2025_10_20[dtype_name] data = self.load_testdata(info) self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=check_triangular_solve_results) - data = self.load_testdata(cpu_triangular_solve_blas_trsm.data_2023_07_16[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=check_triangular_solve_results, - expect_current_custom_calls=info["custom_call_targets"]) - @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) @@ -773,12 +690,6 @@ def func(): data = self.load_testdata(info) self.run_one_test(func, data, rtol=rtol, atol=atol) - data = self.load_testdata( - cpu_hessenberg_lapack_gehrd.data_2024_08_30[dtype_name] - ) - self.run_one_test(func, data, rtol=rtol, atol=atol, - expect_current_custom_calls=info["custom_call_targets"]) - @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) @@ -802,12 +713,6 @@ def func(): data = self.load_testdata(info) self.run_one_test(func, data, rtol=rtol, atol=atol) - data = self.load_testdata( - cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_09_03[dtype_name] - ) - self.run_one_test(func, data, rtol=rtol, atol=atol, - expect_current_custom_calls=info["custom_call_targets"]) - @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) @@ -827,7 +732,7 @@ def test_cpu_tridiagonal_solve_lapack_gtsv(self, dtype_name): dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) @jax.default_matmul_precision("float32") - def test_gpu_tridiagonal_solver_sytrd(self, dtype_name): + def test_gpu_tridiagonal_sytrd(self, dtype_name): if not config.enable_x64.value and dtype_name in ["f64", "c128"]: self.skipTest("Test disabled for x32 mode") @@ -842,7 +747,27 @@ def func(x): ) self.run_one_test(func, data, rtol=rtol, atol=atol) - def test_approx_top_k(self): + @parameterized.named_parameters( + dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) + for dtype_name in ("f32", "f64")) + @jax.default_matmul_precision("float32") + def test_gpu_tridiagonal_solve(self, dtype_name): + if not config.enable_x64.value and dtype_name == "f64": + self.skipTest("Test disabled for x32 mode") + + dtype = dict(f32=np.float32, f64=np.float64)[dtype_name] + def func(dl, d, du, b): + return lax.linalg.tridiagonal_solve(dl, d, du, b) + + rtol = dict(f32=1e-3, f64=1e-5)[dtype_name] + atol = dict(f32=1e-4, f64=1e-12)[dtype_name] + + data = self.load_testdata( + cuda_tridiagonal_solve.data_2025_06_16[dtype_name] + ) + self.run_one_test(func, data, atol=atol, rtol=rtol) + + def test_tpu_approx_top_k(self): def func(): x = np.array([3.0, 1.0, 4.0, 2.0, 5.0, 6.0, 7.0]) y = lax.approx_max_k(x, 3) @@ -859,27 +784,81 @@ def func(x): data = self.load_testdata(cuda_threefry2x32.data_2024_07_30) self.run_one_test(func, data) - def test_sharding(self): + def test_tpu_sharding(self): # Tests "Sharding", "SPMDShardToFullShape", "SPMDFullToShardShape" on TPU if not jtu.test_device_matches(["tpu"]) or len(jax.devices()) < 2: self.skipTest("Test runs only on TPU with at least 2 devices") # Must use exactly 2 devices for expected outputs from ppermute devices = jax.devices()[:2] - mesh = Mesh(devices, axis_names=('a')) + mesh = Mesh(devices, axis_names=("a")) - @partial(pjit.pjit, - in_shardings=(P('a', None),), out_shardings=P('a', None)) + @partial(jax.jit, + in_shardings=(NS(mesh, P("a", None)),), + out_shardings=NS(mesh, P("a", None))) @partial(shard_map, mesh=mesh, - in_specs=(P('a', None),), out_specs=P('a', None)) + in_specs=(P("a", None),), + out_specs=P("a", None)) def func(x): # b: f32[2, 4] - axis_size = lax.psum(1, 'a') + axis_size = lax.axis_size("a") perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] - return lax.ppermute(x, 'a', perm=perm) + return lax.ppermute(x, "a", perm=perm) - data = self.load_testdata(tpu_Sharding.data_2023_03_16) - with mesh: - self.run_one_test(func, data) + data = [ + (tpu_Sharding.data_2025_06_30, None), + ] + # Due to changes in how Shardy is serialized, from using custom calls to + # natively serializing Shardy with StableHLO, we may need to override + # the expected custom call targets for old test data that was serialized + # with custom calls. + for data, custom_call_targets_override in data: + with mesh: + if jax.config.jax_use_shardy_partitioner: + self.run_one_test( + func, self.load_testdata(data["shardy"]), + expect_current_custom_calls=custom_call_targets_override) + else: + self.run_one_test(func, self.load_testdata(data["gspmd"])) + + + @parameterized.named_parameters( + dict(testcase_name=f"_platform={platform}", platform=platform) + for platform in ("tpu", "gpu")) + def test_annotate_device_placement(self, platform): + if not jtu.test_device_matches([platform]): + self.skipTest(f"Test enabled only for {platform}") + + mesh = Mesh(jax.local_devices()[0:1], axis_names=("a")) + + dev_sharding = NS(mesh, P("a")) + host_sharding = NS(mesh, P("a"), memory_kind="pinned_host") + + @partial(jax.jit, + in_shardings=(dev_sharding, host_sharding), + out_shardings=host_sharding) + def func(x, y): + return x + y + + if platform == "tpu": + data = [(annotate_data_placement.data_2025_04_07_tpu, + ["annotate_device_placement"]), + (annotate_data_placement.data_2025_06_30_tpu, None)] + else: + data = [(annotate_data_placement.data_2025_04_07_cuda, + ["annotate_device_placement"]), + (annotate_data_placement.data_2025_06_30_cuda, None)] + + # Due to changes in how Shardy is serialized, from using custom calls to + # natively serializing Shardy with StableHLO, we may need to override + # the expected custom call targets for old test data that was serialized + # with custom calls. + for data, custom_call_targets_override in data: + if jax.config.jax_use_shardy_partitioner: + self.run_one_test( + func, self.load_testdata(data["shardy"]), + expect_current_custom_calls=custom_call_targets_override) + else: + self.run_one_test(func, self.load_testdata(data["gspmd"])) def test_tpu_stablehlo_dynamic_reduce_window_unary(self): # stablehlo.dynamic_reduce_window is used temporarily on TPU for a @@ -1026,8 +1005,8 @@ def check_top_k_results(res_run, res_expected, *, rtol, atol): ) -@jtu.with_config(jax_use_shardy_partitioner=True) class ShardyCompatTest(bctu.CompatTestBase): + def test_shardy_sharding_ops_with_different_meshes(self): # Tests whether we can save and load a module with meshes that have the # same axis sizes (and same order) but different axis names. @@ -1044,15 +1023,25 @@ def func(x): # x: f32[4, 4] @partial(shard_map, mesh=old_mesh, in_specs=(P('a', None),), out_specs=P('a', None)) def shard_map_func(x): # b: f32[2, 4] - axis_size = lax.psum(1, 'a') + axis_size = lax.axis_size('a') perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] return lax.ppermute(x, 'a', perm=perm) x = jax.lax.with_sharding_constraint(x, NS(old_mesh, P('a', None))) return shard_map_func(x) - data = self.load_testdata(shardy_sharding_ops_with_different_meshes.data_2025_02_12) - with Mesh(devices, axis_names=('x')): - self.run_one_test(func, data) + data = [ + (shardy_sharding_ops_with_different_meshes.data_2025_06_30, None), + ] + + # Due to changes in how Shardy is serialized, from using custom calls to + # natively serializing Shardy with StableHLO, we may need to override + # the expected custom call targets for old test data that was serialized + # with custom calls. + for data, custom_call_targets_override in data: + with Mesh(devices, axis_names=('x')): + self.run_one_test( + func, self.load_testdata(data), + expect_current_custom_calls=custom_call_targets_override) if __name__ == "__main__": diff --git a/tests/export_harnesses_multi_platform_test.py b/tests/export_harnesses_multi_platform_test.py index ef9d1e04c796..2eec0a1c59be 100644 --- a/tests/export_harnesses_multi_platform_test.py +++ b/tests/export_harnesses_multi_platform_test.py @@ -23,27 +23,17 @@ from collections.abc import Callable import math -import re from absl import logging from absl.testing import absltest - -import numpy as np - import jax from jax import export from jax import lax +from jax import random from jax._src import config from jax._src import test_util as jtu from jax._src.internal_test_util import test_harnesses -from jax import random - - -def make_disjunction_regexp(*parts: str) -> re.Pattern[str]: - if not parts: - return re.compile("matches_no_test") - else: - return re.compile("(" + "|".join(parts) + ")") +import numpy as np class PrimitiveTest(jtu.JaxTestCase): @@ -84,9 +74,13 @@ def test_prim(self, harness: test_harnesses.Harness): self.skipTest("Eigenvalues are sorted and it is not correct to compare " "decompositions for equality.") - if (jtu.device_under_test() == "gpu" - and "tridiagonal_solve_" in harness.fullname): - self.skipTest("tridiagonal_solve_ is not yet guaranteed stable.") + # Tridiagonal Solve (gtsv2) on ROCm is implemented but produces + # numerical errors as of at least ROCm 7.2 so gtsv2 cannot be marked + # as stable yet. + # TODO Re-enable this test when ROCm numerical errors in gtsv2 are fixed. + if "tridiagonal_solve" in harness.fullname and jtu.is_device_rocm(): + self.skipTest("Tridiagonal Solve (gtsv2) is currently" + "unsupported on ROCm") if harness.params.get("enable_xla", False): self.skipTest("enable_xla=False is not relevant") @@ -98,11 +92,6 @@ def test_prim(self, harness: test_harnesses.Harness): for l in harness.jax_unimplemented: if l.filter(dtype=harness.dtype): unimplemented_platforms = unimplemented_platforms.union(l.devices) - # Some primitive lowering rules need the GPU backend to be able to create - # CUDA lowering. - if ("tridiagonal_solve_" in harness.fullname - and all(d.platform != "gpu" for d in self.devices)): - unimplemented_platforms.add("gpu") if unimplemented_platforms: logging.info("Harness is not implemented on %s", unimplemented_platforms) @@ -121,6 +110,7 @@ def test_prim(self, harness: test_harnesses.Harness): def export_and_compare_to_native( self, func_jax: Callable, *args: jax.Array, + is_pmap: bool = False, unimplemented_platforms: set[str] = set(), skip_run_on_platforms: set[str] = set(), tol: float | None = None): @@ -133,7 +123,7 @@ def export_and_compare_to_native( # lowering_platforms uses "cuda" or "rocm" instead of "gpu" gpu_platform = "cuda" if jtu.is_device_rocm(): - gpu_platform = "rocm" + gpu_platform = "rocm" lowering_platforms: list[str] = [ p if p != "gpu" else gpu_platform for p in ("cpu", "gpu", "tpu") @@ -145,11 +135,23 @@ def export_and_compare_to_native( "Harness is uninteresting with fewer than 2 platforms" ) + fn = func_jax + # NOTE(dsuo): There are two issues with `shard_map.pmap` under `jit`: + # 1. `shard_map.pmap`'s default devices may be different from the devices + # where args live. + # 2. `shard_map.pmap` sets its mesh when the pmap is constructed. In this + # test, we subsequently change which devices the args are put on. + # In both cases, we get a `stages.DeviceAssignmentMismatchError`. + if is_pmap: + fn = jax.pmap(func_jax, axis_name="i") logging.info("Exporting harness for %s", lowering_platforms) - exp = export.export(jax.jit(func_jax), + exp = export.export(jax.jit(fn), platforms=lowering_platforms)(*args) for device in devices: + logging.info("Exporting harness for %s", device) + if is_pmap: + fn = jax.pmap(func_jax, axis_name="i", devices=[device]) if device.platform in skip_run_on_platforms: logging.info("Skipping running on %s", device) continue @@ -157,7 +159,7 @@ def export_and_compare_to_native( lambda x: jax.device_put(x, device), args ) logging.info("Running harness natively on %s", device) - native_res = jax.jit(func_jax)(*device_args) + native_res = jax.jit(fn)(*device_args) logging.info("Running exported harness on %s", device) exported_res = exp.call(*device_args) if tol is not None: @@ -166,13 +168,10 @@ def export_and_compare_to_native( # TODO(necula): Check HLO equivalence for the ultimate test. def test_psum_scatter(self): - f = jax.jit(jax.pmap(lambda x: lax.psum_scatter(x, 'i'), - axis_name='i', - devices=jax.devices()[:1])) - shape = (1, 1, 8) x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) - self.export_and_compare_to_native(f, x) + f = lambda x: lax.psum_scatter(x, "i") + self.export_and_compare_to_native(f, x, is_pmap=True) # The lowering rule for all_gather has special cases for bool. @jtu.parameterized_filterable( @@ -181,15 +180,12 @@ def test_psum_scatter(self): for dtype in [np.bool_, np.float32]], ) def test_all_gather(self, *, dtype): - f = jax.jit(jax.pmap(lambda x: lax.all_gather(x, 'i'), - axis_name='i', - devices=jax.devices()[:1])) - + f = lambda x: lax.all_gather(x, axis_name='i') shape = (1, 4) x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) if dtype == np.bool_: x = (x % 2).astype(np.bool_) - self.export_and_compare_to_native(f, x) + self.export_and_compare_to_native(f, x, is_pmap=True) def test_random_with_threefry_gpu_kernel_lowering(self): # On GPU we use a custom call for threefry2x32 diff --git a/tests/export_serialization_back_compat_test.py b/tests/export_serialization_back_compat_test.py new file mode 100644 index 000000000000..650b42e23d29 --- /dev/null +++ b/tests/export_serialization_back_compat_test.py @@ -0,0 +1,219 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for backwards compatibility of serialization of JAX exports. + +Whenever we change the serialization format for jax.export.Exported +(see file jax.export.serialization), we should first save a serialization +of the current format and add a test that it can be deserialized and it has +the expected behavior. + +To add a new test: + + * Create a new test method, with a function to be serialized that exercises + the feature you want to test, and a call to self.export_and_serialize. + You can follow the model of the tests below, which are parameterized by + the testdata. Use only `None` for the testdata parameter to signal that + you want to use a current serialization and not a saved one. + * Run the test. This will save the serialized data in + TEST_UNDECLARED_OUTPUTS_DIR (or "/tmp/back_compat_testdata" if not set). + * Copy the test data defined in the output file, to the file + jax._src.internal_test_util.export_back_compat_test_data.export_{name}.py. + * Add a new import statement to this file to import that module + +This process will ensure that the saved serialized export can be read by +future code version (backward compatibility of the deserializer). To check +forward compatibility you'd have to check out an older version of the code +and cherry pick a new version of the directory +`jax._src.internal_test_util.export_back_compat_test_data`. +""" + +import logging +import os +import re +from typing import Any + +from absl.testing import absltest +import numpy as np + +# ruff: noqa: F401 +try: + import flatbuffers + CAN_SERIALIZE = True +except (ModuleNotFoundError, ImportError): + CAN_SERIALIZE = False + +import jax +from jax._src import config +from jax._src import core +from jax._src.export import _export +from jax._src.export.serialization import _SERIALIZATION_VERSION +from jax.sharding import PartitionSpec as P +from jax._src import test_util as jtu + +from jax._src.internal_test_util.export_back_compat_test_data import export_with_specified_sharding +from jax._src.internal_test_util.export_back_compat_test_data import export_with_unspecified_sharding +from jax._src.internal_test_util.export_back_compat_test_data import export_with_memory_space + +config.parse_flags_with_absl() +jtu.request_cpu_devices(8) + + +class CompatTest(jtu.JaxTestCase): + + def setUp(self): + if not CAN_SERIALIZE: + self.skipTest("Serialization not available") + + def export_and_serialize(self, fun, *args, + vjp_order=0, + platforms=None, + **kwargs) -> bytearray: + """Export and serialize a function. + + The test data is saved in TEST_UNDECLARED_OUTPUTS_DIR (or + "/tmp/back_compat_testdata" if not set) and should be copied as explained + in the module docstring. + """ + exp = _export.export(fun, platforms=platforms)(*args, **kwargs) + serialized = exp.serialize(vjp_order=vjp_order) + updated_testdata = f""" + # Paste to the test data file (see export_serialization_back_compat_test.py module docstring) + dict( + serialization_version={_SERIALIZATION_VERSION}, + exported_serialized={serialized!r}, + ), + +""" + # Replace the word that should not appear. + updated_testdata = re.sub(r"google.", "googlex", updated_testdata) + output_dir = os.getenv("TEST_UNDECLARED_OUTPUTS_DIR", + "/tmp/back_compat_testdata") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + output_file_basename = f"export_{self._testMethodName.replace('test_', '')}.py" + output_file = os.path.join(output_dir, output_file_basename) + logging.info("Writing the updated serialized Exported at %s", output_file) + with open(output_file, "w") as f: + f.write(updated_testdata) + return serialized + + @jtu.parameterized_filterable( + kwargs=[ + dict(testdata=testdata, + testcase_name=("current" if testdata is None + else f"v{testdata['serialization_version']}")) + for testdata in [None, *export_with_specified_sharding.serializations] + ] + ) + def test_with_specified_sharding(self, testdata: dict[str, Any] | None): + if jtu.device_under_test() != "cpu": + self.skipTest("Testing only the CPU serialization") + a = np.arange(16 * 4, dtype=np.float32).reshape((16, 4)) + mesh = jtu.create_mesh((2,), "x") + with jax.set_mesh(mesh): + @jax.jit(in_shardings=(jax.sharding.NamedSharding(mesh, P("x", None),),), + out_shardings=jax.sharding.NamedSharding(mesh, P(None, "x"))) + def f(b): + return b * 2. + + a = jax.device_put(a, jax.sharding.NamedSharding(mesh, P("x", None))) + if testdata is None: + serialized = self.export_and_serialize(f, a) + else: + serialized = testdata["exported_serialized"] + + out = _export.deserialize(serialized).call(a) + self.assertAllClose(out, a * 2.) + self.assertEqual(out.addressable_shards[0].index, (slice(None), slice(0, 2))) + self.assertEqual(out.addressable_shards[1].index, (slice(None), slice(2, 4))) + + + @jtu.parameterized_filterable( + kwargs=[ + dict(testdata=testdata, + testcase_name=("current" if testdata is None + else f"v{testdata['serialization_version']}")) + for testdata in [None, *export_with_unspecified_sharding.serializations] + ] + ) + def test_with_unspecified_sharding(self, testdata: dict[str, Any] | None): + if jtu.device_under_test() != "cpu": + self.skipTest("Testing only the CPU serialization") + a = np.arange(16 * 4, dtype=np.float32).reshape((16, 4)) + + # Output sharding is not specified + mesh = jtu.create_mesh((2,), "x") + with jax.set_mesh(mesh): + @jax.jit(in_shardings=(jax.sharding.NamedSharding(mesh, P("x", None),),)) + def f(b): + return b * 2. + + a = jax.device_put(a, jax.sharding.NamedSharding(mesh, P("x", None))) + if testdata is None: + serialized = self.export_and_serialize(f, a) + else: + serialized = testdata["exported_serialized"] + + out = _export.deserialize(serialized).call(a) + self.assertAllClose(out, a * 2.) + self.assertEqual(out.addressable_shards[0].index, (slice(0, 8), slice(None))) + self.assertEqual(out.addressable_shards[1].index, (slice(8, 16), slice(None))) + + + @jtu.parameterized_filterable( + kwargs=[ + dict(testdata=testdata, + testcase_name=("current" if testdata is None + else f"v{testdata['serialization_version']}")) + for testdata in [None, *export_with_memory_space.serializations] + ] + ) + def test_with_memory_space(self, testdata: dict[str, Any] | None): + # This test is based on export_test.py::test_memory_space_from_arg + mesh = jtu.create_mesh((2,), "x") + with jax.set_mesh(mesh): + shd = jax.sharding.NamedSharding(mesh, P("x", None), + memory_kind="pinned_host") + a = jax.device_put(np.ones((2, 3), dtype=np.float32), shd) + f = jax.jit(lambda x: x) + + if testdata is None: + serialized = self.export_and_serialize( + f, a, platforms=("tpu", "cuda", "rocm")) + else: + # The testdata for the serialization formats has been generated and + # is awaiting review and merging upstream. It is not included + # downstream to avoid import errors. For this reason, the downstream + # version of this test is currently skipped. + # + # TODO: Remove this skip once the testdata and modified unit test + # are merged upstream. + if jtu.is_device_rocm(): + self.skipTest("Serialized export testdata for serialization format " + f"{testdata['serialization_version']} is not " + "currently available for ROCm.") + serialized = testdata["exported_serialized"] + + exported = _export.deserialize(serialized) + self.assertEqual(exported.in_avals[0].memory_space, core.MemorySpace.Host) + self.assertEqual(exported.out_avals[0].memory_space, core.MemorySpace.Host) + + if jtu.device_under_test() in ("tpu", "gpu"): + b = exported.call(a) + self.assertEqual(b.aval.memory_space, core.MemorySpace.Host) + self.assertEqual(b.sharding.memory_kind, a.sharding.memory_kind) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/export_test.py b/tests/export_test.py index 2b083f3121f4..4996e3258f18 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -19,7 +19,6 @@ import dataclasses import functools import logging -import json import math import re import unittest @@ -29,21 +28,22 @@ from jax import lax from jax import numpy as jnp from jax import export -from jax.experimental import pjit -from jax.experimental.shard_map import shard_map -from jax.sharding import NamedSharding -from jax.sharding import Mesh -from jax.sharding import PartitionSpec as P +from jax._src.shard_map import shard_map +from jax.sharding import (NamedSharding, Mesh, PartitionSpec as P, + reshard) +from jax._src.sharding_impls import GSPMDSharding from jax import tree_util from jax._src import config +from jax._src import compute_on from jax._src import core +from jax._src import dispatch from jax._src import dtypes from jax._src import effects from jax._src import test_util as jtu from jax._src import xla_bridge as xb from jax._src.interpreters import mlir - +from jax._src import lib from jax._src.lib.mlir.dialects import hlo import numpy as np @@ -204,7 +204,14 @@ def test_basic(self): f = jnp.sin x = np.arange(4, dtype=np.float32) exp_f = get_exported(f)(x) + self.assertAllClose(f(x), exp_f.call(x)) + def test_basic_single_device_sharding(self): + device = jax.local_devices()[0] + s = jax.sharding.SingleDeviceSharding(device) + x = np.arange(16, dtype=np.float32).reshape(4, -1) + f = jax.jit(lambda x: x * 2., in_shardings=s, out_shardings=s) + exp_f = get_exported(f)(x) self.assertAllClose(f(x), exp_f.call(x)) def test_jit_static_arg(self): @@ -281,6 +288,39 @@ def test_unused_args(self): self.assertAllClose(f(x, y), exp_f.call(x, y)) + def test_dict_non_string_key(self): + @jax.jit + def f(x_dict): + return x_dict[(0, 1)] + x_dict[(1, 2)] + + x_dict = {(0, 1): np.float32(42.), (1, 2): np.float32(43.)} + with self.assertRaisesRegex( + TypeError, "Serialization is supported only for dictionaries with string keys"): + get_exported(f)(x_dict) + + def test_closed_over_constant(self): + const_size = 100 + const = jax.random.uniform(jax.random.key(0), (const_size,), + dtype=np.float32) + + f = jax.jit(lambda x: x + const) + x = np.zeros((const_size,), dtype=np.float32) + exp_f = get_exported(f)(x) + + self.assertAllClose(f(x), exp_f.call(x)) + + def test_override_lowering_rules(self): + @jax.jit + def f(x): + return jnp.sin(x) + + def my_lowering_rule(ctx, arg, **_): + return mlir.hlo.CosineOp(arg).results + + exp = get_exported(f, _override_lowering_rules=( + (lax.sin_p, my_lowering_rule),))(42.) + self.assertIn("stablehlo.cosine", exp.mlir_module()) + def test_pytree(self): a = np.arange(4, dtype=np.float32) b = np.arange(6, dtype=np.float32) @@ -397,7 +437,7 @@ def f(x1, x2): exp = export.export(jax.jit(f))(x1, x2) res = exp.call(x1, x2) self.assertEqual(tree_util.tree_structure(res), - tree_util.tree_structure(((x1, x2, x1, x2)))) + tree_util.tree_structure((x1, x2, x1, x2))) self.assertEqual(type(res[0]), type(x1)) self.assertEqual(type(res[1]), type(x2)) self.assertEqual(type(res[2]), type(x1)) @@ -410,6 +450,18 @@ def f(x1, x2): self.assertEqual(tree_util.tree_structure(res2), tree_util.tree_structure(res)) + @jtu.parameterized_filterable( + kwargs=[dict(impl=p) + for p in ("rbg", "unsafe_rbg", "threefry2x32")]) + def test_prng_keys(self, *, impl): + + key = jax.random.key(42, impl=impl) + @jax.jit + def f(key): + return key + exp_f = get_exported(jax.jit(f))(key) + self.assertEqual(f(key), exp_f.call(key)) + def test_error_wrong_intree(self): def f(a_b_pair, *, c): return jnp.sin(a_b_pair[0]) + jnp.cos(a_b_pair[1]) + c @@ -511,33 +563,47 @@ def test_lowering_parameters_for_export(self): context = {} def test_primitive_lowering(ctx, arg): context["for_export"] = ctx.module_context.lowering_parameters.for_export + context["hoist_constants_as_args"] = ctx.module_context.lowering_parameters.hoist_constants_as_args context["export_ignore_forward_compatibility"] = ctx.module_context.lowering_parameters.export_ignore_forward_compatibility return mlir.hlo.AddOp(arg, arg).results mlir.register_lowering(test_primitive, test_primitive_lowering) + test_primitive.def_impl(functools.partial(dispatch.apply_primitive, + test_primitive)) self.addCleanup(lambda: mlir.register_lowering(test_primitive, None)) f = jax.jit(test_primitive.bind) a = np.arange(3, dtype=np.float32) context.clear() + + res = test_primitive.bind(a) # eager mode + self.assertAllClose(res, a + a) + self.assertEqual(context, + dict(for_export=False, + hoist_constants_as_args=config.use_simplified_jaxpr_constants.value, + export_ignore_forward_compatibility=False)) + res = f(a) # Works with JIT self.assertAllClose(res, a + a) self.assertEqual(context, dict(for_export=False, + hoist_constants_as_args=config.use_simplified_jaxpr_constants.value, export_ignore_forward_compatibility=False)) context.clear() - f.lower(a) # Works with most AOT - # The above was cached - self.assertEqual(context, {}) + if config.use_simplified_jaxpr_constants.value: + f.lower(a) # Works with most AOT + self.assertEqual(context, {}) # hit the cache _ = export.export(f)(a) self.assertEqual(context, dict(for_export=True, + hoist_constants_as_args=False, export_ignore_forward_compatibility=False)) context.clear() with config.export_ignore_forward_compatibility(True): _ = export.export(f)(a) self.assertEqual(context, dict(for_export=True, + hoist_constants_as_args=False, export_ignore_forward_compatibility=True)) def test_grad(self): @@ -941,7 +1007,7 @@ def outer(x): # x: outer_poly_spec "Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). " "Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), " "'b' = 0 from specification '2*b + a' for dimension args[0].shape[0] (= 2), . " - "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." + "Please see https://docs.jax.dev/en/latest/export/shape_poly.html#shape-assertion-errors for more details." )), dict(shape=(3, 2, 6), # a = 2, b = 0.5, c = 4 - b is not integer poly_spec="(a + 2*b, a, a + b + c)", @@ -950,7 +1016,7 @@ def outer(x): # x: outer_poly_spec "Division had remainder 1 when computing the value of 'b'. " "Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). " "Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), . " - "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." + "Please see https://docs.jax.dev/en/latest/export/shape_poly.html#shape-assertion-errors for more details." )), dict(shape=(8, 2, 6), # a = 2, b = 3 - inconsistency poly_spec="(a + 2*b, a, a + b)", @@ -960,7 +1026,7 @@ def outer(x): # x: outer_poly_spec "Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, b + a). " "Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), " "'b' = 4 from specification 'b + a' for dimension args[0].shape[2] (= 6), . " - "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." + "Please see https://docs.jax.dev/en/latest/export/shape_poly.html#shape-assertion-errors for more details." )), dict(shape=(7, 2, 36), # a = 2, b = 3, c = 6 - cannot solve c poly_spec="(2 * a + b, a, c * c)", @@ -969,7 +1035,7 @@ def outer(x): # x: outer_poly_spec "We can only solve linear uni-variate constraints. " "Using the following polymorphic shapes specifications: args[0].shape = (b + 2*a, a, c^2). " "Unprocessed specifications: 'c^2' for dimension size args[0].shape[2]. " - "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details." + "Please see https://docs.jax.dev/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details." )), ]) def test_shape_constraints_errors(self, *, @@ -1161,7 +1227,7 @@ def test_input_shardings_unused_args(self): # We can use other devices and other meshes for running run_devices = devices[::-1] - run_mesh = Mesh(run_devices, "a") + run_mesh = Mesh(run_devices, "x") run_input_shardings = exp.in_shardings_jax(run_mesh) a_run = jax.device_put(a, run_input_shardings[0]) b_run = jax.device_put(a, run_input_shardings[1]) @@ -1169,11 +1235,54 @@ def test_input_shardings_unused_args(self): self.assertEqual(res.addressable_shards[0].device, run_devices[0]) self.assertEqual(res.addressable_shards[1].device, run_devices[1]) + @jtu.parameterized_filterable( + kwargs=[ + dict(testcase_name=f"_in={has_in_shardings}_out={has_out_shardings}", + has_in_shardings=has_in_shardings, + has_out_shardings=has_out_shardings) + for has_in_shardings in (False, True) + for has_out_shardings in (False, True) + ]) + def test_gspmdshardings(self, has_in_shardings, has_out_shardings): + if len(jax.devices()) < 2: + self.skipTest("Need at least 2 devices") + run_devices = jax.devices()[:2] + mesh = jtu.create_mesh((2, 1), ("x", "y")) + ns1 = NamedSharding(mesh, P(None)) + ns2 = NamedSharding(mesh, P("x")) + gs1 = GSPMDSharding(run_devices, ns1._to_xla_hlo_sharding(2)) + gs2 = GSPMDSharding(run_devices, ns2._to_xla_hlo_sharding(2)) + + jit_kwargs = {} + if has_in_shardings: + jit_kwargs["in_shardings"] = (gs1, gs2) + if has_out_shardings: + jit_kwargs["out_shardings"] = (gs1, gs2) + + f = jax.jit(lambda x1, x2: (x1, x2), **jit_kwargs) + x = jnp.arange(16 * 4, dtype=np.float32).reshape((16, 4)) + x1_dev = jax.device_put(x, gs1) + x2_dev = jax.device_put(x, gs2) + res = f(x1_dev, x2_dev) + + if not has_in_shardings: + exp = get_exported(f)(x1_dev, x2_dev) + else: + exp = get_exported(f)(x, x) + + with jax.set_mesh(mesh): + call_res = exp.call(x, x) # args are on default device + for r, cr in zip(res, call_res): + self.assertAllClose(r, cr) + self.assertEqual(len(r.addressable_shards), len(cr.addressable_shards)) + for rs, crs in zip(r.addressable_shards, cr.addressable_shards): + self.assertArraysEqual(rs.data, crs.data) + def test_export_abstract_mesh(self): if jax.local_device_count() < 2: self.skipTest("Need at least 2 devices") - abs_mesh = jax.sharding.AbstractMesh((2,), 'x') + abs_mesh = jax.sharding.AbstractMesh((2,), "x") input_sharding = jax.sharding.NamedSharding(abs_mesh, P("x", None)) output_sharding = jax.sharding.NamedSharding(abs_mesh, P(None, "x")) @jax.jit @@ -1219,12 +1328,12 @@ def f_without_shardings(x): res_exported = exp.call(b) self.assertAllClose(res_native, res_exported) - def test_call_with_different_no_of_devices_error_has_in_shardings(self): + def test_call_with_different_no_of_devices_in_shardings_success(self): if jax.local_device_count() < 2: self.skipTest("Need at least 2 devices") mesh_1 = Mesh(jax.local_devices()[:1], "i") - @functools.partial(pjit.pjit, + @functools.partial(jax.jit, in_shardings=NamedSharding(mesh_1, P("i"))) def f_with_sharding(x): return jnp.sum(x ** 2, axis=0) @@ -1232,6 +1341,7 @@ def f_with_sharding(x): a = jnp.arange(jax.device_count() * 10, dtype=np.float32).reshape( (jax.device_count(), 10) ) + res_native = f_with_sharding(a) exp = get_exported(f_with_sharding)(a) self.assertEqual(exp.nr_devices, 1) @@ -1239,11 +1349,33 @@ def f_with_sharding(x): run_mesh = Mesh(run_devices, "i") b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i"))) + res_exported = exp.call(b) + self.assertAllClose(res_native, res_exported) + + def test_call_with_different_no_of_devices_in_shardings_error(self): + if jax.local_device_count() < 3: + self.skipTest("Need at least 3 devices") + + mesh_1 = Mesh(jax.local_devices()[:2], "i") + @functools.partial(jax.jit, + in_shardings=NamedSharding(mesh_1, P("i"))) + def f_with_sharding(x): + return jnp.sum(x ** 2, axis=0) + + a = jnp.arange(jax.device_count() * 10, dtype=np.float32).reshape( + (jax.device_count(), 10) + ) + exp = get_exported(f_with_sharding)(a) + self.assertEqual(exp.nr_devices, 2) + + run_devices = jax.local_devices() + run_mesh = Mesh(run_devices, "i") + b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i"))) + with self.assertRaisesRegex( ValueError, - "Function .* was exported for 1 devices and is called in a " - f"context with {jax.local_device_count()} devices.* function contains " - "non-replicated sharding annotations"): + "Function .* was exported for 2 devices and is called in a " + f"context with {jax.local_device_count()} devices"): exp.call(b) def test_call_with_different_no_of_devices_pmap(self): @@ -1265,7 +1397,7 @@ def f_jax(x): res_exported = jax.pmap(exp.call)(b) self.assertAllClose(res_native, res_exported[0]) - def test_call_with_different_no_of_devices_error_has_sharding_constraint(self): + def test_call_with_different_no_of_devices_sharding_constraint_success(self): if jax.device_count() < 2: self.skipTest("Need at least 2 devices") @@ -1278,6 +1410,7 @@ def f_with_sharding(x): a = jnp.arange(jax.device_count() * 10, dtype=np.float32).reshape( (jax.device_count(), 10) ) + res_native = f_with_sharding(a) exp = get_exported(f_with_sharding)(a) self.assertEqual(exp.nr_devices, 1) @@ -1285,13 +1418,82 @@ def f_with_sharding(x): run_mesh = Mesh(run_devices, "i") b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i"))) + res_exported = exp.call(b) + self.assertAllClose(res_native, res_exported) + + def test_call_with_different_no_of_devices_sharding_constraint_error(self): + if jax.device_count() < 3: + self.skipTest("Need at least 3 devices") + + # We export for 2 devices, but call with >=3 devices. + mesh_1 = Mesh(jax.local_devices()[:2], "i") + @jax.jit + def f_with_sharding(x): + x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh_1, P("i"))) + return jnp.sum(x ** 2, axis=0) + + a = jnp.arange(jax.device_count() * 10, dtype=np.float32).reshape( + (jax.device_count(), 10) + ) + exp = get_exported(f_with_sharding)(a) + self.assertEqual(exp.nr_devices, 2) + + run_devices = jax.local_devices() + run_mesh = Mesh(run_devices, "i") + b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i"))) + with self.assertRaisesRegex( ValueError, - "Function .* was exported for 1 devices and is called in a " - f"context with {jax.local_device_count()} devices.* function contains " - "non-replicated sharding annotations"): + "Function .* was exported for 2 devices and is called in a " + f"context with {jax.local_device_count()} devices"): exp.call(b) + def test_memory_space_from_arg(self): + shd = jax.sharding.SingleDeviceSharding( + jax.devices()[0], memory_kind="pinned_host") + a = jax.device_put(np.ones((2, 3), dtype=np.float32), shd) + f = jax.jit(lambda x: x) + + exported = get_exported(f, platforms=("tpu", "cuda", "rocm"))(a) + self.assertEqual(exported.in_avals[0].memory_space, core.MemorySpace.Host) + self.assertEqual(exported.out_avals[0].memory_space, core.MemorySpace.Host) + + empty_mesh = jax.sharding.AbstractMesh((), ()) + shd_ns = jax.sharding.NamedSharding(empty_mesh, P(None, None), + memory_kind="pinned_host") + + self.assertEqual(exported.in_avals[0].sharding, + jax.sharding.NamedSharding(empty_mesh, P(None, None))) + self.assertEqual(exported.in_shardings_jax(empty_mesh)[0], + jax.sharding.NamedSharding(empty_mesh, P(None, None), + memory_kind="pinned_host")) + # TODO(necula): a situation when the out_shardings is not the SOT, because + # they are unspecified, so the memory space can only be in aval. If we + # want the shardings to be SOT, then maybe the unspecified shardings should + # contain a memory_kind. + self.assertEqual(exported.out_shardings_jax(empty_mesh)[0], + None) # sharding.UnspecifiedValue + if jtu.device_under_test() in ("tpu", "gpu"): + b = exported.call(a) + self.assertEqual(b.sharding, a.sharding) + + def test_memory_space_from_out_shardings(self): + shd = jax.sharding.SingleDeviceSharding(jax.devices()[0], + memory_kind="pinned_host") + f = jax.jit(lambda: jnp.ones((2, 2), dtype=np.float32), + out_shardings=shd) + + exported = get_exported(f, platforms=("tpu", "cuda", "rocm"))() + self.assertEqual(exported.out_avals[0].memory_space, core.MemorySpace.Host) + empty_mesh = jax.sharding.AbstractMesh((), ()) + shd_ns = jax.sharding.NamedSharding(empty_mesh, P(None, None), + memory_kind="pinned_host") + self.assertEqual(exported.out_shardings_jax(empty_mesh)[0], shd_ns) + # TODO(necula): this test should work on TPU also + if jtu.device_under_test() == "gpu": + b = exported.call() + self.assertEqual(b.sharding, shd) + @jtu.parameterized_filterable( kwargs=[ dict(testcase_name=f"_poly={poly}", poly=poly) @@ -1305,14 +1507,14 @@ def test_shard_map_collective_permute(self, poly=None): a = np.arange(4 * 4, dtype=np.float32).reshape((4, 4)) @functools.partial( - pjit.pjit, + jax.jit, in_shardings=NamedSharding(mesh, P("x", None),), out_shardings=NamedSharding(mesh, P("x", None))) @functools.partial( shard_map, mesh=mesh, in_specs=(P("x", None),), out_specs=P("x", None)) def f_jax(b): # b: f32[2, 4] - axis_size = lax.psum(1, "x") + axis_size = lax.axis_size("x") perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] return lax.ppermute(b, "x", perm=perm) @@ -1349,16 +1551,105 @@ def f_jax(b): # b: f32[2, 4] self.assertAllClose(res_jax.addressable_shards[i].data, res_r.addressable_shards[i].data) + @jtu.with_explicit_mesh((2,), 'x') + def test_unreduced_einsum_basic(self, mesh): + np_inp = np.arange(4).reshape(2, 2) + x = jax.device_put(np_inp, P(None, 'x')) + y = jax.device_put(np_inp, P('x', None)) + + @jax.jit + def f(x, y): + out = jnp.einsum('ab,bc->ac', x, y, + out_sharding=P(None, None, unreduced={'x'})) + self.assertEqual(out.aval.sharding.spec, P(None, None, unreduced={'x'})) + return out + + exported = get_exported(f)(x, y) + out = exported.call(x, y) + self.assertEqual(out.sharding, + NamedSharding(mesh, P(None, None, unreduced={'x'}))) + self.assertEqual(out.shape, (2, 2)) + self.assertEqual(out.sharding.shard_shape(out.shape), (2, 2)) + + expected_shards = [np.array([[0, 0], [0, 2]]), np.array([[2, 3], [6, 9]])] + for s, es in zip(out.addressable_shards, expected_shards): + self.assertEqual(s.data.shape, (2, 2)) + self.assertArraysEqual(s.data, es) + + reshard_out = reshard(out, P(None, None)) + self.assertArraysEqual(reshard_out, np_inp @ np_inp) + @jtu.parameterized_filterable( kwargs=[ - dict(in_shardings=in_shardings, out_shardings=out_shardings, - with_mesh_context=with_mesh_context) + dict(testcase_name=name, spec1=spec1, spec2=spec2, + out_spec=out_spec, collective_name=collective_name) + for name, spec1, spec2, out_spec, collective_name in [ + ("x_y", P("x", None), P(None, "y"), P("x", "y"), None), + ("x_None", P("x", None), P(None, None), P("x", None), None), + ("contracting2", P("x", "y"), P(None, None), P("x", None), "all-gather"), + ("fsdp", P("x", None), P("x", None), P("x", None), "all-gather"), + ("half_tp", P(None, "y"), P(None, "y"), P(None, "y"), "all-gather") + ] + ]) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_explicit_sharding_dot_general(self, spec1, spec2, out_spec, + collective_name, mesh): + # Based on pjit_test::test_dot_general + def check_wsc_in_lowered(text): + assert config.use_shardy_partitioner.value # TODO(necula): is shardy always on? + if config.use_shardy_partitioner.value: + self.assertIn('sdy.sharding_constraint', text) + else: + self.assertIn('@Sharding', text) + + np_inp1 = np.arange(16.).reshape(8, 2) + arr1 = jax.device_put(np_inp1, NamedSharding(mesh, spec1)) + arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, spec2)) + + def f(x, y): + out = x @ y + self.assertEqual(out.aval.sharding.spec, out_spec) + return out + + exported = get_exported(jax.jit(f))(arr1, arr2) + amesh = jax.sharding.AbstractMesh(mesh.axis_sizes, mesh.axis_names, + axis_types=mesh.axis_types) + self.assertEqual(exported.in_avals[0].sharding, + NamedSharding(amesh, spec1)) + self.assertEqual(exported.in_avals[1].sharding, + NamedSharding(amesh, spec2)) + self.assertEqual(exported.in_shardings_jax(amesh)[0], + NamedSharding(amesh, spec1, memory_kind="device")) + out = exported.call(arr1, arr2) + self.assertArraysEqual(out, np_inp1 @ np_inp1.T) + self.assertEqual(out.sharding, NamedSharding(mesh, out_spec)) + + lowered = jax.jit(exported.call).lower(arr1, arr2) + check_wsc_in_lowered(lowered.as_text()) + + compiled_text = lowered.compile().as_text() + if collective_name is not None: + self.assertIn(collective_name, compiled_text) + + @jax.jit + def g(x, y): + out = f(x, y) + return jnp.sum(out) + + g = jax.jit(jax.grad(g, argnums=(0, 1))) + exported_g = get_exported(g)(arr1, arr2) + out = exported_g.call(arr1, arr2) + out_g = jax.jit(g)(arr1, arr2) + self.assertEqual(out_g[0].sharding, arr1.sharding) + self.assertEqual(out_g[1].sharding, arr2.sharding) + + @jtu.parameterized_filterable( + kwargs=[ + dict(in_shardings=in_shardings, out_shardings=out_shardings) for in_shardings in ("missing", None, "P") for out_shardings in ("missing", None, "P") - for with_mesh_context in (True, False) ]) - def test_grad_with_sharding(self, in_shardings="P", out_shardings=None, - with_mesh_context=False): + def test_grad_with_sharding(self, in_shardings="P", out_shardings=None): if len(jax.devices()) < 2: self.skipTest("Test requires at least 2 devices") x_shape = (10, 20) @@ -1369,28 +1660,21 @@ def f_jax(x): # x: f32[10,20] -> f32[20,10] return jnp.sin(x.T) mesh = Mesh(jax.devices()[:2], "d") - pjit_kwargs = {} - # Use NamedShardings if we don't have a mesh_context - if with_mesh_context: - sharding_None_d = P(None, "d") - sharding_d_None = P("d", None) - else: - sharding_None_d = NamedSharding(mesh, P(None, "d")) - sharding_d_None = NamedSharding(mesh, P("d", None)) + jit_kwargs = {} + sharding_None_d = NamedSharding(mesh, P(None, "d")) + sharding_d_None = NamedSharding(mesh, P("d", None)) if in_shardings != "missing": - pjit_kwargs["in_shardings"] = ( + jit_kwargs["in_shardings"] = ( sharding_None_d if in_shardings == "P" else None) if out_shardings != "missing": - pjit_kwargs["out_shardings"] = ( + jit_kwargs["out_shardings"] = ( sharding_d_None if out_shardings == "P" else None) - f_jax_pjit = pjit.pjit(f_jax, **pjit_kwargs) + f_jax_jit = jax.jit(f_jax, **jit_kwargs) with contextlib.ExitStack() as stack: - if with_mesh_context: - stack.enter_context(mesh) # Serialize higher-order gradiends - exp = get_exported(f_jax_pjit, vjp_order=2)(x) + exp = get_exported(f_jax_jit, vjp_order=2)(x) exp_vjp = exp.vjp() # Try 2nd order grad as well exp_vjp2 = exp_vjp.vjp() @@ -1405,30 +1689,40 @@ def f_jax(x): # x: f32[10,20] -> f32[20,10] r"\) -> \(tensor<10x20xf32> (.*)", # the result vjp_module_str).groups() + if config.use_shardy_partitioner.value: + attr_name = "sdy.sharding" + else: + attr_name = "mhlo.sharding" + if in_shardings == "P": - self.assertRegex(arg0_attrs, re.escape("{devices=[1,2]<=[2]}")) - self.assertRegex(res_attrs, re.escape("{devices=[1,2]<=[2]}")) - primal_in_sharding = "{devices=[1,2]<=[2]}" + if config.use_shardy_partitioner.value: + sharding = r'#sdy.sharding<@mesh, \[{}, {"d"}\]>' + primal_in_sharding = '#sdy.sharding<@mesh, [{}, {"d"}]>' + else: + sharding = re.escape("{devices=[1,2]<=[2]}") + primal_in_sharding = "{devices=[1,2]<=[2]}" + self.assertRegex(arg0_attrs, sharding) + self.assertRegex(res_attrs, sharding) else: primal_in_sharding = "{replicated}" - if with_mesh_context: - self.assertRegex(arg0_attrs, re.escape("replicated")) - self.assertRegex(res_attrs, re.escape("replicated")) - else: - # If there is no mesh context, we have used NamedSharding(None) - # and then the sharding is unspecified! - self.assertNotIn("mhlo.sharding", arg0_attrs) - self.assertNotIn("mhlo.sharding", res_attrs) + self.assertNotIn(attr_name, arg0_attrs) + self.assertNotIn(attr_name, res_attrs) if out_shardings == "P": - self.assertRegex(arg1_attrs, re.escape("{devices=[2,1]<=[2]}")) - primal_out_sharding = "{devices=[2,1]<=[2]}" + if config.use_shardy_partitioner.value: + self.assertRegex(arg1_attrs, + re.escape('#sdy.sharding<@mesh, [{"d"}, {}]>')) + primal_out_sharding = '#sdy.sharding<@mesh, [{"d"}, {}]>' + else: + self.assertRegex(arg1_attrs, re.escape("{devices=[2,1]<=[2]}")) + primal_out_sharding = "{devices=[2,1]<=[2]}" else: - primal_out_sharding = "{replicated}" - if with_mesh_context: - self.assertRegex(arg1_attrs, re.escape("replicated")) + if config.use_shardy_partitioner.value: + primal_out_sharding = '#sdy.sharding<@mesh, [{}, {}]>' else: - self.assertNotIn("mhlo.sharding", arg1_attrs) + primal_out_sharding = "{replicated}" + + self.assertNotIn(attr_name, arg1_attrs) # Sharding custom calls for the primal input shape all match primal_in_sharding primal_in_sharding_calls = re.findall( @@ -1452,7 +1746,7 @@ def f_jax(x): # x: f32[10,20] -> f32[20,10] # we replicate the inputs. If we don't use a mesh context and there are # no shardings on inputs or outputs, then we have serialized for one # device. - if in_shardings != "P" and out_shardings != "P" and not with_mesh_context: + if in_shardings != "P" and out_shardings != "P": self.assertEqual(exp_vjp.nr_devices, 1) self.assertEqual(exp_vjp2.nr_devices, 1) call_mesh = Mesh(jax.devices()[:1], "e") @@ -1461,17 +1755,17 @@ def f_jax(x): # x: f32[10,20] -> f32[20,10] self.assertEqual(exp_vjp2.nr_devices, 2) call_mesh = Mesh(jax.devices()[:2], "e") - g1 = pjit.pjit(exp_vjp.call, - in_shardings=(NamedSharding(call_mesh, P()), - NamedSharding(call_mesh, P())))(x, x.T) + g1 = jax.jit(exp_vjp.call, + in_shardings=(NamedSharding(call_mesh, P()), + NamedSharding(call_mesh, P())))(x, x.T) _, f_jax_vjp = jax.vjp(f_jax, x) xbar = f_jax_vjp(x.T) self.assertAllClose(xbar, g1) - g2 = pjit.pjit(exp_vjp2.call, - in_shardings=(NamedSharding(call_mesh, P()), - NamedSharding(call_mesh, P()), - NamedSharding(call_mesh, P())))(x, x.T, x) + g2 = jax.jit(exp_vjp2.call, + in_shardings=(NamedSharding(call_mesh, P()), + NamedSharding(call_mesh, P()), + NamedSharding(call_mesh, P())))(x, x.T, x) _, f_jax_vjp2 = jax.vjp(f_jax_vjp, x.T) xbar2, = f_jax_vjp2((x,)) self.assertAllClose(xbar2, g2[1]) @@ -1489,10 +1783,10 @@ def f(x): shardings_rev = NamedSharding(mesh_rev, jax.sharding.PartitionSpec(("i",))) input_no_shards = jnp.ones(shape=(jax.local_device_count(),)) input = jnp.ones(shape=(jax.local_device_count(),), device=shardings) - input_rev = jax.device_put(input_no_shards, device=shardings_rev) + input_rev = jnp.ones(shape=(jax.local_device_count(),), device=shardings_rev) - exp = export.export(pjit.pjit(f, in_shardings=shardings))(input) - exp_rev = export.export(pjit.pjit(f, in_shardings=shardings_rev))(input_no_shards) + exp = export.export(jax.jit(f, in_shardings=shardings))(input) + exp_rev = export.export(jax.jit(f, in_shardings=shardings_rev))(input_no_shards) if CAN_SERIALIZE: _ = exp.serialize(vjp_order=1) @@ -1517,7 +1811,7 @@ def test_multi_platform(self): self.assertIn("jax.uses_shape_polymorphism = true", module_str) - # Call with argument placed on different plaforms + # Call with argument placed on different platforms for platform in self.platforms: x_device = jax.device_put(x, jax.devices(platform)[0]) res_exp = exp.call(x_device) @@ -1542,7 +1836,7 @@ def test_multi_platform_nested(self): count_sine = len(re.findall("stablehlo.sine", exp2_module_str)) self.assertEqual(1, count_sine) - # Call with argument placed on different plaforms + # Call with argument placed on different platforms for platform in self.platforms: if platform == "tpu": continue x_device = jax.device_put(x, jax.devices(platform)[0]) @@ -1685,7 +1979,7 @@ def f_jax(b): # b: f32[16 // DEVICES, 4] res_native = f_jax(a) exp = get_exported(f_jax, platforms=("cpu", "tpu", "cuda", "rocm"))(a) - # Call with argument placed on different plaforms + # Call with argument placed on different platforms for platform in self.platforms: run_devices = jax.devices(platform)[0:len(export_devices)] if len(run_devices) != len(export_devices): @@ -1695,6 +1989,22 @@ def f_jax(b): # b: f32[16 // DEVICES, 4] res_exp = exp.call(a_device) self.assertArraysAllClose(res_native, res_exp) + def test_compute_on_host(self): + operand = np.float32(0.) + + @jax.jit + @compute_on.compute_on("device_host") + def f_host(x): + # Adds 1 on CPU, which should be the result on all platforms because + # this code should always run on the host. + return jax.lax.platform_dependent(x, + cpu=lambda x: x + np.float32(1.), + default=lambda x: x + np.float32(2.)) + + self.assertAllClose(np.float32(1.), f_host(operand)) + exp = get_exported(f_host, platforms=("cpu", "tpu", "cuda", "rocm"))(operand) + self.assertAllClose(np.float32(1.), exp.call(operand)) + @jtu.parameterized_filterable( kwargs=[ dict(v=v) @@ -1901,10 +2211,14 @@ def f_jax(x): with self.assertRaisesRegex(Exception, expect_error): _ = get_exported(jax.jit(f_jax))(jax.ShapeDtypeStruct((3, 4), x.dtype)) + # On tpu ragged_dot does not zero out the output, hence in the case `sum + # (group_sizes) < m`, the rows that are not routed to any experts remain + # uninitialized. For that reason `group_sizes` sum up to `m` for all test + # cases below. @jtu.parameterized_filterable( kwargs=[ - {"m": 5, "k": 4, "n": 3, "group_sizes": [5]}, - {"m": 10, "k": 9, "n": 8, "group_sizes": [3, 7]}, + {"m": 64, "k": 4, "n": 3, "group_sizes": [64]}, + {"m": 64, "k": 9, "n": 8, "group_sizes": [30, 34]}, ]) def test_ragged_dot(self, m, k, n, group_sizes): def f_jax(x, y, gs): @@ -1961,6 +2275,69 @@ def f(x, y): r = jax.jit(exp.call, out_shardings=NamedSharding(old_mesh_0, P("old_b")))(a, b) self.assertAllClose(a + b, r) + def test_lower_and_load_with_different_meshes_axis_names(self): + self.skipTest("TODO(necula): different axis names are not supported yet") + mesh1 = jtu.create_mesh((8,), ("a",)) + mesh2 = jtu.create_mesh((8,), ("b",)) + mesh3 = jtu.create_mesh((8,), ("c",)) + mesh4 = jtu.create_mesh((8,), ("d",)) + + a = jnp.arange(32 * 32, dtype=np.float32).reshape((32, 32)) + b = jnp.arange(32 * 32, dtype=np.float32).reshape((32, 32)) + + @jax.jit + def f(x, y): + return x + lax.with_sharding_constraint(y, NamedSharding(mesh2, P(None, "b"))) + + a_put = jax.device_put(a, NamedSharding(mesh1, P(None, "a"))) + b_put = jax.device_put(b, NamedSharding(mesh2, P(None, "b"))) + exp = get_exported(f)(a_put, b_put) + + jax.jit(exp.call, in_shardings=(NamedSharding(mesh3, P("c")), + NamedSharding(mesh4, P("d"))))(a, b) + + @jtu.parameterized_filterable( + kwargs=[ + {"use_shardy_on_save": True, "error_msg": "Please enable Shardy", + "poly_shape": False}, + {"use_shardy_on_save": False, "error_msg": "", "poly_shape": False}, + {"use_shardy_on_save": False, "error_msg": "", "poly_shape": True}, + ]) + def test_lower_load_with_different_partitioners(self, use_shardy_on_save, + error_msg, poly_shape): + with config.use_shardy_partitioner(use_shardy_on_save): + mesh = jtu.create_mesh((8,), ("a",)) + @jax.jit + def f(x, y): + z = x + y + return jax.lax.with_sharding_constraint( + z, NamedSharding(mesh, P("a"))) + + args = ( + jax.ShapeDtypeStruct( + (32, 32), dtype=np.float32, + sharding=NamedSharding(mesh, P(None, "a"))), + jax.ShapeDtypeStruct( + (32, 32), dtype=np.float32, + sharding=NamedSharding(mesh, P("a")))) + + if poly_shape: + args = export.symbolic_args_specs(args, shapes_specs=["32, a", "32, a"]) + + exp = get_exported(f)(*args) + + with config.use_shardy_partitioner(not use_shardy_on_save): + a = jnp.arange(32 * 32, dtype=np.float32).reshape((32, 32)) + a = jax.device_put(a, NamedSharding(mesh, P(None, "a"))) + b = jnp.arange(32 * 32, dtype=np.float32).reshape((32, 32)) + b = jax.device_put(b, NamedSharding(mesh, P("a"))) + + if use_shardy_on_save: + with self.assertRaisesRegex(ValueError, error_msg): + jax.jit(exp.call, out_shardings=NamedSharding(mesh, P("a")))(a, b) + else: + jax.jit(exp.call, out_shardings=NamedSharding(mesh, P("a")))(a, b) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/extend_test.py b/tests/extend_test.py index fcf9d3b54c6d..d7572dc2b1fe 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -20,11 +20,13 @@ from jax._src import abstract_arrays from jax._src import api +from jax._src import core from jax._src import linear_util from jax._src import prng from jax._src import test_util as jtu from jax._src import xla_bridge from jax._src.interpreters import mlir +from jax._src.lib import xla_client jax.config.parse_flags_with_absl() @@ -47,6 +49,8 @@ def test_symbols(self): self.assertIs(jex.backend.get_backend, xla_bridge.get_backend) self.assertIs(jex.backend.register_backend_factory, xla_bridge.register_backend_factory) self.assertIs(jex.core.array_types, abstract_arrays.array_types) + self.assertIs(jex.core.mapped_aval, core.mapped_aval) + self.assertIs(jex.core.unmapped_aval, core.unmapped_aval) self.assertIs(jex.linear_util.StoreException, linear_util.StoreException) self.assertIs(jex.linear_util.WrappedFun, linear_util.WrappedFun) self.assertIs(jex.linear_util.cache, linear_util.cache) @@ -105,5 +109,20 @@ def test_unknown_platform_error(self): mlir.register_lowering(prim=None, rule=None, platform="foo") +class ShardingTest(jtu.JaxTestCase): + def test_hlo_sharding_roundtrip(self): + proto = xla_client.OpSharding() + hlo_sharding = xla_client.HloSharding.from_proto(proto) + serialized_proto = jex.sharding.get_serialized_proto_from_hlo_sharding( + hlo_sharding + ) + self.assertIsInstance(serialized_proto, bytes) + deserialized_hlo_sharding = jex.sharding.get_hlo_sharding_from_serialized_proto( + serialized_proto + ) + self.assertIsInstance(deserialized_hlo_sharding, xla_client.HloSharding) + self.assertEqual(hlo_sharding, deserialized_hlo_sharding) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/ffi_test.py b/tests/ffi_test.py index 46aaefa8f521..035d355b3149 100644 --- a/tests/ffi_test.py +++ b/tests/ffi_test.py @@ -22,20 +22,19 @@ import jax from jax import lax -import jax.extend as jex import jax.numpy as jnp from jax.sharding import PartitionSpec as P -from jax._src import config from jax._src import core from jax._src import dispatch +from jax._src import dtypes from jax._src import test_util as jtu from jax._src.interpreters import mlir -from jax._src.layout import DeviceLocalLayout +from jax._src.layout import Layout from jax._src.lib import lapack from jax._src.lib.mlir.dialects import hlo from jax._src.lax import linalg as lax_linalg_internal -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map jax.config.parse_flags_with_absl() jtu.request_cpu_devices(8) @@ -59,7 +58,7 @@ def test_headers_exist(self): @parameterized.parameters([ (tuple(range(3)), tuple(range(3))), (None, tuple(reversed(range(3)))), - (DeviceLocalLayout(tuple(range(3))), tuple(reversed(range(3)))), + (Layout(tuple(range(3))), tuple(reversed(range(3)))), ]) def test_lowering_layouts(self, layout_spec, expected_layout): # Regression test to ensure that the lowering rule properly captures @@ -86,15 +85,56 @@ def lowering_rule(ctx, x): pattern = rf"result_layouts = \[dense<\[{expected}\]>" self.assertRegex(text, pattern) - @parameterized.parameters([ - (True, mlir.ir.BoolAttr.get), - (1, mlir.i64_attr), - (5.0, lambda x: mlir.ir.FloatAttr.get(mlir.ir.F64Type.get(), x)), - ("param", mlir.ir.StringAttr.get), - (np.float32(0.5), - lambda x: mlir.ir.FloatAttr.get(mlir.ir.F32Type.get(), x)), - ]) - def test_params(self, param, expected_builder): + # Concise helpers to every test instance below in one line. + _arr = lambda value, dtype=None: np.array(value, dtype=dtype) + _ftens1 = lambda et: f"dense<1.000000e+00> : tensor<{et}>" + _itens1 = lambda et: f"dense<1> : tensor<{et}>" + + @parameterized.parameters( + (_arr(1, dtypes.int2), _itens1("i2")), + (_arr(1, dtypes.int4), _itens1("i4")), + (_arr(1, dtypes.uint2), _itens1("ui2")), + (_arr(1, dtypes.uint4), _itens1("ui4")), + (_arr(1, np.int16), _itens1("i16")), + (_arr(1, np.int32), _itens1("i32")), + (_arr(1, np.int64), _itens1("i64")), + (_arr(1, np.int8), _itens1("i8")), + (_arr(1, np.uint16), _itens1("ui16")), + (_arr(1, np.uint32), _itens1("ui32")), + (_arr(1, np.uint64), _itens1("ui64")), + (_arr(1, np.uint8), _itens1("ui8")), + (_arr(1.0, dtypes.bfloat16), _ftens1("bf16")), + (_arr(1.0, dtypes.float4_e2m1fn), _ftens1("f4E2M1FN")), + (_arr(1.0, dtypes.float8_e3m4), _ftens1("f8E3M4")), + (_arr(1.0, dtypes.float8_e4m3), _ftens1("f8E4M3")), + (_arr(1.0, dtypes.float8_e4m3b11fnuz), _ftens1("f8E4M3B11FNUZ")), + (_arr(1.0, dtypes.float8_e4m3fn), _ftens1("f8E4M3FN")), + (_arr(1.0, dtypes.float8_e4m3fnuz), _ftens1("f8E4M3FNUZ")), + (_arr(1.0, dtypes.float8_e5m2), _ftens1("f8E5M2")), + (_arr(1.0, dtypes.float8_e5m2fnuz), _ftens1("f8E5M2FNUZ")), + (_arr(1.0, dtypes.float8_e8m0fnu), _ftens1("f8E8M0FNU")), + (_arr(1.0, np.bool), "dense : tensor"), + (_arr(1.0, np.float16), _ftens1("f16")), + (_arr(1.0, np.float32), _ftens1("f32")), + (_arr(1.0, np.float64), _ftens1("f64")), + (dtypes.bfloat16(1.0), "1.000000e+00 : bf16"), + (np.bool(False), "false"), + (np.bool(True), "true"), + (np.float16(1.0), "1.000000e+00 : f16"), + (np.float32(1.0), "1.000000e+00 : f32"), + (np.float64(1.0), "1.000000e+00 : f64"), + (np.int16(1), "1 : i16"), + (np.int32(1), "1 : i32"), + (np.int64(1), "1 : i64"), + (np.int8(1), "1 : i8"), + (np.uint16(1), "1 : ui16"), + (np.uint32(1), "1 : ui32"), + (np.uint64(1), "1 : ui64"), + (np.uint8(1), "1 : ui8"), + (np.zeros((), dtype=dtypes.float0), "dense : tensor"), + ("param", '"param"'), + ) + def test_params(self, param, expected_str): def fun(x): return jax.ffi.ffi_call("test_ffi", x)(x, param=param) @@ -102,13 +142,10 @@ def fun(x): # serialized with the appropriate type. module = jax.jit(fun).lower(0.5).compiler_ir("stablehlo") op = self.find_custom_call_in_module(module) - config = op.attributes["mhlo.backend_config"] - self.assertIsInstance(config, mlir.ir.DictAttr) - self.assertIn("param", config) - with mlir.make_ir_context(), mlir.ir.Location.unknown(): - expected = expected_builder(param) - self.assertEqual(type(config["param"]), type(expected)) - self.assertTrue(expected.type.isinstance(config["param"].type)) + conf = op.attributes["mhlo.backend_config"] + self.assertIsInstance(conf, mlir.ir.DictAttr) + self.assertIn("param", conf) + self.assertEqual(str(conf["param"]), expected_str) def test_token(self): def fun(): @@ -151,7 +188,7 @@ def test_non_hashable_attributes(self): def fun(x): return jax.ffi.ffi_call("test_ffi", x)(x, non_hashable_arg={"a": 1}) - self.assertIn("HashableDict", str(jax.make_jaxpr(fun)(jnp.ones(5)))) + self.assertIn("FrozenDict", str(jax.make_jaxpr(fun)(jnp.ones(5)))) hlo = jax.jit(fun).lower(jnp.ones(5)).as_text() self.assertIn("non_hashable_arg = {a = 1", hlo) @@ -200,21 +237,6 @@ def test_ffi_call_batching(self, shape, vmap_method): else: self.assertArraysEqual(a, b) - @jtu.run_on_devices("gpu", "cpu") - def test_vectorized_deprecation(self): - x = self.rng().randn(3, 5, 4).astype(np.float32) - with self.assertWarns(DeprecationWarning): - ffi_call_geqrf(x, vectorized=True) - with self.assertWarns(DeprecationWarning): - jax.vmap(ffi_call_geqrf)(x) - - def test_backward_compat_syntax(self): - def fun(x): - return jax.ffi.ffi_call("test_ffi", x, x, param=0.5) - msg = "Calling ffi_call directly with input arguments is deprecated" - with self.assertDeprecationWarnsOrRaises("jax-ffi-call-args", msg): - jax.jit(fun).lower(jnp.ones(5)) - def test_input_output_aliases(self): def fun(x): return jax.ffi.ffi_call("test", x, input_output_aliases={0: 0})(x) @@ -269,8 +291,6 @@ def fun(x): jax.jit(fun).lower(jnp.ones(5)).as_text() def test_allow_x64(self): - if config.enable_x64.value: - self.skipTest("Requires enable_x64=False") def fun(): return jax.ffi.ffi_call("test", jax.ShapeDtypeStruct((), np.int64))() self.assertIn("tensor", jax.jit(fun).lower().as_text()) @@ -296,13 +316,13 @@ def f(x): jax.jit(f)(x) # neither does JIT self.assertNotIn("all-gather", jax.jit(f).lower(x).compile().as_text()) - @jtu.run_on_devices("gpu", "cpu") - @jtu.ignore_warning(category=DeprecationWarning) - def test_extend_import_shim(self): - ffi_call_geqrf(jnp.ones((4, 5), dtype=np.float32), _use_extend=True) + def test_extended_dtype_lowering(self): + def f(x): + return jax.ffi.ffi_call("edtype", (), has_side_effect=True)(x) + jax.jit(f).lower(jax.random.key(0)) # doesn't crash -def ffi_call_geqrf(x, _use_extend=False, **kwargs): +def ffi_call_geqrf(x, **kwargs): if jtu.test_device_matches(["cpu"]): lapack._lapack.initialize() @@ -318,8 +338,7 @@ def call(platform, x): rocm="hipsolver_geqrf_ffi", cuda="cusolver_geqrf_ffi", )[platform] - f = jex.ffi.ffi_call if _use_extend else jax.ffi.ffi_call - return f( + return jax.ffi.ffi_call( target_name, output_types, input_output_aliases={0: 0}, input_layouts=[x_major_to_minor], output_layouts=[x_major_to_minor, None], @@ -349,7 +368,7 @@ def test_shard_map(self): x = self.rng().randn(8, 4, 5).astype(np.float32) @partial(shard_map, mesh=mesh, in_specs=P("i"), out_specs=P("i"), - check_rep=False) + check_vma=False) def f(x): return batch_partitionable_ffi_call(x) diff --git a/tests/fft_test.py b/tests/fft_test.py index 26b69f5beca8..230aa54a8b37 100644 --- a/tests/fft_test.py +++ b/tests/fft_test.py @@ -336,7 +336,7 @@ def testFft2Errors(self, inverse, real): ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[-3, -4])) @jtu.sample_product( - dtype=all_dtypes, + dtype=jtu.dtypes.floating + jtu.dtypes.complex, size=[9, 10, 101, 102], d=[0.1, 2.], device=[None, -1], @@ -344,15 +344,12 @@ def testFft2Errors(self, inverse, real): def testFftfreq(self, size, d, dtype, device): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng([size], dtype),) - jnp_op = jnp.fft.fftfreq - np_op = np.fft.fftfreq if device is not None: device = jax.devices()[device] - jnp_fn = lambda a: jnp_op(size, d=d, device=device) - np_fn = lambda a: np_op(size, d=d) + jnp_fn = lambda a: jnp.fft.fftfreq(size, d=d, device=device, dtype=dtype) + np_fn = lambda a: np.fft.fftfreq(size, d=d).astype(dtype) # Numpy promotes to complex128 aggressively. - self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, - tol=1e-4) + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker) # Test gradient for differentiable types. if dtype in inexact_dtypes: diff --git a/tests/filecheck/custom_call.filecheck.py b/tests/filecheck/custom_call.filecheck.py index c6af4235ebb4..27cc904e59d8 100644 --- a/tests/filecheck/custom_call.filecheck.py +++ b/tests/filecheck/custom_call.filecheck.py @@ -19,7 +19,7 @@ from absl import app import jax -from jax.interpreters import mlir +from jax._src.interpreters import mlir from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect import numpy as np diff --git a/tests/filecheck/jax_mlir_ext.filecheck.py b/tests/filecheck/jax_mlir_ext.filecheck.py new file mode 100644 index 000000000000..dd0b8be2f6fd --- /dev/null +++ b/tests/filecheck/jax_mlir_ext.filecheck.py @@ -0,0 +1,139 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# RUN: %PYTHON %s | FileCheck %s -dump-input=always + +from absl import app +import jax +from jax._src.interpreters import mlir +from jax._src.lib.mlir import ir +from jax._src.lib.mlir import passmanager +from jax._src.lib.mlir.dialects import func as func_dialect +from jax._src.lib.mlir.dialects import hlo +from jaxlib.mlir._mlir_libs import _jax_mlir_ext +import numpy as np + + +jax.config.parse_flags_with_absl() + + +def test_inlined_func_call(): + # CHECK: #loc = loc(unknown) + # CHECK: module { + # CHECK-NEXT: func.func public @caller(%arg0: tensor<2x3xf32> loc(unknown), %arg1: tensor<2x3xf32> loc(unknown)) -> (tensor<2x3xf32>, tensor<2x3xf32>) { + # CHECK-NEXT: %0 = stablehlo.add %arg1, %arg0 : tensor<2x3xf32> loc(#loc5) + # CHECK-NEXT: %1 = stablehlo.multiply %0, %arg1 : tensor<2x3xf32> loc(#loc6) + # CHECK-NEXT: return %0, %1 : tensor<2x3xf32>, tensor<2x3xf32> loc(#loc) + # CHECK-NEXT: } loc(#loc) + # CHECK-NEXT: } loc(#loc) + # CHECK-NEXT: #loc1 = loc("caller_file":3:4) + # CHECK-NEXT: #loc2 = loc("caller_stack"(#loc1)) + # CHECK-NEXT: #loc3 = loc("caller_name/callee_name1"(#loc2)) + # CHECK-NEXT: #loc4 = loc("caller_name/callee_name2"(#loc2)) + # CHECK-NEXT: #loc5 = loc("caller_type:"(#loc3)) + # CHECK-NEXT: #loc6 = loc("caller_type:"(#loc4)) + ctx = mlir.make_ir_context() + loc = ir.Location.unknown(context=ctx) + aval = jax.core.ShapedArray((2, 3), np.dtype(np.float32)) + arg_avals = [aval, aval] + result_avals = [aval, aval] + with ctx, loc: + callee_stack_loc = ir.Location.name( + "callee_stack", ir.Location.file("callee_file", 1, 2) + ) + callee_loc1 = ir.Location.name( + "callee_name1", ir.Location.name("callee_type1:", callee_stack_loc) + ) + callee_loc2 = ir.Location.name( + "callee_name2", ir.Location.name("callee_type2:", callee_stack_loc) + ) + caller_stack_loc = ir.Location.name( + "caller_stack", ir.Location.file("caller_file", 3, 4) + ) + caller_loc = ir.Location.name( + "caller_name", ir.Location.name("caller_type:", caller_stack_loc) + ) + + module = ir.Module.create(loc=ir.Location.unknown()) + ip = ir.InsertionPoint(module.body) + arg_types = [mlir.aval_to_ir_type(aval) for aval in arg_avals] + result_types = [mlir.aval_to_ir_type(aval) for aval in result_avals] + ftype = ir.FunctionType.get(arg_types, result_types) + callee = func_dialect.FuncOp("callee", ftype, ip=ip) + callee.attributes["sym_visibility"] = ir.StringAttr.get("private") + entry_block = callee.add_entry_block() + with ir.InsertionPoint(entry_block): + with callee_loc1: + x = hlo.add(entry_block.arguments[0], entry_block.arguments[1]) + with callee_loc2: + y = hlo.multiply(x, entry_block.arguments[0]) + func_dialect.ReturnOp([x, y]) + + caller = func_dialect.FuncOp("caller", ftype, ip=ip) + caller.attributes["sym_visibility"] = ir.StringAttr.get("public") + entry_block = caller.add_entry_block() + with ir.InsertionPoint(entry_block): + x, y = entry_block.arguments + with caller_loc: + x, y = _jax_mlir_ext.inlined_func_call(callee, [y, x], entry_block) + func_dialect.ReturnOp([x, y]) + module.operation.verify() + pipeline = passmanager.PassManager.parse("builtin.module(symbol-dce)") + pipeline.run(module.operation) + print(module.operation.print(enable_debug_info=True)) + + +def test_traceback_to_location(): + def f(): + return g() + + def g(): + return h() + + def h(): + return jax._src.lib._jax.Traceback.get_traceback() + + tb = f() + + def _code_to_filename(code): + return ( + "jax_mlir_ext_test.filecheck.py" + if "jax_mlir_ext.filecheck" in code.co_filename + else None + ) + + # CHECK: --- test_traceback_to_location + print("--- test_traceback_to_location") + ctx = mlir.make_ir_context() + with ctx: + # CHECK: loc(callsite("test_traceback_to_location..h"("jax_mlir_ext_test.filecheck.py":{{[0-9]+}}:{{[0-9]+}} to :{{[0-9]+}}) at callsite("test_traceback_to_location..g"("jax_mlir_ext_test.filecheck.py":{{[0-9]+}}:{{[0-9]+}} to :{{[0-9]+}}) at callsite("test_traceback_to_location..f"("jax_mlir_ext_test.filecheck.py":{{[0-9]+}}:{{[0-9]+}} to :{{[0-9]+}}) at callsite("test_traceback_to_location"("jax_mlir_ext_test.filecheck.py":{{[0-9]+}}:{{[0-9]+}} to :{{[0-9]+}}) at callsite("main"("jax_mlir_ext_test.filecheck.py":{{[0-9]+}}:{{[0-9]+}} to :{{[0-9]+}}) at ""("jax_mlir_ext_test.filecheck.py":{{[0-9]+}}:{{[0-9]+}} to :{{[0-9]+}}))))))) + cache = _jax_mlir_ext.TracebackToLocationCache( + code_to_filename=_code_to_filename, frame_limit=1000) + loc = cache.get(tb) + print(loc) + + # CHECK: loc(callsite("test_traceback_to_location..h"("jax_mlir_ext_test.filecheck.py":{{[0-9]+}}:{{[0-9]+}} to :{{[0-9]+}}) at "test_traceback_to_location..g"("jax_mlir_ext_test.filecheck.py":{{[0-9]+}}:{{[0-9]+}} to :{{[0-9]+}}))) + cache = _jax_mlir_ext.TracebackToLocationCache( + code_to_filename=_code_to_filename, frame_limit=2) + loc = cache.get(tb) + print(loc) + + +def main(_): + test_inlined_func_call() + test_traceback_to_location() + + +if __name__ == "__main__": + app.run(main) diff --git a/tests/filecheck/subcomputations.filecheck.py b/tests/filecheck/subcomputations.filecheck.py index b3c3191ca416..77cbd16a8a74 100644 --- a/tests/filecheck/subcomputations.filecheck.py +++ b/tests/filecheck/subcomputations.filecheck.py @@ -30,7 +30,7 @@ def main(_): - # The lowering of cumsum is annotated with @cache_lowering, which means we + # The lowering of cumsum is annotated with inline=False, which means we # should lower it as an out-of-line function once for any given shape. # CHECK-LABEL: TEST: cumsum_only_once int32[2,7] int32[2,7] diff --git a/tests/for_loop_test.py b/tests/for_loop_test.py deleted file mode 100644 index 9e0ebd4ff922..000000000000 --- a/tests/for_loop_test.py +++ /dev/null @@ -1,409 +0,0 @@ -# Copyright 2022 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from functools import partial - -from absl.testing import absltest -from absl.testing import parameterized - -import numpy as np - -import jax -from jax import random -from jax._src import test_util as jtu -from jax._src.lax.control_flow import for_loop -import jax.numpy as jnp - -jax.config.parse_flags_with_absl() - -def remat_of_for_loop(nsteps, body, state, **kwargs): - return jax.remat(lambda state: for_loop.for_loop(nsteps, body, state, - **kwargs))(state) - -def nested_for_loop(nsteps, body, state, **kwargs): - def outer_body(_, refs): - def inner_body(i, _): - body(i, refs) - return - for_loop.for_loop(nsteps, inner_body, ()) - return for_loop.for_loop(1, outer_body, state) - -FOR_LOOP_IMPLS = [ - (for_loop.for_loop, 'for_loop'), - (jax.jit(for_loop.for_loop, static_argnums=(0, 1)), 'jit_for_loop'), - (remat_of_for_loop, 'remat_for_loop'), - (nested_for_loop, 'nested_for_loop'), - (partial(for_loop.for_loop, unroll=3), 'unrolled_for_loop'), -] - - -def _for_loop_impls(f): - return parameterized.named_parameters( - dict(testcase_name=impl_name, for_impl=for_impl) - for for_impl, impl_name in FOR_LOOP_IMPLS - )(f) - - -class ForLoopTest(jtu.JaxTestCase): - - @_for_loop_impls - def test_for_loop_impl_trivial(self, for_impl): - out = for_impl(5, lambda i, _: None, None) - self.assertIsNone(out) - - @_for_loop_impls - def test_for_loop_can_write_to_ref(self, for_impl): - def body(_, x_ref): - x_ref[()] = jnp.float32(1.) - out = for_impl(1, body, jnp.float32(0.)) - self.assertEqual(out, 1.) - - def body2(i, x_ref): - x_ref[()] = jnp.float32(i) - out = for_impl(2, body2, jnp.float32(0.)) - self.assertEqual(out, 1.) - - def body3(i, x_ref): - x_ref[()] = jnp.float32(i) * 2. - out = for_impl(2, body3, jnp.float32(0.)) - self.assertEqual(out, 2.) - - @_for_loop_impls - def test_for_loop_can_write_to_multiple_refs(self, for_impl): - def body(_, refs): - x_ref, y_ref = refs - x_ref[()] = jnp.float32(1.) - y_ref[()] = jnp.float32(2.) - x, y = for_impl(1, body, (jnp.float32(0.), jnp.float32(0.))) - self.assertEqual(x, 1.) - self.assertEqual(y, 2.) - - @_for_loop_impls - def test_for_loop_can_read_from_ref(self, for_impl): - def body(_, x_ref): - x_ref[()] # pylint: disable=pointless-statement - x = for_impl(1, body, jnp.float32(0.)) - self.assertEqual(x, 0.) - - @_for_loop_impls - def test_for_loop_can_read_from_and_write_to_ref(self, for_impl): - def body(_, x_ref): - x = x_ref[()] - x_ref[()] = x + jnp.float32(1.) - x = for_impl(5, body, jnp.float32(0.)) - self.assertEqual(x, 5.) - - @_for_loop_impls - def test_for_loop_can_read_from_and_write_to_refs(self, for_impl): - def body2(_, refs): - x_ref, y_ref = refs - x = x_ref[()] - y_ref[()] = x + 1. - x_ref[()] = x + 1. - x, y = for_impl(5, body2, (0., 0.)) - self.assertEqual(x, 5.) - self.assertEqual(y, 5.) - - @_for_loop_impls - def test_for_loop_can_read_from_and_write_to_ref_slice(self, for_impl): - def body(i, x_ref): - x = x_ref[i] - x_ref[i] = x + jnp.float32(1.) - x = for_impl(4, body, jnp.ones(4, jnp.float32)) - np.testing.assert_allclose(x, 2 * jnp.ones(4, jnp.float32)) - - def body2(i, x_ref): - x = x_ref[i, 0] - x_ref[i, 1] = x + x_ref[i, 1] - x = for_impl(4, body2, jnp.arange(8.).reshape((4, 2))) - np.testing.assert_allclose( - x, jnp.array([[0., 1.], [2., 5.], [4., 9.], [6., 13.]])) - - @_for_loop_impls - @jax.legacy_prng_key('allow') - def test_for_loop_can_implement_cumsum(self, for_impl): - def cumsum(x): - def body(i, refs): - x_ref, accum_ref = refs - accum_ref[i + 1] = accum_ref[i] + x_ref[i] - accum = jnp.zeros(x.shape[0] + 1, x.dtype) - _, accum_out = for_impl(x.shape[0], body, (x, accum)) - return accum_out[1:] - - key = jax.random.PRNGKey(0) - x = jax.random.normal(key, (8,)) - np.testing.assert_allclose(cumsum(x), jnp.cumsum(x), rtol=1e-6) - -def for_body_swap(i, refs): - a_ref, b_ref = refs - a, b = a_ref[i], b_ref[i] - b_ref[i] = a - a_ref[i] = b - -def swap_ref(a, b): - return b, a - -def for_body_swap_swap(i, refs): - for_body_swap(i, refs) - for_body_swap(i, refs) - -swap_swap_ref = lambda a, b: (a, b) - -def for_body_sincos(i, refs): - a_ref, b_ref = refs - a = a_ref[i] - b_ref[i] = jnp.sin(jnp.cos(a)) - -sincos_ref = lambda x, y: (x, jnp.sin(jnp.cos(x))) - -def for_body_sincostan(i, refs): - a_ref, b_ref = refs - a = a_ref[i] - b_ref[i] = jnp.tan(jnp.sin(jnp.cos(a))) - -sincostan_ref = lambda x, y: (x, jnp.tan(jnp.sin(jnp.cos(x)))) - -def for_body_accum(i, refs): - x_ref, accum_ref = refs - accum_ref[i + 1] = accum_ref[i] + x_ref[i] - -def accum_ref(x, accum): - for i in range(x.shape[0] - 1): - accum = accum.at[i + 1].set(accum[i] + x[i]) - return x, accum - -def for_body_sin_sq(i, refs): - x_ref, y_ref = refs - x = x_ref[i] - y = x - y_ref[i] = y - y = y_ref[i] - y_ref[i] = jnp.sin(y * y) - -sin_sq_ref = lambda x, y: (x, jnp.sin(x * x)) - -def for_body_reverse(i, refs): - x_ref, y_ref = refs - j = y_ref.shape[0] - i - 1 - y_ref[i] = x_ref[j] - -reverse_ref = lambda x, y: (x, x[::-1]) - -def for_body_noop(i, refs): - pass -noop_ref = lambda x, y: (x, y) -for_reference = for_loop.discharged_for_loop - - -class ForLoopTransformationTest(jtu.JaxTestCase): - - @jtu.sample_product( - [dict(for_body_name=for_body_name, f=for_body, ref=ref, - body_shapes=body_shapes, n=nsteps) - for for_body_name, for_body, ref, body_shapes, nsteps in [ - ("swap", for_body_swap, swap_ref, [(4,), (4,)], 4), - ("swap_swap", for_body_swap_swap, swap_swap_ref, [(4,), (4,)], 4), - ("sincos", for_body_sincos, sincos_ref, [(4,), (4,)], 4), - ("sincostan", for_body_sincostan, sincostan_ref, [(4,), (4,)], 4), - ("accum", for_body_accum, accum_ref, [(4,), (4,)], 3), - ("sin_sq", for_body_sin_sq, sin_sq_ref, [(4,), (4,)], 4), - ("reverse", for_body_reverse, reverse_ref, [(4,), (4,)], 4), - ] - ], - [dict(for_impl=for_impl, impl_name=impl_name) - for for_impl, impl_name in FOR_LOOP_IMPLS], - ) - @jtu.skip_on_devices("gpu", "cpu") # TODO(mattjj,sharadmv, dougalm): timeouts? - def test_for_jvp(self, f, ref, body_shapes, n, for_impl, for_body_name, - impl_name): - for_ = for_impl - rng = self.rng() - - args = [rng.randn(*s) for s in body_shapes] - - tol = {np.float64: 1e-12, np.float32: 1e-4} - ans = jax.jvp( lambda *args: for_( n, f, args), args, args) - ans_discharged = jax.jvp(lambda *args: for_reference(n, f, args), args, args) - expected = jax.jvp(ref, args, args) - self.assertAllClose(ans, ans_discharged, check_dtypes=True, rtol=tol, atol=tol) - self.assertAllClose(ans, expected, check_dtypes=True, rtol=tol, atol=tol) - jtu.check_grads(partial(for_, n, f), (args,), order=2, modes=["fwd"]) - - @jtu.sample_product( - [dict(for_body_name=for_body_name, f=for_body, ref=ref, - body_shapes=body_shapes, n=nsteps) - for for_body_name, for_body, ref, body_shapes, nsteps in [ - ("swap", for_body_swap, swap_ref, [(4,), (4,)], 4), - ("swap_swap", for_body_swap_swap, swap_swap_ref, [(4,), (4,)], 4), - ("sincos", for_body_sincos, sincos_ref, [(4,), (4,)], 4), - ("sincostan", for_body_sincostan, sincostan_ref, [(4,), (4,)], 4), - ("accum", for_body_accum, accum_ref, [(4,), (4,)], 3), - ("sin_sq", for_body_sin_sq, sin_sq_ref, [(4,), (4,)], 4), - ("reverse", for_body_reverse, reverse_ref, [(4,), (4,)], 4), - ] - ], - [dict(for_impl=for_impl, impl_name=impl_name) - for for_impl, impl_name in FOR_LOOP_IMPLS], - ) - @jtu.skip_on_devices("gpu", "cpu") # TODO(mattjj,sharadmv, dougalm): timeouts? - def test_for_linearize(self, f, ref, body_shapes, n, for_impl, for_body_name, - impl_name): - for_ = for_impl - rng = self.rng() - - args = [rng.randn(*s) for s in body_shapes] - - tol = {np.float64: 1e-12, np.float32: 1e-4} - ans = jax.linearize(lambda *args: for_( n, f, args), *args)[1](*args) - ans_discharged = jax.linearize(lambda *args: for_reference(n, f, args), - *args)[1](*args) - expected = jax.linearize(ref, *args)[1](*args) - self.assertAllClose(ans, ans_discharged, check_dtypes=True, rtol=tol, atol=tol) - self.assertAllClose(ans, expected, check_dtypes=True, rtol=tol, atol=tol) - - def test_for_loop_invar(self): - def f(x): - s = jnp.ones((2, 32), x.dtype) - def body(i, refs): - x_ref, y_ref = refs - y_ref[i] = s * x_ref[i] * jnp.cos(s) - # We should save `s` and `jnp.cos(s)` as residuals and not broadcast - # them. - return for_loop.for_loop(x.shape[0], body, (x, jnp.zeros_like(x))) - _, f_vjp = jax.linearize(f, jnp.ones((5, 2, 32))) - jaxpr = jax.make_jaxpr(f_vjp)(jnp.ones((5, 2, 32))) - consts = [v.aval for v in jaxpr.jaxpr.constvars - if v.aval.shape == (2, 32)] - self.assertLen(consts, 2) - - def loss(A): - def step(x, _): - return jnp.matmul(A, x), None - init_x = jnp.zeros(A.shape[-1:]) - last_x, _ = for_loop.scan(step, init_x, jnp.arange(10)) - return jnp.sum(last_x) - - A = jnp.zeros((3, 3)) - # The second DUS was unnecessarily replicating A across time. - # We check XLA because _scan_impl is "underneath" the jaxpr language. - s = jax.jit(jax.grad(loss)).lower(A).as_text('hlo') - assert s.count("dynamic-update-slice(") < 2 - - @_for_loop_impls - def test_for_loop_fixpoint_correctly_identifies_loop_varying_residuals( - self, for_impl): - - def body(i, refs): - a_ref, b_ref, c_ref = refs - a = a_ref[i] - b = b_ref[()] - x = jnp.sin(a) - b_ref[()] = jnp.sin(b * x) - c_ref[i] = x * b - def f(a, b): - c = jnp.zeros_like(a) - _, b, c = for_impl(5, body, (a, b, c)) - return b, c - a = jnp.arange(5.) + 1. - b = jnp.ones_like(a[0]) - _, f_lin = jax.linearize(f, a, b) - expected_tangents = f_lin(a, b) - _, actual_tangents = jax.jvp(f, (a, b), (a, b)) - np.testing.assert_allclose(actual_tangents[0], expected_tangents[0], - rtol=1e-6, atol=1e-6) - np.testing.assert_allclose(actual_tangents[1], expected_tangents[1], - rtol=1e-6, atol=1e-6) - - def body2(_, refs): - # Here we use `i_ref` as a loop counter - a_ref, b_ref, c_ref, i_ref = refs - i = i_ref[()] - a = a_ref[i] - b = b_ref[()] - x = jnp.sin(a) - b_ref[()] = jnp.sin(b * x) - c_ref[i] = x * b - i_ref[()] = i + 1 - - def g(a, b): - c = jnp.zeros_like(a) - _, b, c, _ = for_impl(5, body2, (a, b, c, 0)) - return b, c - a = jnp.arange(5.) + 1. - b = jnp.ones_like(a[0]) - _, g_lin = jax.linearize(f, a, b) - expected_tangents = g_lin(a, b) - _, actual_tangents = jax.jvp(g, (a, b), (a, b)) - np.testing.assert_allclose(actual_tangents[0], expected_tangents[0]) - np.testing.assert_allclose(actual_tangents[1], expected_tangents[1], - rtol=1e-6) - - @jtu.sample_product( - [dict(for_body_name=for_body_name, f=for_body, ref=ref, - body_shapes=body_shapes, n=nsteps) - for for_body_name, for_body, ref, body_shapes, nsteps in [ - ("noop", for_body_noop, noop_ref, [(4,), (4,)], 4), - ("swap", for_body_swap, swap_ref, [(4,), (4,)], 4), - ("swap_swap", for_body_swap_swap, swap_swap_ref, [(4,), (4,)], 4), - ("sincos", for_body_sincos, sincos_ref, [(4,), (4,)], 4), - ("sincostan", for_body_sincostan, sincostan_ref, [(4,), (4,)], 4), - ("accum", for_body_accum, accum_ref, [(4,), (4,)], 3), - ("sin_sq", for_body_sin_sq, sin_sq_ref, [(4,), (4,)], 4), - ("reverse", for_body_reverse, reverse_ref, [(4,), (4,)], 4), - ] - ], - [dict(for_impl=for_impl, impl_name=impl_name) - for for_impl, impl_name in FOR_LOOP_IMPLS], - ) - @jtu.skip_on_devices("gpu", "cpu") # TODO(mattjj,sharadmv, dougalm): timeouts? - @jtu.skip_on_flag("jax_skip_slow_tests", True) - def test_for_grad(self, f, ref, body_shapes, n, for_impl, for_body_name, - impl_name): - for_ = for_impl - rng = self.rng() - - args = [rng.randn(*s) for s in body_shapes] - - tol = {np.float64: 1e-12, np.float32: 1e-4} - ans = jax.grad(lambda args: for_( n, f, args)[1].sum())(args) - ans_discharged = jax.grad( - lambda args: for_reference(n, f, args)[1].sum())(args) - expected = jax.grad(lambda args: ref(*args)[1].sum())(args) - self.assertAllClose(ans, ans_discharged, check_dtypes=True, rtol=tol, - atol=tol) - self.assertAllClose(ans, expected, check_dtypes=True, rtol=tol, atol=tol) - jtu.check_grads(lambda *args: for_(n, f, args)[1].sum(), args, order=2, - rtol=7e-3, atol=1e-2) - - @jtu.skip_on_devices("gpu", "cpu") # TODO(mattjj,sharadmv, dougalm): timeouts? - @jax.legacy_prng_key('allow') - def test_grad_of_triple_nested_for_loop(self): - - func = lambda x: jnp.sin(x) + 1. - - @jax.jit - def f(x): - out = jnp.zeros_like(x) - def body(i, j, k, refs): - x_ref, out_ref = refs - y = func(x_ref[i, j, k]) - out_ref[i, j, k] += y - return for_loop.for_loop(x.shape, body, (x, out))[1].sum() - - x = random.normal(random.PRNGKey(0), (5, 4, 3)) - ref = lambda x: jax.vmap(jax.vmap(jax.vmap(func)))(x).sum() - self.assertAllClose(f(x), ref(x)) - jtu.check_grads(f, (x,), order=2, atol=0.1, rtol=0.1) - -if __name__ == '__main__': - absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/fused_attention_stablehlo_test.py b/tests/fused_attention_stablehlo_test.py index af0b18b02f37..3a96633fb101 100644 --- a/tests/fused_attention_stablehlo_test.py +++ b/tests/fused_attention_stablehlo_test.py @@ -24,10 +24,9 @@ from jax._src import test_util as jtu from jax._src.cudnn.fused_attention_stablehlo import ( dot_product_attention, - check_is_flash_attention, + paged_attention, check_cudnn_version, MaskType, - AttentionLayout, ) config.parse_flags_with_absl() @@ -116,7 +115,7 @@ def sdpa_train(query: Array, sliding_window_length=sliding_window_length), query, key, value, bias, mask, q_seqlen, kv_seqlen, q_offsets, kv_offsets) query_grad, key_grad, value_grad, bias_grad = sdpa_vjp(grad)[:4] - if bias is not None and len(bias.shape) == 3: + if bias is not None: # has dbias return out, (query_grad, key_grad, value_grad, bias_grad) return out, (query_grad, key_grad, value_grad) @@ -128,6 +127,7 @@ def sdpa_ref(query: Array, mask: Array | None = None, scale: float = 0.5, mask_type: MaskType = MaskType.NO_MASK, + is_bnth: bool = False, dropout_rate: float = 0.1, sliding_window_length: int | None = None) -> Array: @@ -148,11 +148,12 @@ def get_padding_mask(logits): (q_padding + kv_padding).astype(logits.dtype) * large_negative_number return jax.lax.broadcast(combined_padding, logits.shape[:-2]) - def get_encoded_padding_mask(encoded): - S = encoded.shape[1] - encoded_padding = (jax.lax.iota(np.int32, S) < S // 2).astype(encoded.dtype) + def get_encoded_padding_mask(encoded, is_bnth): + dim = 2 if is_bnth else 1 + T = encoded.shape[dim] + encoded_padding = (jax.lax.iota(np.int32, T) < T // 2).astype(encoded.dtype) return jax.lax.broadcast_in_dim( - encoded_padding, encoded.shape, broadcast_dimensions=[1]) + encoded_padding, encoded.shape, broadcast_dimensions=[dim]) def get_sliding_window_mask(logits, window_length): large_negative_number = get_large_negative_number(logits.dtype) @@ -164,9 +165,14 @@ def get_sliding_window_mask(logits, window_length): col_idx <= row_idx - window_length).astype(logits.dtype) * large_negative_number return mask[(*([jnp.newaxis]*(len(logits.shape) - 2)), ...)] - B, T, qN, H = query.shape - _, _, kN, _ = key.shape - logits = jnp.einsum("bqhd,bkhd->bhqk", query, key, preferred_element_type=jnp.float32) + if is_bnth: + B, qN, T, H = query.shape + _, kN, _, _ = key.shape + logits = jnp.einsum("bhqd,bhkd->bhqk", query, key, preferred_element_type=jnp.float32) + else: + B, T, qN, H = query.shape + _, _, kN, _ = key.shape + logits = jnp.einsum("bqhd,bkhd->bhqk", query, key, preferred_element_type=jnp.float32) if scale != 1.0: logits = logits * scale if mask_type == MaskType.CAUSAL: @@ -198,11 +204,14 @@ def get_sliding_window_mask(logits, window_length): dropout_rng = jax.random.key(0) keep = jax.random.bernoulli(dropout_rng, keep_prob, probs.shape) probs = jax.lax.select(keep, probs / keep_prob, jnp.zeros_like(probs)) - encoded = jnp.einsum("bhqk,bkhd->bqhd", probs, value, preferred_element_type=jnp.float32) + if is_bnth: + encoded = jnp.einsum("bhqk,bhkd->bhqd", probs, value, preferred_element_type=jnp.float32) + else: + encoded = jnp.einsum("bhqk,bkhd->bqhd", probs, value, preferred_element_type=jnp.float32) if mask_type == MaskType.PADDING: # cuDNN padding mask generation will mask out output accordingly # make sure the behavior is the same - encoded_mask = get_encoded_padding_mask(encoded) + encoded_mask = get_encoded_padding_mask(encoded, is_bnth) encoded = encoded * encoded_mask return encoded.astype(query.dtype) @@ -214,15 +223,16 @@ def sdpa_train_ref(query: Array, mask: Array | None = None, scale: float = 0.5, mask_type: MaskType = MaskType.NO_MASK, + is_bnth: bool = False, dropout_rate: float = 0.1, sliding_window_length: int | None = None) -> Array: out_ref, sdpa_vjp_ref = jax.vjp( partial( sdpa_ref, scale=scale, mask_type=mask_type, dropout_rate=dropout_rate, - sliding_window_length=sliding_window_length), + sliding_window_length=sliding_window_length, is_bnth=is_bnth), query, key, value, bias, mask) query_grad_ref, key_grad_ref, value_grad_ref, bias_grad_ref, _ = sdpa_vjp_ref(grad) - if bias is not None and len(bias.shape) == 3: + if bias is not None: return out_ref, (query_grad_ref, key_grad_ref, value_grad_ref, bias_grad_ref) return out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) @@ -257,15 +267,10 @@ def dot_product_attention_fp8(query, key, value, fp8_metas): class DotProductAttentionTest(jtu.JaxTestCase): def setUp(self): super().setUp() - try: - cudnn_version = check_cudnn_version() - except RuntimeError as e: - self.skipTest(str(e)) - return - if cudnn_version < 8904: - self.skipTest("Requires >= cuDNN 8.9.4") if not jtu.is_cuda_compute_capability_at_least("8.0"): self.skipTest("Requires at least Ampere arch") + if jtu.is_cuda_version_at_least(13, 0): + self.skipTest("cuDNN creates no execution plans on CUDA 13.0.") @jtu.sample_product( batch_size=[4], @@ -275,6 +280,7 @@ def setUp(self): use_mask=[False, True], use_bias=[False, True], mask_type=[MaskType.NO_MASK], + is_bnth=[False, True], dropout_rate=[0], scale=[0.5], dtype=[jnp.float16, jnp.bfloat16] @@ -282,20 +288,20 @@ def setUp(self): @jtu.run_on_devices("cuda") def test_sdpa(self, batch_size: int, seq_len: int, num_heads: int, head_dim: int, use_mask: bool, use_bias: bool, mask_type: MaskType, - dropout_rate: float, scale: float, dtype: jnp.dtype): + is_bnth: bool, dropout_rate: float, scale: float, dtype: jnp.dtype): if len(jax.local_devices()) < 4: self.skipTest("Require at least 4 devices to run sharding tests.") if use_mask and mask_type != MaskType.NO_MASK: self.skipTest("Either pass in mask or generate mask directly in cuDNN.") k1, k2, k3, k4, k5, k6 = jax.random.split(jax.random.key(0), 6) - query = jax.random.normal( - k1, (batch_size, seq_len, num_heads, head_dim), dtype=dtype) - key = jax.random.normal( - k2, (batch_size, seq_len, num_heads, head_dim), dtype=dtype) - value = jax.random.normal( - k3, (batch_size, seq_len, num_heads, head_dim), dtype=dtype) - grad = jax.random.normal( - k4, (batch_size, seq_len, num_heads, head_dim), dtype=dtype) + if is_bnth: + qkv_shape = (batch_size, num_heads, seq_len, head_dim) + else: + qkv_shape = (batch_size, seq_len, num_heads, head_dim) + query = jax.random.normal(k1, qkv_shape, dtype=dtype) + key = jax.random.normal(k2, qkv_shape, dtype=dtype) + value = jax.random.normal(k3, qkv_shape, dtype=dtype) + grad = jax.random.normal(k4, qkv_shape, dtype=dtype) if use_bias: bias = jax.random.normal( k5, (batch_size, num_heads, seq_len, seq_len), dtype=dtype) @@ -309,7 +315,10 @@ def test_sdpa(self, batch_size: int, seq_len: int, num_heads: int, devices = np.array(jax.local_devices()[:4]) devices = devices.reshape((2, 2)) with Mesh(devices, ("dp", "tp")) as mesh: - qkv_spec = PartitionSpec("dp", None, "tp", None) + if is_bnth: + qkv_spec = PartitionSpec("dp", "tp", None, None) + else: + qkv_spec = PartitionSpec("dp", None, "tp", None) qkv_sharding = NamedSharding(mesh, qkv_spec) if bias is not None: bias_spec = PartitionSpec("dp", "tp", None, None) @@ -331,11 +340,14 @@ def test_sdpa(self, batch_size: int, seq_len: int, num_heads: int, grad = jax.device_put(grad, qkv_sharding) in_shardings = (qkv_sharding, qkv_sharding, qkv_sharding, qkv_sharding, bias_sharding, mask_sharding) - out_shardings = (qkv_sharding, (qkv_sharding, qkv_sharding, qkv_sharding)) + if use_bias: + out_shardings = (qkv_sharding, (qkv_sharding, qkv_sharding, qkv_sharding, bias_sharding)) + else: + out_shardings = (qkv_sharding, (qkv_sharding, qkv_sharding, qkv_sharding)) jitted_sdpa_train = jax.jit( partial( sdpa_train, scale=scale, mask_type=mask_type, - dropout_rate=dropout_rate), + dropout_rate=dropout_rate, is_bnth=is_bnth), in_shardings=in_shardings, out_shardings=out_shardings ) @@ -343,15 +355,21 @@ def test_sdpa(self, batch_size: int, seq_len: int, num_heads: int, jitted_sdpa_train_ref = jax.jit( partial( sdpa_train_ref, scale=scale, mask_type=mask_type, - dropout_rate=dropout_rate), + dropout_rate=dropout_rate, is_bnth=is_bnth), in_shardings=in_shardings, out_shardings=out_shardings ) - out, (query_grad, key_grad, value_grad) = \ - jitted_sdpa_train(query, key, value, grad, bias, mask) - out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = \ - jitted_sdpa_train_ref(query, key, value, grad, bias, mask) + if use_bias: + out, (query_grad, key_grad, value_grad, _) = \ + jitted_sdpa_train(query, key, value, grad, bias, mask) + out_ref, (query_grad_ref, key_grad_ref, value_grad_ref, _) = \ + jitted_sdpa_train_ref(query, key, value, grad, bias, mask) + else: + out, (query_grad, key_grad, value_grad) = \ + jitted_sdpa_train(query, key, value, grad, bias, mask) + out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = \ + jitted_sdpa_train_ref(query, key, value, grad, bias, mask) self.assertArraysAllClose(out_ref, out, rtol=2e-2, atol=2e-2) self.assertArraysAllClose( query_grad_ref, query_grad, rtol=2e-1, atol=2e-1) @@ -434,17 +452,13 @@ def test_sdpa_var_seq(self): self.assertArraysAllClose(key_grad_ref, key_grad, rtol=2e-1, atol=2e-1) self.assertArraysAllClose(value_grad_ref, value_grad, rtol=2e-1, atol=2e-1) + @jtu.sample_product( + broadcast_dims=[(), (0,), (1,), (0,1)], + ) @jtu.run_on_devices("cuda") - def test_sdpa_broadcast_bias_and_dbias(self): + def test_sdpa_broadcast_bias_and_dbias(self, broadcast_dims): if jax.device_count() < 4: self.skipTest("Requires more than 4 devices.") - try: - cudnn_version = check_cudnn_version() - except RuntimeError as e: - self.skipTest(str(e)) - return - if cudnn_version < 8906: - self.skipTest("Requires >= cuDNN 8.9.6") if not jtu.is_cuda_compute_capability_at_least("9.0"): self.skipTest("Requires at least Hopper arch") @@ -457,14 +471,29 @@ def test_sdpa_broadcast_bias_and_dbias(self): k3, (4, 1024, 4, 64), dtype=jnp.bfloat16) grad = jax.random.normal( k4, (4, 1024, 4, 64), dtype=jnp.bfloat16) - bias = jax.random.normal( - k5, (4, 1024, 1024), dtype=jnp.bfloat16) + + if broadcast_dims == (): + bias = jax.random.normal( + k5, (4, 4, 1024, 1024), dtype=jnp.bfloat16) + bias_spec = PartitionSpec("dp", "tp", None, None) + elif broadcast_dims == (0,): + bias = jax.random.normal( + k5, (1, 4, 1024, 1024), dtype=jnp.bfloat16) + bias_spec = PartitionSpec(None, "tp", None, None) + elif broadcast_dims == (1,): + bias = jax.random.normal( + k5, (4, 1, 1024, 1024), dtype=jnp.bfloat16) + bias_spec = PartitionSpec("dp", None, None, None) + else: + bias = jax.random.normal( + k5, (1, 1, 1024, 1024), dtype=jnp.bfloat16) + bias_spec = PartitionSpec(None, None, None, None) + devices = np.array(jax.local_devices()[:4]) devices = devices.reshape((2, 2)) with Mesh(devices, ("dp", "tp")) as mesh: qkv_spec = PartitionSpec("dp", None, "tp", None) qkv_sharding = NamedSharding(mesh, qkv_spec) - bias_spec = PartitionSpec("tp", None, None) bias_sharding = NamedSharding(mesh, bias_spec) in_shardings = (qkv_sharding, qkv_sharding, qkv_sharding, qkv_sharding, bias_sharding) @@ -536,8 +565,6 @@ def attn_vjp(x, bias, mask, target_fn): _, dbias_ans, _ = attn_ans(x, bias, mask) dbias_ans = jnp.squeeze(dbias_ans, axis=1) self.assertArraysAllClose(dbias_ans, dbias_ref) - if batch_size != 1: - self.assertTrue(not jnp.any(dbias_ans)) @jtu.run_on_devices("cuda") def test_sdpa_sliding_window_length(self): @@ -576,13 +603,6 @@ def test_sdpa_sliding_window_length(self): @jtu.run_on_devices("cuda") def test_sdpa_large_head_size(self): - try: - cudnn_version = check_cudnn_version() - except RuntimeError as e: - self.skipTest(str(e)) - return - if cudnn_version < 90500: - self.skipTest("Requires >= cuDNN 9.5.0") if not jtu.is_cuda_compute_capability_equal("9.0"): self.skipTest("Requires Hopper arch") @@ -611,13 +631,8 @@ def test_sdpa_large_head_size(self): def test_sdpa_packed_layout(self): if jax.device_count() < 4: self.skipTest("Requires more than 4 devices.") - try: - cudnn_version = check_cudnn_version() - except RuntimeError as e: - self.skipTest(str(e)) - return - if cudnn_version < 90600: - self.skipTest("Requires >= cuDNN 9.6.0") + if not jtu.is_cuda_compute_capability_at_least("9.0"): + self.skipTest("Requires at least Hopper arch") k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4) query = jax.random.normal( k1, (4, 512, 4, 64), dtype=jnp.bfloat16) @@ -707,7 +722,7 @@ def generate_segment_mask(segment_ids, dtype): sdpa_train_ref, scale=0.1, mask_type=MaskType.NO_MASK, dropout_rate=0), in_shardings=(qkv_sharding, qkv_sharding, qkv_sharding, qkv_sharding, bias_sharding), - out_shardings=(qkv_sharding, (qkv_sharding, qkv_sharding, qkv_sharding)) + out_shardings=(qkv_sharding, (qkv_sharding, qkv_sharding, qkv_sharding, bias_sharding)) ) query = query * mask @@ -717,7 +732,7 @@ def generate_segment_mask(segment_ids, dtype): out, (query_grad, key_grad, value_grad) = \ jitted_sdpa_train(query, key, value, grad, None, None, q_seqlen, kv_seqlen, q_offsets, kv_offsets) - out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = \ + out_ref, (query_grad_ref, key_grad_ref, value_grad_ref, _) = \ jitted_sdpa_train_ref(query, key, value, grad, bias) out = out * mask @@ -738,60 +753,161 @@ def generate_segment_mask(segment_ids, dtype): self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-2, atol=1e-2) @jtu.run_on_devices("cuda") - def test_layouts(self): - if jax.device_count() < 4: - self.skipTest("Requires more than 4 devices.") - dtype = "bfloat16" - B, T, N, H = 4, 1024, 8, 128 - S = T - k0, k1, k2, k3 = jax.random.split(jax.random.key(123), 4) - query = jax.random.normal(k0, (B, T, N, H), dtype=dtype) - key = jax.random.normal(k1, (B, S, N, H), dtype=dtype) - value = jax.random.normal(k2, (B, S, N, H), dtype=dtype) - grad = jax.random.normal(k3, (B, T, N, H), dtype=dtype) - - btnh_fn = jax.jit(partial(sdpa_train, scale=.5, - mask_type=MaskType.CAUSAL, is_bnth=False, dropout_rate=0.0)) - out_ref, (dq_ref, dk_ref, dv_ref) = btnh_fn(query, key, value, grad) - - def _cvt(x): - return jnp.einsum("BTNH->BNTH", x) - def _cvt_back(x): - return jnp.einsum("BNTH->BTNH", x) - bnth_fn = jax.jit(partial(sdpa_train, scale=.5, mask_type=MaskType.CAUSAL, - is_bnth=True, dropout_rate=0.0)) - out, (dq, dk, dv) = bnth_fn(_cvt(query), _cvt(key), _cvt(value), _cvt(grad)) - - self.assertArraysAllClose(out_ref, _cvt_back(out)) - self.assertArraysAllClose(dq_ref, _cvt_back(dq)) - self.assertArraysAllClose(dk_ref, _cvt_back(dk)) - self.assertArraysAllClose(dv_ref, _cvt_back(dv)) - - def test_sdpa_utils(self): + def test_sdpa_residual(self): + k1, k2, k3, k4, k5 = jax.random.split(jax.random.key(0), 5) + query = jax.random.normal( + k1, (4, 1024, 4, 64), dtype=jnp.bfloat16) + key = jax.random.normal( + k2, (4, 1024, 4, 64), dtype=jnp.bfloat16) + value = jax.random.normal( + k3, (4, 1024, 4, 64), dtype=jnp.bfloat16) + grad = jax.random.normal( + k4, (4, 1024, 4, 64), dtype=jnp.bfloat16) + grad_stat = jax.random.normal( + k5, (4, 4, 1024), dtype=jnp.float32) + + devices = np.array(jax.local_devices()[:2]) + with Mesh(devices, ("dp")) as mesh: + qkv_spec = PartitionSpec("dp", None, None, None) + stat_spec = PartitionSpec("dp", None, None) + qkv_sharding = NamedSharding(mesh, qkv_spec) + stat_sharding = NamedSharding(mesh, stat_spec) + + query = jax.device_put(query, qkv_sharding) + key = jax.device_put(key, qkv_sharding) + value = jax.device_put(value, qkv_sharding) + grad = jax.device_put(grad, qkv_sharding) + grad_stat = jax.device_put(grad_stat, stat_sharding) + + jitted_sdpa_inference = jax.jit( + partial( + dot_product_attention, scale=1.0, mask_type=MaskType.NO_MASK, + dropout_rate=0, return_residual=True), + in_shardings=(qkv_sharding, qkv_sharding, qkv_sharding), + out_shardings=(qkv_sharding, stat_sharding) + ) + + outs = jitted_sdpa_inference(query, key, value) + assert len(outs) == 2 + + def train(query, key, value, grads): + outs, grad_fn = jax.vjp(partial( + dot_product_attention, scale=1.0, mask_type=MaskType.NO_MASK, + dropout_rate=0, return_residual=True), query, key, value) + return outs, grad_fn(grads) + jitted_sdpa_train = jax.jit(train, + in_shardings=(qkv_sharding, qkv_sharding, qkv_sharding, (qkv_sharding, stat_sharding)), + out_shardings=((qkv_sharding, stat_sharding), (qkv_sharding, qkv_sharding, qkv_sharding))) + outs = jitted_sdpa_train(query, key, value, (grad, grad_stat)) + assert len(outs) == 2 + + @jtu.sample_product( + batch_size=[4], + q_seq_len=[1, 1024], + kv_seq_len=[1024], + num_heads=[8], + head_dim=[64, 128], + block_size=[64, 128], + dtype=[jnp.float16, jnp.bfloat16] + ) + @jtu.run_on_devices("cuda") + def test_sdpa_paged_attention(self, batch_size, q_seq_len, kv_seq_len, + num_heads, head_dim, block_size, dtype): + + keys = jax.random.split(jax.random.key(0), 5) + blocks_per_batch = kv_seq_len // block_size + num_blocks = batch_size * blocks_per_batch + + # different q_seq_len for prefill and decode + q = jax.random.normal( + keys[0], (batch_size, q_seq_len, num_heads, head_dim), dtype=dtype) + k_container = jax.random.normal( + keys[1], (num_blocks, block_size, num_heads, head_dim), dtype=dtype) + v_container = jax.random.normal( + keys[2], (num_blocks, block_size, num_heads, head_dim), dtype=dtype) + page_table_k = jax.random.randint( + keys[3], (batch_size, 1, blocks_per_batch, 1), 0, num_blocks-1, dtype=jnp.int32) + page_table_v = jax.random.randint( + keys[4], (batch_size, 1, blocks_per_batch, 1), 0, num_blocks-1, dtype=jnp.int32) + # full page table + q_seqlen = jnp.full((batch_size,), q_seq_len, jnp.int32) + kv_seqlen = jnp.full((batch_size,), kv_seq_len, jnp.int32) + + def unpaged(paged, page_table): + output = jnp.zeros((batch_size, kv_seq_len, num_heads, head_dim), dtype=dtype) + for b in range(batch_size): + for block in range(blocks_per_batch): + block_idx = page_table[b, 0, block, 0] + output = output.at[ + b, block * block_size : (block + 1) * block_size, :, : + ].set(paged[block_idx, :, :, :]) + return output + + k = unpaged(k_container, page_table_k) + v = unpaged(v_container, page_table_v) + + sdpa_infer = jax.jit(partial( + paged_attention, scale=1.0, mask_type=MaskType.NO_MASK) + ) + sdpa_infer_ref = jax.jit(partial( + sdpa_ref, scale=1.0, mask_type=MaskType.NO_MASK, dropout_rate=0) + ) + + out = sdpa_infer(q, k_container, v_container, q_seqlen=q_seqlen, + kv_seqlen=kv_seqlen, page_table_k=page_table_k, page_table_v=page_table_v) + out_ref = sdpa_infer_ref(q, k, v) + self.assertArraysAllClose(out_ref, out_ref, rtol=1e-2, atol=1e-2) + + @jtu.run_on_devices("cuda") + def test_sdpa_mla(self): if jax.device_count() < 4: self.skipTest("Requires more than 4 devices.") - test_cases = [ - (1, 257, 64, 8905, False, True, True), - (1, 1024, 64, 8905, False, False, True), - (1024, 1024, 64, 8905, False, False, True), - (1024, 1024, 128, 8905, False, False, True), - (1024, 1024, 127, 8905, False, False, False), - ] - - for k in test_cases: - sql_q, sql_v, head_dim, cudnn_version, has_bias, is_training, \ - expected_pass = k - query = jnp.empty((4, sql_q, 4, head_dim)) - key = jnp.empty((4, sql_v, 4, head_dim)) - if expected_pass: - check_is_flash_attention( - query, key, AttentionLayout.BNTH.value, cudnn_version, has_bias, - is_training) - else: - with self.assertRaises(NotImplementedError): - check_is_flash_attention( - query, key, AttentionLayout.BNTH.value, cudnn_version, has_bias, - is_training) + try: + cudnn_version = check_cudnn_version() + except RuntimeError as e: + self.skipTest(str(e)) + if cudnn_version < 91000: + self.skipTest("Requires >= cuDNN 9.10.0") + if not jtu.is_cuda_compute_capability_at_least("9.0"): + self.skipTest("Requires at least Hopper arch") + k1, k2, k3 = jax.random.split(jax.random.key(0), 3) + query = jax.random.normal( + k1, (4, 1024, 4, 128), dtype=jnp.bfloat16) + key = jax.random.normal( + k2, (4, 1024, 4, 128), dtype=jnp.bfloat16) + value = jax.random.normal( + k3, (4, 1024, 4, 64), dtype=jnp.bfloat16) + + devices = np.array(jax.local_devices()[:4]) + devices = devices.reshape((2, 2)) + with Mesh(devices, ("dp", "tp")) as mesh: + qkv_spec = PartitionSpec("dp", None, "tp", None) + qkv_sharding = NamedSharding(mesh, qkv_spec) + in_shardings = ( + qkv_sharding, qkv_sharding, qkv_sharding) + out_shardings = qkv_sharding + query = jax.device_put(query, qkv_sharding) + key = jax.device_put(key, qkv_sharding) + value = jax.device_put(value, qkv_sharding) + + jitted_sdpa_inference = jax.jit( + partial( + dot_product_attention, scale=1.0, mask_type=MaskType.NO_MASK, + dropout_rate=0), + in_shardings=in_shardings, + out_shardings=out_shardings + ) + + jitted_sdpa_inference_ref = jax.jit( + partial( + sdpa_ref, scale=1.0, mask_type=MaskType.NO_MASK, dropout_rate=0), + in_shardings=in_shardings, + out_shardings=out_shardings + ) + + out = jitted_sdpa_inference(query, key, value) + out_ref = jitted_sdpa_inference_ref(query, key, value) + self.assertArraysAllClose(out_ref, out, rtol=2e-2, atol=2e-2) @jtu.with_config(jax_numpy_dtype_promotion="standard") @@ -804,10 +920,14 @@ def setUp(self): except RuntimeError as e: self.skipTest(str(e)) return - if cudnn_version < 90100: - self.skipTest("Requires >= cuDNN 9.1.0") + if cudnn_version == 91000: + self.skipTest("cuDNN 9.10.0 does not support SDPA FP8") if not jtu.is_cuda_compute_capability_at_least("9.0"): self.skipTest("Requires at least Hopper arch") + if jtu.is_cuda_compute_capability_equal("12.0"): + self.skipTest("cuDNN does not support FP8 with compute capability 12.0") + if jtu.is_cuda_version_at_least(13, 0): + self.skipTest("cuDNN creates no execution plans on CUDA 13.0.") @jtu.sample_product( batch_size=[2, 4], @@ -953,7 +1073,7 @@ def dot_product_attention_fp8(query, key, value, fp8_metas): query_quantized, key_quantized, value_quantized, fp8_metas ) out_ref = jitted_sdpa_inference_ref(query, key, value) - self.assertArraysAllClose(out_ref, out.astype(dtype), rtol=5e-2, atol=5e-2) + self.assertArraysAllClose(out_ref, out.astype(dtype), rtol=8e-2, atol=8e-2) if __name__ == "__main__": diff --git a/tests/fused_test.py b/tests/fused_test.py new file mode 100644 index 000000000000..96cf306ffbfa --- /dev/null +++ b/tests/fused_test.py @@ -0,0 +1,67 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest + +import jax +import jax.numpy as jnp +from jax._src import test_util as jtu + +from jax.experimental.fused import fused + +jax.config.parse_flags_with_absl() + +@fused(out_spaces=(jax.memory.Space.Host, jax.memory.Space.Device)) +def f(x, y): + z = x + y + w = x * y + return z, w + +class FusedTest(jtu.JaxTestCase): + + def test_basic(self): + x = jnp.arange(3.) + x_host = jax.device_put(x, jax.memory.Space.Host) + y_device = jnp.arange(3.) + low = jax.jit(f).trace(x_host, y_device).lower(lowering_platforms=('cuda',)) + txt = low._lowering.hlo().as_hlo_module().to_string() + self.assertIn('custom_call', txt) + self.assertIn('inlineable', txt) + self.assertIn('MUST_FUSE', txt) + self.assertIn('out_spaces', txt) + + def test_vmap_basic(self): + x = jnp.arange(3.) + x_host = jax.device_put(x, jax.memory.Space.Host) + y_device = jnp.arange(3.) + f_ = jax.jit(jax.vmap(f)) + f_.trace(x_host, y_device).lower(lowering_platforms=('cuda',)) # don't crash + + def test_jvp_basic(self): + x = jnp.arange(3.) + x_host = jax.device_put(x, jax.memory.Space.Host) + y_device = jnp.arange(3.) + f_ = jax.jit(lambda x, y: jax.jvp(f, (x, y), (x, y))) + f_.trace(x_host, y_device).lower(lowering_platforms=('cuda',)) # don't crash + + def test_grad_basic(self): + x = jnp.arange(3.) + x_host = jax.device_put(x, jax.memory.Space.Host) + y_device = jnp.arange(3.) + f_ = jax.jit(jax.grad(lambda x, y: f(x, y)[1].sum())) + f_.trace(x_host, y_device).lower(lowering_platforms=('cuda',)) # don't crash + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/generated_fun_test.py b/tests/generated_fun_test.py index cdfeeba6275b..67c19179bb8b 100644 --- a/tests/generated_fun_test.py +++ b/tests/generated_fun_test.py @@ -218,7 +218,7 @@ def check_all_close(xs, ys, tol=1e-3): def check_close(x, y, tol=1e-3): assert jnp.shape(x) == jnp.shape(y) - # TODO(dougalm): re-enable once we've tackled the less pendantic bugs + # TODO(dougalm): re-enable once we've tackled the less pedantic bugs # assert x.dtype == y.dtype assert jnp.allclose(x, y, rtol=tol, atol=tol), \ f"Value mismatch:\n{x}\n vs\n{y}\n" diff --git a/tests/gpu_memory_flags_test.py b/tests/gpu_memory_flags_test.py index 308fff257348..bada2bebc74e 100644 --- a/tests/gpu_memory_flags_test.py +++ b/tests/gpu_memory_flags_test.py @@ -29,7 +29,7 @@ class GpuMemoryAllocationTest(absltest.TestCase): @jtu.skip_under_pytest("Test must run in an isolated process") @unittest.skipIf( "XLA_PYTHON_CLIENT_ALLOCATOR" in os.environ, - "Test does not work if the python client allocator has been overriden", + "Test does not work if the python client allocator has been overridden", ) def test_gpu_memory_allocation(self): falsey_values = ("0", "False", "false") @@ -40,7 +40,7 @@ def test_gpu_memory_allocation(self): device = jax.devices()[0] mem_stats = device.memory_stats() self.assertEqual(mem_stats["pool_bytes"], 0) - x = jax.lax.add(1, 2) + x = jax.lax.add(1, 2).block_until_ready() mem_stats = device.memory_stats() if preallocate: diff --git a/tests/hijax_test.py b/tests/hijax_test.py new file mode 100644 index 000000000000..ac4307f7920b --- /dev/null +++ b/tests/hijax_test.py @@ -0,0 +1,1696 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from functools import partial +import itertools as it +from typing import Any +import unittest + +from absl.testing import absltest, parameterized + +import jax +import jax.numpy as jnp +from jax import typeof + +from jax._src import config +from jax._src import core +from jax._src import dtypes +from jax._src import state +from jax._src.state import indexing +from jax._src.state import primitives as state_primitives +from jax._src.interpreters import ad +from jax._src import test_util as jtu +from jax._src.util import safe_zip, safe_map +from jax._src.state.discharge import run_state + +from jax._src.hijax import ( + HiPrimitive, HiType, Box, new_box, box_set, box_get, box_effect, + register_hitype, ShapedArray, Ty, custom_vjp3) +from jax.experimental.hijax import VJPHiPrimitive + +jtu.request_cpu_devices(2) + +config.parse_flags_with_absl() + +map, unsafe_map = safe_map, map +zip, unsafe_zip = safe_zip, zip + + +@dataclass(frozen=True) +class QArray: + arr: jax.Array # int8[m, k] + scale: jax.Array # f32[m] + +# Define a type +@dataclass(frozen=True) +class QArrayTy(HiType): + shape: tuple[int, int] + + # how to lower to (lo)jax types + def lo_ty(self) -> list[ShapedArray]: + m, k = self.shape + return [ShapedArray((m, k), jnp.dtype('int8')), + ShapedArray((m, ), jnp.dtype('float32'))] + # these next two are essentially the pytree interface + def lower_val(self, hi_val: QArray) -> list[jax.Array]: + return [hi_val.arr, hi_val.scale] + def raise_val(self, arr, scale) -> QArray: + return QArray(arr, scale) # alternative: LowerTrace + + def ref_get_abstract_eval(self, ref_aval, *args, tree): + arr_aval = core.ShapedArray(self.shape, jnp.dtype('float32')) + updated_ref = ref_aval.update(inner_aval=arr_aval) + out, effects = state_primitives.get_p.abstract_eval( + updated_ref, *args, tree=tree + ) + assert isinstance(out, core.ShapedArray) + return QArrayTy(out.shape), effects + + def ref_swap_abstract_eval(self, ref_aval, val_aval, *args, tree): + arr_aval = core.ShapedArray(self.shape, jnp.dtype('float32')) + val_arr_aval = core.ShapedArray(val_aval.shape, jnp.dtype('float32')) + updated_ref = ref_aval.update(inner_aval=arr_aval) + out_aval, effects = state_primitives.swap_p.abstract_eval( + updated_ref, val_arr_aval,*args, tree=tree + ) + assert isinstance(out_aval, core.ShapedArray) + return QArrayTy(out_aval.shape), effects + + def ref_get_to_lojax(self, ref: state.TransformedRef | jax.Ref, + idx: indexing.NDIndexer): + if isinstance(ref, state.TransformedRef): + if ref.transforms: raise NotImplementedError(ref) + ref = ref.ref + # Unpack Ref type + ref = ref._refs + if not all(i.start == 0 and i.size == s + for i, s in zip(idx.indices, ref.arr.shape)): + raise NotImplementedError + outs = [out.get() for out in self.lower_val(ref)] + return self.raise_val(*outs) + + def ref_swap_to_lojax(self, ref: state.TransformedRef | jax.Ref, + val: jax.Array, idx: indexing.NDIndexer): + if isinstance(ref, state.TransformedRef): + if ref.transforms: raise NotImplementedError(ref) + ref = ref.ref + # Unpack Ref type + ref = ref._refs + if not all(i.start == 0 and i.size == s + for i, s in zip(idx.indices, ref.arr.shape)): + raise NotImplementedError + outs = [out.swap(val) for out, val + in zip(self.lower_val(ref), self.lower_val(val))] + return self.raise_val(*outs) + + # autodiff + def to_tangent_aval(self): + return self # different from what a pytree would do! + def vspace_zero(self): + m, k = self.shape + return QArray(jnp.zeros((m, k), jnp.dtype('int8')), + jnp.ones ((m, ), jnp.dtype('float32'))) + +register_hitype(QArray, lambda q: QArrayTy(q.arr.shape)) + +def to_qarray(x): + return to_qarray_p.bind(x) + +def from_qarray(x): + return from_qarray_p.bind(x) + +class ToQ(HiPrimitive): + def abstract_eval(_, lo_aval): + return QArrayTy(lo_aval.shape), set() + + def to_lojax(_, lo_val): + m, _ = lo_val.shape + scale = lo_val.max(1) / 32. + return QArray((lo_val / scale[:, None]).astype('int8'), scale) + + def jvp(_, primals, tangents): + (x,), (xdot,) = primals, tangents + return to_qarray(x), to_qarray(xdot) + + def transpose(_, out_bar, __): + return [from_qarray(out_bar)] +to_qarray_p = ToQ('to_q') + +class FromQ(HiPrimitive): + def abstract_eval(_, hi_aval): + return ShapedArray(hi_aval.shape, jnp.dtype('float32')), set() + + def to_lojax(_, hi_val): + return hi_val.arr.astype('float32') * hi_val.scale[:, None] + + def jvp(_, primals, tangents): + (x,), (xdot,) = primals, tangents + return from_qarray(x), from_qarray(xdot) + + def transpose(_, out_bar, __): + return [to_qarray(out_bar)] +from_qarray_p = FromQ('from_q') + + +@dataclass +class HiTup: + elts: tuple + def __repr__(self): + return 'Tup{' + ','.join(map(repr, self.elts)) + '}' + +@dataclass(frozen=True) +class TupTy(HiType): + tys: tuple[Ty] + + def __repr__(self): + return 'Tup{' + ','.join(a.str_short() for a in self.tys) + '}' + + def __hash__(self): + return hash(self.tys) + + def __eq__(self, other): + return self.tys == other.tys + + def lo_ty(self): + return list(self.tys) + + def lower_val(self, hi_val: HiTup): + return [lo for ty, elt in zip(self.tys, hi_val.elts) + for lo in ty.lower_val(elt)] + + def raise_val(self, *elts_flat): + elts_iter = iter(elts_flat) + return HiTup(tuple(ty.raise_val(*it.islice(elts_iter, len(ty.lo_ty()))) + for ty in self.tys)) + + def to_tangent_aval(self): + return TupTy(tuple(ty.to_tangent_aval() for ty in self.tys)) + + def normalize(self): + return TupTy(tuple(ty.normalize() for ty in self.tys)) + +register_hitype(HiTup, lambda t: TupTy(tuple(map(typeof, t.elts)))) + +class MakeTup(HiPrimitive): + def abstract_eval(_, *in_avals): + return TupTy(in_avals), set() + + def to_lojax(self, *elts): + return HiTup(elts) +make_tup_p = MakeTup('make_tup') + +class GetTupElt(HiPrimitive): + def abstract_eval(_, tup, *, idx): + return tup.tys[idx], set() + + def to_lojax(self, tup, *, idx): + return tup.elts[idx] + + def jvp(self, primals, tangents, *, idx): + (tup,), (tup_dot,) = primals, tangents + return tup.elts[idx], get_tuple_element(tup_dot, idx) + + def transpose(self, out_bar, tup, *, idx): + if ad.is_undefined_primal(tup): + tup_ty = tup.aval + else: + tup_ty = tup + out_elts = [ + jnp.zeros(elt_ty.shape, elt_ty.dtype) for elt_ty in tup_ty.tys + ] + out_elts[idx] = out_bar + return [make_tup(*out_elts)] + +get_tup_elt_p = GetTupElt('get_tup_elt') + +def make_tup(*elts): + return make_tup_p.bind(*elts) + +def get_tuple_element(tup, idx): + return get_tup_elt_p.bind(tup, idx=idx) + +@dataclass(frozen=True) +class ImmutBox: + _val: Any + + @property + def shape(self): + if hasattr(self._val, 'shape'): + return self._val.shape + leaves = jax.tree.leaves(self._val) + if leaves and hasattr(leaves[0], 'shape'): + return leaves[0].shape + raise AttributeError(f"ImmutBox with value {self._val} has no shape") + + @property + def ndim(self): + return len(self.shape) + +def _is_zero(x): + return isinstance(x, ad.Zero) + +def _get_aval(x): + return x.aval if _is_zero(x) else core.typeof(x) + +def immutbox_to_aval(box: ImmutBox) -> 'ImmutBoxTy': + leaves, treedef = jax.tree.flatten(box._val, is_leaf=_is_zero) + leaf_avals = tuple(map(_get_aval, leaves)) + return ImmutBoxTy(leaf_avals, treedef) + +@dataclass(frozen=True) +class ImmutBoxTy(HiType): + leaf_avals: tuple[core.AbstractValue, ...] + treedef: Any + has_qdd = False + + @property + def shape(self): + reconstructed = jax.tree.unflatten(self.treedef, self.leaf_avals) + if hasattr(reconstructed, 'shape'): + return reconstructed.shape + if self.leaf_avals and hasattr(self.leaf_avals[0], 'shape'): + return self.leaf_avals[0].shape + raise AttributeError(f"ImmutBoxTy with treedef {self.treedef} has no shape") + + @property + def ndim(self): + return len(self.shape) + + @property + def sharding(self): + reconstructed = jax.tree.unflatten(self.treedef, self.leaf_avals) + if hasattr(reconstructed, 'sharding'): + return reconstructed.sharding + if self.leaf_avals and hasattr(self.leaf_avals[0], 'sharding'): + return self.leaf_avals[0].sharding + return None + + def lo_ty(self): + return list(self.leaf_avals) + + def lower_val(self, hi_val: ImmutBox): + leaves, treedef = jax.tree.flatten(hi_val._val, is_leaf=_is_zero) + assert treedef == self.treedef + return leaves + + def raise_val(self, *lo_vals): + return ImmutBox(jax.tree.unflatten(self.treedef, lo_vals)) + + def to_tangent_aval(self): + tangent_leaf_avals = tuple(aval.to_tangent_aval() for aval in self.leaf_avals) + return ImmutBoxTy(tangent_leaf_avals, self.treedef) + +def _map_immutbox_ty(size: int, axis: int | None, aval: ImmutBoxTy) -> ImmutBoxTy: + if axis is None: + return aval + mapped_leaf_avals = tuple(core.mapped_aval(size, axis, leaf_aval) + for leaf_aval in aval.leaf_avals) + return ImmutBoxTy(mapped_leaf_avals, aval.treedef) + +def _unmap_immutbox_ty(size: int, axis: int | None, explicit_mesh_axis, + aval: ImmutBoxTy) -> ImmutBoxTy: + if axis is None: + return aval + elif isinstance(axis, int): + unmapped_leaf_avals = tuple(core.unmapped_aval(size, axis, explicit_mesh_axis, leaf_aval) + for leaf_aval in aval.leaf_avals) + return ImmutBoxTy(unmapped_leaf_avals, aval.treedef) + else: + raise TypeError(axis) + +core.aval_mapping_handlers[ImmutBoxTy] = (_map_immutbox_ty, _unmap_immutbox_ty) + +class ImmutBoxNew(HiPrimitive): + def is_high(self, *leaves, leaf_avals, treedef) -> bool: + return True + + def abstract_eval(self, *leaves, leaf_avals, treedef): + return ImmutBoxTy(leaf_avals, treedef), set() + + def to_lojax(self, *leaves, leaf_avals, treedef): + val = jax.tree.unflatten(treedef, leaves) + return ImmutBox(val) + + def jvp(self, primals, tangents, *, leaf_avals, treedef): + return (immutbox_new_p.bind(*primals, leaf_avals=leaf_avals, treedef=treedef), + immutbox_new_p.bind(*tangents, leaf_avals=leaf_avals, treedef=treedef)) + + def transpose(self, out_bar, *leaves, leaf_avals, treedef): + val = out_bar._val + leaves, _ = jax.tree.flatten(val, is_leaf=_is_zero) + return leaves + +immutbox_new_p = ImmutBoxNew('immutbox_new') + +def immutbox_new(val): + leaves, treedef = jax.tree.flatten(val, is_leaf=_is_zero) + leaf_avals = tuple(map(_get_aval, leaves)) + leaves = [ad.instantiate_zeros(leaf) for leaf in leaves] + return immutbox_new_p.bind(*leaves, leaf_avals=leaf_avals, treedef=treedef) + +class ImmutBoxGet(HiPrimitive): + multiple_results = True + + def is_high(self, box_aval) -> bool: + return True + + def abstract_eval(self, box_aval): + leaf_avals = box_aval.leaf_avals + return list(leaf_avals), set() + + def to_lojax(self, box): + leaves, _ = jax.tree.flatten(box._val, is_leaf=_is_zero) + return tuple(leaves) + + def jvp(self, primals, tangents): + (box,), (box_dot,) = primals, tangents + return immutbox_get(box), immutbox_get(box_dot) + + def transpose(self, out_bars, box): + box_aval = core.typeof(box) if not ad.is_undefined_primal(box) else box.aval + treedef = box_aval.treedef + reconstructed_cotangent = jax.tree.unflatten(treedef, out_bars) + return (immutbox_new(reconstructed_cotangent),) + +immutbox_get_p = ImmutBoxGet('immutbox_get') + +def immutbox_get(box): + leaves = immutbox_get_p.bind(box) + box_ty = core.typeof(box) + return jax.tree.unflatten(box_ty.treedef, leaves) + +register_hitype(ImmutBox, immutbox_to_aval) + + +class HijaxTest(jtu.JaxTestCase): + def test_basic_register(self): + # older test that defines a slightly different QArray internally + @dataclass(frozen=True) + class QArray: + arr: jax.Array + scale: jax.Array + axis: int + + @dataclass(frozen=True) + class QArrayTy(HiType): + shape: tuple[int, int] + axis: int + + ndim = property(lambda self: len(self.shape)) + + # how to lower to (lo)jax types + def lo_ty(self) -> list[ShapedArray]: + m, k = self.shape + return [ShapedArray((m, k), jnp.dtype('int8')), + ShapedArray((m, ), jnp.dtype('float32'))] + + # these next two are essentially the pytree interface + def lower_val(self, hi_val: QArray) -> list[jax.Array]: + return [hi_val.arr, hi_val.scale] + def raise_val(self, arr, scale) -> QArray: + return QArray(arr, scale, self.axis) + + register_hitype(QArray, lambda q: QArrayTy(q.arr.shape, q.axis)) + + q = QArray(jnp.zeros((4, 4), 'int8'), jnp.ones(4, 'float32'), axis=1) + jax.jit(lambda x: x)(q) # don't crash + + def test_custom_types_and_primitive(self): + if config.enable_x64.value: raise unittest.SkipTest("no x64") + + @dataclass(frozen=True) + class MyArray: + arr: jax.Array # always f32 + + @dataclass(frozen=True) + class MyTy(HiType): + def to_tangent_aval(self): + return MyTy() + def str_short(self, short_dtypes=False): + return 'MyTy' + def lo_ty(self): + return [core.ShapedArray((), jnp.dtype('float32'))] + def lower_val(self, hi_val: MyArray) -> list[jax.Array]: + return [hi_val.arr] + def raise_val(self, val) -> MyArray: + return MyArray(val) + + def __eq__(self, other): return isinstance(other, MyTy) + + def vspace_zero(self): + return MyArray(jnp.zeros((), 'float32')) + def vspace_add(self, x, y): + return add(x, y) + core.pytype_aval_mappings[MyArray] = lambda _: MyTy() + dtypes.canonicalize_value_handlers[MyArray] = lambda x: x + + class ToMy(HiPrimitive): + def is_high(self, _): return True + + def abstract_eval(_, lo_aval): + return MyTy(), set() + + def to_lojax(_, lo): + return MyArray(lo) + + def jvp(_, primals, tangents): + x, x_dot = *primals, *tangents + return to(x), to(x_dot) + + def transpose(self, out_bar, _): + return from_(out_bar), + + class FromMy(HiPrimitive): + def is_high(self, _): return True + + def abstract_eval(_, hi_aval): + return hi_aval.lo_ty()[0], set() + + def to_lojax(_, hi): + return hi.arr + + def jvp(_, primals, tangents): + x, x_dot = *primals, *tangents + return from_(x), from_(x_dot) + + def transpose(self, out_bar, _): + return to(out_bar), + + def to(x): return to_p.bind(x) + to_p = ToMy('to_my') + + def from_(x): return from_p.bind(x) + from_p = FromMy('from_my') + + def mul(x, y): return mul_p.bind(x, y) + def add(x, y): return add_p.bind(x, y) + + class MyMul(HiPrimitive): + def is_high(self, *_): return True + + def abstract_eval(_, hi_x, hi_y): + if hi_x != hi_y: raise Exception + return hi_x, set() + + def to_lojax(_, hi_x, hi_y): + return MyArray(hi_x.arr * hi_y.arr) + + def jvp(_, primals, tangents): + (x, y), (x_dot, y_dot) = primals, tangents + return mul(x, y), add(mul(x, y_dot), mul(x_dot, y)) + + def transpose(self, out_bar, x, y): + assert ad.is_undefined_primal(x) ^ ad.is_undefined_primal(y) + if ad.is_undefined_primal(x): + return mul(out_bar, y), None + else: + return None, mul(x, out_bar) + + class MyAdd(HiPrimitive): + def is_high(self, *_): return True + + def abstract_eval(_, hi_x, hi_y): + if hi_x != hi_y: raise Exception + return hi_x, set() + + def to_lojax(_, hi_x, hi_y): + return MyArray(hi_x.arr + hi_y.arr) + + def jvp(_, primals, tangents): + assert False # TODO + + def transpose(self, out_bar, x, y): + return out_bar, out_bar + + mul_p = MyMul('my_mul') + add_p = MyAdd('my_add') + + + @jax.jit + def f(x): + return to(from_(x)) + + # test basic to/from jit + a = MyArray(jnp.ones(())) + b = f(a) # don't crash + self.assertIsInstance(b, MyArray) + self.assertAllClose(b.arr, jnp.ones(())) + + # test basic to/from autodiff + b, b_dot = jax.jvp(f, (a,), (a,)) + self.assertIsInstance(b, MyArray) + self.assertIsInstance(b_dot, MyArray) + + # test mul jit and backward pass + + @jax.jit + def f(x): + return mul(x, x) + + b, f_vjp = jax.vjp(f, a) + self.assertIn('MyTy', str(f_vjp)) + a_grad, = f_vjp(b) + self.assertIsInstance(a_grad, MyArray) + self.assertAllClose(a_grad.arr, 2.0, check_dtypes=False) + + def test_stages(self): + @dataclass(frozen=True) + class ArrayTuple: + x0: jax.Array + x1: jax.Array + + @dataclass(frozen=True) + class ShapedArrayTuple(HiType): + x0: ShapedArray + x1: ShapedArray + # sharding=None + + # how to lower to (lo)jax types + def lo_ty(self) -> list[ShapedArray]: + return [self.x0, self.x1] + + # these next two are essentially the pytree interface + def lower_val(self, hi_val: ArrayTuple) -> list[jax.Array]: + return [hi_val.x0, hi_val.x1] + def raise_val(self, x0, x1) -> ArrayTuple: + return ArrayTuple(x0, x1) + + register_hitype(ArrayTuple, lambda q: ShapedArrayTuple( + jax.typeof(q.x0), jax.typeof(q.x1))) + + q = ArrayTuple(jnp.zeros((4, 4), 'int8'), jnp.ones(4, 'float32')) + jax.jit(lambda x: x).lower(q).as_text() # don't crash + + compiled = jax.jit(lambda x: x).lower(q).compile() + compiled(q) # don't crash + + @parameterized.parameters([False, True]) + def test_while_loop(self, jit): + q = to_qarray(jnp.ones((2, 2), 'float32')) + + def f(q1, q2): + def cond_fun(i_carry): + i, _, __ = i_carry + return i < 1 + def body_fun(i_carry): + i, q_carry, _ = i_carry + q_carry = to_qarray(from_qarray(q_carry)) + return i + 1, q_carry, q + n, q_out, _ = jax.lax.while_loop(cond_fun, body_fun, (0, q1, q2)) + return n, q_out + + if jit: + f = jax.jit(f) + + jax.make_jaxpr(f)(q, q) # doesn't crash + n, q_out = f(q, q) + self.assertEqual(n, 1) + expected = from_qarray(to_qarray(from_qarray(q))) + self.assertAllClose(from_qarray(q_out), expected, check_dtypes=False) + + @parameterized.parameters([False, True]) + def test_tuple_basic(self, jit): + def f(): + tup = make_tup(1, 2) + return get_tuple_element(tup, 1) + + if jit: + f = jax.jit(f) + + self.assertEqual(f(), 2) + + @parameterized.parameters([False, True]) + def test_ref_to_tuple(self, jit): + def f(): + tup = make_tup(1, 2) + ref = jax.new_ref(tup) + tup_ = ref[...] + return get_tuple_element(tup_, 1) + + if jit: + f = jax.jit(f) + + self.assertEqual(f(), 2) + + @parameterized.parameters([False, True]) + def test_run_state(self, jit): + def f(): + @run_state + def g(ref_args): + tup_ref, x_ref = ref_args + tup = tup_ref[...] + x_ref[...] = get_tuple_element(tup, 1) + + tup = make_tup(1, 2) + _, ans = g((tup, 3)) + return ans + + if jit: + f = jax.jit(f) + + ans = f() + self.assertEqual(ans, 2) + + @parameterized.parameters([False, True]) + def test_newstyle_hiprimitive(self, jit): + + class RaiseToStaticPower(VJPHiPrimitive): + def __init__(self, in_aval, *, power): + self.in_avals = (in_aval,) + self.out_aval = in_aval + self.params = dict(power=power) + super().__init__() + + def expand(self, x): + return x ** self.power + + def vjp_fwd(self, nzs_in, x): + ans = self(x) + return (ans, x) + + def vjp_bwd(self, res, t, xbar_accum): + xbar = t * self.power * raise_to_static_power(res, self.power-1) + xbar_accum.accum(xbar) + + def batch(self, _axis_data, args, in_dims): + in_dim, = in_dims + x, = args + return raise_to_static_power(x, self.power), in_dim + + def jvp(self, primals, tangents): + (x,), (t,) = primals, tangents + return self(x), t * self.power * raise_to_static_power(x, self.power-1) + + def raise_to_static_power(x, power): + x_aval = jax.typeof(x) + return RaiseToStaticPower(x_aval, power=power)(x) + + def f(x): + return raise_to_static_power(x, power=3) + + if jit: + f = jax.jit(f) + self.assertEqual(f.lower(2.0).compile()(2.0), 8.0) + + self.assertEqual(f(2.0), 8.0) + xs = jnp.arange(3.0) + self.assertAllClose(jax.vmap(f)(xs), xs**3) + self.assertEqual(jax.grad(f)(2.0), 12.0) + self.assertEqual(jax.jvp(f, (2.0,), (1.0,)), + (8.0, 12.0)) + + @parameterized.parameters([False, True]) + def test_newstyle_hiprimitive_retval(self, jit): + + class RaiseToStaticPower(VJPHiPrimitive): + def __init__(self, in_aval, *, power): + self.in_avals = (in_aval,) + self.out_aval = in_aval + self.params = dict(power=power) + super().__init__() + + def expand(self, x): + return x ** self.power + + def vjp_fwd(self, nzs_in, x): + ans = self(x) + return (ans, x) + + def vjp_bwd_retval(self, res, t): + return (t * self.power * raise_to_static_power(res, self.power-1),) + + def batch(self, _axis_data, args, in_dims): + in_dim, = in_dims + x, = args + return raise_to_static_power(x, self.power), in_dim + + def raise_to_static_power(x, power): + x_aval = jax.typeof(x) + return RaiseToStaticPower(x_aval, power=power)(x) + + def f(x): + return raise_to_static_power(x, power=3) + + if jit: + f = jax.jit(f) + + self.assertEqual(f(2.0), 8.0) + xs = jnp.arange(3.0) + self.assertAllClose(jax.vmap(f)(xs), xs**3) + self.assertEqual(jax.grad(f)(2.0), 12.0) + + def test_newstyle_hiprimitive_defines_both_types_of_vjp_error(self): + class RaiseToStaticPower(VJPHiPrimitive): + def __init__(self, in_aval, *, power): + self.in_avals = (in_aval,) + self.out_aval = in_aval + self.params = dict(power=power) + super().__init__() + + def expand(self, x): + return x ** self.power + + def vjp_fwd(self, x): + ans = self(x) + return (ans, x) + + def vjp_bwd(self, res, t, xbar_accum): + xbar = t * self.power * raise_to_static_power(res, self.power-1) + xbar_accum.accum(xbar) + + def vjp_bwd_retval(self, res, t): + return (t * self.power * raise_to_static_power(res, self.power-1),) + + def batch(self, _axis_data, args, in_dims): + in_dim, = in_dims + x, = args + return raise_to_static_power(x, self.power), in_dim + + def raise_to_static_power(x, power): + x_aval = jax.typeof(x) + return RaiseToStaticPower(x_aval, power=power)(x) + + def f(x): + return raise_to_static_power(x, power=3) + + with self.assertRaises(AttributeError): + f(2.0) + + @config.numpy_dtype_promotion('standard') + def test_newstyle_hiprimitive_qarray(self): + + @dataclass(frozen=True) # not NamedTuple, which is a pytree + class QArray: + qvalue: jax.Array + scale: jax.Array + + @dataclass(frozen=True) + class QArrayTy(HiType): + shape: tuple[int, int] + + def to_tangent_aval(self): + return ShapedArray(self.shape, jnp.dtype('float32')) + + register_hitype(QArray, lambda q: QArrayTy(q.qvalue.shape)) + + def q(x): + return Q(jax.typeof(x))(x) + + def dq(qx): + return DQ(jax.typeof(qx))(qx) + + class Q(VJPHiPrimitive): + def __init__(self, unquantized_aval): + if unquantized_aval.dtype != jnp.dtype('float32'): raise TypeError + quantized_aval = QArrayTy(unquantized_aval.shape) + self.in_avals = (unquantized_aval,) + self.out_aval = quantized_aval + self.params = {} + super().__init__() + + def expand(self, x): + scale = jnp.max(jnp.abs(x)) / 127 + qvalue = jnp.round(x / scale).astype(jnp.int8) + return QArray(qvalue, scale) + + def vjp_fwd(self, nzs_in, x): + return self(x), None + + def vjp_bwd_retval(self, _, g): + return g, + + class DQ(VJPHiPrimitive): + def __init__(self, quantized_aval): + unquantized_aval = ShapedArray(quantized_aval.shape, jnp.dtype('float32')) + self.in_avals = (quantized_aval,) + self.out_aval = unquantized_aval + self.params = {} + super().__init__() + + def expand(self, qx): + return qx.qvalue * qx.scale + + def vjp_fwd(self, nzs_in, qx): + return self(qx), None + + def vjp_bwd_retval(self, _, g): + return g, + + def f(x): + return jnp.sum(dq(q(x))) + + x = jax.random.normal(jax.random.key(0), (3, 3), dtype='float32') + g = jax.grad(f)(x) + + def test_symbolic_zeros(self): + + class Mul(VJPHiPrimitive): + def __init__(self, aval): + self.in_avals = (aval, aval) + self.out_aval = aval + self.params = {} + super().__init__() + + def expand(self, x, y): + return x * y + + def vjp_fwd(self, nzs_in, x, y): + assert list(nzs_in) == list(nzs_in_) # defined below + ans = self(x, y) + return ans, (x, y) + + def vjp_bwd(self, res, g, x_acc, y_acc): + assert list(nzs_in_) == [not isinstance(x_acc, ad.NullAccum), + not isinstance(y_acc, ad.NullAccum)] + x, y = res + x_acc.accum(g * y) + y_acc.accum(x * g) + + def mul(x, y): + return Mul(typeof(x))(x, y) + + nzs_in_ = (True, False) + self.assertAllClose(jax.grad(mul)(2., 3.), 3., check_dtypes=False) + + nzs_in_ = (False, True) + self.assertAllClose(jax.grad(mul, 1)(2., 3.), 2., check_dtypes=False) + + def test_symbolic_zeros_retval(self): + + class Mul(VJPHiPrimitive): + def __init__(self, aval): + self.in_avals = (aval, aval) + self.out_aval = aval + self.params = {} + super().__init__() + + def expand(self, x, y): + return x * y + + def vjp_fwd(self, nzs_in, x, y): + assert list(nzs_in) == list(nzs_in_) # defined below + ans = self(x, y) + return ans, (x, y) + + def vjp_bwd_retval(self, res, g): + x, y = res + return (g * y, x * g) + + def mul(x, y): + return Mul(typeof(x))(x, y) + + nzs_in_ = (True, False) + self.assertAllClose(jax.grad(mul)(2., 3.), 3., check_dtypes=False) + + nzs_in_ = (False, True) + self.assertAllClose(jax.grad(mul, 1)(2., 3.), 2., check_dtypes=False) + + @jtu.with_explicit_mesh((2,), ('data',)) + def test_hijax_primitive_under_shard_map(self, mesh): + class Square(VJPHiPrimitive): + def __init__(self, in_aval): + self.in_avals = (in_aval,) + self.out_aval = in_aval + self.params = {} + super().__init__() + + def expand(self, x): + return x ** 2 + + def square(x): + return Square(jax.typeof(x))(x) + g = jax.shard_map(square, in_specs=(jax.P('data'),), out_specs=jax.P('data')) + x = jnp.arange(10) + g(x) + jax.jit(g)(x) + + +class BoxTest(jtu.JaxTestCase): + + @parameterized.parameters([False, True]) + def test_qdd(self, jit): + + val1 = 1.0 + val2 = jnp.arange(3) + + box1 = Box(val1) + + def f(box2): + assert core.cur_qdd(box2).leaf_avals == (core.typeof(val1),) + box2.set(val2) + assert core.cur_qdd(box2).leaf_avals == (core.typeof(val2),) + + box3 = new_box() + box3.set(val2) + assert core.cur_qdd(box3).leaf_avals == (core.typeof(val2),) + box3.set(val1) + assert core.cur_qdd(box3).leaf_avals == (core.typeof(val1),) + + assert core.cur_qdd(box1).leaf_avals == (core.typeof(val1),) + box1.set(val2) + assert core.cur_qdd(box1).leaf_avals == (core.typeof(val2),) + + return + + if jit: + f = jax.jit(f) + + f(Box(val1)) + + def test_jit_internal(self): + @jax.jit + def f(x): + box = new_box() # TODO not Box + box.set(x) + box.set(box.get() + box.get()) + return box.get() + + f(1) + + def test_jit_internal_box_constructor(self): + @jax.jit + def f(x): + box = Box(x) + box.set(box.get() + box.get()) + return box.get() + + f(1) + + @parameterized.parameters([False, True]) + def test_isinstance(self, jit): + def f(): + box = Box() + self.assertIsInstance(box, Box) + if jit: + f = jax.jit(f) + f() + + def test_jit_arg(self): + @jax.jit + def f(box, x): + assert tracing_ok + box.set(box.get() + x) + + tracing_ok = True + box1 = Box(1.0) + f(box1, 1.) + self.assertAllClose(box1.get(), 2.0) + + tracing_ok = False + box2 = Box(2.0) + f(box2, 2.) + self.assertAllClose(box2.get(), 4.0) + + def test_jit_arg2(self): + # set without get + + @jax.jit + def f(box, x): + box_set(box, x) + + box = Box(0.0) + f(box, 1.) + self.assertAllClose(box_get(box), 1.0, check_dtypes=False) + + def test_jit_arg_in_pytree(self): + @jax.jit + def f(dct, x): + assert tracing_ok + box = dct['box'] + box.set(box.get() + x) + + tracing_ok = True + box1 = Box(1.0) + f({'box': box1, 'a': 1.0}, 1.) + self.assertAllClose(box1.get(), 2.0) + + tracing_ok = False + box2 = Box(2.0) + f({'box': box2, 'a': 2.0}, 2.) + self.assertAllClose(box2.get(), 4.0) + + tracing_ok = True + box3 = Box(3) # int, dtype changed + f({'box': box3, 'a': 2.0}, 2.) + self.assertAllClose(box3.get(), 5.0) + + def test_jit_closure(self): + box = Box(1.0) + + @jax.jit + def f(x): + assert tracing_ok + box.set(box.get() + x) + + tracing_ok = True + f(2.0) + self.assertAllClose(box.get(), 3.0) + tracing_ok = False + f(5.0) + self.assertAllClose(box.get(), 8.0) + + def test_jit_closure_nested(self): + box = Box(5.0) + + @jax.jit + def f(x): + box.set(box.get() + x) + + @jax.jit + def g(x): + f(x) + + g(3.0) + self.assertAllClose(box.get(), 8.0) + + def test_jit_closure_nested2(self): + @jax.jit + def h(x): + box = new_box() + box.set(x) + + @jax.jit + def k(x): + box.set(box.get() + x) + + k(1.0) + k(1.0) + return box.get() + + ans = h(2.0) + self.assertAllClose(ans, 4.0) + + def test_jit_closure_nested3(self): + box = new_box() + + @jax.jit + def h(x): + box.set(x) + + @jax.jit + def k(x): + box.set(box.get() + x) + + k(1.0) + k(1.0) + return box.get() + + ans = h(2.0) + self.assertAllClose(ans, 4.0) + + @parameterized.parameters([False, True]) + def test_jvp_closure_stop_gradient(self, jit): + box = Box(1.0) + + def f(x): + y = 2 * x + box.set(box.get() + jax.lax.stop_gradient(y)) + return y + + if jit: + f = jax.jit(f) + + y, y_dot = jax.jvp(f, (1.0,), (1.0,)) + self.assertAllClose(y, 2.0) + self.assertAllClose(y_dot, 2.0) + self.assertAllClose(box.get(), 3.0) + + @parameterized.parameters([False, True]) + def test_jvp_arg(self, jit): + def f(box, x): + box.set(box.get() + x) + return x + + if jit: + f = jax.jit(f) + + box = Box(5.0) + box_dot = Box(1.0) + y, y_dot = jax.jvp(f, (box, 2.), (box_dot, 1.)) + self.assertAllClose(y, 2.0) + self.assertAllClose(y_dot, 1.0) + self.assertAllClose(box.get(), 7.0) + self.assertAllClose(box_dot.get(), 2.0) + + @parameterized.parameters([False, True]) + def test_custom_vjp_plumbing(self, jit): + box = Box(0.0) + + @jax.custom_vjp + def foo(x): + return x + def foo_fwd(x): + return foo(x), None + def foo_bwd(_, g): + box.set(g) + return g, + foo.defvjp(foo_fwd, foo_bwd) + + def f(x): + x = 2 * x + x = foo(x) + x = 2 * x + return x + + if jit: + f = jax.jit(f) + + jax.grad(f)(1.0) + + self.assertAllClose(box.get(), 2.0) + + @parameterized.parameters([False, True]) + def test_custom_vjp_plumbing_abstracted(self, jit): + box = Box(0.0) + + @jax.custom_vjp + def foo(box, x): + return x + def foo_fwd(box, x): + return x, box + def foo_bwd(box, g): + box.set(g) + return None, g + foo.defvjp(foo_fwd, foo_bwd) + + def f(box, x): + x = 2 * x + x = foo(box, x) + x = 2 * x + return x + + if jit: + f = jax.jit(f) + + jax.grad(partial(f, box))(1.0) + self.assertAllClose(box.get(), 2.0) + + @parameterized.parameters([False, True]) + def test_custom_vjp_primal(self, jit): + box = Box(0.0) + + @custom_vjp3 + def foo(box, x): + box.set(x) + return x + def foo_fwd(box, x): + assert False # doesn't run + def foo_bwd(box, g): + assert False # doesn't run + foo.defvjp(foo_fwd, foo_bwd) + + def f(box, x): + x = 2 * x + x = foo(box, x) + x = 2 * x + return x + + if jit: + f = jax.jit(f) + + f(box, 1.0) + self.assertAllClose(box.get(), 2.0) + + @parameterized.parameters([False, True]) + def test_grad_closure_stop_gradient(self, jit): + box = Box(0.0) + + def f(x): + y = x * 2 + box.set(box.get() + jax.lax.stop_gradient(y)) + return y + + if jit: + f = jax.jit(f) + + g = jax.grad(f)(1.0) + self.assertAllClose(g, 2.0) + self.assertAllClose(box.get(), 2.0) + + @parameterized.parameters([False, True]) + def test_scan_basic(self, jit): + box = Box(1.0) + + def double_it_10(): + def body(_, __): + box.set(box.get() * 2) + return None, None + _, _ = jax.lax.scan(body, None, None, length=10) + + if jit: + double_it_10 = jax.jit(double_it_10) + + double_it_10() + + self.assertAllClose(box.get(), 1024., check_dtypes=False) + + def test_cond_box_internally_pure(self): + @jax.jit + def doubleit(x): + b = new_box() + b.set(x) + b.set(b.get() + b.get()) + return b.get() + + def identity(x): return x + + @jax.jit + def f(x): + return jax.lax.cond(x > 0, doubleit, identity, x) + + self.assertAllClose(f(1.0), 2.0) + + def test_cond_box_arg(self): + @jax.jit + def f(x): + b = new_box() + b.set(x) + jax.lax.cond(x > 0, lambda box: box.set(box.get() + 1), lambda _: None, b) + return b.get() + + self.assertAllClose(f(1.0), 2.0) + + def test_cond_closed_over_box(self): + # TODO: good error messages in the case that qdd changes differently in each branch + def f(x): + b = new_box() + b.set(1.0) + jax.lax.cond(x > 0., lambda _: b.set(b.get() + 1.0), lambda _: None, 1.0) + return b.get() + + self.assertAllClose(f(1.0), 2.0) + + + # TODO error-checking tests from attrs_test.py + + ### + + def test_box_autodiff(self): + if config.enable_x64.value: raise unittest.SkipTest("no x64") + + class StashTangents(HiPrimitive): + def is_high(self, *_): + return True + + def abstract_eval(_, box_aval, x_aval): + del box_aval + return x_aval, {box_effect} + + def to_lojax(_, box, x): + return x + + def jvp(_, primals, tangents): + box, x = primals + _, x_dot = tangents + box_set(box, x_dot) + return x, x_dot + + def transpose(self, *args): + assert False # TODO + stash_tangents_p = StashTangents('stash_tangents') + + def stash_tangents(box, x): + return stash_tangents_p.bind(box, x) + + @jax.jit + def f(box, x): + x = stash_tangents(box, x) + return x + + box = Box(0.0) + jax.jvp(partial(f, box), (3.,), (5.,)) + self.assertAllClose(box_get(box), 5.0, check_dtypes=False) + + def test_type_changing_box(self): + box = Box(jnp.arange(1)) + box_set(box, jnp.arange(2)) + self.assertLen(box._val, 2) + + @jax.jit + def f(box, x): + box_set(box, x) + + f(box, jnp.arange(3)) + self.assertLen(box._val, 3) + f(box, jnp.arange(4)) + self.assertLen(box._val, 4) + + def test_pytree_box(self): + box = Box(None) + + @jax.jit + def f(box, x): + assert tracing_ok + val = box_get(box) + if val is None: + box_set(box, x) + else: + box_set(box, [x, x]) + + tracing_ok = True + f(box, 1.0) + self.assertAllClose(box_get(box), 1.0, check_dtypes=False) + f(box, 2.0) + self.assertAllClose(box_get(box), [2.0, 2.0], check_dtypes=False) + f(box, 3.0) + self.assertAllClose(box_get(box), [3.0, 3.0], check_dtypes=False) + tracing_ok = False + f(box, 4.0) + self.assertAllClose(box_get(box), [4.0, 4.0], check_dtypes=False) + + def test_pytree_of_hijaxtypes_box(self): + + @dataclass(frozen=True) + class MyArray: + arr: jax.Array # always f32 + + @dataclass(frozen=True) + class MyTy(HiType): + has_qdd = False + + def to_tangent_aval(self): + return MyTy() + def str_short(self, short_dtypes=False): + return 'MyTy' + def lo_ty(self): + return [core.ShapedArray((), jnp.dtype('float32'))] + def lower_val(self, hi_val: MyArray) -> list[jax.Array]: + return [hi_val.arr] + def raise_val(self, val) -> MyArray: + return MyArray(val) + + def __eq__(self, other): return isinstance(other, MyTy) + + core.pytype_aval_mappings[MyArray] = lambda _: MyTy() + + box = Box([MyArray(jnp.float32(1)), + MyArray(jnp.float32(2))]) + + @jax.jit + def f(box): + a, b = box_get(box) + box_set(box, [b, a]) + + f(box) + val = box_get(box) + self.assertIsInstance(val, list) + self.assertLen(val, 2) + b_, a_ = val + self.assertIsInstance(a_, MyArray) + self.assertIsInstance(b_, MyArray) + self.assertAllClose(a_.arr, 1, check_dtypes=False) + self.assertAllClose(b_.arr, 2, check_dtypes=False) + + def test_closed_over_type_changing_box(self): + + box = Box(None) + box2 = Box(None) + + @jax.jit + def f(): + assert tracing_ok + x = box.get() + if x is None: + box.set(0) + elif type(x) is dict: + box.set(dict(x, a=5)) + box2.set(3) + else: + box.set(x + 1) + + tracing_ok = True + f() # tracing okay because first time + f() # tracing okay because first time with box as not None + tracing_ok = False + f() + self.assertEqual(box.get(), 2) + self.assertEqual(box2.get(), None) + box.set(None) + f() + f() + f() + f() + self.assertEqual(box.get(), 3) + self.assertEqual(box2.get(), None) + box.set({'b': 3}) + tracing_ok = True + f() + self.assertEqual(box.get(), dict(a=5, b=3)) + self.assertEqual(box2.get(), 3) + + @parameterized.parameters([False, True]) + def test_while_loop(self, jit): + box = Box(1.) + + def f(): + zero = jnp.zeros((), 'int32') + + def cond_fun(i): + return i + zero < 5 + def body_fun(i): + box.set(box.get() * 2.) + return i + 1 + _ = jax.lax.while_loop(cond_fun, body_fun, 0) + + if jit: + f = jax.jit(f) + + f() + self.assertAllClose(box.get(), 32, check_dtypes=False) + + def test_while_loop_typechange_error(self): + box = Box([1.]) + def cond_fun(i): + return i < 5 + def body_fun(i): + box.set(box.get() * 2) + return i + 1 + with self.assertRaisesRegex(TypeError, "type-changing mutations not allowed"): + _ = jax.lax.while_loop(cond_fun, body_fun, 0) + + def test_eval_shape(self): + qarray = QArray(jnp.ones((2, 2)), jnp.ones(2)) + + @jax.jit + def f(): + return qarray + + out_type = jax.eval_shape(f) + self.assertEqual(out_type, QArrayTy((2, 2))) + + def test_stages_mutable(self): + box = Box(1.0) + + @jax.jit + def f(box): + box.set(box.get() + 1.) + + f.lower(box).as_text() # don't crash + compiled = f.lower(box).compile() + compiled(box) + compiled(box) + compiled(box) + self.assertAllClose(box.get(), 4.) + + +class RefTest(jtu.JaxTestCase): + + def test_get_ref_hitype(self): + + @jax.jit + def f(q): + ref = jax.new_ref(q) + return ref[:, 0:2] + + qarray = QArray(jnp.ones((2, 2), dtype='int8'), jnp.ones(2, 'float32')) + o = f(qarray) + self.assertArraysEqual(o.arr, qarray.arr) + self.assertArraysEqual(o.scale, qarray.scale) + + def test_swap_ref_hitype(self): + + @jax.jit + def f(q1, q2): + ref = jax.new_ref(q1) + ref[:, :] = q2 + return ref.get() + + q1 = QArray(jnp.zeros((2, 2), dtype='int8'), jnp.zeros(2, 'float32')) + q2 = QArray(jnp.ones((2, 2), dtype='int8'), jnp.ones(2, 'float32')) + o = f(q1, q2) + self.assertArraysEqual(o.arr, q2.arr) + self.assertArraysEqual(o.scale, q2.scale) + +class HijaxTransformCoverageTest(jtu.JaxTestCase): + # ------------ + # grad + # ------------ + # with differentiable hijax arguments + def test_hitypes_as_grad_args(self): + box = immutbox_new((jnp.array(2.0), jnp.array(3.0))) + + def loss_fn(tup): + x = immutbox_get(tup)[0] + return x ** 2 + + grads = jax.grad(loss_fn)(box) + self.assertAllClose(immutbox_get(grads)[0], 4.0) + + # with non-differentiable hijax arguments + def test_hitypes_as_nondiff_grad_args(self): + box = immutbox_new((jnp.array(2.0), jnp.array(3.0))) + x = jnp.array(3.0) + + def loss_fn(x, box): + y = immutbox_get(box)[1] + return x ** 2 + y + + grad = jax.grad(loss_fn)(x, box) + self.assertAllClose(grad, 6.0, check_dtypes=False) + + # with hijax captured arguments + def test_hitypes_as_captured_args(self): + box = immutbox_new((jnp.array(2.0), jnp.array(3.0))) + + def loss_fn(x): + y = immutbox_get(box)[1] + return x ** 2 + y + + grad = jax.grad(loss_fn)(jnp.array(4.0)) + self.assertAllClose(grad, 8.0, check_dtypes=False) + + # with differentiable mutable hijax arguments + @absltest.skip("Not yet implemented") + def test_mutable_hitypes_as_grad_args(self): + box = Box(jnp.array(2.0)) + + def loss_fn(box): + return box.get() ** 2 + + grads = jax.grad(loss_fn)(box) + # NOTE: unclear what the tangent type will be here + + # with non-differentiable mutable hijax arguments + def test_mutable_hitypes_as_nondiff_grad_args(self): + box = Box(jnp.array(2.0)) + x = jnp.array(3.0) + + def loss_fn(x, box): + box.set(jax.lax.stop_gradient(x * 2)) + return x ** 2 + box.get() + + grad = jax.grad(loss_fn)(x, box) + self.assertAllClose(box.get(), 6.0, check_dtypes=False) + self.assertAllClose(grad, 6.0, check_dtypes=False) + + # with mutable hijax captured arguments + def test_mutable_hitypes_as_captured_args(self): + box = Box(jnp.array(2.0)) + + def loss_fn(x): + box.set(jax.lax.stop_gradient(x * 3)) + return x ** 2 + box.get() + + grad = jax.grad(loss_fn)(jnp.array(4.0)) + self.assertAllClose(box.get(), 12.0, check_dtypes=False) + self.assertAllClose(grad, 8.0, check_dtypes=False) + + #------------ + # scan + #------------ + # with hijax carry arguments + def test_hitypes_as_scan_carry(self): + box = immutbox_new((jnp.array(1.0), jnp.array(2.0))) + + def body(box, _): + x, y = immutbox_get(box) + return immutbox_new((x + 1.0, y + 2.0)), None + + box, _ = jax.lax.scan(body, box, None, length=5) + x, y = immutbox_get(box) + self.assertAllClose(x, 6.0, check_dtypes=False) + self.assertAllClose(y, 12.0, check_dtypes=False) + + # with hijax extensive arguments + def test_hitypes_as_scan_extensive(self): + box = immutbox_new((jnp.arange(5), -jnp.arange(5))) + + def body(_, box_i): + x, y = immutbox_get(box_i) + box_i = immutbox_new((x * 2, y * 2)) + return None, box_i + _, box = jax.lax.scan(body, None, box) + x, y = immutbox_get(box) + self.assertAllClose(x, jnp.arange(5) * 2, check_dtypes=False) + self.assertAllClose(y, -jnp.arange(5) * 2, check_dtypes=False) + + # with hijax captured arguments + def test_hitypes_as_scan_captured(self): + box = immutbox_new((jnp.array(3.0), jnp.array(4.0))) + carry0 = jnp.array(1.0) + xs = jnp.arange(5, dtype=jnp.float32) + + def body(carry, x): + a, b = immutbox_get(box) + carry = a * carry + b + y = a * x + b + return carry, immutbox_new(y) + + carry, ys_box = jax.lax.scan(body, carry0, xs) + ys = immutbox_get(ys_box) + self.assertAllClose(carry, 727.0, check_dtypes=False) + self.assertAllClose(ys, 3.0 * xs + 4.0, check_dtypes=False) + + # with mutable hijax carry arguments + @absltest.skip("has_qdd not yet supported for Box in scan carry") + def test_mutable_hitypes_as_scan_carry(self): + box = Box(jnp.array(1.0)) + + def body(box, _): + box.set(box.get() * 2) + return box, None + + box, _ = jax.lax.scan(body, box, None, length=5) + self.assertAllClose(box.get(), 32.0, check_dtypes=False) + + # with mutable hijax extensive arguments + @absltest.skip("Box doesn't have shape attribute needed for scan extensive") + def test_mutable_hitypes_as_scan_extensive(self): + boxes = [Box(jnp.float32(i)) for i in range(5)] + + def body(_, box_i): + val = box_i.get() + box_i.set(val * 2) + return None, box_i + + _, boxes_out = jax.lax.scan(body, None, boxes) + for i, box in enumerate(boxes_out): + self.assertAllClose(box.get(), i * 2, check_dtypes=False) + + # with mutable hijax captured arguments + def test_mutable_hitypes_as_scan_captured(self): + box = Box(jnp.array(3.0)) + + def body(_, __): + box.set(box.get() + 1.0) + return None, None + + jax.lax.scan(body, None, None, length=5) + self.assertAllClose(box.get(), 8.0, check_dtypes=False) + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/image_test.py b/tests/image_test.py index 0f6341086d19..4d13b7dab15a 100644 --- a/tests/image_test.py +++ b/tests/image_test.py @@ -52,6 +52,7 @@ class ImageTest(jtu.JaxTestCase): antialias=[False, True], ) @unittest.skipIf(not tf, "Test requires TensorFlow") + @jtu.thread_unsafe_test() # TensorFlow isn't thread-safe without the GIL. def testResizeAgainstTensorFlow(self, dtype, image_shape, target_shape, method, antialias): # TODO(phawkins): debug this. There is a small mismatch between TF and JAX diff --git a/tests/infeed_test.py b/tests/infeed_test.py deleted file mode 100644 index 060502ae68cd..000000000000 --- a/tests/infeed_test.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright 2019 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import threading -from unittest import SkipTest - -from absl.testing import absltest -import jax -from jax import lax, numpy as jnp -from jax._src import core -from jax._src import xla_bridge -from jax._src.lib import xla_client -import jax._src.test_util as jtu -import numpy as np - -jax.config.parse_flags_with_absl() - - -@jtu.thread_unsafe_test_class() # infeed isn't thread-safe -class InfeedTest(jtu.JaxTestCase): - - def setUp(self): - if xla_bridge.using_pjrt_c_api(): - raise SkipTest("infeed not implemented in PJRT C API") - super().setUp() - - @jax.numpy_rank_promotion("allow") # Test explicitly exercises implicit rank promotion. - def testInfeed(self): - raise SkipTest("skipping temporarily for stackless") - - @jax.jit - def f(x): - token = lax.create_token(x) - (y,), token = lax.infeed( - token, shape=(core.ShapedArray((3, 4), jnp.float32),)) - (z,), _ = lax.infeed( - token, shape=(core.ShapedArray((3, 1, 1), jnp.float32),)) - return x + y + z - - x = np.float32(1.5) - y = np.reshape(np.arange(12, dtype=np.float32), (3, 4)) # self.rng().randn(3, 4).astype(np.float32) - z = self.rng().randn(3, 1, 1).astype(np.float32) - device = jax.local_devices()[0] - device.transfer_to_infeed((y,)) - device.transfer_to_infeed((z,)) - self.assertAllClose(f(x), x + y + z) - - def testInfeedPytree(self): - raise SkipTest("skipping temporarily for stackless") - - x = np.float32(1.5) - y = np.reshape(np.arange(12, dtype=np.int16), (3, 4)) - to_infeed = dict(a=x, b=y) - to_infeed_shape = dict(a=core.ShapedArray((), dtype=np.float32), - b=core.ShapedArray((3, 4), dtype=np.int16)) - @jax.jit - def f(x): - token = lax.create_token(x) - res, token = lax.infeed(token, shape=to_infeed_shape) - return res - - device = jax.local_devices()[0] - # We must transfer the flattened data, as a tuple!!! - flat_to_infeed, _ = jax.tree.flatten(to_infeed) - device.transfer_to_infeed(tuple(flat_to_infeed)) - self.assertAllClose(f(x), to_infeed) - - @jax.numpy_rank_promotion("allow") # Test explicitly exercises implicit rank promotion. - def testInfeedThenOutfeed(self): - - @jax.jit - def f(x): - token = lax.create_token(x) - y, token = lax.infeed( - token, shape=core.ShapedArray((3, 4), jnp.float32)) - token = lax.outfeed(token, y + np.float32(1)) - return x - 1 - - x = np.float32(7.5) - y = self.rng().randn(3, 4).astype(np.float32) - execution = threading.Thread(target=lambda: f(x)) - execution.start() - device = jax.local_devices()[0] - device.transfer_to_infeed((y,)) - out, = device.transfer_from_outfeed( - xla_client.shape_from_pyval((y,)).with_major_to_minor_layout_if_absent()) - execution.join() - self.assertAllClose(out, y + np.float32(1)) - - def testInfeedThenOutfeedInALoop(self): - - def doubler(_, token): - y, token = lax.infeed( - token, shape=core.ShapedArray((3, 4), jnp.float32)) - return lax.outfeed(token, y * np.float32(2)) - - @jax.jit - def f(n): - token = lax.create_token(n) - token = lax.fori_loop(0, n, doubler, token) - return n - - device = jax.local_devices()[0] - n = 10 - execution = threading.Thread(target=lambda: f(n)) - execution.start() - for _ in range(n): - x = self.rng().randn(3, 4).astype(np.float32) - device.transfer_to_infeed((x,)) - y, = device.transfer_from_outfeed(xla_client.shape_from_pyval((x,)) - .with_major_to_minor_layout_if_absent()) - self.assertAllClose(y, x * np.float32(2)) - execution.join() - - -if __name__ == '__main__': - absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/jax_jit_test.py b/tests/jax_jit_test.py index 5946d557d4ba..e52d1bcfa7ea 100644 --- a/tests/jax_jit_test.py +++ b/tests/jax_jit_test.py @@ -17,30 +17,27 @@ from absl.testing import absltest from absl.testing import parameterized import jax -from jax import dtypes from jax import numpy as jnp from jax._src import config from jax._src import core +from jax._src import dtypes from jax._src import lib as jaxlib from jax._src import test_util as jtu from jax._src.interpreters import pxla import numpy as np + config.parse_flags_with_absl() -def _cpp_device_put(value, device): +def _cpp_device_put(value, device, enable_x64: bool | None = None): aval = core.shaped_abstractify(value) return pxla.batched_device_put( - aval, jax.sharding.SingleDeviceSharding(device), [value], [device]) + aval, jax.sharding.SingleDeviceSharding(device), [value], [device], + enable_x64=enable_x64) class JaxJitTest(jtu.JaxTestCase): - def test_is_float_0(self): - self.assertTrue( - jaxlib.jax_jit._is_float0(np.zeros((5, 5), dtype=jax.float0))) - self.assertFalse(jaxlib.jax_jit._is_float0(np.zeros((5, 5)))) - @parameterized.parameters([jax.device_put, _cpp_device_put]) def test_device_put_on_numpy_masked_array(self, device_put_function): # TODO(jakevdp): add appropriate logic to jaxlib device_put and update this test. @@ -108,8 +105,8 @@ def test_device_put_on_sharded_device_array(self, device_put_function): def test_device_put_on_python_scalars(self): device = jax.devices()[0] - int_type = dtypes.canonicalize_dtype(np.int64) - float_type = dtypes.canonicalize_dtype(np.float64) + int_type = dtypes.default_int_dtype() + float_type = dtypes.default_float_dtype() complex_type = dtypes.canonicalize_dtype(np.complex128) # int @@ -164,8 +161,8 @@ def test_arg_signature_of_value(self): self.assertEqual(signature.shape, (3, 4)) self.assertFalse(signature.weak_type) - int_type = dtypes.canonicalize_dtype(np.int64) - float_type = dtypes.canonicalize_dtype(np.float64) + int_type = dtypes.default_int_dtype() + float_type = dtypes.default_float_dtype() complex_type = dtypes.canonicalize_dtype(np.complex128) # 3. Python scalar types @@ -198,6 +195,18 @@ def test_arg_signature_of_value(self): self.assertEqual(signature.shape, ()) self.assertTrue(signature.weak_type) + def test_device_put_on_numpy_arrays_x64_enabled(self): + device = jax.devices()[0] + for dtype in jtu.supported_dtypes(): + value = np.zeros((3, 4), dtype=dtype) + output_buffer = _cpp_device_put(value, device=device, enable_x64=True) + self.assertFalse(output_buffer.aval.weak_type) + self.assertEqual(output_buffer.aval, core.ShapedArray((3, 4), dtype)) + self.assertEqual(output_buffer.dtype, dtype) # NB: no canonicalization + np.testing.assert_array_equal(output_buffer, np.zeros((3, 4), + dtype=dtype)) + + def test_signature_support(self): def f(a, b, c): return a + b + c @@ -227,6 +236,44 @@ def fn(x): self.assertArraysEqual(v1, v1_expected) self.assertArraysEqual(v2, v2_expected) + @jtu.skip_on_flag("jax_use_simplified_jaxpr_constants", True) + def test_check_for_large_number_of_constants(self): + y = jnp.ones((128, 128)) + x = jnp.zeros((128,)) + + def jit_maker(): # need to ensure we lower at each test + def func(x): + return x @ y + return jax.jit(func) + + with self.assertWarnsRegex(UserWarning, "A large amount of constants were captured during lowering"): + with config.captured_constants_warn_bytes(y.nbytes): + jit_maker()(x) + + with self.assertNoWarnings(): + with config.captured_constants_warn_bytes(y.nbytes + 1): + jit_maker()(x) + + with config.captured_constants_warn_bytes(-1): + jit_maker()(x) + + def testParseArguments(self): + pytree_registry = jaxlib.pytree.default_registry() + sig, args = jaxlib.jax_jit.parse_arguments( + positional_args=[1, 2, 3], + keyword_args=[4, 5], + kwnames=("a", "b"), + static_argnums=[0, 2], + static_argnames=["a"], + pytree_registry=pytree_registry, + ) + self.assertEqual(args, [2, 5]) + self.assertEqual(sig.static_args, [1, 3, 4]) + self.assertEqual(sig.static_arg_names, ["a"]) + _, leaf = pytree_registry.flatten(0) + self.assertEqual(sig.dynamic_arg_names, ["b"]) + self.assertEqual(sig.dynamic_arg_treedefs, [leaf, leaf]) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/jax_numpy_error_test.py b/tests/jax_numpy_error_test.py new file mode 100644 index 000000000000..6a2f151cec1f --- /dev/null +++ b/tests/jax_numpy_error_test.py @@ -0,0 +1,282 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import operator + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import config +from jax._src import error_check +from jax._src import test_util as jtu +from jax._src.numpy import error as jnp_error +import jax.numpy as jnp + +config.parse_flags_with_absl() + + +JaxValueError = error_check.JaxValueError + + +class JaxNumpyErrorTests(jtu.JaxTestCase): + def setUp(self): + # TODO(b/408148001): Fix thread safety issue. + if jtu.TEST_NUM_THREADS.value > 1: + self.skipTest("Test does not work with multiple threads") + super().setUp() + + @parameterized.product(jit=[True, False]) + def test_set_error_if_nan(self, jit): + def f(x): + jnp_error._set_error_if_nan(x) + return x + + if jit: + f = jax.jit(f) + + x = jnp.full((4,), jnp.nan, dtype=jnp.float32) + + with jnp_error.error_checking_behavior(nan="ignore"): + _ = f(x) + error_check.raise_if_error() # should not raise error + + with jnp_error.error_checking_behavior(nan="raise"): + _ = f(x) + with self.assertRaisesRegex(JaxValueError, "NaN"): + error_check.raise_if_error() + + @parameterized.product(jit=[True, False]) + def test_set_error_if_divide_by_zero(self, jit): + def f(x, y): + jnp_error._set_error_if_divide_by_zero(y) + return x / y + + if jit: + f = jax.jit(f) + + x = jnp.arange(4, dtype=jnp.float32) + 1 + y = jnp.arange(4, dtype=jnp.float32) + + with jnp_error.error_checking_behavior(divide="ignore"): + _ = f(x, y) + error_check.raise_if_error() # should not raise error + + with jnp_error.error_checking_behavior(divide="raise"): + _ = f(x, y) + with self.assertRaisesRegex(JaxValueError, "Division by zero"): + error_check.raise_if_error() + + @parameterized.product(jit=[True, False]) + def test_error_category_oob_check(self, jit): + def f(x, start_indices, slice_sizes): + jnp_error._set_error_if_with_category( + jnp.logical_or( + start_indices < 0, + start_indices + jnp.array(slice_sizes, dtype=jnp.int32) + >= jnp.array(x.shape, dtype=jnp.int32), + ), + "Out of bounds in dynamic_slice", + category="oob", + ) + y = jax.lax.dynamic_slice( + x, start_indices, slice_sizes, allow_negative_indices=False + ) + return y + + if jit: + f = jax.jit(f, static_argnums=(2,)) + + x = jnp.arange(12).reshape(3, 4) + start_indices = jnp.array([0, -1], dtype=jnp.int32) + slice_sizes = (3, 4) + + with jnp_error.error_checking_behavior(oob="ignore"): + _ = f(x, start_indices, slice_sizes) + error_check.raise_if_error() # should not raise error + + with jnp_error.error_checking_behavior(oob="raise"): + _ = f(x, start_indices, slice_sizes) + with self.assertRaisesRegex( + JaxValueError, "Out of bounds in dynamic_slice", + ): + error_check.raise_if_error() + + def test_error_category_invalid_category(self): + with self.assertRaisesRegex(ValueError, "Invalid category"): + jnp_error._set_error_if_with_category( + jnp.isnan(jnp.float32(1.0)), "x is NaN", category="invalid" + ) + + @staticmethod + def nan_cases(cases): + for jit in (True, False): + for func, args_error, args_no_err in cases: + if not isinstance(args_error, tuple): + args_error = (args_error,) + if not isinstance(args_no_err, tuple): + args_no_err = (args_no_err,) + + jit_str = "jit" if jit else "nojit" + func_str = f"{func.__module__}.{func.__name__}" + name = f"_{jit_str}_{func_str}" + + yield name, jit, func, args_error, args_no_err + + @parameterized.named_parameters( + nan_cases(( + # List of all NaN-producing jax.numpy functions. + # The first group of numbers is the input that will produce a NaN, and + # the second group is the input that will not produce a NaN. + # go/keep-sorted start + (jnp.acos, 2.0, 0.5), + (jnp.acosh, 0.5, 2.0), + (jnp.add, (jnp.inf, -jnp.inf), (0.0, 0.0)), + (jnp.arccos, 2.0, 0.5), + (jnp.arccosh, 0.5, 2.0), + (jnp.arcsin, -2.0, 0.5), + (jnp.arctanh, -2.0, 0.5), + (jnp.asin, -2.0, 0.5), + (jnp.atanh, -2.0, 0.5), + (jnp.cos, jnp.inf, 1.0), + (jnp.divide, (0.0, 0.0), (1.0, 1.0)), + (jnp.divmod, (1.0, 0.0), (1.0, 1.0)), + (jnp.float_power, (-1.0, 0.5), (1.0, 1.0)), + (jnp.fmod, (1.0, 0.0), (1.0, 1.0)), + (jnp.log, -1.0, 1.0), + (jnp.log10, -1.0, 1.0), + (jnp.log1p, -1.5, 1.0), + (jnp.log2, -1.0, 1.0), + (jnp.mod, (1.0, 0.0), (1.0, 1.0)), + (jnp.pow, (-1.0, 0.5), (1.0, 1.0)), + (jnp.power, (-1.0, 0.5), (1.0, 1.0)), + (jnp.remainder, (1.0, 0.0), (1.0, 1.0)), + (jnp.sin, jnp.inf, 1.0), + # TODO(https://github.com/jax-ml/jax/issues/27470): Not yet supported. + # (jnp.sinc, jnp.inf, 1.0), + (jnp.sqrt, -4.0, 4.0), + (jnp.subtract, (jnp.inf, jnp.inf), (0.0, 0.0)), + (jnp.tan, jnp.inf, 1.0), + (jnp.true_divide, (0.0, 0.0), (1.0, 1.0)), + (operator.add, (jnp.inf, -jnp.inf), (0.0, 0.0)), + (operator.mod, (1.0, 0.0), (1.0, 1.0)), + (operator.pow, (-1.0, 0.5), (1.0, 1.0)), + (operator.sub, (jnp.inf, jnp.inf), (0.0, 0.0)), + (operator.truediv, (0.0, 0.0), (1.0, 1.0)), + # go/keep-sorted end + )) + ) + def test_can_raise_nan_error(self, jit, f, args_err, args_no_err): + args_err = [jnp.float32(x) for x in args_err] + args_no_err = [jnp.float32(x) for x in args_no_err] + + if jit: + f = jax.jit(f) + + with jnp_error.error_checking_behavior(nan="raise"): + f(*args_no_err) + error_check.raise_if_error() # should not raise error + + f(*args_err) + with self.assertRaisesRegex(JaxValueError, "NaN"): + error_check.raise_if_error() + + INT_TYPES = jtu.dtypes.supported( + (jnp.int32, jnp.uint32, jnp.int64, jnp.uint64, + jnp.int16, jnp.uint16, jnp.int8, jnp.uint8)) + FLOAT_TYPES = jtu.dtypes.supported( + (jnp.float32, jnp.float64, jnp.float16, jnp.bfloat16)) + + @staticmethod + def divide_cases(cases): + for jit in (True, False): + for func, dtypes in cases: + for dtype in dtypes: + jit_str = "jit" if jit else "nojit" + func_str = f"{func.__module__}.{func.__name__}" + dtype_str = dtype.__name__ + name = f"_{jit_str}_{func_str}_{dtype_str}" + yield name, jit, func, dtype + + @parameterized.named_parameters( + divide_cases(( + # go/keep-sorted start + (jnp.divmod, FLOAT_TYPES + INT_TYPES), + (jnp.floor_divide, INT_TYPES), + (jnp.mod, FLOAT_TYPES + INT_TYPES), + (jnp.remainder, FLOAT_TYPES + INT_TYPES), + (jnp.true_divide, FLOAT_TYPES), + (operator.mod, FLOAT_TYPES + INT_TYPES), + (operator.truediv, FLOAT_TYPES), + # go/keep-sorted end + )) + ) + def test_can_raise_divide_by_zero_error(self, jit, div_func, dtype): + args_err = (dtype(1), dtype(0)) + args_no_err = (dtype(1), dtype(1)) + + if jit: + div_func = jax.jit(div_func) + + with jnp_error.error_checking_behavior(divide="raise"): + div_func(*args_no_err) + error_check.raise_if_error() # should not raise error + + div_func(*args_err) + with self.assertRaisesRegex(JaxValueError, "Division by zero"): + error_check.raise_if_error() + + @parameterized.product(jit=[True, False]) + def test_can_raise_oob_error_take(self, jit): + def f(x, a): + return x[a] + + if jit: + f = jax.jit(f) + + x = jnp.arange(10) + a = jnp.int32(10) + + with jnp_error.error_checking_behavior(oob="ignore"): + f(x, a) + error_check.raise_if_error() # should not raise error + + with jnp_error.error_checking_behavior(oob="raise"): + f(x, a) + with self.assertRaisesRegex(JaxValueError, "Out of bounds"): + error_check.raise_if_error() + + def test_can_raise_oob_error_dynamic_slice(self): + def f(x, a): + return x[:, a:a+4] # dynamic indices are non-jittable + + x = jnp.arange(10).reshape(2, 5) + a = jnp.array(3, dtype=jnp.int32) + + with jnp_error.error_checking_behavior(oob="ignore"): + f(x, a) + error_check.raise_if_error() # should not raise error + + with jnp_error.error_checking_behavior(oob="raise"): + f(x, a) + with self.assertRaisesRegex(JaxValueError, "Out of bounds"): + error_check.raise_if_error() + + def test_empty_indices(self): + # Regression test for https://github.com/jax-ml/jax/issues/32070 + with jnp_error.error_checking_behavior(oob="raise"): + jnp.zeros(1)[None] # should not error. + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/jax_to_ir_test.py b/tests/jax_to_ir_test.py index f600a08f5dc4..46131bd22b64 100644 --- a/tests/jax_to_ir_test.py +++ b/tests/jax_to_ir_test.py @@ -15,15 +15,21 @@ import unittest from absl.testing import absltest +from jax._src import test_util as jtu import jax.numpy as jnp from jax.tools import jax_to_ir -from jax._src import test_util as jtu + try: import tensorflow as tf except ImportError: tf = None # type: ignore +try: + from tensorflow.compiler.tf2xla.python import xla as tfxla +except ImportError: + tfxla = None # type: ignore + def axpy(a, x, y): return a * x + y[:, jnp.newaxis] @@ -81,6 +87,11 @@ def test_parse_shape_str_invalid(self): jax_to_ir.parse_shape_str('foo[]') @unittest.skipIf(tf is None, 'TensorFlow not installed.') + # TODO(dsuo): Remove this once we bump tensorflow version. + @unittest.skipIf( + tfxla is None or tfxla.call_module_maximum_supported_version() < 10, + 'TensorFlow version too old.', + ) def test_jax_to_tf_axpy(self): tf_proto, tf_text = jax_to_ir.jax_to_tf(axpy, [ ('y', jax_to_ir.parse_shape_str('f32[128]')), @@ -114,15 +125,13 @@ def test_parse_shape_str(self): self.assertParsedShape('f32[]', [], jnp.float32) self.assertParsedShape('f32[1,2,3]', [1, 2, 3], jnp.float32) self.assertParsedShape('pred[1]', [1], jnp.bool_) - if hasattr(jnp, 'int2'): - self.assertParsedShape('s2[1]', [1], jnp.int2) + self.assertParsedShape('s2[1]', [1], jnp.int2) self.assertParsedShape('s4[1]', [1], jnp.int4) self.assertParsedShape('s8[1]', [1], jnp.int8) self.assertParsedShape('s16[1]', [1], jnp.int16) self.assertParsedShape('s32[1]', [1], jnp.int32) self.assertParsedShape('s64[1]', [1], jnp.int64) - if hasattr(jnp, 'uint2'): - self.assertParsedShape('u2[1]', [1], jnp.uint2) + self.assertParsedShape('u2[1]', [1], jnp.uint2) self.assertParsedShape('u4[1]', [1], jnp.uint4) self.assertParsedShape('u8[1]', [1], jnp.uint8) self.assertParsedShape('u16[1]', [1], jnp.uint16) diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py index c331bfaf438a..bea0c01b78af 100644 --- a/tests/jaxpr_effects_test.py +++ b/tests/jaxpr_effects_test.py @@ -20,8 +20,6 @@ from jax import api_util import jax.numpy as jnp from jax import lax -from jax.experimental import pjit -from jax._src import ad_checkpoint from jax._src import callback as cb from jax._src import dispatch from jax._src import config @@ -29,7 +27,6 @@ from jax._src import effects from jax._src import linear_util as lu from jax._src import test_util as jtu -from jax._src import util from jax._src.interpreters import ad from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe @@ -86,23 +83,8 @@ def trivial_effect_lowering(ctx, *, effect): mlir.register_lowering(effect_p, trivial_effect_lowering) def function_effect_lowering(ctx, *, effect): - def _f(ctx): - ctx.set_tokens_out(ctx.tokens_in) - return [] - func = mlir._emit_lowering_rule_as_fun(_f, ctx) - - output_types = map(mlir.aval_to_ir_type, ctx.avals_out) - effs = list(ctx.tokens_in.effects()) - in_tokens = [ctx.tokens_in.get(eff) for eff in effs] - token_types = [mlir.token_type() for _ in effs] - output_types = [*token_types, *output_types] - flat_output_types = mlir.flatten_ir_types(output_types) - call = mlir.func_dialect.CallOp(flat_output_types, - mlir.ir.FlatSymbolRefAttr.get(func.name.value), - mlir.flatten_ir_values(in_tokens)) - tokens, out = util.split_list(call.results, [len(ctx.tokens_in)]) - ctx.set_tokens_out(mlir.TokenSet(zip(effs, tokens))) - return out + ctx.set_tokens_out(ctx.tokens_in) + return [] callback_p = core.Primitive('callback') callback_p.multiple_results = True @@ -126,7 +108,7 @@ def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out out_op, token_out, _ = cb.emit_python_callback( ctx, callback, token_in, list(args), list(ctx.avals_in), - list(ctx.avals_out), has_side_effect=True) + list(ctx.avals_out), has_side_effect=True, returns_token=True) if token_out: ctx.set_tokens_out(ctx.tokens_in.update_tokens(mlir.TokenSet({effect: token_out}))) @@ -216,7 +198,7 @@ def f(x): def test_new_remat_allows_certain_effects(self): remat_effect = RematEffect() - @ad_checkpoint.checkpoint + @jax.checkpoint def f(x): x, = effect_p.bind(x, effect=remat_effect) return x @@ -248,6 +230,8 @@ def f(x): jax.make_jaxpr(f)(2.) def test_pmap_inherits_effects(self): + if config.pmap_shmap_merge.value: + self.skipTest("Test does not raise under `pmap_shmap_merge=True`.") @jax.pmap def f(x): @@ -259,18 +243,89 @@ def f(x): r"Ordered effects not supported for map primitives: \[.*\]"): jax.make_jaxpr(f)(jnp.arange(jax.local_device_count())) - def test_pjit_inherits_effects(self): + def test_jit_inherits_effects(self): def f(x): effect_p.bind(effect=foo_effect) effect_p.bind(effect=bar_effect) return x mesh = jax.sharding.Mesh(np.array(jax.devices()), ['x']) spec = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x')) - f = pjit.pjit(f, in_shardings=spec, out_shardings=spec) - with mesh: + f = jax.jit(f, in_shardings=spec, out_shardings=spec) + with jax.set_mesh(mesh): jaxpr = jax.make_jaxpr(f)(np.arange(jax.local_device_count())) self.assertSetEqual(jaxpr.effects, {foo_effect, bar_effect}) + def test_pjit_const_input_effect_indexing(self): + # https://github.com/jax-ml/jax/issues/32399 + @jax.jit + def bar(x, w): + def scan_fn(x, _): + c = jnp.array([]) + o = w[...] @ x + x = jnp.concatenate([x, c], axis=-1) + return x, None + + x, _ = jax.lax.scan(scan_fn, x, None, length=10) + return x + + + @jax.jit + def foo(w): + return bar(jnp.zeros((1,)), w) + + foo(jax.new_ref(jnp.eye(1))) # don't crash + + def test_jit_const_input_effect_indexing(self): + @jax.jit + def bar(w): + x = jnp.zeros((1,)) + jnp.array([0.]) + x = jax.jit(lambda x: x + w[...])(x) + return x + + @jax.jit + def foo(w): + return bar(w) + + foo(jax.new_ref(jnp.ones((1,)))) + jax.grad(jax.remat(lambda x: foo(jax.new_ref(x)).sum()))(jnp.ones((1,))) + + def test_cond_const_input_effect_indexing(self): + @jax.custom_jvp + def weird(x): + return x + + @weird.defjvp + def weird_jvp(primals, tangents): + (x,), (xdot,) = primals, tangents + return jnp.sum(np.ones(3)) * x, xdot + + @jax.jit + def f(x): + x_ref = jax.new_ref(0.) + return jax.lax.cond(x < 0, lambda: x_ref[...], lambda: weird(x[...])) + + jax.jvp(f, (1.,), (1.,)) + + def test_scan_const_input_effect_indexing(self): + @jax.custom_jvp + def weird(x): + return x + + @weird.defjvp + def weird_jvp(primals, tangents): + (x,), (xdot,) = primals, tangents + return jnp.sum(np.ones(3)) * x, xdot + + @jax.jit + def f(x): + x_ref = jax.new_ref(0.) + y, () = jax.lax.scan(lambda _, __: (weird(x_ref[...]), ()), + x_ref[...], length=1) + return y + + jax.jvp(f, (1.,), (1.,)) + jax.grad(jax.remat(f))(1.) + @jtu.thread_unsafe_test_class() # because of mlir.register_lowering calls class EffectfulJaxprLoweringTest(jtu.JaxTestCase): @@ -294,7 +349,7 @@ def _effect_lowering(ctx, *, effect): def tearDown(self): super().tearDown() dispatch.runtime_tokens.clear() - mlir.register_lowering(effect_p, self._old_lowering) + mlir._lowerings[effect_p] = self._old_lowering def test_can_lower_lowerable_effect(self): @jax.jit @@ -355,8 +410,7 @@ def f(x): f.lower(2.) def test_nontrivial_lowering_with_ordered_effect_should_consume_token(self): - - mlir.register_lowering(effect_p, function_effect_lowering) + mlir.register_lowering(effect_p, function_effect_lowering, inline=False) @jax.jit def f(x): @@ -375,8 +429,7 @@ def f(x): self.assertIn('hlo.token', str(func.type.results[0])) def test_nontrivial_lowering_with_unordered_effect_should_consume_token(self): - - mlir.register_lowering(effect_p, function_effect_lowering) + mlir.register_lowering(effect_p, function_effect_lowering, inline=False) @jax.jit def f(x): @@ -476,9 +529,17 @@ def test_cant_jit_and_pmap_function_with_ordered_effects(self): def f(x): effect_p.bind(effect=foo_effect) return x + 1 - with self.assertRaisesRegex( - ValueError, - r"Ordered effects not supported for map primitives: \[foo\]"): + if config.pmap_shmap_merge.value: + if jax.device_count() == 1: + self.skipTest("This test won't raise with 1 device.") + if jtu.device_under_test() == "gpu": + self.skipTest("Test does not raise under GPU.") + if jtu.device_under_test() == "tpu" and jtu.get_tpu_version() > 3: + self.skipTest("Test does not raise under TPU v4+.") + regex = r"The following ordered effects are not supported for more than 1 device: \[foo\]" + else: + regex = r"Ordered effects not supported for map primitives: \[foo\]" + with self.assertRaisesRegex(ValueError, regex): f(jnp.arange(jax.device_count())) def test_runtime_tokens_should_update_after_running_effectful_function(self): @@ -527,7 +588,7 @@ def log_value(x): @jax.jit def f(x): - return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=[]) + return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=()) f(2.) jax.effects_barrier() @@ -552,11 +613,11 @@ def f(x): # Expensive computation x = x.dot(x) x = jnp.log(x.sum()) - return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=[]) + return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=()) @jax.jit def g(x): - return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=[]) + return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=()) x = jax.device_put(jnp.ones((500, 500)), jax.devices()[0]) y = jax.device_put(3., jax.devices()[1]) @@ -579,7 +640,7 @@ def f(x): # Runs in a thread. res = jax.jit( lambda x: callback_p.bind( - x, callback=_noop, effect=log_effect, out_avals=[]) + x, callback=_noop, effect=log_effect, out_avals=()) )(x) tokens.append(dispatch.runtime_tokens.current_tokens[log_effect]) return res @@ -606,13 +667,22 @@ def f(x): jax.pmap(f)(jnp.arange(jax.local_device_count())) def test_cannot_pmap_ordered_effect(self): - def f(x): # foo is lowerable and ordered effect_p.bind(effect=foo_effect) return x + if config.pmap_shmap_merge.value: + if jax.device_count() == 1: + self.skipTest("This test won't raise with 1 device.") + if jtu.device_under_test() == "gpu": + self.skipTest("Test does not raise under GPU.") + if jtu.device_under_test() == "tpu" and jtu.get_tpu_version() > 3: + self.skipTest("Test does not raise under TPU v4+.") + regex = r"The following ordered effects are not supported for more than 1 device: \[foo\]" + else: + regex = "Ordered effects not supported in `pmap`." with self.assertRaisesRegex( - ValueError, "Ordered effects not supported in `pmap`."): + ValueError, regex): jax.pmap(f)(jnp.arange(jax.local_device_count())) def test_can_pmap_unordered_effect(self): @@ -635,7 +705,7 @@ def log_value(x): @jax.pmap def f(x): callback_p.bind( - x, callback=log_value, effect=unordered_log_effect, out_avals=[]) + x, callback=log_value, effect=unordered_log_effect, out_avals=()) return x + 1 f(jnp.arange(2)).block_until_ready() jax.effects_barrier() @@ -893,7 +963,10 @@ def f(x, y): def f(y): return input_effect(x, y, index=0) jaxpr = jax.make_jaxpr(f)(0) - self.assertIn(InputEffect(0), jaxpr.effects) + if config.use_simplified_jaxpr_constants.value: + self.assertEmpty(jaxpr.effects) + else: + self.assertIn(InputEffect(0), jaxpr.effects) def test_jaxpr_input_effect_is_tracked_through_partial_eval_custom(self): def f(_, y): @@ -935,9 +1008,15 @@ def f(_, y): def f(_): input_effect(x, index=0) jaxpr = jax.make_jaxpr(f)(0) - self.assertIn(InputEffect(0), jaxpr.effects) + if config.use_simplified_jaxpr_constants.value: + self.assertEmpty(jaxpr.effects) + else: + self.assertIn(InputEffect(0), jaxpr.effects) jaxpr3, _ = pe.dce_jaxpr(jaxpr.jaxpr, [], instantiate=[False]) - self.assertIn(InputEffect(0), jaxpr3.effects) + if config.use_simplified_jaxpr_constants.value: + self.assertEmpty(jaxpr3.effects) + else: + self.assertIn(InputEffect(0), jaxpr3.effects) def test_jaxpr_input_effect_is_tracked_through_while_loop(self): @@ -947,22 +1026,31 @@ def make_fun(index): def f(x): def body(y): input_effect(x, y, index=index) - return y + return 2 * y lax.while_loop(lambda _: True, body, y) return f jaxpr = jax.make_jaxpr(make_fun(0))(0) - self.assertIn(InputEffect(1), jaxpr.effects) + if config.use_simplified_jaxpr_constants.value: + self.assertIn(InputEffect(0), jaxpr.effects) + else: + self.assertIn(InputEffect(1), jaxpr.effects) jaxpr = jax.make_jaxpr(make_fun(1))(0) - self.assertIn(InputEffect(0), jaxpr.effects) + if config.use_simplified_jaxpr_constants.value: + self.assertEmpty(jaxpr.effects) + else: + self.assertIn(InputEffect(0), jaxpr.effects) def f(x): def body(y): input_effect(x, y, index=1) - return y + return 2 * y lax.while_loop(lambda _: (x > 0).all(), body, y) jaxpr = jax.make_jaxpr(f)(0) - self.assertIn(InputEffect(0), jaxpr.effects) + if config.use_simplified_jaxpr_constants.value: + self.assertEmpty(jaxpr.effects) + else: + self.assertIn(InputEffect(0), jaxpr.effects) def test_jaxpr_input_effect_is_tracked_through_scan(self): c = np.ones(2) @@ -974,13 +1062,22 @@ def body(z, x): lax.scan(body, z, xs) return f jaxpr = jax.make_jaxpr(make_fun(0))(jnp.arange(8), 0) - self.assertIn(InputEffect(1), jaxpr.effects) + if config.use_simplified_jaxpr_constants.value: + self.assertIn(InputEffect(0), jaxpr.effects) + else: + self.assertIn(InputEffect(1), jaxpr.effects) jaxpr = jax.make_jaxpr(make_fun(1))(jnp.arange(8), 0) - self.assertIn(InputEffect(2), jaxpr.effects) + if config.use_simplified_jaxpr_constants.value: + self.assertIn(InputEffect(1), jaxpr.effects) + else: + self.assertIn(InputEffect(2), jaxpr.effects) jaxpr = jax.make_jaxpr(make_fun(2))(jnp.arange(8), 0) - self.assertIn(InputEffect(0), jaxpr.effects) + if config.use_simplified_jaxpr_constants.value: + self.assertEmpty(jaxpr.effects) + else: + self.assertIn(InputEffect(0), jaxpr.effects) def test_jaxpr_input_effect_is_tracked_through_scan_with_dce(self): c = np.ones(2) @@ -993,15 +1090,24 @@ def body(z, x): return f jaxpr = jax.make_jaxpr(make_fun(0))(jnp.arange(8), 0) jaxpr, _ = pe.dce_jaxpr(jaxpr.jaxpr, []) - self.assertIn(InputEffect(1), jaxpr.effects) + if config.use_simplified_jaxpr_constants.value: + self.assertIn(InputEffect(0), jaxpr.effects) + else: + self.assertIn(InputEffect(1), jaxpr.effects) jaxpr = jax.make_jaxpr(make_fun(1))(jnp.arange(8), 0) jaxpr, _ = pe.dce_jaxpr(jaxpr.jaxpr, []) - self.assertIn(InputEffect(2), jaxpr.effects) + if config.use_simplified_jaxpr_constants.value: + self.assertIn(InputEffect(1), jaxpr.effects) + else: + self.assertIn(InputEffect(2), jaxpr.effects) jaxpr = jax.make_jaxpr(make_fun(2))(jnp.arange(8), 0) jaxpr, _ = pe.dce_jaxpr(jaxpr.jaxpr, []) - self.assertIn(InputEffect(0), jaxpr.effects) + if config.use_simplified_jaxpr_constants.value: + self.assertEmpty(jaxpr.effects) + else: + self.assertIn(InputEffect(0), jaxpr.effects) def test_jaxpr_input_effect_is_tracked_through_cond(self): @@ -1018,10 +1124,16 @@ def false_fun(x): return f # [c, pred, x] jaxpr = jax.make_jaxpr(make_fun(0))(0) - self.assertIn(InputEffect(1), jaxpr.effects) + if config.use_simplified_jaxpr_constants.value: + self.assertIn(InputEffect(0), jaxpr.effects) + else: + self.assertIn(InputEffect(1), jaxpr.effects) jaxpr = jax.make_jaxpr(make_fun(1))(0) - self.assertIn(InputEffect(0), jaxpr.effects) + if config.use_simplified_jaxpr_constants.value: + self.assertEmpty(jaxpr.effects) + else: + self.assertIn(InputEffect(0), jaxpr.effects) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/jaxpr_util_test.py b/tests/jaxpr_util_test.py index 4597ce6bd7d5..7bf49799feff 100644 --- a/tests/jaxpr_util_test.py +++ b/tests/jaxpr_util_test.py @@ -38,8 +38,7 @@ def f(x, y): hist = jaxpr_util.primitives(make_jaxpr(f)(1., 1.).jaxpr) - primitives = ['add', 'sin', 'cos'] - primitives.append('pjit') + primitives = ['add', 'sin', 'cos', 'jit'] for k in primitives: assert k in hist, k self.assertEqual(hist['sin'], 2) @@ -74,7 +73,7 @@ def sub(x, y): f'cos :: float{t}[]', f'reduce_sum :: float{t}[]', f'concatenate :: float{t}[2]', - f'pjit :: float{t}[]', + f'jit :: float{t}[]', ] for k in shapes: self.assertEqual(hist[k], 1) diff --git a/tests/jet_test.py b/tests/jet_test.py index 7c2c71e9bbfa..2db446cc4735 100644 --- a/tests/jet_test.py +++ b/tests/jet_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - from functools import reduce, partial from absl.testing import absltest @@ -289,6 +288,8 @@ def test_tanh(self): self.unary_check(jnp.tanh, lims=[-500, 500], order=5, atol=5e-3) @jtu.skip_on_devices("tpu") def test_logistic(self): self.unary_check(lax.logistic, lims=[-100, 100], order=5) + @unittest.skipIf(jtu.is_test_rbe() and jtu.is_gil_disabled() and jtu.is_tsan(), + "Consumes too much RAM under FT TSAN: b/456211935") @jtu.skip_on_devices("tpu") def test_expit2(self): self.expit_check(lims=[-500, 500], order=5) @jtu.skip_on_devices("tpu") @@ -413,6 +414,8 @@ def g(eps): return jax.grad(f)(x, eps) jet(g, (1.,), ([1.],)) # doesn't crash + @unittest.skipIf(jtu.is_test_rbe() and jtu.is_gil_disabled() and jtu.is_tsan(), + "Consumes too much RAM under FT TSAN: b/456211935") def test_scatter_add(self): # very basic test from https://github.com/jax-ml/jax/issues/5365 def f(x): diff --git a/tests/key_reuse_test.py b/tests/key_reuse_test.py index 0d7ac9c18827..26599ef726b0 100644 --- a/tests/key_reuse_test.py +++ b/tests/key_reuse_test.py @@ -257,8 +257,9 @@ def f(key1, key2): def test_jit_can_consume_input(self): def f(key): assert_unconsumed(key) - jax.jit(jax.random.bits)(key) + ans = jax.jit(jax.random.bits)(key) assert_consumed(key) + return ans self.check_key_reuse(f, jax.random.key(0)) def test_jit_can_return_consumed_output(self): @@ -306,28 +307,31 @@ def g(key): assert_unconsumed(key) assert_unconsumed(key1) assert_unconsumed(key2) - _ = jax.random.bits(key1) + other = jax.random.bits(key1) assert_consumed(key) assert_consumed(key1) assert_consumed(key2) + return (key1, key2, other) self.check_key_reuse(f, jax.random.key(0)) def test_cond_both_consumed(self): @jax.jit def f(flag, key): assert_unconsumed(key) - _ = jax.lax.cond( + ans = jax.lax.cond( flag, jax.random.uniform, jax.random.normal, key) assert_consumed(key) + return ans self.check_key_reuse(f, True, jax.random.key(0)) def test_cond_one_consumed(self): @jax.jit def f(flag, key): assert_unconsumed(key) - _ = jax.lax.cond( + ans = jax.lax.cond( flag, jax.random.uniform, lambda k: 1.0, key) assert_consumed(key) + return ans self.check_key_reuse(f, True, jax.random.key(0)) def test_cond_neither_consumed(self): @@ -369,7 +373,7 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase): random_bits_error = "In random_bits, argument [0-9]+ is already consumed.*" random_split_error = "In random_split, argument [0-9]+ is already consumed.*" generic_error = ".*argument [0-9]+ is already consumed.*" - pjit_error = "In pjit, argument 0 is already consumed." + pjit_error = "In jit, argument 0 is already consumed." def check_key_reuse(self, f, *args): return _core.check_key_reuse(f, *args) @@ -391,17 +395,17 @@ def f_good(): def f_bad(): key = jax.random.key(0) - _ = jax.random.split(key) - return jax.random.uniform(key) + other = jax.random.split(key) + return (jax.random.uniform(key), other) with self.assertRaisesRegex(KeyReuseError, self.pjit_error): self.check_key_reuse(f_bad) def f_bad_2(): key = jax.random.key(0) - _ = jax.random.split(key) - key1, _ = jax.random.split(key) - return jax.random.uniform(key1) + other1 = jax.random.split(key) + key1, other2 = jax.random.split(key) + return (jax.random.uniform(key1), other1, other2) with self.assertRaisesRegex(KeyReuseError, self.random_split_error): self.check_key_reuse(f_bad_2) @@ -612,7 +616,7 @@ def f_good(x, key): self.check_key_reuse(f_bad, x, key) with self.assertRaisesRegex(KeyReuseError, self.random_bits_error): - self.check_key_reuse(jax.grad(f_bad), x, key) + self.check_key_reuse(jax.value_and_grad(f_bad), x, key) self.check_key_reuse(f_good, x, key) self.check_key_reuse(jax.grad(f_good), x, key) diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index a69f44f37754..e3ba552140cf 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -28,7 +28,6 @@ from jax import dtypes from jax import lax from jax._src import test_util as jtu -from jax._src.util import NumpyComplexWarning from jax.test_util import check_grads jax.config.parse_flags_with_absl() @@ -205,14 +204,16 @@ class LaxAutodiffTest(jtu.JaxTestCase): )) def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol): rng = rng_factory(self.rng()) - if jtu.test_device_matches(["cpu"]): + if jtu.test_device_matches(["cpu", "tpu"]): if op is lax.cosh and dtype == np.complex64: - tol = 3e-1 # 2nd-order gradients are noisy on CPU + tol = 3e-1 # 2nd-order gradients are noisy on CPU and TPU if jtu.test_device_matches(["tpu"]): if op is lax.pow: raise SkipTest("pow grad imprecise on tpu") if op is lax.cos: order = 1 # 2nd-order gradient is imprecise on TPU. + if op is lax.sin: + order = 1 # 2nd-order gradient is imprecise on TPUv5p. if op is lax.log: order = 1 # 2nd-order gradient is imprecise on TPU. @@ -242,7 +243,7 @@ def testConvertElementTypeGrad(self, from_dtype, to_dtype): jtu.tolerance(from_dtype, jtu.default_gradient_tolerance)) args = (rng((2, 3), from_dtype),) convert_element_type = lambda x: lax.convert_element_type(x, to_dtype) - convert_element_type = jtu.ignore_warning(category=NumpyComplexWarning)( + convert_element_type = jtu.ignore_warning(category=np.exceptions.ComplexWarning)( convert_element_type) check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.) @@ -685,6 +686,18 @@ def f(x, y): result2, _ = jax.value_and_grad(f, 0)(x, y) self.assertAllClose(result1, result2) + def testGradOfVmapOfDynamicSlice(self): + # Regression test for https://github.com/jax-ml/jax/issues/34228. + def f(x, i): + return jax.lax.dynamic_index_in_dim(x, i, axis=0, keepdims=False) + + x = jax.numpy.array([1.0]) + i = jax.numpy.array([1]) # out-of-bound index + expected = jax.numpy.array([[1.0]]) + + self.assertArraysEqual(jax.jacrev(f)(x, i[0]), expected[0, 0]) + self.assertArraysEqual(jax.jacrev(jax.vmap(f, (None, 0)))(x, i), expected) + @jtu.sample_product( [dict(shape=shape, perm=perm) for shape, perm in [ @@ -852,7 +865,7 @@ def testCumulativeReduceGrad(self, op, shape, dtype, axis, reverse): # TODO(b/205052657): enable more tests when supported @jtu.sample_product( [dict(shape=shape, axis=axis) - for shape in [(5,), (5, 7), (4, 9, 3)] + for shape in [(0,), (5,), (5, 7), (4, 9, 3)] for axis in [len(shape) - 1] ], dtype=[np.float32], @@ -891,13 +904,14 @@ def args_maker(): @jtu.sample_product( dtype=[np.float32,], - shape=[(4,), (5, 5), (2, 1, 4)], + shape=[(4,), (5, 5), (3, 1, 4)], k=[1, 3], + axis=[0, -1] ) - def testTopKGrad(self, shape, dtype, k): + def testTopKGrad(self, shape, dtype, k, axis): flat_values = np.arange(math.prod(shape), dtype=dtype) values = self.rng().permutation(flat_values).reshape(shape) - fun = lambda vs: lax.top_k(vs, k=k)[0] + fun = lambda vs: lax.top_k(vs, k=k, axis=axis)[0] check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2) @jtu.sample_product( @@ -1178,6 +1192,23 @@ def testPowShapeMismatch(self): expected = jax.numpy.diag(y * x ** (y - 1)) self.assertArraysEqual(actual, expected) + @jtu.sample_product( + [ + dict(arg_shape=arg_shape, reps=reps) + for arg_shape, reps in [ + [(3,), (2,)], + [(2, 3), (1, 2)], + [(1, 1, 4), (1, 3, 1)], + ] + ], + dtype=grad_float_dtypes, + ) + def testTileAutodiff(self, arg_shape, reps, dtype): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(arg_shape, dtype)] + op = lambda x: lax.tile(x, reps) + check_grads(op, args_maker(), order=3, modes=["fwd", "rev"], eps=1.) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 3871a87a7a3e..b76d29362913 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -15,8 +15,10 @@ import collections import contextlib +import gc from functools import partial import itertools +import math import operator import re import unittest @@ -28,17 +30,19 @@ import jax from jax._src import core -from jax import dtypes +from jax._src import config +from jax._src import dtypes from jax import lax from jax import random from jax._src import test_util as jtu from jax import tree_util -from jax._src.util import unzip2 -from jax.ad_checkpoint import checkpoint as new_checkpoint, checkpoint_policies +from jax._src.util import unzip2, split_list +from jax import checkpoint_policies import jax.numpy as jnp # scan tests use numpy import jax.scipy as jsp +from jax._src import dispatch from jax._src.lax import control_flow as lax_control_flow -from jax._src.lax.control_flow import for_loop +from jax._src.interpreters import batching from jax._src.interpreters import mlir jax.config.parse_flags_with_absl() @@ -48,25 +52,14 @@ # provides a lax.cond-compatible interface to a two-branch lax.switch. Several # tests in this file are parameterized such that they either call into lax.cond # or into this function. -def cond_via_switch(pred, true_fun, false_fun, op, *args): - if len(args) > 0: - assert len(args) == 1 - true_op, _true_fun, false_op, _false_fun = true_fun, false_fun, op, args[0] - op = (false_op, true_op) - false_fun = lambda op: _false_fun(op[0]) - true_fun = lambda op: _true_fun(op[1]) +def cond_via_switch(pred, true_fun, false_fun, *args): index = lax.convert_element_type(pred, np.int32) - return lax.switch(index, [false_fun, true_fun], op) - -def cond_with_new_checkpoint(pred, true_fun, false_fun, op, *args): - if args: - true_op, _true_fun, false_op, _false_fun = true_fun, false_fun, op, args[0] - op = (false_op, true_op) - false_fun = lambda op: _false_fun(op[0]) - true_fun = lambda op: _true_fun(op[1]) + return lax.switch(index, [false_fun, true_fun], *args) + +def cond_with_new_checkpoint(pred, true_fun, false_fun, *args): index = lax.convert_element_type(pred, np.int32) - fn = lambda index, op: lax.switch(index, [false_fun, true_fun], op) - return new_checkpoint(fn)(index, op) + fn = lambda index, *args: lax.switch(index, [false_fun, true_fun], *args) + return jax.checkpoint(fn)(index, *args) COND_IMPLS = [ (lax.cond, 'cond'), @@ -76,33 +69,26 @@ def cond_with_new_checkpoint(pred, true_fun, false_fun, op, *args): # We wanted to try all scan tests with the scan partial evaluation rule that -# happens under ad_checkpoint.checkpoint, so we make a scan wrapper which -# wraps a ad_checkpoint.checkpoint around the computation. +# happens under jax.checkpoint, so we make a scan wrapper which +# wraps a jax.checkpoint around the computation. def scan_with_new_checkpoint(f, *args, **kwargs): - return new_checkpoint(partial(lax.scan, f, **kwargs), + return jax.checkpoint(partial(lax.scan, f, **kwargs), policy=checkpoint_policies.nothing_saveable)(*args) def scan_with_new_checkpoint2(f, *args, **kwargs): - return new_checkpoint(partial(lax.scan, f, **kwargs), + return jax.checkpoint(partial(lax.scan, f, **kwargs), policy=checkpoint_policies.everything_saveable)(*args) -def scan_with_for(f, *args, **kwargs): - return for_loop.scan(f, *args, **kwargs) - -def scan_with_remat_for(f, *args, **kwargs): - return jax.remat(lambda *args: for_loop.scan(f, *args, **kwargs))(*args) - SCAN_IMPLS_WITH_FOR = [ (lax.scan, 'unroll1'), + (partial(lax.scan, unroll=0), 'unroll0'), (partial(lax.scan, unroll=2), 'unroll2'), (partial(lax.scan, _split_transpose=True), 'split_transpose'), (scan_with_new_checkpoint , 'new_checkpoint'), (scan_with_new_checkpoint2, 'new_checkpoint2'), - (scan_with_for, 'for_loop'), - (scan_with_remat_for, 'for_loop_remat'), ] def while_loop_new_checkpoint(cond_fun, body_fun, init_val): - return new_checkpoint(partial(lax.while_loop, cond_fun, body_fun))(init_val) + return jax.checkpoint(partial(lax.while_loop, cond_fun, body_fun))(init_val) WHILE_LOOP_IMPLS = [ (lax.while_loop, 'while_loop'), @@ -137,15 +123,44 @@ def scan_reference(f, init, xs): lambda ctx, x: mlir.hlo.CustomCallOp( [x.type], [x], call_target_name=mlir.ir.StringAttr.get("__testing_non_existent_custom_call")).results) +batching.primitive_batchers[prim_non_existent_custom_call] = ( + lambda batched_args, batch_dims: (prim_non_existent_custom_call.bind(batched_args[0]), + batch_dims[0])) + +# A JAX primitive that triggers error when lowering on unintended platforms +prim_with_lowering_error = core.Primitive("__testing_prim_with_lowering_error") +prim_with_lowering_error.def_abstract_eval(lambda x_aval, **_: x_aval) +def prim_with_lowering_error_lowering(platform: str, + ctx: mlir.LoweringRuleContext, x, *, + only_on: str): + if platform != only_on: + raise ValueError(f"prim_with_lowering_error with only_on={only_on} lowered for {platform}") + return mlir.hlo.SineOp(x).results +def prim_with_lowering_error_batch_rule(batched_args, batch_dims, **params): + xs, = batched_args + xs_bdim, = batch_dims + return prim_with_lowering_error.bind(xs, **params), xs_bdim + +batching.primitive_batchers[prim_with_lowering_error] = prim_with_lowering_error_batch_rule + +mlir.register_lowering( + prim_with_lowering_error, + partial(prim_with_lowering_error_lowering, "cpu"), + platform="cpu") +mlir.register_lowering( + prim_with_lowering_error, + partial(prim_with_lowering_error_lowering, "tpu"), + platform="tpu") +prim_with_lowering_error.def_impl(partial(dispatch.apply_primitive, + prim_with_lowering_error)) class LaxControlFlowTest(jtu.JaxTestCase): def setUp(self): super().setUp() - lax_control_flow._initial_style_open_jaxpr.cache_clear() - lax_control_flow._initial_style_jaxpr.cache_clear() - lax_control_flow.common._pad_jaxpr_constvars.cache_clear() + lax_control_flow.common._dedup_consts.cache_clear() + lax_control_flow.common._pad_constvars.cache_clear() def testCallableErrors(self): not_callable = 42 @@ -588,7 +603,6 @@ def test_fori_loop_returns_init_with_nonpositive_length( init = jnp.float32(10) self.assertEqual(fori_loop_with_static_upper_and_lower(init), init) - def testForiLoopBatched(self): def body_fun(i, loop_carry): x, y = loop_carry @@ -972,8 +986,8 @@ def cfun(x): lax.lt(x, 2), lambda x: lax.mul(2, x), lambda x: cond(lax.lt(x, 5), - x, lambda x: lax.mul(3, x), - 4, lambda y: lax.mul(y, x)), + lambda x, _: lax.mul(3, x), + lambda _, y: lax.mul(y, x), x, 4), x) self.assertEqual(cfun(1), 2) @@ -994,16 +1008,24 @@ def testCondTypeErrors(self): with self.assertRaisesRegex(TypeError, re.escape("Pred must be a scalar, got (1.0, 1.0) of type ")): lax.cond((1., 1.), lambda top: 2., lambda fop: 3., 1.) - with self.assertRaisesRegex(TypeError, - re.compile("true_fun output must have same type structure " - "as false_fun output, but there are differences:.*" - r"at output\['a'\], true_fun output has pytree leaf", re.DOTALL)): + + with self.assertRaisesRegex( + TypeError, + re.compile( + r"cond branch outputs must have the same pytree structure, but they" + r" differ:.*true_fun output at path \['a'\] is a pytree leaf but" + r" false_fun output at path \['a'\] is a ", + re.DOTALL)): lax.cond(True, lambda top: dict(a=2.), lambda fop: dict(a=(3., 3.)), 1.) + with self.assertRaisesRegex( TypeError, - "true_fun output and false_fun output must have identical types, got\n" - r"DIFFERENT ShapedArray\(float32\[1\]\) vs. " - r"ShapedArray\(float32\[\].*\)."): + re.compile( + r"cond branches must have equal output types but they differ.*The" + r" output of true_fun has type float32\[1\] but the corresponding" + r" output of false_fun has type float32\[\], so the shapes do not" + r" match", + re.DOTALL)): lax.cond(True, lambda top: jnp.array([1.], jnp.float32), lambda fop: jnp.float32(1.), @@ -1023,16 +1045,26 @@ def testSwitchErrors(self): with self.assertRaisesRegex(ValueError, re.escape("Empty branch sequence")): lax.switch(0, [], 1.) - with self.assertRaisesRegex(TypeError, - re.compile("branch 0 output must have same type structure " - "as branch 1 output, but there are differences:.*" - r"at output\['a'\], branch 0 output has pytree leaf", re.DOTALL)): + + with self.assertRaisesRegex( + TypeError, + re.compile( + "switch branch outputs must have the same pytree structure, but" + r" they differ.*branch 0 output at path \['a'\] is a pytree leaf" + r" but branch1 output at path \['a'\] is a , so" + r" their" + " Python types differ.", + re.DOTALL)): lax.switch(1, [lambda _: dict(a=2.), lambda _: dict(a=(3., 3.))], 1.) + with self.assertRaisesRegex( TypeError, - "branch 0 output and branch 1 output must have identical types, got\n" - r"{'a': 'DIFFERENT ShapedArray\(float32\[1\]\) " - r"vs. ShapedArray\(float32\[\].*\)'}."): + re.compile( + "switch branches must have equal output types but they differ.*The" + r" output of branch 0 at path \['a'\] has type float32\[1\] but the" + r" corresponding output of branch1 has type float32\[\], so the" + " shapes do not match", + re.DOTALL)): lax.switch(1, [lambda _: dict(a=jnp.array([1.], jnp.float32)), lambda _: dict(a=jnp.float32(1.))], 1.) @@ -1075,9 +1107,9 @@ def cfun(x): def testCondBatched(self): def fun(x, y, z): pred = lax.lt(x, 3) - true_fun = lambda y: y - false_fun = lambda z: lax.neg(z) - return lax.cond(pred, y, true_fun, z, false_fun) + true_fun = lambda y, _: y + false_fun = lambda _, z: lax.neg(z) + return lax.cond(pred, true_fun, false_fun, y, z) # these cases stay as cond x = jnp.array(2) @@ -1241,7 +1273,7 @@ def fun_ref(x): return 2. * x def fun(x): - return cond(x < 3, None, lambda _: 2., x, lambda x: 2. * x) + return cond(x < 3, lambda _: 2., lambda x: 2. * x, x) x = 3.14 ans = jax.jvp(fun, (x,), (x,)) @@ -1309,8 +1341,36 @@ def f(x): self.assertAllClose(ans, expected, check_dtypes=False) jtu.check_grads(f, (x,), order=2, modes=["fwd", "rev"]) + @parameterized.parameters(itertools.product(range(4), repeat=3)) + @jtu.run_on_devices("cpu") + def testSwitchGradWithForwarding(self, seed, num_input_fwd, num_output_fwd): + num_args = 3 + num_branches = 4 + rng = np.random.RandomState(seed) + in_perm = rng.permutation(num_args) + out_perm = rng.permutation(num_args) + + def branch(s, inputs): + inputs = [inputs[i] for i in in_perm] + outputs = inputs[:num_input_fwd] + [ + s * jnp.exp(inputs[i]) if i < num_output_fwd else jnp.sin(inputs[i]) + for i in range(num_args - num_input_fwd)] + return [outputs[i] for i in out_perm] + + branches = [partial(branch, i) for i in range(num_branches)] + + @jax.jit + def f_(idx, inputs): + idx = lax.convert_element_type(idx // 1, np.int32) + return lax.switch(idx, branches, inputs) + + for idx in range(num_branches): + f = partial(f_, idx) + jtu.check_grads(f, (jnp.arange(float(num_args)),), + order=1, modes=['fwd', 'rev'], atol=1e-2, rtol=1e-2) + def testSwitchGradWithWeakTypeMismatch(self): # issue #4696, PR #4896 - dtype = dtypes.canonicalize_dtype(np.float64) + dtype = dtypes.default_float_dtype() dtype = jnp.float32 if dtype == jnp.float32 else jnp.float64 branches = [ @@ -1333,7 +1393,7 @@ def f(x): @parameterized.named_parameters( {"testcase_name": f"_{name}", "cond": cond} for cond, name in COND_IMPLS) - def testCondGrad2(self, cond): + def testCondGrad2(self, cond=cond_with_new_checkpoint): def f_ref(x): z = jnp.array([1., 2.], x.dtype) * x if x[0] < 2 else jnp.sin(x) return z.sum() @@ -1371,7 +1431,7 @@ def fun_ref(x): return 2. * x def fun(x): - return cond(x < 3, None, lambda _: 2., x, lambda x: 2. * x) + return cond(x < 3, lambda _: 2., lambda x: 2. * x, x) x = 3.14 ans = jax.grad(fun)(x) @@ -1401,8 +1461,9 @@ def fun_ref(x, y): def fun(x, y): return cond( x < 3, - None, lambda _: 2. * jnp.sin(y), - x, lambda x: 2. * x) + lambda _: 2. * jnp.sin(y), + lambda x: 2. * x, + x) y = 5.8 x = 3.14 @@ -1591,7 +1652,7 @@ def g(x): return jnp.where(x > 0, f_1(x), f_2(x)) def testIssue1263(self): def f(rng, x): cond = random.bernoulli(rng) - return lax.cond(cond, x, lambda x: x, jnp.abs(x) - 1., lambda x: x) + return lax.cond(cond, lambda x, _: x, lambda _, x: x, x, jnp.abs(x) - 1.) def body_fn(i, state): rng, x = state @@ -1606,8 +1667,9 @@ def g(rng, x): def testIssue514(self): # just check this doesn't crash lax.cond(True, - (0, 0), lambda x: (x[0], 0), - (1, 1), lambda x: x) + lambda x, _: (x[0], 0), + lambda _, x: x, + (0, 0), (1, 1)) def testIssue649(self): from jax import lax @@ -1725,15 +1787,17 @@ def f(c, a): c = rng.randn(4) if scan is scan_with_new_checkpoint2: - rtol = {np.float64: 1e-12, np.float32: 1e-4} - elif scan is scan_with_for: + atol = {} rtol = {np.float64: 1e-12, np.float32: 1e-4} else: + atol = {np.float64: 1e-14} rtol = {np.float64: 1e-14, np.float32: 1e-4} ans = jax.linearize(lambda c, as_: scan(f, c, as_), c, as_)[1](c, as_) expected = jax.linearize(lambda c, as_: scan_reference(f, c, as_), c, as_)[1](c, as_) - self.assertAllClose(ans, expected, check_dtypes=False, rtol=rtol) + self.assertAllClose( + ans, expected, check_dtypes=False, atol=atol, rtol=rtol + ) @parameterized.named_parameters( {"testcase_name": f"_{jit_scan=}_{jit_f=}_impl={scan_name}", @@ -1754,15 +1818,13 @@ def f(c, a): assert b.shape == () return c, b + on_gpu = jtu.device_under_test() == "gpu" if scan is scan_with_new_checkpoint: - rtol = {np.float32: 5e-5, np.float64: 1e-13} + rtol = {np.float32: 5e-5, np.float64: 1e-10 if on_gpu else 1e-13} atol = 1e-5 - elif scan is scan_with_for: - rtol = {np.float32: 2e-5, np.float64: 1e-13} - atol = {np.float32: 6e-2, np.float64: 1e-13} else: - rtol = {np.float32: 2e-4, np.float64: 1e-13} - atol = {np.float32: 8e-5, np.float64: 1e-13} + rtol = {np.float32: 2e-4, np.float64: 1e-10 if on_gpu else 1e-13} + atol = {np.float32: 8e-5, np.float64: 1e-10 if on_gpu else 1e-13} if jit_f: f = jax.jit(f) @@ -1776,7 +1838,7 @@ def f(c, a): expected = jax.grad(lambda c, as_: list(scan_reference(f, c, as_))[0].sum())(c, as_) self.assertAllClose(ans, expected, check_dtypes=False, rtol=rtol, atol=atol) - rtol = 5e-3 if scan is not scan_with_new_checkpoint2 else 5e-2 + rtol = 5e-1 if scan is not scan_with_new_checkpoint2 else 5e-2 atol = 5e-2 if jtu.test_device_matches(["tpu"]) else 1e-3 jtu.check_grads(partial(scan, f), (c, as_), order=2, modes=["rev"], atol=atol, rtol=rtol) @@ -1896,7 +1958,7 @@ def plus_one(p, iter_idx): def testScanBodyOutputError(self): with self.assertRaisesRegex( TypeError, - re.escape("scan body output must be a pair, got ShapedArray(float32[]).")): + re.escape("scan body output must be a pair, got float32[].")): lax.scan(lambda c, x: np.float32(0.), 0, jnp.arange(5.)) def testScanMetadataError(self): @@ -1955,7 +2017,7 @@ def testScanBodyCarryTypeMismatchErrors(self): with self.assertRaisesRegex( TypeError, re.escape("function carry input and carry output must have equal " - "types (e.g. shapes and dtypes of arrays), but they differ:\n\n" + "types, but they differ:\n\n" "The input carry x has type int32[] but the corresponding " "output carry component has type float32[], so the dtypes do " "not match" @@ -1966,7 +2028,7 @@ def testScanBodyCarryTypeMismatchErrors(self): with self.assertRaisesRegex( TypeError, re.escape("function carry input and carry output must have equal " - "types (e.g. shapes and dtypes of arrays), but they differ:\n\n" + "types, but they differ:\n\n" "The input carry component x[1] has type int32[] but the " "corresponding output carry component has type float32[], " "so the dtypes do not match" @@ -1977,13 +2039,13 @@ def testScanBodyCarryTypeMismatchErrors(self): with self.assertRaisesRegex( TypeError, re.escape("function carry input and carry output must have equal " - "types (e.g. shapes and dtypes of arrays), but they differ:\n\n" + "types, but they differ:\n\n" " * the input carry component x[0] has type int32[] but the " "corresponding output carry component has type float32[], " "so the dtypes do not match;\n" " * the input carry component x[1] has type int32[] but the " "corresponding output carry component has type float32[1,1], " - "so the dtypes do not match and also the shapes do not match." + "so the dtypes do not match, and the shapes do not match." )): jax.lax.scan(lambda x, _: ((x[0].astype('float32'), x[1].astype('float32').reshape(1, 1), @@ -1994,8 +2056,6 @@ def testScanBodyCarryTypeMismatchErrors(self): def testScanInvalidUnrollRaises(self): with self.assertRaisesRegex(ValueError, "`unroll` must be"): jax.lax.scan(lambda x, _: (x, x), 0, jnp.arange(5), unroll=-1) - with self.assertRaisesRegex(ValueError, "`unroll` must be"): - jax.lax.scan(lambda x, _: (x, x), 0, jnp.arange(5), unroll=0) @parameterized.named_parameters( {"testcase_name": f"_{scan_name}", @@ -2192,7 +2252,7 @@ def body(x): def test_caches_depend_on_axis_env(self): # https://github.com/jax-ml/jax/issues/9187 - scanned_f = lambda _, __: (lax.psum(1, 'i'), None) + scanned_f = lambda _, __: (lax.axis_size('i'), None) f = lambda: lax.scan(scanned_f, 0, None, length=1)[0] ans = jax.vmap(f, axis_name='i', axis_size=2, out_axes=None)() self.assertEqual(ans, 2) @@ -2317,8 +2377,9 @@ def testWhileGradError(self, loop: str = "fori_inside_scan"): elif loop == "fori_inside_cond": func = lambda x: lax.cond( True, - x, lambda x: lax.fori_loop(x, x + 2., lambda i, c: c, x), - 1., lambda x: x) + lambda x, _: lax.fori_loop(x, x + 2., lambda i, c: c * 2., x), + lambda _, x: x, + x, 1.) elif loop == "fori_inside_scan": func = lambda x: lax.scan( lambda c, x: (lax.fori_loop(x, x + 2., lambda i, c1: c1 * c, x), None), @@ -2410,12 +2471,19 @@ def body(i, x): too_big = 2 * jax.device_count() + if config.pmap_shmap_merge.value: + expected_regex = re.compile( + "cannot select an axis to squeeze out which has size not equal to " + r"one, got shape=\(\d,\) and dimensions=\(\d,\)" + ) + else: + expected_regex = re.escape( + "compiling computation `jit(scan)` that requires {} " + "replicas, but only {} XLA devices are available." + .format(too_big, jax.device_count())) + self.assertRaisesRegex( - ValueError, - re.escape( - "compiling computation `scan` that requires {} " - "replicas, but only {} XLA devices are available." - .format(too_big, jax.device_count())), + ValueError, expected_regex, lambda: f_loop(jnp.ones(too_big))) @parameterized.named_parameters( @@ -2467,7 +2535,7 @@ def f(c, a): self.assertLess(len(scan_unrolled_hlo), len(scan_fully_unrolled_hlo)) # and the lowering should contain a while loop, unless the scan is fully - # unrolled + # unrolled self.assertIn("while(", scan_hlo) self.assertIn("while(", scan_unrolled_hlo) self.assertNotIn("while(", scan_fully_unrolled_hlo) @@ -2483,7 +2551,7 @@ def f(h, _): def test_disable_jit_cond_with_vmap(self): # https://github.com/jax-ml/jax/issues/3093 def fn(t): - return lax.cond(t > 0, 0, lambda x: 0, 0, lambda x: 1) + return lax.cond(t > 0, lambda x, _: 0, lambda _, x: 1, 0, 0) fn = jax.vmap(fn) with jax.disable_jit(): @@ -2690,7 +2758,7 @@ def body_fun(val): {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ ('', None), - ('new_remat', new_checkpoint), + ('new_remat', jax.checkpoint), ]) def test_scan_vjp_forwards_extensive_residuals(self, remat): # https://github.com/jax-ml/jax/issues/4510 @@ -2704,10 +2772,13 @@ def cumprod(x): x = jnp.asarray(rng.randn(32, 2, 32).astype('float32')) _, vjp_fun = jax.vjp(cumprod, x) - # Need to spelunk into vjp_fun. This is fragile, and if it causes problems - # just skip this test and make an issue for mattjj. - *_, ext_res = vjp_fun.args[0].args[0] - self.assertIs(ext_res, x) + # TODO(mattjj): should we re-enable this check? The constants are now + # inlined in the Jaxprs, not easy to find them. + # ==> Yes, we don't want to change autodiff const behavior. We must make + # these tessts pass under use_simplified_jaxpr_constants. + if not config.use_simplified_jaxpr_constants.value: + ext_res, = vjp_fun.args_res + self.assertIs(ext_res, x) if remat is not None: # TODO(mattjj): make the numpy.ndarray test pass w/ remat @@ -2715,8 +2786,9 @@ def cumprod(x): x = rng.randn(32, 2, 32).astype('float32') # numpy.ndarray, not Array _, vjp_fun = jax.vjp(cumprod, x) - *_, ext_res = vjp_fun.args[0].args[0] - self.assertIsInstance(ext_res, jax.Array) + if not config.use_simplified_jaxpr_constants.value: + ext_res, *_ = vjp_fun.opaque_residuals + self.assertIsInstance(ext_res, jax.Array) def test_scan_vmap_collectives(self): def scan_f(state, x): @@ -2758,7 +2830,6 @@ def cond_fun(val): self.assertAllClose(deriv(my_pow)(3.0, 1), 1.0, check_dtypes=False) - def test_while_loop_fixed_point_with_batched_pred_and_consts(self): def f(i, x): def cond(carry): @@ -2861,18 +2932,13 @@ def f(x): x = np.arange(3, dtype=np.float32) lowered = jax.jit(f).lower(x) stablehlo = lowered.as_text() - self.assertIn("stablehlo.case", stablehlo) - self.assertIn("stablehlo.sine", stablehlo) - self.assertIn("stablehlo.cosine", stablehlo) - - # The HLO has been canonicalized and contains only the branch we need - hlo = lowered.as_text("hlo") + # The StableHLO contains only the branch we need if jtu.device_under_test() == "cpu": - self.assertIn(" sine", hlo) - self.assertNotIn(" cosine", hlo) + self.assertIn("stablehlo.sine", stablehlo) + self.assertNotIn("stablehlo.cosine", stablehlo) else: - self.assertNotIn(" sine", hlo) - self.assertIn(" cosine", hlo) + self.assertNotIn("stablehlo.sine", stablehlo) + self.assertIn("stablehlo.cosine", stablehlo) def test_platform_dependent_with_non_existent_custom_call(self): if not jtu.test_device_matches(["cpu"]): @@ -2895,8 +2961,7 @@ def f(x): x = np.arange(3, dtype=np.float32) hlo = str(jax.jit(f).lower(x).compiler_ir()) - occurrences = re.findall(prim_non_existent_custom_call.name, hlo) - self.assertLen(occurrences, 3) + self.assertNotIn(prim_non_existent_custom_call.name, hlo) res_eager = f(x) self.assertAllClose(res_eager, 3. * np.sin(x)) @@ -2912,6 +2977,26 @@ def f(x): res_grad = jax.grad(f)(1.) self.assertAllClose(res_grad, 3. * np.cos(1.)) + def test_platform_dependent_with_primitive_with_lowering_error(self): + if not jtu.test_device_matches(["cpu", "tpu"]): + self.skipTest("Only for CPU and TPU") + + def f(x): + return lax.platform_dependent( + x, + # Check that we only lower on the intended platform + cpu=lambda x: prim_with_lowering_error.bind(x, only_on="cpu"), + tpu=lambda x: prim_with_lowering_error.bind(x, only_on="tpu")) + + self.assertAllClose(np.sin(1.), f(1.)) # Eager + self.assertAllClose(np.sin(1.), jax.jit(f)(1.)) + self.assertAllClose(np.sin(1.), lax.cond(True, f, lambda x: x, 1.)) + self.assertAllClose(1., lax.cond(False, f, lambda x: x, 1.)) + self.assertAllClose((0., np.sin(np.arange(8.))), + lax.scan(lambda carry, x: (carry, f(x)), + 0., np.arange(8.))) + self.assertAllClose(np.sin(np.arange(8.)), jax.vmap(f)(np.arange(8.))) + def test_platform_dependent_multiple_identical_branches(self): x = np.arange(3, dtype=np.float32) def f(x): @@ -2921,13 +3006,14 @@ def f(x): tpu=jnp.sin, default=lambda x: x) res = f(x) + on_cpu_tpu = jtu.device_under_test() in ["cpu", "tpu"] self.assertAllClose( res, - np.sin(x) if jtu.device_under_test() in ["cpu", "tpu"] else x) - # We only lower the common branches once + np.sin(x) if on_cpu_tpu else x) + stablehlo = jax.jit(f).lower(x).as_text() sines = re.findall(r"stablehlo.sine", stablehlo) - self.assertEqual(1, len(sines)) + self.assertEqual(1 if on_cpu_tpu else 0, len(sines)) def test_platform_dependent_no_default(self): ctx = contextlib.ExitStack() @@ -2981,6 +3067,26 @@ def f(x): self.assertEqual(expect_a_dot, " dot(" in hlo) self.assertEqual(not expect_a_dot, " while(" in hlo) + def test_issue_29329(self): + + def outer_fn(x): + def inner_fn(x): + return jax.jit( + lambda x: lax.platform_dependent(x, + default=jnp.sin, + other=jnp.cos))(x) + + _, lin_fn = jax.linearize(inner_fn, x) + + def with_transpose(x): + grad = jax.linear_transpose(lin_fn, x)(x) + del grad + return x + + return jax.lax.cond(x[0][0] > 0., with_transpose, lambda x: x, x) + + jax.vmap(outer_fn)(jnp.ones((5, 10, 10))) + def test_scan_lowering_doesnt_introduce_singleton(self): b = 4 i = 2 @@ -3044,27 +3150,358 @@ def test_cond_casting(self): @jtu.thread_unsafe_test() # live_arrays count isn't thread-safe def test_cond_memory_leak(self): # https://github.com/jax-ml/jax/issues/12719 - def leak(): data = jax.device_put(np.zeros((1024), dtype=np.float32) + 1) def g(): - return jax.lax.cond( + return jax.lax.cond( True, - lambda: data[0], # noqa: F821 + jax.jit(lambda: data[0]), # noqa: F821 lambda: data[1], # noqa: F821 ) + # _ = g() # TODO(necula): enable this, requires fixing leaks in the + # caching of dispatch.xla_primitive_callable. jg = jax.jit(g) _ = jg().block_until_ready() + jg.clear_cache() del g, jg, data, _ + gc.collect() nbufs = lambda: len(jax.live_arrays()) + gc.collect() base = nbufs() leak() - self.assertEqual(base, nbufs()) + # You would hope for exact equality here, but you cannot entirely trust + # gc.collect() to collect everything immediately under a free threaded + # build. + self.assertGreaterEqual(base, nbufs()) leak() - self.assertEqual(base, nbufs()) + self.assertGreaterEqual(base, nbufs()) leak() - self.assertEqual(base, nbufs()) + self.assertGreaterEqual(base, nbufs()) + + def test_grad_remat_while_fixpoint(self): + @jax.remat + def f(x, y): + def cond(_): + return False + def body(c): + x, y = c + return (y, x) + x, y = jax.lax.while_loop(cond, body, (x, y)) + return x + y + jax.linearize(f, 1., 2.) # don't crash + + def test_while_readonly_carry_optimization(self): + # https://github.com/google/flax/issues/4700 + def foo(w, x, c_max): + def while_cond(val): + c, x, w = val + return c < c_max + + def while_body(val): + c, x, w = val + return c + 1, x @ w, w + + _, x, w = jax.lax.while_loop(while_cond, while_body, (0, x, w)) + return w, x + + w = jnp.ones((2, 2)) + xs = jnp.ones((4, 2)) + c_maxs = jnp.arange(4) + w_, _ = jax.vmap(foo, in_axes=(None, 0, 0), out_axes=(None, 0) + )(w, xs, c_maxs) # doesn't crash + self.assertAllClose(w, w_, check_dtypes=False) + + @parameterized.parameters(itertools.product(range(3), repeat=5)) + @jtu.run_on_devices("cpu") + def test_while_constification_correctness( + self, + seed, + num_body_consts, + num_inplace_fwds_cond_uses, + num_inplace_fwds_cond_doesnt_use, + num_noninplace_fwds): + + num_fwds = (num_inplace_fwds_cond_uses + num_inplace_fwds_cond_doesnt_use + + num_noninplace_fwds) + num_carry = num_fwds + 4 + + rng = np.random.RandomState(seed) + perm = rng.permutation(num_carry) + iperm = np.argsort(perm) + + body_consts = [rng.randn(3) for _ in range(num_body_consts)] + init_vals = list(rng.uniform(size=num_carry)) + + def cond_fun(c): + i, c = c + c = [c[i] for i in iperm] + c, _ = split_list(c, [num_inplace_fwds_cond_uses]) + return (i < 2) + (0. * jnp.array(sum(c))).astype(bool) + + def body_fun(c): + i, c = c + c = [c[i] for i in iperm] + inplace_fwds, noninplace_fwds, dont_fwd = split_list( + c, [num_inplace_fwds_cond_uses + num_inplace_fwds_cond_doesnt_use, + num_noninplace_fwds]) + dont_fwd = [jnp.sin(x) * sum(jnp.sum(c) for c in body_consts) + for x in dont_fwd] + new_c_perm = [*inplace_fwds, *dont_fwd, *noninplace_fwds] + new_c = [new_c_perm[i] for i in perm] + return (i + 1, new_c) + + i, outs = jax.lax.while_loop(cond_fun, body_fun, (0, init_vals)) + self.assertEqual(i, 2) + _, outs_ref = body_fun(body_fun((0, init_vals))) + self.assertAllClose(outs, outs_ref, check_dtypes=False) + + def test_while_constification_correctness_manually(self): + # regression test for a particular index-offset logic bug + + def cond_fun(c): + # cond doesn't use first or third element of the carry + _, i, _ = c + return i == 0 + + def body_fun(c): + # two body consts + for _ in range(2): jnp.sin(np.zeros(3)) + # first element of the carry is forwarded to third element of the carry + return 0., 1., c[0] + + outs = jax.lax.while_loop(cond_fun, body_fun, (5., 0., 3.14)) + self.assertAllClose(outs, (0., 1., 5.)) + + def test_scan_readonly_carry_optimization(self): + # https://github.com/google/flax/issues/4709 + def f(x, y): + def g(_, y): + y, _ = jax.lax.scan(lambda y, _: (y, None), y, None, length=1) + return y + return jax.lax.cond(x < 0, g, g, x, y) + xs = jnp.arange(3.) + y = 3. + jax.vmap(f, (0, None), None)(xs, y) # don't crash + + @parameterized.parameters(itertools.product(range(3), repeat=4)) + @jtu.run_on_devices("cpu") + def test_scan_constification_correctness( + self, + seed, + num_body_consts, + num_inplace_fwds, + num_noninplace_fwds): + + num_fwds = num_inplace_fwds + num_noninplace_fwds + num_carry = num_fwds + 4 + num_xs = 2 + num_ys = 3 + + rng = np.random.RandomState(seed) + perm = rng.permutation(num_carry) + iperm = np.argsort(perm) + + body_consts = [rng.randn(3) for _ in range(num_body_consts)] + init_vals = list(rng.uniform(size=num_carry)) + + def body_fun(c, _): + c = [c[i] for i in iperm] + inplace_fwds, noninplace_fwds, dont_fwd = split_list( + c, [num_inplace_fwds, num_noninplace_fwds]) + dont_fwd = [jnp.sin(x) * sum(jnp.sum(c) for c in body_consts) + for x in dont_fwd] + new_c_perm = [*inplace_fwds, *dont_fwd, *noninplace_fwds] + new_c = [new_c_perm[i] for i in perm] + return new_c, [0 for _ in range(num_ys)] + + xs = [jnp.arange(2.) for _ in range(num_xs)] + outs = jax.lax.scan(body_fun, init_vals, xs)[0] + outs_ref = body_fun(body_fun(init_vals, [x[0] for x in xs])[0], [x[1] for x in xs])[0] + self.assertAllClose(outs, outs_ref, check_dtypes=False) + + @parameterized.parameters(itertools.product(range(3), repeat=4)) + @jtu.run_on_devices("cpu") + def test_scan_forwarding_correctness( + self, + seed, + num_body_consts, + num_const_fwds, + num_input_fwds): + + num_carry = num_const_fwds + 4 + num_xs = num_input_fwds + 2 + num_ys = num_xs + 1 + + rng = np.random.RandomState(seed) + carry_perm = rng.permutation(num_carry) + carry_iperm = np.argsort(carry_perm) + + xs_perm = rng.permutation(num_xs) + ys_perm = rng.permutation(num_ys) + f = np.arange(num_xs) + f = [f[i] if idx < num_input_fwds else None for idx, i in enumerate(xs_perm)] + f += [None] + in_fwd = [f[i] for i in ys_perm] + + body_consts = [rng.randn(3) for _ in range(num_body_consts)] + init_vals = list(rng.uniform(size=num_carry)) + + def body_fun(c, x): + c = [c[i] for i in carry_iperm] + carry_fwds, carry_dont_fwd = split_list(c, [num_const_fwds]) + carry_dont_fwd = [jnp.sin(x) * sum(jnp.sum(c) for c in body_consts) + for x in carry_dont_fwd] + new_c_perm = [*carry_fwds, *carry_dont_fwd] + new_c = [new_c_perm[i] for i in carry_perm] + + x = [x[i] for i in xs_perm] + x_fwd, x_dont_fwd = split_list(x, [num_input_fwds]) + x_dont_fwd = [jnp.cos(x) * sum(jnp.sum(c) for c in body_consts) + for x in x_dont_fwd] + y = [*x_fwd, *x_dont_fwd, 0] + y = [y[i] for i in ys_perm] + + return new_c, y + + xs = list(rng.uniform(size=(num_xs, 2))) + final, outs = jax.lax.scan(body_fun, init_vals, xs) + for f, y in zip(in_fwd, outs): + if f is not None: + self.assertAllClose(y, xs[f]) + + final_ref = body_fun(body_fun(init_vals, [x[0] for x in xs])[0], [x[1] for x in xs])[0] + self.assertAllClose(final, final_ref, check_dtypes=False) + + def test_scan_diff_of_print(self): + # ref: https://github.com/jax-ml/jax/issues/28738 + def f(c, _): + jax.debug.print("c = {c}", c=c, ordered=True) + return c + 1, None + def g(x): + return jax.lax.scan(f, x, length=2)[0] + jaxpr = jax.make_jaxpr(jax.value_and_grad(g))(1.0) + eqn_jaxpr = jaxpr.eqns[0].params["jaxpr"] + self.assertIn("debug_print", [e.primitive.name for e in eqn_jaxpr.eqns]) + + def test_scan_input_to_output_forwarding(self): + def f(c, x): + return c + 1, x + def g(x): + return jax.lax.scan(f, 0, x) + jaxpr = jax.make_jaxpr(g)(jnp.arange(3.)) + self.assertLen(jaxpr.eqns[0].params["jaxpr"].jaxpr.outvars, 1) + + @jtu.sample_product( + seed=range(6), + num_rule_consts=range(6), + num_const_fwds=range(6), + num_carry_fwds=range(6), + num_input_fwds=range(6), + ) + @jtu.run_on_devices("cpu") + def test_scan_vjp_forwarding_correctness( + self, + seed, + num_rule_consts, + num_const_fwds, + num_carry_fwds, + num_input_fwds): + # Unlike test_scan_forwarding_correctness, which tests forwarding in the + # scan traceable, this test covers forwarding logic related to residuals in + # the scan partial eval / vjp rule. So 'forwards' refer to residuals that + # will be forwarded. + + # We use a custom_jvp where the jvp rule introduces consts to populate + # jaxpr.consts in _scan_partial_eval's input. + @jax.custom_jvp + def foo(x): + return 3. * x + @foo.defjvp + def foo_jvp(primals, tangents): + (x,), (x_dot,) = primals, tangents + if num_rule_consts: + coeff = sum([jnp.array(np.ones(3) / num_rule_consts) for _ in range(num_rule_consts)]) # noqa: C419 + else: + coeff = 1. + return foo(x), jnp.prod(coeff) * x_dot + + num_const = num_const_fwds + 2 + num_carry = num_carry_fwds + 4 + num_xs = num_input_fwds + 2 + num_ys = num_xs + 1 + + rng = np.random.RandomState(seed) + carry_perm = rng.permutation(num_carry) + carry_iperm = np.argsort(carry_perm) + + xs_perm = rng.permutation(num_xs) + ys_perm = rng.permutation(num_ys) + f = np.arange(num_xs) + f = [f[i] if idx < num_input_fwds else None for idx, i in enumerate(xs_perm)] + f += [None] + in_fwd = [f[i] for i in ys_perm] + + body_consts = [jnp.array(rng.randn(3)) for _ in range(num_const)] + init_vals = list(map(jnp.array, rng.uniform(size=(num_carry, 3)))) + + def body_fun(c, x): + c = [c[i] for i in carry_iperm] + + const_fwds, const_dont_fwd = split_list(body_consts, [num_const_fwds]) + z = sum(const_dont_fwd) + + carry_fwds, carry_dont_fwd = split_list(c, [num_const_fwds]) + carry_fwds = [math.prod([x, x, *const_fwds, z]) for x in carry_fwds] + carry_dont_fwd = [jnp.sin(x) * sum(jnp.sum(c) for c in body_consts) + for x in carry_dont_fwd] + new_c_perm = [*carry_fwds, *carry_dont_fwd] + new_c = [new_c_perm[i] for i in carry_perm] + new_c = [foo(new_c[0]), *new_c[1:]] + + x = [x[i] for i in xs_perm] + x_fwd, x_dont_fwd = split_list(x, [num_input_fwds]) + x_fwd = [x * x for x in x_fwd] + x_dont_fwd = [jnp.cos(x) * sum(jnp.sum(c) for c in body_consts) + for x in x_dont_fwd] + y = [*x_fwd, *x_dont_fwd, 0] + y = [y[i] for i in ys_perm] + + return new_c, y + + xs = list(map(jnp.array, rng.uniform(size=(num_xs, 2)))) + + (final, outs), vjp = jax.vjp(partial(jax.lax.scan, body_fun), init_vals, xs) + init_vals_bar, xs_bar = vjp((final, outs)) + + with jax.disable_jit(): + (final_ref, outs_ref), vjp = jax.vjp(partial(jax.lax.scan, body_fun), init_vals, xs) + init_vals_bar_ref, xs_bar_ref = vjp((final, outs)) + + self.assertAllClose(final, final_ref, check_dtypes=False, rtol=1e-5) + self.assertAllClose(outs, outs_ref, check_dtypes=False) + self.assertAllClose(xs_bar, xs_bar_ref, check_dtypes=False) + + def test_scan_fixpoint_instantiate(self): + def f(x): + c, () = jax.lax.scan(lambda c, _: ((0., 0.), ()), (x, 0.), (), length=5) + return sum(c) + jax.grad(f)(1.) # doesn't crash + + def test_cond_basic_vjp3(self): + def f(x): + return jax.lax.cond(True, jnp.sin, lambda x: x, x) + + _, f_vjp = jax.vjp(f, 1.) + g, = f_vjp(1.0) + self.assertAllClose(g, jnp.cos(1.), check_dtypes=False) + + def h(x): + return jax.lax.cond(True, jnp.sin, lambda x: 1., x) + + _, h_vjp = jax.vjp(h, 1.) + g, = h_vjp(1.0) + self.assertAllClose(g, jnp.cos(1.), check_dtypes=False) if __name__ == '__main__': diff --git a/tests/lax_metal_test.py b/tests/lax_metal_test.py deleted file mode 100644 index 5f1781c3be06..000000000000 --- a/tests/lax_metal_test.py +++ /dev/null @@ -1,5778 +0,0 @@ -# Copyright 2018 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from array import array as make_python_array -import collections -import copy -from functools import partial -import io -import itertools -import math -import platform -from typing import Union, cast -import unittest -from unittest import SkipTest - -from absl.testing import absltest -from absl.testing import parameterized - -import numpy as np -try: - import numpy_dispatch -except ImportError: - numpy_dispatch = None - -import jax -import jax.ops -from jax import lax -from jax import numpy as jnp -from jax.sharding import SingleDeviceSharding - -from jax._src import array -from jax._src import config -from jax._src import core -from jax._src import dtypes -from jax._src import test_util as jtu -from jax._src.lax import lax as lax_internal - -from jax._src.util import safe_zip, NumpyComplexWarning - -try: - from jax_plugins import metal_plugin -except ImportError: - metal_plugin = None - -config.parse_flags_with_absl() - -numpy_version = jtu.numpy_version() - -nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)] -nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes -one_dim_array_shapes = [(1,), (6,), (12,)] -empty_array_shapes = [(0,), (0, 4), (3, 0),] -broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)] - -scalar_shapes = [jtu.NUMPY_SCALAR_SHAPE, jtu.PYTHON_SCALAR_SHAPE] -array_shapes = nonempty_array_shapes + empty_array_shapes -nonzerodim_shapes = nonempty_nonscalar_array_shapes + empty_array_shapes -nonempty_shapes = scalar_shapes + nonempty_array_shapes -all_shapes = scalar_shapes + array_shapes - -float_dtypes = jtu.dtypes.all_floating -complex_dtypes = jtu.dtypes.complex -int_dtypes = jtu.dtypes.all_integer -unsigned_dtypes = jtu.dtypes.all_unsigned -bool_dtypes = jtu.dtypes.boolean -default_dtypes = float_dtypes + int_dtypes -inexact_dtypes = float_dtypes + complex_dtypes -number_dtypes = float_dtypes + complex_dtypes + int_dtypes + unsigned_dtypes -all_dtypes = number_dtypes + bool_dtypes - -NO_VALUE = object() - -python_scalar_dtypes = [jnp.bool_, jnp.int_, jnp.float_, jnp.complex_] - -# uint64 is problematic because with any uint type it promotes to float: -int_dtypes_no_uint64 = [d for d in int_dtypes + unsigned_dtypes if d != np.uint64] - -def np_unique_backport(ar, return_index=False, return_inverse=False, return_counts=False, - axis=None, **kwds): - # Wrapper for np.unique, handling the change to inverse_indices in numpy 2.0 - result = np.unique(ar, return_index=return_index, return_inverse=return_inverse, - return_counts=return_counts, axis=axis, **kwds) - if jtu.numpy_version() >= (2, 0, 0) or np.ndim(ar) == 1 or not return_inverse: - return result - - idx = 2 if return_index else 1 - inverse_indices = result[idx] - if axis is None: - inverse_indices = inverse_indices.reshape(np.shape(ar)) - else: - inverse_indices = np.expand_dims(inverse_indices, [i for i in range(np.ndim(ar)) if i != axis]) - return (*result[:idx], inverse_indices, *result[idx + 1:]) - - -def _indexer_with_default_outputs(indexer, use_defaults=True): - """Like jtu.with_jax_dtype_defaults, but for __getitem__ APIs""" - class Indexer: - @partial(jtu.with_jax_dtype_defaults, use_defaults=use_defaults) - def __getitem__(self, *args): - return indexer.__getitem__(*args) - return Indexer() - -def _valid_dtypes_for_shape(shape, dtypes): - # Not all (shape, dtype) pairs are valid. In particular, Python scalars only - # have one type in each category (float, bool, etc.) - if shape is jtu.PYTHON_SCALAR_SHAPE: - return [t for t in dtypes if t in python_scalar_dtypes] - return dtypes - -def _shape_and_dtypes(shapes, dtypes): - for shape in shapes: - for dtype in _valid_dtypes_for_shape(shape, dtypes): - yield (shape, dtype) - -def _compatible_shapes(shape): - if np.ndim(shape) == 0 or shape in scalar_shapes: - return [shape] - return (shape[n:] for n in range(len(shape) + 1)) - -OpRecord = collections.namedtuple( - "OpRecord", - ["name", "nargs", "dtypes", "shapes", "rng_factory", "diff_modes", - "test_name", "check_dtypes", "tolerance", "inexact", "kwargs"]) - -def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, - test_name=None, check_dtypes=True, - tolerance=None, inexact=False, kwargs=None): - test_name = test_name or name - return OpRecord(name, nargs, dtypes, shapes, rng_factory, diff_modes, - test_name, check_dtypes, tolerance, inexact, kwargs) - - -JAX_ARGMINMAX_RECORDS = [ - op_record("argmin", 1, default_dtypes, nonempty_shapes, jtu.rand_some_equal, []), - op_record("argmax", 1, default_dtypes, nonempty_shapes, jtu.rand_some_equal, []), - op_record("nanargmin", 1, default_dtypes, nonempty_shapes, jtu.rand_some_nan, []), - op_record("nanargmax", 1, default_dtypes, nonempty_shapes, jtu.rand_some_nan, []), -] - -def _shapes_are_broadcast_compatible(shapes): - try: - lax.broadcast_shapes(*(() if s in scalar_shapes else s for s in shapes)) - except ValueError: - return False - else: - return True - -def _shapes_are_equal_length(shapes): - return all(len(shape) == len(shapes[0]) for shape in shapes[1:]) - -@unittest.skipIf(metal_plugin == None, "Tests require jax-metal plugin.") -class LaxBackedNumpyTests(jtu.JaxTestCase): - """Tests for LAX-backed Numpy implementation.""" - - def _GetArgsMaker(self, rng, shapes, dtypes, np_arrays=True): - def f(): - out = [rng(shape, dtype or jnp.float_) - for shape, dtype in zip(shapes, dtypes)] - if np_arrays: - return out - return [jnp.asarray(a) if isinstance(a, (np.ndarray, np.generic)) else a - for a in out] - return f - - @parameterized.parameters( - [dtype for dtype in [jnp.bool, jnp.uint8, jnp.uint16, jnp.uint32, - jnp.int8, jnp.int16, jnp.int32, jnp.int64, - jnp.float16, jnp.float32] - if dtype == dtypes.canonicalize_dtype(dtype)]) - def testDtypeWrappers(self, dtype): - arr = dtype(0) - self.assertIsInstance(arr, jax.Array) - self.assertEqual(arr.dtype, np.dtype(dtype)) - self.assertArraysEqual(arr, 0, check_dtypes=False) - - # No copy primitive is generated - jaxpr = jax.make_jaxpr(dtype)(0) - prims = [eqn.primitive for eqn in jaxpr.eqns] - self.assertEqual(prims, [lax.convert_element_type_p]) # No copy generated. - - def testBoolDtypeAlias(self): - self.assertIs(jnp.bool, jnp.bool_) - - @jtu.sample_product( - dtype=float_dtypes + [object], - allow_pickle=[True, False], - ) - def testLoad(self, dtype, allow_pickle): - if dtype == object and not allow_pickle: - self.skipTest("dtype=object requires allow_pickle=True") - rng = jtu.rand_default(self.rng()) - arr = rng((10), dtype) - with io.BytesIO() as f: - jnp.save(f, arr) - f.seek(0) - arr_out = jnp.load(f, allow_pickle=allow_pickle) - self.assertArraysEqual(arr, arr_out, allow_object_dtype=True) - - @unittest.skip("Jax-metal fail.") - def testArrayEqualExamples(self): - # examples from the array_equal() docstring. - self.assertTrue(jnp.array_equal([1, 2], [1, 2])) - self.assertTrue(jnp.array_equal(np.array([1, 2]), np.array([1, 2]))) - self.assertFalse(jnp.array_equal([1, 2], [1, 2, 3])) - self.assertFalse(jnp.array_equal([1, 2], [1, 4])) - - a = np.array([1, np.nan]) - self.assertFalse(jnp.array_equal(a, a)) - self.assertTrue(jnp.array_equal(a, a, equal_nan=True)) - - a = np.array([1 + 1j]) - b = a.copy() - a.real = np.nan - b.imag = np.nan - self.assertTrue(jnp.array_equal(a, b, equal_nan=True)) - - def testArrayEquivExamples(self): - # examples from the array_equiv() docstring. - self.assertTrue(jnp.array_equiv([1, 2], [1, 2])) - self.assertFalse(jnp.array_equiv([1, 2], [1, 3])) - with jax.numpy_rank_promotion('allow'): - self.assertTrue(jnp.array_equiv([1, 2], [[1, 2], [1, 2]])) - self.assertFalse(jnp.array_equiv([1, 2], [[1, 2, 1, 2], [1, 2, 1, 2]])) - self.assertFalse(jnp.array_equiv([1, 2], [[1, 2], [1, 3]])) - - def testArrayModule(self): - if numpy_dispatch is None: - raise SkipTest('requires https://github.com/seberg/numpy-dispatch') - - jnp_array = jnp.array(1.0) - np_array = np.array(1.0) - - module = numpy_dispatch.get_array_module(jnp_array) - self.assertIs(module, jnp) - - module = numpy_dispatch.get_array_module(jnp_array, np_array) - self.assertIs(module, jnp) - - def f(x): - module = numpy_dispatch.get_array_module(x) - self.assertIs(module, jnp) - return x - jax.jit(f)(jnp_array) - jax.grad(f)(jnp_array) - - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in all_shapes - for axis in list(range(-len(shape), len(shape)))], - discont=[None, "pi", 2], - period=["2pi", "pi"], - dtype=default_dtypes, - ) - def testUnwrap(self, shape, dtype, axis, discont, period): - special_vals = {"pi": np.pi, "2pi": 2 * np.pi} - period = special_vals.get(period, period) - discont = special_vals.get(discont, discont) - - rng = jtu.rand_default(self.rng()) - - def np_fun(x): - dtype = None - if x.dtype == dtypes.bfloat16: - dtype = x.dtype - x = x.astype(np.float32) - out = np.unwrap(x, axis=axis, discont=discont, period=period) - return out if dtype is None else out.astype(dtype) - - jnp_fun = partial(jnp.unwrap, axis=axis, discont=discont, period=period) - if not dtypes.issubdtype(dtype, np.inexact): - # This case requires implicit dtype promotion - jnp_fun = jax.numpy_dtype_promotion('standard')(jnp_fun) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, - atol={dtypes.bfloat16: 1e-1, np.float16: 1e-2}) - self._CompileAndCheck(jnp_fun, args_maker, atol={dtypes.bfloat16: 1e-1}) - - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in all_shapes - for axis in list(range(-len(shape), len(shape))) + [None]], - dtype=all_dtypes, - ) - def testCountNonzero(self, shape, dtype, axis): - rng = jtu.rand_some_zero(self.rng()) - np_fun = lambda x: np.count_nonzero(x, axis) - jnp_fun = lambda x: jnp.count_nonzero(x, axis) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes) - def testNonzero(self, shape, dtype): - rng = jtu.rand_some_zero(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np.nonzero, jnp.nonzero, args_maker, check_dtypes=False) - - @jtu.sample_product( - [dict(shape=shape, fill_value=fill_value) - for shape in nonempty_array_shapes - for fill_value in [None, -1, shape or (1,)] - ], - dtype=all_dtypes, - size=[1, 5, 10], - ) - def testNonzeroSize(self, shape, dtype, size, fill_value): - rng = jtu.rand_some_zero(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - def np_fun(x): - result = np.nonzero(x) - if size <= len(result[0]): - return tuple(arg[:size] for arg in result) - else: - fillvals = fill_value if np.ndim(fill_value) else len(result) * [fill_value or 0] - return tuple(np.concatenate([arg, np.full(size - len(arg), fval, arg.dtype)]) - for fval, arg in safe_zip(fillvals, result)) - jnp_fun = lambda x: jnp.nonzero(x, size=size, fill_value=fill_value) - with jtu.ignore_warning(category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*"): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product(shape=all_shapes, dtype=all_dtypes) - def testFlatNonzero(self, shape, dtype): - rng = jtu.rand_some_zero(self.rng()) - np_fun = jtu.ignore_warning( - category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*")(np.flatnonzero) - jnp_fun = jnp.flatnonzero - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - - # JIT compilation requires specifying the size statically: - jnp_fun = lambda x: jnp.flatnonzero(x, size=np.size(x) // 2) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - shape=nonempty_array_shapes, - dtype=all_dtypes, - fill_value=[None, -1, 10, (-1,), (10,)], - size=[1, 5, 10], - ) - def testFlatNonzeroSize(self, shape, dtype, size, fill_value): - rng = jtu.rand_some_zero(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - @jtu.ignore_warning(category=DeprecationWarning, message="Calling nonzero on 0d arrays.*") - def np_fun(x): - result = np.flatnonzero(x) - if size <= len(result): - return result[:size] - else: - fill_val = fill_value or 0 - return np.concatenate([result, np.full(size - len(result), fill_val, result.dtype)]) - jnp_fun = lambda x: jnp.flatnonzero(x, size=size, fill_value=fill_value) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes) - def testArgWhere(self, shape, dtype): - rng = jtu.rand_some_zero(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np.argwhere, jnp.argwhere, args_maker, check_dtypes=False) - - # JIT compilation requires specifying a size statically. Full test of this - # behavior is in testNonzeroSize(). - jnp_fun = lambda x: jnp.argwhere(x, size=np.size(x) // 2) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, fill_value=fill_value) - for shape in nonempty_array_shapes - for fill_value in [None, -1, shape or (1,)] - ], - dtype=all_dtypes, - size=[1, 5, 10], - ) - def testArgWhereSize(self, shape, dtype, size, fill_value): - rng = jtu.rand_some_zero(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - def np_fun(x): - result = np.argwhere(x) - if size <= len(result): - return result[:size] - else: - fillvals = fill_value if np.ndim(fill_value) else result.shape[-1] * [fill_value or 0] - return np.empty((size, 0), dtype=int) if np.ndim(x) == 0 else np.stack([np.concatenate([arg, np.full(size - len(arg), fval, arg.dtype)]) - for fval, arg in safe_zip(fillvals, result.T)]).T - jnp_fun = lambda x: jnp.argwhere(x, size=size, fill_value=fill_value) - - with jtu.ignore_warning(category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*"): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(np_op=getattr(np, rec.name), jnp_op=getattr(jnp, rec.name), - shape=shape, dtype=dtype, axis=axis, rng_factory=rec.rng_factory) - for rec in JAX_ARGMINMAX_RECORDS - for shape, dtype in _shape_and_dtypes(rec.shapes, rec.dtypes) - for axis in range(-len(shape), len(shape))], - keepdims=[False, True], - ) - def testArgMinMax(self, np_op, jnp_op, rng_factory, shape, dtype, axis, keepdims): - rng = rng_factory(self.rng()) - if dtype == np.complex128 and jtu.test_device_matches(["gpu"]): - raise unittest.SkipTest("complex128 reductions not supported on GPU") - if "nan" in np_op.__name__ and dtype == jnp.bfloat16: - raise unittest.SkipTest("NumPy doesn't correctly handle bfloat16 arrays") - kwds = {"keepdims": True} if keepdims else {} - - np_fun = jtu.with_jax_dtype_defaults(partial(np_op, axis=axis, **kwds)) - jnp_fun = partial(jnp_op, axis=axis, **kwds) - - args_maker = lambda: [rng(shape, dtype)] - try: - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - except ValueError as e: - if str(e) == "All-NaN slice encountered": - self.skipTest("JAX doesn't support checking for all-NaN slices") - else: - raise - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(name=rec.name, np_op=getattr(np, rec.name), - jnp_op=getattr(jnp, rec.name)) - for rec in JAX_ARGMINMAX_RECORDS], - ) - def testArgMinMaxEmpty(self, name, np_op, jnp_op): - name = name[3:] if name.startswith("nan") else name - msg = f"attempt to get {name} of an empty sequence" - with self.assertRaisesRegex(ValueError, msg): - jnp_op(np.array([])) - with self.assertRaisesRegex(ValueError, msg): - jnp_op(np.zeros((2, 0)), axis=1) - np_fun = jtu.with_jax_dtype_defaults(partial(np_op, axis=0)) - jnp_fun = partial(jnp_op, axis=0) - args_maker = lambda: [np.zeros((2, 0))] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, axes=axes) - for lhs_shape, rhs_shape, axes in [ - [(2,), (2,), (-1, -1, -1, None)], # scalar output - [(2, 4), (2, 4), (-1, -1, -1, 0)], # 2D vectors - [(3, 4), (3, 4), (-1, -1, -1, 0)], # 3D vectors - [(3, 4), (3, 6, 5, 4), (-1, -1, -1, 0)], # broadcasting - [(4, 3), (3, 6, 5, 4), (1, 0, -1, None)], # different axes - [(6, 1, 3), (5, 3), (-1, -1, -1, None)], # more broadcasting - [(6, 1, 2), (5, 3), (-1, -1, -1, None)], # mixed 2D and 3D vectors - [(10, 5, 2, 8), (1, 5, 1, 3), (-2, -1, -3, None)], # axes/broadcasting - [(4, 5, 2), (4, 5, 2), (-1, -1, 0, None)], # axisc should do nothing - [(4, 5, 2), (4, 5, 2), (-1, -1, -1, None)] # same as before - ]], - lhs_dtype=number_dtypes, - rhs_dtype=number_dtypes, - ) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testCross(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] - axisa, axisb, axisc, axis = axes - jnp_fun = lambda a, b: jnp.cross(a, b, axisa, axisb, axisc, axis) - # Note: 2D inputs to jnp.cross are deprecated in numpy 2.0. - @jtu.ignore_warning(category=DeprecationWarning, - message="Arrays of 2-dimensional vectors are deprecated.") - def np_fun(a, b): - a = a.astype(np.float32) if lhs_dtype == jnp.bfloat16 else a - b = b.astype(np.float32) if rhs_dtype == jnp.bfloat16 else b - out = np.cross(a, b, axisa, axisb, axisc, axis) - return out.astype(jnp.promote_types(lhs_dtype, rhs_dtype)) - tol_spec = {dtypes.bfloat16: 3e-1, np.float16: 0.15} - tol = max(jtu.tolerance(lhs_dtype, tol_spec), - jtu.tolerance(rhs_dtype, tol_spec)) - with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol) - - @jtu.sample_product( - [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape) - for lhs_shape, rhs_shape in [ - ((3, 3), ()), - ((), (3, 3)), - ((4, 5), (5,)), - ((6,), (6, 4)), - ((3, 4), (4, 5)), - ((4, 3, 2), (2,)), - ((2,), (3, 2, 4)), - ((4, 3, 2), (2, 5)), - ((5, 2), (3, 2, 4)), - ((2, 3, 4), (5, 4, 1))]], - lhs_dtype=float_dtypes,#number_dtypes, - rhs_dtype=float_dtypes,#number_dtypes, - ) - @jax.default_matmul_precision("float32") - def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] - tol = {np.float16: 1e-2, np.float32: 2e-5, np.float64: 1e-14, - np.complex128: 1e-14} - if (lhs_dtype in [np.float16, jnp.bfloat16] and - rhs_dtype in [np.float16, jnp.bfloat16]): - tol = 1e-2 - def np_dot(x, y): - x = x.astype(np.float32) if lhs_dtype == jnp.bfloat16 else x - y = y.astype(np.float32) if rhs_dtype == jnp.bfloat16 else y - return np.dot(x, y).astype(jnp.promote_types(lhs_dtype, rhs_dtype)) - with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): - self._CheckAgainstNumpy(np_dot, jnp.dot, args_maker, tol=tol) - self._CompileAndCheck(jnp.dot, args_maker, atol=tol, rtol=tol) - - @jtu.sample_product( - lhs_dtype=number_dtypes, - rhs_dtype=number_dtypes, - ) - @jax.numpy_dtype_promotion('standard') - def testMixedPrecisionDot(self, lhs_dtype, rhs_dtype): - # This test confirms that jnp.dot lowers to a single dot_general call, - # avoiding explicit type casting of inputs and outputs. - lhs = jax.ShapeDtypeStruct((5,), lhs_dtype) - rhs = jax.ShapeDtypeStruct((5,), rhs_dtype) - jaxpr = jax.make_jaxpr(jnp.dot)(lhs, rhs) - prims = [eqn.primitive for eqn in jaxpr.eqns] - self.assertIn(prims, [ - [lax.dot_general_p], - [lax.dot_general_p, lax.convert_element_type_p] - ]) - - @jtu.sample_product( - [dict(name=name, lhs_shape=lhs_shape, rhs_shape=rhs_shape) - for name, lhs_shape, rhs_shape in [ - ("vector-vector", (3,), (3,)), - ("matrix-vector", (3, 3), (3,)), - ("vector-matrix", (3,), (3, 3)), - ("matrix-matrix", (3, 3), (3, 3)), - ("vector-tensor", (3,), (5, 3, 2)), - ("tensor-vector", (5, 3, 2), (2,)), - ("matrix-tensor", (5, 2), (3, 2, 4)), - ("tensor-matrix", (5, 2, 3), (3, 2)), - ("tensor-tensor", (5, 3, 4), (5, 4, 1)), - ("tensor-tensor-broadcast", (3, 1, 3, 4), (5, 4, 1))]], - lhs_dtype=float_dtypes, #number_dtypes, - rhs_dtype=float_dtypes, #number_dtypes, - ) - @jax.default_matmul_precision("float32") - def testMatmul(self, name, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype): - rng = jtu.rand_default(self.rng()) - def np_fun(x, y): - dtype = jnp.promote_types(lhs_dtype, rhs_dtype) - return np.matmul(x, y).astype(dtype) - args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] - tol = {np.float16: 1e-2, np.float32: 2e-2, np.float64: 1e-12, - np.complex128: 1e-12} - - with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): - self._CheckAgainstNumpy(np_fun, jnp.matmul, args_maker, tol=tol) - self._CompileAndCheck(jnp.matmul, args_maker, atol=tol, rtol=tol) - - @jtu.sample_product( - lhs_batch=broadcast_compatible_shapes, - rhs_batch=broadcast_compatible_shapes, - axis_size=[2, 4], - axis=range(-2, 2), - dtype=float_dtypes,#number_dtypes, - ) - @jax.default_matmul_precision("float32") - @jax.numpy_rank_promotion('allow') # adopt PR#22316 - def testVecdot(self, lhs_batch, rhs_batch, axis_size, axis, dtype): - # Construct vecdot-compatible shapes. - size = min(len(lhs_batch), len(rhs_batch)) - axis = int(np.clip(axis, -size - 1, size)) - if axis >= 0: - lhs_shape = (*lhs_batch[:axis], axis_size, *lhs_batch[axis:]) - rhs_shape = (*rhs_batch[:axis], axis_size, *rhs_batch[axis:]) - else: - laxis = axis + len(lhs_batch) + 1 - lhs_shape = (*lhs_batch[:laxis], axis_size, *lhs_batch[laxis:]) - raxis = axis + len(rhs_batch) + 1 - rhs_shape = (*rhs_batch[:raxis], axis_size, *rhs_batch[raxis:]) - - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] - @jtu.promote_like_jnp - def np_fn(x, y, axis=axis): - f = jtu.numpy_vecdot if jtu.numpy_version() < (2, 0, 0) else np.vecdot - return f(x, y, axis=axis).astype(x.dtype) - jnp_fn = partial(jnp.vecdot, axis=axis) - tol = {np.float16: 1e-2, np.float32: 1E-3, np.float64: 1e-12, - np.complex64: 1E-3, np.complex128: 1e-12} - self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol) - self._CompileAndCheck(jnp_fn, args_maker, tol=tol) - - @jtu.sample_product( - [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, axes=axes) - for lhs_shape, rhs_shape, axes in [ - [(3,), (), 0], - [(2, 3, 4), (5, 6, 7), 0], # from issue #740 - [(2, 3, 4), (3, 4, 5, 6), 2], - [(2, 3, 4), (5, 4, 3, 6), [1, 2]], - [(2, 3, 4), (5, 4, 3, 6), [[1, 2], [2, 1]]], - [(1, 2, 3, 4), (4, 5, 3, 6), [[2, 3], [2, 0]]], - ]], - lhs_dtype=float_dtypes,#number_dtypes, - rhs_dtype=float_dtypes,#number_dtypes, - ) - @jax.default_matmul_precision("float32") - def testTensordot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] - jnp_fun = lambda a, b: jnp.tensordot(a, b, axes) - def np_fun(a, b): - a = a if lhs_dtype != jnp.bfloat16 else a.astype(np.float32) - b = b if rhs_dtype != jnp.bfloat16 else b.astype(np.float32) - dtype = jnp.promote_types(lhs_dtype, rhs_dtype) - return np.tensordot(a, b, axes).astype(dtype) - tol = {np.float16: 1e-1, np.float32: 1e-3, np.float64: 1e-12, - np.complex64: 1e-3, np.complex128: 1e-12} - - with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, tol=tol) - - def testTensordotErrors(self): - a = self.rng().random((3, 2, 2)) - b = self.rng().random((2,)) - self.assertRaisesRegex( - TypeError, "Number of tensordot axes.*exceeds input ranks.*", - lambda: jnp.tensordot(a, b, axes=2)) - - self.assertRaisesRegex( - TypeError, "tensordot requires axes lists to have equal length.*", - lambda: jnp.tensordot(a, b, axes=([0], [0, 1]))) - - self.assertRaisesRegex( - TypeError, "tensordot requires both axes lists to be either ints, tuples or lists.*", - lambda: jnp.tensordot(a, b, axes=('bad', 'axes'))) - - self.assertRaisesRegex( - TypeError, "tensordot axes argument must be an int, a pair of ints, or a pair of lists.*", - lambda: jnp.tensordot(a, b, axes='badaxes')) - - @jtu.sample_product( - element_shape=all_shapes, - test_shape=all_shapes, - dtype=default_dtypes, - invert=[False, True], - ) - def testIsin(self, element_shape, test_shape, dtype, invert): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(element_shape, dtype), rng(test_shape, dtype)] - jnp_fun = lambda e, t: jnp.isin(e, t, invert=invert) - np_fun = lambda e, t: np.isin(e, t, invert=invert) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - dtype1=[s for s in default_dtypes if s != jnp.bfloat16], - dtype2=[s for s in default_dtypes if s != jnp.bfloat16], - shape1=all_shapes, - shape2=all_shapes, - ) - def testSetdiff1d(self, shape1, shape2, dtype1, dtype2): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] - - with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): - self._CheckAgainstNumpy(np.setdiff1d, jnp.setdiff1d, args_maker) - - @unittest.skip("JAx-metal fail.") - @jtu.sample_product( - dtype1=[s for s in default_dtypes if s != jnp.bfloat16], - dtype2=[s for s in default_dtypes if s != jnp.bfloat16], - shape1=all_shapes, - shape2=all_shapes, - size=[1, 5, 10], - fill_value=[None, -1], - ) - def testSetdiff1dSize(self, shape1, shape2, dtype1, dtype2, size, fill_value): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] - def np_fun(arg1, arg2): - result = np.setdiff1d(arg1, arg2) - if size <= len(result): - return result[:size] - else: - return np.pad(result, (0, size-len(result)), constant_values=fill_value or 0) - def jnp_fun(arg1, arg2): - return jnp.setdiff1d(arg1, arg2, size=size, fill_value=fill_value) - with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - dtype1=[s for s in default_dtypes if s != jnp.bfloat16], - dtype2=[s for s in default_dtypes if s != jnp.bfloat16], - shape1=nonempty_nonscalar_array_shapes, - shape2=nonempty_nonscalar_array_shapes, - ) - def testUnion1d(self, shape1, shape2, dtype1, dtype2): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] - def np_fun(arg1, arg2): - dtype = jnp.promote_types(arg1.dtype, arg2.dtype) - return np.union1d(arg1, arg2).astype(dtype) - with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): - self._CheckAgainstNumpy(np_fun, jnp.union1d, args_maker) - - @unittest.skip("Jax-metal fail.") - @jtu.sample_product( - dtype1=[s for s in default_dtypes if s != jnp.bfloat16], - dtype2=[s for s in default_dtypes if s != jnp.bfloat16], - shape1=nonempty_nonscalar_array_shapes, - shape2=nonempty_nonscalar_array_shapes, - size=[1, 5, 10], - fill_value=[None, -1], - ) - def testUnion1dSize(self, shape1, shape2, dtype1, dtype2, size, fill_value): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] - def np_fun(arg1, arg2): - dtype = jnp.promote_types(arg1.dtype, arg2.dtype) - result = np.union1d(arg1, arg2).astype(dtype) - fv = result.min() if fill_value is None else fill_value - if size <= len(result): - return result[:size] - else: - return np.concatenate([result, np.full(size - len(result), fv, result.dtype)]) - def jnp_fun(arg1, arg2): - return jnp.union1d(arg1, arg2, size=size, fill_value=fill_value) - with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - dtype1=[s for s in default_dtypes if s != jnp.bfloat16], - dtype2=[s for s in default_dtypes if s != jnp.bfloat16], - shape1=all_shapes, - shape2=all_shapes, - assume_unique=[False, True], - ) - def testSetxor1d(self, shape1, dtype1, shape2, dtype2, assume_unique): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] - jnp_fun = lambda ar1, ar2: jnp.setxor1d(ar1, ar2, assume_unique=assume_unique) - def np_fun(ar1, ar2): - if assume_unique: - # pre-flatten the arrays to match with jax implementation - ar1 = np.ravel(ar1) - ar2 = np.ravel(ar2) - return np.setxor1d(ar1, ar2, assume_unique) - with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - - @jtu.sample_product( - dtype1=[s for s in default_dtypes if s != jnp.bfloat16], - dtype2=[s for s in default_dtypes if s != jnp.bfloat16], - shape1=all_shapes, - shape2=all_shapes, - assume_unique=[False, True], - return_indices=[False, True], - ) - def testIntersect1d(self, shape1, dtype1, shape2, dtype2, assume_unique, - return_indices): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] - jnp_fun = lambda ar1, ar2: jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) - np_fun = lambda ar1, ar2: np.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) - with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - - @jtu.sample_product( - [dict(lhs_shape=lhs_shape, lhs_dtype=lhs_dtype, - rhs_shape=rhs_shape, rhs_dtype=rhs_dtype) - # TODO(phawkins): support integer dtypes too. - for lhs_shape, lhs_dtype in _shape_and_dtypes(all_shapes, inexact_dtypes) - for rhs_shape, rhs_dtype in _shape_and_dtypes(all_shapes, inexact_dtypes) - if len(jtu._dims_of_shape(lhs_shape)) == 0 - or len(jtu._dims_of_shape(rhs_shape)) == 0 - or lhs_shape[-1] == rhs_shape[-1]], - ) - @jax.default_matmul_precision("float32") - def testInner(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] - def np_fun(lhs, rhs): - lhs = lhs if lhs_dtype != jnp.bfloat16 else lhs.astype(np.float32) - rhs = rhs if rhs_dtype != jnp.bfloat16 else rhs.astype(np.float32) - dtype = jnp.promote_types(lhs_dtype, rhs_dtype) - return np.inner(lhs, rhs).astype(dtype) - jnp_fun = lambda lhs, rhs: jnp.inner(lhs, rhs) - tol_spec = {np.float16: 1e-2, np.float32: 1e-5, np.float64: 1e-13, - np.complex64: 1e-5} - tol = max(jtu.tolerance(lhs_dtype, tol_spec), - jtu.tolerance(rhs_dtype, tol_spec)) - # TODO(phawkins): there are float32/float64 disagreements for some inputs. - with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=False, atol=tol, rtol=tol) - - @unittest.skip("MLIR translation rule for primitive 'eigh' not found for platform METAL.") - @jtu.sample_product( - dtype=[dt for dt in float_dtypes if dt not in [jnp.float16, jnp.bfloat16]], - shape=[shape for shape in one_dim_array_shapes if shape != (1,)], - deg=[1, 2, 3], - rcond=[None, -1, 10e-3, 10e-5, 10e-10], - full=[False, True], - w=[False, True], - cov=[False, True, "unscaled"], - ) - @jax.default_matmul_precision("float32") - def testPolyfit(self, shape, dtype, deg, rcond, full, w, cov): - rng = jtu.rand_default(self.rng()) - tol_spec = {np.float32: 1e-3, np.float64: 1e-13, np.complex64: 1e-5} - tol = jtu.tolerance(dtype, tol_spec) - _w = lambda a: abs(a) if w else None - args_maker = lambda: [rng(shape, dtype), rng(shape, dtype), rng(shape, dtype)] - jnp_fun = lambda x, y, a: jnp.polyfit(x, y, deg=deg, rcond=rcond, full=full, w=_w(a), cov=cov) - np_fun = jtu.ignore_warning( - message="Polyfit may be poorly conditioned*")(lambda x, y, a: np.polyfit(x, y, deg=deg, rcond=rcond, full=full, w=_w(a), cov=cov)) - - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=False, atol=tol, rtol=tol) - - args = args_maker() - if not full: - args = args_maker() - try: - np_out = np_fun(*args) - except ValueError: - return # https://github.com/numpy/numpy/issues/22380 - jnp_out = jnp_fun(*args) - self.assertAllClose(np_out, jnp_out, atol=tol, rtol=tol, - check_dtypes=False) - else: - # Don't compare the residuals because jnp.linalg.lstsq acts slightly - # differently to remain `jit`-compatible. - np_p, _, nrank, nsingular_values, nrcond = np_fun(*args) - jp_p, _, jrank, jsingular_values, jrcond = jnp_fun(*args) - self.assertAllClose( - (np_p, nrank, nsingular_values, nrcond), - (jp_p, jrank, jsingular_values, jrcond), - atol=tol, rtol=tol, check_dtypes=False) - - @jtu.sample_product( - [dict(a_min=a_min, a_max=a_max) - for a_min, a_max in [(-1, None), (None, 1), (-0.9, 1), - (-np.ones(1), None), - (None, np.ones(1)), - (np.full(1, -0.9), np.ones(1))] - ], - shape=all_shapes, - dtype=number_dtypes, - ) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion - def testClipStaticBounds(self, shape, dtype, a_min, a_max): - if np.issubdtype(dtype, np.unsignedinteger): - a_min = None if a_min is None else abs(a_min) - a_max = None if a_max is None else abs(a_max) - rng = jtu.rand_default(self.rng()) - np_fun = lambda x: np.clip(x, a_min=a_min, a_max=a_max) - jnp_fun = lambda x: jnp.clip(x, min=a_min, max=a_max) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, dtype=dtype) - for shape, dtype in _shape_and_dtypes(all_shapes, number_dtypes)], - decimals=[0, 1, -2], - ) - def testRoundStaticDecimals(self, shape, dtype, decimals): - rng = jtu.rand_default(self.rng()) - if jnp.issubdtype(dtype, np.integer) and decimals < 0: - self.skipTest("Integer rounding with decimals < 0 not implemented") - np_fun = lambda x: np.round(x, decimals=decimals) - jnp_fun = lambda x: jnp.round(x, decimals=decimals) - args_maker = lambda: [rng(shape, dtype)] - tol = {jnp.bfloat16: 5e-2, np.float16: 1e-2} - check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, - check_dtypes=check_dtypes, tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=check_dtypes, - atol=tol, rtol=tol) - - @jtu.sample_product(jit=[False, True]) - def testOperatorRound(self, jit): - jround = jax.jit(round, static_argnums=1) if jit else round - self.assertAllClose(round(np.float32(7.532), 1), - jround(jnp.float32(7.5), 1)) - self.assertAllClose(round(np.float32(1.234), 2), - jround(jnp.float32(1.234), 2)) - self.assertAllClose(round(np.float32(1.234)), - jround(jnp.float32(1.234)), check_dtypes=False) - self.assertAllClose(round(np.float32(7.532), 1), - jround(jnp.array(7.5, jnp.float32), 1)) - self.assertAllClose(round(np.float32(1.234), 2), - jround(jnp.array(1.234, jnp.float32), 2)) - self.assertAllClose(round(np.float32(1.234)), - jround(jnp.array(1.234, jnp.float32)), - check_dtypes=False) - - def testRoundMethod(self): - # https://github.com/jax-ml/jax/issues/15190 - (jnp.arange(3.) / 5.).round() # doesn't crash - - @jtu.sample_product(shape=[(5,), (5, 2)]) - def testOperatorReversed(self, shape): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, 'float32')] - np_fun = lambda x: np.array(list(reversed(x))) - jnp_fun = lambda x: jnp.array(list(reversed(x))) - - self._CompileAndCheck(jnp_fun, args_maker) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - - @jtu.sample_product( - [dict(mode=mode, shape=shape, dtype=dtype, - pad_width=pad_width, constant_values=constant_values) - for mode, shapes in [ - ('constant', all_shapes), - ('wrap', nonempty_shapes), - ('edge', nonempty_shapes), - ] - for shape, dtype in _shape_and_dtypes(shapes, all_dtypes) - for constant_values in [ - # None is used for modes other than 'constant' - None, - # constant - 0, 1, - # (constant,) - (0,), (2.718,), - # ((before_const, after_const),) - ((0, 2),), ((-1, 3.14),), - # ((before_1, after_1), ..., (before_N, after_N)) - tuple((i / 2, -3.14 * i) for i in range(len(shape))), - ] - for pad_width in [ - # ((before_1, after_1), ..., (before_N, after_N)) - tuple((i % 3, (i + 1) % 3) for i in range(len(shape))), - # ((before, after),) - ((1, 2),), ((2, 0),), - # (before, after) (not in the docstring but works in numpy) - (2, 0), (0, 0), - # (pad,) - (1,), (2,), - # pad - 0, 1, - ] - if (pad_width != () and constant_values != () and - ((mode == 'constant' and constant_values is not None) or - (mode != 'constant' and constant_values is None)))], - ) - def testPad(self, shape, dtype, mode, pad_width, constant_values): - if np.issubdtype(dtype, np.unsignedinteger): - constant_values = jax.tree.map(abs, constant_values) - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - if constant_values is None: - np_fun = partial(np.pad, pad_width=pad_width, mode=mode) - jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode) - else: - np_fun = partial(np.pad, pad_width=pad_width, mode=mode, - constant_values=constant_values) - jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode, - constant_values=constant_values) - - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, - check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(mode=mode, shape=shape, dtype=dtype, - pad_width=pad_width, stat_length=stat_length) - for mode in ['maximum', 'minimum', 'mean', 'median'] - for shape, dtype in _shape_and_dtypes(nonempty_shapes, all_dtypes) - for pad_width in [ - # ((before_1, after_1), ..., (before_N, after_N)) - tuple((i % 3, (i + 1) % 3) for i in range(len(shape))), - # ((before, after),) - ((1, 2),), ((2, 0),), - # (before, after) (not in the docstring but works in numpy) - (2, 0), (0, 0), - # (pad,) - (1,), (2,), - # pad - 0, 1, - ] - for stat_length in [ - None, - # ((before_1, after_1), ..., (before_N, after_N)) - tuple(((i % 3 + 1), ((i + 1) % 3) + 1) for i in range(len(shape))), - # ((before, after),) - ((1, 2),), ((2, 2),), - # (before, after) (not in the docstring but works in numpy) - (1, 1), (3, 4), - # (pad,) - (1,), (2,), - # pad - 1, 2 - ] - if (pad_width != () and stat_length != () and - not (dtype in bool_dtypes and mode == 'mean'))], - ) - def testPadStatValues(self, shape, dtype, mode, pad_width, stat_length): - if mode == 'median' and np.issubdtype(dtype, np.complexfloating): - self.skipTest("median statistic is not supported for dtype=complex.") - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - - np_fun = partial(np.pad, pad_width=pad_width, mode=mode, stat_length=stat_length) - jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode, stat_length=stat_length) - - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, - check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, dtype=dtype, - pad_width=pad_width, reflect_type=reflect_type) - for shape, dtype in _shape_and_dtypes(nonempty_shapes, all_dtypes) - for pad_width in [ - # ((before_1, after_1), ..., (before_N, after_N)) - tuple((i % 3, (i + 1) % 3) for i in range(len(shape))), - # ((before, after),) - ((1, 2),), ((2, 3),), - # (before, after) (not in the docstring but works in numpy) - (2, 1), (1, 2), - # (pad,) - (1,), (2,), (3,), - # pad - 0, 5, 7, 10 - ] - for reflect_type in ['even', 'odd'] - if (pad_width != () and - # following types lack precision when calculating odd values - (reflect_type != 'odd' or dtype not in [np.bool_, np.float16, jnp.bfloat16]))], - mode=['symmetric', 'reflect'] - ) - def testPadSymmetricAndReflect(self, shape, dtype, mode, pad_width, reflect_type): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - - np_fun = partial(np.pad, pad_width=pad_width, mode=mode, reflect_type=reflect_type) - jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode, reflect_type=reflect_type) - - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, - check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE, - tol={np.float32: 1e-3, np.complex64: 1e-3}) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, dtype=dtype, pad_width=pad_width, end_values=end_values) - for shape, dtype in _shape_and_dtypes(nonempty_shapes, default_dtypes + complex_dtypes) - for pad_width in [ - # ((before_1, after_1), ..., (before_N, after_N)) - tuple((i % 3, (i + 1) % 3) for i in range(len(shape))), - # ((before, after),) - ((1, 2),), ((2, 0),), - # (before, after) (not in the docstring but works in numpy) - (2, 0), (0, 0), - # (pad,) - (1,), (2,), - # pad - 0, 1, - ] - for end_values in [ - # ((before_1, after_1), ..., (before_N, after_N)) - tuple((i % 3, (i + 1) % 3) for i in range(len(shape))), - # ((before, after),) - ((1, 2),), ((2.0, 3.14),), - # (before, after) (not in the docstring but works in numpy) - (0, 0), (-8.0, 2.0), - # (end_values,) - (1,), (2,), - # end_values - 0, 1, 100, 10.0, 3.5, 4.2, -5, -3 - ] - if (pad_width != () and end_values != () and - # following types lack precision - dtype not in [np.int8, np.int16, np.float16, jnp.bfloat16])], - ) - def testPadLinearRamp(self, shape, dtype, pad_width, end_values): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - - np_fun = partial(np.pad, pad_width=pad_width, mode="linear_ramp", - end_values=end_values) - jnp_fun = partial(jnp.pad, pad_width=pad_width, mode="linear_ramp", - end_values=end_values) - - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, - check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE) - self._CompileAndCheck(jnp_fun, args_maker) - - def testPadEmpty(self): - arr = np.arange(6).reshape(2, 3) - - pad_width = ((2, 3), (3, 1)) - np_res = np.pad(arr, pad_width=pad_width, mode="empty") - jnp_res = jnp.pad(arr, pad_width=pad_width, mode="empty") - - np.testing.assert_equal(np_res.shape, jnp_res.shape) - np.testing.assert_equal(arr, np_res[2:-3, 3:-1]) - np.testing.assert_equal(arr, jnp_res[2:-3, 3:-1]) - np.testing.assert_equal(np_res[2:-3, 3:-1], jnp_res[2:-3, 3:-1]) - - def testPadKwargs(self): - modes = { - 'constant': {'constant_values': 0}, - 'edge': {}, - 'linear_ramp': {'end_values': 0}, - 'maximum': {'stat_length': None}, - 'mean': {'stat_length': None}, - 'median': {'stat_length': None}, - 'minimum': {'stat_length': None}, - 'reflect': {'reflect_type': 'even'}, - 'symmetric': {'reflect_type': 'even'}, - 'wrap': {}, - 'empty': {} - } - arr = jnp.array([1, 2, 3]) - pad_width = 1 - - for mode in modes.keys(): - allowed = modes[mode] - not_allowed = {} - for kwargs in modes.values(): - if kwargs != allowed: - not_allowed.update(kwargs) - - # Test if allowed keyword arguments pass - jnp.pad(arr, pad_width, mode, **allowed) - # Test if prohibited keyword arguments of other modes raise an error - match = f"unsupported keyword arguments for mode '{mode}'" - for key, value in not_allowed.items(): - with self.assertRaisesRegex(ValueError, match): - jnp.pad(arr, pad_width, mode, **{key: value}) - - # Test if unsupported mode raise error. - unsupported_modes = [1, None, "foo"] - for mode in unsupported_modes: - match = f"Unimplemented padding mode '{mode}' for np.pad." - with self.assertRaisesRegex(NotImplementedError, match): - jnp.pad(arr, pad_width, mode) - - def testPadFunction(self): - def np_pad_with(vector, pad_width, iaxis, kwargs): - pad_value = kwargs.get('padder', 10) - vector[:pad_width[0]] = pad_value - vector[-pad_width[1]:] = pad_value - - def jnp_pad_with(vector, pad_width, iaxis, kwargs): - pad_value = kwargs.get('padder', 10) - vector = vector.at[:pad_width[0]].set(pad_value) - vector = vector.at[-pad_width[1]:].set(pad_value) - return vector - - arr = np.arange(6).reshape(2, 3) - np_res = np.pad(arr, 2, np_pad_with) - jnp_res = jnp.pad(arr, 2, jnp_pad_with) - np.testing.assert_equal(np_res, jnp_res) - - arr = np.arange(24).reshape(2, 3, 4) - np_res = np.pad(arr, 1, np_pad_with, padder=100) - jnp_res = jnp.pad(arr, 1, jnp_pad_with, padder=100) - np.testing.assert_equal(np_res, jnp_res) - - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(arr.shape, arr.dtype)] - jnp_fun = partial(jnp.pad, pad_width=1, mode=jnp_pad_with) - self._CompileAndCheck(jnp_fun, args_maker) - - def testPadWithNumpyPadWidth(self): - a = jnp.array([1, 2, 3, 4, 5]) - f = jax.jit( - partial( - jnp.pad, - pad_width=np.asarray((2, 3)), - mode="constant", - constant_values=(4, 6))) - - np.testing.assert_array_equal( - f(a), - np.pad( - a, - pad_width=np.asarray((2, 3)), - mode="constant", - constant_values=(4, 6))) - - def testPadWeakType(self): - x = jnp.array(1.0)[None] - for mode in ['constant', 'edge', 'linear_ramp', 'maximum', 'mean', 'median', - 'minimum', 'reflect', 'symmetric', 'wrap', 'empty']: - y = jnp.pad(x, 0, mode=mode) - self.assertTrue(dtypes.is_weakly_typed(y)) - - @jtu.sample_product( - [dict(shape=shape, dtype=dtype) - for shape, dtype in _shape_and_dtypes(all_shapes, default_dtypes)], - reps=[(), (2,), (3, 4), (2, 3, 4), (1, 0, 2)], - ) - def testTile(self, shape, dtype, reps): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg: np.tile(arg, reps) - jnp_fun = lambda arg: jnp.tile(arg, reps) - - args_maker = lambda: [rng(shape, dtype)] - - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, - check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product(shape=all_shapes, dtype=all_dtypes) - def testExtract(self, shape, dtype): - rng = jtu.rand_some_zero(self.rng()) - args_maker = lambda: [rng(shape, jnp.float32), rng(shape, dtype)] - self._CheckAgainstNumpy(np.extract, jnp.extract, args_maker) - - @jtu.sample_product( - [dict(ncond=ncond, nfunc=nfunc) - for ncond in [1, 2, 3] - for nfunc in [ncond, ncond + 1] - ], - shape=all_shapes, - dtype=all_dtypes) - def testPiecewise(self, shape, dtype, ncond, nfunc): - rng = jtu.rand_default(self.rng()) - rng_bool = jtu.rand_int(self.rng(), 0, 2) - funclist = [lambda x: x - 1, 1, lambda x: x, 0][:nfunc] - args_maker = lambda: (rng(shape, dtype), [rng_bool(shape, bool) for i in range(ncond)]) - np_fun = partial(np.piecewise, funclist=funclist) - jnp_fun = partial(jnp.piecewise, funclist=funclist) - - if dtype == np.bool_: - # The `x - 1` above uses type promotion. - jnp_fun = jax.numpy_dtype_promotion('standard')(jnp_fun) - - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - # This is a higher-order function, so the cache miss check will fail. - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, check_cache_misses=False) - - def testPiecewiseRecompile(self): - def g(x): - g.num_traces += 1 - return x - g.num_traces = 0 - x = jnp.arange(10.0) - for i in range(5): - jnp.piecewise(x, [x < 0], [g, 0.]) - self.assertEqual(g.num_traces, 1) - - @jtu.sample_product( - [dict(shape=shape, perm=perm) - for shape in array_shapes - for perm in [ - None, - tuple(np.random.RandomState(0).permutation(np.zeros(shape).ndim)), - tuple(np.random.RandomState(0).permutation( - np.zeros(shape).ndim) - np.zeros(shape).ndim) - ] - ], - dtype=default_dtypes, - arg_type=["splat", "value"], - ) - def testTransposeTuple(self, shape, dtype, perm, arg_type): - rng = jtu.rand_some_zero(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - if arg_type == "value": - np_fun = lambda x: x.transpose(perm) - jnp_fun = lambda x: jnp.array(x).transpose(perm) - else: - np_fun = lambda x: x.transpose(*(perm or ())) - jnp_fun = lambda x: jnp.array(x).transpose(*(perm or ())) - - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) - - @jtu.sample_product( - shape=array_shapes, - dtype=default_dtypes, - ) - def testPermuteDims(self, shape, dtype): - rng = jtu.rand_some_zero(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - axes = self.rng().permutation(len(shape)) - np_fun = partial(getattr(np, "permute_dims", np.transpose), axes=axes) - jnp_fun = partial(jnp.permute_dims, axes=axes) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) - - @jtu.sample_product( - shape=[s for s in array_shapes if len(s) >= 2], - dtype=default_dtypes, - use_property=[True, False] - ) - def testMatrixTranspose(self, shape, dtype, use_property): - if use_property: - jnp_fun = lambda x: jnp.asarray(x).mT - else: - jnp_fun = jnp.matrix_transpose - if hasattr(np, 'matrix_transpose'): - np_fun = np.matrix_transpose - else: - np_fun = lambda x: np.swapaxes(x, -1, -2) - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - dtype=default_dtypes, - a_shape=one_dim_array_shapes, - trim=["f", "b", "fb"], - ) - def testTrimZeros(self, a_shape, dtype, trim): - rng = jtu.rand_some_zero(self.rng()) - args_maker = lambda: [rng(a_shape, dtype)] - np_fun = lambda arg1: np.trim_zeros(arg1, trim) - jnp_fun = lambda arg1: jnp.trim_zeros(arg1, trim) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - - @unittest.skip("Jax-metal don't support map op.") - @jtu.sample_product( - rank=(1, 2), - dtype=default_dtypes, - a_shape=one_dim_array_shapes, - ) - @jax.default_matmul_precision("float32") - def testPoly(self, a_shape, dtype, rank): - if dtype in (np.float16, jnp.bfloat16, np.int16): - self.skipTest(f"{dtype} gets promoted to {np.float16}, which is not supported.") - elif rank == 2 and not jtu.test_device_matches(["cpu"]): - self.skipTest("Nonsymmetric eigendecomposition is only implemented on the CPU backend.") - rng = jtu.rand_default(self.rng()) - tol = { np.int8: 2e-3, np.int32: 1e-3, np.float32: 1e-3, np.float64: 1e-6 } - if jtu.test_device_matches(["tpu"]): - tol[np.int32] = tol[np.float32] = 1e-1 - tol = jtu.tolerance(dtype, tol) - args_maker = lambda: [rng(a_shape * rank, dtype)] - self._CheckAgainstNumpy(np.poly, jnp.poly, args_maker, check_dtypes=False, tol=tol) - self._CompileAndCheck(jnp.poly, args_maker, check_dtypes=True, rtol=tol, atol=tol) - - @unittest.skip("Jax-metal don't support map op.") - @jtu.sample_product( - dtype=default_dtypes, - a_shape=one_dim_array_shapes, - b_shape=one_dim_array_shapes, - ) - def testPolyAdd(self, a_shape, b_shape, dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg1, arg2: np.polyadd(arg1, arg2) - jnp_fun = lambda arg1, arg2: jnp.polyadd(arg1, arg2) - args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) - - @unittest.skip("Jax-metal don't support map op.") - @jtu.sample_product( - dtype=default_dtypes, - a_shape=one_dim_array_shapes, - b_shape=one_dim_array_shapes, - ) - def testPolySub(self, a_shape, b_shape, dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg1, arg2: np.polysub(arg1, arg2) - jnp_fun = lambda arg1, arg2: jnp.polysub(arg1, arg2) - args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) - - @unittest.skip("Jax-metal don't support map op.") - @jtu.sample_product( - [dict(order=order, k=k, dtype=dtype) - for dtype in default_dtypes - for order in range(5) - for k in [np.arange(order, dtype=dtype), np.ones(1, dtype), None]], - a_shape=one_dim_array_shapes, - ) - def testPolyInt(self, a_shape, order, k, dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg1: np.polyint(arg1, m=order, k=k) - jnp_fun = lambda arg1: jnp.polyint(arg1, m=order, k=k) - args_maker = lambda: [rng(a_shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) - - @unittest.skip("Jax-metal don't support map op.") - @jtu.sample_product( - dtype=default_dtypes, - a_shape=one_dim_array_shapes, - order=list(range(5)), - ) - def testPolyDer(self, a_shape, order, dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg1: np.polyder(arg1, m=order) - jnp_fun = lambda arg1: jnp.polyder(arg1, m=order) - args_maker = lambda: [rng(a_shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) - - @parameterized.parameters(['int', 'np.int', 'jnp.int']) - def testIntegerPower(self, ptype): - p = {'int': 2, 'np.int': np.int32(2), 'jnp.int': jnp.int32(2)}[ptype] - jaxpr = jax.make_jaxpr(lambda x1: jnp.power(x1, p))(1) - eqns = jaxpr.jaxpr.eqns - self.assertLen(eqns, 1) - self.assertEqual(eqns[0].primitive, lax.integer_pow_p) - - @jtu.sample_product( - x=[-1, 0, 1], - y=[0, 32, 64, 128], - ) - def testIntegerPowerOverflow(self, x, y): - # Regression test for https://github.com/jax-ml/jax/issues/5987 - args_maker = lambda: [x, y] - self._CheckAgainstNumpy(np.power, jnp.power, args_maker) - self._CompileAndCheck(jnp.power, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in all_shapes - for axis in [None] + list(range(len(shape))) - ], - dtype=all_dtypes, - ) - def testCompress(self, shape, dtype, axis): - rng = jtu.rand_some_zero(self.rng()) - if shape in scalar_shapes or len(shape) == 0: - cond_shape = (0,) - elif axis is None: - cond_shape = (math.prod(shape),) - else: - cond_shape = (shape[axis],) - - args_maker = lambda: [rng(cond_shape, jnp.float32), rng(shape, dtype)] - - np_fun = partial(np.compress, axis=axis) - jnp_fun = partial(jnp.compress, axis=axis) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - - @jtu.sample_product( - shape=[(2, 3)], - dtype=int_dtypes, - # condition entries beyond axis size must be zero. - condition=[[1], [1, 0, 0, 0, 0, 0, 0]], - axis=[None, 0, 1], - ) - def testCompressMismatchedShapes(self, shape, dtype, condition, axis): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [np.array(condition), rng(shape, dtype)] - np_fun = partial(np.compress, axis=axis) - jnp_fun = partial(jnp.compress, axis=axis) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in array_shapes - for axis in [None] + list(range(len(shape))) - ], - dtype=all_dtypes, - ) - def testCompressMethod(self, shape, dtype, axis): - rng = jtu.rand_some_zero(self.rng()) - if shape in scalar_shapes or len(shape) == 0: - cond_shape = (0,) - elif axis is None: - cond_shape = (math.prod(shape),) - else: - cond_shape = (shape[axis],) - - args_maker = lambda: [rng(cond_shape, jnp.float32), rng(shape, dtype)] - - np_fun = lambda condition, x: np.compress(condition, x, axis=axis) - jnp_fun = lambda condition, x: x.compress(condition, axis=axis) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - - @jtu.sample_product( - [dict(base_shape=base_shape, axis=axis) - for base_shape in [(4,), (3, 4), (2, 3, 4)] - for axis in (None, *range(-len(base_shape)+1, len(base_shape))) - ], - arg_dtypes=[ - arg_dtypes - for num_arrs in [3] - for arg_dtypes in itertools.combinations_with_replacement(default_dtypes, num_arrs) - ], - dtype=[None] + default_dtypes, - ) - def testConcatenate(self, axis, dtype, base_shape, arg_dtypes): - rng = jtu.rand_default(self.rng()) - wrapped_axis = 0 if axis is None else axis % len(base_shape) - shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:] - for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes)] - @jtu.promote_like_jnp - def np_fun(*args, dtype=dtype): - dtype = dtype or args[0].dtype - args = [x if x.dtype != jnp.bfloat16 else x.astype(np.float32) - for x in args] - return np.concatenate(args, axis=axis, dtype=dtype, casting='unsafe') - jnp_fun = lambda *args: jnp.concatenate(args, axis=axis, dtype=dtype) - - def args_maker(): - return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)] - - with jtu.strict_promotion_if_dtypes_match(arg_dtypes): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in [(4, 1), (4, 3), (4, 5, 6)] - for axis in [None] + list(range(1 - len(shape), len(shape) - 1)) - ], - dtype=all_dtypes, - ) - def testConcatenateArray(self, shape, dtype, axis): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - np_fun = lambda x: np.concatenate(x, axis=axis) - jnp_fun = lambda x: jnp.concatenate(x, axis=axis) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - def testConcatenateAxisNone(self): - # https://github.com/jax-ml/jax/issues/3419 - a = jnp.array([[1, 2], [3, 4]]) - b = jnp.array([[5]]) - jnp.concatenate((a, b), axis=None) - - def testConcatenateScalarAxisNone(self): - arrays = [np.int32(0), np.int32(1)] - self.assertArraysEqual(jnp.concatenate(arrays, axis=None), - np.concatenate(arrays, axis=None)) - - @jtu.sample_product( - [dict(base_shape=base_shape, axis=axis) - for base_shape in [(), (4,), (3, 4), (2, 3, 4)] - for axis in (None, *range(-len(base_shape)+1, len(base_shape))) - ], - dtype=default_dtypes, - ) - def testConcat(self, axis, base_shape, dtype): - rng = jtu.rand_default(self.rng()) - wrapped_axis = 0 if axis is None else axis % len(base_shape) - shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:] - for size in [3, 1, 4]] - @jtu.promote_like_jnp - def np_fun(*args): - if jtu.numpy_version() >= (2, 0, 0): - return np.concat(args, axis=axis) - else: - return np.concatenate(args, axis=axis) - jnp_fun = lambda *args: jnp.concat(args, axis=axis) - args_maker = lambda: [rng(shape, dtype) for shape in shapes] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(base_shape=base_shape, axis=axis) - for base_shape in [(4,), (3, 4), (2, 3, 4)] - for axis in range(-len(base_shape)+1, len(base_shape))], - arg_dtypes=itertools.combinations_with_replacement(default_dtypes, 2) - ) - def testAppend(self, axis, base_shape, arg_dtypes): - rng = jtu.rand_default(self.rng()) - wrapped_axis = axis % len(base_shape) - shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:] - for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes)] - def np_fun(arr, values): - arr = arr.astype(np.float32) if arr.dtype == jnp.bfloat16 else arr - values = (values.astype(np.float32) if values.dtype == jnp.bfloat16 - else values) - out = np.append(arr, values, axis=axis) - return out.astype(jnp.promote_types(*arg_dtypes)) - jnp_fun = lambda arr, values: jnp.append(arr, values, axis=axis) - - def args_maker(): - return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)] - - with jtu.strict_promotion_if_dtypes_match(arg_dtypes): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axis=axis, idx=idx) - for shape in nonempty_nonscalar_array_shapes - for axis in [None] + list(range(-len(shape), len(shape))) - for idx in (range(-math.prod(shape), math.prod(shape)) - if axis is None else - range(-shape[axis], shape[axis]))], - dtype=all_dtypes, - ) - def testDeleteInteger(self, shape, dtype, idx, axis): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - np_fun = lambda arg: np.delete(arg, idx, axis=axis) - jnp_fun = lambda arg: jnp.delete(arg, idx, axis=axis) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in nonempty_nonscalar_array_shapes - for axis in [None] + list(range(-len(shape), len(shape))) - ], - dtype=all_dtypes, - slc=[slice(None), slice(1, 3), slice(1, 5, 2)], - ) - def testDeleteSlice(self, shape, dtype, axis, slc): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - np_fun = lambda arg: np.delete(arg, slc, axis=axis) - jnp_fun = lambda arg: jnp.delete(arg, slc, axis=axis) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in nonempty_nonscalar_array_shapes - for axis in [None] + list(range(-len(shape), len(shape))) - ], - dtype=all_dtypes, - idx_shape=all_shapes, - ) - def testDeleteIndexArray(self, shape, dtype, axis, idx_shape): - rng = jtu.rand_default(self.rng()) - max_idx = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis] - idx = jtu.rand_int(self.rng(), low=-max_idx, high=max_idx)(idx_shape, int) - args_maker = lambda: [rng(shape, dtype)] - np_fun = lambda arg: np.delete(arg, idx, axis=axis) - jnp_fun = lambda arg: jnp.delete(arg, idx, axis=axis) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in nonempty_nonscalar_array_shapes - for axis in [None] + list(range(-len(shape), len(shape))) - ], - dtype=all_dtypes, - idx_shape=all_shapes, - ) - def testDeleteUniqueIndices(self, shape, dtype, axis, idx_shape): - rng = jtu.rand_default(self.rng()) - max_idx = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis] - idx_size = np.zeros(idx_shape).size - if idx_size > max_idx: - self.skipTest("Too many indices to be unique") - def args_maker(): - x = rng(shape, dtype) - idx = self.rng().choice(max_idx, idx_shape, replace=False) - return x, idx - np_fun = partial(np.delete, axis=axis) - jnp_fun = partial(jnp.delete, axis=axis, assume_unique_indices=True) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in nonempty_nonscalar_array_shapes - for axis in [None] + list(range(-len(shape), len(shape))) - ], - dtype=all_dtypes, - ) - def testDeleteMaskArray(self, shape, dtype, axis): - rng = jtu.rand_default(self.rng()) - mask_size = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis] - mask = jtu.rand_int(self.rng(), low=0, high=2)(mask_size, bool) - args_maker = lambda: [rng(shape, dtype)] - np_fun = lambda arg: np.delete(arg, mask, axis=axis) - jnp_fun = lambda arg: jnp.delete(arg, mask, axis=axis) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @unittest.skip("JAX-metal fail.") - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in nonempty_nonscalar_array_shapes - for axis in [None] + list(range(-len(shape), len(shape))) - ], - dtype=all_dtypes, - ) - def testInsertInteger(self, shape, dtype, axis): - x = jnp.empty(shape) - max_ind = x.size if axis is None else x.shape[axis] - rng = jtu.rand_default(self.rng()) - i_rng = jtu.rand_int(self.rng(), -max_ind, max_ind) - args_maker = lambda: [rng(shape, dtype), i_rng((), np.int32), rng((), dtype)] - np_fun = lambda *args: np.insert(*args, axis=axis) - jnp_fun = lambda *args: jnp.insert(*args, axis=axis) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @unittest.skip("Jax-metal fail.") - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in nonempty_nonscalar_array_shapes - for axis in [None] + list(range(-len(shape), len(shape))) - ], - dtype=all_dtypes, - ) - def testInsertSlice(self, shape, dtype, axis): - x = jnp.empty(shape) - max_ind = x.size if axis is None else x.shape[axis] - rng = jtu.rand_default(self.rng()) - i_rng = jtu.rand_int(self.rng(), -max_ind, max_ind) - slc = slice(i_rng((), jnp.int32).item(), i_rng((), jnp.int32).item()) - args_maker = lambda: [rng(shape, dtype), rng((), dtype)] - np_fun = lambda x, val: np.insert(x, slc, val, axis=axis) - jnp_fun = lambda x, val: jnp.insert(x, slc, val, axis=axis) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @parameterized.parameters([ - [[[1, 1], [2, 2], [3, 3]], 1, 5, None], - [[[1, 1], [2, 2], [3, 3]], 1, 5, 1], - [[[1, 1], [2, 2], [3, 3]], 1, [1, 2, 3], 1], - [[[1, 1], [2, 2], [3, 3]], [1], [[1],[2],[3]], 1], - [[1, 1, 2, 2, 3, 3], [2, 2], [5, 6], None], - [[1, 1, 2, 2, 3, 3], slice(2, 4), [5, 6], None], - [[1, 1, 2, 2, 3, 3], [2, 2], [7.13, False], None], - [[[0, 1, 2, 3], [4, 5, 6, 7]], (1, 3), 999, 1] - ]) - def testInsertExamples(self, arr, index, values, axis): - # Test examples from the np.insert docstring - args_maker = lambda: ( - np.asarray(arr), index if isinstance(index, slice) else np.array(index), - np.asarray(values), axis) - self._CheckAgainstNumpy(np.insert, jnp.insert, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in nonempty_array_shapes - for axis in range(-len(shape), len(shape)) - ], - dtype=default_dtypes, - out_dims=[0, 1, 2], - ) - def testApplyAlongAxis(self, shape, dtype, axis, out_dims): - def func(x, out_dims): - if out_dims == 0: - return x.sum(dtype=x.dtype) - elif out_dims == 1: - return x * x[0] - elif out_dims == 2: - return x[:, None] + x[None, :] - else: - raise NotImplementedError(f"{out_dims=}") - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - np_fun = lambda arr: np.apply_along_axis(func, axis, arr, out_dims=out_dims) - jnp_fun = lambda arr: jnp.apply_along_axis(func, axis, arr, out_dims=out_dims) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, - atol={dtypes.bfloat16: 2e-2}) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axes=axes) - for shape in nonempty_shapes - for axes in itertools.combinations(range(len(shape)), 2) - ], - func=["sum"], - keepdims=[True, False], - # Avoid low-precision types in sum() - dtype=[dtype for dtype in default_dtypes - if dtype not in [np.float16, jnp.bfloat16]], - ) - def testApplyOverAxes(self, shape, dtype, func, keepdims, axes): - f = lambda x, axis: getattr(x, func)(axis=axis, keepdims=keepdims, dtype=dtype) - rng = jtu.rand_default(self.rng()) - args_maker = lambda: (rng(shape, dtype),) - np_fun = lambda a: np.apply_over_axes(f, a, axes) - jnp_fun = lambda a: jnp.apply_over_axes(f, a, axes) - self._CompileAndCheck(jnp_fun, args_maker) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, dtype=dtype, axis=axis) - for shape, dtype in _shape_and_dtypes(all_shapes, default_dtypes) - for axis in [None] + list(range(-len(shape), max(1, len(shape)))) - ], - repeats=[0, 1, 2], - fixed_size=[False, True], - ) - def testRepeat(self, axis, shape, dtype, repeats, fixed_size): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg: np.repeat(arg, repeats=repeats, axis=axis) - np_fun = jtu.promote_like_jnp(np_fun) - if fixed_size: - total_repeat_length = np.repeat(np.zeros(shape), repeats, axis).shape[axis or 0] - jnp_fun = lambda arg, rep: jnp.repeat(arg, repeats=rep, axis=axis, - total_repeat_length=total_repeat_length) - jnp_args_maker = lambda: [rng(shape, dtype), repeats] - clo_fun = lambda arg: jnp.repeat(arg, repeats=repeats, axis=axis, - total_repeat_length=total_repeat_length) - clo_fun_args_maker = lambda: [rng(shape, dtype)] - self._CompileAndCheck(jnp_fun, jnp_args_maker) - self._CheckAgainstNumpy(np_fun, clo_fun, clo_fun_args_maker) - else: - # Now repeats is in a closure, so a constant. - jnp_fun = lambda arg: jnp.repeat(arg, repeats=repeats, axis=axis) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - def testRepeatScalarFastPath(self): - a = jnp.array([1,2,3,4]) - f = lambda a: jnp.repeat(a, repeats=2) - jaxpr = jax.make_jaxpr(f)(a) - self.assertLessEqual(len(jaxpr.jaxpr.eqns), 6) - - @unittest.skip("jax-metal fail to convert sort op.") - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in all_shapes - for axis in [None] + list(range(len(shape)))], - dtype=number_dtypes, - return_index=[False, True], - return_inverse=[False, True], - return_counts=[False, True], - ) - def testUnique(self, shape, dtype, axis, return_index, return_inverse, return_counts): - rng = jtu.rand_some_equal(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - extra_args = (return_index, return_inverse, return_counts) - use_defaults = (False, *(True for arg in extra_args if arg)) if any(extra_args) else False - np_fun = jtu.with_jax_dtype_defaults(lambda x: np_unique_backport(x, *extra_args, axis=axis), use_defaults) - jnp_fun = lambda x: jnp.unique(x, *extra_args, axis=axis) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - - @jtu.sample_product(shape=all_shapes, dtype=number_dtypes) - def testUniqueAll(self, shape, dtype): - rng = jtu.rand_some_equal(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - if jtu.numpy_version() < (2, 0, 0): - np_fun = partial(np_unique_backport, return_index=True, return_inverse=True, return_counts=True) - else: - np_fun = np.unique_all - self._CheckAgainstNumpy(jnp.unique_all, np_fun, args_maker) - - @jtu.sample_product(shape=all_shapes, dtype=number_dtypes) - def testUniqueCounts(self, shape, dtype): - rng = jtu.rand_some_equal(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - if jtu.numpy_version() < (2, 0, 0): - np_fun = lambda x: np.unique(x, return_counts=True) - else: - np_fun = np.unique_counts - self._CheckAgainstNumpy(jnp.unique_counts, np_fun, args_maker) - - @jtu.sample_product(shape=all_shapes, dtype=number_dtypes) - def testUniqueInverse(self, shape, dtype): - rng = jtu.rand_some_equal(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - if jtu.numpy_version() < (2, 0, 0): - np_fun = partial(np_unique_backport, return_inverse=True) - else: - np_fun = np.unique_inverse - self._CheckAgainstNumpy(jnp.unique_inverse, np_fun, args_maker) - - @jtu.sample_product(shape=all_shapes, dtype=number_dtypes) - def testUniqueValues(self, shape, dtype): - rng = jtu.rand_some_equal(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - if jtu.numpy_version() < (2, 0, 0): - np_fun = np.unique - else: - np_fun = np.unique_values - self._CheckAgainstNumpy(jnp.unique_values, np_fun, args_maker) - - @unittest.skip("jax-metal fail.") - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in nonempty_array_shapes - for axis in [None] + list(range(len(shape)))], - dtype=number_dtypes, - size=[1, 5, 10], - fill_value=[None, 0, "slice"], - ) - def testUniqueSize(self, shape, dtype, axis, size, fill_value): - rng = jtu.rand_some_equal(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - kwds = dict(axis=axis, return_index=True, return_inverse=True, return_counts=True) - - if fill_value == "slice": - if axis is None: - fill_value = rng((), dtype) - else: - fill_value = rng(shape[:axis] + shape[axis + 1:], dtype) - elif fill_value is not None: - fill_value = np.array(fill_value).astype(dtype) - - @partial(jtu.with_jax_dtype_defaults, use_defaults=(False, True, True, True)) - def np_fun(x, fill_value=fill_value): - u, ind, inv, counts = np_unique_backport(x, **kwds) - axis = kwds['axis'] - if axis is None: - x = x.ravel() - axis = 0 - - n_unique = u.shape[axis] - if size <= u.shape[axis]: - slc = (slice(None),) * axis + (slice(size),) - u, ind, counts = u[slc], ind[:size], counts[:size] - else: - extra = (0, size - n_unique) - pads = [(0, 0)] * u.ndim - pads[axis] = extra - u = np.pad(u, pads, constant_values=0) - slices = [slice(None)] * u.ndim - slices[axis] = slice(1) - if fill_value is None: - fill_value = u[tuple(slices)] - elif np.ndim(fill_value): - fill_value = lax.expand_dims(fill_value, (axis,)) - slices[axis] = slice(n_unique, None) - u[tuple(slices)] = fill_value - ind = np.pad(ind, extra, constant_values=ind[0]) - counts = np.pad(counts, extra, constant_values=0) - return u, ind, inv, counts - - jnp_fun = lambda x: jnp.unique(x, size=size, fill_value=fill_value, **kwds) - - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @unittest.skip("jax-metal fail.") - @jtu.sample_product(dtype=inexact_dtypes) - def testUniqueNans(self, dtype): - def args_maker(): - x = [-0.0, 0.0, 1.0, 1.0, np.nan, -np.nan] - if np.issubdtype(dtype, np.complexfloating): - x = [complex(i, j) for i, j in itertools.product(x, repeat=2)] - return [np.array(x, dtype=dtype)] - - kwds = dict(return_index=True, return_inverse=True, return_counts=True) - jnp_fun = partial(jnp.unique, **kwds) - def np_fun(x): - dtype = x.dtype - # numpy unique fails for bfloat16 NaNs, so we cast to float64 - if x.dtype == jnp.bfloat16: - x = x.astype('float64') - u, *rest = np.unique(x, **kwds) - return (u.astype(dtype), *rest) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - - @unittest.skip("jax-metal fail.") - @jtu.sample_product(dtype=inexact_dtypes, equal_nan=[True, False]) - def testUniqueEqualNan(self, dtype, equal_nan): - shape = (20,) - rng = jtu.rand_some_nan(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - def np_fun(x): - dtype = x.dtype - # numpy unique fails for bfloat16 NaNs, so we cast to float64 - if x.dtype == jnp.bfloat16: - x = x.astype('float64') - return np.unique(x, equal_nan=equal_nan).astype(dtype) - jnp_fun = partial(jnp.unique, equal_nan=equal_nan) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - - @jtu.sample_product(fixed_size=[False, True]) - def testNonScalarRepeats(self, fixed_size): - ''' - Following numpy test suite from `test_repeat` at - https://github.com/numpy/numpy/blob/main/numpy/core/tests/test_multiarray.py - ''' - tol = 1e-5 - - def test_single(m, args_maker, repeats, axis): - lax_ans = jnp.repeat(m, repeats, axis) - numpy_ans = np.repeat(m, repeats, axis) - - self.assertAllClose(lax_ans, numpy_ans, rtol=tol, atol=tol) - if fixed_size: - - # Calculate expected size of the repeated axis. - rep_length = np.repeat(np.zeros_like(m), repeats, axis).shape[axis or 0] - jnp_fun = lambda arg, rep: jnp.repeat( - arg, repeats=rep, axis=axis, total_repeat_length=rep_length) - else: - jnp_fun = lambda arg: jnp.repeat(arg, repeats = repeats, axis=axis) - self._CompileAndCheck(jnp_fun, args_maker) - - m = jnp.array([1,2,3,4,5,6]) - if fixed_size: - args_maker = lambda: [m, repeats] - else: - args_maker = lambda: [m] - - for repeats in [2, jnp.array([1,3,0,1,1,2]), jnp.array([1,3,2,1,1,2]), jnp.array([2])]: - test_single(m, args_maker, repeats, axis=None) - test_single(m, args_maker, repeats, axis=0) - - m_rect = m.reshape((2,3)) - if fixed_size: - args_maker = lambda: [m_rect, repeats] - else: - args_maker = lambda: [m_rect] - - for repeats in [2, jnp.array([2,1]), jnp.array([2])]: - test_single(m_rect, args_maker, repeats, axis=0) - - for repeats in [2, jnp.array([1,3,2]), jnp.array([2])]: - test_single(m_rect, args_maker, repeats, axis=1) - - def testIssue2330(self): - ''' - Make sure return value of jnp.concatenate is a jax.ndarray and is side-effect save - ''' - def attempt_sideeffect(x): - x = [x] - x = jnp.concatenate(x) - x -= 1. - return x - - np_input = np.ones(1) - jnp_input = jnp.ones(1) - expected_np_input_after_call = np.ones(1) - expected_jnp_input_after_call = jnp.ones(1) - - out = jnp.concatenate([np_input]) - self.assertIs(type(out), array.ArrayImpl) - - attempt_sideeffect(np_input) - attempt_sideeffect(jnp_input) - - self.assertAllClose(np_input, expected_np_input_after_call) - self.assertAllClose(jnp_input, expected_jnp_input_after_call) - - @jtu.sample_product( - mode=['full', 'same', 'valid'], - op=['convolve', 'correlate'], - dtype= float_dtypes, #number_dtypes, - xshape=one_dim_array_shapes, - yshape=one_dim_array_shapes, - ) - def testConvolutions(self, xshape, yshape, dtype, mode, op): - jnp_op = getattr(jnp, op) - np_op = getattr(np, op) - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)] - precision = lax.Precision.HIGHEST if jtu.test_device_matches(["tpu"]) else None - jnp_fun = partial(jnp_op, mode=mode, precision=precision) - def np_fun(x, y): - return np_op(x, y, mode=mode).astype(dtypes.to_inexact_dtype(dtype)) - tol = {np.float16: 2e-1, np.float32: 1e-2, np.float64: 1e-14, - np.complex128: 1e-14} - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, tol=tol) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - mode=['full', 'same', 'valid'], - op=['convolve', 'correlate'], - dtype=float_dtypes, #number_dtypes, - xshape=one_dim_array_shapes, - yshape=one_dim_array_shapes, - ) - @jtu.skip_on_devices("cuda", "rocm") # backends don't support all dtypes. - def testConvolutionsPreferredElementType(self, xshape, yshape, dtype, mode, op): - jnp_op = getattr(jnp, op) - np_op = getattr(np, op) - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)] - precision = lax.Precision.HIGHEST if jtu.test_device_matches(["tpu"]) else None - jnp_fun = partial(jnp_op, mode=mode, precision=precision, - preferred_element_type=dtype) - def np_fun(x, y): - return np_op(x, y, mode=mode).astype(dtype) - tol = {np.float16: 2e-1, np.float32: 1e-2, np.float64: 1e-14, - np.complex128: 1e-14} - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, tol=tol) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in all_shapes - for axis in [None] + list(range(-len(shape), len(shape)))], - op=["cumsum", "cumprod"], - dtype=all_dtypes, - out_dtype=[dtype for dtype in default_dtypes if dtype != np.float16], - ) - def testCumSumProd(self, axis, shape, dtype, out_dtype, op): - jnp_op = getattr(jnp, op) - np_op = getattr(np, op) - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg: np_op(arg, axis=axis, dtype=out_dtype) - np_fun = jtu.ignore_warning(category=NumpyComplexWarning)(np_fun) - np_fun = jtu.ignore_warning(category=RuntimeWarning, - message="overflow encountered.*")(np_fun) - jnp_fun = lambda arg: jnp_op(arg, axis=axis, dtype=out_dtype) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) - - args_maker = lambda: [rng(shape, dtype)] - - tol_thresholds = {dtypes.bfloat16: 4e-2} - tol = max(jtu.tolerance(dtype, tol_thresholds), - jtu.tolerance(out_dtype, tol_thresholds)) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, - tol=tol) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in all_shapes - for axis in [None] + list(range(-len(shape), len(shape)))], - op=["nancumsum", "nancumprod"], - dtype=all_dtypes, - out_dtype=default_dtypes, - ) - def testNanCumSumProd(self, axis, shape, dtype, out_dtype, op): - jnp_op = getattr(jnp, op) - np_op = getattr(np, op) - rng = jtu.rand_some_nan(self.rng()) - np_fun = partial(np_op, axis=axis, dtype=out_dtype) - np_fun = jtu.ignore_warning(category=NumpyComplexWarning)(np_fun) - np_fun = jtu.ignore_warning(category=RuntimeWarning, - message="overflow encountered.*")(np_fun) - jnp_fun = partial(jnp_op, axis=axis, dtype=out_dtype) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) - - args_maker = lambda: [rng(shape, dtype)] - - tol_thresholds = {dtypes.bfloat16: 4e-2, np.float16: 3e-3} - tol = max(jtu.tolerance(dtype, tol_thresholds), - jtu.tolerance(out_dtype, tol_thresholds)) - if dtype != jnp.bfloat16: - # numpy functions do not properly handle bfloat16 - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, - tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) - - @unittest.skip("Jax-metal fail on testEye2") - @jtu.sample_product( - dtype=default_dtypes, - n=[0, 4], - m=[None, 0, 1, 3, 4], - k=[*range(-4, 4), -2**100, 2**100], - ) - def testEye(self, n, m, k, dtype): - np_fun = lambda: np.eye(n, M=m, k=k, dtype=dtype) - jnp_fun = lambda: jnp.eye(n, M=m, k=k, dtype=dtype) - args_maker = lambda: [] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - dtype=default_dtypes, - n=[0, 4], - m=[None, 0, 1, 3, 4], - k=range(-4, 4), - ) - def testTri(self, m, n, k, dtype): - np_fun = lambda: np.tri(n, M=m, k=k, dtype=dtype) - jnp_fun = lambda: jnp.tri(n, M=m, k=k, dtype=dtype) - args_maker = lambda: [] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - dtype=default_dtypes, - shape=[shape for shape in all_shapes if len(shape) >= 2], - op=["tril", "triu"], - k=list(range(-3, 3)), - ) - def testTriLU(self, dtype, shape, op, k): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg: getattr(np, op)(arg, k=k) - jnp_fun = lambda arg: getattr(jnp, op)(arg, k=k) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - n=range(5), - k=range(-3, 3), - m=[None, *range(5)], - ) - def testTrilIndices(self, n, k, m): - np_fun = lambda n, k, m: np.tril_indices(n, k=k, m=m) - jnp_fun = lambda n, k, m: jnp.tril_indices(n, k=k, m=m) - args_maker = lambda: [n, k, m] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - - @jtu.sample_product( - n=range(5), - k=range(-3, 3), - m=[None, *range(5)], - ) - def testTriuIndices(self, n, k, m): - np_fun = lambda n, k, m: np.triu_indices(n, k=k, m=m) - jnp_fun = lambda n, k, m: jnp.triu_indices(n, k=k, m=m) - args_maker = lambda: [n, k, m] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - - @jtu.sample_product( - dtype=default_dtypes, - shape=[(1,1), (1,2), (2,2), (2,3), (3,2), (3,3), (4,4)], - k=[-1, 0, 1], - ) - def testTriuIndicesFrom(self, shape, dtype, k): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arr, k: np.triu_indices_from(arr, k=k) - jnp_fun = lambda arr, k: jnp.triu_indices_from(arr, k=k) - args_maker = lambda: [rng(shape, dtype), k] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - - @jtu.sample_product( - dtype=default_dtypes, - shape=[(1,1), (1,2), (2,2), (2,3), (3,2), (3,3), (4,4)], - k=[-1, 0, 1], - ) - def testTrilIndicesFrom(self, shape, dtype, k): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arr, k: np.tril_indices_from(arr, k=k) - jnp_fun = lambda arr, k: jnp.tril_indices_from(arr, k=k) - args_maker = lambda: [rng(shape, dtype), k] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - - @jtu.sample_product( - dtype=default_dtypes, - a_shape=[(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2), (1, 2), (0, 2), (2, 3), (2, 2, 2), (2, 2, 2, 2)], - val_shape=[(), (1,), (2,), (1, 2), (3, 2)], - ) - def testFillDiagonal(self, dtype, a_shape, val_shape): - rng = jtu.rand_default(self.rng()) - - def np_fun(a, val): - a_copy = a.copy() - np.fill_diagonal(a_copy, val) - return a_copy - - jnp_fun = partial(jnp.fill_diagonal, inplace=False) - args_maker = lambda : [rng(a_shape, dtype), rng(val_shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - ndim=[0, 1, 4], - n=[0, 1, 7], - ) - def testDiagIndices(self, ndim, n): - np.testing.assert_equal(jtu.with_jax_dtype_defaults(np.diag_indices)(n, ndim), - jnp.diag_indices(n, ndim)) - - @jtu.sample_product( - dtype=default_dtypes, - shape=[(1,1), (2,2), (3,3), (4,4), (5,5)], - ) - def testDiagIndicesFrom(self, dtype, shape): - rng = jtu.rand_default(self.rng()) - np_fun = jtu.with_jax_dtype_defaults(np.diag_indices_from) - jnp_fun = jnp.diag_indices_from - args_maker = lambda : [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - dtype=default_dtypes, - shape=[shape for shape in all_shapes if len(shape) in (1, 2)], - k=list(range(-4, 4)), - ) - def testDiag(self, shape, dtype, k): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg: np.diag(arg, k) - jnp_fun = lambda arg: jnp.diag(arg, k) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - dtype=default_dtypes, - shape=all_shapes, - k=list(range(-4, 4)), - ) - def testDiagFlat(self, shape, dtype, k): - rng = jtu.rand_default(self.rng()) - # numpy has inconsistencies for scalar values - # https://github.com/numpy/numpy/issues/16477 - # jax differs in that it treats scalars values as length-1 arrays - np_fun = lambda arg: np.diagflat(np.atleast_1d(arg), k) - jnp_fun = lambda arg: jnp.diagflat(arg, k) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) - - @unittest.skip("jax-metal fail.") - @jtu.sample_product( - dtype=default_dtypes, - a1_shape=one_dim_array_shapes, - a2_shape=one_dim_array_shapes, - ) - def testPolyMul(self, a1_shape, a2_shape, dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg1, arg2: np.polymul(arg1, arg2) - jnp_fun_np = lambda arg1, arg2: jnp.polymul(arg1, arg2, trim_leading_zeros=True) - jnp_fun_co = lambda arg1, arg2: jnp.polymul(arg1, arg2) - args_maker = lambda: [rng(a1_shape, dtype), rng(a2_shape, dtype)] - tol = {np.float16: 2e-1, np.float32: 5e-2, np.float64: 1e-13} - self._CheckAgainstNumpy(np_fun, jnp_fun_np, args_maker, check_dtypes=False, tol=tol) - self._CompileAndCheck(jnp_fun_co, args_maker, check_dtypes=False) - - @unittest.skip("jax-metal fail.") - @jtu.sample_product( - dtype=[dtype for dtype in default_dtypes - if dtype not in (np.float16, jnp.bfloat16)], - a_shape=one_dim_array_shapes, - b_shape=one_dim_array_shapes, - ) - def testPolyDiv(self, a_shape, b_shape, dtype): - rng = jtu.rand_default(self.rng()) - - @jtu.ignore_warning(category=RuntimeWarning, message="divide by zero.*") - @jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*") - @jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered.*") - def np_fun(arg1, arg2): - q, r = np.polydiv(arg1, arg2) - while r.size < max(arg1.size, arg2.size): # Pad residual to same size - r = np.pad(r, (1, 0), 'constant') - return q, r - - def jnp_fun(arg1, arg2): - q, r = jnp.polydiv(arg1, arg2, trim_leading_zeros=True) - while r.size < max(arg1.size, arg2.size): # Pad residual to same size - r = jnp.pad(r, (1, 0), 'constant') - return q, r - - args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)] - tol = { - dtypes.bfloat16: 2e-1, - np.float16: 2e-1, - np.float32: 5e-2, - np.float64: 5e-7 - } - - jnp_compile = jnp.polydiv # Without trim_leading_zeros (trim_zeros make it unable to be compiled by XLA) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol) - self._CompileAndCheck(jnp_compile, args_maker, check_dtypes=True, atol=tol, rtol=tol) - - @jtu.sample_product( - [dict(shape=shape, axis1=axis1, axis2=axis2) - for shape in [shape for shape in all_shapes if len(shape) >= 2] - for axis1 in range(-len(shape), len(shape)) - for axis2 in [a for a in range(-len(shape), len(shape)) - if a % len(shape) != axis1 % len(shape)] - ], - dtype=default_dtypes, - offset=list(range(-4, 4)), - ) - def testDiagonal(self, shape, dtype, offset, axis1, axis2): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg: np.diagonal(arg, offset, axis1, axis2) - jnp_fun = lambda arg: jnp.diagonal(arg, offset, axis1, axis2) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - dtype=default_dtypes, - n=list(range(4)), - ) - def testIdentity(self, n, dtype): - np_fun = lambda: np.identity(n, dtype) - jnp_fun = lambda: jnp.identity(n, dtype) - args_maker = lambda: [] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @unittest.skip("jax-metal crash.") - @jtu.sample_product( - shape=nonempty_shapes, - period=[None, 0.59], - left=[None, 0], - right=[None, 1], - # Note: skip 8-bit and 16-bit types due to insufficient precision. - dtype=jtu.dtypes.integer + jtu.dtypes.floating, - target_dtype=jtu.dtypes.inexact, - ) - def testInterp(self, shape, dtype, period, left, right, target_dtype): - rng = jtu.rand_default(self.rng(), scale=10) - kwds = dict(period=period, left=left, right=right) - np_fun = partial(np.interp, **kwds) - jnp_fun = partial(jnp.interp, **kwds) - - args_maker = lambda: [rng(shape, dtype), np.unique(rng((100,), dtype))[:20], - rng((20,), target_dtype)] - - with jtu.strict_promotion_if_dtypes_match([dtype, target_dtype]): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, - rtol=3e-3, atol=1e-3) - self._CompileAndCheck(jnp_fun, args_maker) - - @unittest.skip("jax-metal crash.") - @jtu.sample_product([ - dict(x=0.5, left='extrapolate', expected=5), - dict(x=1.5, left='extrapolate', expected=15), - dict(x=3.5, left='extrapolate', expected=30), - dict(x=3.9, right='extrapolate', expected=39), - ]) - def testInterpExtrapoate(self, x, expected, **kwargs): - xp = jnp.array([1.0, 2.0, 3.0]) - fp = jnp.array([10.0, 20.0, 30.0]) - actual = jnp.interp(x, xp, fp, **kwargs) - self.assertAlmostEqual(actual, expected) - - def testInterpErrors(self): - with self.assertRaisesWithLiteralMatch( - ValueError, - 'xp and fp must be one-dimensional arrays of equal size' - ): - jnp.interp(0.0, jnp.arange(2.0), jnp.arange(3.0)) - with self.assertRaisesWithLiteralMatch( - ValueError, - "the only valid string value of `left` is 'extrapolate', but got: 'interpolate'" - ): - jnp.interp(0.0, jnp.arange(3.0), jnp.arange(3.0), left='interpolate') - with self.assertRaisesWithLiteralMatch( - ValueError, - "the only valid string value of `right` is 'extrapolate', but got: 'interpolate'" - ): - jnp.interp(0.0, jnp.arange(3.0), jnp.arange(3.0), right='interpolate') - with self.assertRaisesWithLiteralMatch( - ValueError, - "jnp.interp: complex x values not supported." - ): - jnp.interp(1j, 1j * np.arange(3.0), np.arange(3.0)) - with self.assertRaisesRegex( - ValueError, - "period must be a scalar; got" - ): - jnp.interp(0.0, jnp.arange(3.0), jnp.arange(3.0), period=np.array([1.0])) - - @jtu.sample_product( - period=[None, 0.59], - left=[None, 0], - right=[None, 1], - dtype=jtu.dtypes.floating, - ) - def testInterpGradNan(self, dtype, period, left, right): - kwds = dict(period=period, left=left, right=right) - jnp_fun = partial(jnp.interp, **kwds) - # Probe values of x and xp that are close to zero and close together. - x = dtype(np.exp(np.linspace(-90, -20, 1000))) - g = jax.grad(lambda z: jnp.sum(jnp_fun(z, z, jnp.ones_like(z))))(x) - np.testing.assert_equal(np.all(np.isfinite(g)), True) - - @jtu.sample_product( - [dict(x1_shape=x1_shape, x2_shape=x2_shape) - for x1_shape, x2_shape in filter(_shapes_are_broadcast_compatible, - itertools.combinations_with_replacement(array_shapes, 2)) - ], - x1_rng_factory=[jtu.rand_some_inf_and_nan, jtu.rand_some_zero], - x2_rng_factory=[partial(jtu.rand_int, low=-1075, high=1024)], - x1_dtype=default_dtypes, - ) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testLdexp(self, x1_shape, x1_dtype, x2_shape, x1_rng_factory, x2_rng_factory): - x1_rng = x1_rng_factory(self.rng()) - x2_rng = x2_rng_factory(self.rng()) - - @jtu.ignore_warning(category=RuntimeWarning, message="overflow.*") - def np_fun(x1, x2): - out_dtype = dtypes.to_inexact_dtype(x1.dtype) - return np.ldexp(x1.astype(out_dtype), x2) - - jnp_fun = jnp.ldexp - args_maker = lambda: [x1_rng(x1_shape, x1_dtype), - x2_rng(x2_shape, np.int32)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - rng_factory=[ - jtu.rand_some_inf_and_nan, - jtu.rand_some_zero, - partial(jtu.rand_not_small, offset=1e8), - ], - shape=all_shapes, - dtype=default_dtypes, - ) - def testFrexp(self, shape, dtype, rng_factory): - # integer types are converted to float64 in numpy's implementation - if (dtype not in [jnp.bfloat16, np.float16, np.float32] - and not config.enable_x64.value): - self.skipTest("Only run float64 testcase when float64 is enabled.") - rng = rng_factory(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - def np_frexp(x): - mantissa, exponent = np.frexp(x) - # NumPy is inconsistent between Windows and Linux/Mac on what the - # value of exponent is if the input is infinite. Normalize to the Linux - # behavior. - exponent = np.where(np.isinf(mantissa), np.zeros_like(exponent), exponent) - return mantissa, exponent - self._CheckAgainstNumpy(np_frexp, jnp.frexp, args_maker, - check_dtypes=np.issubdtype(dtype, np.inexact)) - self._CompileAndCheck(jnp.frexp, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axis1=axis1, axis2=axis2) - for shape in [shape for shape in all_shapes if len(shape) >= 2] - for axis1 in range(-len(shape), len(shape)) - for axis2 in range(-len(shape), len(shape)) - if (axis1 % len(shape)) != (axis2 % len(shape)) - ], - dtype=default_dtypes, - out_dtype=[None] + number_dtypes, - offset=list(range(-4, 4)), - ) - def testTrace(self, shape, dtype, out_dtype, offset, axis1, axis2): - rng = jtu.rand_default(self.rng()) - def np_fun(arg): - if out_dtype == jnp.bfloat16: - return np.trace(arg, offset, axis1, axis2, np.float32).astype(jnp.bfloat16) - else: - return np.trace(arg, offset, axis1, axis2, out_dtype) - jnp_fun = lambda arg: jnp.trace(arg, offset, axis1, axis2, out_dtype) - args_maker = lambda: [rng(shape, dtype)] - # TODO: Fails with uint8/uint16 output dtypes (integer overflow?) - if out_dtype not in (np.uint8, np.uint16, np.uint32): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - #unittest.skip("jax-metal fail with empty vshape.") - @jtu.sample_product( - ashape=[(15,), (16,), (17,)], - vshape= [(5,), (5, 5)],#[(), (5,), (5, 5)], - side=['left', 'right'], - dtype= number_dtypes, - method=['sort', 'scan', 'scan_unrolled', 'compare_all'], - ) - def testSearchsorted(self, ashape, vshape, side, dtype, method): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [np.sort(rng(ashape, dtype)), rng(vshape, dtype)] - def np_fun(a, v): - return np.searchsorted(a, v, side=side).astype('int32') - jnp_fun = lambda a, v: jnp.searchsorted(a, v, side=side, method=method) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @unittest.skipIf( - platform.system() == "Windows", - "Under Windows, NumPy throws if 2**32 is converted to an int32" - ) - def testSearchsortedDtype(self): - # Test that for large arrays, int64 indices are used. We test this - # via abstract evaluation to avoid allocating a large array in tests. - a_int32 = core.ShapedArray((np.iinfo(np.int32).max,), np.float32) - a_int64 = core.ShapedArray((np.iinfo(np.int32).max + 1,), np.float32) - v = core.ShapedArray((), np.float32) - - out_int32 = jax.eval_shape(jnp.searchsorted, a_int32, v) - self.assertEqual(out_int32.dtype, np.int32) - - if config.enable_x64.value: - out_int64 = jax.eval_shape(jnp.searchsorted, a_int64, v) - self.assertEqual(out_int64.dtype, np.int64) - elif jtu.numpy_version() < (2, 0, 0): - with self.assertWarnsRegex(UserWarning, "Explicitly requested dtype int64"): - with jtu.ignore_warning(category=DeprecationWarning, - message="NumPy will stop allowing conversion.*"): - out_int64 = jax.eval_shape(jnp.searchsorted, a_int64, v) - else: - with self.assertWarnsRegex(UserWarning, "Explicitly requested dtype int64"): - with self.assertRaisesRegex(OverflowError, "Python integer 2147483648 out of bounds.*"): - out_int64 = jax.eval_shape(jnp.searchsorted, a_int64, v) - - @unittest.skip("Jax-metal fail.") - @jtu.sample_product( - dtype=inexact_dtypes, - side=['left', 'right'], - method=['sort', 'scan', 'compare_all'], - ) - def testSearchsortedNans(self, dtype, side, method): - if np.issubdtype(dtype, np.complexfloating): - raise SkipTest("Known failure for complex inputs; see #9107") - x = np.array([-np.inf, -1.0, 0.0, -0.0, 1.0, np.inf, np.nan, -np.nan], dtype=dtype) - # The sign bit should not matter for 0.0 or NaN, so argsorting the above should be - # equivalent to argsorting the following: - x_equiv = np.array([0, 1, 2, 2, 3, 4, 5, 5]) - - if jnp.issubdtype(dtype, jnp.complexfloating): - x = np.array([complex(r, c) for r, c in itertools.product(x, repeat=2)]) - x_equiv = np.array([complex(r, c) for r, c in itertools.product(x_equiv, repeat=2)]) - - fun = partial(jnp.searchsorted, side=side, method=method) - self.assertArraysEqual(fun(x, x), fun(x_equiv, x_equiv)) - self.assertArraysEqual(jax.jit(fun)(x, x), fun(x_equiv, x_equiv)) - - @jtu.sample_product( - xshape=[(20,), (5, 4)], - binshape=[(1,), (5,)], - right=[True, False], - reverse=[True, False], - dtype=default_dtypes, - ) - def testDigitize(self, xshape, binshape, right, reverse, dtype): - order = jnp.index_exp[::-1] if reverse else jnp.index_exp[:] - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(xshape, dtype), jnp.sort(rng(binshape, dtype))[order]] - np_fun = lambda x, bins: np.digitize(x, bins, right=right).astype('int32') - jnp_fun = lambda x, bins: jnp.digitize(x, bins, right=right) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - dtypes=[ - [np.float32], - [np.float32, np.float32], - [np.float32, np.int32, np.float32], - [np.float32, np.int64, np.float32], - [np.float32, np.int32, np.float64], - ], - shape=[(), (2,), (3, 4), (1, 5)], - array_input=[True, False], - ) - def testColumnStack(self, shape, dtypes, array_input): - rng = jtu.rand_default(self.rng()) - if array_input: - args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])] - else: - args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] - np_fun = jtu.promote_like_jnp(np.column_stack) - jnp_fun = jnp.column_stack - with jtu.strict_promotion_if_dtypes_match(dtypes): - self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in [(), (2,), (3, 4), (1, 100)] - for axis in range(-len(shape), len(shape) + 1) - ], - dtypes=[ - [np.float32], - [np.float32, np.float32], - [np.float32, np.int32, np.float32], - [np.float32, np.int64, np.float32], - [np.float32, np.int32, np.float64], - ], - array_input=[True, False], - out_dtype=[np.float32, np.int32], - ) - def testStack(self, shape, axis, dtypes, array_input, out_dtype): - rng = jtu.rand_default(self.rng()) - if array_input: - args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])] - else: - args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] - - np_fun = jtu.promote_like_jnp(partial(np.stack, axis=axis, dtype=out_dtype, casting='unsafe')) - - jnp_fun = partial(jnp.stack, axis=axis, dtype=out_dtype) - with jtu.strict_promotion_if_dtypes_match(dtypes): - self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - op=["hstack", "vstack", "dstack"], - dtypes=[ - [np.float32], - [np.float32, np.float32], - [np.float32, np.int32, np.float32], - [np.float32, np.int64, np.float32], - [np.float32, np.int32, np.float64], - ], - shape=[(), (2,), (3, 4), (1, 100), (2, 3, 4)], - array_input=[True, False], - out_dtype=[np.float32, np.int32], - ) - def testHVDStack(self, shape, op, dtypes, array_input, out_dtype): - rng = jtu.rand_default(self.rng()) - if array_input: - args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])] - else: - args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] - - if op == "dstack": - np_fun = jtu.promote_like_jnp(lambda *args: getattr(np, op)(*args).astype(out_dtype)) - else: - np_fun = partial(jtu.promote_like_jnp(getattr(np, op)), dtype=out_dtype, - casting='unsafe') - - jnp_fun = partial(getattr(jnp, op), dtype=out_dtype) - with jtu.strict_promotion_if_dtypes_match(dtypes): - self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(name=name, **kwds) - for name in ['blackman', 'bartlett', 'hamming', 'hanning', 'kaiser'] - for kwds in ([dict(beta=1), dict(beta=0.5)] if name == 'kaiser' else [{}]) - ], - size = [0, 1, 5, 10], - ) - def testWindowFunction(self, name, size, **kwds): - jnp_fun = partial(getattr(jnp, name), size, **kwds) - np_fun = jtu.with_jax_dtype_defaults(partial(getattr(np, name), size, **kwds)) - args_maker = lambda: [] - tol = ( - 5e-6 if jtu.test_device_matches(['tpu']) and name == 'kaiser' else None - ) - self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker, atol=tol, rtol=tol) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, fill_value_shape=fill_value_shape) - for shape in array_shapes + [3, np.array(7, dtype=np.int32)] - for fill_value_shape in _compatible_shapes(shape)], - fill_value_dtype=default_dtypes, - out_dtype=[None] + default_dtypes, - ) - def testFull(self, shape, fill_value_dtype, fill_value_shape, out_dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda fill_value: np.full(shape, fill_value, dtype=out_dtype) - jnp_fun = lambda fill_value: jnp.full(shape, fill_value, dtype=out_dtype) - args_maker = lambda: [rng(fill_value_shape, fill_value_dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, dtype=dtype, axis=axis) - for shape, dtype in _shape_and_dtypes(nonempty_nonscalar_array_shapes, default_dtypes) - for axis in list(range(-len(shape), max(1, len(shape)))) - ], - prepend=[None, 1, 0], - append=[None, 1, 0], - n=[0, 1, 2], - ) - def testDiff(self, shape, dtype, n, axis, prepend, append): - prepend = np.zeros(shape, dtype=dtype) if prepend == 0 else prepend - append = np.zeros(shape, dtype=dtype) if append == 0 else append - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - - def np_fun(x, n=n, axis=axis, prepend=prepend, append=append): - if prepend is None: - prepend = np._NoValue - elif not np.isscalar(prepend) and prepend.dtype == jnp.bfloat16: - prepend = prepend.astype(np.float32) - - if append is None: - append = np._NoValue - elif not np.isscalar(append) and append.dtype == jnp.bfloat16: - append = append.astype(np.float32) - - if x.dtype == jnp.bfloat16: - return np.diff(x.astype(np.float32), n=n, axis=axis, prepend=prepend, append=append).astype(jnp.bfloat16) - else: - return np.diff(x, n=n, axis=axis, prepend=prepend, append=append) - - jnp_fun = lambda x: jnp.diff(x, n=n, axis=axis, prepend=prepend, append=append) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - self._CompileAndCheck(jnp_fun, args_maker) - - def testDiffPrepoendScalar(self): - # Regression test for https://github.com/jax-ml/jax/issues/19362 - x = jnp.arange(10) - result_jax = jnp.diff(x, prepend=x[0], append=x[-1]) - - x = np.array(x) - result_numpy = np.diff(x, prepend=x[0], append=x[-1]) - - self.assertArraysEqual(result_jax, result_numpy) - - @jtu.sample_product( - op=["zeros", "ones"], - shape=[2, (), (2,), (3, 0), np.array((4, 5, 6), dtype=np.int32), - np.array(4, dtype=np.int32)], - dtype=all_dtypes, - ) - def testZerosOnes(self, op, shape, dtype): - np_op = getattr(np, op) - jnp_op = getattr(jnp, op) - args_maker = lambda: [] - np_op = partial(np_op, shape, dtype) - jnp_op = partial(jnp_op, shape, dtype) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - def testOnesWithInvalidShape(self): - with self.assertRaises(TypeError): - jnp.ones((-1, 1)) - - def test_full_like_commited(self): - x = jnp.array((1, 2, 3), dtype=np.int32) - self.assertFalse(x._committed) - self.assertFalse(lax.full_like(x, 1.1)._committed) - x = jax.device_put(x, jax.devices()[-1]) - self.assertTrue(x._committed) - y = lax.full_like(x, 1.1) - self.assertTrue(y._committed) - self.assertEqual(x.sharding, y.sharding) - - def test_zeros_like_with_explicit_device_and_jitted(self): - x = jnp.array((1, 2, 3), dtype=np.int32) - x = jax.device_put(x, jax.devices()[0]) - zeros_like_with_device = partial(jnp.zeros_like, device=jax.devices()[0]) - y = jax.jit(zeros_like_with_device)(x) - self.assertEqual(x.shape, y.shape) - self.assertEqual(y.sharding, SingleDeviceSharding(jax.devices()[0])) - - @jtu.sample_product( - [dict(shape=shape, out_shape=out_shape, fill_value_shape=fill_value_shape) - for shape in array_shapes - for out_shape in [None] + array_shapes - for fill_value_shape in _compatible_shapes(shape if out_shape is None else out_shape) - ], - in_dtype=default_dtypes, - fill_value_dtype=default_dtypes, - out_dtype=default_dtypes, - ) - def testFullLike(self, shape, in_dtype, fill_value_dtype, fill_value_shape, out_dtype, out_shape): - rng = jtu.rand_default(self.rng()) - np_fun = lambda x, fill_value: np.full_like( - x, fill_value, dtype=out_dtype, shape=out_shape) - jnp_fun = lambda x, fill_value: jnp.full_like( - x, fill_value, dtype=out_dtype, shape=out_shape) - args_maker = lambda: [rng(shape, in_dtype), rng(fill_value_shape, fill_value_dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - shape=array_shapes, - out_shape=[None] + array_shapes, - in_dtype=default_dtypes, - func=["ones_like", "zeros_like"], - out_dtype=default_dtypes, - ) - def testZerosOnesLike(self, func, shape, in_dtype, out_shape, out_dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda x: getattr(np, func)(x, dtype=out_dtype, shape=out_shape) - jnp_fun = lambda x: getattr(jnp, func)(x, dtype=out_dtype, shape=out_shape) - args_maker = lambda: [rng(shape, in_dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - func=[jnp.empty, jnp.zeros, jnp.ones, jnp.full], - shape=array_shapes, - dtype=default_dtypes, - ) - def testArrayCreationWithDevice(self, func, shape, dtype): - device = jax.devices()[-1] - kwds = {'fill_value': 1} if func is jnp.full else {} - out = func(**kwds, shape=shape, dtype=dtype, device=device) - self.assertEqual(out.devices(), {device}) - - @jtu.sample_product( - func=[jnp.empty, jnp.zeros, jnp.ones, jnp.full], - shape=array_shapes, - dtype=default_dtypes, - ) - def testArrayCreationWithSharding(self, func, shape, dtype): - sharding = SingleDeviceSharding(jax.devices()[-1]) - kwds = {'fill_value': 1} if func is jnp.full else {} - out = func(**kwds, shape=shape, dtype=dtype, device=sharding) - self.assertEqual(out.sharding, sharding) - - @jtu.sample_product( - func=[jnp.empty_like, jnp.zeros_like, jnp.ones_like, jnp.full_like], - shape=array_shapes, - dtype=default_dtypes, - ) - def testFullLikeWithDevice(self, func, shape, dtype): - device = jax.devices()[-1] - rng = jtu.rand_default(self.rng()) - x = rng(shape, dtype) - kwds = {'fill_value': 1} if func is jnp.full_like else {} - - with self.subTest('device from keyword'): - out = func(x, **kwds, device=device) - self.assertEqual(out.devices(), {device}) - - with self.subTest('device from input array'): - out2 = func(out, **kwds) - self.assertEqual(out2.devices(), out.devices()) - - @jtu.sample_product( - func=[jnp.empty_like, jnp.zeros_like, jnp.ones_like, jnp.full_like], - shape=array_shapes, - dtype=default_dtypes, - ) - def testFullLikeWithSharding(self, func, shape, dtype): - sharding = SingleDeviceSharding(jax.devices()[-1]) - rng = jtu.rand_default(self.rng()) - x = rng(shape, dtype) - kwds = {'fill_value': 1} if func is jnp.full_like else {} - - with self.subTest('device from keyword'): - out = func(x, **kwds, device=sharding) - self.assertEqual(out.sharding, sharding) - - with self.subTest('device from input array'): - out2 = func(out, **kwds) - self.assertEqual(out2.devices(), out.devices()) - - def testDuckTypedLike(self): - x = jax.ShapeDtypeStruct((1, 2, 3), np.dtype("int32")) - self.assertArraysEqual(jnp.zeros_like(x), jnp.zeros(x.shape, x.dtype)) - self.assertArraysEqual(jnp.ones_like(x), jnp.ones(x.shape, x.dtype)) - self.assertArraysEqual(jnp.empty_like(x), jnp.empty(x.shape, x.dtype)) - self.assertArraysEqual(jnp.full_like(x, 2), jnp.full(x.shape, 2, x.dtype)) - - @jtu.sample_product( - [dict(func=func, args=args) - for func, args in [("full_like", (-100,)), ("ones_like", ()), ("zeros_like", ())] - ], - shape=array_shapes, - #in_dtype=[np.int32, np.float32, np.complex64], - in_dtype=[np.int32, np.float32], - weak_type=[True, False], - out_shape=[None, (), (10,)], - out_dtype=[None, float], - ) - def testZerosOnesFullLikeWeakType(self, func, args, shape, in_dtype, weak_type, out_shape, out_dtype): - rng = jtu.rand_default(self.rng()) - x = lax_internal._convert_element_type(rng(shape, in_dtype), - weak_type=weak_type) - fun = lambda x: getattr(jnp, func)(x, *args, dtype=out_dtype, shape=out_shape) - expected_weak_type = weak_type and (out_dtype is None) - self.assertEqual(dtypes.is_weakly_typed(fun(x)), expected_weak_type) - self.assertEqual(dtypes.is_weakly_typed(jax.jit(fun)(x)), expected_weak_type) - - @jtu.sample_product( - funcname=["array", "asarray"], - dtype=[int, float, None], - val=[0, 1], - input_type=[int, float, np.int32, np.float32], - ) - def testArrayWeakType(self, funcname, input_type, val, dtype): - func = lambda x: getattr(jnp, funcname)(x, dtype=dtype) - fjit = jax.jit(func) - val = input_type(val) - expected_weak_type = dtype is None and input_type in set(dtypes._weak_types) - self.assertEqual(dtypes.is_weakly_typed(func(val)), expected_weak_type) - self.assertEqual(dtypes.is_weakly_typed(fjit(val)), expected_weak_type) - - @jtu.sample_product( - shape=nonempty_nonscalar_array_shapes, - #dtype=[int, float, complex], - dtype=[int, float], - weak_type=[True, False], - slc=[slice(None), slice(0), slice(3), 0, ...], - ) - def testSliceWeakTypes(self, shape, dtype, weak_type, slc): - rng = jtu.rand_default(self.rng()) - x = lax_internal._convert_element_type(rng(shape, dtype), - weak_type=weak_type) - op = lambda x: x[slc] - self.assertEqual(op(x).aval.weak_type, weak_type) - self.assertEqual(jax.jit(op)(x).aval.weak_type, weak_type) - - @jtu.sample_product( - [dict(shape=shape, axis=axis, num_sections=num_sections) - for shape, axis, num_sections in [ - ((3,), 0, 3), ((12,), 0, 3), ((12, 4), 0, 4), ((12, 4), 1, 2), - ((2, 3, 4), -1, 2), ((2, 3, 4), -2, 3)] - ], - dtype=default_dtypes, - ) - def testSplitStaticInt(self, shape, num_sections, axis, dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda x: np.split(x, num_sections, axis=axis) - jnp_fun = lambda x: jnp.split(x, num_sections, axis=axis) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axis=axis, num_sections=num_sections) - # All testcases split the specified axis unequally - for shape, axis, num_sections in [ - ((3,), 0, 2), ((12,), 0, 5), ((12, 4), 0, 7), ((12, 4), 1, 3), - ((2, 3, 5), -1, 2), ((2, 4, 4), -2, 3), ((7, 2, 2), 0, 3)] - ], - dtype=default_dtypes, - ) - def testArraySplitStaticInt(self, shape, num_sections, axis, dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda x: np.array_split(x, num_sections, axis=axis) - jnp_fun = lambda x: jnp.array_split(x, num_sections, axis=axis) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - def testSplitTypeError(self): - # If we pass an ndarray for indices_or_sections -> no error - self.assertEqual(3, len(jnp.split(jnp.zeros(3), jnp.array([1, 2])))) - - CONCRETIZATION_MSG = "Abstract tracer value encountered where concrete value is expected." - with self.assertRaisesRegex(TypeError, CONCRETIZATION_MSG): - # An abstract tracer for idx - jax.jit(lambda idx: jnp.split(jnp.zeros((12, 2)), idx))(2.) - with self.assertRaisesRegex(TypeError, CONCRETIZATION_MSG): - # A list including an abstract tracer - jax.jit(lambda idx: jnp.split(jnp.zeros((12, 2)), [2, idx]))(2.) - - # A concrete tracer -> no error - jax.jvp(lambda idx: jnp.split(jnp.zeros((12, 2)), idx), - (2.,), (1.,)) - # A tuple including a concrete tracer -> no error - jax.jvp(lambda idx: jnp.split(jnp.zeros((12, 2)), (1, idx.astype(np.int32))), - (2.,), (1.,)) - - @jtu.sample_product( - shape=[(5,), (5, 5)], - dtype=number_dtypes, - bins=[10, np.arange(-5, 6), np.array([-5, 0, 3])], - range=[None, (0, 0), (0, 10)], - weights=[True, False], - ) - def testHistogramBinEdges(self, shape, dtype, bins, range, weights): - rng = jtu.rand_default(self.rng()) - _weights = lambda w: abs(w) if weights else None - np_fun = lambda a, w, r: np.histogram_bin_edges(a, bins=bins, range=r, - weights=_weights(w)) - jnp_fun = lambda a, w, r: jnp.histogram_bin_edges(a, bins=bins, range=r, - weights=_weights(w)) - args_maker = lambda: [rng(shape, dtype), rng(shape, dtype), range] - tol = {jnp.bfloat16: 2E-2, np.float16: 1E-2} - # linspace() compares poorly to numpy when using bfloat16 - if dtype != jnp.bfloat16: - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, - atol=tol, rtol=tol) - - @jtu.sample_product( - shape=[(5,), (4, 5)], - dtype=default_dtypes, - # We only test explicit integer-valued bin edges because in other cases - # rounding errors lead to flaky tests. - bins=[np.arange(-5, 6), np.array([-5, 0, 3])], - density=[True, False], - weights=[True, False], - ) - def testHistogram(self, shape, dtype, bins, density, weights): - rng = jtu.rand_default(self.rng()) - _weights = lambda w: abs(w) if weights else None - def np_fun(a, w): - # Numpy can't handle bfloat16 - a = a.astype('float32') if a.dtype == jnp.bfloat16 else a - w = w.astype('float32') if w.dtype == jnp.bfloat16 else w - return np.histogram(a, bins=bins, density=density, weights=_weights(w)) - jnp_fun = lambda a, w: jnp.histogram(a, bins=bins, density=density, - weights=_weights(w)) - args_maker = lambda: [rng(shape, dtype), rng(shape, dtype)] - tol = {jnp.bfloat16: 2E-2, np.float16: 1E-1} - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, - tol=tol) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - shape=[(5,), (12,)], - dtype=int_dtypes, - bins=[2, [2, 2], [np.array([0, 1, 3, 5]), np.array([0, 2, 3, 4, 6])]], - weights=[False, True], - density=[False, True], - range=[None, [(-1, 1), None], [(-1, 1), (-2, 2)]], - ) - def testHistogram2d(self, shape, dtype, bins, weights, density, range): - rng = jtu.rand_default(self.rng()) - _weights = lambda w: abs(w) if weights else None - np_fun = jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*")( - lambda a, b, w: np.histogram2d(a, b, bins=bins, weights=_weights(w), density=density, range=range)) - jnp_fun = lambda a, b, w: jnp.histogram2d(a, b, bins=bins, weights=_weights(w), density=density, range=range) - args_maker = lambda: [rng(shape, dtype), rng(shape, dtype), rng(shape, dtype)] - tol = {jnp.bfloat16: 2E-2, np.float16: 1E-1} - # np.searchsorted errors on bfloat16 with - # "TypeError: invalid type promotion with custom data type" - with np.errstate(divide='ignore', invalid='ignore'): - if dtype != jnp.bfloat16: - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, - tol=tol) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - shape=[(5, 3), (10, 3)], - dtype=int_dtypes, - bins=[(2, 2, 2), [np.array([-5, 0, 4]), np.array([-4, -1, 2]), np.array([-6, -1, 4])]], - weights=[False, True], - density=[False, True], - range=[None, [(-1, 1), None, None], [(-1, 1), (-2, 2), (-3, 3)]], - ) - def testHistogramdd(self, shape, dtype, bins, weights, density, range): - rng = jtu.rand_default(self.rng()) - _weights = lambda w: abs(w) if weights else None - np_fun = jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*")( - lambda a, w: np.histogramdd(a, bins=bins, weights=_weights(w), density=density, range=range)) - jnp_fun = lambda a, w: jnp.histogramdd(a, bins=bins, weights=_weights(w), density=density, range=range) - args_maker = lambda: [rng(shape, dtype), rng((shape[0],), dtype)] - tol = {jnp.bfloat16: 2E-2, np.float16: 1E-1} - # np.searchsorted errors on bfloat16 with - # "TypeError: invalid type promotion with custom data type" - if dtype != jnp.bfloat16: - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, - tol=tol) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axis=axis, num_sections=num_sections) - for shape, axis, num_sections in [ - ((12, 4), 0, 4), ((12,), 1, 2), - ((2, 3, 4), 2, 2), ((4, 3, 4), 0, 2)]], - dtype=default_dtypes, - ) - def testHVDSplit(self, shape, num_sections, axis, dtype): - rng = jtu.rand_default(self.rng()) - def fn(module, axis): - if axis == 0: - return module.vsplit - elif axis == 1: - return module.hsplit - else: - assert axis == 2 - return module.dsplit - - np_fun = lambda x: fn(np, axis)(x, num_sections) - jnp_fun = lambda x: fn(jnp, axis)(x, num_sections) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(arg_shape=arg_shape, out_shape=out_shape) - for arg_shape, out_shape in [ - (jtu.NUMPY_SCALAR_SHAPE, (1, 1, 1)), - ((), (1, 1, 1)), - ((7, 0), (0, 42, 101)), - ((3, 4), 12), - ((3, 4), (12,)), - ((3, 4), -1), - ((2, 1, 4), (-1,)), - ((2, 2, 4), (2, 8)) - ] - ], - dtype=default_dtypes, - order=["C", "F"], - ) - def testReshape(self, arg_shape, out_shape, dtype, order): - rng = jtu.rand_default(self.rng()) - np_fun = lambda x: np.reshape(x, out_shape, order=order) - jnp_fun = lambda x: jnp.reshape(x, out_shape, order=order) - args_maker = lambda: [rng(arg_shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(arg_shape=arg_shape, out_shape=out_shape) - for arg_shape, out_shape in [ - ((7, 0), (0, 42, 101)), - ((2, 1, 4), (-1,)), - ((2, 2, 4), (2, 8)) - ] - ], - dtype=default_dtypes, - ) - def testReshapeMethod(self, arg_shape, out_shape, dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda x: np.reshape(x, out_shape) - jnp_fun = lambda x: x.reshape(*out_shape) - args_maker = lambda: [rng(arg_shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(arg_shape=arg_shape, out_shape=out_shape) - for arg_shape, out_shape in itertools.product(all_shapes, array_shapes)], - dtype=default_dtypes, - ) - def testResize(self, arg_shape, out_shape, dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda x: np.resize(x, out_shape) - jnp_fun = lambda x: jnp.resize(x, out_shape) - args_maker = lambda: [rng(arg_shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(arg_shape=arg_shape, dim=dim) - for arg_shape in [(), (3,), (3, 4)] - for dim in (list(range(-len(arg_shape)+1, len(arg_shape))) - + [np.array(0), np.array(-1), (0,), [np.array(0)], - (len(arg_shape), len(arg_shape) + 1)]) - ], - dtype=default_dtypes, - ) - def testExpandDimsStaticDim(self, arg_shape, dtype, dim): - rng = jtu.rand_default(self.rng()) - np_fun = lambda x: np.expand_dims(x, dim) - jnp_fun = lambda x: jnp.expand_dims(x, dim) - args_maker = lambda: [rng(arg_shape, dtype)] - self._CompileAndCheck(jnp_fun, args_maker) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - - def testExpandDimsRepeatedAxisError(self): - x = jnp.ones((2, 3)) - self.assertRaisesRegex( - ValueError, 'repeated axis.*', - lambda: jnp.expand_dims(x, [1, 1])) - self.assertRaisesRegex( - ValueError, 'repeated axis.*', - lambda: jnp.expand_dims(x, [3, -1])) - - # ensure this is numpy's behavior too, so that we remain consistent - x = np.ones((2, 3)) - self.assertRaisesRegex( - ValueError, 'repeated axis.*', - lambda: np.expand_dims(x, [1, 1])) - self.assertRaisesRegex( - ValueError, 'repeated axis.*', - lambda: np.expand_dims(x, [3, -1])) - - @jtu.sample_product( - [dict(arg_shape=arg_shape, ax1=ax1, ax2=ax2) - for arg_shape, ax1, ax2 in [ - ((3, 4), 0, 1), ((3, 4), 1, 0), ((3, 4, 5), 1, 2), - ((3, 4, 5), -1, -2), ((3, 4, 5), 0, 1)] - ], - dtype=default_dtypes, - ) - def testSwapAxesStaticAxes(self, arg_shape, dtype, ax1, ax2): - rng = jtu.rand_default(self.rng()) - np_fun = lambda x: np.swapaxes(x, ax1, ax2) - jnp_fun = lambda x: jnp.swapaxes(x, ax1, ax2) - args_maker = lambda: [rng(arg_shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(arg_shape=arg_shape, ax=ax) - for arg_shape, ax in [ - ((3, 1), None), - ((3, 1), 1), - ((3, 1), -1), - ((3, 1), np.array(1)), - ((1, 3, 1), (0, 2)), - ((1, 3, 1), (0,)), - ((1, 4, 1), (np.array(0),))] - ], - dtype=default_dtypes, - ) - def testSqueeze(self, arg_shape, dtype, ax): - rng = jtu.rand_default(self.rng()) - np_fun = lambda x: np.squeeze(x, ax) - jnp_fun = lambda x: jnp.squeeze(x, ax) - args_maker = lambda: [rng(arg_shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - def testArrayFromMasked(self): - args_maker = lambda: [np.ma.array([1, 2], mask=[True, False])] - # Like np.array, jnp.array strips the mask from masked array inputs. - self._CheckAgainstNumpy(np.array, jnp.array, args_maker) - # Under JIT, masked arrays are flagged as invalid. - with self.assertRaisesRegex(ValueError, "numpy masked arrays are not supported"): - jax.jit(jnp.asarray)(*args_maker()) - - @jtu.sample_product( - [dict(arg=arg, dtype=dtype, ndmin=ndmin) - for arg, dtypes in [ - ([True, False, True], all_dtypes), - (3., all_dtypes), - ([1, 2, 3], all_dtypes), - (np.array([1, 2, 3], dtype=np.int64), all_dtypes), - ([1., 2., 3.], all_dtypes), - ([[1, 2], [3, 4], [5, 6]], all_dtypes), - ([[1, 2.], [3, 4], [5, 6]], all_dtypes), - ([[1., 2j], [3., 4.], [5., 6.]], complex_dtypes), - ([[3, np.array(2, dtype=jnp.float_), 1], - np.arange(3., dtype=jnp.float_)], all_dtypes), - ] - for dtype in [None] + dtypes - for ndmin in [None, np.ndim(arg), np.ndim(arg) + 1, np.ndim(arg) + 2] - ], - ) - def testArray(self, arg, ndmin, dtype): - args_maker = lambda: [arg] - canonical_dtype = dtypes.canonicalize_dtype(dtype or np.array(arg).dtype) - if ndmin is not None: - np_fun = partial(np.array, ndmin=ndmin, dtype=canonical_dtype) - jnp_fun = partial(jnp.array, ndmin=ndmin, dtype=dtype) - else: - np_fun = partial(np.array, dtype=canonical_dtype) - jnp_fun = partial(jnp.array, dtype=dtype) - - # We are testing correct canonicalization behavior here, so we turn off the - # permissive canonicalization logic in the test harness. - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, - canonicalize_dtypes=False) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product(copy=[None, True, False]) - def testAsarrayCopy(self, copy): - x_jax = jnp.arange(4) - x_np = np.arange(4) - x_list = [0, 1, 2, 3] - x_buf = make_python_array('l', x_list) - - func = partial(jnp.asarray, copy=copy) - self.assertArraysEqual(x_jax, func(x_jax)) - self.assertArraysEqual(x_jax, func(x_list), check_dtypes=False) - - if copy is False and jax.default_backend() != 'cpu': - # copy=False is strict: it must raise if the input supports the buffer protocol - # but a copy is still required. - self.assertRaises(ValueError, func, x_np) - self.assertRaises(ValueError, func, x_buf) - else: - self.assertArraysEqual(x_jax, func(x_np), check_dtypes=False) - self.assertArraysEqual(x_jax, func(x_buf), check_dtypes=False) - - @unittest.skip("Jax-metal don't support all dtypes.") - @jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*") - def testArrayDtypeInference(self): - def _check(obj, out_dtype, weak_type): - dtype_reference = np.array(obj, dtype=out_dtype) - - out = jnp.array(obj) - self.assertDtypesMatch(out, dtype_reference) - self.assertEqual(dtypes.is_weakly_typed(out), weak_type) - - out_jit = jax.jit(jnp.array)(obj) - self.assertDtypesMatch(out_jit, dtype_reference) - self.assertEqual(dtypes.is_weakly_typed(out_jit), weak_type) - - # Python scalars become 64-bit weak types. - _check(1, np.int64, True) - _check(1.0, np.float64, True) - _check(1.0j, np.complex128, True) - - # Lists become strongly-typed defaults. - _check([1], jnp.int64, False) - _check([1.0], jnp.float64, False) - _check([1.0j], jnp.complex128, False) - - # Lists of weakly-typed objects become strongly-typed defaults. - _check([jnp.array(1)], jnp.int64, False) - _check([jnp.array(1.0)], jnp.float64, False) - _check([jnp.array(1.0j)], jnp.complex128, False) - - # Lists of strongly-typed objects maintain their strong type. - _check([jnp.int64(1)], np.int64, False) - _check([jnp.float64(1)], np.float64, False) - _check([jnp.complex128(1)], np.complex128, False) - - # Mixed inputs use JAX-style promotion. - # (regression test for https://github.com/jax-ml/jax/issues/8945) - _check([0, np.int16(1)], np.int16, False) - _check([0.0, np.float16(1)], np.float16, False) - - @jtu.sample_product( - dtype=all_dtypes, - func=["array", "copy", "copy.copy", "copy.deepcopy"], - ) - def testArrayCopy(self, dtype, func): - x = jnp.ones(10, dtype=dtype) - if func == "copy.deepcopy": - copy_func = copy.deepcopy - elif func == "copy.copy": - copy_func = copy.copy - else: - copy_func = getattr(jnp, func) - - x_view = jnp.asarray(x) - x_view_jit = jax.jit(jnp.asarray)(x) - x_copy = copy_func(x) - x_copy_jit = jax.jit(copy_func)(x) - - _ptr = lambda x: x.unsafe_buffer_pointer() - - self.assertEqual(_ptr(x), _ptr(x_view)) - self.assertNotEqual(_ptr(x), _ptr(x_view_jit)) - self.assertNotEqual(_ptr(x), _ptr(x_copy)) - self.assertNotEqual(_ptr(x), _ptr(x_copy_jit)) - - x.delete() - - self.assertTrue(x_view.is_deleted()) - self.assertFalse(x_view_jit.is_deleted()) - - self.assertFalse(x_copy.is_deleted()) - self.assertFalse(x_copy_jit.is_deleted()) - - def testArrayCopyAutodiff(self): - f = lambda x: jnp.array(x, copy=True) - - x = jnp.ones(10) - xdot = jnp.ones(10) - y, ydot = jax.jvp(f, (x,), (xdot,)) - self.assertIsNot(x, y) - self.assertIsNot(xdot, ydot) - - ybar = jnp.ones(10) - y, f_vjp = jax.vjp(f, x) - xbar, = f_vjp(ybar) - self.assertIsNot(x, y) - self.assertIsNot(xbar, ybar) - - def testArrayCopyVmap(self): - f = lambda x: jnp.array(x, copy=True) - x = jnp.ones(10) - y = jax.vmap(f)(x) - self.assertIsNot(x, y) - - def testArrayUnsupportedDtypeError(self): - with self.assertRaisesRegex(TypeError, - "JAX only supports number and bool dtypes.*"): - jnp.array(3, [('a',' 0.: - return x * 2 - else: - return x + 2 - - self.assertRaises(jax.errors.ConcretizationTypeError, lambda: g(3.)) - - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in [(3,), (2, 3)] - for axis in list(range(-len(shape), len(shape))) + [None] + [tuple(range(len(shape)))] # Test negative axes and tuples - ], - dtype=default_dtypes, - ) - def testFlip(self, shape, dtype, axis): - rng = jtu.rand_default(self.rng()) - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - jnp_op = lambda x: jnp.flip(x, axis) - np_op = lambda x: np.flip(x, axis) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - @jtu.sample_product( - shape=[(3,), (2, 3), (3, 2, 4)], - dtype=default_dtypes, - ) - def testFlipud(self, shape, dtype): - rng = jtu.rand_default(self.rng()) - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - jnp_op = lambda x: jnp.flipud(x) - np_op = lambda x: np.flipud(x) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - @jtu.sample_product( - shape=[(3, 2), (2, 3), (3, 2, 4)], - dtype=default_dtypes, - ) - def testFliplr(self, shape, dtype): - rng = jtu.rand_default(self.rng()) - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - jnp_op = lambda x: jnp.fliplr(x) - np_op = lambda x: np.fliplr(x) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axes=axes) - for shape, axes in [ - [(2, 3), (0, 1)], - [(2, 3), (1, 0)], - [(4, 3, 2), (0, 2)], - [(4, 3, 2), (2, 1)], - ] - ], - k=range(-3, 4), - dtype=default_dtypes, - ) - def testRot90(self, shape, dtype, k, axes): - rng = jtu.rand_default(self.rng()) - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - jnp_op = lambda x: jnp.rot90(x, k, axes) - np_op = lambda x: np.rot90(x, k, axes) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - # TODO(mattjj): test infix operator overrides - - def testRavel(self): - rng = self.rng() - args_maker = lambda: [rng.randn(3, 4).astype("float32")] - self._CompileAndCheck(lambda x: x.ravel(), args_maker) - - @jtu.sample_product( - shape=nonempty_nonscalar_array_shapes, - order=['C', 'F'], - mode=['wrap', 'clip', 'raise'], - ) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testRavelMultiIndex(self, shape, order, mode): - # generate indices in each dimension with a few out of bounds. - rngs = [jtu.rand_int(self.rng(), low=-1, high=dim + 1) - for dim in shape] - # generate multi_indices of different dimensions that broadcast. - args_maker = lambda: [tuple(rng(ndim * (3,), jnp.int_) - for ndim, rng in enumerate(rngs))] - def np_fun(x): - try: - return np.ravel_multi_index(x, shape, order=order, mode=mode) - except ValueError as err: - if str(err).startswith('invalid entry'): - # sentinel indicating expected error. - return -999 - else: - raise - def jnp_fun(x): - try: - return jnp.ravel_multi_index(x, shape, order=order, mode=mode) - except ValueError as err: - if str(err).startswith('invalid entry'): - # sentinel indicating expected error. - return -999 - else: - raise - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - if mode == 'raise': - msg = ("The error occurred because ravel_multi_index was jit-compiled " - "with mode='raise'. Use mode='wrap' or mode='clip' instead.") - with self.assertRaisesRegex(core.ConcretizationTypeError, msg): - jax.jit(jnp_fun)(*args_maker()) - else: - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - ashape=((), (4,), (3, 4)), - cshapes=[ - [(), (4,)], - [(3, 4), (4,), (3, 1)] - ], - adtype=int_dtypes, - cdtype=default_dtypes, - mode=['wrap', 'clip', 'raise'], - ) - def testChoose(self, ashape, adtype, cshapes, cdtype, mode): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(ashape, adtype), [rng(s, cdtype) for s in cshapes]] - def np_fun(a, c): - try: - return np.choose(a, c, mode=mode) - except ValueError as err: - if mode == 'raise' and str(err).startswith('invalid entry'): - return -999 # sentinel indicating expected error. - else: - raise - def jnp_fun(a, c): - try: - return jnp.choose(a, c, mode=mode) - except ValueError as err: - if mode == 'raise' and str(err).startswith('invalid entry'): - return -999 # sentinel indicating expected error. - else: - raise - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - if mode == 'raise': - msg = ("The error occurred because jnp.choose was jit-compiled" - " with mode='raise'. Use mode='wrap' or mode='clip' instead.") - with self.assertRaisesRegex(core.ConcretizationTypeError, msg): - jax.jit(jnp_fun)(*args_maker()) - else: - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - shape=nonempty_nonscalar_array_shapes, - dtype=int_dtypes, - idx_shape=all_shapes, - ) - def testUnravelIndex(self, shape, idx_shape, dtype): - size = math.prod(shape) - rng = jtu.rand_int(self.rng(), low=-((2 * size) // 3), high=(2 * size) // 3) - - def np_fun(index, shape): - # JAX's version outputs the same dtype as the input in the typical case - # where shape is weakly-typed. - out_dtype = index.dtype - # Adjust out-of-bounds behavior to match jax's documented behavior. - index = np.clip(index, -size, size - 1) - index = np.where(index < 0, index + size, index) - return [i.astype(out_dtype) for i in np.unravel_index(index, shape)] - - jnp_fun = jnp.unravel_index - args_maker = lambda: [rng(idx_shape, dtype), shape] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - from_dtype=['int32', 'float32'], - to_dtype=['int32', 'float32', None], - use_method=[True, False], - ) - def testAstype(self, from_dtype, to_dtype, use_method): - rng = self.rng() - args_maker = lambda: [rng.randn(3, 4).astype(from_dtype)] - if (not use_method) and hasattr(np, "astype"): # Added in numpy 2.0 - np_op = lambda x: np.astype(x, to_dtype) - else: - np_op = lambda x: np.asarray(x).astype(to_dtype) - if use_method: - jnp_op = lambda x: jnp.asarray(x).astype(to_dtype) - else: - jnp_op = lambda x: jnp.astype(x, to_dtype) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - @unittest.skip("Jax-metal don't support all dtypes") - def testAstypeInt4(self): - # Test converting from int4 to int8 - x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int4) - args_maker = lambda: [x] - np_op = lambda x: np.asarray(x).astype(jnp.int8) - jnp_op = lambda x: jnp.asarray(x).astype(jnp.int8) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - # Test converting from int8 to int4 - x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int8) - args_maker = lambda: [x] - np_op = lambda x: np.asarray(x).astype(jnp.int4) - jnp_op = lambda x: jnp.asarray(x).astype(jnp.int4) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - @jtu.sample_product( - shape=array_shapes, - dtype=all_dtypes, - ) - def testNbytes(self, shape, dtype): - rng = jtu.rand_default(self.rng()) - np_op = lambda x: np.asarray(x).nbytes - jnp_op = lambda x: jnp.asarray(x).nbytes - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - @jtu.sample_product( - shape=array_shapes, - dtype=all_dtypes, - ) - def testItemsize(self, shape, dtype): - rng = jtu.rand_default(self.rng()) - np_op = lambda x: np.asarray(x).itemsize - jnp_op = lambda x: jnp.asarray(x).itemsize - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - @jtu.sample_product( - shape=nonempty_array_shapes, - dtype=all_dtypes, - num_args=[0, 1, "all"], - use_tuple=[True, False] - ) - def testItem(self, shape, dtype, num_args, use_tuple): - rng = jtu.rand_default(self.rng()) - size = math.prod(shape) - - if num_args == 0: - args = () - elif num_args == 1: - args = (self.rng().randint(0, size),) - else: - args = tuple(self.rng().randint(0, s) for s in shape) - args = (args,) if use_tuple else args - - np_op = lambda x: np.asarray(x).item(*args) - jnp_op = lambda x: jnp.asarray(x).item(*args) - args_maker = lambda: [rng(shape, dtype)] - - if size != 1 and num_args == 0: - with self.assertRaises(ValueError): - jnp_op(*args_maker()) - else: - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - - @jtu.sample_product( - # Final dimension must be a multiple of 16 to ensure compatibilty of all dtype pairs. - shape=[(0,), (32,), (2, 16)], - a_dtype=all_dtypes, - dtype=(*all_dtypes, None) if config.enable_x64.value else all_dtypes, - ) - def testView(self, shape, a_dtype, dtype): - if jtu.test_device_matches(["tpu"]): - if jnp.dtype(a_dtype).itemsize in [1, 2] or jnp.dtype(dtype).itemsize in [1, 2]: - self.skipTest("arr.view() not supported on TPU for 8- or 16-bit types.") - # It is possible to fill bool arrays with arbitrary bits (not just 0/1 - # bytes), but the behavior is implementation-defined. We therefore only test - # the well-defined case. - rng = (jtu.rand_bool if a_dtype == np.bool_ else jtu.rand_fullrange)( - self.rng() - ) - args_maker = lambda: [rng(shape, a_dtype)] - np_op = lambda x: np.asarray(x).view(dtype) - jnp_op = lambda x: jnp.asarray(x).view(dtype) - # Above may produce signaling nans; ignore warnings from invalid values. - with np.errstate(invalid='ignore'): - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - @jtu.sample_product([ - {'a_dtype': a_dtype, 'dtype': dtype} - for a_dtype in all_dtypes - for dtype in all_dtypes - if np.dtype(a_dtype).itemsize == np.dtype(dtype).itemsize - ]) - def testViewScalar(self, a_dtype, dtype): - if jtu.test_device_matches(["tpu"]): - if jnp.dtype(a_dtype).itemsize in [1, 2] or jnp.dtype(dtype).itemsize in [1, 2]: - self.skipTest("arr.view() not supported on TPU for 8- or 16-bit types.") - rng = jtu.rand_fullrange(self.rng()) - args_maker = lambda: [jnp.array(rng((), a_dtype))] - np_op = lambda x: np.asarray(x).view(dtype) - jnp_op = lambda x: jnp.asarray(x).view(dtype) - # Above may produce signaling nans; ignore warnings from invalid values. - with np.errstate(invalid='ignore'): - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - def testPathologicalFloats(self): - args_maker = lambda: [np.array([ - 0b_0111_1111_1000_0000_0000_0000_0000_0000, # inf - 0b_1111_1111_1000_0000_0000_0000_0000_0000, # -inf - 0b_0111_1111_1100_0000_0000_0000_0000_0000, # qnan - 0b_1111_1111_1100_0000_0000_0000_0000_0000, # -qnan - 0b_0111_1111_1000_0000_0000_0000_0000_0001, # snan - 0b_1111_1111_1000_0000_0000_0000_0000_0001, # -snan - 0b_0111_1111_1000_0000_0000_1100_0000_0000, # nonstandard nan - 0b_1111_1111_1000_0000_0000_1100_0000_0000, # -nonstandard nan - 0b_0000_0000_0000_0000_0000_0000_0000_0000, # zero - 0b_1000_0000_0000_0000_0000_0000_0000_0000, # -zero - ], dtype='uint32')] - - np_op = lambda x: np.asarray(x).view('float32').view('uint32') - jnp_op = lambda x: jnp.asarray(x).view('float32').view('uint32') - - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - # TODO(mattjj): test other ndarray-like method overrides - - def testNpMean(self): - # from https://github.com/jax-ml/jax/issues/125 - x = jnp.eye(3, dtype=float) + 0. - ans = np.mean(x) - self.assertAllClose(ans, np.array(1./3), check_dtypes=False) - - def testArangeOnFloats(self): - np_arange = jtu.with_jax_dtype_defaults(np.arange) - # from https://github.com/jax-ml/jax/issues/145 - self.assertAllClose(np_arange(0.0, 1.0, 0.1), - jnp.arange(0.0, 1.0, 0.1)) - # from https://github.com/jax-ml/jax/issues/3450 - self.assertAllClose(np_arange(2.5), - jnp.arange(2.5)) - self.assertAllClose(np_arange(0., 2.5), - jnp.arange(0., 2.5)) - - def testArangeTypes(self): - # Test that arange() output type is equal to the default types. - int_ = dtypes.canonicalize_dtype(jnp.int_) - float_ = dtypes.canonicalize_dtype(jnp.float_) - - self.assertEqual(jnp.arange(10).dtype, int_) - self.assertEqual(jnp.arange(10.).dtype, float_) - self.assertEqual(jnp.arange(10, dtype='uint16').dtype, np.uint16) - #self.assertEqual(jnp.arange(10, dtype='bfloat16').dtype, jnp.bfloat16) - - self.assertEqual(jnp.arange(0, 10, 1).dtype, int_) - with jax.numpy_dtype_promotion('standard'): - self.assertEqual(jnp.arange(0, 10, 1.).dtype, float_) - self.assertEqual(jnp.arange(0., 10, 1).dtype, float_) - - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in nonzerodim_shapes - for axis in (NO_VALUE, None, *range(-len(shape), len(shape))) - ], - stable=[True, False], - dtype=all_dtypes, - ) - def testSort(self, dtype, shape, axis, stable): - rng = jtu.rand_some_equal(self.rng()) if stable else jtu.rand_some_inf_and_nan(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - kwds = {} if axis is NO_VALUE else {'axis': axis} - - def np_fun(arr): - # Note: numpy sort fails on NaN and Inf values with bfloat16 - dtype = arr.dtype - if arr.dtype == jnp.bfloat16: - arr = arr.astype('float32') - # TODO(jakevdp): switch to stable=stable when supported by numpy. - result = np.sort(arr, kind='stable' if stable else None, **kwds) - with jtu.ignore_warning(category=RuntimeWarning, message='invalid value'): - return result.astype(dtype) - jnp_fun = partial(jnp.sort, stable=stable, **kwds) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - def testSortStableDescending(self): - # TODO(jakevdp): test directly against np.sort when descending is supported. - x = jnp.array([0, 1, jnp.nan, 0, 2, jnp.nan, -jnp.inf, jnp.inf]) - x_sorted = jnp.array([-jnp.inf, 0, 0, 1, 2, jnp.inf, jnp.nan, jnp.nan]) - argsorted_stable = jnp.array([6, 0, 3, 1, 4, 7, 2, 5]) - argsorted_rev_stable = jnp.array([2, 5, 7, 4, 1, 0, 3, 6]) - - self.assertArraysEqual(jnp.sort(x), x_sorted) - self.assertArraysEqual(jnp.sort(x, descending=True), lax.rev(x_sorted, [0])) - self.assertArraysEqual(jnp.argsort(x), argsorted_stable) - self.assertArraysEqual(jnp.argsort(x, descending=True), argsorted_rev_stable) - - @unittest.skip("Jax-metal don't support complex.") - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in one_dim_array_shapes - for axis in [None] - ], - dtype=all_dtypes, - ) - def testSortComplex(self, dtype, shape, axis): - rng = jtu.rand_some_equal(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np.sort_complex, jnp.sort_complex, args_maker, - check_dtypes=False) - self._CompileAndCheck(jnp.sort_complex, args_maker) - - @unittest.skip("Jax-metal fail to convert sort op.") - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in nonempty_nonscalar_array_shapes - for axis in (-1, *range(len(shape) - 1)) - ], - dtype=all_dtypes, - input_type=[np.array, tuple], - ) - def testLexsort(self, dtype, shape, input_type, axis): - rng = jtu.rand_some_equal(self.rng()) - args_maker = lambda: [input_type(rng(shape, dtype))] - jnp_op = lambda x: jnp.lexsort(x, axis=axis) - np_op = jtu.with_jax_dtype_defaults(lambda x: np.lexsort(x, axis=axis)) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - @unittest.skip("JAX-metal crash.") - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in nonzerodim_shapes - for axis in (NO_VALUE, None, *range(-len(shape), len(shape))) - ], - dtype=all_dtypes, - ) - def testArgsort(self, dtype, shape, axis): - rng = jtu.rand_some_equal(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - kwds = {} if axis is NO_VALUE else {'axis': axis} - - @jtu.with_jax_dtype_defaults - def np_fun(arr): - # Note: numpy sort fails on NaN and Inf values with bfloat16 - if arr.dtype == jnp.bfloat16: - arr = arr.astype('float32') - # TODO(jakevdp): switch to stable=True when supported by numpy. - return np.argsort(arr, kind='stable', **kwds) - jnp_fun = partial(jnp.argsort, stable=True, **kwds) - - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @unittest.skip("JAX-metal crash.") - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in nonempty_nonscalar_array_shapes - for axis in (NO_VALUE, None, *range(-len(shape), len(shape))) - ], - descending=[True, False], - dtype=all_dtypes, - ) - def testArgsortUnstable(self, dtype, shape, axis, descending): - # We cannot directly compare unstable argsorts, so instead check that indexed values match. - rng = jtu.rand_some_equal(self.rng()) - x = rng(shape, dtype) - kwds = {} if axis is NO_VALUE else {'axis': axis} - expected = jnp.sort(x, descending=descending, stable=False, **kwds) - indices = jnp.argsort(x, descending=descending, stable=False, **kwds) - if axis is None: - actual = jnp.ravel(x)[indices] - else: - actual = jnp.take_along_axis(x, indices, axis=-1 if axis is NO_VALUE else axis) - self.assertArraysEqual(actual, expected) - - @jtu.sample_product( - [{'shape': shape, 'axis': axis, 'kth': kth} - for shape in nonzerodim_shapes - for axis in range(-len(shape), len(shape)) - for kth in range(-shape[axis], shape[axis])], - dtype=default_dtypes, - ) - def testPartition(self, shape, dtype, axis, kth): - rng = jtu.rand_default(self.rng()) - arg = rng(shape, dtype) - jnp_output = jnp.partition(arg, axis=axis, kth=kth) - np_output = np.partition(arg, axis=axis, kth=kth) - - # Assert that pivot point is equal: - self.assertArraysEqual( - lax.index_in_dim(jnp_output, axis=axis, index=kth), - lax.index_in_dim(np_output, axis=axis, index=kth)) - - # Assert remaining values are correctly partitioned: - self.assertArraysEqual( - lax.sort(lax.slice_in_dim(jnp_output, start_index=0, limit_index=kth, axis=axis), dimension=axis), - lax.sort(lax.slice_in_dim(np_output, start_index=0, limit_index=kth, axis=axis), dimension=axis)) - self.assertArraysEqual( - lax.sort(lax.slice_in_dim(jnp_output, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis), - lax.sort(lax.slice_in_dim(np_output, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis)) - - #@unittest.skipIf(jtu.device_under_test=="METAL", "Jax-metal fail on empty dim shape.") - @jtu.sample_product( - [{'shape': shape, 'axis': axis, 'kth': kth} - for shape in nonempty_shapes# nonzerodim_shapes - for axis in range(-len(shape), len(shape)) - for kth in range(-shape[axis], shape[axis])], - dtype=default_dtypes, - ) - def testArgpartition(self, shape, dtype, axis, kth): - rng = jtu.rand_default(self.rng()) - arg = rng(shape, dtype) - - jnp_output = jnp.argpartition(arg, axis=axis, kth=kth) - np_output = np.argpartition(arg, axis=axis, kth=kth) - - # Assert that all indices are present - self.assertArraysEqual(jnp.sort(jnp_output, axis), np.sort(np_output, axis), check_dtypes=False) - - # Because JAX & numpy may treat duplicates differently, we must compare values - # rather than indices. - getvals = lambda x, ind: x[ind] - for ax in range(arg.ndim): - if ax != range(arg.ndim)[axis]: - getvals = jax.vmap(getvals, in_axes=ax, out_axes=ax) - jnp_values = getvals(arg, jnp_output) - np_values = getvals(arg, np_output) - - # Assert that pivot point is equal: - self.assertArraysEqual( - lax.index_in_dim(jnp_values, axis=axis, index=kth), - lax.index_in_dim(np_values, axis=axis, index=kth)) - - # Assert remaining values are correctly partitioned: - self.assertArraysEqual( - lax.sort(lax.slice_in_dim(jnp_values, start_index=0, limit_index=kth, axis=axis), dimension=axis), - lax.sort(lax.slice_in_dim(np_values, start_index=0, limit_index=kth, axis=axis), dimension=axis)) - self.assertArraysEqual( - lax.sort(lax.slice_in_dim(jnp_values, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis), - lax.sort(lax.slice_in_dim(np_values, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis)) - - @jtu.sample_product( - [dict(shifts=shifts, axis=axis) - for shifts, axis in [ - (3, None), - (1, 1), - ((3,), (0,)), - ((-2,), (-2,)), - ((1, 2), (0, -1)), - ((4, 2, 5, 5, 2, 4), None), - (100, None), - ] - ], - dtype=all_dtypes, - shape=[(3, 4), (3, 4, 5), (7, 4, 0)], - ) - def testRoll(self, shape, dtype, shifts, axis): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype), np.array(shifts)] - jnp_op = partial(jnp.roll, axis=axis) - np_op = partial(np.roll, axis=axis) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - @jtu.sample_product( - dtype=all_dtypes, - shape=[(1, 2, 3, 4)], - axis=[-3, 0, 2, 3], - start=[-4, -1, 2, 4], - ) - def testRollaxis(self, shape, dtype, start, axis): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - jnp_op = partial(jnp.rollaxis, axis=axis, start=start) - np_op = partial(np.rollaxis, axis=axis, start=start) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - @unittest.skip("jax-metal generates a different result from cpu.") - @jtu.sample_product( - dtype=[np.uint8, np.bool_], - bitorder=['big', 'little'], - shape=[(1, 2, 3, 4)], - axis=[None, 0, 1, -2, -1], - ) - def testPackbits(self, shape, dtype, axis, bitorder): - rng = jtu.rand_some_zero(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - jnp_op = partial(jnp.packbits, axis=axis, bitorder=bitorder) - np_op = partial(np.packbits, axis=axis, bitorder=bitorder) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - @jtu.sample_product( - dtype=[np.uint8], - bitorder=['big', 'little'], - shape=[(1, 2, 3, 4)], - axis=[None, 0, 1, -2, -1], - count=[None, 20], - ) - def testUnpackbits(self, shape, dtype, axis, bitorder, count): - rng = jtu.rand_int(self.rng(), 0, 256) - args_maker = lambda: [rng(shape, dtype)] - jnp_op = partial(jnp.unpackbits, axis=axis, bitorder=bitorder, count=count) - np_op = partial(np.unpackbits, axis=axis, bitorder=bitorder, count=count) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - #@unittest.skip("jax-metal generates a different result from cpu.") - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in [(3,), (3, 4), (3, 4, 5)] - for axis in itertools.chain(range(-len(shape), len(shape)), - [cast(Union[int, None], None)]) - ], - index_shape=scalar_shapes + [(3,), (2, 1, 3)], - dtype=all_dtypes, - index_dtype=int_dtypes, - #mode=[None, 'wrap', 'clip'], - mode=[None, 'wrap'], - ) - def testTake(self, shape, dtype, index_shape, index_dtype, axis, mode): - def args_maker(): - x = rng(shape, dtype) - i = rng_indices(index_shape, index_dtype) - return x, i - - rng = jtu.rand_default(self.rng()) - if mode is None: - rng_indices = jtu.rand_int(self.rng(), -shape[axis or 0], shape[axis or 0]) - else: - rng_indices = jtu.rand_int(self.rng(), -5, 5) - jnp_op = lambda x, i: jnp.take(x, i, axis=axis, mode=mode) - np_op = lambda x, i: np.take(x, i, axis=axis, mode=mode) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - def testTakeEmpty(self): - np.testing.assert_array_equal( - jnp.array([], dtype=jnp.float32), - jnp.take(jnp.array([], jnp.float32), jnp.array([], jnp.int32))) - - np.testing.assert_array_equal( - jnp.ones((2, 0, 4), dtype=jnp.float32), - jnp.take(jnp.ones((2, 0, 4), dtype=jnp.float32), jnp.array([], jnp.int32), - axis=1)) - - with self.assertRaisesRegex(IndexError, "non-empty jnp.take"): - jnp.take(jnp.ones((2, 0, 4), dtype=jnp.float32), - jnp.array([0], jnp.int32), axis=1) - - def testTakeOptionalArgs(self): - x = jnp.arange(5.0) - ind = jnp.array([0, 2, 4, 6]) - expected = jnp.array([0.0, 2.0, 4.0, 10.0], dtype=x.dtype) - actual = jnp.take(x, ind, unique_indices=True, - indices_are_sorted=True, fill_value=10.0) - self.assertArraysEqual(expected, actual) - - @jtu.sample_product( - [dict(x_shape=x_shape, i_shape=i_shape, axis=axis) - for x_shape, i_shape in filter( - _shapes_are_equal_length, - filter(_shapes_are_broadcast_compatible, - itertools.combinations_with_replacement(nonempty_nonscalar_array_shapes, 2))) - for axis in itertools.chain(range(len(x_shape)), [-1], - [cast(Union[int, None], None)]) - ], - dtype=default_dtypes, - index_dtype=int_dtypes, - ) - def testTakeAlongAxis(self, x_shape, i_shape, dtype, index_dtype, axis): - rng = jtu.rand_default(self.rng()) - - i_shape = list(i_shape) - if axis is None: - i_shape = [math.prod(i_shape)] - else: - # Test the case where the size of the axis doesn't necessarily broadcast. - i_shape[axis] *= 3 - def args_maker(): - x = rng(x_shape, dtype) - n = math.prod(x_shape) if axis is None else x_shape[axis] - if np.issubdtype(index_dtype, np.unsignedinteger): - index_rng = jtu.rand_int(self.rng(), 0, n) - else: - index_rng = jtu.rand_int(self.rng(), -n, n) - i = index_rng(i_shape, index_dtype) - return x, i - - jnp_op = lambda x, i: jnp.take_along_axis(x, i, axis=axis) - - if hasattr(np, "take_along_axis"): - np_op = lambda x, i: np.take_along_axis(x, i, axis=axis) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - def testTakeAlongAxisWithUint8IndicesDoesNotOverflow(self): - # https://github.com/jax-ml/jax/issues/5088 - h = jtu.rand_default(self.rng())((256, 256, 100), np.float32) - g = jtu.rand_int(self.rng(), 0, 100)((256, 256, 1), np.uint8) - q0 = jnp.take_along_axis(h, g, axis=-1) - q1 = np.take_along_axis( h, g, axis=-1) - np.testing.assert_equal(q0, q1) - - @unittest.skip("Jax-metal fail.") - def testTakeAlongAxisOutOfBounds(self): - x = jnp.arange(10, dtype=jnp.float32) - idx = jnp.array([-11, -10, -9, -5, -1, 0, 1, 5, 9, 10, 11]) - out = jnp.take_along_axis(x, idx, axis=0) - expected_fill = np.array([jnp.nan, 0, 1, 5, 9, 0, 1, 5, 9, jnp.nan, - jnp.nan], np.float32) - np.testing.assert_array_equal(expected_fill, out) - out = jnp.take_along_axis(x, idx, axis=0, mode="fill") - np.testing.assert_array_equal(expected_fill, out) - - expected_clip = np.array([0, 0, 1, 5, 9, 0, 1, 5, 9, 9, 9], np.float32) - out = jnp.take_along_axis(x, idx, axis=0, mode="clip") - np.testing.assert_array_equal(expected_clip, out) - - def testTakeAlongAxisRequiresIntIndices(self): - x = jnp.arange(5) - idx = jnp.array([3.], jnp.float32) - with self.assertRaisesRegex( - TypeError, - "take_along_axis indices must be of integer type, got float32"): - jnp.take_along_axis(x, idx, axis=0) - - def testTakeAlongAxisWithEmptyArgs(self): - # take_along_axis should allow us to gather an empty list of indices from - # an empty input axis without raising a shape error. - x = jnp.ones((4, 0, 3), dtype=jnp.int32) - np.testing.assert_array_equal(x, jnp.take_along_axis(x, x, axis=1)) - - @jtu.sample_product( - dtype=inexact_dtypes, - shape=[0, 5], - n=[2, 4], - increasing=[False, True], - ) - def testVander(self, shape, dtype, n, increasing): - rng = jtu.rand_default(self.rng()) - def np_fun(arg): - arg = arg.astype(np.float32) if dtype == jnp.bfloat16 else arg - return np.vander(arg, N=n, increasing=increasing) - jnp_fun = lambda arg: jnp.vander(arg, N=n, increasing=increasing) - args_maker = lambda: [rng([shape], dtype)] - # np.vander seems to return float64 for all floating types. We could obey - # those semantics, but they seem like a bug. - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, - tol={np.float32: 1e-3, np.complex64: 1e-3}) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=False) - - @jtu.sample_product( - shape=array_shapes, - dtype=all_dtypes, - ) - def testNanToNum(self, shape, dtype): - rng = jtu.rand_some_inf_and_nan(self.rng()) - dtype = np.dtype(dtypes.canonicalize_dtype(dtype)).type - def np_fun(x): - if dtype == jnp.bfloat16: - x = np.where(np.isnan(x), dtype(0), x) - x = np.where(np.isposinf(x), jnp.finfo(dtype).max, x) - x = np.where(np.isneginf(x), jnp.finfo(dtype).min, x) - return x - else: - return np.nan_to_num(x).astype(dtype) - - args_maker = lambda: [rng(shape, dtype)] - check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE - self._CheckAgainstNumpy(np_fun, jnp.nan_to_num, args_maker, - check_dtypes=check_dtypes) - self._CompileAndCheck(jnp.nan_to_num, args_maker, - check_dtypes=check_dtypes) - - @jtu.sample_product( - [dict(shapes=shapes, dtypes=dtypes) - for shapes, dtypes in ( - ((), ()), - (((7,),), (np.int32,)), - (((3,), (4,)), (np.int32, np.int32)), - (((3,), (1,), (4,)), (np.int32, np.int32, np.int32)), - ) - ], - ) - def testIx_(self, shapes, dtypes): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype) - for shape, dtype in zip(shapes, dtypes)] - self._CheckAgainstNumpy(np.ix_, jnp.ix_, args_maker) - self._CompileAndCheck(jnp.ix_, args_maker) - - @jtu.sample_product( - dimensions=[(), (2,), (3, 0), (4, 5, 6)], - dtype=number_dtypes, - sparse=[True, False], - ) - def testIndices(self, dimensions, dtype, sparse): - def args_maker(): return [] - np_fun = partial(np.indices, dimensions=dimensions, - dtype=dtype, sparse=sparse) - jnp_fun = partial(jnp.indices, dimensions=dimensions, - dtype=dtype, sparse=sparse) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - shape=nonzerodim_shapes, dtype=all_dtypes, - ) - def testWhereOneArgument(self, shape, dtype): - rng = jtu.rand_some_zero(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np.where, jnp.where, args_maker, check_dtypes=False) - - # JIT compilation requires specifying a size statically. Full test of - # this behavior is in testNonzeroSize(). - jnp_fun = lambda x: jnp.where(x, size=np.size(x) // 2) - - with jtu.ignore_warning(category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*"): - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - shapes=filter(_shapes_are_broadcast_compatible, - itertools.combinations_with_replacement(all_shapes, 3)), - dtypes=itertools.combinations_with_replacement(all_dtypes, 3), - ) - def testWhereThreeArgument(self, shapes, dtypes): - rng = jtu.rand_default(self.rng()) - args_maker = self._GetArgsMaker(rng, shapes, dtypes) - def np_fun(cond, x, y): - return jtu.promote_like_jnp(partial(np.where, cond))(x, y) - with jtu.strict_promotion_if_dtypes_match(dtypes): - self._CheckAgainstNumpy(np_fun, jnp.where, args_maker) - self._CompileAndCheck(jnp.where, args_maker) - - def testWhereExtraCode(self): - def f(x): - return jnp.where(x > 0, x, -x) - - # Test no comparison literal True/False in jaxpr, and hence no comparison to - # literals - jaxpr = jax.make_jaxpr(jax.grad(f))(3.) - self.assertNotIn('False', str(jaxpr)) - self.assertNotIn('True', str(jaxpr)) - - def testWhereScalarPromotion(self): - x = jnp.where(jnp.array([True, False]), 3, - jnp.ones((2,), dtype=jnp.float32)) - self.assertEqual(x.dtype, np.dtype(np.float32)) - - @jtu.sample_product( - [dict(n=n, shapes=shapes) - for n in range(1, 3) - for shapes in filter( - _shapes_are_broadcast_compatible, - itertools.combinations_with_replacement(all_shapes, 2 * n + 1)) - ], - # To avoid forming the full product of shapes and dtypes we always sample - # maximal set of dtypes. - dtypes=itertools.combinations_with_replacement(all_dtypes, 3), - ) - def testSelect(self, n, shapes, dtypes): - dtypes = dtypes[:n+1] - rng = jtu.rand_default(self.rng()) - n = len(dtypes) - 1 - def args_maker(): - condlist = [rng(shape, np.bool_) for shape in shapes[:n]] - choicelist = [rng(shape, dtype) - for shape, dtype in zip(shapes[n:-1], dtypes[:n])] - default = rng(shapes[-1], dtypes[-1]) - return condlist, choicelist, default - # TODO(phawkins): float32/float64 type mismatches - @jax.numpy_dtype_promotion('standard') - def np_fun(condlist, choicelist, default): - choicelist = [x if jnp.result_type(x) != jnp.bfloat16 - else x.astype(np.float32) for x in choicelist] - dtype = jnp.result_type(default, *choicelist) - return np.select(condlist, - [np.asarray(x, dtype=dtype) for x in choicelist], - np.asarray(default, dtype=dtype)) - with jtu.strict_promotion_if_dtypes_match(dtypes): - self._CheckAgainstNumpy(np_fun, jnp.select, args_maker, - check_dtypes=False) - self._CompileAndCheck(jnp.select, args_maker, - rtol={np.float64: 1e-7, np.complex128: 1e-7}) - - def testIssue330(self): - x = jnp.full((1, 1), jnp.array([1])[0]) # doesn't crash - self.assertEqual(x[0, 0], 1) - - def testScalarDtypePromotion(self): - orig_numpy_result = (1 + np.eye(1, dtype=np.float32)).dtype - jax_numpy_result = (1 + jnp.eye(1, dtype=jnp.float32)).dtype - self.assertEqual(orig_numpy_result, jax_numpy_result) - - def testSymmetrizeDtypePromotion(self): - x = np.eye(3, dtype=np.float32) - orig_numpy_result = ((x + x.T) / 2).dtype - - x = jnp.eye(3, dtype=jnp.float32) - jax_numpy_result = ((x + x.T) / 2).dtype - self.assertEqual(orig_numpy_result, jax_numpy_result) - - # NOTE(mattjj): I disabled this test when removing lax._safe_mul because - # introducing the convention 0 * inf = 0 leads to silently wrong results in - # some cases. See this comment for details: - # https://github.com/jax-ml/jax/issues/1052#issuecomment-514083352 - # def testIssue347(self): - # # https://github.com/jax-ml/jax/issues/347 - # def test_fail(x): - # x = jnp.sqrt(jnp.sum(x ** 2, axis=1)) - # ones = jnp.ones_like(x) - # x = jnp.where(x > 0.5, x, ones) - # return jnp.sum(x) - # x = jnp.array([[1, 2], [3, 4], [0, 0]], dtype=jnp.float64) - # result = jax.grad(test_fail)(x) - # assert not np.any(np.isnan(result)) - - def testIssue453(self): - # https://github.com/jax-ml/jax/issues/453 - a = np.arange(6) + 1 - ans = jnp.reshape(a, (3, 2), order='F') - expected = np.reshape(a, (3, 2), order='F') - self.assertAllClose(ans, expected) - - @jtu.sample_product( - #dtype=[int, float, bool, complex], - dtype=[int, float, bool], - op=["atleast_1d", "atleast_2d", "atleast_3d"], - ) - def testAtLeastNdLiterals(self, dtype, op): - # Fixes: https://github.com/jax-ml/jax/issues/634 - np_fun = lambda arg: getattr(np, op)(arg).astype(dtypes.python_scalar_dtypes[dtype]) - jnp_fun = lambda arg: getattr(jnp, op)(arg) - args_maker = lambda: [dtype(2)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - shape=[(0,), (5,), (10,)], - dtype=int_dtypes, - weights=[True, False], - minlength=[0, 20], - length=[None, 8], - ) - def testBincount(self, shape, dtype, weights, minlength, length): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: (rng(shape, dtype), (rng(shape, 'float32') if weights else None)) - - def np_fun(x, *args): - x = np.clip(x, 0, None) # jnp.bincount clips negative values to zero. - out = np.bincount(x, *args, minlength=minlength) - if length and length > out.size: - return np.pad(out, (0, length - out.size)) - return out[:length] - jnp_fun = partial(jnp.bincount, minlength=minlength, length=length) - - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - if length is not None: - self._CompileAndCheck(jnp_fun, args_maker) - - def testBincountNegative(self): - # Test that jnp.bincount ignores negative values. - x_rng = jtu.rand_int(self.rng(), -100, 100) - w_rng = jtu.rand_uniform(self.rng()) - shape = (1000,) - x = x_rng(shape, 'int32') - w = w_rng(shape, 'float32') - - xn = np.array(x) - xn[xn < 0] = 0 - wn = np.array(w) - np_result = np.bincount(xn[xn >= 0], wn[xn >= 0]) - jnp_result = jnp.bincount(x, w) - self.assertAllClose(np_result, jnp_result, check_dtypes=False) - - @jtu.sample_product( - input=[ - 3, - [3], - [np.array(3)], - [np.array([3])], - [[np.array(3)]], - [[np.array([3])]], - [3, 4, 5], - [ - [np.eye(2, dtype=np.int32) * 2, np.zeros((2, 3), dtype=np.int32)], - [np.ones((3, 2), dtype=np.int32), np.eye(3, dtype=np.int32) * 3], - ], - [np.array([1, 2, 3]), np.array([2, 3, 4]), 10], - [np.ones((2, 2), dtype=np.int32), np.zeros((2, 2), dtype=np.int32)], - [[np.array([1, 2, 3])], [np.array([2, 3, 4])]], - ], - ) - def testBlock(self, input): - args_maker = lambda: [input] - self._CheckAgainstNumpy(np.block, jnp.block, args_maker) - self._CompileAndCheck(jnp.block, args_maker) - - def testLongLong(self): - self.assertAllClose(np.int64(7), jax.jit(lambda x: x)(np.longlong(7))) - - @jtu.ignore_warning(category=UserWarning, - message="Explicitly requested dtype.*") - @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion - def testArange(self): - # test cases inspired by dask tests at - # https://github.com/dask/dask/blob/main/dask/array/tests/test_creation.py#L92 - np_arange = jtu.with_jax_dtype_defaults(np.arange) - self.assertAllClose(jnp.arange(77), - np_arange(77)) - self.assertAllClose(jnp.arange(2, 13), - np_arange(2, 13)) - self.assertAllClose(jnp.arange(4, 21, 9), - np_arange(4, 21, 9)) - self.assertAllClose(jnp.arange(53, 5, -3), - np_arange(53, 5, -3)) - self.assertAllClose(jnp.arange(77, dtype=float), - np_arange(77, dtype=float)) - self.assertAllClose(jnp.arange(2, 13, dtype=int), - np_arange(2, 13, dtype=int)) - self.assertAllClose(jnp.arange(0, 1, -0.5), - np_arange(0, 1, -0.5)) - - self.assertRaises(TypeError, lambda: jnp.arange()) - - # test that jnp.arange(N) doesn't instantiate an ndarray - self.assertNotEqual(type(jnp.arange(77)), type(np.arange(77))) - self.assertEqual(type(jnp.arange(77)), type(lax.iota(np.int32, 77))) - - # test that jnp.arange(N, dtype=int32) doesn't instantiate an ndarray - self.assertNotEqual(type(jnp.arange(77, dtype=jnp.int32)), - type(np.arange(77, dtype=np.int32))) - self.assertEqual(type(jnp.arange(77, dtype=jnp.int32)), - type(lax.iota(np.int32, 77))) - - def testArangeJit(self): - ans = jax.jit(lambda: jnp.arange(5))() - expected = jtu.with_jax_dtype_defaults(np.arange)(5) - self.assertAllClose(ans, expected) - - @jtu.sample_product(args=[(5,), (0, 5)]) - def testArangeJaxpr(self, args): - jaxpr = jax.make_jaxpr(lambda: jnp.arange(*args))() - self.assertEqual(len(jaxpr.jaxpr.eqns), 1) - self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.iota_p) - - @unittest.skip("Jax-metal don't support complex.") - def testIssue830(self): - a = jnp.arange(4, dtype=jnp.complex64) - self.assertEqual(a.dtype, jnp.complex64) - - def testIssue728(self): - np_eye = jtu.with_jax_dtype_defaults(np.eye) - self.assertAllClose(jnp.eye(5000), np_eye(5000)) - self.assertEqual(0, np.sum(jnp.eye(1050) - np_eye(1050))) - - def testIssue746(self): - jnp.arange(12).reshape(3, 4) # doesn't crash - - def testIssue764(self): - x = jnp.linspace(190, 200, 4) - f = jax.grad(lambda x: jnp.sum(jnp.tanh(x))) - # Expected values computed with autograd in float64 precision. - expected = np.array([3.71669453e-165, 4.72999108e-168, 6.01954653e-171, - 7.66067839e-174], np.float64) - self.assertAllClose(f(x), expected, check_dtypes=False) - - # Test removed because tie_in is deprecated. - # def testIssue776(self): - # """Tests that the scatter-add transpose rule instantiates symbolic zeros.""" - # def f(u): - # y = jnp.ones_like(u, shape=10).at[np.array([2, 4, 5])].add(u) - # # The transpose rule for lax.tie_in returns a symbolic zero for its first - # # argument. - # return lax.tie_in(y, 7.) - - # self.assertAllClose(np.zeros(3,), jax.grad(f)(np.ones(3,))) - - # NOTE(mattjj): I disabled this test when removing lax._safe_mul because this - # is a numerical stability issue that should be solved with a custom jvp rule - # of the sigmoid function being differentiated here, not by safe_mul. - # def testIssue777(self): - # x = jnp.linspace(-200, 0, 4, dtype=np.float32) - # f = jax.grad(lambda x: jnp.sum(1 / (1 + jnp.exp(-x)))) - # self.assertAllClose(f(x), np.array([0., 0., 0., 0.25], dtype=np.float32)) - - #unittest.skip("Jax-metal fail on tanh with np.nan") - @jtu.sample_product( - dtype=float_dtypes, - op=("sqrt", "arccos", "arcsin", "arctan", "sin", "cos", "tan", - "sinh", "cosh", "tanh", "arccosh", "arcsinh", "arctanh", "exp", - "log", "expm1", "log1p"), - ) - def testMathSpecialFloatValues(self, op, dtype): - np_op = getattr(np, op) - np_op = jtu.ignore_warning(category=RuntimeWarning, - message="invalid value.*")(np_op) - np_op = jtu.ignore_warning(category=RuntimeWarning, - message="divide by zero.*")(np_op) - np_op = jtu.ignore_warning(category=RuntimeWarning, - message="overflow.*")(np_op) - - jnp_op = getattr(jnp, op) - dtype = np.dtype(dtypes.canonicalize_dtype(dtype)).type - for x in (-np.inf, -100., -2., -1., 0., 1., 2., 100., np.inf, - jnp.finfo(dtype).max, np.sqrt(jnp.finfo(dtype).max), - np.sqrt(jnp.finfo(dtype).max) * 2.): #np.nan - x = dtype(x) - expected = np_op(x) - actual = jnp_op(x) - tol = jtu.tolerance(dtype, {np.float32: 1e-3, np.float64: 1e-7}) - self.assertAllClose(expected, actual, atol=tol, - rtol=tol) - - def testIssue956(self): - self.assertRaises(TypeError, lambda: jnp.ndarray((1, 1))) - - def testIssue967(self): - self.assertRaises(TypeError, lambda: jnp.zeros(1.5)) - - @jtu.sample_product( - shape=[(5,), (10, 5), (4, 10)], - dtype=number_dtypes, - rowvar=[True, False], - ) - @jax.default_matmul_precision("float32") - def testCorrCoef(self, shape, dtype, rowvar): - rng = jtu.rand_default(self.rng()) - def args_maker(): - ok = False - while not ok: - x = rng(shape, dtype) - ok = not np.any(np.isclose(np.std(x), 0.0)) - return (x,) - np_fun = partial(np.corrcoef, rowvar=rowvar) - np_fun = jtu.ignore_warning( - category=RuntimeWarning, message="invalid value encountered.*")(np_fun) - jnp_fun = partial(jnp.corrcoef, rowvar=rowvar) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(dtype=dtype, end_dtype=end_dtype, begin_dtype=begin_dtype, - shape=shape, begin_shape=begin_shape, end_shape=end_shape) - for dtype in number_dtypes - for end_dtype in [None] + [dtype] - for begin_dtype in [None] + [dtype] - for shape in [s for s in all_shapes if s != jtu.PYTHON_SCALAR_SHAPE] - for begin_shape in ( - [None] if begin_dtype is None - else [s for s in all_shapes if s != jtu.PYTHON_SCALAR_SHAPE]) - for end_shape in ( - [None] if end_dtype is None - else [s for s in all_shapes if s != jtu.PYTHON_SCALAR_SHAPE]) - ], - ) - def testEDiff1d(self, shape, dtype, end_shape, end_dtype, begin_shape, - begin_dtype): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype), - (None if end_dtype is None else rng(end_shape, end_dtype)), - (None if begin_dtype is None else rng(begin_shape, begin_dtype))] - np_fun = lambda x, to_end, to_begin: np.ediff1d(x, to_end, to_begin) - jnp_fun = lambda x, to_end, to_begin: jnp.ediff1d(x, to_end, to_begin) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - def testEDiff1dWithDtypeCast(self): - rng = jtu.rand_default(self.rng()) - shape = jtu.NUMPY_SCALAR_SHAPE - dtype = jnp.float32 - end_dtype = jnp.int32 - args_maker = lambda: [rng(shape, dtype), rng(shape, end_dtype), rng(shape, dtype)] - np_fun = lambda x, to_end, to_begin: np.ediff1d(x, to_end, to_begin) - jnp_fun = lambda x, to_end, to_begin: jnp.ediff1d(x, to_end, to_begin) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - shapes=[(), (5,), (5, 3)], - dtype=number_dtypes, - indexing=['xy', 'ij'], - sparse=[True, False], - ) - def testMeshGrid(self, shapes, dtype, indexing, sparse): - rng = jtu.rand_default(self.rng()) - args_maker = self._GetArgsMaker(rng, [(x,) for x in shapes], - [dtype] * len(shapes)) - np_fun = partial(np.meshgrid, indexing=indexing, sparse=sparse) - jnp_fun = partial(jnp.meshgrid, indexing=indexing, sparse=sparse) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - def testMgrid(self): - # wrap indexer for appropriate dtype defaults. - np_mgrid = _indexer_with_default_outputs(np.mgrid) - assertAllEqual = partial(self.assertAllClose, atol=0, rtol=0) - assertAllEqual(np_mgrid[()], jnp.mgrid[()]) - assertAllEqual(np_mgrid[:4], jnp.mgrid[:4]) - assertAllEqual(np_mgrid[:4,], jnp.mgrid[:4,]) - assertAllEqual(np_mgrid[:4], jax.jit(lambda: jnp.mgrid[:4])()) - assertAllEqual(np_mgrid[:5, :5], jnp.mgrid[:5, :5]) - assertAllEqual(np_mgrid[:3, :2], jnp.mgrid[:3, :2]) - assertAllEqual(np_mgrid[1:4:2], jnp.mgrid[1:4:2]) - assertAllEqual(np_mgrid[1:5:3, :5], jnp.mgrid[1:5:3, :5]) - assertAllEqual(np_mgrid[:3, :2, :5], jnp.mgrid[:3, :2, :5]) - assertAllEqual(np_mgrid[:3:2, :2, :5], jnp.mgrid[:3:2, :2, :5]) - # Corner cases - assertAllEqual(np_mgrid[:], jnp.mgrid[:]) - # When the step length is a complex number, because of float calculation, - # the values between jnp and np might slightly different. - atol = 1e-6 - rtol = 1e-6 - self.assertAllClose(np_mgrid[-1:1:5j], - jnp.mgrid[-1:1:5j], - atol=atol, - rtol=rtol) - self.assertAllClose(np_mgrid[3:4:7j], - jnp.mgrid[3:4:7j], - atol=atol, - rtol=rtol) - self.assertAllClose(np_mgrid[1:6:8j, 2:4], - jnp.mgrid[1:6:8j, 2:4], - atol=atol, - rtol=rtol) - # Non-integer steps - self.assertAllClose(np_mgrid[0:3.5:0.5], - jnp.mgrid[0:3.5:0.5], - atol=atol, - rtol=rtol) - self.assertAllClose(np_mgrid[1.3:4.2:0.3], - jnp.mgrid[1.3:4.2:0.3], - atol=atol, - rtol=rtol) - # abstract tracer value for jnp.mgrid slice - with self.assertRaisesRegex(core.ConcretizationTypeError, - "slice start of jnp.mgrid"): - jax.jit(lambda a, b: jnp.mgrid[a:b])(0, 2) - - def testOgrid(self): - # wrap indexer for appropriate dtype defaults. - np_ogrid = _indexer_with_default_outputs(np.ogrid) - def assertSequenceOfArraysEqual(xs, ys): - self.assertIsInstance(xs, (list, tuple)) - self.assertIsInstance(ys, (list, tuple)) - self.assertEqual(len(xs), len(ys)) - for x, y in zip(xs, ys): - self.assertArraysEqual(x, y) - - self.assertArraysEqual(np_ogrid[:5], jnp.ogrid[:5]) - self.assertArraysEqual(np_ogrid[:5], jax.jit(lambda: jnp.ogrid[:5])()) - self.assertArraysEqual(np_ogrid[1:7:2], jnp.ogrid[1:7:2]) - # List of arrays - assertSequenceOfArraysEqual(np_ogrid[:5,], jnp.ogrid[:5,]) - assertSequenceOfArraysEqual(np_ogrid[0:5, 1:3], jnp.ogrid[0:5, 1:3]) - assertSequenceOfArraysEqual(np_ogrid[1:3:2, 2:9:3], jnp.ogrid[1:3:2, 2:9:3]) - assertSequenceOfArraysEqual(np_ogrid[:5, :9, :11], jnp.ogrid[:5, :9, :11]) - # Corner cases - self.assertArraysEqual(np_ogrid[:], jnp.ogrid[:]) - # Complex number steps - atol = 1e-6 - rtol = 1e-6 - self.assertAllClose(np_ogrid[-1:1:5j], - jnp.ogrid[-1:1:5j], - atol=atol, - rtol=rtol) - # Non-integer steps - self.assertAllClose(np_ogrid[0:3.5:0.3], - jnp.ogrid[0:3.5:0.3], - atol=atol, - rtol=rtol) - self.assertAllClose(np_ogrid[1.2:4.8:0.24], - jnp.ogrid[1.2:4.8:0.24], - atol=atol, - rtol=rtol) - # abstract tracer value for ogrid slice - with self.assertRaisesRegex(core.ConcretizationTypeError, - "slice start of jnp.ogrid"): - jax.jit(lambda a, b: jnp.ogrid[a:b])(0, 2) - - def testR_(self): - a = np.arange(6).reshape((2,3)) - self.assertArraysEqual(np.r_[np.array([1,2,3]), 0, 0, np.array([4,5,6])], - jnp.r_[np.array([1,2,3]), 0, 0, np.array([4,5,6])]) - self.assertArraysEqual(np.r_['-1', a, a], jnp.r_['-1', a, a]) - - self.assertArraysEqual(np.r_['0,2', [1,2,3], [4,5,6]], jnp.r_['0,2', [1,2,3], [4,5,6]]) - self.assertArraysEqual(np.r_['0,2,0', [1,2,3], [4,5,6]], jnp.r_['0,2,0', [1,2,3], [4,5,6]]) - self.assertArraysEqual(np.r_['1,2,0', [1,2,3], [4,5,6]], jnp.r_['1,2,0', [1,2,3], [4,5,6]]) - # negative 1d axis start - self.assertArraysEqual(np.r_['0,4,-1', [1,2,3], [4,5,6]], jnp.r_['0,4,-1', [1,2,3], [4,5,6]]) - self.assertArraysEqual(np.r_['0,4,-2', [1,2,3], [4,5,6]], jnp.r_['0,4,-2', [1,2,3], [4,5,6]]) - - # matrix directives - with jtu.ignore_warning(category=PendingDeprecationWarning): - self.assertArraysEqual(np.r_['r',[1,2,3], [4,5,6]], jnp.r_['r',[1,2,3], [4,5,6]]) - self.assertArraysEqual(np.r_['c', [1, 2, 3], [4, 5, 6]], jnp.r_['c', [1, 2, 3], [4, 5, 6]]) - - # bad directive - with self.assertRaisesRegex(ValueError, "could not understand directive.*"): - jnp.r_["asdfgh",[1,2,3]] - # abstract tracer value for r_ slice - with self.assertRaisesRegex(core.ConcretizationTypeError, - "slice start of jnp.r_"): - jax.jit(lambda a, b: jnp.r_[a:b])(0, 2) - - # wrap indexer for appropriate dtype defaults. - np_r_ = _indexer_with_default_outputs(np.r_) - - # Complex number steps - atol = 1e-6 - rtol = 1e-6 - self.assertAllClose(np_r_[-1:1:6j], - jnp.r_[-1:1:6j], - atol=atol, - rtol=rtol) - with jax.numpy_dtype_promotion('standard'): # Requires dtype promotion. - self.assertAllClose(np_r_[-1:1:6j, [0]*3, 5, 6], - jnp.r_[-1:1:6j, [0]*3, 5, 6], - atol=atol, - rtol=rtol) - # Non-integer steps - self.assertAllClose(np_r_[1.2:4.8:0.24], - jnp.r_[1.2:4.8:0.24], - atol=atol, - rtol=rtol) - - def testC_(self): - a = np.arange(6).reshape((2, 3)) - self.assertArraysEqual(np.c_[np.array([1,2,3]), np.array([4,5,6])], - jnp.c_[np.array([1,2,3]), np.array([4,5,6])]) - self.assertArraysEqual(np.c_[np.array([[1,2,3]]), 0, 0, np.array([[4,5,6]])], - jnp.c_[np.array([[1,2,3]]), 0, 0, np.array([[4,5,6]])]) - self.assertArraysEqual(np.c_['-1', a, a], jnp.c_['-1', a, a]) - - self.assertArraysEqual(np.c_['0,2', [1,2,3], [4,5,6]], jnp.c_['0,2', [1,2,3], [4,5,6]]) - self.assertArraysEqual(np.c_['0,2,0', [1,2,3], [4,5,6]], jnp.c_['0,2,0', [1,2,3], [4,5,6]]) - self.assertArraysEqual(np.c_['1,2,0', [1,2,3], [4,5,6]], jnp.c_['1,2,0', [1,2,3], [4,5,6]]) - # negative 1d axis start - self.assertArraysEqual(np.c_['0,4,-1', [1,2,3], [4,5,6]], jnp.c_['0,4,-1', [1,2,3], [4,5,6]]) - self.assertArraysEqual(np.c_['0,4,-2', [1,2,3], [4,5,6]], jnp.c_['0,4,-2', [1,2,3], [4,5,6]]) - # matrix directives, avoid numpy deprecation warning - with jtu.ignore_warning(category=PendingDeprecationWarning): - self.assertArraysEqual(np.c_['r',[1,2,3], [4,5,6]], jnp.c_['r',[1,2,3], [4,5,6]]) - self.assertArraysEqual(np.c_['c', [1, 2, 3], [4, 5, 6]], jnp.c_['c', [1, 2, 3], [4, 5, 6]]) - - # bad directive - with self.assertRaisesRegex(ValueError, "could not understand directive.*"): - jnp.c_["asdfgh",[1,2,3]] - # abstract tracer value for c_ slice - with self.assertRaisesRegex(core.ConcretizationTypeError, - "slice start of jnp.c_"): - jax.jit(lambda a, b: jnp.c_[a:b])(0, 2) - - # wrap indexer for appropriate dtype defaults. - np_c_ = _indexer_with_default_outputs(np.c_) - - # Complex number steps - atol = 1e-6 - rtol = 1e-6 - self.assertAllClose(np_c_[-1:1:6j], - jnp.c_[-1:1:6j], - atol=atol, - rtol=rtol) - - # Non-integer steps - self.assertAllClose(np_c_[1.2:4.8:0.24], - jnp.c_[1.2:4.8:0.24], - atol=atol, - rtol=rtol) - - def testS_(self): - self.assertEqual(np.s_[1:2:20],jnp.s_[1:2:20]) - - def testIndex_exp(self): - self.assertEqual(np.index_exp[5:3:2j],jnp.index_exp[5:3:2j]) - - @jtu.sample_product( - start_shape=[(), (2,), (2, 2)], - stop_shape=[(), (2,), (2, 2)], - num=[0, 1, 2, 5, 20], - endpoint=[True, False], - retstep=[True, False], - # floating-point compute between jitted platforms and non-jit + rounding - # cause unavoidable variation in integer truncation for some inputs, so - # we currently only test inexact 'dtype' arguments. - dtype=inexact_dtypes + [None,], - ) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testLinspace(self, start_shape, stop_shape, num, endpoint, retstep, dtype): - rng = jtu.rand_default(self.rng()) - # relax default tolerances slightly - tol = jtu.tolerance(dtype if dtype else np.float32) * 10 - args_maker = self._GetArgsMaker(rng, - [start_shape, stop_shape], - [dtype, dtype]) - start, stop = args_maker() - ndim = len(np.shape(start + stop)) - for axis in range(-ndim, ndim): - jnp_op = lambda start, stop: jnp.linspace( - start, stop, num, - endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis) - np_op = lambda start, stop: np.linspace( - start, stop, num, - endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis) - - self._CheckAgainstNumpy(np_op, jnp_op, args_maker, - check_dtypes=False, tol=tol) - self._CompileAndCheck(jnp_op, args_maker, - check_dtypes=False, atol=tol, rtol=tol) - - @jtu.sample_product(dtype=number_dtypes) - def testLinspaceEndpoints(self, dtype): - """Regression test for Issue #3014.""" - rng = jtu.rand_default(self.rng()) - endpoints = rng((2,), dtype) - out = jnp.linspace(*endpoints, 10, dtype=dtype) - self.assertAllClose(out[np.array([0, -1])], endpoints, rtol=0, atol=0) - - @jtu.sample_product( - start_shape=[(), (2,), (2, 2)], - stop_shape=[(), (2,), (2, 2)], - num=[0, 1, 2, 5, 20], - endpoint=[True, False], - base=[10.0, 2, np.e], - # skip 16-bit floats due to insufficient precision for the test. - dtype=jtu.dtypes.inexact + [None,], - ) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testLogspace(self, start_shape, stop_shape, num, - endpoint, base, dtype): - if (dtype in int_dtypes and - jtu.test_device_matches(["gpu", "tpu"]) and - not config.enable_x64.value): - raise unittest.SkipTest("GPUx32 truncated exponentiation" - " doesn't exactly match other platforms.") - rng = jtu.rand_default(self.rng()) - # relax default tolerances slightly - tol = {np.float32: 1e-2, np.float64: 1e-6, np.complex64: 1e-3, np.complex128: 1e-6} - args_maker = self._GetArgsMaker(rng, - [start_shape, stop_shape], - [dtype, dtype]) - start, stop = args_maker() - ndim = len(np.shape(start + stop)) - for axis in range(-ndim, ndim): - jnp_op = lambda start, stop: jnp.logspace( - start, stop, num, endpoint=endpoint, base=base, dtype=dtype, axis=axis) - @jtu.ignore_warning(category=RuntimeWarning, - message="overflow encountered in power") - def np_op(start, stop): - return np.logspace(start, stop, num, endpoint=endpoint, - base=base, dtype=dtype, axis=axis) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker, - check_dtypes=False, tol=tol) - if dtype in (inexact_dtypes + [None,]): - # Why do compiled and op-by-op float16 np.power numbers differ - # slightly more than expected? - atol = {np.float16: 1e-2} - self._CompileAndCheck(jnp_op, args_maker, - check_dtypes=False, atol=atol, rtol=tol) - - @jtu.sample_product( - [dict(start_shape=start_shape, stop_shape=stop_shape, axis=axis) - for start_shape in [(), (2,), (2, 2)] - for stop_shape in [(), (2,), (2, 2)] - for axis in range(-max(len(start_shape), len(stop_shape)), - max(len(start_shape), len(stop_shape))) - ], - num=[0, 1, 2, 5, 20], - endpoint=[True, False], - # NB: numpy's geomspace gives nonsense results on integer types - dtype=inexact_dtypes + [None,], - ) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testGeomspace(self, start_shape, stop_shape, num, - endpoint, dtype, axis): - rng = jtu.rand_default(self.rng()) - # relax default tolerances slightly - tol = {dtypes.bfloat16: 2e-2, np.float16: 4e-3, np.float32: 2e-3, - np.float64: 1e-14, np.complex64: 2e-3, np.complex128: 1e-14} - def args_maker(): - """Test the set of inputs np.geomspace is well-defined on.""" - start, stop = self._GetArgsMaker(rng, - [start_shape, stop_shape], - [dtype, dtype])() - # np.geomspace can't handle differently ranked tensors - # w. negative numbers! - start, stop = jnp.broadcast_arrays(start, stop) - if dtype in complex_dtypes: - return start, stop - # to avoid NaNs, non-complex start and stop cannot - # differ in sign, elementwise - start = start * jnp.sign(start) * jnp.sign(stop) - return start, stop - start, stop = args_maker() - def jnp_op(start, stop): - return jnp.geomspace(start, stop, num, endpoint=endpoint, dtype=dtype, - axis=axis) - def np_op(start, stop): - start = start.astype(np.float32) if dtype == jnp.bfloat16 else start - stop = stop.astype(np.float32) if dtype == jnp.bfloat16 else stop - return np.geomspace( - start, stop, num, endpoint=endpoint, - dtype=dtype if dtype != jnp.bfloat16 else np.float32, - axis=axis).astype(dtype) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker, - check_dtypes=False, tol=tol) - if dtype in (inexact_dtypes + [None,]): - self._CompileAndCheck(jnp_op, args_maker, - check_dtypes=False, atol=tol, rtol=tol) - - def testDisableNumpyRankPromotionBroadcasting(self): - with jax.numpy_rank_promotion('allow'): - jnp.ones(2) + jnp.ones((1, 2)) # works just fine - - with jax.numpy_rank_promotion('raise'): - self.assertRaises(ValueError, lambda: jnp.ones(2) + jnp.ones((1, 2))) - jnp.ones(2) + 3 # don't want to raise for scalars - - with jax.numpy_rank_promotion('warn'): - self.assertWarnsRegex(UserWarning, "Following NumPy automatic rank promotion for add on " - r"shapes \(2,\) \(1, 2\).*", lambda: jnp.ones(2) + jnp.ones((1, 2))) - jnp.ones(2) + 3 # don't want to warn for scalars - - @unittest.skip("Test fails on CI, perhaps due to JIT caching") - def testDisableNumpyRankPromotionBroadcastingDecorator(self): - with jax.numpy_rank_promotion("allow"): - jnp.ones(2) + jnp.ones((1, 2)) # works just fine - - with jax.numpy_rank_promotion("raise"): - self.assertRaises(ValueError, lambda: jnp.ones(2) + jnp.ones((1, 2))) - jnp.ones(2) + 3 # don't want to raise for scalars - - with jax.numpy_rank_promotion("warn"): - self.assertWarnsRegex(UserWarning, "Following NumPy automatic rank promotion for add on " - r"shapes \(2,\) \(1, 2\).*", lambda: jnp.ones(2) + jnp.ones((1, 2))) - jnp.ones(2) + 3 # don't want to warn for scalars - - def testStackArrayArgument(self): - # tests https://github.com/jax-ml/jax/issues/1271 - @jax.jit - def foo(x): - return jnp.stack(x) - foo(np.zeros(2)) # doesn't crash - - @jax.jit - def foo(x): - return jnp.concatenate(x) - foo(np.zeros((2, 2))) # doesn't crash - - def testReluGradientConstants(self): - # This is a regression test that verifies that constants associated with the - # gradient of np.maximum (from lax._balanced_eq) aren't hoisted into the - # outermost jaxpr. This was producing some large materialized constants for - # every relu activation in a model. - def body(i, xy): - x, y = xy - y = y + jax.grad(lambda z: jnp.sum(jnp.maximum(z, 0.)))(x) - return x, y - - f = lambda y: lax.fori_loop(0, 5, body, (y, y)) - jaxpr = jax.make_jaxpr(f)(np.zeros((3, 4), np.float32)) - self.assertFalse( - any(np.array_equal(x, np.full((3, 4), 2., dtype=np.float32)) - for x in jaxpr.consts)) - - @jtu.sample_product( - [dict(from_shape=from_shape, to_shape=to_shape) - for from_shape, to_shape in [ - [(1, 3), (4, 3)], - [(3,), (2, 1, 3)], - [(3,), (3, 3)], - [(1,), (3,)], - [(1,), 3], - ] - ], - ) - def testBroadcastTo(self, from_shape, to_shape): - rng = jtu.rand_default(self.rng()) - args_maker = self._GetArgsMaker(rng, [from_shape], [np.float32]) - np_op = lambda x: np.broadcast_to(x, to_shape) - jnp_op = lambda x: jnp.broadcast_to(x, to_shape) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - @jtu.sample_product( - [dict(shapes=shapes, broadcasted_shape=broadcasted_shape) - for shapes, broadcasted_shape in [ - [[], ()], - [[()], ()], - [[(1, 3), (4, 3)], (4, 3)], - [[(3,), (2, 1, 3)], (2, 1, 3)], - [[(3,), (3, 3)], (3, 3)], - [[(1,), (3,)], (3,)], - [[(1,), 3], (3,)], - [[(6, 7), (5, 6, 1), (7,), (5, 1, 7)], (5, 6, 7)], - [[[1], [0, 1]], (0, 1)], - [[(1,), np.array([0, 1])], (0, 1)], - ] - ], - ) - def testBroadcastShapes(self, shapes, broadcasted_shape): - # Test against np.broadcast_shapes once numpy 1.20 is minimum required version - np.testing.assert_equal(jnp.broadcast_shapes(*shapes), broadcasted_shape) - - def testBroadcastToIssue1522(self): - self.assertRaisesRegex( - ValueError, "Incompatible shapes for broadcasting: .*", - lambda: jnp.broadcast_to(np.ones((2, 3)), (1, 3))) - - def testBroadcastToIntIssue1548(self): - self.assertAllClose(jnp.broadcast_to(1, (3, 2)), np.ones((3, 2)), - check_dtypes=False) - - def testBroadcastToOnScalar(self): - self.assertIsInstance(jnp.broadcast_to(10.0, ()), jax.Array) - self.assertIsInstance(np.broadcast_to(10.0, ()), np.ndarray) - - def testPrecision(self): - - ones_1d = np.ones((2,)) - ones_2d = np.ones((2, 2)) - ones_3d = np.ones((2, 2, 2)) - HIGHEST = lax.Precision.HIGHEST - - jtu.assert_dot_precision(None, jnp.dot, ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(jnp.dot, precision=HIGHEST), - ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(jnp.dot, precision=HIGHEST), - ones_3d, ones_3d) - jtu.assert_dot_precision( - HIGHEST, - partial(jnp.matmul, precision=HIGHEST), - ones_2d, ones_2d) - jtu.assert_dot_precision( - HIGHEST, - partial(jnp.vdot, precision=HIGHEST), - ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(jnp.vecdot, precision=HIGHEST), - ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(jnp.tensordot, axes=2, precision=HIGHEST), - ones_2d, ones_2d) - jtu.assert_dot_precision( - HIGHEST, - partial(jnp.tensordot, axes=(0, 0), precision=HIGHEST), - ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(jnp.tensordot, axes=((0,), (0,)), precision=HIGHEST), - ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(jnp.einsum, 'i,i', precision=HIGHEST), - ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(jnp.einsum, 'ij,ij', precision=HIGHEST), - ones_2d, ones_2d) - jtu.assert_dot_precision( - HIGHEST, - partial(jnp.inner, precision=HIGHEST), - ones_1d, ones_1d) - - @jtu.sample_product( - funcname=['inner', 'matmul', 'dot', 'vdot', 'tensordot', 'vecdot'] - ) - def testPreferredElementType(self, funcname): - func = getattr(jnp, funcname) - kwargs = dict(axes=0) if funcname == 'tensordot' else {} - - ones_i32 = np.ones(2, dtype='int32') - ones_f32 = np.ones(2, dtype='float32') - - with jax.numpy_dtype_promotion('strict'): - jtu.assert_dot_preferred_element_type('int32', func, ones_i32, ones_i32, **kwargs) - jtu.assert_dot_preferred_element_type('float32', func, ones_f32, ones_f32, **kwargs) - jtu.assert_dot_preferred_element_type('bfloat16', func, ones_f32, ones_f32, **kwargs, - preferred_element_type='bfloat16') - with jax.numpy_dtype_promotion('standard'): - jtu.assert_dot_preferred_element_type('float32', func, ones_i32, ones_f32, **kwargs) - - @jtu.sample_product( - [dict(shape=shape, varargs=varargs, axis=axis) - for shape in [(10,), (10, 15), (10, 15, 20)] - for _num_axes in range(len(shape)) - for varargs in itertools.combinations(range(1, len(shape) + 1), _num_axes) - for axis in itertools.combinations(range(len(shape)), _num_axes) - ], - dtype=inexact_dtypes, - ) - def testGradient(self, shape, varargs, axis, dtype): - rng = jtu.rand_default(self.rng()) - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - jnp_fun = lambda y: jnp.gradient(y, *varargs, axis=axis) - np_fun = lambda y: np.gradient(y, *varargs, axis=axis) - self._CheckAgainstNumpy( - np_fun, jnp_fun, args_maker, check_dtypes=False) - self._CompileAndCheck(jnp_fun, args_maker) - - def testZerosShapeErrors(self): - # see https://github.com/jax-ml/jax/issues/1822 - self.assertRaisesRegex( - TypeError, - "Shapes must be 1D sequences of concrete values of integer type.*", - lambda: jnp.zeros(1.)) - self.assertRaisesRegex( - TypeError, - r"Shapes must be 1D sequences of concrete values of integer type.*\n" - "If using `jit`, try using `static_argnums` or applying `jit` to " - "smaller subfunctions.", - lambda: jax.jit(jnp.zeros)(2)) - - def testTraceMethod(self): - x = self.rng().randn(3, 4).astype(jnp.float_) - self.assertAllClose(x.trace(), jnp.array(x).trace()) - self.assertAllClose(x.trace(), jax.jit(lambda y: y.trace())(x)) - - def testIntegerPowersArePrecise(self): - # See https://github.com/jax-ml/jax/pull/3036 - # Checks if the squares of float32 integers have no numerical errors. - # It should be satisfied with all integers less than sqrt(2**24). - x = jnp.arange(-2**12, 2**12, dtype=jnp.int32) - np.testing.assert_array_equal(jnp.square(x.astype(jnp.float32)), x * x) - np.testing.assert_array_equal(x.astype(jnp.float32) ** 2, x * x) - - # Similarly for cubes. - x = jnp.arange(-2**8, 2**8, dtype=jnp.int32) - np.testing.assert_array_equal(x.astype(jnp.float32) ** 3, x * x * x) - - x = np.arange(10, dtype=np.float32) - for i in range(10): - self.assertAllClose(x.astype(jnp.float32) ** i, x ** i, - check_dtypes=False) - - def testToBytes(self): - v = np.arange(12, dtype=np.int32).reshape(3, 4) - for order in ['C', 'F']: - self.assertEqual(jnp.asarray(v).tobytes(order), v.tobytes(order)) - - def testToBytesJitError(self): - v = np.arange(12, dtype=np.int32).reshape(3, 4) - f = jax.jit(lambda x: x.tobytes()) - msg = r".*The tobytes\(\) method was called on traced array" - with self.assertRaisesRegex(core.ConcretizationTypeError, msg): - f(v) - - def testToList(self): - v = np.arange(12, dtype=np.int32).reshape(3, 4) - self.assertEqual(jnp.asarray(v).tolist(), v.tolist()) - - def testToListJitError(self): - v = np.arange(12, dtype=np.int32).reshape(3, 4) - f = jax.jit(lambda x: x.tolist()) - msg = r".*The tolist\(\) method was called on traced array" - with self.assertRaisesRegex(core.ConcretizationTypeError, msg): - f(v) - - def testArangeConcretizationError(self): - msg = r"It arose in the jnp.arange argument '{}'".format - with self.assertRaisesRegex(core.ConcretizationTypeError, msg('stop')): - jax.jit(jnp.arange)(3) - - with self.assertRaisesRegex(core.ConcretizationTypeError, msg('start')): - jax.jit(lambda start: jnp.arange(start, 3))(0) - - with self.assertRaisesRegex(core.ConcretizationTypeError, msg('stop')): - jax.jit(lambda stop: jnp.arange(0, stop))(3) - - @jtu.sample_product(dtype=[None] + float_dtypes) - def testArange64Bit(self, dtype): - # Test that jnp.arange uses 64-bit arithmetic to define its range, even if the - # output has another dtype. The issue here is that if python scalar inputs to - # jnp.arange are cast to float32 before the range is computed, it changes the - # number of elements output by the range. It's unclear whether this was deliberate - # behavior in the initial implementation, but it's behavior that downstream users - # have come to rely on. - args = (1.2, 4.8, 0.24) - - # Ensure that this test case leads to differing lengths if cast to float32. - self.assertLen(np.arange(*args), 15) - self.assertLen(np.arange(*map(np.float32, args)), 16) - - jnp_fun = lambda: jnp.arange(*args, dtype=dtype) - np_fun = jtu.with_jax_dtype_defaults(lambda: np.arange(*args, dtype=dtype), dtype is None) - args_maker = lambda: [] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - def testIssue2347(self): - # https://github.com/jax-ml/jax/issues/2347 - object_list = list[tuple[jnp.array, float, float, jnp.array, bool]] - self.assertRaises(TypeError, jnp.array, object_list) - - np_object_list = np.array(object_list) - self.assertRaises(TypeError, jnp.array, np_object_list) - - @unittest.skip("JAX-metal don't support complex type yet.") - @jtu.sample_product( - [dict(shapes=shapes, dtypes=dtypes) - for shapes in filter( - _shapes_are_broadcast_compatible, - itertools.combinations_with_replacement(all_shapes, 2)) - for dtypes in itertools.product( - *(_valid_dtypes_for_shape(s, complex_dtypes + [None]) for s in shapes)) - ], - ) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testLogaddexpComplex(self, shapes, dtypes): - @jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*") - def np_op(x1, x2): - return np.log(np.exp(x1) + np.exp(x2)) - - rng = jtu.rand_some_nan(self.rng()) - args_maker = lambda: tuple(rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)) - if jtu.test_device_matches(["tpu"]): - tol = {np.complex64: 1e-3, np.complex128: 1e-10} - else: - tol = {np.complex64: 1e-5, np.complex128: 1e-14} - - with jtu.strict_promotion_if_dtypes_match(dtypes): - self._CheckAgainstNumpy(jtu.promote_like_jnp(np_op), jnp.logaddexp, args_maker, tol=tol) - self._CompileAndCheck(jnp.logaddexp, args_maker, rtol=tol, atol=tol) - - @unittest.skip("JAX-metal don't support complex type yet.") - @jtu.sample_product( - [dict(shapes=shapes, dtypes=dtypes) - for shapes in filter( - _shapes_are_broadcast_compatible, - itertools.combinations_with_replacement(all_shapes, 2)) - for dtypes in itertools.product( - *(_valid_dtypes_for_shape(s, complex_dtypes + [None]) for s in shapes)) - ], - ) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testLogaddexp2Complex(self, shapes, dtypes): - @jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*") - def np_op(x1, x2): - return np.log2(np.exp2(x1) + np.exp2(x2)) - - rng = jtu.rand_some_nan(self.rng()) - args_maker = lambda: tuple(rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)) - if jtu.test_device_matches(["tpu"]): - tol = {np.complex64: 1e-3, np.complex128: 1e-10} - else: - tol = {np.complex64: 1e-5, np.complex128: 1e-14} - - with jtu.strict_promotion_if_dtypes_match(dtypes): - self._CheckAgainstNumpy(jtu.promote_like_jnp(np_op), jnp.logaddexp2, args_maker, tol=tol) - self._CompileAndCheck(jnp.logaddexp2, args_maker, rtol=tol, atol=tol) - - def testDefaultDtypes(self): - precision = config.default_dtype_bits.value - assert precision in ['32', '64'] - self.assertEqual(jnp.bool_, np.bool_) - self.assertEqual(jnp.int_, np.int32 if precision == '32' else np.int64) - self.assertEqual(jnp.uint, np.uint32 if precision == '32' else np.uint64) - self.assertEqual(jnp.float_, np.float32 if precision == '32' else np.float64) - self.assertEqual(jnp.complex_, np.complex64 if precision == '32' else np.complex128) - - def testFromBuffer(self): - buf = b'\x01\x02\x03' - expected = np.frombuffer(buf, dtype='uint8') - actual = jnp.frombuffer(buf, dtype='uint8') - self.assertArraysEqual(expected, actual) - - def testFromFunction(self): - def f(x, y, z): - return x + 2 * y + 3 * z - shape = (3, 4, 5) - expected = np.fromfunction(f, shape=shape) - actual = jnp.fromfunction(f, shape=shape) - self.assertArraysEqual(expected, actual, check_dtypes=False) - - def testFromString(self): - s = "1,2,3" - expected = np.fromstring(s, sep=',', dtype=int) - actual = jnp.fromstring(s, sep=',', dtype=int) - self.assertArraysEqual(expected, actual) - - @jtu.sample_product( - a_shape=nonempty_nonscalar_array_shapes, - v_shape=nonempty_shapes, - dtype=jtu.dtypes.all, - ) - def testPlace(self, a_shape, v_shape, dtype): - rng = jtu.rand_default(self.rng()) - mask_rng = jtu.rand_bool(self.rng()) - - def args_maker(): - a = rng(a_shape, dtype) - m = mask_rng(a_shape, bool) - v = rng(v_shape, dtype) - return a, m, v - - def np_fun(a, m, v): - a_copy = a.copy() - np.place(a_copy, m, v) - return a_copy - - jnp_fun = partial(jnp.place, inplace=False) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - a_shape=nonempty_nonscalar_array_shapes, - i_shape=all_shapes, - v_shape=all_shapes, - dtype=jtu.dtypes.all, - mode=[None, 'wrap', 'clip'], - ) - def testPut(self, mode, a_shape, i_shape, v_shape, dtype): - size = math.prod(a_shape) - if math.prod(i_shape) > size: - self.skipTest("too many indices") - rng = jtu.rand_default(self.rng()) - # Must test unique integers, because overlapping updates in - # JAX have implementation-defined order - idx_rng = jtu.rand_unique_int(self.rng(), size) - - def args_maker(): - a = rng(a_shape, dtype) - i = idx_rng(i_shape, np.int32) - v = rng(v_shape, dtype) - # put some indices out of range without duplicating indices - if mode == "clip" and i.size: - np.put(i, np.argmax(i), size + 2) - np.put(i, np.argmin(i), -2) - if mode == "wrap" and i.size: - np.put(i, 0, np.take(i, 0) + size) - return a, i, v - - def np_fun(a, i, v): - a_copy = a.copy() - np.put(a_copy, i, v, mode=mode) - return a_copy - - jnp_fun = partial(jnp.put, mode=mode, inplace=False) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - def test_rot90_error(self): - with self.assertRaisesRegex( - ValueError, - "rot90 requires its first argument to have ndim at least two, " - "but got first argument of"): - jnp.rot90(jnp.ones(2)) - - @parameterized.named_parameters( - ('ones', jnp.ones), - ('zeros', jnp.zeros), - ('empty', jnp.empty)) - def test_error_hint(self, fn): - with self.assertRaisesRegex( - TypeError, - r"Did you accidentally write `jax\.numpy\..*?\(2, 3\)` " - r"when you meant `jax\.numpy\..*?\(\(2, 3\)\)`"): - fn(2, 3) - - @jtu.sample_product( - dtype=jtu.dtypes.all, - kind=['bool', 'signed integer', 'unsigned integer', 'integral', - 'real floating', 'complex floating', 'numeric'] - ) - def test_isdtype(self, dtype, kind): - # Full tests also in dtypes_test.py; here we just compare against numpy - jax_result = jnp.isdtype(dtype, kind) - if jtu.numpy_version() < (2, 0, 0) or dtype == dtypes.bfloat16: - # just a smoke test - self.assertIsInstance(jax_result, bool) - else: - numpy_result = np.isdtype(dtype, kind) - self.assertEqual(jax_result, numpy_result) - - -@unittest.skipIf(metal_plugin == None, "Tests require jax-metal plugin.") -class ReportedIssuesTests(jtu.JaxTestCase): - def dispatchOn(self, args, func, device=jax.devices('cpu')[0]): - deviceArgs = [] - for arg in args: - deviceArgs.append(jax.device_put(arg, device)) - return func(*deviceArgs) - - @staticmethod - def compile_and_exec(module, args, run_on_cpu=False): - from jax.extend.backend import get_backend - backend = get_backend('METAL') - if run_on_cpu: - backend = get_backend('cpu') - executable = backend.compile(module) - def put(arg): - return backend.buffer_from_pyval(arg, device=executable.local_devices()[0]) - arguments = [put(arg) for arg in args] - outputs = executable.execute(arguments) - return [np.asarray(x) for x in outputs] - - @staticmethod - def jax_metal_supported(target_ver): - if metal_plugin is None or not hasattr(metal_plugin, 'version'): - return False - curr_ver = metal_plugin.version() - if hasattr(jtu, 'parse_version'): - return jtu.parse_version(curr_ver) >= jtu.parse_version(target_ver) - return False - - - #https://github.com/jax-ml/jax/issues/16420 - def test_broadcast_dim(self): - x = jnp.arange(2) - f = lambda x : jax.lax.broadcast_in_dim(x, (2, 2), (0,)) - res = f(x) - print(res) - res_cpu = self.dispatchOn([x],f) - jtu.check_eq(res, res_cpu) - f = lambda x : jax.lax.broadcast_in_dim(x, (2, 2), (1,)) - res = f(x) - print(res) - res_cpu = self.dispatchOn([x],f) - jtu.check_eq(res, res_cpu) - - def test_identity(self): - x = jnp.identity(4) - jtu.check_eq(x, np.identity(4)) - - def test_triu(self): - x = np.ones((4,4)) - res = jnp.triu(x) - jtu.check_eq(res, np.triu(x)) - - #https://github.com/jax-ml/jax/issues/16471 - def test_matmul_1d(self): - x = np.array(np.random.rand(3, 3)) - y = np.array(np.random.rand(3)) - z = np.array(np.random.rand(3)) - res = jnp.dot(y, z) - self.assertArraysAllClose(res, np.dot(y,z)) - res = jnp.dot(x, y) - self.assertArraysAllClose(res, np.dot(x,y)) - - #https://github.com/jax-ml/jax/issues/17175 - def test_indexing(self): - x = jnp.array([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], dtype=jnp.float32) - @jax.vmap - def f(i): - return x[i] - f = jax.jit(f) - idx = jnp.array([1,1,2,2,0]) - res = f(idx) - jtu.check_eq(res, np.array([[4., 5., 6.], [4., 5., 6.], [7., 8., 9.], [7., 8., 9.], [1., 2., 3.]])) - - #https://github.com/jax-ml/jax/issues/17344 - def test_take_along_axis(self): - @jax.jit - def f(): - idx = jnp.array([[0],[0],[0]]) - x = jnp.array([[0.3756883, 0.05820537, 0.7399422, 0.45242703], - [0.5848844, 0.18772626, 0.47942543, 0.20703673], - [0.1071583, 0.26139486, 0.25664794, 0.8109596]]) - return jnp.take_along_axis(x, idx, axis=1) - jtu.check_eq(f(), self.dispatchOn([], f)) - - #https://github.com/jax-ml/jax/issues/17590 - def test_in1d(self): - a = np.array([123,2,4]) - b = np.array([123,1]) - res = jnp.isin(a,b) - jtu.check_eq(res, np.isin(a, b)) - - def test_indexing_update(self): - x = jnp.array([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], dtype=jnp.float32) - @jax.vmap - def f(x): - return x.at[0].set(1.0) - f = jax.jit(f) - res = f(x) - jtu.check_eq(res, np.array([[1., 2., 3.], [1., 5., 6.,], [1., 8., 9.], [1., 11., 12.]])) - - #https://github.com/jax-ml/jax/issues/16326 - def test_indexing_update2(self): - @jax.jit - def f(x, r): - x = x.at[:, 0].set(x[:, 0] / r) - return x - x = jnp.array([[1.0, 2.0], [3.0, 4.0]]) - fx = f(x, jnp.array([10.0])) - jtu.check_eq(fx, np.array([[0.1, 2.0], [0.3, 4.]])) - - def test_gather_ir(self): - ir = ''' -#loc = loc(unknown) -module @jit_gather attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<3x2x3xf32> {mhlo.sharding = "{replicated}"} loc(unknown), %arg1: tensor<3x2xi32> {mhlo.sharding = "{replicated}"} loc(unknown)) -> tensor<3x2xf32> { - %0 = "stablehlo.gather"(%arg0, %arg1) {dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array} : (tensor<3x2x3xf32>, tensor<3x2xi32>) -> tensor<3x2xf32> loc(#loc2) - return %0 : tensor<3x2xf32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/shuhan/Code/jax-metal/tests/lax_numpy_indexing_test.py":1156:0) -#loc2 = loc("jit(gather)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(0, 2), start_index_map=(0, 2)) slice_sizes=(1, 2, 1) unique_indices=False indices_are_sorted=False mode=GatherScatterMode.CLIP fill_value=None]"(#loc1)) - ''' - data = np.array([[[0.6369617, 0.26978672, 0.04097353], - [0.01652764, 0.8132702, 0.91275555]], - [[0.60663575, 0.72949654, 0.543625 ], - [0.9350724, 0.81585354, 0.0027385 ]], - [[0.8574043, 0.03358557, 0.72965544], - [0.17565562, 0.8631789, 0.5414612 ]]], dtype=np.float32) - index = np.array([[1, 0],[2, 1],[0, 2]], dtype=np.int32) - res = ReportedIssuesTests.compile_and_exec(ir, [data, index]) - res_ref = ReportedIssuesTests.compile_and_exec(ir, [data, index], run_on_cpu = True) - print(res) - jtu.check_eq(res, res_ref) - - #https://github.com/jax-ml/jax/issues/16366 - def test_pad_interior_1(self): - if not ReportedIssuesTests.jax_metal_supported('0.0.6'): - raise unittest.SkipTest("jax-metal version doesn't support it.") - ir = ''' - module @jit_gather attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<128x7x7x64xf32> {mhlo.sharding = "{replicated}"} loc(unknown), %arg1: tensor {mhlo.sharding = "{replicated}"} loc(unknown)) -> tensor<128x15x15x64xf32> { - %206 = "mhlo.pad"(%arg0, %arg1) {edge_padding_high = dense<[0, 1, 1, 0]> : tensor<4xi64>, edge_padding_low = dense<[0, 1, 1, 0]> : tensor<4xi64>, interior_padding = dense<[0, 1, 1, 0]> : tensor<4xi64>} : (tensor<128x7x7x64xf32>, tensor) -> tensor<128x15x15x64xf32> - return %206 : tensor<128x15x15x64xf32> - } - } - ''' - data = np.random.rand(128,7,7,64).astype(np.float32) - padding = np.array(0.5, dtype=np.float32) - res = ReportedIssuesTests.compile_and_exec(ir, [data, padding]) - res_ref = ReportedIssuesTests.compile_and_exec(ir, [data, padding], run_on_cpu = True) - jtu.check_eq(res, res_ref) - - def test_pad_interior_2(self): - if not ReportedIssuesTests.jax_metal_supported('0.0.6'): - raise unittest.SkipTest("jax-metal version doesn't support it.") - batch = 2 - seq_len = 8 - num_decode = 32 - - seq = np.random.randint(size=(batch, seq_len, num_decode), low=0, high=256, dtype=np.uint8) - res = jnp.cumsum(seq, axis=-1) - res_ref = np.cumsum(seq, axis=-1, dtype=np.uint8) - jtu.check_eq(res, res_ref) - - @unittest.expectedFailure - def test_issue_pad(self): - ir = ''' - module @jit_dummy attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<2x2xf32> {mhlo.sharding = "{replicated}"} loc(unknown), %arg1: tensor<3x4xf32> {mhlo.sharding = "{replicated}"} loc(unknown)) -> tensor<4x4xf32> { - %12 = stablehlo.slice %arg0 [0:1, 1:2] : (tensor<2x2xf32>) -> tensor<1x1xf32> - %13 = stablehlo.reshape %12 : (tensor<1x1xf32>) -> tensor - %14 = stablehlo.pad %arg1, %13, low = [0, 0], high = [1, 0], interior = [0, 0] : (tensor<3x4xf32>, tensor) -> tensor<4x4xf32> - return %14 : tensor<4x4xf32> - } - } - ''' - data = np.array([[1, 3], [1, 3]], dtype=np.float32) - input = np.random.rand(3,4).astype(np.float32) - res = ReportedIssuesTests.compile_and_exec(ir, [data, input]) - res_ref = ReportedIssuesTests.compile_and_exec(ir, [data, input], run_on_cpu = True) - jtu.check_eq(res, res_ref) - -if __name__ == "__main__": - absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_numpy_einsum_test.py b/tests/lax_numpy_einsum_test.py index ea7bff1d09fc..a40f063e8970 100644 --- a/tests/lax_numpy_einsum_test.py +++ b/tests/lax_numpy_einsum_test.py @@ -432,6 +432,13 @@ def jnp_fun(*args, signature=signature, optimize=optimize): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=1E-4) self._CompileAndCheck(jnp_fun, args_maker, rtol=1E-4) + def test_einsum_unsupported_optimization(self): + rng = jtu.rand_default(self.rng()) + arrs = (rng(shape, 'float32') for shape in [(2, 3), (2, 4), (2, 5)]) + msg = "jax.numpy.einsum does not support simultaneous contraction of 3 or more operands" + with self.assertRaisesRegex(NotImplementedError, msg): + jnp.einsum('ij,ik,il->jkl', *arrs, optimize=[(0, 1, 2)]) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index 63a725ad3643..defecb9acea5 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -35,7 +35,7 @@ from jax._src import test_util as jtu from jax._src import util from jax._src.lax import lax as lax_internal -from jax._src.util import NumpyComplexWarning +from jax._src.numpy import indexing config.parse_flags_with_absl() @@ -68,7 +68,7 @@ def check_grads(f, args, order, atol=None, rtol=None, eps=None): jtu.check_vjp(f, partial(jax.vjp, f), args, atol, rtol, eps) -STATIC_INDEXING_TESTS = [ +STATIC_SLICE_TESTS = [ ("OneIntIndex", [ IndexSpec(shape=(3,), indexer=1, out_shape=()), IndexSpec(shape=(3, 3), indexer=0, out_shape=(3,)), @@ -103,13 +103,6 @@ def check_grads(f, args, order, atol=None, rtol=None, eps=None): IndexSpec(shape=(10, 8), indexer=slice(0, 8, -1), out_shape=(0, 8)), IndexSpec(shape=(10, 8), indexer=slice(None, None, -1), out_shape=(10, 8)), ]), - ("SliceIndexClamping", [ - IndexSpec(shape=(10,), indexer=slice(2, 11, 1), out_shape=(8,)), - IndexSpec(shape=(10,), indexer=slice(11, 12, 1), out_shape=(0,)), - IndexSpec(shape=(10,), indexer=slice(-11, -2, 1), out_shape=(8,)), - IndexSpec(shape=(10,), indexer=slice(-2, -12, -1), out_shape=(9,)), - IndexSpec(shape=(10,), indexer=slice(12, -12, -1), out_shape=(10,)), - ]), ("OneSliceIndexNonUnitStride", [ IndexSpec(shape=(10,), indexer=slice(0, 8, 2), out_shape=(4,)), IndexSpec(shape=(10,), indexer=slice(0, 8, 3), out_shape=(3,)), @@ -157,6 +150,13 @@ def check_grads(f, args, order, atol=None, rtol=None, eps=None): IndexSpec(shape=(3, 4, 5), indexer=(0, Ellipsis), out_shape=(4, 5)), IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, 2, 3), out_shape=(3,)), ]), + ("SliceIndexClamping", [ + IndexSpec(shape=(10,), indexer=slice(2, 11, 1), out_shape=(8,)), + IndexSpec(shape=(10,), indexer=slice(11, 12, 1), out_shape=(0,)), + IndexSpec(shape=(10,), indexer=slice(-11, -2, 1), out_shape=(8,)), + IndexSpec(shape=(10,), indexer=slice(-2, -12, -1), out_shape=(9,)), + IndexSpec(shape=(10,), indexer=slice(12, -12, -1), out_shape=(10,)), + ]), ("NoneIndex", [ IndexSpec(shape=(), indexer=None, out_shape=(1,)), IndexSpec(shape=(), indexer=(None, None), out_shape=(1, 1)), @@ -166,12 +166,18 @@ def check_grads(f, args, order, atol=None, rtol=None, eps=None): IndexSpec(shape=(3, 4), indexer=(Ellipsis, None), out_shape=(3, 4, 1)), IndexSpec(shape=(3, 4), indexer=(0, None, Ellipsis), out_shape=(1, 4)), IndexSpec(shape=(3, 4, 5), indexer=(1, None, Ellipsis), out_shape=(1, 4, 5)), + IndexSpec(shape=(3, 4, 5), indexer=(1, None, slice(None), None), out_shape=(1, 4, 1, 5)), ]), ("EmptyIndex", [ IndexSpec(shape=(), indexer=(), out_shape=()), IndexSpec(shape=(3,), indexer=(), out_shape=(3,)), IndexSpec(shape=(3, 4), indexer=(), out_shape=(3, 4)), ]), +] + + +STATIC_INDEXING_TESTS = [ + *STATIC_SLICE_TESTS, ("TupleOfIntAndSliceAndIntArray", [ IndexSpec(shape=(3, 2, 3), indexer=(0, slice(None), np.arange(3)), out_shape=(3, 2)), @@ -232,6 +238,11 @@ def check_grads(f, args, order, atol=None, rtol=None, eps=None): IndexSpec(shape=(3,), indexer=np.array([0, 1, 0]), out_shape=(3,)), IndexSpec(shape=(3, 4, 5), indexer=np.array([ 0, -1]), out_shape=(2, 4, 5)), ]), + ("TupleOfEmptyList", [ + IndexSpec(shape=(3, 4), indexer=([],), out_shape=(0, 4)), + IndexSpec(shape=(3, 4), indexer=([], 0), out_shape=(0,)), + IndexSpec(shape=(3, 4), indexer=([], []), out_shape=(0,)), + ]), ("TupleOfListsOfPythonInts", [ IndexSpec(shape=(3, 4, 5), indexer=([0, 1],), out_shape=(2, 4, 5)), IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], [[2, 3, 0, 3]]), @@ -431,6 +442,80 @@ def check_grads(f, args, order, atol=None, rtol=None, eps=None): MODES = ["clip", "drop", "promise_in_bounds"] +class IndexingStrategyTest(jtu.JaxTestCase): + """Tests for arr.static_slice[...]""" + + @jtu.sample_product( + [dict(name=name, shape=shape, indexer=indexer) + for name, index_specs in STATIC_SLICE_TESTS + for shape, indexer, _ in index_specs], + dtype=all_dtypes, + strategy=[indexing.IndexingStrategy.AUTO, + indexing.IndexingStrategy.DYNAMIC_SLICE, + indexing.IndexingStrategy.STATIC_SLICE, + indexing.IndexingStrategy.GATHER], + mode=[None, "clip", "promise_in_bounds"], + ) + def test_simple_indexing(self, name, shape, dtype, indexer, strategy, mode): + del name # unused within test + tuple_indexer = indexer if isinstance(indexer, tuple) else (indexer,) + if (strategy == indexing.IndexingStrategy.STATIC_SLICE and + any(isinstance(i, np.ndarray) for i in tuple_indexer)): + self.skipTest("array indices not supported with STATIC_SLICE.") + if (strategy == indexing.IndexingStrategy.DYNAMIC_SLICE and + any(isinstance(i, slice) and not (i.step is None or i.step in [-1, 1]) + for i in tuple_indexer)): + self.skipTest("non-unit step sizes not supported with DYNAMIC_SLICE") + + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + np_fun = lambda x: np.asarray(x)[indexer] + jnp_fun = partial(indexing.rewriting_take, idx=indexer, strategy=strategy, mode=mode) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + shape=[(3, 4), (3, 5, 2)], + dtype=all_dtypes, + indexer=[(-2,), (-1, -2), (10,), (10, 1)], + strategy=[indexing.IndexingStrategy.AUTO, + indexing.IndexingStrategy.DYNAMIC_SLICE, + indexing.IndexingStrategy.STATIC_SLICE, + indexing.IndexingStrategy.GATHER], + normalize_indices=[True, False] + ) + def test_simple_indexing_oob(self, shape, dtype, indexer, strategy, normalize_indices): + """Test negative and out-of-bound index handling for indexing strategies.""" + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + if normalize_indices: + np_indexer = tuple(np.clip(i, -size, size - 1) + for i, size in zip(indexer, shape)) + else: + np_indexer = tuple(np.clip(i, 0, size - 1) + for i, size in zip(indexer, shape)) + np_fun = lambda x: np.asarray(x)[np_indexer] + jnp_fun = partial(indexing.rewriting_take, idx=indexer, strategy=strategy, + normalize_indices=normalize_indices, mode='clip') + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @parameterized.parameters( + ((2,), -4, IndexError, "index -4 out of bounds for axis 0 with size 2"), + ((2,), 4, IndexError, "index 4 out of bounds for axis 0 with size 2"), + ((2, 3), np.index_exp[:, 4], IndexError, "index 4 out of bounds for axis 1 with size 3"), + ((2, 3), np.index_exp[..., -4], IndexError, "index -4 out of bounds for axis 1 with size 3"), + ((2, 3, 5), np.index_exp[3, :, 0], IndexError, "index 3 out of bounds for axis 0 with size 2"), + ((2, 3), ([1, 2], 0), TypeError, "static_slice: indices must be static scalars or slices."), + ((2, 3), (np.arange(2), 0), TypeError, "static_slice: indices must be static scalars or slices."), + ((2, 3), (1, 2, 3), IndexError, "Too many indices: array is 2-dimensional, but 3 were indexed"), + ) + def test_slice_oob_indexing_fails(self, shape, idx, err, msg): + arr = jnp.zeros(shape) + with self.assertRaisesRegex(err, msg): + indexing.rewriting_take(arr, idx, strategy=indexing.IndexingStrategy.STATIC_SLICE) + + class IndexingTest(jtu.JaxTestCase): """Tests for Numpy indexing translation rules.""" @@ -908,34 +993,49 @@ def testJVPOfGradOfIndexing(self): def testSimpleIndexingUsesSlice(self): jaxpr = jax.make_jaxpr(lambda x: x[:2, :2])(jnp.ones((3, 4))) - self.assertEqual(len(jaxpr.jaxpr.eqns), 1) - self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.slice_p) + eqn, = jaxpr.jaxpr.eqns + self.assertEqual(eqn.primitive, lax.slice_p) + self.assertIsNone(eqn.params['strides']) jaxpr = jax.make_jaxpr(lambda x: x[0, :2, 1])(jnp.ones((3, 4, 5))) - self.assertEqual(len(jaxpr.jaxpr.eqns), 2) - self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p) - self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) + slice_eqn, squeeze_eqn = jaxpr.jaxpr.eqns + self.assertEqual(slice_eqn.primitive, lax.slice_p) + self.assertEqual(squeeze_eqn.primitive, lax.squeeze_p) + self.assertIsNone(slice_eqn.params['strides']) jaxpr = jax.make_jaxpr(lambda x: x[0, 0])(jnp.ones((3, 4, 5))) - self.assertEqual(len(jaxpr.jaxpr.eqns), 2) - self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p) - self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) + slice_eqn, squeeze_eqn = jaxpr.jaxpr.eqns + self.assertEqual(slice_eqn.primitive, lax.slice_p) + self.assertEqual(squeeze_eqn.primitive, lax.squeeze_p) + self.assertIsNone(slice_eqn.params['strides']) jaxpr = jax.make_jaxpr(lambda x: x[:, 1])(jnp.ones((3, 4, 5))) - self.assertEqual(len(jaxpr.jaxpr.eqns), 2) - self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p) - self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) + slice_eqn, squeeze_eqn = jaxpr.jaxpr.eqns + self.assertEqual(slice_eqn.primitive, lax.slice_p) + self.assertEqual(squeeze_eqn.primitive, lax.squeeze_p) + self.assertIsNone(slice_eqn.params['strides']) - # Indexing with `Ellipsis` is not lowered to `gather`. + # Indexing with `Ellipsis` is not lowered to `gather` ... jaxpr = jax.make_jaxpr(lambda x: x[..., 0])(jnp.ones((3, 4, 5))) - self.assertLen((jaxpr.jaxpr.eqns), 2) - self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p) - self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) + slice_eqn, squeeze_eqn = jaxpr.jaxpr.eqns + self.assertEqual(slice_eqn.primitive, lax.slice_p) + self.assertEqual(squeeze_eqn.primitive, lax.squeeze_p) + self.assertIsNone(slice_eqn.params['strides']) + + # ... even when the ellipsis expands to no dimensions. + jaxpr = jax.make_jaxpr(lambda x: x[..., 0:1])(jnp.ones((3,))) + eqn, = jaxpr.jaxpr.eqns + self.assertEqual(eqn.primitive, lax.slice_p) + self.assertIsNone(eqn.params['strides']) + jaxpr = jax.make_jaxpr(lambda x: x[0:1, ...])(jnp.ones((3,))) + eqn, = jaxpr.jaxpr.eqns + self.assertEqual(eqn.primitive, lax.slice_p) + self.assertIsNone(eqn.params['strides']) # Simple reverses lower to lax.rev_p jaxpr = jax.make_jaxpr(lambda x: x[:, ::-1])(jnp.ones((3, 4))) - self.assertEqual(len(jaxpr.jaxpr.eqns), 1) - self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.rev_p) + eqn, = jaxpr.jaxpr.eqns + self.assertEqual(eqn.primitive, lax.rev_p) # Non-static indices produce a dynamic slice jaxpr = jax.make_jaxpr(lambda x, i: x[i])(jnp.ones((4,)), 2) @@ -980,10 +1080,10 @@ def testIndexingEmptyDimension(self): "index .* is out of bounds for axis .* with size 0"): _ = np.ones((2, 0))[0, 0] # The numpy error with self.assertRaisesRegex(IndexError, - "index is out of bounds for axis .* with size 0"): + "index .* out of bounds for axis .* with size 0"): _ = x[0, 0] # JAX indexing with self.assertRaisesRegex(IndexError, - "index is out of bounds for axis .* with size 0"): + "index .* out of bounds for axis .* with size 0"): jax.jit(lambda i: x[0, i])(0) # JAX indexing under jit def testBooleanIndexingWithEmptyResult(self): @@ -1132,6 +1232,47 @@ def testStrIndexingError(self): with self.assertRaisesRegex(TypeError, msg): jnp.zeros((2, 3))[:, 'abc'] + @jtu.sample_product( + mode=["promise_in_bounds", "fill", "clip", "drop"], + wrap_negative_indices=[True, False], + shape=[(5,), (10,)], + idx_shape=[(5,)], + ) + def testWrapNegativeIndices1D(self, mode, wrap_negative_indices, shape, idx_shape): + """Test the behavior of the wrap_negative_indices parameter in array.at[...].get()""" + fill_value = 99 + + data_rng = jtu.rand_default(self.rng()) + idx_rng = jtu.rand_uniform(self.rng(), low=-12, high=12) + + args_maker = lambda: [data_rng(shape, 'float32'), idx_rng(idx_shape, 'int32')] + + def jnp_fun(data, idx): + return jnp.array(data).at[idx].get( + mode=mode, + fill_value=fill_value, + wrap_negative_indices=wrap_negative_indices) + + def np_fun(data, idx): + if wrap_negative_indices: + idx = np.where(idx < 0, idx + len(data), idx) + out_of_bound = (idx < 0) | (idx >= len(data)) + safe_idx = np.where(out_of_bound, 0, idx) + result = data[safe_idx] + if mode in ["fill", "drop"]: + result = np.where(out_of_bound, fill_value, result) + elif mode in ["promise_in_bounds", "clip"]: + result = np.where(idx < 0, data[0], + np.where(idx >= len(data), data[-1], + result)) + else: + raise ValueError(f"Unrecognized mode {mode!r}") + return result + + tol = 1E-4 if jtu.test_device_matches(["tpu"]) else None + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol) + self._CompileAndCheck(jnp_fun, args_maker, tol=tol) + def testIndexOutOfBounds(self): # https://github.com/jax-ml/jax/issues/2245 x = jnp.arange(5, dtype=jnp.int32) + 1 self.assertAllClose(x, x[:10]) @@ -1157,7 +1298,8 @@ def testIndexOutOfBounds(self): # https://github.com/jax-ml/jax/issues/2245 jnp.array([7, 7, 1, 2, 1, 4, 5, 7, 7, 7], jnp.int32)) def testIndexingWeakTypes(self): - x = lax_internal._convert_element_type(jnp.arange(5), float, weak_type=True) + x = lax_internal._convert_element_type(jnp.arange(5), dtypes.dtype(float), + weak_type=True) a = x.at[0].set(1.0) self.assertEqual(a.dtype, x.dtype) @@ -1178,7 +1320,7 @@ def _check(x_type, y_type): out = x.at[0].set(y) self.assertEqual(x.dtype, out.dtype) - @jtu.ignore_warning(category=NumpyComplexWarning, + @jtu.ignore_warning(category=np.exceptions.ComplexWarning, message="Casting complex values to real") def _check_warns(x_type, y_type, msg): with self.assertWarnsRegex(FutureWarning, msg): @@ -1235,13 +1377,28 @@ def _check_raises(x_type, y_type, msg): def testWrongNumberOfIndices(self): with self.assertRaisesRegex( IndexError, - "Too many indices: 0-dimensional array indexed with 1 regular index."): + "Too many indices: array is 0-dimensional, but 1 were indexed"): jnp.array(1)[0] with self.assertRaisesRegex( IndexError, - "Too many indices: 1-dimensional array indexed with 2 regular indices."): + "Too many indices: array is 1-dimensional, but 2 were indexed"): jnp.zeros(3)[:, 5] + @jtu.sample_product(shape=[(), (1,)]) + def testIndexDtypePromotion(self, shape): + # Regression test for https://github.com/jax-ml/jax/issues/31396 + numbers = jnp.arange(1000)[:, None] + idx = jnp.int8(0).reshape(shape) + expected = np.array(999).reshape(shape) + self.assertArraysEqual(numbers[999, idx], expected) + + def testIndexingTypedNdArray(self): + x = jnp.arange(4) + i = dtypes.canonicalize_value(np.array([2, 0, 1])) + result = x[i] + expected = x[jnp.asarray(i)] + self.assertArraysEqual(result, expected) + def _broadcastable_shapes(shape): """Returns all shapes that broadcast to `shape`.""" @@ -1284,22 +1441,30 @@ class UpdateOps(enum.Enum): def np_fn(op, indexer, x, y): x = x.copy() - x[indexer] = { - UpdateOps.UPDATE: lambda: y, - UpdateOps.ADD: lambda: x[indexer] + y, - UpdateOps.SUB: lambda: x[indexer] - y, - UpdateOps.MUL: lambda: x[indexer] * y, - UpdateOps.DIV: jtu.ignore_warning(category=RuntimeWarning)( - lambda: x[indexer] / y.astype(x.dtype)), - UpdateOps.POW: jtu.ignore_warning(category=RuntimeWarning)( - lambda: x[indexer] ** y.astype(x.dtype)), - UpdateOps.MIN: lambda: np.minimum(x[indexer], y), - UpdateOps.MAX: lambda: np.maximum(x[indexer], y), - }[op]() + if op == UpdateOps.UPDATE: + x[indexer] = y + elif op == UpdateOps.ADD: + np.add.at(x, indexer, y) + elif op == UpdateOps.SUB: + np.subtract.at(x, indexer, y) + elif op == UpdateOps.MUL: + np.multiply.at(x, indexer, y) + elif op == UpdateOps.DIV: + with jtu.ignore_warning(category=RuntimeWarning): + np.divide.at(x, indexer, y) + elif op == UpdateOps.POW: + with jtu.ignore_warning(category=RuntimeWarning): + np.power.at(x, indexer, y) + elif op == UpdateOps.MIN: + np.minimum.at(x, indexer, y.astype(x.dtype)) + elif op == UpdateOps.MAX: + np.maximum.at(x, indexer, y.astype(x.dtype)) + else: + raise ValueError(f"{op=}") return x def jax_fn(op, indexer, x, y, indices_are_sorted=False, - unique_indices=False, mode=None): + unique_indices=False, mode=None, wrap_negative_indices=True): x = jnp.array(x) return { UpdateOps.UPDATE: x.at[indexer].set, @@ -1311,7 +1476,8 @@ def jax_fn(op, indexer, x, y, indices_are_sorted=False, UpdateOps.MIN: x.at[indexer].min, UpdateOps.MAX: x.at[indexer].max, }[op](y, indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode) + unique_indices=unique_indices, mode=mode, + wrap_negative_indices=wrap_negative_indices) def dtypes(op): if op == UpdateOps.UPDATE: @@ -1424,6 +1590,52 @@ def testMixedAdvancedIndexing(self, name, shape, dtype, update_shape, self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, tol=_update_tol(op)) self._CompileAndCheck(jax_fn, args_maker) + @jtu.sample_product( + op=UpdateOps, + mode=["fill", "clip"], + wrap_negative_indices=[True, False], + shape=[(5,), (10,)], + update_shape=[(5,)], + ) + def testWrapNegativeIndices1D(self, op, mode, wrap_negative_indices, shape, update_shape): + rng = jtu.rand_default(self.rng()) + idx_rng = jtu.rand_unique_int(self.rng(), high=shape[0]) + + def args_maker(): + data = rng(shape, 'float32').round(1) + update = rng(update_shape, 'float32').round(1) + # we need indices to be unique, so we generate unique values in [0, N) + # and then subtract N from half of them. To test out-of-bound behavior + # we push the bottom and top index out-of-bounds + idx = idx_rng(update_shape, 'int32') + idx = np.where(rng(update_shape, bool), idx, idx - shape[0]) + idx[idx == shape[0] - 1] = shape[0] + 2 # out-of-bound positive + idx[idx == -shape[0]] = -(shape[0] + 2) # out-of-bound negative + return data, idx, update + + def jnp_fun(data, idx, values): + return UpdateOps.jax_fn(op, idx, data, values, + mode=mode, + wrap_negative_indices=wrap_negative_indices) + + def np_fun(data, idx, values): + if wrap_negative_indices: + idx = np.where(idx < 0, idx + len(data), idx) + if mode in ["fill", "drop", "promise_in_bounds"]: + ok = (idx >= 0) & (idx < len(data)) + idx = idx[ok] + values = values[ok] + elif mode == "clip": + idx = np.where(idx < 0, 0, idx) + idx = np.where(idx >= len(data), len(data) - 1, idx) + else: + raise ValueError(f"Unrecognized mode {mode!r}") + return UpdateOps.np_fn(op, idx, data, values) + + tol = 1E-4 if jtu.test_device_matches(["tpu"]) else None + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol) + self._CompileAndCheck(jnp_fun, args_maker, tol=tol) + @jtu.sample_product( [dict(name=name, mode=mode, shape=shape, indexer=indexer, update_shape=update_shape) @@ -1691,5 +1903,54 @@ def f(x, i): self.assertArraysEqual(jax.jacrev(f)(x, i), expected) self.assertArraysEqual(jax.jacrev(jax.vmap(f, (None, 0)))(x, i), expected) +@jtu.with_config(jax_check_static_indices=True) +class ValidateIndicesTest(jtu.JaxTestCase): + @parameterized.parameters( + ((2,), -4, IndexError, "index -4 out of bounds for axis 0 with size 2"), + ((2,), 4, IndexError, "index 4 out of bounds for axis 0 with size 2"), + ((2, 3), np.index_exp[:, 4], IndexError, "index 4 out of bounds for axis 1 with size 3"), + ((2, 3), np.index_exp[..., -4], IndexError, "index -4 out of bounds for axis 1 with size 3"), + ((2, 3, 5), np.index_exp[3, :, 0], IndexError, "index 3 out of bounds for axis 0 with size 2"), + ((2, 3, 5), np.index_exp[:5, :, 6], IndexError, "index 6 out of bounds for axis 2 with size 5"), + ((2, 3, 5), np.index_exp[:, [1, 2], 6], IndexError, "index 6 out of bounds for axis 2 with size 5"), + ((2, 3, 5), np.index_exp[np.arange(3), 6, None], IndexError, "index 6 out of bounds for axis 1 with size 3"), + ((2, 3), (1, 2, 3), IndexError, "Too many indices: array is 2-dimensional, but 3 were indexed"), + ) + def test_out_of_bound_indices(self, shape, idx, err, msg): + """Test that out-of-bound indexing """ + arr = jnp.zeros(shape) + + with self.subTest("eager"): + with self.assertRaisesRegex(err, msg): + arr[idx] + + with self.subTest("jit"): + with self.assertRaisesRegex(err, msg): + jax.jit(lambda x: x[idx])(arr) + + with self.subTest("arr.at[idx].get()"): + with self.assertRaisesRegex(err, msg): + arr.at[idx].get() + + @jtu.sample_product( + [dict(name=name, shape=shape, indexer=indexer) + for name, index_specs in STATIC_INDEXING_TESTS + for shape, indexer, _ in index_specs], + dtype=all_dtypes + ) + def test_simple_indexing(self, name, shape, dtype, indexer): + """Test that in-bound indexing works correctly.""" + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + np_fun = lambda x: np.asarray(x)[indexer] + jnp_fun = lambda x: jnp.asarray(x)[indexer] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + # Tests x.at[...].get(...) as well. + jnp_fun = lambda x: jnp.asarray(x).at[indexer].get() + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_numpy_operators_test.py b/tests/lax_numpy_operators_test.py index 744a99fb70e7..4ea6c655fe86 100644 --- a/tests/lax_numpy_operators_test.py +++ b/tests/lax_numpy_operators_test.py @@ -27,7 +27,6 @@ import numpy as np import jax -import jax.ops from jax import lax from jax import numpy as jnp @@ -213,9 +212,6 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, test_name="expm1_large", tolerance={np.float64: 1e-8}, inexact=True), op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_small_positive, [], tolerance={np.float64: 1e-8}, inexact=True), - op_record("fix", 1, float_dtypes, all_shapes, jtu.rand_default, []), - op_record("fix", 1, int_dtypes + unsigned_dtypes, all_shapes, - jtu.rand_default, [], check_dtypes=False), op_record("floor_divide", 2, default_dtypes + unsigned_dtypes, all_shapes, jtu.rand_nonzero, ["rev"]), op_record("fmin", 2, number_dtypes, all_shapes, jtu.rand_some_nan, []), @@ -275,9 +271,8 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, []), op_record("rint", 1, int_dtypes + unsigned_dtypes, all_shapes, jtu.rand_default, [], check_dtypes=False), - # numpy < 2.0.0 has a different convention for complex sign. - op_record("sign", 1, real_dtypes if jtu.numpy_version() < (2, 0, 0) else number_dtypes, - all_shapes, jtu.rand_some_inf_and_nan, []), + op_record("sign", 1, number_dtypes, all_shapes, jtu.rand_some_inf_and_nan, + []), # numpy 1.16 has trouble mixing uint and bfloat16, so we test these separately. op_record("copysign", 2, default_dtypes + unsigned_dtypes, all_shapes, jtu.rand_some_inf_and_nan, [], check_dtypes=False), @@ -325,11 +320,8 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, jtu.rand_fullrange, []), op_record("bitwise_xor", 2, int_dtypes + unsigned_dtypes, all_shapes, jtu.rand_fullrange, []), + op_record("bitwise_count", 1, int_dtypes, all_shapes, jtu.rand_fullrange, []), ] -if hasattr(np, "bitwise_count"): - # Numpy versions after 1.26 - JAX_BITWISE_OP_RECORDS.append( - op_record("bitwise_count", 1, int_dtypes, all_shapes, jtu.rand_fullrange, [])) JAX_OPERATOR_OVERLOADS = [ op_record("__add__", 2, number_dtypes, all_shapes, jtu.rand_default, []), @@ -480,7 +472,7 @@ def testOp(self, op_name, rng_factory, shapes, dtypes, check_dtypes, "arccosh", "arcsinh", "sinh", "cosh", "tanh", "sin", "cos", "tan", "log", "log1p", "log2", "log10", "exp", "expm1", "exp2", "pow", "power", "logaddexp", "logaddexp2", "i0", "acosh", "asinh"): - tol = jtu.join_tolerance(tol, 1e-4) + tol = jtu.join_tolerance(tol, 2e-4) tol = functools.reduce(jtu.join_tolerance, [tolerance, tol, jtu.default_tolerance()]) @@ -618,14 +610,9 @@ def testBitwiseOp(self, name, rng_factory, shapes, dtypes, alias): dtype=int_dtypes + unsigned_dtypes, ) def testBitwiseCount(self, shape, dtype): - # np.bitwise_count added after numpy 1.26, but - # np_scalar.bit_count() is available before that. - np_fun = getattr( - np, "bitwise_count", - np.vectorize(lambda x: np.ravel(x)[0].bit_count(), otypes=['uint8'])) rng = jtu.rand_fullrange(self.rng()) args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp.bitwise_count, args_maker) + self._CheckAgainstNumpy(np.bitwise_count, jnp.bitwise_count, args_maker) self._CompileAndCheck(jnp.bitwise_count, args_maker) @jtu.sample_product( @@ -655,14 +642,7 @@ def testShiftOpAgainstNumpy(self, op, dtypes, shapes): # NumPy requires shifts to be non-negative and below the bit width: shift_rng = jtu.rand_int(self.rng(), high=max(info.bits, shift_info.bits)) args_maker = lambda: (x_rng(shapes[0], dtype), shift_rng(shapes[1], shift_dtype)) - - if jtu.numpy_version() < (2, 0, 0) and op.__name__ in ("bitwise_left_shift", "bitwise_right_shift"): - # numpy < 2.0.0 does not have bitwise shift functions. - op_name = op.__name__.removeprefix("bitwise_") - else: - op_name = op.__name__ - - np_op = getattr(np, op_name) + np_op = getattr(np, op.__name__) with jtu.strict_promotion_if_dtypes_match(dtypes): self._CompileAndCheck(op, args_maker) @@ -675,10 +655,7 @@ def testShiftOpAgainstNumpy(self, op, dtypes, shapes): ) def testSignComplex(self, shape, dtype): rng = jtu.rand_default(self.rng()) - if jtu.numpy_version() >= (2, 0, 0): - np_fun = np.sign - else: - np_fun = lambda x: (x / np.where(x == 0, 1, abs(x))).astype(np.result_type(x)) + np_fun = np.sign jnp_fun = jnp.sign args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 0c3f1d1471fb..b868c2fa4694 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -29,11 +29,9 @@ from jax._src import config from jax._src import dtypes from jax._src import test_util as jtu -from jax._src.util import NumpyComplexWarning config.parse_flags_with_absl() -numpy_version = jtu.numpy_version() nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)] nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes @@ -169,8 +167,11 @@ def _reducer_output_dtype(name: str, input_dtype: np.dtype, promote_integers: bo input_dtype = dtypes.to_numeric_dtype(input_dtype) if promote_integers: if dtypes.issubdtype(input_dtype, np.integer): - default_int = dtypes.canonicalize_dtype( - dtypes.uint if dtypes.issubdtype(input_dtype, np.unsignedinteger) else dtypes.int_) + default_int = ( + dtypes.default_uint_dtype() + if dtypes.issubdtype(input_dtype, np.unsignedinteger) + else dtypes.default_int_dtype() + ) if np.iinfo(input_dtype).bits < np.iinfo(default_int).bits: return default_int return input_dtype @@ -209,7 +210,7 @@ def testReducer(self, name, rng_factory, shape, dtype, out_dtype, np_op = getattr(np, name) jnp_op = getattr(jnp, name) rng = rng_factory(self.rng()) - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) @jtu.ignore_warning(category=RuntimeWarning, message="Mean of empty slice.*") @jtu.ignore_warning(category=RuntimeWarning, @@ -225,7 +226,7 @@ def np_fun(x): return np_op(x_cast, axis, dtype=t, keepdims=keepdims) jnp_fun = lambda x: jnp_op(x, axis, dtype=out_dtype, keepdims=keepdims) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + jnp_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] tol_spec = {np.float16: 1e-2, np.int16: 2e-7, np.int32: 1E-3, np.uint32: 3e-7, np.float32: 1e-3, np.complex64: 1e-3, @@ -313,7 +314,7 @@ def testReducerInitial(self, name, rng_factory, shape, dtype, axis, is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' @jtu.ignore_warning(category=RuntimeWarning, message="Degrees of freedom <= 0 for slice.*") - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) def np_fun(x): x = np.asarray(x) if inexact: @@ -324,7 +325,7 @@ def np_fun(x): return res.astype(_reducer_output_dtype(name, x.dtype)) jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + jnp_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] tol = {jnp.bfloat16: 3E-2} self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol, atol=tol) @@ -353,7 +354,7 @@ def testReducerPromoteInt(self, name, rng_factory, shape, dtype, axis, rng_factory.__name__ == 'rand_some_nan') @jtu.ignore_warning(category=RuntimeWarning, message="Degrees of freedom <= 0 for slice.*") - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) def np_fun(x): x = np.asarray(x) if inexact: @@ -364,7 +365,7 @@ def np_fun(x): return res.astype(_reducer_output_dtype(name, x.dtype, promote_integers)) jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial, promote_integers=promote_integers) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + jnp_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] tol = {jnp.bfloat16: 3E-2, jnp.float16: 5e-3} self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol) @@ -390,7 +391,7 @@ def testReducerNoInitialZeroDims(self, name, rng_factory, shape, dtype, axis, is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' @jtu.ignore_warning(category=RuntimeWarning, message="Degrees of freedom <= 0 for slice.*") - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) def np_fun(x): x = np.asarray(x) if inexact: @@ -401,7 +402,7 @@ def np_fun(x): return res.astype(_reducer_output_dtype(name, x.dtype)) jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + jnp_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] tol = {jnp.bfloat16: 3E-2} self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol) @@ -436,7 +437,7 @@ def testReducerWhere(self, name, rng_factory, shape, dtype, axis, where = jtu.rand_bool(self.rng())(whereshape, np.bool_) @jtu.ignore_warning(category=RuntimeWarning, message="Degrees of freedom <= 0 for slice.*") - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) def np_fun(x): x = np.asarray(x) if inexact: @@ -447,7 +448,7 @@ def np_fun(x): return res.astype(_reducer_output_dtype(name, x.dtype)) jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial, where=where) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + jnp_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, atol=tol, rtol=tol) self._CompileAndCheck(jnp_fun, args_maker) @@ -458,8 +459,8 @@ def testReducerWhereNonBooleanErrorInitial(self, rec): x = jnp.zeros((10,), dtype) where = jnp.ones(10, dtype=int) func = getattr(jnp, rec.name) - with self.assertDeprecationWarnsOrRaises("jax-numpy-reduction-non-boolean-where", - f"jnp.{rec.name}: where must be None or a boolean array"): + with self.assertRaisesRegex( + ValueError, f"jnp.{rec.name}: where must be None or a boolean array"): func(x, where=where, initial=jnp.array(0, dtype=dtype)) @jtu.sample_product(rec=JAX_REDUCER_WHERE_NO_INITIAL_RECORDS) @@ -468,8 +469,8 @@ def testReducerWhereNonBooleanErrorNoInitial(self, rec): x = jnp.zeros((10,), dtype) where = jnp.ones(10, dtype=int) func = getattr(jnp, rec.name) - with self.assertDeprecationWarnsOrRaises("jax-numpy-reduction-non-boolean-where", - f"jnp.{rec.name}: where must be None or a boolean array"): + with self.assertRaisesRegex( + ValueError, f"jnp.{rec.name}: where must be None or a boolean array"): func(x, where=where) @parameterized.parameters(itertools.chain.from_iterable( @@ -499,7 +500,7 @@ def testReducerWhereNoInitial(self, name, rng_factory, shape, dtype, axis, message="Mean of empty slice.*") @jtu.ignore_warning(category=RuntimeWarning, message="invalid value encountered.*") - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) def np_fun(x): x = np.asarray(x) if inexact: @@ -510,7 +511,7 @@ def np_fun(x): return res jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, where=where) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + jnp_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, atol=tol, rtol=tol) self._CompileAndCheck(jnp_fun, args_maker) @@ -558,6 +559,26 @@ def testAverage(self, shape, dtype, axis, weights_shape, returned, keepdims): self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=check_dtypes, rtol=tol, atol=tol) + @parameterized.parameters( + dict(shape=(2, 3, 4), axis=(1, 2)), + dict(shape=(2, 3, 4), axis=(2, 0)), + dict(shape=(2, 3, 4), axis=(0, 1, 2)), + dict(shape=(2, 3, 4), axis=(2, 0, 1)), + dict(shape=(2, 3, 4), axis=(2, 1, 0)), + dict(shape=(2, 3, 4, 5), axis=(3, 0)), + dict(shape=(2, 3, 4, 5), axis=(3, 0, 2, 1)), + ) + def testAverageNDWeights(self, shape, axis): + weights_shape = tuple(shape[ax] for ax in axis) + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, np.float32), rng(weights_shape, np.float32)] + np_fun = lambda x, weights: np.average(x, axis, weights) + jnp_fun = lambda x, weights: jnp.average(x, axis, weights) + tol = {dtypes.bfloat16: 2e-1, np.float16: 1e-2, np.float32: 1e-5, + np.float64: 1e-12, np.complex64: 1e-5} + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol) + self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( test_fns=[(np.var, jnp.var), (np.std, jnp.std)], shape=[(5,), (10, 5)], @@ -574,13 +595,12 @@ def testStdOrVar(self, test_fns, shape, dtype, out_dtype, axis, ddof_correction, args_maker = self._GetArgsMaker(rng, [shape], [dtype]) @jtu.ignore_warning(category=RuntimeWarning, message="Degrees of freedom <= 0 for slice.") - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) def np_fun(x): # setup ddof and correction kwargs excluding case when correction is not specified ddof_correction_kwargs = {"ddof": ddof} if correction is not None: - key = "correction" if numpy_version >= (2, 0) else "ddof" - ddof_correction_kwargs[key] = correction + ddof_correction_kwargs["correction"] = correction # Numpy fails with bfloat16 inputs out = np_fn(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype), dtype=np.float32 if out_dtype == dtypes.bfloat16 else out_dtype, @@ -625,7 +645,7 @@ def testNanVar(self, shape, dtype, out_dtype, axis, ddof, keepdims): args_maker = self._GetArgsMaker(rng, [shape], [dtype]) @jtu.ignore_warning(category=RuntimeWarning, message="Degrees of freedom <= 0 for slice.") - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) def np_fun(x): # Numpy fails with bfloat16 inputs out = np.nanvar(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype), @@ -654,6 +674,58 @@ def testNanStdGrad(self): z = jax.grad(jnp.nanstd)(x) self.assertEqual(jnp.isnan(z).sum(), 0) + @jtu.sample_product( + [ + dict(shape=(5,), axis=None), + dict(shape=(5,), axis=0), + dict(shape=(5,), axis=-1), + dict(shape=(10, 5), axis=None), + dict(shape=(10, 5), axis=0), + dict(shape=(10, 5), axis=1), + dict(shape=(10, 5), axis=-1), + dict(shape=(10, 5), axis=(0, 1)), + dict(shape=(5, 4, 3), axis=(0, 2)), + dict(shape=(5, 4, 3), axis=(1, -1)), + ], + jnp_fn_name=["var", "std", "nanvar", "nanstd"], + dtype=inexact_dtypes + int_dtypes, + ddof=[0, 1], + keepdims=[False, True], + ) + def testReducerWithMean(self, jnp_fn_name, shape, dtype, axis, ddof, keepdims): + """Tests variance and standard deviation functions with a pre-supplied mean.""" + jnp_fn = getattr(jnp, jnp_fn_name) + np_fn = getattr(np, jnp_fn_name) + is_nan_test = "nan" in jnp_fn_name + + # Generate a random mean value. This should have NaNs if the test is for NaNs. + input_rng = jtu.rand_some_nan(self.rng()) if is_nan_test else jtu.rand_default(self.rng()) + # Generate a random mean value. This should never have NaNs. + mean_rng = jtu.rand_default(self.rng()) + mean_shape = np.mean(np.zeros(shape, dtype=dtype), axis=axis, keepdims=True).shape + mean_dtype = dtypes.to_inexact_dtype(dtype) + + args_maker = lambda: [input_rng(shape, dtype), mean_rng(mean_shape, mean_dtype)] + + def jnp_wrapper(x, mean_val): + return jnp_fn(x, axis=axis, ddof=ddof, keepdims=keepdims, mean=mean_val) + + @jtu.ignore_warning(category=RuntimeWarning, message="Degrees of freedom <= 0 for slice.") + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) + def np_wrapper(x, mean_val): + if dtype in int_dtypes: + x = x.astype(dtypes.to_inexact_dtype(x.dtype)) + x_cast = x.astype(np.float32) if x.dtype == dtypes.bfloat16 else x + mean_cast = mean_val.astype(np.float32) if mean_val.dtype == dtypes.bfloat16 else mean_val + return np_fn(x_cast, axis=axis, ddof=ddof, keepdims=keepdims, mean=mean_cast) + + tol_spec = {np.float16: 1e-1, np.float32: 1e-3, np.float64: 1e-5, np.complex128: 1e-6, + np.int8: 1e-4, np.int16: 1e-4, np.int32: 1e-4, np.int64: 1e-5} + self._CheckAgainstNumpy(np_wrapper, jnp_wrapper, args_maker, + check_dtypes=dtype != jnp.bfloat16, + tol=tol_spec) + self._CompileAndCheck(jnp_wrapper, args_maker, rtol=tol_spec, atol=tol_spec) + @jtu.sample_product( [dict(shape=shape, dtype=dtype, y_dtype=y_dtype, rowvar=rowvar, y_shape=y_shape) @@ -691,6 +763,48 @@ def testCov(self, shape, dtype, y_shape, y_dtype, rowvar, ddof, bias, fweights, self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol) + @jtu.sample_product( + shape=[(0,), (3, 0), (0, 5)], + dtype=jtu.dtypes.floating, + rowvar=[True, False], + ) + @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion + @jax.default_matmul_precision('float32') + def testEmptyCov(self, shape, dtype, rowvar): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + ignore_warning = jtu.ignore_warning( + category=RuntimeWarning, + message="(Mean of empty slice|[Ii]nvalid value|Degrees of freedom|divide by zero)") + np_fun = ignore_warning(partial(np.cov, rowvar=rowvar)) + jnp_fun = partial(jnp.cov, rowvar=rowvar) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker) + + @unittest.skipIf(jtu.numpy_version() < (2, 2, 0), "test covers NumPy 2.2+ behavior.") + @jtu.sample_product( + shape=[(1, 3), (3, 1)], + rowvar=[True, False] + ) + def testCovTransposeBehavior(self, shape, rowvar): + # Tests compatibility with NumPy 2.2 API change: + # https://github.com/numpy/numpy/pull/27661 + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, np.float32)] + np_fun = partial(np.cov, rowvar=rowvar, ddof=0) + jnp_fun = partial(jnp.cov, rowvar=rowvar, ddof=0) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker) + + def testCovDtype(self): + x = jnp.arange(5) + result_bf16 = jnp.cov(x, dtype='bfloat16') + self.assertEqual(result_bf16.dtype, np.dtype('bfloat16')) + + with self.assertRaisesRegex(ValueError, "cov: dtype must be a subclass of float or complex"): + jnp.cov(x, dtype=int) + + @jtu.sample_product( [dict(op=op, q_rng=q_rng) for (op, q_rng) in ( @@ -747,15 +861,6 @@ def np_fun(*args): tol=tol) self._CompileAndCheck(jnp_fun, args_maker, rtol=tol) - @jtu.sample_product( - op=['quantile', 'nanquantile', 'percentile', 'nanpercentile'] - ) - def testQuantileDeprecatedArgs(self, op): - func = getattr(jnp, op) - with self.assertDeprecationWarnsOrRaises("jax-numpy-quantile-interpolation", - f"The interpolation= argument to '{op}' is deprecated. "): - func(jnp.arange(4), 0.5, interpolation='linear') - @unittest.skipIf(not config.enable_x64.value, "test requires X64") @jtu.run_on_devices("cpu") # test is for CPU float64 precision def testPercentilePrecision(self): @@ -803,6 +908,11 @@ def testMeanLargeArray(self): self.assertEqual(1.0, jnp.mean(x)) self.assertEqual(1.0, jnp.mean(x, where=True)) + def testMeanVeryLargeArray(self): + # https://github.com/jax-ml/jax/pull/30769 + x = jax.ShapeDtypeStruct((1 << 32,), jnp.dtype('float32')) + jax.eval_shape(jnp.mean, x) + def testStdLargeArray(self): # https://github.com/jax-ml/jax/issues/15068 raise unittest.SkipTest("test is slow, but it passes!") @@ -834,7 +944,7 @@ def test_f16_mean(self, dtype): ], include_initial=[False, True], ) - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion def testCumulativeSum(self, shape, axis, dtype, out_dtype, include_initial): rng = jtu.rand_some_zero(self.rng()) @@ -902,11 +1012,11 @@ def testCumulativeSumBool(self): ], include_initial=[False, True], ) - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion def testCumulativeProd(self, shape, axis, dtype, out_dtype, include_initial): - if jtu.is_device_tpu(6): - raise unittest.SkipTest("TODO(b/364258243): Test fails on TPU v6e") + if jtu.is_device_tpu_at_least(6): + raise unittest.SkipTest("TODO(b/364258243): Test fails on TPU v6+") rng = jtu.rand_some_zero(self.rng()) # We currently "cheat" to ensure we have JAX arrays, not NumPy arrays as diff --git a/tests/lax_numpy_setops_test.py b/tests/lax_numpy_setops_test.py new file mode 100644 index 000000000000..bdb1634fcef4 --- /dev/null +++ b/tests/lax_numpy_setops_test.py @@ -0,0 +1,406 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial, wraps +import itertools +import math + +from absl.testing import absltest + +import numpy as np + +import jax +from jax import lax +import jax.numpy as jnp +from jax._src import config +from jax._src import test_util as jtu + +config.parse_flags_with_absl() + + +nonempty_array_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)] +empty_array_shapes = [(0,), (0, 4), (3, 0),] + +scalar_shapes = [jtu.NUMPY_SCALAR_SHAPE, jtu.PYTHON_SCALAR_SHAPE] +array_shapes = nonempty_array_shapes + empty_array_shapes +nonempty_shapes = scalar_shapes + nonempty_array_shapes +all_shapes = scalar_shapes + array_shapes + +default_dtypes = jtu.dtypes.all_floating + jtu.dtypes.all_integer +inexact_dtypes = jtu.dtypes.all_floating + jtu.dtypes.complex +number_dtypes = default_dtypes + jtu.dtypes.complex + jtu.dtypes.all_unsigned + + +def np_unique_backport(ar, return_index=False, return_inverse=False, return_counts=False, + axis=None, **kwds): + # Wrapper for np.unique, handling the change to inverse_indices in numpy 2.0 + result = np.unique(ar, return_index=return_index, return_inverse=return_inverse, + return_counts=return_counts, axis=axis, **kwds) + if jtu.numpy_version() >= (2, 0, 1) or np.ndim(ar) == 1 or not return_inverse: + return result + + idx = 2 if return_index else 1 + inverse_indices = result[idx] + if axis is None: + inverse_indices = inverse_indices.reshape(np.shape(ar)) + elif jtu.numpy_version() == (2, 0, 0): + inverse_indices = inverse_indices.reshape(-1) + return (*result[:idx], inverse_indices, *result[idx + 1:]) + +def arrays_with_overlapping_values(rng, shapes, dtypes, unique=False, overlap=0.5) -> list[jax.Array]: + """Generate multiple arrays with some overlapping values. + + This is useful for tests of set-like operations. + """ + assert 0 <= overlap <= 1 + sizes = [math.prod(jtu._dims_of_shape(shape)) for shape in shapes] + total_size = int(sum(sizes) * (1 - overlap)) + max(sizes) # non-strict upper-bound. + if unique: + vals = jtu.rand_unique_int(rng)((total_size,), 'int32') + else: + vals = jtu.rand_default(rng)((total_size,), 'int32') + offsets = [int(sum(sizes[:i]) * (1 - overlap)) for i in range(len(sizes))] + return [rng.permutation(vals[offset: offset + size]).reshape(shape).astype(dtype) + for (offset, size, shape, dtype) in zip(offsets, sizes, shapes, dtypes)] + +def with_size_argument(fun): + @wraps(fun) + def wrapped(*args, size=None, fill_value=None, **kwargs): + result = fun(*args, **kwargs) + if size is None or size == len(result): + return result + elif size < len(result): + return result[:size] + else: + if fill_value is None: + fill_value = result.min() if result.size else 0 + return np.pad(result, (0, size - len(result)), constant_values=fill_value) + return wrapped + + +class LaxNumpySetopsTest(jtu.JaxTestCase): + """Tests of set-like operations from jax.numpy.""" + + @jtu.sample_product( + element_shape=all_shapes, + test_shape=all_shapes, + dtype=default_dtypes, + invert=[False, True], + method=['auto', 'compare_all', 'binary_search', 'sort'] + ) + def testIsin(self, element_shape, test_shape, dtype, invert, method): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(element_shape, dtype), rng(test_shape, dtype)] + jnp_fun = lambda e, t: jnp.isin(e, t, invert=invert, method=method) + np_fun = lambda e, t: np.isin(e, t, invert=invert) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + dtype1=[s for s in default_dtypes if s != jnp.bfloat16], + dtype2=[s for s in default_dtypes if s != jnp.bfloat16], + shape1=all_shapes, + shape2=all_shapes, + overlap=[0.1, 0.5, 0.9], + ) + def testSetdiff1d(self, shape1, shape2, dtype1, dtype2, overlap): + args_maker = partial(arrays_with_overlapping_values, self.rng(), + shapes=[shape1, shape2], dtypes=[dtype1, dtype2], + overlap=overlap) + with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): + self._CheckAgainstNumpy(np.setdiff1d, jnp.setdiff1d, args_maker) + + @jtu.sample_product( + dtype1=[s for s in default_dtypes if s != jnp.bfloat16], + dtype2=[s for s in default_dtypes if s != jnp.bfloat16], + shape1=all_shapes, + shape2=all_shapes, + size=[1, 5, 10], + fill_value=[None, -1], + overlap=[0.1, 0.5, 0.9], + ) + def testSetdiff1dSize(self, shape1, shape2, dtype1, dtype2, size, fill_value, overlap): + args_maker = partial(arrays_with_overlapping_values, self.rng(), + shapes=[shape1, shape2], dtypes=[dtype1, dtype2], + overlap=overlap) + def np_fun(arg1, arg2): + result = np.setdiff1d(arg1, arg2) + if size <= len(result): + return result[:size] + else: + return np.pad(result, (0, size-len(result)), constant_values=fill_value or 0) + def jnp_fun(arg1, arg2): + return jnp.setdiff1d(arg1, arg2, size=size, fill_value=fill_value) + with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + shape1=all_shapes, + shape2=all_shapes, + ) + def testSetdiff1dAssumeUnique(self, shape1, shape2): + # regression test for https://github.com/jax-ml/jax/issues/32335 + args_maker = lambda: (jnp.arange(math.prod(shape1), dtype='int32').reshape(shape1), + jnp.arange(math.prod(shape2), dtype='int32').reshape(shape2)) + np_op = partial(np.setdiff1d, assume_unique=True) + jnp_op = partial(jnp.setdiff1d, assume_unique=True) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + + @jtu.sample_product( + dtype1=[s for s in default_dtypes if s != jnp.bfloat16], + dtype2=[s for s in default_dtypes if s != jnp.bfloat16], + shape1=all_shapes, + shape2=all_shapes, + overlap=[0.1, 0.5, 0.9], + ) + def testUnion1d(self, shape1, shape2, dtype1, dtype2, overlap): + args_maker = partial(arrays_with_overlapping_values, self.rng(), + shapes=[shape1, shape2], dtypes=[dtype1, dtype2], + overlap=overlap) + def np_fun(arg1, arg2): + dtype = jnp.promote_types(arg1.dtype, arg2.dtype) + return np.union1d(arg1, arg2).astype(dtype) + with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): + self._CheckAgainstNumpy(np_fun, jnp.union1d, args_maker) + + @jtu.sample_product( + dtype1=[s for s in default_dtypes if s != jnp.bfloat16], + dtype2=[s for s in default_dtypes if s != jnp.bfloat16], + shape1=nonempty_shapes, + shape2=nonempty_shapes, + size=[1, 5, 10], + fill_value=[None, -1], + overlap=[0.1, 0.5, 0.9], + ) + def testUnion1dSize(self, shape1, shape2, dtype1, dtype2, size, fill_value, overlap): + args_maker = partial(arrays_with_overlapping_values, self.rng(), + shapes=[shape1, shape2], dtypes=[dtype1, dtype2], + overlap=overlap) + def np_fun(arg1, arg2): + dtype = jnp.promote_types(arg1.dtype, arg2.dtype) + result = np.union1d(arg1, arg2).astype(dtype) + fv = result.min() if fill_value is None else fill_value + if size <= len(result): + return result[:size] + else: + return np.concatenate([result, np.full(size - len(result), fv, result.dtype)]) + def jnp_fun(arg1, arg2): + return jnp.union1d(arg1, arg2, size=size, fill_value=fill_value) + with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + dtype1=[s for s in default_dtypes if s != jnp.bfloat16], + dtype2=[s for s in default_dtypes if s != jnp.bfloat16], + shape1=all_shapes, + shape2=all_shapes, + assume_unique=[False, True], + size=[None, 2, 5], + fill_value=[None, 99], + overlap=[0.1, 0.5, 0.9], + ) + def testSetxor1d(self, shape1, dtype1, shape2, dtype2, assume_unique, size, fill_value, overlap): + args_maker = partial(arrays_with_overlapping_values, self.rng(), + shapes=[shape1, shape2], dtypes=[dtype1, dtype2], + overlap=overlap) + jnp_fun = lambda ar1, ar2: jnp.setxor1d(ar1, ar2, assume_unique=assume_unique, + size=size, fill_value=fill_value) + def np_fun(ar1, ar2): + if assume_unique: + # numpy requires 1D inputs when assume_unique is True. + ar1 = np.ravel(ar1) + ar2 = np.ravel(ar2) + return with_size_argument(np.setxor1d)(ar1, ar2, assume_unique, size=size, fill_value=fill_value) + with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + + @jtu.sample_product( + dtype1=[s for s in default_dtypes if s != jnp.bfloat16], + dtype2=[s for s in default_dtypes if s != jnp.bfloat16], + shape1=nonempty_shapes, + shape2=nonempty_shapes, + assume_unique=[False, True], + return_indices=[False, True], + size=[None, 3, 5], + fill_value=[None, -1], + overlap=[0.1, 0.5, 0.9], + ) + def testIntersect1d(self, shape1, dtype1, shape2, dtype2, assume_unique, + return_indices, size, fill_value, overlap): + args_maker = partial(arrays_with_overlapping_values, self.rng(), + shapes=[shape1, shape2], dtypes=[dtype1, dtype2], + unique=assume_unique, overlap=overlap) + + def jnp_fun(ar1, ar2): + return jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices, + size=size, fill_value=fill_value) + + def np_fun(ar1, ar2): + result = np.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) + def correct_size(x, fill_value): + if size is None or size == len(x): + return x + elif size < len(x): + return x[:size] + else: + if fill_value is None: + fill_value = x.min() + return np.pad(x, (0, size - len(x)), constant_values=fill_value) + if return_indices: + return tuple(correct_size(r, f) for r, f in zip(result, [fill_value, ar1.size, ar2.size])) + else: + return correct_size(result, fill_value) + + with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in all_shapes + for axis in [None] + list(range(len(shape)))], + dtype=number_dtypes, + return_index=[False, True], + return_inverse=[False, True], + return_counts=[False, True], + ) + def testUnique(self, shape, dtype, axis, return_index, return_inverse, return_counts): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + extra_args = (return_index, return_inverse, return_counts) + use_defaults = (False, *(True for arg in extra_args if arg)) if any(extra_args) else False + np_fun = jtu.with_jax_dtype_defaults(lambda x: np_unique_backport(x, *extra_args, axis=axis), use_defaults) + jnp_fun = lambda x: jnp.unique(x, *extra_args, axis=axis) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + @jtu.sample_product(shape=all_shapes, dtype=number_dtypes) + def testUniqueAll(self, shape, dtype): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(jnp.unique_all, np.unique_all, args_maker) + + @jtu.sample_product(shape=all_shapes, dtype=number_dtypes) + def testUniqueCounts(self, shape, dtype): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(jnp.unique_counts, np.unique_counts, args_maker) + + @jtu.sample_product(shape=all_shapes, dtype=number_dtypes) + def testUniqueInverse(self, shape, dtype): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(jnp.unique_inverse, np.unique_inverse, args_maker) + + @jtu.sample_product(shape=all_shapes, dtype=number_dtypes) + def testUniqueValues(self, shape, dtype): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + np_fun = lambda *args: np.sort(np.unique_values(*args)) + self._CheckAgainstNumpy(jnp.unique_values, np_fun, args_maker) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in nonempty_array_shapes + for axis in [None] + list(range(len(shape)))], + dtype=number_dtypes, + size=[1, 5, 10], + fill_value=[None, 0, "slice"], + ) + def testUniqueSize(self, shape, dtype, axis, size, fill_value): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + kwds = dict(axis=axis, return_index=True, return_inverse=True, return_counts=True) + + if fill_value == "slice": + if axis is None: + fill_value = rng((), dtype) + else: + fill_value = rng(shape[:axis] + shape[axis + 1:], dtype) + elif fill_value is not None: + fill_value = np.array(fill_value).astype(dtype) + + @partial(jtu.with_jax_dtype_defaults, use_defaults=(False, True, True, True)) + def np_fun(x, fill_value=fill_value): + u, ind, inv, counts = np_unique_backport(x, **kwds) + axis = kwds['axis'] + if axis is None: + x = x.ravel() + axis = 0 + + n_unique = u.shape[axis] + if size <= u.shape[axis]: + slc = (slice(None),) * axis + (slice(size),) + u, ind, counts = u[slc], ind[:size], counts[:size] + else: + extra = (0, size - n_unique) + pads = [(0, 0)] * u.ndim + pads[axis] = extra + u = np.pad(u, pads, constant_values=0) + slices = [slice(None)] * u.ndim + slices[axis] = slice(1) + if fill_value is None: + fill_value = u[tuple(slices)] + elif np.ndim(fill_value): + fill_value = lax.expand_dims(fill_value, (axis,)) + slices[axis] = slice(n_unique, None) + u[tuple(slices)] = fill_value + ind = np.pad(ind, extra, constant_values=ind[0]) + counts = np.pad(counts, extra, constant_values=0) + return u, ind, inv, counts + + jnp_fun = lambda x: jnp.unique(x, size=size, fill_value=fill_value, **kwds) + + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product(dtype=inexact_dtypes) + def testUniqueNans(self, dtype): + def args_maker(): + x = [-0.0, 0.0, 1.0, 1.0, np.nan, -np.nan] + if np.issubdtype(dtype, np.complexfloating): + x = [complex(i, j) for i, j in itertools.product(x, repeat=2)] + return [np.array(x, dtype=dtype)] + + kwds = dict(return_index=True, return_inverse=True, return_counts=True) + jnp_fun = partial(jnp.unique, **kwds) + def np_fun(x): + dtype = x.dtype + # numpy unique fails for bfloat16 NaNs, so we cast to float64 + if x.dtype == jnp.bfloat16: + x = x.astype('float64') + u, *rest = np.unique(x, **kwds) + return (u.astype(dtype), *rest) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + @jtu.sample_product(dtype=inexact_dtypes, equal_nan=[True, False]) + @jtu.ignore_warning( + category=RuntimeWarning, message='invalid value encountered in cast' + ) + def testUniqueEqualNan(self, dtype, equal_nan): + shape = (20,) + rng = jtu.rand_some_nan(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + def np_fun(x): + dtype = x.dtype + # numpy unique fails for bfloat16 NaNs, so we cast to float64 + if x.dtype == jnp.bfloat16: + x = x.astype('float64') + return np.unique(x, equal_nan=equal_nan).astype(dtype) + jnp_fun = partial(jnp.unique, equal_nan=equal_nan) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 98f10d9c02b3..2d694c8c83b6 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -18,7 +18,7 @@ import collections from collections.abc import Iterator import copy -from functools import partial, wraps +from functools import partial import inspect import io import itertools @@ -38,7 +38,6 @@ numpy_dispatch = None import jax -import jax.ops from jax import lax from jax import numpy as jnp from jax.sharding import SingleDeviceSharding @@ -47,10 +46,11 @@ from jax._src import array from jax._src import config from jax._src import core +from jax._src import deprecations from jax._src import dtypes from jax._src import test_util as jtu from jax._src.lax import lax as lax_internal -from jax._src.util import safe_zip, NumpyComplexWarning, tuple_replace +from jax._src.util import safe_zip, tuple_update config.parse_flags_with_absl() @@ -99,11 +99,13 @@ def _bitcast_uint8_to_uint4(operand): result[..., 1::2] = ((operand & 0b11110000) >> 4).astype('uint4') return result -def np_view(arr, dtype): +def np_view(arr: np.ndarray, dtype) -> np.ndarray: # Implementation of np.ndarray.view() that works for int4/uint4 + if dtype is None: + return arr dtype = np.dtype(dtype) - nbits_in = dtypes.bit_width(arr.dtype) - nbits_out = dtypes.bit_width(dtype) + nbits_in = dtypes.itemsize_bits(arr.dtype) + nbits_out = dtypes.itemsize_bits(dtype) if nbits_in == 4: arr = _bitcast_uint4_to_uint8(arr.view('uint4')) if nbits_out == 4: @@ -111,23 +113,6 @@ def np_view(arr, dtype): return arr.view(dtype) -def np_unique_backport(ar, return_index=False, return_inverse=False, return_counts=False, - axis=None, **kwds): - # Wrapper for np.unique, handling the change to inverse_indices in numpy 2.0 - result = np.unique(ar, return_index=return_index, return_inverse=return_inverse, - return_counts=return_counts, axis=axis, **kwds) - if jtu.numpy_version() >= (2, 0, 1) or np.ndim(ar) == 1 or not return_inverse: - return result - - idx = 2 if return_index else 1 - inverse_indices = result[idx] - if axis is None: - inverse_indices = inverse_indices.reshape(np.shape(ar)) - elif jtu.numpy_version() == (2, 0, 0): - inverse_indices = inverse_indices.reshape(-1) - return (*result[:idx], inverse_indices, *result[idx + 1:]) - - def _indexer_with_default_outputs(indexer, use_defaults=True): """Like jtu.with_jax_dtype_defaults, but for __getitem__ APIs""" class Indexer: @@ -185,38 +170,6 @@ def _shapes_are_equal_length(shapes): return all(len(shape) == len(shapes[0]) for shape in shapes[1:]) -def arrays_with_overlapping_values(rng, shapes, dtypes, unique=False, overlap=0.5) -> list[jax.Array]: - """Generate multiple arrays with some overlapping values. - - This is useful for tests of set-like operations. - """ - assert 0 <= overlap <= 1 - sizes = [math.prod(jtu._dims_of_shape(shape)) for shape in shapes] - total_size = int(sum(sizes) * (1 - overlap)) + max(sizes) # non-strict upper-bound. - if unique: - vals = jtu.rand_unique_int(rng)((total_size,), 'int32') - else: - vals = jtu.rand_default(rng)((total_size,), 'int32') - offsets = [int(sum(sizes[:i]) * (1 - overlap)) for i in range(len(sizes))] - return [rng.permutation(vals[offset: offset + size]).reshape(shape).astype(dtype) - for (offset, size, shape, dtype) in zip(offsets, sizes, shapes, dtypes)] - - -def with_size_argument(fun): - @wraps(fun) - def wrapped(*args, size=None, fill_value=None, **kwargs): - result = fun(*args, **kwargs) - if size is None or size == len(result): - return result - elif size < len(result): - return result[:size] - else: - if fill_value is None: - fill_value = result.min() if result.size else 0 - return np.pad(result, (0, size - len(result)), constant_values=fill_value) - return wrapped - - class LaxBackedNumpyTests(jtu.JaxTestCase): """Tests for LAX-backed Numpy implementation.""" @@ -337,7 +290,7 @@ def f(x): for shape in all_shapes for axis in list(range(-len(shape), len(shape)))], discont=[None, "pi", 2], - period=["2pi", "pi"], + period=["2pi", "pi", 2, 4], dtype=default_dtypes, ) def testUnwrap(self, shape, dtype, axis, discont, period): @@ -360,7 +313,7 @@ def np_fun(x): # This case requires implicit dtype promotion jnp_fun = jax.numpy_dtype_promotion('standard')(jnp_fun) args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, atol={dtypes.bfloat16: 1e-1, np.float16: 1e-2}) self._CompileAndCheck(jnp_fun, args_maker, atol={dtypes.bfloat16: 1e-1}) @@ -665,8 +618,7 @@ def testVecdot(self, lhs_batch, rhs_batch, axis_size, axis, dtype): args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] @jtu.promote_like_jnp def np_fn(x, y, axis=axis): - f = jtu.numpy_vecdot if jtu.numpy_version() < (2, 0, 0) else np.vecdot - return f(x, y, axis=axis).astype(x.dtype) + return np.vecdot(x, y, axis=axis).astype(x.dtype) jnp_fn = partial(jnp.vecdot, axis=axis) tol = {np.float16: 1e-2, np.float32: 1E-3, np.float64: 1e-12, np.complex64: 1E-3, np.complex128: 1e-12, jnp.bfloat16: 1e-1} @@ -773,169 +725,6 @@ def testTensordotErrors(self): TypeError, "tensordot axes argument must be an int, a pair of ints, or a pair of lists.*", lambda: jnp.tensordot(a, b, axes='badaxes')) - @jtu.sample_product( - element_shape=all_shapes, - test_shape=all_shapes, - dtype=default_dtypes, - invert=[False, True], - method=['auto', 'compare_all', 'binary_search', 'sort'] - ) - def testIsin(self, element_shape, test_shape, dtype, invert, method): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(element_shape, dtype), rng(test_shape, dtype)] - jnp_fun = lambda e, t: jnp.isin(e, t, invert=invert, method=method) - np_fun = lambda e, t: np.isin(e, t, invert=invert) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - dtype1=[s for s in default_dtypes if s != jnp.bfloat16], - dtype2=[s for s in default_dtypes if s != jnp.bfloat16], - shape1=all_shapes, - shape2=all_shapes, - overlap=[0.1, 0.5, 0.9], - ) - def testSetdiff1d(self, shape1, shape2, dtype1, dtype2, overlap): - args_maker = partial(arrays_with_overlapping_values, self.rng(), - shapes=[shape1, shape2], dtypes=[dtype1, dtype2], - overlap=overlap) - with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): - self._CheckAgainstNumpy(np.setdiff1d, jnp.setdiff1d, args_maker) - - @jtu.sample_product( - dtype1=[s for s in default_dtypes if s != jnp.bfloat16], - dtype2=[s for s in default_dtypes if s != jnp.bfloat16], - shape1=all_shapes, - shape2=all_shapes, - size=[1, 5, 10], - fill_value=[None, -1], - overlap=[0.1, 0.5, 0.9], - ) - def testSetdiff1dSize(self, shape1, shape2, dtype1, dtype2, size, fill_value, overlap): - args_maker = partial(arrays_with_overlapping_values, self.rng(), - shapes=[shape1, shape2], dtypes=[dtype1, dtype2], - overlap=overlap) - def np_fun(arg1, arg2): - result = np.setdiff1d(arg1, arg2) - if size <= len(result): - return result[:size] - else: - return np.pad(result, (0, size-len(result)), constant_values=fill_value or 0) - def jnp_fun(arg1, arg2): - return jnp.setdiff1d(arg1, arg2, size=size, fill_value=fill_value) - with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - dtype1=[s for s in default_dtypes if s != jnp.bfloat16], - dtype2=[s for s in default_dtypes if s != jnp.bfloat16], - shape1=all_shapes, - shape2=all_shapes, - overlap=[0.1, 0.5, 0.9], - ) - def testUnion1d(self, shape1, shape2, dtype1, dtype2, overlap): - args_maker = partial(arrays_with_overlapping_values, self.rng(), - shapes=[shape1, shape2], dtypes=[dtype1, dtype2], - overlap=overlap) - def np_fun(arg1, arg2): - dtype = jnp.promote_types(arg1.dtype, arg2.dtype) - return np.union1d(arg1, arg2).astype(dtype) - with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): - self._CheckAgainstNumpy(np_fun, jnp.union1d, args_maker) - - @jtu.sample_product( - dtype1=[s for s in default_dtypes if s != jnp.bfloat16], - dtype2=[s for s in default_dtypes if s != jnp.bfloat16], - shape1=nonempty_shapes, - shape2=nonempty_shapes, - size=[1, 5, 10], - fill_value=[None, -1], - overlap=[0.1, 0.5, 0.9], - ) - def testUnion1dSize(self, shape1, shape2, dtype1, dtype2, size, fill_value, overlap): - args_maker = partial(arrays_with_overlapping_values, self.rng(), - shapes=[shape1, shape2], dtypes=[dtype1, dtype2], - overlap=overlap) - def np_fun(arg1, arg2): - dtype = jnp.promote_types(arg1.dtype, arg2.dtype) - result = np.union1d(arg1, arg2).astype(dtype) - fv = result.min() if fill_value is None else fill_value - if size <= len(result): - return result[:size] - else: - return np.concatenate([result, np.full(size - len(result), fv, result.dtype)]) - def jnp_fun(arg1, arg2): - return jnp.union1d(arg1, arg2, size=size, fill_value=fill_value) - with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - dtype1=[s for s in default_dtypes if s != jnp.bfloat16], - dtype2=[s for s in default_dtypes if s != jnp.bfloat16], - shape1=all_shapes, - shape2=all_shapes, - assume_unique=[False, True], - size=[None, 2, 5], - fill_value=[None, 99], - overlap=[0.1, 0.5, 0.9], - ) - def testSetxor1d(self, shape1, dtype1, shape2, dtype2, assume_unique, size, fill_value, overlap): - args_maker = partial(arrays_with_overlapping_values, self.rng(), - shapes=[shape1, shape2], dtypes=[dtype1, dtype2], - overlap=overlap) - jnp_fun = lambda ar1, ar2: jnp.setxor1d(ar1, ar2, assume_unique=assume_unique, - size=size, fill_value=fill_value) - def np_fun(ar1, ar2): - if assume_unique: - # numpy requires 1D inputs when assume_unique is True. - ar1 = np.ravel(ar1) - ar2 = np.ravel(ar2) - return with_size_argument(np.setxor1d)(ar1, ar2, assume_unique, size=size, fill_value=fill_value) - with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - - @jtu.sample_product( - dtype1=[s for s in default_dtypes if s != jnp.bfloat16], - dtype2=[s for s in default_dtypes if s != jnp.bfloat16], - shape1=nonempty_shapes, - shape2=nonempty_shapes, - assume_unique=[False, True], - return_indices=[False, True], - size=[None, 3, 5], - fill_value=[None, -1], - overlap=[0.1, 0.5, 0.9], - ) - def testIntersect1d(self, shape1, dtype1, shape2, dtype2, assume_unique, - return_indices, size, fill_value, overlap): - args_maker = partial(arrays_with_overlapping_values, self.rng(), - shapes=[shape1, shape2], dtypes=[dtype1, dtype2], - unique=assume_unique, overlap=overlap) - - def jnp_fun(ar1, ar2): - return jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices, - size=size, fill_value=fill_value) - - def np_fun(ar1, ar2): - result = np.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) - def correct_size(x, fill_value): - if size is None or size == len(x): - return x - elif size < len(x): - return x[:size] - else: - if fill_value is None: - fill_value = x.min() - return np.pad(x, (0, size - len(x)), constant_values=fill_value) - if return_indices: - return tuple(correct_size(r, f) for r, f in zip(result, [fill_value, ar1.size, ar2.size])) - else: - return correct_size(result, fill_value) - - with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - @jtu.sample_product( [dict(lhs_shape=lhs_shape, lhs_dtype=lhs_dtype, rhs_shape=rhs_shape, rhs_dtype=rhs_dtype) @@ -968,6 +757,7 @@ def np_fun(lhs, rhs): @jtu.sample_product( dtype=[dt for dt in float_dtypes if dt not in [jnp.float16, jnp.bfloat16]], shape=[shape for shape in one_dim_array_shapes if shape != (1,)], + num_rhs=[1, 5], deg=[1, 2, 3], rcond=[None, -1, 10e-3, 10e-5, 10e-10], full=[False, True], @@ -975,12 +765,13 @@ def np_fun(lhs, rhs): cov=[False, True, "unscaled"], ) @jax.default_matmul_precision("float32") - def testPolyfit(self, shape, dtype, deg, rcond, full, w, cov): + def testPolyfit(self, shape, num_rhs, dtype, deg, rcond, full, w, cov): rng = jtu.rand_default(self.rng()) tol_spec = {np.float32: 1e-3, np.float64: 1e-13, np.complex64: 1e-5} tol = jtu.tolerance(dtype, tol_spec) _w = lambda a: abs(a) if w else None - args_maker = lambda: [rng(shape, dtype), rng(shape, dtype), rng(shape, dtype)] + rhs_shape = shape + (num_rhs,) if num_rhs > 1 else shape + args_maker = lambda: [rng(shape, dtype), rng(rhs_shape, dtype), rng(shape, dtype)] jnp_fun = lambda x, y, a: jnp.polyfit(x, y, deg=deg, rcond=rcond, full=full, w=_w(a), cov=cov) np_fun = jtu.ignore_warning( message="Polyfit may be poorly conditioned*")(lambda x, y, a: np.polyfit(x, y, deg=deg, rcond=rcond, full=full, w=_w(a), cov=cov)) @@ -1063,6 +854,14 @@ def testClipDeprecatedArgs(self): "Passing arguments 'a', 'a_min' or 'a_max' to jax.numpy.clip is deprecated"): jnp.clip(jnp.arange(4), a_min=2, a_max=3) + def testClipUpperPrecedence(self): + a_min = 3 * np.ones(1) + a_max = 2 * np.ones(1) + x = 4 * np.ones(1) + y = jnp.clip(x, min=a_min, max=a_max) + assert y == a_max, f"Expected {y} to equal {a_max} when a_min > a_max." + assert y == jnp.asarray(np.clip(x, a_min=a_min, a_max=a_max)) + def testHypotComplexInputError(self): rng = jtu.rand_default(self.rng()) x = rng((5,), dtype=jnp.complex64) @@ -1095,6 +894,18 @@ def testRoundStaticDecimals(self, shape, dtype, decimals): self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=check_dtypes, atol=tol, rtol=tol) + @jtu.sample_product( + dtype=number_dtypes, + decimals=[1, 10, 100, 1000], + ) + def testRoundLargeDecimals(self, dtype, decimals): + # Regression test for https://github.com/jax-ml/jax/issues/31689. + # Avoid testing against NumPy here because it returns NaN for large decimals. + rng = jtu.rand_default(self.rng()) + x = rng((10,), dtype) + result = jnp.round(x, decimals) + self.assertArraysAllClose(x, result, atol=2 * 10. ** -decimals) + @jtu.sample_product(jit=[False, True]) def testOperatorRound(self, jit): jround = jax.jit(round, static_argnums=1) if jit else round @@ -1528,32 +1339,37 @@ def testMatrixTranspose(self, shape, dtype, use_property): jnp_fun = lambda x: jnp.asarray(x).mT else: jnp_fun = jnp.matrix_transpose - if hasattr(np, 'matrix_transpose'): - np_fun = np.matrix_transpose - else: - np_fun = lambda x: np.swapaxes(x, -1, -2) rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CheckAgainstNumpy(np.matrix_transpose, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( dtype=default_dtypes, - a_shape=one_dim_array_shapes, + shape=one_dim_array_shapes if jtu.numpy_version() < (2, 2, 0) else all_shapes, trim=["f", "b", "fb"], ) - def testTrimZeros(self, a_shape, dtype, trim): + def testTrimZeros(self, shape, dtype, trim): rng = jtu.rand_some_zero(self.rng()) - args_maker = lambda: [rng(a_shape, dtype)] - np_fun = lambda arg1: np.trim_zeros(arg1, trim) + args_maker = lambda: [rng(shape, dtype)] + np_fun = lambda arg1: np.trim_zeros(np.asarray(arg1), trim) jnp_fun = lambda arg1: jnp.trim_zeros(arg1, trim) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - def testTrimZerosNotOneDArray(self): - # TODO: make this an error after the deprecation period. - with self.assertWarnsRegex(DeprecationWarning, - r"Passing arrays with ndim != 1 to jnp.trim_zeros\(\)"): - jnp.trim_zeros(jnp.array([[0.0, 1.0, 0.0],[2.0, 4.5, 0.0]])) + @jtu.sample_product( + dtype=default_dtypes, + shape=[(2, 3), (3, 4)], + trim=["f", "b", "fb"], + axis=[None, 0, -1] # note: contrary to its docs, NumPy errors for multiple axes. + ) + @unittest.skipIf(jtu.numpy_version() < (2, 2, 0), "n-dimensional trim_zeros requires NumPy 2.2") + def testTrimZerosAxis(self, shape, dtype, trim, axis): + print(shape, trim, axis) + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + np_fun = lambda arg1: np.trim_zeros(arg1, trim, axis=axis) + jnp_fun = lambda arg1: jnp.trim_zeros(arg1, trim, axis=axis) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) @jtu.sample_product( rank=(1, 2), @@ -1644,7 +1460,10 @@ def testIntegerPower(self, ptype): def testIntegerPowerOverflow(self, x, y): # Regression test for https://github.com/jax-ml/jax/issues/5987 args_maker = lambda: [x, y] - self._CheckAgainstNumpy(np.power, jnp.power, args_maker) + check_dtypes = platform.system() != 'Windows' + self._CheckAgainstNumpy( + np.power, jnp.power, args_maker, check_dtypes=check_dtypes + ) self._CompileAndCheck(jnp.power, args_maker) @jtu.sample_product( @@ -1808,10 +1627,7 @@ def testConcat(self, axis, base_shape, dtype): for size in [3, 1, 4]] @jtu.promote_like_jnp def np_fun(*args): - if jtu.numpy_version() >= (2, 0, 0): - return np.concat(args, axis=axis) - else: - return np.concatenate(args, axis=axis) + return np.concat(args, axis=axis) jnp_fun = lambda *args: jnp.concat(args, axis=axis) args_maker = lambda: [rng(shape, dtype) for shape in shapes] self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) @@ -1928,9 +1744,6 @@ def testDeleteMaskArray(self, shape, dtype, axis): rng = jtu.rand_default(self.rng()) mask_size = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis] mask = jtu.rand_int(self.rng(), low=0, high=2)(mask_size, bool) - if numpy_version == (1, 23, 0) and mask.shape == (1,): - # https://github.com/numpy/numpy/issues/21840 - self.skipTest("test fails for numpy v1.23.0") args_maker = lambda: [rng(shape, dtype)] np_fun = lambda arg: np.delete(arg, mask, axis=axis) jnp_fun = lambda arg: jnp.delete(arg, mask, axis=axis) @@ -2072,155 +1885,6 @@ def testRepeatScalarFastPath(self): jaxpr = jax.make_jaxpr(f)(a) self.assertLessEqual(len(jaxpr.jaxpr.eqns), 6) - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in all_shapes - for axis in [None] + list(range(len(shape)))], - dtype=number_dtypes, - return_index=[False, True], - return_inverse=[False, True], - return_counts=[False, True], - ) - def testUnique(self, shape, dtype, axis, return_index, return_inverse, return_counts): - rng = jtu.rand_some_equal(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - extra_args = (return_index, return_inverse, return_counts) - use_defaults = (False, *(True for arg in extra_args if arg)) if any(extra_args) else False - np_fun = jtu.with_jax_dtype_defaults(lambda x: np_unique_backport(x, *extra_args, axis=axis), use_defaults) - jnp_fun = lambda x: jnp.unique(x, *extra_args, axis=axis) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - - @jtu.sample_product(shape=all_shapes, dtype=number_dtypes) - def testUniqueAll(self, shape, dtype): - rng = jtu.rand_some_equal(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - if jtu.numpy_version() < (2, 0, 0): - np_fun = partial(np_unique_backport, return_index=True, return_inverse=True, return_counts=True) - else: - np_fun = np.unique_all - self._CheckAgainstNumpy(jnp.unique_all, np_fun, args_maker) - - @jtu.sample_product(shape=all_shapes, dtype=number_dtypes) - def testUniqueCounts(self, shape, dtype): - rng = jtu.rand_some_equal(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - if jtu.numpy_version() < (2, 0, 0): - np_fun = lambda x: np.unique(x, return_counts=True) - else: - np_fun = np.unique_counts - self._CheckAgainstNumpy(jnp.unique_counts, np_fun, args_maker) - - @jtu.sample_product(shape=all_shapes, dtype=number_dtypes) - def testUniqueInverse(self, shape, dtype): - rng = jtu.rand_some_equal(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - if jtu.numpy_version() < (2, 0, 0): - np_fun = partial(np_unique_backport, return_inverse=True) - else: - np_fun = np.unique_inverse - self._CheckAgainstNumpy(jnp.unique_inverse, np_fun, args_maker) - - @jtu.sample_product(shape=all_shapes, dtype=number_dtypes) - def testUniqueValues(self, shape, dtype): - rng = jtu.rand_some_equal(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - if jtu.numpy_version() < (2, 0, 0): - np_fun = np.unique - else: - np_fun = lambda *args: np.sort(np.unique_values(*args)) - self._CheckAgainstNumpy(jnp.unique_values, np_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in nonempty_array_shapes - for axis in [None] + list(range(len(shape)))], - dtype=number_dtypes, - size=[1, 5, 10], - fill_value=[None, 0, "slice"], - ) - def testUniqueSize(self, shape, dtype, axis, size, fill_value): - rng = jtu.rand_some_equal(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - kwds = dict(axis=axis, return_index=True, return_inverse=True, return_counts=True) - - if fill_value == "slice": - if axis is None: - fill_value = rng((), dtype) - else: - fill_value = rng(shape[:axis] + shape[axis + 1:], dtype) - elif fill_value is not None: - fill_value = np.array(fill_value).astype(dtype) - - @partial(jtu.with_jax_dtype_defaults, use_defaults=(False, True, True, True)) - def np_fun(x, fill_value=fill_value): - u, ind, inv, counts = np_unique_backport(x, **kwds) - axis = kwds['axis'] - if axis is None: - x = x.ravel() - axis = 0 - - n_unique = u.shape[axis] - if size <= u.shape[axis]: - slc = (slice(None),) * axis + (slice(size),) - u, ind, counts = u[slc], ind[:size], counts[:size] - else: - extra = (0, size - n_unique) - pads = [(0, 0)] * u.ndim - pads[axis] = extra - u = np.pad(u, pads, constant_values=0) - slices = [slice(None)] * u.ndim - slices[axis] = slice(1) - if fill_value is None: - fill_value = u[tuple(slices)] - elif np.ndim(fill_value): - fill_value = lax.expand_dims(fill_value, (axis,)) - slices[axis] = slice(n_unique, None) - u[tuple(slices)] = fill_value - ind = np.pad(ind, extra, constant_values=ind[0]) - counts = np.pad(counts, extra, constant_values=0) - return u, ind, inv, counts - - jnp_fun = lambda x: jnp.unique(x, size=size, fill_value=fill_value, **kwds) - - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product(dtype=inexact_dtypes) - def testUniqueNans(self, dtype): - def args_maker(): - x = [-0.0, 0.0, 1.0, 1.0, np.nan, -np.nan] - if np.issubdtype(dtype, np.complexfloating): - x = [complex(i, j) for i, j in itertools.product(x, repeat=2)] - return [np.array(x, dtype=dtype)] - - kwds = dict(return_index=True, return_inverse=True, return_counts=True) - jnp_fun = partial(jnp.unique, **kwds) - def np_fun(x): - dtype = x.dtype - # numpy unique fails for bfloat16 NaNs, so we cast to float64 - if x.dtype == jnp.bfloat16: - x = x.astype('float64') - u, *rest = np.unique(x, **kwds) - return (u.astype(dtype), *rest) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - - @jtu.sample_product(dtype=inexact_dtypes, equal_nan=[True, False]) - @jtu.ignore_warning( - category=RuntimeWarning, message='invalid value encountered in cast' - ) - def testUniqueEqualNan(self, dtype, equal_nan): - shape = (20,) - rng = jtu.rand_some_nan(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - def np_fun(x): - dtype = x.dtype - # numpy unique fails for bfloat16 NaNs, so we cast to float64 - if x.dtype == jnp.bfloat16: - x = x.astype('float64') - return np.unique(x, equal_nan=equal_nan).astype(dtype) - jnp_fun = partial(jnp.unique, equal_nan=equal_nan) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - @jtu.sample_product(fixed_size=[False, True]) def testNonScalarRepeats(self, fixed_size): ''' @@ -2318,8 +1982,10 @@ def np_fun(x, y): xshape=one_dim_array_shapes, yshape=one_dim_array_shapes, ) - @jtu.skip_on_devices("cuda", "rocm") # backends don't support all dtypes. + @jtu.skip_on_devices("cuda") # backends don't support all dtypes. def testConvolutionsPreferredElementType(self, xshape, yshape, dtype, mode, op): + if jtu.test_device_matches(["rocm"]) and not dtypes.issubdtype(dtype, np.inexact): + self.skipTest(f"preferred_element_type={dtype} unsupported for ROCm GPU convolutions") jnp_op = getattr(jnp, op) np_op = getattr(np, op) rng = jtu.rand_default(self.rng()) @@ -2347,11 +2013,11 @@ def testCumSumProd(self, axis, shape, dtype, out_dtype, op): np_op = getattr(np, op) rng = jtu.rand_default(self.rng()) np_fun = lambda arg: np_op(arg, axis=axis, dtype=out_dtype) - np_fun = jtu.ignore_warning(category=NumpyComplexWarning)(np_fun) + np_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(np_fun) np_fun = jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered.*")(np_fun) jnp_fun = lambda arg: jnp_op(arg, axis=axis, dtype=out_dtype) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + jnp_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] @@ -2375,11 +2041,11 @@ def testNanCumSumProd(self, axis, shape, dtype, out_dtype, op): np_op = getattr(np, op) rng = jtu.rand_some_nan(self.rng()) np_fun = partial(np_op, axis=axis, dtype=out_dtype) - np_fun = jtu.ignore_warning(category=NumpyComplexWarning)(np_fun) + np_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(np_fun) np_fun = jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered.*")(np_fun) jnp_fun = partial(jnp_op, axis=axis, dtype=out_dtype) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + jnp_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] @@ -2422,17 +2088,26 @@ def testEyeDynamicK(self, n, m, k, dtype): dtype=default_dtypes, n=[0, 4], m=[None, 0, 1, 3, 4], - k=range(-4, 4), + k=[*range(-4, 4), -2**33, 2**33], ) def testTri(self, m, n, k, dtype): - np_fun = lambda: np.tri(n, M=m, k=k, dtype=dtype) - jnp_fun = lambda: jnp.tri(n, M=m, k=k, dtype=dtype) - args_maker = lambda: [] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) + np_fun = lambda k: np.tri(n, M=m, k=k, dtype=dtype) + jnp_fun = lambda k: jnp.tri(n, M=m, k=k, dtype=dtype) + args_maker = lambda: [k] + if not config.enable_x64.value and ( + k < np.iinfo(np.int32).min or k > np.iinfo(np.int32).max + ): + with self.assertRaises(OverflowError): + jnp_fun(k) + else: + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) def test_tri_bug_22751(self): - with self.assertRaisesRegex(core.ConcretizationTypeError, "jax.numpy.tri"): + with self.assertRaisesRegex( + TypeError, + 'Shapes must be 1D sequences of concrete values of integer type', + ): jax.jit(jnp.tri)(3, M=3, k=0) @jtu.sample_product( @@ -2754,6 +2429,28 @@ def np_fun(x1, x2): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @parameterized.parameters(*float_dtypes) + def testLdexpOverflow(self, dtype): + # Regression test for https://github.com/jax-ml/jax/issues/28040 + args_maker = lambda: [np.array(0.5, dtype), 1 << (dtypes.finfo(dtype).nexp - 1)] + def np_ldexp(x1, x2): + return np.ldexp(x1, x2).astype(x1.dtype) + self._CheckAgainstNumpy(np_ldexp, jnp.ldexp, args_maker) + self._CompileAndCheck(jnp.ldexp, args_maker) + + @parameterized.parameters(*float_dtypes) + def testLdexpExtremeValues(self, dtype): + # Regression test for https://github.com/jax-ml/jax/issues/28040 + def args_maker(): + info = dtypes.finfo(dtype) + span = int(np.log2(float(info.max)) - np.log2(float(info.tiny))) - 1 + return [np.array([info.tiny, info.max], dtype=dtype), + np.array([span, -span])] + def np_ldexp(x1, x2): + return np.ldexp(x1, x2).astype(x1.dtype) + self._CheckAgainstNumpy(np_ldexp, jnp.ldexp, args_maker) + self._CompileAndCheck(jnp.ldexp, args_maker) + @jtu.sample_product( rng_factory=[ jtu.rand_some_inf_and_nan, @@ -2764,6 +2461,12 @@ def np_fun(x1, x2): dtype=default_dtypes, ) @jtu.ignore_warning(category=RuntimeWarning, message="overflow") + @jtu.ignore_warning(category=RuntimeWarning, + message="invalid value encountered in isinf") + @unittest.skipIf( + platform.system() == "Windows", + "TODO (ybaturina): Test fails on Windows b/435663064." + ) def testFrexp(self, shape, dtype, rng_factory): # integer types are converted to float64 in numpy's implementation if (dtype not in [jnp.bfloat16, np.float16, np.float32] @@ -2847,18 +2550,16 @@ def testSearchsortedDtype(self): out_int32 = jax.eval_shape(jnp.searchsorted, a_int32, v) self.assertEqual(out_int32.dtype, np.int32) - if config.enable_x64.value: + if ( + config.enable_x64.value + or config.explicit_x64_dtypes.value == config.ExplicitX64Mode.ALLOW + ): out_int64 = jax.eval_shape(jnp.searchsorted, a_int64, v) self.assertEqual(out_int64.dtype, np.int64) - elif jtu.numpy_version() < (2, 0, 0): - with self.assertWarnsRegex(UserWarning, "Explicitly requested dtype int64"): - with jtu.ignore_warning(category=DeprecationWarning, - message="NumPy will stop allowing conversion.*"): - out_int64 = jax.eval_shape(jnp.searchsorted, a_int64, v) - else: + elif config.explicit_x64_dtypes.value == config.ExplicitX64Mode.WARN: with self.assertWarnsRegex(UserWarning, "Explicitly requested dtype.*int64"): with self.assertRaisesRegex(OverflowError, "Python integer 2147483648 out of bounds.*"): - out_int64 = jax.eval_shape(jnp.searchsorted, a_int64, v) + jax.eval_shape(jnp.searchsorted, a_int64, v) @jtu.sample_product( dtype=inexact_dtypes, @@ -3496,11 +3197,6 @@ def testReshape(self, arg_shape, out_shape, dtype, order): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) - def testReshapeDeprecatedArgs(self): - msg = "The newshape argument to jnp.reshape was removed in JAX v0.4.36." - with self.assertRaisesRegex(TypeError, msg): - jnp.reshape(jnp.arange(4), newshape=(2, 2)) - @jtu.sample_product( [dict(arg_shape=arg_shape, out_shape=out_shape) for arg_shape, out_shape in [ @@ -3666,6 +3362,53 @@ def testAsarrayCopy(self, copy): self.assertArraysEqual(x_jax, func(x_np), check_dtypes=False) self.assertArraysEqual(x_jax, func(x_buf), check_dtypes=False) + @jtu.sample_product(numpy_array=[True, False]) + def testAsarrayWithCopyFalse(self, numpy_array): + x_jax = jnp.arange(4) + if numpy_array: + x = np.arange(4) + else: + x = make_python_array('l', [0, 1, 2, 3]) + device_error_msg = ('jnp.asarray: cannot convert object of type .* to JAX' + ' Array on platform={} with copy=False. Consider using' + ' copy=None or copy=True instead.') + + if jax.default_backend() != 'cpu': + # test accelerator devices - no support for copy=False + expected_platform = jax.local_devices()[0].platform + with self.assertRaisesRegex( + ValueError, device_error_msg.format(expected_platform)): + jnp.asarray(x, copy=False, device=jax.local_devices()[0]) + sharding = SingleDeviceSharding(jax.local_devices()[0]) + with self.assertRaisesRegex( + ValueError, device_error_msg.format(expected_platform)): + jnp.asarray(x, copy=False, device=sharding) + + # test None defaults to default backend - no support for copy=False + with self.assertRaisesRegex( + ValueError, device_error_msg.format(expected_platform)): + jnp.asarray(x, copy=False, device=None) + else: + self.assertArraysEqual(jnp.asarray(x, copy=False, device=None), x_jax, + check_dtypes=False) + + # test explicit CPU device or default CPU device context managers overwrite the default backend + x = make_python_array('l', [0, 1, 2, 3]) + for device in [jax.local_devices(backend='cpu')[0], + SingleDeviceSharding(jax.local_devices(backend='cpu')[0])]: + self.assertArraysEqual(jnp.asarray(x, copy=False, device=device), + x_jax, check_dtypes=False) + with jax.default_device('cpu'): + self.assertArraysEqual(jnp.asarray(x, copy=False), x_jax, + check_dtypes=False) + self.assertArraysEqual(jnp.asarray(x, copy=False, device=None), x_jax, + check_dtypes=False) + with jax.default_device(jax.local_devices(backend='cpu')[0]): + self.assertArraysEqual(jnp.asarray(x, copy=False), x_jax, + check_dtypes=False) + self.assertArraysEqual(jnp.asarray(x, copy=False, device=None), x_jax, + check_dtypes=False) + @jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*") def testArrayDtypeInference(self): def _check(obj, out_dtype, weak_type): @@ -3695,9 +3438,10 @@ def _check(obj, out_dtype, weak_type): _check([jnp.array(1.0j)], jnp.complex128, False) # Lists of strongly-typed objects maintain their strong type. - _check([jnp.int64(1)], np.int64, False) - _check([jnp.float64(1)], np.float64, False) - _check([jnp.complex128(1)], np.complex128, False) + if config.explicit_x64_dtypes.value != config.ExplicitX64Mode.ERROR: + _check([jnp.int64(1)], np.int64, False) + _check([jnp.float64(1)], np.float64, False) + _check([jnp.complex128(1)], np.complex128, False) # Mixed inputs use JAX-style promotion. # (regression test for https://github.com/jax-ml/jax/issues/8945) @@ -3765,7 +3509,7 @@ def testArrayUnsupportedDtypeError(self): jnp.array(3, [('a','= (2, 3, 0) and (np.isinf(atol) or np.isinf(rtol)): + self.skipTest("NumPy 2.3.0 now throws warnings for inf atol/rtol") + vals = np.array([-np.nan, -np.inf, -1.00001, -1.0, -0.00001, -0.0, + 0.0, 0.00001, 1.0, 1.00001, np.inf, np.nan]) + x, y = np.meshgrid(vals, vals) + self.assertArraysEqual( + np.isclose(x, y, atol=atol, rtol=rtol, equal_nan=equal_nan), + jnp.isclose(x, y, atol=atol, rtol=rtol, equal_nan=equal_nan) + ) + @jtu.sample_product( x=[1, [1], [1, 1 + 1E-4], [1, np.nan]], y=[1, [1], [1, 1 + 1E-4], [1, np.nan]], @@ -4105,19 +3883,23 @@ def jnp_fun(a, c): @jtu.sample_product( shape=nonempty_nonscalar_array_shapes, - dtype=int_dtypes, + dtype=int_dtypes + unsigned_dtypes, idx_shape=all_shapes, ) def testUnravelIndex(self, shape, idx_shape, dtype): + if not jtu.is_optimized_build(): + self.skipTest("Test fails under debug mode. See https://github.com/numpy/numpy/issues/29690.") size = math.prod(shape) - rng = jtu.rand_int(self.rng(), low=-((2 * size) // 3), high=(2 * size) // 3) + unsigned = dtypes.issubdtype(dtype, np.unsignedinteger) + rng = jtu.rand_int( + self.rng(), low=0 if unsigned else -((2 * size) // 3), high=(2 * size) // 3) def np_fun(index, shape): # JAX's version outputs the same dtype as the input in the typical case # where shape is weakly-typed. out_dtype = index.dtype # Adjust out-of-bounds behavior to match jax's documented behavior. - index = np.clip(index, -size, size - 1) + index = np.clip(index, 0 if unsigned else -size, size - 1) index = np.where(index < 0, index + size, index) return [i.astype(out_dtype) for i in np.unravel_index(index, shape)] @@ -4134,7 +3916,7 @@ def np_fun(index, shape): def testAstype(self, from_dtype, to_dtype, use_method): rng = self.rng() args_maker = lambda: [rng.randn(3, 4).astype(from_dtype)] - if (not use_method) and hasattr(np, "astype"): # Added in numpy 2.0 + if not use_method: np_op = lambda x: np.astype(x, to_dtype) else: np_op = lambda x: np.asarray(x).astype(to_dtype) @@ -4152,7 +3934,7 @@ def testAstype(self, from_dtype, to_dtype, use_method): def testAstypeBool(self, from_dtype, use_method, to_dtype='bool'): rng = jtu.rand_some_zero(self.rng()) args_maker = lambda: [rng((3, 4), from_dtype)] - if (not use_method) and hasattr(np, "astype"): # Added in numpy 2.0 + if not use_method: np_op = lambda x: np.astype(x, to_dtype) else: np_op = lambda x: np.asarray(x).astype(to_dtype) @@ -4285,7 +4067,7 @@ def testView(self, shape, a_dtype, dtype): {'a_dtype': a_dtype, 'dtype': dtype} for a_dtype in [jnp.int4, jnp.uint4, *all_dtypes] for dtype in [jnp.int4, jnp.uint4, *all_dtypes] - if dtypes.bit_width(a_dtype) == dtypes.bit_width(dtype) + if dtypes.itemsize_bits(a_dtype) == dtypes.itemsize_bits(dtype) ]) def testViewScalar(self, a_dtype, dtype): if jtu.test_device_matches(["tpu"]): @@ -4341,8 +4123,8 @@ def testArangeOnFloats(self): def testArangeTypes(self): # Test that arange() output type is equal to the default types. - int_ = dtypes.canonicalize_dtype(jnp.int_) - float_ = dtypes.canonicalize_dtype(jnp.float_) + int_ = dtypes.default_int_dtype() + float_ = dtypes.default_float_dtype() self.assertEqual(jnp.arange(10).dtype, int_) self.assertEqual(jnp.arange(10.).dtype, float_) @@ -4594,7 +4376,7 @@ def testRollaxis(self, shape, dtype, start, axis): self._CompileAndCheck(jnp_op, args_maker) @jtu.sample_product( - dtype=[np.uint8, np.bool_], + dtype=int_dtypes + unsigned_dtypes + bool_dtypes, bitorder=['big', 'little'], shape=[(1, 2, 3, 4)], axis=[None, 0, 1, -2, -1], @@ -4706,14 +4488,19 @@ def args_maker(): jnp_one_hot_op = lambda x, i: jnp.take_along_axis( x, i, axis=axis, mode='one_hot' ) - - if hasattr(np, "take_along_axis"): - np_op = lambda x, i: np.take_along_axis(x, i, axis=axis) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CheckAgainstNumpy(np_op, jnp_one_hot_op, args_maker) + np_op = lambda x, i: np.take_along_axis(x, i, axis=axis) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CheckAgainstNumpy(np_op, jnp_one_hot_op, args_maker) self._CompileAndCheck(jnp_op, args_maker) self._CompileAndCheck(jnp_one_hot_op, args_maker) + def testTakeAlongAxisDefaultAxis(self): + arr = jtu.rand_default(self.rng())((10, 20), np.float32) + indices = jtu.rand_int(self.rng(), 0, 100)((10, 30), np.uint8) + q0 = jnp.take_along_axis(arr, indices, axis=-1) + q1 = jnp.take_along_axis(arr, indices) + np.testing.assert_array_equal(q0, q1) + def testTakeAlongAxisWithUint8IndicesDoesNotOverflow(self): # https://github.com/jax-ml/jax/issues/5088 h = jtu.rand_default(self.rng())((256, 256, 100), np.float32) @@ -4839,7 +4626,7 @@ def args_maker(): return [] def testIndicesDefaultDtype(self): self.assertEqual(jnp.indices((2, 3)).dtype, - dtypes.canonicalize_dtype(np.int64)) + dtypes.default_int_dtype()) @jtu.sample_product( shape=nonzerodim_shapes, @@ -4963,14 +4750,16 @@ def testIssue453(self): self.assertAllClose(ans, expected) @jtu.sample_product( - dtype=[int, float, bool, complex], + scalar_type=[int, float, bool, complex], op=["atleast_1d", "atleast_2d", "atleast_3d"], ) - def testAtLeastNdLiterals(self, dtype, op): + def testAtLeastNdLiterals(self, scalar_type, op): # Fixes: https://github.com/jax-ml/jax/issues/634 - np_fun = lambda arg: getattr(np, op)(arg).astype(dtypes.python_scalar_dtypes[dtype]) + np_fun = lambda arg: getattr(np, op)(arg).astype( + dtypes.python_scalar_types_to_dtypes[scalar_type] + ) jnp_fun = lambda arg: getattr(jnp, op)(arg) - args_maker = lambda: [dtype(2)] + args_maker = lambda: [scalar_type(2)] self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) @@ -5092,6 +4881,53 @@ def testArangeJaxpr(self, args, specify_device): self.assertEqual(len(jaxpr.jaxpr.eqns), num_eqs) self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.iota_p) + @jtu.sample_product(specify_device=[True, False]) + def testArangeJaxprNonZeroStart(self, specify_device): + device = jax.devices()[-1] if specify_device else None + jaxpr = jax.make_jaxpr(lambda: jnp.arange(1, 5, device=device))() + # Non-zero start should produce iota + add (+ device_put if device specified) + num_eqs = 3 if device is not None else 2 + self.assertEqual(len(jaxpr.jaxpr.eqns), num_eqs) + self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.iota_p) + self.assertEqual(jaxpr.jaxpr.eqns[1].primitive, lax.add_p) + + @jtu.sample_product( + dtype=[np.int32, np.float32], + iteration=range(10) + ) + def testArangeRandomValues(self, dtype, iteration): + del iteration # not needed: each test case gets its own random seed. + rng = jtu.rand_default(self.rng()) + start = rng((), dtype) + stop = rng((), dtype) + jax_result = jnp.arange(start, stop, dtype=dtype) + np_result = np.arange(start, stop, dtype=dtype) + self.assertAllClose(jax_result, np_result) + + @parameterized.parameters( + (1+2j, 5+3j), + (0+0j, 5+0j), + (1.0+0j, 5.0+0j), + (0, 5, 1+1j), + ) + def testArangeComplex(self, *args): + dep_id = "jax-numpy-arange-complex" + msg = "Passing complex start/stop/step to jnp.arange is deprecated" + if deprecations.is_accelerated(dep_id): + with self.assertRaisesRegex(ValueError, msg): + jax_result = jnp.arange(*args) + else: + with self.assertWarnsRegex(DeprecationWarning, msg): + jax_result = jnp.arange(*args) + np_result = np.arange(*args) + self.assertArraysEqual(jax_result, np_result) + + @parameterized.parameters(int, float, np.int32, np.float32) + def testArangeTransferGuard(self, typ): + # Ensure that simple arange calls avoid host-to-device transfer. + with jax.transfer_guard("disallow"): + jnp.arange(typ(5)) + def testIssue830(self): a = jnp.arange(4, dtype=jnp.complex64) self.assertEqual(a.dtype, jnp.complex64) @@ -5107,10 +4943,9 @@ def testIssue746(self): def testIssue764(self): x = jnp.linspace(190, 200, 4) f = jax.grad(lambda x: jnp.sum(jnp.tanh(x))) - # Expected values computed with autograd in float64 precision. - expected = np.array([3.71669453e-165, 4.72999108e-168, 6.01954653e-171, - 7.66067839e-174], np.float64) - self.assertAllClose(f(x), expected, check_dtypes=False) + # tanh(190) and tanh(200) are both 1, so the gradient is 0 in f64. + expected = np.array([0,0,0,0], np.float64) + self.assertAllClose(f(x), expected, check_dtypes=False, atol=1e-14, rtol=0) # Test removed because tie_in is deprecated. # def testIssue776(self): @@ -5185,6 +5020,14 @@ def args_maker(): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) self._CompileAndCheck(jnp_fun, args_maker) + def testCorrCoefDtype(self): + x = jnp.arange(5) + result_bf16 = jnp.corrcoef(x, x, dtype='bfloat16') + self.assertEqual(result_bf16.dtype, np.dtype('bfloat16')) + + with self.assertRaisesRegex(ValueError, "corrcoef: dtype must be a subclass of float or complex"): + jnp.corrcoef(x, x, dtype=int) + @jtu.sample_product( [dict(dtype=dtype, end_dtype=end_dtype, begin_dtype=begin_dtype, shape=shape, begin_shape=begin_shape, end_shape=end_shape) @@ -5560,7 +5403,7 @@ def np_op(start, stop): axis=axis).astype(dtype) # JAX follows NumPy 2.0 semantics for complex geomspace. - if not (jtu.numpy_version() < (2, 0, 0) and dtypes.issubdtype(dtype, jnp.complexfloating)): + if not dtypes.issubdtype(dtype, jnp.complexfloating): self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=False, tol=tol) if dtype in (inexact_dtypes + [None,]): @@ -5948,13 +5791,11 @@ def np_op(x1, x2): self._CompileAndCheck(jnp.logaddexp2, args_maker, rtol=tol, atol=tol) def testDefaultDtypes(self): - precision = config.default_dtype_bits.value - assert precision in ['32', '64'] self.assertEqual(jnp.bool_, np.bool_) - self.assertEqual(jnp.int_, np.int32 if precision == '32' else np.int64) - self.assertEqual(jnp.uint, np.uint32 if precision == '32' else np.uint64) - self.assertEqual(jnp.float_, np.float32 if precision == '32' else np.float64) - self.assertEqual(jnp.complex_, np.complex64 if precision == '32' else np.complex128) + self.assertEqual(jnp.int_, np.int64) + self.assertEqual(jnp.uint, np.uint64) + self.assertEqual(jnp.float_, np.float64) + self.assertEqual(jnp.complex_, np.complex128) def testFromBuffer(self): buf = b'\x01\x02\x03' @@ -6042,7 +5883,10 @@ def np_fun(a, i, v): dict(a_shape=a_shape, i_shape=i_shape, v_shape=v_shape, axis=axis) for a_shape in nonempty_array_shapes for axis in list(range(-len(a_shape), len(a_shape))) - for i_shape in [tuple_replace(a_shape, axis, J) for J in range(a_shape[axis] + 1)] + for i_shape in [ + tuple_update(a_shape, axis if axis >= 0 else axis + len(a_shape), J) + for J in range(a_shape[axis] + 1) + ] for v_shape in [(), (1,), i_shape] ] + [ dict(a_shape=a_shape, i_shape=i_shape, v_shape=v_shape, axis=None) @@ -6102,7 +5946,7 @@ def test_error_hint(self, fn): def test_isdtype(self, dtype, kind): # Full tests also in dtypes_test.py; here we just compare against numpy jax_result = jnp.isdtype(dtype, kind) - if jtu.numpy_version() < (2, 0, 0) or dtype == dtypes.bfloat16: + if dtype == dtypes.bfloat16: # just a smoke test self.assertIsInstance(jax_result, bool) else: @@ -6123,15 +5967,12 @@ def test_isdtype(self, dtype, kind): ], dtype=float_dtypes + int_dtypes, ) - @jtu.skip_on_devices("tpu") # TODO(jakevdp): fix and reenable this test. + @jtu.skip_on_devices("tpu") # TODO(jakevdp): fix and re-enable this test. @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. def test_trapezoid(self, yshape, xshape, dtype, dx, axis): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(yshape, dtype), rng(xshape, dtype) if xshape is not None else None] - if jtu.numpy_version() >= (2, 0, 0): - np_fun = partial(np.trapezoid, dx=dx, axis=axis) - else: - np_fun = partial(np.trapz, dx=dx, axis=axis) + np_fun = partial(np.trapezoid, dx=dx, axis=axis) jnp_fun = partial(jnp.trapezoid, dx=dx, axis=axis) tol = jtu.tolerance(dtype, {np.float16: 2e-3, np.float64: 1e-12, jax.dtypes.bfloat16: 4e-2}) @@ -6168,6 +6009,33 @@ def testSizeAlongAxis(self, shape, dtype): self._CheckAgainstNumpy(np_op, jnp_op, args_maker) self._CompileAndCheck(jnp_op, args_maker) + @jtu.sample_product( + tuple_size=[0, 1, 2, 3, 4] + ) + def testSizeAlongAxisTuple(self, tuple_size): + rng = self.rng() + + ndim = tuple_size + rng.randint(10) + + shape = rng.randint(0, 10, ndim) + tuples = list(itertools.combinations(range(ndim), tuple_size)) + axis = tuples[rng.randint(len(tuples))] + + array = jnp.zeros(shape) + output = jnp.size(array, axis) + expected = math.prod(shape[i] for i in axis) + assert output == expected + + @jtu.sample_product( + axis=[(0, 0), (0, -3), (1, 1), (1, -2), (2, 2), (2, -1)], + ) + def testSizeAlongAxisDuplicate(self, axis): + shape = (2, 3, 4) + array = jnp.zeros(shape) + msg = "repeated axis" + with self.assertRaisesRegex(ValueError, msg): + jnp.size(array, axis) + @jtu.sample_product( op=[jnp.ndim, jnp.shape, jnp.size], ) @@ -6247,8 +6115,9 @@ def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol): ) for rec in GRAD_SPECIAL_VALUE_TEST_RECORDS)) def testOpGradSpecialValue(self, op, special_value, order): - check_grads(op, (special_value,), order, ["fwd", "rev"], - atol={np.float32: 3e-3}) + check_grads( + op, (special_value,), order, ['fwd', 'rev'], atol={np.float32: 3.4e-3} + ) def testSincAtZero(self): # Some manual tests for sinc at zero, since it doesn't have well-behaved @@ -6327,9 +6196,18 @@ def testGradLogaddexp2Complex(self, shapes, dtype): ) def testGradLdexp(self, n, dtype): rng = jtu.rand_default(self.rng()) - x = rng((), dtype) + x = rng((10,), dtype) check_grads(lambda x: jnp.ldexp(x, n), (x,), 1) + @jtu.sample_product( + n=range(-4, 5), + dtype=[jnp.float32, jnp.float64], + ) + def testGradFrexp(self, n, dtype): + rng = jtu.rand_default(self.rng()) + x = rng((10,), dtype) * 2 ** n + check_grads(lambda x: jnp.frexp(x)[0], (x,), 1) + class NumpySignaturesTest(jtu.JaxTestCase): @@ -6354,6 +6232,7 @@ def testWrappedSignaturesMatch(self): 'datetime_as_string', 'datetime_data', 'errstate', + 'fix', 'flatiter', 'format_float_positional', 'format_float_scientific', @@ -6397,43 +6276,6 @@ def testWrappedSignaturesMatch(self): 'trapz', 'typename'} - # symbols removed in NumPy 2.0 - skip |= {'add_docstring', - 'add_newdoc', - 'add_newdoc_ufunc', - 'alltrue', - 'asfarray', - 'byte_bounds', - 'compare_chararrays', - 'cumproduct', - 'deprecate', - 'deprecate_with_doc', - 'disp', - 'fastCopyAndTranspose', - 'find_common_type', - 'get_array_wrap', - 'geterrobj', - 'issctype', - 'issubclass_', - 'issubsctype', - 'lookfor', - 'mat', - 'maximum_sctype', - 'msort', - 'obj2sctype', - 'product', - 'recfromcsv', - 'recfromtxt', - 'round_', - 'safe_eval', - 'sctype2char', - 'set_numeric_ops', - 'set_string_function', - 'seterrobj', - 'sometrue', - 'source', - 'who'} - self.assertEmpty(skip.intersection(dir(jnp))) names = (name for name in dir(np) if not (name.startswith('_') or name in skip)) @@ -6445,46 +6287,55 @@ def testWrappedSignaturesMatch(self): # TODO(jakevdp): fix some of the following signatures. Some are due to wrong argument names. unsupported_params = { + 'arange': ['start_or_stop', 'like'], + 'array': ['ndmax', 'like', 'subok'], 'argpartition': ['kind', 'order'], 'asarray': ['like'], 'broadcast_to': ['subok'], 'clip': ['kwargs', 'out'], + 'concat': ['out', 'dtype', 'casting'], + 'concatenate': ['out', 'casting'], 'copy': ['subok'], - 'corrcoef': ['ddof', 'bias', 'dtype'], - 'cov': ['dtype'], + 'corrcoef': ['ddof', 'bias'], 'cumulative_prod': ['out'], 'cumulative_sum': ['out'], + 'dot': ['out'], 'empty_like': ['subok', 'order'], 'einsum': ['kwargs'], 'einsum_path': ['einsum_call'], + 'empty': ['order', 'like'], 'eye': ['order', 'like'], 'hstack': ['casting'], 'identity': ['like'], 'isin': ['kind'], 'full': ['order', 'like'], 'full_like': ['subok', 'order'], + 'frombuffer': ['like'], 'fromfunction': ['like'], + 'frompyfunc': ['kwargs'], + 'fromstring': ['like'], 'load': ['mmap_mode', 'allow_pickle', 'fix_imports', 'encoding', 'max_header_size'], - 'nanpercentile': ['weights'], - 'nanquantile': ['weights'], - 'nanstd': ['correction', 'mean'], - 'nanvar': ['correction', 'mean'], + 'nanpercentile': ['interpolation', 'weights'], + 'nanquantile': ['interpolation', 'weights'], + 'nanstd': ['correction'], + 'nanvar': ['correction'], 'ones': ['order', 'like'], 'ones_like': ['subok', 'order'], 'partition': ['kind', 'order'], - 'percentile': ['weights'], - 'quantile': ['weights'], + 'percentile': ['interpolation', 'weights'], + 'promote_types': ['type1', 'type2'], + 'quantile': ['interpolation', 'weights'], 'row_stack': ['casting'], 'stack': ['casting'], - 'std': ['mean'], 'tri': ['like'], - 'trim_zeros': ['axis'], - 'var': ['mean'], + 'unravel_index': ['order'], 'vstack': ['casting'], + 'zeros': ['order', 'like'], 'zeros_like': ['subok', 'order'] } extra_params = { + 'arange': ['start'], 'compress': ['size', 'fill_value'], 'einsum': ['subscripts', 'precision'], 'einsum_path': ['subscripts'], @@ -6499,6 +6350,11 @@ def testWrappedSignaturesMatch(self): for name in names: jnp_fun = getattr(jnp, name) np_fun = getattr(np, name) + if isinstance(getattr(np, name), np.ufunc): + # Skip all `np.ufunc`s since many of the missing ufunc keywords may not + # be relevant for JAX. However, args such as `axis` and `keepdims` may + # be useful to `matmul` and others. + continue if name in ['histogram', 'histogram2d', 'histogramdd']: # numpy 1.24 re-orders the density and weights arguments. # TODO(jakevdp): migrate histogram APIs to match newer numpy versions. @@ -6516,6 +6372,12 @@ def testWrappedSignaturesMatch(self): # Similar issue to clip: we'd need logic specific to the NumPy version # because of the change in argument name from `newshape` to `shape`. continue + if name == "asarray": + # The order of the `device` and `copy` kwargs are swapped between jnp + # and np. + # jnp.asarray: a, dtype, order, copy, device, out_sharding + # np.asarray: a, dtype, order, device, copy, like + continue # Note: can't use inspect.getfullargspec for some functions due to numpy issue # https://github.com/numpy/numpy/issues/12225 try: diff --git a/tests/lax_numpy_ufuncs_test.py b/tests/lax_numpy_ufuncs_test.py index fd5050a5829b..84923b1bc14a 100644 --- a/tests/lax_numpy_ufuncs_test.py +++ b/tests/lax_numpy_ufuncs_test.py @@ -56,7 +56,7 @@ def _jnp_ufunc_props(name): jnp_func = getattr(jnp, name) assert isinstance(jnp_func, jnp.ufunc) np_func = getattr(np, name) - dtypes = [np.dtype(c) for c in "Ffi?" if f"{c}{c}->{c}" in np_func.types or f"{c}->{c}" in np_func.types] + dtypes = [np.dtype(c) for c in "FfIi?" if f"{c}{c}->{c}" in np_func.types or f"{c}->{c}" in np_func.types] return [dict(name=name, dtype=dtype) for dtype in dtypes] @@ -87,6 +87,9 @@ def _jnp_ufunc_props(name): broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)] nonscalar_shapes = [(3,), (4,), (4, 3)] +empty_shapes = [(0, 3), (4, 0)] +all_shapes = nonscalar_shapes + empty_shapes + def cast_outputs(fun): def wrapped(*args, **kwargs): @@ -226,26 +229,34 @@ def test_binary_ufunc_outer(self, name, lhs_shape, rhs_shape, dtype): @jtu.sample_product( SCALAR_FUNCS, [{'shape': shape, 'axis': axis} - for shape in nonscalar_shapes + for shape in all_shapes for axis in [None, *range(-len(shape), len(shape))]], dtype=jtu.dtypes.floating, ) def test_frompyfunc_reduce(self, func, nin, nout, identity, shape, axis, dtype): if (nin, nout) != (2, 1): self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}") + jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis) np_fun = cast_outputs(partial(np.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis)) rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] + along_zero_dim = (0 in shape) if axis is None else shape[axis] == 0 + if identity is None and along_zero_dim: + with self.assertRaises(ValueError): + jnp_fun(*args_maker()) + return + self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( BINARY_UFUNCS_WITH_DTYPES, [{'shape': shape, 'axis': axis} - for shape in nonscalar_shapes + for shape in all_shapes for axis in [None, *range(-len(shape), len(shape))]], ) def test_binary_ufunc_reduce(self, name, shape, axis, dtype): @@ -261,6 +272,11 @@ def test_binary_ufunc_reduce(self, name, shape, axis, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] + if jnp_fun.identity is None and axis is not None and shape[axis] == 0: + with self.assertRaises(ValueError): + jnp_fun_reduce(*args_maker()) + return + tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker, tol=tol) @@ -269,7 +285,7 @@ def test_binary_ufunc_reduce(self, name, shape, axis, dtype): @jtu.sample_product( SCALAR_FUNCS, [{'shape': shape, 'axis': axis} - for shape in nonscalar_shapes + for shape in all_shapes for axis in [None, *range(-len(shape), len(shape))]], dtype=jtu.dtypes.floating, ) @@ -302,7 +318,7 @@ def np_fun(arr, where): @jtu.sample_product( BINARY_UFUNCS_WITH_DTYPES, [{'shape': shape, 'axis': axis} - for shape in nonscalar_shapes + for shape in all_shapes for axis in [None, *range(-len(shape), len(shape))]], ) def test_binary_ufunc_reduce_where(self, name, shape, axis, dtype): @@ -324,10 +340,68 @@ def test_binary_ufunc_reduce_where(self, name, shape, axis, dtype): self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker, tol=tol) self._CompileAndCheck(jnp_fun_reduce, args_maker) + @jtu.sample_product( + BINARY_UFUNCS_WITH_DTYPES, + [{'shape': shape, 'axis': axis} + for shape in all_shapes + for axis in [None, *range(-len(shape), len(shape))]], + ) + def test_binary_ufunc_reduce_initial(self, name, shape, axis, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + + if jnp_fun.identity is None and axis is None and len(shape) > 1: + self.skipTest("Multiple-axis reduction over non-reorderable ufunc.") + + jnp_fun_reduce = lambda a, initial: jnp_fun.reduce(a, axis=axis, initial=initial) + np_fun_reduce = lambda a, initial: np_fun.reduce(a, axis=axis, initial=initial) + + rng = jtu.rand_default(self.rng()) + rng_initial = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype), rng_initial((), dtype)] + + tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None + + self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker, tol=tol) + self._CompileAndCheck(jnp_fun_reduce, args_maker) + + @jtu.sample_product( + BINARY_UFUNCS_WITH_DTYPES, + [{'shape': shape, 'axis': axis} + for shape in all_shapes + for axis in [None, *range(-len(shape), len(shape))]], + ) + def test_binary_ufunc_reduce_where_initial(self, name, shape, axis, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + + # Skip if the ufunc doesn't have an identity and we're doing a multi-axis reduction + if jnp_fun.identity is None and axis is None and len(shape) > 1: + self.skipTest("Multiple-axis reduction over non-reorderable ufunc.") + + jnp_fun_reduce = lambda a, where, initial: jnp_fun.reduce( + a, axis=axis, where=where, initial=initial) + np_fun_reduce = lambda a, where, initial: np_fun.reduce( + a, axis=axis, where=where, initial=initial) + + rng = jtu.rand_default(self.rng()) + rng_where = jtu.rand_bool(self.rng()) + rng_initial = jtu.rand_default(self.rng()) + args_maker = lambda: [ + rng(shape, dtype), + rng_where(shape, bool), + rng_initial((), dtype) + ] + + tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None + + self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker, tol=tol) + self._CompileAndCheck(jnp_fun_reduce, args_maker) + @jtu.sample_product( SCALAR_FUNCS, [{'shape': shape, 'axis': axis} - for shape in nonscalar_shapes + for shape in all_shapes for axis in range(-len(shape), len(shape))], dtype=jtu.dtypes.floating, ) @@ -346,7 +420,7 @@ def test_frompyfunc_accumulate(self, func, nin, nout, identity, shape, axis, dty @jtu.sample_product( BINARY_UFUNCS_WITH_DTYPES, [{'shape': shape, 'axis': axis} - for shape in nonscalar_shapes + for shape in all_shapes for axis in range(-len(shape), len(shape))] ) def test_binary_ufunc_accumulate(self, name, shape, axis, dtype): @@ -375,7 +449,7 @@ def np_fun_accumulate(x): ) def test_frompyfunc_at(self, func, nin, nout, identity, shape, idx_shape, dtype): if (nin, nout) != (2, 1): - self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}") + self.skipTest(f"at requires (nin, nout)=(2, 1); got {(nin, nout)=}") jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).at, inplace=False) def np_fun(x, idx, y): x_copy = x.copy() @@ -453,14 +527,15 @@ def np_fun(x, idx, y): @jtu.sample_product( SCALAR_FUNCS, [{'shape': shape, 'axis': axis} - for shape in nonscalar_shapes - for axis in [*range(-len(shape), len(shape))]], + for shape in all_shapes + for axis in [*range(-len(shape), len(shape))] + if shape[axis] != 0], idx_shape=[(0,), (3,), (5,)], dtype=jtu.dtypes.floating, ) def test_frompyfunc_reduceat(self, func, nin, nout, identity, shape, axis, idx_shape, dtype): if (nin, nout) != (2, 1): - self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}") + self.skipTest(f"reduceat requires (nin, nout)=(2, 1); got {(nin, nout)=}") jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduceat, axis=axis) np_fun = cast_outputs(partial(np.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduceat, axis=axis)) @@ -474,31 +549,33 @@ def test_frompyfunc_reduceat(self, func, nin, nout, identity, shape, axis, idx_s @jtu.sample_product( BINARY_UFUNCS_WITH_DTYPES, [{'shape': shape, 'axis': axis} - for shape in nonscalar_shapes - for axis in [*range(-len(shape), len(shape))]], + for shape in all_shapes + for axis in [*range(-len(shape), len(shape))] + if shape[axis] != 0], idx_shape=[(0,), (3,), (5,)], ) def test_binary_ufunc_reduceat(self, name, shape, axis, idx_shape, dtype): jnp_fun = getattr(jnp, name) np_fun = getattr(np, name) if (jnp_fun.nin, jnp_fun.nout) != (2, 1): - self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}") + self.skipTest(f"reduceat requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}") if name in ['add', 'multiply'] and dtype == bool: - # TODO(jakevdp): figure out how to fix thest cases. + # TODO(jakevdp): figure out how to fix test cases. self.skipTest(f"known failure for {name}.reduceat with {dtype=}") rng = jtu.rand_default(self.rng()) idx_rng = jtu.rand_int(self.rng(), low=0, high=shape[axis]) args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32')] + jnp_fun_reduceat = partial(jnp_fun.reduceat, axis=axis) def np_fun_reduceat(x, i): # Numpy has different casting behavior. - return np_fun.reduceat(x, i).astype(x.dtype) + return np_fun.reduceat(x, i, axis=axis).astype(x.dtype) tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None - self._CheckAgainstNumpy(jnp_fun.reduceat, np_fun_reduceat, args_maker, tol=tol) - self._CompileAndCheck(jnp_fun.reduceat, args_maker) + self._CheckAgainstNumpy(jnp_fun_reduceat, np_fun_reduceat, args_maker, tol=tol) + self._CompileAndCheck(jnp_fun_reduceat, args_maker) if __name__ == "__main__": diff --git a/tests/lax_numpy_vectorize_test.py b/tests/lax_numpy_vectorize_test.py index 985dba484845..8fbd393dc3f8 100644 --- a/tests/lax_numpy_vectorize_test.py +++ b/tests/lax_numpy_vectorize_test.py @@ -287,6 +287,15 @@ def test_rank_promotion_error(self): with self.assertNoWarnings(): f2(rank2, rank1) + def test_non_scalar_outputs_and_default_signature(self): + def f(x): + self.assertEqual(np.shape(x), ()) + return x + jnp.linspace(-1, 1, out_dim) + + out_dim = 5 + self.assertEqual(jnp.vectorize(f)(0.5).shape, (out_dim,)) + self.assertEqual(jnp.vectorize(f)(jnp.ones(3)).shape, (3, out_dim)) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_scipy_sparse_test.py b/tests/lax_scipy_sparse_test.py index 303c67c5860d..21169a4d9a1a 100644 --- a/tests/lax_scipy_sparse_test.py +++ b/tests/lax_scipy_sparse_test.py @@ -261,7 +261,7 @@ def args_maker(): dtype=float_types + complex_types, preconditioner=[None, 'identity', 'exact'], ) - @jtu.skip_on_devices("gpu") + @jtu.skip_on_devices("cuda") def test_bicgstab_on_identity_system(self, shape, dtype, preconditioner): A = jnp.eye(shape[1], dtype=dtype) solution = jnp.ones(shape[1], dtype=dtype) @@ -280,7 +280,7 @@ def test_bicgstab_on_identity_system(self, shape, dtype, preconditioner): dtype=float_types + complex_types, preconditioner=[None, 'identity', 'exact'], ) - @jtu.skip_on_devices("gpu") + @jtu.skip_on_devices("cuda") def test_bicgstab_on_random_system(self, shape, dtype, preconditioner): rng = jtu.rand_default(self.rng()) A = rng(shape, dtype) @@ -367,7 +367,7 @@ def args_maker(): preconditioner=[None, 'identity', 'exact'], solve_method=['batched', 'incremental'], ) - @jtu.skip_on_devices("gpu") + @jtu.skip_on_devices("cuda") def test_gmres_on_identity_system(self, shape, dtype, preconditioner, solve_method): A = jnp.eye(shape[1], dtype=dtype) @@ -391,7 +391,7 @@ def test_gmres_on_identity_system(self, shape, dtype, preconditioner, preconditioner=[None, 'identity', 'exact'], solve_method=['incremental', 'batched'], ) - @jtu.skip_on_devices("gpu") + @jtu.skip_on_devices("cuda") def test_gmres_on_random_system(self, shape, dtype, preconditioner, solve_method): rng = jtu.rand_default(self.rng()) diff --git a/tests/lax_scipy_special_functions_test.py b/tests/lax_scipy_special_functions_test.py index f4e4e4f48213..4b4bfd6b0e9e 100644 --- a/tests/lax_scipy_special_functions_test.py +++ b/tests/lax_scipy_special_functions_test.py @@ -25,6 +25,7 @@ import jax import jax.numpy as jnp +from jax._src import dtypes from jax._src import test_util as jtu from jax.scipy import special as lsp_special @@ -89,6 +90,9 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t op_record( "expit", 1, float_dtypes, jtu.rand_small_positive, True ), + op_record( + "sici", 1, float_dtypes, jtu.rand_default, True + ), # TODO: gammaln has slightly high error. op_record( "gammaln", 1, float_dtypes, jtu.rand_positive, False @@ -157,6 +161,10 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t "hyp1f1", 3, float_dtypes, functools.partial(jtu.rand_uniform, low=0.5, high=30), True ), + op_record( + "hyp2f1", 4, float_dtypes, + functools.partial(jtu.rand_uniform, low=0.1, high=0.9), True + ), op_record("log_softmax", 1, float_dtypes, jtu.rand_default, True), op_record("softmax", 1, float_dtypes, jtu.rand_default, True), ] @@ -256,6 +264,17 @@ def testNdtriExtremeValues(self): self._CheckAgainstNumpy(osp_special.ndtri, lsp_special.ndtri, args_maker, rtol=rtol) self._CompileAndCheck(lsp_special.ndtri, args_maker, rtol=rtol) + @parameterized.parameters([True, False]) + def testNdtriDebugInfs(self, with_jit): + # ref: https://github.com/jax-ml/jax/issues/29328 + f = jax.jit(lsp_special.ndtri) if with_jit else lsp_special.ndtri + with jax.debug_infs(True): + f(0.5) # Doesn't crash + with self.assertRaisesRegex(FloatingPointError, "invalid value \\(inf\\)"): + f(1.0) + with self.assertRaisesRegex(FloatingPointError, "invalid value \\(inf\\)"): + f(0.0) + def testRelEntrExtremeValues(self): # Testing at the extreme values (bounds (0. and 1.) and outside the bounds). dtype = jnp.zeros(0).dtype # default float dtype. @@ -288,35 +307,117 @@ def testExpiDisableJit(self): self.assertAllClose(result_jit, result_nojit) def testGammaIncBoundaryValues(self): - dtype = jax.numpy.zeros(0).dtype # default float dtype. + dtype = dtypes.default_float_dtype() nan = float('nan') inf = float('inf') if jtu.parse_version(scipy.__version__) >= (1, 16): - samples_slice = slice(None) + a_samples = [0, 0, 0, 1, nan, 1, nan, 0, 1, 1, nan] + x_samples = [0, 1, 2, 0, 1, nan, nan, inf, inf, -1, inf] else: # disable samples that contradict with scipy/scipy#22441 - samples_slice = slice(None, -1) - args_maker = lambda: [np.array([0, 0, 0, 1, nan, 1, nan, 0, 1, nan][samples_slice]).astype(dtype), - np.array([0, 1, 2, 0, 1, nan, nan, inf, inf, inf][samples_slice]).astype(dtype)] + a_samples = [0, 0, 0, 1, nan, 1, nan, 0, 1, 1] + x_samples = [0, 1, 2, 0, 1, nan, nan, inf, inf, -1] + + args_maker = lambda: (np.array(a_samples, dtype=dtype), np.array(x_samples, dtype=dtype)) + rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5 self._CheckAgainstNumpy(lsp_special.gammainc, osp_special.gammainc, args_maker, rtol=rtol) self._CompileAndCheck(lsp_special.gammainc, args_maker, rtol=rtol) def testGammaIncCBoundaryValues(self): - dtype = jax.numpy.zeros(0).dtype # default float dtype. + dtype = dtypes.default_float_dtype() nan = float('nan') inf = float('inf') if jtu.parse_version(scipy.__version__) >= (1, 16): - samples_slice = slice(None) + a_samples = [0, 0, 0, 1, nan, 1, nan, 0, 1, 1, nan] + x_samples = [0, 1, 2, 0, 1, nan, nan, inf, inf, -1, inf] else: # disable samples that contradict with scipy/scipy#22441 - samples_slice = slice(None, -1) - args_maker = lambda: [np.array([0, 0, 0, 1, nan, 1, nan, 0, 1, 1, nan][samples_slice]).astype(dtype), - np.array([0, 1, 2, 0, 1, nan, nan, inf, inf, -1, inf][samples_slice]).astype(dtype)] + a_samples = [0, 0, 0, 1, nan, 1, nan, 0, 1, 1] + x_samples = [0, 1, 2, 0, 1, nan, nan, inf, inf, -1] + + args_maker = lambda: (np.array(a_samples, dtype=dtype), np.array(x_samples, dtype=dtype)) + rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5 self._CheckAgainstNumpy(lsp_special.gammaincc, osp_special.gammaincc, args_maker, rtol=rtol) self._CompileAndCheck(lsp_special.gammaincc, args_maker, rtol=rtol) + def testBetaIncBoundaryValues(self): + dtype = dtypes.default_float_dtype() + fi = jax.numpy.finfo(dtype) + nan = float('nan') + inf = float('inf') + tiny = fi.tiny + eps = fi.eps + if jtu.parse_version(scipy.__version__) >= (1, 16): + # TODO(pearu): enable tiny samples when a fix to scipy/scipy#22682 + # will be available + a_samples = [nan, -0.5, inf, 0, eps, 1, tiny][:-1] + b_samples = [nan, -0.5, inf, 0, eps, 1, tiny][:-1] + elif jtu.parse_version(scipy.__version__) >= (1, 12): + # disabled samples that contradict with scipy/scipy#22425 + a_samples = [nan, -0.5, 0.5] + b_samples = [nan, -0.5, 0.5] + else: + a_samples = [-0.5, 0.5] + b_samples = [-0.5, 0.5] + x_samples = [nan, -0.5, 0, 0.5, 1, 1.5] + + a_samples = np.array(a_samples, dtype=dtype) + b_samples = np.array(b_samples, dtype=dtype) + x_samples = np.array(x_samples, dtype=dtype) + + args_maker = lambda: np.meshgrid(a_samples, b_samples, x_samples) + + rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 5e-5 + self._CheckAgainstNumpy(osp_special.betainc, lsp_special.betainc, args_maker, rtol=rtol) + self._CompileAndCheck(lsp_special.betainc, args_maker, rtol=rtol) + + def testHyp2f1SpecialCases(self): + dtype = dtypes.default_float_dtype() + + a_samples = np.array([0, 1, 1, 1, 1, 5, 5, 0.245, 0.45, 0.45, 2, 0.4, 0.32, 4, 4], dtype=dtype) + b_samples = np.array([1, 0, 1, 1, 1, 1, 1, 3, 0.7, 0.7, 1, 0.7, 0.76, 2, 3], dtype=dtype) + c_samples = np.array([1, 1, 0, 1, -1, 3, 3, 3, 0.45, 0.45, 5, 0.3, 0.11, 7, 7], dtype=dtype) + x_samples = np.array([1, 1, 1, 0, 1, 0.5, 1, 0.35, 0.35, 1.5, 1, 0.4, 0.95, 0.95, 0.95], dtype=dtype) + + args_maker = lambda: (a_samples, b_samples, c_samples, x_samples) + rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 5e-5 + self._CheckAgainstNumpy(osp_special.hyp2f1, lsp_special.hyp2f1, args_maker, rtol=rtol) + self._CompileAndCheck(lsp_special.hyp2f1, args_maker, rtol=rtol) + + def testSiciEdgeCases(self): + dtype = jnp.zeros(0).dtype + x_samples = np.array([0.0, np.inf, -np.inf], dtype=dtype) + scipy_op = lambda x: osp_special.sici(x) + lax_op = lambda x: lsp_special.sici(x) + si_scipy, ci_scipy = scipy_op(x_samples) + si_jax, ci_jax = lax_op(x_samples) + + expected_si = np.array([0.0, np.pi/2, -np.pi/2], dtype=dtype) + expected_ci = np.array([-np.inf, 0.0, np.nan], dtype=dtype) + self.assertAllClose(si_jax, si_scipy, atol=1e-6, rtol=1e-6) + self.assertAllClose(ci_jax, ci_scipy, atol=1e-6, rtol=1e-6) + self.assertAllClose(si_jax, expected_si, atol=1e-6, rtol=1e-6) + self.assertAllClose(ci_jax, expected_ci, atol=1e-6, rtol=1e-6) + + @jtu.sample_product( + scale=[1, 10, 1e9], + shape=[(5,), (10,)] + ) + def testSiciValueRanges(self, scale, shape): + rng = jtu.rand_default(self.rng(), scale=scale) + args_maker = lambda: [rng(shape, jnp.float32)] + rtol = 5e-3 if jtu.test_device_matches(["tpu"]) else 1e-6 + self._CheckAgainstNumpy( + osp_special.sici, lsp_special.sici, args_maker, rtol=rtol) + + def testSiciRaiseOnComplexInput(self): + samples = jnp.arange(5, dtype=complex) + with self.assertRaisesRegex(ValueError, "Argument `x` to sici must be real-valued."): + lsp_special.sici(samples) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_scipy_spectral_dac_test.py b/tests/lax_scipy_spectral_dac_test.py index a09dcac5371c..4359318a7997 100644 --- a/tests/lax_scipy_spectral_dac_test.py +++ b/tests/lax_scipy_spectral_dac_test.py @@ -18,7 +18,7 @@ from jax import lax from jax import numpy as jnp from jax._src import test_util as jtu -from jax._src.lax import eigh as lax_eigh +from jax._src.tpu.linalg import eigh as lax_eigh from absl.testing import absltest diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index 388d053d9608..fbc74cc6d072 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -113,10 +113,12 @@ def testLogSumExp(self, shapes, dtype, axis, keepdims, return_sign, use_b): if jnp.issubdtype(dtype, jnp.complexfloating) and scipy_version < (1, 13, 0): self.skipTest("logsumexp of complex input uses scipy 1.13.0 semantics.") - if not jtu.test_device_matches(["cpu"]): - rng = jtu.rand_some_inf_and_nan(self.rng()) - else: - rng = jtu.rand_default(self.rng()) + if use_b and scipy_version >= (1, 15) and scipy_version < (1, 15, 3): + self.skipTest( + "TODO(https://github.com/scipy/scipy/issues/22903): logsumexp with a" + " b scale array is buggy in scipy 1.15" + ) + rng = jtu.rand_default(self.rng()) # TODO(mattjj): test autodiff if use_b: def scipy_fun(array_to_reduce, scale_array): @@ -189,6 +191,11 @@ def testLogSumExpNans(self): result = lsp_special.logsumexp(1.0, b=1.0) self.assertEqual(result, 1.0) + def testLogSumExpInfs(self): + out, sign = lsp_special.logsumexp(jnp.array([1.0, np.inf]), return_sign=True) + self.assertEqual(out, np.inf) + self.assertEqual(sign, 1.0) + @jtu.sample_product( shape=[(0,), (1,), (2,), (3,), (4,), (5,)], dtype=float_dtypes, @@ -338,9 +345,10 @@ def scipy_fun(z): dtype=float_dtypes, ) @jtu.ignore_warning(category=DeprecationWarning, message=".*scipy.special.lpmn.*") + @unittest.skipIf(scipy_version >= (1, 17, 0), "scipy.special.lpmn has been removed.") def testLpmn(self, l_max, shape, dtype): - if jtu.is_device_tpu(6, "e"): - self.skipTest("TODO(b/364258243): fails on TPU v6e") + if jtu.is_device_tpu_at_least(6): + self.skipTest("TODO(b/364258243): fails on TPU v6+") rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9) args_maker = lambda: [rng(shape, dtype)] @@ -361,6 +369,7 @@ def scipy_fun(z, m=l_max, n=l_max): dtype=float_dtypes, ) @jtu.ignore_warning(category=DeprecationWarning, message=".*scipy.special.lpmn.*") + @unittest.skipIf(scipy_version >= (1, 17, 0), "scipy.special.lpmn_values has been removed.") def testNormalizedLpmnValues(self, l_max, shape, dtype): rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9) args_maker = lambda: [rng(shape, dtype)] @@ -390,8 +399,11 @@ def scipy_fun(z, m=l_max, n=l_max): @jtu.ignore_warning(category=DeprecationWarning, message=".*scipy.special.sph_harm.*") + @unittest.skipIf(scipy_version >= (1, 17, 0), "scipy.special.sph_harm has been removed.") @jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion def testSphHarmAccuracy(self): + if not hasattr(lsp_special, 'sph_harm'): + self.skipTest("jax.scipy.special.sph_harm has been removed.") m = jnp.arange(-3, 3)[:, None] n = jnp.arange(3, 6) n_max = 5 @@ -406,9 +418,12 @@ def testSphHarmAccuracy(self): @jtu.ignore_warning(category=DeprecationWarning, message=".*scipy.special.sph_harm.*") + @unittest.skipIf(scipy_version >= (1, 17, 0), "scipy.special.sph_harm has been removed.") @jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion def testSphHarmOrderZeroDegreeZero(self): """Tests the spherical harmonics of order zero and degree zero.""" + if not hasattr(lsp_special, 'sph_harm'): + self.skipTest("jax.scipy.special.sph_harm has been removed.") theta = jnp.array([0.3]) phi = jnp.array([2.3]) n_max = 0 @@ -421,9 +436,12 @@ def testSphHarmOrderZeroDegreeZero(self): @jtu.ignore_warning(category=DeprecationWarning, message=".*scipy.special.sph_harm.*") + @unittest.skipIf(scipy_version >= (1, 17, 0), "scipy.special.sph_harm has been removed.") @jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion def testSphHarmOrderZeroDegreeOne(self): """Tests the spherical harmonics of order one and degree zero.""" + if not hasattr(lsp_special, 'sph_harm'): + self.skipTest("jax.scipy.special.sph_harm has been removed.") theta = jnp.array([2.0]) phi = jnp.array([3.1]) n_max = 1 @@ -436,9 +454,12 @@ def testSphHarmOrderZeroDegreeOne(self): @jtu.ignore_warning(category=DeprecationWarning, message=".*scipy.special.sph_harm.*") + @unittest.skipIf(scipy_version >= (1, 17, 0), "scipy.special.sph_harm has been removed.") @jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion def testSphHarmOrderOneDegreeOne(self): """Tests the spherical harmonics of order one and degree one.""" + if not hasattr(lsp_special, 'sph_harm'): + self.skipTest("jax.scipy.special.sph_harm has been removed.") theta = jnp.array([2.0]) phi = jnp.array([2.5]) n_max = 1 @@ -458,11 +479,14 @@ def testSphHarmOrderOneDegreeOne(self): ) @jtu.ignore_warning(category=DeprecationWarning, message=".*scipy.special.sph_harm.*") + @unittest.skipIf(scipy_version >= (1, 17, 0), "scipy.special.sph_harm has been removed.") @jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype): """Tests against JIT compatibility and Numpy.""" - if jtu.is_device_tpu(6, "e"): - self.skipTest("TODO(b/364258243): fails on TPU v6e") + if not hasattr(lsp_special, 'sph_harm'): + self.skipTest("jax.scipy.special.sph_harm has been removed.") + if jtu.is_device_tpu_at_least(6): + self.skipTest("TODO(b/364258243): fails on TPU v6+") n_max = l_max shape = (num_z,) rng = jtu.rand_int(self.rng(), -l_max, l_max + 1) @@ -484,9 +508,12 @@ def args_maker(): @jtu.ignore_warning(category=DeprecationWarning, message=".*scipy.special.sph_harm.*") + @unittest.skipIf(scipy_version >= (1, 17, 0), "scipy.special.sph_harm has been removed.") @jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion def testSphHarmCornerCaseWithWrongNmax(self): """Tests the corner case where `n_max` is not the maximum value of `n`.""" + if not hasattr(lsp_special, 'sph_harm'): + self.skipTest("jax.scipy.special.sph_harm has been removed.") m = jnp.array([2]) n = jnp.array([10]) n_clipped = jnp.array([6]) @@ -508,8 +535,8 @@ def testSphHarmCornerCaseWithWrongNmax(self): ) @jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion def testSphHarmY(self, l_max, num_z, dtype): - if jtu.is_device_tpu(6, "e"): - self.skipTest("TODO(b/364258243): fails on TPU v6e") + if jtu.is_device_tpu_at_least(6): + self.skipTest("TODO(b/364258243): fails on TPU v6+") n_max = l_max shape = (num_z,) rng = jtu.rand_int(self.rng(), -l_max, l_max + 1) @@ -641,7 +668,7 @@ def test_spence(self, shape, dtype): ], dtype=float_dtypes + int_dtypes, ) - @jtu.skip_on_devices("tpu") # TODO(jakevdp): fix and reenable this test. + @jtu.skip_on_devices("tpu") # TODO(jakevdp): fix and re-enable this test. @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. def testIntegrateTrapezoid(self, yshape, xshape, dtype, dx, axis): rng = jtu.rand_default(self.rng()) diff --git a/tests/lax_test.py b/tests/lax_test.py index 8764caeb2e49..7d189b50dbaf 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -28,17 +28,16 @@ import numpy as np import jax -from jax._src import core +from jax import export from jax import jvp, grad from jax import lax import jax.numpy as jnp from jax.test_util import check_grads -import jax.util from jax.interpreters import batching -from jax.interpreters import xla from jax._src import array from jax._src import config +from jax._src import core from jax._src import dtypes from jax._src import lax_reference from jax._src import test_util as jtu @@ -47,9 +46,9 @@ from jax._src.interpreters import pxla from jax._src.internal_test_util import lax_test_util from jax._src.lax import lax as lax_internal -from jax._src.util import NumpyComplexWarning, safe_zip +from jax._src.lax import utils as lax_utils +from jax._src.util import safe_zip from jax._src.tree_util import tree_map -from jax._src.lib import xla_extension_version config.parse_flags_with_absl() @@ -136,6 +135,25 @@ def testOpAgainstNumpy(self, op_name, rng_factory, shapes, dtype, tol): tol = jtu.join_tolerance(tol, 2e-15) self._CheckAgainstNumpy(numpy_op, op, args_maker, tol=tol) + @parameterized.parameters(["logistic", "tanh"]) + def testEvenFunctionGrads(self, op_name): + op = getattr(lax, op_name) + x = jnp.arange(0.0, 80.0, 1.0, dtype=jnp.float32) + high_acc_op = lambda x: op(x, accuracy=lax.AccuracyMode.HIGHEST) + grads = jax.vmap(jax.grad(high_acc_op))(x) + neg_grads = jax.vmap(jax.grad(high_acc_op))(-x) + self.assertAllClose( + grads, neg_grads, atol=jtu.default_tolerance()[np.dtype(np.float32)], rtol=0.0 + ) + + def testExpm1Grad(self): + x = jnp.arange(-80.0, 80.0, 1.0, dtype=jnp.float32) + expected = jax.vmap(jax.grad(lambda x: lax.exp(x, accuracy=lax.AccuracyMode.HIGHEST)))(x) + actual = jax.vmap(jax.grad(lambda x: lax.expm1(x, accuracy=lax.AccuracyMode.HIGHEST)))(x) + self.assertAllClose( + actual, expected, atol=jtu.default_tolerance()[np.dtype(np.float32)], rtol=0.0 + ) + # TODO test shift_left, shift_right_arithmetic, shift_right_logical @jtu.sample_product( @@ -147,7 +165,12 @@ def testOpAgainstNumpy(self, op_name, rng_factory, shapes, dtype, tol): def testConvertElementType(self, from_dtype, to_dtype, weak_type): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng((2, 3), from_dtype)] - op = lambda x: lax_internal._convert_element_type(x, to_dtype, weak_type) + to_dtype_canonicalized = ( + dtypes.canonicalize_dtype(to_dtype) if to_dtype is not None else None + ) + op = lambda x: lax_internal._convert_element_type( + x, to_dtype_canonicalized, weak_type + ) self._CompileAndCheck(op, args_maker) x = rng((1,), from_dtype) @@ -178,8 +201,8 @@ def testConvertElementTypeAgainstNumpy(self, from_dtype, to_dtype): ) def testBitcastConvertType(self, from_dtype, to_dtype, shape): rng = jtu.rand_default(self.rng()) - nbits_in = dtypes.bit_width(from_dtype) - nbits_out = dtypes.bit_width(to_dtype) + nbits_in = dtypes.itemsize_bits(from_dtype) + nbits_out = dtypes.itemsize_bits(to_dtype) if nbits_in < nbits_out: shape = (*shape, nbits_out // nbits_in) args_maker = lambda: [rng(shape, from_dtype)] @@ -206,8 +229,8 @@ def testBitcastConvertType(self, from_dtype, to_dtype, shape): shape=[(4,), (2, 4), (2, 3, 4)] ) def testBitcastConvertTypeAgainstNumpy(self, from_dtype, to_dtype, shape): - nbits_in = dtypes.bit_width(from_dtype) - nbits_out = dtypes.bit_width(to_dtype) + nbits_in = dtypes.itemsize_bits(from_dtype) + nbits_out = dtypes.itemsize_bits(to_dtype) if nbits_in < nbits_out: shape = (*shape, nbits_out // nbits_in) rng = jtu.rand_default(self.rng()) @@ -224,7 +247,7 @@ def testBitcastConvertTypeAgainstNumpy(self, from_dtype, to_dtype, shape): ) def testBitcastConvertWeakType(self, from_dtype, to_dtype, weak_type): rng = jtu.rand_default(self.rng()) - x_in = lax_internal._convert_element_type(rng((2, 3), from_dtype), + x_in = lax_internal._convert_element_type(rng((2, 3), np.dtype(from_dtype)), weak_type=weak_type) op = lambda x: lax.bitcast_convert_type(x, to_dtype) self.assertEqual(dtypes.is_weakly_typed(x_in), weak_type) @@ -864,6 +887,9 @@ def _conv_transpose_via_grad(data, kernel, strides, padding, for i in range(nspatial)] elif padding == 'SAME': o_sdims = [in_sdims[i]*strides[i] for i in range(nspatial)] + else: + o_sdims = [in_sdims[i]*strides[i] + max(e_k_sdims[i]-strides[i],0) - np.sum(p) + for i, p in enumerate(padding)] o_shape = [in_shape[0], k_shape[1]] + o_sdims out_spec_inv = [x[0] for x in sorted(enumerate(dn.out_spec), key=lambda x: x[1])] @@ -899,7 +925,9 @@ def _transpose_conv_kernel(data, kernel, dimension_numbers): ], dtype=lax_test_util.float_dtypes, strides=[(1, 1), (1, 2), (2, 1), (2, 2), (3, 3)], - padding=["VALID", "SAME"], + padding=list(itertools.product( + itertools.product([0,1,2], [0,1,2]), + itertools.product([0,1,2], [0,1,2]))) + ["VALID", "SAME"], dspec=[ ("NHWC", "HWIO", "NHWC"), ], @@ -917,7 +945,8 @@ def fun(lhs, rhs): return lax.conv_transpose(lhs, rhs, strides, padding, rhs_dilation=rhs_dilation, dimension_numbers=dspec, - transpose_kernel=True) + transpose_kernel=True, + use_consistent_padding=True) def fun_via_grad(lhs, rhs): return self._conv_transpose_via_grad(lhs, rhs, strides, padding, @@ -939,7 +968,9 @@ def fun_via_grad(lhs, rhs): ], dtype=lax_test_util.float_dtypes, strides=[(1, 1), (1, 2), (2, 1), (2, 2), (3, 3)], - padding=["VALID", "SAME"], + padding=list(itertools.product( + itertools.product([0,1,2], [0,1,2]), + itertools.product([0,1,2], [0,1,2]))) + ["VALID", "SAME"], dspec=[ ("NHWC", "HWIO", "NHWC"), ], @@ -955,7 +986,8 @@ def fun(lhs, rhs): return lax.conv_transpose(lhs, rhs, strides, padding, rhs_dilation=rhs_dilation, dimension_numbers=dspec, - transpose_kernel=False) + transpose_kernel=False, + use_consistent_padding=True) def fun_via_grad(lhs, rhs): rhs_t = self._transpose_conv_kernel(lhs, rhs, dimension_numbers=dspec) @@ -1079,6 +1111,13 @@ def testDot(self, lhs_shape, rhs_shape, lhs_dtype, rhs_dtype, precision): args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] self._CompileAndCheck(partial(lax.dot, precision=precision), args_maker) + def testDotPositionalArgumentDeprecation(self): + lhs = jnp.arange(5.0) + rhs = jnp.arange(5.0) + + with self.assertRaisesRegex(TypeError, r"dot\(\) takes 2 positional arguments"): + lax.dot(lhs, rhs, lax.Precision.DEFAULT) + @parameterized.parameters([ (algorithm, dtype) for algorithm, test_dtypes in [ @@ -1128,11 +1167,6 @@ def testDotAlgorithm(self, algorithm, dtype): raise SkipTest( f"The dot algorithm '{algorithm}' is not supported on CPU.") if jtu.test_device_matches(["gpu"]): - if (algorithm == lax.DotAlgorithmPreset.BF16_BF16_F32_X9 and - xla_extension_version < 320): - raise SkipTest( - f"The dot algorithm ${algorithm} requires XLA extension version " - ">= 320.") # GPU algorithm support is a little spotty. It is checked in # xla/service/algorithm_util.cc and the logic is copied here. if algorithm in { @@ -1143,7 +1177,7 @@ def testDotAlgorithm(self, algorithm, dtype): lax.DotAlgorithmPreset.BF16_BF16_F32_X6, lax.DotAlgorithmPreset.BF16_BF16_F32_X9, }: - if not jtu.is_cuda_compute_capability_at_least("8.0"): + if jtu.test_device_matches(["cuda"]) and not jtu.is_cuda_compute_capability_at_least("8.0"): raise SkipTest( f"The dot algorithm '{algorithm}' requires CUDA compute " "capability >= 8.0.") @@ -1157,9 +1191,6 @@ def testDotAlgorithm(self, algorithm, dtype): raise SkipTest( f"The dot algorithm '{algorithm}' is not supported on GPU.") if jtu.test_device_matches(["tpu"]): - # TODO(apaszke): Remove after 12 weeks have passed. - if not jtu.if_cloud_tpu_at_least(2024, 12, 19): - self.skipTest("Requires libtpu built after 2024-12-19") if algorithm not in { lax.DotAlgorithmPreset.DEFAULT, lax.DotAlgorithmPreset.BF16_BF16_F32, @@ -1173,7 +1204,8 @@ def testDotAlgorithm(self, algorithm, dtype): rhs_shape = (4, 3) rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] - self._CompileAndCheck(partial(lax.dot, precision=algorithm), args_maker) + self._CompileAndCheck(partial(lax.dot, precision=algorithm), args_maker, + rtol={np.float64: 3e-15}) self.assertEqual(lax.dot(*args_maker(), precision=algorithm).dtype, dtype) def testDotAlgorithmInvalidFloat8Type(self): @@ -1232,7 +1264,8 @@ def testDotPreferredElement(self, lhs_shape, rhs_shape, dtype, preferred_element_type): if (not config.enable_x64.value and (dtype == np.float64 or preferred_element_type == np.float64 - or dtype == np.int64 or preferred_element_type == np.int64)): + or dtype == np.int64 or preferred_element_type == np.int64 + or dtype == np.complex128 or preferred_element_type == np.complex128)): raise SkipTest("64-bit mode disabled") if (jtu.test_device_matches(["tpu"]) and (dtype == np.complex128 or preferred_element_type == np.complex128)): @@ -1485,6 +1518,28 @@ def testBroadcastInDimAgainstNumpy(self, inshape, dtype, outshape, dimensions): numpy_op = lambda x: lax_reference.broadcast_in_dim(x, outshape, dimensions) self._CheckAgainstNumpy(numpy_op, op, args_maker) + @jtu.sample_product( + [ + dict(arg_shape=arg_shape, reps=reps) + for arg_shape, reps in [ + [(3,), (2,)], + [(2, 3), (1, 0)], + [(2, 3), (1, 2)], + [(2, 3), (2, 1)], + [(2, 1, 3), (1, 2, 3)], + [(1, 1, 4), (1, 3, 1)], + ] + ], + dtype=lax_test_util.default_dtypes, + ) + def testTile(self, arg_shape, reps, dtype): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(arg_shape, dtype)] + op = lambda x: lax.tile(x, reps) + numpy_op = lambda x: np.tile(x, reps) + self._CompileAndCheck(op, args_maker) + self._CheckAgainstNumpy(numpy_op, op, args_maker) + @parameterized.parameters( {"inshape": inshape, "dimensions": dimensions, "error_type": error_type, "err_msg": err_msg} @@ -1632,6 +1687,32 @@ def testPadErrors(self): with self.assertRaisesRegex(ValueError, "Dimension size after padding is not at least 0"): lax.pad(np.zeros(2), 0., [(-4, 0, 1)]) + @jtu.sample_product( + [dict(in_shape=in_shape, window_shape=window_shape, + window_strides=window_strides, padding=padding) + for in_shape, window_shape, window_strides, padding in [ + ((10, 10), (5, 5), (1, 1), 'SAME'), + ((8, 8), (3, 3), (2, 2), 'SAME_LOWER'), + ((7, 7), (3, 3), (2, 2), 'VALID'), + ] + ], + ) + def testPadtypeToPadsReturnsInts(self, in_shape, window_shape, window_strides, + padding): + """Test that padtype_to_pads returns Python ints, not NumPy scalars.""" + in_shape_arr = np.array(in_shape, dtype=np.int64) + window_shape_arr = np.array(window_shape, dtype=np.int64) + window_strides_arr = np.array(window_strides, dtype=np.int64) + + pads = lax.padtype_to_pads(in_shape_arr, window_shape_arr, + window_strides_arr, padding) + + for i, (low, high) in enumerate(pads): + self.assertIsInstance(low, int, + f"Padding dimension {i} low value is {type(low)}, expected int") + self.assertIsInstance(high, int, + f"Padding dimension {i} high value is {type(high)}, expected int") + def testReverse(self): rev = jax.jit(lambda operand: lax.rev(operand, dimensions)) @@ -1984,7 +2065,6 @@ def reference_fun(operand): self._CompileAndCheck(lax_fun, args_maker) self._CheckAgainstNumpy(reference_fun, lax_fun, args_maker) - @jtu.sample_product( op=["add", "mul"], op_namespace=[lax, operator], @@ -1993,9 +2073,10 @@ def reference_fun(operand): ) def testReduceWeakType(self, op_namespace, op, arr_weak_type, init_weak_type): op = getattr(op_namespace, op) - arr = lax_internal._convert_element_type(np.arange(10), int, + arr = lax_internal._convert_element_type(np.arange(10), dtypes.dtype(int), weak_type=arr_weak_type) - init = lax_internal._convert_element_type(1, int, weak_type=init_weak_type) + init = lax_internal._convert_element_type(1, dtypes.dtype(int), + weak_type=init_weak_type) fun = lambda arr, init: lax.reduce(arr, init, op, (0,)) out = fun(arr, init) self.assertEqual(dtypes.is_weakly_typed(out), arr_weak_type and init_weak_type) @@ -2132,7 +2213,7 @@ def fun(operand): ) ], ) - @jtu.skip_on_devices('gpu') # jax.lax.mul has an XLA bug on GPU b/339071103 + @jtu.skip_on_devices('cuda') # jax.lax.mul has an XLA bug on CUDA GPU b/339071103 @jtu.skip_on_devices('tpu') # b/39342488 def testReduceWindowGeneralJVP( self, @@ -2228,7 +2309,7 @@ def fun2(operand): ) ], ) - @jtu.skip_on_devices('gpu') # jax.lax.mul has an XLA bug on GPU b/339071103 + @jtu.skip_on_devices('cuda') # jax.lax.mul has an XLA bug on CUDA GPU b/339071103 @jtu.skip_on_devices('tpu') # b/39342488 def testReduceWindowCustomSameAsMonoid( self, @@ -2351,7 +2432,7 @@ def fun(op_, operand_): ], dtype=[np.float32], ) - @jtu.skip_on_devices('gpu') + @jtu.skip_on_devices('cuda') def testReduceWindowVariadic(self, dtype, shape, dims, strides, padding, base_dilation, window_dilation): if (jtu.test_device_matches(["tpu"]) and @@ -2427,7 +2508,7 @@ def testReduceWindowShapeDilation(self, shape, window_dimensions, window_dimensions=window_dimensions) # With a stride of 1 in each direction and a padding of 'SAME', the # shape of the input should be equal to the shape of the result according - # to https://www.tensorflow.org/xla/operation_semantics#reducewindow. + # to https://www.openxla.org/xla/operation_semantics#reducewindow. self.assertEqual(shape, result.shape) def testReduceWindowWithEmptyOutput(self): @@ -2530,8 +2611,8 @@ def testSortFloatSpecialValues(self, dtype): # - NaNs are sorted to the end, regardless of representation # - sign bit of 0.0 is ignored x = jnp.array([-np.inf, 0.0, -0.0, np.inf, np.nan, -np.nan], dtype=dtype) - index = lax.iota(dtypes.int_, x.size) - argsort = lambda x: lax.sort_key_val(x, lax.iota(dtypes.int_, x.size), is_stable=True)[1] + index = lax.iota(int, x.size) + argsort = lambda x: lax.sort_key_val(x, lax.iota(int, x.size), is_stable=True)[1] self.assertArraysEqual(argsort(x), index) self.assertArraysEqual(jax.jit(argsort)(x), index) @@ -2617,25 +2698,24 @@ def args_maker(): @jtu.sample_product( dtype=[np.float32, np.int32, np.uint32], - shape=[(20,), (5, 20), (2000,)], - k=[1, 3, 12], - negative=[False, True] + shape=[(20,), (8, 20), (2000,)], + k=[1, 3, 8], + axis=[0, -1] ) - def testTopK(self, shape, dtype, k, negative): + def testTopK(self, shape, dtype, k, axis): + rng = jtu.rand_some_equal(self.rng()) def args_maker(): - flat_values = np.arange(math.prod(shape), dtype=dtype) - values = self.rng().permutation(flat_values).reshape(shape) - if negative: - values = -values - return [values] - def reference_top_k(x): - bcast_idxs = np.broadcast_to(np.arange(shape[-1], dtype=np.int32), shape) - sorted_vals, sorted_idxs = lax_reference.sort_key_val(x, bcast_idxs) - return sorted_vals[..., :-k-1:-1], sorted_idxs[..., :-k-1:-1] - op = lambda vs: lax.top_k(vs, k=k) - self._CheckAgainstNumpy(op, reference_top_k, args_maker) + return [rng(shape, dtype)] + op = lambda vs: lax.top_k(vs, k=k, axis=axis) + ref_op = lambda vs: lax_reference.top_k(vs, k=k, axis=axis) + self._CheckAgainstNumpy(op, ref_op, args_maker) self._CompileAndCheck(op, args_maker) + def testTopKOverflow(self): + x = jax.ShapeDtypeStruct((2 ** 31 + 1,), np.dtype('bfloat16')) + with self.assertRaisesRegex(ValueError, "top_k returns int32 indices, which will overflow"): + jax.eval_shape(lambda x: jax.lax.top_k(x, 100), x) + @jtu.sample_product( [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape) for lhs_shape, rhs_shape in [((3, 2), (2, 4)), @@ -2717,7 +2797,12 @@ def testIndexTake(self, shape, dtype, idxs, axes): offset_dims=(2,), collapsed_slice_dims=(), start_index_map=(2,), operand_batching_dims=(0, 1), start_indices_batching_dims=(1, 0)), - (1, 1, 3)) + (1, 1, 3)), + # This test verifies that we allow slice sizes that would not fit in + # the operand if indices were empty. This is a useful base case. + ((0,), np.zeros((0, 1), dtype=np.int32), lax.GatherDimensionNumbers( + offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), + (1,)), ]], dtype=lax_test_util.all_dtypes, ) @@ -3527,11 +3612,11 @@ def test_const(self, dtype, weak_type): if dtype in set(lax_test_util.python_scalar_types): val = dtype(0) else: - val = lax_internal._convert_element_type(0, dtype, weak_type=weak_type) + val = lax_internal._convert_element_type(0, np.dtype(dtype), + weak_type=weak_type) const = lax_internal._const(val, 0) - self.assertEqual(dtypes.dtype(val, canonicalize=True), - dtypes.dtype(const, canonicalize=True)) + self.assertEqual(dtypes.dtype(val), dtypes.dtype(const)) def testIgammaSpecial(self): self.assertEqual(lax.igamma(1., np.inf), 1.) @@ -3627,6 +3712,126 @@ def f(x): g = jax.grad(f)(5.) # doesn't crash self.assertAllClose(g, 3., check_dtypes=False) + def test_shape_as_value_handles_static_shapes(self): + result = lax.shape_as_value(()) + self.assertArraysEqual(result, lax.full((0,), np.array(0, np.int32))) + + result = lax.shape_as_value((2,)) + self.assertArraysEqual(result, np.asarray((2,), np.int32)) + + result = lax.shape_as_value((2, 3)) + self.assertArraysEqual(result, np.asarray((2, 3), np.int32)) + + def test_shape_as_value_handles_polymorphic_shapes(self): + @jax.jit + def f(x): + return lax.shape_as_value(x.shape) + + exported = export.export(f)( + jax.ShapeDtypeStruct(export.symbolic_shape("a"), jnp.float32) + ) + result = exported.call(np.ones((1), dtype=np.float32)) + self.assertArraysEqual(result, np.asarray((1,), np.int64)) + result = exported.call(np.ones((2), dtype=np.float32)) + self.assertArraysEqual(result, np.asarray((2,), np.int64)) + + exported = export.export(f)( + jax.ShapeDtypeStruct(export.symbolic_shape("a, b"), jnp.float32) + ) + result = exported.call(np.ones((1, 2), dtype=np.float32)) + self.assertArraysEqual(result, np.asarray((1, 2), np.int64)) + result = exported.call(np.ones((3, 4), dtype=np.float32)) + self.assertArraysEqual(result, np.asarray((3, 4), np.int64)) + + @jtu.sample_product( + name = ['abs'], + dtype = ['int4', 'uint4'], + ) + def test_int4_non_support_errors(self, name, dtype): + func = getattr(lax, name) + arg = lax.iota(dtype, 3) + with self.assertRaisesRegex(TypeError, f'{name} does not accept dtype {dtype}.'): + func(arg) + + @jtu.sample_product( + name = ['bitwise_not', 'neg', 'sign'], + dtype = ['int4', 'uint4'], + ) + def test_int4_unary_ops(self, name, dtype): + func = getattr(lax, name) + rng = jtu.rand_default(self.rng()) + x = rng(3, dtype) + actual = func(x) + expected = func(x.astype('int8')).astype(dtype) + self.assertArraysEqual(actual, expected, check_dtypes=True) + + @jtu.sample_product( + name = ['add', 'sub', 'mul', 'div', 'rem', 'max', 'min', + 'shift_left', 'shift_right_arithmetic', 'shift_right_logical', + 'bitwise_and', 'bitwise_or', 'bitwise_xor', + 'eq', 'ne', 'gt', 'ge', 'lt', 'le'], + dtype = ['int4', 'uint4'], + ) + def test_int4_binary_ops(self, name, dtype): + func = getattr(lax, name) + rng = jtu.rand_default(self.rng()) + x, y = rng(3, dtype), rng(3, dtype) + actual = func(x, y) + expected = func(x.astype('int8'), y.astype('int8')) + if expected.dtype == 'int8': + expected = expected.astype(dtype) + self.assertArraysEqual(actual, expected, check_dtypes=True) + + def test_gather_with_asymmetric_dtype(self): + @jax.custom_vjp + def f(x): + return x + + def f_fwd(x): + return f(x), () + + def f_bwd(res, g): + del res + return g.astype(jnp.bfloat16), + + f.defvjp(f_fwd, f_bwd) + + def g(x): + idx = jnp.argsort(x) + x = x.at[idx].get() + return f(x) + + x = jnp.arange(8, dtype=jnp.float8_e4m3fn) + _, vjp_fn = jax.vjp(g, x) + cts = vjp_fn(jnp.ones((8,), dtype=jnp.float8_e4m3fn)) # Don't crash + self.assertEqual(cts[0].dtype, jnp.bfloat16) + + def test_stop_gradient_on_ints(self): + # https://github.com/jax-ml/jax/issues/33689 + @jax.custom_gradient + def f(x): + def fbwd(g): + return jnp.ones_like(x) + return (x, jnp.round(x).astype(jnp.int32)), fbwd + + def loss(x): + y, i = f(x) + y_nograd, i_nograd = jax.lax.stop_gradient((y, i)) + self.assertEqual(type(y_nograd), type(i_nograd)) + return jnp.sum(f(y)[0]) + + jax.grad(loss)(jnp.ones((3,))) + + def test_no_complex_to_real_cast_warning_in_transpose(self): + # https://github.com/jax-ml/jax/issues/33521 + def f(x, y): + return jax.lax.dot(x, y).real + + x = jnp.arange(5, dtype='float32') + y = jnp.arange(5, dtype='complex64') + with self.assertNoWarnings(): + jax.jacobian(f)(x, y) + class LazyConstantTest(jtu.JaxTestCase): def _Check(self, make_const, expected): @@ -3718,7 +3923,7 @@ def testConvertElementReturnType(self, input_type, dtype, value, jit): @jtu.sample_product( dtype_in=lax_test_util.all_dtypes, dtype_out=lax_test_util.all_dtypes) - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) def testConvertElementTypeAvoidsCopies(self, dtype_in, dtype_out): x = jax.device_put(np.zeros(5, dtype_in)) self.assertEqual(x.dtype, dtype_in) @@ -3837,14 +4042,11 @@ def handler(_, buf): @staticmethod def global_sharded_result_handler(aval, out_sharding, committed): - def handler(arr): - from jax._src.array import ArrayImpl - if isinstance(arr, ArrayImpl): - buf, = arr._arrays - else: - buf, = arr - return FooArray(aval.shape, buf) - return handler + phys_sharding = out_sharding # unlike KeyTyRules, assume same shape + phys_aval = core.physical_aval(aval) + phys_handler_maker = pxla.global_result_handlers[core.ShapedArray] + phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed) + return phys_handler.wrap(lambda arr: FooArray(aval.shape, arr)) class FooTy(dtypes.ExtendedDType): @@ -3916,8 +4118,8 @@ def shard_foo_array_handler(xs, shardings, layouts, copy_semantics): aval, jax.sharding.SingleDeviceSharding(device), [x.data], [device])) return results -def foo_array_constant_handler(x): - return array._array_mlir_constant_handler(x.data) +def foo_array_constant_handler(x, aval): + return array._array_mlir_constant_handler(x.data, aval) def make_lowering(*, shape): return jnp.zeros((*shape, 2), 'uint32') @@ -3948,7 +4150,7 @@ class CustomElementTypesTest(jtu.JaxTestCase): def setUp(self): core.pytype_aval_mappings[FooArray] = \ lambda x: core.ShapedArray(x.shape, FooTy(), sharding=None) - xla.canonicalize_dtype_handlers[FooArray] = lambda x: x + dtypes.canonicalize_value_handlers[FooArray] = lambda x: x pxla.shard_arg_handlers[FooArray] = shard_foo_array_handler mlir._constant_handlers[FooArray] = foo_array_constant_handler mlir.register_lowering(make_p, mlir.lower_fun(make_lowering, False)) @@ -3960,7 +4162,7 @@ def setUp(self): def tearDown(self): del core.pytype_aval_mappings[FooArray] - del xla.canonicalize_dtype_handlers[FooArray] + del dtypes.canonicalize_value_handlers[FooArray] del mlir._constant_handlers[FooArray] del mlir._lowerings[make_p] del mlir._lowerings[bake_p] @@ -4074,6 +4276,7 @@ def test_scan_jaxpr(self): b, = e.outvars self.assertEqual(b.aval, core.ShapedArray((3, 4), FooTy())) + @unittest.skip('removed split_transpose') def test_scan_jaxpr_split_transpose(self): def stage(x, w): x = x @ w @@ -4369,7 +4572,7 @@ def _testOnComplexPlaneWorker(self, name, dtype, kind): # # In addition, the 1/3 middle parts of regions q1, q2, q3, q4, # neg, pos are tested separately as these don't contain extremely - # small or extremelly large values and functions on these regions + # small or extremely large values and functions on these regions # ought not to possess any incorrectness issues. s0, s1 = size_re, size_im @@ -4406,7 +4609,7 @@ def _testOnComplexPlaneWorker(self, name, dtype, kind): # return values) to (i) workaround numpy 1.x assert_allclose bug # in comparing complex infinities, and (ii) expose more details # about failing cases: - s_dict_parts = dict() + s_dict_parts = {} for k, v in s_dict.items(): s_dict_parts[k + '.real'] = v s_dict_parts[k + '.imag'] = v @@ -4618,7 +4821,7 @@ def my_square_impl(x): def test_composite_with_attributes(self): # The static_argnames is required here since k is a constant that should # come out of a larger context, but we unit test one op (composite) here. - @partial(jax.jit, static_argnames=['k']) + @jax.jit(static_argnames=['k']) @partial(lax.composite, name="my.top_k") def my_top_k(x, *, k): return lax.top_k(x, k) @@ -4747,7 +4950,7 @@ def my_square(x): ValueError, "JVP rule for composite not implemented. You can use `jax.custom_jvp` " "to add support. See " - "https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_jvp.html" + "https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.html" ): jvp(my_square, (1.0,), (2.0,)) @@ -4760,7 +4963,7 @@ def my_square(x): ValueError, "JVP rule for composite not implemented. You can use `jax.custom_jvp` " "to add support. See " - "https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_jvp.html" + "https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.html" ): grad(my_square)(1.0) @@ -4773,11 +4976,15 @@ def my_consts(x, /, *, scale): x = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32) self.assertAllClose(my_consts(x, scale=scale), jnp.round(x / scale)) - # The constant must not appear as an extra input argument to the composite. mlir_module = jax.jit(partial(my_consts, scale=scale)).lower(x).as_text() - self.assertIn( - "@my.consts(%arg0: tensor<3xf32>) -> tensor<3xf32>", mlir_module - ) + if config.use_simplified_jaxpr_constants.value: + self.assertIn( + "@my.consts(%arg0: tensor<3xf32> {jax.const = true}, %arg1: tensor<3xf32>) -> tensor<3xf32>", mlir_module + ) + else: + self.assertIn( + "@my.consts(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32>", mlir_module + ) def test_composite_with_tracer_consts(self): def fun(x, scale): @@ -4800,14 +5007,7 @@ def my_consts(y): class RaggedTest(jtu.JaxTestCase): - @jtu.sample_product( - [ - {'m': 5, 'k': 4, 'n': 3, 'num_groups': 1}, - {'m': 10, 'k': 9, 'n': 8, 'num_groups': 2}, - ], - dtype=jtu.dtypes.numeric, - ) - def test_ragged_dot(self, m, k, n, num_groups, dtype): + def _test_ragged_dot(self, m, k, n, num_groups, dtype): """Tests ragged_dot. The ragged_dot is tested against numpy reference implementation, and by @@ -4816,6 +5016,8 @@ def test_ragged_dot(self, m, k, n, num_groups, dtype): Raises: SkipTest: in the case dtype is not supported. """ + if (dtype == np.float16): + raise SkipTest(f"unsupported dtype for ragged_dot: {dtype}") lhs_shape = (m, k) rhs_shape = (num_groups, k, n) @@ -4837,6 +5039,50 @@ def group_sizes(m, num_groups): self._CheckAgainstNumpy( lax_reference.ragged_dot, lax.ragged_dot, args_maker) + @jtu.sample_product( + [ + {"m": 64, "k": 4, "n": 3, "num_groups": 1}, + {"m": 64, "k": 9, "n": 8, "num_groups": 2}, + ], + dtype=jtu.dtypes.all_floating, + ) + def test_ragged_dot(self, m, k, n, num_groups, dtype): + return self._test_ragged_dot(m, k, n, num_groups, dtype) + + @parameterized.parameters([True, False]) + def test_ragged_dot_use_ragged_dot_instruction(self, use_instruction): + with config.jax_ragged_dot_use_ragged_dot_instruction(use_instruction): + self._test_ragged_dot(16, 4, 3, 2, jnp.float32) + if jtu.test_device_matches(["tpu"]) and use_instruction: + self.assertIn( + "chlo.ragged_dot", + jax.jit(lax.ragged_dot) + .lower( + core.ShapedArray((16, 4), dtype=jnp.float32), + core.ShapedArray((2, 4, 3), dtype=jnp.float32), + core.ShapedArray((2,), dtype=jnp.int32), + ) + .as_text(dialect="stablehlo"), + ) + + @parameterized.parameters( + {"m": 5, "k": 4, "n": 3, "num_groups": 1}, + {"m": 5, "k": 4, "n": 3, "num_groups": 2}, + {"m": 9, "k": 4, "n": 3, "num_groups": 1}, + {"m": 10, "k": 9, "n": 8, "num_groups": 2}, + ) + def test_ragged_dot_small_m(self, m, k, n, num_groups): + lhs_shape = (m, k) + rhs_shape = (num_groups, k, n) + group_sizes_shape = (num_groups,) + + args_maker = lambda: [ + jnp.ones(lhs_shape, dtype=jnp.float32), + jnp.ones(rhs_shape, dtype=jnp.float32), + jnp.ones(group_sizes_shape, dtype=jnp.int32), + ] + self._CompileAndCheck(lax.ragged_dot, args_maker) + @parameterized.parameters( { "lhs_shape": lhs_shape, @@ -5026,7 +5272,7 @@ def test_ragged_dot_general_shape_inference_failure( "out_shape": out_shape, } for lhs_shape, rhs_shape, group_sizes_shape, ragged_dnums, out_shape in [ - ( + ( # Ragged non-contracting. [11, 5], [3, 5, 7], [3], @@ -5037,7 +5283,7 @@ def test_ragged_dot_general_shape_inference_failure( ), (11, 7), ), - ( + ( # Ragged contracting. [11, 5], [5, 7], [3], @@ -5048,6 +5294,18 @@ def test_ragged_dot_general_shape_inference_failure( ), (3, 11, 7), ), + ( # Ragged contracting with batch dimensions. + [2, 11, 5], + [2, 5, 7], + [2, 3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([2], [1]), ([0], [0]), + ), + lhs_ragged_dimensions=[2], + rhs_group_dimensions=[], + ), + (3, 2, 11, 7), + ), ] ) def test_ragged_dot_general_shape_inference_success( @@ -5055,10 +5313,124 @@ def test_ragged_dot_general_shape_inference_success( lhs = jnp.ones(lhs_shape, dtype=jnp.float32) rhs = jnp.ones(rhs_shape, dtype=jnp.float32) group_sizes = jnp.ones(group_sizes_shape, dtype=jnp.int32) + if jtu.test_device_matches(["tpu"]): + actual_shape = lax_internal._ragged_dot_general_shape_rule( + lhs, rhs, group_sizes, ragged_dot_dimension_numbers=ragged_dnums, + precision=jax.lax.Precision.DEFAULT, + preferred_element_type=jnp.float32, + ) + else: + actual_shape = lax.ragged_dot_general( + lhs, rhs, group_sizes, ragged_dnums + ).shape + self.assertEqual(actual_shape, out_shape) + + @parameterized.product( + batch_size=[3, 5], + m=[128, 1024], + k=[128, 1024], + n=[128, 1024], + num_groups=[2, 4], + ) + def test_ragged_dot_general_vmap( + self, batch_size: int, m: int, k: int, n: int, num_groups: int + ): + if (jtu.test_device_matches(["tpu"])): + raise SkipTest("batched ragged_dot not yet supported on TPU") + + lhs_shape = (batch_size, m, k) + rhs_shape = (batch_size, num_groups, k, n) + dtype = jnp.float32 + + def make_group_sizes(m, num_groups): + ends_no_final = jnp.sort(self.rng().choice(m, size=num_groups - 1)) + ends = jnp.concatenate( + [ends_no_final, jnp.array([m], dtype=ends_no_final.dtype)]) + starts = jnp.concatenate( + [jnp.zeros(1, dtype=ends_no_final.dtype), ends_no_final]) + return ends - starts + + rng = jtu.rand_small(self.rng()) + args_maker = lambda: [ + rng(lhs_shape, dtype), + rng(rhs_shape, dtype), + jnp.array([make_group_sizes(m, num_groups) for _ in range(batch_size)]), + ] + lhs, rhs, group_sizes = args_maker() + + out_dtype = jnp.float32 + precision = jax.lax.Precision.HIGHEST + ragged_dot = partial( + jax.lax.ragged_dot, + preferred_element_type=out_dtype, + precision=precision, + ) + tol = 1e-5 + + batch_res = jax.vmap(ragged_dot)(lhs, rhs, group_sizes) + for i in range(batch_size): + # The ragged_dot does not zero out the output in the case sum(group_sizes) + # < m, hence we need to compare only the valid part of the output. + upper_bound = group_sizes[i].sum(axis=0) + ref_res = ragged_dot(lhs[i], rhs[i], group_sizes[i])[0:upper_bound, :] + self.assertArraysAllClose( + batch_res[i, 0:upper_bound, :], ref_res, rtol=tol, atol=tol + ) + +class LaxUtilsTest(jtu.JaxTestCase): + + def test_int_dtype_for_dim(self): + self.assertEqual(lax_utils.int_dtype_for_dim(10, signed=True), np.int32) + self.assertEqual(lax_utils.int_dtype_for_dim(10, signed=False), np.uint32) self.assertEqual( - lax.ragged_dot_general(lhs, rhs, group_sizes, ragged_dnums).shape, - out_shape, + lax_utils.int_dtype_for_dim(np.iinfo(np.int32).max, signed=True), + np.int32, ) + self.assertEqual( + lax_utils.int_dtype_for_dim(np.iinfo(np.int32).max + 1, signed=True), + np.int64, + ) + self.assertEqual( + lax_utils.int_dtype_for_dim(np.iinfo(np.uint32).max, signed=False), + np.uint32, + ) + self.assertEqual( + lax_utils.int_dtype_for_dim(np.iinfo(np.uint32).max + 1, signed=False), + np.uint64, + ) + + def test_int_dtype_for_shape(self): + self.assertEqual( + lax_utils.int_dtype_for_shape([10, 20], signed=True), np.int32 + ) + self.assertEqual( + lax_utils.int_dtype_for_shape([10, 20], signed=False), np.uint32 + ) + self.assertEqual( + lax_utils.int_dtype_for_shape( + [10, np.iinfo(np.int32).max], signed=True + ), + np.int32, + ) + self.assertEqual( + lax_utils.int_dtype_for_shape( + [np.iinfo(np.int32).max + 1, 20], signed=True + ), + np.int64, + ) + self.assertEqual( + lax_utils.int_dtype_for_shape( + [10, np.iinfo(np.uint32).max], signed=False + ), + np.uint32, + ) + self.assertEqual( + lax_utils.int_dtype_for_shape( + [np.iinfo(np.uint32).max + 1, 20], signed=False + ), + np.uint64, + ) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_vmap_test.py b/tests/lax_vmap_test.py index 2fd817c5a45e..de260ec4ed93 100644 --- a/tests/lax_vmap_test.py +++ b/tests/lax_vmap_test.py @@ -695,18 +695,19 @@ def testBroadcastShapesFaultyInputs(self): for shape in [(4,), (3, 5, 3)] for bdims in lax_test_util.all_bdims(shape)], k=[1, 3], + axis=[0, -1], dtype=lax_test_util.default_dtypes, ) # The top_k indices for integer arrays with identical entries won't match between # vmap'd version and manual reference, so only test unique integer arrays for int_dtypes. # Note also that we chose 3 * 5 * 3 * 5 such that it fits in the range of # values a bfloat16 can represent exactly to avoid ties. - def testTopK(self, shape, dtype, k, bdims): + def testTopK(self, shape, dtype, k, bdims, axis): rng = jtu.rand_int(self.rng(), high=math.prod(shape)) # _CheckBatching doesn't work with tuple outputs, so test outputs separately. - op1 = lambda x: lax.top_k(x, k=k)[0] + op1 = lambda x: lax.top_k(x, k=k, axis=axis)[0] self._CheckBatching(op1, 5, bdims, (shape,), (dtype,), rng) - op2 = lambda x: lax.top_k(x, k=k)[1] + op2 = lambda x: lax.top_k(x, k=k, axis=axis)[1] self._CheckBatching(op2, 5, bdims, (shape,), (dtype,), rng) @jtu.sample_product( @@ -758,8 +759,8 @@ def testSort(self, shape, dimension, arity, bdims, is_stable): # TODO Collapse # TODO Scatter - # TODO(b/183233858): variadic reduce-window is not implemented on XLA:GPU - @jtu.skip_on_devices("gpu") + # b/183233858: variadic reduce-window not implemented on XLA:CUDA + @jtu.skip_on_devices("cuda") def test_variadic_reduce_window(self): # https://github.com/jax-ml/jax/discussions/9818 and # https://github.com/jax-ml/jax/issues/9837 @@ -790,6 +791,46 @@ def g(a, b): expected = jnp.array([[[2.0, 2.0, 0.0], [3.0, 0.0, 1.0]]]) self.assertAllClose(output, expected, check_dtypes=False) + @jtu.sample_product( + [ + dict(arg_shape=arg_shape, reps=reps) + for arg_shape, reps in [ + [(3,), (2,)], + [(2, 3), (1, 2)], + [(2, 3), (2, 1)], + [(2, 1, 3), (1, 2, 3)], + ] + ], + in_axes=[0, 1, -1], + out_axes=[0, 1, -1], + ) + def testTileBatching(self, arg_shape, reps, in_axes, out_axes): + rng = jtu.rand_default(self.rng()) + dtype = np.float32 + args_maker = lambda: [rng(arg_shape, dtype)] + op = lambda x: lax.tile(x, reps) + args = args_maker() + + # Construct batched arguments based on in_axes + if in_axes == 0: + batched_args = [jnp.stack([arg, arg], axis=0) for arg in args] + elif in_axes == 1: + batched_args = [jnp.stack([arg, arg], axis=1) for arg in args] + else: # in_axes == -1 + batched_args = [jnp.stack([arg, arg], axis=-1) for arg in args] + + # Compute expected output + out = op(*args) + if out_axes == 0: + expected = jnp.stack([out, out], axis=0) + elif out_axes == 1: + expected = jnp.stack([out, out], axis=1) + else: # out_axes == -1 + expected = jnp.stack([out, out], axis=-1) + + actual = jax.vmap(op, in_axes=in_axes, out_axes=out_axes)(*batched_args) + self.assertAllClose(expected, actual) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/layout_test.py b/tests/layout_test.py index b9062b8d21dc..a6ef84864eea 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -14,17 +14,18 @@ import math from functools import partial + from absl.testing import absltest +from absl.testing import parameterized import numpy as np import jax import jax.numpy as jnp from jax.sharding import NamedSharding, PartitionSpec as P, SingleDeviceSharding from jax._src import config -from jax._src.layout import Layout, DeviceLocalLayout as DLL from jax._src import test_util as jtu from jax._src.util import safe_zip -from jax.experimental.compute_on import compute_on +from jax.experimental.layout import with_layout_constraint, Format, Layout config.parse_flags_with_absl() jtu.request_cpu_devices(8) @@ -50,22 +51,22 @@ def init(x, y): sds1 = jax.ShapeDtypeStruct(np_inp1.shape, np_inp1.dtype, sharding=s1) sds2 = jax.ShapeDtypeStruct(np_inp2.shape, np_inp2.dtype, sharding=s2) - lowered_apply = jax.jit(apply, in_shardings=Layout(DLL.AUTO), - out_shardings=Layout(DLL.AUTO)).lower(sds1, sds2) + lowered_apply = jax.jit(apply, in_shardings=Format(Layout.AUTO), + out_shardings=Format(Layout.AUTO)).lower(sds1, sds2) compiled_apply = lowered_apply.compile() - arg_layouts, kw_layouts = compiled_apply.input_layouts + arg_formats, kw_layouts = compiled_apply.input_formats self.assertEmpty(kw_layouts) - for i, o in zip(arg_layouts, compiled_apply.output_layouts): - self.assertEqual(i.device_local_layout.major_to_minor, - o.device_local_layout.major_to_minor[::-1]) + for i, o in zip(arg_formats, compiled_apply.output_formats): + self.assertEqual(i.layout.major_to_minor, + o.layout.major_to_minor[::-1]) init_compiled = jax.jit( - init, out_shardings=arg_layouts).lower(sds1, sds2).compile() + init, out_shardings=arg_formats).lower(sds1, sds2).compile() - for i, o in zip(init_compiled.input_layouts[0], - init_compiled.output_layouts): + for i, o in zip(init_compiled.input_formats[0], + init_compiled.output_formats): self.assertEqual(i, o) arr1 = jax.device_put(np_inp1, s1) @@ -76,21 +77,21 @@ def init(x, y): init_compiled(arr1, arr2) self.assertEqual(init_count(), 1) - self.assertEqual(init_out[0].layout, init_compiled.output_layouts[0]) - self.assertEqual(init_out[1].layout, init_compiled.output_layouts[1]) + self.assertEqual(init_out[0].format, init_compiled.output_formats[0]) + self.assertEqual(init_out[1].format, init_compiled.output_formats[1]) with jtu.count_aot_jit_cpp_cache_miss() as apply_count: apply_out = compiled_apply(*init_out) compiled_apply(*init_out) self.assertEqual(apply_count(), 1) - self.assertEqual(apply_out[0].layout, compiled_apply.output_layouts[0]) - self.assertEqual(apply_out[1].layout, compiled_apply.output_layouts[1]) + self.assertEqual(apply_out[0].format, compiled_apply.output_formats[0]) + self.assertEqual(apply_out[1].format, compiled_apply.output_formats[1]) - self.assertTupleEqual(apply_out[0].layout.device_local_layout.major_to_minor, - init_out[0].layout.device_local_layout.major_to_minor[::-1]) - self.assertTupleEqual(apply_out[1].layout.device_local_layout.major_to_minor, - init_out[1].layout.device_local_layout.major_to_minor[::-1]) + self.assertTupleEqual(apply_out[0].format.layout.major_to_minor, + init_out[0].format.layout.major_to_minor[::-1]) + self.assertTupleEqual(apply_out[1].format.layout.major_to_minor, + init_out[1].format.layout.major_to_minor[::-1]) self.assertArraysEqual(init_out[0], np_inp1 * 2) self.assertArraysEqual(init_out[1], np_inp2 * 2) @@ -113,27 +114,27 @@ def f(x): out = compiled(arr) self.assertTupleEqual( - compiled.input_layouts[0][0].device_local_layout.major_to_minor[::-1], + compiled.input_formats[0][0].layout.major_to_minor[::-1], (2, 1, 0)) self.assertTupleEqual( - compiled.output_layouts.device_local_layout.major_to_minor[::-1], + compiled.output_formats.layout.major_to_minor[::-1], (2, 1, 0)) self.assertArraysEqual(out, np_inp.T) self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x'))) - compiled_auto = jax.jit(f, in_shardings=Layout(DLL.AUTO), - out_shardings=Layout(DLL.AUTO)).lower(sds).compile() + compiled_auto = jax.jit(f, in_shardings=Format(Layout.AUTO), + out_shardings=Format(Layout.AUTO)).lower(sds).compile() self.assertTupleEqual( - compiled_auto.input_layouts[0][0].device_local_layout.major_to_minor[::-1], + compiled_auto.input_formats[0][0].layout.major_to_minor[::-1], (2, 1, 0)) self.assertTupleEqual( - compiled_auto.output_layouts.device_local_layout.major_to_minor[::-1], + compiled_auto.output_formats.layout.major_to_minor[::-1], (0, 1, 2)) with self.assertRaisesRegex( ValueError, "jax.jit` does not accept device-local layouts directly"): - jax.jit(f, in_shardings=DLL.AUTO, - out_shardings=DLL.AUTO).lower(sds).compile() + jax.jit(f, in_shardings=Layout.AUTO, + out_shardings=Layout.AUTO).lower(sds).compile() def test_in_layouts_out_layouts(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) @@ -145,18 +146,18 @@ def test_in_layouts_out_layouts(self): def f(x): return x.T - compiled = jax.jit(f, in_shardings=Layout(), - out_shardings=Layout(DLL.AUTO)).lower(arr).compile() + compiled = jax.jit(f, in_shardings=Format(), + out_shardings=Format(Layout.AUTO)).lower(arr).compile() self.assertTupleEqual( - compiled.input_layouts[0][0].device_local_layout.major_to_minor[::-1], + compiled.input_formats[0][0].layout.major_to_minor[::-1], (1, 0)) self.assertTupleEqual( - compiled.output_layouts.device_local_layout.major_to_minor[::-1], + compiled.output_formats.layout.major_to_minor[::-1], (0, 1)) out = compiled(arr) self.assertArraysEqual(out, np_inp.T) - self.assertEqual(out.layout, compiled.output_layouts) + self.assertEqual(out.format, compiled.output_formats) self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'x'))) def test_sharding_and_layouts(self): @@ -165,15 +166,15 @@ def test_sharding_and_layouts(self): np_inp = np.arange(math.prod(shape)).reshape(shape) s = NamedSharding(mesh, P('x', 'y')) - compiled = jax.jit(lambda x: x.T, in_shardings=Layout(DLL.AUTO, s), - out_shardings=Layout(DLL.AUTO, s)).lower(np_inp).compile() + compiled = jax.jit(lambda x: x.T, in_shardings=Format(Layout.AUTO, s), + out_shardings=Format(Layout.AUTO, s)).lower(np_inp).compile() out = compiled(np_inp) self.assertTupleEqual( - compiled.input_layouts[0][0].device_local_layout.major_to_minor[::-1], + compiled.input_formats[0][0].layout.major_to_minor[::-1], (1, 0)) if not jtu.test_device_matches(['cpu']): self.assertTupleEqual( - compiled.output_layouts.device_local_layout.major_to_minor[::-1], + compiled.output_formats.layout.major_to_minor[::-1], (0, 1)) self.assertArraysEqual(out, np_inp.T) self.assertEqual(out.sharding, s) @@ -184,21 +185,21 @@ def f(x, y, z, a, b, c): shape = (8, 2) inps = [np.arange(math.prod(shape)).reshape(shape)] * 6 - compiled = jax.jit(f, in_shardings=Layout(DLL.AUTO), - out_shardings=Layout(DLL.AUTO)).lower(*inps).compile() - arg_layouts, _ = compiled.input_layouts + compiled = jax.jit(f, in_shardings=Format(Layout.AUTO), + out_shardings=Format(Layout.AUTO)).lower(*inps).compile() + arg_formats, _ = compiled.input_formats out1, out2 = compiled(*inps) - compiled2 = jax.jit(f, in_shardings=arg_layouts).lower(*inps).compile() + compiled2 = jax.jit(f, in_shardings=arg_formats).lower(*inps).compile() out3, out4 = compiled2(*inps) - for l1, l2 in safe_zip(arg_layouts, compiled2.input_layouts[0]): + for l1, l2 in safe_zip(arg_formats, compiled2.input_formats[0]): self.assertEqual(l1, l2) self.assertArraysEqual(out1, out3) self.assertArraysEqual(out2, out4) - arrs = [jax.device_put(i, l) for i, l in zip(inps, arg_layouts)] + arrs = [jax.device_put(i, l) for i, l in zip(inps, arg_formats)] out5, out6 = jax.jit(f)(*arrs) self.assertArraysEqual(out1, out5) self.assertArraysEqual(out2, out6) @@ -215,11 +216,11 @@ def test_no_error_dced_args(self): def f(x, y): return x * 2 - jf = jax.jit(f, in_shardings=Layout(DLL.AUTO, s), - out_shardings=Layout(DLL.AUTO, s)) + jf = jax.jit(f, in_shardings=Format(Layout.AUTO, s), + out_shardings=Format(Layout.AUTO, s)) compiled = jf.lower(np_inp, np_inp).compile() - arg_layouts, _ = compiled.input_layouts - arrs = [jax.device_put(i, l) for i, l in zip(arrs, arg_layouts)] + arg_formats, _ = compiled.input_formats + arrs = [jax.device_put(i, l) for i, l in zip(arrs, arg_formats)] compiled(*arrs) def test_aot_layout_mismatch(self): @@ -243,16 +244,15 @@ def f(x): with self.assertRaisesRegex( ValueError, 'Layout passed to jit does not match the layout on the respective arg'): - jax.jit(f, in_shardings=Layout(DLL.AUTO)).lower(arr) + jax.jit(f, in_shardings=Format(Layout.AUTO)).lower(arr) - compiled = jax.jit(f, in_shardings=Layout(DLL.AUTO), - out_shardings=Layout(DLL.AUTO)).lower(sds).compile() + compiled = jax.jit(f, in_shardings=Format(Layout.AUTO), + out_shardings=Format(Layout.AUTO)).lower(sds).compile() with self.assertRaisesRegex( ValueError, - r'Compiled object called with input layout\(s\) does' - r' not match the layout\(s\) the computation was' - ' compiled with'): + r'Computation was compiled for input layouts that disagree with the ' + r'layouts of arguments passed to it.'): compiled(arr) @jtu.ignore_warning(category=DeprecationWarning, @@ -272,30 +272,30 @@ def test_device_put_concrete_layout(self): arr = jax.device_put(np_inp, s) compiled = jax.jit( - lambda x: x * 2, out_shardings=Layout(DLL.AUTO)).lower(arr).compile() - col = compiled.output_layouts + lambda x: x * 2, out_shardings=Format(Layout.AUTO)).lower(arr).compile() + col = compiled.output_formats out = jax.device_put(np_inp, col) - self.assertEqual(out.layout, col) + self.assertEqual(out.format, col) self.assertArraysEqual(out, np_inp) for s in out.addressable_shards: - self.assertEqual(out.layout.device_local_layout, - s.data.layout.device_local_layout) + self.assertEqual(out.format.layout, + s.data.format.layout) def test_device_put_non_concrete_layout_error(self): np_inp = np.arange(16).reshape(8, 2) - l1 = Layout(DLL.AUTO, SingleDeviceSharding(jax.devices()[0])) + l1 = Format(Layout.AUTO, SingleDeviceSharding(jax.devices()[0])) with self.assertRaisesRegex( - ValueError, 'sharding and device_local_layout.*should be concrete'): + ValueError, 'sharding and layout.*should be concrete'): jax.device_put(np_inp, l1) - l2 = Layout(DLL.AUTO) + l2 = Format(Layout.AUTO) with self.assertRaisesRegex( - ValueError, 'sharding and device_local_layout.*should be concrete'): + ValueError, 'sharding and layout.*should be concrete'): jax.device_put(np_inp, l2) - l3 = Layout(None, SingleDeviceSharding(jax.devices()[0])) + l3 = Format(None, SingleDeviceSharding(jax.devices()[0])) out = jax.device_put(np_inp, l3) self.assertArraysEqual(out, np_inp) self.assertTrue(out._committed) @@ -305,7 +305,7 @@ def invalid_layout_spec(self): compiled = jax.jit(lambda x: x).lower(x).compile() with self.assertRaisesRegex( ValueError, 'Sharding has to be concrete when layout.*'): - Layout(compiled.output_layouts[0], None) + Format(compiled.output_formats[0], None) def test_layout_on_sds(self): mesh = jtu.create_mesh((2, 1), ('x', 'y')) @@ -313,18 +313,18 @@ def test_layout_on_sds(self): np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, s) - out_layout = jax.jit(jnp.sin, out_shardings=Layout(DLL.AUTO)).lower( - arr).compile().output_layouts + out_format = jax.jit(jnp.sin, out_shardings=Format(Layout.AUTO)).lower( + arr).compile().output_formats - sds = jax.ShapeDtypeStruct(arr.shape, arr.dtype, sharding=out_layout) - arg_layout, _ = jax.jit(lambda x: x * 2).lower(sds).compile().input_layouts - self.assertEqual(arg_layout[0], out_layout) + sds = jax.ShapeDtypeStruct(arr.shape, arr.dtype, sharding=out_format) + arg_format, _ = jax.jit(lambda x: x * 2).lower(sds).compile().input_formats + self.assertEqual(arg_format[0], out_format) with self.assertRaisesRegex( TypeError, - 'DeviceLocalLayout.AUTO` cannot be used in place of a device-local' + 'Layout.AUTO` cannot be used in place of a device-local' ' layout in a `ShapeDtypeStruct`'): - jax.ShapeDtypeStruct(arr.shape, arr.dtype, sharding=Layout(DLL.AUTO)) + jax.ShapeDtypeStruct(arr.shape, arr.dtype, sharding=Format(Layout.AUTO)) def test_make_array_from_callback(self): mesh = jtu.create_mesh((2, 1), ('x', 'y')) @@ -332,24 +332,24 @@ def test_make_array_from_callback(self): np_inp = np.arange(16).reshape(8, 2) sds = jax.ShapeDtypeStruct(np_inp.shape, np_inp.dtype, sharding=s) - layout = jax.jit(lambda x: x * 2).lower(sds).compile().output_layouts + format = jax.jit(lambda x: x * 2).lower(sds).compile().output_formats - out = jax.make_array_from_callback(np_inp.shape, layout, + out = jax.make_array_from_callback(np_inp.shape, format, lambda idx: np_inp[idx]) self.assertArraysEqual(out, np_inp) - self.assertEqual(out.layout, layout) + self.assertEqual(out.format, format) with self.assertRaisesRegex( TypeError, - '`DeviceLocalLayout.AUTO` cannot be used in place of a device-local' + '`Layout.AUTO` cannot be used in place of a device-local' ' layout'): - jax.make_array_from_callback(np_inp.shape, Layout(DLL.AUTO, s), + jax.make_array_from_callback(np_inp.shape, Format(Layout.AUTO, s), lambda idx: np_inp[idx]) with self.assertRaisesRegex( TypeError, 'sharding should be an instance of `jax.sharding`'): jax.make_array_from_callback( - np_inp.shape, Layout(None, None), lambda idx: np_inp[idx]) + np_inp.shape, Format(None, None), lambda idx: np_inp[idx]) def test_wsc_concrete_layout(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) @@ -359,52 +359,52 @@ def test_wsc_concrete_layout(self): arr = jax.device_put(np_inp, s) # Create a custom layout instead of using `arr.layout` to test the API. - custom_dll = DLL(major_to_minor=(0, 1)) + custom_dll = Layout(major_to_minor=(0, 1)) @jax.jit def f(x): y = x.T # Constrain `y` to the original layout of `arr` because without it, # the layout of `y` would be the transpose of `arr`. - return jax.lax.with_sharding_constraint(y, Layout(custom_dll, s)) + return jax.lax.with_sharding_constraint(y, Format(custom_dll, s)) out = f(arr) - self.assertEqual(out.layout.device_local_layout.major_to_minor, + self.assertEqual(out.format.layout.major_to_minor, custom_dll.major_to_minor) - self.assertEqual(out.layout, arr.layout) + self.assertEqual(out.format, arr.format) self.assertArraysEqual(out, np_inp.T) def test_wsc_bfloat16_concrete_layout(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) - shape = (16, 128) + shape = (64, 128) s = NamedSharding(mesh, P('x')) inp = jnp.arange(math.prod(shape), dtype=jnp.bfloat16).reshape(shape) arr = jax.device_put(inp, s) # Create a custom layout instead of using `arr.layout` to test the API. - custom_dll = DLL(major_to_minor=(0, 1)) + custom_dll = Layout(major_to_minor=(0, 1)) @jax.jit def f(x): y = x.T # Constrain `y` to the original layout of `arr` because without it, # the layout of `y` would be the transpose of `arr`. - return jax.lax.with_sharding_constraint(y, Layout(custom_dll, s)) + return jax.lax.with_sharding_constraint(y, Format(custom_dll, s)) out = f(arr) - self.assertEqual(out.layout.device_local_layout.major_to_minor, + self.assertEqual(out.format.layout.major_to_minor, custom_dll.major_to_minor) - self.assertEqual(out.layout, arr.layout) + self.assertEqual(out.format, arr.format) self.assertArraysEqual(out, inp.T) def test_device_put_user_concrete_layout(self): shape = (8, 128) np_inp = np.arange(math.prod(shape)).reshape(shape) - dll = DLL(major_to_minor=(1, 0)) + dll = Layout(major_to_minor=(1, 0)) s = SingleDeviceSharding(jax.devices()[0]) - out = jax.device_put(np_inp, Layout(dll, s)) - self.assertEqual(out.layout.device_local_layout.major_to_minor, + out = jax.device_put(np_inp, Format(dll, s)) + self.assertEqual(out.format.layout.major_to_minor, dll.major_to_minor) self.assertArraysEqual(out, np_inp) @@ -416,18 +416,18 @@ def test_device_put_user_concrete_layout_multi_device(self): jnp_inp = jnp.arange(math.prod(shape)).reshape(shape) arr = jax.device_put(np_inp, s) - custom_layout = Layout(DLL(major_to_minor=(0, 1)), s) - out1 = jax.device_put(arr, custom_layout) + custom_format = Format(Layout(major_to_minor=(0, 1)), s) + out1 = jax.device_put(arr, custom_format) - with jax.sharding.use_mesh(mesh): - out2 = jax.device_put(arr, custom_layout) - out3 = jax.device_put(jnp_inp, custom_layout) - out4 = jax.device_put(np_inp, custom_layout) + with jax.set_mesh(mesh): + out2 = jax.device_put(arr, custom_format) + out3 = jax.device_put(jnp_inp, custom_format) + out4 = jax.device_put(np_inp, custom_format) for o in [out1, out2, out3, out4]: self.assertArraysEqual(o, np_inp) - self.assertEqual(o.layout.device_local_layout.major_to_minor, - custom_layout.device_local_layout.major_to_minor) + self.assertEqual(o.format.layout.major_to_minor, + custom_format.layout.major_to_minor) def test_concrete_layout_jit(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) @@ -439,20 +439,20 @@ def test_concrete_layout_jit(self): def f(x): return x.T - custom_dll = DLL(major_to_minor=(0, 1)) - f = jax.jit(f, out_shardings=Layout(custom_dll, s)) + custom_dll = Layout(major_to_minor=(0, 1)) + f = jax.jit(f, out_shardings=Format(custom_dll, s)) out = f(arr) self.assertArraysEqual(out, np_inp.T) - self.assertEqual(out.layout.device_local_layout.major_to_minor, + self.assertEqual(out.format.layout.major_to_minor, custom_dll.major_to_minor) def test_compatible_aval_error(self): - custom_dll = DLL(major_to_minor=(0, 1, 2)) - l = Layout(custom_dll, SingleDeviceSharding(jax.devices()[0])) + custom_dll = Layout(major_to_minor=(0, 1, 2)) + l = Format(custom_dll, SingleDeviceSharding(jax.devices()[0])) inp = np.arange(8) - @partial(jax.jit, in_shardings=l) + @jax.jit(in_shardings=l) def f(x): return x * 2 @@ -462,8 +462,8 @@ def f(x): f(inp) def test_incompatible_aval_error_device_put(self): - custom_dll = DLL(major_to_minor=(0, 1, 2)) - l = Layout(custom_dll, SingleDeviceSharding(jax.devices()[0])) + custom_dll = Layout(major_to_minor=(0, 1, 2)) + l = Format(custom_dll, SingleDeviceSharding(jax.devices()[0])) inp = np.arange(8) with self.assertRaisesRegex( @@ -478,22 +478,22 @@ def test_concrete_layout_in_shardings(self): np_inp = np.arange(math.prod(shape)).reshape(shape) arr = jax.device_put(np_inp, s) - custom_dll = DLL(major_to_minor=(0, 1)) + custom_dll = Layout(major_to_minor=(0, 1)) @partial(jax.jit, - in_shardings=Layout(custom_dll, s), - out_shardings=Layout(DLL.AUTO)) + in_shardings=Format(custom_dll, s), + out_shardings=Format(Layout.AUTO)) def f(x): return x.T out = f(arr) self.assertArraysEqual(out, np_inp.T) - self.assertEqual(out.layout.device_local_layout.major_to_minor, + self.assertEqual(out.format.layout.major_to_minor, custom_dll.major_to_minor[::-1]) - custom_dll2 = DLL(major_to_minor=(1, 0)) + custom_dll2 = Layout(major_to_minor=(1, 0)) - @partial(jax.jit, in_shardings=Layout(custom_dll2, s)) + @jax.jit(in_shardings=Format(custom_dll2, s)) def g(x): return x.T @@ -503,11 +503,11 @@ def g(x): g(arr) def test_in_layouts_jit_jnp_input(self): - major_last_layout = DLL(major_to_minor=(1, 0)) + major_last_layout = Layout(major_to_minor=(1, 0)) sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0]) f = jax.jit(lambda x: x + 1, - in_shardings=Layout(major_last_layout, sharding)) + in_shardings=Format(major_last_layout, sharding)) arr = jnp.arange(8 * 128).reshape(8, 128) out = f(arr) @@ -531,10 +531,10 @@ def test_layout_donation(self): shape = (16, 128) np_inp = np.arange(math.prod(shape)).reshape(shape) - custom_dll = DLL(major_to_minor=(0, 1)) - arr = jax.device_put(np_inp, Layout(custom_dll, s)) + custom_dll = Layout(major_to_minor=(0, 1)) + arr = jax.device_put(np_inp, Format(custom_dll, s)) - @partial(jax.jit, in_shardings=Layout(custom_dll, s), donate_argnums=0) + @jax.jit(in_shardings=Format(custom_dll, s), donate_argnums=0) def f(x): return x @@ -549,7 +549,7 @@ def test_layout_donation_auto(self): arr = jax.device_put(np_inp, s) - @partial(jax.jit, out_shardings=Layout(DLL.AUTO), donate_argnums=0) + @jax.jit(out_shardings=Format(Layout.AUTO), donate_argnums=0) def f(x): return x * x @@ -562,11 +562,11 @@ def test_layout_donation_matching_in_and_out(self): shape = (128, 16) np_inp = np.arange(math.prod(shape)).reshape(shape) - custom_dll = DLL(major_to_minor=(0, 1)) - l = Layout(custom_dll, s) + custom_dll = Layout(major_to_minor=(0, 1)) + l = Format(custom_dll, s) arr = jax.device_put(np_inp, l) - @partial(jax.jit, in_shardings=l, out_shardings=l, donate_argnums=0) + @jax.jit(in_shardings=l, out_shardings=l, donate_argnums=0) def f(x): return x * x @@ -580,11 +580,13 @@ def test_layout_donation_mismatching_in_and_out_fails(self): shape = (16*2, 32016*2) np_inp = np.arange(math.prod(shape), dtype=jnp.bfloat16).reshape(shape) - custom_dll1 = DLL(major_to_minor=(1, 0), _tiling=((8,128), (2,1))) - l1 = Layout(custom_dll1, s) + tiling = (((16, 128), (2, 1)) if jtu.get_tpu_version() == 7 + else ((8, 128), (2, 1))) + custom_dll1 = Layout(major_to_minor=(1, 0), tiling=tiling) + l1 = Format(custom_dll1, s) arr = jax.device_put(np_inp, s) - @partial(jax.jit, out_shardings=l1, donate_argnums=0) + @jax.jit(out_shardings=l1, donate_argnums=0) def f(x): return x * x @@ -593,7 +595,7 @@ def f(x): self.assertFalse(arr.is_deleted()) def test_donation_error_on_auto(self): - @partial(jax.jit, donate_argnums=0, in_shardings=Layout(DLL.AUTO)) + @jax.jit(donate_argnums=0, in_shardings=Format(Layout.AUTO)) def f(x): return x * 2 @@ -601,7 +603,7 @@ def f(x): ValueError, ".*Did you mean to set the.*output layout.*AUTO.*"): f(jnp.arange(8)) - @partial(jax.jit, donate_argnums=0, out_shardings=Layout(DLL.AUTO)) + @jax.jit(donate_argnums=0, out_shardings=Format(Layout.AUTO)) def g(x): return x * 2 @@ -609,98 +611,6 @@ def g(x): ValueError, ".*Did you mean to set the.*input layout.*AUTO.*"): g(jnp.arange(8)) - def test_sparsecore_compute(self): - if not (jax.devices()[0].device_kind == 'TPU v5' or - jtu.is_device_tpu_at_least(6)): - self.skipTest('Does not have a sparsecore present') - shape = (128, 128) - inp = jnp.arange(math.prod(shape)).reshape(shape) - - dll = DLL(major_to_minor=(0, 1), _tiling=((8,),)) - s = SingleDeviceSharding(jax.devices()[0]) - sparse_layout = Layout(dll, s) - sparecore_arr = jax.device_put(inp, sparse_layout) - dense_layout = Layout(DLL(major_to_minor=(0, 1)), s) - - @compute_on('tpu_sparsecore') - @jax.jit - def sparsecore_compute(x): - return x * x - - @partial(jax.jit, out_shardings=(dense_layout, sparse_layout)) - def f(x, y): - return x * 2, sparsecore_compute(y) - - f(inp, sparecore_arr) - - def test_sparsecore_compute_twice(self): - if not ( - jax.devices()[0].device_kind == 'TPU v5' - or jtu.is_device_tpu_at_least(6) - ): - self.skipTest('Does not have a sparsecore present') - shape = (4096, 8) - inp = jnp.arange(math.prod(shape)).reshape(shape) - - dll = DLL(major_to_minor=(0, 1), _tiling=((8,),)) - s = SingleDeviceSharding(jax.devices()[0]) - sparse_layout = Layout(dll, s) - sparecore_arr = jax.device_put(inp, sparse_layout) - - @compute_on('tpu_sparsecore') - @jax.jit - def sparsecore_multiply(x, y): - return x * y - - @compute_on('tpu_sparsecore') - @jax.jit - def sparsecore_add(x, y): - return x + y - - @partial(jax.jit, donate_argnums=0, out_shardings=sparse_layout) - def f(x): - return sparsecore_multiply(sparsecore_add(x, x) + 1, x) - - f(sparecore_arr) - - def test_sparsecore_and_host_compute(self): - if not ( - jax.devices()[0].device_kind == 'TPU v5' - or jtu.is_device_tpu_at_least(6) - ): - self.skipTest('Does not have a sparsecore present') - shape = (128, 128) - inp = jnp.arange(math.prod(shape)).reshape(shape) - s = SingleDeviceSharding(jax.devices()[0]) - - sparse_dll = DLL(major_to_minor=(0, 1), _tiling=((8,),)) - sparse_layout = Layout(sparse_dll, s) - sparecore_arr = jax.device_put(inp, sparse_layout) - - host_dll = DLL(major_to_minor=(0, 1), _tiling=((1,),)) - host_layout = Layout(host_dll, s) - host_arr = jax.device_put(inp, host_layout) - - @compute_on('tpu_sparsecore') - @jax.jit - def sparsecore_compute(x): - return x * x - - @compute_on('device_host') - @jax.jit - def host_compute(x): - return x + x - - @partial( - jax.jit, - in_shardings=(sparse_layout, host_layout), - out_shardings=(sparse_layout, host_layout), - ) - def f(x, y): - return sparsecore_compute(x), host_compute(y) - - f(sparecore_arr, host_arr) - def test_cpp_layout_cache_miss(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) @@ -708,9 +618,9 @@ def test_cpp_layout_cache_miss(self): np_inp = np.arange(math.prod(shape)).reshape(shape) arr = jax.device_put(np_inp, s) - arr_m2m = arr.layout.device_local_layout.major_to_minor - custom_layout = Layout(DLL(major_to_minor=arr_m2m[::-1]), s) - arr2 = jax.device_put(np_inp, custom_layout) + arr_m2m = arr.format.layout.major_to_minor + custom_format = Format(Layout(major_to_minor=arr_m2m[::-1]), s) + arr2 = jax.device_put(np_inp, custom_format) @jax.jit def f(x): @@ -730,9 +640,9 @@ def test_layout_donation_with_default_layout(self): shape = (16, 16) np_inp = np.arange(math.prod(shape)).reshape(shape) arr = jax.device_put(np_inp, s) - out_layout = Layout(arr.layout.device_local_layout, s) + out_format = Format(arr.format.layout, s) - @partial(jax.jit, out_shardings=out_layout, donate_argnums=0) + @jax.jit(out_shardings=out_format, donate_argnums=0) def f(x): return x * 2 @@ -742,7 +652,133 @@ def f(x): out = f(arr) self.assertArraysEqual(out, np_inp * 2) - self.assertEqual(out.layout, out_layout) + self.assertEqual(out.format, out_format) + + def test_with_layout_constraint(self): + if not jtu.test_device_matches(['tpu']): + self.skipTest('Only works for TPU') + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + shape = (16, 128) + s = NamedSharding(mesh, P('x')) + np_inp = np.arange(math.prod(shape)).reshape(shape) + arr = jax.device_put(np_inp, s) + + # Create a custom layout instead of using `arr.layout` to test the API. + custom_dll = Layout(major_to_minor=arr.format.layout.major_to_minor[::-1]) + + def f(x): + y = x.T + # Constrain `y` to the original layout of `arr` because without it, + # the layout of `y` would be the transpose of `arr`. + y = with_layout_constraint(y, custom_dll) + return y * 2 + + f(arr) # doesn't crash + + f = jax.jit(f) + out = f(arr) + self.assertEqual(out.format.layout.major_to_minor, + custom_dll.major_to_minor) + self.assertArraysEqual(out, np_inp.T * 2) + + lowered_text = f.lower(arr).as_text() + self.assertIn('LayoutConstraint', lowered_text) + + def test_with_layout_constraint_vmap(self): + if not jtu.test_device_matches(['tpu']): + self.skipTest('Only works for TPU') + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + shape = (16, 128) + s = NamedSharding(mesh, P('x')) + np_inp = np.arange(math.prod(shape)).reshape(shape) + arr = jax.device_put(np_inp, s) + + def f(x): + y = x.T + # Constrain `y` to the original layout of `arr` because without it, + # the layout of `y` would be the transpose of `arr`. + y = with_layout_constraint(y, Layout(major_to_minor=(0,))) + return y * 2 + + out = jax.jit(jax.vmap(f))(arr) + self.assertEqual(out.format.layout.major_to_minor, (0, 1)) + + def test_eval_shape_format(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + s = NamedSharding(mesh, P('x', 'y')) + shape = (128, 16) + np_inp = np.arange(math.prod(shape)).reshape(shape) + + custom_dll = Layout(major_to_minor=(0, 1)) + l = Format(custom_dll, s) + arr = jax.device_put(np_inp, l) + + @jax.jit(in_shardings=l, out_shardings=l) + def f(x): + return x * x + + out = jax.eval_shape(f, arr) + self.assertEqual(out.format, l) + self.assertEqual(out.sharding, s) + + def test_valid_custom_layout_after_copy_across_clients(self): + if jax._src.lib.ifrt_version < 45: + self.skipTest('Only works for JAX_IFRT_VERSION_NUMBER >= 45') + if not jtu.test_device_matches(['tpu']): + self.skipTest('Only works for TPU') + + custom_dll = Layout(major_to_minor=(1, 0)) + + cpu_sharding = jax.sharding.SingleDeviceSharding( + jax.local_devices(backend='cpu')[0]) + cpu_format = Format(custom_dll, cpu_sharding) + cpu_array = jax.device_put(np.ones((128, 8)), cpu_format) + + mesh = jtu.create_mesh((1, 1), ('x', 'y')) + tpu_sharding = jax.sharding.NamedSharding(mesh, P()) + tpu_format = Format(custom_dll, tpu_sharding) + + copied_tpu_array = jax.device_put(cpu_array, tpu_format.sharding) + canonical_tpu_array = jax.device_put(np.ones((128, 8)), tpu_format) + self.assertEqual( + copied_tpu_array.format.layout, canonical_tpu_array.format.layout) + + @parameterized.named_parameters( + ('device_to_pinned_host', 'device', 'pinned_host'), + ('pinned_host_to_device', 'pinned_host', 'device'), + ('device_to_unpinned_host', 'device', 'unpinned_host'), + ('unpinned_host_to_device', 'unpinned_host', 'device'), + ('pinned_host_to_unpinned_host', 'pinned_host', 'unpinned_host'), + ('unpinned_host_to_pinned_host', 'unpinned_host', 'pinned_host'), + ) + def test_valid_layout_after_copy_across_memories( + self, src_memory_kind, dst_memory_kind): + if not jtu.test_device_matches(['tpu']): + self.skipTest('Only works for TPU') + custom_dll = Layout(major_to_minor=(1, 0)) + + mesh = jtu.create_mesh((1, 1), ('x', 'y')) + src_tpu_sharding = jax.sharding.NamedSharding( + mesh, P(), memory_kind=src_memory_kind) + dst_tpu_sharding = jax.sharding.NamedSharding( + mesh, P(), memory_kind=dst_memory_kind) + + # TPU unpinned_host memories do not support custom layouts. + if src_memory_kind == 'unpinned_host': + src_tpu_format = src_tpu_sharding + else: + src_tpu_format = Format(custom_dll, src_tpu_sharding) + if dst_memory_kind == 'unpinned_host': + dst_tpu_format = dst_tpu_sharding + else: + dst_tpu_format = Format(custom_dll, dst_tpu_sharding) + + tpu_array = jax.device_put(np.ones((128, 8)), src_tpu_format) + + copied_tpu_array = jax.device_put(tpu_array, dst_tpu_sharding) + canonical_tpu_array = jax.device_put(np.ones((128, 8)), dst_tpu_format) + self.assertEqual( + copied_tpu_array.format.layout, canonical_tpu_array.format.layout) if __name__ == '__main__': diff --git a/tests/linalg_sharding_test.py b/tests/linalg_sharding_test.py index d8e1e6a16871..551f2fb49454 100644 --- a/tests/linalg_sharding_test.py +++ b/tests/linalg_sharding_test.py @@ -14,7 +14,7 @@ import functools -from absl.testing import absltest +from absl.testing import absltest, parameterized import numpy as np import jax @@ -31,30 +31,22 @@ complex_types = jtu.dtypes.complex +# These functions are only supported on CPU. CPU_ONLY_FUN_AND_SHAPES = [ - # These functions are supported on GPU, but partitioning support will - # require updates to GSPMD, since they are lowered directly to HLO ops - # instead of custom calls on GPU. - (lax.linalg.cholesky, ((6, 6),)), - (lax.linalg.triangular_solve, ((6, 6), (4, 6))), - - # The GPU kernel for this function still uses an opaque descriptor to - # encode the input shapes so it is not partitionable. - # TODO(danfm): Update the kernel and enable this test on GPU. - (lax.linalg.tridiagonal_solve, ((6,), (6,), (6,), (6, 4))), - - # These functions are only supported on CPU. (lax.linalg.hessenberg, ((6, 6),)), (lax.linalg.schur, ((6, 6),)), ] CPU_AND_GPU_FUN_AND_SHAPES = [ + (lax.linalg.cholesky, ((6, 6),)), (lax.linalg.eig, ((6, 6),)), (lax.linalg.eigh, ((6, 6),)), (lax.linalg.lu, ((10, 6),)), (lax.linalg.qr, ((6, 6),)), (lax.linalg.svd, ((10, 6),)), + (lax.linalg.triangular_solve, ((6, 6), (4, 6))), (lax.linalg.tridiagonal, ((6, 6),)), + (lax.linalg.tridiagonal_solve, ((6,), (6,), (6,), (6, 4))), ] ALL_FUN_AND_SHAPES = CPU_ONLY_FUN_AND_SHAPES + CPU_AND_GPU_FUN_AND_SHAPES @@ -68,9 +60,19 @@ def setUp(self): self.skipTest("Requires multiple devices") def get_fun_and_shapes(self, fun_and_shapes, grad=False): - if (jtu.test_device_matches(["gpu"]) - and fun_and_shapes not in CPU_AND_GPU_FUN_AND_SHAPES): - self.skipTest(f"{fun_and_shapes[0].__name__} not supported on GPU") + if jtu.test_device_matches(["gpu"]): + if fun_and_shapes not in CPU_AND_GPU_FUN_AND_SHAPES: + self.skipTest( + f"Partitioning {fun_and_shapes[0].__name__} not supported on GPU.") + if (fun_and_shapes[0] in (lax.linalg.cholesky, lax.linalg.triangular_solve) + and not config.use_shardy_partitioner.value): + self.skipTest( + f"Partitioning {fun_and_shapes[0].__name__} only supported on GPU " + "when shardy is enabled.") + if fun_and_shapes[0] == lax.linalg.tridiagonal_solve: + self.skipTest( + f"Partitioning {fun_and_shapes[0].__name__} on GPU, requires a " + "more recent jaxlib version.") if not grad: return fun_and_shapes @@ -79,10 +81,10 @@ def get_fun_and_shapes(self, fun_and_shapes, grad=False): self.skipTest(f"{fun.__name__} does not support differentation") if jtu.test_device_matches(["gpu"]) and fun in ( lax.linalg.eig, lax.linalg.lu, lax.linalg.qr - ): + ) and not config.use_shardy_partitioner.value: self.skipTest( f"JVP of {fun.__name__} uses triangular solve on GPU, which doesn't " - "support batch partitioning yet") + "support batch partitioning unless shardy is enabled.") if fun == lax.linalg.eig: fun = functools.partial( @@ -107,9 +109,8 @@ def arg_maker(shape): return x return tuple(arg_maker(shape) for shape in shapes) - @jtu.sample_product( - fun_and_shapes=ALL_FUN_AND_SHAPES, - dtype=float_types + complex_types, + @parameterized.product( + fun_and_shapes=ALL_FUN_AND_SHAPES, dtype=float_types + complex_types ) @jtu.run_on_devices("gpu", "cpu") def test_batch_axis_sharding(self, fun_and_shapes, dtype): @@ -124,20 +125,17 @@ def test_batch_axis_sharding(self, fun_and_shapes, dtype): expected = fun(*args) actual = fun_jit(*args_sharded) self.assertAllClose(actual, expected) - # TODO(danfm): Re-enable this check after diganosing non-determinism. - # self.assertNotIn("all-", fun_jit.lower(*args_sharded).compile().as_text()) + self.assertNotIn("all-", fun_jit.lower(*args_sharded).compile().as_text()) vmap_fun = jax.vmap(fun) vmap_fun_jit = jax.jit(vmap_fun) actual = vmap_fun_jit(*args_sharded) self.assertAllClose(actual, expected) - # TODO(danfm): Re-enable this check after diganosing non-determinism. - # self.assertNotIn( - # "all-", vmap_fun_jit.lower(*args_sharded).compile().as_text()) + self.assertNotIn( + "all-", vmap_fun_jit.lower(*args_sharded).compile().as_text()) - @jtu.sample_product( - fun_and_shapes=ALL_FUN_AND_SHAPES, - dtype=float_types + complex_types, + @parameterized.product( + fun_and_shapes=ALL_FUN_AND_SHAPES, dtype=float_types + complex_types ) @jtu.run_on_devices("gpu", "cpu") def test_non_batch_axis_sharding(self, fun_and_shapes, dtype): @@ -155,12 +153,19 @@ def test_non_batch_axis_sharding(self, fun_and_shapes, dtype): self.assertIn( "all-gather", fun_jit.lower(*args_sharded).compile().as_text()) - @jtu.sample_product( - fun_and_shapes=ALL_FUN_AND_SHAPES, - dtype=float_types + complex_types, + @parameterized.product( + fun_and_shapes=ALL_FUN_AND_SHAPES, dtype=float_types + complex_types ) @jtu.run_on_devices("gpu", "cpu") def test_batch_axis_sharding_jvp(self, fun_and_shapes, dtype): + if (jtu.is_device_rocm() and + fun_and_shapes[0] is lax.linalg.qr and + dtype == np.complex64): + # numerical errors seen as of ROCm 7.2 due to rocSolver issue for qr with complex64 + # TODO: re-enable the test once the rocSolver issue is fixed + self.skipTest("test_batch_axis_sharding_jvp13 not supported on ROCm due to rocSolver issue") + if fun_and_shapes[0] is lax.linalg.tridiagonal_solve and jtu.is_device_rocm(): + self.skipTest("test_batch_axis_sharding_jvp is not supported on ROCm") fun, shapes = self.get_fun_and_shapes(fun_and_shapes, grad=True) primals = self.get_args(shapes, dtype, batch_size=8) tangents = tuple(map(jnp.ones_like, primals)) @@ -181,14 +186,14 @@ def jvp_fun(primals, tangents): (primals_sharded, tangents), ]: _, actual = jvp_fun_jit(*args) - self.assertAllClose(actual, expected) - # TODO(danfm): Re-enable this check after diganosing non-determinism. - # hlo = jvp_fun_jit.lower(primals_sharded, tangents_sharded).compile() - # self.assertNotIn("all-", hlo.as_text()) - - @jtu.sample_product( - fun_and_shapes=ALL_FUN_AND_SHAPES, - dtype=float_types + complex_types, + self.assertAllClose(actual, expected, rtol={ + np.float32: 1e-4, np.float64: 2e-11, np.complex64: 1e-4, + np.complex128: 1e-11}) + hlo = jvp_fun_jit.lower(primals_sharded, tangents_sharded).compile() + self.assertNotIn("all-", hlo.as_text()) + + @parameterized.product( + fun_and_shapes=ALL_FUN_AND_SHAPES, dtype=float_types + complex_types ) @jtu.run_on_devices("gpu", "cpu") def test_batch_axis_sharding_vjp(self, fun_and_shapes, dtype): @@ -204,10 +209,11 @@ def test_batch_axis_sharding_vjp(self, fun_and_shapes, dtype): vjp_fun_jit = jax.jit(vjp_fun) expected = vjp_fun(tangents) actual = vjp_fun_jit(tangents_sharded) - self.assertAllClose(actual, expected) - # TODO(danfm): Re-enable this check after diganosing non-determinism. - # hlo = vjp_fun_jit.lower(tangents_sharded).compile() - # self.assertNotIn("all-", hlo.as_text()) + self.assertAllClose(actual, expected, rtol={ + np.float32: 1e-4, np.float64: 1e-11, np.complex64: 1e-4, + np.complex128: 1e-11}) + hlo = vjp_fun_jit.lower(tangents_sharded).compile() + self.assertNotIn("all-", hlo.as_text()) if __name__ == "__main__": diff --git a/tests/linalg_test.py b/tests/linalg_test.py index feab105ccbe2..56b480b91d29 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -14,8 +14,9 @@ from functools import partial import itertools -from typing import Iterator -from unittest import skipIf +from collections.abc import Iterator +import platform +import unittest import numpy as np import scipy @@ -31,6 +32,7 @@ from jax import scipy as jsp from jax._src import config from jax._src.lax import linalg as lax_linalg +from jax._src.lib import cuda_versions from jax._src import test_util as jtu from jax._src import xla_bridge from jax._src.numpy.util import promote_dtypes_inexact @@ -67,9 +69,25 @@ def _axis_for_ndim(ndim: int) -> Iterator[None | int | tuple[int, ...]]: yield (-1, 0, 1) +def _random_invertible(rng, shape, dtype): + """ + Generate a random invertible matrix was specified shape and dtype + """ + while True: + a = rng(shape, dtype) + try: + np.linalg.inv(a) + except np.linalg.LinAlgError: + pass + else: + return a + + def osp_linalg_toeplitz(c: np.ndarray, r: np.ndarray | None = None) -> np.ndarray: """scipy.linalg.toeplitz with v1.17+ batching semantics.""" - if scipy_version >= (1, 17, 0): + # scipy 1.17 doesn't support zero batch size: https://github.com/scipy/scipy/pull/24151 + zero_batch = (0 in c.shape[:-1]) or (r is not None and 0 in r.shape[:-1]) + if scipy_version >= (1, 17, 0) and not zero_batch: return scipy.linalg.toeplitz(c, r) elif r is None: c = np.atleast_1d(c) @@ -82,6 +100,53 @@ def osp_linalg_toeplitz(c: np.ndarray, r: np.ndarray | None = None) -> np.ndarra scipy.linalg.toeplitz, signature="(m),(n)->(m,n)", otypes=(np.result_type(c, r),))(c, r) +def svd_algorithms(): + algorithms = [None] + if jtu.device_under_test() in ["cpu", "gpu"]: + algorithms.append(lax.linalg.SvdAlgorithm.QR) + if jtu.device_under_test() == "gpu": + algorithms.append(lax.linalg.SvdAlgorithm.JACOBI) + if jtu.device_under_test() == "tpu" or jtu.device_under_test() == "gpu": + algorithms.append(lax.linalg.SvdAlgorithm.POLAR) + return algorithms + + +# (complex) Eigenvectors are only unique up to an arbitrary phase. This makes the gradient +# tests based on finite differences unstable, since perturbing the input matri may cause an +# arbitrary sign flip of one or more of the eigenvectors. To remedy this, we normalize the +# vectors such that the first component has phase 0. +def _normalizing_eigh(H: np.ndarray, lower: bool, symmetrize_input: bool): + uplo = "L" if lower else "U" + e, v = jnp.linalg.eigh(H, UPLO=uplo, symmetrize_input=symmetrize_input) + top_rows = v[..., 0:1, :] + if np.issubdtype(H.dtype, np.complexfloating): + angle = -jnp.angle(top_rows) + phase = lax.complex(jnp.cos(angle), jnp.sin(angle)) + else: + phase = jnp.sign(top_rows) + v *= phase + return e, v + + +# (complex) singular vectors are only unique up to an arbitrary phase. This makes the gradient +# tests based on finite differences unstable, since perturbing the input matri may cause an +# arbitrary sign flip of one or more of the singular vectors. To remedy this, we normalize the +# singular vectors such that the first component of the left singular vectors has phase 0. +def _normalizing_svd(a: np.array, full_matrices: bool): + u, s, vt = jnp.linalg.svd(a, full_matrices=full_matrices, compute_uv=True) + top_rows = u[..., 0:1, :] + if np.issubdtype(a.dtype, np.complexfloating): + angle = -jnp.angle(top_rows) + u_phase = lax.complex(jnp.cos(angle), jnp.sin(angle)) + v_phase = lax.complex(jnp.cos(-angle), jnp.sin(-angle)) + else: + u_phase = jnp.sign(top_rows) + v_phase = u_phase + u *= u_phase + vt *= np.swapaxes(v_phase, -1, -2) + return u, s, vt + + class NumpyLinalgTest(jtu.JaxTestCase): @jtu.sample_product( @@ -96,21 +161,9 @@ def args_maker(): a = rng(factor_shape, dtype) return [np.matmul(a, jnp.conj(T(a)))] - jnp_fun = partial(jnp.linalg.cholesky, upper=upper) - - def np_fun(x, upper=upper): - # Upper argument added in NumPy 2.0.0 - if jtu.numpy_version() >= (2, 0, 0): - return np.linalg.cholesky(x, upper=upper) - result = np.linalg.cholesky(x) - if upper: - axes = list(range(x.ndim)) - axes[-1], axes[-2] = axes[-2], axes[-1] - return np.transpose(result, axes).conj() - return result - - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, - tol=1e-3) + np_fun = partial(np.linalg.cholesky, upper=upper) + jnp_fun = partial(jnp.linalg.cholesky, upper=upper, symmetrize_input=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=1e-3) self._CompileAndCheck(jnp_fun, args_maker) if jnp.finfo(dtype).bits == 64: @@ -232,6 +285,7 @@ def testTensorsolveAxes(self): shape=[(0, 0), (1, 1), (3, 3), (4, 4), (10, 10), (200, 200), (2, 2, 2), (2, 3, 3), (3, 2, 2)], ) + @jtu.ignore_warning(message="(divide by zero|overflow|invalid value)", category=RuntimeWarning) def testSlogdet(self, shape, dtype, method): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] @@ -286,15 +340,28 @@ def check_left_eigenvectors(a, w, vl): check_right_eigenvectors(aH, wC, vl) a, = args_maker() - results = lax.linalg.eig( - a, compute_left_eigenvectors=compute_left_eigenvectors, - compute_right_eigenvectors=compute_right_eigenvectors) - w = results[0] - if compute_left_eigenvectors: - check_left_eigenvectors(a, w, results[1]) - if compute_right_eigenvectors: - check_right_eigenvectors(a, w, results[1 + compute_left_eigenvectors]) + implementations = [None] + + if ( + jtu.is_device_cuda() + and not compute_left_eigenvectors + and cuda_versions + and cuda_versions.cusolver_get_version() >= 11701 + ): + implementations.append(jax.lax.linalg.EigImplementation.CUSOLVER) + + for implementation in implementations: + results = lax.linalg.eig( + a, compute_left_eigenvectors=compute_left_eigenvectors, + compute_right_eigenvectors=compute_right_eigenvectors, + implementation=implementation) + w = results[0] + + if compute_left_eigenvectors: + check_left_eigenvectors(a, w, results[1]) + if compute_right_eigenvectors: + check_right_eigenvectors(a, w, results[1 + compute_left_eigenvectors]) self._CompileAndCheck(partial(jnp.linalg.eig), args_maker, rtol=1e-3) @@ -308,10 +375,16 @@ def check_left_eigenvectors(a, w, vl): def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors, compute_right_eigenvectors): """Verifies that `eig` fails gracefully if given non-finite inputs.""" + if jtu.is_device_cuda(): + # TODO(phawkins): CUSOLVER's implementation does not pass this test. + implementation = jax.lax.linalg.EigImplementation.LAPACK + else: + implementation = None a = jnp.full(shape, jnp.nan, dtype) results = lax.linalg.eig( a, compute_left_eigenvectors=compute_left_eigenvectors, - compute_right_eigenvectors=compute_right_eigenvectors) + compute_right_eigenvectors=compute_right_eigenvectors, + implementation=implementation) for result in results: self.assertTrue(np.all(np.isnan(result))) @@ -341,7 +414,13 @@ def testEigvals(self, shape, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] a, = args_maker() - w1, _ = jnp.linalg.eig(a) + result = jnp.linalg.eig(a) + # Check that eig returns a namedtuple with the right fields + self.assertTrue(hasattr(result, 'eigenvalues')) + self.assertTrue(hasattr(result, 'eigenvectors')) + self.assertIs(result.eigenvalues, result[0]) + self.assertIs(result.eigenvectors, result[1]) + w1 = result.eigenvalues w2 = jnp.linalg.eigvals(a) self.assertAllClose(w1, w2, rtol={np.complex64: 1e-5, np.complex128: 2e-14}) @@ -413,6 +492,15 @@ def testEigh(self, n, dtype, lower): w_np.astype(w.dtype), w, atol=tol * np.linalg.norm(a), rtol=tol ) + @jax._src.config.explicit_x64_dtypes("allow") + @jtu.run_on_devices("gpu") + @unittest.skip("Needs a large amount of GPU memory, doesn't work in CI") + def testEighLargeMatrix(self): + # https://github.com/jax-ml/jax/issues/33062 + n = 16384 + A = jnp.eye(n, dtype=jnp.float64) + jax.block_until_ready(jax.lax.linalg.eigh(A)) + @jtu.sample_product( start=[0, 1, 63, 64, 65, 255], end=[1, 63, 64, 65, 256], @@ -479,10 +567,15 @@ def testEighZeroDiagonal(self): eps = jnp.finfo(a.dtype).eps with jax.numpy_rank_promotion('allow'): self.assertLessEqual( - np.linalg.norm(np.matmul(a, v) - w * v), 2 * eps * np.linalg.norm(a) + np.linalg.norm(np.matmul(a, v) - w * v), 2.5 * eps * np.linalg.norm(a) ) + def testEighTinyNorm(self): + if jtu.is_device_rocm(): + # numerical errors seen as of ROCm 7.2 due to hipSolver issue + # TODO: re-enable the test once the hipSolver issue is fixed + self.skipTest("testEighNorm not supported on ROCm due to hipSOLVER issue") rng = jtu.rand_default(self.rng()) a = rng((300, 300), dtype=np.float32) eps = jnp.finfo(a.dtype).eps @@ -545,14 +638,14 @@ def args_maker(): ) @jtu.sample_product( - shape=[(1, 1), (4, 4), (5, 5), (50, 50), (2, 10, 10)], - dtype=float_types + complex_types, - lower=[True, False], + shape=[(1, 1), (4, 4), (5, 5), (25, 25), (2, 10, 10)], + dtype=float_types + complex_types, + lower=[True, False], ) def testEighGrad(self, shape, dtype, lower): + if platform.system() == "Windows": + self.skipTest("Skip on Windows due to tolerance issues.") rng = jtu.rand_default(self.rng()) - self.skipTest("Test fails with numeric errors.") - uplo = "L" if lower else "U" a = rng(shape, dtype) a = (a + np.conj(T(a))) / 2 ones = np.ones((a.shape[-1], a.shape[-1]), dtype=dtype) @@ -560,51 +653,12 @@ def testEighGrad(self, shape, dtype, lower): # Gradient checks will fail without symmetrization as the eigh jvp rule # is only correct for tangents in the symmetric subspace, whereas the # checker checks against unconstrained (co)tangents. - if dtype not in complex_types: - f = partial(jnp.linalg.eigh, UPLO=uplo, symmetrize_input=True) - else: # only check eigenvalue grads for complex matrices - f = lambda a: partial(jnp.linalg.eigh, UPLO=uplo, symmetrize_input=True)(a)[0] - jtu.check_grads(f, (a,), 2, rtol=1e-5) - - @jtu.sample_product( - shape=[(1, 1), (4, 4), (5, 5), (50, 50)], - dtype=complex_types, - lower=[True, False], - eps=[1e-5], - ) - def testEighGradVectorComplex(self, shape, dtype, lower, eps): - rng = jtu.rand_default(self.rng()) - # Special case to test for complex eigenvector grad correctness. - # Exact eigenvector coordinate gradients are hard to test numerically for complex - # eigensystem solvers given the extra degrees of per-eigenvector phase freedom. - # Instead, we numerically verify the eigensystem properties on the perturbed - # eigenvectors. You only ever want to optimize eigenvector directions, not coordinates! - uplo = "L" if lower else "U" - a = rng(shape, dtype) - a = (a + np.conj(a.T)) / 2 - a = np.tril(a) if lower else np.triu(a) - a_dot = eps * rng(shape, dtype) - a_dot = (a_dot + np.conj(a_dot.T)) / 2 - a_dot = np.tril(a_dot) if lower else np.triu(a_dot) - # evaluate eigenvector gradient and groundtruth eigensystem for perturbed input matrix - f = partial(jnp.linalg.eigh, UPLO=uplo) - (w, v), (dw, dv) = jvp(f, primals=(a,), tangents=(a_dot,)) - self.assertTrue(jnp.issubdtype(w.dtype, jnp.floating)) - self.assertTrue(jnp.issubdtype(dw.dtype, jnp.floating)) - new_a = a + a_dot - new_w, new_v = f(new_a) - new_a = (new_a + np.conj(new_a.T)) / 2 - new_w = new_w.astype(new_a.dtype) - # Assert rtol eigenvalue delta between perturbed eigenvectors vs new true eigenvalues. - RTOL = 1e-2 - with jax.numpy_rank_promotion('allow'): - assert np.max( - np.abs((np.diag(np.dot(np.conj((v+dv).T), np.dot(new_a,(v+dv)))) - new_w) / new_w)) < RTOL - # Redundant to above, but also assert rtol for eigenvector property with new true eigenvalues. - assert np.max( - np.linalg.norm(np.abs(new_w*(v+dv) - np.dot(new_a, (v+dv))), axis=0) / - np.linalg.norm(np.abs(new_w*(v+dv)), axis=0) - ) < RTOL + f = partial(_normalizing_eigh, lower=lower, symmetrize_input=True) + norm_a = jnp.linalg.norm(a) + eps = 2e-5 * norm_a + atol = 5e-3 * norm_a + rtol = 0.025 + jtu.check_grads(f, (a,), 2, atol=atol, rtol=rtol, eps=eps) def testEighGradPrecision(self): rng = jtu.rand_default(self.rng()) @@ -705,10 +759,7 @@ def testStringInfNorm(self): def testMatrixNorm(self, shape, dtype, keepdims, ord): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] - if jtu.numpy_version() < (2, 0, 0): - np_fn = partial(np.linalg.norm, ord=ord, keepdims=keepdims, axis=(-2, -1)) - else: - np_fn = partial(np.linalg.matrix_norm, ord=ord, keepdims=keepdims) + np_fn = partial(np.linalg.matrix_norm, ord=ord, keepdims=keepdims) jnp_fn = partial(jnp.linalg.matrix_norm, ord=ord, keepdims=keepdims) self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-3) self._CompileAndCheck(jnp_fn, args_maker) @@ -723,7 +774,6 @@ def testEmptyMatrixNorm(self, shape, dtype, ord): norm = jnp.linalg.matrix_norm(x, ord=ord) self.assertEqual(norm, 0) - @skipIf(jtu.numpy_version() < (2, 0, 0), "np.linalg.vector_norm requires NumPy 2.0") @jtu.sample_product( [ dict(shape=shape, axis=axis) @@ -766,7 +816,7 @@ def testEmptyVectorNorm(self, dtype, ord): def testVecdot(self, lhs_shape, rhs_shape, axis, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] - np_fn = jtu.numpy_vecdot if jtu.numpy_version() < (2, 0, 0) else np.linalg.vecdot + np_fn = np.linalg.vecdot np_fn = jtu.promote_like_jnp(partial(np_fn, axis=axis)) jnp_fn = partial(jnp.linalg.vecdot, axis=axis) tol = {np.float16: 1e-2, np.float32: 2e-2, np.float64: 1e-12, @@ -794,8 +844,7 @@ def testVecdot(self, lhs_shape, rhs_shape, axis, dtype): def testMatmul(self, lhs_shape, rhs_shape, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] - np_fn = jtu.promote_like_jnp( - np.matmul if jtu.numpy_version() < (2, 0, 0) else np.linalg.matmul) + np_fn = jtu.promote_like_jnp(np.linalg.matmul) jnp_fn = jnp.linalg.matmul tol = {np.float16: 1e-2, np.float32: 2e-2, np.float64: 1e-12, np.complex128: 1e-12} @@ -821,10 +870,7 @@ def testMatmul(self, lhs_shape, rhs_shape, dtype): def testTensordot(self, lhs_shape, rhs_shape, axes, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] - np_fn = jtu.promote_like_jnp( - partial( - np.tensordot if jtu.numpy_version() < (2, 0, 0) else np.linalg.tensordot, - axes=axes)) + np_fn = jtu.promote_like_jnp(partial(np.linalg.tensordot, axes=axes)) jnp_fn = partial(jnp.linalg.tensordot, axes=axes) tol = {np.float16: 1e-2, np.float32: 2e-2, np.float64: 1e-12, np.complex128: 1e-12} @@ -837,44 +883,52 @@ def testTensordot(self, lhs_shape, rhs_shape, axes, dtype): preferred_element_type=dtype) self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol) - @jtu.sample_product( - [ - dict(m=m, n=n, full_matrices=full_matrices, hermitian=hermitian) - for (m, n), full_matrices in ( - list( - itertools.product( - itertools.product([0, 2, 7, 29, 32, 53], repeat=2), - [False, True], + @parameterized.product( + jtu.sample_product_testcases( + [ + dict(m=m, n=n, full_matrices=full_matrices, hermitian=hermitian) + for (m, n), full_matrices in ( + list( + itertools.product( + itertools.product([0, 2, 7, 29, 32, 53], repeat=2), + [False, True], + ) ) + + + # Test cases that ensure we are economical when computing the SVD + # and its gradient. If we form a 400kx400k matrix explicitly we + # will OOM. + [((400000, 2), False), ((2, 400000), False)] ) - + - # Test cases that ensure we are economical when computing the SVD - # and its gradient. If we form a 400kx400k matrix explicitly we - # will OOM. - [((400000, 2), False), ((2, 400000), False)] - ) - for hermitian in ([False, True] if m == n else [False]) - ], - b=[(), (3,), (2, 3)], - dtype=float_types + complex_types, - compute_uv=[False, True], - algorithm=[None, lax.linalg.SvdAlgorithm.QR, lax.linalg.SvdAlgorithm.JACOBI], + for hermitian in ([False, True] if m == n else [False]) + ], + b=[(), (3,), (2, 3)], + dtype=float_types + complex_types, + compute_uv=[False, True], + ), + algorithm=svd_algorithms() ) @jax.default_matmul_precision("float32") - def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, hermitian, algorithm): - if algorithm is not None: - if hermitian: - self.skipTest("Hermitian SVD doesn't support the algorithm parameter.") - if not jtu.test_device_matches(["cpu", "gpu"]): - self.skipTest("SVD algorithm selection only supported on CPU and GPU.") - # TODO(danfm): Remove this check after 0.5.2 is released. - if jtu.test_device_matches(["cpu"]) and jtu.jaxlib_version() <= (0, 5, 1): - self.skipTest("SVD algorithm selection on CPU requires a newer jaxlib version.") - if jtu.test_device_matches(["cpu"]) and algorithm == lax.linalg.SvdAlgorithm.JACOBI: - self.skipTest("Jacobi SVD not supported on GPU.") - - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(b + (m, n), dtype)] + def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, hermitian, + algorithm): + if hermitian and algorithm is not None: + # Hermitian SVD doesn't support the algorithm parameter. + self.skipTest("Hermitian SVD doesn't support the algorithm parameter") + + if jtu.is_device_rocm() and algorithm == lax.linalg.SvdAlgorithm.POLAR: + self.skipTest("ROCM polar SVD not implemented") + + if ( + jtu.test_device_matches(["cuda"]) + and (algorithm, m, n) in [ + (lax.linalg.SvdAlgorithm.POLAR, 400000, 2), + (lax.linalg.SvdAlgorithm.POLAR, 2, 400000), + (lax.linalg.SvdAlgorithm.JACOBI, 400000, 2), + (lax.linalg.SvdAlgorithm.JACOBI, 2, 400000), + ] + ): + # Test fails with CUDA polar and jacobi decompositions + self.skipTest("Test fails with CUDA polar and jacobi decompositions") def compute_max_backward_error(operand, reconstructed_operand): error_norm = np.linalg.norm(operand - reconstructed_operand, @@ -884,6 +938,9 @@ def compute_max_backward_error(operand, reconstructed_operand): max_backward_error = np.amax(backward_error) return max_backward_error + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(b + (m, n), dtype)] + tol = 100 * jnp.finfo(dtype).eps reconstruction_tol = 2 * tol unitariness_tol = 3 * tol @@ -935,8 +992,8 @@ def compute_max_backward_error(operand, reconstructed_operand): unitary_mat, rtol=unitariness_tol, atol=unitariness_tol) else: - self.assertTrue(np.allclose(np.linalg.svd(a, compute_uv=False), - np.asarray(out), atol=1e-4, rtol=1e-4)) + self.assertAllClose(np.linalg.svd(a, compute_uv=False), np.asarray(out), + atol=1e-4, rtol=3e-4) self._CompileAndCheck(partial(fun, full_matrices=full_matrices, compute_uv=compute_uv), @@ -952,26 +1009,45 @@ def compute_max_backward_error(operand, reconstructed_operand): jtu.check_jvp(svd, partial(jvp, svd), (a,), rtol=5e-2, atol=2e-1) if compute_uv and (not full_matrices): - b, = args_maker() + d, = args_maker() def f(x): u, s, v = jnp.linalg.svd( - a + x * b, + a + x * d, full_matrices=full_matrices, compute_uv=compute_uv) vdiag = jnp.vectorize(jnp.diag, signature='(k)->(k,k)') return jnp.matmul(jnp.matmul(u, vdiag(s).astype(u.dtype)), v).real _, t_out = jvp(f, (1.,), (1.,)) if dtype == np.complex128: - atol = 2e-13 + tol = 2e-13 else: - atol = 6e-4 - self.assertArraysAllClose(t_out, b.real, atol=atol) + tol = 6e-4 + self.assertArraysAllClose(t_out, d.real, atol=tol, rtol=tol) def testJspSVDBasic(self): # since jax.scipy.linalg.svd is almost the same as jax.numpy.linalg.svd # do not check it functionality here jsp.linalg.svd(np.ones((2, 2), dtype=np.float32)) + @jtu.sample_product( + shape=[(1, 1), (4, 4), (2, 5), (5, 2), (5, 5), (2, 5, 5)], + dtype=float_types + complex_types, + full_matrices=[True, False], + compute_uv=[True, False], + ) + @jax.default_matmul_precision("float32") + def testSVDGrad(self, shape, dtype, full_matrices, compute_uv): + rng = jtu.rand_default(self.rng()) + a = rng(shape, dtype) + if not compute_uv: + f = partial(jnp.linalg.svd, full_matrices=False, compute_uv=False) + else: + f = partial(_normalizing_svd, full_matrices=full_matrices) + if full_matrices and shape[-1] != shape[-2]: + self.skipTest("JVP for SVD not implemented for full matrices.") + + jtu.check_grads(f, (a,), order=2, rtol=0.035, eps=1.0 / 512) + @jtu.sample_product( shape=[(0, 2), (2, 0), (3, 4), (3, 3), (4, 3)], dtype=[np.float32], @@ -1031,7 +1107,8 @@ def compare_orthogonal(q1, q2): phases = np.divide(sum_of_ratios, np.abs(sum_of_ratios)) q1 *= phases nm = norm(q1 - q2) - self.assertTrue(np.all(nm < 160), msg=f"norm={np.amax(nm)}") + max_norm = 220 if jtu.is_device_tpu(7, 'x') else 160 + self.assertTrue(np.all(nm < max_norm), msg=f"norm={np.amax(nm)}") # Check a ~= qr norm_error = norm(a - np.matmul(lq, lr)) @@ -1064,7 +1141,7 @@ def testQrInvalidDtypeCPU(self, shape=(5, 6), dtype=np.float16): else: err, msg = Exception, "Unsupported dtype" with self.assertRaisesRegex(err, msg): - jnp.linalg.qr(arr) + jax.block_until_ready(jnp.linalg.qr(arr)) @jtu.sample_product( shape=[(10, 4, 5), (5, 3, 3), (7, 6, 4)], @@ -1081,8 +1158,17 @@ def testQrBatching(self, shape, dtype): pnorm=[jnp.inf, -jnp.inf, 1, -1, 2, -2, 'fro'], dtype=float_types + complex_types, ) - @jtu.skip_on_devices("gpu") # TODO(#2203): numerical errors def testCond(self, shape, pnorm, dtype): + + if jtu.test_device_matches(['gpu']): + # Unskipping test for ROCm while leaving + # original skip in place for other GPUs as per + # commit: e81024f5053def119eddb7fb06ff6c4f7b5948a8 + # + # Original note: TODO(#2203): numerical errors + if not jtu.is_device_rocm(): + self.skipTest("Unsupported platform") + def gen_mat(): # arr_gen = jtu.rand_some_nan(self.rng()) arr_gen = jtu.rand_default(self.rng()) @@ -1149,8 +1235,7 @@ def testSolveBroadcasting(self, lhs_shape, rhs_shape): # that we match NumPy's convention in all cases. rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(lhs_shape, 'float32'), rng(rhs_shape, 'float32')] - if jtu.numpy_version() >= (2, 0, 0): # NumPy 2.0 semantics - self._CheckAgainstNumpy(np.linalg.solve, jnp.linalg.solve, args_maker, tol=1E-3) + self._CheckAgainstNumpy(np.linalg.solve, jnp.linalg.solve, args_maker, tol=1E-3) self._CompileAndCheck(jnp.linalg.solve, args_maker) @jtu.sample_product( @@ -1161,14 +1246,7 @@ def testInv(self, shape, dtype): rng = jtu.rand_default(self.rng()) def args_maker(): - invertible = False - while not invertible: - a = rng(shape, dtype) - try: - np.linalg.inv(a) - invertible = True - except np.linalg.LinAlgError: - pass + a = _random_invertible(rng=rng, shape=shape, dtype=dtype) return [a] self._CheckAgainstNumpy(np.linalg.inv, jnp.linalg.inv, args_maker, @@ -1183,6 +1261,7 @@ def args_maker(): for hermitian in ([False, True] if shape[-1] == shape[-2] else [False])], dtype=float_types + complex_types, ) + @jtu.ignore_warning(message="invalid value", category=RuntimeWarning) def testPinv(self, shape, hermitian, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] @@ -1199,11 +1278,15 @@ def np_fn(a): # TODO(phawkins): 6e-2 seems like a very loose tolerance. jtu.check_grads(jnp_fn, args_maker(), 1, rtol=6e-2, atol=1e-3) - def testPinvDeprecatedArgs(self): + def testPinvRcond(self): x = jnp.ones((3, 3)) - with self.assertDeprecationWarnsOrRaises("jax-numpy-linalg-pinv-rcond", - "The rcond argument for linalg.pinv is deprecated."): - jnp.linalg.pinv(x, rcond=1E-2) + with self.assertRaisesWithLiteralMatch( + ValueError, "pinv: only one of rtol and rcond may be specified."): + jnp.linalg.pinv(x, rcond=1E-2, rtol=1E-2) + self.assertArraysEqual( + jnp.linalg.pinv(x, rcond=1E-2), + jnp.linalg.pinv(x, rtol=1E-2) + ) def testPinvGradIssue2792(self): def f(p): @@ -1233,6 +1316,13 @@ def testMatrixPower(self, shape, dtype, n): self._CompileAndCheck(partial(jnp.linalg.matrix_power, n=n), args_maker, rtol=1e-3) + def testMatrixPowerBool(self): + # Regression test for https://github.com/jax-ml/jax/issues/28603 + mat = np.array([[True,True], [False,True]]) + np_result = np.linalg.matrix_power(mat, 2) + jnp_result = jnp.linalg.matrix_power(mat, 2) + self.assertArraysEqual(np_result, jnp_result) + @jtu.sample_product( shape=[(3, ), (1, 2), (8, 5), (4, 4), (5, 5), (50, 50), (3, 4, 5), (2, 3, 4, 5)], @@ -1247,11 +1337,15 @@ def testMatrixRank(self, shape, dtype): self._CompileAndCheck(jnp.linalg.matrix_rank, args_maker, check_dtypes=False, rtol=1e-3) - def testMatrixRankDeprecatedArgs(self): + def testMatrixRankTol(self): x = jnp.ones((3, 3)) - with self.assertDeprecationWarnsOrRaises("jax-numpy-linalg-matrix_rank-tol", - "The tol argument for linalg.matrix_rank is deprecated."): - jnp.linalg.matrix_rank(x, tol=1E-2) + with self.assertRaisesWithLiteralMatch( + ValueError, "matrix_rank: only one of tol or rtol may be specified."): + jnp.linalg.matrix_rank(x, rtol=1E-2, tol=1E-2) + self.assertArraysEqual( + jnp.linalg.matrix_rank(x, rtol=1E-2), + jnp.linalg.matrix_rank(x, tol=1E-2) + ) @jtu.sample_product( shapes=[ @@ -1305,6 +1399,16 @@ def testLstsq(self, lhs_shape, rhs_shape, dtype, rcond): # TODO: # jtu.check_grads(lambda *args: jnp_fun(*args)[0], args_maker(), order=2, atol=1e-2, rtol=1e-2) + @jtu.sample_product( + shape=[(2, 1), (2, 2), (1, 2)] + ) + def testLstsqZeroMatrix(self, shape): + # Regression test for https://github.com/jax-ml/jax/issues/32666 + args_maker = lambda: [np.zeros(shape), np.ones((shape))] + np_fun = np.linalg.lstsq + jnp_fun = partial(jnp.linalg.lstsq, numpy_resid=True) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + # Regression test for incorrect type for eigenvalues of a complex matrix. def testIssue669(self): def test(x): @@ -1361,9 +1465,7 @@ def testCross(self, lhs_shape, rhs_shape, lhs_dtype, rhs_dtype, axis): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] lax_fun = partial(jnp.linalg.cross, axis=axis) - np_fun = jtu.promote_like_jnp(partial( - np.cross if jtu.numpy_version() < (2, 0, 0) else np.linalg.cross, - axis=axis)) + np_fun = jtu.promote_like_jnp(partial(np.linalg.cross, axis=axis)) with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): self._CheckAgainstNumpy(np_fun, lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker) @@ -1378,8 +1480,7 @@ def testOuter(self, lhs_shape, rhs_shape, lhs_dtype, rhs_dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] lax_fun = jnp.linalg.outer - np_fun = jtu.promote_like_jnp( - np.outer if jtu.numpy_version() < (2, 0, 0) else np.linalg.outer) + np_fun = jtu.promote_like_jnp(np.linalg.outer) with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): self._CheckAgainstNumpy(np_fun, lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker) @@ -1393,10 +1494,7 @@ def testDiagonal(self, shape, dtype, offset): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] lax_fun = partial(jnp.linalg.diagonal, offset=offset) - if jtu.numpy_version() >= (2, 0, 0): - np_fun = partial(np.linalg.diagonal, offset=offset) - else: - np_fun = partial(np.diagonal, offset=offset, axis1=-2, axis2=-1) + np_fun = partial(np.linalg.diagonal, offset=offset) self._CheckAgainstNumpy(np_fun, lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker) @@ -1405,10 +1503,7 @@ def testTrace(self): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] lax_fun = partial(jnp.linalg.trace, offset=offset, dtype=out_dtype) - if jtu.numpy_version() >= (2, 0, 0): - np_fun = partial(np.linalg.trace, offset=offset) - else: - np_fun = partial(np.trace, offset=offset, axis1=-2, axis2=-1, dtype=out_dtype) + np_fun = partial(np.linalg.trace, offset=offset) self._CheckAgainstNumpy(np_fun, lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker) @@ -1809,7 +1904,7 @@ def testHessenberg(self, shape, dtype, calc_q): dtype=float_types + complex_types, lower=[False, True], ) - @jtu.skip_on_devices("tpu","rocm") + @jtu.skip_on_devices("tpu") def testTridiagonal(self, shape, dtype, lower): rng = jtu.rand_default(self.rng()) def jax_func(a): @@ -1999,6 +2094,9 @@ def func(x): dtype=float_types + complex_types, ) @jtu.run_on_devices("cpu") + @jtu.ignore_warning( + category=RuntimeWarning, message='invalid value encountered in matmul' + ) def testSqrtmPSDMatrix(self, shape, dtype): # Checks against scipy.linalg.sqrtm when the principal square root # is guaranteed to be unique (i.e no negative real eigenvalue) @@ -2010,11 +2108,8 @@ def testSqrtmPSDMatrix(self, shape, dtype): tol = 1e-4 else: tol = 1e-8 - self._CheckAgainstNumpy(osp.linalg.sqrtm, - jsp.linalg.sqrtm, - args_maker, - tol=tol, - check_dtypes=False) + self._CheckAgainstNumpy(osp.linalg.sqrtm, jsp.linalg.sqrtm, args_maker, + tol=tol, check_dtypes=False) self._CompileAndCheck(jsp.linalg.sqrtm, args_maker) @jtu.sample_product( @@ -2078,7 +2173,6 @@ def testToeplitzConstruction(self, rshape, rdtype, cshape, cdtype): @jtu.sample_product( shape=[(), (3,), (1, 4), (1, 5, 9), (11, 0, 13)], dtype=float_types + complex_types + int_types) - @jtu.skip_on_devices("rocm") def testToeplitzSymmetricConstruction(self, shape, dtype): if (dtype in [np.float64, np.complex128] and not config.enable_x64.value): @@ -2109,6 +2203,70 @@ def testToeplitzConstructionWithKnownCases(self): [2, 1, 4, 5], [3, 2, 1, 4]], dtype=np.float32)) + @jtu.sample_product( + shape=[(2, 3), (4, 6), (50, 7), (100, 110)], + dtype = float_types + complex_types, + method = ["schur", "eigen"] + ) + @jtu.run_on_devices("cpu", "gpu") + @jax.default_matmul_precision("float32") + def test_solve_sylvester(self, shape, dtype, method): + if jtu.test_device_matches(["gpu"]) and method == "schur": + self.skipTest("Schur not supported on GPU.") + + tol = {np.float32: 5e-2, np.float64: 1e-9, np.complex64: 5e-2, np.complex128: 1e-9} + + def args_maker(): + rng = jtu.rand_default(self.rng()) + m, n = shape + + A = rng(shape=(m, m), dtype=dtype) + B = rng(shape=(n, n), dtype=dtype) + X_true = rng(shape=(m, n), dtype=dtype) + + C = A @ X_true + X_true @ B + return [A, B, C] + + jnp_fun = partial(jsp.linalg.solve_sylvester, method=method) + + self._CheckAgainstNumpy(osp.linalg.solve_sylvester, jnp_fun, args_maker, tol=tol) + self._CompileAndCheck(jnp_fun, args_maker) + + + @jtu.sample_product( + n=[3, 6, 7, 100], + dtype = float_types + complex_types, + method = ["schur", "eigen"] + ) + @jtu.run_on_devices("cpu", "gpu") + def test_ill_conditioned_sylvester(self, n, dtype, method): + """ + Test no solution case to AX + XB = C using the eigen decomposition method. + When the sum of the eigenvalues of A and B are zero there is no solution. + We simulate this case below by randomly selecting the eigenvalues of A and then assign the + eigenvalues of B as negative eigenvalues of A. We say that A and B are ill-conditioned. + """ + if jtu.test_device_matches(["gpu"]) and method == "schur": + self.skipTest("Schur not supported on GPU.") + + rng = jtu.rand_default(self.rng()) + + # Define eigenvalues that sum to zero + eigenvalues_A = rng(shape=(n,), dtype=dtype) + eigenvalues_B = -eigenvalues_A + P = _random_invertible(rng=rng, shape=(n, n), dtype=dtype) + + # Construct A and B matrices using selected eigenvalues that positionally sum to zero + D_A = np.diag(eigenvalues_A) + D_B = np.diag(eigenvalues_B) + P_inv = np.linalg.inv(P) + A = P @ D_A @ P_inv + B = P @ D_B @ P_inv + + C = rng(shape=(n, n), dtype=dtype) + sylv_solution = jsp.linalg.solve_sylvester(A, B, C, method=method, tol=1e-5) + self.assertArraysEqual(sylv_solution, np.full((n, n), np.nan, dtype)) + class LaxLinalgTest(jtu.JaxTestCase): """Tests for lax.linalg primitives.""" @@ -2120,25 +2278,50 @@ class LaxLinalgTest(jtu.JaxTestCase): sort_eigenvalues=[True, False], ) def testEigh(self, n, dtype, lower, sort_eigenvalues): - rng = jtu.rand_default(self.rng()) - tol = 1e-3 - args_maker = lambda: [rng((n, n), dtype)] + implementations = [ + None, + lax.linalg.EighImplementation.QR, + lax.linalg.EighImplementation.JACOBI, + lax.linalg.EighImplementation.QDWH, + ] - a, = args_maker() - a = (a + np.conj(a.T)) / 2 - v, w = lax.linalg.eigh(np.tril(a) if lower else np.triu(a), - lower=lower, symmetrize_input=False, - sort_eigenvalues=sort_eigenvalues) - w = np.asarray(w) - v = np.asarray(v) - self.assertLessEqual( - np.linalg.norm(np.eye(n) - np.matmul(np.conj(T(v)), v)), 1e-3) - self.assertLessEqual(np.linalg.norm(np.matmul(a, v) - w * v), - tol * np.linalg.norm(a)) + for implementation in implementations: + if ( + implementation == lax.linalg.EighImplementation.QR + and jtu.test_device_matches(["tpu"]) + ): + continue + if ( + implementation == lax.linalg.EighImplementation.JACOBI + and jtu.test_device_matches(["cpu"]) + ): + continue + if ( + implementation == lax.linalg.EighImplementation.QDWH + and jtu.test_device_matches(["cpu", "gpu"]) + ): + continue + + rng = jtu.rand_default(self.rng()) + tol = 1e-3 + args_maker = lambda: [rng((n, n), dtype)] + + a, = args_maker() + a = (a + np.conj(a.T)) / 2 + v, w = lax.linalg.eigh(np.tril(a) if lower else np.triu(a), + lower=lower, symmetrize_input=False, + sort_eigenvalues=sort_eigenvalues, + implementation=implementation) + w = np.asarray(w) + v = np.asarray(v) + self.assertLessEqual( + np.linalg.norm(np.eye(n) - np.matmul(np.conj(T(v)), v)), 1e-3) + self.assertLessEqual(np.linalg.norm(np.matmul(a, v) - w * v), + tol * np.linalg.norm(a)) - w_expected, v_expected = np.linalg.eigh(np.asarray(a)) - self.assertAllClose(w_expected, w if sort_eigenvalues else np.sort(w), - rtol=1e-4, atol=1e-4) + w_expected, v_expected = np.linalg.eigh(np.asarray(a)) + self.assertAllClose(w_expected, w if sort_eigenvalues else np.sort(w), + rtol=1e-4, atol=1e-4) def run_eigh_tridiagonal_test(self, alpha, beta): n = alpha.shape[-1] @@ -2197,6 +2380,7 @@ def testSelect(self, dtype): @jtu.sample_product(shape=[(3,), (3, 4), (3, 4, 5)], dtype=float_types + complex_types) + @jtu.skip_on_devices("rocm") # Numerical errors on ROCm def test_tridiagonal_solve(self, shape, dtype): if dtype not in float_types and jtu.test_device_matches(["gpu"]): self.skipTest("Data type not supported on GPU") @@ -2213,7 +2397,10 @@ def build_tri(dl, d, du): build_tri = jax.vmap(build_tri) a = build_tri(dl, d, du) - self.assertAllClose(a @ x, b, atol=5e-5, rtol=1e-4) + with jax.default_matmul_precision("float32"): + self.assertAllClose(a @ x, b, atol={ + np.float32: 1e-3, np.float64: 1e-10, np.complex64: 1e-3, + np.complex128: 1e-10}) def test_tridiagonal_solve_endpoints(self): # tridagonal_solve shouldn't depend on the endpoints being explicitly zero. @@ -2234,6 +2421,10 @@ def test_tridiagonal_solve_endpoints(self): @jtu.sample_product(shape=[(3,), (3, 4)], dtype=float_types + complex_types) def test_tridiagonal_solve_grad(self, shape, dtype): + if jtu.is_device_rocm() and shape == (3, 4) and dtype == np.float32: + # numerical errors seen as of ROCm 7.2 due to rocSparse issue for grad0 variant + # TODO: re-enable the test once the rocSparse issue is fixed + self.skipTest("test_tridiagonal_solve_grad0 not supported on ROCm due to rocSparse issue") if dtype not in float_types and jtu.test_device_matches(["gpu"]): self.skipTest("Data type not supported on GPU") rng = self.rng() @@ -2283,10 +2474,7 @@ def testMatrixTranspose(self, shape, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] jnp_fun = jnp.linalg.matrix_transpose - if jtu.numpy_version() < (2, 0, 0): - np_fun = lambda x: np.swapaxes(x, -1, -2) - else: - np_fun = np.linalg.matrix_transpose + np_fun = np.linalg.matrix_transpose self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) @@ -2332,6 +2520,22 @@ def testSymmetricProduct(self, shape, dtype, symmetrize_output): self.assertAllClose( new_product_with_batching, old_product, atol=atol) + @jtu.sample_product( + n=[0, 1, 5, 10, 20], + kind=["symmetric", "lower", "upper"], + ) + @jax.default_matmul_precision("float32") + def testPascal(self, n, kind): + args_maker = lambda: [] + osp_fun = partial(osp.linalg.pascal, n=n, kind=kind, exact=False) + jsp_fun = partial(jsp.linalg.pascal, n=n, kind=kind) + self._CheckAgainstNumpy(osp_fun, + jsp_fun, args_maker, + atol=1e-3, + rtol=1e-2 if jtu.test_device_matches(['tpu']) else 1e-3, + check_dtypes=False) + self._CompileAndCheck(jsp_fun, args_maker) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lobpcg_test.py b/tests/lobpcg_test.py index fc2b0df849d1..bd67c8d522ac 100644 --- a/tests/lobpcg_test.py +++ b/tests/lobpcg_test.py @@ -21,7 +21,6 @@ import functools import re import os -import unittest from absl.testing import absltest from absl.testing import parameterized @@ -272,7 +271,7 @@ def checkLobpcgMonotonicity(self, matrix_name, n, k, m, tol, dtype): self._possibly_plot(A, eigs, X, m, matrix_name) def _possibly_plot(self, A, eigs, X, m, matrix_name): - if not os.getenv('LOBPCG_EMIT_DEBUG_PLOTS'): + if os.getenv('LOBPCG_EMIT_DEBUG_PLOTS', '0') != '1': return if isinstance(A, (np.ndarray, jax.Array)): @@ -370,12 +369,6 @@ def checkApproxEigs(self, example_name, dtype): class F32LobpcgTest(LobpcgTest): - def setUp(self): - # TODO(phawkins): investigate this failure - if jtu.test_device_matches(["gpu"]): - raise unittest.SkipTest("Test is failing on CUDA gpus") - super().setUp() - def testLobpcgValidatesArguments(self): A, _ = _concrete_generators(np.float32)['id'](100, 10) X = self.rng().standard_normal(size=(100, 10)).astype(np.float32) @@ -394,7 +387,7 @@ def testLobpcgValidatesArguments(self): linalg.lobpcg_standard(A[:50, :50], X[:50]) @parameterized.named_parameters(_make_concrete_cases(f64=False)) - @jtu.skip_on_devices("gpu") + @jtu.skip_on_devices("cuda") def testLobpcgConsistencyF32(self, matrix_name, n, k, m, tol): self.checkLobpcgConsistency(matrix_name, n, k, m, tol, jnp.float32) @@ -410,24 +403,18 @@ def testCallableMatricesF32(self, matrix_name): @jtu.with_config(jax_enable_x64=True) class F64LobpcgTest(LobpcgTest): - def setUp(self): - # TODO(phawkins): investigate this failure - if jtu.test_device_matches(["gpu"]): - raise unittest.SkipTest("Test is failing on CUDA gpus") - super().setUp() - @parameterized.named_parameters(_make_concrete_cases(f64=True)) - @jtu.skip_on_devices("tpu", "gpu") + @jtu.skip_on_devices("tpu", "cuda") def testLobpcgConsistencyF64(self, matrix_name, n, k, m, tol): self.checkLobpcgConsistency(matrix_name, n, k, m, tol, jnp.float64) @parameterized.named_parameters(_make_concrete_cases(f64=True)) - @jtu.skip_on_devices("tpu", "gpu") + @jtu.skip_on_devices("tpu", "cuda") def testLobpcgMonotonicityF64(self, matrix_name, n, k, m, tol): self.checkLobpcgMonotonicity(matrix_name, n, k, m, tol, jnp.float64) @parameterized.named_parameters(_make_callable_cases(f64=True)) - @jtu.skip_on_devices("tpu", "gpu") + @jtu.skip_on_devices("tpu", "cuda") def testCallableMatricesF64(self, matrix_name): self.checkApproxEigs(matrix_name, jnp.float64) diff --git a/tests/logging_test.py b/tests/logging_test.py index cfe10c5a90e2..bb6eaf0f3b68 100644 --- a/tests/logging_test.py +++ b/tests/logging_test.py @@ -27,7 +27,6 @@ import jax import jax._src.test_util as jtu from jax._src import xla_bridge -from jax._src.logging_config import _default_TF_CPP_MIN_LOG_LEVEL # Note: importing absltest causes an extra absl root log handler to be # registered, which causes extra debug log messages. We don't expect users to @@ -150,7 +149,7 @@ def test_debug_logging(self): with capture_jax_logs() as log_output: jax.jit(lambda x: x + 1)(1) self.assertIn("Finished tracing + transforming", log_output.getvalue()) - self.assertIn("Compiling ", log_output.getvalue()) + self.assertIn("Compiling jit()", log_output.getvalue()) # Turn off all debug logging. with jax_debug_log_modules(""): @@ -163,7 +162,7 @@ def test_debug_logging(self): with capture_jax_logs() as log_output: jax.jit(lambda x: x + 1)(1) self.assertIn("Finished tracing + transforming", log_output.getvalue()) - self.assertNotIn("Compiling ", log_output.getvalue()) + self.assertNotIn("Compiling jit()", log_output.getvalue()) # Turn everything off again. with jax_debug_log_modules(""): @@ -282,8 +281,11 @@ def test_subprocess_cpp_logging_level(self): self.assertNotIn("Initializing CoordinationService", o.stderr) o = _run(program) - if int(_default_TF_CPP_MIN_LOG_LEVEL) >= 1: + default_cpp_log_level = os.environ.get("TF_CPP_MIN_LOG_LEVEL") + if default_cpp_log_level is not None and int(default_cpp_log_level) >= 1: self.assertNotIn("Initializing CoordinationService", o.stderr) + else: + self.assertIn("Initializing CoordinationService", o.stderr) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/magma_linalg_test.py b/tests/magma_linalg_test.py index 37bc8833959e..3de44f039405 100644 --- a/tests/magma_linalg_test.py +++ b/tests/magma_linalg_test.py @@ -115,7 +115,7 @@ def testEigMagmaConfig(self): self.assertIn('magma = "on"', hlo) @jtu.sample_product( - shape=[(3, 4), (3, 3), (4, 3), (4, 3)], + shape=[(3, 4), (3, 3), (4, 3), (100, 100), (100, 10)], dtype=float_types + complex_types, ) @jtu.run_on_devices("gpu") diff --git a/tests/memories_test.py b/tests/memories_test.py index 0ca973c4d221..1f2edf3be7cf 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -23,19 +23,20 @@ import jax from jax import lax +from jax._src import core from jax._src import test_util as jtu from jax._src import xla_bridge as xb -from jax._src.layout import DeviceLocalLayout as DLL, Layout +from jax._src.layout import Layout as DLL, Format from jax._src import config -from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint import jax.numpy as jnp -from jax.ad_checkpoint import Offloadable, remat, Recompute +from jax.ad_checkpoint import checkpoint_name, Offloadable, Recompute from jax._src.sharding import common_devices_indices_map -from jax._src.sharding_impls import (NamedSharding, PositionalSharding, - SingleDeviceSharding, GSPMDSharding, - TransferToMemoryKind, PartitionSpec as P) +from jax._src.sharding_impls import ( + NamedSharding, SingleDeviceSharding, GSPMDSharding, PartitionSpec as P) +from jax._src.xla_metadata import set_xla_metadata from jax.experimental.compute_on import compute_on -from jax.experimental.shard_map import shard_map +from jax._src.compute_on import compute_on2 +from jax._src.shard_map import shard_map import numpy as np config.parse_flags_with_absl() @@ -59,14 +60,10 @@ class ShardingMemoriesTest(jtu.JaxTestCase): def setUp(self): super().setUp() - if jtu.test_device_matches(["cpu"]): - self._default_memory_kind = "unpinned_host" - else: - self._default_memory_kind = "device" + self._default_memory_kind = "device" @parameterized.named_parameters( ("named_sharding", "named_sharding"), - ("positional_sharding", "positional_sharding"), ("single_device_sharding", "single_device_sharding"), ("gspmd_sharding", "gspmd_sharding"), ) @@ -75,9 +72,6 @@ def test_canonicalize_memory_kind(self, name): mesh = jtu.create_mesh((1,), "x") ns = NamedSharding(mesh, P("x")) self.assertEqual(ns.memory_kind, self._default_memory_kind) - elif name == "positional_sharding": - ps = PositionalSharding(jax.devices()) - self.assertEqual(ps.memory_kind, self._default_memory_kind) elif name == "single_device_sharding": ss = SingleDeviceSharding(jax.devices()[0]) self.assertEqual(ss.memory_kind, self._default_memory_kind) @@ -88,7 +82,6 @@ def test_canonicalize_memory_kind(self, name): @parameterized.named_parameters( ("named_sharding", "named_sharding"), - ("positional_sharding", "positional_sharding"), ("single_device_sharding", "single_device_sharding"), ("gspmd_sharding", "gspmd_sharding"), ) @@ -99,11 +92,6 @@ def test_wrong_memory_kind(self, name): ): mesh = jtu.create_mesh((1,), ("x",)) NamedSharding(mesh, P("x"), memory_kind="hbm") - elif name == "positional_sharding": - with self.assertRaisesRegex( - ValueError, "Could not find memory addressable by device.*" - ): - PositionalSharding(jax.devices(), memory_kind="gpu_hbm") elif name == "single_device_sharding": with self.assertRaisesRegex( ValueError, @@ -120,7 +108,6 @@ def test_wrong_memory_kind(self, name): @parameterized.named_parameters( ("named_sharding", "named_sharding"), - ("positional_sharding", "positional_sharding"), ("single_device_sharding", "single_device_sharding"), ("gspmd_sharding", "gspmd_sharding"), ) @@ -131,8 +118,6 @@ def test_correct_tpu_memory_kind(self, name): if name == "named_sharding": mesh = jtu.create_mesh((1,), ("x",)) NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind) - elif name == "positional_sharding": - PositionalSharding(jax.devices(), memory_kind=self._default_memory_kind) elif name == "single_device_sharding": SingleDeviceSharding(jax.devices()[0], memory_kind="unpinned_host") else: @@ -141,7 +126,6 @@ def test_correct_tpu_memory_kind(self, name): @parameterized.named_parameters( ("named_sharding", "named_sharding"), - ("positional_sharding", "positional_sharding"), ("single_device_sharding", "single_device_sharding"), ("gspmd_sharding", "gspmd_sharding"), ) @@ -151,10 +135,6 @@ def test_sharding_eq(self, name): s1 = NamedSharding(mesh, P("x")) s2 = NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind) self.assertEqual(s1, s2) - elif name == "positional_sharding": - s1 = PositionalSharding(jax.devices()) - s2 = PositionalSharding(jax.devices(), memory_kind=self._default_memory_kind) - self.assertEqual(s1, s2) elif name == "single_device_sharding": s1 = SingleDeviceSharding(jax.devices()[0]) s2 = SingleDeviceSharding(jax.devices()[0], memory_kind=self._default_memory_kind) @@ -205,13 +185,6 @@ def _check_device_put_addressable_shards( self.assertArraysEqual(s.data, inp) self.assertEqual(s.data.sharding.memory_kind, expected_mem_kind) - def test_error_transfer_to_memory_kind_outside_jit(self): - with self.assertRaisesRegex( - ValueError, - "TransferToMemoryKind argument to jax.device_put can only be used" - " inside jax.jit"): - jax.device_put(np.arange(16), TransferToMemoryKind("device")) - @parameterized.parameters("unpinned_host", "pinned_host") def test_device_put_host_to_hbm(self, host_memory_kind: str): if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host": @@ -563,7 +536,7 @@ def test_parameter_streaming_inside_scan(self): @jax.jit def f(xs): def body(carry, x): - x_tpu = jax.device_put(x, TransferToMemoryKind("device")) + x_tpu = jax.device_put(x, jax.memory.Space.Device) return carry, x_tpu + carry return jax.lax.scan(body, 1.0, xs) @@ -574,6 +547,22 @@ def body(carry, x): out_s = NamedSharding(mesh, P(None, None, "z"), memory_kind="device") self.assertEqual(out_hbm.sharding, out_s) + def test_diff_mem_space_error(self): + mesh = jtu.create_mesh((2,), ("x",)) + np_inp = np.arange(16.0).reshape(8, 2) + arr_hbm = jax.device_put( + np_inp, NamedSharding(mesh, P("x"), memory_kind="device")) + arr_host = jax.device_put( + np_inp, NamedSharding(mesh, P("x"), memory_kind="pinned_host")) + + @jax.jit + def f(x, y): + return x + y + + with self.assertRaisesRegex( + ValueError, "memory_space of all inputs.*must be the same"): + f(arr_hbm, arr_host) + def test_output_streaming(self): mesh = jtu.create_mesh((1, 1), ("x", "y")) np_inp = np.arange(16.0).reshape(8, 2) @@ -598,7 +587,9 @@ def test_weight_offload_with_dp_on_output(self): @jax.jit def f(x): x = x * 2 + self.assertEqual(x.aval.memory_space, core.MemorySpace.Device) y = jax.device_put(x, s_host) + self.assertEqual(y.aval.memory_space, core.MemorySpace.Host) return y out_host = f(inp_dev) @@ -620,6 +611,7 @@ def body(carry, x): return carry, jax.device_put( out_tpu, NamedSharding(mesh, P("y", "z"), memory_kind="pinned_host")) _, res = jax.lax.scan(body, 1, xs) + self.assertEqual(res.aval.memory_space, core.MemorySpace.Host) return res out = f(arr_hbm) @@ -655,7 +647,7 @@ def f(): @jtu.run_on_devices('tpu') def test_ragged_copy_on_host(self): mesh = jtu.create_mesh((2,), ('x')) - sharding = jax.sharding.NamedSharding(mesh, P(('x'))) + sharding = jax.sharding.NamedSharding(mesh, P('x')) cpu_sharding = sharding.with_memory_kind('pinned_host') num_pages = 512 * 1024 @@ -665,13 +657,13 @@ def test_ragged_copy_on_host(self): def write(x): return x.at[16 * 1024:].set(0) - x = shard_map(write, mesh, P(('x'),), P(('x')))(x) + x = shard_map(write, mesh=mesh, in_specs=P(('x'),), out_specs=P('x'))(x) chunk_size = 8 def inner(state): idx, x, output = state chunk = jax.lax.dynamic_slice_in_dim(x, idx * chunk_size, chunk_size) - chunk_host = jax.device_put(chunk, TransferToMemoryKind('pinned_host')) + chunk_host = jax.device_put(chunk, jax.memory.Space.Host) output = jax.lax.dynamic_update_slice_in_dim( output, chunk_host, idx * chunk_size, axis=0) return (idx + 1, x, output) @@ -682,18 +674,16 @@ def cond(state): return (idx * chunk_size < x.shape[0]) & jnp.any(chunk > 0) def foo(x): - output = jnp.zeros_like(x, device=cpu_sharding) + output = jax.device_put(jnp.zeros_like(x), + jax.memory.Space.Host) _, _, cpu_x = jax.lax.while_loop(cond, inner, (0, x, output)) return cpu_x - fn = jax.jit(shard_map(foo, mesh, P(('x'),), P(('x')), - check_rep=False), + fn = jax.jit(shard_map(foo, mesh=mesh, in_specs=P(('x'),), + out_specs=P('x'), check_vma=False), out_shardings=cpu_sharding) y = fn(x) jax.block_until_ready(y) - compiled_text = fn.lower(x).compile().as_text() - if compiled_text is not None: - self.assertIn('custom_call_target="AllocateBuffer"', compiled_text) def test_disallow_alias_copies_arrays(self): mesh = jtu.create_mesh((2,), ("x",)) @@ -708,6 +698,19 @@ def test_disallow_alias_copies_arrays(self): jax.block_until_ready(inp_host_copy) + def test_device_put_memory_space(self): + mesh = jtu.create_mesh((2,), ("x",)) + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, NamedSharding(mesh, P("x"))) + + out = jax.device_put(arr, jax.memory.Space.Host) + self.assertEqual(out.sharding, + NamedSharding(mesh, P("x"), memory_kind='pinned_host')) + + out = jax.device_put(arr, jax.memory.Space.Device) + self.assertEqual(out.sharding, + NamedSharding(mesh, P("x"), memory_kind='device')) + def test_disallow_alias_copies_arrays_with_donated_input(self): mesh = jtu.create_mesh((2,), ("x",)) np_inp = np.arange(16).reshape(8, 2) @@ -723,12 +726,28 @@ def test_disallow_alias_copies_arrays_with_donated_input(self): jax.block_until_ready(inp_host_donate_copy) + def test_host_to_device_transfer(self): + orig = np.arange(8) + d = jax.device_put(orig, jax.memory.Space.Device) + self.assertTrue(d.committed) + + for _ in range(2): + h = jax.device_put(d, jax.memory.Space.Host) + self.assertTrue(h.committed) + self.assertEqual(h.sharding.memory_kind, 'pinned_host') + self.assertArraysEqual(h, orig) + + d = jax.device_put(h, jax.memory.Space.Device) + self.assertTrue(d.committed) + self.assertEqual(d.sharding.memory_kind, 'device') + self.assertArraysEqual(d, orig) + class ComputeOffload(jtu.BufferDonationTestCase): def setUp(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Memories do not work on CPU and GPU backends yet.") + if not jtu.test_device_matches(["tpu", "gpu"]): + self.skipTest("Memories do not work on CPU backends yet.") super().setUp() def _check_mem_kind(self, executable_kind, out_sharding, expected_kind): @@ -754,11 +773,6 @@ def init(): self.assertEqual(cpu_array.sharding, cpu_sharding) def test_compute_no_inputs_host_replicated(self): - if xb.backend_xla_version() is not None and xb.backend_xla_version() < 3: - self.skipTest("This test requires an xla_version >= 3.") - if config.use_shardy_partitioner.value: - self.skipTest("XLA failure due to b/370786664 and b/366411266. " - "Enable when fixed.") mesh = jtu.create_mesh((4,), ('data')) tpu_sharding = NamedSharding(mesh, P('data')) @@ -777,8 +791,8 @@ def init(): def test_compute_on_basic(self): out_s = SingleDeviceSharding(jax.devices()[0], memory_kind='pinned_host') - @compute_on('device_host') - @jax.jit + @compute_on2(compute_type='device_host', + out_memory_spaces=jax.memory.Space.Device) def g(x): return x * 2 @@ -794,6 +808,36 @@ def f(x): lowered_text = f.lower(jnp.arange(8)).as_text() self.assertIn('_xla_compute_type', lowered_text) + @functools.partial(jax.jit, out_shardings=out_s) + def h(x): + y = g(x) + return y * 3 + + out2 = h(inp) + self.assertArraysEqual(out2, inp * 6) + self.assertEqual(out2.sharding.memory_kind, "pinned_host") + + def test_compute_on_2d(self): + out_s = SingleDeviceSharding(jax.devices()[0], memory_kind="pinned_host") + + @compute_on("device_host") + @jax.jit + def g(x): + return x * 2 + + @jax.jit + def f(x): + y = g(x) + return y * 3 + + inp = jnp.arange(9943.0) + inp = jnp.reshape(inp, (61, 163)) + out = f(inp) + self.assertArraysEqual(out, inp * 6) + + lowered_text = f.lower(inp).as_text() + self.assertIn("_xla_compute_type", lowered_text) + @functools.partial(jax.jit, out_shardings=out_s) def h(x): y = g(x) @@ -809,18 +853,16 @@ def test_compute_on_host_shared_sharding(self): host_sharding = device_sharding.with_memory_kind("pinned_host") @compute_on("device_host") - @functools.partial( - jax.jit, - in_shardings=(host_sharding, device_sharding), - out_shardings=(host_sharding, device_sharding), - donate_argnums=(0, 1), - ) + @jax.jit def host_func(x, y): - return (x * y), ((x**2) * (y**2)) + y = jax.device_put(y, host_sharding) + out1 = x * y + out2 = (x ** 2) * (y ** 2) + return (jax.device_put(out1, host_sharding), + jax.device_put(out2, device_sharding)) @functools.partial( jax.jit, - in_shardings=(host_sharding, device_sharding), out_shardings=(host_sharding, device_sharding), donate_argnums=(0), ) @@ -828,10 +870,9 @@ def device_func(host_data, device_data): host_data, device_data = host_func(host_data, device_data) device_data = device_data * 2 host_data, device_data = host_func(host_data, device_data) - return (host_data, device_data) + return host_data, device_data - input_x = jnp.ones(8) - input_host = jax.device_put(input_x, host_sharding) + input_host = jax.device_put(jnp.ones(8), host_sharding) input_device = jnp.arange(8) input_device = jnp.where(input_device < 4, 0, 1) @@ -900,9 +941,6 @@ def h(x): self.assertEqual(out2.sharding.memory_kind, 'pinned_host') def test_compute_host_loop(self): - # TODO(apaszke): Remove after 12 weeks have passed. - if not jtu.if_cloud_tpu_at_least(2024, 12, 19): - self.skipTest("Requires libtpu built after 2024-12-19") @compute_on('device_host') @jax.jit def fn(): @@ -938,8 +976,8 @@ def f2(x): f2(jnp.arange(8)) def test_compute_on_grad(self): - @compute_on('device_host') - @jax.jit + @compute_on2(compute_type='device_host', + out_memory_spaces=jax.memory.Space.Device) def g(x): return jnp.sin(x) @@ -954,7 +992,7 @@ def f(x): lowered_text = jf.lower(inp).as_text('hlo') out = re.findall(r"call.*to_apply.*_xla_compute_type", lowered_text) - self.assertLen(out, 2) + self.assertLen(out, 1) def test_compute_on_remat(self): inp = jnp.arange(16.) @@ -962,15 +1000,15 @@ def test_compute_on_remat(self): def policy(prim, *avals, **params): return Recompute - @compute_on('device_host') - @jax.jit + @compute_on2(compute_type='device_host', + out_memory_spaces=jax.memory.Space.Device) def g(x): x = jnp.sin(x) x = jnp.sin(x) x = jnp.sin(x) return x - @functools.partial(remat, policy=policy) + @functools.partial(jax.remat, policy=policy) def f(x): x = g(x) return jnp.sum(x) @@ -981,7 +1019,7 @@ def f(x): lowered_text = jf.lower(inp).as_text('hlo') out = re.findall(r"call.*to_apply.*_xla_compute_type", lowered_text) - self.assertLen(out, 2) + self.assertLen(out, 1) def test_nested_no_op_compute(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) @@ -1014,8 +1052,8 @@ def test_sharded_compute_on_host(self): np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, s) - @compute_on('device_host') - @jax.jit + @compute_on2(compute_type='device_host', + out_memory_spaces=jax.memory.Space.Device) def g(x, y): return x * y @@ -1043,20 +1081,20 @@ def eq(x, y): def f_fwd(x): y = x * 2 - z = jax.device_put(y, TransferToMemoryKind('pinned_host')) + z = jax.device_put(y, jax.memory.Space.Host) return y, (x, z) def f_bwd(res, tx): x, z = res y = x * 2 - z2 = jax.device_put(y, TransferToMemoryKind('pinned_host')) + z2 = jax.device_put(y, jax.memory.Space.Host) return (eq(z, z2),) f.defvjp(f_fwd, f_bwd) g = jax.jit(jax.grad(lambda x: f(x).sum())) x = jnp.ones(3) * 4 - all_true = jnp.ones(3) + all_true = jnp.ones(3, jnp.float32) self.assertArraysEqual(g(x), all_true) def test_host_offload_in_custom_vjp_sharded(self): @@ -1089,7 +1127,7 @@ def f_bwd(res, tx): g = jax.jit(jax.grad(lambda x: f(x).sum())) arr = jax.device_put(jnp.ones(4) * 4, s) - all_true = jnp.ones(4) + all_true = jnp.ones(4, dtype=jnp.float32) self.assertArraysEqual(g(arr), all_true) def test_scan_offload(self): @@ -1243,15 +1281,15 @@ def f(x): with self.assertRaisesRegex( ValueError, "Memory kinds passed to jax.jit does not match memory kind on the" - " respective arg. Got pjit memory kind: pinned_host, arg memory kind:" - " device for arg shape.*"): + " respective arg. Got jit memory kind: pinned_host, arg memory kind:" + " device for arg.*"): f(jnp.arange(16).reshape(8, 2)) # uncommitted inp also raises error with self.assertRaisesRegex( ValueError, "Memory kinds passed to jax.jit does not match memory kind on the" - " respective arg. Got pjit memory kind: pinned_host, arg memory kind:" - " device for arg shape.*"): + " respective arg. Got jit memory kind: pinned_host, arg memory kind:" + " device for arg.*"): f(inp) # committed inp raises error. @functools.partial(jax.jit, in_shardings=s.with_memory_kind("device")) @@ -1379,8 +1417,6 @@ def test_sharding_devices_indices_map_cache_hit(self): self.assertEqual(cache_info2.misses, cache_info1.misses) def test_no_donation_across_memory_kinds(self): - if xb.using_pjrt_c_api(): - raise unittest.SkipTest("GetOutputShardings not supported in PJRT C API") mesh = jtu.create_mesh((2, 1), ("x", "y")) np_inp = np.arange(16).reshape(8, 2) s_hbm = NamedSharding(mesh, P("x")) @@ -1438,6 +1474,10 @@ def f(x): def test_qr_decomposition_offload(self): if jtu.is_cloud_tpu(): self.skipTest("Test fails on cloud TPU") + if jtu.test_device_matches(["gpu"]): + # TODO(b/446898771) This test fails on GPU in OSS, it will work + # internally. + self.skipTest("Test doesn't work on GPU in OSS.") shape = (3, 3) dtype = np.float32 @@ -1457,7 +1497,8 @@ def f(x): out = f(operand) # doesn't crash lowered_text = f.lower(operand).as_text() self.assertIn('@lapack_sgeqrf', lowered_text) - self.assertIn('@Qr', lowered_text) + if jtu.test_device_matches(["tpu"]): + self.assertIn("@Qr", lowered_text) @jax.jit def h(x): @@ -1474,8 +1515,8 @@ def test_mem_kind_donation_pinned_host(self): s = NamedSharding(mesh, P(), memory_kind='pinned_host') s_dev = s.with_memory_kind('device') - @compute_on('device_host') @functools.partial(jax.jit, out_shardings=(s, s_dev), donate_argnums=(0, 1)) + @compute_on('device_host') def f(inp1, inp2): return inp1 * 2, inp2 * 2 @@ -1541,9 +1582,8 @@ def test_fn(x_in, y_in): self.assertArraysEqual(y_out, y1 + y1) def test_compute_offload_with_linear_layout(self): - # TODO(apaszke): Remove after 12 weeks have passed. - if not jtu.if_cloud_tpu_at_least(2024, 12, 19): - self.skipTest("Requires libtpu built after 2024-12-19") + if jtu.test_device_matches(["gpu"]): + self.skipTest("GPU does not support tiling.") sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0]) p_sharding = jax.sharding.SingleDeviceSharding( jax.devices()[0], memory_kind="pinned_host" @@ -1562,10 +1602,10 @@ def test_fn(x_in, y_in): x = jnp.reshape(x, (16, 64)) y = jnp.arange(0, 1024, dtype=jnp.float32) y = jnp.reshape(y, (16, 64)) - custom_dll = DLL(major_to_minor=(0, 1), _tiling=((8, 128),)) - custom_dll_linear = DLL(major_to_minor=(0, 1), _tiling=((1,),)) - x = jax.device_put(x, Layout(custom_dll, sharding)) - y = jax.device_put(y, Layout(custom_dll_linear, p_sharding)) + custom_dll = DLL(major_to_minor=(0, 1), tiling=((8, 128),)) + custom_dll_linear = DLL(major_to_minor=(0, 1), tiling=((1,),)) + x = jax.device_put(x, Format(custom_dll, sharding)) + y = jax.device_put(y, Format(custom_dll_linear, p_sharding)) x1 = jnp.arange(0, 1024, dtype=jnp.float32) x1 = jnp.reshape(x1, (16, 64)) @@ -1575,8 +1615,8 @@ def test_fn(x_in, y_in): jit_fn = jax.jit( test_fn, out_shardings=( - Layout(custom_dll, sharding), - Layout(custom_dll_linear, p_sharding), + Format(custom_dll, sharding), + Format(custom_dll_linear, p_sharding), ), ) x_out, y_out = jit_fn(x, y) @@ -1584,6 +1624,8 @@ def test_fn(x_in, y_in): self.assertArraysEqual(y_out, y1 + y1) def test_compute_offload_mesh_with_linear_layout(self): + if jtu.test_device_matches(["gpu"]): + self.skipTest("GPU does not support tiling.") mesh = jtu.create_mesh((2, 2), ("x", "y")) sharding = NamedSharding(mesh, P("x", "y")) p_sharding = NamedSharding(mesh, P("x", "y"), memory_kind="pinned_host") @@ -1601,10 +1643,10 @@ def test_fn(x_in, y_in): x = jnp.reshape(x, (32, 64)) y = jnp.arange(0, 2048, dtype=jnp.float32) y = jnp.reshape(y, (32, 64)) - custom_dll = DLL(major_to_minor=(0, 1), _tiling=((8, 128),)) - custom_dll_linear = DLL(major_to_minor=(0, 1), _tiling=((1,),)) - x = jax.device_put(x, Layout(custom_dll, sharding)) - y = jax.device_put(y, Layout(custom_dll_linear, p_sharding)) + custom_dll = DLL(major_to_minor=(0, 1), tiling=((8, 128),)) + custom_dll_linear = DLL(major_to_minor=(0, 1), tiling=((1,),)) + x = jax.device_put(x, Format(custom_dll, sharding)) + y = jax.device_put(y, Format(custom_dll_linear, p_sharding)) x1 = jnp.arange(0, 2048, dtype=jnp.float32) x1 = jnp.reshape(x1, (32, 64)) @@ -1614,14 +1656,27 @@ def test_fn(x_in, y_in): jit_fn = jax.jit( test_fn, out_shardings=( - Layout(custom_dll, sharding), - Layout(custom_dll_linear, p_sharding), + Format(custom_dll, sharding), + Format(custom_dll_linear, p_sharding), ), ) x_out, y_out = jit_fn(x, y) self.assertArraysEqual(x_out, x1 * x1) self.assertArraysEqual(y_out, y1 + y1) + def test_indexing_on_host(self): + @jax.jit + @compute_on("device_host") + def fn2(x): + x = jax.device_put(x, jax.memory.Space.Host) + y = jnp.ones((2, 1, 4)) + y = jax.device_put(y, jax.memory.Space.Host) + z = x.at[:, 1:2, :].set(y) + return z + + x_host = jax.device_put(jnp.ones((2,3,4)), jax.memory.Space.Host) + fn2(x_host) # doesn't crash + def test_compute_on_cache_miss(self): @jax.jit def f(x): @@ -1638,10 +1693,23 @@ def f(x): # 2 for `f` and `2` for `mul` (compute type changes for `mul`) self.assertEqual(count(), 4) + def test_compute_on_aot(self): + operand = np.float32(0.) + + @jax.jit + @compute_on("device_host") + def f_host(x): + # Adds 1 on CPU and adds 2 on other platforms + return jax.lax.platform_dependent(x, + cpu=lambda x: x + 1., + default=lambda x: x + 2.) + + self.assertAllClose(jnp.float32(1.0), f_host(operand)) + self.assertAllClose( + jnp.float32(1.0), f_host.lower(operand).compile()(operand) + ) + def test_offload_take_host(self): - # TODO(apaszke): Remove after 12 weeks have passed. - if not jtu.if_cloud_tpu_at_least(2024, 12, 19): - self.skipTest("Requires libtpu built after 2024-12-19") @compute_on('device_host') @jax.jit def peer_forward(x, experts, indices, scores): @@ -1658,12 +1726,232 @@ def peer_forward(x, experts, indices, scores): scores = jnp.ones((16, 4, 4, 2)) jax.jit(peer_forward)(x, experts, indices, scores) # doesn't crash + def test_int4_host_compute(self): + + @compute_on("device_host") + @jax.jit + def g(x): + return x + x + + @jax.jit + def f(x): + y = g(x) + return 2 * y + + inp = jnp.arange(4, dtype=jnp.uint4) + out = f(inp) + self.assertArraysEqual(out, 4 * inp) + + lowered_text = f.lower(inp).as_text() + self.assertIn("_xla_compute_type", lowered_text) + + def test_sparsecore_unsupported_gather(self): + if not ( + jax.devices()[0].device_kind == "TPU v5" + or jtu.is_device_tpu_at_least(6) + ): + self.skipTest("Does not have a sparsecore present") + + dnums = jax.lax.GatherDimensionNumbers( + offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1) + ) + slice_sizes = (1, 3) + + @compute_on("tpu_sparsecore") + @jax.jit + def f_sc(operand, indices): + return jax.lax.gather(operand, indices, dnums, slice_sizes) + + inputs = ( + np.linspace(0, 1, 10 * 5).reshape(10, 5), + np.array([[4, 2], [3, 2]]), + ) + + unsupported_gather = False + error_msg = None + try: + jax.jit(f_sc).lower(*inputs).compile() + except jax.errors.JaxRuntimeError as e: + unsupported_gather = True + error_msg = str(e) + self.assertTrue(unsupported_gather) + self.assertIn("UNIMPLEMENTED", error_msg) + + def test_sparsecore_supported_gather(self): + if not ( + jax.devices()[0].device_kind == "TPU v5" + or jtu.is_device_tpu_at_least(6) + ): + self.skipTest("Does not have a sparsecore present") + + dnums = jax.lax.GatherDimensionNumbers( + offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,) + ) + slice_sizes = (1, 128) + + @jax.jit + def f_tc(operand, indices): + return jax.lax.gather(operand, indices, dnums, slice_sizes) + + @compute_on("tpu_sparsecore") + @jax.jit + def f_sc(operand, indices): + return jax.lax.gather(operand, indices, dnums, slice_sizes) + + inputs = ( + np.linspace(0, 1, 122479 * 128).reshape(122479, 128), + np.random.randint(2, size=32768).reshape(32768, 1), + ) + + self.assertAllClose(f_tc(*inputs), f_sc(*inputs)) + + compiled_f_sc = jax.jit(f_sc).lower(*inputs).compile() + compiled_text = compiled_f_sc.as_text() + self.assertIn('async_execution_thread="sparsecore"', compiled_text) + + def test_sparsecore_unsupported_scatter(self): + if not ( + jax.devices()[0].device_kind == "TPU v5" + or jtu.is_device_tpu_at_least(6) + ): + self.skipTest("Does not have a sparsecore present") + + dnums = jax.lax.ScatterDimensionNumbers( + update_window_dims=(), + inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(0,), + ) + + @compute_on("tpu_sparsecore") + @jax.jit + def f_sc(operand, indices, updates): + return jax.lax.scatter(operand, indices, updates, dnums) + + inputs = ( + np.linspace(0, 1, 15677312).reshape(15677312), + np.random.randint(15677312, size=524288).reshape(524288, 1), + np.linspace(0, 1, 524288).reshape(524288), + ) + + unsupported_scatter = False + error_msg = None + try: + jax.jit(f_sc).lower(*inputs).compile() + except jax.errors.JaxRuntimeError as e: + unsupported_scatter = True + error_msg = str(e) + self.assertTrue(unsupported_scatter) + self.assertIn("UNIMPLEMENTED", error_msg) + + def test_sparsecore_supported_scatter(self): + if not ( + jax.devices()[0].device_kind == "TPU v5" + or jtu.is_device_tpu_at_least(6) + ): + self.skipTest("Does not have a sparsecore present") + + dnums = jax.lax.ScatterDimensionNumbers( + update_window_dims=(), + inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(0,), + ) + + @jax.jit + def f_tc(operand, indices, updates): + return jax.lax.scatter_add(operand, indices, updates, dnums) + + @compute_on("tpu_sparsecore") + @jax.jit + def f_sc(operand, indices, updates): + return jax.lax.scatter_add(operand, indices, updates, dnums) + + inputs = ( + np.linspace(0, 1, 15677312).reshape(15677312), + np.random.randint(15677312, size=524288).reshape(524288, 1), + np.linspace(0, 1, 524288).reshape(524288), + ) + + self.assertAllClose(f_tc(*inputs), f_sc(*inputs)) + + compiled_f_sc = jax.jit(f_sc).lower(*inputs).compile() + compiled_text = compiled_f_sc.as_text() + self.assertIn('async_execution_thread="sparsecore"', compiled_text) + class StreamAnnotationTest(jtu.JaxTestCase): + def test_stream_annotation_single_instruction(self): + # E2E test for fix https://github.com/openxla/xla/pull/24269 + if not jtu.test_device_matches(["gpu"]): + self.skipTest("Stream annotation is only supported on GPU.") + + mesh = jtu.create_mesh((2,), ('x',)) + s = NamedSharding(mesh, P('x')) + np_inp = np.ones((8,)) + arr1 = jax.device_put(np_inp, s) + arr2 = jax.device_put(np_inp, s) + + @compute_on('gpu_stream:1') + @jax.jit + def g(x, y): + return x + y + + @jax.jit + def f(x, y): + return g(x, y) + + compiled_f = jax.jit(f).lower(arr1, arr2).compile() + compiled_text = compiled_f.as_text() + self.assertIn('call-start', compiled_text) + self.assertIn('_xla_stream_annotation="1"', compiled_text) + self.assertIn('wrapped_add', compiled_text) + self.assertArraysEqual(compiled_f(arr1, arr2), arr1 * 2) + + def test_streamed_gemm_overlap(self): + if not jtu.test_device_matches(["gpu"]): + self.skipTest("Stream annotation is only supported on GPU.") + + mesh = jtu.create_mesh((2,), ('x',)) + s = NamedSharding(mesh, P('x')) + + @compute_on('gpu_stream:1') + @jax.jit + def g(x, y): + return x @ y + + @compute_on('gpu_stream:2') + @jax.jit + def h(x, y): + return x @ y + + @jax.jit + @functools.partial( + jax.shard_map, mesh=mesh, in_specs=(P('x'), P('x')), + out_specs=P('x')) + def f(x, y): + with set_xla_metadata(_scheduling_group_id="1"): + a = g(x, y) + b = h(y, x) + return a + b + + np_input = np.ones((1024, 512)) + + arr1 = jax.device_put(np_input, s) + arr2 = jax.device_put(np_input, s) + + compiled_f = jax.jit(f).lower(arr1, arr2).compile() + compiled_text = compiled_f.as_text() + self.assertIn('call-start', compiled_text) + self.assertIn('_xla_stream_annotation="1"', compiled_text) + self.assertIn('call-start.1', compiled_text) + self.assertIn('_xla_stream_annotation="2"', compiled_text) + self.assertIn('_scheduling_group_id="1"', compiled_text) + self.assertArraysEqual(compiled_f(arr1, arr2), arr1 * 1024) + def test_stream_annotation_inside_shmap(self): if not jtu.test_device_matches(["gpu"]): self.skipTest("Stream annotation is only supported on GPU.") + mesh = jtu.create_mesh((2,), ('x',)) s = NamedSharding(mesh, P('x')) np_inp = np.ones((8,)) @@ -1673,22 +1961,27 @@ def test_stream_annotation_inside_shmap(self): @compute_on('gpu_stream:1') @jax.jit def g(x, y): - return x * y + return x * y + x @compute_on('gpu_stream:2') @jax.jit def h(x, y): - return x * y + return x * y + x def f(x, y): z = g(x, y) w = h(3 * x, 2 * y) return z + w - out = jax.jit(shard_map(f, mesh=mesh, in_specs=(P('x'), P('x')), - out_specs=P('x')))(arr1, arr2) - self.assertArraysEqual(out, arr1 * 7) - + compiled_f = jax.jit( + shard_map(f, mesh=mesh, in_specs=(P('x'), P('x')), + out_specs=P('x'))).lower(arr1, arr2).compile() + compiled_text = compiled_f.as_text() + self.assertIn('call-start', compiled_text) + self.assertIn('_xla_stream_annotation="1"', compiled_text) + self.assertIn('call-start.1', compiled_f.as_text()) + self.assertIn('_xla_stream_annotation="2"', compiled_text) + self.assertArraysEqual(compiled_f(arr1, arr2), arr1 * 11) class ActivationOffloadingTest(jtu.JaxTestCase): @@ -1704,7 +1997,7 @@ def test_remat_jaxpr_offloadable(self): def policy(prim, *avals, **params): return Offloadable(src="device", dst="pinned_host") - @functools.partial(remat, policy=policy) + @functools.partial(jax.remat, policy=policy) def f(x): x = jnp.sin(x) x = jnp.sin(x) @@ -1714,13 +2007,11 @@ def f(x): fwd_jaxpr, bwd_jaxpr = jtu.fwd_bwd_jaxprs(f, inp) self.assertLen(fwd_jaxpr.out_avals, 4) # 1 output, 3 offloaded residuals - fwd_mem_kind_count = str(fwd_jaxpr).count( - "TransferToMemoryKind(memory_kind='pinned_host')") + fwd_mem_kind_count = str(fwd_jaxpr).count("MemorySpace.Host") self.assertEqual(fwd_mem_kind_count, 3) self.assertLen(bwd_jaxpr.in_avals, 4) # 3 offloaded residuals, 1 input - bwd_mem_kind_count = str(bwd_jaxpr).count( - "TransferToMemoryKind(memory_kind='device')") + bwd_mem_kind_count = str(bwd_jaxpr).count("MemorySpace.Device") self.assertEqual(bwd_mem_kind_count, 3) # Execution test. @@ -1757,7 +2048,7 @@ def test_remat_scan_jaxpr_offloadable(self): names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z", "w"], offload_src='device', offload_dst='pinned_host') - @functools.partial(remat, policy=policy) + @functools.partial(jax.remat, policy=policy) def f(x): def g(ys, _): y, _ = ys @@ -1772,13 +2063,11 @@ def g(ys, _): fwd_jaxpr, bwd_jaxpr = jtu.fwd_bwd_jaxprs(f, inp) self.assertLen(fwd_jaxpr.out_avals, 5) # 2 output, 3 offloaded residuals - fwd_mem_kind_count = str(fwd_jaxpr).count( - "TransferToMemoryKind(memory_kind='pinned_host')") + fwd_mem_kind_count = str(fwd_jaxpr).count("MemorySpace.Host") self.assertEqual(fwd_mem_kind_count, 2) self.assertLen(bwd_jaxpr.in_avals, 5) # 3 offloaded residuals, 2 input - bwd_mem_kind_count = str(bwd_jaxpr).count( - "TransferToMemoryKind(memory_kind='device')") + bwd_mem_kind_count = str(bwd_jaxpr).count("MemorySpace.Device") self.assertEqual(bwd_mem_kind_count, 2) f = jax.jit(jax.grad(f)) @@ -1789,8 +2078,6 @@ def g(ys, _): compiled_text = compiled_f.as_text() if compiled_text is not None: self.assertIn('S(5)', compiled_text) - self.assertNotRegex(compiled_text, r"copy-start.*S\(5\)") - self.assertNotRegex(compiled_text, r"copy-done.*S\(5\)") compiled_stats = compiled_f.memory_analysis() if compiled_stats is not None: @@ -1807,7 +2094,7 @@ def test_remat_scan_layout_change_offloadable(self): names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z", "w"], offload_src='device', offload_dst='pinned_host') - @functools.partial(remat, policy=policy) + @functools.partial(jax.remat, policy=policy) def f(x): def g(ys, _): y, _ = ys @@ -1828,8 +2115,6 @@ def g(ys, _): compiled_text = compiled_f.as_text() if compiled_text is not None: self.assertIn('S(5)', compiled_text) - self.assertNotRegex(compiled_text, r"copy-start.*S\(5\)") - self.assertNotRegex(compiled_text, r"copy-done.*S\(5\)") self.assertRegex(compiled_text, r"dynamic-update-slice-start.*S\(5\)") self.assertRegex(compiled_text, r"dynamic-update-slice-done.*S\(5\)") self.assertRegex(compiled_text, r"dynamic-slice-start.*S\(5\)") @@ -1843,7 +2128,7 @@ def test_remat_checkpoint_dots_with_no_batch_dims(self): policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims( "device", "pinned_host") - @functools.partial(new_checkpoint, policy=policy) + @functools.partial(jax.checkpoint, policy=policy) def f(x): x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST) x = jnp.sin(x) @@ -1881,7 +2166,7 @@ def policy(prim, *args, **kwargs): return Offloadable("device", "pinned_host") return Recompute - @functools.partial(remat, policy=policy) + @functools.partial(jax.remat, policy=policy) def test_fn(x): # Need any primitive with multiple outputs and a non-trivial grad. x1, _ = jax.lax.approx_max_k(x, k=2) diff --git a/tests/mesh_utils_test.py b/tests/mesh_utils_test.py index 136b507942e7..64566fd3a281 100644 --- a/tests/mesh_utils_test.py +++ b/tests/mesh_utils_test.py @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The JAX Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -650,6 +650,65 @@ def test_create_contiguous_submeshes_errors(self): mesh_shape, devices=devices, contiguous_submeshes=True ) + @parameterized.named_parameters( + # <-logical-> <-physical-> + ('1x1x2', [1, 1, 2], [1, 1, 1], [[[0, 1]]]), + ('2x1x4', [2, 1, 4], [2, 2, 1], [[[0, 1, 2, 3]], [[6, 7, 4, 5]]]), + ('4x1x2', [4, 1, 2], [2, 2, 1], [[[0, 1]], [[2, 3]], [[6, 7]], [[4, 5]]]), + ('4x2x2', [4, 2, 2], [2, 2, 2], [[[0, 1], [2, 3]], + [[6, 7], [4, 5]], + [[8, 9], [10, 11]], + [[14, 15], [12, 13]]]), + ('8x2', [2, 8], [2, 2, 2], [[0, 1, 2, 3, 6, 7, 4, 5], + [8, 9, 10, 11, 14, 15, 12, 13]]), + ('4x4x2', [4, 4, 2], [2, 2, 4], [[[0, 1], [2, 3], + [6, 7], [4, 5]], + [[8, 9], [10, 11], + [14, 15], [12, 13]], + [[16, 17], [18, 19], + [22, 23], [20, 21]], + [[24, 25], [26, 27], + [30, 31], [28, 29]]]), + ('4x2x4', [4, 2, 4], [2, 2, 4], [[[0, 1, 2, 3], [6, 7, 4, 5]], + [[8, 9, 10, 11], [14, 15, 12, 13]], + [[16, 17, 18, 19], [22, 23, 20, 21]], + [[24, 25, 26, 27], [30, 31, 28, 29]]]), + ('8x4', [8, 4], [2, 2, 4], [[0, 1, 2, 3], [6, 7, 4, 5], + [8, 9, 10, 11], [14, 15, 12, 13], + [16, 17, 18, 19], [22, 23, 20, 21], + [24, 25, 26, 27], [30, 31, 28, 29]]), + ('4x8', [4, 8], [2, 2, 4], [[0, 1, 2, 3, 6, 7, 4, 5], + [8, 9, 10, 11, 14, 15, 12, 13], + [16, 17, 18, 19, 22, 23, 20, 21], + [24, 25, 26, 27, 30, 31, 28, 29]]), + ) + def test_v7x_create_device_mesh( + self, logical_mesh_shape, physical_mesh_shape, expected_device_id_mesh + ): + global_devices = mock_tpu_devices( + physical_mesh_shape[0], + physical_mesh_shape[1], + physical_mesh_shape[2], + mesh_utils._TPU_7X, + one_device_per_chip=False, + ) + mesh = mesh_utils.create_device_mesh( + logical_mesh_shape, devices=global_devices, contiguous_submeshes=False + ) + device_id_mesh = np.vectorize(lambda d: d.id)(mesh) + self.assertAllClose(device_id_mesh, np.array(expected_device_id_mesh)) + + def test_v7x_create_device_mesh_fallback(self): + devices = mock_tpu_devices(2, 4, 4, mesh_utils._TPU_7X, + one_device_per_chip=False) + mesh = mesh_utils.create_device_mesh( + (1, 32), devices=devices[:32], contiguous_submeshes=False) + self.assertEqual(mesh.shape, (1, 32)) + + mesh = mesh_utils.create_device_mesh( + (1, 32), devices=devices[32:], contiguous_submeshes=False) + self.assertEqual(mesh.shape, (1, 32)) + def int64_array(x) -> np.ndarray: return np.array(x, dtype=np.int64) diff --git a/tests/metadata_test.py b/tests/metadata_test.py index 9fe6950773b9..524768aaed87 100644 --- a/tests/metadata_test.py +++ b/tests/metadata_test.py @@ -37,11 +37,11 @@ class MetadataTest(jtu.JaxTestCase): def test_jit_metadata(self): hlo = module_to_string(jax.jit(jnp.sin).lower(1.).compiler_ir()) - self.assertRegex(hlo, r'loc\("jit\(sin\)/jit\(main\)/sin"') + self.assertRegex(hlo, r'loc\("jit\(sin\)/sin"') def foo(x): return jnp.sin(x) hlo = module_to_string(jax.jit(foo).lower(1.).compiler_ir()) - self.assertRegex(hlo, r'loc\("jit\(foo\)/jit\(main\)/sin"') + self.assertRegex(hlo, r'loc\("jit\(foo\)/sin"') @unittest.skip("TODO") # TODO(jekbradbury) def test_nested_jit_metadata(self): @@ -70,8 +70,8 @@ def test_grad_jit_metadata(self): def foo(x): return jnp.sin(x) hlo = module_to_string(jax.jit(jax.grad(foo)).lower(1.).compiler_ir()) - self.assertRegex(hlo, r'loc\(".*jvp\(jit\(foo\)\)/cos"') - self.assertRegex(hlo, r'loc\(".*transpose\(jvp\(jit\(foo\)\)\)/mul"') + self.assertRegex(hlo, r'loc\(".*jvp\(jit\(foo\)\)"') + self.assertRegex(hlo, r'loc\(".*transpose\(jvp\(jit\(foo\)\)\)"') def test_cond_metadata(self): def true_fun(x): @@ -79,7 +79,7 @@ def true_fun(x): def false_fun(x): return jnp.cos(x) def f(which, x): - return jax.lax.cond(which, x, true_fun, x, false_fun) + return jax.lax.cond(which, true_fun, false_fun, x) hlo = module_to_string(jax.jit(f).lower(True, 1.).compiler_ir()) self.assertRegex(hlo, r'loc\(".*cond/branch_0_fun/cos"') self.assertRegex(hlo, r'loc\(".*cond/branch_1_fun/sin"') diff --git a/tests/mock_gpu_topology_test.py b/tests/mock_gpu_topology_test.py index 59c511ae61cf..8e409d6ed331 100644 --- a/tests/mock_gpu_topology_test.py +++ b/tests/mock_gpu_topology_test.py @@ -14,6 +14,7 @@ from absl.testing import absltest import jax +from jax._src import config from jax._src import test_util as jtu import jax.numpy as jnp from jax.sharding import NamedSharding @@ -49,13 +50,18 @@ def testMockWithSharding(self): f_lowered = f.lower(jnp.arange(16)) hlo = f_lowered.compiler_ir() + hlo_str = str(hlo) mocked_count = NUM_SLICES * NUM_HOSTS_PER_SLICE - self.assertIn(f'num_partitions = {mocked_count}', str(hlo)) - self.assertIn( - f'sharding = "{{devices=[{mocked_count}]<=[{mocked_count}]}}"', - str(hlo) - ) + self.assertIn(f'num_partitions = {mocked_count}', hlo_str) + + if config.use_shardy_partitioner.value: + expected_sharding = 'sharding = #sdy.sharding<@mesh, [{"x"}]>' + else: + expected_sharding = ( + f'sharding = "{{devices=[{mocked_count}]<=[{mocked_count}]}}"' + ) + self.assertIn(expected_sharding, hlo_str) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/monitoring_test.py b/tests/monitoring_test.py index 52b53895c2cc..674a2d91e9d2 100644 --- a/tests/monitoring_test.py +++ b/tests/monitoring_test.py @@ -29,7 +29,7 @@ def tearDown(self): def test_record_event(self): events = [] - counters = {} # Map event names to frequency. + counters = {} # Map event names to frequency. def increment_event_counter(event): if event not in counters: counters[event] = 0 @@ -48,8 +48,9 @@ def increment_event_counter(event): "test_common_event": 2}) def test_record_event_durations(self): - durations = {} # Map event names to frequency. - def increment_event_duration(event, duration): + durations = {} # Map event names to frequency. + def increment_event_duration(event, duration, **kwargs): + del kwargs if event not in durations: durations[event] = 0. durations[event] += duration @@ -62,9 +63,33 @@ def increment_event_duration(event, duration): self.assertDictEqual(durations, {"test_short_event": 3, "test_long_event": 10}) + def test_record_scalar(self): + observed_keys = [] + observed_values = [] + + monitoring.register_scalar_listener( + lambda key, _, **kwargs: observed_keys.append(key), + ) + monitoring.register_scalar_listener( + lambda _, value, **kwargs: observed_values.append(value), + ) + + monitoring.record_scalar("test_unique_event", 1) + monitoring.record_scalar("test_common_event", 2.5) + monitoring.record_scalar("test_common_event", 5e5) + + self.assertListEqual( + observed_keys, + ["test_unique_event", "test_common_event", "test_common_event"], + ) + self.assertListEqual( + observed_values, + [1, 2.5, 5e5], + ) + def test_unregister_exist_callback_success(self): original_duration_listeners = jax_src_monitoring.get_event_duration_listeners() - callback = lambda event, durations: None + callback = lambda event, durations, **kwargs: None self.assertNotIn(callback, original_duration_listeners) monitoring.register_event_duration_secs_listener(callback) self.assertIn(callback, jax_src_monitoring.get_event_duration_listeners()) @@ -72,49 +97,22 @@ def test_unregister_exist_callback_success(self): self.assertNotEqual(original_duration_listeners, jax_src_monitoring.get_event_duration_listeners()) - jax_src_monitoring._unregister_event_duration_listener_by_callback(callback) + jax_src_monitoring.unregister_event_duration_listener(callback) self.assertEqual(original_duration_listeners, jax_src_monitoring.get_event_duration_listeners()) def test_unregister_not_exist_callback_fail(self): - callback = lambda event, durations: None + callback = lambda event, durations, **kwargs: None self.assertNotIn(callback, jax_src_monitoring.get_event_duration_listeners()) with self.assertRaises(AssertionError): - jax_src_monitoring._unregister_event_duration_listener_by_callback( - callback) - - def test_unregister_callback_index_in_range_success(self): - original_duration_listeners = jax_src_monitoring.get_event_duration_listeners() - callback = lambda event, durations: None - self.assertNotIn(callback, original_duration_listeners) - monitoring.register_event_duration_secs_listener(callback) - self.assertIn(callback, jax_src_monitoring.get_event_duration_listeners()) - # Verify that original listeners list is not modified by register function. - self.assertNotEqual(original_duration_listeners, - jax_src_monitoring.get_event_duration_listeners()) - - jax_src_monitoring._unregister_event_duration_listener_by_index(-1) - - self.assertEqual(original_duration_listeners, - jax_src_monitoring.get_event_duration_listeners()) - - def test_unregister_callback_index_out_of_range_fail(self): - size = len(jax_src_monitoring.get_event_duration_listeners()) - - # Verify index >= size raises AssertionError. - with self.assertRaises(AssertionError): - jax_src_monitoring._unregister_event_duration_listener_by_index(size) - - # Verify index < -size raises AssertionError. - with self.assertRaises(AssertionError): - jax_src_monitoring._unregister_event_duration_listener_by_index(-size - 1) + jax_src_monitoring.unregister_event_duration_listener(callback) def test_get_event_duration_listeners_returns_a_copy(self): original_duration_listeners = jax_src_monitoring.get_event_duration_listeners() - callback = lambda event, durations: None + callback = lambda event, durations, **kwargs: None original_duration_listeners.append(callback) @@ -132,7 +130,7 @@ def test_unregister_exist_event_callback_success(self): self.assertNotEqual(original_event_listeners, jax_src_monitoring.get_event_listeners()) - jax_src_monitoring._unregister_event_listener_by_callback(callback) + jax_src_monitoring.unregister_event_listener(callback) self.assertEqual(original_event_listeners, jax_src_monitoring.get_event_listeners()) @@ -142,7 +140,7 @@ def test_unregister_not_exist_event_callback_fail(self): self.assertNotIn(callback, jax_src_monitoring.get_event_listeners()) with self.assertRaises(AssertionError): - jax_src_monitoring._unregister_event_listener_by_callback(callback) + jax_src_monitoring.unregister_event_listener(callback) if __name__ == "__main__": absltest.main() diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index 71b2b7d80570..ad2f5afb068d 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -14,6 +14,7 @@ load( "//jaxlib:jax.bzl", + "if_cuda_is_configured", "jax_generate_backend_suites", "jax_multiplatform_test", "jax_py_test", @@ -27,102 +28,216 @@ package( default_visibility = ["//visibility:private"], ) +test_suite( + name = "gpu_mlir_deviceless_tests", + tags = ["gpu_mlir_deviceless_test"], +) + jax_generate_backend_suites() jax_multiplatform_test( name = "gpu_test", srcs = ["gpu_test.py"], + config_tags_overrides = { + # TODO(b/448760413): Re-enable once llvm fixed. + "gpu_b200": {"notap": True}, + }, enable_backends = [], enable_configs = [ "gpu_h100", - "gpu_h100x2", + "gpu_b200", ], env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, - shard_count = 16, + shard_count = 8, tags = [ - "multiaccelerator", "noasan", # Times out. ], + deps = if_cuda_is_configured([ + "//jax/experimental:mosaic_gpu", + "//jax/experimental:mosaic_gpu_test_util", + ]) + py_deps([ + "absl/testing", + "numpy", + "hypothesis", + ]), +) + +jax_multiplatform_test( + name = "gpu_torch_test", + srcs = ["gpu_torch_test.py"], + enable_backends = [], + enable_configs = [ + "gpu_h100", + "gpu_b200", + ], + env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, + tags = [ + "noasan", # ASAN is unsupported. + ], + deps = if_cuda_is_configured([ + "//jax/experimental:mosaic_gpu", + ]) + py_deps([ + "absl/testing", + "numpy", + "torch", + ]), +) + +jax_multiplatform_test( + name = "gpu_test_multidevice", + srcs = ["gpu_test_multidevice.py"], + enable_backends = [], + enable_configs = [ + "gpu_h100x2", + "gpu_b200x2", + ], + env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, + tags = ["multiaccelerator"], + deps = if_cuda_is_configured([ + "//jax/experimental:mosaic_gpu", + ]) + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "gpu_test_distributed", + srcs = ["gpu_test_distributed.py"], + args = [ + "--num_processes=2", + "--gpus_per_process=1", + ], + enable_backends = [], + enable_configs = [ + "gpu_h100x2", + "gpu_b200x2", + ], + env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0 --xla_gpu_experimental_enable_nvshmem=true"}, + tags = ["multiaccelerator"], deps = [ - "//jax:mosaic_gpu", - ] + py_deps("absl/testing") + py_deps("numpy"), + "//jax/_src:test_multiprocess", + ] + if_cuda_is_configured([ + "//jax/experimental:mosaic_gpu", + ]) + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_py_test( name = "gpu_dialect_test", srcs = ["gpu_dialect_test.py"], + tags = ["gpu_mlir_deviceless_test"], deps = [ "//jax", - "//jax:mosaic_gpu", - "//jax:test_util", - ] + py_deps("absl/testing"), + "//jax/_src:test_util", + ] + if_cuda_is_configured([ + "//jax/experimental:mosaic_gpu", + ]) + py_deps("absl/testing"), ) jax_py_test( - name = "gpu_layout_inference_test", - srcs = ["gpu_layout_inference_test.py"], + name = "gpu_constraints_test", + srcs = ["gpu_constraints_test.py"], deps = [ "//jax", - "//jax:mosaic_gpu", - "//jax:test_util", - ] + py_deps("absl/testing"), + "//jax/_src:test_util", + ] + if_cuda_is_configured([ + "//jax/experimental:mosaic_gpu", + ]) + py_deps("absl/testing"), ) jax_py_test( - name = "gpu_transform_inference_test", - srcs = ["gpu_transform_inference_test.py"], + name = "gpu_layout_inference_test", + size = "small", + srcs = ["gpu_layout_inference_test.py"], + tags = ["gpu_mlir_deviceless_test"], deps = [ "//jax", - "//jax:mosaic_gpu", - "//jax:test_util", - ] + py_deps("absl/testing"), + "//jax/_src:test_util", + ] + if_cuda_is_configured([ + "//jax/experimental:mosaic_gpu", + "//jax/experimental:mosaic_gpu_test_util", + ]) + py_deps("absl/testing"), ) jax_multiplatform_test( name = "matmul_test", srcs = ["matmul_test.py"], enable_backends = [], - enable_configs = ["gpu_h100"], + enable_configs = [ + "gpu_h100", + "gpu_b200", + ], shard_count = 5, - deps = [ - "//jax:mosaic_gpu", + tags = [ + "noasan", + ], + deps = if_cuda_is_configured([ + "//jax/experimental:mosaic_gpu", "//jax/experimental/mosaic/gpu/examples:matmul", - ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), + "//jax/experimental/mosaic/gpu/examples:matmul_blackwell", + ]) + py_deps([ + "absl/testing", + "numpy", + "hypothesis", + ]), ) jax_multiplatform_test( name = "flash_attention", srcs = ["//jax/experimental/mosaic/gpu/examples:flash_attention.py"], enable_backends = [], - enable_configs = ["gpu_h100"], + enable_configs = [ + "gpu_h100", + "gpu_b200", + ], main = "//jax/experimental/mosaic/gpu/examples:flash_attention.py", - tags = ["notap"], - deps = [ - "//jax:mosaic_gpu", - ] + py_deps("numpy"), + tags = [ + "manual", + "noasan", + "notap", + ], + deps = if_cuda_is_configured([ + "//jax/experimental:mosaic_gpu", + ]) + py_deps([ + "numpy", + "absl/testing", + ]), ) jax_multiplatform_test( name = "flash_attention_test", srcs = ["flash_attention_test.py"], enable_backends = [], - enable_configs = ["gpu_h100"], - deps = [ - "//jax:mosaic_gpu", + enable_configs = [ + "gpu_h100", + "gpu_b200", + ], + shard_count = 8, + tags = [ + "noasan", # Remove the tag once the CUPTI issue is fixed. + ], + deps = if_cuda_is_configured([ + "//jax/experimental:mosaic_gpu", "//jax/experimental/mosaic/gpu/examples:flash_attention", - ] + py_deps("absl/testing"), + ]) + py_deps("absl/testing"), ) jax_multiplatform_test( name = "profiler_cupti_test", srcs = ["profiler_cupti_test.py"], enable_backends = [], - enable_configs = ["gpu_h100"], + enable_configs = [ + "gpu_h100", + "gpu_b200", + ], tags = [ "noasan", # CUPTI leaks memory "nomsan", ], - deps = [ - "//jax:mosaic_gpu", - ] + py_deps("absl/testing"), + deps = if_cuda_is_configured([ + "//jax/experimental:mosaic_gpu", + ]) + py_deps("absl/testing"), ) diff --git a/tests/mosaic/gpu_constraints_test.py b/tests/mosaic/gpu_constraints_test.py new file mode 100644 index 000000000000..17b1403837db --- /dev/null +++ b/tests/mosaic/gpu_constraints_test.py @@ -0,0 +1,477 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Mosaic GPU's `constraints` module.""" + +from absl.testing import parameterized +from jax._src import config +from jax._src import test_util as jtu +import jax.experimental.mosaic.gpu as mgpu +from jax.experimental.mosaic.gpu import constraints as cs +from jax.experimental.mosaic.gpu import fragmented_array as fa +from jax.experimental.mosaic.gpu import launch_context as lc +from jax.experimental.mosaic.gpu import tcgen05 + +config.parse_flags_with_absl() + +RL = cs.RegisterLayout +Eq = cs.Equals +V = cs.Variable + + +class ConstraintSystemTest(parameterized.TestCase): + + def test_constraint_system_is_unsatisfiable_if_assignments_are_incompatible( + self, + ): + v0 = V(0) + layout0, layout1 = [RL(mgpu.WGSplatFragLayout((1, i))) for i in (1, 2)] + system = cs.ConstraintSystem( + constraints=[Eq(v0, layout0), Eq(v0, layout1)], + ) + self.assertIsInstance(cs.reduce(system), cs.Unsatisfiable) + + def test_constraint_system_is_unsatisfiable_if_constraints_are_unsatisfiable( + self, + ): + v0 = V(0) + layout0, layout1 = [RL(mgpu.WGSplatFragLayout((1, i))) for i in (1, 2)] + system = cs.ConstraintSystem( + assignments={v0: layout0}, + constraints=[cs.Relayout(v0, layout1, 32)], + ) + self.assertIsInstance(cs.reduce(system), cs.Unsatisfiable) + + @parameterized.parameters(*cs._SUPPORTED_TILED_RELAYOUTS) + def test_reduce_constraint_system_removes_satisfed_relayouts(self, src, tgt): + system = cs.ConstraintSystem( + constraints=[cs.Relayout(RL(src), RL(tgt), 4)], + ) + self.assertEqual(cs.reduce(system), cs.ConstraintSystem()) + + def test_relayout_constraint_does_not_hold_for_incompatible_layouts(self): + self.assertFalse( + cs.Relayout( + RL(mgpu.WGMMA_ROW_LAYOUT), RL(mgpu.WGMMA_COL_LAYOUT), 32 + ).holds() + ) + + def test_not_of_type_constraint_holds_for_different_types(self): + layout = RL(mgpu.WGMMA_LAYOUT) + self.assertTrue(cs.NotOfType(layout, mgpu.WGSplatFragLayout).holds()) + + def test_not_of_type_constraint_does_not_holds_for_same_types(self): + layout = RL(mgpu.WGSplatFragLayout((1, 128))) + self.assertFalse(cs.NotOfType(layout, mgpu.WGSplatFragLayout).holds()) + + def test_not_of_type_constraint_is_unknown_for_unreduced_expression(self): + self.assertIsNone(cs.NotOfType(V(0), mgpu.WGSplatFragLayout).holds()) + + def test_reduce_constraint_system_removes_tautological_constraints_and_constraints( + self, + ): + v0, v1 = V(0), V(1) + system = cs.ConstraintSystem( + constraints=[ + Eq(v0, v1), + Eq(v0, v0), + cs.Relayout(v0, v0, 32), + cs.NotOfType(RL(mgpu.WGMMA_LAYOUT), mgpu.WGSplatFragLayout), + cs.NotOfType(v1, mgpu.WGSplatFragLayout), + ], + ) + self.assertLen(cs.reduce(system).constraints, 2) + + def test_reduce_constraint_system_of_simplified_system_is_noop(self): + v0, v1 = V(0), V(1) + system = cs.ConstraintSystem(constraints=[Eq(v0, v1)]) + self.assertEqual(cs.reduce(system), system) + + def test_reduce_constraint_system_assigns_variables_with_known_constraints( + self, + ): + v0, v1 = V(0), V(1) + layout = RL(mgpu.WGSplatFragLayout((1, 1))) + + with self.subTest("left-to-right-assignment"): + system = cs.ConstraintSystem( + constraints=[Eq(v0, layout), Eq(v0, v1)], + ) + self.assertEqual( + cs.reduce(system), + cs.ConstraintSystem(assignments={v0: layout, v1: layout}), + ) + + with self.subTest("right-to-left-assignment"): + system = cs.ConstraintSystem( + constraints=[Eq(v1, layout), Eq(v0, v1)], + ) + self.assertEqual( + cs.reduce(system), + cs.ConstraintSystem(assignments={v0: layout, v1: layout}), + ) + + def test_constraint_system_unknowns_are_all_the_variables_without_assignment( + self, + ): + v0, v1, v2, v3 = V(0), V(1), V(2), V(3) + layout = RL(mgpu.WGSplatFragLayout((1, 1))) + system = cs.ConstraintSystem( + assignments={v0: layout}, + constraints=[Eq(v1, v2), cs.Relayout(v2, v3, 32)], + ) + self.assertSequenceEqual(system.unknowns(), [v1, v2, v3]) + + def test_intersection_of_conflicting_systems_is_unsatisfiable(self): + v0 = V(0) + layout0, layout1 = [RL(mgpu.WGSplatFragLayout((1, i))) for i in (1, 2)] + system0 = cs.ConstraintSystem(assignments={v0: layout0}) + system1 = cs.ConstraintSystem(assignments={v0: layout1}) + self.assertIsInstance(system0 & system1, cs.Unsatisfiable) + + def test_intersection_of_compatible_systems_is_union_of_fields(self): + v0, v1, v2 = V(0), V(1), V(2) + layout0, layout1, layout2 = [ + RL(mgpu.WGSplatFragLayout((1, i))) for i in (1, 2, 3) + ] + system0 = cs.ConstraintSystem(constraints=[Eq(v0, layout0)]) + system1 = cs.ConstraintSystem( + assignments={v2: layout2}, + constraints=[Eq(v1, layout1)], + ) + system_intersection = system0 & system1 + self.assertEqual( + system_intersection, + cs.ConstraintSystem( + assignments={v2: layout2}, + constraints=[Eq(v0, layout0), Eq(v1, layout1)], + ), + ) + self.assertSequenceEqual(system0.unknowns(), [v0]) + self.assertSequenceEqual(system1.unknowns(), [v1]) + self.assertSequenceEqual(system_intersection.unknowns(), [v0, v1]) + + @parameterized.named_parameters( + ("reduce_to_row_layout", (1,), mgpu.WGMMA_ROW_LAYOUT), + ("reduce_to_col_layout", (0,), mgpu.WGMMA_COL_LAYOUT), + ) + def test_reduce_reduce_expression_reduces_layout(self, axes, expected_layout): + tiled_layout = RL(mgpu.WGMMA_LAYOUT) + self.assertEqual( + cs.reduce_expression(cs.Reduce(tiled_layout, axes=axes), {}), + RL(expected_layout), + ) + + def test_reduce_reduce_expression_with_unsupported_layout_is_irreducible(self): + layout = RL(mgpu.WGStridedFragLayout((128, 8), vec_size=8)) + expr = cs.Reduce(layout, axes=(0,)) + self.assertEqual(cs.reduce_expression(expr, {}), expr) + + def test_reduce_broadcast_of_splat_layout_is_reduced_to_splat_layout(self): + layout = RL(mgpu.WGSplatFragLayout((128,))) + valid_shape = (128, 8) + self.assertEqual( + cs.reduce_expression( + cs.BroadcastInDim(layout, axes=(0,), shape=valid_shape), {} + ), + RL(mgpu.WGSplatFragLayout((128, 8))), + ) + + def test_reduce_broadcast_of_splat_layout_is_unsatisfiable_for_incompatible_shape(self): + layout = RL(mgpu.WGSplatFragLayout((128,))) + invalid_shape = (129, 8) + self.assertIsInstance( + cs.reduce_expression( + cs.BroadcastInDim(layout, axes=(0,), shape=invalid_shape), + {}, + ), + cs.Unsatisfiable, + ) + + def test_reduce_broadcast_of_strided_layout_is_irreducible(self): + layout = RL(mgpu.WGStridedFragLayout((128,), vec_size=1)) + expr = cs.BroadcastInDim(layout, axes=(0,), shape=(128, 8)) + self.assertEqual(cs.reduce_expression(expr, {}), expr) + + def test_reduce_broadcast_of_tiled_layout_is_irreducible(self): + layout = RL(mgpu.WGMMA_LAYOUT) + expr = cs.BroadcastInDim(layout, axes=(1, 2), shape=(8, 128, 8)) + self.assertEqual(cs.reduce_expression(expr, {}), expr) + + def test_reduce_reshape_of_splat_layout_is_reduced_to_splat_layout(self): + layout = RL(mgpu.WGSplatFragLayout((1024,))) + source_shape, target_shape = (1024,), (128, 8) + self.assertEqual( + cs.reduce_expression( + cs.Reshape(layout, source_shape, target_shape), {} + ), + RL(mgpu.WGSplatFragLayout(target_shape)), + ) + + def test_reduce_reshape_of_strided_layout_is_reduced_to_strided_layout(self): + layout = RL(mgpu.WGStridedFragLayout((1024,), vec_size=8)) + source_shape, target_shape = (1024,), (128, 8) + self.assertEqual( + cs.reduce_expression( + cs.Reshape(layout, source_shape, target_shape), {} + ), + RL(mgpu.WGStridedFragLayout(target_shape, vec_size=8)), + ) + + def test_reduce_reshape_of_tiled_layout_with_indivisible_shape_is_irreducible(self): + layout = RL(mgpu.WGMMA_LAYOUT) + source_shape, target_shape = (128, 8), (129, 8) + eq = cs.Reshape(layout, source_shape, target_shape) + self.assertEqual(cs.reduce_expression(eq, {}), eq) + + def test_reduce_reshape_of_tiled_layout_with_modified_minor_tiled_dimensions_is_irreducible( + self, + ): + layout = RL(mgpu.WGMMA_LAYOUT) + source_shape, target_shape = (2, 128, 8), (2, 64, 16) + eq = cs.Reshape(layout, source_shape, target_shape) + self.assertEqual(cs.reduce_expression(eq, {}), eq) + + def test_reduce_reshape_of_tiled_layout_with_compatible_shape_is_identity( + self, + ): + layout = RL(mgpu.WGMMA_LAYOUT) + source_shape, target_shape = (2, 128, 8), (256, 8) + eq = cs.Reshape(layout, source_shape, target_shape) + self.assertEqual(cs.reduce_expression(eq, {}), layout) + + def test_relayout_of_non_splat_to_splat_is_unsatisfiable_shortcut( + self, + ): + splat_layout = RL(mgpu.WGSplatFragLayout((128,))) + v0, v1 = V(0), V(1) + system = cs.ConstraintSystem( + assignments={v1: splat_layout}, + constraints=[ + cs.NotOfType(v0, mgpu.WGSplatFragLayout), + cs.Relayout(v0, v1, 32), + ], + ) + self.assertIsInstance(cs.reduce(system), cs.Unsatisfiable) + + def test_saturate_distinct_from_splat_does_not_create_duplicate_constraints( + self, + ): + bw = 32 + v0, v1, v2 = V(0), V(1), V(2) + system = cs.ConstraintSystem( + constraints=[ + cs.NotOfType(v0, mgpu.WGSplatFragLayout), + cs.NotOfType(v1, mgpu.WGSplatFragLayout), + cs.Relayout(v0, v2, bw), + cs.Relayout(v1, v2, bw), + ], + ) + + self.assertEqual( + cs.saturate_distinct_from_splat(system), + cs.ConstraintSystem( + constraints=[ + cs.NotOfType(v0, mgpu.WGSplatFragLayout), + cs.NotOfType(v1, mgpu.WGSplatFragLayout), + cs.Relayout(v0, v2, bw), + cs.Relayout(v1, v2, bw), + cs.NotOfType(v2, mgpu.WGSplatFragLayout), + ], + ), + ) + + def test_saturate_distinct_from_splat_does_not_affect_non_splat( + self, + ): + bw = 32 + v0, v1, v2, v3, v4 = V(0), V(1), V(2), V(3), V(4) + system = cs.ConstraintSystem( + constraints=[ + cs.NotOfType(v0, mgpu.WGSplatFragLayout), + cs.NotOfType(v1, mgpu.WGStridedFragLayout), + cs.Relayout(v0, v2, bw), + cs.Relayout(v1, v3, bw), + cs.Relayout(v4, v0, bw), + ], + ) + + self.assertEqual( + cs.saturate_distinct_from_splat(system), + cs.ConstraintSystem( + constraints=[ + cs.NotOfType(v0, mgpu.WGSplatFragLayout), + cs.NotOfType(v1, mgpu.WGStridedFragLayout), + cs.Relayout(v0, v2, bw), + cs.Relayout(v1, v3, bw), + cs.Relayout(v4, v0, bw), + cs.NotOfType(v2, mgpu.WGSplatFragLayout), + ], + ), + ) + + @parameterized.parameters( + (mgpu.WGMMA_LAYOUT, (64, 64), True), + (mgpu.WGMMA_LAYOUT, (64,), False), + (mgpu.WGMMA_LAYOUT, None, False), + (mgpu.WGMMA_ROW_LAYOUT, None, True), + (mgpu.WGMMA_ROW_LAYOUT, (64,), False), + (mgpu.WGMMA_COL_LAYOUT, None, True), + (mgpu.WGMMA_COL_LAYOUT, (64,), False), + (mgpu.WGSplatFragLayout((16, 16)), None, True), + (mgpu.WGSplatFragLayout((16, 16)), (16,), False), + (mgpu.WGStridedFragLayout((16, 128), vec_size=4), None, True), + (mgpu.WGStridedFragLayout((16, 128), vec_size=4), (1,), False), + ) + def test_smem_is_transferable(self, layout, tiling, expected): + eq_layout = cs.RegisterLayout(layout) + eq_tiling = cs.SMEMTiling(lc.TileTransform(tiling) if tiling else None) + + reg_to_smem = cs.IsTransferable(eq_layout, eq_tiling, ()) + self.assertEqual(reg_to_smem.holds(), expected) + smem_to_reg = cs.IsTransferable(eq_tiling, eq_layout, ()) + self.assertEqual(smem_to_reg.holds(), expected) + + def test_transpose_expression(self): + def transpose(tiling): + transform = None if tiling is None else lc.TileTransform(tiling) + return cs.Transpose(cs.SMEMTiling(transform)) + + self.assertEqual( + cs.reduce_expression(transpose(None), {}), + cs.SMEMTiling(None), + ) + self.assertEqual( + cs.reduce_expression(transpose((2, 3)), {}), + cs.SMEMTiling(lc.TileTransform((3, 2))), + ) + + def test_divides_constraint_are_satisfied_by_empty_tiling(self): + self.assertTrue(cs.Divides(cs.SMEMTiling(None), (1, 2)).holds()) + + def test_divides_constraints_are_satisfied_by_divisor_tiling(self): + with self.subTest("SMEMTiling"): + tiling = cs.SMEMTiling(lc.TileTransform((2, 2))) + self.assertTrue(cs.Divides(tiling, (4, 6)).holds()) + with self.subTest("RegisterLayout"): + tiling = cs.RegisterLayout(fa.WGMMA_LAYOUT) + self.assertTrue(cs.Divides(tiling, (0, 64)).holds()) + with self.subTest("TMEMLayout"): + layout = tcgen05.tmem_default_layout(packing=1) + tiling = cs.TMEMLayout(layout) + self.assertTrue(cs.Divides(tiling, (0, 64)).holds()) + + def test_divides_constraints_are_not_satisfied_by_non_divisor_tiling(self): + with self.subTest("SMEMTiling"): + tiling = cs.SMEMTiling(lc.TileTransform((2, 2))) + self.assertFalse(cs.Divides(tiling, (4, 3)).holds()) + with self.subTest("RegisterLayout"): + tiling = cs.RegisterLayout(fa.WGMMA_LAYOUT) + self.assertFalse(cs.Divides(tiling, (3, 64)).holds()) + with self.subTest("TMEMLayout"): + layout = tcgen05.tmem_default_layout(packing=1) + tiling = cs.TMEMLayout(layout) + self.assertFalse(cs.Divides(tiling, (3, 64)).holds()) + + def test_reduce_merges_divides_constraints_on_same_variable(self): + v0, v1 = cs.Variable(0), cs.Variable(1) + constraints = [ + cs.Divides(v0, (18, 17)), + cs.Divides(v0, (3, 19)), + cs.Divides(v1, (6, 1, 3)), + ] + self.assertEqual( + cs.reduce(cs.ConstraintSystem(constraints=constraints)).constraints, + [ + cs.Divides(v0, (3, 1)), + cs.Divides(v1, (6, 1, 3)), + ], + ) + + # Check that merging constraints with different lenghts yields a constraint + # whose length matches the one of the shorter tiling_multiple. + constraints = [ + cs.Divides(v0, (16, 10)), + cs.Divides(v0, (8,)), + ] + self.assertEqual( + cs.reduce(cs.ConstraintSystem(constraints=constraints)).constraints, + [ + cs.Divides(v0, (2,)), + ], + ) + + def test_saturate_divides_constraints_for_equal_vars(self): + def equals(a, b): + return cs.Equals(cs.Variable(a), cs.Variable(b)) + def divides(var, dims): + return cs.Divides(cs.Variable(var), dims) + + # One equality + s = cs.ConstraintSystem( + constraints=[ + equals(0, 1), + divides(0, (1,)), + ], + ) + got = cs.saturate_divides_constraints_for_equal_vars(s) + want = [equals(0, 1), divides(0, (1,)), divides(1, (1,))] + self.assertEqual(got.constraints, want) + + # Five transitively equal variables and one disconnected one. + s = cs.ConstraintSystem( + constraints=[ + equals(0, 1), + equals(2, 3), + equals(2, 4), + equals(1, 4), + divides(0, (1,)), + divides(5, (1,)), + ], + ) + got = cs.saturate_divides_constraints_for_equal_vars(s) + want = [ + equals(0, 1), + equals(2, 3), + equals(2, 4), + equals(1, 4), + divides(0, (1,)), + divides(1, (1,)), + divides(2, (1,)), + divides(3, (1,)), + divides(4, (1,)), + divides(5, (1,)), + ] + self.assertEqual(got.constraints, want) + + @parameterized.parameters( + (fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT_UPCAST_2X, 8), + (fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT, 8), + (fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT, 32), + ) + def test_forcing_relayout_on_unsupported_bitwidth_raises(self, src, dst, bitwidth): + self.assertFalse(cs.Relayout(cs.RegisterLayout(src), cs.RegisterLayout(dst), bitwidth).holds()) + + @parameterized.parameters( + (fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT_UPCAST_2X, 4), + (fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT, 4), + (fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT, 16), + ) + def test_forcing_relayout_on_supported_bitwidth_succeeds(self, src, dst, bitwidth): + self.assertTrue(cs.Relayout(cs.RegisterLayout(src), cs.RegisterLayout(dst), bitwidth).holds()) + + + +if __name__ == "__main__": + parameterized.absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/mosaic/gpu_dialect_test.py b/tests/mosaic/gpu_dialect_test.py index ba9d23fa5b4f..7dbd9f85de27 100644 --- a/tests/mosaic/gpu_dialect_test.py +++ b/tests/mosaic/gpu_dialect_test.py @@ -14,7 +14,7 @@ # ============================================================================== """(Deviceless) tests for the Mosaic GPU MLIR dialect.""" -from typing import Callable +from collections.abc import Callable from absl.testing import parameterized import jax @@ -24,7 +24,7 @@ from jax._src.interpreters import mlir as mlir_interpreter from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith -from jax._src.lib.mlir.dialects import func +from jax._src.lib.mlir.dialects import builtin from jax._src.lib.mlir.dialects import gpu from jax._src.lib.mlir.dialects import llvm from jax._src.lib.mlir.dialects import memref @@ -32,7 +32,9 @@ from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector from jax.experimental.mosaic import gpu as mgpu +from jax.experimental.mosaic.gpu import dialect_lowering as lowering from jax.experimental.mosaic.gpu import layouts +from jax.experimental.mosaic.gpu import tcgen05 from jax.experimental.mosaic.gpu import utils as mgpu_utils _cext = mgpu.dialect._cext if mgpu.dialect is not None else None @@ -82,6 +84,12 @@ def workgroup_ptr_ty() -> ir.Type: return ir.Type.parse(f"!llvm.ptr<{workgroup_nvptx_address_space}>") +def undefs(*tys: ir.Type) -> list[ir.Value]: + """Returns a list of undefined values of the given types.""" + # TODO(allanrenucci): Use `ub.poison` once Python bindings are available. + return [builtin.unrealized_conversion_cast([ty], []) for ty in tys] + + class MosaicGpuTest(parameterized.TestCase): def setUp(self): @@ -91,6 +99,13 @@ def setUp(self): self.enter_context(_make_ir_context()) self.enter_context(ir.Location.unknown()) self.module = ir.Module.create() + i32 = ir.IntegerType.get_signless(32) + self.module.operation.attributes["mosaic_gpu.arch_major"] = ( + ir.IntegerAttr.get(i32, 9) + ) + self.module.operation.attributes["mosaic_gpu.arch_minor"] = ( + ir.IntegerAttr.get(i32, 0) + ) class DialectTest(MosaicGpuTest): @@ -98,24 +113,12 @@ class DialectTest(MosaicGpuTest): def test_dialect_module_is_loaded(self): self.assertTrue(_cext.globals._check_dialect_module_loaded("mosaic_gpu")) - def test_initialize_barrier_op_result_memref_must_wrap_barriers(self): - with ir.InsertionPoint(self.module.body): - mgpu.dialect.initialize_barrier( - ir.MemRefType.get((1, 2), ir.F32Type.get()), - llvm.UndefOp(workgroup_ptr_ty()), - arrival_count=1, - ) - with self.assertRaisesRegex( - ir.MLIRError, "must be memref of barrier values" - ): - self.module.operation.verify() - def test_initialize_barrier_op_arrival_count_must_be_strictly_positive(self): with ir.InsertionPoint(self.module.body): mgpu.dialect.initialize_barrier( - ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")), llvm.UndefOp(workgroup_ptr_ty()), arrival_count=0, + num_barriers=2, ) with self.assertRaisesRegex(ir.MLIRError, "value is positive"): self.module.operation.verify() @@ -123,9 +126,9 @@ def test_initialize_barrier_op_arrival_count_must_be_strictly_positive(self): def test_initialize_barrier_op_with_a_non_shared_base_pointer_fails(self): with ir.InsertionPoint(self.module.body): mgpu.dialect.initialize_barrier( - ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")), llvm.UndefOp(ir.Type.parse(f"!llvm.ptr<{0}>")), arrival_count=1, + num_barriers=2, ) with self.assertRaisesRegex(ir.MLIRError, "pointer in address space 3"): self.module.operation.verify() @@ -133,63 +136,28 @@ def test_initialize_barrier_op_with_a_non_shared_base_pointer_fails(self): def test_initialize_barrier_op_with_a_positive_arrival_count_passes(self): with ir.InsertionPoint(self.module.body): mgpu.dialect.initialize_barrier( - ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")), llvm.UndefOp(workgroup_ptr_ty()), arrival_count=1, + num_barriers=2, ) self.assertTrue(self.module.operation.verify()) - self.assertIsInstance( - self.module.body.operations[1], mgpu.dialect.InitializeBarrierOp - ) - - def test_async_load_op_dest_must_be_contiguous(self): - with ir.InsertionPoint(self.module.body): - func.FuncOp.from_py_func( - ir.MemRefType.get([4, 8], ir.F32Type.get()), - ir.MemRefType.get( - [4, 8], - ir.F32Type.get(), - layout=ir.Attribute.parse("strided<[16, 1]>"), - ), - ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")), - ir.IntegerType.get_signless(32), - ir.IntegerType.get_signless(32), - name="async_load", - )( - lambda source, destination, barrier, *indices: mgpu.dialect.async_load( - source, - destination, - barrier, - indices, - slice_lengths=[4, 8], - collective=ir.ArrayAttr.get([]), - ) - ) - - with self.assertRaisesRegex( - ir.MLIRError, - "The `destination` memref must be contiguous", - ): - self.module.operation.verify() def test_async_load_op_source_and_dest_must_have_same_element_type(self): with ir.InsertionPoint(self.module.body): - func.FuncOp.from_py_func( + source, destination, barrier, *indices = undefs( ir.MemRefType.get([4, 8], ir.F32Type.get()), ir.MemRefType.get([4, 8], ir.F64Type.get()), ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")), ir.IntegerType.get_signless(32), ir.IntegerType.get_signless(32), - name="async_load", - )( - lambda source, destination, barrier, *indices: mgpu.dialect.async_load( - source, - destination, - barrier, - indices, - slice_lengths=[4, 8], - collective=ir.ArrayAttr.get([]), - ) + ) + mgpu.dialect.async_load( + source, + destination, + barrier, + indices, + slice_lengths=[4, 8], + collective=ir.ArrayAttr.get([]), ) with self.assertRaisesRegex( @@ -200,22 +168,20 @@ def test_async_load_op_source_and_dest_must_have_same_element_type(self): def test_async_load_op_slice_lengths_must_be_larger_than_minus_two(self): with ir.InsertionPoint(self.module.body): - func.FuncOp.from_py_func( + source, destination, barrier, *indices = undefs( ir.MemRefType.get([4, 8], ir.F32Type.get()), ir.MemRefType.get([4, 8], ir.F32Type.get()), ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")), ir.IntegerType.get_signless(32), ir.IntegerType.get_signless(32), - name="async_load", - )( - lambda source, destination, barrier, *indices: mgpu.dialect.async_load( - source, - destination, - barrier, - indices, - slice_lengths=[-2, 8], - collective=ir.ArrayAttr.get([]), - ) + ) + mgpu.dialect.async_load( + source, + destination, + barrier, + indices, + slice_lengths=[-2, 8], + collective=ir.ArrayAttr.get([]), ) with self.assertRaisesRegex( @@ -226,23 +192,21 @@ def test_async_load_op_slice_lengths_must_be_larger_than_minus_two(self): def test_async_load_op_source_and_dest_ranks_must_match_with_collapse(self): with ir.InsertionPoint(self.module.body): - func.FuncOp.from_py_func( + source, destination, barrier, *indices = undefs( ir.MemRefType.get([1, 4, 8], ir.F32Type.get()), ir.MemRefType.get([4], ir.F32Type.get()), ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")), ir.IntegerType.get_signless(32), ir.IntegerType.get_signless(32), ir.IntegerType.get_signless(32), - name="async_load", - )( - lambda source, destination, barrier, *indices: mgpu.dialect.async_load( - source, - destination, - barrier, - indices, - slice_lengths=[-1, 4, 8], - collective=ir.ArrayAttr.get([]), - ) + ) + mgpu.dialect.async_load( + source, + destination, + barrier, + indices, + slice_lengths=[-1, 4, 8], + collective=ir.ArrayAttr.get([]), ) with self.assertRaisesRegex( @@ -253,21 +217,19 @@ def test_async_load_op_source_and_dest_ranks_must_match_with_collapse(self): def test_async_load_op_indices_size_must_match_source_rank(self): with ir.InsertionPoint(self.module.body): - func.FuncOp.from_py_func( + source, destination, barrier, *indices = undefs( ir.MemRefType.get([4, 8], ir.F32Type.get()), ir.MemRefType.get([4, 8], ir.F32Type.get()), ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")), ir.IntegerType.get_signless(32), - name="async_load", - )( - lambda source, destination, barrier, *indices: mgpu.dialect.async_load( - source, - destination, - barrier, - indices, - slice_lengths=[4, 8], - collective=ir.ArrayAttr.get([]), - ) + ) + mgpu.dialect.async_load( + source, + destination, + barrier, + indices, + slice_lengths=[4, 8], + collective=ir.ArrayAttr.get([]), ) with self.assertRaisesRegex( @@ -278,21 +240,19 @@ def test_async_load_op_indices_size_must_match_source_rank(self): def test_async_load_op_slice_lengths_size_must_match_source_rank(self): with ir.InsertionPoint(self.module.body): - func.FuncOp.from_py_func( + source, destination, barrier, *indices = undefs( ir.MemRefType.get([4], ir.F32Type.get()), ir.MemRefType.get([4], ir.F32Type.get()), ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")), ir.IntegerType.get_signless(32), - name="async_load", - )( - lambda source, destination, barrier, *indices: mgpu.dialect.async_load( - source, - destination, - barrier, - indices, - slice_lengths=[4, 8], - collective=ir.ArrayAttr.get([]), - ) + ) + mgpu.dialect.async_load( + source, + destination, + barrier, + indices, + slice_lengths=[4, 8], + collective=ir.ArrayAttr.get([]), ) with self.assertRaisesRegex( @@ -304,24 +264,22 @@ def test_async_load_op_slice_lengths_size_must_match_source_rank(self): def test_async_load_op_slice_collective_must_be_unique(self): with ir.InsertionPoint(self.module.body): i32 = ir.IntegerType.get_signless(32) - func.FuncOp.from_py_func( + source, destination, barrier, *indices = undefs( ir.MemRefType.get([4], ir.F32Type.get()), ir.MemRefType.get([4], ir.F32Type.get()), ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")), i32, - name="async_load", - )( - lambda source, destination, barrier, *indices: mgpu.dialect.async_load( - source, - destination, - barrier, - indices, - slice_lengths=[4], - collective=ir.ArrayAttr.get([ - ir.IntegerAttr.get(i32, mgpu.dialect.Dimension.x), - ir.IntegerAttr.get(i32, mgpu.dialect.Dimension.x), - ]), - ) + ) + mgpu.dialect.async_load( + source, + destination, + barrier, + indices, + slice_lengths=[4], + collective=ir.ArrayAttr.get([ + ir.IntegerAttr.get(i32, mgpu.dialect.Dimension.x), + ir.IntegerAttr.get(i32, mgpu.dialect.Dimension.x), + ]), ) with self.assertRaisesRegex( @@ -330,48 +288,77 @@ def test_async_load_op_slice_collective_must_be_unique(self): ): self.module.operation.verify() - def test_async_store_op_source_must_be_contiguous(self): + def test_async_load_op_vector_indices_shape_must_match_slice_lengths(self): + # TODO(b/415721295): Remove when the minimum jaxlib version is 0.8.3. + if not hasattr(mgpu.dialect, "tma_gather_supported"): + self.skipTest("TMA gather support is required.") + with ir.InsertionPoint(self.module.body): - func.FuncOp.from_py_func( - ir.MemRefType.get( - [4, 8], - ir.F32Type.get(), - layout=ir.Attribute.parse("strided<[16, 1]>"), - ), + i32 = ir.IntegerType.get_signless(32) + source, destination, barrier, *indices = undefs( ir.MemRefType.get([4, 8], ir.F32Type.get()), - ir.IntegerType.get_signless(32), - ir.IntegerType.get_signless(32), - name="async_store", - )( - lambda source, destination, *indices: mgpu.dialect.async_store( - source, - destination, - indices, - slice_lengths=[4, 8], - ) + ir.MemRefType.get([4, 8], ir.F32Type.get()), + ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")), + i32, + ir.VectorType.get([4], i32), + ) + mgpu.dialect.async_load( + source, + destination, + barrier, + indices, + slice_lengths=[4, 8], + collective=ir.ArrayAttr.get([]), ) with self.assertRaisesRegex( ir.MLIRError, - "The `source` memref must be contiguous", + "The size of the vector index must be equal to the slice length", + ): + self.module.operation.verify() + + def test_async_load_op_only_one_vector_index_allowed(self): + # TODO(b/415721295): Remove when the minimum jaxlib version is 0.8.3. + if not hasattr(mgpu.dialect, "tma_gather_supported"): + self.skipTest("TMA gather support is required.") + + with ir.InsertionPoint(self.module.body): + i32 = ir.IntegerType.get_signless(32) + source, destination, barrier, *indices = undefs( + ir.MemRefType.get([4, 8], ir.F32Type.get()), + ir.MemRefType.get([4, 8], ir.F32Type.get()), + ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")), + ir.VectorType.get([4], i32), + ir.VectorType.get([4], i32), + ) + mgpu.dialect.async_load( + source, + destination, + barrier, + indices, + slice_lengths=[4, 4], + collective=ir.ArrayAttr.get([]), + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "Only one index may be a vector.*dimensions 0 and 1.", ): self.module.operation.verify() def test_async_store_op_source_and_dest_must_have_same_element_type(self): with ir.InsertionPoint(self.module.body): - func.FuncOp.from_py_func( + source, destination, *indices = undefs( ir.MemRefType.get([4, 8], ir.F32Type.get()), ir.MemRefType.get([4, 8], ir.F64Type.get()), ir.IntegerType.get_signless(32), ir.IntegerType.get_signless(32), - name="async_store", - )( - lambda source, destination, *indices: mgpu.dialect.async_store( - source, - destination, - indices, - slice_lengths=[4, 8], - ) + ) + mgpu.dialect.async_store( + source, + destination, + indices, + slice_lengths=[4, 8], ) with self.assertRaisesRegex( @@ -382,19 +369,17 @@ def test_async_store_op_source_and_dest_must_have_same_element_type(self): def test_async_store_op_slice_lengths_must_be_larger_than_minus_two(self): with ir.InsertionPoint(self.module.body): - func.FuncOp.from_py_func( + source, destination, *indices = undefs( ir.MemRefType.get([4, 8], ir.F32Type.get()), ir.MemRefType.get([4, 8], ir.F32Type.get()), ir.IntegerType.get_signless(32), ir.IntegerType.get_signless(32), - name="async_store", - )( - lambda source, destination, *indices: mgpu.dialect.async_store( - source, - destination, - indices, - slice_lengths=[-2, 8], - ) + ) + mgpu.dialect.async_store( + source, + destination, + indices, + slice_lengths=[-2, 8], ) with self.assertRaisesRegex( @@ -405,20 +390,18 @@ def test_async_store_op_slice_lengths_must_be_larger_than_minus_two(self): def test_async_store_op_source_and_dest_ranks_must_match_with_collapse(self): with ir.InsertionPoint(self.module.body): - func.FuncOp.from_py_func( + source, destination, *indices = undefs( ir.MemRefType.get([4], ir.F32Type.get()), ir.MemRefType.get([1, 4, 8], ir.F32Type.get()), ir.IntegerType.get_signless(32), ir.IntegerType.get_signless(32), ir.IntegerType.get_signless(32), - name="async_store", - )( - lambda source, destination, *indices: mgpu.dialect.async_store( - source, - destination, - indices, - slice_lengths=[-1, 4, 8], - ) + ) + mgpu.dialect.async_store( + source, + destination, + indices, + slice_lengths=[-1, 4, 8], ) with self.assertRaisesRegex( @@ -429,18 +412,16 @@ def test_async_store_op_source_and_dest_ranks_must_match_with_collapse(self): def test_async_store_op_indices_size_must_match_destination_rank(self): with ir.InsertionPoint(self.module.body): - func.FuncOp.from_py_func( + source, destination, *indices = undefs( ir.MemRefType.get([4, 8], ir.F32Type.get()), ir.MemRefType.get([4, 8], ir.F32Type.get()), ir.IntegerType.get_signless(32), - name="async_store", - )( - lambda source, destination, *indices: mgpu.dialect.async_store( - source, - destination, - indices, - slice_lengths=[4, 8], - ) + ) + mgpu.dialect.async_store( + source, + destination, + indices, + slice_lengths=[4, 8], ) with self.assertRaisesRegex( @@ -451,18 +432,16 @@ def test_async_store_op_indices_size_must_match_destination_rank(self): def test_async_store_op_slice_lengths_size_must_match_source_rank(self): with ir.InsertionPoint(self.module.body): - func.FuncOp.from_py_func( + source, destination, *indices = undefs( ir.MemRefType.get([4], ir.F32Type.get()), ir.MemRefType.get([4], ir.F32Type.get()), ir.IntegerType.get_signless(32), - name="async_store", - )( - lambda source, destination, *indices: mgpu.dialect.async_store( - source, - destination, - indices, - slice_lengths=[4, 8], - ) + ) + mgpu.dialect.async_store( + source, + destination, + indices, + slice_lengths=[4, 8], ) with self.assertRaisesRegex( @@ -474,12 +453,12 @@ def test_async_store_op_slice_lengths_size_must_match_source_rank(self): def test_wgmma_types_match(self): with ir.InsertionPoint(self.module.body): - func.FuncOp.from_py_func( - ir.VectorType.get([128, 160], ir.BF16Type.get()), + acc, a, b = undefs( + ir.VectorType.get([128, 160], ir.F32Type.get()), ir.MemRefType.get([128, 128], ir.F16Type.get()), ir.MemRefType.get([128, 160], ir.BF16Type.get()), - name="wgmma", - )(mgpu.dialect.wgmma) + ) + mgpu.dialect.wgmma(acc, a, b) with self.assertRaisesRegex( ir.MLIRError, @@ -487,74 +466,108 @@ def test_wgmma_types_match(self): ): self.module.operation.verify() - def test_wgmma_a_rank_is_2(self): + def test_wgmma_acc_m_dim_not_multiple_of_64(self): with ir.InsertionPoint(self.module.body): - func.FuncOp.from_py_func( - ir.VectorType.get([128, 160], ir.BF16Type.get()), - ir.MemRefType.get([3, 128, 128], ir.BF16Type.get()), + acc, a, b = undefs( + ir.VectorType.get([127, 160], ir.F32Type.get()), + ir.MemRefType.get([127, 128], ir.BF16Type.get()), ir.MemRefType.get([128, 160], ir.BF16Type.get()), - name="wgmma", - )(mgpu.dialect.wgmma) + ) + mgpu.dialect.wgmma(acc, a, b) with self.assertRaisesRegex( ir.MLIRError, - "The `a` input must have rank 2.", + r"accumulator.*must be a multiple of 64", ): self.module.operation.verify() - def test_wgmma_b_rank_is_2(self): + def test_wgmma_acc_m_not_equal_to_a_m_dim(self): with ir.InsertionPoint(self.module.body): - func.FuncOp.from_py_func( - ir.VectorType.get([128, 160], ir.BF16Type.get()), - ir.MemRefType.get([128, 128], ir.BF16Type.get()), - ir.MemRefType.get([2, 128, 160], ir.BF16Type.get()), - name="wgmma", - )(mgpu.dialect.wgmma) + acc, a, b = undefs( + ir.VectorType.get([256, 160], ir.F32Type.get()), + ir.MemRefType.get([512, 128], ir.BF16Type.get()), + ir.MemRefType.get([128, 160], ir.BF16Type.get()), + ) + mgpu.dialect.wgmma(acc, a, b) with self.assertRaisesRegex( ir.MLIRError, - "The `b` input must have rank 2.", + r"accumulator's first dimension 256 must be equal to.*`a`", ): self.module.operation.verify() - def test_wgmma_acc_rank_is_2(self): + def test_wgmma_a_k_dim_not_equal_to_b_k_dim(self): with ir.InsertionPoint(self.module.body): - func.FuncOp.from_py_func( - ir.VectorType.get([2, 128, 160], ir.BF16Type.get()), + acc, a, b = undefs( + ir.VectorType.get([128, 160], ir.F32Type.get()), ir.MemRefType.get([128, 128], ir.BF16Type.get()), - ir.MemRefType.get([128, 160], ir.BF16Type.get()), - name="wgmma", - )(mgpu.dialect.wgmma) + ir.MemRefType.get([160, 160], ir.BF16Type.get()), + ) + mgpu.dialect.wgmma(acc, a, b) with self.assertRaisesRegex( ir.MLIRError, - "The accumulator must have rank 2.", + "`a`'s contracting dimension 128 must be equal to the first dimension" + " of `b`", ): self.module.operation.verify() - def test_wgmma_acc_m_dim_not_multiple_of_64(self): + def test_wgmma_b_n_dim_not_equal_to_acc_n_dim(self): with ir.InsertionPoint(self.module.body): - func.FuncOp.from_py_func( - ir.VectorType.get([127, 160], ir.BF16Type.get()), + acc, a, b = undefs( + ir.VectorType.get([128, 160], ir.F32Type.get()), ir.MemRefType.get([128, 128], ir.BF16Type.get()), - ir.MemRefType.get([128, 160], ir.BF16Type.get()), - name="wgmma", - )(mgpu.dialect.wgmma) + ir.MemRefType.get([128, 192], ir.BF16Type.get()), + ) + mgpu.dialect.wgmma(acc, a, b) with self.assertRaisesRegex( ir.MLIRError, - r"accumulator.*must be a multiple of 64", + r"`b`'s non-contracting dimension 192 must be equal to the", ): self.module.operation.verify() - def test_wgmma_acc_m_not_equal_to_a_m_dim(self): + def test_tcgen05_mma_types_match(self): with ir.InsertionPoint(self.module.body): - func.FuncOp.from_py_func( - ir.VectorType.get([256, 160], ir.BF16Type.get()), - ir.MemRefType.get([512, 128], ir.BF16Type.get()), + acc, a, b, accumulate = undefs( + ir.MemRefType.get([128, 160], ir.F16Type.get()), + ir.MemRefType.get([128, 128], ir.F16Type.get()), ir.MemRefType.get([128, 160], ir.BF16Type.get()), - name="wgmma", - )(mgpu.dialect.wgmma) + ir.IntegerType.get_signless(1), + ) + mgpu.dialect.tcgen05_mma(acc, a, b, accumulate) + + with self.assertRaisesRegex( + ir.MLIRError, + "The `a` and `b` inputs must have the same element type.", + ): + self.module.operation.verify() + + def test_tcgen05_mma_acc_m_dim_not_multiple_of_128(self): + with ir.InsertionPoint(self.module.body): + acc, a, b, accumulate = undefs( + ir.MemRefType.get([127, 160], ir.F16Type.get()), + ir.MemRefType.get([127, 128], ir.F16Type.get()), + ir.MemRefType.get([128, 160], ir.F16Type.get()), + ir.IntegerType.get_signless(1), + ) + mgpu.dialect.tcgen05_mma(acc, a, b, accumulate) + + with self.assertRaisesRegex( + ir.MLIRError, + r"accumulator.*must be a multiple of 32", + ): + self.module.operation.verify() + + def test_tcgen05_mma_acc_m_not_equal_to_a_m_dim(self): + with ir.InsertionPoint(self.module.body): + acc, a, b, accumulate = undefs( + ir.MemRefType.get([256, 160], ir.F16Type.get()), + ir.MemRefType.get([512, 128], ir.F16Type.get()), + ir.MemRefType.get([128, 160], ir.F16Type.get()), + ir.IntegerType.get_signless(1), + ) + mgpu.dialect.tcgen05_mma(acc, a, b, accumulate) with self.assertRaisesRegex( ir.MLIRError, @@ -562,29 +575,32 @@ def test_wgmma_acc_m_not_equal_to_a_m_dim(self): ): self.module.operation.verify() - def test_wgmma_a_k_dim_not_equal_to_b_k_dim(self): + def test_tcgen05_mma_a_k_dim_not_equal_to_b_k_dim(self): with ir.InsertionPoint(self.module.body): - func.FuncOp.from_py_func( - ir.VectorType.get([128, 160], ir.BF16Type.get()), - ir.MemRefType.get([128, 128], ir.BF16Type.get()), - ir.MemRefType.get([160, 160], ir.BF16Type.get()), - name="wgmma", - )(mgpu.dialect.wgmma) + acc, a, b, accumulate = undefs( + ir.MemRefType.get([128, 160], ir.F16Type.get()), + ir.MemRefType.get([128, 128], ir.F16Type.get()), + ir.MemRefType.get([160, 160], ir.F16Type.get()), + ir.IntegerType.get_signless(1), + ) + mgpu.dialect.tcgen05_mma(acc, a, b, accumulate) with self.assertRaisesRegex( ir.MLIRError, - r"`a`'s contracting dimension 128 must be equal to one of.*`b`", + "`a`'s contracting dimension 128 must be equal to the first dimension" + " of `b`", ): self.module.operation.verify() - def test_wgmma_b_n_dim_not_equal_to_acc_n_dim(self): + def test_tcgen05_mma_b_n_dim_not_equal_to_acc_n_dim(self): with ir.InsertionPoint(self.module.body): - func.FuncOp.from_py_func( - ir.VectorType.get([128, 160], ir.BF16Type.get()), - ir.MemRefType.get([128, 128], ir.BF16Type.get()), - ir.MemRefType.get([128, 192], ir.BF16Type.get()), - name="wgmma", - )(mgpu.dialect.wgmma) + acc, a, b, accumulate = undefs( + ir.MemRefType.get([128, 160], ir.F16Type.get()), + ir.MemRefType.get([128, 128], ir.F16Type.get()), + ir.MemRefType.get([128, 192], ir.F16Type.get()), + ir.IntegerType.get_signless(1), + ) + mgpu.dialect.tcgen05_mma(acc, a, b, accumulate) with self.assertRaisesRegex( ir.MLIRError, @@ -592,15 +608,573 @@ def test_wgmma_b_n_dim_not_equal_to_acc_n_dim(self): ): self.module.operation.verify() + def test_tcgen05_mma_b_n_dim_not_equal_to_half_acc_n_dim(self): + with ir.InsertionPoint(self.module.body): + acc, a, b, accumulate = undefs( + ir.MemRefType.get([128, 160], ir.F16Type.get()), + ir.MemRefType.get([128, 128], ir.F16Type.get()), + ir.MemRefType.get([128, 160], ir.F16Type.get()), + ir.IntegerType.get_signless(1), + ) + mgpu.dialect.tcgen05_mma(acc, a, b, accumulate, collective=True) + + with self.assertRaisesRegex( + ir.MLIRError, + r"`b`'s non-contracting dimension 160 must be half", + ): + self.module.operation.verify() + + def test_tcgen05_mma_acc_mem_space_is_tmem(self): + smem = mgpu_utils.smem() + with ir.InsertionPoint(self.module.body): + acc, a, b, accumulate = undefs( + ir.MemRefType.get([128, 160], ir.F16Type.get(), memory_space=smem), + ir.MemRefType.get([128, 128], ir.F16Type.get()), + ir.MemRefType.get([128, 160], ir.F16Type.get()), + ir.IntegerType.get_signless(1), + ) + mgpu.dialect.tcgen05_mma(acc, a, b, accumulate) + + with self.assertRaisesRegex( + ir.MLIRError, + r"The accumulator must be in TMEM", + ): + self.module.operation.verify() + + def test_tcgen05_mma_a_mem_space_is_smem_or_tmem(self): + tmem = mgpu_utils.tmem() + with ir.InsertionPoint(self.module.body): + acc, a, b, accumulate = undefs( + ir.MemRefType.get([128, 160], ir.F16Type.get(), memory_space=tmem), + ir.MemRefType.get([128, 128], ir.F16Type.get()), + ir.MemRefType.get([128, 160], ir.F16Type.get()), + ir.IntegerType.get_signless(1), + ) + mgpu.dialect.tcgen05_mma(acc, a, b, accumulate) + + with self.assertRaisesRegex( + ir.MLIRError, + r"The `a` input must be in TMEM or SMEM", + ): + self.module.operation.verify() + + def test_tcgen05_mma_b_mem_space_is_smem(self): + smem, tmem = mgpu_utils.smem(), mgpu_utils.tmem() + with ir.InsertionPoint(self.module.body): + acc, a, b, accumulate = undefs( + ir.MemRefType.get([128, 160], ir.F16Type.get(), memory_space=tmem), + ir.MemRefType.get([128, 128], ir.F16Type.get(), memory_space=smem), + ir.MemRefType.get([128, 160], ir.F16Type.get(), memory_space=tmem), + ir.IntegerType.get_signless(1), + ) + mgpu.dialect.tcgen05_mma(acc, a, b, accumulate) + + with self.assertRaisesRegex( + ir.MLIRError, + r"The `b` input must be in SMEM", + ): + self.module.operation.verify() + + def test_tcgen05_mma_scale_arg_missing(self): + smem, tmem = mgpu_utils.smem(), mgpu_utils.tmem() + f8e0m0 = ir.Float8E8M0FNUType.get() + with ir.InsertionPoint(self.module.body): + acc, a, b, accumulate, a_scale = undefs( + ir.MemRefType.get([128, 160], ir.F16Type.get(), memory_space=tmem), + ir.MemRefType.get([128, 128], ir.F16Type.get(), memory_space=smem), + ir.MemRefType.get([128, 160], ir.F16Type.get(), memory_space=smem), + ir.IntegerType.get_signless(1), + ir.MemRefType.get([128, 4], f8e0m0, memory_space=tmem), + ) + mgpu.dialect.tcgen05_mma(acc, a, b, accumulate, a_scale=a_scale) + + with self.assertRaisesRegex( + ir.MLIRError, + r"Either none or both scales should be provided.", + ): + self.module.operation.verify() + + def test_tcgen05_mma_a_scale_mem_space_is_tmem(self): + smem, tmem = mgpu_utils.smem(), mgpu_utils.tmem() + f8e0m0 = ir.Float8E8M0FNUType.get() + with ir.InsertionPoint(self.module.body): + acc, a, b, accumulate, a_scale, b_scale = undefs( + ir.MemRefType.get([128, 160], ir.F16Type.get(), memory_space=tmem), + ir.MemRefType.get([128, 128], ir.F16Type.get(), memory_space=smem), + ir.MemRefType.get([128, 160], ir.F16Type.get(), memory_space=smem), + ir.IntegerType.get_signless(1), + ir.MemRefType.get([128, 4], f8e0m0, memory_space=smem), + ir.MemRefType.get([160, 4], f8e0m0, memory_space=tmem), + ) + mgpu.dialect.tcgen05_mma( + acc, a, b, accumulate, a_scale=a_scale, b_scale=b_scale + ) + + with self.assertRaisesRegex( +ir.MLIRError, + r"The `a_scale` input must be in TMEM", + ): + self.module.operation.verify() + + def test_tcgen05_mma_b_scale_mem_space_is_tmem(self): + smem, tmem = mgpu_utils.smem(), mgpu_utils.tmem() + f8e0m0 = ir.Float8E8M0FNUType.get() + with ir.InsertionPoint(self.module.body): + acc, a, b, accumulate, a_scale, b_scale = undefs( + ir.MemRefType.get([128, 160], ir.F16Type.get(), memory_space=tmem), + ir.MemRefType.get([128, 128], ir.F16Type.get(), memory_space=smem), + ir.MemRefType.get([128, 160], ir.F16Type.get(), memory_space=smem), + ir.IntegerType.get_signless(1), + ir.MemRefType.get([128, 4], f8e0m0, memory_space=tmem), + ir.MemRefType.get([160, 4], f8e0m0, memory_space=smem), + ) + mgpu.dialect.tcgen05_mma( + acc, a, b, accumulate, a_scale=a_scale, b_scale=b_scale + ) + + with self.assertRaisesRegex( + ir.MLIRError, + r"The `b_scale` input must be in TMEM", + ): + self.module.operation.verify() + + def test_tiled_layout_attr_parsing(self): + with ir.InsertionPoint(self.module.body): + for layout in ( + mgpu.WGMMA_LAYOUT, + mgpu.WGMMA_ROW_LAYOUT, + mgpu.WGMMA_COL_LAYOUT, + mgpu.WGMMA_TRANSPOSED_LAYOUT, + ): + attr = layouts.to_tiled_layout_attr(layout) + parsed_layout = layouts.from_tiled_layout_attr(attr) + self.assertEqual(layout, parsed_layout) + + def test_broadcast_in_dim_ok(self): + with ir.InsertionPoint(self.module.body): + (operand,) = undefs(ir.VectorType.get([64], ir.F32Type.get())) + mgpu.dialect.broadcast_in_dim( + ir.VectorType.get([64, 64], ir.F32Type.get()), + operand, + broadcast_dimensions=[0], + ) + + self.assertTrue(self.module.operation.verify()) + + def test_broadcast_in_dim_no_0d(self): + with ir.InsertionPoint(self.module.body): + (operand,) = undefs(ir.VectorType.get([], ir.F32Type.get())) + mgpu.dialect.broadcast_in_dim( + ir.VectorType.get([64], ir.F32Type.get()), + operand, + broadcast_dimensions=[], + ) + + with self.assertRaisesRegex( + ir.MLIRError, + r"The input vector must have rank > 0", + ): + self.module.operation.verify() + + def test_broadcast_in_dim_no_input_larger_than_output(self): + with ir.InsertionPoint(self.module.body): + (operand,) = undefs(ir.VectorType.get([64, 64], ir.F32Type.get())) + mgpu.dialect.broadcast_in_dim( + ir.VectorType.get([64], ir.F32Type.get()), + operand, + broadcast_dimensions=[], + ) + + with self.assertRaisesRegex( + ir.MLIRError, + r"rank of the input vector must be smaller", + ): + self.module.operation.verify() + + def test_broadcast_in_dim_too_many_dims(self): + with ir.InsertionPoint(self.module.body): + (operand,) = undefs(ir.VectorType.get([64], ir.F32Type.get())) + mgpu.dialect.broadcast_in_dim( + ir.VectorType.get([64, 64], ir.F32Type.get()), + operand, + broadcast_dimensions=[0, 1], + ) + + with self.assertRaisesRegex( + ir.MLIRError, + r"size of the `broadcast_dimensions` attribute must be", + ): + self.module.operation.verify() + + def test_broadcast_in_dim_dim_oob(self): + with ir.InsertionPoint(self.module.body): + (operand,) = undefs(ir.VectorType.get([64], ir.F32Type.get())) + mgpu.dialect.broadcast_in_dim( + ir.VectorType.get([64, 64], ir.F32Type.get()), + operand, + broadcast_dimensions=[2], + ) + + with self.assertRaisesRegex( + ir.MLIRError, + r"must be in the range \[0, result.shape.rank", + ): + self.module.operation.verify() + + def test_broadcast_in_dim_dim_transpose(self): + with ir.InsertionPoint(self.module.body): + (operand,) = undefs(ir.VectorType.get([64, 64, 64, 64], ir.F32Type.get())) + mgpu.dialect.broadcast_in_dim( + ir.VectorType.get([64, 64, 64, 64], ir.F32Type.get()), + operand, + broadcast_dimensions=[0, 1, 3, 2], + ) + + with self.assertRaisesRegex( + ir.MLIRError, + r"`broadcast_dimensions` attribute must be strictly increasing", + ): + self.module.operation.verify() + + def test_custom_primitive_op_args_must_match_args_of_terminator(self): + with ir.InsertionPoint(self.module.body): + shape = (128,) + elt_ty = ir.F32Type.get() + ty = ir.VectorType.get(shape, elt_ty) + strided_layout = mgpu.WGStridedFragLayout.from_shaped_type(ty) + assert strided_layout is not None + out_layouts = ir.ArrayAttr.get([layouts.to_layout_attr(strided_layout)]) + + op = mgpu.dialect.CustomPrimitiveOp( + result=[ty], + operands_=[], + in_layouts=[], + in_transforms=[], + out_layouts=out_layouts, + ) + block = op.body.blocks.append() + with ir.InsertionPoint(block): + v = llvm.mlir_undef(ir.VectorType.get([256], ir.F32Type.get())) + mgpu.dialect.ReturnOp(operands_=[v]) + + with self.assertRaisesRegex( + ir.MLIRError, + r"type of return operand 0 \('vector<256xf32>'\) doesn't match the" + r" result type \('vector<128xf32>'\) in custom_primitive", + ): + self.module.operation.verify() + + def test_custom_primitive_op_results_must_be_scalar_or_vector(self): + with ir.InsertionPoint(self.module.body): + ref_ty = ir.MemRefType.get((128, 128), ir.F32Type.get()) + op = mgpu.dialect.CustomPrimitiveOp( + result=[ref_ty], + operands_=[], + in_layouts=[], + in_transforms=[], + out_layouts=[], + ) + block = op.body.blocks.append() + with ir.InsertionPoint(block): + [ref] = undefs(ref_ty) + mgpu.dialect.ReturnOp(operands_=[ref]) + + with self.assertRaisesRegex( + ir.MLIRError, + r"Custom primitive can only return scalars or vectors.", + ): + self.module.operation.verify() + + def test_tmem_alloc_op_must_have_smem_ref_input(self): + with ir.InsertionPoint(self.module.body): + (smem_ptr,) = undefs( + ir.MemRefType.get([], ir.IntegerType.get_signless(32)) + ) + mgpu.dialect.tmem_alloc( + result=ir.MemRefType.get( + [128, 32], + ir.BF16Type.get(), + memory_space=mgpu_utils.tmem(), + ), + smem_ptr=smem_ptr, + collective=False, + packing=1, + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "The `smem_ptr` memref must have the Workgroup address space", + ): + self.module.operation.verify() + + def test_tmem_alloc_op_result_must_have_tmem_memory_space(self): + with ir.InsertionPoint(self.module.body): + (smem_ptr,) = undefs( + ir.MemRefType.get( + [], + ir.IntegerType.get_signless(32), + memory_space=mgpu_utils.smem(), + ) + ) + mgpu.dialect.tmem_alloc( + result=ir.MemRefType.get( + [128, 32], + ir.BF16Type.get(), + ), + smem_ptr=smem_ptr, + collective=False, + packing=1, + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "The tmem memref must have a mosaic_gpu.tmem memory space", + ): + self.module.operation.verify() + + def test_tmem_alloc_op_exact_column_count_must_be_at_most_512(self): + with ir.InsertionPoint(self.module.body): + (smem_ptr,) = undefs( + ir.MemRefType.get( + [], + ir.IntegerType.get_signless(32), + memory_space=mgpu_utils.smem(), + ) + ) + mgpu.dialect.tmem_alloc( + result=ir.MemRefType.get( + [128, 1024], + ir.BF16Type.get(), + memory_space=mgpu_utils.tmem(), + ), + smem_ptr=smem_ptr, + collective=False, + packing=1, + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "The number of allocated columns must be less than or equal to 512 but" + " got: 1024", + ): + self.module.operation.verify() + + def test_tmem_alloc_op_bad_packing(self): + with ir.InsertionPoint(self.module.body): + (smem_ptr,) = undefs( + ir.MemRefType.get( + [], + ir.IntegerType.get_signless(32), + memory_space=mgpu_utils.smem(), + ) + ) + mgpu.dialect.tmem_alloc( + result=ir.MemRefType.get( + [128, 128], + ir.BF16Type.get(), + memory_space=mgpu_utils.tmem(), + ), + smem_ptr=smem_ptr, + collective=False, + packing=4, + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "Only unpacked, or fully packed allocations are supported.", + ): + self.module.operation.verify() + + def test_tmem_alloc_op_exact_false_column_count_15_ok(self): + with ir.InsertionPoint(self.module.body): + (smem_ptr,) = undefs( + ir.MemRefType.get( + [], + ir.IntegerType.get_signless(32), + memory_space=mgpu_utils.smem(), + ) + ) + mgpu.dialect.tmem_alloc( + result=ir.MemRefType.get( + [128, 15], + ir.BF16Type.get(), + memory_space=mgpu_utils.tmem(), + ), + smem_ptr=smem_ptr, + collective=False, + packing=1, + ) + + self.assertTrue(self.module.operation.verify()) + + def test_tmem_alloc_op_exact_false_column_count_100_ok(self): + with ir.InsertionPoint(self.module.body): + (smem_ptr,) = undefs( + ir.MemRefType.get( + [], + ir.IntegerType.get_signless(32), + memory_space=mgpu_utils.smem(), + ) + ) + mgpu.dialect.tmem_alloc( + result=ir.MemRefType.get( + [128, 100], + ir.BF16Type.get(), + memory_space=mgpu_utils.tmem(), + ), + smem_ptr=smem_ptr, + collective=False, + packing=1, + ) + + self.assertTrue(self.module.operation.verify()) + + def test_tmem_alloc_op_exact_false_column_count_777_packed_not_ok(self): + with ir.InsertionPoint(self.module.body): + (smem_ptr,) = undefs( + ir.MemRefType.get( + [], + ir.IntegerType.get_signless(32), + memory_space=mgpu_utils.smem(), + ) + ) + mgpu.dialect.tmem_alloc( + result=ir.MemRefType.get( + [128, 777], + ir.BF16Type.get(), + memory_space=mgpu_utils.tmem(), + ), + smem_ptr=smem_ptr, + collective=False, + packing=2, + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "The number of unpacked columns must be divisible by the packing", + ): + self.module.operation.verify() + + def test_tmem_alloc_op_exact_false_column_count_778_packed_ok(self): + with ir.InsertionPoint(self.module.body): + (smem_ptr,) = undefs( + ir.MemRefType.get( + [], + ir.IntegerType.get_signless(32), + memory_space=mgpu_utils.smem(), + ) + ) + mgpu.dialect.tmem_alloc( + result=ir.MemRefType.get( + [128, 778], + ir.BF16Type.get(), + memory_space=mgpu_utils.tmem(), + ), + smem_ptr=smem_ptr, + collective=False, + packing=2, + ) + + self.assertTrue(self.module.operation.verify()) + + def test_tmem_alloc_dealloc_packed_large_shape_ok(self): + with ir.InsertionPoint(self.module.body): + ref_ty = ir.MemRefType.get( + [128, 1024], + ir.BF16Type.get(), + memory_space=mgpu_utils.tmem(), + ) + (smem_ptr,) = undefs( + ir.MemRefType.get( + [], + ir.IntegerType.get_signless(32), + memory_space=mgpu_utils.smem(), + ) + ) + # This allocation would exceed the 512 columns limit if it were not packed. + ref = mgpu.dialect.tmem_alloc( + result=ref_ty, + smem_ptr=smem_ptr, + collective=False, + packing=2, + ) + mgpu.dialect.tmem_dealloc(ref) + self.assertTrue(self.module.operation.verify()) + + def test_tmem_layout_cast_invalid_tmem_ref(self): + with ir.InsertionPoint(self.module.body): + (tmem_ref,) = undefs( + ir.MemRefType.get( + [128, 128], + ir.BF16Type.get(), + memory_space=mgpu_utils.smem(), + ) + ) + mgpu.dialect.tmem_layout_cast( + tmem_ref, layouts.to_layout_attr(tcgen05.TMEM_NATIVE_LAYOUT) + ) + with self.assertRaisesRegex( + ir.MLIRError, + "The tmem memref must have a mosaic_gpu.tmem memory space", + ): + self.module.operation.verify() + + def test_vector_store_op_src_dst_shape_mismatch(self): + with ir.InsertionPoint(self.module.body): + src_ty = ir.VectorType.get((8,), ir.BF16Type.get()) + dst_ty = ir.MemRefType.get((4,), ir.BF16Type.get()) + (src, dst) = undefs(src_ty, dst_ty) + mgpu.dialect.vector_store(src, dst) + with self.assertRaisesRegex( + ir.MLIRError, + "The source and destination must have the same shape", + ): + self.module.operation.verify() + + def test_vector_store_op_src_dst_dtype_mismatch(self): + with ir.InsertionPoint(self.module.body): + src_ty = ir.VectorType.get((8,), ir.BF16Type.get()) + dst_ty = ir.MemRefType.get((8,), ir.F32Type.get()) + (src, dst) = undefs(src_ty, dst_ty) + mgpu.dialect.vector_store(src, dst) + with self.assertRaisesRegex( + ir.MLIRError, + "The source and destination must have the same element type", + ): + self.module.operation.verify() + + def test_broadcasted_iota_op_invalid_dimension(self): + with ir.InsertionPoint(self.module.body): + ty = ir.VectorType.get((2,), ir.F32Type.get()) + mgpu.dialect.broadcasted_iota(ty, dimension=2) + with self.assertRaisesRegex( + ir.MLIRError, + "dimension=2 must be smaller than the rank=1 of the result.", + ): + self.module.operation.verify() + + def test_print_layout_op_invalid_ref(self): + with ir.InsertionPoint(self.module.body): + ref_ty = ir.MemRefType.get( + (2,), ir.F32Type.get(), memory_space=mgpu_utils.smem() + ) + (ref,) = undefs(ref_ty) + mgpu.dialect.print_layout("tmem: {}", ref) + with self.assertRaisesRegex( + ir.MLIRError, + "The tmem memref must have a mosaic_gpu.tmem memory space", + ): + self.module.operation.verify() + class DialectLoweringTest(MosaicGpuTest): def test_lowering_removes_mosaic_gpu_ops(self): with ir.InsertionPoint(self.module.body): mgpu.dialect.initialize_barrier( - ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")), llvm.UndefOp(workgroup_ptr_ty()), arrival_count=1, + num_barriers=2, ) mgpu.lower_mgpu_dialect(self.module, None) @@ -615,9 +1189,9 @@ def test_lowering_traverses_regions_correctly(self): if_op = scf.IfOp(cst_true) with ir.InsertionPoint(if_op.then_block): mgpu.dialect.initialize_barrier( - ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")), llvm.UndefOp(workgroup_ptr_ty()), arrival_count=1, + num_barriers=2, ) scf.yield_([]) mgpu.lower_mgpu_dialect(self.module, None) @@ -627,46 +1201,43 @@ def test_lowering_traverses_regions_correctly(self): ) def test_initialize_barrier_op_lowering_rule(self): - shape = (3, 4) - num_shape_elements = shape[0] * shape[1] + num_barriers = 4 arrival_count = 1337 with ir.InsertionPoint(self.module.body): - barriers_ref = mgpu.dialect.initialize_barrier( - ir.MemRefType.get(shape, ir.Type.parse("!mosaic_gpu.barrier")), + mgpu.dialect.initialize_barrier( llvm.UndefOp(workgroup_ptr_ty()), arrival_count=arrival_count, + num_barriers=num_barriers, ) - # Add a user for barriers_ref to make sure that the lowering keeps types - # consistent. - memref.copy(barriers_ref, barriers_ref) self.assertTrue(self.module.operation.verify()) mgpu.lower_mgpu_dialect(self.module, None) self.assertTrue(self.module.operation.verify()) - all_mbarrier_init_shared_ops = find_if( + all_mbarrier_init_ops = find_if( self.module, - lambda op: op.name == nvvm.MBarrierInitSharedOp.OPERATION_NAME, + lambda op: op.name == nvvm.MBarrierInitOp.OPERATION_NAME, ) # One nvvm.mbarrier_init_shared is issued per barrier. - self.assertLen(all_mbarrier_init_shared_ops, num_shape_elements) + self.assertLen(all_mbarrier_init_ops, num_barriers) - # Each barrier has its count equal to the arrival count. - for op in all_mbarrier_init_shared_ops: + # Each barrier has its count equal to the arrival count times the + # warpgroup size. + for op in all_mbarrier_init_ops: count = op.count.owner.opview self.assertIsInstance(count, arith.ConstantOp) - self.assertEqual(count.literal_value, arrival_count) + self.assertEqual( + count.literal_value, arrival_count * mgpu_utils.WARPGROUP_SIZE + ) def test_lowering_vector_op_without_layout_fails(self): shape = (3, 4) elt_ty = ir.BF16Type.get() with ir.InsertionPoint(self.module.body): ref = llvm.mlir_undef(ir.MemRefType.get(shape, elt_ty)) - zero_index = arith.constant(ir.IndexType.get(), 0) - ty = ir.VectorType.get(shape, elt_ty) - vector.load(ty, ref, [zero_index, zero_index]) + mgpu.dialect.vector_load(ref) with self.assertRaisesRegex( ValueError, "missing a layout and can not be lowered" ): @@ -677,11 +1248,11 @@ def test_lowering_eliminates_layouts(self): elt_ty = ir.BF16Type.get() with ir.InsertionPoint(self.module.body): ref = llvm.mlir_undef(ir.MemRefType.get(shape, elt_ty)) - zero_index = arith.constant(ir.IndexType.get(), 0) - ty = ir.VectorType.get(shape, elt_ty) - load = vector.load(ty, ref, [zero_index, zero_index]) + load = mgpu.dialect.vector_load(ref) + strided_layout = mgpu.WGStridedFragLayout.from_shaped_type(load.type) + assert strided_layout is not None load.owner.attributes["out_layouts"] = ir.ArrayAttr.get([ - layouts.to_layout_attr(mgpu.WGStridedFragLayout.from_shaped_type(ty)) + layouts.to_layout_attr(strided_layout) ]) mgpu.lower_mgpu_dialect(self.module, None) @@ -695,13 +1266,11 @@ def test_lowering_eliminates_layouts(self): self.assertEmpty(all_ops_with_layouts) def test_lowering_splat_constant(self): - cst = None elt_ty = ir.BF16Type.get() - def body(): + with ir.InsertionPoint(self.module.body): vec_ty = ir.VectorType.get((16, 8), elt_ty) zero = ir.FloatAttr.get(elt_ty, 0) - nonlocal cst cst = arith.ConstantOp( vec_ty, ir.DenseElementsAttr.get_splat(vec_ty, zero) ) @@ -711,9 +1280,6 @@ def body(): ) ]) - with ir.InsertionPoint(self.module.body): - func.FuncOp.from_py_func()(body) - mgpu.lower_mgpu_dialect(self.module, None) cst_ops = find_if( @@ -728,10 +1294,8 @@ def test_lowering_vector_load_and_store_ops(self): elt_ty = ir.BF16Type.get() with ir.InsertionPoint(self.module.body): ref = llvm.mlir_undef(ir.MemRefType.get(shape, elt_ty)) - zero_index = arith.constant(ir.IndexType.get(), 0) - ty = ir.VectorType.get(shape, elt_ty) - array = vector.load(ty, ref, [zero_index, zero_index]) - vector.store(array, ref, [zero_index, zero_index]) + reg = mgpu.dialect.vector_load(ref) + mgpu.dialect.vector_store(reg, ref) mgpu.infer_layout(self.module) mgpu.lower_mgpu_dialect(self.module, None) @@ -752,7 +1316,9 @@ def test_lowering_vector_load_and_store_ops(self): self.assertLen(all_stores, 2) def check_type(ty: ir.Type): - self.assertTrue(ir.VectorType.get((4,), elt_ty).isinstance(ty)) + self.assertIsInstance(ty, ir.VectorType) + self.assertEqual(ty.element_type, elt_ty) + self.assertEqual(ty.shape, [4]) load1, load2, *_ = all_loads # Variadic unpacking to silence linter. check_type(load1.result.type) @@ -765,24 +1331,23 @@ def check_type(ty: ir.Type): def test_lowering_for(self): shape = (4, 128) i32 = ir.IntegerType.get_signless(32) - vec_ty = ir.VectorType.get(shape, i32) splat_layout_attr = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape)) - strided_layout_attr = layouts.to_layout_attr( - mgpu.WGStridedFragLayout.from_shaped_type(vec_ty) - ) with ir.InsertionPoint(self.module.body): i1 = arith.constant(ir.IndexType.get(), 1) c1 = arith.constant(i32, 1) - splat = vector.SplatOp( - ir.VectorType.get(shape, i32), arith.constant(i32, 1234), + splat = vector.BroadcastOp( + ir.VectorType.get(shape, i32), + arith.constant(i32, 1234), ) splat.attributes["out_layouts"] = ir.ArrayAttr.get([ splat_layout_attr ]) ptr = llvm.mlir_undef(ir.Type.parse("!llvm.ptr")) ref = mgpu_utils.ptr_as_memref(ptr, ir.MemRefType.get(shape, i32)) - i0 = arith.constant(ir.IndexType.get(), 0) - other_vec = vector.LoadOp(vec_ty, ref, [i0, i0]) + other_vec = mgpu.dialect.VectorLoadOp(ref) + strided_layout_attr = layouts.to_layout_attr( + mgpu.WGStridedFragLayout.from_shaped_type(other_vec.result.type) + ) other_vec.attributes["out_layouts"] = ir.ArrayAttr.get([strided_layout_attr]) for_op = scf.ForOp(i1, i1, i1, [c1, splat.result]) for_op.attributes["in_layouts"] = ir.ArrayAttr.get([strided_layout_attr]) @@ -804,17 +1369,13 @@ def test_lowering_for(self): self.assertSequenceEqual(result_types, [i32, reg_vec_ty, reg_vec_ty]) def test_lowering_slice_smem_op(self): - shift = 1234 - offset = None - - def body(): - nonlocal offset + with ir.InsertionPoint(self.module.body): + shift = 1234 i32 = ir.IntegerType.get_signless(32) + memref_ty = ir.MemRefType.get((4, 32), i32, memory_space=mgpu_utils.smem()) offset = arith.constant(i32, shift) - mgpu.dialect.slice_smem(i32, offset) - - with ir.InsertionPoint(self.module.body): - func.FuncOp.from_py_func()(body) + op = mgpu.dialect.SliceSMEMOp(memref_ty, offset) + op.attributes["out_transforms"] = ir.ArrayAttr.get([ir.ArrayAttr.get([])]) mgpu.lower_mgpu_dialect(self.module, None) # Avoid making a change detector, only validate that lowering runs as @@ -844,7 +1405,7 @@ def test_lower_conversion_op_lowers_to_same_op(self, op, in_dtype, out_dtype): scalar_out_ty = mgpu_utils.dtype_to_ir_type(out_dtype) in_ty = ir.VectorType.get(shape, scalar_in_ty) out_ty = ir.VectorType.get(shape, scalar_out_ty) - if ir.IntegerType.isinstance(scalar_in_ty): + if isinstance(scalar_in_ty, ir.IntegerType): zero = ir.IntegerAttr.get(scalar_in_ty, 0) else: zero = ir.FloatAttr.get(scalar_in_ty, 0) @@ -862,6 +1423,121 @@ def test_lower_conversion_op_lowers_to_same_op(self, op, in_dtype, out_dtype): self.assertLen(conversion_ops, 1) self.assertEqual(conversion_ops[0].result.type, scalar_out_ty) + @parameterized.parameters( + (True, False, False), + (False, True, False), + (False, False, True), + ) + def test_custom_primitive_op_must_have_number_of_annotations_matching_operands_and_results( + self, omit_in_layouts, omit_in_transforms, omit_out_layouts + ): + vec_ty = ir.VectorType.get((4, 32), ir.BF16Type.get()) + out_layouts = [ + layouts.to_layout_attr( + mgpu.WGStridedFragLayout.from_shaped_type(vec_ty) + ) + ] + in_layouts = out_layouts * 2 + in_transforms = [ + ir.ArrayAttr.get([mgpu.dialect.SwizzleTransformAttr.get(128)]) + ] + + in_layouts = [] if omit_in_layouts else in_layouts + in_transforms = [] if omit_in_transforms else in_transforms + out_layouts = [] if omit_out_layouts else out_layouts + + with ir.InsertionPoint(self.module.body): + ref_ty = ir.MemRefType.get( + (4, 32), ir.BF16Type.get(), memory_space=mgpu_utils.smem() + ) + vec1, vec2, ref = undefs(vec_ty, vec_ty, ref_ty) + op = mgpu.dialect.CustomPrimitiveOp( + [vec_ty], [vec1, vec2, ref], in_layouts, in_transforms, out_layouts + ) + args_ty = [arg.type for arg in op.operands_] + block = op.body.blocks.append(*args_ty) + with ir.InsertionPoint(block): + out = undefs(vec_ty) + mgpu.dialect.ReturnOp(out) + + if omit_in_layouts: + error = "layout for each vector operand" + elif omit_in_transforms: + error = "transforms for each memref operand in smem" + else: + assert omit_out_layouts + error = "layout for each vector result" + + with self.assertRaisesRegex(ir.MLIRError, error): + self.module.operation.verify() + + def test_memref_transforms_with_transpose(self): + with ir.InsertionPoint(self.module.body): + ty_in = ir.MemRefType.get( + (64, 128), + ir.BF16Type.get(), + memory_space=mgpu_utils.smem(), + ) + ref = memref.alloc(ty_in, [], []) + + ref = mgpu_utils.memref_transpose(ref, (1, 0)) + # This tiling is applied to the transposed memref. + transforms = [mgpu.TileTransform(tiling=(16, 32))] + + ref_transformed = lowering.reinterpret_smem_ref(ref, transforms) + ty_transformed = ir.MemRefType(ref_transformed.type) + self.assertEqual(ty_transformed.shape, [8, 2, 16, 32]) + strides, _ = ty_transformed.get_strides_and_offset() + self.assertEqual(strides, [512, 4096, 1, 16]) + + def test_optimized_gmem_transfers_are_not_supported(self): + def body(ctx, input, output, scratch): + del ctx, output, scratch + reg = mgpu.dialect.vector_load(input, optimized=True) + layout = layouts.to_layout_attr(mgpu.WGMMA_LAYOUT) + mgpu.dialect.layout_cast(reg, layout) + + shape = (128, 128) + dtype = jnp.bfloat16 + with self.assertRaisesRegex( + NotImplementedError, "Only optimized transfers to SMEM supported" + ): + mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=jax.ShapeDtypeStruct(shape, dtype), + out_shape=jax.ShapeDtypeStruct(shape, dtype), + smem_scratch_shape=(), + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + + def test_inconsistent_collective_attributes_in_kernel_raise(self): + def body(ctx, out, smem_ptr): + del ctx, out + ref_ty = ir.MemRefType.get( + (128, 128), + ir.BF16Type.get(), + memory_space=mgpu_utils.tmem(), + ) + mgpu.dialect.tmem_alloc(ref_ty, smem_ptr, collective=False) + mgpu.dialect.tmem_alloc(ref_ty, smem_ptr, collective=True) + + with self.assertRaisesRegex( + ValueError, + "Collective attributes are inconsistent across operations in the" + " kernel", + ): + mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(), + out_shape=(jax.ShapeDtypeStruct((), jnp.int32),), + smem_scratch_shape=jax.ShapeDtypeStruct((), jnp.int32), + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + if __name__ == "__main__": parameterized.absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/mosaic/gpu_layout_inference_test.py b/tests/mosaic/gpu_layout_inference_test.py index 36c8ff9cf47e..a005e5f05812 100644 --- a/tests/mosaic/gpu_layout_inference_test.py +++ b/tests/mosaic/gpu_layout_inference_test.py @@ -18,16 +18,28 @@ from absl.testing import parameterized import jax +from jax import numpy as jnp from jax._src import config from jax._src import test_util as jtu from jax._src.interpreters import mlir as mlir_interpreter from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith -from jax._src.lib.mlir.dialects import func +from jax._src.lib.mlir.dialects import builtin +from jax._src.lib.mlir.dialects import llvm +from jax._src.lib.mlir.dialects import math as math_dialect +from jax._src.lib.mlir.dialects import memref from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector import jax.experimental.mosaic.gpu as mgpu +from jax.experimental.mosaic.gpu import constraints as cs +from jax.experimental.mosaic.gpu import fragmented_array as fa +from jax.experimental.mosaic.gpu import inference_utils +from jax.experimental.mosaic.gpu import launch_context as lc +from jax.experimental.mosaic.gpu import layout_inference from jax.experimental.mosaic.gpu import layouts +from jax.experimental.mosaic.gpu import tcgen05 +from jax.experimental.mosaic.gpu import test_util as mtu +import numpy as np config.parse_flags_with_absl() @@ -40,7 +52,53 @@ def _make_ir_context(): return context +def layout_cast(x: ir.Value, layout: mgpu.FragmentedLayout | ir.Attribute) -> ir.Value: + """Convenience wrapper around `mgpu.dialect.layout_cast`.""" + if isinstance(layout, mgpu.FragmentedLayout): + layout = layouts.to_layout_attr(layout) + return mgpu.dialect.layout_cast(x, layout) + + +def undefs(*tys: ir.Type) -> list[ir.Value]: + """Returns a list of `llvm.mlir_undef` values of the given types.""" + return [llvm.mlir_undef(ty) for ty in tys] + + +V = cs.Variable +E = cs.Equals +RL = cs.RegisterLayout + + +def _undef_constraint_system( + ctx: layout_inference.DerivationContext, + op: llvm.UndefOp, +) -> tuple[ + cs.ConstraintSystem, + layout_inference.ValueSitesForVariable, +]: + del ctx + # This rule is only called if the single output of the undef op is a vector or + # TMEM reference, so we can just return a trivial mapping. + result = layout_inference.ValueSite( + op, layout_inference.VariableType.RESULT, 0 + ) + return cs.ConstraintSystem(), {cs.Variable(result): [result]} + + class LayoutInferenceTest(parameterized.TestCase): + @classmethod + def setUpClass(cls): + super().setUpClass() + layout_inference._add_constraint_system_derivation_rule(llvm.UndefOp)( + _undef_constraint_system + ) + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + del layout_inference._constraint_system_derivation_rules[ + llvm.UndefOp.OPERATION_NAME + ] def setUp(self): if jax.version._version != jax.lib.__version__: @@ -50,60 +108,74 @@ def setUp(self): self.enter_context(ir.Location.unknown()) self.module = ir.Module.create() - def test_infer_strided_layout_default(self): - shape = (16, 8) - elt_type = ir.BF16Type.get() - add = None - - def body(a, b): - nonlocal add - add = arith.AddFOp(a, b) + def checkInLayouts(self, op, in_layouts): + in_layouts = [ + layouts.to_layout_attr(l) if isinstance(l, mgpu.FragmentedLayout) else l + for l in in_layouts + ] + self.assertSequenceEqual(op.attributes["in_layouts"], in_layouts) + + def checkOutLayouts(self, op, out_layouts): + out_layouts = [ + layouts.to_layout_attr(l) if isinstance(l, mgpu.FragmentedLayout) else l + for l in out_layouts + ] + self.assertSequenceEqual(op.attributes["out_layouts"], out_layouts) + + def checkInTmemLayouts(self, op, in_layouts): + in_layouts = [ + layouts.to_layout_attr(l) if isinstance(l, tcgen05.TMEMLayout) else l + for l in in_layouts + ] + self.assertSequenceEqual(op.attributes["in_tmem_layouts"], in_layouts) + + def checkOutTmemLayouts(self, op, out_layouts): + out_layouts = [ + layouts.to_layout_attr(l) if isinstance(l, tcgen05.TMEMLayout) else l + for l in out_layouts + ] + self.assertSequenceEqual(op.attributes["out_tmem_layouts"], out_layouts) + def test_infer_strided_layout_default(self): with ir.InsertionPoint(self.module.body): - ty = ir.VectorType.get(shape, elt_type) - func.FuncOp.from_py_func(ty, ty)(body) + ty = ir.VectorType.get((128,), ir.BF16Type.get()) + x = llvm.UndefOp(ty) # Not setting any layouts on the module should default in ops having a # strided fragmented layout. mgpu.infer_layout(self.module) - layout = layouts.to_layout_attr( - mgpu.WGStridedFragLayout.from_shaped_type(ty) - ) + strided_layout = mgpu.WGStridedFragLayout.from_shaped_type(ty) + assert strided_layout is not None + layout = layouts.to_layout_attr(strided_layout) - self.assertSequenceEqual(add.attributes["in_layouts"], [layout, layout]) - self.assertSequenceEqual(add.attributes["out_layouts"], [layout]) + self.assertNotIn("in_layouts", x.attributes) + self.checkOutLayouts(x, [layout]) - def test_infer_strided_layout_from_shape_cast(self): - shape = (16, 8) + @parameterized.parameters( + (mgpu.WGMMA_LAYOUT, None), (None, mgpu.WGMMA_LAYOUT) + ) + def test_infer_layout_bidirectionally_through_shape_cast( + self, in_layout, out_layout + ): + assert in_layout is not None or out_layout is not None elt_type = ir.BF16Type.get() - src_type = ir.VectorType.get(shape, elt_type) - dst_type = ir.VectorType.get([*reversed(shape)], elt_type) - op = None - - def body(x): - nonlocal op - op = vector.ShapeCastOp(dst_type, x) + src_type = ir.VectorType.get((2, 128, 8), elt_type) + dst_type = ir.VectorType.get((256, 8), elt_type) with ir.InsertionPoint(self.module.body): - func.FuncOp.from_py_func(src_type)(body) + [x] = undefs(src_type) + if in_layout is not None: + x = layout_cast(x, in_layout) + op = vector.ShapeCastOp(dst_type, x) + if out_layout is not None: + layout_cast(op.result, out_layout) mgpu.infer_layout(self.module) - in_layout = layouts.to_layout_attr( - mgpu.WGStridedFragLayout.from_shaped_type(src_type) - ) - out_layout = layouts.to_layout_attr( - mgpu.WGStridedFragLayout.from_shaped_type(dst_type) - ) - - self.assertSequenceEqual(op.attributes["in_layouts"], [in_layout]) - self.assertSequenceEqual(op.attributes["out_layouts"], [out_layout]) - - # Ensure that we can recover the original layout. - del op.attributes["in_layouts"] - mgpu.infer_layout(self.module) - self.assertSequenceEqual(op.attributes["in_layouts"], [in_layout]) + expected_layout = in_layout or out_layout + self.checkInLayouts(op, [expected_layout]) + self.checkOutLayouts(op, [expected_layout]) def test_infer_splat_layout_for_splat_constants(self): shape = (16, 8) @@ -112,10 +184,7 @@ def test_infer_splat_layout_for_splat_constants(self): with ir.InsertionPoint(self.module.body): ty = ir.VectorType.get(shape, elt_type) c0 = ir.FloatAttr.get(elt_type, 0) - c1 = ir.FloatAttr.get(elt_type, 1) - splat0 = arith.ConstantOp(ty, ir.DenseElementsAttr.get_splat(ty, c0)) - splat1 = arith.ConstantOp(ty, ir.DenseElementsAttr.get_splat(ty, c1)) - add = arith.AddFOp(splat0, splat1) + splat = arith.ConstantOp(ty, ir.DenseElementsAttr.get_splat(ty, c0)) # Not setting any layouts on the module should default in all ops having a # splat fragmented layout. @@ -123,18 +192,15 @@ def test_infer_splat_layout_for_splat_constants(self): layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape=shape)) - self.assertEmpty(splat0.attributes["in_layouts"]) - self.assertSequenceEqual(splat0.attributes["out_layouts"], [layout]) - - self.assertEmpty(splat1.attributes["in_layouts"]) - self.assertSequenceEqual(splat1.attributes["out_layouts"], [layout]) - - self.assertSequenceEqual(add.attributes["in_layouts"], [layout, layout]) - self.assertSequenceEqual(add.attributes["out_layouts"], [layout]) + self.assertNotIn("in_layouts", splat.attributes) + self.checkOutLayouts(splat, [layout]) def test_infer_layout_from_consumer_for_non_splat_constant(self): shape = (16, 8) elt_type = ir.BF16Type.get() + layout = layouts.to_layout_attr( + mgpu.WGStridedFragLayout(shape=shape, vec_size=1) + ) with ir.InsertionPoint(self.module.body): ty = ir.VectorType.get(shape, elt_type) @@ -142,96 +208,195 @@ def test_infer_layout_from_consumer_for_non_splat_constant(self): ir.FloatAttr.get(elt_type, i) for i in range(shape[0] * shape[1]) ] c = arith.ConstantOp(ty, ir.DenseElementsAttr.get(attr_list, ty)) - add = arith.AddFOp(c, c) - - layout = layouts.to_layout_attr( - mgpu.WGStridedFragLayout(shape=shape, vec_size=1) - ) - add.attributes["in_layouts"] = ir.ArrayAttr.get([layout, layout]) + layout_cast(c, layout) mgpu.infer_layout(self.module) - self.assertEmpty(c.attributes["in_layouts"]) - self.assertSequenceEqual(c.attributes["out_layouts"], [layout]) - - @parameterized.parameters(True, False) - def test_infer_splat_layout_for_vector_splat(self, rhs_splat): - add = splat = None - - def body(lhs, rhs): - nonlocal add, splat - splat = vector.SplatOp(rhs.type, lhs) - add = arith.AddFOp(splat.result, rhs) + self.assertNotIn("in_layouts", c.attributes) + self.checkOutLayouts(c, [layout]) + def test_infer_splat_layout_for_vector_splat(self): + shape = (16, 8) + splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape=shape)) with ir.InsertionPoint(self.module.body): - shape = (16, 8) - elt_type = ir.BF16Type.get() - ty = ir.VectorType.get(shape, elt_type) - func_op = func.FuncOp.from_py_func(elt_type, ty)(body).func_op + bf16 = ir.BF16Type.get() + ty = ir.VectorType.get(shape, bf16) + lhs, rhs = undefs(bf16, ty) + rhs = layout_cast(rhs, splat_layout) + splat = vector.BroadcastOp(rhs.type, lhs) + add = arith.AddFOp(splat.result, rhs) - layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape=shape)) - if rhs_splat: - func_op.attributes["in_layouts"] = ir.ArrayAttr.get([layout]) mgpu.infer_layout(self.module) - self.assertEmpty(splat.attributes["in_layouts"]) - self.assertSequenceEqual(splat.attributes["out_layouts"], [layout]) - - add_layout = layout - if not rhs_splat: - add_layout = layouts.to_layout_attr( - mgpu.WGStridedFragLayout.from_shaped_type(ty) - ) + self.assertNotIn("in_layouts", splat.attributes) + self.checkOutLayouts(splat, [splat_layout]) - self.assertSequenceEqual(add.attributes["in_layouts"], [add_layout, add_layout]) - self.assertSequenceEqual(add.attributes["out_layouts"], [add_layout]) + self.checkInLayouts(add, [splat_layout, splat_layout]) + self.checkOutLayouts(add, [splat_layout]) @parameterized.parameters( mgpu.WGSplatFragLayout(shape=(32, 4)), mgpu.WGStridedFragLayout(shape=(32, 4), vec_size=1), ) def test_pointwise_op_propagates_argument_layouts(self, layout): - add = None - - def body(lhs, rhs): - nonlocal add - add = arith.AddFOp(lhs, rhs) with ir.InsertionPoint(self.module.body): ty = ir.VectorType.get(layout.shape, ir.BF16Type.get()) - func.FuncOp.from_py_func(ty, ty)(body) + lhs, rhs = undefs(ty, ty) + lhs = layout_cast(lhs, layout) + rhs = layout_cast(rhs, layout) + add = arith.AddFOp(lhs, rhs) + + mgpu.infer_layout(self.module) - [f] = self.module.body.operations layout_attr = layouts.to_layout_attr(layout) - f.attributes["in_layouts"] = ir.ArrayAttr.get([layout_attr, layout_attr]) + self.checkInLayouts(add, [layout_attr, layout_attr]) + self.checkOutLayouts(add, [layout_attr]) + + def test_vector_load_does_not_allow_splat_result(self): + shape = (32, 4) + splat_layout_attr = layouts.to_layout_attr( + mgpu.WGSplatFragLayout(shape=shape) + ) + strided_layout_attr = layouts.to_layout_attr( + mgpu.WGStridedFragLayout(shape=shape, vec_size=1) + ) + + with ir.InsertionPoint(self.module.body): + vec_ty = ir.VectorType.get(shape, ir.BF16Type.get()) + ref_ty = ir.MemRefType.get(shape, ir.BF16Type.get()) + vec, ref = undefs(vec_ty, ref_ty) + load_op = mgpu.dialect.VectorLoadOp(ref) + lhs = layout_cast(vec, splat_layout_attr) + arith.AddFOp(lhs, load_op.result) mgpu.infer_layout(self.module) - self.assertSequenceEqual( - add.attributes["in_layouts"], [layout_attr, layout_attr] - ) - self.assertSequenceEqual(add.attributes["out_layouts"], [layout_attr]) + self.assertNotIn("in_layouts", load_op.attributes) + self.checkOutLayouts(load_op, [strided_layout_attr]) + + def test_infer_layout_cast_layout(self): + shape = (128, 64) + splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape=shape)) + wgmma_layout = layouts.to_layout_attr(mgpu.WGMMA_LAYOUT) + + with ir.InsertionPoint(self.module.body): + [x] = undefs(ir.VectorType.get(shape, ir.BF16Type.get())) + x = mgpu.dialect.layout_cast(x, splat_layout) + add = arith.AddFOp(x, x) + cast = mgpu.dialect.LayoutCastOp(add.result, wgmma_layout) + + mgpu.infer_layout(self.module) + # The layout of `add` may be either WGMMA or SPLAT. + self.checkOutLayouts(add, [wgmma_layout]) + self.checkInLayouts(cast, [wgmma_layout]) + self.checkOutLayouts(cast, [wgmma_layout]) + + @parameterized.product( + layout=( + mtu.RegisterLayout.WGMMA, + mtu.RegisterLayout.TCGEN05, + mtu.RegisterLayout.TCGEN05_TMEM_NATIVE, + mtu.RegisterLayout.TCGEN05_M64_COLLECTIVE, + ), + axis=(0, 1), + hint_on_input=(True, False), + ) + def test_infer_broadcast_in_dim_layout(self, layout, axis, hint_on_input): + in_shape = (128,) + out_shape = (128, 128) + dtype = ir.F32Type.get() + out_layout = layout.to_mgpu(out_shape, dtype) + in_layout = out_layout.reduce((1 - axis,)) + + with ir.InsertionPoint(self.module.body): + [x] = undefs(ir.VectorType.get(in_shape, dtype)) + if hint_on_input: + x = layout_cast(x, in_layout) + out_type = ir.VectorType.get(out_shape, dtype) + bcast = mgpu.dialect.BroadcastInDimOp(out_type, x, [axis]) + if not hint_on_input: + layout_cast(bcast.result, out_layout) + + if hint_on_input and axis == 1 and layout == mtu.RegisterLayout.TCGEN05: + # Both TCGEN05 and WGMMA are valid layout candidates. WGMMA is tried first. + out_layout = fa.WGMMA_LAYOUT + + mgpu.infer_layout(self.module) + self.checkInLayouts(bcast, [in_layout]) + self.checkOutLayouts(bcast, [out_layout]) + + # TODO(allanrenucci): Turn into a positive test. This is currently not + # implemented. The test checks we fail gracefully. + @parameterized.parameters(True, False) + def test_cant_infer_reduced_strided_layout(self, hint_on_input): + with ir.InsertionPoint(self.module.body): + [x] = undefs(ir.VectorType.get((128,), ir.F32Type.get())) + if hint_on_input: + layout = mgpu.WGStridedFragLayout.from_shaped_type(x.type) + x = layout_cast(x, layout) + out_type = ir.VectorType.get((128, 128), ir.F32Type.get()) + out = mgpu.dialect.broadcast_in_dim(out_type, x, [0]) + if not hint_on_input: + layout = mgpu.WGStridedFragLayout.from_shaped_type(out.type) + layout_cast(out, layout) + + with self.assertRaisesRegex( + ValueError, "Failed to infer a possible set of layouts" + ): + mgpu.infer_layout(self.module) + + @parameterized.parameters( + (1, mgpu.WGMMA_LAYOUT, None, None), + (0, mgpu.WGMMA_LAYOUT, None, None), + (1, None, None, mgpu.WGMMA_ROW_LAYOUT), + (0, None, None, mgpu.WGMMA_COL_LAYOUT), + (1, None, mgpu.WGMMA_ROW_LAYOUT, None), + (0, None, mgpu.WGMMA_COL_LAYOUT, None), + (1, None, mgpu.WGMMA_ROW_LAYOUT, mgpu.WGMMA_ROW_LAYOUT), + (0, None, mgpu.WGMMA_COL_LAYOUT, mgpu.WGMMA_COL_LAYOUT), + (1, mgpu.TCGEN05_LAYOUT, None, None), + (1, None, None, mgpu.TCGEN05_ROW_LAYOUT), + (1, None, mgpu.TCGEN05_ROW_LAYOUT, None), + (1, None, mgpu.TCGEN05_ROW_LAYOUT, mgpu.TCGEN05_ROW_LAYOUT) + ) + def test_infer_multi_reduce_layout( + self, reduce_dim, in_cast, acc_cast, out_cast + ): + with ir.InsertionPoint(self.module.body): + in_ty = ir.VectorType.get((128, 128), ir.F32Type.get()) + acc_ty = ir.VectorType.get((128,), ir.F32Type.get()) + x, acc = undefs(in_ty, acc_ty) + x = layout_cast(x, in_cast) if in_cast is not None else x + acc = layout_cast(acc, acc_cast) if acc_cast is not None else acc + kind = vector.CombiningKind.MAXIMUMF + red = vector.MultiDimReductionOp(kind, x, acc, [reduce_dim]) + if out_cast is not None: + layout_cast(red.result, out_cast) + + mgpu.infer_layout(self.module) + targets_tcgen05 = any(layout in {mgpu.TCGEN05_LAYOUT, mgpu.TCGEN05_ROW_LAYOUT} for layout in [in_cast, acc_cast, out_cast]) + # The tests always expect WGMMA or TCGEN05 as the source layout. + in_layout = mgpu.TCGEN05_LAYOUT if targets_tcgen05 else mgpu.WGMMA_LAYOUT + out_layout = in_layout.reduce((reduce_dim,)) + self.checkInLayouts(red, [in_layout, out_layout]) + self.checkOutLayouts(red, [out_layout]) def test_infer_layout_traverses_ops_correctly(self): shape = (16, 8) elt_type = ir.BF16Type.get() - add = None - def body(a, b): + with ir.InsertionPoint(self.module.body): + ab_type = ir.VectorType.get(shape, elt_type) + a, b = undefs(ab_type, ab_type) bool_type = ir.IntegerType.get_signless(1) cst_true = arith.constant(bool_type, ir.IntegerAttr.get(bool_type, 1)) if_op = scf.IfOp(cst_true) with ir.InsertionPoint(if_op.then_block): - nonlocal add add = arith.AddFOp(a, b) scf.yield_([]) - with ir.InsertionPoint(self.module.body): - ab_type = ir.VectorType.get(shape, elt_type) - func.FuncOp.from_py_func(ab_type, ab_type)(body) - mgpu.infer_layout(self.module) - self.assertIn("in_layouts", add.attributes) self.assertIn("out_layouts", add.attributes) @@ -247,200 +412,1838 @@ def body(a, b): def test_infer_layout_from_yield_op_in_layouts_for_for_op( self, shape, layout ): - add_op = for_op = yield_op = None - - def body(lower_bound, upper_bound, step, a, b): - nonlocal for_op - for_op = scf.ForOp(lower_bound, upper_bound, step, [a, b]) - [loop_a, loop_b] = list(for_op.inner_iter_args) - with ir.InsertionPoint(for_op.body): - nonlocal add_op, yield_op - add_op = arith.AddFOp(loop_a, loop_b) - yield_op = scf.YieldOp([add_op.result, add_op.result]) - with ir.InsertionPoint(self.module.body): - ab_type = ir.VectorType.get(shape, ir.BF16Type.get()) + elt_ty = ir.BF16Type.get() + ab_type = ir.VectorType.get(shape, elt_ty) + ref_ty = ir.MemRefType.get(shape, elt_ty, memory_space=mgpu.utils.smem()) i32 = ir.IntegerType.get_signless(32) - func.FuncOp.from_py_func(i32, i32, i32, ab_type, ab_type)(body) + lower_bound, upper_bound, step, a, b, ref = undefs( + i32, i32, i32, ab_type, ab_type, ref_ty + ) + for_op = scf.ForOp(lower_bound, upper_bound, step, [a, b, ref]) + [loop_a, loop_b, loop_ref] = list(for_op.inner_iter_args) + with ir.InsertionPoint(for_op.body): + add = layout_cast(arith.addf(loop_a, loop_b), layout) - add_op.attributes["out_layouts"] = ir.ArrayAttr.get( - [layouts.to_layout_attr(layout)] - ) + transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((8, 32)), + mgpu.dialect.SwizzleTransformAttr.get(64), + ]) + loop_ref = mgpu.dialect.with_transforms(loop_ref, transforms) + + yield_op = scf.YieldOp([add, add, loop_ref]) mgpu.infer_layout(self.module) - if isinstance(layout, mgpu.WGSplatFragLayout): - # In the splat case, we should not propagate the splat layout from the - # yield op. That is because we can not convert other layouts to a splat - # layout---which could cause trouble if the initial carries have a - # different layout. Instead, we should get the default annotation, i.e. - # strided layouts. - strided_layout = layouts.to_layout_attr( - mgpu.WGStridedFragLayout.from_shaped_type(ab_type) - ) - carry_layouts = [strided_layout, strided_layout] - self.assertSequenceEqual(yield_op.attributes["out_layouts"], []) - self.assertSequenceEqual(for_op.attributes["in_layouts"], carry_layouts) - self.assertSequenceEqual(for_op.attributes["out_layouts"], carry_layouts) - else: - carry_layouts = [layouts.to_layout_attr(layout)] * 2 - self.assertSequenceEqual(yield_op.attributes["out_layouts"], []) - self.assertSequenceEqual(for_op.attributes["in_layouts"], carry_layouts) - self.assertSequenceEqual(for_op.attributes["out_layouts"], carry_layouts) + carry_layouts = [layouts.to_layout_attr(layout)] * 2 + self.assertNotIn("out_layouts", yield_op.attributes) + self.checkInLayouts(for_op, carry_layouts) + self.checkOutLayouts(for_op, carry_layouts) + [in_transform] = inference_utils.in_transforms(for_op) + self.assertSequenceEqual(in_transform, transforms) + [out_transform] = inference_utils.out_transforms(for_op) + self.assertSequenceEqual(out_transform, transforms) def test_infer_layout_from_body_op_to_yield_op_to_for_op(self): - for_op = yield_op = None shape = (64, 64) - - def body(lower_bound, upper_bound, step, a, b, c): - nonlocal for_op + with ir.InsertionPoint(self.module.body): + elt_ty = ir.BF16Type.get() + c_ty = ir.VectorType.get(shape, elt_ty) + ab_type = ir.MemRefType.get(shape, elt_ty, memory_space=mgpu.utils.smem()) + i32 = ir.IntegerType.get_signless(32) + lower_bound, upper_bound, step, a, b, c = undefs( + i32, i32, i32, ab_type, ab_type, c_ty + ) for_op = scf.ForOp(lower_bound, upper_bound, step, [a, b, c]) with ir.InsertionPoint(for_op.body): - nonlocal yield_op [loop_a, loop_b, loop_c] = list(for_op.inner_iter_args) new_loop_c = mgpu.dialect.wgmma(loop_c, loop_a, loop_b) yield_op = scf.YieldOp([loop_a, loop_b, new_loop_c]) + mgpu.infer_layout(self.module) + + wgmma_layout = layouts.to_layout_attr(mgpu.WGMMA_LAYOUT) + self.checkInLayouts(yield_op, [wgmma_layout]) + self.assertNotIn("out_layouts", yield_op.attributes) + self.checkInLayouts(for_op, [wgmma_layout]) + self.checkOutLayouts(for_op, [wgmma_layout]) + + transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((8, 64)), + mgpu.dialect.SwizzleTransformAttr.get(128), + ]) + in_transforms = inference_utils.in_transforms(for_op) + self.assertSequenceEqual(in_transforms, [transforms, transforms]) + out_transforms = inference_utils.out_transforms(for_op) + self.assertSequenceEqual(out_transforms, [transforms, transforms]) + + @parameterized.parameters( + ((), None, (), None), + ((64, 32), mgpu.WGMMA_LAYOUT, (), None), + ((), None, (64, 32), mgpu.WGMMA_LAYOUT), + ((64,), mgpu.WGMMA_ROW_LAYOUT, (64, 32), mgpu.WGMMA_LAYOUT), + ) + def test_infer_while_op_layouts( + self, init_shape, init_layout, result_shape, result_layout + ): + f32 = ir.F32Type.get() + in_type = ir.VectorType.get(init_shape, f32) if init_shape else f32 + out_type = ir.VectorType.get(result_shape, f32) if result_shape else f32 with ir.InsertionPoint(self.module.body): - c_ty = ir.VectorType.get(shape, ir.BF16Type.get()) - ab_ty = ir.MemRefType.get(shape, ir.BF16Type.get()) - i32 = ir.IntegerType.get_signless(32) - func.FuncOp.from_py_func(i32, i32, i32, ab_ty, ab_ty, c_ty)(body) + i1 = ir.IntegerType.get_signless(1) + condition, init, result = undefs(i1, in_type, out_type) + init = layout_cast(init, init_layout) if init_layout else init + result = layout_cast(result, result_layout) if result_layout else result + while_op = scf.WhileOp([out_type], [init]) + before_block = while_op.before.blocks.append(init.type) + with ir.InsertionPoint(before_block): + scf.condition(condition, [result]) + after_block = while_op.after.blocks.append(out_type) + with ir.InsertionPoint(after_block): + scf.yield_([init]) mgpu.infer_layout(self.module) - wgmma_layout = layouts.to_layout_attr(mgpu.WGMMA_LAYOUT) - self.assertSequenceEqual(yield_op.attributes["in_layouts"], [wgmma_layout]) - self.assertSequenceEqual(yield_op.attributes["out_layouts"], []) - self.assertSequenceEqual(for_op.attributes["in_layouts"], [wgmma_layout]) - self.assertSequenceEqual(for_op.attributes["out_layouts"], [wgmma_layout]) + if init_layout: + self.checkInLayouts(while_op, [layouts.to_layout_attr(init_layout)]) + if result_layout: + self.checkOutLayouts(while_op, [layouts.to_layout_attr(result_layout)]) - def test_infer_layout_has_no_layout_for_non_vector_types(self): - shape = (32, 4) - elt_ty = ir.BF16Type.get() + @parameterized.parameters( + (None, mgpu.WGMMA_ROW_LAYOUT, mgpu.WGMMA_LAYOUT), + (mgpu.WGMMA_LAYOUT, mgpu.WGMMA_COL_LAYOUT, None), + ) + def test_infer_index_switch_op_layouts( + self, + out0_layout: mgpu.FragmentedLayout | None, + out3_layout: mgpu.FragmentedLayout, + out4_layout: mgpu.FragmentedLayout | None + ): + out_layouts = [out0_layout or out4_layout, out3_layout] + assert None not in out_layouts + f32 = ir.F32Type.get() + out_type = ir.VectorType.get((128, 128), f32) + with ir.InsertionPoint(self.module.body): + i1 = ir.IntegerType.get_signless(1) + [condition] = undefs(i1) + index_switch = scf.IndexSwitchOp( + [out_type, out_type, f32], + condition, + range(2), + ) + with ir.InsertionPoint(index_switch.caseRegions[0].blocks[0]): + out0, out1, dummy0 = undefs(out_type, out_type, f32) + if out0_layout is not None: + out0 = layout_cast(out0, out0_layout) + yield0 = scf.YieldOp([out0, out1, dummy0]) + with ir.InsertionPoint(index_switch.caseRegions[1].blocks[0]): + out2, out3, dummy1 = undefs(out_type, out_type, f32) + if out3_layout is not None: + out3 = layout_cast(out3, out3_layout) + yield1 = scf.YieldOp([out2, out3, dummy1]) + with ir.InsertionPoint(index_switch.defaultRegion.blocks[0]): + out4, out5, dummy2 = undefs(out_type, out_type, f32) + if out4_layout is not None: + out4 = layout_cast(out4, out4_layout) + yield2 = scf.YieldOp([out4, out5, dummy2]) - vector_store = None + mgpu.infer_layout(self.module) + + self.assertNotIn("in_layouts", index_switch.attributes) + self.assertNotIn("out_layouts", yield0.attributes) + self.assertNotIn("out_layouts", yield1.attributes) + self.assertNotIn("out_layouts", yield2.attributes) - def body(ref, array): - nonlocal vector_store - zero_index = arith.constant(ir.IndexType.get(), 0) - vector_store = vector.store(array, ref, [zero_index, zero_index]) + self.checkOutLayouts(index_switch, out_layouts) + self.checkInLayouts(yield0, out_layouts) + self.checkInLayouts(yield1, out_layouts) + self.checkInLayouts(yield2, out_layouts) + def test_infer_layout_has_no_layout_for_non_vector_types(self): + shape = (32, 4) + elt_ty = ir.BF16Type.get() with ir.InsertionPoint(self.module.body): ref_ty = ir.MemRefType.get(shape, elt_ty) array_ty = ir.VectorType.get(shape, elt_ty) - func.FuncOp.from_py_func(ref_ty, array_ty)(body) + ref, array = undefs(ref_ty, array_ty) + op = mgpu.dialect.VectorStoreOp(array, ref) mgpu.infer_layout(self.module) - self.assertIn("in_layouts", vector_store.attributes) - self.assertIn("out_layouts", vector_store.attributes) - # The vector store should have a layout for the input array, but not for the # memref. - self.assertLen(vector_store.attributes["in_layouts"], 1) - self.assertEmpty(vector_store.attributes["out_layouts"]) + self.assertIn("in_layouts", op.attributes) + self.assertLen(op.attributes["in_layouts"], 1) + self.assertNotIn("out_layouts", op.attributes) @parameterized.parameters( - mgpu.WGStridedFragLayout((32, 4), vec_size=1), + mgpu.WGStridedFragLayout((64, 16), vec_size=1), mgpu.WGMMA_LAYOUT, ) - def test_infer_layout_picks_non_splat_layout_over_splat_layout( - self, layout - ): - add = None - - def body(lhs, rhs): - nonlocal add + def test_infer_layout_picks_non_splat_layout_over_splat_layout(self, layout): + shape = (64, 16) + splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape)) + non_splat_layout = layouts.to_layout_attr(layout) + with ir.InsertionPoint(self.module.body): + elt_type = ir.BF16Type.get() + ty = ir.VectorType.get(shape, elt_type) + lhs, rhs = undefs(ty, ty) + lhs = layout_cast(lhs, non_splat_layout) + rhs = layout_cast(rhs, splat_layout) add = arith.AddFOp(lhs, rhs) + mgpu.infer_layout(self.module) + + self.checkInLayouts(add, [non_splat_layout, non_splat_layout]) + self.checkOutLayouts(add, [non_splat_layout]) + + def test_infer_layout_preserves_splat_layouts_in_producers(self): + shape = (32, 4) + splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape)) + strided_layout = layouts.to_layout_attr( + mgpu.WGStridedFragLayout(shape, vec_size=1) + ) with ir.InsertionPoint(self.module.body): - shape = (32, 4) elt_type = ir.BF16Type.get() ty = ir.VectorType.get(shape, elt_type) + lhs, rhs = undefs(ty, ty) + lhs = layout_cast(lhs, splat_layout) + rhs = layout_cast(rhs, splat_layout) + add0 = arith.AddFOp(lhs, rhs) + cast = layout_cast(add0.result, strided_layout) + add1 = arith.AddFOp(cast, cast) + + mgpu.infer_layout(self.module) + + self.checkInLayouts(add0, [splat_layout, splat_layout]) + self.checkOutLayouts(add0, [splat_layout]) + self.checkInLayouts(add1, [strided_layout, strided_layout]) + self.checkOutLayouts(add1, [strided_layout]) + + def test_optimization_barrier_op_propagates_user_layouts(self): + wgmma_layout = layouts.to_layout_attr(mgpu.WGMMA_LAYOUT) + + with ir.InsertionPoint(self.module.body): + ty = ir.VectorType.get((64, 16), ir.BF16Type.get()) + lhs, rhs = undefs(ty, ty) + optimization_barrier = mgpu.dialect.OptimizationBarrierOp([lhs, rhs]) + lhs, rhs = optimization_barrier.results + layout_cast(arith.addf(lhs, rhs), wgmma_layout) + + mgpu.infer_layout(self.module) - f = func.FuncOp.from_py_func(ty, ty)(body).func_op + self.checkInLayouts(optimization_barrier, [wgmma_layout, wgmma_layout]) + self.checkOutLayouts(optimization_barrier, [wgmma_layout, wgmma_layout]) + def test_optimization_barrier_op_propagates_producer_layouts(self): + shape = (32, 4) splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape)) - non_splat_layout = layouts.to_layout_attr(layout) + with ir.InsertionPoint(self.module.body): + ty = ir.VectorType.get(shape, ir.BF16Type.get()) + lhs, rhs = undefs(ty, ty) + lhs = layout_cast(lhs, splat_layout) + rhs = layout_cast(rhs, splat_layout) + add = arith.addf(lhs, rhs) + optimization_barrier = mgpu.dialect.OptimizationBarrierOp([add]) + + mgpu.infer_layout(self.module) + + self.checkInLayouts(optimization_barrier, [splat_layout]) + self.checkOutLayouts(optimization_barrier, [splat_layout]) + + def test_custom_primitive_op_retains_layouts(self): + with ir.InsertionPoint(self.module.body): + wgmma_layout = layouts.to_layout_attr(mgpu.WGMMA_LAYOUT) + wgmma_row_layout = layouts.to_layout_attr(mgpu.WGMMA_ROW_LAYOUT) + op = mgpu.dialect.custom_primitive( + result=[], + operands_=[], + in_layouts=[wgmma_layout], + in_transforms=[], + out_layouts=[wgmma_row_layout], + ) + + mgpu.infer_layout(self.module) + self.checkInLayouts(op, [wgmma_layout]) + self.checkOutLayouts(op, [wgmma_row_layout]) + + def test_constraint_extraction_works_correctly(self): + layout = mgpu.WGMMA_ROW_LAYOUT + with ir.InsertionPoint(self.module.body): + x = llvm.UndefOp(ir.VectorType.get((64,), ir.BF16Type.get())) + lc = layout_cast(x.result, layouts.to_layout_attr(layout)).owner.opview + + ctx = layout_inference.DerivationContext() + _, x_mapping = _undef_constraint_system(ctx, x) + _, lc_mapping = layout_inference._layout_cast_constraint_system( + ctx, lc + ) + [constraint] = layout_inference.derive_relayout_constraints( + x_mapping | lc_mapping + ) + [x_variable] = x_mapping.keys() + [lc_variable] = lc_mapping.keys() + self.assertEqual(constraint, cs.Relayout(x_variable, lc_variable, 16)) + + @parameterized.parameters(*layout_inference.MemorySpace) + def test_relayout_only_derived_for_registers(self, memory_space): + with ir.InsertionPoint(self.module.body): + shape = (128,) + f32 = ir.F32Type.get() + match memory_space: + case layout_inference.MemorySpace.REG: + ty = ir.VectorType.get(shape, f32) + case layout_inference.MemorySpace.TMEM: + ty = ir.MemRefType.get(shape, f32, memory_space=mgpu.utils.tmem()) + case layout_inference.MemorySpace.SMEM: + ty = ir.MemRefType.get(shape, f32, memory_space=mgpu.utils.smem()) + case _: + raise ValueError(f"Unsupported memory space: {memory_space}") + + [producer] = undefs(ty) + consumer = builtin.unrealized_conversion_cast([ty], [producer]) + + r = layout_inference.ValueSite( + producer.owner, layout_inference.VariableType.RESULT, 0 + ) + r_var = cs.Variable(r) + o = layout_inference.ValueSite( + consumer.owner, layout_inference.VariableType.OPERAND, 0 + ) + o_var = cs.Variable(o) + + relayouts = layout_inference.derive_relayout_constraints( + layout_inference.ValueSitesForVariable({r_var: [r], o_var: [o]}) + ) + + if memory_space == layout_inference.MemorySpace.REG: + self.assertEqual(relayouts, [cs.Relayout(r_var, o_var, 32)]) + else: + self.assertEmpty(relayouts) + + def test_find_assignments_for_is_transferable_constraints_is_deterministic( + self, + ): + v0 = V(0) + tmem_layout = tcgen05.tmem_default_layout(packing=1) + constraint = cs.IsTransferable( + v0, cs.TMEMLayout(tmem_layout), shape=(128, 128) + ) + assignments, _ = layout_inference.find_assignments_for( + {v0}, + cs.ConstraintSystem(constraints=[constraint]), + fuel=1000, + ) + # Another valid layout is TMEM_NATIVE_LAYOUT but TCGEN05_LAYOUT is tried + # first. This may require updating if we decide to change the traversal + # order in the future. + self.assertEqual(assignments, {v0: RL(mgpu.TCGEN05_LAYOUT)}) + + def test_cannot_find_assignments_for_unsatisfiable_constraint_system(self): + with ir.InsertionPoint(self.module.body): + x = llvm.UndefOp(ir.VectorType.get((64,), ir.BF16Type.get())) + + [key] = layout_inference.vector_value_sites(x) + variable = cs.Variable(key) + assignments, _ = layout_inference.find_assignments_for( + {variable}, + cs.ConstraintSystem( + constraints=[ + E(variable, RL(mgpu.WGMMA_ROW_LAYOUT)), + E(variable, RL(mgpu.WGMMA_COL_LAYOUT)), + ] + ), + fuel=1000, + ) + self.assertIsInstance(assignments, cs.Unsatisfiable) + + def test_vector_broadcast_from_scalar_infers_splat_layout(self): + shape = (128,) + f32 = ir.F32Type.get() + layout = mgpu.WGSplatFragLayout(shape) + with ir.InsertionPoint(self.module.body): + source, = undefs(f32) + bcast = vector.BroadcastOp(ir.VectorType.get(shape, f32), source) + + mgpu.infer_layout(self.module) + self.assertNotIn("in_layouts", bcast.attributes) + self.checkOutLayouts(bcast, [layout]) + + def test_vector_reduction_infers_reducible_producer_layout(self): + shape = (128,) + f32 = ir.F32Type.get() + layout = mgpu.WGMMA_ROW_LAYOUT + with ir.InsertionPoint(self.module.body): + source, = undefs(ir.VectorType.get(shape, f32)) + source = layout_cast(source, layout) + reduction = vector.ReductionOp(f32, vector.CombiningKind.ADD, source) + + mgpu.infer_layout(self.module) + self.checkInLayouts(reduction, [layout]) + self.assertNotIn("out_layouts", reduction.attributes) + + def test_infer_layout_of_custom_primitive_op_uses_argument_layouts(self): + in_layouts = [mgpu.WGMMA_LAYOUT, mgpu.WGMMA_ROW_LAYOUT] + out_layouts = [mgpu.WGMMA_COL_LAYOUT] + with ir.InsertionPoint(self.module.body): + f32 = ir.F32Type.get() + vec_ty = ir.VectorType.get((128, 128), f32) + op = mgpu.dialect.CustomPrimitiveOp( + result=[vec_ty], + operands_=undefs(f32, vec_ty, vec_ty, f32), + in_layouts=[layouts.to_layout_attr(l) for l in in_layouts], + in_transforms=[], + out_layouts=[layouts.to_layout_attr(l) for l in out_layouts], + ) + + mgpu.infer_layout(self.module) + self.checkInLayouts(op, in_layouts) + self.checkOutLayouts(op, out_layouts) + + def test_layout_cast_of_vector_load_to_splat_raises(self): + shape = (32, 4) + splat_layout = mgpu.WGSplatFragLayout(shape=shape) + with ir.InsertionPoint(self.module.body): + ref_ty = ir.MemRefType.get(shape, ir.BF16Type.get()) + [ref] = undefs(ref_ty) + loaded = mgpu.dialect.vector_load(ref) + layout_cast(loaded, splat_layout) + + with self.assertRaisesRegex( + ValueError, "user-provided layout casts are unsatisfiable" + ): + mgpu.infer_layout(self.module) + + def test_layout_cast_of_non_splat_constant_to_splat_raises(self): + shape = (128,) + splat_layout = mgpu.WGSplatFragLayout(shape=shape) + with ir.InsertionPoint(self.module.body): + bf16 = ir.BF16Type.get() + ty = ir.VectorType.get(shape, bf16) + values = [ir.FloatAttr.get(bf16, float(i)) for i in range(shape[0])] + constant = arith.constant(ty, ir.DenseElementsAttr.get(values, ty)) + layout_cast(constant, splat_layout) + + with self.assertRaisesRegex( + ValueError, "user-provided layout casts are unsatisfiable" + ): + mgpu.infer_layout(self.module) + + def test_layout_of_wgmma_layout_to_wgmma_row_layout_raises(self): + with ir.InsertionPoint(self.module.body): + [ref] = undefs(ir.VectorType.get((128, 128), ir.F32Type.get())) + wgmma_layout = layouts.to_layout_attr(fa.WGMMA_LAYOUT) + wgmma_row_layout = layouts.to_layout_attr(fa.WGMMA_ROW_LAYOUT) + ref = mgpu.dialect.layout_cast(ref, wgmma_layout) + mgpu.dialect.layout_cast(ref, wgmma_row_layout) + + with self.assertRaisesRegex( + ValueError, "user-provided layout casts are unsatisfiable" + ): + mgpu.infer_layout(self.module) + + def test_infer_layout_for_tmem_alloc_by_default(self): + f32 = ir.F32Type.get() + i32 = ir.IntegerType.get_signless(32) + shape = (128, 128) + ptr_type = ir.MemRefType.get((1,), i32, memory_space=mgpu.utils.smem()) + ref_ty = ir.MemRefType.get(shape, f32, memory_space=mgpu.utils.tmem()) + + with ir.InsertionPoint(self.module.body): + ptr = llvm.mlir_undef(ptr_type) + op = mgpu.dialect.TmemAllocOp(result=ref_ty, smem_ptr=ptr) + + mgpu.infer_layout(self.module) + self.assertNotIn("in_tmem_layouts", op.attributes) + layout = tcgen05._infer_tmem_layout(shape, collective=False, packing=1) + self.checkOutTmemLayouts(op, [layout]) + + def test_infer_tmem_layout_cast_correctly(self): + f32 = ir.F32Type.get() + ref_ty = ir.MemRefType.get((128, 128), f32, memory_space=mgpu.utils.tmem()) + layout = layouts.to_layout_attr(mgpu.TMEM_NATIVE_LAYOUT) + + with ir.InsertionPoint(self.module.body): + ref = llvm.mlir_undef(ref_ty) + op = mgpu.dialect.TmemLayoutCastOp(ref, layout) + + mgpu.infer_layout(self.module) + self.checkInTmemLayouts(op, [layout]) + self.checkOutTmemLayouts(op, [layout]) + + def test_cant_relayout_tmem(self): + f32 = ir.F32Type.get() + ref_ty = ir.MemRefType.get((128, 128), f32, memory_space=mgpu.utils.tmem()) + + with ir.InsertionPoint(self.module.body): + ref = llvm.mlir_undef(ref_ty) + layout = tcgen05.tmem_default_layout(packing=1) + ref = mgpu.dialect.tmem_layout_cast(ref, layouts.to_layout_attr(layout)) + layout = tcgen05.tmem_half_lane_layout(columns=128, packing=1) + mgpu.dialect.tmem_layout_cast(ref, layouts.to_layout_attr(layout)) + + with self.assertRaisesRegex( + ValueError, "Failed to infer a possible set of layouts." + ): + mgpu.infer_layout(self.module) + + def test_infer_tmem_alloc_layout_correctly(self): + f32 = ir.F32Type.get() + i32 = ir.IntegerType.get_signless(32) + ptr_type = ir.MemRefType.get((1,), i32, memory_space=mgpu.utils.smem()) + ref_ty = ir.MemRefType.get((128, 128), f32, memory_space=mgpu.utils.tmem()) + layout = layouts.to_layout_attr(mgpu.TMEM_NATIVE_LAYOUT) + + with ir.InsertionPoint(self.module.body): + ptr = llvm.mlir_undef(ptr_type) + op = mgpu.dialect.TmemAllocOp(ref_ty, ptr) + mgpu.dialect.tmem_layout_cast(op.result, layout) + + mgpu.infer_layout(self.module) + self.assertNotIn("in_tmem_layouts", op.attributes) + self.checkOutTmemLayouts(op, [layout]) + + def test_tmem_dealloc_propagates_producer_layout(self): + f32 = ir.F32Type.get() + ref_ty = ir.MemRefType.get((128, 128), f32, memory_space=mgpu.utils.tmem()) + layout = layouts.to_layout_attr(mgpu.TMEM_NATIVE_LAYOUT) + + with ir.InsertionPoint(self.module.body): + ref = llvm.mlir_undef(ref_ty) + ref = mgpu.dialect.tmem_layout_cast(ref, layout) + op = mgpu.dialect.TmemDeallocOp(ref) + + mgpu.infer_layout(self.module) + self.checkInTmemLayouts(op, [layout]) + self.assertNotIn("out_tmem_layouts", op.attributes) + + def test_infer_async_load_chooses_in_tmem_layouts_compatible_with_register_layout(self): + f32 = ir.F32Type.get() + shape = (128, 128) + ref_type = ir.MemRefType.get(shape, f32, memory_space=mgpu.utils.tmem()) + out_layout = layouts.to_layout_attr(fa.TCGEN05_LAYOUT) + + with ir.InsertionPoint(self.module.body): + [ref] = undefs(ref_type) + op = mgpu.dialect.AsyncLoadTmemOp(ref) + mgpu.dialect.layout_cast(op.result, out_layout) + + mgpu.infer_layout(self.module) + in_layout = tcgen05.tmem_default_layout(packing=1) + in_layout = layouts.to_layout_attr(in_layout) + self.checkInTmemLayouts(op, [in_layout]) + self.checkOutLayouts(op, [out_layout]) + + def test_infer_async_load_chooses_out_layouts_compatible_with_tmem_layout(self): + f32 = ir.F32Type.get() + shape = (128, 128) + ref_type = ir.MemRefType.get(shape, f32, memory_space=mgpu.utils.tmem()) + in_layout = tcgen05.tmem_default_layout(packing=1) + in_layout = layouts.to_layout_attr(in_layout) + + with ir.InsertionPoint(self.module.body): + [ref] = undefs(ref_type) + ref = mgpu.dialect.tmem_layout_cast(ref, in_layout) + op = mgpu.dialect.AsyncLoadTmemOp(ref) + + mgpu.infer_layout(self.module) + self.checkInTmemLayouts(op, [in_layout]) + out_layout = layouts.to_layout_attr(fa.TCGEN05_LAYOUT) + self.checkOutLayouts(op, [out_layout]) + + @parameterized.parameters( + mtu.RegisterLayout.TCGEN05, mtu.RegisterLayout.TCGEN05_TMEM_NATIVE + ) + def test_async_load_tmem_accepts_expected_in_out_layouts(self, out_layout): + f32 = ir.F32Type.get() + shape = (128, 128) + ref_type = ir.MemRefType.get(shape, f32, memory_space=mgpu.utils.tmem()) + in_layout = tcgen05.tmem_default_layout(packing=1) + in_layout = layouts.to_layout_attr(in_layout) + out_layout = out_layout.to_layout_attr(shape, f32) + + with ir.InsertionPoint(self.module.body): + [ref] = undefs(ref_type) + ref = mgpu.dialect.tmem_layout_cast(ref, in_layout) + op = mgpu.dialect.AsyncLoadTmemOp(ref) + mgpu.dialect.layout_cast(op.result, out_layout) + + mgpu.infer_layout(self.module) + self.checkInTmemLayouts(op, [in_layout]) + self.checkOutLayouts(op, [out_layout]) + + def test_async_load_tmem_rejects_incompatible_in_out_layouts(self): + f32 = ir.F32Type.get() + shape = (128, 128) + ref_type = ir.MemRefType.get(shape, f32, memory_space=mgpu.utils.tmem()) + in_layout = tcgen05.tmem_half_lane_layout(columns=shape[1], packing=1) + in_layout = layouts.to_layout_attr(in_layout) + out_layout = layouts.to_layout_attr(fa.TCGEN05_LAYOUT) + + with ir.InsertionPoint(self.module.body): + [ref] = undefs(ref_type) + ref = mgpu.dialect.tmem_layout_cast(ref, in_layout) + op = mgpu.dialect.AsyncLoadTmemOp(ref) + mgpu.dialect.layout_cast(op.result, out_layout) + + with self.assertRaisesRegex( + ValueError, "Failed to infer a possible set of layouts." + ): + mgpu.infer_layout(self.module) + + @parameterized.parameters( + mtu.RegisterLayout.TCGEN05, mtu.RegisterLayout.TCGEN05_TMEM_NATIVE + ) + def test_async_store_tmem_accepts_expected_src_dest_layouts( + self, src_layout + ): + f32 = ir.F32Type.get() + shape = (128, 128) + dest_type = ir.MemRefType.get(shape, f32, memory_space=mgpu.utils.tmem()) + src_type = ir.VectorType.get(shape, f32) + src_layout = src_layout.to_layout_attr(shape, f32) + dest_layout = tcgen05.tmem_default_layout(packing=1) + dest_layout = layouts.to_layout_attr(dest_layout) + + with ir.InsertionPoint(self.module.body): + [src, dest] = undefs(src_type, dest_type) + src = mgpu.dialect.layout_cast(src, src_layout) + dest = mgpu.dialect.tmem_layout_cast(dest, dest_layout) + op = mgpu.dialect.AsyncStoreTmemOp(src, dest) + + mgpu.infer_layout(self.module) + self.checkInLayouts(op, [src_layout]) + self.checkInTmemLayouts(op, [dest_layout]) + + def test_layout_inference_gelu_does_not_timeout(self): + # This test is intended to make sure that the constraint-based layout + # inference does not timeout on a Gelu kernel. This was previously the case, + # and we want to make sure that regressions don't happen. + with ir.InsertionPoint(self.module.body): + shape = (128,) + f32 = ir.F32Type.get() + vector_ty = ir.VectorType.get(shape, f32) + memref_ty = ir.MemRefType.get(shape, f32) + + # The code below is essentially jax.nn.gelu(). + c_05 = arith.constant(vector_ty, ir.DenseElementsAttr.get_splat(vector_ty, ir.FloatAttr.get(f32, 0.5))) + c_1 = arith.constant(vector_ty, ir.DenseElementsAttr.get_splat(vector_ty, ir.FloatAttr.get(f32, 1.0))) + c_079 = arith.constant(vector_ty, ir.DenseElementsAttr.get_splat(vector_ty, ir.FloatAttr.get(f32, 0.797884583))) + c_044 = arith.constant(vector_ty, ir.DenseElementsAttr.get_splat(vector_ty, ir.FloatAttr.get(f32, 0.044715))) + + memref = llvm.mlir_undef(memref_ty) + load = mgpu.dialect.VectorLoadOp(memref) + x = load.result + x2 = arith.mulf(x, x) + x3 = arith.mulf(x2, x) + y = arith.mulf(x3, c_044) + x_y = arith.addf(x, y) + z = arith.mulf(x_y, c_079) + t = math_dialect.tanh(z) + u = arith.addf(t, c_1) + v = arith.mulf(u, c_05) + r = arith.mulf(x, v) + store = mgpu.dialect.VectorStoreOp(r, memref) + + mgpu.infer_layout(self.module) + + strided_layout = layouts.to_layout_attr(mgpu.WGStridedFragLayout(shape, 1)) + self.checkOutLayouts(load, [strided_layout]) + self.checkInLayouts(store, [strided_layout]) + + @parameterized.parameters( + ((32, 256), ir.BF16Type, False, None, 16), + ((32, 256), ir.BF16Type, False, (2, 64), 128), + ((32, 256), ir.BF16Type, False, (2, 32), 64), + ((32, 256), ir.BF16Type, False, (2, 16), 32), + ((32, 256), ir.BF16Type, False, (2, 8), 16), + ((5, 32, 256), ir.BF16Type, False, (2, 64), 128), + ((5, 32, 256), ir.BF16Type, False, (2, 16), 32), + ((3, 32, 256), ir.Float8E4M3FNType, False, (2, 128), 128), + ((3, 32, 256), ir.Float8E4M3FNType, False, (2, 64), 64), + ((3, 32, 256), ir.Float8E4M3FNType, False, (2, 32), 32), + ((3, 32, 256), ir.Float8E4M3FNType, False, (2, 16), 16), + ((3, 32, 256), ir.BF16Type, True, (16, 32), 32), + ((3, 32, 256), ir.BF16Type, False, (64,), 128), + ((256,), ir.BF16Type, False, (2, 2), None), + ) + def test_compute_swizzle(self, shape, type, transposed, tiling, want_swizzle): + with ir.InsertionPoint(self.module.body): + ref_ty = ir.MemRefType.get(shape, type.get()) + if transposed: + strides, offset = ref_ty.get_strides_and_offset() + strides[-1], strides[-2] = strides[-2], strides[-1] + layout = ir.StridedLayoutAttr.get(offset, strides) + ref_ty = ir.MemRefType.get(shape, type.get(), layout) + + tile_transform = None if tiling is None else lc.TileTransform(tiling) + + if want_swizzle is None: + with self.assertRaises(ValueError): + layout_inference._compute_swizzle(ref_ty, tile_transform) + else: + swizzle = layout_inference._compute_swizzle(ref_ty, tile_transform) + self.assertEqual(swizzle, mgpu.dialect.SwizzlingMode(want_swizzle)) + + @parameterized.parameters([False, True]) + def test_conjure_smem_assignment_from_is_transferrable(self, transposed): + # Create a var to use in the constraint system. + shape = (128, 128) + f32 = ir.F32Type.get() + layout = ir.StridedLayoutAttr.get(0, [1, 128]) if transposed else None + ref_ty = ir.MemRefType.get(shape, f32, layout=layout, memory_space=mgpu.utils.smem()) + [ref] = undefs(ref_ty) + value_site = layout_inference.ValueSite( + operation=ref.owner, + type=layout_inference.VariableType.RESULT, + index=0, + ) + var = cs.Variable(value_site) + + def conjure(constraints) -> list[tuple[cs.Variable, cs.Constant]]: + system = cs.ConstraintSystem(constraints=constraints) + return list(layout_inference.conjure_assignment({var}, system)) + + # Yield only empty tiling with no constraints. + with self.subTest("no_constraints_yield_empty_tiling"): + self.assertEqual(conjure([]), [(var, cs.SMEMTiling(None))]) + + # Yield empty if not an mma layout. + with self.subTest("not_mma_layout_yield_empty_tiling"): + layout = cs.RegisterLayout(fa.WGSplatFragLayout(shape)) + constraints = [cs.IsTransferable(layout, var, (128, 128))] + conjured = conjure(constraints) + self.assertEqual(conjured, [(var, cs.SMEMTiling(None))]) + + wgmma_layout = cs.RegisterLayout(fa.WGMMA_LAYOUT) + + # Yield also maximal tiling with no Divides constraints. + with self.subTest("no_divides_constraints_yield_maximal_tiling_with_mma"): + constraints = [cs.IsTransferable(wgmma_layout, var, (128, 128))] + conjured = conjure(constraints) + if transposed: + expected_tiling = (32, 8) + else: + expected_tiling = (8, 32) + self.assertEqual( + conjured, + [ + (var, cs.SMEMTiling(lc.TileTransform(expected_tiling))), + (var, cs.SMEMTiling(None)), + ], + ) - f.attributes["in_layouts"] = ir.ArrayAttr.get( - [non_splat_layout, splat_layout] + # Yield also valid tiling with Divides constraints. + with self.subTest("divides_constraints_yield_valid_tiling"): + constraints = [ + cs.IsTransferable(wgmma_layout, var, (128, 128)), + cs.Divides(var, (32, 16)), + ] + conjured = conjure(constraints) + if transposed: + expected_tiling = (32, 8) + else: + expected_tiling = (8, 16) + self.assertEqual( + conjured, + [ + (var, cs.SMEMTiling(lc.TileTransform(expected_tiling))), + (var, cs.SMEMTiling(None)), + ], + ) + + def test_conjure_tries_high_priority_assignments_first(self): + shape = (128, 128) + f32 = ir.F32Type.get() + [val] = undefs(ir.VectorType.get(shape, f32)) + value_site = layout_inference.ValueSite( + operation=val.owner, + type=layout_inference.VariableType.RESULT, + index=0, ) + var = cs.Variable(value_site) + bitwidth = mgpu.utils.bitwidth(f32) + + constraints = [ + cs.Relayout( + var, + cs.RegisterLayout(fa.WGSplatFragLayout((128, 128))), + bitwidth, + ), + cs.Relayout( + var, + cs.RegisterLayout(fa.WGMMA_LAYOUT), + bitwidth, + ), + cs.Relayout( + var, + cs.RegisterLayout(fa.WGStridedFragLayout(shape, vec_size=4)), + bitwidth, + ), + ] + + system = cs.ConstraintSystem(constraints=constraints) + ordered = list(layout_inference.conjure_assignment({var}, system)) + expected = [ + (var, cs.RegisterLayout(fa.WGMMA_LAYOUT)), + (var, cs.RegisterLayout(fa.WGSplatFragLayout((128, 128)))), + (var, cs.RegisterLayout(fa.WGStridedFragLayout(shape, vec_size=4))), + (var, cs.RegisterLayout(fa.WGStridedFragLayout(shape, vec_size=2))), + ] + self.assertEqual(ordered, expected) + + def test_memref_load_store_op_transforms_are_empty(self): + with ir.InsertionPoint(self.module.body): + i32 = ir.IntegerType.get_signless(32) + ref_ty = ir.MemRefType.get((), i32, memory_space=mgpu.utils.smem()) + + [val, load_ref, store_ref] = undefs(i32, ref_ty, ref_ty) + load_op = memref.LoadOp(load_ref, []) + store_op = memref.StoreOp(val, store_ref, []) + + mgpu.infer_layout(self.module) + + want = ir.ArrayAttr.get([ir.ArrayAttr.get([])]) + self.assertEqual(inference_utils.in_transforms(load_op), want) + self.assertEqual(inference_utils.in_transforms(store_op), want) + + @parameterized.product( + swizzle=tuple(mgpu.dialect.SwizzlingMode), + dtype=(jnp.bfloat16, jnp.float32), + lhs_in_registers=(False, True), + ) + def test_infer_transforms_for_wgmma_op(self, swizzle, dtype, lhs_in_registers): + if swizzle == mgpu.dialect.SwizzlingMode.kNoSwizzle: + self.skipTest("kNoSwizzle is not supported by this test.") + + swizzle_elems = swizzle // np.dtype(dtype).itemsize + m = 64 + # Note: `group_m` and `group_k` should be coprime with 2 for the test to be + # correct. Otherwise, we may infer larger swizzles than the test intends to + # check. + group_m, group_k = 3, 3 + lhs_shape = (group_m * m, group_k * swizzle_elems) + rhs_shape = (group_k * swizzle_elems, group_k * swizzle_elems) + out_shape = (group_m * m, group_k * swizzle_elems) + + with ir.InsertionPoint(self.module.body): + elt_ty = mgpu.utils.dtype_to_ir_type(dtype) + lhs_ref_ty = ir.MemRefType.get(lhs_shape, elt_ty, memory_space=mgpu.utils.smem()) + lhs_vec_ty = ir.VectorType.get(lhs_shape, elt_ty) + lhs_ty = lhs_vec_ty if lhs_in_registers else lhs_ref_ty + rhs_ty = ir.MemRefType.get(rhs_shape, elt_ty, memory_space=mgpu.utils.smem()) + acc_ty = ir.VectorType.get(out_shape, elt_ty) + [acc, lhs, rhs] = undefs(acc_ty, lhs_ty, rhs_ty) + wgmma_op = mgpu.dialect.WGMMAOp(acc, lhs, rhs) mgpu.infer_layout(self.module) + wgmma_layout = layouts.to_layout_attr(mgpu.WGMMA_LAYOUT) + arg_transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((8, swizzle_elems)), + mgpu.dialect.SwizzleTransformAttr.get(int(swizzle)), + ]) + + in_layouts = [wgmma_layout] + out_layouts = [wgmma_layout] + in_transforms = [arg_transforms] + if lhs_in_registers: + in_layouts.append(wgmma_layout) + else: + in_transforms.append(arg_transforms) + + self.checkInLayouts(wgmma_op, in_layouts) + self.checkOutLayouts(wgmma_op, out_layouts) self.assertSequenceEqual( - add.attributes["in_layouts"], - [non_splat_layout, non_splat_layout], + inference_utils.in_transforms(wgmma_op), in_transforms ) - self.assertSequenceEqual(add.attributes["out_layouts"], [non_splat_layout]) - def test_infer_layout_preserves_splat_layouts_in_producers(self): - add0 = add1 = None + @parameterized.product( + dtype=(jnp.int8, jnp.uint8), + lhs_in_registers=(False, True), + ) + def test_infer_layouts_for_8bits_wgmma_op(self, dtype, lhs_in_registers): + shape = (128, 128) + with ir.InsertionPoint(self.module.body): + elt_ty = mgpu.utils.dtype_to_ir_type(dtype) + lhs_ref_ty = ir.MemRefType.get( + shape, elt_ty, memory_space=mgpu.utils.smem() + ) + lhs_vec_ty = ir.VectorType.get(shape, elt_ty) + lhs_ty = lhs_vec_ty if lhs_in_registers else lhs_ref_ty + rhs_ty = ir.MemRefType.get(shape, elt_ty, memory_space=mgpu.utils.smem()) + acc_ty = ir.VectorType.get(shape, elt_ty) + [acc, lhs, rhs] = undefs(acc_ty, lhs_ty, rhs_ty) + wgmma_op = mgpu.dialect.WGMMAOp(acc, lhs, rhs) - def body(lhs, rhs): - nonlocal add0, add1 - add0 = arith.AddFOp(lhs, rhs) - add1 = arith.AddFOp(add0.result, add0) + mgpu.infer_layout(self.module) + + if lhs_in_registers: + self.checkInLayouts(wgmma_op, [mgpu.WGMMA_LAYOUT, mgpu.WGMMA_LAYOUT_8BIT]) + else: + self.checkInLayouts(wgmma_op, [mgpu.WGMMA_LAYOUT]) + self.checkOutLayouts(wgmma_op, [mgpu.WGMMA_LAYOUT]) + + @parameterized.product( + swizzle_lhs=tuple(mgpu.dialect.SwizzlingMode), + swizzle_rhs=tuple(mgpu.dialect.SwizzlingMode), + dtype=(jnp.bfloat16, jnp.float32), + lhs_in_tmem=(False, True), + ) + def test_infer_transforms_for_tcgen05_mma_op( + self, swizzle_lhs, swizzle_rhs, dtype, lhs_in_tmem + ): + if mgpu.dialect.SwizzlingMode.kNoSwizzle in (swizzle_lhs, swizzle_rhs): + self.skipTest("kNoSwizzle is not supported by this test.") + + swizzle_elems_lhs = swizzle_lhs // np.dtype(dtype).itemsize + swizzle_elems_rhs = swizzle_rhs // np.dtype(dtype).itemsize + m = 128 + # Note: `group_m` and `group_k` should be coprime with 2 for the test to be + # correct. Otherwise, we may infer larger swizzles than the test intends to + # check. + group_k, group_n = 3, 5 + lhs_shape = (m, group_k * swizzle_elems_lhs) + rhs_shape = (group_k * swizzle_elems_lhs, group_n * swizzle_elems_rhs) + out_shape = (m, group_n * swizzle_elems_rhs) with ir.InsertionPoint(self.module.body): - shape = (32, 4) - elt_type = ir.BF16Type.get() - ty = ir.VectorType.get(shape, elt_type) - f = func.FuncOp.from_py_func(ty, ty)(body).func_op + elt_ty = mgpu.utils.dtype_to_ir_type(dtype) + lhs_mem_space = mgpu.utils.tmem() if lhs_in_tmem else mgpu.utils.smem() + lhs_ty = ir.MemRefType.get(lhs_shape, elt_ty, memory_space=lhs_mem_space) + rhs_ty = ir.MemRefType.get(rhs_shape, elt_ty, memory_space=mgpu.utils.smem()) + acc_ty = ir.MemRefType.get(out_shape, elt_ty, memory_space=mgpu.utils.tmem()) + [acc, lhs, rhs] = undefs(acc_ty, lhs_ty, rhs_ty) + accumulate = arith.constant(ir.IntegerType.get_signless(1), 1) + tcgen05_mma_op = mgpu.dialect.TcGen05MMAOp(acc, lhs, rhs, accumulate) - splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape)) - strided_layout = layouts.to_layout_attr( - mgpu.WGStridedFragLayout(shape, vec_size=1) + mgpu.infer_layout(self.module) + + self.assertNotIn("out_tmem_layouts", tcgen05_mma_op.attributes) + acc_layout = tcgen05._infer_tmem_layout(out_shape, collective=False, packing=1) + a_packing = 2 if dtype == jnp.bfloat16 else 1 + a_layout = tcgen05._infer_tmem_layout(lhs_shape, collective=False, packing=a_packing) + expected_layouts = [acc_layout, a_layout] if lhs_in_tmem else [acc_layout] + self.checkInTmemLayouts(tcgen05_mma_op, expected_layouts) + + arg_transforms_lhs = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((8, swizzle_elems_lhs)), + mgpu.dialect.SwizzleTransformAttr.get(int(swizzle_lhs)), + ]) + arg_transforms_rhs = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((8, swizzle_elems_rhs)), + mgpu.dialect.SwizzleTransformAttr.get(int(swizzle_rhs)), + ]) + if lhs_in_tmem: + transforms = [arg_transforms_rhs] + else: + transforms = [arg_transforms_lhs, arg_transforms_rhs] + + self.assertSequenceEqual( + inference_utils.in_transforms(tcgen05_mma_op), transforms + ) + + def test_infer_correct_swizzle_for_tcgen05_mma_op_with_m64(self): + with ir.InsertionPoint(self.module.body): + dtype = ir.IntegerType.get_signless(8) + shape = (64, 64) + lhs_ty = ir.MemRefType.get(shape, dtype, memory_space=mgpu.utils.smem()) + rhs_ty = ir.MemRefType.get(shape, dtype, memory_space=mgpu.utils.smem()) + acc_ty = ir.MemRefType.get(shape, dtype, memory_space=mgpu.utils.tmem()) + [acc, lhs, rhs] = undefs(acc_ty, lhs_ty, rhs_ty) + accumulate = arith.constant(ir.IntegerType.get_signless(1), 1) + op = mgpu.dialect.TcGen05MMAOp(acc, lhs, rhs, accumulate) + + mgpu.infer_layout(self.module) + lhs_transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((8, 64)), + mgpu.dialect.SwizzleTransformAttr.get(64), + ]) + rhs_transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((8, 32)), + mgpu.dialect.SwizzleTransformAttr.get(32), + ]) + self.assertSequenceEqual( + inference_utils.in_transforms(op), [lhs_transforms, rhs_transforms] + ) + + @parameterized.parameters(mgpu.dialect.AsyncLoadOp, mgpu.dialect.AsyncStoreOp) + def test_infer_transforms_for_async_load_store_works_on_ok_input(self, op_type): + # OK input means that the indices are a multiple of the tile size. + shape = (64, 64) + elt_ty = ir.BF16Type.get() + + with ir.InsertionPoint(self.module.body): + gmem_ty = ir.MemRefType.get(shape, elt_ty) + smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=mgpu.utils.smem()) + barrier_ty = ir.Type.parse("!mosaic_gpu.barrier") + gmem_ref, smem_ref, barrier = undefs(gmem_ty, smem_ty, barrier_ty) + + transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((8, 32)), + mgpu.dialect.SwizzleTransformAttr.get(64), + ]) + zero = arith.constant(ir.IntegerType.get_signless(32), 0) + smem_ref = mgpu.dialect.with_transforms(smem_ref, transforms) + if op_type == mgpu.dialect.AsyncLoadOp: + op = mgpu.dialect.AsyncLoadOp( + source=gmem_ref, + destination=smem_ref, + barrier=barrier, + indices=[zero, zero], + slice_lengths=shape, + collective=ir.ArrayAttr.get([]), + ) + else: + op = mgpu.dialect.AsyncStoreOp( + source=smem_ref, + destination=gmem_ref, + indices=[zero, zero], + slice_lengths=shape, + ) + + mgpu.infer_layout(self.module) + + self.assertSequenceEqual( + inference_utils.in_transforms(op), [transforms] ) - f.attributes["in_layouts"] = ir.ArrayAttr.get([splat_layout, splat_layout]) - add1.attributes["out_layouts"] = ir.ArrayAttr.get([strided_layout]) + + @parameterized.parameters(mgpu.dialect.AsyncLoadOp, mgpu.dialect.AsyncStoreOp) + def test_infer_transforms_for_async_load_store_raises_on_unaligned_tiles(self, op_type): + shape = (64, 64) + elt_ty = ir.BF16Type.get() + + with ir.InsertionPoint(self.module.body): + gmem_ty = ir.MemRefType.get(shape, elt_ty) + smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=mgpu.utils.smem()) + barrier_ty = ir.Type.parse("!mosaic_gpu.barrier") + gmem_ref, smem_ref, barrier = undefs(gmem_ty, smem_ty, barrier_ty) + + transforms = ir.ArrayAttr.get( + [mgpu.dialect.TileTransformAttr.get((8, 32))] + ) + one = arith.constant(ir.IntegerType.get_signless(32), 1) + smem_ref = mgpu.dialect.with_transforms(smem_ref, transforms) + if op_type == mgpu.dialect.AsyncLoadOp: + mgpu.dialect.AsyncLoadOp( + source=gmem_ref, + destination=smem_ref, + barrier=barrier, + indices=[one, one], + slice_lengths=shape, + collective=ir.ArrayAttr.get([]), + ) + else: + mgpu.dialect.AsyncStoreOp( + source=smem_ref, + destination=gmem_ref, + indices=[one, one], + slice_lengths=shape, + ) + + with self.assertRaisesRegex(ValueError, "Failed to infer"): + mgpu.infer_layout(self.module) + + @parameterized.parameters(*mtu.RegisterLayout) + def test_infer_transforms_for_vector_load_op(self, layout): + if layout == mtu.RegisterLayout.WG_SPLAT: + self.skipTest("WG_SPLAT is not supported for `vector_load`.") + + shape = (128, 128) + elt_ty = ir.BF16Type.get() + layout = layout.to_mgpu(shape, elt_ty) + + with ir.InsertionPoint(self.module.body): + smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=mgpu.utils.smem()) + [smem_ref] = undefs(smem_ty) + op = mgpu.dialect.VectorLoadOp(smem_ref) + layout_cast(op.result, layout) + + if inference_utils.is_mma_layout(layout): + expected_transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((8, 64)), + mgpu.dialect.SwizzleTransformAttr.get(128), + ]) + else: + expected_transforms = ir.ArrayAttr.get([]) + mgpu.infer_layout(self.module) + self.assertSequenceEqual( + inference_utils.in_transforms(op), [expected_transforms] + ) + + @parameterized.parameters(*mtu.RegisterLayout) + def test_infer_transforms_for_vector_store_op(self, layout): + shape = (128, 128) + elt_ty = ir.BF16Type.get() + layout = layout.to_mgpu(shape, elt_ty) + with ir.InsertionPoint(self.module.body): + smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=mgpu.utils.smem()) + value_ty = ir.VectorType.get(shape, elt_ty) + [smem_ref, value_to_store] = undefs(smem_ty, value_ty) + value_to_store = layout_cast(value_to_store, layout) + op = mgpu.dialect.VectorStoreOp(value_to_store, smem_ref) + + if inference_utils.is_mma_layout(layout): + expected_transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((8, 64)), + mgpu.dialect.SwizzleTransformAttr.get(128), + ]) + else: + expected_transforms = ir.ArrayAttr.get([]) + + mgpu.infer_layout(self.module) self.assertSequenceEqual( - add0.attributes["in_layouts"], [splat_layout, splat_layout] + inference_utils.in_transforms(op), [expected_transforms] ) + + def test_slice_smem_gets_empty_by_default(self): + with ir.InsertionPoint(self.module.body): + shape = (64, 64) + elt_ty = ir.BF16Type.get() + i32 = ir.IntegerType.get_signless(32) + [offset] = undefs(i32) + ref_ty = ir.MemRefType.get(shape, elt_ty, memory_space=mgpu.utils.smem()) + slice_smem_op = mgpu.dialect.SliceSMEMOp(ref_ty, offset) + + transforms = ir.ArrayAttr.get([]) + mgpu.infer_layout(self.module) + self.assertSequenceEqual( + inference_utils.out_transforms(slice_smem_op), [transforms] + ) + + def test_infer_transforms_preserves_with_transforms_requirements(self): + shape = (64, 64) + elt_ty = ir.BF16Type.get() + + with ir.InsertionPoint(self.module.body): + ref_ty = ir.MemRefType.get(shape, elt_ty, memory_space=mgpu.utils.smem()) + [ref] = undefs(ref_ty) + + transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((8, 64)), + mgpu.dialect.SwizzleTransformAttr.get(128), + ]) + mgpu.dialect.with_transforms(ref, transforms) + + mgpu.infer_layout(self.module) self.assertSequenceEqual( - add1.attributes["in_layouts"], [strided_layout, strided_layout] + inference_utils.out_transforms(ref.owner), [transforms] ) - self.assertSequenceEqual(add0.attributes["out_layouts"], [splat_layout]) - self.assertSequenceEqual(add1.attributes["out_layouts"], [strided_layout]) + def test_infer_transforms_fails_on_conflicting_with_transforms_requirements(self): + shape = (64, 64) + elt_ty = ir.BF16Type.get() - def test_infer_layout_propagates_func_layouts_to_ops(self): - add = None + with ir.InsertionPoint(self.module.body): + ref_ty = ir.MemRefType.get(shape, elt_ty, memory_space=mgpu.utils.smem()) + [ref] = undefs(ref_ty) + + transforms1 = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((8, 64)), + mgpu.dialect.SwizzleTransformAttr.get(128), + ]) + transforms2 = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((16, 64)), + mgpu.dialect.SwizzleTransformAttr.get(128), + ]) + mgpu.dialect.with_transforms(ref, transforms1) + mgpu.dialect.with_transforms(ref, transforms2) + + with self.assertRaisesRegex(ValueError, "Failed to infer"): + mgpu.infer_layout(self.module) + + def test_infer_transforms_sets_default_empty_transforms_on_async_load(self): + shape = (64, 64) + elt_ty = ir.BF16Type.get() - def body(lhs, rhs): - nonlocal add - add = arith.AddFOp(lhs, rhs) + with ir.InsertionPoint(self.module.body): + gmem_ty = ir.MemRefType.get(shape, elt_ty) + smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=mgpu.utils.smem()) + barrier_ty = ir.Type.parse("!mosaic_gpu.barrier") + [gmem_ref, smem_ref, barrier] = undefs(gmem_ty, smem_ty, barrier_ty) + + zero = arith.constant(ir.IntegerType.get_signless(32), 0) + async_load_op = mgpu.dialect.AsyncLoadOp( + source=gmem_ref, + destination=smem_ref, + barrier=barrier, + indices=[zero, zero], + slice_lengths=shape, + collective=ir.ArrayAttr.get([]), + ) + mgpu.infer_layout(self.module) + [in_transform] = inference_utils.in_transforms(async_load_op) + self.assertSequenceEqual(in_transform, ir.ArrayAttr.get([])) + + @parameterized.parameters([False, True]) + def test_infer_transforms_for_memref_cast_op(self, annotate_producer): with ir.InsertionPoint(self.module.body): - shape = (32, 4) - ty = ir.VectorType.get(shape, ir.BF16Type.get()) - f = func.FuncOp.from_py_func(ty, ty)(body).func_op + shape = (64, 64) + elt_ty = ir.BF16Type.get() + ref_ty = ir.MemRefType.get(shape, elt_ty, memory_space=mgpu.utils.smem()) + [ref] = undefs(ref_ty) + + transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((8, 64)), + mgpu.dialect.SwizzleTransformAttr.get(128), + ]) + + if annotate_producer: + ref = mgpu.dialect.with_transforms(ref, transforms) + cast = memref.cast(ref_ty, ref) + if not annotate_producer: + mgpu.dialect.with_transforms(cast, transforms) + + mgpu.infer_layout(self.module) + self.assertSequenceEqual( + inference_utils.in_transforms(cast.owner), [transforms] + ) + self.assertSequenceEqual( + inference_utils.out_transforms(cast.owner), [transforms] + ) + + @parameterized.parameters([False, True]) + def test_infer_transforms_for_subview_raises_on_slice_incompatible_with_tile( + self, annotate_input + ): + with ir.InsertionPoint(self.module.body): + in_ref_ty = ir.MemRefType.get( + (2, 64, 64), ir.BF16Type.get(), memory_space=mgpu.utils.smem() + ) + [in_ref] = undefs(in_ref_ty) + + transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((32, 16)), + mgpu.dialect.SwizzleTransformAttr.get(32), + ]) + + if annotate_input: + in_ref = mgpu.dialect.with_transforms(in_ref, transforms) + + out_ref = memref.subview( + in_ref, offsets=[1, 0, 0], sizes=[2, 64, 8], strides=[1, 1, 1] + ) + + if not annotate_input: + mgpu.dialect.with_transforms(out_ref, transforms) + + with self.assertRaisesRegex(ValueError, "Failed to infer"): + mgpu.infer_layout(self.module) + + @parameterized.parameters([False, True]) + def test_infer_tmem_layouts_for_subview_raises_on_slice_incompatible_with_tile( + self, annotate_input + ): + with ir.InsertionPoint(self.module.body): + in_ref_ty = ir.MemRefType.get( + (128, 64), ir.BF16Type.get(), memory_space=mgpu.utils.tmem() + ) + [in_ref] = undefs(in_ref_ty) + + layout = tcgen05.tmem_default_layout(packing=1) + layout_attr = layouts.to_layout_attr(layout) + + if annotate_input: + in_ref = mgpu.dialect.tmem_layout_cast(in_ref, layout_attr) + + out_ref = memref.subview( + in_ref, offsets=[1, 0], sizes=[2, 64], strides=[1, 1] + ) + + if not annotate_input: + mgpu.dialect.tmem_layout_cast(out_ref, layout_attr) + + with self.assertRaisesRegex(ValueError, "Failed to infer"): + mgpu.infer_layout(self.module) + + @parameterized.parameters([False, True]) + def test_infer_transforms_for_sibling_subviews_and_distant_op( + self, even_offsets + ): + # This test uses the following op tree extracted from this ragged dot + # kernel: + # https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/gpu/ragged_dot_mgpu.py + # + # subview_op0 (slice = 64, 64) + # - subview_op1 (slice = 2, 64) + # - subview_op2 (slice = 4, 64, either at an even or odd offset) + # - subview_op3 (slice = 8, 64) + # - user_op0 (in_transforms = [tile(64, 64), swizzle(32)]) + # + # First the in_transforms of user_op0 have to be propagated up to + # subview_op0. Then they have to be propagated down and resolved. Finally + # all subview ops need to have the same transforms. + + source_shape = (64, 64) + elt_ty = ir.BF16Type.get() + source_ref_ty = ir.MemRefType.get(source_shape, elt_ty, memory_space=mgpu.utils.smem()) + + slice1_shape = (2, 64) + slice2_shape = (4, 64) + slice3_shape = (8, 64) + + slice0_ref_ty = ir.MemRefType.get(source_shape, elt_ty, memory_space=mgpu.utils.smem()) + slice1_ref_ty = ir.MemRefType.get(slice1_shape, elt_ty, memory_space=mgpu.utils.smem()) + slice2_ref_ty = ir.MemRefType.get(slice2_shape, elt_ty, memory_space=mgpu.utils.smem()) + slice3_ref_ty = ir.MemRefType.get(slice3_shape, elt_ty, memory_space=mgpu.utils.smem()) + + want_tt = mgpu.dialect.TileTransformAttr.get((2 if even_offsets else 1, 64)) + + with ir.InsertionPoint(self.module.body): + [source_ref] = undefs(source_ref_ty) + subview_op0 = memref.SubViewOp( + slice0_ref_ty, + source_ref, + [], # dynamic offsets + [], # dynamic sizes + [], # dynamic strides + static_offsets=[0, 0], + static_sizes=source_shape, + static_strides=[1, 1], + ) + + transforms_0 = ir.ArrayAttr.get([want_tt]) + mgpu.dialect.WithTransformsOp(subview_op0.result, transforms_0) + + subview_op1 = memref.SubViewOp( + slice1_ref_ty, + subview_op0, + [], # dynamic offsets + [], # dynamic sizes + [], # dynamic strides + static_offsets=[0, 0], + static_sizes=slice1_shape, + static_strides=[1, 1], + ) + + subview_op2 = memref.SubViewOp( + slice2_ref_ty, + subview_op0, + [], # dynamic offsets + [], # dynamic sizes + [], # dynamic strides + static_offsets=[16 if even_offsets else 15, 0], + static_sizes=slice2_shape, + static_strides=[1, 1], + ) + + # The following ops are just to test the dynamic offsets support. + c = lambda x: arith.constant(ir.IntegerType.get_signless(32), x) + c64 = c(64) + c32 = c(32) + c16 = c(16) + subi = arith.subi(c64, c32) + maxsi = arith.maxsi(c16, subi) + addi = arith.addi(maxsi, subi) + andi = arith.andi(addi, maxsi) + idx = arith.index_cast(ir.IndexType.get(), andi) + subview_op3 = memref.SubViewOp( + slice3_ref_ty, + subview_op0, + [idx], # dynamic offsets + [], # dynamic sizes + [], # dynamic strides + static_offsets=[ir.ShapedType.get_dynamic_size(), 0], + static_sizes=slice3_shape, + static_strides=[1, 1], + ) - splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape)) - f.attributes["in_layouts"] = ir.ArrayAttr.get([splat_layout, splat_layout]) mgpu.infer_layout(self.module) + want = ir.ArrayAttr.get([ + want_tt, + mgpu.dialect.SwizzleTransformAttr.get(128), + ]) + + self.assertSequenceEqual(inference_utils.out_transforms(source_ref.owner), [want]) + self.assertSequenceEqual(inference_utils.in_transforms(subview_op0), [want]) + self.assertSequenceEqual(inference_utils.out_transforms(subview_op0), [want]) + self.assertSequenceEqual(inference_utils.in_transforms(subview_op1), [want]) + self.assertSequenceEqual(inference_utils.out_transforms(subview_op1), [want]) + self.assertSequenceEqual(inference_utils.in_transforms(subview_op2), [want]) + self.assertSequenceEqual(inference_utils.out_transforms(subview_op2), [want]) + self.assertSequenceEqual(inference_utils.in_transforms(subview_op3), [want]) + self.assertSequenceEqual(inference_utils.out_transforms(subview_op3), [want]) + + @parameterized.parameters([False, True]) + def test_infer_transforms_for_subview_handles_dynamic_offsets( + self, annotate_input + ): + with ir.InsertionPoint(self.module.body): + in_ref_ty = ir.MemRefType.get( + (32, 32, 32, 32), ir.BF16Type.get(), memory_space=mgpu.utils.smem() + ) + [in_ref] = undefs(in_ref_ty) + + transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((4, 8, 16)), + mgpu.dialect.SwizzleTransformAttr.get(32), + ]) + + if annotate_input: + in_ref = mgpu.dialect.with_transforms(in_ref, transforms) + + c = lambda x: arith.constant(ir.IntegerType.get_signless(32), x) + out_ref = memref.subview( + in_ref, + offsets=[c(16), c(4), arith.muli(c(8), c(3)), 0], + sizes=[16, 16, 32, 32], + strides=[1, 1, 1, 1], + ) + + if not annotate_input: + mgpu.dialect.with_transforms(out_ref, transforms) + + mgpu.infer_layout(self.module) + self.assertSequenceEqual( + inference_utils.in_transforms(out_ref.owner), [transforms] + ) self.assertSequenceEqual( - add.attributes["in_layouts"], [splat_layout, splat_layout]) - self.assertSequenceEqual(add.attributes["out_layouts"], [splat_layout]) + inference_utils.out_transforms(out_ref.owner), [transforms] + ) + + @parameterized.parameters([False, True]) + def test_infer_tmem_layouts_for_subview_handles_dynamic_offsets( + self, annotate_input + ): + with ir.InsertionPoint(self.module.body): + in_ref_ty = ir.MemRefType.get( + (128, 256), ir.BF16Type.get(), memory_space=mgpu.utils.tmem() + ) + [in_ref] = undefs(in_ref_ty) - def test_infer_layout_does_not_assign_default_layouts_to_func(self): + layout = tcgen05.tmem_default_layout(packing=1) + layout_attr = layouts.to_layout_attr(layout) - def body(lhs, rhs): - arith.AddFOp(lhs, rhs) + if annotate_input: + in_ref = mgpu.dialect.tmem_layout_cast(in_ref, layout_attr) + c = lambda x: arith.constant(ir.IntegerType.get_signless(32), x) + out_ref = memref.subview( + in_ref, + offsets=[c(0), arith.muli(c(16), c(4))], + sizes=[128, 128], + strides=[1, 1], + ) + + if not annotate_input: + mgpu.dialect.tmem_layout_cast(out_ref, layout_attr) + + mgpu.infer_layout(self.module) + self.checkInTmemLayouts(out_ref.owner, [layout]) + self.checkOutTmemLayouts(out_ref.owner, [layout]) + + def test_custom_primitive_op_retains_transforms(self): with ir.InsertionPoint(self.module.body): - shape = (32, 4) - ty = ir.VectorType.get(shape, ir.BF16Type.get()) - f = func.FuncOp.from_py_func(ty, ty)(body).func_op + transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((64, 64)), + mgpu.dialect.SwizzleTransformAttr.get(128), + ]) + ref_ty = ir.MemRefType.get( + (128, 128), ir.BF16Type.get(), memory_space=mgpu.utils.smem() + ) + [ref] = undefs(ref_ty) + op = mgpu.dialect.custom_primitive( + result=[], + operands_=[ref], + in_layouts=[], + in_transforms=[transforms], + out_layouts=[], + ) + + mgpu.infer_layout(self.module) + self.assertSequenceEqual(inference_utils.in_transforms(op), [transforms]) + + def test_custom_primitive_op_with_conflicting_transforms_is_unsat(self): + with ir.InsertionPoint(self.module.body): + transforms_a = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((64, 64)), + ]) + transforms_b = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((32, 32)), + ]) + ref_ty = ir.MemRefType.get( + (128, 128), ir.BF16Type.get(), memory_space=mgpu.utils.smem() + ) + [ref] = undefs(ref_ty) + mgpu.dialect.custom_primitive( + result=[], + operands_=[ref, ref], + in_layouts=[], + in_transforms=[transforms_a, transforms_b], + out_layouts=[], + ) + + with self.assertRaisesRegex(ValueError, "Failed to infer"): + mgpu.infer_layout(self.module) + + @parameterized.parameters([False, True]) + def test_infer_transforms_for_memref_transpose(self, annotate_input): + in_shape = (32, 64) + out_shape = (64, 32) + elt_ty = ir.BF16Type.get() + + in_ref_ty = ir.MemRefType.get( + in_shape, elt_ty, memory_space=mgpu.utils.smem() + ) + layout = ir.StridedLayoutAttr.get(0, strides=[1, 64]) + out_ref_ty = ir.MemRefType.get( + out_shape, elt_ty, layout=layout, memory_space=mgpu.utils.smem() + ) + + with ir.InsertionPoint(self.module.body): + [in_ref] = undefs(in_ref_ty) + + in_transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((8, 16)), + mgpu.dialect.SwizzleTransformAttr.get(32), + ]) + + if annotate_input: + in_ref = mgpu.dialect.with_transforms(in_ref, in_transforms) + + permutation = ir.AffineMap.get_permutation((1, 0)) + transpose_op = memref.TransposeOp(out_ref_ty, in_ref, permutation) + + out_transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((16, 8)), + mgpu.dialect.SwizzleTransformAttr.get(32), + ]) + + if not annotate_input: + mgpu.dialect.with_transforms(transpose_op.result, out_transforms) + + mgpu.infer_layout(self.module) + + self.assertSequenceEqual( + inference_utils.in_transforms(transpose_op), [in_transforms] + ) + self.assertSequenceEqual( + inference_utils.out_transforms(transpose_op), [out_transforms] + ) + + def test_default_strided_layout_assignment_is_deterministic(self): + with ir.InsertionPoint(self.module.body): + shape = (8, 128) + src_elt_ty = ir.IntegerType.get_signless(32) + dst_elt_ty = ir.IntegerType.get_signless(16) + src_ref_ty = ir.MemRefType.get(shape, src_elt_ty) + dst_ref_ty = ir.MemRefType.get(shape, dst_elt_ty) + src_ref, dst_ref = undefs(src_ref_ty, dst_ref_ty) + + # Make sure to have at least three ops such that the default assignment + # can pick a vector size from data types of various lengths. + src = mgpu.dialect.vector_load(src_ref) + conversion = arith.TruncIOp(ir.VectorType.get(shape, dst_elt_ty), src) + mgpu.dialect.vector_store(conversion.result, dst_ref) + + mgpu.infer_layout(self.module) + + # The default assignment should yield a strided layout here. The specific + # vector size does not matter to the test, but it is important that it is + # consistent between several runs of the test. If the logic changes such + # that another vector size is deterministically chosen, it is likely fine to + # edit this. + layout = fa.WGStridedFragLayout(shape, vec_size=2) + self.checkInLayouts(conversion, [layout]) + self.checkOutLayouts(conversion, [layout]) + + def test_infer_layout_for_vector_extract_strided_slice(self): + layout = layouts.to_layout_attr(fa.WGMMA_LAYOUT) + with ir.InsertionPoint(self.module.body): + i16 = ir.IntegerType.get_signless(16) + src_ty = ir.VectorType.get([128, 128], i16) + [src] = undefs(src_ty) + src = mgpu.dialect.layout_cast(src, layout) + dest_ty = ir.VectorType.get([64, 64], i16) + op = vector.ExtractStridedSliceOp(dest_ty, src, [0, 64], [64, 64], [1, 1]) + mgpu.infer_layout(self.module) + self.checkInLayouts(op, [layout]) + self.checkOutLayouts(op, [layout]) + + @parameterized.named_parameters( + ( + "tiled_layout_non_divisible_by_offset", + mtu.RegisterLayout.WGMMA, + [3, 64], + ), + ("strided_layout", mtu.RegisterLayout.WG_STRIDED, [0, 64]), + ("splat_layout", mtu.RegisterLayout.WG_SPLAT, [0, 64]), + ) + def test_infer_layout_for_vector_extract_strided_slice_fails( + self, layout, offsets + ): + with ir.InsertionPoint(self.module.body): + i16 = ir.IntegerType.get_signless(16) + src_ty = ir.VectorType.get([128, 128], i16) + [src] = undefs(src_ty) + layout_attr = layout.to_layout_attr(src_ty.shape, src_ty.element_type) + src = mgpu.dialect.layout_cast(src, layout_attr) + dest_ty = ir.VectorType.get([64, 64], i16) + vector.extract_strided_slice(dest_ty, src, offsets, [64, 64], [1, 1]) + with self.assertRaisesRegex( + ValueError, "Failed to infer a possible set of layouts." + ): + mgpu.infer_layout(self.module) + + def test_infer_layout_for_vector_extract(self): + layout = layouts.to_layout_attr(fa.WGMMA_LAYOUT) + with ir.InsertionPoint(self.module.body): + i16 = ir.IntegerType.get_signless(16) + src_ty = ir.VectorType.get([2, 3, 64, 8], i16) + [src] = undefs(src_ty) + src = mgpu.dialect.layout_cast(src, layout) + op = vector.ExtractOp(src, dynamic_position=[], static_position=[1, 1]) + mgpu.infer_layout(self.module) + self.checkInLayouts(op, [layout]) + self.checkOutLayouts(op, [layout]) + + def test_infer_layout_for_vector_extract_to_scalar(self): + with ir.InsertionPoint(self.module.body): + i16 = ir.IntegerType.get_signless(16) + src_ty = ir.VectorType.get([64, 8], i16) + [src] = undefs(src_ty) + op = vector.ExtractOp(src, dynamic_position=[], static_position=[1, 1]) + mgpu.infer_layout(self.module) + self.checkInLayouts(op, [mgpu.WGSplatFragLayout(tuple(src_ty.shape))]) + self.assertNotIn("out_layouts", op.attributes) + + def test_infer_layout_for_vector_extract_fails_if_not_dividing_result_shape(self): + layout = layouts.to_layout_attr(fa.WGMMA_LAYOUT) + with ir.InsertionPoint(self.module.body): + i16 = ir.IntegerType.get_signless(16) + src_ty = ir.VectorType.get([64, 64], i16) + [src] = undefs(src_ty) + src = mgpu.dialect.layout_cast(src, layout) + vector.extract(src, dynamic_position=[], static_position=[0]) + with self.assertRaisesRegex( + ValueError, "Failed to infer a possible set of layouts." + ): + mgpu.infer_layout(self.module) + + def test_infer_tmem_layout_for_slice_tmem_op(self): + # in and out layouts can be different. + in_layout = layouts.to_layout_attr(tcgen05.tmem_default_layout(packing=1)) + out_layout = layouts.to_layout_attr(tcgen05.tmem_default_layout(packing=2)) + with ir.InsertionPoint(self.module.body): + i32 = ir.IntegerType.get_signless(32) + src_tmem_type = ir.MemRefType.get( + (128, 512), i32, memory_space=mgpu.utils.tmem() + ) + [src] = undefs(src_tmem_type) + src = mgpu.dialect.tmem_layout_cast(src, in_layout) + dst_tmem_type = ir.MemRefType.get( + (128, 64), ir.BF16Type.get(), memory_space=mgpu.utils.tmem() + ) + op = mgpu.dialect.SliceTmemOp(dst_tmem_type, src, 64) + mgpu.dialect.tmem_layout_cast(op.result, out_layout) mgpu.infer_layout(self.module) - self.assertNotIn("in_layouts", f.attributes) - self.assertNotIn("out_layouts", f.attributes) + self.checkInTmemLayouts(op, [in_layout]) + self.checkOutTmemLayouts(op, [out_layout]) + + def test_infer_layout_fails_if_not_enough_fuel(self): + layout = fa.WGStridedFragLayout((128, 128), vec_size=4) + with ir.InsertionPoint(self.module.body): + vec_ty = ir.VectorType.get((128, 128), ir.BF16Type.get()) + a, b = undefs(vec_ty, vec_ty) + a = layout_cast(a, layout) + add = arith.AddFOp(a, b) + + with self.assertRaisesRegex(ValueError, "Consider adding layout annotations"): + mgpu.infer_layout(self.module, fuel=1) + + mgpu.infer_layout(self.module, fuel=100) + + self.checkInLayouts(add, [layout, layout]) + self.checkOutLayouts(add, [layout]) + + def test_infer_layout_for_broadcasted_iota_rejects_splat_layout(self): + with ir.InsertionPoint(self.module.body): + vec_ty = ir.VectorType.get((128, 128), ir.BF16Type.get()) + iota = mgpu.dialect.broadcasted_iota(vec_ty, 0) + layout_cast(iota, fa.WGSplatFragLayout(vec_ty.shape)) + + with self.assertRaisesRegex( + ValueError, "user-provided layout casts are unsatisfiable" + ): + mgpu.infer_layout(self.module) + + def test_infer_layout_for_print_register_layout_op(self): + with ir.InsertionPoint(self.module.body): + vec_ty = ir.VectorType.get((128, 128), ir.BF16Type.get()) + [vec] = undefs(vec_ty) + vec = layout_cast(vec, fa.WGMMA_LAYOUT) + op = mgpu.dialect.PrintLayoutOp("{}", vec) + mgpu.infer_layout(self.module) + self.checkInLayouts(op, [fa.WGMMA_LAYOUT]) + + def test_infer_layout_for_print_tmem_layout_op(self): + layout = tcgen05.tmem_default_layout(packing=1) + with ir.InsertionPoint(self.module.body): + ref_ty = ir.MemRefType.get( + (128, 128), ir.BF16Type.get(), memory_space=mgpu.utils.tmem() + ) + [ref] = undefs(ref_ty) + ref = mgpu.dialect.tmem_layout_cast(ref, layouts.to_layout_attr(layout)) + op = mgpu.dialect.PrintLayoutOp("{}", ref) + mgpu.infer_layout(self.module) + self.checkInTmemLayouts(op, [layout]) + + @parameterized.product( + op_type=( + mgpu.dialect.AsyncLoadOp, + mgpu.dialect.AsyncStoreOp, + mgpu.dialect.AsyncPrefetchOp, + ), + vec_offset=(1, 2), + ) + def test_infer_layout_for_async_ops_with_vector_indices( + self, op_type, vec_offset, + ): + # TODO(b/415721295): Remove when the minimum jaxlib version is 0.8.3. + if not hasattr(mgpu.dialect, "tma_gather_supported"): + self.skipTest("TMA gather support is required.") + with ir.InsertionPoint(self.module.body): + elt_ty = ir.BF16Type.get() + vec_len = 64 + gmem_shape = (8, 128, 128) + smem_shape = (4, vec_len, 128) if vec_offset == 1 else (4, 128, vec_len) + i32 = ir.IntegerType.get_signless(32) + vec_ty = ir.VectorType.get((vec_len,), i32) + + gmem_ty = ir.MemRefType.get(gmem_shape, elt_ty) + smem_ty = ir.MemRefType.get(smem_shape, elt_ty, memory_space=mgpu.utils.smem()) + barrier_ty = ir.Type.parse("!mosaic_gpu.barrier") + + gmem_ref, smem_ref, barrier, scalar_idx, vec_idx = undefs( + gmem_ty, smem_ty, barrier_ty, i32, vec_ty + ) + + if vec_offset == 1: + indices = [scalar_idx, vec_idx, scalar_idx] + slice_lengths = [4, vec_len, 128] + else: + indices = [scalar_idx, scalar_idx, vec_idx] + slice_lengths = [4, 128, vec_len] + + if op_type == mgpu.dialect.AsyncLoadOp: + op = op_type( + source=gmem_ref, + destination=smem_ref, + barrier=barrier, + indices=indices, + slice_lengths=slice_lengths, + collective=ir.ArrayAttr.get([]), + ) + elif op_type == mgpu.dialect.AsyncStoreOp: + op = op_type( + source=smem_ref, + destination=gmem_ref, + indices=indices, + slice_lengths=slice_lengths, + ) + elif op_type == mgpu.dialect.AsyncPrefetchOp: + op = op_type( + source=gmem_ref, + indices=indices, + slice_lengths=slice_lengths, + collective=ir.ArrayAttr.get([]), + ) + + layout = mgpu.TMA_GATHER_INDICES_LAYOUT + mgpu.infer_layout(self.module) + self.checkInLayouts(op, [layout]) + + @parameterized.parameters( + ((32, 64, 128), [[0], [1], [2]], (32, 64, 128), False), + ((32, 64, 128), [[0], [1, 2], [3]], (32, 4, 16, 128), False), + ((32, 64, 128), [[0, 1], [2], [3]], (4, 8, 64, 128), True), + ( + (ir.ShapedType.get_dynamic_size(), 64, 128), + [[0, 1], [2], [3]], + ( + ir.ShapedType.get_dynamic_size(), + ir.ShapedType.get_dynamic_size(), + 64, + 128, + ), + True, + ), + ) + def test_infer_layout_for_memref_expand_shape_op(self, input_shape, reassociation, output_shape, has_transforms): + with ir.InsertionPoint(self.module.body): + ref_ty = ir.MemRefType.get( + input_shape, ir.BF16Type.get(), memory_space=mgpu.utils.smem() + ) + [in_ref, idx] = undefs(ref_ty, ir.IndexType.get()) + + if has_transforms: + transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((32, 32)), + mgpu.dialect.SwizzleTransformAttr.get(64), + ]) + in_ref = mgpu.dialect.with_transforms(in_ref, transforms) + else: + transforms = [] + + dynamic_output_sizes = [ + idx + for size in output_shape + if size == ir.ShapedType.get_dynamic_size() + ] + + op = memref.ExpandShapeOp( + result=ref_ty, + src=in_ref, + reassociation=reassociation, + output_shape=dynamic_output_sizes, + static_output_shape=output_shape, + ) + mgpu.infer_layout(self.module) + [in_transform] = inference_utils.in_transforms(op) + self.assertSequenceEqual(in_transform, transforms) + [out_transform] = inference_utils.out_transforms(op) + self.assertSequenceEqual(out_transform, transforms) + + def test_layout_cast_incompatible_with_vector_shape_is_unsatisfiable(self): + with ir.InsertionPoint(self.module.body): + [vec] = undefs(ir.VectorType.get((4, 4), ir.BF16Type.get())) + mgpu.dialect.layout_cast(vec, layouts.to_layout_attr(fa.WGMMA_LAYOUT)) + with self.assertRaisesRegex( + ValueError, "Failed to infer a possible set of layouts" + ): + mgpu.infer_layout(self.module) + + def test_tmem_layout_cast_incompatible_with_ref_shape_is_unsatisfiable(self): + with ir.InsertionPoint(self.module.body): + f32 = ir.F32Type.get() + ref_ty = ir.MemRefType.get((4, 4), f32, memory_space=mgpu.utils.tmem()) + [ref] = undefs(ref_ty) + mgpu.dialect.tmem_layout_cast( + ref, layouts.to_layout_attr(mgpu.TMEM_NATIVE_LAYOUT) + ) + with self.assertRaisesRegex( + ValueError, "Failed to infer a possible set of layouts" + ): + mgpu.infer_layout(self.module) + + def test_with_transforms_incompatible_with_smem_shape_is_unsatisfiable(self): + with ir.InsertionPoint(self.module.body): + f32 = ir.F32Type.get() + ref_ty = ir.MemRefType.get((4, 4), f32, memory_space=mgpu.utils.smem()) + [ref] = undefs(ref_ty) + transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((8, 2)), + ]) + mgpu.dialect.with_transforms(ref, transforms) + with self.assertRaisesRegex( + ValueError, "Failed to infer a possible set of layouts" + ): + mgpu.infer_layout(self.module) if __name__ == "__main__": diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index e7bd7fad3798..cde9d10d688d 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -22,27 +22,45 @@ import operator import os import re -import unittest +import sys +import tempfile -from absl.testing import absltest, parameterized +from absl.testing import absltest +from absl.testing import parameterized import jax from jax._src import config +from jax._src import dtypes from jax._src import test_util as jtu from jax._src.interpreters import mlir from jax._src.lib.mlir import ir from jax._src.lib.mlir import passmanager from jax._src.lib.mlir.dialects import arith +from jax._src.lib.mlir.dialects import cf +from jax._src.lib.mlir.dialects import gpu +from jax._src.lib.mlir.dialects import llvm from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector +import jax.experimental.mosaic.gpu as mgpu +from jax.experimental.mosaic.gpu import core from jax.experimental.mosaic.gpu import dialect as mgpu_dialect # pylint: disable=g-importing-member from jax.experimental.mosaic.gpu import fragmented_array as fa +from jax.experimental.mosaic.gpu import inference_utils +from jax.experimental.mosaic.gpu import launch_context +from jax.experimental.mosaic.gpu import layouts +from jax.experimental.mosaic.gpu import profiler from jax.experimental.mosaic.gpu import tcgen05 +from jax.experimental.mosaic.gpu import test_util as mtu +from jax.experimental.mosaic.gpu import utils +from jax.experimental.mosaic.gpu.utils import * # noqa: F403 import jax.numpy as jnp import numpy as np + + try: - import jax._src.lib.mosaic_gpu # noqa: F401 + import jax._src.lib.mosaic_gpu as mosaic_gpu_lib # noqa: F401 HAS_MOSAIC_GPU = True except ImportError: + mosaic_gpu_lib = None HAS_MOSAIC_GPU = False class Dimension(enum.IntEnum): # Just to make parameterized tests expand ok @@ -50,17 +68,15 @@ class Dimension(enum.IntEnum): # Just to make parameterized tests expand ok y = 1 z = 2 else: - import jax.experimental.mosaic.gpu as mgpu - from jax.experimental.mosaic.gpu import core - from jax.experimental.mosaic.gpu import launch_context - from jax.experimental.mosaic.gpu import utils as utils - from jax.experimental.mosaic.gpu import profiler - from jax.experimental.mosaic.gpu import inference_utils - from jax.experimental.mosaic.gpu.utils import * # noqa: F403 - from jax._src.lib.mlir.dialects import gpu - from jax._src.lib.mlir.dialects import llvm Dimension = gpu.Dimension +try: + import hypothesis as hp + import hypothesis.strategies as hps + jtu.setup_hypothesis() +except ImportError: + hp = hps = None + # ruff: noqa: F405 # pylint: disable=g-complex-comprehension @@ -85,20 +101,6 @@ def mlir_sum(elems): return total -@contextlib.contextmanager -def get_sass(): - prev_dump = os.environ.get("MOSAIC_GPU_DUMP_SASS", None) - os.environ["MOSAIC_GPU_DUMP_SASS"] = "1" - try: - with jtu.capture_stdout() as output: - yield output - finally: - if prev_dump is not None: - os.environ["MOSAIC_GPU_DUMP_SASS"] = prev_dump - else: - del os.environ["MOSAIC_GPU_DUMP_SASS"] - - def copy(src: ir.Value, dst: ir.Value, swizzle: int | None = None): index = ir.IndexType.get() thread_id = gpu.thread_id(gpu.Dimension.x) @@ -134,7 +136,6 @@ def copy(src: ir.Value, dst: ir.Value, swizzle: int | None = None): packing = 8 // bw if shape[-1] % packing: raise NotImplementedError - workgroup_mem = ir.Attribute.parse("#gpu.address_space") shape = (*shape[:-1], shape[-1] // packing) contig_strides = get_contiguous_strides(shape) def bitcast(ref): @@ -149,12 +150,7 @@ def bitcast(ref): ir.StridedLayoutAttr.get(0, new_strides), ref_ty.memory_space, ) - ptr_space = ( - 3 - if ref_ty.memory_space is not None - and ref_ty.memory_space == workgroup_mem - else None - ) + ptr_space = 3 if utils.is_smem_ref(ref_ty) else None return ptr_as_memref( # NOTE: memref_ptr applies the offset in case there was any. memref_ptr(ref, memory_space=ptr_space), @@ -198,15 +194,15 @@ def body(*idx): nvvm.fence_proxy(nvvm.ProxyKind.async_) -def iota_tensor(m, n, dtype): - """A wgmma tensor where arr[i, j] = i * N + j.""" +def iota_tensor(m, n, dtype, layout=mgpu.WGMMA_LAYOUT): + """A tensor with given layout where arr[i, j] = i * N + j.""" index = ir.IndexType.get() mlir_dtype = utils.dtype_to_ir_type(dtype) int_ty = ir.IntegerType.get_signless(bitwidth(mlir_dtype)) ret = mgpu.FragmentedArray.splat( llvm.mlir_undef(int_ty), (m, n), is_signed=False ) - ret = ret.to_layout(mgpu.WGMMA_LAYOUT) + ret = ret.to_layout(layout) def iota_value(_, idx): assert len(idx) == 2 @@ -233,12 +229,22 @@ def setUp(self): super().setUp() self.prng = np.random.default_rng(1234) self.context = mlir.make_ir_context() - if mgpu_dialect is not None: - mgpu_dialect.register_dialect(self.context) + mgpu_dialect.register_dialect(self.context) self.enter_context(config.traceback_filtering("off")) self.enter_context(self.context) self.enter_context(ir.Location.unknown()) + @contextlib.contextmanager + def capture_stdout(self): + if "pytest" in sys.modules: + self.skipTest("pytest interacts badly with GPU stdout capture") + if mosaic_gpu_lib is None: + raise ValueError("Running tests but missing Mosaic GPU extension") + with jtu.capture_stdout() as stdout: + yield stdout + # We need to cudaDeviceSynchronize to make sure printfs are flushed. + mosaic_gpu_lib._mosaic_gpu_ext._sync_all_devices() + class Sm90ATestCase(TestCase, jtu.CudaArchSpecificTest): @@ -286,7 +292,11 @@ def kernel(ctx, dst, _): assert registers.size == 16, registers.size for i, vec_reg in enumerate(registers.flat): for j in range(2): - reg = vector.extractelement(vec_reg, position=c(j, index)) + reg = vector.extract( + source=vec_reg, + dynamic_position=[], + static_position=ir.DenseI64ArrayAttr.get([j]), + ) memref.store( reg, dst, [gpu.thread_id(gpu.Dimension.x), c(2 * i + j, index)] ) @@ -382,17 +392,20 @@ def kernel(ctx, inp, out, _): ("add_1s", (5, 1, 2), (1, 1, 5, 1, 1, 2, 1, 1)), ("fold", (1, 5, 2, 1,), (1, 10, 1)), ("un", (1, 10, 1), (1, 5, 2, 1,)), + ("to_scalar", (1, 1, 1), ()), + ("from_scalar", (), (1, 1, 1)), + ("arbitrary", (2 * 5, 7 * 3), (2, 7, 5, 3)), ) def test_reshape(self, inp_shape, out_shape): def kernel(ctx, inp, out, _): copy(memref_reshape(inp, out_shape), out) - x = np.arange(math.prod(inp_shape), dtype=jnp.float32).reshape(*inp_shape) + x = np.arange(math.prod(inp_shape), dtype=jnp.float32).reshape(inp_shape) out_ty = jax.ShapeDtypeStruct(out_shape, jnp.float32) y = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), x, out_ty, () )(x) - np.testing.assert_array_equal(y, x.reshape(*out_shape)) + np.testing.assert_array_equal(y, x.reshape(out_shape)) @parameterized.named_parameters([ ("packed", (4, 4, 4), (16, 4, 1), 1, 2, False), @@ -404,7 +417,7 @@ def kernel(ctx, inp, out, _): # ("overap", (2, 4, 4), (16, 1, 1), 0, 3, True), ]) def test_fold_strided( - self, shape, strides, dim, fold_rank, throws_not_impl + self, shape, strides, dim, fold_rank, throws ): expanded_shape = get_packed_shape(strides, shape) total_size = np.prod(expanded_shape) @@ -417,7 +430,7 @@ def np_fold(inp, dim, fold_rank): out_shape[dim : dim + fold_rank] = [ int(np.prod(inp.shape[dim : dim + fold_rank])) ] - if throws_not_impl: + if throws: return jax.ShapeDtypeStruct(shape=out_shape, dtype=inp.dtype) else: return inp.reshape(*out_shape) @@ -433,12 +446,12 @@ def kernel(ctx, inp, out, _): kernel, (1, 1, 1), (128, 1, 1), np_inp, out, () )(np_inp) assert ( - not throws_not_impl + not throws ), "If it should have thrown it would during the call." np.testing.assert_array_equal(y, out) - if throws_not_impl: - with self.assertRaises(NotImplementedError): + if throws: + with self.assertRaises(ValueError): do_test() else: do_test() @@ -452,7 +465,7 @@ def test_scalar_argument(self, dtype): " values read from the 32-bit input buffer to sometimes" " (nondeterministically) contain garbage.") - scalar = 42 + scalar = dtype(42) expected = np.full((128, 128), scalar, dtype=dtype) def kernel(ctx, inp, out, _): @@ -470,6 +483,29 @@ def kernel(ctx, inp, out, _): )(scalar) np.testing.assert_array_equal(res, expected) + @parameterized.parameters(gpu.Dimension.x, gpu.Dimension.y) + def test_cluster_ref(self, dim): + index = ir.IndexType.get() + dims = (gpu.Dimension.x, gpu.Dimension.y) + def kernel(ctx, src, dst, scratch): + smem, barrier = scratch + cluster_idx = tuple(gpu.cluster_block_id(dim) for dim in dims) + peer_idx = arith.subi(arith.constant(index, 1), cluster_idx[dim]) + peer_smem = ctx.get_cluster_ref(smem, dim, peer_idx) + a = mgpu.FragmentedArray.load_strided(memref_slice(src, cluster_idx)).store_untiled(smem) + utils.warpgroup_barrier() + barrier.arrive() + barrier.wait() + mgpu.FragmentedArray.load_strided(peer_smem).store_untiled(memref_slice(dst, cluster_idx)) + + barrier = mgpu.ClusterBarrier(collective_dims=(dim,)) + x = np.arange(2 * 2 * 512, dtype=jnp.float32).reshape(2, 2, 512) + smem = jax.ShapeDtypeStruct(shape=(x.shape[-1],), dtype=jnp.float32) + f = mgpu.as_gpu_kernel( + kernel, (2, 2, 1), (128, 1, 1), x, x, (smem, barrier), cluster=(2, 2, 1) + ) + np.testing.assert_array_equal(f(x), np.flip(x, axis=int(dim))) + def get_packed_shape(strides, shape): perm = sorted(range(len(strides)), key=lambda i: strides[i], reverse=True) @@ -489,24 +525,51 @@ def get_packed_shape(strides, shape): class WGMMALayoutTest(TestCase): - @parameterized.product(dtype=[jnp.float16, jnp.float32], - transposed_smem=[False, True]) - def test_store_untiled(self, dtype, transposed_smem): + @parameterized.product(dtype=[jnp.float16, jnp.float32]) + def test_store_untiled(self, dtype): def kernel(ctx, out, _): del ctx - if transposed_smem: - out = memref_transpose(out, (1, 0)) - iota_tensor(64, 64, dtype).store_untiled( - out, vector_store=not transposed_smem - ) + iota_tensor(64, 64, dtype).store_untiled(out, optimized=False) expected = np.arange(64 * 64, dtype=dtype).reshape(64, 64) - if transposed_smem: - expected = expected.T iota = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), expected, () )() np.testing.assert_array_equal(iota, expected) + @parameterized.product( + dtype=[jnp.float8_e5m2fnuz, jnp.float8_e5m2, jnp.float8_e4m3b11fnuz, + jnp.float8_e4m3fn, jnp.float8_e4m3fnuz], + swizzle=(32, 64, 128), + num_col_tiles=(1, 2, 3), + ) + def test_load_and_store_tiled_f8(self, dtype, swizzle, num_col_tiles): + # We use a different test than `test_store_tiled` because converting + # `iota` to `f8` type requires additional specialized logic that is not + # yet available. + col_tiling = swizzle + m = 128 + n = col_tiling * num_col_tiles + tiling = (64, col_tiling) + def kernel(ctx, inp, out, smem): + del ctx + smem_inp, smem_out = smem + copy(inp, smem_inp, swizzle=swizzle) + arr = mgpu.FragmentedArray.load_tiled(smem_inp, swizzle=swizzle) + arr.store_tiled(smem_out, swizzle=swizzle) + copy(smem_out, out, swizzle=swizzle) + expected = ( + jax.random.randint( + jax.random.key(42), (m * n,), -16, 15, dtype=jnp.int8 + ) + .reshape(m // tiling[0], tiling[0], n // tiling[1], tiling[1]) + .astype(dtype) + .transpose(0, 2, 1, 3) + ) + res = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), expected, expected, (expected,) * 2 + )(expected) + np.testing.assert_array_equal(res, expected) + @parameterized.product( dtype=[jnp.float32, jnp.float16, jnp.int8], swizzle=(32, 64, 128), @@ -534,84 +597,168 @@ def kernel(ctx, out, smem): )() np.testing.assert_array_equal(iota, expected) - @parameterized.parameters(jnp.int8, jnp.int16, jnp.int32) - def test_sub_byte_conversion(self, jax_dtype_to): + @parameterized.product( + jax_dtype_to=( + jnp.int8, jnp.int16, jnp.int32, jnp.bfloat16, jnp.float8_e4m3fn, + ), + # Use different layouts to vary the size of the vector dimension. + layout=( + fa.WGMMA_LAYOUT, + fa.WGMMA_LAYOUT_UPCAST_2X, + fa.WGMMA_LAYOUT_UPCAST_4X, + ), + ) + def test_sub_byte_conversion(self, jax_dtype_to, layout: fa.TiledLayout): + if jax_dtype_to == jnp.int32 and layout.vector_length == 8: + self.skipTest( + "Raises: failed to prove that vector transfers don't cross swizzle" + " tile boundaries.") jax_dtype_from = jnp.int4 + is_signed = utils.is_signed(jax_dtype_to) def kernel(ctx, inp, out, smem): del ctx # Unused. smem_inp, smem_out = smem copy(inp, smem_inp, swizzle=16) - t = mgpu.FragmentedArray.load_tiled(smem_inp, is_signed=True, swizzle=16) - t = t.astype(utils.dtype_to_ir_type(jax_dtype_to), is_signed=True) + t = mgpu.FragmentedArray.load_tiled( + smem_inp, is_signed=True, swizzle=16, layout=layout + ) + t = t.astype(utils.dtype_to_ir_type(jax_dtype_to), is_signed=is_signed) t.store_tiled(smem_out, swizzle=32 * jnp.dtype(jax_dtype_to).itemsize) copy(smem_out, out, swizzle=32 * jnp.dtype(jax_dtype_to).itemsize) x = self.prng.integers( low=-8, high=7, size=(1, 1, 64, 64), dtype=np.int32 ).astype(jax_dtype_from) - y = x.astype(jax_dtype_to) + y = jax.lax.convert_element_type(x, jax_dtype_to) f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, y, (x, y)) np.testing.assert_array_equal(f(x), y) + @parameterized.parameters( + (jnp.float32, jnp.float8_e4m3fn), + (jnp.bfloat16, jnp.float8_e4m3fn) + ) + def test_f8_conversions(self, jax_dtype_from, jax_dtype_to): + mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to) + def kernel(ctx, inp, out, smem): + del ctx + smem_from, smem_to = smem + copy(inp, smem_from, swizzle=128) + t = mgpu.FragmentedArray.load_tiled( + smem_from, + swizzle=128, + is_signed=None, + layout=fa.WGMMA_LAYOUT, + ) + t = t.astype(mlir_dtype_to, is_signed=utils.is_signed(jax_dtype_to)) + t.store_tiled(smem_to, swizzle=128) + copy(smem_to, out, swizzle=128) + + # These generative shenanigans are to ensure that we don't generate values + # that are too large for the target type. That is because the saturation + # behavior of the conversion is different between XLA and Mosaic GPU here + # (to use the NVIDIA internal, we allow Mosaic GPU to use the .satfinite + # modifier, which saturates to the largest finite value---while XLA would + # give us NaNs in this case). + max_finite_val = 0b111_1110 + + expected = jax.lax.bitcast_convert_type( + jax.random.randint( + jax.random.key(42), + (1, 1, 64, 128), + -max_finite_val, + max_finite_val + 1, + dtype=jnp.uint8, + ), + jax_dtype_to, + ) + x = expected.astype(jax_dtype_from) + + res = mgpu.as_gpu_kernel( + kernel, + (1, 1, 1), + (128, 1, 1), + x, + expected, + (x, expected), + )(x) + np.testing.assert_array_equal(res, expected) + @parameterized.product( jax_dtype_from_to=( (jnp.int8, jnp.bfloat16), (jnp.int4, jnp.bfloat16), + (jnp.int4, jnp.float8_e4m3fn), + (jnp.int4, jnp.int8), + # TODO(apaszke,bchetioui): bf16/f32 -> f8e4m3fn ), - layout=( - fa.WGMMA_LAYOUT, - fa.WGMMA_LAYOUT_UPCAST_2X, - fa.WGMMA_LAYOUT_UPCAST_4X, + layout_descs=( + ("WGMMA_LAYOUT", "WGMMA_LAYOUT"), + ("WGMMA_LAYOUT_8BIT", "WGMMA_LAYOUT_8BIT"), + ("WGMMA_LAYOUT_UPCAST_2X", "WGMMA_LAYOUT_UPCAST_2X"), + ("WGMMA_LAYOUT_UPCAST_2X", "WGMMA_LAYOUT"), + ("WGMMA_LAYOUT_UPCAST_4X", "WGMMA_LAYOUT_UPCAST_4X"), + ("WGMMA_LAYOUT_UPCAST_4X", "WGMMA_LAYOUT_UPCAST_2X"), + ("WGMMA_LAYOUT_UPCAST_4X", "WGMMA_LAYOUT"), ), ) - def test_optimized_conversion(self, jax_dtype_from_to, layout): + @jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell") + def test_optimized_conversion(self, jax_dtype_from_to, layout_descs): + layout_desc_from, layout_desc_to = layout_descs + layout_from: fa.TiledLayout = getattr(fa, layout_desc_from) + layout_to: fa.TiledLayout = getattr(fa, layout_desc_to) jax_dtype_from, jax_dtype_to = jax_dtype_from_to mlir_dtype_from = utils.dtype_to_ir_type(jax_dtype_from) mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to) m = 128 - n = 256 * 8 // bitwidth(mlir_dtype_from) + n = 256 def kernel(ctx, inp, out, smem): - del ctx - smem_from, smem_to = smem - copy(inp, smem_from, swizzle=128) - t = mgpu.FragmentedArray.load_tiled( - smem_from, - swizzle=128, + del ctx, smem + t = mgpu.FragmentedArray.load_untiled( + inp, is_signed=utils.is_signed(jax_dtype_from), - layout=layout, + layout=layout_from, + optimized=False, ) + if layout_from != layout_to: + if ( + layout_from == fa.WGMMA_LAYOUT_UPCAST_4X + and utils.bitwidth(mlir_dtype_from) != 4 + ): + self.skipTest("Unimplemented relayout") + t = t.to_layout(layout_to) t = t.astype(mlir_dtype_to, is_signed=utils.is_signed(jax_dtype_to)) - t.store_tiled(smem_to, swizzle=128) - copy(smem_to, out, swizzle=128) + t.store_untiled(out, optimized=False) - from_tiling = (64, 128 * 8 // bitwidth(mlir_dtype_from)) - to_tiling = (64, 128 * 8 // bitwidth(mlir_dtype_to)) - # We only test lossless conversions for now. - # TODO(apaszke): Test and fix failures that appear with lossy conversions. int_sample_dtype = getattr( jnp, "int" + str(min(bitwidth(mlir_dtype_from), bitwidth(mlir_dtype_to))), ) sample_iinfo = jnp.iinfo(int_sample_dtype) - expected_raw = self.prng.integers( - low=sample_iinfo.min, high=sample_iinfo.max, - size=(m, n), dtype=np.int32 - ) - expected = lambda jax_dtype, tiling: expected_raw.reshape( - m // tiling[0], tiling[0], n // tiling[1], tiling[1] - ).transpose(0, 2, 1, 3).astype(jax_dtype) + values = self.prng.integers( + low=sample_iinfo.min, high=sample_iinfo.max, size=(m, n), dtype=np.int32 + ).astype(jax_dtype_from) - expected_from = expected(jax_dtype_from, from_tiling) - expected_to = expected(jax_dtype_to, to_tiling) - res = mgpu.as_gpu_kernel( - kernel, - (1, 1, 1), - (128, 1, 1), - expected_from, - expected_to, - (expected_from, expected_to), - )(expected_from) - np.testing.assert_array_equal(res, expected_to) + expected = values.astype(np.int32).astype(jax_dtype_to) + @contextlib.contextmanager + def _maybe_profile(): + yield; return # Comment to gather statistics. + with jtu.set_env(MOSAIC_GPU_DUMP_SASS="1"), self.capture_stdout() as sass: + yield + log_dir = os.getenv("TEST_UNDECLARED_OUTPUTS_DIR", "/tmp") + file_path = os.path.join(log_dir, "conversion_stats.csv") + with open(file_path, "a") as f: + data = ( + jnp.dtype(jax_dtype_from).name, jnp.dtype(jax_dtype_to).name, + layout_desc_from, layout_desc_to, sass().count("\n") + ) + f.write(",".join(map(str, data)) + "\n") + f.flush() + self.fail("Disable profiling before submission") + with _maybe_profile(): + res = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), values, expected, () + )(values) + np.testing.assert_array_equal(res, expected) @parameterized.named_parameters( ("f32", jnp.float32), @@ -643,6 +790,19 @@ def kernel(ctx, in_, out, smem): np.testing.assert_array_equal(iota, expected) +class I8Type: + """A type that represents a 8-bit signed integer. + + This is a workaround to bypass the fact that we don't have a proper 8-bit + integer type class available in MLIR, and can't instantiate types without a + MLIR context. + """ + + @staticmethod + def get(): # pylint: disable=no-method-argument + return ir.IntegerType.get_signless(8) + + class WGMMATest(TestCase): def setUp(self): @@ -653,16 +813,99 @@ def setUp(self): @parameterized.product( lhs_transpose=(False, True), rhs_transpose=(False, True), - in_mlir_dtype_cls=(ir.F16Type, ir.BF16Type, ir.F32Type), + in_mlir_dtype_cls=( + ir.F16Type, + ir.BF16Type, + ir.F32Type, + ir.Float8E5M2Type, + ir.Float8E4M3FNType, + ), m=(64, 128, 192), n=(64, 128, 192), - k_steps=(1, 2), swizzle=(32, 64, 128), jax_out_dtype=(jnp.float16, jnp.float32), + ) + def test_wgmma_basic_float( + self, + lhs_transpose, + rhs_transpose, + in_mlir_dtype_cls, + m, + n, + swizzle, + jax_out_dtype, + ): + self._test_wgmma_basic( + m, + n, + k_steps=2, # Decrease to 1 to simplify debugging. + in_mlir_dtype_cls=in_mlir_dtype_cls, + lhs_transpose=lhs_transpose, + rhs_transpose=rhs_transpose, + swizzle=swizzle, + jax_out_dtype=jax_out_dtype, + lhs_tiling_kind="small+no_transpose" if lhs_transpose else "small", + rhs_tiling_kind="small+no_transpose" if rhs_transpose else "small", + ) + + @parameterized.product( + in_mlir_dtype_cls=(I8Type,), + m=(64, 128, 192), + n=(64, 128, 192), + swizzle=(32, 64, 128), + jax_out_dtype=(jnp.int32,), + ) + def test_wgmma_basic_int( + self, in_mlir_dtype_cls, m, n, swizzle, jax_out_dtype, + ): + self._test_wgmma_basic( + m, + n, + k_steps=2, # Decrease to 1 to simplify debugging. + in_mlir_dtype_cls=in_mlir_dtype_cls, + lhs_transpose=False, + rhs_transpose=True, + swizzle=swizzle, + jax_out_dtype=jax_out_dtype, + rhs_tiling_kind="small", + lhs_tiling_kind="small+no_transpose", + ) + + @parameterized.product( + lhs_transpose=(False, True), + rhs_transpose=(False, True), + in_mlir_dtype_cls=( + ir.F32Type, + ir.F16Type, + ir.Float8E5M2Type, + ), + swizzle=(32, 64, 128), rhs_tiling_kind=("large", "small", "small+no_transpose"), lhs_tiling_kind=("large", "small", "small+no_transpose"), ) - def test_wgmma_basic( + def test_wgmma_transposes( + self, + lhs_transpose, + rhs_transpose, + in_mlir_dtype_cls, + swizzle, + rhs_tiling_kind, + lhs_tiling_kind, + ): + self._test_wgmma_basic( + m=128, + n=192, + k_steps=2, # Decrease to 1 to simplify debugging. + in_mlir_dtype_cls=in_mlir_dtype_cls, + lhs_transpose=lhs_transpose, + rhs_transpose=rhs_transpose, + swizzle=swizzle, + jax_out_dtype=jnp.float32, + rhs_tiling_kind=rhs_tiling_kind, + lhs_tiling_kind=lhs_tiling_kind, + ) + + def _test_wgmma_basic( self, m, n, @@ -675,8 +918,12 @@ def test_wgmma_basic( rhs_tiling_kind, lhs_tiling_kind, ): - if jax_out_dtype == jnp.float16 and in_mlir_dtype_cls is not ir.F16Type: - self.skipTest("Only f16 input is supported for f16 output.") + if jax_out_dtype == jnp.int32 and in_mlir_dtype_cls != I8Type: + self.skipTest("s32 accumulator only supported with s8 inputs") + if jax_out_dtype != jnp.int32 and in_mlir_dtype_cls == I8Type: + self.skipTest("s8 inputs only supported with s32 accumulator") + if jax_out_dtype == jnp.float16 and in_mlir_dtype_cls in {ir.F32Type, ir.BF16Type}: + self.skipTest(f"{in_mlir_dtype_cls.get()} does not support f16 output.") if swizzle != 128 and lhs_transpose and lhs_tiling_kind == "large": self.skipTest("Transpose only supported in 128B swizzled WGMMA") if rhs_tiling_kind == "small+no_transpose" and not rhs_transpose: @@ -686,26 +933,37 @@ def test_wgmma_basic( in_mlir_dtype = in_mlir_dtype_cls.get() out_mlir_dtype = utils.dtype_to_ir_type(jax_out_dtype) - if ir.F32Type.isinstance(in_mlir_dtype): # We actually use tf32 instead + if (lhs_transpose or not rhs_transpose) and bytewidth(in_mlir_dtype) != 2: + self.skipTest("Transpose only supported in 16-bit WGMMA") + if isinstance(in_mlir_dtype, ir.F32Type): # We actually use tf32 instead in_jax_dtype = jnp.float32 - if lhs_transpose or not rhs_transpose: - self.skipTest("Transpose only supported in 16-bit WGMMA") exponent_bits, mantissa_bits = 8, 10 # Use tf32 elif bytewidth(in_mlir_dtype) == 2: if n % 64 != 0: self.skipTest("16-bit WGMMA only supports n % 64 == 0") - if ir.F16Type.isinstance(in_mlir_dtype): + if isinstance(in_mlir_dtype, ir.F16Type): in_jax_dtype = jnp.float16 exponent_bits, mantissa_bits = 5, 10 - elif ir.BF16Type.isinstance(in_mlir_dtype): + elif isinstance(in_mlir_dtype, ir.BF16Type): in_jax_dtype = jnp.bfloat16 exponent_bits, mantissa_bits = 8, 7 else: raise NotImplementedError(in_mlir_dtype) + elif in_mlir_dtype_cls == ir.Float8E5M2Type: + in_jax_dtype = jnp.float8_e5m2 + exponent_bits, mantissa_bits = 5, 2 + elif in_mlir_dtype_cls == ir.Float8E4M3FNType: + in_jax_dtype = jnp.float8_e4m3fn + exponent_bits, mantissa_bits = 4, 3 + elif in_mlir_dtype_cls == I8Type: + in_jax_dtype = jnp.int8 + exponent_bits = mantissa_bits = None else: raise NotImplementedError(in_mlir_dtype) nk_tile = swizzle // bytewidth(in_mlir_dtype) k = nk_tile * k_steps + if n % nk_tile: + self.skipTest("tiling does not divide N") assert m % 64 == 0 and n % nk_tile == 0 small_rhs_tile = rhs_tiling_kind != "large" @@ -739,7 +997,8 @@ def kernel(ctx, lhs, rhs, out, scratch): ) for i in range(2): barriers[i].wait() - init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n, dtype=out_mlir_dtype) + is_signed = True if isinstance(in_mlir_dtype, ir.IntegerType) else None + init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n, dtype=out_mlir_dtype, is_signed=is_signed) if lhs_transpose: perm = (0, 1, 3, 2) if transpose_lhs_tiles else (1, 0, 3, 2) lhs_smem = memref_transpose(lhs_smem, perm) @@ -749,16 +1008,20 @@ def kernel(ctx, lhs, rhs, out, scratch): acc = mgpu.wgmma(init_acc, lhs_smem, rhs_smem, swizzle=swizzle) nvvm.wgmma_commit_group_sync_aligned() nvvm.wgmma_wait_group_sync_aligned(0) - acc.value.store_untiled(out) + acc.value.store_untiled(out, optimized=False) def quantize(x): # Quantize the input to avoid rounding when feeding the WGMMA return jax.lax.reduce_precision(x, exponent_bits, mantissa_bits) x_shape = (k, m) if lhs_transpose else (m, k) - x = quantize(self.prng.uniform(-1, 1, x_shape)).astype(in_jax_dtype) y_shape = (n, k) if rhs_transpose else (k, n) - y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype) + if in_mlir_dtype_cls == I8Type: + x = self.prng.integers(-128, 127, x_shape).astype(in_jax_dtype) + y = self.prng.integers(-128, 127, y_shape).astype(in_jax_dtype) + else: + x = quantize(self.prng.uniform(-1, 1, x_shape)).astype(in_jax_dtype) + y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype) out_shape = jax.ShapeDtypeStruct((m, n), jax_out_dtype) if transpose_rhs_tiles: rhs_tiling_t = rhs_tiling[::-1] if rhs_transpose else rhs_tiling @@ -781,6 +1044,10 @@ def quantize(x): x32, y32 = x.astype(np.float32), y.astype(np.float32) ref = (x32.T if lhs_transpose else x32) @ (y32.T if rhs_transpose else y32) atol = 2e-2 if jax_out_dtype == jnp.float16 else 5e-6 + if isinstance(in_mlir_dtype, ir.IntegerType) and isinstance(out_mlir_dtype, ir.IntegerType): + atol = 0 + elif utils.bitwidth(in_mlir_dtype) == 8: + atol = 3e-2 np.testing.assert_allclose(z, ref, atol=atol) # TODO(apaszke): Add support for f32 @@ -793,13 +1060,35 @@ def quantize(x): dtype=[jnp.float16, jnp.bfloat16], ) def test_wgmma_reg_lhs(self, m, n, k_steps, rhs_transpose, swizzle, dtype): - index = ir.IndexType.get() + self._test_wgmma_reg_lhs(m, n, k_steps, rhs_transpose, swizzle, dtype) - bytewidth = 2 + @parameterized.product( + m=(64, 128, 192), + n=(64, 128, 192), + k_steps=(1, 2), + swizzle=(32, 64, 128), + dtype=(jnp.int8, jnp.float8_e5m2, jnp.float8_e4m3fn), + ) + def test_wgmma_reg_lhs_8bit(self, m, n, k_steps, swizzle, dtype): + # TODO(bchetioui): relax this when ptxas is fixed. As of ptxas 12.8, + # optimizations eliminate MMA instructions, leading to only the first tile + # of the result being computed correctly. + if swizzle == 32 and dtype == jnp.int8: + self.skipTest("32-bit swizzle not supported for int8") + self._test_wgmma_reg_lhs( + m, n, k_steps, rhs_transpose=True, swizzle=swizzle, dtype=dtype + ) + + def _test_wgmma_reg_lhs(self, m, n, k_steps, rhs_transpose, swizzle, dtype): + index = ir.IndexType.get() + out_dtype = jnp.int32 if dtype == jnp.int8 else jnp.float32 + bytewidth = jnp.dtype(dtype).itemsize nk_tile = swizzle // bytewidth k = nk_tile * k_steps + if n % nk_tile: + self.skipTest("swizzle must divide N") - def kernel(ctx, rhs, out, rhs_smem): + def kernel(ctx, lhs, rhs, out, rhs_smem): del ctx for ki in range(k_steps): for ni in range(n // nk_tile): @@ -814,30 +1103,55 @@ def kernel(ctx, rhs, out, rhs_smem): dst=memref_slice(rhs_smem, (ki, ni)), swizzle=swizzle, ) - init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n) - lhs_regs = iota_tensor(m, k, dtype) + init_acc = mgpu.WGMMAAccumulator.zero( + m=m, n=n, dtype=utils.dtype_to_ir_type(out_dtype), + is_signed=True if dtype == jnp.int8 else None, + ) + layout = fa.WGMMA_LAYOUT_8BIT if dtypes.itemsize_bits(dtype) == 8 else fa.WGMMA_LAYOUT + lhs_regs = fa.FragmentedArray.load_untiled( + lhs, layout=layout, optimized=False, is_signed=utils.is_signed(dtype), + ) if rhs_transpose: rhs_smem = memref_transpose(rhs_smem, (0, 1, 3, 2)) acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, swizzle=swizzle) nvvm.wgmma_commit_group_sync_aligned() nvvm.wgmma_wait_group_sync_aligned(0) - acc.value.store_untiled(out) + acc.value.store_untiled(out, optimized=False) y_shape = (n, k) if rhs_transpose else (k, n) - y = self.prng.uniform(-1, 1, y_shape).astype(dtype) - out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) + if dtype == jnp.int8: + x = np.arange(m * k, dtype=dtype).reshape(m, k) + y = self.prng.integers(-128, 127, y_shape, dtype=dtype) + else: + def quantize_f8(x): + if dtype not in {jnp.float8_e4m3fn, jnp.float8_e5m2}: + return x + if dtype == jnp.float8_e4m3fn: + exponent_bits, mantissa_bits = 4, 3 + else: + exponent_bits, mantissa_bits = 5, 2 + return jax.lax.reduce_precision(x, exponent_bits, mantissa_bits) + x = quantize_f8(self.prng.uniform(-1, 1, (m, k))).astype(dtype) + y = quantize_f8(self.prng.uniform(-1, 1, y_shape)).astype(dtype) + out_shape = jax.ShapeDtypeStruct((m, n), out_dtype) scratch_shape = jax.ShapeDtypeStruct( - (k_steps, n // nk_tile, nk_tile, nk_tile), dtype + (k_steps, n // nk_tile, nk_tile, nk_tile), dtype ) z = mgpu.as_gpu_kernel( - kernel, (1, 1, 1), (128, 1, 1), y, out_shape, scratch_shape - )(y) - x = np.arange(m * k, dtype=dtype).reshape(m, k) + kernel, (1, 1, 1), (128, 1, 1), (x, y), out_shape, scratch_shape + )(x, y) ref = jax.lax.dot( - x, (y.T if rhs_transpose else y), preferred_element_type=jnp.float32 + x, (y.T if rhs_transpose else y), preferred_element_type=out_dtype ) - rtol = 5e-4 - np.testing.assert_allclose(z, ref, rtol=rtol, atol=0) + if dtype == jnp.int8: + atol = rtol = 0 + elif dtype == jnp.float8_e4m3fn: + atol = rtol = 6e-3 + elif dtype == jnp.float8_e5m2: + atol = rtol = 3e-3 + else: + atol, rtol = 0, 5e-4 + np.testing.assert_allclose(z, ref, rtol=rtol, atol=atol) @parameterized.product( rhs_transpose=(False, True), @@ -881,7 +1195,7 @@ def kernel(ctx, rhs, out, smem): acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, swizzle=swizzle) nvvm.wgmma_commit_group_sync_aligned() nvvm.wgmma_wait_group_sync_aligned(0) - acc.value.store_untiled(out) + acc.value.store_untiled(out, optimized=False) jax_dtype = jnp.float16 y_shape = (n, k) if rhs_transpose else (k, n) @@ -897,7 +1211,7 @@ def kernel(ctx, rhs, out, smem): ref = jax.lax.dot( x, (y.T if rhs_transpose else y), preferred_element_type=jnp.float32 ) - np.testing.assert_allclose(z, ref, rtol=5e-4, atol=0) + np.testing.assert_allclose(z, ref, rtol=1e-3, atol=0) class TCGen05Test(TestCase): @@ -908,19 +1222,207 @@ def setUp(self): if not any(jtu.is_cuda_compute_capability_equal(sm) for sm in capabilities): self.skipTest("Only works on GPU with capability sm_100a or sm_101a") + @parameterized.product( + jax_dtype_packing=[(jnp.float32, 1), (jnp.float16, 1), (jnp.float16, 2), (jnp.float8_e5m2, 4)], + reg_tmem_layout_m=[ + (lambda _c, _p: tcgen05.LAYOUT, lambda _, p: tcgen05.tmem_default_layout(p), 128), + (lambda _c, _p: fa.WGMMA_LAYOUT, tcgen05.tmem_half_lane_layout, 64), + ( + lambda c, _p: tcgen05.fa_m64_collective_layout(c), + tcgen05.tmem_m64_collective_layout, + 64, + ), + ( + lambda c, p: tcgen05.tmem_m64_collective_layout(c, p).as_tiled_layout(), + tcgen05.tmem_m64_collective_layout, + 64, + ), + ], + ) + def test_load_store_tmem(self, jax_dtype_packing, reg_tmem_layout_m): + jax_dtype, packing = jax_dtype_packing + reg_layout_f, tmem_layout_f, m = reg_tmem_layout_m + n = 160 + reg_layout = reg_layout_f(n, packing) + if tmem_layout_f is tcgen05.tmem_m64_collective_layout: + if jax_dtype == jnp.float16 and packing == 1: + self.skipTest("Not implemented yet") + is_native_transfer = tmem_layout_f(n, packing).as_tiled_layout() == reg_layout + if not is_native_transfer and jax_dtype == jnp.float8_e5m2: + self.skipTest("Not implemented yet") + + def kernel(ctx, input, output, tmem): + del ctx + tmem.store(fa.FragmentedArray.load_untiled(input, layout=reg_layout, optimized=False)) + tcgen05.commit_tmem() + tmem.load(reg_layout).store_untiled(output, optimized=False) + + x = self.prng.uniform(-1, 1, (m, n)).astype(jax_dtype) + y = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, + mgpu.TMEM(x.shape, jax_dtype, layout=tmem_layout_f(n, packing)), + )(x) + np.testing.assert_array_equal(x, y) + + @parameterized.parameters([ + (jnp.float32, 1), + (jnp.float16, 1), + (jnp.float16, 2), + (jnp.float8_e5m2, 4), + (jnp.float4_e2m1fn, 8), + ]) + def test_load_store_tmem_native(self, jax_dtype, packing): + # TODO(bchetioui): add a test for int8 with a native layout with vector + # length equal to 4 once TMEM load is implemented for it. + def kernel(ctx, input, output, tmem): + del ctx + reg_layout = tcgen05.tmem_default_layout(max(packing, 2)).as_tiled_layout() + tmem.store(fa.FragmentedArray.load_untiled(input, layout=reg_layout, optimized=False)) + tcgen05.commit_tmem() + tmem.load(reg_layout).store_untiled(output, optimized=False) + + x = self.prng.uniform(-1, 1, (128, 128)).astype(jax_dtype) + y = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, mgpu.TMEM(x.shape, jax_dtype, packing=packing) + )(x) + np.testing.assert_array_equal(x, y) + + def test_mixed_tmem_allocations_raise(self): + def body(ctx, out, scratch): + del ctx, out, scratch + + with self.assertRaisesRegex( + ValueError, + "Can't mix collective and non-collective TMEM allocations within the" + " same kernel.", + ): + mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(), + out_shape=(jax.ShapeDtypeStruct((), jnp.int32),), + smem_scratch_shape=[ + mgpu.TMEM((128, 128), jnp.float16, collective=True), + mgpu.TMEM((128, 128), jnp.float16, collective=False), + ], + ) + + @parameterized.parameters([ + (jnp.float32, 1, "130.0000"), + (jnp.float16, 1, "130.0000"), + (jnp.float16, 2, "[132.000000,133.000000]"), + ]) + @jtu.thread_unsafe_test() + def test_tmem_debug_print(self, jax_dtype, packing, expected): + def kernel(ctx, input, output, tmem): + del ctx, output + tmem.store(fa.FragmentedArray.load_untiled(input, layout=tcgen05.LAYOUT, optimized=False)) + tcgen05.commit_tmem() + tmem.slice(slice(None), slice(0, 8))._debug_print() + + x = jnp.arange(128 * 128, dtype=jax_dtype).reshape(128, 128) + with self.capture_stdout() as stdout: + mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, mgpu.TMEM(x.shape, jax_dtype, packing=packing), + )(x).block_until_ready() + self.assertIn("[1, 2]: " + expected, stdout()) + @parameterized.product( lhs_transpose=(False, True), rhs_transpose=(False, True), - in_jax_dtype=(jnp.float16, jnp.bfloat16), # TODO(apaszke): f32 - out_jax_dtype=(jnp.float32,), # TODO(apaszke): f16 accumulation - m=(128,), # TODO(apaszke): 64, 192, 256 - n=(64, 128, 256, 512), # TODO(apaszke): 192, other non-power-of-2 - k_steps=(1, 2), + in_jax_dtype=(jnp.float16, jnp.bfloat16, jnp.float8_e5m2, jnp.float8_e4m3fn), # TODO(apaszke): f32 + out_jax_dtype=(jnp.float16, jnp.float32,), + m=(64, 128,), # TODO(apaszke): 64, 192, 256 + n=(64, 128, 192, 224, 256, 512), + swizzle=(32, 64, 128,), + ) + def test_mma_basic_float(self, **kwargs): + in_bytewidth = jnp.dtype(kwargs["in_jax_dtype"]).itemsize + lhs_transpose = kwargs["lhs_transpose"] + swizzle = kwargs["swizzle"] + if lhs_transpose and kwargs["m"] * in_bytewidth < swizzle: + self.skipTest("swizzle too large for input (lhs)") + n_steps = 2 if kwargs["m"] == 64 else 1 + n_instr_size = kwargs["n"] * in_bytewidth // n_steps + if n_instr_size < swizzle or n_instr_size % swizzle != 0: + self.skipTest("swizzle doesn't work with this instruction size") + if dtypes.itemsize_bits(kwargs["in_jax_dtype"]) <= 8 and kwargs["n"] == swizzle: + self.skipTest("Only 8-bit and larger inputs are supported for MMA") + self._basic_mma_test( + **kwargs, + k_steps=2, # Reducing to 1 can be helpful while debugging. + lhs_transpose_tiles=False, + rhs_transpose_tiles=False, + ) + + @parameterized.product( + lhs_transpose=(False, True), + rhs_transpose=(False, True), + in_jax_dtype=(jnp.int8,), + out_jax_dtype=(jnp.int32,), + m=(64, 128,), # TODO(apaszke): 192, 256 + n=(64, 128, 160, 192, 256, 512), + swizzle=(32, 64, 128,), + ) + def test_mma_basic_int(self, **kwargs): + in_bytewidth = jnp.dtype(kwargs["in_jax_dtype"]).itemsize + lhs_transpose = kwargs["lhs_transpose"] + swizzle = kwargs["swizzle"] + if lhs_transpose and kwargs["m"] * in_bytewidth < swizzle: + self.skipTest("swizzle too large for input (lhs)") + n_steps = 2 if kwargs["m"] == 64 else 1 + n_instr_size = kwargs["n"] * in_bytewidth // n_steps + if n_instr_size < swizzle or n_instr_size % swizzle != 0: + self.skipTest("swizzle doesn't work with this instruction size") + if dtypes.itemsize_bits(kwargs["in_jax_dtype"]) <= 8 and kwargs["n"] == swizzle: + self.skipTest("Only 8-bit and larger inputs are supported for MMA") + self._basic_mma_test( + **kwargs, + k_steps=2, # Reducing to 1 can be helpful while debugging. + lhs_transpose_tiles=False, + rhs_transpose_tiles=False, + ) + + @parameterized.product( + lhs_transpose=(False, True), + rhs_transpose=(False, True), + in_jax_dtype=(jnp.float16,), + out_jax_dtype=(jnp.float32,), + m=(128,), + n=(128, 512), swizzle=(32, 64, 128,), - rhs_transpose_tiles=(False, True), lhs_transpose_tiles=(False, True), + rhs_transpose_tiles=(False, True), + ) + def test_mma_transposed_tiles(self, **kwargs): + if not kwargs["lhs_transpose_tiles"] and not kwargs["rhs_transpose_tiles"]: + self.skipTest("This is already tested in test_mma_basic") + self._basic_mma_test( + **kwargs, + k_steps=2, # Reducing to 1 can be helpful while debugging. + ) + + @parameterized.product( + lhs_transpose=(False, True), + rhs_transpose=(False, True), + m=(64, 128,), + n=(128, 256, 512), + lhs_swizzle=(32, 64, 128,), + rhs_swizzle=(32, 64, 128,), ) - def test_mma_basic( + def test_mma_different_swizzle(self, **kwargs): + if kwargs["lhs_swizzle"] == kwargs["rhs_swizzle"]: + self.skipTest("Swizzle is equal") + self._basic_mma_test( + in_jax_dtype=jnp.float16, + out_jax_dtype=jnp.float32, + swizzle=None, + k_steps=2, # Reducing to 1 can be helpful while debugging. + **kwargs, + ) + + def _basic_mma_test( self, m, n, @@ -930,19 +1432,29 @@ def test_mma_basic( rhs_transpose, in_jax_dtype, out_jax_dtype, - rhs_transpose_tiles, - lhs_transpose_tiles, + rhs_transpose_tiles=False, + lhs_transpose_tiles=False, + lhs_swizzle=None, + rhs_swizzle=None, ): - if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16: - self.skipTest("Only f16 input is supported for f16 output.") + if lhs_swizzle is None: + lhs_swizzle = swizzle + if rhs_swizzle is None: + rhs_swizzle = swizzle + swizzle = max(lhs_swizzle, rhs_swizzle) + if out_jax_dtype != jnp.float32 and ( + in_jax_dtype == jnp.float32 or in_jax_dtype == jnp.bfloat16 + ): + self.skipTest("Only f32 output is supported for f32 and bf16 input.") in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype) swizzle_elems = swizzle // bytewidth(in_mlir_dtype) k = swizzle_elems * k_steps - lhs_tiling = rhs_tiling = (8, swizzle_elems) + lhs_tiling = (8, lhs_swizzle // bytewidth(in_mlir_dtype)) + rhs_tiling = (8, rhs_swizzle // bytewidth(in_mlir_dtype)) def kernel(ctx, lhs, rhs, out, scratch): - lhs_smem, rhs_smem, barriers, acc = scratch + lhs_smem, rhs_smem, barriers, mma_barrier, acc = scratch lhs_transform = (mgpu.TileTransform(lhs_tiling),) if lhs_transpose_tiles: lhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) @@ -952,14 +1464,14 @@ def kernel(ctx, lhs, rhs, out, scratch): ctx.async_copy( src_ref=lhs, dst_ref=lhs_smem, - swizzle=swizzle, + swizzle=lhs_swizzle, gmem_transform=lhs_transform, barrier=barriers[0], ) ctx.async_copy( src_ref=rhs, dst_ref=rhs_smem, - swizzle=swizzle, + swizzle=rhs_swizzle, gmem_transform=rhs_transform, barrier=barriers[1], ) @@ -975,70 +1487,805 @@ def kernel(ctx, lhs, rhs, out, scratch): if rhs_transpose: rhs_smem = memref_transpose(rhs_smem, (1, 0, 3, 2)) tcgen05.mma( - acc, lhs_smem, rhs_smem, a_swizzle=swizzle, b_swizzle=swizzle, accumulate=False, + acc, lhs_smem, rhs_smem, a_swizzle=lhs_swizzle, b_swizzle=rhs_swizzle, accumulate=False, ) - tcgen05.commit_arrive(barriers[2]) - barriers[2].wait(for_tensor_core=True) - acc[:].store_untiled(out) - - in_finfo = jnp.finfo(in_jax_dtype) - exponent_bits, mantissa_bits = in_finfo.nexp, in_finfo.nmant - def quantize(x): - # Quantize the input to avoid rounding when feeding the TensorCore - return jax.lax.reduce_precision(x, exponent_bits, mantissa_bits) + tcgen05.commit_arrive(mma_barrier) + mma_barrier.wait(orders_tensor_core=True) + is_signed = True if jnp.issubdtype(in_jax_dtype, jnp.integer) else None + acc.load(is_signed=is_signed).store_untiled(out, optimized=False) x_shape = (k, m) if lhs_transpose else (m, k) - x = quantize(self.prng.uniform(-1, 1, x_shape)).astype(in_jax_dtype) + x = self.prng.uniform(-1, 1, x_shape).astype(in_jax_dtype) y_shape = (n, k) if rhs_transpose else (k, n) - y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype) + y = self.prng.uniform(-1, 1, y_shape).astype(in_jax_dtype) out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype) + if y_shape[0] % rhs_tiling[0] != 0 or y_shape[1] % rhs_tiling[1] != 0: + self.skipTest("rhs tiling must divide y_shape") + rhs_smem_shape = tile_shape(y_shape, rhs_tiling) if rhs_transpose_tiles: rhs_smem_shape = ( - y_shape[1] // rhs_tiling[1], y_shape[0] // rhs_tiling[0], *rhs_tiling, + rhs_smem_shape[1], rhs_smem_shape[0], *rhs_smem_shape[2:] ) - else: - rhs_smem_shape = tile_shape(y_shape, rhs_tiling) + if x_shape[0] % lhs_tiling[0] != 0 or x_shape[1] % lhs_tiling[1] != 0: + self.skipTest("lhs tiling must divide x_shape") + lhs_smem_shape = tile_shape(x_shape, lhs_tiling) if lhs_transpose_tiles: lhs_smem_shape = ( - x_shape[1] // lhs_tiling[1], x_shape[0] // lhs_tiling[0], *lhs_tiling, + lhs_smem_shape[1], lhs_smem_shape[0], *lhs_smem_shape[2:] ) - else: - lhs_smem_shape = tile_shape(x_shape, lhs_tiling) scratch_shape = [ jax.ShapeDtypeStruct(lhs_smem_shape, in_jax_dtype), jax.ShapeDtypeStruct(rhs_smem_shape, in_jax_dtype), - mgpu.TMABarrier(3), - mgpu.TMEM((128, n), out_jax_dtype), + mgpu.TMABarrier(2), + mgpu.Barrier(1), + mgpu.TMEM((m, n), out_jax_dtype), ] z = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (x, y), out_shape, scratch_shape )(x, y) x32, y32 = x.astype(np.float32), y.astype(np.float32) ref = (x32.T if lhs_transpose else x32) @ (y32.T if rhs_transpose else y32) - atol = 2e-2 if out_jax_dtype == jnp.float16 else 5e-6 - np.testing.assert_allclose(z, ref, atol=atol) + atol = 2e-2 if out_jax_dtype == jnp.float16 else 2e-5 + rtol = 8e-4 if out_jax_dtype == jnp.float16 else 1e-7 + np.testing.assert_allclose(z, ref, atol=atol, rtol=rtol) @parameterized.product( - lhs_transpose=(False, True), - rhs_transpose=(False, True), - in_jax_dtype=(jnp.float16,), # TODO(apaszke): f32 - out_jax_dtype=(jnp.float32,), # TODO(apaszke): f16 accumulation - m=(256,), # TODO(apaszke): 64, 192, 256 - n=(128, 256, 512), # TODO(apaszke): 192, other non-power-of-2 - k_steps=(1, 2), - swizzle=(32, 64, 128,), + in_jax_dtype=(jnp.float16, jnp.bfloat16), # TODO(apaszke): f32 + out_jax_dtype=(jnp.float16, jnp.float32,), + m=(128,), # TODO(apaszke): 64, 192, 256 + n=(64, 160, 128, 256), ) - def test_mma_collective( - self, - m, - n, - k_steps, - swizzle, - lhs_transpose, - rhs_transpose, - in_jax_dtype, + def test_mma_lhs_tmem_float(self, m, n, in_jax_dtype, out_jax_dtype): + self._basic_mma_lhs_tmem_test( + m, n, in_jax_dtype, out_jax_dtype, tcgen05.LAYOUT, swizzle=128 + ) + + @parameterized.product( + in_jax_dtype=(jnp.int8, jnp.uint8), + out_jax_dtype=(jnp.int32,), + m=(128,), + n=(64, 128, 256), + ) + def test_mma_lhs_tmem_integer(self, m, n, in_jax_dtype, out_jax_dtype): + self._basic_mma_lhs_tmem_test( + m, n, in_jax_dtype, out_jax_dtype, fa.tmem_native_layout(vector_length=4), + swizzle=math.gcd(n, 128) + ) + + def _basic_mma_lhs_tmem_test( + self, m, n, in_jax_dtype, out_jax_dtype, lhs_layout, swizzle + ): + k_steps = 2 # Reducing to 1 can be helpful while debugging. + if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16: + self.skipTest("Only f16 input is supported for f16 output.") + + in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype) + swizzle_elems = swizzle // bytewidth(in_mlir_dtype) + k = swizzle_elems * k_steps + rhs_tiling = (8, swizzle_elems) + + def kernel(ctx, lhs, rhs, out, scratch): + rhs_smem, barrier, mma_barrier, acc, lhs_tmem = scratch + ctx.async_copy( + src_ref=rhs, + dst_ref=rhs_smem, + swizzle=swizzle, + gmem_transform=mgpu.TileTransform(rhs_tiling), + barrier=barrier, + ) + barrier.wait() + if jnp.issubdtype(in_jax_dtype, jnp.integer): + is_signed = jnp.issubdtype(in_jax_dtype, jnp.signedinteger) + else: + is_signed = None + lhs_tmem.store( + fa.FragmentedArray.load_untiled( + lhs, layout=lhs_layout, is_signed=is_signed, optimized=False + ) + ) + tcgen05.commit_tmem() + with mgpu.single_thread(): + tcgen05.mma( + acc, lhs_tmem, rhs_smem, a_swizzle=swizzle, b_swizzle=swizzle, accumulate=False, + ) + tcgen05.commit_arrive(mma_barrier) + mma_barrier.wait(orders_tensor_core=True) + acc.load(is_signed=is_signed).store_untiled(out, optimized=False) + + x_shape = (m, k) + x = self.prng.uniform(-1, 1, x_shape).astype(in_jax_dtype) + y_shape = (k, n) + y = self.prng.uniform(-1, 1, y_shape).astype(in_jax_dtype) + out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype) + if y_shape[0] % rhs_tiling[0] != 0 or y_shape[1] % rhs_tiling[1] != 0: + self.skipTest("rhs tiling must divide y_shape") + scratch_shape = [ + jax.ShapeDtypeStruct(tile_shape(y_shape, rhs_tiling), in_jax_dtype), + mgpu.TMABarrier(), + mgpu.Barrier(1), + mgpu.TMEM((128, n), out_jax_dtype), + mgpu.TMEM((128, k), in_jax_dtype, packing=4 // bytewidth(in_mlir_dtype)), + ] + z = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (x, y), out_shape, scratch_shape + )(x, y) + x32, y32 = x.astype(np.float32), y.astype(np.float32) + ref = x32 @ y32 + atol = 2e-2 if out_jax_dtype == jnp.float16 else 2e-5 + rtol = 8e-4 if out_jax_dtype == jnp.float16 else 1e-7 + np.testing.assert_allclose(z, ref, atol=atol, rtol=rtol) + + def test_tmem_copy_scales(self): + dtype = jnp.float8_e8m0fnu + + def kernel(ctx, src, out, scratch): + smem, barrier, tmem = scratch + ctx.async_copy(src_ref=src, dst_ref=smem, barrier=barrier) + barrier.wait() + with mgpu.single_thread(): + tcgen05.async_copy_scales_smem_to_tmem(smem, tmem) + tcgen05.commit_arrive(barrier) + barrier.wait(orders_tensor_core=True) + # We print as i32, because i8 seems to overflow the CUDA printf buffer and + # produce a truncated output. + tcgen05.TMEMRef( + tmem.address, + (128, 4), + ir.IntegerType.get_signless(32), + tcgen05.tmem_default_layout(), + )._debug_print() + copy(src, out) + + shape = (1, 1, 32, 16) + x = jax.lax.bitcast_convert_type( + np.arange(math.prod(shape), dtype=np.uint8).reshape(shape), dtype + ) + scratch_shape = [ + x, + mgpu.TMABarrier(1), + mgpu.TMEM((128, 4), dtype, layout=tcgen05.scales_layout()), + ] + with self.capture_stdout() as stdout: + mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, scratch_shape + )(x) + matches = 0 + for l in stdout().splitlines(): + if ":" not in l: + continue + idxs, value = l.split(":") + row, col = map(int, idxs[1:-1].split(",")) + base = (row % 32) * 16 + col * 4 + base %= 256 # int8 has very limited range + expected = base | (base + 1) << 8 | (base + 2) << 16 | (base + 3) << 24 + self.assertEqual(int(value), expected) + matches += 1 + self.assertEqual(matches, 128 * 4) + + def _sample_scales(self, m, k, n, block_size, scale_jax_dtype): + ka, kb = jax.random.split(jax.random.key(1234), 2) + if scale_jax_dtype == jnp.float8_e8m0fnu: + a_scales = jax.lax.bitcast_convert_type( + jax.random.randint(ka, (m, k // block_size), 122, 132, dtype=jnp.uint8), + scale_jax_dtype + ) + b_scales = jax.lax.bitcast_convert_type( + jax.random.randint(kb, (n, k // block_size), 122, 132, dtype=jnp.uint8), + scale_jax_dtype + ) + elif scale_jax_dtype == jnp.float8_e4m3fn: + a_scales = jnp.abs( + jax.random.normal(ka, (m, k // block_size), dtype=jnp.float32).astype( + scale_jax_dtype + ) + ) + b_scales = jnp.abs( + jax.random.normal(kb, (n, k // block_size), dtype=jnp.float32).astype( + scale_jax_dtype + ) + ) + else: + raise ValueError(f"Unsupported scale dtype: {scale_jax_dtype}") + return a_scales, b_scales + + @parameterized.product( + in_jax_dtype=(jnp.float8_e5m2, jnp.float8_e4m3fn, jnp.float4_e2m1fn), + scale_jax_dtype=(jnp.float8_e8m0fnu, jnp.float8_e4m3fn), + m=(128,), # TODO(apaszke): 256 + n=(128, 256), # TODO(apaszke): 192, other non-power-of-2 + swizzle=(32, 128), + ) + def test_mma_block_scaled_basic(self, m, n, in_jax_dtype, scale_jax_dtype, swizzle): + out_jax_dtype = jnp.float32 + # When swizzle is small, we need to take many steps to make it large enough + # to make the scale count a multiple of 4. + k_steps = 4 if swizzle == 32 else 2 + if scale_jax_dtype == jnp.float8_e8m0fnu: + block_size = 32 + elif scale_jax_dtype == jnp.float8_e4m3fn: + if in_jax_dtype != jnp.float4_e2m1fn: + self.skipTest("Only float4_e2m1fn input is supported for e4m3fn scale.") + block_size = 16 + else: + raise ValueError(f"Unsupported scale dtype: {scale_jax_dtype}") + if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16: + self.skipTest("Only f16 input is supported for f16 output.") + + in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype) + swizzle_elems = 8 * swizzle // bitwidth(in_mlir_dtype) + k = swizzle_elems * k_steps + lhs_tiling = rhs_tiling = (8, swizzle_elems) + + def kernel(ctx, lhs, rhs, lhs_scales_gmem, rhs_scales_gmem, out, scratch): + lhs_smem, rhs_smem, lhs_scales_smem, rhs_scales_smem, barriers, mma_barrier, acc, lhs_scales, rhs_scales = scratch + operand_kwargs = dict( + swizzle=swizzle, + gmem_transform=mgpu.TileTransform(lhs_tiling), + ) + ctx.async_copy(src_ref=lhs, dst_ref=lhs_smem, barrier=barriers[0], **operand_kwargs) + ctx.async_copy(src_ref=rhs, dst_ref=rhs_smem, barrier=barriers[1], **operand_kwargs) + ctx.async_copy(src_ref=lhs_scales_gmem, dst_ref=lhs_scales_smem, barrier=barriers[2]) + ctx.async_copy(src_ref=rhs_scales_gmem, dst_ref=rhs_scales_smem, barrier=barriers[3]) + for i in range(4): + barriers[i].wait() + with mgpu.single_thread(): + tcgen05.async_copy_scales_smem_to_tmem(lhs_scales_smem, lhs_scales) + tcgen05.async_copy_scales_smem_to_tmem(rhs_scales_smem, rhs_scales) + tcgen05.mma( + acc, + lhs_smem, + mgpu.memref_transpose(rhs_smem, (1, 0, 3, 2)), + a_swizzle=swizzle, + b_swizzle=swizzle, + a_scale=lhs_scales, + b_scale=rhs_scales, + accumulate=False, + ) + tcgen05.commit_arrive(mma_barrier) + mma_barrier.wait(orders_tensor_core=True) + acc.load().store_untiled(out, optimized=False) + + x_shape = (m, k) + x = self.prng.uniform(-1, 1, x_shape).astype(in_jax_dtype) + y_shape = (n, k) + y = self.prng.uniform(-1, 1, y_shape).astype(in_jax_dtype) + out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype) + scratch_shape = [ + jax.ShapeDtypeStruct(tile_shape(x_shape, lhs_tiling), in_jax_dtype), + jax.ShapeDtypeStruct(tile_shape(y_shape, rhs_tiling), in_jax_dtype), + jax.ShapeDtypeStruct((m // 128, k // (block_size * 4), 32, 16), scale_jax_dtype), + jax.ShapeDtypeStruct((n // 128, k // (block_size * 4), 32, 16), scale_jax_dtype), + mgpu.TMABarrier(4), + mgpu.Barrier(1), + mgpu.TMEM((m, n), out_jax_dtype), + mgpu.TMEM((m, k // block_size), scale_jax_dtype, layout=tcgen05.scales_layout()), + mgpu.TMEM((n, k // block_size), scale_jax_dtype, layout=tcgen05.scales_layout()), + ] + a_scales, b_scales = self._sample_scales(m, k, n, block_size, scale_jax_dtype) + def format_scales(scales): + mn, k = scales.shape + assert mn % 128 == 0 and k % 4 == 0, scales.shape + return ( + scales.reshape(mn // 128, 4, 32, k // 4, 4) + .transpose(0, 3, 2, 1, 4) + .reshape(mn // 128, k // 4, 32, 16) + ) + a_gpu_scales, b_gpu_scales = map(format_scales, (a_scales, b_scales)) + args = (x, y, a_gpu_scales, b_gpu_scales) + z = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), args, out_shape, scratch_shape + )(*args) + x32, y32 = x.astype(np.float32), y.astype(np.float32) + a_logical_scales = jnp.repeat(a_scales, block_size, axis=1).astype(jnp.float32) + b_logical_scales = jnp.repeat(b_scales, block_size, axis=1).astype(jnp.float32) + ref = (x32 * a_logical_scales) @ (y32 * b_logical_scales).T + np.testing.assert_allclose(z, ref, atol=2e-4, rtol=5e-6) + + @parameterized.product( + m=(256,), + n=(128, 256), + scale_jax_dtype=(jnp.float8_e8m0fnu, jnp.float8_e4m3fn), + ) + def test_mma_block_scaled_collective(self, m, n, scale_jax_dtype): + in_jax_dtype = jnp.float4_e2m1fn + out_jax_dtype = jnp.float32 + scale_block = 32 if scale_jax_dtype == jnp.float8_e8m0fnu else 16 + swizzle = 128 + k_steps = 2 + + in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype) + swizzle_elems = 8 * swizzle // bitwidth(in_mlir_dtype) + k = swizzle_elems * k_steps + lhs_tiling = rhs_tiling = (8, swizzle_elems) + + def kernel(ctx, lhs, rhs, lhs_scales_gmem, rhs_scales_gmem, out, scratch): + ( + lhs_smem, rhs_smem, lhs_scales_smem, rhs_scales_smem, + barriers, mma_barrier, acc, lhs_scales, rhs_scales + ) = scratch + ctx.async_copy( + src_ref=lhs, + dst_ref=lhs_smem, + barrier=barriers[0], + swizzle=swizzle, + gmem_transform=mgpu.TileTransform(lhs_tiling), + collective=gpu.Dimension.x, + partitioned=0, + ) + ctx.async_copy( + src_ref=rhs, + dst_ref=rhs_smem, + barrier=barriers[1], + swizzle=swizzle, + gmem_transform=mgpu.TileTransform(rhs_tiling), + collective=gpu.Dimension.x, + partitioned=0, + ) + ctx.async_copy( + src_ref=lhs_scales_gmem, + dst_ref=lhs_scales_smem, + barrier=barriers[2], + collective=gpu.Dimension.x, + partitioned=0, + ) + # B scales are replicated! Note that this does not use 2CTA TMA and will + # need to be awaited in the non-leader CTA or else we will double arrive. + ctx.async_copy( + src_ref=rhs_scales_gmem, + dst_ref=rhs_scales_smem, + barrier=barriers[3], + collective=gpu.Dimension.x, + ) + + is_leader_thread = single_thread_predicate() + index = ir.IndexType.get() + block_id = gpu.cluster_block_id(gpu.Dimension.x) + is_first_block = arith.cmpi(arith.CmpIPredicate.eq, block_id, c(0, index)) + with when(arith.andi(is_first_block, is_leader_thread)): + for i in range(4): + barriers[i].wait() + tcgen05.async_copy_scales_smem_to_tmem(lhs_scales_smem, lhs_scales, collective=True) + tcgen05.async_copy_scales_smem_to_tmem(rhs_scales_smem, rhs_scales, collective=True) + tcgen05.mma( + acc, + lhs_smem, + mgpu.memref_transpose(rhs_smem, (1, 0, 3, 2)), + a_swizzle=swizzle, + b_swizzle=swizzle, + a_scale=lhs_scales, + b_scale=rhs_scales, + accumulate=False, + collective=True, + ) + tcgen05.commit_arrive(mma_barrier, collective=True, ctx=ctx) + mma_barrier.wait(orders_tensor_core=True) + m_block_tile = m // 2 + m_slice = ds(arith.muli(block_id, c(m_block_tile, index)), m_block_tile) + acc.load().store_untiled(memref_slice(out, m_slice), optimized=False) + + x_shape = (m, k) + x = self.prng.uniform(-1, 1, x_shape).astype(in_jax_dtype) + y_shape = (n, k) + y = self.prng.uniform(-1, 1, y_shape).astype(in_jax_dtype) + out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype) + + m_block = m // 2 + n_block = n // 2 + + scratch_shape = [ + jax.ShapeDtypeStruct( + tile_shape((m_block, k), lhs_tiling), in_jax_dtype + ), + jax.ShapeDtypeStruct( + tile_shape((n_block, k), rhs_tiling), in_jax_dtype + ), + jax.ShapeDtypeStruct( + (m_block // 128, k // (scale_block * 4), 32, 16), scale_jax_dtype + ), + jax.ShapeDtypeStruct( + (n // 128, k // (scale_block * 4), 32, 16), scale_jax_dtype + ), + mgpu.TMABarrier(4), + mgpu.Barrier(1), + mgpu.TMEM((m_block, n), out_jax_dtype, collective=True), + mgpu.TMEM( + (m_block, k // scale_block), + scale_jax_dtype, + layout=tcgen05.scales_layout(), + collective=True, + ), + mgpu.TMEM( + (n, k // scale_block), + scale_jax_dtype, + layout=tcgen05.scales_layout(), + collective=True, + ), + ] + + a_scales, b_scales = self._sample_scales(m, k, n, scale_block, scale_jax_dtype) + + def format_scales(scales): + mn, k = scales.shape + assert mn % 128 == 0 and k % 4 == 0, scales.shape + return ( + scales.reshape(mn // 128, 4, 32, k // 4, 4) + .transpose(0, 3, 2, 1, 4) + .reshape(mn // 128, k // 4, 32, 16) + ) + + a_gpu_scales = format_scales(a_scales) + b_gpu_scales = format_scales(b_scales) + args = (x, y, a_gpu_scales, b_gpu_scales) + z = mgpu.as_gpu_kernel( + kernel, (2, 1, 1), (128, 1, 1), args, out_shape, scratch_shape, cluster=(2, 1, 1), + )(*args) + + x32, y32 = x.astype(np.float32), y.astype(np.float32) + a_logical_scales = jnp.repeat(a_scales, scale_block, axis=1).astype(jnp.float32) + b_logical_scales = jnp.repeat(b_scales, scale_block, axis=1).astype(jnp.float32) + ref = (x32 * a_logical_scales) @ (y32 * b_logical_scales).T + np.testing.assert_allclose(z, ref, atol=2e-4, rtol=5e-6) + + @parameterized.product( + lhs_transpose=(False, True), + rhs_transpose=(False, True), + in_jax_dtype=(jnp.float16, jnp.bfloat16, jnp.int8, jnp.float8_e4m3fn), + m=(128,), # TODO(apaszke): 256 + n=(128, 256), # TODO(apaszke): other non-power-of-2 + lhs_swizzle=(32, 64, 128), + rhs_swizzle=(64, 128), # 32 is too small and unsuported. + ) + def test_mma_sparse(self, m, n, in_jax_dtype, lhs_swizzle, rhs_swizzle, lhs_transpose, rhs_transpose): + if jnp.issubdtype(in_jax_dtype, jnp.floating): + out_jax_dtype = jnp.float32 + else: + out_jax_dtype = jnp.int32 + sparse_meta_dtype = jnp.uint2 + + in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype) + k = 256 + lhs_tiling = (8, 8 * lhs_swizzle // bitwidth(in_mlir_dtype)) + rhs_tiling = (8, 8 * rhs_swizzle // bitwidth(in_mlir_dtype)) + + def kernel(ctx, lhs, rhs, lhs_sparse_gmem, out, scratch): + lhs_smem, rhs_smem, lhs_sparse_smem, barriers, mma_barrier, acc, lhs_sparse = scratch + ctx.async_copy(src_ref=lhs, dst_ref=lhs_smem, barrier=barriers[0], swizzle=lhs_swizzle, gmem_transform=mgpu.TileTransform(lhs_tiling)) + ctx.async_copy(src_ref=rhs, dst_ref=rhs_smem, barrier=barriers[1], swizzle=rhs_swizzle, gmem_transform=mgpu.TileTransform(rhs_tiling)) + ctx.async_copy(src_ref=lhs_sparse_gmem, dst_ref=lhs_sparse_smem, barrier=barriers[2]) + for i in range(3): + barriers[i].wait() + with mgpu.single_thread(): + tcgen05.async_copy_sparse_metadata_smem_to_tmem(lhs_sparse_smem, lhs_sparse) + if lhs_transpose: + lhs_smem = mgpu.memref_transpose(lhs_smem, (1, 0, 3, 2)) + if rhs_transpose: + rhs_smem = mgpu.memref_transpose(rhs_smem, (1, 0, 3, 2)) + tcgen05.mma( + acc, + lhs_smem, + rhs_smem, + a_swizzle=lhs_swizzle, + b_swizzle=rhs_swizzle, + a_sparse_metadata=lhs_sparse, + accumulate=False, + ) + tcgen05.commit_arrive(mma_barrier) + mma_barrier.wait(orders_tensor_core=True) + is_signed = True if jnp.issubdtype(in_jax_dtype, jnp.integer) else None + acc.load(is_signed=is_signed).store_untiled(out, optimized=False) + + x_shape = (k // 2, m) if lhs_transpose else (m, k // 2) + y_shape = (n, k) if rhs_transpose else (k, n) + if jnp.issubdtype(in_jax_dtype, jnp.integer): + x = jax.random.randint(jax.random.key(1234), x_shape, -64, 64, dtype=in_jax_dtype) + y = jax.random.randint(jax.random.key(2567), y_shape, -64, 64, dtype=in_jax_dtype) + else: + x = self.prng.uniform(-1, 1, x_shape).astype(in_jax_dtype) + y = self.prng.uniform(-1, 1, y_shape).astype(in_jax_dtype) + out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype) + scratch_shape = [ + jax.ShapeDtypeStruct(tile_shape(x_shape, lhs_tiling), in_jax_dtype), + jax.ShapeDtypeStruct(tile_shape(y_shape, rhs_tiling), in_jax_dtype), + jax.ShapeDtypeStruct((m // 128, k // 128, 128, 64), sparse_meta_dtype), + mgpu.TMABarrier(3), + mgpu.Barrier(1), + mgpu.TMEM((m, n), out_jax_dtype), + mgpu.TMEM((m, k // 2), sparse_meta_dtype, layout=tcgen05.sparse_meta_layout()), + ] + index_pairs = np.asarray(np.meshgrid(range(4), range(4))).T.reshape(-1, 2) + valid_pairs = index_pairs[index_pairs[:, 0] < index_pairs[:, 1]] + assert len(valid_pairs) == 6 + x_pairs = jax.random.randint(jax.random.key(1234), (m, k // 4), 0, 6, dtype=jnp.uint8) + x_sparse = valid_pairs[x_pairs] + assert x_sparse.shape == (m, k // 4, 2) + def format_sparse_meta(meta): + mn, k, _2 = meta.shape + assert _2 == 2 + k *= 2 + if jnp.dtype(in_jax_dtype).itemsize == 1: + meta_tiled = ( + meta.reshape(mn // 128, 128, k // 64, 64).transpose(0, 2, 1, 3) + ) + else: + meta_tiled = ( + meta.reshape(mn // 128, 8, 2, 8, k // 64, 4, 2, 8) + .transpose(0, 4, 1, 6, 3, 5, 2, 7) + ) + return ( + meta_tiled.reshape(mn // 128, k // 64, 128, 64) + .astype(sparse_meta_dtype) + ) + x_gpu_sparse = format_sparse_meta(x_sparse) + args = (x, y, x_gpu_sparse) + z = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), args, out_shape, scratch_shape + )(*args) + if lhs_transpose: + x = x.T + if rhs_transpose: + y = y.T + x_logical = np.zeros_like(x, shape=(m, k // 4, 4)) + np.put_along_axis(x_logical, x_sparse, x.reshape(x_sparse.shape), axis=-1) + x_logical = x_logical.reshape(m, k) + ref = x_logical.astype(jnp.float32) @ y.astype(jnp.float32) + atol = 2e-2 if out_jax_dtype == jnp.float16 else 7e-5 + rtol = 8e-4 if out_jax_dtype == jnp.float16 else 5e-6 + np.testing.assert_allclose(z, ref, atol=atol, rtol=rtol) + + @parameterized.product( + in_jax_dtype=(jnp.float16, jnp.bfloat16), + m=(128,), # TODO(apaszke): 256 + n=(128, 256), # TODO(apaszke): other non-power-of-2 + lhs_swizzle=(32, 64, 128), + rhs_swizzle=(64, 128), # 32 is too small and unsuported. + ) + def test_mma_sparse_lhs_tmem( + self, m, n, in_jax_dtype, lhs_swizzle, rhs_swizzle + ): + out_jax_dtype = jnp.float32 + sparse_meta_dtype = jnp.uint2 + + in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype) + k = 256 + rhs_tiling = (8, 8 * rhs_swizzle // bitwidth(in_mlir_dtype)) + + def kernel(ctx, lhs, rhs, lhs_sparse_gmem, out, scratch): + ( + rhs_smem, + lhs_sparse_smem, + barriers, + mma_barrier, + acc, + lhs_tmem, + lhs_sparse, + ) = scratch + ctx.async_copy( + src_ref=rhs, + dst_ref=rhs_smem, + barrier=barriers[0], + swizzle=rhs_swizzle, + gmem_transform=mgpu.TileTransform(rhs_tiling), + ) + ctx.async_copy( + src_ref=lhs_sparse_gmem, dst_ref=lhs_sparse_smem, barrier=barriers[1] + ) + barriers[0].wait() + barriers[1].wait() + lhs_tmem.store( + fa.FragmentedArray.load_untiled( + lhs, layout=tcgen05.LAYOUT, optimized=False + ) + ) + tcgen05.commit_tmem() + with mgpu.single_thread(): + tcgen05.async_copy_sparse_metadata_smem_to_tmem( + lhs_sparse_smem, lhs_sparse + ) + tcgen05.mma( + acc, + lhs_tmem, + rhs_smem, + a_swizzle=lhs_swizzle, + b_swizzle=rhs_swizzle, + a_sparse_metadata=lhs_sparse, + accumulate=False, + ) + tcgen05.commit_arrive(mma_barrier) + mma_barrier.wait(orders_tensor_core=True) + acc.load().store_untiled(out, optimized=False) + + x_shape = (m, k // 2) + x = self.prng.uniform(-1, 1, x_shape).astype(in_jax_dtype) + y_shape = (k, n) + y = self.prng.uniform(-1, 1, y_shape).astype(in_jax_dtype) + out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype) + scratch_shape = [ + jax.ShapeDtypeStruct(tile_shape(y_shape, rhs_tiling), in_jax_dtype), + jax.ShapeDtypeStruct((m // 128, k // 128, 128, 64), sparse_meta_dtype), + mgpu.TMABarrier(2), + mgpu.Barrier(1), + mgpu.TMEM((m, n), out_jax_dtype), + mgpu.TMEM((m, k // 2), in_jax_dtype, packing=2), + mgpu.TMEM( + (m, k // 2), sparse_meta_dtype, layout=tcgen05.sparse_meta_layout() + ), + ] + index_pairs = np.asarray(np.meshgrid(range(4), range(4))).T.reshape(-1, 2) + valid_pairs = index_pairs[index_pairs[:, 0] < index_pairs[:, 1]] + assert len(valid_pairs) == 6 + x_pairs = jax.random.randint( + jax.random.key(1234), (m, k // 4), 0, 6, dtype=jnp.uint8 + ) + x_sparse = valid_pairs[x_pairs] + assert x_sparse.shape == (m, k // 4, 2) + + def format_sparse_meta(meta): + mn, k, _2 = meta.shape + assert _2 == 2 + k *= 2 + return ( + meta.reshape(mn // 128, 8, 2, 8, k // 64, 4, 2, 8) + .transpose(0, 4, 1, 6, 3, 5, 2, 7) + .reshape(mn // 128, k // 64, 128, 64) + .astype(sparse_meta_dtype) + ) + + x_gpu_sparse = format_sparse_meta(x_sparse) + args = (x, y, x_gpu_sparse) + z = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), args, out_shape, scratch_shape + )(*args) + x_logical = np.zeros_like(x, shape=(m, k // 4, 4)) + np.put_along_axis(x_logical, x_sparse, x.reshape(x_sparse.shape), axis=-1) + x_logical = x_logical.reshape(m, k) + ref = x_logical.astype(jnp.float32) @ y.astype(jnp.float32) + atol = 2e-2 if out_jax_dtype == jnp.float16 else 7e-5 + rtol = 8e-4 if out_jax_dtype == jnp.float16 else 5e-6 + np.testing.assert_allclose(z, ref, atol=atol, rtol=rtol) + + @parameterized.product( + in_jax_dtype=(jnp.float16, jnp.float8_e4m3fn), + m=(256,), # TODO(apaszke): 256 + n=(128, 256), # TODO(apaszke): other non-power-of-2 + lhs_swizzle=(32, 64, 128), + rhs_swizzle=(64, 128), # 32 is too small and unsupported. + ) + def test_mma_sparse_collective(self, m, n, in_jax_dtype, lhs_swizzle, rhs_swizzle): + out_jax_dtype = jnp.float32 + sparse_meta_dtype = jnp.uint2 + + in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype) + k = 256 + lhs_tiling = (8, 8 * lhs_swizzle // bitwidth(in_mlir_dtype)) + rhs_tiling = (8, 8 * rhs_swizzle // bitwidth(in_mlir_dtype)) + if m // 2 < lhs_tiling[1]: + self.skipTest("LHS too small for this swizzle") + if n // 2 < rhs_tiling[1]: + self.skipTest("RHS too small for this swizzle") + + def kernel(ctx, lhs, rhs, lhs_sparse_gmem, out, scratch): + lhs_smem, rhs_smem, lhs_sparse_smem, barriers, mma_barrier, acc, lhs_sparse = scratch + ctx.async_copy( + src_ref=lhs, + dst_ref=lhs_smem, + barrier=barriers[0], + swizzle=lhs_swizzle, + gmem_transform=mgpu.TileTransform(lhs_tiling), + collective=gpu.Dimension.x, + partitioned=0, + ) + ctx.async_copy( + src_ref=rhs, + dst_ref=rhs_smem, + barrier=barriers[1], + swizzle=rhs_swizzle, + gmem_transform=mgpu.TileTransform(rhs_tiling), + collective=gpu.Dimension.x, + partitioned=1, + ) + ctx.async_copy( + src_ref=lhs_sparse_gmem, + dst_ref=lhs_sparse_smem, + barrier=barriers[2], + collective=gpu.Dimension.x, + partitioned=0, + ) + index = ir.IndexType.get() + block_id = gpu.cluster_block_id(gpu.Dimension.x) + is_first_block = arith.cmpi(arith.CmpIPredicate.eq, block_id, c(0, index)) + is_leader_thread = single_thread_predicate() + with when(arith.andi(is_first_block, is_leader_thread)): + for i in range(3): + barriers[i].wait() + tcgen05.async_copy_sparse_metadata_smem_to_tmem(lhs_sparse_smem, lhs_sparse, collective=True) + tcgen05.mma( + acc, + lhs_smem, + rhs_smem, + a_swizzle=lhs_swizzle, + b_swizzle=rhs_swizzle, + a_sparse_metadata=lhs_sparse, + accumulate=False, + collective=True, + ) + tcgen05.commit_arrive(mma_barrier, collective=True, ctx=ctx) + mma_barrier.wait(orders_tensor_core=True) + m_block_tile = m // 2 + m_slice = ds(arith.muli(block_id, c(m_block_tile, index)), m_block_tile) + acc.load().store_untiled(memref_slice(out, m_slice), optimized=False) + + x_shape = (m, k // 2) + y_shape = (k, n) + x = self.prng.uniform(-1, 1, x_shape).astype(in_jax_dtype) + y = self.prng.uniform(-1, 1, y_shape).astype(in_jax_dtype) + out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype) + m_block = m // 2 + n_block = n // 2 + scratch_shape = [ + jax.ShapeDtypeStruct(tile_shape((m_block, k // 2), lhs_tiling), in_jax_dtype), + jax.ShapeDtypeStruct(tile_shape((k, n_block), rhs_tiling), in_jax_dtype), + jax.ShapeDtypeStruct((m_block // 128, k // 128, 128, 64), sparse_meta_dtype), + mgpu.TMABarrier(3), + mgpu.Barrier(1), + mgpu.TMEM((m_block, n), out_jax_dtype, collective=True), + mgpu.TMEM((m_block, k // 2), sparse_meta_dtype, layout=tcgen05.sparse_meta_layout(), collective=True), + ] + index_pairs = np.asarray(np.meshgrid(range(4), range(4))).T.reshape(-1, 2) + valid_pairs = index_pairs[index_pairs[:, 0] < index_pairs[:, 1]] + assert len(valid_pairs) == 6 + x_pairs = jax.random.randint(jax.random.key(1234), (m, k // 4), 0, 6, dtype=jnp.uint8) + x_sparse = valid_pairs[x_pairs] + assert x_sparse.shape == (m, k // 4, 2) + def format_sparse_meta(meta): + mn, k, _2 = meta.shape + assert _2 == 2 + k *= 2 + if jnp.dtype(in_jax_dtype).itemsize == 1: + meta_tiled = ( + meta.reshape(mn // 128, 128, k // 64, 64).transpose(0, 2, 1, 3) + ) + else: + meta_tiled = ( + meta.reshape(mn // 128, 8, 2, 8, k // 64, 4, 2, 8) + .transpose(0, 4, 1, 6, 3, 5, 2, 7) + ) + return ( + meta_tiled.reshape(mn // 128, k // 64, 128, 64) + .astype(sparse_meta_dtype) + ) + x_gpu_sparse = format_sparse_meta(x_sparse) + args = (x, y, x_gpu_sparse) + z = mgpu.as_gpu_kernel( + kernel, (2, 1, 1), (128, 1, 1), args, out_shape, scratch_shape, cluster=(2, 1, 1) + )(*args) + x_logical = np.zeros_like(x, shape=(m, k // 4, 4)) + np.put_along_axis(x_logical, x_sparse, x.reshape(x_sparse.shape), axis=-1) + x_logical = x_logical.reshape(m, k) + ref = x_logical.astype(jnp.float32) @ y.astype(jnp.float32) + atol = 2e-2 if out_jax_dtype == jnp.float16 else 7e-5 + rtol = 8e-4 if out_jax_dtype == jnp.float16 else 5e-6 + np.testing.assert_allclose(z, ref, atol=atol, rtol=rtol) + + @parameterized.product( + lhs_transpose=(False, True), + rhs_transpose=(False, True), + in_jax_dtype=(jnp.float16,), + out_jax_dtype=(jnp.float32,), + m=(128, 256), # TODO(apaszke): 192, 256 + n=(128, 160, 256), + swizzle=(32, 64, 128,), + ) + def test_mma_collective( + self, + m, + n, + swizzle, + lhs_transpose, + rhs_transpose, + in_jax_dtype, out_jax_dtype, ): + k_steps = 2 # Reducing to 1 can be helpful while debugging. if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16: raise self.skipTest("Only f16 input is supported for f16 output.") @@ -1052,7 +2299,7 @@ def test_mma_collective( tiling = (8, swizzle_elems) def kernel(ctx, lhs, rhs, out, scratch): - lhs_smem, rhs_smem, barriers, acc = scratch + lhs_smem, rhs_smem, barriers, mma_barrier, acc = scratch block_id = gpu.cluster_block_id(gpu.Dimension.x) ctx.async_copy( src_ref=lhs, @@ -1084,10 +2331,10 @@ def kernel(ctx, lhs, rhs, out, scratch): tcgen05.mma( acc, lhs_smem, rhs_smem, a_swizzle=swizzle, b_swizzle=swizzle, accumulate=False, collective=True ) - tcgen05.commit_arrive(barriers[2], collective=True, ctx=ctx) - barriers[2].wait(for_tensor_core=True) + tcgen05.commit_arrive(mma_barrier, collective=True, ctx=ctx) + mma_barrier.wait(orders_tensor_core=True) m_slice = ds(arith.muli(block_id, c(m_block_tile, index)), m_block_tile) - acc[:].store_untiled(memref_slice(out, m_slice)) + acc.load().store_untiled(memref_slice(out, m_slice), optimized=False) in_finfo = jnp.finfo(in_jax_dtype) exponent_bits, mantissa_bits = in_finfo.nexp, in_finfo.nmant @@ -1102,11 +2349,16 @@ def quantize(x): y_block_shape = (n_block_tile, k) if rhs_transpose else (k, n_block_tile) y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype) out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype) + if any(s % t for s, t in zip(x_block_shape, tiling)): + self.skipTest("LHS block shape not divisible by tiling.") + if any(s % t for s, t in zip(y_block_shape, tiling)): + self.skipTest("RHS block shape not divisible by tiling.") scratch_shape = [ jax.ShapeDtypeStruct(tile_shape(x_block_shape, tiling), in_jax_dtype), jax.ShapeDtypeStruct(tile_shape(y_block_shape, tiling), in_jax_dtype), - mgpu.TMABarrier(3), - mgpu.TMEM((128, n), out_jax_dtype, collective=True), + mgpu.TMABarrier(2), + mgpu.Barrier(1), + mgpu.TMEM((m_block_tile, n), out_jax_dtype, collective=True), ] z = mgpu.as_gpu_kernel( kernel, (2, 1, 1), (128, 1, 1), (x, y), out_shape, scratch_shape, cluster=(2, 1, 1) @@ -1116,6 +2368,208 @@ def quantize(x): atol = 2e-2 if out_jax_dtype == jnp.float16 else 5e-6 np.testing.assert_allclose(z, ref, atol=atol) + @parameterized.product( + in_jax_dtype=(jnp.float16,), + out_jax_dtype=(jnp.float32,), + m=(256,), # TODO(apaszke): 64, 192, 256 + n=(128, 192, 224, 256,), + k_steps=(2,), # Note: reducing to 1 can be useful for debugging. + swizzle=(32, 64, 128,), + ) + def test_mma_collective_lhs_tmem( + self, + m, + n, + k_steps, + swizzle, + in_jax_dtype, + out_jax_dtype, + ): + if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16: + raise self.skipTest("Only f16 input is supported for f16 output.") + + in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype) + m_block_tile = m // 2 + n_block_tile = n // 2 + swizzle_elems = swizzle // bytewidth(in_mlir_dtype) + k = swizzle_elems * k_steps + index = ir.IndexType.get() + + tiling = (8, swizzle_elems) + + def kernel(ctx, lhs, rhs, out, scratch): + lhs_smem, rhs_smem, barriers, mma_barrier, cluster_barrier, acc, lhs_tmem = scratch + block_id = gpu.cluster_block_id(gpu.Dimension.x) + ctx.async_copy( + src_ref=lhs, + dst_ref=lhs_smem, + swizzle=swizzle, + gmem_transform=mgpu.TileTransform(tiling), + barrier=barriers[0], + collective=gpu.Dimension.x, + partitioned=0, # Split non-contracting dim. + ) + ctx.async_copy( + src_ref=rhs, + dst_ref=rhs_smem, + swizzle=swizzle, + gmem_transform=mgpu.TileTransform(tiling), + barrier=barriers[1], + collective=gpu.Dimension.x, + partitioned=1, # Split non-contracting dim. + ) + + is_leader_thread = single_thread_predicate() + is_first_block = arith.cmpi(arith.CmpIPredicate.eq, block_id, c(0, index)) + + with when(arith.andi(is_first_block, is_leader_thread)): + barriers[0].wait() + gpu.barrier() + # Because only block 1 waits on the TMA, we need a cluster barrier so + # that the SMEM updates are visible on block 2. + cluster_barrier.arrive(orders_tensor_core=True) + cluster_barrier.wait(orders_tensor_core=True) + lhs_tmem.store( + fa.FragmentedArray.load_tiled( + lhs_smem, swizzle, layout=tcgen05.LAYOUT + ) + ) + tcgen05.commit_tmem() + # Make sure TMEM has been loaded on both blocks. + cluster_barrier.arrive(orders_tensor_core=True) + cluster_barrier.wait(orders_tensor_core=True) + with when(arith.andi(is_first_block, is_leader_thread)): + barriers[1].wait() + tcgen05.mma( + acc, + lhs_tmem, + rhs_smem, + a_swizzle=swizzle, + b_swizzle=swizzle, + accumulate=False, + collective=True, + ) + tcgen05.commit_arrive(mma_barrier, collective=True, ctx=ctx) + mma_barrier.wait(orders_tensor_core=True) + m_slice = ds(arith.muli(block_id, c(m_block_tile, index)), m_block_tile) + acc.load().store_untiled(memref_slice(out, m_slice), optimized=False) + + in_finfo = jnp.finfo(in_jax_dtype) + exponent_bits, mantissa_bits = in_finfo.nexp, in_finfo.nmant + + def quantize(x): + # Quantize the input to avoid rounding when feeding the TensorCore + return jax.lax.reduce_precision(x, exponent_bits, mantissa_bits) + + x_shape = (m, k) + x_block_shape = (m_block_tile, k) + x = quantize(self.prng.uniform(-1, 1, x_shape)).astype(in_jax_dtype) + y_shape = (k, n) + y_block_shape = (k, n_block_tile) + y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype) + out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype) + if any(s % t for s, t in zip(x_block_shape, tiling)): + self.skipTest("LHS block shape not divisible by tiling.") + if any(s % t for s, t in zip(y_block_shape, tiling)): + self.skipTest("RHS block shape not divisible by tiling.") + scratch_shape = [ + jax.ShapeDtypeStruct(tile_shape(x_block_shape, tiling), in_jax_dtype), + jax.ShapeDtypeStruct(tile_shape(y_block_shape, tiling), in_jax_dtype), + mgpu.TMABarrier(2), + mgpu.Barrier(1), + mgpu.ClusterBarrier(collective_dims=(gpu.Dimension.x,)), + mgpu.TMEM((128, n), out_jax_dtype, collective=True), + mgpu.TMEM((128, k), in_jax_dtype, collective=True, packing=2), + ] + z = mgpu.as_gpu_kernel( + kernel, + (2, 1, 1), + (128, 1, 1), + (x, y), + out_shape, + scratch_shape, + cluster=(2, 1, 1), + )(x, y) + x32, y32 = x.astype(np.float32), y.astype(np.float32) + ref = x32 @ y32 + atol = 2e-2 if out_jax_dtype == jnp.float16 else 5e-6 + np.testing.assert_allclose(z, ref, atol=atol) + + def test_raises_error_if_tmem_oom(self): + def kernel(ctx, input, output, scratch): + del ctx, input, output, scratch + + x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128) + scratch_shape = [ + mgpu.TMEM((128, 384), jnp.float32), # Should round up to 512 columns. + mgpu.TMEM((128, 64), jnp.float32), # Will trigger OOM. + ] + with self.assertRaisesRegex(ValueError, + "Total TMEM allocation exceeds memory limit."): + mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, scratch_shape + )(x).block_until_ready() + + def test_raises_error_if_collective_tmem_without_cluster(self): + def kernel(ctx, input, output, scratch): + del ctx, input, output, scratch + + x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128) + scratch_shape = [mgpu.TMEM((128, 384), jnp.float32, collective=True)] + with self.assertRaisesRegex( + ValueError, + "Collective TMEM allocations are only supported for clusters with an" + " even number of blocks in them.", + ): + mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, scratch_shape + )(x).block_until_ready() + + @parameterized.parameters((0,), (1,), (2,)) + def test_cluster_launch_control(self, dim): + # Let's say we have 148 SMs in our gpu. We attempt to schedule 149 blocks on + # 148 SMs. Only one SM will succeed in stealing the 149th block, and the + # others will fail. Therefore we test that there is exactly 1 stolen block + # and the others fail and return -1. + num_sms = jax.devices()[0].core_count + num_blocks = num_sms + 1 + grid = [1, 1, 1] + grid[dim] = num_blocks + + def kernel(ctx, out, scratch): + del ctx + cancel_result_ref, barrier, _ = scratch + + is_leader_thread = single_thread_predicate() + barrier.arrive_expect_tx(16, predicate=is_leader_thread) + mgpu.try_cluster_cancel(cancel_result_ref, barrier, is_leader_thread) + + barrier.wait() + *cta_ids, cancelled_launch = mgpu.query_cluster_cancel(cancel_result_ref) + cta_id = arith.addi(cta_ids[0], arith.addi(cta_ids[1], cta_ids[2])) + + # Store a sentinel value if no work can be scheduled. + idx = arith.index_cast(ir.IndexType.get(), utils.block_idx()) + sentinel_val = arith.constant(ir.IntegerType.get_signless(32), -1) + + value = arith.select(cancelled_launch, cta_id, sentinel_val) + memref.store(value, out, [idx]) + + cancel_result_ref = jax.ShapeDtypeStruct((16,), jnp.int8) # 128 bits + out_ty = jax.ShapeDtypeStruct((num_sms,), jnp.int32) + scratch = ( + cancel_result_ref, + mgpu.Barrier(1), + # Requesting SMEM close to the 228kb limit to ensure that each SM only + # schedules 1 block. + jax.ShapeDtypeStruct((220 * 1024,), jnp.int8), + ) + out = mgpu.as_gpu_kernel(kernel, grid, (128, 1, 1), (), out_ty, scratch)() + + out = np.sort(out) + out_ref = np.array([-1] * (num_sms - 1) + [num_sms]) + np.testing.assert_array_equal(out, out_ref) + class BarrierTest(TestCase): @@ -1140,7 +2594,7 @@ def kernel(ctx, dst, scratch): final_arr = arr + mgpu.FragmentedArray.load_strided( tmp, is_signed=False ) - final_arr.store_untiled(memref_slice(dst, 0)) + final_arr.store_untiled(memref_slice(dst, 0), optimized=False) scf.yield_([]) with ir.InsertionPoint(scf.IfOp(is_second_wg).then_block): barriers[0].wait() @@ -1151,7 +2605,7 @@ def kernel(ctx, dst, scratch): barriers[2].wait() # Synchronize this warpgroup before we overwrite tmp. arr.store_untiled(tmp) barriers[1].arrive() # Signal that tmp is ready. - final_arr.store_untiled(memref_slice(dst, 1)) + final_arr.store_untiled(memref_slice(dst, 1), optimized=False) scf.yield_([]) out_shape = jax.ShapeDtypeStruct((2, 128), jnp.int32) y = mgpu.as_gpu_kernel( @@ -1193,7 +2647,7 @@ def test_collective_arrive(self, collective_dims, noncollective_dims, collective cluster[d] = collective_size for d in noncollective_dims: cluster[d] = 2 - if math.prod(cluster) > 16: + if math.prod(cluster) > jtu.get_cuda_nonportable_max_cluster_size(): self.skipTest("Cluster too big") is_trivial = math.prod(cluster[d] for d in collective_dims) == 1 def kernel(ctx, dst, mask, collective_barrier): @@ -1237,8 +2691,55 @@ def kernel(ctx, dst, mask, collective_barrier): expected_mask |= np.bitwise_or.reduce(mask_bits, axis=None) self.assertEqual(min(mask), expected_mask) + def test_collective_arrival_count(self): + i32 = ir.IntegerType.get_signless(32) + cluster = [2, 1, 1] + def kernel(ctx, dst, collective_barrier): + collective_barrier.arrive() + collective_barrier.arrive() + collective_barrier.arrive() + collective_barrier.arrive() + collective_barrier.wait() + memref.store(arith.constant(i32, 1), dst, []) + out_shape = jax.ShapeDtypeStruct((), jnp.int32) + scratch = mgpu.ClusterBarrier((gpu.Dimension.x,), arrival_count=4) + y = mgpu.as_gpu_kernel( + kernel, cluster, (128, 1, 1), (), out_shape, scratch, cluster=cluster, + )() + np.testing.assert_array_equal(y, np.ones((), dtype=np.int32)) -class TMATest(TestCase): + @parameterized.parameters(False, True) + def test_mbarrier_complete_tx(self, predicated): + i32 = ir.IntegerType.get_signless(32) + + def kernel(ctx, dst, mbar): + mbar.arrive_expect_tx(1024 // 128) + + if predicated: + is_leader = mgpu.single_thread_predicate(mgpu.ThreadSubset.BLOCK) + mbar.complete_tx(1024, predicate=is_leader) + else: + mbar.complete_tx(1024 // 128) + + mbar.wait() + + with mgpu.single_thread(scope=mgpu.ThreadSubset.BLOCK): + memref.store(arith.constant(i32, 1), dst, []) + + out_shape = jax.ShapeDtypeStruct((), jnp.int32) + y = mgpu.as_gpu_kernel( + kernel, + (1, 1, 1), + (128, 1, 1), + (), + out_shape, + mgpu.Barrier(arrival_count=128), + )() + + np.testing.assert_array_equal(y, np.array(1, dtype=np.int32)) + + +class AsyncCopyTest(TestCase): @parameterized.product( swizzle=(None, 32, 64, 128), @@ -1260,6 +2761,138 @@ def kernel(ctx, src, dst, smem): y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem)(x) np.testing.assert_array_equal(y, x) + @parameterized.product( + swizzle=(None, 32, 64, 128), + shape=((64, None), (5, None), (2, 3, 5, None)), + dtype=(jnp.float32, jnp.float16, jnp.int4), + ) + def test_tma_prefetch_basic(self, swizzle, shape, dtype): + bw = bitwidth(dtype_to_ir_type(dtype)) + minor_size = 64 if swizzle is None else 8 * swizzle // bw + shape = (*shape[:-1], minor_size) + i1 = ir.IntegerType.get_signless(1) + def kernel(ctx, src, dst, smem): + tmp, barrier = smem + ctx.async_prefetch(gmem_ref=src, swizzle=swizzle) + ctx.async_copy(src_ref=src, dst_ref=tmp, swizzle=swizzle, barrier=barrier) + barrier.wait_parity(c(0, i1)) + copy(tmp, dst, swizzle=swizzle) + x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) + smem = (x, mgpu.TMABarrier()) + y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem)(x) + np.testing.assert_array_equal(y, x) + + @parameterized.product( + swizzle=(16, 32, 64, 128), + shape=((64, None),), + dtype=(jnp.int32, jnp.int16), + idx_dtype=(jnp.int32, jnp.int8), + ) + def test_tma_gather_basic(self, swizzle, shape, dtype, idx_dtype): + if not jtu.is_cuda_compute_capability_at_least("10.0"): + self.skipTest("TMA gather requires CUDA compute capability 10.0 or higher") + i1 = ir.IntegerType.get_signless(1) + swizzle_elems = 8 * swizzle // bitwidth(dtype_to_ir_type(dtype)) + col_slice = swizzle_elems if swizzle != 16 else 128 + shape = (*shape[:-1], 2 * col_slice) + def kernel(ctx, src, idx, dst, smem): + tmp, barrier = smem + idxs = mgpu.FragmentedArray.load_untiled( + idx, layout=fa.TMA_GATHER_INDICES_LAYOUT, optimized=False, is_signed=False + ) + ctx.async_copy( + src_ref=src, + dst_ref=tmp, + swizzle=swizzle, + barrier=barrier, + gmem_slice=(idxs, mgpu.ds(col_slice, col_slice)), + ) + barrier.wait_parity(c(0, i1)) + copy(tmp, dst, swizzle=swizzle) + x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) + idx = jax.random.permutation(jax.random.key(1234), 48).astype(idx_dtype) + out_type = jax.ShapeDtypeStruct((len(idx), col_slice), dtype) + smem = (out_type, mgpu.TMABarrier()) + y = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (x, idx), out_type, smem, + )(x, idx) + np.testing.assert_array_equal(y, x[idx, slice(col_slice, 2 * col_slice)]) + + @parameterized.product( + swizzle=(16, 32, 64, 128), + shape=((64, None),), + dtype=(jnp.int32, jnp.int16), + transpose_tiles=(False, True), + ) + def test_tma_gather_tiled(self, swizzle, shape, dtype, transpose_tiles): + if not jtu.is_cuda_compute_capability_at_least("10.0"): + self.skipTest("TMA gather requires CUDA compute capability 10.0 or higher") + i1 = ir.IntegerType.get_signless(1) + swizzle_elems = 8 * swizzle // bitwidth(dtype_to_ir_type(dtype)) + col_slice = swizzle_elems if swizzle != 16 else 128 + shape = (*shape[:-1], 3 * col_slice) + # Using (8, swizzle_elems) produces too short transfers (we'd end up with + # misaligned SMEM addresses). + tiling = (8, swizzle_elems) if swizzle != 16 else (8, 2 * swizzle_elems) + if transpose_tiles: + transforms = (mgpu.TileTransform(tiling), mgpu.TransposeTransform((1, 0, 2, 3))) + else: + transforms = mgpu.TileTransform(tiling) + def kernel(ctx, src, idx, dst, smem): + tmp, barrier = smem + idxs = mgpu.FragmentedArray.load_untiled( + idx, layout=fa.TMA_GATHER_INDICES_LAYOUT, optimized=False, is_signed=False + ) + ctx.async_copy( + src_ref=src, + dst_ref=tmp, + swizzle=swizzle, + barrier=barrier, + gmem_slice=(idxs, mgpu.ds(col_slice, 2 * col_slice)), + gmem_transform=transforms, + ) + barrier.wait_parity(c(0, i1)) + ctx.async_copy( + src_ref=tmp, + dst_ref=dst, + swizzle=swizzle, + gmem_transform=transforms, + ) + ctx.await_async_copy(0) + x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) + idx = jax.random.permutation(jax.random.key(1234), 48).astype(jnp.int32) + out_type = jax.ShapeDtypeStruct((len(idx), 2 * col_slice), dtype) + smem_shape = tile_shape((len(idx), 2 * col_slice), tiling) + if transpose_tiles: + smem_shape = (smem_shape[1], smem_shape[0], *smem_shape[2:]) + smem = ( + jax.ShapeDtypeStruct(smem_shape, dtype), + mgpu.TMABarrier(), + ) + y = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (x, idx), out_type, smem, + )(x, idx) + np.testing.assert_array_equal(y, x[idx, slice(col_slice, 3 * col_slice)]) + + def test_tma_with_1d_tiling(self): + swizzle = 128 + dtype = jnp.float16 + shape = (64, 128) + tiling = (1, swizzle // jnp.dtype(dtype).itemsize) + def kernel(ctx, dst, smem): + iota_tensor(*shape, dtype=dtype).store_tiled(smem, swizzle=swizzle) + ctx.async_copy( + src_ref=smem, + dst_ref=dst, + swizzle=swizzle, + gmem_transform=mgpu.TileTransform(tiling), + ) + ctx.await_async_copy(0) + x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) + smem = jax.ShapeDtypeStruct(utils.tile_shape(shape, tiling), dtype) + y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), (), x, smem)() + np.testing.assert_array_equal(y, x) + @parameterized.named_parameters( ( f"_{''.join(map(str, collective_dims))}={collective_size}{'_' + ''.join(map(str, noncollective_dims)) if noncollective_dims else ''}", @@ -1285,7 +2918,7 @@ def test_tma_load_multicast(self, collective_dims, noncollective_dims, collectiv cluster[d] = collective_dim_size for d in noncollective_dims: cluster[d] = 2 - if math.prod(cluster) > 16: + if math.prod(cluster) > jtu.get_cuda_nonportable_max_cluster_size(): self.skipTest("Cluster too big") collective_size = math.prod(cluster[d] for d in collective_dims) noncollective_size = math.prod(cluster) // collective_size @@ -1380,6 +3013,56 @@ def kernel(ctx, src, dst, scratch): y = f(x) np.testing.assert_array_equal(y, x) + @parameterized.product( + swizzle=(None, 128), + shape=((128, 128), (5, 32, 128)), + dtype=(jnp.float16, jnp.float32), + ) + @jtu.thread_unsafe_test() + def test_tma_prefetch_tiled(self, swizzle, shape, dtype): + # TODO(apaszke): ptxas seems to freeze when generating code for copy with + # swizzle 32 and 64. + i1 = ir.IntegerType.get_signless(1) + index = ir.IndexType.get() + tiling = (32, (swizzle or 128) // jnp.dtype(dtype).itemsize) + tiled_shape = tile_shape(shape, tiling)[:len(shape)] + def kernel(ctx, src, dst, scratch): + tmp, barrier = scratch + ctx.async_prefetch( + gmem_ref=src, swizzle=swizzle, gmem_transform=mgpu.TileTransform(tiling) + ) + ctx.async_copy( + src_ref=src, + dst_ref=tmp, + swizzle=swizzle, + barrier=barrier, + gmem_transform=mgpu.TileTransform(tiling), + ) + barrier.wait_parity(c(0, i1)) + for idxs in np.ndindex(tiled_shape): + untiled_idxs, tiled_idxs = idxs[:-len(tiling)], idxs[-len(tiling):] + s = ( + *untiled_idxs, + *(ds(c(ix * t, index), t) for ix, t in zip(tiled_idxs, tiling)), + ) + copy(memref_slice(tmp, idxs), memref_slice(dst, s), swizzle=swizzle) + x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) + smem = ( + jax.ShapeDtypeStruct(tile_shape(shape, tiling), dtype), + mgpu.TMABarrier(), + ) + env_vars = { + "MOSAIC_GPU_DUMP_HOST_LLVM": "1", + "MOSAIC_GPU_DUMP_PTX": "1", + } + with jtu.set_env(**env_vars), self.capture_stdout() as ptx_llvm_ir: + f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem) + y = f(x) + # We should only create one descriptor for both prefetch and copy. + self.assertEqual(ptx_llvm_ir().count("call void @mosaic_gpu_init_tma_desc("), 1) + self.assertIn("cp.async.bulk.prefetch.tensor", ptx_llvm_ir()) + np.testing.assert_array_equal(y, x) + @parameterized.product(swizzle=(None, 128)) def test_tma_load_tiled_rounding(self, swizzle): # TODO(apaszke): ptxas seems to freeze when generating code for copy with @@ -1572,16 +3255,70 @@ def kernel(ctx, src, dst, tmp): ctx.await_async_copy(0) def run_kernel(shape): - x = np.arange(np.prod(shape)).reshape(shape) + x = np.arange(np.prod(shape), dtype=np.int32).reshape(shape) _ = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) with self.assertRaisesRegex(ValueError, "all GMEM strides except the last"): run_kernel([1] * 6) - with self.assertRaisesRegex( - ValueError, "last dimension to be divisible by 16" - ): - run_kernel([23]) + with self.assertRaisesRegex( + ValueError, "last dimension to be divisible by 128" + ): + run_kernel([23]) + + @parameterized.product( + swizzle=(16, 32, 64, 128), + shape=((64, 128), (128, 32)), + dtype=(jnp.float32, jnp.float16, jnp.float8_e5m2, jnp.int4), + ) + def test_cp_async(self, swizzle, shape, dtype): + bw = bitwidth(dtype_to_ir_type(dtype)) + swizzle_elems = 8 * swizzle // bw + tiling = (8, swizzle_elems) + if shape[-1] < swizzle_elems: + self.skipTest("Minor dimension too small") + minor_size = 64 if swizzle is None else swizzle_elems + shape = (*shape[:-1], minor_size) + def kernel(ctx, src, dst, tmp): + ctx.async_copy( + src_ref=src, + dst_ref=tmp, + swizzle=swizzle, + gmem_transform=mgpu.TileTransform(tiling), + implementation=mgpu.AsyncCopyImplementation.CP_ASYNC, + ) + ctx.await_cp_async_copy(0) + mgpu.copy_tiled(tmp, dst, swizzle=swizzle) + x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) + smem = jax.ShapeDtypeStruct(mgpu.tile_shape(shape, tiling), dtype) + y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem)(x) + np.testing.assert_array_equal(y, x) + + def test_tma_collective_async_cp_with_no_swizzle(self): + def body(ctx, src, dst, scratch): + tmp, barrier = scratch + ctx.async_copy( + src_ref=src, dst_ref=tmp, collective=gpu.Dimension.x, barrier=barrier + ) + barrier.wait() + block_id = gpu.cluster_block_id(gpu.Dimension.x) + ctx.async_copy(src_ref=tmp, dst_ref=dst, gmem_slice=block_id) + + dtype = jnp.float32 + kernel = mgpu.as_gpu_kernel( + body, + grid=(2, 1, 1), + cluster=(2, 1, 1), + block=(128, 1, 1), + in_shape=jax.ShapeDtypeStruct((128,), dtype), + out_shape=jax.ShapeDtypeStruct((2, 128), dtype), + smem_scratch_shape=[ + jax.ShapeDtypeStruct((128,), dtype), + mgpu.TMABarrier(), + ], + ) + x = jnp.arange(128, dtype=jnp.float32) + np.testing.assert_array_equal(kernel(x), jnp.stack([x, x], axis=0)) class FragmentedArrayTest(TestCase): @@ -1591,12 +3328,13 @@ class FragmentedArrayTest(TestCase): operator.add, operator.mul, operator.sub, - (lambda x, y: mgpu.FragmentedArray.min(x, y), np.minimum), - (lambda x, y: mgpu.FragmentedArray.max(x, y), np.maximum), + (mgpu.FragmentedArray.min, np.minimum), + (mgpu.FragmentedArray.max, np.maximum), ), - dtype=[jnp.float32, jnp.int32, jnp.uint32], + # TODO(apaszke): Enable float8 + dtype=[jnp.float32, jnp.int32, jnp.uint32, jnp.float16], m=(64, 128), - n=(8, 16, 32, 64, 80, 128, 256), + n=(8, 64, 256), ) @jtu.ignore_warning( message="(invalid value|divide by zero)", category=RuntimeWarning @@ -1608,18 +3346,24 @@ def test_binary(self, op, dtype, m=64, n=32): np_op = op for scalar_rhs in [None, 2]: - def kernel(ctx, dst, _): + def kernel(ctx, lhs, rhs, dst, _): mlir_dtype = utils.dtype_to_ir_type(dtype) - iota = iota_tensor(m, n, dtype) - rhs = iota if scalar_rhs is None else c(scalar_rhs, mlir_dtype) - op(iota, rhs).store_untiled(dst) - out_shape = jax.ShapeDtypeStruct((m, n), dtype) + lhs = mgpu.FragmentedArray.load_strided(lhs, is_signed=utils.is_signed(dtype)) + if scalar_rhs is None: + rhs = mgpu.FragmentedArray.load_strided(rhs, is_signed=utils.is_signed(dtype)) + else: + rhs = c(scalar_rhs, mlir_dtype) + op(lhs, rhs).store_untiled(dst, optimized=False) + if jnp.issubdtype(dtype, jnp.floating): + x = self.prng.uniform(-1, 1, (m, n)).astype(dtype) + y = self.prng.uniform(-1, 1, (m, n)).astype(dtype) + else: + x = self.prng.integers(-16000, 16000, (m, n)).astype(dtype) + y = self.prng.integers(-16000, 16000, (m, n)).astype(dtype) result = mgpu.as_gpu_kernel( - kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () - )() - ref_x = np.arange(m * n, dtype=dtype).reshape(m, n) - ref_rhs = scalar_rhs or ref_x - np.testing.assert_array_equal(result, np_op(ref_x, ref_rhs)) + kernel, (1, 1, 1), (128, 1, 1), (x, y), x, () + )(x, y) + np.testing.assert_array_equal(result, np_op(x, scalar_rhs or y)) def test_minimum_np_compatibility(self): one = np.ones((128, 128)).astype(np.float32) @@ -1658,7 +3402,7 @@ def test_division(self, op, dtype, m=64, n=32): def kernel(ctx, dst, _): iota = iota_tensor(m, n, dtype) - op(dtype(4.2).item() * iota, iota + 1).store_untiled(dst) + op(dtype(4.2).item() * iota, iota + 1).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), dtype) result = mgpu.as_gpu_kernel( @@ -1688,22 +3432,46 @@ def kernel(ctx, dst, _): rhs = 0 if rhs_is_literal else iota + 1 res = op(iota, rhs) assert not res.is_signed - res.astype(i8, is_signed=False).store_untiled(dst) + res.astype(i8, is_signed=False).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), jnp.int8) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () )() iota = np.arange(m * n, dtype=dtype).reshape(m, n) - rhs = rhs = 0 if rhs_is_literal else iota + 1 + rhs = 0 if rhs_is_literal else iota + 1 np.testing.assert_array_equal(result, op(iota, rhs).astype(jnp.int8)) + def test_foreach_wgmma_row_array(self): + def kernel(ctx, out, smem): + del ctx, smem + x = iota_tensor(128, 128, jnp.float32) + row = x.reduce("add", 1) + # Test returning an array + row = row.foreach( + lambda x, _: arith.addf(x, c(1, row.mlir_dtype)), create_array=True + ) + # Test no array return + @row.foreach + def _(v, idx): + memref.store(v, out, idx) + + result = mgpu.as_gpu_kernel( + kernel, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(), + out_shape=jax.ShapeDtypeStruct(shape=(128,), dtype=jnp.float32), + smem_scratch_shape=(), + )() + iota = np.arange(128 * 128, dtype=jnp.float32).reshape(128, 128) + np.testing.assert_array_equal(result, iota.sum(axis=1) + 1) + def test_foreach(self): dtype = jnp.int32 swizzle = 128 - tile = 64, swizzle // jnp.dtype(dtype).itemsize + tiling = (8, swizzle // jnp.dtype(dtype).itemsize) shape = 128, 192 - tiled_shape = mgpu.tile_shape(shape, tile) mlir_dtype = utils.dtype_to_ir_type(dtype) cst = 9999 def causal(val, idx): @@ -1711,12 +3479,16 @@ def causal(val, idx): mask = arith.cmpi(arith.CmpIPredicate.uge, row, col) return arith.select(mask, val, c(cst, mlir_dtype)) - tiling = mgpu.TileTransform(tile) def kernel(ctx, dst, smem): x = iota_tensor(shape[0], shape[1], dtype) - x.foreach(causal, create_array=True, is_signed=False).store_untiled(smem) + x.foreach(causal, create_array=True, is_signed=False).store_tiled(smem, swizzle=128) mgpu.commit_shared() - ctx.async_copy(src_ref=smem, dst_ref=dst) + ctx.async_copy( + src_ref=smem, + dst_ref=dst, + gmem_transform=mgpu.TileTransform(tiling), + swizzle=128, + ) ctx.await_async_copy(0) iota = np.arange(np.prod(shape), dtype=dtype).reshape(*shape) @@ -1726,7 +3498,7 @@ def kernel(ctx, dst, smem): (128, 1, 1), (), jax.ShapeDtypeStruct(shape=shape, dtype=dtype), - jax.ShapeDtypeStruct(shape=shape, dtype=dtype), + jax.ShapeDtypeStruct(shape=mgpu.tile_shape(shape, tiling), dtype=dtype), )() expected = jnp.tril(iota) + jnp.triu(jnp.ones(shape), k=1) * cst np.testing.assert_array_equal(result, expected) @@ -1738,7 +3510,7 @@ def kernel(ctx, dst, smem): def test_bitwise(self, op, dtype, m=64, n=8): def kernel(ctx, dst, _): iota = iota_tensor(m, n, dtype) - op(iota, iota + 1).store_untiled(dst) + op(iota, iota + 1).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), dtype) result = mgpu.as_gpu_kernel( @@ -1752,6 +3524,7 @@ def kernel(ctx, dst, _): (lambda x: -x, jax.lax.neg), (lambda x: x + 42, lambda x: x + 42), (lambda x: x.tanh(), jax.lax.tanh), + (lambda x: x.abs(), np.abs), ), dtype=[jnp.float32, jnp.int32, jnp.uint32], ) @@ -1762,7 +3535,7 @@ def test_unary(self, ops, dtype, m=64, n=32): def kernel(ctx, dst, _): iota = iota_tensor(m, n, dtype) - op(iota).store_untiled(dst) + op(iota).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), dtype) result = mgpu.as_gpu_kernel( @@ -1775,7 +3548,7 @@ def test_select(self, m=64, n=32): def kernel(ctx, dst, _): iota = iota_tensor(m, n, jnp.int32) - (iota < 16).select(iota * 2, iota * 3).store_untiled(dst) + (iota < 16).select(iota * 2, iota * 3).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), jnp.int32) result = mgpu.as_gpu_kernel( @@ -1786,19 +3559,25 @@ def kernel(ctx, dst, _): @parameterized.product( ops=[ - (lambda x: mgpu.FragmentedArray.exp(x), np.exp), - (lambda x: mgpu.FragmentedArray.sin(x), np.sin), - (lambda x: mgpu.FragmentedArray.cos(x), np.cos), - (lambda x: mgpu.FragmentedArray.rsqrt(x), jax.lax.rsqrt), + (mgpu.FragmentedArray.exp, np.exp), + (mgpu.FragmentedArray.sin, np.sin), + (mgpu.FragmentedArray.cos, np.cos), + (mgpu.FragmentedArray.rsqrt, jax.lax.rsqrt), + (mgpu.FragmentedArray.erf, jax.scipy.special.erf), ], approx=[False, True], ) @jtu.ignore_warning(message="overflow encountered", category=RuntimeWarning) def test_math(self, ops, approx, m=64, n=32): op, np_op = ops + kwargs = dict(approx=approx) + if op is mgpu.FragmentedArray.erf: + if approx: + raise self.skipTest("ERF not supported with approximation") + kwargs = {} def kernel(ctx, dst, _): iota = iota_tensor(m, n, jnp.float32) - op(iota).store_untiled(dst) + op(iota, **kwargs).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () @@ -1808,6 +3587,61 @@ def kernel(ctx, dst, _): rtol = 4e-6 if approx else 2e-7 np.testing.assert_allclose(result, np_op(x), atol=atol, rtol=rtol) + def test_atan2(self, m=64, n=32): + def kernel(ctx, dst, _): + y = iota_tensor(m, n, jnp.float32) + 1 # Avoid zero + x = iota_tensor(m, n, jnp.float32) + 2 + y.atan2(x).store_untiled(dst, optimized=False) + + out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () + )() + y = np.arange(m * n, dtype=jnp.float32).reshape(m, n) + 1 + x = np.arange(m * n, dtype=jnp.float32).reshape(m, n) + 2 + np.testing.assert_allclose(result, np.arctan2(y, x), atol=2e-7, rtol=2e-7) + + def test_strided_copy_noncontig_good(self): + def kernel(ctx, src, dst, _): + src_slice = mgpu.memref_slice(src, (slice(None), 1)) + mgpu.FragmentedArray.load_strided(src_slice, is_signed=True, vec_size=4).store_untiled(dst) + + in_shape = jax.ShapeDtypeStruct((32, 2, 32), jnp.int32) + out_shape = jax.ShapeDtypeStruct((32, 32), jnp.int32) + + kernel_fn = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), in_shape, out_shape, () + ) + x = np.arange(math.prod(in_shape.shape), dtype=jnp.int32).reshape(in_shape.shape) + np.testing.assert_array_equal(kernel_fn(x), x[:, 1]) + + def test_strided_copy_noncontig_bad(self): + def kernel(ctx, src, dst, _): + src_slice = mgpu.memref_slice(src, (slice(None), 1)) + mgpu.FragmentedArray.load_strided(src_slice, is_signed=True, vec_size=2).store_untiled(dst) + + out_shape = jax.ShapeDtypeStruct((256, 7), jnp.int32) + + in_shape = jax.ShapeDtypeStruct((256, 6, 7), jnp.int32) + msg = ( + "The contiguous dimension of the reference must be a multiple of the" + " layout's vector size (got 7 and vector size 2)" + ) + with self.assertRaises(ValueError, msg=msg): + mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), in_shape, out_shape, () + ) + + in_shape = jax.ShapeDtypeStruct((256, 5, 7), jnp.int32) + msg = ( + "Non-contiguous dimension of the reference must have strides that are" + " multiples of the layout's vector size (got 35 and vector size 2)" + ) + with self.assertRaises(ValueError, msg=msg): + mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), in_shape, out_shape, () + ) + @parameterized.product( dtype=[jnp.float32, jnp.int32], m=[128], @@ -1818,8 +3652,8 @@ def kernel(ctx, src, dst, scratch): src = mgpu.FragmentedArray.load_strided( src, is_signed=utils.is_signed(dtype) ) - acc = src.reduce_sum(scratch).broadcast((m,)) - acc.store_untiled(dst) + acc = src.reduce("add", (0, 1), scratch).broadcast((m,)) + acc.store_untiled(dst, optimized=False) in_shape = jax.ShapeDtypeStruct((m, n), dtype) out_shape = jax.ShapeDtypeStruct((m,), dtype) @@ -1834,20 +3668,38 @@ def kernel(ctx, src, dst, scratch): x = np.arange(m * n, dtype=dtype).reshape(m, n) np.testing.assert_array_equal(kernel_fn(x), jnp.full((m,), x.sum())) + def test_dimension_compression_for_vec_size(self): + def body(ctx, src, dst, _): + src_arr = mgpu.FragmentedArray.load_strided( + mgpu.memref_slice(src, (slice(None), slice(4, None))), vec_size=4 + ) + src_arr.store_untiled(dst, optimized=False) + in_shape = jax.ShapeDtypeStruct((8, 20, 4, 3, 1), jnp.float32) + out_shape = jax.ShapeDtypeStruct((8, 16, 4, 3, 1), jnp.float32) + kernel = mgpu.as_gpu_kernel( + body, (1, 1, 1), (128, 1, 1), in_shape, out_shape, () + ) + x = np.arange(math.prod(in_shape.shape), dtype=np.float32).reshape(in_shape.shape) + np.testing.assert_array_equal(kernel(x), x[:, 4:]) + @parameterized.product( dtype=[jnp.float32, jnp.int32], m=[128], n=[32, 64], + reduce_both=[False, True], ) - def test_splat_reduce_sum(self, dtype, m, n): + def test_splat_reduce_sum(self, dtype, m, n, reduce_both): def kernel(ctx, dst, _): src = mgpu.FragmentedArray.splat( utils.c(1, utils.dtype_to_ir_type(dtype)), (m, n), is_signed=utils.is_signed(dtype), ) - acc = src.reduce_sum().broadcast((m,)) - acc.store_untiled(dst) + if reduce_both: + acc = src.reduce("add", (0, 1)).broadcast((m,)) + else: + acc = src.reduce("add", 1) + acc.store_untiled(dst, optimized=False) kernel_fn = mgpu.as_gpu_kernel( kernel, @@ -1857,7 +3709,40 @@ def kernel(ctx, dst, _): out_shape=jax.ShapeDtypeStruct((m,), dtype), smem_scratch_shape=(), ) - np.testing.assert_array_equal(kernel_fn(), jnp.full((m,), m * n * 1.0)) + result = m * n if reduce_both else n + np.testing.assert_array_equal(kernel_fn(), jnp.full((m,), result, dtype)) + + @parameterized.named_parameters( + ("wgmma_row", fa.WGMMA_LAYOUT, fa.WGMMA_ROW_LAYOUT, 1), + ("wgmma_col", fa.WGMMA_LAYOUT, fa.WGMMA_COL_LAYOUT, 0), + ("tcgen05_row", tcgen05.LAYOUT, tcgen05.ROW_LAYOUT, 1), + ("tcgen05_col", tcgen05.LAYOUT, tcgen05.COL_LAYOUT, 0), + ) + def test_layout_reduction_definition(self, layout, expected_reduced_layout, axis): + self.assertEqual(layout.reduce((axis,)), expected_reduced_layout) + + def test_layout_reduction_handles_tiles_with_three_different_ranks(self): + layout = fa.TiledLayout( + tiling=fa.Tiling(tiles=((1, 2, 64), (2, 16), (8,), (4,), (2,), (1,))), + warp_dims=(-7,), + lane_dims=(-6, -5, -4, -3, -2), + vector_dim=-1, + ) + self.assertEqual( + layout.reduce((2,)), + fa.TiledLayout( + tiling=fa.Tiling(tiles=((1, 2), (1,))), + warp_dims=(fa.Replicated(times=4),), + lane_dims=( + -2, + fa.Replicated(times=2), + fa.Replicated(times=2), + fa.Replicated(times=2), + fa.Replicated(times=2), + ), + vector_dim=-1, + ), + ) @parameterized.product( op=(arith.addf, arith.maximumf), @@ -1867,7 +3752,9 @@ def kernel(ctx, dst, _): def test_reduce(self, op, m=64, n=32): def kernel(ctx, dst, _): iota = iota_tensor(m, n, jnp.float32) - iota.reduce(op, axis=1).broadcast_minor(n).store_untiled(dst) + iota.reduce(op, axis=1).broadcast_in_dim( + (m, n), (0,), mgpu.WGMMA_LAYOUT + ).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () @@ -1881,6 +3768,52 @@ def kernel(ctx, dst, _): raise NotImplementedError(f"Unsupported op: {op}") np.testing.assert_array_equal(result, expected) + @parameterized.product( + vec_size=(4, 3, 1), + dtype=(jnp.float32, jnp.float16, jnp.bfloat16, + jnp.int32, jnp.int16, jnp.uint32, jnp.uint16), + ) + @jtu.thread_unsafe_test() + def test_max(self, vec_size, dtype): + def kernel(ctx, src, src2, dst, _): + is_signed = utils.is_signed(dtype) + src = fa.FragmentedArray.load_strided(src, vec_size=vec_size, is_signed=is_signed) + src2 = fa.FragmentedArray.load_strided(src2, vec_size=vec_size, is_signed=is_signed) + src.max(src2).store_untiled(dst) + x = self.prng.uniform(-1, 1, (12 * 128,)).astype(dtype) + y = self.prng.uniform(-1, 1, (12 * 128,)).astype(dtype) + f = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (x, y), x, () + ) + with jtu.set_env(MOSAIC_GPU_DUMP_PTX="1"), self.capture_stdout() as ptx: + z = f(x, y).block_until_ready() + if dtype == jnp.float32: + dtype_short = "f32" + elif dtype == jnp.float16: + dtype_short = "f16" + elif dtype == jnp.bfloat16: + dtype_short = "bf16" + elif jnp.issubdtype(dtype, jnp.signedinteger): + dtype_short = f"s{dtypes.itemsize_bits(dtype)}" + elif jnp.issubdtype(dtype, jnp.unsignedinteger): + dtype_short = f"u{dtypes.itemsize_bits(dtype)}" + else: + raise NotImplementedError(f"Unsupported dtype: {dtype}") + ptx = ptx() + nan_modifier = ".NaN" if jnp.issubdtype(dtype, jnp.floating) else "" + instr = f"max{nan_modifier}.{dtype_short} " + instr_double = f"max{nan_modifier}.{dtype_short}x2 " + single_converts = ptx.count(instr) + double_converts = ptx.count(instr_double) + self.assertEqual(128 * (single_converts + 2 * double_converts), 12 * 128) + if vec_size % 2: + self.assertGreater(single_converts, 0) + elif dtypes.itemsize_bits(dtype) < 32: + # This, together with the assertion above, implies that all converts + # happened through doubled operations. + self.assertEqual(single_converts, 0) + np.testing.assert_array_equal(z, np.maximum(x, y)) + def test_splat_layout(self): m, n = 64, 8 def kernel(ctx, dst, _): @@ -1888,7 +3821,7 @@ def kernel(ctx, dst, _): cte = c(1, iota.mlir_dtype) cte_arr = mgpu.FragmentedArray.splat(cte, ()) cte_arr = cte_arr.reshape((1, 1)).broadcast((m, n)) - (iota + cte_arr).store_untiled(dst) + (iota + cte_arr).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () @@ -1903,7 +3836,7 @@ def kernel(ctx, dst, _): t = mgpu.FragmentedArray.splat( v, (128,), mgpu.WGMMA_ROW_LAYOUT ) - t.broadcast_minor(32).store_untiled(dst) + t.broadcast_in_dim((128, 32), (0,), mgpu.WGMMA_LAYOUT).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () @@ -1922,7 +3855,7 @@ def kernel(ctx, src, dst, _): assert isinstance(pi_arr_sq.layout, mgpu.WGStridedFragLayout) pi_arr_cube = pi_splat.broadcast(pi_arr.shape) * pi_arr_sq assert isinstance(pi_arr_cube.layout, mgpu.WGStridedFragLayout) - (pi_arr == pi_arr).select(pi_splat, pi_arr_cube).store_untiled(dst) + (pi_arr == pi_arr).select(pi_splat, pi_arr_cube).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32) inp = jnp.ones_like(out_shape) * 3.14 @@ -1932,6 +3865,8 @@ def kernel(ctx, src, dst, _): np.testing.assert_allclose(result, np.full((128, 32), 3.14, np.float32)) @parameterized.product(in_shape=((128, 128), (128, 64), (64, 128))) + @jtu.skip_if_mosaic_gpu_exceeds_shared_memory( + device_patterns=("RTX PRO 6000 Blackwell", "GB10$")) def test_strided_load_store(self, in_shape): def kernel(ctx, *args): gmem_input, gmem_output, (smem_input, smem_output) = args @@ -1946,20 +3881,137 @@ def kernel(ctx, *args): )(inp) np.testing.assert_array_equal(inp, result) - @parameterized.product(in_shape=((128,), (64,))) - def test_wgmma_row_load_store_with_layout(self, in_shape): + @parameterized.product( + in_shape=((1024,), (256,), (128,), (64,)), + dtype=(jnp.float16, jnp.float32), + swizzle=(16, 32, 64, 128) + ) + def test_wgmma_row_load_store_with_layout(self, in_shape, dtype, swizzle): + def kernel(ctx, gmem_input, gmem_output, smem): + smem_input, smem_output = smem + copy(gmem_input, smem_input, swizzle=swizzle) + t = mgpu.FragmentedArray.load_untiled( + smem_input, layout=mgpu.WGMMA_ROW_LAYOUT, swizzle=swizzle + ) + t.store_untiled(smem_output) + copy(smem_output, gmem_output) + + inp = out = self.prng.uniform(-1, 1, in_shape).astype(dtype) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (inp,), out, [inp, out], + )(inp) + np.testing.assert_array_equal(inp, result) + + @parameterized.product( + in_shape=((128,), (64,)), + dtype=(jnp.float16, jnp.float32), + swizzle=(16, 32, 64, 128), + ) + def test_wgmma_col_load_store_with_layout(self, in_shape, dtype, swizzle): def kernel(ctx, *args): gmem_input, gmem_output, (smem_input, smem_output) = args - copy(gmem_input, smem_input) - t = mgpu.FragmentedArray.load_wgmma_row(smem_input) + copy(gmem_input, smem_input, swizzle=swizzle) + t = mgpu.FragmentedArray.load_untiled( + smem_input, swizzle=swizzle, layout=mgpu.WGMMA_COL_LAYOUT + ) t.store_untiled(smem_output) copy(smem_output, gmem_output) - inp = out = self.prng.uniform(-1, 1, in_shape).astype(jnp.float32) + inp = out = self.prng.uniform(-1, 1, in_shape).astype(dtype) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (inp,), out, [inp, out], )(inp) - np.testing.assert_array_equal(inp, result) + np.testing.assert_array_equal(result, inp) + + @parameterized.parameters( + (128, 128), (64, 128), (64, 256) + ) + def test_broadcast_in_dim_major_strided(self, m, n): + dtype = jnp.float16 + def kernel(ctx, gmem_input, gmem_output, _): + t = mgpu.FragmentedArray.load_strided( + gmem_input, vec_size=1 + ) + t.broadcast_in_dim((m, n), (1,), + mgpu.WGStridedFragLayout(shape=(m, n), vec_size=1), + ).store_untiled(gmem_output, optimized=False) + + inp = self.prng.uniform(-1, 1, (n,)).astype(dtype) + out_shape = jax.ShapeDtypeStruct((m, n), dtype) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (inp,), out_shape, inp + )(inp) + out_ref = jax.lax.broadcast_in_dim(inp, (m, n), (1,)) + np.testing.assert_array_equal(result, out_ref) + + @parameterized.parameters( + (128, 128), (128, 64), (64, 128) + ) + def test_broadcast_in_dim_major_wgmma(self, m, n): + dtype = jnp.float16 + + def kernel(ctx, gmem_input, gmem_output, _): + t = mgpu.FragmentedArray.load_untiled( + gmem_input, layout=mgpu.WGMMA_COL_LAYOUT, optimized=False + ) + t.broadcast_in_dim( + (m, n), (1,), mgpu.WGMMA_LAYOUT + ).store_untiled(gmem_output, optimized=False) + + inp = self.prng.uniform(-1, 1, (n,)).astype(dtype) + out_shape = jax.ShapeDtypeStruct((m, n), dtype) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (inp,), out_shape, inp + )(inp) + out_ref = jax.lax.broadcast_in_dim(inp, (m, n), (1,)) + np.testing.assert_array_equal(result, out_ref) + + @parameterized.parameters( + ((128), (4, 128)), + ((1, 128), (2, 128)), + ((1, 128), (4, 128)), + ((1, 256), (2, 256)), + ((128, ), (1, 3, 1, 2, 4, 128)), + ((1, 1, 128,), (1, 3, 1, 2, 4, 128)), + ((1, 1, 1, 1, 1, 128,), (1, 3, 1, 2, 4, 128)), + ((2, 4, 128,), (1, 3, 1, 2, 4, 128)), + ((1, 1, 1, 2, 4, 128,), (1, 3, 1, 2, 4, 128)), + ((2, 8, 8), (2, 8, 8)), + ) + def test_broadcast_major_strided(self, in_shape, out_shape): + dtype = jnp.float16 + def kernel(ctx, gmem_input, gmem_output, _): + t = mgpu.FragmentedArray.load_strided(gmem_input, vec_size=1) + t.broadcast(out_shape).store_untiled(gmem_output, optimized=False) + inp = self.prng.uniform(-1, 1, in_shape).astype(dtype) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (inp,), jax.ShapeDtypeStruct(out_shape, dtype), inp + )(inp) + np.testing.assert_array_equal(result, jnp.broadcast_to(inp, out_shape)) + + @parameterized.parameters(*mtu.RegisterLayout) + def test_broadcast_splat(self, layout): + out_shape = (128, 128) + + def body(ctx, out_ref, scratch): + del ctx, scratch + c42 = arith.constant(ir.IntegerType.get_signless(32), 42) + arr = mgpu.FragmentedArray.splat(c42, (128,), is_signed=True) + out_layout = layout.to_mgpu(out_shape, jnp.int32) + result = arr.broadcast_in_dim(out_shape, (0,), out_layout) + result.store_untiled(out_ref, optimized=False) + + kernel = mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(), + out_shape=jax.ShapeDtypeStruct(out_shape, jnp.int32), + smem_scratch_shape=[], + ) + np.testing.assert_array_equal( + kernel(), np.full(out_shape, 42, dtype=np.int32) + ) def test_warp_tree_reduce(self): def kernel(ctx, out, *_): @@ -1988,7 +4040,7 @@ def kernel(ctx, inp, out, smem): del ctx, smem arr = mgpu.FragmentedArray.load_strided(inp, is_signed=True) assert ir.VectorType(arr.registers.flat[0].type).shape == [reg_length] - arr.astype(mlir_dtype_to).store_untiled(out) + arr.astype(mlir_dtype_to).store_untiled(out, optimized=False) x = jnp.arange(-128, 128, dtype=jax_dtype_from) x = jnp.tile(x, reg_length // 2) @@ -2000,17 +4052,55 @@ def kernel(ctx, inp, out, smem): np.testing.assert_array_equal(result, reference) @parameterized.parameters( - ([64 * 4], "WGMMA_ROW_LAYOUT"), - ([64 * 4, 8 * 2], "WGMMA_LAYOUT"), + ([64 * 4], mgpu.WGMMA_ROW_LAYOUT), + ([64 * 4, 8 * 2], mgpu.WGMMA_LAYOUT), ) - def test_to_layout(self, shape, new_layout): + def test_splat_relayout(self, shape, new_layout): def kernel(ctx, _): # No assertions, we are just checking there are no compile-time errors. arr = mgpu.FragmentedArray.splat(c(42.0, ir.F32Type.get()), shape) - arr.to_layout(getattr(mgpu, new_layout)) + arr.to_layout(new_layout) _ = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), (), (), None)() + @parameterized.parameters( + (mgpu.WGMMA_LAYOUT, mgpu.WGMMA_TRANSPOSED_LAYOUT), + (mgpu.TCGEN05_LAYOUT, mgpu.TCGEN05_TRANSPOSED_LAYOUT), + (mgpu.WGMMA_TRANSPOSED_LAYOUT, mgpu.WGMMA_LAYOUT), + (mgpu.TCGEN05_TRANSPOSED_LAYOUT, mgpu.TCGEN05_LAYOUT), + ) + def test_transpose_relayout(self, src_layout, dst_layout): + def is_transposed(layout): + return ( + layout == mgpu.WGMMA_TRANSPOSED_LAYOUT + or layout == mgpu.TCGEN05_TRANSPOSED_LAYOUT + ) + + def body(ctx, src, dst, scratch): + del ctx, scratch + if is_transposed(src_layout): + src = utils.memref_transpose(src, (1, 0)) + src_reg = mgpu.FragmentedArray.load_untiled( + src, layout=src_layout, optimized=False + ) + dst_reg = src_reg.to_layout(dst_layout) + if is_transposed(dst_layout): + dst = utils.memref_transpose(dst, (1, 0)) + dst_reg.store_untiled(dst, optimized=False) + + shape = (128, 128) + dtype = jnp.float32 + kernel = mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(jax.ShapeDtypeStruct(shape, dtype),), + out_shape=jax.ShapeDtypeStruct(shape, dtype), + smem_scratch_shape=[], + ) + x = self.prng.uniform(-1, 1, shape).astype(dtype) + np.testing.assert_array_equal(kernel(x), x.T) + @parameterized.parameters( (jnp.float16, jnp.float16), # Noop (jnp.int16, jnp.bfloat16), @@ -2059,12 +4149,27 @@ def kernel(ctx, inp, out, smem): f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, None) np.testing.assert_array_equal(f(x), x * 3) + def test_optimization_barrier_with_single_value(self): + shape = (64, 64) + value = 5.0 + dtype = jnp.float32 + def kernel(ctx, out, smem): + del ctx, smem + mlir_type = utils.dtype_to_ir_type(dtype) + arr = mgpu.FragmentedArray.splat(c(value, mlir_type), shape) + arr = mgpu.optimization_barrier(arr) + arr.store_untiled(out) + + out_shape = jax.ShapeDtypeStruct(shape, dtype) + f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()) + np.testing.assert_array_equal(f(), jnp.full(shape, value, dtype=dtype)) + def test_convert_bool_to_u8(self): m, n = 128, 128 def kernel(ctx, dst, _): i8 = ir.IntegerType.get_signless(8) iota = iota_tensor(m, n, jnp.uint8) - (iota > 10).astype(i8, is_signed=False).store_untiled(dst) + (iota > 10).astype(i8, is_signed=False).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), jnp.int8) result = mgpu.as_gpu_kernel( @@ -2073,6 +4178,59 @@ def kernel(ctx, dst, _): iota = np.arange(m * n, dtype=jnp.uint8).reshape(m, n) np.testing.assert_array_equal(result, (iota > 10).astype(jnp.uint8)) + @parameterized.product(dtype=(jnp.bfloat16, jnp.float16)) + def test_mma(self, dtype): + m, n, k = 128, 128, 128 + def kernel(ctx: mgpu.LaunchContext, acc, a, b, out, scratch): + (acc_smem, a_smem, b_smem), barrier = scratch + + def load(x, x_smem, layout, swizzle=32): + ctx.async_copy( + src_ref=x, + dst_ref=x_smem, + gmem_transform=mgpu.TileTransform(tuple(x_smem.type.shape[2:])), + swizzle=swizzle, + barrier=barrier, + ) + barrier.wait() + return fa.FragmentedArray.load_tiled(x_smem, swizzle=swizzle, layout=layout) + + b_fa = load(b, b_smem, mgpu.MMALayouts.rhs) + a_fa = load(a, a_smem, mgpu.MMALayouts.lhs) + acc_fa = load(acc, acc_smem, mgpu.MMALayouts.acc) + result_fa: mgpu.FragmentedArray = mgpu.mma(acc_fa, a_fa, b_fa) + result_fa.store_tiled(acc_smem, swizzle=32) + mgpu.commit_shared() + ctx.async_copy( + src_ref=acc_smem, + dst_ref=out, + gmem_transform=mgpu.TileTransform(tuple(acc_smem.type.shape[2:])), + swizzle=32, + ) + ctx.await_async_copy(0) + + a = self.prng.uniform(-1, 1, (m, k)).astype(dtype) + b = self.prng.uniform(-1, 1, (n, k)).astype(dtype) + acc = self.prng.uniform(-1, 1, (m, n)).astype(jnp.float32) + + expected = acc + a.astype(jnp.float32) @ b.astype(jnp.float32).T + result = mgpu.as_gpu_kernel( + kernel, + (1, 1, 1), + (128, 1, 1), + (acc, a, b), + out_shape=expected, + smem_scratch_shape=( + mgpu.Union([ + jax.ShapeDtypeStruct(mgpu.tile_shape((m, n), (8, 8)), dtype=jnp.float32), + jax.ShapeDtypeStruct(mgpu.tile_shape((m, k), (8, 16)), dtype=dtype), + jax.ShapeDtypeStruct(mgpu.tile_shape((n, k), (8, 16)), dtype=dtype), + ]), + mgpu.Barrier(1) + ), + )(acc, a, b) + np.testing.assert_allclose(result, expected, atol=1e-5) + @parameterized.parameters( (jnp.uint8, jnp.uint16, 255), (jnp.uint8, jnp.int16, 255), @@ -2100,60 +4258,100 @@ def kernel(ctx, dst, _): expected = jnp.full((m, n), value, dtype=from_dtype).astype(to_dtype) np.testing.assert_array_equal(result, expected) + @parameterized.product( + swizzle=(16, 32, 64, 128), + shape=((128, 128), (8, 128), (128, 32), (48, 64)), + to_smem=(True, False), + ) + def test_copy_tiled(self, swizzle, shape, to_smem): + dtype = jnp.int32 + tiling = (8, 8 * swizzle // jnp.iinfo(dtype).bits) + def kernel(ctx, src, dst, scratch): + smem, barrier = scratch + if to_smem: + mgpu.copy_tiled(src, smem, swizzle=swizzle) + mgpu.commit_shared() + ctx.async_copy( + src_ref=smem, + dst_ref=dst, + gmem_transform=mgpu.TileTransform(tiling), + swizzle=swizzle, + ) + ctx.await_async_copy(0) + else: + ctx.async_copy( + src_ref=src, + dst_ref=smem, + gmem_transform=mgpu.TileTransform(tiling), + swizzle=swizzle, + barrier=barrier, + ) + barrier.wait() + mgpu.copy_tiled(smem, dst, swizzle=swizzle) -class ProfilerTest(TestCase): + x = jnp.arange(math.prod(shape), dtype=dtype).reshape(shape) + scratch_shape = [ + jax.ShapeDtypeStruct(mgpu.tile_shape(shape, tiling), dtype), + mgpu.TMABarrier(1), + ] + y = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, scratch_shape + )(x) + np.testing.assert_array_equal(y, x) - def test_measure_events_explicit(self): - x = jnp.arange(1024 * 1024) - _, runtime_ms = profiler.measure(lambda x, y: x + y, mode="events")(x, x) - self.assertIsInstance(runtime_ms, float) + @parameterized.parameters( + ((32, 32), (0, 5)), + ((32, 128), (3,)), + ((32, 32, 128), (slice(1, 3), 0)), + ) + def test_splat_indexing(self, shape, indices): + def _kernel(ctx, out_ref, scratch): + del ctx, scratch + splat = mgpu.FragmentedArray.splat(c(1.0, ir.F32Type.get()), shape) + splat[indices].store_untiled(out_ref) - def test_profile(self): - def kernel(ctx, src, dst, _): - mgpu.FragmentedArray.load_strided(src).store_untiled(dst) - x = np.arange(64 * 64, dtype=jnp.float32).reshape(64, 64) - spec = profiler.ProfilerSpec(1024) - # This is just a smoke test. - f = jax.jit(mgpu.as_gpu_kernel( - kernel, (1, 1, 1), (128, 1, 1), x, x, (), prof_spec=spec - )) - jax.block_until_ready(f(x)) - - def test_multigpu(self): - if len(jax.devices()) < 2: - self.skipTest("Need at least 2 devices") - def kernel(ctx, src, dst, _): - mgpu.FragmentedArray.load_strided(src).store_untiled(dst) - x = np.arange(64 * 64, dtype=jnp.float32).reshape(64, 64) - f = jax.jit(mgpu.as_gpu_kernel( - kernel, (1, 1, 1), (128, 1, 1), x, x, () - )) - # Make sure we can invoke the same program on different devices. - for xd in (jax.device_put(x, d) for d in jax.devices()[:2]): - jax.block_until_ready(f(xd)) + expected = np.ones(shape, dtype=jnp.float32)[indices] + kernel = mgpu.as_gpu_kernel( + _kernel, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(), + out_shape=expected, + smem_scratch_shape=(), + ) + np.testing.assert_array_equal(kernel(), expected) -class TorchTest(TestCase): +class ProfilerTest(TestCase, jtu.JaxTestCase): - def setUp(self): - super().setUp() - try: - import torch - except ImportError: - raise unittest.SkipTest("Test requires PyTorch") - self.torch = torch - - def test_basic(self): - def kernel(ctx, i_gmem, o_gmem, _): - x = mgpu.FragmentedArray.load_strided(i_gmem) - (x + x).store_untiled(o_gmem) - - ty = jax.ShapeDtypeStruct((128, 128), jnp.float32) - x = self.torch.randn((128, 128), dtype=self.torch.float, device='cuda') - f = mgpu.as_torch_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), ty, ty, ()) - y = f(x) - np.testing.assert_allclose(y.cpu(), x.cpu() * 2) - del y # Make sure the destructor runs successfully. + def test_profiler(self): + def body(ctx, input, result, scratch): + del scratch + with ctx.named_region("load"): + reg = mgpu.FragmentedArray.load_strided(input) + with ctx.named_region("store"): + reg.store_untiled(result) + + dtype = jnp.bfloat16 + shape = (128, 128) + jax_shape = jax.ShapeDtypeStruct(shape, dtype) + with tempfile.TemporaryDirectory() as tmpdir: + kernel = mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(jax_shape), + out_shape=jax_shape, + smem_scratch_shape=[], + prof_spec=profiler.ProfilerSpec(1024, dump_path=tmpdir), + ) + param = self.prng.uniform(-1, 1, shape).astype(dtype) + self.assertArraysEqual(kernel(param), param) + [name] = os.listdir(tmpdir) + with open(os.path.join(tmpdir, name)) as f: + data = f.read() + self.assertEqual(data.count('"name": "load"'), 2) + self.assertEqual(data.count('"name": "store"'), 2) class LayoutTest(TestCase): @@ -2188,11 +4386,11 @@ def kernel(ctx, dst, _): # Note that WGMMA layouts are always (shape[0] // 64, shape[1] // 8, 2, 1) self.assertEqual( tiled.registers.shape, - (shape[0] // 64, shape[1] // 8, 1, 1, 2, 1, 1, 1, 1, 1), + (shape[0] // 64, shape[1] // 8, 1, 1, 2, 1, 1, 1, 1), ) self.assertEqual(tiled.shape, shape) self.assertEqual(tiled.mlir_dtype, iota.mlir_dtype) - tiled.store_untiled(dst) + tiled.store_untiled(dst, optimized=False) ty = jax.ShapeDtypeStruct(shape, dtype) f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), (), ty, ()) expected = np.arange(math.prod(shape), dtype=dtype).reshape(shape) @@ -2204,6 +4402,7 @@ def kernel(ctx, dst, _): num_col_tiles=[1, 2, 3], row_tiling=[8, 64], ) + @jtu.thread_unsafe_test() # Modifies ``os.environ``. def test_copy_tiled(self, dtype, swizzle, num_col_tiles, row_tiling): mlir_dtype = utils.dtype_to_ir_type(dtype) bw = bytewidth(mlir_dtype) @@ -2229,7 +4428,7 @@ def kernel(ctx, in_, out, smems): .transpose(0, 2, 1, 3) ) - with get_sass() as sass: + with jtu.set_env(MOSAIC_GPU_DUMP_SASS="1"), self.capture_stdout() as sass: iota = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), expected, expected, [expected, expected, mgpu.TMABarrier()], @@ -2252,7 +4451,13 @@ def get_reg(addr): return addr[:pos] return addr used_regs = {get_reg(addr) for addr in addrs} - self.assertLessEqual(len(used_regs), expected_regs) + try: + self.assertLessEqual(len(used_regs), expected_regs) + except: + problematic_device_patterns = ("RTX PRO 6000 Blackwell", "GB10$") + if match := jtu.device_kind_match(problematic_device_patterns): + self.skipTest(f"{match} uses more registers for an unknown reason") + raise def test_copy_for_upcast(self): dtype = jnp.int8 @@ -2281,23 +4486,35 @@ def kernel(ctx, in_, out, smems): np.testing.assert_array_equal(f(x), x) @parameterized.product( - dtype=[jnp.int16], # TODO(apaszke): More dtypes - # TODO(apaszke): swizzle=64 <- not implemented in transfer_tiled right now - swizzle=[16, 32, 128], + dtype=[jnp.int16, jnp.int32], # TODO(apaszke): More dtypes + swizzle=[16, 32, 64, 128], + layouts=[ + (fa.WGMMA_LAYOUT, fa.WGMMA_TRANSPOSED_LAYOUT), + (fa.TCGEN05_LAYOUT, fa.TCGEN05_TRANSPOSED_LAYOUT), + ], ) - def test_transpose_tiled(self, dtype, swizzle): + @jtu.skip_if_mosaic_gpu_exceeds_shared_memory( + device_patterns=("RTX PRO 6000 Blackwell", "GB10$")) + def test_transpose_tiled(self, dtype, swizzle, layouts): mlir_dtype = utils.dtype_to_ir_type(dtype) bw = bytewidth(mlir_dtype) col_tiling = swizzle // bw - m, n = 128, 256 + if bw == 2: + m, n = 256, 192 + elif bw == 4: + m, n = 256, 96 + else: + raise ValueError(f"Unsupported bitwidth: {bw}") tiling = (8, col_tiling) - transpose_layout = fa.WGMMA_TRANSPOSED_LAYOUT + if col_tiling < 8: + self.skipTest("Swizzle too small") + layout, transpose_layout = layouts def kernel(ctx, in_, out, smems): smem_in, smem_out, barrier = smems ctx.async_copy(src_ref=in_, dst_ref=smem_in, swizzle=swizzle, barrier=barrier) barrier.wait() t = mgpu.FragmentedArray.load_tiled( - smem_in, swizzle=swizzle, is_signed=True, layout=fa.WGMMA_LAYOUT + smem_in, swizzle=swizzle, is_signed=True, layout=layout ) smem_out_t = memref_transpose(smem_out, (1, 0, 3, 2)) t.to_layout(transpose_layout).store_tiled(smem_out_t, swizzle=swizzle) @@ -2328,6 +4545,8 @@ def kernel(ctx, in_, out, smems): (fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT, jnp.int4, jnp.int4, 0.5), (fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT, jnp.int4, jnp.int4, 2), ) + @jtu.thread_unsafe_test() # Modifies ``os.environ``. + @jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell") def test_upcast_to_wgmma( self, start_layout, end_layout, in_dtype, cast_dtype, shfl_per_reg ): @@ -2339,7 +4558,7 @@ def test_upcast_to_wgmma( in_tiling = (8, in_col_tiling) out_col_tiling = swizzle // out_dtype.itemsize out_tiling = (8, out_col_tiling) - m, n = 128, in_col_tiling * 2 + m, n = 64, in_col_tiling * 2 regs_per_thread = None def kernel(ctx, in_, out, smems): nonlocal regs_per_thread @@ -2371,10 +4590,46 @@ def tile(x, tiling): f = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), xt, yt, [xt, yt, mgpu.TMABarrier()], ) - with get_sass() as sass: + with jtu.set_env(MOSAIC_GPU_DUMP_SASS="1"), self.capture_stdout() as sass: yt_kernel = f(xt) + jax.block_until_ready(yt_kernel) np.testing.assert_array_equal(yt_kernel, yt) - self.assertEqual(sass().count("SHFL.BFLY"), regs_per_thread * shfl_per_reg) + try: + self.assertEqual(sass().count("SHFL.BFLY"), regs_per_thread * shfl_per_reg) + except: + problematic_device_patterns = ("RTX PRO 6000 Blackwell", "GB10$") + if match := jtu.device_kind_match(problematic_device_patterns): + self.skipTest(f"{match} requires more SHFL.BFLY for an unknown reason") + raise + + @parameterized.product( + in_length=[1, 2, 4, 8], + out_length=[1, 2, 4, 8], + ) + def test_convert_tmem_native_vector_length(self, in_length, out_length): + dtype = jnp.dtype(jnp.int16) + def kernel(ctx, in_, out, smems): + smem_in, smem_out, barrier = smems + ctx.async_copy(src_ref=in_, dst_ref=smem_in, barrier=barrier) + barrier.wait() + t = mgpu.FragmentedArray.load_untiled( + smem_in, layout=mgpu.tmem_native_layout(in_length), + is_signed=True, optimized=False + ) + t = t.to_layout(mgpu.tmem_native_layout(out_length)) + t.store_untiled(smem_out, optimized=False) + mgpu.commit_shared() + ctx.async_copy(src_ref=smem_out, dst_ref=out) + ctx.await_async_copy(0) + iinfo = jnp.iinfo(dtype) + x = jax.random.randint( + jax.random.key(42), (128, 128), iinfo.min, iinfo.max, dtype=jnp.int16 + ) + f = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, [x, x, mgpu.TMABarrier()], + ) + y = f(x) + np.testing.assert_array_equal(y, x) @dataclasses.dataclass(frozen=True) @@ -2428,43 +4683,119 @@ def set_in_transforms( if not transforms: return - in_transforms = [] - smem_refs = filter(inference_utils.is_transformable_smem_memref, op.operands) # pylint: disable=undefined-variable - for _, result_transforms in jax.util.safe_zip(smem_refs, transforms): - in_transforms.append( - ir.ArrayAttr.get([t.attr() for t in result_transforms]) + in_transforms = [] + smem_refs = filter(inference_utils.is_transformable_smem_memref, op.operands) # pylint: disable=undefined-variable + for _, result_transforms in jax._src.util.safe_zip(smem_refs, transforms): + in_transforms.append( + ir.ArrayAttr.get([t.attr() for t in result_transforms]) + ) + + op.attributes["in_transforms"] = ir.ArrayAttr.get(in_transforms) + +class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase): + """Device tests with lowering from the MLIR dialect and layout inference.""" + + def setUp(self): + if mgpu_dialect is None: + raise self.skipTest("Test requires Mosaic GPU dialect") + super().setUp() + + @parameterized.product( + layout=tuple(mtu.RegisterLayout), + dtype=(jnp.bfloat16, jnp.int8), + optimized=(True, False, None), + ) + def test_smem_gmem_registers_load_store(self, layout, dtype, optimized): + if layout == mtu.RegisterLayout.WG_SPLAT: + self.skipTest("WG_SPLAT is not supported for `vector.load`.") + # We don't infer optimized transfer-compatible transforms for load/store to + # registers with TCGEN05_TMEM_NATIVE layout. + if optimized and layout == mtu.RegisterLayout.TCGEN05_TMEM_NATIVE: + self.skipTest( + "Optimized loads not supported for TCGEN05_TMEM_NATIVE layout" + ) + shape = (128, 128) + layout_attr = layout.to_layout_attr(shape, dtype) + + def body(ctx, param: ir.Value, result: ir.Value, smem: list[ir.Value]): + del ctx + + # GMEM -> Registers + reg = mgpu_dialect.vector_load(param) + reg = mgpu_dialect.layout_cast(reg, layout_attr) + + # Registers -> SMEM + mgpu_dialect.vector_store(reg, smem, optimized=optimized) + + # SMEM -> Registers + reg = mgpu_dialect.vector_load(smem, optimized=optimized) + reg = mgpu_dialect.layout_cast(reg, layout_attr) + + # Registers -> GMEM + mgpu_dialect.vector_store(reg, result) + + jax_shape = jax.ShapeDtypeStruct(shape, dtype) + kernel = mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=jax_shape, + out_shape=jax_shape, + smem_scratch_shape=jax_shape, + thread_semantics=mgpu.LoweringSemantics.Warpgroup, ) - op.attributes["in_transforms"] = ir.ArrayAttr.get(in_transforms) - + param = self.prng.uniform(-1, 1, shape).astype(dtype) + self.assertArraysEqual(kernel(param), param) -class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase): - """Device tests with lowering from the MLIR dialect and layout inference.""" + @parameterized.parameters( + (mgpu.WGMMA_LAYOUT, mgpu.WGMMA_TRANSPOSED_LAYOUT), + (mgpu.WGMMA_TRANSPOSED_LAYOUT, mgpu.WGMMA_LAYOUT), + ) + def test_transposed_load_store(self, src_layout, dst_layout): + def is_transposed(layout): + return layout == mgpu.WGMMA_TRANSPOSED_LAYOUT + + def body(ctx, src_ref, dst_ref, scratch): + del ctx, scratch + if is_transposed(src_layout): + src_ref = utils.memref_transpose(src_ref, (1, 0)) + if is_transposed(dst_layout): + dst_ref = utils.memref_transpose(dst_ref, (1, 0)) + src_reg = mgpu_dialect.vector_load(src_ref) + src_layout_attr = layouts.to_tiled_layout_attr(src_layout) + src_reg = mgpu_dialect.layout_cast(src_reg, src_layout_attr) + dst_layout_attr = layouts.to_tiled_layout_attr(dst_layout) + dst_reg = mgpu_dialect.layout_cast(src_reg, dst_layout_attr) + mgpu_dialect.vector_store(dst_reg, dst_ref) - def setUp(self): - if mgpu_dialect is None: - raise self.skipTest("Test requires Mosaic GPU dialect") - super().setUp() + shape = (128, 128) + dtype = jnp.float32 + kernel = mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(jax.ShapeDtypeStruct(shape, dtype),), + out_shape=jax.ShapeDtypeStruct(shape, dtype), + smem_scratch_shape=[], + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + x = self.prng.uniform(-1, 1, shape).astype(dtype) + np.testing.assert_array_equal(kernel(x), x.T) def test_pointwise_kernel(self): def add(ctx, a, b, result, smem): del ctx, smem - shape = ir.MemRefType(a.type).shape - elt_type = ir.MemRefType(a.type).element_type - - zero_index = arith.constant(ir.IndexType.get(), 0) - zero_vector_indices = [zero_index] * len(shape) # GMEM -> registers - ab_type = ir.VectorType.get(shape, elt_type) - a = vector.load(ab_type, a, zero_vector_indices) - b = vector.load(ab_type, b, zero_vector_indices) + a = mgpu_dialect.vector_load(a) + b = mgpu_dialect.vector_load(b) # Computation add = arith.addf(a, b) # Registers -> GMEM - vector.store(add, result, zero_vector_indices) + mgpu_dialect.vector_store(add, result) dtype = jnp.bfloat16 shape = (128, 128) @@ -2476,13 +4807,13 @@ def add(ctx, a, b, result, smem): in_shape=(jax_shape, jax_shape), out_shape=jax_shape, smem_scratch_shape=[], - thread_semantics=mgpu.ThreadSemantics.Warpgroup, + thread_semantics=mgpu.LoweringSemantics.Warpgroup, ) x = self.prng.uniform(-1, 1, shape).astype(dtype) y = self.prng.uniform(-1, 1, shape).astype(dtype) - self.assertArraysEqual(jax.jit(kernel)(x, y), x + y) + self.assertArraysEqual(kernel(x, y), x + y) @staticmethod def kernel_with_tma_cases(dtype: jnp.dtype): @@ -2568,33 +4899,27 @@ def add( ): del ctx smem_ref, tma_barrier = smem - dialect_barrier = tma_barrier.as_dialect_barrier_memref() elt_type = ir.MemRefType(in_gmem_ref.type).element_type memref_bytes = utils.bytewidth(elt_type) * math.prod( test_case.shape_sliced ) - mgpu_dialect.arrive_expect_tx( - barrier=dialect_barrier, expect_tx= memref_bytes - ) i32 = ir.IntegerType.get_signless(32) slice_indices = [arith.constant(i32, i) for i in test_case.slice_indices] # GMEM -> SMEM + tma_barrier.arrive_expect_tx(memref_bytes) load_op = mgpu_dialect.AsyncLoadOp( source=in_gmem_ref, destination=smem_ref, - barrier=dialect_barrier, + barrier=tma_barrier.as_barrier_memref(), indices=slice_indices, slice_lengths=test_case.slice_lengths, collective=ir.ArrayAttr.get([]), ) set_in_transforms(load_op, [test_case.transforms]) - - parities = memref.load(tma_barrier.phases, []) - parity, _ = tma_barrier.update_parities(parities) - mgpu_dialect.wait(dialect_barrier, parity) + tma_barrier.wait() # SMEM -> GMEM zero_index = arith.constant(i32, 0) @@ -2621,7 +4946,7 @@ def add( jax_shape_sliced, core.TMABarrier(1), ], - thread_semantics=mgpu.ThreadSemantics.Warpgroup, + thread_semantics=mgpu.LoweringSemantics.Warpgroup, ) x = self.prng.uniform(-1, 1, test_case.shape).astype(dtype) @@ -2631,7 +4956,7 @@ def add( for i, l in zip(test_case.slice_indices, test_case.slice_lengths) ) self.assertArraysEqual( - jax.jit(kernel)(x), + kernel(x), (x[input_slice]).reshape(test_case.shape_sliced), ) @@ -2645,25 +4970,21 @@ def add( ): del ctx a_smem_ref, b_smem_ref, result_smem_ref, tma_barrier = smem - dialect_barrier = tma_barrier.as_dialect_barrier_memref() memref_type = ir.MemRefType(a_gmem_ref.type) shape = memref_type.shape elt_type = memref_type.element_type - memref_bytes = utils.bytewidth(elt_type) * math.prod(shape) - mgpu_dialect.arrive_expect_tx( - barrier=dialect_barrier, expect_tx=2 * memref_bytes - ) - zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0) zero_slice_indices = [zero_i32] * memref_type.rank # GMEM -> SMEM + memref_bytes = utils.bytewidth(elt_type) * math.prod(shape) + tma_barrier.arrive_expect_tx(2 * memref_bytes) mgpu_dialect.async_load( source=a_gmem_ref, destination=a_smem_ref, - barrier=dialect_barrier, + barrier=tma_barrier.as_barrier_memref(), indices=zero_slice_indices, slice_lengths=shape, collective=ir.ArrayAttr.get([]), @@ -2671,29 +4992,22 @@ def add( mgpu_dialect.async_load( source=b_gmem_ref, destination=b_smem_ref, - barrier=dialect_barrier, + barrier=tma_barrier.as_barrier_memref(), indices=zero_slice_indices, slice_lengths=shape, collective=ir.ArrayAttr.get([]), ) - - parities = memref.load(tma_barrier.phases, []) - parity, _ = tma_barrier.update_parities(parities) - mgpu_dialect.wait(dialect_barrier, parity) - - zero_index = arith.constant(ir.IndexType.get(), 0) - zero_vector_indices = [zero_index] * memref_type.rank + tma_barrier.wait() # SMEM -> registers - ab_type = ir.VectorType.get(shape, elt_type) - a = vector.load(ab_type, a_smem_ref, zero_vector_indices) - b = vector.load(ab_type, b_smem_ref, zero_vector_indices) + a = mgpu_dialect.vector_load(a_smem_ref) + b = mgpu_dialect.vector_load(b_smem_ref) # Computation add = arith.addf(arith.addf(a, b), b) # Registers -> SMEM - vector.store(add, result_smem_ref, zero_vector_indices) + mgpu_dialect.vector_store(add, result_smem_ref) # SMEM -> GMEM mgpu_dialect.async_store( @@ -2720,38 +5034,676 @@ def add( spec, core.TMABarrier(1), ], - thread_semantics=mgpu.ThreadSemantics.Warpgroup, + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + + x = self.prng.uniform(-1, 1, spec.shape).astype(dtype) + y = self.prng.uniform(-1, 1, spec.shape).astype(dtype) + + self.assertArraysEqual(kernel(x, y), x + y + y) + + @parameterized.parameters( + ((64,), (64, 128), [0]), + ((64,), (128, 64), [1]), + ) + def test_broadcast_in_dim(self, input_shape, output_shape, bcast_dims): + element_value = 42.0 + layout = fa.WGMMA_ROW_LAYOUT if bcast_dims[0] == 0 else fa.WGMMA_COL_LAYOUT + def body(ctx, result_gmem_ref, scratch): + del ctx, scratch + + # Create input in registers + f32 = ir.F32Type.get() + x_type = ir.VectorType.get(input_shape, f32) + c = arith.constant(f32, element_value) + x = vector.broadcast(x_type, c) + + # Computation + out_type = ir.VectorType.get(output_shape, f32) + cast = mgpu_dialect.layout_cast(x, layouts.to_layout_attr(layout)) + expanded = mgpu_dialect.broadcast_in_dim(out_type, cast, bcast_dims) + + # Registers -> GMEM + mgpu_dialect.vector_store(expanded, result_gmem_ref) + + dtype = jnp.float32 + kernel = mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(), + out_shape=jax.ShapeDtypeStruct(output_shape, dtype), + smem_scratch_shape=[], + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + + x = np.full(input_shape, element_value, dtype=dtype) + self.assertArraysEqual( + kernel(), jax.lax.broadcast_in_dim(x, output_shape, bcast_dims) + ) + + @parameterized.parameters( + (jnp.float32, 5.0, 2.0, vector.CombiningKind.ADD), + (jnp.float32, 5.0, 2.0, vector.CombiningKind.MAXIMUMF), + (jnp.float32, 5.0, 7.0, vector.CombiningKind.MAXIMUMF), + (jnp.int32, 5, 2, vector.CombiningKind.MAXSI), + (jnp.int32, -5, -2, vector.CombiningKind.MAXSI), + (jnp.int32, -2, -5, vector.CombiningKind.MAXSI), + (jnp.uint32, 5, 2, vector.CombiningKind.MAXUI), + (jnp.uint32, 2, 5, vector.CombiningKind.MAXUI), + # + # TODO(dasenov): Add tests for wgmma_col_layout output once + # fragmented_array.reduce supports that. + ) + def test_vector_multi_dim_reduction( + self, + dtype, + input_value, + init_value, + kind, + ): + input_shape = (128, 64) + output_shape = (128,) + red_dims = [1] + + def body(ctx, result_gmem_ref, scratch): + del ctx, scratch + el_type = utils.dtype_to_ir_type(dtype) + + # Create source in registers + source_type = ir.VectorType.get(input_shape, el_type) + c = arith.constant(el_type, input_value) + source = vector.broadcast(source_type, c) + + # Create accumulator in registers + acc_type = ir.VectorType.get(output_shape, el_type) + c = arith.constant(el_type, init_value) + acc = vector.broadcast(acc_type, c) + + # Cast inputs + source = mgpu_dialect.layout_cast( + source, layouts.to_layout_attr(fa.WGMMA_LAYOUT) + ) + acc_layout = ( + fa.WGMMA_ROW_LAYOUT if red_dims[0] == 1 else fa.WGMMA_COL_LAYOUT + ) + acc = mgpu_dialect.layout_cast(acc, layouts.to_layout_attr(acc_layout)) + + # Computation + reduced = vector.multi_reduction(kind, source, acc, red_dims) + + # Registers -> GMEM + mgpu_dialect.vector_store(reduced, result_gmem_ref) + + kernel = mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(), + out_shape=jax.ShapeDtypeStruct(output_shape, dtype), + smem_scratch_shape=[], + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + + source = np.full(input_shape, input_value, dtype=dtype) + acc = np.full(output_shape, init_value, dtype=dtype) + if kind == vector.CombiningKind.ADD: + red = jax.lax.reduce_sum(source, red_dims) + red = red + acc + else: + red = jax.lax.reduce_max(source, red_dims) + red = jax.lax.max(red, acc) + self.assertArraysEqual(kernel(), red) + + @parameterized.parameters(fa.WGMMA_ROW_LAYOUT, fa.WGMMA_COL_LAYOUT) + def test_wgmma_row_col_store(self, in_layout): + element_value = 42.0 + shape = (64, ) + def body(ctx, result_gmem_ref, smem): + del ctx + + # Create input in registers + f32 = ir.F32Type.get() + x_type = ir.VectorType.get(shape, f32) + c = arith.constant(f32, element_value) + x = vector.broadcast(x_type, c) + cast = mgpu_dialect.layout_cast(x, layouts.to_layout_attr(in_layout)) + + # Registers -> SMEM + mgpu_dialect.vector_store(cast, smem) + + # SMEM -> GMEM + zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0) + mgpu_dialect.async_store( + source=smem, + destination=result_gmem_ref, + indices=[zero_i32], + slice_lengths=shape, + ) + nvvm.cp_async_bulk_wait_group(0) + utils.warpgroup_barrier() + + dtype = jnp.float32 + kernel = mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(), + out_shape=jax.ShapeDtypeStruct(shape, dtype), + smem_scratch_shape=jax.ShapeDtypeStruct(shape, dtype), + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + + x = np.full(shape, element_value, dtype=dtype) + self.assertArraysEqual(kernel(), x) + + @parameterized.parameters( + # Positive offsets will be passsed as static offsets. + # Negative offsets will be converted to positive dynamic offsets. + dict( + full_shape=(2, 3, 128, 64), + sub_shape=(32, 64), + offsets=[-1, 0, -96, 0], + tiling=None, + swizzle=None, + ), + dict( + full_shape=(3, 128, 64), + sub_shape=(32, 64), + offsets=[-2, -96, 0], + tiling=[32, 64], + swizzle=mgpu_dialect.SwizzlingMode.k128ByteSwizzle, + ), + dict( + full_shape=(128, 128), + sub_shape=(64,), + offsets=[-1, 64], + tiling=[64], + swizzle=mgpu_dialect.SwizzlingMode.k128ByteSwizzle, + ), + ) + def test_subview( + self, + full_shape, + sub_shape, + offsets, + tiling, + swizzle, + ): + assert len(sub_shape) <= 2 + sizes = [1] * (len(full_shape) - len(sub_shape)) + list(sub_shape) + + def body( + ctx: launch_context.LaunchContext, + full_gmem_ref: ir.Value, + sub_gmem_ref: ir.Value, + smem: list[ir.Value], + ): + del ctx + full_smem_ref, tma_barrier = smem + + zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0) + # GMEM -> SMEM + operand_elt_type = ir.MemRefType(full_gmem_ref.type).element_type + bytes = utils.bytewidth(operand_elt_type) * math.prod(full_shape) + tma_barrier.arrive_expect_tx(bytes) + mgpu_dialect.async_load( + source=full_gmem_ref, + destination=full_smem_ref, + barrier=tma_barrier.as_barrier_memref(), + indices=[zero_i32] * len(full_shape), + slice_lengths=full_shape, + collective=ir.ArrayAttr.get([]), + ) + tma_barrier.wait() + + # SubView + mixed_offsets = [ + o if o >= 0 else arith.constant(ir.IndexType.get(), -o) + for o in offsets + ] + + full_ref_type = ir.MemRefType(full_smem_ref.type) + dynamic = ir.ShapedType.get_dynamic_stride_or_offset() + rhs_subview_ref_type = ir.MemRefType.get( + shape=sub_shape, + element_type=full_ref_type.element_type, + layout=ir.StridedLayoutAttr.get( + dynamic, [full_shape[-1], 1] if len(sub_shape) == 2 else [1] + ), + memory_space=full_ref_type.memory_space, + ) + sub_smem_ref = memref.subview( + full_smem_ref, + mixed_offsets, + sizes, + strides=[1] * len(sizes), + result_type=rhs_subview_ref_type, + ) + + transforms = [] + if tiling is not None: + transforms.append(mgpu_dialect.TileTransformAttr.get(tiling)) + if swizzle is not None: + transforms.append(mgpu_dialect.SwizzleTransformAttr.get(swizzle)) + + if transforms: + sub_smem_ref = mgpu_dialect.with_transforms( + sub_smem_ref, + transforms=ir.ArrayAttr.get(transforms), + ) + + # SMEM -> GMEM + mgpu_dialect.async_store( + source=sub_smem_ref, + destination=sub_gmem_ref, + indices=[zero_i32] * len(sub_shape), + slice_lengths=sub_shape, + ) + nvvm.cp_async_bulk_wait_group(0) + + el_type = jnp.bfloat16 + full_jax_shape = jax.ShapeDtypeStruct(full_shape, el_type) + result_jax_shape = jax.ShapeDtypeStruct(sub_shape, el_type) + + kernel = mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(full_jax_shape), + out_shape=result_jax_shape, + smem_scratch_shape=[full_jax_shape, core.TMABarrier(1)], + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + x = self.prng.uniform(0, 10, full_shape).astype(el_type) + slicing = tuple(slice(abs(o), abs(o) + s) for o, s in zip(offsets, sizes)) + self.assertArraysEqual(kernel(x), x[slicing].reshape(sub_shape)) + + def test_custom_primitive_op(self): + # This test exercises the following cases: + # - The lowering handles nested blocks and regions (e.g. `scf.IfOp`). + # - The lowering updates references to inlined operations. + def body(ctx, result, scratch): + del ctx, scratch + i64 = ir.IntegerType.get_signless(64) + index = ir.IndexType.get() + op = mgpu_dialect.CustomPrimitiveOp( + result=[], + operands_=[result], + in_layouts=[], + in_transforms=[], + out_layouts=[], + ) + args_ty = [arg.type for arg in op.operands_] + block = op.body.blocks.append(*args_ty) + with ir.InsertionPoint(block): + is_leader_thread = single_thread_predicate() + with when(is_leader_thread): + c5 = arith.constant(i64, 5) + memref.store(c5, block.arguments[0], [c(0, index)]) + mgpu_dialect.return_([]) + + kernel = mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + cluster=(1, 1, 1), + block=(128, 1, 1), + in_shape=(), + out_shape=jax.ShapeDtypeStruct((1,), jnp.int64), + smem_scratch_shape=(), + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + self.assertArraysEqual(kernel(), [5]) + + def test_profiler(self): + def body(ctx, input, result, scratch): + del scratch + with ctx.named_region("load"): + reg = mgpu_dialect.vector_load(input) + with ctx.named_region("store"): + mgpu_dialect.vector_store(reg, result) + + dtype = jnp.bfloat16 + shape = (128, 128) + jax_shape = jax.ShapeDtypeStruct(shape, dtype) + with tempfile.TemporaryDirectory() as tmpdir: + kernel = mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(jax_shape), + out_shape=jax_shape, + smem_scratch_shape=[], + prof_spec=profiler.ProfilerSpec(1024, dump_path=tmpdir), + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + param = self.prng.uniform(-1, 1, shape).astype(dtype) + self.assertArraysEqual(kernel(param), param) + [name] = os.listdir(tmpdir) + with open(os.path.join(tmpdir, name)) as f: + data = f.read() + self.assertEqual(data.count('"name": "load"'), 2) + self.assertEqual(data.count('"name": "store"'), 2) + + @parameterized.parameters(((128,),), ((128, 128),)) + def test_tma_collective_async_cp(self, in_shape): + def body(ctx, src, dst, scratch): + del ctx + tmp, barrier = scratch + i32 = ir.IntegerType.get_signless(32) + zero_i32 = arith.constant(i32, 0) + src_type = ir.MemRefType(src.type) + barrier.arrive_expect_tx( + utils.bytewidth(src_type.element_type) * math.prod(src_type.shape) + ) + mgpu_dialect.async_load( + source=src, + destination=tmp, + indices=[zero_i32] * src_type.rank, + slice_lengths=src_type.shape, + collective=ir.ArrayAttr.get([ + ir.IntegerAttr.get(i32, mgpu_dialect.Dimension.x), + ]), + barrier=barrier.as_barrier_memref(), + ) + barrier.wait() + block_id = gpu.cluster_block_id(gpu.Dimension.x) + block_id = arith.index_cast(i32, block_id) + mgpu_dialect.async_store( + source=tmp, + destination=dst, + indices=[block_id] + [zero_i32] * src_type.rank, + slice_lengths=[-1, *src_type.shape], + ) + + dtype = jnp.float32 + kernel = mgpu.as_gpu_kernel( + body, + grid=(2, 1, 1), + cluster=(2, 1, 1), + block=(128, 1, 1), + in_shape=jax.ShapeDtypeStruct(in_shape, dtype), + out_shape=jax.ShapeDtypeStruct((2, *in_shape), dtype), + smem_scratch_shape=[ + jax.ShapeDtypeStruct(in_shape, dtype), + mgpu.TMABarrier(), + ], + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + x = self.prng.uniform(-1, 1, in_shape).astype(dtype) + self.assertArraysEqual(kernel(x), jnp.stack([x, x], axis=0)) + + def test_vector_extract_strided_slice(self): + def body(ctx, src, dst, scratch): + del ctx, scratch + src_vec = mgpu_dialect.vector_load(src) + src_vec = mgpu_dialect.layout_cast( + src_vec, layouts.to_layout_attr(fa.WGMMA_LAYOUT) + ) + dst_type = ir.MemRefType(dst.type) + dest_vec_type = ir.VectorType.get(dst_type.shape, dst_type.element_type) + sliced_vec = vector.extract_strided_slice( + dest_vec_type, + src_vec, + offsets=[0, 64], + sizes=[64, 64], + strides=[1, 1], + ) + mgpu_dialect.vector_store(sliced_vec, dst) + + dtype = jnp.float32 + kernel = mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=jax.ShapeDtypeStruct((128, 128), dtype), + out_shape=jax.ShapeDtypeStruct((64, 64), dtype), + smem_scratch_shape=[], + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + + x = self.prng.uniform(-1, 1, (128, 128)).astype(dtype) + self.assertArraysEqual(kernel(x), x[0:64, 64:128]) + + @parameterized.product( + dtype=(jnp.float32, jnp.int32, jnp.uint32), + dimension=(0, 1), + ) + def test_broadcasted_iota(self, dtype, dimension): + def body(ctx, out, scratch): + del ctx, scratch + result_type = ir.VectorType.get(out.type.shape, out.type.element_type) + iota = mgpu_dialect.broadcasted_iota(result_type, dimension) + mgpu_dialect.vector_store(iota, out) + + shape = (128, 128) + kernel = mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(), + out_shape=jax.ShapeDtypeStruct(shape, dtype), + smem_scratch_shape=[], + thread_semantics=mgpu.LoweringSemantics.Warpgroup, ) + expected = jax.lax.broadcasted_iota(dtype, shape, dimension) + self.assertArraysEqual(kernel(), expected) + + @parameterized.parameters( + ((4, 64, 64), [[0], [1], [2]], (4, 64, 64), False), + ((4, 64, 64), [[0], [1, 2], [3]], (4, 4, 16, 64), False), + ((4, 8, 16, 64), [[0], [1], [2, 3], [4]], (4, 8, 2, 8, 64), False), + ((4, 64, 64), [[0, 1], [2], [3]], (2, 2, 64, 64), True), + ) + def test_memref_expand_shape( + self, input_shape, reassociation, output_shape, has_transforms + ): + def body( + ctx: launch_context.LaunchContext, + in_gmem_ref: ir.Value, + out_gmem_ref: ir.Value, + smem: list[ir.Value], + ): + del ctx + in_smem_ref, tma_barrier = smem + + zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0) + # GMEM -> SMEM + operand_elt_type = ir.MemRefType(in_gmem_ref.type).element_type + bytes = utils.bytewidth(operand_elt_type) * math.prod(input_shape) + tma_barrier.arrive_expect_tx(bytes) + mgpu_dialect.async_load( + source=in_gmem_ref, + destination=in_smem_ref, + barrier=tma_barrier.as_barrier_memref(), + indices=[zero_i32] * len(input_shape), + slice_lengths=input_shape, + collective=ir.ArrayAttr.get([]), + ) + tma_barrier.wait() + + # ExpandShape + expanded_smem_ref = memref.expand_shape( + result=ir.MemRefType.get( + output_shape, + in_smem_ref.type.element_type, + memory_space=in_smem_ref.type.memory_space, + ), + src=in_smem_ref, + reassociation=reassociation, + output_shape=[], + static_output_shape=output_shape, + ) + + if has_transforms: + transforms = [ + mgpu_dialect.TileTransformAttr.get((32,)), + mgpu_dialect.SwizzleTransformAttr.get(64), + ] + expanded_smem_ref = mgpu_dialect.with_transforms( + expanded_smem_ref, transforms=ir.ArrayAttr.get(transforms), + ) + + # SMEM -> GMEM + mgpu_dialect.async_store( + source=expanded_smem_ref, + destination=out_gmem_ref, + indices=[zero_i32] * len(output_shape), + slice_lengths=output_shape, + ) + nvvm.cp_async_bulk_wait_group(0) + + el_type = jnp.bfloat16 + in_jax_shape = jax.ShapeDtypeStruct(input_shape, el_type) + result_jax_shape = jax.ShapeDtypeStruct(output_shape, el_type) + + kernel = mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(in_jax_shape), + out_shape=result_jax_shape, + smem_scratch_shape=[in_jax_shape, core.TMABarrier(1)], + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + x = self.prng.uniform(0, 10, input_shape).astype(el_type) + self.assertArraysEqual(kernel(x), x.reshape(output_shape)) + + @parameterized.product( + dtype=(jnp.int32, jnp.int64, jnp.uint32, jnp.uint64, jnp.float32, jnp.float16, jnp.bfloat16), + reduction_op=("add", "min", "max", "inc", "dec", "and", "or", "xor"), + ) + def test_async_store_reduction(self, dtype, reduction_op): + + if not config.enable_x64.value and dtype in (jnp.int64, jnp.uint64): + self.skipTest("x64 support is disabled") + + # TODO(b/415721295):Clean up after the minimal jaxlib version is 0.8.2. + if not hasattr(mgpu_dialect, "TMAReduction"): + self.skipTest("The mgpu_dialect.TMAReduction attribute is required.") + + if reduction_op in ("min", "max"): + if dtype in (jnp.int32, jnp.int64): + reduction_op = "s" + reduction_op + elif dtype in (jnp.uint32, jnp.uint64): + reduction_op = "u" + reduction_op + + if reduction_op in ("smin", "smax", "umin", "umax") and not hasattr(mgpu_dialect.TMAReduction, "Smin"): + self.skipTest("The Smin/Smax/Umin/Umax reduction types are required.") + + if ( + not launch_context._is_tma_reduction_op_supported( + reduction_op, + utils.dtype_to_ir_type(dtype), + ) + or ( + dtype in (jnp.uint32, jnp.uint64) + and reduction_op in ("smin", "smax") + ) + or ( + dtype in (jnp.int32, jnp.int64) and reduction_op in ("umin", "umax") + ) + or dtype == jnp.int32 and reduction_op in ("inc", "dec") + ): + self.skipTest("TMA does not support this reduction op for this dtype") + + shape = (8, 128) + + def body(ctx, src, dst, smem): + del ctx + src_smem_ref, tma_barrier = smem + i32 = ir.IntegerType.get_signless(32) + zero = arith.constant(i32, 0) + indices = [zero, zero] + slice_lengths = src_smem_ref.type.shape + + tma_barrier.arrive_expect_tx( + utils.bitwidth(src_smem_ref.type.element_type) * math.prod(shape) // 8 + ) + + mgpu_dialect.async_load( + source=src, + destination=src_smem_ref, + barrier=tma_barrier.as_barrier_memref(), + indices=indices, + slice_lengths=slice_lengths, + collective=ir.ArrayAttr.get([]), + ) + + tma_barrier.wait() + + reduction_attr = getattr( + mgpu_dialect.TMAReduction, reduction_op.capitalize() + ) + + mgpu_dialect.async_store( + source=src_smem_ref, + destination=dst, + indices=indices, + slice_lengths=slice_lengths, + reduction_op=reduction_attr, + ) + nvvm.cp_async_bulk_wait_group(0) + + prng_key = jax.random.key(1234) + k0, k1 = jax.random.split(prng_key, 2) + if dtype in (jnp.bfloat16, jnp.float16, jnp.float32): + src = jax.random.uniform(k0, shape, dtype, -10, 10) + dst = jax.random.uniform(k1, shape, dtype, -10, 10) + else: + src = jax.random.randint(k0, shape, -10, 10).astype(dtype) + dst = jax.random.randint(k1, shape, -10, 10).astype(dtype) + + if reduction_op == "add": + expected = src + dst + elif reduction_op in ("min", "smin", "umin"): + expected = jnp.minimum(src, dst) + elif reduction_op in ("max", "smax", "umax"): + expected = jnp.maximum(src, dst) + elif reduction_op == "and": + expected = src & dst + elif reduction_op == "or": + expected = src | dst + elif reduction_op == "xor": + expected = src ^ dst + elif reduction_op == "inc": + expected = jnp.where(dst >= src, 0, dst + 1) + elif reduction_op == "dec": + expected = jnp.where((dst == 0) | (dst > src), src, dst - 1) + else: + raise ValueError(f"Unsupported reduction op: {reduction_op}") - x = self.prng.uniform(-1, 1, spec.shape).astype(dtype) - y = self.prng.uniform(-1, 1, spec.shape).astype(dtype) + jax_shape = jax.ShapeDtypeStruct(shape, dtype) + kernel = mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(jax_shape), + out_shape=(), + inout_shape=(jax_shape,), + smem_scratch_shape=[jax_shape, core.TMABarrier(1)], + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) - self.assertArraysEqual(jax.jit(kernel)(x, y), x + y + y) + np.testing.assert_array_equal(kernel(src, dst)[0], expected) class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): - @parameterized.named_parameters( - ( - f"swizzle={int(swizzle)}_{transpose_lhs=}_{transpose_rhs=}_{lhs_in_registers=}", - swizzle, - transpose_lhs, - transpose_rhs, - lhs_in_registers, - ) - for swizzle in mgpu_dialect.SwizzlingMode - for transpose_lhs in [False, True] - for transpose_rhs in [False, True] - for lhs_in_registers in [False, True] + @parameterized.product( + swizzle=tuple(mgpu_dialect.SwizzlingMode), + transpose_lhs=(False, True), + transpose_rhs=(False, True), + lhs_in_registers=(False, True), ) def test_wgmma_kernel_with_tma( - self, swizzle, transpose_lhs, transpose_rhs, load_a_in_registers + self, swizzle, transpose_lhs, transpose_rhs, lhs_in_registers ): if swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle: self.skipTest("No swizzle is not supported by wgmma") - if transpose_lhs or transpose_rhs: - self.skipTest("Transposes are not supported by transform inference yet.") + if transpose_lhs and lhs_in_registers: + self.skipTest("The A operand can only be transposed if it is in SMEM.") swizzle_elems = swizzle // np.dtype(jnp.bfloat16).itemsize tiling_m, tiling_n, tiling_k = 64, swizzle_elems, swizzle_elems @@ -2772,23 +5724,18 @@ def matmul( ): del ctx lhs_smem_ref, rhs_smem_ref, result_smem_ref, tma_barrier = smem - dialect_barrier = tma_barrier.as_dialect_barrier_memref() operand_elt_type = ir.MemRefType(lhs_gmem_ref.type).element_type bytes_a = utils.bytewidth(operand_elt_type) * math.prod(lhs_shape) bytes_b = utils.bytewidth(operand_elt_type) * math.prod(rhs_shape) - mgpu_dialect.arrive_expect_tx( - barrier=dialect_barrier, - expect_tx=bytes_a + bytes_b, - ) - - zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0) # GMEM -> SMEM + zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0) + tma_barrier.arrive_expect_tx(bytes_a + bytes_b) mgpu_dialect.async_load( source=lhs_gmem_ref, destination=lhs_smem_ref, - barrier=dialect_barrier, + barrier=tma_barrier.as_barrier_memref(), indices=[zero_i32] * len(lhs_shape), slice_lengths=lhs_shape, collective=ir.ArrayAttr.get([]), @@ -2796,15 +5743,12 @@ def matmul( mgpu_dialect.async_load( source=rhs_gmem_ref, destination=rhs_smem_ref, - barrier=dialect_barrier, + barrier=tma_barrier.as_barrier_memref(), indices=[zero_i32] * len(rhs_shape), slice_lengths=rhs_shape, collective=ir.ArrayAttr.get([]), ) - - parities = memref.load(tma_barrier.phases, []) - parity, _ = tma_barrier.update_parities(parities) - mgpu_dialect.wait(dialect_barrier, parity) + tma_barrier.wait() # Computation shape_result = ir.MemRefType(result_gmem_ref.type).shape @@ -2814,19 +5758,16 @@ def matmul( zero_acc = arith.constant( result_elt_type, ir.FloatAttr.get(acc_elt_type, 0.0) ) - accumulator = vector.splat(acc_type, zero_acc) + accumulator = vector.broadcast(acc_type, zero_acc) if transpose_lhs: lhs_smem_ref = utils.memref_transpose(lhs_smem_ref, (1, 0)) if transpose_rhs: rhs_smem_ref = utils.memref_transpose(rhs_smem_ref, (1, 0)) - zero_index = arith.constant(ir.IndexType.get(), 0) - if load_a_in_registers: + if lhs_in_registers: # SMEM -> Registers - lhs_ty = ir.VectorType.get(lhs_shape, operand_elt_type) - zero_vector_indices = [zero_index] * len(lhs_shape) - lhs_operand = vector.load(lhs_ty, lhs_smem_ref, zero_vector_indices) + lhs_operand = mgpu_dialect.vector_load(lhs_smem_ref) else: lhs_operand = lhs_smem_ref @@ -2840,7 +5781,7 @@ def matmul( nvvm.wgmma_wait_group_sync_aligned(0) # Registers -> SMEM - vector.store(result, result_smem_ref, [zero_index] * len(shape_result)) + mgpu_dialect.vector_store(result, result_smem_ref) # SMEM -> GMEM mgpu_dialect.async_store( @@ -2868,7 +5809,7 @@ def matmul( result_jax_shape, core.TMABarrier(1), ], - thread_semantics=mgpu.ThreadSemantics.Warpgroup, + thread_semantics=mgpu.LoweringSemantics.Warpgroup, ) prng_key = jax.random.key(1234) @@ -2879,15 +5820,606 @@ def matmul( transpose = lambda x, t: x.T if t else x self.assertArraysAllClose( - jax.jit(kernel)(x, y), - np.matmul( - transpose(x, transpose_lhs), - transpose(y, transpose_rhs) - ), + kernel(x, y), + np.matmul(transpose(x, transpose_lhs), transpose(y, transpose_rhs)), atol=0, rtol=0, ) + @parameterized.product( + dtype=(jnp.int8, jnp.uint8), + lhs_in_smem=(False, True), + ) + def test_integer_wgmma(self, dtype, lhs_in_smem): + m, k, n = 64, 128, 64 + + def body(ctx, lhs_gmem, rhs_gmem, result_gmem, scratch): + del ctx + lhs_smem, rhs_smem, tma_barrier = scratch + + i32 = ir.IntegerType.get_signless(32) + zero = arith.constant(i32, 0) + + tma_barrier.arrive_expect_tx(m * k + k * n) + mgpu_dialect.async_load( + source=lhs_gmem, + destination=lhs_smem, + barrier=tma_barrier.as_barrier_memref(), + indices=[zero, zero], + slice_lengths=lhs_smem.type.shape, + collective=ir.ArrayAttr.get([]), + ) + mgpu_dialect.async_load( + source=rhs_gmem, + destination=rhs_smem, + barrier=tma_barrier.as_barrier_memref(), + indices=[zero, zero], + slice_lengths=rhs_smem.type.shape, + collective=ir.ArrayAttr.get([]), + ) + tma_barrier.wait() + + acc_type = ir.VectorType.get((m, n), i32) + acc = vector.broadcast(acc_type, zero) + lhs = lhs_smem if lhs_in_smem else mgpu_dialect.vector_load(lhs_smem) + # Only f16 WGMMA supports transposes + rhs_smem = utils.memref_transpose(rhs_smem, (1, 0)) + result = mgpu_dialect.wgmma(acc, lhs, rhs_smem) + nvvm.wgmma_commit_group_sync_aligned() + nvvm.wgmma_wait_group_sync_aligned(0) + mgpu_dialect.vector_store(result, result_gmem) + + kernel = mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=( + jax.ShapeDtypeStruct((m, k), dtype), + jax.ShapeDtypeStruct((n, k), dtype), + ), + out_shape=jax.ShapeDtypeStruct((m, n), jnp.int32), + smem_scratch_shape=[ + jax.ShapeDtypeStruct((m, k), dtype), + jax.ShapeDtypeStruct((n, k), dtype), + core.TMABarrier(1), + ], + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + # Use small values to avoid overflow, [0, 8) for u8 and (-8, 8) for s8. + is_signed = jnp.issubdtype(dtype, jnp.signedinteger) + low, high = (-8, 8) if is_signed else (0, 8) + lhs = self.prng.uniform(low, high, (m, k)).astype(dtype) + rhs = self.prng.uniform(low, high, (n, k)).astype(dtype) + self.assertArraysEqual( + kernel(lhs, rhs), + np.matmul(lhs.astype(jnp.int32), rhs.astype(jnp.int32).T), + ) + + +class MosaicGpuDialectTCGen05Test(TestCase, jtu.JaxTestCase): + + def setUp(self): + super().setUp() + capabilities = ("10.0", "10.1") + if not any(jtu.is_cuda_compute_capability_equal(sm) for sm in capabilities): + self.skipTest("Only works on GPU with capability sm_100a or sm_101a") + + @parameterized.named_parameters( + ("unpacked", (128, 77), jnp.bfloat16, 1, False), + ("packed", (128, 128), jnp.bfloat16, 2, False), + ("collective", (128, 64), jnp.bfloat16, 1, True), + ) + def test_tmem_alloc_dealloc(self, shape, dtype, packing, collective): + tmem_type = ir.MemRefType.get( + shape, + utils.dtype_to_ir_type(dtype), + memory_space=utils.tmem(), + ) + + def body( + ctx: launch_context.LaunchContext, x: ir.Value, smem: list[ir.Value] + ): + # We need to have a result `x` otherwise the kernel will not be generated. + del ctx, x + tmem_ref = mgpu_dialect.tmem_alloc( + result=tmem_type, + smem_ptr=smem, + collective=collective, + packing=packing, + ) + + mgpu_dialect.tmem_relinquish_alloc_permit(collective=collective) + mgpu_dialect.tmem_dealloc(tmem_ref) + + with jtu.set_env(MOSAIC_GPU_DUMP_PTX="1"), self.capture_stdout() as ptx: + mgpu.as_gpu_kernel( + body, + grid=(2 if collective else 1, 1, 1), + cluster=(2 if collective else 1, 1, 1), + block=(128, 1, 1), + in_shape=(), + out_shape=(jax.ShapeDtypeStruct((), jnp.int32),), + smem_scratch_shape=jax.ShapeDtypeStruct((), jnp.int32), + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + )() + [alloc] = re.findall( + r"tcgen05.alloc.cta_group::([12]).sync.aligned.shared::cta.b32", + ptx(), + ) + self.assertEqual(alloc[0], '2' if collective else '1') + + [ld] = re.findall( + r"ld.shared.b32\s+([%\w]+),\s+\[__dynamic_shmem__0\];", + ptx(), + ) + [dealloc] = re.findall( + r"tcgen05.dealloc.cta_group::([12]).sync.aligned.b32\s+([%\w]+),", + ptx(), + ) + self.assertEqual(dealloc[0], '2' if collective else '1') + self.assertEqual(dealloc[1], ld) + [relinquish] = re.findall( + r"tcgen05.relinquish_alloc_permit.cta_group::([12]).sync.aligned;", + ptx(), + ) + self.assertEqual(relinquish[0], "2" if collective else "1") + + @parameterized.named_parameters( + ("unpacked", 1, None), + ("packed", 2, None), + ("custom layout", None, tcgen05.tmem_default_layout(packing=1)), + ) + def test_tmem_load_store(self, packing, layout): + # TODO(bchetioui): add layout inference logic to handle packed/unpacked int8s. + dtype = jnp.bfloat16 + shape = (128, 128) + + def body( + ctx: launch_context.LaunchContext, + input: ir.Value, + result: ir.Value, + tmem: list[ir.Value], + ): + del ctx + + # GMEM -> registers + r_in = mgpu_dialect.vector_load(input) + + # registers -> TMEM + mgpu_dialect.async_store_tmem(r_in, tmem) + tcgen05.commit_tmem() + + # TMEM ->registers + r_out = mgpu_dialect.async_load_tmem(tmem) + # no need to wait in this case, see: + # https://docs.jax.dev/en/latest/pallas/gpu/reference.html#allocating-the-accumulator-using-tmem + + # Registers -> GMEM + mgpu_dialect.vector_store(r_out, result) + + jax_shape = jax.ShapeDtypeStruct(shape, dtype) + kernel = mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + cluster=(1, 1, 1), + block=(128, 1, 1), + in_shape=jax_shape, + out_shape=jax_shape, + smem_scratch_shape=mgpu.TMEM( + shape, dtype, packing=packing, layout=layout + ), + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + + key = jax.random.key(1234) + x = jax.random.randint(key, shape, -10, 10).astype(dtype) + self.assertArraysEqual(kernel(x), x) + + @parameterized.product( + m=(64, 128), + n=(128, 256, 512), + # TODO(allanrenucci): Add 32-byte swizzle once implemented. + swizzle=(64, 128), + ab_type=(jnp.float16, jnp.bfloat16), + acc_type=(jnp.float16, jnp.float32), + a_in_tmem=(False, True), + ) + def test_tcgen05_mma(self, m, n, swizzle, ab_type, acc_type, a_in_tmem): + if acc_type == jnp.float16 and ab_type != jnp.float16: + self.skipTest("Only f16 input is supported for f16 output.") + if a_in_tmem and m != 128: + self.skipTest("Only M=128 is supported for MMA with A in TMEM.") + + swizzle_elems = swizzle // np.dtype(ab_type).itemsize + groups_k = 2 + k = swizzle_elems * groups_k + a_packing = 4 // np.dtype(ab_type).itemsize + tmem_cols = tcgen05.tmem_alloc_exact_ncols(n, exact=False) + if a_in_tmem: + tmem_cols += tcgen05.tmem_alloc_exact_ncols(k // a_packing, exact=False) + if tmem_cols > 512: + self.skipTest( + f"Number of TMEM colums ({tmem_cols}) exceeds the limit of 512" + " columns." + ) + a_shape = (m, k) + b_shape = (k, n) + bytes_a = np.dtype(ab_type).itemsize * math.prod(a_shape) + bytes_b = np.dtype(ab_type).itemsize * math.prod(b_shape) + acc_shape = (m, n) + + def matmul(ctx, a_gmem, b_gmem, result_gmem, scratch): + del ctx + a_smem, b_smem, tma_barrier, mma_barrier, acc_tmem, a_tmem = scratch + + zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0) + + # GMEM -> SMEM + tma_barrier.arrive_expect_tx(bytes_b) + mgpu_dialect.async_load( + source=b_gmem, + destination=b_smem, + barrier=tma_barrier.as_barrier_memref(), + indices=[zero_i32] * len(b_shape), + slice_lengths=b_shape, + collective=ir.ArrayAttr.get([]), + ) + tma_barrier.wait() + + if a_in_tmem: + # GMEM -> Registers -> TMEM + reg = mgpu_dialect.vector_load(a_gmem) + mgpu_dialect.async_store_tmem(reg, a_tmem) + tcgen05.commit_tmem() + else: + # GMEM -> SMEM + tma_barrier.arrive_expect_tx(bytes_a) + mgpu_dialect.async_load( + source=a_gmem, + destination=a_smem, + barrier=tma_barrier.as_barrier_memref(), + indices=[zero_i32] * len(a_shape), + slice_lengths=a_shape, + collective=ir.ArrayAttr.get([]), + ) + tma_barrier.wait() + + mgpu_dialect.tcgen05_mma( + accumulator=acc_tmem, + a=a_tmem if a_in_tmem else a_smem, + b=b_smem, + accumulate=arith.constant(ir.IntegerType.get_signless(1), False), + ) + tcgen05.commit_arrive(mma_barrier.barrier_ref) + + mma_barrier.wait(orders_tensor_core=True) + + # TMEM -> Registers -> GMEM + r_out = mgpu_dialect.async_load_tmem(acc_tmem) + mgpu_dialect.vector_store(r_out, result_gmem) + + # Required order: SMEM -> Barrier -> TMEM. + scratch_shape = [ + jax.ShapeDtypeStruct(a_shape, ab_type) if not a_in_tmem else None, + jax.ShapeDtypeStruct(b_shape, ab_type), + core.TMABarrier(1), + mgpu.Barrier(1), + mgpu.TMEM(acc_shape, acc_type), + mgpu.TMEM(a_shape, ab_type, packing=a_packing) if a_in_tmem else None, + ] + kernel = mgpu.as_gpu_kernel( + matmul, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=( + jax.ShapeDtypeStruct(a_shape, ab_type), + jax.ShapeDtypeStruct(b_shape, ab_type), + ), + out_shape=jax.ShapeDtypeStruct(acc_shape, acc_type), + smem_scratch_shape=scratch_shape, + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + + a = self.prng.uniform(-1, 1, a_shape).astype(ab_type) + b = self.prng.uniform(-1, 1, b_shape).astype(ab_type) + + atol = 2e-2 if acc_type == jnp.float16 else 2e-5 + rtol = 8e-4 if acc_type == jnp.float16 else 1e-7 + self.assertArraysAllClose( + kernel(a, b), + np.matmul(a.astype(acc_type), b.astype(acc_type)), + atol=atol, + rtol=rtol, + ) + + @parameterized.product( + m=(128, 256), + n=(128, 256), + # TODO(allanrenucci): Add 32-byte swizzle once implemented. + swizzle=(64, 128), + ab_type=(jnp.float16, jnp.bfloat16), + acc_type=(jnp.float16, jnp.float32), + a_in_tmem=(False, True), + ) + def test_tcgen05_collective_mma(self, m, n, swizzle, ab_type, acc_type, a_in_tmem): + if acc_type == jnp.float16 and ab_type != jnp.float16: + self.skipTest("Only f16 input is supported for f16 output.") + if a_in_tmem and m != 256: + self.skipTest("Only M=256 is supported for MMA with A in TMEM.") + + swizzle_elems = swizzle // np.dtype(ab_type).itemsize + groups_k = 2 + k = swizzle_elems * groups_k + a_shape = (m, k) + a_block_shape = (m // 2, k) + a_packing = 4 // np.dtype(ab_type).itemsize + b_shape = (k, n) + b_block_shape = (k, n // 2) + bytes_a = np.dtype(ab_type).itemsize * math.prod(a_block_shape) + bytes_b = np.dtype(ab_type).itemsize * math.prod(b_block_shape) + acc_shape = (m, n) + acc_block_shape = (m // 2, n) + + def matmul(ctx, a_gmem, b_gmem, result_gmem, scratch): + (a_smem, b_smem, tma_barrier, mma_barrier, cluster_barrier, acc_tmem, a_tmem) = scratch + + i32_type = ir.IntegerType.get_signless(32) + zero_i32 = arith.constant(i32_type, 0) + + block_id = gpu.cluster_block_id(gpu.Dimension.x) + block_id_i32 = arith.index_cast(i32_type, block_id) + + m_index = arith.muli(block_id, arith.constant(ir.IndexType.get(), m // 2)) + m_index_i32 = arith.muli(block_id_i32, arith.constant(i32_type, m // 2)) + n_index_i32 = arith.muli(block_id_i32, arith.constant(i32_type, n // 2)) + + # GMEM -> SMEM + tma_barrier.arrive_expect_tx(bytes_b) + mgpu_dialect.async_load( + source=b_gmem, + destination=b_smem, + barrier=tma_barrier.as_barrier_memref(), + indices=[zero_i32, n_index_i32], + slice_lengths=b_block_shape, + collective=ir.ArrayAttr.get([]), + ) + tma_barrier.wait() + + if a_in_tmem: + # GMEM -> Registers -> TMEM + sliced_a_gmem = memref_slice(a_gmem, ds(m_index, m // 2)) + reg = mgpu_dialect.vector_load(sliced_a_gmem) + mgpu_dialect.async_store_tmem(reg, a_tmem) + tcgen05.commit_tmem() + else: + # GMEM -> SMEM + tma_barrier.arrive_expect_tx(bytes_a) + mgpu_dialect.async_load( + source=a_gmem, + destination=a_smem, + barrier=tma_barrier.as_barrier_memref(), + indices=[m_index_i32, zero_i32], + slice_lengths=a_block_shape, + collective=ir.ArrayAttr.get([]), + ) + tma_barrier.wait() + + # Make sure operands have been loaded on both blocks. + cluster_barrier.arrive(orders_tensor_core=True) + cluster_barrier.wait(orders_tensor_core=True) + + is_first_block = arith.cmpi( + arith.CmpIPredicate.eq, block_id, c(0, ir.IndexType.get()) + ) + with when(is_first_block): + mgpu_dialect.tcgen05_mma( + accumulator=acc_tmem, + a=a_tmem if a_in_tmem else a_smem, + b=b_smem, + accumulate=arith.constant(ir.IntegerType.get_signless(1), False), + collective=True, + ) + tcgen05.commit_arrive(mma_barrier.barrier_ref, collective=True, ctx=ctx) + + mma_barrier.wait(orders_tensor_core=True) + + # TMEM -> Registers -> GMEM + r_out = mgpu_dialect.async_load_tmem(acc_tmem) + sliced_result_gmem = memref_slice(result_gmem, ds(m_index, m // 2)) + mgpu_dialect.vector_store(r_out, sliced_result_gmem) + + # Required order: SMEM -> Barrier -> TMEM. + scratch_shape = [ + jax.ShapeDtypeStruct(a_block_shape, ab_type) if not a_in_tmem else None, + jax.ShapeDtypeStruct(b_block_shape, ab_type), + core.TMABarrier(1), + mgpu.Barrier(1), + mgpu.ClusterBarrier(collective_dims=(gpu.Dimension.x,)), + mgpu.TMEM(acc_block_shape, acc_type, collective=True), + mgpu.TMEM(a_block_shape, ab_type, collective=True, packing=a_packing) + if a_in_tmem + else None, + ] + kernel = mgpu.as_gpu_kernel( + matmul, + grid=(2, 1, 1), + cluster=(2, 1, 1), + block=(128, 1, 1), + in_shape=( + jax.ShapeDtypeStruct(a_shape, ab_type), + jax.ShapeDtypeStruct(b_shape, ab_type), + ), + out_shape=jax.ShapeDtypeStruct(acc_shape, acc_type), + smem_scratch_shape=scratch_shape, + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + + a = self.prng.uniform(-1, 1, a_shape).astype(ab_type) + b = self.prng.uniform(-1, 1, b_shape).astype(ab_type) + + atol = 2e-2 if acc_type == jnp.float16 else 2e-5 + rtol = 8e-4 if acc_type == jnp.float16 else 1e-7 + self.assertArraysAllClose( + kernel(a, b), + np.matmul(a.astype(acc_type), b.astype(acc_type)), + atol=atol, + rtol=rtol, + ) + + def test_slice_tmem(self): + def tmem_type(ref: ir.Value): + return ir.MemRefType.get( + ref.type.shape, ref.type.element_type, memory_space=utils.tmem() + ) + + def body(ctx, x, y, x_out, y_out, tmem): + del ctx + x_tmem = mgpu_dialect.slice_tmem(tmem_type(x), tmem, offset=0) + y_tmem = mgpu_dialect.slice_tmem(tmem_type(y), tmem, offset=128) + x_layout = layouts.to_layout_attr(tcgen05.tmem_default_layout(packing=2)) + x_tmem = mgpu_dialect.tmem_layout_cast(x_tmem, x_layout) + y_layout = layouts.to_layout_attr(tcgen05.tmem_default_layout(packing=1)) + y_tmem = mgpu_dialect.tmem_layout_cast(y_tmem, y_layout) + + # GMEM -> Registers -> TMEM + x_reg = mgpu_dialect.vector_load(x) + y_reg = mgpu_dialect.vector_load(y) + mgpu_dialect.async_store_tmem(x_reg, x_tmem) + mgpu_dialect.async_store_tmem(y_reg, y_tmem) + tcgen05.commit_tmem() + + # TMEM -> Registers -> GMEM + x_reg = mgpu_dialect.async_load_tmem(x_tmem) + y_reg = mgpu_dialect.async_load_tmem(y_tmem) + mgpu_dialect.vector_store(x_reg, x_out) + mgpu_dialect.vector_store(y_reg, y_out) + + in_out_shapes = ( + jax.ShapeDtypeStruct((128, 128), jnp.bfloat16), + jax.ShapeDtypeStruct((128, 64), jnp.int32), + ) + kernel = mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=in_out_shapes, + out_shape=in_out_shapes, + smem_scratch_shape=mgpu.TMEM((128, 512), jnp.int32), + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + x = self.prng.uniform(-100, 100, (128, 128)).astype(jnp.bfloat16) + y = self.prng.uniform(-100, 100, (128, 64)).astype(jnp.int32) + x_out, y_out = kernel(x, y) + self.assertArraysEqual(x_out, x) + self.assertArraysEqual(y_out, y) + + def test_tmem_subview(self): + def body(ctx, in_ref, out_ref, tmem): + del ctx + # GMEM -> Registers -> TMEM + in_reg = mgpu_dialect.vector_load(in_ref) + slice_in = memref.subview( + tmem, offsets=[0, 8], sizes=[128, 200], strides=[1, 1] + ) + slice_in = memref.subview( + slice_in, offsets=[0, 0], sizes=[128, 128], strides=[1, 1] + ) + mgpu_dialect.async_store_tmem(in_reg, slice_in) + tcgen05.commit_tmem() + + def dynamic_idx(idx: int) -> ir.Value: + idx_type = ir.IndexType.get() + return arith.constant(idx_type, idx) + + # TMEM -> Registers -> GMEM + slice_out = memref.subview( + tmem, + offsets=[dynamic_idx(0), dynamic_idx(8)], + sizes=[128, 128], + strides=[1, 1], + ) + out_reg = mgpu_dialect.async_load_tmem(slice_out) + mgpu_dialect.vector_store(out_reg, out_ref) + + kernel = mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=jax.ShapeDtypeStruct((128, 128), jnp.float32), + out_shape=jax.ShapeDtypeStruct((128, 128), jnp.float32), + smem_scratch_shape=mgpu.TMEM((128, 256), jnp.float32), + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + x = self.prng.uniform(-100, 100, (128, 128)).astype(jnp.float32) + self.assertArraysEqual(kernel(x), x) + + @parameterized.parameters(jnp.int32, jnp.int16, jnp.int8) + def test_tma_gather(self, index_dtype): + # TODO(b/415721295): Remove when the minimum jaxlib version is 0.8.3. + if not hasattr(mgpu_dialect, "tma_gather_supported"): + self.skipTest("TMA gather support is required.") + + dtype = jnp.bfloat16 + src_shape = (128, 64) + dst_shape = (32, 64) + indices_shape = (32,) + + def body(ctx, src, indices, dst, smem): + del ctx + smem_ref, tma_barrier = smem + + # Load indices into registers + indices_vec = mgpu_dialect.vector_load(indices) + i32 = ir.IntegerType.get_signless(32) + zero = arith.constant(i32, 0) + + slice_lengths = (32, 64) + + # Load + Gather + expected_bytes = math.prod(slice_lengths) * np.dtype(dtype).itemsize + tma_barrier.arrive_expect_tx(expected_bytes) + mgpu_dialect.async_load( + source=src, + destination=smem_ref, + barrier=tma_barrier.as_barrier_memref(), + indices=[indices_vec, zero], + slice_lengths=slice_lengths, + collective=ir.ArrayAttr.get([]), + ) + tma_barrier.wait() + + # Store + mgpu_dialect.async_store( + source=smem_ref, + destination=dst, + indices=[zero, zero], + slice_lengths=slice_lengths, + ) + nvvm.cp_async_bulk_wait_group(0) + + kernel = mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=( + jax.ShapeDtypeStruct(src_shape, dtype), + jax.ShapeDtypeStruct(indices_shape, index_dtype), + ), + out_shape=jax.ShapeDtypeStruct(dst_shape, dtype), + smem_scratch_shape=[ + jax.ShapeDtypeStruct((32, 64), dtype), + core.TMABarrier(1), + ], + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + + src = self.prng.uniform(-1, 1, src_shape).astype(dtype) + indices = jax.random.permutation(jax.random.key(0), 128)[:32].astype(index_dtype) + result = kernel(src, indices) + + # Verification + np.testing.assert_array_equal(result, src[indices.astype(jnp.int32)]) + class UtilsTest(TestCase): @parameterized.parameters( @@ -2916,6 +6448,34 @@ def test_parse_indices_oob(self, indices): with self.assertRaisesRegex(IndexError, "out of bounds"): utils.parse_indices(indices, (2, 3, 4)) + @jtu.thread_unsafe_test() # Modifies ``os.environ``. + def test_assert(self): + if cf is None: + self.skipTest("``cf`` is not available") + + def kernel(ctx: mgpu.LaunchContext, x_ref, out, scratch) -> None: + del ctx, out # Unused. + # TODO(b/408271232): Use a False condition once the bug is fixed. + x = mgpu.FragmentedArray.load_strided(x_ref) + cond = x.reduce("add", 0, *scratch) != 42.0 + cf.assert_(cond.registers.item(), "OOOPS") + + f = mgpu.as_gpu_kernel( + kernel, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(jax.ShapeDtypeStruct((128,), jnp.float32),), + out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), + smem_scratch_shape=(jax.ShapeDtypeStruct((4,), jnp.float32),), + ) + + with jtu.set_env(MOSAIC_GPU_DUMP_SASS="1"), self.capture_stdout() as sass: + jax.block_until_ready(f(jnp.ones((128,), jnp.float32))) + + # SASS doesn't seem to include the assertion message, so we are just + # checking that __assertfail appears in the symbol table for the kernel. + self.assertIn("__assertfail", sass()) + class SerializationTest(absltest.TestCase): @@ -2931,5 +6491,418 @@ def test_pass_is_registered(self): pipeline.run(module.operation) +class ApiTest(TestCase): + + def test_inout(self): + def kernel(ctx, src, inout, dst, smem): + val = memref.load(inout, []) + gpu.barrier() + new_val = arith.constant(ir.IntegerType.get_signless(32), 42) + memref.store(new_val, inout, []) + x = mgpu.FragmentedArray.load_strided(src, is_signed=True) + (x + val).store_untiled(dst) + x = jnp.arange(128, dtype=jnp.int32) + y = jnp.asarray(2.0, dtype=jnp.int32) + kernel = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, (), inout_shape=y, + ) + xo, yo = kernel(x, y) + np.testing.assert_array_equal(xo, x + 2.0) + np.testing.assert_array_equal(yo, jnp.asarray(42, dtype=jnp.int32)) + + def test_serialize_uses_bytecode_format(self): + def kernel(ctx, src, dst, smem): + del ctx, smem + x = mgpu.FragmentedArray.load_strided(src, is_signed=True) + (x + 1).store_untiled(dst) + x = jnp.arange(128, dtype=jnp.int32) + with self.subTest("bytecode"): + f = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, (), + ) + bytecode_stablehlo = jax.jit(f).lower(x).as_text() + module_prefix = "module = \"ML\\EFR" + +if hp is not None: + @hps.composite + def tiled_layouts( + draw, initial_tile, vector_transfer: bool = False + ) -> fa.TiledLayout: + assert all(t.bit_count() == 1 for t in initial_tile) + assert math.prod(initial_tile) >= 128 + tiles = [initial_tile] + dim_offset = len(initial_tile) + if draw(hps.booleans()): + warp_dims = [fa.Replicated(2) if draw(hps.booleans()) else None for _ in range(2)] + else: + warp_dims = [fa.Replicated(4) if draw(hps.booleans()) else None] + for i, dim in enumerate(warp_dims): + if isinstance(dim, fa.Replicated): + continue + dim_size = 4 // len(warp_dims) + warp_dim = draw( + hps.sampled_from( + [i for i, t in enumerate(tiles[-1]) if t % dim_size == 0] + ) + ) + warp_tile = list(tiles[-1]) + warp_tile[warp_dim] //= dim_size + warp_dims[i] = dim_offset + warp_dim + tiles.append(warp_tile) + dim_offset += len(warp_tile) + lane_dims = [fa.Replicated(2) if draw(hps.booleans()) else None for _ in range(5)] + for i, dim in enumerate(lane_dims): + if isinstance(dim, fa.Replicated): + continue + lane_dim = draw(hps.sampled_from( + [i for i, t in enumerate(tiles[-1]) if t % 2 == 0] + )) + lane_tile = list(tiles[-1]) + lane_tile[lane_dim] //= 2 + lane_dims[i] = dim_offset + lane_dim + tiles.append(lane_tile) + dim_offset += len(lane_tile) + # Permute lane dims so that they don't always partition the data in order. + lane_dims = draw(hps.permutations(lane_dims)) + if vector_transfer: + min_vector_dim = len(tiles[-1]) - 1 + else: + min_vector_dim = 0 + vector_dim = draw(hps.integers(min_vector_dim, len(tiles[-1]) - 1)) + vector_size = 2 ** draw( + hps.integers(0, tiles[-1][vector_dim].bit_length() - 1) + ) + vector_tile = list(tiles[-1]) + assert vector_tile[vector_dim] % vector_size == 0 + vector_tile[vector_dim] //= vector_size + tiles.append(vector_tile) + dim_offset += len(vector_tile) + vector_dim += dim_offset + dim_offset += len(vector_tile) # This is the remainder after tiling! + + warp_dims = tuple( + d if isinstance(d, fa.Replicated) else d - dim_offset + for d in warp_dims + ) + lane_dims = tuple( + d if isinstance(d, fa.Replicated) else d - dim_offset + for d in lane_dims + ) + vector_dim = vector_dim - dim_offset + return fa.TiledLayout( + tiling=fa.Tiling(tuple(map(tuple, tiles))), + warp_dims=warp_dims, + lane_dims=lane_dims, + vector_dim=vector_dim, + _check_canonical=False, + ).canonicalize() + + @hps.composite + def shape_and_tiled_layout( + draw, vector_transfer: bool = False + ) -> tuple[tuple[int, ...], fa.TiledLayout]: + rank = draw(hps.integers(2, 3)) + initial_tile = tuple( + draw(hps.sampled_from([1, 2, 4, 8, 16, 32, 64, 128])) + for _ in range(rank) + ) + hp.assume(128 <= math.prod(initial_tile) < 128 * 32) + shape = tuple(t * draw(hps.integers(1, 5)) for t in initial_tile) + hp.assume(math.prod(shape) <= 128 * 128) + layout = draw(tiled_layouts(initial_tile, vector_transfer=vector_transfer)) + return shape, layout + + class HypothesisTest(TestCase): + + def test_reduce(self): + @hps.composite + def strategy(draw): + shape, layout = draw(shape_and_tiled_layout(vector_transfer=True)) + rank = len(shape) + reduced_dims = draw(hps.sets(hps.integers(0, rank - 1), min_size=1)) + dtype = draw(hps.sampled_from([jnp.int32, jnp.int16])) + op = draw(hps.sampled_from(["add", "max"])) + return shape, layout, tuple(reduced_dims), dtype, op + + warp_replicated_major = fa.TiledLayout( + fa.Tiling(((2,), (1,))), (fa.Replicated(2,), -2), (fa.Replicated(32,),), -1 + ) + warp_replicated_minor = fa.TiledLayout( + fa.Tiling(((2,), (1,))), (-2, fa.Replicated(2,)), (fa.Replicated(32,),), -1 + ) + warp_row_col_layout = fa.TiledLayout( + fa.Tiling(((2, 2), (1,))), (-3, -2), (fa.Replicated(32,),), -1 + ) + even_lane_split_layout = fa.TiledLayout( + fa.Tiling(((8,), (4,), (2,), (1,))), + (fa.Replicated(4),), + (-4, fa.Replicated(2), -3, fa.Replicated(2), -2), + -1, + ) + odd_lane_split_layout = fa.TiledLayout( + fa.Tiling(((4,), (2,), (1,))), + (fa.Replicated(4),), + (fa.Replicated(2), -3, fa.Replicated(2), -2, fa.Replicated(2)), + -1, + ) + + @hp.given(strategy()) + @hp.example(((16,), warp_replicated_major, (0,), jnp.int32, "add")) + @hp.example(((16,), warp_replicated_minor, (0,), jnp.int32, "add")) + @hp.example(((16, 16), warp_row_col_layout, (0,), jnp.int32, "add")) + @hp.example(((16, 16), warp_row_col_layout, (1,), jnp.int32, "add")) + @hp.example(((256,), even_lane_split_layout, (0,), jnp.float32, "max")) + @hp.example(((256,), odd_lane_split_layout, (0,), jnp.int32, "max")) + def run(args): + shape, layout, reduced_dims, dtype, op = args + out_shape = list(shape) + for d in sorted(reduced_dims, reverse=True): + del out_shape[d] + def kernel(ctx, src, dst, scratch): + del ctx + arr = fa.FragmentedArray.load_untiled(src, layout=layout, optimized=False, + is_signed=utils.is_signed(dtype)) + arr.reduce(op, reduced_dims, scratch).store_untiled(dst, optimized=False) + if jnp.issubdtype(dtype, jnp.integer): + x = jax.random.randint(jax.random.key(1234), shape, -1000, 1000, dtype) + else: + x = jax.random.normal(jax.random.key(1234), shape, dtype) + out_type = jax.ShapeDtypeStruct(out_shape, dtype) + scratch_type = jax.ShapeDtypeStruct((2048,), dtype) + hp.assume(layout.vector_length <= 16) # Otherwise we run out of scratch + try: + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, out_type, scratch_type + )(x) + except NotImplementedError: + hp.assume(False) + return + if op == "max": + ref = x.max(reduced_dims) + else: + ref = x.sum(reduced_dims, dtype=dtype) + np.testing.assert_array_equal(result, ref) + run() + + def test_slice(self): + i32 = ir.IntegerType.get_signless(32) + index = ir.IndexType.get() + + @hps.composite + def strategy(draw): + shape, layout = draw(shape_and_tiled_layout(vector_transfer=True)) + tiling = layout.base_tile_shape + tiled_shape = mgpu.tile_shape(shape, tiling)[:len(shape)] + def draw_slice(size, tile): + start = draw(hps.integers(0, size - 1)) + length = draw(hps.integers(1, size - start)) + return slice(start * tile, (start + length) * tile) + slices = tuple(map(draw_slice, tiled_shape, tiling)) + return shape, layout, slices + + basic_slices = (slice(128, 256), slice(16, 16 + 32)) + @hp.given(strategy()) + @hp.example(((256, 256), fa.WGMMA_LAYOUT, basic_slices)) + @hp.example(((256, 256), tcgen05.LAYOUT, basic_slices)) + @hp.example(((256, 256), tcgen05.TMEM_NATIVE_LAYOUT, basic_slices)) + def run(args): + shape, layout, slices = args + def kernel(ctx, dst, _): + def linear_index(*idxs): + total = arith.constant(index, 0) + stride = 1 + for i, size in zip(idxs[::-1], shape[::-1]): + total = arith.addi(total, arith.muli(i, c(stride, index))) + stride *= size + return arith.index_cast(i32, total) + x = mgpu.FragmentedArray.build( + shape, layout, linear_index, is_signed=True + ) + x[slices].store_untiled(dst, optimized=False) + + slice_shape = tuple(len(range(size)[s]) for s, size in zip(slices, shape)) + out_shape = jax.ShapeDtypeStruct(shape=slice_shape, dtype=jnp.int32) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () + )() + iota = np.arange(np.prod(shape), dtype=jnp.int32).reshape(*shape) + np.testing.assert_array_equal(result, iota[slices]) + run() + + def test_broadcast(self): + @hps.composite + def strategy(draw): + shape, layout = draw(shape_and_tiled_layout(vector_transfer=True)) + rank = len(shape) + broadcast_dims = draw( + hps.sets(hps.integers(0, rank - 1), min_size=1, max_size=rank - 1) + ) + dtype = draw(hps.sampled_from([jnp.float32, jnp.bfloat16])) + return shape, layout, tuple(broadcast_dims), dtype + + @hp.given(strategy()) + def run(args): + out_shape, out_layout, broadcast_dims, dtype = args + in_shape = list(out_shape) + for d in sorted(broadcast_dims, reverse=True): + del in_shape[d] + in_layout = out_layout.reduce(broadcast_dims) + dims = tuple(d for d in range(len(out_shape)) if d not in broadcast_dims) + def kernel(ctx, src, dst, scratch): + del ctx, scratch # Unused. + arr = fa.FragmentedArray.load_untiled(src, layout=in_layout, optimized=False) + arr.broadcast_in_dim(out_shape, dims, out_layout).store_untiled(dst, optimized=False) + x = jax.random.normal(jax.random.key(1234), in_shape, dtype) + out_type = jax.ShapeDtypeStruct(out_shape, dtype) + try: + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, out_type, () + )(x) + except NotImplementedError: + hp.assume(False) + return + np.testing.assert_array_equal(result, jax.lax.broadcast_in_dim(x, out_shape, dims)) + run() + + @hp.given(hps.data()) + def test_canonicalize_trivial_dims(self, data): + layout = data.draw(tiled_layouts((128, 1))) + trivial_dims = [ + i + for i, d in fa.enumerate_negative(layout.tiled_tiling_shape) + if d == 1 and i != layout.vector_dim + ] + if not trivial_dims: + hp.assume(False) + # That should not happen in canonical layouts. + self.assertNoCommonElements(trivial_dims, layout.partitioned_warp_dims) + self.assertNoCommonElements(trivial_dims, layout.partitioned_lane_dims) + # vector_dim can be trivial. + canonical_layout = layout + use_trivial_dim = data.draw( + hps.lists(hps.booleans(), min_size=len(trivial_dims), max_size=len(trivial_dims)) + ) + hp.assume(any(use_trivial_dim)) + for d, use in zip(trivial_dims, use_trivial_dim): + if not use: + continue + if data.draw(hps.booleans()): # Should we put it in warp or lane dims? + new_warp_dims = list(layout.warp_dims) + position = data.draw(hps.integers(0, len(layout.warp_dims))) + new_warp_dims.insert(position, d) + layout = dataclasses.replace( + layout, warp_dims=tuple(new_warp_dims), _check_canonical=False + ) + else: + new_lane_dims = list(layout.lane_dims) + position = data.draw(hps.integers(0, len(layout.lane_dims))) + new_lane_dims.insert(position, d) + layout = dataclasses.replace( + layout, lane_dims=tuple(new_lane_dims), _check_canonical=False + ) + self.assertNotEqual(layout, canonical_layout) + self.assertEqual(layout.canonicalize(), canonical_layout) + + def test_copy_tiled(self): + @hps.composite + def strategy(draw): + swizzle = draw(hps.sampled_from([16, 32, 64, 128])) + dtype = draw(hps.sampled_from([jnp.int32, jnp.int16, jnp.int8])) + tiling = (8, swizzle // jnp.dtype(dtype).itemsize) + shape = [draw(hps.integers(1, 6)) for t in tiling] + while math.prod(shape) % 4: + shape[draw(hps.booleans())] *= 2 + shape = [s * t for s, t in zip(shape, tiling)] + to_smem = draw(hps.booleans()) + return shape, dtype, swizzle, to_smem + + @hp.given(strategy()) + @hp.example(((48, 64), jnp.int32, 16, False)) + @hp.example(((48, 64), jnp.int32, 32, False)) + @hp.example(((48, 64), jnp.int32, 64, False)) + @hp.example(((48, 64), jnp.int32, 128, False)) + @hp.example(((64, 4), jnp.int32, 16, False)) + def run(args): + shape, dtype, swizzle, to_smem = args + tiling = (8, 8 * swizzle // jnp.iinfo(dtype).bits) + def kernel(ctx, src, dst, scratch): + smem, barrier = scratch + if to_smem: + mgpu.copy_tiled(src, smem, swizzle=swizzle) + mgpu.commit_shared() + ctx.async_copy( + src_ref=smem, + dst_ref=dst, + gmem_transform=mgpu.TileTransform(tiling), + swizzle=swizzle, + ) + ctx.await_async_copy(0) + else: + ctx.async_copy( + src_ref=src, + dst_ref=smem, + gmem_transform=mgpu.TileTransform(tiling), + swizzle=swizzle, + barrier=barrier, + ) + barrier.wait() + mgpu.copy_tiled(smem, dst, swizzle=swizzle) + + x = jnp.arange(math.prod(shape), dtype=dtype).reshape(shape) + scratch_shape = [ + jax.ShapeDtypeStruct(mgpu.tile_shape(shape, tiling), dtype), + mgpu.TMABarrier(1), + ] + y = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, scratch_shape + )(x) + np.testing.assert_array_equal(y, x) + run() + + def test_dialect_vector_load_store(self): + @hps.composite + def strategy(draw): + shape, layout = draw(shape_and_tiled_layout(vector_transfer=True)) + return shape, layout + + @hp.given(strategy()) + @hp.example(((128, 128), fa.WGMMA_LAYOUT)) + @hp.example(((128, 128), fa.TCGEN05_LAYOUT)) + @hp.example(((128, 128), fa.TMEM_NATIVE_LAYOUT)) + def run(args): + shape, layout = args + dtype = jnp.float32 + layout_attr = layouts.to_layout_attr(layout) + + def body(ctx, input, result, smem): + del ctx + # GMEM -> Registers + reg = mgpu_dialect.vector_load(input) + reg = mgpu_dialect.layout_cast(reg, layout_attr) + # Registers -> SMEM + mgpu_dialect.vector_store(reg, smem) + # SMEM -> Registers + reg = mgpu_dialect.vector_load(smem) + reg = mgpu_dialect.layout_cast(reg, layout_attr) + # Registers -> GMEM + mgpu_dialect.vector_store(reg, result) + + jax_shape = jax.ShapeDtypeStruct(shape, dtype) + kernel = mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=jax_shape, + out_shape=jax_shape, + smem_scratch_shape=jax_shape, + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + + input = self.prng.uniform(-1, 1, shape).astype(dtype) + np.testing.assert_array_equal(kernel(input), input) + + run() + + if __name__ == "__main__": absltest.main(argv=["python"], testLoader=jtu.JaxTestLoader()) diff --git a/tests/mosaic/gpu_test_distributed.py b/tests/mosaic/gpu_test_distributed.py new file mode 100644 index 000000000000..244a43e9279c --- /dev/null +++ b/tests/mosaic/gpu_test_distributed.py @@ -0,0 +1,421 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os + +from absl.testing import parameterized +import jax +from jax._src import config +from jax._src import test_util as jtu +from jax._src import test_multiprocess as jt_multiprocess +from jax._src.interpreters import mlir +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import arith +from jax._src.lib.mlir.dialects import memref +from jax._src.lib.mlir.dialects import vector +from jax.experimental.mosaic.gpu import dialect as mgpu_dialect # pylint: disable=g-importing-member +from jax.experimental import multihost_utils +import jax.numpy as jnp +import numpy as np +import jax.experimental.mosaic.gpu as mgpu +import jax.experimental.mosaic.gpu.fragmented_array as fa + + +# ruff: noqa: F405 +# pylint: disable=g-complex-comprehension +P = jax.sharding.PartitionSpec + + +class TestCase(parameterized.TestCase): + + def setUp(self): + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("Only works on GPU with capability >= sm90") + if not mgpu.supports_cross_device_collectives(): + self.skipTest("NVSHMEM library unavailable.") + if os.environ.get("XLA_PYTHON_CLIENT_ALLOCATOR", "") == "platform": + self.skipTest("NVSHMEM doesn't work with the platform allocator.") + if jax.process_count() == 1: + self.skipTest("Test requires multiple processes.") + if jax.device_count() != jax.process_count(): + self.skipTest("Need 1 device per process") + super().setUp() + self.prng = np.random.default_rng(1234) + self.context = mlir.make_ir_context() + if mgpu_dialect is not None: + mgpu_dialect.register_dialect(self.context) + self.enter_context(config.traceback_filtering("off")) + self.enter_context(self.context) + self.enter_context(ir.Location.unknown()) + + +class ProfilerTest(TestCase): + + def test_get_device_id(self): + index = ir.IndexType.get() + def kernel(ctx, dst, _): + device_id = ctx.device_id() + memref.store(device_id, dst, [arith.constant(index, 0)]) + mesh = jax.make_mesh( + (jax.device_count(),), ("x",), axis_types=(jax.sharding.AxisType.Explicit,) + ) + with jax.set_mesh(mesh): + out_shape = jax.ShapeDtypeStruct((1,), jnp.int32) + y = jax.jit( + jax.shard_map( + mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () + ), + out_specs=P("x"), + check_vma=False, + ) + )() + y_np = multihost_utils.process_allgather(y, tiled=True) + np.testing.assert_array_equal(y_np, np.arange(jax.device_count())) + + def test_remote_async_copy_basic(self): + i32 = ir.IntegerType.get_signless(32) + def kernel(ctx, src, sem, dst, scratch): + tmp, barrier = scratch + other_device = arith.subi(arith.constant(i32, 1), ctx.device_id()) + ctx.async_copy(src_ref=src, dst_ref=tmp, barrier=barrier) + barrier.wait() + ctx.async_copy(src_ref=tmp, dst_ref=dst, gmem_peer_id=other_device) + ctx.await_async_copy(0) + other_sem = mgpu.SemaphoreRef( + mgpu.utils.memref_ptr(ctx.to_remote(sem, other_device)) + ) + other_sem.signal(1) + my_sem = mgpu.SemaphoreRef(mgpu.utils.memref_ptr(sem)) + my_sem.wait(1) + + mesh = jax.make_mesh( + (2,), ("x",), axis_types=(jax.sharding.AxisType.Explicit,) + ) + with jax.set_mesh(mesh): + x_np = np.arange(64 * 64, dtype=jnp.float32).reshape(64, 64) + x = jax.sharding.reshard(x_np, P("x")) + sem = jax.sharding.reshard(jnp.zeros((1,), dtype=jnp.int32), P()) + y, _ = jax.jit( + jax.shard_map( + lambda x, sem: mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, (x, mgpu.TMABarrier()), inout_shape=sem + )(x, sem), + in_specs=(P("x"), P(None)), + out_specs=[P("x"), P(None)], + check_vma=False, + ) + )(x, sem) + y_np = multihost_utils.process_allgather(y, tiled=True) + np.testing.assert_array_equal( + y_np, np.concatenate(np.split(x_np, 2)[::-1], axis=0) + ) + + def test_remote_async_copy_add(self): + i32 = ir.IntegerType.get_signless(32) + def kernel(ctx, src, sem, dst, scratch): + tmp, barrier = scratch + other_device = arith.subi(arith.constant(i32, 1), ctx.device_id()) + other_sem = mgpu.SemaphoreRef( + mgpu.utils.memref_ptr(ctx.to_remote(sem, other_device)) + ) + my_sem = mgpu.SemaphoreRef(mgpu.utils.memref_ptr(sem)) + ctx.async_copy(src_ref=src, dst_ref=tmp, barrier=barrier) + barrier.wait() + fa.FragmentedArray.splat(arith.constant(ir.F32Type.get(), 1.0), (32, 64)).store_untiled(dst) + mgpu.warpgroup_barrier() + other_sem.signal(1) + my_sem.wait(1) + ctx.async_copy(src_ref=tmp, dst_ref=dst, gmem_peer_id=other_device, reduction_op="add") + ctx.await_async_copy(0) + other_sem.signal(1) + my_sem.wait(1) + + mesh = jax.make_mesh( + (2,), ("x",), axis_types=(jax.sharding.AxisType.Explicit,) + ) + with jax.set_mesh(mesh): + x_np = np.arange(64 * 64, dtype=jnp.float32).reshape(64, 64) + x = jax.sharding.reshard(x_np, P("x")) + sem = jax.sharding.reshard(jnp.zeros((1,), dtype=jnp.int32), P()) + y, _ = jax.jit( + jax.shard_map( + lambda x, sem: mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, (x, mgpu.TMABarrier()), inout_shape=sem + )(x, sem), + in_specs=(P("x"), P(None)), + out_specs=[P("x"), P(None)], + check_vma=False, + ) + )(x, sem) + y_np = multihost_utils.process_allgather(y, tiled=True) + np.testing.assert_array_equal( + y_np, 1 + np.concatenate(np.split(x_np, 2)[::-1], axis=0) + ) + + def test_remote_semaphore(self): + i32 = ir.IntegerType.get_signless(32) + def kernel(ctx, sem, _): + my_device = ctx.device_id() + other_device = arith.subi(arith.constant(i32, 1), my_device) + my_sem = mgpu.SemaphoreRef(mgpu.utils.memref_ptr(sem)) + other_dst = ctx.to_remote(sem, other_device) + other_sem = mgpu.SemaphoreRef(mgpu.utils.memref_ptr(other_dst)) + # We signal and wait a different amount on each device to make sure we're + # really communicating here. + other_sem.signal(arith.addi(arith.constant(i32, 1), other_device)) + @mgpu.fori(arith.addi(arith.constant(i32, 1), my_device), None) + def wait_loop(i, _): + my_sem.wait(1) + + mesh = jax.make_mesh( + (2,), ("x",), axis_types=(jax.sharding.AxisType.Explicit,) + ) + with jax.set_mesh(mesh): + sem = jax.sharding.reshard(jnp.zeros((1,), dtype=jnp.int32), P()) + out_sem = jax.jit( + jax.shard_map( + mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), (), (), inout_shape=sem + ), + out_specs=P("x"), + check_vma=False, + ) + )(sem) + out_sems = multihost_utils.process_allgather(out_sem, tiled=True) + np.testing.assert_array_equal(out_sems, np.zeros_like(out_sems)) + + @parameterized.parameters(1, 2, 4) + def test_multimem_basic(self, vector_length): + i32 = ir.IntegerType.get_signless(32) + index = ir.IndexType.get() + def kernel(ctx, sem, out, _): + my_device = ctx.device_id() + other_device = arith.subi(arith.constant(i32, 1), my_device) + my_sem = mgpu.SemaphoreRef(mgpu.utils.memref_ptr(sem)) + other_dst = ctx.to_remote(sem, other_device) + other_sem = mgpu.SemaphoreRef(mgpu.utils.memref_ptr(other_dst)) + with mgpu.when(arith.cmpi(arith.CmpIPredicate.eq, my_device, arith.constant(i32, 0))): + c = arith.constant(i32, 1) + vc = vector.broadcast(ir.VectorType.get((vector_length,), i32), c) + multicast_ref = ctx.to_remote_multicast(out) + multicast_ref.store(vc, [arith.constant(index, 0)]) + other_sem.signal(arith.constant(i32, 1)) + my_sem.wait(1) + + mesh = jax.make_mesh( + (2,), ("x",), axis_types=(jax.sharding.AxisType.Explicit,) + ) + with jax.set_mesh(mesh): + sem = jax.sharding.reshard(jnp.zeros((1,), dtype=jnp.int32), P()) + out_shape = jax.ShapeDtypeStruct((vector_length,), jnp.int32) + out, out_sem = jax.jit( + jax.shard_map( + mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out_shape, (), inout_shape=sem + ), + out_specs=P("x"), + check_vma=False, + ) + )(sem) + out_sems = multihost_utils.process_allgather(out_sem, tiled=True) + np.testing.assert_array_equal(out_sems, np.zeros_like(out_sems)) + out = multihost_utils.process_allgather(out, tiled=True) + np.testing.assert_array_equal(out, np.ones_like(out)) + + def test_multimem_store_registers(self): + i32 = ir.IntegerType.get_signless(32) + def kernel(ctx, inp, sem, out, _): + my_device = ctx.device_id() + other_device = arith.subi(arith.constant(i32, 1), my_device) + my_sem = mgpu.SemaphoreRef(mgpu.utils.memref_ptr(sem)) + other_dst = ctx.to_remote(sem, other_device) + other_sem = mgpu.SemaphoreRef(mgpu.utils.memref_ptr(other_dst)) + with mgpu.when(arith.cmpi(arith.CmpIPredicate.eq, my_device, arith.constant(i32, 0))): + arr = mgpu.FragmentedArray.load_strided(inp, is_signed=True) + arr.store_untiled(ctx.to_remote_multicast(out), optimized=False) + other_sem.signal(arith.constant(i32, 1)) + my_sem.wait(1) + + mesh = jax.make_mesh( + (2,), ("x",), axis_types=(jax.sharding.AxisType.Explicit,) + ) + with jax.set_mesh(mesh): + sem = jax.sharding.reshard(jnp.zeros((1,), dtype=jnp.int32), P()) + x = jax.sharding.reshard(jnp.arange(2048, dtype=jnp.int32).reshape(64, 32), P()) + y, out_sem = jax.jit( + jax.shard_map( + mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, (), inout_shape=sem + ), + out_specs=P("x"), + check_vma=False, + ) + )(x, sem) + out_sems = multihost_utils.process_allgather(out_sem, tiled=True) + np.testing.assert_array_equal(out_sems, np.zeros_like(out_sems)) + y = multihost_utils.process_allgather(y, tiled=True).reshape(2, *x.shape) + np.testing.assert_array_equal(y, jnp.stack([x, x])) + + def test_multimem_store_tma(self): + i32 = ir.IntegerType.get_signless(32) + def kernel(ctx, inp, sem, out, scratch): + my_device = ctx.device_id() + other_device = arith.subi(arith.constant(i32, 1), my_device) + my_sem = mgpu.SemaphoreRef(mgpu.utils.memref_ptr(sem)) + other_dst = ctx.to_remote(sem, other_device) + other_sem = mgpu.SemaphoreRef(mgpu.utils.memref_ptr(other_dst)) + with mgpu.when(arith.cmpi(arith.CmpIPredicate.eq, my_device, arith.constant(i32, 0))): + arr = mgpu.FragmentedArray.load_strided(inp, is_signed=True) + arr.store_untiled(scratch) + mgpu.commit_shared() + ctx.async_copy( + src_ref=scratch, dst_ref=out, gmem_peer_id=mgpu.GLOBAL_BROADCAST + ) + ctx.await_async_copy(0) + other_sem.signal(arith.constant(i32, 1)) + my_sem.wait(1) + + mesh = jax.make_mesh( + (2,), ("x",), axis_types=(jax.sharding.AxisType.Explicit,) + ) + with jax.set_mesh(mesh): + sem = jax.sharding.reshard(jnp.zeros((1,), dtype=jnp.int32), P()) + x = jax.sharding.reshard(jnp.arange(2048, dtype=jnp.int32).reshape(64, 32), P()) + y, out_sem = jax.jit( + jax.shard_map( + mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, x, inout_shape=sem + ), + out_specs=P("x"), + check_vma=False, + ) + )(x, sem) + out_sems = multihost_utils.process_allgather(out_sem, tiled=True) + np.testing.assert_array_equal(out_sems, np.zeros_like(out_sems)) + y = multihost_utils.process_allgather(y, tiled=True).reshape(2, *x.shape) + np.testing.assert_array_equal(y, jnp.stack([x, x])) + + @parameterized.parameters( + (jnp.int32, 1, "add"), + (jnp.int32, 1, "min"), + (jnp.int32, 1, "max"), + (jnp.int32, 1, "and"), + (jnp.int32, 1, "or"), + (jnp.int32, 1, "xor"), + (jnp.float32, 1, "add"), + (jnp.float32, 2, "add"), + (jnp.float32, 4, "add"), + (jnp.float16, 2, "add"), + (jnp.float16, 2, "min"), + (jnp.float16, 4, "max"), + (jnp.float16, 8, "add"), + (jnp.bfloat16, 2, "max"), + (jnp.bfloat16, 8, "add"), + (jnp.float8_e5m2, 4, "add"), + (jnp.float8_e5m2, 8, "min"), + (jnp.float8_e5m2, 16, "max"), + (jnp.float8_e4m3fn, 4, "min"), + (jnp.float8_e4m3fn, 8, "max"), + (jnp.float8_e4m3fn, 16, "add"), + ) + def test_multimem_load_reduce(self, dtype, vector_length, reduction): + if dtype in ( + jnp.float8_e5m2, + jnp.float8_e4m3fn, + ) and not jtu.is_cuda_compute_capability_at_least("10.0"): + self.skipTest("Only works on GPU with capability >= sm100") + i32 = ir.IntegerType.get_signless(32) + def kernel(ctx, inp, sem, out, _): + my_device = ctx.device_id() + other_device = arith.subi(arith.constant(i32, 1), my_device) + my_sem = mgpu.SemaphoreRef(mgpu.utils.memref_ptr(sem)) + other_dst = ctx.to_remote(sem, other_device) + other_sem = mgpu.SemaphoreRef(mgpu.utils.memref_ptr(other_dst)) + layout = fa.WGStridedFragLayout((64, 32), vec_size=vector_length) + arr = mgpu.FragmentedArray.load_reduce_untiled( + ctx.to_remote_multicast(inp), + layout=layout, + is_signed=mgpu.utils.is_signed(dtype), + reduction=reduction, + ) + arr.store_untiled(out, optimized=False) + other_sem.signal(arith.constant(i32, 1)) + my_sem.wait(1) + + mesh = jax.make_mesh( + (2,), ("x",), axis_types=(jax.sharding.AxisType.Explicit,) + ) + with jax.set_mesh(mesh): + sem = jax.sharding.reshard(jnp.zeros((1,), dtype=jnp.int32), P()) + # The rounding we see in low precision types seems to be different from + # what JAX/XLA use. + match jnp.dtype(dtype).itemsize: + case 4: + bound = 800000 + case 2: + bound = 128 + case 1: + bound = 4 + case _: + raise ValueError(f"Unsupported dtype: {dtype}") + x_local = jax.random.randint( + jax.random.key(1234), (128, 32), dtype=jnp.int32, minval=-bound, maxval=bound + ).astype(dtype) + x = jax.sharding.reshard(x_local, P("x")) + x_shard = jax.ShapeDtypeStruct((64, 32), dtype) + # TODO(b/448323639): We don't need x to be inout here, but without aliasing + # XLA doesn't actually insert the copy that puts the operand in symmetric + # memory, which causes the kernel to crash. + y, _, out_sem = jax.jit( + jax.shard_map( + mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), x_shard, (), inout_shape=(x_shard, sem) + ), + out_specs=P("x"), + check_vma=False, + ) + )(x, sem) + out_sems = multihost_utils.process_allgather(out_sem, tiled=True) + np.testing.assert_array_equal(out_sems, np.zeros_like(out_sems)) + y = multihost_utils.process_allgather(y, tiled=True) + match reduction: + case "add": + np_reduction = jnp.add + case "min": + np_reduction = jnp.minimum + case "max": + np_reduction = jnp.maximum + case "and": + np_reduction = jnp.bitwise_and + case "or": + np_reduction = jnp.bitwise_or + case "xor": + np_reduction = jnp.bitwise_xor + case _: + raise ValueError(reduction) + np.testing.assert_array_equal( + y.astype(jnp.float32), np.tile(np_reduction(x_local[:64], x_local[64:]), (2, 1)) + ) + + +if __name__ == "__main__": + # This test doesn't work with the platform allocator, so we override it + # if it's ran alone. If it's part of a larger test suite and the platform + # allocator is used, setUp will skip the test. + os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.01' + os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'default' + jt_multiprocess.main() diff --git a/tests/mosaic/gpu_test_multidevice.py b/tests/mosaic/gpu_test_multidevice.py new file mode 100644 index 000000000000..114409a5efd8 --- /dev/null +++ b/tests/mosaic/gpu_test_multidevice.py @@ -0,0 +1,74 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from absl.testing import absltest, parameterized +import jax +from jax._src import config +from jax._src import test_util as jtu +from jax._src.interpreters import mlir +from jax._src.lib.mlir import ir +from jax.experimental.mosaic.gpu import dialect as mgpu_dialect # pylint: disable=g-importing-member +import jax.numpy as jnp +import numpy as np +try: + import jax._src.lib.mosaic_gpu # noqa: F401 + HAS_MOSAIC_GPU = True +except ImportError: + HAS_MOSAIC_GPU = False +else: + import jax.experimental.mosaic.gpu as mgpu + + +# ruff: noqa: F405 +# pylint: disable=g-complex-comprehension +config.parse_flags_with_absl() + + +class TestCase(parameterized.TestCase): + + def setUp(self): + if not HAS_MOSAIC_GPU: + self.skipTest("jaxlib built without Mosaic GPU") + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("Only works on GPU with capability >= sm90") + super().setUp() + self.prng = np.random.default_rng(1234) + self.context = mlir.make_ir_context() + if mgpu_dialect is not None: + mgpu_dialect.register_dialect(self.context) + self.enter_context(config.traceback_filtering("off")) + self.enter_context(self.context) + self.enter_context(ir.Location.unknown()) + + +class ProfilerTest(TestCase): + + def test_multigpu(self): + if len(jax.devices()) < 2: + self.skipTest("Need at least 2 devices") + def kernel(ctx, src, dst, _): + mgpu.FragmentedArray.load_strided(src).store_untiled(dst) + x = np.arange(64 * 64, dtype=jnp.float32).reshape(64, 64) + f = jax.jit(mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, () + )) + # Make sure we can invoke the same program on different devices. + for xd in (jax.device_put(x, d) for d in jax.devices()[:2]): + jax.block_until_ready(f(xd)) + + +if __name__ == "__main__": + absltest.main(argv=["python"], testLoader=jtu.JaxTestLoader()) diff --git a/tests/mosaic/gpu_torch_test.py b/tests/mosaic/gpu_torch_test.py new file mode 100644 index 000000000000..9fb4160686b7 --- /dev/null +++ b/tests/mosaic/gpu_torch_test.py @@ -0,0 +1,93 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import unittest + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import config +from jax._src import test_util as jtu +from jax._src.interpreters import mlir +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import arith +from jax._src.lib.mlir.dialects import gpu +import jax.experimental.mosaic.gpu as mgpu +from jax.experimental.mosaic.gpu import dialect as mgpu_dialect # pylint: disable=g-importing-member +from jax.experimental.mosaic.gpu.utils import * # noqa: F403 +import jax.numpy as jnp +import numpy as np + +try: + import torch +except ImportError: + torch = None + + +# ruff: noqa: F405 +# pylint: disable=g-complex-comprehension +config.parse_flags_with_absl() + + +class TorchTest(parameterized.TestCase): + + def setUp(self): + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("Only works on GPU with capability >= sm90") + super().setUp() + self.prng = np.random.default_rng(1234) + self.context = mlir.make_ir_context() + mgpu_dialect.register_dialect(self.context) + self.enter_context(config.traceback_filtering("off")) + self.enter_context(self.context) + self.enter_context(ir.Location.unknown()) + if torch is None: + raise unittest.SkipTest("Test requires PyTorch") + + def test_basic(self): + def kernel(ctx, i_gmem, o_gmem, _): + x = mgpu.FragmentedArray.load_strided(i_gmem) + (x + x).store_untiled(o_gmem) + + ty = jax.ShapeDtypeStruct((128, 128), jnp.float32) + x = torch.randn((128, 128), dtype=torch.float, device='cuda') + f = mgpu.as_torch_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), ty, ty, ()) + y = f(x) + np.testing.assert_allclose(y.cpu(), x.cpu() * 2) + del y # Make sure the destructor runs successfully. + + def test_inout(self): + def kernel(ctx, src, inout, dst, smem): + val = memref.load(inout, []) + gpu.barrier() + new_val = arith.constant(ir.IntegerType.get_signless(32), 42) + memref.store(new_val, inout, []) + x = mgpu.FragmentedArray.load_strided(src, is_signed=True) + (x + val).store_untiled(dst) + x = torch.arange(128, dtype=torch.int32, device='cuda') + y = torch.tensor(2.0, dtype=torch.int32, device='cuda') + x_ty = jax.ShapeDtypeStruct((128,), jnp.int32) + y_ty = jax.ShapeDtypeStruct((), jnp.int32) + kernel = mgpu.as_torch_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x_ty, x_ty, (), inout_shape=y_ty, + ) + xo, yo = kernel(x, y) + np.testing.assert_array_equal(xo.cpu(), x.cpu() + 2.0) + np.testing.assert_array_equal(yo.cpu(), torch.tensor(42, dtype=torch.int32)) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/mosaic/gpu_torch_test_distributed.py b/tests/mosaic/gpu_torch_test_distributed.py new file mode 100644 index 000000000000..61dab1e9abda --- /dev/null +++ b/tests/mosaic/gpu_torch_test_distributed.py @@ -0,0 +1,132 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os +import unittest + +from absl.testing import parameterized +import jax +from jax._src import config +from jax._src import test_util as jtu +from jax._src import test_multiprocess as jt_multiprocess +from jax._src.interpreters import mlir +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import arith +from jax._src.lib.mlir.dialects import memref +from jax.experimental.mosaic.gpu import dialect as mgpu_dialect # pylint: disable=g-importing-member +import jax.numpy as jnp +import numpy as np +import jax.experimental.mosaic.gpu as mgpu +try: + import torch + import torch.distributed as dist + import torch.distributed._symmetric_memory as symm_mem +except ImportError: + torch = None + + +# ruff: noqa: F405 +# pylint: disable=g-complex-comprehension + + +class TorchTest(parameterized.TestCase): + + def setUpClass(): + torch.cuda.set_device("cuda:0") + torch.set_default_device("cuda") + if torch is None: + raise unittest.SkipTest("Test requires torch") + if not torch.cuda.is_available(): + raise unittest.SkipTest("Test requires torch with CUDA support") + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_at_least("9.0")): + raise unittest.SkipTest("Only works on GPU with capability >= sm90") + device_count = torch.cuda.device_count() + for d1 in range(device_count - 1): + for d2 in range(d1 + 1, device_count): + if not torch.cuda.can_device_access_peer(d1, d2): + raise unittest.SkipTest("Test requires p2p access") + if jax.process_count() == 1: + raise unittest.SkipTest("Test requires multiple processes.") + if jax.device_count() != jax.process_count(): + raise unittest.SkipTest("Need 1 device per process") + + os.environ["RANK"] = str(jax.process_index()) + os.environ["WORLD_SIZE"] = str(jax.process_count()) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "5728" + dist.init_process_group("nccl") + symm_mem.enable_symm_mem_for_group(dist.group.WORLD.group_name) + assert dist.is_initialized() + assert symm_mem.is_nvshmem_available() + symm_mem.set_backend("NVSHMEM") + symm_mem.empty(1) # Just to initialize NVSHMEM + + def setUp(self): + self.prng = np.random.default_rng(1234) + self.context = mlir.make_ir_context() + if mgpu_dialect is not None: + mgpu_dialect.register_dialect(self.context) + self.enter_context(config.traceback_filtering("off")) + self.enter_context(self.context) + self.enter_context(ir.Location.unknown()) + + def test_get_device_id(self): + index = ir.IndexType.get() + def kernel_body(ctx, dst, _): + device_id = ctx.device_id() + memref.store(device_id, dst, [arith.constant(index, 0)]) + + out_shape = jax.ShapeDtypeStruct((1,), jnp.int32) + kernel = mgpu.as_torch_gpu_kernel( + kernel_body, (1, 1, 1), (128, 1, 1), (), out_shape, () + ) + gathered = torch.empty((2,), dtype=torch.int32) + dist.all_gather_into_tensor(gathered, kernel()) + self.assertEqual(gathered.tolist(), list(range(jax.process_count()))) + + def test_remote_semaphore(self): + if dist.get_world_size() != 2: + self.skipTest("Test assumes 2 devices") + + i32 = ir.IntegerType.get_signless(32) + def kernel(ctx, sem, _): + my_device = ctx.device_id() + other_device = arith.subi(arith.constant(i32, 1), my_device) + my_sem = mgpu.SemaphoreRef(mgpu.utils.memref_ptr(sem)) + other_dst = ctx.to_remote(sem, other_device) + other_sem = mgpu.SemaphoreRef(mgpu.utils.memref_ptr(other_dst)) + # We signal and wait a different amount on each device to make sure we're + # really communicating here. + other_sem.signal(arith.addi(arith.constant(i32, 1), other_device)) + @mgpu.fori(arith.addi(arith.constant(i32, 1), my_device), None) + def wait_loop(i, _): + my_sem.wait(1) + + sem_shape = jax.ShapeDtypeStruct((1,), jnp.int32) + kernel = mgpu.as_torch_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), (), (), inout_shape=sem_shape + ) + gathered = torch.empty((2,), dtype=torch.int32) + sem = symm_mem.empty((1,), dtype=torch.int32) + sem_symm = symm_mem.rendezvous(sem, dist.group.WORLD) + (sem_again,) = kernel(sem) + self.assertEqual(sem_again.data_ptr(), sem.data_ptr()) + dist.all_gather_into_tensor(gathered, sem) + self.assertEqual(gathered.tolist(), [0, 0]) + + +if __name__ == "__main__": + jt_multiprocess.main() diff --git a/tests/mosaic/gpu_transform_inference_test.py b/tests/mosaic/gpu_transform_inference_test.py deleted file mode 100644 index b7cd146dfdb6..000000000000 --- a/tests/mosaic/gpu_transform_inference_test.py +++ /dev/null @@ -1,423 +0,0 @@ -# Copyright 2025 The JAX Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Transform inference tests for the Mosaic GPU MLIR dialect.""" - -# pylint: disable=g-complex-comprehension - -from absl.testing import parameterized -import jax -from jax import numpy as jnp -from jax._src import config -from jax._src import test_util as jtu -from jax._src.interpreters import mlir as mlir_interpreter -from jax._src.lib.mlir import ir -from jax._src.lib.mlir.dialects import arith -from jax._src.lib.mlir.dialects import func -from jax._src.lib.mlir.dialects import vector -import jax.experimental.mosaic.gpu as mgpu -from jax.experimental.mosaic.gpu import fragmented_array as fa -from jax.experimental.mosaic.gpu import inference_utils -from jax.experimental.mosaic.gpu import layouts as layouts_lib -import numpy as np - - -config.parse_flags_with_absl() - - -def _make_ir_context(): - context = ir.Context() - context.append_dialect_registry(mlir_interpreter.upstream_dialects) - context.load_all_available_dialects() - mgpu.dialect.register_dialect(context) - return context - - -class TransformInferenceTest(parameterized.TestCase): - - def setUp(self): - if jax.version._version != jax.lib.__version__: - raise self.skipTest("Test requires matching jax and jaxlib versions") - super().setUp() - self.enter_context(_make_ir_context()) - self.enter_context(ir.Location.unknown()) - self.module = ir.Module.create() - - @parameterized.parameters( - (swizzle, dtype) - for swizzle in mgpu.dialect.SwizzlingMode - for dtype in [jnp.bfloat16, jnp.float32] - ) - def test_infer_transforms_for_wgmma_op(self, swizzle, dtype): - swizzle_elems = swizzle // np.dtype(dtype).itemsize - m = 64 - # Note: `group_m` and `group_k` should be coprime with 2 for the test to be - # correct. Otherwise, we may infer larger swizzles than the test intends to - # check. - group_m, group_k = 3, 3 - lhs_shape = (group_m * m, group_k * swizzle_elems) - rhs_shape = (group_k * swizzle_elems, group_k * swizzle_elems) - out_shape = (group_m * m, group_k * swizzle_elems) - wgmma_op = None - - def body(accumulator, lhs, rhs): - nonlocal wgmma_op - wgmma_op = mgpu.dialect.WGMMAOp(accumulator, lhs, rhs) - - with ir.InsertionPoint(self.module.body): - smem = ir.Attribute.parse("#gpu.address_space") - elt_ty = mgpu.utils.dtype_to_ir_type(dtype) - lhs_ty = ir.MemRefType.get(lhs_shape, elt_ty, memory_space=smem) - rhs_ty = ir.MemRefType.get(rhs_shape, elt_ty, memory_space=smem) - acc_ty = ir.VectorType.get(out_shape, elt_ty) - func.FuncOp.from_py_func(acc_ty, lhs_ty, rhs_ty)(body) - - mgpu.infer_transforms(self.module) - - arg_transforms = ir.ArrayAttr.get([ - mgpu.dialect.TileTransformAttr.get((8, swizzle_elems)), - mgpu.dialect.SwizzleTransformAttr.get(int(swizzle)), - ]) - - self.assertSequenceEqual( - inference_utils.in_transforms(wgmma_op), - [arg_transforms, arg_transforms], - ) - self.assertEmpty(inference_utils.out_transforms(wgmma_op)) - - def test_infer_transforms_for_async_load_derives_from_destination(self): - async_load_op = None - shape = (64, 64) - elt_ty = ir.BF16Type.get() - - def body(gmem_ref, smem_ref, barrier): - nonlocal async_load_op - zero = arith.constant(ir.IntegerType.get_signless(32), 0) - async_load_op = mgpu.dialect.AsyncLoadOp( - source=gmem_ref, - destination=smem_ref, - barrier=barrier, - indices=[zero, zero], - slice_lengths=shape, - collective=ir.ArrayAttr.get([]), - ) - - with ir.InsertionPoint(self.module.body): - smem = ir.Attribute.parse("#gpu.address_space") - gmem_ty = ir.MemRefType.get(shape, elt_ty) - smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem) - barrier_ty = ir.Type.parse("!mosaic_gpu.barrier") - f = func.FuncOp.from_py_func(gmem_ty, smem_ty, barrier_ty)(body).func_op - - transforms = ir.ArrayAttr.get( - [mgpu.dialect.TransposeTransformAttr.get((1, 0))] - ) - f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) - - mgpu.infer_transforms(self.module) - - self.assertSequenceEqual( - inference_utils.in_transforms(async_load_op), [transforms] - ) - self.assertEmpty(inference_utils.out_transforms(async_load_op)) - - def test_infer_transforms_for_async_store_op_derives_from_source(self): - async_store_op = None - shape = (64, 64) - elt_ty = ir.BF16Type.get() - - def body(gmem_ref, smem_ref): - nonlocal async_store_op - zero = arith.constant(ir.IntegerType.get_signless(32), 0) - async_store_op = mgpu.dialect.AsyncStoreOp( - source=smem_ref, - destination=gmem_ref, - indices=[zero, zero], - slice_lengths=shape, - ) - - with ir.InsertionPoint(self.module.body): - smem = ir.Attribute.parse("#gpu.address_space") - gmem_ty = ir.MemRefType.get(shape, elt_ty) - smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem) - f = func.FuncOp.from_py_func(gmem_ty, smem_ty)(body).func_op - - transforms = ir.ArrayAttr.get( - [mgpu.dialect.TransposeTransformAttr.get((1, 0))] - ) - f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) - - mgpu.infer_transforms(self.module) - - self.assertSequenceEqual( - inference_utils.in_transforms(async_store_op), [transforms] - ) - self.assertEmpty(inference_utils.out_transforms(async_store_op)) - - def test_infer_transforms_for_vector_load_op_derives_from_destination(self): - vector_load_op = None - shape = (64, 64) - elt_ty = ir.BF16Type.get() - - def body(smem_ref): - nonlocal vector_load_op - zero = arith.constant(ir.IntegerType.get_signless(32), 0) - vector_load_op = vector.LoadOp( - ir.VectorType.get(shape, elt_ty), smem_ref, [zero] * len(shape) - ) - - with ir.InsertionPoint(self.module.body): - smem = ir.Attribute.parse("#gpu.address_space") - smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem) - func.FuncOp.from_py_func(smem_ty)(body) - - vector_load_op.attributes["out_layouts"] = ir.ArrayAttr.get( - [layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)] - ) - - mgpu.infer_transforms(self.module) - - expected_transforms = ir.ArrayAttr.get([ - mgpu.dialect.TileTransformAttr.get((8, 64)), - mgpu.dialect.SwizzleTransformAttr.get(128), - ]) - - self.assertSequenceEqual( - inference_utils.in_transforms(vector_load_op), [expected_transforms] - ) - self.assertEmpty(inference_utils.out_transforms(vector_load_op)) - - def test_infer_transforms_for_vector_load_op_derives_from_source(self): - vector_load_op = None - shape = (64, 64) - elt_ty = ir.BF16Type.get() - - def body(smem_ref): - nonlocal vector_load_op - zero = arith.constant(ir.IntegerType.get_signless(32), 0) - vector_load_op = vector.LoadOp( - ir.VectorType.get(shape, elt_ty), smem_ref, [zero] * len(shape) - ) - - with ir.InsertionPoint(self.module.body): - smem = ir.Attribute.parse("#gpu.address_space") - smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem) - f = func.FuncOp.from_py_func(smem_ty)(body).func_op - - vector_load_op.attributes["out_layouts"] = ir.ArrayAttr.get( - [layouts_lib.to_layout_attr(fa.WGStridedFragLayout(shape, vec_size=4))] - ) - transforms = ir.ArrayAttr.get([mgpu.dialect.TileTransformAttr.get((8, 64))]) - f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) - - mgpu.infer_transforms(self.module) - - self.assertSequenceEqual( - inference_utils.in_transforms(vector_load_op), [transforms] - ) - self.assertEmpty(inference_utils.out_transforms(vector_load_op)) - - def test_infer_transforms_for_vector_load_op_raises_on_mismatches(self): - vector_load_op = None - shape = (64, 64) - elt_ty = ir.BF16Type.get() - - def body(smem_ref): - nonlocal vector_load_op - zero = arith.constant(ir.IntegerType.get_signless(32), 0) - vector_load_op = vector.LoadOp( - ir.VectorType.get(shape, elt_ty), smem_ref, [zero] * len(shape) - ) - - with ir.InsertionPoint(self.module.body): - smem = ir.Attribute.parse("#gpu.address_space") - smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem) - f = func.FuncOp.from_py_func(smem_ty)(body).func_op - - vector_load_op.attributes["out_layouts"] = ir.ArrayAttr.get( - [layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)] - ) - transforms = ir.ArrayAttr.get([mgpu.dialect.TileTransformAttr.get((8, 64))]) - f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) - - with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"): - mgpu.infer_transforms(self.module) - - def test_infer_transforms_for_vector_store_op_derives_from_destination(self): - vector_store_op = None - shape = (64, 64) - elt_ty = ir.BF16Type.get() - - def body(smem_ref, value_to_store): - nonlocal vector_store_op - zero = arith.constant(ir.IntegerType.get_signless(32), 0) - vector_store_op = vector.StoreOp( - value_to_store, smem_ref, [zero] * len(shape) - ) - - with ir.InsertionPoint(self.module.body): - smem = ir.Attribute.parse("#gpu.address_space") - smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem) - value_ty = ir.VectorType.get(shape, elt_ty) - func.FuncOp.from_py_func(smem_ty, value_ty)(body) - - vector_store_op.attributes["in_layouts"] = ir.ArrayAttr.get( - [layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)] - ) - - mgpu.infer_transforms(self.module) - - expected_transforms = ir.ArrayAttr.get([ - mgpu.dialect.TileTransformAttr.get((8, 64)), - mgpu.dialect.SwizzleTransformAttr.get(128), - ]) - - self.assertSequenceEqual( - inference_utils.in_transforms(vector_store_op), [expected_transforms] - ) - self.assertEmpty(inference_utils.out_transforms(vector_store_op)) - - def test_infer_transforms_for_vector_store_op_derives_from_source(self): - vector_store_op = None - shape = (64, 64) - elt_ty = ir.BF16Type.get() - - def body(smem_ref, value_to_store): - nonlocal vector_store_op - zero = arith.constant(ir.IntegerType.get_signless(32), 0) - vector_store_op = vector.StoreOp( - value_to_store, smem_ref, [zero] * len(shape) - ) - - with ir.InsertionPoint(self.module.body): - smem = ir.Attribute.parse("#gpu.address_space") - smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem) - value_ty = ir.VectorType.get(shape, elt_ty) - f = func.FuncOp.from_py_func(smem_ty, value_ty)(body).func_op - - vector_store_op.attributes["in_layouts"] = ir.ArrayAttr.get( - [layouts_lib.to_layout_attr(fa.WGStridedFragLayout(shape, vec_size=4))] - ) - transforms = ir.ArrayAttr.get([mgpu.dialect.TileTransformAttr.get((8, 64))]) - f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) - - mgpu.infer_transforms(self.module) - - self.assertSequenceEqual( - inference_utils.in_transforms(vector_store_op), [transforms] - ) - self.assertEmpty(inference_utils.out_transforms(vector_store_op)) - - def test_infer_transforms_for_vector_store_op_raises_on_mismatches(self): - vector_store_op = None - shape = (64, 64) - elt_ty = ir.BF16Type.get() - - def body(smem_ref, value_to_store): - nonlocal vector_store_op - zero = arith.constant(ir.IntegerType.get_signless(32), 0) - vector_store_op = vector.StoreOp( - value_to_store, smem_ref, [zero] * len(shape) - ) - - with ir.InsertionPoint(self.module.body): - smem = ir.Attribute.parse("#gpu.address_space") - smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem) - value_ty = ir.VectorType.get(shape, elt_ty) - f = func.FuncOp.from_py_func(smem_ty, value_ty)(body).func_op - - vector_store_op.attributes["in_layouts"] = ir.ArrayAttr.get( - [layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)] - ) - transforms = ir.ArrayAttr.get([mgpu.dialect.TileTransformAttr.get((8, 64))]) - f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) - - with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"): - mgpu.infer_transforms(self.module) - - def test_infer_transforms_for_slice_smem_op_derives_from_user(self): - slice_smem_op = vector_load_op = None - shape = (64, 64) - elt_ty = ir.BF16Type.get() - smem = ir.Attribute.parse("#gpu.address_space") - - def body(offset): - nonlocal slice_smem_op, vector_load_op - slice_smem_op = mgpu.dialect.SliceSMEMOp( - ir.MemRefType.get(shape, elt_ty, memory_space=smem), offset - ) - zero = arith.constant(ir.IntegerType.get_signless(32), 0) - load_offsets = [zero] * len(shape) - vector_load_op = vector.LoadOp( - ir.VectorType.get(shape, elt_ty), slice_smem_op.result, load_offsets - ) - - with ir.InsertionPoint(self.module.body): - func.FuncOp.from_py_func(ir.IntegerType.get_signless(32))(body) - - vector_load_op.attributes["out_layouts"] = ir.ArrayAttr.get( - [layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)] - ) - - mgpu.infer_transforms(self.module) - - expected_transforms = ir.ArrayAttr.get([ - mgpu.dialect.TileTransformAttr.get((8, 64)), - mgpu.dialect.SwizzleTransformAttr.get(128), - ]) - - self.assertEmpty(inference_utils.in_transforms(slice_smem_op)) - self.assertSequenceEqual( - inference_utils.out_transforms(slice_smem_op), [expected_transforms] - ) - - def test_infer_transforms_for_slice_smem_op_raises_on_mismatches(self): - slice_smem_op = vector_load_op1 = vector_load_op2 = None - shape = (64, 64) - elt_ty = ir.BF16Type.get() - smem = ir.Attribute.parse("#gpu.address_space") - - def body(offset): - nonlocal slice_smem_op, vector_load_op1, vector_load_op2 - slice_smem_op = mgpu.dialect.SliceSMEMOp( - ir.MemRefType.get(shape, elt_ty, memory_space=smem), offset - ) - zero = arith.constant(ir.IntegerType.get_signless(32), 0) - load_offsets = [zero] * len(shape) - vector_load_op1 = vector.LoadOp( - ir.VectorType.get(shape, elt_ty), slice_smem_op.result, load_offsets - ) - vector_load_op2 = vector.LoadOp( - ir.VectorType.get(shape, elt_ty), slice_smem_op.result, load_offsets - ) - - with ir.InsertionPoint(self.module.body): - func.FuncOp.from_py_func(ir.IntegerType.get_signless(32))(body) - - vector_load_op1.attributes["out_layouts"] = ir.ArrayAttr.get( - [layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)] - ) - vector_load_op2.attributes["out_layouts"] = ir.ArrayAttr.get( - [layouts_lib.to_layout_attr(fa.WGStridedFragLayout(shape, vec_size=4))] - ) - vector_load_op2.attributes["in_transforms"] = ir.ArrayAttr.get( - [ir.ArrayAttr.get([mgpu.dialect.TransposeTransformAttr.get((1, 0))])] - ) - - with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"): - mgpu.infer_transforms(self.module) - - -if __name__ == "__main__": - parameterized.absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/mosaic/matmul_test.py b/tests/mosaic/matmul_test.py index d598d7d0c0ec..e609c9959741 100644 --- a/tests/mosaic/matmul_test.py +++ b/tests/mosaic/matmul_test.py @@ -15,12 +15,19 @@ """Test different parameterizations of a matmul.""" import os -import unittest from absl.testing import absltest, parameterized from jax._src import config from jax._src import test_util as jtu +from jax._src.interpreters import mlir +from jax._src.lib.mlir import ir +from jax.experimental.mosaic.gpu import dialect as mgpu_dialect # pylint: disable=g-importing-member import jax.numpy as jnp +import numpy as np + +import hypothesis as hp +import hypothesis.strategies as hps + try: # We only import this to see if Mosaic is available. import jax.experimental.mosaic.gpu # noqa: F401 @@ -28,11 +35,7 @@ matmul = None else: from jax.experimental.mosaic.gpu.examples import matmul -try: - import hypothesis as hp - import hypothesis.strategies as hps -except (ModuleNotFoundError, ImportError): - raise unittest.SkipTest("these tests require hypothesis") + from jax.experimental.mosaic.gpu.examples import matmul_blackwell config.parse_flags_with_absl() @@ -48,15 +51,20 @@ def wrapper(self, seed): @jtu.with_config(jax_traceback_filtering="off") +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe class MatmulTestCase(jtu.JaxTestCase): def setUp(self): super().setUp() if matmul is None: self.skipTest("Mosaic GPU not available.") - if (not jtu.test_device_matches(["cuda"]) or - not jtu.is_cuda_compute_capability_equal("9.0")): - self.skipTest("Only works on GPU with capability sm90a") + if not jtu.test_device_matches(["cuda"]): + self.skipTest("Test needs a GPU device") + self.context = mlir.make_ir_context() + mgpu_dialect.register_dialect(self.context) + self.enter_context(config.traceback_filtering("off")) + self.enter_context(self.context) + self.enter_context(ir.Location.unknown()) @parameterized.named_parameters( (f"_shard{i}", i) for i in range(5) @@ -64,7 +72,10 @@ def setUp(self): @seed_hypothesis @hp.settings(max_examples=100) # Add verbosity=hp.Verbosity.verbose to debug @hp.given(hps.data()) - def test_matmul(self, data): + def test_matmul_sm90(self, data): + if not jtu.is_cuda_compute_capability_equal("9.0"): + self.skipTest("Only works on GPU with capability sm90a") + in_dtype = data.draw( hps.sampled_from([jnp.float16, jnp.bfloat16, jnp.float32]), label="in_dtype", @@ -123,6 +134,73 @@ def test_matmul(self, data): hp.assume(False) raise e + @parameterized.named_parameters( + # TODO(apaszke): Increase shard count once we have more B200s in CI. + (f"_shard{i}", i) for i in range(1) + ) + @seed_hypothesis + @hp.settings(max_examples=100) # Add verbosity=hp.Verbosity.verbose to debug + @hp.given(hps.data()) + def test_matmul_sm100(self, data): + if not jtu.is_cuda_compute_capability_equal("10.0"): + self.skipTest("Only works on GPU with capability sm100a") + + dtype = data.draw( + hps.sampled_from([jnp.float16, jnp.bfloat16]), + label="dtype", + ) + m, n, k = ( + data.draw(hps.sampled_from([128, 256, 512, 2048, 8192]), label=d) for d in "mnk" + ) + max_concurrent_steps = data.draw( + hps.integers(2, 5), label="max_concurrent_steps" + ) + collective = data.draw(hps.booleans(), label="collective") + num_ctas = 2 if collective else 1 + hp.assume(not (m == 128 and collective)) # Too small for collective MMA. + tile_m = data.draw( + hps.sampled_from([t for t in [128] if t * num_ctas <= m]), label="tile_m" + ) + tmem_cols = 512 + tile_n = data.draw( + hps.sampled_from([ + t + for t in [64, 128, 256] + # We're double buffering TMEM in the kernel, hence the 2x. + if t * num_ctas <= n and 2 * t * num_ctas <= tmem_cols + ]), + label="tile_n", + ) + grid_m = m // (num_ctas * tile_m) + grid_tile_m = data.draw(hps.sampled_from([1, 2, 4, 8, 16]), label="grid_tile_m") + hp.assume(grid_m % grid_tile_m == 0) + + try: + kernel = matmul_blackwell.build_kernel( + m, + k, + n, + dtype=dtype, + tile_m=tile_m, + tile_n=tile_n, + grid_tile_m=grid_tile_m, + max_concurrent_steps=max_concurrent_steps, + collective=collective, + ) + except ValueError as e: + if "Mosaic GPU kernel exceeds available shared memory" in str(e): + hp.assume(False) + raise + + ka, kb = jax.random.split(jax.random.key(0), 2) + a = jax.random.normal(key=ka, shape=(m, k), dtype=dtype) + b = jax.random.normal(key=kb, shape=(n, k), dtype=dtype) + out = kernel(a, b) + out_ref = jnp.dot(a, b.T) + np.testing.assert_allclose( + out, out_ref, atol=2e-3, rtol=1e-2 + ) + if __name__ == "__main__": - absltest.main(testLoader=jtu.JaxTestLoader()) + absltest.main(argv=["python"], testLoader=jtu.JaxTestLoader()) diff --git a/tests/mosaic/profiler_cupti_test.py b/tests/mosaic/profiler_cupti_test.py index f3dcf71cab14..9b7eb9bd9036 100644 --- a/tests/mosaic/profiler_cupti_test.py +++ b/tests/mosaic/profiler_cupti_test.py @@ -14,7 +14,6 @@ # ============================================================================== from absl.testing import absltest, parameterized -import jax from jax._src import config from jax._src import test_util as jtu import jax.numpy as jnp @@ -43,11 +42,11 @@ def setUp(self): self.f = lambda x: 2*x def test_measure_cupti_explicit(self): - _, runtime_ms = profiler.measure(self.f, mode="cupti")(self.x) + _, runtime_ms = profiler.measure(self.f)(self.x) self.assertIsInstance(runtime_ms, float) def test_measure_per_kernel(self): - _, runtimes_ms = profiler.measure(self.f, mode="cupti", aggregate=False)(self.x) + _, runtimes_ms = profiler.measure(self.f, aggregate=False)(self.x) for item in runtimes_ms: self.assertIsInstance(item, tuple) self.assertEqual(len(item), 2) @@ -56,7 +55,7 @@ def test_measure_per_kernel(self): self.assertIsInstance(runtime_ms, float) def test_measure_cupti_repeated(self): - f_profiled = profiler.measure(self.f, mode="cupti") + f_profiled = profiler.measure(self.f) n = 3 timings = [f_profiled(self.x)[1] for _ in range(n)] for item in timings: @@ -64,12 +63,30 @@ def test_measure_cupti_repeated(self): def test_measure_repeated_interleaved(self): # test that kernels run outside of measure() are not captured - _, timings = profiler.measure(self.f, mode="cupti", aggregate=False)(self.x) + _, timings = profiler.measure(self.f, aggregate=False)(self.x) self.assertEqual(len(timings), 1) self.f(self.x) - _, timings = profiler.measure(self.f, mode="cupti", aggregate=False)(self.x) + _, timings = profiler.measure(self.f, aggregate=False)(self.x) self.assertEqual(len(timings), 1) + def test_iterations(self): + _, timings = profiler.measure( + self.f, aggregate=False, iterations=10 + )(self.x) + self.assertEqual(len(timings), 10) + self.assertTrue( + all( + isinstance(n, str) and isinstance(t, float) + for iter_timings in timings + for n, t in iter_timings + ) + ) + _, timings = profiler.measure( + self.f, aggregate=True, iterations=5 + )(self.x) + self.assertEqual(len(timings), 5) + self.assertTrue(all(isinstance(t, float) for t in timings)) + def test_measure_double_subscription(self): # This needs to run in a separate process, otherwise it affects the # outcomes of other tests since CUPTI state is global. diff --git a/tests/multi_device_test.py b/tests/multi_device_test.py index 38a37844ebf8..7053a8e5e29a 100644 --- a/tests/multi_device_test.py +++ b/tests/multi_device_test.py @@ -102,7 +102,7 @@ def test_computation_follows_data(self): self.assert_uncommitted_to_device(z3, devices[0]) - # A jitted computation with an device specification behaves as if the + # A jitted computation with a device specification behaves as if the # arguments are first device_put to the specified device. The result # will be committed on the specified. # The `device` parameter is experimental, and subject to change. @@ -304,7 +304,10 @@ def test_lax_full_like_efficient(self): self.skipTest('Only can run test on device with mem_stats') mesh = Mesh(devices, axis_names=("i")) sharding = NamedSharding(mesh, P('i')) - available_memory = mem_stats['bytes_reservable_limit'] + if jtu.is_device_rocm(): + available_memory = mem_stats['bytes_limit'] + else: + available_memory = mem_stats['bytes_reservable_limit'] array_size = available_memory // (6 * len(devices)) * len(devices) # Set up tracemalloc to track memory usage. tm.start() diff --git a/tests/multibackend_test.py b/tests/multibackend_test.py index 4697ba8b2858..74d4be77cc99 100644 --- a/tests/multibackend_test.py +++ b/tests/multibackend_test.py @@ -13,8 +13,6 @@ # limitations under the License. -from functools import partial - from absl.testing import absltest import numpy as np @@ -39,7 +37,7 @@ class MultiBackendTest(jtu.JaxTestCase): def testMultiBackend(self, backend): if backend not in ('cpu', jtu.device_under_test(), None): raise SkipTest("Backend is not CPU or the device under test") - @partial(jax.jit, backend=backend) + @jax.jit(backend=backend) def fun(x, y): return jnp.matmul(x, y) @@ -60,10 +58,10 @@ def testMultiBackendNestedJit(self, ordering): outer, inner = ordering if outer not in ('cpu', jtu.device_under_test(), None): raise SkipTest("Backend is not CPU or the device under test") - @partial(jax.jit, backend=outer) + @jax.jit(backend=outer) def fun(x, y): - @partial(jax.jit, backend=inner) + @jax.jit(backend=inner) def infun(x, y): return jnp.matmul(x, y) @@ -97,10 +95,10 @@ def testMultiBackendNestedJitConflict(self, ordering): "the entire computation. So if inner is CPU and outer is " "None, then the computation will be execute on CPU.") - @partial(jax.jit, backend=outer) + @jax.jit(backend=outer) def fun(x, y): - @partial(jax.jit, backend=inner) + @jax.jit(backend=inner) def infun(x, y): return jnp.matmul(x, y) @@ -116,7 +114,7 @@ def infun(x, y): def testGpuMultiBackendOpByOpReturn(self, backend): if backend not in ('cpu', jtu.device_under_test()): raise SkipTest("Backend is not CPU or the device under test") - @partial(jax.jit, backend=backend) + @jax.jit(backend=backend) def fun(x, y): return jnp.matmul(x, y) x = npr.uniform(size=(10,10)) @@ -130,7 +128,7 @@ def fun(x, y): @jtu.ignore_warning(category=DeprecationWarning, message="backend and device argument") def testJitCpu(self): - @partial(jax.jit, backend='cpu') + @jax.jit(backend='cpu') def get_arr(scale): return scale + jnp.ones((2, 2)) diff --git a/tests/multiprocess/BUILD b/tests/multiprocess/BUILD new file mode 100644 index 000000000000..ed6523ce10e9 --- /dev/null +++ b/tests/multiprocess/BUILD @@ -0,0 +1,219 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load( + "//jaxlib:jax.bzl", + "if_oss", + "jax_multiprocess_generate_backend_suites", + "jax_multiprocess_test", +) + +licenses(["notice"]) + +package( + default_applicable_licenses = [], +) + +jax_multiprocess_generate_backend_suites() + +jax_multiprocess_test( + name = "all_reduce_test", + srcs = ["all_reduce_test.py"], + backend_tags = { + "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 + }, + enable_configs = ["cpu_megascale"], + main = "all_reduce_test.py", + deps = ["//jax/_src:test_multiprocess"], +) + +jax_multiprocess_test( + name = "all_gather_test", + srcs = ["all_gather_test.py"], + backend_tags = { + "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 + }, + enable_configs = ["cpu_megascale"], + main = "all_gather_test.py", + deps = ["//jax/_src:test_multiprocess"], +) + +jax_multiprocess_test( + name = "all_to_all_test", + srcs = ["all_to_all_test.py"], + backend_tags = { + "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 + }, + enable_configs = ["cpu_megascale"], + main = "all_to_all_test.py", + deps = ["//jax/_src:test_multiprocess"], +) + +jax_multiprocess_test( + name = "axis_index_test", + srcs = ["axis_index_test.py"], + main = "axis_index_test.py", + deps = [ + "//jax/_src:test_multiprocess", + ], +) + +jax_multiprocess_test( + name = "array_test", + srcs = ["array_test.py"], + main = "array_test.py", + deps = [ + "//jax/_src:test_multiprocess", + ], +) + +jax_multiprocess_test( + name = "colocated_python_test", + srcs = ["colocated_python_test.py"], + disable_configs = [ + # This config has two cores per chip, and JAX distributed does not get + # the correct number of logical devices per host. + "tpu_v3_x4", + ], + main = "colocated_python_test.py", + deps = [ + "//jax/_src:test_multiprocess", + "//jax/experimental:colocated_python", + ], +) + +jax_multiprocess_test( + name = "device_id_test", + srcs = ["device_id_test.py"], + main = "device_id_test.py", + deps = [ + "//jax/_src:test_multiprocess", + ], +) + +jax_multiprocess_test( + name = "host_callback_test", + srcs = ["host_callback_test.py"], + main = "host_callback_test.py", + deps = [ + "//jax/_src:test_multiprocess", + "//jax/experimental", + "//jax/experimental:multihost_utils", + ], +) + +jax_multiprocess_test( + name = "key_value_store_test", + srcs = ["key_value_store_test.py"], + main = "key_value_store_test.py", + deps = [ + "//jax/_src:test_multiprocess", + ], +) + +jax_multiprocess_test( + name = "multihost_utils_test", + srcs = ["multihost_utils_test.py"], + backend_tags = { + "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 + }, + enable_backends = if_oss( + [ + "cpu", + "tpu", + ], + None, + ), # b/453057226 + main = "multihost_utils_test.py", + deps = [ + "//jax/_src:test_multiprocess", + ], +) + +jax_multiprocess_test( + name = "pjit_test", + srcs = ["pjit_test.py"], + main = "pjit_test.py", + deps = [ + "//jax/_src:test_multiprocess", + ], +) + +jax_multiprocess_test( + name = "pgle_test", + srcs = ["pgle_test.py"], + backend_tags = { + "gpu": ["noasan"], # Memory leaks in NCCL. + }, + enable_backends = ["gpu"], + main = "pgle_test.py", + deps = [ + "//jax/_src:test_multiprocess", + ], +) + +jax_multiprocess_test( + name = "pmap_test", + srcs = ["pmap_test.py"], + backend_tags = { + "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 + }, + main = "pmap_test.py", + deps = [ + "//jax/_src:test_multiprocess", + ], +) + +jax_multiprocess_test( + name = "socket_transfer_test", + srcs = ["socket_transfer_test.py"], + # TODO(b/476419684): Remove the following line once OSS tests are fixed. + enable_backends = if_oss([], None), + main = "socket_transfer_test.py", + deps = [ + "//jax/_src:test_multiprocess", + ], +) + +jax_multiprocess_test( + name = "thread_guard_test", + srcs = ["thread_guard_test.py"], + main = "thread_guard_test.py", + deps = [ + "//jax/_src:test_multiprocess", + ], +) + +jax_multiprocess_test( + name = "tpu_device_test", + srcs = ["tpu_device_test.py"], + enable_backends = ["tpu"], + main = "tpu_device_test.py", + deps = [ + "//jax/_src:test_multiprocess", + ], +) + +jax_multiprocess_test( + name = "wait_barrier_test", + srcs = ["wait_barrier_test.py"], + enable_backends = [ + "cpu", + "gpu", + ], + main = "wait_barrier_test.py", + deps = [ + "//jax/_src:test_multiprocess", + ], +) diff --git a/tests/multiprocess/all_gather_test.py b/tests/multiprocess/all_gather_test.py new file mode 100644 index 000000000000..0a5a39576db8 --- /dev/null +++ b/tests/multiprocess/all_gather_test.py @@ -0,0 +1,54 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import parameterized +import jax +from jax import lax +from jax._src import test_multiprocess as jt_multiprocess +from jax._src import test_util as jtu +import jax.numpy as jnp +import numpy as np + + +class AllGatherTest(jt_multiprocess.MultiProcessTest): + + @parameterized.parameters( + (np.int32,), (jnp.float32,), (jnp.float16,), (jnp.bfloat16,) + ) + def test_all_gather_shard_map(self, dtype): + mesh_shape = (jax.process_count(), jax.local_device_count()) + mesh = jtu.create_mesh(mesh_shape, ("x", "y")) + spec = jax.P("x", "y") + + @jax.shard_map( + mesh=mesh, in_specs=spec, out_specs=jax.P(None, None), check_vma=False + ) + def f(x): + out = lax.all_gather(x, "x", axis=0, tiled=True) + return lax.all_gather(out, "y", axis=1, tiled=True) + + global_len = np.prod(mesh_shape) + global_arr = jnp.arange(global_len, dtype=dtype).reshape(mesh_shape) + sharding = jax.NamedSharding(mesh, spec) + global_xs = jax.make_array_from_callback( + mesh_shape, sharding, lambda index: global_arr[index] + ) + + out = f(global_xs) + for actual in out.addressable_shards: + jtu.check_close(actual.data, global_arr[actual.index]) + + +if __name__ == "__main__": + jt_multiprocess.main() diff --git a/tests/multiprocess/all_reduce_test.py b/tests/multiprocess/all_reduce_test.py new file mode 100644 index 000000000000..0276baf62f80 --- /dev/null +++ b/tests/multiprocess/all_reduce_test.py @@ -0,0 +1,150 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import parameterized +import jax +from jax import lax +from jax import numpy as jnp +from jax._src import test_multiprocess as jt_multiprocess +from jax._src import test_util as jtu +import numpy as np + + +def randint_sample(shape): + return jax.random.randint(jax.random.PRNGKey(42), shape, -100, 100) + + +class AllReduceTest(jt_multiprocess.MultiProcessTest): + + def test_psum_simple(self): + mesh = jtu.create_mesh((jax.device_count(),), "x") + spec = jax.P("x") + + @jax.shard_map(mesh=mesh, in_specs=spec, out_specs=spec) + def f(x): + return lax.psum(x, "x") + + out = f(jnp.array([1] * jax.device_count())) + + for o in out.addressable_shards: + self.assertEqual(o.data, np.array([jax.device_count()])) + + @parameterized.parameters( + (np.int32,), (jnp.float32,), (jnp.float16,), (jnp.bfloat16,) + ) + def test_psum(self, dtype): + mesh_shape = (jax.process_count(), jax.local_device_count()) + mesh = jtu.create_mesh(mesh_shape, ("x", "y")) + spec = jax.P("x", "y") + + @jax.shard_map(mesh=mesh, in_specs=spec, out_specs=spec) + def f(x): + return lax.psum(x, ("x", "y")) + + xs = ( + jnp.arange(jax.local_device_count()) + + jax.process_index() * jax.local_device_count() + ) + xs = jnp.expand_dims(xs, axis=0).astype(dtype) + sharding = jax.NamedSharding(mesh, spec) + global_xs = jax.make_array_from_process_local_data(sharding, xs, mesh_shape) + local_xs = jnp.sum(jnp.arange(jax.device_count())).reshape(1, 1) + out = f(global_xs) + for actual in out.addressable_shards: + jtu.check_close(actual.data, local_xs) + + def test_psum_subset_devices(self): + mesh_shape = (jax.process_count(), jax.local_device_count()) + mesh = jtu.create_mesh(mesh_shape, ("x", "y")) + spec = jax.P("x") + + @jax.shard_map(mesh=mesh, in_specs=spec, out_specs=spec) + def f(x): + return lax.psum(x, "x") + + xs = ( + jnp.arange(jax.local_device_count()) + + jax.process_index() * jax.local_device_count() + ) + xs = jnp.expand_dims(xs, axis=0) + sharding = jax.NamedSharding(mesh, spec) + global_xs = jax.make_array_from_process_local_data(sharding, xs, mesh_shape) + local_xs = ( + jnp.arange(jax.device_count()) + .reshape(mesh_shape) + .sum(axis=0, keepdims=True) + ) + out = f(global_xs) + for actual in out.addressable_shards: + jtu.check_close(actual.data, local_xs) + + def test_psum_multiple_operands(self): + mesh_shape = (jax.process_count(), jax.local_device_count()) + mesh = jtu.create_mesh(mesh_shape, ("x", "y")) + spec = jax.P("x", "y") + sharding = jax.NamedSharding(mesh, spec) + x = ( + jnp.arange(jax.local_device_count()) + + jax.process_index() * jax.local_device_count() + ) + x = jnp.expand_dims(x, axis=(0, -1)) + + @jax.shard_map(mesh=mesh, in_specs=spec, out_specs=spec) + def f(x): + return lax.psum(x, ("x", "y")) + + length = 100 + xs = jnp.tile(x, (1, 1, length)) + global_shape = mesh_shape + (length,) + global_xs = jax.make_array_from_process_local_data(sharding, xs, global_shape) + local_xs = jnp.sum(jnp.arange(jax.device_count())) * jnp.ones((1, 1, length)) + out = f(global_xs) + for actual in out.addressable_shards: + jtu.check_close(actual.data, local_xs) + + length = 200 + xs = jnp.tile(x, (1, 1, length)) + global_shape = mesh_shape + (length,) + global_xs = jax.make_array_from_process_local_data(sharding, xs, global_shape) + local_xs = jnp.sum(jnp.arange(jax.device_count())) * jnp.ones((1, 1, length)) + out = f(global_xs) + for actual in out.addressable_shards: + jtu.check_close(actual.data, local_xs) + + # TODO(dsuo): Remove this warning once PmapSharding is removed. We don't + # convert this to shard_map since axis_index_groups raises a + # NotImplementedError. + @jtu.ignore_warning(category=DeprecationWarning) + def test_psum_axis_index_groups(self): + devices = list(range(jax.device_count())) + axis_index_groups = [devices[0::2], devices[1::2]] + print(axis_index_groups, jax.devices()) + f = jax.pmap( + lambda x: lax.psum(x, "i", axis_index_groups=axis_index_groups), + axis_name="i", + ) + xs = randint_sample([jax.process_count(), jax.local_device_count(), 100]) + out = f(xs[jax.process_index()]) + + xs = xs.reshape([jax.device_count(), 100]) + group0_expected = sum(xs[0::2, :]) + group1_expected = sum(xs[1::2, :]) + for i, actual in enumerate(out): + device_id = i + jax.process_index() * jax.local_device_count() + expected = group0_expected if device_id % 2 == 0 else group1_expected + np.testing.assert_array_equal(actual, expected) + + +if __name__ == "__main__": + jt_multiprocess.main() diff --git a/tests/multiprocess/all_to_all_test.py b/tests/multiprocess/all_to_all_test.py new file mode 100644 index 000000000000..167ae5971bd2 --- /dev/null +++ b/tests/multiprocess/all_to_all_test.py @@ -0,0 +1,77 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import parameterized +import jax +from jax import lax +from jax._src import test_multiprocess as jt_multiprocess +from jax._src import test_util as jtu +import jax.numpy as jnp +import numpy as np + + +class AllToAllTest(jt_multiprocess.MultiProcessTest): + + @parameterized.parameters( + (np.int32,), (jnp.float32,), (jnp.float16,), (jnp.bfloat16,) + ) + def test_all_to_all_shard_map(self, dtype): + rng = np.random.RandomState(42) + devices = jax.devices() + mesh = jax.sharding.Mesh(devices, ("i",)) + device_to_index = {d: i for i, d in enumerate(devices)} + + @jax.shard_map( + mesh=mesh, + in_specs=jax.P("i", None, None), + out_specs=jax.P("i", None, None), + ) + def f(x): + x = jnp.squeeze(x, 0) + out = lax.all_to_all(x, "i", split_axis=0, concat_axis=0) + return jnp.expand_dims(out, 0) + + shape = [ + jax.process_count(), + jax.local_device_count(), + jax.device_count(), + 100, + ] + + if jnp.issubdtype(dtype, jnp.floating): + xs = rng.randn(*shape).astype(dtype) + else: + xs = rng.randint(0, 100, size=shape).astype(dtype) + + global_shape = (jax.device_count(), jax.device_count(), 100) + sharding = jax.NamedSharding(mesh, jax.P("i", None, None)) + local_data = xs[jax.process_index()] + global_xs = jax.make_array_from_process_local_data( + sharding, local_data, global_shape + ) + + global_out = f(global_xs) + + local_shards = global_out.addressable_shards + local_shards = sorted(local_shards, key=lambda s: device_to_index[s.device]) + + for shard in local_shards: + rank = device_to_index[shard.device] + actual = np.array(shard.data).squeeze(0) # (D, 100) + expected = np.reshape(xs[:, :, rank, :], [jax.device_count(), 100]) + jtu.check_close(actual, expected) + + +if __name__ == "__main__": + jt_multiprocess.main() diff --git a/tests/multiprocess/array_test.py b/tests/multiprocess/array_test.py new file mode 100644 index 000000000000..2e6e456bb930 --- /dev/null +++ b/tests/multiprocess/array_test.py @@ -0,0 +1,1080 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multihost tests for jax.Array.""" + +import math +import unittest + +from absl.testing import parameterized +import jax +from jax._src import array +from jax._src import sharding_impls +from jax._src import test_multiprocess as jt_multiprocess +from jax._src import test_util as jtu +import jax.numpy as jnp +from jax.sharding import PartitionSpec as P +import numpy as np + + +def create_array(shape, arr_sharding, global_data=None): + if global_data is None: + global_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) + + return array.make_array_from_callback( + shape, arr_sharding, lambda idx: global_data[idx], + dtype=global_data.dtype), global_data + + +def create_nonaddressable_array(shape, spec=None): + """Creates an array that is non-addressable in half of the processes. + + Args: + shape: Shape of the array. + spec: Sharding spec of the array. If None, the array is sharded over all + participating devices. + + Returns: + A tuple of the created array and the global data. + """ + n_dev = len(jax.devices()) // 2 + mesh = jax.make_mesh((n_dev,), ("x",), devices=jax.devices()[:n_dev], + axis_types=(jax.sharding.AxisType.Explicit,)) + if spec is None: + spec = P("x") + s = jax.sharding.NamedSharding(mesh, spec) + return create_array(shape, s) + + +class ArrayTestMultiHost(jt_multiprocess.MultiProcessTest): + + @parameterized.named_parameters( + ( + "mesh_x_y", + P("x", "y"), + # device_id -> (index, replica_id) + { + 0: ((slice(0, 2), slice(0, 1)), 0), + 1: ((slice(0, 2), slice(1, 2)), 0), + 2: ((slice(2, 4), slice(0, 1)), 0), + 3: ((slice(2, 4), slice(1, 2)), 0), + 4: ((slice(4, 6), slice(0, 1)), 0), + 5: ((slice(4, 6), slice(1, 2)), 0), + 6: ((slice(6, 8), slice(0, 1)), 0), + 7: ((slice(6, 8), slice(1, 2)), 0), + }, + (2, 1), + False, + False, + ), + ( + "mesh_x", + P("x"), + # device_id -> (index, replica_id) + { + 0: ((slice(0, 2), slice(None)), 0), + 1: ((slice(0, 2), slice(None)), 1), + 2: ((slice(2, 4), slice(None)), 0), + 3: ((slice(2, 4), slice(None)), 1), + 4: ((slice(4, 6), slice(None)), 0), + 5: ((slice(4, 6), slice(None)), 1), + 6: ((slice(6, 8), slice(None)), 0), + 7: ((slice(6, 8), slice(None)), 1), + }, + (2, 2), + False, + False, + ), + ( + "mesh_y", + P("y"), + # device_id -> (index, replica_id) + { + 0: ((slice(0, 4), slice(None)), 0), + 1: ((slice(4, 8), slice(None)), 0), + 2: ((slice(0, 4), slice(None)), 1), + 3: ((slice(4, 8), slice(None)), 1), + 4: ((slice(0, 4), slice(None)), 2), + 5: ((slice(4, 8), slice(None)), 2), + 6: ((slice(0, 4), slice(None)), 3), + 7: ((slice(4, 8), slice(None)), 3), + }, + (4, 2), + False, + True, + ), + ( + "mesh_xy", + P(("x", "y")), + # device_id -> (index, replica_id) + { + 0: ((slice(0, 1), slice(None)), 0), + 1: ((slice(1, 2), slice(None)), 0), + 2: ((slice(2, 3), slice(None)), 0), + 3: ((slice(3, 4), slice(None)), 0), + 4: ((slice(4, 5), slice(None)), 0), + 5: ((slice(5, 6), slice(None)), 0), + 6: ((slice(6, 7), slice(None)), 0), + 7: ((slice(7, 8), slice(None)), 0), + }, + (1, 2), + False, + False, + ), + ( + "mesh_fully_replicated", + P(), + # device_id -> (index, replica_id) + { + 0: ((slice(None), slice(None)), 0), + 1: ((slice(None), slice(None)), 1), + 2: ((slice(None), slice(None)), 2), + 3: ((slice(None), slice(None)), 3), + 4: ((slice(None), slice(None)), 4), + 5: ((slice(None), slice(None)), 5), + 6: ((slice(None), slice(None)), 6), + 7: ((slice(None), slice(None)), 7), + }, + (8, 2), + True, + True, + ), + ) + # Test does not work with non-contiguous device IDs. + @jtu.skip_on_devices("cpu") + def test_array_2d_shard(self, pspec, expected_idx_rid, + expected_shard_shape, expected_is_fully_replicated, + fetch_to_host): + if jtu.is_device_tpu("5", "e"): + raise unittest.SkipTest("Test fails on v5e") + global_mesh = jtu.create_mesh((4, 2), ("x", "y"), iota_order=True) + global_input_shape = (8, 2) + arr, global_input_data = create_array( + global_input_shape, jax.sharding.NamedSharding(global_mesh, pspec)) + + self.assertEqual(arr.is_fully_replicated, expected_is_fully_replicated) + + for s in arr.addressable_shards: + sd = s.device.id + expected_index = expected_idx_rid[sd][0] + expected_replica_id = expected_idx_rid[sd][1] + self.assertEqual(s.index, expected_index) + self.assertEqual(s.replica_id, expected_replica_id) + self.assertEqual(s.data.shape, expected_shard_shape) + np.testing.assert_array_equal(np.asarray(s.data), + global_input_data[expected_index]) + + for s in arr.global_shards: + sd = s.device.id + expected_index = expected_idx_rid[sd][0] + expected_replica_id = expected_idx_rid[sd][1] + self.assertEqual(s.index, expected_index) + self.assertEqual(s.replica_id, expected_replica_id) + if s.data is not None: + self.assertEqual(s.data.shape, expected_shard_shape) + np.testing.assert_array_equal(np.asarray(s.data), + global_input_data[expected_index]) + + if fetch_to_host: + np.testing.assert_array_equal(arr._value, global_input_data) + else: + with self.assertRaisesRegex( + RuntimeError, + r"Fetching value for `jax.Array` that spans non-addressable \(non" + r" process local\) devices is not possible", + ): + _ = arr._value + + @parameterized.named_parameters( + ("mesh_x_y_z", P("x", "y", "z"), (4, 2, 1), False), + ("mesh_xy_z", P(("x", "y"), "z"), (2, 2, 2), False), + ("mesh_z", P("z"), (4, 4, 2), True), + ("mesh_None_z", P(None, None, "z"), (8, 4, 1), True), + ) + def test_array_3d_shard(self, pspec, expected_shard_shape, fetch_to_host): + if jtu.is_device_tpu("5", "e"): + raise unittest.SkipTest("Test fails on v5e") + global_mesh = jtu.create_mesh((2, 2, 2), ("x", "y", "z")) + global_input_shape = (8, 4, 2) + arr, global_input_data = create_array( + global_input_shape, jax.sharding.NamedSharding(global_mesh, pspec)) + + self.assertEqual(arr.ndim, 3) + self.assertEqual(arr.size, 64) + self.assertEqual(arr.addressable_data(0).shape, expected_shard_shape) + if fetch_to_host: + np.testing.assert_array_equal(arr._value, global_input_data) + else: + with self.assertRaisesRegex( + RuntimeError, + r"Fetching value for `jax.Array` that spans non-addressable \(non" + r" process local\) devices is not possible", + ): + _ = arr._value + + def test_sharded_zeros_like(self): + if jtu.is_device_tpu("5", "e"): + raise unittest.SkipTest("Test fails on v5e") + global_mesh = jtu.create_mesh((4, 2), ("x", "y")) + input_shape = (8, 2) + a, input_data = create_array( + input_shape, jax.sharding.NamedSharding(global_mesh, P("x", "y"))) + out = jnp.zeros_like(a) + expected = np.zeros_like(input_data) + self.assertLen(out.addressable_shards, 2) + for i in out.addressable_shards: + np.testing.assert_array_equal(i.data, expected[i.index]) + self.assertEqual(i.replica_id, 0) + self.assertEqual(i.data.shape, (2, 1)) + + @parameterized.product( + spec=[ + (("a", "b", "c"),), + (("a", "b"), "c"), + (("a", "b"),), + (("a",),), + (("b",),), + (("c",),), + ((),), + ], + infer_shape=[True, False], + ) + def test_make_array_from_process_data(self, spec, infer_shape): + mesh = jtu.create_mesh((2, 2, 2), ("a", "b", "c"), iota_order=True) + # Key: number of processes. Value: axes corresponding to hosts. + host_axes_dict = {1: (), 2: ("a",), 4: ("a", "b"), 8: ("a", "b", "c")} + + host_axes = set(host_axes_dict[jax.process_count()]) + axis0_spec = (spec[0],) if isinstance(spec[0], str) else spec[0] + expected_process_shards = 2 ** len(host_axes.intersection(axis0_spec)) + sharding = jax.sharding.NamedSharding(mesh, P(*spec)) + replicated = jax.sharding.NamedSharding(mesh, P(None)) + num_indices0 = sharding_impls.num_addressable_indices(sharding, 0, (8, 4)) + num_indices1 = sharding_impls.num_addressable_indices(sharding, 1, (8, 4)) + global_shape = None if infer_shape else (8, 4) + if infer_shape and num_indices1 < 4: + # If 2nd dimension is sharded across hosts (as it is on v5e 4x2) + # it would affect the computed global_shape, for test's simplicity we + # set explicit global_shape global shape to be 2. + global_shape = (8, 4) + + process_index, num_shards = sharding_impls.get_process_index_and_count( + sharding, + 0, + ndims=2, + ) + self.assertEqual(num_shards, expected_process_shards) + process_data = np.arange(4)[None, :] + 4 * process_index + b = np.broadcast_to(process_data, (num_indices0, 4)) + r = jax.make_array_from_process_local_data(sharding, b, global_shape) + self.assertEqual(r.shape, (8, 4)) + self.assertEqual(r.sharding, sharding) + r = np.array(jax.jit(lambda x: x, out_shardings=replicated)(r)) + global_target = [np.arange(4) + 4 * (i * num_shards // 8) for i in range(8)] + np.testing.assert_array_equal(sorted(r, key=lambda x: x[0]), global_target) + + def test_make_array_from_process_data_shape_inference(self): + mesh = jtu.create_mesh((2, 2, 2), ("a", "b", "c"), iota_order=True) + sharding = jax.sharding.NamedSharding(mesh, P(("a", "b"), "c")) + r = jax.make_array_from_process_local_data(sharding, np.ones([4, 4])) + self.assertEqual(r.sharding, sharding) + process_to_target_shape = {1: (4, 4), 2: (8, 4), 4: (16, 4), 8: (16, 8)} + target_shape = process_to_target_shape[jax.process_count()] + self.assertEqual(target_shape, r.shape) + + # Check if we can specify that local input actually contains full-span + # across different axes. + r2 = jax.make_array_from_process_local_data( + sharding, np.ones([4, 4]), global_shape=(target_shape[0], 4) + ) + self.assertEqual(r2.sharding, sharding) + self.assertEqual((target_shape[0], 4), r2.shape) + + r2 = jax.make_array_from_process_local_data( + sharding, np.ones([4, 4]), global_shape=(4, target_shape[1]) + ) + self.assertEqual(r2.sharding, sharding) + self.assertEqual((4, target_shape[1]), r2.shape) + + r2 = jax.make_array_from_process_local_data( + sharding, np.ones([4, 4]), global_shape=(4, 4) + ) + self.assertEqual(r2.sharding, sharding) + self.assertEqual((4, 4), r2.shape) + # Verify that we get not-supported message rather than non-uniform + with self.assertRaisesRegex(ValueError, ".*supported"): + jax.make_array_from_process_local_data( + sharding, np.ones([4, 4]), global_shape=(4, None) + ) + + @parameterized.named_parameters( + ("shape_none", None), + ("shape_tuple", (16, 4)), + ("shape_pytree", {"a": (16, 4), "b": (16, 4)}), + ) + @jtu.run_on_devices("cpu") + def test_make_array_from_process_local_data_pytree(self, global_shape): + mesh = jtu.create_mesh((2, 2, 2), ("a", "b", "c"), iota_order=True) + with jax.set_mesh(mesh): + r = jax.make_array_from_process_local_data( + P(("a", "b"), "c"), + {"a": np.ones([4, 4]), "b": np.ones([4, 4])}, + global_shape=global_shape, + ) + self.assertTupleEqual(r["a"].shape, (16, 4)) + self.assertTupleEqual(r["b"].shape, (16, 4)) + + def test_multi_process_to_py(self): + global_mesh = jtu.create_mesh((4, 2), ("x", "y")) + input_shape = (8, 2) + a, input_data = create_array( + input_shape, jax.sharding.NamedSharding(global_mesh, P(None)) + ) + self.assertIsInstance(np.asarray(a), np.ndarray) + np.testing.assert_array_equal(np.asarray(a), input_data) + + a, input_data = create_array( + input_shape, jax.sharding.NamedSharding(global_mesh, P("x")) + ) + with self.assertRaisesRegex( + RuntimeError, + r"Fetching value for `jax.Array` that spans non-addressable \(non" + r" process local\) devices is not possible.", + ): + np.asarray(a) + + def test_multi_process_repr(self): + global_mesh = jtu.create_mesh((4, 2), ("x", "y")) + input_shape = (8, 2) + a, _ = create_array(input_shape, + jax.sharding.NamedSharding(global_mesh, P(None))) + val = repr(a) + self.assertIn("Array([[ 0., 1.]", val) + + a, _ = create_array(input_shape, + jax.sharding.NamedSharding(global_mesh, P("x"))) + val = repr(a) + self.assertEqual(val, "Array(shape=(8, 2), dtype=float32)") + + def test_getitem(self): + if jtu.is_device_tpu("5", "e"): + raise unittest.SkipTest("Test fails on v5e") + global_mesh = jtu.create_mesh((4, 2), ("x", "y")) + input_shape = (16, 8) + arr, input_data = create_array( + input_shape, jax.sharding.NamedSharding(global_mesh, P("x", "y"))) + + s = arr[2:4, 0:1] + np.testing.assert_array_equal(s, input_data[2:4, 0:1]) + + p = arr[:2] + np.testing.assert_array_equal(p, input_data[:2]) + + def test_array_fully_replicated_shard(self): + global_mesh = jtu.create_mesh((4, 2), ("x", "y")) + inp_shape = (8, 2) + arr, inp_data = create_array( + inp_shape, jax.sharding.NamedSharding(global_mesh, P())) + fs = arr._fully_replicated_shard() + self.assertEqual(fs.shape, inp_shape) + self.assertLen(fs.sharding.device_set, 1) + self.assertEqual(fs.devices(), {jax.local_devices()[0]}) + np.testing.assert_array_equal(fs, inp_data) + np.testing.assert_array_equal(arr.addressable_data(0), inp_data) + + def test_device_put_uncommitted_array(self): + mesh = jtu.create_mesh((4, 2), ("x", "y")) + s = jax.sharding.NamedSharding(mesh, P("x", "y")) + inp = jnp.arange(16).reshape(8, 2) + out = jax.device_put(inp, s) + + for shard in out.addressable_shards: + np.testing.assert_array_equal(shard.data, inp[shard.index]) + self.assertEqual(out.sharding, s) + + def test_device_put_np_array(self): + mesh = jtu.create_mesh((4, 2), ("x", "y")) + s = jax.sharding.NamedSharding(mesh, P("x", "y")) + inp = np.arange(16).reshape(8, 2) + out = jax.device_put(inp, s) + + for shard in out.addressable_shards: + np.testing.assert_array_equal(shard.data, inp[shard.index]) + self.assertEqual(out.sharding, s) + + def test_device_put_python_scalar(self): + mesh = jtu.create_mesh((2, 2), ("x", "y")) + s = jax.sharding.NamedSharding(mesh, P()) + out = jax.device_put(1, s) + + for shard in out.addressable_shards: + np.testing.assert_array_equal(shard.data, 1) + self.assertEqual(out.sharding, s) + + def test_device_put_python_scalar_different_error(self): + mesh = jtu.create_mesh((4, 2), ("x", "y")) + s = jax.sharding.NamedSharding(mesh, P()) + + with self.assertRaisesRegex( + AssertionError, + ".*passed to device_put is not the same on each process.*"): + if jax.process_index() == 0: + jax.device_put(1., s) + else: + jax.device_put(2., s) + + def test_device_put_uncommitted_array_different_inputs_error(self): + mesh = jtu.create_mesh((4, 2), ("x", "y")) + s = jax.sharding.NamedSharding(mesh, P("x", "y")) + + with self.assertRaisesRegex( + AssertionError, + ".*passed to device_put is not the same on each process.*"): + if jax.process_index() == 0: + jax.device_put(jnp.arange(16).reshape(8, 2), s) + else: + jax.device_put(jnp.arange(16, stop=32).reshape(8, 2), s) + + def test_device_put_committed_array_error(self): + mesh = jtu.create_mesh((4, 2), ("x", "y")) + s = jax.sharding.NamedSharding(mesh, P("x", "y")) + inp = jax.device_put(jnp.arange(16).reshape(8, 2), jax.local_devices()[0]) + + with self.assertRaisesRegex(ValueError, "device_put's second argument.*"): + jax.device_put(inp, s) + + def test_closed_over_global_array_error(self): + mesh = jtu.create_mesh((4, 2), ("x", "y")) + s = jax.sharding.NamedSharding(mesh, P("x", "y")) + arr, np_inp = create_array((8, 2), s) + + @jax.jit + def f(x): + return x + arr + + with self.assertRaisesRegex( + RuntimeError, + r"Closing over jax.Array that spans non-addressable \(non process" + r" local\) devices is not allowed"): + f(np_inp) + + def test_zeros_like_use_mesh(self): + mesh = jtu.create_mesh((4, 2), ("x", "y")) + s = jax.sharding.NamedSharding(mesh, P()) + np_inp = np.array(0, dtype=np.float32) + arr = jax.device_put(np_inp, s) + + with jax.set_mesh(mesh): + out = jnp.zeros_like(arr) + np.testing.assert_array_equal(out, np_inp) + + def test_sharding_process_indices_all_devices(self): + mesh = jax.make_mesh((jax.device_count(),), ("x",), devices=jax.devices(), + axis_types=(jax.sharding.AxisType.Explicit,)) + s = jax.sharding.NamedSharding(mesh, P("x",)) + + expected_pids = {d.process_index for d in s.device_set} + self.assertEqual(s._internal_device_list.process_indices, expected_pids) + self.assertLen(s._internal_device_list.process_indices, jax.process_count()) + + +class NonaddressableArrayTestMultiHost(jt_multiprocess.MultiProcessTest): + + def test_create_nonaddressable_array(self): + y, x = create_nonaddressable_array((8, 8)) + # The array is non-addressable in at least one process. + self.assertLess(len(y.sharding._internal_device_list.process_indices), + jax.process_count()) + for a in y.addressable_shards: + np.testing.assert_array_equal(a.data, x[a.index]) + + fr, x = create_nonaddressable_array((8, 8), spec=P()) + self.assertTrue(fr.sharding.is_fully_replicated) + self.assertLess(len(fr.sharding._internal_device_list.process_indices), + jax.process_count()) + if fr.sharding.has_addressable_devices: + np.testing.assert_array_equal(x, fr) + + def test_named_sharding_is_fully_addressable(self): + pid = 0 + ds = jax.local_devices(process_index=pid) + mesh = jtu.create_mesh((len(ds),), ("x",)) + s = jax.sharding.NamedSharding(mesh, P("x")) + self.assertEqual(s.is_fully_addressable, jax.process_index() == pid) + + def test_single_device_sharding_is_fully_addressable(self): + d = jax.devices()[0] + s = jax.sharding.SingleDeviceSharding(d) + self.assertEqual(s.is_fully_addressable, + jax.process_index() == d.process_index) + + def test_array_with_no_local_shards_has_valid_layout(self): + d = jax.devices()[0] + s = jax.sharding.SingleDeviceSharding(d) + shape = (8, 8) + np_inp = np.arange(math.prod(shape), dtype=np.int32).reshape(shape) + xs = [] + if jax.process_index() == d.process_index: + x = jax.device_put(np_inp, s) + xs.append(x) + + arr = jax.make_array_from_single_device_arrays( + shape, s, xs, dtype=jnp.int32) + self.assertIsNotNone(arr.format.layout) + + def test_device_put_uncommitted_array_namedsharding(self): + n_local = len(jax.local_devices()) + pid = 0 + mesh = jax.make_mesh( + (n_local,), ("x",), devices=jax.local_devices(process_index=pid), + axis_types=(jax.sharding.AxisType.Explicit,)) + s = jax.sharding.NamedSharding(mesh, P("x",)) + inp = jnp.arange(16).reshape(8, 2) + out = jax.device_put(inp, s) + + # device_put of an uncommitted array to a sharding that is addressable only + # in process `pid` should return an array with addressable shards only in + # process `pid`. In other processes, the returned array has no addressable + # shards. + expected_num_shards = n_local if jax.process_index() == pid else 0 + self.assertLen(out.addressable_shards, expected_num_shards) + for shard in out.addressable_shards: + np.testing.assert_array_equal(shard.data, inp[shard.index]) + self.assertEqual(out.sharding, s) + + def test_device_put_numpy_array_namedsharding(self): + n_local = len(jax.local_devices()) + pid = 1 + mesh = jax.make_mesh( + (n_local,), ("x",), devices=jax.local_devices(process_index=pid), + axis_types=(jax.sharding.AxisType.Explicit,)) + s = jax.sharding.NamedSharding(mesh, P("x",)) + inp = np.arange(16).reshape(8, 2) + out = jax.device_put(inp, s) + + # device_put of a numpy array to a sharding that is addressable only in + # process `pid` should return an array with addressable shards only in + # process `pid`. In other processes, the returned array has no addressable + # shards. + expected_num_shards = n_local if jax.process_index() == pid else 0 + self.assertLen(out.addressable_shards, expected_num_shards) + for shard in out.addressable_shards: + np.testing.assert_array_equal(shard.data, inp[shard.index]) + self.assertEqual(out.sharding, s) + + def test_device_put_numpy_array_singledevice(self): + inp = np.arange(16).reshape(8, 2) + d = jax.devices()[0] + out = jax.device_put(inp, d) + + # device_put of a numpy array to a sharding that is addressable only in + # process `pid` should return an array with addressable shards only in + # process `pid`. In other processes, the returned array has no addressable + # shards. + expected_num_shards = 1 if jax.process_index() == d.process_index else 0 + self.assertLen(out.addressable_shards, expected_num_shards) + for shard in out.addressable_shards: + np.testing.assert_array_equal(shard.data, inp[shard.index]) + self.assertEqual(out.sharding, jax.sharding.SingleDeviceSharding(d)) + + def test_device_put_committed_array_error(self): + inp = jax.device_put(jnp.arange(16).reshape(8, 2), jax.local_devices()[0]) + + # device_put of a committed array to a nonaddressable sharding should raise + # an error (until cross-host transfers are supported). + with self.assertRaisesRegex(RuntimeError, + "Cannot copy array to non-addressable device"): + nonlocal_pid = (jax.process_index() + 1) % jax.process_count() + jax.device_put(inp, jax.local_devices(process_index=nonlocal_pid)[0]) + + def test_make_array_from_callback(self): + n_local = jax.local_device_count() + pid = 1 + + mesh = jax.make_mesh( + (n_local,), ("x",), devices=jax.local_devices(process_index=pid), + axis_types=(jax.sharding.AxisType.Explicit,)) + s = jax.sharding.NamedSharding(mesh, P("x",)) + + # Create an array that is non-addressable in processes besides `pid`. + global_data = np.arange(16, dtype=np.int32).reshape(8, 2) + arr = jax.make_array_from_callback( + global_data.shape, s, lambda idx: global_data[idx], + dtype=global_data.dtype) + + # The returned array should only contain addressable shards in process + # `pid`. + expected_num_shards = n_local if jax.process_index() == pid else 0 + self.assertLen(arr.addressable_shards, expected_num_shards) + np.testing.assert_array_equal(arr.shape, global_data.shape) + for shard in arr.addressable_shards: + np.testing.assert_array_equal(shard.data, global_data[shard.index]) + + def test_make_array_from_callback_prngkey(self): + n_local = jax.local_device_count() + pid = 1 + mesh = jax.make_mesh( + (n_local,), ("x",), devices=jax.local_devices(process_index=pid), + axis_types=(jax.sharding.AxisType.Explicit,)) + s = jax.sharding.NamedSharding(mesh, P("x",)) + + # Create a PRNG key array that is non-addressable in processes besides + # `pid`. + seeds = jnp.arange(8) + global_data = jax.vmap(lambda x: jax.random.key(seed=x))(seeds) + k = jax.random.key(0) + arr = jax.make_array_from_callback( + global_data.shape, s, lambda idx: global_data[idx], + dtype=k.dtype) + + # The returned array should only contain addressable shards in process + # `pid`. + expected_num_shards = n_local if jax.process_index() == pid else 0 + self.assertLen(arr.addressable_shards, expected_num_shards) + np.testing.assert_array_equal(arr.shape, global_data.shape) + for shard in arr.addressable_shards: + np.testing.assert_array_equal(shard.data.shape, (8 // n_local,)) + + def test_sharding_process_indices_device_subset(self): + n_devices = jax.device_count() + mesh = jax.make_mesh( + (n_devices // 2,), ("x",), devices=jax.devices()[:n_devices // 2], + axis_types=(jax.sharding.AxisType.Explicit,)) + s = jax.sharding.NamedSharding(mesh, P("x",)) + + expected_pids = {d.process_index for d in s.device_set} + self.assertEqual(s._internal_device_list.process_indices, expected_pids) + self.assertLen(s._internal_device_list.process_indices, + jax.process_count() // 2) + + def test_jit_no_local_devices_named_sharding(self): + x = np.arange(64).reshape(8, 8) + n_local = jax.local_device_count() + pid = 1 + + # Create a sharding that is non-addressable in processes besides `pid`. + mesh = jax.make_mesh( + (n_local,), ("x",), devices=jax.local_devices(process_index=pid), + axis_types=(jax.sharding.AxisType.Explicit,)) + s = jax.sharding.NamedSharding(mesh, P("x",)) + y = jax.device_put(x, s) + expected_num_shards = n_local if jax.process_index() == pid else 0 + self.assertLen(y.addressable_shards, expected_num_shards) + + @jax.jit + def f(x): + return x + 1 + + # The returned array should only contain addressable shards in process + # `pid`. No work is done in other processes. + z = f(y) + z.block_until_ready() + self.assertLen(z.addressable_shards, expected_num_shards) + if jax.process_index() == pid: + for shard in z.addressable_shards: + np.testing.assert_array_equal(shard.data, x[shard.index] + 1) + + def test_jit_no_local_devices_named_sharding_collective(self): + x = np.arange(64).reshape(8, 8) + n_local = jax.local_device_count() + pid = 1 + + # Create a sharding that is non-addressable in processes besides `pid`. + mesh = jax.make_mesh( + (n_local,), ("x",), devices=jax.local_devices(process_index=pid), + axis_types=(jax.sharding.AxisType.Explicit,)) + s = jax.sharding.NamedSharding(mesh, P("x",)) + y = jax.device_put(x, s) + expected_num_shards = n_local if jax.process_index() == pid else 0 + self.assertLen(y.addressable_shards, expected_num_shards) + + @jax.jit + def f(x): + return jnp.sum(x) + + # The returned array should only contain addressable shards in process + # `pid`. No work is done in other processes. + z = f(y) + z.block_until_ready() + self.assertLen(z.addressable_shards, expected_num_shards) + if jax.process_index() == pid: + expected = x.sum() + for shard in z.addressable_shards: + np.testing.assert_array_equal(shard.data, expected) + + def test_jit_no_local_devices_single_device_sharding(self): + x = np.arange(64).reshape(8, 8) + pid = 1 + + # Create a single device sharding for a device local to process `pid`. + s = jax.sharding.SingleDeviceSharding( + jax.local_devices(process_index=pid)[0]) + y = jax.device_put(x, s) + expected_num_shards = 1 if jax.process_index() == pid else 0 + self.assertLen(y.addressable_shards, expected_num_shards) + + @jax.jit + def f(x): + return x + 1 + + # The returned array should only contain an addressable shard in process + # `pid`. No work is done in other processes. + z = f(y) + z.block_until_ready() + self.assertLen(z.addressable_shards, expected_num_shards) + if jax.process_index() == pid: + np.testing.assert_array_equal(z.addressable_shards[0].data, x + 1) + + def test_jit_fastpath_matmul(self): + mesh = jax.sharding.Mesh( + jax.devices()[: len(jax.devices()) // 2], axis_names=("devices")) + sharding = jax.sharding.NamedSharding(mesh, P()) + + x = jax.device_put( + jnp.arange(8 * 16, dtype=jnp.float32).reshape((8, 16)), sharding) + w = jax.device_put( + jnp.arange(16 * 4, dtype=jnp.float32).reshape((16, 4)), sharding) + + jax.experimental.multihost_utils.sync_global_devices("start") + matmul = jax.jit(lambda x, w: x @ w, out_shardings=sharding) + + _ = matmul(x, w) + y = matmul(x, w) # doesn't crash on second call + expected = x @ w + for shard in y.addressable_shards: + np.testing.assert_array_equal(shard.data, expected[shard.index]) + + def test_numpy_asarray_no_local_devices(self): + y, x = create_nonaddressable_array((8, 8), spec=P()) + + # In processes with local shards, we can fetch the value of the array using + # np.asarray, since the sharding is fully replicated. In processes with no + # local shards, attempting to fetch the NumPy array is an error. + if y.sharding.has_addressable_devices: + np.testing.assert_array_equal(np.asarray(y), x) + else: + with self.assertRaisesRegex( + RuntimeError, + r"Fetching value for `jax.Array` that spans non-addressable \(non" + r" process local\) devices is not possible."): + np.asarray(y) + + def test_shard_map_no_local_devices(self): + x, x_np = create_nonaddressable_array((8, 8)) + + # shard_map works as expected when there are nonparticipating hosts. + shard_map_f = jax.shard_map( + lambda x: jax.lax.psum(x, "x"), mesh=x.sharding.mesh, in_specs=P("x"), + out_specs=P()) + y = shard_map_f(x) + expected_y = sum(np.split(x_np, len(x.sharding.device_set))) + sharding_process_indices = x.sharding._internal_device_list.process_indices + expected_num_shards = (jax.local_device_count() + if jax.process_index() in sharding_process_indices + else 0) + self.assertLen(y.addressable_shards, expected_num_shards) + for shard in y.addressable_shards: + np.testing.assert_array_equal(shard.data, expected_y[shard.index]) + + def test_array_delete(self): + y, _ = create_nonaddressable_array((8, 8)) + y.delete() + with self.assertRaisesRegex(RuntimeError, "Array has been deleted."): + y._check_if_deleted() + self.assertIsNone(y._npy_value) + self.assertIsNone(y._arrays) + + def test_single_device_array_usage_after_delete(self): + y, _ = create_nonaddressable_array((8, 8)) + y.delete() + + with self.assertRaisesRegex(RuntimeError, "Array has been deleted."): + _ = y + 1 + + def test_repr(self): + y, _ = create_nonaddressable_array((8, 8)) + if y.is_fully_addressable: + self.assertStartsWith(repr(y), "Array([[ 0., 1., 2., 3.,") + else: + self.assertEqual(repr(y), "Array(shape=(8, 8), dtype=float32)") + + def test_str(self): + y, _ = create_nonaddressable_array((8, 8)) + if y.is_fully_addressable: + self.assertStartsWith(str(y), "[[ 0. 1. 2. 3.") + else: + self.assertEqual(str(y), "Array(shape=(8, 8), dtype=float32)") + + def test_format(self): + y, _ = create_nonaddressable_array((8, 8)) + if y.is_fully_addressable: + self.assertStartsWith(format(y), "[[ 0. 1. 2. 3.") + else: + self.assertEqual(format(y), "Array(shape=(8, 8), dtype=float32)") + + def test_array_astype(self): + y, _ = create_nonaddressable_array((8, 8)) + y = y.astype(np.int32) + self.assertEqual(y.dtype, np.int32) + + def test_sharded_add(self): + y, y_np = create_nonaddressable_array((8, 8)) + z, z_np = create_nonaddressable_array((8, 8), spec=P()) + out = y + z + expected = y_np + z_np + self.assertLen(out.addressable_shards, len(y.sharding.addressable_devices)) + for shard in out.addressable_shards: + np.testing.assert_array_equal(shard.data, expected[shard.index]) + + def test_sharded_zeros_like(self): + y, _ = create_nonaddressable_array((8, 8)) + out = jnp.zeros_like(y) + expected = jnp.zeros(y.shape, dtype=y.dtype) + self.assertLen(out.addressable_shards, len(y.sharding.addressable_devices)) + for i in out.addressable_shards: + np.testing.assert_array_equal(i.data, expected[i.index]) + + def test_array_not_hashable(self): + y, _ = create_nonaddressable_array((8, 8)) + with self.assertRaisesRegex(TypeError, "unhashable type"): + hash(y) + + def test_on_device_size_in_bytes(self): + a, _ = create_nonaddressable_array((8, 8)) + if not a.sharding.has_addressable_devices: + with self.assertRaisesRegex( + RuntimeError, + r"GetOnDeviceSizeInBytes\(\) is not yet supported for arrays with no " + r"addressable devices"): + a.on_device_size_in_bytes() + else: + shard_size = a.addressable_shards[0].data.on_device_size_in_bytes() + self.assertEqual(shard_size * len(a.global_shards), + a.on_device_size_in_bytes()) + + def test_array_is_ready(self): + y, _ = create_nonaddressable_array((8, 8)) + y.is_ready() # doesn't crash + + def test_array_copy_to_host_async(self): + y, x = create_nonaddressable_array((8, 8)) + y.copy_to_host_async() # doesn't crash + for shard in y.addressable_shards: + np.testing.assert_array_equal(shard.data, x[shard.index]) + + def test_device_get_replicated(self): + y, x = create_nonaddressable_array((8, 8), spec=P()) + if y.sharding.has_addressable_devices: + np.testing.assert_array_equal(jax.device_get(y), x) + else: + with self.assertRaisesRegex( + RuntimeError, + r"Fetching value for `jax.Array` that spans non-addressable \(non" + r" process local\) devices is not possible."): + jax.device_get(y) + + # Skipped on GPU since there are two processes with one device each, so we + # can't construct a sharding that is nonaddressable in one of the processes + # and also not fully replicated (since the sharding must contain one device). + @jtu.skip_on_devices("gpu") + def test_device_get_sharded(self): + y, _ = create_nonaddressable_array((8, 8)) + with self.assertRaisesRegex( + RuntimeError, + r"Fetching value for `jax.Array` that spans non-addressable \(non" + r" process local\) devices is not possible."): + jax.device_get(y) + + def test_array_fully_replicated_shard(self): + y, x = create_nonaddressable_array((8, 8), spec=P()) + if y.sharding.has_addressable_devices: + fs = y.addressable_data(0) + self.assertEqual(fs.shape, x.shape) + self.assertLen(fs.sharding.device_set, 1) + self.assertEqual(fs.devices(), {jax.local_devices()[0]}) + np.testing.assert_array_equal(fs, x) + np.testing.assert_array_equal(y.addressable_data(0), x) + else: + with self.assertRaisesRegex( + RuntimeError, "FullyReplicatedShard: Array has no addressable shards." + ): + y.addressable_data(0) + + def test_array_iter_replicated(self): + y, _ = create_nonaddressable_array((8, 8), spec=P()) + y_iter = iter(y) + self.assertLen(list(y_iter), 8) + + # Skipped on GPU since the sharding contains one device and is therefore fully + # replicated. + @jtu.skip_on_devices("gpu") + def test_array_iter_sharded(self): + y, _ = create_nonaddressable_array((8, 8)) + with self.assertRaises(AssertionError): + iter(y) + + +class CrossHostTransferTest(jt_multiprocess.MultiProcessTest): + + @jtu.run_on_devices("cpu") + def test_cross_host_transfer_cpu_error(self): + x = np.arange(64).reshape(8, 8) + src_pid = 0 + dst_pid = 1 + src_sharding = jax.sharding.SingleDeviceSharding( + jax.local_devices(process_index=src_pid)[0]) + dst_sharding = jax.sharding.SingleDeviceSharding( + jax.local_devices(process_index=dst_pid)[0]) + y = jax.device_put(x, src_sharding) + with self.assertRaisesRegex( + ValueError, "does not support cross-host device transfers"): + jax.device_put(y, dst_sharding) + + @jtu.skip_on_devices("cpu") + def test_cross_host_transfer_single_device_sharding(self): + x = np.arange(64).reshape(8, 8) + src_pid = 0 + dst_pid = 1 + src_sharding = jax.sharding.SingleDeviceSharding( + jax.local_devices(process_index=src_pid)[0]) + dst_sharding = jax.sharding.SingleDeviceSharding( + jax.local_devices(process_index=dst_pid)[0]) + y = jax.device_put(x, src_sharding) + z = jax.device_put(y, dst_sharding) + if jax.process_index() == dst_pid: + self.assertLen(z.addressable_shards, 1) + np.testing.assert_array_equal(z.addressable_shards[0].data, x) + else: + self.assertEmpty(z.addressable_shards) + + @jtu.skip_on_devices("cpu") + def test_cross_host_transfer_named_sharding(self): + x = np.arange(64).reshape(8, 8) + n_local = jax.local_device_count() + src_pid = 0 + dst_pid = 1 + src_sharding = jax.sharding.NamedSharding( + jax.make_mesh((n_local,), ("x",), + devices=jax.local_devices(process_index=src_pid), + axis_types=(jax.sharding.AxisType.Explicit,)), + P("x")) + dst_sharding = jax.sharding.NamedSharding( + jax.make_mesh((n_local,), ("x",), + devices=jax.local_devices(process_index=dst_pid), + axis_types=(jax.sharding.AxisType.Explicit,)), + P("x")) + y = jax.device_put(x, src_sharding) + z = jax.device_put(y, dst_sharding) + if jax.process_index() == dst_pid: + self.assertLen(z.addressable_shards, n_local) + for shard in z.addressable_shards: + np.testing.assert_array_equal(shard.data, x[shard.index]) + else: + self.assertEmpty(z.addressable_shards) + + @jtu.skip_on_devices("cpu") + def test_cross_host_transfer_named_sharding_replicated(self): + x = np.arange(64).reshape(8, 8) + n_dev = jax.device_count() // 2 + src_sharding = jax.sharding.NamedSharding( + jax.make_mesh((n_dev,), ("x",), devices=jax.devices()[:n_dev], + axis_types=(jax.sharding.AxisType.Explicit,)), + P() + ) + dst_sharding = jax.sharding.NamedSharding( + jax.make_mesh((n_dev,), ("x",), devices=jax.devices()[n_dev:], + axis_types=(jax.sharding.AxisType.Explicit,)), + P() + ) + y = jax.device_put(x, src_sharding) + z = jax.device_put(y, dst_sharding) + for shard in z.addressable_shards: + np.testing.assert_array_equal(shard.data, x[shard.index]) + + @jtu.skip_on_devices("cpu") + def test_cross_host_transfer_batched(self): + num_arrays = 3 + xs = [] + for i in range(1, num_arrays + 1): + xs.append(jnp.arange(64 * i).reshape(8, 8 * i)) + # TODO(emilyaf): Smaller sizes fail on TPU because the dst buffer size + # returned by TransferSizeUtil::ShapeSizeCompact is larger than the src + # buffer size. Investigate this further. + # xs.append(jnp.arange(16 * i).reshape(8, 2 * i)) + xs[0] = xs[0].astype(jnp.float32) + + n_local = jax.local_device_count() + src_pid = 0 + dst_pid = 1 + src_sharding = jax.sharding.NamedSharding( + jax.make_mesh((n_local,), ("x",), + devices=jax.local_devices(process_index=src_pid), + axis_types=(jax.sharding.AxisType.Explicit,)), + P("x")) + dst_sharding = jax.sharding.NamedSharding( + jax.make_mesh((n_local,), ("x",), + devices=jax.local_devices(process_index=dst_pid), + axis_types=(jax.sharding.AxisType.Explicit,)), + P("x")) + + ys = jax.device_put(xs, src_sharding) + zs = jax.device_put(ys, dst_sharding) + for (x, z) in zip(xs, zs): + if jax.process_index() == dst_pid: + self.assertLen(z.addressable_shards, n_local) + for shard in z.addressable_shards: + np.testing.assert_array_equal(shard.data, x[shard.index]) + else: + self.assertEmpty(z.addressable_shards) + + @jtu.skip_on_devices("cpu") + def test_device_to_cpu_transfer_jit(self): + x = jnp.arange(64).reshape(8, 8) + with self.assertWarnsRegex( + DeprecationWarning, + r"backend and device argument on jit is deprecated", + ): + cpu_transfer_f = jax.jit(lambda x: x + 1, backend="cpu") + cpu_transfer_f(x) # Should not raise a cross-host transfer error. + + @jtu.skip_on_devices("cpu") + def test_device_put_to_cpu(self): + x = jnp.arange(64).reshape(8, 8) + devices = jax.devices() + cpu_devices = jax.devices(backend="cpu") + num_devices = min(len(devices), len(cpu_devices)) + + # Create CPU and GPU/TPU shardings that are not fully addressable. + cpu_sharding = jax.sharding.NamedSharding( + jax.make_mesh( + (num_devices,), ("x",), devices=cpu_devices[:num_devices], + axis_types=(jax.sharding.AxisType.Explicit,)), + P("x")) + sharding = jax.sharding.NamedSharding( + jax.make_mesh( + (num_devices,), ("x",), devices=devices[:num_devices], + axis_types=(jax.sharding.AxisType.Explicit,)), + P("x")) + y = jax.device_put(x, sharding) + + # device_put of a GPU/TPU array to the CPU sharding should raise a helpful + # error. + with self.assertRaisesRegex( + ValueError, ("For a cross-host reshard in multi-controller JAX|" + "device_put's second argument must be a Device")): + jax.device_put(y, cpu_sharding) + + +if __name__ == "__main__": + jt_multiprocess.main() diff --git a/jaxlib/tools/build_wheel_test.py b/tests/multiprocess/axis_index_test.py similarity index 51% rename from jaxlib/tools/build_wheel_test.py rename to tests/multiprocess/axis_index_test.py index a33491f1c606..f64fc435a5aa 100644 --- a/jaxlib/tools/build_wheel_test.py +++ b/tests/multiprocess/axis_index_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 The JAX Authors. +# Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,21 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -# This test verifies that the build_wheel.py runs successfully. +import jax +from jax import lax +from jax._src import test_multiprocess as jt_multiprocess +import numpy as np -import platform -import subprocess -import sys -import tempfile -from bazel_tools.tools.python.runfiles import runfiles +class AxisIndexTest(jt_multiprocess.MultiProcessTest): -r = runfiles.Create() + def test(self): + f = jax.pmap(lambda _: lax.axis_index("i"), axis_name="i") + n = jax.local_device_count() + xs = np.arange(n) + out = f(xs * 0) + np.testing.assert_equal(out, xs + (n * jax.process_index())) -with tempfile.TemporaryDirectory(prefix="jax_build_wheel_test") as tmpdir: - subprocess.run([ - sys.executable, r.Rlocation("__main__/jaxlib/tools/build_wheel.py"), - f"--cpu={platform.machine()}", - f"--output_path={tmpdir}", - "--jaxlib_git_hash=12345678" - ], check=True) + +if __name__ == "__main__": + jt_multiprocess.main() diff --git a/tests/multiprocess/colocated_python_test.py b/tests/multiprocess/colocated_python_test.py new file mode 100644 index 000000000000..14fa45a5544c --- /dev/null +++ b/tests/multiprocess/colocated_python_test.py @@ -0,0 +1,78 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multihost tests for jax.Array.""" + +import jax +from jax._src import test_multiprocess as jt_multiprocess +from jax._src import test_util as jtu +from jax.experimental import colocated_python +import numpy as np + +try: + import cloudpickle # noqa + HAS_CLOUDPICKLE = True +except (ModuleNotFoundError, ImportError): + HAS_CLOUDPICKLE = False + +class ColocatedPythonTestMultiHost(jt_multiprocess.MultiProcessTest): + + def setUp(self): + super().setUp() + if not HAS_CLOUDPICKLE: + self.skipTest( + "ColocatedPythonTestMultiHost depends on cloudpickle library" + ) + jtu.request_cpu_devices(jax.local_device_count()) + + def test_colocated_cpu_devices(self): + if jax.device_count() % 2 == 0: + mesh_shape = (2, jax.device_count() // 2) + else: + mesh_shape = (1, jax.device_count()) + mesh = jax.make_mesh(mesh_shape, ("x", "y"), + axis_types=(jax.sharding.AxisType.Explicit,) * 2) + cpu_mesh1 = colocated_python.colocated_cpu_devices(mesh) + + cpu_devices = colocated_python.colocated_cpu_devices(mesh.devices.flat) + cpu_mesh2 = jax.make_mesh(mesh_shape, ("x", "y"), + axis_types=(jax.sharding.AxisType.Explicit,) * 2, + devices=cpu_devices) + self.assertEqual(cpu_mesh1, cpu_mesh2) + + def test_simple_function(self): + @colocated_python.colocated_python + def add_one(x): + return jax.make_array_from_single_device_arrays( + x.shape, x.sharding, [s.data + 1 for s in x.addressable_shards]) + + mesh = jax.make_mesh((jax.device_count(),), ("x",), + axis_types=(jax.sharding.AxisType.Explicit,)) + cpu_mesh = colocated_python.colocated_cpu_devices(mesh) + cpu_sharding = jax.NamedSharding(cpu_mesh, jax.P("x")) + + x = np.arange(cpu_mesh.size) + x = jax.device_put(x, cpu_sharding) + + out = add_one(x) + + out = jax.jit(lambda x: x, + out_shardings=jax.NamedSharding(cpu_mesh, jax.P()))(out) + out = jax.device_get(out) + + np.testing.assert_equal(out, np.arange(cpu_mesh.size) + 1) + + +if __name__ == "__main__": + jt_multiprocess.main() diff --git a/tests/multiprocess/device_id_test.py b/tests/multiprocess/device_id_test.py new file mode 100644 index 000000000000..83d38b4e7dc6 --- /dev/null +++ b/tests/multiprocess/device_id_test.py @@ -0,0 +1,91 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +from jax._src import test_multiprocess as jt_multiprocess +from jax._src import test_util as jtu + + +class DeviceIdTest(jt_multiprocess.MultiProcessTest): + + def testDeviceIds(self): + # TODO(phawkins): TPU process IDs won't necessarily match the global + # process index. + if not jtu.test_device_matches(["tpu"]): + self.assertEqual( + jax.process_index(), + jt_multiprocess.MULTIPROCESS_TEST_WORKER_ID.value, + ) + self.assertLen( + jax.devices(), + jt_multiprocess.NUM_PROCESSES.value * jax.local_device_count(), + ) + self.assertEqual( + jax.local_devices()[0].process_index, + jax.process_index(), + ) + + def testPrimitive(self): + with jax.default_device(jax.local_devices(backend="cpu")[0]): + self.assertEqual(2, jax.lax.neg(jax.lax.neg(2))) + + def testJit(self): + """Verifies that local computation works inside a distributed job.""" + x = jax.device_put(1) + self.assertEqual(x, 1) + y = jax.jit(lambda x: x + 1)(x) + self.assertEqual(y, 2) + + # TODO(phawkins): this test CHECK-fails on TPU. + @jtu.skip_on_devices("tpu") + def testNonaddressableDeviceToDevicePut(self): + source_device = jax.local_devices(backend="cpu")[0] + x = jax.device_put(0, source_device) + for device in jax.devices(): + if device.process_index != jax.process_index(): + with self.assertRaisesRegex( + RuntimeError, + "(Cannot copy array to non-addressable device.*|.*is not a local" + " device.*)", + ): + jax.device_put(x, device) + + def testDefaultDevicePlatformString(self): + with jax.default_device("cpu"): + result = jax.jit(lambda x: x + 1)(1) + self.assertEqual(result.device.platform, "cpu") + self.assertEqual(result.device, jax.local_devices(backend="cpu")[0]) + + result = jax.jit(lambda x: x + 1)(1) + self.assertEqual(result.device.platform, jax.default_backend()) + self.assertEqual(result.device, jax.local_devices()[0]) + + # def testCrossProcessReduceScatter(self): + # i = multiprocess_test.MULTIPROCESS_TEST_WORKER_ID.value + # n = multiprocess_test.NUM_PROCESSES.value + # f = jax.pmap( + # lambda x: lax.psum_scatter( + # x, + # "i", + # ), + # axis_name="i", + # ) + # x = np.arange(n * n).reshape(n, n) + # out = f(x[i : i + 1]) + # expected = np.sum(x, axis=0) + # np.testing.assert_allclose(expected[i : i + 1], out) + + +if __name__ == "__main__": + jt_multiprocess.main() diff --git a/tests/multiprocess/host_callback_test.py b/tests/multiprocess/host_callback_test.py new file mode 100644 index 000000000000..454e8ad2d996 --- /dev/null +++ b/tests/multiprocess/host_callback_test.py @@ -0,0 +1,171 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for host_callback on multi-host setup.""" + +import unittest + +import jax +from jax import lax +from jax import numpy as jnp +from jax._src import pjit +from jax._src import test_multiprocess as jt_multiprocess +from jax._src import test_util as jtu +from jax.experimental import io_callback +from jax.experimental import multihost_utils +from jax.sharding import PartitionSpec as P +import numpy as np + + +class _CollectCallbacks: + """Collect the callback arguments.""" + + def __init__(self, test_method_name): + self.collected = [] + self.test_method_name = test_method_name + + def collect(self, what) -> None: + print(f"collect[{self.test_method_name}]: {what}") + self.collected.append(what) + + +callback_collector = None + +NR_PROCESSES = 4 +NR_LOCAL_DEVICES = 2 +NR_DEVICES = NR_PROCESSES * NR_LOCAL_DEVICES + + +def sorted_devices(): + devices = sorted( + jax.devices(), + key=lambda d: (d.process_index, getattr(d, "core_on_chip", 0)), + ) + if len(devices) != NR_DEVICES: + raise unittest.SkipTest("Test assumes that it runs on 8 devices.") + if jax.process_count() != NR_PROCESSES: + raise unittest.SkipTest(f"Test assumes we have {NR_PROCESSES} processes.") + return devices + + +class IoCallbackMultiProcessTest(jtu.JaxTestCase, + jt_multiprocess.MultiProcessTest): + + def setUp(self): + super(jtu.JaxTestCase, self).setUp() + global callback_collector + callback_collector = _CollectCallbacks(self._testMethodName) + + def tearDown(self): + super(jtu.JaxTestCase, self).tearDown() + jax.effects_barrier() + + def test_pure_callback_pmap(self): + # x_global: i32[D, 2] = [[0, 1], [10, 11], [20, 21], ...] + # x_local: i32[L, 2] + x_global = np.arange(100, dtype=np.int32).reshape((10, 10))[:NR_DEVICES, :2] + + process_idx = jax.process_index() + local_device_idx = process_idx * NR_LOCAL_DEVICES + x_local = x_global[local_device_idx:local_device_idx + NR_LOCAL_DEVICES] + + def func(x): # Runs on each device. + sum_global = jax.lax.psum(x, "d") + return jax.pure_callback(callback_func, + x, # result_shapes_dtype + lax.axis_index("d"), x, sum_global) + + def callback_func(axis_index, x, sum_global): + callback_collector.collect((axis_index, x, sum_global)) + return x * np.array(3, np.int32) + sum_global + + pmap_func = jax.pmap(func, axis_name="d", devices=sorted_devices()) + res = pmap_func(x_local) + expected_sum_global = np.sum(x_global, axis=0, dtype=np.int32) + # On each host we only get the local result. + self.assertAllClose(x_local * np.array(3, np.int32) + expected_sum_global, + res) + + jax.effects_barrier() + + # Each process gets only the callbacks for its local devices. + self.assertAllClose( + sorted(callback_collector.collected, key=lambda x: x[0]), + [(np.array(process_idx * NR_LOCAL_DEVICES, dtype=np.int32), + np.array([10 * local_device_idx, + 10 * local_device_idx + 1], dtype=np.int32), + expected_sum_global), + (np.array(process_idx * NR_LOCAL_DEVICES + 1, dtype=np.int32), + np.array([10 * local_device_idx + 10, + 10 * local_device_idx + 11], dtype=np.int32), + expected_sum_global)]) + + @jtu.ignore_warning(category=DeprecationWarning) + def test_io_callback_pjit(self): + devices = np.array(sorted_devices()).reshape( + (NR_PROCESSES, NR_LOCAL_DEVICES)) + mesh = jax.sharding.Mesh(devices, ["p", "l"]) + + # x_global: i32[P, L, 3] = [[[0, 1, 2], [10, 11, 12]], + # [[100, 101, 102], [110, 111, 112]], + # ...] + # x_local: i32[1, L, 3] + # y: i32[3, 5] + x_global = jnp.arange( + 1000, dtype=jnp.int32).reshape( + (10, 10, 10))[:NR_PROCESSES, :NR_LOCAL_DEVICES, :3] + process_id = jax.process_index() + x_local = x_global[process_id:process_id + 1] + + def callback_times5_func(x): + callback_collector.collect(x) + return x * np.array(5, np.int32) + + def fun(x_local): + return io_callback(callback_times5_func, + x_local, # result shape dtypes + x_local) + + expected_res = x_local * np.array(5, np.int32) + pjit_fun = pjit.pjit(fun, + in_shardings=P("p", "l"), + out_shardings=P("p", "l")) + + with mesh: + gx = multihost_utils.host_local_array_to_global_array( + x_local, mesh, P("p", "l")) + global_res = pjit_fun(gx) + res = multihost_utils.global_array_to_host_local_array( + global_res, mesh, P("p", "l")) + + self.assertAllClose(expected_res, res) + jax.effects_barrier() + + if jax.process_index() == 0: + # All calls are on the process 0; the 100s digit specifies the device + self.assertAllClose(callback_collector.collected, + [np.array([[[0, 1, 2], + [10, 11, 12]], + [[100, 101, 102], + [110, 111, 112]], + [[200, 201, 202], + [210, 211, 212]], + [[300, 301, 302], + [310, 311, 312]]], dtype=np.int32)]) + else: + self.assertAllClose(callback_collector.collected, []) + + +if __name__ == "__main__": + jt_multiprocess.main() diff --git a/tests/multiprocess/key_value_store_test.py b/tests/multiprocess/key_value_store_test.py new file mode 100644 index 000000000000..e02c3d36c608 --- /dev/null +++ b/tests/multiprocess/key_value_store_test.py @@ -0,0 +1,146 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Distributed key value store test.""" + +import jax +from jax._src import distributed +from jax._src import test_multiprocess as jt_multiprocess + + +class KeyValueStoreTest(jt_multiprocess.MultiProcessTest): + def testBlockingKeyValueGet(self): + client = distributed.global_state.client + key = 'test_key' + expected_value = 'JAX is great!' + timeout_in_ms = 1000 + + if jax.process_index() == 0: + client.key_value_set(key, expected_value) + actual_value = client.blocking_key_value_get(key, timeout_in_ms) + + self.assertEqual(expected_value, actual_value) + + def testBlockingKeyValueSetTwice(self): + client = distributed.global_state.client + key = 'test_key_' + str(jax.process_index()) + expected_value = 'JAX is great!' + + with self.assertRaisesRegex( + jax.errors.JaxRuntimeError, + r'ALREADY_EXISTS: key .* already exists.' + ): + client.key_value_set(key, expected_value) + client.key_value_set(key, expected_value) + + def testBlockingKeyValueSetTwice_Overwrite(self): + client = distributed.global_state.client + key = 'test_key_overwrite_' + str(jax.process_index()) + initial_value = 'JAX is okay!' + overwritten_value = 'JAX is great!' + timeout_in_ms = 1000 + + client.key_value_set(key, initial_value) + client.key_value_set(key, overwritten_value, allow_overwrite=True) + actual_value = client.blocking_key_value_get(key, timeout_in_ms) + + self.assertEqual(overwritten_value, actual_value) + + def testBlockingKeyValueGetBytes(self): + client = distributed.global_state.client + key = 'test_key2' + expected_value = b'JAX is great!' + timeout_in_ms = 1000 + + if jax.process_index() == 0: + client.key_value_set_bytes(key, expected_value) + actual_value = client.blocking_key_value_get_bytes(key, timeout_in_ms) + + self.assertEqual(expected_value, actual_value) + + def testKeyValueTryGet(self): + client = distributed.global_state.client + key = 'test_key_try_get' + expected_value = 'JAX is great!' + if jax.process_index() == 0: + client.key_value_set(key, expected_value) + client.wait_at_barrier('kv_try_get_barrier', 1000) # 1 second. + + actual_value = client.key_value_try_get(key) + + self.assertEqual(expected_value, actual_value) + + def testKeyValueTryGet_NotFound(self): + client = distributed.global_state.client + key = 'test_key_not_found' + with self.assertRaisesRegex( + jax.errors.JaxRuntimeError, + r'NOT_FOUND: Config key .* not found.' + ): + client.key_value_try_get(key) + + def testKeyValueTryGetBytes(self): + client = distributed.global_state.client + key = 'test_key_try_get_bytes' + expected_value = b'JAX is great!' + if jax.process_index() == 0: + client.key_value_set_bytes(key, expected_value) + client.wait_at_barrier('kv_try_get_bytes_barrier', 1000) # 1 second. + + actual_value = client.key_value_try_get_bytes(key) + + self.assertEqual(expected_value, actual_value) + + def testKeyValueDirGet(self): + client = distributed.global_state.client + kvs = [('dir/key0', 'value0'), ('dir/key2', 'value2'), + ('dir/nested/key3', 'value3')] + timeout_in_ms = 1000 + + if jax.process_index() == 0: + for kv in kvs: + client.key_value_set(kv[0], kv[1]) + client.wait_at_barrier('wait_for_kv_set1', timeout_in_ms) + actual_kvs = client.key_value_dir_get('dir/') + self.assertSameElements(kvs, actual_kvs) + + def testKeyValueDirGetBytes(self): + client = distributed.global_state.client + kvs = [('dir2/key0', b'value0'), ('dir2/key2', b'avalue2'), + ('dir2/nested/key3', b'avalue3')] + timeout_in_ms = 1000 + + if jax.process_index() == 0: + for kv in kvs: + client.key_value_set_bytes(kv[0], kv[1]) + client.wait_at_barrier('wait_for_kv_set2', timeout_in_ms) + actual_kvs = client.key_value_dir_get_bytes('dir2/') + self.assertSameElements(kvs, actual_kvs) + + def testLargeKeyValueDirGet(self): + client = distributed.global_state.client + value_size = 1024 * 1024 # bytes + num_keys = 10 + kvs = [(f'dir3/key{i}', 'x' * value_size) for i in range(num_keys)] + timeout_in_ms = 30 * 1000 + + if jax.process_index() == 0: + for kv in kvs: + client.key_value_set(kv[0], kv[1]) + client.wait_at_barrier('wait_for_kv_set3', timeout_in_ms) + actual_kvs = client.key_value_dir_get('dir3/') + self.assertSameElements(kvs, actual_kvs) + +if __name__ == '__main__': + jt_multiprocess.main() diff --git a/tests/multiprocess/multihost_utils_test.py b/tests/multiprocess/multihost_utils_test.py new file mode 100644 index 000000000000..0b2527e4fb32 --- /dev/null +++ b/tests/multiprocess/multihost_utils_test.py @@ -0,0 +1,541 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multihost tests for pjit.""" + +import math +import unittest + +from absl.testing import parameterized +import jax +from jax import numpy as jnp +from jax._src import test_multiprocess as jt_multiprocess +from jax._src import test_util as jtu +from jax.experimental import multihost_utils +from jax.sharding import PartitionSpec as P +import numpy as np + + +class MultiHostUtilsTest(jt_multiprocess.MultiProcessTest): + + def test_process_allgather_stacked(self): + elems_per_host = 4 + + num_processes = jax.process_count() + x = jnp.ones((4,)).reshape((2, 2)) + out = multihost_utils.process_allgather(x, tiled=False) + self.assertEqual(out.shape, (num_processes, 2, 2)) + np.testing.assert_array_equal(out, np.stack([x] * num_processes)) + + x = jnp.ones((64,)).reshape((8, 4, 2)) + out = multihost_utils.process_allgather(x, tiled=False) + self.assertEqual(out.shape, (num_processes, 8, 4, 2)) + np.testing.assert_array_equal(out, np.stack([x] * num_processes)) + + x = np.arange(elems_per_host) + jax.process_index() * elems_per_host + out = multihost_utils.process_allgather(x, tiled=False) + self.assertEqual(out.shape, (num_processes, 4)) + np.testing.assert_array_equal( + out, + np.arange(elems_per_host * jax.process_count()).reshape( + num_processes, elems_per_host + ), + ) + + x = np.array(0) + jax.process_index() * elems_per_host + out = multihost_utils.process_allgather(x, tiled=False) + self.assertEqual(out.shape, (num_processes,)) + np.testing.assert_array_equal( + out, np.arange(num_processes) * elems_per_host + ) + + def test_process_allgather_concatenated(self): + elems_per_host = 4 + + num_processes = jax.process_count() + x = jnp.ones((4,)).reshape((2, 2)) + out = multihost_utils.process_allgather(x, tiled=True) + self.assertEqual(out.shape, (2 * num_processes, 2)) + np.testing.assert_array_equal(out, np.concatenate([x] * num_processes)) + + x = jnp.ones((64,)).reshape((8, 4, 2)) + out = multihost_utils.process_allgather(x, tiled=True) + self.assertEqual(out.shape, (8 * num_processes, 4, 2)) + np.testing.assert_array_equal(out, np.concatenate([x] * num_processes)) + + x = np.arange(elems_per_host) + jax.process_index() * elems_per_host + out = multihost_utils.process_allgather(x, tiled=True) + self.assertEqual(out.shape, (elems_per_host * num_processes,)) + np.testing.assert_array_equal( + out, np.arange(elems_per_host * jax.process_count()) + ) + + x = np.array(0) + jax.process_index() * elems_per_host + out = multihost_utils.process_allgather(x, tiled=True) + self.assertEqual(out.shape, (num_processes,)) + np.testing.assert_array_equal( + out, np.arange(num_processes) * elems_per_host + ) + + def test_process_allgather_set_mesh(self): + devices = jax.devices()[1:] + [jax.devices()[0]] + user_mesh = jax.sharding.Mesh( + np.array(devices).reshape(jax.device_count(), 1, 1), + ('x', 'y', 'z'), + ) + x = jnp.ones((4,)).reshape((2, 2)) + # process_allgather should not be impacted by any global mesh context. + with jax.set_mesh(user_mesh): + num_processes = jax.process_count() + out = multihost_utils.process_allgather(x, tiled=True) + self.assertEqual(out.shape, (2 * num_processes, 2)) + np.testing.assert_array_equal(out, np.concatenate([x] * num_processes)) + + @jtu.ignore_warning( + category=DeprecationWarning, + message='jax.sharding.PmapSharding is deprecated', + ) + def test_broadcast_one_to_all(self): + elems_per_host = 4 + + x = np.arange(elems_per_host) + jax.process_index() * elems_per_host + out = multihost_utils.broadcast_one_to_all((x, x)) + jax.tree.map( + lambda x: np.testing.assert_array_equal( # pylint: disable=g-long-lambda + x, np.arange(elems_per_host) + ), + out, + ) + + x = np.array(0) + jax.process_index() * elems_per_host + out = multihost_utils.broadcast_one_to_all(x) + np.testing.assert_array_equal(out, np.array(0)) + + @jtu.ignore_warning( + category=DeprecationWarning, + message='jax.sharding.PmapSharding is deprecated', + ) + def test_broadcast_one_to_all_set_mesh(self): + devices = jax.devices()[1:] + [jax.devices()[0]] + user_mesh = jax.sharding.Mesh( + np.array(devices).reshape(jax.device_count(), 1, 1), + ('x', 'y', 'z'), + ) + # broadcast_one_to_all should not be impacted by any global mesh context. + with jax.set_mesh(user_mesh): + elems_per_host = 4 + + x = np.arange(elems_per_host) + jax.process_index() * elems_per_host + out = multihost_utils.broadcast_one_to_all((x, x)) + jax.tree.map( + lambda x: np.testing.assert_array_equal( # pylint: disable=g-long-lambda + x, np.arange(elems_per_host) + ), + out, + ) + + x = np.array(0) + jax.process_index() * elems_per_host + out = multihost_utils.broadcast_one_to_all(x) + np.testing.assert_array_equal(out, np.array(0)) + + @jtu.ignore_warning( + category=DeprecationWarning, + message='jax.sharding.PmapSharding is deprecated', + ) + def test_broadcast_one_to_all_uint8(self): + elems_per_host = 4 + + x = (np.arange(elems_per_host, dtype=jnp.uint8) + + jax.process_index() * elems_per_host) + out = multihost_utils.broadcast_one_to_all((x, x)) + jax.tree.map( + lambda x: np.testing.assert_array_equal( # pylint: disable=g-long-lambda + x, np.arange(elems_per_host, dtype=jnp.uint8) + ), + out, + ) + jax.tree.map(lambda o: self.assertEqual(o.dtype, jnp.uint8), out) + + x = np.array(0, dtype=jnp.uint8) + jax.process_index() * elems_per_host + out = multihost_utils.broadcast_one_to_all(x) + self.assertEqual(out.dtype, jnp.uint8) + np.testing.assert_array_equal(out, np.array(0, dtype=jnp.uint8)) + + def test_sync_global_devices(self): + multihost_utils.sync_global_devices('test sync global devices') + + def test_sync_global_devices_error(self): + # All processes should raise. + with self.assertRaises(AssertionError): + if jax.process_index() == 0: + multihost_utils.sync_global_devices('test message') + else: + multihost_utils.sync_global_devices('test message2') + + def test_sync_global_devices_mesh_context_manager(self): + global_mesh = jtu.create_mesh((2, 2), ('x', 'y'), iota_order=True) + with global_mesh: + multihost_utils.sync_global_devices('test sync global devices') + + def test_assert_equal_global(self): + mesh = jtu.create_mesh((8,), 'x') + shape = (8, 2) + np_inp = np.arange(math.prod(shape)).reshape(shape) + inp = jax.make_array_from_callback( + shape, jax.NamedSharding(mesh, P()), lambda idx: np_inp[idx]) + multihost_utils.assert_equal(inp) + + def test_process_allgather_cache_hit(self): + x = jnp.ones((4,)).reshape(2, 2) + y = jnp.arange(4.0).reshape(2, 2) + + num_processes = jax.process_count() + with jtu.count_pjit_cpp_cache_miss() as count: + out = multihost_utils.process_allgather(x, tiled=False) + out2 = multihost_utils.process_allgather(y, tiled=False) + + # Cpp cache hit. + self.assertEqual(count(), 1) + + self.assertEqual(out.shape, (num_processes, 2, 2)) + np.testing.assert_array_equal(out, np.stack([x] * num_processes)) + self.assertEqual(out2.shape, (num_processes, 2, 2)) + np.testing.assert_array_equal(out2, np.stack([y] * num_processes)) + + def test_reshard(self): + mesh1 = jtu.create_mesh((8,), 'x') + mesh2 = jax.sharding.Mesh( + np.asarray(jax.devices()[::-1]).reshape(4, 2), ('x', 'y') + ) + + shape = (8, 2) + np_inp = np.arange(math.prod(shape)).reshape(shape) + inp = jax.make_array_from_callback( + shape, + jax.sharding.NamedSharding(mesh1, P('x')), + lambda idx: np_inp[idx], + ) + + out = jax.device_put(inp, jax.sharding.NamedSharding(mesh2, P('x', 'y'))) + self.assertIsInstance(out.sharding, jax.sharding.NamedSharding) + for s in out.addressable_shards: + np.testing.assert_array_equal(s.data, np_inp[s.index]) + + @parameterized.named_parameters( + ('inp_replicated', P(), P('x', 'y')), + ('target_replicated', P('x'), P()), + ('both_replicated', P(), P()), + ) + def test_reshard_replicated_sharding(self, inp_spec, target_spec): + mesh1 = jtu.create_mesh((8,), 'x') + mesh2 = jax.sharding.Mesh( + np.asarray(jax.devices()[::-1]).reshape(4, 2), ('x', 'y') + ) + + shape = (8, 2) + np_inp = np.arange(math.prod(shape)).reshape(shape) + inp = jax.make_array_from_callback( + shape, + jax.sharding.NamedSharding(mesh1, inp_spec), + lambda idx: np_inp[idx], + ) + + out = jax.device_put(inp, jax.sharding.NamedSharding(mesh2, target_spec)) + self.assertIsInstance(out.sharding, jax.sharding.NamedSharding) + for s in out.addressable_shards: + np.testing.assert_array_equal(s.data, np_inp[s.index]) + + def test_reshard_same_device_assignment(self): + mesh1 = jtu.create_mesh((4, 2), ('x', 'y')) + mesh2 = jtu.create_mesh((2, 4), ('x', 'y')) + + shape = (8, 2) + np_inp = np.arange(math.prod(shape)).reshape(shape) + inp = jax.make_array_from_callback( + shape, + jax.sharding.NamedSharding(mesh1, P('x', 'y')), + lambda idx: np_inp[idx], + ) + + out = jax.device_put(inp, jax.sharding.NamedSharding(mesh2, P('y'))) + self.assertIsInstance(out.sharding, jax.sharding.NamedSharding) + for s in out.addressable_shards: + np.testing.assert_array_equal(s.data, np_inp[s.index]) + + def test_reshard_pytree(self): + mesh1 = jtu.create_mesh((8,), 'x') + + dev = jax.devices() + if len(dev) < 8: + raise unittest.SkipTest('Test requires 8 devices') + dev_list = [dev[0], dev[7], dev[6], dev[2], dev[4], dev[3], dev[5], dev[1]] + mesh2 = jax.sharding.Mesh( + np.asarray(dev_list).reshape(2, 2, 2), ('x', 'y', 'z') + ) + + shape = (8, 2) + np_inp = np.arange(math.prod(shape)).reshape(shape) + inp = jax.make_array_from_callback( + shape, + jax.sharding.NamedSharding(mesh1, P('x')), + lambda idx: np_inp[idx], + ) + + out1, out2 = jax.device_put( + (inp, inp), jax.sharding.NamedSharding(mesh2, P('x', 'y')) + ) + + for out in (out1, out2): + self.assertIsInstance(out.sharding, jax.sharding.NamedSharding) + for s in out.addressable_shards: + np.testing.assert_array_equal(s.data, np_inp[s.index]) + + def test_reshard_different_devices(self): + if jtu.is_device_tpu('5', 'e'): + raise unittest.SkipTest('Test fails on v5e') + dev = jax.devices() + if len(dev) < 8: + raise unittest.SkipTest('Test requires 8 devices') + mesh1 = jax.sharding.Mesh([dev[0], dev[2], dev[4], dev[6]], 'x') + mesh2 = jax.sharding.Mesh(jax.devices(), 'x') + + shape = (8, 2) + np_inp = np.arange(math.prod(shape)).reshape(shape) + inp = jax.make_array_from_callback( + shape, + jax.sharding.NamedSharding(mesh1, P('x')), + lambda idx: np_inp[idx], + ) + + with self.assertRaisesRegex( + ValueError, + 'input and target sharding should have the same set of devices', + ): + jax.device_put(inp, jax.sharding.NamedSharding(mesh2, P('x'))) + + def test_process_allgather_array_not_fully_addressable(self): + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) + global_input_shape = (8, 2) + global_input_data = np.arange(math.prod(global_input_shape)).reshape( + global_input_shape + ) + + arr = jax.make_array_from_callback( + global_input_shape, + jax.sharding.NamedSharding(global_mesh, P('x', 'y')), + lambda idx: global_input_data[idx], + ) + + out = multihost_utils.process_allgather(arr, tiled=True) + np.testing.assert_array_equal(out, global_input_data) + + with self.assertRaisesRegex( + ValueError, + 'Gathering global non-fully-addressable arrays only supports' + ' tiled=True'): + multihost_utils.process_allgather(arr, tiled=False) + + @jtu.ignore_warning( + category=DeprecationWarning, + message='jax.sharding.PmapSharding is deprecated', + ) + def test_host_local_array_to_global_array_already_global(self): + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) + global_input_shape = (8, 2) + global_input_data = np.arange(math.prod(global_input_shape)).reshape( + global_input_shape + ) + + arr = jax.make_array_from_callback( + global_input_shape, + jax.sharding.NamedSharding(global_mesh, P('x', 'y')), + lambda idx: global_input_data[idx], + ) + + out = multihost_utils.host_local_array_to_global_array( + arr, global_mesh, P('x', 'y') + ) + + self.assertEqual(id(arr), id(out)) + + @jtu.ignore_warning( + category=DeprecationWarning, + message='jax.sharding.PmapSharding is deprecated', + ) + def test_host_local_array_to_global_array_same_sharding_array(self): + if jtu.is_device_tpu('5', 'e'): + raise unittest.SkipTest('Test fails on v5e') + global_mesh = jtu.create_mesh((4, 2), ('x', 'y'), iota_order=True) + local_input_shape = (2, 2) + + elems_per_host = 4 + local_input_data = ( + jnp.arange(elems_per_host) + jax.process_index() * elems_per_host + ).reshape(local_input_shape) + + arr = jax.make_array_from_callback( + local_input_shape, + jax.sharding.NamedSharding(global_mesh.local_mesh, P('x', 'y')), + lambda idx: local_input_data[idx], + ) + + out = multihost_utils.host_local_array_to_global_array( + arr, global_mesh, P('x', 'y') + ) + + expected_global_shape = (8, 2) + self.assertEqual(out.shape, expected_global_shape) + + global_data = np.arange(math.prod(expected_global_shape)).reshape( + expected_global_shape + ) + for a, o in zip(arr.addressable_shards, out.addressable_shards): + self.assertEqual( + a.data.unsafe_buffer_pointer(), o.data.unsafe_buffer_pointer() + ) + np.testing.assert_array_equal(o.data, global_data[o.index]) + + @jtu.ignore_warning( + category=DeprecationWarning, + message='jax.sharding.PmapSharding is deprecated', + ) + def test_host_local_to_global_reshard_committed_single_device_array(self): + if jtu.is_device_tpu('5', 'e'): + raise unittest.SkipTest('Test fails on v5e') + global_mesh = jtu.create_mesh((4, 2), ('x', 'y'), iota_order=True) + local_input_shape = (2, 2) + + elems_per_host = 4 + local_input_data = ( + jnp.arange(elems_per_host) + jax.process_index() * elems_per_host + ).reshape(local_input_shape) + + arr = jax.make_array_from_callback( + local_input_shape, + jax.sharding.NamedSharding(global_mesh.local_mesh, P('x', 'y')), + lambda idx: local_input_data[idx], + ) + + out = multihost_utils.host_local_array_to_global_array( + arr, global_mesh, P('x', 'y') + ) + + expected_global_shape = (8, 2) + self.assertEqual(out.shape, expected_global_shape) + + global_data = np.arange(math.prod(expected_global_shape)).reshape( + expected_global_shape + ) + for a, o in zip(arr.addressable_shards, out.addressable_shards): + self.assertEqual( + a.data.unsafe_buffer_pointer(), o.data.unsafe_buffer_pointer() + ) + np.testing.assert_array_equal(o.data, global_data[o.index]) + + @jtu.ignore_warning(category=DeprecationWarning) + def test_host_local_to_global_replicated(self): + num_local_devices = jax.local_device_count() + global_mesh = jax.sharding.Mesh(jax.devices(), axis_names=['x']) + local_input_shape = (2, 2) + local_input_data = jnp.arange(4).reshape(local_input_shape) + + out = multihost_utils.host_local_array_to_global_array( + local_input_data, global_mesh, P() + ) + + expected_global_shape = (2, 2) + self.assertEqual(out.shape, expected_global_shape) + self.assertLen(out.addressable_shards, num_local_devices) + # Array is accessible on every host. + np.testing.assert_array_equal(out, local_input_data) + + @jtu.ignore_warning(category=DeprecationWarning) + def test_host_local_to_global_locally_replicated(self): + # Make an array which is locally replicated but sharded across hosts. + num_processes = jax.process_count() + num_local_devices = jax.local_device_count() + global_mesh = jtu.create_mesh( + (num_processes, num_local_devices), ('host', 'dev'), iota_order=True) + local_input_shape = (2, 2) + host_id = jax.process_index() + local_input_data = jnp.arange(4).reshape(local_input_shape) * host_id + + out = multihost_utils.host_local_array_to_global_array( + local_input_data, global_mesh, P('host', None)) + global_data = np.concatenate([jnp.arange(4).reshape(local_input_shape) * i + for i in range(num_processes)]) + expected_global_shape = global_data.shape + self.assertEqual(out.shape, expected_global_shape) + self.assertLen(out.addressable_shards, num_local_devices) + for o in out.addressable_shards: + # Each shard has the same shape matching local_input_shape and smae + # global index. + self.assertEqual(o.data.shape, local_input_shape) + self.assertEqual(o.index, out.addressable_shards[0].index) + np.testing.assert_array_equal(o.data, global_data[o.index]) + + @jtu.ignore_warning( + category=DeprecationWarning, + message='jax.sharding.PmapSharding is deprecated', + ) + def test_global_array_to_host_local_array(self): + if jtu.is_device_tpu('5', 'e'): + raise unittest.SkipTest('Test fails on v5e') + global_mesh = jtu.create_mesh((4, 2), ('x', 'y'), iota_order=True) + global_shape = (8, 2) + global_data = np.arange(math.prod(global_shape)).reshape(global_shape) + + arr = jax.make_array_from_callback( + global_shape, + jax.sharding.NamedSharding(global_mesh, P('x', 'y')), + lambda idx: global_data[idx], + ) + + out = multihost_utils.global_array_to_host_local_array( + arr, global_mesh, P('x') + ) + + self.assertEqual(out.shape, (2, 2)) + self.assertEqual( + out.sharding, jax.sharding.NamedSharding(global_mesh.local_mesh, P('x')) + ) + + local_input_data = (np.arange(4) + jax.process_index() * 4).reshape( + out.shape + ) + for s in out.addressable_shards: + np.testing.assert_array_equal(s.data, local_input_data) + + def test_host_local_array_to_global_array_none_error(self): + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) + global_shape = (8, 2) + data = np.arange(math.prod(global_shape)).reshape(global_shape) + + with self.assertRaisesRegex( + ValueError, '`None` is not a valid input to the pspecs argument' + ): + multihost_utils.host_local_array_to_global_array(data, global_mesh, None) + + with self.assertRaisesRegex( + ValueError, '`None` is not a valid input to the pspecs argument' + ): + multihost_utils.global_array_to_host_local_array(data, global_mesh, None) + + def test_live_devices(self): + with multihost_utils.live_devices(jax.devices()) as live: + self.assertEqual(set(live), set(jax.devices())) + + +if __name__ == '__main__': + jt_multiprocess.main() diff --git a/tests/multiprocess/pgle_test.py b/tests/multiprocess/pgle_test.py new file mode 100644 index 000000000000..775a94c030e8 --- /dev/null +++ b/tests/multiprocess/pgle_test.py @@ -0,0 +1,110 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multihost tests for pgle.""" + +import functools +import math +import os +import tempfile + +from absl.testing import parameterized +import jax +from jax._src import config +from jax._src import test_multiprocess as jt_multiprocess +from jax._src import test_util as jtu +import jax.numpy as jnp +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec +import numpy as np + + +class PgleTestMultiHost(jt_multiprocess.MultiProcessTest): + + def get_fdo_profiles(self, dump_dir): + jit_f_fdo_profiles = [ + x + for x in os.listdir(dump_dir) + if 'jit_f' in x and x.endswith('.fdo_profile') + ] + return jit_f_fdo_profiles + + @parameterized.parameters(True, False) + def testAutoPGLE(self, use_compilation_cache: bool): + mesh = jtu.create_mesh((jax.device_count(),), ('x',)) + + its = 500 + + with tempfile.TemporaryDirectory() as dump_dir: + + @functools.partial( + jax.jit, + in_shardings=NamedSharding(mesh, PartitionSpec('x')), + out_shardings=NamedSharding(mesh, PartitionSpec('x')), + compiler_options={ + 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # TODO(patrios): Remove this flag once b/376647494 is fixed. + 'xla_gpu_graph_min_graph_size': '100000', + 'xla_dump_to': dump_dir, + 'xla_gpu_experimental_dump_fdo_profiles': 'True', + }, + ) + def f(x): + agg = x + for _ in range(its): + agg = agg @ x + return agg + + shape = (16, 16) + x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) + + num_runs = 2 + with ( + config.pgle_profiling_runs(num_runs), + config.enable_pgle(True), + config.enable_compilation_cache(use_compilation_cache), + config.raise_persistent_cache_errors(True), + config.raise_persistent_cache_errors(True), + config.persistent_cache_min_entry_size_bytes(0), + config.persistent_cache_min_compile_time_secs(0), + ): + for _ in range(num_runs): + f(x) + + # There should be 3 fdo profiles: before optimization, after + # SPMD-partitioning, and after optimization. + fdo_profiles_before_pgle = self.get_fdo_profiles(dump_dir) + self.assertLen(fdo_profiles_before_pgle, 3) + self.assertEqual( + os.path.getsize( + os.path.join(dump_dir, fdo_profiles_before_pgle[0]) + ), + 0, + ) + + # Should recompile with the FDO profile. + f(x) + + # Expect 3 additional non-empty fdo profiles. + fdo_profiles_after_pgle = self.get_fdo_profiles(dump_dir) + self.assertLen(fdo_profiles_after_pgle, 6) + for fdo_profile in fdo_profiles_after_pgle: + if fdo_profile not in fdo_profiles_before_pgle: + self.assertGreater( + os.path.getsize(os.path.join(dump_dir, fdo_profile)), 0 + ) + + +if __name__ == '__main__': + jt_multiprocess.main() diff --git a/tests/multiprocess/pjit_test.py b/tests/multiprocess/pjit_test.py new file mode 100644 index 000000000000..79c0721ab66b --- /dev/null +++ b/tests/multiprocess/pjit_test.py @@ -0,0 +1,556 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multihost tests for pjit.""" + +import collections +from concurrent import futures +import contextlib +import functools +import io +import math +import unittest + +import jax +from jax import numpy as jnp +from jax._src import array +from jax._src import debugging +from jax._src import pjit +from jax._src import test_multiprocess as jt_multiprocess +from jax._src import test_util as jtu +from jax.sharding import PartitionSpec as P +import numpy as np + +X_SIZE = 2 +Y_SIZE = 2 +CHIPS_SIZE = 2 +ALL_AXES = ("x", "y", "chips") + + +def sorted_devices(): + devices = sorted( + jax.devices(), key=lambda d: (d.host_id, getattr(d, "core_on_chip", 0)) + ) + if len(devices) != 8: + raise unittest.SkipTest("Test assumes that it runs on a TPU donut") + return devices + + +@contextlib.contextmanager +def use_default_mesh(): + devices = sorted_devices() + mesh_devices = np.array(devices).reshape((X_SIZE, Y_SIZE, CHIPS_SIZE)) + with jax.sharding.Mesh(mesh_devices, ("x", "y", "chips")): + yield + + +def create_2d_non_contiguous_mesh(): + devices = sorted_devices() + device_mesh = np.array([ + [devices[0], devices[2]], + [devices[3], devices[1]], + [devices[4], devices[6]], + [devices[7], devices[5]], + ]) + # On TPUv3, the mesh looks like this (the integers are process index): + # 0 1 + # 1 0 + # 2 3 + # 3 2 + return jax.sharding.Mesh(device_mesh, ("x", "y")) + + +def create_2d_non_contiguous_mesh2(): + devices = sorted_devices() + device_mesh = np.array([ + [devices[0], devices[2]], + [devices[1], devices[3]], + [devices[4], devices[6]], + [devices[5], devices[7]], + ]) + # On TPUv3, the mesh looks like this (the integers are process index): + # 0 1 + # 0 1 + # 2 3 + # 2 3 + return jax.sharding.Mesh(device_mesh, ("x", "y")) + + +# TODO(apaszke): Test with mesh that has host-tiled axes (especially nesting!) +class PJitTestMultiHost(jt_multiprocess.MultiProcessTest): + + @jtu.ignore_warning(category=DeprecationWarning) + def testLocalInputsWithJaxArray(self): + # Note that this is too small to shard over the global mesh, but fine for + # the local mesh and so should be accepted. + mesh = jtu.create_mesh((4, 2), ("x", "y")) + elems_per_host = 4 + x = jnp.arange(elems_per_host) + jax.process_index() * elems_per_host + iar = jax.sharding.PartitionSpec("x") + oar = jax.sharding.PartitionSpec("x") + with mesh: + f = pjit.pjit(lambda x, y: (x, y), in_shardings=iar, out_shardings=oar) + gx = jax.experimental.multihost_utils.host_local_array_to_global_array( + (x, x), mesh, iar + ) + global_out = f(*gx) + out1, out2 = ( + jax.experimental.multihost_utils.global_array_to_host_local_array( + global_out, mesh, oar + ) + ) + np.testing.assert_array_equal(out1, x) + np.testing.assert_array_equal(out2, x) + + +class ArrayPjitMultiHost(jt_multiprocess.MultiProcessTest): + + def test_pjit_array_single_output(self): + global_mesh = jtu.create_mesh((4, 2), ("x", "y")) + global_input_shape = (8, 2) + mesh_axes = jax.sharding.PartitionSpec("x", "y") + global_input_data = np.arange(math.prod(global_input_shape)).reshape( + global_input_shape + ) + s = jax.sharding.NamedSharding(global_mesh, mesh_axes) + + arr = array.make_array_from_callback( + global_input_shape, s, lambda idx: global_input_data[idx] + ) + + @functools.partial(pjit.pjit, out_shardings=s) + def f(x): + return x @ x.T + + expected_matrix_mul = global_input_data @ global_input_data.T + + out = f(arr) + self.assertIsInstance(out, array.ArrayImpl) + self.assertEqual(out.shape, (8, 8)) + self.assertEqual(out.addressable_shards[0].data.shape, (2, 4)) + for s in out.addressable_shards: + np.testing.assert_array_equal( + np.asarray(s.data), expected_matrix_mul[s.index] + ) + + # Test does not work with non-contiguous device IDs. + @jtu.skip_on_devices("cpu") + def test_pjit_array_non_contiguous_mesh_2d(self): + global_mesh = create_2d_non_contiguous_mesh() + global_input_shape = (8, 2) + pspec = jax.sharding.PartitionSpec("x", "y") + input_data = np.arange(math.prod(global_input_shape)).reshape( + global_input_shape + ) + in_sharding = jax.sharding.NamedSharding(global_mesh, pspec) + out_sharding = jax.sharding.NamedSharding(global_mesh, pspec) + + a1 = array.make_array_from_callback( + global_input_shape, in_sharding, lambda idx: input_data[idx] + ) + + # device_id -> (index, replica_id) + expected_idx_rid = { + 0: ((slice(0, 2), slice(0, 1)), 0), + 1: ((slice(2, 4), slice(1, 2)), 0), + 2: ((slice(0, 2), slice(1, 2)), 0), + 3: ((slice(2, 4), slice(0, 1)), 0), + 4: ((slice(4, 6), slice(0, 1)), 0), + 5: ((slice(6, 8), slice(1, 2)), 0), + 6: ((slice(4, 6), slice(1, 2)), 0), + 7: ((slice(6, 8), slice(0, 1)), 0), + } + + with global_mesh: + f = pjit.pjit(lambda x: x, out_shardings=out_sharding) + out = f(a1) + + for s in out.addressable_shards: + device_id = s.device.id + expected_index = expected_idx_rid[device_id][0] + expected_replica_id = expected_idx_rid[device_id][1] + self.assertEqual(s.index, expected_index) + self.assertEqual(s.replica_id, expected_replica_id) + self.assertEqual(s.data.shape, (2, 1)) + np.testing.assert_array_equal(s.data._value, input_data[expected_index]) + + with global_mesh: + f = pjit.pjit(lambda x: x) + out = f(a1) + + for s in out.addressable_shards: + device_id = s.device.id + expected_index = expected_idx_rid[device_id][0] + expected_replica_id = expected_idx_rid[device_id][1] + self.assertEqual(s.index, expected_index) + self.assertEqual(s.replica_id, expected_replica_id) + self.assertEqual(s.data.shape, (2, 1)) + np.testing.assert_array_equal(s.data._value, input_data[expected_index]) + + none_sharding = jax.sharding.NamedSharding( + global_mesh, jax.sharding.PartitionSpec(None) + ) + + with global_mesh: + f = pjit.pjit( + lambda x: x, in_shardings=none_sharding, out_shardings=out_sharding + ) + # Fully replicated values allows a non-contiguous mesh. + out = f(input_data) + self.assertIsInstance(out, array.ArrayImpl) + + a2 = array.make_array_from_callback( + global_input_shape, none_sharding, lambda idx: input_data[idx] + ) + + with global_mesh: + f = pjit.pjit( + lambda x, y: (x, y), + in_shardings=(none_sharding, none_sharding), + out_shardings=(out_sharding, out_sharding), + ) + # Fully replicated values + Array allows a non-contiguous mesh. + out1, out2 = f(input_data, a2) + self.assertIsInstance(out1, array.ArrayImpl) + self.assertIsInstance(out2, array.ArrayImpl) + + def test_sharded_add(self): + global_mesh = create_2d_non_contiguous_mesh() + input_shape = (8, 2) + input_data = np.arange(math.prod(input_shape)).reshape(input_shape) + a_s = jax.sharding.NamedSharding(global_mesh, P("x", "y")) + b_s = jax.sharding.NamedSharding(global_mesh, P("x")) + + a = array.make_array_from_callback( + input_shape, a_s, lambda idx: input_data[idx] + ) + b = array.make_array_from_callback( + input_shape, b_s, lambda idx: input_data[idx] + ) + + out = a + b + for s in out.addressable_shards: + np.testing.assert_array_equal( + s.data, (input_data + input_data)[s.index] + ) + + def test_sharded_jit_add(self): + global_mesh = create_2d_non_contiguous_mesh() + input_shape = (8, 2) + input_data = np.arange(math.prod(input_shape)).reshape(input_shape) + a_s = jax.sharding.NamedSharding(global_mesh, P("x", "y")) + b_s = jax.sharding.NamedSharding(global_mesh, P("x")) + + a = array.make_array_from_callback( + input_shape, a_s, lambda idx: input_data[idx] + ) + b = array.make_array_from_callback( + input_shape, b_s, lambda idx: input_data[idx] + ) + + out = jax.jit(lambda x, y: x + y)(a, b) + for s in out.addressable_shards: + np.testing.assert_array_equal(s.data, (input_data + input_data)[s.index]) + + def test_sharded_copy(self): + global_mesh = create_2d_non_contiguous_mesh() + input_shape = (8, 2) + input_data = np.arange(math.prod(input_shape)).reshape(input_shape) + s = jax.sharding.NamedSharding(global_mesh, P("x", "y")) + arr = array.make_array_from_callback( + input_shape, s, lambda idx: input_data[idx] + ) + # Copy the array sharded over multiple devices across multiple processes. + copy_arr = jnp.copy(arr) + + for c, a in zip(copy_arr.addressable_shards, arr.addressable_shards): + self.assertNotEqual( + c.data.unsafe_buffer_pointer(), a.data.unsafe_buffer_pointer() + ) + self.assertEqual(c.index, a.index) + self.assertEqual(c.replica_id, a.replica_id) + self.assertEqual(c.device, a.device) + np.testing.assert_array_equal(c.data, a.data) + + def test_sharded_mul(self): + global_mesh = create_2d_non_contiguous_mesh() + input_shape = (8, 2) + input_data = np.arange(math.prod(input_shape)).reshape(input_shape) + a_s = jax.sharding.NamedSharding(global_mesh, P("x", "y")) + + a = array.make_array_from_callback( + input_shape, a_s, lambda idx: input_data[idx] + ) + + out = a @ a.T + for s in out.addressable_shards: + np.testing.assert_array_equal( + s.data, (input_data @ input_data.T)[s.index] + ) + + def test_pjit_array_eval_shape(self): + with jtu.create_mesh((8,), "x"): + + @functools.partial( + pjit.pjit, + in_shardings=jax.sharding.PartitionSpec(None), + out_shardings=jax.sharding.PartitionSpec("x"), + ) + def f(): + return jnp.zeros([32, 10]) + + self.assertEqual(f().shape, (32, 10)) + self.assertEqual(jax.eval_shape(f).shape, (32, 10)) + + def test_trace_with_global_avals(self): + devices = sorted_devices() + mesh_devices = np.array(devices[::2] + devices[1::2]) + # The device order in the below mesh is: + # (0, 2, 4, 6, 1, 3, 5, 7) + # each having the following process index: + # (0, 1, 2, 3, 0, 1, 2, 3) + # self.assertListEqual([d.process_index for d in mesh_devices], + # [0, 1, 2, 3, 0, 1, 2, 3]) + global_mesh = jax.sharding.Mesh(mesh_devices, ("x",)) + x = jnp.arange(16) + + def check_shape(x): + self.assertEqual(x.shape, (16,)) + return x + + with global_mesh: + f = pjit.pjit( + check_shape, + in_shardings=jax.sharding.PartitionSpec("x"), + out_shardings=None, + ) + np.testing.assert_array_equal(f(x), jnp.arange(16)) + + @use_default_mesh() + def test_pjit_in_pjit(self): + # The global size of x is 16. The shape should remain constant i.e. (16,) + # within all `pjit`'s since with Array, pjit only accepts global shaped + # inputs and doesn't lift the shape. + x = jnp.arange(16) + + def pjit_all(f): + return pjit.pjit( + f, + in_shardings=jax.sharding.PartitionSpec(ALL_AXES), + out_shardings=jax.sharding.PartitionSpec(ALL_AXES), + ) + + def check_shape(x): + assert x.shape == (16,) + return x + + pjit_all(check_shape)(x) + pjit_all(pjit_all(check_shape))(x) + pjit_all(pjit_all(pjit_all(check_shape)))(x) + + def test_compile_parallel(self): + x = jnp.arange(16) + global_mesh = jtu.create_mesh((4, 2), ("x", "y")) + + def _lower_compile(inp): + with global_mesh: + f = pjit.pjit( + lambda x: x.sum(), + in_shardings=jax.sharding.PartitionSpec("x"), + out_shardings=None, + ) + exe = f.lower(inp).compile() + return exe + + with futures.ThreadPoolExecutor(max_workers=5) as executor: + result = executor.map(_lower_compile, [x] * 5) + + expected_out = np.arange(16).sum() + + for out in list(result): + np.testing.assert_array_equal(out(x), expected_out) + + def test_fully_sharded_on_all_devices(self): + if jax.local_device_count() > 1: + self.skipTest("This test only works with 1 process per device.") + num_devices = jax.device_count() + x = jnp.arange(num_devices) + global_mesh = jtu.create_mesh((num_devices,), "x") + with global_mesh: + f = pjit.pjit( + lambda x: x, + in_shardings=jax.sharding.PartitionSpec("x"), + out_shardings=jax.sharding.PartitionSpec("x"), + ) + out = f(x) + expected_out = np.arange(num_devices) + for s in out.addressable_shards: + np.testing.assert_array_equal(s.data, expected_out[s.index]) + + def test_on_device_size_in_bytes(self): + global_mesh = create_2d_non_contiguous_mesh() + input_shape = (8, 2) + input_data = np.arange(math.prod(input_shape)).reshape(input_shape) + a_s = jax.sharding.NamedSharding(global_mesh, P("x", "y")) + + a = array.make_array_from_callback( + input_shape, a_s, lambda idx: input_data[idx] + ) + shard_size = a.addressable_shards[0].data.on_device_size_in_bytes() + self.assertGreaterEqual(shard_size, 4 * 2) + self.assertEqual( + shard_size * len(a.global_shards), a.on_device_size_in_bytes() + ) + + def test_numpy_input_error_with_non_trivial_sharding(self): + global_mesh = jtu.create_mesh((8,), "x") + inp = np.arange(8) + with global_mesh: + f = pjit.pjit( + lambda x: x, + in_shardings=jax.sharding.PartitionSpec(None), + out_shardings=jax.sharding.PartitionSpec(None), + ) + out = f(inp) + np.testing.assert_array_equal(out, inp) + + # If no in_axis_resources are specified, then pjit assumes that the + # numpy input is fully replicated. + f = pjit.pjit(lambda x: x, out_shardings=jax.sharding.PartitionSpec(None)) + out = f(inp) + np.testing.assert_array_equal(out, inp) + + f = pjit.pjit( + lambda x: x, + in_shardings=jax.sharding.PartitionSpec("x"), + out_shardings=jax.sharding.PartitionSpec("x"), + ) + with self.assertRaisesRegex( + ValueError, + "Passing non-trivial shardings for numpy inputs is not allowed", + ): + f(inp) + + def test_non_contiguous_mesh_fetch_to_host(self): + if jax.local_device_count() != 2: + raise unittest.SkipTest("Test assumes 2 devices per process") + global_mesh = create_2d_non_contiguous_mesh() + input_shape = (8, 2) + input_data = np.arange(math.prod(input_shape)).reshape(input_shape) + a_s = jax.sharding.NamedSharding(global_mesh, P(None, "y")) + a = array.make_array_from_callback( + input_shape, a_s, lambda idx: input_data[idx] + ) + np.testing.assert_array_equal(a, input_data) + + def test_non_contiguous_mesh_fetch_to_host2(self): + global_mesh = create_2d_non_contiguous_mesh2() + input_shape = (8, 2) + input_data = np.arange(math.prod(input_shape)).reshape(input_shape) + a_s = jax.sharding.NamedSharding(global_mesh, P(None, "y")) + a = array.make_array_from_callback( + input_shape, a_s, lambda idx: input_data[idx] + ) + with self.assertRaisesRegex( + RuntimeError, + r"Fetching value for `jax.Array` that spans non-addressable \(non" + r" process local\) devices is not possible", + ): + _ = a._value + + def test_no_python_shard_arg_fallback(self): + global_mesh = jtu.create_mesh((4, 2), ("x", "y")) + input_shape = (8, 2) + input_data = np.arange(math.prod(input_shape)).reshape(input_shape) + a_s = jax.sharding.NamedSharding(global_mesh, P("x", "y")) + arr = array.make_array_from_callback( + input_shape, a_s, lambda idx: input_data[idx]) + + @jax.jit + def f(x): + return x * 2 + + with jtu.count_jax_array_shard_arg_calls() as count: + f(arr) + f(arr) + self.assertEqual(count(), 1) + + +@contextlib.contextmanager +def capture_stdout(): + with unittest.mock.patch("sys.stdout", new_callable=io.StringIO) as fp: + + def _read() -> str: + return fp.getvalue() + + yield _read + + +class MultiHostDebuggingTest(jt_multiprocess.MultiProcessTest): + + def _assert_lines_equal(self, text1, text2): + def _count(lines): + return collections.Counter(lines) + + self.assertDictEqual(_count(text1.split("\n")), _count(text2.split("\n"))) + + @use_default_mesh() + def test_print_in_multihost_pjit_array(self): + x = jnp.arange(16, dtype=jnp.int32) + + def f(x): + debugging.debug_print("{}", x, ordered=False) + return x + + f = pjit.pjit( + f, + in_shardings=jax.sharding.PartitionSpec(ALL_AXES), + out_shardings=jax.sharding.PartitionSpec(ALL_AXES), + ) + with capture_stdout() as output: + f(x) + jax.effects_barrier() + if jax.process_index() == 0: + self.assertEqual(output(), f"{np.arange(16, dtype=np.int32)}\n") + else: + self.assertEqual(output(), "") + + def test_print_in_multihost_shard_map(self): + devices = jax.devices() + mesh = jax.sharding.Mesh(devices, ("i",)) + num_devices = jax.local_device_count() + local_x = ( + jnp.arange(num_devices, dtype=jnp.int32) + + jax.process_index() * num_devices + ) + global_shape = (jax.device_count(),) + sharding = jax.NamedSharding(mesh, jax.P("i")) + global_x = jax.make_array_from_process_local_data(sharding, local_x, global_shape) + + @jax.jit + @jax.shard_map(mesh=mesh, in_specs=jax.P("i"), out_specs=jax.P("i")) + def f(x): + debugging.debug_print("{}", x[0], ordered=False) + return x + + with capture_stdout() as output: + out = f(global_x) + out.block_until_ready() + jax.effects_barrier() + + lines = [f"{i}" for i in local_x] + [""] + self._assert_lines_equal(output(), "\n".join(lines)) + +if __name__ == "__main__": + jt_multiprocess.main() diff --git a/tests/multiprocess/pmap_test.py b/tests/multiprocess/pmap_test.py new file mode 100644 index 000000000000..16737307e106 --- /dev/null +++ b/tests/multiprocess/pmap_test.py @@ -0,0 +1,97 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multihost tests for pmap.""" + +import unittest + +from absl.testing import parameterized +import jax +from jax import lax +from jax._src import array +from jax._src import test_multiprocess as jt_multiprocess +from jax._src import test_util as jtu +import jax.numpy as jnp +import numpy as np + + +def sorted_devices(): + devices = sorted( + jax.devices(), key=lambda d: (d.process_index(), d.core_on_chip)) + if len(devices) != 8: + raise unittest.SkipTest("Test assumes that it runs on a TPU donut") + return devices + + +class PmapTestMultiHost(jt_multiprocess.MultiProcessTest): + + @jtu.ignore_warning(category=DeprecationWarning) + def testBasic(self): + elems_per_host = 4 + devices = jax.local_devices() + x = [np.arange(i, i + elems_per_host) + jax.process_index() * elems_per_host + for i in range(len(devices))] + y = jax.device_put_sharded(x, devices) + f = jax.pmap(lambda x: lax.psum(x, "i"), axis_name="i") + out = f(y) + + expected_out = np.array([ + np.arange(i, i + elems_per_host) + p * elems_per_host # pylint: disable=g-complex-comprehension + for p in range(jax.process_count()) for i in range(len(devices)) + ]) + + self.assertIsInstance(out, array.ArrayImpl) + if jax.config.jax_pmap_shmap_merge: + self.assertIsInstance(out.sharding, jax.sharding.NamedSharding) + else: + self.assertIsInstance(out.sharding, jax.sharding.PmapSharding) + np.testing.assert_array_equal( + out, np.array([expected_out.sum(axis=0)] * len(devices))) + + def testLocalPmap(self): + z = jax.pmap( + lambda x: lax.axis_index("i"), + axis_name="i", + devices=jax.local_devices(), + )(np.arange(jax.local_device_count())) + np.testing.assert_array_equal(z, np.arange(jax.local_device_count())) + + @parameterized.named_parameters( + ("sharded_dim_0", 0), + ("sharded_dim_1", 1), + ) + @jtu.ignore_warning(category=DeprecationWarning) + def test_default_pmap_sharding(self, sharded_dim): + if jax.config.jax_pmap_shmap_merge: + self.skipTest("Does not apply for pmap shard_map merge") + + n = jax.local_device_count() + shape = (n, 1) if sharded_dim == 0 else (1, n) + + ps = jax.sharding.PmapSharding.default(shape, sharded_dim) + inp = jnp.arange(np.prod(shape)).reshape(shape) + compiled = jax.pmap(lambda x: x, in_axes=sharded_dim).lower(inp).compile() + pmap_in_sharding, = compiled._executable.unsafe_call.in_handler.in_shardings + + self.assertEqual(ps._device_assignment, pmap_in_sharding._device_assignment) + self.assertEqual(ps.sharding_spec, pmap_in_sharding.sharding_spec) + + def test_global_axis_size_initial_style(self): + xs = jnp.ones(jax.local_device_count()) + pmapped_f = jax.pmap(lambda x: jax.lax.all_gather(x, "i"), axis_name="i") + jaxpr = jax.make_jaxpr(pmapped_f)(xs) + jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, xs) # does not crash + +if __name__ == "__main__": + jt_multiprocess.main() diff --git a/tests/multiprocess/socket_transfer_test.py b/tests/multiprocess/socket_transfer_test.py new file mode 100644 index 000000000000..cc25d4cfe6c5 --- /dev/null +++ b/tests/multiprocess/socket_transfer_test.py @@ -0,0 +1,112 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for socket transfer.""" + +import jax +from jax._src import test_multiprocess as jt_multiprocess +from jax._src.lib import jaxlib_extension_version +from jax.sharding import PartitionSpec as P +import numpy as np + +try: + import portpicker # pytype: disable=import-error +except ImportError: + portpicker = None + + +class SocketTransferTest(jt_multiprocess.MultiProcessTest): + + def test_cross_host_transfer_single_device_sharding(self): + if jaxlib_extension_version < 397: + self.skipTest("Fails in older versions of jaxlib.") + x = np.arange(64).reshape(8, 8) + src_pid = 0 + dst_pid = 1 + src_sharding = jax.sharding.SingleDeviceSharding( + jax.local_devices(process_index=src_pid)[0]) + dst_sharding = jax.sharding.SingleDeviceSharding( + jax.local_devices(process_index=dst_pid)[0]) + y = jax.device_put(x, src_sharding) + z = jax.device_put(y, dst_sharding) + z.block_until_ready() + if jax.process_index() == dst_pid: + self.assertLen(z.addressable_shards, 1) + np.testing.assert_array_equal(z.addressable_shards[0].data, x) + else: + self.assertEmpty(z.addressable_shards) + + def test_cross_host_transfer_named_sharding(self): + if jaxlib_extension_version < 397: + self.skipTest("Fails in older versions of jaxlib.") + x = np.arange(64).reshape(8, 8) + n_local = jax.local_device_count() + src_pid = 0 + dst_pid = 1 + src_sharding = jax.sharding.NamedSharding( + jax.make_mesh((n_local,), ("x",), + devices=jax.local_devices(process_index=src_pid), + axis_types=(jax.sharding.AxisType.Explicit,)), + P("x")) + dst_sharding = jax.sharding.NamedSharding( + jax.make_mesh((n_local,), ("x",), + devices=jax.local_devices(process_index=dst_pid), + axis_types=(jax.sharding.AxisType.Explicit,)), + P("x")) + y = jax.device_put(x, src_sharding) + z = jax.device_put(y, dst_sharding) + z.block_until_ready() + if jax.process_index() == dst_pid: + self.assertLen(z.addressable_shards, n_local) + for shard in z.addressable_shards: + np.testing.assert_array_equal(shard.data, x[shard.index]) + else: + self.assertEmpty(z.addressable_shards) + + def test_cross_host_transfer_named_sharding_replicated(self): + if jaxlib_extension_version < 397: + self.skipTest("Fails in older versions of jaxlib.") + x = np.arange(64).reshape(8, 8) + n_dev = jax.device_count() // 2 + src_sharding = jax.sharding.NamedSharding( + jax.make_mesh((n_dev,), ("x",), devices=jax.devices()[:n_dev], + axis_types=(jax.sharding.AxisType.Explicit,)), + P() + ) + dst_sharding = jax.sharding.NamedSharding( + jax.make_mesh((n_dev,), ("x",), devices=jax.devices()[n_dev:], + axis_types=(jax.sharding.AxisType.Explicit,)), + P() + ) + y = jax.device_put(x, src_sharding) + z = jax.device_put(y, dst_sharding) + z.block_until_ready() + for shard in z.addressable_shards: + np.testing.assert_array_equal(shard.data, x[shard.index]) + + +if __name__ == "__main__": + if portpicker is None: + socket_port = 12345 + else: + socket_port = portpicker.pick_unused_port() + jax.config.update("jax_force_dcn_cross_host_transfers", True) + jax.config.update( + "jax_cross_host_transfer_socket_address", f"127.0.0.1:{socket_port}") + # Too small for good performance, but set to avoid oom in msan tests. + jax.config.update( + "jax_cross_host_transfer_transfer_size", + 64 * 1024, + ) + jt_multiprocess.main() diff --git a/tests/multiprocess/thread_guard_test.py b/tests/multiprocess/thread_guard_test.py new file mode 100644 index 000000000000..6f9eb74db110 --- /dev/null +++ b/tests/multiprocess/thread_guard_test.py @@ -0,0 +1,140 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test thread guard for multiprocess arrays.""" + +import concurrent.futures + +import jax +from jax._src import test_multiprocess as jt_multiprocess +from jax._src.lib import jaxlib_extension_version +import jax.numpy as jnp + + +@jax.jit +def f(x): + y = jnp.square(x) + z = jnp.cos(y) + return jnp.sum(z) + + +@jax.jit +def g(x): + y = jnp.square(x) + z = jnp.sin(y) + return z + 1 + + +class ThreadGuardTest(jt_multiprocess.MultiProcessTest): + + # Use a single test method since the thread guard affects global state and + # tests can't run in parallel. + def test_thread_guard(self): + if jaxlib_extension_version < 395: + self.skipTest( + 'Thread guard is supported only in jaxlib version >= 395. Jaxlib ' + f'version is {jaxlib_extension_version}') + + mesh = jax.make_mesh( + (jax.device_count(),), ('i',), + axis_types=(jax.sharding.AxisType.Explicit,), devices=jax.devices()) + sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('i')) + x = jnp.ones((jax.device_count(),)) + arr = jax.device_put(x, sharding) + + # Test slow JIT path. + with self.assertRaisesRegex( + (RuntimeError, ValueError), 'thread guard was set'): + with (jax.thread_guard(True), + concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor): + y = executor.submit(f, arr) + jax.block_until_ready(y.result()) + + # Test fast JIT path. + x = g(arr) + x = g(x) + with self.assertRaisesRegex( + (RuntimeError, ValueError), 'thread guard was set'): + with (jax.thread_guard(True), + concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor): + y = executor.submit(g, x) + jax.block_until_ready(y.result()) + + # Test local devices only. + mesh = jax.make_mesh( + (jax.local_device_count(),), ('i',), + axis_types=(jax.sharding.AxisType.Explicit,), + devices=jax.local_devices()) + sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('i')) + x = jnp.ones((jax.local_device_count(),)) + x = jax.device_put(x, sharding) + + with jax.thread_guard(True): + z = g(x) + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + y = executor.submit(f, x).result() + out = y + z + jax.block_until_ready(out) # No cross-process arrays, so no error. + + # Test nested thread guard context managers. + with jax.thread_guard(True): + y = g(arr) + with jax.thread_guard(True): + z = g(y) # No error when context manager is redundantly nested. + jax.block_until_ready(z) + with jax.thread_guard(False): + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + w = executor.submit(f, z).result() + jax.block_until_ready(w) # No error, thread guard deactivated. + + # Thread guard is re-activated by the outer context manager. + with self.assertRaisesRegex( + (RuntimeError, ValueError), 'thread guard was set'): + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + v = executor.submit(g, w) + jax.block_until_ready(v.result()) + + # No error on the call in a different thread outside the context manager. + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + y = executor.submit(f, x).result() + jax.block_until_ready(y) + + # Test thread guard in a subthread. + def f_with_thread_guard_should_raise(x): + with (jax.thread_guard(True), + concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor): + return executor.submit(f, x).result() + + with self.assertRaisesRegex( + (RuntimeError, ValueError), 'thread guard was set'): + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + y = executor.submit(f_with_thread_guard_should_raise, arr).result() + jax.block_until_ready(y) + + # Test nested thread guard context managers in different threads raises. + def f_with_thread_guard(x): + with jax.thread_guard(True): + return f(x) + + with self.assertRaisesRegex( + (RuntimeError, ValueError), + 'Nested thread guards in different threads are not supported'): + with jax.thread_guard(True): + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + y = executor.submit(f_with_thread_guard, arr).result() + jax.block_until_ready(y) + + +if __name__ == '__main__': + jt_multiprocess.main() diff --git a/tests/multiprocess/tpu_device_test.py b/tests/multiprocess/tpu_device_test.py new file mode 100644 index 000000000000..24dd08556008 --- /dev/null +++ b/tests/multiprocess/tpu_device_test.py @@ -0,0 +1,83 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test TpuDevice API on multiprocess setup.""" + +import unittest + +import jax +from jax._src import test_multiprocess as jt_multiprocess + + +class TpuDeviceTest(jt_multiprocess.MultiProcessTest): + + def test_coords(self): + for device in jax.local_devices(): + coords = device.coords + self.assertIsInstance(coords, list) + self.assertLen(coords, 3) + for coord in coords: + self.assertIsInstance(coord, int) + + def test_core(self): + for device in jax.local_devices(): + core = device.core_on_chip + self.assertIsInstance(core, int) + self.assertGreaterEqual(core, 0) + self.assertLess(core, 2) + + def test_missing_attribute(self): + for device in jax.local_devices(): + with self.assertRaises(AttributeError): + device.gpu_type # pylint: disable=pointless-statement + + def test_memory(self): + for device in jax.devices(): + device_is_local = device.process_index == jax.process_index() + self.assertLen(device.addressable_memories(), 3) + hbm = device.addressable_memories()[0] + self.assertEqual( + hbm.process_index == device.process_index, device_is_local) + self.assertEqual(hbm.platform, device.platform) + self.assertEqual(hbm.kind, 'device') + self.assertEqual(hbm, device.memory(hbm.kind)) + self.assertListEqual(hbm.addressable_by_devices(), [device]) + + host = device.addressable_memories()[1] + self.assertEqual( + host.process_index == device.process_index, device_is_local) + self.assertEqual(host.platform, device.platform) + self.assertEqual(host.kind, 'pinned_host') + self.assertEqual(host, device.memory(host.kind)) + self.assertListEqual(host.addressable_by_devices(), [device]) + + with self.assertRaisesRegex( + jax.errors.JaxRuntimeError, + 'INVALID_ARGUMENT: Could not find memory addressable by device TPU.*' + ' Device TPU.* can address the following memory kinds: device,' + ' pinned_host, unpinned_host. Got memory kind: gpu_hbm', + ): + device.memory('gpu_hbm') + + def test_host_memory_id(self): + if jax.local_device_count() < 2: + raise unittest.SkipTest('test requires 2 devices per process') + self.assertGreaterEqual(len(jax.local_devices()), 2) + host_0 = jax.local_devices()[0].memory('unpinned_host') + host_1 = jax.local_devices()[1].memory('unpinned_host') + self.assertNotEqual(id(host_0), id(host_1)) + + +if __name__ == '__main__': + jt_multiprocess.main() diff --git a/tests/multiprocess/wait_barrier_test.py b/tests/multiprocess/wait_barrier_test.py new file mode 100644 index 000000000000..d8434e227cc0 --- /dev/null +++ b/tests/multiprocess/wait_barrier_test.py @@ -0,0 +1,93 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wait barrier test.""" + +import jax +from jax._src import distributed +from jax._src import test_multiprocess as jt_multiprocess + + +class WaitBarrierTest(jt_multiprocess.MultiProcessTest): + + def test_only_participants_call_succeeds(self): + client = distributed.global_state.client + timeout_in_ms = 1000 + + # Only even process ids will participate in the barrier. + barrier_participants = [] + for process_index in range(jax.process_count()): + if process_index % 2 == 0: + barrier_participants.append(process_index) + + if jax.process_index() % 2 == 0: + client.wait_at_barrier( + 'only_even_participants_call', + timeout_in_ms, + process_ids=barrier_participants, + ) + # This test is intended to implicitly verify that no exceptions are raised + # when the barrier is called if only the barrier participants including + # each one of them call the barrier. Thus there are no explicit assertions. + + def test_non_participant_calls_fails(self): + client = distributed.global_state.client + timeout_in_ms = 1000 + + process_group = [] + for process_index in range(jax.process_count()): + if process_index % 2 == 0: + process_group.append(process_index) + + # Processes 0, 2, 4 ... wait here. + # Processes 1, 3, 5 ... go ahead to the barrier call. + if jax.process_index() % 2 == 0: + client.blocking_key_value_get('sync', timeout_in_ms=1000) + + # 1, 3, 5 ... arrive and hit an error as they are non-participating. + # 0, 2, 4 ... has not yet arrived here. They will arrive once 1 unblocks + # them after leaving the barrier. But when they arrive at the barrier, they + # would see the error state even though they are participating. + with self.assertRaisesRegex( + jax.errors.JaxRuntimeError, + r'INVALID_ARGUMENT: A non-participating task.*' + ): + client.wait_at_barrier( + 'non_participant_calls', timeout_in_ms, process_ids=process_group + ) + + # 1 unblocks 0, 2 and 4. + if jax.process_index() == 1: + client.key_value_set('sync', 'process 1 exiting') + + def test_all_participate_succeeds(self): + client = distributed.global_state.client + timeout_in_ms = 1000 + client.wait_at_barrier('all_processes_call', timeout_in_ms) + + # This test checks that processes do wait in `wait_at_barrier`and do not + # leave until all participating processes arrive. + def test_one_participant_never_calls_fails(self): + client = distributed.global_state.client + timeout_in_ms = 1000 + + if jax.process_index() != 0: + with self.assertRaisesRegex( + jax.errors.JaxRuntimeError, r'DEADLINE_EXCEEDED: Barrier timed out.*' + ): + client.wait_at_barrier('one_participant_never_calls', timeout_in_ms) + + +if __name__ == '__main__': + jt_multiprocess.main() diff --git a/tests/multiprocess_gpu_test.py b/tests/multiprocess_gpu_test.py index fe9922148ab4..4b1723de7697 100644 --- a/tests/multiprocess_gpu_test.py +++ b/tests/multiprocess_gpu_test.py @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib +import functools import os import shutil import subprocess import sys import unittest -import functools from absl.testing import absltest import numpy as np @@ -53,41 +52,47 @@ def test_gpu_distributed_initialize(self): num_gpus_per_task = 1 num_tasks = num_gpus // num_gpus_per_task - with contextlib.ExitStack() as exit_stack: - subprocesses = [] - for task in range(num_tasks): - env = os.environ.copy() - env["JAX_PORT"] = str(port) - env["NUM_TASKS"] = str(num_tasks) - env["TASK"] = str(task) - if jtu.is_device_rocm(): - env["HIP_VISIBLE_DEVICES"] = ",".join( - str((task * num_gpus_per_task) + i) for i in range(num_gpus_per_task)) - else: - env["CUDA_VISIBLE_DEVICES"] = ",".join( - str((task * num_gpus_per_task) + i) for i in range(num_gpus_per_task)) - args = [ - sys.executable, - "-c", - ('import jax, os; ' - 'jax.distributed.initialize(' - 'f\'localhost:{os.environ["JAX_PORT"]}\', ' - 'int(os.environ["NUM_TASKS"]), int(os.environ["TASK"])); ' - 'print(f\'{jax.local_device_count()},{jax.device_count()}\', end="")' - ) - ] - proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, universal_newlines=True) - subprocesses.append(exit_stack.enter_context(proc)) - - try: - for proc in subprocesses: - out, _ = proc.communicate() - self.assertEqual(proc.returncode, 0) - self.assertEqual(out, f'{num_gpus_per_task},{num_gpus}') - finally: - for proc in subprocesses: - proc.kill() + if jax.device_count() < num_gpus: + raise unittest.SkipTest( + f"Test requires >={num_gpus} GPUs; got {jax.device_count()}." + ) + + subprocesses = [] + for task in range(num_tasks): + env = os.environ.copy() + env["JAX_PORT"] = str(port) + env["NUM_TASKS"] = str(num_tasks) + env["TASK"] = str(task) + if jtu.is_device_rocm(): + env["HIP_VISIBLE_DEVICES"] = ",".join( + str((task * num_gpus_per_task) + i) for i in range(num_gpus_per_task)) + else: + env["CUDA_VISIBLE_DEVICES"] = ",".join( + str((task * num_gpus_per_task) + i) for i in range(num_gpus_per_task)) + args = [ + sys.executable, + "-c", + ('import jax, os; ' + 'jax.distributed.initialize(' + 'f\'localhost:{os.environ["JAX_PORT"]}\', ' + 'int(os.environ["NUM_TASKS"]), int(os.environ["TASK"])); ' + 'print(f\'{jax.local_device_count()},{jax.device_count()}\', end="")' + ) + ] + proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, universal_newlines=True) + subprocesses.append(self.enter_context(proc)) + + try: + for proc in subprocesses: + out, err = proc.communicate() + self.assertEqual(proc.returncode, 0, msg=f"Process failed:\n\n{out}\n\n{err}") + self.assertEqual( + out, f"{num_gpus_per_task},{num_gpus}", msg=f"Process failed:\n\n{out}\n\n{err}", + ) + finally: + for proc in subprocesses: + proc.kill() def test_distributed_jax_visible_devices(self): """Test jax_visible_devices works in distributed settings.""" @@ -99,49 +104,36 @@ def test_distributed_jax_visible_devices(self): num_gpus_per_task = 1 num_tasks = num_gpus // num_gpus_per_task - with contextlib.ExitStack() as exit_stack: - subprocesses = [] - for task in range(num_tasks): - env = os.environ.copy() - env["JAX_PORT"] = str(port) - env["NUM_TASKS"] = str(num_tasks) - env["TASK"] = str(task) - visible_devices = ",".join( - str((task * num_gpus_per_task) + i) for i in range(num_gpus_per_task)) - - if jtu.is_device_rocm(): - program = ( - 'import jax, os; ' - f'jax.config.update("jax_rocm_visible_devices", "{visible_devices}"); ' - 'jax.distributed.initialize(' - 'f\'localhost:{os.environ["JAX_PORT"]}\', ' - 'int(os.environ["NUM_TASKS"]), int(os.environ["TASK"])); ' - 's = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(jax.numpy.ones(jax.local_device_count())); ' - 'print(f\'{jax.local_device_count()},{jax.device_count()},{s}\', end=""); ' - ) - else: - program = ( - 'import jax, os; ' - f'jax.config.update("jax_cuda_visible_devices", "{visible_devices}"); ' - 'jax.distributed.initialize(' - 'f\'localhost:{os.environ["JAX_PORT"]}\', ' - 'int(os.environ["NUM_TASKS"]), int(os.environ["TASK"])); ' - 's = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(jax.numpy.ones(jax.local_device_count())); ' - 'print(f\'{jax.local_device_count()},{jax.device_count()},{s}\', end=""); ' - ) - args = [sys.executable, "-c", program] - proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, universal_newlines=True) - subprocesses.append(exit_stack.enter_context(proc)) - - try: - for proc in subprocesses: - out, err = proc.communicate() - self.assertEqual(proc.returncode, 0, msg=f"Process failed:\n\n{err}") - self.assertRegex(out, f'{num_gpus_per_task},{num_gpus},\\[{num_gpus}.\\]$') - finally: - for proc in subprocesses: - proc.kill() + subprocesses = [] + for task in range(num_tasks): + env = os.environ.copy() + env["JAX_PORT"] = str(port) + env["NUM_TASKS"] = str(num_tasks) + env["TASK"] = str(task) + visible_devices = [ + (task * num_gpus_per_task) + i for i in range(num_gpus_per_task) + ] + program = ( + 'import jax, os; ' + 'jax.distributed.initialize(' + 'f\'localhost:{os.environ["JAX_PORT"]}\', ' + f'int(os.environ["NUM_TASKS"]), int(os.environ["TASK"]), {visible_devices}); ' + 's = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(jax.numpy.ones(jax.local_device_count())); ' + 'print(f\'{jax.local_device_count()},{jax.device_count()},{s}\', end=""); ' + ) + args = [sys.executable, "-c", program] + proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, universal_newlines=True) + subprocesses.append(self.enter_context(proc)) + + try: + for proc in subprocesses: + out, err = proc.communicate() + self.assertEqual(proc.returncode, 0, msg=f"Process failed:\n\n{err}") + self.assertRegex(out, f'{num_gpus_per_task},{num_gpus},\\[{num_gpus}.\\]$') + finally: + for proc in subprocesses: + proc.kill() def test_gpu_ompi_distributed_initialize(self): if not jtu.test_device_matches(['gpu']): @@ -152,34 +144,33 @@ def test_gpu_ompi_distributed_initialize(self): num_gpus = 4 num_gpus_per_task = 1 - with contextlib.ExitStack() as exit_stack: - args = [ - 'mpirun', - '--oversubscribe', - '--allow-run-as-root', - '-n', - str(num_gpus), - sys.executable, - '-c', - ('import jax, os; ' - 'jax.distributed.initialize(); ' - 'print(f\'{jax.local_device_count()},{jax.device_count()}\' if jax.process_index() == 0 else \'\', end="")' - ) - ] - env = os.environ.copy() - # In case the job was launched via Slurm, - # prevent OpenMPI from detecting Slurm environment - env.pop('SLURM_JOBID', None) - proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, universal_newlines=True) - proc = exit_stack.enter_context(proc) - - try: - out, _ = proc.communicate() - self.assertEqual(proc.returncode, 0) - self.assertEqual(out, f'{num_gpus_per_task},{num_gpus}') - finally: - proc.kill() + args = [ + 'mpirun', + '--oversubscribe', + '--allow-run-as-root', + '-n', + str(num_gpus), + sys.executable, + '-c', + ('import jax, os; ' + 'jax.distributed.initialize(); ' + 'print(f\'{jax.local_device_count()},{jax.device_count()}\' if jax.process_index() == 0 else \'\', end="")' + ) + ] + env = os.environ.copy() + # In case the job was launched via Slurm, + # prevent OpenMPI from detecting Slurm environment + env.pop('SLURM_JOBID', None) + proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, universal_newlines=True) + proc = self.enter_context(proc) + + try: + out, _ = proc.communicate() + self.assertEqual(proc.returncode, 0) + self.assertEqual(out, f'{num_gpus_per_task},{num_gpus}') + finally: + proc.kill() def test_gpu_mpi4py_distributed_initialize(self): if not jtu.test_device_matches(['gpu']): @@ -192,34 +183,33 @@ def test_gpu_mpi4py_distributed_initialize(self): num_gpus = 4 num_gpus_per_task = 1 - with contextlib.ExitStack() as exit_stack: - args = [ - 'mpirun', - '--oversubscribe', - '--allow-run-as-root', - '-n', - str(num_gpus), - sys.executable, - '-c', - ('import jax, os; ' - 'jax.distributed.initialize(spec_detection_method="mpi4py"); ' - 'print(f\'{jax.local_device_count()},{jax.device_count()}\' if jax.process_index() == 0 else \'\', end="")' - ) - ] - env = os.environ.copy() - # In case the job was launched via Slurm, - # prevent OpenMPI from detecting Slurm environment - env.pop('SLURM_JOBID', None) - proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, universal_newlines=True) - proc = exit_stack.enter_context(proc) - - try: - out, _ = proc.communicate() - self.assertEqual(proc.returncode, 0) - self.assertEqual(out, f'{num_gpus_per_task},{num_gpus}') - finally: - proc.kill() + args = [ + 'mpirun', + '--oversubscribe', + '--allow-run-as-root', + '-n', + str(num_gpus), + sys.executable, + '-c', + ('import jax, os; ' + 'jax.distributed.initialize(spec_detection_method="mpi4py"); ' + 'print(f\'{jax.local_device_count()},{jax.device_count()}\' if jax.process_index() == 0 else \'\', end="")' + ) + ] + env = os.environ.copy() + # In case the job was launched via Slurm, + # prevent OpenMPI from detecting Slurm environment + env.pop('SLURM_JOBID', None) + proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, universal_newlines=True) + proc = self.enter_context(proc) + + try: + out, _ = proc.communicate() + self.assertEqual(proc.returncode, 0) + self.assertEqual(out, f'{num_gpus_per_task},{num_gpus}') + finally: + proc.kill() @unittest.skipIf( diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index e962653ed32d..9234bf57f28f 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -14,13 +14,19 @@ from __future__ import annotations +import unittest + from absl.testing import absltest from absl.testing import parameterized +from functools import partial +import itertools as it import numpy as np import jax from jax._src import core from jax._src import config from jax._src import test_util as jtu +from jax._src.util import safe_map, safe_zip +from jax._src.interpreters import mlir from jax.sharding import NamedSharding, PartitionSpec as P, AxisType import jax.numpy as jnp @@ -28,6 +34,12 @@ config.parse_flags_with_absl() +jtu.request_cpu_devices(8) + +map, unsafe_map = safe_map, map +zip, unsafe_zip = safe_zip, zip + + class MutableArrayTest(jtu.JaxTestCase): @parameterized.parameters([True, False]) @@ -40,7 +52,7 @@ def f(x_mut): if jit: f = jax.jit(f) - x_mut = core.mutable_array(jnp.zeros(3)) + x_mut = core.new_ref(jnp.zeros(3)) f(x_mut) self.assertAllClose(x_mut[...], jnp.array([2., 6., 1.]), @@ -49,6 +61,72 @@ def f(x_mut): jaxpr = jax.make_jaxpr(f)(x_mut) self.assertTrue(any(isinstance(e, RefEffect) for e in jaxpr.effects)) + def test_basic_aot(self): + @jax.jit + def f(x_mut): + x_mut[...] += 1. + x_mut[0] += 1 + x_mut[1] += 5 + + x_mut = core.new_ref(jnp.zeros(3)) + f.lower(x_mut).compile()(x_mut) + self.assertAllClose(x_mut[...], jnp.array([2., 6., 1.]), + check_dtypes=False) + + def test_basic_aot_closure(self): + x_mut = core.new_ref(jnp.zeros(3)) + + @jax.jit + def f(): + x_mut[...] += 1. + x_mut[0] += 1 + x_mut[1] += 5 + + c = f.lower().compile() + c() + c() + self.assertAllClose(x_mut[...], jnp.array([4., 12., 2.]), + check_dtypes=False) + + def test_basic_sharded_aot(self): + mesh = jtu.create_mesh((2,), ('x',)) + arr = jax.device_put(np.arange(8.), NamedSharding(mesh, P('x'))) + + @jax.jit + def f(x_mut): + x_mut[...] += 1. + x_mut[0] += 1 + x_mut[1] += 5 + + x_mut = core.new_ref(arr) + f.lower(x_mut).compile()(x_mut) + expected = np.arange(8.) + 1 + expected[0] += 1 + expected[1] += 5 + self.assertAllClose(x_mut[...], expected) + + def test_sharded_aot_mutable_sds(self): + mesh = jtu.create_mesh((2,), ('x',)) + arr = jax.device_put(np.arange(8.), NamedSharding(mesh, P('x'))) + + @jax.jit + def f(x_mut): + x_mut[...] += 1. + x_mut[0] += 1 + x_mut[1] += 5 + + sds_mut = jax.ShapeDtypeStruct(arr.shape, arr.dtype, sharding=arr.sharding, + is_ref=True) + compiled = f.lower(sds_mut).compile() + + x_mut = core.new_ref(arr) + compiled(x_mut) + + expected = np.arange(8.) + 1 + expected[0] += 1 + expected[1] += 5 + self.assertAllClose(x_mut[...], expected) + @parameterized.parameters([True, False]) def test_multiple_inputs_and_outputs(self, jit): def f(x_mut, y, z_mut, w): @@ -59,9 +137,9 @@ def f(x_mut, y, z_mut, w): if jit: f = jax.jit(f) - x_mut = core.mutable_array(jnp.zeros((1, 3))) + x_mut = core.new_ref(jnp.zeros((1, 3))) y = jnp.ones((2, 3)) - z_mut = core.mutable_array(jnp.zeros((2, 3))) + z_mut = core.new_ref(jnp.zeros((2, 3))) w = jnp.ones((2, 1)) out1, out2 = f(x_mut, y, z_mut, w) @@ -73,7 +151,7 @@ def f(x_mut, y, z_mut, w): @parameterized.parameters([True, False]) def test_closed_over_basic(self, jit): - x_mut = core.mutable_array(jnp.zeros(3)) + x_mut = core.new_ref(jnp.zeros(3)) def f(): x_mut[...] += 1. x_mut[0] += 1 @@ -92,7 +170,7 @@ def f(): @parameterized.parameters([True, False]) def test_closed_over_nested(self, jit): - x_mut = core.mutable_array(jnp.zeros(3)) + x_mut = core.new_ref(jnp.zeros(3)) @jax.jit def f(y_mut, z): @@ -106,7 +184,7 @@ def f(y_mut, z): if jit: f = jax.jit(f) - y_mut = core.mutable_array(np.zeros(3)) + y_mut = core.new_ref(np.zeros(3)) w = f(y_mut, 1) @@ -116,10 +194,22 @@ def f(y_mut, z): check_dtypes=False) self.assertAllClose(w, 10, check_dtypes=False) + @parameterized.parameters([True, False]) + def test_len_mutable_array(self, jit): + x_mut = core.new_ref(jnp.zeros(3)) + + def f(): + return jnp.int32(len(x_mut)) + + if jit: + f = jax.jit(f) + + self.assertEqual(f(), 3) + @parameterized.parameters([True, False]) def test_internal_mutarray_basic(self, jit): def f(): - x_mut = core.mutable_array(jnp.zeros(3)) + x_mut = core.new_ref(jnp.zeros(3)) x_mut[0] += 1 x_mut[0] += 1 x_mut[2] += 1 @@ -134,7 +224,7 @@ def f(): @parameterized.parameters([True, False]) def test_scan_internal_mut_array(self, jit): def body_fun(_, x): - x_mut = core.mutable_array(x) + x_mut = core.new_ref(x) x_mut[...] += 2 return ((), x_mut[...]) doit = lambda: jax.lax.scan(body_fun, (), np.arange(5)) @@ -145,7 +235,7 @@ def body_fun(_, x): @parameterized.parameters([True, False]) def test_scan_closed_over_mut_array(self, jit): - x_mut = core.mutable_array(0) + x_mut = core.new_ref(0) def body_fun(_, x): x_mut[...] += 2 return ((), x_mut[...]) @@ -164,7 +254,7 @@ def body_fun(_, index_x): x[...] += index return ((), x[...]) - x_mut = core.mutable_array(np.arange(5)) + x_mut = core.new_ref(np.arange(5)) doit = lambda: jax.lax.scan(body_fun, (), (np.arange(5), x_mut)) if jit: doit = jax.jit(doit) @@ -175,33 +265,37 @@ def test_double_jit_mutable_array(self): @jax.jit @jax.jit def f(): - x_ref = core.mutable_array(jnp.zeros(8)) + x_ref = core.new_ref(jnp.zeros(8)) return x_ref[...] x = f() self.assertArraysEqual(x, jnp.zeros(8)) - def test_grad_mutable_array(self): - @jax.jit + @parameterized.parameters([False, True]) + def test_grad_mutable_array(self, jit): + def f(x): - x_ = core.mutable_array(x) + x_ = core.new_ref(x) x_[()] = x_[()] + x_[()] y = core.freeze(x_) return y + if jit: + f = jax.jit(f) + ans = jax.grad(f)(1.) expected = 2.0 self.assertAllClose(ans, expected, check_dtypes=False) def test_defensive_copy(self): x = jnp.arange(3.) - _ = jax.jit(lambda x_ref: x_ref[...])(core.mutable_array(x)) + _ = jax.jit(lambda x_ref: x_ref[...])(core.new_ref(x)) x + 1 # don't crash def test_sharding_persists(self): mesh = jtu.create_mesh((1,), ('i',)) x = jax.device_put(jnp.arange(2), NamedSharding(mesh, P('i'))) s = x.sharding - a = core.mutable_array(x) + a = core.new_ref(x) self.assertEqual(s, a.sharding) self.assertEqual(s, a[...].sharding) f = jax.jit(lambda: a[...]) @@ -222,11 +316,770 @@ def f(x_ref): y = x_ref[...] + 1 return y - with jax.sharding.use_mesh(mesh): + with jax.set_mesh(mesh): x = jnp.zeros((4, 4), jnp.int32, device=sharding) - x_ref = core.mutable_array(x) + x_ref = core.new_ref(x) y = f(x_ref) + def test_vmap_basic(self): + @jax.vmap + def f(x): + x_ref = core.new_ref(x) + x_ref[...] = x_ref[...] * x_ref[...] + return x_ref[...] + xs = jnp.arange(4.) + ys = f(xs) + self.assertAllClose(ys, xs ** 2, check_dtypes=False) + + def test_vmap_extensive_inputs(self): + def f(x_ref, val): + x_ref[...] += val + x_ref[...] += val + + xs_ref = core.new_ref(jnp.array([0, 0, 0])) + vals = jnp.arange(3) + jax.vmap(f)(xs_ref, vals) + self.assertAllClose(xs_ref[...], 2 * vals, check_dtypes=False) + + def test_vmap_closed_over_read_only(self): + y_ref = core.new_ref(1) + + def f(x_ref): + x_ref[...] += y_ref[...] + x_ref[...] += y_ref[...] + + xs_ref = core.new_ref(jnp.array([0, 0, 0])) + jax.vmap(f)(xs_ref) + self.assertAllClose(xs_ref[...], jnp.array([2, 2, 2]), check_dtypes=False) + + def test_implicit_bitcast_regression(self): + # https://github.com/jax-ml/jax/issues/27683 + v = core.new_ref(jnp.array([0, 0, 0])) + with self.assertRaises(ValueError): + v[...] += 1.0 + + def test_implicit_cast_in_swap(self): + v = core.new_ref(jnp.array(0, dtype='bfloat16')) + v[...] += 1.0 # don't crash + + def test_rng_key(self): + key = core.new_ref(jax.random.key(0)) + # test read/write + key[...] = jax.random.fold_in(key[...], 1) # don't crash + + def test_scan_grad_doesnt_hoist_mutable_stuff(self): + x_ref = core.new_ref(0) + + def f(x): + def body(c, _): + x_ref[...] += 1 + return c, () + x, () = jax.lax.scan(body, x, (), length=3) + return x + + jax.grad(f)(1.0) + self.assertAllClose(x_ref[...], 3, check_dtypes=False) + + def test_scan_grad_doesnt_hoist_mutable_stuff2(self): + x_ref = core.new_ref(0) + const = jnp.arange(3) + const2 = jnp.zeros(()) + + def f(x): + def body(c, _): + x_ref[...] += const.sum() + return c + const2, () + x, () = jax.lax.scan(body, x, (), length=4) + return x + + jax.grad(f)(1.0) + self.assertAllClose(x_ref[...], 12, check_dtypes=False) + + @parameterized.parameters([False, True]) + def test_custom_vjp_grad_stats_plumbing(self, jit): + + @jax.custom_vjp + def gradient_history_calculator(x, ref): + del ref + return x + + def gradient_history_calculator_fwd(x, ref): + return x, ref + + def gradient_history_calculator_bwd(amax_history, grad_output): + amax_update = jnp.max(jnp.abs(grad_output)) + shifted = jnp.roll(amax_history[:], 1) + shifted = shifted.at[0].set(amax_update) + amax_history[:] = shifted + amax_from_history = jnp.max(amax_history[:]) + grad_output = grad_output / amax_from_history + return grad_output, None + + gradient_history_calculator.defvjp( + gradient_history_calculator_fwd, + gradient_history_calculator_bwd) + + class DotOp: + def __init__(self): + self.amax_history = core.new_ref(jnp.zeros(5,)) + + def forward(self, x, y): + out = jnp.dot(x, y) + out = gradient_history_calculator(out, self.amax_history) + return out + + dot_op = DotOp() + x_top = jnp.ones((5,)) + y_top = jnp.ones((5,)) + + def loss(x, y): + return dot_op.forward(x, y).sum() + + if jit: + loss = jax.jit(loss) + + for i in range(3): + jax.grad(loss, (0,1))(x_top, y_top) + self.assertAllClose(dot_op.amax_history[:], jnp.zeros((5,)).at[:i+1].set(1.0), check_dtypes=False) + + @parameterized.parameters([False, True]) + def test_custom_vjp_grad_stats_plumbing_basic(self, jit): + def primal(grads_ref, x): # note: jit-abstracted! + x = jnp.sin(x) + x = stash_grads(grads_ref, x) + x = jnp.sin(x) + x = stash_grads(grads_ref, x) # ignored, order-preserved + return x + + if jit: + primal = jax.jit(primal) + + @jax.custom_vjp + def stash_grads(grads_ref, x): + return x + def stash_grads_fwd(grads_ref, x): + return x, grads_ref + def stash_grads_bwd(grads_ref, g): + grads_ref[...] = g + return None, g + stash_grads.defvjp(stash_grads_fwd, stash_grads_bwd) + + grads_ref = core.new_ref(jnp.float32(0.)) + jax.grad(primal, 1)(grads_ref, jnp.float32(1.0)) + self.assertAllClose(grads_ref[...], jnp.cos(jnp.sin(1.)), check_dtypes=False) + + @parameterized.parameters(it.product([False, True], repeat=2)) + def test_custom_vjp_grad_stats_plumbing_scan(self, jit, remat): + def primal(grads_ref, x): # note: jit-abstracted! + def body(x, _): + x = jnp.sin(x) + x = stash_grads(grads_ref, x) + x = jnp.sin(x) + return x, () + if remat: + body = jax.remat(body) + x, () = jax.lax.scan(body, x, None, length=1) + return x + + if jit: + primal = jax.jit(primal) + + @jax.custom_vjp + def stash_grads(grads_ref, x): + return x + def stash_grads_fwd(grads_ref, x): + return x, grads_ref + def stash_grads_bwd(grads_ref, g): + grads_ref[...] = g + return None, g + stash_grads.defvjp(stash_grads_fwd, stash_grads_bwd) + + grads_ref = core.new_ref(jnp.float32(0.)) + jax.grad(primal, argnums=1)(grads_ref, jnp.float32(1.0)) + self.assertAllClose(grads_ref[...], jnp.cos(jnp.sin(1.)), check_dtypes=False) + + @parameterized.product(jit=[False, True], has_aux=[False, True]) + def test_custom_vjp_grad_stats_plumbing_basic_vjp3(self, jit, has_aux): + def primal(grads_ref, x): # note: abstracts over jit and has_aux! + x0 = x + x = jnp.sin(x) + x = stash_grads(grads_ref, x) + x = jnp.sin(x) + x = stash_grads(grads_ref, x) # ignored, order-preserved + return (x, x0) if has_aux else x + + if jit: + primal = jax.jit(primal) + + @jax.custom_vjp + def stash_grads(grads_ref, x): + return x + def stash_grads_fwd(grads_ref, x): + return x, grads_ref + def stash_grads_bwd(grads_ref, g): + grads_ref[...] = g + return None, g + stash_grads.defvjp(stash_grads_fwd, stash_grads_bwd) + + grads_ref = core.new_ref(jnp.float32(0.)) + x = jnp.float32(1.) + _, f_vjp, *maybe_aux = jax.vjp( + lambda x: primal(grads_ref, x), x, has_aux=has_aux) + _ = f_vjp(jnp.float32(1.)) + self.assertAllClose(grads_ref[...], jnp.cos(jnp.sin(1.)), check_dtypes=False) + if has_aux: + aux, = maybe_aux + self.assertAllClose(aux, x) + + def test_custom_vjp_grad_stats_plumbing_scan_vjp3(self): + def primal(stash_ref, x): # note: jit-abstracted! + def body(x, _): + x = jnp.sin(x) + x = stash_grads(stash_ref, x) + x = jnp.sin(x) + return x, () + x, () = jax.lax.scan(body, x, None, length=1) + return x + + @jax.custom_vjp + def stash_grads(stash_ref, x): + return x + def stash_grads_fwd(stash_ref, x): + return x, stash_ref + def stash_grads_bwd(stash_ref, g): + stash_ref[...] = g + return None, g + stash_grads.defvjp(stash_grads_fwd, stash_grads_bwd) + + stash_ref = core.new_ref(jnp.float32(0.)) + _, f_vjp = jax.vjp(lambda x: primal(stash_ref, x), jnp.float32(1.)) + grads_val, = f_vjp(jnp.float32(1.)) + self.assertAllClose(stash_ref[...], jnp.cos(jnp.sin(1.)), check_dtypes=False) + self.assertAllClose(grads_val, jnp.cos(jnp.sin(1.)) * jnp.cos(1.), + check_dtypes=False) + + stash_ref = core.new_ref(jnp.float32(0.)) + grads_ref = core.new_ref(jnp.float32(0.)) + _, f_vjp = jax.vjp(lambda x: primal(stash_ref, x), jnp.float32(1.)) + _ = f_vjp.with_refs(grads_ref)(jnp.float32(1.)) + self.assertAllClose(stash_ref[...], jnp.cos(jnp.sin(1.)), check_dtypes=False) + self.assertAllClose(grads_ref[...], jnp.cos(jnp.sin(1.)) * jnp.cos(1.), + check_dtypes=False) + + @parameterized.parameters([False, True], [False, True]) + def test_freeze_insertion(self, inner_jit, outer_jit): + def f(x): + x_ref = core.new_ref(x) + + def g(): + x_ref[...] = x_ref[...] + x_ref[...] + if inner_jit: + g = jax.jit(g) + g() + + return x_ref[...] + + if outer_jit: + f = jax.jit(f) + + self.assertAllClose(jax.grad(f)(3.), 2., check_dtypes=False) + + @parameterized.parameters([False, True]) + def test_grad_jit(self, jit): + def f(x): + x_ref = core.new_ref(x) + + @jax.jit + def g(): + x_ref[...] = jnp.sin(x_ref[...]) + jnp.sin(x_ref[...]) + g() + + return core.freeze(x_ref) + + if jit: + f = jax.jit(f) + + ans = jax.grad(f)(2.) + expected = 2. * jnp.cos(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + @parameterized.parameters([False, True]) + def test_grad_scan(self, jit): + def f(x): + x_ref = core.new_ref(x) + + def g(_, __): + x_ref[...] = jnp.sin(x_ref[...]) + jnp.sin(x_ref[...]) + return None, None + jax.lax.scan(g, None, None, length=1) + + return core.freeze(x_ref) + + if jit: + f = jax.jit(f) + + ans = jax.grad(f)(2.) + expected = 2. * jnp.cos(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_grad_scan_extensive(self): + def f(xs): + xs_ref = core.new_ref(xs) + + def g(c, x_ref): + return c + x_ref[...], None + out, _ = jax.lax.scan(g, 0., xs_ref) + + return out + + ans = jax.grad(f)(jnp.arange(3.)) + expected = jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + @parameterized.parameters([False, True]) + def test_grad_jit_readonly(self, jit): + def f(x): + x_ref = core.new_ref(jnp.zeros_like(x)) + x_ref[...] = x + return x_ref[...] + + if jit: + f = jax.jit(f) + + jtu.check_grads(f, (1.5,), 2, ['fwd', 'rev']) + + def test_grad_jit_readonly_1(self): + @jax.jit + def f(x): + x_ref = core.new_ref(x) + + def inner(): + return jnp.sin(x_ref[...]) + + return inner() + + jtu.check_grads(f, (1.5,), 2, ['fwd', 'rev']) + + def test_grad_jit_readonly_2(self): + def f(x): + x_ref = core.new_ref(x) + + @jax.jit + def inner(): + return jnp.sin(x_ref[...]) + + return inner() + + jtu.check_grads(f, (1.5,), 2, ['fwd', 'rev']) + + @jtu.sample_product( + seed=range(6), + num_consts=range(2, 6), + num_args=[0, 3], + ) + @jtu.run_on_devices("cpu") + def test_jit_vjp_systematic_readonly(self, seed, num_consts, num_args): + num_mut_consts = num_consts // 2 + num_pure_consts = num_consts - num_mut_consts + + rng = np.random.RandomState(seed) + pure_consts = [rng.normal() for _ in range(num_pure_consts)] + mut_const_vals = [rng.normal() for _ in range(num_mut_consts)] + + args = [rng.normal() for _ in range(num_args)] + + mutable_bools = rng.permutation([True] * num_mut_consts + + [False] * num_pure_consts) + + def f(mut_const_vals, pure_consts, args): + consts = pure_consts[:], map(core.new_ref, mut_const_vals) + + @jax.jit + def inner(args): + tot = 0. + for is_mut in mutable_bools: + const = consts[int(is_mut)].pop() + if is_mut: const = const[...] + tot += jnp.sin(const) + for x in args: + tot += jnp.sin(x) + return tot + + return inner(args) + + jtu.check_grads(f, (mut_const_vals, pure_consts, args), 2, ['rev']) + + @jtu.sample_product( + seed=range(6), + num_consts=range(2, 6), + num_carry=[0, 3], + num_ext_in=[0, 3], + num_iters=[1, 3], + ) + @jtu.run_on_devices("cpu") + def test_scan_vjp_systematic_readonly( + self, seed, num_consts, num_carry, num_ext_in, num_iters): + num_mut_consts = num_consts // 2 + num_pure_consts = num_consts - num_mut_consts + + rng = np.random.RandomState(seed) + pure_consts = [rng.normal() for _ in range(num_pure_consts)] + mut_const_vals = [rng.normal() for _ in range(num_mut_consts)] + + init_carry = [rng.normal() for _ in range(num_carry)] + xs = [rng.normal(size=num_iters) for _ in range(num_ext_in)] + + mutable_bools = rng.permutation([True] * num_mut_consts + + [False] * num_pure_consts) + + def f(mut_const_vals, pure_consts, c, xs): + consts = pure_consts[:], map(core.new_ref, mut_const_vals) + + def body(c, x): + tot = 0. + for is_mut in mutable_bools: + const = consts[int(is_mut)].pop() + if is_mut: const = const[...] + tot += jnp.sin(const) + new_c = [jnp.sin(carry) + tot for carry in c] + y = sum(map(jnp.sin, x)) * 1.0 + return new_c, y + + return jax.lax.scan(body, init_carry, xs, length=num_iters) + + jtu.check_grads(f, (mut_const_vals, pure_consts, init_carry, xs), + 2, ['fwd', 'rev'], rtol=1.5e-2) + + @parameterized.parameters([False, True]) + def test_remat_basic_internal(self, jit): + @jax.remat + def f(y, x): + x_ref = jax.new_ref(x) + out = y * x_ref[...] + x_ref[...] += 1 + return out + + if jit: + f = jax.jit(f) + + g = jax.grad(f)(2., 1.) + self.assertAllClose(g, 1.) + + @parameterized.parameters([False, True]) + def test_remat_basic_arg(self, jit): + @jax.remat + def f(y, x_ref): + out = y * y + x_ref[...] += out + return out + + if jit: + f = jax.jit(f) + + x_ref = core.new_ref(1., kind='anselm_ref') + g = jax.grad(f)(2., x_ref) + self.assertAllClose(x_ref[...], 5.) + self.assertAllClose(g, 4.) + + @parameterized.parameters([False, True]) + def test_remat_basic_closed_over(self, jit): + @jax.remat + def f(y): + out = y * x_ref[...] + x_ref[...] += 1 + return out + + if jit: + f = jax.jit(f) + + x_ref = core.new_ref(1., kind='anselm_ref') + g = jax.grad(f)(2.) + self.assertAllClose(x_ref[...], 2.) + self.assertAllClose(g, 1.) + + def test_remat_basic_closed_over_nested(self): + @jax.remat + @partial(jax.remat, policy=lambda *_, **__: False) + @jax.remat + def f(y): + jax.debug.callback(lambda _: lst.append('hi'), y) + out = y * x_ref[...] + x_ref[...] += 1 + return jnp.sin(out) + + lst = [] + x_ref = core.new_ref(1., kind='anselm_ref') + g = jax.grad(f)(2.) + self.assertAllClose(x_ref[...], 2.) + self.assertAllClose(g, jnp.cos(2.)) + self.assertLen(lst, 4) + + def test_remat_grad_stats_plumbing_basic(self): + @jax.remat + def f(x_ref, y): + stash_grads(x_ref, y) + return y + + @jax.custom_vjp + def stash_grads(grads_ref, x): + return x + def stash_grads_fwd(grads_ref, x): + return x, grads_ref + def stash_grads_bwd(grads_ref, g): + grads_ref[...] = g + return None, g + stash_grads.defvjp(stash_grads_fwd, stash_grads_bwd) + + x_ref = core.new_ref(0) + jax.grad(f, 1)(x_ref, 3.14) + + @jtu.run_on_devices("cpu") # tolerances, lol + def test_vjp3_ref_grads_for_val_primals(self): + NUM_LAYERS = 3 + NUM_MUBATCHES = 5 + MUBATCH_SIZE = 7 + + def mubatch_loss(Ws, xs): + # Inner loop: scan over layers + act, _ = jax.lax.scan(lambda xs, W: (jnp.dot(xs, W), None), xs, Ws) + return jnp.mean(act) + + def process_batch(Ws, xs_batch): + grad_acc = jax.new_ref(jnp.zeros_like(Ws)) # CHANGED + + def process_mubatch(_, xs): + loss, f_vjp = jax.vjp(lambda Ws: mubatch_loss(Ws, xs), Ws) # CHANGED + f_vjp.with_refs(grad_acc)(jnp.ones_like(loss)) # CHANGED + return (), loss + + assert xs_batch.shape[0] == NUM_MUBATCHES * MUBATCH_SIZE + xs_mubatches = xs_batch.reshape(NUM_MUBATCHES, MUBATCH_SIZE, *xs_batch.shape[1:]) + + # Outer loop: scan over microbatches + (), _losses = jax.lax.scan(process_mubatch, (), xs_mubatches) + return jax.ref.freeze(grad_acc) + + Ws = jnp.ones((NUM_LAYERS, 4, 4)) + xs_batch = jnp.ones((NUM_MUBATCHES * MUBATCH_SIZE, 4)) + g = process_batch(Ws, xs_batch) + self.assertAllClose(g, 20. * jnp.ones_like(Ws), atol=1e-3, rtol=1e-3) + + @parameterized.parameters([False, True]) + def test_custom_vjp_internal_ref(self, jit): + @jax.custom_vjp + def f(x): + x_ref = jax.new_ref(jnp.zeros_like(x)) + x_ref[...] = x + return x_ref[...] + def f_fwd(x): + return x, None + def f_bwd(_, g): + return g, + f.defvjp(f_fwd, f_bwd) + + if jit: + f = jax.jit(f) + + x = jax.jit(f)(3.) # no ad, doesn't crash + self.assertAllClose(x, 3., check_dtypes=False) + + g = jax.grad(f)(3.) + self.assertAllClose(g, 1., check_dtypes=False) + + def test_custom_vjp_ad_after_discharge_error(self): + @jax.custom_vjp + def f(x): + x_ref = jax.new_ref(jnp.zeros_like(x)) + x_ref[...] = x + return x_ref[...] + def f_fwd(x): + return x, None + def f_bwd(_, g): + return g, + + f.defvjp(f_fwd, f_bwd) + from jax._src import core + from jax._src.state.discharge import discharge_state + jaxpr = jax.make_jaxpr(f)(3.) + jaxpr_, consts_ = discharge_state(jaxpr.jaxpr, jaxpr.consts) + with self.assertRaises(Exception): + jax.grad(lambda x: core.eval_jaxpr(jaxpr_, consts_, x)[0])(3.) + + @parameterized.parameters([False, True]) + def test_custom_vjp_differentiated_ref(self, jit): + @jax.custom_vjp + def f(x_ref): + return x_ref[...] + def f_fwd(x_ref): + return f(x_ref), None + def f_bwd(_, g): + return g, + f.defvjp(f_fwd, f_bwd) + + if jit: + f = jax.jit(f) + + y = f(jax.new_ref(3.14)) + self.assertAllClose(y, 3.14, check_dtypes=False) + + # this exercises the fallback path, not a fancy transpose + _, f_vjp = jax.vjp(lambda x: f(jax.new_ref(x)), 3.14) + g, = f_vjp(1.) + self.assertAllClose(g, 1., check_dtypes=False) + + def test_get_transpose_uninstantiated_grad_ref(self): + # from https://github.com/jax-ml/jax/pull/31412#discussion_r2308151559 + f = lambda x: jax.new_ref(x)[0] + jax.grad(f)(jnp.array([3.])) # don't crash + + def test_vmap_create_ref_from_unbatched_value(self): + @jax.jit + def internally_pure(x): + ref = jax.new_ref(1.) + ref[...] += x + return ref[...] + + ans = jax.vmap(internally_pure)(jnp.arange(4.)) + self.assertAllClose(ans, jnp.array([1., 2., 3., 4.])) + + def test_isinstance(self): + ref = jax.new_ref(1.) + self.assertIsInstance(ref, jax.Ref) + + @jax.jit + def f(x_ref): + self.assertIsInstance(x_ref, jax.Ref) + f(ref) + + self.assertNotIsInstance(ref, jax.Array) + + arr = jnp.ones(3) + self.assertNotIsInstance(arr, jax.Ref) + + def test_scan_vjp3_reverse(self): + # https://github.com/jax-ml/jax/issues/32411 + def f(xs): + ys = jnp.arange(5.) + + def body(_, xy): + x, y = xy + return (), x * y + (), z = jax.lax.scan(body, (), (xs, ys)) + return z.sum() + + grad_accum = jax.new_ref(jnp.zeros(5)) + _, f_vjp = jax.vjp(f, jnp.ones(5)) + _, = f_vjp.with_refs(grad_accum)(1.) + self.assertAllClose(grad_accum[...], jnp.arange(5.)) + + def test_vmap_with_vjp3(self): + # https://github.com/jax-ml/jax/issues/32479 + def grad_via_ref(f): + def wrapper(*args): + grad_accum = jax.tree.map(lambda x: jax.new_ref(jnp.zeros_like(x)), args) + out, f_vjp = jax.vjp(f, *args) + f_vjp.with_refs(*grad_accum)(jnp.ones_like(out)) + return jax.tree.map(lambda x: jax.freeze(x), grad_accum) + return wrapper + + def issue_vmap1(): + def f(x): + return x + 1 + x = jnp.ones((4,)) + # g = grad_via_ref(jax.vmap(f)) # good + g = jax.vmap(grad_via_ref(f)) # bad + g(x) # crash + + def issue_vmap1_minimized(): + def f(x): + x.addupdate(1.0) # bad (assumes non-empty list of indexers) + jax.vmap(f)(jax.new_ref(jnp.zeros((4,)))) # crash + + def issue_vmap2(): + def f(x): + x[...] = 1.0 # bad (mismatched shapes) + jax.vmap(f)(jax.new_ref(jnp.zeros((4,)))) # crash + + # don't crash + issue_vmap1() + issue_vmap1_minimized() + issue_vmap2() + + def test_slicing_with_vjp3(self): + @jax.jit + def f(x, i): + return x[i] ** 2 + + x = jnp.arange(10.) + + grad_accum = jax.new_ref(jnp.zeros(10)) + not_needed = object() + + @jax.make_jaxpr + def run(): + _, f_vjp = jax.vjp(f, x, 5) + f_vjp = f_vjp.with_refs(grad_accum, not_needed) + f_vjp(1.) + + jaxpr = run() + self.assertIn('+=', str(jaxpr)) + self.assertNotIn('0.0', str(jaxpr)) + + @absltest.skip("Not yet implemented") + def test_none_index(self): + ref = jax.new_ref(jnp.array([1, 2, 3])) + y = ref[None] + self.assertEqual(y.shape, (1, 3)) + + def test_what_if_you_lower_fun_something_with_internal_effects(self): + bjp_p = core.Primitive('bjp') + + @bjp_p.def_abstract_eval + def _(aval): + return aval + + def lowering(x): + x_ref = jax.new_ref(x) + x_ref[...] += 1 + x_ref[...] += -1 + return jax.freeze(x_ref) + + mlir.register_lowering(bjp_p, mlir.lower_fun(lowering, multiple_results=False)) + + @jax.jit + def f(x): + return bjp_p.bind(x) + + f(3.) # don't crash + + def test_remat_while_loop_residuals(self): + @jax.custom_vjp + def ra2a(x): + return jax.freeze(jax.new_ref(x)) + + def ra2a_fwd(x): + o = ra2a(x) + return o, () + + def ra2a_bwd(res, g): + return (ra2a(g),) + + ra2a.defvjp(ra2a_fwd, ra2a_bwd) + + @jax.jit + @jax.remat + def f(x): + + def g(x): + def body(carry): + i, x = carry + x = ra2a(x) + return i + 1, x + return jax.lax.while_loop(lambda x: x[0] < 5, body, (0, x))[1] + return g(x) + + jax.linearize(f, 5.) # don't crash + @jtu.with_config(jax_mutable_array_checks=True) class MutableArrayErrorsTest(jtu.JaxTestCase): @@ -235,36 +1088,38 @@ def test_return_from_jit(self): ValueError, r"traced for jit returned a mutable array reference.*\n\n" r".*was created on line"): - jax.jit(core.mutable_array)(jnp.arange(3)) + jax.jit(core.new_ref)(jnp.arange(3)) def test_return_from_jit_arg(self): with self.assertRaisesRegex( ValueError, r"traced for jit returned a mutable array reference.*\n\n" r".*was passed in as the argument x_ref"): - jax.jit(lambda x_ref: x_ref)(core.mutable_array(jnp.arange(3))) + jax.jit(lambda x_ref: x_ref)(core.new_ref(jnp.arange(3))) + @unittest.skip("regressed") # TODO(mattjj): fix def test_return_from_jit_pytree(self): with self.assertRaisesRegex( ValueError, r"tree path result\['hi'\]"): - jax.jit(lambda x_ref: {'hi': x_ref})(core.mutable_array(jnp.arange(3))) + jax.jit(lambda x_ref: {'hi': x_ref})(core.new_ref(jnp.arange(3))) + @unittest.skip("regressed") # TODO(mattjj): fix def test_return_from_jit_closure(self): with self.assertRaisesRegex( ValueError, r"tree path result\['hi'\]"): - x_ref = core.mutable_array(jnp.arange(3)) + x_ref = core.new_ref(jnp.arange(3)) jax.jit(lambda: {'hi': x_ref})() def test_argument_aliases_jit(self): - x_ref = core.mutable_array(0.) + x_ref = core.new_ref(0.) with self.assertRaisesRegex( ValueError, "appeared at both x_ref and y_ref"): jax.jit(lambda x_ref, y_ref: x_ref[...] + y_ref[...])(x_ref, x_ref) def test_closure_and_argument_aliases_jit(self): - x_ref = core.mutable_array(0.) + x_ref = core.new_ref(0.) with self.assertRaisesRegex( ValueError, "closed over and passed as the argument y_ref"): jax.jit(lambda y_ref: x_ref[...] + y_ref[...])(x_ref) @@ -272,16 +1127,16 @@ def test_closure_and_argument_aliases_jit(self): def test_return_from_scan(self): with self.assertRaisesRegex( ValueError, "traced for scan returned a mutable array reference of type"): - jax.lax.scan(lambda c, x: (core.mutable_array(c), x), 0, jnp.arange(3)) + jax.lax.scan(lambda c, x: (core.new_ref(c), x), 0, jnp.arange(3)) def test_argument_aliases_scan(self): - x_ref = core.mutable_array(0.) + x_ref = core.new_ref(0.) with self.assertRaisesRegex( ValueError, r"appeared at both c\[0\] and c\[1\]"): jax.lax.scan(lambda c, _: (None, None), (x_ref, x_ref), None, length=1) def test_closure_and_argument_aliases_scan(self): - x_ref = core.mutable_array(0.) + x_ref = core.new_ref(0.) with self.assertRaisesRegex( ValueError, r"closed over and passed as the argument y_ref"): jax.lax.scan(lambda y_ref, _: (x_ref[...] + y_ref[...], None), x_ref, @@ -290,15 +1145,15 @@ def test_closure_and_argument_aliases_scan(self): def test_return_from_cond(self): with self.assertRaisesRegex( ValueError, "traced for cond returned a mutable array reference of type"): - jax.lax.cond(True, lambda: core.mutable_array(1.0), lambda: core.mutable_array(2.0)) + jax.lax.cond(True, lambda: core.new_ref(1.0), lambda: core.new_ref(2.0)) def test_argument_aliases_cond(self): - x_ref = core.mutable_array(0.) + x_ref = core.new_ref(0.) with self.assertRaisesRegex( ValueError, r"for cond.*at both x1 and x2"): jax.lax.cond(True, lambda x1, x2: ..., lambda x1, x2: ..., x_ref, x_ref) def test_closure_and_argument_aliases_cond(self): - x_ref = core.mutable_array(0.) + x_ref = core.new_ref(0.) with self.assertRaisesRegex( ValueError, r"closed over and passed as the argument y_ref"): jax.lax.cond(True, @@ -314,11 +1169,24 @@ def f(ref): f.defvjp(lambda ref: ..., lambda *_: ...) if jit: f = jax.jit(f) - x_ref = core.mutable_array(0.) + x_ref = core.new_ref(0.) with self.assertRaisesRegex( ValueError, "custom_vjp primal function"): f(x_ref) + @parameterized.parameters([False, True]) + def test_return_from_custom_vjp_primal_nondiff_argnum(self, jit): + @partial(jax.custom_vjp, nondiff_argnums=(0,)) + def f(_, ref): + return ref + f.defvjp(lambda _, ref: ..., lambda *_: ...) + if jit: + f = jax.jit(f, static_argnums=0) + x_ref = core.new_ref(0.) + with self.assertRaisesRegex( + ValueError, "custom_vjp primal function"): + f('hi', x_ref) + @parameterized.parameters([False, True]) def test_return_from_custom_vjp_fwd(self, jit): @jax.custom_vjp @@ -327,10 +1195,24 @@ def f(x, ref): f.defvjp(lambda x, ref: (x, ref), lambda ref, g: g) if jit: f = jax.jit(f) - x_ref = core.mutable_array(0.) + x_ref = core.new_ref(0.) + + jax.vjp(f, 3., x_ref) # returning input ref, okay + + @jax.custom_vjp + def g(x, ref): + return x + def g_fwd(x, _): + y_ref = core.new_ref(0) + return x, y_ref + g.defvjp(g_fwd, lambda ref, g: g) + if jit: + g = jax.jit(g) + x_ref = core.new_ref(0.) + with self.assertRaisesRegex( ValueError, "custom_vjp fwd function"): - jax.vjp(f, 3., x_ref) + jax.vjp(g, 3., x_ref) @parameterized.parameters([False, True]) def test_argument_aliases_custom_vjp_primal(self, jit): @@ -340,29 +1222,28 @@ def f(x_ref, y_ref): f.defvjp(lambda x_ref, y_ref: (None, None), lambda _, g: (None, None)) if jit: f = jax.jit(f) - x_ref = core.mutable_array(0.) + x_ref = core.new_ref(0.) with self.assertRaisesRegex(ValueError, "x_ref and y_ref"): f(x_ref, x_ref) - # TODO(mattjj): re-enable test after direct-linearize - # @parameterized.parameters([False, True]) - # def test_argument_aliases_custom_vjp_fwd(self, jit): - # @jax.custom_vjp - # def f(x_ref, y_ref): - # ... - # f.defvjp(lambda x_ref, y_ref: (None, None), lambda _, g: (None, None)) - # if jit: - # f = jax.jit(f) - # x_ref = core.mutable_array(0.) - # with self.assertRaisesRegex(ValueError, "x_ref and y_ref"): - # jax.vjp(f, x_ref, x_ref) + @parameterized.parameters([False, True]) + def test_argument_aliases_custom_vjp_fwd(self, jit): + @jax.custom_vjp + def f(x_ref, y_ref): + ... + f.defvjp(lambda x_ref, y_ref: (None, None), lambda _, g: (None, None)) + if jit: + f = jax.jit(f) + x_ref = core.new_ref(0.) + with self.assertRaisesRegex(ValueError, "x_ref and y_ref"): + jax.vjp(f, x_ref, x_ref) # TODO(mattjj): add test test_closure_and_argument_aliases_custom_vjp @parameterized.parameters([False, True]) def test_cond_both_branches_close_over_same_mutable_array(self, jit): # see also test_cond_with_ref_reuse in state_test.py - x_ref = core.mutable_array(0.) + x_ref = core.new_ref(0.) def f(pred): def true_fun(): x_ref[()] = 1. @@ -376,6 +1257,35 @@ def false_fun(): out_false = f(False) self.assertAllClose(x_ref[...], 2.) + def test_vmap_closed_over_ref_write(self): + x_ref = core.new_ref(jnp.zeros((), 'int32')) + + def f(val): + x_ref[...] += val + + vals = jnp.arange(3, dtype='int32') + with self.assertRaisesRegex(Exception, "unbatched array reference"): + jax.vmap(f)(vals) + + def test_vmap_aliased_arguments(self): + def f(ref_1, ref_2): + pass + + x_ref = core.new_ref(jnp.zeros((3, 3))) + with self.assertRaisesRegex( + ValueError, + "only one reference to a mutable array may be passed as an argument"): + jax.vmap(f)(x_ref, x_ref) + + def test_jvp_closed_over_ref_error(self): + ref = core.new_ref(0.) + def f(x): + ref[...] = x + return x + with self.assertRaisesRegex( + Exception, "Move the array reference"): + jax.jvp(f, [1.], [1.]) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/name_stack_test.py b/tests/name_stack_test.py index f371c431e7ee..44084017f46f 100644 --- a/tests/name_stack_test.py +++ b/tests/name_stack_test.py @@ -16,13 +16,12 @@ from absl.testing import absltest import jax from jax import api_util +from jax import lax import jax.numpy as jnp +from jax._src import config from jax._src import core -from jax import lax -from jax._src.pjit import pjit from jax._src import linear_util as lu from jax._src import test_util as jtu -from jax._src import ad_checkpoint jax.config.parse_flags_with_absl() @@ -97,7 +96,8 @@ def _f(x): self.assertEqual(str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar') hlo_text = _get_hlo(f)(2) - self.assertIn('foo/jit(core_call)/bar', hlo_text) + self.assertIn('jit(f)/foo/call', hlo_text) + self.assertIn('bar/add', hlo_text) def test_jit_jaxpr_should_not_store_outer_name_stack(self): @jax.named_scope('foo') @@ -116,7 +116,7 @@ def _f(x): 'bar') hlo_text = _get_hlo(f)(2) - self.assertIn('foo/jit(_f)/bar', hlo_text) + self.assertIn('foo/jit(_f)', hlo_text) def test_pmap_call_primitive_jaxpr_should_not_store_outer_name_stack(self): @jax.named_scope('foo') @@ -126,7 +126,10 @@ def f(x): return x + 1 jaxpr = jax.make_jaxpr(f)(jnp.ones(1)).jaxpr self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo') - self.assertEqual(str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar') + if config.pmap_shmap_merge.value: + self.assertEqual(str(jaxpr.eqns[0].params['jaxpr'].eqns[0].params['jaxpr'].eqns[1].source_info.name_stack), 'bar') + else: + self.assertEqual(str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar') class NameStackTransformationTest(jtu.JaxTestCase): @@ -161,13 +164,14 @@ def _f(x): jaxpr = jax.make_jaxpr(f)(jnp.ones(2)).jaxpr jaxpr_param = 'jaxpr' - self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo') + self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo/vmap()') self.assertEqual( str(jaxpr.eqns[0].params[jaxpr_param].eqns[0].source_info.name_stack), 'bar') hlo_text = _get_hlo(f)(jnp.ones(2)) - self.assertIn('foo/vmap(jit(_f))/bar', hlo_text) + self.assertIn('foo/vmap(jit(_f))', hlo_text) + self.assertIn('bar', hlo_text) def test_jvp_should_transform_stacks(self): def f(x): @@ -188,13 +192,14 @@ def f(x): g = jax.named_scope('foo')(lambda x, t: jax.jvp(f, (x,), (t,))) jaxpr = jax.make_jaxpr(g)(1., 1.).jaxpr jaxpr_param = 'jaxpr' - self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo') + self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo/jvp()') self.assertEqual( str(jaxpr.eqns[0].params[jaxpr_param].eqns[0].source_info.name_stack), 'bar/baz') hlo_text = _get_hlo(g)(1., 1.) - self.assertIn('foo/jvp(jit(f))/bar/baz/mul', hlo_text) + self.assertIn('foo/jvp(jit(f))', hlo_text) + self.assertIn('bar/baz/mul', hlo_text) def test_grad_should_add_jvp_and_transpose_to_name_stack(self): @jax.value_and_grad @@ -234,9 +239,11 @@ def f(x): jaxpr.eqns[1].params[jaxpr_param].eqns[0].source_info.name_stack), 'bar') hlo_text = _get_hlo(f)(1.) - self.assertIn('jvp(foo)/jit(f)/bar/sin', hlo_text) - self.assertIn('jvp(foo)/jit(f)/bar/cos', hlo_text) - self.assertIn('transpose(jvp(foo))/jit(f)/bar/mul', hlo_text) + self.assertIn('jvp(foo)/jit(f)', hlo_text) + self.assertIn('bar/sin', hlo_text) + self.assertIn('bar/cos', hlo_text) + self.assertIn('transpose(jvp(foo))/jit(f)', hlo_text) + self.assertIn('bar/mul', hlo_text) def test_nested_jit_stack(self): @@ -249,26 +256,12 @@ def g(y): return g(x) hlo_text = _get_hlo(f)(2.) - self.assertIn('jvp(jit(f))/jit(g)/sin', hlo_text) - self.assertIn('jvp(jit(f))/jit(g)/cos', hlo_text) - self.assertIn('transpose(jvp(jit(f)))/jit(g)/mul', hlo_text) + self.assertIn('jvp(jit(f))', hlo_text) + self.assertIn('jit(g)', hlo_text) + self.assertIn('transpose(jvp(jit(f)))', hlo_text) - def test_nested_pjit_stack(self): - @jax.value_and_grad - @pjit - def f(x): - @pjit - def g(y): - return jnp.sin(y) - return g(x) - - hlo_text = _get_hlo(f)(2.) - self.assertIn('jvp(jit(f))/jit(g)/sin', hlo_text) - self.assertIn('jvp(jit(f))/jit(g)/cos', hlo_text) - self.assertIn('transpose(jvp(jit(f)))/jit(g)/mul', hlo_text) - - def test_remat_appears_in_hlo(self): - @ad_checkpoint.remat + def test_re_materalization_appears_in_hlo(self): + @jax.remat def f(x): return jnp.sin(x) @@ -381,7 +374,6 @@ def cond(x): self.assertIn('vmap(jvp(foo))/while/body/bar/add', hlo_text) self.assertIn('vmap(jvp(foo))/while/body_pred/bar_cond', hlo_text) - def test_cond_body_should_not_have_name_stack(self): @jax.named_scope('foo') @@ -432,8 +424,8 @@ def false_fn(x): 'true') hlo_text = _get_hlo(f)(jnp.arange(2.), True) - self.assertIn('foo/vmap(cond)/branch_0_fun/false/sub', hlo_text) - self.assertIn('foo/vmap(cond)/branch_1_fun/true/add', hlo_text) + self.assertIn('foo/vmap()/cond/branch_0_fun/false/sub', hlo_text) + self.assertIn('foo/vmap()/cond/branch_1_fun/true/add', hlo_text) def test_jvp_of_cond_transforms_name_stack(self): @@ -460,8 +452,9 @@ def false_fn(x): 'true') hlo_text = _get_hlo(g)(jnp.arange(2.), jnp.ones(2)) - self.assertIn('jvp(jit(f))/foo/cond/branch_0_fun/false/sub', hlo_text) - self.assertIn('jvp(jit(f))/foo/cond/branch_1_fun/true/add', hlo_text) + self.assertIn('jvp(jit(f))', hlo_text) + self.assertIn('foo/cond/branch_0_fun/false/sub', hlo_text) + self.assertIn('foo/cond/branch_1_fun/true/add', hlo_text) def test_vmap_of_jvp_of_cond_transforms_name_stack(self): @@ -488,12 +481,9 @@ def false_fn(x): 'true') hlo_text = _get_hlo(g)(jnp.arange(2.), jnp.ones(2)) - self.assertIn( - 'vmap(jvp(jit(f)))/foo/cond/branch_0_fun/false/sub"', - hlo_text) - self.assertIn( - 'vmap(jvp(jit(f)))/foo/cond/branch_1_fun/true/add"', - hlo_text) + self.assertIn('vmap(jvp(jit(f)))', hlo_text) + self.assertIn('foo/cond/branch_0_fun/false/sub', hlo_text) + self.assertIn('foo/cond/branch_1_fun/true/add"', hlo_text) def test_grad_of_cond_transforms_name_stack(self): @@ -576,7 +566,8 @@ def body(carry, x): 'scan_body') hlo_text = _get_hlo(f)(1.) - self.assertIn('foo/while/body/scan_body', hlo_text) + self.assertIn('foo/while/body', hlo_text) + self.assertIn('scan_body/add', hlo_text) def test_vmap_of_scan_should_transform_stack(self): @@ -594,7 +585,8 @@ def body(carry, x): 'scan_body') hlo_text = _get_hlo(f)(jnp.arange(2.)) - self.assertIn('vmap(foo)/while/body/scan_body/add', hlo_text) + self.assertIn('vmap(foo)/while/body', hlo_text) + self.assertIn('scan_body/add', hlo_text) def test_jvp_of_scan_should_transform_stack(self): @@ -612,7 +604,8 @@ def body(carry, x): 'scan_body') hlo_text = _get_hlo(g)(1., 1.) - self.assertIn('jvp(foo)/while/body/scan_body/add', hlo_text) + self.assertIn('jvp(foo)/while/body', hlo_text) + self.assertIn('scan_body/add', hlo_text) def test_grad_of_scan_should_transform_stack(self): @@ -632,8 +625,9 @@ def body(carry, x): 'scan_body') hlo_text = _get_hlo(f)(1.) - self.assertIn('jvp(foo)/while/body/scan_body/mul', hlo_text) - self.assertIn('transpose(jvp(foo))/while/body/scan_body/mul', hlo_text) + self.assertIn('jvp(foo)/while/body', hlo_text) + self.assertIn('scan_body/mul', hlo_text) + self.assertIn('transpose(jvp(foo))/while/body/', hlo_text) def test_vmap_of_grad_of_scan_should_transform_stack(self): @@ -654,8 +648,9 @@ def body(carry, x): 'scan_body') hlo_text = _get_hlo(f)(jnp.arange(2.)) - self.assertIn('vmap(jvp(foo))/while/body/scan_body/mul', hlo_text) - self.assertIn('vmap(transpose(jvp(foo)))/while/body/scan_body/mul', hlo_text) + self.assertIn('vmap(jvp(foo))/while/body', hlo_text) + self.assertIn('scan_body/mul', hlo_text) + self.assertIn('vmap(transpose(jvp(foo)))/while/body', hlo_text) if __name__ == '__main__': diff --git a/tests/nn_test.py b/tests/nn_test.py index ed016ec349ef..7fb2fce402fb 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -27,11 +27,9 @@ from jax._src import core from jax._src import dtypes as _dtypes from jax._src import test_util as jtu -from jax._src.lib import cuda_versions from jax._src.cudnn.scaled_matmul_stablehlo import ( quantize, shape_normalization, - BlockScaleConfig, ) from jax.test_util import check_grads from jax import nn @@ -41,13 +39,6 @@ config.parse_flags_with_absl() -def _is_required_cudnn_version_satisfied(min_cc, min_cudnn_version): - return ( - jtu.is_cuda_compute_capability_at_least(min_cc) and - cuda_versions is not None and - cuda_versions.cudnn_get_version() >= min_cudnn_version - ) - def _check_cudnn_backend(fn, *args, **kwargs): lowered = jax.jit(fn).lower(*args, **kwargs) hlo = lowered.as_text('stablehlo', debug_info=True) @@ -110,17 +101,7 @@ def create_mxfp8_configs_if_available(): if _dtypes.float8_e8m0fnu is None: raise unittest.SkipTest("float8_e8m0fnu is not available.") - def _create_mxfp8_config(): - return BlockScaleConfig( - mode='mxfp8', - block_size=32, - data_type=jnp.float8_e4m3fn, - scale_type=jnp.float8_e8m0fnu, - global_scale=None, - infer_only=False - ) - - return [_create_mxfp8_config() for _ in range(3)] + return [nn.get_scaled_dot_general_config("mxfp8") for _ in range(3)] @jtu.with_config(jax_legacy_prng_key="allow", @@ -130,11 +111,10 @@ class NNFunctionsTest(jtu.JaxTestCase): contract=[160, 96], lhs_non_contract=[240, 100], dtype=[jnp.float16, jnp.bfloat16, jnp.float32], - impl=['cudnn',], ) - def testScaledMatmul(self, contract, lhs_non_contract, dtype, impl): - if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700): - raise unittest.SkipTest("CUDA or cuDNN versions are not compatible") + def testScaledMatmul(self, contract, lhs_non_contract, dtype): + if not jtu.is_cuda_compute_capability_at_least("10.0"): + raise unittest.SkipTest("Needs compute capability 10.0 or higher.") # Check if float8_e8m0fnu is available configs = create_mxfp8_configs_if_available() batch, rhs_non_contract = 4, 256 @@ -153,12 +133,11 @@ def testScaledMatmul(self, contract, lhs_non_contract, dtype, impl): @parameterized.product( is_training=[True, False], output_type=[jnp.float16, jnp.bfloat16, jnp.float32], - impl=['cudnn',], ) def testScaledDotGeneral( - self, is_training, output_type, impl): - if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700): - raise unittest.SkipTest("CUDA or cuDNN versions are not compatible") + self, is_training, output_type): + if not jtu.is_cuda_compute_capability_at_least("10.0"): + raise unittest.SkipTest("Needs compute capability 10.0 or higher.") configs = create_mxfp8_configs_if_available() cast_to_representable = partial( @@ -214,10 +193,12 @@ def fwd(a, b, is_ref=False): impl=['cudnn', 'xla'], ) def testDotProductAttention(self, dtype, group_num, use_vmap, impl): - if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("8.0", 8904): - raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") + if impl == 'cudnn' and not jtu.is_cuda_compute_capability_at_least("8.0"): + raise unittest.SkipTest("Needs compute capability 8.0 or higher.") if impl == 'cudnn' and dtype == jnp.float32: raise unittest.SkipTest("cuDNN only supports fp16 or bf16.") + if impl == 'cudnn' and jtu.is_cuda_version_at_least(13, 0): + raise unittest.SkipTest("cuDNN creates no execution plans on CUDA 13.0.") B, S, T, N, H, G = 2, 128, 128, 4, 32, group_num keys = random.split(random.PRNGKey(0), 5) @@ -225,13 +206,19 @@ def testDotProductAttention(self, dtype, group_num, use_vmap, impl): K = random.normal(keys[1], (B, S, N // G, H), dtype) V = random.normal(keys[2], (B, S, N // G, H), dtype) grad = random.normal(keys[3], (B, T, N, H), dtype) + lse_grad = random.normal(keys[4], (B, T, N), dtype) bias, mask = None, None sdpa = nn.dot_product_attention sdpa_ref = partial(sdpa, implementation=None) sdpa_ans = partial(sdpa, implementation=impl) + sdpa_ref_lse = partial(sdpa, implementation=None, return_residual=True) + sdpa_ans_lse = partial(sdpa, implementation=impl, return_residual=True) if use_vmap: sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0) + spda_ans_lse = jax.vmap( + sdpa_ans_lse, in_axes=(0, 0, 0, None, None), out_axes=0 + ) # For testing purposes, we call the non-GQA version without vmap in the # reference code @@ -240,20 +227,36 @@ def testDotProductAttention(self, dtype, group_num, use_vmap, impl): out_ref, sdpa_vjp_ref = jax.vjp(sdpa_ref, Q, K_ref, V_ref, bias, mask) out_ans, sdpa_vjp_ans = jax.vjp(sdpa_ans, Q, K, V, bias, mask) + out_ref_lse, sdpa_vjp_ref_lse = jax.vjp(sdpa_ref_lse, Q, K_ref, V_ref, bias, mask) + out_ans_lse, sdpa_vjp_ans_lse = jax.vjp(sdpa_ans_lse, Q, K, V, bias, mask) + dQ_ref, dK_ref, dV_ref = sdpa_vjp_ref(grad)[:3] dQ_ans, dK_ans, dV_ans = sdpa_vjp_ans(grad)[:3] dK_ref = dK_ref.reshape(B, S, N // G, G, H).sum(axis=3) dV_ref = dV_ref.reshape(B, S, N // G, G, H).sum(axis=3) + dQ_ref_lse, dK_ref_lse, dV_ref_lse = sdpa_vjp_ref_lse((grad, lse_grad))[:3] + dQ_ans_lse, dK_ans_lse, dV_ans_lse = sdpa_vjp_ans_lse((grad, lse_grad))[:3] + dK_ref_lse = dK_ref_lse.reshape(B, S, N // G, G, H).sum(axis=3) + dV_ref_lse = dV_ref_lse.reshape(B, S, N // G, G, H).sum(axis=3) + if impl == 'cudnn': self.assertTrue(_check_cudnn_backend(sdpa_ans, Q, K, V, bias, mask)) self.assertTrue(_check_cudnn_backend(sdpa_vjp_ans, grad)) + self.assertTrue(_check_cudnn_backend(sdpa_ans_lse, Q, K, V, bias, mask)) + self.assertTrue(_check_cudnn_backend(sdpa_vjp_ans_lse, (grad, lse_grad))) self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01) self.assertAllClose(dQ_ref, dQ_ans, rtol=.01, atol=.01) self.assertAllClose(dK_ref, dK_ans, rtol=.01, atol=.01) self.assertAllClose(dV_ref, dV_ans, rtol=.01, atol=.01) + self.assertAllClose(out_ref_lse[0], out_ans_lse[0], atol=.01, rtol=.01) + self.assertAllClose(out_ref_lse[1], out_ans_lse[1], atol=.01, rtol=.01) + self.assertAllClose(dQ_ref_lse, dQ_ans_lse, rtol=.01, atol=.01) + self.assertAllClose(dK_ref_lse, dK_ans_lse, rtol=.01, atol=.01) + self.assertAllClose(dV_ref_lse, dV_ans_lse, rtol=.01, atol=.01) + @parameterized.product( mask_mode=['bias', 'causal', 'padding', 'custom', ('causal', 'padding'), ('custom', 'padding'), ('bias', 'causal'), @@ -262,9 +265,10 @@ def testDotProductAttention(self, dtype, group_num, use_vmap, impl): def testDotProductAttentionMask(self, mask_mode): if isinstance(mask_mode, str): mask_mode = (mask_mode,) - min_cudnn_version = 90200 if 'sliding_window' in mask_mode else 8904 - if not _is_required_cudnn_version_satisfied("8.0", min_cudnn_version): - raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") + if not jtu.is_cuda_compute_capability_at_least("8.0"): + raise unittest.SkipTest("Requires compute capability 8.0 or higher.") + if jtu.is_cuda_version_at_least(13, 0): + raise unittest.SkipTest("cuDNN creates no execution plans on CUDA 13.0.") dtype = jnp.bfloat16 B, S, T, N, H = 2, 128, 128, 4, 32 @@ -326,8 +330,10 @@ def testDotProductAttentionMask(self, mask_mode): use_vmap=[False, True], ) def testDotProductAttentionBiasGradient(self, batch_size, use_vmap): - if not _is_required_cudnn_version_satisfied("8.0", 8904): - raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") + if not jtu.is_cuda_compute_capability_at_least("8.0"): + raise unittest.SkipTest("Requires compute capability 8.0 or higher.") + if jtu.is_cuda_version_at_least(13, 0): + raise unittest.SkipTest("cuDNN creates no execution plans on CUDA 13.0.") dtype = jnp.bfloat16 B, S, N, H = batch_size, 128, 4, 32 @@ -370,18 +376,17 @@ def bwd_ans(x, bias, mask): _, f_vjp = jax.vjp(attn_ans, x, bias, mask) return f_vjp(x) - if batch_size != 1: - with self.assertRaisesRegex(ValueError, _cudnn_dbias_error): - _, dbias_ans, _ = bwd_ans(x, bias, mask) - else: - _, dbias_ref, _ = bwd_ref(x, bias, mask) - _, dbias_ans, _ = bwd_ans(x, bias, mask) - self.assertAllClose(dbias_ans, dbias_ref, rtol=0.1, atol=0.1) + _, dbias_ref, _ = bwd_ref(x, bias, mask) + _, dbias_ans, _ = bwd_ans(x, bias, mask) + self.assertAllClose(dbias_ans, dbias_ref, rtol=0.1, atol=0.1) - @jtu.skip_on_flag("jax_skip_slow_tests", True) def testSoftplusGrad(self): check_grads(nn.softplus, (1e-8,), order=4, - rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None) + rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None, + modes=["fwd"]) + check_grads(nn.softplus, (1e-8,), order=4, + rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None, + modes=["rev"]) def testSoftplusGradZero(self): check_grads(nn.softplus, (0.,), order=1, @@ -424,7 +429,11 @@ def testSparseplusAndSparseSigmoid(self): def testSquareplusGrad(self): check_grads(nn.squareplus, (1e-8,), order=4, - rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None) + rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None, + modes=["fwd"]) + check_grads(nn.squareplus, (1e-8,), order=4, + rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None, + modes=["rev"]) def testSquareplusGradZero(self): check_grads(nn.squareplus, (0.,), order=1, @@ -444,7 +453,11 @@ def testSquareplusZero(self, dtype): def testMishGrad(self): check_grads(nn.mish, (1e-8,), order=4, - rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None) + rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None, + modes=["fwd"]) + check_grads(nn.mish, (1e-8,), order=4, + rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None, + modes=["rev"]) def testMishGradZero(self): check_grads(nn.mish, (0.,), order=1, @@ -502,9 +515,9 @@ def testMishValue(self): val = nn.mish(1e3) self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3) - @jtu.skip_on_flag("jax_skip_slow_tests", True) def testEluGrad(self): - check_grads(nn.elu, (1e4,), order=4, eps=1.) + check_grads(nn.elu, (1e4,), order=4, eps=1., modes=["fwd"]) + check_grads(nn.elu, (1e4,), order=4, eps=1., modes=["rev"]) def testEluValue(self): val = nn.elu(1e4) @@ -541,7 +554,7 @@ def gelu_reference(x): (jnp.float32, jnp.bfloat16, jnp.float16), (partial(nn.gelu, approximate=False), partial(nn.gelu, approximate=True), - nn.relu, nn.softplus, nn.sparse_plus, nn.sigmoid, nn.squareplus, nn.mish))) + nn.relu, nn.identity, nn.softplus, nn.sparse_plus, nn.sigmoid, nn.squareplus, nn.mish))) def testDtypeMatchesInput(self, dtype, fn): x = jnp.zeros((), dtype=dtype) out = fn(x) @@ -632,6 +645,12 @@ def testStandardizeWhereMask(self): self.assertAllClose(out_masked, out_filtered) + def testStandardizeNegativeVariance(self): + # Regression test for https://github.com/google/jax/issues/30426 + x = jnp.array([-11., -11., -11.]) + 3e-6 + result = jax.nn.standardize(x) + self.assertFalse(jnp.any(jnp.isnan(result))) + def testOneHot(self): actual = nn.one_hot(jnp.array([0, 1, 2]), 3) expected = jnp.array([[1., 0., 0.], @@ -731,12 +750,74 @@ def f(hx, _): with jax.checking_leaks(): fwd() # doesn't crash + @parameterized.product( + shape=[(5,), (3, 5), (2, 3, 5)], + use_where=[True, False], + keepdims=[True, False], + ) + def testLogMeanExp(self, shape, use_where, keepdims): + x = self.rng().rand(*shape) * 2 - 1 + axis = self.rng().randint(0, x.ndim) + if use_where: + where = self.rng().randint(0, 2, size=shape).astype(bool) + else: + where = None + got = nn.logmeanexp(x, axis=axis, where=where, keepdims=keepdims) + expected = jnp.log(jnp.mean(jnp.exp(x), axis=axis, where=where, keepdims=keepdims)) + self.assertAllClose(got, expected, atol=1e-3) + + def testLog1mExp(self): + x, expected = jnp.array([ + [0.1, jnp.log(1 - jnp.exp(-0.1))], + [1.1, jnp.log(1 - jnp.exp(-1.1))], + [0, -jnp.inf], + [1, -0.45867515], + [1e2, 0.0], + [1e-5, jnp.log(1e-5)], + [-1, jnp.nan], + [-1e-2, jnp.nan], + [-1e2, jnp.nan], + [jnp.inf, 0.0], + ]).T + got = nn.log1mexp(x) + self.assertAllClose(got, expected, rtol=1e-3, atol=1e-3) + + def testLog1mExpGrad(self): + check_grads( + nn.log1mexp, + (jnp.array([1e-2, 1e-1, 1e0, 1e1, 1e2]),), + order=1, + rtol=1e-2 if jtu.test_device_matches(["tpu"]) else 1e-3, + atol=1e-3, + ) + + def testDotProductAttention_localWindowSizeWithoutMask(self): + dtype = jnp.float32 + B, S, T, N, H = 2, 128, 128, 4, 32 + keys = random.split(random.PRNGKey(0), 3) + Q = random.normal(keys[0], (B, T, N, H), dtype) + K = random.normal(keys[1], (B, S, N, H), dtype) + V = random.normal(keys[2], (B, S, N, H), dtype) + + output_large_window = nn.dot_product_attention( + Q, K, V, mask=None, local_window_size=(32, 32) + ) + + output_small_window = nn.dot_product_attention( + Q, K, V, mask=None, local_window_size=(1, 1) + ) + + self.assertFalse( + jnp.allclose(output_large_window, output_small_window), + "Attention output should differ with different local_window_size, even without a mask.", + ) + InitializerRecord = collections.namedtuple( "InitializerRecord", ["name", "initializer", "shapes", "dtypes"]) -ALL_SHAPES = [(2,), (2, 2), (2, 3), (3, 2), (2, 3, 4), (4, 3, 2), (2, 3, 4, 5)] +ALL_SHAPES = [(), (2,), (2, 2), (2, 3), (3, 2), (2, 3, 4), (4, 3, 2), (2, 3, 4, 5)] def initializer_record(name, initializer, dtypes, min_dims=2, max_dims=4): shapes = [shape for shape in ALL_SHAPES @@ -760,6 +841,24 @@ def initializer_record(name, initializer, dtypes, min_dims=2, max_dims=4): partial(nn.initializers.variance_scaling, 1, "fan_geo_avg", "normal"), jtu.dtypes.floating, ), + initializer_record( + "variance_scaling_fan_in", + partial(nn.initializers.variance_scaling, 1, "fan_in", "normal", in_axis=[0], out_axis=[]), + jtu.dtypes.floating, + min_dims=1, + ), + initializer_record( + "variance_scaling_fan_in", + partial(nn.initializers.variance_scaling, 1, "fan_in", "normal", in_axis=[], out_axis=[0]), + jtu.dtypes.floating, + min_dims=1, + ), + initializer_record( + "variance_scaling_fan_in", + partial(nn.initializers.variance_scaling, 1, "fan_in", "normal", in_axis=[], out_axis=[]), + jtu.dtypes.floating, + min_dims=0, + ), ] @@ -778,7 +877,7 @@ def testInitializer(self, initializer, shape, dtype): val = initializer(rng, shape, dtype) self.assertEqual(shape, jnp.shape(val)) - self.assertEqual(jax.dtypes.canonicalize_dtype(dtype), jnp.dtype(val)) + self.assertEqual(jax.dtypes.canonicalize_dtype(dtype), val.dtype) @parameterized.parameters(itertools.chain.from_iterable( jtu.sample_product_testcases( @@ -794,7 +893,7 @@ def testInitializerProvider(self, initializer_provider, shape, dtype): val = initializer(rng, shape) self.assertEqual(shape, jnp.shape(val)) - self.assertEqual(jax.dtypes.canonicalize_dtype(dtype), jnp.dtype(val)) + self.assertEqual(jax.dtypes.canonicalize_dtype(dtype), val.dtype) def testVarianceScalingMultiAxis(self): rng = random.PRNGKey(0) @@ -824,11 +923,18 @@ def testVarianceScalingError(self): with self.assertRaisesRegex( ValueError, - "Can't compute input and output sizes of a 1" - "-dimensional weights tensor. Must be at least 2D." + "Can't compute input and output sizes of a 1-dimensional" + " weights tensor with default in_axis. Must be at least 2D or specify" + " in_axis explicitly.", ): initializer(rng, shape) + def testIdentity(self): + x = jnp.array([1., 2., 3.]) + self.assertAllClose(nn.identity(x), x, check_dtypes=False) + grad = jax.grad(nn.identity)(6.0) + self.assertEqual(grad, 1.) + def testAccidentalUpcasting(self): rng = random.PRNGKey(0) shape = (4, 4) diff --git a/tests/notebooks/colab_cpu.ipynb b/tests/notebooks/colab_cpu.ipynb index f5dcff837838..2f0c92cbe0a6 100644 --- a/tests/notebooks/colab_cpu.ipynb +++ b/tests/notebooks/colab_cpu.ipynb @@ -54,16 +54,16 @@ "print(jax.__version__)\n", "print(jaxlib.__version__)" ], - "execution_count": 6, + "execution_count": 1, "outputs": [ { + "name": "stdout", "output_type": "stream", "text": [ "m-s-1p12yf76kgzz\n", "0.1.64\n", "0.1.45\n" - ], - "name": "stdout" + ] } ] }, @@ -91,14 +91,15 @@ "execution_count": 2, "outputs": [ { + "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:123: UserWarning: No GPU/TPU found, falling back to CPU.\n", " warnings.warn('No GPU/TPU found, falling back to CPU.')\n" - ], - "name": "stderr" + ] }, { + "name": "stdout", "output_type": "stream", "text": [ "JAX device type: cpu:0\n" @@ -106,7 +107,6 @@ } ], "source": [ - "from jaxlib import xla_extension\n", "import jax\n", "key = jax.random.PRNGKey(1701)\n", "arr = jax.random.normal(key, (1000,))\n", @@ -149,11 +149,11 @@ "execution_count": 3, "outputs": [ { + "name": "stdout", "output_type": "stream", "text": [ "1.0216691\n" - ], - "name": "stdout" + ] } ] }, @@ -195,12 +195,12 @@ "execution_count": 4, "outputs": [ { + "name": "stdout", "output_type": "stream", "text": [ "[6.9178133 5.9580317 5.581113 4.506963 4.111582 3.973543 3.3307292\n", " 2.8664916 1.8229378 1.5478933]\n" - ], - "name": "stdout" + ] } ] }, @@ -236,12 +236,12 @@ "execution_count": 5, "outputs": [ { + "name": "stdout", "output_type": "stream", "text": [ "[ 0.34676832 -0.7532232 1.7060695 ... 2.1208048 -0.42621925\n", " 0.13093236]\n" - ], - "name": "stdout" + ] } ] } diff --git a/tests/optimizers_test.py b/tests/optimizers_test.py index 2e027e6150be..724af37bc062 100644 --- a/tests/optimizers_test.py +++ b/tests/optimizers_test.py @@ -41,7 +41,7 @@ def _CheckFuns(self, optimizer, loss, x0, *args): self.assertEqual(jax.tree.structure(opt_state), jax.tree.structure(opt_state2)) - @jtu.skip_on_devices('gpu') + @jtu.skip_on_devices('cuda') def _CheckRun(self, optimizer, loss, x0, num_steps, *args, **kwargs): init_fun, update_fun, get_params = optimizer(*args) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 987a3aa9d50a..8741a0a1344f 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -14,6 +14,7 @@ load( "//jaxlib:jax.bzl", + "if_oss", "jax_generate_backend_suites", "jax_gpu_support_deps", "jax_multiplatform_test", @@ -30,6 +31,11 @@ package( jax_generate_backend_suites() +test_suite( + name = "mosaic_gpu_tests", + tags = ["mosaic_gpu_test"], +) + jax_multiplatform_test( name = "pallas_test", srcs = [ @@ -42,6 +48,7 @@ jax_multiplatform_test( enable_configs = [ "gpu_a100", "gpu_h100", + "gpu_b200", ], shard_count = { "cpu": 8, @@ -49,46 +56,32 @@ jax_multiplatform_test( "tpu": 4, }, deps = [ - "//jax:pallas", - "//jax:pallas_gpu", - "//jax:pallas_gpu_ops", - "//jax:pallas_tpu", - "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy"), + "//jax/_src/pallas:pallas_test_util", + "//jax/experimental:pallas", + "//jax/experimental:pallas_gpu", + "//jax/experimental:pallas_gpu_ops", + "//jax/experimental:pallas_mosaic_gpu", + "//jax/experimental:pallas_tpu", + "//jax/experimental:pallas_tpu_ops", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) -jax_multiplatform_test( +jax_py_test( name = "pallas_cost_estimate_test", srcs = [ "pallas_cost_estimate_test.py", ], + args = ["--jax_test_dut=cpu"], deps = [ - "//jax:pallas", - "//jax:pallas_gpu", - "//jax:pallas_gpu_ops", - "//jax:pallas_tpu", - "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy"), -) - -jax_multiplatform_test( - name = "pallas_jumble_test", - srcs = [ - "pallas_jumble_test.py", - ], - disable_configs = [ - "gpu_v100", - "gpu_v100_x32", - "gpu_a100", - "gpu_p100", - "gpu_p100_x32", - "gpu_h100", - ], - deps = [ - "//jax:pallas", - "//jax:pallas_tpu", - "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy"), + "//jax/_src:test_util", + "//jax/experimental:pallas", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -107,11 +100,18 @@ jax_multiplatform_test( "gpu_a100_x32", "gpu_h100", "gpu_h100_x32", - "tpu_v6e_1x1", + "gpu_b200", + "tpu_v6e", ], + env = { + "JAX_PALLAS_USE_MOSAIC_GPU": "0", + }, + minimal_shard_count = { + "tpu": 8, + }, shard_count = { "cpu": 16, - "gpu": 16, + "gpu": 32, "tpu": 16, }, tags = [ @@ -120,11 +120,17 @@ jax_multiplatform_test( "notsan", # Times out. ], deps = [ - "//jax:pallas", - "//jax:pallas_gpu", # build_cleaner: keep - "//jax:pallas_tpu", - "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), + "//jax/_src/pallas:pallas_test_util", + "//jax/experimental:pallas", + "//jax/experimental:pallas_gpu", # build_cleaner: keep + "//jax/experimental:pallas_tpu", + "//jax/experimental:pallas_tpu_ops", + ] + py_deps([ + "absl/testing:flagsaver", + "absl/testing", + "hypothesis", + "numpy", + ]), ) jax_multiplatform_test( @@ -146,22 +152,30 @@ jax_multiplatform_test( enable_configs = [ "gpu_h100", "gpu_h100_x32", + "gpu_b200", ], env = { "JAX_PALLAS_USE_MOSAIC_GPU": "1", - "JAX_PALLAS_VERBOSE_ERRORS": "0", }, + shard_count = 16, tags = [ + "mosaic_gpu_test", "noasan", # Times out. "nomsan", # Times out. "notsan", # Times out. ], deps = [ - "//jax:pallas", - "//jax:pallas_gpu", # build_cleaner: keep - "//jax:pallas_mosaic_gpu", # build_cleaner: keep - "//jax:pallas_tpu", - ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), + "//jax/_src/pallas:pallas_test_util", + "//jax/experimental:pallas", + "//jax/experimental:pallas_gpu", # build_cleaner: keep + "//jax/experimental:pallas_mosaic_gpu", # build_cleaner: keep + "//jax/experimental:pallas_tpu", + ] + py_deps([ + "absl/testing:flagsaver", + "absl/testing", + "hypothesis", + "numpy", + ]), ) jax_multiplatform_test( @@ -179,9 +193,13 @@ jax_multiplatform_test( "notsan", # Times out. ], deps = [ - "//jax:pallas", - "//jax:pallas_tpu", - ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), + "//jax/experimental:pallas", + "//jax/experimental:pallas_tpu", + ] + py_deps([ + "absl/testing", + "hypothesis", + "numpy", + ]), ) jax_multiplatform_test( @@ -193,15 +211,19 @@ jax_multiplatform_test( enable_configs = [ "gpu_a100_x32", "gpu_h100_x32", + "gpu_b200", ], shard_count = 4, deps = [ - "//jax:pallas", - "//jax:pallas_gpu", - "//jax:pallas_gpu_ops", - "//jax:pallas_tpu", - "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy"), + "//jax/experimental:pallas", + "//jax/experimental:pallas_gpu", + "//jax/experimental:pallas_gpu_ops", + "//jax/experimental:pallas_tpu", + "//jax/experimental:pallas_tpu_ops", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -209,19 +231,29 @@ jax_multiplatform_test( srcs = [ "mosaic_gpu_test.py", ], + disable_configs = [ + # TODO(b/462499936): Re-enable when test passes on MIG partition. + "gpu_b200", + ], enable_backends = [], enable_configs = [ "gpu_h100_x32", "gpu_h100", + # TODO(b/462499936): Remove gpu_b200_full when test passes onMIG partition. + "gpu_b200_full", + ], + shard_count = 4, + tags = [ + "mosaic_gpu_test", + "notsan", # Times out. ], - env = { - "JAX_PALLAS_USE_MOSAIC_GPU": "1", - "JAX_PALLAS_VERBOSE_ERRORS": "0", - }, deps = [ - "//jax:pallas", - "//jax:pallas_mosaic_gpu", # build_cleaner: keep - ] + py_deps("absl/testing") + py_deps("numpy"), + "//jax/experimental:pallas", + "//jax/experimental:pallas_mosaic_gpu", # build_cleaner: keep + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -231,15 +263,17 @@ jax_multiplatform_test( enable_configs = [ "gpu_a100_x32", "gpu_h100_x32", + "gpu_b200", ], tags = [], deps = [ - "//jax:internal_export_back_compat_test_data", - "//jax:internal_export_back_compat_test_util", - "//jax:pallas", - "//jax:pallas_gpu", # build_cleaner: keep - "//jax:pallas_tpu_ops", # build_cleaner: keep - ], + "//jax/_src:internal_export_back_compat_test_data", + "//jax/_src:internal_export_back_compat_test_util", + "//jax/experimental:pallas", + "//jax/experimental:pallas_gpu", # build_cleaner: keep + "//jax/experimental:pallas_mosaic_gpu", # build_cleaner: keep + "//jax/experimental:pallas_tpu_ops", # build_cleaner: keep + ] + py_deps("absl/testing"), ) jax_py_test( @@ -248,11 +282,15 @@ jax_py_test( args = ["--jax_test_dut=cpu"], main = "export_pallas_test.py", deps = [ - "//jax:pallas", - "//jax:pallas_gpu", # build_cleaner: keep - "//jax:pallas_tpu", # build_cleaner: keep - "//jax:test_util", - ] + jax_gpu_support_deps, + "//jax/_src:test_util", + "//jax/experimental:pallas", + "//jax/experimental:pallas_gpu", # build_cleaner: keep + "//jax/experimental:pallas_mosaic_gpu", # build_cleaner: keep + "//jax/experimental:pallas_tpu", # build_cleaner: keep + ] + jax_gpu_support_deps + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -261,7 +299,6 @@ jax_multiplatform_test( # Cross-compilation on CPU is tested separately. disable_configs = [ "cpu", - "cpu_shardy", "cpu_x32", ], enable_configs = [ @@ -269,10 +306,14 @@ jax_multiplatform_test( ], tags = [], deps = [ - "//jax:pallas", - "//jax:pallas_gpu", # build_cleaner: keep - "//jax:pallas_tpu", # build_cleaner: keep - ], + "//jax/experimental:pallas", + "//jax/experimental:pallas_gpu", # build_cleaner: keep + "//jax/experimental:pallas_mosaic_gpu", # build_cleaner: keep + "//jax/experimental:pallas_tpu", # build_cleaner: keep + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -283,17 +324,19 @@ jax_multiplatform_test( "gpu_p100", "gpu_p100_x32", "gpu_v100_x32", - "gpu_p100_pjrt_c_api", ], enable_configs = [ "gpu_a100_x32", ], tags = [], deps = [ - "//jax:pallas", - "//jax:pallas_gpu", # build_cleaner: keep - "//jax:pallas_tpu", # build_cleaner: keep - ], + "//jax/experimental:pallas", + "//jax/experimental:pallas_gpu", # build_cleaner: keep + "//jax/experimental:pallas_tpu", # build_cleaner: keep + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -303,10 +346,13 @@ jax_multiplatform_test( ], enable_backends = ["tpu"], deps = [ - "//jax:pallas", - "//jax:pallas_tpu", "//jax/_src/pallas/mosaic:random", - ] + py_deps("absl/testing") + py_deps("numpy"), + "//jax/experimental:pallas", + "//jax/experimental:pallas_tpu", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -316,11 +362,16 @@ jax_multiplatform_test( ], enable_backends = [], enable_configs = [ - "tpu_v5e_4x2", + "tpu_v5e_x8", ], deps = [ - "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), + "//jax/experimental:mesh_utils", + "//jax/experimental:pallas_tpu_ops", + ] + py_deps([ + "absl/testing", + "numpy", + "hypothesis", + ]), ) jax_multiplatform_test( @@ -329,14 +380,16 @@ jax_multiplatform_test( "tpu_gmm_test.py", ], enable_backends = ["tpu"], - shard_count = 50, + minimal_shard_count = 5, + shard_count = 5, tags = [ "noasan", # Times out. "nomsan", # Times out. + "notap", # TODO(b/455849275): Timing out. "notsan", # Times out. ], deps = [ - "//jax:pallas_tpu_ops", + "//jax/experimental:pallas_tpu_ops", ] + py_deps([ "absl/testing", "absl/flags", @@ -348,18 +401,85 @@ jax_multiplatform_test( jax_multiplatform_test( name = "tpu_pallas_test", srcs = ["tpu_pallas_test.py"], + enable_backends = [ + "tpu", + "cpu", + ], + enable_configs = [ + "tpu_v5e", + "tpu_v5p", + ], + shard_count = 8, + deps = [ + "//jax/_src/pallas:pallas_test_util", + "//jax/experimental:mesh_utils", + "//jax/experimental:pallas_tpu", + "//jax/experimental:pallas_tpu_ops", + "//jax/extend", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "tpu_pallas_call_print_test", + srcs = ["tpu_pallas_call_print_test.py"], # The flag is necessary for ``pl.debug_print`` tests to work on TPU. args = ["--logtostderr"], enable_backends = ["tpu"], enable_configs = [ "tpu_v5e", - "tpu_v5p_1x1", - ], + "tpu_v5p", + ], + env = if_oss( + { + "TPU_STDERR_LOG_LEVEL": "0", + }, + google_value = {}, + ), + shard_count = 8, deps = [ - "//jax:extend", - "//jax:pallas_tpu", - "//jax:pallas_tpu_ops", + "//jax/_src/pallas:pallas_test_util", + "//jax/experimental:pallas_tpu", + "//jax/experimental:pallas_tpu_ops", + "//jax/extend", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "gpu_pallas_distributed_test", + srcs = ["gpu_pallas_distributed_test.py"], + args = [ + "--num_processes=2", + "--gpus_per_process=1", + ], + enable_backends = [], + enable_configs = [ + "gpu_h100x2", + "gpu_b200x2", + ], + env = { + "JAX_PALLAS_USE_MOSAIC_GPU": "1", + "XLA_FLAGS": "--xla_gpu_experimental_enable_nvshmem=true", + }, + tags = [ + "mosaic_gpu_test", + "multiaccelerator", ], + deps = [ + "//jax/_src:test_multiprocess", + "//jax/experimental:pallas_experimental_gpu_ops", + "//jax/experimental:pallas_mosaic_gpu", + "//jax/extend", + ] + py_deps([ + "portpicker", + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -367,43 +487,59 @@ jax_multiplatform_test( srcs = [ "tpu_ops_test.py", ], - enable_backends = [ - "cpu", - "tpu", - ], + enable_backends = ["tpu"], + shard_count = 8, deps = [ - "//jax:pallas", - "//jax:pallas_gpu", # build_cleaner: keep - "//jax:pallas_tpu", - "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), + "//jax/_src/pallas:pallas_test_util", + "//jax/experimental:pallas", + "//jax/experimental:pallas_gpu", # build_cleaner: keep + "//jax/experimental:pallas_tpu", + "//jax/experimental:pallas_tpu_ops", + ] + py_deps([ + "absl/testing", + "hypothesis", + "numpy", + ]), ) jax_multiplatform_test( name = "tpu_pallas_distributed_test", srcs = ["tpu_pallas_distributed_test.py"], - enable_backends = ["tpu"], - enable_configs = [ - "tpu_v5e_4x2", - "tpu_v5p_2x2", - "tpu_v4_2x2", - "tpu_v3_2x2", + # TODO(sharadmv): fix timeouts + disable_configs = [ + "tpu_7x", + "tpu_7x_x4", ], + enable_backends = ["tpu"], + tags = ["multiaccelerator"], deps = [ - "//jax:extend", - "//jax:pallas_tpu", - "//jax:pallas_tpu_ops", - ], + "//jax/experimental:mesh_utils", + "//jax/experimental:pallas_tpu", + "//jax/experimental:pallas_tpu_ops", + "//jax/extend", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "tpu_pallas_pipeline_test", srcs = ["tpu_pallas_pipeline_test.py"], + args = if_oss( + [], + google_value = [ + # Timeout on deadlocks for a better error message. + "--xla_tpu_debug_sflag_wait_timeout_ms=500", + "--xla_tpu_debug_sflag_wait_shalt_on_detection", + ], + ), enable_backends = ["tpu"], enable_configs = [ - "tpu_v5e_4x2", - "tpu_v5p_1x1", + "tpu_v5e_x8", + "tpu_v5p", ], + minimal_shard_count = 5, shard_count = 5, tags = [ "noasan", # Times out. @@ -411,10 +547,16 @@ jax_multiplatform_test( "notsan", # Times out. ], deps = [ - "//jax:extend", - "//jax:pallas_tpu", - "//jax:pallas_tpu_ops", - ] + py_deps("hypothesis"), + "//jax/experimental:hijax", + "//jax/experimental:mesh_utils", + "//jax/experimental:pallas_tpu", + "//jax/experimental:pallas_tpu_ops", + "//jax/extend", + ] + py_deps([ + "absl/testing", + "numpy", + "hypothesis", + ]), ) jax_multiplatform_test( @@ -422,12 +564,44 @@ jax_multiplatform_test( srcs = ["tpu_pallas_async_test.py"], enable_backends = ["tpu"], enable_configs = [ - "tpu_v5e_4x2", - "tpu_v5p_1x1", + "tpu_v5e_x8", + "tpu_v5p", + "tpu_v5p_x4", ], deps = [ - "//jax:pallas_tpu", + "//jax/experimental:pallas_tpu", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "tpu_pallas_memory_space_test", + srcs = ["tpu_pallas_memory_space_test.py"], + enable_backends = ["tpu"], + enable_configs = [ + "tpu_v5p", ], + deps = [ + "//jax/experimental:pallas_tpu", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "tpu_side_effects_test", + srcs = ["tpu_side_effects_test.py"], + enable_backends = ["tpu"], + deps = [ + "//jax/experimental:pallas", + "//jax/experimental:pallas_tpu", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -440,9 +614,12 @@ jax_multiplatform_test( "notsan", ], deps = [ - "//jax:extend", - "//jax:pallas_tpu", - ], + "//jax/experimental:pallas_tpu", + "//jax/extend", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -452,14 +629,17 @@ jax_multiplatform_test( ], enable_backends = ["tpu"], enable_configs = [ - "tpu_v5p_2x2", + "tpu_v5p_x4", ], deps = [ - "//jax:pallas", - "//jax:pallas_tpu", - "//jax:pallas_tpu_ops", "//jax/_src/pallas/mosaic:random", - ] + py_deps("absl/testing") + py_deps("numpy"), + "//jax/experimental:pallas", + "//jax/experimental:pallas_tpu", + "//jax/experimental:pallas_tpu_ops", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -467,12 +647,30 @@ jax_multiplatform_test( srcs = [ "tpu_pallas_interpret_test.py", ], - disable_configs = ["cpu_shardy"], enable_backends = ["cpu"], deps = [ - "//jax:pallas", - "//jax:pallas_tpu", - ] + py_deps("absl/testing") + py_deps("numpy"), + "//jax/experimental:pallas", + "//jax/experimental:pallas_tpu", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "tpu_pallas_interpret_thread_map_test", + srcs = [ + "tpu_pallas_interpret_thread_map_test.py", + ], + enable_backends = ["cpu"], + deps = [ + "//jax/experimental", + "//jax/experimental:pallas", + "//jax/experimental:pallas_tpu", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -480,39 +678,46 @@ jax_multiplatform_test( srcs = [ "tpu_pallas_interpret_distributed_test.py", ], - disable_configs = ["cpu_shardy"], enable_backends = ["cpu"], deps = [ - "//jax:pallas", - "//jax:pallas_tpu", - ] + py_deps("absl/testing") + py_deps("numpy"), + "//jax/experimental:pallas", + "//jax/experimental:pallas_tpu", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "tpu_paged_attention_kernel_test", srcs = ["tpu_paged_attention_kernel_test.py"], disable_configs = [ - "tpu_v5p_1x1", + "tpu_v5p", ], enable_backends = ["tpu"], - shard_count = 5, + minimal_shard_count = 10, + shard_count = 10, tags = [ "noasan", # Times out. "nomsan", # Times out. "notsan", # Times out. ], deps = [ - "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy"), + "//jax/experimental:pallas_tpu_ops", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "tpu_ragged_paged_attention_test", srcs = ["tpu_ragged_paged_attention_test.py"], disable_configs = [ - "tpu_v5p_1x1", + "tpu_v5p", ], enable_backends = ["tpu"], + minimal_shard_count = 8, shard_count = 24, tags = [ "noasan", # Times out. @@ -520,8 +725,11 @@ jax_multiplatform_test( "notsan", # Times out. ], deps = [ - "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy"), + "//jax/experimental:pallas_tpu_ops", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -529,7 +737,11 @@ jax_multiplatform_test( srcs = [ "tpu_splash_attention_kernel_test.py", ], + disable_configs = [ + "tpu_7x_x4", # not a multi-chip test + ], enable_backends = ["tpu"], + minimal_shard_count = 8, shard_count = 24, tags = [ "noasan", # Times out. @@ -537,8 +749,32 @@ jax_multiplatform_test( "notsan", # Times out. ], deps = [ - "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), + "//jax/_src/pallas:pallas_test_util", + "//jax/experimental:pallas_tpu_ops", + ] + py_deps([ + "absl/testing", + "numpy", + "hypothesis", + ]), +) + +jax_multiplatform_test( + name = "tpu_splash_attention_kernel_sharded_test", + srcs = ["tpu_splash_attention_kernel_sharded_test.py"], + enable_configs = [ + "tpu_v5e_x8", + "tpu_v5p_x4", + ], + shard_count = 10, + deps = [ + "//jax/_src/pallas:pallas_test_util", + "//jax/experimental:pallas_tpu", + "//jax/experimental:pallas_tpu_ops", + "//jax/extend", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) # This test doesn't need a TPU; it only tests numpy-using helpers. @@ -549,9 +785,13 @@ jax_py_test( ], deps = [ "//jax", - "//jax:pallas_tpu_ops", - "//jax:test_util", - ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), + "//jax/_src:test_util", + "//jax/experimental:pallas_tpu_ops", + ] + py_deps([ + "absl/testing", + "numpy", + "hypothesis", + ]), ) jax_multiplatform_test( @@ -559,17 +799,24 @@ jax_multiplatform_test( srcs = [ "gpu_attention_test.py", ], - enable_backends = ["cpu"], + enable_backends = [], enable_configs = [ "gpu_a100_x32", "gpu_h100_x32", + "gpu_b200", ], shard_count = 1, + tags = [ + "noasan", # Times out. + ], deps = [ - "//jax:pallas", - "//jax:pallas_gpu", # build_cleaner: keep - "//jax:pallas_gpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy"), + "//jax/experimental:pallas", + "//jax/experimental:pallas_gpu", # build_cleaner: keep + "//jax/experimental:pallas_gpu_ops", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -581,6 +828,7 @@ jax_multiplatform_test( enable_configs = [ "gpu_a100_x32", "gpu_h100_x32", + "gpu_b200", ], shard_count = 20, tags = [ @@ -589,10 +837,13 @@ jax_multiplatform_test( "notsan", # Times out. ], deps = [ - "//jax:pallas", - "//jax:pallas_gpu", - "//jax:pallas_gpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy"), + "//jax/experimental:pallas", + "//jax/experimental:pallas_gpu", + "//jax/experimental:pallas_gpu_ops", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -604,29 +855,65 @@ jax_multiplatform_test( enable_configs = [ "gpu_a100_x32", "gpu_h100_x32", + "gpu_b200", ], - shard_count = 6, - deps = [ - "//jax:pallas", - "//jax:pallas_gpu", - "//jax:pallas_gpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy"), + shard_count = 12, + tags = [ + "noasan", # Times out. + "notsan", # Times out. + ], + deps = [ + "//jax/experimental:pallas", + "//jax/experimental:pallas_gpu", + "//jax/experimental:pallas_gpu_ops", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "triton_pallas_test", + srcs = [ + "triton_pallas_test.py", + ], + enable_backends = ["cpu"], + enable_configs = [ + "gpu_h100_x32", + "gpu_b200", + ], + env = { + "JAX_PALLAS_USE_MOSAIC_GPU": "0", + }, + shard_count = 1, + deps = [ + "//jax/experimental:pallas", + "//jax/experimental:pallas_gpu", + ] + py_deps([ + "absl/testing", + ]), ) jax_multiplatform_test( name = "mgpu_attention_run", srcs = ["//jax/experimental/pallas/ops/gpu:attention_mgpu.py"], enable_backends = [], - enable_configs = ["gpu_h100_x32"], + enable_configs = [ + "gpu_h100_x32", + "gpu_b200", + ], env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, tags = [ "manual", "notap", ], deps = [ - "//jax:pallas", - "//jax:pallas_mosaic_gpu", - ] + py_deps("absl/testing") + py_deps("numpy"), + "//jax/experimental:pallas", + "//jax/experimental:pallas_mosaic_gpu", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -636,15 +923,349 @@ jax_multiplatform_test( enable_configs = [ "gpu_h100_x32", "gpu_h100", + "gpu_b200", + ], + env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, + shard_count = 8, + tags = [ + "mosaic_gpu_test", + ], + deps = [ + "//jax/experimental:pallas", + "//jax/experimental:pallas_experimental_gpu_ops", + "//jax/experimental:pallas_mosaic_gpu", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "mgpu_examples_test_b200", + srcs = ["mgpu_examples_test.py"], + enable_backends = [], + enable_configs = ["gpu_b200"], + # env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, + shard_count = 8, + tags = [ + "mosaic_gpu_test", + # TODO(b/330364373): Remove when B200 is fully supported. + "notap", + ], + deps = [ + "//jax/experimental:pallas", + "//jax/experimental:pallas_experimental_gpu_ops", + "//jax/experimental:pallas_mosaic_gpu", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "mgpu_matmul_test", + srcs = ["mgpu_matmul_test.py"], + enable_backends = [], + enable_configs = [ + "gpu_h100", + "gpu_b200", + ], + env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, + shard_count = 8, + tags = ["mosaic_gpu_test"], + deps = [ + "//jax/experimental:pallas", + "//jax/experimental:pallas_experimental_gpu_ops", + "//jax/experimental:pallas_mosaic_gpu", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "hopper_matmul_mgpu_run", + srcs = ["//jax/experimental/pallas/ops/gpu:hopper_matmul_mgpu.py"], + enable_backends = [], + enable_configs = ["gpu_h100"], + env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, + tags = [ + "manual", + "notap", ], + deps = [ + "//jax/experimental:pallas", + "//jax/experimental:pallas_mosaic_gpu", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "hopper_mixed_type_matmul_mgpu_run", + srcs = ["//jax/experimental/pallas/ops/gpu:hopper_mixed_type_matmul_mgpu.py"], + enable_backends = [], + enable_configs = ["gpu_h100"], env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, + tags = [ + "manual", + "notap", + ], + deps = [ + "//jax/experimental:pallas", + "//jax/experimental:pallas_mosaic_gpu", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "blackwell_matmul_mgpu_run", + srcs = ["//jax/experimental/pallas/ops/gpu:blackwell_matmul_mgpu.py"], + enable_backends = [], + enable_configs = ["gpu_b200"], + tags = [ + "manual", + "notap", + ], + deps = [ + "//jax/experimental:pallas", + "//jax/experimental:pallas_mosaic_gpu", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "blackwell_ragged_dot_mgpu_run", + srcs = [ + "//jax/experimental/pallas/ops/gpu:blackwell_matmul_mgpu.py", + "//jax/experimental/pallas/ops/gpu:blackwell_ragged_dot_mgpu.py", + "//jax/experimental/pallas/ops/gpu:ragged_dot_mgpu.py", + ], + enable_backends = [], + enable_configs = ["gpu_b200"], + env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, + main = "//jax/experimental/pallas/ops/gpu:blackwell_ragged_dot_mgpu.py", + tags = [ + "manual", + "notap", + ], + deps = [ + "//jax/experimental:pallas", + "//jax/experimental:pallas_mosaic_gpu", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "mgpu_ragged_dot_run", + srcs = ["//jax/experimental/pallas/ops/gpu:ragged_dot_mgpu.py"], + enable_backends = [], + enable_configs = [ + "gpu_h100_x32", + "gpu_h100", + "gpu_b200", + ], + env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, + tags = [ + "manual", + "notap", + ], + deps = [ + "//jax/experimental:pallas", + "//jax/experimental:pallas_mosaic_gpu", + ] + py_deps("absl/testing") + py_deps("numpy"), +) + +jax_multiplatform_test( + name = "mgpu_transposed_ragged_dot_run", + srcs = [ + "//jax/experimental/pallas/ops/gpu:transposed_ragged_dot_mgpu.py", + ], + enable_backends = [], + enable_configs = [ + "gpu_h100_x32", + "gpu_h100", + ], + env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, + tags = [ + "manual", + "notap", + ], + deps = [ + "//jax/experimental:pallas", + "//jax/experimental:pallas_mosaic_gpu", + ] + py_deps("absl/testing") + py_deps("numpy"), +) + +jax_multiplatform_test( + name = "mgpu_ragged_dot_test", + size = "large", # Increased timout to account for the extra tests. + srcs = ["mgpu_ragged_dot_test.py"], + enable_backends = [], + enable_configs = [ + "gpu_h100", + "gpu_b200", + ], + env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, + shard_count = 20, + tags = [ + "mosaic_gpu_test", + "noasan", # Times out. + ], + deps = [ + "//jax/experimental:pallas", + "//jax/experimental:pallas_experimental_gpu_ops", + "//jax/experimental:pallas_mosaic_gpu", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "mgpu_collective_matmul_test", + srcs = ["mgpu_collective_matmul_test.py"], + args = [ + "--num_processes=2", + "--gpus_per_process=1", + ], + enable_backends = [], + enable_configs = [ + "gpu_h100x2", + "gpu_b200x2", + ], + env = { + "XLA_FLAGS": "--xla_gpu_experimental_enable_nvshmem=true", + "JAX_PALLAS_USE_MOSAIC_GPU": "1", + "FAIL_ON_NVSHMEM_UNAVAILABLE": "", + }, + shard_count = 8, + tags = [ + "manual", + "multiaccelerator", + "notap", + ], deps = [ - "//jax:pallas", - "//jax:pallas_experimental_gpu_ops", - "//jax:pallas_mosaic_gpu", + "//jax/_src:test_multiprocess", + "//jax/experimental:pallas", + "//jax/experimental:pallas_experimental_gpu_ops", + "//jax/experimental:pallas_mosaic_gpu", ] + py_deps("absl/testing") + py_deps("numpy"), ) +jax_multiplatform_test( + name = "mgpu_torch_test", + srcs = ["mgpu_torch_test.py"], + enable_backends = [], + enable_configs = [ + "gpu_h100_x32", + ], + env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, + tags = [ + "noasan", + ], + deps = [ + "//jax/experimental:pallas", + "//jax/experimental:pallas_experimental_gpu_ops", + "//jax/experimental:pallas_mosaic_gpu", + ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("torch"), +) + +jax_multiplatform_test( + name = "mgpu_collective_matmul_run", + srcs = ["//jax/experimental/pallas/ops/gpu:collective_matmul_mgpu.py"], + args = [ + "--num_processes=2", + "--gpus_per_process=1", + ], + enable_backends = [], + enable_configs = [ + "gpu_h100x2", + ], + env = { + "XLA_FLAGS": "--xla_gpu_experimental_enable_nvshmem=true", + "JAX_PALLAS_USE_MOSAIC_GPU": "1", + }, + tags = [ + "manual", + "multiaccelerator", + "notap", + ], + deps = [ + "//jax/_src:test_multiprocess", + "//jax/experimental:pallas", + "//jax/experimental:pallas_experimental_gpu_ops", + "//jax/experimental:pallas_mosaic_gpu", + ], +) + +jax_multiplatform_test( + name = "mgpu_reduce_scatter_run", + srcs = ["//jax/experimental/pallas/ops/gpu:reduce_scatter_mgpu.py"], + args = [ + "--num_processes=2", + "--gpus_per_process=1", + ], + enable_backends = [], + enable_configs = [ + "gpu_h100x2", + ], + env = { + "XLA_FLAGS": "--xla_gpu_experimental_enable_nvshmem=true", + "JAX_PALLAS_USE_MOSAIC_GPU": "1", + }, + tags = [ + "manual", + "multiaccelerator", + "notap", + ], + deps = [ + "//jax/_src:test_multiprocess", + "//jax/experimental:pallas", + "//jax/experimental:pallas_experimental_gpu_ops", + "//jax/experimental:pallas_mosaic_gpu", + ], +) + +jax_multiplatform_test( + name = "mgpu_all_gather_run", + srcs = ["//jax/experimental/pallas/ops/gpu:all_gather_mgpu.py"], + args = [ + "--num_processes=2", + "--gpus_per_process=1", + ], + enable_backends = [], + enable_configs = [ + "gpu_h100x2", + ], + env = { + "XLA_FLAGS": "--xla_gpu_experimental_enable_nvshmem=true", + "JAX_PALLAS_USE_MOSAIC_GPU": "1", + }, + tags = [ + "manual", + "mosaic_gpu_test", + "multiaccelerator", + ], + deps = [ + "//jax/_src:test_multiprocess", + "//jax/experimental:pallas_experimental_gpu_ops", + "//jax/experimental:pallas_mosaic_gpu", + "//jax/extend", + ] + py_deps([ + "portpicker", + "absl/testing", + "numpy", + ]), +) + jax_multiplatform_test( name = "fuser_block_spec_test", srcs = [ @@ -652,7 +1273,6 @@ jax_multiplatform_test( ], disable_configs = [ "cpu", - "cpu_shardy", ], enable_backends = ["cpu"], tags = [ @@ -661,34 +1281,44 @@ jax_multiplatform_test( "notsan", ], deps = [ - "//jax:pallas", "//jax/_src/pallas/fuser", - ] + py_deps("absl/testing") + py_deps("numpy"), + "//jax/experimental:pallas", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( - name = "tpu_fusable_matmul_test", - srcs = ["tpu_fusable_matmul_test.py"], + name = "fusion_test", + srcs = [ + "fusion_test.py", + ], disable_configs = [ - "tpu_v3_1x1", - "tpu_pjrt_c_api", - "gpu_v100", - "gpu_v100_x32", - "gpu_a100", - "gpu_p100", - "gpu_p100_x32", - "gpu_h100", "cpu", - "cpu_x32", - "cpu_shardy", ], - enable_backends = ["tpu"], - enable_configs = [ - "tpu_v4_1x1", - "tpu_v5e", - "tpu_v5p_1x1", - "tpu_v6e_1x1", + enable_backends = ["cpu"], + tags = [ + "noasan", + "nomsan", + "notsan", + ], + deps = [ + "//jax/experimental:pallas", + "//jax/experimental:pallas_fuser", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "tpu_fusible_matmul_test", + srcs = ["tpu_fusible_matmul_test.py"], + disable_configs = [ + "tpu_v3", ], + enable_backends = ["tpu"], shard_count = 4, tags = [ "noasan", @@ -696,8 +1326,124 @@ jax_multiplatform_test( "notsan", ], deps = [ - "//jax:pallas_tpu", - "//jax:pallas_tpu_ops", - "//jax/_src/pallas/fuser", - ] + py_deps("absl/testing") + py_deps("numpy"), + "//jax/experimental:pallas_fuser", + "//jax/experimental:pallas_tpu", + "//jax/experimental:pallas_tpu_ops", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "tpu_sparsecore_pallas_test", + srcs = ["tpu_sparsecore_pallas_test.py"], + config_tags_overrides = { + # TODO(slebedev): Flip to False once we have more v5p capacity. + "tpu_v5p": {"ondemand": True}, + }, + enable_backends = [], + enable_configs = [ + "tpu_v5p", + "tpu_v6e", + ], + shard_count = 10, + deps = [ + "//jax/experimental:pallas", + "//jax/experimental:pallas_tpu", + "//jax/experimental:pallas_tpu_sc", + ] + py_deps([ + "absl/testing", + "hypothesis", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "tpu_sparsecore_pallas_debug_check_test", + srcs = ["tpu_sparsecore_pallas_debug_check_test.py"], + config_tags_overrides = { + # TODO(slebedev): Flip to False once we have more v6e capacity. + "tpu_v6e": {"ondemand": True}, + }, + enable_backends = [], + enable_configs = [ + "tpu_v5p", + "tpu_v6e", + "cpu", # Just to make sure we at least check the shard_count in presubmit. + ], + shard_count = 3, + deps = [ + "//jax/experimental:pallas", + "//jax/experimental:pallas_tpu", + "//jax/experimental:pallas_tpu_sc", + ] + py_deps([ + "numpy", + "absl/testing", + "absl/flags", + ]), +) + +jax_multiplatform_test( + name = "tpu_sparsecore_pallas_distributed_test", + srcs = ["tpu_sparsecore_pallas_distributed_test.py"], + config_tags_overrides = { + # TODO(slebedev): Flip to False once we have more v5p capacity. + "tpu_v5p_x4": {"ondemand": True}, + "tpu_v6e_x8": {"ondemand": True}, + }, + enable_backends = [], + enable_configs = [ + "tpu_v5p_x4", + "tpu_v6e_x8", + ], + deps = [ + "//jax/experimental:mesh_utils", + "//jax/experimental:pallas_tpu", + "//jax/experimental:pallas_tpu_sc", + ] + py_deps([ + "numpy", + "absl/testing", + ]), +) + +jax_multiplatform_test( + name = "tpu_info_test", + srcs = ["tpu_info_test.py"], + enable_backends = ["tpu"], + enable_configs = [ + "tpu_7x", + ], + deps = [ + "//jax/experimental:pallas_tpu", + ] + py_deps("absl/testing"), +) + +jax_multiplatform_test( + name = "gpu_pallas_interpret_test", + srcs = ["gpu_pallas_interpret_test.py"], + enable_backends = ["cpu"], + deps = [ + "//jax/_src/pallas/mosaic_gpu/interpret:interpret_pallas_call", + "//jax/experimental:pallas", + "//jax/experimental:pallas_mosaic_gpu", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "tpu_trace_value_test", + srcs = ["tpu_trace_value_test.py"], + enable_backends = ["tpu"], + enable_configs = [ + "tpu_7x", + ], + deps = [ + "//jax/experimental:pallas", + "//jax/experimental:pallas_tpu", + ] + py_deps([ + "absl/testing", + ]), ) diff --git a/tests/pallas/export_back_compat_pallas_test.py b/tests/pallas/export_back_compat_pallas_test.py index addf14d73792..6dbe7cfb0287 100644 --- a/tests/pallas/export_back_compat_pallas_test.py +++ b/tests/pallas/export_back_compat_pallas_test.py @@ -17,6 +17,7 @@ update these tests. """ +import functools import math import unittest @@ -25,11 +26,13 @@ from jax._src import config from jax._src import test_util as jtu from jax._src.internal_test_util import export_back_compat_test_util as bctu +from jax._src.internal_test_util.export_back_compat_test_data.pallas import mosaic_gpu_add_one from jax._src.internal_test_util.export_back_compat_test_data.pallas import mosaic_matmul from jax._src.internal_test_util.export_back_compat_test_data.pallas import mosaic_semaphore_dma from jax._src.internal_test_util.export_back_compat_test_data.pallas import triton_add_one from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu +from jax.experimental.pallas import mosaic_gpu as plgpu from jax.experimental.pallas.ops.tpu import matmul import jax.numpy as jnp @@ -43,9 +46,6 @@ class CompatTest(bctu.CompatTestBase): def setUp(self): if jax.config.x64_enabled: self.skipTest("Only works in 32-bit") - if (jtu.test_device_matches(["cuda"]) and - not jtu.is_cuda_compute_capability_at_least("8.0")): - self.skipTest("Only works on GPUs with capability >= sm80") super().setUp() @unittest.skip("This test is checking backwards compatibility " @@ -53,6 +53,9 @@ def setUp(self): "compatibility for its IR, and we have since removed " "the corresponding custom call from the guaranteed stable list.") def test_triton_add_one(self): + if not jtu.is_cuda_compute_capability_at_least("8.0"): + self.skipTest("Only works on GPUs with capability >= sm80") + def func(x): def add_one(x_ref, o_ref): o_ref[0] = x_ref[0] + 1 @@ -65,11 +68,40 @@ def add_one(x_ref, o_ref): self.run_one_test(func, data) + def test_mosaic_gpu_add_one(self): + if not jtu.is_cuda_compute_capability_at_least("9.0"): + self.skipTest("Only works on GPUs with capability >= sm90") + + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct((128 * 2,), jnp.float32), + grid=2, + backend="mosaic_gpu", + ) + def add_one(x_ref, o_ref): + o_ref[...] = x_ref[...] + 1 + + data = self.load_testdata(mosaic_gpu_add_one.data_2025_04_22) + self.run_one_test(add_one, data, expect_current_custom_calls=["mosaic_gpu_v2"]) + + def test_mosaic_gpu_kernel_add_one(self): + if not jtu.is_cuda_compute_capability_at_least("9.0"): + self.skipTest("Only works on GPUs with capability >= sm90") + + @functools.partial( + plgpu.kernel, + out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), + grid=(2,), + grid_names=("x",), + ) + def add_one(x_ref, o_ref): + o_ref[...] = x_ref[...] + 1 + + data = self.load_testdata(mosaic_gpu_add_one.kernel_data_2025_09_07) + self.run_one_test(add_one, data) + @jax.default_matmul_precision("bfloat16") def test_mosaic_matmul(self): - # TODO(apaszke): Remove after 12 weeks have passed. - if not jtu.if_cloud_tpu_at_least(2024, 9, 30): - self.skipTest("Requires libtpu built after 2024-09-30") dtype = jnp.float32 def func(): # Build the inputs here, to reduce the size of the golden inputs. diff --git a/tests/pallas/export_pallas_test.py b/tests/pallas/export_pallas_test.py index 0c989f098220..e67592a21e3c 100644 --- a/tests/pallas/export_pallas_test.py +++ b/tests/pallas/export_pallas_test.py @@ -20,7 +20,9 @@ import jax from jax import export from jax._src import test_util as jtu +from jax._src.pallas import pallas_call as pallas_call_lib from jax.experimental import pallas as pl + import numpy as np try: from jax._src.lib import triton @@ -31,18 +33,20 @@ jax.config.parse_flags_with_absl() -class ExportTest(jtu.JaxTestCase): +class ExportTestWithTriton(jtu.JaxTestCase): def setUp(self): if sys.platform == "win32": self.skipTest("Only works on non-Windows platforms") - + self.enter_context(pallas_call_lib._PALLAS_USE_MOSAIC_GPU(False)) super().setUp() + def _check_cuda_export(self, exp): + self.assertRegex( + exp.mlir_module(), + r"stablehlo.custom_call @__gpu\$xla\.gpu\.triton.+name\s*=\s*\"my_custom_kernel_name\"") + def test_cross_platform(self): - # TODO(apaszke): Remove after 12 weeks have passed. - if not jtu.if_cloud_tpu_at_least(2024, 12, 19): - self.skipTest("Requires libtpu built after 2024-12-19") def add_vectors_kernel(x_ref, y_ref, o_ref): x, y = x_ref[...], y_ref[...] o_ref[...] = x + y @@ -83,9 +87,24 @@ def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array: exp.mlir_module(), r"stablehlo.custom_call @tpu_custom_call.+kernel_name\s*=\s*\"my_custom_kernel_name\"") if "cuda" in platforms: - self.assertRegex( - exp.mlir_module(), - r"stablehlo.custom_call @__gpu\$xla\.gpu\.triton.+name\s*=\s*\"my_custom_kernel_name\"") + self._check_cuda_export(exp) + + +class ExportTestWithMosaicGpu(ExportTestWithTriton): + + def setUp(self): + # TODO(b/432678342): remove once this is fixed. + if jtu.is_device_cuda() and not jtu.is_cuda_compute_capability_at_least("9.0"): + self.skipTest( + "LLVM seems to care about the compute capability if a GPU is present" + ) + super().setUp() + self.enter_context(pallas_call_lib._PALLAS_USE_MOSAIC_GPU(True)) + + def _check_cuda_export(self, exp): + self.assertRegex( + exp.mlir_module(), + r"stablehlo.custom_call @mosaic_gpu_v2.*my_custom_kernel_name") if __name__ == '__main__': diff --git a/tests/pallas/fuser_block_spec_test.py b/tests/pallas/fuser_block_spec_test.py index 1b3a215876ec..50e8753978be 100644 --- a/tests/pallas/fuser_block_spec_test.py +++ b/tests/pallas/fuser_block_spec_test.py @@ -19,6 +19,7 @@ from jax._src import config from jax._src import test_util as jtu from jax._src.pallas.fuser import block_spec as block_spec_lib +from jax._src.pallas.fuser import custom_fusion_lib from jax.experimental import pallas as pl import jax.numpy as jnp import numpy as np @@ -170,20 +171,62 @@ def f(x): fn(x) + b, ) + def test_custom_fusion(self): + @custom_fusion_lib.custom_fusion + def fn(x, y): + return x + y + + fn.def_pull_block_spec(lambda bss: (bss[0], bss[0])) + fn.def_push_block_spec(lambda bss: (bss[0],)) + fn.def_eval_rule(lambda _, x, y: (fn(x, y),)) + + in_type = ( + jax.ShapeDtypeStruct((512, 512), jnp.float32), + jax.ShapeDtypeStruct((512, 512), jnp.float32), + ) + f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values( + fn, *in_type + ) + self.assertEmpty(new_values) + self.assertEmpty(scalar_prefetch_values) + + block_spec = pl.BlockSpec((128, 128), lambda i, j, k: (i, j)) + kernel_fn, (value_block_specs, *in_block_specs), _ = ( + block_spec_lib.pull_block_spec( + f2, + block_spec, + grid=(1, 1, 1), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values, *in_type) + ) + self.assertEmpty(value_block_specs) + self.assertLen(in_block_specs, 2) + x_block_spec, y_block_spec = in_block_specs + self.assertEqual(x_block_spec.block_shape, (128, 128)) + self.assertEqual( + x_block_spec.index_map(0, 1, 2), block_spec.index_map(0, 1, 2) + ) + self.assertEqual(y_block_spec.block_shape, (128, 128)) + self.assertEqual( + y_block_spec.index_map(0, 1, 2), block_spec.index_map(0, 1, 2) + ) + + x = np.ones((128, 128), dtype=np.float32) + y = np.ones((128, 128), dtype=np.float32) + np.testing.assert_array_equal( + kernel_fn((0, 0, 0), scalar_prefetch_values, new_values, x, y), x + y + ) + @parameterized.product( fn=[lax.mul, lax.add, lax.sub, lax.div, lax.max, lax.lt, lax.eq, lax.gt], ) def test_binop(self, fn): - - def f(x, y): - return fn(x, y) - in_type = ( jax.ShapeDtypeStruct((512, 512), jnp.float32), jax.ShapeDtypeStruct((512, 512), jnp.float32), ) f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values( - f, *in_type + fn, *in_type ) self.assertEmpty(new_values) self.assertEmpty(scalar_prefetch_values) @@ -216,6 +259,39 @@ def f(x, y): fn(x, y), ) + def test_binop_bcast_mapped_dim(self): + in_type = ( + jax.ShapeDtypeStruct((128, 512), jnp.float32), + jax.ShapeDtypeStruct((1, 512), jnp.float32), + ) + f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values( + lax.add, *in_type + ) + self.assertEmpty(new_values) + self.assertEmpty(scalar_prefetch_values) + + block_spec = pl.BlockSpec((None, 128), lambda i, j: (i, j)) + kernel_fn, (value_block_specs, *in_block_specs), _ = ( + block_spec_lib.pull_block_spec( + f2, + block_spec, + grid=(128, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values, *in_type) + ) + self.assertEmpty(value_block_specs) + self.assertLen(in_block_specs, 2) + x_block_spec, y_block_spec = in_block_specs + self.assertEqual(x_block_spec.block_shape, (None, 128)) + self.assertEqual(x_block_spec.index_map(2, 3), (2, 3)) + self.assertEqual(y_block_spec.block_shape, (None, 128)) + self.assertEqual(y_block_spec.index_map(2, 3), (0, 3)) + + x = y = np.ones((128,), dtype=np.float32) + np.testing.assert_array_equal( + kernel_fn((0, 0), scalar_prefetch_values, new_values, x, y), x + y + ) + def test_slice(self): x = jax.random.normal(jax.random.key(0), (4, 512, 512), dtype=np.float32) @@ -653,9 +729,12 @@ def f(): kernel_fn((0, 0, 3, 0), scalar_prefetch_values, ()), x ) - def test_broadcast_array(self): + @parameterized.parameters( + (False, False), (False, True), (True, False), (True, True) + ) + def test_broadcast_array(self, bcast0, bcast1): - x = jnp.ones((512, 512)) + x = jnp.ones((1 if bcast0 else 512, 1 if bcast1 else 512)) def f(): return jax.lax.broadcast_in_dim(x, (2, 2, 512, 512), (2, 3)) @@ -664,9 +743,47 @@ def f(): self.assertLen(new_values, 1) self.assertEmpty(scalar_prefetch_values) - block_spec = pl.BlockSpec( - (None, 1, 128, 128), lambda i, j, k, l: (i, j, k, l) + block_shape = (None, 1, 128, 128) + block_spec = pl.BlockSpec(block_shape, lambda i, j, k, l: (i, j, k, l)) + kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec( + f2, + block_spec, + grid=(2, 2, 4, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values) + self.assertLen(value_block_specs, 1) + x_index_map = value_block_specs[0].index_map + self.assertEqual( + x_index_map(0, 0, 1, 2), (0 if bcast0 else 1, 0 if bcast1 else 2) ) + self.assertEqual( + x_index_map(1, 2, 3, 3), (0 if bcast0 else 3, 0 if bcast1 else 3) + ) + + block_shape = (1 if bcast0 else 128, 1 if bcast1 else 128) + self.assertEqual(block_shape, value_block_specs[0].block_shape) + x = jnp.full(block_shape, fill_value=1.2345, dtype=jnp.float32) + y = jax.lax.broadcast_in_dim(x, (1, 128, 128), (1, 2)) + np.testing.assert_array_equal(kernel_fn((0, 0, 0, 0), (), (x,)), y) + np.testing.assert_array_equal(kernel_fn((1, 1, 0, 0), (), (x,)), y) + np.testing.assert_array_equal(kernel_fn((0, 0, 0, 1), (), (x,)), y) + np.testing.assert_array_equal(kernel_fn((0, 0, 1, 0), (), (x,)), y) + np.testing.assert_array_equal(kernel_fn((0, 0, 3, 0), (), (x,)), y) + + @parameterized.parameters(0, 1, 2, 3) + def test_broadcast_1d_array(self, bcast_dim): + full_shape = (2, 2, 512, 512) + x = jnp.ones((full_shape[bcast_dim],)) + + def f(): + return jax.lax.broadcast_in_dim(x, full_shape, (bcast_dim,)) + + f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f) + self.assertLen(new_values, 1) + self.assertEmpty(scalar_prefetch_values) + + block_shape = (None, 1, 128, 128) + block_spec = pl.BlockSpec(block_shape, lambda i, j, k, l: (i, j, k, l)) kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec( f2, block_spec, @@ -674,26 +791,364 @@ def f(): scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), )(new_values) self.assertLen(value_block_specs, 1) - x_block_spec = value_block_specs[0] - self.assertEqual(x_block_spec.index_map(0, 0, 1, 2), (1, 2)) - self.assertEqual(x_block_spec.index_map(1, 2, 3, 3), (3, 3)) + x_index_map = value_block_specs[0].index_map + self.assertEqual(x_index_map(0, 0, 1, 2), ((0, 0, 1, 2)[bcast_dim],)) + self.assertEqual(x_index_map(1, 2, 3, 3), ((1, 2, 3, 3)[bcast_dim],)) - x = jnp.full((128, 128), fill_value=1.2345, dtype=jnp.float32) - np.testing.assert_array_equal( - kernel_fn((0, 0, 0, 0), scalar_prefetch_values, (x,)), x + if block_shape[bcast_dim] is None: + x = jnp.ones(()) + y = jax.lax.broadcast_in_dim(x, (1, 128, 128), ()) + else: + x = jnp.arange(block_shape[bcast_dim] or 1, dtype=jnp.float32) + y = jax.lax.broadcast_in_dim(x, (1, 128, 128), (bcast_dim - 1,)) + + np.testing.assert_array_equal(kernel_fn((0, 0, 0, 0), (), (x,)), y) + np.testing.assert_array_equal(kernel_fn((1, 1, 0, 0), (), (x,)), y) + np.testing.assert_array_equal(kernel_fn((0, 0, 0, 1), (), (x,)), y) + np.testing.assert_array_equal(kernel_fn((0, 0, 1, 0), (), (x,)), y) + np.testing.assert_array_equal(kernel_fn((0, 0, 3, 0), (), (x,)), y) + + def test_element_indexing(self): + + x = np.zeros((512, 512), dtype=np.float32) + + def f(): + return x + + f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f) + self.assertLen(new_values, 1) + self.assertEmpty(scalar_prefetch_values) + + # Block spec with an offset on the first dimension + block_spec = pl.BlockSpec( + (pl.Element(128, (0, 16)), 128), lambda i, j, k: (128 * i + 16, j) ) - np.testing.assert_array_equal( - kernel_fn((1, 1, 0, 0), scalar_prefetch_values, (x,)), x + kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec( + f2, + block_spec, + grid=(1, 1, 1), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values) + self.assertLen(value_block_specs, 1) + self.assertEmpty(scalar_prefetch_values) + self.assertEqual( + value_block_specs[0].block_shape, (pl.Element(128, (0, 16)), 128) ) + self.assertEqual(value_block_specs[0].index_map(0, 1, 2), (16, 1)) + self.assertEqual(value_block_specs[0].index_map(1, 1, 2), (128 + 16, 1)) + + x_block = np.ones((128, 128), dtype=np.float32) np.testing.assert_array_equal( - kernel_fn((0, 0, 0, 1), scalar_prefetch_values, (x,)), x + kernel_fn( + (0, 0, 0), + scalar_prefetch_values, + (np.ones((128, 128), dtype=np.float32),), + ), + x_block, ) - np.testing.assert_array_equal( - kernel_fn((0, 0, 1, 0), scalar_prefetch_values, (x,)), x + + @parameterized.parameters( + # Merge two dimensions. + ((8, 8, 128), (1, 2, 128), (1, 1, 2, 128), (0, 2, 3, 5)), + ((2, 32, 128), (2, 4, 128), (2, 1, 4, 128), (2, 1, 1, 5)), + ((2, 4, 1024), (2, 1, 128), (2, 1, 1, 128), (2, 3, 5, 0)), + ((2, 4, 1024), (2, None, 128), (2, 1, 1, 128), (2, 3, 5, 0)), + # Merge three dimensions. + ((64, 128), (4, 128), (1, 1, 4, 128), (0, 1, 0, 3)), + ((2, 4096), (1, 64), (1, 1, 1, 64), (2, 0, 1, 1)), + # Merge two pairs of dimensions. + ((8, 1024), (1, 256), (1, 1, 2, 128), (0, 2, 3, 0)), + # Merge three dims and expand in trailing dim. + ((64, 128, 1), (4, 128, 1), (1, 1, 4, 128), (0, 1, 0, 3)), + # Test propagating pl.BoundedSlice. + ( + (8, 8, 128), + (2, pl.BoundedSlice(2), 16), + (1, 2, pl.BoundedSlice(2), 16), + (1, 0, 3, 5), + ), + ) + def test_reshape( + self, shape, block_shape, expected_x_block_shape, expected_x_index + ): + f = lambda x: x.reshape(shape) + in_type = jax.ShapeDtypeStruct((2, 4, 8, 128), jnp.float32) + f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values( + f, in_type ) - np.testing.assert_array_equal( - kernel_fn((0, 0, 3, 0), scalar_prefetch_values, (x,)), x + self.assertEmpty(new_values) + self.assertEmpty(scalar_prefetch_values) + + block_spec = pl.BlockSpec(block_shape, lambda *pids: pids) + kernel_fn, (value_block_specs, x_block_spec), _ = ( + block_spec_lib.pull_block_spec( + f2, + block_spec, + grid=(2, 3, 4)[: len(shape)], + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values, in_type) + ) + pids = (2, 3, 5)[: len(shape)] + self.assertEmpty(value_block_specs) + self.assertEqual(x_block_spec.block_shape, expected_x_block_shape) + self.assertEqual(x_block_spec.index_map(*pids), expected_x_index) + + def shape_to_concrete(block_shape): + return [ + bd.block_size if isinstance(bd, pl.BoundedSlice) else bd + for bd in block_shape + if bd is not None + ] + + concrete_block_shape = shape_to_concrete(block_shape) + concrete_expected_block_shape = shape_to_concrete(expected_x_block_shape) + + x = jnp.arange( + np.prod(concrete_block_shape), + dtype=jnp.float32, + ) + x = x.reshape(concrete_expected_block_shape) + y = kernel_fn((0, 1, 2), scalar_prefetch_values, (), x) + np.testing.assert_array_equal(y, x.reshape(concrete_block_shape)) + + def test_basic_reshape_sublanes_to_lanes(self): + + def f(x): + return x.reshape((512, 4096)) + + in_type = jax.ShapeDtypeStruct((512, 32, 128), jnp.float32) + f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values( + f, in_type + ) + self.assertEmpty(new_values) + self.assertEmpty(scalar_prefetch_values) + + block_spec = pl.BlockSpec((256, 1024), lambda i, j, k: (i, k)) + kernel_fn, (value_block_specs, x_block_spec), _ = ( + block_spec_lib.pull_block_spec( + f2, + block_spec, + grid=(2, 3, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values, in_type) ) + self.assertEmpty(value_block_specs) + self.assertEqual(x_block_spec.block_shape, (256, 8, 128)) + self.assertEqual(x_block_spec.index_map(0, 1, 2), (0, 2, 0)) + self.assertEqual(x_block_spec.index_map(3, 2, 1), (3, 1, 0)) + + x = jnp.arange((256 * 1024), dtype=jnp.float32).reshape((256, 8, 128)) + y = kernel_fn((0, 1, 2), scalar_prefetch_values, (), x) + np.testing.assert_array_equal(y, x.reshape((256, 1024))) + + def test_basic_reshape_lanes_to_sublanes(self): + + def f(x): + return x.reshape((512, 32, 128)) + + in_type = jax.ShapeDtypeStruct((512, 4096), jnp.float32) + f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values( + f, in_type + ) + self.assertEmpty(new_values) + self.assertEmpty(scalar_prefetch_values) + + block_spec = pl.BlockSpec((256, 8, 128), lambda i, j, k: (i, k, 0)) + kernel_fn, (value_block_specs, x_block_spec), _ = ( + block_spec_lib.pull_block_spec( + f2, + block_spec, + grid=(2, 3, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values, in_type) + ) + self.assertEmpty(value_block_specs) + self.assertEqual(x_block_spec.index_map(0, 1, 2), (0, 2)) + self.assertEqual(x_block_spec.index_map(3, 2, 1), (3, 1)) + + x = jnp.arange((256 * 1024), dtype=jnp.float32).reshape((256, 1024)) + y = kernel_fn((0, 1, 2), scalar_prefetch_values, (), x) + np.testing.assert_array_equal(y, x.reshape((256, 8, 128))) + + block_spec = pl.BlockSpec((256, 4, 256), lambda i, j, k: (i, j, k)) + with self.assertRaises(NotImplementedError): + _ = block_spec_lib.pull_block_spec( + f2, + block_spec, + grid=(2, 3, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values, in_type) + + def test_basic_swap(self): + value = jnp.arange((512 * 1024), dtype=jnp.int32).reshape((512, 1024)) * 2 + x = jnp.zeros((256, 512), dtype=jnp.int32) + + def outer(refs): + ref, y_ref = refs + + def f(x): + return ref.swap(x) + + in_type = jax.ShapeDtypeStruct((512, 1024), jnp.int32) + f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values( + f, in_type + ) + self.assertLen(new_values, 1) # Captures Ref + self.assertEmpty(scalar_prefetch_values) + + block_spec = pl.BlockSpec((256, 512), lambda i, j, k: (i, k)) + kernel_fn, (value_block_specs, x_block_spec), _ = ( + block_spec_lib.pull_block_spec( + f2, + block_spec, + grid=(2, 3, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values, in_type) + ) + self.assertLen(value_block_specs, 1) + self.assertEqual(x_block_spec.index_map(0, 1, 2), (0, 2)) + self.assertEqual(x_block_spec.index_map(3, 2, 1), (3, 1)) + + y_ref[...] = kernel_fn((0, 1, 1), scalar_prefetch_values, (ref,), x) + + y = jnp.zeros((256, 512), jnp.int32) + _, y = pl.run_state(outer)((value, y)) + np.testing.assert_array_equal(y, value[:256, 512:1024]) + + def test_basic_get(self): + value = jnp.arange((512 * 1024), dtype=jnp.int32).reshape((512, 1024)) * 2 + + def outer(refs): + ref, y_ref = refs + + def f(): + return ref.get() + + block_spec = pl.BlockSpec((256, 512), lambda i, j, k: (i, k)) + kernel_fn, (), _ = block_spec_lib.pull_block_spec( + f, + block_spec, + grid=(2, 3, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )() + y_ref[...] = kernel_fn((0, 1, 1), ()) + + y = jnp.zeros((256, 512), jnp.int32) + _, y = pl.run_state(outer)((value, y)) + np.testing.assert_array_equal(y, value[:256, 512:1024]) + + def test_get_with_squeezed_block_spec(self): + value = ( + jnp.arange((4 * 512 * 1024), dtype=jnp.int32).reshape((4, 512, 1024)) + * 2 + ) + + def outer(refs): + ref, y_ref = refs + + def f(): + return ref.get() + + block_spec = pl.BlockSpec( + (pl.Squeezed(), 256, 512), lambda i, j, k: (j, i, k) + ) + kernel_fn, (), _ = block_spec_lib.pull_block_spec( + f, + block_spec, + grid=(2, 3, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )() + y_ref[...] = kernel_fn((0, 3, 1), ()) + + y = jnp.zeros((256, 512), jnp.int32) + _, y = pl.run_state(outer)((value, y)) + np.testing.assert_array_equal(y, value[3, :256, 512:1024]) + + def test_get_with_squeezed_indexer(self): + value = ( + jnp.arange((4 * 512 * 1024), dtype=jnp.int32).reshape((4, 512, 1024)) + * 2 + ) + + def outer(refs): + ref, y_ref = refs + + def f(): + return ref[3] + + block_spec = pl.BlockSpec((256, 512), lambda i, j, k: (i, k)) + kernel_fn, (), _ = block_spec_lib.pull_block_spec( + f, + block_spec, + grid=(2, 3, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )() + y_ref[...] = kernel_fn((0, 2, 1), ()) + + y = jnp.zeros((256, 512), jnp.int32) + _, y = pl.run_state(outer)((value, y)) + np.testing.assert_array_equal(y, value[3, :256, 512:1024]) + + def test_random_noise(self): + key = jax.random.key(0, impl='threefry2x32') + + def f(key): + return jax.random.uniform(key, (512, 512), dtype=jnp.float32) + + f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values( + f, key + ) + self.assertEmpty(new_values) + self.assertEmpty(scalar_prefetch_values) + + block_spec = pl.BlockSpec((128, 256), lambda i, j: (i, j)) + kernel_fn, (value_block_specs, key_block_spec), _ = ( + block_spec_lib.pull_block_spec( + f2, + block_spec, + grid=(4, 2), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values, key) + ) + self.assertEmpty(value_block_specs) + self.assertEqual(key_block_spec.memory_space, pl.MemorySpace.KEY) + self.assertIsNone(key_block_spec.block_shape) + + @jax.jit + def gen(idx): + k = key + for i in idx: + k = jax.random.fold_in(k, i) + return jax.random.uniform(k, (128, 256), dtype=jnp.float32) + + for i in range(4): + for j in range(2): + out = kernel_fn((i, j), scalar_prefetch_values, (), key) + out_ref = gen((i, j)) + np.testing.assert_array_equal(out, out_ref) + + def test_reduce_sum(self): + + x = jnp.arange(1024 * 256, dtype=jnp.float32).reshape((1024, 256)) + + def f(): + return x.sum(axis=1) + + f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f) + self.assertLen(new_values, 1) + self.assertEmpty(scalar_prefetch_values) + + block_spec = pl.BlockSpec((128,), lambda i: (i,)) + kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec( + f2, + block_spec, + grid=(8,), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values) + self.assertLen(value_block_specs, 1) + y = x[128:256] + out = kernel_fn((1,), scalar_prefetch_values, (y,)) + np.testing.assert_array_equal(out, y.sum(axis=1)) class PullBlockSpecHOPTest(jtu.JaxTestCase): @@ -769,6 +1224,87 @@ def f(x): kernel_fn((0, 0, 0, 0), scalar_prefetch_values, (), x), relu_x ) + def test_custom_vjp(self): + @jax.custom_vjp + def act(x): + return jax.nn.relu(x) * x + + def act_fwd(x): + return jax.nn.relu(x) * x, (x,) + + def act_bwd(res, dy): + (x,) = res + return (dy * x * 2.34,) + + act.defvjp(act_fwd, act_bwd) + + def f(x): + return act(x) + + in_type = jax.ShapeDtypeStruct((512, 512), jnp.float32) + f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values( + f, in_type + ) + self.assertEmpty(new_values) + self.assertEmpty(scalar_prefetch_values) + + block_spec = pl.BlockSpec( + (None, 1, 128, 128), lambda i, j, k, l, _: (i, j, k, l) + ) + kernel_fn, (value_block_specs, *in_block_specs), _ = ( + block_spec_lib.pull_block_spec( + f2, + block_spec, + grid=(2, 2, 4, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values, in_type) + ) + self.assertEmpty(value_block_specs) + x_block_spec = in_block_specs[0] + self.assertEqual(x_block_spec.index_map(0, 0, 1, 2, ()), (0, 0, 1, 2)) + self.assertEqual(x_block_spec.index_map(1, 2, 3, 3, ()), (1, 2, 3, 3)) + + x = jax.random.normal(jax.random.key(0), (1, 128, 128), dtype=np.float32) + relu_x_x = jax.nn.relu(x) * x + np.testing.assert_array_equal( + kernel_fn((0, 0, 0, 0), scalar_prefetch_values, (), x), relu_x_x + ) + + def test_pull_block_spec_handles_closed_over_constants(self): + x = jnp.ones((2, 512, 512)) + i = jnp.array(1) + + def f(): + return x[i] + + f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f) + self.assertLen(new_values, 1) + self.assertLen(scalar_prefetch_values, 1) + + block_spec = pl.BlockSpec( + (None, 1, 128, 128), lambda i, j, k, l, _: (i, j, k, l) + ) + kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec( + f2, + block_spec, + grid=(2, 2, 4, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values) + self.assertLen(value_block_specs, 1) + scalar_prefetch_values = jax.tree.map( + lambda x: x[None], scalar_prefetch_values + ) + fn = lambda x: kernel_fn((0, 0, 0, 0), scalar_prefetch_values, x) + new_values_type = (jax.ShapeDtypeStruct((1, 128, 128), jnp.float32),) + # Try pulling again + # This should not raise an error. + _ = block_spec_lib.pull_block_spec( + fn, + block_spec, + grid=(1,), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values_type) + class PushBlockSpecTest(parameterized.TestCase): @@ -777,6 +1313,49 @@ def setUp(self): if config.enable_x64.value: self.skipTest('x64 not supported') + def test_binop(self): + + def f(x): + return x + jnp.ones_like(x) + + block_spec = pl.BlockSpec((128, 128), lambda i, j: (i, j)) + x_type = jax.ShapeDtypeStruct((512, 512), jnp.float32) + out_block_spec = block_spec_lib.push_block_spec(f, block_spec)(x_type) + self.assertEqual(out_block_spec.block_shape, block_spec.block_shape) + + def f(x, y): + return x + y + + x_block_spec = pl.BlockSpec((128, 128), lambda i, j: (i, j)) + y_block_spec = pl.BlockSpec((128, 1), lambda i, j: (i, 0)) + x_type = jax.ShapeDtypeStruct((512, 512), jnp.float32) + y_type = jax.ShapeDtypeStruct((512, 1), jnp.float32) + with self.assertRaisesRegex( + ValueError, 'Cannot propagate block spec through RHS broadcast.' + ): + block_spec_lib.push_block_spec(f, pl.no_block_spec, y_block_spec)( + x_type, y_type + ) + out_block_spec = block_spec_lib.push_block_spec( + f, x_block_spec, pl.no_block_spec + )(x_type, y_type) + self.assertIs(x_block_spec, out_block_spec) + + x_block_spec = pl.BlockSpec((1, 128), lambda i, j: (0, j)) + y_block_spec = pl.BlockSpec((128, 128), lambda i, j: (i, j)) + x_type = jax.ShapeDtypeStruct((1, 512), jnp.float32) + y_type = jax.ShapeDtypeStruct((512, 512), jnp.float32) + with self.assertRaisesRegex( + ValueError, 'Cannot propagate block spec through LHS broadcast.' + ): + block_spec_lib.push_block_spec(f, x_block_spec, pl.no_block_spec)( + x_type, y_type + ) + out_block_spec = block_spec_lib.push_block_spec( + f, pl.no_block_spec, y_block_spec + )(x_type, y_type) + self.assertIs(out_block_spec, y_block_spec) + def test_jit(self): def f(x): @@ -800,6 +1379,149 @@ def f(x): out_block_spec = block_spec_lib.push_block_spec(f, block_spec)(x_type) self.assertEqual(out_block_spec.block_shape, block_spec.block_shape) + def test_push_reshape_lanes_to_sublanes(self): + def f(x): + return x.reshape((512, 32, 128)) + + x_type = jax.ShapeDtypeStruct((512, 4096), jnp.float32) + block_spec = pl.BlockSpec((256, 1024), lambda i, j, k: (i, k)) + out_block_spec = block_spec_lib.push_block_spec(f, block_spec)(x_type) + self.assertEqual(out_block_spec.block_shape, (256, 8, 128)) + self.assertTupleEqual(out_block_spec.index_map(0, 1, 2), (0, 2, 0)) + self.assertEqual(out_block_spec.index_map(3, 2, 1), (3, 1, 0)) + + def f(x): + return x.reshape((512, 16, 256)) + + x_type = jax.ShapeDtypeStruct((512, 4096), jnp.float32) + block_spec = pl.BlockSpec((256, 1024), lambda i, j, k: (i, k)) + out_block_spec = block_spec_lib.push_block_spec(f, block_spec)(x_type) + self.assertEqual(out_block_spec.block_shape, (256, 4, 256)) + self.assertTupleEqual(out_block_spec.index_map(0, 1, 2), (0, 2, 0)) + self.assertEqual(out_block_spec.index_map(3, 2, 1), (3, 1, 0)) + + def test_custom_vjp(self): + @jax.custom_vjp + def act(x): + return jax.nn.relu(x) * x + + def act_fwd(x): + return jax.nn.relu(x) * x, (x,) + + def act_bwd(res, dy): + (x,) = res + return (dy * x * 2.34,) + + act.defvjp(act_fwd, act_bwd) + + def f(x): + return act(x) + + x_type = jax.ShapeDtypeStruct((1, 1, 512, 512), jnp.float32) + block_spec = pl.BlockSpec( + (None, 1, 128, 128), lambda i, j, k, l, _: (i, l, k, j) + ) + out_block_spec = block_spec_lib.push_block_spec(f, block_spec)(x_type) + self.assertEqual(out_block_spec.block_shape, block_spec.block_shape) + + def test_reduce_sum_push(self): + def f(x): + return x.sum(axis=0) + + x_type = jax.ShapeDtypeStruct((256, 512), jnp.float32) + block_spec = pl.BlockSpec((256, 256), lambda i, j: (i, j)) + out_block_spec = block_spec_lib.push_block_spec(f, block_spec)(x_type) + self.assertEqual(out_block_spec.block_shape, (256,)) + self.assertEqual(out_block_spec.index_map(2, 3), (3,)) + + def f(x): + return x.sum(axis=1) + + x_type = jax.ShapeDtypeStruct((128, 512), jnp.float32) + block_spec = pl.BlockSpec((64, 512), lambda i, j: (i, j)) + out_block_spec = block_spec_lib.push_block_spec(f, block_spec)(x_type) + self.assertEqual(out_block_spec.block_shape, (64,)) + self.assertEqual(out_block_spec.index_map(2, 3), (2,)) + + def test_broadcast_in_dim_push(self): + def f(x): + return jnp.broadcast_to(x, (128, 512)) + + x_type = jax.ShapeDtypeStruct((512,), jnp.float32) + block_spec = pl.BlockSpec((128,), lambda i: (i,)) + out_block_spec = block_spec_lib.push_block_spec(f, block_spec)(x_type) + self.assertEqual(out_block_spec.block_shape, (128, 128)) + self.assertEqual(out_block_spec.index_map(3), (0, 3)) + + def f(x): + return jnp.broadcast_to(x, (128, 512)) + + x_type = jax.ShapeDtypeStruct((1, 512), jnp.float32) + block_spec = pl.BlockSpec((1, 128), lambda i, j: (i, j)) + out_block_spec = block_spec_lib.push_block_spec(f, block_spec)(x_type) + self.assertEqual(out_block_spec.block_shape, (128, 128)) + self.assertEqual(out_block_spec.index_map(0, 3), (0, 3)) + + def f(x): + x = jnp.expand_dims(x, axis=1) + return jnp.broadcast_to(x, (128, 512)) + + x_type = jax.ShapeDtypeStruct((128,), jnp.float32) + block_spec = pl.BlockSpec((64,), lambda i: (i,)) + out_block_spec = block_spec_lib.push_block_spec(f, block_spec)(x_type) + self.assertEqual(out_block_spec.block_shape, (64, 512)) + self.assertEqual(out_block_spec.index_map(1), (1, 0)) + + def f(x): + x = jnp.expand_dims(x, axis=0) + return jnp.broadcast_to(x, (128, 512)) + + x_type = jax.ShapeDtypeStruct((512,), jnp.float32) + block_spec = pl.BlockSpec((256,), lambda i: (i,)) + out_block_spec = block_spec_lib.push_block_spec(f, block_spec)(x_type) + self.assertEqual(out_block_spec.block_shape, (128, 256)) + self.assertEqual(out_block_spec.index_map(1), (0, 1)) + + def test_concatenate_push(self): + def f(x1, x2): + return jnp.concatenate((x1, x2), axis=0) + + x_type = jax.ShapeDtypeStruct((512,), jnp.float32) + block_spec = pl.BlockSpec((128,), lambda i: (i,)) + with self.assertRaisesRegex( + NotImplementedError, 'concatenate not supported yet' + ): + block_spec_lib.push_block_spec(f, block_spec, block_spec)(x_type, x_type) + x_type = jax.ShapeDtypeStruct((512,), jnp.float32) + block_spec = pl.BlockSpec((512,), lambda i: (i,)) + out_block_spec = block_spec_lib.push_block_spec(f, block_spec, block_spec)( + x_type, x_type + ) + self.assertEqual(out_block_spec.block_shape, (1024,)) + self.assertEqual(out_block_spec.index_map(0), (0,)) + + def f(x1, x2): + return jnp.stack([x1, x2], axis=0) + + x_type = jax.ShapeDtypeStruct((512,), jnp.float32) + block_spec = pl.BlockSpec((128,), lambda i: (i,)) + out_block_spec = block_spec_lib.push_block_spec(f, block_spec, block_spec)( + x_type, x_type + ) + self.assertEqual(out_block_spec.block_shape, (2, 128)) + self.assertEqual(out_block_spec.index_map(3), (0, 3)) + + def f(x1, x2): + return jnp.stack([x1, x2], axis=1) + + x_type = jax.ShapeDtypeStruct((512,), jnp.float32) + block_spec = pl.BlockSpec((128,), lambda i: (i,)) + out_block_spec = block_spec_lib.push_block_spec(f, block_spec, block_spec)( + x_type, x_type + ) + self.assertEqual(out_block_spec.block_shape, (128, 2)) + self.assertEqual(out_block_spec.index_map(3), (3, 0)) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/fusion_test.py b/tests/pallas/fusion_test.py new file mode 100644 index 000000000000..a625b52e9559 --- /dev/null +++ b/tests/pallas/fusion_test.py @@ -0,0 +1,322 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import dataclasses + +from absl.testing import absltest +import jax +from jax import lax +from jax._src import core as jax_core +from jax._src import hijax +from jax._src import test_util as jtu +from jax.experimental.pallas import fuser +import jax.numpy as jnp +import numpy as np + +jax.config.parse_flags_with_absl() + + +class FusionTest(jtu.JaxTestCase): + + def test_basic_fusion(self): + + @jax.jit + @fuser.fuse + @fuser.fusible + def f(x_fn, y_fn): + x = x_fn() + if y_fn is None: + y_fn = lambda x: x + return y_fn(x) + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + np.testing.assert_array_equal(f(x), x) + + def test_separate_output_fusions_trivial(self): + + @fuser.fusible(output_fusion_prefix=(True, True)) + def f(x_fn, y_fn, z_fns): + x = x_fn() + y = y_fn() + if z_fns is None: + z_fns = lambda x: x, lambda x: x + z_fn1, z_fn2 = z_fns + return z_fn1(x), z_fn2(y) + + @jax.jit + @fuser.fuse + def g(x, y): + x, y = f(x, y) + return x, y * 2 + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + x_out, y_out = g(x, y) + np.testing.assert_array_equal(x_out, x) + np.testing.assert_array_equal(y_out, y * 2) + + def test_custom_fusion(self): + const = jnp.array(1.0, dtype=jnp.float32) + const2 = jnp.array(2.0, dtype=jnp.float32) + const3 = jnp.array(3.0, dtype=jnp.float32) + + @fuser.custom_fusion + def c(x, y): + return x + y + const + + c.def_pull_block_spec(lambda bss: (bss[0], bss[0])) + c.def_push_block_spec(lambda bss: (bss[0],)) + c.def_eval_rule(lambda _, x, y: (c(x, y),)) + c.def_pallas_impl(lambda x, y: x + y + const2 + const3) + + @fuser.fusible(output_fusion_prefix=(True, True)) + def f(x_fn, y_fn, z_fns): + x = x_fn() + y = y_fn() + if z_fns is None: + z_fns = lambda x: x, lambda x: x + z_fn1, z_fn2 = z_fns + return z_fn1(x), z_fn2(y) + + def g(x, y, z): + x, y = f(x, c(y, z)) + return c(x, z), y * 2 + + x = jax.random.normal(jax.random.key(0), (4, 4), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (1, 4), dtype=jnp.float32) + z = jax.random.normal(jax.random.key(2), (1, 4), dtype=jnp.float32) + x_out, y_out = g(x, y, z) + np.testing.assert_array_equal(x_out, (x + z + 1.0)) + np.testing.assert_array_equal(y_out, (y + z + 1.0) * 2) + + g_fused = jax.jit(fuser.fuse(g)) + x_out, y_out = g_fused(x, y, z) + np.testing.assert_allclose(x_out, (x + z + 1.0)) + np.testing.assert_allclose(y_out, (y + z + 1.0) * 2) + + def test_separate_output_fusions_should_error_if_not_disjoint(self): + + @fuser.fusible(output_fusion_prefix=(True, True)) + def f(x_fn, y_fn, z_fns): + x = x_fn() + y = y_fn() + if z_fns is None: + z_fns = lambda x: x, lambda x: x + z_fn1, z_fn2 = z_fns + return z_fn1(x), z_fn2(y) + + @jax.jit + @fuser.fuse + def g(x, y): + x_res, y_res = f(x, y) + return x_res + y_res + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (128, 128), dtype=jnp.float32) + + with self.assertRaisesRegex( + ValueError, + "Outputs must be disjoint in order to use separate output fusions", + ): + g(x, y) + + def test_separate_output_fusions_allows_permute(self): + + @fuser.fusible(output_fusion_prefix=(True, True)) + def f(x_fn, y_fn, z_fns): + x = x_fn() + y = y_fn() + if z_fns is None: + z_fns = lambda x: x, lambda x: x + z_fn1, z_fn2 = z_fns + return z_fn1(x), z_fn2(y) + + @jax.jit + @fuser.fuse + def g(x, y): + x_res, y_res = f(x, y) + return y_res * 2, x_res + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + y_out, x_out = g(x, y) + np.testing.assert_array_equal(x_out, x) + np.testing.assert_array_equal(y_out, y * 2) + + def test_separate_output_fusions_with_nesting(self): + + @fuser.fusible(output_fusion_prefix=(True, True)) + def f(x_fn, y_fn, z_fns): + x = x_fn() + y = y_fn() + if z_fns is None: + z_fns = lambda x: x, lambda x: x + z_fn1, z_fn2 = z_fns + return z_fn1(x), z_fn2(y) + + @jax.jit + @fuser.fuse + def g(x, y): + x_res, y_res = f(x, y) + return (x_res * 2, x_res + x_res), y_res + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + (x1_out, x2_out), y_out = g(x, y) + np.testing.assert_array_equal(x1_out, x * 2) + np.testing.assert_array_equal(x2_out, x + x) + np.testing.assert_array_equal(y_out, y) + + def test_separate_output_fusions_with_nesting_and_permutation(self): + + @fuser.fusible(output_fusion_prefix=(True, True)) + def f(x_fn, y_fn, z_fns): + x = x_fn() + y = y_fn() + if z_fns is None: + z_fns = lambda x: x, lambda x: x + z_fn1, z_fn2 = z_fns + return z_fn1(x), z_fn2(y) + + @jax.jit + @fuser.fuse + def g(x, y): + x_res, y_res = f(x, y) + return y_res, (x_res * 2, x_res + x_res) + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + y_out, (x1_out, x2_out) = g(x, y) + np.testing.assert_array_equal(x1_out, x * 2) + np.testing.assert_array_equal(x2_out, x + x) + np.testing.assert_array_equal(y_out, y) + + def test_separate_output_fusions_with_deep_output_mask(self): + + @fuser.fusible(output_fusion_prefix=(True, (True, True))) + def f(x_fn, y_fn, z_fn, o_fns): + x = x_fn() + y = y_fn() + z = z_fn() + if o_fns is None: + o_fns = lambda x: x, (lambda x: x, lambda x: x) + o_fn1, (o_fn2, o_fn3) = o_fns + return o_fn1(x), (o_fn2(y), o_fn3(z)) + + @jax.jit + @fuser.fuse + def g(x, y, z): + x_res, (y_res, z_res) = f(x, y, z) + return (x_res * 2, (y_res, z_res + z_res)) + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + z = jax.random.normal(jax.random.key(1), (128, 1), dtype=jnp.float32) + x_out, (y_out, z_out) = g(x, y, z) + np.testing.assert_array_equal(x_out, x * 2) + np.testing.assert_array_equal(y_out, y) + np.testing.assert_array_equal(z_out, z + z) + + def test_separate_output_fusions_with_reused_value(self): + + @fuser.fusible(output_fusion_prefix=(True, True)) + def f(x_fn, y_fn, z_fns): + x = x_fn() + y = y_fn() + if z_fns is None: + z_fns = lambda x: x, lambda x: x + z_fn1, z_fn2 = z_fns + return z_fn1(x), z_fn2(y) + + @jax.jit + @fuser.fuse + def g(x, y, a): + x_res, y_res = f(x, y) + return y_res + a, (x_res * 2, x_res + x_res + a) + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + a = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + y_out, (x1_out, x2_out) = g(x, y, a) + np.testing.assert_array_equal(x1_out, x * 2) + np.testing.assert_array_equal(x2_out, x + x + a) + np.testing.assert_array_equal(y_out, y + a) + + def test_empty_fusion(self): + + @fuser.fusible + def f(x_fn, y_fn): + x = x_fn() + if y_fn is None: + y_fn = lambda x: x + return y_fn(x) + + @jax.jit + @fuser.fuse + def g(x, a): + _ = lax.dce_sink(f(x)) + return a + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + a = jax.random.normal(jax.random.key(1), (128, 128), dtype=jnp.float32) + y_out = g(x, a) + np.testing.assert_array_equal(y_out, a) + + +@dataclasses.dataclass(frozen=True) +class ArrayTuple: + x0: jax.Array + x1: jax.Array + + +@dataclasses.dataclass(frozen=True) +class ArrayTupleTy(hijax.HiType): + x0: jax_core.ShapedArray + x1: jax_core.ShapedArray + + def lo_ty(self) -> list[jax_core.ShapedArray]: + return [self.x0, self.x1] + + def lower_val(self, hi_val: ArrayTuple) -> list[jax.Array]: + return [hi_val.x0, hi_val.x1] + + def raise_val(self, x0, x1) -> ArrayTuple: + return ArrayTuple(x0, x1) + + +hijax.register_hitype( + ArrayTuple, lambda t: ArrayTupleTy(jax.typeof(t.x0), jax.typeof(t.x1)) +) + + +class FusionHijaxTest(jtu.JaxTestCase): + + def test_basic_fusion(self): + + @jax.jit + @fuser.fuse + @fuser.fusible + def f(x_fn, y_fn): + x = x_fn() + if y_fn is None: + y_fn = lambda x: x + return y_fn(x) + + xt = ArrayTuple(x0=jnp.ones((8, 8)), x1=jnp.zeros(4)) + ot = f(xt) + np.testing.assert_array_equal(ot.x0, xt.x0) + np.testing.assert_array_equal(ot.x1, xt.x1) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/gpu_attention_test.py b/tests/pallas/gpu_attention_test.py index 7f42793eb7c0..fed3946ad606 100644 --- a/tests/pallas/gpu_attention_test.py +++ b/tests/pallas/gpu_attention_test.py @@ -112,7 +112,7 @@ def test_mqa( k = random.normal(k2, (batch_size, seq_len, head_dim), dtype=jnp.float16) v = random.normal(k3, (batch_size, seq_len, head_dim), dtype=jnp.float16) - o, *res = decode_attention.mqa( + outputs = decode_attention.mqa( q, k, v, @@ -122,7 +122,7 @@ def test_mqa( normalize_output=normalize_output, interpret=self.INTERPRET, ) - o_ref, *res_ref = decode_attention.mqa_reference( + outputs_ref = decode_attention.mqa_reference( q, k, v, @@ -131,12 +131,17 @@ def test_mqa( return_residuals=return_residuals, normalize_output=normalize_output ) - np.testing.assert_allclose(o, o_ref, atol=0.05) + if return_residuals: - l, m = res[0] - l_ref, m_ref = res_ref[0] + o, (l, m) = outputs + o_ref, (l_ref, m_ref) = outputs_ref np.testing.assert_allclose(l, l_ref, atol=0.05) np.testing.assert_allclose(m, m_ref, atol=0.05) + else: + o = outputs + o_ref = outputs_ref + np.testing.assert_allclose(o, o_ref, atol=0.05) + self.assertTupleEqual(o.shape, q.shape) @parameterized.named_parameters(*[ ( @@ -163,6 +168,7 @@ def test_mqa( kwargs, ) in [ (1, 1024, 16, 4, 64, {}), + (2, 1024, 16, 4, 64, {}), (1, 1024, 16, 16, 64, {}), (1, 1024, 32, 32, 64, {}), ] @@ -196,7 +202,7 @@ def test_gqa( v = random.normal( k3, (batch_size, seq_len, num_kv_heads, head_dim), dtype=jnp.float16 ) - o, *res = decode_attention.gqa( + outputs = decode_attention.gqa( q, k, v, @@ -206,7 +212,7 @@ def test_gqa( normalize_output=normalize_output, interpret=self.INTERPRET, ) - o_ref, *res_ref = decode_attention.gqa_reference( + outputs_ref = decode_attention.gqa_reference( q, k, v, @@ -215,12 +221,16 @@ def test_gqa( return_residuals=return_residuals, normalize_output=normalize_output ) - np.testing.assert_allclose(o, o_ref, atol=0.05) if return_residuals: - l, m = res[0] - l_ref, m_ref = res_ref[0] + o, (l, m) = outputs + o_ref, (l_ref, m_ref) = outputs_ref np.testing.assert_allclose(l, l_ref, atol=0.05) np.testing.assert_allclose(m, m_ref, atol=0.05) + else: + o = outputs + o_ref = outputs_ref + np.testing.assert_allclose(o, o_ref, atol=0.05) + self.assertTupleEqual(o.shape, q.shape) class DecodeAttentionInterpretTest(DecodeAttentionTest): diff --git a/tests/pallas/gpu_ops_test.py b/tests/pallas/gpu_ops_test.py index a33760cbfa86..65cdfd93dae0 100644 --- a/tests/pallas/gpu_ops_test.py +++ b/tests/pallas/gpu_ops_test.py @@ -21,12 +21,9 @@ from absl.testing import absltest from absl.testing import parameterized import jax -from jax import lax from jax import random from jax._src import config from jax._src import test_util as jtu -from jax._src.lax.control_flow.for_loop import for_loop -from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr from jax.experimental import pallas as pl if sys.platform != "win32": from jax.experimental.pallas.ops.gpu import attention @@ -49,77 +46,6 @@ config.parse_flags_with_absl() -@functools.partial(jax.jit, static_argnames=["bm", "bn", "gm", "bk", - "interpret", "debug"]) -def matmul(x, y, *, bm, bn, gm, bk, interpret, debug=False): - m, n, k = x.shape[0], y.shape[1], x.shape[1] - @functools.partial( - pl.pallas_call, out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), - interpret=interpret, - debug=debug, - grid=pl.cdiv(m, bm) * pl.cdiv(n, bn)) - def matmul_kernel(x_ref, y_ref, o_ref): - pid = pl.program_id(axis=0) - num_pid_m = m // bm - num_pid_n = n // bn - num_pid_in_group = gm * num_pid_n - group_id = lax.div(pid, num_pid_in_group) - first_pid_m = group_id * gm - group_size_m = jnp.minimum(num_pid_m - first_pid_m, gm) - pid_m = first_pid_m + lax.rem(pid, group_size_m) - pid_n = lax.div(lax.rem(pid, num_pid_in_group), group_size_m) - idx_m = pid_m * bm + jnp.arange(bm) - idx_n = pid_n * bn + jnp.arange(bn) - idx_m = pl.max_contiguous(pl.multiple_of(idx_m, bm), bm) - idx_n = pl.max_contiguous(pl.multiple_of(idx_n, bn), bn) - acc = jnp.zeros((bm, bn), dtype=jnp.float32) - def body(i, acc_ref): - idx_k = i * bk + jnp.arange(bk) - x_idx = ( - jax.lax.broadcast_in_dim(idx_m, (bm, bk), (0,)), - jax.lax.broadcast_in_dim(idx_k, (bm, bk), (1,))) - y_idx = ( - jax.lax.broadcast_in_dim(idx_k, (bk, bn), (0,)), - jax.lax.broadcast_in_dim(idx_n, (bk, bn), (1,))) - x_block, y_block = x_ref[x_idx], y_ref[y_idx] - out = pl.dot(x_block, y_block) - acc_ref[:, :] += out - acc = for_loop(k // bk, body, acc).astype(o_ref.dtype) - o_idx = ( - jax.lax.broadcast_in_dim(idx_m, (bm, bn), (0,)), - jax.lax.broadcast_in_dim(idx_n, (bm, bn), (1,)), - ) - o_ref[o_idx] = acc - return matmul_kernel(x, y) - - -@functools.partial(jax.jit, static_argnames=["bm", "bn", "bk", - "interpret", "debug"]) -def matmul_block_spec(x, y, *, bm, bn, bk, interpret, debug=False): - m, n, k = x.shape[0], y.shape[1], x.shape[1] - @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), - interpret=interpret, - debug=debug, - in_specs=[ - pl.BlockSpec((bm, x.shape[1]), lambda i, _: (i, 0)), - pl.BlockSpec((y.shape[0], bn), lambda _, j: (0, j)), - ], - out_specs=pl.BlockSpec((bm, bn), lambda i, j: (i, j)), - grid=(pl.cdiv(m, bm), pl.cdiv(n, bn)), - ) - def matmul_kernel(x_ref, y_ref, o_ref): - acc = jnp.zeros(o_ref.shape, dtype=jnp.float32) - def body(i, acc_ref): - x_block = pl.load(x_ref, (slice(None), pl.ds(i * bk, bk))) - y_block = pl.load(y_ref, (pl.ds(i * bk, bk), slice(None))) - acc_ref[:, :] += pl.dot(x_block, y_block) - acc = for_loop(k // bk, body, acc).astype(o_ref.dtype) - o_ref[:, :] = acc - return matmul_kernel(x, y) - - @jtu.with_config(jax_traceback_filtering="off") class PallasBaseTest(jtu.JaxTestCase): INTERPRET = False @@ -136,7 +62,6 @@ def setUp(self): self.skipTest("Only works on non-Windows platforms") super().setUp() - _trace_kernel_to_jaxpr.cache_clear() def pallas_call(self, *args, **kwargs): return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET) @@ -153,7 +78,7 @@ def setUp(self): batch_size=(1, 2), seq_len=(128, 384), num_heads=(1, 2, 8), - head_dim=(32, 64, 128), + head_dim=(32, 64, 72, 128), block_sizes=( (("block_q", 128), ("block_k", 128)), (("block_q", 64), ("block_k", 64)), @@ -226,14 +151,14 @@ def impl(q, k, v): batch_size=(1, 2), seq_len=(128, 384), num_heads=(1, 2), - head_dim=(32, 64, 128,), + head_dim=(32, 64, 72, 128,), block_sizes=( ( ("block_q", 128), ("block_k", 128), - ("block_q_dkv", 128), - ("block_kv_dkv", 128), - ("block_q_dq", 128), + ("block_q_dkv", 32), + ("block_kv_dkv", 32), + ("block_q_dq", 32), ("block_kv_dq", 128), ), ( @@ -248,8 +173,8 @@ def impl(q, k, v): ("block_q", 64), ("block_k", 128), ("block_q_dkv", 64), - ("block_kv_dkv", 128), - ("block_q_dq", 128), + ("block_kv_dkv", 32), + ("block_q_dq", 32), ("block_kv_dq", 64), ), ), @@ -267,6 +192,10 @@ def test_fused_attention_bwd( causal, use_segment_ids, ): + if jtu.is_cuda_compute_capability_at_least("8.0"): + # TODO(b/416306534) + self.skipTest("Precision issues after CUDA 12.8.1 upgrade") + k1, k2, k3 = random.split(random.key(0), 3) q = random.normal( k1, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16 @@ -302,6 +231,30 @@ def f_ref(q, k, v): self.assertAllClose(dk, dk_ref, atol=5e-2) self.assertAllClose(dv, dv_ref, atol=5e-2) + def test_return_residuals_not_differentiable(self): + batch_size, seq_len, num_heads, head_dim = 2, 128, 2, 128 + causal = False + k1, k2, k3 = random.split(random.key(0), 3) + q = random.normal( + k1, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16 + ) + k = random.normal( + k2, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16 + ) + v = random.normal( + k3, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16 + ) + segment_ids = None + + def f(q, k, v): + return attention.mha(q, k, v, causal=causal, segment_ids=segment_ids, + interpret=self.INTERPRET, + return_residuals=True)[0].sum() + + with self.assertRaisesRegex(ValueError, "Kernel differentiation is not" + " supported if return_residuals is True."): + _ = jax.grad(f, argnums=(0, 1, 2))(q, k, v) + class FusedAttentionInterpretTest(FusedAttentionTest): INTERPRET = True diff --git a/tests/pallas/gpu_paged_attention_test.py b/tests/pallas/gpu_paged_attention_test.py index 081051f15dae..6a1d8de22a78 100644 --- a/tests/pallas/gpu_paged_attention_test.py +++ b/tests/pallas/gpu_paged_attention_test.py @@ -44,9 +44,11 @@ def _generate_qkv( k_pages = jax.random.normal( k1, (num_kv_heads, total_pages, page_size, head_dim), dtype=dtype ) + k_pages = k_pages / jnp.linalg.norm(k_pages, axis=-1)[..., None] v_pages = jax.random.normal( k2, (num_kv_heads, total_pages, page_size, head_dim), dtype=dtype ) + v_pages = v_pages / jnp.linalg.norm(v_pages, axis=-1)[..., None] block_tables = jnp.arange( batch_size * max_num_blocks_per_seq, dtype=jnp.int32 @@ -54,6 +56,7 @@ def _generate_qkv( block_tables = jax.random.permutation(k3, block_tables, independent=True) block_tables = block_tables.reshape(batch_size, max_num_blocks_per_seq) q = jax.random.normal(k4, (batch_size, num_heads, head_dim), dtype=dtype) + q = q / jnp.linalg.norm(q, axis=-1)[..., None] return q, k_pages, v_pages, block_tables @@ -72,6 +75,17 @@ def fn(_block_tables, _pages): return out +def _quantize(x: jax.Array, dtype=jnp.int8): + if isinstance(dtype, jnp.floating): + max_val = jnp.astype(jnp.finfo(dtype).max, x.dtype) + else: + max_val = 127 + x_scale = jnp.max(jnp.abs(x), axis=-1) / (0.95 * max_val) + x_quant = (x / x_scale[..., None]) + if isinstance(dtype, jnp.floating): + x_quant = jnp.rint(x_quant) + return x_quant.astype(dtype), x_scale.astype(x.dtype) + @jtu.with_config(jax_traceback_filtering="off") class PallasBaseTest(jtu.JaxTestCase): @@ -93,12 +107,69 @@ def setUp(self): super().setUp() - class PagedAttentionKernelTest(PallasBaseTest): def setUp(self): super().setUp() + def _estimate_shared_memory_bytes(self, block_h, pages_per_compute_block, + page_size, head_dim, dtype): + """Estimate shared memory usage for paged attention kernel.""" + dtype_size = jnp.dtype(dtype).itemsize + # Approximate calculation based on kernel's memory usage + # Q block: block_h * head_dim + # K/V blocks: pages_per_compute_block * page_size * head_dim + # Plus accumulators and intermediate values + block_k = pages_per_compute_block * page_size + estimated = dtype_size * ( + block_h * head_dim + # Q + 2 * block_k * head_dim + # K and V + block_h * block_k + # logits/attention weights + block_h * 8 # accumulators (m, l, etc.) in float32 + ) + return estimated + + def _adjust_params_for_shared_memory(self, block_h, pages_per_compute_block, + page_size, head_dim, dtype): + """Adjust parameters to fit within device shared memory limits. + + Uses XLA's DeviceDescription.shared_memory_per_block_optin() to query + the actual device capability rather than hardcoding values. + """ + try: + device = jax.local_devices()[0] + # Query XLA DeviceDescription for max shared memory per block + # This is exposed from stream_executor::DeviceDescription::shared_memory_per_block_optin() + max_smem = device.shared_memory_per_block_optin + except (AttributeError, IndexError): + # Fallback if XLA doesn't expose shared_memory_per_block_optin (older versions) + # or if no devices are available. Use conservative 48KB (safe for most GPUs). + max_smem = 48 * 1024 + + estimated = self._estimate_shared_memory_bytes( + block_h, pages_per_compute_block, page_size, head_dim, dtype) + + # If within limits, no adjustment needed + if estimated <= max_smem: + return block_h, pages_per_compute_block, page_size + + # Try to reduce parameters to fit + while estimated > max_smem: + if pages_per_compute_block > 2: + pages_per_compute_block = pages_per_compute_block // 2 + elif page_size > 8: + page_size = page_size // 2 + elif block_h > 8: + block_h = block_h // 2 + else: + # Can't reduce further, will need to skip + return None, None, None + + estimated = self._estimate_shared_memory_bytes( + block_h, pages_per_compute_block, page_size, head_dim, dtype) + + return block_h, pages_per_compute_block, page_size + @jtu.sample_product( dtype=(jnp.float16,), page_size=(8, 16, 32), @@ -154,6 +225,94 @@ def test_paged_attention( self.assertArraysAllClose(o, o_ref, rtol=5e-2, atol=5e-2) + @jtu.sample_product( + dtype=(jnp.float16,), + page_size=(8, 16, 32), + num_kv_heads=(1, 2), + q_kv_head_ratio=(2, 16, 20), + head_dim=(32, 64), + block_h=(16, 32), + pages_per_compute_block=(4, 8), + k_splits=(4, 16), + attn_logits_soft_cap=(None,), + quantize_k=(True, False), + quantize_v=(True, False), + quant_dtype=(jnp.float8_e5m2, jnp.float8_e4m3fn, jnp.int8), + ) + def test_quantized_paged_attention( + self, + dtype, + page_size, + num_kv_heads, + q_kv_head_ratio, + head_dim, + block_h, + pages_per_compute_block, + k_splits, + attn_logits_soft_cap, + quantize_k, + quantize_v, + quant_dtype, + ): + if not quantize_k and not quantize_v: + self.skipTest("Skipping since neither (k, v) quantization requested.") + if (quant_dtype == jnp.float8_e4m3fn + and not jtu.is_cuda_compute_capability_at_least("8.9")): + self.skipTest("Skipping since float8_e4m3fn is not supported on < sm89") + + # Check and adjust parameters if needed to fit device limits for ROCm + if jtu.is_device_rocm(): + adjusted = self._adjust_params_for_shared_memory( + block_h, pages_per_compute_block, page_size, head_dim, dtype) + + if adjusted == (None, None, None): + self.skipTest("Cannot adjust parameters to fit ROCm device shared memory limits") + + block_h, pages_per_compute_block, page_size = adjusted + + max_kv_len = 2048 + seq_lens = np.asarray([3, 256, 513, 1023, 2048], dtype=jnp.int32) + q, k_pages, v_pages, block_tables = _generate_qkv( + seq_lens.shape[0], + page_size, + max_kv_len, + num_kv_heads, + num_kv_heads * q_kv_head_ratio, + head_dim, + jax.random.key(0), + dtype, + ) + k = _reconstruct_kv(block_tables, k_pages) + v = _reconstruct_kv(block_tables, v_pages) + + k_, k_scales = (_quantize(k_pages, quant_dtype) + if quantize_k else (k_pages, None)) + v_, v_scales = (_quantize(v_pages, quant_dtype) + if quantize_v else (v_pages, None)) + + o = paged_attention.paged_attention( + q, + k_, + v_, + block_tables, + seq_lens, + k_scales_pages=k_scales, + v_scales_pages=v_scales, + block_h=block_h, + pages_per_compute_block=pages_per_compute_block, + k_splits=k_splits, + attn_logits_soft_cap=attn_logits_soft_cap, + interpret=self.INTERPRET, + ) + + o_ref = paged_attention.paged_attention_reference(q, k, v, lengths=seq_lens) + + error = (jnp.linalg.norm((o - o_ref).astype(jnp.float32), axis=-1) + / jnp.linalg.norm(o_ref.astype(jnp.float32))) + + admissible_error = 3e-1 + self.assertLessEqual(jnp.mean(error), admissible_error) + class PagedAttentionInterpretTest(PagedAttentionKernelTest): INTERPRET = True diff --git a/tests/pallas/gpu_pallas_distributed_test.py b/tests/pallas/gpu_pallas_distributed_test.py new file mode 100644 index 000000000000..165e543c1d88 --- /dev/null +++ b/tests/pallas/gpu_pallas_distributed_test.py @@ -0,0 +1,737 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for distributed pallas GPU operations.""" + +import functools +import os + +from absl.testing import parameterized +import jax +from jax import lax +from jax._src import test_multiprocess as jt_multiprocess +from jax._src import test_util as jtu +from jax.experimental import multihost_utils +from jax.experimental import pallas as pl +import jax.experimental.mosaic.gpu as mgpu +from jax.experimental.pallas import mosaic_gpu as plgpu +from jax.experimental.pallas.ops.gpu.reduce_scatter_mgpu import reduce_scatter +from jax.experimental.pallas.ops.gpu.all_gather_mgpu import all_gather +import jax.numpy as jnp +import numpy as np + + +P = jax.sharding.PartitionSpec +partial = functools.partial + + +class TestCase(jt_multiprocess.MultiProcessTest): + + def setUp(self): + # Check mosaic support first (before GPU capability check) + if not mgpu.supports_cross_device_collectives(): + if jtu.test_device_matches(["rocm"]): + self.skipTest("Mosaic not supported on ROCm currently.") + else: + self.skipTest("NVSHMEM library unavailable.") + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("Only works on GPU with capability >= sm90") + if jax.process_count() == 1: + self.skipTest("Test requires multiple processes.") + if os.environ.get("XLA_PYTHON_CLIENT_ALLOCATOR", "") == "platform": + self.skipTest("NVSHMEM doesn't work with the platform allocator.") + super().setUp() + +class PallasCallRemoteDMATest(TestCase): + + def test_remote_dma_basic(self): + if jax.process_index() > 2: + return # Only 2 processes needed. + def kernel(x_ref, y_ref, ready_sem, recv_sem): + other_dev_id = 1 - lax.axis_index('x') + y_ref[...] = x_ref[...] + pl.semaphore_signal(ready_sem, device_id=other_dev_id) + pl.semaphore_wait(ready_sem) + neighbor_ptr = plgpu.remote_ref(y_ref, other_dev_id) + neighbor_ptr[...] = x_ref[...] + pl.semaphore_signal(recv_sem, device_id=other_dev_id) + pl.semaphore_wait(recv_sem) + + x = jnp.arange(2 * 8 * 128.0, dtype=jnp.float32).reshape((2 * 8, 128)) + def body(x): + return pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + scratch_shapes=[ + plgpu.SemaphoreType.REGULAR, + plgpu.SemaphoreType.REGULAR, + ], + )(x) + + devices = jax.devices()[:2] + mesh = jax.sharding.Mesh(devices, ['x']) + y = jax.jit( + jax.shard_map( + body, mesh=mesh, in_specs=P('x'), out_specs=P('x'), check_vma=False, + ) + )(x) + + expected = x[8:] if jax.process_index() == 0 else x[:8] + np.testing.assert_allclose(y.addressable_shards[0].data, expected) + + @parameterized.parameters(('x',), ('y',)) + def test_remote_dma_2d_mesh(self, axis): + if jax.process_count() < 4: + self.skipTest('Test requires at least 4 devices (and processes).') + if jax.process_index() > 4: + return # Only 4 processes needed. + def kernel(x_ref, y_ref, recv_sem): + other_dev_id = {axis: 1 - lax.axis_index(axis)} + other_y_ref = plgpu.remote_ref(y_ref, other_dev_id) + other_y_ref[...] = x_ref[...] + pl.semaphore_signal(recv_sem, device_id=other_dev_id) + pl.semaphore_wait(recv_sem) + + x = jnp.arange(2 * 8 * 128.0, dtype=jnp.float32).reshape((2 * 8, 128)) + def body(x): + return pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + scratch_shapes=[plgpu.SemaphoreType.REGULAR], + )(x) + + devices = jax.devices()[:4] + mesh = jax.sharding.Mesh(np.asarray(devices).reshape(2, 2), ['x', 'y']) + y = jax.jit( + jax.shard_map( + body, mesh=mesh, in_specs=P(axis), out_specs=P(axis), check_vma=False, + ) + )(x) + + expected = x[8:] if jax.process_index() == 0 else x[:8] + np.testing.assert_allclose(y.addressable_shards[0].data, expected) + + def test_wait_twice(self): + if jax.process_index() > 2: + return # Only 2 processes needed. + + def kernel(y_ref, sem): + other_dev_id = 1 - lax.axis_index('x') + pl.semaphore_signal(sem, 2, device_id=other_dev_id) + pl.semaphore_wait(sem) + pl.semaphore_wait(sem) + y_ref[...] = jnp.ones_like(y_ref) + + kernel_call = pl.pallas_call( + kernel, + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + scratch_shapes=[plgpu.SemaphoreType.REGULAR], + ) + + devices = jax.devices()[:2] + mesh = jax.sharding.Mesh(devices, ['x']) + y = jax.jit( + jax.shard_map( + kernel_call, mesh=mesh, in_specs=(), out_specs=P(None), check_vma=False, + ) + )() + np.testing.assert_allclose(y, jnp.ones_like(y)) + + def test_wait_nodec(self): + if jax.process_index() > 2: + return # Only 2 processes needed. + + def kernel(y_ref, sem): + other_dev_id = 1 - lax.axis_index('x') + pl.semaphore_signal(sem, 2, device_id=other_dev_id) + pl.semaphore_wait(sem, decrement=False) + pl.semaphore_wait(sem, 2, decrement=False) + pl.semaphore_wait(sem, 2) + y_ref[...] = jnp.ones_like(y_ref) + + kernel_call = pl.pallas_call( + kernel, + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + scratch_shapes=[plgpu.SemaphoreType.REGULAR], + ) + + devices = jax.devices()[:2] + mesh = jax.sharding.Mesh(devices, ['x']) + y = jax.jit( + jax.shard_map( + kernel_call, mesh=mesh, in_specs=(), out_specs=P(None), check_vma=False, + ) + )() + np.testing.assert_allclose(y, jnp.ones_like(y)) + + def test_signal_parallel(self): + if jax.process_index() > 2: + return # Only 2 processes needed. + + def kernel(y_ref, sem, sem2): + other_dev_id = 1 - lax.axis_index('x') + plgpu.semaphore_signal_parallel( + plgpu.SemaphoreSignal(sem, device_id=other_dev_id), + plgpu.SemaphoreSignal(sem2, device_id=other_dev_id), + ) + pl.semaphore_wait(sem) + pl.semaphore_wait(sem2) + y_ref[...] = jnp.ones_like(y_ref) + + kernel_call = pl.pallas_call( + kernel, + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + scratch_shapes=[plgpu.SemaphoreType.REGULAR] * 2, + ) + + devices = jax.devices()[:2] + mesh = jax.sharding.Mesh(devices, ['x']) + y = jax.jit( + jax.shard_map( + kernel_call, mesh=mesh, in_specs=(), out_specs=P(None), check_vma=False, + ) + )() + np.testing.assert_allclose(y, jnp.ones_like(y)) + + def test_semaphore_signal_collective_axes(self): + if jax.process_index() > 2: + return # Only 2 processes needed. + + def kernel(y_ref, sem): + plgpu.semaphore_signal_multicast(sem, collective_axes='x') + # Wait for the multicast signal (each device gets signaled by all devices) + pl.semaphore_wait(sem, 2) # Wait for signals from both devices + y_ref[...] = jnp.ones_like(y_ref) + + kernel_call = pl.pallas_call( + kernel, + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + scratch_shapes=[plgpu.SemaphoreType.REGULAR], + ) + + devices = jax.devices()[:2] + mesh = jax.sharding.Mesh(devices, ['x']) + y = jax.jit( + jax.shard_map( + kernel_call, mesh=mesh, in_specs=(), out_specs=P(None), check_vma=False, + ) + )() + np.testing.assert_allclose(y, jnp.ones_like(y)) + + def test_permuted_mesh(self): + def kernel(y_ref, sem): + other_dev_id = 1 - lax.axis_index('x') + pl.semaphore_signal(sem, 1, device_id=other_dev_id) + pl.semaphore_wait(sem) + + kernel_call = pl.pallas_call( + kernel, + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + scratch_shapes=[plgpu.SemaphoreType.REGULAR], + ) + mesh = jax.sharding.Mesh(jax.devices()[::-1], ['x']) # Reverse the devices. + f = jax.jit( + jax.shard_map( + kernel_call, mesh=mesh, in_specs=(), out_specs=P(None), check_vma=False, + ) + ) + msg = ( + 'Mosaic GPU only supports meshes with device ordering that follows' + ' row-major device ids.' + ) + with self.assertRaisesRegex(NotImplementedError, msg): + f() + + @parameterized.parameters(False, True) + def test_copy_tma(self, use_dict): + if jax.process_index() > 2: + return # Only 2 processes needed. + + def kernel(y_ref, smem_ref, sem): + dev_id = lax.axis_index("y") + other_dev_id = 1 - dev_id + if use_dict: + ids = lambda x, y: dict(x=x, y=y) + else: + ids = lambda x, y: (x, y) + + # Device ID must be an int32. + zero = jnp.int32(0) + + @pl.when(dev_id == zero) + def _store(): + output = plgpu.layout_cast(lax.broadcasted_iota(jnp.int32, (128, 128), 1), plgpu.Layout.WGMMA) + smem_ref[...] = output + plgpu.copy_smem_to_gmem(smem_ref, plgpu.remote_ref(y_ref, ids(zero, dev_id))) + plgpu.copy_smem_to_gmem(smem_ref, plgpu.remote_ref(y_ref, ids(zero, other_dev_id))) + plgpu.wait_smem_to_gmem(0) + pl.semaphore_signal(sem, 1, device_id=ids(zero, other_dev_id)) + pl.semaphore_wait(sem) + + transforms = (plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128)) + kernel_call = pl.pallas_call( + kernel, + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct((128, 128), jnp.int32), + scratch_shapes=[ + plgpu.SMEM((128, 128), jnp.int32, transforms=transforms), + plgpu.SemaphoreType.REGULAR, + ], + ) + mesh = jtu.create_mesh((1, 2), ("x", "y")) + y = jax.jit( + jax.shard_map( + kernel_call, mesh=mesh, in_specs=(), out_specs=P("y"), check_vma=False, + ) + )() + y = multihost_utils.process_allgather(y, tiled=True) + ref = lax.broadcasted_iota(jnp.int32, (128, 128), 1) + np.testing.assert_array_equal(y, np.concat([ref, ref], axis=0)) + + +class PallasCallMultimemTest(TestCase): + + def _get_reduction_impl(self, reduction): + match reduction: + case "add": + return jnp.add + case "min": + return jnp.minimum + case "max": + return jnp.maximum + case "and": + return jnp.bitwise_and + case "or": + return jnp.bitwise_or + case "xor": + return jnp.bitwise_xor + case _: + raise ValueError(reduction) + + def test_multimem_store_regs(self): + if jax.process_index() > 2: + return # Only 2 processes needed. + + def kernel(y_ref, sem): + @pl.when(lax.axis_index('x') == 0) + def _store(): + output = plgpu.layout_cast(lax.broadcasted_iota(jnp.int32, (128, 128), 1), plgpu.Layout.WGMMA) + plgpu.multimem_store(output, y_ref, 'x') + other_dev_id = 1 - lax.axis_index('x') + pl.semaphore_signal(sem, 1, device_id=other_dev_id) + pl.semaphore_wait(sem) + + kernel_call = pl.pallas_call( + kernel, + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct((128, 128), jnp.int32), + scratch_shapes=[plgpu.SemaphoreType.REGULAR], + ) + mesh = jax.sharding.Mesh(jax.devices(), ['x']) + y = jax.jit( + jax.shard_map( + kernel_call, mesh=mesh, in_specs=(), out_specs=P("x"), check_vma=False, + ) + )() + y = multihost_utils.process_allgather(y, tiled=True) + ref = lax.broadcasted_iota(jnp.int32, (128, 128), 1) + np.testing.assert_array_equal(y, np.concat([ref, ref], axis=0)) + + def test_multimem_store_tma(self): + if jax.process_index() > 2: + return # Only 2 processes needed. + + def kernel(y_ref, smem_ref, sem): + @pl.when(lax.axis_index('x') == 0) + def _store(): + output = plgpu.layout_cast(lax.broadcasted_iota(jnp.int32, (128, 128), 1), plgpu.Layout.WGMMA) + smem_ref[...] = output + plgpu.copy_smem_to_gmem(smem_ref, plgpu.multicast_ref(y_ref, 'x')) + plgpu.wait_smem_to_gmem(0) + other_dev_id = 1 - lax.axis_index('x') + pl.semaphore_signal(sem, 1, device_id=other_dev_id) + pl.semaphore_wait(sem) + + transforms = (plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128)) + kernel_call = pl.pallas_call( + kernel, + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct((128, 128), jnp.int32), + scratch_shapes=[ + plgpu.SMEM((128, 128), jnp.int32, transforms=transforms), + plgpu.SemaphoreType.REGULAR, + ], + ) + mesh = jax.sharding.Mesh(jax.devices(), ['x']) + y = jax.jit( + jax.shard_map( + kernel_call, mesh=mesh, in_specs=(), out_specs=P("x"), check_vma=False, + ) + )() + y = multihost_utils.process_allgather(y, tiled=True) + ref = lax.broadcasted_iota(jnp.int32, (128, 128), 1) + np.testing.assert_array_equal(y, np.concat([ref, ref], axis=0)) + + @parameterized.parameters( + (jnp.int32, 1, "add"), + (jnp.int32, 1, "min"), + (jnp.int32, 1, "max"), + (jnp.int32, 1, "and"), + (jnp.int32, 1, "or"), + (jnp.int32, 1, "xor"), + (jnp.float32, 1, "add"), + (jnp.float32, 2, "add", True), + (jnp.float32, 4, "add"), + (jnp.float16, 2, "add"), + (jnp.float16, 2, "min"), + (jnp.float16, 4, "max"), + (jnp.float16, 8, "add", True), + (jnp.bfloat16, 2, "max"), + (jnp.bfloat16, 8, "add"), + (jnp.float8_e5m2, 4, "add"), + (jnp.float8_e5m2, 8, "min"), + (jnp.float8_e5m2, 16, "max", True), + (jnp.float8_e4m3fn, 4, "min", True), + (jnp.float8_e4m3fn, 8, "max"), + (jnp.float8_e4m3fn, 16, "add"), + ) + def test_multimem_load_reduce(self, dtype, vector_length, reduction, tiled_layout=False): + if dtype in ( + jnp.float8_e5m2, + jnp.float8_e4m3fn, + ) and not jtu.is_cuda_compute_capability_at_least("10.0"): + self.skipTest("Only works on GPU with capability >= sm100") + if jax.process_index() > 2: + return # Only 2 processes needed. + devices = jax.devices()[:2] + + def kernel(x_ref, y_ref, _, sem_ref): + if tiled_layout: + layout = plgpu.Layout.TILED( + plgpu.Tiling( + ( + (64, 2 * vector_length), + (16, 2 * vector_length), + (vector_length,), + ) + ), + warp_dims=(-5,), + lane_dims=(-3, -2), + vector_dim=-1, + ) + else: + layout = plgpu.Layout.WG_STRIDED((64, 32), vec_size=vector_length) + y_ref[...] = plgpu.layout_cast( + plgpu.multimem_load_reduce( + x_ref.at[16:-16], collective_axes="x", reduction_op=reduction, + ), + layout + ) + my_device = lax.axis_index("x") + other_device = 1 - my_device + pl.semaphore_signal(sem_ref, 1, device_id=other_device) + pl.semaphore_wait(sem_ref) + + # The rounding we see in low precision types seems to be different from + # what JAX/XLA use. + match jnp.dtype(dtype).itemsize: + case 4: + bound = 800000 + case 2: + bound = 128 + case 1: + bound = 4 + case _: + raise ValueError(f"Unsupported dtype: {dtype}") + x_local = jax.random.randint( + jax.random.key(1234), (128 + 64, 32), dtype=jnp.int32, minval=-bound, maxval=bound, + ).astype(dtype) + mesh = jax.sharding.Mesh(devices, ("x",)) + x_shard = jax.ShapeDtypeStruct((64 + 32, 32), dtype) + y_shape = jax.ShapeDtypeStruct((64, 32), dtype) + y, _ = jax.jit( + jax.shard_map( + pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=[ + pl.BlockSpec(memory_space=plgpu.SMEM), + pl.BlockSpec(memory_space=plgpu.GMEM), + ], + out_shape=(y_shape, x_shard), + scratch_shapes=[plgpu.SemaphoreType.REGULAR], + # TODO(b/448323639): Without aliasing XLA doesn't actually + # insert the copy that puts the operand in symmetric memory, + # which causes the kernel to crash. + input_output_aliases={0: 1}, + ), + mesh=mesh, + in_specs=P("x"), + out_specs=P("x"), # Not really, but lets us test. + check_vma=False, + ) + )(x_local) + y = multihost_utils.process_allgather(y, tiled=True) + np_reduction = self._get_reduction_impl(reduction) + np.testing.assert_array_equal( + y.astype(jnp.float32), + np.tile(np_reduction(x_local[16:64+16], x_local[64+48:128+48]), (2, 1)), + ) + + def _test_reduce_scatter( + self, + shape, + dtype, + reduction, + scatter_dimension=0, + tile_size=None, + vec_size=None, + num_blocks=None, + ): + if jax.process_index() > 2: + return + + devices = jax.devices()[:2] + mesh = jax.sharding.Mesh(devices, ["x"]) + if jnp.issubdtype(dtype, jnp.floating): + x = jax.random.uniform(jax.random.key(42), shape, dtype=dtype, minval=-1.0, maxval=1.0) + else: + x = jax.random.randint(jax.random.key(42), shape, dtype=dtype, minval=-1000, maxval=1000) + + def body(x): + return reduce_scatter( + x, + axis_name="x", + scatter_dimension=scatter_dimension, + reduction=reduction, + vec_size=vec_size, + tile_size=tile_size, + num_blocks=num_blocks, + ) + + spec = P(*([None] * scatter_dimension), "x") + y = jax.jit( + jax.shard_map( + body, mesh=mesh, in_specs=spec, out_specs=spec, check_vma=False + ) + )(x) + + y = multihost_utils.process_allgather(y, tiled=True) + np_reduction = self._get_reduction_impl(reduction) + + split_idx = x.shape[scatter_dimension] // 2 + slices_first = [slice(None)] * len(shape) + slices_first[scatter_dimension] = slice(None, split_idx) + slices_second = [slice(None)] * len(shape) + slices_second[scatter_dimension] = slice(split_idx, None) + expected = np_reduction(x[tuple(slices_first)], x[tuple(slices_second)]) + tol = 1e-5 if reduction == "add" else 0 + np.testing.assert_allclose(y, expected, rtol=tol, atol=tol) + + @parameterized.parameters( + (jnp.float32, "add", 1), + (jnp.float16, "add", 2), + (jnp.bfloat16, "add", 2), + (jnp.float16, "min", 4), + (jnp.float16, "max", 8), + (jnp.int32, "add", 1), + ) + def test_reduce_scatter(self, dtype, reduction, vec_size): + # 16 rows * 64 cols = 1024 elements = 8 elements per thread + self._test_reduce_scatter( + (1024, 64), dtype, reduction, tile_size=1024, vec_size=vec_size, num_blocks=4 + ) + + def test_reduce_scatter_large_minor_dims(self): + self._test_reduce_scatter( + (512, 32768), jnp.float16, "add", tile_size=8192, vec_size=4, num_blocks=4 + ) + + @parameterized.parameters(2048, 256, None) + def test_reduce_scatter_auto_vec_size(self, tile_size): + self._test_reduce_scatter( + (1024, 64), jnp.float16, "add", tile_size=tile_size, vec_size=None, num_blocks=4 + ) + + @parameterized.parameters(2048, 256, None) + def test_reduce_scatter_auto_vec_size_int(self, tile_size): + self._test_reduce_scatter( + (1024, 64), jnp.int32, "add", tile_size=tile_size, vec_size=None, num_blocks=4 + ) + + @parameterized.parameters(1, 2) + def test_reduce_scatter_different_axes(self, axis): + if axis == 1: + shape = (64, 1024, 32) + tile_size = 2048 + else: # axis == 2 + shape = (32, 64, 1024) + tile_size = 2048 + self._test_reduce_scatter( + shape, jnp.float16, "add", scatter_dimension=axis, tile_size=tile_size, vec_size=None, num_blocks=4 + ) + + @parameterized.parameters( + (jnp.float16, "add"), + (jnp.float32, "add"), + (jnp.bfloat16, "max"), + ) + def test_all_reduce(self, dtype, reduction): + """Test all-reduce functionality when scatter_dimension=None.""" + self._test_all_reduce( + (1024, 1024), dtype, reduction, tile_size=512, vec_size=None, num_blocks=4 + ) + + def _test_all_reduce( + self, + shape, + dtype, + reduction, + tile_size=None, + vec_size=None, + num_blocks=None, + ): + """Helper function to test all-reduce functionality.""" + devices = jax.devices()[:2] + mesh = jax.sharding.Mesh(devices, ['x']) + x = jax.random.normal(jax.random.key(42), (2, *shape), dtype) + + def body(x): + return reduce_scatter( + x, + axis_name="x", + scatter_dimension=None, # All-reduce mode + reduction=reduction, + vec_size=vec_size, + tile_size=tile_size, + num_blocks=num_blocks, + ) + + spec = P("x") + y = jax.jit( + jax.shard_map( + body, mesh=mesh, in_specs=spec, out_specs=spec, check_vma=False + ) + )(x) + y = multihost_utils.process_allgather(y, tiled=True) + np_reduction = self._get_reduction_impl(reduction) + expected = np_reduction(x[0], x[1]) + tol = 1e-5 if reduction == "add" else 0 + for ys in y: + # It seems that the rounding used by the switch is different from what + # XLA uses. + y_rounded = np.nextafter(ys, expected) + np.testing.assert_allclose(y_rounded, expected, rtol=tol, atol=tol) + + def _test_all_gather( + self, + shape, + dtype, + gather_dimension=0, + tile_size=None, + vec_size=None, + num_blocks=None, + ): + if jax.process_index() > 2: + return + + if jnp.issubdtype(dtype, jnp.floating): + x = jax.random.uniform(jax.random.key(42), shape, dtype=dtype, minval=-1.0, maxval=1.0) + else: + x = jax.random.randint(jax.random.key(42), shape, dtype=dtype, minval=-1000, maxval=1000) + + def body(x): + return all_gather( + x, + axis_name="x", + gather_dimension=gather_dimension, + vec_size=vec_size, + tile_size=tile_size, + num_blocks=num_blocks, + ) + + spec = P(*([None] * gather_dimension), "x") + devices = jax.devices()[:2] + mesh = jax.sharding.Mesh(devices, ["x"]) + y = jax.jit( + jax.shard_map( + body, mesh=mesh, in_specs=spec, out_specs=spec, check_vma=False + ) + )(x) + y = multihost_utils.process_allgather(y, tiled=True) + repeats = [1] * len(x.shape) + repeats[gather_dimension] = 2 + np.testing.assert_array_equal(y, np.tile(x, repeats)) + + @parameterized.parameters( + (jnp.float32, 1), + (jnp.float16, 2), + (jnp.bfloat16, 2), + (jnp.float16, 4), + (jnp.float16, 8), + (jnp.int32, 1), + ) + def test_all_gather(self, dtype, vec_size): + # 16 rows * 64 cols = 1024 elements = 8 elements per thread + self._test_all_gather( + (1024, 64), dtype, tile_size=1024, vec_size=vec_size, num_blocks=4 + ) + + def test_all_gather_large_minor_dims(self): + self._test_all_gather( + (512, 32768), jnp.float16, tile_size=8192, vec_size=4, num_blocks=4 + ) + + @parameterized.parameters(2048, 256, None) + def test_all_gather_auto_vec_size(self, tile_size): + self._test_all_gather( + (1024, 64), jnp.float16, tile_size=tile_size, vec_size=None, num_blocks=4 + ) + + @parameterized.parameters(2048, 256, None) + def test_all_gather_auto_vec_size_int(self, tile_size): + self._test_all_gather( + (1024, 64), jnp.int32, tile_size=tile_size, vec_size=None, num_blocks=4 + ) + + @parameterized.parameters(1, 2) + def test_all_gather_different_axes(self, axis): + if axis == 1: + shape = (64, 1024, 32) + tile_size = 2048 + else: # axis == 2 + shape = (32, 64, 1024) + tile_size = 2048 + self._test_all_gather( + shape, jnp.float16, gather_dimension=axis, tile_size=tile_size, vec_size=None, num_blocks=4 + ) + + +if __name__ == '__main__': + # This test doesn't work with the platform allocator, so we override it + # if it's ran alone. If it's part of a larger test suite and the platform + # allocator is used, setUp will skip the test. + os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.01' + os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'default' + jt_multiprocess.main() diff --git a/tests/pallas/gpu_pallas_interpret_test.py b/tests/pallas/gpu_pallas_interpret_test.py new file mode 100644 index 000000000000..424423d6bee8 --- /dev/null +++ b/tests/pallas/gpu_pallas_interpret_test.py @@ -0,0 +1,243 @@ +# Copyright 2026 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from absl.testing import absltest +import jax +from jax._src import test_util as jtu +from jax._src.pallas.mosaic_gpu.interpret import interpret_pallas_call as mosaic_interpret +from jax.experimental import pallas as pl +from jax.experimental.pallas import mosaic_gpu as plgpu +import jax.numpy as jnp +import numpy as np + + +jax.config.parse_flags_with_absl() + + +# TODO(nrink): Figure out how to safely run different instance of GPU +# interpret mode in parallel, and then remove this decorator. +@jtu.thread_unsafe_test_class() +class InterpretTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + + if not jtu.test_device_matches(['cpu']): + self.skipTest('CPU-only test') + + self.num_devices = jax.device_count() + if self.num_devices > 1: + self.skipTest(f'requires 1 device, found {self.num_devices}') + + def test_interpret_pallas_call(self): + def _kernel(o_ref): + o_ref[0] = 42 + + @jax.jit + def kernel(): + return pl.pallas_call( + _kernel, + out_shape=jax.ShapeDtypeStruct((1,), jnp.int32), + interpret=mosaic_interpret.InterpretParams(detect_races=True), + )() + + np.testing.assert_equal(kernel(), np.array([42], dtype=jnp.int32)) + self.assertFalse(mosaic_interpret.get_races().races_found) + + @jtu.parameterized.parameters(1, 2, 4, 8, 16) + def test_interpret_core_map(self, num_threads: int): + @pl.run_state + def kernel(o_ref): + mesh = plgpu.Mesh(num_threads=num_threads, thread_name='x') + + @pl.core_map( + mesh, + interpret=mosaic_interpret.InterpretParams(detect_races=True), + ) + def _(): + thread_idx = jax.lax.axis_index('x') + o_ref[thread_idx] = thread_idx + + y = kernel(jnp.zeros((num_threads,), jnp.int32)) + np.testing.assert_equal(y, np.arange(num_threads, dtype=jnp.int32)) + self.assertFalse(mosaic_interpret.get_races().races_found) + + def test_interpret_core_map_with_race(self): + @pl.run_state + def kernel(o_ref): + mesh = plgpu.Mesh(num_threads=2, thread_name='x') + + @pl.core_map( + mesh, + interpret=mosaic_interpret.InterpretParams(detect_races=True), + ) + def _(): + thread_idx = jax.lax.axis_index('x') + o_ref[...] = thread_idx + + kernel(jnp.zeros((), jnp.int32)) + self.assertTrue(mosaic_interpret.get_races().races_found) + + @jtu.parameterized.parameters(1, 2, 4, 8, 16) + def test_interpret_kernel(self, num_threads): + @functools.partial( + plgpu.kernel, + out_shape=jax.ShapeDtypeStruct((num_threads,), jnp.int32), + num_threads=num_threads, + thread_name='x', + interpret=mosaic_interpret.InterpretParams(detect_races=True), + ) + def _kernel(o_ref): + thread_idx = jax.lax.axis_index('x') + o_ref[thread_idx] = thread_idx + + np.testing.assert_equal(jax.jit(_kernel)(), np.arange(num_threads)) + self.assertFalse(mosaic_interpret.get_races().races_found) + + def test_skip_floating_point_ops(self): + def matmul_kernel(x_ref, y_ref, z_ref): + z_ref[...] = x_ref[...] @ y_ref[...] + + def matmul(x: jax.Array, y: jax.Array): + return pl.pallas_call( + matmul_kernel, + out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype), + interpret=mosaic_interpret.InterpretParams( + skip_floating_point_ops=True + ), + )(x, y) + + k1, k2 = jax.random.split(jax.random.key(0)) + x = jax.random.normal(k1, (1024, 1024)) + y = jax.random.normal(k2, (1024, 1024)) + z = jax.jit(matmul)(x, y) + np.testing.assert_array_equal(z, jnp.full_like(z, jnp.inf)) + + lowered = jax.jit(matmul).lower(x, y).as_text(dialect='stablehlo') + self.assertNotIn('dot_general', lowered) + + @jtu.parameterized.parameters( + (1, 1, 1), + (2, 1, 2), + (2, 2, 1), + (4, 1, 4), + (4, 2, 2), + (4, 4, 1), + (8, 1, 8), + (8, 2, 4), + (8, 4, 2), + (8, 8, 1), + (16, 1, 16), + (16, 2, 8), + (16, 4, 4), + (16, 8, 2), + (16, 16, 1), + ) + def test_matmul_example(self, num_threads, num_row_blocks, num_col_blocks): + assert num_threads == num_row_blocks * num_col_blocks + + @jax.jit + def matmul(x: jax.Array, y: jax.Array): + num_rows_per_block = x.shape[0] // num_row_blocks + num_cols_per_block = y.shape[1] // num_col_blocks + + @functools.partial( + plgpu.kernel, + out_shape=jax.ShapeDtypeStruct( + ( + x.shape[0], + y.shape[1], + ), + x.dtype, + ), + num_threads=num_threads, + thread_name='t', + interpret=mosaic_interpret.InterpretParams( + detect_races=True, num_cores_or_threads=num_threads + ), + ) + def _matmul_kernel(x_ref, y_ref, o_ref): + thread_idx = jax.lax.axis_index('t') + + row_block_idx = thread_idx // num_col_blocks + row_slice = pl.ds( + row_block_idx * num_rows_per_block, num_rows_per_block + ) + + col_block_idx = jax.lax.rem(thread_idx, jnp.int32(num_col_blocks)) + col_slice = pl.ds( + col_block_idx * num_cols_per_block, num_cols_per_block + ) + + o_ref[row_slice, col_slice] = x_ref[row_slice, :] @ y_ref[:, col_slice] + + return _matmul_kernel(x, y) + + k1, k2 = jax.random.split(jax.random.key(0)) + x = jax.random.normal(k1, (1024, 1024)) + y = jax.random.normal(k2, (1024, 1024)) + z = matmul(x, y) + np.testing.assert_allclose(z, x @ y, atol=1e-3) + self.assertFalse(mosaic_interpret.get_races().races_found) + + @jtu.parameterized.parameters(False, True) + def test_run_scoped(self, with_race): + mesh = plgpu.Mesh(num_threads=2, thread_name='n') + + @jax.jit + def f(x): + def inner(o_ref): + @pl.core_map( + mesh, + interpret=mosaic_interpret.InterpretParams( + detect_races=True, + ), + ) # type: ignore[wrong-arg-types] + def _(): + def body(ref): + @pl.when(jax.lax.axis_index('n') == 0) + def _(): + ref[...] = jnp.zeros_like(ref[...]) + o_ref[0, ...] = ref[...] + + @pl.when(jax.lax.axis_index('n') == 1) + def _(): + ref[...] = jnp.ones_like(ref[...]) + o_ref[1, ...] = ref[...] + + pl.run_scoped( + body, + plgpu.GMEM(o_ref.shape[1:], dtype=o_ref.dtype), + collective_axes=('n',) if with_race else (), + ) + + y = pl.run_state(inner)(x) + return y + + y = f(jnp.zeros((2, 16, 128))) + + if with_race: + # Due to the presence of a race, we cannot expect `y` to have a + # well-defined value. Hence, we do not assert anything about `y`. + self.assertTrue(mosaic_interpret.get_races().races_found) + else: + np.testing.assert_array_equal( + y, np.broadcast_to(np.arange(2).reshape(2, 1, 1), y.shape) + ) + self.assertFalse(mosaic_interpret.get_races().races_found) + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/indexing_test.py b/tests/pallas/indexing_test.py index c3f3fa6e80a8..d057387e0b9b 100644 --- a/tests/pallas/indexing_test.py +++ b/tests/pallas/indexing_test.py @@ -14,12 +14,12 @@ from __future__ import annotations import sys -import unittest from absl.testing import absltest from absl.testing import parameterized import jax from jax import random +from jax._src import core from jax._src import test_util as jtu from jax._src import util from jax._src.state import indexing @@ -32,11 +32,7 @@ else: pltpu = None -try: - import hypothesis as hp -except (ModuleNotFoundError, ImportError): - raise unittest.SkipTest("tests depend on hypothesis library") - +import hypothesis as hp import hypothesis.extra.numpy as hnp import hypothesis.strategies as hps @@ -95,7 +91,7 @@ def array_indexer_strategy(draw, shape) -> jax.Array: @hps.composite def indexer_strategy(draw, dim, int_indexer_shape - ) -> int | Slice | jax.Array: + ) -> int | Slice | jax.Array: return draw(hps.one_of( int_indexer_strategy(dim), slice_indexer_strategy(dim), @@ -104,12 +100,12 @@ def indexer_strategy(draw, dim, int_indexer_shape @hps.composite -def nd_indexer_strategy(draw, shape) -> NDIndexer: +def nd_indices_strategy(draw, shape) -> tuple[int | Slice | jax.Array, ...]: num_indices = draw(hps.integers(min_value=0, max_value=len(shape))) int_indexer_shape = draw(hnp.array_shapes()) indices = tuple(draw(indexer_strategy(dim, int_indexer_shape)) for dim in shape[:num_indices]) - return NDIndexer.from_indices_shape(indices, shape) + return indices class PallasBaseTest(jtu.JaxTestCase): @@ -127,6 +123,7 @@ def pallas_call(cls, *args, **kwargs): return pl.pallas_call(*args, interpret=cls.INTERPRET, **kwargs) +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe class IndexerTest(jtu.JaxTestCase): """These are unit tests for the indexer logic, not using pallas_call.""" @@ -199,7 +196,43 @@ def test_ndindexer_with_arrays_and_invalid_broadcasting(self): with self.assertRaisesRegex( ValueError, "Cannot broadcast shapes for indexing" ): + NDIndexer.from_indices_shape(indices, shape) + + def test_ndindexer_with_ref(self): + indices = (core.new_ref(jnp.tile(jnp.arange(4), (2,))),) + shape = (4, 2) + indexer = NDIndexer.from_indices_shape(indices, shape) + self.assertTupleEqual(indexer.get_indexer_shape(), (8, 2)) + + def test_ndindexer_with_transformed_ref(self): + @jax.jit + def f(ref, size): + # We need the jit to make sure the ref has a dynamic size. + indices = (ref.at[pl.ds(0, size)],) + shape = (4, 2) indexer = NDIndexer.from_indices_shape(indices, shape) + return indexer.get_indexer_shape() + + size = 4 + np.testing.assert_array_equal( + f(core.new_ref(jnp.tile(jnp.arange(4), (size,))), 4), (size, 2) + ) + + def test_ndindexer_with_multiple_refs(self): + indices = (core.new_ref(jnp.tile(jnp.arange(4), (2,))),) * 2 + shape = (4, 2) + with self.assertRaisesRegex( + NotImplementedError, "Multiple Ref indexers are not supported" + ): + NDIndexer.from_indices_shape(indices, shape) + + def test_ndindexer_with_ref_and_int(self): + indices = (core.new_ref(jnp.tile(jnp.arange(4), (2,))), 0) + shape = (4, 2) + with self.assertRaisesRegex( + NotImplementedError, "Ref cannot be mixed with other non-slice indexers" + ): + NDIndexer.from_indices_shape(indices, shape) def test_indexer_with_all_types(self): indices = (0, slice(10), np.arange(5)) @@ -217,12 +250,15 @@ def test_indexer_with_all_types(self): indices = (ds(0, 2), np.arange(5)[:, None], np.arange(4)[None]) indexer = NDIndexer.from_indices_shape(indices, shape) - self.assertTupleEqual(indexer.get_indexer_shape(), (5, 4, 2)) + self.assertTupleEqual(indexer.get_indexer_shape(), (2, 5, 4)) @hp.given(hps.data()) + @hp.settings(suppress_health_check=[hp.HealthCheck.too_slow]) # ASAN is slow def test_ndindexer(self, data): shape = data.draw(hnp.array_shapes()) - indexer = data.draw(nd_indexer_strategy(shape)) + indices = data.draw(nd_indices_strategy(shape)) + indexer = NDIndexer.from_indices_shape(indices, shape) + is_int_indexer = [not isinstance(idx, Slice) for idx in indexer.indices] rest_indexers, int_indexers = util.partition_list( is_int_indexer, indexer.indices @@ -234,18 +270,15 @@ def test_ndindexer(self, data): self.assertTupleEqual( indexer.int_indexer_shape, expected_int_indexer_shape ) + for idx in rest_indexers: self.assertIsInstance(idx, (np.ndarray, Slice)) if isinstance(idx, np.ndarray): self.assertTupleEqual(idx.shape, ()) self.assertEqual(idx.dtype, np.dtype("int32")) - rest_shape = tuple( - r.size for r in rest_indexers if not isinstance(r, np.ndarray) - ) - self.assertTupleEqual((*indexer.int_indexer_shape, *rest_shape), - indexer.get_indexer_shape()) +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe class IndexerOpsTest(PallasBaseTest): def test_multi_indexing_interpreter_only(self): @@ -373,24 +406,25 @@ def permute_columns_in_row_kernel(left, right, new_left, new_right): def test_vmap_nd_indexing(self, data): self.skipTest("TODO(necula): enable this test; was in jax_triton.") vmap_shape = data.draw(hnp.array_shapes(min_dims=1, max_dims=3, min_side=2), - label="vmap_shape") + label="vmap_shape") el_shape = data.draw(hnp.array_shapes(min_dims=2), label="el_shape") # TODO(sharadmv,apaszke): enable rank 0 and rank 1 Refs # hp.assume(len(el_shape) >= 2) - nd_indexer = data.draw(nd_indexer_strategy(el_shape), label="nd_indexer") + nd_indexer = NDIndexer.from_indices_shape( + data.draw(nd_indices_strategy(el_shape), label="nd_indexer"), + el_shape) expected_shape = jax.eval_shape(lambda x: x[nd_indexer], jax.ShapeDtypeStruct(el_shape, jnp.float32)) ref = lambda x: x[nd_indexer] def kernel(x_ref, y_ref): - x = pl.load(x_ref, nd_indexer) - pl.store(y_ref, (slice(None),) * len(y_ref.shape), x) + y_ref[...] = x_ref[nd_indexer] func = pl.pallas_call(kernel, out_shape=expected_shape) shape = el_shape for vmap_dim in vmap_shape[::-1]: index = data.draw(hps.integers(min_value=0, - max_value=max(0, len(shape) - 2)), + max_value=max(0, len(shape) - 2)), label="index") # hp.assume(index <= max(0, len(shape) - 2)) # TODO(sharadmv,apaszke): enable vmapping over batch axes in 2 minormost @@ -404,14 +438,8 @@ def kernel(x_ref, y_ref): y = func(x) np.testing.assert_array_equal(y, expected) - @parameterized.product( - indexer_type=["state", "pallas"], - case=_INDEXING_TEST_CASES, - ) - def test_can_load_with_ref_at(self, indexer_type, case): - # TODO(apaszke): Remove after 12 weeks have passed. - if not jtu.if_cloud_tpu_at_least(2024, 12, 19): - self.skipTest("Requires libtpu built after 2024-12-19") + @parameterized.product(case=_INDEXING_TEST_CASES) + def test_can_load_with_ref_at(self, case): if self.INTERPRET: self.skipTest("TODO: fails in interpret mode.") in_shape, indexers, out_shape = case @@ -419,12 +447,8 @@ def test_can_load_with_ref_at(self, indexer_type, case): def body(x_ref, y_ref): for indexer in indexers[:-1]: x_ref = x_ref.at[indexer] - if indexer_type == "state": - x = x_ref[indexers[-1]] - y_ref[...] = x - elif indexer_type == "pallas": - x = pl.load(x_ref, indexers[-1]) - pl.store(y_ref, ..., x) + x = x_ref[indexers[-1]] + y_ref[...] = x x = random.normal(random.key(0), in_shape, dtype=dtype) y = x @@ -437,11 +461,8 @@ def body(x_ref, y_ref): out = self.pallas_call(body, out_shape=y)(x) self.assertAllClose(out, y) - @parameterized.product( - indexer_type=["state", "pallas"], - case=_INDEXING_TEST_CASES, - ) - def test_can_store_with_ref_at(self, indexer_type, case): + @parameterized.product(case=_INDEXING_TEST_CASES) + def test_can_store_with_ref_at(self, case): if self.INTERPRET: self.skipTest("TODO: fails in interpret mode.") in_shape, indexers, val_shape = case @@ -450,12 +471,8 @@ def body(x_ref, y_ref): y_ref[...] = jnp.zeros_like(y_ref) for indexer in indexers[:-1]: y_ref = y_ref.at[indexer] - if indexer_type == "state": - x = x_ref[...] - y_ref[indexers[-1]] = x - elif indexer_type == "pallas": - x = pl.load(x_ref, ...) - pl.store(y_ref, indexers[-1], x) + x = x_ref[...] + y_ref[indexers[-1]] = x val = random.normal(random.key(0), val_shape, dtype=dtype) # Use NumPy arrays to do nested indexing and mutation. This is really @@ -472,10 +489,7 @@ def body(x_ref, y_ref): out = self.pallas_call(body, out_shape=x)(val) self.assertAllClose(out, x) - @parameterized.product( - indexer_type=["state", "pallas"], - slice_type=["slice", "ds"], - ) + @parameterized.product(slice_type=["slice", "ds"]) @hp.given( ref_shape=hps.sampled_from(((8, 8, 32), (7, 7, 33))), indices=hps.tuples( @@ -486,7 +500,7 @@ def body(x_ref, y_ref): ), ) def test_strided_load_and_store( - self, indexer_type, slice_type, ref_shape, indices, strides + self, slice_type, ref_shape, indices, strides ): if self.INTERPRET: self.skipTest("TODO: fails in interpret mode.") @@ -507,12 +521,8 @@ def body(x_ref, y_ref1, y_ref2): slices = tuple( pl.ds(i, vs, s) for i, vs, s in zip(indices, vec_shape, strides) ) - if indexer_type == "state": - y_ref1[...] = x_ref[slices] - y_ref2[slices] = y_ref1[...] - elif indexer_type == "pallas": - pl.store(y_ref1, ..., pl.load(x_ref, slices)) - pl.store(y_ref2, slices, pl.load(y_ref1, ...)) + y_ref1[...] = x_ref[slices] + y_ref2[slices] = y_ref1[...] x = random.normal(random.key(0), ref_shape, dtype=dtype) y1, y2 = self.pallas_call( @@ -531,6 +541,43 @@ def body(x_ref, y_ref1, y_ref2): y2[slices], expected, err_msg="Strided Store Error" ) + @hp.given(hps.data()) + def test_load_and_broadcast_with_stride_0(self, data): + if self.INTERPRET: + self.skipTest("TODO: fails in interpret mode.") + dtype = jnp.float32 + rank = data.draw(hps.integers(min_value=2, max_value=4)) + shape = data.draw(hps.tuples( + *(hps.integers(min_value=1, max_value=10) for _ in range(rank - 1)))) + shape = (*shape, 128) + + strides = data.draw(hps.tuples( + *(hps.sampled_from([0, 1]) for _ in range(rank - 1)))) + strides = (*strides, 1) + + indices = [] + for i in range(rank): + index = (data.draw(hps.integers(min_value=0, max_value=shape[i] - 1)) + if strides[i] == 0 else 0) + indices.append(index) + + def body(x_ref, y_ref): + slices = tuple( + pl.ds(i, l, s) for i, l, s in zip(indices, shape, strides) + ) + y_ref[...] = x_ref[slices] + + x = random.normal(random.key(33), shape, dtype=dtype) + y = self.pallas_call( + body, + out_shape=jax.ShapeDtypeStruct(shape, dtype), + )(x) + slices = tuple(slice(i, l, 1) if s != 0 else slice(i, i + 1, 1) + for i, l, s in zip(indices, shape, strides)) + + expected = jnp.broadcast_to(x[slices], shape) + self.assertAllClose(y, expected) + def test_load_with_dynamic_2nd_minor_index(self): if pltpu is None: self.skipTest("No TPU module available.") @@ -541,7 +588,7 @@ def test_load_with_dynamic_2nd_minor_index(self): start = 2 def kernel(x_ref, indices, y_ref): - y_ref[...] = pl.load(x_ref, pl.ds(indices[0], k)) + y_ref[...] = x_ref[pl.ds(indices[0], k)] x = jnp.arange(m * n, dtype=jnp.int32).reshape((m, n)) indices = jnp.array([start]) @@ -569,7 +616,7 @@ def test_store_with_dynamic_2nd_minor_index(self): start = 2 def kernel(x_ref, indices, y_ref): - pl.store(y_ref, pl.ds(indices[0], m), x_ref[...]) + y_ref[pl.ds(indices[0], m)] = x_ref[...] x = jnp.arange(m * n, dtype=jnp.int32).reshape((m, n)) indices = jnp.array([start]) @@ -641,6 +688,34 @@ def kernel(x_ref, indices, y_ref): )(x, indices) self.assertAllClose(res[:, start : start + 1, :], x, atol=0., rtol=0.) + def test_scalar_load_from_vmem(self): + if not jtu.is_device_tpu_at_least(4): + self.skipTest("Requires TPU v4 or later") + def kernel(x_ref, o_ref, sem_ref): + o_ref[...] = jnp.zeros_like(o_ref) + scalar_val = x_ref[1, 2] + # Use scalar_val in both async_copy and store. + o_ref[scalar_val] = jnp.ones_like(o_ref[0]) * scalar_val + desc = pltpu.make_async_copy( + o_ref.at[scalar_val], + o_ref.at[scalar_val + 1], + sem_ref, + ) + desc.start() + desc.wait() + + x = jnp.array([[1, 2, 3], [4, 5, 6]], dtype=jnp.int32) + res = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((8, 8, 128), jnp.int32), + grid=(1,), + scratch_shapes=[pltpu.SemaphoreType.DMA] + )(x) + expected = jnp.zeros_like(res) + expected = expected.at[6].set(jnp.ones((8, 128), jnp.int32) * 6) + expected = expected.at[7].set(jnp.ones((8, 128), jnp.int32) * 6) + self.assertArraysEqual(res, expected) + class IndexerOpsInterpretTest(IndexerOpsTest): INTERPRET = True @@ -662,18 +737,18 @@ class IndexerOpsInterpretTest(IndexerOpsTest): ((4, 3), lambda arr, a, b, c, d: arr[a, 2]), # slice + 1-D array ((4, 3), lambda arr, a, b, c, d: arr[a, :]), - # ((4, 3), lambda arr, a, b, c, d: arr[:, a]), + ((4, 3), lambda arr, a, b, c, d: arr[:, a]), ((6, 8, 3), lambda arr, a, b, c, d: arr[c, ::3]), - # ((8, 6, 3), lambda arr, a, b, c, d: arr[::3, c]), - # ((8, 8, 3), lambda arr, a, b, c, d: arr[::4, ::2, a]), - # ((8, 8, 3), lambda arr, a, b, c, d: arr[::4, a, ::2]), + ((8, 6, 3), lambda arr, a, b, c, d: arr[::3, c]), + ((8, 8, 3), lambda arr, a, b, c, d: arr[::4, ::2, a]), + ((8, 8, 3), lambda arr, a, b, c, d: arr[::4, a, ::2]), ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[b, ::4, a, ::2]), ((3, 8, 8, 7), lambda arr, a, b, c, d: arr[b, a, ::4, ::2]), # ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[::4, b, a, ::2]), ((16, 3, 6, 2), lambda arr, a, b, c, d: arr[::4, a, 1::2, b]), ((8, 8, 3, 6), lambda arr, a, b, c, d: arr[b, ::4, a, a]), # slice + array w/ broadcasting - ((8, 8, 3, 6), lambda arr, a, b, c, d: \ + ((8, 8, 3, 6), lambda arr, a, b, c, d: arr[b[:, None], ::4, a[None], a[:, None]]), # integer + slice + 1-D array ((5, 8, 8, 3), lambda arr, a, b, c, d: arr[2, ::4, ::2, a]), diff --git a/tests/pallas/mgpu_attention_test.py b/tests/pallas/mgpu_attention_test.py index cf8ed30925bf..7e0934cd08c4 100644 --- a/tests/pallas/mgpu_attention_test.py +++ b/tests/pallas/mgpu_attention_test.py @@ -20,6 +20,8 @@ from absl.testing import absltest, parameterized from jax._src import config from jax._src import test_util as jtu +from jax._src.lib import cuda_versions +from jax._src.pallas import pallas_call import jax.numpy as jnp # pylint: disable=g-import-not-at-top @@ -47,6 +49,7 @@ def setUp(self): if (not jtu.test_device_matches(["cuda"]) or not jtu.is_cuda_compute_capability_equal("9.0")): self.skipTest("Only works on GPU with capability sm90a") + self.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(True)) @parameterized.product( batch_size=(1, 4), @@ -58,10 +61,13 @@ def setUp(self): (4, 4), ), # MHA head_dim=(64, 128, 256), + blocks=((64, 64), (64, 128), (128, 64)), attention_impl=( attention_mgpu.attention, attention_mgpu.attention_with_pipeline_emitter, ), + save_residuals=(True,), + causal=(True, False,), ) def test_flash_attention( self, @@ -70,24 +76,106 @@ def test_flash_attention( kv_seq_len, num_q_and_kv_heads, head_dim, + blocks, attention_impl, + save_residuals, + causal, ): + assert cuda_versions is not None + cuda_runtime_version = cuda_versions.cuda_runtime_get_version() + # TODO(pobudzey): Undo when we upgrade to cuda 12.9.1. + if causal and (cuda_runtime_version >= 12080 and cuda_runtime_version < 12091): + self.skipTest("Skipping because of ptxas miscompilation.") + + if causal and attention_impl == attention_mgpu.attention_with_pipeline_emitter: + self.skipTest("Pipeline emitter does not support causal attention.") + + if head_dim >= 256 and max(blocks) >= 128: + self.skipTest("Head dim too large for block sizes.") + num_q_heads, num_kv_heads = num_q_and_kv_heads + block_q, block_kv = blocks k1, k2, k3 = jax.random.split(jax.random.key(42), 3) q = jax.random.normal(k1, (batch_size, q_seq_len, num_q_heads, head_dim), jnp.float16) k = jax.random.normal(k2, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) v = jax.random.normal(k3, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) - out = attention_impl( + out, *res = attention_impl( q, k, v, attention_mgpu.TuningConfig( - block_q=64, block_kv=64, max_concurrent_steps=2 + block_q=block_q, block_kv=block_kv, max_concurrent_steps=2, causal=causal ), + save_residuals=save_residuals, ) - out_ref = attention_mgpu.attention_reference(q, k, v) + out_ref, *res_ref = attention_mgpu.attention_reference( + q, k, v, causal=causal, save_residuals=save_residuals) np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3) + if save_residuals: + (lse,) = res[0] + (lse_ref,) = res_ref[0] + np.testing.assert_allclose(lse, lse_ref, atol=2e-3, rtol=1e-3) + + @parameterized.product( + batch_size=(3,), + seq_lens=((512, 512), (3584, 4096)), + num_q_and_kv_heads=( + (4, 4), # MHA + (4, 1), # MQA + (6, 3), # GQA + ), + bwd_blocks = ( + (64, 64, 64, 64), + (64, 128, 128, 64), + (128, 128, 128, 128), + ), + head_dim=(64, 128, 256), + ) + def test_bwd_flash_attention( + self, + batch_size, + seq_lens, + num_q_and_kv_heads, + bwd_blocks, + head_dim, + ): + num_q_heads, num_kv_heads = num_q_and_kv_heads + kv_seq_len, q_seq_len = seq_lens + block_q_dq, block_kv_dq, block_q_dkv, block_kv_dkv = bwd_blocks + compute_wgs = 2 if head_dim <= 128 else 1 + k1, k2, k3 = jax.random.split(jax.random.key(42), 3) + q = jax.random.normal(k1, (batch_size, q_seq_len, num_q_heads, head_dim), jnp.float16) + k = jax.random.normal(k2, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) + v = jax.random.normal(k3, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) + + def f(q, k, v): + return attention_mgpu.attention( + q, + k, + v, + attention_mgpu.TuningConfig( + block_q=block_q_dq, block_kv=block_kv_dq, + max_concurrent_steps=2, compute_wgs_bwd=compute_wgs, + block_q_dkv=block_q_dkv, block_kv_dkv=block_kv_dkv, + block_q_dq=block_q_dq, block_kv_dq=block_kv_dq, + ) + ).sum() + + def f_ref(q, k, v): + return attention_mgpu.attention_reference(q, k, v).sum() + + try: + # TODO(pobudzey): Replace with `jtu.check_grads` when it's fixed. + dq, dk, dv = jax.grad(f, argnums=(0, 1, 2))(q, k, v) + dq_ref, dk_ref, dv_ref = jax.grad(f_ref, argnums=(0, 1, 2))(q, k, v) + + self.assertAllClose(dq, dq_ref, atol=7e-2) + self.assertAllClose(dk, dk_ref, atol=7e-2) + self.assertAllClose(dv, dv_ref, atol=5e-2) + except ValueError as e: + if "exceeds available shared memory" in e.args[0]: + self.skipTest("Not enough SMEM for this configuration.") if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/mgpu_collective_matmul_test.py b/tests/pallas/mgpu_collective_matmul_test.py new file mode 100644 index 000000000000..9f9995be0f5d --- /dev/null +++ b/tests/pallas/mgpu_collective_matmul_test.py @@ -0,0 +1,151 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test different parameterizations of our Mosaic GPU collective matmul.""" + +import functools +import os + +from absl.testing import parameterized # pylint: disable=g-multiple-import +import jax +from jax import lax +from jax import random +from jax._src import test_multiprocess as jt_multiprocess +from jax._src import test_util as jtu +from jax._src.pallas import pallas_call +from jax.experimental.mosaic import gpu as mgpu +from jax.experimental.pallas.ops.gpu import collective_matmul_mgpu +import jax.numpy as jnp +import numpy as np + + +P = jax.sharding.PartitionSpec + + +@jtu.with_config(jax_traceback_filtering="off") +class CollectiveMatmulTestCase(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if collective_matmul_mgpu is None: + self.skipTest("Mosaic GPU not available.") + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_equal("9.0")): + self.skipTest("Only works on GPU with capability sm90a") + if not mgpu.supports_cross_device_collectives(): + if "FAIL_ON_NVSHMEM_UNAVAILABLE" in os.environ: + raise ValueError("NVSHMEM library unavailable.") + else: + self.skipTest("NVSHMEM library unavailable.") + if jax.process_count() == 1: + self.skipTest("Test requires multiple processes.") + if os.environ.get("XLA_PYTHON_CLIENT_ALLOCATOR", "") == "platform": + self.skipTest("NVSHMEM doesn't work with the platform allocator.") + self.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(True)) + num_devices = jax.device_count() + mesh = jax.make_mesh( + (num_devices,), ("x",), axis_types=(jax.sharding.AxisType.Explicit,) + ) + self.enter_context(jax.set_mesh(mesh)) + + @parameterized.product( + m_shard=(3072,), + n_shard=(256, 576), + k=(4096,), + tile_m=(64, 128, 192), + tile_n=(64, 128, 192), + tile_k=(64, 128), + grid_minor_dim=(collective_matmul_mgpu.MatmulDimension.N,), + grid_tile_width=(1,), + wg_dimension=(collective_matmul_mgpu.MatmulDimension.N,), + max_concurrent_steps=(2, 4), + dtype=(jnp.bfloat16,), + ) + def test_all_gather_lhs_matmul( + self, + m_shard, + n_shard, + k, + tile_m, + tile_n, + tile_k, + max_concurrent_steps, + grid_minor_dim, + grid_tile_width, + wg_dimension, + dtype, + ): + num_devices = jax.device_count() + epi_tile_size = 64 * 64 + num_epi_tiles = tile_m * tile_n // epi_tile_size + cta_tile_m = tile_m * (1 + (wg_dimension == collective_matmul_mgpu.MatmulDimension.M)) + cta_tile_n = tile_n * (1 + (wg_dimension == collective_matmul_mgpu.MatmulDimension.N)) + if ( + (cta_tile_m + cta_tile_n) * tile_k * max_concurrent_steps + + 2 * min(2, num_epi_tiles) * epi_tile_size + ) * 2 > 228000: + self.skipTest("Tile too big to fit into SMEM") + if n_shard % cta_tile_n: + self.skipTest("n_shard must be divisible by block_n for now.") + if m_shard % cta_tile_m: + self.skipTest("m_shard must be divisible by block_m for now.") + + k1, k2 = random.split(random.key(1234), num=2) + lhs = random.normal(k1, (num_devices * m_shard, k), dtype) + rhs = random.normal(k2, (k, num_devices * n_shard), dtype) + lhs = jax.sharding.reshard(lhs, P("x", None)) + rhs = jax.sharding.reshard(rhs, P(None, "x")) + + def run(body): + out = jax.jit( + jax.shard_map(body, out_specs=P(None, "x"), check_vma=False) + )(lhs, rhs) + # Gather output, for NumPy comparison on the host. + out = jax.shard_map( + lambda x: lax.all_gather(x, "x", axis=1, tiled=True), + out_specs=P(None), check_vma=False, + )(out) + return out + + ref_out = run(lambda x, y: lax.all_gather(x, "x", axis=0, tiled=True) @ y) + config = collective_matmul_mgpu.TuningConfig( + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + max_concurrent_steps=max_concurrent_steps, + grid_minor_dim=grid_minor_dim, + grid_tile_width=grid_tile_width, + wg_dimension=wg_dimension, + ) + out = run( + functools.partial( + collective_matmul_mgpu.all_gather_lhs_matmul, + axis_name="x", + config=config, + dtype=dtype, + ) + ) + np.testing.assert_allclose(out, ref_out) + + +if __name__ == "__main__": + # This test doesn't work with the platform allocator, so we override it + # if it's ran alone. If it's part of a larger test suite and the platform + # allocator is used, setUp will skip the test. + os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.01" + os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "default" + os.environ["XLA_FLAGS"] = ( + os.environ.get("XLA_FLAGS", "") + " --xla_gpu_autotune_level=0" + ) + jt_multiprocess.main() diff --git a/tests/pallas/mgpu_examples_test.py b/tests/pallas/mgpu_examples_test.py new file mode 100644 index 000000000000..560c8c37e534 --- /dev/null +++ b/tests/pallas/mgpu_examples_test.py @@ -0,0 +1,988 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for examples from Pallas:MGPU documentation.""" + +import dataclasses +import functools +import itertools +import statistics + +from absl.testing import absltest +from absl.testing import parameterized +from jax import lax +from jax.extend import backend +from jax._src import config +from jax._src import test_util as jtu +from jax._src.pallas import pallas_call +import jax.experimental.mosaic.gpu # noqa: F401 +import jax.experimental.pallas as pl +import jax.experimental.pallas.mosaic_gpu as plgpu +from jax.experimental.mosaic.gpu import profiler +import jax.numpy as jnp +import numpy as np + + +config.parse_flags_with_absl() + + +@dataclasses.dataclass(frozen=True) +class TuningConfig: + tile_m: int + tile_n: int + tile_k: int + max_concurrent_steps: int + epilogue_tile_n: int = 64 + grid_minor_dim: int = 0 + grid_tile_width: int = 1 + + +def matmul0(a, b, config: TuningConfig): + dtype = a.dtype + m, k = a.shape + _, n = b.shape + tile_m, tile_n, tile_k = config.tile_m, config.tile_n, config.tile_k + swizzle = plgpu.find_swizzle(tile_k * jnp.dtype(dtype).itemsize * 8) + swizzle_elems = swizzle // jnp.dtype(dtype).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), plgpu.SwizzleTransform(swizzle) + ) + if m % tile_m != 0: + raise ValueError(f"{m=} must be divisible by {tile_m=}") + if n % tile_n != 0: + raise ValueError(f"{n=} must be divisible by {tile_n=}") + if k % tile_k != 0: + raise ValueError(f"{k=} must be divisible by {tile_k=}") + m_iters = m // tile_m + n_iters = n // tile_n + k_iters = k // tile_k + max_concurrent_steps = config.max_concurrent_steps + + def kernel(a_gmem, b_gmem, out_gmem, acc_tmem, acc_smem, consumed_barriers): + mi = lax.axis_index("m") + ni = lax.axis_index("n") + m_slice = pl.ds(mi * tile_m, tile_m) + n_slice = pl.ds(ni * tile_n, tile_n) + + def do_mma(idxs, a_smem, b_smem): + (ki,) = idxs + arrive_barrier_slot = ki % 2 + wait_barrier_slot = 1 - arrive_barrier_slot + plgpu.tcgen05_mma( + acc_tmem, + a_smem, + b_smem, + barrier=consumed_barriers.at[arrive_barrier_slot], + accumulate=(ki > 0), + ) + plgpu.barrier_wait(consumed_barriers.at[wait_barrier_slot]) + + # Make sure the wait succeeds in the first iteration. + plgpu.barrier_arrive(consumed_barriers.at[1]) + block_kwargs = dict(transforms=transforms, delay_release=1) + plgpu.emit_pipeline( + do_mma, + in_specs=[ + plgpu.BlockSpec((tile_m, tile_k), lambda ki: (mi, ki), **block_kwargs), + plgpu.BlockSpec((tile_k, tile_n), lambda ki: (ki, ni), **block_kwargs), + ], + grid=(k_iters,), + max_concurrent_steps=max_concurrent_steps, + )(a_gmem, b_gmem) + + final_barrier = 1 - (k_iters % 2) + plgpu.barrier_wait(consumed_barriers.at[final_barrier]) + acc_smem[...] = plgpu.async_load_tmem(acc_tmem).astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(acc_smem, out_gmem.at[m_slice, n_slice]) + plgpu.wait_smem_to_gmem(0, wait_read_only=True) + + f = plgpu.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct((m, n), dtype), + grid=(m_iters, n_iters), + grid_names=("m", "n"), + scratch_shapes=dict( + acc_tmem=plgpu.TMEM((tile_m, tile_n), jnp.float32), + acc_smem=plgpu.SMEM((tile_m, tile_n), dtype, transforms=transforms), + consumed_barriers=plgpu.Barrier( + num_arrivals=1, num_barriers=2, orders_tensor_core=True + ), + ), + ) + return f(a, b) + + +def matmul1(a, b, config: TuningConfig): + dtype = a.dtype + m, k = a.shape + _, n = b.shape + tile_m, tile_n, tile_k = config.tile_m, config.tile_n, config.tile_k + swizzle = plgpu.find_swizzle(tile_k * jnp.dtype(dtype).itemsize * 8) + swizzle_elems = swizzle // jnp.dtype(dtype).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), plgpu.SwizzleTransform(swizzle) + ) + if m % tile_m != 0: + raise ValueError(f"{m=} must be divisible by {tile_m=}") + if n % tile_n != 0: + raise ValueError(f"{n=} must be divisible by {tile_n=}") + if k % tile_k != 0: + raise ValueError(f"{k=} must be divisible by {tile_k=}") + m_iters = m // tile_m + n_iters = n // tile_n + k_iters = k // tile_k + max_concurrent_steps = config.max_concurrent_steps + + def kernel(a_gmem, b_gmem, out_gmem, + a_smem, b_smem, acc_tmem, acc_smem, + load_barriers, consumed_barriers, mma_done_barrier): + m_index = lax.axis_index("m") + n_index = lax.axis_index("n") + m_slice = pl.ds(m_index * tile_m, tile_m) + n_slice = pl.ds(n_index * tile_n, tile_n) + + @pl.core_map(plgpu.WarpMesh(axis_name="warp")) + def _per_warp(): + warp_id = lax.axis_index("warp") + + @pl.when(warp_id == 0) + def _memory(): + def _loop_body(ki, _): + slot = lax.rem(ki, max_concurrent_steps) + @pl.when(ki >= max_concurrent_steps) + def _(): # Make sure the data has been consumed before overwriting. + plgpu.barrier_wait(consumed_barriers.at[slot]) + k_slice = pl.ds(ki * tile_k, tile_k) + plgpu.copy_gmem_to_smem( + a_gmem.at[m_slice, k_slice], a_smem.at[slot], load_barriers.at[slot] + ) + plgpu.copy_gmem_to_smem( + b_gmem.at[k_slice, n_slice], b_smem.at[slot], load_barriers.at[slot] + ) + lax.fori_loop(0, k_iters, _loop_body, None) + + @pl.when(warp_id == 1) + def _compute(): + def _loop_body(ki, _): + slot = lax.rem(ki, max_concurrent_steps) + plgpu.barrier_wait(load_barriers.at[slot]) # Wait for data to arrive. + plgpu.tcgen05_mma( + acc_tmem, + a_smem.at[slot], + b_smem.at[slot], + consumed_barriers.at[slot], + accumulate=(ki > 0), + ) + lax.fori_loop(0, k_iters, _loop_body, None) + plgpu.tcgen05_commit_arrive(mma_done_barrier) + + plgpu.barrier_wait(mma_done_barrier) + acc_smem[...] = plgpu.async_load_tmem(acc_tmem).astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(acc_smem, out_gmem.at[m_slice, n_slice]) + plgpu.wait_smem_to_gmem(0, wait_read_only=True) + + f = plgpu.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct((m, n), dtype), + grid=(m_iters, n_iters), + grid_names=("m", "n"), + scratch_shapes=dict( + a_smem=plgpu.SMEM( + (max_concurrent_steps, tile_m, tile_k), + dtype, + transforms=transforms, + ), + b_smem=plgpu.SMEM( + (max_concurrent_steps, tile_k, tile_n), + dtype, + transforms=transforms, + ), + acc_tmem=plgpu.TMEM((tile_m, tile_n), jnp.float32), + acc_smem=plgpu.SMEM((tile_m, tile_n), dtype, transforms=transforms), + load_barriers=plgpu.Barrier( + num_arrivals=2, num_barriers=max_concurrent_steps + ), + consumed_barriers=plgpu.Barrier( + num_arrivals=1, + num_barriers=max_concurrent_steps, + orders_tensor_core=True, + ), + mma_done_barrier=plgpu.Barrier( + num_arrivals=1, num_barriers=1, orders_tensor_core=True + ), + ), + ) + return f(a, b) + + +def matmul2(a, b, config: TuningConfig): + dtype = a.dtype + m, k = a.shape + _, n = b.shape + tile_m, tile_n, tile_k = config.tile_m, config.tile_n, config.tile_k + swizzle = plgpu.find_swizzle(tile_k * jnp.dtype(dtype).itemsize * 8) + swizzle_elems = swizzle // jnp.dtype(dtype).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), plgpu.SwizzleTransform(swizzle) + ) + if m % tile_m != 0: + raise ValueError(f"{m=} must be divisible by {tile_m=}") + if n % tile_n != 0: + raise ValueError(f"{n=} must be divisible by {tile_n=}") + if k % tile_k != 0: + raise ValueError(f"{k=} must be divisible by {tile_k=}") + m_iters = m // tile_m + n_iters = n // tile_n + k_iters = k // tile_k + max_concurrent_steps = config.max_concurrent_steps + + def kernel(a_gmem, b_gmem, out_gmem, + a_smem, b_smem, acc_tmem, acc_smem, + load_barriers, consumed_barriers, mma_done_barrier): + m_index = lax.axis_index("m") + n_index = lax.axis_index("n") + m_slice = pl.ds(m_index * tile_m, tile_m) + n_slice = pl.ds(n_index * tile_n, tile_n) + + @pl.core_map(plgpu.WarpMesh(axis_name="warp")) + def _per_warp(): + warp_id = lax.axis_index("warp") + + @pl.when(warp_id == 0) + def _memory(): + def _loop_body(ki, _): + slot = lax.rem(ki, max_concurrent_steps) + @pl.when(ki >= max_concurrent_steps) + def _(): # Make sure the data has been consumed before overwriting. + plgpu.barrier_wait(consumed_barriers.at[slot]) + k_slice = pl.ds(ki * tile_k, tile_k) + plgpu.copy_gmem_to_smem( + a_gmem.at[m_slice, k_slice], a_smem.at[slot], load_barriers.at[slot] + ) + plgpu.copy_gmem_to_smem( + b_gmem.at[k_slice, n_slice], b_smem.at[slot], load_barriers.at[slot] + ) + lax.fori_loop(0, k_iters, _loop_body, None) + + @pl.when(warp_id == 1) + def _compute(): + def _loop_body(ki, _): + slot = lax.rem(ki, max_concurrent_steps) + plgpu.barrier_wait(load_barriers.at[slot]) # Wait for data to arrive. + plgpu.tcgen05_mma( + acc_tmem, + a_smem.at[slot], + b_smem.at[slot], + consumed_barriers.at[slot], + accumulate=(ki > 0), + ) + lax.fori_loop(0, k_iters, _loop_body, None) + plgpu.tcgen05_commit_arrive(mma_done_barrier) + + plgpu.barrier_wait(mma_done_barrier) + out_gmem_window = out_gmem.at[m_slice, n_slice] + for ni in range(tile_n // config.epilogue_tile_n): + acc_smem_ni = acc_smem.at[ni % 2] + ni_slice = pl.ds(ni * config.epilogue_tile_n, config.epilogue_tile_n) + # Make sure that previous copy is done before we overwrite. + plgpu.wait_smem_to_gmem(1, wait_read_only=True) + acc_smem_ni[...] = plgpu.async_load_tmem(acc_tmem.at[:, ni_slice]).astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(acc_smem_ni, out_gmem_window.at[:, ni_slice]) + plgpu.wait_smem_to_gmem(0, wait_read_only=True) + + f = plgpu.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct((m, n), dtype), + grid=(m_iters, n_iters), + grid_names=("m", "n"), + scratch_shapes=dict( + a_smem=plgpu.SMEM( + (max_concurrent_steps, tile_m, tile_k), dtype, transforms=transforms + ), + b_smem=plgpu.SMEM( + (max_concurrent_steps, tile_k, tile_n), dtype, transforms=transforms + ), + acc_tmem=plgpu.TMEM((tile_m, tile_n), jnp.float32), + acc_smem=plgpu.SMEM((2, tile_m, config.epilogue_tile_n), dtype, transforms=transforms), + load_barriers=plgpu.Barrier(num_arrivals=2, num_barriers=max_concurrent_steps), + consumed_barriers=plgpu.Barrier( + num_arrivals=1, + num_barriers=max_concurrent_steps, + orders_tensor_core=True, + ), + mma_done_barrier=plgpu.Barrier(num_arrivals=1, num_barriers=1, orders_tensor_core=True), + ) + ) + return f(a, b) + + +def matmul3(a, b, config: TuningConfig): + dtype = a.dtype + m, k = a.shape + _, n = b.shape + tile_m, tile_n, tile_k = config.tile_m, config.tile_n, config.tile_k + swizzle = plgpu.find_swizzle(tile_k * jnp.dtype(dtype).itemsize * 8) + swizzle_elems = swizzle // jnp.dtype(dtype).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), plgpu.SwizzleTransform(swizzle) + ) + if m % tile_m != 0: + raise ValueError(f"{m=} must be divisible by {tile_m=}") + if n % tile_n != 0: + raise ValueError(f"{n=} must be divisible by {tile_n=}") + if k % tile_k != 0: + raise ValueError(f"{k=} must be divisible by {tile_k=}") + cluster_tile_m = 2 * tile_m + cluster_tile_n = 2 * tile_n + m_iters = m // cluster_tile_m + n_iters = n // cluster_tile_n + k_iters = k // tile_k + max_concurrent_steps = config.max_concurrent_steps + + def kernel(a_gmem, b_gmem, out_gmem, + a_smem, b_smem, acc_tmem, acc_smem, + load_barriers, consumed_barriers, mma_done_barrier): + is_lead_block = lax.axis_index("cluster") == 0 + m_index = lax.axis_index("m") + n_index = lax.axis_index("n") + m_slice = pl.ds(m_index * cluster_tile_m, cluster_tile_m) + n_slice = pl.ds(n_index * cluster_tile_n, cluster_tile_n) + + @pl.core_map(plgpu.WarpMesh(axis_name="warp")) + def _per_warp(): + warp_id = lax.axis_index("warp") + + @pl.when(warp_id == 0) + def _memory(): + def _loop_body(ki, _): + slot = lax.rem(ki, max_concurrent_steps) + @pl.when(ki >= max_concurrent_steps) + def _(): # Make sure the data has been consumed before overwriting. + plgpu.barrier_wait(consumed_barriers.at[slot]) + k_slice = pl.ds(ki * tile_k, tile_k) + plgpu.copy_gmem_to_smem( + a_gmem.at[m_slice, k_slice], a_smem.at[slot], load_barriers.at[slot], + collective_axes="cluster", partitioned_axis=0 + ) + plgpu.copy_gmem_to_smem( + b_gmem.at[k_slice, n_slice], b_smem.at[slot], load_barriers.at[slot], + collective_axes="cluster", partitioned_axis=1 + ) + lax.fori_loop(0, k_iters, _loop_body, None) + + @pl.when(jnp.logical_and(warp_id == 1, is_lead_block)) + def _compute(): + def _loop_body(ki, _): + slot = lax.rem(ki, max_concurrent_steps) + plgpu.barrier_wait(load_barriers.at[slot]) # Wait for data to arrive. + plgpu.tcgen05_mma( + acc_tmem, + a_smem.at[slot], + b_smem.at[slot], + consumed_barriers.at[slot], + accumulate=(ki > 0), + collective_axis="cluster", + ) + lax.fori_loop(0, k_iters, _loop_body, None) + plgpu.tcgen05_commit_arrive(mma_done_barrier, collective_axis="cluster") + + plgpu.barrier_wait(mma_done_barrier) + out_m_index = m_index * 2 + lax.axis_index("cluster") + out_m_slice = pl.ds(out_m_index * tile_m, tile_m) + out_gmem_window = out_gmem.at[out_m_slice, n_slice] + for ni in range(cluster_tile_n // config.epilogue_tile_n): + acc_smem_ni = acc_smem.at[ni % 2] + ni_slice = pl.ds(ni * config.epilogue_tile_n, config.epilogue_tile_n) + # Make sure that previous copy is done before we overwrite. + plgpu.wait_smem_to_gmem(1, wait_read_only=True) + acc_smem_ni[...] = plgpu.async_load_tmem(acc_tmem.at[:, ni_slice]).astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(acc_smem_ni, out_gmem_window.at[:, ni_slice]) + plgpu.wait_smem_to_gmem(0, wait_read_only=True) + + f = plgpu.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct((m, n), dtype), + grid=(m_iters, n_iters), + grid_names=("m", "n"), + cluster=(2,), + cluster_names=("cluster",), + scratch_shapes=dict( + a_smem=plgpu.SMEM( + (max_concurrent_steps, tile_m, tile_k), dtype, transforms=transforms, + ), + b_smem=plgpu.SMEM( + (max_concurrent_steps, tile_k, tile_n), dtype, transforms=transforms, + ), + acc_tmem=plgpu.TMEM((tile_m, cluster_tile_n), jnp.float32, collective=True), + acc_smem=plgpu.SMEM((2, tile_m, config.epilogue_tile_n), dtype, transforms=transforms), + load_barriers=plgpu.Barrier(num_arrivals=2, num_barriers=max_concurrent_steps), + consumed_barriers=plgpu.Barrier( + num_arrivals=1, + num_barriers=max_concurrent_steps, + orders_tensor_core=True, + ), + mma_done_barrier=plgpu.Barrier(num_arrivals=1, num_barriers=1, orders_tensor_core=True), + ) + ) + return f(a, b) + + +def matmul4(a, b, config: TuningConfig): + dtype = a.dtype + m, k = a.shape + _, n = b.shape + tile_m, tile_n, tile_k = config.tile_m, config.tile_n, config.tile_k + swizzle = plgpu.find_swizzle(tile_k * jnp.dtype(dtype).itemsize * 8) + swizzle_elems = swizzle // jnp.dtype(dtype).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), plgpu.SwizzleTransform(swizzle) + ) + if m % tile_m != 0: + raise ValueError(f"{m=} must be divisible by {tile_m=}") + if n % tile_n != 0: + raise ValueError(f"{n=} must be divisible by {tile_n=}") + if k % tile_k != 0: + raise ValueError(f"{k=} must be divisible by {tile_k=}") + cluster_tile_m = 2 * tile_m + cluster_tile_n = 2 * tile_n + m_iters = m // cluster_tile_m + n_iters = n // cluster_tile_n + k_iters = k // tile_k + max_concurrent_steps = config.max_concurrent_steps + + def kernel(a_gmem, b_gmem, out_gmem, + a_smem, b_smem, acc_tmem, acc_smem, + load_barriers, consumed_barriers, mma_done_barrier): + is_lead_block = lax.axis_index("cluster") == 0 + + @plgpu.nd_loop((m_iters, n_iters), collective_axes="cluster_grid") + def _mn_loop(loop_info: plgpu.NDLoopInfo): + m_index, n_index = loop_info.index + m_slice = pl.ds(m_index * cluster_tile_m, cluster_tile_m) + n_slice = pl.ds(n_index * cluster_tile_n, cluster_tile_n) + + @pl.core_map(plgpu.WarpMesh(axis_name="warp")) + def _per_warp(): + warp_id = lax.axis_index("warp") + + @pl.when(warp_id == 0) + def _memory(): + def _loop_body(ki, _): + slot = lax.rem(ki, max_concurrent_steps) + @pl.when(jnp.logical_or(ki >= max_concurrent_steps, loop_info.local_index > 0)) + def _(): # Make sure the data has been consumed before overwriting. + plgpu.barrier_wait(consumed_barriers.at[slot]) + k_slice = pl.ds(ki * tile_k, tile_k) + plgpu.copy_gmem_to_smem( + a_gmem.at[m_slice, k_slice], a_smem.at[slot], load_barriers.at[slot], + collective_axes="cluster", partitioned_axis=0 + ) + plgpu.copy_gmem_to_smem( + b_gmem.at[k_slice, n_slice], b_smem.at[slot], load_barriers.at[slot], + collective_axes="cluster", partitioned_axis=1 + ) + + lax.fori_loop(0, k_iters, _loop_body, None) + + @pl.when(jnp.logical_and(warp_id == 1, is_lead_block)) + def _compute(): + def _loop_body(ki, _): + slot = lax.rem(ki, max_concurrent_steps) + plgpu.barrier_wait(load_barriers.at[slot]) # Wait for data to arrive. + plgpu.tcgen05_mma( + acc_tmem, + a_smem.at[slot], + b_smem.at[slot], + consumed_barriers.at[slot], + accumulate=(ki > 0), + collective_axis="cluster", + ) + lax.fori_loop(0, k_iters, _loop_body, None) + plgpu.tcgen05_commit_arrive( + mma_done_barrier, + collective_axis="cluster", + ) + + plgpu.wait_smem_to_gmem(0, wait_read_only=True) # Make sure that previous store is done. + plgpu.barrier_wait(mma_done_barrier) + out_m_index = m_index * 2 + lax.axis_index("cluster") + out_m_slice = pl.ds(out_m_index * tile_m, tile_m) + out_gmem_window = out_gmem.at[out_m_slice, n_slice] + for ni in range(cluster_tile_n // config.epilogue_tile_n): + acc_smem_ni = acc_smem.at[ni % 2] + ni_slice = pl.ds(ni * config.epilogue_tile_n, config.epilogue_tile_n) + # Make sure that previous copy is done before we overwrite. + plgpu.wait_smem_to_gmem(1, wait_read_only=True) + acc_smem_ni[...] = plgpu.async_load_tmem(acc_tmem.at[:, ni_slice]).astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(acc_smem_ni, out_gmem_window.at[:, ni_slice]) + plgpu.wait_load_tmem() # Load must complete before MMA can overwrite TMEM. + plgpu.wait_smem_to_gmem(0, wait_read_only=True) + + num_sms = backend.get_default_device().core_count + f = plgpu.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct((m, n), dtype), + grid=(num_sms // 2,), + grid_names=("cluster_grid",), + cluster=(2,), + cluster_names=("cluster",), + scratch_shapes=dict( + a_smem=plgpu.SMEM( + (max_concurrent_steps, tile_m, tile_k), dtype, transforms=transforms, + ), + b_smem=plgpu.SMEM( + (max_concurrent_steps, tile_k, tile_n), dtype, transforms=transforms, + ), + acc_tmem=plgpu.TMEM((tile_m, cluster_tile_n), jnp.float32, collective=True), + acc_smem=plgpu.SMEM((2, tile_m, config.epilogue_tile_n), dtype, transforms=transforms), + load_barriers=plgpu.Barrier(num_arrivals=2, num_barriers=max_concurrent_steps), + consumed_barriers=plgpu.Barrier( + num_arrivals=1, + num_barriers=max_concurrent_steps, + orders_tensor_core=True, + ), + mma_done_barrier=plgpu.Barrier(num_arrivals=1, num_barriers=1, orders_tensor_core=True), + ), + ) + return f(a, b) + + +def matmul5(a, b, config: TuningConfig): + dtype = a.dtype + m, k = a.shape + _, n = b.shape + tile_m, tile_n, tile_k = config.tile_m, config.tile_n, config.tile_k + swizzle = plgpu.find_swizzle(tile_k * jnp.dtype(dtype).itemsize * 8) + swizzle_elems = swizzle // jnp.dtype(dtype).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), plgpu.SwizzleTransform(swizzle) + ) + if m % tile_m != 0: + raise ValueError(f"{m=} must be divisible by {tile_m=}") + if n % tile_n != 0: + raise ValueError(f"{n=} must be divisible by {tile_n=}") + if k % tile_k != 0: + raise ValueError(f"{k=} must be divisible by {tile_k=}") + cluster_tile_m = 2 * tile_m + cluster_tile_n = 2 * tile_n + m_iters = m // cluster_tile_m + n_iters = n // cluster_tile_n + k_iters = k // tile_k + max_concurrent_steps = config.max_concurrent_steps + + def kernel(a_gmem, b_gmem, out_gmem, + a_smem, b_smem, acc_tmem, acc_smem, + load_barriers, consumed_barriers, mma_done_barrier, store_done_barrier): + wg_idx = lax.axis_index("wg") + is_lead_block = lax.axis_index("cluster") == 0 + + @plgpu.nd_loop((m_iters, n_iters), collective_axes="cluster_grid") + def _mn_loop(loop_info: plgpu.NDLoopInfo): + m_index, n_index = loop_info.index + m_slice = pl.ds(m_index * cluster_tile_m, cluster_tile_m) + n_slice = pl.ds(n_index * cluster_tile_n, cluster_tile_n) + acc_slot = lax.rem(loop_info.local_index, jnp.int32(2)) + mn_acc_tmem = acc_tmem.at[:, pl.ds(acc_slot * cluster_tile_n, cluster_tile_n)] + + @pl.when(wg_idx == 0) + def _compute_wg(): + @pl.core_map(plgpu.WarpMesh(axis_name="warp")) + def _per_warp(): + warp_id = lax.axis_index("warp") + + @pl.when(warp_id == 0) + def _memory(): + def _loop_body(ki, _): + slot = lax.rem(ki, max_concurrent_steps) + @pl.when(jnp.logical_or(ki >= max_concurrent_steps, loop_info.local_index > 0)) + def _(): # Make sure the data has been consumed before overwriting. + plgpu.barrier_wait(consumed_barriers.at[slot]) + k_slice = pl.ds(ki * tile_k, tile_k) + plgpu.copy_gmem_to_smem( + a_gmem.at[m_slice, k_slice], a_smem.at[slot], load_barriers.at[slot], + collective_axes="cluster", partitioned_axis=0 + ) + plgpu.copy_gmem_to_smem( + b_gmem.at[k_slice, n_slice], b_smem.at[slot], load_barriers.at[slot], + collective_axes="cluster", partitioned_axis=1 + ) + + lax.fori_loop(0, k_iters, _loop_body, None) + + # Wait for store to complete (except for the first two steps). + @pl.when(jnp.logical_and(warp_id == 1, loop_info.local_index >= 2)) + def _wait_store(): + plgpu.barrier_wait(store_done_barrier.at[acc_slot]) + @pl.when(jnp.logical_and(warp_id == 1, is_lead_block)) + def _compute(): + def _loop_body(ki, _): + slot = lax.rem(ki, max_concurrent_steps) + plgpu.barrier_wait(load_barriers.at[slot]) # Wait for data to arrive. + plgpu.tcgen05_mma( + mn_acc_tmem, + a_smem.at[slot], + b_smem.at[slot], + consumed_barriers.at[slot], + accumulate=(ki > 0), + collective_axis="cluster", + ) + lax.fori_loop(0, k_iters, _loop_body, None) + plgpu.tcgen05_commit_arrive( + mma_done_barrier.at[acc_slot], + collective_axis="cluster", + ) + + @pl.when(wg_idx == 1) + def _store_wg(): + # Ensure that copies from the previous mn step have completed. + plgpu.wait_smem_to_gmem(0, wait_read_only=True) + plgpu.barrier_wait(mma_done_barrier.at[acc_slot]) + out_m_index = m_index * 2 + lax.axis_index("cluster") + out_m_slice = pl.ds(out_m_index * tile_m, tile_m) + out_gmem_window = out_gmem.at[out_m_slice, n_slice] + for ni in range(cluster_tile_n // config.epilogue_tile_n): + acc_smem_ni = acc_smem.at[ni % 2] + ni_slice = pl.ds(ni * config.epilogue_tile_n, config.epilogue_tile_n) + # Make sure that previous copy is done before we overwrite. + plgpu.wait_smem_to_gmem(1, wait_read_only=True) + acc_smem_ni[...] = plgpu.async_load_tmem(mn_acc_tmem.at[:, ni_slice]).astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(acc_smem_ni, out_gmem_window.at[:, ni_slice]) + plgpu.wait_load_tmem() # Load must complete before we signal. + plgpu.barrier_arrive(store_done_barrier.at[acc_slot]) + plgpu.wait_smem_to_gmem(0, wait_read_only=True) + + num_sms = backend.get_default_device().core_count + f = plgpu.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct((m, n), dtype), + grid=(num_sms // 2,), + grid_names=("cluster_grid",), + cluster=(2,), + cluster_names=("cluster",), + num_threads=2, + thread_name="wg", + scratch_shapes=dict( + a_smem=plgpu.SMEM( + (max_concurrent_steps, tile_m, tile_k), dtype, transforms=transforms, + ), + b_smem=plgpu.SMEM( + (max_concurrent_steps, tile_k, tile_n), dtype, transforms=transforms, + ), + acc_tmem=plgpu.TMEM((tile_m, 2 * cluster_tile_n), jnp.float32, collective=True), + acc_smem=plgpu.SMEM((2, tile_m, config.epilogue_tile_n), dtype, transforms=transforms), + load_barriers=plgpu.Barrier(num_arrivals=2, num_barriers=max_concurrent_steps), + consumed_barriers=plgpu.Barrier( + num_arrivals=1, + num_barriers=max_concurrent_steps, + orders_tensor_core=True, + ), + mma_done_barrier=plgpu.Barrier(num_arrivals=1, num_barriers=2, orders_tensor_core=True), + store_done_barrier=plgpu.ClusterBarrier( + collective_axes=("cluster",), + num_arrivals=1, + num_barriers=2, + orders_tensor_core=True, + ), + ), + ) + return f(a, b) + + +def matmul6(a, b, config: TuningConfig): + dtype = a.dtype + m, k = a.shape + _, n = b.shape + tile_m, tile_n, tile_k = config.tile_m, config.tile_n, config.tile_k + swizzle = plgpu.find_swizzle(tile_k * jnp.dtype(dtype).itemsize * 8) + swizzle_elems = swizzle // jnp.dtype(dtype).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), plgpu.SwizzleTransform(swizzle) + ) + if m % tile_m != 0: + raise ValueError(f"{m=} must be divisible by {tile_m=}") + if n % tile_n != 0: + raise ValueError(f"{n=} must be divisible by {tile_n=}") + if k % tile_k != 0: + raise ValueError(f"{k=} must be divisible by {tile_k=}") + cluster_tile_m = 2 * tile_m + cluster_tile_n = 2 * tile_n + m_iters = m // cluster_tile_m + n_iters = n // cluster_tile_n + k_iters = k // tile_k + max_concurrent_steps = config.max_concurrent_steps + + def kernel(a_gmem, b_gmem, out_gmem, + a_smem, b_smem, acc_tmem, acc_smem, + load_barriers, consumed_barriers, mma_done_barrier, store_done_barrier): + wg_idx = lax.axis_index("wg") + is_lead_block = lax.axis_index("cluster") == 0 + + @plgpu.nd_loop((m_iters * n_iters,), collective_axes="cluster_grid") + def _mn_loop(loop_info: plgpu.NDLoopInfo): + (lin_idx,) = loop_info.index + m_index, n_index = plgpu.planar_snake( + lin_idx, + (m_iters, n_iters), + config.grid_minor_dim, + config.grid_tile_width, + ) + m_slice = pl.ds(m_index * cluster_tile_m, cluster_tile_m) + n_slice = pl.ds(n_index * cluster_tile_n, cluster_tile_n) + acc_slot = lax.rem(loop_info.local_index, jnp.int32(2)) + mn_acc_tmem = acc_tmem.at[:, pl.ds(acc_slot * cluster_tile_n, cluster_tile_n)] + + @pl.when(wg_idx == 0) + def _compute_wg(): + @pl.core_map(plgpu.WarpMesh(axis_name="warp")) + def _per_warp(): + warp_id = lax.axis_index("warp") + + @pl.when(warp_id == 0) + def _memory(): + def _loop_body(ki, _): + slot = lax.rem(ki, max_concurrent_steps) + @pl.when(jnp.logical_or(ki >= max_concurrent_steps, loop_info.local_index > 0)) + def _(): # Make sure the data has been consumed before overwriting. + plgpu.barrier_wait(consumed_barriers.at[slot]) + k_slice = pl.ds(ki * tile_k, tile_k) + plgpu.copy_gmem_to_smem( + a_gmem.at[m_slice, k_slice], a_smem.at[slot], load_barriers.at[slot], + collective_axes="cluster", partitioned_axis=0 + ) + plgpu.copy_gmem_to_smem( + b_gmem.at[k_slice, n_slice], b_smem.at[slot], load_barriers.at[slot], + collective_axes="cluster", partitioned_axis=1 + ) + + lax.fori_loop(0, k_iters, _loop_body, None) + + # Wait for store to complete (except for the first two steps). + @pl.when(jnp.logical_and(warp_id == 1, loop_info.local_index >= 2)) + def _wait_store(): + plgpu.barrier_wait(store_done_barrier.at[acc_slot]) + @pl.when(jnp.logical_and(warp_id == 1, is_lead_block)) + def _compute(): + def _loop_body(ki, _): + slot = lax.rem(ki, max_concurrent_steps) + plgpu.barrier_wait(load_barriers.at[slot]) # Wait for data to arrive. + plgpu.tcgen05_mma( + mn_acc_tmem, + a_smem.at[slot], + b_smem.at[slot], + consumed_barriers.at[slot], + accumulate=(ki > 0), + collective_axis="cluster", + ) + lax.fori_loop(0, k_iters, _loop_body, None) + plgpu.tcgen05_commit_arrive( + mma_done_barrier.at[acc_slot], + collective_axis="cluster", + ) + + @pl.when(wg_idx == 1) + def _store_wg(): + # Ensure that copies from the previous mn step have completed. + plgpu.wait_smem_to_gmem(0, wait_read_only=True) + plgpu.barrier_wait(mma_done_barrier.at[acc_slot]) + out_m_index = m_index * 2 + lax.axis_index("cluster") + out_m_slice = pl.ds(out_m_index * tile_m, tile_m) + out_gmem_window = out_gmem.at[out_m_slice, n_slice] + for ni in range(cluster_tile_n // config.epilogue_tile_n): + acc_smem_ni = acc_smem.at[ni % 2] + ni_slice = pl.ds(ni * config.epilogue_tile_n, config.epilogue_tile_n) + # Make sure that previous copy is done before we overwrite. + plgpu.wait_smem_to_gmem(1, wait_read_only=True) + acc_smem_ni[...] = plgpu.async_load_tmem(mn_acc_tmem.at[:, ni_slice]).astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(acc_smem_ni, out_gmem_window.at[:, ni_slice]) + plgpu.wait_load_tmem() # Load must complete before we signal. + plgpu.barrier_arrive(store_done_barrier.at[acc_slot]) + plgpu.wait_smem_to_gmem(0, wait_read_only=True) + + num_sms = backend.get_default_device().core_count + f = plgpu.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct((m, n), dtype), + grid=(num_sms // 2,), + grid_names=("cluster_grid",), + cluster=(2,), + cluster_names=("cluster",), + num_threads=2, + thread_name="wg", + scratch_shapes=dict( + a_smem=plgpu.SMEM( + (max_concurrent_steps, tile_m, tile_k), dtype, transforms=transforms + ), + b_smem=plgpu.SMEM( + (max_concurrent_steps, tile_k, tile_n), dtype, transforms=transforms + ), + acc_tmem=plgpu.TMEM((tile_m, 2 * cluster_tile_n), jnp.float32, collective=True), + acc_smem=plgpu.SMEM((2, tile_m, config.epilogue_tile_n), dtype, transforms=transforms), + load_barriers=plgpu.Barrier(num_arrivals=2, num_barriers=max_concurrent_steps), + consumed_barriers=plgpu.Barrier( + num_arrivals=1, + num_barriers=max_concurrent_steps, + orders_tensor_core=True, + ), + mma_done_barrier=plgpu.Barrier(num_arrivals=1, num_barriers=2, orders_tensor_core=True), + store_done_barrier=plgpu.ClusterBarrier( + collective_axes=("cluster",), + num_arrivals=1, + num_barriers=2, + orders_tensor_core=True, + ), + ) + ) + return f(a, b) + + +@jtu.with_config(jax_traceback_filtering="off") +class MatmulTutorialSm100ATest(jtu.JaxTestCase): + BENCHMARK = False + + def setUp(self): + super().setUp() + if not jtu.test_device_matches(["cuda"]): + self.skipTest("Test requires an NVIDIA GPU") + if not jtu.is_cuda_compute_capability_equal("10.0"): + self.skipTest("Only works on GPU with capability sm100a") + self.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(True)) + + def benchmark(self, matmul_impl, a, b, config_search_space): + if not self.BENCHMARK: + return + config_names = config_search_space.keys() + config_all_values = config_search_space.values() + peak_flops = 2250e12 # f16 TensorCore peak = 2250 TFLOPS + matmul_flops = 2 * a.shape[0] * b.shape[0] * b.shape[1] + optimal_time_us = matmul_flops / peak_flops * 1e6 # us + best_util = 0.0 + ref = jnp.dot(a, b, precision=jax.lax.DotAlgorithmPreset.F16_F16_F32) + for config_values in itertools.product(*config_all_values): + config = TuningConfig(**dict(zip(config_names, config_values))) + try: + out, runtimes_ms = profiler.measure( + functools.partial(matmul_impl, config=config), iterations=100 + )(a, b) + except ValueError as e: + if "exceeds available shared memory" in e.args[0]: # Ignore SMEM OOMs. + continue + raise + assert runtimes_ms is not None + runtime_ms = statistics.median(runtimes_ms) + runtime_us = runtime_ms * 1e3 # type: ignore + achieved_tc_util = optimal_time_us / runtime_us * 100 + print(f"{config} {achieved_tc_util:.2f}% TC utilization") + if achieved_tc_util > best_util: + best_util = achieved_tc_util + np.testing.assert_allclose(out, ref) + print(f"Best result for {matmul_impl.__name__}: {best_util:.2f}% TC utilization") + _, runtimes_ms = profiler.measure( + functools.partial( + jnp.dot, precision=jax.lax.DotAlgorithmPreset.F16_F16_F32 + ), + iterations=100, + )(a, b) + runtime_ms = statistics.median(runtimes_ms) + runtime_us = runtime_ms * 1e3 # type: ignore + achieved_tc_util = optimal_time_us / runtime_us * 100 + print(f"Reference: {achieved_tc_util:.2f}% TC utilization") + + def _test_matmul(self, matmul_impl, example_config, config_search_space): + dtype = jnp.float16 + m = 4096 + n = 8192 + k = 4096 + k1, k2, = jax.random.split(jax.random.key(42), 2) + a = jax.random.normal(k1, (m, k), dtype) + b = jax.random.normal(k2, (k, n), dtype) + + out = matmul_impl(a, b, example_config) + out_ref = jnp.dot(a, b, precision=jax.lax.DotAlgorithmPreset.F16_F16_F32) + np.testing.assert_allclose(out, out_ref) + self.benchmark(matmul_impl, a, b, config_search_space) + + @parameterized.parameters(matmul0, matmul1, matmul2) + def test_matmul(self, matmul_impl): + example_config = TuningConfig( + tile_m=128, tile_n=128, tile_k=64, max_concurrent_steps=4, + ) + config_search_space = { + "tile_m": (128,), + "tile_n": (128, 256, 512), + "tile_k": (64,), + "max_concurrent_steps": (4, 6), + } + self._test_matmul(matmul_impl, example_config, config_search_space) + + def test_matmul3(self): + example_config = TuningConfig( + tile_m=128, tile_n=128, tile_k=64, max_concurrent_steps=4, + ) + config_search_space = { + "tile_m": (128,), + "tile_n": (128,), + "tile_k": (64,), + "max_concurrent_steps": (6,), + } + self._test_matmul(matmul3, example_config, config_search_space) + + def test_matmul4(self): + example_config = TuningConfig( + tile_m=128, tile_n=128, tile_k=64, max_concurrent_steps=4, + ) + config_search_space = { + "tile_m": (128,), + "tile_n": (128,), + "tile_k": (64,), + "max_concurrent_steps": (6,), + } + self._test_matmul(matmul4, example_config, config_search_space) + + def test_matmul5(self): + example_config = TuningConfig( + tile_m=128, tile_n=128, tile_k=64, max_concurrent_steps=4, + ) + config_search_space = { + "tile_m": (128,), + "tile_n": (128,), + "tile_k": (64,), + "max_concurrent_steps": (6,), + } + self._test_matmul(matmul5, example_config, config_search_space) + + def test_matmul6(self): + example_config = TuningConfig( + tile_m=128, tile_n=128, tile_k=64, max_concurrent_steps=4, + grid_minor_dim=0, grid_tile_width=6, + ) + config_search_space = { + "tile_m": (128,), + "tile_n": (128,), + "tile_k": (64,), + "max_concurrent_steps": (6,), + "grid_minor_dim": (0, 1), + "grid_tile_width": (1, 4, 12, 16), + } + self._test_matmul(matmul6, example_config, config_search_space) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/mgpu_matmul_test.py b/tests/pallas/mgpu_matmul_test.py new file mode 100644 index 000000000000..76c04ae71afd --- /dev/null +++ b/tests/pallas/mgpu_matmul_test.py @@ -0,0 +1,336 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test different parameterizations of matrix multiplication.""" + +import os + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import config +from jax._src import dtypes +from jax._src import test_util as jtu +from jax._src.pallas import pallas_call +import jax.experimental.mosaic.gpu # noqa: F401 +from jax.experimental.pallas.ops.gpu import blackwell_matmul_mgpu +from jax.experimental.pallas.ops.gpu import hopper_matmul_mgpu +from jax.experimental.pallas.ops.gpu import hopper_mixed_type_matmul_mgpu +import jax.numpy as jnp +import numpy as np + + +config.parse_flags_with_absl() +os.environ["XLA_FLAGS"] = ( + os.environ.get("XLA_FLAGS", "") + " --xla_gpu_autotune_level=0") + + +def exceeds_h100_smem(alloc_bytes: int) -> bool: + """Whether the given allocation will exceed the amount of SMEM on H100.""" + return alloc_bytes > 228000 + + +@jtu.with_config(jax_traceback_filtering="off") +class MatrixMultiplicationSm100ATest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if not jtu.test_device_matches(["cuda"]): + self.skipTest("Test requires an NVIDIA GPU") + self.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(True)) + + @parameterized.product( + m=(1024, 4096), + k=(1024, 4096), + n=(1024, 4096), + dtype=(jnp.float16,), + ) + def test_blackwell_matmul( + self, + m, + n, + k, + dtype, + ): + if not jtu.is_cuda_compute_capability_equal("10.0"): + self.skipTest("Only works on GPU with capability sm100a") + k1, k2, = jax.random.split(jax.random.key(42), 2) + a = jax.random.normal(k1, (m, k), dtype) + b = jax.random.normal(k2, (k, n), dtype) + + out = blackwell_matmul_mgpu.matmul_kernel( + a, + b, + blackwell_matmul_mgpu.TuningConfig( + tile_m=128, tile_n=128, tile_k=128, + max_concurrent_steps=2, + collective=False, + ), + ) + out_ref = a @ b + np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3) + + +@jtu.with_config(jax_traceback_filtering="off") +class MatrixMultiplicationSm90ATest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if not jtu.test_device_matches(["cuda"]): + self.skipTest("Test requires an NVIDIA GPU") + self.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(True)) + + @parameterized.product( + m=(4096,), + k=(4096,), + n=(4096,), + tile_m=(64, 128), + tile_n=(64, 128), + tile_k=(64, 128), + max_concurrent_steps=(2, 4), + dtype=(jnp.float16,), + epi_tile_n=(None, 64), + epi_tile_m=(None, 64), + wg_dimension=tuple(hopper_matmul_mgpu.MatmulDimension), + ) + def test_hopper_matmul(self, *args, **kwargs): + self.check_hopper_matmul(*args, **kwargs) + + # Grid tiling doesn't really interact with many other options so we can test + # it separately. + @parameterized.product( + grid_minor_dim=tuple(hopper_matmul_mgpu.MatmulDimension), + grid_tile_width=(1, 3, 4), + ) + def test_hopper_matmul_grid_tiling(self, grid_minor_dim, grid_tile_width): + self.check_hopper_matmul( + m=4096, + k=4096, + n=4096, + dtype=jnp.float16, + tile_m=64, + tile_n=64, + tile_k=64, + max_concurrent_steps=2, + epi_tile_m=64, + epi_tile_n=64, + wg_dimension=hopper_matmul_mgpu.MatmulDimension.M, + grid_minor_dim=grid_minor_dim, + grid_tile_width=grid_tile_width, + ) + + @parameterized.product( + tile_m=(64, 128), + tile_n=(64, 128), + wg_dimension=tuple(hopper_matmul_mgpu.MatmulDimension), + cluster_dimension=tuple(hopper_matmul_mgpu.MatmulDimension), + ) + def test_hopper_matmul_cluster(self, tile_m, tile_n, wg_dimension, cluster_dimension): + self.check_hopper_matmul( + m=4096, + k=4096, + n=4096, + dtype=jnp.float16, + tile_m=tile_m, + tile_n=tile_n, + tile_k=64, + max_concurrent_steps=4, + epi_tile_m=64, + epi_tile_n=64, + wg_dimension=wg_dimension, + cluster_dimension=cluster_dimension, + ) + + def check_hopper_matmul( + self, + m, + n, + k, + dtype, + tile_m, + tile_n, + tile_k, + max_concurrent_steps, + epi_tile_m, + epi_tile_n, + wg_dimension, + **kwargs + ): + if not jtu.is_cuda_compute_capability_equal("9.0"): + self.skipTest("Only works on GPU with capability sm90a") + + epi_tile_size = (epi_tile_m or tile_m) * (epi_tile_n or tile_n) + num_epi_tiles = tile_m * tile_n // epi_tile_size + cta_tile_m = tile_m * (1 + (wg_dimension == hopper_matmul_mgpu.MatmulDimension.M)) + cta_tile_n = tile_n * (1 + (wg_dimension == hopper_matmul_mgpu.MatmulDimension.N)) + if exceeds_h100_smem( + ((cta_tile_m + cta_tile_n) * tile_k * max_concurrent_steps + + 2 * min(2, num_epi_tiles) * epi_tile_size) * 2 + ): + self.skipTest("Tile too big to fit into SMEM") + k1, k2, = jax.random.split(jax.random.key(42), 2) + a = jax.random.normal(k1, (m, k), dtype) + b = jax.random.normal(k2, (k, n), dtype) + + spec = hopper_matmul_mgpu.TuningConfig( + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + max_concurrent_steps=max_concurrent_steps, + epi_tile_m=epi_tile_m, + epi_tile_n=epi_tile_n, + wg_dimension=wg_dimension, + **kwargs, + ) + out = hopper_matmul_mgpu.matmul(a, b, spec) + out_ref = jnp.dot(a, b, precision=jax.lax.DotAlgorithmPreset.F16_F16_F32) + np.testing.assert_allclose(out, out_ref) + + @parameterized.product( + m=(4096,), + k=(4096,), + n=(4096,), + tile_m=(64, 128), + tile_n=(64, 128, 256), + tile_k=(64, 128), + epi_tile_m=(None, 64), + epi_tile_n=(None, 64), + max_concurrent_steps=(2, 4), + lhs_dtype=(jnp.int8,), # TODO(bchetioui): add int4. + rhs_dtype=(jnp.bfloat16, jnp.float16), + wg_dimension=tuple(hopper_mixed_type_matmul_mgpu.MatmulDimension), + ) + def test_hopper_mixed_type_matmul(self, *args, **kwargs): + self.check_hopper_mixed_type_matmul(*args, **kwargs) + + def check_hopper_mixed_type_matmul( + self, + m, + n, + k, + tile_m, + tile_n, + tile_k, + max_concurrent_steps, + epi_tile_m, + epi_tile_n, + wg_dimension, + lhs_dtype, + rhs_dtype, + **kwargs, + ): + if not jtu.is_cuda_compute_capability_equal("9.0"): + self.skipTest("Only works on GPU with capability sm90a") + out_dtype = rhs_dtype + lhs_bits = dtypes.itemsize_bits(lhs_dtype) + rhs_bits = dtypes.itemsize_bits(rhs_dtype) + out_bits = dtypes.itemsize_bits(out_dtype) + + cta_tile_m = tile_m * (1 + (wg_dimension == hopper_mixed_type_matmul_mgpu.MatmulDimension.M)) + cta_tile_n = tile_n * (1 + (wg_dimension == hopper_mixed_type_matmul_mgpu.MatmulDimension.N)) + lhs_smem_bytes = cta_tile_m * tile_k * lhs_bits // 8 + rhs_smem_bytes = tile_k * cta_tile_n * rhs_bits // 8 + + epi_tile_size = (epi_tile_m or tile_m) * (epi_tile_n or tile_n) + num_epi_tiles = tile_m * tile_n // epi_tile_size + out_smem_bytes = 2 * min(2, num_epi_tiles) * epi_tile_size * out_bits // 8 + + if exceeds_h100_smem( + max_concurrent_steps * (lhs_smem_bytes + rhs_smem_bytes) + + out_smem_bytes + ): + self.skipTest("Tile too big to fit into SMEM") + (k1, k2) = jax.random.split(jax.random.key(42), 2) + lhs = jax.random.randint( + k1, (m, k), minval=-5, maxval=5, dtype=jnp.int8 + ).astype(lhs_dtype) + rhs = jax.random.normal(k2, (k, n), rhs_dtype) + + tuning_config = hopper_mixed_type_matmul_mgpu.TuningConfig( + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + epi_tile_m=epi_tile_m, + epi_tile_n=epi_tile_n, + max_concurrent_steps=max_concurrent_steps, + wg_dimension=wg_dimension, + **kwargs, + ) + + out = hopper_mixed_type_matmul_mgpu.mixed_matmul_kernel( + lhs, rhs, out_dtype=out_dtype, config=tuning_config + ) + precision = { + jnp.float16: jax.lax.DotAlgorithmPreset.F16_F16_F32, + jnp.bfloat16: jax.lax.DotAlgorithmPreset.BF16_BF16_F32, + }[rhs_dtype] + out_ref = jnp.dot( + lhs.astype(rhs_dtype), rhs, precision=precision, + ).astype(out_dtype) + np.testing.assert_allclose(out, out_ref, strict=True) + + # Grid tiling doesn't really interact with many other options so we can test + # it separately. + @parameterized.product( + grid_minor_dim=tuple(hopper_matmul_mgpu.MatmulDimension), + grid_tile_width=(1, 3, 4), + ) + def test_hopper_mixed_type_matmul_grid_tiling( + self, grid_minor_dim, grid_tile_width + ): + self.check_hopper_mixed_type_matmul( + m=4096, + k=4096, + n=4096, + lhs_dtype=jnp.int8, + rhs_dtype=jnp.float16, + tile_m=64, + tile_n=64, + tile_k=64, + max_concurrent_steps=2, + epi_tile_m=64, + epi_tile_n=64, + wg_dimension=hopper_matmul_mgpu.MatmulDimension.M, + grid_minor_dim=grid_minor_dim, + grid_tile_width=grid_tile_width, + ) + + @parameterized.product( + tile_m=(64, 128), + tile_n=(64, 128), + wg_dimension=tuple(hopper_matmul_mgpu.MatmulDimension), + cluster_dimension=tuple(hopper_matmul_mgpu.MatmulDimension), + ) + def test_hopper_mixed_type_matmul_cluster( + self, tile_m, tile_n, wg_dimension, cluster_dimension + ): + self.check_hopper_mixed_type_matmul( + m=4096, + k=4096, + n=4096, + lhs_dtype=jnp.int8, + rhs_dtype=jnp.float16, + tile_m=tile_m, + tile_n=tile_n, + tile_k=64, + max_concurrent_steps=4, + epi_tile_m=64, + epi_tile_n=64, + wg_dimension=wg_dimension, + cluster_dimension=cluster_dimension, + ) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/mgpu_ragged_dot_test.py b/tests/pallas/mgpu_ragged_dot_test.py new file mode 100644 index 000000000000..8ecc20aedf66 --- /dev/null +++ b/tests/pallas/mgpu_ragged_dot_test.py @@ -0,0 +1,241 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test different parameterizations of our Mosaic GPU ragged dot kernel.""" + +import os + +from absl.testing import absltest, parameterized # pylint: disable=g-multiple-import +import jax +from jax import random +from jax._src import config +from jax._src import test_util as jtu +from jax._src.pallas import pallas_call +from jax.experimental.pallas.ops.gpu import blackwell_ragged_dot_mgpu +from jax.experimental.pallas.ops.gpu import ragged_dot_mgpu +from jax.experimental.pallas.ops.gpu import transposed_ragged_dot_mgpu +import jax.numpy as jnp +import numpy as np + + +config.parse_flags_with_absl() + + +# TODO(justinfu): Test empty groups +def sample_inputs( + key, m, k, n, num_groups, dtype=jnp.float16, transposed=False, +): + kx, ky, kz = random.split(key, num=3) + if transposed: + lhs = jax.random.normal(kx, (k, m), dtype) + rhs = jax.random.normal(ky, (k, n), dtype) + batch_size = k + else: + lhs = jax.random.normal(kx, (m, k), dtype) + rhs = jax.random.normal(ky, (num_groups, k, n), dtype) + batch_size = m + group_boundaries = jax.lax.sort( + jax.random.randint(kz, (num_groups - 1,), 0, batch_size, jnp.int32) + ) + group_starts = jax.lax.concatenate( + [jnp.array([0], dtype=jnp.int32), group_boundaries], 0 + ) + group_ends = jax.lax.concatenate( + [group_boundaries, jnp.array([batch_size], dtype=jnp.int32)], 0 + ) + group_sizes = group_ends - group_starts + assert group_sizes.shape == (num_groups,) + return lhs, rhs, group_sizes + + +@jtu.with_config(jax_traceback_filtering="off") +class RaggedDotTestCase(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if ragged_dot_mgpu is None: + self.skipTest("Mosaic GPU not available.") + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_equal("9.0")): + self.skipTest("Only works on GPU with capability sm90a") + self.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(True)) + + @parameterized.product( + block_m=(64, 128), + block_n=(64, 128, 192), + block_k=(64, 128), + grid_block_n=(2, 4), + max_concurrent_steps=(2, 4), + num_groups=(1, 3, 16), + transpose_rhs=(False, True), + ) + def test_ragged_dot( + self, + block_m, + block_n, + block_k, + grid_block_n, + max_concurrent_steps, + num_groups, + transpose_rhs, + ): + dtype = jnp.float16 + lhs_smem_size = block_m * block_k * max_concurrent_steps * 2 + rhs_smem_size = block_k * block_n * max_concurrent_steps * 2 + # H100 SMEM limit is 228kB. + if lhs_smem_size + rhs_smem_size > 228_000: + self.skipTest("This configuration requires too much SMEM.") + + m, k, n = 16 * 1024, 2048, 16 * 1024 + lhs, rhs, group_sizes = sample_inputs( + random.key(1234), m, k, n, num_groups, dtype + ) + + out = ragged_dot_mgpu.ragged_dot( + lhs, + jnp.transpose(rhs, (0, 2, 1)) if transpose_rhs else rhs, + group_sizes=group_sizes, + block_m=block_m, + block_n=block_n, + block_k=block_k, + max_concurrent_steps=max_concurrent_steps, + grid_block_n=grid_block_n, + transpose_rhs=transpose_rhs, + ) + out_ref = jax.lax.ragged_dot(lhs, rhs, group_sizes=group_sizes) + np.testing.assert_allclose(out, out_ref, atol=1e-3, rtol=1e-3) + + @parameterized.product( + block_m=(64, 128), + block_n=(64, 128), + block_k=(64, 128), + grid_block_n=(2, 4), + max_concurrent_steps=(1, 2, 4), + ) + def test_ragged_dot_transposed( + self, + block_m, + block_n, + block_k, + grid_block_n, + max_concurrent_steps, + ): + # See log at + # https://github.com/jax-ml/jax/actions/runs/20821451405/job/59813460647. + self.skipTest("TODO(bchetioui): this test has broken in CI. Investigate.") + dtype = jnp.float16 + lhs_smem_size = block_m * block_k * max_concurrent_steps * 2 + rhs_smem_size = block_k * block_n * max_concurrent_steps * 2 + # H100 SMEM limit is 228kB. + if lhs_smem_size + rhs_smem_size > 228_000: + self.skipTest("This configuration requires too much SMEM.") + + k, m, n, num_groups = 16 * 1024, 2048, 2048, 16 + lhs, rhs, group_sizes = sample_inputs( + random.key(1234), m, k, n, num_groups, + dtype=dtype, transposed=True, + ) + + with jax.numpy_dtype_promotion("standard"): + # We need standard dtype promotion for dynamic grid size to work, because + # python integers are treated as int64, and some of the dtypes inside + # emit_pipeline are hardcoded to use int32. + out = transposed_ragged_dot_mgpu.transposed_ragged_dot( + lhs, + rhs, + group_sizes=group_sizes, + block_m=block_m, + block_n=block_n, + block_k=block_k, + max_concurrent_steps=max_concurrent_steps, + grid_block_n=grid_block_n, + ) + out_ref = jax.lax.ragged_dot_general( + lhs, rhs, group_sizes, + ragged_dot_dimension_numbers=jax.lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(((0,), (0,)), ((), ())), + lhs_ragged_dimensions=[0], + rhs_group_dimensions=[], + ) + ) + np.testing.assert_allclose(out, out_ref, atol=1e-3, rtol=1e-3) + + +@jtu.with_config(jax_traceback_filtering="off") +class RaggedDotSm100aTestCase(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if blackwell_ragged_dot_mgpu is None: + self.skipTest("Mosaic GPU not available.") + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_equal("10.0")): + self.skipTest("Only works on GPU with capability sm100a") + self.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(True)) + + @parameterized.product( + grid_tile_width=(1, 8, 16), + grid_minor_dim=(0, 1), + max_concurrent_steps=(2, 4), + num_groups=(1, 3, 16), + tile_k=(64, 128) + ) + def test_ragged_dot( + self, + grid_tile_width, + grid_minor_dim, + max_concurrent_steps, + num_groups, + tile_k, + ): + # Kernel does not support other tiling on M and N dimensions currently. + tile_m = 128 + tile_n = 128 + + lhs_smem_size = tile_m * tile_k * max_concurrent_steps * 2 + rhs_smem_size = tile_k * tile_n * max_concurrent_steps * 2 + # B200 SMEM limit is 228kB. + if lhs_smem_size + rhs_smem_size > 228_000: + self.skipTest("This configuration requires too much SMEM.") + + dtype = jnp.float16 + m, k, n = 16 * 1024, 2048, 16 * 1024 + lhs, rhs, group_sizes = sample_inputs( + random.key(1234), m, k, n, num_groups, dtype + ) + tuning_config = blackwell_ragged_dot_mgpu.TuningConfig( + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + grid_tile_width=grid_tile_width, + grid_minor_dim=grid_minor_dim, + max_concurrent_steps=max_concurrent_steps, + collective=True, + ) + out = blackwell_ragged_dot_mgpu.ragged_dot_kernel( + lhs, + rhs, + group_sizes=group_sizes, + config=tuning_config, + ) + out_ref = jax.lax.ragged_dot(lhs, rhs, group_sizes=group_sizes, + preferred_element_type=dtype) + np.testing.assert_allclose(out, out_ref, atol=1e-3, rtol=1e-3) + + +if __name__ == "__main__": + os.environ["XLA_FLAGS"] = ( + os.environ.get("XLA_FLAGS", "") + " --xla_gpu_autotune_level=0" + ) + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/mgpu_torch_test.py b/tests/pallas/mgpu_torch_test.py new file mode 100644 index 000000000000..d0f80394b880 --- /dev/null +++ b/tests/pallas/mgpu_torch_test.py @@ -0,0 +1,182 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import functools + +from absl.testing import absltest + +import jax +from jax._src import config +from jax._src import test_util as jtu +from jax._src.pallas import pallas_call +from jax.experimental import pallas as pl +from jax.experimental.pallas import mosaic_gpu as plgpu +import jax.numpy as jnp +import numpy as np + +try: + import torch +except ImportError: + torch = None + +# pylint: disable=g-import-not-at-top +try: + # We only import this to see if Mosaic is available. + import jax.experimental.mosaic.gpu # noqa: F401 +except ImportError: + attention_mgpu = None +else: + from jax.experimental.pallas.ops.gpu import attention_mgpu + + +config.parse_flags_with_absl() + + +class TorchTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if torch is None: + self.skipTest("Test requires PyTorch") + if attention_mgpu is None: + self.skipTest("Mosaic GPU not available.") + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("Only works on GPU with capability sm90a+") + self.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(True)) + + def test_simple_pallas_call(self): + @plgpu.as_torch_kernel + @functools.partial( + pl.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.int32) + ) + def kernel(x_ref, y_ref, o_ref): + o_ref[...] = x_ref[...] + y_ref[0] + + x = torch.arange(128, dtype=torch.int32, device="cuda") + y = torch.arange(128, dtype=torch.int32, device="cuda") + np.testing.assert_array_equal(kernel(x, y).cpu(), (x + y[0]).cpu()) + + def test_simple_plgpu_kernel(self): + @plgpu.as_torch_kernel + @functools.partial( + plgpu.kernel, out_shape=jax.ShapeDtypeStruct([128], jnp.int32) + ) + def kernel(x_ref, y_ref, o_ref): + o_ref[...] = x_ref[...] + y_ref[0] + + x = torch.arange(128, dtype=torch.int32, device="cuda") + y = torch.arange(128, dtype=torch.int32, device="cuda") + np.testing.assert_array_equal(kernel(x, y).cpu(), (x + y[0]).cpu()) + + def test_flip(self): + @functools.partial( + pl.pallas_call, out_shape=(jax.ShapeDtypeStruct([128], jnp.int32),) * 2 + ) + def kernel(x_ref, y_ref, x_o_ref, y_o_ref): + x_o_ref[...] = x_ref[...] + y_o_ref[...] = y_ref[...] + + x = torch.arange(128, dtype=torch.int32, device="cuda") + y = torch.arange(128, dtype=torch.int32, device="cuda") + yo, xo = plgpu.as_torch_kernel(lambda x, y: kernel(x, y)[::-1])(x, y) + np.testing.assert_array_equal(xo.cpu(), x.cpu()) + np.testing.assert_array_equal(yo.cpu(), y.cpu()) + + def test_not_all_returned(self): + @functools.partial( + pl.pallas_call, out_shape=(jax.ShapeDtypeStruct([128], jnp.int32),) * 2 + ) + def kernel(x_ref, y_ref, x_o_ref, y_o_ref): + x_o_ref[...] = x_ref[...] + y_o_ref[...] = y_ref[...] + + x = torch.arange(128, dtype=torch.int32, device="cuda") + y = torch.arange(128, dtype=torch.int32, device="cuda") + xo = plgpu.as_torch_kernel(lambda x, y: kernel(x, y)[0])(x, y) + np.testing.assert_array_equal(xo.cpu(), x.cpu()) + + def test_invalid(self): + @functools.partial( + pl.pallas_call, out_shape=(jax.ShapeDtypeStruct([128], jnp.int32),) * 2 + ) + def kernel(x_ref, y_ref, x_o_ref, y_o_ref): + x_o_ref[...] = x_ref[...] + y_o_ref[...] = y_ref[...] + + x = torch.arange(128, dtype=torch.int32, device="cuda") + y = torch.arange(128, dtype=torch.int32, device="cuda") + + with self.assertRaisesRegex(ValueError, "Unsupported operation .* stablehlo.add"): + plgpu.as_torch_kernel(lambda x, y: x + y)(x, y) + + with self.assertRaisesRegex(ValueError, "Multiple Mosaic GPU kernels"): + plgpu.as_torch_kernel(lambda x, y: kernel(*kernel(x, y)))(x, y) + with self.assertRaisesRegex(ValueError, "Unsupported operation .* stablehlo.add"): + plgpu.as_torch_kernel(lambda x, y: kernel(x, y + jnp.ones_like(x)))(x, y) + with self.assertRaisesRegex(ValueError, "The function can only return kernel results"): + plgpu.as_torch_kernel(lambda x, y: (kernel(x, y), x, y))(x, y) + + def test_attention(self): + if not jtu.is_cuda_compute_capability_equal("9.0"): + self.skipTest("Test requires compute capability == 9.0") + batch_size = 1 + q_seq_len = 4096 + kv_seq_len = 4096 + head_dim = 64 + num_q_heads, num_kv_heads = 4, 1 + block_q = block_kv = 64 + q = torch.randn( + (batch_size, q_seq_len, num_q_heads, head_dim), + dtype=torch.float16, + device="cuda", + ) + k = torch.randn( + (batch_size, kv_seq_len, num_kv_heads, head_dim), + dtype=torch.float16, + device="cuda", + ) + v = torch.randn( + (batch_size, kv_seq_len, num_kv_heads, head_dim), + dtype=torch.float16, + device="cuda", + ) + kernel_fn = functools.partial( + attention_mgpu.attention, + config=attention_mgpu.TuningConfig( + block_q=block_q, + block_kv=block_kv, + max_concurrent_steps=2, + ), + ) + np.testing.assert_array_equal( + plgpu.as_torch_kernel(kernel_fn)(q, k, v).cpu(), + kernel_fn(jnp.asarray(q), jnp.asarray(k), jnp.asarray(v)), + ) + + def test_torch_aliasing(self): + @pl.kernel(mesh=plgpu.Mesh(), out_shape=(), compiler_params=plgpu.CompilerParams()) + def kernel(x_ref): + x_ref[...] = jnp.ones_like(x_ref) + + x = torch.zeros(128, dtype=torch.float32, device="cuda") + plgpu.as_torch_kernel(kernel)(x) # Run for side effects + np.testing.assert_array_equal( + x.cpu(), torch.ones((128,), dtype=torch.float32, device="cpu") + ) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index b3c3ddb84e09..ab14e9c88eff 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -12,24 +12,45 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sequence import contextlib +import dataclasses import functools +import itertools import math import operator import os import re +import sys import tempfile +import traceback +from typing import ClassVar from absl.testing import absltest from absl.testing import parameterized import jax +from jax import export from jax import lax +from jax._src import core as jax_core +from jax._src import dtypes from jax._src import test_util as jtu +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import arith as arith_dialect +from jax._src.lib.mlir.dialects import gpu as gpu_dialect +from jax._src.pallas import core as pallas_core +from jax._src.pallas import pallas_call +from jax._src.pallas import primitives as pallas_primitives +from jax._src.pallas.mosaic_gpu import core as gpu_core +from jax._src.pallas.mosaic_gpu import lowering as mgpu_lowering from jax._src.pallas.mosaic_gpu import pipeline as mgpu_pipeline +from jax._src.pallas.mosaic_gpu import primitives as mgpu_primitives +from jax._src.state import types as state_types from jax.experimental import pallas as pl +import jax.experimental.mosaic.gpu as mgpu from jax.experimental.pallas import mosaic_gpu as plgpu import jax.numpy as jnp import numpy as np + try: from jax._src.lib import mosaic_gpu as mosaic_gpu_lib except ImportError: @@ -54,16 +75,94 @@ def _sum_same_dtype(x): return jnp.sum(x, dtype=x.dtype) -class PallasTest(jtu.JaxTestCase): +def _get_linearized_cuda_grid_index(): + shape = () + layout = plgpu.Layout.WG_SPLAT + + @plgpu.inline_mgpu( + arg_types=(), + return_type=plgpu.ShapeDtypeStruct(shape, jnp.int32, layout=layout), + ) + def fn(_): + grid_x = gpu_dialect.grid_dim(gpu_dialect.Dimension.x) + grid_y = gpu_dialect.grid_dim(gpu_dialect.Dimension.y) + block_x = gpu_dialect.block_id(gpu_dialect.Dimension.x) + block_y = gpu_dialect.block_id(gpu_dialect.Dimension.y) + block_z = gpu_dialect.block_id(gpu_dialect.Dimension.z) + + grid_idx = arith_dialect.addi( + block_x, + arith_dialect.addi( + arith_dialect.muli(block_y, grid_x), + arith_dialect.muli(block_z, arith_dialect.muli(grid_x, grid_y)), + ), + ) + + return mgpu.FragmentedArray.splat( + arith_dialect.index_cast(ir.IntegerType.get_signless(32), grid_idx), + shape=shape, + layout=layout.to_mgpu(), + is_signed=False + ) + return fn() + + +def _array_splat(value, shape: tuple[int, ...]): + """Same as `jnp.full(shape, value, jnp.float32)` but implemented using `inline_mgpu`. + + This is useful to prevent the result from being optimized away. + """ + @plgpu.inline_mgpu( + return_type=plgpu.ShapeDtypeStruct( + shape, jnp.float32, layout=plgpu.Layout.WG_SPLAT(shape) + ), + ) + def fn(_): + ir_value = mgpu.c(value, ir.F32Type.get()) + return mgpu.FragmentedArray.splat(ir_value, shape) + return fn() + + +class PallasTestMetaclass(parameterized.TestGeneratorMetaclass): + + def __new__(mcs, *args, lowering_semantics=plgpu.LoweringSemantics.Lane): + cls = super().__new__(mcs, *args) + cls.LOWERING_SEMANTICS = lowering_semantics + return cls + + +class PallasTest(jtu.JaxTestCase, metaclass=PallasTestMetaclass): + LOWERING_SEMANTICS: ClassVar[plgpu.LoweringSemantics] def setUp(self): if not jtu.is_cuda_compute_capability_at_least("9.0"): self.skipTest("Only works on a GPU with capability >= sm90") + self.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(True)) super().setUp() + def skip_if_wg_semantics(self): + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Warpgroup: + self.skipTest("Not supported under WG semantics") + + def kernel(self, *args, **kwargs): + compiler_params = dataclasses.replace( + kwargs.pop("compiler_params", plgpu.CompilerParams()), + lowering_semantics=self.LOWERING_SEMANTICS, + ) + return plgpu.kernel(*args, compiler_params=compiler_params, **kwargs) + + def pallas_call(self, *args, **kwargs): + compiler_params = dataclasses.replace( + kwargs.pop("compiler_params", plgpu.CompilerParams()), + lowering_semantics=self.LOWERING_SEMANTICS, + ) + return pl.pallas_call(*args, compiler_params=compiler_params, **kwargs) + @contextlib.contextmanager def capture_stdout(self): + if "pytest" in sys.modules: + self.skipTest("pytest interacts badly with GPU stdout capture") if mosaic_gpu_lib is None: raise ValueError("Running tests but missing Mosaic GPU extension") with jtu.capture_stdout() as stdout: @@ -71,6 +170,17 @@ def capture_stdout(self): # We need to cudaDeviceSynchronize to make sure printfs are flushed. mosaic_gpu_lib._mosaic_gpu_ext._sync_all_devices() + def default_transforms( + self, *, swizzle: int = 128, dtype: jnp.dtype + ) -> Sequence[plgpu.MemoryRefTransform]: + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Warpgroup: + return () + swizzle_elems = 8 * swizzle // dtypes.itemsize_bits(dtype) + return ( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(swizzle), + ) + class PallasSm90ATest(PallasTest, jtu.CudaArchSpecificTest): @@ -79,8 +189,37 @@ def setUp(self): super().setUp() +class PallasSm100ATest(PallasTest, jtu.CudaArchSpecificTest): + + def setUp(self): + self.skip_unless_sm100a() + super().setUp() + + class PallasCallTest(PallasTest): + def test_jitted_function_containing_multiple_pallas_calls(self): + # This test aims to ensure that execution works correctly inside CUDA + # graphs. This is complementary to the test in + # jaxlib/mosaic/gpu/custom_call_test.cc that checks that such jitted + # functions do invoke CUDA graphs. + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + ) + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + 1 + + @jax.jit + def f(x): + # Run the kernel 10 times because CUDA graphs only trigger for >= 5 ops. + for _ in range(10): + x = kernel(x) + return x + + x = jnp.arange(256).astype(jnp.float32) + np.testing.assert_array_equal(f(x), x + 10) + @parameterized.product( op=[ lax.neg, @@ -88,22 +227,24 @@ class PallasCallTest(PallasTest): lax.logistic, lax.exp, lambda x: x**2, + lambda x: x**5, lax.rsqrt, lax.tanh, lax.log, + jax.nn.gelu, + lax.abs, + lax.round, + lambda x: lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN), ], approx_math=[True, False], - thread_semantics=[*plgpu.ThreadSemantics], ) - def test_unary_op(self, op, approx_math, thread_semantics): + def test_unary_op(self, op, approx_math): dtype = jnp.int32 if op is lax.bitwise_not else jnp.float32 @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], dtype), - compiler_params=plgpu.GPUCompilerParams( - approx_math=approx_math, thread_semantics=thread_semantics - ), + compiler_params=plgpu.CompilerParams(approx_math=approx_math), ) def kernel(x_ref, o_ref): o_ref[...] = op(x_ref[...]) @@ -113,6 +254,19 @@ def kernel(x_ref, o_ref): kernel(x), op(x), rtol=1e-5 if approx_math else 3e-7 ) + @parameterized.parameters(jnp.float32, jnp.int32, jnp.uint32) + def test_sign(self, dtype): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], dtype), + ) + def kernel(src_ref, dst_ref): + dst_ref[...] = lax.sign(src_ref[...]) + + # Use values that include negative, zero, and positive. + src = np.arange(256, dtype=dtype) - 128 + np.testing.assert_array_equal(kernel(src), lax.sign(src)) + @parameterized.product( op=[ operator.add, @@ -124,16 +278,10 @@ def kernel(x_ref, o_ref): jnp.maximum, ], dtype=[jnp.float32, jnp.int32, jnp.uint32], - thread_semantics=[*plgpu.ThreadSemantics], ) - def test_binary_op(self, op, dtype, thread_semantics): - + def test_binary_op(self, op, dtype): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], dtype), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], dtype) ) def kernel(x_ref, y_ref, o_ref): o_ref[...] = op(x_ref[...], y_ref[...]) @@ -154,16 +302,10 @@ def kernel(x_ref, y_ref, o_ref): ], # TODO(slebedev): Support integral types. dtype=[jnp.float32, jnp.int32, jnp.uint32], - thread_semantics=[*plgpu.ThreadSemantics], ) - def test_comparison_op(self, op, dtype, thread_semantics): - + def test_comparison_op(self, op, dtype): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], dtype), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], dtype) ) def kernel(o_ref): o_ref[...] = jnp.broadcast_to( @@ -173,8 +315,9 @@ def kernel(o_ref): np.testing.assert_array_equal(kernel(), jnp.full([256], op(42, 24), dtype)) def test_add_first(self): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), ) def kernel(x_ref, y_ref, o_ref): @@ -185,28 +328,25 @@ def kernel(x_ref, y_ref, o_ref): np.testing.assert_array_equal(kernel(x, y), x + y[0]) @parameterized.product( - shape=[(128,), (128, 128)], thread_semantics=[*plgpu.ThreadSemantics] + op=(jnp.sum, jnp.max, jnp.min), + shape=((128,), (128, 64)), + dtype=(jnp.float32, jnp.int32, jnp.uint32), ) - def test_reduce_sum(self, shape, thread_semantics): - @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), - ) + def test_scalar_reduce(self, op, shape, dtype): + @functools.partial(self.kernel, out_shape=jax.ShapeDtypeStruct((), dtype)) def kernel(x_ref, o_ref): - o_ref[...] = jnp.broadcast_to(_sum_same_dtype(x_ref[...]), o_ref.shape) + o_ref[...] = op(x_ref[...]) - x = jnp.arange(math.prod(shape)).reshape(shape).astype(jnp.float32) - np.testing.assert_array_equal(kernel(x), jnp.sum(x)) + x = jnp.arange(math.prod(shape)).reshape(shape).astype(dtype) + if dtype != jnp.uint32: + x = x - x.size // 2 # include negative values. + np.testing.assert_array_equal(kernel(x), op(x)) def test_reshape(self): shape1, shape2 = (128,), (2, 16, 4) @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct(shape2, jnp.float32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape2, jnp.float32) ) def kernel(x_ref, out_ref): x_ref_reshaped = x_ref.reshape(shape2) @@ -217,14 +357,63 @@ def kernel(x_ref, out_ref): x = jnp.arange(math.prod(shape1)).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x.reshape(shape2)) - @parameterized.product(thread_semantics=[*plgpu.ThreadSemantics]) - def test_add_xy_indexed(self, thread_semantics): + def test_reshape_tiled(self): + shape1, shape2 = (6 * 64, 8), (2, 3, 64, 8) + @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([128], jnp.float32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.kernel, + out_shape=jax.ShapeDtypeStruct(shape2, jnp.float32), + ) + def kernel(x_ref, out_ref): + x = plgpu.load(x_ref, (), layout=plgpu.Layout.WGMMA, optimized=False) + out_ref[...] = x.reshape(shape2) + + x = jnp.arange(math.prod(shape1)).reshape(shape1).astype(jnp.float32) + np.testing.assert_array_equal(kernel(x), x.reshape(shape2)) + + def test_reshape_splat(self): + shape = (1, 1, 1) + + @functools.partial( + self.kernel, + out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), + ) + def kernel(out_ref): + x = jnp.array(42, dtype=jnp.float32) + out_ref[...] = x.reshape(shape) + + np.testing.assert_array_equal( + kernel(), jnp.array(42, dtype=jnp.float32).reshape(shape) + ) + + def test_slice_untiled_dim(self): + shape = (2, 3, 64, 8) + + @functools.partial( + self.kernel, + out_shape=jax.ShapeDtypeStruct(shape[2:], jnp.float32), + ) + def kernel(x_ref, out_ref): + x = plgpu.load(x_ref, (), layout=plgpu.Layout.WGMMA, optimized=False) + out_ref[...] = x[1, 1] + + x = jnp.arange(math.prod(shape)).reshape(shape).astype(jnp.float32) + np.testing.assert_array_equal(kernel(x), x[1, 1]) + + def test_squeeze_to_scalar(self): + @functools.partial( + self.kernel, + out_shape=jax.ShapeDtypeStruct((), jnp.float32), + ) + def kernel(out_ref): + x = _array_splat(42, (1, 1, 1)) + out_ref[...] = lax.squeeze(x, dimensions=(0, 1, 2)) + + np.testing.assert_array_equal(kernel(), jnp.array(42, dtype=jnp.float32)) + + def test_add_xy_indexed(self): + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32) ) def kernel(x_ref, y_ref, o_ref): idx = _sum_same_dtype(y_ref[...]) @@ -235,8 +424,9 @@ def kernel(x_ref, y_ref, o_ref): np.testing.assert_array_equal(kernel(x, y), x[jnp.sum(y)]) def test_add_one_grid(self): + @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=[pl.BlockSpec((128,), lambda *i: i)], out_specs=pl.BlockSpec((128,), lambda *i: i), out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.float32), @@ -249,9 +439,8 @@ def kernel(x_ref, o_ref): np.testing.assert_array_equal(kernel(x), x + 1.0) def test_add_one_grid_with_scratch(self): - @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.float32), in_specs=[pl.BlockSpec((128,), lambda *i: i)], out_specs=pl.BlockSpec((128,), lambda *i: i), @@ -269,11 +458,11 @@ def kernel(x_ref, o_ref, scratch_ref): def test_add_one_grid_pipelined(self, max_concurrent_steps): @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=[pl.BlockSpec((128, 16), lambda i, j: (i, j))], out_specs=pl.BlockSpec((128, 16), lambda i, j: (i, j)), out_shape=jax.ShapeDtypeStruct([128 * 2, 64], jnp.float32), - compiler_params=plgpu.GPUCompilerParams( + compiler_params=plgpu.CompilerParams( dimension_semantics=["parallel", "sequential"], max_concurrent_steps=max_concurrent_steps, ), @@ -285,13 +474,30 @@ def kernel(x_ref, o_ref): x = jnp.arange(128 * 2 * 64).reshape((128 * 2, 64)).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x + 1.0) + def test_add_one_grid_pipelined_with_leading_sequential_dimension(self): + @functools.partial( + self.pallas_call, + in_specs=[pl.BlockSpec((128, 16), lambda i, j: (i, j))], + out_specs=pl.BlockSpec((128, 16), lambda i, j: (i, j)), + out_shape=jax.ShapeDtypeStruct([128 * 2, 64], jnp.float32), + compiler_params=plgpu.CompilerParams( + dimension_semantics=["sequential", "parallel"], + ), + grid=(2, 4), + ) + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + 1.0 + + x = jnp.arange(128 * 2 * 64).reshape((128 * 2, 64)).astype(jnp.float32) + np.testing.assert_array_equal(kernel(x), x + 1.0) + def test_add_one_grid_pipelined_program_id(self): @functools.partial( - pl.pallas_call, + self.pallas_call, out_specs=pl.BlockSpec((16, 16), lambda i, j: (i, j)), out_shape=jax.ShapeDtypeStruct([16, 64], jnp.int32), - compiler_params=plgpu.GPUCompilerParams( + compiler_params=plgpu.CompilerParams( dimension_semantics=["parallel", "sequential"], max_concurrent_steps=2, ), @@ -306,12 +512,13 @@ def kernel(o_ref): ) def test_add_one_grid_pipelined_sequential_invariant_output(self): + @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=[pl.BlockSpec((32, 16), lambda i, j: (i, j))], out_specs=pl.BlockSpec((32, 16), lambda i, j: (i, 0)), out_shape=jax.ShapeDtypeStruct([32 * 2, 64], jnp.float32), - compiler_params=plgpu.GPUCompilerParams( + compiler_params=plgpu.CompilerParams( dimension_semantics=["parallel", "sequential"], max_concurrent_steps=2, ), @@ -335,29 +542,181 @@ def kernel(x_ref, o_ref): @parameterized.parameters(jnp.float32, jnp.int32, jnp.uint32) def test_iota(self, dtype): dimension = 1 + @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct((128, 128), dtype), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((128, 128), dtype) ) def kernel(o_ref): - o_ref[...] = plgpu.broadcasted_iota(dtype, (128, 128), dimension, layout=plgpu.Layout.WGMMA) + o_ref[...] = plgpu.broadcasted_iota( + dtype, o_ref.shape, dimension, layout=plgpu.Layout.WGMMA + ) - np.testing.assert_array_equal(kernel(), jax.lax.broadcasted_iota(dtype, (128, 128), dimension)) + np.testing.assert_array_equal( + kernel(), jax.lax.broadcasted_iota(dtype, (128, 128), dimension) + ) - @parameterized.product( - indexer=[..., slice(128), slice(None, 128)], - thread_semantics=[*plgpu.ThreadSemantics], + @parameterized.parameters(jnp.bfloat16, jnp.int16, jnp.uint16) + def test_inline_mgpu(self, jnp_type): + dtype = jnp.dtype(jnp_type) + is_signed = mgpu.utils.is_signed(dtype) + shape = (128, 128) + tile = (64, 128 // dtype.itemsize) + tiled_shape = list(mgpu.tile_shape(shape, tile)) + + key = jax.random.key(0) + x = jax.random.uniform(key, (2, *shape), minval=-10, maxval=10).astype( + dtype + ) + + transforms = ( + plgpu.TilingTransform(tile), + plgpu.SwizzleTransform(128), + ) + + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Warpgroup: + pallas_call_transforms = () + else: + pallas_call_transforms = transforms + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(shape, dtype), + in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), + scratch_shapes=[ + plgpu.SMEM( + x.shape, + dtype, + transforms=pallas_call_transforms, + ), + plgpu.Barrier(), + ], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + ) + def kernel(x_ref, o_ref, smem_ref, barrier): + plgpu.copy_gmem_to_smem(x_ref, smem_ref, barrier) + plgpu.barrier_wait(barrier) + # Add an indexer at the end. + sliced_smem_ref = smem_ref.at[0] + @plgpu.inline_mgpu( + arg_types=(plgpu.RefType(( + plgpu.TilingTransform(tile), + plgpu.SwizzleTransform(128), + )),), + return_type=plgpu.ShapeDtypeStruct( + shape, dtype, layout=plgpu.Layout.WGMMA + ), + ) + def foo(ctx, smem_ref): + del ctx + assert smem_ref.type.shape == tiled_shape, (smem_ref.type, tiled_shape) + x = mgpu.FragmentedArray.load_tiled( + smem_ref, swizzle=128, is_signed=is_signed + ) + y = mgpu.FragmentedArray.splat( + mgpu.c(1, x.mlir_dtype), + shape=x.shape, + layout=x.layout, + is_signed=is_signed, + ) + return (x + x + y) + + arr = foo(sliced_smem_ref) + @plgpu.inline_mgpu(arg_types=(plgpu.Layout.WGMMA, plgpu.RefType(transforms), plgpu.RefType())) + def store(ctx, arr, smem_ref, o_ref): + sliced_smem_ref = mgpu.memref_slice(smem_ref, (0,)) + arr.store_tiled(sliced_smem_ref, swizzle=128) + mgpu.commit_shared() + ctx.async_copy( + src_ref=sliced_smem_ref, + dst_ref=o_ref, + swizzle=128, + gmem_transform=(mgpu.TileTransform(tile)), + ) + ctx.await_async_copy(0) + + # A dummy if statement to make sure we inline nested blocks correctly. + is_leader_thread = mgpu.utils.single_thread_predicate() + with mgpu.utils.when(is_leader_thread): + pass + + # This time we slice inside the inline_mgpu body. + store(arr, smem_ref, o_ref) + + np.testing.assert_array_equal(kernel(x), x[0] + x[0] + 1) + + @parameterized.parameters( + plgpu.Layout.WGMMA, + plgpu.Layout.WGMMA_UPCAST_2X, + plgpu.Layout.WGMMA_UPCAST_4X, + plgpu.Layout.TCGEN05, ) - def test_copy_smem_to_gmem(self, indexer, thread_semantics): + def test_inline_mgpu_layout_args(self, layout: gpu_core.SomeLayout): + quant_dtype = jnp.int8 + dtype = jnp.bfloat16 + mgpu_layout = layout.to_mgpu() + shape = (128, 128) + + rngs = list(jax.random.split(jax.random.key(0))) + x = jax.random.randint(rngs.pop(), shape, minval=-10, maxval=10).astype( + quant_dtype + ) + x_s = jax.random.uniform( + rngs.pop(), shape[0], minval=-100, maxval=100 + ).astype(dtype) + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(shape, dtype), + in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM), + pl.BlockSpec(memory_space=plgpu.GMEM)), + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + ) + def kernel(x_ref, x_scale_ref, o_ref): + x = plgpu.load(x_ref, (), layout=layout, optimized=False).astype(x_scale_ref.dtype) + x_s = plgpu.load(x_scale_ref, (), layout=layout.reduce(1), optimized=False) + + @plgpu.inline_mgpu( + arg_types=(layout, layout.reduce(1)), + return_type=plgpu.ShapeDtypeStruct(shape, dtype, layout=layout), + ) + def custom_broadcast(ctx, x_fa, xs_fa): + del ctx + return xs_fa.broadcast_in_dim(shape, [0], layout=mgpu_layout) * x_fa + + o_ref[...] = custom_broadcast(x, x_s) + + np.testing.assert_array_equal( + kernel(x, x_s), + x.astype(dtype) * jnp.broadcast_to(x_s[:, None], x.shape), + ) + + def test_sync_copy(self): + shape = (128, 128) + transforms = self.default_transforms(dtype=jnp.float32) + @functools.partial( + self.pallas_call, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + scratch_shapes=[plgpu.SMEM(shape, jnp.float32, transforms=transforms)], + ) + def kernel(x_ref, y_ref, scratch_ref): + layout = plgpu.Layout.SMEM_GMEM_COPY(shape, jnp.float32, swizzle=128) + # GMEM loads require optimized=False, because we can't prove coalescing. + # But with this layout they should be fast. + scratch_ref[...] = plgpu.load(x_ref, (), layout=layout, optimized=False) + y_ref[...] = plgpu.layout_cast(scratch_ref[...], layout) + + x = jnp.arange(math.prod(shape), dtype=jnp.float32).reshape(shape) + np.testing.assert_array_equal(kernel(x), x) + @parameterized.product(indexer=[..., slice(128), slice(None, 128)]) + def test_copy_smem_to_gmem(self, indexer): @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), scratch_shapes=[plgpu.SMEM((256,), jnp.float32)], - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), ) def kernel(x_ref, o_ref_gmem, scratch_ref): scratch_ref[...] = x_ref[...] + 1 @@ -368,6 +727,33 @@ def kernel(x_ref, o_ref_gmem, scratch_ref): x = jnp.arange(256).astype(jnp.float32) np.testing.assert_array_equal(kernel(x)[indexer], x[indexer] + 1.0) + @parameterized.parameters(jnp.bfloat16, jnp.float16, jnp.float32) + def test_copy_smem_to_gmem_reduction(self, dtype): + # TODO(b/415721295):Remove after the minimal jaxlib version is 0.8.2. + if not hasattr(mgpu.dialect, "TMAReduction"): + self.skip_if_wg_semantics() + + @functools.partial( + self.pallas_call, + grid=(200,), + in_specs=[pl.BlockSpec((128,), lambda *i: i), pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct([128], dtype), + scratch_shapes=[plgpu.SMEM((128,), dtype)], + input_output_aliases={1:0} + ) + def kernel(x_ref, o_ref_gmem, o_ref_gmem_alias, scratch_ref): + del o_ref_gmem_alias + scratch_ref[...] = x_ref[...] + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(scratch_ref.at[...], o_ref_gmem.at[...], reduction_op="add") + plgpu.wait_smem_to_gmem(0) + x = jnp.ones(200 * 128).astype(dtype) # 200 blocks + output = jnp.zeros(128).astype(dtype) + output = kernel(x, output) + output_val = x.reshape(-1, 128).sum(axis=0) + np.testing.assert_array_equal(output, output_val) + @parameterized.named_parameters( {"testcase_name": "1d_none", "shape": (256,), "indexers": (slice(0, 128), slice(None, 32))}, @@ -377,8 +763,9 @@ def kernel(x_ref, o_ref_gmem, scratch_ref): "shape": (64, 64), "indexers": (4, slice(0, 64))}, ) def test_copy_smem_to_gmem_with_multiple_gmem_indexers(self, shape, indexers): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), scratch_shapes=[plgpu.SMEM(shape, jnp.float32)], @@ -402,13 +789,14 @@ def kernel(x_ref, o_ref_gmem, scratch_ref): @parameterized.product(indexer=[..., slice(128), slice(None, 128)]) def test_copy_gmem_to_smem(self, indexer): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), scratch_shapes=[ plgpu.SMEM((256,), jnp.float32), - plgpu.Barrier(num_arrivals=1), + plgpu.Barrier(), ], ) def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): @@ -421,6 +809,68 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): x = jnp.arange(256).astype(jnp.float32) np.testing.assert_array_equal(kernel(x)[indexer], x[indexer] + 1.0) + def test_collective_copy_gmem_to_smem(self): + @functools.partial( + self.kernel, + out_shape=jax.ShapeDtypeStruct((2, 128), jnp.float32), + scratch_shapes=dict( + smem_ref=plgpu.SMEM((128,), jnp.float32), + barrier_ref=plgpu.Barrier(), + ), + cluster=(2,), + cluster_names=("cluster",), + ) + def kernel(x_ref, y_ref, smem_ref, barrier_ref): + # Specifying collective_axes will enable TMA multicast automatically. + plgpu.copy_gmem_to_smem( + x_ref, smem_ref, barrier_ref, collective_axes="cluster" + ) + plgpu.barrier_wait(barrier_ref) + plgpu.copy_smem_to_gmem(smem_ref, y_ref.at[jax.lax.axis_index("cluster")]) + plgpu.wait_smem_to_gmem(0) + + x = jnp.arange(128, dtype=jnp.float32) + y = kernel(x) + # Each block gets the same data and writes it out. + np.testing.assert_array_equal(y, jnp.stack([x, x], axis=0)) + + @parameterized.product(indexer=[..., slice(128), slice(None, 128)]) + def test_async_prefetch(self, indexer): + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), + scratch_shapes=[ + plgpu.SMEM((256,), jnp.float32), + plgpu.Barrier(), + ], + ) + def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): + plgpu.async_prefetch(x_ref_gmem.at[indexer]) + plgpu.copy_gmem_to_smem( + x_ref_gmem.at[indexer], scratch_ref.at[indexer], barrier_ref + ) + plgpu.barrier_wait(barrier_ref) + o_ref[...] = scratch_ref[...] + 1 + + x = jnp.arange(256).astype(jnp.float32) + np.testing.assert_array_equal(kernel(x)[indexer], x[indexer] + 1.0) + + def test_barrier_repeated_indexing(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), + scratch_shapes=[plgpu.Barrier(num_barriers=4)], + ) + def kernel(o_ref, barrier_ref): + b = barrier_ref.at[2:4].at[1:2] + plgpu.barrier_arrive(b) + plgpu.barrier_wait(barrier_ref.at[3]) + o_ref[...] = jnp.ones_like(o_ref) + + np.testing.assert_array_equal(kernel(), np.ones((128,), dtype=jnp.float32)) + @parameterized.named_parameters( { "testcase_name": "1d_none", @@ -447,13 +897,15 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): }, ) def test_copy_gmem_to_smem_with_multiple_gmem_indexers(self, shape, indexers): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), - scratch_shapes=[plgpu.SMEM(shape, jnp.float32), - plgpu.Barrier(num_arrivals=1), - ], + scratch_shapes=[ + plgpu.SMEM(shape, jnp.float32), + plgpu.Barrier(), + ], grid=(1,), ) def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): @@ -478,12 +930,12 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): def test_gmem_to_smem_with_multiple_smem_indexers(self): x = jax.random.uniform(jax.random.key(0), (2, 64, 64), dtype=jnp.float32) @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([64, 64], jnp.float32), in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), scratch_shapes=[ plgpu.SMEM(x.shape, jnp.float32), - plgpu.Barrier(num_arrivals=1), + plgpu.Barrier(), ], ) def extract_x0(x_ref_gmem, o_ref, scratch_ref, barrier_ref): @@ -495,21 +947,27 @@ def extract_x0(x_ref_gmem, o_ref, scratch_ref, barrier_ref): np.testing.assert_array_equal(extract_x0(x), x[0]) def test_gmem_to_smem_with_multiple_smem_indexers_and_transforms(self): + transforms = self.default_transforms(dtype=jnp.int32) x = jnp.arange(512 * 512, dtype=jnp.int32).reshape(512, 512) @functools.partial( - pl.pallas_call, + self.pallas_call, grid=(4, 4), out_shape=jax.ShapeDtypeStruct((256, 128), jnp.int32), - in_specs=(plgpu.GPUBlockSpec( - block_shape=(128, 128), - index_map=lambda i, j: (i, j), - memory_space=plgpu.SMEM, - transforms=(plgpu.TilingTransform((64, 32)), - plgpu.SwizzleTransform(128))),), - out_specs=(plgpu.GPUBlockSpec( - block_shape=(64, 32), - index_map=lambda i, j: (i, j), - memory_space=plgpu.SMEM,)), + in_specs=( + plgpu.BlockSpec( + block_shape=(128, 128), + index_map=lambda i, j: (i, j), + memory_space=plgpu.SMEM, + transforms=transforms, + ), + ), + out_specs=( + plgpu.BlockSpec( + block_shape=(64, 32), + index_map=lambda i, j: (i, j), + memory_space=plgpu.SMEM, + ) + ), ) def kernel(x_ref, o_ref): x_sliced = x_ref.at[0:64, 32:96].at[:, 0:32] # get x_ref[0:64, 32:64] @@ -521,13 +979,14 @@ def kernel(x_ref, o_ref): @parameterized.product(indexer=[0, 1, 2, 3]) def test_copy_gmem_to_smem_with_indexed_barrier(self, indexer): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32), in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), scratch_shapes=[ plgpu.SMEM((128,), jnp.float32), - plgpu.Barrier(num_arrivals=1, num_barriers=4), + plgpu.Barrier(num_barriers=4), ], ) def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): @@ -542,6 +1001,8 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): @parameterized.named_parameters(("_g2s", False), ("_s2g", True)) def test_copy_with_transforms(self, to_smem): + transforms = self.default_transforms(dtype=jnp.float32) + def kernel(x_ref, o_ref, barrier_ref): if to_smem: plgpu.copy_gmem_to_smem(x_ref, o_ref, barrier_ref) @@ -552,105 +1013,234 @@ def kernel(x_ref, o_ref, barrier_ref): plgpu.wait_smem_to_gmem(0) in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) - out_spec = plgpu.GPUBlockSpec( - (128, 128), - lambda: (0, 0), - transforms=( - plgpu.TilingTransform((64, 32)), - plgpu.SwizzleTransform(128), - ), + out_spec = plgpu.BlockSpec( + transforms=transforms, memory_space=plgpu.SMEM, ) if not to_smem: in_spec, out_spec = out_spec, in_spec - f = pl.pallas_call( + f = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32), in_specs=(in_spec,), out_specs=out_spec, - scratch_shapes=[plgpu.Barrier(num_arrivals=1)], + scratch_shapes=[plgpu.Barrier()], ) x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128) np.testing.assert_array_equal(f(x), x) def test_scoped_copy_with_transforms(self): - ts = (plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128)) + ts = self.default_transforms(dtype=jnp.float32) def kernel(x_ref, o_ref, barrier_ref): def body(tmp_ref): plgpu.copy_gmem_to_smem(x_ref, tmp_ref, barrier_ref) plgpu.barrier_wait(barrier_ref) o_ref[...] = tmp_ref[...] * 2 - pl.run_scoped(body, plgpu.SMEM((128, 128), jnp.float32, transforms=ts)) + pl.run_scoped(body, plgpu.SMEM((128, 64), jnp.float32, transforms=ts)) in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) - out_spec = plgpu.GPUBlockSpec( - (128, 128), lambda: (0, 0), transforms=ts, memory_space=plgpu.SMEM, - ) - f = pl.pallas_call( + out_spec = plgpu.BlockSpec(transforms=ts, memory_space=plgpu.SMEM) + f = self.pallas_call( kernel, - out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32), + out_shape=jax.ShapeDtypeStruct([128, 64], jnp.float32), in_specs=(in_spec,), out_specs=out_spec, - scratch_shapes=[plgpu.Barrier(num_arrivals=1)], + scratch_shapes=[plgpu.Barrier()], ) - x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128) + x = jnp.arange(128 * 64, dtype=jnp.float32).reshape(128, 64) + np.testing.assert_array_equal(f(x), x * 2) + + @jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell") + def test_scoped_copy_with_user_transforms(self): + self.skip_if_wg_semantics() + + def kernel(x_ref, o_ref, barrier_ref): + def body(tmp_ref): + tmp_ref = plgpu.unswizzle_ref(tmp_ref, 128) + tmp_ref = plgpu.untile_ref(tmp_ref, (8, 32)) + plgpu.copy_gmem_to_smem(x_ref, tmp_ref, barrier_ref) + plgpu.barrier_wait(barrier_ref) + o_ref[...] = tmp_ref[...] * 2 + pl.run_scoped(body, plgpu.SMEM((8, 4, 8, 32), jnp.float32)) + + in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) + f = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct([64, 128], jnp.float32), + in_specs=(in_spec,), + scratch_shapes=[plgpu.Barrier()], + ) + x = jnp.arange(64 * 128, dtype=jnp.float32).reshape(64, 128) np.testing.assert_array_equal(f(x), x * 2) def test_copy_with_transforms_and_indexing(self): + self.skip_if_wg_semantics() + def kernel(x_ref, o_ref, barrier_ref): for i in range(2): plgpu.copy_gmem_to_smem(x_ref, o_ref.at[i], barrier_ref) plgpu.barrier_wait(barrier_ref) in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) - out_spec = plgpu.GPUBlockSpec( - (2, 128, 128), - lambda: (0, 0, 0), + out_spec = plgpu.BlockSpec( transforms=( - plgpu.TilingTransform((64, 32)), + plgpu.TilingTransform((8, 32)), plgpu.TransposeTransform((0, 2, 1, 3, 4)), plgpu.SwizzleTransform(128), ), memory_space=plgpu.SMEM, ) - f = pl.pallas_call( + f = self.pallas_call( kernel, - out_shape=jax.ShapeDtypeStruct([2, 128, 128], jnp.float32), + out_shape=jax.ShapeDtypeStruct([2, 64, 128], jnp.float32), in_specs=(in_spec,), out_specs=out_spec, - scratch_shapes=[plgpu.Barrier(num_arrivals=1)], + scratch_shapes=[plgpu.Barrier()], ) - x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128) + x = jnp.arange(64 * 128, dtype=jnp.float32).reshape(64, 128) np.testing.assert_array_equal(f(x), np.stack([x, x], axis=0)) - def test_indexing_before_transpose(self): - def kernel(x_ref, o_ref, barrier_ref): - for i in range(2): - plgpu.copy_gmem_to_smem( - x_ref, plgpu.transpose_ref(o_ref.at[i], (1, 0, 2)), barrier_ref - ) - plgpu.barrier_wait(barrier_ref) - - in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) - out_spec = plgpu.GPUBlockSpec( - (2, 64, 2, 128), lambda: (0, 0, 0, 0), memory_space=plgpu.SMEM, - ) - f = pl.pallas_call( - kernel, - out_shape=jax.ShapeDtypeStruct([2, 64, 2, 128], jnp.float32), - in_specs=(in_spec,), - out_specs=out_spec, - scratch_shapes=[plgpu.Barrier(num_arrivals=1)], + @parameterized.parameters( + ((),), + ((plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128)),), + ( + ( + plgpu.TilingTransform((8, 32)), + plgpu.TransposeTransform((1, 0, 2, 3)), + plgpu.SwizzleTransform(128), + ), + ), + ) + def test_copy_gmem_to_smem_gather(self, transforms): + if not jtu.is_cuda_compute_capability_at_least("10.0"): + self.skipTest("Only works on a GPU with capability >= sm100") + if transforms: + # We cannot yet specify transforms on block specs for WG semantics. + self.skip_if_wg_semantics() + + # TODO(b/415721295): Remove when the minimum jaxlib version is 0.8.3. + if not hasattr(mgpu.dialect, "tma_gather_supported"): + self.skip_if_wg_semantics() + + dtype = jnp.int32 + out_shape = (64, 128) + shape = (128, 64 + out_shape[-1]) + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(out_shape, dtype), + out_specs=plgpu.BlockSpec(memory_space=plgpu.SMEM, transforms=transforms), + in_specs=( + pl.BlockSpec(memory_space=plgpu.GMEM), + pl.BlockSpec(memory_space=plgpu.SMEM), + ), + scratch_shapes=[plgpu.Barrier()], ) - x = jnp.arange(2 * 64 * 128, dtype=jnp.float32).reshape(2, 64, 128) - xt = x.transpose((1, 0, 2)) - np.testing.assert_array_equal(f(x), np.stack([xt, xt], axis=0)) + def kernel(x_ref_gmem, idx_ref, o_ref, barrier_ref): + idxs = plgpu.load(idx_ref, (), layout=plgpu.Layout.TMA_GATHER_INDICES) + plgpu.copy_gmem_to_smem(x_ref_gmem.at[idxs, 64:], o_ref, barrier_ref) + plgpu.barrier_wait(barrier_ref) - def test_copy_gmem_to_smem_in_run_scoped(self): + x = jnp.arange(math.prod(shape)).reshape(shape).astype(dtype) + idx = jax.random.permutation(jax.random.key(1234), out_shape[0]).astype(jnp.uint32) + np.testing.assert_array_equal(kernel(x, idx), x[idx, 64:]) + + @parameterized.product( + src_transposed=(False, True), shape=((128, 128), (1, 128, 128)) + ) + def test_transposed_load_store(self, src_transposed, shape): + dtype = jnp.float32 + permutation = (0, 2, 1) if len(shape) == 3 else (1, 0) @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.float32), - in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), + self.kernel, + out_shape=jax.ShapeDtypeStruct(shape, dtype), + ) + def kernel(src_ref, dst_ref): + if src_transposed: + src_ref = plgpu.transpose_ref(src_ref, permutation) + src_layout = plgpu.Layout.WGMMA_TRANSPOSED + dst_layout = plgpu.Layout.WGMMA + else: + dst_ref = plgpu.transpose_ref(dst_ref, permutation) + src_layout = plgpu.Layout.WGMMA + dst_layout = plgpu.Layout.WGMMA_TRANSPOSED + src = plgpu.load(src_ref, (), layout=src_layout, optimized=False) + dst = plgpu.layout_cast(src, dst_layout) + dst_ref[...] = dst + + x = jnp.arange(math.prod(shape), dtype=dtype).reshape(shape) + np.testing.assert_array_equal(kernel(x), jnp.transpose(x, permutation)) + + @parameterized.product( + src_memory_space=[plgpu.SMEM, plgpu.GMEM], + layout=[plgpu.Layout.WG_STRIDED((128,), vec_size=1), None, + ] + ) + def test_load_to_strided_layout_with_indexing(self, src_memory_space, layout): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([2, 128], jnp.float32), + in_specs=[pl.BlockSpec(memory_space=src_memory_space)], + out_specs=plgpu.BlockSpec(memory_space=plgpu.SMEM), + ) + def kernel(x_ref, o_ref): + for i in range(2): + x = plgpu.load(x_ref, (i,), layout=layout) + o_ref[i, ...] = x + + x = jnp.arange(2 * 128, dtype=jnp.float32).reshape(2, 128) + np.testing.assert_array_equal(kernel(x), x) + + @parameterized.parameters(None, plgpu.TilingTransform((32,))) + def test_smem_gmem_transposed_copies(self, tiling): + shape = (2, 2, 64) + transforms = (tiling,) if tiling is not None else () + if transforms: + self.skip_if_wg_semantics() # Can't specify user transforms under WG semantics. + + @functools.partial( + self.kernel, + out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), + scratch_shapes=[ + plgpu.SMEM(shape, jnp.float32, transforms=transforms), + plgpu.Barrier() + ], + ) + def kernel(src_ref, dst_ref, smem_ref, barrier_ref): + smem_ref = plgpu.transpose_ref(smem_ref, (1, 0, 2)) + plgpu.copy_gmem_to_smem(src_ref, smem_ref, barrier_ref) + plgpu.barrier_wait(barrier_ref) + plgpu.copy_smem_to_gmem(smem_ref, dst_ref) + + x = jnp.arange(math.prod(shape), dtype=jnp.float32).reshape(shape) + np.testing.assert_array_equal(kernel(x), x) + + def test_indexing_before_transpose(self): + def kernel(x_ref, o_ref, barrier_ref): + for i in range(2): + plgpu.copy_gmem_to_smem( + x_ref, plgpu.transpose_ref(o_ref.at[i], (1, 0, 2)), barrier_ref + ) + plgpu.barrier_wait(barrier_ref) + + in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) + out_spec = plgpu.BlockSpec(memory_space=plgpu.SMEM) + f = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct([2, 32, 2, 128], jnp.float32), + in_specs=(in_spec,), + out_specs=out_spec, + scratch_shapes=[plgpu.Barrier()], + ) + x = jnp.arange(2 * 32 * 128, dtype=jnp.float32).reshape(2, 32, 128) + xt = x.transpose((1, 0, 2)) + np.testing.assert_array_equal(f(x), np.stack([xt, xt], axis=0)) + + def test_copy_gmem_to_smem_in_run_scoped(self): + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), ) def kernel(x_ref_gmem, o_ref): def body(barrier_ref): @@ -659,14 +1249,15 @@ def inner_body(scratch_ref): plgpu.barrier_wait(barrier_ref) o_ref[...] = scratch_ref[...] + 1 pl.run_scoped(inner_body, plgpu.SMEM((256,), jnp.float32)) - pl.run_scoped(body, plgpu.Barrier(num_arrivals=1)) + pl.run_scoped(body, plgpu.Barrier()) x = jnp.arange(256).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x + 1.0) def test_add_doubled_sum(self): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32), ) def kernel(x_ref, o_ref): @@ -675,26 +1266,6 @@ def kernel(x_ref, o_ref): x = jnp.arange(128).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x + x.sum()*2) - @parameterized.named_parameters( - ("rsqrt", jax.lax.rsqrt, ), - ("log", jax.lax.log, 5e-7), - ("exp", jax.lax.exp, ), - ("exp2", jax.lax.exp2, 5e-7), - ("logistic", jax.lax.logistic, ), - ("tanh", jax.lax.tanh, 5e-7), - ) - def test_approx_math_unary_op(self, unary_op, rtol=1e-7): - @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([128], jnp.float32), - compiler_params=plgpu.GPUCompilerParams(approx_math=True), - ) - def kernel(x_ref, o_ref): - o_ref[...] = unary_op(x_ref[...]) - - x = jnp.arange(128).astype(jnp.float32) / 128 - np.testing.assert_allclose(kernel(x), unary_op(x), rtol=rtol, atol=1e-5) - @parameterized.product(input_factor=[0.001, 1, 10, 100, 100]) def test_layer_norm(self, input_factor): eps = 1e-5 @@ -702,7 +1273,7 @@ def test_layer_norm(self, input_factor): beta = 1.0 @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), ) def layer_norm(x_ref, o_ref): @@ -730,8 +1301,9 @@ def layer_norm_np(x): np.testing.assert_allclose(layer_norm(x), layer_norm_np(x), rtol=5e-5) def test_print(self): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), ) def kernel(x_ref, o_ref): @@ -744,16 +1316,27 @@ def kernel(x_ref, o_ref): self.assertEqual(output(), "It works!\n") def test_print_wgmma_tiled_layout(self): - shape = (128, 64) + # The default printf buffer on some smaller GPUs (e.g. Thor) only has space for + # 4096 threads to printf (short) messages. Keep this shape below that. + shape = (128, 32) size = math.prod(shape) + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), + in_specs=[ + plgpu.BlockSpec( + transforms= self.default_transforms(dtype=jnp.float32), + ) + ], + ) def kernel(x_ref, o_ref): + del o_ref # Unused. pl.debug_print("prefix {}", x_ref[...]) - spec = plgpu.GPUBlockSpec(shape, lambda: (0, 0), transforms=(plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128))) - x = jnp.arange(size, dtype=jnp.float32).reshape(shape) - f = pl.pallas_call(kernel, out_shape=x, in_specs=[spec], out_specs=spec) + x = jnp.arange(size, dtype=jnp.float32).reshape(shape) with self.capture_stdout() as get_output: - jax.block_until_ready(f(x)) + jax.block_until_ready(kernel(x)) output = get_output() results = re.findall(r"prefix \[(\d+), (\d+)\]: (\d+).?\d*", output) @@ -764,7 +1347,7 @@ def kernel(x_ref, o_ref): def test_print_scalar(self): @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32), ) def kernel(x_ref, o_ref): @@ -779,7 +1362,7 @@ def kernel(x_ref, o_ref): def test_print_scalar_array(self): @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32), ) def kernel(x_ref, o_ref): @@ -796,7 +1379,7 @@ def test_print_array(self): in_shape = [2, 1, 64, 64] @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct(in_shape, jnp.int32), ) def kernel(x_ref, o_ref): @@ -809,11 +1392,85 @@ def kernel(x_ref, o_ref): self.assertIn("x: [1, 0, 43, 23]: 6871\n", output()) + def test_print_layout(self): + shape = (128,) + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(shape, jnp.bfloat16), + ) + def kernel(x_ref, o_ref): + del o_ref + x = plgpu.layout_cast(x_ref[...], plgpu.Layout.WGMMA_ROW) + plgpu.print_layout("x: {}", x) + + x = jnp.arange(math.prod(shape), dtype=jnp.bfloat16).reshape(shape) + with self.capture_stdout() as output: + jax.block_until_ready(kernel(x)) + + self.assertIn("x: WGMMA_ROW\n", output()) + + @parameterized.parameters( + (plgpu.TilingTransform((1, 32)), plgpu.SwizzleTransform(128)), + (plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128)), + (), + ) + def test_get_swap_with_transforms(self, *transforms): + self.skip_if_wg_semantics() + + shape = (128, 128) + + @functools.partial( + self.pallas_call, + in_specs=[plgpu.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=plgpu.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(shape, jnp.int32), + scratch_shapes=[ + plgpu.SMEM(shape, jnp.int32, transforms=tuple(transforms)), + plgpu.Barrier(), + ] + ) + def kernel(x_ref, o_ref, scratch_ref, barrier_ref): + plgpu.copy_gmem_to_smem(x_ref, scratch_ref, barrier_ref) + plgpu.barrier_wait(barrier_ref) + scratch_ref[...] = scratch_ref[...] * 2 + plgpu.copy_smem_to_gmem(scratch_ref, o_ref) + plgpu.wait_smem_to_gmem(0) + + x = jnp.arange(math.prod(shape), dtype=jnp.int32).reshape(shape) + np.testing.assert_array_equal(kernel(x), x * 2) + + def test_swap_scalar_constant(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((), jnp.int32), + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + ) + def kernel(o_ref): + o_ref[...] = jnp.array(42) + + np.testing.assert_array_equal(kernel(), jnp.array(42, jnp.int32)) + + def test_check(self): + self.enter_context(pl.enable_debug_checks(True)) + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.int32), + ) + def kernel(x_ref, o_ref): + pl.debug_check(_sum_same_dtype(x_ref[...]) > 0, "x.sum() is negative") + o_ref[...] = x_ref[...] + + x = jnp.arange(256, dtype=jnp.int32) + np.testing.assert_array_equal(kernel(x), x) + def test_load_scalar(self): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct((128,), jnp.int32), - in_specs=[plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)], + in_specs=[plgpu.BlockSpec(memory_space=plgpu.GMEM)], ) def kernel(x_ref, o_ref): o_ref[...] = jnp.broadcast_to(x_ref[10], (128,)) @@ -821,9 +1478,11 @@ def kernel(x_ref, o_ref): np.testing.assert_array_equal(kernel(jnp.arange(11, dtype=jnp.int32)), jnp.full((128,), 10, dtype=jnp.int32)) - @parameterized.product(thread_semantics=[*plgpu.ThreadSemantics]) - def test_run_scoped(self, thread_semantics): - + def test_run_scoped(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + ) def kernel(x_ref, o_ref): def body(tmp_ref): self.assertEqual(tmp_ref.shape, (8, 128)) @@ -834,20 +1493,32 @@ def body(tmp_ref): self.assertEqual(tmp.shape, (8, 128)) o_ref[...] = tmp - inp = np.ones((8, 128), jnp.float32) - f = pl.pallas_call( - kernel, - out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + x = np.ones((8, 128), jnp.float32) + np.testing.assert_array_equal(kernel(x), x + 1.0) + + def test_run_scoped_in_cond(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.int32), + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.SMEM), ) - o = f(inp) - np.testing.assert_array_equal(o, inp + 1.0) + def kernel(x_ref_gmem, o_ref): + def scoped_kernel(barrier_ref): + plgpu.copy_gmem_to_smem(x_ref_gmem, o_ref, barrier_ref) + plgpu.barrier_wait(barrier_ref) + + def branch(): + pl.run_scoped(scoped_kernel, plgpu.Barrier()) + + jax.lax.cond(x_ref_gmem[0] % 2 == 0, branch, branch) + + x = jnp.full((256,), 1234, dtype=jnp.int32) + np.testing.assert_array_equal(kernel(x), x) def test_program_id(self): @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=(), out_specs=pl.BlockSpec((128,), lambda *i: i), out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.int32), @@ -866,7 +1537,7 @@ def test_program_id_in_squashed_grid(self): # 3 CUDA grid dimensions. grid = (2, 3, 4, 5) @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=(), out_specs=pl.BlockSpec((1,) * len(grid) + (128,), lambda *i: (*i, 0)), out_shape=jax.ShapeDtypeStruct([*grid, 128], jnp.int32), @@ -885,9 +1556,58 @@ def kernel(o_ref): jnp.arange(math.prod(grid), dtype=jnp.int32).reshape(*grid) ) + @parameterized.parameters( + ((2, 3), ("a", "b"), (), ()), + ((2, 3), ("a", "b"), (2,), ("x",)), + ((2, 3, 4), ("a", "b", "c"), (), ()), + ((2, 3, 4), ("a", "b", "c"), (2,), ("x",)), + ((2, 3, 4), ("a", "b", "c"), (2, 3), ("x", "y")), + ((2, 3, 4, 5), ("a", "b", "c", "d"), (), ()), + ((2, 3, 4, 5), ("a", "b", "c", "d"), (2,), ("x",)), + ) + def test_axis_indices_in_grid(self, grid, grid_names, cluster, cluster_names): + @functools.partial( + self.kernel, + out_shape=[ + jax.ShapeDtypeStruct([*cluster, *grid, 128], jnp.int32), + jax.ShapeDtypeStruct([*cluster, *grid, 128], jnp.int32) + ], + grid=grid, + grid_names=grid_names, + cluster=cluster, + cluster_names=cluster_names, + ) + def kernel(out1_ref, out2_ref): + pallas_grid_idx = lax.axis_index(grid_names) + cuda_grid_idx = _get_linearized_cuda_grid_index() + + out_indices = [lax.axis_index(ax) for ax in (*cluster_names, *grid_names)] + out1_ref[*out_indices] = jnp.full((128,), pallas_grid_idx) + out2_ref[*out_indices] = jnp.full((128,), cuda_grid_idx) + out1, out2 = kernel() + + out_per_cta = jnp.arange(math.prod(grid), dtype=jnp.int32).reshape(grid) + out1_ref = jnp.broadcast_to(out_per_cta[..., None], (*cluster, *grid, 128)) + np.testing.assert_array_equal(out1, out1_ref) + + padded_cluster = (1,) * (len(grid) - len(cluster)) + cluster + scaled_grid = tuple(g * c for g, c in zip(grid, padded_cluster)) + original = jnp.arange(math.prod(scaled_grid), dtype=jnp.int32).reshape( + scaled_grid + ) + + # Untile the scaled grid to get the per-cluster grid. + interleaved_shape = tuple(val for pair in zip(grid, padded_cluster) for val in pair) + perm = tuple(range(1, 2 * len(grid), 2)) + tuple(range(0, 2 * len(grid), 2)) + + out2_ref = original.reshape(interleaved_shape).transpose(perm).squeeze() + out2_ref = jnp.broadcast_to(out2_ref[..., None], out2_ref.shape + (128,)) + + np.testing.assert_array_equal(out2, out2_ref) + def test_program_id_in_block_spec(self): @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=(pl.BlockSpec((2, 128), lambda i: (pl.program_id(0), i)),), out_specs=pl.BlockSpec((2, 128), lambda i: (pl.program_id(0), i)), out_shape=jax.ShapeDtypeStruct([2, 128], jnp.int32), @@ -901,7 +1621,7 @@ def kernel(x_ref, o_ref): def test_num_programs(self): @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=(), out_specs=pl.BlockSpec((128,), lambda *i: i), out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.int32), @@ -916,17 +1636,13 @@ def kernel(o_ref): ) def test_swizzled_blockspec_shapes(self): - - spec = plgpu.GPUBlockSpec( + spec = plgpu.BlockSpec( (128, 64), lambda *i: i, - transforms=( - plgpu.TilingTransform((64, 64)), - plgpu.SwizzleTransform(128), - ), + transforms=self.default_transforms(dtype=jnp.float16), ) @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=[spec], out_specs=spec, out_shape=jax.ShapeDtypeStruct((128, 128), jnp.float16), @@ -939,30 +1655,38 @@ def kernel(x_ref, o_ref): x = jnp.arange(128 * 128).astype(jnp.float16).reshape(128, 128) np.testing.assert_array_equal(kernel(x), x) - @parameterized.product( - force_while=[False, True], thread_semantics=[*plgpu.ThreadSemantics] - ) - def test_fori_loop_array(self, force_while, thread_semantics): + @parameterized.product(force_while=[False, True]) + def test_fori_loop_array(self, force_while): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.int32), - compiler_params=plgpu.GPUCompilerParams(thread_semantics=thread_semantics), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32) ) def kernel(x_ref, o_ref): # Equivalent to x_ref[...] + 2 + 3. - o_ref[...] = _fori_loop(force_while, 2, 4, lambda i, x: x + i, x_ref[...]) + o_ref[...] = _fori_loop( + force_while, 2, 4, lambda i, x: x + i, x_ref[...] + ) x = jnp.arange(256, dtype=jnp.int32) np.testing.assert_array_equal(kernel(x), x + 2 + 3) - @parameterized.product( - force_while=[False, True], thread_semantics=[*plgpu.ThreadSemantics] - ) - def test_fori_loop_scalar(self, force_while, thread_semantics): + @parameterized.product(unroll=[1, 2, 4]) + def test_fori_loop_array_unrolled(self, unroll): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.int32), - compiler_params=plgpu.GPUCompilerParams(thread_semantics=thread_semantics), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32) + ) + def kernel(x_ref, o_ref): + # Equivalent to x_ref[...] + 2 + 3 + 4 + 5. + o_ref[...] = lax.fori_loop( + 2, 6, lambda i, x: x + i, x_ref[...], unroll=unroll + ) + + x = jnp.arange(256, dtype=jnp.int32) + np.testing.assert_array_equal(kernel(x), x + 2 + 3 + 4 + 5) + + @parameterized.product(force_while=[False, True]) + def test_fori_loop_scalar(self, force_while): + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32) ) def kernel(o_ref): # Equivalent to 2 + 3. @@ -974,9 +1698,8 @@ def kernel(o_ref): np.testing.assert_array_equal(kernel(), jnp.full([256], 5, jnp.int32)) def test_fori_loop_dynamic_bounds(self): - @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32), grid=(1,) ) @@ -989,16 +1712,10 @@ def kernel(o_ref): np.testing.assert_array_equal(kernel(), jnp.full([256], 5, dtype=jnp.int32)) - @parameterized.product( - force_while=[False, True], thread_semantics=[*plgpu.ThreadSemantics] - ) - def test_fori_loop_tuple(self, force_while, thread_semantics): + @parameterized.product(force_while=[False, True]) + def test_fori_loop_tuple(self, force_while): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.int32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32) ) def kernel(o_ref): def body(step, xs): @@ -1017,16 +1734,11 @@ def body(step, xs): kernel(), jnp.full([256], 3 * (0 + 1), jnp.int32) ) - @parameterized.product( - force_while=[False, True], thread_semantics=[*plgpu.ThreadSemantics] - ) - def test_fori_loop_indexed_store(self, force_while, thread_semantics): + @parameterized.product(force_while=[False, True]) + def test_fori_loop_indexed_store(self, force_while): @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([4, 128], jnp.float32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), ) def kernel(x_ref, y_ref, o_ref): def body(idx, _): @@ -1039,17 +1751,9 @@ def body(idx, _): y = x + 1 np.testing.assert_array_equal(kernel(x, y), x + y) - @parameterized.product(thread_semantics=[*plgpu.ThreadSemantics]) - def test_while_loop(self, thread_semantics): - if thread_semantics == plgpu.ThreadSemantics.Warpgroup: - self.skipTest("WG lowering does not support reduce_sum_p needed for this test") - + def test_while_loop(self): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([128], jnp.int32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.int32) ) def kernel(x_ref, o_ref): o_ref[...] = jnp.zeros(o_ref.shape, dtype=jnp.int32) @@ -1071,7 +1775,7 @@ def body(acc): def test_while_loop_layout_mismatch(self): @functools.partial( - pl.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.int32) + self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.int32) ) def kernel(o_ref): def cond(acc): @@ -1079,23 +1783,33 @@ def cond(acc): def body(acc): del acc # Unused. + o_ref[...] = o_ref[...] # side-effect to prevent DCE # We deliberately do a cast here to trigger a layout mismatch. return plgpu.layout_cast( jnp.zeros(o_ref.shape, o_ref.dtype), plgpu.Layout.WGMMA_ROW ) - - _ = jax.lax.while_loop(cond, body, o_ref[...]) - - with self.assertRaisesRegex(ValueError, "has layout .*, when it should be"): - kernel() - - @parameterized.parameters([*plgpu.ThreadSemantics]) - def test_cond(self, thread_semantics): + # Cast explicitly to cause the mismatch, otherwise layout inference will + # succeed at constructing a working program. + strided_input = plgpu.layout_cast( + o_ref[...], plgpu.Layout.WG_STRIDED(shape=(128,), vec_size=1) + ) + _ = jax.lax.while_loop(cond, body, strided_input) + + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Warpgroup: + with self.assertRaisesRegex( + ValueError, "Failed to infer a possible set of layouts", + ): + kernel() + else: + with self.assertRaisesRegex( + ValueError, "has layout .*, when it should be" + ): + kernel() + + def test_cond(self): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.int32), - compiler_params=plgpu.GPUCompilerParams(thread_semantics=thread_semantics), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32) ) def kernel(x_ref, o_ref): jax.lax.cond( @@ -1111,27 +1825,49 @@ def kernel(x_ref, o_ref): self.assertIn("acc % 2", output()) - @parameterized.parameters([*plgpu.ThreadSemantics]) - def test_cond_returning_array(self, thread_semantics): + def test_cond_returning_array(self): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.int32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32) ) def kernel(x_ref, o_ref): - acc = _sum_same_dtype(x_ref[...]) + acc_sum = _sum_same_dtype(x_ref[...]) acc2, acc = jax.lax.cond( - acc % 2 == 0, - lambda: (acc * 2, acc), - lambda: (acc, acc * 2), + acc_sum % 2 == 0, + lambda: (acc_sum * 2, x_ref[...]), + lambda: (acc_sum, x_ref[...]), ) - o_ref[...] = jnp.broadcast_to(acc + acc2, o_ref.shape) + o_ref[...] = jnp.broadcast_to(_sum_same_dtype(acc) + acc2, o_ref.shape) x = jnp.arange(256, dtype=jnp.int32) np.testing.assert_array_equal(kernel(x), jnp.broadcast_to(jnp.sum(x) * 3, [256])) + def test_tile_slicing(self): + # Not testing with warpgroup semantics, because we want to enforce a layout. + self.skip_if_wg_semantics() + + shape = (256, 128) + block_spec = plgpu.BlockSpec( + transforms=self.default_transforms(dtype=jnp.uint16) + ) + @functools.partial( + self.pallas_call, + in_specs=[block_spec], + out_specs=block_spec, + out_shape=jax.ShapeDtypeStruct((64, 64), jnp.uint16), + ) + def kernel(x_ref, o_ref): + def sum_tiles(row, acc): + row_slice = pl.ds(row * 64, 64) + for col in range(128 // 64): + acc += x_ref[row_slice, pl.ds(col * 64, 64)] + return acc + acc = plgpu.layout_cast(jnp.zeros((64, 64), jnp.uint16), plgpu.Layout.WGMMA) + o_ref[...] = _fori_loop(False, 0, 256 // 64, sum_tiles, acc) + + x = jnp.arange(math.prod(shape), dtype=jnp.uint16).reshape(shape) + y = x.reshape(256 // 64, 64, 128 // 64, 64).sum(axis=(0, 2), dtype=jnp.uint16) + np.testing.assert_array_equal(kernel(x), y) + def test_input_output_aliases(self): # Note that we're writing to the input pointer, which should alias b_ptr. def kernel(a_ref, b_ref): @@ -1139,10 +1875,10 @@ def kernel(a_ref, b_ref): a_ref[...] = jnp.ones_like(a_ref) a = np.zeros((64, 64), dtype=jnp.float32) - b = pl.pallas_call( + b = self.pallas_call( kernel, - in_specs=[plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)], - out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM), + in_specs=[plgpu.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=plgpu.BlockSpec(memory_space=plgpu.GMEM), input_output_aliases={0: 0}, out_shape=a, )(a) @@ -1159,22 +1895,17 @@ def rotate(src, dst): dst[lower, left] = src[lower, right] x = jnp.arange(128 * 128).astype(jnp.float16).reshape(128, 128) - spec = plgpu.GPUBlockSpec( - (128, 128), - lambda: (0, 0), - transforms=( - plgpu.TilingTransform((64, 64)), - plgpu.SwizzleTransform(128), - ), + spec = plgpu.BlockSpec( + transforms=self.default_transforms(dtype=jnp.float16) ) - f = pl.pallas_call(rotate, out_shape=x, in_specs=[spec], out_specs=spec) + f = self.pallas_call(rotate, out_shape=x, in_specs=[spec], out_specs=spec) expected = np.empty_like(x) rotate(x, expected) np.testing.assert_array_equal(f(x), expected) def test_layout_cast(self, shape=(256, 64)): @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), ) def kernel(o_ref): @@ -1183,6 +1914,32 @@ def kernel(o_ref): x = jnp.full(shape, 42.0, jnp.float32) np.testing.assert_array_equal(kernel(), x) + @parameterized.product( + layouts=[ + (plgpu.Layout.WGMMA, plgpu.Layout.WGMMA_TRANSPOSED), + (plgpu.Layout.TCGEN05, plgpu.Layout.TCGEN05_TRANSPOSED), + ], + ) + def test_transposed_layout(self, layouts): + layout, transposed_layout = layouts + dtype = jnp.dtype(jnp.float16) + shape = (256, 192) + transforms = self.default_transforms(dtype=dtype) + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(shape[::-1], dtype), + out_specs=plgpu.BlockSpec(transforms=transforms), + ) + def kernel(o_ref): + iota = plgpu.broadcasted_iota(dtype, shape, 0, layout=layout) + iota *= shape[1] + iota += plgpu.broadcasted_iota(dtype, shape, 1, layout=layout) + o_ref_t = plgpu.transpose_ref(o_ref, (1, 0)) + o_ref_t[...] = plgpu.layout_cast(iota, transposed_layout) + + x = jnp.arange(math.prod(shape), dtype=dtype).reshape(shape).T + np.testing.assert_array_equal(kernel(), x) + def test_profiler(self): def kernel(x_ref, o_ref): with jax.named_scope("add"): @@ -1193,17 +1950,17 @@ def kernel(x_ref, o_ref): o_ref[...] = o with tempfile.TemporaryDirectory() as tmpdir: x = jnp.arange(256).astype(jnp.float32) - y = pl.pallas_call( + y = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), - compiler_params=plgpu.GPUCompilerParams( + compiler_params=plgpu.CompilerParams( profile_space=16, profile_dir=tmpdir ), )(x) jax.block_until_ready(y) jax.effects_barrier() [name] = os.listdir(tmpdir) - with open(os.path.join(tmpdir, name), "r") as f: + with open(os.path.join(tmpdir, name)) as f: data = f.read() self.assertEqual(data.count('"name": "add"'), 2) self.assertEqual(data.count('"name": "load"'), 2) @@ -1221,20 +1978,13 @@ def kernel(x_ref, o_ref): (jnp.uint32, jnp.int32), (jnp.int32, jnp.uint32), ], - thread_semantics=[*plgpu.ThreadSemantics], ) - def test_bitcast_convert_type(self, dtypes, thread_semantics): + def test_bitcast_convert_type(self, dtypes): in_dtype, out_dtype = dtypes m, n = 16, 8 out_shape = jax.ShapeDtypeStruct((m, n), out_dtype) - @functools.partial( - pl.pallas_call, - out_shape=out_shape, - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), - ) + @functools.partial(self.pallas_call, out_shape=out_shape) def convert(x_ref, y_ref): y_ref[...] = jax.lax.bitcast_convert_type(x_ref[...], out_shape) @@ -1243,96 +1993,1082 @@ def convert(x_ref, y_ref): convert(x), jax.lax.bitcast_convert_type(x, out_dtype) ) + def test_optimization_barrier(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), + ) + def kernel(x_ref, o_ref): + o_ref[...] = lax.optimization_barrier(x_ref[...]) -class PallasCallSm90ATest(PallasSm90ATest): + x = jax.lax.iota(jnp.float32, 128) + np.testing.assert_array_equal(kernel(x), x) - @parameterized.parameters(False, True) - def test_fori_loop_accumulator(self, force_while): - transforms = (plgpu.TilingTransform((64, 64)), plgpu.SwizzleTransform(128)) + def test_optimization_barrier_multiple_inputs(self): @functools.partial( - pl.pallas_call, - in_specs=[plgpu.GPUBlockSpec((64, 64), lambda: (0, 0), transforms=transforms)], - out_shape=jax.ShapeDtypeStruct((64, 64), jnp.float16), - out_specs=plgpu.GPUBlockSpec((64, 64), lambda: (0, 0)), + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), ) - def kernel(i_ref, o_ref): - def scope(acc_ref): - return _fori_loop(force_while, 0, 4, lambda _, v: v + acc_ref[...], acc_ref[...]) - o_ref[...] = pl.run_state(scope)(plgpu.ACC.init(i_ref[...])) + def kernel(x_ref, y_ref, o_ref): + x, y = lax.optimization_barrier([x_ref[...], y_ref[...]]) + o_ref[...] = x + y - acc_ini = jnp.ones((64, 64), dtype=jnp.float16) - np.testing.assert_array_equal(kernel(acc_ini), jnp.full((64, 64), 5, dtype=jnp.float16)) + x = jax.lax.iota(jnp.float32, 128) + y = jax.lax.iota(jnp.float32, 128) * 3 + np.testing.assert_array_equal(kernel(x, y), x + y) - def test_realistic_matmul(self): - dtype = jnp.float16 - swizzle = 128 - elems_128b = swizzle // jnp.dtype(dtype).itemsize - grid_m, grid_k, grid_n = 132, 10, 4 - tile_m = tile_n = 128 - tile_k = elems_128b - m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n - def kernel(a_ref, b_ref, o_ref, acc_ref): - # Make sure tiling does not alter the shape of references - assert a_ref.shape == (tile_m, tile_k) - assert b_ref.shape == (tile_k, tile_n) - assert o_ref.shape == acc_ref.shape == (tile_m, tile_n) - plgpu.wgmma(acc_ref, a_ref, b_ref) - is_last_step = pl.program_id(2) == grid_k - 1 - @pl.when(is_last_step) - def _epilogue(): - o_ref[...] = acc_ref[...].astype(dtype) - plgpu.wgmma_wait(1) # We don't await the last WGMMA, hence delay_release=1 + def test_smem_aliasing_works_basic(self): + self.skip_if_wg_semantics() - key1, key2 = jax.random.split(jax.random.key(42), 2) - a = jax.random.uniform(key1, shape=(m, k), dtype=dtype) - b = jax.random.uniform(key2, shape=(k, n), dtype=dtype) + in_shape = (2, 256) - res = pl.pallas_call( - kernel, - in_specs=[ - plgpu.GPUBlockSpec( - (tile_m, tile_k), - lambda m, n, k: (m, k), - transforms=( - plgpu.TilingTransform((64, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), - plgpu.GPUBlockSpec( - (tile_k, tile_n), - lambda m, n, k: (k, n), - transforms=( - plgpu.TilingTransform((elems_128b, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([128], jnp.float32), + in_specs=[pl.BlockSpec(in_shape)], + out_specs=pl.BlockSpec((128,), memory_space=plgpu.GMEM), + scratch_shapes=[ + plgpu.RefUnion( + # Note: this test exposes internals that we don't particularly + # want to phold for the sake of testing the functionality of the + # API. It's expected that this test might end up breaking in the + # future, e.g. if we decide to change our alignment requirements + # on SMEM refs---and that's OK. Users should explicitly NOT rely + # on this exact behaviour. + # + # Use a value larger than the number of bytes used for SMEM + # alignment (1024) in order to make sure that the second ref + # in the second group aliases the single ref in the first group. + plgpu.SMEM(in_shape, jnp.float32), + [ + plgpu.SMEM((256,), jnp.bfloat16), + # Add an arbitrary level of nesting to make sure that we + # support PyTrees. + [ + plgpu.SMEM( + (128,), + jnp.float32, + transforms=(plgpu.TilingTransform((64,)),)), + ] + ], + ) ], - out_specs=plgpu.GPUBlockSpec( - (tile_m, tile_n), - lambda m, n, k: (m, n), - transforms=( - plgpu.TilingTransform((64, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), - out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16), - scratch_shapes=[plgpu.ACC((tile_m, tile_n), jnp.float32)], - grid=(grid_m, grid_n, grid_k), - compiler_params=plgpu.GPUCompilerParams( - dimension_semantics=["parallel", "parallel", "sequential"], - max_concurrent_steps=2, - delay_release=1, - ), - )(a, b) - np.testing.assert_allclose(res, a @ b, rtol=1e-3) + ) + def kernel(x_ref, o_ref128, aliased_ref): + smem_ref256, [_, [smem_ref128]] = aliased_ref + # Ensure that extraction via index works the same as unfolding. + smem_ref128_2 = aliased_ref[1][1][0] + self.assertIsInstance(smem_ref128, state_types.TransformedRef) + self.assertIsInstance(smem_ref128_2, state_types.TransformedRef) + self.assertIs(smem_ref128.ref, smem_ref128_2.ref) + self.assertEqual(smem_ref128.transforms, smem_ref128_2.transforms) + extract_alias_transform, tile_transform = smem_ref128.transforms + # Ensure that the transforms provided in the scratch shapes have been + # passed correctly. + self.assertIsInstance(extract_alias_transform, gpu_core.ExtractAliasedRef) + self.assertIsInstance(tile_transform, gpu_core.UntileRef) + smem_ref256[...] = x_ref[...] + 1 + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_ref128, o_ref128) - @parameterized.parameters(jnp.float16, jnp.float32) - def test_wgmma(self, dtype): - # TensorCores can only fuse transposes of 16-bit values, and RHS - # is expected to be column major by default. - rhs_transpose = jnp.dtype(dtype).itemsize != 2 - swizzle = 128 + x = jnp.arange(512).astype(jnp.float32) + np.testing.assert_array_equal( + kernel(x.reshape(in_shape)).reshape((128,)), x[256 : 256 + 128] + 1 + ) + + def test_smem_aliasing_works_with_subbyte_dtypes(self): + self.skip_if_wg_semantics() + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.uint4), + in_specs=[pl.BlockSpec((128,))], + out_specs=pl.BlockSpec((256,), memory_space=plgpu.GMEM), + scratch_shapes=[ + plgpu.RefUnion( + # Note: this test exposes internals that we don't particularly + # want to phold for the sake of testing the functionality of the + # API. It's expected that this test might end up breaking in the + # future, e.g. if we decide to change our alignment requirements + # on SMEM refs---and that's OK. Users should explicitly NOT rely + # on this exact behaviour. + # + # This allocation scheme is a bit complicated, but serves to + # test that + # 1. Refs are aligned correctly (currently to 1024 bytes); + # 2. (u)int4 references are not allocated more than 1 byte per + # 2 elements. + # The first group of refs serves to create two allocations, each + # aligned to 1024 bytes. The second group serves to create two + # allocations where the first one is exactly 1024 bytes, + # assuming 1 byte per 2 uint4 elements. As a result, if our + # implementation is correct, the second allocation of the second + # group should exactly alias the second allocation of the first + # group. + [ + plgpu.SMEM((128,), jnp.int8), + plgpu.SMEM((128,), jnp.int8), + ], + [plgpu.SMEM((2048,), jnp.uint4), plgpu.SMEM((256,), jnp.uint4)], + ) + ], + ) + def kernel(x_ref, o_refi4, aliased_ref): + [_, smem_refi8], [_, smem_refi4] = aliased_ref + smem_refi8[...] = x_ref[...] + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_refi4, o_refi4) + + def unpack_i4_as_i8(x): + x = x.reshape((128, 1)) + x_high = x >> 4 + x_low = x & 0xF + return jnp.concatenate([x_low, x_high], axis=-1).reshape((256,)) + + x = jnp.arange(128).astype(jnp.int8) + test_as_i8 = jax.lax.convert_element_type(kernel(x), new_dtype=jnp.int8) + np.testing.assert_array_equal(test_as_i8[:256], unpack_i4_as_i8(x)) + + def test_smem_aliasing_works_for_quantization(self): + self.skip_if_wg_semantics() + shape = (64, 256) + large_ty, small_ty = jnp.bfloat16, jnp.uint4 + large_swizzle = plgpu.SwizzleTransform(64 * jnp.finfo(large_ty).bits // 8) + small_swizzle = plgpu.SwizzleTransform(64 * jnp.iinfo(small_ty).bits // 8) + tiling = plgpu.TilingTransform((8, 64)) + + def kernel(x_gmem, o_gmem): + return pl.run_scoped( + functools.partial(scoped_kernel, x_gmem, o_gmem), + plgpu.RefUnion( + plgpu.SMEM(shape, large_ty, transforms=(tiling, large_swizzle)), + plgpu.SMEM(shape, small_ty, transforms=(tiling, small_swizzle)) + ), + plgpu.Barrier(num_barriers=1), + ) + + def scoped_kernel(x_gmem, o_gmem, aliased_ref, barrier): + ref_large_ty, ref_small_ty = aliased_ref + plgpu.copy_gmem_to_smem(x_gmem, ref_small_ty, barrier=barrier) + plgpu.barrier_wait(barrier) + ref_large_ty[...] = ref_small_ty[...].astype(ref_large_ty.dtype) * 3 + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(ref_large_ty, o_gmem) + plgpu.wait_smem_to_gmem(0, wait_read_only=True) + + kernel_fn = self.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(shape, large_ty), + grid=(1, 1), + ) + key = jax.random.key(42) + x = jax.random.randint(key, shape, 0, 4).astype(small_ty) + expected = x * 3 + np.testing.assert_array_equal(kernel_fn(x), expected) + + def test_assigning_to_ref_union_raises(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([128], jnp.float32), + in_specs=[pl.BlockSpec((128,))], + out_specs=pl.BlockSpec((128,), memory_space=plgpu.GMEM), + scratch_shapes=[plgpu.RefUnion(plgpu.SMEM((128,), jnp.float32))], + ) + def kernel(x_ref, o_ref128, aliased_ref): + aliased_ref[...] = x_ref[...] + 1 + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(aliased_ref, o_ref128) + + with self.assertRaisesRegex(ValueError, "can't be assigned to"): + kernel(jnp.arange(128).astype(jnp.float32)) + + def test_loading_from_ref_union_works(self): + self.skip_if_wg_semantics() # Transform inference not implemented. + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([128], jnp.float32), + in_specs=[pl.BlockSpec((128,))] * 2, + out_specs=pl.BlockSpec((128,), memory_space=plgpu.GMEM), + scratch_shapes=[plgpu.RefUnion(plgpu.SMEM((128,), jnp.float32)), + plgpu.SMEM((128,), jnp.float32)], + ) + def kernel(x_ref, y_ref, o_ref128, ref_union, o_smem): + [aliased_ref] = ref_union + aliased_ref[...] = x_ref[...] + plgpu.commit_smem() + load_ref = lambda r: plgpu.load(r, (), layout=plgpu.Layout.TCGEN05_ROW) + # This is a regression test for b/423697560, where we used to fail to + # transform the dtype correctly when processing an aliased ref. + o_smem[...] = load_ref(aliased_ref) + load_ref(y_ref) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(o_smem, o_ref128) + + x, y = (jnp.arange(128).astype(jnp.float32) for _ in range(2)) + np.testing.assert_array_equal(kernel(x, y), x + y) + + @parameterized.parameters(1, 2, 3) + def test_nd_loop_with_carry(self, sm_steps): + @functools.partial( + self.kernel, + out_shape=( + jax.ShapeDtypeStruct((sm_steps, 132, 128), jnp.int32), + jax.ShapeDtypeStruct((132,), jnp.int32) + ), + grid=(132,), + grid_names=("sm",), + ) + def kernel(o_ref, steps_ref): + def body(loop_info, carry): + idx = loop_info.index + assert len(idx) == 3 + # We need to use `mode="clip"`, because the indices are not static. + flat_idx = jnp.ravel_multi_index(idx, (sm_steps, 4, 33), mode="clip") + sm_step = lax.div( + flat_idx, lax.convert_element_type(lax.axis_size("sm"), jnp.int32) + ) + o_ref[sm_step, lax.axis_index("sm")] = lax.broadcast( + flat_idx, o_ref.shape[-1:] + ) + return carry + 1 + + steps_ref[lax.axis_index("sm")] = plgpu.nd_loop( + (sm_steps, 4, 33), collective_axes="sm", init_carry=0 + )(body) + + result, steps = kernel() # pylint: disable=unpacking-non-sequence + for sm_step in range(sm_steps): + np.testing.assert_array_equal(steps, jnp.full((132,), sm_steps)) + + np.testing.assert_array_equal( + result[sm_step], + jnp.tile( + (132 * sm_step + jnp.arange(132))[:, None], + 128, + ), + ) + + @parameterized.product( + sm_steps=(1, 2, 3), + tiling=(None, 1, 2, 4), + ) + def test_nd_loop(self, sm_steps: int, tiling: int | None): + if tiling is not None: + tiling = (sm_steps, tiling, 33) + @functools.partial( + self.kernel, + out_shape=jax.ShapeDtypeStruct((sm_steps, 132, 128), jnp.int32), + grid=(132,), + grid_names=("sm",), + ) + def kernel(o_ref): + @plgpu.nd_loop((sm_steps, 4, 33), tiling=tiling, collective_axes="sm") + def _(loop_info): + idx = loop_info.index + assert len(idx) == 3 + # We need to use `mode="clip"`, because the indices are not static. + grid = (sm_steps, 4, 33) + if tiling: + # Reconstruct the tiled grid and index. + tiled_grid = tuple(g // t for g, t in zip(grid, tiling)) + grid = tiled_grid + tiling + tile_idx = tuple( + lax.div(idx, jnp.int32(t)) for idx, t in zip(idx, tiling)) + subtile_idx = tuple( + lax.rem(idx, jnp.int32(t)) for idx, t in zip(idx, tiling)) + idx = tile_idx + subtile_idx + flat_idx = jnp.ravel_multi_index(idx, grid, mode="clip") + sm_step = lax.div( + flat_idx, lax.convert_element_type(lax.axis_size("sm"), jnp.int32) + ) + o_ref[sm_step, lax.axis_index("sm")] = lax.broadcast( + flat_idx, o_ref.shape[-1:] + ) + + result = kernel() + for sm_step in range(sm_steps): + np.testing.assert_array_equal( + result[sm_step], + jnp.tile((132 * sm_step + jnp.arange(132))[:, None], 128), + ) + + def test_lowering_error_context(self): + def body(x_ref, y_ref, barrier): + plgpu.copy_gmem_to_smem(x_ref, y_ref, barrier) + plgpu.barrier_wait(barrier) + + x = jnp.arange(127, dtype=jnp.int4) # Size is not a multiple of bytes + offending_line = "plgpu.copy_gmem_to_smem(x_ref, y_ref, barrier)" + try: + self.pallas_call( + body, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.SMEM), + out_shape=x, + scratch_shapes=[plgpu.Barrier()], + )(x) + except: + # assertRaisesRegex raises does not let us match the traceback. + self.assertIn(offending_line, traceback.format_exc()) + else: + self.fail("Should have raised an exception") + + def test_lower_with_abstract_mesh(self): + def kernel(y_ref, sem): + plgpu.semaphore_signal_multicast(sem, collective_axes='x') + # Wait for the multicast signal (each device gets signaled by all devices) + pl.semaphore_wait(sem, 2) # Wait for signals from both devices + y_ref[...] = jnp.ones_like(y_ref) + + kernel_jax = pl.pallas_call( + kernel, + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + scratch_shapes=[plgpu.SemaphoreType.REGULAR], + ) + abstract_mesh = jax.sharding.AbstractMesh((2,), ('x',)) + jax.jit(jax.shard_map( + kernel_jax, mesh=abstract_mesh, in_specs=(), + out_specs=jax.P(), check_vma=False)).trace().lower( + lowering_platforms=('gpu',)) # doesn't crash + + @parameterized.named_parameters( + ( + f"_{''.join(map(str, collective_dims))}={collective_size}{'_' + ''.join(map(str, noncollective_dims)) if noncollective_dims else ''}", + collective_dims, + noncollective_dims, + collective_size, + ) + for collective_dims in itertools.chain.from_iterable( + itertools.combinations("xyz", n) for n in range(1, 4) + ) + for noncollective_dims in itertools.chain.from_iterable( + itertools.combinations("xyz", n) for n in range(3) + ) + for collective_size in (1, 2, 4) + if all(d not in noncollective_dims for d in collective_dims) + ) + def test_tma_load_multicast(self, collective_dims, noncollective_dims, collective_dim_size): + """ + 1. Broadcast a GMEM slice to SMEM across collective CTAs. + 2. Send a SMEM slice from each collective CTA to reconstruct the GMEM slice. + It's not strictly necessary to use every collective CTA, but we use them + to test that the cluster axes are used correctly. + """ + + self.skip_if_wg_semantics() # User transforms are not supported. + + dtype = jnp.float16 + cluster = [1, 1, 1] + for d in collective_dims: + cluster["xyz".index(d)] = collective_dim_size + for d in noncollective_dims: + cluster["xyz".index(d)] = 2 + if math.prod(cluster) > jtu.get_cuda_nonportable_max_cluster_size(): + self.skipTest("Cluster is too big.") + + collective_size = math.prod(cluster["xyz".index(d)] for d in collective_dims) + noncollective_size = math.prod(cluster) // collective_size + + swizzle = 128 + swizzle_elems = swizzle // jnp.dtype(dtype).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(swizzle), + ) + shape = (noncollective_size, collective_size * 8, swizzle_elems) + + def body(x_gmem, out_gmem, smem, tma_barrier): + # Compute the index in a subset of the cluster. + def cluster_id(axes): + idx, stride = 0, 1 + for d in sorted(axes): + idx += lax.axis_index(d) * stride + stride *= lax.axis_size(d) + return idx + + noncollective_idx = cluster_id(noncollective_dims) + collective_idx = cluster_id(collective_dims) + + plgpu.copy_gmem_to_smem( + x_gmem.at[noncollective_idx], + smem, + tma_barrier, + collective_axes=collective_dims) + plgpu.barrier_wait(tma_barrier) + + plgpu.commit_smem() + collective_slice = pl.ds(8 * collective_idx, 8) + plgpu.copy_smem_to_gmem( + smem.at[collective_slice], + out_gmem.at[noncollective_idx, collective_slice, :], + ) + plgpu.wait_smem_to_gmem(0) + + x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) + kernel = self.kernel( + body, + grid=cluster, + grid_names=("grid_x", "grid_y", "grid_z"), + cluster=cluster, + cluster_names=("x", "y", "z"), + out_shape=jax.ShapeDtypeStruct(shape, dtype), + scratch_shapes=( + plgpu.SMEM(shape[1:], dtype, transforms=transforms), + plgpu.Barrier(), + ) + ) + np.testing.assert_array_equal(kernel(x), x) + + @parameterized.product( + layout=( + plgpu.Layout.WGMMA, + plgpu.Layout.TCGEN05, + plgpu.Layout.TCGEN05_TMEM_NATIVE, + plgpu.Layout.TCGEN05_M64_COLLECTIVE(128), + plgpu.Layout.TILED( # WGMMA, but defined as a custom tiling. + plgpu.Tiling(((64, 8), (16, 8), (8, 8), (2,))), + warp_dims=(-7,), + lane_dims=(-3, -2), + vector_dim=-1, + ), + # To have some layout with vector length of 1 + plgpu.Layout.TCGEN05_TMEM_NATIVE(1), + # To have some layout with vector length > 2 + plgpu.Layout.TCGEN05_TMEM_NATIVE(4), + ), + op=(jnp.sum, jnp.max, jnp.min), + # TODO(apaszke): Add support for f8 (MLIR/LLVM barfs at the moment). + dtype=(jnp.float32, jnp.float16, jnp.bfloat16, jnp.int32, jnp.uint32), + ) + def test_reduce_with_layout(self, layout, op, dtype): + if layout == plgpu.Layout.TCGEN05_M64_COLLECTIVE(128): + self.skip_if_wg_semantics() # cross-warp reductions are not supported. + axis = -1 + @functools.partial( + self.kernel, + out_shape=jnp.zeros((128,), dtype), + ) + def kernel(x_ref, y_ref): + x_val = plgpu.load(x_ref, (), layout=layout, optimized=False) + y_ref[...] = op(x_val, axis=axis) + + if jnp.issubdtype(dtype, jnp.integer): + x = jnp.arange(128 * 128).reshape((128, 128)).astype(dtype) + if dtype == jnp.int32: + x = x - x.size // 2 # include negative values. + else: + x = jax.random.uniform(jax.random.key(0), shape=(128, 128), dtype=dtype) + x_result = jax.block_until_ready(kernel(x)) + np.testing.assert_allclose(x_result, op(x, axis=axis), atol=5e-5) + + def test_cross_warp_reduction(self): + self.skip_if_wg_semantics() # cross-warp reductions are not supported. + @functools.partial( + self.kernel, + out_shape=jnp.zeros((128,), jnp.float32), + ) + def kernel(x_ref, y_ref): + layout = plgpu.Layout.TCGEN05_TMEM_NATIVE(4) + x_val = plgpu.load(x_ref, (), layout=layout, optimized=False) + y_ref[...] = jnp.sum(x_val, axis=0) + + x = jax.random.uniform(jax.random.key(0), shape=(128, 128), dtype=jnp.float32) + np.testing.assert_allclose(kernel(x), jnp.sum(x, axis=0), atol=5e-5) + + def _test_broadcast_in_dim_base(self, shape, layout, *, axis, hint): + assert len(shape) == 2 + + @functools.partial( + self.kernel, + out_shape=jnp.zeros(shape, jnp.float32), + ) + def kernel(x_ref, y_ref): + reduced_layout = layout.reduce(axis) + reduced = plgpu.load(x_ref, (), layout=reduced_layout, optimized=False) + broadcasted = lax.broadcast_in_dim(reduced, shape, [1 - axis]) + if hint: + broadcasted = plgpu.layout_cast(broadcasted, layout) + # Note that without the hint, the layout of broadcasted is not guaranteed + # to be the same as the layout argument! + y_ref[...] = broadcasted + + x = jax.random.uniform(jax.random.key(0), shape=(128,), dtype=jnp.float32) + x_result = jax.block_until_ready(kernel(x)) + expected = jnp.expand_dims(x, axis=axis) + expected = jnp.broadcast_to(expected, shape) + np.testing.assert_array_equal(x_result, expected) + + @parameterized.product( + layout=( + plgpu.Layout.WGMMA, + plgpu.Layout.TCGEN05, + plgpu.Layout.TCGEN05_TMEM_NATIVE, + plgpu.Layout.TCGEN05_M64_COLLECTIVE(128), + ), + axis=(0, 1), + hint=(True, False), + ) + def test_broadcast_in_dim(self, layout, axis, hint): + self._test_broadcast_in_dim_base((128, 128), layout, axis=axis, hint=hint) + + # Regression test for a crash when using a small shape. + def test_broadcast_in_dim_does_not_crash_on_small_shape(self): + shape = (128, 4) + self._test_broadcast_in_dim_base( + shape, plgpu.Layout.TCGEN05_TMEM_NATIVE, axis=1, hint=False + ) + + def test_broadcast_in_dim_wg_strided_majormost_dim(self): + self.skip_if_wg_semantics() + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((256, 128), jnp.float32), + ) + def kernel(x_ref, y_ref): + to_be_broadcasted = plgpu.load( + x_ref, (), layout=plgpu.Layout.WG_STRIDED((128,), 1) + ) + broadcasted = lax.broadcast_in_dim(to_be_broadcasted, (256, 128), (1,)) + y_ref[...] = broadcasted + + result = jax.random.uniform(jax.random.key(0), shape=(128,), dtype=jnp.float32) + np.testing.assert_array_equal(kernel(result), jnp.broadcast_to(result[None,:], (256, 128))) + + @parameterized.parameters( + ((4, 128),), + ((2, 4, 128),), + ) + def test_broadcast_wg_strided_majormost_dim(self, out_shape): + self.skip_if_wg_semantics() # Lowering not implemented. + dtype = jnp.float32 + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct(out_shape, dtype) + ) + def kernel(x_ref, side_load_ref, y_ref): + x_strided = plgpu.load( + x_ref, (), layout=plgpu.Layout.WG_STRIDED((128,), vec_size=1) + ) + side_load_strided = plgpu.load( + side_load_ref, (), layout=plgpu.Layout.WG_STRIDED(out_shape, vec_size=1) + ) + for _ in range(len(out_shape) - 1): + x_strided = x_strided[None, ...] + y_ref[...] = x_strided + side_load_strided[...] + + inp = jax.random.uniform(jax.random.key(0), (128,), dtype) + side_load = jax.random.uniform(jax.random.key(1), out_shape, dtype) + np.testing.assert_array_equal(kernel(inp, side_load), + jnp.broadcast_to(inp, out_shape) + side_load) + + def test_broadcast_in_dim_tcgen05_native_layout(self): + @functools.partial( + self.kernel, + out_shape=jnp.zeros((128, 128), jnp.float32), + num_threads=1, + thread_name="x", + ) + def kernel(x_ref, y_ref): + reduced = plgpu.load(x_ref, (), layout=plgpu.Layout.TCGEN05_TMEM_NATIVE.reduce(1), optimized=False) + broadcasted = lax.broadcast_in_dim(reduced, (128, 128), [0]) + broadcasted = plgpu.layout_cast(broadcasted, plgpu.Layout.TCGEN05_TMEM_NATIVE) + y_ref[...] = broadcasted + + x = jax.random.uniform(jax.random.key(0), shape=(128,), dtype=jnp.float32) + np.testing.assert_array_equal(kernel(x), jnp.broadcast_to(x[:, None], (128, 128))) + + @parameterized.named_parameters((l.name.lower(), l) for l in plgpu.Layout) + @jtu.skip_if_mosaic_gpu_exceeds_shared_memory( + device_patterns=("RTX PRO 6000 Blackwell", "GB10$")) + def test_copy_layout(self, layout): + if layout in { + plgpu.Layout.WG_SPLAT, + plgpu.Layout.WGMMA_TRANSPOSED, + plgpu.Layout.TCGEN05_TRANSPOSED, + plgpu.Layout.TILED + }: + self.skipTest("Not the right layout for this test") + + # We don't infer optimized transfer-compatible transforms for load to + # registers with TCGEN05_TMEM_NATIVE layout. + # TODO(allanrenucci): Manually specify transforms when supported for WG + # lowering semantic. + optimized = ( + self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane + or layout != plgpu.Layout.TCGEN05_TMEM_NATIVE + ) and layout != plgpu.Layout.TCGEN05_M64_COLLECTIVE_NATIVE + + shape = (128, 128) if "tcgen05" in layout.name.lower() else (64, 128) + dtype = jnp.float32 + swizzle = 128 + if layout in (plgpu.Layout.WGMMA_UPCAST_4X, plgpu.Layout.WGMMA_UPCAST_2X): + dtype = jnp.float8_e5m2 + swizzle = 64 + transforms = self.default_transforms(dtype=dtype, swizzle=swizzle) + + if layout == plgpu.Layout.TCGEN05_M64_COLLECTIVE: + layout = plgpu.Layout.TCGEN05_M64_COLLECTIVE(128) + elif layout == plgpu.Layout.TCGEN05_M64_COLLECTIVE_NATIVE: + layout = plgpu.Layout.TCGEN05_M64_COLLECTIVE_NATIVE(128) + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: + self.skipTest("Need to add support for optimized= for stores") + elif layout == plgpu.Layout.WG_STRIDED: + layout = plgpu.Layout.WG_STRIDED(shape, 2) + transforms = () + elif layout == plgpu.Layout.SMEM_GMEM_COPY: + layout = plgpu.Layout.SMEM_GMEM_COPY(shape, jnp.float32, swizzle=128) + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(shape, dtype), + in_specs=[plgpu.BlockSpec(transforms=transforms)], + out_specs=plgpu.BlockSpec(transforms=transforms), + ) + def kernel(x_ref, o_ref): + o_ref[...] = plgpu.load(x_ref, (), layout=layout, optimized=optimized) + + x = jnp.arange(math.prod(shape), dtype=dtype).reshape(shape) + np.testing.assert_array_equal(kernel(x), x) + + @parameterized.parameters( + (((0, 0),), (128, 128), (128, 128)), + (((0, 1),), (128, 128), (128, 128)), + (((1, None),), (128, 128), (128,)), + (((0, 0),), (128, 128), (128, 128)), + (((0, 0), (0, 0)), (128, 128), (128, 128)), + ) + def test_vmap_kernel(self, vmap_axes, x_shape, y_shape): + rng0, rng1 = jax.random.split(jax.random.key(0)) + x = jax.random.uniform(rng0, x_shape, jnp.float32) + y = jax.random.uniform(rng1, y_shape, jnp.float32) + + out_shape = list(x_shape) + for x_axis, _ in vmap_axes: + del out_shape[x_axis] + out_shape = jax.ShapeDtypeStruct(out_shape, jnp.float32) + + @functools.partial(self.kernel, out_shape=out_shape) + def f(x_ref, y_ref, o_ref): + o_ref[...] = x_ref[...] + y_ref[...] + + f_ref = lambda x, y: x + y + for in_axes in vmap_axes: + f = jax.vmap(f, in_axes) + f_ref = jax.vmap(f_ref, in_axes) + + np.testing.assert_array_equal(f(x, y), f_ref(x, y)) + + def test_discharge_comms_effect(self): + def body(out, sem): + pl.semaphore_signal(sem, device_id=jnp.asarray(2, jnp.int32)) + + f = self.kernel( + body, + out_shape=jax.ShapeDtypeStruct((), jnp.int32), + scratch_shapes=[plgpu.SemaphoreType.REGULAR], + ) + jax_core.check_jaxpr(jax.make_jaxpr(f)().jaxpr) + + @jtu.thread_unsafe_test() # Modifies ``os.environ``. + @jtu.skip_under_pytest("Test fails under pytest in CI") + def test_line_info(self): + self.skip_if_wg_semantics() + + with jtu.set_env(MOSAIC_GPU_DUMP_PTX="1"), jtu.capture_stdout() as output: + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + ) + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + x_ref[0] + + jax.block_until_ready(kernel(jnp.arange(256, dtype=jnp.float32))) + + ptx = output() + self.assertIn(".file", ptx) + self.assertIn(".loc", ptx) + [path] = re.findall(r'.file\s+\d+\s+"(.+)"', ptx) + self.assertEndsWith(__file__, path) + + def test_collective_arrival_count(self): + def kernel(dst, collective_barrier): + plgpu.barrier_arrive(collective_barrier) + plgpu.barrier_arrive(collective_barrier) + plgpu.barrier_arrive(collective_barrier) + plgpu.barrier_arrive(collective_barrier) + plgpu.barrier_wait(collective_barrier) + dst[...] = jnp.ones_like(dst) + y = self.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct((128,), jnp.int32), + scratch_shapes=[plgpu.ClusterBarrier(collective_axes=("x",), num_arrivals=4)], + cluster=(2,), + cluster_names=("x",) + )() + np.testing.assert_array_equal(y, np.ones((), dtype=np.int32)) + + def test_replicated_layout(self): + shape = (32,) + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), + ) + def kernel(src_ref, dst_ref): + layout = plgpu.Layout.TILED( + plgpu.Tiling(((32,), (1,))), + warp_dims=(plgpu.Replicated(4),), + lane_dims=(-2,), + vector_dim=-1, + ) + dst_ref[...] = plgpu.load(src_ref, (), layout=layout) + src = jnp.arange(shape[0], dtype=jnp.float32) + np.testing.assert_array_equal(kernel(src), src) + + +class PallasCallWarpPrimitiveSemanticsTest(PallasTest): + def setUp(self): + super().setUp() + if self.LOWERING_SEMANTICS != plgpu.LoweringSemantics.Lane: + self.skipTest("Test only works on Lane semantics") + + def test_axis_index(self): + warp_mesh = plgpu.WarpMesh(axis_name="warp") + @functools.partial(self.kernel, + out_shape=jax.ShapeDtypeStruct((2, 128), jnp.int32)) + def kernel(y_ref): + def scope(ones_smem_ref, threes_smem_ref): + # Prepare data to copy. + ones_smem_ref[:] = jnp.ones((1, 128), jnp.int32) + threes_smem_ref[:] = jnp.ones((1, 128), jnp.int32) * 3 + plgpu.commit_smem() + @pl.core_map(warp_mesh) + def _(): + warp_id = lax.axis_index("warp") + # We cannot load/store inside of core_map, so we issue async + # copies instead to produce a testable result. + @pl.when(warp_id == 1) + def _(): + plgpu.async_prefetch(y_ref.at[1:2]) + plgpu.copy_smem_to_gmem(ones_smem_ref, y_ref.at[0:1]) + @pl.when(warp_id == 3) + def _(): + plgpu.copy_smem_to_gmem(threes_smem_ref, y_ref.at[1:2]) + plgpu.wait_smem_to_gmem(0) + pl.run_scoped(scope, + plgpu.SMEM((1, 128), jnp.int32), + plgpu.SMEM((1, 128), jnp.int32) + ) + result = kernel() + expected = jnp.stack((jnp.ones((128,), jnp.int32), + jnp.ones((128,), jnp.int32) * 3), axis=0) + np.testing.assert_array_equal(result, expected) + + def test_scalar_load(self): + warp_mesh = plgpu.WarpMesh(axis_name="warp") + @functools.partial(self.kernel, + out_shape=jax.ShapeDtypeStruct((), jnp.int32)) + def kernel(x_ref, y_ref): + @pl.core_map(warp_mesh) + def _(): + warp_id = lax.axis_index("warp") + @pl.when(warp_id == 1) + def _(): + y_ref[...] = x_ref[...] + np.testing.assert_array_equal(kernel(4), 4) + + def test_non_scalar_load_raises(self): + warp_mesh = plgpu.WarpMesh(axis_name="warp") + @functools.partial(self.kernel, + out_shape=jax.ShapeDtypeStruct((2,), jnp.int32)) + def kernel(x_ref, y_ref): + @pl.core_map(warp_mesh) + def _(): + warp_id = lax.axis_index("warp") + @pl.when(warp_id == 1) + def _(): + y_ref[...] = x_ref[...] + with self.assertRaisesRegex(ValueError, "Can only load scalars",): + kernel(jnp.ones((2,), jnp.int32)) + + @parameterized.parameters( + lax.add, lax.sub, lax.mul, lax.div, lax.rem, lax.bitwise_and, + lax.bitwise_or, lax.bitwise_xor, lax.max, lax.min, + lax.gt, lax.lt, lax.ge, lax.le, lax.eq, lax.ne, + ) + def test_scalar_binary_op(self, op): + warp_mesh = plgpu.WarpMesh(axis_name="warp") + @functools.partial(self.kernel, + out_shape=jax.ShapeDtypeStruct((), jnp.int32)) + def kernel(y_ref): + @pl.core_map(warp_mesh) + def _(): + warp_id = lax.axis_index("warp") + @pl.when(warp_id == 1) + def _(): + x = jnp.array(1234, dtype=jnp.int32) + y = jnp.array(6543, dtype=jnp.int32) + y_ref[...] = op(x, y).astype(jnp.int32) + result = kernel() + x = jnp.array(1234, dtype=jnp.int32) + y = jnp.array(6543, dtype=jnp.int32) + np.testing.assert_array_equal(result, op(x, y).astype(jnp.int32)) + + def test_errors_when_closing_over_array(self): + # We currently do not allow closing over arrays when mapping over + # a mesh, since we would need to present a view of the array local + # to each warp. + warp_mesh = plgpu.WarpMesh(axis_name="warp") + @functools.partial(self.kernel, + out_shape=jax.ShapeDtypeStruct((32, 32), jnp.float32), + scratch_shapes=[plgpu.SMEM((32, 32), jnp.float32)]) + def kernel(out_ref, smem_ref): + arr = jnp.ones((32, 32), dtype=jnp.float32) + @pl.core_map(warp_mesh) + def _(): + smem_ref[...] = arr + 1 + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_ref, out_ref) + plgpu.wait_smem_to_gmem(0) + with self.assertRaisesRegex( + mgpu_lowering.LoweringError, + "Can only close over scalars and Refs .* with WarpMesh", + ): + kernel() + + @parameterized.parameters(True, False) + def test_single_warp_loop(self, force_while): + warp_mesh = plgpu.WarpMesh(axis_name="warp") + @functools.partial(self.kernel, + out_shape=jax.ShapeDtypeStruct((10, 128), jnp.int32)) + def kernel(y_ref): + def scope(smem_ref): + # Prepare data to copy. + for i in range(10): + smem_ref[i, :] = jnp.ones_like(smem_ref.at[i]) * i + plgpu.commit_smem() + @pl.core_map(warp_mesh) + def _(): + warp_id = lax.axis_index("warp") + @pl.when(warp_id == 0) + def _(): + def loop_body(i, _): + _slice = pl.ds(i, 1) + plgpu.copy_smem_to_gmem(smem_ref.at[_slice], y_ref.at[_slice]) + _fori_loop(force_while, 0, 10, loop_body, None) + plgpu.wait_smem_to_gmem(0) + pl.run_scoped(scope, plgpu.SMEM((10, 128), jnp.int32)) + result = kernel() + expected = jnp.stack( + [jnp.ones((128,), jnp.int32) * i for i in range(10)], axis=0) + np.testing.assert_array_equal(result, expected) + + def test_debug_print(self): + warp_mesh = plgpu.WarpMesh(axis_name="warp") + @functools.partial( + self.kernel, + out_shape=jnp.zeros(128, np.int32), + ) + def kernel(ref): + ref[...] = ref[...] # Prevent kernel from being DCE'd + @pl.core_map(warp_mesh) + def _(): + warp_id = lax.axis_index("warp") + pl.debug_print("warp: {}", warp_id) + + with self.capture_stdout() as output: + jax.block_until_ready(kernel()) + self.assertEqual( + set(output().splitlines()), + { + "warp: 0", + "warp: 1", + "warp: 2", + "warp: 3", + }, + ) + + @parameterized.parameters(False, True) + def test_copy_gmem_to_smem_from_different_warps(self, + wait_smem_to_gmem_in_warp): + # In this test, we issue a copy from from warp 0 and await it in warp 1. + warp_mesh = plgpu.WarpMesh(axis_name="warp") + @functools.partial(self.kernel, + out_shape=jax.ShapeDtypeStruct((32, 32), jnp.float32)) + def kernel(x_ref, y_ref): + def scope(smem_ref, tma_barrier): + @pl.core_map(warp_mesh) + def _(): + warp_id = lax.axis_index("warp") + @pl.when(warp_id == 0) + def _(): + plgpu.copy_gmem_to_smem(x_ref.at[32:64], smem_ref, tma_barrier) + + @pl.when(warp_id == 1) + def _(): + plgpu.barrier_wait(tma_barrier) + plgpu.copy_smem_to_gmem(smem_ref, y_ref) + if wait_smem_to_gmem_in_warp: + plgpu.wait_smem_to_gmem(0) + if not wait_smem_to_gmem_in_warp: + plgpu.wait_smem_to_gmem(0) + pl.run_scoped(scope, + smem_ref=plgpu.SMEM((32, 32), jnp.float32), + tma_barrier=plgpu.Barrier()) + x = jax.random.uniform(jax.random.key(42), (64, 32), jnp.float32) + result = kernel(x) + np.testing.assert_array_equal(result, x[32:64]) + + +class PallasCallWGTest( + PallasCallTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup +): + ... + + def test_missing_primitive_lowerings_are_tracked(self): + # This test is a way to keep track of which primitives need to be adapted + # to using warpgroup semantics. Once the set is empty, we should be able to + # enable warpgroup semantics by default (assuming we haven't overspecialized + # lowerings). + rules = mgpu_lowering.mosaic_lowering_rules + wg_wg_lowered_primitives = set( + rules[(plgpu.LoweringSemantics.Warpgroup, + gpu_core.PrimitiveSemantics.Warpgroup)]) + lane_wg_lowered_primitives = set(rules[ + (plgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warpgroup)]) + + actual_missing_primitives = (lane_wg_lowered_primitives - + wg_wg_lowered_primitives) + expected_missing_primitives = { + mgpu_primitives.async_copy_scales_to_tmem_p, + mgpu_primitives.async_copy_sparse_metadata_to_tmem_p, + mgpu_primitives.wait_load_tmem_p, + mgpu_primitives.semaphore_signal_parallel_p, + mgpu_primitives.semaphore_signal_multicast_p, + mgpu_primitives.try_cluster_cancel_p, + mgpu_primitives.query_cluster_cancel_p, + mgpu_primitives.multimem_store_p, + mgpu_primitives.multimem_load_reduce_p, + pallas_core.core_map_p, + pallas_primitives.semaphore_signal_p, + pallas_primitives.semaphore_wait_p, + pallas_primitives.semaphore_read_p, + pallas_primitives.delay_p, + } + + self.assertSetEqual(actual_missing_primitives, expected_missing_primitives) + + +class PallasCallSm90ATest(PallasSm90ATest): + + @parameterized.parameters(False, True) + def test_fori_loop_accumulator(self, force_while): + transforms = self.default_transforms(dtype=jnp.float16) + + @functools.partial( + self.pallas_call, + in_specs=[plgpu.BlockSpec((64, 64), transforms=transforms)], + out_shape=jax.ShapeDtypeStruct((64, 64), jnp.float16), + out_specs=plgpu.BlockSpec((64, 64)), + ) + def kernel(i_ref, o_ref): + def scope(acc_ref): + return _fori_loop(force_while, 0, 4, lambda _, v: v + acc_ref[...], acc_ref[...]) + o_ref[...] = pl.run_state(scope)(plgpu.ACC.init(i_ref[...])) + + acc_ini = jnp.ones((64, 64), dtype=jnp.float16) + np.testing.assert_array_equal(kernel(acc_ini), jnp.full((64, 64), 5, dtype=jnp.float16)) + + @parameterized.product(lhs_transpose=[False, True], rhs_transpose=[False, True]) + def test_realistic_matmul(self, lhs_transpose, rhs_transpose): + dtype = jnp.float16 + swizzle = 128 elems_128b = swizzle // jnp.dtype(dtype).itemsize + grid_m, grid_k, grid_n = 132, 10, 4 + tile_m = tile_n = 128 + tile_k = elems_128b + m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n + def kernel(a_ref, b_ref, o_ref, acc_ref): + # Make sure tiling does not alter the shape of references + if lhs_transpose: + a_ref = plgpu.transpose_ref(a_ref, (1, 0)) + assert a_ref.shape == (tile_m, tile_k) + if rhs_transpose: + b_ref = plgpu.transpose_ref(b_ref, (1, 0)) + assert b_ref.shape == (tile_k, tile_n) + assert o_ref.shape == acc_ref.shape == (tile_m, tile_n) + plgpu.wgmma(acc_ref, a_ref, b_ref) + is_last_step = pl.program_id(2) == grid_k - 1 + @pl.when(is_last_step) + def _epilogue(): + o_ref[...] = acc_ref[...].astype(dtype) + plgpu.wgmma_wait(1) # We don't await the last WGMMA, hence delay_release=1 + + key1, key2 = jax.random.split(jax.random.key(42), 2) + a_shape = (k, m) if lhs_transpose else (m, k) + a = jax.random.uniform(key1, shape=a_shape, dtype=dtype) + b_shape = (n, k) if rhs_transpose else (k, n) + b = jax.random.uniform(key2, shape=b_shape, dtype=dtype) + + transforms = self.default_transforms(dtype=dtype) + + if lhs_transpose: + lhs_spec = plgpu.BlockSpec( + (tile_k, tile_m), + lambda m, n, k: (k, m), + delay_release=1, + transforms=transforms, + ) + else: + lhs_spec = plgpu.BlockSpec( + (tile_m, tile_k), + lambda m, n, k: (m, k), + delay_release=1, + transforms=transforms, + ) + if rhs_transpose: + rhs_spec = plgpu.BlockSpec( + (tile_n, tile_k), + lambda m, n, k: (n, k), + delay_release=1, + transforms=transforms, + ) + else: + rhs_spec = plgpu.BlockSpec( + (tile_k, tile_n), + lambda m, n, k: (k, n), + delay_release=1, + transforms=transforms, + ) + out_spec = plgpu.BlockSpec( + (tile_m, tile_n), + lambda m, n, k: (m, n), + transforms=transforms, + ) + + res = self.pallas_call( + kernel, + in_specs=[lhs_spec, rhs_spec], + out_specs=out_spec, + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16), + scratch_shapes=[plgpu.ACC((tile_m, tile_n), jnp.float32)], + grid=(grid_m, grid_n, grid_k), + compiler_params=plgpu.CompilerParams( + dimension_semantics=["parallel", "parallel", "sequential"], + max_concurrent_steps=2, + ), + )(a, b) + np.testing.assert_allclose( + res, + (a.T if lhs_transpose else a) @ (b.T if rhs_transpose else b), + rtol=1e-3, + ) + + @parameterized.parameters(jnp.float16, jnp.float32) + def test_wgmma(self, dtype): + # TensorCores can only fuse transposes of 16-bit values, and RHS + # is expected to be column major by default. + rhs_transpose = jnp.dtype(dtype).itemsize != 2 def kernel(a_ref, b_ref, o_ref): if rhs_transpose: b_ref = plgpu.transpose_ref(b_ref, (1, 0)) @@ -1349,27 +3085,22 @@ def scope(acc_ref): b_shape = b_shape[::-1] b = jax.random.uniform(key2, shape=b_shape, dtype=dtype) - rhs_transforms = (plgpu.TilingTransform((elems_128b, elems_128b)),) - if rhs_transpose: - rhs_transforms += (plgpu.TransposeTransform((1, 0, 2, 3)),) - res = pl.pallas_call( + transforms = self.default_transforms(dtype=dtype) + res = self.pallas_call( kernel, in_specs=[ - plgpu.GPUBlockSpec( + plgpu.BlockSpec( (64, 128), lambda i, j: (i, j), - transforms=( - plgpu.TilingTransform((64, elems_128b)), - plgpu.SwizzleTransform(128), - ), + transforms=transforms, ), - plgpu.GPUBlockSpec( + plgpu.BlockSpec( b_shape, lambda *i: i, - transforms=(*rhs_transforms, plgpu.SwizzleTransform(128)), + transforms=transforms, ), ], - out_specs=plgpu.GPUBlockSpec((64, 192), lambda *i: i), + out_specs=plgpu.BlockSpec((64, 192), lambda *i: i), out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32), grid=(1, 1), )(a, b) @@ -1377,7 +3108,83 @@ def scope(acc_ref): res, a @ (b.T if rhs_transpose else b), rtol=1e-3 ) - def test_wgmma_registers(self): + @parameterized.parameters(jnp.int8, jnp.uint8) + def test_wgmma_integer(self, dtype): + m, k, n = 64, 128, 64 + + is_signed = jnp.issubdtype(dtype, jnp.signedinteger) + acc_type = jnp.int32 + + def kernel(a_ref, b_ref, o_ref): + + def scope(acc_ref): + plgpu.wgmma(acc_ref, a_ref, plgpu.transpose_ref(b_ref, (1, 0))) + return acc_ref[...] + + o_ref[...] = pl.run_scoped(scope, plgpu.ACC((m, n), acc_type)) + + # use small values to avoid overflow, [0, 8) for u8 and (-8, 8) for s8 + random_int_input = lambda key, shape: jax.random.randint( + key, minval=-8 * is_signed, maxval=8, shape=shape, dtype=dtype + ) + + a = random_int_input(jax.random.key(0), shape=(m, k)) + b = random_int_input(jax.random.key(1), shape=(n, k)) + + transforms = self.default_transforms(dtype=dtype) + res = self.pallas_call( + kernel, + in_specs=[ + plgpu.BlockSpec( + (m, k), + lambda i, j: (i, j), + transforms=transforms, + ), + plgpu.BlockSpec( + (n, k), + lambda *i: i, + transforms=transforms, + ), + ], + out_specs=plgpu.BlockSpec((m, n), lambda *i: i), + out_shape=jax.ShapeDtypeStruct((m, n), acc_type), + grid=(1, 1), + )(a, b) + np.testing.assert_array_equal( + res, a.astype(acc_type) @ b.T.astype(acc_type) + ) + + def test_wgmma_sliced_acc_flip(self): + self.skip_if_wg_semantics() + dtype = jnp.float16 + + key1, key2 = jax.random.split(jax.random.key(42), 2) + a = jax.random.uniform(key1, shape=(64, 128), dtype=dtype) + b = jax.random.uniform(key2, shape=(128, 256), dtype=dtype) + + def kernel(a_ref, b_ref, o_ref): + def scope(acc_ref): + plgpu.wgmma(acc_ref.at[:, :128], a_ref, b_ref.at[:, 128:]) + plgpu.wgmma(acc_ref.at[:, 128:], a_ref, b_ref.at[:, :128]) + return acc_ref[...] + + o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 256), jnp.float32)) + + transforms = self.default_transforms(dtype=dtype) + res = self.pallas_call( + kernel, + in_specs=[plgpu.BlockSpec(transforms=transforms)] * 2, + out_shape=jax.ShapeDtypeStruct((64, 256), jnp.float32), + )(a, b) + + def flip_halves(x): + y = x.reshape(*x.shape[:-1], 2, x.shape[-1] // 2) + y = y[..., ::-1, :] + return y.reshape(x.shape) + + np.testing.assert_allclose(res, a @ flip_halves(b), rtol=1e-3) + + def test_wgmma_registers(self): def kernel(a_ref, b_ref, o_ref): def scope(acc_ref): plgpu.wgmma(acc_ref, a_ref[...], b_ref) @@ -1388,18 +3195,57 @@ def scope(acc_ref): a = jax.random.uniform(key1, shape=(64, 128), dtype=jnp.float16) b = jax.random.uniform(key2, shape=(128, 192), dtype=jnp.float16) - transforms = (plgpu.TilingTransform((64, 64)), plgpu.SwizzleTransform(128)) - res = pl.pallas_call( + transforms = self.default_transforms(dtype=jnp.float16) + res = self.pallas_call( kernel, in_specs=[ - plgpu.GPUBlockSpec((64, 128), lambda: (0, 0), transforms=transforms), - plgpu.GPUBlockSpec((128, 192), lambda: (0, 0), transforms=transforms), + plgpu.BlockSpec(transforms=transforms), + plgpu.BlockSpec(transforms=transforms), ], - out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)), out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32), )(a, b) np.testing.assert_allclose(res, a @ b, rtol=1e-3) + @parameterized.parameters(jnp.int8, jnp.float8_e4m3fn, jnp.float8_e5m2) + def test_wgmma_registers_8bit(self, input_dtype): + if jnp.issubdtype(input_dtype, jnp.integer): + out_dtype = jnp.int32 + else: + out_dtype = jnp.float32 + def kernel(a_ref, b_ref, o_ref): + def scope(acc_ref): + a_regs = plgpu.load(a_ref, (), layout=plgpu.Layout.WGMMA_8BIT) + plgpu.wgmma(acc_ref, a_regs, plgpu.transpose_ref(b_ref, (1, 0))) + return acc_ref[...] + o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 192), out_dtype)) + + key1, key2 = jax.random.split(jax.random.key(42), 2) + m = 64 + k = 128 + n = 192 + if input_dtype == jnp.int8: + a = jax.random.randint(key1, shape=(m, k), minval=-128, maxval=127, dtype=jnp.int8) + b = jax.random.randint(key2, shape=(n, k), minval=-128, maxval=127, dtype=jnp.int8) + else: + assert jnp.issubdtype(input_dtype, jnp.floating) + a = jax.random.uniform(key1, shape=(m, k), dtype=input_dtype) + b = jax.random.uniform(key2, shape=(n, k), dtype=input_dtype) + + transforms = self.default_transforms(swizzle=64, dtype=input_dtype) + res = self.pallas_call( + kernel, + in_specs=[ + plgpu.BlockSpec(transforms=transforms), + plgpu.BlockSpec(transforms=transforms), + ], + out_shape=jax.ShapeDtypeStruct((64, 192), out_dtype), + )(a, b) + ref = a.astype(out_dtype) @ b.T.astype(out_dtype) + if input_dtype == jnp.int8: + np.testing.assert_array_equal(res, ref) + else: + np.testing.assert_allclose(res, ref) + def test_wgmma_registers_init(self): def kernel(a_ref, b_ref, i_ref, o_ref): def scope(acc_ref): @@ -1411,15 +3257,14 @@ def scope(acc_ref): b = jax.random.uniform(key2, shape=(128, 192), dtype=jnp.float16) i = jax.random.uniform(key3, shape=(64, 192), dtype=jnp.float16) * 10 - transforms = (plgpu.TilingTransform((64, 64)), plgpu.SwizzleTransform(128)) - res = pl.pallas_call( + transforms = self.default_transforms(dtype=jnp.float16) + res = self.pallas_call( kernel, in_specs=[ - plgpu.GPUBlockSpec((64, 128), lambda: (0, 0), transforms=transforms), - plgpu.GPUBlockSpec((128, 192), lambda: (0, 0), transforms=transforms), - plgpu.GPUBlockSpec((64, 192), lambda: (0, 0), transforms=transforms), + plgpu.BlockSpec(transforms=transforms), + plgpu.BlockSpec(transforms=transforms), + plgpu.BlockSpec(transforms=transforms), ], - out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)), out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float16), )(a, b, i) np.testing.assert_allclose(res, i + a @ b, rtol=2e-3) @@ -1436,32 +3281,20 @@ def scope(acc_ref): a = jax.random.uniform(key1, shape=(2, 64, 128), dtype=jnp.float16) b = jax.random.uniform(key2, shape=(2, 128, 192), dtype=jnp.float16) - res = pl.pallas_call( + transforms = self.default_transforms(dtype=jnp.float16) + res = self.pallas_call( kernel, in_specs=[ - plgpu.GPUBlockSpec( - (2, 64, 128), lambda: (0, 0, 0), - transforms=( - plgpu.TilingTransform((64, 64)), - plgpu.SwizzleTransform(128), - ), - ), - plgpu.GPUBlockSpec( - (2, 128, 192), lambda: (0, 0, 0), - transforms=( - plgpu.TilingTransform((64, 64)), - plgpu.SwizzleTransform(128), - ), - ), + plgpu.BlockSpec(transforms=transforms), + plgpu.BlockSpec(transforms=transforms), ], - out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)), out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32), )(a, b) np.testing.assert_allclose(res, a[0] @ b[0], rtol=1e-3) - def test_wgmma_sliced_acc(self): - swizzle = 128 - elems_128b = swizzle // jnp.dtype(jnp.float16).itemsize + def test_wgmma_sliced_acc_read(self): + self.skip_if_wg_semantics() # MLIR verifier error for `memref.subview`. + def kernel(a_ref, b_ref, o_ref): def scope(acc_ref): plgpu.wgmma(acc_ref, a_ref, b_ref) @@ -1472,685 +3305,2847 @@ def scope(acc_ref): key1, key2 = jax.random.split(jax.random.key(42), 2) a = jax.random.uniform(key1, shape=(64, 128), dtype=jnp.float16) b = jax.random.uniform(key2, shape=(128, 128), dtype=jnp.float16) - res = pl.pallas_call( + transforms = self.default_transforms(dtype=jnp.float16) + res = self.pallas_call( kernel, in_specs=[ - plgpu.GPUBlockSpec( - (64, 128), - lambda i, j: (i, j), - transforms=( - plgpu.TilingTransform((64, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), - plgpu.GPUBlockSpec( - (128, 128), - lambda *i: i, - transforms=( - plgpu.TilingTransform((elems_128b, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), + plgpu.BlockSpec((64, 128), lambda *ij: ij, transforms=transforms), + plgpu.BlockSpec((128, 128), lambda *ij: ij, transforms=transforms), ], - out_specs=plgpu.GPUBlockSpec((64, 128), lambda *i: i), + out_specs=plgpu.BlockSpec((64, 128), lambda *ij: ij), out_shape=jax.ShapeDtypeStruct((64, 128), jnp.float32), grid=(1, 1), )(a, b) np.testing.assert_allclose(res, a @ b, rtol=1e-3) + @parameterized.product( + src_memory_space=[plgpu.SMEM, plgpu.GMEM], + layout=[plgpu.Layout.WGMMA_ROW, plgpu.Layout.WGMMA_COL], + m=[64, 128, 192], + ) + def test_load_to_wgmma_row_col_layout_with_indexing(self, src_memory_space, layout, m): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([2, m], jnp.float32), + in_specs=[pl.BlockSpec(memory_space=src_memory_space)], + out_specs=plgpu.BlockSpec(memory_space=plgpu.SMEM), + ) + def kernel(x_ref, o_ref): + for i in range(2): + x = plgpu.load( + x_ref, (i,), layout=layout, optimized=src_memory_space == plgpu.SMEM + ) + o_ref[i, ...] = x + + x = jnp.arange(2 * m, dtype=jnp.float32).reshape(2, m) + np.testing.assert_array_equal(kernel(x), x) -class PipelineTest(PallasTest): + @parameterized.product( + src_memory_space=[plgpu.SMEM], + layout=[plgpu.Layout.WGMMA_ROW, plgpu.Layout.WGMMA_COL], + ) + def test_load_row_input_to_wgmma_with_transforms(self, src_memory_space, layout): + m, k, n = 64, 128, 192 + key1, key2 = jax.random.split(jax.random.key(42), 2) + if layout == plgpu.Layout.WGMMA_ROW: + input_shape = (m,) + broadcast_dim = 0 + expand_dim = 1 + else: + input_shape = (k,) + broadcast_dim = 1 + expand_dim = 0 + a = jax.random.uniform(key1, shape=input_shape, dtype=jnp.float16) + b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16) + def kernel(x_ref, y_ref, o_ref): + x = plgpu.load(x_ref, (), layout=layout) + x = lax.broadcast_in_dim(x, (m, k), [broadcast_dim]) - def test_pipeline_mode(self): - def body(x_ref, y_ref, o_ref): - x = x_ref[:] - y = y_ref[:] - o_ref[:] = x + y + def compute(acc_ref): + plgpu.wgmma(acc_ref, x, y_ref) + return acc_ref[...] - data_size = 64 * 256 - block_size = 256 + out = pl.run_scoped(compute, plgpu.ACC((m, n), jnp.float32)) + o_ref[...] = out + f = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct([m, n], jnp.float32), + in_specs=( + pl.BlockSpec(memory_space=src_memory_space), + plgpu.BlockSpec( + transforms=self.default_transforms(dtype=jnp.float16), + ), + ), + out_specs=plgpu.BlockSpec(memory_space=plgpu.SMEM), + ) - x = jnp.arange(data_size, dtype=jnp.float32) - y = jnp.arange(data_size, dtype=jnp.float32) - in_specs = [ - pl.BlockSpec((block_size,), lambda *i: i, pipeline_mode=pl.Buffered(2)), - pl.BlockSpec((block_size,), lambda *i: i, pipeline_mode=pl.Buffered(1)) + out_ref = ( + jnp.broadcast_to(jnp.expand_dims(a, axis=expand_dim), (m, k)) @ b + ) + np.testing.assert_allclose(f(a, b), out_ref, rtol=1e-3) + + def test_load_store_wgmma_transposed(self): + transforms = self.default_transforms(swizzle=64, dtype=jnp.float32) + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([8, 64], jnp.float32), + in_specs=[ + pl.BlockSpec(memory_space=plgpu.GMEM), + ], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + scratch_shapes=[ + plgpu.SMEM((8, 64), jnp.float32, transforms=transforms), + plgpu.Barrier(), + ], + ) + def kernel(x_gmem, o_ref, x_smem, barrier): + plgpu.copy_gmem_to_smem(x_gmem, x_smem, barrier) + plgpu.barrier_wait(barrier) + x = plgpu.load(x_smem.T, (), layout=plgpu.Layout.WGMMA_TRANSPOSED) + x_smem.T[...] = x + 1 + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(x_smem, o_ref) + plgpu.wait_smem_to_gmem(0) + + x = jax.random.uniform(jax.random.key(42), shape=(8, 64), dtype=jnp.float32) + result = kernel(x) + np.testing.assert_array_equal(result, x + 1) + + +class PallasCallSm90AWGTest( + PallasCallSm90ATest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup +): + ... + + +class PallasCallSm100ATest(PallasSm100ATest): + + def test_print_layout_tmem(self): + shape = (128, 256) + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(shape, jnp.bfloat16), + scratch_shapes=[plgpu.TMEM(shape, jnp.bfloat16, packed=True)], + ) + def kernel(o_ref, tmem_ref): + del o_ref + # Slicing TMEM to make sure we handle transforms correctly. + plgpu.print_layout("tmem: {}", tmem_ref.at[:, :128]) + + with self.capture_stdout() as output: + jax.block_until_ready(kernel()) + + self.assertIn("tmem: TMEM_DEFAULT(packing=2)\n", output()) + + def test_mixed_tmem_allocations_raise(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((), jnp.float32), + scratch_shapes=[ + plgpu.TMEM((128, 128), jnp.float32, collective=True), + plgpu.TMEM((128, 128), jnp.float32, collective=False), + ], + ) + def kernel(out_ref, tmem_ref0, tmem_ref1): + del out_ref, tmem_ref0, tmem_ref1 + + with self.assertRaisesRegex( + ValueError, + "Can't mix collective and non-collective TMEM allocations within the" + " same kernel.", + ): + kernel() + + def test_transposed_tmem_ref_raises(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([], jnp.float32), + scratch_shapes=[plgpu.TMEM((128, 128), jnp.float32)], + ) + def kernel(out, tmem_ref): + del out + plgpu.transpose_ref(tmem_ref, (1, 0)) + + with self.assertRaisesRegex(ValueError, "Can't transpose a TMEM reference"): + kernel() + + @parameterized.parameters((False,), (True,)) + def test_tmem(self, collective): + transforms = self.default_transforms(dtype=jnp.float32) + @functools.partial( + self.kernel, + out_shape=jnp.zeros((128, 128), jnp.float32), + scratch_shapes=[ + plgpu.TMEM((128, 128), jnp.float32, collective=collective), + plgpu.TMEM((128, 128), jnp.float32, collective=collective), + plgpu.SMEM((128, 128), jnp.float32, transforms=transforms), + plgpu.Barrier(), + ], + num_threads=1, + thread_name="x", + cluster=(2,) if collective else (), + cluster_names=("x",) if collective else (), + ) + def kernel(x_ref, y_ref, tmem_ref, tmem_ref2, smem_ref, barrier_ref): + plgpu.copy_gmem_to_smem(x_ref, smem_ref, barrier_ref) + plgpu.barrier_wait(barrier_ref) + # Exercise TMEM by roundtripping SMEM -> TMEM -> TMEM -> SMEM. + x_val = plgpu.load(smem_ref, (), layout=plgpu.Layout.TCGEN05) + plgpu.async_store_tmem(tmem_ref, x_val + 1) + plgpu.commit_tmem() + # We don't await the load, because we never overwrite tmem_ref + tmem_read = plgpu.async_load_tmem(tmem_ref) + plgpu.async_store_tmem(tmem_ref2, tmem_read) + plgpu.commit_tmem() + # We don't await the load, because we never overwrite tmem_ref2 + smem_ref[...] = plgpu.async_load_tmem(tmem_ref2) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_ref, y_ref) + plgpu.wait_smem_to_gmem(0) + + x = jax.random.uniform( + jax.random.key(0), shape=(128, 128), dtype=jnp.float32) + x_result = jax.block_until_ready(kernel(x)) + np.testing.assert_array_equal(x_result, x + 1) + + def test_tmem_allocation_estimation(self): + """Make sure that we don't overestimate the TMEM allocation. + + All of the refs below are packed and should fit into TMEM at once. + """ + transforms = self.default_transforms(dtype=jnp.bfloat16) + @functools.partial( + self.kernel, + out_shape=jnp.zeros((128, 256), jnp.bfloat16), + scratch_shapes=[ + plgpu.TMEM((128, 256), jnp.bfloat16, packed=True), + plgpu.TMEM((128, 256), jnp.bfloat16, packed=True), + plgpu.TMEM((128, 256), jnp.bfloat16, packed=True), + plgpu.SMEM((128, 256), jnp.bfloat16, transforms=transforms), + plgpu.Barrier(), + ], + num_threads=1, + thread_name="x", + ) + def kernel(x_ref, y_ref, tmem_ref1, tmem_ref2, tmem_ref3, smem_ref, barrier_ref): + plgpu.copy_gmem_to_smem(x_ref, smem_ref, barrier_ref) + plgpu.barrier_wait(barrier_ref) + x_val = plgpu.load(smem_ref, (), layout=plgpu.Layout.TCGEN05) + plgpu.async_store_tmem(tmem_ref1, x_val + 1) + plgpu.commit_tmem() + x_val = plgpu.async_load_tmem(tmem_ref1) + plgpu.async_store_tmem(tmem_ref2, x_val + 1) + plgpu.commit_tmem() + x_val = plgpu.async_load_tmem(tmem_ref2) + plgpu.async_store_tmem(tmem_ref3, x_val + 1) + plgpu.commit_tmem() + smem_ref[...] = plgpu.async_load_tmem(tmem_ref3) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_ref, y_ref) + plgpu.wait_smem_to_gmem(0) + + x = jax.random.uniform(jax.random.key(0), shape=(128, 256), dtype=jnp.bfloat16) + x_result = jax.block_until_ready(kernel(x)) + np.testing.assert_array_equal(x_result, x + 3) + + def test_tmem_ref_aliasing(self): + self.skip_if_wg_semantics() + transforms = self.default_transforms(dtype=jnp.float32) + @functools.partial( + self.kernel, + out_shape=jnp.zeros((128, 128), jnp.float32), + scratch_shapes=[ + plgpu.RefUnion( + [plgpu.TMEM((128, 32), jnp.float32), + plgpu.TMEM((128, 32), jnp.float32)], + plgpu.TMEM((128, 64), jnp.float32), + ), + plgpu.SMEM((128, 128), jnp.float32, transforms=transforms), + plgpu.Barrier(), + ], + num_threads=1, + thread_name="x", + ) + def kernel(x_ref, y_ref, aliased_ref, smem_ref, barrier_ref): + [tmem_128x32a, tmem_128x32b], tmem_128x64 = aliased_ref + plgpu.copy_gmem_to_smem(x_ref, smem_ref, barrier_ref) + plgpu.barrier_wait(barrier_ref) + # Test tmem_128x32 a and b + x_val = plgpu.load(smem_ref.at[:, 0:32], (), layout=plgpu.Layout.TCGEN05) + plgpu.async_store_tmem(tmem_128x32a, x_val + 1) + plgpu.commit_tmem() + smem_ref[:, 0:32] = plgpu.async_load_tmem(tmem_128x32a) + plgpu.wait_load_tmem() # Make sure the load is done before we write to TMEM again. + + x_val = plgpu.load(smem_ref.at[:, 32:64], (), layout=plgpu.Layout.TCGEN05) + plgpu.async_store_tmem(tmem_128x32b, x_val + 1) + plgpu.commit_tmem() + smem_ref[:, 32:64] = plgpu.async_load_tmem(tmem_128x32b) + plgpu.wait_load_tmem() # Make sure the load is done before we write to TMEM again. + + # Test tmem_128x64 + x_val = plgpu.load(smem_ref.at[:, 64:128], (), layout=plgpu.Layout.TCGEN05) + plgpu.async_store_tmem(tmem_128x64, x_val + 1) + plgpu.commit_tmem() + smem_ref[:, 64:128] = plgpu.async_load_tmem(tmem_128x64) + + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_ref, y_ref) + plgpu.wait_smem_to_gmem(0) + x = jax.random.uniform( + jax.random.key(0), shape=(128, 128), dtype=jnp.float32) + x_result = jax.block_until_ready(kernel(x)) + np.testing.assert_array_equal(x_result, x + 1) + + @parameterized.parameters( + plgpu.Layout.TCGEN05, plgpu.Layout.TCGEN05_TMEM_NATIVE + ) + def test_tmem_load_layout(self, layout): + transforms = self.default_transforms(dtype=jnp.float32) + @functools.partial( + self.kernel, + out_shape=jax.ShapeDtypeStruct((128, 128), jnp.float32), + scratch_shapes=[ + plgpu.TMEM((128, 128), jnp.float32), + plgpu.SMEM((128, 128), jnp.float32, transforms=transforms), + plgpu.Barrier(), + ], + ) + def kernel(x_ref, y_ref, tmem_ref, smem_ref, barrier_ref): + plgpu.copy_gmem_to_smem(x_ref, smem_ref, barrier_ref) + plgpu.barrier_wait(barrier_ref) + optimized = layout != plgpu.Layout.TCGEN05_TMEM_NATIVE + x_val = plgpu.load(smem_ref, (), layout=layout, optimized=optimized) + plgpu.async_store_tmem(tmem_ref, x_val + 1) + plgpu.commit_tmem() + # We don't wait for the load to complete, because we never overwrite + # tmem_ref. + smem_ref[...] = plgpu.async_load_tmem(tmem_ref, layout=layout) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_ref, y_ref) + plgpu.wait_smem_to_gmem(0) + + x = jax.random.uniform( + jax.random.key(0), shape=(128, 128), dtype=jnp.float32) + x_result = jax.block_until_ready(kernel(x)) + np.testing.assert_array_equal(x_result, x + 1) + + @parameterized.parameters( + plgpu.Layout.TCGEN05_M64_COLLECTIVE(160), + plgpu.Layout.TCGEN05_M64_COLLECTIVE_NATIVE(160) + ) + def test_tmem_store_load_collective(self, layout): + self.skip_if_wg_semantics() # Failed to infer a possible set of layouts. + @functools.partial( + self.kernel, + out_shape=jax.ShapeDtypeStruct((64, 160), jnp.float32), + cluster=(2,), + cluster_names=("cluster",), + scratch_shapes=[ + plgpu.TMEM( + (64, 160), jnp.float32, collective=True, + layout=plgpu.TMEMLayout.M64_COLLECTIVE_LAYOUT(160), + ), + ], + ) + def kernel(x_ref, y_ref, tmem_ref): + x_val = plgpu.load(x_ref, (), layout=layout, optimized=False) + plgpu.async_store_tmem(tmem_ref, x_val + 1) + plgpu.commit_tmem() + # We don't wait for the load to complete, because we never overwrite + # tmem_ref. + y_ref[...] = plgpu.async_load_tmem(tmem_ref, layout=layout) + + x = jax.random.uniform( + jax.random.key(0), shape=(64, 160), dtype=jnp.float32) + x_result = jax.block_until_ready(kernel(x)) + np.testing.assert_array_equal(x_result, x + 1) + + def test_tmem_column_slicing(self): + @functools.partial( + self.kernel, + out_shape=jax.ShapeDtypeStruct((128, 128), jnp.float32), + scratch_shapes=[ + plgpu.TMEM((128, 256), jnp.float32), + ], + num_threads=1, + thread_name="x", + ) + def kernel(x_ref, y_ref, tmem_ref): + x_val = plgpu.load(x_ref, (), layout=plgpu.Layout.TCGEN05, optimized=False) + tmem_slice = tmem_ref.at[:, 8:208].at[:, 0:128] + plgpu.async_store_tmem(tmem_slice, x_val + 1) + plgpu.commit_tmem() + y_ref[...] = plgpu.async_load_tmem(tmem_ref.at[:, 8:136]) + + x = jax.random.uniform( + jax.random.key(0), shape=(128, 128), dtype=jnp.float32) + np.testing.assert_array_equal(kernel(x), (x + 1)[:, 0:128]) + + @parameterized.product( + m=[64, 128], + n=[64, 128, 256], + swizzle=[64, 32], + dtype=[jnp.int8, jnp.uint8], + lhs_tmem=[False, True], + ) + def test_integer_matmul(self, m, n, swizzle, dtype, lhs_tmem): + if n * jnp.dtype(dtype).itemsize <= swizzle: + self.skipTest("swizzle too big") + if lhs_tmem and m == 64: + self.skipTest("m=64 not supported for LHS in TMEM") + if lhs_tmem: + self.skip_if_wg_semantics() # Layout inference fails to find a solution. + k = 128 + is_signed = jnp.issubdtype(dtype, jnp.signedinteger) + o_dtype = jnp.int32 + + in_transforms = self.default_transforms(dtype=dtype, swizzle=swizzle) + + def kernel( + a_smem, b_smem, out_ref, acc_tmem, barrier_ref, a_tmem_ref + ): + if lhs_tmem: + lhs_ref = a_tmem_ref + layout = plgpu.Layout.TCGEN05_TMEM_NATIVE(4) + plgpu.async_store_tmem(lhs_ref, plgpu.load(a_smem, (), layout=layout, optimized=False)) + plgpu.commit_tmem() + else: + lhs_ref = a_smem + + plgpu.tcgen05_mma( + acc_tmem, lhs_ref, b_smem, barrier_ref, accumulate=False + ) + plgpu.barrier_wait(barrier_ref) + out_ref[...] = plgpu.async_load_tmem(acc_tmem) + + scratch_shapes = [ + plgpu.TMEM((m, n), o_dtype, packed=False), + plgpu.Barrier(orders_tensor_core=True), ] - out_specs = pl.BlockSpec((block_size,), lambda *i: i) + if lhs_tmem: + scratch_shapes.append(plgpu.TMEM((m, k), dtype, packed=True)) + else: + scratch_shapes.append(None) - @jax.jit - def vadd(x, y): - return pl.pallas_call( - body, - out_shape=jax.ShapeDtypeStruct(x.shape, jnp.float32), - in_specs=in_specs, - out_specs=out_specs, - grid=data_size // block_size, - )(x, y) + f = self.pallas_call( + kernel, + in_specs=( + plgpu.BlockSpec(transforms=in_transforms, memory_space=plgpu.SMEM), + plgpu.BlockSpec(transforms=in_transforms, memory_space=plgpu.SMEM), + ), + out_specs=plgpu.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct((m, n), o_dtype), + scratch_shapes=scratch_shapes, + ) + # use small values to avoid overflow, [0, 8) for u8 and (-8, 8) for s8 + random_int_input = lambda key, shape: jax.random.randint( + key, minval=-8 * is_signed, maxval=8, shape=shape, dtype=dtype + ) + + x = random_int_input(jax.random.key(0), shape=(m, k)) + y = random_int_input(jax.random.key(1), shape=(k, n)) + + result = f(x, y) + expected = x.astype(o_dtype) @ y.astype(o_dtype) + np.testing.assert_array_equal(result, expected) + + @parameterized.product(m=[64, 128], + n=[64, 128, 256], + swizzle=[128, 64, 32], + dtype=[jnp.float16, jnp.bfloat16], + lhs_tmem=[False, True], + transpose_rhs=[False, True], + transpose_lhs=[False, True]) + def test_simple_matmul( + self, m, n, swizzle, dtype, lhs_tmem, transpose_lhs, transpose_rhs + ): + if transpose_lhs and lhs_tmem: + self.skipTest("TMEM transpose not supported") + if n * jnp.dtype(dtype).itemsize <= swizzle: + self.skipTest("swizzle too big") + if lhs_tmem and m == 64: + self.skipTest("m=64 not supported for LHS in TMEM") + k = 128 + # Test a matmul with a single block. + transforms = self.default_transforms(dtype=dtype, swizzle=swizzle) + + def kernel(a_smem, b_smem, out_ref, acc_tmem, barrier_ref, + a_tmem_ref): + if transpose_lhs: + a_smem = plgpu.transpose_ref(a_smem, (1, 0)) + if transpose_rhs: + b_smem = plgpu.transpose_ref(b_smem, (1, 0)) + if lhs_tmem: + lhs_ref = a_tmem_ref + layout = plgpu.Layout.TCGEN05 if m == 128 else plgpu.Layout.WGMMA + plgpu.async_store_tmem(lhs_ref, plgpu.load(a_smem, (), layout=layout)) + plgpu.commit_tmem() + else: + lhs_ref = a_smem + plgpu.tcgen05_mma(acc_tmem, + lhs_ref, + b_smem, + barrier_ref, + accumulate=False) + plgpu.barrier_wait(barrier_ref) + # We don't await the load because acc_tmem is never modified again. + out_ref[...] = plgpu.async_load_tmem(acc_tmem).astype(dtype) + + scratch_shapes = [ + plgpu.TMEM((m, n), jnp.float32, packed=False), + plgpu.Barrier(orders_tensor_core=True), + ] + if lhs_tmem: + scratch_shapes.append(plgpu.TMEM((m, k), dtype, packed=True)) + else: + scratch_shapes.append(None) + + f = self.pallas_call( + kernel, + in_specs=( + plgpu.BlockSpec(transforms=transforms, memory_space=plgpu.SMEM), + plgpu.BlockSpec(transforms=transforms, memory_space=plgpu.SMEM), + ), + out_specs=plgpu.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct((m, n), dtype), + scratch_shapes=scratch_shapes, + ) + lhs_shape = (k, m) if transpose_lhs else (m, k) + rhs_shape = (n, k) if transpose_rhs else (k, n) + x = jax.random.uniform(jax.random.key(0), shape=lhs_shape, dtype=dtype) + y = jax.random.uniform(jax.random.key(1), shape=rhs_shape, dtype=dtype) + result = f(x, y) + if transpose_lhs: + x = jnp.transpose(x, (1, 0)) + if transpose_rhs: + y = jnp.transpose(y, (1, 0)) + expected = x @ y + np.testing.assert_allclose(result, expected, rtol=1e-3) + + def test_matmul_alignment(self): + m = k = n = 128 + dtype = jnp.float16 + transforms = self.default_transforms(dtype=dtype) + + def kernel(a_smem, b_smem, out_ref, _, acc_tmem, barrier_ref): + plgpu.tcgen05_mma(acc_tmem, a_smem, b_smem, barrier_ref, accumulate=False) + plgpu.barrier_wait(barrier_ref) + # We don't await the load because acc_tmem is never modified again. + out_ref[...] = plgpu.async_load_tmem(acc_tmem).astype(dtype) + + spec = plgpu.BlockSpec(transforms=transforms, memory_space=plgpu.SMEM) + f = self.pallas_call( + kernel, + in_specs=(spec, spec), + out_specs=spec, + out_shape=jax.ShapeDtypeStruct((m, n), dtype), + # Add a one column space to test if we align the accumulator. + scratch_shapes=( + plgpu.TMEM((128, 1), jnp.float32), + plgpu.TMEM((m, n), jnp.float32), + plgpu.Barrier(orders_tensor_core=True), + ), + ) + lhs_shape = (m, k) + rhs_shape = (k, n) + x = jax.random.uniform(jax.random.key(0), shape=lhs_shape, dtype=dtype) + y = jax.random.uniform(jax.random.key(1), shape=rhs_shape, dtype=dtype) + result = f(x, y) + expected = x @ y + np.testing.assert_allclose(result, expected, rtol=1e-3) + + @parameterized.product( + m=[128], + n=[128, 256], + dtype=[jnp.float8_e5m2, jnp.float8_e4m3fn, jnp.float4_e2m1fn], + ) + def test_simple_scaled_matmul(self, m, n, dtype): + self.skip_if_wg_semantics() + # TODO(apaszke): Add support for single-buffering in pallas_call. + causes_oom = jnp.finfo(dtype).bits == 8 and n == 256 + k = 128 if causes_oom else 256 + swizzle = 128 + transforms = self.default_transforms(swizzle=swizzle, dtype=dtype) + out_transforms = self.default_transforms(dtype=jnp.float32) + + def kernel(a_smem, b_smem, a_scale_smem, b_scale_smem, out_ref, + barrier_ref, acc_tmem, a_scale_tmem, b_scale_tmem): + plgpu.async_copy_scales_to_tmem(a_scale_smem, a_scale_tmem) + plgpu.async_copy_scales_to_tmem(b_scale_smem, b_scale_tmem) + # We don't have to await the copy because it's only used by the MMA. + plgpu.tcgen05_mma(acc_tmem, + a_smem, + plgpu.transpose_ref(b_smem, (1, 0)), + a_scale=a_scale_tmem, + b_scale=b_scale_tmem, + accumulate=False) + plgpu.tcgen05_commit_arrive(barrier_ref) + plgpu.barrier_wait(barrier_ref) + # We don't await the load because acc_tmem is never modified again. + out_ref[...] = plgpu.async_load_tmem(acc_tmem) + + scratch_shapes = [ + plgpu.Barrier(orders_tensor_core=True), + plgpu.TMEM((m, n), jnp.float32), + plgpu.TMEM((m, k // 32), jnp.float8_e8m0fnu, layout=plgpu.TMEMLayout.SCALES_LAYOUT), + plgpu.TMEM((n, k // 32), jnp.float8_e8m0fnu, layout=plgpu.TMEMLayout.SCALES_LAYOUT), + ] + + f = self.pallas_call( + kernel, + in_specs=( + plgpu.BlockSpec(memory_space=plgpu.SMEM, transforms=transforms), + plgpu.BlockSpec(memory_space=plgpu.SMEM, transforms=transforms), + plgpu.BlockSpec(memory_space=plgpu.SMEM), + plgpu.BlockSpec(memory_space=plgpu.SMEM), + ), + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), + out_specs=plgpu.BlockSpec(transforms=out_transforms), + scratch_shapes=scratch_shapes, + ) + x = jax.random.uniform(jax.random.key(1), shape=(m, k), dtype=jnp.float32).astype(dtype) + y = jax.random.uniform(jax.random.key(2), shape=(n, k), dtype=jnp.float32).astype(dtype) + ksx, ksy = jax.random.split(jax.random.key(1234), 2) + x_scale = jax.lax.bitcast_convert_type( + jax.random.randint(ksx, (m, k // 32), 122, 132, dtype=jnp.uint8), + jnp.float8_e8m0fnu + ) + y_scale = jax.lax.bitcast_convert_type( + jax.random.randint(ksy, (n, k // 32), 122, 132, dtype=jnp.uint8), + jnp.float8_e8m0fnu + ) + def format_scales(scales): + mn, k = scales.shape + assert mn % 128 == 0 and k % 4 == 0 + return ( + scales.reshape(mn // 128, 4, 32, k // 4, 4) + .transpose(0, 3, 2, 1, 4) + .reshape(mn // 128, k // 4, 32, 16) + ) + result = f(x, y, format_scales(x_scale), format_scales(y_scale)) + x_logical_scale = jnp.repeat(x_scale, 32, axis=1).astype(jnp.float32) + y_logical_scale = jnp.repeat(y_scale, 32, axis=1).astype(jnp.float32) + expected = jnp.dot( + x.astype(jnp.float32) * x_logical_scale, + (y.astype(jnp.float32) * y_logical_scale).T, + ) + np.testing.assert_allclose(result, expected, rtol=1e-3) + + @parameterized.product( + m=[256], + n=[128, 256], + scale_jax_dtype=[jnp.float8_e8m0fnu, jnp.float8_e4m3fn], + ) + def test_collective_scaled_matmul(self, m, n, scale_jax_dtype): + self.skip_if_wg_semantics() + + in_jax_dtype = jnp.float4_e2m1fn + out_jax_dtype = jnp.float32 + scale_block = 32 if scale_jax_dtype == jnp.float8_e8m0fnu else 16 + swizzle = 128 + k_steps = 2 + swizzle_elems = 8 * swizzle // dtypes.itemsize_bits(in_jax_dtype) + k = swizzle_elems * k_steps + tiling = (8, swizzle_elems) + transforms = ( + plgpu.TilingTransform(tiling), plgpu.SwizzleTransform(swizzle) + ) + + m_block = m // 2 + n_block = n // 2 + + def kernel(lhs_gmem, rhs_gmem, lhs_scales_gmem, rhs_scales_gmem, out_gmem, + lhs_smem, rhs_smem, lhs_scales_smem, rhs_scales_smem, + tma_barrier, mma_barrier, + acc_tmem, lhs_scales_tmem, rhs_scales_tmem): + plgpu.copy_gmem_to_smem(lhs_gmem, lhs_smem, tma_barrier, + collective_axes="x", partitioned_axis=0) + plgpu.copy_gmem_to_smem(rhs_gmem, rhs_smem, tma_barrier, + collective_axes="x", partitioned_axis=0) + plgpu.copy_gmem_to_smem(lhs_scales_gmem, lhs_scales_smem, tma_barrier, + collective_axes="x", partitioned_axis=0) + # RHS scales are replicated (multicast) + plgpu.copy_gmem_to_smem(rhs_scales_gmem, rhs_scales_smem, tma_barrier, + collective_axes="x", partitioned_axis=None) + cluster_idx = lax.axis_index("x") + + @pl.when(cluster_idx == 0) + def _leader_block(): + plgpu.barrier_wait(tma_barrier) + plgpu.async_copy_scales_to_tmem(lhs_scales_smem, lhs_scales_tmem, collective_axis="x") + plgpu.async_copy_scales_to_tmem(rhs_scales_smem, rhs_scales_tmem, collective_axis="x") + plgpu.tcgen05_mma( + acc_tmem, + lhs_smem, + plgpu.transpose_ref(rhs_smem, (1, 0)), + mma_barrier, + a_scale=lhs_scales_tmem, + b_scale=rhs_scales_tmem, + accumulate=False, + collective_axis="x" + ) + plgpu.barrier_wait(mma_barrier) + + slice_out = pl.ds(cluster_idx * m_block, m_block) + out_gmem[slice_out, :] = plgpu.async_load_tmem(acc_tmem) + + scratch_shapes = [ + plgpu.SMEM((m_block, k), in_jax_dtype, transforms=transforms), + plgpu.SMEM((n_block, k), in_jax_dtype, transforms=transforms), + plgpu.SMEM((m_block // 128, k // (scale_block * 4), 32, 16), scale_jax_dtype), + plgpu.SMEM((n // 128, k // (scale_block * 4), 32, 16), scale_jax_dtype), + plgpu.Barrier(num_arrivals=4), + plgpu.Barrier(orders_tensor_core=True), + plgpu.TMEM((m_block, n), out_jax_dtype, collective=True), + plgpu.TMEM((m_block, k // scale_block), scale_jax_dtype, + layout=plgpu.TMEMLayout.SCALES_LAYOUT, collective=True), + plgpu.TMEM((n, k // scale_block), scale_jax_dtype, + layout=plgpu.TMEMLayout.SCALES_LAYOUT, collective=True), + ] + + f = self.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct((m, n), out_jax_dtype), + grid=(1,), + grid_names=("_",), + cluster=(2,), + cluster_names=("x",), + scratch_shapes=scratch_shapes, + ) + + x = jax.random.uniform(jax.random.key(1), shape=(m, k), dtype=jnp.float32).astype(in_jax_dtype) + y = jax.random.uniform(jax.random.key(2), shape=(n, k), dtype=jnp.float32).astype(in_jax_dtype) + + ka, kb = jax.random.split(jax.random.key(1234), 2) + if scale_jax_dtype == jnp.float8_e8m0fnu: + x_scale = jax.lax.bitcast_convert_type( + jax.random.randint(ka, (m, k // scale_block), 122, 132, dtype=jnp.uint8), + scale_jax_dtype + ) + y_scale = jax.lax.bitcast_convert_type( + jax.random.randint(kb, (n, k // scale_block), 122, 132, dtype=jnp.uint8), + scale_jax_dtype + ) + else: + x_scale = jnp.abs( + jax.random.normal(ka, (m, k // scale_block), dtype=jnp.float32).astype(scale_jax_dtype) + ) + y_scale = jnp.abs( + jax.random.normal(kb, (n, k // scale_block), dtype=jnp.float32).astype(scale_jax_dtype) + ) + + def format_scales(scales): + mn, k = scales.shape + assert mn % 128 == 0 and k % 4 == 0 + return ( + scales.reshape(mn // 128, 4, 32, k // 4, 4) + .transpose(0, 3, 2, 1, 4) + .reshape(mn // 128, k // 4, 32, 16) + ) + + result = f(x, y, format_scales(x_scale), format_scales(y_scale)) + + x_logical_scale = jnp.repeat(x_scale, scale_block, axis=1).astype(jnp.float32) + y_logical_scale = jnp.repeat(y_scale, scale_block, axis=1).astype(jnp.float32) + expected = jnp.dot( + x.astype(jnp.float32) * x_logical_scale, + (y.astype(jnp.float32) * y_logical_scale).T, + ) + np.testing.assert_allclose(result, expected, rtol=1e-3) + + @parameterized.product( + m=[128], + n=[128, 256], + dtype=[jnp.float16], + ) + def test_simple_sparse_matmul(self, m, n, dtype): + self.skip_if_wg_semantics() + k = 128 + swizzle = 128 // jnp.dtype(dtype).itemsize + transforms = self.default_transforms(swizzle=swizzle, dtype=dtype) + out_transforms = self.default_transforms(dtype=jnp.float32) + + def kernel(a_smem, b_smem, a_sparse_smem, out_ref, + barrier_ref, acc_tmem, a_sparse_tmem): + plgpu.async_copy_sparse_metadata_to_tmem(a_sparse_smem, a_sparse_tmem) + # We don't have to await the copy because it's only used by the MMA. + plgpu.tcgen05_mma(acc_tmem, + a_smem, + plgpu.transpose_ref(b_smem, (1, 0)), + a_sparse_metadata=a_sparse_tmem, + accumulate=False) + plgpu.tcgen05_commit_arrive(barrier_ref) + plgpu.barrier_wait(barrier_ref) + # We don't await the load because acc_tmem is never modified again. + out_ref[...] = plgpu.async_load_tmem(acc_tmem) + + scratch_shapes = [ + plgpu.Barrier(orders_tensor_core=True), + plgpu.TMEM((m, n), jnp.float32), + plgpu.TMEM((m, k // 2), jnp.uint2, layout=plgpu.TMEMLayout.SPARSE_METADATA_LAYOUT), + ] + + f = self.pallas_call( + kernel, + in_specs=( + plgpu.BlockSpec(memory_space=plgpu.SMEM, transforms=transforms), + plgpu.BlockSpec(memory_space=plgpu.SMEM, transforms=transforms), + plgpu.BlockSpec(memory_space=plgpu.SMEM), + ), + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), + out_specs=plgpu.BlockSpec(transforms=out_transforms), + scratch_shapes=scratch_shapes, + ) + x = jax.random.uniform(jax.random.key(1), shape=(m, k // 2), dtype=dtype) + y = jax.random.uniform(jax.random.key(2), shape=(n, k), dtype=dtype) + index_pairs = np.asarray(np.meshgrid(range(4), range(4))).T.reshape(-1, 2) + valid_pairs = index_pairs[index_pairs[:, 0] < index_pairs[:, 1]] + assert len(valid_pairs) == 6 + x_pairs = jax.random.randint(jax.random.key(1234), (m, k // 4), 0, 6, dtype=jnp.uint8) + x_sparse = valid_pairs[x_pairs] + assert x_sparse.shape == (m, k // 4, 2) + z = f(x, y, plgpu.format_tcgen05_sparse_metadata(x_sparse.astype(jnp.uint2))) + x_logical = np.zeros_like(x, shape=(m, k // 4, 4)) + np.put_along_axis(x_logical, x_sparse, x.reshape(x_sparse.shape), axis=-1) + x_logical = x_logical.reshape(m, k) + ref = x_logical.astype(jnp.float32) @ y.T.astype(jnp.float32) + np.testing.assert_allclose(z, ref, atol=7e-5, rtol=5e-6) + + @parameterized.parameters( + (128, jnp.float16) + ) + def test_manual_tcgen05_commit_arrive(self, swizzle, dtype): + shape = (128, 128) + transforms = self.default_transforms(swizzle=swizzle, dtype=dtype) + + def kernel(a_gmem, b_gmem, out_gmem, + a_smem, b_smem, tma_barrier, mma_barrier, acc_tmem): + plgpu.copy_gmem_to_smem(a_gmem, a_smem, tma_barrier) + plgpu.barrier_wait(tma_barrier) + plgpu.copy_gmem_to_smem(b_gmem, b_smem, tma_barrier) + plgpu.barrier_wait(tma_barrier) + + plgpu.commit_tmem() + # Don't pass a barrier directly into tcgen05_mma and arrive manually. + plgpu.tcgen05_mma(acc_tmem, + a_smem, + b_smem, + accumulate=False) + plgpu.tcgen05_commit_arrive(mma_barrier) + plgpu.barrier_wait(mma_barrier) + # We don't await the load because acc_tmem is never modified again. + out_gmem[...] = plgpu.async_load_tmem(acc_tmem).astype(dtype) + + f = self.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct(shape, dtype), + scratch_shapes=[ + plgpu.SMEM(shape, dtype, transforms=transforms), # a_smem + plgpu.SMEM(shape, dtype, transforms=transforms), # b_smem + plgpu.Barrier(), # tma_barrier + plgpu.Barrier(orders_tensor_core=True), # mma_barrier + plgpu.TMEM((128, 128), jnp.float32), # acc + ], + ) + x = jax.random.uniform(jax.random.key(0), shape=shape, dtype=dtype) + y = jax.random.uniform(jax.random.key(1), shape=shape, dtype=dtype) + result = f(x, y) + np.testing.assert_allclose(result, x @ y, rtol=1e-3) + + def test_matmul_with_sliced_accumulator(self): + dtype = jnp.bfloat16 + shape = (128, 128) + tmem_shape = (128, 2 * 128) + swizzle = 128 + + # Test a matmul with a single block. + transforms = self.default_transforms(swizzle=swizzle, dtype=dtype) + + def kernel(a_smem, b_smem, out_ref, acc_tmem, barrier_ref): + acc_tmem_slice = acc_tmem.at[slice(None), pl.dslice(0, 128)] + plgpu.tcgen05_mma(acc_tmem_slice, + a_smem, + b_smem, + barrier_ref, + accumulate=False) + plgpu.barrier_wait(barrier_ref) + # We don't await the load because acc_tmem is never modified again. + out_ref[...] = plgpu.async_load_tmem(acc_tmem_slice).astype(dtype) + + scratch_shapes = [ + plgpu.TMEM(tmem_shape, jnp.float32, packed=False), + plgpu.Barrier(orders_tensor_core=True), + ] + + f = self.pallas_call( + kernel, + in_specs=( + plgpu.BlockSpec(transforms=transforms, memory_space=plgpu.SMEM), + plgpu.BlockSpec(transforms=transforms, memory_space=plgpu.SMEM), + ), + out_specs=plgpu.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(shape, dtype), + scratch_shapes=scratch_shapes, + ) + x = jax.random.uniform(jax.random.key(0), shape=shape, dtype=dtype) + y = jax.random.uniform(jax.random.key(1), shape=shape, dtype=dtype) + result = f(x, y) + expected = x @ y + np.testing.assert_allclose(result, expected, rtol=1e-3) + + @parameterized.product( + m_n_k=[ + (256, 256, 256), + (256, 128, 128), + (256, 256, 64), + (128, 64, 128), + (128, 64, 128), + ], + swizzle=[128, 64, 32], + dtype=[jnp.float16, jnp.bfloat16], + lhs_tmem=[False, True], + ) + def test_simple_collective_matmul(self, m_n_k, swizzle, dtype, lhs_tmem): + m, n, k = m_n_k + if (n // 2) * jnp.dtype(dtype).itemsize < swizzle: + self.skipTest("swizzle too big") + full_lhs_shape = (m, k) + full_rhs_shape = (k, n) + full_acc_shape = (m, n) + block_acc_shape = (m // 2, n) + block_lhs_shape = (m // 2, k) + block_rhs_shape = (k, n // 2) + # Test a collective (paired CTA) matmul on a single block. + transforms = self.default_transforms(swizzle=swizzle, dtype=dtype) + if lhs_tmem and m == 128: + self.skipTest("m=128 not supported for LHS in TMEM") + + def kernel(a_gmem, b_gmem, out_gmem, a_smem, b_smem, + acc_tmem, tma_barrier, mma_barrier, + cluster_barrier, lhs_tmem_ref): + cluster_idx = lax.axis_index("x") + slice_lhs = pl.ds(cluster_idx * block_lhs_shape[0], block_lhs_shape[0]) + slice_rhs = pl.ds(cluster_idx * block_rhs_shape[1], block_rhs_shape[1]) + + plgpu.copy_gmem_to_smem(a_gmem.at[slice_lhs, :], a_smem, tma_barrier) + plgpu.barrier_wait(tma_barrier) + plgpu.copy_gmem_to_smem(b_gmem.at[:, slice_rhs], b_smem, tma_barrier) + plgpu.barrier_wait(tma_barrier) + + if lhs_tmem: + lhs_ref = lhs_tmem_ref + plgpu.async_store_tmem(lhs_ref, plgpu.load(a_smem, (), layout=plgpu.Layout.TCGEN05)) + plgpu.commit_tmem() + else: + lhs_ref = a_smem + + plgpu.barrier_arrive(cluster_barrier) + plgpu.barrier_wait(cluster_barrier) + + plgpu.tcgen05_mma( + acc_tmem, + lhs_ref, + b_smem, + mma_barrier, + accumulate=False, + collective_axis="x", + ) + plgpu.barrier_wait(mma_barrier) + if m == 128: + layout = plgpu.Layout.TCGEN05_M64_COLLECTIVE(n) + else: + layout = plgpu.Layout.TCGEN05 + # We don't await the load because acc_tmem is never modified again. + out_gmem[slice_lhs, :] = plgpu.async_load_tmem(acc_tmem, layout=layout).astype(dtype) + + scratch_shapes = [ + plgpu.SMEM(block_lhs_shape, dtype, transforms=transforms), + plgpu.SMEM(block_rhs_shape, dtype, transforms=transforms), + plgpu.TMEM(block_acc_shape, jnp.float32, collective=True), + plgpu.Barrier(), + plgpu.Barrier(orders_tensor_core=True), + plgpu.ClusterBarrier(collective_axes=("x",)), + ] + if lhs_tmem: + scratch_shapes.append( + plgpu.TMEM(block_lhs_shape, dtype, collective=True, packed=True) + ) + else: + scratch_shapes.append(None) + + f = self.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct(full_acc_shape, dtype), + grid=(1,), + grid_names=("_",), + cluster=(2,), + cluster_names=("x",), + scratch_shapes=scratch_shapes, + ) + x = jax.random.uniform(jax.random.key(0), shape=full_lhs_shape, dtype=dtype) + y = jax.random.uniform(jax.random.key(1), shape=full_rhs_shape, dtype=dtype) + result = f(x, y) + expected = x @ y + np.testing.assert_allclose(result, expected, rtol=1e-3) + + @parameterized.parameters( + (128, jnp.float16) + ) + def test_matmul_with_smem_aliasing(self, swizzle, dtype): + # Perform a 128x128 @ 128x128 matmul and a 128x64 @ 64x128 matmul + # using aliased Refs pointing to the same SMEM address. + self.skip_if_wg_semantics() + shape = (128, 128) + transforms = self.default_transforms(swizzle=swizzle, dtype=dtype) + + def kernel(a_gmem, b_gmem, out_gmem128, out_gmem64, + a_aliased, b_aliased, tma_barrier, mma_barrier, acc_tmem): + # Note: We directly copy into 128-sized refs assuming that both aliased + # refs point to the same address, so we can skip the copy for + # the 64-sized ref. We transpose the LHS Ref so that the 64-sized Ref + # receives the correct slice of data from this TMA. + # As this is implementation dependent, this test may break if we change + # the underlying aliasing behavior. + a_smem_128, a_smem_64 = a_aliased + plgpu.copy_gmem_to_smem(a_gmem, a_smem_128, tma_barrier) + plgpu.barrier_wait(tma_barrier) + b_smem_128, b_smem_64 = b_aliased + plgpu.copy_gmem_to_smem(b_gmem, b_smem_128, tma_barrier) + plgpu.barrier_wait(tma_barrier) + + # Do 128x128 @ 128x128 matmul + plgpu.tcgen05_mma(acc_tmem, + plgpu.transpose_ref(a_smem_128, (1, 0)), + b_smem_128, + mma_barrier, + accumulate=False) + plgpu.barrier_wait(mma_barrier) + out_gmem128[...] = plgpu.async_load_tmem(acc_tmem).astype(dtype) + + # Do 128x64 @ 64x128 matmul + plgpu.wait_load_tmem() # Make sure the loads are complete + plgpu.tcgen05_mma(acc_tmem, + plgpu.transpose_ref(a_smem_64, (1, 0)), + b_smem_64, + mma_barrier, + accumulate=False) + plgpu.barrier_wait(mma_barrier) + out_gmem64[...] = plgpu.async_load_tmem(acc_tmem).astype(dtype) + + f = self.kernel( + kernel, + out_shape=[jax.ShapeDtypeStruct(shape, dtype), + jax.ShapeDtypeStruct(shape, dtype)], + scratch_shapes=[ + plgpu.RefUnion( # aliased a_smem + plgpu.SMEM(shape, dtype, transforms=transforms), + plgpu.SMEM((64, 128), dtype, transforms=transforms), + ), + plgpu.RefUnion( # aliased b_smem + plgpu.SMEM(shape, dtype, transforms=transforms), + plgpu.SMEM((64, 128), dtype, transforms=transforms), + ), + plgpu.Barrier(), # tma_barrier + plgpu.Barrier(orders_tensor_core=True), # mma_barrier + plgpu.TMEM(shape, jnp.float32), # acc + ], + ) + x = jax.random.uniform(jax.random.key(0), shape=shape, dtype=dtype) + y = jax.random.uniform(jax.random.key(1), shape=shape, dtype=dtype) + result_128, result_64 = f(x.T, y) + np.testing.assert_allclose(result_128, x @ y, rtol=1e-3) + np.testing.assert_allclose(result_64, x[:, :64] @ y[:64, :], rtol=1e-3) + + @parameterized.parameters( + (128, jnp.float16) + ) + def test_matmul_with_tmem_aliasing(self, swizzle, dtype): + # Perform a 128x128 @ 128x128 matmul and a 128x64 @ 64x128 matmul + # using aliased Refs pointing to the same TMEM address. + self.skip_if_wg_semantics() + shape = (128, 128) + swizzle_elems = swizzle // jnp.dtype(dtype).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(swizzle), + ) + + def kernel(a_gmem, b_gmem, out_gmem128, out_gmem64, + a_smem, b_smem, tma_barrier, mma_barrier, aliased_refs): + plgpu.copy_gmem_to_smem(a_gmem, a_smem, tma_barrier) + plgpu.barrier_wait(tma_barrier) + plgpu.copy_gmem_to_smem(b_gmem, b_smem, tma_barrier) + plgpu.barrier_wait(tma_barrier) + [acc_128, lhs_128], [lhs_64, acc_64], _ = aliased_refs + + # Do 128x128 @ 128x128 matmul + plgpu.async_store_tmem(lhs_128, plgpu.load(a_smem, (), layout=plgpu.Layout.TCGEN05)) + plgpu.commit_tmem() + plgpu.tcgen05_mma(acc_128, + lhs_128, + b_smem, + mma_barrier, + accumulate=False) + plgpu.barrier_wait(mma_barrier) + out_gmem128[...] = plgpu.async_load_tmem(acc_128).astype(dtype) + + # Do 128x64 @ 64x128 matmul + plgpu.wait_load_tmem() # Make sure the loads have completed + plgpu.async_store_tmem( + lhs_64, + plgpu.load(a_smem.at[:, 0:64], (), layout=plgpu.Layout.TCGEN05), + ) + plgpu.commit_tmem() + plgpu.tcgen05_mma(acc_64, + lhs_64, + b_smem.at[0:64, :], + mma_barrier, + accumulate=False) + plgpu.barrier_wait(mma_barrier) + # We don't await the load because TMEM is never modified again. + out_gmem64[...] = plgpu.async_load_tmem(acc_64).astype(dtype) + + f = self.kernel( + kernel, + out_shape=[ + jax.ShapeDtypeStruct(shape, dtype), + jax.ShapeDtypeStruct(shape, dtype), + ], + scratch_shapes=[ + plgpu.SMEM(shape, dtype, transforms=transforms), # a_smem + plgpu.SMEM(shape, dtype, transforms=transforms), # b_smem + plgpu.Barrier(), # tma_barrier + plgpu.Barrier(orders_tensor_core=True), # mma_barrier + plgpu.RefUnion( # aliased_refs + [ + plgpu.TMEM((128, 128), jnp.float32), # acc + plgpu.TMEM((128, 128), dtype, packed=True), # lhs + ], + [ + plgpu.TMEM((128, 64), dtype, packed=True), # lhs + plgpu.TMEM((128, 128), jnp.float32), # acc + ], + plgpu.TMEM((128, 128), jnp.float32), # unused + ), + ], + ) + x = jax.random.uniform(jax.random.key(0), shape=shape, dtype=dtype) + y = jax.random.uniform(jax.random.key(1), shape=shape, dtype=dtype) + result_128, result_64 = f(x, y) + np.testing.assert_allclose(result_128, x @ y, rtol=1e-3) + np.testing.assert_allclose(result_64, x[:, :64] @ y[:64, :], rtol=1e-3) + + @parameterized.parameters((0,), (1,)) + def test_mma_barrier_indexing( + self, barrier_index, shape=(128, 128), swizzle=128, dtype=jnp.float16 + ): + transforms = self.default_transforms(swizzle=swizzle, dtype=dtype) + + def kernel(a_smem, b_smem, out_ref, acc_tmem, barrier_ref): + plgpu.tcgen05_mma( + acc_tmem, + a_smem, + b_smem, + barrier_ref.at[barrier_index], + accumulate=False, + ) + plgpu.barrier_wait(barrier_ref.at[barrier_index]) + out_ref[...] = plgpu.async_load_tmem(acc_tmem).astype(dtype) + + scratch_shapes = [ + plgpu.TMEM(shape, jnp.float32, packed=False), + plgpu.Barrier(num_barriers=2, orders_tensor_core=True), + ] + f = self.pallas_call( + kernel, + in_specs=( + plgpu.BlockSpec(transforms=transforms, memory_space=plgpu.SMEM), + plgpu.BlockSpec(transforms=transforms, memory_space=plgpu.SMEM), + ), + out_specs=plgpu.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(shape, dtype), + scratch_shapes=scratch_shapes, + ) + x = jax.random.uniform(jax.random.key(0), shape=shape, dtype=dtype) + y = jax.random.uniform(jax.random.key(1), shape=shape, dtype=dtype) + result = f(x, y) + expected = x @ y + np.testing.assert_allclose(result, expected, rtol=1e-3) + + @parameterized.product( + warp_level=(True, False), + squeezed_index=(True, False), + ) + def test_copy_gmem_to_smem_partitioned(self, warp_level, squeezed_index): + self.skip_if_wg_semantics() # `pl.core_map` not implemented for warpgroup. + block_size = (128, 128) + partitioned_block_size = (block_size[0] // 2, block_size[1]) + a = jax.random.uniform( + jax.random.key(0), shape=block_size, dtype=jnp.float32) + if squeezed_index: + a = a.reshape(1, *block_size) + b = jax.random.uniform( + jax.random.key(1), shape=block_size, dtype=jnp.float32) + def kernel(a_gmem, b_gmem, out_gmem, + a_smem, b_smem, + a_tma_barrier, b_tma_barrier, cluster_barrier): + if squeezed_index: + a_gmem = a_gmem.at[0] + cluster_idx = lax.axis_index("x") + out_slice = pl.ds(cluster_idx * partitioned_block_size[0], + partitioned_block_size[0]) + + if warp_level: + @pl.core_map(plgpu.WarpMesh(axis_name="warp")) + def _per_warp(): + warp_id = lax.axis_index("warp") + @pl.when(warp_id == 0) + def _(): + plgpu.copy_gmem_to_smem( + a_gmem, + a_smem, + a_tma_barrier, + collective_axes="x", + partitioned_axis=0, + ) + plgpu.copy_gmem_to_smem( + b_gmem, + b_smem, + b_tma_barrier, + collective_axes="x", + partitioned_axis=0, + ) + else: + plgpu.copy_gmem_to_smem( + a_gmem, + a_smem, + a_tma_barrier, + collective_axes="x", + partitioned_axis=0, + ) + plgpu.copy_gmem_to_smem( + b_gmem, + b_smem, + b_tma_barrier, + collective_axes="x", + partitioned_axis=0, + ) + @pl.when(cluster_idx == 0) + def _(): + plgpu.barrier_wait(a_tma_barrier) + plgpu.barrier_wait(b_tma_barrier) + plgpu.barrier_arrive(cluster_barrier) + plgpu.barrier_wait(cluster_barrier) + out_gmem[out_slice] = a_smem[...] + b_smem[...] + f = self.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct(block_size, jnp.float32), + grid=(1,), + grid_names=("_"), + cluster_names=("x",), + cluster=(2,), + scratch_shapes=( # type: ignore + plgpu.SMEM(partitioned_block_size, jnp.float32), + plgpu.SMEM(partitioned_block_size, jnp.float32), + plgpu.Barrier(num_arrivals=1), + plgpu.Barrier(num_arrivals=1), + plgpu.ClusterBarrier(collective_axes=("x",)), + ), + ) + result = f(a, b) + if squeezed_index: + a = a[0] + np.testing.assert_array_equal(result, a + b) + + def test_arrive_wait_on_tc_barrier(self): + def kernel(out_ref, barrier): + plgpu.barrier_arrive(barrier) + plgpu.barrier_wait(barrier) + out_ref[...] = jnp.ones_like(out_ref) + + f = self.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), + scratch_shapes=( # type: ignore + plgpu.Barrier(num_arrivals=1, orders_tensor_core=True), + ), + ) + np.testing.assert_array_equal(f(), np.ones((128,), np.float32)) + + @parameterized.parameters( + (0, (1,), False), + (0, (1,), True), + (1, (1,), False), + (2, (1,), False), + (0, (1, 2,), False), + (0, (2, 1,), False), + ) + def test_cluster_launch_control(self, dim, cluster, with_indexing): + self.skip_if_wg_semantics() + # We attempt to schedule 1 more CTA than can be scheduled at once. Only + # one CTA will succeed in stealing the last block, and the others will + # fail. Therefore we test that there is exactly 1 stolen block and the + # others fail and return -1. + + num_sms = jax.devices()[0].core_count + cluster_size = math.prod(cluster) + + grid = [1, 1, 1] + grid[dim] = num_sms // cluster_size + 1 + + grid_names = tuple("xyz"[: len(grid)]) + cluster_names = tuple("abc"[: len(cluster)]) + + def kernel(out_ref, cancel_result_ref, barrier, _): + if with_indexing: + cancel_result_ref = cancel_result_ref.at[0] + plgpu.try_cluster_cancel(cancel_result_ref, barrier) + plgpu.barrier_wait(barrier) + + cta_ids, cancelled_launch = plgpu.query_cluster_cancel( + cancel_result_ref, grid_names=grid_names) + cta_id = sum(cta_ids) + + # Store a sentinel value if no work can be scheduled. + value = lax.select(cancelled_launch, cta_id, jnp.int32(-1)) + + grid_idx = lax.axis_index(grid_names) * lax.axis_size( + cluster_names + ) + lax.axis_index(cluster_names) + out_ref[grid_idx] = value + + f = self.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct((num_sms,), jnp.int32), + grid=grid, + grid_names=grid_names, + num_threads=2, + thread_name="wg", + cluster=cluster, + cluster_names=cluster_names, + scratch_shapes=[ + plgpu.TryClusterCancelResult(2 if with_indexing else None), + plgpu.Barrier(num_arrivals=2), + # Requesting SMEM close to the 228kb limit to ensure that each SM + # only schedules 1 block. + plgpu.SMEM((220 * 1024,), jnp.int8), + ], + ) + result = np.sort(f()) + last_cta_id = math.ceil(num_sms / cluster_size) + expected = np.array([-1] * (num_sms - cluster_size) + [last_cta_id] * cluster_size) + np.testing.assert_equal(result, expected) + + +class PallasCallSm100AWGTest( + PallasCallSm100ATest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup +): + ... + + +class PipelineTest(PallasTest): + + def test_pipeline_mode(self): + def body(x_ref, y_ref, o_ref): + x = x_ref[:] + y = y_ref[:] + o_ref[:] = x + y + + data_size = 64 * 256 + block_size = 256 + + x = jnp.arange(data_size, dtype=jnp.float32) + y = jnp.arange(data_size, dtype=jnp.float32) + in_specs = [ + pl.BlockSpec((block_size,), lambda *i: i, pipeline_mode=pl.Buffered(2)), + pl.BlockSpec((block_size,), lambda *i: i, pipeline_mode=pl.Buffered(1)) + ] + out_specs = pl.BlockSpec((block_size,), lambda *i: i) + + @jax.jit + def vadd(x, y): + return self.pallas_call( + body, + out_shape=jax.ShapeDtypeStruct(x.shape, jnp.float32), + in_specs=in_specs, + out_specs=out_specs, + grid=data_size // block_size, + )(x, y) + + with self.assertRaisesRegex(Exception, "Pipeline mode is not supported"): + vadd(x, y) + + def test_manual(self): + max_concurrent_steps = 2 + num_steps = 4 + + def kernel(x_gmem, o_gmem): + return pl.run_scoped( + functools.partial(scoped_kernel, x_gmem, o_gmem), + plgpu.SMEM((max_concurrent_steps, 32, 16), jnp.float32), + plgpu.SMEM((max_concurrent_steps, 32, 16), jnp.float32), + plgpu.Barrier(num_barriers=max_concurrent_steps), + ) + + def scoped_kernel(x_gmem, o_gmem, x_smem, o_smem, barrier): + gmem_slice = pl.ds(pl.program_id(0) * 32, 32) + + def body(step, _): + slot = step % max_concurrent_steps + + # Wait for the current GMEM->SMEM copy to complete. + plgpu.barrier_wait(barrier.at[slot]) + # Wait for the previous output SMEM->GMEM copy to complete. + plgpu.wait_smem_to_gmem(max_concurrent_steps - 1) + + o_smem.at[slot][...] = x_smem.at[slot][...] + 1.0 + + plgpu.commit_smem() + plgpu.copy_smem_to_gmem( + o_smem.at[slot], o_gmem.at[gmem_slice, pl.ds(step * 16, 16)] + ) + + fetch_step = step + max_concurrent_steps + fetch_slot = slot # (x + y) % y == x % y + jax.lax.cond( + fetch_step < num_steps, + lambda: plgpu.copy_gmem_to_smem( + x_gmem.at[gmem_slice, pl.ds(fetch_step * 16, 16)], + x_smem.at[fetch_slot], + barrier.at[fetch_slot], + ), + lambda: None, + ) + return () + + # Initialize the pipeline. + for slot in range(min(max_concurrent_steps, num_steps)): + plgpu.copy_gmem_to_smem( + x_gmem.at[gmem_slice, pl.ds(slot * 16, 16)], + x_smem.at[slot], + barrier.at[slot], + ) + + jax.lax.fori_loop(0, num_steps, body, ()) + + # Finalize the pipeline. + plgpu.wait_smem_to_gmem(0) + + x = jnp.arange(32 * 4 * 64).reshape(32 * 4, 64).astype(jnp.float32) + kernel_fn = self.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + grid=(4, 1), + ) + np.testing.assert_array_equal(kernel_fn(x), x + 1.0) + + @parameterized.product( + transforms=( + (), + (plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128)), + ), + repeats=(1, 10), + ) + def test_emit(self, transforms, repeats): + if transforms: + self.skip_if_wg_semantics() + + num_steps = 4 + + def kernel(x_gmem, o_gmem): + for _ in range(repeats): + plgpu.emit_pipeline( + kernel_body, + in_specs=[ + plgpu.BlockSpec( + (64, 64), lambda i: (0, i), transforms=transforms + ) + ], + out_specs=[ + plgpu.BlockSpec( + (64, 64), lambda i: (0, i), transforms=transforms + ) + ], + grid=(num_steps,), + max_concurrent_steps=2, + )(x_gmem, o_gmem) + + def kernel_body(_, x_smem, o_smem): + # +1 for the indexing done by ``emit_pipeline`. + self.assertLen(x_smem.transforms, len(transforms) + 1) + o_smem[...] = x_smem[...] + 1.0 + + x = jnp.arange(64 * num_steps * 64) + x = x.reshape(-1, num_steps * 64).astype(jnp.float32) + kernel_fn = self.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + ) + np.testing.assert_array_equal(kernel_fn(x), x + 1.0) + + def test_nested_emit(self): + num_steps = 4 + + def kernel(x_gmem, o_gmem): + plgpu.emit_pipeline( + nested_kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + grid=(), + )(x_gmem, o_gmem) + + def nested_kernel(_, x_gmem, o_gmem): + plgpu.emit_pipeline( + nested_kernel_body, + in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], + out_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], + grid=(num_steps,), + max_concurrent_steps=2, + )(x_gmem, o_gmem) + + def nested_kernel_body(_, x_smem, o_smem): + o_smem[...] = x_smem[...] + 1.0 + + x = jnp.arange(32 * num_steps * 16) + x = x.reshape(-1, num_steps * 16).astype(jnp.float32) + kernel_fn = self.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + ) + np.testing.assert_array_equal(kernel_fn(x), x + 1.0) + + def test_emit_with_grid_invariant_output(self): + num_steps = 4 + + def kernel(x_gmem, o_gmem): + plgpu.emit_pipeline( + kernel_body, + in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], + out_specs=[pl.BlockSpec((32, 16), lambda i: (0, 0))], + grid=(num_steps,), + max_concurrent_steps=2, + )(x_gmem, o_gmem) + + def kernel_body(_, x_smem, o_smem): + o_smem[...] = x_smem[...] + 1.0 + + x = jnp.arange(32 * num_steps * 16) + x = x.reshape(-1, num_steps * 16).astype(jnp.float32) + kernel_fn = self.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + ) + y = jnp.empty_like(x) + for i in range(num_steps): + i_slice = slice(16 * i, 16 * (i + 1)) + y = y.at[:, :16].set(x[:, i_slice] + 1) + # We only compare the elements in the first 16 columns, because the rest + # are never written to. + np.testing.assert_array_equal(kernel_fn(x)[:, :16], y[:, :16]) + + def test_emit_with_no_output(self): + m, n = 16, 128 + + def kernel(x_gmem, o_gmem): + def acc_scope(acc_ref): + acc_ref[...] = jnp.zeros_like(acc_ref) + def body(_, x_smem): + acc_ref[...] += x_smem[...] # Can't += in a lambda... + in_specs = [plgpu.BlockSpec((1, n), lambda i: (i, 0), delay_release=1)] + plgpu.emit_pipeline( + body, + in_specs=in_specs, + grid=(m,), + max_concurrent_steps=2, + )(x_gmem) + return acc_ref[...] + + acc = pl.run_scoped(acc_scope, plgpu.SMEM((1, n), dtype=jnp.float32)) + + o_gmem[...] = acc[...] + + dtype = jnp.float32 + x = jax.random.uniform(jax.random.key(0), (m, n)).astype(dtype) + + kernel_fn = self.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct((1, n), dtype), + ) + + np.testing.assert_allclose(kernel_fn(x), x.sum(0, keepdims=True), rtol=1e-6) + + def test_emit_with_parallel_grid(self): + num_steps1 = 4 + num_steps2 = 5 + + def kernel(x_gmem, o_gmem): + pid = pl.program_id(0) + plgpu.emit_pipeline( + kernel_body, + in_specs=[pl.BlockSpec((32, 16), lambda i: (pid, i))], + out_specs=[pl.BlockSpec((32, 16), lambda i: (pid, i))], + grid=(num_steps2,), + max_concurrent_steps=2, + )(x_gmem, o_gmem) + + def kernel_body(_, x_smem, o_smem): + o_smem[...] = x_smem[...] + 1.0 + + x = jnp.arange(num_steps1 * 32 * num_steps2 * 16) + x = x.reshape(-1, num_steps2 * 16).astype(jnp.float32) + kernel_fn = self.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + grid=(num_steps1,), + ) + y = x + 1.0 + np.testing.assert_array_equal(kernel_fn(x), y) + + def test_emit_with_dynamic_grid_smaller_than_concurrent_steps(self): + block_x = 128 + x = jax.random.randint(jax.random.key(1234), (block_x,), + minval=-128, maxval=128, dtype=jnp.int32) + + def body(num_blocks_gmem, x_gmem, o_gmem): + num_blocks = num_blocks_gmem[...] + def pipeline_body(_, x_smem, o_smem): + o_smem[...] = x_smem[...] + for _ in range(2): + plgpu.emit_pipeline( + pipeline_body, + grid=(num_blocks,), + in_specs=[plgpu.BlockSpec((block_x,), lambda i: (i,))], + out_specs=[plgpu.BlockSpec((block_x,), lambda i: (i,))], + max_concurrent_steps=2, + )(x_gmem, o_gmem) + + # The test only intends to check that this does not crash/hang. + plgpu.kernel( + body, + out_shape=jax.ShapeDtypeStruct((block_x,), jnp.int32), + grid=(1,), + grid_names=("blocks",) + )(0, x).block_until_ready() + + @parameterized.product(static=[False, True], short=[False, True]) + def test_emit_with_2d_grid(self, static, short): + num_steps1 = 4 + num_steps2 = 5 + if short: + num_steps1 = num_steps2 = 1 + + def kernel(x_gmem, o_gmem): + grid = (num_steps1, num_steps2) + if static: + grid = jax.tree.map(jnp.asarray, grid) + + plgpu.emit_pipeline( + kernel_body, + in_specs=[pl.BlockSpec((32, 16, 8), lambda i, j: (0, i, j))], + out_specs=[pl.BlockSpec((32, 16, 8), lambda i, j: (0, i, j))], + grid=grid, + max_concurrent_steps=2, + )(x_gmem, o_gmem) + + def kernel_body(_, x_smem, o_smem): + o_smem[...] = x_smem[...] + 1.0 + + x = jnp.arange(32 * num_steps1 * 16 * num_steps2 * 8) + x = x.reshape(-1, num_steps1 * 16, num_steps2 * 8).astype(jnp.float32) + kernel_fn = self.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + ) + np.testing.assert_array_equal(kernel_fn(x), x + 1.0) + + def test_emit_with_carry(self): + num_steps = 4 + + def kernel(o_gmem): + plgpu.emit_pipeline( + kernel_body, + out_specs=[pl.BlockSpec((64, 64), lambda i: (0, i))], + grid=(num_steps,), + max_concurrent_steps=2, + init_carry=0, + )(o_gmem) + + def kernel_body(_, o_smem, carry): + o_smem[...] = lax.broadcast(carry, o_smem.shape) + return carry + 1 + + kernel_fn = self.pallas_call( + kernel, + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct((64, num_steps * 64), jnp.int32), + ) + np.testing.assert_array_equal( + kernel_fn(), jnp.tile(jnp.repeat(jnp.arange(num_steps), 64), (64, 1)) + ) + + @parameterized.parameters((pl.Squeezed(),), (None,)) + def test_emit_with_squeezed_dim(self, squeezed_dim): + + shape = (16, 256) + num_steps = shape[0] + + def kernel(x_gmem, o_gmem): + plgpu.emit_pipeline( + kernel_body, + in_specs=[pl.BlockSpec((squeezed_dim, shape[1]), lambda i: (i, 0))], + out_specs=[pl.BlockSpec((squeezed_dim, shape[1]), lambda i: (i, 0))], + grid=(num_steps,), + max_concurrent_steps=2, + )(x_gmem, o_gmem) + + def kernel_body(_, in_smem, o_smem): + assert in_smem.shape == (shape[1],) + assert o_smem.shape == (shape[1],) + o_smem[...] = in_smem[...] + 1 + + kernel_fn = self.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct((16, 256), jnp.int32), + ) + x = jnp.arange(16 * 256, dtype=jnp.int32).reshape(16, 256) + np.testing.assert_array_equal(kernel_fn(x), x + 1) + + +class PipelineWGTest( + PipelineTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup +): + ... + + +class PipelineSm90ATest(PallasSm90ATest): + + def test_realistic_matmul(self): + dtype = jnp.float16 + swizzle = 128 + elems_128b = swizzle // jnp.dtype(dtype).itemsize + grid_m, grid_k, grid_n = 132, 10, 4 + tile_m = tile_n = 128 + assert tile_m % elems_128b == 0 + tile_k = elems_128b + m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n + + transforms = self.default_transforms(swizzle=swizzle, dtype=dtype) + + def kernel(a_gmem, b_gmem, o_smem, acc): + def kernel_body(_, a_smem, b_smem): + assert a_smem.shape == (tile_m, tile_k) + assert b_smem.shape == (tile_k, tile_n) + plgpu.wgmma(acc, a_smem, b_smem) + plgpu.wgmma_wait(1) + + pid_m = pl.program_id(0) + pid_n = pl.program_id(1) + in_specs = [ + plgpu.BlockSpec( + (tile_m, tile_k), lambda k: (pid_m, k), transforms=transforms, + delay_release=1, + ), + plgpu.BlockSpec( + (tile_k, tile_n), lambda k: (k, pid_n), transforms=transforms, + delay_release=1, + ), + ] + plgpu.emit_pipeline( + kernel_body, + in_specs=in_specs, + grid=(grid_k,), + max_concurrent_steps=2, + )(a_gmem, b_gmem) + + o_smem[...] = acc[...].astype(dtype) + + key1, key2 = jax.random.split(jax.random.key(42), 2) + a = jax.random.uniform(key1, shape=(m, k), dtype=dtype) + b = jax.random.uniform(key2, shape=(k, n), dtype=dtype) + + res = self.pallas_call( + kernel, + in_specs=[ + pl.BlockSpec(memory_space=plgpu.GMEM), + pl.BlockSpec(memory_space=plgpu.GMEM), + ], + out_specs=plgpu.BlockSpec( + (tile_m, tile_n), lambda m, n: (m, n), transforms=transforms + ), + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16), + scratch_shapes=[plgpu.ACC((tile_m, tile_n), jnp.float32)], + grid=(grid_m, grid_n), + )(a, b) + np.testing.assert_array_equal(res, a @ b) + + +class PipelineSm90AWGTest( + PipelineSm90ATest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup +): + ... + + +class WarpSpecializedPipelineTest(PallasTest): + + @parameterized.product(m=[512], n=[512], repeats=[1, 10], + manual_consumed_barriers=[False, True], + max_concurrent_steps=[2, 3]) + def test_pipelined_copy( + self, m, n, repeats, manual_consumed_barriers, max_concurrent_steps + ): + x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float16) + blk_m = blk_n = 32 + + def copy_kernel(_, x_smem, o_smem, o_last_block_smem, *consumed_barriers): + wg_idx = lax.axis_index("wg") + o_smem[...] = x_smem[...] + o_last_block_smem[...] = x_smem[...] + if manual_consumed_barriers: + [x_barrier] = consumed_barriers + plgpu.barrier_arrive(x_barrier) + + spec = pl.BlockSpec( + block_shape=(2 * blk_m, blk_n), index_map=lambda i, j: (i, j) + ) + def body(*gmem_refs): + pipeline = mgpu_pipeline.emit_pipeline_warp_specialized( + copy_kernel, + grid=(m // (2 * blk_m), n // blk_n), + memory_registers=40, + max_concurrent_steps=max_concurrent_steps, + num_compute_wgs=1, + wg_axis="wg", + manual_consumed_barriers=manual_consumed_barriers, + in_specs=[spec], + out_specs=[ + spec, + # Create an index-invariant output. + pl.BlockSpec( + block_shape=(2 * blk_m, blk_n), index_map=lambda i, j: (0, 0) + ), + ], + ) + for _ in range(repeats): + pipeline(*gmem_refs) # Make sure we can run the pipeline multiple times + kernel = self.kernel( + body, + out_shape=( + jax.ShapeDtypeStruct((m, n), jnp.float16), + jax.ShapeDtypeStruct((2 * blk_m, blk_n), jnp.float16), + ), + compiler_params=plgpu.CompilerParams(approx_math=True), + grid=(1,), + grid_names=("_",), + num_threads=2, + thread_name="wg", + ) + out, out_last_block = kernel(x) + np.testing.assert_array_equal(out, x) + np.testing.assert_array_equal(out_last_block, x[-(2 * blk_m):, -blk_n:]) + + @parameterized.product( + m=[256, 64], + n=[256, 64], + num_compute_wgs=[1], # TODO(apaszke): Use 2WGs once we add support for outputs. + static=[False, True], + manual_consumed_barriers=[False, True], + in_tree_template=[(0, 1), ((0, (1,), None))], + ) + @jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell") + def test_elementwise_add(self, m, n, num_compute_wgs, static, + manual_consumed_barriers, in_tree_template): + blk_m = blk_n = 64 + if m % (num_compute_wgs * blk_m): + self.skipTest(f"{m=} must be divisible by {num_compute_wgs=} * {blk_m=}") + spec = pl.BlockSpec( + block_shape=(num_compute_wgs * blk_m, blk_n), + index_map=lambda i, j: (i, j), + ) + in_treedef = jax.tree.structure(in_tree_template) + in_specs = jax.tree.unflatten(in_treedef, (spec, spec)) + + def tiled_add_kernel(_, *smems): + flat_smems, _ = jax.tree.flatten(smems) + x_smem, y_smem, o_smem, *consumed_barriers = flat_smems + + wg_idx = lax.axis_index("wg") + m_slice = pl.ds(wg_idx * blk_m, blk_m) + o_smem[m_slice] = x_smem[m_slice] + y_smem[m_slice] + if manual_consumed_barriers: + [x_consumed_barrier, y_consumed_barrier] = consumed_barriers + plgpu.barrier_arrive(x_consumed_barrier) + plgpu.barrier_arrive(y_consumed_barrier) + + def pipeline(*gmem_refs): + grid = (m // (num_compute_wgs * blk_m), n // blk_n) + if not static: + grid = jax.tree.map(jnp.asarray, grid) + return mgpu_pipeline.emit_pipeline_warp_specialized( + tiled_add_kernel, + grid=grid, + max_concurrent_steps=2, + num_compute_wgs=num_compute_wgs, + memory_registers=40, + wg_axis="wg", + in_specs=in_specs, + out_specs=[spec], + manual_consumed_barriers=manual_consumed_barriers, + )(*gmem_refs) + + kernel = self.kernel( + pipeline, + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), + compiler_params=plgpu.CompilerParams(approx_math=True), + grid=(1,), + grid_names=("_",), + num_threads=num_compute_wgs + 1, + thread_name="wg", + ) + x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32) + y = jax.random.uniform(jax.random.key(1), (m, n), dtype=jnp.float32) + inputs = jax.tree.unflatten(in_treedef, (x, y)) + np.testing.assert_allclose(kernel(*inputs), x + y, atol=1e-4) + + def test_carry_accumulate(self, m=256, n=256, num_compute_wgs=2): + blk_m = blk_n = 64 + + @functools.partial( + self.kernel, + out_shape=jax.ShapeDtypeStruct((blk_m, blk_n), jnp.float32), + scratch_shapes=[ + plgpu.SMEM((blk_m, blk_n), jnp.float32), + ], + compiler_params=plgpu.CompilerParams(approx_math=True), + grid=(1,), + grid_names=("_",), + num_threads=num_compute_wgs + 1, + thread_name="wg", + ) + def kernel(x_gmem, acc_gmem, acc_smem): + def _compute_thread(pipeline_fn): + # Cast the init value to the same layout as x_smem, so the pipeline loop + # carry has a constant signature. + o_acc = plgpu.layout_cast( + jnp.full((blk_m, blk_n,), 0, dtype=jnp.float32), + plgpu.Layout.WG_STRIDED((blk_m, blk_n), vec_size=2)) + # Pass control to the pipeline emitter and return the final carry. + o_final = pipeline_fn(o_acc) + # Note that both compute WGs are doing identical work so the potential + # race condition on the store here won't affect the result. + acc_smem[...] = o_final + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(acc_smem, acc_gmem) + plgpu.wait_smem_to_gmem(0) + + def tiled_acc_kernel(_, x_smem, carry): + new_carry = x_smem[...] + carry + return new_carry + + pipeline = mgpu_pipeline.emit_pipeline_warp_specialized( + tiled_acc_kernel, + grid=(m // blk_m, n // blk_n), + max_concurrent_steps=2, + num_compute_wgs=num_compute_wgs, + memory_registers=40, + wg_axis="wg", + compute_context=_compute_thread, + in_specs=[ + pl.BlockSpec( + block_shape=(blk_m, blk_n), index_map=lambda i, j: (i, j) + ) + ], + out_specs=[], + ) + pipeline(x_gmem) + + x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32) + ref = jnp.sum(jnp.stack(np.split(x, m // blk_m, axis=0)), axis=0) + ref = jnp.sum(jnp.stack(np.split(ref, n // blk_n, axis=1)), axis=0) + np.testing.assert_allclose(kernel(x), ref, atol=1e-4) + + @parameterized.product( + num_compute_wgs=[1], # TODO(apaszke): Use 2WGs once we add support for outputs. + static=[False, True], + manual_consumed_barriers=[False, True], + small_shape=[True, False], + max_concurrent_steps=[2, 3, 4], + ) + @jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell") + def test_delay_release( + self, num_compute_wgs, static, manual_consumed_barriers, small_shape, + max_concurrent_steps + ): + if small_shape: + m = n = 64 + else: + m = n = 256 + blk_m, blk_n = 32, 64 + spec = plgpu.BlockSpec( + block_shape=(num_compute_wgs * blk_m, blk_n), + index_map=lambda i, j: (i, j), + delay_release=1, + ) + out_spec = pl.BlockSpec( + block_shape=(num_compute_wgs * blk_m, blk_n), + index_map=lambda i, j: (i, j), + ) + + def tiled_add_kernel(idx, x_smem, y_smem, o_smem, *consumed_barriers): + wg_idx = lax.axis_index("wg") + m_slice = pl.ds(wg_idx * blk_m, blk_m) + o_smem[m_slice] = x_smem[m_slice] + y_smem[m_slice] + if manual_consumed_barriers: + @pl.when(jnp.logical_or(idx[0] != 0, idx[1] != 0)) + def _signal_consumed(): + for b in consumed_barriers: + plgpu.barrier_arrive(b) + + def pipeline(*gmem_refs): + grid = (m // (num_compute_wgs * blk_m), n // blk_n) + if not static: + grid = jax.tree.map(jnp.asarray, grid) + return mgpu_pipeline.emit_pipeline_warp_specialized( + tiled_add_kernel, + grid=grid, + max_concurrent_steps=max_concurrent_steps, + manual_consumed_barriers=manual_consumed_barriers, + num_compute_wgs=num_compute_wgs, + memory_registers=40, + wg_axis="wg", + in_specs=[spec, spec], + out_specs=[out_spec], + )(*gmem_refs) + + kernel = self.kernel( + pipeline, + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), + compiler_params=plgpu.CompilerParams(approx_math=True), + grid=(1,), + grid_names=("_",), + num_threads=num_compute_wgs + 1, + thread_name="wg", + ) + x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32) + y = jax.random.uniform(jax.random.key(1), (m, n), dtype=jnp.float32) + np.testing.assert_allclose(kernel(x, y), x + y, atol=1e-4) + + def test_different_delay_release(self): + m, n = 128, 64 + blk_m, blk_n = 32, 32 + in_specs = [ + plgpu.BlockSpec( + block_shape=(blk_m, blk_n), + index_map=lambda i, j: (i, j), + delay_release=delay, + ) + for delay in range(3) + ] + out_spec = pl.BlockSpec( + block_shape=(blk_m, blk_n), + index_map=lambda i, j: (i, j), + ) + + def tiled_add_kernel(_, x_smem, y_smem, z_smem, o_smem): + o_smem[...] = x_smem[...] + y_smem[...] + z_smem[...] + + def pipeline(*gmem_refs): + grid = (m // blk_m, n // blk_n) + return mgpu_pipeline.emit_pipeline( + tiled_add_kernel, + grid=grid, + max_concurrent_steps=4, + in_specs=in_specs, + out_specs=[out_spec], + )(*gmem_refs) + + kernel = self.kernel( + pipeline, + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), + grid=(1,), + grid_names=("_",) + ) + x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32) + y = jax.random.uniform(jax.random.key(1), (m, n), dtype=jnp.float32) + z = jax.random.uniform(jax.random.key(3), (m, n), dtype=jnp.float32) + np.testing.assert_allclose(kernel(x, y, z), x + y + z) + + @parameterized.product( + delay_release=[0, 1], + ) + def test_repeat(self, delay_release): + num_steps = 4 + + def kernel_body(_, x_smem, o_smem): + o_smem[...] = x_smem[...] + 1.0 + + def kernel(x_gmem, o_gmem): + in_specs = [ + plgpu.BlockSpec((64, 64), lambda i: (0, i), delay_release=delay_release) + ] + out_specs = [plgpu.BlockSpec((64, 64), lambda i: (0, i))] + for _ in range(3): + plgpu.emit_pipeline_warp_specialized( + kernel_body, + in_specs=in_specs, + out_specs=out_specs, + grid=(num_steps,), + max_concurrent_steps=2, + num_compute_wgs=1, + memory_registers=40, + wg_axis="wg", + )(x_gmem, o_gmem) + + x = jnp.arange(64 * num_steps * 64) + x = x.reshape(-1, num_steps * 64).astype(jnp.float32) + kernel_fn = self.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + grid=(1,), + grid_names=("_",), + num_threads=2, + thread_name="wg", + ) + np.testing.assert_array_equal(kernel_fn(x), x + 1.0) + + @parameterized.parameters((False,), (True,)) + def test_stationary_input(self, flip): + m = n = 256 + blk_m = blk_n = 64 + + def add_kernel(_, x_smem, y_smem, o_smem): + if flip: + x_smem, y_smem = y_smem, x_smem + o_smem[...] = x_smem[...] + y_smem[...] + + def body(*gmem_refs): + mgpu_pipeline.emit_pipeline_warp_specialized( + add_kernel, + grid=(m // blk_m, n // blk_n), + memory_registers=40, + max_concurrent_steps=2, + num_compute_wgs=1, + wg_axis="wg", + in_specs=[ + pl.BlockSpec( + block_shape=(blk_m, blk_n), index_map=lambda i, j: (i, j) + ), + pl.BlockSpec( + block_shape=(blk_m, blk_n), index_map=lambda i, j: (0, 0) + ) + ][::(-1 if flip else 1)], + out_specs=[ + pl.BlockSpec( + block_shape=(blk_m, blk_n), index_map=lambda i, j: (i, j) + ), + ], + )(*gmem_refs) + kernel = self.kernel( + body, + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16), + grid=(1,), + grid_names=("_",), + num_threads=2, + thread_name="wg", + ) + x = jax.random.uniform(jax.random.key(1), (m, n), dtype=jnp.float16) + y = jax.random.uniform(jax.random.key(2), (blk_m, blk_n), dtype=jnp.float16) + ref = x + np.tile(y, (m // blk_m, n // blk_n)) + if flip: + x, y = y, x + # TODO(apaszke,justinfu): Fix the bug (this test freezes) and remove this restriction. + with self.assertRaisesRegex( + NotImplementedError, + "Only inputs with a dependency on the grid are supported.", + ): + out = kernel(x, y) + np.testing.assert_array_equal(out, ref) + + def test_no_output(self): + m = n = 256 + blk_m = blk_n = 64 + + def body(x_ref, o_ref, o_scratch, barrier): + @pl.when(lax.axis_index("wg") == 0) + def _(): + o_scratch[...] = jnp.zeros_like(o_scratch) + + # Wait for scratch to be initialized + plgpu.barrier_arrive(barrier) + plgpu.barrier_wait(barrier) + + # Make sure we can run the pipeline many times. This also introduces + # extra jitter into warp scheduling and has uncovered bugs in the past. + @pl.loop(0, 10) + def _pipeline_loop(_): + def add(_, x_smem): + slc = pl.ds(lax.axis_index("wg") * (blk_m // 2), blk_m // 2) + o_scratch[slc] += x_smem[slc] + mgpu_pipeline.emit_pipeline_warp_specialized( + add, + grid=(m // blk_m, n // blk_n), + memory_registers=40, + max_concurrent_steps=2, + num_compute_wgs=2, + wg_axis="wg", + in_specs=[ + pl.BlockSpec( + block_shape=(blk_m, blk_n), index_map=lambda i, j: (i, j) + ), + ] + )(x_ref) + + # Wait for both compute WGs to finish initializing the output + plgpu.barrier_arrive(barrier) + plgpu.barrier_wait(barrier) + + @pl.when(lax.axis_index("wg") == 0) + def _(): + plgpu.copy_smem_to_gmem(o_scratch, o_ref) + plgpu.wait_smem_to_gmem(0, wait_read_only=True) + + kernel = self.kernel( + body, + out_shape=jax.ShapeDtypeStruct((blk_m, blk_n), jnp.float32), + num_threads=3, + thread_name="wg", + scratch_shapes=[ + plgpu.SMEM((blk_m, blk_n), jnp.float32), + plgpu.Barrier(num_arrivals=3), + ], + ) + x = jax.random.uniform(jax.random.key(1234), (m, n), dtype=jnp.float32) + ref = 10 * x.reshape(m // blk_m, blk_m, n // blk_n, blk_n).sum((0, 2)) + np.testing.assert_allclose(kernel(x), ref, rtol=5e-6) + + @parameterized.product(manual_consumed_barriers=[False, True]) + def test_pipelined_pipeline(self, manual_consumed_barriers): + m = n = 512 + + x = jax.random.randint(jax.random.key(0), (m, n), -10, 15, dtype=jnp.int32) + blk_m = blk_n = 64 - with self.assertRaisesRegex(Exception, "Pipeline mode is not supported"): - vadd(x, y) + def body(x_ref, out_gmem_ref, out_ref): + wg_idx = jax.lax.axis_index("wg") + @pl.when(wg_idx == 0) + def _zero_output(): + out_ref[...] = jnp.zeros_like(out_ref) - def test_manual(self): - max_concurrent_steps = 2 - num_steps = 4 + def pipeline_body(_, x_smem, *consumed_barriers): + out_ref[...] += x_smem[...] + if manual_consumed_barriers: + [x_barrier] = consumed_barriers + plgpu.barrier_arrive(x_barrier) - def kernel(x_gmem, o_gmem): - return pl.run_scoped( - functools.partial(scoped_kernel, x_gmem, o_gmem), - plgpu.SMEM((max_concurrent_steps, 32, 16), jnp.float32), - plgpu.SMEM((max_concurrent_steps, 32, 16), jnp.float32), - plgpu.Barrier(1, num_barriers=max_concurrent_steps), + spec = pl.BlockSpec( + block_shape=(blk_m, blk_n), index_map=lambda i, j: (i, j) + ) + pipeline = functools.partial( + mgpu_pipeline.emit_pipeline_warp_specialized, + body=pipeline_body, + grid=(m // blk_m, n // blk_n), + memory_registers=40, + max_concurrent_steps=2, + num_compute_wgs=1, + wg_axis="wg", + manual_consumed_barriers=manual_consumed_barriers, + in_specs=[spec], ) - def scoped_kernel(x_gmem, o_gmem, x_smem, o_smem, barrier): - gmem_slice = pl.ds(pl.program_id(0) * 32, 32) - - def body(step, _): - slot = step % max_concurrent_steps + @functools.partial( + pl.run_scoped, + allocs=pipeline(pipeline_state=None).get_allocations(x_ref), + collective_axes="wg", + ) + def _pipeline_scope(allocs): + @pl.loop(0, 2) + def _outer_loop(_): + @pl.loop(0, 4) + def _pipeline_loop(i): + state = plgpu.PipelinePipeline.START + state = jnp.where(i > 0, plgpu.PipelinePipeline.STEADY, state) + state = jnp.where(i == 3, plgpu.PipelinePipeline.STOP, state) + pipeline(pipeline_state=state)(x_ref, allocations=allocs) + # Make sure we have properly quiesced the pipeline. + pipeline(pipeline_state=None)(x_ref, allocations=allocs) + + @pl.when(wg_idx == 0) + def _store_out(): + out_gmem_ref[...] = out_ref[...] + + kernel = self.kernel( + body, + out_shape=jax.ShapeDtypeStruct((blk_m, blk_n), jnp.int32), + compiler_params=plgpu.CompilerParams(approx_math=True), + scratch_shapes=[plgpu.SMEM((blk_m, blk_n), jnp.int32)], + grid=(1,), + grid_names=("_",), + num_threads=2, + thread_name="wg", + ) + out = kernel(x) + np.testing.assert_array_equal( + out, x.reshape(m // blk_m, blk_m, n // blk_n, blk_n).sum((0, 2)) * 10 + ) - # Wait for the current GMEM->SMEM copy to complete. - plgpu.barrier_wait(barrier.at[slot]) - # Wait for the previous output SMEM->GMEM copy to complete. - plgpu.wait_smem_to_gmem(max_concurrent_steps - 1) + @parameterized.product(manual_consumed_barriers=[False, True]) + def test_pipeline_with_manual_allocation(self, manual_consumed_barriers): + m = n = 512 - o_smem.at[slot][...] = x_smem.at[slot][...] + 1.0 + x = jax.random.randint(jax.random.key(4), (m, n), -10, 15, dtype=jnp.int32) + y = jax.random.randint(jax.random.key(5), (m, n), -10, 15, dtype=jnp.int32) + blk_m = blk_n = 64 - plgpu.commit_smem() - plgpu.copy_smem_to_gmem( - o_smem.at[slot], o_gmem.at[gmem_slice, pl.ds(step * 16, 16)] - ) + def body(x_ref, y_ref, out_gmem_ref, out_ref): + wg_idx = jax.lax.axis_index("wg") + @pl.when(wg_idx == 0) + def _zero_output(): + out_ref[...] = jnp.zeros_like(out_ref) - fetch_step = step + max_concurrent_steps - fetch_slot = slot # (x + y) % y == x % y - jax.lax.cond( - fetch_step < num_steps, - lambda: plgpu.copy_gmem_to_smem( - x_gmem.at[gmem_slice, pl.ds(fetch_step * 16, 16)], - x_smem.at[fetch_slot], - barrier.at[fetch_slot], - ), - lambda: None, - ) - return () + def pipeline_body(_, x_smem, y_smem, *consumed_barriers): + out_ref[...] += x_smem[...] + y_smem[...] + for b in consumed_barriers: + plgpu.barrier_arrive(b) - # Initialize the pipeline. - for slot in range(min(max_concurrent_steps, num_steps)): - plgpu.copy_gmem_to_smem( - x_gmem.at[gmem_slice, pl.ds(slot * 16, 16)], - x_smem.at[slot], - barrier.at[slot], - ) + spec = pl.BlockSpec( + block_shape=(blk_m, blk_n), index_map=lambda i, j: (i, j) + ) + pipeline = mgpu_pipeline.emit_pipeline_warp_specialized( + body=pipeline_body, + grid=(m // blk_m, n // blk_n), + memory_registers=40, + max_concurrent_steps=2, + num_compute_wgs=1, + wg_axis="wg", + manual_consumed_barriers=manual_consumed_barriers, + in_specs=[spec, spec], + ) - jax.lax.fori_loop(0, num_steps, body, ()) + @functools.partial( + pl.run_scoped, + allocs=pipeline.get_allocations(x_ref, y_ref), + collective_axes="wg", + ) + def _alloc_scope(allocs): + @pl.loop(0, 4) + def _outer_loop(_): + pipeline(x_ref, y_ref, allocations=allocs) - # Finalize the pipeline. - plgpu.wait_smem_to_gmem(0) + @pl.when(wg_idx == 0) + def _store_out(): + out_gmem_ref[...] = out_ref[...] - x = jnp.arange(32 * 4 * 64).reshape(32 * 4, 64).astype(jnp.float32) - kernel_fn = pl.pallas_call( - kernel, - in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], - out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), - out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), - grid=(4, 1), + kernel = self.kernel( + body, + out_shape=jax.ShapeDtypeStruct((blk_m, blk_n), jnp.int32), + compiler_params=plgpu.CompilerParams(approx_math=True), + scratch_shapes=[plgpu.SMEM((blk_m, blk_n), jnp.int32)], + grid=(1,), + grid_names=("_",), + num_threads=2, + thread_name="wg", + ) + np.testing.assert_array_equal( + kernel(x, y), + (x + y).reshape(m // blk_m, blk_m, n // blk_n, blk_n).sum((0, 2)) * 4, ) - np.testing.assert_array_equal(kernel_fn(x), x + 1.0) - @parameterized.parameters( - ((),), - ((plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128)),), - ) - def test_emit(self, transforms): + @jtu.thread_unsafe_test() # Modifies ``os.environ``. + def test_collective(self): num_steps = 4 def kernel(x_gmem, o_gmem): - plgpu.emit_pipeline( - kernel_body, - in_specs=[ - plgpu.GPUBlockSpec( - (64, 64), lambda i: (0, i), transforms=transforms - ) - ], - out_specs=[ - plgpu.GPUBlockSpec( - (64, 64), lambda i: (0, i), transforms=transforms - ) - ], - grid=(num_steps,), - max_concurrent_steps=2, - )(x_gmem, o_gmem) - - def kernel_body(x_smem, o_smem): - # +1 for the indexing done by ``emit_pipeline`. - self.assertLen(x_smem.transforms, len(transforms) + 1) - o_smem[...] = x_smem[...] + 1.0 + cluster_idx = lax.axis_index("cluster") + in_specs = [ + plgpu.BlockSpec( + (64, 64), lambda i: (0, i), collective_axes=("cluster",) + ) + ] + out_specs = [plgpu.BlockSpec((1, 64, 64), lambda i: (cluster_idx, 0, i))] + # Run a few times to make sure we leave barriers in a good state. + for _ in range(3): + def pipeline_body(_, x_smem, o_smem): + o_smem[0, ...] = x_smem[...] + 1.0 + plgpu.emit_pipeline_warp_specialized( + pipeline_body, + in_specs=in_specs, + out_specs=out_specs, + grid=(num_steps,), + max_concurrent_steps=2, + num_compute_wgs=1, + memory_registers=40, + wg_axis="wg", + )(x_gmem, o_gmem) x = jnp.arange(64 * num_steps * 64) x = x.reshape(-1, num_steps * 64).astype(jnp.float32) - kernel_fn = pl.pallas_call( + kernel_fn = self.kernel( kernel, - in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], - out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), - out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + out_shape=jax.ShapeDtypeStruct((2, *x.shape), x.dtype), + num_threads=2, + thread_name="wg", + cluster=(2,), + cluster_names=("cluster",) ) - np.testing.assert_array_equal(kernel_fn(x), x + 1.0) + with jtu.set_env(MOSAIC_GPU_DUMP_PTX="1"), self.capture_stdout() as ptx: + y = jax.block_until_ready(kernel_fn(x)) + self.assertIn( + "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster", + ptx(), + ) + np.testing.assert_array_equal(y, np.stack([x + 1.0, x + 1.0])) - def test_nested_emit(self): - num_steps = 4 + @parameterized.parameters((pl.Squeezed(),), (None,)) + def test_emit_with_squeezed_dim(self, squeezed_dim): + shape = (16, 256) + num_steps = shape[0] def kernel(x_gmem, o_gmem): - plgpu.emit_pipeline( - nested_kernel, - in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], - out_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], - grid=(), - )(x_gmem, o_gmem) - - def nested_kernel(x_gmem, o_gmem): - plgpu.emit_pipeline( - nested_kernel_body, - in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], - out_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], + plgpu.emit_pipeline_warp_specialized( + kernel_body, + in_specs=[pl.BlockSpec((squeezed_dim, shape[1]), lambda i: (i, 0))], + out_specs=[pl.BlockSpec((squeezed_dim, shape[1]), lambda i: (i, 0))], grid=(num_steps,), max_concurrent_steps=2, + num_compute_wgs=1, + memory_registers=40, + wg_axis="wg", )(x_gmem, o_gmem) - def nested_kernel_body(x_smem, o_smem): - o_smem[...] = x_smem[...] + 1.0 + def kernel_body(_, in_smem, o_smem): + o_smem[...] = in_smem[...] + 1 - x = jnp.arange(32 * num_steps * 16) - x = x.reshape(-1, num_steps * 16).astype(jnp.float32) - kernel_fn = pl.pallas_call( + kernel_fn = self.kernel( kernel, - in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], - out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), - out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + out_shape=jax.ShapeDtypeStruct((16, 256), jnp.int32), + num_threads=2, + thread_name="wg", ) - np.testing.assert_array_equal(kernel_fn(x), x + 1.0) + x = jnp.arange(16 * 256, dtype=jnp.int32).reshape(16, 256) + np.testing.assert_array_equal(kernel_fn(x), x + 1) - def test_emit_with_grid_invariant_output(self): - num_steps = 4 - def kernel(x_gmem, o_gmem): - plgpu.emit_pipeline( - kernel_body, - in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], - out_specs=[pl.BlockSpec((32, 16), lambda i: (0, 0))], - grid=(num_steps,), - max_concurrent_steps=2, - )(x_gmem, o_gmem) +class WarpSpecializedPipelineWGTest( + WarpSpecializedPipelineTest, + lowering_semantics=plgpu.LoweringSemantics.Warpgroup, +): + ... - def kernel_body(x_smem, o_smem): - o_smem[...] = x_smem[...] + 1.0 - x = jnp.arange(32 * num_steps * 16) - x = x.reshape(-1, num_steps * 16).astype(jnp.float32) - kernel_fn = pl.pallas_call( - kernel, - in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], - out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), - out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), +class CoreMapTest(PallasTest, jtu.CudaArchSpecificTest): + + def test_multiple_wg(self): + + @functools.partial( + self.kernel, + out_shape=jnp.zeros((2, 128), np.int32), + num_threads=2, + thread_name="wg", ) - y = jnp.empty_like(x) - for i in range(num_steps): - i_slice = slice(16 * i, 16 * (i + 1)) - y = y.at[:, :16].set(x[:, i_slice] + 1) - # We only compare the elements in the first 16 columns, because the rest - # are never written to. - np.testing.assert_array_equal(kernel_fn(x)[:, :16], y[:, :16]) + def kernel(o_ref): + wg_idx = jax.lax.axis_index("wg") + o_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,)) - def test_emit_with_parallel_grid(self): - num_steps1 = 4 - num_steps2 = 5 + np.testing.assert_array_equal( + kernel(), np.repeat(np.arange(2), 128).reshape(2, 128) + ) - def kernel(x_gmem, o_gmem): - pid = pl.program_id(0) - plgpu.emit_pipeline( - kernel_body, - in_specs=[pl.BlockSpec((32, 16), lambda i: (pid, i))], - out_specs=[pl.BlockSpec((32, 16), lambda i: (pid, i))], - grid=(num_steps2,), - max_concurrent_steps=2, - )(x_gmem, o_gmem) + def test_multiple_wg_with_grid(self): - def kernel_body(x_smem, o_smem): - o_smem[...] = x_smem[...] + 1.0 + @functools.partial( + self.kernel, + out_shape=jnp.zeros((4, 2, 128), np.int32), + grid=(2, 2), + grid_names=("x", "y"), + num_threads=2, + thread_name="wg", + ) + def kernel(o_ref): + xy_idx = jax.lax.axis_index(("x", "y")) + yx_idx = jax.lax.axis_index(("y", "x")) + wg_idx = jax.lax.axis_index("wg") + num_wgs = jax.lax.axis_size("wg") + o_ref[xy_idx, wg_idx] = jnp.broadcast_to( + yx_idx * num_wgs + wg_idx, (128,) + ) - x = jnp.arange(num_steps1 * 32 * num_steps2 * 16) - x = x.reshape(-1, num_steps2 * 16).astype(jnp.float32) - kernel_fn = pl.pallas_call( - kernel, - in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], - out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), - out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), - grid=(num_steps1,), + np.testing.assert_array_equal( + kernel(), np.repeat([0, 1, 4, 5, 2, 3, 6, 7], 128).reshape(4, 2, 128) ) - y = x + 1.0 - np.testing.assert_array_equal(kernel_fn(x), y) - def test_emit_with_2d_grid(self): - num_steps1 = 4 - num_steps2 = 5 + def test_multiple_wg_with_squashed_grid(self): + # Tests whether a grid with >3 logical dimensions is correctly squashed to + # 3 CUDA grid dimensions. + b = 4 + x_dim = 3 + y_dim = 5 + z_dim = 7 + num_threads = 2 - def kernel(x_gmem, o_gmem): - plgpu.emit_pipeline( - kernel_body, - in_specs=[pl.BlockSpec((32, 16, 8), lambda i, j: (0, i, j))], - out_specs=[pl.BlockSpec((32, 16, 8), lambda i, j: (0, i, j))], - grid=(num_steps1, num_steps2), - max_concurrent_steps=2, - )(x_gmem, o_gmem) + @functools.partial( + self.kernel, + out_shape=jnp.zeros( + (b, x_dim, y_dim, z_dim, num_threads, 128), np.int32 + ), + grid=(b, x_dim, y_dim, z_dim), + grid_names=("b", "x", "y", "z"), + num_threads=num_threads, + thread_name="wg", + ) + def kernel(o_ref): + b_idx = jax.lax.axis_index("b") + x_idx = jax.lax.axis_index("x") + y_idx = jax.lax.axis_index("y") + z_idx = jax.lax.axis_index("z") + wg_idx = jax.lax.axis_index("wg") + bxyzw_idx = jax.lax.axis_index(("b", "x", "y", "z", "wg")) + o_ref[b_idx, x_idx, y_idx, z_idx, wg_idx] = jnp.broadcast_to( + bxyzw_idx, (128,) + ) - def kernel_body(x_smem, o_smem): - o_smem[...] = x_smem[...] + 1.0 + result = kernel()[:, :, :, :, :, 0] + ref = np.arange(b * x_dim * y_dim * z_dim * num_threads).reshape( + result.shape + ) + np.testing.assert_array_equal(result, ref) - x = jnp.arange(32 * num_steps1 * 16 * num_steps2 * 8) - x = x.reshape(-1, num_steps1 * 16, num_steps2 * 8).astype(jnp.float32) - kernel_fn = pl.pallas_call( - kernel, - in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], - out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), - out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + def test_cross_wg_barrier(self): + @functools.partial( + self.kernel, + out_shape=jnp.zeros((2, 128), np.int32), + # Each warpgroup is a single logical thread! + scratch_shapes=[plgpu.Barrier(num_arrivals=2)], + num_threads=2, + thread_name="wg", ) - np.testing.assert_array_equal(kernel_fn(x), x + 1.0) + def kernel(o_ref, barrier): + plgpu.barrier_arrive(barrier) + plgpu.barrier_wait(barrier) + wg_idx = jax.lax.axis_index("wg") + o_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,)) + + np.testing.assert_array_equal( + kernel(), np.repeat([0, 1], 128).reshape(2, 128) + ) + + def test_cluster(self): + @functools.partial( + self.kernel, + out_shape=jnp.zeros(128, np.int32), + grid=(2,), + grid_names=("x",), + cluster=(2,), + cluster_names=("cluster",), + ) + def kernel(ref): + block_idx = jax.lax.axis_index("x") + cluster_idx = jax.lax.axis_index("cluster") + pl.debug_print("block: {} cluster: {}", block_idx, cluster_idx) + + ref[...] = ref[...] + + with self.capture_stdout() as output: + jax.block_until_ready(kernel()) + self.assertEqual( + set(output().splitlines()), + { + "block: 0 cluster: 0", + "block: 1 cluster: 0", + "block: 0 cluster: 1", + "block: 1 cluster: 1", + }, + ) + + def test_realistic_matmul_with_cluster(self): + self.skip_unless_sm90a() # Requires WGMMA. + + dtype = jnp.float16 + swizzle = 128 + elems_128b = swizzle // jnp.dtype(dtype).itemsize + grid_m, grid_k, grid_n = 132, 10, 32 + # TODO(slebedev): Remove ``grid_tile_n`` to simplify the test. + grid_tile_n = 4 + assert grid_n % grid_tile_n == 0 + cluster_m = 2 + cluster_n = 2 + cluster_tile_n = min(cluster_n, grid_tile_n) + tile_m = tile_n = 128 + assert tile_m % elems_128b == 0 + tile_k = elems_128b + m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n + + transforms = self.default_transforms(dtype=dtype) + + max_concurrent_steps = 2 + delay_release = 1 + + @functools.partial( + self.kernel, + out_shape=jax.ShapeDtypeStruct((m, n), dtype), + scratch_shapes=[ + plgpu.SMEM( + (max_concurrent_steps, tile_m, tile_k), + dtype, + transforms=transforms, + ), + plgpu.SMEM( + (max_concurrent_steps, tile_k, tile_n), + dtype, + transforms=transforms, + ), + plgpu.SMEM((tile_m, tile_n), dtype, transforms=transforms), + plgpu.ACC((tile_m, tile_n), jnp.float32), + plgpu.Barrier(num_arrivals=2, num_barriers=max_concurrent_steps), + plgpu.ClusterBarrier( + collective_axes=(("x", "z"), "y"), + num_barriers=max_concurrent_steps, + ), + ], + grid=(grid_tile_n, grid_m, grid_n // grid_tile_n), + grid_names=("tile_n", "m", "n"), + cluster=(cluster_tile_n, cluster_m, cluster_n // cluster_tile_n), + cluster_names=("x", "y", "z"), + ) + def kernel( + a_gmem, + b_gmem, + o_gmem, + a_smem, + b_smem, + o_smem, + acc, + barrier, + cluster_barrier, + ): + m_slice = pl.ds(lax.axis_index("m") * tile_m, tile_m) + n_slice = pl.ds( + (lax.axis_index("tile_n") + lax.axis_index("n") * grid_tile_n) + * tile_n, + tile_n, + ) + + def fetch(step, slot): + if not isinstance(slot, int): # Skip in initialization. + plgpu.barrier_arrive(cluster_barrier.at[slot]) + plgpu.barrier_wait(cluster_barrier.at[slot]) + + k_slice = pl.ds(step * tile_k, tile_k) + plgpu.copy_gmem_to_smem( + a_gmem.at[m_slice, k_slice], + a_smem.at[slot], + barrier.at[slot], + collective_axes=("x", "z"), + ) + plgpu.copy_gmem_to_smem( + b_gmem.at[k_slice, n_slice], + b_smem.at[slot], + barrier.at[slot], + collective_axes="y", + ) + # Initialize the pipeline. + for slot in range(min(max_concurrent_steps, grid_k)): + fetch(slot, slot) -class PipelineSm90ATest(PallasSm90ATest): + def body(step, _): + slot = step % max_concurrent_steps + plgpu.barrier_wait(barrier.at[slot]) - def test_realistic_matmul(self): - dtype = jnp.float16 - swizzle = 128 - elems_128b = swizzle // jnp.dtype(dtype).itemsize - grid_m, grid_k, grid_n = 132, 10, 4 - tile_m = tile_n = 128 - assert tile_m % elems_128b == 0 - tile_k = elems_128b - m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n + plgpu.wgmma(acc, a_smem.at[slot], b_smem.at[slot]) + plgpu.wgmma_wait(delay_release) - def kernel(a_gmem, b_gmem, o_smem, acc): - def kernel_body(a_smem, b_smem): - assert a_smem.shape == (tile_m, tile_k) - assert b_smem.shape == (tile_k, tile_n) - plgpu.wgmma(acc, a_smem, b_smem) - plgpu.wgmma_wait(1) + fetch_step = step + (max_concurrent_steps - delay_release) + fetch_slot = lax.rem(fetch_step, max_concurrent_steps) + jax.lax.cond( + lax.bitwise_and(step >= delay_release, fetch_step < grid_k), + lambda: fetch(fetch_step, fetch_slot), + lambda: None, + ) + return () - pid_m = pl.program_id(0) - pid_n = pl.program_id(1) - plgpu.emit_pipeline( - kernel_body, - in_specs=[ - plgpu.GPUBlockSpec( - (tile_m, tile_k), - lambda k: (pid_m, k), - transforms=( - plgpu.TilingTransform((64, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), - plgpu.GPUBlockSpec( - (tile_k, tile_n), - lambda k: (k, pid_n), - transforms=( - plgpu.TilingTransform((elems_128b, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), - ], - grid=(grid_k,), - max_concurrent_steps=2, - delay_release=1, - )(a_gmem, b_gmem) + jax.lax.fori_loop(0, grid_k, body, ()) + # Finalize the pipeline. o_smem[...] = acc[...].astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(o_smem, o_gmem.at[m_slice, n_slice]) + plgpu.wait_smem_to_gmem(0) key1, key2 = jax.random.split(jax.random.key(42), 2) a = jax.random.uniform(key1, shape=(m, k), dtype=dtype) b = jax.random.uniform(key2, shape=(k, n), dtype=dtype) - - res = pl.pallas_call( - kernel, - in_specs=[ - pl.BlockSpec(memory_space=plgpu.GMEM), - pl.BlockSpec(memory_space=plgpu.GMEM) - ], - out_specs=plgpu.GPUBlockSpec( - (tile_m, tile_n), - lambda m, n: (m, n), - transforms=( - plgpu.TilingTransform((64, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), - out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16), - scratch_shapes=[plgpu.ACC((tile_m, tile_n), jnp.float32)], - grid=(grid_m, grid_n), - )(a, b) - np.testing.assert_array_equal(res, a @ b) + np.testing.assert_array_equal(kernel(a, b), a @ b) -class WarpSpecializedPipelineTest(PallasTest): - - @parameterized.product(m=[512], n=[512], - manual_consumed_barriers=[False, True]) - def test_pipelined_copy(self, m, n, manual_consumed_barriers): - x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float16) - o = jnp.zeros((m, n), dtype=jnp.float16) - blk_m = blk_n = 64 - o_last_block = jnp.zeros((blk_m, blk_n), dtype=jnp.float16) +class CoreMapWGTest( + CoreMapTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup +): + ... - def copy_kernel(x_smem, o_smem, o_last_block_smem, *consumed_barriers): - # TODO(justinfu): Have each wg compute a separate slice - # after multiple-indexers are supported. - # This is currently a race, but the values written are the same. - o_smem[...] = x_smem[...] - o_last_block_smem[...] = x_smem[...] - if manual_consumed_barriers: - [x_barrier] = consumed_barriers - plgpu.barrier_arrive(x_barrier) - block_spec = plgpu.GPUBlockSpec( - block_shape=(blk_m, blk_n), - index_map=lambda i, j: (i, j), - transforms=[], - ) - pipeline = mgpu_pipeline.emit_pipeline_warp_specialized( - copy_kernel, - grid=(m // blk_m, n // blk_n), - memory_registers=40, - max_concurrent_steps=2, - num_compute_wgs=2, - wg_axis="wg", - manual_consumed_barriers=manual_consumed_barriers, - in_specs=[block_spec], - out_specs=[block_spec, - # Create an index-invariant output. - plgpu.GPUBlockSpec(block_shape=(blk_m, blk_n), - index_map=lambda i, j: (0, 0)) - ], - ) - mesh = plgpu.GPUMesh(grid=(1,), num_threads=3, axis_names=("_", "wg")) - def run(refs): - @pl.core_map( - mesh, compiler_params=plgpu.GPUCompilerParams(approx_math=True) - ) - def _kernel_entry(): - pipeline(*refs) - @jax.jit - def run_function(x, o, o_last_block): - _, out, out_last = pl.run_state(run)((x, o, o_last_block)) - return (out, out_last) - out, out_last_block = run_function(x, o, o_last_block) - np.testing.assert_array_equal(out, x) - np.testing.assert_array_equal(out_last_block, x[-blk_m:, -blk_n:]) - def test_elementwise_add(self, m=256, n=256, num_compute_wgs=2): - blk_m = blk_n = 64 - x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32) - y = jax.random.uniform(jax.random.key(1), (m, n), dtype=jnp.float32) - o = jnp.zeros((m, n), dtype=jnp.float32) +class PrettyPrintingTest(PallasTest): - def tiled_add_kernel(x_smem, y_smem, o_smem): - # TODO(justinfu): Have each wg compute a separate slice - # after multiple-indexers are supported. - # This is currently a race, but the values written are the same. - o_smem[...] = x_smem[...] + y_smem[...] + def test_load(self): - pipeline = mgpu_pipeline.emit_pipeline_warp_specialized( - tiled_add_kernel, - grid=(m // blk_m, n // blk_n), - max_concurrent_steps=2, - num_compute_wgs=num_compute_wgs, - memory_registers=40, - wg_axis="wg", - in_specs=[ - plgpu.GPUBlockSpec( - block_shape=(blk_m, blk_n), - index_map=lambda i, j: (i, j), - transforms=[]), - plgpu.GPUBlockSpec( - block_shape=(blk_m, blk_n), - index_map=lambda i, j: (i, j), - transforms=[]), - ], - out_specs=[ - plgpu.GPUBlockSpec( - block_shape=(blk_m, blk_n), - index_map=lambda i, j: (i, j), - transforms=[])], - ) - mesh = plgpu.GPUMesh( - grid=(1,), num_threads=num_compute_wgs + 1, axis_names=("_", "wg") + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([2, 128], jnp.float32), + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=plgpu.BlockSpec(memory_space=plgpu.SMEM), ) - def run(refs): - @pl.core_map( - mesh, compiler_params=plgpu.GPUCompilerParams(approx_math=True) - ) - def _kernel_entry(): - pipeline(*refs) - @jax.jit - def run_function(x, y, o): - _, _, out = pl.run_state(run)((x, y, o)) - return out - out = run_function(x, y, o) - reference = x + y - np.testing.assert_allclose(out, reference, atol=1e-4) - - def test_carry_accumulate(self, m=256, n=256, num_compute_wgs=2): - blk_m = blk_n = 64 - x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32) - acc_init = jnp.zeros((blk_m, blk_n), dtype=jnp.float32) + def kernel(x_ref, o_ref): + for i in range(2): + x = plgpu.load(x_ref, (i,)) + o_ref[i, ...] = x - def _scoped(acc_smem, x_gmem, acc_gmem): - def _compute_thread(): - # Cast the init value to the same layout as x_smem, so the pipeline loop - # carry has a constant signature. - o_acc = plgpu.layout_cast( - jnp.full((blk_m, blk_n,), 0, dtype=jnp.float32), - plgpu.Layout.WG_STRIDED((blk_m, blk_n), vec_size=2)) - carry_init = (o_acc,) - # Pass control to the pipeline emitter and return the final carry. - final_carry = (yield carry_init) - o_final, = final_carry - # Note that both compute WGs are doing identical work so the potential - # race condition on the store here won't affect the result. - acc_smem[...] = o_final - plgpu.commit_smem() - plgpu.copy_smem_to_gmem(acc_smem, acc_gmem) - plgpu.wait_smem_to_gmem(0) + _ = str(jax.make_jaxpr(kernel)(jax.ShapeDtypeStruct((2, 128), jnp.float32))) - def tiled_acc_kernel(x_smem, carry): - o_carry, = carry - new_carry = x_smem[...] + o_carry - return (new_carry,) + def test_copy_primitives(self): + num_steps = 4 - pipeline = mgpu_pipeline.emit_pipeline_warp_specialized( - tiled_acc_kernel, - grid=(m // blk_m, n // blk_n), - max_concurrent_steps=2, - num_compute_wgs=num_compute_wgs, - memory_registers=40, - wg_axis="wg", - carry_coroutine=_compute_thread, - in_specs=[ - plgpu.GPUBlockSpec( - block_shape=(blk_m, blk_n), - index_map=lambda i, j: (i, j), - transforms=[]), + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((64, 64), jnp.float32), + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + ) + def kernel(x_gmem, o_gmem): + # ``plgpu.emit_pipeline`` is implemented in terms of async copy and + # synchronization primitives. + plgpu.emit_pipeline( + kernel_body, + in_specs=[pl.BlockSpec((64, 64), lambda i: (0, i))], + out_specs=[ + pl.BlockSpec( + (64, 64), + lambda i: (0, i), + ) ], - out_specs=[], - ) - pipeline(x_gmem) - - mesh = plgpu.GPUMesh( - grid=(1,), - num_threads=num_compute_wgs + 1, - axis_names=("_", "wg",), - ) - def run(refs): - x_ref, acc_ref = refs - @pl.core_map(mesh) - def _kernel_entry(): - pl.run_scoped( - functools.partial(_scoped, x_gmem=x_ref, acc_gmem=acc_ref), - plgpu.SMEM((blk_m, blk_n), jnp.float32) - ) - @jax.jit - def run_function(x, acc): - _, out_acc = pl.run_state(run)((x, acc)) - return out_acc - out_acc = run_function(x, acc_init) - ref = jnp.sum(jnp.stack(np.split(x, m // blk_m, axis=0)), axis=0) - ref = jnp.sum(jnp.stack(np.split(ref, n // blk_n, axis=1)), axis=0) - np.testing.assert_allclose(out_acc, ref, atol=1e-4) + grid=(num_steps,), + max_concurrent_steps=2, + )(x_gmem, o_gmem) + def kernel_body(_, x_smem, o_smem): + o_smem[...] = x_smem[...] + 1.0 -class CoreMapTest(PallasTest): + _ = str(jax.make_jaxpr(kernel)(jax.ShapeDtypeStruct((64, 64), jnp.float32))) - def test_multiple_wg(self): - mesh = plgpu.GPUMesh(num_threads=2, axis_names=("y",)) + def test_wgmma(self): + transforms = self.default_transforms(dtype=jnp.float16) - @jax.jit - def f(): - @pl.run_state - def inner(y_ref): - @pl.core_map(mesh) - def kernel(): - wg_idx = jax.lax.axis_index("y") - y_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,)) - y_init = jnp.zeros((2, 128), np.int32) - return inner(y_init) - np.testing.assert_array_equal( - f(), np.repeat(np.arange(2), 128).reshape(2, 128) + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32), + in_specs=[ + plgpu.BlockSpec(transforms=transforms), + plgpu.BlockSpec(transforms=transforms), + ], ) + def kernel(a_ref, b_ref, o_ref): + def scope(acc_ref): + plgpu.wgmma(acc_ref, a_ref[...], b_ref) + return acc_ref[...] - def test_multiple_wg_with_grid(self): - mesh = plgpu.GPUMesh(grid=(2, 2), num_threads=2, axis_names=("x", "y", "wg")) + o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 192), jnp.float32)) - @jax.jit - def f(): - @pl.run_state - def inner(y_ref): - @pl.core_map(mesh) - def kernel(): - xy_idx = jax.lax.axis_index(("x", "y")) - yx_idx = jax.lax.axis_index(("y", "x")) - wg_idx = jax.lax.axis_index("wg") - num_wgs = jax.lax.psum(1, "wg") - y_ref[xy_idx, wg_idx] = jnp.broadcast_to( - yx_idx * num_wgs + wg_idx, (128,) - ) - y_init = jnp.zeros((4, 2, 128), np.int32) - return inner(y_init) - np.testing.assert_array_equal( - f(), np.repeat([0, 1, 4, 5, 2, 3, 6, 7], 128).reshape(4, 2, 128) + _ = str( + jax.make_jaxpr(kernel)( + jax.ShapeDtypeStruct((64, 128), jnp.float16), + jax.ShapeDtypeStruct((128, 192), jnp.float16), + ) ) - def test_multiple_wg_with_squashed_grid(self): - # Tests whether a grid with >3 logical dimensions is correctly squashed to - # 3 CUDA grid dimensions. - b = 4 - x_dim = 3 - y_dim = 5 - z_dim = 7 - num_threads = 2 - mesh = plgpu.GPUMesh(grid=(b, x_dim, y_dim, z_dim), - num_threads=num_threads, - axis_names=("b", "x", "y", "z", "wg")) - @jax.jit - def f(): - @pl.run_state - def inner(y_ref): - @pl.core_map(mesh) - def _(): - b_idx = jax.lax.axis_index("b") - x_idx = jax.lax.axis_index("x") - y_idx = jax.lax.axis_index("y") - z_idx = jax.lax.axis_index("z") - wg_idx = jax.lax.axis_index("wg") - bxyzw_idx = jax.lax.axis_index(("b", "x", "y", "z", "wg")) - y_ref[b_idx, x_idx, y_idx, z_idx, wg_idx] = jnp.broadcast_to( - bxyzw_idx, (128,) - ) - y_init = jnp.zeros((b, x_dim, y_dim, z_dim, num_threads, 128), np.int32) - return inner(y_init) - result = f()[:, :, :, :, :, 0] - ref = np.arange(b * x_dim * y_dim * z_dim * num_threads).reshape( - result.shape) - np.testing.assert_array_equal(result, ref) +class ExportTest(PallasTest): + def test_export_succeeds(self): + out_shape = jax.ShapeDtypeStruct([128], jnp.float32) - def test_cross_wg_barrier(self): - mesh = plgpu.GPUMesh(num_threads=2, axis_names=("wg",)) + @functools.partial(self.pallas_call, out_shape=out_shape) + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + 1.0 - @jax.jit - def f(): - @pl.run_state - def inner(y_ref): - @pl.core_map(mesh) - def kernel(): - def scoped(barrier): - plgpu.barrier_arrive(barrier) - plgpu.barrier_wait(barrier) - wg_idx = jax.lax.axis_index("wg") - y_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,)) - # Each warpgroup is a single logical thread! - pl.run_scoped(scoped, plgpu.Barrier(num_arrivals=2)) - y_init = jnp.zeros((2, 128), np.int32) - return inner(y_init) - np.testing.assert_array_equal(f(), np.repeat([0, 1], 128).reshape(2, 128)) + _ = export.export(kernel)(out_shape) class ExamplesTest(PallasTest): # Basic def test_stage0(self): - def body(l_ref, r_ref, o_ref): + x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) + + @functools.partial(self.kernel, out_shape=x) + def kernel(l_ref, r_ref, o_ref): o_ref[...] = l_ref[...] + r_ref[...] - x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x)(x, x) - np.testing.assert_allclose(out, x + x) + np.testing.assert_allclose(kernel(x, x), x + x) # Multi-block kernels def test_stage1(self): row_block = 64 - def body(l_ref, r_ref, o_ref): + x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) + + @functools.partial( + self.kernel, out_shape=x, grid=(2,), grid_names=("rows",) + ) + def kernel(l_ref, r_ref, o_ref): my_slice = pl.ds(lax.axis_index("rows") * row_block, row_block) o_ref[my_slice] = l_ref[my_slice] + r_ref[my_slice] - x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x) - np.testing.assert_allclose(out, x + x) + np.testing.assert_allclose(kernel(x, x), x + x) # Async copies def test_stage3(self): row_block, col_block = 64, 128 - def body(l_ref, r_ref, o_ref): + + @functools.partial( + self.kernel, + out_shape=jax.ShapeDtypeStruct((128, 128), jnp.float16), + scratch_shapes=[ + *([plgpu.SMEM((row_block, col_block), jnp.float16)] * 3), + plgpu.Barrier(num_arrivals=2), + ], + grid=(2,), + grid_names=("rows",), + ) + def kernel(l_ref, r_ref, o_ref, l_smem, r_smem, o_smem, barrier): my_slice = pl.ds(lax.axis_index("rows") * row_block, row_block) - def scoped(l_smem, r_smem, o_smem, barrier): - plgpu.copy_gmem_to_smem(l_ref.at[my_slice], l_smem, barrier) - plgpu.copy_gmem_to_smem(r_ref.at[my_slice], r_smem, barrier) - plgpu.barrier_wait(barrier) - o_smem[...] = l_smem[...] + r_smem[...] - plgpu.commit_smem() - plgpu.copy_smem_to_gmem(o_smem, o_ref.at[my_slice]) - plgpu.wait_smem_to_gmem(0) - pl.run_scoped( - scoped, - *([plgpu.SMEM((row_block, col_block), jnp.float16)] * 3), - plgpu.Barrier(num_arrivals=2), - ) + plgpu.copy_gmem_to_smem(l_ref.at[my_slice], l_smem, barrier) + plgpu.copy_gmem_to_smem(r_ref.at[my_slice], r_smem, barrier) + plgpu.barrier_wait(barrier) + o_smem[...] = l_smem[...] + r_smem[...] + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(o_smem, o_ref.at[my_slice]) + plgpu.wait_smem_to_gmem(0) x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x) - np.testing.assert_allclose(out, x + x) + np.testing.assert_allclose(kernel(x, x), x + x) # Pipelining def test_stage4(self): row_block, col_block = 64, 32 - def body(l_ref, r_ref, o_ref): - def compute(l_smem, r_smem, o_smem): + x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) + + @functools.partial( + self.kernel, out_shape=x, grid=(2,), grid_names=("rows",) + ) + def kernel(l_ref, r_ref, o_ref): + def compute(_, l_smem, r_smem, o_smem): o_smem[...] = l_smem[...] + r_smem[...] r = lax.axis_index("rows") block = pl.BlockSpec((row_block, col_block), lambda c: (r, c)) @@ -2161,20 +6156,24 @@ def compute(l_smem, r_smem, o_smem): out_specs=[block], )(l_ref, r_ref, o_ref) - x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x) - np.testing.assert_allclose(out, x + x) + np.testing.assert_allclose(kernel(x, x), x + x) # Transforms def test_stage5(self): row_block, col_block = 64, 32 - def body(l_ref, r_ref, o_ref): - def compute(l_smem, r_smem, o_smem): + x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) + + @functools.partial( + self.kernel, out_shape=x, grid=(2,), grid_names=("rows",) + ) + def kernel(l_ref, r_ref, o_ref): + def compute(_, l_smem, r_smem, o_smem): o_smem[...] = l_smem[...] + r_smem[...] r = lax.axis_index("rows") - block = plgpu.GPUBlockSpec( - (row_block, col_block), lambda c: (r, c), - transforms=(plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(64)), + block = plgpu.BlockSpec( + (row_block, col_block), + lambda c: (r, c), + transforms=self.default_transforms(swizzle=64, dtype=jnp.float16), ) plgpu.emit_pipeline( compute, @@ -2183,40 +6182,457 @@ def compute(l_smem, r_smem, o_smem): out_specs=[block], )(l_ref, r_ref, o_ref) - x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x) - np.testing.assert_allclose(out, x + x) + np.testing.assert_allclose(kernel(x, x), x + x) + + +class SemaphoreTest(PallasTest): + + def test_lowering(self): + # This is a smoke test until we add support for lowering of semaphore ops. + def body(i_ref1, i_ref2, o_ref, sem_ref): + del i_ref2 # Only here to have a different number of inputs and outputs. + assert sem_ref.shape == (4,) + assert jnp.issubdtype(sem_ref.dtype, pl.semaphore) + o_ref[...] = i_ref1[...] + x = jnp.arange(128, dtype=jnp.float32).reshape((128,)) + kernel = self.pallas_call( + body, + out_shape=x, + scratch_shapes=[plgpu.SemaphoreType.REGULAR((4,))], + ) + text = jax.jit(kernel).lower(x, x).as_text() + self.assertIn( + r"output_operand_aliases =" + r" [#stablehlo.output_operand_alias]", + text, + ) + self.assertIn( + r"(tensor<128xf32>, tensor<128xf32>, tensor<4xi32>) ->" + r" (tensor<128xf32>, tensor<4xi32>)", + text, + ) + + def test_basic(self): + def body(o_ref, sem_ref): + assert jnp.issubdtype(sem_ref.dtype, pl.semaphore) + pl.semaphore_signal(sem_ref) + o_ref[...] = jnp.ones_like(o_ref) + pl.semaphore_wait(sem_ref) + kernel = self.kernel( + body, + out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), + scratch_shapes=[plgpu.SemaphoreType.REGULAR], + grid=(2,), + grid_names=("x",), + ) + text = jax.jit(kernel).lower().as_text() + np.testing.assert_array_equal(kernel(), jnp.ones((128,), jnp.float32)) + # The semaphore array is scaled up by the grid size. + self.assertIn( + r"(tensor<128xf32>, tensor<2xi32>) -> (tensor<128xf32>, tensor<2xi32>)", + text, + ) + + def test_with_profiler(self): + # Dealing with profiler and semaphores together is tricky because they both + # add extra outputs to the HLO op. + def body(o_ref, sem_ref): + assert jnp.issubdtype(sem_ref.dtype, pl.semaphore) + with jax.named_scope("output"): + o_ref[...] = jnp.ones_like(o_ref) + with tempfile.TemporaryDirectory() as tmp_dir: + kernel = self.kernel( + body, + out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), + scratch_shapes=[plgpu.SemaphoreType.REGULAR], + grid=(2,), + grid_names=("x",), + compiler_params=plgpu.CompilerParams(profile_space=32, profile_dir=tmp_dir), + ) + text = jax.jit(kernel).lower().as_text() + np.testing.assert_array_equal(kernel(), jnp.ones((128,), jnp.float32)) + self.assertIn( + r"(tensor<128xf32>, tensor<2xi32>) ->" + r" (tensor<128xf32>, tensor<2xi32>, tensor<512xui32>)", + text, + ) + + def test_global_semaphore(self): + # We signal from block 0 and wait on block 1 to test whether the semaphore + # is globally shared. + def body(out_ref): + sem_ref = pl.get_global(plgpu.SemaphoreType.REGULAR) + block_id = lax.axis_index("x") + @pl.when(block_id == 0) + def _(): + pl.semaphore_signal(sem_ref) + @pl.when(block_id == 1) + def _(): + pl.semaphore_wait(sem_ref) + out_ref[...] = jnp.ones_like(out_ref) + + kernel = self.kernel( + body, + out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), + grid=(10,), + grid_names=("x",), + ) + result = kernel() + np.testing.assert_array_equal(result, jnp.ones((128,), jnp.float32)) + + def test_global_semaphore_with_multiple_threads(self): + def body(out_ref): + sem_ref = pl.get_global(plgpu.SemaphoreType.REGULAR) + block_id = lax.axis_index("x") + @pl.when(block_id == 0) + def _(): + pl.semaphore_signal(sem_ref) + @pl.when(block_id == 1) + def _(): + pl.semaphore_wait(sem_ref) + out_ref[...] = jnp.ones_like(out_ref) + + kernel = self.kernel( + body, + out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), + grid=(10,), + grid_names=("x",), + thread_name="wg", + num_threads=2, + ) + result = kernel() + np.testing.assert_array_equal(result, jnp.ones((128,), jnp.float32)) + + def test_multiple_get_global_semaphores(self): + def body(out_ref): + sem1 = pl.get_global(plgpu.SemaphoreType.REGULAR) + sem2 = pl.get_global(plgpu.SemaphoreType.REGULAR) + block_id = lax.axis_index("x") + @pl.when(block_id == 0) + def _(): + pl.semaphore_signal(sem1, inc=5) + pl.semaphore_signal(sem2, inc=10) + @pl.when(block_id == 1) + def _(): + pl.semaphore_wait(sem1, value=5, decrement=False) + pl.semaphore_wait(sem2, value=10, decrement=False) + val1 = pl.semaphore_read(sem1) + val2 = pl.semaphore_read(sem2) + out_ref[0] = val1 + out_ref[1] = val2 + + kernel = self.kernel( + body, + out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), + grid=(10,), + grid_names=("x",), + ) + result = kernel() + np.testing.assert_array_equal(result, jnp.array([5, 10], jnp.int32)) + + def test_get_global_in_and_outside_control_flow(self): + def body(out_ref): + sem_before = pl.get_global(plgpu.SemaphoreType.REGULAR) + block_id = lax.axis_index("x") + + @pl.when(block_id == 0) + def _(): + sem_inside = pl.get_global(plgpu.SemaphoreType.REGULAR) + pl.semaphore_signal(sem_inside, 7) + pl.semaphore_signal(sem_before, 3) + val_inside = pl.semaphore_read(sem_inside) + out_ref[1] = val_inside + + sem_after = pl.get_global(plgpu.SemaphoreType.REGULAR) + pl.semaphore_signal(sem_after, 11) + val_before = pl.semaphore_read(sem_before) + val_after = pl.semaphore_read(sem_after) + out_ref[0] = val_before + out_ref[2] = val_after + + kernel = self.kernel( + body, + out_shape=jax.ShapeDtypeStruct((3,), jnp.int32), + grid=(1,), + grid_names=("x",), + ) + result = kernel() + np.testing.assert_array_equal(result, jnp.array([3, 7, 11], jnp.int32)) + + def test_multiple_semaphore_scopes(self): + def body(out_ref): + global_sem = pl.get_global(plgpu.SemaphoreType.REGULAR) + + @functools.partial(pl.run_scoped, block_sem=plgpu.SemaphoreType.REGULAR) + def _scope2(block_sem): + block_id = lax.axis_index("x") + pl.semaphore_signal(block_sem) + + @pl.when(block_id == 0) + def _(): + pl.semaphore_signal(global_sem) + + @pl.when(block_id == 1) + def _(): + pl.semaphore_wait(global_sem) + out_ref[...] = jnp.ones_like(out_ref) + + pl.semaphore_wait(block_sem) + + kernel = self.kernel( + body, + out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), + grid=(10,), + grid_names=("x",), + ) + result = kernel() + np.testing.assert_array_equal(result, jnp.ones((128,), jnp.float32)) + + +class ExamplesWGTest( + ExamplesTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup +): + ... class ExamplesSm90ATest(PallasSm90ATest): # WGMMA def test_stage6(self): + self.skip_if_wg_semantics() # `fa.optimization_barrier` does not support f16 arrays. + m_block = n_block = 64 k_block = 32 - def body(l_ref, r_ref, o_ref): - def compute(l_smem, r_smem, o_smem): + x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) + + @functools.partial( + self.kernel, out_shape=x, grid=(2, 2), grid_names=("m", "n") + ) + def kernel(l_ref, r_ref, o_ref): + def compute(_, l_smem, r_smem, o_smem): def do_wgmma(acc_ref): plgpu.wgmma(acc_ref, l_smem, r_smem) return acc_ref[...] o_smem[...] += pl.run_scoped(do_wgmma, plgpu.ACC((m_block, n_block), jnp.float16)) - m, n = lax.axis_index("m"), lax.axis_index("n") - lo_transforms = (plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(64)) - r_transforms = (plgpu.TilingTransform((32, 32)), plgpu.SwizzleTransform(64)) + m = lax.axis_index("m") + n = lax.axis_index("n") + transforms = self.default_transforms(swizzle=64, dtype=jnp.float16) plgpu.emit_pipeline( compute, grid=(l_ref.shape[1] // k_block,), - in_specs=[plgpu.GPUBlockSpec((m_block, k_block), lambda k: (m, k), transforms=lo_transforms), - plgpu.GPUBlockSpec((k_block, n_block), lambda k: (k, n), transforms=r_transforms)], - out_specs=[plgpu.GPUBlockSpec((m_block, n_block), lambda k: (m, n), transforms=lo_transforms)], + in_specs=[ + plgpu.BlockSpec( + (m_block, k_block), lambda k: (m, k), transforms=transforms + ), + plgpu.BlockSpec( + (k_block, n_block), lambda k: (k, n), transforms=transforms + ), + ], + out_specs=[ + plgpu.BlockSpec( + (m_block, n_block), lambda k: (m, n), transforms=transforms + ) + ], )(l_ref, r_ref, o_ref) - x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2, 2), axis_names=("m", "n"))(x, x) - np.testing.assert_allclose(out, x @ x) + np.testing.assert_allclose(kernel(x, x), x @ x) # TODO(apaszke): Clusters and multicast +class ExamplesSm90AWGTest( + ExamplesSm90ATest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup +): + ... + + +class HelpersTest(PallasTest): + + @parameterized.product( + m=[4, 16], + n=[4, 16], + minor_dim=[0, 1], + tile_width=[1, 2, 4], + ) + def test_planar_snake(self, m, n, minor_dim, tile_width): + reference = np.full((m, n), -1) + counter = itertools.count() + minor_size, major_size = (m, n) if minor_dim == 0 else (n, m) + for minor_tile in range(minor_size // tile_width): + for major in range(major_size): + major = major if minor_tile % 2 == 0 else major_size - 1 - major + for minor_in_tile in range(tile_width): + minor = minor_tile * tile_width + minor_in_tile + idx = (minor, major) if minor_dim == 0 else (major, minor) + reference[idx] = next(counter) + results = np.full((m, n), -1) + for lin in range(m * n): + results[plgpu.planar_snake(np.int32(lin), (m, n), minor_dim, tile_width)] = lin + np.testing.assert_array_equal(results, reference) + + def test_planar_snake_golden_with_partial_tile(self): + m, n = 5, 5 + with self.subTest("minor_dim=0 tile_width=3"): + results = np.full((m, n), -1) + for lin in range(m * n): + results[plgpu.planar_snake(np.int32(lin), (m, n), 0, 3)] = lin + expected = np.array([ + [0, 3, 6, 9, 12], + [1, 4, 7, 10, 13], + [2, 5, 8, 11, 14], + [23, 21, 19, 17, 15], + [24, 22, 20, 18, 16]]) + np.testing.assert_array_equal(results, expected) + with self.subTest("minor_dim=1 tile_width=3"): + results = np.full((m, n), -1) + for lin in range(m * n): + results[plgpu.planar_snake(np.int32(lin), (m, n), 1, 3)] = lin + expected = np.array([ + [0, 1, 2, 23, 24], + [3, 4, 5, 21, 22], + [6, 7, 8, 19, 20], + [9, 10, 11, 17, 18], + [12, 13, 14, 15, 16]]) + np.testing.assert_array_equal(results, expected) + + def test_planar_snake_golden(self): + m, n = 8, 8 + with self.subTest("minor_dim=0 tile_width=2"): + results = np.full((m, n), -1) + for lin in range(m * n): + results[plgpu.planar_snake(np.int32(lin), (m, n), 0, 2)] = lin + expected = np.array([ + [0, 2, 4, 6, 8, 10, 12, 14], + [1, 3, 5, 7, 9, 11, 13, 15], + [30, 28, 26, 24, 22, 20, 18, 16], + [31, 29, 27, 25, 23, 21, 19, 17], + [32, 34, 36, 38, 40, 42, 44, 46], + [33, 35, 37, 39, 41, 43, 45, 47], + [62, 60, 58, 56, 54, 52, 50, 48], + [63, 61, 59, 57, 55, 53, 51, 49], + ]) + np.testing.assert_array_equal(results, expected) + with self.subTest("minor_dim=1 tile_width=2"): + results = np.full((m, n), -1) + for lin in range(m * n): + results[plgpu.planar_snake(np.int32(lin), (m, n), 1, 2)] = lin + expected = np.array([ + [0, 1, 30, 31, 32, 33, 62, 63], + [2, 3, 28, 29, 34, 35, 60, 61], + [4, 5, 26, 27, 36, 37, 58, 59], + [6, 7, 24, 25, 38, 39, 56, 57], + [8, 9, 22, 23, 40, 41, 54, 55], + [10, 11, 20, 21, 42, 43, 52, 53], + [12, 13, 18, 19, 44, 45, 50, 51], + [14, 15, 16, 17, 46, 47, 48, 49], + ]) + np.testing.assert_array_equal(results, expected) + with self.subTest("minor_dim=0 tile_width=1"): + results = np.full((m, n), -1) + for lin in range(m * n): + results[plgpu.planar_snake(np.int32(lin), (m, n), 0, 1)] = lin + expected = np.array([ + [0, 1, 2, 3, 4, 5, 6, 7], + [15, 14, 13, 12, 11, 10, 9, 8], + [16, 17, 18, 19, 20, 21, 22, 23], + [31, 30, 29, 28, 27, 26, 25, 24], + [32, 33, 34, 35, 36, 37, 38, 39], + [47, 46, 45, 44, 43, 42, 41, 40], + [48, 49, 50, 51, 52, 53, 54, 55], + [63, 62, 61, 60, 59, 58, 57, 56], + ]) + np.testing.assert_array_equal(results, expected) + + @parameterized.parameters( + ((100,), ()), # grid < SM count + ((300,), ()), # grid > SM count + ((3, 3, 3, 3, 3), ()), # squashed grid dimensions + ((50,), (2, 1)), # small grid w/ cluster + ((50, 4), (1, 2)), # large grid w/ cluster + ) + def test_dynamic_work_scheduling(self, grid, cluster): + if not jtu.is_cuda_compute_capability_at_least("10.0"): + self.skipTest("Only works on a GPU with capability >= sm100a") + grid_names = tuple(str(i) for i in range(len(grid))) + cluster_names = tuple("c"+str(i) for i in range(len(cluster))) + def body(out_gmem, _): + sm_idx = lax.axis_index(grid_names) + cluster_idx = () + if cluster: + cluster_idx = tuple(lax.axis_index(axis) for axis in cluster_names) + @plgpu.dynamic_scheduling_loop(grid_names) + def loop_body(loop_info: plgpu.NDLoopInfo): + out_gmem[*loop_info.index, *cluster_idx] = sm_idx + out_shape = (*grid, *cluster) + max_shared_memory = jax.local_devices()[0].shared_memory_per_block_optin + # Mosaic GPU uses some shared memory implicitly, so we can't + # explicitly request the full amount. + large_amount_of_shared_memory = int(0.9 * max_shared_memory) + result = self.kernel(body, + out_shape=jax.ShapeDtypeStruct(out_shape, jnp.int32), + grid=grid, + grid_names=grid_names, + cluster=cluster, + cluster_names=cluster_names, + # Allocate a large amount of SMEM to prevent multiple blocks + # being scheduled on the same SM. + scratch_shapes=[ + plgpu.SMEM((large_amount_of_shared_memory,), jnp.int8)], + )() + + # Result maps grid_idx -> SM that performed the work. + # Check that each SM had at least 1 block of work. + cluster_size = int(np.prod(cluster)) + num_sms = min(jax.devices()[0].core_count // cluster_size, np.prod(grid)) + histogram = np.histogram(result, bins=range(num_sms+1))[0] + self.assertEqual(np.sum(histogram), np.prod(out_shape)) + self.assertGreaterEqual(np.min(histogram), 1) + # Make sure all blocks > num_sms were stolen. + self.assertEqual(np.max(result), jnp.int32(num_sms) - 1) + + def test_dynamic_work_scheduling_with_carry(self): + if not jtu.is_cuda_compute_capability_at_least("10.0"): + self.skipTest("Only works on a GPU with capability >= sm100a") + # In this test we make SM 0 run a the dynamic scheduling loop while all + # other SMs spin. This means SM 0 should steal all of the work and we + # keep track of the number of stolen blocks in the carry. + blocks_to_steal = 100 + sm_count = jax.devices()[0].core_count + def body(out_gmem, _): + sm_idx = lax.axis_index("x") + global_semaphore = pl.get_global(plgpu.SemaphoreType.REGULAR) + + @pl.when(sm_idx == 0) + def _steal_loop(): + def loop_body(loop_info: plgpu.NDLoopInfo, carry: jax.Array): + del loop_info + return carry + jnp.int32(1) + + final_carry = plgpu.dynamic_scheduling_loop( + ("x",), init_carry=jnp.int32(0) + )(loop_body) + out_gmem[0] = final_carry + pl.semaphore_signal(global_semaphore, inc=sm_count) + + # All SMs wait until SM 0 has finished all blocks. + pl.semaphore_wait(global_semaphore) + + max_shared_memory = jax.local_devices()[0].shared_memory_per_block_optin + # Mosaic GPU uses some shared memory implicitly, so we can't + # explicitly request the full amount. + large_amount_of_shared_memory = int(0.9 * max_shared_memory) + result = self.kernel(body, + out_shape=jax.ShapeDtypeStruct((1,), jnp.int32), + grid=(sm_count + blocks_to_steal,), + grid_names=("x",), + # Allocate a large amount of SMEM to prevent multiple blocks + # being scheduled on the same SM. + scratch_shapes=[ + plgpu.SMEM((large_amount_of_shared_memory,), jnp.int8)], + )() + self.assertEqual(result[0], blocks_to_steal + 1) + + if __name__ == "__main__": absltest.main() diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 0fc375bf64a1..26071b23e66a 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Sequence +from collections.abc import Callable, Sequence import functools import itertools import math import sys -from typing import Any, Callable +from typing import Any import unittest from absl.testing import absltest @@ -30,6 +30,9 @@ from jax._src import linear_util as lu from jax._src import state from jax._src import test_util as jtu +from jax._src.pallas import pallas_call +from jax._src.pallas import pallas_test_util as ptu +from jax._src.pallas import primitives as pallas_primitives from jax.experimental import pallas as pl from jax.interpreters import partial_eval as pe import jax.numpy as jnp @@ -40,31 +43,26 @@ from jax.experimental.pallas import mosaic_gpu as plgpu_mgpu except ImportError: plgpu_mgpu = None - from jax.experimental.pallas import triton as plgpu_triton from jax.experimental.pallas import tpu as pltpu else: plgpu_mgpu = None - plgpu_triton = None pltpu = None -try: - import hypothesis as hp -except (ModuleNotFoundError, ImportError): - raise unittest.SkipTest("tests depend on hypothesis library") - +import hypothesis as hp import hypothesis.extra.numpy as hnp import hypothesis.strategies as hps + # There are many inherited redefinitions of _ # ruff: noqa: F811 jax.config.parse_flags_with_absl() jtu.setup_hypothesis(max_examples=50) -use_mosaic_gpu = jax.config.read("jax_pallas_use_mosaic_gpu") +use_mosaic_gpu = pallas_call._PALLAS_USE_MOSAIC_GPU.value -intx = dtypes.canonicalize_dtype(jnp.int64) -floatx = dtypes.canonicalize_dtype(jnp.float64) +intx = dtypes.default_int_dtype() +floatx = dtypes.default_float_dtype() def wrap_init(f: Callable, nr_args: int): # wrapper for lu.wrap_init with debugging info @@ -76,6 +74,34 @@ def is_power_of_two(n: int) -> bool: return (n > 0) and (n & (n - 1) == 0) +def get_rocm_shared_memory_limit() -> int: + """Get the shared memory (LDS) limit in bytes for ROCm devices. + + Queries rocminfo to get the GROUP segment size dynamically. + Returns 64KB as default if rocminfo fails (MI100/MI200/MI300 all have 64KB LDS). + """ + try: + result = subprocess.run( + ['rocminfo'], capture_output=True, text=True, timeout=10 + ) + if result.returncode != 0: + return 64 * 1024 # Default if rocminfo fails + lines = result.stdout.split('\n') + for i, line in enumerate(lines): + if 'Segment:' in line and 'GROUP' in line: + if i + 1 < len(lines): + size_line = lines[i + 1] + # Match "Size: () KB" with case-insensitive KB check + match = re.search(r'Size:\s+(\d+)\s*\([^)]+\)\s*KB', size_line, re.IGNORECASE) + if match: + size_kb = int(match.group(1)) + return size_kb * 1024 # Convert KB to bytes + except Exception: + pass + # Default for AMD GPUs (MI100/MI200/MI300 all have 64KB LDS) + return 64 * 1024 + + def smem_on_tpu(): if jtu.test_device_matches(["tpu"]): return pltpu.SMEM @@ -127,7 +153,7 @@ def make_shape_dtype_strategy( min_size_exp: int, max_size_exp: int, valid_dtypes: Sequence[jnp.dtype], - max_bytes: int = 2**16, + max_bytes: int = 2**15, ) -> jax.ShapeDtypeStruct: dtype = draw(hps.sampled_from(valid_dtypes)) # To generate shapes with power-of-two sizes, we draw the exponents of the @@ -175,6 +201,8 @@ def select_n_strategy( # TODO(sharadmv,apaszke): enable bf16 # np.dtype(jnp.bfloat16), ], + # Max 4K bytes. Helps avoid slow input generation. + max_bytes=2**12, ) ) allowed_elements = hps.integers(min_value=0, max_value=n_cases - 1) @@ -187,7 +215,7 @@ def select_n_strategy( else: pred_dtype = np.int32 pred = draw(arrays(shape=pred_shape, dtype=pred_dtype, - elements=allowed_elements)) + elements=allowed_elements)) cases = ( draw( arrays(shape=case_shape_dtype.shape, dtype=case_shape_dtype.dtype) @@ -203,7 +231,7 @@ def select_n_strategy( # TODO(sharadmv,apaszke): enable zero dim sizes # TODO(sharadmv,apaszke): enable one dim sizes ( - lax.neg_p, + lax.neg_p, {}, make_shape_dtype_strategy( min_rank=2, max_rank=3, @@ -213,7 +241,7 @@ def select_n_strategy( ), ), ( - lax.not_p, + lax.not_p, {}, make_shape_dtype_strategy( min_rank=2, max_rank=3, @@ -225,6 +253,7 @@ def select_n_strategy( *[ ( prim, + params, make_shape_dtype_strategy( min_rank=2, max_rank=3, @@ -233,23 +262,23 @@ def select_n_strategy( valid_dtypes=[jnp.dtype("float32")], ), ) - for prim in [ - lax.exp_p, - lax.tanh_p, - lax.logistic_p, - lax.rsqrt_p, - lax.log_p, - lax.exp2_p, - lax.abs_p, - lax.log1p_p, - lax.sin_p, - lax.sqrt_p, + for prim, params in [ + (lax.abs_p, {}), + (lax.exp_p, {"accuracy": None}), + (lax.tanh_p, {"accuracy": None}), + (lax.logistic_p, {"accuracy": None}), + (lax.rsqrt_p, {"accuracy": None}), + (lax.log_p, {"accuracy": None}), + (lax.exp2_p, {"accuracy": None}), + (lax.log1p_p, {"accuracy": None}), + (lax.sin_p, {"accuracy": None}), + (lax.sqrt_p, {"accuracy": None}), ] ], ] UNARY_FUNCTIONS = [ - (prim.name, prim.bind, strategy) for prim, strategy in UNARY_PRIMITIVES + (prim.name, functools.partial(prim.bind, **params), strategy) for prim, params, strategy in UNARY_PRIMITIVES ] + [ ( name, @@ -273,38 +302,25 @@ def select_n_strategy( ] -class PallasBaseTest(jtu.JaxTestCase): - INTERPRET = False - - def setUp(self): - if not self.INTERPRET: - if jtu.device_under_test() == "cpu": - self.skipTest("Only interpret mode supported on CPU") - if (jtu.test_device_matches(["cuda"]) and - not jtu.is_cuda_compute_capability_at_least("8.0")): - self.skipTest("Only works on GPUs with capability >= sm80") - if (jtu.test_device_matches(["cuda"]) and use_mosaic_gpu and - not jtu.is_cuda_compute_capability_at_least("9.0")): - self.skipTest("Mosaic GPU requires capability >= sm90") - - super().setUp() +class PallasBaseTest(ptu.PallasTest): @classmethod def pallas_call(cls, *args, **kwargs): if jtu.test_device_matches(["cuda"]) and use_mosaic_gpu: assert plgpu_mgpu is not None - compiler_params = plgpu_mgpu.GPUCompilerParams( - thread_semantics=plgpu_mgpu.ThreadSemantics.Warpgroup + compiler_params = plgpu_mgpu.CompilerParams( + lowering_semantics=plgpu_mgpu.LoweringSemantics.Warpgroup ) kwargs["compiler_params"] = compiler_params return pl.pallas_call(*args, interpret=cls.INTERPRET, **kwargs) def skip_if_mosaic_gpu(self): - if jtu.test_device_matches(["cuda"]) and use_mosaic_gpu: + if jtu.test_device_matches(["gpu"]) and use_mosaic_gpu: self.skipTest("TODO: Mosaic GPU does not support this yet") +@jtu.thread_unsafe_test_class(condition=not jtu.hypothesis_is_thread_safe()) class OpsTest(PallasBaseTest): @parameterized.named_parameters( @@ -329,7 +345,7 @@ def kernel(x_ref, y_ref, o_ref): x = jnp.full((8, 128), 4, dtype=dtype) y = jnp.full((8, 128), 2 if jnp.issubdtype(dtype, jnp.integer) else 2.0, - dtype=dtype) + dtype=dtype) np.testing.assert_allclose(kernel(x, y), fn(x, y)) @parameterized.named_parameters( @@ -524,9 +540,45 @@ def kernel(x_ref, o_ref): np.testing.assert_allclose(result[0, 0], reduction_op(x), atol=1e-5) + @parameterized.named_parameters( + ("sum", jnp.sum, (32, 256)), + ("max", jnp.max, (32, 256)), + ("min", jnp.min, (32, 256)), + ("sum_irregular", jnp.sum, (31, 300)), + ("max_irregular", jnp.max, (31, 300)), + ("min_irregular", jnp.min, (31, 300)), + ) + def test_reduce_int32(self, reduction_op, input_shape): + if jtu.test_device_matches(["gpu"]): + self.skipTest("TODO: error on GPU") + + def kernel(x_ref, o_ref): + o_ref[0, 0] = reduction_op(x_ref[...]) + + x = jax.random.randint( + jax.random.key(0), + shape=input_shape, + minval=-100, + maxval=100, + dtype=jnp.int32, + ) + result = self.pallas_call( + kernel, + in_specs=[ + pl.BlockSpec(input_shape, lambda *_: (0, 0)), + ], + out_specs=pl.BlockSpec((1, 1), memory_space=smem_on_tpu()), + out_shape=jax.ShapeDtypeStruct([1, 1], intx), + grid=(1,), + )(x) + + np.testing.assert_allclose(result[0, 0], reduction_op(x), atol=1e-5) + # TODO(sharadmv): test rank < 2, size < 2 @hp.given(select_n_strategy(max_cases=2, min_rank=2, max_rank=4, min_size_exp=1)) + @hp.settings(suppress_health_check=([hp.HealthCheck.too_slow] + if jtu.is_asan() else [])) def test_select_n(self, args): if jtu.test_device_matches(["gpu"]): self.skipTest("TODO: error on GPU, lowering bug for select_n") @@ -559,13 +611,18 @@ def kernel(*refs): for name, func, strategy in UNARY_FUNCTIONS ) @hp.given(hps.data()) + @jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell") def test_unary_primitives(self, name, func, shape_dtype_strategy, data): - self.skip_if_mosaic_gpu() + if name in ["abs", "log1p", "pow2", "reciprocal", "relu", "sin", "sqrt"]: + self.skip_if_mosaic_gpu() if self.INTERPRET: self.skipTest("This hypothesis test is slow, even more so in interpret mode.") # We want exact equality here to match how JAX lowers to XLA tol = 0. + if jtu.test_device_matches(["tpu"]): + if name == "exp2": + tol = 1e-6 if jtu.test_device_matches(["gpu"]): if func == jnp.round or func == jnp.rint: self.skipTest("TODO: not implemented on GPU") @@ -577,6 +634,12 @@ def test_unary_primitives(self, name, func, shape_dtype_strategy, data): def kernel(x_ref, y_ref): y_ref[...] = func(x_ref[...]) x_shape_dtype = data.draw(shape_dtype_strategy) + + sut_is_mosaic_gpu = jtu.test_device_matches(["gpu"]) and use_mosaic_gpu + if sut_is_mosaic_gpu: + hp.assume(math.prod(x_shape_dtype.shape) % 128 == 0) + hp.assume(x_shape_dtype.shape[-1] >= 16) + key = random.key(0) x = _random_value(key, x_shape_dtype) out = self.pallas_call(kernel, out_shape=x_shape_dtype)(x) @@ -584,25 +647,20 @@ def kernel(x_ref, y_ref): @parameterized.product(from_dtype=_DTYPES_32BIT, to_dtype=_DTYPES) @hp.given(hps.data()) + @hp.settings(suppress_health_check=[hp.HealthCheck.too_slow]) # ASAN is slow def test_cast_from_32bit(self, from_dtype, to_dtype, data): + if jtu.test_device_matches(["cpu"]) and jtu.SKIP_SLOW_TESTS.value: + self.skipTest("Test is slow on CPU.") + sut_is_mosaic_gpu = jtu.test_device_matches(["gpu"]) and use_mosaic_gpu if to_dtype in {"float8_e4m3b11fnuz", "float8_e5m2", "float8_e4m3fn"}: - if not jtu.test_device_matches(["tpu"]) or jtu.get_tpu_version() < 5: + if not jtu.test_device_matches(["tpu"]): self.skipTest("Not supported on this hardware") - if not jtu.if_cloud_tpu_at_least(2025, 3, 8): - self.skipTest("Test requires libtpu from 2025/3/8 or later") - if from_dtype in {"int2", "uint2"} or to_dtype in {"int2", "uint2"}: - if jtu.test_device_matches(["tpu"]) and not jtu.if_cloud_tpu_at_least( - 2025, 4, 1 - ): - self.skipTest("Test requires libtpu from 2025/4/1 or later") if from_dtype == to_dtype: self.skipTest("Unnecessary test") if jtu.is_device_tpu(version=4): - if to_dtype in {"int8", "uint8", "int4", "uint4", "int2", "uint2"}: + if to_dtype in {"int2", "uint2"}: self.skipTest("Not supported on this TPU generation") - if to_dtype in {"int16", "uint16"} and not jtu.if_cloud_tpu_at_least(2025, 1, 18): - self.skipTest("Test requires libtpu from 2025/1/18 or later") if jtu.test_device_matches(["tpu"]) and jtu.get_tpu_version() < 4: # Currently only casts between 32-bit types and to bf16 are supported. if to_dtype not in {"int32", "uint32", "float32", "bfloat16"}: @@ -624,8 +682,10 @@ def test_cast_from_32bit(self, from_dtype, to_dtype, data): shape = (8, 128) if to_dtype in {"int2", "uint2"}: # Make sure #rows is a least the packing factor of int2. - # TODO(b/343490729): XLA convert(f32[16, 128]) fails on v5p. - shape = (32, 128) + # TODO: b/343490729 - XLA convert(f32[16, 128]) fails on v5p. + # TODO: b/459440496 - Support more shapes for int2. The number of rows is + # required to be an even multiple of 128. + shape = (128, 128) x = data.draw(hnp.arrays(from_dtype, shape, elements=elements)) x = jnp.asarray(x) def kernel(x_ref, y_ref): @@ -658,46 +718,22 @@ def test_cast_from_sub_32bit(self, from_dtype, to_dtype, randomize): if from_dtype == to_dtype: self.skipTest("Unnecessary test") - if from_dtype in {"int2", "uint2"} or to_dtype in {"int2", "uint2"}: - if jtu.test_device_matches(["tpu"]) and not jtu.if_cloud_tpu_at_least( - 2025, 4, 1 - ): - self.skipTest("Test requires libtpu from 2025/4/1 or later") if jtu.is_device_tpu(version=4): - allowed_v4_cats = {("int16", "int32"): (2025, 1, 18)} + allowed_v4_casts = {("int16", "int32")} if ( - from_dtype - in { - "int16", - "int8", - "uint16", - "uint8", - "int4", - "uint4", - "int2", - "uint2", - } - or to_dtype in {"int8", "uint8", "int4", "uint4", "int2", "uint2"} - ) and (from_dtype, to_dtype) not in allowed_v4_cats: + from_dtype in {"int2", "uint2"} or to_dtype in {"int2", "uint2"} + ) and (from_dtype, to_dtype) not in allowed_v4_casts: self.skipTest("Not supported on this TPU generation") - if minimum_libtpu_date := allowed_v4_cats.get((from_dtype, to_dtype), None): - if not jtu.if_cloud_tpu_at_least(*minimum_libtpu_date): - self.skipTest("Test requires a newer libtpu") - if to_dtype in {"int16", "uint16"} and not jtu.if_cloud_tpu_at_least(2025, 1, 18): - self.skipTest("Test requires libtpu from 2025/1/18 or later") if jtu.test_device_matches(["tpu"]) and jtu.get_tpu_version() < 4: self.skipTest("Not supported on this TPU generation") if jtu.test_device_matches(["gpu"]) and ( - to_dtype - in { - "int4", - "uint4", - "int2", - "uint2", - } + to_dtype in {"int4", "uint4", "int2", "uint2"} or from_dtype in {"int2", "uint2"} ): self.skipTest("sub-byte casts are buggy on GPU") # b/391292861 + if self.INTERPRET and (to_dtype in {"int2", "uint2"} or + from_dtype in {"int2", "uint2"}): + self.skipTest("Test fails on CPU.") if from_dtype == "float16" or to_dtype == "float16" and not sut_is_mosaic_gpu: self.skipTest("float16 is only supported with Mosaic GPU") if sut_is_mosaic_gpu: @@ -712,11 +748,9 @@ def test_cast_from_sub_32bit(self, from_dtype, to_dtype, randomize): "float8_e5m2", "float8_e4m3fn", } or to_dtype in {"float8_e4m3b11fnuz", "float8_e5m2", "float8_e4m3fn"}: - if not jtu.test_device_matches(["tpu"]) or jtu.get_tpu_version() < 5: + if not jtu.test_device_matches(["tpu"]): self.skipTest("Not supported on this hardware") - if not jtu.if_cloud_tpu_at_least(2025, 3, 9): - self.skipTest("Test requires libtpu from 2025/3/9 or later") - if from_dtype == "int2" and to_dtype == "bool": + if from_dtype in ("uint2", "int2") and to_dtype == "bool": self.skipTest( "TODO(b/343490729): XLA compare(s2, s2) yields wrong results" ) @@ -732,14 +766,6 @@ def test_cast_from_sub_32bit(self, from_dtype, to_dtype, randomize): "TODO(b/401624977): Mask on int2 is not yet supported in Mosaic" ) - from_int = np.issubdtype(np.dtype(from_dtype), np.integer) - to_int = np.issubdtype(np.dtype(to_dtype), np.integer) - if ( - from_int and to_int and np.dtype(from_dtype).itemsize != 4 - and not jtu.if_cloud_tpu_at_least(2025, 1, 12) - ): - self.skipTest("trunc from non-32 bit only implemented recently") - # TODO(sharadmv,apaszke): add support for the following casts if from_dtype == "bool" and to_dtype in { "int16", @@ -753,16 +779,8 @@ def test_cast_from_sub_32bit(self, from_dtype, to_dtype, randomize): }: self.skipTest("Not supported: cannot extend to sub-32 bit types") - def bitwidth(dtype): - if jnp.issubdtype(dtype, jnp.integer): - return jnp.iinfo(dtype).bits - elif jnp.issubdtype(dtype, jnp.floating): - return jnp.finfo(dtype).bits - else: - raise ValueError(f"Unsupported dtype: {dtype}") - if from_dtype != "bool": - from_bitwidth = bitwidth(from_dtype) + from_bitwidth = dtypes.itemsize_bits(from_dtype) from_int_dtype = getattr(jnp, "uint" + str(from_bitwidth)) if randomize: # randint has no support for 4 bit integers. @@ -1001,9 +1019,25 @@ def kernel(x_ref, o_ref): -0.2, jnp.inf, -jnp.inf, jnp.nan, 0.0, 1.0, -1.0, 0.5, ] - def test_is_finite(self): - if jtu.test_device_matches(["gpu"]): - self.skipTest("Not supported on GPU") + @parameterized.named_parameters( + (dtype.__name__, dtype) + for dtype in (jnp.float32, jnp.float16, jnp.bfloat16) + # other dtypes are TBD once is_nan and is_inf supports them + ) + def test_is_finite(self, dtype): + if jtu.test_device_matches(["tpu"]) and dtype != jnp.float32: + # The original test worked only fp32@TPU. Have no way to test TPU with other types. + self.skipTest("Not tested on TPU, todo for the respective team") + if jtu.test_device_matches(["cuda"]): + # The original test worked only on fp32@TPU, have no way to test CUDA + self.skipTest("Not tested on CUDA, todo for the respective team") + + if jtu.test_device_matches(["cuda"]): + self.skipTest("Not tested on CUDA") # set this b/c this how the test was + # originally configured. Have no way to test cuda. + + if jtu.is_device_rocm(): + self.skipTest("is_finite not in Triton lowering for jax 0.9.0") size = len(self.IS_FINITE_TEST_VALUES) @@ -1014,14 +1048,49 @@ def test_is_finite(self): def kernel(x_ref, o_ref): o_ref[...] = lax.is_finite(x_ref[...]) - x = jnp.array(self.IS_FINITE_TEST_VALUES, dtype=jnp.float32) + x = jnp.array(self.IS_FINITE_TEST_VALUES, dtype=dtype) out = kernel(x) expected = lax.is_finite(x) self.assertArraysEqual(out, expected) - def test_is_finite_scalar(self): - if jtu.test_device_matches(["gpu"]): - self.skipTest("Not supported on GPU") + @parameterized.parameters(jnp.float32, jnp.bfloat16, jnp.int32, jnp.int16) + def test_clamp(self, dtype): + if dtype == jnp.int16 and jtu.test_device_matches(["tpu"]): + self.skipTest("int16 is not supported on TPU") + + k1, k2, k3 = random.split(jax.random.key(0), num=3) + if jnp.issubdtype(dtype, jnp.floating): + lo_ = random.normal(k1, (8, 128), dtype=dtype) + hi_ = random.normal(k2, (8, 128), dtype=dtype) + x = random.normal(k3, (8, 128), dtype=dtype) + else: + lo_ = random.randint(k1, (8, 128), -100, 100, dtype=dtype) + hi_ = random.randint(k2, (8, 128), -100, 100, dtype=dtype) + x = random.randint(k3, (8, 128), -100, 100, dtype=dtype) + lo = jnp.minimum(lo_, hi_) + hi = jnp.maximum(lo_, hi_) + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((8, 128), dtype), + ) + def kernel(lo_ref, x_ref, hi_ref, o_ref): + o_ref[...] = lax.clamp(lo_ref[...], x_ref[...], hi_ref[...]) + np.testing.assert_array_equal(kernel(lo, x, hi), lax.clamp(lo, x, hi)) + + @parameterized.named_parameters( + (dtype.__name__, dtype) + for dtype in (jnp.float32, jnp.float16, jnp.bfloat16) + # other dtypes are TBD once is_nan and is_inf supports them + ) + def test_is_finite_scalar(self, dtype): + if jtu.test_device_matches(["tpu"]) and dtype != jnp.float32: + # The original test worked only fp32@TPU. Have no way to test TPU with other types. + self.skipTest("Not tested on TPU, todo for the respective team") + if jtu.test_device_matches(["cuda"]): + # The original test worked only on fp32@TPU, have no way to test CUDA + self.skipTest("Not tested on CUDA, todo for the respective team") + + if jtu.is_device_rocm(): + self.skipTest("is_finite not in Triton lowering for jax 0.9.0") size = len(self.IS_FINITE_TEST_VALUES) @@ -1035,7 +1104,7 @@ def kernel(x_ref, o_ref): for i in range(8): o_ref[i] = jnp.isfinite(x_ref[i]) - x = jnp.array(self.IS_FINITE_TEST_VALUES, dtype=jnp.float32) + x = jnp.array(self.IS_FINITE_TEST_VALUES, dtype=dtype) out = kernel(x) expected = lax.is_finite(x) self.assertArraysEqual(out, expected) @@ -1061,8 +1130,8 @@ def kernel(x_ref, o_ref): ( # fmt: off [jnp.expm1, jnp.log1p, jnp.cbrt, lax.rsqrt, jnp.tan, jnp.asin, - jnp.acos, jnp.atan, jnp.sinh, jnp.cosh, jnp.tanh, jnp.asinh, - jnp.acosh, jnp.atanh], + jnp.acos, jnp.atan, jnp.sinh, jnp.cosh, jnp.tanh, jnp.asinh, + jnp.acosh, jnp.atanh], # fmt: on ["bfloat16", "float32", "float64"], ), @@ -1076,7 +1145,8 @@ def kernel(x_ref, o_ref): for fn, dtype in itertools.product(*args) ) def test_elementwise(self, fn, dtype): - self.skip_if_mosaic_gpu() + if fn not in (jnp.sin, jnp.cos) or dtype == "float64": + self.skip_if_mosaic_gpu() if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8: self.skipTest("64-bit types require x64_enabled") @@ -1086,7 +1156,7 @@ def test_elementwise(self, fn, dtype): self.skipTest("int16 and float16 are not supported on TPU") if ( fn in (jnp.ceil, jnp.floor, jnp.negative, jnp.exp, jnp.exp2, jnp.log, - jnp.sqrt, lax.rsqrt) + jnp.sqrt, lax.rsqrt) and dtype == "bfloat16" and not jtu.is_device_tpu_at_least(6) ): @@ -1102,11 +1172,6 @@ def test_elementwise(self, fn, dtype): jnp.cbrt, jnp.cosh, jnp.expm1, jnp.sinh, ): self.skipTest(f"{fn.__name__} not implemented on TPU") - # TODO(apaszke): Remove after 12 weeks have passed. - if not jtu.if_cloud_tpu_at_least(2024, 12, 19): - self.skipTest("Requires libtpu built at least on 2024-12-19") - if fn == jnp.exp2 and dtype == "bfloat16" and not jtu.if_cloud_tpu_at_least(2025, 1, 31): - self.skipTest("Test requires newer libtpu") if ( jtu.test_device_matches(["gpu"]) @@ -1285,8 +1350,6 @@ def kernel(x_ref, y_ref, o_ref): ) ) def test_comparison(self, fn, dtype): - self.skip_if_mosaic_gpu() - if jtu.test_device_matches(["gpu"]) and dtype == jnp.bool_: self.skipTest("Not implemented on GPU.") @@ -1295,16 +1358,16 @@ def test_comparison(self, fn, dtype): @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct((8,), jnp.bool_), + out_shape=jax.ShapeDtypeStruct((128,), jnp.int32), ) def kernel(x_ref, y_ref, o_ref): - o_ref[:] = fn(x_ref[...], y_ref[...]) + o_ref[:] = fn(x_ref[...], y_ref[...]).astype(jnp.int32) - x = jnp.array([0, 3, -4, -6, 0, 5, 4, -7]).astype(dtype) - y = jnp.array([3, 1, -4, -5, 0, -2, 2, 4]).astype(dtype) + x = jnp.tile(jnp.array([0, 3, -4, -6, 0, 5, 4, -7]).astype(dtype), 16) + y = jnp.tile(jnp.array([3, 1, -4, -5, 0, -2, 2, 4]).astype(dtype), 16) out = kernel(x, y) expected = fn(x, y) - self.assertArraysEqual(out, expected) + self.assertArraysEqual(out != 0, expected) @parameterized.named_parameters( (f"{fn.__name__}_{dtype.__name__}", fn, dtype) @@ -1314,7 +1377,8 @@ def kernel(x_ref, y_ref, o_ref): ) ) def test_comparison_scalar(self, fn, dtype): - self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["cpu"]) and jtu.SKIP_SLOW_TESTS.value: + self.skipTest("Test is slow on CPU.") if jtu.test_device_matches(["tpu"]) and dtype == jnp.float16: self.skipTest("float16 is not supported on TPU") @@ -1325,6 +1389,9 @@ def test_comparison_scalar(self, fn, dtype): ): self.skipTest("Only works on GPUs with capability >= sm80") + if jtu.test_device_matches(["gpu"]) and dtype == jnp.bool_: + self.skip_if_mosaic_gpu() + @functools.partial( self.pallas_call, in_specs=( @@ -1332,17 +1399,17 @@ def test_comparison_scalar(self, fn, dtype): pl.BlockSpec(memory_space=smem_on_tpu()), ), out_specs=pl.BlockSpec(memory_space=smem_on_tpu()), - out_shape=jax.ShapeDtypeStruct((8,), jnp.bool_), + out_shape=jax.ShapeDtypeStruct((128,), jnp.int32), ) def kernel(x_ref, y_ref, o_ref): - for i in range(8): - o_ref[i] = fn(x_ref[i], y_ref[i]) + for i in range(128): + o_ref[i] = fn(x_ref[i], y_ref[i]).astype(jnp.int32) - x = jnp.array([0, 3, -4, -6, 0, 5, 4, -7]).astype(dtype) - y = jnp.array([3, 1, -4, -5, 0, -2, 2, 4]).astype(dtype) + x = jnp.tile(jnp.array([0, 3, -4, -6, 0, 5, 4, -7]).astype(dtype), 16) + y = jnp.tile(jnp.array([3, 1, -4, -5, 0, -2, 2, 4]).astype(dtype), 16) out = kernel(x, y) expected = fn(x, y) - self.assertArraysEqual(out, expected) + self.assertArraysEqual(out != 0, expected) def test_isnan(self): self.skip_if_mosaic_gpu() @@ -1373,7 +1440,10 @@ def kernel(x_ref, y_ref, out_ref): )(x, y) np.testing.assert_array_equal(out, jnp.einsum('mk,mn->kn', x, y)) - def test_dot_dims_batched_simple_dot_general(self): + @parameterized.product( + dtype=[jnp.float32, jnp.bfloat16], + ) + def test_dot_dims_batched_simple_dot_general(self, dtype): """This test is only meant to exercise a simple batch lowering of dot_general. It is not meant to be a comprehensive test of dot_general lowering, for that @@ -1382,24 +1452,335 @@ def test_dot_dims_batched_simple_dot_general(self): if jtu.test_device_matches(["gpu"]): self.skipTest("TPU only test") - x = jnp.arange(11 * 16 * 256, dtype=jnp.float32).reshape((11, 16, 256)) - y = jnp.arange(11 * 256 * 128, dtype=jnp.float32).reshape((11, 256, 128)) + if jtu.test_device_matches(["tpu"]): + if dtype == jnp.bfloat16: + if not jtu.is_device_tpu_at_least(version=4): + self.skipTest("Requires TPUv4+") + + k1, k2 = random.split(jax.random.key(0)) + x = random.normal(k1, (11, 16, 256), dtype=dtype) + y = random.normal(k2, (11, 256, 128), dtype=dtype) def kernel(x_ref, y_ref, out_ref): out_ref[...] = jax.lax.dot_general( x_ref[...], y_ref[...], dimension_numbers=(([2], [1]), ([0], [0])), + preferred_element_type=jnp.float32, ) out = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((11, 16, 128), jnp.float32) )(x, y) - np.testing.assert_array_equal( + np.testing.assert_allclose( out, - jax.lax.dot_general(x, y, dimension_numbers=(([2], [1]), ([0], [0]))), + jax.lax.dot_general( + x, + y, + dimension_numbers=(([2], [1]), ([0], [0])), + preferred_element_type=jnp.float32, + ), + rtol=5e-3 if dtype == jnp.bfloat16 else 1e-7, + ) + + @parameterized.product( + batch_size=(None, 1, 2), + # dims_numbers is without batch dims + shapes_and_dims_numbers=( + # leading lhs non contracting dims. + ((8, 8, 256), (256, 128), ([2], [0])), + ((3, 4, 128), (256, 128), ([2], [1])), + # trailing lhs non contracting dims. + ((256, 8, 128), (256, 128), ([0], [0])), + ((128, 8, 128), (256, 128), ([0], [1])), + # leading rhs non contracting dims. + ((8, 128), (8, 128, 128), ([1], [2])), + ((128, 8), (8, 128, 128), ([0], [2])), + # trailing rhs non contracting dims. + ((8, 128), (128, 8, 128), ([1], [0])), + ((128, 8), (128, 8, 128), ([0], [0])), + # leading lhs and rhs non contracting dims. + ((8, 8, 128), (8, 128, 128), ([2], [2])), + # leading lhs and trailing rhs non contracting dims. + ((8, 8, 128), (128, 8, 128), ([2], [0])), + # trailing lhs and leading rhs non contracting dims. + ((32, 8, 128), (8, 128, 32), ([0], [2])), + # trailing lhs and trailing rhs non contracting dims. + ((8, 8, 128), (8, 8, 128), ([0], [0])), + # non-contiguous lhs non contracting dims and leading rhs non contracting dims. + ((2, 128, 128), (2, 128, 128), ([1], [2])), + # non-contiguous lhs non contracting dims and trailing rhs non contracting dims. + ((2, 128, 128), (128, 2, 128), ([1], [0])), + # leading lhs non contracting dims and non-contiguous rhs non contracting dims. + ((2, 128, 128), (2, 128, 128), ([2], [1])), + # trailing lhs non contracting dims and non-contiguous rhs non contracting dims. + ((128, 2, 128), (2, 128, 128), ([0], [1])), + # non-contiguous lhs and rhs non contracting dims. + ((2, 128, 128), (2, 128, 128), ([1], [1])), + ), + ) + def test_dot_general_multiple_non_contracting_dims( + self, batch_size, shapes_and_dims_numbers + ): + if jtu.test_device_matches(["gpu"]): + self.skipTest("TPU only test") + + if jtu.test_device_matches(["tpu"]) and not jtu.is_cloud_tpu_at_least( + 2025, 10, 5 + ): + self.skipTest("Requires libtpu built after 2025-10-05") + + x_shape, y_shape, dims_numbers = shapes_and_dims_numbers + if batch_size is not None: + x_shape = (batch_size,) + x_shape + y_shape = (batch_size,) + y_shape + + # Batch size is always the first dimension so we need to offset + # dims_numbers by 1. + def offset_by_one(x): + return [a + 1 for a in x] + + dims_numbers = ( + (offset_by_one(dims_numbers[0]), offset_by_one(dims_numbers[1])), + ([0], [0]), + ) + else: + dims_numbers = ( + (dims_numbers[0], dims_numbers[1]), + ([], []), + ) + + k1, k2 = random.split(jax.random.key(0)) + x = jax.random.normal(k1, x_shape, dtype=jnp.float32) + y = jax.random.normal(k2, y_shape, dtype=jnp.float32) + + # Just infer shape from jax. + expected = jax.lax.dot_general(x, y, dimension_numbers=dims_numbers) + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(expected.shape, jnp.float32), + ) + def kernel(x_ref, y_ref, out_ref): + out_ref[...] = jax.lax.dot_general( + x_ref[...], + y_ref[...], + dimension_numbers=dims_numbers, + ) + + np.testing.assert_allclose( + kernel(x, y), + expected, ) + @parameterized.product( + shapes_and_dims_numbers=( + ((3, 4, 128), (4, 2, 128), (((2,), (2,)), ((1,), (0,)))), + ((3, 4, 128), (2, 4, 128), (((2,), (2,)), ((1,), (1,)))), + ((3, 4, 256), (2, 3, 256), (((2,), (2,)), ((0,), (1,)))), + ((4, 3, 2, 32), (2, 128, 32, 2), (((3,), (2,)), ((2,), (3,)))), + ), + ) + def test_dot_general_non_front_batch_dims(self, shapes_and_dims_numbers): + if jtu.test_device_matches(["gpu"]): + self.skipTest("TPU only test") + + if jtu.test_device_matches(["tpu"]) and not jtu.is_cloud_tpu_at_least( + 2025, 11, 30 + ): + self.skipTest("Requires libtpu built after 2025-11-30") + + x_shape, y_shape, dims_numbers = shapes_and_dims_numbers + + k1, k2 = random.split(jax.random.key(0)) + x = jax.random.normal(k1, x_shape, dtype=jnp.float32) + y = jax.random.normal(k2, y_shape, dtype=jnp.float32) + + # Just infer shape from jax. + expected = jax.lax.dot_general(x, y, dimension_numbers=dims_numbers) + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(expected.shape, jnp.float32), + ) + def kernel(x_ref, y_ref, out_ref): + out_ref[...] = jax.lax.dot_general( + x_ref[...], + y_ref[...], + dimension_numbers=dims_numbers, + ) + + np.testing.assert_allclose(kernel(x, y), expected, atol=1e-5, rtol=1e-5) + + @parameterized.product( + batch_size=(None, 1, 2), + # dims_numbers is without batch dims + shapes_and_dims_numbers=( + # Case with LHS already being a transpose + ( + (4, 3, 2, 6), + (4, 6), + ([2], [1]), + (2, 1, 3, 0), + None, + ), + # Case with RHS already being a transpose + ( + (3, 4, 2, 6), + (6, 4), + ([1], [0]), + None, + (1, 0), + ), + # Case with both LHS and RHS already being transposes + ( + (3, 2, 6, 4), + (4, 6), + ([2], [1]), + (0, 1, 3, 2), + (1, 0), + ), + ), + ) + def test_dot_general_multiple_non_contracting_dims_with_transposes( + self, batch_size, shapes_and_dims_numbers + ): + if jtu.test_device_matches(["gpu"]): + self.skipTest("TPU only test") + + if jtu.test_device_matches(["tpu"]) and not jtu.is_cloud_tpu_at_least( + 2025, 10, 5 + ): + self.skipTest("Requires libtpu built after 2025-10-05") + + ( + x_shape_unbatched, + y_shape_unbatched, + dims_numbers_unbatched, + x_perm_unbatched, + y_perm_unbatched, + ) = shapes_and_dims_numbers + if batch_size is not None: + x_shape = (batch_size,) + x_shape_unbatched + y_shape = (batch_size,) + y_shape_unbatched + + x_perm = ( + tuple([0] + [i + 1 for i in x_perm_unbatched]) + if x_perm_unbatched is not None + else None + ) + y_perm = ( + tuple([0] + [i + 1 for i in y_perm_unbatched]) + if y_perm_unbatched is not None + else None + ) + + # Batch size is always the first dimension so we need to offset + # dims_numbers by 1. + def offset_by_one(x): + return [a + 1 for a in x] + + dims_numbers = ( + ( + offset_by_one(dims_numbers_unbatched[0]), + offset_by_one(dims_numbers_unbatched[1]), + ), + ([0], [0]), + ) + else: + x_shape = x_shape_unbatched + y_shape = y_shape_unbatched + x_perm = x_perm_unbatched + y_perm = y_perm_unbatched + dims_numbers = ( + (dims_numbers_unbatched[0], dims_numbers_unbatched[1]), + ([], []), + ) + + k1, k2 = random.split(jax.random.key(0)) + x = jax.random.normal(k1, x_shape, dtype=jnp.float32) + y = jax.random.normal(k2, y_shape, dtype=jnp.float32) + expected = jax.lax.dot_general( + jnp.transpose(x, x_perm) if x_perm else x, + jnp.transpose(y, y_perm) if y_perm else y, + dimension_numbers=dims_numbers, + ) + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(expected.shape, jnp.float32), + ) + def kernel(x_ref, y_ref, out_ref): + if x_perm: + x_in_kernel = jnp.transpose(x_ref[...], x_perm) + else: + x_in_kernel = x_ref[...] + if y_perm: + y_in_kernel = jnp.transpose(y_ref[...], y_perm) + else: + y_in_kernel = y_ref[...] + out_ref[...] = jax.lax.dot_general( + x_in_kernel, + y_in_kernel, + dimension_numbers=dims_numbers, + ) + + np.testing.assert_allclose(kernel(x, y), expected, atol=1e-6, rtol=1e-6) + + @parameterized.product( + batch_size=(None, 1, 2), + lhs_non_contracting_shape=((8,), (8, 128)), + rhs_is_vector=(True, False), # [K] or [1, K]. + dtype=(jnp.float32,), + ) + def test_matrix_vector_like_dot_general( + self, + batch_size, + lhs_non_contracting_shape, + rhs_is_vector, + dtype, + ): + if jtu.test_device_matches(["gpu"]): + self.skipTest("TPU only test") + + if jtu.test_device_matches(["tpu"]): + if not jtu.is_device_tpu_at_least(5) and rhs_is_vector: + self.skipTest("Requires TPUv5+ for sublane gather") + + contracting_shape = 128 + rhs_non_contracting_shape = () if rhs_is_vector else (1,) + batch_shape = (batch_size,) if batch_size is not None else () + batch_dim = [0] if batch_size else [] + lhs_shape = (*batch_shape, *lhs_non_contracting_shape, contracting_shape) + rhs_shape = (*batch_shape, *rhs_non_contracting_shape, contracting_shape) + k1, k2 = random.split(jax.random.key(0)) + lhs = jax.random.normal(k1, lhs_shape, dtype=dtype) + rhs = jax.random.normal(k2, rhs_shape, dtype=dtype) + dims_numbers = ( + ([len(lhs_shape) - 1], [len(rhs_shape) - 1]), + (batch_dim, batch_dim), + ) + expected = jax.lax.dot_general( + lhs, + rhs, + dimension_numbers=dims_numbers, + preferred_element_type=jnp.float32, + ) + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(expected.shape, jnp.float32), + ) + def kernel(lhs_ref, rhs_ref, out_ref): + out_ref[...] = jax.lax.dot_general( + lhs_ref[...], + rhs_ref[...], + dimension_numbers=dims_numbers, + preferred_element_type=jnp.float32, + ) + + np.testing.assert_allclose(kernel(lhs, rhs), expected, atol=5e-6, rtol=5e-4) + @parameterized.parameters( ("int32", "float32"), ("float32", "float32"), @@ -1411,8 +1792,6 @@ def test_true_divide(self, dtype, out_dtype): if jtu.test_device_matches(["tpu"]): if out_dtype == "bfloat16" and not jtu.is_device_tpu_at_least(6): self.skipTest("bfloat16 is not supported on older TPU generations") - if not jtu.if_cloud_tpu_at_least(2025, 1, 9): - self.skipTest("Requires libtpu built after 2025-01-09") elif jtu.test_device_matches(["gpu"]): if dtype == "bfloat16": self.skipTest("bfloat16 not supported") @@ -1464,7 +1843,7 @@ def kernel(x_ref, y_ref, o_ref): ( # fmt: off [jnp.bitwise_and, jnp.bitwise_or, jnp.bitwise_xor, - jnp.bitwise_left_shift, jnp.bitwise_right_shift], + jnp.bitwise_left_shift, jnp.bitwise_right_shift], # fmt: on ["int32", "uint32"], ), @@ -1510,10 +1889,10 @@ def test_binary_scalar(self, f, dtype): @functools.partial( self.pallas_call, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), out_shape=jax.ShapeDtypeStruct((1,), dtype), ) def kernel(x_ref, y_ref, o_ref): @@ -1525,14 +1904,15 @@ def kernel(x_ref, y_ref, o_ref): np.testing.assert_allclose(f(x, y), kernel(x, y)) @parameterized.parameters( + ((32,), jnp.int32, 0), ((8, 4), jnp.int32, 0), ((8, 16), jnp.float32, 1), ((8, 16, 2), jnp.int8, 1), ) - def test_broadcasted_iota(self, shape, dtype, dimension): + def test_iota(self, shape, dtype, dimension): self.skip_if_mosaic_gpu() - if jtu.test_device_matches(["tpu"]): + if jtu.test_device_matches(["tpu"]) and dtype != jnp.int32: self.skipTest("Only 32-bit integer iota supported") f = lambda: jax.lax.broadcasted_iota(dtype, shape, dimension) @@ -1545,7 +1925,7 @@ def kernel(o_ref): np.testing.assert_allclose(f(), kernel()) - @parameterized.parameters("float16", "bfloat16", "float32") + @parameterized.parameters("float16", "bfloat16", "float32", "float64") def test_approx_tanh(self, dtype): self.skip_if_mosaic_gpu() @@ -1556,9 +1936,20 @@ def test_approx_tanh(self, dtype): self.skipTest("approx_tanh is not supported in interpret mode") if (dtype == "bfloat16" and + jtu.test_device_matches(["cuda"]) and not jtu.is_cuda_compute_capability_at_least("9.0")): self.skipTest("tanh.approx.bf16 requires a GPU with capability >= sm90") + if dtype == "float64": + if jtu.test_device_matches(["cuda"]): + self.skipTest("f64 approx_tanh is only supported on ROCm") + + # Enable x64 for f64 test if not already enabled, restore after test + original_x64 = jax.config.x64_enabled + if dtype == "float64" and not original_x64: + jax.config.update("jax_enable_x64", True) + self.addCleanup(lambda: jax.config.update("jax_enable_x64", False)) + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), dtype), ) @@ -1575,113 +1966,6 @@ def kernel(x_ref, o_ref): rtol=5e-3, ) - def test_elementwise_inline_asm(self): - self.skip_if_mosaic_gpu() - - if jtu.test_device_matches(["tpu"]): - self.skipTest("Not implemented: elementwise_inline_asm_p") - - if self.INTERPRET: - self.skipTest( - "elementwise_inline_asm is not supported in interpret mode" - ) - - @functools.partial( - self.pallas_call, - out_shape=jax.ShapeDtypeStruct((256,), jnp.float16), - ) - def kernel(x_ref, o_ref): - [o_ref[...]] = plgpu_triton.elementwise_inline_asm( - "tanh.approx.f16x2 $0, $1;", - args=[x_ref[...]], - constraints="=r,r", - pack=2, - result_shape_dtypes=[jax.ShapeDtypeStruct(x_ref.shape, x_ref.dtype)], - ) - - x = jnp.arange(256).astype(jnp.float16) - np.testing.assert_allclose(kernel(x), jnp.tanh(x), atol=5e-3, rtol=5e-3) - - def test_debug_barrier(self): - self.skip_if_mosaic_gpu() - - if jtu.test_device_matches(["tpu"]): - self.skipTest("Not implemented: debug_barrier_p") - - if self.INTERPRET: - self.skipTest("debug_barrier is not supported in interpret mode") - - @functools.partial( - self.pallas_call, - out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), - ) - def kernel(x_ref, o_ref): - o_ref[...] = x_ref[...] - plgpu_triton.debug_barrier() - - x = jnp.array([4.2, 2.4]).astype(jnp.float32) - np.testing.assert_array_equal(kernel(x), x) - - @unittest.skipIf( - sys.platform == "win32", - "plgpu_triton.TritonCompilerParams unavailable on Windows", - ) - def test_debug_print(self): - self.skip_if_mosaic_gpu() - - if jtu.test_device_matches(["tpu"]): - self.skipTest("Test for TPU is covered in tpu_pallas_test.py") - - # TODO: this test flakes on gpu - if jtu.test_device_matches(["gpu"]): - self.skipTest("This test flakes on gpu") - - @functools.partial( - self.pallas_call, - out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), - compiler_params=plgpu_triton.TritonCompilerParams( - num_warps=1, num_stages=1 - ), - ) - def kernel(x_ref, o_ref): - pl.debug_print("It works!") - - x = jnp.array([4.2, 2.4]).astype(jnp.float32) - with jtu.capture_stdout() as output: - jax.block_until_ready(kernel(x)) - jax.effects_barrier() - - self.assertIn("It works!", output()) - - @unittest.skipIf( - sys.platform == "win32", - "plgpu_triton.TritonCompilerParams unavailable on Windows", - ) - def test_debug_print_with_values(self): - if jtu.test_device_matches(["tpu"]): - self.skipTest("Test for TPU is covered in tpu_pallas_test.py") - - # TODO: this test flakes on gpu - if jtu.test_device_matches(["gpu"]): - self.skipTest("This test flakes on gpu") - - @functools.partial( - self.pallas_call, - out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), - compiler_params=plgpu_triton.TritonCompilerParams( - num_warps=1, num_stages=1 - ), - ) - def kernel(x_ref, o_ref): - pl.debug_print("x[0] =", x_ref[0]) - - x = jnp.array([4.2, 2.4]).astype(jnp.float32) - with jtu.capture_stdout() as output: - jax.block_until_ready(kernel(x)) - jax.effects_barrier() - - self.assertIn("x[0] = 4.2", output()) - @parameterized.parameters( ((2, 4), (8,)), ((2, 4), (8, 1)), @@ -1739,6 +2023,27 @@ def f(x_ref, o_ref): expected = x.reshape(out_shape) np.testing.assert_allclose(f(x), expected) + def test_reshape_to_scalar(self): + self.skip_if_mosaic_gpu() + # Test reshapes from (1, 1) to (). + # Because TPUs distinguish between VREGs/SREGs this tests an implicit + # copy from VREG -> SREG that must be inserted by Pallas. + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.int32), + ) + def f(x_ref, o_ref): + o_ref[...] = jnp.zeros_like(o_ref) + vector_val = x_ref[1:2, 0:1] + scalar_val = jnp.reshape(vector_val, ()) + o_ref[scalar_val] = jnp.ones_like(o_ref[0]) * scalar_val + + in_shape = (4, 4) + x = jnp.arange(int(np.prod(in_shape)), dtype=jnp.int32).reshape(in_shape) + expected = jnp.zeros((8, 128), jnp.int32) + expected = expected.at[x[1, 0]].set(x[1, 0]) + np.testing.assert_allclose(f(x), expected) + def test_num_programs(self): self.skip_if_mosaic_gpu() @@ -1779,30 +2084,47 @@ def copyitem(x_ref, in_idx_ref, out_idx_ref, o_ref): np.testing.assert_allclose(out[oi], x[ii]) np.testing.assert_allclose(out[oi + 1 :], jnp.zeros_like(out[oi + 1 :])) - @parameterized.parameters( - ((), (2,), ()), - ((1,), (2,), (0,)), - ((1, 1), (2, 2), (0, 1)), - ((), (2, 2), ()), + @parameterized.product( + shape_spec=[ + ((), (2,), ()), + ((1,), (2,), (0,)), + ((1, 128), (8, 128), (0, 1)), # row broadcasting + ((), (2, 2), ()), + ], + dtype=[jnp.int32, jnp.int16, jnp.int8, jnp.bool_], ) - def test_broadcast_in_dim(self, in_shape, out_shape, dims): + def test_broadcast_in_dim(self, shape_spec, dtype): self.skip_if_mosaic_gpu() - # The Pallas TPU lowering currently supports only blocks of rank >= 1 + in_shape, out_shape, dims = shape_spec if jtu.test_device_matches(["tpu"]): - self.skipTest("Not supported on TPU") + if not in_shape: + self.skipTest( + "The Pallas TPU lowering currently supports only blocks of rank" + " >= 1" + ) + if ( + len(in_shape) == 1 + and len(out_shape) == 1 + and dtype not in {jnp.int32, jnp.bool_} + ): + self.skipTest("Unsupported tiling") @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct(out_shape, jnp.float32), + out_shape=jax.ShapeDtypeStruct(out_shape, dtype), ) def f(x_ref, o_ref): x = x_ref[...] o_ref[...] = jax.lax.broadcast_in_dim(x, out_shape, dims) - x = jnp.arange(int(np.prod(in_shape)), dtype=jnp.float32).reshape(in_shape) + x = ( + jnp.arange(math.prod(in_shape), dtype=jnp.int32) + .reshape(in_shape) + .astype(dtype) + ) expected = jax.lax.broadcast_in_dim(x, out_shape, dims) - np.testing.assert_allclose(f(x), expected) + np.testing.assert_array_equal(f(x), expected) @parameterized.product( lhs_and_rhs_shape=[ @@ -1834,12 +2156,13 @@ def f(x_ref, o_ref): trans_x=[False, True], trans_y=[False, True], ) + @jtu.skip_if_triton_exceeds_shared_memory( + device_patterns=("RTX PRO 6000 Blackwell", "GB10$")) + @jtu.skip_if_mosaic_gpu_exceeds_shared_memory( + device_patterns=("RTX PRO 6000 Blackwell", "GB10$")) def test_dot(self, lhs_and_rhs_shape, dtype, trans_x, trans_y): self.skip_if_mosaic_gpu() - # TODO(apaszke): Remove after 12 weeks have passed. - if not jtu.if_cloud_tpu_at_least(2024, 12, 19): - self.skipTest("Requires libtpu built after 2024-12-19") lhs_shape, rhs_shape = lhs_and_rhs_shape final_lhs_shape = lhs_shape[::-1] if trans_x else lhs_shape @@ -1860,10 +2183,29 @@ def test_dot(self, lhs_and_rhs_shape, dtype, trans_x, trans_y): if jtu.test_device_matches(["gpu"]): if dtype == jnp.bfloat16: self.skipTest("bfloat16 type are not supported on GPU") - if ( - math.prod(lhs_shape) + math.prod(rhs_shape) + math.prod(out_shape) - > (256 * 256) * 2 - ): + # Check shared memory limit: Triton loads lhs + rhs into shared memory + if jtu.is_device_rocm(): + # ROCm: use correct formula with dynamic limit from rocminfo + dtype_size = jnp.dtype(dtype).itemsize + shared_mem_bytes = (math.prod(lhs_shape) + math.prod(rhs_shape)) * dtype_size + shared_mem_limit = get_rocm_shared_memory_limit() + if shared_mem_bytes > shared_mem_limit: + self.skipTest("Shared memory size limit exceeded") + else: + # NVIDIA: keep original check + if ( + math.prod(lhs_shape) + math.prod(rhs_shape) + math.prod(out_shape) + > (256 * 256) * 2 + ): + self.skipTest("Shared memory size limit exceeded") + if (jax.local_devices()[0].device_kind == "NVIDIA L4" and + dtype == jnp.float32 and + lhs_and_rhs_shape in [ + ((128, 16), (128, 256)), + ((16, 128), (128, 256)), + ((16, 256), (256, 128)), + ((256, 16), (256, 128)), + ]): self.skipTest("Shared memory size limit exceeded") if min(*lhs_shape, *rhs_shape) < 16: self.skipTest("All dimensions of lhs and rhs must be >= 16") @@ -1886,7 +2228,7 @@ def dot(x_ref, y_ref, o_ref): # Pallas always accumulates in FP32, so we are explicit about # preferred_element_type here. expected = jnp.dot(x.T if trans_x else x, y.T if trans_y else y, - preferred_element_type=jnp.float32).astype(dtype) + preferred_element_type=jnp.float32).astype(dtype) np.testing.assert_allclose( out.astype(jnp.float32), expected.astype(jnp.float32), @@ -1895,56 +2237,62 @@ def dot(x_ref, y_ref, o_ref): ) @parameterized.product( - size=[1, 2, 64, 129, 1021], - block_size=[1, 2, 32, 64, 128], + dtype=[jnp.int8, jnp.int4], + trans_x=[False, True], + trans_y=[False, True], ) - def test_masked_load_store(self, size, block_size): + def test_itof_dot_canonicalization(self, dtype, trans_x, trans_y): self.skip_if_mosaic_gpu() - - if jtu.test_device_matches(["tpu"]): - self.skipTest("Not implemented") + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Not supported on this hardware") + if jtu.get_tpu_version() != 7: + self.skipTest("The canonicalization pass being tested is on v7 only.") + if self.INTERPRET and dtype == jnp.int4: + self.skipTest("Interpret mode does not support int4") + lhs_shape = rhs_shape = out_shape = (256, 256) @functools.partial( self.pallas_call, - out_shape=(jax.ShapeDtypeStruct((size,), floatx)), - grid=pl.cdiv(size, block_size), + out_shape=jax.ShapeDtypeStruct(out_shape, jnp.int32), ) - def kernel(x_ref, o_ref): - idx = pl.program_id(0) * block_size + jnp.arange( - block_size, dtype=jnp.int32) - mask = idx < x_ref.shape[0] - x = pl.load(x_ref, (idx,), mask=mask) - pl.store(o_ref, (idx,), x + 1.0, mask=mask) - - key = random.key(0) - x = random.normal(key, (size,)) - np.testing.assert_allclose(kernel(x), x + 1.0, atol=1e-5, rtol=1e-5) - - def test_masked_oob_load_store_slice(self): - self.skip_if_mosaic_gpu() - - # The Pallas TPU lowering currently supports only blocks of rank >= 1 - if jtu.test_device_matches(["tpu"]): - self.skipTest("Not supported on TPU") - - n = 16 + def dot(x_ref, y_ref, o_ref): + x = x_ref[:, :] + y = y_ref[:, :] + o_ref[:, :] = pl.dot(x, y, trans_x, trans_y).astype(o_ref.dtype) - @functools.partial( - self.pallas_call, - out_shape=(jax.ShapeDtypeStruct((n,), floatx)), + # random.randint does not support int4, so create as int8. + x = random.randint( + key=random.key(0), + shape=lhs_shape, + minval=jnp.iinfo(dtype).min, + maxval=jnp.iinfo(dtype).max, + dtype=jnp.int8, + ) + y = random.randint( + key=random.key(1), + shape=rhs_shape, + minval=jnp.iinfo(dtype).min, + maxval=jnp.iinfo(dtype).max, + dtype=jnp.int8, + ) + if dtype == jnp.int4: + out = dot(x.astype(jnp.int4), y.astype(jnp.int4)) + else: + out = dot(x, y) + # TODO: b/438321086 - investigate and fix jnp.dot with int4. + # For now, use int8 instead. + expected = jnp.dot( + x.T if trans_x else x, + y.T if trans_y else y, + preferred_element_type=jnp.int32, + ).astype(jnp.int32) + np.testing.assert_allclose( + out.astype(jnp.int32), + expected.astype(jnp.int32), + # s4->f8 introduces more error than s8->bf16. + atol=3 if dtype == jnp.int4 else 0, + rtol=0.05 if dtype == jnp.int4 else 0, ) - def masked_oob_load_store_slice(x_ref, mask_ref, start_idx_ref, o_ref): - x = pl.load(x_ref, (pl.dslice(start_idx_ref[()], n)), - mask=mask_ref[:], other=-1.) - pl.store(o_ref, (pl.dslice(None),), x) - - x = random.normal(random.key(0), (n,)) - slice_start = random.randint(random.key(2), (), 1, n) - indices = jnp.arange(n) + slice_start - mask = indices < n - out = masked_oob_load_store_slice(x, mask, slice_start) - o_new = jnp.where(mask, x[indices], jnp.full_like(x, -1.)) - np.testing.assert_array_equal(out, o_new) def test_strided_load(self): self.skip_if_mosaic_gpu() @@ -1973,46 +2321,13 @@ def test_broadcasted_load_store(self): out_shape=(jax.ShapeDtypeStruct((m, n), floatx)), ) def load(x_ref, o_ref): - x = pl.load(x_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None, :])) - pl.store(o_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None, :]), x + 1.0) + idx = (jnp.arange(m)[:, None], jnp.arange(n)[None, :]) + o_ref[idx] = x_ref[idx] + 1.0 key = random.key(0) x = random.normal(key, (m, n)) np.testing.assert_allclose(load(x), x + 1.0, atol=1e-5, rtol=1e-5) - @parameterized.parameters( - ((16, 32), (16,)), - ((16, 32), (32,)), - ((16, 32), (16, 16)), - ) - def test_invalid_broadcasted_load(self, x_shape, mask_shape): - self.skip_if_mosaic_gpu() - - # The Pallas TPU lowering currently supports only blocks of rank >= 1 - if jtu.test_device_matches(["tpu"]): - self.skipTest("Not supported on TPU") - - if self.INTERPRET: - self.skipTest("No broadcasting checks in pl.load in interpret mode") - - @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32) - ) - def kernel(x_ref, mask_ref, o_ref): - del o_ref # Unused. - pl.load(x_ref, slice(None), mask=mask_ref[:]) - - x = jnp.ones(x_shape, dtype=jnp.float32) - mask = jnp.ones(mask_shape, dtype=jnp.bool_) - # assertRaises* methods do not support inspecting the __cause__, so - # we have to check it manually. - try: - kernel(x, mask) - except Exception as e: - self.assertIn("Cannot broadcast", str(e.__cause__)) - else: - self.fail("Expected exception due to invalid broadcasting") - def test_swap(self): self.skip_if_mosaic_gpu() @@ -2029,7 +2344,7 @@ def test_swap(self): ) def swap(_, _2, x_ref, y_ref): x = x_ref[:] - y = pl.swap(y_ref, (slice(None),), x) + y = pallas_primitives.swap(y_ref, (slice(None),), x) x_ref[:] = y x = random.normal(random.key(0), (m, n)) @@ -2053,7 +2368,7 @@ def test_masked_swap(self): ) def masked_swap(_, _2, mask_ref, x_ref, y_ref): x = x_ref[:] - y = pl.swap(y_ref, (slice(None),), x, mask=mask_ref[:]) + y = pallas_primitives.swap(y_ref, (slice(None),), x, mask=mask_ref[:]) x_ref[:] = y x = random.normal(random.key(0), (m, n)) @@ -2075,12 +2390,14 @@ def test_masked_oob_swap_slice(self): @functools.partial( self.pallas_call, out_shape=(jax.ShapeDtypeStruct((n,), floatx), - jax.ShapeDtypeStruct((m,), floatx)), + jax.ShapeDtypeStruct((m,), floatx)), input_output_aliases={0: 0, 1: 1}, ) def masked_oob_swap_slice(_, _2, mask_ref, start_idx_ref, x_ref, y_ref): x, mask = x_ref[:], mask_ref[:] - y = pl.swap(y_ref, (pl.dslice(start_idx_ref[()], n)), x, mask=mask) + y = pallas_primitives.swap( + y_ref, (pl.dslice(start_idx_ref[()], n)), x, mask=mask + ) x_ref[:] = y x = random.normal(random.key(0), (n,)) @@ -2097,150 +2414,7 @@ def masked_oob_swap_slice(_, _2, mask_ref, start_idx_ref, x_ref, y_ref): np.testing.assert_array_equal(out[0], x_new) np.testing.assert_array_equal(out[1], y_new) - @parameterized.named_parameters( - ("add_i32", pl.atomic_add, np.array([1, 2, 3, 4], np.int32), np.sum), - ("max_i", pl.atomic_max, np.array([1, 2, 3, 4], np.int32), np.max), - ("min_i32", pl.atomic_min, np.array([1, 2, 3, 4], np.int32), np.min), - ("add_f16", pl.atomic_add, np.array([1, 2, 3, 4], np.float16), np.sum), - ("add_f32", pl.atomic_add, np.array([1, 2, 3, 4], np.float32), np.sum), - ("max_f32", pl.atomic_max, np.array([1, 2, 3, 4], np.float32), np.max), - ("min_f32", pl.atomic_min, np.array([1, 2, 3, 4], np.float32), np.min), - ) - def test_scalar_atomic(self, op, value, numpy_op): - self.skip_if_mosaic_gpu() - - # The Pallas TPU lowering currently supports only blocks of rank >= 1 - if jtu.test_device_matches(["tpu"]): - self.skipTest("Not supported on TPU") - - @functools.partial( - self.pallas_call, - out_shape=jax.ShapeDtypeStruct((), value.dtype), - grid=value.shape[0], - input_output_aliases={1: 0}, - ) - def atomic_kernel(x_ref, _, o_ref): - pid = pl.program_id(axis=0) - op(o_ref, (), x_ref[pid]) - - if op == pl.atomic_add: - neutral = np.array(0, dtype=value.dtype) - elif op == pl.atomic_max: - if np.issubdtype(value.dtype, np.integer): - neutral = np.array(np.iinfo(value.dtype).min, value.dtype) - else: - neutral = np.array(-float("inf"), value.dtype) - elif op == pl.atomic_min: - if np.issubdtype(value.dtype, np.integer): - neutral = np.array(np.iinfo(value.dtype).max, value.dtype) - else: - neutral = np.array(float("inf"), value.dtype) - elif op == pl.atomic_or: - neutral = np.array(False, value.dtype) - else: - raise NotImplementedError() - out = atomic_kernel(value, neutral) - np.testing.assert_allclose(out, numpy_op(value)) - - @parameterized.parameters((0,), (1,)) - def test_array_atomic_add(self, axis): - self.skip_if_mosaic_gpu() - - if jtu.test_device_matches(["tpu"]): - self.skipTest("Unimplemented primitive: broadcast_to") - - m, n = 32, 8 - if axis == 0: - grid = m - else: - grid = n - out_shape = jax.ShapeDtypeStruct((n if axis == 0 else m,), floatx) - - @functools.partial( - self.pallas_call, - out_shape=out_shape, - grid=grid, - input_output_aliases={1: 0}, - ) - def reduce(x_ref, _, y_ref): - i = pl.program_id(axis=0) - if axis == 0: - idx = (i, jnp.arange(n)) - else: - idx = (jnp.arange(m), i) - x = pl.load(x_ref, idx) - pl.atomic_add(y_ref, (jnp.arange(y.shape[0]),), x) - - x = random.normal(random.key(0), (m, n)) - y = jnp.zeros(out_shape.shape, out_shape.dtype) - y = reduce(x, y) - y_ref = np.sum(x, axis=axis) - np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2) - - @parameterized.parameters( - (0, 0, 1), - (0, 1, 1), - (1, 0, 1), - (1, 1, 1), - (2, 1, 1), - (2, 1, 1), - ) - def test_atomic_cas(self, init_value, cmp, new_value): - self.skip_if_mosaic_gpu() - - # The Pallas TPU lowering currently supports only blocks of rank >= 1 - if jtu.test_device_matches(["tpu"]): - self.skipTest("Not supported on TPU") - - if jax.config.x64_enabled and jtu.test_device_matches(["gpu"]): - self.skipTest("Not supported on GPU in 64-bit mode") - - @functools.partial( - self.pallas_call, out_shape=( - jax.ShapeDtypeStruct((), intx), - jax.ShapeDtypeStruct((), intx)), - input_output_aliases={0: 0}) - def swap(_, lock_ref, out_ref): - out_ref[()] = pl.atomic_cas(lock_ref, cmp, new_value) - - lock, out = swap(init_value) - np.testing.assert_allclose(lock, new_value if cmp == init_value else - init_value) - np.testing.assert_allclose(out, init_value) - - @parameterized.parameters(1, 2, 3, 4, 8) - def test_atomic_counter(self, num_threads): - self.skip_if_mosaic_gpu() - - # The Pallas TPU lowering currently supports only blocks of rank >= 1 - if jtu.test_device_matches(["tpu"]): - self.skipTest("Not supported on TPU") - - if self.INTERPRET: - self.skipTest("While loop not supported in interpret mode.") - - if jax.config.x64_enabled and jtu.test_device_matches(["gpu"]): - self.skipTest("Not supported on GPU in 64-bit mode") - - @functools.partial( - self.pallas_call, out_shape=( - jax.ShapeDtypeStruct((), intx), - jax.ShapeDtypeStruct((), intx)), - input_output_aliases={0: 0, 1: 1}, - grid=(num_threads,)) - def increment(_, __, lock_ref, counter_ref): - def _cond(_): - return pl.atomic_cas(lock_ref, 0, 1) == 1 - lax.while_loop(_cond, lambda a: a, 0) - counter_ref[...] += 1 - pl.atomic_xchg(lock_ref, (), 0) - - lock, count = increment(0, 0) - np.testing.assert_allclose(lock, 0) - np.testing.assert_allclose(count, num_threads) - - @parameterized.parameters(False, True) - def test_reduce_only_dim(self, use_store): + def test_reduce_only_dim(self): self.skip_if_mosaic_gpu() # The Pallas TPU lowering currently supports only blocks of rank >= 1 @@ -2253,12 +2427,7 @@ def test_reduce_only_dim(self, use_store): @functools.partial(self.pallas_call, out_shape=out_shape) def reduce(x_ref, y_ref): - x = pl.load(x_ref, (jnp.arange(m),)) - y = jnp.sum(x, axis=-1) - if use_store: - pl.store(y_ref, (), y) - else: - y_ref[...] = y + y_ref[...] = jnp.sum(x_ref[jnp.arange(m)], axis=-1) y = reduce(x) y_ref = jnp.sum(x, axis=-1) @@ -2328,11 +2497,12 @@ def make_x(key): @functools.partial(self.pallas_call, out_shape=out_shape, grid=grid) def reduce(x_ref, y_ref): - x = pl.load(x_ref, (jnp.arange(m, dtype=jnp.int32)[:, None], - jnp.arange(n, dtype=jnp.int32)[None])) + x = x_ref[ + jnp.arange(m, dtype=jnp.int32)[:, None], + jnp.arange(n, dtype=jnp.int32)[None], + ] y = op(x, axis=axis) - pl.store(y_ref, - tuple(jnp.arange(d, dtype=jnp.int32) for d in y.shape), y) + y_ref[tuple(jnp.arange(d, dtype=jnp.int32) for d in y.shape)] = y for i, key in enumerate(random.split(random.key(0), 20)): x = make_x(key) @@ -2467,9 +2637,6 @@ def test_arbitrary_padding_jnp_pad( ): if jtu.test_device_matches(["gpu"]): self.skipTest("Not implemented on GPU") - # TODO(apaszke): Remove after 12 weeks have passed. - if not jtu.if_cloud_tpu_at_least(2024, 12, 19): - self.skipTest("Requires libtpu built after 2024-12-19") x = jnp.arange(np.prod(array_shapes), dtype=dtype).reshape(array_shapes) @@ -2517,6 +2684,162 @@ def kernel(x_ref, out_ref): )(x) np.testing.assert_array_equal(out, np.diagonal(x)) + @parameterized.product( + # Skip some steps to just run less cases + # TODO(mvoz): Hypothesis? + x_dim_size=tuple(8 * i for i in range(1, 5)), + y_dim_size=tuple(8 * i for i in range(1, 5)), + z_dim_size=tuple(128 * i for i in range(1, 3)), + dtype=(jnp.float32,), + ) + def test_jnp_swapaxes_major_minor( + self, x_dim_size, y_dim_size, z_dim_size, dtype + ): + if jtu.test_device_matches(["gpu"]): + if any( + not is_power_of_two(x) for x in [x_dim_size, y_dim_size, z_dim_size] + ): + self.skipTest( + "the Pallas Triton lowering currently requires that all operations" + " have array arguments and results whose size is a power of 2." + f" Encountered an array of shape ({x_dim_size}, {y_dim_size}," + f" {z_dim_size})" + ) + if x_dim_size * y_dim_size * z_dim_size * 4 > 32768: + self.skipTest( + "Mosaic GPU kernel exceeds available shared memory" + f" smem_bytes={x_dim_size * y_dim_size * z_dim_size * 4} > 32768" + ) + self.skip_if_mosaic_gpu() + + x = jnp.arange(x_dim_size * y_dim_size * z_dim_size, dtype=dtype).reshape( + (x_dim_size, y_dim_size, z_dim_size) + ) + + def kernel(x_ref, out_ref): + out_ref[...] = jnp.swapaxes(x_ref[...], 0, 1) + + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct( + (y_dim_size, x_dim_size, z_dim_size), dtype + ), + )(x) + expected = jnp.swapaxes(x, 0, 1) + np.testing.assert_array_equal(out, expected) + + @parameterized.product( + shape_and_axes=[ + ((8, 16, 128), (0, 2, 1)), + ((16, 128, 128), (1, 0, 2)), + ((8, 16, 128), (1, 2, 0)), + ((8, 16, 128), (2, 0, 1)), + ((8, 16, 128), (2, 1, 0)), + ((1, 2, 16, 128), (0, 1, 2, 3)), + ((1, 2, 16, 128), (1, 0, 3, 2)), + ((8, 16, 8, 128), (0, 2, 1, 3)), + ((2, 8, 8, 3), (0, 2, 1, 3)), + ((2, 8, 16, 8, 128), (0, 1, 3, 2, 4)), + ((1, 2, 8, 8, 1), (0, 1, 3, 2, 4)), + ((3, 2, 8, 8, 8), (1, 0, 3, 2, 4)), + ] + + [((1, 2, 9, 3, 4), perm) for perm in itertools.permutations(range(5))] + ) + def test_transpose(self, shape_and_axes): + if jtu.test_device_matches(["gpu"]): + self.skipTest("Not implemented on GPU") + in_shape, transpose_axes = shape_and_axes + x = jnp.arange(math.prod(in_shape), dtype=jnp.float32).reshape(in_shape) + expected = jnp.transpose(x, axes=transpose_axes) + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(expected.shape, jnp.float32), + ) + def kernel(x_ref, out_ref): + out_ref[...] = jnp.transpose(x_ref[...], axes=transpose_axes) + + np.testing.assert_array_equal( + kernel(x), + expected, + ) + + @hp.given(batch_size=hps.integers(1, 16)) + def test_8bit_gather(self, batch_size): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Not supported on this hardware") + if jtu.get_tpu_version() < 6: + self.skipTest("Requires TPUv6 or newer") + + dtype = jnp.int8 + xspec = pl.BlockSpec((32, 128), lambda i: (i, 0)) + lspec = pl.BlockSpec((32, 128), lambda i: (0, 0)) + + data = jax.random.randint( + key=jax.random.key(1234), + shape=(32 * batch_size, 128), + minval=0, + maxval=32, + dtype=jnp.int8, + ) + lut = jax.random.randint( + key=jax.random.key(1234), + shape=(32, 128), + minval=-128, + maxval=127, + dtype=jnp.int8, + ) + + def kernel(data_ref, lut_ref, output_ref): + data_chunk = data_ref[...] + lut_chunk = lut_ref[...] + data_chunk = data_chunk.reshape((32, 128, 1)) + output = lax.gather( + lut_chunk, + data_chunk, + dimension_numbers=lax.GatherDimensionNumbers( + offset_dims=(), + start_index_map=(0,), + operand_batching_dims=(1,), + start_indices_batching_dims=(1,), + collapsed_slice_dims=(0,) + ), + slice_sizes=(1, 1), + mode="promise_in_bounds", + ) + output_ref[...] = output + + deq_call = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(data.shape, dtype), + grid=(batch_size,), + out_specs=xspec, + in_specs=[xspec, lspec], + ) + result = deq_call(data, lut) + expected = jnp.take_along_axis( + lut, data, axis=0, mode="promise_in_bounds" + ) + np.testing.assert_array_equal(result, expected) + + def test_delay(self): + if jtu.test_device_matches(["gpu"]): + # ROCm always uses Triton (Mosaic GPU doesn't support ROCm). + if jtu.is_device_rocm() or not use_mosaic_gpu: + self.skipTest("Delay is only implemented on the MGPU backend for GPUs.") + if self.INTERPRET: + self.skipTest("Not implemented in interpret mode.") + # This is mostly to test that the kernel compiles. It's difficult to + # test the exact timing of the delay. + def kernel(x_ref, o_ref): + pl.delay(100_000) + o_ref[...] = x_ref[...] + x = jax.random.normal(jax.random.key(0), (128,), dtype=jnp.float32) + result = pl.pallas_call( + kernel, out_shape=x, + interpret=self.INTERPRET)(x) + np.testing.assert_array_equal(result, x) + class OpsInterpretTest(OpsTest): INTERPRET = True @@ -2548,7 +2871,7 @@ class PallasPrimitivesTest(PallasBaseTest): ]) def test_load_pretty_print(self, expr, expected): def body(x_ref): - x = pl.load(x_ref, expr()) + x = pallas_primitives.load(x_ref, expr()) return [x] jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( wrap_init(body, 1), [state.shaped_array_ref((4, 3, 2), jnp.int32)]) @@ -2563,7 +2886,9 @@ def body(x_ref): ]) def test_store_pretty_print(self, expr, expected): def body(x_ref): - pl.store(x_ref, expr(), pl.load(x_ref, expr())) + pallas_primitives.store( + x_ref, expr(), pallas_primitives.load(x_ref, expr()) + ) return [] jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( wrap_init(body, 1), [state.shaped_array_ref((4, 3, 2), jnp.int32)]) @@ -2571,19 +2896,21 @@ def body(x_ref): @parameterized.parameters(*[ (lambda: (pl.dslice(0, 4), slice(None), slice(None)), - "c:i32[4,3,2], a[:,:,:] <-"), + "c:i32[4,3,2], a[:,:,:] <-"), (lambda: (pl.dslice(0, 3), slice(None), slice(None)), - "c:i32[3,3,2], a[:3,:,:] <-"), + "c:i32[3,3,2], a[:3,:,:] <-"), (lambda: (pl.dslice(1, 3), slice(None), pl.dslice(0, 4)), - "c:i32[3,3,4], a[1:,:,:4] <-"), + "c:i32[3,3,4], a[1:,:,:4] <-"), (lambda: (jnp.arange(5), slice(None), pl.dslice(0, 4)), - "e:i32[5,3,4], a[b,:,:4] <-"), + "e:i32[5,3,4], a[b,:,:4] <-"), (lambda: (jnp.arange(5)[:, None], jnp.arange(3)[None], pl.dslice(4)), - "o:i32[5,3,4], a[m,n,:4] <-"), + "o:i32[5,3,4], a[m,n,:4] <-"), ]) def test_swap_pretty_print(self, expr, expected): def body(x_ref): - x = pl.swap(x_ref, expr(), pl.load(x_ref, expr())) + x = pallas_primitives.swap( + x_ref, expr(), pallas_primitives.load(x_ref, expr()) + ) return [x] jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( wrap_init(body, 1), [state.shaped_array_ref((4, 3, 2), jnp.int32)]) @@ -2593,8 +2920,6 @@ def body(x_ref): def test_reciprocal(self, approx): if not jtu.test_device_matches(["tpu"]): self.skipTest("Not implemented on non-TPU devices") - if not jtu.if_cloud_tpu_at_least(2025, 3, 8): - self.skipTest("Test requires libtpu from 2025/3/8 or later") shape = (32, 256) x = jnp.arange(np.prod(shape), dtype=jnp.float32).reshape(shape) diff --git a/tests/pallas/pallas_cost_estimate_test.py b/tests/pallas/pallas_cost_estimate_test.py index d9eb18e6f540..6b0b39ed15dc 100644 --- a/tests/pallas/pallas_cost_estimate_test.py +++ b/tests/pallas/pallas_cost_estimate_test.py @@ -61,6 +61,20 @@ def matmul(a, b): self.assertEqual(cost.transcendentals, 0) self.assertEqual(cost.bytes_accessed, 4*(b*m*k + b*n*k + b*m*n)) + @parameterized.parameters( + ((10, 11, 12), (11, 12), "abc,bc->a", 2640), + ((10, 11, 12), (13, 11, 12), "abc,dbc->ad", 34320), + ((10, 11, 12), (9, 10, 11, 12), "abc,dabc->d", 23760), + ) + def test_einsum(self, a_shape, b_shape, pattern, expected_flops): + def matmul(a, b): + return jnp.einsum(pattern, a, b) + cost = cost_estimate.estimate_cost( + matmul, + jax.ShapeDtypeStruct(a_shape, jnp.float32), + jax.ShapeDtypeStruct(b_shape, jnp.float32)) + self.assertEqual(cost.flops, expected_flops) + def test_attention(self): qk_dim = 16 v_dim = 4 diff --git a/tests/pallas/pallas_error_handling_test.py b/tests/pallas/pallas_error_handling_test.py index cd5ceecfc9a8..e64c5f70b446 100644 --- a/tests/pallas/pallas_error_handling_test.py +++ b/tests/pallas/pallas_error_handling_test.py @@ -16,6 +16,7 @@ import traceback from absl.testing import absltest +from absl.testing import parameterized import jax from jax import numpy as jnp from jax._src import config @@ -23,6 +24,7 @@ from jax._src.pallas.mosaic import error_handling from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu +import numpy as np config.parse_flags_with_absl() @@ -50,9 +52,9 @@ def test_non_singular_stride(self): grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), ) @functools.partial(pl.pallas_call, out_shape=out_shape, grid_spec=grid_spec) @@ -92,20 +94,21 @@ def kernel_in_jitted_fn(x): tb_string = "".join(tb_string) self.assertEndsWith(tb_string, "x = input_ref[:, ::8]\n") - def test_invalid_smem_vmem_verification_error(self): + def test_index_with_f32_verification_error(self): input_arr = jax.random.uniform(jax.random.key(0), (2, 2), dtype=jnp.float32) out_shape = jax.ShapeDtypeStruct((1, 1), jnp.float32) grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), ) @functools.partial(pl.pallas_call, out_shape=out_shape, grid_spec=grid_spec) def test_kernel(input_ref, output_ref): - output_ref[0, 0] = input_ref[0, 0] + idx = input_ref[0, 0] + output_ref[idx, 0] = input_ref[0, 0] # Test that a verification error is raised. This assert is a guard against # underlying changes in Pallas lowering. @@ -113,8 +116,8 @@ def test_kernel(input_ref, output_ref): # the test example to force a different error. with self.assertRaisesRegex( error_handling.VerificationError, - "'memref.store' op failed to verify that type of 'value' matches " - "element type of 'memref'", + "must be signless-integer-like or memref of signless-integer, " + "but got 'f32'" ): test_kernel(input_arr) @@ -125,7 +128,37 @@ def test_kernel(input_ref, output_ref): except error_handling.MosaicError as e: tb_string = traceback.format_tb(e.__traceback__) tb_string = "".join(tb_string) - self.assertEndsWith(tb_string, "output_ref[0, 0] = input_ref[0, 0]\n") + self.assertEndsWith(tb_string, "output_ref[idx, 0] = input_ref[0, 0]\n") + + @parameterized.parameters( + ((2048,), (256,)), + ((2048,), (512,)), + ) + def test_small_1d_block_spec_raises(self, total_shape, block_shape): + # https://github.com/jax-ml/jax/issues/25379 + dtype = jnp.float32 + + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...] * 2 + + x = jnp.arange(np.prod(total_shape), dtype=dtype).reshape(total_shape) + x_spec = pl.BlockSpec(block_shape, lambda *args: args) + fn = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(total_shape, dtype), + in_specs=[x_spec], + out_specs=x_spec, + grid=tuple(tot // blk for tot, blk in zip(total_shape, block_shape, + strict=True)), + ) + # Having a block size that is too small should raise a suggestion + # to increase the block size. + with self.assertRaisesRegex( + jax.errors.JaxRuntimeError, + r"Try changing your kernel block shape to \([0-9,\s]+\) to align with" + " the XLA layout", + ): + fn(x) def test_parse_location_string(self): name, frames = error_handling.parse_location_string(LOCATION_TEST_STRING) diff --git a/tests/pallas/pallas_jumble_test.py b/tests/pallas/pallas_jumble_test.py deleted file mode 100644 index 509ef08a987f..000000000000 --- a/tests/pallas/pallas_jumble_test.py +++ /dev/null @@ -1,373 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import sys - -os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5" - -from absl.testing import absltest -import jax -from jax import lax -from jax._src import config -from jax._src import core -from jax._src import dtypes -from jax._src import test_util as jtu -from jax._src.interpreters import batching -from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr -from jax.experimental import pallas as pl -import jax.numpy as jnp -import numpy as np - - -# TODO(mvoz): Update signatures of pallas_call to correct inputs/outputs. -# pylint: disable=no-value-for-parameter - -config.parse_flags_with_absl() - - -intx = dtypes.canonicalize_dtype(jnp.int64) -floatx = dtypes.canonicalize_dtype(jnp.float64) - - -def _assert_ragged_equal_with_elementwise_mask( - row_count, col_grid_size, ragged_shape, res, ref -): - total_columns = col_grid_size * 128 - mask = jnp.zeros((len(ragged_shape), row_count, total_columns), dtype=bool) - - for i, r in enumerate(ragged_shape): - mask = mask.at[i, :, : r * 128].set(True) - - res_valid = jnp.where(mask, res, -1) - ref_valid = jnp.where(mask, ref, -1) - - np.testing.assert_allclose(res_valid, ref_valid) - - -@jtu.with_config(jax_traceback_filtering="off") -class PallasBaseTest(jtu.JaxTestCase): - INTERPRET = False - - def setUp(self): - if jtu.test_device_matches(["cpu"]) and not self.INTERPRET: - self.skipTest("On CPU the test works only in interpret mode") - if jtu.test_device_matches( - ["cuda"] - ) and not jtu.is_cuda_compute_capability_at_least("8.0"): - self.skipTest("Only works on GPU with capability >= sm80") - if sys.platform == "win32" and not self.INTERPRET: - self.skipTest("Only works on non-Windows platforms") - - super().setUp() - _trace_kernel_to_jaxpr.cache_clear() - - def pallas_call(self, *args, **kwargs): - return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET) - - -@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_dtype_promotion="standard") -class PallasCallRaggedVmapTest(PallasBaseTest): - - def test_vmap_jumble_over_sin_kernel(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Only tested on TPU") - - row_count = 8 - col_grid_size = 5 - ragged_shape = [3, 1, 4] - sizes = lax.convert_element_type( - jnp.array([128 * x for x in ragged_shape]), - core.bint(col_grid_size * 128), - ) - x = jax.vmap( - lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis - )(sizes) - - def kernel(x_ref, o_ref): - o_ref[...] = jnp.sin(x_ref[...]) - - def invoke_kernel(x): - return pl.pallas_call( - kernel, - in_specs=[pl.BlockSpec((8, 128), lambda j, k: (j, k))], - out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)), - out_shape=jax.ShapeDtypeStruct( - (8, col_grid_size * 128), dtype=jnp.float32 - ), - grid=(1, col_grid_size), - interpret=self.INTERPRET, - # See note - on zero filling counterfactuals - debug=True, - )(x) - - res = jax.vmap( - invoke_kernel, - out_axes=batching.jumble_axis, - in_axes=batching.jumble_axis, - axis_size=3, - )(x) - - ref = jax.vmap( - jnp.sin, - out_axes=batching.jumble_axis, - in_axes=batching.jumble_axis, - axis_size=3, - )(x) - - _assert_ragged_equal_with_elementwise_mask( - row_count, col_grid_size, ragged_shape, res.data, ref.data - ) - - def test_vmap_jumble_over_add_kernel(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Only tested on TPU") - - row_count = 8 - col_grid_size = 5 - ragged_shape = [3, 1, 4] - sizes = lax.convert_element_type( - jnp.array([128 * x for x in ragged_shape]), - core.bint(col_grid_size * 128), - ) - x = jax.vmap( - lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis - )(sizes) - y = jax.vmap( - lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis - )(sizes) - - def kernel(x_ref, y_ref, o_ref): - o_ref[...] = x_ref[...] + y_ref[...] - - def invoke_kernel(x, y): - return pl.pallas_call( - kernel, - in_specs=[ - pl.BlockSpec((8, 128), lambda j, k: (j, k)), - pl.BlockSpec((8, 128), lambda j, k: (j, k)), - ], - out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)), - out_shape=jax.ShapeDtypeStruct( - (8, col_grid_size * 128), dtype=jnp.float32 - ), - grid=(1, col_grid_size), - interpret=self.INTERPRET, - )(x, y) - - # We've had this test fail with data corruption due to multiple - # invocations, so we run it k times to make sure it's not setting up - # memory incorrectly for subsequent invocations. - for _ in range(4): - res = jax.vmap( - invoke_kernel, - out_axes=batching.jumble_axis, - in_axes=batching.jumble_axis, - axis_size=3, - )(x, y) - - res = res.data - total = len(ragged_shape) * row_count * col_grid_size * 128 - res_total = np.prod(res.shape) - self.assertEqual(res_total, total) - - ref = jax.vmap( - lambda x, y: x + y, - out_axes=batching.jumble_axis, - in_axes=batching.jumble_axis, - axis_size=3, - )(x, y) - _assert_ragged_equal_with_elementwise_mask( - row_count, col_grid_size, ragged_shape, res, ref.data - ) - - def test_vmap_jumble_over_sin_kernel_grid_remapping(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Only tested on TPU") - - row_count = 8 - col_grid_size = 5 - ragged_shape = [3, 1, 4] - sizes = lax.convert_element_type( - jnp.array([128 * x for x in ragged_shape]), - core.bint(col_grid_size * 128), - ) - x = jax.vmap( - lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis - )(sizes) - - def kernel(x_ref, o_ref): - o_ref[...] = jnp.sin(x_ref[...]) * pl.program_id(2) - - def invoke_kernel(x): - return pl.pallas_call( - kernel, - in_specs=[pl.BlockSpec((8, 128), lambda j, k: (j, k))], - out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)), - out_shape=jax.ShapeDtypeStruct((8, 640), dtype=jnp.float32), - grid=(1, 5), - interpret=self.INTERPRET, - )(x) - - with self.assertRaisesRegex(ValueError, "Axis 2 is out of bounds for grid"): - jax.vmap( - invoke_kernel, - out_axes=batching.jumble_axis, - in_axes=batching.jumble_axis, - axis_size=3, - )(x) - - def test_vmap_jumble_over_matmul_kernel(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Only tested on TPU") - - if jtu.is_device_tpu(version=4): - self.skipTest("Flaky 15% of the time on tpuv4?") - - m = 128 - k = 640 - n = 640 - - def matmul_kernel(x_ref, y_ref, x_sentinel, z_ref): - # weird little once-only reset - @pl.when(x_sentinel[...][0][0] == 1.0) - def _(): - z_ref[...] = jnp.zeros_like(z_ref) - x_sentinel[...] = jnp.zeros_like(x_sentinel) - - z_ref[...] += x_ref[...] @ y_ref[...] - - def matmul( - x: jax.Array, - y: jax.Array, - x_sentinel: jax.Array, - *, - bm: int = 128, - bk: int = 128, - bn: int = 640, - ): - # m, k = x.shape - # _, n = y.shape - # a (1, 5) grid - # TODO(mvoz): parameterize this grid? - grid = (n // bn, k // bk) - return pl.pallas_call( - matmul_kernel, - out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), - in_specs=[ - pl.BlockSpec( - (bm, bk), - lambda j, k: (0, k), - ), - pl.BlockSpec( - (bk, bn), - lambda j, k: (k, j), - ), - pl.BlockSpec( - (bm, bn), - lambda j, k: (0, j), - ), - ], - out_specs=pl.BlockSpec( - (bm, bn), - lambda j, k: (0, j), - ), - grid=grid, - input_output_aliases={2: 0}, - interpret=self.INTERPRET, - )(x, y, x_sentinel) - - # TODO(mvoz): parameterize this shape? - ragged_shape = [3, 1, 4] - sizes = lax.convert_element_type( - jnp.array([128 * x for x in ragged_shape]), - core.bint(k), - ) - x = jax.vmap(lambda k_: jnp.ones((m, k_)), out_axes=batching.jumble_axis)( - sizes - ) - x_sentinel = jax.vmap( - lambda k_: jnp.ones((m, k_)), out_axes=batching.jumble_axis - )(sizes) - y = jax.vmap(lambda k_: jnp.ones((k_, n)), out_axes=batching.jumble_axis)( - sizes - ) - - res = jax.vmap( - matmul, - out_axes=batching.jumble_axis, - in_axes=batching.jumble_axis, - axis_size=3, - )(x, y, x_sentinel) - - ref = jax.vmap( - jnp.dot, - out_axes=batching.jumble_axis, - in_axes=batching.jumble_axis, - axis_size=3, - )(x, y) - - ref = ref.data - res = res.data - np.testing.assert_allclose(ref, res) - - def test_vmap_jumble_ragged_boundary_unaligned_with_grid(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Only tested on TPU") - - self.skipTest("Checkify NYI") - - row_count = 8 - col_grid_size = 5 - ragged_shape = [3, 1, 4] - sizes = lax.convert_element_type( - jnp.array([(128 * x) - 1 for x in ragged_shape]), - core.bint(col_grid_size * 128), - ) - x = jax.vmap( - lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis - )(sizes) - - def kernel(x_ref, o_ref): - o_ref[...] = jnp.sin(x_ref[...]) - - def invoke_kernel(x): - return pl.pallas_call( - kernel, - in_specs=[pl.BlockSpec((8, 128), lambda j, k: (j, k))], - out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)), - out_shape=jax.ShapeDtypeStruct((8, 640), dtype=jnp.float32), - grid=(1, 5), - interpret=False, - )(x) - - with self.assertRaisesRegex( - ValueError, - "Ragged input shape must be evenly divisble by the grid" # noqa: W605 - " size at the ragged dimension 2", - ): - jax.vmap( - invoke_kernel, - out_axes=batching.jumble_axis, - in_axes=batching.jumble_axis, - axis_size=3, - )(x) - - -class PallasCallNamedGridInterpretTest(PallasCallRaggedVmapTest): - INTERPRET = True - - -if __name__ == "__main__": - absltest.main() diff --git a/tests/pallas/pallas_shape_poly_test.py b/tests/pallas/pallas_shape_poly_test.py index e9384afec37c..04d7b11b9379 100644 --- a/tests/pallas/pallas_shape_poly_test.py +++ b/tests/pallas/pallas_shape_poly_test.py @@ -27,7 +27,7 @@ import jax from jax._src import config from jax._src import test_util as jtu -from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr +from jax._src.pallas import pallas_call import jax.numpy as jnp from jax.experimental import pallas as pl from jax import export @@ -93,13 +93,15 @@ def setUp(self): if sys.platform == "win32": self.skipTest("Only works on non-Windows platforms") super().setUp() - _trace_kernel_to_jaxpr.cache_clear() + # TODO(bchetioui): Remove this for H100+ once tests are all compatible with + # Pallas/Mosaic GPU. + self.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(False)) def test_copy(self): # The blocks are static, but the input and the grid are of polymorphic # dimensions. block_shape = (8, 128) - def f(x, *, eager=False): # x: i32[w, h] + def f(x, *, eager=False, backend=None): # x: i32[w, h] def copy_kernel(x_ref, o_ref): o_ref[...] = x_ref[...] # Use both pl.cdiv and // for specifying the grid @@ -111,6 +113,7 @@ def copy_kernel(x_ref, o_ref): in_specs=[pl.BlockSpec(block_shape, lambda i, j: (i, j))], out_specs=pl.BlockSpec(block_shape, lambda i, j: (i, j)), grid=grid, + backend=backend, interpret=eager and jtu.test_device_matches(["cpu"]))(x) shape1 = (128, 256) @@ -137,7 +140,7 @@ def copy_kernel(x_ref, o_ref): NotImplementedError, "dynamic grid bounds not supported in the Triton backend"): export.export( - jax.jit(f), + jax.jit(functools.partial(f, backend="triton")), platforms=["cuda"])(jax.ShapeDtypeStruct((w, h), jnp.int32)) def test_block_sizes_must_be_static_no_grid(self): diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 745c30ba98cb..d05a0508db19 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -1,4 +1,3 @@ -import contextlib # Copyright 2023 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations +import contextlib +import dataclasses import functools import itertools import math @@ -20,32 +22,34 @@ import re import sys -os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5" - from absl.testing import absltest from absl.testing import parameterized import jax -import jax.export from jax import lax from jax import random from jax._src import checkify from jax._src import config from jax._src import core as jax_core from jax._src import dtypes +from jax._src import hijax from jax._src import test_util as jtu -from jax._src.lax.control_flow.for_loop import for_loop from jax._src.pallas import pallas_call -from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr +from jax._src.pallas import pallas_test_util as ptu from jax.experimental import pallas as pl +import jax.export import jax.numpy as jnp import numpy as np +os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5" + if sys.platform != "win32": from jax.experimental.pallas import tpu as pltpu - from jax.experimental.pallas import triton as plgpu + from jax.experimental.pallas import triton as pltriton + from jax.experimental.pallas import mosaic_gpu as plmgpu else: pltpu = None - plgpu = None + pltriton = None + plmgpu = None # TODO(sharadmv): Update signatures of pallas_call to correct inputs/outputs. @@ -61,52 +65,8 @@ def smem_on_tpu(): return None -intx = dtypes.canonicalize_dtype(jnp.int64) -floatx = dtypes.canonicalize_dtype(jnp.float64) - - -@functools.partial(jax.jit, static_argnames=["bm", "bn", "gm", "bk", - "interpret", "debug"]) -def matmul(x, y, *, bm, bn, gm, bk, interpret, debug=False): - m, n, k = x.shape[0], y.shape[1], x.shape[1] - @functools.partial( - pl.pallas_call, out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), - interpret=interpret, - debug=debug, - grid=pl.cdiv(m, bm) * pl.cdiv(n, bn)) - def matmul_kernel(x_ref, y_ref, o_ref): - pid = pl.program_id(axis=0).astype(intx) - num_pid_m = m // bm - num_pid_n = n // bn - num_pid_in_group = gm * num_pid_n - group_id = lax.div(pid, num_pid_in_group) - first_pid_m = group_id * gm - group_size_m = jnp.minimum(num_pid_m - first_pid_m, gm) - pid_m = first_pid_m + lax.rem(pid, group_size_m) - pid_n = lax.div(lax.rem(pid, num_pid_in_group), group_size_m) - idx_m = pid_m * bm + jnp.arange(bm) - idx_n = pid_n * bn + jnp.arange(bn) - idx_m = pl.max_contiguous(pl.multiple_of(idx_m, bm), bm) - idx_n = pl.max_contiguous(pl.multiple_of(idx_n, bn), bn) - acc = jnp.zeros((bm, bn), dtype=jnp.float32) - def body(i, acc_ref): - idx_k = i * bk + jnp.arange(bk) - x_idx = ( - jax.lax.broadcast_in_dim(idx_m, (bm, bk), (0,)), - jax.lax.broadcast_in_dim(idx_k, (bm, bk), (1,))) - y_idx = ( - jax.lax.broadcast_in_dim(idx_k, (bk, bn), (0,)), - jax.lax.broadcast_in_dim(idx_n, (bk, bn), (1,))) - x_block, y_block = x_ref[x_idx], y_ref[y_idx] - out = pl.dot(x_block, y_block) - acc_ref[:, :] += out - acc = for_loop(k // bk, body, acc).astype(o_ref.dtype) - o_idx = ( - jax.lax.broadcast_in_dim(idx_m, (bm, bn), (0,)), - jax.lax.broadcast_in_dim(idx_n, (bm, bn), (1,)), - ) - o_ref[o_idx] = acc - return matmul_kernel(x, y) +intx = dtypes.default_int_dtype() +floatx = dtypes.default_float_dtype() @functools.partial(jax.jit, static_argnames=["bm", "bn", "bk", @@ -127,36 +87,62 @@ def matmul_block_spec(x, y, *, bm, bn, bk, interpret, debug=False): ) def matmul_kernel(x_ref, y_ref, o_ref): acc = jnp.zeros(o_ref.shape, dtype=jnp.float32) - def body(i, acc_ref): - x_block = pl.load(x_ref, (slice(None), pl.ds(i * bk, bk))) - y_block = pl.load(y_ref, (pl.ds(i * bk, bk), slice(None))) - acc_ref[:, :] += pl.dot(x_block, y_block) - acc = for_loop(k // bk, body, acc).astype(o_ref.dtype) + def body(i, acc): + x_block = x_ref[:, pl.ds(i * bk, bk)] + y_block = y_ref[pl.ds(i * bk, bk), :] + return acc + pl.dot(x_block, y_block) + acc = lax.fori_loop(0, k // bk, body, acc).astype(o_ref.dtype) o_ref[:, :] = acc return matmul_kernel(x, y) -@jtu.with_config(jax_traceback_filtering="off") -class PallasBaseTest(jtu.JaxTestCase): - INTERPRET = False +class PallasCallTest(ptu.PallasTest): def setUp(self): - if jtu.test_device_matches(["cpu"]) and not self.INTERPRET: - self.skipTest("On CPU the test works only in interpret mode") - if (jtu.test_device_matches(["cuda"]) and - not jtu.is_cuda_compute_capability_at_least("8.0")): - self.skipTest("Only works on GPU with capability >= sm80") - if sys.platform == "win32" and not self.INTERPRET: - self.skipTest("Only works on non-Windows platforms") - super().setUp() - _trace_kernel_to_jaxpr.cache_clear() + # TODO(bchetioui): Remove this once tests are all compatible with + # Pallas/Mosaic GPU. + self.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(False)) + + def test_pallas_call_infers_backend_from_compiler_params(self): + if not jtu.test_device_matches(["gpu"]): + self.skipTest("Only works on GPU.") + if not jtu.is_cuda_compute_capability_at_least("9.0"): + self.skipTest("Only works on a GPU with capability >= sm90") + + triton_params = pltriton.CompilerParams( + num_warps=2, + num_stages=1, + ) + mosaic_gpu_params = plmgpu.CompilerParams() - def pallas_call(self, *args, **kwargs): - return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET) + pallas_call = functools.partial( + pl.pallas_call, + grid=(1,), + out_shape=jax.ShapeDtypeStruct((128, 128), jnp.float32), + ) + def add_one(x_ref, o_ref): + x = x_ref[:] + # Use a Pallas/Mosaic GPU-specific primitive to trigger a failure when + # using a different backend. + plmgpu.print_layout("x: {}", x) + o_ref[:] = x + 1 + + add_one_mgpu = pallas_call(add_one, compiler_params=mosaic_gpu_params) + add_one_triton = pallas_call(add_one, compiler_params=triton_params) + x = jnp.ones((128, 128), jnp.float32) -class PallasCallTest(PallasBaseTest): + # Running on the Mosaic GPU backend should be fine. + self.assertArraysEqual(add_one_mgpu(x), x + 1) + + # But Triton doesn't have the required primitive, so it should fail to + # lower. + with self.assertRaisesRegex( + NotImplementedError, + "Unimplemented primitive in Pallas GPU lowering: print_layout." + ): + add_one_triton(x) def test_add_one(self): if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: @@ -496,7 +482,9 @@ def kernel(o_ref): self.assertEqual(o_ref_shape, (4,)) self.assertAllClose(pids[0:4], np.array([0] * 4, dtype=np.int32)) - def test_hoisted_consts(self): + def test_const_args(self): + if config.use_simplified_jaxpr_constants.value: + self.skipTest("TODO: decide if we want to keep these errors") # See https://github.com/jax-ml/jax/issues/21557. # to_store will be hoisted as a constant. Choose distinct shapes from in/outs. to_store = np.arange(128, dtype=np.float32).reshape((1, 128)) @@ -532,32 +520,6 @@ def index(x_ref, idx_ref, o_ref): idx = jnp.arange(i, i + 2) np.testing.assert_allclose(index(x, idx), x[idx]) - @parameterized.named_parameters(*[ - (f"m_{m}_n_{n}_k_{k}_dtype_{dtype}_bm_{block_size_m}_" - f"bn_{block_size_n}_bk_{block_size_k}_gm_{group_size_m}", m, n, k, dtype, - block_size_m, block_size_n, block_size_k, group_size_m) - for m in [512, 1024] - for k in [512] - for n in [512, 1024] - for dtype in ["float32", "float16"] - for block_size_m in [64, 128] - for block_size_n in [64, 128] - for block_size_k in [32] - for group_size_m in [8] - if block_size_m <= m and block_size_n <= n and block_size_k <= k - ]) - def test_matmul(self, m, n, k, dtype, bm, bn, bk, gm): - if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: - self.skipTest("On TPU the test works only in interpret mode") - k1, k2 = random.split(random.key(0)) - x = random.normal(k1, (m, k), dtype=dtype) - y = random.normal(k2, (k, n), dtype=dtype) - out = matmul(x, y, bm=bm, bn=bn, bk=bk, gm=gm, - interpret=self.INTERPRET) - expected = jnp.matmul( - x, y, preferred_element_type=jnp.float32).astype(dtype) - np.testing.assert_allclose(out, expected, atol=0.05, rtol=0.05) - @parameterized.named_parameters(*[ (f"m_{m}_n_{n}_k_{k}_dtype_{dtype}_bm_{block_size_m}_" f"bn_{block_size_n}_bk_{block_size_k}", m, n, k, dtype, @@ -583,38 +545,6 @@ def test_matmul_block_spec(self, m, n, k, dtype, bm, bn, bk): x, y, preferred_element_type=jnp.float32).astype(dtype) np.testing.assert_allclose(out, expected, atol=0.05, rtol=0.05) - @parameterized.named_parameters(*( - dict(testcase_name=f"{batch_size}_{size}_{block_size}_{dtype}", - batch_size=batch_size, size=size, block_size=block_size, dtype=dtype) - for batch_size in [1, 2, 4, 23] - for size in [1, 2, 129, 255, 256] - for block_size in [1, 2, 32, 64, 128, 256] - for dtype in ["float32"] - if size < block_size - )) - def test_softmax(self, batch_size, size, block_size, dtype): - if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: - self.skipTest("On TPU the test works only in interpret mode") - @functools.partial(self.pallas_call, - out_shape=jax.ShapeDtypeStruct((batch_size, size), dtype), - grid=batch_size) - def softmax(x_ref, o_ref): - row_idx = pl.program_id(0) - x_idx = jnp.arange(block_size) - row_idxs = (row_idx, x_idx) - mask = x_idx < x_ref.shape[1] - row = pl.load(x_ref, row_idxs, mask=mask, other=-float("inf")) - row_minus_max = row - jnp.max(row, axis=0) - numerator = jnp.exp(row_minus_max) - denominator = jnp.sum(numerator, axis=0) - softmax_output = numerator / denominator - pl.store(o_ref, row_idxs, softmax_output, mask=mask) - - key = random.key(0) - x = random.normal(key, [batch_size, size], dtype=dtype) - np.testing.assert_allclose(softmax(x), jax.nn.softmax(x, axis=-1), - atol=1e-5, rtol=1e-5) - def test_unused_ref(self): if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: self.skipTest("On TPU the test works only in interpret mode") @@ -624,39 +554,14 @@ def test_unused_ref(self): out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), ) def dummy(_, o_ref): - pl.store(o_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None, :]), - jnp.ones_like(o_ref)) + o_ref[jnp.arange(m)[:, None], jnp.arange(n)[None, :]] = jnp.ones_like( + o_ref + ) key = random.key(0) x = random.normal(key, (m, n)) np.testing.assert_allclose(dummy(x), jnp.ones_like(x), atol=1e-5, rtol=1e-5) - def test_with_input_output_aliasing(self): - if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: - self.skipTest("On TPU the test works only in interpret mode") - def add_inplace_kernel(_, o_ref, *, block_size): - pid = pl.program_id(axis=0) # we use a 1d launch grid so axis is 0 - block_start = pid * block_size - offsets = block_start + jnp.arange(block_size, dtype=jnp.int32) - mask = offsets < o_ref.shape[0] - x = pl.load(o_ref, (offsets,), mask=mask) - output = x + 1 - pl.store(o_ref, (offsets,), output, mask=mask) - - grid = (8,) - size = 8 - dtype = "float32" - k1 = random.key(0) - block_size = 1 - x = random.normal(k1, [size], dtype=dtype) - kernel = functools.partial(add_inplace_kernel, block_size=block_size) - out = self.pallas_call( - kernel, - out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), - grid=grid, input_output_aliases={0: 0})(x) - expected = x + 1 - np.testing.assert_allclose(out, expected) - def test_using_pallas_slice(self): if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: self.skipTest("On TPU the test works only in interpret mode") @@ -667,8 +572,7 @@ def test_using_pallas_slice(self): out_shape=out_shape, ) def slice_kernel(x_ref, y_ref): - x = pl.load(x_ref, (pl.dslice(0, 4), pl.dslice(0, 4))) - pl.store(y_ref, (pl.dslice(4), pl.dslice(4)), x) + y_ref[:4, :4] = x_ref[:4, :4] x = random.normal(random.key(0), (m, n)) y = slice_kernel(x) y_ref = x[:4] @@ -694,6 +598,22 @@ def f(x): self.assertEqual(f(x), 2.) self.assertEqual(trace_count, 1) + def test_pallas_call_under_disable_jit(self): + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.float32), + ) + def add_one(x_ref, o_ref): + o_ref[...] = x_ref[...] + 1. + + x = jnp.arange(8, dtype=jnp.float32) + + result = add_one(x) + np.testing.assert_array_equal(result, x + 1.) + + with jax.disable_jit(): + result = add_one(x) + np.testing.assert_array_equal(result, x + 1.) + @parameterized.parameters( ("float32", None), ("float32", jax.lax.Precision.DEFAULT), @@ -702,6 +622,9 @@ def f(x): ("float32", jax.lax.DotAlgorithmPreset.DEFAULT), ("float32", jax.lax.DotAlgorithmPreset.F16_F16_F32), ("float32", jax.lax.DotAlgorithmPreset.BF16_BF16_F32), + ("float32", jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X3), + ("float32", jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X6), + ("float32", jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X9), ("float32", jax.lax.DotAlgorithmPreset.TF32_TF32_F32), ("float32", jax.lax.DotAlgorithmPreset.TF32_TF32_F32_X3), ("float32", jax.lax.DotAlgorithmPreset.F32_F32_F32), @@ -731,7 +654,21 @@ def dot_kernel(x_ref, y_ref, o_ref): precision=jax.lax.Precision.HIGHEST, preferred_element_type=jnp.float32, ) - self.assertAllClose(dot_kernel(x, y), expected, atol=5e-2, rtol=5e-3) + if dtype == "bfloat16" or precision in ( + jax.lax.Precision.HIGHEST, + jax.lax.DotAlgorithmPreset.F32_F32_F32, + jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X6, + jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X9, + ): + atol = 5e-6 + elif precision in ( + jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X3, + jax.lax.DotAlgorithmPreset.TF32_TF32_F32_X3, + ): + atol = 5e-4 + else: + atol = 5e-2 + self.assertAllClose(dot_kernel(x, y), expected, atol=atol, rtol=atol / 10) @parameterized.parameters(jnp.int8, jnp.uint8) def test_integer_dot(self, dtype): @@ -805,6 +742,8 @@ def copy_kernel(x_ref, o_ref): def test_float8_e4m3b11fnuz_dot(self, transpose): if not jtu.test_device_matches(["tpu"]) or not jtu.is_device_tpu_at_least(5): self.skipTest("`float8_e4m3b11fnuz` dot only supported on TPU.") + if jtu.is_device_tpu(7, "x"): + self.skipTest("Unsupported type for matmul.") dtype = jnp.float8_e4m3b11fnuz x = jax.random.normal(jax.random.key(0), (2048, 1024), dtype=jnp.bfloat16) @@ -826,18 +765,42 @@ def dot_kernel(x_ref, y_ref, o_ref): self.assertAllClose(dot_kernel(x, y), expected) + @parameterized.parameters( + ((32,), 2, 0), ((32, 64), 4, 0), ((32, 16), 8, 1), ((32, 16, 2), 16, 1) + ) + def test_split(self, shape, num_parts, axis): + if jtu.test_device_matches(["tpu"]) and shape[axis] == num_parts: + self.skipTest("TPU doesn't support fully split axis.") + + x = jax.random.normal(jax.random.key(0), shape) + expected = jnp.split(x, num_parts, axis) + + @functools.partial(self.pallas_call, out_shape=expected) + def kernel(x_ref, *o_ref): + x_parts = jnp.split(x_ref[()], num_parts, axis) + for o_ref, x_part in zip(o_ref, x_parts): + o_ref[...] = x_part + + self.assertAllClose(kernel(x), expected) + + class PallasCallInterpretTest(PallasCallTest): INTERPRET = True -class PallasCallUnblockedIndexingTest(PallasBaseTest): +class PallasCallElementIndexingTest(ptu.PallasTest): + def setUp(self): + super().setUp() + # TODO(bchetioui): Remove this once tests are all compatible with + # Pallas/Mosaic GPU. + self.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(False)) - def test_block_spec_unblocked(self): + def test_block_spec_element(self): def show_program_ids( - *, shape, block_shape, grid, indexing_mode: pl.IndexingMode + *, shape, block_shape, grid, ): def kernel(o1_ref): - assert o1_ref.shape == block_shape + assert o1_ref.shape == (8, 128) o1_ref[...] = jnp.full(o1_ref.shape, pl.program_id(0)) return self.pallas_call( @@ -845,16 +808,15 @@ def kernel(o1_ref): jax.ShapeDtypeStruct(shape, dtype=np.int32), grid=grid, out_specs=pl.BlockSpec( - block_shape, lambda i: (8 * i, 0), indexing_mode=indexing_mode + block_shape, lambda i: (8 * i, 0), ), )() # No padding pids = show_program_ids( shape=(16, 128), - block_shape=(8, 128), + block_shape=(pl.Element(8), pl.Element(128)), grid=(2,), - indexing_mode=pl.Unblocked(), ) expected_pids = np.array([[0] * 128] * 8 + [[1] * 128] * 8, dtype=np.int32) self.assertAllClose(pids, expected_pids) @@ -865,9 +827,8 @@ def kernel(o1_ref): # Only high padding pids = show_program_ids( shape=(14, 128), - block_shape=(8, 128), + block_shape=(pl.Element(8, (0, 2)), pl.Element(128, (0, 0))), grid=(2,), - indexing_mode=pl.Unblocked(((0, 2), (0, 0))), ) expected_pids = np.array([[0] * 128] * 8 + [[1] * 128] * 6, dtype=np.int32) self.assertAllClose(pids, expected_pids) @@ -876,15 +837,14 @@ def kernel(o1_ref): self.skipTest("TODO: low padding not supported yet") pids = show_program_ids( shape=(11, 128), - block_shape=(8, 128), + block_shape=(pl.Element(8, (3, 2)), pl.Element(128, (0, 0))), grid=(2,), - indexing_mode=pl.Unblocked(((3, 2), (0, 0))), ) expected_pids = np.array([[0] * 128] * 5 + [[1] * 128] * 6, dtype=np.int32) self.assertAllClose(pids, expected_pids) @parameterized.parameters("int32", "float32") - def test_block_spec_unblocked_padding_is_nan(self, dtype_name): + def test_block_spec_element_padding_is_nan(self, dtype_name): if not self.INTERPRET: self.skipTest("Only applicable for the interpret mode") @@ -899,7 +859,7 @@ def copy_kernel(x_ref, o_ref): grid=(1,), in_specs=[ pl.BlockSpec( - (6,), lambda i: 0, indexing_mode=pl.Unblocked(((1, 2),)) + (pl.Element(6, (1, 2)),), lambda i: 0, ) ], )(np.full((3,), 42, dtype=dtype)) @@ -913,7 +873,7 @@ def copy_kernel(x_ref, o_ref): ), ) - def test_unblocked_indexing(self): + def test_element_indexing(self): shape = (16 * 8, 128) result_ty = jax.ShapeDtypeStruct((15 * 8, 128), jnp.float32) @@ -926,7 +886,7 @@ def kernel(x_ref, o_ref): grid=(15,), in_specs=( pl.BlockSpec( - (2 * 8, 128), lambda i: (i * 8, 0), indexing_mode=pl.unblocked + (pl.Element(2 * 8), pl.Element(128)), lambda i: (i * 8, 0), ), ), out_specs=pl.BlockSpec((8, 128), lambda i: (i, 0)), @@ -955,9 +915,8 @@ def kernel(x_ref, y_ref): grid=(1,), in_specs=( pl.BlockSpec( - (2 * 8, 128), + (pl.Element(2 * 8, (0, 8)), pl.Element(128)), lambda i: (0, 0), - indexing_mode=pl.Unblocked(((0, 8), (0, 0))), ), ), out_specs=pl.BlockSpec((8, 128), lambda i: (0, 0)), @@ -966,11 +925,48 @@ def kernel(x_ref, y_ref): np.testing.assert_array_equal(y, x) -class PallasCallUnblockedIndexingInterpretTest(PallasCallUnblockedIndexingTest): +class PallasCallElementIndexingInterpretTest(PallasCallElementIndexingTest): INTERPRET = True -class ApiErrorTest(PallasBaseTest): +class PallasCallBoundedSliceIndexingTest(ptu.PallasTest): + + def setUp(self): + super().setUp() + if not jtu.is_device_tpu(): + self.skipTest("Only applicable for TPU") + + def test_block_spec_bounded_slice_static(self): + shape = (16, 8, 128) + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + + x = jnp.arange(np.prod(shape), dtype=np.int32).reshape(shape) + with self.assertRaisesRegex(NotImplementedError, + "Unsupported block dimension type:"): + _ = self.pallas_call( + kernel, + jax.ShapeDtypeStruct((8, 8, 128), dtype=np.int32), + grid=(1,), + in_specs=( + pl.BlockSpec( + (pl.BoundedSlice(8), 8, 128), lambda i: (pl.ds(4, 8), 0, 0), + ), + ), + out_specs=pl.BlockSpec( + (8, 8, 128), lambda i: (0, 0, 0), + ), + )(x) + + +class ApiErrorTest(ptu.PallasTest): + + def setUp(self): + super().setUp() + # TODO(bchetioui): Remove this once tests are all compatible with + # Pallas/Mosaic GPU. + self.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(False)) + def test_pallas_call_kernel_args_mismatch(self): a = np.arange(256, dtype=np.int32) f = self.pallas_call(lambda x_ref: None, # Missing o_ref @@ -1022,10 +1018,10 @@ def test_pallas_call_in_specs_mismatch_inputs(self): pl.BlockSpec((4,), lambda: 0)]) with self.assertRaisesRegex( ValueError, - re.compile("Pytree for `in_specs` and inputs do not match. " + re.compile("Pytree for `in_specs` and `inputs` do not match. " "There are 1 mismatches, including:" ".* at \\[1\\], `in_specs` is a pytree leaf but " - "inputs is a.*", re.DOTALL)): + "`inputs` is a.*", re.DOTALL)): f(a, dict(a=a)) def test_pallas_call_index_map_wrong_number_of_arguments(self): @@ -1067,7 +1063,6 @@ def my_index_map(): "Currently returning 2 values."): f(dict(one=a, two=a)) - def test_pallas_call_index_map_wrong_return_type(self): a = np.arange(256, dtype=np.int32) def my_index_map(i): @@ -1099,6 +1094,8 @@ def my_index_map(i): f(a) def test_pallas_call_index_map_captures_consts(self): + if config.use_simplified_jaxpr_constants.value: + self.skipTest("TODO: decide if we want to keep these errors") a = np.arange(256, dtype=np.int32) index_map_result = np.array([0], dtype=np.int32) f = self.pallas_call(lambda x_ref, o1_ref: None, @@ -1181,14 +1178,55 @@ def test_pallas_call_input_output_aliases_errors(self): out_shape=[jax.ShapeDtypeStruct(x.shape, jnp.float32)], input_output_aliases={1: 0})(x, x) + def test_pallas_error_for_ref_to_jax(self): + m, n, k = 8, 16, 32 + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), + ) + def dot_general_kernel(x_ref, y_ref, o_ref): + o_ref[...] = jax.lax.dot_general(x_ref, y_ref, (((2), (1)), ((1,), (2,)))) + + key1, key2 = random.split(random.key(0)) + x = random.normal(key1, (m, k), dtype=jnp.float32) + y = random.normal(key2, (k, n), dtype=jnp.float32) + with self.assertRaisesRegex( + ValueError, + r"Attempting to pass a Ref" + r" Ref{float32\[8,32\]}" + r" to a primitive: dot_general -- did you forget to unpack \(\[...\]\)" + r" the ref?", + ): + dot_general_kernel(x, y) + + def test_pallas_error_for_writing_ref_to_ref(self): + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + ) + def kernel(x_ref, o_ref): + o_ref[...] = x_ref + + x = jnp.ones((8, 128), dtype=jnp.float32) + with self.assertRaisesRegex( + ValueError, "Cannot store a Ref into another Ref", + ): + kernel(x) + class ApiErrorInterpretTest(ApiErrorTest): INTERPRET = True -class PallasCallInputOutputAliasingTest(PallasBaseTest): +class PallasCallInputOutputAliasingTest(ptu.PallasTest): - def test_basic_input_output_aliasing(self): + def setUp(self): + super().setUp() + # TODO(bchetioui): Remove this once tests are all compatible with + # Pallas/Mosaic GPU. + self.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(False)) + + def test_vector_input_output_aliasing(self): # Input needs to be big so it doesn't fit in VMEM size = 1024 if jtu.is_device_cuda(): @@ -1217,17 +1255,99 @@ def f(x): self.assertEqual(mem_analysis.alias_size_in_bytes, expected_num_bytes) self.assertEqual(mem_analysis.temp_size_in_bytes, 0) + def test_scalar_input_output_aliasing(self): + if jtu.test_device_matches(["tpu"]) and not jtu.is_cloud_tpu_at_least( + 2025, 10, 7 + ): + self.skipTest("Requires libtpu built after 2025-10-07") + + x = jnp.array([41.0], dtype=jnp.float32) + expected = x + 1.0 + + def kernel(x_ref, y_ref): + y_ref[0] = x_ref[0] + 1.0 + + shape = jax.ShapeDtypeStruct(x.shape, x.dtype) + scalar_smem_spec = pl.BlockSpec( + block_shape=(1,), index_map=lambda *_: (0,), memory_space=pltpu.SMEM + ) + + @functools.partial(jax.jit, donate_argnums=(0,)) + def f(x_in): + return self.pallas_call( + kernel, + out_shape=shape, + in_specs=[scalar_smem_spec], + out_specs=scalar_smem_spec, + grid=(1,), + input_output_aliases={0: 0}, + )(x_in) -class PallasCallInputOutputAliasingInterpretTest(PallasBaseTest): + o = f(x) + np.testing.assert_array_equal(o, expected) + with self.assertRaisesRegex(RuntimeError, "Array has been deleted"): + print(x) + + def test_mixed_scalar_vector_input_output_aliasing(self): + if jtu.test_device_matches(["tpu"]) and not jtu.is_cloud_tpu_at_least( + 2025, 10, 7 + ): + self.skipTest("Requires libtpu built after 2025-10-07") + + x_scalar = jnp.array([41.0], dtype=jnp.float32) + x_vector = jnp.arange(1024, dtype=jnp.float32).reshape((8, 128)) + expected_scalar = x_scalar + 1.0 + expected_vector = x_vector + 1.0 + + def kernel(scalar_in_ref, vector_in_ref, scalar_out_ref, vector_out_ref): + scalar_out_ref[0] = scalar_in_ref[0] + 1.0 + vector_out_ref[:] = vector_in_ref[:] + 1.0 + + scalar_shape = jax.ShapeDtypeStruct(x_scalar.shape, x_scalar.dtype) + vector_shape = jax.ShapeDtypeStruct(x_vector.shape, x_vector.dtype) + scalar_spec = pl.BlockSpec( + block_shape=(1,), index_map=lambda *_: (0,), memory_space=pltpu.SMEM + ) + vector_spec = pl.BlockSpec( + block_shape=x_vector.shape, index_map=lambda *_: (0,) * x_vector.ndim + ) + + @functools.partial(jax.jit, donate_argnums=(0, 1)) + def f(x_scalar_in, x_vector_in): + return self.pallas_call( + kernel, + out_shape=(scalar_shape, vector_shape), + in_specs=[scalar_spec, vector_spec], + out_specs=[scalar_spec, vector_spec], + grid=(1,), + input_output_aliases={ + 0: 0, + 1: 1, + }, + )(x_scalar_in, x_vector_in) + + o_scalar, o_vector = f(x_scalar, x_vector) + np.testing.assert_array_equal(o_scalar, expected_scalar) + np.testing.assert_array_equal(o_vector, expected_vector) + with self.assertRaisesRegex(RuntimeError, "Array has been deleted"): + print(x_scalar) + with self.assertRaisesRegex(RuntimeError, "Array has been deleted"): + print(x_vector) + + +class PallasCallInputOutputAliasingInterpretTest(ptu.PallasTest): INTERPRET = True -class PallasControlFlowTest(PallasBaseTest): +class PallasControlFlowTest(ptu.PallasTest): def setUp(self): super().setUp() if self.INTERPRET: self.skipTest("Control flow not supported in interpret mode yet.") + # TODO(bchetioui): Remove this once tests are all compatible with + # Pallas/Mosaic GPU. + self.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(False)) def test_loop_with_float64_carry(self): if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: @@ -1531,6 +1651,28 @@ def body_fn(i, args): jax.value_and_grad(lambda params, x: f(program, params, x).sum())( params, x) + @parameterized.product(start=[0, 1, 2], stop=[6, 7, 8], step=[None, 3]) + def test_loop(self, start, stop, step): + + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((128,), jnp.int32) + ) + def f(x_ref, y_ref): + y_ref[...] = x_ref[...] + + @pl.loop( + jnp.int32(start), + jnp.int32(stop), + **{} if step is None else dict(step=jnp.astype(step, jnp.int32)), + ) + def _(i): + y_ref[...] += i + + x = jnp.zeros((128,), jnp.int32) + np.testing.assert_array_equal( + f(x), jnp.full_like(x, sum(range(start, stop, step or 1))) + ) + def test_fori_loop_simple(self): if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: self.skipTest("TODO: error on TPU") @@ -1697,7 +1839,7 @@ def test_range_while_loop(self): def kernel(x_ref, r_ref): @pl.when(pl.program_id(0) == 0) def _(): - pl.store(r_ref, (0, 0), 0) + r_ref[0, 0] = 0 def cond(carry): i, j = carry @@ -1709,8 +1851,7 @@ def body(carry): sl = jax.lax.div(i, 128) l = jax.lax.rem(i, 128) v = x_ref[0, sl, l] - s = pl.load(r_ref, (0, 0)) - pl.store(r_ref, (0, 0), s + v) + r_ref[0, 0] += v return io + 1, j i = 128 @@ -1762,7 +1903,7 @@ def test_non_range_while_loop(self): def kernel(x_ref, r_ref): @pl.when(pl.program_id(0) == 0) def _(): - pl.store(r_ref, (0, 0), 0) + r_ref[0, 0] = 0 def cond(state): i, s = state @@ -1772,14 +1913,11 @@ def body(state): i, s = state sl = jax.lax.div(i, jnp.astype(128, i.dtype)) l = jax.lax.rem(i, jnp.astype(128, i.dtype)) - v = pl.load(x_ref, (0, sl, l)) + v = x_ref[0, sl, l] return i + 1, s + v i = jnp.int32(0) - s = pl.load(r_ref, (0, 0)) - - i, s = jax.lax.while_loop(cond, body, (i, s)) - pl.store(r_ref, (0, 0), s) + _, r_ref[0, 0] = jax.lax.while_loop(cond, body, (i, r_ref[0, 0])) x = jnp.arange(4096) x = jnp.reshape(x, [4, 8, 128]) @@ -1826,6 +1964,34 @@ def body(v): # 3 -> 6 -> 12 -> 24 np.testing.assert_array_equal(reduced, 1024 * 24) + def test_vector_1d_slice_carry_while_loop(self): + """Tests lowering of a while_loop which carries a sliced vector quantity.""" + if jtu.test_device_matches(["gpu"]) and not self.INTERPRET: + self.skipTest("TODO: slice not implemented on GPU") + + def kernel(x_ref, r_ref): + + def cond(v): + return v[0] < 16 + + def body(v): + return jnp.concatenate([v, v])[1:101] * 2 + + r_ref[:] = jax.lax.while_loop(cond, body, x_ref[:]) + + x = jnp.full((100,), 3, dtype=jnp.int32) + fn = pl.pallas_call( + kernel, + grid=(1,), + in_specs=[pl.BlockSpec((100,), lambda i: (0,))], + out_specs=pl.BlockSpec((100,), lambda i: (0,)), + out_shape=jax.ShapeDtypeStruct((100,), jnp.int32), + ) + r = fn(x) + reduced = jnp.sum(r) + # 3 -> 6 -> 12 -> 24 + np.testing.assert_array_equal(reduced, 100 * 24) + @parameterized.named_parameters( ('1x128', (1, 128)), ('2x128', (2, 128)), @@ -1965,7 +2131,7 @@ class PallasControlFlowInterpretTest(PallasControlFlowTest): ] -class PallasCallAutodifferentiationTest(PallasBaseTest): +class PallasCallAutodifferentiationTest(ptu.PallasTest): def setUp(self): super().setUp() @@ -1975,6 +2141,9 @@ def setUp(self): # TODO: improve tolerance setting self.tol = 1e-5 self.grad_tol = jtu.default_gradient_tolerance[np.dtype(jnp.float32)] + # TODO(bchetioui): Remove this once tests are all compatible with + # Pallas/Mosaic GPU. + self.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(False)) @parameterized.named_parameters(*AD_TEST_CASES) def test_jvp(self, impl): @@ -2080,7 +2249,7 @@ class PallasCallAutodifferentiationInterpretTest(PallasCallAutodifferentiationTe INTERPRET = True -class PallasOutOfBoundsInterpretTest(PallasBaseTest): +class PallasOutOfBoundsInterpretTest(ptu.PallasTest): INTERPRET = True def test_interpret_mode_out_of_bounds_access(self): @@ -2160,7 +2329,7 @@ def _(): np.testing.assert_allclose(out, expected, atol=atol, rtol=rtol) -class PallasCheckifyTest(PallasBaseTest): +class PallasCheckifyTest(ptu.PallasTest): INTERPRET = False def test_basic_runtime_assert(self): @@ -2175,7 +2344,7 @@ def kernel(x_ref, y_ref): checkify.check(False, "second check failed") input_ = jnp.arange(4, dtype=jnp.int32) out_shape = jax.ShapeDtypeStruct(input_.shape, input_.dtype) - with pltpu.enable_runtime_assert(True): + with pl.enable_debug_checks(True): pallas_call = pl.pallas_call(kernel, out_shape=out_shape) pallas_call(input_) # This should log "second check failed" @@ -2185,11 +2354,10 @@ def test_runtime_assert_is_noop_when_not_enabled(self): self.skipTest("Runtime check only implemented on TPU.") def kernel(x_ref, y_ref): y_ref[...] = x_ref[...] - checkify.check(False, "failed check", - debug=True) # This check always fails. + pl.debug_check(False, "failed check") # This check always fails. input_ = jnp.arange(4, dtype=jnp.int32) out_shape = jax.ShapeDtypeStruct(input_.shape, input_.dtype) - with pltpu.enable_runtime_assert(False): + with pl.enable_debug_checks(False): pallas_call = pl.pallas_call(kernel, out_shape=out_shape) result = pallas_call(input_) np.testing.assert_allclose(result, input_) @@ -2341,7 +2509,14 @@ class PallasCheckifyInterpretTest(PallasCheckifyTest): INTERPRET = True -class PallasCallNamedGridTest(PallasBaseTest): +class PallasCallNamedGridTest(ptu.PallasTest): + + def setUp(self): + super().setUp() + # TODO(bchetioui): Remove this once tests are all compatible with + # Pallas/Mosaic GPU. + self.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(False)) + def test_named_grid(self): def kernel(x_ref, y_ref): @@ -2379,8 +2554,8 @@ def kernel(x_ref, y_ref): def test_can_query_named_grid_size_in_kernel_via_psum(self): def kernel(x_ref, y_ref): - self.assertEqual(lax.psum(1, "i"), 2) - self.assertEqual(lax.psum(1, "j"), 4) + self.assertEqual(lax.axis_size("i"), 2) + self.assertEqual(lax.axis_size("j"), 4) y_ref[...] = x_ref[...] x = jnp.arange(4 * 16 * 128, dtype=np.int32).reshape((4, 16, 128)) @@ -2400,8 +2575,8 @@ def test_can_query_named_dynamic_grid_size_in_kernel_via_psum(self): self.skipTest("Not supported.") def kernel(x_ref, y_ref): - self.assertEqual(lax.psum(1, "i"), 2) - self.assertEqual(lax.psum(1, "j"), 4) + self.assertEqual(lax.axis_size("i"), 2) + self.assertEqual(lax.axis_size("j"), 4) y_ref[...] = x_ref[...] x = jnp.arange(4 * 8 * 128, dtype=np.int32).reshape((4, 8, 128)) @@ -2441,7 +2616,7 @@ def kernel(x_ref, y_ref): ) -class SymbolicPallasTest(PallasBaseTest): +class SymbolicPallasTest(ptu.PallasTest): def test_simple_symbolic_matmul_export(self): if jtu.test_device_matches(["gpu"]): @@ -2517,51 +2692,207 @@ def sym_matmul_kernel(x_ref, y_ref, z_ref): str(exported_module), ) + def test_pallas_shape_poly_no_cache_collision(self): + + def kernel(x, y): + y[:] = x[:] + + f = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + ) + f = jax.vmap(f) + + x1_shape = jax.ShapeDtypeStruct( + jax.export.symbolic_shape('b1, 8, 128'), jnp.float32 + ) + exported_module1 = pl.lower_as_mlir(jax.jit(f), x1_shape, dynamic_shapes=True) + self.assertIn("(b1, 8, 128)", str(exported_module1)) + x2_shape = jax.ShapeDtypeStruct( + jax.export.symbolic_shape('b2, 8, 128'), jnp.float32 + ) + exported_module2 = pl.lower_as_mlir(jax.jit(f), x2_shape, dynamic_shapes=True) + self.assertIn("(b2, 8, 128)", str(exported_module2)) + class PallasCallNamedGridInterpretTest(PallasCallNamedGridTest): INTERPRET = True -def _find_pallas_call_in_jaxpr( - jaxpr: jax_core.Jaxpr) -> jax_core.JaxprEqn | None: - for eqn in jaxpr.eqns: - call_eqn = None - if eqn.primitive == pallas_call.pallas_call_p: - call_eqn = eqn - elif 'jaxpr' in eqn.params: - call_eqn = _find_pallas_call_in_jaxpr(eqn.params['jaxpr']) - if call_eqn is not None: - return call_eqn - return None +@dataclasses.dataclass(frozen=True) +class WeirdTuple: + x0: jax.Array + x1: jax.Array -class PallasCompilerParamsTest(PallasBaseTest): - def test_triton_params_consistent_across_double_jit(self): - # Test for https://github.com/jax-ml/jax/issues/25714 - if not jtu.test_device_matches(["gpu"]): - self.skipTest("Triton backend only works on GPU.") - params = plgpu.TritonCompilerParams(num_warps=8) +@dataclasses.dataclass(frozen=True) +class WeirdTupleTy(hijax.HiType): + x0_aval: jax_core.ShapedArray + x1_aval: jax_core.ShapedArray - @jax.jit - @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32), - compiler_params=params) - def copy_kernel(x_ref, o_ref): - o_ref[...] = x_ref[...] + @property + def shape(self) -> tuple[int, ...]: + return self.x0_aval.shape - @functools.partial(jax.jit, static_argnames=["z"]) - def plus_z(x, z): - return copy_kernel(x+z) + @property + def dtype(self) -> jnp.dtype: + return self.x0_aval.dtype - x = 0. - extracted_params = _find_pallas_call_in_jaxpr( - plus_z.trace(x, 1).jaxpr).params["compiler_params"] - self.assertEqual(plus_z(0., 1.), 1.) - self.assertEqual(extracted_params["triton"]["num_warps"], 8) - extracted_params = _find_pallas_call_in_jaxpr( - plus_z.trace(x, 2).jaxpr).params["compiler_params"] - self.assertEqual(plus_z(0., 2.), 2.) - self.assertEqual(extracted_params["triton"]["num_warps"], 8) + def update(self, *, shape: tuple[int, ...]) -> WeirdTupleTy: + return dataclasses.replace( + self, x0_aval=self.x0_aval.update(shape=shape), + x1_aval=self.x1_aval.update(shape=shape[1:]) + ) + + def lo_ty(self) -> list[jax_core.ShapedArray]: + return [self.x0_aval, self.x1_aval] + + def lower_val(self, hi_val: WeirdTuple) -> list[jax.Array]: + return [hi_val.x0, hi_val.x1] + + def raise_val(self, x0, x1) -> WeirdTuple: + return WeirdTuple(x0, x1) + + def lower_block_spec(self, block_spec: pl.BlockSpec): + x1_block_spec = block_spec.replace(block_shape=block_spec.block_shape[1:], + index_map=lambda *args: (0,)) + return [block_spec, x1_block_spec] + +hijax.register_hitype( + WeirdTuple, lambda t: WeirdTupleTy(jax.typeof(t.x0), jax.typeof(t.x1)) +) + + +@dataclasses.dataclass(frozen=True) +class SlicedArray: + x: jax.Array # any shape/dtype + s: jax.Array # i32[] + + @property + def shape(self) -> tuple[int, ...]: + return self.x.shape[1:] + + @property + def dtype(self) -> jnp.dtype: + return self.x.dtype + + +@dataclasses.dataclass(frozen=True) +class SlicedArrayTy(hijax.HiType): + pre_sliced_aval: jax_core.ShapedArray + + @property + def shape(self) -> tuple[int, ...]: + return self.pre_sliced_aval.shape[1:] + + @property + def dtype(self) -> jnp.dtype: + return self.pre_sliced_aval.dtype + + def update(self, *, shape: tuple[int, ...]) -> WeirdTupleTy: + return dataclasses.replace( + self, + pre_sliced_aval=self.pre_sliced_aval.update( + shape=(self.pre_sliced_aval.shape[0],) + shape + ), + ) + + def lo_ty(self) -> list[jax_core.ShapedArray]: + return [self.pre_sliced_aval, jax_core.ShapedArray((1,), jnp.int32)] + + def lower_val(self, hi_val: SlicedArray) -> list[jax.Array]: + return [hi_val.x, hi_val.s] + + def raise_val(self, x, s) -> WeirdTuple: + return SlicedArray(x, s) + + def lower_block_spec(self, block_spec: pl.BlockSpec): + def index_map(*args): + idx = block_spec.index_map(*args) + return 0, *idx + new_block_shape = (pl.Blocked(self.pre_sliced_aval.shape[0]), *block_spec.block_shape) + x_block_spec = block_spec.replace(index_map=index_map, block_shape=new_block_shape) + return [x_block_spec, pl.BlockSpec(memory_space=pltpu.SMEM)] + +hijax.register_hitype( + SlicedArray, lambda t: SlicedArrayTy(jax.typeof(t.x)) +) + +index_p = jax_core.Primitive('index_p') +index_p.is_high = lambda *_: True +index_p.def_abstract_eval(lambda xt: jax_core.ShapedArray(xt.shape, xt.dtype)) + + +def index_to_lojax(xt: jax.Ref) -> jax.Array: + assert isinstance(xt, jax.Ref) + x_ref = xt._refs.x + s_ref = xt._refs.s + s = s_ref[0] + return x_ref[s] +index_p.to_lojax = index_to_lojax + + +class PallasHiJaxTest(ptu.PallasTest): + + def setUp(self): + super().setUp() + # TODO(bchetioui): Remove this once tests are all compatible with + # Pallas/Mosaic GPU. + self.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(False)) + + def test_pass_weird_tuple_into_pallas_call(self): + + xt = WeirdTuple(x0=jnp.ones((8, 8)), x1=jnp.zeros((8,))) + + def kernel(xt_ref, ot_ref): + xt = xt_ref[...] + ot_ref[...] = xt + + ot = self.pallas_call(kernel, out_shape=jax.typeof(xt))(xt) + self.assertArraysEqual(ot.x0, xt.x0) + self.assertArraysEqual(ot.x1, xt.x1) + + def test_pass_sliced_array_into_pallas_call(self): + + xs = SlicedArray( + x=jnp.arange(8 * 16 * 128).reshape(8, 16, 128), + s=jnp.array([2], jnp.int32), + ) + + def kernel(xs_ref, o_ref): + x = index_p.bind(xs_ref) + o_ref[...] = x + + o = self.pallas_call( + kernel, out_shape=jax.ShapeDtypeStruct(xs.shape, xs.dtype), + in_specs=[pl.BlockSpec((8, 128), lambda i: (i, 0))], + out_specs=pl.BlockSpec((8, 128), lambda i: (i, 0)), + grid=(2,) + )(xs) + self.assertArraysEqual(o, xs.x[xs.s[0]]) + + def test_pass_hi_type_with_aliasing(self): + + xs = SlicedArray( + x=jnp.arange(8 * 16 * 128).reshape(8, 16, 128), + s=jnp.array([2], jnp.int32), + ) + + def kernel(xs_ref, o_ref): + o_ref[...] = xs_ref[...] + + @jax.jit + def f(xs): + return self.pallas_call( + kernel, out_shape=jax.typeof(xs), + in_specs=[pl.BlockSpec((8, 128), lambda i: (i, 0))], + out_specs=pl.BlockSpec((8, 128), lambda i: (i, 0)), + grid=(2,), + input_output_aliases={0: 0} + )(xs) + os = f(xs) + self.assertArraysEqual(os.x, xs.x) + self.assertArraysEqual(os.s, xs.s) if __name__ == "__main__": diff --git a/tests/pallas/pallas_vmap_test.py b/tests/pallas/pallas_vmap_test.py index ffa6195625dd..fdfeb2437504 100644 --- a/tests/pallas/pallas_vmap_test.py +++ b/tests/pallas/pallas_vmap_test.py @@ -24,7 +24,7 @@ from jax._src import config from jax._src import dtypes from jax._src import test_util as jtu -from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr +from jax._src.pallas import pallas_call from jax.experimental import pallas as pl import jax.numpy as jnp import numpy as np @@ -36,8 +36,8 @@ config.parse_flags_with_absl() -intx = dtypes.canonicalize_dtype(jnp.int64) -floatx = dtypes.canonicalize_dtype(jnp.float64) +intx = dtypes.default_int_dtype() +floatx = dtypes.default_float_dtype() @jtu.with_config(jax_traceback_filtering="off") @@ -54,7 +54,9 @@ def setUp(self): self.skipTest("Only works on non-Windows platforms") super().setUp() - _trace_kernel_to_jaxpr.cache_clear() + # TODO(bchetioui): Remove this once tests are all compatible with + # Pallas/Mosaic GPU. + self.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(False)) def pallas_call(self, *args, **kwargs): return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET) @@ -132,7 +134,9 @@ def add_one(x_ref, o_ref): out_ref = jnp.arange(1, 9).reshape((4, 2)) np.testing.assert_allclose(out, out_ref) - def test_vmap_with_hoisted_consts(self): + def test_vmap_with_const_args(self): + if config.use_simplified_jaxpr_constants.value: + self.skipTest("TODO: decide if we want to keep these errors") to_store = np.arange(128, dtype=np.float32).reshape((1, 128)) x = np.arange(4 * 16 * 128, dtype=np.float32).reshape((4, 16, 128)) diff --git a/tests/pallas/pipelining/BUILD b/tests/pallas/pipelining/BUILD new file mode 100644 index 000000000000..ac0659d59d7e --- /dev/null +++ b/tests/pallas/pipelining/BUILD @@ -0,0 +1,58 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +load( + "//jaxlib:jax.bzl", + "jax_generate_backend_suites", + "jax_multiplatform_test", + "py_deps", +) + +licenses(["notice"]) + +package( + default_applicable_licenses = [], + default_visibility = ["//visibility:private"], +) + +jax_generate_backend_suites(backends = ["cpu"]) + +jax_multiplatform_test( + name = "schedule_api_test", + srcs = ["schedule_api_test.py"], + enable_backends = ["cpu"], + enable_configs = ["cpu"], + deps = [ + "//jax/_src:core", + "//jax/_src:state_types", + "//jax/_src/pallas/pipelining:pipeline_test_util", + "//jax/_src/pallas/pipelining:schedule_api", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "schedulers_test", + srcs = ["schedulers_test.py"], + enable_backends = ["cpu"], + enable_configs = ["cpu"], + deps = [ + "//jax/_src/pallas/pipelining:internal", + "//jax/_src/pallas/pipelining:pipeline_test_util", + "//jax/_src/pallas/pipelining:schedulers", + ] + py_deps([ + "absl/testing", + ]), +) diff --git a/tests/pallas/pipelining/schedule_api_test.py b/tests/pallas/pipelining/schedule_api_test.py new file mode 100644 index 000000000000..21281419bf6e --- /dev/null +++ b/tests/pallas/pipelining/schedule_api_test.py @@ -0,0 +1,134 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import dataclasses +from typing import Any + +from absl.testing import absltest +import jax +from jax import numpy as jnp +from jax._src import core as jax_core +from jax._src import test_util as jtu +from jax._src.pallas.pipelining import pipeline_test_util as test_util +from jax._src.pallas.pipelining import schedule_api +from jax._src.state import types as state_types +import numpy as np + + +jax.config.parse_flags_with_absl() + + +@dataclasses.dataclass(frozen=True) +class MemoryRef: + shape: tuple[int, ...] + dtype: np.dtype + memory_space: Any | None = None + + def get_ref_aval(self) -> state_types.AbstractRef: + return state_types.AbstractRef( + inner_aval=jax_core.ShapedArray(shape=self.shape, dtype=self.dtype), + memory_space=self.memory_space, + ) + + +class ApiTest(absltest.TestCase): + + def setUp(self): + super().setUp() + if not jtu.test_device_matches(["cpu"]): + self.skipTest("Only works on CPU") + + def test_basic_pipeline(self): + # Use reads/writes to mimic the Ref effects of DMAs. + copy_in = schedule_api.AsyncStage(max_in_flight=2) + + @copy_in.def_start + def copy_in_start(_, x_ref, o_ref): + del o_ref + # dma_start creates a write_effect to x_ref + x_ref[...] = jnp.ones_like(x_ref) + + @copy_in.def_end + def copy_in_end(_, x_ref, o_ref): + del o_ref + # dma_end creates a write_effect to x_ref + x_ref[...] = jnp.ones_like(x_ref) + + @schedule_api.stage(max_in_flight=2) + def kernel_body(_, x_ref, o_ref): + o_ref[...] = x_ref[...] + 1.0 + + copy_out = schedule_api.AsyncStage(max_in_flight=2) + @copy_out.def_start + def copy_out_start(_, x_ref, o_ref): + del x_ref + # dma_start creates a read_effect to o_ref + _ = o_ref[...] + + @copy_out.def_end + def copy_out_end(_, x_ref, o_ref): + del x_ref + # dma_end creates a read_effect to o_ref + _ = o_ref[...] + + pipeline = schedule_api.schedule_pipeline( + stages=(copy_in, kernel_body, copy_out), + grid=(4,), + args=( + MemoryRef(shape=(128, 128), dtype=jnp.dtype(jnp.float32), + memory_space="VMEM"), + MemoryRef(shape=(128, 128), dtype=jnp.dtype(jnp.float32), + memory_space="VMEM"), + ), + eval_fn=test_util.print_stage, + ) + ref = jnp.ones((128, 128), jnp.float32) + ref = jax.new_ref(ref) + with jtu.capture_stdout() as stdout: + pipeline(ref, ref) + output = stdout().strip().split("\n") + expected = [ + # step + "[itr=0] copy_in_start", + "[itr=1] copy_in_start", + # step + "[itr=0] copy_in_end", + "[itr=0] kernel_body", + "[itr=0] copy_out_start", + "[itr=2] copy_in_start", + # step + "[itr=1] copy_in_end", + "[itr=1] kernel_body", + "[itr=1] copy_out_start", + "[itr=3] copy_in_start", + # step + test_util.AnyOrder([ + "[itr=0] copy_out_end", + "[itr=2] copy_in_end"]), + "[itr=2] kernel_body", + "[itr=2] copy_out_start", + # step + test_util.AnyOrder([ + "[itr=1] copy_out_end", + "[itr=3] copy_in_end"]), + "[itr=3] kernel_body", + "[itr=3] copy_out_start", + # step + "[itr=2] copy_out_end", + "[itr=3] copy_out_end", + ] + self.assertTrue(test_util.compare_lists(output, expected)) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/pallas/pipelining/schedulers_test.py b/tests/pallas/pipelining/schedulers_test.py new file mode 100644 index 000000000000..d48196ce664e --- /dev/null +++ b/tests/pallas/pipelining/schedulers_test.py @@ -0,0 +1,493 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import test_util as jtu +from jax._src.pallas.pipelining import internal +from jax._src.pallas.pipelining import pipeline_test_util as test_util +from jax._src.pallas.pipelining import schedulers + + +jax.config.parse_flags_with_absl() + + +def empty_jaxpr(): + def noop(): + pass + return jax.make_jaxpr(noop) + + +class SchedulersGoldenTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + if not jtu.test_device_matches(["cpu"]): + self.skipTest("Only works on CPU") + + def test_2_async_stages(self): + # This test uses 2 stages that are both async. + # 1 + # | + # 2 + token1 = internal.make_token("a") + token2 = internal.make_token("b") + stage1_start = internal.PipelineStage( + jaxpr=empty_jaxpr(), + effects=(internal.WriteEffect(token1),), + properties=internal.SchedulingProperties( + max_in_flight=3, is_async_start=True, is_async_done=False), + name="stage1_start" + ) + stage1_done = internal.PipelineStage( + jaxpr=empty_jaxpr(), + effects=(internal.ReadEffect(token1), internal.WriteEffect(0)), + properties=internal.SchedulingProperties( + max_in_flight=3, is_async_start=False, is_async_done=True), + name="stage1_end" + ) + stage2_start = internal.PipelineStage( + jaxpr=empty_jaxpr(), + effects=(internal.WriteEffect(token2), + # We need to insert this token so that stage1_start + # does not clobber input 0. + internal.ReadEffect(token1), + internal.ReadEffect(0)), + properties=internal.SchedulingProperties( + max_in_flight=3, is_async_start=True, is_async_done=False), + name="stage2_start" + ) + stage2_done = internal.PipelineStage( + jaxpr=empty_jaxpr(), + effects=(internal.ReadEffect(token2),), + properties=internal.SchedulingProperties( + max_in_flight=3, is_async_start=False, is_async_done=True), + name="stage2_end" + ) + loop_struct = internal.NDLoopStruct( + stages=(stage1_start, stage1_done, stage2_start, stage2_done), + grid=(4,) + ) + with jtu.capture_stdout() as stdout: + schedulers.static_nd_loop_scheduler( + loop_struct, + args=(None,), + eval_fn=test_util.print_stage) + output = stdout().strip().split("\n") + expected = [ + "[itr=0] stage1_start", + "[itr=1] stage1_start", + "[itr=2] stage1_start", + "[itr=0] stage1_end", + "[itr=0] stage2_start", + "[itr=3] stage1_start", + "[itr=1] stage1_end", + "[itr=1] stage2_start", + "[itr=2] stage1_end", + "[itr=2] stage2_start", + "[itr=3] stage1_end", + "[itr=0] stage2_end", + "[itr=3] stage2_start", + "[itr=1] stage2_end", + "[itr=2] stage2_end", + "[itr=3] stage2_end", + ] + self.assertEqual(output, expected) + + def test_async_inputs_with_different_buffering(self): + # This test uses 2 input stages (1a, 1b) that feed into a synchronous stage. + # 1a 1b + # \ / + # 2 + token1a = internal.make_token("1a") + token1b = internal.make_token("1b") + stage1a_start = internal.PipelineStage( + jaxpr=empty_jaxpr(), + effects=(internal.WriteEffect(token1a),), + properties=internal.SchedulingProperties( + max_in_flight=2, is_async_start=True, is_async_done=False), + name="stage1a_start" + ) + stage1a_done = internal.PipelineStage( + jaxpr=empty_jaxpr(), + effects=(internal.ReadEffect(token1a), internal.WriteEffect(0)), + properties=internal.SchedulingProperties( + max_in_flight=2, is_async_start=False, is_async_done=True), + name="stage1a_end" + ) + stage1b_start = internal.PipelineStage( + jaxpr=empty_jaxpr(), + effects=(internal.WriteEffect(token1b),), + properties=internal.SchedulingProperties( + max_in_flight=4, is_async_start=True, is_async_done=False), + name="stage1b_start" + ) + stage1b_done = internal.PipelineStage( + jaxpr=empty_jaxpr(), + effects=(internal.ReadEffect(token1b), internal.WriteEffect(1)), + properties=internal.SchedulingProperties( + max_in_flight=4, is_async_start=False, is_async_done=True), + name="stage1b_end" + ) + stage2 = internal.PipelineStage( + jaxpr=empty_jaxpr(), + effects=(internal.ReadEffect(token1a), + internal.ReadEffect(token1b), + internal.ReadEffect(0), + internal.ReadEffect(1)), + properties=internal.SchedulingProperties( + max_in_flight=2, is_async_start=False, is_async_done=False), + name="stage2" + ) + loop_struct = internal.NDLoopStruct( + stages=(stage1a_start, stage1a_done, + stage1b_start, stage1b_done, + stage2,), + grid=(6,) + ) + with jtu.capture_stdout() as stdout: + schedulers.static_nd_loop_scheduler( + loop_struct, + args=(None, None), + eval_fn=test_util.print_stage) + output = stdout().strip().split("\n") + expected = [ + "[itr=0] stage1a_start", + "[itr=0] stage1b_start", + "[itr=1] stage1b_start", + "[itr=2] stage1b_start", + "[itr=1] stage1a_start", + "[itr=3] stage1b_start", + "[itr=0] stage1a_end", + "[itr=0] stage1b_end", + "[itr=0] stage2", + "[itr=2] stage1a_start", + "[itr=4] stage1b_start", + "[itr=1] stage1a_end", + "[itr=1] stage1b_end", + "[itr=1] stage2", + "[itr=3] stage1a_start", + "[itr=5] stage1b_start", + "[itr=2] stage1a_end", + "[itr=2] stage1b_end", + "[itr=2] stage2", + "[itr=4] stage1a_start", + "[itr=3] stage1b_end", + "[itr=3] stage1a_end", + "[itr=3] stage2", + "[itr=5] stage1a_start", + "[itr=4] stage1b_end", + "[itr=4] stage1a_end", + "[itr=4] stage2", + "[itr=5] stage1a_end", + "[itr=5] stage1b_end", + "[itr=5] stage2", + ] + self.assertEqual(output, expected) + + def test_synchronous_3_stage(self): + # This test models a 3-stage pipeline where all stages are synchronous. + # 1 + # | + # 2 + # | + # 3 + stage1 = internal.PipelineStage( + jaxpr=empty_jaxpr(), + effects=(internal.WriteEffect(0),), + properties=internal.SchedulingProperties( + max_in_flight=3, is_async_start=False, is_async_done=False), + name="stage1" + ) + stage2 = internal.PipelineStage( + jaxpr=empty_jaxpr(), + effects=(internal.ReadEffect(0), internal.WriteEffect(1),), + properties=internal.SchedulingProperties( + max_in_flight=3, is_async_start=False, is_async_done=False), + name="stage2" + ) + stage3 = internal.PipelineStage( + jaxpr=empty_jaxpr(), + effects=(internal.ReadEffect(1),), + properties=internal.SchedulingProperties( + max_in_flight=3, is_async_start=False, is_async_done=False), + name="stage3" + ) + loop_struct = internal.NDLoopStruct( + stages=(stage1, stage2, stage3), + grid=(4,) + ) + + with jtu.capture_stdout() as stdout: + schedulers.static_nd_loop_scheduler( + loop_struct, + args=(None, None), + eval_fn=test_util.print_stage) + output = stdout().strip().split("\n") + expected = [ + # step + "[itr=0] stage1", + # step + "[itr=1] stage1", + "[itr=0] stage2", + # step + "[itr=2] stage1", + "[itr=1] stage2", + "[itr=0] stage3", + # step + "[itr=3] stage1", + "[itr=2] stage2", + "[itr=1] stage3", + # step + "[itr=3] stage2", + "[itr=2] stage3", + # step + "[itr=3] stage3", + ] + self.assertEqual(output, expected) + + def test_standard_emit_pipeline(self): + # This test uses 3 stages where copy_in and copy_out are async. + # copy_in + # | + # body + # | + # copy_out + token1 = internal.make_token("a") + token2 = internal.make_token("b") + copy_in_start = internal.PipelineStage( + jaxpr=empty_jaxpr(), + effects=(internal.WriteEffect(token1),), + properties=internal.SchedulingProperties( + max_in_flight=2, is_async_start=True, is_async_done=False), + name="copy_in_start" + ) + copy_in_done = internal.PipelineStage( + jaxpr=empty_jaxpr(), + effects=(internal.ReadEffect(token1), internal.WriteEffect(0)), + properties=internal.SchedulingProperties( + max_in_flight=2, is_async_start=False, is_async_done=True), + name="copy_in_done" + ) + body_stage = internal.PipelineStage( + jaxpr=empty_jaxpr(), + effects=(internal.ReadEffect(token1), internal.ReadEffect(0), + internal.WriteEffect(1)), + properties=internal.SchedulingProperties( + max_in_flight=2, is_async_start=False, is_async_done=False), + name="body" + ) + copy_out_start = internal.PipelineStage( + jaxpr=empty_jaxpr(), + effects=(internal.WriteEffect(token2), + internal.ReadEffect(1)), + properties=internal.SchedulingProperties( + max_in_flight=2, is_async_start=True, is_async_done=False), + name="copy_out_start" + ) + copy_out_done = internal.PipelineStage( + jaxpr=empty_jaxpr(), + effects=(internal.ReadEffect(token2),), + properties=internal.SchedulingProperties( + max_in_flight=2, is_async_start=False, is_async_done=True), + name="copy_out_done" + ) + loop_struct = internal.NDLoopStruct( + stages=(copy_in_start, copy_in_done, body_stage, + copy_out_start, copy_out_done), + grid=(4, 4) + ) + with jtu.capture_stdout() as stdout: + schedulers.static_nd_loop_scheduler( + loop_struct, + args=(None,), + eval_fn=test_util.print_stage) + output = stdout().strip().split("\n") + prologue = [ + "[itr=0] copy_in_start", + "[itr=1] copy_in_start", + "[itr=0] copy_in_done", + "[itr=0] body", + "[itr=0] copy_out_start", + "[itr=2] copy_in_start", + "[itr=1] copy_in_done", + "[itr=1] body", + "[itr=1] copy_out_start", + ] + steady_state = [] + for itr in range(3, 16): + steady_state.extend([ + f"[itr={itr}] copy_in_start", + test_util.AnyOrder([ + f"[itr={itr-3}] copy_out_done", + f"[itr={itr-1}] copy_in_done",]), + f"[itr={itr-1}] body", + f"[itr={itr-1}] copy_out_start", + ]) + epilogue = [ + "[itr=15] copy_in_done", + "[itr=13] copy_out_done", + "[itr=15] body", + "[itr=15] copy_out_start", + "[itr=14] copy_out_done", + "[itr=15] copy_out_done", + ] + expected = prologue + steady_state + epilogue + list_equal = test_util.compare_lists(output, expected) + self.assertTrue(list_equal) + + def test_pipelined_prefetch(self): + # This test uses 4 stages where prefetch, copy_in and copy_out are async. + # prefetch + # | + # copy_in + # | + # body + # | + # copy_out + token1 = internal.make_token("a") + token2 = internal.make_token("b") + token_prefetch = internal.make_token("c") + prefetch_start = internal.PipelineStage( + jaxpr=empty_jaxpr(), + effects=(internal.WriteEffect(token_prefetch),), + properties=internal.SchedulingProperties( + max_in_flight=2, is_async_start=True, is_async_done=False), + name="prefetch_start" + ) + prefetch_done = internal.PipelineStage( + jaxpr=empty_jaxpr(), + effects=(internal.ReadEffect(token_prefetch), internal.WriteEffect(0)), + properties=internal.SchedulingProperties( + max_in_flight=2, is_async_start=False, is_async_done=True), + name="prefetch_done" + ) + copy_in_start = internal.PipelineStage( + jaxpr=empty_jaxpr(), + effects=(internal.ReadEffect(token_prefetch), + internal.ReadEffect(0), + internal.WriteEffect(token1),), + properties=internal.SchedulingProperties( + max_in_flight=2, is_async_start=True, is_async_done=False), + name="copy_in_start" + ) + copy_in_done = internal.PipelineStage( + jaxpr=empty_jaxpr(), + effects=(internal.ReadEffect(token1), internal.WriteEffect(1)), + properties=internal.SchedulingProperties( + max_in_flight=2, is_async_start=False, is_async_done=True), + name="copy_in_done" + ) + body_stage = internal.PipelineStage( + jaxpr=empty_jaxpr(), + effects=(internal.ReadEffect(token1), internal.ReadEffect(1), + internal.WriteEffect(2)), + properties=internal.SchedulingProperties( + max_in_flight=2, is_async_start=False, is_async_done=False), + name="body" + ) + copy_out_start = internal.PipelineStage( + jaxpr=empty_jaxpr(), + effects=(internal.WriteEffect(token2), + internal.ReadEffect(2)), + properties=internal.SchedulingProperties( + max_in_flight=2, is_async_start=True, is_async_done=False), + name="copy_out_start" + ) + copy_out_done = internal.PipelineStage( + jaxpr=empty_jaxpr(), + effects=(internal.ReadEffect(token2),), + properties=internal.SchedulingProperties( + max_in_flight=2, is_async_start=False, is_async_done=True), + name="copy_out_done" + ) + loop_struct = internal.NDLoopStruct( + stages=(prefetch_start, prefetch_done, + copy_in_start, copy_in_done, + body_stage, + copy_out_start, copy_out_done), + grid=(4, 4) + ) + with jtu.capture_stdout() as stdout: + schedulers.static_nd_loop_scheduler( + loop_struct, + args=(None,), + eval_fn=test_util.print_stage) + output = stdout().strip().split("\n") + # The schedule is slightly suboptimal here, noted in the comments. + prologue = [ + "[itr=0] prefetch_start", + "[itr=1] prefetch_start", + "[itr=0] prefetch_done", + "[itr=0] copy_in_start", + "[itr=2] prefetch_start", + "[itr=1] prefetch_done", + "[itr=1] copy_in_start", + "[itr=3] prefetch_start", + "[itr=0] copy_in_done", + # This can be pushed after body, before [itr=2] copy_in_start + "[itr=2] prefetch_done", + "[itr=0] body", + "[itr=0] copy_out_start", + "[itr=2] copy_in_start", + "[itr=4] prefetch_start", + "[itr=1] copy_in_done", + # This can be pushed after body, before [itr=2] copy_in_start + "[itr=3] prefetch_done", + "[itr=1] body", + "[itr=1] copy_out_start", + "[itr=3] copy_in_start", + # This can be pushed after [itr=5] prefetch_start + "[itr=0] copy_out_done", + "[itr=5] prefetch_start", + "[itr=2] copy_in_done", + "[itr=4] prefetch_done", + "[itr=2] body", + "[itr=2] copy_out_start", + ] + steady_state = [] + for i in range(6, 16): + steady_state.extend([ + f"[itr={i-2}] copy_in_start", + f"[itr={i-5}] copy_out_done", + f"[itr={i}] prefetch_start", + f"[itr={i-3}] copy_in_done", + f"[itr={i-1}] prefetch_done", + f"[itr={i-3}] body", + f"[itr={i-3}] copy_out_start", + ]) + epilogue = [ + "[itr=14] copy_in_start", + "[itr=11] copy_out_done", + "[itr=15] prefetch_done", + "[itr=13] copy_in_done", + "[itr=13] body", + "[itr=13] copy_out_start", + "[itr=15] copy_in_start", + "[itr=12] copy_out_done", + "[itr=14] copy_in_done", + "[itr=14] body", + "[itr=14] copy_out_start", + "[itr=15] copy_in_done", + "[itr=13] copy_out_done", + "[itr=15] body", + "[itr=15] copy_out_start", + "[itr=14] copy_out_done", + "[itr=15] copy_out_done", + ] + expected = prologue + steady_state + epilogue + self.assertEqual(output, expected) + +if __name__ == "__main__": + absltest.main() diff --git a/tests/pallas/tpu_all_gather_test.py b/tests/pallas/tpu_all_gather_test.py index 98b3e5b40135..72477e7230ce 100644 --- a/tests/pallas/tpu_all_gather_test.py +++ b/tests/pallas/tpu_all_gather_test.py @@ -20,119 +20,115 @@ from jax import random from jax._src import test_util as jtu from jax.experimental import mesh_utils +from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu from jax.experimental.pallas.ops.tpu import all_gather import jax.numpy as jnp import numpy as np -try: - import hypothesis as hp - import hypothesis.strategies as hps - CAN_USE_HYPOTHESIS = True -except (ModuleNotFoundError, ImportError): - CAN_USE_HYPOTHESIS = False +import hypothesis as hp +import hypothesis.strategies as hps jax.config.parse_flags_with_absl() P = jax.sharding.PartitionSpec -if CAN_USE_HYPOTHESIS: - - hp.settings.register_profile( - "deterministic", - database=None, - derandomize=True, - deadline=None, - max_examples=50, - print_blob=True, - verbosity=hp.Verbosity.verbose, +hp.settings.register_profile( + "deterministic", + database=None, + derandomize=True, + deadline=None, + max_examples=50, + print_blob=True, + verbosity=hp.Verbosity.verbose, +) +hp.settings.load_profile("deterministic") + + +@hps.composite +def _array_shapes(draw): + # TODO(sharadmv, apaszke): enable this on a wider variety of shapes + valid_shapes = [ + (128, 128), + (256, 128), + (256, 512), + (256, 1024), + # TODO(sharadmv,apaszke): enable these shapes + # (256, 129), + # (129, 128), + # (64, 64), + # (1, 1), + ] + return draw(hps.sampled_from(valid_shapes)) + + +@hps.composite +def _array_dtypes(draw): + return draw( + hps.sampled_from([ + jnp.float32, + jnp.bfloat16, + jnp.int32, + # jnp.float16, # TODO(sharadmv,apaszke): enable float16 all gather + # jnp.int16, # TODO(sharadmv,apaszke): enable int16 all gather + # jnp.int8, # TODO(sharadmv,apaszke): enable int8 all gather + ]) ) - hp.settings.load_profile("deterministic") - - - @hps.composite - def _array_shapes(draw): - # TODO(sharadmv, apaszke): enable this on a wider variety of shapes - valid_shapes = [ - (128, 128), - (256, 128), - (256, 512), - (256, 1024), - # TODO(sharadmv,apaszke): enable these shapes - # (256, 129), - # (129, 128), - # (64, 64), - # (1, 1), - ] - return draw(hps.sampled_from(valid_shapes)) - - - @hps.composite - def _array_dtypes(draw): - return draw( - hps.sampled_from([ - jnp.float32, - jnp.bfloat16, - jnp.int32, - # jnp.float16, # TODO(sharadmv,apaszke): enable float16 all gather - # jnp.int16, # TODO(sharadmv,apaszke): enable int16 all gather - # jnp.int8, # TODO(sharadmv,apaszke): enable int8 all gather - ]) - ) - class AllGatherTest(jtu.JaxTestCase): - - def setUp(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Need TPU devices") - if not jtu.is_device_tpu(version=5, variant="e"): - # TODO(sharadmv,apaszke): expand support to more versions - self.skipTest("Currently only supported on TPU v5e") - - super().setUp() - - @hp.given(hps.booleans(), _array_shapes(), _array_dtypes()) - def test_all_gather_1d_mesh(self, is_vmem, shape, dtype): - if jax.device_count() < 2: - self.skipTest("Need more devices") - memory_space = pltpu.VMEM if is_vmem else pltpu.ANY - mesh_shape = (jax.device_count(),) - mesh = jax.sharding.Mesh( - mesh_utils.create_device_mesh(mesh_shape, jax.devices()), ["x"] - ) - leading, *rest = shape - shape = (mesh.shape["x"] * leading, *rest) - x = random.normal(random.key(0), shape, dtype=jnp.float32).astype(dtype) - x_sharded = jax.device_put(x, jax.sharding.NamedSharding(mesh, P("x"))) - y = all_gather.all_gather(x_sharded, mesh=mesh, axis_name="x", - memory_space=memory_space) - np.testing.assert_array_equal(y, x) - - @hp.given(hps.booleans(), _array_shapes(), _array_dtypes(), - hps.sampled_from(["x", "y"])) - def test_all_gather_2d_mesh(self, is_vmem, shape, dtype, - axis_name): - if jax.device_count() < 2: - self.skipTest("Need more devices") - if jax.device_count() % 2: - self.skipTest("Need an even number of devices") - memory_space = pltpu.VMEM if is_vmem else pltpu.ANY - mesh_shape = (2, jax.device_count() // 2) - mesh = jax.sharding.Mesh( - mesh_utils.create_device_mesh(mesh_shape, jax.devices()), ["x", "y"] - ) - if axis_name == "x": - sharding = jax.sharding.NamedSharding(mesh, P("x", None)) - else: - sharding = jax.sharding.NamedSharding(mesh, P("y", None)) - leading, *rest = shape - shape = (mesh.shape[axis_name] * leading, *rest) - x = random.normal(random.key(0), shape, dtype=jnp.float32).astype(dtype) - x_sharded = jax.device_put(x, sharding) - y = all_gather.all_gather(x_sharded, mesh=mesh, axis_name=axis_name, - memory_space=memory_space) - np.testing.assert_array_equal(y, x) +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe +class AllGatherTest(jtu.JaxTestCase): + + def setUp(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Need TPU devices") + if not jtu.is_device_tpu(version=5, variant="e"): + # TODO(sharadmv,apaszke): expand support to more versions + self.skipTest("Currently only supported on TPU v5e") + + super().setUp() + + @hp.given(hps.booleans(), _array_shapes(), _array_dtypes()) + def test_all_gather_1d_mesh(self, is_vmem, shape, dtype): + if jax.device_count() < 2: + self.skipTest("Need more devices") + memory_space = pltpu.VMEM if is_vmem else pl.ANY + mesh_shape = (jax.device_count(),) + mesh = jax.sharding.Mesh( + mesh_utils.create_device_mesh(mesh_shape, jax.devices()), ["x"] + ) + leading, *rest = shape + shape = (mesh.shape["x"] * leading, *rest) + x = random.normal(random.key(0), shape, dtype=jnp.float32).astype(dtype) + x_sharded = jax.device_put(x, jax.sharding.NamedSharding(mesh, P("x"))) + y = all_gather.all_gather(x_sharded, mesh=mesh, axis_name="x", + memory_space=memory_space) + np.testing.assert_array_equal(y, x) + + @hp.given(hps.booleans(), _array_shapes(), _array_dtypes(), + hps.sampled_from(["x", "y"])) + def test_all_gather_2d_mesh(self, is_vmem, shape, dtype, + axis_name): + if jax.device_count() < 2: + self.skipTest("Need more devices") + if jax.device_count() % 2: + self.skipTest("Need an even number of devices") + memory_space = pltpu.VMEM if is_vmem else pl.ANY + mesh_shape = (2, jax.device_count() // 2) + mesh = jax.sharding.Mesh( + mesh_utils.create_device_mesh(mesh_shape, jax.devices()), ["x", "y"] + ) + if axis_name == "x": + sharding = jax.sharding.NamedSharding(mesh, P("x", None)) + else: + sharding = jax.sharding.NamedSharding(mesh, P("y", None)) + leading, *rest = shape + shape = (mesh.shape[axis_name] * leading, *rest) + x = random.normal(random.key(0), shape, dtype=jnp.float32).astype(dtype) + x_sharded = jax.device_put(x, sharding) + y = all_gather.all_gather(x_sharded, mesh=mesh, axis_name=axis_name, + memory_space=memory_space) + np.testing.assert_array_equal(y, x) if __name__ == "__main__": diff --git a/tests/pallas/tpu_fusable_matmul_test.py b/tests/pallas/tpu_fusible_matmul_test.py similarity index 76% rename from tests/pallas/tpu_fusable_matmul_test.py rename to tests/pallas/tpu_fusible_matmul_test.py index df7c1221bb0c..45b8537a26dc 100644 --- a/tests/pallas/tpu_fusable_matmul_test.py +++ b/tests/pallas/tpu_fusible_matmul_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Fusable matmul test.""" +"""Fusible matmul test.""" import functools from typing import Any @@ -21,8 +21,8 @@ from absl.testing import parameterized import jax from jax._src import test_util as jtu -from jax._src.pallas import fuser from jax.experimental import pallas as pl +from jax.experimental.pallas import fuser from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp import numpy as np @@ -34,7 +34,6 @@ ) -@jit_no_excess_precision def mm_ref(x, y): return jnp.dot(x, y, preferred_element_type=jnp.float32) @@ -71,10 +70,11 @@ def _(): def _(): acc = acc_ref[...].astype(out_dtype) z_values = jax.tree.map(lambda ref: ref.get(), z_value_refs) - o_ref[...] = z_fn(pids, scalar_prefetch, z_values, acc) + out = z_fn(pids, scalar_prefetch, z_values, acc) + jax.tree.map(lambda ref, x: ref.set(x), o_ref, out) -def _fusable_matmul( +def _fusible_matmul( x: fuser.Fusion[[], jax.Array], # pytype: disable=invalid-annotation y: fuser.Fusion[[], jax.Array], # pytype: disable=invalid-annotation z: fuser.Fusion[[jax.Array], jax.Array] | None, # pytype: disable=invalid-annotation @@ -174,23 +174,18 @@ def z_index_map(i, j, k, *_): y_value_block_specs, z_value_block_specs, ], - out_specs=z_out_block_spec, + out_specs=[z_out_block_spec], ), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=dimension_semantics, ), - out_shape=z_out_type, + out_shape=[z_out_type], interpret=interpret, debug=debug, - )( - *scalar_prefetch, - x_values, - y_values, - z_values, - ) + )(*scalar_prefetch, x_values, y_values, z_values,)[0] -def fusable_matmul( +def fusible_matmul( x: jax.Array, y: jax.Array, *, @@ -200,9 +195,9 @@ def fusable_matmul( debug: bool = False, interpret: bool = False, ) -> jax.Array: - return fuser.fusable( + return fuser.fusible( functools.partial( - _fusable_matmul, + _fusible_matmul, bm=bm, bk=bk, bn=bn, @@ -212,7 +207,7 @@ def fusable_matmul( )(x, y) -class FusableMatmulTest(jtu.JaxTestCase): +class FusibleMatmulTest(jtu.JaxTestCase): def setUp(self): if not jtu.is_device_tpu_at_least(4): @@ -225,7 +220,9 @@ def test_matmul(self, dtype): x = jax.random.normal(k0, (512, 512), dtype) y = jax.random.normal(k1, (512, 512), dtype) np.testing.assert_allclose( - jax.jit(fusable_matmul)(x, y), mm_ref(x, y), atol=5e-5 + jax.jit(fusible_matmul)(x, y), + jit_no_excess_precision(mm_ref)(x, y), + atol=5e-5, ) @parameterized.parameters('float32', 'bfloat16') @@ -237,12 +234,104 @@ def test_matmul_with_activation(self, dtype): @jax.jit @fuser.fuse def matmul_relu(x, y): - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) x = jnp.maximum(x, 0.0) return x + @jit_no_excess_precision + def matmul_relu_ref(x, y): + return jax.nn.relu(mm_ref(x, y)) + np.testing.assert_allclose( - matmul_relu(x, y), jax.nn.relu(mm_ref(x, y)), atol=5e-5 + matmul_relu(x, y), matmul_relu_ref(x, y), atol=5e-5 + ) + + def test_matmul_reduce_sum(self): + if not jtu.is_device_tpu_at_least(5): + self.skipTest('Only works with TPU v5+') + dtype = jnp.float32 + k0, k1 = jax.random.split(jax.random.key(0)) + x = jax.random.normal(k0, (512, 512), dtype) + y = jax.random.normal(k1, (512, 512), dtype) + + @jax.jit + @fuser.fuse + def matmul_relu(x, y): + x = fusible_matmul(x, y, bm=512, bn=512) + x = jnp.sum(x, axis=1) + return x + + @jit_no_excess_precision + def matmul_relu_ref(x, y): + return mm_ref(x, y).sum(axis=1) + + np.testing.assert_allclose( + matmul_relu(x, y), matmul_relu_ref(x, y), atol=1e-3 + ) + + def test_matmul_reduce_sum_broadcast(self): + if not jtu.is_device_tpu_at_least(5): + self.skipTest('Only works with TPU v5+') + dtype = jnp.float32 + k0, k1 = jax.random.split(jax.random.key(0)) + x = jax.random.normal(k0, (512, 512), dtype) + y = jax.random.normal(k1, (512, 512), dtype) + + @jax.jit + @fuser.fuse + def matmul_sum_bcast(x, y): + x = fusible_matmul(x, y, bm=512, bn=512) + x = jnp.sum(x, axis=1, keepdims=True) + x = jnp.broadcast_to(x, (x.shape[0], 128)) + return x + + @jit_no_excess_precision + def matmul_sum_bcast_ref(x, y): + return jnp.broadcast_to( + mm_ref(x, y).sum(axis=1, keepdims=True), (x.shape[0], 128) + ) + + np.testing.assert_allclose( + matmul_sum_bcast(x, y), + matmul_sum_bcast_ref(x, y), + atol=1e-3, + ) + + @parameterized.parameters('float32', 'bfloat16') + def test_matmul_plus_iota_custom_fusion(self, dtype): + def make_iota_custom_fusion(shape, dtype): + @fuser.custom_fusion + def iota(start=0): + return jnp.broadcast_to( + jnp.astype(jnp.arange(shape[-1]) + start, dtype), shape + ) + + iota.def_pull_block_spec(lambda bss: (None,)) + + @iota.def_eval_rule + def iota_eval_rule(ctx, _): + shape = ctx.out_block_specs[0].block_shape + i = ctx.out_block_indices[0][-1] + return (make_iota_custom_fusion(shape, dtype)(shape[-1] * i),) + + return iota + + iota_custom_fusion = make_iota_custom_fusion((512, 512), dtype) + + @jax.jit + @fuser.fuse + def matmul_plus_iota(x, y): + return fusible_matmul(x, y).astype(dtype) + iota_custom_fusion() + + @jit_no_excess_precision + def matmul_plus_iota_ref(x, y): + return mm_ref(x, y).astype(dtype) + iota_custom_fusion() + + k0, k1 = jax.random.split(jax.random.key(0)) + x = jax.random.normal(k0, (512, 512), dtype) + y = jax.random.normal(k1, (512, 512), dtype) + np.testing.assert_allclose( + matmul_plus_iota(x, y), matmul_plus_iota_ref(x, y), atol=5e-5 ) @parameterized.parameters('float32', 'bfloat16') @@ -257,13 +346,17 @@ def test_matmul_with_bias(self, dtype): @jax.jit @fuser.fuse def matmul_bias(x, y, b): - x = fusable_matmul(x, y).astype(dtype) + b + x = fusible_matmul(x, y).astype(dtype) + b x = jnp.maximum(x, 0.0) return x + @jit_no_excess_precision + def matmul_bias_ref(x, y, b): + return jax.nn.relu(mm_ref(x, y).astype(dtype) + b) + np.testing.assert_allclose( matmul_bias(x, y, b), - jax.nn.relu(mm_ref(x, y).astype(dtype) + b), + matmul_bias_ref(x, y, b), atol=5e-5 if dtype == 'float32' else 0.5, ) @@ -276,10 +369,16 @@ def test_matmul_with_slice(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y): - x = fusable_matmul(x, y[1]) + x = fusible_matmul(x, y[1]) return x - np.testing.assert_allclose(matmul_slice(x, y), mm_ref(x, y[1]), atol=5e-5) + @jit_no_excess_precision + def matmul_slice_ref(x, y): + return mm_ref(x, y[1]) + + np.testing.assert_allclose( + matmul_slice(x, y), matmul_slice_ref(x, y), atol=5e-5 + ) @parameterized.parameters('float32', 'bfloat16') def test_matmul_with_dynamic_slice(self, dtype): @@ -290,11 +389,15 @@ def test_matmul_with_dynamic_slice(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y, i): - x = fusable_matmul(x, y[i]) + x = fusible_matmul(x, y[i]) return x + @jit_no_excess_precision + def matmul_slice_ref(x, y, i): + return mm_ref(x, y[i]) + np.testing.assert_allclose( - matmul_slice(x, y, 1), mm_ref(x, y[1]), atol=5e-5 + matmul_slice(x, y, 1), matmul_slice_ref(x, y, 1), atol=5e-5 ) @parameterized.parameters('float32', 'bfloat16') @@ -307,12 +410,16 @@ def test_matmul_with_dynamic_slice_bias(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y, b, i, j): - x = fusable_matmul(x, y[j]).astype(dtype) + b[i] + x = fusible_matmul(x, y[j]).astype(dtype) + b[i] return x + @jit_no_excess_precision + def matmul_slice_ref(x, y, b, i, j): + return mm_ref(x, y[j]).astype(dtype) + b[i] + np.testing.assert_allclose( matmul_slice(x, y, b, 1, 2), - mm_ref(x, y[2]).astype(dtype) + b[1], + matmul_slice_ref(x, y, b, 1, 2), atol=5e-5 if dtype == 'float32' else 0.5, ) @@ -325,11 +432,15 @@ def test_matmul_with_multi_slice(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y): - x = fusable_matmul(x, y[1, 1]) + x = fusible_matmul(x, y[1, 1]) return x + @jit_no_excess_precision + def matmul_slice_ref(x, y): + return mm_ref(x, y[1, 1]) + np.testing.assert_allclose( - matmul_slice(x, y), mm_ref(x, y[1, 1]), atol=5e-5 + matmul_slice(x, y), matmul_slice_ref(x, y), atol=5e-5 ) @parameterized.parameters('float32', 'bfloat16') @@ -341,11 +452,15 @@ def test_matmul_with_multiple_slices(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y): - x = fusable_matmul(x, y[1][1]) + x = fusible_matmul(x, y[1][1]) return x + @jit_no_excess_precision + def matmul_slice_ref(x, y): + return mm_ref(x, y[1][1]) + np.testing.assert_allclose( - matmul_slice(x, y), mm_ref(x, y[1, 1]), atol=5e-5 + matmul_slice(x, y), matmul_slice_ref(x, y), atol=5e-5 ) @parameterized.parameters('float32', 'bfloat16') @@ -357,13 +472,17 @@ def test_matmul_with_multiple_dynamic_slices(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y, i, j): - x = fusable_matmul(x, y[i][j]) + x = fusible_matmul(x, y[i][j]) return x + @jit_no_excess_precision + def matmul_slice_ref(x, y, i, j): + return mm_ref(x, y[i][j]) + for i in range(2): for j in range(3): np.testing.assert_allclose( - matmul_slice(x, y, i, j), mm_ref(x, y[i, j]), atol=5e-5 + matmul_slice(x, y, i, j), matmul_slice_ref(x, y, i, j), atol=5e-5 ) @parameterized.parameters('float32', 'bfloat16') @@ -375,13 +494,17 @@ def test_matmul_with_mixed_slices(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y, i, j): - x = fusable_matmul(x, y[2][i, j]) + x = fusible_matmul(x, y[2][i, j]) return x + @jit_no_excess_precision + def matmul_slice_ref(x, y, i, j): + return mm_ref(x, y[2, i, j]) + for i in range(2): for j in range(3): np.testing.assert_allclose( - matmul_slice(x, y, i, j), mm_ref(x, y[2, i, j]), atol=5e-5 + matmul_slice(x, y, i, j), matmul_slice_ref(x, y, i, j), atol=5e-5 ) @parameterized.parameters('float32', 'bfloat16') @@ -396,7 +519,7 @@ def test_matmul_with_multiple_mixed_slices_and_bias(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y, b, i, j, k): - x = fusable_matmul(x[k][3], y[2][i, j]).astype(dtype) + x = fusible_matmul(x[k][3], y[2][i, j]).astype(dtype) return x + b[i, j] @jit_no_excess_precision @@ -415,7 +538,7 @@ def matmul_slice_ref(x, y, b, i, j, k): @parameterized.parameters('float32', 'bfloat16') def test_matmul_input_concat_output(self, dtype): - self.skipTest('select_n doesnt support more than 3 elements') + self.skipTest('select_n does not support more than 3 elements') # TODO(sharadmv): fix this test k0, k1, k2, k3 = jax.random.split(jax.random.key(0), 4) x = jax.random.normal(k0, (128, 128), dtype) @@ -427,10 +550,10 @@ def test_matmul_input_concat_output(self, dtype): @fuser.fuse def matmul_concat(x, ys): y = jnp.concatenate(ys, axis=1) - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) return x - @jax.jit + @jit_no_excess_precision def matmul_concat_ref(x, ys): y = jnp.concatenate(ys, axis=1) return mm_ref(x, y) @@ -453,7 +576,7 @@ def test_matmul_input_concat_contract(self, dtype): @fuser.fuse def matmul_concat(x, ys): y = jnp.concatenate(ys, axis=0) - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) return x @jit_no_excess_precision @@ -481,7 +604,7 @@ def test_matmul_double_concat(self, dtype): def matmul_concat(x, ys, y3): y = jnp.concatenate(ys, axis=0) y = jnp.concatenate([y, y3], axis=1) - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) return x @jit_no_excess_precision @@ -508,7 +631,7 @@ def test_matmul_slice_concat(self, dtype): @fuser.fuse def matmul_concat(x, y1, y2): y = jnp.concatenate([y1, y2[3]], axis=0) - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) return x @jit_no_excess_precision @@ -533,7 +656,7 @@ def test_matmul_slice_concat_slice(self, dtype): @fuser.fuse def matmul_concat(x, y1, y2): y = jnp.concatenate([y1, y2[3]], axis=1)[1] - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) return x @jit_no_excess_precision @@ -558,7 +681,7 @@ def test_matmul_dynamic_slice_concat(self, dtype): @fuser.fuse def matmul_concat(x, y1, y2, i, j): y = jnp.concatenate([y1, y2[i]], axis=1)[j] - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) return x @jit_no_excess_precision @@ -584,7 +707,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -606,7 +729,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -628,7 +751,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -650,7 +773,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -672,7 +795,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -694,7 +817,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -715,7 +838,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -737,7 +860,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bm=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bm=256)) ) ref = functools.partial(matmul, mm_ref) @@ -759,7 +882,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bm=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bm=256)) ) ref = functools.partial(matmul, mm_ref) @@ -781,7 +904,7 @@ def matmul(impl, x, y): return z.T impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -791,6 +914,87 @@ def matmul(impl, x, y): atol=5e-5, ) + @parameterized.parameters('float32', 'bfloat16') + def test_matmul_out_custom_vjp_fwd(self, dtype): + k0, k1 = jax.random.split(jax.random.key(0), 2) + x = jax.random.normal(k0, (256, 256), dtype) + y = jax.random.normal(k1, (256, 512), dtype) + + @jax.custom_vjp + def act(x): + return jax.nn.relu(x) * x + + def act_fwd(x): + del x + assert False, 'unreachable' + + def act_bwd(res, dy): + del res, dy + assert False, 'unreachable' + + act.defvjp(act_fwd, act_bwd) + + def matmul(impl, x, y): + z = impl(x, y) + return act(z) + + impl = fuser.fuse( + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) + ) + ref = functools.partial(matmul, mm_ref) + self.assertAllClose( + jax.jit(impl)(x, y), + jax.jit(ref)(x, y), + atol=5e-5, + ) + + @parameterized.parameters('float32', 'bfloat16') + def test_matmul_out_custom_vjp_bwd(self, dtype): + k0, k1, k2 = jax.random.split(jax.random.key(0), 3) + x = jax.random.normal(k0, (256, 256), dtype) + y = jax.random.normal(k1, (256, 512), dtype) + dz = jax.random.normal(k2, (256, 512), dtype) + + @jax.custom_vjp + def act(x): + return jax.nn.relu(x) * x + + def act_fwd(x): + return jax.nn.relu(x) * x, (x,) + + def act_bwd(res, dy): + (x,) = res + return (dy * x * 2.34,) + + act.defvjp(act_fwd, act_bwd) + + def matmul(impl, x, y, dz): + z = impl(x, y) + dz = dz.astype(z.dtype) + return jax.vjp(act, z)[1](dz)[0].astype(dtype) + + impl = fuser.fuse( + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) + ) + ref = functools.partial(matmul, mm_ref) + out_dz = jax.jit(impl)(x, y, dz) + out_ref_dz = jax.jit(ref)(x, y, dz) + expected_dz = ( + dz.astype(jnp.float32) + * 2.34 + * jnp.dot(x, y, preferred_element_type=jnp.float32) + ).astype(dtype) + self.assertAllClose( + out_dz, + expected_dz, + atol=5e-5, + ) + self.assertAllClose( + out_dz, + out_ref_dz, + atol=5e-5, + ) + @parameterized.parameters('float32', 'bfloat16') def test_matmul_out_transpose_mul(self, dtype): k0, k1 = jax.random.split(jax.random.key(0), 2) @@ -803,7 +1007,7 @@ def matmul(impl, x, y): return z.T * 2 impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -866,7 +1070,7 @@ def matmul(impl, x, y): impl = fuser.fuse( functools.partial( matmul, - fusable_matmul, + fusible_matmul, ) ) ref = functools.partial(matmul, dot_ref) @@ -892,7 +1096,7 @@ def matmul(impl, x, y): out_ref = jit_no_excess_precision(ref)(x, y) - impl = fuser.fuse(functools.partial(matmul, fusable_matmul)) + impl = fuser.fuse(functools.partial(matmul, fusible_matmul)) out = jax.jit(impl)(x, y) self.assertAllClose(out, out_ref, atol=0) @@ -907,16 +1111,14 @@ def matmul(impl, x, y): z = impl(x, y) return z - ref = functools.partial( - matmul, mm_ref - ) + ref = functools.partial(matmul, mm_ref) out_ref = jit_no_excess_precision(ref)(x, y) impl = fuser.fuse( functools.partial( matmul, - functools.partial(fusable_matmul, bk=256, bn=128), + functools.partial(fusible_matmul, bk=256, bn=128), ) ) out = jax.jit(impl)(x, y) @@ -924,7 +1126,7 @@ def matmul(impl, x, y): atol = 0 if jtu.is_device_tpu_at_least(6): # 256 MXU changes some tols. - atol = 1e-6 + atol = 1e-5 self.assertAllClose(out, out_ref, atol=atol) def test_matmul_f32_out_fused_downcast(self): @@ -952,7 +1154,7 @@ def matmul(impl, x, y): functools.partial( matmul, functools.partial( - fusable_matmul, + fusible_matmul, bm=bm, bk=bk, bn=bn, @@ -989,7 +1191,7 @@ def matmul(impl, x, y): functools.partial( matmul, functools.partial( - fusable_matmul, + fusible_matmul, bm=bm, bk=bk, bn=bn, @@ -1024,7 +1226,7 @@ def matmul(impl, x, y): functools.partial( matmul, functools.partial( - fusable_matmul, + fusible_matmul, bm=bm, bk=bk, bn=bn, diff --git a/tests/pallas/tpu_gmm_test.py b/tests/pallas/tpu_gmm_test.py index 9c416dabaeb1..464c1498f4b1 100644 --- a/tests/pallas/tpu_gmm_test.py +++ b/tests/pallas/tpu_gmm_test.py @@ -24,12 +24,8 @@ import jax.numpy as jnp import numpy as np -try: - import hypothesis as hp - import hypothesis.strategies as hps - CAN_USE_HYPOTHESIS = True -except (ModuleNotFoundError, ImportError): - CAN_USE_HYPOTHESIS = False +import hypothesis as hp +import hypothesis.strategies as hps jax.config.parse_flags_with_absl() @@ -37,326 +33,328 @@ partial = functools.partial -if CAN_USE_HYPOTHESIS: - hp.settings.register_profile( - "deterministic", - database=None, - derandomize=True, - deadline=None, - max_examples=10, - print_blob=True, +hp.settings.register_profile( + "deterministic", + database=None, + derandomize=True, + deadline=None, + max_examples=10, + print_blob=True, +) +hp.settings.load_profile("deterministic") + +def seed_strategy() -> hps.SearchStrategy[int]: + return hps.integers(min_value=0, max_value=4) + +@hps.composite +def group_strategy( + draw: hps.DrawFn, + max_groups: int = 32, + max_stride: int = 32, + min_groups: int = 1, +) -> tuple[int, int]: + assert max_stride <= max_groups + + # Sample the number of groups owned by each shard. + group_stride = draw(hps.integers(min_value=1, max_value=max_stride)) + + # Sample the number of groups as a multiple of the stride to ensure that we + # have an equal number of groups per shard. Round down s.t. num_groups <= + # max_groups. + num_groups = group_stride * draw( + hps.integers(min_value=min_groups, max_value=max_groups // group_stride) ) - hp.settings.load_profile("deterministic") - - def seed_strategy() -> hps.SearchStrategy[int]: - return hps.integers(min_value=0, max_value=4) - - @hps.composite - def group_strategy( - draw: hps.DrawFn, - max_groups: int = 32, - max_stride: int = 32, - min_groups: int = 1, - ) -> tuple[int, int]: - assert max_stride <= max_groups - - # Sample the number of groups owned by each shard. - group_stride = draw(hps.integers(min_value=1, max_value=max_stride)) - - # Sample the number of groups as a multiple of the stride to ensure that we - # have an equal number of groups per shard. Round down s.t. num_groups <= - # max_groups. - num_groups = group_stride * draw( - hps.integers(min_value=min_groups, max_value=max_groups // group_stride) - ) - return num_groups, group_stride - - @hps.composite - def group_sizes_strategy( - draw: hps.DrawFn, m: int, num_groups: int - ) -> jnp.ndarray: - # Randomly sample the ends of the groups in the m-dimension. Let the fuzzer - # sample with replacement so that it's possible to get zero-sized groups. Get - # 'num_groups - 1' run ends. The final group will end at 'm'. - ends_no_final = np.sort( - np.array( - [ - draw(hps.integers(min_value=0, max_value=m)) - for _ in range(num_groups - 1) - ], - dtype=np.int32, - ), + return num_groups, group_stride + +@hps.composite +def group_sizes_strategy( + draw: hps.DrawFn, m: int, num_groups: int +) -> jnp.ndarray: + # Randomly sample the ends of the groups in the m-dimension. Let the fuzzer + # sample with replacement so that it's possible to get zero-sized groups. Get + # 'num_groups - 1' run ends. The final group will end at 'm'. + ends_no_final = np.sort( + np.array( + [ + draw(hps.integers(min_value=0, max_value=m)) + for _ in range(num_groups - 1) + ], + dtype=np.int32, + ), + ) + ends = np.concatenate([ends_no_final, np.array([m], dtype=np.int32)]) + + # Calculate the run starts by shifting ends 1 to the right. The first run + # starts at zero. + starts = np.concatenate([np.zeros(1, dtype=np.int32), ends_no_final]) + return jnp.array(ends - starts, dtype=jnp.int32) + +GROUPED_MATMUL_TESTS = ( + (128, 128, 128), # Small + (512, 2048, 256), # Big + (128, 8, 16), # Test partial tiles. +) + +def random_dense( + shape: tuple[int, ...], + key: jax.Array, + dtype: jnp.dtype, + limit: int | None = None, +) -> jnp.ndarray: + if limit is None: + limit = 1 / np.prod(shape) + x = jax.random.uniform(key, shape, dtype, minval=-limit, maxval=limit) # pylint: disable=invalid-unary-operand-type + return x.astype(jnp.bfloat16).astype(dtype) + +def dot( + lhs: jnp.ndarray, + rhs: jnp.ndarray, + transpose_lhs: bool = False, + transpose_rhs: bool = False, + preferred_element_type: jnp.dtype = jnp.float32, +) -> jnp.ndarray: + lhs = jnp.transpose(lhs) if transpose_lhs else lhs + rhs = jnp.transpose(rhs) if transpose_rhs else rhs + return jax.lax.dot(lhs, rhs, preferred_element_type=preferred_element_type) + +def reference_gmm( + lhs: jnp.ndarray, + rhs: jnp.ndarray, + group_sizes: jnp.ndarray, + preferred_element_type: jnp.dtype = jnp.float32, +) -> jnp.ndarray: + + start = 0 + out = [] + for i, size in enumerate(group_sizes): + result = dot( + lhs[start : start + size, :], + rhs[i, :, :], + preferred_element_type=preferred_element_type, ) - ends = np.concatenate([ends_no_final, np.array([m], dtype=np.int32)]) - # Calculate the run starts by shifting ends 1 to the right. The first run - # starts at zero. - starts = np.concatenate([np.zeros(1, dtype=np.int32), ends_no_final]) - return jnp.array(ends - starts, dtype=jnp.int32) - - GROUPED_MATMUL_TESTS = ( - (128, 128, 128), # Small - (512, 2048, 256), # Big - (128, 8, 16), # Test partial tiles. - ) + out.append(result) + start += group_sizes[i] + return jnp.concatenate(out, axis=0) + +def with_dtype_arguments(xs: tuple[Any, ...]) -> tuple[Any, ...]: + dtypes = [jnp.float32, jnp.bfloat16] + + result = [] + for x in xs: + for dtypes_tuple in itertools.product(dtypes, dtypes, dtypes): + result.append(x + dtypes_tuple) + return tuple(result) + +def with_transpose_argument(xs: tuple[Any, ...]) -> tuple[Any, ...]: + flags = [False, True] + result = [] + for x in xs: + for flag in flags: + result.append(x + (flag,)) + return tuple(result) + +def tolerances( + lhs_dtype: jnp.dtype, rhs_dtype: jnp.dtype, out_dtype: jnp.dtype +) -> tuple[float, float]: + if ( + lhs_dtype == jnp.bfloat16 + or rhs_dtype == jnp.bfloat16 + or out_dtype == jnp.bfloat16 + ): + if jtu.is_device_tpu(7): + return 2e-2, 1e-2 + return 1e-3, 1e-2 # atol, rtol + return 1e-3, 1e-5 # atol, rtol + +# TODO(tgale): Fix errors with strict dtype promotion. +@jtu.with_config(jax_numpy_dtype_promotion="standard") +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe +class GroupedMatmulTest(jtu.JaxTestCase): + + def setUp(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Test requires TPU device.") + + super().setUp() + self.key = jax.random.PRNGKey(1234) + + def assert_allclose( + self, + out: jnp.ndarray, + expected_out: jnp.ndarray, + *, + atol: float = 1e-5, + rtol: float = 1e-5, + ): + self.assertEqual(out.dtype, expected_out.dtype) + np.testing.assert_allclose( + out.astype(jnp.float32), + expected_out.astype(jnp.float32), + atol=atol, + rtol=rtol, + ) - def random_dense( - shape: tuple[int, ...], - key: jax.Array, - dtype: jnp.dtype, - limit: int | None = None, - ) -> jnp.ndarray: - if limit is None: - limit = 1 / np.prod(shape) - x = jax.random.uniform(key, shape, dtype, minval=-limit, maxval=limit) # pylint: disable=invalid-unary-operand-type - return x.astype(jnp.bfloat16).astype(dtype) - - def dot( - lhs: jnp.ndarray, - rhs: jnp.ndarray, - transpose_lhs: bool = False, - transpose_rhs: bool = False, - preferred_element_type: jnp.dtype = jnp.float32, - ) -> jnp.ndarray: - lhs = jnp.transpose(lhs) if transpose_lhs else lhs - rhs = jnp.transpose(rhs) if transpose_rhs else rhs - return jax.lax.dot(lhs, rhs, preferred_element_type=preferred_element_type) - - def reference_gmm( - lhs: jnp.ndarray, - rhs: jnp.ndarray, - group_sizes: jnp.ndarray, - preferred_element_type: jnp.dtype = jnp.float32, - ) -> jnp.ndarray: - - start = 0 - out = [] - for i, size in enumerate(group_sizes): - result = dot( - lhs[start : start + size, :], - rhs[i, :, :], - preferred_element_type=preferred_element_type, - ) + def gmm_test( + self, + m: int, + k: int, + n: int, + data: hps.SearchStrategy[hps.DataObject], + interpret: bool = False, + ): + seed = data.draw(seed_strategy()) + num_groups, _ = data.draw(group_strategy(max_stride=1)) + lhs_dtype, rhs_dtype, out_dtype = ( + data.draw(hps.sampled_from([jnp.float32, jnp.bfloat16])) + for _ in range(3) + ) + transpose_rhs = data.draw(hps.booleans()) + + key = jax.random.key(seed) + k1, k2 = jax.random.split(key, 2) + lhs = random_dense((m, k), k1, lhs_dtype, limit=1) + rhs = random_dense((num_groups, k, n), k2, rhs_dtype, limit=1) + group_sizes = data.draw(group_sizes_strategy(m=m, num_groups=num_groups)) + + out, vjpfun = jax.vjp( + partial( + mblx.gmm, + preferred_element_type=out_dtype, + transpose_rhs=transpose_rhs, + interpret=interpret, + ), + lhs, + rhs.swapaxes(1, 2) if transpose_rhs else rhs, + group_sizes, + ) - out.append(result) - start += group_sizes[i] - return jnp.concatenate(out, axis=0) - - def with_dtype_arguments(xs: tuple[Any, ...]) -> tuple[Any, ...]: - dtypes = [jnp.float32, jnp.bfloat16] - - result = [] - for x in xs: - for dtypes_tuple in itertools.product(dtypes, dtypes, dtypes): - result.append(x + dtypes_tuple) - return tuple(result) - - def with_transpose_argument(xs: tuple[Any, ...]) -> tuple[Any, ...]: - flags = [False, True] - result = [] - for x in xs: - for flag in flags: - result.append(x + (flag,)) - return tuple(result) - - def tolerances( - lhs_dtype: jnp.dtype, rhs_dtype: jnp.dtype, out_dtype: jnp.dtype - ) -> tuple[float, float]: - if ( - lhs_dtype == jnp.bfloat16 - or rhs_dtype == jnp.bfloat16 - or out_dtype == jnp.bfloat16 - ): - return 1e-3, 1e-2 # atol, rtol - return 1e-3, 1e-5 # atol, rtol - - # TODO(tgale): Fix errors with strict dtype promotion. - @jtu.with_config(jax_numpy_dtype_promotion="standard") - class GroupedMatmulTest(jtu.JaxTestCase): - - def setUp(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Test requires TPU device.") - - super().setUp() - self.key = jax.random.PRNGKey(1234) - - def assert_allclose( - self, - out: jnp.ndarray, - expected_out: jnp.ndarray, - *, - atol: float = 1e-5, - rtol: float = 1e-5, - ): - self.assertEqual(out.dtype, expected_out.dtype) - np.testing.assert_allclose( - out.astype(jnp.float32), - expected_out.astype(jnp.float32), - atol=atol, - rtol=rtol, + def reference_fn(lhs, rhs, group_sizes, preferred_element_type): + rhs = rhs.swapaxes(1, 2) if transpose_rhs else rhs + return reference_gmm( + lhs, rhs, group_sizes, preferred_element_type=preferred_element_type ) - def gmm_test( - self, - m: int, - k: int, - n: int, - data: hps.SearchStrategy[hps.DataObject], - interpret: bool = False, - ): - seed = data.draw(seed_strategy()) - num_groups, _ = data.draw(group_strategy(max_stride=1)) - lhs_dtype, rhs_dtype, out_dtype = [ - data.draw(hps.sampled_from([jnp.float32, jnp.bfloat16])) - for _ in range(3) - ] - transpose_rhs = data.draw(hps.booleans()) - - key = jax.random.key(seed) - k1, k2 = jax.random.split(key, 2) - lhs = random_dense((m, k), k1, lhs_dtype, limit=1) - rhs = random_dense((num_groups, k, n), k2, rhs_dtype, limit=1) - group_sizes = data.draw(group_sizes_strategy(m=m, num_groups=num_groups)) - - out, vjpfun = jax.vjp( - partial( - mblx.gmm, - preferred_element_type=out_dtype, - transpose_rhs=transpose_rhs, - interpret=interpret, - ), - lhs, - rhs.swapaxes(1, 2) if transpose_rhs else rhs, - group_sizes, - ) + expected_out, reference_vjpfun = jax.vjp( + partial(reference_fn, preferred_element_type=out_dtype), + lhs, + rhs.swapaxes(1, 2) if transpose_rhs else rhs, + group_sizes, + ) + self.assertEqual(out.dtype, out_dtype) + self.assertEqual(expected_out.dtype, out_dtype) + + atol, rtol = tolerances(lhs_dtype, rhs_dtype, out_dtype) + self.assert_allclose(out, expected_out, atol=atol, rtol=rtol) + + cotangent = random_dense((m, n), k1, out_dtype, limit=1) + grad_lhs, grad_rhs, *_ = vjpfun(cotangent) + expected_grad_lhs, expected_grad_rhs, *_ = reference_vjpfun(cotangent) + self.assert_allclose(grad_lhs, expected_grad_lhs, atol=atol, rtol=rtol) + self.assert_allclose(grad_rhs, expected_grad_rhs, atol=atol, rtol=rtol) + + @parameterized.parameters(*GROUPED_MATMUL_TESTS) + @hp.given(hps.data()) + def test_gmm( + self, + m: int, + k: int, + n: int, + data: hps.SearchStrategy[hps.DataObject], + ): + self.gmm_test(m, k, n, data) + + # NOTE: Run fewer tests with interpret mode. We just want to sanity check that + # changes do not break running these kernels with interpret=True. + @parameterized.parameters(*GROUPED_MATMUL_TESTS[0:1]) + @hp.given(hps.data()) + def test_gmm_interpret( + self, + m: int, + k: int, + n: int, + data: hps.SearchStrategy[hps.DataObject], + ): + self.skipTest("interpret mode with dynamic grids is unsupported") + self.gmm_test( + m, + k, + n, + data=data, + interpret=True, + ) - def reference_fn(lhs, rhs, group_sizes, preferred_element_type): - rhs = rhs.swapaxes(1, 2) if transpose_rhs else rhs - return reference_gmm( - lhs, rhs, group_sizes, preferred_element_type=preferred_element_type - ) + @parameterized.parameters(*GROUPED_MATMUL_TESTS) + @hp.given(hps.data()) + def test_gmm_sharded_groups( + self, + m: int, + k: int, + n: int, + data: hps.SearchStrategy[hps.DataObject], + ): + seed = data.draw(seed_strategy()) + num_groups, group_stride = data.draw(group_strategy()) + lhs_dtype, rhs_dtype, out_dtype = ( + data.draw(hps.sampled_from([jnp.float32, jnp.bfloat16])) + for _ in range(3) + ) - expected_out, reference_vjpfun = jax.vjp( - partial(reference_fn, preferred_element_type=out_dtype), + key = jax.random.key(seed) + k1, k2 = jax.random.split(key, 2) + lhs = random_dense((m, k), k1, lhs_dtype, limit=1) + rhs = random_dense((num_groups, k, n), k2, rhs_dtype, limit=1) + group_sizes = data.draw(group_sizes_strategy(m=m, num_groups=num_groups)) + + out, shard_vjpfun = jax.vjp( + partial(mblx.gmm, preferred_element_type=out_dtype), + lhs, + rhs[0:group_stride], + group_sizes, + ) + vjpfuns = [shard_vjpfun] + for group_offset in range(group_stride, num_groups, group_stride): + out, shard_vjpfun = jax.vjp( + lambda lhs, rhs, group_sizes, out: mblx.gmm( + lhs, + rhs, + group_sizes, + out_dtype, + group_offset=jnp.array(group_offset, dtype=jnp.int32), # pylint: disable=cell-var-from-loop + existing_out=out, + ), lhs, - rhs.swapaxes(1, 2) if transpose_rhs else rhs, + rhs[group_offset : group_offset + group_stride], group_sizes, + out, ) - self.assertEqual(out.dtype, out_dtype) - self.assertEqual(expected_out.dtype, out_dtype) - - atol, rtol = tolerances(lhs_dtype, rhs_dtype, out_dtype) - self.assert_allclose(out, expected_out, atol=atol, rtol=rtol) - - cotangent = random_dense((m, n), k1, out_dtype, limit=1) - grad_lhs, grad_rhs, *_ = vjpfun(cotangent) - expected_grad_lhs, expected_grad_rhs, *_ = reference_vjpfun(cotangent) - self.assert_allclose(grad_lhs, expected_grad_lhs, atol=atol, rtol=rtol) - self.assert_allclose(grad_rhs, expected_grad_rhs, atol=atol, rtol=rtol) - - @parameterized.parameters(*GROUPED_MATMUL_TESTS) - @hp.given(hps.data()) - def test_gmm( - self, - m: int, - k: int, - n: int, - data: hps.SearchStrategy[hps.DataObject], - ): - self.gmm_test(m, k, n, data) - - # NOTE: Run fewer tests with interpret mode. We just want to sanity check that - # changes do not break running these kernels with interpret=True. - @parameterized.parameters(*GROUPED_MATMUL_TESTS[0:1]) - @hp.given(hps.data()) - def test_gmm_interpret( - self, - m: int, - k: int, - n: int, - data: hps.SearchStrategy[hps.DataObject], - ): - self.skipTest("interpret mode with dynamic grids is unsupported") - self.gmm_test( - m, - k, - n, - data=data, - interpret=True, - ) + vjpfuns.append(shard_vjpfun) - @parameterized.parameters(*GROUPED_MATMUL_TESTS) - @hp.given(hps.data()) - def test_gmm_sharded_groups( - self, - m: int, - k: int, - n: int, - data: hps.SearchStrategy[hps.DataObject], + expected_out, reference_vjpfun = jax.vjp( + partial(reference_gmm, preferred_element_type=out_dtype), + lhs, + rhs, + group_sizes, + ) + self.assertEqual(out.dtype, out_dtype) + self.assertEqual(expected_out.dtype, out_dtype) + atol, rtol = tolerances(lhs_dtype, rhs_dtype, out_dtype) + self.assert_allclose(out, expected_out, atol=atol, rtol=rtol) + + cotangent = random_dense((m, n), k1, out_dtype, limit=1) + shard_grad_lhs, shard_grad_rhs, *_ = vjpfuns[0](cotangent) + grad_lhs = shard_grad_lhs + grad_rhs = [shard_grad_rhs] + for i, group_offset in enumerate( + range(group_stride, num_groups, group_stride) ): - seed = data.draw(seed_strategy()) - num_groups, group_stride = data.draw(group_strategy()) - lhs_dtype, rhs_dtype, out_dtype = [ - data.draw(hps.sampled_from([jnp.float32, jnp.bfloat16])) - for _ in range(3) - ] - - key = jax.random.key(seed) - k1, k2 = jax.random.split(key, 2) - lhs = random_dense((m, k), k1, lhs_dtype, limit=1) - rhs = random_dense((num_groups, k, n), k2, rhs_dtype, limit=1) - group_sizes = data.draw(group_sizes_strategy(m=m, num_groups=num_groups)) - - out, shard_vjpfun = jax.vjp( - partial(mblx.gmm, preferred_element_type=out_dtype), - lhs, - rhs[0:group_stride], - group_sizes, - ) - vjpfuns = [shard_vjpfun] - for group_offset in range(group_stride, num_groups, group_stride): - out, shard_vjpfun = jax.vjp( - lambda lhs, rhs, group_sizes, out: mblx.gmm( - lhs, - rhs, - group_sizes, - out_dtype, - group_offset=jnp.array(group_offset, dtype=jnp.int32), # pylint: disable=cell-var-from-loop - existing_out=out, - ), - lhs, - rhs[group_offset : group_offset + group_stride], - group_sizes, - out, - ) - vjpfuns.append(shard_vjpfun) - - expected_out, reference_vjpfun = jax.vjp( - partial(reference_gmm, preferred_element_type=out_dtype), - lhs, - rhs, - group_sizes, - ) - self.assertEqual(out.dtype, out_dtype) - self.assertEqual(expected_out.dtype, out_dtype) - atol, rtol = tolerances(lhs_dtype, rhs_dtype, out_dtype) - self.assert_allclose(out, expected_out, atol=atol, rtol=rtol) - - cotangent = random_dense((m, n), k1, out_dtype, limit=1) - shard_grad_lhs, shard_grad_rhs, *_ = vjpfuns[0](cotangent) - grad_lhs = shard_grad_lhs - grad_rhs = [shard_grad_rhs] - for i, group_offset in enumerate( - range(group_stride, num_groups, group_stride) - ): - shard_grad_lhs, shard_grad_rhs, *_ = vjpfuns[i + 1](cotangent) - grad_lhs += shard_grad_lhs - grad_rhs.append(shard_grad_rhs) - grad_rhs = jnp.concatenate(grad_rhs, axis=0) - expected_grad_lhs, expected_grad_rhs, *_ = reference_vjpfun(cotangent) - self.assert_allclose(grad_lhs, expected_grad_lhs, atol=atol, rtol=rtol) - self.assert_allclose(grad_rhs, expected_grad_rhs, atol=atol, rtol=rtol) + shard_grad_lhs, shard_grad_rhs, *_ = vjpfuns[i + 1](cotangent) + grad_lhs += shard_grad_lhs + grad_rhs.append(shard_grad_rhs) + grad_rhs = jnp.concatenate(grad_rhs, axis=0) + expected_grad_lhs, expected_grad_rhs, *_ = reference_vjpfun(cotangent) + self.assert_allclose(grad_lhs, expected_grad_lhs, atol=atol, rtol=rtol) + self.assert_allclose(grad_rhs, expected_grad_rhs, atol=atol, rtol=rtol) if __name__ == "__main__": diff --git a/tests/pallas/tpu_info_test.py b/tests/pallas/tpu_info_test.py new file mode 100644 index 000000000000..9107a1f0d26b --- /dev/null +++ b/tests/pallas/tpu_info_test.py @@ -0,0 +1,52 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +import jax +from jax._src import test_util as jtu +from jax.experimental.pallas import tpu as pltpu + + +class TpuInfoTest(jtu.JaxTestCase): + + def test_get_tpu_info(self): + device = jax.devices()[0] + if not jtu.is_device_tpu(): + self.assertFalse(pltpu.is_tpu_device()) + return + self.assertTrue(pltpu.is_tpu_device()) + info = pltpu.get_tpu_info() + self.assertIsInstance(info, pltpu.TpuInfo) + match device.device_kind: + case "TPU v3": + self.assertEqual(info.chip_version, pltpu.ChipVersion.TPU_V3) + case "TPU v4 lite": + self.assertEqual(info.chip_version, pltpu.ChipVersion.TPU_V4I) + case "TPU v4": + self.assertEqual(info.chip_version, pltpu.ChipVersion.TPU_V4) + case "TPU v5 lite": + self.assertEqual(info.chip_version, pltpu.ChipVersion.TPU_V5E) + case "TPU v5": + self.assertEqual(info.chip_version, pltpu.ChipVersion.TPU_V5P) + case "TPU v6 lite": + self.assertEqual(info.chip_version, pltpu.ChipVersion.TPU_V6E) + case "TPU7x": + self.assertEqual(info.chip_version, pltpu.ChipVersion.TPU_7X) + case _: + self.fail(f"Unexpected device kind: {device.device_kind}") + + +if __name__ == "__main__": + jax.config.parse_flags_with_absl() + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index c8def2627462..0183add13837 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -15,14 +15,14 @@ import functools import math import sys -import unittest from absl.testing import absltest from absl.testing import parameterized import jax from jax import lax +from jax._src import dtypes from jax._src import test_util as jtu -from jax._src.pallas import utils as pallas_utils +from jax._src.pallas import pallas_test_util as ptu from jax.experimental import pallas as pl import jax.numpy as jnp import numpy as np @@ -32,46 +32,63 @@ else: pltpu = None -try: - import hypothesis as hp -except (ModuleNotFoundError, ImportError): - raise unittest.SkipTest("tests depend on hypothesis library") - -import hypothesis.strategies as hps jax.config.parse_flags_with_absl() jtu.setup_hypothesis(max_examples=100) -_JAX_DTYPES = ( +_JAX_DTYPES_NO_BOOL = ( jnp.float32, jnp.bfloat16, jnp.int32, jnp.int16, jnp.int8, - jnp.bool_, + jnp.int4, + jnp.float8_e5m2, ) +_JAX_DTYPES = ( + *_JAX_DTYPES_NO_BOOL, + jnp.bool_, +) -class PallasBaseTest(jtu.JaxTestCase): - INTERPRET = False - - def setUp(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Test only supported on TPU.") +_JAX_INT_DTYPES = ( + jnp.int32, + jnp.int16, + jnp.int8, + jnp.int4, + jnp.uint32, + jnp.uint16, + jnp.uint8, + jnp.uint4, +) - super().setUp() - @classmethod - def pallas_call(cls, *args, **kwargs): - return pl.pallas_call(*args, interpret=cls.INTERPRET, **kwargs) +def rand( + shape: tuple[int, ...], dtype: np.dtype | jnp.dtype, seed: int = 1234 +) -> np.ndarray: + """A helper function to generate random data for testing.""" + rng = np.random.Generator(np.random.Philox(counter=0, key=seed)) + if jnp.issubdtype(dtype, jnp.floating): + return rng.normal(size=shape).astype(dtype) + if jnp.issubdtype(dtype, jnp.integer): + return rng.integers( + jnp.iinfo(dtype).min, jnp.iinfo(dtype).max, shape, dtype=np.int32 + ).astype(dtype) + raise NotImplementedError(f"Unsupported random data generation for {dtype=}") -class OpsTest(PallasBaseTest): +@jtu.thread_unsafe_test_class(condition=not jtu.hypothesis_is_thread_safe()) +class OpsTest(ptu.PallasTPUTest): @parameterized.product( - from_dtype=_JAX_DTYPES, to_dtype=_JAX_DTYPES, is_ref_bitcast=[False, True] + from_dtype=_JAX_DTYPES, + to_dtype=_JAX_DTYPES, + is_ref_bitcast=[False, True], + use_primitive_io_op=[False, True], ) - def test_bitcast(self, from_dtype, to_dtype, is_ref_bitcast): + def test_bitcast( + self, from_dtype, to_dtype, is_ref_bitcast, use_primitive_io_op + ): if not jtu.is_device_tpu_at_least(version=4): self.skipTest("Run on TPUv4+ to have expected memory layout") if from_dtype == to_dtype: @@ -81,13 +98,19 @@ def test_bitcast(self, from_dtype, to_dtype, is_ref_bitcast): def kernel(x_ref, y_ref): if is_ref_bitcast: - y_ref[...] = x_ref.bitcast(to_dtype)[...] + if use_primitive_io_op: + pltpu.store(y_ref, pltpu.load(x_ref.bitcast(to_dtype))) + else: + y_ref[...] = x_ref.bitcast(to_dtype)[...] else: - y_ref[...] = pltpu.bitcast(x_ref[...], to_dtype) + if use_primitive_io_op: + pltpu.store(y_ref, pltpu.bitcast(pltpu.load(x_ref), to_dtype)) + else: + y_ref[...] = pltpu.bitcast(x_ref[...], to_dtype) m, n = 1, 256 - in_packing = 32 // pallas_utils.dtype_bitwidth(from_dtype) - out_packing = 32 // pallas_utils.dtype_bitwidth(to_dtype) + in_packing = 32 // dtypes.itemsize_bits(from_dtype) + out_packing = 32 // dtypes.itemsize_bits(to_dtype) in_shape = (m * in_packing, n) out_shape = (m * out_packing, n) inp = np.arange(np.prod(in_shape), dtype=from_dtype).reshape(in_shape) @@ -103,57 +126,13 @@ def kernel(x_ref, y_ref): )(inp) self.assertAllClose(out, out_interpret) - @parameterized.product(is_dynamic=(False, True)) - @hp.given( - axis=hps.integers(0, 3), - shift=hps.integers(0, 3), - stride=hps.one_of(hps.just(None), hps.integers(0, 2)), - # Stride dimension on the minor most is not supported. - stride_axis=hps.one_of(hps.just(None), hps.integers(0, 2)), - ) - @hp.example(3, 9, 1, 2) - @hp.example(3, 9, 2, 2) - @hp.example(0, 9, 0, 1) - @hp.example(0, 9, 1, 1) - def test_roll(self, is_dynamic, axis, shift, stride, stride_axis): - if (stride is None) != (stride_axis is None): - self.skipTest( - "Roll op requires both stride and stride_axis to be either specified" - " or not specified." - ) - if (not jtu.is_device_tpu(version=5)) and stride_axis == 2: - self.skipTest( - "Roll op with stride axis on 2nd minor requires at least TPU v5" - ) - shape = (4, 4, 32, 512) + def test_stop_gradient(self): + def kernel(x_ref, y_ref): + y_ref[...] = jax.lax.stop_gradient(x_ref[...] + 1) - def kernel(s_ref, x_ref, y_ref): - amt = s_ref[0] if is_dynamic else shift - y_ref[...] = pltpu.roll( - x_ref[...], amt, axis, stride=stride, stride_axis=stride_axis - ) - - def roll(x, shift, axis, stride=None, stride_axis=None): - assert (stride is None) == (stride_axis is None) - if stride is None: - return np.roll(x, shift, axis) - outputs = [ - np.roll(xs, shift + i * stride, axis) - for i, xs in enumerate(np.split(x, x.shape[stride_axis], stride_axis)) - ] - return np.concatenate(outputs, stride_axis) - - inp = np.arange(np.prod(shape), dtype=jnp.int32).reshape(shape) - ref = roll(inp, shift, axis, stride, stride_axis) - dynamic_shift = jnp.array([abs(shift)], jnp.int32) - for interpret in [False, True]: - out = pl.pallas_call( - kernel, - out_shape=jax.ShapeDtypeStruct(shape, jnp.int32), - grid_spec=pltpu.PrefetchScalarGridSpec(num_scalar_prefetch=1), - interpret=interpret, - )(dynamic_shift, inp) - np.testing.assert_array_equal(out, ref, err_msg=f"{interpret=}") + x = jnp.arange(1024, dtype=jnp.float32) + y = pl.pallas_call(kernel, out_shape=x)(x) + self.assertAllClose(y, x + 1) def test_interleave_vectors(self): if not jtu.is_device_tpu_at_least(version=4): @@ -178,10 +157,9 @@ def kernel(x_ref, y_ref, out_ref): @parameterized.parameters([jnp.int32, jnp.int16, jnp.int8, jnp.int4]) def test_row_broadcast(self, dtype): - if not jtu.if_cloud_tpu_at_least(2025, 1, 10): - self.skipTest("Requires libtpu built after 2025-01-10") - if not self.INTERPRET and jtu.get_tpu_version() < 5: - self.skipTest("Requires TPUv5+") + bitwidth = dtypes.itemsize_bits(dtype) + if not self.INTERPRET and jtu.get_tpu_version() < 4 and bitwidth < 8: + self.skipTest("Requires TPUv4+ for sub-byte types") def kernel(x_ref, y_ref): y_ref[...] = jnp.broadcast_to(x_ref[pl.ds(3, 1)], y_ref.shape).astype(y_ref.dtype) m, n = 4, 1152 @@ -193,30 +171,18 @@ def kernel(x_ref, y_ref): )(x) np.testing.assert_array_equal(y, jnp.broadcast_to(x[3:4], y.shape)) - def test_tpu_unsigned_int(self): - self.skipTest("TODO(apaszke): Unsigned upcasts were implemented incorrectly") - def body(x_ref, o_ref): - # Test cast from uint16 -> uint32 - ux = lax.convert_element_type(x_ref[...], jnp.uint32) - res = ux + 1 - # Test cast from uint32 -> float32 - o_ref[...] = res.astype(jnp.float32) - out = jax.ShapeDtypeStruct((8, 128), jnp.float32) - x = jnp.arange(8 * 128, dtype=jnp.uint16).reshape((8, 128)) - result = self.pallas_call(body, out_shape=out)(x) - np.testing.assert_array_equal(result, x.astype(jnp.float32) + 1.0) - - def test_tpu_signed_int_upcast(self): + @parameterized.parameters([jnp.uint4, jnp.int4]) + def test_tpu_int4_upcast_and_matmul(self, dtype): if not jtu.is_device_tpu_at_least(version=5): self.skipTest("TPUv5+ needed for integer matmuls") def body(x_ref, o_ref): - # Test cast from int4 -> int8 + # Test cast from (u)int4 -> int8 ux = lax.convert_element_type(x_ref[...], jnp.int8) o_ref[...] = jax.lax.dot(ux, ux, preferred_element_type=jnp.int32) out = jax.ShapeDtypeStruct((128, 128), jnp.int32) - x = jnp.arange(128 * 128, dtype=jnp.int4).reshape((128, 128)) + x = jnp.arange(128 * 128, dtype=dtype).reshape((128, 128)) result = self.pallas_call(body, out_shape=out)(x) np.testing.assert_array_equal( result, @@ -227,6 +193,65 @@ def body(x_ref, o_ref): ), ) + def test_sum_of_two_matmuls(self): + if not jtu.is_device_tpu_at_least(version=5): + self.skipTest("Test requires TPUv5+") + + M, K = 8, 8 + k1, k2, k3, k4 = jax.random.split(jax.random.key(42), 4) + a_val = jax.random.normal(k1, (M, K), dtype=jnp.float32) + b_val = jax.random.normal(k2, (K,), dtype=jnp.float32) + c_val = jax.random.normal(k3, (M, K), dtype=jnp.float32) + d_val = jax.random.normal(k4, (K,), dtype=jnp.float32) + + def kernel(a_ref, b_ref, c_ref, d_ref, o_ref): + a = a_ref[:] + b = b_ref[:] + c = c_ref[:] + d = d_ref[:] + res1 = jnp.dot(a, b) + res2 = jnp.dot(c, d) + + o_ref[:] = res1 + res2 + + @jax.jit + def pallas_fn(a, b, c, d): + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((M,), np.float32), + grid=(1,), + )(a, b, c, d) + + result_pallas = pallas_fn(a_val, b_val, c_val, d_val) + expected = jnp.dot(a_val, b_val) + jnp.dot(c_val, d_val) + self.assertAllClose(result_pallas, expected, atol=1e-5, rtol=1e-5) + + @parameterized.product(from_dtype=_JAX_INT_DTYPES, + to_dtype=_JAX_INT_DTYPES) + def test_integer_cast(self, from_dtype, to_dtype): + if not jtu.is_device_tpu_at_least(4): + self.skipTest("Expect TPUv4+") + # Generate both low and high values to better cover the entire range + # of the source dtype. + min_val = from_dtype(jnp.iinfo(from_dtype).min) + max_val = from_dtype(jnp.iinfo(from_dtype).max) + if jnp.iinfo(from_dtype).bits > 4: + x_random = jax.random.randint(jax.random.key(0), shape=(112, 256), + minval=min_val, maxval=max_val, dtype=from_dtype) + else: + # randint does not support sub-byte types. + x_random = jnp.arange(112 * 256, dtype=from_dtype).reshape((112, 256)) + arange = jnp.arange(8 * 256, dtype=from_dtype).reshape((8, 256)) + x = jnp.concatenate([min_val + arange, x_random, max_val - arange], axis=0) + + def body(x_ref, o_ref): + o_ref[...] = lax.convert_element_type(x_ref[...], to_dtype) + + out = jax.ShapeDtypeStruct(x.shape, to_dtype) + expected = x.astype(to_dtype) + result = self.pallas_call(body, out_shape=out)(x) + np.testing.assert_array_equal(result, expected) + def test_select_with_scalar_condition(self): def kernel(cond, lhs, rhs, out): out[:] = jax.lax.select(cond[0] != 0, lhs[:], rhs[:]) @@ -272,10 +297,8 @@ def body(x_ref, y_ref): @parameterized.product(dtype=[jnp.float32, jnp.bfloat16, jnp.int16, jnp.int8]) def test_cast_vector_to_mask(self, dtype): - if not jtu.if_cloud_tpu_at_least(2025, 1, 22): - self.skipTest("Requires libtpu built after 2025-01-22") shape = (128, 128) - bitwidth = pallas_utils.dtype_bitwidth(dtype) + bitwidth = dtypes.itemsize_bits(dtype) if jtu.get_tpu_version() < 5 and bitwidth < 32: self.skipTest( f"Not implemented: cast vector to mask with bitwidth == {bitwidth}" @@ -302,15 +325,6 @@ def kernel(x_ref, mask_ref, o_ref): reduce_func = [jnp.sum, jnp.max, jnp.min] ) def test_reduction(self, dtype, axis, reduce_func): - if dtype == jnp.int32: - # TODO(apaszke): Remove after 12 weeks have passed. - if not jtu.if_cloud_tpu_at_least(2024, 12, 19): - self.skipTest("Requires libtpu built after 2024-12-19") - if axis == 2: - self.skipTest("Int32 reduction on minor is not supported.") - # TODO(b/384127570): fix bfloat16 reduction. - if dtype == jnp.bfloat16 and reduce_func != jnp.sum: - self.skipTest("b/384127570") in_shape = (2, 16, 128) out_shape = list(in_shape) out_shape[axis] = 1 @@ -327,15 +341,52 @@ def kernel(x, out): np.testing.assert_array_equal(result, expected) @parameterized.product( + axis=[0, 1, 2], + in_shape=[(2, 29, 206), (12, 28), (12,)], + reduce_func = [jnp.argmax, jnp.argmin], + keepdims=[False, True], + ) + def test_reduce_index(self, axis, in_shape, reduce_func, keepdims): + dtype = jnp.float32 + rank = len(in_shape) + if axis >= rank: + self.skipTest("Requires axis < rank") + if axis == rank - 1: + if keepdims and not jtu.is_device_tpu_at_least(version=4): + self.skipTest("Requires TPUv4+ for axis=rank-1 and keepdims=True") + if not keepdims and not jtu.is_device_tpu_at_least(version=5): + self.skipTest("Requires TPUv5+ for axis=rank-1 and keepdims=False") + if rank == 1 and not keepdims: + self.skipTest("Scalar output not supported") + + out_shape = list(in_shape) + if keepdims: + out_shape[axis] = 1 + else: + del out_shape[axis] + + def kernel(x, out): + out[:] = reduce_func(x[:], axis, keepdims=keepdims) + + x = jax.random.permutation( + jax.random.key(22), + jnp.arange(np.prod(in_shape), dtype=dtype) + ).reshape(in_shape) + result = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(out_shape, jnp.int32), + )(x) + expected = reduce_func(x, axis, keepdims=keepdims) + np.testing.assert_array_equal(result, expected) + + @parameterized.product( + shape=[(129, 129), (1, 129), (2, 129), (4, 129)], msk_dtype=[jnp.float32, jnp.bfloat16, jnp.int8], dtype=[jnp.float32, jnp.bfloat16], ) - def test_i1_relayout_with_bitwidth_change(self, msk_dtype, dtype): - if not jtu.if_cloud_tpu_at_least(2025, 1, 25): - self.skipTest("Requires libtpu built after 2025-01-25") - shape = (129, 129) - msk_bitwidth = pallas_utils.dtype_bitwidth(msk_dtype) - bitwidth = pallas_utils.dtype_bitwidth(dtype) + def test_i1_relayout_bw(self, shape, msk_dtype, dtype): + msk_bitwidth = dtypes.itemsize_bits(msk_dtype) + bitwidth = dtypes.itemsize_bits(dtype) if jtu.get_tpu_version() < 5 and msk_bitwidth < 32: self.skipTest( "Not implemented: cast vector to mask with bitwidth ==" @@ -361,13 +412,61 @@ def kernel(x_ref, mask_ref, o_ref): expected = jnp.where(mask, x, jnp.zeros_like(x)) self.assertArraysEqual(out, expected) + @parameterized.product( + msk_dtype=[jnp.float32, jnp.bfloat16, jnp.int8], + dtype=[jnp.float32, jnp.bfloat16, jnp.int8], + ) + def test_i1_relayout_bw_tiling(self, msk_dtype, dtype): + self.skipTest("TODO: jevinjiang - Enable once presubmits pass.") + shape = (256, 256) + bitwidth = dtypes.itemsize_bits(dtype) + msk_bitwidth = dtypes.itemsize_bits(msk_dtype) + msk_packing = 32 // msk_bitwidth + if jtu.get_tpu_version() < 5 and msk_bitwidth < 32: + self.skipTest( + "Not implemented: cast vector to mask with bitwidth ==" + f" {msk_bitwidth}" + ) + if jtu.get_tpu_version() < 5 and bitwidth < 32: + self.skipTest(f"Not implemented: comparison with bitwidth == {bitwidth}") + + # Creating large tiling for masks by passing i32 vector first and + # then bitcast to msk_dtype so the tiling is also bitcasted from + # T(8, 128) to T(8 * msk_packing, 128). + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct(shape, dtype), + ) + def kernel(x_ref, msk_ref, o_ref): + zeros = jnp.zeros_like(x_ref) + msk = pltpu.bitcast(msk_ref[...], msk_dtype) + o_ref[...] = jnp.where(msk, x_ref[...], zeros) + + mask = jax.random.bernoulli(jax.random.key(1234), 0.5, shape).astype( + msk_dtype + ) + if msk_bitwidth < 32: + mask_for_bitcast = mask.reshape( + shape[0] // msk_packing, msk_packing, shape[1] + ).swapaxes(-1, -2) + else: + mask_for_bitcast = mask + + mask_i32 = lax.bitcast_convert_type( + mask_for_bitcast, + jnp.int32, + ) + x = jnp.arange(np.prod(shape), dtype=dtype).reshape(shape) + 1 + + out = kernel(x, mask_i32) + expected = jnp.where(mask, x, jnp.zeros_like(x)) + self.assertArraysEqual(out, expected) + @parameterized.product( target=(jnp.int8,), # TODO(apaszke): Add int4. round=(False, True), ) def test_quantize(self, target, round): - if not jtu.if_cloud_tpu_at_least(2025, 1, 15): - self.skipTest("Requires libtpu built after 2025-01-15") if not jtu.is_device_tpu_at_least(version=6): self.skipTest("Requires TPUv6+") shape = (256, 256) @@ -390,8 +489,6 @@ def kernel(x_ref, o_ref): @parameterized.product(axis=[0, 1], mode=["promise_in_bounds", None]) def test_dynamic_gather_along_axis(self, axis, mode): - if not jtu.if_cloud_tpu_at_least(2025, 2, 5): - self.skipTest("Requires libtpu built after 2025-02-05") if (axis == 0 and not jtu.is_device_tpu_at_least(version=5)) or ( axis == 1 and not jtu.is_device_tpu_at_least(version=4) ): @@ -418,12 +515,10 @@ def kernel(x, indices, out): @parameterized.product(dtype=[jnp.float32, jnp.bfloat16]) def test_float_div(self, dtype): - if not jtu.if_cloud_tpu_at_least(2025, 2, 13): - self.skipTest("Requires libtpu built after 2025-02-13") if not jtu.is_device_tpu_at_least(version=4): self.skipTest("Requires TPUv4+") kwargs = {} - if jtu.get_tpu_version() == 6: + if jtu.is_device_tpu_at_least(version=6): kwargs.update(dict(rtol=1e-2)) def kernel(x, y, out): out[:] = jax.lax.div(x[:], y[:]) @@ -441,9 +536,7 @@ def kernel(x, y, out): dtype=[jnp.float32, jnp.bfloat16, jnp.int8], ) def test_concat_mask(self, dtype): - if not jtu.if_cloud_tpu_at_least(2025, 2, 19): - self.skipTest("Requires libtpu built after 2025-02-19") - bitwidth = pallas_utils.dtype_bitwidth(dtype) + bitwidth = dtypes.itemsize_bits(dtype) if jtu.get_tpu_version() < 5 and bitwidth < 32: self.skipTest( f"Not implemented: cast vector to mask with bitwidth == {bitwidth}" @@ -490,9 +583,323 @@ def kernel(x, out): expected = dot(x[:], jnp.ones((1, d), jnp.bfloat16)) np.testing.assert_array_equal(output, expected) + # We need to manually run the test with the env variable + # `export LIBTPU_INIT_ARGS="--xla_jf_bounds_check=true"` + def test_disable_bounds_check(self): + if jtu.get_tpu_version() < 4: + self.skipTest("Requires TPUv4+") + src_shape = (8, 128) + tgt_shape = (16, 256) + + def kernel(src, tgt): + tgt[:] = src[tuple(pl.ds(d) for d in tgt.shape)] + + x = jnp.arange(np.prod(src_shape), dtype=jnp.float32).reshape(src_shape) + run = pl.pallas_call( + kernel, + jax.ShapeDtypeStruct(tgt_shape, jnp.float32), + compiler_params=pltpu.CompilerParams(disable_bounds_checks=True), + ) + output = run(x) + np.testing.assert_array_equal( + output[tuple(slice(0, d) for d in src_shape)], x + ) -class OpsInterpretTest(OpsTest): - INTERPRET = True + def test_while_loop_arg_num_change(self): + # This kernel will generate a while loop that will be CSEd by MLIR to have + # the different number of argments in before region and after region. + def kernel( + out_ref, + a, + ): + def loop_cond(state): + _, y = state + return y + + def loop_body(state): + x, y = state + + def then_0(): + def then_1(): + return jnp.int32(0) + + def else_1(): + a[0] = a[0] + 1 + + return jnp.int32(1) + + z = lax.cond(x == 0, then_1, else_1) + new_x = z + new_y = z != 0 + return new_x, new_y + + def else_0(): + return x, jnp.bool_(False) + + new_x, new_y = lax.cond(y, then_0, else_0) + + return (new_x, new_y) + + out_ref[0] = lax.while_loop( + loop_cond, loop_body, (jnp.int32(0), jnp.bool_(True)) + )[0] + + output = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((1,), jnp.int32), + in_specs=(), + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), + scratch_shapes=(pltpu.SMEM((1,), jnp.int32),), + )()[0] + self.assertEqual(output, 0) + + def test_produce_predicate_phi(self): + def kernel( + out_ref, + a, + ): + def loop_cond(state): + x, y = state + return jnp.logical_or(y, (x == 1)) + + def loop_body(state): + x, y = state + + def then_0(): + def then_1(): + return jnp.int32(0) + + def else_1(): + a[0] = a[0] + 1 + + return jnp.int32(1) + + z = lax.cond(x == 0, then_1, else_1) + new_x = z + new_y = z != 0 + return new_x, new_y + + def else_0(): + return x, jnp.bool_(False) + + new_x, new_y = lax.cond(y, then_0, else_0) + + return (new_x, new_y) + + out_ref[0] = lax.while_loop( + loop_cond, loop_body, (jnp.int32(0), jnp.bool_(True)) + )[0] + + output = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((1,), jnp.int32), + in_specs=(), + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), + scratch_shapes=(pltpu.SMEM((1,), jnp.int32),), + )()[0] + self.assertEqual(output, 0) + + def test_retiling_with_replicated_lane(self): + shape = (32, 1) + broadcast_shape = (32, 256) + + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct((8, 4, 256), jnp.float32), + ) + def kernel(x_ref, o_ref): + o_ref[...] = jnp.broadcast_to( + x_ref[...], broadcast_shape + ).reshape(o_ref.shape) + + x = jnp.arange(np.prod(shape), dtype=jnp.float32).reshape(shape) + out = kernel(x).reshape(broadcast_shape) + expected = jnp.broadcast_to(x, broadcast_shape) + np.testing.assert_array_equal(out, expected) + + @parameterized.parameters( + [jnp.bfloat16, jnp.float8_e5m2, jnp.float8_e4m3fn, jnp.float8_e4m3b11fnuz] + ) + def test_stochastic_round(self, target_dtype): + if not jtu.is_device_tpu_at_least(version=5): + self.skipTest("Requires TPU v5+") + + def kernel(x_ref, b_ref, o_ref): + o_ref[...] = pltpu.stochastic_round( + x_ref[...], b_ref[...], target_dtype=target_dtype + ) + + shape = (8, 128) + k1, k2 = jax.random.split(jax.random.key(4242), 2) + x = jax.random.normal(k1, shape, dtype=jnp.float32) + bits = jax.random.bits(k2, shape, dtype=jnp.uint32) + x_cast = x.astype(target_dtype) + x_cast_as_f32 = x_cast.astype(jnp.float32) + max_val = jnp.finfo(target_dtype).max + min_val = jnp.finfo(target_dtype).min + lower = jnp.where(x_cast_as_f32 > x, jnp.nextafter(x_cast, min_val), x_cast) + upper = jnp.where(x_cast_as_f32 < x, jnp.nextafter(x_cast, max_val), x_cast) + + result = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(x.shape, target_dtype), + )(x, bits) + + int_dtype = getattr(jnp, f"uint{dtypes.itemsize_bits(target_dtype)}") + is_correct_bitwise = ( + (result.view(int_dtype) == lower.view(int_dtype)) | + (result.view(int_dtype) == upper.view(int_dtype)) + ) + is_correct = jnp.where( + jnp.isnan(x_cast), jnp.isnan(result), is_correct_bitwise + ) + self.assertTrue(jnp.all(is_correct)) + + def _pack_unpack_elementwise_test_data( + self, shape, unpacked_dtype, packed_dtype): + """Generates data for test_pack_elementwise and test_unpack_elementwise.""" + unpacked_bitwidth = dtypes.itemsize_bits(unpacked_dtype) + packed_bitwidth = dtypes.itemsize_bits(packed_dtype) + num_sources = unpacked_bitwidth // packed_bitwidth + if jnp.issubdtype(unpacked_dtype, jnp.integer): + stacked_sources = jax.random.randint( + jax.random.key(0), + (num_sources, *shape), + minval=-1000, + maxval=1000, + dtype=jnp.int32, + ).astype(unpacked_dtype) + else: + stacked_sources = jax.random.uniform( + jax.random.key(0), (num_sources, *shape), dtype=unpacked_dtype + ) + stacked_results = ( + stacked_sources.astype(packed_dtype) + .view(getattr(jnp, f"uint{packed_bitwidth}")) + .astype(getattr(jnp, f"uint{unpacked_bitwidth}")) + ) + shifts = jnp.arange(num_sources, dtype=jnp.uint32) * packed_bitwidth + shifts = jnp.expand_dims(shifts, axis=tuple(range(1, stacked_results.ndim))) + packed_data = jnp.bitwise_or.reduce( + stacked_results.astype(jnp.uint32) << shifts, axis=0 + ).astype(getattr(jnp, f"uint{unpacked_bitwidth}")) + return stacked_sources, packed_data + + @parameterized.product( + config=[ + (jnp.float32, jnp.bfloat16), + (jnp.int32, jnp.int16), + (jnp.int32, jnp.int8), + (jnp.int32, jnp.int4), + (jnp.int16, jnp.int8), + (jnp.int8, jnp.int4), + ], + shape=[(8, 128), (2, 15, 300)], + ) + def test_pack_elementwise(self, config, shape): + unpacked_dtype, packed_dtype = config + if not jtu.is_device_tpu_at_least(version=5): + self.skipTest("Requires TPU v5+") + if dtypes.itemsize_bits( + unpacked_dtype + ) != 32 and not jtu.is_cloud_tpu_at_least(2026, 1, 2): + self.skipTest("Test requires libtpu from 2026/01/02 or later") + + src_bitwidth = dtypes.itemsize_bits(unpacked_dtype) + tgt_bitwidth = dtypes.itemsize_bits(packed_dtype) + num_sources = src_bitwidth // tgt_bitwidth + output_dtype = getattr(jnp, f"uint{src_bitwidth}") + + def kernel(xs_ref, o_ref): + xs = [xs_ref[i] for i in range(num_sources)] + o_ref[...] = pltpu.pack_elementwise(xs, packed_dtype=packed_dtype) + + stacked_sources, expected = self._pack_unpack_elementwise_test_data( + shape, unpacked_dtype, packed_dtype + ) + + result = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(shape, output_dtype), + )(stacked_sources) + + np.testing.assert_array_equal(result, expected) + + @parameterized.product( + config=[ + (jnp.float32, jnp.bfloat16), + (jnp.int32, jnp.int16), + (jnp.int32, jnp.int8), + (jnp.int32, jnp.int4), + ], + index=[0, 1, 3], + shape=[(8, 128), (2, 15, 300)], + ) + def test_unpack_elementwise(self, config, index, shape): + unpacked_dtype, packed_dtype = config + if not jtu.is_device_tpu_at_least(version=5): + self.skipTest("Requires TPU v5+") + + bitwidth = dtypes.itemsize_bits(packed_dtype) + packing_factor = 32 // bitwidth + + if index >= packing_factor: + self.skipTest( + f"Index {index} out of bounds for packing factor {packing_factor}") + + def kernel(x_ref, o_ref): + o_ref[...] = pltpu.unpack_elementwise( + x_ref[...], index=index, + packed_dtype=packed_dtype, unpacked_dtype=unpacked_dtype + ) + + sources, packed = self._pack_unpack_elementwise_test_data( + shape, unpacked_dtype, packed_dtype + ) + expected = sources[index].astype(packed_dtype).astype(unpacked_dtype) + + result = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(shape, unpacked_dtype), + )(packed) + + np.testing.assert_array_equal(result, expected) + + @parameterized.product( + dtype=[jnp.float32, jnp.bfloat16, jnp.int32, jnp.int8], + shape_reps=[ + ((8, 128), (2, 1)), + ((8, 128), (1, 2)), + ((8, 128), (2, 2)), + ((128, 8, 128), (3,2,1)), + ], + ) + def test_tile(self, dtype, shape_reps): + shape, reps = shape_reps + + k1 = jax.random.key(1234) + if jnp.issubdtype(dtype, jnp.integer): + x = jax.random.randint( + k1, + shape, + minval=dtypes.iinfo(dtype).min, + maxval=dtypes.iinfo(dtype).max, + dtype=dtype, + ) + else: + x = jax.random.normal(k1, shape, dtype=dtype) + + def kernel(x_ref, y_ref): + y_ref[...] = jnp.tile(x_ref[...], reps) + + run = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(np.multiply(shape, reps), dtype), + ) + + output = run(x) + expected = jnp.tile(x, reps) + np.testing.assert_array_equal(output, expected) if __name__ == "__main__": diff --git a/tests/pallas/tpu_paged_attention_kernel_test.py b/tests/pallas/tpu_paged_attention_kernel_test.py index 7fbccdb338d4..348f11acb5b8 100644 --- a/tests/pallas/tpu_paged_attention_kernel_test.py +++ b/tests/pallas/tpu_paged_attention_kernel_test.py @@ -18,19 +18,176 @@ from jax._src import test_util as jtu from jax.experimental.pallas.ops.tpu import paged_attention from jax.experimental.pallas.ops.tpu.paged_attention import quantization_utils +from jax.experimental.pallas.ops.tpu.paged_attention import util import jax.numpy as jnp import numpy as np -jax.config.parse_flags_with_absl() +def _generate_qkv_simplest( + dtype: jnp.dtype, +) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: + """Generates queries with one query head, kv pages, and attention.""" + max_seq_len = 4 + seq_lens = jnp.asarray([max_seq_len // 2]) + assert seq_lens.shape == (1,) + + # q_shape = (batch_size=1, num_q_heads=1, head_dim=1) + queries = jnp.asarray([[[1.2]]], dtype) + assert queries.shape == (1, 1, 1) + + # kv_shape = (batch_size=1, num_kv_heads=1, max_seq_len=4, head_dim=1) + k_pages = jnp.asarray([[[[0.1], [0.2], [0.3], [0.4]]]], dtype) + v_pages = jnp.asarray([[[[4.0], [3.0], [2.0], [1.0]]]], dtype) + assert k_pages.shape == (1, 1, 4, 1) + assert v_pages.shape == k_pages.shape + + # q*k: [[[ [.12, .24, .36, .48] ]]] + # masked: [[[ [.12, .24, -inf, -inf] ]]] + # softmax: [[[ [.47, .53, 0, 0] ]]] + # softmax(q*k) * v: .47*4 + .53*3 + 0*... = 3.47 + attention = jnp.asarray([[[3.47]]], dtype) + assert attention.shape == queries.shape + return seq_lens, queries, k_pages, v_pages, attention + + +def _generate_qkv_with_one_q_head( + dtype: jnp.dtype, +) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: + """Generates queries with one query head, kv pages, and attention.""" + max_seq_len = 4 + seq_lens = jnp.asarray([max_seq_len - 1]) + assert seq_lens.shape == (1,) + + # q_shape = (batch_size=1, num_q_heads=1, head_dim=1) + queries = jnp.asarray([[[1.7]]], dtype) + assert queries.shape == (1, 1, 1) + + # kv_shape = (batch_size=1, num_kv_heads=1, max_seq_len=4, head_dim=1) + k_pages = jnp.asarray([[[[0.12], [0.23], [0.34], [0.45]]]], dtype) + v_pages = jnp.asarray([[[[4.32], [3.21], [2.10], [1.09]]]], dtype) + assert k_pages.shape == (1, 1, 4, 1) + assert v_pages.shape == k_pages.shape + + # q*k: [[[ [.204, .391, .578, .765] ]]] + # masked: [[[ [.204, .391, .578, -inf] ]]] + # softmax: [[[ [.273, .330, .397, 0] ]]] + # softmax(q*k) * v: .273*4.32 + .330*3.21 + .397*2.10 + 0*... = 3.0723 + attention = jnp.asarray([[[3.0723]]], dtype) + assert attention.shape == queries.shape + return seq_lens, queries, k_pages, v_pages, attention + + +def _generate_qkv_with_two_q_heads( + dtype: jnp.dtype, +) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: + """Generates queries with two query heads, kv pages, and attention.""" + max_seq_len = 4 + seq_lens = jnp.asarray([max_seq_len]) + assert seq_lens.shape == (1,) + + # q_shape = (batch_size=1, num_q_heads=2, head_dim=1) + queries = jnp.asarray([[[1.3], [9.7]]], dtype) + assert queries.shape == (1, 2, 1) + + # kv_shape = (batch_size=1, num_kv_heads=1, max_seq_len=4, head_dim=1) + k_pages = jnp.asarray([[[[0.12], [0.23], [0.34], [0.45]]]], dtype) + v_pages = jnp.asarray([[[[4.32], [3.21], [2.10], [1.09]]]], dtype) + assert k_pages.shape == (1, 1, 4, 1) + assert v_pages.shape == k_pages.shape + + # q*k: [[[ [ .156, .299, .442, .585], + # [1.164, 2.231, 3.298, 4.365] ]]] + # softmax: [[[ [ .199, .230, .265, .306], + # [ .027, .079, .229, .665] ]]] + # softmax(q*k) * v: .199*4.32 + .230*3.21 + .265*2.10 + .306*1.09 = 2.488 + # softmax(q*k) * v: .027*4.32 + .079*3.21 + .229*2.10 + .665*1.09 = 1.576 + attention = jnp.asarray([[[2.488], [1.576]]], dtype) + assert attention.shape == queries.shape + return seq_lens, queries, k_pages, v_pages, attention + + +def _generate_qkv_with_head_dim_two( + dtype: jnp.dtype, +) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: + """Generates queries, kv pages, and attention with head_dim=2.""" + max_seq_len = 4 + seq_lens = jnp.asarray([max_seq_len // 2]) + assert seq_lens.shape == (1,) + + # q_shape = (batch_size=1, num_q_heads=1, head_dim=2) + queries = jnp.asarray([[[1.2, 9.0]]], dtype) + assert queries.shape == (1, 1, 2) + + # kv_shape = (batch_size=1, num_kv_heads=1, max_seq_len=4, head_dim=2) + k_pages = jnp.asarray( + [[[[0.1, 0.2], [0.2, 0.3], [0.3, 0.4], [0.4, 0.5]]]], dtype + ) + v_pages = jnp.asarray( + [[[[4.0, 5.0], [3.0, 6.0], [2.0, 7.0], [1.0, 8.0]]]], dtype + ) + assert k_pages.shape == (1, 1, 4, 2) + assert v_pages.shape == k_pages.shape + + # q*k: [[[ [ 1.92, 2.94, 3.96, 4.98] ]]] + # masked: [[[ [ 1.92, 2.94, -inf, -inf] ]]] + # softmax: [[[ [ .265, .735, 0, 0] ]]] + # softmax(q*k) * v: .265*4 + 0.735*3 + 0*... = 3.265 + # softmax(q*k) * v: .265*5 + 0.735*6 + 0*... = 5.735 + attention = jnp.asarray([[[3.265, 5.735]]], dtype) + assert attention.shape == queries.shape + return seq_lens, queries, k_pages, v_pages, attention def _generate_qkv( + dtype: jnp.dtype, + case: int, +) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: + match case: + case 0: + return _generate_qkv_simplest(dtype) + case 1: + return _generate_qkv_with_one_q_head(dtype) + case 2: + return _generate_qkv_with_two_q_heads(dtype) + case 3: + return _generate_qkv_with_head_dim_two(dtype) + case _: + raise ValueError(f"Unsupported case: {case}") + + +@jtu.with_config(jax_numpy_dtype_promotion="standard") +class JaxGroupedQueryAttentionReferenceTest(jtu.JaxTestCase): + + @parameterized.product( + dtype=(jnp.float32, jnp.bfloat16), + case=(0, 1, 2, 3), + ) + def test_grouped_query_attention(self, dtype: jnp.dtype, case: int): + # generate queries, kv pages, and seq_lens + seq_lens, queries, k_pages, v_pages, expected = _generate_qkv(dtype, case) + jax.debug.print("seq_lens: {seq_lens}", seq_lens=seq_lens) + jax.debug.print("queries: {queries}", queries=queries) + jax.debug.print("k_pages: {k_pages}", k_pages=k_pages) + jax.debug.print("v_pages: {v_pages}", v_pages=v_pages) + jax.debug.print("expected: {expected}", expected=expected) + + # calculate grouped query attention + attention = util.grouped_query_attention_reference( + queries, k_pages, v_pages, seq_lens + ) + jax.debug.print("attention: {attention}", attention=attention) + + # compare the results + atol, rtol = (3e-3, 5e-3) if dtype == jnp.bfloat16 else (2e-4, 2e-4) + self.assertAllClose(attention, expected, atol=atol, rtol=rtol) + + +def _generate_random_qkv( seq_lens, page_size, max_seq_len, num_kv_heads, - num_heads, + num_q_heads, head_dim, prng_key, dtype=jnp.float32, @@ -55,7 +212,7 @@ def _generate_qkv( page_indices = jnp.arange(batch_size * pages_per_sequence, dtype=jnp.int32) page_indices = jax.random.permutation(k3, page_indices, independent=True) page_indices = page_indices.reshape(batch_size, pages_per_sequence) - q = jax.random.normal(k4, (batch_size, num_heads, head_dim), dtype=dtype) + q = jax.random.normal(k4, (batch_size, num_q_heads, head_dim), dtype=dtype) return q, k_pages, v_pages, page_indices @@ -64,7 +221,7 @@ def _reconstruct_kv(page_indices, pages): pages = quantization_utils.unquantize_from_int8(pages, dtype=jnp.float32) batch_size = page_indices.shape[0] - num_heads, _, _, head_dim = pages.shape + num_kv_heads, _, _, head_dim = pages.shape def per_sequence_page_gather(pages, page_indices): return jnp.take(pages, page_indices, 1) @@ -72,32 +229,7 @@ def per_sequence_page_gather(pages, page_indices): gathered = jax.vmap(per_sequence_page_gather, in_axes=(None, 0))( pages, page_indices ) - return gathered.reshape(batch_size, num_heads, -1, head_dim) - - -def _grouped_query_attention_reference(q, k, v, lengths, attn_logits_soft_cap): - batch_size, num_heads, head_dim = q.shape - _, num_kv_heads, max_seq_len, _ = k.shape - assert k.shape == v.shape - assert num_heads % num_kv_heads == 0 - q = q.reshape(batch_size, num_kv_heads, num_heads // num_kv_heads, head_dim) - - if isinstance(k, quantization_utils.QuantizedTensor): - k = quantization_utils.unquantize_from_int8(k, dtype=jnp.float32) - if isinstance(v, quantization_utils.QuantizedTensor): - v = quantization_utils.unquantize_from_int8(v, dtype=jnp.float32) - - logits = jnp.einsum( - "bhgd,bhtd->bhgt", q.astype(jnp.float32), k.astype(jnp.float32) - ) - if attn_logits_soft_cap is not None: - logits = jnp.tanh(logits / attn_logits_soft_cap) * attn_logits_soft_cap - mask = jnp.arange(max_seq_len)[None] < lengths[:, None] - mask_value = -0.7 * float(np.finfo(np.dtype("float32")).max) - logits = logits + jnp.where(mask, 0.0, mask_value)[:, None, None, :] - weights = jax.nn.softmax(logits, axis=-1) - o = jnp.einsum("bhgt,bhtd->bhgd", weights.astype(v.dtype), v) - return o.reshape(batch_size, num_heads, head_dim) + return gathered.reshape(batch_size, num_kv_heads, -1, head_dim) def _megacore_enabled(): @@ -146,10 +278,13 @@ def test_paged_attention( self.skipTest("Megacore is only available on TPU v4 or TPU v5p") if num_kv_heads % 2 != 0 and megacore_mode == "kv_head": self.skipTest("Skip kv_head megacore mode when num_kv_heads is odd") + if (jtu.is_device_tpu(version=7, variant='x') and dtype == jnp.bfloat16 and + num_kv_heads == 8): + self.skipTest("Test does not work with large second-minor layout.") max_kv_len = 2048 block_size = 512 seq_lens = np.asarray([0, 3, 256, 513, 1023, 2048]) - q, k_pages, v_pages, page_indices = _generate_qkv( + q, k_pages, v_pages, page_indices = _generate_random_qkv( seq_lens, page_size, max_kv_len, @@ -172,8 +307,9 @@ def test_paged_attention( ) k = _reconstruct_kv(page_indices, k_pages) v = _reconstruct_kv(page_indices, v_pages) - o_ref = _grouped_query_attention_reference( - q, k, v, seq_lens, attn_logits_soft_cap) + o_ref = util.grouped_query_attention_reference( + q, k, v, seq_lens, attn_logits_soft_cap + ) if q_kv_head_ratio > 1: atol, rtol = 1e-2, 2e-2 @@ -188,4 +324,5 @@ def test_paged_attention( if __name__ == "__main__": + jax.config.config_with_absl() absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_pallas_async_test.py b/tests/pallas/tpu_pallas_async_test.py index 3dfc9bf1637a..e9f2cd45e5ad 100644 --- a/tests/pallas/tpu_pallas_async_test.py +++ b/tests/pallas/tpu_pallas_async_test.py @@ -22,7 +22,7 @@ from jax._src import test_util as jtu from jax._src.state import discharge as state_discharge from jax.experimental import pallas as pl -from jax.experimental import shard_map +from jax._src import shard_map from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp import numpy as np @@ -37,7 +37,7 @@ def make_async_copy(target_memory_space=None): if target_memory_space is None: - target_memory_space = pltpu.ANY + target_memory_space = pl.ANY @jax.named_call def copy_start(x: jax.Array) -> tuple[jax.Array, Future]: @@ -53,10 +53,10 @@ def copy_start_kernel(x_ref, aliased_x_ref, o_ref, sem): pltpu.SemaphoreType.DMA(()), ), in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], out_specs=( - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=target_memory_space), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), ), @@ -76,7 +76,7 @@ def copy_done_kernel(x_ref, o_ref, sem, aliased_o_ref): copy_done_kernel, out_shape=target_memory_space(x.shape, x.dtype), # out in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=target_memory_space), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), ], @@ -109,11 +109,11 @@ def async_slice_start(x: jax.Array) -> tuple[jax.Array, Future]: pltpu.SemaphoreType.DMA(()), ), in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], out_specs=( - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), ), input_output_aliases={0: 0}, @@ -129,11 +129,11 @@ def async_slice_done( async_slice_done_kernel, out_shape=(jax.ShapeDtypeStruct(x.shape[1:], x.dtype)), # out in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), ], - out_specs=(pl.BlockSpec(memory_space=pltpu.ANY)), + out_specs=(pl.BlockSpec(memory_space=pl.ANY)), input_output_aliases={1: 0}, )(x, out, sem) return out @@ -164,11 +164,11 @@ def async_dslice_start(x: jax.Array) -> tuple[jax.Array, Future]: grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], out_specs=( - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), ), ), @@ -185,11 +185,11 @@ def async_dslice_done( async_dslice_done_kernel, out_shape=(jax.ShapeDtypeStruct(x.shape[1:], x.dtype)), # out in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), ], - out_specs=(pl.BlockSpec(memory_space=pltpu.ANY)), + out_specs=(pl.BlockSpec(memory_space=pl.ANY)), input_output_aliases={1: 0}, )(x, out, sem) return out @@ -388,17 +388,169 @@ def body(i, carry): y = f(x) np.testing.assert_array_equal(y, x) + @parameterized.product(joint_axis=[True, False]) + def test_device_id_as_axis_dict(self, joint_axis): + if jax.device_count() < 2: + self.skipTest('Requires at least 2 devices for a 2d mesh.') + xdim, ydim = 2, jax.device_count() // 2 + mesh = jax.make_mesh( + (xdim, ydim), + ('x', 'y'), + axis_types=(jax.sharding.AxisType.Auto,) * 2, + ) + + xlocal, ylocal = 8, 128 + if joint_axis: + axis_name = ('x', 'y') + pspec = P(('x', 'y'), None) + input_arr = jax.device_put( + jax.random.uniform(jax.random.key(0), (xlocal * xdim * ydim, ylocal)), + jax.sharding.NamedSharding(mesh, pspec), + ) + else: + axis_name = 'x' + pspec = P('x', 'y') + input_arr = jax.device_put( + jax.random.uniform(jax.random.key(0), (xlocal * xdim, ylocal * ydim)), + jax.sharding.NamedSharding(mesh, pspec), + ) + + def copy_kernel(input_ref, output_ref, send_sem, recv_sem, local_copy_sem): + xid = jax.lax.axis_index(axis_name) + x0_local_copy = pltpu.make_async_copy( + src_ref=input_ref, dst_ref=output_ref, sem=local_copy_sem + ) + copy_x0_to_x1 = pltpu.make_async_remote_copy( + src_ref=input_ref, + dst_ref=output_ref, + send_sem=send_sem, + recv_sem=recv_sem, + device_id={axis_name: 1}, + ) + + @pl.when(xid == 0) + def _(): + copy_x0_to_x1.start() + x0_local_copy.start() + x0_local_copy.wait() + copy_x0_to_x1.wait_send() + @pl.when(xid == 1) + def _(): + copy_x0_to_x1.wait_recv() + + copy = pl.pallas_call( + copy_kernel, + out_shape=jax.ShapeDtypeStruct((xlocal, ylocal), jnp.float32), + in_specs=[pl.BlockSpec(memory_space=pl.ANY),], + out_specs=pl.BlockSpec(memory_space=pl.ANY), + scratch_shapes=[pltpu.SemaphoreType.DMA] * 3, + ) + + # Wrap the kernel within a shard_map to call. + pallas_out = jax.jit( + jax.shard_map( + copy, mesh=mesh, in_specs=pspec, out_specs=pspec, check_vma=False + ) + )(input_arr) + + # x=1 devices are flushed with x=0 device contents + np.testing.assert_array_equal(input_arr[:xlocal], pallas_out[:xlocal]) + np.testing.assert_array_equal(pallas_out[:xlocal], + pallas_out[xlocal:(2*xlocal)]) + + def test_axis_dict_with_core_single_device(self): + if jax.device_count() > 2 or (jax.devices()[0].num_cores) != 2: + self.skipTest('Testing single device two cores') + mesh = jax.make_mesh( + (jax.device_count(),), + ('device',), + axis_types=(jax.sharding.AxisType.Auto,), + ) + ddim = jax.device_count() + tcmesh = pltpu.create_tensorcore_mesh('core') + pspec = P('device', None) + sharding = jax.sharding.NamedSharding(mesh, pspec) + + # Array is fully sharded. + xlocal, ylocal = 8, 256 + input_arr = jnp.arange(xlocal * ddim * ylocal, dtype=jnp.int32).reshape( + (xlocal * ddim, ylocal) + ) + input_arr = jax.device_put(input_arr, sharding) + + def core_copy(refs): + in_ref, out_ref = refs + + @pl.core_map(tcmesh, compiler_params=pltpu.CompilerParams(collective_id=7)) + def _(): + num_cores = jax.lax.axis_size('core') + slc_size = ylocal // num_cores + vmem_shape = (xlocal, slc_size) + + # This runs on every core, for every vmem iterations + def alloc(out_vmem_ref, sem, send_sem, recv_sem): + core_index = jax.lax.axis_index('core') + slc = pl.ds(core_index * slc_size, slc_size) + + # Make sure all cores have entered run_scoped. + sem0 = pltpu.get_barrier_semaphore() + for i in range(ddim): + for j in range(num_cores): + pltpu.semaphore_signal( + sem0, 1, device_id={'device': i, 'core': j}, + device_id_type=pltpu.DeviceIdType.MESH) + pltpu.semaphore_wait(sem0, ddim * num_cores) + + # Identity function by default + pltpu.async_copy(in_ref.at[:, slc], out_ref.at[:, slc], sem).wait() + + copy_c0_to_c1 = pltpu.make_async_remote_copy( + src_ref=in_ref.at[:, slc], + dst_ref=out_vmem_ref, + send_sem=send_sem, + recv_sem=recv_sem, + device_id={'core': 1}, + device_id_type=pltpu.DeviceIdType.MESH, + ) + + @pl.when(core_index == 0) + def _(): + copy_c0_to_c1.start() + copy_c0_to_c1.wait_send() + + @pl.when(core_index == 1) + def _(): + copy_c0_to_c1.wait_recv() + pltpu.async_copy(out_vmem_ref, out_ref.at[:, slc], sem).wait() + + pl.run_scoped( + alloc, + pltpu.VMEM(vmem_shape, out_ref.dtype), + *([pltpu.SemaphoreType.DMA] * 3), + ) + + @partial(jax.shard_map, mesh=mesh, in_specs=pspec, out_specs=pspec, check_vma=False) + def run_core_kernel(input): + output = jnp.zeros_like(input) + _, output = pl.run_state(core_copy)((input, output)) + return output + pallas_out = jax.jit(run_core_kernel)(input_arr) + + # The device=0 core=1 slice was flushed with device=0 core=0 contents + np.testing.assert_array_equal(pallas_out[:, 128:], input_arr[:, :128]) + np.testing.assert_array_equal(pallas_out[:, :128], input_arr[:, :128]) + def make_async_remote_copy(axis_name: str, direction: str = 'right', target_memory_space=None): if target_memory_space is None: - target_memory_space = pltpu.ANY + target_memory_space = pl.ANY @jax.named_call def copy_start(x: jax.Array) -> tuple[jax.Array, Future]: def copy_start_kernel(x_ref, aliased_x_ref, o_ref, send_sem, recv_sem): del aliased_x_ref - axis_size = jax.lax.psum(1, axis_name) + axis_size = jax.lax.axis_size(axis_name) left_neighbor = jax.lax.rem( jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size ) @@ -412,7 +564,7 @@ def copy_start_kernel(x_ref, aliased_x_ref, o_ref, send_sem, recv_sem): src_neighbor = right_neighbor dst_neighbor = left_neighbor barrier_sem = pltpu.get_barrier_semaphore() - pltpu.semaphore_signal(barrier_sem, device_id=src_neighbor, core_index=0) + pltpu.semaphore_signal(barrier_sem, device_id=src_neighbor) pltpu.semaphore_wait(barrier_sem, 1) pltpu.make_async_remote_copy( x_ref, o_ref, send_sem, recv_sem, device_id=dst_neighbor, @@ -427,16 +579,16 @@ def copy_start_kernel(x_ref, aliased_x_ref, o_ref, send_sem, recv_sem): pltpu.SemaphoreType.DMA(()), # recv_sem ), in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], out_specs=( - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=target_memory_space), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), ), input_output_aliases={0: 0}, - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( collective_id=0, has_side_effects=True ), )(x) @@ -454,10 +606,10 @@ def send_done_kernel(x_ref, send_sem, aliased_o_ref): send_done_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), # out in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), ], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), input_output_aliases={0: 0}, )(x, send_sem) return x @@ -474,7 +626,7 @@ def send_done_kernel(x_ref, o_ref, send_sem, aliased_o_ref): send_done_kernel, out_shape=target_memory_space(x.shape, x.dtype), # out in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=target_memory_space), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), ], @@ -492,7 +644,7 @@ def copy_start(x: jax.Array) -> tuple[jax.Array, Future]: def copy_start_kernel(x_ref, aliased_x_ref, o_ref, left_sems, right_sems): del aliased_x_ref - axis_size = jax.lax.psum(1, axis_name) + axis_size = jax.lax.axis_size(axis_name) left_neighbor = jax.lax.rem( jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size ) @@ -500,10 +652,8 @@ def copy_start_kernel(x_ref, aliased_x_ref, o_ref, left_sems, right_sems): jax.lax.axis_index(axis_name) + 1, axis_size ) barrier_sem = pltpu.get_barrier_semaphore() - pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor, core_index=0) - pltpu.semaphore_signal( - barrier_sem, device_id=right_neighbor, core_index=0 - ) + pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor) + pltpu.semaphore_signal(barrier_sem, device_id=right_neighbor) pltpu.semaphore_wait(barrier_sem, 2) assert x.shape[0] % 2 == 0, x.shape pltpu.make_async_remote_copy( @@ -525,21 +675,21 @@ def copy_start_kernel(x_ref, aliased_x_ref, o_ref, left_sems, right_sems): copy_start_kernel, out_shape=( jax.ShapeDtypeStruct(x.shape, x.dtype), # aliased x - pltpu.ANY(x.shape, x.dtype), # out + pl.ANY(x.shape, x.dtype), # out (pltpu.SemaphoreType.DMA(()),) * 2, # left_sems (pltpu.SemaphoreType.DMA(()),) * 2, # right_sems ), in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], out_specs=( - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), (pl.BlockSpec(memory_space=pltpu.SEMAPHORE),) * 2, (pl.BlockSpec(memory_space=pltpu.SEMAPHORE),) * 2, ), input_output_aliases={0: 0}, - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( collective_id=0, has_side_effects=False ), )(x) @@ -566,11 +716,11 @@ def send_done_kernel(x_ref, send_left_sem, send_right_sem, aliased_o_ref): send_done_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), # out in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), ], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), input_output_aliases={0: 0}, )(x, send_left_sem, send_right_sem) return x @@ -597,12 +747,12 @@ def recv_done_kernel(o_ref, x_ref, recv_left_sem, recv_right_sem, recv_done_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), # out in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), ], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), input_output_aliases={0: 0}, )(out, x, recv_left_sem, recv_right_sem) return out @@ -620,12 +770,16 @@ def setUp(self): def test_basic_remote_copy(self): - mesh = jax.make_mesh((jax.device_count(),), ('x',)) + mesh = jax.make_mesh( + (jax.device_count(),), + ('x',), + axis_types=(jax.sharding.AxisType.Auto,), + ) @jax.jit @partial( shard_map.shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'), - check_rep=False, + check_vma=False, ) def f(x): copy_start, send_done, recv_done = make_async_remote_copy('x') @@ -643,12 +797,16 @@ def f(x): def test_multi_remote_copy(self): - mesh = jax.make_mesh((jax.device_count(),), ('x',)) + mesh = jax.make_mesh( + (jax.device_count(),), + ('x',), + axis_types=(jax.sharding.AxisType.Auto,), + ) @jax.jit @partial( shard_map.shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'), - check_rep=False, + check_vma=False, ) def f(x): copy_start, send_done, recv_done = make_async_remote_copy( @@ -676,12 +834,16 @@ def f(x): def test_basic_collective_permute_loop(self): - mesh = jax.make_mesh((jax.device_count(),), ('x',)) + mesh = jax.make_mesh( + (jax.device_count(),), + ('x',), + axis_types=(jax.sharding.AxisType.Auto,), + ) @jax.jit @partial( shard_map.shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'), - check_rep=False, + check_vma=False, ) def f(x): copy_start, send_done, recv_done = make_async_remote_copy('x') @@ -701,12 +863,16 @@ def body(_, x): def test_staggered_collective_permute_loop(self): - mesh = jax.make_mesh((jax.device_count(),), ('x',)) + mesh = jax.make_mesh( + (jax.device_count(),), + ('x',), + axis_types=(jax.sharding.AxisType.Auto,), + ) @jax.jit @partial( shard_map.shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'), - check_rep=False, + check_vma=False, ) def f(x): assert x.shape[0] == 1 @@ -734,12 +900,16 @@ def body(_, carry): np.testing.assert_array_equal(y, expected) def test_bidi_collective_permute_loop(self): - mesh = jax.make_mesh((jax.device_count(),), ('x',)) + mesh = jax.make_mesh( + (jax.device_count(),), + ('x',), + axis_types=(jax.sharding.AxisType.Auto,), + ) @jax.jit @partial( shard_map.shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'), - check_rep=False, + check_vma=False, ) def f(x): assert x.shape[0] == 1 diff --git a/tests/pallas/tpu_pallas_call_print_test.py b/tests/pallas/tpu_pallas_call_print_test.py new file mode 100644 index 000000000000..5978a886576f --- /dev/null +++ b/tests/pallas/tpu_pallas_call_print_test.py @@ -0,0 +1,155 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test TPU-specific extensions to pallas print call.""" + +import functools +import re +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import test_util as jtu +from jax._src.pallas import pallas_test_util as ptu +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +import jax.numpy as jnp +import numpy as np + +jax.config.parse_flags_with_absl() + +P = jax.sharding.PartitionSpec + +partial = functools.partial + + +@jtu.thread_unsafe_test_class() # debug print test is not thread safe +class PallasCallPrintTest(ptu.PallasTPUTest): + + def test_debug_print(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + ) + def kernel(x_ref, o_ref): + pl.debug_print('It works!') + + x = jnp.arange(8 * 128, dtype=jnp.float32).reshape((8, 128)) + compiled_kernel = ( + jax.jit(kernel) + .lower(x) + .compile({'xla_tpu_enable_log_recorder': 'true'}) + ) + with jtu.capture_stderr() as get_output: + jax.block_until_ready(compiled_kernel(x)) + self.assertIn('It works!', get_output()) + + def test_debug_print_in_index_map(self): + def index_map(i): + pl.debug_print('It works!') + return (i, 0) + + @functools.partial( + self.pallas_call, + grid=(1,), + in_specs=(pl.BlockSpec(index_map=index_map),), + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + ) + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + + x = jnp.arange(8 * 128, dtype=jnp.float32).reshape((8, 128)) + compiled_kernel = ( + jax.jit(kernel) + .lower(x) + .compile({'xla_tpu_enable_log_recorder': 'true'}) + ) + with jtu.capture_stderr() as get_output: + jax.block_until_ready(compiled_kernel(x)) + self.assertIn('It works!', get_output()) + + @parameterized.product(dtype=[jnp.int32, jnp.float32]) + def test_debug_print_with_values(self, dtype): + @functools.partial( + self.pallas_call, + in_specs=(pl.BlockSpec(memory_space=pltpu.SMEM),), + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + ) + def kernel(x_ref, o_ref): + if dtype == jnp.int32: + pl.debug_print('BEGIN1 x[0] == {}', x_ref[0]) + pl.debug_print( + 'BEGIN2 x[0] == {} ; x[1] == {} ; END', x_ref[0], x_ref[1] + ) + else: + pl.debug_print('BEGIN1 x[0] == ', x_ref[0]) + + x = jnp.array([42, 24], dtype=dtype) + compiled_kernel = ( + jax.jit(kernel) + .lower(x) + .compile({'xla_tpu_enable_log_recorder': 'true'}) + ) + with jtu.capture_stderr() as get_output: + jax.block_until_ready(compiled_kernel(x)) + output = get_output() + if dtype == jnp.int32: + self.assertIn('BEGIN1 x[0] == 42', output) + self.assertIn('BEGIN2 x[0] == 42 ; x[1] == 24 ; END', output) + else: + self.assertIn('BEGIN1 x[0] == f32[] 42', output) + + @parameterized.named_parameters( + (f"{'_'.join(map(str, shape))}_{dtype.__name__}", shape, dtype) + for shape in ( + (2, 8, 128), + # test unaligned shapes + (3,), + (3, 4), + (2, 3, 4), + (2, 9, 129), + ) + for dtype in (jnp.int32, jnp.uint32, jnp.float32) + ) + def test_debug_print_vector(self, shape, dtype): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(shape, dtype), + ) + def kernel(x_ref, o_ref): + pl.debug_print("{}", x_ref[...]) + o_ref[...] = x_ref[...] + + n = np.prod(shape) + x = jnp.arange(n, dtype=dtype).reshape(shape) + compiled_kernel = ( + jax.jit(kernel) + .lower(x) + .compile({"xla_tpu_enable_log_recorder": "true"}) + ) + with jtu.capture_stderr() as get_output: + jax.block_until_ready(compiled_kernel(x)) + output = get_output() + numbers = [ + int(num) + for line in output.splitlines() + if (match := re.search(r"\{(.*)", line)) # extract contents after `{` + for num in re.findall(r"\d+", match.group(1)) + ] + # Check if the numbers in the output match the values generated by `arange`. + self.assertLen(numbers, n) + self.assertTrue(all(num == i for i, num in enumerate(numbers))) + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_pallas_distributed_test.py b/tests/pallas/tpu_pallas_distributed_test.py index f7d7daf1874f..35c5bd27c35c 100644 --- a/tests/pallas/tpu_pallas_distributed_test.py +++ b/tests/pallas/tpu_pallas_distributed_test.py @@ -13,6 +13,7 @@ # limitations under the License. import functools +import json import os import tempfile from absl.testing import absltest @@ -22,7 +23,7 @@ from jax._src import test_util as jtu from jax.experimental import mesh_utils from jax.experimental import pallas as pl -from jax.experimental import shard_map +from jax._src import shard_map from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp import numpy as np @@ -40,19 +41,21 @@ def setUp(self): super().setUp() if jax.device_count() < 2: self.skipTest('Only >=2 devices are supported.') - if not jtu.is_device_tpu(5, 'e'): - self.skipTest('Only works with TPU v5e.') + if not jtu.is_device_tpu_at_least(4): + self.skipTest('Only TPUs v4+ are supported.') + if jtu.is_device_tpu(7, 'x'): + # TODO(sharadmv): Enable these tests. + self.skipTest('Tests time out on TPUs v7x.') @parameterized.named_parameters( - ('vmem', pltpu.TPUMemorySpace.VMEM), - ('hbm', pltpu.TPUMemorySpace.ANY), + ('vmem', pltpu.VMEM), + ('hbm', pl.ANY), ) def test_basic_remote_vmem_dma(self, mem): # Implements very simple collective permute def kernel(x_ref, y_ref): def body(ready_sem, send_sem, recv_sem): - dev_id = pltpu.device_id() - other_dev_id = 1 - dev_id + other_dev_id = 1 - lax.axis_index('x') pltpu.semaphore_signal(ready_sem, device_id=other_dev_id, device_id_type=pltpu.DeviceIdType.LOGICAL) pltpu.semaphore_wait(ready_sem) @@ -77,19 +80,66 @@ def body(x): kernel, in_specs=[pl.BlockSpec(memory_space=mem)], out_specs=pl.BlockSpec(memory_space=mem), - out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32, vma=frozenset('x')), )(x) devices = jax.devices()[:2] mesh = jax.sharding.Mesh(devices, ['x']) - y = jax.jit( + f = jax.jit( shard_map.shard_map( - body, mesh, in_specs=P('x'), out_specs=P('x'), check_rep=False + body, mesh=mesh, in_specs=P('x'), out_specs=P('x'), ) - )(x) + ) + jaxpr = f.trace(x).jaxpr + self.assertNotIn('pvary', str(jaxpr)) + y = f(x) expected = jnp.concatenate([x[8:], x[:8]]) np.testing.assert_allclose(y, expected) + def test_vma_error(self): + def kernel(x_ref, y_ref): + def body(ready_sem, send_sem, recv_sem): + other_dev_id = 1 - lax.axis_index('x') + pltpu.semaphore_signal(ready_sem, device_id=other_dev_id, + device_id_type=pltpu.DeviceIdType.LOGICAL) + pltpu.semaphore_wait(ready_sem) + copy_done = pltpu.async_remote_copy( + x_ref, y_ref, send_sem, recv_sem, other_dev_id, + device_id_type=pltpu.DeviceIdType.LOGICAL, + ) + copy_done.wait_send() + copy_done.wait_recv() + + pl.run_scoped( + body, + pltpu.SemaphoreType.REGULAR, + pltpu.SemaphoreType.DMA, + pltpu.SemaphoreType.DMA, + ) + + x = jnp.arange(2 * 8 * 128.0).reshape((2 * 8, 128)) + + def body(x): + return pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=pl.ANY)], + out_specs=pl.BlockSpec(memory_space=pl.ANY), + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + )(x) + + devices = jax.devices()[:2] + mesh = jax.sharding.Mesh(devices, ['x']) + f = jax.jit( + shard_map.shard_map( + body, mesh=mesh, in_specs=P('x'), out_specs=P('x'), + ) + ) + with self.assertRaisesRegex( + ValueError, + 'When `check_vma=True` on `jax.shard_map`, `vma` on' + ' `jax.ShapeDtypeStruct` must not be `None`'): + f(x) + @parameterized.named_parameters( ('left', 'left'), ('right', 'right') @@ -99,7 +149,7 @@ def test_pallas_call_axis_index(self, direction): def kernel(x_ref, y_ref): def body(ready_sem, send_sem, recv_sem): my_id = lax.axis_index('x') - num_devices = lax.psum(1, 'x') + num_devices = lax.axis_size('x') if direction == 'right': neighbor = lax.rem(my_id + 1, num_devices) else: @@ -127,8 +177,8 @@ def body(ready_sem, send_sem, recv_sem): def body(x): return pl.pallas_call( kernel, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM)], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), out_shape=x, )(x) @@ -137,7 +187,7 @@ def body(x): mesh = jax.sharding.Mesh(device_mesh, ['x']) y = jax.jit( shard_map.shard_map( - body, mesh, in_specs=P('x'), out_specs=P('x'), check_rep=False + body, mesh=mesh, in_specs=P('x'), out_specs=P('x'), check_vma=False ) )(x) if direction == 'right': @@ -153,7 +203,7 @@ def kernel(x_ref, y_ref): def body(ready_sem, send_sem, recv_sem): my_id = lax.axis_index('x') my_other_id = lax.axis_index('y') - axis_size = lax.psum(1, 'x') + axis_size = lax.axis_size('x') if direction == 'right': neighbor = lax.rem(my_id + 1, axis_size) else: @@ -181,8 +231,8 @@ def body(ready_sem, send_sem, recv_sem): def body(x): return pl.pallas_call( kernel, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM)], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), out_shape=x, )(x) @@ -193,10 +243,10 @@ def body(x): y = jax.jit( shard_map.shard_map( body, - mesh, + mesh=mesh, in_specs=P('x', None), out_specs=P('x', None), - check_rep=False, + check_vma=False, ) )(x) if direction == 'right': @@ -209,7 +259,7 @@ def test_barrier_semaphore(self): def kernel(x_ref, y_ref): def body(ready_sem, send_sem, recv_sem): my_id = lax.axis_index('x') - num_devices = lax.psum(1, 'x') + num_devices = lax.axis_size('x') neighbor = lax.rem(my_id + 1, num_devices) barrier_sem = pltpu.get_barrier_semaphore() pltpu.semaphore_signal(barrier_sem, device_id=neighbor) @@ -233,10 +283,10 @@ def body(ready_sem, send_sem, recv_sem): def body(x): return pl.pallas_call( kernel, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM)], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), out_shape=x, - compiler_params=dict(mosaic=dict(collective_id=0)), + compiler_params=pltpu.CompilerParams(collective_id=0), )(x) device_mesh = mesh_utils.create_device_mesh( @@ -244,12 +294,226 @@ def body(x): mesh = jax.sharding.Mesh(device_mesh, ['x']) y = jax.jit( shard_map.shard_map( - body, mesh, in_specs=P('x'), out_specs=P('x'), check_rep=False + body, mesh=mesh, in_specs=P('x'), out_specs=P('x'), check_vma=False ) )(x) expected = jnp.concatenate([x[-8:], x[:-8]]) np.testing.assert_allclose(y, expected) + def test_barrier_semaphore_no_axis_name(self): + def kernel(x_ref, y_ref): + num_devices = lax.axis_size('x') + barrier_sem = pltpu.get_barrier_semaphore() + for i in range(num_devices): + pltpu.semaphore_signal(barrier_sem, device_id=i) + pltpu.semaphore_wait(barrier_sem, num_devices) + pltpu.sync_copy(x_ref, y_ref) + + x = jnp.arange(8 * 128).reshape((8, 128)) + + def body(x): + return pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), + out_shape=x, + compiler_params=pltpu.CompilerParams(collective_id=0), + )(x) + + device_mesh = mesh_utils.create_device_mesh( + (jax.device_count(),), jax.devices()) + mesh = jax.sharding.Mesh(device_mesh, ['x']) + y = jax.jit( + shard_map.shard_map( + body, mesh=mesh, in_specs=P('x'), out_specs=P('x'), check_vma=False + ) + )(x) + np.testing.assert_allclose(y, x) + + @parameterized.product(joint_axis=[True, False]) + def test_axis_dict_with_core_multi_device(self, joint_axis): + if jax.device_count() < 2: + self.skipTest('Requires at least 2 devices for DMAs.') + if (cdim := jax.devices()[0].num_cores) < 2: + self.skipTest('Requires a TPU with at least 2 cores.') + if pltpu.get_tpu_info().num_cores > 1 and joint_axis: + self.skipTest('Joint axis is not supported on multi-core TPUs.') + mesh = jax.make_mesh( + (jax.device_count(),), + ('device',), + axis_types=(jax.sharding.AxisType.Auto,), + ) + ddim = jax.device_count() + tcmesh = pltpu.create_tensorcore_mesh('core') + pspec = P('device', None) + sharding = jax.sharding.NamedSharding(mesh, pspec) + + # Array is fully sharded. + xlocal, ylocal = 8, 256 + input_arr = jnp.arange(xlocal * ddim * ylocal, dtype=jnp.int32).reshape( + (xlocal * ddim, ylocal) + ) + input_arr = jax.device_put(input_arr, sharding) + + def core_copy(refs): + in_ref, out_ref = refs + + @pl.core_map(tcmesh, compiler_params=pltpu.CompilerParams(collective_id=7)) + def _(): + num_cores = jax.lax.axis_size('core') + slc_size = ylocal // num_cores + vmem_shape = (xlocal, slc_size) + + # This runs on every core, for every vmem iterations + def alloc(core_sem, out_vmem_ref, sem, send_sem, recv_sem): + core_index = jax.lax.axis_index('core') + # Make sure all cores have entered run_scoped. + for j in range(num_cores): + pltpu.semaphore_signal(core_sem, 1, device_id={'core': j}) + pltpu.semaphore_wait(core_sem, num_cores) + + device_index = jax.lax.axis_index('device') + slc = pl.ds(core_index * slc_size, slc_size) + + # Make sure all devices and cores have entered run_scoped. + sem0 = pltpu.get_barrier_semaphore() + for i in range(ddim): + for j in range(num_cores): + pltpu.semaphore_signal( + sem0, 1, device_id={'device': i, 'core': j} + ) + pltpu.semaphore_wait(sem0, ddim * num_cores) + + # Identity function by default + pltpu.async_copy(in_ref.at[:, slc], out_ref.at[:, slc], sem).wait() + + if joint_axis: + device_id = {('device', 'core'): cdim + 1} + else: + device_id = {'device': 1, 'core': 1} + copy_d0c0_to_d1c1 = pltpu.make_async_remote_copy( + src_ref=in_ref.at[:, slc], + dst_ref=out_vmem_ref, + send_sem=send_sem, + recv_sem=recv_sem, + device_id=device_id, + device_id_type=pltpu.DeviceIdType.MESH, + ) + + @pl.when(device_index == 0) + def _(): + @pl.when(core_index == 0) + def _(): + copy_d0c0_to_d1c1.start() + copy_d0c0_to_d1c1.wait_send() + + @pl.when(device_index == 1) + def _(): + @pl.when(core_index == 1) + def _(): + copy_d0c0_to_d1c1.wait_recv() + pltpu.async_copy(out_vmem_ref, out_ref.at[:, slc], sem).wait() + + pl.run_scoped( + alloc, + pltpu.SemaphoreType.REGULAR, + pltpu.VMEM(vmem_shape, out_ref.dtype), + *([pltpu.SemaphoreType.DMA] * 3), + ) + + @partial(jax.shard_map, mesh=mesh, in_specs=pspec, out_specs=pspec, check_vma=False) + def run_core_kernel(input): + output = jnp.zeros_like(input) + _, output = pl.run_state(core_copy)((input, output)) + return output + pallas_out = jax.jit(run_core_kernel)(input_arr) + + # The device=1 core=1 slice was flushed with device=0 core=0 contents + np.testing.assert_array_equal(pallas_out[8:16, 128:], input_arr[:8, :128]) + # Mask that slice out and all should be the same. + mask = jnp.zeros((8, 128), jnp.int32) + masked_in = jax.lax.dynamic_update_slice(input_arr, mask, (8, 128)) + masked_out = jax.lax.dynamic_update_slice(pallas_out, mask, (8, 128)) + np.testing.assert_array_equal(masked_in, masked_out) + + def test_multi_device_core_local_kernel(self): + num_devices = jax.device_count() + num_cores = pltpu.get_tpu_info().num_cores + x = jnp.arange(num_devices * num_cores * 8 * 128).reshape( + (num_devices, num_cores, 8, 128) + ) + + def body(x): + x_ref = jax.new_ref(x) + y_ref = jax.new_ref(jnp.empty_like(x)) + + tcmesh = pltpu.create_tensorcore_mesh('core') + @pl.core_map(tcmesh) + def _(): + num_cores = jax.lax.axis_size('core') + def inner(sem): + for i in range(num_cores): + pltpu.semaphore_signal(sem, 1, device_id={'core': i}) + pltpu.semaphore_wait(sem, num_cores) + core_id = jax.lax.axis_index('core') + pltpu.sync_copy(x_ref.at[:, core_id], y_ref.at[:, core_id]) + pl.run_scoped(inner, pltpu.SemaphoreType.REGULAR) + return jax.freeze(y_ref) + + mesh = jax.make_mesh((jax.device_count(),), ['x']) + y = jax.jit( + shard_map.shard_map( + body, mesh=mesh, in_specs=P('x'), out_specs=P('x'), check_vma=False + ) + )(x) + np.testing.assert_allclose(y, x) + + def test_no_barrier_semaphore(self): + def alloc_sem(_): + num_devices = lax.axis_size('x') + barrier_sem = pltpu.get_barrier_semaphore() + for i in range(num_devices): + pltpu.semaphore_signal(barrier_sem, device_id=i) + pltpu.semaphore_wait(barrier_sem, num_devices) + + def barrier_kernel(x_ref, sem_ref, out_ref): + num_devices = lax.axis_size('x') + for i in range(num_devices): + pltpu.semaphore_signal(sem_ref, device_id=i) + pltpu.semaphore_wait(sem_ref, num_devices) + out_ref[...] = x_ref[...] + 1 + + x = jnp.arange(8 * 128).reshape((8, 128)) + + def body(x): + sem = pl.pallas_call( + alloc_sem, + in_specs=[], + out_specs=pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + out_shape=pltpu.SemaphoreType.REGULAR(()), + compiler_params=pltpu.CompilerParams(collective_id=0), + )() + return pl.pallas_call( + barrier_kernel, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.VMEM), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), + out_shape=x, + compiler_params=pltpu.CompilerParams(skip_device_barrier=True), + )(x, sem) + + device_mesh = mesh_utils.create_device_mesh( + (jax.device_count(),), jax.devices()) + mesh = jax.sharding.Mesh(device_mesh, ['x']) + y = jax.jit( + shard_map.shard_map( + body, mesh=mesh, in_specs=P('x'), out_specs=P('x'), check_vma=False + ) + )(x) + np.testing.assert_allclose(y, x + 1) + class PallasCallRemoteDMAInterpretTest(parameterized.TestCase): @@ -257,6 +521,9 @@ def setUp(self): super().setUp() if not jtu.is_device_tpu(): self.skipTest('Test requires TPU') + if jtu.is_device_tpu(7, 'x'): + # TODO(sharadmv): Enable these tests. + self.skipTest('Tests time out on TPUs v7x.') @parameterized.parameters(('left',), ('right',)) def test_interpret_remote_dma_ppermute(self, permutation): @@ -292,7 +559,7 @@ def test_kernel(x_ref, grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], scratch_shapes=( [pltpu.SemaphoreType.DMA] * 2 @@ -317,7 +584,7 @@ def test_kernel(x_ref, mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x'), - check_rep=False)) + check_vma=False)) result = compiled_func(sharded_arr) perm = tuple((src, permute_fn(src)) for src in range(num_devices)) @@ -341,12 +608,23 @@ def test_interpret_remote_dma_asymmetrical_indexer(self): def test_kernel(x_ref, output_ref, send_sem, - recv_sem): + recv_sem, barrier_sem): output_ref[...] = jnp.zeros_like(output_ref[...]) my_id = lax.axis_index('x') even_device = lax.rem(my_id, 2) odd_device = 1 - even_device - neighbor = lax.rem(my_id + 1, num_devices) + next_device = lax.rem(my_id + 1, num_devices) + + del barrier_sem + # This kernel as written is racey, but remote semaphore_signal is not + # supported in HLO interpret mode yet. HLO interpret will not race + # because DMAs are implemented as collectives which will barrier. + # Signal to the sender to this device that output_ref has been zeroed + # and this device is ready to receive. + # prev_device = (my_id - 1) % num_devices + # pltpu.semaphore_signal(barrier_sem, 1, device_id=prev_device) + # pltpu.semaphore_wait(barrier_sem) + # If the device_id is even, we copy to output_ref[1]. # If it's odd, we copy to output_ref[0]. @pl.when(even_device) @@ -356,7 +634,7 @@ def _(): dst_ref=output_ref.at[1], send_sem=send_sem, recv_sem=recv_sem, - device_id=neighbor, + device_id=next_device, ) remote_dma.start() remote_dma.wait() @@ -367,7 +645,7 @@ def _(): dst_ref=output_ref.at[0], send_sem=send_sem, recv_sem=recv_sem, - device_id=neighbor, + device_id=next_device, ) remote_dma.start() remote_dma.wait() @@ -376,11 +654,13 @@ def _(): grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), scratch_shapes=( - [pltpu.SemaphoreType.DMA] * 2 + [pltpu.SemaphoreType.DMA, + pltpu.SemaphoreType.DMA, + pltpu.SemaphoreType.REGULAR] ) ) @@ -403,25 +683,24 @@ def _(): mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x'), - check_rep=False)) + check_vma=False)) result_interpret = compiled_func(sharded_arr) - kernel = pl.pallas_call( - test_kernel, - out_shape=out_shape, - grid_spec=grid_spec, - ) - compiled_func = jax.jit(shard_map.shard_map( - kernel, - mesh=mesh, - in_specs=P(None, 'x'), - out_specs=P(None, 'x'), - check_rep=False)) - result_noninterpret = compiled_func(sharded_arr) - np.testing.assert_allclose(result_interpret, - result_noninterpret, - atol=1e-5, - rtol=1e-3) + expected = [] + zeros = jnp.zeros((8, 128), jnp.float32) + for i in range(num_devices): + if i == 0: + x_slice = unsharded_arr[:, 128 * (num_devices - 1):] + else: + x_slice = unsharded_arr[:, 128 * (i-1):128 * i] + if i % 2 == 0: + expected.append(jnp.stack([zeros, x_slice], axis=0)) + else: + expected.append(jnp.stack([x_slice, zeros], axis=0)) + expected = jnp.concatenate(expected, axis=1) + + np.testing.assert_array_equal(result_interpret, + expected) def test_interpret_remote_dma_asymmetrical_refs(self): # Test DMAs where dst refs are not the same. @@ -468,11 +747,11 @@ def _(): grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), ], out_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), ], scratch_shapes=( [pltpu.SemaphoreType.DMA] * 2 @@ -498,7 +777,7 @@ def _(): mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x'), - check_rep=False)) + check_vma=False)) result_interpret = compiled_func(sharded_arr) kernel = pl.pallas_call( @@ -511,7 +790,7 @@ def _(): mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x'), - check_rep=False)) + check_vma=False)) result_noninterpret = compiled_func(sharded_arr) np.testing.assert_allclose(result_interpret, result_noninterpret, @@ -522,6 +801,13 @@ def _(): class VerificationTest(jtu.JaxTestCase): def test_verification(self): + if jtu.is_device_tpu(7, 'x'): + # TODO(sharadmv): Enable these tests. + self.skipTest('Tests time out on TPUs v7x.') + + self.skipTest( + 'TODO(b/455847773): Fix MLIR layout mismatch in tpu.memref_slice (dynamic offset issue).' + ) if (num_devices := jax.local_device_count()) <= 1: self.skipTest('Test requires multiple devices.') if not jtu.is_device_tpu_at_least(4) or jax.devices()[0].num_cores > 1: @@ -569,11 +855,83 @@ def _(i, _): previous_config = jax.config.read('jax_pallas_dump_promela_to') jax.config.update('jax_pallas_dump_promela_to', tmpdir) shard_map.shard_map( - kernel, mesh=mesh, in_specs=P('x'), out_specs=P(None), check_rep=False + kernel, mesh=mesh, in_specs=P('x'), out_specs=P(None), check_vma=False )(jnp.ones((8, 128, 128), jnp.float32)) jax.config.update('jax_pallas_dump_promela_to', previous_config) self.assertNotEmpty(os.listdir(tmpdir)) +class PallasKernelMetadataDistributedTest(parameterized.TestCase): + + @parameterized.product( + axis_names=[['x', 'y'], [('x', 'y')], ['x'], ['y']], + op=['copy', 'signal'], + ) + def test_mesh_axes_metadata_is_preserved(self, axis_names, op): + if not jtu.is_device_tpu_at_least(4): + self.skipTest('Remote async copy only supported on TPU v4+') + if jtu.is_device_tpu(7, 'x'): + # TODO(sharadmv): Enable these tests. + self.skipTest('Tests time out on TPUs v7x.') + if len(jax.devices()) < 4: + self.skipTest('Not enough devices') + devices = np.array(jax.devices()[:4]).reshape((2, 2)) + mesh = jax.sharding.Mesh(devices, ('x', 'y')) + + def kernel(x_ref, out_ref): + def body(send_sem, recv_sem, sem): + if len(jax.tree.leaves(axis_names)) > 0: + device_id = {a: 0 for a in axis_names} + if op == 'copy': + pltpu.async_remote_copy( + x_ref, + out_ref, + send_sem, + recv_sem, + device_id=device_id, + ).wait() + else: + pl.semaphore_signal(sem, device_id=device_id) + else: + out_ref[...] = x_ref[...] + pl.run_scoped( + body, + send_sem=pltpu.SemaphoreType.DMA, + recv_sem=pltpu.SemaphoreType.DMA, + sem=pltpu.SemaphoreType.REGULAR, + ) + + @functools.partial( + jax.jit, + out_shardings=jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec('x', 'y') + ), + ) + @functools.partial( + jax.shard_map, + mesh=mesh, + in_specs=jax.sharding.PartitionSpec('x', 'y'), + out_specs=jax.sharding.PartitionSpec('x', 'y'), + check_vma=False, + ) + def f(x): + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((1, 1, 1, 128), jnp.float32), + in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), + )(x) + + x = jnp.zeros((2, 2, 1, 128), dtype=jnp.float32) + hlo = f.lower(x).compile().as_text() + axis_names_text = json.dumps( + json.dumps(sorted(jax.tree.leaves(axis_names))) + ) + self.assertIn( + f'"mesh_axes":{axis_names_text}', + hlo, + ) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_pallas_interpret_distributed_test.py b/tests/pallas/tpu_pallas_interpret_distributed_test.py index 518c16ed2109..d7bccfd6475d 100644 --- a/tests/pallas/tpu_pallas_interpret_distributed_test.py +++ b/tests/pallas/tpu_pallas_interpret_distributed_test.py @@ -18,20 +18,16 @@ contains only tests that use shard_map. """ -import functools - from absl.testing import absltest from absl.testing import parameterized - import jax from jax import lax +from jax._src import shard_map from jax._src import test_util as jtu -import jax._src.pallas.mosaic.interpret as mosaic_interpret +from jax._src.pallas.mosaic.interpret import interpret_pallas_call as mosaic_interpret from jax.experimental import pallas as pl -from jax.experimental import shard_map from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp - import numpy as np jax.config.parse_flags_with_absl() @@ -40,9 +36,16 @@ P = jax.sharding.PartitionSpec +# TODO(jburnim): Figure out how to safely run different instance of TPU +# interpret mode in parallel, and then remove this decorator. +@jtu.thread_unsafe_test_class() class InterpretDistributedTest(jtu.JaxTestCase): def setUp(self): super().setUp() + + if not jtu.test_device_matches(['cpu']): + self.skipTest('CPU-only test') + if jax.device_count() < 4: self.skipTest(f'requires at least 4 devices, found {jax.device_count()}') @@ -52,7 +55,7 @@ def setUp(self): def test_right_permute_example(self, dma_execution_mode, detect_races): num_devices = jax.device_count() partition = P(None, 'x') - mesh = jax.make_mesh((num_devices,), ('x',)) + mesh = jtu.create_mesh((num_devices,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, partition) # Create an input array that shards the last dimension across @@ -91,11 +94,11 @@ def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem): out_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32) grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, - # TPUMemorySpace.ANY will (usually) place the tensor in HBM. + # MemorySpace.ANY will (usually) place the tensor in HBM. in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), scratch_shapes=( # We allocate DMA semaphores in scratch memory. [pltpu.SemaphoreType.DMA] * 2 @@ -105,8 +108,8 @@ def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem): right_permute_kernel, out_shape=out_shape, grid_spec=grid_spec, - compiler_params=pltpu.TPUCompilerParams(collective_id=13), - interpret=mosaic_interpret.TPUInterpretParams( + compiler_params=pltpu.CompilerParams(collective_id=13), + interpret=pltpu.InterpretParams( dma_execution_mode=dma_execution_mode, detect_races=detect_races), ) # Wrap the kernel within a shard_map to call. @@ -116,7 +119,7 @@ def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem): mesh=mesh, in_specs=partition, out_specs=partition, - check_rep=False, + check_vma=False, ) )(input_arr) @@ -138,7 +141,7 @@ def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem): def test_all_gather_example(self, dma_execution_mode, detect_races): num_devices = jax.device_count() partition = P('x', None) - mesh = jax.make_mesh((num_devices,), ('x',)) + mesh = jtu.create_mesh((num_devices,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, partition) # Create an input array that shards the first dimension across @@ -205,10 +208,10 @@ def _(): grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - # TPUMemorySpace.ANY will (usually) place the tensor in HBM. - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + # MemorySpace.ANY will (usually) place the tensor in HBM. + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), scratch_shapes=( # DMA semaphores are allocated in scratch memory. # We allocated one semaphore for a local HBM-VMEM copy, @@ -227,9 +230,9 @@ def _(): all_gather_kernel, out_shape=out_shape, grid_spec=grid_spec, - interpret=mosaic_interpret.TPUInterpretParams( + interpret=pltpu.InterpretParams( dma_execution_mode=dma_execution_mode, detect_races=detect_races), - compiler_params=pltpu.TPUCompilerParams(collective_id=0), + compiler_params=pltpu.CompilerParams(collective_id=0), ) # Wrap the kernel within a shard_map to call. @@ -239,7 +242,7 @@ def _(): mesh=mesh, in_specs=partition, out_specs=partition, - check_rep=False + check_vma=False ) )(input_arr) @@ -261,7 +264,7 @@ def _(): def test_all_reduce_sum_example(self, dma_execution_mode, detect_races): num_devices = jax.device_count() partition = P(None, 'x') - mesh = jax.make_mesh((num_devices,), ('x',)) + mesh = jtu.create_mesh((num_devices,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, partition) input_arr = jax.random.uniform( @@ -367,13 +370,13 @@ def _(): num_scalar_prefetch=0, in_specs=[ # Our input lives in VMEM - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), ], out_specs=[ # Our output lives in VMEM - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), # Our double-buffer lives in HBM - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], grid=(num_devices,), scratch_shapes=( @@ -387,9 +390,9 @@ def _(): all_reduce_kernel, out_shape=out_shape, grid_spec=grid_spec, - interpret=mosaic_interpret.TPUInterpretParams( + interpret=pltpu.InterpretParams( dma_execution_mode=dma_execution_mode, detect_races=detect_races), - compiler_params=pltpu.TPUCompilerParams(collective_id=0), + compiler_params=pltpu.CompilerParams(collective_id=0), ) pallas_result = jax.jit( @@ -398,7 +401,7 @@ def _(): mesh=mesh, in_specs=partition, out_specs=partition, - check_rep=False, + check_vma=False, ) )(input_arr) pallas_result = jax.block_until_ready(pallas_result)[0] @@ -422,7 +425,7 @@ def lax_sum(x): def test_reduce_scatter_sum_example(self, dma_execution_mode, detect_races): num_devices = jax.device_count() partition = P(None, 'x') - mesh = jax.make_mesh((num_devices,), ('x',)) + mesh = jtu.create_mesh((num_devices,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, partition) # We need a block size of (16, 128) to ensure that a half-slice is at least @@ -649,11 +652,11 @@ def _(): grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), ], out_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.VMEM), + pl.BlockSpec(memory_space=pl.ANY), ], grid=(num_devices, 2), scratch_shapes=( @@ -671,9 +674,9 @@ def pallas_reduce_scatter(input_arr): reduce_scatter_kernel, out_shape=out_shape, grid_spec=grid_spec, - interpret=mosaic_interpret.TPUInterpretParams( + interpret=pltpu.InterpretParams( dma_execution_mode=dma_execution_mode, detect_races=True), - compiler_params=pltpu.TPUCompilerParams(collective_id=7), + compiler_params=pltpu.CompilerParams(collective_id=7), )(input_arr)[0] pallas_result = jax.jit( @@ -682,7 +685,7 @@ def pallas_reduce_scatter(input_arr): mesh=mesh, in_specs=P(None, 'x'), out_specs=P('x', None), - check_rep=False, + check_vma=False, ) )(input_arr) pallas_result = jax.block_until_ready(pallas_result) @@ -716,7 +719,7 @@ def test_reduce_scatter_sum_with_emit_pipeline_example( self.skipTest('pallas.emit_pipeline + x64 is not currently supported') num_devices = jax.device_count() partition = P(None, 'x') - mesh = jax.make_mesh((num_devices,), ('x',)) + mesh = jtu.create_mesh((num_devices,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, partition) # We pick a large outer kernel block size that we do not want to place @@ -742,7 +745,7 @@ def test_reduce_scatter_sum_with_emit_pipeline_example( inner_block_spec = pl.BlockSpec( index_map=lambda i, j: (i, j), block_shape=inner_block_size, - memory_space=pltpu.TPUMemorySpace.ANY, + memory_space=pl.ANY, ) LEFT = 0 @@ -954,11 +957,11 @@ def _(): grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], out_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], grid=(num_devices, 2), scratch_shapes=( @@ -975,9 +978,9 @@ def pallas_reduce_scatter(input_arr): reduce_scatter_kernel, out_shape=out_shape, grid_spec=grid_spec, - interpret=mosaic_interpret.TPUInterpretParams( + interpret=pltpu.InterpretParams( dma_execution_mode=dma_execution_mode, detect_races=detect_races), - compiler_params=pltpu.TPUCompilerParams(collective_id=19), + compiler_params=pltpu.CompilerParams(collective_id=19), )(input_arr)[0] pallas_result = jax.jit( @@ -986,7 +989,7 @@ def pallas_reduce_scatter(input_arr): mesh=mesh, in_specs=P(None, 'x'), out_specs=P('x', None), - check_rep=False, + check_vma=False, ) )(input_arr) pallas_result = jax.block_until_ready(pallas_result) @@ -1017,19 +1020,6 @@ def test_race_detection(self): input_arr = jax.device_put(input_arr, sharding) def kernel(src_dst_ids_ref, x_ref, o_ref, send_sem, recv_sem): - # Barrier with all devices before doing any DMAs. - barrier_sem = pltpu.get_barrier_semaphore() - @functools.partial(jax.lax.fori_loop, 0, num_devices, init_val=None) - def _(i, _): - pltpu.semaphore_signal( - barrier_sem, - inc=1, - device_id=(jnp.int32(i),), - device_id_type=pltpu.DeviceIdType.MESH, - ) - return None - pltpu.semaphore_wait(barrier_sem, num_devices) - # Send the specified DMAs. my_id = lax.axis_index('x') src_dst_ids = src_dst_ids_ref[:] @@ -1071,13 +1061,12 @@ def run(src_dst_ids): kernel, out_shape=jax.ShapeDtypeStruct((8, 128), input_arr.dtype), in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.SMEM), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), scratch_shapes=[pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.DMA], - compiler_params=pltpu.TPUCompilerParams(collective_id=0), - interpret=mosaic_interpret.TPUInterpretParams( + interpret=pltpu.InterpretParams( dma_execution_mode='eager', detect_races=True, ), @@ -1085,7 +1074,7 @@ def run(src_dst_ids): mesh=mesh, in_specs=(P(None), P('x', None)), out_specs=P('x', None), - check_rep=False, + check_vma=False, )(src_dst_ids, input_arr) run(jnp.array([[0, 1], [1, 2], [2, 3]], jnp.int32)).block_until_ready() @@ -1095,6 +1084,67 @@ def run(src_dst_ids): run(jnp.array([[0, 1], [1, 2], [3, 2], [3, 0]], jnp.int32)).block_until_ready() self.assertTrue(mosaic_interpret.races.races_found) + @parameterized.parameters(1, 2, 4) + def test_shard_map_of_core_map(self, num_cores): + num_devices = jax.device_count() + partition = P('x', None) + mesh = jtu.create_mesh((num_devices,), ('x',)) + sharding = jax.sharding.NamedSharding(mesh, partition) + + core_mesh = pltpu.create_tensorcore_mesh('core', num_cores=num_cores) + interpret = pltpu.InterpretParams(detect_races=True) + + @jax.jit + def f(x): + y = jnp.zeros_like(x) + def inner(refs): + x_ref, y_ref = refs + @pl.core_map(core_mesh, interpret=interpret) + def _(): + num_cores = jax.lax.axis_size('core') + slc_size = 16 // num_cores + def alloc(x_vmem_ref, y_vmem_ref, dma_sem, sem): + # Barrier so we deadlock unless the core_map is actually parallel. + for i in range(num_cores): + pl.semaphore_signal(sem, 1, core_index=i) + pl.semaphore_wait(sem, num_cores) + + core_index = jax.lax.axis_index('core') + slc = pl.ds(core_index * slc_size, slc_size) + pltpu.async_copy( + x_ref.at[slc], + x_vmem_ref, + dma_sem, + ).wait() + y = (x_vmem_ref[...] + num_cores * jax.lax.axis_index('x') + + core_index + 1) + y_vmem_ref[...] = y + pltpu.async_copy(y_vmem_ref, y_ref.at[slc], dma_sem).wait() + pl.run_scoped( + alloc, + pltpu.VMEM((slc_size, 128), x_ref.dtype), + pltpu.VMEM((slc_size, 128), y_ref.dtype), + pltpu.SemaphoreType.DMA, + pltpu.SemaphoreType.REGULAR, + ) + _, y = pl.run_state(inner)((x, y)) + return y + + x = jnp.arange(num_devices * 16 * 128, dtype=jnp.int32).reshape((-1, 128)) + y = jax.jit( + shard_map.shard_map(f, + mesh=mesh, + in_specs=partition, + out_specs=partition, + check_vma=False, + ) + )(x).block_until_ready() + expected_out = ( + x.reshape((num_devices, num_cores, -1, 128)) + 1 + + jnp.arange(num_devices, dtype=jnp.int32)[..., None, None, None] * num_cores + + jnp.arange(num_cores, dtype=jnp.int32)[None, ..., None, None] + ).reshape(x.shape) + np.testing.assert_array_equal(y, expected_out) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index bc589855b836..223eb1286ccf 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -18,30 +18,124 @@ contains only tests that do not use shard_map. """ +from collections.abc import Callable +import dataclasses +import functools + from absl.testing import absltest from absl.testing import parameterized - import jax from jax._src import test_util as jtu -import jax._src.pallas.mosaic.interpret as mosaic_interpret +from jax._src.pallas.mosaic.interpret import interpret_pallas_call as mosaic_interpret from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp - import numpy as np jax.config.parse_flags_with_absl() +jax.config.update('jax_threefry_partitionable', True) + + +class CountStoreCallbacksContext: + """Wraps the I/O callback `store` into a callback that counts the number of calls to `store`.""" + + def __init__(self): + self._num_stores = 0 + self._saved = mosaic_interpret.store + + def __enter__(self): + def _store_callback(self, *args, **kwargs): + self._num_stores += 1 + return self._saved(*args, **kwargs) + + mosaic_interpret.store = functools.partial(_store_callback, self) + return self + + def __exit__(self, ty, value, traceback): + del ty, value, traceback + mosaic_interpret.store = self._saved + + @property + def num_stores(self): + return self._num_stores + + +@dataclasses.dataclass(frozen=True) +class ProcessedGridPoint(): + """Represents a grid point and the ID of the core that has processed it.""" + grid_point: tuple[int, ...] + core_id: int + +class GridPointRecorderContext: + """Records grid points in the order in which they are procsessed.""" + def __init__(self): + self._grid_points: list[ProcessedGridPoint] = [] + + def __enter__(self): + return self + + def __exit__(self, ty, value, traceback): + ... + + def get_recorder(self) -> Callable[[tuple[np.int32, ...], np.int32], None]: + def _recorder(grid_point, core_id): + processed_grid_point = ProcessedGridPoint( + tuple(int(coord) for coord in grid_point), int(core_id) + ) + self._grid_points.append(processed_grid_point) + + return _recorder + + @property + def grid_points(self) -> list[ProcessedGridPoint]: + return sorted(self._grid_points, key=lambda x: x.core_id) + + +# TODO(jburnim): Figure out how to safely run different instance of TPU +# interpret mode in parallel, and then remove this decorator. +@jtu.thread_unsafe_test_class() class InterpretTest(jtu.JaxTestCase): + def setUp(self): super().setUp() + + if not jtu.test_device_matches(['cpu']): + self.skipTest('CPU-only test') + self.num_devices = jax.device_count() if self.num_devices > 1: # Workaround for https://github.com/jax-ml/jax/issues/25671 self.skipTest(f'requires 1 device, found {self.num_devices}') + def test_revisiting_is_an_error(self): + def kernel(x_ref, o1_ref, o2_ref): + pass + + @jax.jit + def run(): + return pl.pallas_call( + kernel, + out_shape=[ + jax.ShapeDtypeStruct((16, 256), jnp.float32), + jax.ShapeDtypeStruct((16, 256), jnp.float32), + ], + grid=(4, 4), + in_specs=[pl.BlockSpec(memory_space=pl.ANY)], + out_specs=[ + pl.BlockSpec((4, 128), lambda i, j: (i, j // 2)), + pl.BlockSpec((4, 128), lambda i, j: (j // 2, i % 2)), + ], + interpret=pltpu.InterpretParams(), + )(jnp.zeros((8, 128))) + + with self.assertRaisesRegex( + Exception, r'Revisited block .* of output 1 in iteration \(2, 0\)'): + run()[0].block_until_ready() + pltpu.reset_tpu_interpret_mode_state() + def test_matmul_example(self): def matmul_kernel(x_ref, y_ref, z_ref): z_ref[...] = x_ref[...] @ y_ref[...] @@ -49,50 +143,242 @@ def matmul_kernel(x_ref, y_ref, z_ref): @jax.jit def matmul(x: jax.Array, y: jax.Array): return pl.pallas_call( - matmul_kernel, - out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype), - grid=(2, 2), - in_specs=[ - pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)), - pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j)) - ], - out_specs=pl.BlockSpec( - (x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j), - ), - interpret=mosaic_interpret.TPUInterpretParams(), + matmul_kernel, + out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype), + grid=(2, 2), + in_specs=[ + pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)), + pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j)), + ], + out_specs=pl.BlockSpec( + (x.shape[0] // 2, y.shape[1] // 2), + lambda i, j: (i, j), + ), + interpret=pltpu.InterpretParams(), )(x, y) k1, k2 = jax.random.split(jax.random.key(0)) x = jax.random.normal(k1, (1024, 1024)) y = jax.random.normal(k2, (1024, 1024)) z = matmul(x, y) - np.testing.assert_allclose(z, x @ y, atol=1e-4) + np.testing.assert_allclose(z, x @ y, atol=1e-3) - def test_dynamic_grid_and_aliasing(self): + @parameterized.parameters('raise', 'uninitialized') + def test_out_of_bounds_block_spec(self, out_of_bounds_reads): + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + + @functools.partial(jax.jit, static_argnums=(0, 1)) + def run(input_offset, output_offset): + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((16, 128), jnp.float32), + out_specs=pl.BlockSpec((4, 128), lambda i: (i+output_offset, 0)), + in_specs=[pl.BlockSpec((4, 128), lambda i: (i+input_offset, 0))], + grid=(4,), + interpret=pltpu.InterpretParams( + out_of_bounds_reads=out_of_bounds_reads), + )(jnp.zeros((16, 128), jnp.float32)) + + # Out-of-bounds input block. + if out_of_bounds_reads == 'uninitialized': + out = np.array(run(1, 0)) + np.testing.assert_equal(out[:12], 0.0) + self.assertTrue(np.isnan(out[12:]).all()) + elif out_of_bounds_reads == 'raise': + with self.assertRaisesRegex( + Exception, 'Out-of-bounds block index .* for input'): + run(1, 0) + pltpu.reset_tpu_interpret_mode_state() + + # Out-of-bounds output block. + if out_of_bounds_reads == 'raise': + with self.assertRaisesRegex( + Exception, 'Out-of-bounds block index .* for output'): + run(0, 2) + pltpu.reset_tpu_interpret_mode_state() + + @parameterized.parameters('raise', 'uninitialized') + def test_out_of_bounds_read_index(self, out_of_bounds_reads): def kernel(s_ref, x_ref, o_ref): - o_ref[...] = x_ref[...] + s_ref[0].astype(x_ref.dtype) + def read(ref, i): + return ref[i] + def body(carry): + i, accum = carry + accum += read(x_ref, i) + return (i + 1, accum) + start = read(x_ref, s_ref[0]) + stop = read(x_ref, s_ref[1]) + o_ref[0] = jax.lax.while_loop( + lambda c: c[0] < stop, + body, + (start, jnp.int32(0)))[1] - iters = jax.random.randint(jax.random.key(0), (), 10, 20, dtype=jnp.int32) @jax.jit - def f(s, x): + def run(s, x): return pl.pallas_call( kernel, - out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), - grid=(iters,), + out_shape=jax.ShapeDtypeStruct((1,), jnp.int32), + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), - pl.BlockSpec(x.shape, lambda i: (0, 0)), + pl.BlockSpec(memory_space=pltpu.SMEM), + pl.BlockSpec(memory_space=pltpu.SMEM) ], - out_specs=pl.BlockSpec(x.shape, lambda i: (0, 0)), - input_output_aliases={1: 0}, - interpret=mosaic_interpret.TPUInterpretParams() + interpret=pltpu.InterpretParams( + out_of_bounds_reads=out_of_bounds_reads), )(s, x) - s = jnp.array([1], dtype=jnp.int32) - x = jnp.arange(32 * 128.).reshape((32, 128)) - y = f(s, x) + self.assertEqual(run(jnp.array([0, 1], jnp.int32), + jnp.array([2, 5, 9, 15, 17], jnp.int32)), + 9 + 15 + 17) + + if out_of_bounds_reads == 'uninitialized': + self.assertLess(run(jnp.array([0, 1], jnp.int32), + jnp.array([2, 6, 9, 15, 17], jnp.int32)), + 0) # sum includes one uninitialized value + elif out_of_bounds_reads == 'raise': + with self.assertRaisesRegex(Exception, 'Out-of-bounds read'): + run(jnp.array([0, 1], jnp.int32), + jnp.array([2, 6, 9, 15, 17], jnp.int32)), + pltpu.reset_tpu_interpret_mode_state() + + @parameterized.parameters('raise', 'uninitialized') + def test_out_of_bounds_read_range(self, out_of_bounds_reads): + def kernel(x_ref, o_ref, sem): + pltpu.async_copy(x_ref.at[pl.ds(jnp.int32(4), 8), 1], o_ref, sem).wait() + + @jax.jit + def run(): + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((8, 128,), jnp.float32), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), + in_specs=[pl.BlockSpec(memory_space=pl.ANY)], + scratch_shapes=[pltpu.SemaphoreType.DMA], + interpret=pltpu.InterpretParams( + out_of_bounds_reads=out_of_bounds_reads), + )(jnp.zeros((8, 4, 128), jnp.float32)) + + if out_of_bounds_reads == 'raise': + with self.assertRaisesRegex(Exception, 'Out-of-bounds read'): + run().block_until_ready() + pltpu.reset_tpu_interpret_mode_state() + else: + out = run().block_until_ready() + np.testing.assert_equal(np.array(out[:4]), 0.0) + self.assertTrue(np.isnan(out[4:]).all()) + + def test_scalar_prefetch_example(self): + def dynamic_slice_kernel(indices, x_ref, o_ref): + del indices + o_ref[...] = x_ref[...] + + @functools.partial(jax.jit, static_argnums=(2,)) + def block_dynamic_slice(x, starts, sizes): + grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + grid=(1, 1), + in_specs=[ + pl.BlockSpec( + sizes, lambda i, j, block_idx: (block_idx[0], block_idx[1]) + ) + ], + out_specs=pl.BlockSpec(sizes, lambda *_: (0, 0)), + ) + + kernel = pl.pallas_call( + dynamic_slice_kernel, + grid_spec=grid_spec, + out_shape=jax.ShapeDtypeStruct(shape=sizes, dtype=x.dtype), + interpret=pltpu.InterpretParams(), + ) + block_idx = jnp.array([starts[0] // sizes[0], starts[1] // sizes[1]]) + return kernel(block_idx, x) + + shape = (512, 512) + x = jnp.reshape(jnp.arange(np.prod(shape), dtype=jnp.int32), shape) + result = block_dynamic_slice( + x, starts=jnp.array([128, 256]), sizes=(128, 128) + ) + ref = jax.lax.dynamic_slice( + x, start_indices=(128, 256), slice_sizes=(128, 128) + ) + diff = jnp.max(jnp.abs(result - ref)) + np.testing.assert_allclose(result, ref) + + def test_dynamic_grid_and_aliasing(self): + def kernel(s1_ref, s2_ref, x_ref, o_ref): + del s2_ref + o_ref[...] = x_ref[...] + s1_ref[0].astype(x_ref.dtype) + + iters = jax.random.randint(jax.random.key(0), (), 10, 20, dtype=jnp.int32) + + @jax.jit + def f(s1, s2, x): + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=2, + grid=(iters,), + in_specs=[ + pl.BlockSpec(x.shape, lambda i, *_: (0, 0)), + ], + out_specs=pl.BlockSpec(x.shape, lambda i, *_: (0, 0)), + ), + input_output_aliases={2: 0}, + interpret=pltpu.InterpretParams(), + )(s1, s2, x) + + s1 = jnp.array([1], dtype=jnp.int32) + s2 = jnp.array([2], dtype=jnp.int32) + x = jnp.arange(32 * 128.0).reshape((32, 128)) + y = f(s1, s2, x) + # NOTE: No matter how many times the kernel body is run, the kernel input + # buffer will only be written once by the pallas_call machinery, just + # before the first iteration. So the output will be x + 1 , despite the + # aliasing in HBM. np.testing.assert_allclose(y, x + 1.0) + def test_aliasing(self): + def kernel(x_ref, o_ref, s_ref): + @pl.when((pl.program_id(0) == 0) & (pl.program_id(1) == 0)) + def _(): + s_ref[0] = jnp.int32(0) + + s = s_ref[0] + s_ref[0] = s + 1 + o_ref[:] = x_ref[:] + s.astype(x_ref.dtype) + + x = jnp.zeros((4 * 8, 4 * 128)) + y = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + grid=(4, 4), + in_specs=[ + pl.BlockSpec(block_shape=(8, 128), index_map=lambda i, j: (i, j)), + ], + out_specs=pl.BlockSpec( + block_shape=(8, 128), index_map=lambda i, j: (j, i) + ), + scratch_shapes=(pltpu.SMEM((1,), jnp.int32),), + input_output_aliases={0: 0}, + interpret=pltpu.InterpretParams(), + )(x) + + expected = np.zeros((4, 4)) + t = 0 + for i in range(4): + for j in range(4): + expected[j, i] = expected[i, j] + t + t += 1 + # NOTE: expected is + # [[0, 5, 10, 15], + # [1, 5, 15, 20], + # [2, 6, 10, 25], + # [3, 7, 11, 15]] + np.testing.assert_allclose(y[::8, ::128], expected) + @parameterized.parameters('eager', 'on_wait') def test_race_detection(self, dma_execution_mode): def kernel_without_race(x_ref, o_ref, t_ref, sem): @@ -109,28 +395,32 @@ def kernel_with_race(x_ref, o_ref, t_ref, sem): copy.wait() x = jnp.zeros((8, 128), jnp.float32) - y = pl.pallas_call(kernel_without_race, + y = pl.pallas_call( + kernel_without_race, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)], + in_specs=[pl.BlockSpec(memory_space=pl.ANY)], scratch_shapes=[ pltpu.VMEM(x.shape, x.dtype), pltpu.SemaphoreType.DMA, ], - interpret=mosaic_interpret.TPUInterpretParams( - detect_races=True, dma_execution_mode=dma_execution_mode), + interpret=pltpu.InterpretParams( + detect_races=True, dma_execution_mode=dma_execution_mode + ), )(x).block_until_ready() self.assertFalse(mosaic_interpret.races.races_found) np.testing.assert_allclose(y, x + 1.0) - pl.pallas_call(kernel_with_race, + pl.pallas_call( + kernel_with_race, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)], + in_specs=[pl.BlockSpec(memory_space=pl.ANY)], scratch_shapes=[ pltpu.VMEM(x.shape, x.dtype), pltpu.SemaphoreType.DMA, ], - interpret=mosaic_interpret.TPUInterpretParams( - detect_races=True, dma_execution_mode=dma_execution_mode), + interpret=pltpu.InterpretParams( + detect_races=True, dma_execution_mode=dma_execution_mode + ), )(x).block_until_ready() self.assertTrue(mosaic_interpret.races.races_found) @@ -142,7 +432,7 @@ def matmul(x: jax.Array, y: jax.Array): return pl.pallas_call( matmul_kernel, out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype), - interpret=mosaic_interpret.TPUInterpretParams( + interpret=pltpu.InterpretParams( skip_floating_point_ops=True ), )(x, y) @@ -153,8 +443,8 @@ def matmul(x: jax.Array, y: jax.Array): z = jax.jit(matmul)(x, y) np.testing.assert_array_equal(z, jnp.full_like(z, jnp.inf)) - lowered = jax.jit(matmul).lower(x, y).as_text(dialect="stablehlo") - self.assertNotIn("dot_general", lowered) + lowered = jax.jit(matmul).lower(x, y).as_text(dialect='stablehlo') + self.assertNotIn('dot_general', lowered) @parameterized.parameters('nan', 'zero') def test_uninitialized_memory(self, uninitialized_memory): @@ -174,8 +464,9 @@ def kernel(o1_ref, o2_ref, o3_ref, t1_ref, t2_ref): pltpu.VMEM((8, 128), jnp.bfloat16), pltpu.VMEM((8, 128), jnp.int16), ], - interpret=mosaic_interpret.TPUInterpretParams( - uninitialized_memory=uninitialized_memory), + interpret=pltpu.InterpretParams( + uninitialized_memory=uninitialized_memory + ), )() if uninitialized_memory == 'nan': self.assertTrue(jnp.isnan(x).all()) @@ -186,6 +477,795 @@ def kernel(o1_ref, o2_ref, o3_ref, t1_ref, t2_ref): np.testing.assert_equal(np.array(y), 0) np.testing.assert_equal(np.array(z), 0) + def test_correct_number_of_stores(self): + def kernel(x_ref, s_ref, o_ref): + s = s_ref[0] + x_ref[:] += jax.lax.full_like(x_ref, s) + s_ref[0] = s + 1 + o_ref[:] = x_ref[:] + + def kernel_call(x, s): + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((16, 256), jnp.float32), + grid=(2, 2), + in_specs=[ + pl.BlockSpec((8, 256), lambda i, j: (i, 0)), + pl.BlockSpec(memory_space=pltpu.SMEM), + ], + out_specs=pl.BlockSpec((8, 256), lambda i, j: (i, 0)), + interpret=pltpu.InterpretParams(), + )(x, s) + + with CountStoreCallbacksContext() as store_callbacks_counter: + result = jax.jit(kernel_call)( + jnp.zeros((16, 256), jnp.float32), jnp.zeros((1,), jnp.int32) + ) + np.testing.assert_allclose(result[::8, ::256], [[1.0], [5.0]]) + self.assertEqual(store_callbacks_counter.num_stores, 5) + + def test_randomization_of_parallel_dimensions(self): + def kernel(s_ref, o_ref): + s = s_ref[0] + s_ref[0] = s + 1 + o_ref[:] = jax.lax.full_like(o_ref, s) + + def kernel_call_dimensions_parallel_arbitrary(s, grid_point_recorder): + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((32, 512), jnp.float32), + grid=(4, 4), + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], + out_specs=pl.BlockSpec((8, 128), lambda i, j: (i, j)), + interpret=pltpu.InterpretParams( + random_seed=12345, grid_point_recorder=grid_point_recorder + ), + compiler_params=pltpu.CompilerParams( + dimension_semantics=('parallel', 'arbitrary') + ), + )(s) + + with GridPointRecorderContext() as grid_point_recorder: + result = jax.jit( + kernel_call_dimensions_parallel_arbitrary, static_argnums=1 + )( + jnp.zeros((1,), jnp.int32), + grid_point_recorder.get_recorder(), + ) + np.testing.assert_allclose( + result[::8, ::128], + [ + [ 8.0, 9.0, 10.0, 11.0], + [12.0, 13.0, 14.0, 15.0], + [ 0.0, 1.0, 2.0, 3.0], + [ 4.0, 5.0, 6.0, 7.0], + ], + ) + self.assertListEqual( + grid_point_recorder.grid_points, + [ + ProcessedGridPoint((2, 0), 0), + ProcessedGridPoint((2, 1), 0), + ProcessedGridPoint((2, 2), 0), + ProcessedGridPoint((2, 3), 0), + ProcessedGridPoint((3, 0), 0), + ProcessedGridPoint((3, 1), 0), + ProcessedGridPoint((3, 2), 0), + ProcessedGridPoint((3, 3), 0), + ProcessedGridPoint((0, 0), 0), + ProcessedGridPoint((0, 1), 0), + ProcessedGridPoint((0, 2), 0), + ProcessedGridPoint((0, 3), 0), + ProcessedGridPoint((1, 0), 0), + ProcessedGridPoint((1, 1), 0), + ProcessedGridPoint((1, 2), 0), + ProcessedGridPoint((1, 3), 0), + ], + ) + + def test_dimensions_arbitrary_parallel_raises(self): + def kernel_call(s): + def kernel(s_ref, o_ref): + s = s_ref[0] + o_ref[0] = s + + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((32, 512), jnp.float32), + grid=(4, 4), + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], + out_specs=pl.BlockSpec((8, 128), lambda i, j: (i, j)), + interpret=pltpu.InterpretParams(random_seed=12345), + compiler_params=pltpu.CompilerParams( + dimension_semantics=('arbitrary', 'parallel') + ), + )(s) + + with self.assertRaises(ValueError): + jax.jit(kernel_call)( + jnp.zeros((1,), jnp.int32), + ) + + def test_dynamic_parallel_dimension_raises(self): + def kernel(o_ref): + o_ref[0] = 42.0 + + @jax.jit + def kernel_call_dynamic_parallel_dimension(): + dim_size = jax.random.randint( + jax.random.key(0), (), 10, 20, dtype=jnp.int32 + ) + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((1,), jnp.float32), + grid=(dim_size,), + in_specs=[], + out_specs=pl.BlockSpec((1,), lambda _: (0,)), + interpret=pltpu.InterpretParams(), + compiler_params=pltpu.CompilerParams( + dimension_semantics=('parallel',) + ), + )() + + with self.assertRaises(jax.errors.ConcretizationTypeError): + kernel_call_dynamic_parallel_dimension() + + @parameterized.product( + num_cores=[1, 2, 4], + use_context_manager=[False, True] + ) + def test_core_map(self, num_cores, use_context_manager): + mesh = pltpu.create_tensorcore_mesh('x', num_cores=num_cores) + interpret = False if use_context_manager else pltpu.InterpretParams() + + @jax.jit + def f(x): + y = jnp.zeros_like(x) + def inner(refs): + x_ref, y_ref = refs + @pl.core_map(mesh, interpret=interpret) + def _(): + num_cores = jax.lax.axis_size('x') + slc_size = 16 // num_cores + def alloc(x_vmem_ref, y_vmem_ref, dma_sem, sem): + # Barrier so we deadlock unless the core_map is actually parallel. + for i in range(num_cores): + pl.semaphore_signal(sem, 1, core_index=i) + pl.semaphore_wait(sem, num_cores) + + core_index = jax.lax.axis_index('x') + slc = pl.ds(core_index * slc_size, slc_size) + pltpu.async_copy( + x_ref.at[slc], + x_vmem_ref, + dma_sem, + ).wait() + y = x_vmem_ref[...] + jax.lax.axis_index('x') + 1 + y_vmem_ref[...] = y + pltpu.async_copy(y_vmem_ref, y_ref.at[slc], dma_sem).wait() + pl.run_scoped( + alloc, + pltpu.VMEM((slc_size, 128), x_ref.dtype), + pltpu.VMEM((slc_size, 128), y_ref.dtype), + pltpu.SemaphoreType.DMA, + pltpu.SemaphoreType.REGULAR, + ) + _, y = pl.run_state(inner)((x, y)) + return y + + x = jnp.arange(16 * 128, dtype=jnp.int32).reshape((16, 128)) + expected_out = ( + x.reshape((num_cores, -1, 128)) + 1 + + jnp.arange(num_cores, dtype=jnp.int32)[..., None, None] + ).reshape(x.shape) + + if use_context_manager: + with pltpu.force_tpu_interpret_mode(): + y = f(x) + else: + y = f(x) + np.testing.assert_array_equal(y, expected_out) + + def test_hbm_allocation_in_run_scoped_raises(self): + mesh = pltpu.create_tensorcore_mesh('x', num_cores=1) + + @jax.jit + def f(x): + y = jnp.zeros_like(x) + + def inner(x): + x_ref, y_ref = x + + @pl.core_map( + mesh, + interpret=pltpu.InterpretParams( + allow_hbm_allocation_in_run_scoped=False + ), + ) + def _(): + def copy(hbm): + pltpu.sync_copy(x_ref, hbm) + pltpu.sync_copy(hbm, y_ref) + + pl.run_scoped( + copy, + pltpu.HBM(x_ref.shape, x_ref.dtype), + ) + + _, y = pl.run_state(inner)((x, y)) + return y + + with self.assertRaisesRegex( + ValueError, r'Cannot allocate HBM in `run_scoped`.' + ): + f(jnp.arange(8)) + + @parameterized.product( + first_core_to_copy=[0, 1], dma_execution_mode=['eager', 'on_wait'] + ) + def test_allocate_shared_buffer_in_core_map( + self, first_core_to_copy, dma_execution_mode + ): + mesh = pltpu.create_tensorcore_mesh('x', num_cores=2) + second_core_to_copy = 1 if first_core_to_copy == 0 else 0 + + @jax.jit + def f(x): + y = jnp.zeros_like(x) + + def inner(refs): + x_ref, y_ref = refs + # Thanks to the semaphore `sem` below, this test is race-free, and both + # cores access the shared HBM buffer entirely sequentially. If the + # runtime management of buffers were not done carefully, two issues + # could arise, each resulting in an attempt to access unallocated + # memory: + # 1. The first core to reach the `copy` function, inside the nested + # `run_scoped`, might find that the shared HBM buffer has not been + # allocated yet. This could happen if the other core were + # repsonsible for allocating the HBM buffer when entering the nested + # `run_scoped`; but that other core has not reached the nested + # `run_scoped` yet, and hence has not allocated the HBM buffer yet. + # 2. The second core to reach the `copy` function might find that the + # shared HBM buffer has been deallocated already. This could happen + # if the other core (i.e. the first one to reach the `copy` + # function) were responsible for deallocating the HBM buffer when + # exiting the nested `run_scoped`. If that other core has already + # run ahead to the end of the nested `run_scoped`, it will have + # deallocated the HBM buffer. + @pl.core_map( + mesh, + interpret=pltpu.InterpretParams( + detect_races=True, allow_hbm_allocation_in_run_scoped=True, + dma_execution_mode=dma_execution_mode, + ), + ) + def _(): + def body(sem): + @pl.when(jax.lax.axis_index('x') == second_core_to_copy) + def _(): + pltpu.semaphore_wait(sem, 1) + + def copy(x_hbm_ref): + pltpu.sync_copy(x_ref, x_hbm_ref) + pltpu.sync_copy(x_hbm_ref, y_ref) + + pl.run_scoped( + copy, + pltpu.HBM(x_ref.shape, x_ref.dtype), + ) + + @pl.when(jax.lax.axis_index('x') == first_core_to_copy) + def _(): + pltpu.semaphore_signal(sem, 1, core_index=second_core_to_copy) + + pl.run_scoped( + body, + pltpu.SemaphoreType.REGULAR, + ) + + _, y = pl.run_state(inner)((x, y)) + return y + + x = jnp.arange(16 * 128, dtype=jnp.int32).reshape((16, 128)) + y = f(x) + np.testing.assert_array_equal(y, x) + self.assertFalse(mosaic_interpret.races.races_found) + + @parameterized.product( + slow_core=[0, 1], dma_execution_mode=['eager', 'on_wait'] + ) + def test_allocate_shared_buffer_in_core_map_with_race( + self, slow_core, dma_execution_mode + ): + mesh = pltpu.create_tensorcore_mesh('x', num_cores=2) + + @jax.jit + def f(x, y): + z = jnp.zeros_like(x) + o = jnp.zeros((y.shape[0], y.shape[0]), dtype=y.dtype) + + def inner(refs): + """Copies `x_ref` to `z_ref` and computes `y_ref @ y_ref^t` into `o_ref`.""" + x_ref, y_ref, z_ref, o_ref = refs + @pl.core_map( + mesh, + interpret=pltpu.InterpretParams( + detect_races=True, + allow_hbm_allocation_in_run_scoped=True, + dma_execution_mode=dma_execution_mode, + ), + ) + def _(): + # The slow core performs an expensive matrix multiplication, and then + # copies from `x_ref` to `z_ref`, going through an HBM buffer that is + # shared between the two cores. The other core, aka. the fast core, + # proceeds directly to copying from `x_ref` to `z_ref`, going through + # the same shared HBM buffer. If the shared buffer were, incorrectly, + # deallocated by the fast core (once it is done copying from `x_ref` + # to `z_ref`) and then reallocated by the slow core (before it starts + # copying from `x_ref` to `z_ref`), we would not see any attempts of + # accessing unallocated memory. However, we would also not detect any + # races since the cores would operate on separate buffers. + def body(x_hbm_ref, vmem_ref_0, vmem_ref_1): + @pl.when(jax.lax.axis_index('x') == slow_core) + def _(): + pltpu.sync_copy(y_ref, vmem_ref_0) + vmem_ref_1[...] = vmem_ref_0[...] @ jnp.transpose(vmem_ref_0[...]) + pltpu.sync_copy(vmem_ref_1, o_ref) + + pltpu.sync_copy(x_ref, x_hbm_ref) + pltpu.sync_copy(x_hbm_ref, z_ref) + + pl.run_scoped( + body, + pltpu.HBM(x.shape, dtype=x.dtype), + pltpu.VMEM(y.shape, dtype=y.dtype), + pltpu.VMEM((y.shape[0], y.shape[0]), dtype=y.dtype), + ) + + _, _, z, o = pl.run_state(inner)((x, y, z, o)) + return z, o + + x = jnp.arange(16 * 128, dtype=jnp.int32).reshape((16, 128)) + y = jax.random.randint( + jax.random.key(0), (1024, 1024), minval=-100, maxval=100 + ) + _, o = f(x, y) + # We do not assert that the first result of `f` must be equal to `x`. This + # is because of the copying from `x` to the first result of `f` is racy, and + # we should therefore not expect the first result of `f` to have a + # well-defined value. + np.testing.assert_array_equal(o, y @ jnp.transpose(y)) + self.assertTrue(mosaic_interpret.races.races_found) + + def test_two_cores_along_parallel_dimension_with_race(self): + def kernel(x_ref, o_ref, vmem_ref): + vmem_ref[...] = x_ref[...] + o_ref[...] = x_ref[...] + vmem_ref[...] + + x = jnp.ones((8, 128), jnp.float32) + trace_count = [0] + + @jax.jit + def f(x): + trace_count[0] += 1 + return pl.pallas_call( + kernel, + grid=(2,), + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], + scratch_shapes=[ + pltpu.VMEM(x.shape, x.dtype), + ], + compiler_params=pltpu.CompilerParams( + dimension_semantics=('parallel',), + ), + interpret=pltpu.InterpretParams( + num_cores_or_threads=2, + detect_races=False, + ), + )(x) + + y = f(x).block_until_ready() + self.assertFalse(mosaic_interpret.races.races_found) + np.testing.assert_allclose(y, 2.0 * x) + + with pltpu.force_tpu_interpret_mode(pltpu.InterpretParams( + num_cores_or_threads=1, + detect_races=True, + )): + y = f(x).block_until_ready() + self.assertFalse(mosaic_interpret.races.races_found) + np.testing.assert_allclose(y, 2.0 * x) + self.assertEqual(trace_count[0], 2) + + with pltpu.force_tpu_interpret_mode(pltpu.InterpretParams( + num_cores_or_threads=2, + detect_races=True, + )): + y = f(x).block_until_ready() + self.assertTrue(mosaic_interpret.races.races_found) + np.testing.assert_allclose(y, 2.0 * x) + self.assertEqual(trace_count[0], 3) + + def test_two_cores_along_parallel_dimension_no_race(self): + def kernel(x_ref, o_ref, vmem_ref): + vmem_ref[...] = x_ref[...] + o_ref[...] = x_ref[...] + vmem_ref[...] + + x = jnp.ones((16, 128), jnp.float32) + y = pl.pallas_call( + kernel, + grid=(2,), + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + out_specs=pl.BlockSpec( + (8, 128), + lambda i: (i, 0), + ), + in_specs=[ + pl.BlockSpec( + (8, 128), + lambda i: (i, 0), + ), + ], + scratch_shapes=[ + pltpu.VMEM((8, 128), x.dtype), + ], + interpret=pltpu.InterpretParams( + num_cores_or_threads=2, + detect_races=True, + ), + compiler_params=pltpu.CompilerParams( + dimension_semantics=('parallel',) + ), + )(x).block_until_ready() + self.assertFalse(mosaic_interpret.races.races_found) + np.testing.assert_allclose(y, 2.0 * x) + + def test_parallel_dimension_and_multiple_cores(self): + def kernel(s_ref, in_ref, o_ref): + # NOTE: diff should be 0. + diff = in_ref[...] - jnp.float32(4 * pl.program_id(0) + pl.program_id(1)) + + s = s_ref[0] + s_ref[0] = s + 1 + o_ref[:] = jax.lax.full_like(o_ref, s) + diff + + def kernel_call(s, num_cores_per_device, grid_point_recorder): + block_input = jnp.repeat( + jnp.repeat( + jnp.arange(16, dtype=jnp.float32).reshape((4, 4)), 128, axis=1), + 8, axis=0) + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((32, 512), jnp.float32), + grid=(4, 4), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.SMEM), + pl.BlockSpec((8, 128), lambda i, j: (i, j)), + ], + out_specs=pl.BlockSpec((8, 128), lambda i, j: (i, j)), + interpret=pltpu.InterpretParams( + random_seed=12345, + num_cores_or_threads=num_cores_per_device, + grid_point_recorder=grid_point_recorder, + detect_races=True, + ), + compiler_params=pltpu.CompilerParams( + dimension_semantics=('parallel', 'arbitrary') + ), + )(s, block_input) + + with self.subTest('num_cores_per_device=1'): + with GridPointRecorderContext() as grid_point_recorder: + result = jax.jit(kernel_call, static_argnums=(1, 2))( + jnp.zeros((1,), jnp.float32), 1, grid_point_recorder.get_recorder() + ) + np.testing.assert_allclose( + result[::8, ::128], + [ + [8.0, 9.0, 10.0, 11.0], + [12.0, 13.0, 14.0, 15.0], + [0.0, 1.0, 2.0, 3.0], + [4.0, 5.0, 6.0, 7.0], + ], + ) + self.assertListEqual( + grid_point_recorder.grid_points, + # parallel_subgrid_size = 4 + # num_parallel_points_per_core = (4 + 1 - 1) // 1 = 4 + # num_iterations_per_core = 4 * (16 // 4) = 16 + [ + ProcessedGridPoint((2, 0), 0), + ProcessedGridPoint((2, 1), 0), + ProcessedGridPoint((2, 2), 0), + ProcessedGridPoint((2, 3), 0), + ProcessedGridPoint((3, 0), 0), + ProcessedGridPoint((3, 1), 0), + ProcessedGridPoint((3, 2), 0), + ProcessedGridPoint((3, 3), 0), + ProcessedGridPoint((0, 0), 0), + ProcessedGridPoint((0, 1), 0), + ProcessedGridPoint((0, 2), 0), + ProcessedGridPoint((0, 3), 0), + ProcessedGridPoint((1, 0), 0), + ProcessedGridPoint((1, 1), 0), + ProcessedGridPoint((1, 2), 0), + ProcessedGridPoint((1, 3), 0), + ], + ) + + with self.subTest('num_cores_per_device=2'): + with GridPointRecorderContext() as grid_point_recorder: + result = jax.jit(kernel_call, static_argnums=(1, 2))( + jnp.zeros((1,), jnp.float32), 2, grid_point_recorder.get_recorder() + ) + np.testing.assert_allclose( + result[::8, ::128], + [ + [0.0, 1.0, 2.0, 3.0], + [4.0, 5.0, 6.0, 7.0], + [0.0, 1.0, 2.0, 3.0], + [4.0, 5.0, 6.0, 7.0], + ], + ) + self.assertListEqual( + grid_point_recorder.grid_points, + # parallel_subgrid_size = 4 + # num_parallel_points_per_core = (4 + 2 - 1) // 2 = 2 + # num_iterations_per_core = 2 * (16 // 4) = 8 + [ + ProcessedGridPoint((2, 0), 0), + ProcessedGridPoint((2, 1), 0), + ProcessedGridPoint((2, 2), 0), + ProcessedGridPoint((2, 3), 0), + ProcessedGridPoint((3, 0), 0), + ProcessedGridPoint((3, 1), 0), + ProcessedGridPoint((3, 2), 0), + ProcessedGridPoint((3, 3), 0), + ProcessedGridPoint((0, 0), 1), + ProcessedGridPoint((0, 1), 1), + ProcessedGridPoint((0, 2), 1), + ProcessedGridPoint((0, 3), 1), + ProcessedGridPoint((1, 0), 1), + ProcessedGridPoint((1, 1), 1), + ProcessedGridPoint((1, 2), 1), + ProcessedGridPoint((1, 3), 1), + ], + ) + + with self.subTest('num_cores_per_device=3'): + with GridPointRecorderContext() as grid_point_recorder: + result = jax.jit(kernel_call, static_argnums=(1, 2))( + jnp.zeros((1,), jnp.float32), 3, grid_point_recorder.get_recorder() + ) + np.testing.assert_allclose( + result[::8, ::128], + [ + [0.0, 1.0, 2.0, 3.0], + [4.0, 5.0, 6.0, 7.0], + [0.0, 1.0, 2.0, 3.0], + [4.0, 5.0, 6.0, 7.0], + ], + ) + self.assertListEqual( + grid_point_recorder.grid_points, + # parallel_subgrid_size = 4 + # num_parallel_points_per_core = (4 + 3 - 1) // 3 = 2 + # num_iterations_per_core = 2 * (16 // 4) = 8 + [ + ProcessedGridPoint((2, 0), 0), + ProcessedGridPoint((2, 1), 0), + ProcessedGridPoint((2, 2), 0), + ProcessedGridPoint((2, 3), 0), + ProcessedGridPoint((3, 0), 0), + ProcessedGridPoint((3, 1), 0), + ProcessedGridPoint((3, 2), 0), + ProcessedGridPoint((3, 3), 0), + ProcessedGridPoint((0, 0), 1), + ProcessedGridPoint((0, 1), 1), + ProcessedGridPoint((0, 2), 1), + ProcessedGridPoint((0, 3), 1), + ProcessedGridPoint((1, 0), 1), + ProcessedGridPoint((1, 1), 1), + ProcessedGridPoint((1, 2), 1), + ProcessedGridPoint((1, 3), 1), + ], + ) + + with self.subTest('num_cores_per_device=4'): + with GridPointRecorderContext() as grid_point_recorder: + result = jax.jit(kernel_call, static_argnums=(1, 2))( + jnp.zeros((1,), jnp.float32), 4, grid_point_recorder.get_recorder() + ) + np.testing.assert_allclose( + result[::8, ::128], + [ + [0.0, 1.0, 2.0, 3.0], + [0.0, 1.0, 2.0, 3.0], + [0.0, 1.0, 2.0, 3.0], + [0.0, 1.0, 2.0, 3.0], + ], + ) + self.assertListEqual( + grid_point_recorder.grid_points, + # parallel_subgrid_size = 4 + # num_parallel_points_per_core = (4 + 4 - 1) // 4 = 1 + # num_iterations_per_core = 1 * (16 // 4) = 4 + [ + ProcessedGridPoint((2, 0), 0), + ProcessedGridPoint((2, 1), 0), + ProcessedGridPoint((2, 2), 0), + ProcessedGridPoint((2, 3), 0), + ProcessedGridPoint((3, 0), 1), + ProcessedGridPoint((3, 1), 1), + ProcessedGridPoint((3, 2), 1), + ProcessedGridPoint((3, 3), 1), + ProcessedGridPoint((0, 0), 2), + ProcessedGridPoint((0, 1), 2), + ProcessedGridPoint((0, 2), 2), + ProcessedGridPoint((0, 3), 2), + ProcessedGridPoint((1, 0), 3), + ProcessedGridPoint((1, 1), 3), + ProcessedGridPoint((1, 2), 3), + ProcessedGridPoint((1, 3), 3), + ], + ) + + with self.subTest('num_cores_per_device=5'): + with GridPointRecorderContext() as grid_point_recorder: + result = jax.jit(kernel_call, static_argnums=(1, 2))( + jnp.zeros((1,), jnp.float32), 5, grid_point_recorder.get_recorder() + ) + np.testing.assert_allclose( + result[::8, ::128], + [ + [0.0, 1.0, 2.0, 3.0], + [0.0, 1.0, 2.0, 3.0], + [0.0, 1.0, 2.0, 3.0], + [0.0, 1.0, 2.0, 3.0], + ], + ) + self.assertListEqual( + grid_point_recorder.grid_points, + # parallel_subgrid_size = 4 + # num_parallel_points_per_core = (4 + 5 - 1) // 5 = 1 + # num_iterations_per_core = 1 * (16 // 4) = 4 + [ + ProcessedGridPoint((2, 0), 0), + ProcessedGridPoint((2, 1), 0), + ProcessedGridPoint((2, 2), 0), + ProcessedGridPoint((2, 3), 0), + ProcessedGridPoint((3, 0), 1), + ProcessedGridPoint((3, 1), 1), + ProcessedGridPoint((3, 2), 1), + ProcessedGridPoint((3, 3), 1), + ProcessedGridPoint((0, 0), 2), + ProcessedGridPoint((0, 1), 2), + ProcessedGridPoint((0, 2), 2), + ProcessedGridPoint((0, 3), 2), + ProcessedGridPoint((1, 0), 3), + ProcessedGridPoint((1, 1), 3), + ProcessedGridPoint((1, 2), 3), + ProcessedGridPoint((1, 3), 3), + ], + ) + + with self.subTest('num_cores_per_device=6'): + with GridPointRecorderContext() as grid_point_recorder: + result = jax.jit(kernel_call, static_argnums=(1, 2))( + jnp.zeros((1,), jnp.float32), 6, grid_point_recorder.get_recorder() + ) + np.testing.assert_allclose( + result[::8, ::128], + [ + [0.0, 1.0, 2.0, 3.0], + [0.0, 1.0, 2.0, 3.0], + [0.0, 1.0, 2.0, 3.0], + [0.0, 1.0, 2.0, 3.0], + ], + ) + self.assertListEqual( + grid_point_recorder.grid_points, + # parallel_subgrid_size = 4 + # num_parallel_points_per_core = (4 + 6 - 1) // 6 = 1 + # num_iterations_per_core = 1 * (16 // 4) = 4 + [ + ProcessedGridPoint((2, 0), 0), + ProcessedGridPoint((2, 1), 0), + ProcessedGridPoint((2, 2), 0), + ProcessedGridPoint((2, 3), 0), + ProcessedGridPoint((3, 0), 1), + ProcessedGridPoint((3, 1), 1), + ProcessedGridPoint((3, 2), 1), + ProcessedGridPoint((3, 3), 1), + ProcessedGridPoint((0, 0), 2), + ProcessedGridPoint((0, 1), 2), + ProcessedGridPoint((0, 2), 2), + ProcessedGridPoint((0, 3), 2), + ProcessedGridPoint((1, 0), 3), + ProcessedGridPoint((1, 1), 3), + ProcessedGridPoint((1, 2), 3), + ProcessedGridPoint((1, 3), 3), + ], + ) + + @parameterized.parameters(pltpu.HBM, pl.ANY) + def test_referencing_hbm_raises(self, disallowed_memory_space): + def jax_load_and_store(in_ref, o_ref): + o_ref[...] = in_ref[...] + + def pallas_load_and_store(in_ref, o_ref): + t = pltpu.load(in_ref) + pltpu.store(o_ref, t) + + def kernel_call(kernel, x, *, in_memory_space, out_memory_space): + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + grid=(1,), + in_specs=[pl.BlockSpec(memory_space=in_memory_space)], + out_specs=pl.BlockSpec(memory_space=out_memory_space), + interpret=pltpu.InterpretParams(), + )(x) + + with self.assertRaisesRegex( + ValueError, + r'get_p: Buffers with a memory space of HBM or ANY cannot be' + r' referenced directly. Instead, use `pltpu.sync_copy` or' + r' `pltpu.async_copy`.', + ): + kernel_call( + jax_load_and_store, + jnp.zeros((8, 128), jnp.float32), + in_memory_space=disallowed_memory_space, + out_memory_space=pltpu.VMEM, + ) + pltpu.reset_tpu_interpret_mode_state() + + with self.assertRaisesRegex( + ValueError, + r'load_p: Buffers with a memory space of HBM or ANY cannot be' + r' referenced directly. Instead, use `pltpu.sync_copy` or' + r' `pltpu.async_copy`.', + ): + kernel_call( + pallas_load_and_store, + jnp.zeros((8, 128), jnp.float32), + in_memory_space=disallowed_memory_space, + out_memory_space=pltpu.VMEM, + ) + pltpu.reset_tpu_interpret_mode_state() + + with self.assertRaisesRegex( + ValueError, + r'swap_p: Buffers with a memory space of HBM or ANY cannot be' + r' referenced directly. Instead, use `pltpu.sync_copy` or' + r' `pltpu.async_copy`.', + ): + kernel_call( + jax_load_and_store, + jnp.zeros((8, 128), jnp.float32), + in_memory_space=pltpu.VMEM, + out_memory_space=disallowed_memory_space, + ) + pltpu.reset_tpu_interpret_mode_state() + + with self.assertRaisesRegex( + ValueError, + r'swap_p: Buffers with a memory space of HBM or ANY cannot be' + r' referenced directly. Instead, use `pltpu.sync_copy` or' + r' `pltpu.async_copy`.', + ): + kernel_call( + pallas_load_and_store, + jnp.zeros((8, 128), jnp.float32), + in_memory_space=pltpu.VMEM, + out_memory_space=pltpu.HBM, + ) + pltpu.reset_tpu_interpret_mode_state() + -if __name__ == "__main__": +if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_pallas_interpret_thread_map_test.py b/tests/pallas/tpu_pallas_interpret_thread_map_test.py new file mode 100644 index 000000000000..c0a6b05f3574 --- /dev/null +++ b/tests/pallas/tpu_pallas_interpret_thread_map_test.py @@ -0,0 +1,72 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Thread map test for TPU-specific interpret mode.""" + +import threading + +from absl.testing import absltest +import jax +from jax._src import test_util as jtu +from jax._src.pallas.mosaic.interpret.thread_map import thread_map + + +jax.config.parse_flags_with_absl() +jax.config.update('jax_threefry_partitionable', True) + + +# TODO(jburnim): Figure out how to safely run different instance of TPU +# interpret mode in parallel, and then remove this decorator. +@jtu.thread_unsafe_test_class() +class InterpretThreadMapTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + + if not jtu.test_device_matches(['cpu']): + self.skipTest('CPU-only test') + + self.num_devices = jax.device_count() + if self.num_devices > 1: + # Workaround for https://github.com/jax-ml/jax/issues/25671 + self.skipTest(f'requires 1 device, found {self.num_devices}') + + def test_thread_map(self): + barrier = threading.Barrier(8) + lock = threading.Lock() + concurrent_calls = [0] + max_concurrent_calls = [0] + + def _barrier(): + with lock: + concurrent_calls[0] += 1 + max_concurrent_calls[0] = max( + max_concurrent_calls[0], concurrent_calls[0]) + barrier.wait() + with lock: + concurrent_calls[0] -= 1 + + def f(core_index): + del core_index + jax.experimental.io_callback(_barrier, (), ordered=True) + + thread_map(f, 8) + self.assertEqual(max_concurrent_calls[0], 8) + # `thread_map` returns only after all threads have completed, so the final + # value of `concurrent_calls` should be zero. + self.assertEqual(concurrent_calls[0], 0) + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_pallas_memory_space_test.py b/tests/pallas/tpu_pallas_memory_space_test.py new file mode 100644 index 000000000000..2e9b7a1c1d9f --- /dev/null +++ b/tests/pallas/tpu_pallas_memory_space_test.py @@ -0,0 +1,292 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test TPU-specific uses of Pallas memory space APIs.""" + +import functools +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import test_util as jtu +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +import jax.numpy as jnp +import numpy as np + + +jax.config.parse_flags_with_absl() +P = jax.sharding.PartitionSpec +partial = functools.partial + + +class TPUPallasCallMemorySpaceTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if not jtu.is_device_tpu_at_least(5): + self.skipTest('Needs a newer TPU') + + @parameterized.parameters( + (pltpu.VMEM, 1), + (pltpu.SMEM, 4), + (pltpu.HBM, 0), + (pl.ANY, None), + ) + def test_basic_input_memory_space_constraint(self, memory_space, color): + def kernel(x_ref, y_ref): + pltpu.sync_copy(x_ref, y_ref) + + def g(x): + return pl.pallas_call( + kernel, + out_shape=x, + in_specs=[pl.BlockSpec(memory_space=memory_space)], + out_specs=pl.BlockSpec(memory_space=pl.ANY), + )(x) + + @jax.jit + def f(x): + x = pltpu.with_memory_space_constraint(x, memory_space=memory_space) + if color is not None: + self.assertEqual(jax.typeof(x).memory_space, memory_space) + x = g(x) + return x + + x = jnp.ones((8, 128), dtype=jnp.float32) + y = f(x) + np.testing.assert_array_equal(y, x) + hlo = jax.jit(f).lower(x).compile().as_text() + if color is None or memory_space == pltpu.SMEM: + self.assertIn('"input_memory_space_colors":[]', hlo) + else: + self.assertIn( + f'"input_memory_space_colors":[{{"operand_index":"0","color":"{color}","shape_index":[]}}]', + hlo, + ) + + @parameterized.parameters( + (pltpu.VMEM, 1), + (pltpu.SMEM, 4), + (pltpu.HBM, 0), + (pl.ANY, None), + (pltpu.HOST, 5), + ) + def test_basic_output_memory_space_constraint(self, memory_space, color): + out_shape_ctor = memory_space + if color is None: + out_shape_ctor = jax.ShapeDtypeStruct + + def kernel(x_ref, y_ref): + pltpu.sync_copy(x_ref, y_ref) + + def g(x): + return pl.pallas_call( + kernel, + out_shape=out_shape_ctor(x.shape, x.dtype), + in_specs=[pl.BlockSpec(memory_space=pl.ANY)], + out_specs=pl.BlockSpec(memory_space=memory_space), + )(x) + + if memory_space == pltpu.HOST: + if jax.device_count() > 1: + self.skipTest('Test only works with a single device.') + out_sharding = jax.sharding.NamedSharding( + jax.sharding.Mesh(jax.devices(), 'x'), + jax.sharding.PartitionSpec(), + memory_kind='pinned_host', + ) + else: + out_sharding = None + + @functools.partial(jax.jit, out_shardings=out_sharding) + def f(x): + x = g(x) + return x + + x = jnp.ones((8, 128), dtype=jnp.float32) + y = f(x) + np.testing.assert_array_equal(y, x) + hlo = jax.jit(f, out_shardings=out_sharding).lower(x).compile().as_text() + if color is None: + self.assertIn('"output_memory_colors":[]', hlo) + else: + self.assertIn( + f'"output_memory_colors":["{color}"]', + hlo, + ) + + +class TPUCoreMapMemorySpaceTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if not jtu.is_device_tpu_at_least(5): + self.skipTest('Needs a newer TPU') + + @parameterized.parameters( + (pltpu.VMEM, 1), + (pltpu.SMEM, 4), + (pltpu.HBM, 0), + (pl.ANY, None), + ) + def test_basic_ref_memory_space_constraint(self, memory_space, color): + @jax.jit + def f(x): + x_ref = jax.new_ref(x, memory_space=memory_space) + y_ref = jax.new_ref(pl.empty_like(x), memory_space=memory_space) + + self.assertEqual(jax.typeof(x_ref).memory_space, memory_space) + self.assertEqual(jax.typeof(y_ref).memory_space, memory_space) + + @pl.core_map(mesh=pltpu.create_tensorcore_mesh('core')) + def _(): + if jax.typeof(x_ref).memory_space is pltpu.VMEM: + y_ref[...] = x_ref[...] + else: + pltpu.sync_copy(x_ref, y_ref) + + return y_ref[...] + + x = jnp.arange(1024, dtype=jnp.float32).reshape((8, 128)) + num_cores = jax.devices()[0].num_cores + if num_cores > 1 and memory_space == pltpu.VMEM: + with self.assertRaisesRegex( + NotImplementedError, + 'TensorCoreMesh does not support VMEM inputs/outputs when there are' + ' >1 cores. Use HBM or ANY instead.', + ): + f.lower(x).compile() + return + lowered = f.lower(x) + compiled = lowered.compile() + hlo = compiled.as_text() + if color is None or memory_space == pltpu.SMEM: + self.assertIn('"input_memory_space_colors":[]', hlo) + else: + self.assertIn( + f'"input_memory_space_colors":[{{"operand_index":"0","color":"{color}","shape_index":[]}},{{"operand_index":"1","color":"{color}","shape_index":[]}}]', + hlo, + ) + y = compiled(x) + np.testing.assert_array_equal(y, x) + + def test_smem_copy(self): + mesh = pltpu.create_tensorcore_mesh('core') + if len(mesh.devices) > 1: + self.skipTest('Only one core is supported for this test.') + + kernel = pl.core_map(mesh=mesh) + + @jax.jit + def f(): + y_ref = pl.empty_ref_like(pltpu.SMEM((8,), jnp.int32)) + + @kernel + def _(): + for i in range(y_ref.shape[0]): + y_ref[i] = i + + @kernel + def _(): + for i in range(y_ref.shape[0]): + y_ref[i] = y_ref[i] + 1 + + return y_ref[...] + + np.testing.assert_array_equal(f(), np.arange(8) + 1) + + def test_smem_async_copy(self): + mesh = pltpu.create_tensorcore_mesh('core') + if len(mesh.devices) > 1: + self.skipTest('Only one core is supported for this test.') + + kernel = pl.core_map(mesh=mesh) + + @jax.jit + def f(): + y_ref = pl.empty_ref_like(pltpu.SMEM((8,), jnp.int32)) + + @kernel + def _(): + for i in range(y_ref.shape[0]): + y_ref[i] = i + + @kernel + def _(): + for i in range(y_ref.shape[0]): + y_ref[i] = y_ref[i] + 1 + + y_out_ref = pl.empty_ref_like(pltpu.HBM((8,), jnp.int32)) + + sem = pl.empty_ref_like(pltpu.SemaphoreType.DMA(())) + + @kernel + def _(): + pltpu.make_async_copy(y_ref, y_out_ref, sem).start() + + @kernel + def _(): + pltpu.make_async_copy(y_ref, y_out_ref, sem).wait() + + return y_out_ref[...] + + np.testing.assert_array_equal(f(), np.arange(8) + 1) + + def test_smem_async_copy_megacore(self): + mesh = pltpu.create_tensorcore_mesh('core') + num_cores = len(mesh.devices) + if num_cores == 1: + self.skipTest('Only megacore is supported for this test.') + + kernel = pl.core_map(mesh=mesh) + n = 256 + + @jax.jit + def f(): + y_ref = pl.empty_ref_like(pltpu.SMEM((1, n), jnp.int32)) + + @kernel + def _(): + core_i = jax.lax.axis_index('core') + for i in range(n): + y_ref[0, i] = i + core_i * n + + @kernel + def _(): + for i in range(n): + y_ref[0, i] = y_ref[0, i] + 1 + + y_out_ref = pl.empty_ref_like(pltpu.HBM((num_cores, 1, n), jnp.int32)) + + sem = pl.empty_ref_like(pltpu.SemaphoreType.DMA(())) + + @kernel + def _(): + core_i = jax.lax.axis_index('core') + pltpu.make_async_copy(y_ref, y_out_ref.at[core_i, ...], sem).start() + + @kernel + def _(): + core_i = jax.lax.axis_index('core') + pltpu.make_async_copy(y_ref, y_out_ref.at[core_i, ...], sem).wait() + + return y_out_ref[...] + + np.testing.assert_array_equal( + f(), np.arange(num_cores * n).reshape((num_cores, 1, n)) + 1 + ) + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_pallas_pipeline_test.py b/tests/pallas/tpu_pallas_pipeline_test.py index 8e72c49e2598..62149556edc5 100644 --- a/tests/pallas/tpu_pallas_pipeline_test.py +++ b/tests/pallas/tpu_pallas_pipeline_test.py @@ -14,37 +14,37 @@ """Test TPU-specific extensions to pallas_call.""" +import dataclasses import functools from absl.testing import absltest from absl.testing import parameterized +import hypothesis as hp +import hypothesis.strategies as hps import jax from jax import lax +from jax._src import hijax +from jax._src import shard_map +from jax._src import state from jax._src import test_util as jtu +from jax._src.state import indexing +from jax._src.state import primitives as state_primitives from jax.experimental import mesh_utils from jax.experimental import pallas as pl -from jax.experimental import shard_map from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp import numpy as np -try: - import hypothesis as hp - import hypothesis.strategies as hps - CAN_USE_HYPOTHESIS = True -except (ModuleNotFoundError, ImportError): - CAN_USE_HYPOTHESIS = False - - -if CAN_USE_HYPOTHESIS: - hp.settings.register_profile( - 'deterministic', - database=None, - derandomize=True, - deadline=None, - max_examples=200, - print_blob=True, - verbosity=hp.Verbosity.verbose, - ) - hp.settings.load_profile('deterministic') + + +hp.settings.register_profile( + 'deterministic', + database=None, + derandomize=True, + deadline=None, + max_examples=200, + print_blob=True, + verbosity=hp.Verbosity.verbose, +) +hp.settings.load_profile('deterministic') jax.config.parse_flags_with_absl() @@ -127,20 +127,56 @@ def _reduce_out(): class PallasCallPipelineTest(parameterized.TestCase): def setUp(self): - if jax.device_count() < 2: - self.skipTest('Only >=2 devices are supported.') if not jtu.is_device_tpu_at_least(5): self.skipTest('Only works with TPU v5') super().setUp() - @parameterized.named_parameters( - ('vmem', pltpu.TPUMemorySpace.VMEM), - ('hbm', pltpu.TPUMemorySpace.ANY), + def test_pipeline_without_inputs(self): + def kernel(o_hbm_ref): + def body(o_ref): + o_ref[...] = jnp.full(o_ref.shape, 42, dtype=o_ref.dtype) + + pltpu.emit_pipeline( + body, grid=(4,), out_specs=pl.BlockSpec((8, 128), lambda i: (0, i)) + )(o_hbm_ref) + + out = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((8, 512), jnp.int32), + out_specs=pl.BlockSpec(memory_space=pl.ANY), + )() + np.testing.assert_allclose(out, jnp.full_like(out, 42)) + + def test_hbm_output(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct((8, 512), jnp.int32), + in_specs=[pl.BlockSpec(memory_space=pltpu.HBM)], + out_specs=pl.BlockSpec(memory_space=pltpu.HBM), + ) + def kernel(x_hbm_ref, o_hbm_ref): + @functools.partial( + pltpu.emit_pipeline, + grid=(4,), + in_specs=pl.BlockSpec((8, 128), lambda i: (0, i)), + out_specs=pl.BlockSpec( + (8, 512), lambda i: (0, 0), memory_space=pltpu.HBM + ), + ) + def pipeline(x_ref, o_ref): + i = pl.program_id(0) + pltpu.sync_copy(x_ref, o_ref.at[:, pl.ds(i * 128, 128)]) + + pipeline(x_hbm_ref, o_hbm_ref) + + x = jnp.arange(8 * 512).reshape(8, 512) + np.testing.assert_allclose(kernel(x), x) + + @parameterized.product( + no_pipelining=[False, True], ) - def test_pipeline_matmul(self, memory_space): - # TODO(b/358121809): Re-enable this test once the bug is fixed. - self.skipTest('Broken test.') + def test_pipeline_matmul(self, no_pipelining): k1, k2 = jax.random.split(jax.random.key(0)) x = jax.random.uniform(k1, (512, 512)) y = jax.random.uniform(k2, (512, 512)) @@ -161,16 +197,17 @@ def matmul_kernel(x_ref, y_ref, z_ref): pl.BlockSpec((128, 128), lambda i, j, k: (k, j)), ], out_specs=pl.BlockSpec((128, 128), lambda i, j, k: (i, j)), + no_pipelining=no_pipelining, )(x_ref, y_ref, z_ref) z = pl.pallas_call( matmul_kernel, out_shape=jax.ShapeDtypeStruct((512, 512), jnp.float32), in_specs=[ - pl.BlockSpec(memory_space=memory_space), - pl.BlockSpec(memory_space=memory_space), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=memory_space), + out_specs=pl.BlockSpec(memory_space=pl.ANY), ) jax.block_until_ready(z(x, y)) @@ -179,11 +216,11 @@ def matmul_kernel(x_ref, y_ref, z_ref): out = jax.block_until_ready(z(x, y)) expected_out = jax.block_until_ready(jnp.dot(x, y)) - np.testing.assert_allclose(out, expected_out) + np.testing.assert_allclose(out, expected_out, atol=5e-5) @parameterized.named_parameters( - ('vmem', pltpu.TPUMemorySpace.VMEM), - ('hbm', pltpu.TPUMemorySpace.ANY), + ('vmem', pltpu.VMEM), + ('hbm', pl.ANY), ) def test_double_pipeline_matmul(self, memory_space): # TODO(b/358121809): Re-enable this test once the bug is fixed. @@ -229,22 +266,525 @@ def emit_pipeline(should_accumulate_out): np.testing.assert_allclose(z, jnp.dot(x, y) + jnp.dot(x, y)) +class PallasCallMultipleBufferedPipelineTest(parameterized.TestCase): + + def setUp(self): + if jax.device_count() > 1: + self.skipTest('Only 1 device is supported.') + if not jtu.is_device_tpu_at_least(5): + self.skipTest('Only works with TPU v5+') + super().setUp() + + @parameterized.product( + in_buffer_count=[2, 4], + out_buffer_count=[2], + ) + def test_copy(self, in_buffer_count, out_buffer_count): + x = jnp.reshape(jnp.arange(512 * 512), (512, 512)) + def copy_kernel(x_hbm_ref, o_hbm_ref): + def inner_kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + pltpu.emit_pipeline( + inner_kernel, + grid=(4, 4), + in_specs=[ + pl.BlockSpec((128, 128), lambda i, j: (i, j), + pipeline_mode=pl.Buffered(buffer_count=in_buffer_count)), + ], + out_specs=pl.BlockSpec((128, 128), lambda i, j: (i, j), + pipeline_mode=pl.Buffered(buffer_count=out_buffer_count)), + )(x_hbm_ref, o_hbm_ref) + fn = pl.pallas_call( + copy_kernel, + out_shape=jax.ShapeDtypeStruct(x.shape, jnp.int32), + in_specs=[ + pl.BlockSpec(memory_space=pl.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pl.ANY), + ) + result = fn(x) + np.testing.assert_allclose(result, x) + + @parameterized.product( + x_buffer_count=[2, 4], + y_buffer_count=[2, 4], + out_buffer_count=[2], + ) + def test_matmul(self, x_buffer_count, y_buffer_count, out_buffer_count): + block_shape = (128, 128) + x = jax.random.uniform(jax.random.key(0), (512, 512)) + y = jax.random.uniform(jax.random.key(1), (512, 512)) + def matmul_kernel(x_hbm_ref, y_hbm_ref, o_hbm_ref): + def pipeline_step(x_ref, y_ref, o_ref): + @pl.when(pl.program_id(2) == 0) + def _(): + o_ref[...] = jnp.zeros(o_ref.shape, jnp.float32) + o_ref[...] += x_ref[...] @ y_ref[...] + pltpu.emit_pipeline( + pipeline_step, + grid=( + 512 // block_shape[0], + 512 // block_shape[0], + 512 // block_shape[0], + ), + in_specs=[ + pl.BlockSpec( + block_shape, + lambda i, j, k: (i, k), + pipeline_mode=pl.Buffered(buffer_count=x_buffer_count), + ), + pl.BlockSpec( + block_shape, + lambda i, j, k: (k, j), + pipeline_mode=pl.Buffered(buffer_count=y_buffer_count), + ), + ], + out_specs=pl.BlockSpec( + block_shape, + lambda i, j, k: (i, j), + pipeline_mode=pl.Buffered(buffer_count=out_buffer_count), + ), + )(x_hbm_ref, y_hbm_ref, o_hbm_ref) + fn = pl.pallas_call( + matmul_kernel, + out_shape=jax.ShapeDtypeStruct(x.shape, jnp.float32), + in_specs=[ + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pl.ANY), + ) + result = fn(x, y) + np.testing.assert_allclose(result, x @ y, atol=5e-5) + + @parameterized.product( + x_buffer_count=[2, 4], + y_buffer_count=[2, 4], + out_buffer_count=[2], + ) + def test_matmul_megacore(self, + x_buffer_count, y_buffer_count, out_buffer_count): + block_shape = (128, 128) + x = jax.random.uniform(jax.random.key(0), (512, 512)) + y = jax.random.uniform(jax.random.key(1), (512, 512)) + def matmul_kernel(x_hbm_ref, y_hbm_ref, o_hbm_ref): + def pipeline_step(x_ref, y_ref, o_ref): + @pl.when(pl.program_id(2) == 0) + def _(): + o_ref[...] = jnp.zeros(o_ref.shape, jnp.float32) + o_ref[...] += x_ref[...] @ y_ref[...] + pltpu.emit_pipeline( + pipeline_step, + core_axis=0, + grid=(512 // block_shape[0], 512 // block_shape[0], 512 // block_shape[0]), + dimension_semantics=(pltpu.PARALLEL, pltpu.PARALLEL, pltpu.ARBITRARY), + in_specs=[ + pl.BlockSpec(block_shape, lambda i, j, k: (i, k), + pipeline_mode=pl.Buffered(buffer_count=x_buffer_count)), + pl.BlockSpec(block_shape, lambda i, j, k: (k, j), + pipeline_mode=pl.Buffered(buffer_count=y_buffer_count)), + ], + out_specs=pl.BlockSpec(block_shape, lambda i, j, k: (i, j), + pipeline_mode=pl.Buffered(buffer_count=out_buffer_count)), + )(x_hbm_ref, y_hbm_ref, o_hbm_ref) + fn = pl.pallas_call( + matmul_kernel, + grid=(2,), + out_shape=jax.ShapeDtypeStruct(x.shape, jnp.float32), + in_specs=[ + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pl.ANY), + compiler_params=pltpu.CompilerParams( + dimension_semantics=(pltpu.PARALLEL,) + ), + ) + result = fn(x, y) + np.testing.assert_allclose(result, x @ y, atol=5e-5) + + @parameterized.product( + in_buffer_count=[2, 4], + out_buffer_count=[2], + in_block_indices=[ + [2, 2, 2, 2, 2, 2, 2, 2, 2, 2], + [0, 0, 2, 3, 3, 3, 1, 7, 6, 6], + [3, 3, 3, 3, 3, 3, 3, 7, 6, 6], + [0, 1, 2, 3, 4, 5, 0, 1, 2, 3], + ], + use_lookahead=[True, False], + ) + def test_block_gather(self, in_block_indices, in_buffer_count, + out_buffer_count, use_lookahead): + # Excercises pipeline with repeated input block indices. + block_size = 128 + x = jnp.reshape(jnp.arange(1024 * 128), (1024, 128)) + + def copy_kernel(x_hbm_ref, blk_indices_ref, o_hbm_ref): + def x_index_map(i): + return (blk_indices_ref[i], 0) + def inner_kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + pltpu.emit_pipeline( + inner_kernel, + grid=(len(in_block_indices),), + in_specs=[ + pl.BlockSpec( + (128, 128), + index_map=x_index_map, + pipeline_mode=pl.Buffered(buffer_count=in_buffer_count, + use_lookahead=use_lookahead), + ), + ], + out_specs=pl.BlockSpec( + (128, 128), + lambda i: (i, 0), + pipeline_mode=pl.Buffered(buffer_count=out_buffer_count), + ), + )(x_hbm_ref, o_hbm_ref) + fn = pl.pallas_call( + copy_kernel, + out_shape=jax.ShapeDtypeStruct((len(in_block_indices) * 128, 128), jnp.int32), + in_specs=[ + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pltpu.SMEM), + ], + out_specs=pl.BlockSpec(memory_space=pl.ANY), + ) + result = fn(x, jnp.array(in_block_indices)) + + expected = [] + for blk_idx in in_block_indices: + expected.append(x[blk_idx * block_size:(blk_idx + 1) * block_size, :]) + expected = jnp.concatenate(expected, axis=0) + np.testing.assert_allclose(result, expected) + + @parameterized.product( + in_buffer_count=[2, 4], + out_buffer_count=[2], + out_block_indices=[ + [0, 0, 2, 2, 2, 5, 3, 3], + [5, 5, 5, 5, 5, 5, 5, 5], + ], + ) + def test_block_scatter(self, out_block_indices, in_buffer_count, + out_buffer_count): + # Excercises pipeline with repeated output block indices. + block_size = 128 + x = jnp.reshape(jnp.arange(1024 * 128), (1024, 128)) + + def copy_kernel(x_hbm_ref, blk_indices_ref, o_hbm_ref): + # zero-out o_hbm_ref + @functools.partial(pl.run_scoped, + o_vmem=pltpu.VMEM((1024, 128), jnp.int32)) + def _(o_vmem): + o_vmem[...] = jnp.zeros(o_vmem.shape, jnp.int32) + pltpu.sync_copy(o_vmem, o_hbm_ref) + + def o_index_map(i): + return (blk_indices_ref[i], 0) + def inner_kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + pltpu.emit_pipeline( + inner_kernel, + grid=(8,), + in_specs=[ + pl.BlockSpec((128, 128), index_map=lambda i: (i, 0), + pipeline_mode=pl.Buffered(buffer_count=in_buffer_count)), + ], + out_specs=pl.BlockSpec((128, 128), o_index_map, + pipeline_mode=pl.Buffered(buffer_count=out_buffer_count)), + )(x_hbm_ref, o_hbm_ref) + fn = pl.pallas_call( + copy_kernel, + out_shape=jax.ShapeDtypeStruct((1024, 128), jnp.int32), + in_specs=[ + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pltpu.SMEM), + ], + out_specs=pl.BlockSpec(memory_space=pl.ANY), + ) + result = fn(x, jnp.array(out_block_indices)) + + expected = [jnp.zeros((128, 128), jnp.int32)] * 8 + for i, blk_idx in enumerate(out_block_indices): + expected[blk_idx] = x[i * block_size:(i + 1) * block_size, :] + expected = jnp.concatenate(expected, axis=0) + np.testing.assert_allclose(result, expected) + + @parameterized.product( + in_buffer_count=[2, 4], + out_buffer_count=[2], + + ) + def test_copy_with_multiple_cycles(self, in_buffer_count, out_buffer_count): + x = jnp.reshape(jnp.arange(512 * 512), (512, 512)) + def copy_kernel(x_hbm_ref, o_hbm_ref): + def inner_kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + pipeline_fn, make_allocations = pltpu.emit_pipeline_with_allocations( + inner_kernel, + grid=(2, 4), + in_specs=[ + pl.BlockSpec((128, 128), lambda i, j: (i, j), + pipeline_mode=pl.Buffered(buffer_count=in_buffer_count)), + ], + out_specs=pl.BlockSpec((128, 128), lambda i, j: (i, j), + pipeline_mode=pl.Buffered(buffer_count=out_buffer_count)), + ) + def prefetch(x_bref, o_bref, scheduler): + del o_bref + # Prefetch will use a 0, 0 index so we need to slice x_hbm_ref + scheduler.prefetch(x_bref, x_hbm_ref.at[256:, :]) + + @functools.partial(pl.run_scoped, + allocations=make_allocations(x_hbm_ref, + o_hbm_ref, + should_accumulate_out=(False,))) + def _(allocations): + pipeline_fn(x_hbm_ref.at[:256, :], o_hbm_ref.at[:256, :], + allocations=allocations, + first_cycle=True, + last_cycle=False, + prefetch=prefetch, + postyeet=None, + ) + pipeline_fn(x_hbm_ref.at[256:, :], o_hbm_ref.at[256:, :], + allocations=allocations, + first_cycle=False, + last_cycle=True, + prefetch=None, + postyeet=None, + ) + fn = pl.pallas_call( + copy_kernel, + out_shape=jax.ShapeDtypeStruct(x.shape, jnp.int32), + in_specs=[ + pl.BlockSpec(memory_space=pl.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pl.ANY), + ) + result = fn(x) + np.testing.assert_allclose(result, x) + + @parameterized.product( + in_buffer_count=[2, 4], + out_buffer_count=[2], + in_block_indices=[ + [0, 1, 2, 3, 4, 5], + [2, 2, 2, 2, 2, 2], + [0, 0, 7, 7, 4, 4], + [3, 3, 7, 7, 5, 3, 3], + [5], + ], + use_lookahead=[True, False], + ) + def test_block_gather_with_multiple_cycles( + self, in_block_indices, in_buffer_count, out_buffer_count, use_lookahead + ): + # Exercises pipeline with repeated input block indices. + block_size = 128 + x = jnp.reshape(jnp.arange(1024 * 128), (1024, 128)) + blk_len = len(in_block_indices) + + def copy_kernel(x_hbm_ref, blk_indices_ref, o_hbm_ref, blk_idx_offset): + blk_idx_offset[0] = 0 + def inner_kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + def x_index_map(i): + return (blk_indices_ref[i], 0) + pipeline_fn, make_allocations = pltpu.emit_pipeline_with_allocations( + inner_kernel, + grid=(blk_len,), + in_specs=[ + pl.BlockSpec( + (128, 128), + index_map=x_index_map, + pipeline_mode=pl.Buffered(buffer_count=in_buffer_count, + use_lookahead=use_lookahead), + ), + ], + out_specs=pl.BlockSpec( + (128, 128), + lambda i: (i + blk_idx_offset[0], 0), + pipeline_mode=pl.Buffered(buffer_count=out_buffer_count), + ), + ) + def prefetch(x_bref, o_bref, scheduler): + del o_bref + scheduler.prefetch(x_bref, x_hbm_ref) + + @functools.partial(pl.run_scoped, + allocations=make_allocations(x_hbm_ref, + o_hbm_ref, + should_accumulate_out=(False,))) + def _(allocations): + pipeline_fn(x_hbm_ref, o_hbm_ref, + allocations=allocations, + first_cycle=True, + last_cycle=False, + prefetch=prefetch, + postyeet=None, + ) + blk_idx_offset[0] = blk_len + pipeline_fn(x_hbm_ref, o_hbm_ref, + allocations=allocations, + first_cycle=False, + last_cycle=True, + prefetch=None, + postyeet=None, + ) + fn = pl.pallas_call( + copy_kernel, + out_shape=jax.ShapeDtypeStruct((blk_len * 2 * 128, 128), jnp.int32), + in_specs=[ + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pltpu.SMEM), + ], + out_specs=pl.BlockSpec(memory_space=pl.ANY), + scratch_shapes = [pltpu.SMEM((1,), dtype=jnp.int32)] + ) + result = jax.block_until_ready(fn(x, jnp.array(in_block_indices))) + + expected = [] + for blk_idx in [*in_block_indices, *in_block_indices]: + x_block = x[blk_idx * block_size:(blk_idx + 1) * block_size, :] + expected.append(x_block) + expected = jax.block_until_ready(jnp.concatenate(expected, axis=0)) + np.testing.assert_allclose(result, expected) + + @parameterized.product( + in_buffer_count=[2, 4], + ) + def test_pipeline_with_accumulator(self, in_buffer_count): + x = jnp.reshape(jnp.arange(1024 * 128), (1024, 128)) // (128*128) + accum_schedule = pltpu.get_pipeline_schedule('fixed') + def copy_kernel(x_hbm_ref, o_hbm_ref): + def inner_kernel(x_ref, o_ref): + @pl.when(pl.program_id(0) == 0) + def _(): + o_ref[...] = jnp.zeros_like(o_ref) + o_ref[...] += x_ref[...] + pipeline_fn, make_allocations = pltpu.emit_pipeline_with_allocations( + inner_kernel, + grid=(4,), + in_specs=[ + pl.BlockSpec( + (128, 128), + lambda i: (i, 0), + pipeline_mode=pl.Buffered(buffer_count=in_buffer_count), + ), + ], + out_specs=pl.BlockSpec((128, 128), lambda i: (0, 0)), + should_accumulate_out=True, + ) + def prefetch(x_bref, o_bref, scheduler): + del o_bref + # Prefetch will use a 0, 0 index so we need to slice x_hbm_ref + scheduler.prefetch(x_bref, x_hbm_ref.at[512:, :]) + + @functools.partial(pl.run_scoped, + allocations=make_allocations(x_hbm_ref, + o_hbm_ref, + should_accumulate_out=(True,))) + def _(allocations): + pipeline_fn(x_hbm_ref.at[:512, :], o_hbm_ref, + allocations=allocations, + first_cycle=True, + last_cycle=False, + prefetch=prefetch, + postyeet=None, + init_accumulators=True, + schedule=(None, accum_schedule) + ) + pipeline_fn(x_hbm_ref.at[512:, :], o_hbm_ref, + allocations=allocations, + first_cycle=False, + last_cycle=True, + prefetch=None, + postyeet=None, + init_accumulators=False, + schedule=(None, accum_schedule) + ) + fn = pl.pallas_call( + copy_kernel, + out_shape=jax.ShapeDtypeStruct((128, 128), jnp.int32), + in_specs=[ + pl.BlockSpec(memory_space=pl.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pl.ANY), + ) + result = fn(x) + expected = 0 + for i in range(x.shape[0] // 128): + x_blk = x[i * 128:(i + 1) * 128, :] + expected += x_blk + np.testing.assert_allclose(result, expected) + + def test_matmul_with_input_output(self): + M, N, K = 512, 512, 512 + blk_m, blk_n, blk_k = 128, 128, 128 + nm, nn, nk = M // blk_m, N // blk_n, K // blk_k + inner_allocs = [ + pltpu.BufferedRef.input( + pl.BlockSpec((blk_m, blk_k), lambda n, m, k: (m, k)), jnp.float32), + pltpu.BufferedRef.input( + pl.BlockSpec((blk_k, blk_n), lambda n, m, k: (k, n)), jnp.float32), + pltpu.BufferedRef.input_output( + pl.BlockSpec((blk_m, blk_n), lambda n, m, k: (m, n)), jnp.float32), + ] + + def matmul_kernel(x_hbm, y_hbm, o_hbm, x_bref, y_bref, o_bref): + def inner_kernel(x_ref, y_ref, o_ref): + @pl.when(pl.program_id(2) == 0) + def _(): + o_ref[...] = jnp.zeros_like(o_ref) + o_ref[...] += x_ref[...] @ y_ref[...] + + pltpu.emit_pipeline( + inner_kernel, + grid=(nm, nn, nk), + )( + x_hbm, y_hbm, o_hbm, + allocations=[x_bref, y_bref, o_bref] + ) + + x = jax.random.uniform(jax.random.key(0), (M, K), jnp.float32) + y = jax.random.uniform(jax.random.key(1), (K, N), jnp.float32) + fn = pl.pallas_call( + matmul_kernel, + out_shape=jax.ShapeDtypeStruct((M, N), jnp.float32), + in_specs=[ + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pl.ANY), + scratch_shapes=inner_allocs, + ) + result = fn(x, y) + np.testing.assert_allclose(result, x @ y, atol=5e-5) + + class PallasCallCollectivePipelineTest(parameterized.TestCase): def setUp(self): if jax.device_count() < 2: self.skipTest('Only >=2 devices are supported.') + if jtu.is_device_tpu(7, variant='x') and jax.device_count() < 4: + # v7x consists of pairs of chips that share the same ICI connection, + # so we need at least 4 chips to test collectives. + self.skipTest('Only >=4 devices are supported on TPU v7x.') if not jtu.is_device_tpu_at_least(5): self.skipTest('Only works with TPU v5') super().setUp() @parameterized.named_parameters( - ('vmem', pltpu.TPUMemorySpace.VMEM, jnp.bfloat16, 2, 2, 2), - ('hbm', pltpu.TPUMemorySpace.ANY, jnp.bfloat16, 2, 2, 2), - ('hbm_float32', pltpu.TPUMemorySpace.ANY, jnp.float32, 2, 2, 2), - ('hbm_float32_112', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 1, 2), - ('hbm_float32_111', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 1, 1), + ('vmem', pltpu.VMEM, jnp.bfloat16, 2, 2, 2), + ('hbm', pl.ANY, jnp.bfloat16, 2, 2, 2), + ('hbm_float32', pl.ANY, jnp.float32, 2, 2, 2), + ('hbm_float32_112', pl.ANY, jnp.float32, 1, 1, 2), + ('hbm_float32_111', pl.ANY, jnp.float32, 1, 1, 1), ) def test_pipeline_latency_optimized_allgather_matmul( self, memory_space, out_dtype, n_tiles, m_tiles, k_tiles): @@ -486,7 +1026,7 @@ def _wait_on_prev_dma(): + [pltpu.SemaphoreType.DMA] * 4 + inner_allocs ), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( collective_id=0, # must set scoped vmem flag *larger* than below! e.g.: # flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 @@ -502,7 +1042,7 @@ def _wait_on_prev_dma(): ), in_specs=(P(None, 'x'), P(None, None)), out_specs=P(None, None), - check_rep=False, + check_vma=False, ) test = jax.jit(shard(kernel)) @@ -530,11 +1070,11 @@ def reference(x, y): ) @parameterized.named_parameters( - ('vmem', pltpu.TPUMemorySpace.VMEM, jnp.bfloat16, 2, 2, 2), - ('hbm', pltpu.TPUMemorySpace.ANY, jnp.bfloat16, 2, 2, 2), - ('hbm_float32', pltpu.TPUMemorySpace.ANY, jnp.float32, 2, 2, 2), - ('hbm_float32_122', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 2, 2), - ('hbm_float32_121', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 2, 1), + ('vmem', pltpu.VMEM, jnp.bfloat16, 2, 2, 2), + ('hbm', pl.ANY, jnp.bfloat16, 2, 2, 2), + ('hbm_float32', pl.ANY, jnp.float32, 2, 2, 2), + ('hbm_float32_122', pl.ANY, jnp.float32, 1, 2, 2), + ('hbm_float32_121', pl.ANY, jnp.float32, 1, 2, 1), ) def test_pipeline_throughput_optimized_allgather_matmul( self, memory_space, out_dtype, n_tiles, m_tiles, k_tiles): @@ -720,20 +1260,20 @@ def _wait_on_prev_dma(): pl.BlockSpec(memory_space=memory_space), pl.BlockSpec(memory_space=memory_space), ], - out_specs=[pl.BlockSpec(memory_space=memory_space), - pl.BlockSpec(memory_space=memory_space)], + out_specs=[ + pl.BlockSpec(memory_space=memory_space), + pl.BlockSpec(memory_space=memory_space), + ], grid=(outer_steps, 2), - scratch_shapes=[ - pltpu.VMEM((tm, tn), jnp.float32)] + scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)] + [pltpu.SemaphoreType.DMA] * 4 - + inner_allocs + + inner_allocs, ), - compiler_params=dict( - mosaic=dict(collective_id=0, - # must set scoped vmem flag *larger* than below! e.g.: - # flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 - vmem_limit_bytes=int(134217728 * 0.9) # 0.9 * 128MiB - ) + compiler_params=pltpu.CompilerParams( + collective_id=0, + # must set scoped vmem flag *larger* than below! e.g.: + # flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 + vmem_limit_bytes=int(134217728 * 0.9), # 0.9 * 128MiB ), ) @@ -745,7 +1285,7 @@ def _wait_on_prev_dma(): ), in_specs=(P(None, 'x'), P(None, None)), out_specs=P(None, None), - check_rep=False, + check_vma=False, ) test = jax.jit(shard(kernel)) @@ -773,11 +1313,11 @@ def reference(x, y): ) @parameterized.named_parameters( - ('vmem', pltpu.TPUMemorySpace.VMEM, jnp.bfloat16, 2, 2, 2), - ('hbm', pltpu.TPUMemorySpace.ANY, jnp.bfloat16, 2, 2, 2), - ('hbm_float32', pltpu.TPUMemorySpace.ANY, jnp.float32, 2, 4, 2), - ('hbm_float32_112', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 1, 2), - ('hbm_float32_111', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 1, 1), + ('vmem', pltpu.VMEM, jnp.bfloat16, 2, 2, 2), + ('hbm', pl.ANY, jnp.bfloat16, 2, 2, 2), + ('hbm_float32', pl.ANY, jnp.float32, 2, 4, 2), + ('hbm_float32_112', pl.ANY, jnp.float32, 1, 1, 2), + ('hbm_float32_111', pl.ANY, jnp.float32, 1, 1, 1), ) def test_pipeline_latency_optimized_matmul_reducescatter( self, memory_space, out_dtype, n_tiles, m_tiles, k_tiles): @@ -1010,15 +1550,13 @@ def _loop_epilogue(): grid=(outer_steps, 2), scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)] + [pltpu.SemaphoreType.DMA] * 4 - + inner_allocs + + inner_allocs, ), - compiler_params=dict( - mosaic=dict( - collective_id=0, - # must set scoped vmem flag *larger* than below! - # e.g. flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 - vmem_limit_bytes=int(134217728 * 0.9) # 0.9 * 128MiB - ) + compiler_params=pltpu.CompilerParams( + collective_id=0, + # must set scoped vmem flag *larger* than below! + # e.g. flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 + vmem_limit_bytes=int(134217728 * 0.9), # 0.9 * 128MiB ), ) @@ -1031,7 +1569,7 @@ def _loop_epilogue(): ), in_specs=(P(None, 'x'), P('x', None)), out_specs=P('x', None), - check_rep=False, + check_vma=False, ) test = jax.jit(shard(lambda x, y: kernel(x, y)[0, 0])) @@ -1062,11 +1600,11 @@ def reference(x, y): np.mean(np.abs(out - expected_out)) @parameterized.named_parameters( - ('vmem', pltpu.TPUMemorySpace.VMEM, jnp.bfloat16, 2, 2, 2), - ('hbm', pltpu.TPUMemorySpace.ANY, jnp.bfloat16, 2, 2, 2), - ('hbm_float32', pltpu.TPUMemorySpace.ANY, jnp.float32, 2, 4, 2), - ('hbm_float32_112', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 2, 2), - ('hbm_float32_111', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 2, 1), + ('vmem', pltpu.VMEM, jnp.bfloat16, 2, 2, 2), + ('hbm', pl.ANY, jnp.bfloat16, 2, 2, 2), + ('hbm_float32', pl.ANY, jnp.float32, 2, 4, 2), + ('hbm_float32_112', pl.ANY, jnp.float32, 1, 2, 2), + ('hbm_float32_111', pl.ANY, jnp.float32, 1, 2, 1), ) def test_pipeline_throughput_optimized_matmul_reducescatter( self, memory_space, out_dtype, n_tiles, m_tiles, k_tiles): @@ -1273,15 +1811,13 @@ def _prefetch_accumulator(): grid=(outer_steps, 2), scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)] + [pltpu.SemaphoreType.DMA] * 4 - + inner_allocs + + inner_allocs, ), - compiler_params=dict( - mosaic=dict( - collective_id=0, - # must set scoped vmem flag *larger* than below! - # e.g. flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 - vmem_limit_bytes=int(134217728 * 0.9) # 0.9 * 128MiB - ) + compiler_params=pltpu.CompilerParams( + collective_id=0, + # must set scoped vmem flag *larger* than below! + # e.g. flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 + vmem_limit_bytes=int(134217728 * 0.9), # 0.9 * 128MiB ), ) @@ -1294,7 +1830,7 @@ def _prefetch_accumulator(): ), in_specs=(P(None, 'x'), P('x', None)), out_specs=P('x', None), - check_rep=False, + check_vma=False, ) test = jax.jit(shard(lambda x, y: kernel(x, y)[1])) @@ -1357,12 +1893,14 @@ def mul_kernel(iters_ref, x_ref, y_ref): grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), grid=(num_cores,), ), - compiler_params=dict(mosaic=dict(dimension_semantics=('parallel',))), + compiler_params=pltpu.CompilerParams( + dimension_semantics=('parallel',) + ), ) x = jax.random.uniform(jax.random.key(0), (640, 640)) np.testing.assert_allclose(func(jnp.array([5]), x), x * 2) @@ -1392,11 +1930,13 @@ def matmul_kernel(x_ref, y_ref): matmul_kernel, out_shape=jax.ShapeDtypeStruct((512, 512), jnp.float32), in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), grid=(num_cores,), - compiler_params=dict(mosaic=dict(dimension_semantics=('parallel',))), + compiler_params=pltpu.CompilerParams( + dimension_semantics=('parallel',) + ), ) np.testing.assert_allclose(func(x), x * 2) @@ -1440,114 +1980,491 @@ def matmul_kernel(x_ref, y_ref, z_ref, *, bm, bk, bn): functools.partial(matmul_kernel, bm=bm, bk=bk, bn=bn), out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), grid=(num_cores,), - compiler_params=dict(mosaic=dict(dimension_semantics=('parallel',))), + compiler_params=pltpu.CompilerParams( + dimension_semantics=('parallel',) + ), ) np.testing.assert_allclose(func(x, y), x @ y, atol=7e-5) -if CAN_USE_HYPOTHESIS: +@partial(jax.jit, static_argnames=['bm', 'bk', 'bn']) +def matmul(x: jax.Array, y: jax.Array, *, bm: int, bk: int, bn: int): + + m, k = x.shape + _, n = y.shape - @partial(jax.jit, static_argnames=['bm', 'bk', 'bn']) - def matmul(x: jax.Array, y: jax.Array, *, bm: int, bk: int, bn: int): + def kernel(x_hbm_ref, y_hbm_ref, o_hbm_ref): - m, k = x.shape - _, n = y.shape + grid = (pl.cdiv(m, bm), pl.cdiv(n, bn), pl.cdiv(k, bk)) - def kernel(x_hbm_ref, y_hbm_ref, o_hbm_ref): + def run(acc_scratch_ref): + pltpu.emit_pipeline( + partial(basic_matmul_kernel, acc_scratch_ref=acc_scratch_ref, k=k), + in_specs=[ + pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)), + pl.BlockSpec((bk, bn), lambda i, j, k: (k, j)), + ], + out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)), + grid=grid, + core_axis=0, + dimension_semantics=( + pltpu.PARALLEL, + pltpu.PARALLEL, + pltpu.ARBITRARY, + ), + )(x_hbm_ref, y_hbm_ref, o_hbm_ref) + + accum_dtype = ( + jnp.float32 if jnp.issubdtype(x.dtype, jnp.floating) else jnp.int32 + ) + pl.run_scoped(run, pltpu.VMEM((bm, bn), accum_dtype)) + + num_cores = jax.devices()[0].num_cores + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((m, n), x.dtype), + in_specs=[ + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pl.ANY), + grid=(num_cores,), + )(x, y) + +@jtu.thread_unsafe_test_class(condition=not jtu.hypothesis_is_thread_safe()) +class PaddedPipelineEmitterTest(parameterized.TestCase): - grid = (pl.cdiv(m, bm), pl.cdiv(n, bn), pl.cdiv(k, bk)) + def setUp(self): + super().setUp() + if not jtu.is_device_tpu_at_least(4): + self.skipTest('Only TPU v4+ allowed.') - def run(acc_scratch_ref): + @parameterized.named_parameters( + ('float32', 'float32'), ('bfloat16', 'bfloat16'), ('int8', 'int8') + ) + @hp.given( + hps.integers(1, 1024), + hps.integers(1, 1024), + hps.integers(1, 1024), + hps.sampled_from([8, 16, 32, 128, 256, 512]), + hps.sampled_from([128, 256, 512]), + hps.sampled_from([128, 256, 512]), + hps.integers(0, 4), + ) + def test_padded_matmul(self, dtype, m, k, n, bm, bk, bn, seed): + if dtype == 'int8' and jtu.is_device_tpu_at_least(6): + self.skipTest('Not implemented for TPU v6.') + + hp.assume(bm <= m) + hp.assume(bn <= n) + hp.assume(bk <= k) + if dtype == 'bfloat16': + hp.assume(bm >= 16) + if dtype == 'int8': + if not jtu.is_device_tpu_at_least(5): + self.skipTest('Only TPU v5+ allowed for int8.') + hp.assume(bm >= 32) + k1, k2 = jax.random.split(jax.random.key(seed)) + x = jax.random.normal(k1, (m, k), jnp.float32).astype(dtype) + y = jax.random.normal(k2, (k, n), jnp.float32).astype(dtype) + + out = matmul(x, y, bm=bm, bk=bk, bn=bn) + expected = x @ y + atol = rtol = 2.3e-5 + if dtype == 'bfloat16': + out = out.astype('float32') + expected = expected.astype('float32') + atol = rtol = 1e-2 + np.testing.assert_allclose(out, expected, atol=atol, rtol=rtol) + + +class PallasCallBoundedSliceIndexingTest(parameterized.TestCase): + + def test_block_spec_bounded_slice_invalid_index(self): + if not jtu.is_device_tpu(): + self.skipTest('Only works on TPU.') + shape = (16, 8, 128) + + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + + def main(refs): + x_ref, y_ref = refs + + @pl.core_map(pltpu.create_tensorcore_mesh('core')) + def _(): pltpu.emit_pipeline( - partial(basic_matmul_kernel, acc_scratch_ref=acc_scratch_ref, k=k), - in_specs=[ - pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)), - pl.BlockSpec((bk, bn), lambda i, j, k: (k, j)), - ], - out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)), - grid=grid, - core_axis=0, - dimension_semantics=( - pltpu.PARALLEL, - pltpu.PARALLEL, - pltpu.ARBITRARY, + kernel, + grid=(1,), + in_specs=( + pl.BlockSpec( + (pl.BoundedSlice(8), 8, 128), + lambda i: (0, 0, 0), # first index needs to be a pl.ds + ), ), - )(x_hbm_ref, y_hbm_ref, o_hbm_ref) + out_specs=pl.BlockSpec( + (8, 8, 128), + lambda i: (0, 0, 0), + ), + )(x_ref, y_ref) - accum_dtype = ( - jnp.float32 if jnp.issubdtype(x.dtype, jnp.floating) else jnp.int32 - ) - pl.run_scoped(run, pltpu.VMEM((bm, bn), accum_dtype)) + @jax.jit + def f(x): + y = jnp.ones((8, 8, 128), dtype=jnp.int32) + _, y = pl.run_state(main)((x, y)) + return y + with self.assertRaisesRegex( + ValueError, + 'Must return a pl.ds from the index_map for a BoundedSlice dimension.' + ): + f.trace(jax.ShapeDtypeStruct(shape, jnp.int32)) - num_cores = jax.devices()[0].num_cores - return pl.pallas_call( - kernel, - out_shape=jax.ShapeDtypeStruct((m, n), x.dtype), - in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY), - ], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), - grid=(num_cores,), - )(x, y) + def test_block_spec_bounded_slice_static(self): + if not jtu.is_device_tpu(): + self.skipTest('Only works on TPU.') + if not jtu.is_device_tpu_at_least(4): + self.skipTest('Only works on TPU v4+') + shape = (16, 8, 128) + + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + + def main(refs): + x_ref, y_ref = refs + + @pl.core_map(pltpu.create_tensorcore_mesh('core')) + def _(): + pltpu.emit_pipeline( + kernel, + grid=(1,), + in_specs=( + pl.BlockSpec( + (pl.BoundedSlice(8), 8, 128), + lambda i: (pl.ds(4, 8), 0, 0), + ), + ), + out_specs=pl.BlockSpec( + (8, 8, 128), + lambda i: (0, 0, 0), + ), + )(x_ref, y_ref) + + x = jnp.arange(np.prod(shape), dtype=np.int32).reshape(shape) + + @jax.jit + def f(x): + y = jnp.ones((8, 8, 128), dtype=jnp.int32) + _, y = pl.run_state(main)((x, y)) + return y + + out = f(x) + np.testing.assert_allclose(out, x[4:12]) + + def test_block_spec_bounded_slice_dynamic(self): + if not jtu.is_device_tpu(): + self.skipTest('Only works on TPU.') + if not jtu.is_device_tpu_at_least(4): + self.skipTest('Only works on TPU v4+') + shape = (16, 8, 128) + + slices = jnp.array([[0, 3], [3, 8], [8, 11], [11, 16]], dtype=jnp.int32)[ + ::-1 + ] + + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + + def main(refs): + x_ref, y_ref, slices_ref = refs + + @pl.core_map(pltpu.create_tensorcore_mesh('core')) + def _(): + + @functools.partial( + pl.run_scoped, slices_smem=pltpu.SMEM(slices.shape, slices.dtype) + ) + def _(slices_smem): + pltpu.sync_copy(slices_ref, slices_smem) + def index_map(i): + return ( + pl.ds(slices_smem[i, 0], slices_smem[i, 1] - slices_smem[i, 0]), + 0, + 0, + ) + block_spec = pl.BlockSpec( + (pl.BoundedSlice(16), 8, 128), + index_map, + ) + pltpu.emit_pipeline( + kernel, + grid=(slices.shape[0],), + in_specs=(block_spec,), + out_specs=block_spec, + )(x_ref, y_ref) + + x = jnp.arange(np.prod(shape), dtype=np.int32).reshape(shape) + + @jax.jit + def f(x, slices): + y = pl.empty_like(x) + _, y, _ = pl.run_state(main)((x, y, slices)) + return y + + out = f(x, slices) + np.testing.assert_allclose(out, x) + + +class PipelineHijaxTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + if not jtu.is_device_tpu_at_least(4): + self.skipTest('Only works on TPU v4+.') + + def test_emit_pipeline_hijax(self): + @dataclasses.dataclass(frozen=True) + class ArrayTuple: + x0: jax.Array + x1: jax.Array + + @property + def shape(self): + assert self.x0.shape == self.x1.shape + return self.x0.shape - class PaddedPipelineEmitterTest(parameterized.TestCase): + @property + def dtype(self): + assert self.x0.dtype == self.x1.dtype + return self.x0.dtype - def setUp(self): - super().setUp() - if not jtu.is_device_tpu_at_least(4): - self.skipTest('Only TPU v4+ allowed.') + @dataclasses.dataclass(frozen=True) + class ShapedArrayTuple(hijax.HiType): + shape: tuple[int, ...] + dtype: jnp.dtype - @parameterized.named_parameters( - ('float32', 'float32'), ('bfloat16', 'bfloat16'), ('int8', 'int8') + update = dataclasses.replace + + def lo_ty(self) -> list[hijax.ShapedArray]: + return [hijax.ShapedArray(self.shape, self.dtype)] * 2 + + def lower_val(self, hi_val: ArrayTuple) -> list[jax.Array]: + return [hi_val.x0, hi_val.x1] + + def raise_val(self, x0, x1) -> ArrayTuple: + return ArrayTuple(x0, x1) + + def ref_get_abstract_eval(self, ref_aval, *args, tree): + arr_aval = hijax.ShapedArray(self.shape, self.dtype) + updated_ref = ref_aval.update(inner_aval=arr_aval) + out, effects = state_primitives.get_p.abstract_eval( + updated_ref, *args, tree=tree + ) + assert isinstance(out, hijax.ShapedArray) + return ShapedArrayTuple(out.shape, out.dtype), effects + + def ref_get_to_lojax( + self, ref: state.TransformedRef | jax.Ref, idx: indexing.NDIndexer + ): + tup_ref, transforms = ref._refs, ref.transforms # pylint: disable=protected-access + assert isinstance(transforms, tuple) + transforms += (idx,) + + flat_transforms, tree = jax.tree.flatten(transforms) + x0_out = state_primitives.get_p.bind( + tup_ref.x0, *flat_transforms, tree=tree + ) + x1_out = state_primitives.get_p.bind( + tup_ref.x1, *flat_transforms, tree=tree + ) + return ShapedArrayTuple(x0_out, x1_out).raise_val(x0_out, x1_out) + + def ref_swap_abstract_eval(self, ref_aval, val_aval, *args, tree): + arr_aval = hijax.ShapedArray(self.shape, self.dtype) + val_arr_aval = hijax.ShapedArray(val_aval.shape, val_aval.dtype) + updated_ref = ref_aval.update(inner_aval=arr_aval) + out_aval, effects = state_primitives.swap_p.abstract_eval( + updated_ref, val_arr_aval, *args, tree=tree + ) + assert isinstance(out_aval, hijax.ShapedArray) + return ShapedArrayTuple(out_aval.shape, out_aval.dtype), effects + + def ref_swap_to_lojax( + self, + ref: state.TransformedRef | jax.Ref, + val: ArrayTuple, + idx: indexing.NDIndexer, + ): + tup_ref, transforms = ref._refs, ref.transforms # pylint: disable=protected-access + assert isinstance(transforms, tuple) + transforms += (idx,) + + flat_transforms, tree = jax.tree.flatten(transforms) + x0_out = state_primitives.swap_p.bind( + tup_ref.x0, val.x0, *flat_transforms, tree=tree + ) + x1_out = state_primitives.swap_p.bind( + tup_ref.x1, val.x1, *flat_transforms, tree=tree + ) + return self.raise_val(x0_out, x1_out) + + def lower_block_spec( + self, block_spec: pl.BlockSpec + ) -> list[pl.BlockSpec]: + return [block_spec, block_spec] + + def dma_start( + self, + src_ref: state.TransformedRef, + dst_ref: state.TransformedRef, + src_sem: state.TransformedRef, + dst_sem: state.TransformedRef, + device_id: jax.Array | int | None, + device_id_type: pl.DeviceIdType, + priority: int, + add: bool, + ) -> None: + del add + src_aval = jax.typeof(src_ref.ref).inner_aval + assert isinstance(src_aval, ShapedArrayTuple) + dst_aval = jax.typeof(dst_ref.ref).inner_aval + assert isinstance(dst_aval, ShapedArrayTuple) + + src_ref, src_transforms = src_ref.ref._refs, src_ref.transforms # pylint: disable=protected-access + dst_ref, dst_transforms = dst_ref.ref._refs, dst_ref.transforms # pylint: disable=protected-access + + def _run_dma( + src_ref, + dst_ref, + src_sem, + dst_sem, + device_id, + device_id_type, + priority, + ): + if src_sem is not None: + desc = pltpu.make_async_remote_copy( + src_ref, + dst_ref, + src_sem, + dst_sem, + device_id=device_id, + device_id_type=device_id_type, + ) + else: + assert device_id is None + desc = pltpu.make_async_copy(src_ref, dst_ref, dst_sem) + desc.start(priority=priority) + + src_x0_ref, src_x1_ref = src_ref.x0, src_ref.x1 + dst_x0_ref, dst_x1_ref = dst_ref.x0, dst_ref.x1 + + _run_dma( + state.TransformedRef(src_x0_ref, src_transforms), + state.TransformedRef(dst_x0_ref, dst_transforms), + src_sem, + dst_sem, + device_id, + device_id_type, + priority, + ) + _run_dma( + state.TransformedRef(src_x1_ref, src_transforms), + state.TransformedRef(dst_x1_ref, dst_transforms), + src_sem, + dst_sem, + device_id, + device_id_type, + priority, + ) + + def dma_wait( + self, src_ref, dst_ref, src_sem, dst_sem, device_id, device_id_type + ): + assert isinstance(jax.typeof(src_ref.ref).inner_aval, ShapedArrayTuple) + assert isinstance(jax.typeof(dst_ref.ref).inner_aval, ShapedArrayTuple) + + src_ref, src_transforms = src_ref.ref._refs, src_ref.transforms # pylint: disable=protected-access + dst_ref, dst_transforms = dst_ref.ref._refs, dst_ref.transforms # pylint: disable=protected-access + + def _run_dma( + src_ref, dst_ref, src_sem, dst_sem, device_id, device_id_type + ): + if src_sem is not None: + desc = pltpu.make_async_remote_copy( + src_ref, + dst_ref, + src_sem, + dst_sem, + device_id=device_id, + device_id_type=device_id_type, + ) + else: + assert device_id is None + desc = pltpu.make_async_copy(src_ref, dst_ref, dst_sem) + desc.wait() + + src_x0_ref, src_x1_ref = src_ref.x0, src_ref.x1 + dst_x0_ref, dst_x1_ref = dst_ref.x0, dst_ref.x1 + + _run_dma( + state.TransformedRef(src_x0_ref, src_transforms), + state.TransformedRef(dst_x0_ref, dst_transforms), + src_sem, + dst_sem, + device_id, + device_id_type, + ) + _run_dma( + state.TransformedRef(src_x1_ref, src_transforms), + state.TransformedRef(dst_x1_ref, dst_transforms), + src_sem, + dst_sem, + device_id, + device_id_type, + ) + + hijax.register_hitype( + ArrayTuple, lambda q: ShapedArrayTuple(q.shape, q.dtype) ) - @hp.given( - hps.integers(1, 1024), - hps.integers(1, 1024), - hps.integers(1, 1024), - hps.sampled_from([8, 16, 32, 128, 256, 512]), - hps.sampled_from([128, 256, 512]), - hps.sampled_from([128, 256, 512]), - hps.integers(0, 4), + + def kernel(x_hbm_ref, o_hbm_ref): + def body(x_ref, o_ref): + o_ref[...] = x_ref[...] + + num_steps = 4 + block_shape = (x_hbm_ref.shape[0] // num_steps, x_hbm_ref.shape[1]) + + pltpu.emit_pipeline( + body, + grid=(num_steps,), + in_specs=(pl.BlockSpec(block_shape, lambda i: (i, 0)),), + out_specs=pl.BlockSpec(block_shape, lambda i: (i, 0)), + )(x_hbm_ref, o_hbm_ref) + + inp = ArrayTuple( + jnp.arange(32 * 128, dtype=jnp.int32).reshape((32, 128)), + jnp.arange(32 * 128, dtype=jnp.int32).reshape((32, 128)), + ) + + out_ty = ShapedArrayTuple( + inp.shape, + inp.dtype, ) - def test_padded_matmul(self, dtype, m, k, n, bm, bk, bn, seed): - if dtype == 'int8' and jtu.is_device_tpu_at_least(6): - self.skipTest('Not implemented for TPU v6.') - - def align_up_to(x, y): - return (x + y - 1) // y * y - - hp.assume(bm <= m) - hp.assume(bn <= n) - hp.assume(bk <= k) - if dtype == 'bfloat16': - hp.assume(bm >= 16) - if dtype == 'int8': - if not jtu.is_device_tpu_at_least(5): - self.skipTest('Only TPU v5+ allowed for int8.') - hp.assume(bm >= 32) - # TODO(apaszke): Relax DMA restrictions and remove this. - packing = 4 // jnp.dtype(dtype).itemsize - if packing != 1: - m = align_up_to(m, 8 * packing) - k = align_up_to(k, 8 * packing) - k1, k2 = jax.random.split(jax.random.key(seed)) - x = jax.random.normal(k1, (m, k), jnp.float32).astype(dtype) - y = jax.random.normal(k2, (k, n), jnp.float32).astype(dtype) - - out = matmul(x, y, bm=bm, bk=bk, bn=bn) - expected = x @ y - atol = rtol = 2.3e-5 - if dtype == 'bfloat16': - out = out.astype('float32') - expected = expected.astype('float32') - atol = rtol = 1e-2 - np.testing.assert_allclose(out, expected, atol=atol, rtol=rtol) + + out = pl.pallas_call( + kernel, + in_specs=(pl.BlockSpec(memory_space=pl.ANY),), + out_shape=out_ty, + out_specs=pl.BlockSpec(memory_space=pl.ANY), + )(inp) + + np.testing.assert_allclose(out.x0, inp.x0) + np.testing.assert_allclose(out.x1, inp.x1) if __name__ == '__main__': diff --git a/tests/pallas/tpu_pallas_random_test.py b/tests/pallas/tpu_pallas_random_test.py index ca8edf7a269e..6a675bb0c7ae 100644 --- a/tests/pallas/tpu_pallas_random_test.py +++ b/tests/pallas/tpu_pallas_random_test.py @@ -14,12 +14,13 @@ from absl.testing import absltest from absl.testing import parameterized +import functools import jax from jax import random as jax_random from jax._src import test_util as jtu from jax._src.pallas.mosaic import random as plrandom from jax.experimental import pallas as pl -from jax.experimental import shard_map +from jax._src import shard_map from jax.experimental.pallas import tpu as pltpu from jax.experimental.pallas.ops.tpu.random import philox # pylint: disable=unused-import # noqa: F401 from jax.experimental.pallas.ops.tpu.random import threefry # pylint: disable=unused-import # noqa: F401 @@ -28,7 +29,6 @@ P = jax.sharding.PartitionSpec - jax.config.parse_flags_with_absl() @@ -47,15 +47,15 @@ def test_to_pallas_key_under_vmap(self, use_legacy_key: bool): else: key = jax.random.key(42, impl="rbg") key = jax.random.split(key, 10) - batched_key = plrandom.to_pallas_key(key) + batched_key = pltpu.to_pallas_key(key) batched_key_data = jax.random.key_data(batched_key) - vmapped_key = jax.vmap(plrandom.to_pallas_key)(key) + vmapped_key = jax.vmap(pltpu.to_pallas_key)(key) vmapped_key_data = jax.random.key_data(vmapped_key) np.testing.assert_array_equal(batched_key_data, vmapped_key_data) def test_pallas_key_raise_not_implemented_outside_of_kernel(self): key = jax_random.key(0, impl="rbg") - pallas_key = plrandom.to_pallas_key(key) + pallas_key = pltpu.to_pallas_key(key) # Using a pallas key outside of a kernel should raise an error when # trying to lower TPU-specific ops to XLA. # TODO(justinfu): Make this error more specific to pallas PRNG usage. @@ -105,56 +105,89 @@ def body(o_ref): self.assertLessEqual(jnp.max(result), np.iinfo(jnp.int32).max) self.assertGreaterEqual(jnp.min(result), np.iinfo(jnp.int32).min) - def test_stateful_uniform_sample(self): + @parameterized.parameters( + (pltpu.stateful_uniform, jnp.float32), + (pltpu.stateful_normal, jnp.float32), + ) + def test_stateful_sample(self, generator, dtype): # Test stateful RNG using the jax.random API wrappers. def body(key_ref, o_ref): - plrandom.set_seed(key_ref[...]) - o_ref[...] = plrandom.uniform( - shape=o_ref[...].shape, minval=0.0, maxval=1.0) + pltpu.prng_seed(key_ref[...]) + o_ref[...] = generator(shape=o_ref[...].shape) rbg_key = jax_random.key(0, impl="rbg") - key = plrandom.to_pallas_key(rbg_key) - o_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32) + key = pltpu.to_pallas_key(rbg_key) + o_shape = jax.ShapeDtypeStruct((8, 128), dtype) result = pl.pallas_call( body, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM)], + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], out_shape=o_shape, )(key) - self.assertGreaterEqual(jnp.min(result), 0) - self.assertLessEqual(jnp.max(result), 1.0) + # Check that the numbers are different. + self.assertGreaterEqual(jnp.max(result), jnp.min(result)) - def test_stateless_uniform_sample(self): + @parameterized.parameters( + (jax_random.uniform, jnp.float32, None), + (jax_random.uniform, jnp.float32, (1, 1)), + (jax_random.normal, jnp.float32, None), + (jax_random.bits, jnp.uint32, None), + ) + def test_stateless_sample(self, generator, dtype, key_shape): # Test keyed RNG using the jax.random API. def body(key_ref, o_ref): - o_ref[...] = jax_random.uniform( - key_ref[...], shape=o_ref[...].shape, minval=0.0, maxval=1.0 + if key_shape: + key_ref = key_ref.at[*((0,) * len(key_shape))] + o_ref[...] = generator( + key_ref[...], shape=o_ref[...].shape ) rbg_key = jax_random.key(0, impl="rbg") - key = plrandom.to_pallas_key(rbg_key) - o_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32) + key = pltpu.to_pallas_key(rbg_key) + if key_shape: + key = jnp.reshape(key, key_shape) + o_shape = jax.ShapeDtypeStruct((8, 128), dtype) result = pl.pallas_call( body, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM)], + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], out_shape=o_shape, )(key) - self.assertGreaterEqual(jnp.min(result), 0) - self.assertLessEqual(jnp.max(result), 1.0) + # Check that the numbers are different. + self.assertGreaterEqual(jnp.max(result), jnp.min(result)) def test_key_data(self): def body(key_ref, o_ref): - o_ref[...] = jax.random.key_data(key_ref[...]) + x0, x1 = plrandom.unwrap_pallas_seed(key_ref[...]) + o_ref[0, 0] = x0 + o_ref[0, 1] = x1 rbg_key = jax_random.key(0, impl="rbg") - key = plrandom.to_pallas_key(rbg_key) + key = pltpu.to_pallas_key(rbg_key) expected_key_data = jax.random.key_data(key) o_shape = jax.ShapeDtypeStruct(expected_key_data.shape, expected_key_data.dtype) result = pl.pallas_call( body, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM)], + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), out_shape=o_shape, )(key) - self.assertEqual(result, expected_key_data) + self.assertArraysEqual(result, expected_key_data) + + def test_squeezed_blockspec(self): + @functools.partial( + pl.pallas_call, + grid=(), + in_specs=[ + pl.BlockSpec((pl.squeezed,), lambda: (0,), memory_space=pltpu.SMEM) + ], + out_specs=pl.BlockSpec((8, 128)), + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + ) + def kernel(key_ref, o_ref): + o_ref[...] = jax_random.uniform(key_ref[...], shape=o_ref.shape) + + # Just make sure this does not crash. + k = pltpu.to_pallas_key(jax_random.key(0, impl="rbg")) + kernel(k[None]) def test_fold_in(self): # Test that folding in a value results in different random numbers. @@ -164,25 +197,53 @@ def body(key_ref, o_ref): key, shape=o_ref[0, ...].shape, minval=0.0, maxval=1.0 ) - key = jax_random.fold_in(key, 2) + key = jax_random.fold_in(key, jnp.uint32(2)) o_ref[1, ...] = jax_random.uniform( key, shape=o_ref[1, ...].shape, minval=0.0, maxval=1.0 ) rbg_key = jax_random.key(0, impl="rbg") - key = plrandom.to_pallas_key(rbg_key) + key = pltpu.to_pallas_key(rbg_key) o_shape = jax.ShapeDtypeStruct((2, 8, 128), jnp.float32) result = pl.pallas_call( body, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM)], + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], out_shape=o_shape, )(key) result_a = result[0] result_b = result[1] np.testing.assert_array_compare(np.not_equal, result_a, result_b) - -class BlockInvarianceTest(parameterized.TestCase): + def test_key_in_core_map(self): + if not jtu.is_device_tpu_at_least(4): + self.skipTest("Fails on TPU <= v3") + + def main(refs): + key_hbm, o_ref = refs + @pl.core_map(pltpu.create_tensorcore_mesh('core')) + def _(): + @functools.partial(pl.run_scoped, + key_smem=pltpu.SMEM((), key_hbm.dtype), + o_vmem=pltpu.VMEM(o_ref.shape, o_ref.dtype)) + def _scoped(key_smem, o_vmem): + pltpu.sync_copy(key_hbm, key_smem) + o_vmem[...] = jax_random.uniform( + key_smem[...], shape=o_ref.shape, minval=0.0, maxval=1.0 + ) + pltpu.sync_copy(o_vmem, o_ref) + + @jax.jit + def f(rng_key): + y = jnp.zeros((8, 128), dtype=jnp.float32) + _, y = pl.run_state(main)((rng_key, y)) + return y + + key = pltpu.to_pallas_key(jax_random.key(0, impl="rbg")) + y = f(key) + self.assertGreaterEqual(jnp.max(y), jnp.min(y)) + + +class BlockInvarianceTest(jtu.JaxTestCase): def setUp(self): if not jtu.test_device_matches(["tpu"]): @@ -194,7 +255,7 @@ def test_block_invariance(self): def make_kernel_body(index_map): def body(key_ref, o_ref): key = key_ref[...] - samples = plrandom.sample_block( + samples = pltpu.sample_block( jax.random.uniform, key, block_size=o_ref[...].shape, @@ -208,7 +269,7 @@ def body(key_ref, o_ref): global_key = jax_random.key(0, impl="pallas_tpu") o_shape = jnp.ones((64, 512), dtype=jnp.float32) - key_spec = pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM) + key_spec = pl.BlockSpec(memory_space=pltpu.SMEM) out_spec = pl.BlockSpec((16, 128), lambda i, j: (i, j)) result_16x128 = pl.pallas_call( make_kernel_body(index_map=lambda i, j: (i, j)), @@ -229,37 +290,34 @@ def body(key_ref, o_ref): np.testing.assert_array_equal(result_16x128, result_32x256) -class ThreefryTest(parameterized.TestCase): +class ThreefryTest(jtu.JaxTestCase): def setUp(self): if not jtu.test_device_matches(["tpu"]): self.skipTest("Need TPU devices") super().setUp() - @parameterized.parameters( - ((8, 128),), - ((32, 256),), - ((4, 16, 128),), + @parameterized.product( + shape=((8, 128), (32, 256), (4, 16, 128)), + generator_and_dtype=((jax_random.uniform, jnp.float32), + (jax_random.normal, jnp.float32), + (jax_random.bits, jnp.uint32)) ) - def test_uniform_matches_jax_threefry(self, shape): + def test_pallas_matches_jax_threefry(self, shape, generator_and_dtype): + generator, dtype = generator_and_dtype def body(key_ref, o_ref): - key = jax.random.wrap_key_data(key_ref[0, ...], impl='threefry2x32') - o_ref[...] = jax_random.uniform( - key, shape=o_ref[...].shape, minval=0.0, maxval=1.0 - ) + key = key_ref[...] + o_ref[...] = generator(key, shape=o_ref[...].shape) - threefry_key = jax_random.key(0, impl="threefry2x32").reshape((1,)) - o_shape = jax.ShapeDtypeStruct(shape, jnp.float32) + threefry_key = jax_random.key(0, impl="threefry2x32") + o_shape = jax.ShapeDtypeStruct(shape, dtype) with jax.threefry_partitionable(True): - # TODO(justinfu): support passing keys into VMEM. result = pl.pallas_call( body, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM)], + in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], out_shape=o_shape, - )(jax.random.key_data(threefry_key)) - jax_result = jax_random.uniform( - threefry_key[0], shape=o_shape.shape, minval=0.0, maxval=1.0 - ) + )(threefry_key) + jax_result = generator(threefry_key, shape=o_shape.shape) np.testing.assert_array_equal(result, jax_result) @parameterized.parameters( @@ -267,7 +325,7 @@ def body(key_ref, o_ref): ((137, 275),), # Non block-aligned shape ((4, 512, 512),), # Greater than 2D shape ((34,),), # 1D - (tuple(),), # 0D + ((),), # 0D ) def test_threefry_kernel_matches_jax_threefry(self, shape): with jax.threefry_partitionable(True): @@ -288,7 +346,11 @@ def test_threefry_kernel_matches_jax_threefry_sharded(self, shape): self.skipTest("Need at least 2 devices") num_devices = jax.device_count() partition = P("x") - mesh = jax.make_mesh((num_devices,), ("x",)) + mesh = jax.make_mesh( + (num_devices,), + ("x",), + axis_types=(jax.sharding.AxisType.Auto,), + ) sharding = jax.sharding.NamedSharding(mesh, partition) with jax.threefry_partitionable(True): @@ -303,6 +365,7 @@ def test_threefry_kernel_matches_jax_threefry_sharded(self, shape): mesh=mesh, in_specs=partition, out_specs=partition, + check_vma=False, ) jax_gen = generate(key_jax) pl_gen = generate(key_pallas) @@ -310,19 +373,32 @@ def test_threefry_kernel_matches_jax_threefry_sharded(self, shape): np.testing.assert_array_equal(jax_gen, pl_gen) -class PhiloxTest(parameterized.TestCase): +class PhiloxTest(jtu.JaxTestCase): def setUp(self): if not jtu.test_device_matches(["tpu"]): self.skipTest("Need TPU devices") super().setUp() + @parameterized.product( + x=[0x1, 0x10000, 0xabcdef], + y=[0x1, 0x10000, 0xabcdef], + ) + def test_mul_hi_lo(self, x, y): + x = jnp.uint32(x) + y = jnp.uint32(y) + hi, lo = philox.mul32_hi_lo(x, y) + with jax.enable_x64(): + result = (hi.astype(jnp.uint64) << 32) + lo.astype(jnp.uint64) + ref = x.astype(jnp.uint64) * y.astype(jnp.uint64) + self.assertEqual(result, ref) + @parameterized.parameters( ((512, 512),), ((137, 275),), # Non block-aligned shape ((4, 512, 512),), # Greater than 2D shape ((34,),), # 1D - (tuple(),), # 0D + ((),), # 0D ) def test_generate_uniform(self, shape): key = jax_random.key(0, impl="pallas_philox4x32") diff --git a/tests/pallas/tpu_pallas_state_test.py b/tests/pallas/tpu_pallas_state_test.py index 46f98c087110..013b2e31fc7c 100644 --- a/tests/pallas/tpu_pallas_state_test.py +++ b/tests/pallas/tpu_pallas_state_test.py @@ -14,8 +14,10 @@ import functools from absl.testing import absltest +from absl.testing import parameterized import jax from jax._src import test_util as jtu +from jax._src.state.primitives import pin, unpin from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp @@ -117,6 +119,8 @@ def f_stateful(refs): x = pl.pallas_call( functools.partial(copy_kernel, x_ref, y_ref), + in_specs=[pl.BlockSpec(memory_space=pl.ANY)], + out_specs=pl.BlockSpec(memory_space=pl.ANY), scratch_shapes=[pltpu.SemaphoreType.DMA], out_shape=jax.ShapeDtypeStruct(x_ref.shape, x_ref.dtype), input_output_aliases={0: 0}, @@ -184,12 +188,54 @@ def matmul_pipeline_kernel(acc_ref): y = jax.random.normal(jax.random.key(1), (k, n), jnp.float32) o = matmul(x, y) atol = 0 - if jtu.is_device_tpu(6): + if jtu.is_device_tpu_at_least(6): atol = 2e-5 np.testing.assert_allclose(o, x @ y, atol=atol) -class ShmallasTest(jtu.JaxTestCase): +class PinnedBufferTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if not jtu.is_device_tpu_at_least(4): + self.skipTest("Only supported on TPU v4+") + + def test_basic(self): + + @jax.jit + def f(x): + x_pinned = pin(x) + x_pinned = pl.pallas_call( + lambda *_: None, out_shape=x_pinned, + in_specs=[pl.BlockSpec(memory_space=pl.ANY)], + out_specs=pl.BlockSpec(memory_space=pl.ANY), + input_output_aliases={0: 0} + )(x_pinned) + return unpin(x_pinned) + + x = jnp.arange(3.) + y = f(x) + self.assertAllClose(y, x, check_dtypes=False) + + def test_error_if_not_aliased(self): + + @jax.jit + def f(x): + x_pinned = pin(x) + x_pinned = pl.pallas_call( + lambda *_: None, out_shape=x_pinned, + in_specs=[pl.BlockSpec(memory_space=pl.ANY)], + out_specs=pl.BlockSpec(memory_space=pl.ANY), + # input_output_aliases={0: 0} # no aliasing! + )(x_pinned) + return unpin(x_pinned) + + x = jnp.arange(3.) + with self.assertRaisesRegex(ValueError, r"pinned buffers without"): + f(x) + + +class CoreMapTest(jtu.JaxTestCase): def setUp(self): super().setUp() @@ -199,56 +245,81 @@ def setUp(self): def test_can_create_tensorcore_mesh(self): _ = pltpu.create_tensorcore_mesh("x") - def test_can_run_basic_pallas_kernel_with_core_map(self): + def test_kernel_helper_basic(self): mesh = pltpu.create_tensorcore_mesh("x") - + def body(x_ref, o_ref): + pltpu.sync_copy(x_ref, o_ref) + x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128)) + with self.subTest("decorator"): + result = pl.kernel(body, out_shape=x, mesh=mesh)(x) + np.testing.assert_array_equal(result, x) + with self.subTest("decorator_factory"): + result = pl.kernel(out_shape=x, mesh=mesh)(body)(x) + np.testing.assert_array_equal(result, x) + + def test_empty_core_map_raises_error(self): @jax.jit def f(x): y = jnp.zeros_like(x) def inner(refs): - x_ref, y_ref = refs - @pl.core_map(mesh) + del refs # Unused. + @pl.core_map(pltpu.create_tensorcore_mesh("x")) def _(): - def alloc(sem): - pltpu.async_copy(x_ref, y_ref, sem).wait() - pl.run_scoped(alloc, pltpu.SemaphoreType.DMA) + pass _, y = pl.run_state(inner)((x, y)) return y x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128)) - y = f(x) - np.testing.assert_array_equal(y, x) + with self.assertRaisesRegex(Exception, + "Attempted to lower core_map without discharging."): + f(x) - def test_can_query_core_index_pallas_kernel_with_core_map(self): + def test_can_signal_cores(self): + @jax.jit + def f(x): + x_ref = jax.new_ref(x) + y_ref = jax.new_ref(jnp.empty_like(x)) + @pl.core_map(pltpu.create_tensorcore_mesh("x")) + def _(): + @functools.partial(pl.run_scoped, sem=pltpu.SemaphoreType.REGULAR) + def inner(sem): + s = jax.lax.axis_size("x") + for i in range(s): + pl.semaphore_signal(sem, device_id={"x": i}) + pl.semaphore_wait(sem, s) + pltpu.sync_copy(x_ref, y_ref) + return jax.freeze(y_ref) + x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128)) + np.testing.assert_array_equal(f(x), x) + + def test_can_query_core_index(self): mesh = pltpu.create_tensorcore_mesh("x") + slc_size = 16 // mesh.shape["x"] @jax.jit def f(x): - y = jnp.zeros_like(x) - def inner(refs): - x_ref, y_ref = refs - @pl.core_map(mesh) - def _(): - num_cores = jax.lax.psum(1, "x") - slc_size = 16 // num_cores - def alloc(x_vmem_ref, y_vmem_ref, sem): - core_index = jax.lax.axis_index("x") - slc = pl.ds(core_index * slc_size, slc_size) - pltpu.async_copy( - x_ref.at[slc], - x_vmem_ref, - sem, - ).wait() - y = x_vmem_ref[...] + jax.lax.axis_index("x") - y_vmem_ref[...] = y - pltpu.async_copy(y_vmem_ref, y_ref.at[slc], sem).wait() - pl.run_scoped( - alloc, - pltpu.VMEM((slc_size, 128), x_ref.dtype), - pltpu.VMEM((slc_size, 128), y_ref.dtype), + @pl.kernel( + out_shape=x, + mesh=mesh, + scratch_shapes=[ + pltpu.VMEM((slc_size, 128), x.dtype), + pltpu.VMEM((slc_size, 128), x.dtype), pltpu.SemaphoreType.DMA, - ) - _, y = pl.run_state(inner)((x, y)) - return y + ], + ) + def kernel(x_ref, y_ref, x_vmem_ref, y_vmem_ref, sem): + num_cores = jax.lax.axis_size("x") + slc_size = 16 // num_cores + core_index = jax.lax.axis_index("x") + slc = pl.ds(core_index * slc_size, slc_size) + pltpu.async_copy( + x_ref.at[slc], + x_vmem_ref, + sem, + ).wait() + y = x_vmem_ref[...] + jax.lax.axis_index("x") + y_vmem_ref[...] = y + pltpu.async_copy(y_vmem_ref, y_ref.at[slc], sem).wait() + return kernel(x) num_cores = jax.devices()[0].num_cores x = jnp.arange(16 * 128, dtype=jnp.int32).reshape((16, 128)) expected_out = ( @@ -257,6 +328,101 @@ def alloc(x_vmem_ref, y_vmem_ref, sem): y = f(x) np.testing.assert_array_equal(y, expected_out) + def test_raises_on_captured_arrays(self): + @jax.jit + def f(x): + y = jnp.zeros_like(x) + + @pl.kernel(out_shape=x, + mesh=pltpu.create_tensorcore_mesh("x"), + scratch_shapes=dict(tmp_ref=pltpu.VMEM(x.shape, x.dtype))) + def kernel(x_ref, out_ref, tmp_ref): + pltpu.sync_copy(x_ref, tmp_ref) + tmp_ref[...] += y + out_ref[...] = tmp_ref[...] + return kernel(x) + + x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128)) + with self.assertRaisesRegex( + Exception, "core_map .* captures non-scalar constants" + ): + f(x) + + def test_capture_scalar(self): + @jax.jit + def f(x, i): + @pl.kernel(out_shape=jax.ShapeDtypeStruct(x.shape[1:], jnp.int32), + mesh=pltpu.create_tensorcore_mesh("x", num_cores=1)) + def kernel(x_ref, out_ref): + pltpu.sync_copy(x_ref.at[i], out_ref) + return kernel(x) + + x = jnp.arange(4 * 8 * 128, dtype=jnp.int32).reshape((4, 8, 128)) + for i in range(x.shape[0]): + out = f(x, i) + np.testing.assert_array_equal(out, x[i]) + + @jax.jit + def g(x, i): + @pl.kernel(out_shape=jax.ShapeDtypeStruct((2, *x.shape[1:]), jnp.int32), + mesh=pltpu.create_tensorcore_mesh("x", num_cores=1)) + def kernel(x_ref, out_ref): + pltpu.sync_copy(x_ref.at[pl.ds(i, 2)], out_ref) + return kernel(x) + + x = jnp.arange(4 * 8 * 128, dtype=jnp.int32).reshape((4, 8, 128)) + for i in range(3): + out = g(x, i) + np.testing.assert_array_equal(out, x[i:i+2]) + + def test_kernel_helper_with_scratch(self): + mesh = pltpu.create_tensorcore_mesh("x") + def body(x_ref, o_ref, scratch_ref): + pltpu.sync_copy(x_ref, scratch_ref) + scratch_ref[...] += 1 + pltpu.sync_copy(scratch_ref, o_ref) + x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128)) + result = pl.kernel( + body, out_shape=x, mesh=mesh, + scratch_shapes=dict(scratch_ref=pltpu.VMEM(x.shape, x.dtype)))(x) + np.testing.assert_array_equal(result, x + 1) + + def test_kernel_helper_with_out_tree(self): + mesh = pltpu.create_tensorcore_mesh("x") + def body(x_ref, o1_ref, o2_ref, scratch_ref): + pltpu.sync_copy(x_ref, o1_ref) + pltpu.sync_copy(x_ref, scratch_ref) + scratch_ref[...] += 1 + pltpu.sync_copy(scratch_ref, o2_ref) + x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128)) + result1, result2 = pl.kernel( + body, out_shape=[x, x], mesh=mesh, + scratch_shapes=[pltpu.VMEM(x.shape, x.dtype)])(x) + np.testing.assert_array_equal(result1, x) + np.testing.assert_array_equal(result2, x + 1) + + @parameterized.named_parameters( + ("HBM", pltpu.HBM, 0), + ("VMEM", pltpu.VMEM, 1), + ("SMEM", pltpu.SMEM, 4), + ("SEMAPHORE", pltpu.SEMAPHORE, 2), + ) + def test_kernel_with_output_memory_space(self, memory_space, color): + if not jtu.is_device_tpu_at_least(5): + self.skipTest("Only supported on TPU v5+") + mesh = pltpu.create_tensorcore_mesh("x", num_cores=1) + def body(x_ref, o_ref): + pltpu.sync_copy(x_ref, o_ref) + x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128)) + text = pl.kernel( + body, out_shape=memory_space(x.shape, x.dtype), mesh=mesh, + ).lower(x).as_text() + custom_call = [l for l in text.split("\n") if "@tpu_custom_call" in l] + self.assertLen(custom_call, 1) + custom_call = custom_call[0] + self.assertRegex(custom_call, + r".*output_memory_colors\\22: \[" + str(color) + r"\].*") + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 55831ff6af1d..95e0030cb91d 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -14,32 +14,33 @@ """Test TPU-specific extensions to pallas_call.""" +from collections.abc import Callable import contextlib import functools -import itertools import gc import io +import itertools +import json import math import re import sys -from typing import Callable from absl.testing import absltest from absl.testing import parameterized import jax from jax import api_util from jax import lax from jax._src import checkify +from jax._src import shard_map from jax._src import state from jax._src import test_util as jtu from jax._src.interpreters import partial_eval as pe -from jax._src.lib import xla_extension -from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr -from jax._src.state import utils as state_utils +from jax._src.pallas import pallas_test_util as ptu +from jax._src.pallas.mosaic import error_handling from jax._src.state import discharge as state_discharge +from jax._src.state import utils as state_utils from jax.experimental import mesh_utils from jax.experimental import mosaic from jax.experimental import pallas as pl -from jax.experimental import shard_map from jax.experimental.pallas import tpu as pltpu from jax.experimental.pallas.ops.tpu import example_kernel from jax.extend import linear_util as lu @@ -54,6 +55,25 @@ partial = functools.partial +def only_passes_in_interpret( + unless_generation: int | None = None, *args, **kwargs +): + def decorator(f, *args, **kwargs): + def wrapper(self, *args, **kwargs): + if self.INTERPRET or ( + unless_generation is not None + and jtu.is_device_tpu_at_least(unless_generation) + ): + f(self, *args, **kwargs) + else: + with self.assertRaises(Exception): + f(self, *args, **kwargs) + + return wrapper + + return decorator + + @contextlib.contextmanager def string_stdout(): """Redirects stdout to a string.""" @@ -71,27 +91,13 @@ def wrap_init(f: Callable, nr_args: int): debug_info=api_util.debug_info("state_test", f, (0,) * nr_args, {})) -class PallasBaseTest(jtu.JaxTestCase): - INTERPRET: bool = False - - def setUp(self): - if not jtu.test_device_matches(['tpu']) and not self.INTERPRET: - self.skipTest('Test requires TPUs, or interpret mode') - super().setUp() - _trace_kernel_to_jaxpr.cache_clear() - - def pallas_call(self, *args, **kwargs): - return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET) - -class TPUPipelineModeTest(PallasBaseTest): +class TPUPipelineModeTest(ptu.PallasTPUTest): @parameterized.parameters( (pl.Buffered(2), pl.Buffered(2)), (pl.Buffered(2), pl.Buffered(1)), (pl.Buffered(1), pl.Buffered(1))) def test_two_input_vadd(self, x_pmode : pl.Buffered, y_pmode : pl.Buffered): - if not jtu.if_cloud_tpu_at_least(2025, 2, 11): - self.skipTest("Needs a newer libTPU") def body(x_ref, y_ref, o_ref): x = x_ref[:] y = y_ref[:] @@ -136,7 +142,8 @@ def vadd(x, y): z = vadd(x, y) np.testing.assert_allclose(z, x + y) -class PallasCallScalarPrefetchTest(PallasBaseTest): + +class PallasCallScalarPrefetchTest(ptu.PallasTPUTest): def test_trivial_scalar_prefetch(self): def body(_, x_ref, o_ref): o_ref[...] = x_ref[...] @@ -145,8 +152,7 @@ def body(_, x_ref, o_ref): x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128)) def _x_transform(i, s_ref): - s = pl.load(s_ref, (i,)) - return (s, 0) + return (s_ref[i], 0) out = self.pallas_call( body, @@ -164,6 +170,47 @@ def _x_transform(i, s_ref): )(s, x) np.testing.assert_allclose(out, x.reshape((8, 8, -1))[s].reshape(x.shape)) + @parameterized.parameters( + (jnp.bfloat16, 0), + (jnp.bfloat16, 3), + (jnp.bfloat16, 129), + (jnp.int16, 2), + (jnp.int16, 5), + (jnp.int16, 257), + (jnp.int8, 311), + (jnp.int8, 597), + (jnp.int8, 1025), + ) + def test_narrow_bitwidth_scalar_prefetch(self, dtype, index): + def body(s_ref, o_ref): + o_ref[...] = jnp.broadcast_to(s_ref[index], o_ref.shape) + + s = jnp.arange(16 * 128, dtype=dtype) + out = self.pallas_call( + body, + out_shape=jax.ShapeDtypeStruct((16, 128), dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + ), + )(s) + np.testing.assert_array_equal(out, jnp.broadcast_to(s[index], (16, 128))) + + def test_f32_scalar_prefetch(self): + def body(s_ref, x_ref, o_ref): + o_ref[...] = x_ref[...] + jnp.broadcast_to(s_ref[0], x_ref.shape) + + s = jnp.array([5.0], jnp.float32) + x = jnp.arange(8 * 128, dtype=jnp.float32).reshape((8, 128)) + + out = self.pallas_call( + body, + out_shape=jax.ShapeDtypeStruct(x.shape, jnp.float32), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + ), + )(s, x) + np.testing.assert_allclose(out, x + jnp.broadcast_to(s, x.shape)) + def test_trivial_scalar_prefetch_with_windowless_args(self): def body(_, x_ref, o_ref): o_ref[...] = x_ref[...] @@ -210,13 +257,16 @@ def f(x, grid_size, to_store): grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, # 1 pytree grid=(grid_size,), - in_specs=[pl.BlockSpec((8, 128), - lambda i, s_ref: (pl.load(s_ref[0], (i,)), 0)), - pl.BlockSpec((1, 128), lambda i, s_ref: (0, 0))], - out_specs=pl.BlockSpec((32, 128), - lambda i, s_ref: (pl.load(s_ref[0], i), 0)), - scratch_shapes=([pltpu.SemaphoreType.REGULAR((3,))] if scratch - else []), + in_specs=[ + pl.BlockSpec((8, 128), lambda i, s_ref: (s_ref[0][i], 0)), + pl.BlockSpec((1, 128), lambda i, s_ref: (0, 0)), + ], + out_specs=pl.BlockSpec( + (32, 128), lambda i, s_ref: (s_ref[0][i], 0) + ), + scratch_shapes=( + [pltpu.SemaphoreType.REGULAR((3,))] if scratch else [] + ), ), ) def kernel(s_refs, src, to_store, dst, *scratch_refs): @@ -225,7 +275,7 @@ def kernel(s_refs, src, to_store, dst, *scratch_refs): assert s2.shape == (3,) assert s3 is None store_idx = s_ref[pl.program_id(0)] - pl.store(dst, (pl.dslice(store_idx, 1), slice(None)), to_store[...]) + dst[pl.dslice(store_idx, 1), :] = to_store[...] # Pass a pytree of scalar return kernel((s, np.arange(3, dtype=np.int32), None), x, to_store) @@ -281,7 +331,7 @@ def body(_, x_ref, o_ref): x = jnp.arange(2 * 8 * 8 * 128, dtype=jnp.int32).reshape((2, 8 * 8, 128)) def _x_transform(i, s_ref): - s = pl.load(s_ref, (i,)) + s = s_ref[i] return (s, 0) def f(x): @@ -423,7 +473,7 @@ def body(_, x_ref, o_ref): x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128)) def _x_transform(i, s_ref): - s = pl.load(s_ref, (i,)) + s = s_ref[i] return (s, 0) s = s[None] @@ -457,7 +507,7 @@ def body(_, x_ref, o_ref): x = jnp.arange(2 * 8 * 8 * 128, dtype=jnp.int32).reshape((2, 8 * 8, 128)) def _x_transform(i, s_ref): - s = pl.load(s_ref, (i,)) + s = s_ref[i] return (s, 0) s = jnp.tile(s[None], [2, 1]) @@ -478,7 +528,7 @@ def kernel(s, x): ), grid=8, ), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( allow_input_fusion=[False, True] ), )(s, x) @@ -524,7 +574,7 @@ class PallasCallScalarPrefetchInterpretTest(PallasCallScalarPrefetchTest): INTERPRET: bool = True -class PallasCallDynamicGridTest(PallasBaseTest): +class PallasCallDynamicGridTest(ptu.PallasTPUTest): def test_can_query_grid_statically_via_num_programs(self): @@ -634,6 +684,33 @@ def dynamic_kernel(steps): dynamic_kernel(jnp.int32(4)), np.full(shape, 42.0, np.float32) ) + def test_dynamic_grid_scalar_input_with_input_memory_space(self): + if not jtu.is_device_tpu_at_least(5): + self.skipTest('Needs a newer TPU') + shape = (8, 128) + result_ty = jax.ShapeDtypeStruct(shape, jnp.float32) + + def kernel(scalar_input_ref, output_ref): + output_ref[...] = jnp.full_like(output_ref, scalar_input_ref[0, 0]) + + @jax.jit + def dynamic_kernel(steps): + scalar_input = jnp.array([[42]], dtype=jnp.int32) + scalar_input = pltpu.with_memory_space_constraint( + scalar_input, pltpu.VMEM + ) + return self.pallas_call( + kernel, + out_shape=result_ty, + in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], + out_specs=pl.BlockSpec(shape, lambda i: (0, 0)), + grid=(steps * 2,), + )(scalar_input) + + np.testing.assert_array_equal( + dynamic_kernel(jnp.int32(4)), np.full(shape, 42.0, np.float32) + ) + def test_vmap_trivial_dynamic_grid(self): shape = (8, 128) result_ty = jax.ShapeDtypeStruct(shape, jnp.float32) @@ -774,7 +851,7 @@ class PallasCallDynamicGridInterpretTest(PallasCallDynamicGridTest): INTERPRET = True -class PallasCallDMATest(PallasBaseTest): +class PallasCallDMATest(ptu.PallasTPUTest): def setUp(self): super().setUp() @@ -804,7 +881,7 @@ def body(temp_ref): pl.run_scoped(body, pltpu.VMEM((8,), jnp.float32)) return [] - jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( + jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( wrap_init(kernel, 2), [ state.shaped_array_ref((8,), jnp.float32), @@ -842,7 +919,7 @@ def body(x_ref): kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), ), out_shape=jax.ShapeDtypeStruct((1,), jnp.int32), )() @@ -862,7 +939,7 @@ def body(x_ref): kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), ), out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), )() @@ -881,7 +958,7 @@ def body(x_ref): kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), ), out_shape=jax.ShapeDtypeStruct((16, 128), jnp.int32), )() @@ -900,7 +977,7 @@ def body(x_ref): kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), ), out_shape=jax.ShapeDtypeStruct((17, 128), jnp.int32), )() @@ -935,8 +1012,7 @@ def scope(): aref1 = state.AbstractRef(jax.core.ShapedArray((4,), jnp.dtype('float32'))) aref2 = state.AbstractRef(jax.core.ShapedArray((4,), jnp.dtype('float32'))) in_avals = [aref1, aref2] - stateful_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), - in_avals) + stateful_jaxpr, _, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), in_avals) discharged_jaxpr, _ = state_discharge.discharge_state( stateful_jaxpr, consts=(), should_discharge=[False, True]) self.assertLen(discharged_jaxpr.invars, 2) @@ -1100,7 +1176,7 @@ def body(sems): y = jax.block_until_ready( self.pallas_call( kernel, - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), out_shape=jax.ShapeDtypeStruct((m, n), jnp.int32), )() ) @@ -1123,7 +1199,7 @@ def kernel(x_hbm_ref, y_hbm_ref, sem_val_ref, dma_sem): in_specs=[pl.BlockSpec(memory_space=pl.ANY)], out_specs=[ pl.BlockSpec(memory_space=pl.ANY), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), ], scratch_shapes=[pltpu.SemaphoreType.DMA], ), @@ -1136,11 +1212,41 @@ def kernel(x_hbm_ref, y_hbm_ref, sem_val_ref, dma_sem): np.testing.assert_array_equal(y, x) np.testing.assert_array_equal(sem_val, 0) + def test_set_dma_priority(self): + if jtu.get_tpu_version() < 5: + self.skipTest('Target does not support DMA prefetch between HBM and VMEM') + def kernel(x1, x2, y1, y2, scratch1, scratch2, sem1, sem2): + copy1 = pltpu.async_copy(x1, scratch1, sem1, priority=1) + copy2 = pltpu.async_copy(x2, scratch2, sem2, priority=0) + copy1.wait() + copy2.wait() + copy1 = pltpu.async_copy(scratch1, y1, sem1, priority=0) + copy2 = pltpu.async_copy(scratch2, y2, sem2, priority=1) + copy1.wait() + copy2.wait() + + shape = (8, 128) + dtype = jnp.int32 + x1 = jnp.arange(np.prod(shape), dtype=dtype).reshape(shape) + x2 = x1 + 1 + y1, y2 = self.pallas_call( + kernel, + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[pl.BlockSpec(memory_space=pl.ANY)] * 2, + scratch_shapes=[pltpu.VMEM(shape, dtype)] * 2 + + [pltpu.SemaphoreType.DMA] * 2, + out_specs=[pl.BlockSpec(memory_space=pl.ANY)] * 2, + ), + out_shape=[jax.ShapeDtypeStruct(shape, dtype)] * 2, + )(x1, x2) + np.testing.assert_array_equal(y1, x1) + np.testing.assert_array_equal(y2, x2) + def test_hbm_hbm_dma(self): def kernel(x_hbm_ref, y_hbm_ref): def body(sem): - pltpu.async_copy(x_hbm_ref.at[pl.ds(8), :], y_hbm_ref.at[:, pl.ds(128)], - sem).wait() + pltpu.async_copy(x_hbm_ref.at[:8, :], y_hbm_ref.at[:, :128], sem).wait() pl.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(8 * 128.).reshape((8, 128)) y = self.pallas_call( @@ -1153,6 +1259,67 @@ def body(sem): )(x) np.testing.assert_array_equal(y, x) + def test_host_input_host_to_hbm_dma(self): + if self.INTERPRET: + self.skipTest('Interpret mode does not support host memory.') + if jax.device_count() > 1: + self.skipTest("Test only works with a single device.") + def kernel(x_host_ref, y_hbm_ref): + def body(sem): + pltpu.async_copy(x_host_ref, y_hbm_ref, sem).wait() + + pl.run_scoped(body, pltpu.SemaphoreType.DMA) + + x = jnp.arange(8 * 128.0).reshape((8, 128)) + # Move input to the host. + x = jax.device_put( + x, + jax.sharding.NamedSharding( + jax.sharding.Mesh(jax.devices(), 'x'), + jax.sharding.PartitionSpec(), + memory_kind='pinned_host', + ), + ) + y = self.pallas_call( + kernel, + in_specs=[ + pl.BlockSpec(memory_space=pl.HOST), + ], + out_specs=pl.BlockSpec(memory_space=pl.ANY), + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + )(x) + np.testing.assert_array_equal(y, x) + + def test_hbm_to_host_host_output_dma(self): + if jax.device_count() > 1: + self.skipTest("Test only works with a single device.") + def kernel(y_hbm_ref, x_host_ref): + def body(sem): + pltpu.async_copy(y_hbm_ref, x_host_ref, sem).wait() + + pl.run_scoped(body, pltpu.SemaphoreType.DMA) + + host_sharding = jax.sharding.NamedSharding( + jax.sharding.Mesh(jax.devices(), 'x'), + jax.sharding.PartitionSpec(), + memory_kind='pinned_host', + ) + x = jnp.arange(8 * 128.0).reshape((8, 128)) + + @functools.partial(jax.jit, out_shardings=host_sharding) + def f(x): + return self.pallas_call( + kernel, + in_specs=[ + pl.BlockSpec(memory_space=pl.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pl.HOST), + out_shape=pltpu.HOST(shape=(8, 128), dtype=jnp.float32), + )(x) + + y = f(x) + np.testing.assert_array_equal(y, x) + def test_cannot_dma_with_nonscalar_semaphore_ref(self): def kernel(x_hbm_ref, y_hbm_ref): def body(sem): @@ -1347,7 +1514,7 @@ def body(y_ref, sem): y = self.pallas_call( kernel, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), ], out_specs=pl.BlockSpec(memory_space=pl.ANY), out_shape=jax.ShapeDtypeStruct((1, 2), jnp.float32), @@ -1364,9 +1531,9 @@ def body(sem): y = self.pallas_call( kernel, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) np.testing.assert_allclose(y, x) @@ -1389,7 +1556,7 @@ def body(sem): in_specs=[ pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), out_shape=jax.ShapeDtypeStruct((16, 128), jnp.float32), )(x) np.testing.assert_allclose(y, x) @@ -1412,7 +1579,7 @@ def body(sem): in_specs=[ pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), out_shape=jax.ShapeDtypeStruct((16, 128), jnp.float32), )(x) np.testing.assert_allclose(y, x.reshape((16, 128))) @@ -1441,7 +1608,7 @@ def body(sem): in_specs=[ pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), out_shape=jax.ShapeDtypeStruct((3, 16, 128), jnp.float32), )(x) np.testing.assert_allclose(y, x.reshape((3, 16, 128))) @@ -1468,7 +1635,7 @@ def body(sem): in_specs=[ pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), out_shape=jax.ShapeDtypeStruct((16, 128), jnp.float32), )(x) @@ -1517,7 +1684,6 @@ def kernel(y_ref, scratch_ref): out_specs=pl.BlockSpec((None, 8, 128), lambda i: (i, 0, 0)), grid=(2,), ), - debug=True, out_shape=jax.ShapeDtypeStruct((2, 8, 128), jnp.int32), )() expected = jnp.broadcast_to(jnp.arange(2, dtype=jnp.int32)[..., None, None], @@ -1540,12 +1706,13 @@ def kernel(x_bbm_ref, y_ref, sem, dma_sem): ], scratch_shapes=[pltpu.SemaphoreType.REGULAR, pltpu.SemaphoreType.DMA], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), ), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) np.testing.assert_array_equal(y, x) + @jtu.thread_unsafe_test() # Uses a lot of TPU memory. def test_large_array_indexing(self): n = 6 dtype = jnp.bfloat16 @@ -1586,7 +1753,7 @@ def run(array, data, index, size): kernel, out_shape=array, in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=pltpu.VMEM), pl.BlockSpec(memory_space=pltpu.SMEM), pl.BlockSpec(memory_space=pltpu.SMEM), @@ -1594,7 +1761,7 @@ def run(array, data, index, size): scratch_shapes=[ pltpu.SemaphoreType.DMA, ], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), input_output_aliases={0: 0}, )(array, data, index, size) @@ -1609,6 +1776,24 @@ def run(array, data, index, size): result = run(array, data, index, size) np.testing.assert_array_equal(result, expected) + def test_unused_dma_descriptor_error(self): + x = jnp.arange(8 * 128.0).reshape((8, 128)) + + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + in_specs=[pl.BlockSpec(memory_space=pltpu.HBM)], + scratch_shapes=[pltpu.SemaphoreType.DMA], + out_specs=pl.BlockSpec(memory_space=pltpu.HBM), + ) + def kernel(x_hbm_ref, o_hbm_ref, sem): + pltpu.make_async_copy(x_hbm_ref, o_hbm_ref, sem) + + with self.assertLogs(level='ERROR') as log: + kernel(x) + [message] = log.output + self.assertIn('AsyncCopyDescriptor was not used', message) + class PallasCallDMAInterpretTest(PallasCallDMATest): INTERPRET = True @@ -1707,7 +1892,7 @@ def test_kernel(o_ref, np.testing.assert_array_equal(results, expected) -class PallasCallTest(PallasBaseTest): +class PallasCallTest(ptu.PallasTPUTest): @parameterized.parameters([ dict(shape=shape, dty=dty) @@ -1716,8 +1901,6 @@ class PallasCallTest(PallasBaseTest): ) ]) def test_double_replicated_reduction(self, shape, dty): - if not jtu.if_cloud_tpu_at_least(2025, 2, 19): - self.skipTest("Needs a newer libTPU") def body(o_ref): x = jnp.full(shape, 2.0, dtype=dty) reduction = jnp.sum(x, axis=None) @@ -1745,6 +1928,54 @@ def reduce(): reduce_value = jnp.sum(jnp.full(shape, x), dtype=dty) np.testing.assert_allclose(z, reduce_value) + @jax.jit + def reduce_with_shape_invariant_numerics(): + return self.pallas_call( + body, + out_shape=jax.ShapeDtypeStruct((data_size,), dty), + in_specs=[], + out_specs=pl.BlockSpec((block_size,), lambda i: i), + grid=data_size // block_size, + compiler_params=pltpu.CompilerParams(shape_invariant_numerics=True), + )() + + np.testing.assert_allclose( + jax.block_until_ready(reduce_with_shape_invariant_numerics()), + reduce_value, + ) + + def test_scalar_any_input(self): + if not jtu.is_device_tpu_at_least(4): + self.skipTest("Needs a newer TPU") + def kernel(src, dst, sem): + pltpu.async_copy(src, dst, sem).wait() + + def run(src): + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(src.shape, jnp.float32), + in_specs=[pl.BlockSpec(memory_space=pl.ANY)], + scratch_shapes=[pltpu.SemaphoreType.DMA], + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), + )(src) + x = jnp.full((1,), 3.1415, dtype=jnp.float32) + np.testing.assert_array_equal(run(x), x) + + def test_sum_in_smem(self): + def kernel(x, out): + a = jnp.array(0, dtype=jnp.int32) + for i in range(4): + for j in range(4): + out[i, j] = a.astype(out.dtype) + a += x[i, j].astype(jnp.int32) + + x = jnp.ones((4, 4), jnp.int16) + spec = pl.BlockSpec(memory_space=pltpu.SMEM) + y = pl.pallas_call(kernel, in_specs=[spec], out_specs=spec, out_shape=x)(x) + np.testing.assert_array_equal( + y, jnp.arange(16, dtype=jnp.int32).reshape(4, 4) + ) + @parameterized.parameters([ dict( m=m, @@ -1764,11 +1995,12 @@ def reduce(): def test_replicated_broadcast_reduction( self, m, replicated, reduced_dims, dty, reduce_func ): - if not jtu.if_cloud_tpu_at_least(2025, 2, 19): - self.skipTest("Needs a newer libTPU") - if dty == jnp.int32 and 1 in reduced_dims: - # TODO(b/395579834): Remove this skip once we implement this. - self.skipTest('int32 reduction on last dimension not supported') + # TODO(b/395579834): Remove this skip later. + if ( + dty == jnp.int32 + and 1 in reduced_dims + ): + self.skipTest('Requires libtpu built after 2025-09-01') if not jtu.is_device_tpu_at_least(4) and len(replicated) == 2: self.skipTest( 'Brodcast in both sublanes and lanes not supported on this hardware' @@ -1801,6 +2033,21 @@ def reduce(x): expected = reduce_func(dilated_x, axis=reduced_dims).reshape(red_shape) np.testing.assert_allclose(y, expected) + @jax.jit + def reduce_with_shape_invariant_numerics(x): + return self.pallas_call( + body, + out_shape=jax.ShapeDtypeStruct(red_shape, dty), + in_specs=[pl.BlockSpec(in_shape)], + out_specs=pl.BlockSpec(red_shape), + grid=1, + compiler_params=pltpu.CompilerParams(shape_invariant_numerics=True), + )(x) + + np.testing.assert_allclose( + jax.block_until_ready(reduce_with_shape_invariant_numerics(x)), expected + ) + def test_cost_analysis(self): def kernel(x, y): y[:] = x[:] @@ -1822,7 +2069,7 @@ def kernel(x, y): y[:] = x[:] batch_size = 3 x = jnp.arange(batch_size * 1024.).reshape(batch_size, 8, 128) - f = pl.pallas_call( + f = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), cost_estimate=pl.CostEstimate( @@ -1835,6 +2082,57 @@ def kernel(x, y): self.assertEqual(analysis_result['transcendentals'], batch_size * 21) self.assertEqual(analysis_result['bytes accessed'], batch_size * 12345) + def test_cost_analysis_vmap_symbolic_batch_size(self): + # When exporting a module with a symbolic batch size, the cost analysis + # should be stripped from the tpu_custom_call because we can't accurately + # scale it by the dynamic batch size. + + def kernel(x, y): + y[:] = x[:] + + flops = 1234 + transcendentals = 21 + bytes_accessed = 12345 + + f = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + cost_estimate=pl.CostEstimate( + flops=flops, + transcendentals=transcendentals, + bytes_accessed=bytes_accessed, + ), + ) + f = jax.vmap(f) + + batch_size = 3 + x = jnp.arange(batch_size * 1024.0).reshape(batch_size, 8, 128) + exported_module = pl.lower_as_mlir(jax.jit(f), x, dynamic_shapes=True) + + self.assertIn('tpu_custom_call', str(exported_module)) + self.assertIn('cost_estimate', str(exported_module)) + # The exported module string encodes " as \22. + self.assertIn(f'flops\\22:{batch_size * flops}', str(exported_module)) + self.assertIn( + f'transcendentals\\22:{batch_size * transcendentals}', + str(exported_module), + ) + self.assertIn( + f'bytes_accessed\\22:{batch_size * bytes_accessed}', + str(exported_module), + ) + + x_shape = jax.ShapeDtypeStruct( + jax.export.symbolic_shape('b, 8, 128'), jnp.float32 + ) + exported_module = pl.lower_as_mlir(jax.jit(f), x_shape, dynamic_shapes=True) + # Assert that the cost analysis is not present in the serialized module. + self.assertIn('tpu_custom_call', str(exported_module)) + self.assertNotIn('cost_estimate', str(exported_module)) + self.assertNotIn('flops', str(exported_module)) + self.assertNotIn('transcendentals', str(exported_module)) + self.assertNotIn('bytes_accessed', str(exported_module)) + def test_vmem_limit(self): shape = (128, 128) @@ -1842,18 +2140,187 @@ def kernel(x_ref, y_ref): y_ref[...] = x_ref[...] x = jnp.arange(np.prod(shape), dtype=np.float32).reshape(shape) - with self.assertRaises(xla_extension.XlaRuntimeError): + with self.assertRaises(jax.errors.JaxRuntimeError): self.pallas_call( kernel, out_shape=x, - compiler_params=pltpu.TPUCompilerParams(vmem_limit_bytes=256), + compiler_params=pltpu.CompilerParams(vmem_limit_bytes=256), )(x) self.pallas_call( kernel, out_shape=x, - compiler_params=pltpu.TPUCompilerParams(vmem_limit_bytes=int(2**18)), + compiler_params=pltpu.CompilerParams(vmem_limit_bytes=int(2**18)), )(x) + @parameterized.parameters([ + pl.Buffered(1), + pl.Buffered(2), + ]) + def test_vmem_oom_error_message_basics(self, pmode: pl.Buffered): + + if jtu.is_device_tpu(version=5, variant='e') or jtu.is_device_tpu( + version=6, variant='e' + ): + block_shape = (4096 // pmode.buffer_count, 8192) + elif jtu.is_device_tpu(version=5, variant='p'): + block_shape = (1024, 8192) + else: + self.skipTest('Unsupported TPU variant') + grid = (2, 2) + shape = (grid[0] * block_shape[0], grid[1] * block_shape[1]) + + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + + x = jnp.arange(np.prod(shape), dtype=np.float32).reshape(shape) + out_shape = jax.ShapeDtypeStruct(shape, x.dtype) + + def index_map(i, j): + return (i * block_shape[0], j * block_shape[1]) + + spec = pl.BlockSpec( + block_shape=block_shape, index_map=index_map, pipeline_mode=pmode + ) + + with self.assertRaises(jax.errors.JaxRuntimeError) as cm: + self.pallas_call( + kernel, + out_shape=out_shape, + grid=grid, + in_specs=[spec], + out_specs=spec, + )(x) + + error_message = str(cm.exception) + self.assertIn( + 'input window allocation for operator input 0', + error_message, + ) + self.assertIn( + 'output window allocation for operator output 0', + error_message, + ) + self.assertIn( + f'The window shape is f32[{block_shape[0]},{block_shape[1]}], while the' + f' full shape is f32[{shape[0]},{shape[1]}].', + error_message, + ) + if jtu.is_cloud_tpu_at_least(2025, 11, 5): + self.assertIn( + 'This allocation is single buffered.' + if pmode.buffer_count == 1 + else 'This allocation has 2 buffering levels', + error_message, + ) + + def test_vmem_oom_error_message_dynamic_grid_scalar_prefetch_and_vmem_scratch( + self, + ): + if jax.device_count() > 1: + self.skipTest("Test only works with a single device.") + + def body(s_ref, x_hbm_ref, o_hbm_ref, vmem_scratch_ref): + del s_ref, vmem_scratch_ref + o_hbm_ref[...] = x_hbm_ref[...] + + s = jnp.array([5.0], jnp.float32) + if jtu.is_device_tpu(version=5, variant='e') or jtu.is_device_tpu( + version=6, variant='e' + ): + x_shape = (4096, 8192) + elif jtu.is_device_tpu(version=5, variant='p'): + x_shape = (1024, 8192) + else: + x_shape = (512, 8192) + scratch_shape = (x_shape[0] // 4, 8192) + x = jnp.arange(x_shape[0] * x_shape[1], dtype=jnp.float32).reshape(x_shape) + out_shape = jax.ShapeDtypeStruct(x_shape, jnp.float32) + + @jax.jit + def run(num_grid, s, x): + return pl.pallas_call( + body, + out_shape=out_shape, + # use dynamic grid, scalar prefetch, and scratch input. + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + grid=(num_grid,), + in_specs=[pl.BlockSpec()], + out_specs=pl.BlockSpec(), + scratch_shapes=[pltpu.VMEM(scratch_shape, jnp.float32)], + ), + )(s, x) + + with self.assertRaises(jax.errors.JaxRuntimeError) as cm: + run(4, s, x) + + error_message = str(cm.exception) + self.assertIn( + 'input window allocation for operator input 1', + error_message, + ) + self.assertIn( + 'output window allocation for operator output 0', + error_message, + ) + + def test_automatic_single_buffering(self,): + if self.INTERPRET: + self.skipTest('OOM tests need us to compile the kernels') + + def body(*_): + pass # We only want to compile the kernel. + + window_mib = 10 + if jtu.is_device_tpu_at_least(6): + window_mib = 20 + x = jax.ShapeDtypeStruct((100 * 1024 * 1024,), jnp.int8) + x_small = jax.ShapeDtypeStruct((window_mib * 1024 * 1024,), jnp.int8) + # Should recognize that the block specs only load a single window. + self.pallas_call(body, grid=(4,), out_shape=x_small).lower().compile() + # Should recognize that the block specs only load a single window, as it + # only depends on the 1-sized grid dim + self.pallas_call( + body, grid=(4, 1), out_shape=x, + out_specs=pl.BlockSpec((window_mib * 1024 * 1024,), lambda i, j: (j,)) + ).lower().compile() + self.pallas_call( + body, grid=(1, 4), out_shape=x, + out_specs=pl.BlockSpec((window_mib * 1024 * 1024,), lambda i, j: (i,)) + ).lower().compile() + # Should OOM, as now we are extracting different windows + with self.assertRaisesRegex( + jax.errors.JaxRuntimeError, '(Ran out of memory)|(exceed memory)' + ): + self.pallas_call( + body, grid=(4, 1), out_shape=x, + out_specs=pl.BlockSpec((window_mib * 1024 * 1024,), lambda i, j: (j + i,)) + ).lower().compile() + # Explicitly setting single-buffering should fix it, though. + self.pallas_call( + body, grid=(4, 1), out_shape=x, + out_specs=pl.BlockSpec((window_mib * 1024 * 1024,),lambda i, j: (j + i,), + pipeline_mode=pl.Buffered(1)) + ).lower().compile() + # Add unused scalar prefetch args to make sure we don't incorrectly consider + # them to be unused grid indices. + scalar = jnp.array([0], jnp.int32) + with self.assertRaisesRegex( + jax.errors.JaxRuntimeError, '(Ran out of memory)|(exceed memory)' + ): + self.pallas_call( + body, + out_shape=x, + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=2, + grid=(4, 1), + out_specs=pl.BlockSpec( + (window_mib * 1024 * 1024,), + lambda i, j, *_: (j + i,), + ), + ), + ).lower(scalar, scalar).compile() + def test_allow_input_fusion(self): shape = (3, 128, 128) @@ -1868,7 +2335,7 @@ def f(x, y): in_specs=[pl.BlockSpec((1, 128, 128), lambda i: (i, 0, 0))], out_specs=pl.BlockSpec((1, 128, 128), lambda i: (i, 0, 0)), out_shape=x, - compiler_params=pltpu.TPUCompilerParams(allow_input_fusion=[True]), + compiler_params=pltpu.CompilerParams(allow_input_fusion=[True]), )(z) x = jnp.arange(np.prod(shape), dtype=np.float32).reshape(shape) @@ -1896,16 +2363,19 @@ def kernel(x_ref, y_ref): self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( internal_scratch_in_bytes=requested_bytes, ), )(x) - @parameterized.product(dtype=[jnp.bfloat16, jnp.float32]) - def test_pltpu_repeat(self, dtype): + @parameterized.product( + dtype=[jnp.bfloat16, jnp.float32], + axis=[1, -1], + ) + def test_pltpu_repeat(self, dtype, axis): def test_kernel(x_ref, o_ref): x = x_ref[...] - o_ref[...] = pltpu.repeat(x, 2, axis=1) + o_ref[...] = pltpu.repeat(x, 2, axis=axis) @jax.jit def test(x: jax.Array) -> jax.Array: @@ -1919,11 +2389,10 @@ def test(x: jax.Array) -> jax.Array: np.testing.assert_array_equal(y, jnp.concatenate([x, x], axis=1)) def test_mixed_precision_dot(self): - if not jtu.if_cloud_tpu_at_least(2025, 2, 27): - self.skipTest("Needs a newer libTPU") - if not jtu.is_device_tpu_at_least(5): self.skipTest('float8_e4m3b11fnuz not supported on TPU generations <= 4') + if jtu.is_device_tpu(7, 'x'): + self.skipTest('float8_e4m3b11fnuz not supported on TPU v7x') def kernel(x_ref, w_ref, o_ref): o_ref[:] = jax.lax.dot_general( @@ -1950,6 +2419,70 @@ def kernel(x_ref, w_ref, o_ref): mosaic_nans = jnp.isnan(run(x, w)).sum() self.assertEqual(jax_nans, mosaic_nans) + @parameterized.product( + in_dtype=[ + jnp.int8, + jnp.int16, + jnp.int32, + jnp.float8_e5m2, + jnp.float8_e4m3fn, + jnp.float8_e4m3b11fnuz, + jnp.bfloat16, + jnp.float32, + ], + out_dtype=[ + jnp.int8, + jnp.int16, + jnp.int32, + jnp.float32, + ], + ) + def test_scalar_casting(self, in_dtype, out_dtype): + def kernel(x_ref, o_ref): + o_ref[0] = x_ref[0].astype(out_dtype) + + x = jnp.asarray([7], dtype=in_dtype) + if jnp.issubdtype(in_dtype, jnp.signedinteger): + x *= -1 + + y = pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), + out_shape=jax.ShapeDtypeStruct(x.shape, out_dtype), + )(x) + self.assertEqual(y, x.astype(out_dtype)) + + @parameterized.product(in_dtype=[jnp.int4, jnp.int8, jnp.int16, jnp.int32]) + def test_scalar_load_upcast(self, in_dtype): + if in_dtype == jnp.int4 and not jtu.is_device_tpu_at_least(4): + self.skipTest("Triggers an XLA bug") # TODO(b/413602952) + def kernel(x_ref, o_ref): + o_ref[0, 0] = x_ref[0, 0].astype(o_ref.dtype) + x = jnp.asarray([[-1]], dtype=in_dtype) + y = pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), + out_shape=jax.ShapeDtypeStruct((1, 1), jnp.int32), + )(x) + self.assertEqual(y, x.astype(jnp.int32)) + + @parameterized.product(in_dtype=[jnp.int4, jnp.int8, jnp.int16, jnp.int32]) + def test_scalar_indirect_load(self, in_dtype): + def kernel(x_ref, o_ref): + o_ref[0, 0] = x_ref[0, x_ref[0, 0].astype(jnp.int32)].astype(o_ref.dtype) + if in_dtype == jnp.int4 and not jtu.is_device_tpu_at_least(4): + self.skipTest("Triggers an XLA bug") # TODO(b/413602952) + x = jnp.asarray([[3, 0, 0, 1]], dtype=in_dtype) + y = pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), + out_shape=jax.ShapeDtypeStruct((1, 1), jnp.int32), + )(x) + self.assertEqual(y, x[0, x[0, 0]].astype(jnp.int32)[None, None]) + def test_masked_store(self): shape = (16, 256) mask_shape = (10, 130) @@ -1963,12 +2496,7 @@ def body(scalar_ref, x_ref, o_ref): iota1 = lax.broadcasted_iota(jnp.int32, shape, 1) mask0 = jnp.logical_and(b0 <= iota0, iota0 < e0) mask1 = jnp.logical_and(b1 <= iota1, iota1 < e1) - pl.store( - o_ref, - (slice(None), slice(None)), - x_ref[...], - mask=jnp.logical_and(mask0, mask1), - ) + pltpu.store(o_ref, x_ref[...], mask=jnp.logical_and(mask0, mask1)) s = jnp.array(mask_start, jnp.int32) x = jnp.arange(np.prod(shape), dtype=dtype).reshape(shape) @@ -1984,8 +2512,298 @@ def body(scalar_ref, x_ref, o_ref): expected = expected.at[slices].set(x[slices]) np.testing.assert_array_equal(out, expected) + def test_custom_vjp(self): -class PallasUXTest(PallasBaseTest): + @jax.custom_vjp + def f(x): + return jnp.tanh(x) + def f_fwd(x): + return jnp.tanh(x) * 2, () + def f_bwd(_, g): + return (g * 2,) + + f.defvjp(f_fwd, f_bwd) + + def kernel(x_ref, dy_ref, y_ref, y_p_ref, dx_ref): + x = x_ref[...] + y_ref[...] = f(x) + y_p, f_vjp = jax.vjp(f, x) + y_p_ref[...] = y_p + dx_ref[...] = f_vjp(dy_ref[...])[0] + + x = jax.random.normal(jax.random.key(0), (8, 128), dtype=jnp.float32) + dy = jax.random.normal(jax.random.key(1), (8, 128), dtype=jnp.float32) + y, y_p, dx = pl.pallas_call( + kernel, + out_shape=( + jax.ShapeDtypeStruct((8, 128), jnp.float32), + jax.ShapeDtypeStruct((8, 128), jnp.float32), + jax.ShapeDtypeStruct((8, 128), jnp.float32), + ), + )(x, dy) + np.testing.assert_array_equal(y, f(x)) + np.testing.assert_array_equal(y_p, f(x) * 2) + np.testing.assert_array_equal(dx, dy * 2) + + @parameterized.parameters([ + jnp.int4, + jnp.int8, + jnp.int16, + jnp.int32, + jnp.uint4, + jnp.uint8, + jnp.uint16, + jnp.uint32, + ]) + def test_scalar_integer_addition(self, dtype): + def kernel(x_ref, y_ref): + y_ref[0] = x_ref[0] + x_ref[0] + + x = jnp.asarray([3], dtype=dtype) + + if dtype in [jnp.int32, jnp.uint32]: + y = pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), + out_shape=jax.ShapeDtypeStruct(x.shape, dtype), + )(x) + np.testing.assert_array_equal(y, x + x) + else: + with self.assertRaisesRegex( + error_handling.MosaicError, + 'Not implemented: Only i32 addition is supported.', + ): + _ = pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), + out_shape=jax.ShapeDtypeStruct(x.shape, dtype), + )(x) + + @parameterized.parameters([ + jnp.int4, + jnp.int8, + jnp.int16, + jnp.int32, + jnp.uint4, + jnp.uint8, + jnp.uint16, + jnp.uint32, + ]) + def test_vector_integer_addition(self, dtype): + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...] + x_ref[...] + + x = jnp.full((128, 16), 7, dtype=dtype) + + if dtype in [jnp.int32, jnp.uint32, jnp.int16, jnp.uint16]: + y = pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), + out_shape=jax.ShapeDtypeStruct(x.shape, dtype), + )(x) + np.testing.assert_array_equal(y, x + x) + else: + with self.assertRaisesRegex( + error_handling.MosaicError, + 'Not implemented: Only vector and vector' + ): + _ = pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), + out_shape=jax.ShapeDtypeStruct(x.shape, dtype), + )(x) + + @parameterized.parameters([ + jnp.int32, + jnp.uint32, + jnp.float32, + ]) + def test_max_operation(self, dtype): + def kernel(x_ref, y_ref): + y_ref[0] = jnp.maximum(x_ref[0], x_ref[1]) + + x = jnp.asarray([242, 87], dtype=dtype) + + y = pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), + out_shape=jax.ShapeDtypeStruct(x.shape, dtype), + )(x) + np.testing.assert_array_equal(y[0], jnp.maximum(x[0], x[1])) + + @parameterized.parameters([ + jnp.int32, + jnp.uint32, + jnp.float32, + ]) + def test_min_operation(self, dtype): + def kernel(x_ref, y_ref): + y_ref[0] = jnp.minimum(x_ref[0], x_ref[1]) + + x = jnp.asarray([242, 87], dtype=dtype) + + y = pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), + out_shape=jax.ShapeDtypeStruct(x.shape, dtype), + )(x) + np.testing.assert_array_equal(y[0], jnp.minimum(x[0], x[1])) + + @parameterized.parameters([ + jnp.int32, + jnp.uint32, + jnp.int16, + jnp.uint16, + jnp.int8, + jnp.uint8, + jnp.int4, + jnp.uint4, + jnp.float32, + jnp.bfloat16, + ]) + def test_bool_select_operation(self, dtype): + def kernel(condlist, choicelist, out_ref): + out_ref[...] = jnp.where(condlist[...], choicelist[...], 0) + + if dtype in [jnp.int4, jnp.uint4] and not jtu.is_device_tpu_at_least(4): + self.skipTest('i4 is not supported on TPU generations <= 3') + + shape = (8, 128) + condlist = jax.random.bernoulli(jax.random.key(0), 0.5, shape) + choicelist = jnp.arange(shape[0]*shape[1], dtype=dtype).reshape(shape) + + z = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((shape[0],shape[1]), dtype=dtype), + )(condlist, choicelist) + np.testing.assert_array_equal(z, jnp.where(condlist, choicelist, 0)) + + +class PallasScalarIOpsTest(ptu.PallasTPUTest): + + @staticmethod + def parameterized_integer_types(func): + _DEFAULT_INT_TYPES = [ + jnp.int4, + jnp.int8, + jnp.int16, + jnp.int32, + jnp.uint4, + jnp.uint8, + jnp.uint16, + jnp.uint32, + ] + + @parameterized.parameters(_DEFAULT_INT_TYPES) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + def _integer_ops_canonicalization_helper(self, kernel, result, dtype): + """For integer scalar ops, only i1 and i32 are supported.""" + x = jnp.arange(3, dtype=dtype) + + if dtype in [jnp.int32, jnp.uint32]: + y = pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), + out_shape=jax.ShapeDtypeStruct((1,), dtype), + )(x) + np.testing.assert_array_equal(y, jnp.asarray([result], dtype=dtype)) + else: + with self.assertRaisesRegex( + error_handling.MosaicError, + 'Not implemented: Only i1 and i32 scalars are supported.', + ): + _ = pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), + out_shape=jax.ShapeDtypeStruct(x.shape, dtype), + )(x) + + @parameterized_integer_types + def test_andi_op_canonicalization(self, dtype): + def kernel(x_ref, y_ref): + y_ref[0] = x_ref[1] & x_ref[2] + + self._integer_ops_canonicalization_helper(kernel, 1 & 2, dtype) + + @parameterized_integer_types + def test_divi_op_canonicalization(self, dtype): + # both divsi and divui + def kernel(x_ref, y_ref): + y_ref[0] = x_ref[1] // x_ref[2] + + self._integer_ops_canonicalization_helper(kernel, 0, dtype) + + @parameterized_integer_types + def test_max_op_canonicalization(self, dtype): + def kernel(x_ref, y_ref): + y_ref[0] = jnp.maximum(x_ref[1], x_ref[2]) + + self._integer_ops_canonicalization_helper(kernel, max(1, 2), dtype) + + @parameterized_integer_types + def test_min_op_canonicalization(self, dtype): + def kernel(x_ref, y_ref): + y_ref[0] = jnp.minimum(x_ref[1], x_ref[2]) + + self._integer_ops_canonicalization_helper(kernel, min(1, 2), dtype) + + @parameterized_integer_types + def test_muli_op_canonicalization(self, dtype): + def kernel(x_ref, y_ref): + y_ref[0] = x_ref[1] * x_ref[2] + + self._integer_ops_canonicalization_helper(kernel, 1 * 2, dtype) + + @parameterized_integer_types + def test_ori_op_canonicalization(self, dtype): + def kernel(x_ref, y_ref): + y_ref[0] = x_ref[1] | x_ref[2] + + self._integer_ops_canonicalization_helper(kernel, 1 | 2, dtype) + + @parameterized_integer_types + def test_shli_op_canonicalization(self, dtype): + def kernel(x_ref, y_ref): + y_ref[0] = x_ref[1] << x_ref[2] + + self._integer_ops_canonicalization_helper(kernel, 1 << 2, dtype) + + @parameterized_integer_types + def test_shri_op_canonicalization(self, dtype): + # Includes both shrsi and shrui + def kernel(x_ref, y_ref): + y_ref[0] = x_ref[1] >> x_ref[2] + + self._integer_ops_canonicalization_helper(kernel, 1 >> 2, dtype) + + @parameterized_integer_types + def test_subi_op_canonicalization(self, dtype): + def kernel(x_ref, y_ref): + y_ref[0] = x_ref[2] - x_ref[1] + + self._integer_ops_canonicalization_helper(kernel, 2 - 1, dtype) + + @parameterized_integer_types + def test_xori_op_canonicalization(self, dtype): + def kernel(x_ref, y_ref): + y_ref[0] = x_ref[1] ^ x_ref[2] + + self._integer_ops_canonicalization_helper(kernel, 1 ^ 2, dtype) + + +class PallasUXTest(ptu.PallasTPUTest): def test_mlir_location(self): # Make sure that MLIR locations are correctly propagated to primitives. @@ -2003,7 +2821,7 @@ def capture_as_tpu_kernel(module, *args, **kwargs): mosaic.as_tpu_kernel = as_tpu_kernel -class PallasMegacoreTest(PallasBaseTest): +class PallasMegacoreTest(ptu.PallasTPUTest): def test_megacore_splitting(self): # We want to make sure a 3-sized dimension is split across megacore @@ -2031,7 +2849,6 @@ def _(): pl.BlockSpec((128, 128), lambda i, j, k: (k, j)), ], out_specs=pl.BlockSpec((128, 128), lambda i, j, k: (i, j)), - debug=True, ) ) )(x, y) @@ -2040,7 +2857,7 @@ def _(): ) -class PallasCallVmapTest(PallasBaseTest): +class PallasCallVmapTest(ptu.PallasTPUTest): def test_scratch_input_vmap(self): """Test that vmapp-ing a kernel with scratch inputs works correctly.""" @@ -2077,7 +2894,7 @@ def add_one_with_scratch(x_ref, o_ref, scratch_ref): np.testing.assert_array_equal(out, out_ref, strict=True) -class PallasCallDynamicDMATest(PallasBaseTest): +class PallasCallDynamicDMATest(ptu.PallasTPUTest): def setUp(self): super().setUp() @@ -2101,9 +2918,9 @@ def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem): grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM), - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY)], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY)], + out_specs=pl.BlockSpec(memory_space=pl.ANY), scratch_shapes=[pltpu.SemaphoreType.DMA] ), out_shape=o, @@ -2129,9 +2946,9 @@ def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem): grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM), - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY)], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY)], + out_specs=pl.BlockSpec(memory_space=pl.ANY), scratch_shapes=[pltpu.SemaphoreType.DMA] ), out_shape=o, @@ -2141,7 +2958,7 @@ def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem): np.testing.assert_array_equal(out, expected) -class PallasCallRefTransformTest(PallasBaseTest): +class PallasCallRefTransformTest(ptu.PallasTPUTest): @parameterized.product(slice_first=[True, False]) def test_dma_bitcasted_ref(self, slice_first): @@ -2175,15 +2992,22 @@ def body(sem): ) np.testing.assert_array_equal(y, expected) - @parameterized.product(slice_first=[True, False]) - def test_load_bitcasted_ref(self, slice_first: bool): + @parameterized.product( + slice_first=[True, False], use_primitive_io_op=[False, True] + ) + def test_load_bitcasted_ref( + self, slice_first: bool, use_primitive_io_op: bool + ): def kernel(x_ref, y_ref): ref = ( x_ref.at[:8, :128].bitcast(jnp.int16) if slice_first else x_ref.bitcast(jnp.int16).at[:16, :128] ) - y_ref[...] = ref[...] + if use_primitive_io_op: + pltpu.store(y_ref, pltpu.load(ref)) + else: + y_ref[...] = ref[...] x = jnp.arange(4 * 8 * 128, dtype=jnp.int32).reshape((16, 256)) y = self.pallas_call( @@ -2197,15 +3021,22 @@ def kernel(x_ref, y_ref): ) np.testing.assert_array_equal(y, expected) - @parameterized.product(slice_first=[True, False]) - def test_store_bitcasted_ref(self, slice_first): + @parameterized.product( + slice_first=[True, False], use_primitive_io_op=[False, True] + ) + def test_store_bitcasted_ref( + self, slice_first: bool, use_primitive_io_op: bool + ): def kernel(x_ref, y_ref): ref = ( y_ref.at[:8, :128].bitcast(jnp.bfloat16) if slice_first else y_ref.bitcast(jnp.bfloat16).at[:16, :128] ) - ref[...] = x_ref[...] + if use_primitive_io_op: + pltpu.store(ref, pltpu.load(x_ref)) + else: + ref[...] = x_ref[...] x = jnp.arange(16 * 128, dtype=jnp.bfloat16).reshape((16, 128)) y = self.pallas_call( @@ -2300,89 +3131,9 @@ def kernel(x_ref, y_ref): np.testing.assert_array_equal(y, x[8:16, :128]) -class PallasCallPrintTest(PallasBaseTest): - - def test_debug_print(self): - @functools.partial( - self.pallas_call, - out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), - ) - def kernel(x_ref, o_ref): - pl.debug_print('It works!') - - x = jnp.arange(8 * 128, dtype=jnp.float32).reshape((8, 128)) - compiled_kernel = ( - jax.jit(kernel) - .lower(x) - .compile({'xla_tpu_enable_log_recorder': 'true'}) - ) - with jtu.capture_stderr() as get_output: - jax.block_until_ready(compiled_kernel(x)) - self.assertIn('It works!', get_output()) - - def test_debug_print_with_values(self): - @functools.partial( - self.pallas_call, - in_specs=(pl.BlockSpec(memory_space=pltpu.SMEM),), - out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), - ) - def kernel(x_ref, o_ref): - pl.debug_print('x[0] == {}', x_ref[0]) - - x = jnp.array([42, 24]).astype(jnp.int32) - compiled_kernel = ( - jax.jit(kernel) - .lower(x) - .compile({'xla_tpu_enable_log_recorder': 'true'}) - ) - with jtu.capture_stderr() as get_output: - jax.block_until_ready(compiled_kernel(x)) - self.assertIn('x[0] == 42', get_output()) - - @parameterized.named_parameters( - (f"{'_'.join(map(str, shape))}_{dtype.__name__}", shape, dtype) - for shape in ( - (2, 8, 128), - # test unaligned shapes - (3,), - (3, 4), - (2, 3, 4), - (2, 9, 129), - ) - for dtype in (jnp.int32, jnp.uint32, jnp.float32) - ) - def test_debug_print_vector(self, shape, dtype): - @functools.partial( - self.pallas_call, - out_shape=jax.ShapeDtypeStruct(shape, dtype), - ) - def kernel(x_ref, o_ref): - pl.debug_print("{}", x_ref[...]) - o_ref[...] = x_ref[...] - - n = np.prod(shape) - x = jnp.arange(n, dtype=dtype).reshape(shape) - compiled_kernel = ( - jax.jit(kernel) - .lower(x) - .compile({"xla_tpu_enable_log_recorder": "true"}) - ) - with jtu.capture_stderr() as get_output: - jax.block_until_ready(compiled_kernel(x)) - output = get_output() - numbers = [ - int(num) - for line in output.splitlines() - if (match := re.search(r"\{(.*)", line)) # extract contents after `{` - for num in re.findall(r"\d+", match.group(1)) - ] - # Check if the numbers in the output match the values generated by `arange`. - self.assertLen(numbers, n) - self.assertTrue(all(num == i for i, num in enumerate(numbers))) - - -class PallasCallTraceTest(PallasBaseTest): +class PallasCallTraceTest(ptu.PallasTPUTest): + @jtu.thread_unsafe_test() # stdout redirection is not thread safe def test_trace_start_stop_match(self): def kernel(o_ref): with jax.named_scope('scope1'): @@ -2402,6 +3153,7 @@ def kernel(o_ref): self.assertEqual(num_start, 1) self.assertEqual(num_stop, 1) + @jtu.thread_unsafe_test() # stdout redirection is not thread safe def test_run_scoped(self): def kernel(o_ref): def scope1(): @@ -2429,7 +3181,7 @@ def scope2(): self.assertEqual(num_stop, 2) -class PallasCallTPUBooleanTest(PallasBaseTest): +class PallasCallTPUBooleanTest(ptu.PallasTPUTest): """Tests for loading/storing from bool memrefs on TPUs. We specifically test bools because they have special handling. @@ -2527,8 +3279,8 @@ def kernel(x_ref, o_ref, send_sem, recv_sem): output_shape = jax.ShapeDtypeStruct((8, 128), jnp.bool_) grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM)], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), grid=(1,), scratch_shapes=[pltpu.SemaphoreType.DMA] * 2, ) @@ -2549,7 +3301,7 @@ def kernel(x_ref, o_ref, send_sem, recv_sem): mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x'), - check_rep=False + check_vma=False ) )(input_arr) @@ -2558,7 +3310,7 @@ class PallasCallTPUBooleanInterpretTest(PallasCallTPUBooleanTest): INTERPRET: bool = True -class PallasCallTPUCheckifyTest(PallasBaseTest): +class PallasCallTPUCheckifyTest(ptu.PallasTPUTest): @parameterized.parameters((2,), (5,), (6,), (7,)) def test_checkify_with_scalar_prefetch(self, threshold): def body(scalar_ref, x_ref, o_ref): @@ -2570,8 +3322,7 @@ def body(scalar_ref, x_ref, o_ref): x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128)) def _x_transform(i, s_ref): - s = pl.load(s_ref, (i,)) - return (s, 0) + return (s_ref[i], 0) pallas_call = self.pallas_call( body, @@ -2663,60 +3414,85 @@ class PallasCallTPUCheckifyInterpretTest(PallasCallTPUCheckifyTest): INTERPRET: bool = True -class PrettyPrintingTest(PallasBaseTest): +class PrettyPrintingTest(ptu.PallasTPUTest): @parameterized.parameters( ( - lambda i: (i, pl.ds(0, 8), pl.ds(0, 128)), - 'dma_start c[d,:,:] -> e[...] f', + lambda i: (i, pl.ds(0, 8), pl.ds(0, 128)), 0, False, + 'dma_start(p0) c[d,:,:] -> e[...] f', ), ( - lambda i: (0, pl.ds(i, 8), pl.ds(0, 128)), - 'dma_start c[0,d:d+8,:] -> e[...] f', + lambda i: (0, pl.ds(i, 8), pl.ds(0, 128)), 0, False, + 'dma_start(p0) c[0,d:d+8,:] -> e[...] f', ), ( - lambda i: (i, pl.ds(2, 4), pl.ds(0, 100)), - 'dma_start c[d,2:6,:100] -> e[...] f', + lambda i: (i, pl.ds(2, 4), pl.ds(0, 100)), 0, False, + 'dma_start(p0) c[d,2:6,:100] -> e[...] f', ), ( - lambda i: (i, pl.ds(2, 6), pl.ds(4, 100)), - 'dma_start c[d,2:,4:104] -> e[...] f', + lambda i: (i, pl.ds(2, 6), pl.ds(4, 100)), 1, False, + 'dma_start(p1) c[d,2:,4:104] -> e[...] f', + ), + ( + lambda i: (i, pl.ds(2, 6), pl.ds(4, 100)), 0, True, + 'dma_start(p0, add) c[d,2:,4:104] -> e[...] f', ), ) - def test_dma_custom_pretty_print(self, indexer, expected): + def test_dma_custom_pretty_print(self, indexer, priority, add, expected): def body(x_hbm_ref, i): def inner(x_ref, sem): - pltpu.async_copy(x_hbm_ref.at[indexer(i)], x_ref, sem).wait() + pltpu.async_copy(x_hbm_ref.at[indexer(i)], x_ref, sem, + priority=priority, + add=add).wait() pl.run_scoped( inner, pltpu.VMEM((8, 128), jnp.float32), pltpu.SemaphoreType.DMA ) return [] - jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( - wrap_init(body, 2), [state.shaped_array_ref((2, 8, 128), jnp.int32), - jax.core.ShapedArray((), jnp.int32)] + jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( + wrap_init(body, 2), + [ + state.shaped_array_ref((2, 8, 128), jnp.int32), + jax.core.ShapedArray((), jnp.int32), + ], ) self.assertIn(expected, jaxpr.pretty_print(use_color=False)) -def only_passes_in_interpret(unless_generation: int | None = None): - def decorator(f): - def wrapper(self): - if self.INTERPRET or ( - unless_generation is not None - and jtu.is_device_tpu_at_least(unless_generation) - ): - f(self) - else: - with self.assertRaises(Exception): - f(self) - return wrapper - return decorator +class MiscellaneousTest(ptu.PallasTPUTest): + """Tests for reported bugs. Only pass in interpret mode unless fixed.""" + def test_casting_bool_to_i8(self): + if not jtu.is_device_tpu_at_least(5): + self.skipTest("Operation not supported on this TPU version.") -class MiscellaneousTest(PallasBaseTest): - """Tests for reported bugs. Only pass in interpret mode unless fixed.""" + def greater_than(x: jax.Array, y: jax.Array): + def kernel(x_ref, y_ref, out_ref): + cmp = (x_ref[...] > y_ref[...]).astype(jnp.int8) + out_ref[:] = cmp + + in_specs = [ + pl.BlockSpec(memory_space=pltpu.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), + ] + out_specs = pl.BlockSpec(memory_space=pltpu.VMEM) + + return self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(x.shape, jnp.int8), + in_specs=in_specs, + out_specs=out_specs, + )(x, y) + + key = jax.random.key(0) + x_key, y_key = jax.random.split(key) + x = jax.random.normal(x_key, (128, 16), dtype=jnp.float32) + y = jax.random.normal(y_key, (128, 16), dtype=jnp.float32) + out = jax.jit(greater_than)(x, y) + + expected = (x > y).astype(jnp.int8) + np.testing.assert_array_equal(out, expected) def test_float32_stack(self): x = np.arange(128, dtype=jnp.float32).reshape(1, 128) @@ -2730,9 +3506,9 @@ def kernel(x_ref, y_ref, out_ref): )(x, y) np.testing.assert_array_equal(out, np.stack([x, y], axis=1)) - @only_passes_in_interpret() def test_lane_to_chunk_reshape_bf16(self): - """b/348038320""" + if not jtu.is_device_tpu_at_least(4): + self.skipTest('Operation not supported on this TPU version.') x = np.arange(256 * 1024, dtype=jnp.bfloat16).reshape(1, 256, 1024) def kernel(x_ref, out_ref): @@ -2793,22 +3569,45 @@ def kernel(x_ref, out_ref): )(x) np.testing.assert_array_equal(out, state_utils.bitcast(x, jnp.uint32)) - @only_passes_in_interpret() - def test_roll_partial(self): - """b/337384645""" - x = np.arange(8192, dtype=jnp.float32).reshape(128, 64) + @parameterized.product( + shape=((128, 64), (15, 256), (16, 256)), + shift=(2, 3), + axis=(0, 1), + ) + def test_roll_partial_with_static_shift( + self, shape: tuple[int, int], shift: int, axis: int + ): + x = np.arange(math.prod(shape), dtype=jnp.float32).reshape(shape) def kernel(x_ref, out_ref): - out_ref[...] = pltpu.roll(x_ref[...], 3, 1) + out_ref[...] = pltpu.roll(x_ref[...], shift=shift, axis=axis) out = self.pallas_call( - kernel, out_shape=jax.ShapeDtypeStruct((128, 64), jnp.float32) + kernel, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32) )(x) - np.testing.assert_array_equal(out, np.roll(x, 3, 1)) + np.testing.assert_array_equal(out, np.roll(x, shift, axis)) + + @parameterized.product( + shape_and_axis=(((128, 64), 1), ((63, 256), 0)), + ) + def test_roll_partial_with_dynamic_shift( + self, shape_and_axis: tuple[tuple[int, int], int] + ): + if self.INTERPRET: + self.skipTest('Test only applies to non-interpret mode.') + shape, axis = shape_and_axis + x = np.arange(math.prod(shape), dtype=jnp.float32).reshape(shape) + + def kernel(x_ref, out_ref): + amount = x_ref[0, 0].astype(jnp.int32) + out_ref[...] = pltpu.roll(x_ref[...], amount, axis=axis) + + with self.assertRaisesRegex(Exception, 'unsupported unaligned shape'): + _ = self.pallas_call( + kernel, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32) + )(x) - @only_passes_in_interpret() def test_retiling1(self): - """b/352626602""" x = np.arange(1024, dtype=jnp.bfloat16).reshape(1024) def kernel(x_ref, out_ref): @@ -2848,9 +3647,9 @@ def kernel(x_ref, out_ref): np.testing.assert_array_equal(out, np.reshape(x, (8, 1, 128))) - @only_passes_in_interpret() def test_sublane_adding_shape_cast_bf16(self): - """b/352833257""" + if not jtu.is_device_tpu_at_least(4): + self.skipTest('Operation not supported on this TPU version.') x = np.arange(8 * 128, dtype=jnp.bfloat16).reshape(8, 128) def kernel(x_ref, out_ref): @@ -2863,8 +3662,8 @@ def kernel(x_ref, out_ref): np.testing.assert_array_equal(out, np.reshape(x, (8, 1, 128))) def test_mixed_strides(self): - x = np.zeros((8, 128), dtype=jnp.float32) - y = np.zeros((8, 2, 128), dtype=jnp.bfloat16) + x = np.full((8, 128), 1.0, dtype=jnp.float32) + y = np.full((8, 2, 128), 2.0, dtype=jnp.bfloat16) def kernel(x_ref, y_ref, out_ref): out_ref[:, :] = x_ref[:, :] + y_ref[:, 1, :].astype(jnp.float32) @@ -2874,7 +3673,9 @@ def kernel(x_ref, y_ref, out_ref): out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x, y) - np.testing.assert_array_equal(out, np.zeros((8, 128), dtype=jnp.float32)) + np.testing.assert_array_equal( + out, np.full((8, 128), 3.0, dtype=jnp.float32) + ) def test_sum(self): x = np.zeros((8, 2, 8, 128), dtype=jnp.float32) @@ -2888,9 +3689,7 @@ def kernel(x_ref, out_ref): np.testing.assert_array_equal(out, np.zeros((8, 2, 128), dtype=jnp.float32)) - @only_passes_in_interpret() def test_transpose(self): - """b/356475128""" x = np.zeros((8, 2, 8, 128), dtype=jnp.float32) def kernel(x_ref, out_ref): @@ -2904,10 +3703,524 @@ def kernel(x_ref, out_ref): out, np.zeros((8, 8, 2, 128), dtype=jnp.float32) ) + @parameterized.parameters( + (3, 1, 2048, jnp.bfloat16), + (5, 1, 4096, jnp.int8), + ) + def test_1d_tiling_major_minor_transpose(self, q, m, n, dtype): + in_shape = (q, n) + mid_shape = (q, m, n) + out_shape = (m, q, n) + x = np.arange(np.prod(in_shape), dtype=dtype).reshape(in_shape) + + def kernel(x_ref, o_ref): + x = x_ref[...] + x = jnp.reshape(x, mid_shape) + o_ref[...] = jnp.transpose(x, axes=(1, 0, 2)) + + result = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(out_shape, dtype), + )(x) + np.testing.assert_array_equal( + result, np.transpose(x.reshape(mid_shape), axes=(1, 0, 2)) + ) + + # (q, m, n) -> (q, m * n) where n % 128 == 0 + @parameterized.parameters( + (q, m, n, dtype) + for (q, m, n), dtype in itertools.product( + [ + (32, 16, 512), + (20, 19, 512), + (5, 3, 256), + (9, 15, 256), + (3, 2, 256), + ], + [jnp.float32, jnp.uint32, jnp.bfloat16, jnp.int8], + ) + ) + def test_reshape_two_minor_dims_to_R2(self, q, m, n, dtype): + if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( + dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) + ): + self.skipTest('Operation not supported on this TPU version.') + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape( + x_ref.shape[0], x_ref.shape[1] * x_ref.shape[2] + ) + + x = np.arange(q * m * n, dtype=dtype).reshape(q, m, n) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q, m * n), dtype), + )(x) + jax.numpy.set_printoptions(threshold=jax.numpy.inf) + expected = x.reshape([q, m * n]) + np.testing.assert_array_equal(out, x.reshape([q, m * n])) + + # (q, m, n, k) -> (q, m, n * k) where k % 128 == 0 + @parameterized.parameters( + (q, m, n, k, dtype) + for (q, m, n, k), dtype in itertools.product( + [ + (3, 8, 17, 512), + (1, 8, 9, 256), + (1, 8, 3, 256), + (10, 1, 4, 256), + (1, 2, 2, 256), + (1, 9, 3, 256), + ], + [jnp.float32, jnp.uint32, jnp.bfloat16, jnp.int8], + ) + ) + def test_reshape_two_minor_dims_to_R3(self, q, m, n, k, dtype): + if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( + dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) + ): + self.skipTest('Operation not supported on this TPU version.') + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape( + x_ref.shape[0], x_ref.shape[1], x_ref.shape[2] * x_ref.shape[3] + ) + + x = np.arange(q * m * n * k, dtype=dtype).reshape(q, m, n, k) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q, m, n * k), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([q, m, n * k])) + + # (q, m, n) -> (q, m * n) where n % 128 != 0 + @parameterized.parameters( + (q, m, n, dtype) + for (q, m, n), dtype in itertools.product( + [ + (32, 16, 500), + (20, 19, 500), + (5, 3, 200), + (9, 15, 200), + (3, 2, 200), + (5, 1, 300), + ], + [jnp.float32, jnp.uint32, jnp.bfloat16, jnp.int8], + ) + ) + def test_reshape_two_minor_dims_to_R2_padded_last_dim(self, q, m, n, dtype): + if not jtu.is_cloud_tpu_at_least(2025, 12, 22): + self.skipTest('Needs a newer libTPU') + if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( + dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) + ): + self.skipTest('Operation not supported on this TPU version.') + + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape( + x_ref.shape[0], x_ref.shape[1] * x_ref.shape[2] + ) + + x = np.arange(q * m * n, dtype=dtype).reshape(q, m, n) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q, m * n), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([q, m * n])) + + # (q, m, n, k) -> (q, m, n * k) where k % 128 != 0 + @parameterized.parameters( + (q, m, n, k, dtype) + for (q, m, n, k), dtype in itertools.product( + [ + (3, 8, 17, 500), + (1, 8, 9, 200), + (1, 8, 3, 200), + (10, 1, 4, 200), + (1, 2, 2, 200), + (1, 9, 3, 200), + (4, 7, 1, 300), + ], + [jnp.float32, jnp.uint32, jnp.bfloat16, jnp.int8], + ) + ) + def test_reshape_two_minor_dims_to_R3_padded_last_dim( + self, q, m, n, k, dtype + ): + if not jtu.is_cloud_tpu_at_least(2025, 12, 22): + self.skipTest('Needs a newer libTPU') + if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( + dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) + ): + self.skipTest('Operation not supported on this TPU version.') + + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape( + x_ref.shape[0], x_ref.shape[1], x_ref.shape[2] * x_ref.shape[3] + ) + + x = np.arange(q * m * n * k, dtype=dtype).reshape(q, m, n, k) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q, m, n * k), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([q, m, n * k])) + + # (p, q, m, n, k) -> (p, q * m * n * k) where k % 128 == 0 + @parameterized.parameters( + (p, q, m, n, k, dtype) + for (p, q, m, n, k), dtype in itertools.product( + [ + (5, 3, 8, 17, 512), + (6, 1, 8, 9, 256), + (16, 1, 8, 3, 256), + (3, 2, 1, 4, 256), + (1, 7, 2, 2, 256), + ], + [jnp.float32, jnp.uint32, jnp.bfloat16, jnp.int8], + ) + ) + def test_reshape_four_minor_dims_to_R2(self, p, q, m, n, k, dtype): + if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( + dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) + ): + self.skipTest('Operation not supported on this TPU version.') + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape( + x_ref.shape[0], + x_ref.shape[1] * x_ref.shape[2] * x_ref.shape[3] * x_ref.shape[4], + ) + + x = np.arange(p * q * m * n * k, dtype=dtype).reshape(p, q, m, n, k) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((p, q * m * n * k), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([p, q * m * n * k])) + + # (q, m, n, k) -> (q, m, 1, n * k) where k % 128 == 0 + @parameterized.parameters( + (q, m, n, k, dtype) + for (q, m, n, k), dtype in itertools.product( + [ + (10, 1, 4, 256), + ], + [jnp.float32, jnp.uint32, jnp.bfloat16, jnp.int8], + ) + ) + def test_reshape_two_minor_dims_preserve_rank(self, q, m, n, k, dtype): + if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( + dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) + ): + self.skipTest('Operation not supported on this TPU version.') + def kernel(x_ref, y_ref): + y_ref[...] = ( + x_ref[...] + .reshape( + x_ref.shape[0], x_ref.shape[1], x_ref.shape[2] * x_ref.shape[3] + ) + .reshape( + x_ref.shape[0], 1, x_ref.shape[1], x_ref.shape[2] * x_ref.shape[3] + ) + ) + + q, m, n, k = 10, 1, 4, 256 + x = np.arange(q * m * n * k, dtype=dtype).reshape(q, m, n, k) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q, m, 1, n * k), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([q, m, 1, n * k])) + + # (q, m, n, k) -> (q * m, n * k) where k % 128 == 0 + @parameterized.parameters( + (q, m, n, k, dtype) + for (q, m, n, k), dtype in itertools.product( + [ + (3, 9, 17, 512), + (1, 8, 9, 256), + (1, 8, 3, 384), + (10, 1, 4, 256), + (1, 2, 2, 256), + ], + [jnp.float32, jnp.uint32, jnp.bfloat16, jnp.int8], + ) + ) + def test_reshape_fold_two_leading_dims_and_two_minor_dims_R4_to_R2( + self, q, m, n, k, dtype + ): + if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( + dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) + ): + self.skipTest('Operation not supported on this TPU version.') + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape( + x_ref.shape[0] * x_ref.shape[1], x_ref.shape[2] * x_ref.shape[3] + ) + + x = np.arange(q * m * n * k, dtype=dtype).reshape(q, m, n, k) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q * m, n * k), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([q * m, n * k])) + + # (q * m, n, k) -> (q, m, n * k) where k % 128 == 0 + @parameterized.parameters( + (q, m, n, k, dtype) + for (q, m, n, k), dtype in itertools.product( + [ + (2, 2, 17, 512), + (3, 2, 3, 256), + (1, 5, 4, 384), + ], + [jnp.float32, jnp.uint32, jnp.bfloat16, jnp.int8], + ) + ) + def test_reshape_unfold_leading_dim_and_fold_two_minor_dims_R3_to_R3( + self, q, m, n, k, dtype + ): + if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( + dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) + ): + self.skipTest('Operation not supported on this TPU version.') + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape( + q, + m, + x_ref.shape[1] * x_ref.shape[2], + ) + + x = np.arange(q * m * n * k, dtype=dtype).reshape(q * m, n, k) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q, m, n * k), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([q, m, n * k])) + + # (q * m, n * k) -> (q, m, n, k) where k % 128 == 0 + @parameterized.parameters( + (q, m, n, k, dtype) + for (q, m, n, k), dtype in itertools.product( + [ + (2, 2, 17, 512), + (3, 2, 3, 256), + (1, 5, 4, 384), + ], + [jnp.float32, jnp.uint32, jnp.bfloat16, jnp.int8], + ) + ) + def test_reshape_unfold_leading_and_minor_dims_R2_to_R4( + self, q, m, n, k, dtype + ): + if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( + dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) + ): + self.skipTest('Operation not supported on this TPU version.') + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape(q, m, n, k) + + x = np.arange(q * m * n * k, dtype=dtype).reshape(q * m, n * k) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q, m, n, k), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([q, m, n, k])) + + # (q, m, n * k) -> (q * m, n, k) where k % 128 == 0 + @parameterized.parameters( + (q, m, n, k, dtype) + for (q, m, n, k), dtype in itertools.product( + [ + (2, 2, 17, 512), + (3, 2, 8, 256), + (1, 5, 4, 384), + ], + [jnp.float32, jnp.uint32, jnp.bfloat16, jnp.int8], + ) + ) + def test_reshape_fold_leading_dims_and_unfold_minor_dim( + self, q, m, n, k, dtype + ): + if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( + dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) + ): + self.skipTest('Operation not supported on this TPU version.') + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape(q * m, n, k) + + x = np.arange(q * m * n * k, dtype=dtype).reshape(q, m, n * k) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q * m, n, k), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([q * m, n, k])) + + # (q, m, n, k) -> (q, m * n, k) where k % 128 == 0 + @parameterized.parameters( + (q, m, n, k, dtype) + for (q, m, n, k), dtype in itertools.product( + [ + (2, 2, 17, 512), + (3, 2, 8, 256), + (1, 5, 4, 384), + ], + [jnp.float32, jnp.uint32, jnp.bfloat16, jnp.int8], + ) + ) + def test_reshape_fold_middle_dims(self, q, m, n, k, dtype): + if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( + dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) + ): + self.skipTest('Operation not supported on this TPU version.') + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape(q, m * n, k) + + x = np.arange(q * m * n * k, dtype=dtype).reshape(q, m, n, k) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q, m * n, k), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([q, m * n, k])) + + # (q, m * n, k) -> (q, m, n, k) where k % 128 == 0 + @parameterized.parameters( + (q, m, n, k, dtype) + for (q, m, n, k), dtype in itertools.product( + [ + (2, 2, 17, 512), + (3, 2, 8, 256), + (9, 5, 4, 384), + ], + [jnp.float32, jnp.uint32, jnp.bfloat16, jnp.int8], + ) + ) + def test_reshape_unfold_middle_dims(self, q, m, n, k, dtype): + if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( + dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) + ): + self.skipTest('Operation not supported on this TPU version.') + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape(q, m, n, k) + + x = np.arange(q * m * n * k, dtype=dtype).reshape(q, m * n, k) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q, m, n, k), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([q, m, n, k])) + + @parameterized.parameters([jnp.int8, jnp.bfloat16, jnp.float32]) + def test_reshape_shift_factor_from_minor_to_major(self, dtype): + if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( + dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) + ): + self.skipTest('Operation not supported on this TPU version.') + q0, m0, n0 = 1, 3, 7680 + q1, m1, n1 = 3, 10, 768 + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape(q1, m1, n1) + + x = np.arange(q0 * m0 * n0, dtype=dtype).reshape(q0, m0, n0) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q1, m1, n1), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([q1, m1, n1])) + + @parameterized.product( + dtype=[jnp.float32, jnp.bfloat16, jnp.float8_e4m3fn], + ) + def test_reshape_fold_minormost_dim(self, dtype): + packing = 32 // (8 * np.dtype(dtype).itemsize) + in_shape = (8 * packing, 128) + out_shape = (1, math.prod(in_shape)) + + def kernel(x_ref, y_ref): + x = x_ref[...] + y_ref[...] = x.reshape(out_shape) + + x = np.random.randn(*in_shape).astype(dtype) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(out_shape, dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape(out_shape)) + + def test_dynamic_grid_with_smem_output(self): + if self.INTERPRET: + self.skipTest('Fail on interpreter.') + + def body(_, o_ref): + o_ref[0] = lax.cond( + pl.program_id(0) == 0, lambda: 1, lambda: o_ref[0] + 1 + ) + + def wrapper_dynamic(n): + return self.pallas_call( + body, + out_shape=pltpu.SMEM((1,), dtype=jnp.int32), + grid_spec=pl.GridSpec( + grid=(n,), + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), + ), + )(n) + + n = jax.random.randint(jax.random.key(0), (1,), 1, 10, dtype=jnp.int32) + compiled_kernel = jax.jit(wrapper_dynamic).lower(n).compile() + np.testing.assert_array_equal(compiled_kernel(n), n) + class MiscellaneousInterpretTest(MiscellaneousTest): INTERPRET: bool = True + def test_async_copy_slice(self): + # https://github.com/jax-ml/jax/issues/33260 + def kernel(o): + @functools.partial(pl.run_scoped, + sem=pltpu.SemaphoreType.DMA, + x=pltpu.VMEM((1,), jnp.float32)) + def _(sem, x): + x[...] = jnp.ones_like(x) + @functools.partial(pl.run_scoped, + y=pltpu.VMEM((1, 1,), jnp.float32)) + def _(y): + pltpu.async_copy(x, y.at[0], sem).wait() + o[...] = y[0] + + result = pl.pallas_call(kernel, out_shape=jax.ShapeDtypeStruct( + (1,), jnp.float32), interpret=True)() + np.testing.assert_array_equal(result, np.ones((1,), dtype=jnp.float32)) + + +class PallasKernelMetadataTest(ptu.PallasTPUTest): + + @parameterized.parameters( + (dict(foo='bar'),), + (dict(foo='afjafo'),), + (dict(problem_info=json.dumps(dict(tiling_info=dict(bm=128, bk=128)))),), + ) + def test_metadata_is_preserved(self, metadata): + + def kernel(x_ref, y_ref, out_ref): + out_ref[...] = x_ref[...] + y_ref[...] + + x = jnp.arange(1024, dtype=jnp.float32).reshape((8, 128)) + y = jnp.arange(1024, dtype=jnp.float32).reshape((8, 128)) + + @jax.jit + def f(x, y): + return self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + metadata=metadata, + )(x, y) + + hlo = f.lower(x, y).compile().as_text() + self.assertIn( + json.dumps(metadata, sort_keys=True, indent=0, separators=(',', ':')), + hlo, + ) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_ragged_paged_attention_test.py b/tests/pallas/tpu_ragged_paged_attention_test.py index bffcebc5254b..4e088f7e3ad7 100644 --- a/tests/pallas/tpu_ragged_paged_attention_test.py +++ b/tests/pallas/tpu_ragged_paged_attention_test.py @@ -12,29 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -import random from absl.testing import absltest from absl.testing import parameterized import jax +from jax._src import dtypes from jax._src import test_util as jtu +from jax.experimental import pallas as pl from jax.experimental.pallas.ops.tpu.ragged_paged_attention import ( + dynamic_validate_inputs, ragged_paged_attention, ref_ragged_paged_attention, - validate_inputs_on_runtime, ) import jax.numpy as jnp +import numpy as np jax.config.parse_flags_with_absl() -def ceil_div(x, a): - assert a != 0 - return (x + a - 1) // a - - @jtu.with_config(jax_numpy_dtype_promotion="standard") -class PagedAttentionKernelTest(jtu.JaxTestCase): +class RaggedPagedAttentionKernelTest(jtu.JaxTestCase): def _test_ragged_paged_attention( self, @@ -42,7 +39,8 @@ def _test_ragged_paged_attention( num_heads, # [num_q_heads, num_kv_heads] head_dim, page_size, - dtype, + q_dtype, + kv_dtype, num_pages, *, num_kv_pages_per_block=8, @@ -50,6 +48,10 @@ def _test_ragged_paged_attention( vmem_limit_bytes=32 * 1024 * 1024, max_num_batched_tokens=512, max_num_seq=8, + sliding_window: int | None = None, + soft_cap: float | None = None, + k_scale: float | None = None, + v_scale: float | None = None, ): if not jtu.is_device_tpu_at_least(version=4): self.skipTest("Expect TPUv4+") @@ -63,73 +65,116 @@ def _test_ragged_paged_attention( max_num_batched_tokens = max(cu_q_lens[-1], max_num_batched_tokens) max_num_seq = max(len(seq_lens), max_num_seq) max_kv_len = max(kv_lens) - pages_per_seq = ceil_div(max_kv_len, page_size) + pages_per_seq = pl.cdiv(max_kv_len, page_size) num_q_heads, num_kv_heads = num_heads - cu_q_lens = jnp.array(cu_q_lens, dtype=jnp.int32) - kv_lens = jnp.array(kv_lens, dtype=jnp.int32) - cu_q_lens = jnp.pad(cu_q_lens, (0, max_num_seq + 1 - cu_q_lens.shape[0])) - kv_lens = jnp.pad(kv_lens, (0, max_num_seq - kv_lens.shape[0])) prng_key = jax.random.key(1234) - k0, k1, k2, k3 = jax.random.split(prng_key, 4) + k0, k1 = jax.random.split(prng_key, 2) q = jax.random.normal( k0, (max_num_batched_tokens, num_q_heads, head_dim), - dtype=dtype, + dtype=q_dtype, ) - k_pages = jax.random.normal( - k1, - (num_pages, page_size, num_kv_heads, head_dim), - dtype=dtype, - ) - v_pages = jax.random.normal( - k2, - (num_pages, page_size, num_kv_heads, head_dim), - dtype=dtype, + page_cnt = 0 + page_indices_list = [] + kv_pages_list = [] + for kv_len in kv_lens: + if jnp.issubdtype(kv_dtype, jnp.integer): + # random.randint doesn't support int4, so we use jnp.int32 here and then + # convert to the desired dtype. + kv = jax.random.normal( + k1, + (kv_len, num_kv_heads * 2, head_dim), + dtype=jnp.int32, + ) + kv = kv.astype(kv_dtype) + else: + kv = jax.random.normal( + k1, + (kv_len, num_kv_heads * 2, head_dim), + dtype=kv_dtype, + ) + kv = jnp.pad( + kv, + ((0, pl.cdiv(kv_len, page_size) * page_size - kv_len), (0, 0), (0, 0)), + constant_values=jnp.nan, + ).reshape(-1, page_size, num_kv_heads * 2, head_dim) + indices = page_cnt + jnp.arange(kv.shape[0], dtype=jnp.int32) + indices = jnp.pad( + indices, + ((0, pages_per_seq - indices.shape[0]),), + constant_values=jnp.nan, + ) + page_indices_list.append(indices) + page_cnt += kv.shape[0] + kv_pages_list.append(kv) + + kv_pages = jnp.concatenate(kv_pages_list, axis=0) + kv_pages = jnp.pad( + kv_pages, + ((0, num_pages - kv_pages.shape[0]), (0, 0), (0, 0), (0, 0)), + constant_values=jnp.nan, ) - page_indices = jax.random.randint( - k3, (max_num_seq, pages_per_seq), 0, num_pages, dtype=jnp.int32 + page_indices = jnp.stack(page_indices_list, axis=0) + page_indices = jnp.pad( + page_indices, + ((0, max_num_seq - page_indices.shape[0]), (0, 0)), + constant_values=jnp.nan, ) - + cu_q_lens = jnp.array(cu_q_lens, dtype=jnp.int32) + cu_q_lens = jnp.pad(cu_q_lens, (0, max_num_seq + 1 - cu_q_lens.shape[0])) + kv_lens = jnp.array(kv_lens, dtype=jnp.int32) + kv_lens = jnp.pad(kv_lens, (0, max_num_seq - kv_lens.shape[0])) num_seqs = jnp.array([len(seq_lens)], dtype=jnp.int32) - validate_inputs_on_runtime( + dynamic_validate_inputs( q, - k_pages, - v_pages, + kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs, + sliding_window=sliding_window, + soft_cap=soft_cap, ) + actual_num_q_tokens = cu_q_lens[num_seqs[0]] output = ragged_paged_attention( q, - k_pages, - v_pages, + kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs=num_seqs, - num_kv_pages_per_block=num_kv_pages_per_block, + num_kv_pages_per_block=min(num_kv_pages_per_block, pages_per_seq), num_queries_per_block=num_queries_per_block, vmem_limit_bytes=vmem_limit_bytes, - )[: cu_q_lens[num_seqs[0]]] + sliding_window=sliding_window, + soft_cap=soft_cap, + k_scale=k_scale, + v_scale=v_scale, + )[:actual_num_q_tokens] expected = ref_ragged_paged_attention( q, - k_pages, - v_pages, + kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs=num_seqs, + sliding_window=sliding_window, + soft_cap=soft_cap, + k_scale=k_scale, + v_scale=v_scale, ) + dtype_bits = dtypes.itemsize_bits(jnp.dtype(kv_dtype)) tols = { - "float32": 0.15, - "bfloat16": 0.2, + 32: 0.15, + 16: 0.2, + 8: 0.2, + 4: 0.2, } - tol = tols[jnp.dtype(dtype).name] + tol = tols[dtype_bits] self.assertAllClose(output, expected, atol=tol, rtol=tol) @parameterized.product( @@ -148,9 +193,40 @@ def test_ragged_paged_attention_basic(self, dtype): head_dim, page_size, dtype, + dtype, num_pages, ) + # TODO: support int4 and int8 + @parameterized.product( + q_dtype=[jnp.bfloat16], + kv_dtype=[jnp.float8_e5m2, jnp.float8_e4m3fn], + kv_scales=[(0.5, 0.5), (None, None)], + ) + def test_ragged_paged_attention_quantized_kv_cache( + self, q_dtype, kv_dtype, kv_scales + ): + if not jtu.is_device_tpu_at_least(version=5): + self.skipTest("Expect TPUv5+") + seq_lens = [(192, 328), (128, 180), (64, 255)] + num_heads = (32, 8) + head_dim = 128 + page_size = 16 + num_pages = 1000 + k_scale, v_scale = kv_scales + + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + q_dtype, + kv_dtype, + num_pages, + k_scale=k_scale, + v_scale=v_scale, + ) + @parameterized.product( dtype=[jnp.float32, jnp.bfloat16], ) @@ -184,6 +260,7 @@ def test_ragged_paged_attention_decode_only(self, dtype): head_dim, page_size, dtype, + dtype, num_pages, ) @@ -220,6 +297,7 @@ def test_ragged_paged_attention_prefill_only(self, dtype): head_dim, page_size, dtype, + dtype, num_pages, ) @@ -256,13 +334,15 @@ def test_ragged_paged_attention_mixed(self, dtype): head_dim, page_size, dtype, + dtype, num_pages, ) @parameterized.product( num_seqs=[1, 5, 16], # TODO(jevinjiang): Support more num_heads! - num_heads=[(32, 8), (32, 16), (12, 2), (4, 4)], + # TODO(b/434082000): Investigate why (12, 2) does not work after libtpu-2025-07-21. + num_heads=[(32, 8), (32, 16), (16, 2), (4, 4), (8, 1)], dtype=[jnp.float32, jnp.bfloat16], num_kv_pages_per_block=[4, 8], num_queries_per_block=[32, 64], @@ -275,11 +355,45 @@ def test_ragged_paged_attention_complex( num_kv_pages_per_block, num_queries_per_block, ): - seq_lens = [] - for _ in range(num_seqs): - q_len = random.randint(1, 100) - kv_len = q_len + random.randint(0, 50) - seq_lens.append((q_len, kv_len)) + rng = np.random.default_rng(1234) + q_lens = rng.integers(1, 100, num_seqs) + kv_lens = q_lens + rng.integers(0, 50, num_seqs) + seq_lens = list(zip(q_lens.tolist(), kv_lens.tolist())) + # TODO(jevinjiang): Support non-128 head_dim! + head_dim = 128 + page_size = 16 + num_pages = 1000 + + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + dtype, + num_pages, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + ) + + @parameterized.product( + num_kv_pages_per_block=[4, 8], + num_queries_per_block=[32, 64], + sliding_window=[None, 5, 128], + ) + def test_ragged_paged_attention_sliding_window( + self, + num_kv_pages_per_block, + num_queries_per_block, + sliding_window: int | None, + ): + num_seqs = 5 + num_heads = (4, 4) + dtype = jnp.float32 + rng = np.random.default_rng(1234) + q_lens = rng.integers(1, 100, num_seqs) + kv_lens = q_lens + rng.integers(0, 50, num_seqs) + seq_lens = list(zip(q_lens.tolist(), kv_lens.tolist())) # TODO(jevinjiang): Support non-128 head_dim! head_dim = 128 page_size = 16 @@ -291,11 +405,100 @@ def test_ragged_paged_attention_complex( head_dim, page_size, dtype, + dtype, num_pages, num_kv_pages_per_block=num_kv_pages_per_block, num_queries_per_block=num_queries_per_block, + sliding_window=sliding_window, ) + @parameterized.product( + num_kv_pages_per_block=[4, 8], + num_queries_per_block=[32, 64], + soft_cap=[None, 50.0], + ) + def test_ragged_paged_attention_logit_soft_capping( + self, + num_kv_pages_per_block, + num_queries_per_block, + soft_cap: float | None, + ): + num_heads = (16, 2) + num_seqs = 2 + dtype = jnp.float32 + rng = np.random.default_rng(1234) + q_lens = rng.integers(1, 100, num_seqs) + kv_lens = q_lens + rng.integers(0, 50, num_seqs) + seq_lens = list(zip(q_lens.tolist(), kv_lens.tolist())) + head_dim = 128 + page_size = 16 + num_pages = 1000 + + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + dtype, + num_pages, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + soft_cap=soft_cap, + ) + + def test_ragged_paged_attention_sliding_window_should_be_positive(self): + dtype = jnp.float32 + seq_lens = [(192, 328), (128, 180), (64, 255)] + num_heads = (32, 8) + head_dim = 128 + page_size = 16 + num_pages = 1000 + + with self.assertRaisesRegex(ValueError, "must be positive"): + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + dtype, + num_pages, + sliding_window=0, + ) + + with self.assertRaisesRegex(ValueError, "must be positive"): + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + dtype, + num_pages, + sliding_window=-1, + ) + + def test_ragged_paged_attention_soft_cap_cannot_be_zero(self): + dtype = jnp.float32 + seq_lens = [(192, 328), (128, 180), (64, 255)] + num_heads = (32, 8) + head_dim = 128 + page_size = 16 + num_pages = 1000 + + with self.assertRaisesRegex(ValueError, "must not be 0.0"): + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + dtype, + num_pages, + soft_cap=0.0, + ) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_side_effects_test.py b/tests/pallas/tpu_side_effects_test.py new file mode 100644 index 000000000000..ad650921c889 --- /dev/null +++ b/tests/pallas/tpu_side_effects_test.py @@ -0,0 +1,121 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import test_util as jtu +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +import jax.numpy as jnp +import numpy as np + +jax.config.parse_flags_with_absl() + + +class SideEffectsTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if not jtu.is_device_tpu(): + self.skipTest("TPU required") + + @parameterized.named_parameters( + ("pure", pltpu.SideEffectType.PURE), + ("side_effecting", pltpu.SideEffectType.SIDE_EFFECTING), + ("dataflow_side_effecting", pltpu.SideEffectType.DATAFLOW_SIDE_EFFECTING), + ("legacy_true", True), + ("legacy_false", False), + ) + def test_side_effects_enum(self, side_effect_type): + def kernel(x_ref, o_ref): + pltpu.sync_copy(x_ref, o_ref) + + @jax.jit + def f(x): + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + compiler_params=pltpu.CompilerParams( + has_side_effects=side_effect_type + ), + )(x) + + x = jnp.ones((8, 128), dtype=jnp.float32) + y = f(x) + np.testing.assert_array_equal(y, x) + + def test_invalid_side_effect_type(self): + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + + @jax.jit + def f(x): + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + compiler_params=pltpu.CompilerParams(has_side_effects="invalid"), + )(x) + + with self.assertRaisesRegex(ValueError, "Invalid side effect type"): + f(jnp.ones((8, 8), dtype=jnp.float32)) + + def test_side_effecting_dce(self): + def kernel(x_ref, o_ref): + pltpu.sync_copy(x_ref, o_ref) + + def get_compiled_hlo(side_effect_type): + @jax.jit + def f(x): + # We use dce_sink to consume the output but allow DCE if the op is pure/dce-able. + out = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + compiler_params=pltpu.CompilerParams( + has_side_effects=side_effect_type + ), + )(x) + jax.lax.dce_sink(out) + return x + + lowered = f.lower(jnp.ones((8, 8), dtype=jnp.float32)) + shlo = lowered.as_text() + hlo = lowered.compile().as_text() + return shlo, hlo + + # PURE kernels should be DCE'd. + shlo_pure, hlo_pure = get_compiled_hlo(pltpu.SideEffectType.PURE) + self.assertIn("custom_call", shlo_pure) + self.assertNotIn("custom-call", hlo_pure) + + # SIDE_EFFECTING kernels should NOT be DCE'd. + shlo_side_effecting, hlo_side_effecting = get_compiled_hlo( + pltpu.SideEffectType.SIDE_EFFECTING + ) + self.assertIn("custom_call", shlo_side_effecting) + self.assertIn("has_side_effect = true", shlo_side_effecting) + self.assertIn("custom-call", hlo_side_effecting) + self.assertIn("custom_call_has_side_effect=true", hlo_side_effecting) + + # DATAFLOW_SIDE_EFFECTING kernels SHOULD be DCE'd if outputs are unused. + shlo_dataflow, hlo_dataflow = get_compiled_hlo( + pltpu.SideEffectType.DATAFLOW_SIDE_EFFECTING + ) + self.assertIn("custom_call", shlo_dataflow) + self.assertIn("has_side_effect = true", shlo_dataflow) + self.assertNotIn("custom-call", hlo_dataflow) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_sparsecore_pallas_debug_check_test.py b/tests/pallas/tpu_sparsecore_pallas_debug_check_test.py new file mode 100644 index 000000000000..06844f280944 --- /dev/null +++ b/tests/pallas/tpu_sparsecore_pallas_debug_check_test.py @@ -0,0 +1,160 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SparseCore Pallas tests with runtime assertions. + +Runtime assertions halt TPU execution, which can cause subsequent tests to get +stuck. Therefore, each test with a failing assertion should run in a separate +process. By separating these tests from the rest, we can set the shard count +such that each test runs in its own shard. + +The test class in this file makes an attempt to detect the simple scenario where +there are more test methods in the module than shards. +""" + +import functools +import os +import sys +import unittest + +from absl.testing import absltest +from absl import flags +import jax +from jax._src import test_util as jtu +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +from jax.experimental.pallas import tpu_sc as plsc +import jax.numpy as jnp + + +jax.config.parse_flags_with_absl() + + +class DebugCheckTest(jtu.JaxTestCase): + @classmethod + def setUpClass(cls): + super().setUpClass() + + total_shards = int(os.environ.get("TEST_TOTAL_SHARDS", -1)) + if total_shards == -1: + raise unittest.SkipTest("Tests can only be run with Bazel.") + + loader = unittest.TestLoader() + test_cases = loader.loadTestsFromModule( + sys.modules['__main__'] + ).countTestCases() + if test_cases > total_shards: + raise RuntimeError( + "Each test with a failing assertion should be in a separate test" + " shard because they put the hardware in a halt state, causing" + " subsequent tests to fail. Make sure sharding is enabled and the" + f" shard count is at least {test_cases}." + ) + + def setUp(self): + if not jtu.is_device_tpu(5, "p") and not jtu.is_device_tpu_at_least(6): + self.skipTest("SparseCore only supported on TPU v5p+") + + super().setUp() + + def test_scalar_debug_check(self): + if not jtu.is_device_tpu_at_least(7): + # TODO: b/469486032 - Figure out why the test gets stuck on v5p, v6e. + self.skipTest("Fails on v5p and v6e.") + + x = jnp.arange(8) + + @pl.kernel( + out_shape=x, + mesh=plsc.ScalarSubcoreMesh(axis_name="core", num_cores=1), + ) + def kernel(o_hbm_ref): + @functools.partial( + pl.run_scoped, + sem=pltpu.SemaphoreType.DMA, + ) + def _(sem): + pltpu.async_copy(o_hbm_ref, o_hbm_ref, sem).wait() + pl.debug_check(True, "Check success!") + pl.debug_check(False, "Check failure!") + + with pl.enable_debug_checks(), self.assertRaises( + jax.errors.JaxRuntimeError + ) as error: + jax.block_until_ready(kernel()) + + self.assertNotIn("Check success!", str(error.exception)) + self.assertIn("Check failure!", str(error.exception)) + self.assertIn( + "check at DebugCheckTest.test_scalar_debug_check", str(error.exception) + ) + + def test_vector_debug_check(self): + if jtu.test_device_matches(["tpu"]) and not jtu.is_cloud_tpu_at_least(2026, 1, 16): + self.skipTest("Requires libtpu built after 2026-1-16") + + x = jnp.arange(8) + + @functools.partial( + pl.pallas_call, + out_shape=x, + compiler_params=pltpu.CompilerParams( + kernel_type=pltpu.KernelType.SC_VECTOR_SUBCORE + ), + ) + def kernel(_): + pl.debug_check(True, "Check success!") + pl.debug_check(False, "Check failure!") + + with pl.enable_debug_checks(), self.assertRaises( + jax.errors.JaxRuntimeError + ) as error: + jax.block_until_ready(kernel()) + + self.assertNotIn("Check success!", str(error.exception)) + self.assertIn("Check failure!", str(error.exception)) + self.assertIn( + "check at DebugCheckTest.test_vector_debug_check", str(error.exception) + ) + + def test_trigger_bounds_checker(self): + if "xla_sc_assert_level" in flags.FLAGS: + # The test crashes the process anyway, so no need to be clean. + flags.FLAGS.xla_sc_assert_level = "all-loads-stores" + else: + self.skipTest("TODO: Find another way to enable bounds checking.") + + x = jnp.arange(8, dtype=jnp.int32) + # Index 8 is out-of-bounds. + indices = jnp.array([0, 1, 2, 3, 4, 5, 6, 8], dtype=jnp.int32) + + @functools.partial( + pl.pallas_call, + out_shape=x, + compiler_params=pltpu.CompilerParams( + kernel_type=pltpu.KernelType.SC_VECTOR_SUBCORE + ), + ) + def kernel(x_ref, indices_ref, o_ref): + o_ref[...] = plsc.load_gather(x_ref, [indices_ref[...]]) + + # We expect this to fail with a runtime error from the bounds checker. + with self.assertRaisesRegex( + jax.errors.JaxRuntimeError, + "Trying to perform an indexed vector load from out of bounds address.", + ): + jax.block_until_ready(kernel(x, indices)) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_sparsecore_pallas_distributed_test.py b/tests/pallas/tpu_sparsecore_pallas_distributed_test.py new file mode 100644 index 000000000000..0488fcf72d0a --- /dev/null +++ b/tests/pallas/tpu_sparsecore_pallas_distributed_test.py @@ -0,0 +1,142 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Pallas on SparseCore with multiple devices.""" +import math + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax import lax +from jax._src import test_util as jtu +from jax.experimental import mesh_utils +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +from jax.experimental.pallas import tpu_sc as plsc +import jax.numpy as jnp +import numpy as np + + +jax.config.parse_flags_with_absl() + + +class PallasCallRemoteDMATest(parameterized.TestCase): + + def setUp(self): + super().setUp() + if jax.device_count() < 2: + self.skipTest('Only >=2 devices are supported.') + if not jtu.is_device_tpu_at_least(5): + self.skipTest('SparseCore only supported on TPU v5+') + + @parameterized.product(direction=['left', 'right'], num_devices=[2, None]) + def test_collective_permute_1d(self, direction, num_devices): + shape = (8, 128) + + # Implements a very simple collective permute. + @pl.kernel( + out_shape=jax.ShapeDtypeStruct(shape, jnp.int32), + mesh=plsc.ScalarSubcoreMesh(axis_name='core', num_cores=1), + scratch_shapes=( + pltpu.SemaphoreType.REGULAR, + pltpu.SemaphoreType.DMA, + pltpu.SemaphoreType.DMA, + ), + ) + def kernel(x_ref, y_ref, ready_sem, send_sem, recv_sem): + + my_id = lax.axis_index('x') + axis_size = lax.axis_size('x') + if direction == 'right': + neighbor = lax.rem(my_id + 1, axis_size) + else: + neighbor = lax.rem(my_id + axis_size - 1, axis_size) + pltpu.semaphore_signal(ready_sem, device_id=neighbor) + pltpu.semaphore_wait(ready_sem) + pltpu.async_remote_copy( + x_ref, y_ref, send_sem, recv_sem, device_id=neighbor + ).wait() + + num_devices = num_devices or jax.device_count() + x = jnp.arange(num_devices * math.prod(shape), dtype=jnp.int32).reshape( + (-1, shape[-1]) + ) + device_mesh = mesh_utils.create_device_mesh( + (num_devices,), jax.devices()[:num_devices] + ) + mesh = jax.sharding.Mesh(device_mesh, ['x']) + f = jax.jit( + jax.shard_map( + kernel, + mesh=mesh, + in_specs=jax.P('x'), + out_specs=jax.P('x'), + check_vma=False, + ) + ) + if direction == 'right': + expected = jnp.concatenate([x[-8:], x[:-8]]) + else: + expected = jnp.concatenate([x[8:], x[:8]]) + np.testing.assert_allclose(f(x), expected) + + @parameterized.product(direction=['left', 'right']) + def test_collective_permute_2d(self, direction): + shape = (8, 128) + + @pl.kernel( + out_shape=jax.ShapeDtypeStruct(shape, jnp.int32), + mesh=plsc.ScalarSubcoreMesh(axis_name='core', num_cores=1), + scratch_shapes=( + pltpu.SemaphoreType.REGULAR, + pltpu.SemaphoreType.DMA, + pltpu.SemaphoreType.DMA, + ), + ) + def kernel(x_ref, y_ref, ready_sem, send_sem, recv_sem): + my_id = lax.axis_index('x') + my_other_id = lax.axis_index('y') + axis_size = lax.axis_size('x') + if direction == 'right': + neighbor = lax.rem(my_id + 1, axis_size) + else: + neighbor = lax.rem(my_id + axis_size - 1, axis_size) + pltpu.semaphore_signal(ready_sem, device_id=(my_other_id, neighbor)) + pltpu.semaphore_wait(ready_sem) + pltpu.async_remote_copy( + x_ref, y_ref, send_sem, recv_sem, device_id=(my_other_id, neighbor) + ).wait() + + axis_size = jax.device_count() // 2 + x = jnp.arange(axis_size * 8 * 128).reshape((axis_size * 8, 128)) + + device_mesh = mesh_utils.create_device_mesh((2, axis_size), jax.devices()) + mesh = jax.sharding.Mesh(device_mesh, ['y', 'x']) + y = jax.jit( + jax.shard_map( + kernel, + mesh=mesh, + in_specs=jax.P('x', None), + out_specs=jax.P('x', None), + check_vma=False, + ) + )(x) + if direction == 'right': + expected = jnp.concatenate([x[-8:], x[:-8]]) + else: + expected = jnp.concatenate([x[8:], x[:8]]) + np.testing.assert_allclose(y, expected) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/pallas/tpu_sparsecore_pallas_test.py b/tests/pallas/tpu_sparsecore_pallas_test.py new file mode 100644 index 000000000000..aa05360eed4f --- /dev/null +++ b/tests/pallas/tpu_sparsecore_pallas_test.py @@ -0,0 +1,2011 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Pallas on SparseCore.""" + +import collections +import functools +import itertools +import math + +from absl.testing import absltest +from absl.testing import parameterized +import hypothesis as hp +import hypothesis.strategies as hps +import jax +from jax import lax +from jax._src import test_util as jtu +from jax._src.pallas.mosaic import sc_core +from jax._src.state import discharge as state_discharge +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +from jax.experimental.pallas import tpu_sc as plsc +import jax.numpy as jnp +import numpy as np + + +jtu.setup_hypothesis() +jax.config.parse_flags_with_absl() + + +class PallasSCTest(jtu.JaxTestCase): + USE_TC_TILING = False + + def setUp(self): + if not jtu.is_device_tpu(5, "p") and not jtu.is_device_tpu_at_least(6): + self.skipTest("SparseCore only supported on TPU v5p+") + + if self.USE_TC_TILING and jtu.is_cloud_tpu(): + # TODO(apaszke,slebedev): Fix those. + self.skipTest("Many tests are failing on Cloud TPUs") + + if not jtu.is_cloud_tpu_at_least(2026, 1, 17): + self.skipTest("Need newer libtpu") + + super().setUp() + + @property + def sc_info(self): + return plsc.get_sparse_core_info() + + def vector_subcore_kernel(self, **kwargs): + assert "compiler_params" not in kwargs + return functools.partial( + pl.pallas_call, + **kwargs, + compiler_params=pltpu.CompilerParams( + kernel_type=pltpu.KernelType.SC_VECTOR_SUBCORE, + use_tc_tiling_on_sc=self.USE_TC_TILING, + ), + ) + + def kernel(self, **kwargs): + assert "compiler_params" not in kwargs + return functools.partial( + pl.kernel, + compiler_params=pltpu.CompilerParams( + use_tc_tiling_on_sc=self.USE_TC_TILING + ), + **kwargs, + ) + + def skip_if_tc_tiling(self, reason: str = ""): + if self.USE_TC_TILING: + self.skipTest(f"TC tiling is not supported. {reason}") + + +class DebugPrintTest(PallasSCTest): + + def setUp(self): + if jtu.is_cloud_tpu(): + # TODO(slebedev): Investigate this and remove the skip. + self.skipTest("Fails on Cloud TPUs") + + super().setUp() + + @parameterized.product(dtype=[jnp.int32, jnp.float32]) + def test_vector_subcore(self, dtype): + x = jnp.arange(16, dtype=dtype) + debug_int = 1234552 + debug_float = 12344.625 + + @self.vector_subcore_kernel(out_shape=x) + def kernel(x_hbm_ref, _): + pl.debug_print("Memref", x_hbm_ref) + x = x_hbm_ref[:8] + 100 + pl.debug_print("Vector value", x) + masks = x < 103 + pl.debug_print("Masks", masks) + pl.debug_print("Single int", debug_int) + pl.debug_print("Single float", debug_float) + pl.debug_print("No values") + + compiled_kernel = jax.jit( + kernel, + compiler_options={ + "xla_tpu_enable_sc_log_recorder": "true", + "xla_tpu_enable_tile_log_recorder": "true", + }, + ) + with jtu.capture_stderr() as get_output: + jax.block_until_ready(compiled_kernel(x)) + self.assertIn("Memref", get_output()) + self.assertIn("0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15", get_output()) + self.assertIn("Vector value", get_output()) + self.assertIn("100, 101, 102, 103, 104, 105, 106, 107", get_output()) + self.assertIn("Masks", get_output()) + self.assertIn("1, 1, 1, 0, 0, 0, 0, 0", get_output()) + self.assertIn("Single int, data: s32[1]", get_output()) + self.assertIn(str(debug_int), get_output()) + self.assertIn("Single float, data: f32[1]", get_output()) + self.assertIn(str(debug_float), get_output()) + self.assertIn("No values", get_output()) + + def test_scalar_subcore(self): + int32s = jnp.arange(512, dtype=jnp.int32).reshape(64, 8) + int16s = jnp.arange(512, dtype=jnp.int16).reshape(32, 16) + int8s = jnp.arange(512, dtype=jnp.int8).reshape(16, 32) + debug_int = 1234552 + debug_float = 12344.625 + + @self.kernel( + out_shape=int32s, + mesh=plsc.ScalarSubcoreMesh( + axis_name="core", num_cores=self.sc_info.num_cores + ), + ) + def kernel(int32s_hbm_ref, int16s_hbm_ref, int8s_hbm_ref, o_hbm_ref): + @functools.partial( + pl.run_scoped, + tmp_ref=pltpu.VMEM_SHARED(int32s.shape, int32s.dtype), + sem=pltpu.SemaphoreType.DMA, + ) + def _(tmp_ref, sem): + @pl.when(lax.axis_index("core") == 0) + def _(): + pltpu.async_copy(int32s_hbm_ref, tmp_ref, sem).wait() + pltpu.async_copy(tmp_ref, o_hbm_ref, sem).wait() + pl.debug_print("s32 array", tmp_ref) + pl.debug_print("s16 array", int16s_hbm_ref) + pl.debug_print("s8 array", int8s_hbm_ref) + pl.debug_print("Single int", debug_int) + pl.debug_print("Single float", debug_float) + pl.debug_print("No values") + + compiled_kernel = jax.jit( + kernel, + compiler_options={ + "xla_tpu_enable_sc_log_recorder": "true", + # TODO(slebedev): This should not be necessary. + "xla_sc_force_aligned_buffers": "false", + }, + ) + with jtu.capture_stderr() as get_output: + jax.block_until_ready(compiled_kernel(int32s, int16s, int8s)) + print(get_output()) + self.assertIn("s32 array, data: s32", get_output()) + self.assertIn("{ 8, 9, 10, 11, 12, 13, 14, 15 }", get_output()) + self.assertIn("s16 array, data: s16", get_output()) + self.assertIn( + "{ 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 }", + get_output(), + ) + self.assertIn("s8 array, data: s8", get_output()) + self.assertIn( + "{ 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47" + ", 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63 }", + get_output(), + ) + self.assertIn("Single int", get_output()) + self.assertIn(str(debug_int), get_output()) + self.assertIn("Single float", get_output()) + self.assertIn(str(debug_float), get_output()) + self.assertIn("No values", get_output()) + + +class VectorSubcoreTest(PallasSCTest): + + # Used for testing masked loads and stores below + MASK_FNS = [lambda x: x < 4, lambda x: x >= 4, lambda x: x % 2 == 0] + + @parameterized.product( + dtype=[jnp.int32, jnp.float32], op=[jnp.add, jnp.subtract] + ) + def test_add_sub_one(self, dtype, op): + x = jnp.arange(8, dtype=dtype) + + @self.vector_subcore_kernel(out_shape=x) + def kernel(x_ref, o_ref): + x = x_ref[...] + o_ref[...] = op(x, 1) + + np.testing.assert_array_equal(kernel(x), op(x, 1)) + + def test_add_one_block_specs(self): + x = jnp.arange(32, dtype=jnp.int32) + + @self.vector_subcore_kernel( + out_shape=x, + grid=(4,), + out_specs=pl.BlockSpec([8], lambda i: i), + in_specs=[pl.BlockSpec([8], lambda i: i)], + ) + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + 1 + + np.testing.assert_array_equal(kernel(x), x + 1) + + @parameterized.named_parameters(*( + dict( + testcase_name=( + f"_{'x'.join(map(str, shape))}x{dtype.name}_{minor_scale}" + ), + dtype=dtype, + out_shape=shape, + minor_scale=minor_scale, + ) + for dtype, shapes in sc_core.SUPPORTED_VECTOR_SHAPES.items() + for shape in shapes + if math.prod(shape) * dtype.itemsize == 32 + for minor_scale in [1, 2, 4] + )) + def test_slicing(self, dtype, out_shape, minor_scale): + self.skip_if_tc_tiling() + + if dtype == jnp.float16 and jtu.is_device_tpu(6, "e"): + # TODO(b/433704850): Remove this once the bug is fixed. + self.skipTest("Crashes") + + crashing = { + "int16": [(2, 8)], + "uint16": [(2, 8)], + "float16": [(2, 8)], + "bfloat16": [(2, 8)], + "int8": [(4, 8)], + "uint8": [(4, 8)], + } + if out_shape in crashing.get(dtype.name, []): + self.skipTest("Crashes") + + out_minor = out_shape[-1] + in_minor = out_minor * minor_scale + in_shape = (*out_shape[:-1], in_minor) + indices = [ + slice(i * out_minor, (i + 1) * out_minor) for i in range(minor_scale) + ] + + @self.vector_subcore_kernel( + out_shape=jax.ShapeDtypeStruct(shape=out_shape, dtype=dtype), + ) + def kernel(x_ref, o_ref): + o_ref[...] = sum(x_ref[..., idx] for idx in indices) + + x = jnp.arange(math.prod(in_shape), dtype=dtype).reshape(in_shape) + np.testing.assert_array_equal( + kernel(x), sum(x[..., idx] for idx in indices) + ) + + @parameterized.product(major_dim=[2, 3, 4]) + def test_get_index(self, major_dim): + @self.vector_subcore_kernel( + out_shape=jax.ShapeDtypeStruct(shape=(8,), dtype=jnp.int32), + ) + def kernel(x_ref, o_ref): + o_ref[...] = lax.fori_loop( + 1, major_dim, lambda i, acc: acc + x_ref[i], x_ref[0] + ) + + x = jnp.arange(8 * major_dim).reshape(major_dim, 8) + np.testing.assert_array_equal(kernel(x), x.sum(axis=0)) + + def test_get_multi_index(self): + @self.vector_subcore_kernel( + out_shape=jax.ShapeDtypeStruct(shape=(8,), dtype=jnp.int32) + ) + def kernel(x_ref, o_ref): + o_ref[...] = jnp.zeros_like(o_ref) + for i, j in itertools.product(*map(range, x_ref.shape[:-1])): + o_ref[...] += x_ref.at[i][j] + + x = jnp.arange(3 * 4 * 8).reshape(3, 4, 8) + np.testing.assert_array_equal(kernel(x), x.sum(axis=(0, 1))) + + @jtu.thread_unsafe_test(condition=not jtu.hypothesis_is_thread_safe()) + @hp.given(hps.data()) + def test_block_spec_untiled_slicing(self, data): + self.skipTest( + "Test uncovers a bug: @reproduce_failure('6.80.0', b'AAEBAQAAAAA=')" + ) + slice_shape = data.draw( + hps.lists( + hps.integers(1, 3), min_size=(1 + self.USE_TC_TILING), max_size=4 + ) + ) + if self.USE_TC_TILING: + slice_shape[-2] *= 8 + slice_shape[-1] *= 128 + else: + slice_shape[-1] *= 8 + max_elems = 12000 if jtu.is_device_tpu(6, "e") else 25000 + hp.assume(math.prod(slice_shape) <= max_elems) # Avoid OOMs. + rank = len(slice_shape) + offsets = data.draw( + hps.lists(hps.integers(0, 4), min_size=rank, max_size=rank) + ) + full_shape = tuple(s * (o + 2) for s, o in zip(slice_shape, offsets)) + + def nd_loop(bounds, body, *, _idxs = ()): + if not bounds: + body(*_idxs) + return + bound, *other_bounds = bounds + def _loop_body(i, _): + nd_loop(other_bounds, body, _idxs=(*_idxs, i)) + jax.lax.fori_loop(0, bound, _loop_body, None) + + @self.vector_subcore_kernel( + out_shape=jax.ShapeDtypeStruct(shape=slice_shape, dtype=jnp.int32), + in_specs=[pl.BlockSpec(slice_shape, lambda: offsets)], + ) + def kernel(x_ref, o_ref): + slice_vec_shape = (*slice_shape[:-1], slice_shape[-1] // 8) + def copy(*idxs): + idxs = (*idxs[:-1], pl.ds(idxs[-1] * 8, 8)) + o_ref[idxs] = x_ref[idxs] + nd_loop(slice_vec_shape, copy) + + x = jnp.arange(math.prod(full_shape)).reshape(full_shape) + np_slc = tuple(slice(o * s, (o + 1) * s) for o, s in zip(offsets, slice_shape)) + np.testing.assert_array_equal(kernel(x), x[np_slc]) + + @parameterized.product(major_dim=[2, 3, 4]) + def test_swap_index(self, major_dim): + @self.vector_subcore_kernel( + out_shape=jax.ShapeDtypeStruct(shape=(major_dim, 8), dtype=jnp.int32), + ) + def kernel(x_ref, o_ref): + @pl.loop(0, major_dim) + def _(i): + o_ref[major_dim - 1 - i] = x_ref[i] + + x = jnp.arange(major_dim * 8).reshape(major_dim, 8) + np.testing.assert_array_equal(kernel(x), x[::-1]) + + @parameterized.product(shape=[(8,), (16,), (8, 8), (16, 8), (8, 16, 8)]) + def test_scatter_major(self, shape): + self.skip_if_tc_tiling() + x = jnp.arange(math.prod(shape)).reshape(shape) + major_dim, *_ = shape + indices = jax.random.permutation(jax.random.key(42), jnp.arange(major_dim)) + + @self.vector_subcore_kernel( + out_shape=x, out_specs=pl.BlockSpec(memory_space=pltpu.HBM) + ) + def kernel(x_ref, indices_ref, o_hbm_ref): + @functools.partial(pl.run_scoped, sem=pltpu.SemaphoreType.DMA) + def _(sem): + pltpu.async_copy(x_ref, o_hbm_ref.at[indices_ref], sem).wait() + + np.testing.assert_array_equal( + kernel(x, indices), jnp.empty_like(x).at[indices].set(x) + ) + + def test_scatter_1d_array(self): + x = jnp.arange(8) + indices = jax.random.permutation(jax.random.key(42), jnp.arange(8)) + + @self.vector_subcore_kernel( + out_shape=x, out_specs=pl.BlockSpec(memory_space=pltpu.HBM) + ) + def kernel(x_ref, indices_ref, o_hbm_ref): + pltpu.sync_copy(x_ref, o_hbm_ref.at[indices_ref[...]]) + + np.testing.assert_array_equal( + kernel(x, indices), jnp.empty_like(x).at[indices].set(x) + ) + + def test_scatter_1d_array_from_transformed_src(self): + self.skip_if_tc_tiling() + x = jnp.arange(16).reshape(2, -1) + indices = jax.random.permutation(jax.random.key(42), jnp.arange(8)) + + @self.vector_subcore_kernel( + out_shape=x[0], out_specs=pl.BlockSpec(memory_space=pltpu.HBM), + ) + def kernel(x_ref, indices_ref, o_hbm_ref): + pltpu.sync_copy(x_ref.at[0], o_hbm_ref.at[indices_ref[...]]) + + np.testing.assert_array_equal( + kernel(x, indices), jnp.empty_like(x[0]).at[indices].set(x[0]) + ) + + @parameterized.product(kind=["ref", "array"]) + def test_gather_1d(self, kind): + x = jnp.arange(8) + indices = jax.random.permutation(jax.random.key(42), x) + + @self.vector_subcore_kernel( + out_shape=x, + in_specs=( + pl.BlockSpec(memory_space=pltpu.HBM), + pl.BlockSpec(memory_space=pltpu.VMEM), + ), + ) + def kernel(x_hbm_ref, indices_ref, o_ref): + indices = indices_ref if kind == "ref" else indices_ref[...] + pltpu.sync_copy(x_hbm_ref.at[indices], o_ref) + + np.testing.assert_array_equal(kernel(x, indices), x[indices]) + + @parameterized.product(kind=["ref", "array"]) + def test_gather_1d_to_transformed_dst(self, kind): + self.skip_if_tc_tiling() + x = jnp.arange(8) + indices = jax.random.permutation(jax.random.key(42), x) + + @self.vector_subcore_kernel( + out_shape=jax.ShapeDtypeStruct(shape=(2, 8,), dtype=jnp.int32), + in_specs=( + pl.BlockSpec(memory_space=pltpu.HBM), + pl.BlockSpec(memory_space=pltpu.VMEM), + ), + ) + def kernel(x_hbm_ref, indices_ref, o_ref): + indices = indices_ref if kind == "ref" else indices_ref[...] + pltpu.sync_copy(x_hbm_ref.at[indices], o_ref.at[0]) + + np.testing.assert_array_equal(kernel(x, indices)[0], x[indices]) + + def test_large_gather_1d(self): + x = jnp.arange(1024) + indices = jax.random.permutation(jax.random.key(42), x) + + @self.vector_subcore_kernel( + out_shape=x, + in_specs=( + pl.BlockSpec(memory_space=pltpu.HBM), + pl.BlockSpec(memory_space=pltpu.VMEM), + ), + ) + def kernel(x_hbm_ref, indices_ref, o_ref): + pltpu.sync_copy(x_hbm_ref.at[indices_ref], o_ref) + + np.testing.assert_array_equal(kernel(x, indices), x[indices]) + + def test_gather_1d_with_indexing(self): + self.skip_if_tc_tiling("Small 1d gather does not work on TC tiling.") + x = jnp.arange(4 * 4 * 8).reshape(4, 4, 8) + indices = jax.random.permutation(jax.random.key(42), jnp.arange(8)) + + @self.vector_subcore_kernel( + out_shape=jax.ShapeDtypeStruct(shape=(8,), dtype=jnp.int32), + in_specs=( + pl.BlockSpec(memory_space=pltpu.HBM), + pl.BlockSpec(memory_space=pltpu.VMEM), + ), + ) + def kernel(x_hbm_ref, indices_ref, o_ref): + pltpu.sync_copy(x_hbm_ref.at[1, 2].at[indices_ref], o_ref) + + np.testing.assert_array_equal(kernel(x, indices), x[1, 2, indices]) + + def test_gather_2d_with_indexing(self): + x = jnp.arange(4 * 16 * 128).reshape(4, 16, 128) + indices = jax.random.permutation(jax.random.key(42), jnp.arange(8)) + + @self.vector_subcore_kernel( + out_shape=jax.ShapeDtypeStruct(shape=(8, 128,), dtype=jnp.int32), + in_specs=( + pl.BlockSpec(memory_space=pltpu.HBM), + pl.BlockSpec(memory_space=pltpu.VMEM), + ), + ) + def kernel(x_hbm_ref, indices_ref, o_ref): + pltpu.sync_copy(x_hbm_ref.at[1, pl.ds(8, 8), :].at[indices_ref], o_ref) + + np.testing.assert_array_equal(kernel(x, indices), x[1, 8:][indices]) + + def test_gather_1d_with_indexed_ref(self): + x = jnp.arange(16) + indices = jax.random.permutation(jax.random.key(42), jnp.arange(16)) + + @self.vector_subcore_kernel( + out_shape=jax.ShapeDtypeStruct(shape=(8,), dtype=jnp.int32), + in_specs=( + pl.BlockSpec(memory_space=pltpu.HBM), + pl.BlockSpec(memory_space=pltpu.VMEM), + ), + ) + def kernel(x_hbm_ref, indices_ref, o_ref): + pltpu.sync_copy(x_hbm_ref.at[indices_ref.at[:indices.size // 2]], o_ref) + + np.testing.assert_array_equal( + kernel(x, indices), x[indices[:indices.size // 2]] + ) + + def test_gather_1d_with_dynamically_sized_ref(self): + self.skip_if_tc_tiling() + x = jnp.arange(16) + indices = jax.random.permutation(jax.random.key(42), jnp.arange(16)) + + @self.vector_subcore_kernel( + out_shape=jax.ShapeDtypeStruct(shape=(8,), dtype=jnp.int32), + grid=(1,), + in_specs=( + pl.BlockSpec(memory_space=pltpu.HBM), + pl.BlockSpec(memory_space=pltpu.VMEM), + ), + ) + def kernel(x_hbm_ref, indices_ref, o_ref): + pid = pl.program_id(0) # Always zero. + num_indices = pid + indices_ref.size // 2 + pltpu.sync_copy( + x_hbm_ref.at[indices_ref.at[pl.ds(0, num_indices)]], + o_ref.at[pl.ds(0, num_indices)], + ) + + np.testing.assert_array_equal( + kernel(x, indices), x[indices[: indices.size // 2]] + ) + + def test_gather_1d_with_dynamically_sized_2d_ref(self): + self.skip_if_tc_tiling() + + x = jnp.arange(16) + indices = jax.random.permutation( + jax.random.key(42), jnp.arange(2 * 16).reshape(2, -1), axis=1 + ) + + @self.vector_subcore_kernel( + out_shape=jax.ShapeDtypeStruct( + shape=(indices.size // 4,), dtype=jnp.int32 + ), + grid=(1,), + in_specs=( + pl.BlockSpec(memory_space=pltpu.HBM), + pl.BlockSpec(memory_space=pltpu.VMEM), + ), + ) + def kernel(x_hbm_ref, indices_ref, o_ref): + pid = pl.program_id(0) # Always zero. + num_indices = pid + indices_ref.size // 4 + pltpu.sync_copy( + x_hbm_ref.at[indices_ref.at[pid, pl.ds(0, num_indices)]], + o_ref.at[pl.ds(0, num_indices)], + ) + + np.testing.assert_array_equal( + kernel(x, indices), x[indices[0, : indices.size // 4]] + ) + + def test_invalid_gather_1d_with_extra_transforms(self): + x = jnp.arange(8) + indices = jax.random.permutation(jax.random.key(42), x) + + @self.vector_subcore_kernel( + out_shape=jax.ShapeDtypeStruct(shape=x.shape, dtype=jnp.int32), + in_specs=( + pl.BlockSpec(memory_space=pltpu.HBM), + pl.BlockSpec(memory_space=pltpu.VMEM), + ), + ) + def kernel(x_hbm_ref, indices_ref, o_ref): + pltpu.sync_copy(x_hbm_ref.at[indices_ref].reshape(o_ref.size), o_ref) + + with self.assertRaisesRegex( + NotImplementedError, "cannot have any transforms following the indexer" + ): + kernel(x, indices) + + def test_invalid_gather_1d_with_indexed_destination(self): + x = jnp.arange(8) + indices = jax.random.permutation(jax.random.key(42), x) + + @self.vector_subcore_kernel( + out_shape=jax.ShapeDtypeStruct(shape=x.shape, dtype=jnp.int32), + in_specs=( + pl.BlockSpec(memory_space=pltpu.HBM), + pl.BlockSpec(memory_space=pltpu.VMEM), + ), + ) + def kernel(x_hbm_ref, indices_ref, o_ref): + pltpu.sync_copy(x_hbm_ref.at[indices_ref], o_ref.at[indices_ref]) + + with self.assertRaisesRegex(ValueError, "source ref can be indexed"): + kernel(x, indices) + + def test_invalid_gather_1d_memory_space(self): + x = jnp.arange(8) + indices = jax.random.permutation(jax.random.key(42), x) + + @self.vector_subcore_kernel( + out_shape=jax.ShapeDtypeStruct(shape=x.shape, dtype=jnp.int32), + in_specs=( + pl.BlockSpec(memory_space=pltpu.HBM), + pl.BlockSpec(memory_space=pltpu.VMEM), + ), + out_specs=pl.BlockSpec(memory_space=pltpu.HBM), + ) + def kernel(x_hbm_ref, indices_ref, o_ref): + pltpu.sync_copy(x_hbm_ref.at[indices_ref], o_ref) + + with self.assertRaisesRegex( + NotImplementedError, "from HBM to HBM is not supported" + ): + kernel(x, indices) + + def test_invalid_gather_1d_offsets_memory_space(self): + x = jnp.arange(8) + indices = jax.random.permutation(jax.random.key(42), x) + + @self.vector_subcore_kernel( + out_shape=jax.ShapeDtypeStruct(shape=x.shape, dtype=jnp.int32), + in_specs=( + pl.BlockSpec(memory_space=pltpu.HBM), + pl.BlockSpec(memory_space=pltpu.HBM), + ), + ) + def kernel(x_hbm_ref, indices_ref, o_ref): + pltpu.sync_copy(x_hbm_ref.at[indices_ref], o_ref) + + with self.assertRaisesRegex( + NotImplementedError, "must be in VMEM, got HBM" + ): + kernel(x, indices) + + def test_implicit_gather_1d(self): + self.skip_if_tc_tiling() + num_steps = 4 + x = jnp.arange(num_steps * 8).reshape(num_steps, 8) + indices = jax.random.permutation(jax.random.key(42), jnp.arange(num_steps)) + + @self.vector_subcore_kernel( + out_shape=jax.ShapeDtypeStruct(shape=(num_steps, 8), dtype=jnp.int32), + grid=(num_steps,), + in_specs=( + plsc.BlockSpec((1, 8), indexed_by=1, indexed_dim=0), + pl.BlockSpec((1,), lambda i: i), + ), + out_specs=pl.BlockSpec((1, 8), lambda i: (0, i)), + ) + def kernel(x_ref, indices_ref, o_ref): + del indices_ref # Unused. + o_ref[...] = x_ref[...] + + np.testing.assert_array_equal( + kernel(x, indices), jnp.take(x, indices, axis=0) + ) + + def test_load_gather_1d(self): + x = jnp.arange(8) + indices = jax.random.permutation(jax.random.key(42), jnp.arange(8)) + + @self.vector_subcore_kernel(out_shape=x) + def kernel(x_ref, indices_ref, o_ref): + o_ref[...] = plsc.load_gather(x_ref, [indices_ref[...]]) + + np.testing.assert_array_equal(kernel(x, indices), x[indices]) + + def test_load_gather_2d(self): + x = jnp.arange(8 * 8).reshape(8, -1) + indices0 = indices1 = jax.random.permutation( + jax.random.key(42), jnp.arange(8) + ) + + @self.vector_subcore_kernel(out_shape=jax.ShapeDtypeStruct((8,), x.dtype)) + def kernel(x_ref, indices0_ref, indices1_ref, o_ref): + o_ref[...] = plsc.load_gather( + x_ref, [indices0_ref[...], indices1_ref[...]] + ) + + np.testing.assert_array_equal( + kernel(x, indices0, indices1), x[indices0, indices1] + ) + + def test_load_gather_with_indexing(self): + self.skip_if_tc_tiling() + num_steps = 4 + x = jnp.arange(num_steps * 8).reshape(num_steps, 8) + indices = jax.random.permutation(jax.random.key(42), jnp.arange(8)) + + @self.vector_subcore_kernel(out_shape=x) + def kernel(x_ref, indices_ref, o_ref): + indices = indices_ref[...] + for i in range(num_steps): + o_ref[i] = plsc.load_gather(x_ref.at[i], [indices]) + + out = kernel(x, indices) + for i in range(num_steps): + np.testing.assert_array_equal(out[i], x[i][indices]) + + @parameterized.parameters(*MASK_FNS) + def test_load_gather_masked(self, mask_fn): + x = jnp.arange(8) + indices = jax.random.permutation(jax.random.key(42), jnp.arange(8)) + + @self.vector_subcore_kernel(out_shape=x) + def kernel(x_ref, indices_ref, o_ref): + o_ref[...] = plsc.load_gather( + x_ref, [indices_ref[...]], mask=mask_fn(x_ref[...]) + ) + + mask = mask_fn(x) + np.testing.assert_array_equal(kernel(x, indices)[mask], x[indices][mask]) + + def test_store_scatter(self): + self.skip_if_tc_tiling() + num_steps = 4 + x = jnp.arange(num_steps * 8).reshape(num_steps, 8) + indices = jax.random.permutation(jax.random.key(42), jnp.arange(8)) + + @self.vector_subcore_kernel(out_shape=x) + def kernel(x_ref, indices_ref, o_ref): + indices = indices_ref[...] + o_ref[...] = jnp.zeros_like(o_ref) + for i in range(num_steps): + plsc.store_scatter(o_ref.at[i], [indices], x_ref[i]) + + out = kernel(x, indices) + for i in range(num_steps): + np.testing.assert_array_equal( + out[i], jnp.zeros_like(x[i]).at[indices].set(x[i]) + ) + + @parameterized.parameters(*MASK_FNS) + def test_store_scatter_masked(self, mask_fn): + x = jnp.arange(8) + indices = jax.random.permutation(jax.random.key(42), jnp.arange(8)) + + @self.vector_subcore_kernel(out_shape=x) + def kernel(x_ref, indices_ref, o_ref): + x = x_ref[...] + o_ref[...] = jnp.zeros_like(o_ref) + plsc.store_scatter(o_ref, [indices_ref[...]], x, mask=mask_fn(x)) + + mask = mask_fn(x) + np.testing.assert_array_equal( + kernel(x, indices), + jnp.zeros_like(x).at[indices[mask]].set(x[mask]), + ) + + def test_store_scatter_2d(self): + + num_steps = 4 + x = jnp.arange(num_steps * 8).reshape(num_steps, 8) + indices = jax.random.permutation(jax.random.key(42), jnp.arange(8)) + + @self.vector_subcore_kernel(out_shape=x) + def kernel(x_ref, indices_ref, o_ref): + indices = indices_ref[...] + o_ref[...] = jnp.zeros_like(o_ref) + for i in range(num_steps): + plsc.store_scatter( + o_ref, [jnp.full(indices.shape, i), indices], x_ref[i]) + + out = kernel(x, indices) + for i in range(num_steps): + np.testing.assert_array_equal( + out[i], jnp.zeros_like(x[i]).at[indices].set(x[i]) + ) + + @parameterized.parameters(*MASK_FNS) + def test_addupdate_scatter(self, mask_fn): + self.skip_if_tc_tiling() + x = jnp.arange(8) + indices = jax.random.permutation(jax.random.key(42), jnp.arange(8)) + + @self.vector_subcore_kernel(out_shape=x) + def kernel(x_ref, indices_ref, o_ref): + x = x_ref[...] + o_ref[...] = jnp.ones_like(o_ref) + plsc.addupdate_scatter(o_ref, [indices_ref[...]], x, mask=mask_fn(x)) + + mask = mask_fn(x) + np.testing.assert_array_equal( + kernel(x, indices), + jnp.ones_like(x).at[indices[mask]].add(x[mask]), + ) + + @parameterized.parameters(*MASK_FNS) + def test_load_expanded(self, mask_fn): + @self.vector_subcore_kernel( + out_shape=jax.ShapeDtypeStruct(shape=(8,), dtype=jnp.int32) + ) + def kernel(x_ref, o_ref): + o_ref[...] = plsc.load_expanded(x_ref.at[...], mask=mask_fn(x_ref[...])) + + x = jnp.arange(8) + mask = mask_fn(x) + expected = jnp.zeros_like(x).at[mask].set(x[: mask.sum()]) + np.testing.assert_array_equal(kernel(x)[mask], expected[mask]) + + @parameterized.parameters(*MASK_FNS) + def test_store_compressed(self, mask_fn): + @self.vector_subcore_kernel( + out_shape=jax.ShapeDtypeStruct(shape=(8,), dtype=jnp.int32) + ) + def kernel(x_ref, o_ref): + x = x_ref[...] + plsc.store_compressed(o_ref.at[...], x, mask=mask_fn(x)) + + x = jnp.arange(8) + mask = mask_fn(x) + np.testing.assert_array_equal(kernel(x)[: mask.sum()], x[mask]) + + def test_addupdate(self): + @self.vector_subcore_kernel( + out_shape=jax.ShapeDtypeStruct(shape=(8,), dtype=jnp.int32) + ) + def kernel(o_ref): + o_ref[...] = jnp.zeros_like(o_ref) + for i in range(8): + plsc.addupdate(o_ref.at[...], lax.broadcast(i, o_ref.shape)) + + np.testing.assert_array_equal(kernel(), jnp.full(8, jnp.arange(8).sum())) + + @parameterized.parameters(*MASK_FNS) + def test_addupdate_compressed(self, mask_fn): + @self.vector_subcore_kernel( + out_shape=jax.ShapeDtypeStruct(shape=(8,), dtype=jnp.int32) + ) + def kernel(x_ref, o_ref): + o_ref[...] = jnp.zeros_like(o_ref) + for i in range(8): + plsc.addupdate_compressed( + o_ref.at[...], + lax.broadcast(i, o_ref.shape), + mask=mask_fn(x_ref[...]), + ) + + x = jnp.arange(8) + mask = mask_fn(x) + np.testing.assert_array_equal( + kernel(x)[: mask.sum()], jnp.full(mask.sum(), jnp.arange(8).sum()) + ) + + @parameterized.product( + dtype=[jnp.int32], new_dtype=[jnp.int8, jnp.int16, jnp.float32] + ) + def test_bitcast(self, dtype, new_dtype): + self.skip_if_tc_tiling() + new_shape = ( + 8 * jnp.dtype(dtype).itemsize // jnp.dtype(new_dtype).itemsize, + ) + + @self.vector_subcore_kernel( + out_shape=jax.ShapeDtypeStruct(shape=new_shape, dtype=new_dtype) + ) + def kernel(x_ref, o_ref): + o_ref[...] = plsc.bitcast(x_ref[...], o_ref.dtype) + + x = jnp.arange(8, dtype=dtype) + np.testing.assert_array_equal(kernel(x), x.view(new_dtype)) + + def test_bitcast_invalid(self): + @self.vector_subcore_kernel( + out_shape=jax.ShapeDtypeStruct(shape=[1], dtype=jnp.int32) + ) + def kernel(x_ref, o_ref): + o_ref[...] = plsc.bitcast(x_ref[...], o_ref.dtype) + + x = jnp.arange(2, dtype=jnp.int8) + with self.assertRaisesRegex(ValueError, "is not divisible"): + kernel(x) + + def test_lax_bitcast(self): + @self.vector_subcore_kernel( + out_shape=jax.ShapeDtypeStruct((8,), jnp.uint32), + ) + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...].view(o_ref.dtype) + + x = jnp.arange(8, dtype=jnp.float32) + np.testing.assert_array_equal(kernel(x), x.view(np.uint32)) + + def test_ref_bitcast(self): + # TODO: b/443906446 - Remove the skip once we can lower such bitcasts. + self.skipTest("Ref bitcast is not supported yet") + + @self.vector_subcore_kernel( + out_shape=jax.ShapeDtypeStruct((8,), jnp.uint32), + ) + def kernel(x_ref, o_ref): + o_ref[...] = x_ref.bitcast(o_ref.dtype)[...] + + x = jnp.arange(8, dtype=jnp.float32) + np.testing.assert_array_equal(kernel(x), x.view(np.uint32)) + + @parameterized.product( + pack_format=[*plsc.PackFormat], + dtype=[jnp.float32, jnp.int32], + ) + def test_pack_unpack(self, pack_format, dtype): + shape = (8,) + + @self.vector_subcore_kernel( + out_shape=(jax.ShapeDtypeStruct((8,), dtype),) * 2 + ) + def kernel(a_ref, b_ref, oa_ref, ob_ref): + ab = plsc.pack(a_ref[...], b_ref[...], format=pack_format) + oa_ref[...], ob_ref[...] = plsc.unpack(ab, format=pack_format) + + a = jnp.arange(math.prod(shape), dtype=dtype).reshape(shape) + b = -a + out_a, out_b = kernel(a, b) + np.testing.assert_array_equal(out_a, a) + np.testing.assert_array_equal(out_b, b) + + @parameterized.parameters(jnp.int32, jnp.float32) + def test_scan_count(self, dtype): + shape = [8] + + @self.vector_subcore_kernel( + out_shape=( + jax.ShapeDtypeStruct(shape, jnp.int32), + jax.ShapeDtypeStruct(shape, jnp.int32), + ), + ) + def kernel(x_ref, counts_ref, mask_ref): + counts_ref[...], mask = plsc.scan_count(x_ref[...]) + mask_ref[...] = mask.astype(jnp.int32) + + key = jax.random.key(42) + x = jax.random.randint(key, shape, 0, 10, dtype=jnp.int32).astype(dtype) + counts, mask = kernel(x) + expected_counts = [] + expected_mask = [] + c = collections.Counter() + for item in x: + item = int(item) + c[item] += 1 + expected_counts.append(c[item]) + for item in x: + item = int(item) + c[item] -= 1 + expected_mask.append(c[item] == 0) + np.testing.assert_array_equal(counts, expected_counts) + np.testing.assert_array_equal(mask, expected_mask) + + def test_population_count(self): + key = jax.random.key(42) + x = jax.random.randint(key, [8], 0, 100) + + @self.vector_subcore_kernel(out_shape=x) + def kernel(x_ref, o_ref): + mask = x_ref[...] < 50 + # TODO: b/434208146 - Test with reduce!=1 when we support v6e packed masks + o_ref[...] = plsc.all_reduce_population_count(mask) + + np.testing.assert_array_equal( + kernel(x), np.broadcast_to(np.count_nonzero(x < 50), x.shape) + ) + + def test_iota(self): + key = jax.random.key(42) + x = jax.random.randint(key, [8], 0, 100) + + @self.vector_subcore_kernel(out_shape=x) + def kernel(x_ref, o_ref): + o_ref[...] = jnp.arange(8) + x_ref[...] + + np.testing.assert_array_equal( + kernel(x), x + np.arange(8) + ) + + def test_write_to_transformed_ref(self): + x = jnp.arange(16) + + @self.vector_subcore_kernel(out_shape=x) + def kernel(x_ref, o_ref): + plsc.store_compressed( + o_ref.at[pl.ds(5, 8)], x_ref[pl.ds(2, 8)], mask=jnp.ones(8, jnp.bool), + ) + np.testing.assert_array_equal(kernel(x)[5:13], x[2:10]) + + def test_load_transformed_ref(self): + x = jnp.arange(16) + + @self.vector_subcore_kernel(out_shape=x) + def kernel(x_ref, o_ref): + o_ref[pl.ds(5, 8)] = plsc.load_expanded( + x_ref.at[pl.ds(2, 8)], mask=jnp.arange(8) % 2 == 0) + np.testing.assert_array_equal(kernel(x)[5:13:2], x[2:6]) + + def test_scalar_load_store(self): + + @self.vector_subcore_kernel( + in_specs=(pl.BlockSpec(memory_space=pltpu.HBM),), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), + out_shape=jax.ShapeDtypeStruct((8,), jnp.int32), + scratch_shapes=(pltpu.VMEM((1,), jnp.int32),), + ) + def kernel(x_ref, o_ref, tmp_ref): + pltpu.sync_copy(x_ref, tmp_ref) + o_ref[...] = lax.broadcast(tmp_ref[0], o_ref.shape) + + np.testing.assert_array_equal( + kernel(jnp.ones((1,), jnp.int32)), jnp.ones((8,), jnp.int32) + ) + + def test_scalar_load_hbm(self): + + @self.vector_subcore_kernel( + in_specs=(pl.BlockSpec(memory_space=pltpu.HBM),), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), + out_shape=jax.ShapeDtypeStruct((8,), jnp.int32), + ) + def kernel(x_ref, o_ref): + o_ref[...] = lax.broadcast(x_ref[0], o_ref.shape) + + with self.assertRaisesRegex( + NotImplementedError, "Get does not support loading from HBM" + ): + _ = kernel(jnp.ones((1,), jnp.int32)) + + @parameterized.named_parameters( + ("mixed", [0, 0, 1, 0, 1, 0, 0, 0], 2), + ("all_zero", [0, 0, 0, 0, 0, 0, 0, 0], 8), + ("all_one", [1, 1, 1, 1, 1, 1, 1, 1], 0)) + def test_ffs(self, data, expected): + x = jnp.array(data) + + @self.vector_subcore_kernel(out_shape=x) + def kernel(x_ref, o_ref): + mask = x_ref[...] == 1 + # TODO: b/434208146 - Test with reduce!=1 when we support v6e packed masks + o_ref[...] = plsc.all_reduce_ffs(mask) + + np.testing.assert_array_equal(kernel(x), np.broadcast_to(expected, x.shape)) + + def test_run_scoped(self): + x = jnp.arange(8) + + @self.vector_subcore_kernel( + out_shape=x, out_specs=pl.BlockSpec(memory_space=pltpu.HBM) + ) + def kernel(x_ref, o_hbm_ref): + pltpu.sync_copy(x_ref, o_hbm_ref) + + np.testing.assert_array_equal(kernel(x), x) + + def test_run_scoped_with_tiling(self): + x = jnp.arange(2 * 8).reshape(-1, 8) + + @self.vector_subcore_kernel(out_shape=x) + def kernel(x_ref, o_ref): + def scoped_kernel(scratch_ref): + scratch_ref[...] = x_ref[...] + o_ref[...] = scratch_ref[...] + + pl.run_scoped( + scoped_kernel, + plsc.MemoryRef( + x.shape, x_ref.dtype, memory_space=pltpu.VMEM, tiling=[(1, 8)] + ), + ) + + # Just make sure it compiles. The unrolling logic in the SC compiler + # does not yet handle tiled layouts properly, so the result is wrong. + _ = kernel(x) + + @parameterized.product(sizes=[[1, 1], [2, 2], [1, 1, 1, 1]]) + def test_split_concatenate(self, sizes): + + shape = (sum(sizes), 8) + x = jnp.arange(math.prod(shape)).reshape(-1, 8) + + @self.vector_subcore_kernel(out_shape=x) + def kernel(x_ref, o_ref): + chunks = lax.split(x_ref[...], sizes, 0) + o_ref[...] = lax.concatenate(chunks, 0) + + np.testing.assert_array_equal(kernel(x), x) + + def test_scratch(self): + x = jnp.arange(8) + + @self.vector_subcore_kernel( + out_shape=x, + scratch_shapes=(pltpu.VMEM([8], jnp.float32),), + ) + def kernel(x_ref, o_ref, scratch_ref): + scratch_ref[...] = x_ref[...].astype(jnp.float32) + o_ref[...] = scratch_ref[...].astype(x.dtype) + + np.testing.assert_array_equal(kernel(x), x) + + def test_implicit_padding_unsupported(self): + x = jnp.arange(8, dtype=jnp.int32).reshape((8, 1)) + + @self.vector_subcore_kernel(out_shape=x, in_specs=(pl.BlockSpec((8, 1)),)) + def kernel(*args): + del args # Unused. + + with self.assertRaisesRegex(ValueError, "must be a multiple of 8"): + kernel(x) + + def test_subcore_parallel(self): + self.skip_if_tc_tiling() + num_subcores = 16 + + @self.kernel( + out_shape=jax.ShapeDtypeStruct( + shape=(num_subcores, 8), dtype=jnp.int32 + ), + mesh=plsc.VectorSubcoreMesh( + core_axis_name="core", subcore_axis_name="subcore", num_cores=1 + ), + ) + def kernel(x_ref, o_ref): + # This is a smoke test, since it does not in fact check that the kernel + # is executed in parallel over the subcores. + subcore_id = lax.axis_index("subcore") + pltpu.sync_copy(x_ref.at[subcore_id], o_ref.at[subcore_id]) + + x = jnp.arange(num_subcores * 8, dtype=jnp.int32).reshape(-1, 8) + np.testing.assert_array_equal(kernel(x), x) + + def test_smem_vmem_store_literals(self): + self.skip_if_tc_tiling() + num_subcores = 16 + + @self.kernel( + out_shape=jax.ShapeDtypeStruct( + shape=(num_subcores, 8), dtype=jnp.float32 + ), + mesh=plsc.VectorSubcoreMesh( + core_axis_name="core", subcore_axis_name="subcore", num_cores=1 + ), + scratch_shapes=(pltpu.SMEM([1], jnp.float32), + pltpu.VMEM([8], jnp.float32)), + ) + def kernel(x_ref, o_ref, scratch_scalar_ref, scratch_vec_ref): + subcore_id = lax.axis_index("subcore") + scratch_scalar_ref[0] = 7. + pltpu.sync_copy(x_ref.at[subcore_id], scratch_vec_ref) + scratch_vec_ref[...] = jnp.where( + subcore_id < 3, scratch_scalar_ref[0], scratch_vec_ref[...]) + pltpu.sync_copy(scratch_vec_ref, o_ref.at[subcore_id]) + + x = jnp.arange(num_subcores * 8, dtype=jnp.float32).reshape(-1, 8) + expected = jnp.where(jnp.arange(num_subcores)[:, jnp.newaxis] < 3, + jnp.full((num_subcores, 8), 7.), + x) + np.testing.assert_array_equal(kernel(x), expected) + + @parameterized.named_parameters( + ("barrier", lambda _: plsc.subcore_barrier()), + ("debug_print", lambda vec: pl.debug_print('test', vec)), + ) + def test_effect_discharge(self, effectful_op): + x = jnp.arange(self.sc_info.num_lanes) + mesh = plsc.VectorSubcoreMesh( + core_axis_name="core", subcore_axis_name="subcore", num_cores=1 + ) + def stateful(refs): + def body(x_ref, o_ref): + def with_scratch(scratch_ref): + pltpu.sync_copy(x_ref, scratch_ref) + scratch_ref[...] = scratch_ref[...] + 1 + effectful_op(scratch_ref[...]) + pltpu.sync_copy(scratch_ref, o_ref) + pl.run_scoped(with_scratch, pltpu.VMEM(x.shape, x.dtype)) + pl.core_map(mesh)(lambda: body(*refs)) + + _, out = jax.jit(state_discharge.run_state(stateful))( + (x, jnp.empty_like(x))) + np.testing.assert_array_equal(out, x + 1) + + def test_parallel_loop_effects(self): + chunk_size = 8 + + @self.kernel( + out_shape=(), + mesh=plsc.VectorSubcoreMesh( + core_axis_name="core", subcore_axis_name="subcore", num_cores=1 + ), + scratch_shapes=(pltpu.VMEM((chunk_size,), jnp.uint32),) * 3, + ) + def _kernel(a_ref, b_ref, c_ref): + @pl.loop(0, 4) + def outer(i): + const = jnp.array(0, jnp.uint32) + + @plsc.parallel_loop(0, chunk_size) + def body(_): + x = a_ref[...] >> i.astype(jnp.uint32) + plsc.store_compressed(c_ref.at[...], b_ref[...], mask=x > const) + + _kernel() + + def test_reshape(self): + shape = (8,) + dtype = jnp.int32 + + @self.vector_subcore_kernel(out_shape=jax.ShapeDtypeStruct(shape, dtype)) + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...].reshape(2, 4).reshape(8) + + x = jnp.arange(math.prod(shape), dtype=dtype).reshape(shape) + np.testing.assert_array_equal(kernel(x), x) + + @parameterized.product(dtype=[jnp.int32, jnp.float32]) + def test_cumsum(self, dtype): + x = jnp.arange(self.sc_info.num_lanes, dtype=dtype) + + @self.vector_subcore_kernel(out_shape=x) + def kernel(x_ref, o_ref): + o_ref[...] = jnp.cumsum(x_ref[...]) + + np.testing.assert_array_equal(kernel(x), np.cumsum(x)) + + @parameterized.product(dtype=[jnp.int32, jnp.float32], op=[jnp.sum, jnp.max]) + def test_reductions(self, dtype, op): + x = jnp.arange(self.sc_info.num_lanes, dtype=dtype) + @self.vector_subcore_kernel(out_shape=x) + def kernel(x_ref, o_ref): + o_ref[...] = jnp.full(o_ref.shape, op(x_ref[...])) + np.testing.assert_array_equal(kernel(x)[0], op(x)) + + @parameterized.product(dtype=[jnp.int32, jnp.float32]) + def test_cumsum_2d_not_supported(self, dtype): + x = jnp.arange(self.sc_info.num_lanes, dtype=dtype) + + with self.assertRaisesRegex(NotImplementedError, r"must be rank 1"): + @self.vector_subcore_kernel(out_shape=x) + def kernel(x_ref, o_ref): + o_ref[...] = jnp.cumsum(x_ref[...].reshape(4, 2), axis=0).reshape(-1) + + kernel(x) + + @parameterized.product(dtype=[jnp.int32, jnp.float32]) + def test_masked_cumsum(self, dtype): + x = jnp.arange(self.sc_info.num_lanes, dtype=dtype) + + @self.vector_subcore_kernel(out_shape=x) + def kernel(x_ref, o_ref): + o_ref[...] = plsc.cumsum(x_ref[...], mask=(x_ref[...] % 2) == 1) + + np.testing.assert_array_equal(kernel(x), np.cumsum(x * (x % 2))) + + @parameterized.product(dtype=[jnp.int32, jnp.float32]) + def test_masked_cummax(self, dtype): + x = np.arange(self.sc_info.num_lanes, dtype=dtype) + np.random.shuffle(x) + + @self.vector_subcore_kernel(out_shape=x) + def kernel(x_ref, o_ref): + o_ref[...] = plsc.cummax(x_ref[...], mask=(x_ref[...] % 2) == 1) + + row = np.arange(self.sc_info.num_lanes)[:, np.newaxis] + col = np.arange(self.sc_info.num_lanes)[np.newaxis, :] + mask = x % 2 + expected = (x * mask * (col <= row)).max(axis=1) + has_valid_value_so_far = np.cumsum(mask) > 0 + expected = np.where(has_valid_value_so_far, expected, x) + np.testing.assert_array_equal(kernel(x), expected) + + def test_parallel_loop_with_carry(self): + chunk_size = self.sc_info.num_lanes + nchunks = 4 + per_step_increment = 10 + sentinel_multiplier = 1000 + x = jnp.arange(16 * chunk_size * nchunks, dtype=np.int32) + + @self.vector_subcore_kernel( + out_shape=x, + grid=(16,), + in_specs=[pl.BlockSpec([chunk_size * nchunks], lambda i: (i,))], + out_specs=pl.BlockSpec([chunk_size * nchunks], lambda i: (i,)), + ) + def kernel(x_ref, o_ref): + @pl.when(pl.program_id(0) < 16) + def _(): + init = (jnp.zeros([], x_ref.dtype), # scalar + jnp.zeros([chunk_size], x_ref.dtype), # vector + ) + def for_each_chunk(i, carry): + incr, running_sum = carry + incr += per_step_increment + o_ref[pl.ds(i, chunk_size)] = x_ref[pl.ds(i, chunk_size)] + incr + return incr, running_sum + x_ref[pl.ds(i, chunk_size)] + result = plsc.parallel_loop(0, x_ref.shape[0], chunk_size, carry=init)( + for_each_chunk) + o_ref[pl.ds(0, chunk_size)] = jnp.where( + jnp.arange(chunk_size) == 0, + result[0] * sentinel_multiplier, + result[1]) + + output = kernel(x) + expected = np.array(x).reshape(16, nchunks, chunk_size) + # Check that the increment was properly applied. + expected += 10 * np.arange(1, 5)[:, None] + # Check the final carry values: + # - Scalar in 0th position. + expected[:, 0, 0] = sentinel_multiplier * per_step_increment * nchunks + # - Vector in 1:chunk_size-1 positions. + expected[:, 0, 1:] = x.reshape(16, nchunks, chunk_size).sum(1)[:, 1:] + np.testing.assert_array_equal(output, expected.reshape(-1)) + + @parameterized.parameters( + (lambda x_ref: x_ref, r"may not be.*Ref\{"), + (lambda x_ref: x_ref.at[pl.ds(0, 8)], r"TransformedRef.*not a valid"), + ) + def test_parallel_loop_disallows_ref_carries(self, carry_fn, expected_regex): + x = jnp.arange(64, dtype=jnp.int32) + + with self.assertRaisesRegex(TypeError, expected_regex): + @self.vector_subcore_kernel(out_shape=x) + def kernel(x_ref, o_ref): + @plsc.parallel_loop(0, 1, carry=carry_fn(x_ref)) + def _(i, carry): + del i # Unused. + x_ref[...] = o_ref[...] + return carry + + kernel(x) + + def test_parallel_loop_wrong_carry_return(self): + x = jnp.arange(64, dtype=jnp.int32) + + with self.assertRaisesRegex(ValueError, "should have same structure"): + @self.vector_subcore_kernel(out_shape=x) + def kernel(x_ref, o_ref): + init = dict(x=jnp.zeros([]), y=jnp.ones([8])) + @plsc.parallel_loop(0, 1, carry=init) + def _(i, carry): + del i # Unused. + x_ref[...] = o_ref[...] + return carry["x"] + + kernel(x) + + def test_squeezed_blockspec_error_message(self): + shape = (16, 8, 32) + spec_shape = (pl.squeezed, 8, 32) + x = jnp.arange(np.prod(shape), dtype=jnp.int32).reshape(*shape) + + @self.vector_subcore_kernel( + out_shape=x, + grid=16, + in_specs=[pl.BlockSpec(spec_shape, lambda i: (i, 0, 0))], + out_specs=pl.BlockSpec(spec_shape, lambda i: (i, 0, 0)), + ) + def kernel(x_ref, o_ref): + del x_ref, o_ref # Unused. + + with self.assertRaisesRegex( + NotImplementedError, r"Unsupported block dimension type.*Squeezed"): + kernel(x) + + def test_multiple_of(self): + x = jnp.arange(16) + + @self.vector_subcore_kernel(out_shape=x) + def kernel(x_ref, o_ref): + @pl.loop(0, 16, step=8) + def _(i): + i = pl.multiple_of(i, 8) + o_ref[pl.ds(i, 8)] = x_ref[pl.ds(i, 8)] + 1 + + np.testing.assert_array_equal(kernel(x), x + 1) + + def test_barrier_via_mesh(self): + self.skip_if_tc_tiling() + mesh = plsc.VectorSubcoreMesh( + core_axis_name="core", subcore_axis_name="subcore", num_cores=1 + ) + vec_dim = self.sc_info.num_lanes + @self.kernel( + out_shape=jax.ShapeDtypeStruct( + shape=(mesh.num_subcores, vec_dim), dtype=jnp.uint32 + ), + mesh=mesh, + scratch_shapes=[pltpu.VMEM((mesh.num_subcores, vec_dim), jnp.uint32)], + ) + def kernel(o_ref, vmem_ref): + subcore_id = lax.axis_index("subcore") + @pl.loop(0, 2 * subcore_id + 1) + def _(i): + vmem_ref[subcore_id] = jnp.full(vec_dim, i, dtype=jnp.uint32) + pltpu.sync_copy(vmem_ref.at[subcore_id], o_ref.at[subcore_id]) + plsc.subcore_barrier() + pltpu.sync_copy(o_ref.at[(subcore_id + 1) % mesh.num_subcores], + vmem_ref.at[subcore_id]) + pltpu.sync_copy(vmem_ref.at[subcore_id], o_ref.at[subcore_id]) + expected = 2 * jnp.roll(jnp.arange(mesh.num_subcores), -1) + expected = jnp.broadcast_to(expected[:, None], (mesh.num_subcores, vec_dim)) + np.testing.assert_array_equal(kernel(), expected) + + def test_barrier_via_pallas_call(self): + self.skip_if_tc_tiling() + + mesh = plsc.VectorSubcoreMesh( + core_axis_name="core", subcore_axis_name="subcore", num_cores=1 + ) + vec_dim = self.sc_info.num_lanes + @functools.partial( + pl.pallas_call, + grid=16, + compiler_params=pltpu.CompilerParams( + kernel_type=pltpu.KernelType.SC_VECTOR_SUBCORE, + dimension_semantics=["subcore_parallel"], + use_tc_tiling_on_sc=self.USE_TC_TILING, + ), + out_shape=jax.ShapeDtypeStruct( + shape=(mesh.num_subcores, vec_dim), dtype=jnp.uint32 + ), + out_specs=pl.BlockSpec((1, vec_dim), lambda i: (i, 0)), + scratch_shapes=( + pltpu.VMEM_SHARED((mesh.num_subcores, vec_dim), jnp.uint32), + pltpu.VMEM((vec_dim,), jnp.uint32), + ), + ) + def kernel(o_ref, shared_ref, vmem_ref): + subcore_id = pl.program_id(0) + @pl.loop(0, 10 * subcore_id + 1) + def _(i): + vmem_ref[:] = jnp.full(vec_dim, i, dtype=jnp.uint32) + pltpu.sync_copy(vmem_ref, shared_ref.at[subcore_id]) + plsc.subcore_barrier() + pltpu.sync_copy(shared_ref.at[(subcore_id + 1) % mesh.num_subcores], + o_ref.at[0]) + expected = 10 * jnp.roll(jnp.arange(mesh.num_subcores), -1) + expected = jnp.broadcast_to(expected[:, None], (mesh.num_subcores, vec_dim)) + np.testing.assert_array_equal(kernel(), expected) + + @parameterized.parameters(jnp.int32, jnp.float32) + def test_gather_add(self, dtype): + """Gather from HBM at indices added to contiguous VMEM.""" + self.skip_if_tc_tiling() + shape = (16, 64, 32) + x = jnp.arange(np.prod(shape), dtype=dtype).reshape(*shape) + + @self.kernel( + out_shape=x[:, :8], + mesh=plsc.VectorSubcoreMesh( + core_axis_name="core", subcore_axis_name="subcore", num_cores=1 + ), + scratch_shapes=[ + pltpu.VMEM([8], jnp.int32), + pltpu.VMEM([8, 32], dtype), + pltpu.SemaphoreType.DMA, + ], + ) + def kernel(x_ref, indices_ref, o_ref, indices_vmem, scratch_ref, sem): + subcore_id = lax.axis_index("subcore") + pltpu.sync_copy(indices_ref, indices_vmem) + # Initialize scratch space. + pltpu.sync_copy(x_ref.at[subcore_id, pl.ds(0, 8)], scratch_ref) + # Gather-add selected indices to scratch. + pltpu.async_copy( + # TODO: Can't mix array and ref indexers .at[subcore_id, indices_vmem] + x_ref.at[subcore_id].at[indices_vmem], + scratch_ref, + sem, + add=True, + ).wait() + pltpu.sync_copy(scratch_ref, o_ref.at[subcore_id]) + + indices = jnp.arange(8) * 8 + np.testing.assert_array_equal( + kernel(x, indices), x[:, :8] + x[:, indices]) + + @parameterized.parameters(jnp.int32, jnp.float32) + def test_scatter_add(self, dtype): + """Scatter from contiguous VMEM added to VMEM_SHARED at indices.""" + self.skip_if_tc_tiling() + shape = (16, 32) + x = jnp.arange(np.prod(shape), dtype=dtype).reshape(*shape) + + mesh = plsc.VectorSubcoreMesh( + core_axis_name="core", subcore_axis_name="subcore", num_cores=1 + ) + @functools.partial( + pl.pallas_call, + grid=mesh.num_subcores, + compiler_params=pltpu.CompilerParams( + kernel_type=pltpu.KernelType.SC_VECTOR_SUBCORE, + dimension_semantics=["subcore_parallel"], + use_tc_tiling_on_sc=self.USE_TC_TILING, + ), + out_shape=jax.ShapeDtypeStruct(shape[1:], dtype), + out_specs=pl.BlockSpec( + shape[1:], lambda i: (0,), memory_space=pltpu.HBM + ), + in_specs=[ + pl.BlockSpec(shape, lambda *_: (0, 0), memory_space=pltpu.HBM), + pl.BlockSpec(shape[1:], lambda _: (0,)), + ], + scratch_shapes=[ + pltpu.VMEM_SHARED(shape[1:], dtype), + pltpu.VMEM(shape[1:], dtype), + pltpu.SemaphoreType.DMA, + ], + ) + def kernel(x_ref, indices_ref, o_ref, + shared_scratch_ref, scratch_ref, sem): + subcore_id = pl.program_id(0) + pltpu.sync_copy(x_ref.at[subcore_id], scratch_ref) + # Subcore 0 to init shared scratch. + @pl.when(subcore_id == 0) + def _(): + pltpu.sync_copy(scratch_ref, shared_scratch_ref) + plsc.subcore_barrier() + # All cores to add their slice to shared scratch. + pltpu.async_copy( + scratch_ref, + shared_scratch_ref.at[indices_ref], + sem, + add=True, + ).wait() + plsc.subcore_barrier() + # Subcore 0 to copy shared scratch to output. + @pl.when(subcore_id == 0) + def _(): + pltpu.sync_copy(shared_scratch_ref, scratch_ref) + pltpu.sync_copy(scratch_ref, o_ref) + + indices = 31 - jnp.arange(32) + np.testing.assert_array_equal(kernel(x, indices), x[0] + x.sum(0)[::-1]) + + def test_shared_scratch(self): + self.skip_if_tc_tiling() + mesh = plsc.VectorSubcoreMesh( + core_axis_name="core", subcore_axis_name="subcore", num_cores=1 + ) + shape = (mesh.num_subcores, 8, 8) + x = jnp.arange(np.prod(shape), dtype=jnp.int32).reshape(*shape) + + @self.kernel(out_shape=x, mesh=mesh) + def kernel(x_ref, o_ref): + subcore_id = lax.axis_index("subcore") + shared_scratch_ref = pl.get_global( + pltpu.VMEM_SHARED(shape[1:], jnp.int32)) + @pl.when(subcore_id == 0) + def _(): + shared_scratch_ref2 = pl.get_global(pltpu.VMEM_SHARED(shape, jnp.int32)) + pltpu.sync_copy( + x_ref.at[subcore_id], shared_scratch_ref2.at[subcore_id]) + pltpu.sync_copy(x_ref.at[subcore_id], shared_scratch_ref) + pltpu.sync_copy(shared_scratch_ref, o_ref.at[subcore_id]) + + np.testing.assert_array_equal(kernel(x)[0], x[0]) + + def test_copy_in_shard_map(self): + self.skip_if_tc_tiling() + num_devices = len(jax.devices()) + mesh = jtu.create_mesh((num_devices,), ("x",)) + + rng = np.random.default_rng(0) + x = rng.integers(512, size=(num_devices * 1024, 16), dtype=np.int32) + + # The test ensures that JAX-level memory space for ``x`` is not propagated + # into Pallas, since Pallas cannot use it. + x = jax.device_put(x, jax.sharding.NamedSharding(mesh, jax.P("x", None))) + self.assertEqual(jax.typeof(x).memory_space, jax.memory.Space.Device) + + @functools.partial( + jax.shard_map, + in_specs=(jax.P("x", None),), + out_specs=jax.P("x", None), + mesh=mesh, + check_vma=True, + ) + def f(x): + @self.kernel( + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype, vma={"x"}), + mesh=plsc.VectorSubcoreMesh( + core_axis_name="core", subcore_axis_name="subcore", num_cores=1 + ), + scratch_shapes=(pltpu.VMEM(x.shape, x.dtype),), + ) + def kernel(in_ref, o_ref, scratch_ref): + pltpu.sync_copy(in_ref, scratch_ref) + pltpu.sync_copy(scratch_ref, o_ref) + + return kernel(x) + + np.testing.assert_array_equal(f(x), x) + + @parameterized.named_parameters( + ("exp", jnp.exp), ("neg", lambda x: -x), ("abs", jnp.abs) + ) + def test_unary_ops(self, op): + x = jnp.arange(8, dtype=jnp.float32) + + @self.vector_subcore_kernel(out_shape=x) + def kernel(x_ref, o_ref): + o_ref[...] = op(x_ref[...]) + + np.testing.assert_array_equal(kernel(x), op(x)) + + @parameterized.product(dtype=[np.int32, np.float32]) + def test_vector_gather(self, dtype): + vec_dim = self.sc_info.num_lanes + x = np.arange(vec_dim, dtype=dtype) + indices = np.random.randint(0, vec_dim, size=vec_dim, dtype=np.int32) + indices[[0, -2]] = 2 # Verify non-unique works. + indices[1] = -2 # Verify negative indices work. + + @self.vector_subcore_kernel(out_shape=x) + def kernel(x_ref, indices_ref, out_ref): + out_ref[...] = x_ref[...][indices_ref[...]] + + np.testing.assert_array_equal(kernel(x, indices), x[indices]) + + @parameterized.product( + keys_dtype=[np.int32, np.float32], + values_dtype=[np.int32, np.float32], + use_mask=[False, True], + descending=[False, True], + ) + def test_sort_key_val(self, keys_dtype, values_dtype, use_mask, descending): + vec_dim = self.sc_info.num_lanes + keys = np.arange(vec_dim, dtype=keys_dtype) + np.random.shuffle(keys) + keys[3] = keys[1] # Verify sort stability. + values = np.arange(vec_dim, dtype=values_dtype) + np.random.shuffle(values) + mask = np.random.choice([True, False], size=vec_dim) if use_mask else None + maybe_mask_arg = (mask.astype(jnp.int32),) if use_mask else () + + @self.vector_subcore_kernel(out_shape=(keys, values, *maybe_mask_arg)) + def kernel(*args): + if use_mask: + mask_ref, *args, o_mask_ref = args + mask = mask_ref[...].astype(jnp.bool) + else: + mask, o_mask_ref = None, None + keys_ref, values_ref, o_keys_ref, o_vals_ref = args + o_keys_ref[...], o_vals_ref[...], *maybe_out_mask = plsc.sort_key_val( + keys_ref[...], values_ref[...], mask=mask, descending=descending) + if use_mask: + [out_mask] = maybe_out_mask + o_mask_ref[...] = out_mask.astype(jnp.int32) + + out_keys, out_values, *maybe_out_mask = kernel( + *maybe_mask_arg, keys, values) + + keys_arg = keys + if descending: + keys_arg = -keys_arg + if use_mask: + keys_arg = jnp.where(mask, keys_arg, 100) + _, gt_keys = jax.lax.sort_key_val(keys_arg, keys) + _, gt_values = jax.lax.sort_key_val(keys_arg, values) + if use_mask: + [out_mask] = maybe_out_mask + gt_out_mask = jnp.arange(vec_dim) < mask.sum() + np.testing.assert_array_equal(out_mask, gt_out_mask.astype(jnp.int32)) + np.testing.assert_array_equal(out_keys, gt_keys) + np.testing.assert_array_equal(out_values, gt_values) + + @parameterized.product(dtype=[np.int32, np.float32]) + def test_rev_and_sort_desc(self, dtype): + vec_dim = self.sc_info.num_lanes + keys = np.arange(vec_dim, dtype=dtype) + np.random.shuffle(keys) + + @self.vector_subcore_kernel(out_shape=(keys, keys)) + def kernel(x_ref, o1_ref, o2_ref): + o1_ref[...] = jnp.sort(x_ref[...], descending=True) + o2_ref[...] = jnp.flip(x_ref[...], axis=-1) + + sorted_desc, reversed_keys = kernel(keys) # pylint: disable=unpacking-non-sequence + np.testing.assert_array_equal( + sorted_desc, jnp.arange(vec_dim, dtype=dtype)[::-1]) + np.testing.assert_array_equal(reversed_keys, keys[::-1]) + + @parameterized.product( + keys_dtype=[np.int32, np.float32], + values_dtypes=[(), (np.int32,), (np.float32, np.int32)], + ) + def test_sort(self, keys_dtype, values_dtypes): + vec_dim = self.sc_info.num_lanes + keys = np.arange(vec_dim, dtype=keys_dtype) + np.random.shuffle(keys) + values = [np.arange(vec_dim, dtype=dtype) for dtype in values_dtypes] + _ = [np.random.shuffle(v) for v in values] + + @self.vector_subcore_kernel(out_shape=(keys, *values)) + def kernel(*args): + keys_ref, *values_refs = args[: len(args) // 2] + keys_out, *all_values_out = jax.lax.sort( + (keys_ref[...], *(ref[...] for ref in values_refs)) + ) + keys_out_ref, *values_out_refs = args[len(args) // 2 :] + keys_out_ref[...] = keys_out + for values_out_ref, values_out in zip( + values_out_refs, all_values_out, strict=True + ): + values_out_ref[...] = values_out + + perm = np.argsort(keys) + keys_result, *values_results = kernel(keys, *values) + np.testing.assert_array_equal(keys_result, keys[perm]) + for values_result, values_in in zip(values_results, values, strict=True): + np.testing.assert_array_equal(values_result, values_in[perm]) + + +class VectorSubcoreTestWithTCTiling(VectorSubcoreTest): + USE_TC_TILING = True + + +class ScalarSubcoreTest(PallasSCTest): + + def test_copy(self): + x = jnp.arange(16) + + @self.kernel( + out_shape=x, + mesh=plsc.ScalarSubcoreMesh( + axis_name="core", num_cores=self.sc_info.num_cores + ), + ) + def kernel(x_ref, o_ref): + lax.cond( + lax.axis_index("core") == lax.axis_size("core") - 1, + lambda: pltpu.sync_copy(x_ref, o_ref), + lambda: None, + ) + + np.testing.assert_array_equal(kernel(x), x) + + def test_sliced_copy(self): + self.skip_if_tc_tiling() + x = jnp.arange(self.sc_info.num_cores * 8).reshape( + self.sc_info.num_cores, -1 + ) + + @self.kernel( + out_shape=x, + mesh=plsc.ScalarSubcoreMesh( + axis_name="core", num_cores=self.sc_info.num_cores + ), + ) + def kernel(x_ref, o_ref): + @functools.partial(pl.run_scoped, sems=pltpu.SemaphoreType.DMA(4)) + def _(sems): + core_id = lax.axis_index("core") + pltpu.async_copy( + x_ref.at[core_id], o_ref.at[core_id], sems.at[core_id] + ).wait() + + np.testing.assert_array_equal(kernel(x), x) + + def test_scalar_load_store(self): + x = jnp.arange(8) + + @self.kernel( + out_shape=x, mesh=plsc.ScalarSubcoreMesh(axis_name="core", num_cores=1) + ) + def kernel(x_ref, o_ref): + @functools.partial( + pl.run_scoped, + tmp_ref=pltpu.SMEM(x.shape, x.dtype), + sem=pltpu.SemaphoreType.DMA, + ) + def _(tmp_ref, sem): + pltpu.async_copy(x_ref, tmp_ref, sem).wait() + + @pl.loop(1, *x.shape) + def _(i): + tmp_ref[i] += tmp_ref[i - 1] + + pltpu.async_copy(tmp_ref, o_ref, sem).wait() + + np.testing.assert_array_equal(kernel(x), jnp.cumsum(x)) + + @parameterized.product( + first_parallel=[False, True], second_parallel=[False, True] + ) + def test_parallel_loop(self, first_parallel, second_parallel): + self.skip_if_tc_tiling() + x = jnp.arange(8*8).reshape(8, 8) + + loop = lambda start, end, parallel, **kwargs: ( + plsc.parallel_loop(start, end, **kwargs) + if parallel + else pl.loop(start, end, **kwargs) + ) + + @self.kernel( + out_shape=x, + mesh=plsc.ScalarSubcoreMesh(axis_name="core", num_cores=1), + scratch_shapes=( + pltpu.SMEM(x.shape, x.dtype), + pltpu.SemaphoreType.DMA, + ), + ) + def kernel(x_ref, o_ref, tmp_ref, sem): + pltpu.async_copy(x_ref, tmp_ref, sem).wait() + + @loop(0, tmp_ref.shape[0], first_parallel) + def _(i): + @loop(0, tmp_ref.shape[1], second_parallel, unroll=tmp_ref.shape[1]) + def _(j): + tmp_ref[i, j] += 1 + + pltpu.async_copy(tmp_ref, o_ref, sem).wait() + + np.testing.assert_array_equal(kernel(x), x + 1) + + def test_parallel_loop_with_carry(self): + self.skip_if_tc_tiling() + x = jnp.arange(8*8).reshape(8, 8) + + @self.kernel( + out_shape=x, + mesh=plsc.ScalarSubcoreMesh(axis_name="core", num_cores=1), + scratch_shapes=( + pltpu.SMEM(x.shape, x.dtype), + pltpu.SemaphoreType.DMA, + ), + ) + def kernel(x_ref, o_ref, tmp_ref, sem): + pltpu.async_copy(x_ref, tmp_ref, sem).wait() + + @plsc.parallel_loop(0, tmp_ref.shape[0], carry=jnp.zeros([], x.dtype)) + def _(i, carry): + carry += 1 + @plsc.parallel_loop(0, tmp_ref.shape[1], unroll=2) + def _(j): + tmp_ref[i, j] += carry + return carry + + pltpu.async_copy(tmp_ref, o_ref, sem).wait() + + np.testing.assert_array_equal(kernel(x), x + jnp.arange(1, 9)[:, None]) + + +class ScalarSubcoreTestWithTCTiling(ScalarSubcoreTest): + USE_TC_TILING = True + + +class PipelineTest(PallasSCTest): + + def test_basic(self): + self.skip_if_tc_tiling() + num_steps = 16 + x = jnp.arange(num_steps * 8).reshape(-1, 8) + + @self.vector_subcore_kernel( + out_shape=x, + in_specs=(pl.BlockSpec(memory_space=pltpu.HBM),), + out_specs=pl.BlockSpec(memory_space=pltpu.HBM), + ) + def kernel(x_hbm_ref, o_hbm_ref): + @functools.partial( + pltpu.emit_pipeline, + grid=(num_steps // 2,), + in_specs=pl.BlockSpec((2, 8), lambda i: (i, 0)), + out_specs=pl.BlockSpec((2, 8), lambda i: (i, 0)), + ) + def pipeline(x_ref, o_ref): + o_ref[...] = x_ref[...] + 1 + + pipeline(x_hbm_ref, o_hbm_ref) + + np.testing.assert_array_equal(kernel(x), x + 1) + + def test_explicit_sc_tiling_1d(self): + self.skip_if_tc_tiling("The test uses SC tiling.") + + num_steps = 4 + x = jnp.arange(num_steps * 8) + + @self.vector_subcore_kernel( + out_shape=x, + in_specs=(pl.BlockSpec(memory_space=pltpu.HBM),), + out_specs=pl.BlockSpec(memory_space=pltpu.HBM), + ) + def kernel(x_hbm_ref, o_hbm_ref): + spec = plsc.BlockSpec((8,), lambda i: (i,)) + + @functools.partial( + pltpu.emit_pipeline, + grid=(num_steps,), + in_specs=spec, + out_specs=spec, + tiling=pltpu.Tiling.SPARSE_CORE, + ) + def pipeline(x_ref, o_ref): + o_ref[...] = x_ref[...] + 1 + + pipeline(x_hbm_ref, o_hbm_ref) + + np.testing.assert_array_equal(kernel(x), x + 1) + + def test_explicit_sc_tiling_2d(self): + self.skip_if_tc_tiling("The test uses SC tiling.") + + num_steps = 16 + x = jnp.arange(num_steps * 8 * 128).reshape(-1, 8, 128) + + @self.vector_subcore_kernel( + out_shape=x, + in_specs=(pl.BlockSpec(memory_space=pltpu.HBM),), + out_specs=pl.BlockSpec(memory_space=pltpu.HBM), + ) + def kernel(x_hbm_ref, o_hbm_ref): + spec = plsc.BlockSpec((pl.Squeezed(), 8, 128), lambda i: (i, 0, 0)) + + @functools.partial( + pltpu.emit_pipeline, + grid=(num_steps,), + in_specs=[spec], + out_specs=[spec], + tiling=pltpu.Tiling.SPARSE_CORE, + ) + def pipeline(x_ref, o_ref): + @pl.loop(0, 8) + def _(i): + @pl.loop(0, 128, step=8) + def _(j): + o_ref[i, pl.ds(j, 8)] = x_ref[i, pl.ds(j, 8)] + 1 + + pipeline(x_hbm_ref, o_hbm_ref) + + np.testing.assert_array_equal(kernel(x), x + 1) + + +class PipelineTestWithTCTiling(PipelineTest): + USE_TC_TILING = True + + +class PallasSparsecoreAsyncTest(PallasSCTest): + + def setUp(self): + super().setUp() + + @parameterized.product( + shape=[ + (8, 128), + (8, 256), + (8, 512), + (8, 1024), + (16, 128), + (16, 256), + (16, 512), + (16, 1024), + # TODO(sharadmv): These shapes fail right now. + # (64, 8), + ], + dtype=[jnp.int32, jnp.float32, jnp.bfloat16], + ) + def test_basic_async_kernel(self, shape, dtype): + x = jnp.arange(shape[0] * shape[1], dtype=dtype).reshape(shape) + + @jax.jit + def foo(x): + sc_mesh = plsc.ScalarSubcoreMesh(axis_name="core", num_cores=1) + + sem = pl.pallas_call( + lambda _: None, + out_shape=pltpu.SemaphoreType.DMA(()), + out_specs=pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + grid=(1,), + compiler_params=pltpu.CompilerParams( + dimension_semantics=["core_parallel"], + kernel_type=pltpu.KernelType.SC_SCALAR_SUBCORE, + use_tc_tiling_on_sc=self.USE_TC_TILING, + ), + )() + + sem_ref = jax.new_ref(sem, memory_space=pltpu.SEMAPHORE) + y_ref = pl.empty_ref_like(pltpu.HBM(x.shape, x.dtype)) + x_ref = jax.new_ref(x) + + run_kernel = pl.core_map(mesh=sc_mesh) + + @run_kernel + def _(): + pltpu.make_async_copy(x_ref, y_ref, sem_ref).start() + + @run_kernel + def _(): + pltpu.make_async_copy(x_ref, y_ref, sem_ref).wait() + + return y_ref[...] + + o = jax.block_until_ready(foo(x)) + np.testing.assert_array_equal(o, x) + + +class PallasSparsecoreAsyncTestWithTCTiling(PallasSparsecoreAsyncTest): + USE_TC_TILING = True + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_splash_attention_kernel_sharded_test.py b/tests/pallas/tpu_splash_attention_kernel_sharded_test.py new file mode 100644 index 000000000000..4b5b9ee19a48 --- /dev/null +++ b/tests/pallas/tpu_splash_attention_kernel_sharded_test.py @@ -0,0 +1,282 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for partitioning splash_attention.""" + +import functools +import math +from absl.testing import absltest, parameterized +import jax +from jax import random +from jax._src import test_util as jtu +from jax._src.pallas import pallas_test_util as ptu +from jax._src.shard_map import shard_map +from jax.experimental.pallas.ops.tpu.splash_attention import ( + CausalMask, + MultiHeadMask, + SegmentIds, + make_splash_mha, +) +from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel as splash +from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib +import jax.numpy as jnp +from jax.sharding import PartitionSpec +import numpy as np + +partial = functools.partial + +jax.config.parse_flags_with_absl() + + +@jtu.with_config(jax_traceback_filtering="off") +class PallasBaseTest(ptu.PallasTPUTest): + + def setUp(self): + super().setUp() + if len(jax.devices()) < 4: + self.skipTest("This test requires at least 4 devices.") + + def _assert_allclose(self, x, y, **kwargs): + if x.dtype == np.dtype(jnp.bfloat16): + x = x.astype(np.float32) + if y.dtype == np.dtype(jnp.bfloat16): + y = y.astype(np.float32) + self.assertEqual(x.dtype, y.dtype) + self.assertTupleEqual(x.shape, y.shape) + np.testing.assert_allclose(x, y, **kwargs) + + +def generate_mask(shape, num_heads, seed) -> np.ndarray: + assert num_heads >= 2 + assert shape > (64, 64) + + masks = [ + mask_lib.make_causal_mask(shape), + mask_lib.make_local_attention_mask(shape, window_size=(64, 64)), + ] + masks += [mask_lib.make_random_mask(shape, 0.8, seed)] * (num_heads - 2) + return np.stack(masks, axis=0) + + +class SplashAttentionShardingTest(PallasBaseTest): + + @parameterized.product( + topology=[(1, 1), (2, 1), (2, 2), (1, 2), (1, 4), (4, 1)], + num_heads=[2, 4, 16], + dtype=[jnp.bfloat16], + is_dynamic_mask=[False, True], + ) + def test_dynamic_mask_manual_partitioning_mha( + self, topology, num_heads, dtype, is_dynamic_mask + ): + k1, k2, k3 = random.split(random.key(0), 3) + seq_len = 1024 + head_dim = 128 + + head_shards, q_seq_shards = topology + num_devices = math.prod(topology) + + if head_shards > num_heads: + self.skipTest( + f"This test requires {num_heads} heads, but has only" + f" {head_shards} head shards available." + ) + + if len(jax.devices()) < num_devices: + self.skipTest( + f"This test requires {num_devices} devices, but has only" + f" {len(jax.devices())} devices available." + ) + + q = random.uniform(k1, (num_heads, seq_len, head_dim), dtype=dtype) + k = random.uniform(k2, (num_heads, seq_len, head_dim), dtype=dtype) + v = random.uniform(k3, (num_heads, seq_len, head_dim), dtype=dtype) + + mask = generate_mask((seq_len, seq_len), num_heads, seed=0) + if is_dynamic_mask: + mask = jnp.array(mask) + + devices = np.asarray(jax.devices()[:num_devices]).reshape( + head_shards, q_seq_shards + ) + + mesh = jax.sharding.Mesh(devices, ("heads", "q_seq")) + q_spec = PartitionSpec( + "heads" if head_shards > 1 else None, + "q_seq" if q_seq_shards > 1 else None, + ) + kv_spec = PartitionSpec("heads" if head_shards > 1 else None, None) + kernel = splash.make_splash_mha( + mask, head_shards=head_shards, q_seq_shards=q_seq_shards + ) + kernel_spec = kernel.manual_sharding_spec( + jax.sharding.NamedSharding(mesh, q_spec) + ) + + @partial( + shard_map, + mesh=mesh, + in_specs=( + kernel_spec, + q_spec, + kv_spec, + kv_spec, + ), + out_specs=q_spec, + check_vma=False, + ) + def f(kernel, q, k, v): + return kernel(q, k, v) + + out = f(kernel, q, k, v) + out_ref = jax.vmap(splash.attention_reference)(mask, q, k, v, None) + self._assert_allclose(out, out_ref, rtol=3e-3, atol=3e-3) + + @parameterized.product( + topology=[(1, 1), (2, 1), (2, 2), (1, 2), (1, 4), (4, 1)], + num_heads=[2, 4], + dtype=[jnp.bfloat16], + is_dynamic_mask=[False, True], + ) + def test_dynamic_mask_manual_partitioning_mha_bwd( + self, topology, num_heads, dtype, is_dynamic_mask + ): + assert num_heads % 2 == 0 + k1, k2, k3, k4 = random.split(random.key(0), 4) + seq_len = 1024 + head_dim = 128 + + head_shards, q_seq_shards = topology + num_devices = math.prod(topology) + + if head_shards > num_heads: + self.skipTest( + f"This test requires {num_heads} heads, but has only" + f" {head_shards} head shards available." + ) + + q = random.uniform(k1, (num_heads, seq_len, head_dim), dtype=dtype) + k = random.uniform(k2, (num_heads, seq_len, head_dim), dtype=dtype) + v = random.uniform(k3, (num_heads, seq_len, head_dim), dtype=dtype) + + mask = generate_mask((seq_len, seq_len), num_heads, seed=0) + if is_dynamic_mask: + mask = jnp.array(mask) + + devices = np.asarray(jax.devices()[:num_devices]).reshape( + head_shards, q_seq_shards + ) + + mesh = jax.sharding.Mesh(devices, ("heads", "q_seq")) + q_spec = PartitionSpec( + "heads" if head_shards > 1 else None, + "q_seq" if q_seq_shards > 1 else None, + ) + kv_spec = PartitionSpec("heads" if head_shards > 1 else None, None) + + kernel = splash.make_splash_mha( + mask, head_shards=head_shards, q_seq_shards=q_seq_shards + ) + kernel_spec = kernel.manual_sharding_spec( + jax.sharding.NamedSharding(mesh, q_spec) + ) + + @partial( + shard_map, + mesh=mesh, + in_specs=( + kernel_spec, + q_spec, + kv_spec, + kv_spec, + ), + out_specs=q_spec, + check_vma=False, + ) + def f(kernel, q, k, v): + return kernel(q, k, v) + + f_ref = jax.vmap(splash.attention_reference) + + out, out_vjp = jax.vjp(f, kernel, q, k, v) + out_ref, out_vjp_ref = jax.vjp(f_ref, mask, q, k, v, None) + self._assert_allclose(out, out_ref, rtol=3e-3, atol=3e-3) + + do = random.uniform(k4, out.shape, dtype=out.dtype) + _, dq, dk, dv = out_vjp(do) + _, dq_ref, dk_ref, dv_ref, _ = out_vjp_ref(do.astype(jnp.float32)) + + self.assertAllClose(dq, dq_ref, atol=5e-2) + self.assertAllClose(dk, dk_ref, atol=5e-2) + self.assertAllClose(dv, dv_ref, atol=5e-2) + + def test_splash_explicit_vmap_one_mesh_axis(self): + mesh = jax.make_mesh((4,), ("dp",)) + + NUM_HEADS = 4 + SEQ_LEN = 256 + HEAD_DIM = 64 + d_model = NUM_HEADS * HEAD_DIM + + key = jax.random.key(0) + input_sharding = jax.NamedSharding(mesh, jax.P("dp", None, None)) + x_seq = jax.random.normal(key, (4, SEQ_LEN, d_model), dtype=jnp.bfloat16) + x_seq = jax.device_put(x_seq, input_sharding) + + def make_splash_kernel_with_shard_map(mesh): + mask = MultiHeadMask([CausalMask(shape=(SEQ_LEN, SEQ_LEN)) + for _ in range(NUM_HEADS)]) + splash_spec = jax.P(None, None) + sspec = jax.NamedSharding(mesh, splash_spec) + + kernel = make_splash_mha(mask, head_shards=1, q_seq_shards=1) + kspec = kernel.manual_sharding_spec(sspec) + + @jax.shard_map( + mesh=mesh, + in_specs=(kspec, splash_spec, splash_spec, splash_spec, jax.P()), + out_specs=splash_spec, + check_vma=False, + ) + def splash_sharded(kernel, q, k, v, segment_ids): + return kernel(q, k, v, segment_ids=segment_ids) + + return splash_sharded, kernel + + def attention_fn_with_shmap(splash_sharded, kernel, x_seq): + s = x_seq.shape[0] + q, k, v = jnp.ones((3, NUM_HEADS, s, HEAD_DIM), out_sharding=jax.P()) + segment_ids = SegmentIds(q=jnp.zeros((s,)), kv=jnp.zeros((s,))) + scale = HEAD_DIM ** -0.25 + out = splash_sharded(kernel, q * scale, k * scale, v, segment_ids) + return out + + splash_sharded, kernel = make_splash_kernel_with_shard_map(mesh) + + @jax.jit + def step(x_seq): + def loss_fn(x_seq): + attn_fn = partial(attention_fn_with_shmap, splash_sharded, kernel) + out = jax.vmap(attn_fn)(x_seq) + return out.sum() + + loss, grads = jax.value_and_grad(loss_fn)(x_seq) + return loss, grads + + with jax.set_mesh(mesh): + step(x_seq) # doesn't crash + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_splash_attention_kernel_test.py b/tests/pallas/tpu_splash_attention_kernel_test.py index dfe0bcc0da3b..dc2a33fe8f71 100644 --- a/tests/pallas/tpu_splash_attention_kernel_test.py +++ b/tests/pallas/tpu_splash_attention_kernel_test.py @@ -22,9 +22,12 @@ from absl.testing import absltest from absl.testing import parameterized +import hypothesis as hp +import hypothesis.strategies as hps import jax from jax import random from jax._src import test_util as jtu +from jax._src.pallas import pallas_test_util as ptu from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel as splash from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask_info import process_mask @@ -32,12 +35,6 @@ import numpy as np -try: - import hypothesis as hp - import hypothesis.strategies as hps -except (ModuleNotFoundError, ImportError): - raise unittest.SkipTest("these tests require hypothesis") - jax.config.parse_flags_with_absl() jtu.setup_hypothesis(max_examples=5) @@ -211,7 +208,9 @@ def sequence_length_strategy(draw: Draw) -> tuple[int, int]: def attention_strategy(draw: Draw) -> tuple[int, int, int, np.dtype]: q_seq_len, kv_seq_len = draw(sequence_length_strategy()) head_dim_qk, head_dim_v = draw( - hps.sampled_from([(128, 128), (256, 256), (192, 128)]) + hps.sampled_from( + [(64, 64), (128, 192), (128, 128), (256, 256), (192, 128)] + ) ) if q_seq_len >= 4096 and kv_seq_len >= 4096: # Do not draw bfloat16 on longer sequence lengths, as this increases @@ -303,27 +302,19 @@ def attn_logits_soft_cap_strategy() -> hps.SearchStrategy[float | None]: return hps.one_of(hps.just(None), hps.floats(min_value=1.0, max_value=50.0)) -def to_dynamic_mask(mask: mask_lib.MultiHeadMask) -> jax.Array: - q_seq_len, kv_seq_len = mask.masks[0].shape - full_mask_slice = (slice(0, q_seq_len), slice(0, kv_seq_len)) - dynamic_mask = jnp.stack([m[full_mask_slice] for m in mask.masks], axis=0) - - return dynamic_mask - - @jtu.with_config(jax_traceback_filtering="off") -class PallasBaseTest(jtu.JaxTestCase): - INTERPRET = False +class PallasBaseTest(ptu.PallasTPUTest): def setUp(self): if not self.INTERPRET: - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Only interpret mode supported on non-TPU") # TODO(b/327487669): selectively re-enable tests that works on TPU v3. if not jtu.is_device_tpu_at_least(4): self.skipTest("Not supported on TPU generations <= 3") if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: self.skipTest("On CPU the test works only in 32-bit") + if jtu.is_device_tpu(7, 'x'): + # TODO(sharadmv): Enable these tests. + self.skipTest('Tests time out on TPUs v7x.') super().setUp() @@ -337,6 +328,7 @@ def _assert_allclose(self, x, y, **kwargs): np.testing.assert_allclose(x, y, **kwargs) +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe class SplashAttentionTest(PallasBaseTest): @parameterized.product( is_mqa=(False, True), @@ -384,7 +376,7 @@ def test_splash_attention(self, is_mqa, is_segmented, is_dynamic_mask, data): masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, num_q_heads)) mask = mask_lib.MultiHeadMask(tuple(m.get_mask() for m in masks)) if is_dynamic_mask: - mask = to_dynamic_mask(mask) + mask = jnp.array(mask[:, :, :]) block_sizes = data.draw(block_sizes_strategy(q_seq_len, kv_seq_len)) if is_mqa: @@ -460,7 +452,7 @@ def test_splash_attention_fwd( masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, num_q_heads)) mask = mask_lib.MultiHeadMask(tuple(m.get_mask() for m in masks)) if is_dynamic_mask: - mask = to_dynamic_mask(mask) + mask = jnp.array(mask[:, :, :]) block_sizes = data.draw(block_sizes_strategy(q_seq_len, kv_seq_len)) if is_mqa: attn_ref = splash.make_masked_mqa_reference(mask) @@ -522,9 +514,9 @@ def test_splash_attention_custom_bwd(self, is_segmented, data): masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, 1)) mask = jnp.array(masks[0].get_mask()[:, :]) attn_logits_soft_cap = data.draw(attn_logits_soft_cap_strategy(), - label="logit_cap") + label="logit_cap") attn_ref = partial(splash.attention_reference, mask, - attn_logits_soft_cap=attn_logits_soft_cap) + attn_logits_soft_cap=attn_logits_soft_cap) attn_custom = partial(splash.attention_reference_custom, mask, attn_logits_soft_cap=attn_logits_soft_cap) attn_custom_vanilla = partial(splash.attention_reference_custom, mask, @@ -532,7 +524,7 @@ def test_splash_attention_custom_bwd(self, is_segmented, data): attn_logits_soft_cap=attn_logits_soft_cap) o_ref, attn_vjp_ref = jax.vjp(attn_ref, q, k, v, segment_ids) q32, k32, v32 = jax.tree.map(lambda x: x.astype(jnp.float32), - (q, k, v)) + (q, k, v)) o_custom = attn_custom(q32, k32, v32, segment_ids) _, attn_vjp = jax.vjp(attn_custom, q32, k32, v32, segment_ids) _, attn_vanilla_vjp = jax.vjp(attn_custom_vanilla, q32, k32, v32, @@ -579,6 +571,7 @@ def test_splash_attention_custom_bwd(self, is_segmented, data): downcast_smem_data=(False, True), use_fused_bwd_kernel=(False, True), use_dynamic_mask=(False, True), + use_sinks=(False, True), ) @hp.given(hps.data()) def test_splash_attention_bwd( @@ -588,11 +581,12 @@ def test_splash_attention_bwd( downcast_smem_data, use_fused_bwd_kernel, use_dynamic_mask, + use_sinks, data, ): seed = data.draw(seed_strategy()) key = random.key(seed) - k1, k2, k3, k4 = random.split(key, 4) + k1, k2, k3, k4, k_sinks = random.split(key, 5) ( q_seq_len, @@ -619,6 +613,10 @@ def test_splash_attention_bwd( v = random.uniform( k3, (num_kv_heads, kv_seq_len, head_dim_v), dtype=dtype ) + if use_sinks: + sinks = 1.0 * random.uniform(k_sinks, (num_q_heads,), dtype=dtype) + else: + sinks = None segment_ids = None if is_segmented: @@ -628,7 +626,7 @@ def test_splash_attention_bwd( masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, num_q_heads)) mask = mask_lib.MultiHeadMask(tuple(m.get_mask() for m in masks)) if use_dynamic_mask: - mask = to_dynamic_mask(mask) + mask = jnp.array(mask[:, :, :]) block_sizes = data.draw( block_sizes_strategy(q_seq_len, kv_seq_len, include_bwd_blocks=True, use_fused_bwd_kernel=use_fused_bwd_kernel) @@ -651,7 +649,10 @@ def test_splash_attention_bwd( attn_logits_soft_cap=attn_logits_soft_cap, interpret=self.INTERPRET, ) - o, attn_vjp = jax.vjp(attn, q, k, v, segment_ids) + if use_sinks: + o, attn_vjp = jax.vjp(attn, q, k, v, segment_ids, sinks) + else: + o, attn_vjp = jax.vjp(attn, q, k, v, segment_ids) q32, k32, v32 = jax.tree.map( lambda x: x.astype(jnp.float32), (q, k, v) ) @@ -660,44 +661,50 @@ def test_splash_attention_bwd( k32, v32, segment_ids, + sinks=sinks, save_residuals=True, attn_logits_soft_cap=attn_logits_soft_cap, ) self._assert_allclose(o, o_ref, atol=3e-3, rtol=3e-3) do = random.uniform(k4, o.shape, dtype=o.dtype) - dq, dk, dv, _ = attn_vjp(do) + if use_sinks: + dq, dk, dv, _, dsinks = attn_vjp(do) + else: + dq, dk, dv, _ = attn_vjp(do) + dsinks = None def bwd( - mask, q, k, v, segment_ids, o, logsumexp, do + mask, q, k, v, segment_ids, sinks, o, logsumexp, do ) -> tuple[jax.Array, jax.Array, jax.Array]: - _, dq, dk, dv, _ = splash._attention_reference_custom_bwd( + _, dq, dk, dv, _, dsinks = splash._attention_reference_custom_bwd( splash.DEFAULT_MASK_VALUE, False, "flash", attn_logits_soft_cap, - (mask, q, k, v, segment_ids, o, logsumexp), + (mask, q, k, v, segment_ids, sinks, o, logsumexp), do, ) - return dq, dk, dv + return dq, dk, dv, dsinks is_grouped = not is_mqa and num_kv_heads < num_q_heads assert num_q_heads % num_kv_heads == 0 head_multiplier = num_q_heads // num_kv_heads if is_mqa: - bwd = jax.vmap(bwd, in_axes=(0, 0, None, None, None, 0, 0, 0)) + bwd = jax.vmap(bwd, in_axes=(0, 0, None, None, None, 0, 0, 0, 0)) else: - bwd = jax.vmap(bwd, in_axes=(0, 0, 0, 0, None, 0, 0, 0)) + bwd = jax.vmap(bwd, in_axes=(0, 0, 0, 0, None, 0, 0, 0, 0)) # Interleave the KV heads to match the corresponding Q heads. if is_grouped: k32 = jnp.repeat(k32, head_multiplier, axis=0) v32 = jnp.repeat(v32, head_multiplier, axis=0) - dq_ref, dk_ref, dv_ref = bwd( + dq_ref, dk_ref, dv_ref, dsinks_ref = bwd( mask[:, :, :], q32, k32, v32, segment_ids, + sinks, o.astype(jnp.float32), logsumexp, do.astype(jnp.float32), @@ -715,6 +722,8 @@ def bwd( self._assert_allclose(dv, dv_ref, atol=2e-2, rtol=3e-2) self._assert_allclose(dq, dq_ref, atol=2e-2, rtol=3e-2) self._assert_allclose(dk, dk_ref, atol=2e-2, rtol=3e-2) + if use_sinks: + self._assert_allclose(dsinks, dsinks_ref, atol=5e-3, rtol=3e-2) def test_grid_shrinking(self): """Make sure that grid shrinking does not change the attention output.""" diff --git a/tests/pallas/tpu_splash_attention_mask_test.py b/tests/pallas/tpu_splash_attention_mask_test.py index f39b4d839340..7e559716fab8 100644 --- a/tests/pallas/tpu_splash_attention_mask_test.py +++ b/tests/pallas/tpu_splash_attention_mask_test.py @@ -14,6 +14,8 @@ from __future__ import annotations +import sys + from absl.testing import absltest from absl.testing import parameterized import jax @@ -44,6 +46,15 @@ def _make_local_attention_mask(*args, **kwargs): return mask_lib.make_local_attention_mask(*args, **kwargs) +def _make_lazy_chunked_causal_mask(shape, chunk_size): + mask = mask_lib.ChunkedCausalMask(shape=shape, chunk_size=chunk_size) + return mask[:, :] + + +def _make_chunked_causal_mask(shape, chunk_size): + return mask_lib.make_chunk_attention_mask(shape=shape, chunk_size=chunk_size) + + class SplashAttentionMaskTest(jtu.JaxTestCase): @parameterized.parameters([_make_lazy_causal_mask, _make_causal_mask]) @@ -412,7 +423,186 @@ def test_lazy_local_mask_chunking( block_size, ) + @parameterized.parameters( + [_make_lazy_chunked_causal_mask, _make_chunked_causal_mask] + ) + def test_chunked_causal_mask(self, make_chunked_mask): + """Tests the chunked causal mask logic for various shapes and chunk sizes.""" + with self.subTest("unit"): + expected = np.array([[1]], dtype=np.bool_) + actual = make_chunked_mask(shape=(1, 1), chunk_size=1) + self.assertArraysEqual(actual, expected) + actual = make_chunked_mask(shape=(1, 1), chunk_size=2) + self.assertArraysEqual(actual, expected) + + with self.subTest("square_exact_chunks"): + # Chunk 0: [0, 1], Chunk 1: [2, 3] + expected = np.array( + [ + [1, 0, 0, 0], + [1, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 1, 1], + ], + dtype=np.bool_, + ) + actual = make_chunked_mask(shape=(4, 4), chunk_size=2) + self.assertArraysEqual(actual, expected) + + with self.subTest("square_uneven_chunks"): + expected = np.array( + [ + [1, 0, 0, 0, 0], + [1, 1, 0, 0, 0], + [1, 1, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 1, 1], + ], + dtype=np.bool_, + ) + actual = make_chunked_mask(shape=(5, 5), chunk_size=3) + self.assertArraysEqual(actual, expected) + + with self.subTest("wide_rectangle"): + expected = np.array( + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + ], + dtype=np.bool_, + ) + actual = make_chunked_mask(shape=(4, 6), chunk_size=3) + self.assertArraysEqual(actual, expected) + + with self.subTest("tall_rectangle"): + expected = np.array( + [ + [1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0], + [0, 0, 0, 1], + [0, 0, 0, 1], + [0, 0, 0, 1], + ], + dtype=np.bool_, + ) + actual = make_chunked_mask(shape=(6, 4), chunk_size=3) + self.assertArraysEqual(actual, expected) + + with self.subTest("chunk_size_1"): + # Should only allow self-attention q==k and chunk_size == 1 + expected = np.array( + [ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1], + ], + dtype=np.bool_, + ) + actual = make_chunked_mask(shape=(4, 4), chunk_size=1) + self.assertArraysEqual(actual, expected) + + with self.subTest("chunk_size_greater_equal_seqlen"): + # Should behave like a normal causal mask + expected = np.array( + [ + [1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0], + [1, 1, 1, 1], + ], + dtype=np.bool_, + ) + # Test chunk_size == seqlen + actual_eq = make_chunked_mask(shape=(4, 4), chunk_size=4) + self.assertArraysEqual(actual_eq, expected) + # Test chunk_size > seqlen + actual_gt = make_chunked_mask(shape=(4, 4), chunk_size=5) + self.assertArraysEqual(actual_gt, expected) + + @parameterized.product( + block_size=[(128, 128), (256, 128), (128, 256)], + shape=[(512, 512), (512, 1024), (1024, 512)], + chunk_size=[64, 128, 256, 512, 1024], + ) + def test_lazy_chunked_causal_mask_chunking( + self, + block_size: tuple[int, int], + shape: tuple[int, int], + chunk_size: int, + ): + """Compares lazy chunked mask evaluation against the dense version block-by-block.""" + q_len, kv_len = shape + # Adjust block size if it exceeds shape dimensions + adjusted_block_size = ( + min(block_size[0], q_len), + min(block_size[1], kv_len), + ) + + if ( + q_len % adjusted_block_size[0] != 0 + or kv_len % adjusted_block_size[1] != 0 + ): + self.skipTest( + f"Shape {shape} not divisible by block_size {adjusted_block_size}" + ) + + dense_mask = _make_chunked_causal_mask(shape=shape, chunk_size=chunk_size) + lazy_mask = mask_lib.ChunkedCausalMask(shape=shape, chunk_size=chunk_size) + self._compare_masks( + dense_mask, + lazy_mask, + adjusted_block_size, + ) + + def test_chunked_causal_mask_invalid_chunk_size(self): + """Tests that invalid chunk_size raises ValueError.""" + with self.assertRaises(ValueError): + mask_lib.ChunkedCausalMask(shape=(10, 10), chunk_size=0) + with self.assertRaises(ValueError): + mask_lib.ChunkedCausalMask(shape=(10, 10), chunk_size=-1) + with self.assertRaises(ValueError): + mask_lib.make_chunk_attention_mask(shape=(10, 10), chunk_size=0) + + def test_chunked_causal_mask_minimal_equality_hash(self): + """Tests for __eq__ and __hash__ of ChunkedCausalMask.""" + shape1, chunk_size1 = (128, 256), 16 + shape2, chunk_size2 = (128, 128), 32 # Different shape/chunk_size + + # Create three masks: two identical, one with different shape/chunk_size. + mask1 = mask_lib.ChunkedCausalMask(shape=shape1, chunk_size=chunk_size1) + mask2 = mask_lib.ChunkedCausalMask(shape=shape1, chunk_size=chunk_size1) + mask_diff_shape = mask_lib.ChunkedCausalMask( + shape=shape2, chunk_size=chunk_size1 + ) + mask_diff_chunk = mask_lib.ChunkedCausalMask( + shape=shape1, chunk_size=chunk_size2 + ) + other_obj = object() + + # Test __eq__ + self.assertEqual(mask1, mask2) + self.assertNotEqual(mask1, mask_diff_shape) + self.assertNotEqual(mask1, mask_diff_chunk) + self.assertNotEqual(mask1, other_obj) + + # Test __hash__ of identical masks + self.assertEqual(hash(mask1), hash(mask2)) + + mask_set = {mask1, mask2, mask_diff_chunk} + self.assertLen(mask_set, 2) # mask1 and mask2 are duplicates + self.assertIn(mask1, mask_set) + self.assertIn(mask_diff_chunk, mask_set) + self.assertNotIn(mask_diff_shape, mask_set) + def test_using_logical_operators_raises_exception(self): + if sys.version_info == (3, 14, 0, "candidate", 1): + # Fails due to Python bug on 3.14.0rc1 + # https://github.com/python/cpython/issues/137288 + self.skipTest("Expected failure.") mask_1 = mask_lib.NumpyMask( mask_lib.make_random_mask((256, 256), 0.5, seed=1) ) @@ -1064,7 +1254,8 @@ def test_local_mask(self, is_lazy_mask: bool): mask_info, mask_info_dkv, mask_function = self._process_mask( multi_head, block_shape ) - self.assertIsNone(mask_function) + if is_lazy_mask: + self.assertIsNotNone(mask_function) expected_partial_mask_blocks = self._stack( [ @@ -1108,10 +1299,12 @@ def test_local_mask(self, is_lazy_mask: bool): expected_mask_info = mask_info_lib.MaskInfo( expected_local_data_next, - expected_local_mask_next, + expected_local_mask_next if not is_lazy_mask else None, expected_local_block_mask, - expected_partial_mask_blocks, - None, + expected_partial_mask_blocks if not is_lazy_mask else None, + np.arange(sequence_lengths[0], dtype=np.int32) + if is_lazy_mask + else None, ) expected_local_data_next_dkv = np.array( @@ -1143,10 +1336,14 @@ def test_local_mask(self, is_lazy_mask: bool): expected_mask_info_dkv = mask_info_lib.MaskInfo( expected_local_data_next_dkv, - expected_local_mask_next_dkv, + expected_local_mask_next_dkv if not is_lazy_mask else None, expected_local_block_mask_dkv, - expected_partial_mask_blocks.swapaxes(-1, -2), - None, + expected_partial_mask_blocks.swapaxes(-1, -2) + if not is_lazy_mask + else None, + np.arange(sequence_lengths[0], dtype=np.int32) + if is_lazy_mask + else None, ) self._assert_mask_info_match(mask_info, expected_mask_info) @@ -1175,7 +1372,9 @@ def test_local_mask_narrow(self, is_lazy_mask: bool): mask_info, mask_info_dkv, mask_function = self._process_mask( multi_head, block_shape ) - self.assertIsNone(mask_function) + + if is_lazy_mask: + self.assertIsNotNone(mask_function) expected_partial_mask_blocks = self._stack( [ @@ -1216,10 +1415,12 @@ def test_local_mask_narrow(self, is_lazy_mask: bool): expected_mask_info = mask_info_lib.MaskInfo( expected_local_data_next, - expected_local_mask_next, + expected_local_mask_next if not is_lazy_mask else None, expected_local_block_mask, - expected_partial_mask_blocks, - None, + expected_partial_mask_blocks if not is_lazy_mask else None, + np.arange(sequence_lengths[0], dtype=np.int32) + if is_lazy_mask + else None, ) expected_local_data_next_dkv = np.array( @@ -1248,10 +1449,14 @@ def test_local_mask_narrow(self, is_lazy_mask: bool): expected_mask_info_dkv = mask_info_lib.MaskInfo( expected_local_data_next_dkv, - expected_local_mask_next_dkv, + expected_local_mask_next_dkv if not is_lazy_mask else None, expected_local_block_mask_dkv, - expected_partial_mask_blocks.swapaxes(-1, -2), - None, + expected_partial_mask_blocks.swapaxes(-1, -2) + if not is_lazy_mask + else None, + np.arange(sequence_lengths[0], dtype=np.int32) + if is_lazy_mask + else None, ) self._assert_mask_info_match(mask_info, expected_mask_info) @@ -2066,11 +2271,12 @@ def test_huge_mask2(self): multi_head, block_shape ) - self.assertIsNone(mask_function) + self.assertIsNotNone(mask_function) self.assertIsNotNone(mask_info.block_mask) self.assertIsNotNone(mask_info.data_next) - self.assertIsNotNone(mask_info.mask_next) - self.assertIsNotNone(mask_info.partial_mask_blocks) + self.assertIsNone(mask_info.mask_next) + self.assertIsNone(mask_info.partial_mask_blocks) + self.assertIsNotNone(mask_info.q_sequence) def test_process_invalid_mask(self): """Masks with of an all-0 row causes undefined softmax, reject them.""" @@ -2166,7 +2372,9 @@ def test_dynamic_mask(self, is_dkv: bool): self.assertArraysEqual(mask_info.block_mask, _expected_block_mask) self.assertArraysEqual( - mask_info.partial_mask_blocks, + mask_info.partial_mask_blocks.reshape( + -1, *mask_info.partial_mask_blocks.shape[-2:] + ), _expected_partial_mask_blocks, ) self.assertArraysEqual(mask_info.mask_next, _expected_mask_next) diff --git a/tests/pallas/tpu_trace_value_test.py b/tests/pallas/tpu_trace_value_test.py new file mode 100644 index 000000000000..2371b9a8c724 --- /dev/null +++ b/tests/pallas/tpu_trace_value_test.py @@ -0,0 +1,74 @@ +# Copyright 2026 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Minimal test for pltpu.trace_value primitive.""" + +from absl.testing import absltest +import jax +from jax._src import test_util as jtu +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +import jax.numpy as jnp + + +def simple_kernel_with_trace_value(x_ref, s_ref, o_ref): + """Simple kernel that emits trace metrics.""" + # Emit a constant to xprof trace + pltpu.trace_value("constant_value", jnp.float32(42.42)) + scale = s_ref[0] + + z = x_ref[...] + jnp.float32(48.0) + scale.astype(jnp.float32).reshape((1, 1)) + pltpu.trace_value("scale_value", scale) + o_ref[...] = z + + +class TraceMetricTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if not jtu.is_device_tpu(): + self.skipTest("trace_value only supported on TPU.") + + def test_simple_trace_metric(self): + """Test that trace_metric compiles and runs without error.""" + if jtu.test_device_matches(["tpu"]) and not jtu.is_cloud_tpu_at_least(2026, 1, 16): + self.skipTest("Requires libtpu built after 2026-1-16") + + x = jnp.ones((8, 128), dtype=jnp.float32) + + s = jax.random.randint(jax.random.key(0), (1,), minval=0, maxval=100) + + result = pl.pallas_call( + simple_kernel_with_trace_value, + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + in_specs=[ + pl.BlockSpec((8, 128), memory_space=pltpu.VMEM), + pl.BlockSpec((1,), memory_space=pltpu.SMEM), + ], + out_specs=pl.BlockSpec((8, 128), memory_space=pltpu.VMEM), + compiler_params=pltpu.CompilerParams(has_side_effects=True), + name="trace_metric_test", + )(x, s) + + # Just verify the kernel runs and produces correct output + # TODO(amuzio): Verify the kernel runs and includes the vtrace in an actual + # xprof. + self.assertEqual(result.shape, (8, 128)) + self.assertTrue( + jnp.allclose(result, x + 48.0 + s.astype(jnp.float32).reshape((1, 1))) + ) + + +if __name__ == "__main__": + jax.config.parse_flags_with_absl() + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/triton_pallas_test.py b/tests/pallas/triton_pallas_test.py new file mode 100644 index 000000000000..fdc7d3840ff1 --- /dev/null +++ b/tests/pallas/triton_pallas_test.py @@ -0,0 +1,537 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import sys +import unittest + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax import lax +from jax._src import config +from jax._src import dtypes +from jax._src import test_util as jtu +from jax._src.pallas import pallas_call +from jax.experimental import pallas as pl +import jax.numpy as jnp +import numpy as np + +if sys.platform != "win32": + from jax.experimental.pallas import triton as plgpu +else: + plgpu = None + +config.parse_flags_with_absl() + +intx = dtypes.default_int_dtype() +floatx = dtypes.default_float_dtype() + + +@jtu.with_config(jax_traceback_filtering="off") +class PallasBaseTest(jtu.JaxTestCase): + INTERPRET = False + + def setUp(self): + if jtu.test_device_matches(["cpu"]): + if not self.INTERPRET: + self.skipTest("On CPU the test works only in interpret mode") + elif jtu.test_device_matches(["gpu"]): + if jtu.test_device_matches(["cuda"]) and \ + not jtu.is_cuda_compute_capability_at_least("9.0"): + self.skipTest("Only works on GPU with capability >= sm90") + else: + self.skipTest("Test only works on CPU and GPU") + + super().setUp() + + def pallas_call(self, *args, **kwargs): + return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET) + + +DTYPE_LIST = [jnp.float32, jnp.float16, jnp.bfloat16, + jnp.float8_e4m3fn, jnp.float8_e5m2] + + +class TritonPallasTest(PallasBaseTest): + INTERPRET = False + + def setUp(self): + super().setUp() + # Force tests to use Triton. + self.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(False)) + + @parameterized.product(src_dtype=DTYPE_LIST, dst_dtype=DTYPE_LIST) + def test_fp_dtype_cast(self, src_dtype, dst_dtype): + if src_dtype == dst_dtype: + self.skipTest("No need to test the same dtype") + if dtypes.itemsize_bits(src_dtype) == 8 and dtypes.itemsize_bits(dst_dtype) == 8: + self.skipTest("Not casting between 8-bit types") + + def body(x_ref, y_ref): + y_ref[...] = x_ref[...].astype(dst_dtype) + + x = 10 * jax.random.normal(jax.random.key(0), (64, 64), dtype=src_dtype) + y = self.pallas_call(body, + in_specs=[pl.BlockSpec((64, 64), lambda i: (0, 0))], + out_specs=pl.BlockSpec((64, 64), lambda i: (0, 0)), + out_shape=jax.ShapeDtypeStruct((64, 64), dst_dtype), + grid=(1,), + )(x) + self.assertEqual(y.dtype, dst_dtype) + self.assertArraysEqual(y, x.astype(dst_dtype)) + + @parameterized.named_parameters( + ("add_i32", "atomic_add", np.array([1, 2, 3, 4], np.int32), np.sum), + ("max_i32", "atomic_max", np.array([1, 2, 3, 4], np.int32), np.max), + ("min_i32", "atomic_min", np.array([1, 2, 3, 4], np.int32), np.min), + ("add_f16", "atomic_add", np.array([1, 2, 3, 4], np.float16), np.sum), + ("add_f32", "atomic_add", np.array([1, 2, 3, 4], np.float32), np.sum), + ("max_f32", "atomic_max", np.array([1, 2, 3, 4], np.float32), np.max), + ("min_f32", "atomic_min", np.array([1, 2, 3, 4], np.float32), np.min), + ) + def test_scalar_atomic(self, op, value, numpy_op): + if plgpu is None: + self.skipTest("plgpu not available on this platform.") + + op = getattr(plgpu, op) + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((), value.dtype), + grid=value.shape[0], + input_output_aliases={1: 0}, + ) + def atomic_kernel(x_ref, _, o_ref): + pid = pl.program_id(axis=0) + op(o_ref, (), x_ref[pid]) + + if op == plgpu.atomic_add: + neutral = np.array(0, dtype=value.dtype) + elif op == plgpu.atomic_max: + if np.issubdtype(value.dtype, np.integer): + neutral = np.array(np.iinfo(value.dtype).min, value.dtype) + else: + # JAX on ROCm does not currently handle atomic fmin/fmax correctly + if jtu.test_device_matches(["rocm"]): + self.skipTest("Atomic fmin/fmax not currently supported on ROCm.") + neutral = np.array(-float("inf"), value.dtype) + elif op == plgpu.atomic_min: + if np.issubdtype(value.dtype, np.integer): + neutral = np.array(np.iinfo(value.dtype).max, value.dtype) + else: + # JAX on ROCm does not currently handle atomic fmin/fmax correctly + if jtu.test_device_matches(["rocm"]): + self.skipTest("Atomic fmin/fmax not currently supported on ROCm.") + neutral = np.array(float("inf"), value.dtype) + elif op == plgpu.atomic_or: + neutral = np.array(False, value.dtype) + else: + raise NotImplementedError() + out = atomic_kernel(value, neutral) + np.testing.assert_allclose(out, numpy_op(value)) + + @parameterized.parameters((0,), (1,)) + def test_array_atomic_add(self, axis): + m, n = 32, 8 + if axis == 0: + grid = m + else: + grid = n + out_shape = jax.ShapeDtypeStruct((n if axis == 0 else m,), floatx) + + @functools.partial( + self.pallas_call, + out_shape=out_shape, + grid=grid, + input_output_aliases={1: 0}, + ) + def reduce(x_ref, _, y_ref): + i = pl.program_id(axis=0) + if axis == 0: + idx = (i, jnp.arange(n)) + else: + idx = (jnp.arange(m), i) + x = x_ref[idx] + plgpu.atomic_add(y_ref, (jnp.arange(y.shape[0]),), x) + + x = jax.random.normal(jax.random.key(0), (m, n)) + y = jnp.zeros(out_shape.shape, out_shape.dtype) + y = reduce(x, y) + y_ref = np.sum(x, axis=axis) + np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2) + + @parameterized.parameters( + (0, 0, 1), + (0, 1, 1), + (1, 0, 1), + (1, 1, 1), + (2, 1, 1), + (2, 1, 1), + ) + def test_atomic_cas(self, init_value, cmp, new_value): + if jax.config.x64_enabled: + self.skipTest("Not supported in 64-bit mode") + + @functools.partial( + self.pallas_call, + out_shape=( + jax.ShapeDtypeStruct((), intx), + jax.ShapeDtypeStruct((), intx), + ), + input_output_aliases={0: 0}, + ) + def swap(_, lock_ref, out_ref): + out_ref[()] = plgpu.atomic_cas(lock_ref, cmp, new_value) + + lock, out = swap(init_value) + np.testing.assert_allclose( + lock, new_value if cmp == init_value else init_value + ) + np.testing.assert_allclose(out, init_value) + + @parameterized.parameters(1, 2, 3, 4, 8) + def test_atomic_counter(self, num_threads): + if self.INTERPRET: + self.skipTest("While loop not supported in interpret mode.") + if jax.config.x64_enabled: + self.skipTest("Not supported in 64-bit mode") + + @functools.partial( + self.pallas_call, + out_shape=( + jax.ShapeDtypeStruct((), intx), + jax.ShapeDtypeStruct((), intx), + ), + input_output_aliases={0: 0, 1: 1}, + grid=(num_threads,), + ) + def increment(_, __, lock_ref, counter_ref): + def _cond(_): + return plgpu.atomic_cas(lock_ref, 0, 1) == 1 + + lax.while_loop(_cond, lambda a: a, 0) + counter_ref[...] += 1 + plgpu.atomic_xchg(lock_ref, (), 0) + + lock, count = increment(0, 0) + np.testing.assert_allclose(lock, 0) + np.testing.assert_allclose(count, num_threads) + + @parameterized.product( + size=[1, 2, 64, 129, 1021], + block_size=[1, 2, 32, 64, 128], + ) + def test_masked_load_store(self, size, block_size): + @functools.partial( + self.pallas_call, + out_shape=(jax.ShapeDtypeStruct((size,), floatx)), + grid=pl.cdiv(size, block_size), + ) + def kernel(x_ref, o_ref): + idx = pl.program_id(0) * block_size + jnp.arange( + block_size, dtype=jnp.int32 + ) + mask = idx < x_ref.shape[0] + x = plgpu.load(x_ref.at[idx], mask=mask) + plgpu.store(o_ref.at[idx], x + 1.0, mask=mask) + + key = jax.random.key(0) + x = jax.random.normal(key, (size,)) + np.testing.assert_allclose(kernel(x), x + 1.0, atol=1e-5, rtol=1e-5) + + def test_masked_oob_load_store_slice(self): + n = 16 + + @functools.partial( + self.pallas_call, + out_shape=(jax.ShapeDtypeStruct((n,), floatx)), + ) + def masked_oob_load_store_slice(x_ref, mask_ref, start_idx_ref, o_ref): + x = plgpu.load( + x_ref.at[pl.ds(start_idx_ref[()], n)], mask=mask_ref[:], other=-1.0 + ) + o_ref[...] = x + + x = jax.random.normal(jax.random.key(0), (n,)) + slice_start = jax.random.randint(jax.random.key(2), (), 1, n) + indices = jnp.arange(n) + slice_start + mask = indices < n + out = masked_oob_load_store_slice(x, mask, slice_start) + o_new = jnp.where(mask, x[indices], jnp.full_like(x, -1.0)) + np.testing.assert_array_equal(out, o_new) + + @parameterized.parameters( + ((16, 32), (16,)), + ((16, 32), (32,)), + ((16, 32), (16, 16)), + ) + def test_invalid_broadcasted_load(self, x_shape, mask_shape): + if self.INTERPRET: + self.skipTest("No broadcasting checks in pl.load in interpret mode") + + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32) + ) + def kernel(x_ref, mask_ref, o_ref): + del o_ref # Unused. + plgpu.load(x_ref, mask=mask_ref[:]) + + x = jnp.ones(x_shape, dtype=jnp.float32) + mask = jnp.ones(mask_shape, dtype=jnp.bool_) + # assertRaises* methods do not support inspecting the __cause__, so + # we have to check it manually. + try: + kernel(x, mask) + except Exception as e: + self.assertIn("Cannot broadcast", str(e.__cause__)) + else: + self.fail("Expected exception due to invalid broadcasting") + + @parameterized.parameters("float16", "bfloat16", "float32") + def test_approx_tanh(self, dtype): + if self.INTERPRET: + self.skipTest("approx_tanh is not supported in interpret mode") + + if (dtype == "bfloat16" and + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("tanh.approx.bf16 requires a GPU with capability >= sm90") + + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), dtype), + ) + def kernel(x_ref, o_ref): + o_ref[...] = plgpu.approx_tanh(x_ref[...]) + + x = jnp.asarray([-1, 0.42, 0.24, 1]).astype(dtype) + # We upcast to float32 because NumPy <2.0 does not handle custom dtypes + # properly. See https://github.com/jax-ml/jax/issues/11014. + np.testing.assert_allclose( + kernel(x).astype(jnp.float32), + jnp.tanh(x).astype(jnp.float32), + atol=5e-3, + rtol=5e-3, + ) + + def test_elementwise_inline_asm(self): + if self.INTERPRET: + self.skipTest( + "elementwise_inline_asm is not supported in interpret mode" + ) + + if jtu.is_device_rocm(): + self.skipTest("elementwise_inline_asm is not currently " + "supported on ROCm") + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((256,), jnp.float16), + ) + def kernel(x_ref, o_ref): + [o_ref[...]] = plgpu.elementwise_inline_asm( + "tanh.approx.f16x2 $0, $1;", + args=[x_ref[...]], + constraints="=r,r", + pack=2, + result_shape_dtypes=[jax.ShapeDtypeStruct(x_ref.shape, x_ref.dtype)], + ) + + x = jnp.arange(256).astype(jnp.float16) + np.testing.assert_allclose(kernel(x), jnp.tanh(x), atol=5e-3, rtol=5e-3) + + def test_debug_barrier(self): + if self.INTERPRET: + self.skipTest("debug_barrier is not supported in interpret mode") + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), + ) + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + plgpu.debug_barrier() + + x = jnp.array([4.2, 2.4]).astype(jnp.float32) + np.testing.assert_array_equal(kernel(x), x) + + @unittest.skipIf( + sys.platform == "win32", + "plgpu.CompilerParams unavailable on Windows", + ) + def test_debug_print(self): + if jtu.test_device_matches(["gpu"]): + self.skipTest("This test flakes on gpu") + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), + compiler_params=plgpu.CompilerParams( + num_warps=1, num_stages=1 + ), + ) + def kernel(x_ref, o_ref): + pl.debug_print("It works!") + + x = jnp.array([4.2, 2.4]).astype(jnp.float32) + with jtu.capture_stdout() as output: + jax.block_until_ready(kernel(x)) + jax.effects_barrier() + + self.assertIn("It works!", output()) + + @unittest.skipIf( + sys.platform == "win32", + "plgpu.CompilerParams unavailable on Windows", + ) + def test_debug_print_with_values(self): + if jtu.test_device_matches(["gpu"]): + self.skipTest("This test flakes on gpu") + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), + compiler_params=plgpu.CompilerParams( + num_warps=1, num_stages=1 + ), + ) + def kernel(x_ref, o_ref): + pl.debug_print("x[0] =", x_ref[0]) + + x = jnp.array([4.2, 2.4]).astype(jnp.float32) + with jtu.capture_stdout() as output: + jax.block_until_ready(kernel(x)) + jax.effects_barrier() + + self.assertIn("x[0] = 4.2", output()) + + @parameterized.named_parameters(*[ + (f"m_{m}_n_{n}_k_{k}_dtype_{dtype}_bm_{block_size_m}_" + f"bn_{block_size_n}_bk_{block_size_k}_gm_{group_size_m}", m, n, k, dtype, + block_size_m, block_size_n, block_size_k, group_size_m) + for m in [512, 1024] + for k in [512] + for n in [512, 1024] + for dtype in ["float32", "float16"] + for block_size_m in [64, 128] + for block_size_n in [64, 128] + for block_size_k in [32] + for group_size_m in [8] + if block_size_m <= m and block_size_n <= n and block_size_k <= k + ]) + def test_matmul(self, m, n, k, dtype, bm, bn, bk, gm): + if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: + self.skipTest("On TPU the test works only in interpret mode") + k1, k2 = jax.random.split(jax.random.key(0)) + x = jax.random.normal(k1, (m, k), dtype=dtype) + y = jax.random.normal(k2, (k, n), dtype=dtype) + out = matmul(x, y, bm=bm, bn=bn, bk=bk, gm=gm, + interpret=self.INTERPRET) + expected = jnp.matmul( + x, y, preferred_element_type=jnp.float32).astype(dtype) + np.testing.assert_allclose(out, expected, atol=0.05, rtol=0.05) + + @parameterized.named_parameters(*( + dict( + testcase_name=f"{batch_size}_{size}_{block_size}_{dtype}", + batch_size=batch_size, + size=size, + block_size=block_size, + dtype=dtype, + ) + for batch_size in [1, 2, 4, 23] + for size in [1, 2, 129, 255, 256] + for block_size in [1, 2, 32, 64, 128, 256] + for dtype in ["float32"] + if size < block_size + )) + def test_softmax(self, batch_size, size, block_size, dtype): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((batch_size, size), dtype), + grid=batch_size, + ) + def softmax(x_ref, o_ref): + row_idx = pl.program_id(0) + x_idx = jnp.arange(block_size) + row_idxs = (row_idx, x_idx) + mask = x_idx < x_ref.shape[1] + row = plgpu.load(x_ref.at[row_idxs], mask=mask, other=-float("inf")) + row_minus_max = row - jnp.max(row, axis=0) + numerator = jnp.exp(row_minus_max) + denominator = jnp.sum(numerator, axis=0) + softmax_output = numerator / denominator + plgpu.store(o_ref.at[row_idxs], softmax_output, mask=mask) + + key = jax.random.key(0) + x = jax.random.normal(key, [batch_size, size], dtype=dtype) + np.testing.assert_allclose( + softmax(x), jax.nn.softmax(x, axis=-1), atol=1e-5, rtol=1e-5 + ) + + +@functools.partial( + jax.jit, static_argnames=["bm", "bn", "gm", "bk", "interpret", "debug"] +) +def matmul(x, y, *, bm, bn, gm, bk, interpret, debug=False): + m, n, k = x.shape[0], y.shape[1], x.shape[1] + + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), + interpret=interpret, + debug=debug, + grid=pl.cdiv(m, bm) * pl.cdiv(n, bn), + ) + def matmul_kernel(x_ref, y_ref, o_ref): + pid = pl.program_id(axis=0).astype(intx) + num_pid_m = m // bm + num_pid_n = n // bn + num_pid_in_group = gm * num_pid_n + group_id = lax.div(pid, num_pid_in_group) + first_pid_m = group_id * gm + group_size_m = jnp.minimum(num_pid_m - first_pid_m, gm) + pid_m = first_pid_m + lax.rem(pid, group_size_m) + pid_n = lax.div(lax.rem(pid, num_pid_in_group), group_size_m) + idx_m = pid_m * bm + jnp.arange(bm) + idx_n = pid_n * bn + jnp.arange(bn) + idx_m = plgpu.max_contiguous(pl.multiple_of(idx_m, bm), bm) + idx_n = plgpu.max_contiguous(pl.multiple_of(idx_n, bn), bn) + acc = jnp.zeros((bm, bn), dtype=jnp.float32) + + def body(i, acc): + idx_k = i * bk + jnp.arange(bk) + x_idx = ( + lax.broadcast_in_dim(idx_m, (bm, bk), (0,)), + lax.broadcast_in_dim(idx_k, (bm, bk), (1,)), + ) + y_idx = ( + lax.broadcast_in_dim(idx_k, (bk, bn), (0,)), + lax.broadcast_in_dim(idx_n, (bk, bn), (1,)), + ) + x_block, y_block = x_ref[x_idx], y_ref[y_idx] + out = pl.dot(x_block, y_block) + return acc + out + + acc = lax.fori_loop(0, k // bk, body, acc).astype(o_ref.dtype) + o_idx = ( + lax.broadcast_in_dim(idx_m, (bm, bn), (0,)), + lax.broadcast_in_dim(idx_n, (bm, bn), (1,)), + ) + o_ref[o_idx] = acc + + return matmul_kernel(x, y) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 7f9ea598d51b..ac89977adb91 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -21,7 +21,7 @@ import tempfile import warnings -from absl.testing import absltest +from absl.testing import absltest, parameterized import jax from jax._src import api from jax._src import compilation_cache as cc @@ -65,7 +65,11 @@ def testPGLEProfilerGetFDOProfile(self): jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), - compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'}, + compiler_options={ + 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # Make sure that matmul is not emitted as Triton GEMM. + 'xla_gpu_enable_triton_gemm': 'False', + }, ) def f(x, y): return x @ y @@ -81,7 +85,7 @@ def f(x, y): pgle_profiler = profiler.PGLEProfiler(1, 90) with config.enable_pgle(False): with profiler.PGLEProfiler.trace(pgle_profiler): - compiled(x, y) + jax.block_until_ready(compiled(x, y)) fdo_profile = pgle_profiler.consume_fdo_profile() self.assertIsNotNone(fdo_profile) @@ -93,6 +97,8 @@ def testPGLEProfilerGetFDOProfileLarge(self): compiler_options = { 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # Make sure that matmul is not emitted as Triton GEMM. + 'xla_gpu_enable_triton_gemm': 'False', } # TODO(b/37664749): Remove this flag once the bug is fixed. compiler_options['xla_gpu_enable_command_buffer'] = '' @@ -151,29 +157,31 @@ def f(x): with config.pgle_profiling_runs(2), config.enable_pgle(True): # Run 1: Module should be compiled without FDO. Two modules are expected - # One is the funtion f, the other one is multi slice module - with jtu.count_jit_compilation_cache_miss() as cache_miss_count: + # One is the function f, the other one is multi slice module + with jtu.count_pjit_cpp_cache_miss() as cache_miss_count: self.assertArraysEqual(f(x), expected) self.assertEqual(cache_miss_count(), 2) # Run 2: Second PGLE run. Profile should be empty. - with jtu.count_jit_compilation_cache_miss() as cache_miss_count: + with jtu.count_pjit_cpp_cache_miss() as cache_miss_count: self.assertArraysEqual(f(x), expected) self.assertEqual(cache_miss_count(), 2) fdo_profiles_before_pgle = self.get_fdo_profiles(dump_dir) - # One for before and one for after optimization. - self.assertLen(fdo_profiles_before_pgle, 2) + # One for before optimizatiom, one after SPMD partitioning, and one + # after optimization. + self.assertLen(fdo_profiles_before_pgle, 3) # The FDO profile file should be empty. self.assertEqual( os.path.getsize(os.path.join(dump_dir, fdo_profiles_before_pgle[0])), 0) # Run 3: The module should be recompiled with FDO profiles - with jtu.count_jit_compilation_cache_miss() as cache_miss_count: + with jtu.count_pjit_cpp_cache_miss() as cache_miss_count: self.assertArraysEqual(f(x), expected) self.assertEqual(cache_miss_count(), 2) fdo_profiles_after_pgle = self.get_fdo_profiles(dump_dir) - # One for before and one for after optimization. - self.assertLen(fdo_profiles_after_pgle, 4) + # One more before optimizatiom, one more after SPMD partitioning, and + # one more after optimization. + self.assertLen(fdo_profiles_after_pgle, 6) for fdo_profile in fdo_profiles_after_pgle: if fdo_profile not in fdo_profiles_before_pgle: @@ -182,7 +190,7 @@ def f(x): ) # Run 4: Fast-path should be used after PGLE is done - with jtu.count_jit_compilation_cache_miss() as cache_miss_count: + with jtu.count_pjit_cpp_cache_miss() as cache_miss_count: self.assertArraysEqual(f(x), expected) self.assertLess(cache_miss_count(), 2) @@ -196,7 +204,8 @@ def f(x): f_lowered = f.lower(x) serialized, in_tree, out_tree = serialize(f_lowered.compile()) - compiled = deserialize_and_load(serialized, in_tree, out_tree) + compiled = deserialize_and_load( + serialized, in_tree, out_tree, execution_devices=jax.devices()[:1]) with config.pgle_profiling_runs(1), config.enable_pgle(True): # Run 1 @@ -310,7 +319,7 @@ def check_if_cache_hit(event): monitoring.register_event_listener(check_if_cache_hit) f(x) - monitoring._unregister_event_listener_by_callback(check_if_cache_hit) + monitoring.unregister_event_listener(check_if_cache_hit) self.assertGreater(cache_hit, 0) @@ -321,7 +330,11 @@ def testPassingFDOProfile(self): jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), - compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'}, + compiler_options={ + 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # Make sure that matmul is not emitted as Triton GEMM. + 'xla_gpu_enable_triton_gemm': 'False', + }, ) def f(x, y): return x @ y @@ -435,7 +448,7 @@ def check_if_cache_hit(event): monitoring.register_event_listener(check_if_cache_hit) f(x) - monitoring._unregister_event_listener_by_callback(check_if_cache_hit) + monitoring.unregister_event_listener(check_if_cache_hit) self.assertGreater(cache_hit, 0) # Run 5: `g` was only executed once and did not get re-compiled with PGLE, so @@ -446,7 +459,7 @@ def check_if_cache_hit(event): cache_hit = 0 monitoring.register_event_listener(check_if_cache_hit) g(x) - monitoring._unregister_event_listener_by_callback(check_if_cache_hit) + monitoring.unregister_event_listener(check_if_cache_hit) self.assertEqual(cache_hit, 1) if len(w) != 1: print("Warnings:", [str(w_) for w_ in w], flush=True) @@ -468,5 +481,59 @@ def check_if_cache_hit(event): self.assertLen(w, 1) self.assertIn("PERSISTENT CACHE WRITE with key jit_h-", str(w[0].message)) + @parameterized.parameters([True, False]) + @jtu.thread_unsafe_test() + def testAutoPgleWithCommandBuffers(self, enable_compilation_cache): + with (config.pgle_profiling_runs(1), + config.enable_compilation_cache(enable_compilation_cache), + config.enable_pgle(True), + tempfile.TemporaryDirectory() as dump_dir, + tempfile.TemporaryDirectory() as cache_dir): + if enable_compilation_cache: + cc.reset_cache() + cc.set_cache_dir(cache_dir) + compiler_options = { + 'xla_dump_to': dump_dir, + # FUSION, see https://github.com/openxla/xla/issues/22459 + 'xla_gpu_enable_command_buffer': 1, + 'xla_gpu_graph_min_graph_size': 1, + } + @partial( + jax.jit, + compiler_options=compiler_options, + ) + def f(x): + return x * 2 + + x = jnp.arange(1) + expected = x * 2 + + # This is ugly, but it does not seem possible to get the AutoPGLE-recompiled + # executable text (.lower(x).compile().as_text() or similar). + def get_new_files(): + additions = set(os.listdir(dump_dir)) - get_new_files.seen_files + get_new_files.seen_files |= additions + new_files = list(filter(lambda f: f.endswith('debug_options'), additions)) + assert len(new_files) == 1 + with open(os.path.join(dump_dir, new_files[0])) as ifile: + return ifile.read() + + get_new_files.seen_files = set() + + # Run 1 + self.assertArraysEqual(f(x), expected) + self.assertNotIn( + 'xla_gpu_enable_command_buffer: 1', get_new_files() + ) # b/376647494 workaround + # Run 2 + self.assertArraysEqual(f(x), expected) + self.assertIn( + 'xla_gpu_enable_command_buffer', get_new_files() + ) # workaround disabled + + api.clear_caches() + pjit._pgle_profiler_dict.clear() + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pickle_test.py b/tests/pickle_test.py index 185eebd90726..3f0f2c70de52 100644 --- a/tests/pickle_test.py +++ b/tests/pickle_test.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import pickle +import sys import unittest from absl.testing import absltest @@ -25,9 +27,12 @@ import jax from jax import numpy as jnp -from jax.interpreters import pxla +from jax._src import config +from jax._src import literals from jax._src import test_util as jtu +from jax._src.interpreters import pxla from jax._src.lib import xla_client as xc +from jax._src.sharding_impls import GSPMDSharding import numpy as np @@ -76,6 +81,13 @@ def g(z): @unittest.skipIf(cloudpickle is None, "Requires cloudpickle") def testPickleOfPmappedFunctions(self): + if config.pmap_shmap_merge.value: + self.skipTest( + 'Nested pmaps are not relevant for `pmap_shmap_merge=True` and' + ' `pmap`s pickled prior to `pmap_shmap_merge=True` may not work, but' + " perhaps it's worth making sure that freshly pickled `pmap`s still" + ' work?' + ) @jax.pmap def f(x, y): @@ -113,6 +125,8 @@ def testPickleOfArrayWeakType(self): self.assertIsInstance(y, type(x)) self.assertEqual(x.aval, y.aval) + @unittest.skipIf(sys.version_info[:2] == (3, 11), + "cannot pickle: b/470129766") @jtu.sample_product(prng_name=['threefry2x32', 'rbg', 'unsafe_rbg']) def testPickleOfKeyArray(self, prng_name): with jax.default_prng_impl(prng_name): @@ -138,11 +152,11 @@ def testPickleOfPartitionSpecs(self, partition_spec): self.assertEqual(partition_spec, restored_partition_spec) def testPickleX64(self): - with jax.experimental.enable_x64(): + with jax.enable_x64(True): x = jnp.array(4.0, dtype='float64') s = pickle.dumps(x) - with jax.experimental.disable_x64(): + with jax.enable_x64(False): y = pickle.loads(s) self.assertEqual(x.dtype, jnp.float64) @@ -174,6 +188,19 @@ def test_pickle_single_device_sharding(self): s = jax.sharding.SingleDeviceSharding(jax.devices()[0]) self.assertEqual(s, pickle.loads(pickle.dumps(s))) + def test_pickle_single_device_sharding_with_memory_kind(self): + for memory_kind in ( + *[memory.kind for memory in jax.devices()[0].addressable_memories()], + None, + ): + with self.subTest(memory_kind=memory_kind): + s = jax.sharding.SingleDeviceSharding( + jax.devices()[0], memory_kind=memory_kind + ) + self.assertEqual(s, pickle.loads(pickle.dumps(s))) + + @jtu.ignore_warning(category=DeprecationWarning, + message='jax.sharding.PmapSharding is deprecated') def test_pickle_pmap_sharding(self): ss = pxla.ShardingSpec( sharding=(pxla.Unstacked(8),), @@ -182,16 +209,55 @@ def test_pickle_pmap_sharding(self): self.assertEqual(s, pickle.loads(pickle.dumps(s))) def test_pickle_gspmd_sharding(self): - s = jax.sharding.GSPMDSharding.get_replicated(jax.devices()) + s = GSPMDSharding.get_replicated(jax.devices()) self.assertEqual(s, pickle.loads(pickle.dumps(s))) + def test_pickle_gspmd_sharding_with_memory_kind(self): + for memory_kind in ( + *[memory.kind for memory in jax.devices()[0].addressable_memories()], + None, + ): + with self.subTest(memory_kind=memory_kind): + s = GSPMDSharding.get_replicated(jax.devices(), memory_kind=memory_kind) + self.assertEqual(s, pickle.loads(pickle.dumps(s))) + @unittest.skipIf(cloudpickle is None, "Requires cloudpickle") def test_pickle_named_sharding(self): s = jax.sharding.NamedSharding( mesh=jax.sharding.Mesh(np.array(jax.devices()), 'd'), - spec=jax.sharding.PartitionSpec('d')) + spec=jax.sharding.PartitionSpec('d'), + ) self.assertEqual(s, pickle.loads(pickle.dumps(s))) + @unittest.skipIf(cloudpickle is None, 'Requires cloudpickle') + def test_pickle_named_sharding_with_memory_kind(self): + for memory_kind in ( + *[memory.kind for memory in jax.devices()[0].addressable_memories()], + None, + ): + with self.subTest(memory_kind=memory_kind): + s = jax.sharding.NamedSharding( + mesh=jax.sharding.Mesh(np.array(jax.devices()), 'd'), + spec=jax.sharding.PartitionSpec('d'), + memory_kind=memory_kind, + ) + self.assertEqual(s, pickle.loads(pickle.dumps(s))) + + def test_pickle_typed_scalar(self): + for l in [ + literals.TypedInt(3, np.dtype(np.int32)), + literals.TypedFloat(2.0, np.dtype(np.float32)), + literals.TypedComplex(1j, np.dtype(np.complex64)), + ]: + m = pickle.loads(pickle.dumps(l)) + self.assertEqual(type(l), type(m)) + self.assertEqual(l, m) + self.assertEqual(l.dtype, m.dtype) + + n = copy.deepcopy(l) + self.assertEqual(type(l), type(n)) + self.assertEqual(l, n) + self.assertEqual(l.dtype, n.dtype) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 293b37a9fbc7..8b581fa7538f 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -13,13 +13,12 @@ # limitations under the License. from collections import OrderedDict, namedtuple +import itertools import re -from functools import partial -import logging +from functools import partial, wraps import json import math import textwrap -import threading import unittest from absl.testing import absltest @@ -30,9 +29,11 @@ import jax import jax.numpy as jnp +from jax import reshard from jax._src import core from jax._src import config from jax._src import dispatch +from jax._src import literals from jax._src import test_util as jtu from jax._src import dtypes from jax import stages @@ -40,29 +41,29 @@ from jax._src.lax import lax as lax_internal from jax.lax import with_sharding_constraint from jax._src import prng -from jax.sharding import PartitionSpec as P, Mesh +from jax._src import lib +from jax.sharding import (PartitionSpec as P, Mesh, auto_axes, explicit_axes, + AbstractDevice) from jax.experimental import multihost_utils -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map from jax._src.compilation_cache import is_persistent_cache_enabled -from jax.experimental.custom_partitioning import ( - custom_partitioning, SdyShardingRule, BATCHING) +from jax.experimental import primal_tangent_dtype from jax._src import array from jax._src.sharding import Sharding, common_devices_indices_map from jax._src import op_shardings from jax._src import sharding_impls from jax._src.sharding_impls import ( - AUTO, UNSPECIFIED, NamedSharding, GSPMDSharding, PositionalSharding, + AUTO, UNSPECIFIED, NamedSharding, GSPMDSharding, SingleDeviceSharding, parse_flatten_op_sharding) -from jax._src.pjit import (pjit, mesh_cast, auto_axes, explicit_axes, - use_auto_axes, use_explicit_axes, reshard) +from jax._src.mesh import use_abstract_mesh +from jax._src.pjit import pjit, _pjit_lower +from jax._src.layout import Format, Layout as DLL from jax._src.named_sharding import DuplicateSpecError from jax._src import mesh as mesh_lib from jax._src.mesh import AxisType from jax._src.interpreters import pxla -from jax._src.lib.mlir import dialects -from jax._src import xla_bridge from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension +from jax._src.lib import ifrt_version from jax._src.util import curry, unzip2 config.parse_flags_with_absl() @@ -74,18 +75,14 @@ def create_array(global_shape, global_mesh, mesh_axes, global_data=None, if global_data is None: global_data = np.arange( math.prod(global_shape), dtype=dtype).reshape(global_shape) - - if isinstance(mesh_axes, Sharding): - sharding = mesh_axes - else: - sharding = NamedSharding(global_mesh, mesh_axes) - + sharding = (mesh_axes if isinstance(mesh_axes, Sharding) else + NamedSharding(global_mesh, mesh_axes)) return array.make_array_from_callback( global_shape, sharding, lambda idx: global_data[idx]), global_data -def _check_instance(self, x): - self.assertIsInstance(x, array.ArrayImpl) +def spec_regex(s): + return str(s).replace(r"(", r"\(").replace(r")", r"\)") @curry @@ -115,7 +112,7 @@ def f(x): actual = f(x) expected = x self.assertAllClose(actual, expected, check_dtypes=False) - _check_instance(self, actual) + self.assertIsInstance(actual, array.ArrayImpl) self.assertLen(actual.addressable_shards, 1) self.assertAllClose( np.asarray(actual.addressable_shards[0].data), expected, check_dtypes=False) @@ -135,7 +132,7 @@ def f(x, y): actual = f(x, x + 1) expected = x + (x + 1) self.assertAllClose(actual, expected, check_dtypes=False) - _check_instance(self, actual) + self.assertIsInstance(actual, array.ArrayImpl) self.assertLen(actual.addressable_shards, 2) self.assertAllClose(np.asarray(actual.addressable_shards[0].data), expected, check_dtypes=False) @@ -171,7 +168,7 @@ def f(x, y): actual = f(x, x + 1) expected = x + (x + 1) self.assertAllClose(actual[:3], expected[:3], check_dtypes=False) - _check_instance(self, actual) + self.assertIsInstance(actual, array.ArrayImpl) self.assertLen(actual.addressable_shards, 2) self.assertAllClose(np.asarray(actual.addressable_shards[0].data)[:3], expected[:3], check_dtypes=False) @@ -190,7 +187,7 @@ def f(x, y): expected = x + (x + 1) self.assertEqual(mesh, jtu.create_mesh((2,), ('x'))) self.assertAllClose(actual, expected, check_dtypes=False) - _check_instance(self, actual) + self.assertIsInstance(actual, array.ArrayImpl) self.assertLen(actual.addressable_shards, 2) self.assertAllClose(np.asarray(actual.addressable_shards[0].data), expected, check_dtypes=False) @@ -210,7 +207,7 @@ def f(x, y): actual = f(x, y) expected = x @ y self.assertAllClose(actual, expected, check_dtypes=False) - _check_instance(self, actual) + self.assertIsInstance(actual, array.ArrayImpl) self.assertLen(actual.addressable_shards, 4) split0, split1 = np.split(expected, 2) @@ -279,7 +276,7 @@ def f(x, y): actual = f(x, x + 1) expected = x @ (x + 1) self.assertAllClose(actual, expected, check_dtypes=False) - _check_instance(self, actual) + self.assertIsInstance(actual, array.ArrayImpl) self.assertLen(actual.addressable_shards, 4) splits = np.split(expected, 4) @@ -432,6 +429,30 @@ def f(x): f(x) self.assertNotDeleted(x) + @jtu.run_on_devices('tpu', 'cpu', 'gpu') + def testBufferDonationDifferentIOShapes(self): + mesh = jtu.create_mesh((2,), 'x') + + s1 = NamedSharding(mesh, P('x')) + s2 = NamedSharding(mesh, P(None, 'x', None)) + + x = jax.device_put(np.arange(16), s1) + y = jax.device_put(np.arange(16).reshape(16, 1), s1) + z = jax.device_put(np.arange(16).reshape(2, 2, 4), s1) + + @partial( + jax.jit, + out_shardings=(s1, s1, s2), + donate_argnames=('x', 'y', 'z'), + ) + def f(x, y, z): + return x, jnp.reshape(y, (16,)), z + + f(x, y, z) + self.assertDeleted(x) + self.assertDeleted(y) + self.assertDeleted(z) + @jtu.run_on_devices('tpu', 'cpu', 'gpu') def testBufferDonationMixedConstrainedness(self): mesh = jtu.create_mesh((2,), 'x') @@ -462,7 +483,7 @@ def f(x): expected = (x + 1) * 2 actual = f(x) self.assertAllClose(actual, expected, check_dtypes=False) - _check_instance(self, actual) + self.assertIsInstance(actual, array.ArrayImpl) self.assertLen(actual.addressable_shards, 2) self.assertAllClose(np.asarray(actual.addressable_shards[0].data), expected, check_dtypes=False) @@ -505,8 +526,6 @@ def f(x): self.assertIn("sharding={replicated}", hlo.as_hlo_text()) def testShardingConstraintWithArrayOpSharding(self): - if config.use_shardy_partitioner.value: - self.skipTest("Shardy doesn't support PositionalSharding") shape = (8, 8) mesh = jtu.create_mesh((2, 1), ('x', 'y')) s = NamedSharding(mesh, P(None)) @@ -644,7 +663,7 @@ def testNested(self): x = jnp.arange(16.).reshape((4, 4)) y = g(x) self.assertAllClose(y, jnp.sin(x).sum() + h.sum()) - _check_instance(self, y) + self.assertIsInstance(y, array.ArrayImpl) @check_1d_2d_mesh(set_mesh=True) def testAutodiff(self, mesh, resources): @@ -750,7 +769,7 @@ def testVMapShardingConstraint(self): self.assertTrue(op.is_tiled()) self.assertListEqual(op.tile_assignment_dimensions(), [1, 2]) self.assertListEqual(op.tile_assignment_devices(), [0, 1]) - self.assertFalse(op_shardings.is_op_sharding_replicated(op)) + self.assertFalse(op_shardings.is_hlo_sharding_replicated(op)) @jtu.with_mesh([('x', 2)]) def testVMapShardingConstraintWithSpmdAxis(self): @@ -770,7 +789,7 @@ def testVMapShardingConstraintWithSpmdAxis(self): self.assertTrue(op.is_tiled()) self.assertListEqual(op.tile_assignment_dimensions(), [2, 1]) self.assertListEqual(op.tile_assignment_devices(), [0, 1]) - self.assertFalse(op_shardings.is_op_sharding_replicated(op)) + self.assertFalse(op_shardings.is_hlo_sharding_replicated(op)) @jtu.with_mesh([('x', 2)]) def testLowerWithDuckTyping(self): @@ -800,138 +819,6 @@ def f(inp1): f_com = f_low.compile() f_low.donate_argnums == f_com.donate_argnums == (0,) - @unittest.skip('Fails in OSS builds on GPU with jax at HEAD and latest ' - 'jaxlib on pypi.') - def testInfeed(self): - devices = np.array(jax.local_devices()) - nr_devices = len(devices) - shape = (nr_devices * 3, nr_devices * 5) - - def f_for_jit(x): - token = lax.create_token(x) - (y,), token = lax.infeed( - token, shape=(core.ShapedArray(x.shape, np.float32),)) - (z,), token = lax.infeed( - token, shape=(core.ShapedArray(x.shape, np.float32),)) - (w,), token = lax.infeed( - token, shape=(core.ShapedArray(x.shape, np.float32),)) - - return x + y + z + w - - x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) - y = x * 2. - z = x * 3. - w = x * 4. - - # Transfer data to infeed before executing the function. For GPUs, the - # execution of the compiled function is blocking, so transferring data - # to infeed before executing ensures that the execution does not deadlock - # waiting for the infeed data. - logging.info('Transferring to infeed for the jit call') - d = devices[0] - d.transfer_to_infeed((y,)) - d.transfer_to_infeed((z,)) - d.transfer_to_infeed((w,)) - - # JIT - logging.info('Making jit call') - res0 = jax.jit(f_for_jit)(x) - self.assertAllClose(res0, x + y + z + w, check_dtypes=True) - - # PJIT - def f_for_pjit(x): - token = lax.create_token(x) - # A replicated infeed - (y,), token = lax.infeed( - token, - shape=(core.ShapedArray(x.shape, np.float32),), - partitions=(None,)) - # An infeed sharded on first axis - (z,), token = lax.infeed( - token, - shape=(core.ShapedArray(x.shape, np.float32),), - partitions=(P(nr_devices, 1),)) - # An infeed sharded on second axis - (w,), token = lax.infeed( - token, - shape=(core.ShapedArray(x.shape, np.float32),), - partitions=(P(1, nr_devices),)) - return x + y + z + w - - logging.info('Transferring to infeed for the pjit call') - for didx, d in enumerate(devices): - # Transfer the whole array to all devices for replicated. - d.transfer_to_infeed((y,)) - # For sharded infeed, transfer only the needed slices to each device. - d.transfer_to_infeed(z[3 * didx:3 * didx + 3, :]) - d.transfer_to_infeed((w[:, 5 * didx:5 * didx + 5],)) - - with jax.sharding.Mesh(devices, ['d']): - logging.info('Making pjit call') - res = pjit(f_for_pjit, in_shardings=(P('d'),), out_shardings=P('d'))(x) - - self.assertAllClose(res0, res, check_dtypes=True) - - def testOutfeed(self): - if xla_bridge.using_pjrt_c_api(): - raise unittest.SkipTest('outfeed not implemented in PJRT C API') - if config.use_shardy_partitioner.value: - self.skipTest( - 'b/355263220: outfeed lowering not supported by Shardy') - - devices = np.array(jax.local_devices()) - nr_devices = len(devices) - shape = (nr_devices * 3, nr_devices * 5) - - def f(x): - token = lax.create_token(x) - token = lax.outfeed(token, x, partitions=(None,)) - token = lax.outfeed(token, x, partitions=((nr_devices, 1),)) - token = lax.outfeed(token, x, partitions=((1, nr_devices),)) - return x - - x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) - - def _dispatch(): - with jax.sharding.Mesh(devices, ['d']): - logging.info('Making pjit call') - pjit(f, in_shardings=(P('d'),), out_shardings=P('d'))(x) - execution = threading.Thread(target=_dispatch) - execution.start() - - # Check the expected outfeed for all devices. - def check_outfeed(x_fn): - for didx, d in enumerate(devices): - x = x_fn(didx) - y, = d.transfer_from_outfeed( - xc.shape_from_pyval((x,)).with_major_to_minor_layout_if_absent()) - self.assertAllClose(x, y, check_dtypes=True) - - logging.info('Transferring from outfeed for the pjit call') - - # Note, when checking results of multiple outfeeds, the loop structure - # should be such that we check a given outfeed for all devices before - # moving on to the next outfeed. If there are any collectives generated - # by pjit, a loop structutre like: - # for each device: - # check outfeed#0; - # check outfeed#1; - # - # Could cause a deadlock if there is a collective scheduled between the - # 2 outfeeds, as device #0, after processing outfeed#0 will execute the - # collective, waiting for other devices to join, but other devices won't - # execute their collective until their outfeed#0 is executed. This is - # because, for GPU for example, execution of an outfeed on GPU is blocked - # till the corresponding `transfer_from_outfeed` is executed on the host. - - # Transfer the whole array from all devices for replicated. - check_outfeed(lambda didx: x) - # For sharded outfeed, the results are sliced. - check_outfeed(lambda didx: x[3 * didx:3 * didx + 3, :]) - check_outfeed(lambda didx: x[:, 5 * didx:5 * didx + 5]) - - execution.join() - @jtu.with_mesh([('x', 2)]) def testWithCustomPRNGKey(self): if not config.enable_custom_prng.value: @@ -940,16 +827,27 @@ def testWithCustomPRNGKey(self): # Make sure this doesn't crash pjit(lambda x: x, in_shardings=None, out_shardings=None)(key) + def test_lower_with_wrapper_error(self): + @jax.jit + def f(x): + return x + + self.assertAllClose(1., f(1.)) + self.assertAllClose(1., f.lower(1.).compile()(1.)) + wrapped_f = wraps(f)(lambda x: f(x + 1)) + + with self.assertRaisesRegex(AttributeError, "has no attribute 'lower'"): + wrapped_f.lower(1.) + @jtu.with_mesh([('x', 2), ('y', 2)]) def testLowerCompile(self): - @partial(pjit, - in_shardings=P(('x', 'y'),), - out_shardings=P(('x', 'y'),)) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + x = jnp.arange(64).reshape(8, 8) + + @partial(pjit, in_shardings=P(('x', 'y')), out_shardings=P(('x', 'y'))) def f(x, y): return x @ y - shape = (8, 8) - x = jnp.arange(math.prod(shape)).reshape(shape) expected = x @ (x + 1) lowered = f.lower(x, x + 1) @@ -957,9 +855,11 @@ def f(x, y): actual = compiled(x, x + 1) self.assertEqual(lowered.in_avals, compiled.in_avals) - self.assertEqual( - lowered.in_avals, - ((core.ShapedArray(x.shape, x.dtype, weak_type=False),) * 2, {})) + + abs_mesh = mesh.abstract_mesh + exp_aval = core.ShapedArray(x.shape, x.dtype, + sharding=NamedSharding(abs_mesh, P())) + self.assertEqual(lowered.in_avals, ((exp_aval,) * 2, {})) splits = np.split(expected, 4) self.assertAllClose(np.asarray(actual.addressable_shards[0].data), splits[0], @@ -1169,7 +1069,7 @@ def f(x, y): return x @ y shape = (8, 8) - aval = core.ShapedArray(shape, dtypes.canonicalize_dtype(jnp.int64)) + aval = core.ShapedArray(shape, dtypes.default_int_dtype()) x = jnp.arange(math.prod(shape)).reshape(shape) exe = f.lower(aval, x).compile() self.assertIsInstance(exe, stages.Compiled) @@ -1220,11 +1120,11 @@ def test_pretty_print(self): textwrap.dedent(""" let lambda = { lambda ; a:f32[1]. let b:f32[1] = integer_pow[y=2] a in (b,) } in { lambda ; c:f32[1]. let - d:f32[1] = pjit[ + d:f32[1] = jit[ name= jaxpr={ lambda ; c:f32[1]. let - e:f32[1] = pjit[name= jaxpr=lambda] c - f:f32[1] = pjit[name= jaxpr=lambda] c + e:f32[1] = jit[name= jaxpr=lambda] c + f:f32[1] = jit[name= jaxpr=lambda] c d:f32[1] = add e f in (d,) } ] c @@ -1240,9 +1140,12 @@ def test_pretty_print_pjit_id(self): jaxpr.pretty_print(use_color=False), textwrap.dedent(""" { lambda ; a:f32[1]. let - pjit[name= jaxpr={ lambda ; a:f32[1] b:f32[1]. let in () }] a a - c:f32[1] = add a a - in (c,) } + b:f32[1] = jit[ + name= + jaxpr={ lambda ; a:f32[1] c:f32[1]. let in (a,) } + ] a a + d:f32[1] = add a b + in (d,) } """).strip(), ) @@ -1254,10 +1157,10 @@ def test_pretty_print_with_constant_pjit_arg(self): jaxpr.pretty_print(use_color=False), textwrap.dedent(""" { lambda ; a:f32[1]. let - b:f32[1] = pjit[ + b:f32[1] = jit[ name= jaxpr={ lambda ; a:f32[1] c:f32[]. let b:f32[1] = mul a c in (b,) } - ] a 1.0 + ] a 1.0:f32[] in (b,) } """).strip(), ) @@ -1270,7 +1173,7 @@ def test_pretty_print_with_aliased_args(self): jaxpr.pretty_print(use_color=False), textwrap.dedent(""" { lambda ; a:f32[1]. let - b:f32[1] = pjit[ + b:f32[1] = jit[ name= jaxpr={ lambda ; a:f32[1] c:f32[1] d:f32[1]. let e:f32[1] = mul a c @@ -1289,8 +1192,11 @@ def test_pretty_print_with_literal_outvar(self): jaxpr.pretty_print(use_color=False), textwrap.dedent(""" { lambda ; a:f32[1]. let - b:i32[] = pjit[name= jaxpr={ lambda ; a:f32[1]. let in (2,) }] a - in (b, a) } + b:i32[] c:f32[1] = jit[ + name= + jaxpr={ lambda ; a:f32[1]. let in (2:i32[], a) } + ] a + in (b, c) } """).strip(), ) @@ -1309,11 +1215,11 @@ def f(x): textwrap.dedent(""" let f = { lambda ; a:f32[1] b:f32[1]. let c:f32[1] = mul b a in (c,) } in { lambda ; d:f32[1] e:f32[1]. let - g:f32[1] = pjit[ + g:f32[1] = jit[ name=g jaxpr={ lambda ; d:f32[1] e:f32[1]. let - h:f32[1] = pjit[name=f jaxpr=f] e d - i:f32[1] = pjit[name=f jaxpr=f] e e + h:f32[1] = jit[name=f jaxpr=f] e d + i:f32[1] = jit[name=f jaxpr=f] e e g:f32[1] = add h i in (g,) } ] d e @@ -1336,19 +1242,19 @@ def f(x): self.assertEqual( jaxpr.pretty_print(use_color=False), textwrap.dedent(""" - let f = { lambda ; a:f32[1]. let in () } in - let f1 = { lambda ; b:f32[2]. let in () } in + let f = { lambda ; a:f32[1]. let in (a,) } in + let f1 = { lambda ; b:f32[2]. let in (b,) } in { lambda ; c:f32[1] d:f32[2]. let - e:f32[2] = pjit[ + e:f32[2] = jit[ name=g jaxpr={ lambda ; c:f32[1] d:f32[2]. let - pjit[name=f jaxpr=f] c - pjit[name=f jaxpr=f] c - g:f32[1] = mul c c - pjit[name=f jaxpr=f1] d - pjit[name=f jaxpr=f1] d - h:f32[2] = mul d d - e:f32[2] = add g h + g:f32[1] = jit[name=f jaxpr=f] c + h:f32[1] = jit[name=f jaxpr=f] c + i:f32[1] = mul g h + j:f32[2] = jit[name=f jaxpr=f1] d + k:f32[2] = jit[name=f jaxpr=f1] d + l:f32[2] = mul j k + e:f32[2] = add i l in (e,) } ] c d in (e,) } @@ -1394,332 +1300,14 @@ def test_zero_literal_equality(self): self.assertIn("stablehlo.constant dense<0.000000e+00>", ir) self.assertIn("stablehlo.constant dense<-0.000000e+00>", ir) -@jtu.pytest_mark_if_available('multiaccelerator') -class CustomPartitionerTest(jtu.JaxTestCase): - - def skip_if_custom_partitioning_not_supported(self): - if jtu.is_cloud_tpu(): - raise unittest.SkipTest("Custom partitioning is not supported on libtpu.") - - @jtu.skip_on_devices('cpu') # Collectives don't seem to work on CPU. - @jtu.with_mesh([('x', 4), ('y', 2)]) - def test_custom_partitioner(self): - self.skip_if_custom_partitioning_not_supported() - - def partition(precision, mesh, arg_shapes, result_shape): - arg_shardings = jax.tree.map(lambda s: s.sharding, arg_shapes) - result_sharding = result_shape[0].sharding - self.assertEqual(arg_shardings[0], result_sharding) - self.assertEqual(P('x', None), result_sharding.spec) - self.assertEqual(P('y', None), arg_shardings[1].spec) - - def lower_fn(x, y): - axis_name = arg_shardings[1].spec[0][0] - i = jax.lax.axis_index(axis_name) - # Use offset i * 0 instead of 0 to ensure that the two offsets have the - # same dtype regardless the value of config.enable_x64. - z = jax.lax.psum( - jax.lax.dynamic_slice(x, (i * 0, i * 8), (8, 8)) @ y, (axis_name) - ) - return z, z * z - - return mesh, lower_fn, (result_sharding, result_sharding), arg_shardings - - def infer_sharding_from_operands(precision, mesh, arg_shapes, result_shape): - arg_shardings = jax.tree.map(lambda s: s.sharding, arg_shapes) - x_shard, y_shard = arg_shardings - x_shape, y_shape = arg_shapes - x_names = tuple(x_shard.spec) + tuple( - None for _ in range(len(x_shape.shape) - len(x_shard.spec))) - y_names = tuple(y_shard.spec) + tuple( - None for _ in range(len(y_shape.shape) - len(y_shard.spec))) - z_shard = NamedSharding(y_shard.mesh, P(*(x_names[:-1] + y_names[1:]))) - return z_shard, z_shard - - @partial(custom_partitioning, static_argnums=(2,)) - def f(x, y, precision=None): - z = jnp.matmul(x, y, precision=precision) - return z, z * z - - f.def_partition( - infer_sharding_from_operands=infer_sharding_from_operands, - partition=partition, - sharding_rule=SdyShardingRule(operand_mappings=(('i', 'j'), ('j', 'k')), result_mappings=(('i', 'k'), ('i', 'k')))) - - pjit_f = pjit(f, in_shardings=(P('x'), P('y')), out_shardings=P('x')) - x = np.asarray(np.random.randint(0, 20, (32, 16)), dtype=np.float32) - y = np.asarray(np.random.randint(0, 20, (16, 32)), dtype=np.float32) - result1 = jax.jit(f)(x, y) - result2 = f(x, y) - result0 = pjit_f(x, y) - self.assertArraysEqual(result0, result1) - self.assertArraysEqual(result1, result2) - - @jtu.with_mesh([('x', 4), ('y', 2)]) - def test_custom_partitioner_propagate_user_sharding(self): - self.skip_if_custom_partitioning_not_supported() - - def partition(mesh, arg_shapes, result_shape): - def lower_fn(x): - return x - - return ( - mesh, - lower_fn, - arg_shapes[0].sharding, - (arg_shapes[0].sharding,), - ) - - def infer_sharding_from_operands(mesh, arg_shapes, result_shape): - return arg_shapes[0].sharding - - def propagate_user_sharding(mesh, user_shape): - return user_shape.sharding - - @custom_partitioning - def f(x): - return x - - f.def_partition( - infer_sharding_from_operands=infer_sharding_from_operands, - partition=partition, - propagate_user_sharding=propagate_user_sharding, - sharding_rule='i j -> i j', - ) - - def f2(a): - return a + f(a) - - pjit_f = pjit(f2, in_shardings=(P(None, 'x')), out_shardings=P('x')) - x = np.asarray(np.random.randint(0, 20, (32, 16)), dtype=np.float32) - self.assertArraysEqual(x + x, pjit_f(x)) - - @jtu.with_mesh([('x', 4), ('y', 2)]) - def test_custom_partitioner_sharding_override(self): - self.skip_if_custom_partitioning_not_supported() - - def partition(mesh, arg_shapes, result_shape): - def lower_fn(x): - return x - - y_shard = arg_shapes[0].sharding - return ( - mesh, - lower_fn, - NamedSharding(y_shard.mesh, P(None)), - (NamedSharding(y_shard.mesh, P(None)),), - ) - - def infer_sharding_from_operands(mesh, arg_shapes, result_shape): - y_shard = arg_shapes[0].sharding - return NamedSharding(y_shard.mesh, P('x')) - - @custom_partitioning - def f(x): - return x - - f.def_partition( - infer_sharding_from_operands=infer_sharding_from_operands, - partition=partition, - sharding_rule=SdyShardingRule(operand_mappings=((BATCHING, 'i'),), result_mappings=((BATCHING, 'i'),))) - - pjit_f = pjit(f, in_shardings=(P(None, 'x')), out_shardings=P('x')) - x = np.asarray(np.random.randint(0, 20, (32, 16)), dtype=np.float32) - self.assertArraysEqual(x, pjit_f(x)) - - @jtu.with_mesh([('x', 4), ('y', 2)]) - def test_custom_partitioner_invalid_sharding(self): - self.skip_if_custom_partitioning_not_supported() - def partition(mesh, arg_shapes, result_shape): - def lower_fn(x): - return x - - y_shard = arg_shapes[0].sharding - return ( - mesh, - lower_fn, - NamedSharding(y_shard.mesh, P(None)), - (NamedSharding(y_shard.mesh, P(None, 'x')),), - ) - - def infer_sharding_from_operands(mesh, arg_shapes, result_shape): - y_shard = arg_shapes[0].sharding - return NamedSharding(y_shard.mesh, P('x')) - - @custom_partitioning - def f(x): - return x - - f.def_partition( - infer_sharding_from_operands=infer_sharding_from_operands, - partition=partition, - sharding_rule='i j -> i j', - ) - - pjit_f = pjit(f, in_shardings=(P(None, 'x')), out_shardings=P('x')) - x = np.asarray(np.random.randint(0, 20, (32, 16)), dtype=np.float32) - - with self.assertRaisesRegex(Exception, 'Mismatch in result shapes.'): - pjit_f(x).block_until_ready() - - @jtu.with_mesh([('x', 4)]) - def test_custom_partitioner_jit_annotated_function(self): - """Test correct lowering of function with a @jax.jit annotated callee. - - Annotating a callee with @jax.jit results in a module with a HLO CallOp. - This test is makes sure that the custom partitioner lowering supports - CallOps. - """ - - self.skip_if_custom_partitioning_not_supported() - - @custom_partitioning - def f(x): - return x - - def partition(mesh, arg_shapes, result_shape): - def lower_fn(x): - @jax.jit - def g(y): - return y - - return g(x) - - x_shard = arg_shapes[0].sharding - return ( - mesh, - lower_fn, - NamedSharding(x_shard.mesh, P('x')), - (NamedSharding(x_shard.mesh, P('x')),), - ) - - def infer_sharding_from_operands(mesh, arg_shapes, result_shape): - x_shard = arg_shapes[0].sharding - return NamedSharding(x_shard.mesh, P('x')) - - f.def_partition( - infer_sharding_from_operands=infer_sharding_from_operands, - partition=partition, - sharding_rule='i -> i', - ) - - jit_f = jax.jit(f) - x = np.asarray(np.random.randint(0, 20, (32,)), dtype=np.float32) - pjit_f = pjit(jit_f, in_shardings=(P('x')), out_shardings=P('x')) - self.assertArraysEqual(x, pjit_f(x)) - - @jtu.with_mesh([('x', 4)]) - def test_custom_partitioner_with_scan(self): - self.skip_if_custom_partitioning_not_supported() - - # This is a reproducer from https://github.com/jax-ml/jax/issues/20864. - - @custom_partitioning - def f(x): - return jnp.sum(x) - - def partition(mesh, arg_shapes, result_shape): - def lower_fn(xs): - def f(carry, x): - return carry + jax.lax.psum(jnp.sum(x), axis_name='x'), None - - carry, _ = jax.lax.scan(f, 0, xs) - return carry - - result_shardings = jax.tree.map(lambda x: x.sharding, result_shape) - arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) - return mesh, lower_fn, result_shardings, arg_shardings - - f.def_partition( - partition, - infer_sharding_from_operands=lambda mesh, *_: NamedSharding(mesh, P()), - propagate_user_sharding=lambda _, user_shape: user_shape.sharding, - sharding_rule='i j -> ') # Result is a scalar. - - pjit_f = pjit(f, in_shardings=P(None, 'x')) - xs = jnp.ones([32, 16]) - self.assertEqual(pjit_f(xs), xs.sum()) - - def test_custom_partitioning_no_mesh_context(self): - self.skip_if_custom_partitioning_not_supported() - - @custom_partitioning - def f(x): - return x - - def partition(mesh, arg_shapes, result_shape): - def lower_fn(x): - @jax.jit - def g(y): - return y - - return g(x) - - x_shard = arg_shapes[0].sharding - return ( - mesh, - lower_fn, - NamedSharding(x_shard.mesh, P('x')), - (NamedSharding(x_shard.mesh, P('x')),), - ) - - def infer_sharding_from_operands(mesh, arg_shapes, result_shape): - x_shard = arg_shapes[0].sharding - return NamedSharding(x_shard.mesh, P('x')) - - f.def_partition( - infer_sharding_from_operands=infer_sharding_from_operands, - partition=partition, - sharding_rule='i -> i', - ) - - mesh = jtu.create_mesh((4,), ('x',)) - x = np.asarray(np.random.randint(0, 20, (32,)), dtype=np.float32) - s = NamedSharding(mesh, P('x')) - - jit_f = jax.jit(f, in_shardings=s, out_shardings=s) - self.assertArraysEqual(x, jit_f(x)) - - @jtu.with_mesh([('x', 4), ('y', 2)]) - def test_custom_partitioner_pytree_inputs(self): - self.skip_if_custom_partitioning_not_supported() - - def partition(mesh, arg_shapes, result_shape): - def lower_fn(xs): - x, y, z = xs - return x + y + z - - return ( - mesh, - lower_fn, - arg_shapes[0][0].sharding, - jax.tree.map(lambda x: x.sharding, arg_shapes), - ) - - def infer_sharding_from_operands(mesh, arg_shapes, result_shape): - return arg_shapes[0][0].sharding - - def propagate_user_sharding(mesh, user_shape): - return user_shape.sharding - - @custom_partitioning - def f(xs): - x, y, z = xs - return x + y + z - - f.def_partition( - infer_sharding_from_operands=infer_sharding_from_operands, - partition=partition, - propagate_user_sharding=propagate_user_sharding, - sharding_rule='i j, i j, i j -> i j', - ) - - def f2(a): - return a + f((a, a, a)) - - pjit_f = pjit(f2, in_shardings=(P(None, 'x')), out_shardings=P('x')) - x = np.asarray(np.random.randint(0, 20, (32, 16)), dtype=np.float32) - self.assertArraysEqual(x * 4, pjit_f(x)) + def test_device_put_copy_donate(self): + x = np.arange(1000) + y = jax.device_put(x, device=jax.devices()[0], may_alias=False, donate=False) + z = jax.device_put(y, device=jax.devices()[0], may_alias=False, donate=False) + a = jax.jit(lambda y: y * 2, donate_argnums=0)(y) + self.assertDeleted(y) + self.assertNotDeleted(z) + self.assertArraysEqual(a, x * 2) @jtu.pytest_mark_if_available('multiaccelerator') @@ -1951,6 +1539,11 @@ def test_pjit_array_single_output_with_mesh_context_manager( self.assertArraysEqual(s.data, expected_matrix_mul[s.index]) self.assertArraysEqual(out._value, expected_matrix_mul) + def test_empty_mesh_to_out_sharding(self): + sharding = jax.NamedSharding(mesh_lib.empty_concrete_mesh, P()) + with self.assertRaisesRegex(ValueError, "got an empty NamedSharding"): + jax.jit(lambda x: x, out_shardings=sharding)(jnp.ones((32,))) + def test_numpy_array_input_assume_fully_replicated(self): input_shape = (8, 2) global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) @@ -2083,8 +1676,8 @@ def test_in_axis_resources_mismatch_error(self): f = pjit(lambda x: x, in_shardings=NamedSharding(global_mesh, P('x'))) err_msg = re.compile( - "Sharding passed to pjit does not match the sharding on the " - r"respective arg.*arg shape.*\[8,2\]", re.M | re.S) + "Sharding passed to jit does not match the sharding on the " + r"respective arg.*arg.*", re.M | re.S) with self.assertRaisesRegex(ValueError, err_msg): f(input_array) @@ -2140,9 +1733,9 @@ def test_array_lower_compile(self): with self.assertRaisesRegex( ValueError, - r"Compiled object called with input sharding.*does not match the " - r"sharding.*the computation was compiled with. " - "Here are.*mismatches.*"): + r"Computation was compiled for input shardings.* that " + r"disagree with the shardings.* of arguments passed to it. " + r"Here are.*mismatches.*"): compiled(a2, a2, a2, a2, a2, a2) with global_mesh: @@ -2154,9 +1747,9 @@ def test_array_lower_compile(self): inp2 = {'x': a2, 'y': {'y1': a2}} with self.assertRaisesRegex( ValueError, - r"Compiled object called with input sharding.*does not match the " - r"sharding.*the computation was compiled with. " - "Here are the.*mismatches"): + r"Computation was compiled for input shardings.* that " + r"disagree with the shardings.* of arguments passed to it. " + r"Here are.*mismatches.*"): compiled(inp2) def test_globally_sharded_key_array_result_8x4_single_device(self): @@ -2252,7 +1845,7 @@ def test_mixed_inputs(self): in_shardings=NamedSharding(global_mesh, P(None))) with self.assertRaisesRegex( ValueError, - ('Sharding passed to pjit does not match the sharding on the ' + ('Sharding passed to jit does not match the sharding on the ' 'respective arg')): f(input_data, a1) @@ -2275,13 +1868,13 @@ def add(x, y): return x + y out = add(a, b) - cache_info1 = pxla._cached_lowering_to_hlo.cache_info() + cache_info1 = _pjit_lower.cache_info() self.assertIsInstance(out, array.ArrayImpl) self.assertArraysEqual(out, a + b) self.assertFalse(out._committed) out2 = add(out, out) - cache_info2 = pxla._cached_lowering_to_hlo.cache_info() + cache_info2 = _pjit_lower.cache_info() self.assertIsInstance(out2, array.ArrayImpl) self.assertArraysEqual(out2, 2 * (a + b)) self.assertFalse(out2._committed) @@ -2291,7 +1884,7 @@ def add(x, y): c = jax.device_put(a, jax.devices()[0]) out3 = add(c, c) - cache_info3 = pxla._cached_lowering_to_hlo.cache_info() + cache_info3 = _pjit_lower.cache_info() self.assertArraysEqual(out3, 2 * c) self.assertTrue(out3._committed) @@ -2313,6 +1906,30 @@ def mul(x): self.assertIsInstance(out, array.ArrayImpl) self.assertArraysEqual(out, a @ a.T) + def test_pjit_mixed_device_cache_miss(self): + try: + tpu_device = jax.devices('tpu')[0] + except: + raise unittest.SkipTest('Only a bug on TPU.') + + with jax.default_device('cpu'): + cpu_arr = jnp.array([1, 2, 3], dtype=np.float32) + tpu_arr = jax.device_put(cpu_arr, tpu_device) + + with jtu.count_pjit_cpp_cache_miss() as count: + for _ in range(2): + np.array(cpu_arr + tpu_arr) + self.assertEqual(count(), 1) + + def test_cpu_bfloat16_to_tpu(self): + mesh = jtu.create_mesh((1,), 'x') + np_inp = np.zeros((8, 2), dtype=jnp.bfloat16) + + arr_cpu = jax.device_put(np_inp, jax.devices('cpu')[0]) + arr_tpu = jax.device_put(arr_cpu, NamedSharding(mesh, P())) + + jax.jit(lambda x: jnp.sum(x))(arr_tpu) # doesn't crash + def test_pjit_single_device_sharding_cache(self): a = jnp.arange(16).reshape((8, 2)) f = pjit(lambda x: x) @@ -2388,9 +2005,7 @@ def test_fast_path_array(self): def test_array_enabled_non_empty_mesh_with_pspec(self): arr = jnp.array([1, 2, 3]) with self.assertRaisesRegex( - RuntimeError, - r'jit requires a non-empty mesh if you are passing `PartitionSpec`s or' - r' `None` to in_shardings.*'): + RuntimeError, r'pjit requires a non-empty mesh in context.*'): pjit(lambda x: x, in_shardings=P('x'))(arr) with self.assertRaisesRegex( @@ -2477,6 +2092,20 @@ def test_pjit_committed_array_different_devices_variadic_args(self): r"\[1\].*"): pjit(lambda *x: x)(a, b) + def test_jit_no_forwarding(self): + mesh = jtu.create_mesh((2,), ('x',)) + + @jax.jit(donate_argnums=(0,)) + def f(x): + return x, x * 2 + + x = jax.device_put(jnp.zeros(64, dtype="int32"), NamedSharding(mesh, P())) + jaxpr = jax.make_jaxpr(f)(x) + y = core.jaxpr_as_fun(jaxpr)(x) + self.assertTrue(x.is_deleted()) + self.assertFalse(y[0].is_deleted()) + self.assertFalse(y[1].is_deleted()) + def test_pjit_pytree_inp_device_assignment_mismatch(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) a = jax.device_put(np.array([1, 2, 3]), jax.devices()[0]) @@ -2606,14 +2235,14 @@ def test_device_put_on_different_sharding(self): def test_with_sharding_constraint_jit(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) - @partial(jax.jit, static_argnums=(0, 1)) + @jax.jit(static_argnums=(0, 1)) def sharded_zeros(shape, pspec): out = jnp.zeros(shape, jnp.bfloat16) return jax.lax.with_sharding_constraint(out, NamedSharding(mesh, pspec)) out = sharded_zeros((4096, 3072), P('x', 'y')) out_s = NamedSharding(mesh, P('x', 'y')) - self.assertTrue(op_shardings.are_op_shardings_equal( + self.assertTrue(op_shardings.are_hlo_shardings_equal( out.sharding._to_xla_hlo_sharding(out.ndim), out_s._to_xla_hlo_sharding(out.ndim))) @@ -2627,7 +2256,7 @@ def sharded_zeros(shape, pspec): out = sharded_zeros((4096, 3072), P('x', 'y')) out_s = NamedSharding(mesh, P('x', 'y')) - self.assertTrue(op_shardings.are_op_shardings_equal( + self.assertTrue(op_shardings.are_hlo_shardings_equal( out.sharding._to_xla_hlo_sharding(out.ndim), out_s._to_xla_hlo_sharding(out.ndim))) @@ -2660,7 +2289,7 @@ def f(x, y, z): ValueError, "Received incompatible devices for jitted computation. Got argument " r"inp1 of.*my_nested_pjit with shape bfloat16\[8,2\] and device ids \[0\].*" - r"pjit inside jit with device ids.*"): + r"jit inside jit with device ids.*"): my_nested_pjit(committed_inp, committed_inp, committed_inp) @jtu.ignore_warning(category=DeprecationWarning, @@ -2668,7 +2297,7 @@ def f(x, y, z): def test_jit_device_with_sharding_constraint_error(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) - @partial(jax.jit, static_argnums=(0, 1), device=jax.devices()[0]) + @jax.jit(static_argnums=(0, 1), device=jax.devices()[0]) def sharded_zeros(shape, pspec): out = jnp.zeros(shape, jnp.bfloat16) return jax.lax.with_sharding_constraint(out, NamedSharding(mesh, pspec)) @@ -2704,6 +2333,16 @@ def _invoke_with_mesh_twice(arg_tuple): for i, x, y in zip(range(n), xs, ys): self.assertAllClose(x + i, y) + def test_wsc_eager_copy(self): + sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0]) + x = jnp.arange(10) + y_wsc = jax.lax.with_sharding_constraint(x, sharding) + y_jit = jax.jit(lambda x: x, out_shardings=sharding)(x) + y_dp = jax.device_put(x, sharding) + self.assertTrue(y_wsc.committed) + self.assertTrue(y_jit.committed) + self.assertTrue(y_dp.committed) + def test_trivial_computation(self): shape = (8, 2) mesh = jtu.create_mesh((2, 2), ('x', 'y')) @@ -3067,6 +2706,25 @@ def f(x, y, z, a, b, c): # pylint: disable=unused-argument self.assertEqual(compiled._executable._kept_var_idx, {5}) self.assertLen(compiled._executable.in_avals, 1) + def test_abstract_device(self): + if not jtu.is_device_tpu(3): + self.skipTest('only works on TPU v3') + + inp = jnp.arange(8) + tpu_mesh = Mesh([jax.devices('tpu')[0]], 'x') + cpu_mesh = Mesh([jax.devices('cpu')[0]], 'x') + + @jax.jit + def f(x): + return x + + with jtu.count_jit_tracing_cache_miss() as tracing_count: + with jax.set_mesh(tpu_mesh): + f(inp) + with jax.set_mesh(cpu_mesh): + f(inp) + self.assertEqual(tracing_count(), 2) # twice for f + def test_pjit_relayout_multi_slice(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) @@ -3214,16 +2872,31 @@ def f(x): return x * 2 jaxpr = jax.make_jaxpr(f)(3) - self.assertIn('pjit', str(jaxpr)) + self.assertIn('jit', str(jaxpr)) @partial(pjit, inline=True) def g(x): return x * 2 jaxpr = jax.make_jaxpr(g)(3) - self.assertNotIn('pjit', str(jaxpr)) + self.assertNotIn('jit', str(jaxpr)) + + def test_pjit_inline_literal(self): + # https://github.com/jax-ml/jax/issues/27545 + def bar(x): + return jnp.array(1) + + def foo(x): + x = pjit(bar, inline=True)(x) + self.assertEqual(x.shape, ()) + + pjit(foo)(0) # doesn't crash + @jtu.ignore_warning(category=DeprecationWarning) def test_pmap_in_axis_resources_error(self): + if config.pmap_shmap_merge.value: + self.skipTest("Test does not raise under pmap_shmap_merge=True") + pmap_out = jax.pmap(lambda x: x)(jnp.arange(jax.device_count())) self.assertIsInstance(pmap_out.sharding, jax.sharding.PmapSharding) @@ -3237,41 +2910,66 @@ def test_pmap_in_axis_resources_error(self): r"One of out_shardings.*got sharding.*which is not allowed."): pjit(lambda x: x, out_shardings=pmap_out.sharding) + @jtu.ignore_warning(category=DeprecationWarning) def test_pmap_sharding_input_to_pjit_single_device(self): pmap_out = jax.pmap(lambda x: x)(jnp.arange(jax.device_count())) - self.assertIsInstance(pmap_out.sharding, jax.sharding.PmapSharding) + if config.pmap_shmap_merge.value: + self.assertIsInstance(pmap_out.sharding, jax.sharding.NamedSharding) + else: + self.assertIsInstance(pmap_out.sharding, jax.sharding.PmapSharding) self.assertLen(pmap_out.devices(), jax.device_count()) out = pjit(lambda x: x * 3)(pmap_out) self.assertArraysEqual(out, pmap_out * 3) - # Even though pmap out is on jax.device_count() number of devices, the - # output will be 1 device since it will be resharded. - self.assertLen(out.devices(), 1) + if config.pmap_shmap_merge.value: + self.assertLen(out.devices(), jax.device_count()) + else: + # Even though pmap out is on jax.device_count() number of devices, the + # output will be 1 device since it will be resharded. + self.assertLen(out.devices(), 1) + @jtu.ignore_warning(category=DeprecationWarning) def test_pmap_sharding_input_to_pjit_multi_device(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) - pmap_out = jax.pmap(lambda x: x)(jnp.arange(jax.device_count())) - self.assertIsInstance(pmap_out.sharding, jax.sharding.PmapSharding) - + pmap_out = jax.pmap(lambda x, y: x, in_axes=(None, 0), out_axes=None)( + jnp.arange(4), jnp.arange(4) + ) inp2 = jnp.arange(4) - with mesh: + if config.pmap_shmap_merge.value: + # NOTE(dsuo): Under `pmap_shmap_merge=True`, the mesh shape used by pmap + # to produce its output must match the mesh shape used by pjit for its + # inputs. Don't use the `mesh` context manager here since it does not + # match pmap's default mesh shape (i.e., (4,)). + self.assertIsInstance(pmap_out.sharding, jax.sharding.NamedSharding) out1, out2 = pjit(lambda x, y: (x * 2, y * 2))(pmap_out, inp2) + else: + self.assertIsInstance(pmap_out.sharding, jax.sharding.PmapSharding) + with mesh: + out1, out2 = pjit(lambda x, y: (x * 2, y * 2))(pmap_out, inp2) self.assertArraysEqual(out1, pmap_out * 2) self.assertArraysEqual(out2, inp2 * 2) self.assertLen(out1.devices(), 4) self.assertLen(out2.devices(), 4) - self.assertTrue(op_shardings.is_op_sharding_replicated( - out1.sharding._to_xla_hlo_sharding(pmap_out.ndim))) - self.assertTrue(op_shardings.is_op_sharding_replicated( - out2.sharding._to_xla_hlo_sharding(inp2.ndim))) + self.assertTrue(out1.is_fully_replicated) + self.assertTrue(out2.is_fully_replicated) + @jtu.ignore_warning(category=DeprecationWarning) def test_pmap_sharding_input_pjit_in_axis_resources(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + if config.pmap_shmap_merge.value: + # NOTE(dsuo): jax.pmap under `pmap_shmap_merge=True` will use this mesh + # shape by default, so we need pjit to have the same mesh shape as the one + # pmap uses. + mesh = jtu.create_mesh((4,), ('x')) + else: + mesh = jtu.create_mesh((2, 2), ('x', 'y')) - pmap_out = jax.pmap(lambda x: x)(jnp.arange(jax.device_count())) - self.assertIsInstance(pmap_out.sharding, jax.sharding.PmapSharding) + pmap_out = jax.pmap(lambda x: x)(jnp.arange(4)) + if config.pmap_shmap_merge.value: + self.assertIsInstance(pmap_out.sharding, jax.sharding.NamedSharding) + else: + self.assertIsInstance(pmap_out.sharding, jax.sharding.PmapSharding) out = pjit(lambda x: x * 2, in_shardings=NamedSharding(mesh, P('x')))(pmap_out) self.assertArraysEqual(out, pmap_out * 2) @@ -3353,9 +3051,7 @@ def f(x): def test_jit_with_mesh_context_manager(self): mesh = jtu.create_mesh((1,), ('x',)) with self.assertRaisesRegex( - RuntimeError, - "jax.jit only supports `Sharding`s being passed to " - "in_shardings"): + RuntimeError, "jit requires a non-empty mesh in context"): with mesh: jax.jit(lambda x: x, in_shardings=P('x'), out_shardings=P('x'))(jnp.arange(8)) @@ -3400,6 +3096,7 @@ def _pmapped_fun(inputs): pjit(_pmapped_fun)(inputs) # doesn't crash jax.jit(_pmapped_fun)(inputs) # doesn't crash + @unittest.skipIf(lib.jaxlib_extension_version < 396, "jaxlib version") @jtu.thread_unsafe_test() # logging is not thread-safe def test_cache_miss_explanations_sharding_mismatch(self): mesh = jtu.create_mesh((2,), ('x',)) @@ -3423,9 +3120,8 @@ def f(x, y): f(x_, y) self.assertLen(cm.output, expected_log_len) msg = cm.output[0] - self.assertIn('never seen input type signature', msg) - self.assertIn('closest seen input type signature has 1 mismatches', msg) - self.assertIn("seen f32[8]({}), but now given f32[8]({Auto: ('x',)})", msg) + self.assertIn("different input types", msg) + self.assertIn("at x, now f32[8]({Auto: ('x',)}) and before f32[8]({})", msg) def test_pjit_function_cache_cpp(self): def f(x): @@ -3438,6 +3134,7 @@ def f(x): pjit(f)(inp) self.assertEqual(count(), 1) + @jtu.thread_unsafe_test() # count_pjit_cpp_cache_miss is not thread-safe def test_pjit_no_global_cache_hit_axis_resources(self): mesh = jtu.create_mesh((1,), ('x',)) s = NamedSharding(mesh, P('x')) @@ -3508,11 +3205,18 @@ def test_device_put_sharding_nondivisible_sharding_error(self): 'divisible by 2, but it is equal to 1 '): jax.device_put((y, x), s) + if config.pmap_shmap_merge.value: + expected_regex = re.compile( + r"One of device_put args was given the sharding of .*" + ) + else: + expected_regex = ( + "The sharded dimension must be equal to the number of " + "devices passed to PmapSharding. Got sharded dimension 0 with value 1 " + r"in shape \(1,\) and the number of devices=2") + with self.assertRaisesRegex( - ValueError, - "The sharded dimension must be equal to the number of " - "devices passed to PmapSharding. Got sharded dimension 0 with value 1 " - r"in shape \(1,\) and the number of devices=2"): + ValueError, expected_regex): s2 = jax.pmap(lambda x: x, devices=list(mesh.devices.flat))(jnp.arange(2)).sharding jax.device_put(x, s2) @@ -3557,12 +3261,29 @@ def test_device_assignment_mismatch_apply_primitive(self): r"of concatenate with shape int.*\[8\].*and argument.*"): jnp.concatenate([arr, arr2]) + def test_closed_over_constant_diff_sharding(self): + if not config.use_simplified_jaxpr_constants.value: + self.skipTest('Requires use_simplified_jaxpr_constants=True') + if jax.device_count() < 2: + self.skipTest('Requires >=2 devices') + + arr = jax.device_put(np.arange(8), SingleDeviceSharding(jax.devices()[0])) + const = jax.device_put(np.arange(8), SingleDeviceSharding(jax.devices()[1])) + + @jax.jit + def f(x): + return x + const + + with self.assertRaisesRegex( + ValueError, "Received incompatible devices for jitted computation"): + f(arr) + + out = f(np.arange(8)) + self.assertEqual(out.sharding, const.sharding) + def test_device_put_grad(self): if jax.device_count() < 8: self.skipTest("Requires >=8 devices.") - if jtu.is_device_tpu(5, 'e'): - self.skipTest('TPU v5e does not support computations that run on a ' - 'non-singleton subset of cores.') def _test(fun, inp, np_inp, in_s): out = fun(inp) @@ -3612,20 +3333,17 @@ def g(x): @jtu.thread_unsafe_test() # cache_info isn't thread-safe def test_pjit_out_sharding_preserved(self): - if config.use_shardy_partitioner.value: - raise unittest.SkipTest("Shardy doesn't support PositionalSharding") mesh = jtu.create_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) - ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) + gs = GSPMDSharding(jax.devices()[:2], ns._to_xla_hlo_sharding(2)) arr = jax.device_put(np.arange(8).reshape(8, 1), ns) - arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps) + arr2 = jax.device_put(np.arange(8).reshape(8, 1), gs) def mul(x): return x * 2 f = pjit(mul, out_shardings=ns) - f2 = pjit(mul, out_shardings=ps) with jtu.count_pjit_cpp_cache_miss() as count: out = f(arr) @@ -3636,24 +3354,12 @@ def mul(x): self.assertIsInstance(out.sharding, NamedSharding) self.assertEqual(count(), 1) - with jtu.count_pjit_cpp_cache_miss() as count: - out2 = f2(arr) - cache_info2 = pxla._cached_compilation.cache_info() - self.assertIsInstance(out2.sharding, PositionalSharding) - - out2 = f2(arr) - self.assertIsInstance(out2.sharding, PositionalSharding) - self.assertEqual(count(), 1) - - self.assertEqual(cache_info2.hits, cache_info1.hits + 1) - self.assertEqual(cache_info2.misses, cache_info1.misses) - with jtu.count_jit_tracing_cache_miss() as tracing_count: out3 = jnp.squeeze(arr, axis=-1) self.assertIsInstance(out3.sharding, NamedSharding) out4 = jnp.squeeze(arr2, axis=-1) - self.assertIsInstance(out4.sharding, PositionalSharding) + self.assertIsInstance(out4.sharding, GSPMDSharding) self.assertEqual(tracing_count(), 2) @jtu.thread_unsafe_test() # cache_info isn't thread-safe @@ -3686,25 +3392,6 @@ def test_list_in_pspec(self): out = with_sharding_constraint(jnp.arange(8), P(['x'])) self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) - def test_sharding_preserved_trivial(self): - if config.use_shardy_partitioner.value: - raise unittest.SkipTest("Shardy doesn't support PositionalSharding") - mesh = jtu.create_mesh((2, 1), ('x', 'y')) - ns = NamedSharding(mesh, P('x')) - ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) - - arr = jax.device_put(np.arange(8).reshape(8, 1), ns) - arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps) - - def identity(x): - return x - - out = pjit(identity)(arr) - self.assertIsInstance(out.sharding, NamedSharding) - - out2 = pjit(identity)(arr2) - self.assertIsInstance(out2.sharding, PositionalSharding) - def test_wsc_error_on_none(self): with self.assertRaisesRegex( ValueError, @@ -3712,23 +3399,6 @@ def test_wsc_error_on_none(self): ' not allowed'): with_sharding_constraint(jnp.arange(8), None) - def test_sharding_preserved_aot(self): - mesh = jtu.create_mesh((2, 1), ('x', 'y')) - ns = NamedSharding(mesh, P('x')) - ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) - - arr = jax.device_put(np.arange(8).reshape(8, 1), ns) - arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps) - - compiled = pjit(lambda x: x * 2).lower(arr).compile() - out = compiled(arr) - self.assertIsInstance(out.sharding, NamedSharding) - - out2 = compiled(arr2) - # The sharding won't be PositionalSharding since the pjit was already - # Compiled which bakes in the output sharding. - self.assertIsInstance(out2.sharding, NamedSharding) - def test_sharding_on_output_with_vmap(self): mesh = jtu.create_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) @@ -3747,16 +3417,29 @@ def test_sharding_on_output_with_vmap(self): self.assertIsInstance(out3.sharding, NamedSharding) self.assertEqual(count(), 1) - @jtu.thread_unsafe_test() # cache_info isn't thread-safe + @config.numpy_dtype_promotion('standard') + def test_mutable_array_closed_over_multi_device(self): + mesh = jtu.create_mesh((2,), ('x',)) + key_data = jax.device_put(jax.random.key_data(jax.random.key(42)), + NamedSharding(mesh, P())) + key_data_ref = core.new_ref(key_data) + + @jax.jit(out_shardings= NamedSharding(mesh, P('x'))) + def generate_random_numbers(): + key_val = key_data_ref[...] + outputs = jnp.arange(8, dtype=jnp.float32) + key_val[0] + return outputs + + generate_random_numbers() # doesn't crash + + @jtu.thread_unsafe_test() # cache_info isn't thread-safe def test_jit_mul_sum_sharding_preserved(self): - if config.use_shardy_partitioner.value: - raise unittest.SkipTest("Shardy doesn't support PositionalSharding") mesh = jtu.create_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) - ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) + gs = GSPMDSharding(tuple(mesh.devices.flat), ns._to_xla_hlo_sharding(2)) arr = jax.device_put(np.arange(8).reshape(8, 1), ns) - arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps) + arr2 = jax.device_put(np.arange(8).reshape(8, 1), gs) f = jax.jit(lambda x: x * 2) @@ -3766,11 +3449,11 @@ def test_jit_mul_sum_sharding_preserved(self): with jtu.count_pjit_cpp_cache_miss() as cpp_count: out2 = f(arr2) - self.assertIsInstance(out2.sharding, PositionalSharding) + self.assertIsInstance(out2.sharding, GSPMDSharding) # This will hit the cpp cache. out3 = f(out2) - self.assertIsInstance(out3.sharding, PositionalSharding) + self.assertIsInstance(out3.sharding, GSPMDSharding) self.assertEqual(compilation_count(), 2) self.assertEqual(cpp_count(), 1) @@ -3818,8 +3501,6 @@ def test_none_out_sharding(self): self.assertEqual(out2.sharding.spec, P()) def test_sharding_preserved_apply_primitive(self): - if config.use_shardy_partitioner.value: - raise unittest.SkipTest("Shardy doesn't support PositionalSharding") mesh = jtu.create_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) @@ -3828,10 +3509,10 @@ def test_sharding_preserved_apply_primitive(self): out = jnp.copy(arr) self.assertIsInstance(out.sharding, NamedSharding) - ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) - arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps) + gs = GSPMDSharding(jax.devices()[:2], ns._to_xla_hlo_sharding(2)) + arr2 = jax.device_put(np.arange(8).reshape(8, 1), gs) out2 = jnp.copy(arr2) - self.assertIsInstance(out2.sharding, PositionalSharding) + self.assertIsInstance(out2.sharding, GSPMDSharding) arr3 = jnp.arange(8) out3 = jnp.copy(arr3) @@ -3883,6 +3564,11 @@ def identity(x): self.assertEqual(out2.devices(), {jax.devices()[0]}) self.assertArraysEqual(out2, np_inp) + def test_jnp_arange_concrete_sharding(self): + mesh = jtu.create_mesh((2,), 'x') + out = jnp.arange(8, device=NamedSharding(mesh, P('x'))) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + def test_jit_submhlo_cached(self): @jax.jit def nest(x): @@ -3912,13 +3598,6 @@ def test_wsc_eager(self): for s in out.addressable_shards: self.assertArraysEqual(s.data, np_inp[s.index]) - def test_wsc_eager_no_resharding(self): - mesh = jtu.create_mesh((2,), ('x',)) - np_inp = np.arange(8) - inp = jax.device_put(np_inp, NamedSharding(mesh, P('x'))) - out = with_sharding_constraint(inp, NamedSharding(mesh, P('x'))) - self.assertEqual(id(out), id(inp)) - def test_wsc_eager_different_order_devices(self): mesh1 = jtu.create_mesh((2,), ('x',)) mesh2 = jax.sharding.Mesh([jax.devices()[1], jax.devices()[0]], 'x') @@ -3989,8 +3668,8 @@ def trace_to_jaxpr(x): def test_shape_dtype_struct_as_const_error(self): const = jax.ShapeDtypeStruct((8,), jnp.int32) - with self.assertRaisesRegex(TypeError, - r"Argument.*is not a valid JAX type"): + with self.assertRaisesRegex(Exception, + r"A ShapeDtypeStruct does not have a value.*"): jax.jit(lambda x: (x, const))(jnp.arange(8)) def test_jit_out_shardings_none(self): @@ -4091,6 +3770,87 @@ def f(): with mesh: f() # doesn't crash + def test_closed_constants_at_top_level(self): + const = literals.TypedNdArray( + np.arange(8, dtype=np.float32), weak_type=False) + + @jax.jit + def f(x): + return x + jax.jit(lambda y: y + const)(x) + + jaxpr = f.trace(const).jaxpr + pjit_e, = [e for e in jaxpr.jaxpr.eqns if e.primitive.name == "jit"] + inner_pjit_jaxpr = pjit_e.params["jaxpr"] + if config.use_simplified_jaxpr_constants.value: + self.assertEmpty(jaxpr.consts) + self.assertEmpty(inner_pjit_jaxpr.consts) + else: + self.assertEmpty(jaxpr.consts) + self.assertIs(const, inner_pjit_jaxpr.consts[0]) + + def test_lowering_cache_hit_with_closed_over_constants_jit(self): + np_inp = np.arange(8) + arr = jnp.arange(8) + np_const = np.arange(9) # distinctive shape + arr_const = jnp.arange(10) # distinctive shape + @jax.jit + def f(x): + return x + np_const[:8] + arr_const[:8] + + # all misses + self.assertCacheMisses(lambda: f(np_inp), cpp=1) + # all hits + self.assertCacheMisses(lambda: f(np_inp), cpp=0, tracing=0, lowering=0) + + # misses in the C++ cache for a different argument + self.assertCacheMisses(lambda: f(arr), cpp=1, tracing=0, lowering=0) + # cpp hit + self.assertCacheMisses(lambda: f(arr), cpp=0, tracing=0, lowering=0) + + # Hits the lowering cache when using the AOT + self.assertCacheMisses(lambda: f.lower(arr), cpp=0, tracing=0, lowering=0) + + def test_lowering_cache_hit_with_closed_over_constants_sharded(self): + mesh = jtu.create_mesh((2,), 'x') + s = NamedSharding(mesh, P('x')) + + inp = jax.device_put(np.arange(4), s) + arr_const = jax.device_put(np.arange(4), s) + np_const = np.arange(5) + + @jax.jit + def f(x): + return x + arr_const + np_const[:4] + + # all misses + self.assertCacheMisses(lambda: f(inp), cpp=1) + # all hits + self.assertCacheMisses(lambda: f(inp), cpp=0, tracing=0, lowering=0) + + # Hits the lowering cache when using the AOT + self.assertCacheMisses(lambda: f.lower(inp), cpp=0, tracing=0, lowering=0) + + @jtu.thread_unsafe_test() + def test_lowering_cache_hit_with_closed_over_constants_scan(self): + np_inp = np.arange(8) + + @jax.jit + def scan_body(carry, x): + return carry + np_inp, None + + @jax.jit + def f(): + return lax.scan(scan_body, np.zeros_like(np_inp), + np.ones((8,), dtype=np.float32)) + + self.assertCacheMisses(f, cpp=1, lowering=1) + self.assertCacheMisses(f, cpp=0, tracing=0, lowering=0) + # Run the scan body directly + self.assertCacheMisses(lambda: scan_body(np_inp, np.float32(1)), + cpp=1, tracing=0, lowering=1) + self.assertCacheMisses(lambda: scan_body(np_inp, np.float32(1)), + cpp=0, tracing=0, lowering=0) + def test_lowering_cache_hit_different_devices(self): if jax.device_count() < 4: self.skipTest('Requires >=4 devices') @@ -4203,7 +3963,7 @@ def test_prng_sharding_propagation_with_nested_jit(self): @jax.jit def make_keys(seeds): - @partial(jax.jit, out_shardings=NamedSharding(mesh, P('y'))) + @jax.jit(out_shardings=NamedSharding(mesh, P('y'))) def f(): make_key = partial(prng.random_seed, impl=prng.threefry_prng_impl) return make_key(seeds) @@ -4250,6 +4010,13 @@ def make_keys(seeds): else: self.assertIn('unspecified_dims=[0,1,2]', lowered_text) + def test_wsc_with_scalar(self): + mesh = jtu.create_mesh((2,), 'x') + s = NamedSharding(mesh, P()) + out = jax.lax.with_sharding_constraint(1., s) + self.assertArraysEqual(out, 1.) + self.assertEqual(out.sharding, s) + def test_jit_partially_specified_shardings(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) @@ -4259,7 +4026,7 @@ def test_jit_partially_specified_shardings(self): arr = jax.device_put(np_inp, s) arr2 = jax.device_put(np_inp, s2) - @partial(jax.jit, in_shardings=(s, None, s2, UNSPECIFIED, UNSPECIFIED), + @jax.jit(in_shardings=(s, None, s2, UNSPECIFIED, UNSPECIFIED), out_shardings=(s2, None, None, s, None)) def f(x, y, z, a, b): return x * 2, y @ y.T, z ** 2, a * 3, b.T @@ -4301,11 +4068,10 @@ def f(*args): f(inps) # doesn't crash def test_spmd_preserves_input_sharding_vmap_grad(self): - if config.use_shardy_partitioner.value: - self.skipTest("Shardy doesn't support PositionalSharding") # https://github.com/jax-ml/jax/issues/20710 n_devices = jax.device_count() - sharding = PositionalSharding(jax.devices()) + mesh = Mesh(jax.devices(), 'x') + sharding = NamedSharding(mesh, P('x')) def model(params, x): return x @ params @@ -4318,8 +4084,8 @@ def model(params, x): params = jnp.ones(feature_dim) # Shard data, replicate params - x = jax.device_put(x, sharding.reshape(n_devices, 1)) - params = jax.device_put(params, sharding.replicate(axis=0)) + x = jax.device_put(x, sharding) + params = jax.device_put(params, NamedSharding(mesh, P())) model(params, x) # doesn't crash @@ -4422,7 +4188,7 @@ def f(x, y): input_shardings, _ = f.lower(inp, inp).compile().input_shardings self.assertLen(input_shardings, 2) - def test_aot_out_info(self): + def test_lowered_out_info(self): inp = np.arange(8, dtype=np.int32) out_info = jax.jit(lambda x: x).lower((inp, inp)).out_info self.assertEqual(out_info[0].shape, (8,)) @@ -4432,6 +4198,26 @@ def test_aot_out_info(self): self.assertEqual(out_info[0].sharding, None) self.assertEqual(out_info[1].sharding, None) + def test_lowered_out_info_mesh(self): + mesh = jtu.create_mesh((2,), 'x') + arr = jax.device_put(np.arange(8, dtype=np.int32), + NamedSharding(mesh, P('x'))) + lowered = jax.jit(lambda x: x * 2).lower(arr) + out_info = lowered.out_info + self.assertEqual(out_info.shape, (8,)) + self.assertEqual(out_info.dtype, np.int32) + self.assertEqual(out_info.sharding, None) + + def test_compiled_out_info(self): + mesh = jtu.create_mesh((2,), 'x') + arr = jax.device_put(np.arange(8, dtype=np.int32), + NamedSharding(mesh, P('x'))) + compiled = jax.jit(lambda x: x * 2).lower(arr).compile() + out_info = compiled.out_info + self.assertEqual(out_info.shape, (8,)) + self.assertEqual(out_info.dtype, np.int32) + self.assertEqual(out_info.sharding, NamedSharding(mesh, P('x'))) + def test_jit_trace(self): def f(x): return x * 2 @@ -4446,9 +4232,15 @@ def f(x): self.assertLen(traced.in_avals[0], 1) self.assertLen(traced.in_avals[1], 0) # empty kwarg + def test_in_out_shardings_unconstrained_error(self): + mesh = jtu.create_mesh((1,), ('x',)) + + with self.assertRaisesRegex( + ValueError, "Unconstrained dims are not allowed"): + jax.jit(lambda x: x, + in_shardings=NamedSharding(mesh, P(P.UNCONSTRAINED, 'x'))) + def test_empty_io_callback_under_shard_map(self): - if config.use_shardy_partitioner.value: - self.skipTest("TODO(b/384938613): Failing under shardy.") mesh = jtu.create_mesh((4,), 'i') def empty_callback(x): @@ -4460,7 +4252,7 @@ def _f(x, y): return x + y[..., jnp.newaxis] f = jax.jit(shard_map( - _f, mesh, in_specs=(P(None, 'i'), P(None)), + _f, mesh=mesh, in_specs=(P(None, 'i'), P(None)), out_specs=P(None, 'i'))) f(jnp.zeros((2, 16)), jnp.ones(2)) @@ -4478,7 +4270,7 @@ def _f(x, y): return x + y[..., jnp.newaxis] f = jax.jit(shard_map( - _f, mesh, in_specs=(P(None, 'i'), P(None)), + _f, mesh=mesh, in_specs=(P(None, 'i'), P(None)), out_specs=P(None, 'i'))) f(jnp.zeros((2, 16)), jnp.ones(2)) @@ -4510,7 +4302,7 @@ def f(x): def test_nullary_out_sharding_partial(self): mesh = jtu.create_mesh((jax.device_count(),), 'x') - @partial(jax.jit, out_shardings=(None, NamedSharding(mesh, P()))) + @jax.jit(out_shardings=(None, NamedSharding(mesh, P()))) def init(): tensor = jnp.zeros(shape=(1,)) other_tensor = jnp.zeros(shape=(1,)) @@ -4565,14 +4357,16 @@ def test_device_put_efficient_reshard_complex_mesh(self, shape): x_s1 = jax.device_put(np_inp, s1) # Reshard! - out = jax.device_put(x_s1, s2) + with jax.transfer_guard('disallow_explicit'): + out = jax.device_put(x_s1, s2) self.assertArraysEqual(out, np_inp) self.assertEqual(out.sharding, s2) s3 = NamedSharding(mesh2, P('model_q')) x_s3 = jax.device_put(np_inp, s3) # Reshard to iota device assignment! - out2 = jax.device_put(x_s3, s1) + with jax.transfer_guard('disallow_explicit'): + out2 = jax.device_put(x_s3, s1) self.assertArraysEqual(out2, np_inp) self.assertEqual(out2.sharding, s1) @@ -4615,15 +4409,13 @@ def test_convert_element_type_sharding(self): inp = np.arange(16).reshape(8, 2) out = lax_internal._convert_element_type( - inp, new_dtype=np.float32, weak_type=False, sharding=s) + inp, new_dtype=np.dtype(np.float32), weak_type=False, sharding=s) self.assertArraysEqual(out, inp.astype('float32')) self.assertEqual(out.dtype, np.float32) self.assertEqual(out.sharding, s) def test_jnp_array_sharding(self): - if jax.device_count() < 4: - self.skipTest('Requires >=4 devices') - mesh = jax.make_mesh((2, 2), ('x', 'y'), devices=jax.devices()[:4]) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) inp = np.arange(16).reshape(8, 2) @@ -4632,9 +4424,7 @@ def test_jnp_array_sharding(self): self.assertEqual(out.sharding, s) def test_jnp_array_inside_jit_sharding(self): - if jax.device_count() < 4: - self.skipTest('Requires >=4 devices') - mesh = jax.make_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) inp = np.arange(16).reshape(8, 2) @@ -4660,9 +4450,10 @@ def test_make_mesh_non_int_error(self): with self.assertRaisesRegex( ValueError, "`axis_shapes` passed to `make_mesh` should be a sequence of ints"): - jax.make_mesh(((4,), 4), ('x', 'y')) + jax.make_mesh(((4,), 4), ('x', 'y'), axis_types=(AxisType.Auto,) * 2) - jax.make_mesh((1, np.int32(1), np.int64(1)), ('x', 'y', 'z')) # doesn't crash + jax.make_mesh((1, np.int32(1), np.int64(1)), ('x', 'y', 'z'), + axis_types=(AxisType.Explicit,) * 3) # doesn't crash def test_jnp_array_reshard_error(self): if jax.device_count() < 2: @@ -4706,7 +4497,8 @@ def f(x): cpu_arr = jax.device_put(np_inp, jax.devices('cpu')[0]) with self.assertRaisesRegex( ValueError, - "Compiled object called with input sharding.*does not match"): + r'Computation was compiled for input shardings.* that ' + r'disagree with the shardings.* of arguments passed to it'): compiled(cpu_arr) def test_different_devices_wsc_abstract_mesh_cache_hit(self): @@ -4828,14 +4620,53 @@ def f(x): ins, _ = f.lower(np.arange(8)).compile().input_shardings self.assertEqual(ins[0], SingleDeviceSharding(jax.devices()[0])) + def test_aot_devices_to_compile(self): + mesh = jtu.create_mesh((2, 1), ('x', 'y')) + abstract_sds = jax.ShapeDtypeStruct( + (8, 2), jnp.float32, sharding=NamedSharding(mesh.abstract_mesh, P('x'))) + + @jax.jit + def f(x): + return x * 2 + + lowered = f.trace(abstract_sds).lower(lowering_platforms=('tpu',)) + self.assertIn('num_partitions = 2', lowered.as_text()) + + compiled = lowered.compile(device_assignment=tuple(mesh.devices.flat)) + + arr = jax.device_put(np.arange(16, dtype=jnp.float32).reshape(8, 2), + NamedSharding(mesh, P('x'))) + out = compiled(arr) + self.assertArraysEqual(out, arr * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + with self.assertRaisesRegex( + ValueError, + 'The size of abstract mesh 2.*must match the length of device' + ' assignment: 1'): + lowered.compile(device_assignment=(jax.devices()[0],)) + + def test_aot_devices_to_compile_error(self): + mesh = jtu.create_mesh((2,), ('x',)) + arr = jax.ShapeDtypeStruct((8, 2), jnp.float32, + sharding=NamedSharding(mesh, P('x'))) + + @jax.jit + def f(x): + return x + + with self.assertRaisesRegex( + ValueError, + 'device_assignment passed to `.compile` must match the' + ' device_assignment calculated from array shardings and out_shardings'): + f.lower(arr).compile(device_assignment=(jax.devices()[0],)) + def test_abstract_mesh_lower(self): mesh = jtu.create_mesh((2,), 'x') mesh2 = jtu.create_mesh((1,), 'x') abstract_sds = jax.ShapeDtypeStruct( (8, 2), jnp.float32, sharding=NamedSharding(mesh.abstract_mesh, P('x'))) - abstract_sds2 = jax.ShapeDtypeStruct( - (8, 2), jnp.float32, sharding=NamedSharding(mesh2.abstract_mesh, P('x'))) @jax.jit def f(x): @@ -4845,7 +4676,7 @@ def f(x): self.assertIn('num_partitions = 2', lowered.as_text()) with self.assertRaisesRegex( - RuntimeError, 'A jitted computation cannot contain AbstractMesh'): + RuntimeError, 'device_assignment cannot be `None` during compilation'): lowered.compile() @jax.jit @@ -4854,6 +4685,8 @@ def g(x, y): concrete_s = NamedSharding(mesh, P('x')) concrete_sds = jax.ShapeDtypeStruct((8,), jnp.float32, sharding=concrete_s) + abstract_sds2 = jax.ShapeDtypeStruct( + (8, 2), jnp.float32, sharding=NamedSharding(mesh2.abstract_mesh, P('x'))) with self.assertRaisesRegex( ValueError, 'AbstractMesh size: 1 does not match the device assignment size: 2'): @@ -4867,14 +4700,24 @@ def g(x, y): lowering_platforms=('tpu',)) self.assertIn('num_partitions = 2', lowered2.as_text()) with self.assertRaisesRegex( - RuntimeError, 'A jitted computation cannot contain AbstractMesh'): + RuntimeError, 'device_assignment cannot be `None` during compilation'): lowered2.compile() - lowered3 = g.lower(abstract_sds, concrete_sds) - self.assertIn('num_partitions = 2', lowered3.as_text()) - with self.assertRaisesRegex( - RuntimeError, 'A jitted computation cannot contain AbstractMesh'): - lowered3.compile() + def test_abstract_mesh_lower_same_size_different_axis_name(self): + mesh = jtu.create_mesh((2,), 'x') + mesh2 = jtu.create_mesh((2,), 'y') + + a1 = jax.ShapeDtypeStruct( + (8, 2), jnp.float32, sharding=NamedSharding(mesh.abstract_mesh, P('x'))) + a2 = jax.ShapeDtypeStruct( + (8, 2), jnp.float32, sharding=NamedSharding(mesh2.abstract_mesh, P('y'))) + + @jax.jit + def f(x, y): + return x * y + + lowered = f.trace(a1, a2).lower(lowering_platforms=('tpu',)) + self.assertIn('num_partitions = 2', lowered.as_text()) def test_jit_out_shardings_unconstrained(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) @@ -4883,7 +4726,7 @@ def test_jit_out_shardings_unconstrained(self): arr = jax.device_put(np_inp, s) out_s = NamedSharding(mesh, P(P.UNCONSTRAINED, P.UNCONSTRAINED)) - @partial(jax.jit, out_shardings=out_s) + @jax.jit(out_shardings=out_s) def f(x): return x * 2 @@ -4891,7 +4734,7 @@ def f(x): self.assertEqual(out.sharding, s) self.assertArraysEqual(out, np_inp * 2) - @partial(jax.jit, out_shardings=NamedSharding(mesh, P(P.UNCONSTRAINED, 'y'))) + @jax.jit(out_shardings=NamedSharding(mesh, P(P.UNCONSTRAINED, 'y'))) def g(x): return x * 3 @@ -4900,13 +4743,237 @@ def g(x): self.assertEqual(out.sharding, s) lowered_text = g.lower(arr).as_text() if config.use_shardy_partitioner.value: - self.assertIn('<@mesh, [{?}, {"y"}]>', lowered_text) + self.assertIn('<@mesh, [{?}, {"y", ?}]>', lowered_text) else: - self.assertIn("unspecified_dims=[0]", lowered_text) + self.assertIn("unspecified_dims=[0,1]", lowered_text) + def test_prng_key_wsc(self): + mesh = jtu.create_mesh((2,), 'x') -def spec_regex(s): - return str(s).replace(r"(", r"\(").replace(r")", r"\)") + @jax.jit + def f(x): + y = lax.with_sharding_constraint(x, NamedSharding(mesh, P())) + return y.T + f(jax.random.key(0)) # doesn't crash + + @jax.jit + def g(x): + return lax.with_sharding_constraint(x, NamedSharding(mesh, P())) + g(jax.random.key(1)) # doesn't crash + + def test_prng_key_wsc_multi_axes_sharding(self): + input_shape = (8, 4) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) + spec = P('x', 'y') + + seeds, _ = create_array(input_shape, mesh, spec, dtype=np.uint32) + + @jax.jit + def make_keys(seeds): + make_key = partial(prng.random_seed, impl=prng.threefry_prng_impl) + return lax.with_sharding_constraint( + make_key(seeds), NamedSharding(mesh, P('x', 'y'))) + + out = make_keys(seeds) + self.assertTrue(jax.dtypes.issubdtype(out.dtype, jax.dtypes.prng_key)) + self.assertEqual(out.shape, input_shape) + jax.random.key_data(out) # doesn't crash + + def test_sds_update(self): + mesh = jtu.create_mesh((2, 1), ('x', 'y')) + s1 = jax.ShapeDtypeStruct((2, 2), jnp.int32) + s1_u = s1.update(shape=(4, 2), dtype=np.float32) + self.assertEqual(s1_u.shape, (4, 2)) + self.assertEqual(s1_u.dtype, np.float32) + self.assertFalse(s1_u.weak_type) + + s2 = jax.ShapeDtypeStruct((2, 2), jnp.int32) + s2_u = s2.update(shape=(4, 2), weak_type=True) + self.assertEqual(s2_u.shape, (4, 2)) + self.assertEqual(s2_u.dtype, np.int32) + self.assertTrue(s2_u.weak_type) + + s3 = jax.ShapeDtypeStruct((2, 2), jnp.int32, + sharding=NamedSharding(mesh, P())) + s3_u = s3.update(sharding=NamedSharding(mesh, P('x'))) + self.assertEqual(s3_u.sharding, NamedSharding(mesh, P('x'))) + + s32_u = s3.update(shape=(4, 2)) + self.assertEqual(s32_u.shape, (4, 2)) + self.assertEqual(s32_u.sharding, NamedSharding(mesh, P())) + + sh = NamedSharding(mesh, P()) + s4 = jax.ShapeDtypeStruct((2, 2), jnp.int32, + sharding=Format(DLL((0, 1)), sh)) + new_layout = Format(DLL((1, 0)), NamedSharding(mesh, P('x'))) + s4_u = s4.update(sharding=new_layout) + self.assertEqual(s4_u.sharding, new_layout.sharding) + self.assertEqual(s4_u.format, new_layout) + + with self.assertRaisesRegex(ValueError, "updating ShapeDtypeStruct"): + s4.update(sharding=NamedSharding(mesh, P('x'))) + + @jtu.with_explicit_mesh((2, 1), ('x', 'y'), axis_types=(AxisType.Auto,) * 2) + def test_sds_pspec_input(self, mesh): + inp = jax.ShapeDtypeStruct((2, 2), np.float32, sharding=P('x')) + lowered = jax.jit(lambda x: x * 2).lower(inp) + self.assertIn('num_partitions = 2', lowered.as_text()) + + np_inp = np.arange(4, dtype=np.float32).reshape(2, 2) + arr = jax.device_put(np_inp, P('x')) + out = lowered.compile()(arr) + self.assertArraysEqual(out, np_inp * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + def test_sds_pspec_no_mesh_ctx_error(self): + sds = jax.ShapeDtypeStruct((2, 2), np.float32, sharding=P('x')) + + mesh = jtu.create_mesh((2,), 'x') + with jax.set_mesh(mesh): + out_s = sds.sharding + self.assertEqual(out_s, NamedSharding(mesh, P('x'))) + + with self.assertRaisesRegex( + TypeError, + 'When specifying PartitionSpec to `ShapeDtypeStruct`, the context mesh' + ' cannot be empty'): + _ = sds.sharding + + mesh = jtu.create_mesh((2,), 'y') + with jax.set_mesh(mesh): + with self.assertRaisesRegex( + ValueError, "Resource axis.*not found in mesh"): + _ = sds.sharding + + def test_set_mesh_none_out_sharding(self): + mesh = jtu.create_mesh((2,), 'x') + + @jax.jit(static_argnums=0, out_shardings=None) + def f(spec): + return jax.lax.with_sharding_constraint(jnp.arange(8), spec) + + # no mesh context and mesh context behavior has to be the same + with jax.set_mesh(mesh): + out = f(P('x')) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + out = f(NamedSharding(mesh, P('x'))) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + def test_lowering_cache_hit_inputs_with_different_mesh(self): + mesh1 = jtu.create_mesh((2, 2), ('x', 'y')) + devs = jax.devices()[:4] + mesh2 = Mesh(np.asarray(devs[::-1]).reshape(2, 2), ('x', 'y')) + + np_inp = np.arange(16).reshape(8, 2) + arr1 = jax.device_put(np_inp, NamedSharding(mesh1, P('x', 'y'))) + arr2 = jax.device_put(np_inp, NamedSharding(mesh2, P('x', 'y'))) + + @jax.jit + def f(x): + return x * 2 + + with (jtu.count_pjit_cpp_cache_miss() as cpp_count, + jtu.count_jit_and_pmap_lowerings() as lowering_count): + f(arr1) + f(arr2) + self.assertEqual(cpp_count(), 2) + self.assertEqual(lowering_count(), 1) + + def test_single_device_lowering_cache_hit_diff_devices(self): + if jax.device_count() < 2: + self.skipTest('Requires device_count() >= 2') + np_inp = np.arange(16).reshape(8, 2) + arr1 = jax.device_put(np_inp, SingleDeviceSharding(jax.devices()[0])) + arr2 = jax.device_put(np_inp, SingleDeviceSharding(jax.devices()[1])) + + @jax.jit + def f(x): + return x * 2 + + with (jtu.count_pjit_cpp_cache_miss() as cpp_count, + jtu.count_jit_and_pmap_lowerings() as lowering_count): + f(arr1) + f(arr2) + self.assertEqual(cpp_count(), 2) + self.assertEqual(lowering_count(), 1) + + def test_compiler_options_nested_jit_error(self): + @jax.jit(compiler_options={"invalid_key": "invalid_value"}) + def g(x): + return x * 2 + + @jax.jit + def f(x): + x = x + 2 + return g(x) + + with self.assertRaisesRegex( + ValueError, + '`compiler_options` can only be passed to top-level `jax.jit`'): + f(jnp.arange(4)) + + def test_wsc_aval_cache_hit(self): + mesh = jtu.create_mesh((2,), 'x') + sharding = NamedSharding(mesh, P('x')) + zeros = jnp.zeros((2,2)) + x = jax.device_put(zeros, sharding) + + @jax.jit + def inner(x): + return x + + @jax.jit + def init(): + x = with_sharding_constraint(zeros, sharding) + self.assertEqual(x.aval.sharding.mesh, sharding.mesh.abstract_mesh) + return inner(x) + + @jax.jit + def apply(x): + return inner(x) + + with jtu.count_jit_tracing_cache_miss() as count: + init() + apply(x) + self.assertEqual(count(), 3) # misses for init, apply and inner (only once) + + def test_wsc_aval_diff_shardings(self): + mesh = jtu.create_mesh((1,), 'x') + x = jax.device_put(jnp.zeros((2,2)), NamedSharding(mesh, P())) + + @jax.jit + def f(x): + out1 = with_sharding_constraint(x, NamedSharding(mesh, P('x'))) + self.assertEqual(out1.aval.sharding.mesh, mesh.abstract_mesh) + out2 = with_sharding_constraint(out1, SingleDeviceSharding(jax.devices()[0])) + self.assertTrue(out2.aval.sharding.mesh.empty) + return out2 + + f(x) # doesn't crash + + @jtu.with_explicit_mesh((2,), ('x',)) + def test_sds_input_to_zeros_like_propagates_sharding(self, mesh): + val = jax.ShapeDtypeStruct( + (32,), dtype=jnp.float32, + sharding=NamedSharding(mesh.abstract_mesh, P('x'))) + out = jnp.zeros_like(val) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + def test_sds_incompatible_sharding(self): + mesh = jtu.create_mesh((2,), 'x') + with self.assertRaisesRegex( + ValueError, + "only valid for values of rank at least 3, but was applied to a value " + "of rank 2"): + jax.ShapeDtypeStruct((128, 128), jnp.float32, + sharding=NamedSharding(mesh, P(None, 'x', None))) + + with self.assertRaisesRegex( + ValueError, + "only valid for values of rank at least 3, but was applied to a value " + "of rank 2"): + jax.ShapeDtypeStruct((128, 128), jnp.float32, sharding=P(None, 'x', None)) class ShardingInTypesTest(jtu.JaxTestCase): @@ -4917,10 +4984,10 @@ def check_wsc_in_lowered(self, text): else: self.assertIn('@Sharding', text) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_basic_mul(self, mesh): np_inp = np.arange(16.).reshape(8, 2) - s = NamedSharding(mesh, P('x', 'y')) + s = NamedSharding(mesh, jax.P('x', 'y')) arr = jax.device_put(np_inp, s) def f(x): @@ -4962,14 +5029,14 @@ def g(x): jax.jit(jax.grad(g)).lower(sds) # doesn't crash - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_fully_replicated_array_mul(self, mesh): np_inp1 = np.arange(16).reshape(8, 2) - s = NamedSharding(mesh, P('x', 'y')) + s = NamedSharding(mesh, jax.P('x', 'y')) arr1 = jax.device_put(np_inp1, s) np_inp2 = np.arange(2).reshape(1, 2) - arr2 = jax.device_put(np_inp2, NamedSharding(mesh, P(None, None))) + arr2 = jax.device_put(np_inp2, NamedSharding(mesh, jax.P(None, None))) @jax.jit def f(x, y): @@ -4982,11 +5049,12 @@ def f(x, y): self.assertEqual(out.sharding, s) self.assertArraysEqual(out, (np_inp1 * np_inp2)) - out = f(arr1, jax.device_put(np_inp1, NamedSharding(mesh, P(('x',), ('y',))))) + out = f(arr1, jax.device_put( + np_inp1, NamedSharding(mesh, jax.P(('x',), ('y',))))) self.assertEqual(out.sharding, s) self.assertArraysEqual(out, (np_inp1 * np_inp1)) - out = f(arr1, jax.device_put(np_inp2, NamedSharding(mesh, P()))) + out = f(arr1, jax.device_put(np_inp2, NamedSharding(mesh, jax.P()))) self.assertEqual(out.sharding, s) self.assertArraysEqual(out, (np_inp1 * np_inp2)) @@ -4995,12 +5063,14 @@ def g(x, y): return x * y with self.assertRaisesRegex( - TypeError, "mul got incompatible shardings for broadcasting"): - g(arr1, jax.device_put(np_inp1, NamedSharding(mesh, P('y', 'x')))) + core.ShardingTypeError, + "mul got incompatible shardings for broadcasting"): + g(arr1, jax.device_put(np_inp1, NamedSharding(mesh, jax.P('y', 'x')))) with self.assertRaisesRegex( - TypeError, "mul got incompatible shardings for broadcasting"): - g(arr1, jax.device_put(np_inp1, NamedSharding(mesh, P(('x', 'y'))))) + core.ShardingTypeError, + "mul got incompatible shardings for broadcasting"): + g(arr1, jax.device_put(np_inp1, NamedSharding(mesh, jax.P(('x', 'y'))))) @parameterized.named_parameters( ('x_y', P('x', None), P(None, 'y'), P('x', 'y'), None), @@ -5009,7 +5079,7 @@ def g(x, y): ('fsdp', P('x', None), P('x', None), P('x', None), 'all-gather'), ('half_tp', P(None, 'y'), P(None, 'y'), P(None, 'y'), 'all-gather'), ) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_dot_general(self, spec1, spec2, out_spec, collective_name, mesh): np_inp1 = np.arange(16.).reshape(8, 2) arr1 = jax.device_put(np_inp1, NamedSharding(mesh, spec1)) @@ -5034,7 +5104,7 @@ def f(x, y): self.check_wsc_in_lowered(lowered.as_text()) compiled_text = lowered.compile().as_text() - if collective_name is not None and compiled_text is not None: + if collective_name is not None: self.assertIn(collective_name, compiled_text) @jax.jit @@ -5051,7 +5121,7 @@ def g(x, y): self.assertEqual(out[1].sharding, arr2.sharding) @parameterized.parameters([True, False]) - @jtu.with_user_mesh((4,), ('x',)) + @jtu.with_explicit_mesh((4,), ('x',)) def test_dot_general_out_sharding(self, use_jit, mesh): np_inp1 = np.arange(16.).reshape(8, 2) arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None))) @@ -5073,7 +5143,7 @@ def f(x, y): ValueError, 'PartitionSpec passed to einsum cannot contain axis names that are of' ' type Auto or Manual'): - auto_axes(f, out_shardings=P())(arr1, arr2) + auto_axes(f, out_sharding=P())(arr1, arr2) out = jax.grad(f, argnums=(0, 1))(arr1, arr2) self.assertEqual(out[0].sharding, arr1.sharding) @@ -5086,28 +5156,28 @@ def f(x, y): self.assertEqual(out[1].sharding, arr2.sharding) jaxpr = jitted_grad.trace(arr1, arr2).jaxpr - bwd_jaxpr = jaxpr.eqns[1] - expected_spec = [('broadcast_in_dim', P('x', None)), - ('dot_general', P('x', None)), - ('transpose', P(None, 'x')), - ('dot_general', P('x', None))] - for eqn, spec in zip(bwd_jaxpr.params['jaxpr'].eqns, expected_spec): - self.assertEqual(eqn.primitive.name, spec[0]) - self.assertEqual(eqn.outvars[0].aval.sharding.spec, spec[1]) + bwd_jaxpr = next(e for e in reversed(jaxpr.eqns) if 'jaxpr' in e.params) + expected_spec = {'broadcast_in_dim': P('x', None), + 'dot_general': P('x', None), + 'transpose': P(None, 'x')} + for eqn in bwd_jaxpr.params['jaxpr'].eqns: + spec = expected_spec.get(eqn.primitive.name) + if spec is not None: + self.assertEqual(eqn.outvars[0].aval.sharding.spec, spec) @parameterized.named_parameters( ('fail1', P('x', None), P(None, 'x'), "dot_general operation.*produces an illegally sharded result", - TypeError), + core.ShardingTypeError), ('fail2', P('x', 'y'), P('x', 'y'), "dot_general requires contracting dimensions to have consistent sharding", - TypeError), + core.ShardingTypeError), ('contracting1', P('x', 'y'), P('y', None), - 'Contracting dimensions are sharded', ValueError), + 'Contracting dimensions are sharded', core.ShardingTypeError), ('other_half_tp', P(None, 'y'), P('y', None), - 'Contracting dimensions are sharded', ValueError), + 'Contracting dimensions are sharded', core.ShardingTypeError), ) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_dot_general_error(self, spec1, spec2, error_msg, error_type, mesh): np_inp1 = np.arange(16).reshape(8, 2) arr1 = jax.device_put(np_inp1, NamedSharding(mesh, spec1)) @@ -5120,26 +5190,26 @@ def f(x, y): with self.assertRaisesRegex(error_type, error_msg): f(arr1, arr2) - @jtu.with_user_mesh((2, 2, 1), ('x', 'y', 'z')) + @jtu.with_explicit_mesh((2, 2, 1), ('x', 'y', 'z')) def test_dot_general_batch_error(self, mesh): arr1 = jax.device_put(np.ones((8, 4, 2)), NamedSharding(mesh, P('x', 'y', 'z'))) arr2 = jax.device_put(np.ones((8, 2, 4)), NamedSharding(mesh, P('y', 'z', 'x'))) with self.assertRaisesRegex( - TypeError, + core.ShardingTypeError, 'dot_general requires lhs batch dimensions and rhs batch dimensions to' ' have the consistent sharding'): jax.lax.dot_general( arr1, arr2, dimension_numbers=(([2], [1]), ([0], [0]))) with self.assertRaisesRegex( - TypeError, + core.ShardingTypeError, 'dot_general requires lhs batch dimensions and rhs batch dimensions to' ' have the consistent sharding'): jnp.einsum('abc,acz->abz', arr1, arr2) - @jtu.with_user_mesh((2, 2), ('model', 'data')) + @jtu.with_explicit_mesh((2, 2), ('model', 'data')) def test_aval_repr(self, mesh): mesh = mesh.abstract_mesh aval = core.ShapedArray((128, 64), np.float32, @@ -5158,7 +5228,7 @@ def test_aval_repr(self, mesh): aval = aval.update(sharding=NamedSharding(mesh, P(('model', 'data'), None))) self.assertEqual(aval.str_short(), 'float32[128@(model,data),64]') - @jtu.with_user_mesh((2, 1), ('x', 'y')) + @jtu.with_explicit_mesh((2, 1), ('x', 'y')) def test_jnp_ones_mesh_context_eager(self, mesh): s = NamedSharding(mesh, P('x', None)) out = jnp.ones((8, 2), dtype=jnp.int32, device=s) @@ -5168,6 +5238,27 @@ def test_jnp_ones_mesh_context_eager(self, mesh): out = jnp.ones((8, 2), dtype=jnp.int32, device=s) self.assertEqual(out.sharding, s) + @jtu.with_explicit_mesh((2,), 'x') + def test_jnp_like_functions(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + + out = jnp.zeros_like(np_inp, out_sharding=P('x')) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + self.assertArraysEqual(out, np.zeros_like(np_inp)) + + out = jnp.ones_like(np_inp, out_sharding=P('x')) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + self.assertArraysEqual(out, np.ones_like(np_inp)) + + @jtu.with_explicit_mesh((2, 1), ('x', 'y')) + def test_jnp_array_out_sharding(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + out = jnp.array(np_inp, dtype=jnp.int32, out_sharding=P('x')) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + + out = jnp.asarray(np_inp, dtype=jnp.int32, out_sharding=P('x')) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + @parameterized.named_parameters( ('all', None, P('x', 'y'), P(), True), ('first', 0, P('x', 'y'), P('y'), True), @@ -5175,7 +5266,7 @@ def test_jnp_ones_mesh_context_eager(self, mesh): ('first2', 0, P(('x', 'y'), None), P(None), True), ('second2', 1, P(('x', 'y'), None), P(('x', 'y')), False), ) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_reduce_sum(self, axis, in_spec, out_spec, reduce, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, in_spec) @@ -5196,7 +5287,7 @@ def f(x): self.check_wsc_in_lowered(lowered.as_text()) compiled_text = lowered.compile().as_text() - if reduce and compiled_text is not None: + if reduce: self.assertIn('all-reduce', compiled_text) @parameterized.named_parameters( @@ -5206,7 +5297,7 @@ def f(x): ('first2', 0, P(('x', 'y'), None), P(None), True), ('second2', 1, P(('x', 'y'), None), P(('x', 'y')), False), ) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_reduce_max(self, axis, in_spec, out_spec, reduce, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, in_spec) @@ -5227,7 +5318,7 @@ def f(x): self.check_wsc_in_lowered(lowered.as_text()) compiled_text = lowered.compile().as_text() - if reduce and compiled_text is not None: + if reduce: self.assertIn('all-reduce', compiled_text) @jax.jit @@ -5247,7 +5338,7 @@ def g(x): ('2', 2, P('x', 'y', None)), ('-1', -1, P('x', 'y', None)), ) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_broadcast_in_dim(self, axis, out_spec, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -5273,7 +5364,7 @@ def f(x): ('3', 3), ('4', 4), ) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_integer_pow(self, pow, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -5292,7 +5383,7 @@ def f(x): lowered_text = f.lower(arr).as_text() self.check_wsc_in_lowered(lowered_text) - @jtu.with_user_mesh((1,), 'x') + @jtu.with_explicit_mesh((1,), 'x') def test_broadcasting_nary_error(self, mesh): mesh2 = Mesh([jax.devices()[0]], 'y', axis_types=(mesh_lib.AxisType.Explicit,)) @@ -5308,7 +5399,24 @@ def f(x, y): ValueError, "For primitive.*context mesh.*aval mesh"): f(arr1, arr2) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2,), 'x') + def test_no_op_broadcast_except_for_sharding_change(self, mesh): + arr = jnp.arange(8.).reshape(4, 2) + + @jax.jit + def f(x): + out = jax.lax.broadcast_in_dim(x, (4, 2), [0, 1], out_sharding=P('x')) + self.assertEqual(out.aval.sharding.spec, P('x', None)) + return out + + out = f(arr) + self.assertArraysEqual(out, arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + + out_g = jax.jit(jax.grad(lambda x: f(x).sum()))(arr) + self.assertEqual(out_g.sharding, NamedSharding(mesh, P(None, None))) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_sin_unop(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -5326,7 +5434,7 @@ def f(x): lowered_text = f.lower(arr).as_text() self.check_wsc_in_lowered(lowered_text) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_jnp_array(self, mesh): np_inp = np.arange(16, dtype=jnp.int32).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -5342,7 +5450,7 @@ def f(x): f(arr) - @jtu.with_user_mesh((2, 2, 1), ('x', 'y', 'z')) + @jtu.with_explicit_mesh((2, 2, 1), ('x', 'y', 'z')) def test_lax_transpose_rule(self, mesh): np_inp = np.arange(16).reshape(4, 2, 2) s = NamedSharding(mesh, P('x', 'y', 'z')) @@ -5361,7 +5469,7 @@ def f(x): lowered_text = f.lower(arr).as_text() self.check_wsc_in_lowered(lowered_text) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_broadcasted_iota_with_sharding(self, mesh): np_inp = np.arange(4) s = NamedSharding(mesh, P('x')) @@ -5386,7 +5494,7 @@ def g(x): _, out = g(arr) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_einsum_with_out_sharding(self, mesh): np_inp = np.arange(16.).reshape(8, 2) arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) @@ -5431,7 +5539,7 @@ def h2(x, y): self.assertEqual(out[0].sharding, arr3.sharding) self.assertEqual(out[1].sharding, arr4.sharding) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_einsum_inverse(self, mesh): np_inp = np.arange(64.) @@ -5465,24 +5573,59 @@ def h2(x, y): self.assertEqual(out[0].sharding, arr1.sharding) self.assertEqual(out[1].sharding, arr2.sharding) - @parameterized.named_parameters( - ('1', (16, 1), (1, 16, 1), P('x', None), P(None, 'x', None), False), - ('2', (8, 2, 1), (1, 16, 1), P('x', None, None), P(None, 'x', None), True), - ('3', (8, 1), (1, 4, 2), P('x', None), P(None, None, 'x'), True), - ('4', (1, 4, 1, 6, 1), (1, 4, 6), - P(None, 'x', None, None, None), P(None, 'x', None), False), - ('5', (4, 6), (4, 6), P(None, 'x'), P(None, 'x'), False), + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_fully_replicated_reshape(self, mesh): + np_inp = np.arange(64).reshape(64, 1) + arr = jax.device_put(np_inp, P(('x', 'y'))) + + @jax.jit + def f(x): + x = reshard(x, P(None, None)) + return jax.lax.reshape(x, (2, 32, 1)) + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P(None, None, None))) + self.assertArraysEqual(out, np_inp.reshape(2, 32, 1)) + + @parameterized.parameters( + (src_shape, dst_shape, src_spec, dst_spec, use_sharding_arg, fun) + for fun in [jnp.reshape, jax.lax.reshape] + for src_shape, dst_shape, src_spec, dst_spec, use_sharding_arg in [ + ((16, 1), (1, 16, 1), P('x', None), P(None, 'x', None), + False), + ((8, 2, 1), (1, 16, 1), P('x', None, None), + P(None, 'x', None), True), + ((8, 1), (1, 4, 2), P('x', None), P(None, None, 'x'), + True), + ((1, 4, 1, 6, 1), (1, 4, 6), + P(None, 'x', None, None, None), P(None, 'x', None), False), + ((4, 6), (4, 6), P(None, 'x'), P(None, 'x'), False), + ((1024, 4096), (1024, 2048, 2, 1, 1, 1, 1), + P('x', None), P('x', None, None, None, None, None, None), False), + ((1024, 4096, 32), (1024, 2048, 2, 1, 1, 32), + P('x', None, None), P('x', None, None, None, None, None), False), + ((1024, 4096), (1024, 1, 1, 4096), + P('x', None), P('x', None, None, None), False), + ((1024, 4096), (1024, 1, 1, 4096), + P(None, 'x'), P(None, None, None, 'x'), False), + ((1024, 2048, 2, 1, 1, 1), (1024, 4096), + P('x', None, None, None, None, None), P('x', None), False), + ((1024, 2048, 2, 1, 1, 1), (1024, 4096), + P(None, 'x', None, None, None, None), P(None, 'x'), False), + ] ) - @jtu.with_user_mesh((2,), ('x',)) + @jtu.with_explicit_mesh((2,), ('x',)) def test_reshape(self, src_shape, dst_shape, src_spec, dst_spec, - use_sharding_arg, mesh): + use_sharding_arg, fun, mesh): np_inp = np.arange(math.prod(src_shape), dtype=np.float32).reshape(src_shape) arr = jax.device_put(np_inp, NamedSharding(mesh, src_spec)) - @partial(jax.jit, static_argnums=1) + @jax.jit(static_argnums=1) def f(x, new_sharding): - y = lax.reshape(x, dst_shape, out_sharding=new_sharding) + y = fun(x, dst_shape, out_sharding=new_sharding) + self.assertEqual(y.aval.sharding.spec, dst_spec) + self.assertEqual(y.shape, dst_shape) y = y * 2 self.assertEqual(y.aval.sharding.spec, dst_spec) return y @@ -5502,6 +5645,12 @@ def g(x): out = jax.jit(jax.grad(g))(arr) self.assertEqual(out.sharding, arr.sharding) + @jtu.with_explicit_mesh((2,), "x") + def test_jnp_reshape_order_F(self, mesh): + value = jnp.ones(16, out_sharding=P('x')) + out = jnp.reshape(value, (-1, 4), order='F', out_sharding=P(None, 'x')) + self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'x'))) + @parameterized.named_parameters( ('split_1', (4, 6, 8), (4, 2, 3, 8), P('x', None, 'y'), P('x', None, None, 'y'), '' @@ -5528,6 +5677,12 @@ def g(x): ('split_6_error', (4, 8, 9), (4, 2, 2, 3, 3, 2), P('x', None, None), None, 'This reshape is not supported' ), + ('split_7', (10, 1), (2, 5, 1), P('x', None), P('x', None, None), ''), + ('split_8', (10, 1), (2, 5, 1, 1), P('x', None), + P('x', None, None, None), ''), + ('split_9', (10, 1, 1), (2, 5, 1, 1), P('x', None, None), + P('x', None, None, None), ''), + ('split_10', (1, 10), (1, 2, 5), P(None, 'x'), P(None, 'x', None), ''), ('merge_1', (4, 2, 3, 8), (4, 6, 8), P('x', None, None, 'y'), P('x', None, 'y'), '' ), @@ -5553,8 +5708,9 @@ def g(x): ('merge_6_error', (4, 2, 3, 8), (4, 8, 6), P(None, 'y', None, 'x'), None, 'This reshape is not supported' ), + ('merge_7', (2, 5, 1), (10, 1), P('x', None, None), P('x', None), ''), ) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_reshape_split_merge_one_axis(self, src_shape, dst_shape, src_spec, dst_spec, error_msg, mesh): np_inp = np.arange(math.prod(src_shape), @@ -5569,7 +5725,7 @@ def f(x): return y if error_msg: - with self.assertRaisesRegex(ValueError, error_msg): + with self.assertRaisesRegex(core.ShardingTypeError, error_msg): f(arr) else: out = f(arr) @@ -5586,7 +5742,7 @@ def g(x): out = jax.jit(jax.grad(g))(arr) self.assertEqual(out.sharding, arr.sharding) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_select(self, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -5608,7 +5764,7 @@ def f(pred, on_true, on_false): arr3 = jax.device_put(np_inp, NamedSharding(mesh, P('y', 'x'))) with self.assertRaisesRegex( - TypeError, "select cases must have the same shardings"): + core.ShardingTypeError, "select cases must have the same shardings"): f(arr1 == arr2, arr1, arr3) def test_explicit_mode_no_context_mesh(self): @@ -5655,61 +5811,17 @@ def f(x): out = f(arr) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) - @jtu.with_user_mesh((2, 2), ('x', 'y')) - def test_mesh_cast_reshard_error(self, mesh): + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_shard_map_full_manual(self, mesh): np_inp = np.arange(16).reshape(8, 2) - s = NamedSharding(mesh, P('x', 'y')) - arr = jax.device_put(np_inp, s) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) + arr2 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) - @jax.jit - def f(x): - y = mesh_cast(x, NamedSharding(x.aval.sharding.mesh, P('x', None))) - return y - - with self.assertRaisesRegex( - ValueError, - 'mesh_cast should only be used when AxisType changes between the input' - ' mesh and the target mesh'): - f(arr) - - @jax.jit - def g(x): - return mesh_cast(x, P('x', None)) - - with self.assertRaisesRegex( - ValueError, - 'mesh_cast should only be used when AxisType changes between the input' - ' mesh and the target mesh'): - g(arr) - - @jtu.with_user_mesh((2, 2), ('x', 'y'), - axis_types=(AxisType.Explicit, AxisType.Auto)) - def test_mesh_cast_explicit_data_movement_error(self, mesh): - np_inp = np.arange(16).reshape(8, 2) - s = NamedSharding(mesh, P('x', 'y')) - arr = jax.device_put(np_inp, s) - full_user_mesh = mesh_lib.AbstractMesh( - (2, 2), ('x', 'y'), axis_types=(AxisType.Explicit,) * 2) - - @jax.jit - def f(x): - return mesh_cast(x, NamedSharding(full_user_mesh, P('y', None))) - - with self.assertRaisesRegex( - ValueError, 'Explicit data movement in mesh_cast is not allowed'): - f(arr) - - @jtu.with_user_mesh((2, 2), ('x', 'y')) - def test_shard_map_full_manual(self, mesh): - np_inp = np.arange(16).reshape(8, 2) - arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) - arr2 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) - - def g(x, y): - self.assertTrue(x.aval.sharding.mesh._are_all_axes_manual) - self.assertTrue(y.aval.sharding.mesh._are_all_axes_manual) - self.assertTrue(mesh_lib.get_abstract_mesh()._are_all_axes_manual) - return x * y + def g(x, y): + self.assertTrue(x.aval.sharding.mesh.are_all_axes_manual) + self.assertTrue(y.aval.sharding.mesh.are_all_axes_manual) + self.assertTrue(mesh_lib.get_abstract_mesh().are_all_axes_manual) + return x * y @jax.jit def f(x, y): @@ -5725,16 +5837,16 @@ def f(x, y): self.assertArraysEqual(out, (np_inp * np_inp) * 2) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_shard_map_dot(self, mesh): np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x'))) def g(x, y): - self.assertTrue(x.aval.sharding.mesh._are_all_axes_manual) - self.assertTrue(y.aval.sharding.mesh._are_all_axes_manual) - self.assertTrue(mesh_lib.get_abstract_mesh()._are_all_axes_manual) + self.assertTrue(x.aval.sharding.mesh.are_all_axes_manual) + self.assertTrue(y.aval.sharding.mesh.are_all_axes_manual) + self.assertTrue(mesh_lib.get_abstract_mesh().are_all_axes_manual) allgatherd_y = jax.lax.all_gather(y, axis_name='x', axis=1, tiled=True) z = x @ allgatherd_y return jax.lax.psum(z, axis_name='y') @@ -5753,7 +5865,15 @@ def f(x, y): self.assertArraysEqual(out, (np_inp @ np_inp.T) * 2) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_full_like_eager_non_concrete_sharding(self): + s = NamedSharding(mesh_lib.AbstractMesh((2,), ('x',)), P('x')) + arr = jax.ShapeDtypeStruct((8, 2), np.float32, sharding=s) + out = jax.lax.full_like(arr, 0) + # The sharding is single device because the sharding of input `arr`` to + # full_like is not concrete. + self.assertEqual(out.sharding, SingleDeviceSharding(jax.devices()[0])) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_slice(self, mesh): np_inp = np.arange(16.).reshape(4, 4) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None))) @@ -5778,13 +5898,13 @@ def g(x): out = jax.jit(jax.grad(g))(arr) self.assertEqual(out.sharding, arr.sharding) - with self.assertRaisesRegex(NotImplementedError, "slicing on sharded dims"): + with self.assertRaisesRegex(core.ShardingTypeError, "slicing on sharded dims"): f(jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))) - with self.assertRaisesRegex(NotImplementedError, "slicing on sharded dims"): + with self.assertRaisesRegex(core.ShardingTypeError, "slicing on sharded dims"): f(jax.device_put(np_inp, NamedSharding(mesh, P(None, ('x', 'y'))))) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_squeeze(self, mesh): np_inp = np.arange(16.).reshape(4, 4, 1) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None, None))) @@ -5810,12 +5930,12 @@ def g(x): out = jax.jit(jax.grad(g))(arr) self.assertEqual(out.sharding, arr.sharding) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_pad(self, mesh): np_inp = np.arange(8.) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x'))) - @partial(jax.jit, static_argnums=(1, 2)) + @jax.jit(static_argnums=(1, 2)) def f(x, padding_config, spec): y = lax.pad(x, 0., padding_config) self.assertEqual(y.aval.sharding.spec, spec) @@ -5842,24 +5962,24 @@ def g(x): out = jax.jit(jax.grad(g))(arr) self.assertEqual(out.sharding, arr.sharding) - with self.assertRaisesRegex(NotImplementedError, "padding on sharded dims"): + with self.assertRaisesRegex(core.ShardingTypeError, "padding on sharded dims"): f(arr, ((2, 3, 0), ), None) - with self.assertRaisesRegex(NotImplementedError, "padding on sharded dims"): + with self.assertRaisesRegex(core.ShardingTypeError, "padding on sharded dims"): f(arr, ((0, 3, 0), ), None) - with self.assertRaisesRegex(NotImplementedError, "padding on sharded dims"): + with self.assertRaisesRegex(core.ShardingTypeError, "padding on sharded dims"): arr = jax.device_put(np_inp, NamedSharding(mesh, P(('x', 'y')))) f(arr, ((4, 4, 1),), None) - @jtu.with_user_mesh((2, 1), ('x', 'y')) + @jtu.with_explicit_mesh((2, 1), ('x', 'y')) def test_concatenate(self, mesh): np_inp = np.arange(16.).reshape(4, 4) s = NamedSharding(mesh, P('x', 'y')) arr1 = jax.device_put(np_inp, s) arr2 = jax.device_put(np.arange(4.).reshape(4, 1), s) - @partial(jax.jit, static_argnums=2) + @jax.jit(static_argnums=2) def f(x, y, method='jnp'): if method == 'jnp': y = jnp.concatenate([x, y], axis=1) @@ -5879,7 +5999,7 @@ def f(x, y, method='jnp'): self.assertArraysEqual(out, np.concatenate([arr1, arr2], axis=1)) with self.assertRaisesRegex( - TypeError, "All operands should have the same sharding"): + core.ShardingTypeError, "All operands should have the same sharding"): arr3 = jax.device_put(np.arange(4.).reshape(4, 1), NamedSharding(mesh, P('x'))) f(arr1, arr3) @@ -5894,7 +6014,7 @@ def g(x, y): out = jax.jit(jax.grad(g))(arr1, arr2) self.assertEqual(out.sharding, s) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_scan(self, mesh): carry = jax.device_put(np.arange(16.).reshape(2, 8), NamedSharding(mesh, P(None, 'x'))) @@ -5932,7 +6052,21 @@ def g(carry, arr): ValueError, "0th dimension of all xs should be replicated"): f(carry, jax.device_put(arr, NamedSharding(mesh, P('x', None, None)))) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2,), 'x') + def test_scan_carry_shardings_diff_error(self, mesh): + @jax.jit + def f(x): + def g(carry, _): + y = reshard(carry, P()) + return y, None + return jax.lax.scan(g, x, None, length=1)[0] + + with self.assertRaisesRegex( + TypeError, + r'scan.*carry input and carry output must have equal types'): + f(jax.device_put(np.arange(4), P('x'))) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_argminmax(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -5953,7 +6087,7 @@ def f(x): self.assertEqual(out2.sharding, NamedSharding(mesh, P('x'))) self.check_wsc_in_lowered(f.lower(arr).as_text()) - @jtu.with_user_mesh((2, 2), ('x', 'y'), (mesh_lib.AxisType.Auto,) * 2) + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), (mesh_lib.AxisType.Auto,) * 2) def test_only_auto(self, mesh): np_inp = np.arange(16.).reshape(8, 2) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None))) @@ -5985,7 +6119,7 @@ def f(x, x2): a = z @ x2 return a - with jax.sharding.use_mesh(mesh): + with jax.set_mesh(mesh): out = f(arr, arr.T) self.assertEqual(out.sharding, NamedSharding(mesh, P('x',))) lowered_text = f.lower(arr, arr.T).as_text() @@ -5994,7 +6128,7 @@ def f(x, x2): mesh2 = jtu.create_mesh((2, 2), ('x', 'y'), axis_types=(mesh_lib.AxisType.Explicit, mesh_lib.AxisType.Auto)) - with jax.sharding.use_mesh(mesh2): + with jax.set_mesh(mesh2): arr = jax.device_put(arr, NamedSharding(mesh2, P('x', 'y'))) arr2 = jax.device_put(np_inp.T, NamedSharding(mesh2, P('y', None))) out = f(arr, arr2) @@ -6008,7 +6142,7 @@ def f(x, x2): mesh3 = jtu.create_mesh((2, 2), ('x', 'y'), axis_types=(mesh_lib.AxisType.Auto, mesh_lib.AxisType.Explicit)) - with jax.sharding.use_mesh(mesh3): + with jax.set_mesh(mesh3): arr = jax.device_put(arr, NamedSharding(mesh3, P('x', 'y'))) arr2 = jax.device_put(np_inp.T, NamedSharding(mesh3, P(None, 'x'))) out = f(arr, arr2) @@ -6020,12 +6154,7 @@ def f(x, x2): else: self.assertTrue(lowered_text.count("unspecified_dims") == 4) - with self.assertRaisesRegex( - ValueError, - "AxisTypes should be the same in a tuple subset of PartitionSpec"): - NamedSharding(mesh2, P(('x', 'y'))) - - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_where_with_scalar(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -6035,8 +6164,9 @@ def test_where_with_scalar(self, mesh): self.assertArraysEqual(out, x) self.assertEqual(out.sharding, s) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_full_user_to_full_auto(self, mesh): + am = mesh.abstract_mesh np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) @@ -6044,14 +6174,15 @@ def test_full_user_to_full_auto(self, mesh): @jax.jit def f(x): y = x * 2 - with use_auto_axes('x', 'y'): - y = mesh_cast(y, P(None, None)) + with use_abstract_mesh( + am.update_axis_types({'x': AxisType.Auto, 'y': AxisType.Auto})): + y = reshard(y, P(None, None)) self.assertEqual(y.aval.sharding.spec, P(None, None)) z = jnp.sin(y) self.assertEqual(z.aval.sharding.spec, P(None, None)) a = z @ z.T self.assertEqual(a.aval.sharding.spec, P(None, None)) - a = mesh_cast(a, P('x', None)) + a = reshard(a, P('x', None)) self.assertEqual(a.aval.sharding.spec, P('x', None)) return a @@ -6062,9 +6193,10 @@ def f(x): out2 = core.jaxpr_as_fun(jaxpr)(arr) self.assertEqual(out2[0].sharding, NamedSharding(mesh, P('x', None))) - @jtu.with_user_mesh((2, 2), ('x', 'y'), + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(mesh_lib.AxisType.Auto,) * 2) def test_full_auto_to_full_user(self, mesh): + am = mesh.abstract_mesh np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) @@ -6072,12 +6204,14 @@ def test_full_auto_to_full_user(self, mesh): @jax.jit def f(x): y = x * 2 - with use_explicit_axes('x', 'y'): - y = mesh_cast(y, P(None, 'y')) + with use_abstract_mesh( + am.update_axis_types({'x': AxisType.Explicit, + 'y': AxisType.Explicit})): + y = reshard(y, P(None, 'y')) self.assertEqual(y.aval.sharding.spec, P(None, 'y')) z = jnp.sin(y) self.assertEqual(z.aval.sharding.spec, P(None, 'y')) - a = mesh_cast(z, P(None, None)) + a = reshard(z, P(None, None)) self.assertEqual(a.aval.sharding.spec, P(None, None)) return a @@ -6087,7 +6221,7 @@ def f(x): jaxpr = f.trace(arr).jaxpr core.jaxpr_as_fun(jaxpr)(arr) # doesn't crash - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_full_user_to_auto_user_mix(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -6096,14 +6230,15 @@ def test_full_user_to_auto_user_mix(self, mesh): @jax.jit def f(x): y = x * 2 - with use_auto_axes('x'): - y = mesh_cast(y, P(None, 'y')) + with use_abstract_mesh( + mesh.abstract_mesh.update_axis_types({'x': AxisType.Auto})): + y = reshard(y, P(None, 'y')) self.assertEqual(y.aval.sharding.spec, P(None, 'y')) z = jnp.sin(y) self.assertEqual(z.aval.sharding.spec, P(None, 'y')) a = jnp.einsum('xy,yz->xz', z, z.T, out_sharding=P(None, None)) self.assertEqual(a.aval.sharding.spec, P(None, None)) - a = mesh_cast(a, P('x', None)) + a = reshard(a, P('x', None)) self.assertEqual(a.aval.sharding.spec, P('x', None)) return a @@ -6114,7 +6249,7 @@ def f(x): out2 = core.jaxpr_as_fun(jaxpr)(arr) self.assertEqual(out2[0].sharding, NamedSharding(mesh, P('x', None))) - @jtu.with_user_mesh((2, 1), ('x', 'y')) + @jtu.with_explicit_mesh((2, 1), ('x', 'y')) def test_user_auto_mix_error(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -6123,7 +6258,8 @@ def test_user_auto_mix_error(self, mesh): @jax.jit def f(x, y): x = x * 2 - with use_auto_axes('x'): + with use_abstract_mesh( + mesh.abstract_mesh.update_axis_types({'x': AxisType.Auto})): z = x @ y return z @@ -6131,13 +6267,13 @@ def f(x, y): ValueError, "For primitive dot_general, context mesh.*aval mesh"): f(arr, arr.T) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_split(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) - @partial(jax.jit, static_argnums=(1, 2)) + @jax.jit(static_argnums=(1, 2)) def f(x, sizes=(4, 4), axis=0): ys = lax.split(x, sizes, axis=axis) self.assertEqual(ys[0].aval.sharding.spec, P('x', 'y')) @@ -6147,7 +6283,7 @@ def f(x, sizes=(4, 4), axis=0): f(arr) self.check_wsc_in_lowered(f.lower(arr).as_text()) - with self.assertRaisesRegex(NotImplementedError, "split on sharded dims"): + with self.assertRaisesRegex(core.ShardingTypeError, "split on sharded dims"): f(arr, sizes=(1, 1), axis=1) def g(x): @@ -6160,7 +6296,7 @@ def g(x): out = jax.jit(jax.grad(g))(arr) self.assertEqual(out.sharding, s) - @jtu.with_user_mesh((2,), 'x') + @jtu.with_explicit_mesh((2,), 'x') def test_return_output_different_context(self, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x')) @@ -6168,53 +6304,72 @@ def test_return_output_different_context(self, mesh): @jax.jit def f(x): - with use_auto_axes('x'): - x = mesh_cast(x, P(None, None)) + with use_abstract_mesh( + mesh.abstract_mesh.update_axis_types({'x': AxisType.Auto})): + x = reshard(x, P(None, None)) return x - self.assertDictEqual(arr.sharding.mesh._axis_types_dict, - {AxisType.Explicit: ('x',)}) + self.assertTupleEqual(arr.sharding.mesh.axis_types, (AxisType.Explicit,)) out = f(arr) self.assertArraysEqual(out, np_inp) - self.assertDictEqual(out.sharding.mesh._axis_types_dict, - {AxisType.Auto: ('x',)}) + self.assertTupleEqual(out.sharding.mesh.axis_types, (AxisType.Auto,)) - @jtu.with_user_mesh((2,), 'x') - def test_device_put_use_mesh(self, mesh): + @jtu.with_explicit_mesh((2,), 'x') + def test_use_abstract_mesh_override(self, mesh): + with self.assertRaisesRegex( + ValueError, "use_abstract_mesh cannot change the size.*"): + new_am = jax.sharding.AbstractMesh((4,), ('x',)) + with use_abstract_mesh(new_am): + _ = jnp.arange(8) + + new_am = jax.sharding.AbstractMesh((2,), ('y',)) + with use_abstract_mesh(new_am): + out = jnp.arange(8) + self.assertEqual( + out.sharding, + NamedSharding(mesh.update(axis_names=('y',), axis_types=(AxisType.Auto,)), + P())) + + new_am = jax.sharding.AbstractMesh((2,), ('x',), (AxisType.Explicit,)) + with use_abstract_mesh(new_am): + out = jnp.arange(8) + self.assertEqual(out.sharding, NamedSharding( + mesh.update(axis_types=(AxisType.Explicit,)), P(None))) + self.assertArraysEqual(out, np.arange(8)) + + @jtu.with_explicit_mesh((2,), 'x') + def test_device_put_set_mesh(self, mesh): out = jax.device_put(np.arange(8), P('x')) self.assertArraysEqual(out, np.arange(8)) self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) - def test_device_put_no_use_mesh_error(self): + def test_device_put_no_set_mesh_error(self): with self.assertRaisesRegex( ValueError, - 'Please set a mesh via `jax.sharding.use_mesh` if a PartitionSpec is' + 'Please set a mesh via `jax.set_mesh` if a PartitionSpec is' ' passed to device_put'): jax.device_put(np.arange(8), P('x')) - @jtu.with_user_mesh((2,), 'x') + @jtu.with_explicit_mesh((2,), 'x') def test_inputs_different_context(self, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x')) arr = jax.device_put(np_inp, s) auto_mesh = jax.make_mesh((2,), 'x', axis_types=(AxisType.Auto,)) - with jax.sharding.use_mesh(auto_mesh): + with jax.set_mesh(auto_mesh): arr2 = jnp.ones(8) - self.assertDictEqual(arr2.sharding.mesh._axis_types_dict, - {AxisType.Auto: ('x',)}) + self.assertTupleEqual(arr2.sharding.mesh.axis_types, (AxisType.Auto,)) @jax.jit def f(x, y): return x, y out1, out2 = f(arr, arr2) - self.assertDictEqual(out1.sharding.mesh._axis_types_dict, - {AxisType.Explicit: ('x',)}) - self.assertDictEqual(out2.sharding.mesh._axis_types_dict, - {AxisType.Auto: ('x',)}) + self.assertTupleEqual(out1.sharding.mesh.axis_types, (AxisType.Explicit,)) + self.assertTupleEqual(out2.sharding.mesh.axis_types, (AxisType.Auto,)) - @jtu.with_user_mesh((2,), 'x') + @jtu.with_explicit_mesh((2,), 'x') def test_output_different_context_error(self, mesh): np_inp1 = np.arange(16).reshape(8, 2) arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None))) @@ -6233,7 +6388,8 @@ def f(x, y): @jax.jit def g(x, y): - with use_auto_axes('x'): + with use_abstract_mesh( + mesh.abstract_mesh.update_axis_types({'x': AxisType.Auto})): out = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', None)) return out @@ -6241,7 +6397,7 @@ def g(x, y): ValueError, "PartitionSpec.*cannot contain axis names.*Auto"): g(arr1, arr2) - @jtu.with_user_mesh((2, 2, 2), ('x', 'y', 'z'), + @jtu.with_explicit_mesh((2, 2, 2), ('x', 'y', 'z'), axis_types=(AxisType.Explicit, AxisType.Explicit, AxisType.Auto)) def test_out_sharding_mix_axis_types(self, mesh): @@ -6262,17 +6418,17 @@ def f(x): lowered_text = f.lower(arr).as_text() if config.use_shardy_partitioner.value: self.assertTrue(lowered_text.count( - '[{"x"}, {?}, {?}], replicated={"y"}') == 3) + '[{"x", ?}, {?}, {?}], replicated={"y"}') == 3) else: - self.assertTrue(lowered_text.count("unspecified_dims=[1,2]") == 3) + self.assertTrue(lowered_text.count("unspecified_dims=[0,1,2]") == 3) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_auto_mode_mix(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) - @partial(auto_axes, axes='x', out_shardings=P('x', None)) + @auto_axes(axes='x', out_sharding=P('x', None)) def h(y): self.assertEqual(y.aval.sharding.spec, P(None, 'y')) z = jnp.sin(y) @@ -6295,7 +6451,192 @@ def g(x): out2 = core.jaxpr_as_fun(jaxpr)(arr) self.assertEqual(out2[0].sharding, NamedSharding(mesh, P('x', None))) - @jtu.with_user_mesh((4,), ('x',)) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_auto_mode_mix_dual_sharded_output(self, mesh): + np_inp = np.arange(16.).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + @partial(auto_axes, axes='x', out_sharding=P('x', 'y')) + def h(y): + self.assertEqual(y.aval.sharding.spec, P(None, 'y')) + z = jnp.sin(y) + self.assertEqual(z.aval.sharding.spec, P(None, 'y')) + a = jnp.einsum('xy,yz->xz', z, z.T, out_sharding=P(None, 'y')) + self.assertEqual(a.aval.sharding.spec, P(None, 'y')) + return a + + @jax.jit + def g(x): + y = x * 2 + a = h(y) + self.assertEqual(a.aval.sharding.spec, P('x', 'y')) + return a + + out = g(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + jaxpr = g.trace(arr).jaxpr + out2 = core.jaxpr_as_fun(jaxpr)(arr) + self.assertEqual(out2[0].sharding, NamedSharding(mesh, P('x', 'y'))) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), iota_order=True) + def test_device_put_different_dst_mesh(self, mesh): + np1 = np.arange(16).reshape(8, 2) + x = jax.device_put(np1, P('x', 'y')) + mesh2 = jtu.create_mesh((4,), ('a',), axis_types=(AxisType.Explicit,), + iota_order=True) + y = jax.device_put(x, NamedSharding(mesh2, P('a', None))) + self.assertEqual(y.sharding, NamedSharding(mesh2, P('a', None))) + self.assertArraysEqual(y, np1) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_auto_mode_mix_repeat(self, mesh): + np_inp = np.arange(16).reshape(8, 1, 2) + s = NamedSharding(mesh, P('x', None, 'y')) + arr = jax.device_put(np_inp, s) + + @partial(auto_axes, axes='x', out_sharding=P('x', None, 'y')) + def h(y): + return jnp.repeat(y, 2, axis=1) + + @jax.jit + def g(x): + y = x * 2 + a = h(y) + self.assertEqual(a.aval.sharding.spec, P('x', None, 'y')) + return a + + out = g(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None, 'y'))) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_manual_mode_mix_repeat(self, mesh): + np_inp = np.arange(16.).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + def h(y): + return jnp.repeat(y, 2, axis=1, out_sharding=P(None, 'y')) + + @jax.jit + def g(x): + y = x * 2 + a = jax.shard_map( + h, + out_specs=P('x'), + axis_names={'x'}, + )(y) + self.assertEqual(a.aval.sharding.spec, P('x', 'y')) + return a + + out = g(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_manual_mode_mix_repeat_no_out_sharding(self, mesh): + np_inp = np.arange(16.).reshape(8, 1, 2) + s = NamedSharding(mesh, P('x', None, 'y')) + arr = jax.device_put(np_inp, s) + + def h(y): + return jnp.repeat(y, 2, axis=1) + + @jax.jit + def g(x): + y = x * 2 + a = jax.shard_map( + h, + out_specs=P('x'), + axis_names={'x'}, + )(y) + self.assertEqual(a.aval.sharding.spec, P('x', None, 'y')) + return a + + out = g(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None, 'y'))) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_manual_mode_mix_random_no_out_sharding(self, mesh): + s = NamedSharding(mesh, P('x')) + keys = jax.random.split(jax.random.key(0), 2) + keys = reshard(keys, s) + + def h(key): + key = key.squeeze(0) + arr = jax.random.normal(key, (1, 4)) + return reshard(arr, P(None, 'y')) + + @jax.jit + def g(keys): + a = jax.shard_map( + h, + out_specs=P('x'), + axis_names={'x'}, + )(keys) + self.assertEqual(a.aval.sharding.spec, P('x', 'y')) + return a + out = g(keys) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_manual_mode_mix_random(self, mesh): + self.skipTest('Failing.') + s = NamedSharding(mesh, P('x')) + keys = jax.random.split(jax.random.key(0), 2) + keys = reshard(keys, s) + + def h(key): + key = key.squeeze(0) + return jax.random.normal(key, (1, 4), out_sharding=P(None, 'y')) + + @jax.jit + def g(keys): + a = jax.shard_map( + h, + out_specs=P('x'), + axis_names={'x'}, + )(keys) + self.assertEqual(a.aval.sharding.spec, P('x', 'y')) + return a + out = g(keys) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_manual_mode_mix_scatter_gather(self, mesh): + x = np.random.uniform(size=(mesh.size * 2, 4)) + i = np.random.randint(0, x.shape[1], len(x)) + j = np.random.randint(0, x.shape[1], len(x)) + x = jax.device_put(x, P('x', 'y')) + i = jax.device_put(i, P('y')) + j = jax.device_put(j, P('y')) + + @jax.jit + @partial(jax.shard_map, out_specs=P('x'), axis_names={'x'}) + def f1(x, i, j): + x_a_j = x.at[:, j].get(out_sharding=jax.typeof(i).sharding) + return x.at[:, i].set(x_a_j, out_sharding=jax.typeof(x).sharding) + f1(x,i,j) # doesn't crash + + @config.numpy_rank_promotion('allow') + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_manual_mode_mix_map(self, mesh): # pylint: disable=unused-argument + def simple_func(w, x): + return jnp.sum(w * x, axis=-1) + + w = jax.device_put(np.arange(4, dtype=np.float32), P()) + x = jax.device_put(np.ones((2, 5, 2, 4), dtype=np.float32), + P('x', None, 'y', None)) + + for batch_size in (None, 2): + @jax.jit + @partial(jax.shard_map, out_specs=P('x'), axis_names={'x'}) + def map_fn(x, w): + return jax.lax.map(partial(simple_func, w), x.squeeze(0), + batch_size=batch_size) # pylint: disable=cell-var-from-loop + map_fn(x, w) # doesn't crash + + @jtu.with_explicit_mesh((4,), ('x',)) def test_concat_vmap(self, mesh): @jax.jit def _f(sharded_array, replicated_array): @@ -6326,15 +6667,15 @@ def test_aval_spec_explicit_auto_complete(self): out = core.ShapedArray((8, 2), jnp.int32, sharding=s) self.assertEqual(out.sharding.spec, P('x', None)) - @jtu.with_user_mesh((2, 2), ('x', 'y'), - axis_types=(mesh_lib.AxisType.Auto,) * 2) - def test_full_user_mode(self, mesh): + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), + axis_types=(mesh_lib.AxisType.Auto,) * 2) + def test_full_explicit_mode(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) - # No axes specified means full visible mode. - @partial(explicit_axes, in_shardings=P('x', 'y')) + # No axes specified means full explicit mode. + @partial(explicit_axes, in_sharding=P('x', 'y')) def h(y): self.assertEqual(y.aval.sharding.spec, P('x', 'y')) z = jnp.sin(y) @@ -6356,12 +6697,12 @@ def f(x): jaxpr = f.trace(arr).jaxpr core.jaxpr_as_fun(jaxpr)(arr) # doesn't crash - @jtu.with_user_mesh((4,), ('data',)) + @jtu.with_explicit_mesh((4,), ('data',)) def test_intermediate_einsum(self, mesh): shape1 = (8, 32, 1, 16) shape2 = (8, 32, 1, 8) - np_inp1 = np.arange(math.prod(shape1)).reshape(shape1) - np_inp2 = np.arange(math.prod(shape2)).reshape(shape2) + np_inp1 = np.ones(math.prod(shape1)).reshape(shape1) + np_inp2 = np.ones(math.prod(shape2)).reshape(shape2) s = NamedSharding(mesh, P('data')) arr1 = jax.device_put(np_inp1, s) @@ -6380,16 +6721,16 @@ def f(x, y, z): self.assertEqual(out.shape, (16, 8, 16)) self.assertEqual(out.sharding, NamedSharding(mesh, P('data', None, None))) - @jtu.with_user_mesh((4,), ('data',)) + @jtu.with_explicit_mesh((4,), ('data',)) def test_intermediate_einsum_auto_complete_spec(self, mesh): s = NamedSharding(mesh, P('data')) shape1 = (8, 32, 2*16) shape2 = (8, 32, 2, 8) shape3 = (8, 32, 2, 8) - np_inp1 = np.arange(math.prod(shape1)).reshape(shape1) - np_inp2 = np.arange(math.prod(shape2)).reshape(shape2) - np_inp3 = np.arange(math.prod(shape3)).reshape(shape3) + np_inp1 = np.ones(math.prod(shape1)).reshape(shape1) + np_inp2 = np.ones(math.prod(shape2)).reshape(shape2) + np_inp3 = np.ones(math.prod(shape3)).reshape(shape3) arr1 = jax.device_put(np_inp1, s) arr2 = jax.device_put(np_inp2, s) @@ -6421,8 +6762,8 @@ def test_where_with_prng_sharded_inp(self): def f(condition, x, y): condition = jnp.asarray(condition) - self.assertTrue(x.aval.sharding.mesh._are_all_axes_auto) - self.assertTrue(y.aval.sharding.mesh._are_all_axes_auto) + self.assertTrue(x.aval.sharding.mesh.are_all_axes_auto) + self.assertTrue(y.aval.sharding.mesh.are_all_axes_auto) x1 = jnp.asarray(x) self.assertEqual(x1.aval.sharding, x.aval.sharding) y1 = jnp.asarray(y) @@ -6432,12 +6773,12 @@ def f(condition, x, y): f = jax.jit(f, in_shardings=(sharding, sharding, sharding)) f(condition, x, x).block_until_ready() - @jtu.with_user_mesh((4,), ('data',)) + @jtu.with_explicit_mesh((4,), ('data',)) def test_intermediate_einsum_conflict_error(self, mesh): shape1 = (8, 32, 1, 16) shape2 = (8, 32, 1, 8) - np_inp1 = np.arange(math.prod(shape1)).reshape(shape1) - np_inp2 = np.arange(math.prod(shape2)).reshape(shape2) + np_inp1 = np.ones(math.prod(shape1)).reshape(shape1) + np_inp2 = np.ones(math.prod(shape2)).reshape(shape2) arr1 = jax.device_put( np_inp1, NamedSharding(mesh, P(None, None, None, 'data'))) @@ -6446,25 +6787,23 @@ def test_intermediate_einsum_conflict_error(self, mesh): @jax.jit def f(x, y, z): - return jnp.einsum('bthD, bthi, bthj->ijD', x, y, z, - out_sharding=P('data', None, None)) + out = jnp.einsum('bthD, bthi, bthj->ijD', x, y, z, + out_sharding=P('data', None, None)) + self.assertEqual(out.aval.sharding.spec, P('data', None, None)) + return out - # Errors out on the intermediate einsum: `bthj,bthD->bthjD` - # because of a conflict - with self.assertRaisesRegex( - TypeError, - 'dot_general operation.*produces an illegally sharded result'): - f(arr1, arr2, arr3) + out = f(arr1, arr2, arr3) + self.assertEqual(out.sharding, NamedSharding(mesh, P('data', None, None))) - @jtu.with_user_mesh((2, 2), ('x', 'y'), - axis_types=(mesh_lib.AxisType.Explicit, - mesh_lib.AxisType.Auto)) + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), + axis_types=(mesh_lib.AxisType.Explicit, + mesh_lib.AxisType.Auto)) def test_mix_to_full_user_mode(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) - @partial(explicit_axes, axes='y', in_shardings=P('x', 'y')) + @partial(explicit_axes, axes='y', in_sharding=P('x', 'y')) def h(y): self.assertEqual(y.aval.sharding.spec, P('x', 'y')) z = jnp.sin(y) @@ -6483,14 +6822,14 @@ def f(x): out = f(arr) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - @jtu.with_user_mesh((2, 2), ('x', 'y'), + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(mesh_lib.AxisType.Auto,) * 2) def test_full_auto_to_partial_user(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) - @partial(explicit_axes, axes='y', in_shardings=P(None, 'y')) + @partial(explicit_axes, axes='y', in_sharding=P(None, 'y')) def h(y): self.assertEqual(y.aval.sharding.spec, P(None, 'y')) z = jnp.sin(y) @@ -6509,7 +6848,7 @@ def f(x): out = f(arr) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_auto_gather_out_sharding(self, mesh): embed = jax.device_put(jnp.arange(128 * 8.).reshape(64, 16), jax.NamedSharding(mesh, P(None, 'x'))) @@ -6544,8 +6883,79 @@ def g(x, y): out = jax.jit(jax.grad(g))(embed, tok) self.assertEqual(out.sharding, embed.sharding) - @jtu.with_user_mesh((2, 2), ('x', 'y')) - def test_reshard_error(self, mesh): + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_gather_sharding_rule(self, mesh): + embed = jax.device_put(jnp.arange(128 * 8.).reshape(64, 16), + jax.NamedSharding(mesh, P('x', None))) + tok = jax.device_put(jnp.arange(8 * 4).reshape(8, 4), + jax.NamedSharding(mesh, P())) + tok_vmap = jax.device_put(jnp.arange(64 * 4).reshape(64, 4), + jax.NamedSharding(mesh, P('x', None))) + + @jax.jit + def f(embed_vd, token_bt, token_vmap): + # Operand sharded in full slice dim and indices replicated. + out = embed_vd.at[:, token_bt].get() + self.assertEqual(out.shape, (64, 8, 4)) + self.assertEqual(out.aval.sharding.spec, P('x', None, None)) + + # Operand and indices sharded in batching dims. + out2 = jax.vmap(lambda x, y: x[y])(embed_vd, token_vmap) + self.assertEqual(out2.shape, (64, 4)) + self.assertEqual(out2.aval.sharding.spec, P('x', None)) + return out, out2 + + outs = f(embed, tok, tok_vmap) + self.assertEqual(outs[0].sharding, NamedSharding(mesh, P('x', None, None))) + + def g(x, y, z): + outs = f(x, y, z) + return outs[0].sum() + outs[1].sum() + + out = jax.jit(jax.grad(g))(embed, tok, tok_vmap) + self.assertEqual(out.sharding, embed.sharding) + + out = jax.grad(g)(embed, tok, tok_vmap) + self.assertEqual(out.sharding, embed.sharding) + + @parameterized.named_parameters( + (f'operand_{spec_name}_sharded_{op_name}', operand_spec, update_spec, op) + for (spec_name, operand_spec, update_spec), (op_name, op) in itertools.product( + (('xy', P('x', None, 'y'), P('x', 'y')), + ('x', P('x', None, None), P('x', None))), + (('set', lambda x, ind, y: x.at[ind].set(y)), + ('add', lambda x, ind, y: x.at[ind].add(y)), + ('mul', lambda x, ind, y: x.at[ind].mul(y)), + ('min', lambda x, ind, y: x.at[ind].min(y)), + ('max', lambda x, ind, y: x.at[ind].max(y)), + ('dynamic_update_slice_in_dim', lambda x, ind, y: ( + jax.lax.dynamic_update_slice_in_dim(x, y[None], ind, axis=0)))), + ) + ) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_scatter_sharding_rule(self, operand_spec, update_spec, scatter_fn, + mesh): + operand = jax.device_put(jnp.zeros((2, 10, 8)), + jax.NamedSharding(mesh, operand_spec)) + indices = jax.device_put(jnp.array([2, 3], dtype=jnp.int32), + jax.NamedSharding(mesh, P('x'))) + updates = jax.device_put(jnp.ones((2, 8)), + jax.NamedSharding(mesh, update_spec)) + + f = jax.jit(jax.vmap(scatter_fn)) + + out = f(operand, indices, updates) + self.assertEqual(out.sharding.spec, operand_spec) + + def g(*args): + return f(*args).sum() + + outs = jax.grad(g, argnums=(0, 2))(operand, indices, updates) + self.assertEqual(outs[0].sharding.spec, operand_spec) + self.assertEqual(outs[1].sharding.spec, update_spec) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_reshard_api(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) @@ -6590,12 +7000,10 @@ def f_vmap(x): @jax.jit def h(x): - with use_auto_axes('x'): + with use_abstract_mesh( + mesh.abstract_mesh.update_axis_types({'x': AxisType.Auto})): return reshard(x, P('y', None)) - - with self.assertRaisesRegex( - ValueError, 'Mesh of the input.*does not equal.*target sharding'): - h(arr) + h(arr) # doesn't crash def test_auto_axes_top_level(self): mesh = jtu.create_mesh((2, 2), ('x', 'y'), @@ -6604,7 +7012,7 @@ def test_auto_axes_top_level(self): arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x'))) - @partial(auto_axes, out_shardings=P('x', None)) + @partial(auto_axes, out_sharding=P('x', None)) def auto_matmul(arr1, arr2): return arr1 @ arr2 @@ -6615,7 +7023,7 @@ def f(arr1, arr2): self.assertEqual(z.aval.sharding.spec, P('x', None)) return z + 1 - with jax.sharding.use_mesh(mesh): + with jax.set_mesh(mesh): out = f(arr1, arr2) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) @@ -6626,7 +7034,7 @@ def test_explicit_axes_top_level(self): arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x'))) - @partial(explicit_axes, in_shardings=(P('x', None), P('x', None))) + @partial(explicit_axes, in_sharding=(P('x', None), P('x', None))) def jax_matmul(arr1, arr2): out = arr1 @ arr2 self.assertEqual(out.aval.sharding.spec, P('x', None)) @@ -6638,7 +7046,7 @@ def f(arr1, arr2): z = jax_matmul(y, arr2) return z + 1 - with jax.sharding.use_mesh(mesh): + with jax.set_mesh(mesh): out = f(arr1, arr2) self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) @@ -6656,10 +7064,10 @@ def matmul_reshard(arr1, arr2): self.assertEqual(out.aval.sharding.spec, P('x', 'y')) return out - with jax.sharding.use_mesh(mesh): + with jax.set_mesh(mesh): matmul_reshard(arr1, arr2) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_full_auto_outside_jit(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -6675,11 +7083,11 @@ def f(x): self.assertEqual(a.aval.sharding.spec, P(None, None)) return a - hf = auto_axes(f, axes=('x', 'y'), out_shardings=P('x', 'y')) + hf = auto_axes(f, axes=('x', 'y'), out_sharding=P('x', 'y')) out = hf(arr) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - @jtu.with_user_mesh((2, 2), ('x', 'y'), + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Auto,) * 2) def test_full_visible_outside_jit(self, mesh): np_inp = np.arange(16.).reshape(8, 2) @@ -6694,7 +7102,7 @@ def f(x): self.assertEqual(z.aval.sharding.spec, P('x', 'y')) return z - hf = explicit_axes(f, axes=('x', 'y'), in_shardings=P('x', 'y')) + hf = explicit_axes(f, axes=('x', 'y'), in_sharding=P('x', 'y')) out = hf(arr) # doesn't crash self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) @@ -6704,9 +7112,9 @@ def test_compilation_cache_miss_when_devices_change(self): mesh2 = Mesh(np.asarray(devs[::-1]).reshape(2, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) - with jax.sharding.use_mesh(mesh1): + with jax.set_mesh(mesh1): arr1 = jax.device_put(np_inp, NamedSharding(mesh1, P('x', 'y'))) - with jax.sharding.use_mesh(mesh2): + with jax.set_mesh(mesh2): arr2 = jax.device_put(np_inp, NamedSharding(mesh2, P('x', 'y'))) @jax.jit @@ -6717,9 +7125,9 @@ def f(x): jtu.count_jit_and_pmap_lowerings() as lowering_count, jtu.count_jit_compilation_cache_miss() as compilation_count, jtu.count_pjit_cpp_cache_miss() as cpp_cache_miss_count): - with jax.sharding.use_mesh(mesh1): + with jax.set_mesh(mesh1): out1 = f(arr1) - with jax.sharding.use_mesh(mesh2): + with jax.set_mesh(mesh2): out2 = f(arr2) self.assertEqual(tracing_count(), 1) @@ -6732,7 +7140,7 @@ def f(x): self.assertTupleEqual(out2.sharding._device_assignment, tuple(mesh2.devices.flat)) - @jtu.with_user_mesh((2, 1), ('x', 'y')) + @jtu.with_explicit_mesh((2, 1), ('x', 'y')) def test_svd(self, mesh): np_inp = np.zeros([128, 128]) arr = jax.device_put(np_inp, NamedSharding(mesh, P(None, None))) @@ -6757,7 +7165,7 @@ def f(x, y): self.assertNotIn("mhlo.sharding", lowered_text) @parameterized.parameters(True, False) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_mul_vmap(self, use_jit, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -6794,7 +7202,7 @@ def g(x): self.assertEqual(out.sharding, arr.sharding) @parameterized.parameters(True, False) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_dot_general_vmap(self, use_jit, mesh): np_inp1 = np.arange(16.).reshape(4, 2, 2) np_inp2 = np.arange(16.).reshape(2, 4, 2) @@ -6813,7 +7221,7 @@ def f(x, y): self.assertEqual(out.shape, (2, 2, 4)) self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x'))) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_reshape_vmap(self, mesh): np_inp = np.arange(16).reshape(2, 8) arr = jax.device_put(np_inp, NamedSharding(mesh, P(None, 'x'))) @@ -6829,7 +7237,7 @@ def f(x): self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None, 'y'))) @parameterized.parameters(True, False) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_shit_vmap_error_check(self, use_jit, mesh): np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None))) @@ -6858,7 +7266,55 @@ def f(x, y): "Only one of spmd_axis_name or arrays sharded on.*spmd_axis_name"): jax.vmap(f, spmd_axis_name='y')(arr, arr) - @jtu.with_user_mesh((2,), ('x',)) + @parameterized.parameters('x', 'y') + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_spmd_axis_name_explicit_mode_assert(self, spmd_axis_name, mesh): + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) + + @jax.jit + @partial(jax.vmap, spmd_axis_name=spmd_axis_name) + def f(x): + self.assertEqual(x.aval.sharding.spec, P('y')) + out = x * 2 + self.assertEqual(out.aval.sharding.spec, P('y')) + return out + + if spmd_axis_name == 'x': + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + self.assertArraysEqual(out, np_inp * 2) + else: + assert spmd_axis_name == 'y' + with self.assertRaisesRegex( + ValueError, + "Only one of spmd_axis_name or arrays sharded on.*spmd_axis_name"): + f(arr) + + @parameterized.parameters( + (('data', 'model', 'stage'), ('data', 'model')), + (('data', 'stage'), 'data') + ) + @config.remove_size_one_mesh_axis_from_type(True) + @jtu.with_explicit_mesh((2, 2, 1), ('data', 'model', 'stage')) + def test_spmd_axis_name_explicit_mode_assert_remove_one_size( + self, in_spec, out_spec, mesh): + np_inp = np.arange(16).reshape(4, 2, 2) + arr = jax.device_put(np_inp, NamedSharding(mesh, P(in_spec, None))) + + @jax.jit + @partial(jax.vmap, spmd_axis_name=in_spec) + def f(x): + self.assertEqual(x.aval.sharding.spec, P(None, None)) + out = x * 2 + self.assertEqual(out.aval.sharding.spec, P(None, None)) + return out + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P(out_spec, None, None))) + self.assertArraysEqual(out, np_inp * 2) + + @jtu.with_explicit_mesh((2,), ('x',)) def test_unmapped_last_vmap(self, mesh): np_inp = np.arange(8) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x',))) @@ -6871,7 +7327,7 @@ def f(x): self.assertEqual(out.shape, (4, 8)) self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'x'))) - @jtu.with_user_mesh((2,), ('x',), axis_types=AxisType.Auto) + @jtu.with_explicit_mesh((2,), ('x',), axis_types=AxisType.Auto) def test_shmap_close_over(self, mesh): const = jnp.arange(8) def f(): @@ -6881,7 +7337,7 @@ def f(): shmap_f() # doesn't crash jax.jit(shmap_f)() # doesn't crash - @jtu.with_user_mesh((2, 2), ('x', 'y'), + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Auto,) * 2) def test_shmap_close_over_partial_auto(self, mesh): const = jnp.arange(8) @@ -6889,44 +7345,34 @@ def f(): return const * 2 shmap_f = shard_map(f, mesh=mesh, in_specs=(), out_specs=P('x'), - auto=frozenset({'y'})) + axis_names={'x'}) f = jax.jit(shmap_f) out = f() self.assertArraysEqual(out, jnp.concatenate([const * 2, const * 2])) jaxpr = f.trace().jaxpr - self.assertIn('mesh_cast', str(jaxpr)) + self.assertIn('reshard', str(jaxpr)) - @jtu.with_user_mesh((2, 1), ('x', 'y')) + @jtu.with_explicit_mesh((2, 1), ('x', 'y')) def test_wsc_error(self, mesh): - s = NamedSharding(mesh, P('x')) - with self.assertRaisesRegex( - ValueError, - "The spec of NamedSharding passed to with_sharding_constraint"): - jax.lax.with_sharding_constraint(np.arange(8), s) - - s = NamedSharding(mesh, P(('x', 'y'), None)) with self.assertRaisesRegex( ValueError, - "The spec of NamedSharding passed to with_sharding_constraint"): - jax.lax.with_sharding_constraint(np.arange(8).reshape(4, 2), s) + 'PartitionSpec.*cannot contain `P.UNCONSTRAINED` when no mesh' + ' axis_types are `Auto`'): + NamedSharding(mesh, P(P.UNCONSTRAINED)) - s = NamedSharding(mesh, P()) - jax.lax.with_sharding_constraint(np.arange(8), s) + @jax.jit + def f(x): + return jax.lax.with_sharding_constraint(x, P('x')) - s = NamedSharding(Mesh(mesh.devices, mesh.axis_names, - axis_types=(AxisType.Explicit, AxisType.Auto)), - P('x', P.UNCONSTRAINED)) with self.assertRaisesRegex( - ValueError, - "The spec of NamedSharding passed to with_sharding_constraint"): - jax.lax.with_sharding_constraint(np.arange(8).reshape(4, 2), s) + AssertionError, + '`with_sharding_constraint` acts as an assert when all axes of mesh are' + ' of type `Explicit`'): + f(jax.device_put(np.arange(8), P())) - with self.assertRaisesRegex( - ValueError, - 'PartitionSpec.*cannot contain `P.UNCONSTRAINED` when no mesh' - ' axis_types are `Auto`'): - NamedSharding(mesh, P(P.UNCONSTRAINED)) + jaxpr = f.trace(jax.device_put(np.arange(8), P('x'))).jaxpr + self.assertNotIn('sharding_constraint', str(jaxpr)) def test_pspec_einsum_no_context_mesh(self): mesh = jtu.create_mesh((1, 1), ('x', 'y'), @@ -6944,7 +7390,7 @@ def f(x, y): "Using PartitionSpec when.*not under a mesh context.*is not allowed"): f(arr, arr2) - @jtu.with_user_mesh((2, 1), ('x', 'y'), + @jtu.with_explicit_mesh((2, 1), ('x', 'y'), axis_types=(AxisType.Auto,) * 2) def test_error_on_canonicalize_under_auto_mode(self, mesh): np_inp = np.arange(16).reshape(8, 2) @@ -6961,7 +7407,41 @@ def f(x, y): "PartitionSpec passed to einsum cannot contain axis names.*Auto.*Manual"): f(arr, arr2) - @jtu.with_user_mesh((2,), ('x',)) + def test_broadcasted_iota_mix_axes(self): + mesh = jtu.create_mesh( + (2, 2, 2), ('x', 'y', 'z'), + axis_types=(AxisType.Auto, AxisType.Explicit, AxisType.Explicit)) + yz_sharding = NamedSharding(mesh, P(('y', 'z'))) + + @jax.jit + def iota(): + out = jax.lax.broadcasted_iota( + dtype=jnp.int32, + shape=(16, 24), + dimension=1, + out_sharding=yz_sharding) + self.assertEqual(out.aval.sharding.spec, P(('y', 'z'), None)) + return out + + with jax.set_mesh(mesh): + out = iota() + self.assertEqual(out.sharding, yz_sharding) + + @jtu.with_explicit_mesh((2, 2, 2), ('x', 'y', 'z')) + def test_broadcast_to(self, mesh): + x = np.arange(24).reshape((1, 24)) + x = jax.device_put(x, P(None, ('y', 'z'))) + + @jax.jit + def f(x): + out = jnp.broadcast_to(x, (8, 24), out_sharding=P('x', ('y', 'z'))) + self.assertEqual(out.aval.sharding.spec, P('x', ('y', 'z'))) + return out + + out = f(x) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', ('y', 'z')))) + + @jtu.with_explicit_mesh((2,), ('x',)) def test_cumsum(self, mesh): np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, NamedSharding(mesh, P())) @@ -6974,12 +7454,24 @@ def f(x): self.assertArraysEqual(out, np.cumsum(np_inp)) self.assertEqual(out.sharding, NamedSharding(mesh, P(None))) - def test_device_put_under_use_mesh(self): + @jax.jit + def f(x): + x = jnp.expand_dims(x, 1) + self.assertEqual(x.aval.sharding.spec, P('x', None)) + out = jnp.cumsum(x, axis=1) + self.assertEqual(out.aval.sharding.spec, P('x', None)) + return out + + arr2 = jax.device_put(np.arange(8), P('x')) + out = f(arr2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + + def test_device_put_under_set_mesh(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) x = jnp.zeros((4, 4), dtype=jnp.int32) x_np = np.zeros((4, 4), dtype=np.int32) s = NamedSharding(mesh, P('x', 'y')) - with jax.sharding.use_mesh(mesh): + with jax.set_mesh(mesh): y = jax.device_put(x, s) self.assertArraysEqual(y, x) self.assertEqual(y.sharding, s) @@ -6994,7 +7486,7 @@ def test_device_put_under_use_mesh(self): self.assertEqual(z.sharding, s2) @parameterized.parameters(True, False) - def test_wsc_pspec_use_mesh(self, sharded_inp): + def test_wsc_pspec_set_mesh(self, sharded_inp): mesh = jtu.create_mesh((2, 2), ('x', 'y')) np_inp = np.zeros((4, 4), dtype=np.int32) if sharded_inp: @@ -7002,12 +7494,12 @@ def test_wsc_pspec_use_mesh(self, sharded_inp): else: arr = np_inp - with jax.sharding.use_mesh(mesh): + with jax.set_mesh(mesh): out = with_sharding_constraint(arr, P('x', 'y')) self.assertArraysEqual(out, np_inp) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - with jax.sharding.use_mesh(mesh): + with jax.set_mesh(mesh): f = jax.jit(lambda x: with_sharding_constraint(x, P('x', 'y'))) jaxpr = f.trace(arr).jaxpr self.assertIsInstance(jaxpr.eqns[0].params['sharding'].mesh, @@ -7021,26 +7513,26 @@ def test_wsc_pspec_use_mesh(self, sharded_inp): self.assertArraysEqual(out2, np_inp) self.assertEqual(out2.sharding, NamedSharding(mesh, P('x', 'y'))) - @jtu.with_user_mesh((2, 1), ('x', 'y'), + @jtu.with_explicit_mesh((2, 1), ('x', 'y'), axis_types=(AxisType.Auto,) * 2) def test_axes_api_error_manual_to_auto_explicit(self, mesh): def g(x): return auto_axes(lambda a: a * 2, axes=('x', 'y'), - out_shardings=P('x', 'y'))(x) + out_sharding=P('x', 'y'))(x) with self.assertRaisesRegex( NotImplementedError, "Going from `Manual`.*to.*`Auto`.*`Explicit`"): jax.jit(shard_map(g, mesh=mesh, in_specs=P('x', 'y'), out_specs=P('x', 'y')) )(np.arange(16).reshape(8, 2)) - @jtu.with_user_mesh((2,), ('x',)) + @jtu.with_explicit_mesh((2,), ('x',)) def test_auto_axes_numpy_array(self, mesh): @jax.jit def f(x): - self.assertTrue(x.aval.sharding.mesh._are_all_axes_auto) + self.assertTrue(x.aval.sharding.mesh.are_all_axes_auto) return x * 2 - out = auto_axes(f, out_shardings=P('x'))(np.arange(8)) + out = auto_axes(f, out_sharding=P('x'))(np.arange(8)) self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) self.assertArraysEqual(out, np.arange(8) * 2) @@ -7051,13 +7543,13 @@ def f(x): jtu.dtypes.all_integer + jtu.dtypes.all_unsigned), shape_and_spec=[((), P()), ((2,), P('x')), ((2, 4), P('x', 'y'))], ) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_bitcast_convert_type(self, from_dtype, to_dtype, shape_and_spec, mesh): shape, spec = shape_and_spec rng = jtu.rand_default(self.rng()) - nbits_in = dtypes.bit_width(from_dtype) - nbits_out = dtypes.bit_width(to_dtype) + nbits_in = dtypes.itemsize_bits(from_dtype) + nbits_out = dtypes.itemsize_bits(to_dtype) if nbits_in < nbits_out: shape = (*shape, nbits_out // nbits_in) spec = P(*(*spec, None)) @@ -7090,79 +7582,2522 @@ def f(x): self.assertEqual(out.shape, expected_shape) self.assertEqual(out.sharding, NamedSharding(mesh, expected_spec)) - def test_auto_axes_computation_follows_data_error(self): - mesh = jtu.create_mesh((2,), ('x',), axis_types=(AxisType.Explicit,)) + @jtu.with_explicit_mesh((2,), ('x',)) + def test_dynamic_slice(self, mesh): + np_inp = np.arange(16., dtype=np.float32) s = NamedSharding(mesh, P('x')) - arr = jax.device_put(np.arange(8), s) + arr = jax.device_put(np_inp, s) @jax.jit def f(x): - return x * 2 - - with self.assertRaisesRegex(ValueError, "Context mesh.*cannot be empty"): - auto_axes(f, out_shardings=s)(arr) + y = lax.dynamic_slice_in_dim(x, jnp.array(1, dtype=np.int32), 2) + self.assertEqual(y.aval.sharding.spec, P('x')) + return y - def test_divisbility_aval_error(self): - abstract_mesh = mesh_lib.AbstractMesh( - (2,), ('x',), axis_types=AxisType.Explicit) - s = NamedSharding(abstract_mesh, P('x')) - with self.assertRaisesRegex( - ValueError, 'does not evenly divide the dimension size'): - core.ShapedArray((5, 2), jnp.int32, sharding=s) + out = f(arr) + self.assertEqual(out.sharding, s) - @jtu.with_user_mesh((2, 2), ('x', 'y')) - def test_scan_unroll(self, mesh): - np_inp = np.arange(64, dtype=jnp.float32).reshape(8, 8) - arr = jax.device_put(np_inp, NamedSharding(mesh, P(None, 'y'))) - carry = jnp.ones((8,), dtype=jnp.float32) + def g(x): + return jnp.sum(f(x)) - @jax.jit - def f(carry, xs): - def body(carry, x): - return carry + x, x - return jax.lax.scan(body, carry, xs, unroll=2) + out = jax.jit(jax.grad(g))(arr) + self.assertEqual(out.sharding, arr.sharding) - f(carry, arr) # doesn't crash + out = jax.grad(g)(arr) + self.assertEqual(out.sharding, arr.sharding) - @jtu.with_user_mesh((2,), ('x',)) - def test_reshard_with_np_array(self, mesh): - out = reshard(np.arange(8), P('x')) - self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_gather_with_full_slice_in_index_start_map(self, mesh): + np_inp = np.arange(32, dtype=np.float32).reshape(2, 4, 4) + s = NamedSharding(mesh, P('x', 'y', None)) + arr = jax.device_put(np_inp, s) + # vmap dynamic_slice -> gather @jax.jit + @jax.vmap def f(x): - return reshard(x, P('x')) - out = f(np.arange(8)) - self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + zero = jnp.array(0, dtype=jnp.int32) + # Slice is full in dim 0, which is a sharded dim. + y = jax.lax.dynamic_slice(x, (zero, zero), (4, 1)) + self.assertEqual(y.aval.sharding.spec, P('y', None)) + return y - def test_set_mesh(self): - mesh = jtu.create_mesh((2,), ('x',), axis_types=(AxisType.Explicit,)) - try: - prev_mesh = jax.sharding.set_mesh(mesh) - out = reshard(np.arange(8), P('x')) - self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) - finally: - jax.sharding.set_mesh(prev_mesh) + out = f(arr) + self.assertEqual(out.sharding, s) + + def g(x): + return jnp.sum(f(x)) + + out = jax.jit(jax.grad(g))(arr) + self.assertEqual(out.sharding, arr.sharding) + + out = jax.grad(g)(arr) + self.assertEqual(out.sharding, arr.sharding) + + @jtu.with_explicit_mesh((2,), ('x',)) + def test_dynamic_update_slice(self, mesh): + arr = jax.device_put(np.arange(16., dtype=np.float32), P('x')) + update = jax.device_put(np.arange(8., dtype=np.float32), P('x')) + + @jax.jit + def f(arr, update): + y = lax.dynamic_update_slice(arr, update, (3,)) + self.assertEqual(y.aval.sharding.spec, P('x')) + return y + + out = f(arr, update) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + def g(x, y): + return f(x, y).sum() + + out = jax.jit(jax.grad(g))(arr, update) + self.assertEqual(out.sharding, arr.sharding) + + out = jax.grad(g)(arr, update) + self.assertEqual(out.sharding, arr.sharding) + + def test_auto_axes_computation_follows_data(self): + mesh = jtu.create_mesh((2,), ('x',), axis_types=(AxisType.Explicit,)) + s = NamedSharding(mesh, P('x')) + arr = jax.device_put(np.arange(8), s) + + @jax.jit + def f(x): + return x * 2 + + out = auto_axes(f, out_sharding=s)(arr) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, arr * 2) + + def test_divisbility_aval_error(self): + abstract_mesh = mesh_lib.AbstractMesh( + (2,), ('x',), axis_types=AxisType.Explicit) + s = NamedSharding(abstract_mesh, P('x')) + with self.assertRaisesRegex( + ValueError, 'does not evenly divide the dimension size'): + core.ShapedArray((5, 2), jnp.int32, sharding=s) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_scan_unroll(self, mesh): + np_inp = np.arange(64, dtype=jnp.float32).reshape(8, 8) + arr = jax.device_put(np_inp, NamedSharding(mesh, P(None, 'y'))) + carry = jnp.ones((8,), dtype=jnp.float32, out_sharding=P('y')) + + @jax.jit + def f(carry, xs): + def body(carry, x): + return carry + x, x + return jax.lax.scan(body, carry, xs, unroll=2) + + f(carry, arr) # doesn't crash + + def test_eval_shape_jitted_fun_cache_hit(self): + inp = jnp.zeros([1,1]) + + @jax.jit(inline=True) + def g(x): + return x * 2 + + with jtu.count_jit_tracing_cache_miss() as count: + jax.eval_shape(g, inp) + g.trace(inp) + self.assertEqual(count(), 2) # one for `g`, one for `*` + + @jtu.with_explicit_mesh((2,), ('x',)) + def test_reshard_with_np_array(self, mesh): + out = reshard(np.arange(8), P('x')) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + @jax.jit + def f(x): + return reshard(x, P('x')) + out = f(np.arange(8)) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + @jtu.thread_unsafe_test() + def test_set_mesh(self): + mesh = jtu.create_mesh((2,), ('x',), axis_types=(AxisType.Explicit,)) + try: + jax.set_mesh(mesh) + out = reshard(np.arange(8), P('x')) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + out_mesh = jax.sharding.get_mesh() + self.assertEqual(out_mesh, mesh) + finally: + config.abstract_mesh_context_manager.set_local( + mesh_lib.empty_abstract_mesh) + config.device_context.set_local(None) + + @jtu.with_explicit_mesh((2,), ('x',)) + def test_auto_axes_late_bind(self, mesh): + @auto_axes + def f(x): + return x * 2 + + out = f(np.arange(8), out_sharding=P('x')) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + self.assertArraysEqual(out, np.arange(8) * 2) + + @jtu.with_explicit_mesh((2,), ('x',), axis_types=AxisType.Auto) + def test_explicit_axes_late_bind(self, mesh): + @explicit_axes + def f(x): + return x * 2 + + out = f(np.arange(8), in_sharding=P('x')) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + self.assertArraysEqual(out, np.arange(8) * 2) + + @jtu.with_explicit_mesh((2,), ('x',)) + def test_rng_bit_generator(self, mesh): + def f(key): + out = lax.rng_bit_generator(key, shape=(4, 8), out_sharding=P('x')) + self.assertEqual(out[0].aval.sharding.spec, P(None)) + self.assertEqual(out[1].aval.sharding.spec, P('x', None)) + return out + + key = np.array((1, 2, 3, 4)).astype(np.uint32) + out1 = f(key) + jit_f = jax.jit(f) + out2 = jit_f(key) + self.assertEqual(out1[0].shape, (4,)) + self.assertEqual(out1[1].shape, (4, 8)) + self.assertEqual(out2[0].sharding, NamedSharding(mesh, P())) + self.assertEqual(out2[1].sharding, NamedSharding(mesh, P('x', None))) + self.assertEqual(out1[0].sharding, out2[0].sharding) + self.assertEqual(out1[1].sharding, out2[1].sharding) + self.assertArraysEqual(out1[0], out2[0]) + self.assertArraysEqual(out1[1], out2[1]) + + @jtu.with_explicit_mesh((2,), ('x',)) + def test_fold_in(self, mesh): + key = jax.random.key(72) + key = jax.device_put(key, NamedSharding(mesh, P())) + + @jax.jit + def f(key): + f1 = jax.random.fold_in(key, 1) + self.assertEqual(jax.random.key_data(f1).aval.sharding.spec, P(None)) + return f1 + + out = f(key) + self.assertEqual(out.sharding, NamedSharding(mesh, P())) + + @parameterized.named_parameters( + ("bits", partial(jax.random.bits, shape=(8, 12)), P('x', 'y')), + ("uniform", partial(jax.random.uniform, shape=(8, 12)), P('x', 'y')), + ("normal", partial(jax.random.normal, shape=(8, 12)), P('x', 'y')), + ("gumbel", partial(jax.random.gumbel, shape=(8, 12)), P('x', 'y')), + ("randint", partial(jax.random.randint, shape=(8, 12), minval=0, maxval=10), + P('x', 'y')), + ("permutation_1d", partial(jax.random.permutation, x=8), P('x')), + ("permutation_2d", partial(jax.random.permutation, + x=np.arange(8 * 12).reshape(8, 12)), + P('x', 'y')), + ) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_random_functions(self, fun, out_spec, mesh): + @jax.jit + def f(key): + out = fun(key, out_sharding=out_spec) + self.assertEqual(out.aval.sharding.spec, out_spec) + return out + + key = jax.random.key(1) + out = f(key) + self.assertEqual(out.sharding, NamedSharding(mesh, out_spec)) + + lowered_text = f.lower(key).as_text() + if config.use_shardy_partitioner.value: + self.assertIn('sdy.sharding_constraint', lowered_text) + if out_spec == P('x', 'y'): + self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) + else: + assert out_spec == P('x') + self.assertIn('<@mesh, [{"x"}]>', lowered_text) + else: + if out_spec == P('x', 'y'): + self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + else: + assert out_spec == P('x') + self.assertIn( + 'mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"}', + lowered_text) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_random_truncated_normal(self, mesh): + @jax.jit + def f(key, lower): + out = jax.random.truncated_normal(key, lower, 2., shape=(8, 12), + out_sharding=P('x', 'y')) + self.assertEqual(out.aval.sharding.spec, P('x', 'y')) + return out + + key = jax.random.key(1) + out = f(key, -1.) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + lowered_text = f.lower(key, -1.).as_text() + if config.use_shardy_partitioner.value: + self.assertIn('sdy.sharding_constraint', lowered_text) + self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) + else: + self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_random_bernoulli(self, mesh): + @jax.jit + def f(key): + out = jax.random.bernoulli(key, shape=(8, 12), out_sharding=P('x', 'y')) + self.assertEqual(out.aval.sharding.spec, P('x', 'y')) + return out + + key = jax.random.key(1) + out = f(key) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + lowered_text = f.lower(key).as_text() + self.assertIn('sdy.sharding_constraint', lowered_text) + self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) + + def test_random_normal_wo_mesh_context_error(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y'), + axis_types=(AxisType.Explicit,) * 2) + s = NamedSharding(mesh, P('x', 'y')) + + @jax.jit + def f(key): + out = jax.random.normal(key, shape=(8, 12), out_sharding=s) + self.assertEqual(out.aval.sharding.spec, P('x', 'y')) + self.assertEqual(out.aval.sharding.mesh, mesh.abstract_mesh) + return out + + key = jax.random.key(1) + with self.assertRaisesRegex( + ValueError, + 'Length of device assignment.*is not equal to the size of the mesh'): + f(key) + + def test_random_normal_wo_mesh_context(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y'), + axis_types=(AxisType.Explicit,) * 2) + s = NamedSharding(mesh, P('x', 'y')) + + @jax.jit + def f(arr, key): + out = jax.random.normal(key, shape=(8, 12), out_sharding=s) + self.assertEqual(out.aval.sharding.spec, P('x', 'y')) + return arr + out + + key = jax.random.key(1) + out = f(jax.device_put(np.arange(8 * 12.).reshape(8, 12), s), key) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + def test_auto_axes_no_context_mesh(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Explicit,) * 2) + np_inp = np.arange(16.).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + @partial(auto_axes, axes='x', + out_sharding=NamedSharding(mesh, P('x', 'y'))) + def h(y): + self.assertEqual(y.aval.sharding.spec, P(None, 'y')) + z = jnp.sin(y) + self.assertEqual(z.aval.sharding.spec, P(None, 'y')) + return z + + out = jax.jit(h)(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + out = h(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + def test_scan_with_random_key_inside_jit(self): + mesh = jtu.create_mesh((2,), ('x',)) + sharding = NamedSharding(mesh, P(None, 'x')) + + @jax.jit + def scan(xs): + def step(carry, x): + next_carry = jax.vmap(jax.random.fold_in)(carry, x) + next_carry = jnp.where(x % 2 == 0, carry, next_carry) + return next_carry, None + rng = jnp.broadcast_to(jax.random.key(0), xs.shape[1:]) + rng, _ = jax.lax.scan(step, rng, xs) + return rng + + xs = jnp.arange(8).reshape(2, 4) + scan(xs) + + xs = jax.device_put(xs, sharding) + scan(xs) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_select_batch(self, mesh): + y_sharding = NamedSharding(mesh, P('y', None)) + xy_sharding = NamedSharding(mesh, P('x', 'y', None)) + batch_a = jax.device_put(jnp.ones((4, 2, 3), dtype=jnp.float32), xy_sharding) + batch_b = jax.device_put(jnp.ones((4, 2, 2), dtype=jnp.int32), xy_sharding) + + out_s = NamedSharding(mesh, P('x', 'y', None, None)) + + def select(a, b): + c = a.at[b].get(out_sharding=y_sharding) + return c + + @jax.jit + def vmap_select(batch_a, batch_b): + out = jax.vmap(select)(batch_a, batch_b) + self.assertEqual(out.aval.sharding.spec, out_s.spec) + return out + + out = vmap_select(batch_a, batch_b) + self.assertEqual(out.sharding, out_s) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_where_vmap(self, mesh): + xy_sharding = NamedSharding(mesh, P('x', 'y', None)) + batch_a = jax.device_put(jnp.ones((4, 2, 3), dtype=jnp.float32), xy_sharding) + batch_b = jax.device_put(jnp.ones((4, 2, 3), dtype=jnp.bool), xy_sharding) + + def where(a, b): + out = jnp.where(b, a, 0) + return out + + @jax.jit + def vmap_where(batch_a, batch_b): + out = jax.vmap(where)(batch_a, batch_b) + self.assertEqual(out.aval.sharding.spec, xy_sharding.spec) + return out + + out = vmap_where(batch_a, batch_b) + self.assertEqual(out.sharding, xy_sharding) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_convert_element_type_vmap(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, P('x', 'y')) + am = mesh.abstract_mesh + + @jax.jit + @jax.vmap + def f(x): + y = lax_internal._convert_element_type( + x, np.dtype(jnp.bfloat16), sharding=NamedSharding(am, P('y'))) + self.assertEqual(y.aval.sharding.spec, P('y')) + return y + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + @jtu.with_explicit_mesh((2,), ('x',)) + def test_jnp_repeat(self, mesh): + out = jnp.repeat(np.eye(3), np.array((2,2,2,)) - 1, axis=0) + self.assertEqual(out.sharding, NamedSharding(mesh, P(None, None))) + + a = jnp.eye(3) + out = jnp.repeat(a, np.array((2,2,2,)) - 1, axis=0) + self.assertEqual(out.sharding, a.sharding) + + a = jax.device_put(jnp.eye(4), P('x')) + out = jnp.repeat(a, np.array((2,2,2,2)) - 1, axis=0, out_sharding=P('x')) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + + a = jax.device_put(jnp.eye(16).reshape(16, 16), P('x')) + @jax.jit + def f(x): + return jnp.repeat(x, 3, axis=-1) + f(a) + + @jtu.with_explicit_mesh((2,), ('x',)) + def test_scatter_gather(self, mesh): + x = np.random.uniform(size=(mesh.size * 2, 3)) + i = np.random.randint(0, x.shape[1], len(x)) + j = np.random.randint(0, x.shape[1], len(x)) + x = jax.device_put(x, P("x")) + i = jax.device_put(i, P("x")) + j = jax.device_put(j, P("x")) + + @jax.jit + def f1(x, i, j): + x_a_j = x.at[:, j].get(out_sharding=jax.typeof(i).sharding) + return x.at[:, i].set(x_a_j, out_sharding=jax.typeof(x).sharding) + f1(x,i,j) # doesn't crash + + @jax.jit + @jax.vmap + def f2(x, i, j): + x_j = x.at[j].get(out_sharding=jax.typeof(x).sharding) + return x.at[i].set(x_j, out_sharding=jax.typeof(x).sharding) + f2(x,i,j) # doesn't crash + + @jtu.with_explicit_mesh((4, 2), ('x', 'y')) + def test_conv_general_dilated(self, mesh): + arr = jax.device_put(np.zeros((16, 128, 8)), P('x', 'y')) + + @jax.jit + def f(x): + # Conv1D across sharded y-axis: + out = jax.lax.conv_general_dilated( + x, np.zeros((5, 8, 10)), + window_strides=(1,), padding='SAME', feature_group_count=1, + lhs_dilation=(1,), rhs_dilation=(1,), + dimension_numbers=('NWC', 'WIO', 'NWC')) + self.assertEqual(out.aval.sharding.spec, P('x', 'y', None)) + # Max pooling along sharded y-axis. + out2 = jax.lax.reduce_window( + out, -np.inf, jax.lax.max, (1,2,1), (1,2,1), 'SAME') + self.assertEqual(out2.aval.sharding.spec, P('x', 'y', None)) + return out2 + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y', None))) + self.check_wsc_in_lowered(f.lower(arr).as_text()) + + jax.jit(jax.grad(lambda x: f(x).sum()))(arr) # doesn't crash + + with self.assertRaises(core.ShardingTypeError): + arr2 = jax.device_put(np.zeros((16, 128, 8)), P('x', None, 'y')) + f(arr2) + + @parameterized.named_parameters( + ('spec1', P('x', 'y', None)), + ('spec2', P('x', None, 'y')), + ('spec3', P(None, 'x', 'y')), + ('spec4', P(('x', 'y'), None, None)) + ) + @jtu.with_explicit_mesh((4, 2), ('x', 'y')) + def test_reduce_window(self, spec, mesh): + arr = jax.device_put(np.zeros((16, 128, 8)), spec) + + @jax.jit + def f(x): + out = jax.lax.reduce_window( + x, -np.inf, jax.lax.max, (1,2,1), (1,2,1), 'SAME') + self.assertEqual(out.aval.sharding.spec, spec) + return out + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, spec)) + self.check_wsc_in_lowered(f.lower(arr).as_text()) + + jax.jit(jax.grad(lambda x: f(x).sum()))(arr) # doesn't crash + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_jnp_dot(self, mesh): + np_inp1 = np.arange(16).reshape(8, 2) + np_inp2 = np.arange(16).reshape(2, 8) + arr1 = jax.device_put(np_inp1, P('x', 'y')) + arr2 = jax.device_put(np_inp2, P('x', 'y')) + + @jax.jit + def f(x, y): + out = jnp.dot(x, y, out_sharding=P('x')) + self.assertEqual(out.aval.sharding.spec, P('x', None)) + return out + + out = f(arr1, arr2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + self.assertArraysEqual(out, np.dot(np_inp1, np_inp2)) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_jnp_matmul(self, mesh): + np_inp1 = np.arange(16).reshape(8, 2) + np_inp2 = np.arange(16).reshape(2, 8) + arr1 = jax.device_put(np_inp1, P('x', 'y')) + arr2 = jax.device_put(np_inp2, P('x', 'y')) + + @jax.jit + def f(x, y): + out = jnp.matmul(x, y, out_sharding=P('x')) + self.assertEqual(out.aval.sharding.spec, P('x', None)) + return out + + out = f(arr1, arr2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + self.assertArraysEqual(out, np.dot(np_inp1, np_inp2)) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_jnp_tensordot(self, mesh): + np_inp1 = np.arange(16).reshape(8, 2) + np_inp2 = np.arange(12).reshape(2, 6) + arr1 = jax.device_put(np_inp1, P('x', 'y')) + arr2 = jax.device_put(np_inp2, P('x', 'y')) + + @jax.jit + def f(x, y): + out = jnp.tensordot(x, y, axes=([1], [0]), out_sharding=P('x', None)) + self.assertEqual(out.aval.sharding.spec, P('x', None)) + return out + + out = f(arr1, arr2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + self.assertArraysEqual(out, np.dot(np_inp1, np_inp2)) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_jnp_ravel(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, P('x', 'y')) + + @jax.jit + def f(x): + out = jnp.ravel(x, out_sharding=P('x')) + self.assertEqual(out.aval.sharding.spec, P('x')) + return out + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + self.assertArraysEqual(out, np.ravel(np_inp)) + + @jtu.with_explicit_mesh((4, 2), ('x', 'y')) + def test_broadcast_forwarding(self, mesh): + arr = jax.device_put(np.zeros(()), P()) + + def f(x): + out = jax.lax.full_like(x, 1.0) + self.assertEqual(jax.typeof(out).sharding, jax.typeof(x).sharding) + return out + + f(arr) # doesn't crash + jax.jit(f)(arr) # doesn't crash + + @jtu.with_explicit_mesh((2,), 'x') + def test_unreduced_einsum_basic(self, mesh): + np_inp = np.arange(4).reshape(2, 2) + x = jax.device_put(np_inp, P(None, 'x')) + y = jax.device_put(np_inp, P('x', None)) + + @jax.jit + def f(x, y): + out = jnp.einsum('ab,bc->ac', x, y, + out_sharding=P(None, None, unreduced={'x'})) + self.assertEqual(out.aval.sharding.spec, P(None, None, unreduced={'x'})) + return out + + out = f(x, y) + self.assertEqual(out.sharding, + NamedSharding(mesh, P(None, None, unreduced={'x'}))) + self.assertEqual(out.shape, (2, 2)) + self.assertEqual(out.sharding.shard_shape(out.shape), (2, 2)) + + expected_shards = [np.array([[0, 0], [0, 2]]), np.array([[2, 3], [6, 9]])] + for s, es in zip(out.addressable_shards, expected_shards): + self.assertEqual(s.data.shape, (2, 2)) + self.assertArraysEqual(s.data, es) + + reshard_out = reshard(out, P(None, None)) + self.assertArraysEqual(reshard_out, np_inp @ np_inp) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_unreduced_einsum_add_basic(self, mesh): + np_inp = np.arange(16.).reshape(8, 2) + x = jax.device_put(np_inp, P('x', 'y')) + y = jax.device_put(np_inp.T, P('y', None)) + a = jax.device_put(np_inp, P('x', 'y')) + b = jax.device_put(np_inp.T, P('y', None)) + + @jax.jit + def f(x, y, a, b): + m1 = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced={'y'})) + self.assertEqual(m1.aval.sharding.spec, P('x', None, unreduced={'y'})) + + m2 = jnp.einsum('xy,yz->xz', a, b, out_sharding=P('x', unreduced={'y'})) + self.assertEqual(m2.aval.sharding.spec, P('x', None, unreduced={'y'})) + + s = m1 + m2 # unreduced + self.assertEqual(s.aval.sharding.spec, P('x', None, unreduced={'y'})) + + out = reshard(s, P('x')) # reduce + self.assertEqual(out.aval.sharding.spec, P('x', None)) + return out + + out = f(x, y, a, b) + self.assertArraysEqual(out, (np_inp @ np_inp.T) + (np_inp @ np_inp.T)) + + traced = f.trace(x, y, a, b) + lowered_text = traced.lower().as_text() + self.assertIn('unreduced={"y"}', lowered_text) + self.assertEqual(lowered_text.count('unreduced={"y"}'), 3) + + f_bar = jax.jit(jax.grad(lambda x, y, a, b: f(x, y, a, b).sum(), + argnums=(0, 1, 2, 3))) + f_bar(x, y, a, b) # doesn't crash + + grad_jaxpr = f_bar.trace(x, y, a, b).jaxpr + reshard_eqn = grad_jaxpr.eqns[4].params['jaxpr'].eqns[0] + self.assertEqual(reshard_eqn.params['dst_sharding'].spec.reduced, + frozenset('y')) + self.assertEqual(reshard_eqn.params['dst_sharding'].spec.unreduced, + frozenset()) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_einsum_unreduced_with_transpose(self, mesh): + arr1 = jax.device_put(jnp.arange(192).reshape(6, 4, 8), P(None, 'x', 'y')) + arr2 = jax.device_put(jnp.arange(320).reshape(4, 8, 10), P('x', 'y', None)) + + @jax.jit + def f(arr1, arr2): + out = jnp.einsum("heb,eba->hea", arr1, arr2, + out_sharding=P(None, 'x', None, unreduced={'y'})) + self.assertEqual(out.aval.sharding.spec, + P(None, 'x', None, unreduced={'y'})) + return out + + out = f(arr1, arr2) + self.assertEqual(out.sharding, + NamedSharding(mesh, P(None, 'x', None, unreduced={'y'}))) + + reshard_out = jax.sharding.reshard(out, P(None, 'x', None)) + expected_out = jnp.einsum("heb,eba->hea", arr1, arr2, + out_sharding=P(None, 'x', None)) + self.assertArraysEqual(reshard_out, expected_out) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_unreduced_multi_axes_einsum(self, mesh): + x = jax.device_put(np.arange(16.).reshape(8, 2), P(('x', 'y'), None)) + y = jax.device_put(np.arange(8.), P(('x', 'y'))) + + @jax.jit + def f(x, y): + out = jnp.einsum('ij,i->j', x, y, + out_sharding=P(None, unreduced={'x', 'y'})) + self.assertEqual(out.aval.sharding.spec.unreduced, {'x', 'y'}) + return out + + out = f(x, y) + self.assertEqual(out.sharding, + NamedSharding(mesh, P(None, unreduced={'x', 'y'}))) + + reshard_out = reshard(out, P(None)) + expected_out = jnp.einsum('ij,i->j', x, y, out_sharding=P(None)) + self.assertArraysEqual(reshard_out, expected_out) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_unreduced_multi_axes_none_einsum(self, mesh): + np_inp = np.arange(64.).reshape(4, 4, 2, 2) + x = jax.device_put(np_inp, P(None, 'x', None, 'y')) + y = jax.device_put(np_inp, P(None, 'x', None, 'y')) + + @jax.jit + def f(x, y): + out = jnp.einsum('bxyz,bxyz->b', x, y, + out_sharding=P(None, unreduced={'x', 'y'})) + self.assertEqual(out.aval.sharding.spec.unreduced, {'x', 'y'}) + return out + + out = f(x, y) + self.assertEqual(out.sharding, + NamedSharding(mesh, P(None, unreduced={'x', 'y'}))) + + reshard_out = reshard(out, P(None)) + expected_out = jnp.einsum('bxyz,bxyz->b', x, y, out_sharding=P(None)) + self.assertArraysEqual(reshard_out, expected_out) + + @jtu.with_explicit_mesh((2, 2, 1), ('x', 'y', 'z')) + def test_dot_general_unreduced_error(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + # Case 1 + x = jax.device_put(np_inp, P('x', 'y')) + y = jax.device_put(np_inp.T, P('y', None)) + + @jax.jit + def f(x, y): + return jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced={'z'})) + with self.assertRaisesRegex( + core.ShardingTypeError, + "unreduced axes should be equal to the contracting specs"): + f.trace(x, y) + + # Case 2 + x = jax.device_put(np_inp, P('x', 'y')) + y = jax.device_put(np_inp.T, P(None, None)) + @jax.jit + def g(x, y): + return jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced={'y'})) + with self.assertRaisesRegex( + core.ShardingTypeError, + "lhs and rhs contracting dims should be sharded identically"): + g.trace(x, y) + + # Case 3 + x = jax.device_put(np_inp, P('x', None)) + y = jax.device_put(np_inp.T, P(None, None)) + + @jax.jit + def h(x, y): + return jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced={'y'})) + with self.assertRaisesRegex( + core.ShardingTypeError, + "unreduced axes should be equal to the contracting specs"): + h.trace(x, y) + + # Case 4 + x = jax.device_put(np_inp, P('x', 'y')) + y = jax.device_put(np_inp.T, P('y', None)) + + @jax.jit + def k(x, y): + z = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced={'y'})) + return jnp.einsum('xy,yz->xz', z, x) + with self.assertRaisesRegex( + core.ShardingTypeError, + "lhs or rhs passed to dot_general cannot be unreduced"): + k.trace(x, y) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_three_operand_einsum_unreduced(self, mesh): + n1, n2, n3 = np.arange(2), np.arange(8).reshape(2, 4), np.arange(2) + arr1 = jax.device_put(n1, P('x')) + arr2 = jax.device_put(n2, P('x', 'y')) + arr3 = jax.device_put(n3, P('x')) + + @jax.jit + def f(arr1, arr2, arr3): + out = jnp.einsum('d,dh,d->h', arr1, arr2, arr3, + out_sharding=P('y', unreduced={'x'})) + self.assertEqual(out.aval.sharding.spec, P('y', unreduced={'x'})) + return out + + lowered_text = f.lower(arr1, arr2, arr3).as_text() + self.assertEqual(lowered_text.count('unreduced={"x"}'), 2) + + out = f(arr1, arr2, arr3) + self.assertEqual(out.sharding, NamedSharding(mesh, P('y', unreduced={'x'}))) + + reshard_out = reshard(out, P('y')) + expected_out = np.einsum("d,dh,d->h", n1, n2, n3) + self.assertArraysEqual(reshard_out, expected_out) + + @jtu.with_explicit_mesh((2, 2, 1), ('x', 'y', 'z')) + def test_add_unreduced_error(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + x = jax.device_put(np_inp, P('x', 'y')) + y = jax.device_put(np_inp.T, P('y', None)) + a = jax.device_put(np_inp, P('x', 'z')) + b = jax.device_put(np_inp.T, P('z', None)) + + @jax.jit + def f(x, y, a, b): + m1 = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced={'y'})) + m2 = jnp.einsum('xy,yz->xz', a, b, out_sharding=P('x', unreduced={'z'})) + return m1 + m2 + + with self.assertRaisesRegex( + core.ShardingTypeError, + "lhs and rhs to `add` must be unreduced along the same mesh axes"): + f.trace(x, y, a, b) + + @jax.jit + def g(x, y): + m1 = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced={'y'})) + m2 = jnp.einsum('xy,yz->xz', a, b, out_sharding=P('x')) + return m1 + m2 + + with self.assertRaisesRegex( + core.ShardingTypeError, "lhs is unreduced while rhs is not"): + g.trace(x, y) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_eval_shape(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, P('x', 'y')) + + @jax.jit + def f(x): + return x * 2 + + out = jax.eval_shape(f, arr) + self.assertIsInstance(out, jax.ShapeDtypeStruct) + self.assertEqual(out.sharding, + NamedSharding(mesh.abstract_mesh, P('x', 'y'))) + + @jtu.with_explicit_mesh((2,), ('x',)) + def test_he_normal(self, mesh): + init = jax.nn.initializers.he_normal(in_axis=0, out_axis=1) + key = jax.random.key(0) + out = init(key, (8, 2), jnp.float32, out_sharding=P('x')) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + + @jtu.with_explicit_mesh((2,), ('x',)) + def test_nn_uniform(self, mesh): + init = jax.nn.initializers.uniform() + key = jax.random.key(0) + out = init(key, (8, 2), jnp.float32, out_sharding=P('x')) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + + @jtu.with_explicit_mesh((2,), ('x',)) + def test_nn_constant(self, mesh): + init = jax.nn.initializers.constant(-7) + key = jax.random.key(0) + out = init(key, (8, 2), jnp.float32, out_sharding=P('x')) + self.assertArraysEqual(out, jnp.full((8, 2), -7, dtype=jnp.float32)) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + + @config.numpy_rank_promotion('allow') + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_lax_map(self, mesh): + def simple_func(w, x): + return jnp.sum(w * x, axis=-1) + + w = jax.device_put(np.arange(4, dtype=np.float32), P('x')) + x = jax.device_put(np.ones((4, 2, 4), dtype=np.float32), + P(None, 'y', None)) + + jax.lax.map(lambda _x: simple_func(w, _x), x) # doesn't crash + + jax.lax.map(lambda _x: simple_func(w, _x), x, batch_size=2) # doesn't crash + + @config.numpy_rank_promotion('allow') + @jtu.with_explicit_mesh((2,), ('x',)) + def test_lax_map_remainder(self, mesh): + def simple_func(w, x): + return jnp.sum(w * x, axis=-1) + + w = jax.device_put(np.arange(4, dtype=np.float32), P()) + x = jax.device_put(np.ones((5, 2, 4), dtype=np.float32), + P(None, 'x', None)) + + jax.lax.map(lambda _x: simple_func(w, _x), x, batch_size=2) # doesn't crash + + @jtu.with_explicit_mesh((2,), ('x',)) + def test_extended_dtypes(self, mesh): + dtype = primal_tangent_dtype(jnp.dtype('int8'), jnp.dtype('bfloat16')) + + @jax.jit + def f(x): + x = jax.lax.convert_element_type(x, dtype) + self.assertEqual(x.aval.sharding.spec, P('x')) + x = jax.lax.convert_element_type(x, 'int8') + self.assertEqual(x.aval.sharding.spec, P('x')) + return x + + x = jax.device_put(jnp.arange(8, dtype='int8'), P('x')) + out = f(x) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + @jtu.with_explicit_mesh((2,), 'x') + def test_dot_empty_mesh_lhs_rhs(self, mesh): + np_inp = np.ones((2, 2)) + arr = jax.device_put(np.ones((2, 2)), P('x')) + + @jax.jit + def f(x, y): + return jnp.dot(x, y) + + out = f(np_inp, arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P(None, None))) + + def g(x, y): + return jnp.sum(f(x, y)) + + jax.jit(jax.grad(g))(np_inp, arr) # doesn't crash + + out2 = f(arr, np_inp) + self.assertEqual(out2.sharding, NamedSharding(mesh, P('x', None))) + jax.jit(jax.grad(g))(arr, np_inp) # doesn't crash + + @parameterized.named_parameters( + ('mesh1', (1, 4)), + ('mesh2', (2, 2)), + ) + def test_reshape_merge_replicated(self, axis_sizes): + mesh = jtu.create_mesh(axis_sizes, ('x', 'y'), + axis_types=(AxisType.Explicit,) * 2) + with jax.set_mesh(mesh): + np_inp = np.ones((8,4,4)) + arr = jax.device_put(np_inp, P(None, None, 'x')) + out = jnp.reshape(arr, (-1, arr.shape[-1])) + self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'x'))) + + @config.numpy_dtype_promotion('standard') + @jtu.with_explicit_mesh((2,), 'x') + def test_lax_switch_vmap(self, mesh): + batch_idx = jnp.array([0, 1, 2, 4]) + val = jax.numpy.ones((), dtype=jnp.float32) + batch_idx_shard = reshard(batch_idx, P('x')) + + def switch_fun(val, index): + def branch(args): + val_, index_ = args + return val_ + index_ + branches = [branch for _ in range(5)] + out = jax.lax.switch(index, branches, (val, index)) + return out + + vmap_switch_fun = jax.vmap(switch_fun, in_axes=(None, 0), out_axes=0) + vmap_switch_fun(val, batch_idx_shard) # doesn't crash + jax.jit(vmap_switch_fun)(val, batch_idx_shard) # doesn't crash + + @jtu.with_explicit_mesh((2,), 'x') + def test_lax_switch_vmap_random(self, mesh): + @partial(jax.vmap, in_axes=(None, 0,0)) + def flip_state_scalar(key, state, index): + def add_rng_to(_): + return jax.random.uniform(key, shape=(), minval=0.0, maxval=1.0) + + branches = [add_rng_to for _ in range(state.shape[0])] + new_state = jax.lax.switch(index, branches, state) + return new_state + + batch_states = reshard(jnp.zeros((4, 5)), P('x')) + batch_idxs = reshard(jnp.array([0, 1, 2, 4]), P('x')) + key = jax.random.key(42) + + flip_state_scalar(key, batch_states, batch_idxs) # doesn't crash + jax.jit(flip_state_scalar)(key, batch_states, batch_idxs) # doesn't crash + + @jtu.with_explicit_mesh((2,), 'x') + def test_jnp_zeros_out_sharding(self, mesh): + s = NamedSharding(mesh, P('x')) + + out = jnp.zeros((8,), jnp.float32, out_sharding=s) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, np.zeros((8,), np.float32)) + + out = jnp.zeros((8,), jnp.float32, out_sharding=P('x')) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, np.zeros((8,), np.float32)) + + @jax.jit + def f(x): + return jnp.zeros(x.shape, x.dtype, out_sharding=P('x')) + out = f(np.arange(8, dtype=np.float32)) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, np.zeros((8,), np.float32)) + + @jtu.with_explicit_mesh((2,), 'x') + def test_at_add_lower_to_scatter_add(self, mesh): + x = jnp.zeros((2, 10)) + ids = jnp.array([1, 3, 7, 10]) + scalar = 5.0 + + x_updated = x.at[:, ids].add(scalar) + self.assertEqual(x_updated.sharding, NamedSharding(mesh, P(None, None))) + + xs = reshard(x, P('x', None)) + x_updated = xs.at[:, ids].add(scalar) + self.assertEqual(x_updated.sharding, NamedSharding(mesh, P('x', None))) + + @jax.jit(static_argnames=('out_sharding',)) + def f(x, ids, scalar, out_sharding = None): + out = x.at[:, ids].add(scalar, out_sharding=out_sharding) + self.assertEqual(out.aval.sharding.spec, P('x', None)) + return out + + out = f(xs, ids, scalar) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + + out = f(xs, reshard(ids, P('x')), scalar, out_sharding=P('x', None)) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + + @config.numpy_dtype_promotion('standard') + @jtu.with_explicit_mesh((2,), 'x') + def test_explicit_complex_grad(self, mesh): + @jax.jit + def f(x): + return (x * 1j).real + + jax.grad(f)(1.0) # doesn't crash + jax.jit(jax.grad(f))(1.0) # doesn't crash + + @jtu.with_explicit_mesh((2,), 'x') + def test_explicit_complex(self, mesh): + x = jnp.arange(8, dtype=np.float32) + y = np.arange(8, dtype=np.float32) + + @jax.jit + def f(x, y): + return jax.lax.complex(x, y) + + f(x, y) # doesn't crash + + @jtu.with_explicit_mesh((2,), 'x') + def test_scatter_jit_invariant(self, mesh): + def f(): + a = jnp.zeros((10,10)) + val = jnp.ones((5,5)) + b = a.at[::2,::2].set(val) + return b + + f() + jax.jit(f)() + + @config.numpy_rank_promotion('allow') + @jtu.with_explicit_mesh((2,), 'x') + def test_lax_map_batch_size_greater_than_input(self, mesh): + w = jnp.arange(4, dtype=np.float32) + x = jax.device_put(jnp.ones((5, 2, 4), dtype=np.float32), P(None, 'x', None)) + + def f(w, x): + return jnp.sum(w * x, axis=-1) + + jax.lax.map(lambda _x: f(w, _x), x, batch_size=10) # doesn't crash + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_explicit_ctx_vmap_over_auto_axes(self, mesh): + xs = jax.device_put(jnp.ones((20, 2)), P(('x', 'y'), None)) + + @partial(jax.sharding.auto_axes, out_sharding=P()) + def f(x): + return jnp.where(x, jnp.ones(2), x) + + out = jax.jit(jax.vmap(f))(xs) + self.assertEqual(out.sharding, NamedSharding(mesh, P(('x', 'y'), None))) + + @jtu.with_explicit_mesh((2,), 'x') + def test_jnp_arange_out_sharding(self, mesh): + s = NamedSharding(mesh, P('x')) + + out = jnp.arange(8, dtype=jnp.float32, out_sharding=s) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, np.arange(8, dtype=np.float32)) + + out = jnp.arange(8, dtype=jnp.float32, out_sharding=P('x')) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, np.arange(8, dtype=np.float32)) + + out = jnp.arange(start=8, stop=16, dtype=jnp.float32, out_sharding=P('x')) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, np.arange(start=8, stop=16, dtype=np.float32)) + + out = jnp.arange(start=8, stop=16, step=2, dtype=jnp.float32, + out_sharding=P('x')) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, np.arange(start=8, stop=16, step=2, + dtype=np.float32)) + + @jax.jit + def f(): + return jnp.arange(8, dtype=np.float32, out_sharding=P('x')) + out = f() + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, np.arange(8, dtype=np.float32)) + + def test_set_mesh_error(self): + mesh = Mesh(jax.devices(), 'x', (AxisType.Manual,)) + with self.assertRaisesRegex(ValueError, ".*contains manual axes"): + jax.set_mesh(mesh) + + @jtu.with_explicit_mesh((1,), 'x') + def test_auto_axes_single_device(self, mesh): + @partial(auto_axes, out_sharding=P('x')) + def f(x): + return x * 2 + + out = f(jnp.ones((2, 3))) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + + out = jax.jit(f)(jnp.ones((2, 3))) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), + axis_types=(AxisType.Explicit, AxisType.Auto)) + def test_mix_axis_type_in_pspec(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, P(('x', 'y'), None)) + + @jax.jit + def f(x): + self.assertEqual(x.aval.sharding.spec, P('x', None)) + out = x * 2 + self.assertEqual(out.aval.sharding.spec, P('x', None)) + return out + + out = f(arr) + self.assertArraysEqual(out, np_inp * 2) + + lowered_text = f.lower(arr).as_text() + if config.use_shardy_partitioner.value: + self.assertEqual(lowered_text.count("{?}"), 3) + self.assertEqual(lowered_text.count('{"x", ?}'), 3) + else: + self.assertEqual(lowered_text.count('unspecified_dims=[0,1]'), 3) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Auto,) * 2) + def test_vmap_spmd_axis_name_explicit_axes_inside(self, mesh): + np_inp = np.arange(16).reshape(2, 8) + arr1 = jax.device_put(np_inp, P()) + arr2 = jax.device_put(np_inp, P()) + + @jax.jit + def f(x, y): + @explicit_axes(in_sharding=(P('y'), P('y'))) + def g(a, b): + self.assertEqual(a.aval.sharding.spec, P('y')) + self.assertEqual(b.aval.sharding.spec, P('y')) + a = reshard(a, P()) + self.assertEqual(a.aval.sharding.spec, P(None)) + out = a * b + self.assertEqual(out.aval.sharding.spec, P('y')) + return out + return g(x, y) + + out = jax.jit(jax.vmap(f, spmd_axis_name='x'))(arr1, arr2) # doesn't crash + + @parameterized.named_parameters( + ('replicated', None), + ('sharded', 'x'), + ) + @jtu.with_explicit_mesh((2,), ('x',)) + def test_gather(self, spec, mesh): + tokens = jnp.arange(32 * 257).reshape(32, 257, out_sharding=P(spec)) + + @jax.jit + def f(x): + out = tokens[:, :-1] + return out + + out = f(tokens) + self.assertEqual(out.shape, (32, 256)) + self.assertEqual(out.sharding, NamedSharding(mesh, P(spec, None))) + + @jtu.with_explicit_mesh((2,), 'x') + def test_linalg_slogdet_and_solve(self, mesh): + B, N = 8, 4 + key = jax.random.key(0) + A = jax.random.normal(key, (B, N, N)) + As = reshard(A, P('x')) + + jnp.linalg.slogdet(As) # doesn't crash + jax.vmap(jnp.linalg.slogdet)(As) # doesn't crash + + b = jax.random.normal(key, (B, N)) + bs = reshard(b, P('x')) + jax.vmap(jnp.linalg.solve)(As, bs) # doesn't crash + + b2 = b.reshape((*b.shape, 1)) + bs2 = reshard(b2, P('x')) + jnp.linalg.solve(As, bs2) # doesn't crash + + @jtu.with_explicit_mesh((2,), 'x') + def test_sparse_linalg_cg_indexing(self, mesh): + key = jax.random.key(123) + samples = jax.random.randint(key, (2, 2), 0, 2, dtype=jnp.int32) + params = jax.random.normal(key, (2, 2), jnp.complex64) + + def apply_fn(params, x): + def single_apply(qn): + qn_idx = qn[0] + result = params[qn_idx, :] + return jnp.sum(result) + return jax.vmap(single_apply)(x) + + def mat_vec(v): + _, jvp_fn = jax.linearize(lambda W: apply_fn(W, samples), params) + vjp_fn = jax.linear_transpose(jvp_fn, v) + w = jvp_fn(v) + (res,) = vjp_fn(w) + return res + + jax.scipy.sparse.linalg.cg(mat_vec, params) # doesn't crash + + @jtu.with_explicit_mesh((4, 2), ('x', 'y'), + axis_types=(AxisType.Explicit, AxisType.Auto)) + def test_wsc_mix_axis_types(self, mesh): + arr = jax.device_put(np.arange(16).reshape(8, 2), P('x')) + + @jax.jit + def f(x): + return jax.lax.with_sharding_constraint(x, P('y')) + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P(('x', 'y')))) + + @jtu.with_explicit_mesh((2,), 'x') + def test_reshard_rng_keys(self, mesh): + key = jax.random.key(12) + + @jax.jit + def f(): + return reshard(key, P()) + + out = f() + self.assertEqual(out.sharding, NamedSharding(mesh, P())) + + @jtu.with_explicit_mesh((2,), 'x', axis_types=(AxisType.Auto,)) + def test_reshard_rng_keys_sharding(self, mesh): + key = jax.random.key(12) + explicit_mesh = jax.sharding.AbstractMesh((2,), ('x',), (AxisType.Explicit,)) + + @jax.jit + def f(): + return reshard(key, NamedSharding(explicit_mesh, P())) + + f() # doesn't crash + + @jtu.with_explicit_mesh((2,), 'x') + def test_add_transpose(self, mesh): + x = jax.device_put(jnp.arange(4.), P()) + y = jax.device_put(jnp.arange(4.), P('x')) + + @jax.jit + def f(x, y): + return x + y + + out = f(x, y) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + out = jax.jit(jax.grad(lambda x, y: f(x, y).sum(), argnums=(0, 1)))(x, y) + self.assertEqual(out[0].sharding, NamedSharding(mesh, P(None))) + self.assertEqual(out[1].sharding, NamedSharding(mesh, P('x'))) + + @jtu.with_explicit_mesh((2,), ('x',)) + def test_dynamic_update_slice_transpose(self, mesh): + x = jax.device_put(jnp.arange(8.), P()) + y = jax.device_put(jnp.arange(4.), P()) + z = jax.device_put(jnp.arange(8.), P('x')) + + @jax.jit + def f(x, y, z): + x_updated = jax.lax.dynamic_update_slice(x, y, (1,)) + w = x_updated + z + self.assertEqual(w.aval.sharding.spec, P('x')) + return w.sum() + + out = jax.jit(jax.grad(f, argnums=(0, 1, 2)))(x, y, z) + self.assertEqual(out[0].sharding, NamedSharding(mesh, P(None))) + self.assertEqual(out[1].sharding, NamedSharding(mesh, P(None))) + self.assertEqual(out[2].sharding, NamedSharding(mesh, P('x'))) + + @parameterized.named_parameters( + ('fully_replicated', P(None, None)), + ('sharded', P(None, 'x')) + ) + @jtu.with_explicit_mesh((2,), ('x',)) + def test_gather_transpose(self, z_spec, mesh): + x = jax.device_put(jnp.arange(16.).reshape(2, 8), P()) + y = jax.device_put(jnp.arange(6.).reshape(2, 3), P()) + z = jax.device_put(jnp.arange(16.).reshape(2, 8), z_spec) + + @jax.jit + @jax.vmap + def f(x, y, z): + x_updated = jax.lax.dynamic_update_slice(x, y, (1,)) + w = x_updated + z + self.assertEqual(w.aval.sharding.spec, P(z_spec[-1])) + return w.sum() + + f(x, y, z) # doesn't crash + + def g(x, y, z): + return f(x, y, z).sum() + + out = jax.jit(jax.grad(g, argnums=(0, 1, 2)))(x, y, z) + self.assertEqual(out[0].sharding, NamedSharding(mesh, P(None, None))) + self.assertEqual(out[1].sharding, NamedSharding(mesh, P(None, None))) + self.assertEqual(out[2].sharding, NamedSharding(mesh, z_spec)) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_multi_einsum_intermediate_einsum_auto(self, mesh): + AB = jnp.ones((8,16), out_sharding=P('x', 'y')) + BC = jnp.ones((16, 4), out_sharding=P('y', None)) + AC = jnp.ones((8, 4), out_sharding=P('x', None)) + + out = jnp.einsum('ab,bc,ac->c', AB, BC, AC, out_sharding=jax.P()) + self.assertEqual(out.sharding, NamedSharding(mesh, P(None))) + + def test_auto_axes_no_op(self): + mesh = jtu.create_mesh((1, 1), ('x', 'y')) + with jax.set_mesh(mesh): + @auto_axes(out_sharding=P('x')) + def f(x): + return x * 2 + out = jax.jit(f)(jnp.arange(8)) # doesn't crash + self.assertArraysEqual(out, np.arange(8) * 2) + jaxpr = jax.jit(f).trace(jnp.arange(8)).jaxpr + self.assertNotIn('reshard', str(jaxpr)) + + mesh = jtu.create_mesh((1, 1), ('x', 'y'), + axis_types=(AxisType.Auto, AxisType.Explicit)) + with jax.set_mesh(mesh): + @auto_axes(out_sharding=P('y'), axes='x') + def f(x): + return x * 2 + out = jax.jit(f)(jnp.arange(8)) # doesn't crash + self.assertArraysEqual(out, np.arange(8) * 2) + jaxpr = jax.jit(f).trace(jnp.arange(8)).jaxpr + self.assertNotIn('reshard', str(jaxpr)) + + @auto_axes(out_sharding=P('y')) + def f(x): + return x * 2 + out = jax.jit(f)(jnp.arange(8)) # doesn't crash + self.assertArraysEqual(out, np.arange(8) * 2) + jaxpr = jax.jit(f).trace(jnp.arange(8)).jaxpr + self.assertEqual(str(jaxpr).count('reshard'), 2) + + @jtu.with_explicit_mesh((2,), 'x') + def test_nary_op_input_constraint(self, mesh): + arr = jax.device_put(np.arange(16).reshape(8, 2), P('x')) + + @jax.jit + def f(x): + var = jnp.broadcast_to(2, x.shape) + out = x * var + return out + + lowered_text = f.lower(arr).as_text() + matches = re.findall(r'sdy.sharding_constraint.*\[\{\"x\"\}, \{\}\]', + lowered_text) + self.assertEqual(len(matches), 2) + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + self.assertArraysEqual(out, arr * 2) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_abstract_mesh_changing_names_sizes(self, mesh): + np_inp = np.arange(16.).reshape(8, 2) + abstract_mesh = mesh.abstract_mesh + + @jax.jit + def f(x): + x = jnp.sin(x) + with jax.sharding.use_abstract_mesh( + abstract_mesh.update(axis_sizes=(4, 1), axis_names=('a', 'b'))): + x = reshard(x, P(('a', 'b'), None)) + out = x * 2 + self.assertEqual(out.aval.sharding.spec, P(('a', 'b'), None)) + return out + + arr = jax.device_put(np_inp, P('x', 'y')) + out = f(arr) + self.assertEqual(out.sharding.spec, P(('a', 'b'), None)) + self.assertArraysEqual(out, jnp.sin(np_inp) * 2) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_abstract_mesh_changing_names_sizes_grad(self, mesh): + np_inp = np.arange(16.).reshape(8, 2) + abstract_mesh = mesh.abstract_mesh + + @jax.jit + def f(x): + x = jnp.sin(x) + with jax.sharding.use_abstract_mesh( + abstract_mesh.update(axis_sizes=(4, 1), axis_names=('a', 'b'))): + x = reshard(x, P(('a', 'b'), None)) + out = x * 2 + self.assertEqual(out.aval.sharding.spec, P(('a', 'b'), None)) + return reshard(out, P('y', None)) + + arr = jax.device_put(np_inp, P('x', 'y')) + out = f(arr) + self.assertEqual(out.sharding.spec, P('y', None)) + + out = jax.jit(jax.grad(lambda x: f(x).sum()))(arr) + self.assertEqual(out.sharding, arr.sharding) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Auto,) * 2) + def test_abstract_mesh_changing_names_sizes_auto_mode(self, mesh): + np_inp = np.arange(16.).reshape(8, 2) + abstract_mesh = mesh.abstract_mesh + + @jax.jit + def f(x): + x = jnp.sin(x) + new_am = abstract_mesh.update(axis_sizes=(4, 1), axis_names=('a', 'b')) + with jax.sharding.use_abstract_mesh(new_am): + x = with_sharding_constraint(x, P(('a', 'b'), None)) + out = x * 2 + self.assertEqual(out.aval.sharding.mesh, new_am) + return out + + arr = jax.device_put(np_inp, P('x', 'y')) + out = f(arr) + self.assertEqual(out.sharding.spec, P('a')) + + @jtu.run_on_devices('cpu') + @jtu.with_explicit_mesh((2,), 'x') + def test_changing_abstract_device(self, mesh): + inp = jnp.arange(8) + abstract_mesh = mesh.abstract_mesh + + @jax.jit + def f(x): + return x + + with jtu.count_jit_tracing_cache_miss() as tracing_count: + f(inp) + with jax.sharding.use_abstract_mesh(abstract_mesh.update( + abstract_device=AbstractDevice('tpu', None))): # induces a cache miss + f(inp) + self.assertEqual(tracing_count(), 2) # twice for f + + @parameterized.named_parameters( + ('1', P('x', 'y'), P('x', None)), + ('2', P(('x', 'y')), P('x', None)), + ('3', P('y'), P(None, None)), + ) + @config.remove_size_one_mesh_axis_from_type(True) + @jtu.with_explicit_mesh((2, 1), ('x', 'y')) + def test_remove_size_one_mesh_axis(self, arr_spec, type_spec, mesh): + arr = jax.device_put(np.arange(16).reshape(8, 2), arr_spec) + + @jax.jit + def f(x): + self.assertEqual(x.aval.sharding.spec, type_spec) + out = x * 2 + self.assertEqual(out.aval.sharding.spec, type_spec) + # wsc should act as an assert. + out = with_sharding_constraint(out, arr_spec) + return out + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, type_spec)) + self.assertArraysEqual(out, arr * 2) + + @jtu.with_explicit_mesh((2, 1), ('x', 'y')) + def test_typeof_not_mesh_context_dependent(self, mesh): + arr = jax.device_put(np.arange(16).reshape(8, 2), P('x', 'y')) + self.assertEqual(jax.sharding.get_abstract_mesh().axis_names, ('x', 'y')) + self.assertEqual(jax.typeof(arr).sharding.spec, P('x', 'y')) + + with jax.set_mesh(mesh.update(axis_names=('a', 'b'))): + self.assertEqual(jax.sharding.get_abstract_mesh().axis_names, ('a', 'b')) + # Even though the mesh changed axis_names, the type's spec didn't change. + self.assertEqual(jax.typeof(arr).sharding.spec, P('x', 'y')) + + # There needs to be an explicit cast via reshard to change the + # type's spec. + arr = jax.sharding.reshard(arr, P('a', 'b')) + self.assertEqual(jax.typeof(arr).sharding.spec, P('a', 'b')) + + @parameterized.named_parameters( + ('Ux', (2,), ('x',), P(None, 'x'), P('x', None), + P(None, None, unreduced={'x'}), (8, 8)), + ('Sx,Uy', (2, 2), ('x', 'y'), P('x', 'y'), P('y', None), + P('x', None, unreduced={'y'}), (4, 8)), + ('Rx,Uy', (2, 2), ('x', 'y'), P(None, 'y'), P('y', None), + P(None, None, unreduced={'y'}), (8, 8)), + ('Sx,Uy,Rz', (2, 2, 2), ('x', 'y', 'z'), P('x', 'y'), P('y', None), + P('x', None, unreduced={'y'}), (4, 8)), + ) + def test_unreduced_output_from_jit( + self, axis_sizes, axis_names, x_spec, y_spec, out_spec, shard_shape): + mesh = jtu.create_mesh(axis_sizes, axis_names, + axis_types=(AxisType.Explicit,) * len(axis_names)) + with jax.set_mesh(mesh): + np_inp = np.arange(16.).reshape(8, 2) + x = jax.device_put(np_inp, x_spec) + y = jax.device_put(np_inp.T, y_spec) + + @jax.jit + def f(x, y): + out = jnp.einsum('xy,yz->xz', x, y, out_sharding=out_spec) + self.assertEqual(out.aval.sharding.spec, out_spec) + return out + + out = f(x, y) + self.assertEqual(out.sharding, NamedSharding(mesh, out_spec)) + self.assertNotEmpty(out.sharding.spec.unreduced) + self.assertEqual(out.shape, (8, 8)) + self.assertEqual(out.sharding.shard_shape(out.shape), shard_shape) + for s in out.addressable_shards: + self.assertEqual(s.data.shape, shard_shape) + + expected_out = jnp.dot(x, y, out_sharding=P('x', None)) + + reshard_out = jax.sharding.reshard(out, P('x', None)) + self.assertEmpty(reshard_out.sharding.spec.unreduced) + self.assertEqual(reshard_out.sharding, NamedSharding(mesh, P('x', None))) + self.assertArraysEqual(reshard_out, expected_out) + + @parameterized.named_parameters( + ('Ux', (2,), ('x',), P(None, 'x'), P('x', None), + P(None, None, unreduced={'x'}), (8, 8)), + ('Sx,Uy', (2, 2), ('x', 'y'), P('x', 'y'), P('y', None), + P('x', None, unreduced={'y'}), (4, 8)), + ('Rx,Uy', (2, 2), ('x', 'y'), P(None, 'y'), P('y', None), + P(None, None, unreduced={'y'}), (8, 8)), + ('Sx,Uy,Rz', (2, 2, 2), ('x', 'y', 'z'), P('x', 'y'), P('y', None), + P('x', None, unreduced={'y'}), (4, 8)), + ) + def test_unreduced_input_to_jit( + self, axis_sizes, axis_names, x_spec, y_spec, out_spec, shard_shape): + mesh = jtu.create_mesh(axis_sizes, axis_names, + axis_types=(AxisType.Explicit,) * len(axis_names)) + with jax.set_mesh(mesh): + np_inp = np.arange(16.).reshape(8, 2) + x = jax.device_put(np_inp, x_spec) + y = jax.device_put(np_inp.T, y_spec) + + @jax.jit + def f(x, y): + out = jnp.einsum('xy,yz->xz', x, y, out_sharding=out_spec) + self.assertEqual(out.aval.sharding.spec, out_spec) + return out + + out = f(x, y) + self.assertNotEmpty(out.sharding.spec.unreduced) + + @jax.jit + def g(x): + self.assertEqual(x.aval.sharding.spec, out_spec) + y = x + x + return jax.sharding.reshard(y, P('x', None)) + + expected_out = jnp.dot(x, y, out_sharding=P('x', None)) * 2 + + out2 = g(out) + self.assertArraysEqual(out2, expected_out) + self.assertEqual(out2.sharding, NamedSharding(mesh, P('x', None))) + self.assertEmpty(out2.sharding.spec.unreduced) + + @parameterized.named_parameters( + ('Ux', (2,), ('x',), P(None, 'x'), P('x', None), + P(None, None, unreduced={'x'}), (8, 8)), + ('Sx,Uy', (2, 2), ('x', 'y'), P('x', 'y'), P('y', None), + P('x', None, unreduced={'y'}), (4, 8)), + ('Rx,Uy', (2, 2), ('x', 'y'), P(None, 'y'), P('y', None), + P(None, None, unreduced={'y'}), (8, 8)), + ('Sx,Uy,Rz', (2, 2, 2), ('x', 'y', 'z'), P('x', 'y'), P('y', None), + P('x', None, unreduced={'y'}), (4, 8)), + ) + def test_in_and_out_unreduced( + self, axis_sizes, axis_names, x_spec, y_spec, out_spec, shard_shape): + mesh = jtu.create_mesh(axis_sizes, axis_names, + axis_types=(AxisType.Explicit,) * len(axis_names)) + with jax.set_mesh(mesh): + np_inp = np.arange(16.).reshape(8, 2) + x = jax.device_put(np_inp, x_spec) + y = jax.device_put(np_inp.T, y_spec) + + @jax.jit + def f(x, y): + out = jnp.einsum('xy,yz->xz', x, y, out_sharding=out_spec) + self.assertEqual(out.aval.sharding.spec, out_spec) + return out + + out = f(x, y) + self.assertNotEmpty(out.sharding.spec.unreduced) + + @jax.jit + def g(x): + self.assertEqual(x.aval.sharding.spec, out_spec) + y = x + x + self.assertEqual(y.aval.sharding.spec, out_spec) + return y + + out2 = g(out) + self.assertEqual(out2.sharding, NamedSharding(mesh, out_spec)) + self.assertNotEmpty(out2.sharding.spec.unreduced) + self.assertEqual(out2.shape, (8, 8)) + self.assertEqual(out2.sharding.shard_shape(out2.shape), shard_shape) + for s in out2.addressable_shards: + self.assertEqual(s.data.shape, shard_shape) + + expected_out = jnp.dot(x, y, out_sharding=P('x', None)) * 2 + + reshard_out2 = jax.sharding.reshard(out2, P('x', None)) + self.assertArraysEqual(reshard_out2, expected_out) + self.assertEqual(reshard_out2.sharding, NamedSharding(mesh, P('x', None))) + self.assertEmpty(reshard_out2.sharding.spec.unreduced) + + @jtu.with_explicit_mesh((2,), 'x') + def test_grad_conv_general_dilated(self, mesh): + @jax.jit + def model(kernel, inputs): + padded_inputs = jnp.pad(inputs, ((0, 0), (2, 0), (0, 0))) + return lax.conv_general_dilated( + padded_inputs, + kernel, + window_strides=(1,), + padding="VALID", + dimension_numbers=lax.ConvDimensionNumbers( + (0, 2, 1), (2, 1, 0), (0, 2, 1))) + + batch_size = 64 + seq_length = 32 + kernel = jax.random.normal(jax.random.key(0), (3, 1, 16)) + inputs = jax.random.normal(jax.random.key(1), (batch_size, seq_length, 1)) + inputs = jax.device_put(inputs, P('x', None, None)) + + out1, out2 = jax.jit(jax.grad(lambda x, y: model(x, y).sum(), argnums=(0, 1)) + )(kernel, inputs) + self.assertEqual(out1.sharding, kernel.sharding) + self.assertEqual(out2.sharding, inputs.sharding) + + @jtu.with_explicit_mesh((2, 2,), ('x', 'y')) + def test_vmap_conv_general_dilated(self, mesh): + @jax.jit + def model(kernel, inputs): + padded_inputs = jnp.pad(inputs, ((0, 0), (2, 0), (0, 0))) + return lax.conv_general_dilated( + padded_inputs, + kernel, + window_strides=(1,), + padding="VALID", + dimension_numbers=lax.ConvDimensionNumbers( + (0, 2, 1), (2, 1, 0), (0, 2, 1)), + out_sharding=P(None, None, None)) + + batch_size = 64 + seq_length = 32 + vmap_size = 16 + kernel = jax.random.normal( + jax.random.key(0), (3, 1, 16)) + inputs = jax.random.normal( + jax.random.key(1), (batch_size, vmap_size, seq_length, 1)) + inputs = jax.device_put(inputs, P('x', None, None, None)) + kernel = jax.device_put(kernel, P(None, None, None)) + + out = jax.jit( + jax.vmap(model, in_axes=(None, 0), out_axes=1))(kernel, inputs) + + self.assertEqual( + out.sharding, NamedSharding(mesh, P(None, 'x', None, None))) + + @jtu.with_explicit_mesh((2,), 'x') + def test_device_put_typeof(self, mesh): + array = jnp.zeros(8) + self.assertEqual(jax.typeof(array).sharding, + NamedSharding(mesh.abstract_mesh, P(None))) + + array = jax.device_put(array, SingleDeviceSharding(jax.devices()[0])) + self.assertTrue(jax.typeof(array).sharding.mesh.empty) + + @jtu.with_explicit_mesh((2,), 'x') + def test_indices_sharded_gather_error(self, mesh): + arr = jax.device_put(jnp.arange(8, dtype=jnp.float32).reshape(4, 2), + P(None, None)) + indices = jax.device_put(jnp.array([0, 2]), P('x')) + + @jax.jit + def f(arr, ids): + out = arr[ids, :] + return jnp.sum(out) + + with self.assertRaises(core.ShardingTypeError): + f(arr, indices) + + @config.numpy_dtype_promotion('standard') + @jtu.with_explicit_mesh((2,), 'x') + def test_vmap_vjp_complex_explicit_mode(self, mesh): + @jax.jit + def f(w, samples): + def f(w, x): + y = jnp.dot(x, w) + return jnp.sum(jnp.log1p(jnp.exp(-2 * y))) + 1j * jnp.sum(y) + + def vjp_step(x): + r, vjp_fn = jax.vjp(lambda w: f(w, x), w) + return vjp_fn(jnp.ones_like(r))[0] + + return jax.vmap(vjp_step)(samples) + + samples = reshard(jax.random.normal(jax.random.key(1), (4, 4)), P("x")) + w = jax.random.normal(jax.random.key(1), (4, 8)) + f(w, samples) # doesn't crash + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_unreduced_einsum_lowers_to_reduce_sum(self, mesh): + arr = jax.device_put(jnp.arange(8).reshape(4, 2), P('x', None)) + + @jax.jit + def f(x): + out = jnp.einsum("ab->b", x, out_sharding=P(None, unreduced={'x'})) + self.assertEqual(out.aval.sharding.spec, P(None, unreduced={'x'})) + return out + + compiled_text = f.lower(arr).compile().as_text() + self.assertNotIn('all-reduce', compiled_text) + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P(None, unreduced={'x'}))) + for s in out.addressable_shards: + self.assertEqual(s.data.shape, (2,)) + + reshard_out = jax.sharding.reshard(out, P(None)) + self.assertArraysEqual(reshard_out, jnp.sum(arr, axis=0)) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_reduce_sum_unreduced(self, mesh): + np_inp = np.arange(16).reshape(4, 2, 2) + arr = jax.device_put(np_inp, P('x', 'y', None)) + + @jax.jit + def f(x): + out = jax.lax.reduce_sum( + x, axes=(0, 1), out_sharding=P(None, unreduced={'x'})) + self.assertEqual(out.aval.sharding.spec, P(None, unreduced={'x'})) + return out + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P(None, unreduced={'x'}))) + for s in out.addressable_shards: + self.assertEqual(s.data.shape, (2,)) + + reshard_out = jax.sharding.reshard(out, P(None)) + self.assertArraysEqual(reshard_out, jnp.sum(arr, axis=(0, 1))) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_reduce_sum_unreduced_error(self, mesh): + # Case 1 + arr2 = jax.device_put(np.arange(16).reshape(8, 2), P('y', None)) + + @jax.jit + def g(x): + return jax.lax.reduce_sum( + x, axes=(0,), out_sharding=P(None, unreduced={'x'})) + + with self.assertRaisesRegex( + core.ShardingTypeError, + "out_sharding's unreduced axes should be in operand's specs that were" + ' summed over'): + g(arr2) + + # Case 2 + @jax.jit + def f(x): + out = jax.lax.reduce_sum( + x, axes=(0,), out_sharding=P(None, unreduced={'y'})) + return jax.lax.reduce_sum(out, axes=(0,)) + + with self.assertRaisesRegex( + core.ShardingTypeError, + "operand passed to reduce_sum cannot be unreduced"): + f(arr2) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_unreduced_multi_axes_reduce_sum(self, mesh): + x = jax.device_put(np.arange(16.).reshape(8, 2), P(('x', 'y'), None)) + + @jax.jit + def f(x): + out = jax.lax.reduce_sum(x, axes=(0,), + out_sharding=P(None, unreduced={'x', 'y'})) + self.assertEqual(out.aval.sharding.spec.unreduced, {'x', 'y'}) + return out + + out = f(x) + self.assertEqual(out.sharding, + NamedSharding(mesh, P(None, unreduced={'x', 'y'}))) + + reshard_out = reshard(out, P(None)) + self.assertArraysEqual(reshard_out, jnp.sum(x, axis=0)) + + @jtu.with_explicit_mesh((2,), 'x') + def test_reduce_sum_scalar_unreduced(self, mesh): + x = jax.device_put(np.arange(8, dtype=np.float32), P('x')) + + @jax.jit + def f(x): + out = jax.lax.reduce_sum(x, axes=(0,), + out_sharding=P(unreduced={'x'})) + self.assertEqual(out.aval.sharding.spec.unreduced, {'x'}) + return out + + lowered_text = f.lower(x).as_text() + self.assertEqual(lowered_text.count('unreduced={"x"}'), 2) + + out = f(x) + self.assertEqual(out.sharding, NamedSharding(mesh, P(unreduced={'x'}))) + + expected_shards = [np.array(6, dtype=np.float32), + np.array(22, dtype=np.float32)] + for s, expected_shard in zip(out.addressable_shards, expected_shards): + self.assertArraysEqual(s.data, expected_shard) - @jtu.with_user_mesh((2,), ('x',)) - def test_auto_axes_late_bind(self, mesh): - @auto_axes + reshard_out = reshard(out, P()) + self.assertArraysEqual(reshard_out, jnp.sum(x)) + + out = jax.jit(jax.grad(f))(x) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + @parameterized.named_parameters( + ('custom_vjp', True), + ('grad', False), + ) + @jtu.with_explicit_mesh((2,), 'x') + def test_minibatch_scan_unreduced(self, use_custom_vjp, mesh): + def assert_unreduced(tup): + for val in tup: + self.assertEqual(val.aval.sharding.spec.unreduced, {'x'}) + + if use_custom_vjp: + @jax.custom_vjp + def f(xs, w): + return jnp.dot(xs, w) + + def f_fwd(xs, w): + return f(xs, w), (xs, w) + + def f_bwd(res, g): + xs, w = res + return jnp.dot(g, w), jnp.dot(xs.T, g, out_sharding=P(unreduced={'x'})) + f.defvjp(f_fwd, f_bwd) + else: + def f(xs, w): + return jnp.dot(xs, w) + + def model(ws, xs_mubatch): + for w in ws: + xs_mubatch = f(xs_mubatch, w) + return jnp.sum(xs_mubatch) + + @jax.jit(donate_argnums=(0,)) + def step(ws, xs): + def mubatch_loop_body(grad_acc, xs_mubatch): + grad = jax.grad(model)(ws, xs_mubatch) + assert_unreduced(grad) + assert_unreduced(grad_acc) + grad_acc = jax.tree.map(jnp.add, grad_acc, grad) + assert_unreduced(grad_acc) + return grad_acc, None + + grad_acc = jax.tree.map(jnp.zeros_like, ws) + grad_acc = reshard(grad_acc, P(unreduced={'x'})) + grad_acc, _ = jax.lax.scan(mubatch_loop_body, grad_acc, xs) + assert_unreduced(grad_acc) + # AR once for a batch + grad_acc = reshard(grad_acc, P()) + ws = reshard(ws, P()) + return jax.tree.map(lambda W, g: W - g * 0.01, ws, grad_acc) + + ws = tuple(jax.device_put(jnp.ones((4, 4)), P(reduced={'x'})) + for _ in range(4)) + xs = jax.device_put(jnp.ones((2, 2, 4)), P(None, 'x', None)) + + step(ws, xs) # doesn't crash + + compiled_text = step.lower(ws, xs).compile().as_text() + if jtu.test_device_matches(['gpu']): + self.assertEqual(compiled_text.count('all-reduce-start('), 1) + self.assertEqual(compiled_text.count('all-reduce-done('), 1) + else: + self.assertEqual(compiled_text.count('all-reduce('), 1) + + @jtu.with_explicit_mesh((2,), 'x') + def test_vmap_mapped_input_sharding_error(self, mesh): + np_inp = np.arange(8).reshape(4, 2) + arr1 = jax.device_put(np_inp, P('x')) + arr2 = jax.device_put(np_inp, P()) + + @jax.jit + @jax.vmap + def f(x, y): + return x, y + + with self.assertRaisesRegex( + ValueError, + 'Mapped away dimension of inputs passed to vmap should be sharded the' + ' same'): + f(arr1, arr2) + + @parameterized.named_parameters( + ('custom_vjp', True), + ('grad', False), + ) + @config.numpy_dtype_promotion('standard') + @jtu.with_explicit_mesh((2,), 'x') + def test_scan_over_layers_minibatch_unreduced(self, use_custom_vjp, mesh): + def assert_unreduced(val): + self.assertEqual(val.aval.sharding.spec.unreduced, {'x'}) + + if use_custom_vjp: + @jax.custom_vjp + def f(xs, w): + return jnp.dot(xs.astype(jnp.bfloat16), w.astype(jnp.bfloat16)) + + def f_fwd(xs, w): + return f(xs, w), (xs, w) + + def f_bwd(res, g): + xs, w = res + return (jnp.dot(g.astype(jnp.bfloat16), w.T.astype(jnp.bfloat16)), + jnp.dot(xs.T.astype(jnp.bfloat16), g.astype(jnp.bfloat16), + out_sharding=P(unreduced={'x'}))) + f.defvjp(f_fwd, f_bwd) + else: + def f(xs, w): + return jnp.dot(xs.astype(jnp.bfloat16), w.astype(jnp.bfloat16)) + + def model(stacked_ws, xs_mubatch): + def scan_over_layers(carry_xs, w): + return f(carry_xs, w), None + final_xs, _ = jax.lax.scan(scan_over_layers, xs_mubatch, stacked_ws) + return jnp.sum(final_xs) + + @jax.jit(donate_argnums=(0,)) + def step(stacked_ws, xs): + def mubatch_loop_body(stacked_grad_acc, xs_mubatch): + grad = jax.grad(model)(stacked_ws, xs_mubatch) + assert_unreduced(grad) + assert_unreduced(stacked_grad_acc) + stacked_grad_acc = jax.tree.map(jnp.add, stacked_grad_acc, grad) + assert_unreduced(stacked_grad_acc) + return stacked_grad_acc, None + + stacked_grad_acc = jax.tree.map(jnp.zeros_like, stacked_ws) + stacked_grad_acc = reshard(stacked_grad_acc, P(unreduced={'x'})) + stacked_grad_acc, _ = jax.lax.scan( + mubatch_loop_body, stacked_grad_acc, xs) + assert_unreduced(stacked_grad_acc) + # AR once for a batch + stacked_grad_acc = reshard(stacked_grad_acc, P()) + stacked_ws = reshard(stacked_ws, P()) + return jax.tree.map( + lambda W, g: W - g * 0.01, stacked_ws, stacked_grad_acc) + + ws = tuple(jax.device_put(jnp.ones((4, 4), dtype=jnp.float32), + P(reduced={'x'})) + for _ in range(4)) + + xs = jax.device_put(jnp.ones((2, 4, 4), dtype=jnp.bfloat16), + P(None, 'x', None)) + stacked_ws = jnp.stack(ws, axis=0) + step(stacked_ws, xs) # doesn't crash + + compiled_text = step.lower(stacked_ws, xs).compile().as_text() + if jtu.test_device_matches(['gpu']): + self.assertEqual(compiled_text.count('all-reduce-start('), 1) + self.assertEqual(compiled_text.count('all-reduce-done('), 1) + else: + self.assertEqual(compiled_text.count('all-reduce('), 1) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_jacrev_sharded_broadcast(self, mesh): + np_inp = np.arange(16, dtype=np.float32) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('y'))) + + @jax.jit + def broadcast_sharded(x): + xs = jax.typeof(x).sharding.spec + return lax.broadcast(x, sizes=[2], out_sharding=P('x', *xs)) + + out = jax.jit(jax.jacrev(broadcast_sharded))(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P(None, None, 'y'))) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_jacfwd_sharded_broadcast(self, mesh): + np_inp = np.arange(16, dtype=np.float32) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('y'))) + + @jax.jit + def broadcast_sharded(x): + xs = jax.typeof(x).sharding.spec + return lax.broadcast(x, sizes=[2], out_sharding=P('x', *xs)) + + out = jax.jit(jax.jacfwd(broadcast_sharded))(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y', None))) + + @jtu.with_explicit_mesh((2, 4), ('x', 'y')) + def test_vmap_grad_sharded_gather(self, mesh): + tokens = jax.device_put( + np.arange(32 * 256, dtype=np.float32).reshape(32, 256), P('x', None)) + ids = jax.device_put( + np.arange(32 * 16, dtype=np.int32).reshape(32, 16), P('x', None)) + + def sum_gather(t, i): + return jnp.mean(t.at[i].get(out_sharding=P('y'))) + + @jax.jit + def f(t, i): + return jax.vmap(jax.value_and_grad(sum_gather))(t, i) + + value, grad = f(tokens, ids) + self.assertEqual(value.shape, (32,)) + self.assertEqual(value.sharding, NamedSharding(mesh, P('x'))) + self.assertEqual(grad.shape, tokens.shape) + self.assertEqual(grad.sharding, tokens.sharding) + + @jtu.with_explicit_mesh((2,), 'x') + def test_auto_axes_inside_scan(self, mesh): + def f(params, x): + @jax.sharding.auto_axes(out_sharding=jax.P()) + def model(x): + return params * x + + def body(c, _): + return model(c), () + + y = jax.lax.scan(body, x, None, length=3)[0] + return y + + jax.jit(jax.grad(f))(1., 2.) # doesn't crash + + @jtu.with_explicit_mesh((2,), ('x',)) + def test_reduced_sin_fwd_mul_bwd(self, mesh): + + np_inp1 = np.arange(8.).reshape(4, 2) + np_inp2 = np.arange(16.).reshape(2, 8) + arr1 = jax.device_put(np_inp1, P(reduced={'x'})) + arr2 = jax.device_put(np_inp2, P(None, 'x')) + + @jax.jit + def f(x, y): + x_ = jnp.sin(x) + y_ = jnp.sin(y) + z = x_ @ y_ + return z.sum() + + f(arr1, arr2) # doesn't crash + + out1, out2 = jax.jit(jax.grad(f, argnums=(0, 1)))(arr1, arr2) + self.assertEqual(out1.sharding, + NamedSharding(mesh, P(None, None, unreduced={'x'}))) + self.assertEqual(out2.sharding, NamedSharding(mesh, P(None, 'x'))) + + with jax.set_mesh(jtu.create_mesh((1,), 'x')): + ex_out1, ex_out2 = jax.jit(jax.grad(f, argnums=(0, 1)))(np_inp1, np_inp2) + + self.assertArraysAllClose(ex_out1, reshard(out1, P()), rtol=2e-4) + self.assertArraysAllClose(ex_out2, out2, rtol=2e-4) + + @jtu.with_explicit_mesh((2,), 'x') + def test_mul_reduced_error(self, mesh): + arr1 = jax.device_put(np.arange(8.), P(reduced={'x'})) + arr2 = jax.device_put(np.arange(8.), P('x')) + + @jax.jit + def f(x, y): + return x * y + + with self.assertRaisesRegex( + core.ShardingTypeError, + "Inputs cannot be sharded on the same axes that another input is " + "reduced on"): + f(arr1, arr2) + + with self.assertRaisesRegex( + core.ShardingTypeError, + "Inputs cannot be sharded on the same axes that another input is " + "reduced on"): + f(arr2, arr1) + + @jtu.with_explicit_mesh((2,), 'x') + def test_jnp_pad_reflect(self, mesh): + np_inp = np.arange(2 * 1024).reshape(2, 16, 16, -1) + arr = reshard(np_inp, P('x', None)) + + padded_arr = jnp.pad(arr, [(0, 0), (3, 3), (3, 3), (0, 0)], mode='reflect') + self.assertEqual(padded_arr.sharding, + NamedSharding(mesh, P('x', None, None, None))) + + @jtu.with_explicit_mesh((2,), 'x') + def test_jnp_repeat_arraylike(self, mesh): + positions = jnp.zeros((5, 3)) + charges = jnp.ones((5,), dtype=jnp.int32) + num_electrons = 5 + jnp.repeat(positions, charges, axis=0, total_repeat_length=num_electrons, + out_sharding=P()) # doesn't crash + jnp.repeat(positions, 5, axis=0, total_repeat_length=num_electrons, + out_sharding=P()) # doesn't crash + + @parameterized.named_parameters( + ('mul', jax.lax.mul), + ('add', jax.lax.add), + ) + @jtu.with_explicit_mesh((2,), 'x') + def test_both_inputs_reduced(self, func, mesh): + if ifrt_version < 46: + self.skipTest('Requires ifrt_version >= 46') + if not jtu.is_cloud_tpu_at_least(2025, 12, 22): + self.skipTest('Requires libtpu built after 2025-12-22') + arr1 = jax.device_put(np.arange(8.), P(reduced={'x'})) + arr2 = jax.device_put(np.arange(8.), P(reduced={'x'})) + + @jax.jit + def f(x, y): + z = func(x, y) + return z.sum() + + out1, out2 = jax.jit(jax.grad(f, argnums=(0, 1)))(arr1, arr2) + self.assertEqual(out1.sharding, + NamedSharding(mesh, P(None, unreduced={'x'}))) + self.assertEqual(out2.sharding, + NamedSharding(mesh, P(None, unreduced={'x'}))) + + arr3 = jax.device_put(np.arange(8.), P()) + arr4 = jax.device_put(np.arange(8.), P()) + ex_out1, ex_out2 = jax.jit(jax.grad(f, argnums=(0, 1)))(arr3, arr4) + self.assertArraysEqual(reshard(out1, P()), ex_out1) + self.assertArraysEqual(reshard(out2, P()), ex_out2) + + @parameterized.named_parameters( + ('mul', jax.lax.mul), + ('add', jax.lax.add), + ) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_one_input_reduced_another_replicated(self, func, mesh): + arr1 = jax.device_put(np.arange(8.).reshape(4, 2), P('x', reduced={'y'})) + arr2 = jax.device_put(np.arange(8.).reshape(4, 2), P('x', None)) + + def f(x, y): + z = func(x, y) + return z.sum() + + with self.assertRaisesRegex( + core.ShardingTypeError, + "Inputs cannot be replicated on the same axes that another input is " + "reduced on"): + jax.jit(f)(arr1, arr2) + + with self.assertRaisesRegex( + core.ShardingTypeError, + "Inputs cannot be replicated on the same axes that another input is " + "reduced on"): + jax.jit(jax.shard_map(f, out_specs=P()))(arr1, arr2) + + @parameterized.parameters( + ((8,), P('x'), P(None, unreduced={'x'})), + ((4, 2), P('x', 'y'), P(None, None, unreduced={'x', 'y'})), + ((4, 2), P('x', 'y'), P('x', None, unreduced={'y'})), + ((4, 2), P('x', 'y'), P(None, 'y', unreduced={'x'})), + ((4, 2), P('x', None), P(None, None, unreduced={'x'})), + ((4, 2), P('y', None), P(None, None, unreduced={'y'})), + ((4, 2), P(('x', 'y'), None), P(None, None, unreduced={'x', 'y'})), + ((4, 4), P(None, ('x', 'y')), P(None, None, unreduced={'x', 'y'})), + # TODO(yashkatariya): Enable this after collectives + S->U cast is enabled. + # ((4, 2), P('x', 'y'), P(None, 'x', unreduced={'y'})), + # ((4, 2), P('x', 'y'), P('y', None, unreduced={'x'})), + ) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_sharded_unreduced_roundtrip(self, shape, orig_spec, un_spec, mesh): + np1 = np.arange(math.prod(shape)).reshape(shape) + arr = jax.device_put(np1, orig_spec) + + arr2 = reshard(arr, un_spec) + self.assertEqual(arr2.sharding, NamedSharding(mesh, un_spec)) + + arr3 = reshard(arr2, orig_spec) + self.assertArraysEqual(arr, arr3) + self.assertEqual(arr.sharding, arr3.sharding) + + @jtu.with_explicit_mesh((2,), ('x',)) + def test_scalar_to_unreduced(self, mesh): + if ifrt_version < 46: + self.skipTest('Requires ifrt_version >= 46') + if not jtu.is_cloud_tpu_at_least(2025, 12, 22): + self.skipTest('Requires libtpu built after 2025-12-22') + inp = jnp.array(1) + for s in inp.addressable_shards: + self.assertArraysEqual(s.data, inp) + + out = reshard(inp, P(unreduced={'x'})) + expected_out = [inp, jnp.array(0)] + for s, ex_out in zip(out.addressable_shards, expected_out): + self.assertArraysEqual(s.data, ex_out) + + out2 = reshard(out, P()) + for s, inp_s in zip(out2.addressable_shards, inp.addressable_shards): + self.assertArraysEqual(s.data, inp_s.data) + + @parameterized.parameters( + ((4,), P(None), P(None, unreduced={'x'})), + ((4,), P(None), P(None, unreduced={'y'})), + ((4,), P(None), P(None, unreduced={'x', 'y'})), + ((4, 2), P(None, None), P(None, None, unreduced={'x'})), + ((4, 2), P(None, None), P(None, None, unreduced={'y'})), + ((4, 2), P(None, None), P(None, None, unreduced={'x', 'y'})), + ((4, 4), P('x', None), P('x', None, unreduced={'y'})), + ((4, 4), P(None, 'y'), P(None, 'y', unreduced={'x'})), + ((4, 4), P('x', None), P(None, None, unreduced={'x', 'y'})), + ((4, 4), P('y', None), P(None, None, unreduced={'x', 'y'})), + ((4, 4), P('x', 'y'), P(None, None, unreduced={'x', 'y', 'z'})), + ((4, 4), P('x', 'z'), P(None, None, unreduced={'x', 'y', 'z'})), + ((4, 4), P(('x', 'z'), 'y'), P(None, None, unreduced={'x', 'y', 'z'})), + ((4, 4), P(('x', 'z'), 'y'), P('x', None, unreduced={'y', 'z'})), + ((4, 4), P('z', 'y'), P(None, None, unreduced={'y', 'z'})), + ((4, 4), P('z', 'y'), P(None, None, unreduced={'x', 'y', 'z'})), + ) + @jtu.with_explicit_mesh((2, 2, 2), ('x', 'y', 'z')) + def test_replicated_sharded_unreduced_roundtrip( + self, shape, orig_spec, un_spec, mesh): + if ifrt_version < 46: + self.skipTest('Requires ifrt_version >= 46') + if not jtu.is_cloud_tpu_at_least(2025, 12, 22): + self.skipTest('Requires libtpu built after 2025-12-22') + np1 = np.arange(math.prod(shape)).reshape(shape) + arr = jax.device_put(np1, orig_spec) + + arr2 = reshard(arr, un_spec) + self.assertEqual(arr2.sharding, NamedSharding(mesh, un_spec)) + + arr3 = reshard(arr2, orig_spec) + self.assertArraysEqual(arr, arr3) + self.assertEqual(arr.sharding, arr3.sharding) + + @parameterized.named_parameters( + ('mul', jax.lax.mul), + ('add', jax.lax.add), + ) + @jtu.with_explicit_mesh((2,), 'x') + def test_one_input_sharded_another_reduced(self, func, mesh): + np1 = np.arange(8.) + arr1 = jax.device_put(np1, P('x')) + arr2 = jax.device_put(np1, P(None, reduced={'x'})) + + @jax.jit + def f(x, y): + y_ = reshard(y, P('x')) + z = func(x, y_) + return z + + out = f(arr1, arr2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + self.assertArraysEqual(out, func(np1, np1)) + + @jax.jit + def g(x, y): + return f(x, y).sum() + + out1, out2 = jax.jit(jax.grad(g, argnums=(0, 1)))(arr1, arr2) + self.assertEqual(out1.sharding, NamedSharding(mesh, P('x'))) + self.assertEqual(out2.sharding, + NamedSharding(mesh, P(None, unreduced={'x'}))) + + arr3 = jax.device_put(np1, P(None)) + ex_out1, ex_out2 = jax.jit(jax.grad(g, argnums=(0, 1)))(arr1, arr3) + self.assertArraysEqual(reshard(out1, P()), ex_out1) + self.assertArraysEqual(reshard(out2, P()), ex_out2) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_reduced_reshard_unreduced_bwd(self, mesh): + np1 = np.arange(4.) + arr = jax.device_put(np1, P(None, reduced={'x'})) + + @jax.jit + def f(x): + return jax.reshard(x, P('x')) + + out = f(arr) + self.assertArraysEqual(out, np1) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + @jax.jit + def g(x): + return f(x).sum() + + out = jax.jit(jax.grad(g))(arr) + ex_data = [np.array([1., 1., 0., 0.]), np.array([1., 1., 0., 0.]), + np.array([0., 0., 1., 1.]), np.array([0., 0., 1., 1.])] + for s, d in zip(out.addressable_shards, ex_data): + self.assertArraysEqual(s.data, d) + + arr2 = jax.device_put(np1, P(None)) + expected_out = jax.jit(jax.grad(g))(arr2) + self.assertArraysEqual(reshard(out, P()), expected_out) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_reduced_reshard_unreduced_bwd_sharded(self, mesh): + np1 = np.arange(8.).reshape(4, 2) + arr = jax.device_put(np1, P('x', None, reduced={'y'})) + + @jax.jit + def f(x): + return jax.reshard(x, P('x', 'y')) + + out = f(arr) + self.assertArraysEqual(out, np1) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + @jax.jit + def g(x): + return f(x).sum() + + out = jax.jit(jax.grad(g))(arr) + self.assertEqual(out.sharding, + NamedSharding(mesh, P('x', None, unreduced={'y'}))) + + arr2 = jax.device_put(np1, P('x', None)) + expected_out = jax.jit(jax.grad(g))(arr2) + self.assertEqual(expected_out.sharding, NamedSharding(mesh, P('x', None))) + + self.assertArraysEqual(reshard(out, P('x', None)), expected_out) + self.assertArraysEqual(reshard(out, P()), reshard(expected_out, P())) + + @jtu.with_explicit_mesh((2,), 'x') + def test_reduced_at_get_out_sharding(self, mesh): + np1 = np.ones((2048, 64), dtype=jnp.float32) + np2 = np.ones((4, 128), dtype=jnp.int32) + params = jax.device_put(np1, P(None, None, reduced={'x'})) + inputs = jax.device_put(np2, P('x', None)) + + @jax.jit + def f(params, inputs): + return params.at[inputs].get(out_sharding=P('x', None, None)) + + out = f(params, inputs) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None, None))) + self.assertArraysEqual(out, np1[np2]) + + @jax.jit + def g(params, inputs): + return f(params, inputs).sum() + + out = jax.jit(jax.grad(g))(params, inputs) + self.assertEqual(out.sharding, + NamedSharding(mesh, P(None, None, unreduced={'x'}))) + + @jtu.with_explicit_mesh((2,), 'x') + def test_out_aval_matches_out_sharding(self, mesh): + arr = jnp.arange(8) + + @jax.jit(out_shardings=P('x')) def f(x): return x * 2 - out = f(np.arange(8), out_shardings=P('x')) + out = f(arr) self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) - self.assertArraysEqual(out, np.arange(8) * 2) + self.assertEqual(out.aval.sharding, NamedSharding(mesh.abstract_mesh, P('x'))) - @jtu.with_user_mesh((2,), ('x',), axis_types=AxisType.Auto) - def test_explicit_axes_late_bind(self, mesh): - @explicit_axes + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_out_aval_matches_out_sharding_override(self, mesh): + arr = jax.device_put(jnp.arange(8).reshape(4, 2), P('x', None)) + + @jax.jit(out_shardings=P('x', 'y')) + def f(x): + return x * 2 + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + self.assertEqual(out.aval.sharding, + NamedSharding(mesh.abstract_mesh, P('x', 'y'))) + + @jax.jit + def g(x): + return x * 2 + + out = g(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + self.assertEqual(out.aval.sharding, + NamedSharding(mesh.abstract_mesh, P('x', None))) + + @jtu.with_explicit_mesh((2,), 'x', axis_types=(AxisType.Auto,)) + def test_out_aval_auto_mode(self, mesh): + arr = jax.device_put(jnp.arange(8).reshape(4, 2), P('x')) + + @jax.jit + def f(x): + return x * 2 + + out = f(arr) + self.assertEqual(out.aval.sharding.mesh, mesh.abstract_mesh) + + @jtu.with_explicit_mesh((2,), 'x') + def test_in_aval_matches_in_sharding(self, mesh): + arr = np.arange(8) + + @jax.jit(in_shardings=P('x'), out_shardings=P('x')) def f(x): return x * 2 - out = f(np.arange(8), in_shardings=P('x')) + lowered = f.lower(arr) + l_in_aval = lowered.in_avals[0][0] + self.assertEqual(l_in_aval.sharding, + NamedSharding(mesh.abstract_mesh, P('x'))) + + compiled = lowered.compile() + c_in_aval = compiled.in_avals[0][0] + self.assertEqual(c_in_aval.sharding, + NamedSharding(mesh.abstract_mesh, P('x'))) + compiled(arr) # doesn't crash + + @jtu.with_explicit_mesh((2,), 'x') + def test_c64_to_f32_view_rountrip(self, mesh): + x = jnp.zeros((128, 64), dtype=jnp.complex64, out_sharding=P(('x'))) + y = jax.jit(lambda _x: _x.view(jnp.float32))(x) + self.assertEqual(y.sharding, NamedSharding(mesh, P('x', None))) + + x = jnp.zeros((128, 64), dtype=jnp.float32, out_sharding=P(('x'))) + y = jax.jit(lambda _x: _x.view(jnp.complex64))(x) + self.assertEqual(y.sharding, NamedSharding(mesh, P('x', None))) + + @jtu.with_explicit_mesh((2,), 'x') + def test_jnp_ones_mesh_ctx_aval(self, mesh): + @jax.jit + def f(): + out = jnp.ones((2,)) + self.assertEqual(out.aval.sharding.mesh, mesh.abstract_mesh) + self.assertEqual(out.aval.sharding.spec, P(None)) + return out + + self.assertEqual(f().sharding, NamedSharding(mesh, P(None))) + + def test_reshard_no_mesh_ctx(self): + mesh = jtu.create_mesh((2,), 'x') + with self.assertRaisesRegex( + ValueError, "cannot contain axis names that are of type Auto"): + reshard(np.arange(8), NamedSharding(mesh, P('x'))) + + mesh = jtu.create_mesh((2,), 'x', axis_types=(AxisType.Explicit,)) + out = reshard(np.arange(8), NamedSharding(mesh, P('x'))) self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) - self.assertArraysEqual(out, np.arange(8) * 2) + self.assertArraysEqual(out, np.arange(8)) + + @jtu.with_explicit_mesh((2,), 'x') + def test_sub_custom_jvp(self, mesh): + np1 = np.arange(2 * 1024, dtype=np.float32).reshape(4, 16, 16, -1) + arr1 = jax.device_put(np1, P("x", None)) + arr2 = jax.device_put(np1, P("x", None)) + + def f(logits, labels): + labels = jnp.astype(labels, logits.dtype) + log_p = jax.nn.log_sigmoid(logits) + log_not_p = jax.nn.log_sigmoid(-logits) + return -labels * log_p - (1.0 - labels) * log_not_p + + @jax.jit + def g(pl, tl): + x = pl - jnp.mean(tl) + y = tl - jnp.mean(pl) + loss_on_fake = jnp.mean(f(x, jnp.zeros_like(pl))) + loss_on_real = jnp.mean(f(y, jnp.ones_like(tl))) + disc_loss = loss_on_fake + loss_on_real + gen_loss = jnp.mean(f(x, jnp.ones_like(pl))) + gen_loss += jnp.mean(f(y, jnp.zeros_like(tl))) + return disc_loss, gen_loss + + jax.jit(jax.grad(lambda t1, t2: g(t1, t2)[0]))(arr1, arr2) # doesn't crash + + @jtu.with_explicit_mesh((2,), 'x') + def test_reshard_zero_cotangent(self, mesh): + @jax.custom_vjp + def f(x): + return x + + def f_fwd(x): + return x, None + + def f_bwd(res, g): + return None, + + f.defvjp(f_fwd, f_bwd) + + def g(x): + x = reshard(x, P('x')) + return f(x) + + jax.jit(jax.grad(lambda x: g(x).sum()))(jnp.arange(8.)) # doesn't crash + + @jtu.with_explicit_mesh((4, 2), ('x', 'y')) + def test_tile(self, mesh): + @jax.jit + def tile(x): + return jnp.tile(x, (2, 3)) + + x = jax.device_put(np.ones((32, 64)), P('x', 'y')) + out = tile(x) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + self.check_wsc_in_lowered(tile.lower(x).as_text()) + + x = jax.device_put(np.ones((32, 64)), P('x', None)) + out = tile(x) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + self.check_wsc_in_lowered(tile.lower(x).as_text()) + + x = jax.device_put(np.ones((32, 64)), P(None, 'y')) + out = tile(x) + self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y'))) + self.check_wsc_in_lowered(tile.lower(x).as_text()) + + x = jax.device_put(np.ones((32, 64)), P(None, ('x', 'y'))) + out = tile(x) + self.assertEqual(out.sharding, NamedSharding(mesh, P(None, ('x', 'y')))) + self.check_wsc_in_lowered(tile.lower(x).as_text()) @jtu.pytest_mark_if_available('multiaccelerator') @@ -7181,6 +10116,7 @@ def testNonDivisibleArgs(self, mesh, resources): with self.assertRaisesRegex(ValueError, error): pjit(lambda x: x, in_shardings=spec, out_shardings=None)(x) + @unittest.skip("regressed") # TODO(mattjj): fix test @check_1d_2d_mesh(set_mesh=True) def testNonDivisibleOuts(self, mesh, resources): x = jnp.ones((3, 2)) @@ -7296,9 +10232,7 @@ def testEmptyMesh(self): self.assertEqual(out.sharding, SingleDeviceSharding(jax.devices()[0])) def test_pspec_to_wsc_without_mesh(self): - error = ( - r'with_sharding_constraint requires a non-empty mesh if you are ' - r'passing `PartitionSpec`s or `None` to shardings.*') + error = r'with_sharding_constraint requires a non-empty mesh in context.*' with self.assertRaisesRegex(RuntimeError, error): pjit(lambda x: with_sharding_constraint(x, P('x')))(jnp.arange(4)) @@ -7311,8 +10245,8 @@ def testAxisResourcesMismatch(self): error = re.escape( "pjit in_shardings specification must be a tree prefix of the " - "positional arguments tuple passed to the `pjit`-decorated function. " - "In particular, pjit in_shardings must either be a None, a " + "positional arguments tuple. " + "In particular, pjit in_shardings must either be a Sharding, a " "PartitionSpec, or a tuple of length equal to the number of positional " "arguments. But pjit in_shardings is the wrong length: got a " "tuple or list of length 3 for an args tuple of length 2.") @@ -7363,7 +10297,7 @@ def h(x): xshape = (2, 5, 6) x = jnp.arange(math.prod(xshape)).reshape(xshape) with self.assertRaisesRegex( - ValueError, "Received incompatible devices for jitted computation.*"): + ValueError, ".*cannot change the size of the mesh.*"): f(x) @parameterized.named_parameters( @@ -7437,7 +10371,8 @@ def f(a, b, c): def test_named_sharding_of_none(self): mesh = jtu.create_mesh((2,), ('x',)) - with self.assertRaisesRegex(TypeError, 'Unexpected None'): + with self.assertRaisesRegex( + TypeError, '(Unexpected None|incompatible function arguments)'): jax.NamedSharding(mesh, None) @@ -7499,14 +10434,14 @@ def test_op_sharding_equality_and_hash_equality(self): op3.tile_assignment_dimensions = [4, 2] op3.tile_assignment_devices = [0, 1, 2, 3, 4, 5, 6, 7] - self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2)) - self.assertFalse(op_shardings.are_op_shardings_equal(op1, op3)) - self.assertFalse(op_shardings.are_op_shardings_equal(op2, op3)) - hs1 = xc.HloSharding.from_proto(op1) hs2 = xc.HloSharding.from_proto(op2) hs3 = xc.HloSharding.from_proto(op3) + self.assertTrue(op_shardings.are_hlo_shardings_equal(hs1, hs2)) + self.assertFalse(op_shardings.are_hlo_shardings_equal(hs1, hs3)) + self.assertFalse(op_shardings.are_hlo_shardings_equal(hs2, hs3)) + self.assertEqual(hs1, xc.HloSharding.iota_tile((2, 2))) self.assertEqual(hs2, xc.HloSharding.iota_tile((2, 2))) self.assertEqual(hs3, xc.HloSharding.iota_tile((4, 2))) @@ -7533,10 +10468,11 @@ def test_op_sharding_partial_sharding(self): op2.tile_assignment_devices = [0, 2, 1, 3] op2.last_tile_dims = [xc.OpSharding.Type.REPLICATED] - self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2)) - hs1 = xc.HloSharding.from_proto(op1) hs2 = xc.HloSharding.from_proto(op2) + + self.assertTrue(op_shardings.are_hlo_shardings_equal(hs1, hs2)) + self.assertEqual( hs1, xc.HloSharding.iota_tile( @@ -7582,20 +10518,19 @@ def test_op_sharding_tuple_shardings(self): op2.type = xc.OpSharding.Type.TUPLE op2.tuple_shardings = [top2, top1] - self.assertFalse(op_shardings.are_op_shardings_equal(op1, op2)) - hs1 = xc.HloSharding.from_proto(op1) hs2 = xc.HloSharding.from_proto(op2) + self.assertFalse(op_shardings.are_hlo_shardings_equal(hs1, hs2)) self.assertNotEqual(hash(hs1), hash(hs2)) def test_hlo_sharding_iota_tile_error(self): self.assertRaisesRegex( - xla_extension.XlaRuntimeError, + jax.errors.JaxRuntimeError, 'INVALID_ARGUMENT: `dims` should not be empty.', lambda: xc.HloSharding.iota_tile(()) ) self.assertRaisesRegex( - xla_extension.XlaRuntimeError, + jax.errors.JaxRuntimeError, 'INVALID_ARGUMENT: Cannot reshape from', lambda: xc.HloSharding.iota_tile( (2, 2), @@ -7604,7 +10539,7 @@ def test_hlo_sharding_iota_tile_error(self): ), ) self.assertRaisesRegex( - xla_extension.XlaRuntimeError, + jax.errors.JaxRuntimeError, 'INVALID_ARGUMENT: `reshape_dims` and `transpose_perm` should have the' ' same size', lambda: xc.HloSharding.iota_tile( @@ -7612,10 +10547,10 @@ def test_hlo_sharding_iota_tile_error(self): transpose_perm=(1, 0), ), ) - self.assertRaisesWithLiteralMatch( - xla_extension.XlaRuntimeError, - 'INVALID_ARGUMENT: `subgroup_types`(3) should not have more dimensions ' - 'than `dims`(2).', + self.assertRaisesRegex( + jax.errors.JaxRuntimeError, + r'INVALID_ARGUMENT: `subgroup_types`\(3\) should not have more dimensions ' + r'than `dims`\(2\).', lambda: xc.HloSharding.iota_tile( (2, 2), subgroup_types=( @@ -7626,6 +10561,7 @@ def test_hlo_sharding_iota_tile_error(self): ), ) + @jtu.thread_unsafe_test() def test_device_indices_cache(self): op1 = xc.OpSharding() op1.type = xc.OpSharding.Type.OTHER @@ -7677,13 +10613,18 @@ def test_op_sharding_semantically_replicated(self): op4.tile_assignment_dimensions = [1] op4.tile_assignment_devices = [0] - self.assertTrue(op_shardings.is_op_sharding_replicated(op1)) - self.assertTrue(op_shardings.is_op_sharding_replicated(op2)) - self.assertTrue(op_shardings.is_op_sharding_replicated(op3)) - self.assertTrue(op_shardings.is_op_sharding_replicated(op4)) - self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2)) - self.assertTrue(op_shardings.are_op_shardings_equal(op2, op3)) - self.assertTrue(op_shardings.are_op_shardings_equal(op3, op4)) + hs1 = xc.HloSharding.from_proto(op1) + hs2 = xc.HloSharding.from_proto(op2) + hs3 = xc.HloSharding.from_proto(op3) + hs4 = xc.HloSharding.from_proto(op4) + + self.assertTrue(op_shardings.is_hlo_sharding_replicated(hs1)) + self.assertTrue(op_shardings.is_hlo_sharding_replicated(hs2)) + self.assertTrue(op_shardings.is_hlo_sharding_replicated(hs3)) + self.assertTrue(op_shardings.is_hlo_sharding_replicated(hs4)) + self.assertTrue(op_shardings.are_hlo_shardings_equal(hs1, hs2)) + self.assertTrue(op_shardings.are_hlo_shardings_equal(hs2, hs3)) + self.assertTrue(op_shardings.are_hlo_shardings_equal(hs3, hs4)) def test_op_sharding_manual_replicated(self): op1 = xc.OpSharding() @@ -7701,12 +10642,15 @@ def test_op_sharding_manual_replicated(self): op3 = xc.OpSharding() op3.type = xc.OpSharding.Type.REPLICATED - self.assertTrue(op_shardings.is_op_sharding_replicated(op1)) - self.assertTrue(op_shardings.is_op_sharding_replicated(op2)) - self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2)) - self.assertTrue(op_shardings.are_op_shardings_equal(op1, op3)) - hs1 = xc.HloSharding.from_proto(op1) + hs2 = xc.HloSharding.from_proto(op2) + hs3 = xc.HloSharding.from_proto(op3) + + self.assertTrue(op_shardings.is_hlo_sharding_replicated(hs1)) + self.assertTrue(op_shardings.is_hlo_sharding_replicated(hs2)) + self.assertTrue(op_shardings.are_hlo_shardings_equal(hs1, hs2)) + self.assertTrue(op_shardings.are_hlo_shardings_equal(hs1, hs3)) + self.assertEqual( hs1, xc.HloSharding.iota_tile( @@ -7720,7 +10664,6 @@ def test_op_sharding_manual_replicated(self): self.assertTrue(hs1.is_replicated()) self.assertFalse(hs1.replicate_on_last_tile_dim()) - hs2 = xc.HloSharding.from_proto(op2) self.assertEqual( xc.HloSharding.from_proto(op2), xc.HloSharding.iota_tile( @@ -7862,163 +10805,5 @@ def f(x, y): self.assertLen(out, 16) -@jtu.with_config(jax_use_shardy_partitioner=True) -class ShardyTest(jtu.JaxTestCase): - - # TODO(bartchr): Once JAX is released with SDY, remove setUp. - def setUp(self): - if not dialects.sdy: - raise unittest.SkipTest('Shardy is not available.') - super().setUp() - - def test_lowering_input_output_sharding(self): - mesh = jtu.create_mesh((4, 2), ('x', 'y')) - np_inp = np.arange(16).reshape(8, 2) - s = jax.sharding.NamedSharding(mesh, P('x', 'y')) - arr = jax.device_put(np_inp, s) - - @partial(jax.jit, out_shardings=s) - def f(x): - return x * 2 - - self.assertIn('sdy.sharding = #sdy.sharding', f.lower(arr).as_text()) - - def test_lowering_with_sharding_constraint(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) - arr = np.arange(16).reshape(4, 2, 2) - - @jax.jit - def f(x): - return jax.lax.with_sharding_constraint( - x, NamedSharding(mesh, P('x', None, 'y'))) - lowered_str = jax.jit(f).lower(arr).as_text() - self.assertIn('sdy.sharding_constraint', lowered_str) - self.assertIn('<@mesh, [{"x"}, {}, {"y"}]>', lowered_str) - - def test_lowering_with_sharding_constraint_unconstrained(self): - mesh = jtu.create_mesh((4, 2), ('x', 'y')) - arr = np.arange(16).reshape(4, 2, 2) - - @jax.jit - def f(x): - return jax.lax.with_sharding_constraint( - x, NamedSharding(mesh, P('x', P.UNCONSTRAINED, 'y'))) - lowered_str = f.lower(arr).as_text() - self.assertIn('sdy.sharding_constraint', lowered_str) - self.assertIn('<@mesh, [{"x"}, {?}, {"y"}]>', lowered_str) - - # TODO(bartchr): run on CPU once Shardy is added to the XLA CPU pipeline. - @jtu.skip_on_devices('cpu') - def test_compile_with_inferred_out_sharding(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) - x = jax.device_put(np.arange(8 * 4).reshape(8, 4), - NamedSharding(mesh, P('x', 'y'))) - y = jax.device_put(np.arange(4 * 16).reshape(4, 16), - NamedSharding(mesh, P('y'))) - - @jax.jit - def f(x, y): - return x @ y - - out = f(x, y) - self.assertArraysEqual(out, x @ y) - self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) - - def test_fully_automatic_sharding(self): - mesh = jtu.create_mesh((8,), ('x',)) - x = jax.ShapeDtypeStruct((128, 128), jnp.float32) - - @jax.jit - def f(x, y): - return x @ y - - lowered_str = jax.jit(f, in_shardings=[AUTO(mesh), AUTO(mesh)]).lower(x, x).as_text() - self.assertIn('sdy.mesh @mesh = <["x"=8]>', lowered_str) - - def test_array_sharding_repr_with_priority(self): - sharding = sharding_impls.SdyArraySharding( - mesh_shape=(('data', 4), ('model', 8), ('expert', 2)), - dimension_shardings=[ - sharding_impls.SdyDimSharding(axes=['data', 'expert'], is_closed=True), - sharding_impls.SdyDimSharding(axes=['model'], is_closed=False, priority=2)]) - self.assertEqual(repr(sharding), "SdyArraySharding([{'data', 'expert'}, {'model', ?}p2])") - - def test_array_sharding_repr_with_logical_ids(self): - abstract_mesh = jax.sharding.AbstractMesh((4, 8, 2), ('x', 'y', 'z')) - ns = NamedSharding(abstract_mesh, P(('x', 'y'), 'z', P.UNCONSTRAINED, None), - _logical_device_ids=[4, 5, 6, 7, 0, 1, 2, 3]) - self.assertEqual(repr(ns._to_sdy_sharding(4)), - "SdyArraySharding([{'x', 'y'}, {'z'}, {?}, {}], " - "device_ids=[4, 5, 6, 7, 0, 1, 2, 3])") - - def test_dimension_sharding_repr(self): - dim_sharding = sharding_impls.SdyDimSharding( - axes=['data', 'model'], is_closed=False, priority=2) - self.assertEqual(repr(dim_sharding), - "SdyDimSharding({'data', 'model', ?}p2)") - - def test_tensor_dialect(self): - # While this doesn't emit any `mlir::TensorDialect` ops, some pass in the - # compiler pipeline is temporarily introducing it before then discarding it - # again. Make sure this doesn't crash. - mesh = jtu.create_mesh((2,), ('x')) - in_sds = jax.ShapeDtypeStruct((4, 8), jnp.float32) - - @partial(jax.jit, out_shardings=NamedSharding(mesh, P('x'))) - def gen_dummy_inputs(): - return tuple(jax.random.normal(jax.random.key(42), shape=in_sds.shape - ).astype(in_sds.dtype)) - gen_dummy_inputs() # doesn't crash - - @jtu.skip_on_devices('cpu') - def test_custom_partition_with_sharding_rule_callback(self): - if jtu.is_cloud_tpu(): - raise unittest.SkipTest("Custom partitioning is not supported on libtpu.") - - def partition(static_arg0, static_arg1, mesh, arg_shapes, result_shape): - arg_shardings = jax.tree.map(lambda s: s.sharding, arg_shapes) - result_sharding = result_shape.sharding - rank = len(arg_shapes[0].shape) - - self.assertEqual(static_arg0, 1) - self.assertEqual(static_arg1, 2) - def lower_fn(x, y): - axis_name = arg_shardings[1].spec[rank-2][0] - i = jax.lax.axis_index(axis_name) - z = jax.lax.psum(jax.lax.dynamic_slice_in_dim( - jax.lax.dynamic_slice_in_dim(x, i * 0, 8, axis=rank-2), - i * 8, 8, axis=rank-1) @ y, (axis_name)) - return z - - return mesh, lower_fn, (result_sharding), arg_shardings - - def produce_sharding_rule(static_arg0, static_arg1, mesh, arg_shapes, result_shape): - self.assertEqual(static_arg0, 1) - self.assertEqual(static_arg1, 2) - rank = len(arg_shapes[0].shape) - leading_axes = "" - for i in range(rank - 2): - leading_axes += f" b{i}" - return f"{leading_axes} i j, {leading_axes} j k -> {leading_axes} i k" - - @partial(custom_partitioning, static_argnums=(2,3)) - def f(x, y, static_arg0=1, static_arg1=2): - return jnp.matmul(x, y) - - f.def_partition( - infer_sharding_from_operands=None, - partition=partition, - sharding_rule=produce_sharding_rule) - - mesh = jtu.create_mesh((4, 2), ('x', 'y')) - x = jax.device_put(np.arange(2 * 3 * 32 * 16).reshape(2, 3, 32, 16), - NamedSharding(mesh, P(None, None, 'x'))) - y = jax.device_put(np.arange(2 * 3 * 16 * 32).reshape(2, 3, 16, 32), - NamedSharding(mesh, P(None, None,'y'))) - result = jax.jit(f)(x, y) - expected_result = f(x, y) - self.assertArraysEqual(result, expected_result) - self.assertEqual(result.sharding, NamedSharding(mesh, P(None, None, 'x'))) - if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pmap_shmap_merge_test.py b/tests/pmap_shmap_merge_test.py new file mode 100644 index 000000000000..9a4bafcf42b5 --- /dev/null +++ b/tests/pmap_shmap_merge_test.py @@ -0,0 +1,121 @@ +# Copyright 2026 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +import unittest +import warnings + +from absl.testing import absltest +import jax +from jax._src import config +from jax._src import core +from jax._src import dtypes +from jax._src import stages +from jax._src import test_util as jtu +import jax.numpy as jnp +import numpy as np + +config.parse_flags_with_absl() +jtu.request_cpu_devices(8) + +# Suppress the deprecation warning from @config.pmap_shmap_merge(True) decorator +# which is triggered at class definition time. +warnings.filterwarnings( + 'ignore', + message='Setting `jax_pmap_shmap_merge` is deprecated', + category=DeprecationWarning, +) + + +class PmapShmapMergeTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if jax.device_count() < 2: + raise unittest.SkipTest('test requires at least two devices') + + @config.pmap_shmap_merge(True) + def test_store_exception(self): + def f(x): + return x + + inp = jnp.ones((jax.device_count(), 1), dtype=jnp.float32) + jax.pmap(f, axis_name='i')(inp) + inp = jnp.ones((jax.device_count(), 1), dtype=jnp.int32) + jax.pmap(f, axis_name='i')(inp) + + @config.pmap_shmap_merge(True) + def test_prng_key(self): + keys = jax.random.split(jax.random.key(0), jax.device_count()) + out = jax.pmap(lambda x: x)(keys) + self.assertEqual(type(out), type(keys)) + out = jax.pmap(lambda x, y: y, in_axes=(0, None))(keys, jax.random.key(0)) + self.assertEqual(type(out), type(keys)) + out = jax.pmap(lambda x, y: y, in_axes=(0, None), out_axes=None)( + keys, jax.random.key(0) + ) + self.assertEqual(type(out), type(keys)) + + @config.pmap_shmap_merge(True) + def test_lower_with_flattened_args(self): + shape = (jax.device_count(), 3) + + inputs = np.reshape(np.arange(math.prod(shape)), shape) + # The shard_map implementation of pmap takes pytree args, but the inner + # jitted_f must take flattened args. + _ = jax.pmap(lambda x: x[0]).lower((inputs, ())).compile() # doesn't crash + + @config.pmap_shmap_merge(True) + def test_float0_dtype_input(self): + inputs = np.array([b''] * jax.device_count(), dtype=dtypes.float0) + _ = jax.pmap(lambda x: x)(inputs) # doesn't crash + + @config.pmap_shmap_merge(True) + def test_float0_dtype_output(self): + inputs = np.ones(jax.device_count()) + _ = jax.pmap(lambda x: jnp.array(b'', dtype=dtypes.float0))( + inputs + ) # doesn't crash + + @config.pmap_shmap_merge(True) + def test_lowered_args_info(self): + shmap_lowered = jax.pmap(lambda x: x).lower( + (jnp.ones((1,), jnp.float32), ()) + ) + aval = core.ShapedArray((1,), jnp.float32) + expected_args_info = ( + ( + ( + stages.ArgInfo(aval, donated=False), + (), + ), + ), + {}, + ) + self.assertEqual( + shmap_lowered.args_info, expected_args_info + ) # doesn't crash + + @config.pmap_shmap_merge(True) + def test_wrapped(self): + f = lambda x: x + g = jax.pmap(f) + self.assertTrue(hasattr(g, '__wrapped__')) + self.assertEqual(g.__wrapped__, f) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/pmap_test.py b/tests/pmap_test.py index af2d03e2945d..e42a48be0fb0 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -15,7 +15,6 @@ from __future__ import annotations from concurrent.futures import ThreadPoolExecutor -import contextlib from functools import partial import itertools as it import gc @@ -28,6 +27,7 @@ import weakref import numpy as np +from absl import flags from absl.testing import absltest from absl.testing import parameterized @@ -37,7 +37,6 @@ from jax import lax import jax.scipy.linalg from jax import random -from jax.ad_checkpoint import checkpoint as new_checkpoint import jax.numpy as jnp from jax._src import api as src_api from jax._src import array @@ -49,12 +48,23 @@ from jax._src.internal_test_util import lax_test_util from jax._src.interpreters import pxla from jax._src.lax import parallel -from jax._src.lib import xla_extension +from jax._src.lib import xla_client as xc from jax._src.util import safe_map, safe_zip config.parse_flags_with_absl() jtu.request_cpu_devices(8) +_PMAP_SHMAP_MERGE = flags.DEFINE_bool( + 'pmap_shmap_merge', True, + 'If False, run pmap tests with jax_pmap_shmap_merge=False.') + +if not _PMAP_SHMAP_MERGE.value: + with jtu.ignore_warning( + category=DeprecationWarning, + message='Setting `jax_pmap_shmap_merge` is deprecated', + ): + config.update('jax_pmap_shmap_merge', False) + compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]] @@ -112,6 +122,12 @@ def pmap(self): return src_api.pmap def testDeviceBufferToArray(self): + # NOTE(dsuo): Under `pmap_shmap_merge=True`, the resulting array is sharded, + # whereas under `pmap_shmap_merge=False`, the resulting array is + # "SingleDeviceSharded". The attribute `unsafe_buffer_pointer` is + # unavailable for sharded arrays. + if config.pmap_shmap_merge.value: + self.skipTest("Test fails because pmap is jit(shmap).") sda = self.pmap(lambda x: x)(jnp.ones((jax.device_count(), 2))) # Changed in https://github.com/jax-ml/jax/pull/10584 not to access @@ -161,7 +177,7 @@ def testDefaultDeviceOrdering(self): device_order = jax.devices() pmap_sharding = pmap(lambda x: x)(np.arange(jax.device_count())).sharding if config.pmap_shmap_merge.value: - self.assertListEqual(device_order, pmap_sharding._device_assignment) + self.assertListEqual(device_order, list(pmap_sharding._device_assignment)) else: self.assertListEqual(device_order, pmap_sharding.devices.tolist()) @@ -173,7 +189,6 @@ def testLowerCompile(self): lowered = f.lower(x) compiled = lowered.compile() ans = compiled(x) - self.assertAllClose(ans, expected) # It's a pair of: (positional args, as a tuple of their structures, kwargs). @@ -218,11 +233,17 @@ def testLowerCompileArgTypeMismatch(self): x_f32 = x.astype(jnp.float32) x_i32 = x.astype(jnp.int32) f_exe = f.lower(x_f32).compile() + if config.pmap_shmap_merge.value: + expected_regex = r"Argument types differ .*" + r"The mismatches are:\n" + r"Argument 'args[0]' compiled with.*float32.*and called with.*int32.*" + else: + expected_regex = r"Argument types differ .*" + r"The mismatches are:\n" + r"Argument 'x' compiled with.*float32.*and called with.*int32.*" self.assertRaisesRegex( TypeError, - r"Argument types differ .*" - r"The mismatches are:\n" - r"Argument 'x' compiled with.*float32.*and called with.*int32.*", + expected_regex, lambda: f_exe(x_i32)) def testLowerCompileMultiArg(self): @@ -318,12 +339,12 @@ def test_jit_lower_compile_with_compiler_options_invalid(self): lowered = f.lower(x) self.assertRaisesRegex( - xla_extension.XlaRuntimeError, "No such compile option: 'invalid_key'", + jax.errors.JaxRuntimeError, "No such compile option: 'invalid_key'", lambda: lowered.compile( compiler_options={"invalid_key": "invalid_value"})) self.assertRaisesRegex( - xla_extension.XlaRuntimeError, "is not a valid bool value.", + jax.errors.JaxRuntimeError, "is not a valid bool value.", lambda: lowered.compile( compiler_options={"xla_embed_ir_in_executable": "invalid_value"})) @@ -332,7 +353,10 @@ def test_pmap_replicated_copy(self): inp = jnp.arange(jax.device_count()) x = jax.pmap(lambda x: x, in_axes=0, out_axes=None)(inp) out = jnp.copy(x) - self.assertIsInstance(out.sharding, jax.sharding.SingleDeviceSharding) + if config.pmap_shmap_merge.value: + self.assertIsInstance(out.sharding, jax.sharding.NamedSharding) + else: + self.assertIsInstance(out.sharding, jax.sharding.SingleDeviceSharding) self.assertArraysEqual(out, inp[0]) def test_jit_lower_compile_with_compiler_options_multiple(self): @@ -356,7 +380,7 @@ def test_jit_lower_compile_with_compiler_options_multiple(self): # We should still error on invalid options after some valid compiles self.assertRaisesRegex( - xla_extension.XlaRuntimeError, "No such compile option: 'invalid_key'", + jax.errors.JaxRuntimeError, "No such compile option: 'invalid_key'", lambda: lowered.compile( compiler_options={"invalid_key": "invalid_value"})) @@ -365,7 +389,8 @@ def testLowerShapedArray(self): shape = (jax.device_count(), 4) x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) x_shape = core.ShapedArray(x.shape, x.dtype) - self.assertAllClose(f.lower(x_shape).compile()(x), f(x)) + ans = f.lower(x_shape).compile()(x) + self.assertAllClose(ans, f(x)) def testLowerHasReplicaAttributes(self): f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i') @@ -374,8 +399,12 @@ def testLowerHasReplicaAttributes(self): x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) lowered = f.lower(x) hlo = lowered.as_text("stablehlo") - self.assertIn(f"mhlo.num_replicas = {num_devices}", hlo) - self.assertIn("mhlo.num_partitions = 1", hlo) + if config.pmap_shmap_merge.value: + self.assertIn(f"mhlo.num_partitions = {num_devices}", hlo) + self.assertIn("mhlo.num_replicas = 1", hlo) + else: + self.assertIn(f"mhlo.num_replicas = {num_devices}", hlo) + self.assertIn("mhlo.num_partitions = 1", hlo) def testMean(self): f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i') @@ -499,7 +528,7 @@ def testReduceScatterReplicaGroupsTiled(self): def testTrees(self): ptranspose = lambda x, axis_name: lax.all_to_all(x, axis_name, 0, 0) def protate(x, axis_name): - n = lax.psum(1, axis_name) + n = lax.axis_size(axis_name) return lax.ppermute(x, axis_name, [(i, (i + 1) % n) for i in range(n)]) tree_f = lambda f: partial(jax.tree.map, f) @@ -576,6 +605,11 @@ def f(x): def testAllToAllSplitAxis(self, split_axis, concat_axis): if jax.device_count() < 4: raise SkipTest("test requires at least four devices") + if jtu.device_under_test() == "gpu": + raise SkipTest("TODO(b/456133538): Disable on GPUs until we figure out.") + if config.pmap_shmap_merge.value: + raise SkipTest("Ignore nested pmap when `pmap_shmap_merge=True`.") + pmap_in_axis = 0 shape = (4, 4, 4) x = np.arange(math.prod(shape)).reshape(shape) @@ -597,6 +631,9 @@ def f(x): self.assertAllClose(y, ref) def testNestedPmapAxisSwap(self): + if config.pmap_shmap_merge.value: + raise SkipTest("Ignore nested pmap when `pmap_shmap_merge=True`.") + # Regression test for https://github.com/jax-ml/jax/issues/5757 if jax.device_count() < 8: raise SkipTest("test requires at least 8 devices") @@ -606,6 +643,8 @@ def testNestedPmapAxisSwap(self): self.assertAllClose(A.transpose((0, 2, 1)), f(A)) def testNestedBasic(self): + if config.pmap_shmap_merge.value: + raise SkipTest("Ignore nested pmap when `pmap_shmap_merge=True`.") f = lambda x: lax.psum(lax.psum(x, 'i'), 'j') f = self.pmap(self.pmap(f, 'i'), 'j') @@ -651,6 +690,8 @@ def testOutAxesPyTreePrefixMismatchError(self): "device_mesh_shape": device_mesh_shape} for device_mesh_shape in [(1, 1), (2, -1), (-1, 2)]) def testNestedShardingAndStacking(self, device_mesh_shape): + if config.pmap_shmap_merge.value: + raise SkipTest("Ignore nested pmap when `pmap_shmap_merge=True`.") mesh_shape = self._getMeshShape(device_mesh_shape) f = lambda x: x @@ -677,26 +718,44 @@ def testPartiallyMapped(self): f_ans = f(x, y) self.assertAllClose(f_ans, f_expected) self.assertIsInstance(f_ans, array.ArrayImpl) - sharding_spec = f_ans.sharding.sharding_spec - # the output is actually replicated (has the same values in each device buffer) - # but out_axes is implicitly 0, so we shouldn't have replication in the - # sharding spec. - self.assertEmpty([a for a in sharding_spec.mesh_mapping - if isinstance(a, pxla.Replicated)]) + if config.pmap_shmap_merge.value: + if jax.device_count() == 1: + self.assertEmpty(f_ans.sharding.spec) + else: + self.assertLen(f_ans.sharding.spec, 1) + axis = f_ans.sharding.spec[0] + self.assertEqual(axis, f_ans.sharding.mesh.axis_names[0]) + else: + sharding_spec = f_ans.sharding.sharding_spec + # the output is actually replicated (has the same values in each device + # buffer) but out_axes is implicitly 0, so we shouldn't have replication + # in the sharding spec. + self.assertEmpty([a for a in sharding_spec.mesh_mapping + if isinstance(a, pxla.Replicated)]) g_expected = np.broadcast_to(x - np.sum(y, 0, keepdims=True), shape) g_ans = g(x, y) self.assertAllClose(g_ans, g_expected) self.assertIsInstance(g_ans, array.ArrayImpl) - sharding_spec = g_ans.sharding.sharding_spec - self.assertEmpty([a for a in sharding_spec.mesh_mapping - if isinstance(a, pxla.Replicated)]) + if config.pmap_shmap_merge.value: + if jax.device_count() == 1: + self.assertEmpty(g_ans.sharding.spec) + else: + self.assertLen(g_ans.sharding.spec, 1) + axis = g_ans.sharding.spec[0] + self.assertEqual(axis, g_ans.sharding.mesh.axis_names[0]) + else: + sharding_spec = g_ans.sharding.sharding_spec + self.assertEmpty([a for a in sharding_spec.mesh_mapping + if isinstance(a, pxla.Replicated)]) @parameterized.named_parameters( {"testcase_name": f"_mesh={device_mesh_shape}".replace(" ", ""), "device_mesh_shape": device_mesh_shape} for device_mesh_shape in [(1, 1), (2, -1), (-1, 2)]) def testPartiallyMappedNested(self, device_mesh_shape): + if config.pmap_shmap_merge.value: + raise SkipTest("Ignore nested pmap when `pmap_shmap_merge=True`.") mesh_shape = self._getMeshShape(device_mesh_shape) f = self.pmap(lambda x, y: x - lax.psum(y, 'i'), axis_name='i', in_axes=(None, 0)) @@ -787,6 +846,8 @@ def g(x, y): "device_mesh_shape": device_mesh_shape} for device_mesh_shape in [(1, 1), (2, -1), (-1, 2)]) def testNestedWithClosure(self, device_mesh_shape): + if config.pmap_shmap_merge.value: + raise SkipTest("Ignore nested pmap when `pmap_shmap_merge=True`.") mesh_shape = self._getMeshShape(device_mesh_shape) @partial(self.pmap, axis_name='i') @@ -816,9 +877,10 @@ def g(z): expected = grad(lambda x: jnp.sum(baseline_fun(x)))(x) self.assertAllClose(ans, expected, atol=1e-3, rtol=1e-3) + @jtu.ignore_warning(category=DeprecationWarning) def testArrays(self): - f = lambda x: 2 * x - f = self.pmap(f, axis_name='i') + inner_f = lambda x: 2 * x + f = self.pmap(inner_f, axis_name='i') shape = (jax.device_count(), 4) x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) @@ -845,8 +907,14 @@ def testArrays(self): # test that we can handle device movement on dispatch bufs = y._arrays[::-1] - sharding = jax.sharding.PmapSharding( - [list(b.devices())[0] for b in bufs], y.sharding.sharding_spec) + devices = [list(b.devices())[0] for b in bufs] + if config.pmap_shmap_merge.value: + mesh = jax.sharding.Mesh(devices, 'i') + sharding = jax.sharding.NamedSharding(mesh, y.sharding.spec) + # NOTE(dsuo): Need to redefine pmap with the updated devices. + f = self.pmap(inner_f, axis_name='i', devices=devices) + else: + sharding = jax.sharding.PmapSharding(devices, y.sharding.sharding_spec) y = jax.make_array_from_single_device_arrays(y.shape, sharding, bufs) z = f(y) self.assertAllClose(z, 2 * 2 * x[::-1], check_dtypes=False) @@ -876,6 +944,9 @@ def testArrayReshape(self, in_shape, out_shape): check_dtypes=False) def testPsumMultiple(self): + if config.pmap_shmap_merge.value: + raise SkipTest("Ignore nested pmap when `pmap_shmap_merge=True`.") + f = lambda x: lax.psum(x, ('i', 'j')) f = self.pmap(self.pmap(f, 'i'), 'j') @@ -1030,6 +1101,8 @@ def f(x): jtu.check_grads(f, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, eps=1.) def testNestedPmapReplicaGroups(self): + if config.pmap_shmap_merge.value: + raise SkipTest("Ignore nested pmap when `pmap_shmap_merge=True`.") replicas = jax.device_count() if replicas % 4 != 0: raise SkipTest @@ -1261,33 +1334,42 @@ def testReduceMin(self): def testDeviceCountError(self): device_count = jax.device_count() + # NOTE(dsuo): The error message is different depending on the version of + # this test. + if config.pmap_shmap_merge.value: + expected_regex = r".*" + else: + expected_regex = r".*requires.*replicas" + f = self.pmap(lambda x: 2 * x) x = jnp.arange(device_count + 1) - self.assertRaisesRegex(ValueError, ".*requires.*replicas", lambda: f(x)) + self.assertRaisesRegex(ValueError, expected_regex, lambda: f(x)) f = self.pmap(lambda x: 2 * x) x = np.ones((device_count + 1, 10)) - self.assertRaisesRegex(ValueError, ".*requires.*replicas", lambda: f(x)) + self.assertRaisesRegex(ValueError, expected_regex, lambda: f(x)) f = self.pmap(lambda x: self.pmap(lambda x: 2 * x)(x)) x = np.ones((device_count, 2, 10)) - self.assertRaisesRegex(ValueError, ".*requires.*replicas", lambda: f(x)) + self.assertRaisesRegex(ValueError, expected_regex, lambda: f(x)) def testPmapConstant(self): device_count = jax.device_count() - f = self.pmap(lambda x: 3) - x = jnp.arange(device_count) - with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 - ans = f(x) - # self.assertEqual(count(), 0) # TODO(mattjj): fix this - expected = np.repeat(3, device_count) + const = jnp.arange(16, dtype=np.int32) # distinctive shape + f = self.pmap(lambda x: x + const[15]) + x = jnp.arange(device_count, dtype=np.int32) + expected = x + np.int32(15) + ans = f(x) self.assertAllClose(ans, expected, check_dtypes=False) + if not config.disable_jit.value: + self.assertCacheMisses(lambda: f(x), + compilation_after_persistent_cache_miss=0) if not config.disable_jit.value: - f = self.pmap(lambda x: (x, 3)) - x = np.arange(device_count) + f = self.pmap(lambda x: x + const[15]) + x = np.arange(device_count, dtype=np.int32) with jtu.assert_num_jit_and_pmap_compilations(1): - _, ans = f(x) + ans = f(x) self.assertAllClose(ans, expected, check_dtypes=False) def testPmapConstantDevices(self): @@ -1314,11 +1396,17 @@ def testPmapConstantError(self): device_count = jax.device_count() f = self.pmap(lambda x: 3) x = jnp.arange(device_count + 1) - self.assertRaisesRegex( - ValueError, - (r"compiling computation that requires \d+ logical devices, " - r"but only \d+ XLA devices are available .*"), - lambda: f(x)) + if config.pmap_shmap_merge.value: + expected_regex = [ + # NOTE(dsuo): We get different error messages depending on backend. + r'shard_map applied.*axis sizes.*not evenly divisible.*mesh axis sizes.*', + r'cannot select an axis to squeeze out which has size not equal to one.*', + r'Sharding.*implies that array.*but the dimension size is.*', + ] + expected_regex = '|'.join(expected_regex) + else: + expected_regex = r'compiling computation that requires \d+ logical devices, but only \d+ XLA devices are available .*' + self.assertRaisesRegex(ValueError, expected_regex, lambda: f(x)) # TODO(mattjj): test error message with explicit devices # f = pmap(lambda x: 3, devices=[jax.devices()[0]]) @@ -1330,6 +1418,8 @@ def testPmapConstantError(self): def testNestedPmapConstant(self): if jax.device_count() == 1: raise SkipTest("this test requires multiple devices") + if config.pmap_shmap_merge.value: + raise SkipTest("Ignore nested pmap when `pmap_shmap_merge=True`.") f = self.pmap(self.pmap(lambda x: 3)) shape = (2, jax.device_count() // 2, 3) @@ -1354,6 +1444,8 @@ def testNestedPmapConstant(self): def testNestedPmapConstantDevices(self): if jax.device_count() < 6: raise SkipTest("this test requires >= 6 devices") + if config.pmap_shmap_merge.value: + raise SkipTest("Ignore nested pmap when `pmap_shmap_merge=True`.") devices = jax.devices()[:-2] shuffle(devices) @@ -1373,6 +1465,8 @@ def testNestedPmapConstantDevices(self): def testNestedPmapConstantError(self): if config.disable_jit.value: raise SkipTest("error test doesn't apply with disable_jit") + if config.pmap_shmap_merge.value: + raise SkipTest("Ignore nested pmap when `pmap_shmap_merge=True`.") f = self.pmap(self.pmap(lambda x: 3)) shape = (2, jax.device_count() // 2 + 1, 3) x = jnp.arange(math.prod(shape)).reshape(shape) @@ -1395,22 +1489,24 @@ def testNestedPmapConstantError(self): def testCollectiveConstant(self): device_count = jax.device_count() - f = self.pmap(lambda x: lax.psum(1, 'i'), 'i') + f = self.pmap(lambda x: lax.axis_size('i'), 'i') x = jnp.arange(device_count) ans = f(x) expected = np.repeat(device_count, device_count) self.assertAllClose(ans, expected, check_dtypes=False) def testCollectiveConstantNested(self): + if config.pmap_shmap_merge.value: + raise SkipTest("Ignore nested pmap when `pmap_shmap_merge=True`.") device_count = jax.device_count() @partial(self.pmap, axis_name='i') def f(x): @partial(self.pmap, axis_name='j') def g(y): - a = lax.psum(1, 'i') - b = lax.psum(1, 'j') - c = lax.psum(1, ('i', 'j')) + a = lax.axis_size('i') + b = lax.axis_size('j') + c = lax.axis_size(('i', 'j')) return a, b, c return g(x) @@ -1435,6 +1531,8 @@ def testAxisIndex(self): self.assertAllClose(ans, expected, check_dtypes=False) def testAxisIndexNestedPmap(self): + if config.pmap_shmap_merge.value: + raise SkipTest("Ignore nested pmap when `pmap_shmap_merge=True`.") device_count = jax.device_count() if device_count < 4: raise SkipTest("test requires at least four devices") @@ -1445,6 +1543,8 @@ def testAxisIndexNestedPmap(self): self.assertAllClose(f('i')(x), expected_j.T, check_dtypes=False) def testAxisIndexNd(self): + if config.pmap_shmap_merge.value: + raise SkipTest("Ignore nested pmap when `pmap_shmap_merge=True`.") device_count = jax.device_count() if device_count < 4: raise SkipTest("test requires at least four devices") @@ -1670,6 +1770,7 @@ def time_evolution(state): multi_step_pmap(jnp.zeros((device_count,)), count=1) + @jtu.ignore_warning(category=DeprecationWarning) def test_typed_prng_key_sharded(self): devices = jax.local_devices() @@ -1764,7 +1865,7 @@ def distributed_matrix_vector(x, y): {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ ('', jax.remat), - ('_new', new_checkpoint), + ('_new', jax.checkpoint), ]) def testAxisIndexRemat(self, remat): # https://github.com/jax-ml/jax/issues/2716 @@ -1840,6 +1941,9 @@ def testPsumOnBooleanDtype(self): self.assertEqual(list(out), [1]) def testPsumWithNoAxisDoesntLeakFunctions(self): + if config.pmap_shmap_merge.value: + raise SkipTest("shmap implementation holds an additional weakref.") + x = jnp.ones((1, 1024), dtype=np.float32) f = lambda _: x w = weakref.ref(f) @@ -1852,6 +1956,9 @@ def testPsumWithNoAxisDoesntLeakFunctions(self): self.assertIs(w(), None) def testJitOfPmapWarningMessage(self): + if config.pmap_shmap_merge.value: + raise SkipTest("Test does not warn under `pmap_shmap_merge=True`.") + device_count = jax.device_count() if device_count == 1 or config.disable_jit.value: @@ -1859,7 +1966,7 @@ def testJitOfPmapWarningMessage(self): def foo(x): return x - with self.assertWarnsRegex(UserWarning, "The jitted function foo includes a pmap"): + with self.assertWarnsRegex(UserWarning, "The function jit.foo. includes a pmap"): jit(self.pmap(foo))(jnp.arange(device_count)) def testJitOfPmapOutputSharding(self): @@ -1870,7 +1977,8 @@ def testJitOfPmapOutputSharding(self): @jax.jit @jax.pmap - def foo(x): return x + x + def foo(x): + return x + x x = np.ones((2,2,2), dtype=np.float32) for _ in range(10): @@ -1889,13 +1997,18 @@ def testJitOfPmapLowerHasReplicaAttributes(self): @jax.jit @jax.pmap - def foo(x): return x + x + def foo(x): + return x + x x = np.ones((2,2,2), dtype=np.float32) hlo = foo.lower(x).as_text("stablehlo") - self.assertIn(f"mhlo.num_replicas = {2}", hlo) - self.assertIn("mhlo.num_partitions = 1", hlo) + if config.pmap_shmap_merge.value: + self.assertIn("mhlo.num_replicas = 1", hlo) + self.assertIn("mhlo.num_partitions = 2", hlo) + else: + self.assertIn(f"mhlo.num_replicas = {2}", hlo) + self.assertIn("mhlo.num_partitions = 1", hlo) def testPsumZeroCotangents(self): # https://github.com/jax-ml/jax/issues/3651 @@ -1944,6 +2057,9 @@ def pmapped_multi_step(state): @jtu.skip_on_devices("cpu") def test_replicate_backend(self): + if config.pmap_shmap_merge.value: + raise SkipTest("Ignore nested pmaps under `pmap_shmap_merge=True`.") + # TODO(skye): fix backend caching so we always have multiple CPUs available if jax.device_count("cpu") < 4: self.skipTest("test requires 4 CPU device") @@ -2074,7 +2190,7 @@ def g(x): {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ ('', jax.remat), - ('_new', new_checkpoint), + ('_new', jax.checkpoint), ]) def test_remat_of_pmap(self, remat): f = remat(jax.pmap(lambda x: jnp.sin(jnp.sin(x)))) @@ -2089,7 +2205,7 @@ def test_remat_of_pmap(self, remat): {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ ('', jax.remat), - ('_new', new_checkpoint), + ('_new', jax.checkpoint), ]) def test_remat_of_pmap_policy(self, remat): g = jax.pmap(lambda x: jnp.sin(jnp.sin(x))) @@ -2098,7 +2214,7 @@ def test_remat_of_pmap_policy(self, remat): save_cos = lambda prim, *_, **__: str(prim) == 'cos' f = remat(g, policy=save_cos) _, f_vjp = jax.vjp(f, x) - jaxpr = f_vjp.args[0].func.args[1] + jaxpr = f_vjp.jaxpr jaxpr_text = str(jaxpr) self.assertEqual(jaxpr_text.count(' sin '), 0) self.assertEqual(jaxpr_text.count(' cos '), 0) @@ -2106,7 +2222,7 @@ def test_remat_of_pmap_policy(self, remat): save_sin = lambda prim, *_, **__: str(prim) == 'sin' f = remat(g, policy=save_sin) _, f_vjp = jax.vjp(f, x) - jaxpr = f_vjp.args[0].func.args[1] + jaxpr = f_vjp.jaxpr jaxpr_text = str(jaxpr) self.assertEqual(jaxpr_text.count(' sin '), 0) self.assertEqual(jaxpr_text.count(' cos '), 2) @@ -2114,7 +2230,7 @@ def test_remat_of_pmap_policy(self, remat): save_nothing = lambda prim, *_, **__: False f = remat(g, policy=save_nothing) _, f_vjp = jax.vjp(f, x) - jaxpr = f_vjp.args[0].func.args[1] + jaxpr = f_vjp.jaxpr jaxpr_text = str(jaxpr) self.assertEqual(jaxpr_text.count(' sin '), 1) self.assertEqual(jaxpr_text.count(' cos '), 2) @@ -2152,7 +2268,7 @@ def test_pmap_of_prng_key(self): keys = jax.random.split(jax.random.key(0), jax.device_count()) result1 = jax.pmap(jax.random.bits)(keys) with jtu.ignore_warning( - category=UserWarning, message="The jitted function bits includes a pmap"): + category=UserWarning, message="The function jit.bits. includes a pmap"): result2 = jax.jit(jax.pmap(jax.random.bits))(keys) self.assertArraysEqual(result1, result2) @@ -2160,10 +2276,13 @@ def test_pmap_of_prng_key(self): @jtu.pytest_mark_if_available('multiaccelerator') class CppPmapTest(PythonPmapTest): + def setUp(self): + super().setUp() + if config.pmap_shmap_merge.value: + raise SkipTest('Not testing cpp_pmap when `pmap_shmap_merge=True`.') + @property def pmap(self): - if config.pmap_shmap_merge.value: - return src_api.pmap return src_api._cpp_pmap def pmap_fast_path_is_enabled(self): @@ -2271,7 +2390,11 @@ def f(x, y): if jax.device_count() < 4: raise SkipTest("test requires at least four devices") x = jnp.ones((2, 2, 64, 64)) - y = f(jax.pmap, jax.pmap)(x, x) + if config.pmap_shmap_merge.value: + # NOTE(dsuo): Ignore nested pmap when `pmap_shmap_merge=True`. + y = f(jax.vmap, jax.vmap)(x, x) + else: + y = f(jax.pmap, jax.pmap)(x, x) self.assertAllClose(f(jax.vmap, jax.vmap)(x, x), y) self.assertAllClose(f(jax.pmap, jax.vmap)(x, x), y) self.assertAllClose(f(jax.vmap, jax.pmap)(x, x), y) @@ -2281,6 +2404,8 @@ def f(x, y): "collective": collective} for collective in [lax.psum, lax.pmean, lax.pmax, lax.pmin]) def testCollectivesWithVmap2(self, collective): + if jtu.device_under_test() == "gpu": + raise SkipTest("TODO(b/456133538): Disable on GPUs until we figure out.") def f(map1, map2): @partial(map1, axis_name='i') @partial(map2, axis_name='j') @@ -2291,7 +2416,11 @@ def f(x, y): if jax.device_count() < 8: raise SkipTest("test requires at least eight devices") x = jnp.arange(4*2*64*64, dtype=float).reshape(4, 2, 64, 64) - y = f(jax.pmap, jax.pmap)(x, x) + if config.pmap_shmap_merge.value: + # NOTE(dsuo): Ignore nested pmap when `pmap_shmap_merge=True`. + y = f(jax.vmap, jax.vmap)(x, x) + else: + y = f(jax.pmap, jax.pmap)(x, x) self.assertAllClose(f(jax.vmap, jax.vmap)(x, x), y) self.assertAllClose(f(jax.pmap, jax.vmap)(x, x), y) self.assertAllClose(f(jax.vmap, jax.pmap)(x, x), y) @@ -2299,8 +2428,8 @@ def f(x, y): def testPPermuteWithVmap(self): perm = [(0, 1), (1, 0)] - def f(map2): - @partial(jax.pmap, axis_name='i') + def f(map1, map2): + @partial(map1, axis_name='i') @partial(map2) def f(x, y): return x + jax.lax.ppermute(x.dot(y), 'i', perm) @@ -2309,7 +2438,7 @@ def f(x, y): if jax.device_count() < 4: raise SkipTest("test requires at least four devices") x = jnp.ones((2, 2, 64, 64)) - self.assertAllClose(f(jax.pmap)(x, x), f(jax.vmap)(x, x)) + self.assertAllClose(f(jax.vmap, jax.pmap)(x, x), f(jax.pmap, jax.vmap)(x, x)) def testPPermuteAgreesWithVmap(self): if jax.device_count() < 3: @@ -2419,7 +2548,7 @@ def f(x): shape = (2, 2, 4, 4, 4) x = jnp.arange(math.prod(shape)).reshape(shape) - self.assertAllClose(pmap(pmap(f, axis_name='j'), axis_name='i')(x), + self.assertAllClose(pmap(vmap(f, axis_name='j'), axis_name='i')(x), vmap(vmap(f, axis_name='j'), axis_name='i')(x)) @parameterized.named_parameters([ @@ -2427,13 +2556,13 @@ def f(x): ('ReduceScatter', lax.psum_scatter), ]) def testWithVmap(self, prim): - def f(map2): - return jax.pmap(map2(partial(prim, axis_name='i')), axis_name='i') + def f(map1, map2): + return map1(map2(partial(prim, axis_name='i')), axis_name='i') if jax.device_count() < 4: raise SkipTest("test requires at least four devices") x = jnp.ones((2, 2, 2, 64)) - self.assertAllClose(f(jax.pmap)(x), f(jax.vmap)(x)) + self.assertAllClose(f(jax.vmap, jax.pmap)(x), f(jax.pmap, jax.vmap)(x)) @parameterized.named_parameters(it.chain.from_iterable([ ('AllGather' + ('Tiled' if tiled else ''), lax.all_gather, tiled), @@ -2487,6 +2616,8 @@ def testNoDevicesError(self): def testBadAxisSizeError(self): if jax.device_count() == 1: raise SkipTest("this test requires multiple devices") + if config.pmap_shmap_merge.value: + raise SkipTest("jit(shmap) does not raise error.") f = pmap(lambda x: lax.psum(x, 'i'), axis_name='i', devices=jax.devices()) @@ -2505,6 +2636,8 @@ def testBadAxisSizeError(self): def testBadAxisSizeErrorNested(self): if config.disable_jit.value: raise SkipTest("error doesn't apply when jit is disabled") + if config.pmap_shmap_merge.value: + raise SkipTest("Ignore nested pmap when `pmap_shmap_merge=True`.") f = pmap(pmap(lambda x: lax.psum(x, ('i', 'j')), axis_name='j'), axis_name='i', @@ -2520,6 +2653,8 @@ def testNestedPmaps(self): raise SkipTest if config.disable_jit.value: raise SkipTest("disable_jit requires num devices to equal axis size") + if config.pmap_shmap_merge.value: + raise SkipTest("Ignore nested pmap when `pmap_shmap_merge=True`.") # Devices specified in outer pmap are OK @partial(pmap, axis_name='i', devices=jax.devices()) @@ -2539,6 +2674,8 @@ def testNestedPmapsBools(self): raise SkipTest if config.disable_jit.value: raise SkipTest("disable_jit requires num devices to equal axis size") + if config.pmap_shmap_merge.value: + raise SkipTest("Ignore nested pmap when `pmap_shmap_merge=True`.") # Devices specified in outer pmap are OK @partial(pmap, axis_name='i', devices=jax.devices()) @@ -2554,6 +2691,8 @@ def bar(y): self.assertAllClose(ans, expected) def testNestedPmapsError(self): + if config.pmap_shmap_merge.value: + raise SkipTest('Ignore nested pmap when `pmap_shmap_merge=True`.') # Devices specified in inner pmap not OK @partial(pmap, axis_name='i') def foo(x): @@ -2709,6 +2848,10 @@ def h(y): class ArrayTest(jtu.JaxTestCase): def testThreadsafeIndexing(self): + if config.pmap_shmap_merge.value: + # NOTE(dsuo): This passes when not CPU, but fails for all platforms under + # disable_jit=True. + raise SkipTest('TODO(dsuo): See https://github.com/jax-ml/jax/issues/31911') # NOTE(skye): I picked these values to be big enough to cause interesting # execution overlap, but small enough to not use too much memory. YMMV. shape = (8, 4000, 1000) @@ -2750,13 +2893,17 @@ def testNoCopyIndexing1D(self): self.assertIsInstance(sharded_x[i], array.ArrayImpl) self.assertIsNone(sharded_x._npy_value) + @jtu.ignore_warning(category=DeprecationWarning) def test_device_put_sharded(self): devices = jax.local_devices() n_devices = len(devices) x = [np.arange(i, i + 4) for i in range(n_devices)] y = jax.device_put_sharded(x, devices) self.assertIsInstance(y, array.ArrayImpl) - self.assertIsInstance(y.sharding, jax.sharding.PmapSharding) + if config.pmap_shmap_merge.value: + self.assertIsInstance(y.sharding, jax.NamedSharding) + else: + self.assertIsInstance(y.sharding, jax.sharding.PmapSharding) for s in y.addressable_shards: self.assertArraysEqual(s.data, y[s.index]) self.assertEqual(s.replica_id, 0) @@ -2765,6 +2912,7 @@ def test_device_put_sharded(self): self.assertTrue(all(b.devices() == {d} for b, d in zip(buffers, devices))) self.assertArraysEqual(y, jnp.stack(x)) + @jtu.ignore_warning(category=DeprecationWarning) def test_device_put_sharded_pytree(self): devices = jax.local_devices() n_devices = len(devices) @@ -2781,6 +2929,7 @@ def test_device_put_sharded_pytree(self): y2_buffers = getattr(y2, '_arrays') self.assertTrue(all(b.devices() == {d} for b, d in zip(y2_buffers, devices))) + @jtu.ignore_warning(category=DeprecationWarning) def test_device_put_replicated(self): devices = jax.local_devices() x = np.arange(1, 5) @@ -2792,6 +2941,7 @@ def test_device_put_replicated(self): self.assertTrue(all(b.devices() == {d} for b, d in zip(buffers, devices))) self.assertArraysEqual(y, np.stack([x for _ in devices])) + @jtu.ignore_warning(category=DeprecationWarning) def test_device_put_replicated_pytree(self): devices = jax.local_devices() xs = {'a': np.arange(1, 5), 'b': np.arange(3)} @@ -2811,10 +2961,12 @@ def test_device_put_replicated_pytree(self): self.assertTrue(all(b.devices() == {d} for b, d in zip(y2_buffers, devices))) self.assertArraysEqual(y2, np.stack([xs['b'] for _ in devices])) + @jtu.ignore_warning(category=DeprecationWarning) def test_repr(self): x = jax.device_put_replicated(1, jax.devices()) self.assertStartsWith(repr(x), 'Array') + @jtu.ignore_warning(category=DeprecationWarning) def test_delete_is_idempotent(self): x = jax.device_put_replicated(1, jax.devices()) x.delete() @@ -2990,6 +3142,7 @@ def device_array(x): [(), pxla.ShardingSpec(sharding=(), mesh_mapping=(pxla.Replicated(2), pxla.Replicated(3)))], ]) + @jtu.ignore_warning(category=DeprecationWarning) def testShardArgs(self, shape, spec, make_arg): indices = sharding_specs.spec_to_indices(shape, spec) nshards = len(indices) @@ -2998,7 +3151,8 @@ def testShardArgs(self, shape, spec, make_arg): x = np.arange(math.prod(shape)).reshape(shape) arg = make_arg(x) sharding = jax.sharding.PmapSharding(jax.devices()[:nshards], spec) - results = pxla.shard_args([sharding], [None], [None], [arg]) + results = pxla.shard_args([sharding], [None], + [xc.ArrayCopySemantics.REUSE_INPUT], [arg]) self.assertEqual(len(results), 1) if isinstance(results[0], array.ArrayImpl): bufs = results[0]._arrays @@ -3012,6 +3166,7 @@ def testShardArgs(self, shape, spec, make_arg): @jtu.pytest_mark_if_available('multiaccelerator') class ArrayPmapTest(jtu.JaxTestCase): + @jtu.ignore_warning(category=DeprecationWarning) def test_pmap_input_array_output_array(self): input_shape = (jax.device_count(), 2) input_array, input_data = create_input_array_for_pmap(input_shape) @@ -3026,6 +3181,7 @@ def test_pmap_input_array_output_array(self): self.assertArraysEqual(s.data, expected[s.index]) self.assertArraysEqual(out, expected) + @jtu.ignore_warning(category=DeprecationWarning) def test_pmap_double_input_array_output_array(self): input_shape = (jax.device_count(), 2) input_array, input_data = create_input_array_for_pmap(input_shape) @@ -3046,6 +3202,7 @@ def f(x, y): self.assertArraysEqual(out1, input_data) self.assertArraysEqual(out2, input_data) + @jtu.ignore_warning(category=DeprecationWarning) def test_pmap_array_in_axes_out_axes(self): dc = jax.device_count() input_shape = (dc, 2) @@ -3067,11 +3224,9 @@ def f(x, y): self.assertEqual(out2.shape, (dc, dc, 2)) for i, (s1, s2) in enumerate(safe_zip(out1.addressable_shards, out2.addressable_shards)): self.assertArraysEqual(s1.data, input_data[i]) - if config.pmap_no_rank_reduction.value: - self.assertArraysEqual(s2.data, input_data[None]) - else: - self.assertArraysEqual(s2.data, input_data) + self.assertArraysEqual(s2.data, input_data[None]) + @jtu.ignore_warning(category=DeprecationWarning) def test_pmap_array_sharding_mismatch(self): input_shape = (jax.device_count(), 2) a1, inp_data = create_input_array_for_pmap(input_shape, in_axes=None, @@ -3082,6 +3237,7 @@ def test_pmap_array_sharding_mismatch(self): self.assertArraysEqual(out_array, inp_data) + @jtu.ignore_warning(category=DeprecationWarning) def test_pmap_array_devices_mismatch(self): if jax.device_count() <= 1: raise unittest.SkipTest('Skipping because this test needs more than ' @@ -3094,6 +3250,7 @@ def test_pmap_array_devices_mismatch(self): self.assertArraysEqual(out_array, inp_data) + @jtu.ignore_warning(category=DeprecationWarning) def test_amap(self): # Copied from an example mattjj@ posted in a chat thread. @@ -3122,6 +3279,7 @@ def dynamic_shape_function(y): self.assertArraysEqual(w, jnp.cos(jnp.sin(x) ** 2)) + @jtu.ignore_warning(category=DeprecationWarning) def test_same_out_sharding_id(self): if config.disable_jit.value: self.skipTest('Skip this under eager pmap mode.') @@ -3145,11 +3303,11 @@ def test_same_out_sharding_id(self): self.assertEqual(out1_sharding_id, out3_sharding_id) self.assertEqual(out2_sharding_id, out3_sharding_id) + @jtu.ignore_warning(category=DeprecationWarning) def test_array_with_pmap_sharding_copy_without_round_trip(self): def _compare_if_equal(out, out_copy): self.assertArraysEqual(out, out_copy) - self.assertIsInstance(out_copy.sharding, jax.sharding.PmapSharding) self.assertEqual(out.sharding, out_copy.sharding) for o, o_copy in safe_zip(out.addressable_shards, out_copy.addressable_shards): self.assertArraysEqual(o.data, o_copy.data) @@ -3159,14 +3317,27 @@ def _compare_if_equal(out, out_copy): self.assertNotEqual(o.data.unsafe_buffer_pointer(), o_copy.data.unsafe_buffer_pointer()) - out, _ = create_input_array_for_pmap((jax.device_count(),)) + if config.pmap_shmap_merge.value: + sharding = jax.sharding.NamedSharding( + jax.sharding.Mesh(np.array(jax.devices()), 'x'), + jax.sharding.PartitionSpec('x')) + out = jax.device_put(jnp.ones((jax.device_count(),)), sharding) + else: + out, _ = create_input_array_for_pmap((jax.device_count(),)) out_copy = jnp.copy(out) _compare_if_equal(out, out_copy) - out1, _ = create_input_array_for_pmap((1, jax.device_count(),), in_axes=1) + if config.pmap_shmap_merge.value: + sharding = jax.sharding.NamedSharding( + jax.sharding.Mesh(np.array(jax.devices()).reshape(1, -1), ('x', 'y')), + jax.sharding.PartitionSpec('x', 'y')) + out1 = jax.device_put(jnp.ones((1, jax.device_count())), sharding) + else: + out1, _ = create_input_array_for_pmap((1, jax.device_count(),), in_axes=1) out_copy1 = jnp.copy(out1) _compare_if_equal(out1, out_copy1) + @jtu.ignore_warning(category=DeprecationWarning) def test_device_put_sharded_transfer_guard(self): inp = jnp.arange(jax.device_count()) arr_inp = [jax.device_put(i, d) for i, d in zip(inp, jax.devices())] @@ -3188,17 +3359,20 @@ class EagerPmapMixin: def setUp(self): super().setUp() - stack = contextlib.ExitStack() - stack.enter_context(jtu.thread_local_config_context(jax_disable_jit=True, jax_eager_pmap=True)) - stack.enter_context(jtu.ignore_warning( + if config.pmap_shmap_merge.value: + # NOTE(dsuo): Most test do pass `pmap_shmap_merge=True` and + # `disable_jit=True` but are they still meaningful? They can also be much + # slower. + raise SkipTest('Not testing disable_jit when `pmap_shmap_merge=True`.') + self.enter_context(jtu.thread_local_config_context(jax_disable_jit=True)) + self.enter_context(jtu.ignore_warning( message="Some donated buffers were not usable", category=UserWarning)) - self.addCleanup(stack.close) + @jtu.pytest_mark_if_available('multiaccelerator') class PythonPmapEagerTest(EagerPmapMixin, PythonPmapTest): def test_custom_jvp(self): - @jax.custom_jvp def foo(x): return jnp.exp(x) @@ -3214,7 +3388,6 @@ def foo_jvp(xs, ts): self.assertAllClose(self.pmap(f)(x, x), jax.vmap(f)(x, x)) def test_custom_vjp(self): - @jax.custom_vjp def foo(x): return jnp.exp(x) diff --git a/tests/polynomial_test.py b/tests/polynomial_test.py index 3eeaec482719..e68a4b14115d 100644 --- a/tests/polynomial_test.py +++ b/tests/polynomial_test.py @@ -71,8 +71,8 @@ def assertSetsAllClose(self, x, y, rtol=None, atol=None, check_dtypes=True): leading=[0, 2], trailing=[0, 2], ) - # TODO(phawkins): no nonsymmetric eigendecomposition implementation on GPU. - @jtu.run_on_devices("cpu") + # TODO(phawkins): no nonsymmetric eigendecomposition implementation on TPU. + @jtu.skip_on_devices("tpu") def testRoots(self, dtype, length, leading, trailing): rng = jtu.rand_some_zero(self.rng()) @@ -97,8 +97,8 @@ def np_fun(arg): leading=[0, 2], trailing=[0, 2], ) - # TODO(phawkins): no nonsymmetric eigendecomposition implementation on GPU. - @jtu.run_on_devices("cpu") + # TODO(phawkins): no nonsymmetric eigendecomposition implementation on TPU. + @jtu.skip_on_devices("tpu") def testRootsNoStrip(self, dtype, length, leading, trailing): rng = jtu.rand_some_zero(self.rng()) diff --git a/tests/pretty_printer_test.py b/tests/pretty_printer_test.py index d87708c9d91c..b4363be1c965 100644 --- a/tests/pretty_printer_test.py +++ b/tests/pretty_printer_test.py @@ -13,24 +13,90 @@ # limitations under the License. from absl.testing import absltest - -from jax._src import test_util as jtu from jax._src import pretty_printer as pp +from jax._src import test_util as jtu class PrettyPrinterTest(jtu.JaxTestCase): def testSourceMap(self): doc = pp.concat([ - pp.text("abc"), pp.source_map(pp.text("def"), 101), - pp.source_map(pp.concat([pp.text("gh"), pp.brk(""), pp.text("ijkl")]), 77), - pp.text("mn"), + pp.text("abc"), + pp.source_map(pp.text("def"), 101), + pp.source_map( + pp.concat([pp.text("gh"), pp.brk(""), pp.text("ijkl")]), 77 + ), + pp.text("mn"), ]) source_map = [] out = doc.format(width=8, source_map=source_map) self.assertEqual(out, "abcdefgh\nijklmn") self.assertEqual(source_map, [[(3, 6, 101), (6, 8, 77)], [(0, 4, 77)]]) + def testBasics(self): + self.assertEqual(pp.nil().format(), "") + self.assertEqual(pp.text("").format(), "") + self.assertEqual(pp.text("testing").format(), "testing") + self.assertEqual(pp.text("\n").format(), "\n") + self.assertEqual(pp.brk().format(), "\n") + # Group that fits will use the space from brk() + self.assertEqual(pp.group(pp.brk()).format(), " ") + # Group that doesn't fit (due to width=0) will use newline + self.assertEqual(pp.group(pp.brk()).format(width=0), "\n") + + # Custom break text + self.assertEqual(pp.group(pp.brk("-")).format(), "-") + self.assertEqual(pp.group(pp.brk("-")).format(width=0), "\n") + + # Concatenation + self.assertEqual((pp.text("a") + pp.text("b")).format(), "ab") + self.assertEqual(pp.concat([pp.text("a"), pp.text("b c")]).format(), "ab c") + + x = pp.text("x") + y = pp.text("y") + z = pp.text("z") + + # Join + # Join with a break that becomes a space when fitting + join_doc_space = pp.join( + pp.text(",") + pp.brk(), [pp.text("a"), pp.text("b"), pp.text("c")] + ) + self.assertEqual(pp.group(join_doc_space).format(), "a, b, c") + self.assertEqual(pp.group(join_doc_space).format(width=5), "a,\nb,\nc") + self.assertEqual(pp.join(pp.text(","), [x, y, z]).format(), "x,y,z") + + j = pp.join( + pp.brk(), [pp.text("xx"), pp.text("yy"), pp.text("zz"), pp.text("ww")] + ) + self.assertEqual(pp.group(j).format(width=3), "xx\nyy\nzz\nww") + self.assertEqual(pp.group(j).format(width=80), "xx yy zz ww") + + bx = pp.brk() + x + bxbx = bx + bx + bx4 = bxbx + bxbx + + # Horizontal-like (fits) + self.assertEqual(pp.group(bx).format(), " x") + self.assertEqual(pp.group(bxbx).format(), " x x") + self.assertEqual(pp.group(bx4).format(), " x x x x") + + # Vertical-like (forced by width) + self.assertEqual(pp.group(bx).format(width=0), "\nx") + self.assertEqual(pp.group(bxbx).format(width=0), "\nx\nx") + self.assertEqual(pp.group(bx4).format(width=0), "\nx\nx\nx\nx") + self.assertEqual(pp.group(bxbx).format(width=3), "\nx\nx") + + # Nesting + xbybz = x + pp.brk() + y + pp.brk() + z + self.assertEqual(pp.nest(2, pp.group(bx)).format(), " x") # Stays flat + self.assertEqual(pp.nest(2, pp.group(bxbx)).format(), " x x") # Stays flat + self.assertEqual(pp.nest(2, pp.group(bx)).format(width=0), "\n x") + self.assertEqual( + pp.nest(2, pp.nest(2, pp.group(bx))).format(width=0), "\n x" + ) + self.assertEqual(pp.nest(2, pp.group(xbybz)).format(width=0), "x\n y\n z") + self.assertEqual(pp.nest(2, pp.group(bxbx)).format(width=0), "\n x\n x") + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/profiler_session_test.py b/tests/profiler_session_test.py new file mode 100644 index 000000000000..695a1552a9fc --- /dev/null +++ b/tests/profiler_session_test.py @@ -0,0 +1,76 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pathlib + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import test_util as jtu +import jax.numpy as jnp + +_TEST_SESSION_ID = 'my_custom_session_123' + + +@jtu.thread_unsafe_test_class() +class ProfilerSessionTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + # Ensure that any running profiler is stopped before starting the test. + # This is in setUp rather than tearDown to defend against previous tests + # that may have crashed or failed to clean up properly. + try: + jax.profiler.stop_trace() + except RuntimeError: + pass + + @parameterized.named_parameters( + dict(testcase_name='without_session_id', session_id=None), + dict(testcase_name='with_empty_session_id', session_id=''), + dict(testcase_name='with_custom_session_id', session_id=_TEST_SESSION_ID), + ) + def test_programmatic_profiling(self, session_id: str | None): + tmpdir = pathlib.Path(self.create_tempdir()) + + options = jax.profiler.ProfileOptions() + if session_id is not None: + options.session_id = session_id + + with jax.profiler.trace(tmpdir, profiler_options=options): + jax.pmap(lambda x: jax.lax.psum(x + 1, 'i'), axis_name='i')( + jnp.ones(jax.local_device_count()) + ).block_until_ready() + + profile_plugin_dir = tmpdir / 'plugins' / 'profile' + self.assertTrue(profile_plugin_dir.exists(), f'Not found at {profile_plugin_dir}') + + subdirs = [x.name for x in profile_plugin_dir.iterdir() if x.is_dir()] + self.assertLen(subdirs, 1) + + if session_id is None or not session_id: + self.assertNotIn(_TEST_SESSION_ID, subdirs) + self.assertNotIn('', subdirs) + target_dir = subdirs[0] + else: + self.assertIn(session_id, subdirs) + target_dir = session_id + + session_dir = profile_plugin_dir / target_dir + pb_files = list(session_dir.glob('*.xplane.pb')) + self.assertNotEmpty(pb_files, f'No .xplane.pb files found in {session_dir}') + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/profiler_test.py b/tests/profiler_test.py index 215e363e446d..a803e010859f 100644 --- a/tests/profiler_test.py +++ b/tests/profiler_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import concurrent.futures from functools import partial import glob import os @@ -21,6 +22,7 @@ import threading import time import unittest +import unittest.mock from absl.testing import absltest import pathlib @@ -28,6 +30,7 @@ import jax.numpy as jnp import jax.profiler import jax._src.test_util as jtu + from jax._src import profiler from jax import jit @@ -38,29 +41,30 @@ portpicker = None try: - from tensorflow.python.profiler import profiler_client - from tensorflow.python.profiler import profiler_v2 as tf_profiler -except ImportError: - profiler_client = None - tf_profiler = None - -TBP_ENABLED = False -try: - import tensorboard_plugin_profile - del tensorboard_plugin_profile - TBP_ENABLED = True + from xprof.convert import _pywrap_profiler_plugin + import jax.collect_profile except ImportError: - pass + _pywrap_profiler_plugin = None jax.config.parse_flags_with_absl() -@jtu.thread_unsafe_test_class() # profiler isn't thread-safe +# We do not allow multiple concurrent profiler sessions. +@jtu.thread_unsafe_test_class() class ProfilerTest(unittest.TestCase): # These tests simply test that the profiler API does not crash; they do not # check functional correctness. def setUp(self): + if ( + sys.version_info < (3, 14) + and hasattr(sys, "_is_gil_enabled") + and not sys._is_gil_enabled() + ): + self.skipTest( + "Profiler tests are not thread-safe under Python 3.13 free threading" + ) + super().setUp() self.worker_start = threading.Event() self.profile_done = False @@ -107,6 +111,57 @@ def testProgrammaticProfiling(self): self.assertIn(b"/device:TPU", proto) self.assertIn(b"pxla.py", proto) + def testProgrammaticProfilingConcurrency(self): + def work(): + x = jax.pmap(lambda x: jax.lax.psum(x + 1, 'i'), axis_name='i')( + jnp.ones(jax.local_device_count())) + jax.block_until_ready(x) + with tempfile.TemporaryDirectory() as tmpdir: + try: + jax.profiler.start_trace(tmpdir) + with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: + for _ in range(10): + executor.submit(work) + finally: + jax.profiler.stop_trace() + + proto_path = glob.glob(os.path.join(tmpdir, "**/*.xplane.pb"), + recursive=True) + self.assertEqual(len(proto_path), 1) + with open(proto_path[0], "rb") as f: + proto = f.read() + # Sanity check that serialized proto contains host, device, and + # Python traces without deserializing. + self.assertIn(b"/host:CPU", proto) + if jtu.test_device_matches(["tpu"]): + self.assertIn(b"/device:TPU", proto) + self.assertIn(b"pxla.py", proto) + + def testProgrammaticProfilingWithOptions(self): + with tempfile.TemporaryDirectory() as tmpdir: + try: + options = jax.profiler.ProfileOptions() + options.python_tracer_level = 0 + jax.profiler.start_trace(tmpdir, profiler_options=options) + jax.pmap(lambda x: jax.lax.psum(x + 1, "i"), axis_name="i")( + jnp.ones(jax.local_device_count()) + ) + finally: + jax.profiler.stop_trace() + + proto_path = glob.glob( + os.path.join(tmpdir, "**/*.xplane.pb"), recursive=True + ) + self.assertEqual(len(proto_path), 1) + with open(proto_path[0], "rb") as f: + proto = f.read() + # Verify that the serialized proto contains host and device traces, and + # does not contain Python traces. + self.assertIn(b"/host:CPU", proto) + if jtu.test_device_matches(["tpu"]): + self.assertIn(b"/device:TPU", proto) + self.assertNotIn(b"pxla.py", proto) + def testProgrammaticProfilingPathlib(self): with tempfile.TemporaryDirectory() as tmpdir_string: tmpdir = pathlib.Path(tmpdir_string) @@ -127,6 +182,29 @@ def testProgrammaticProfilingPathlib(self): self.assertIn(b"/device:TPU", proto) self.assertIn(b"pxla.py", proto) + def testProgrammaticProfilingWithOptionsPathlib(self): + with tempfile.TemporaryDirectory() as tmpdir_string: + tmpdir = pathlib.Path(tmpdir_string) + try: + options = jax.profiler.ProfileOptions() + options.advanced_configuration = {"tpu_trace_mode": "TRACE_ONLY_HOST"} + jax.profiler.start_trace(tmpdir, profiler_options=options) + jax.pmap(lambda x: jax.lax.psum(x + 1, "i"), axis_name="i")( + jnp.ones(jax.local_device_count()) + ) + finally: + jax.profiler.stop_trace() + + proto_path = tuple(tmpdir.rglob("*.xplane.pb")) + self.assertEqual(len(proto_path), 1) + proto = proto_path[0].read_bytes() + # Verify that the serialized proto contains host traces and does not + # contain TPU device traces. + self.assertIn(b"/host:CPU", proto) + if jtu.test_device_matches(["tpu"]): + self.assertNotIn(b"/device:TPU", proto) + self.assertIn(b"pxla.py", proto) + def testProfilerGetFDOProfile(self): # Tests stop_and_get_fod_profile could run. try: @@ -176,7 +254,7 @@ def testProgrammaticProfilingContextManager(self): def testProgrammaticGpuCuptiTracing(self): @jit def xy_plus_z(x, y, z): - return jnp.float32(jax.lax.batch_matmul(jnp.bfloat16(x), y)) + z + return jnp.float32(jax.lax.batch_matmul(jnp.bfloat16(x), y)) + z k = jax.random.key(0) s = 1, 16, 16 jax.devices() @@ -190,8 +268,71 @@ def xy_plus_z(x, y, z): proto_path = tuple(tmpdir.rglob("*.xplane.pb")) proto_bytes = proto_path[0].read_bytes() - if jtu.test_device_matches(["gpu"]): - self.assertIn(b"/device:GPU", proto_bytes) + self.assertIn(b"/device:GPU", proto_bytes) + + @jtu.run_on_devices("gpu") + @jtu.thread_unsafe_test() + def testProgrammaticGpuCuptiTracingWithOptions(self): + @jit + def xy_plus_z(x, y, z): + return jnp.float32(jax.lax.batch_matmul(jnp.bfloat16(x), y)) + z + + k = jax.random.key(0) + s = 1, 16, 16 + jax.devices() + x = jnp.int8(jax.random.normal(k, shape=s)) + y = jnp.bfloat16(jax.random.normal(k, shape=s)) + z = jnp.float32(jax.random.normal(k, shape=s)) + with tempfile.TemporaryDirectory() as tmpdir_string: + tmpdir = pathlib.Path(tmpdir_string) + options = jax.profiler.ProfileOptions() + options.advanced_configuration = { + "gpu_max_callback_api_events": 1000000, + "gpu_enable_nvtx_tracking": True, + } + with jax.profiler.trace(tmpdir): + xy_plus_z(x, y, z).block_until_ready() + + proto_path = tuple(tmpdir.rglob("*.xplane.pb")) + proto_bytes = proto_path[0].read_bytes() + self.assertIn(b"/device:GPU", proto_bytes) + + # TODO: b/443121646 - Enable PM sampling test on JAX OSS once the Github CI + # host machine has privileged access. + # @jtu.run_on_devices("gpu") + # @jtu.thread_unsafe_test() + # def testProgrammaticGpuCuptiTracingWithPmSampling(self): + # if not (jtu.is_cuda_compute_capability_equal("9.0")): + # self.skipTest("Only works on GPU with capability sm90") + + # @jit + # def xy_plus_z(x, y, z): + # return jnp.float32(jax.lax.batch_matmul(jnp.bfloat16(x), y)) + z + + # k = jax.random.key(0) + # s = 1, 16, 16 + # jax.devices() + # x = jnp.int8(jax.random.normal(k, shape=s)) + # y = jnp.bfloat16(jax.random.normal(k, shape=s)) + # z = jnp.float32(jax.random.normal(k, shape=s)) + # with tempfile.TemporaryDirectory() as tmpdir_string: + # tmpdir = pathlib.Path(tmpdir_string) + # options = jax.profiler.ProfileOptions() + # options.advanced_configuration = { + # "gpu_pm_sample_counters": ( + # "sm__cycles_active.sum" + # ), + # "gpu_pm_sample_interval_us": 500, + # } + # with jax.profiler.trace(tmpdir, profiler_options=options): + # xy_plus_z(x, y, z).block_until_ready() + + # proto_path = tuple(tmpdir.rglob("*.xplane.pb")) + # proto_bytes = proto_path[0].read_bytes() + # self.assertIn(b"/device:GPU", proto_bytes) + # self.assertIn( + # b"sm__cycles_active.sum", proto_bytes + # ) def testProgrammaticProfilingContextManagerPathlib(self): with tempfile.TemporaryDirectory() as tmpdir_string: @@ -246,8 +387,8 @@ def _check_xspace_pb_exist(self, logdir): 'Expected one path match: ' + path) @unittest.skip("Test causes OOMs") - @unittest.skipIf(not (portpicker and profiler_client and tf_profiler), - "Test requires tensorflow.profiler and portpicker") + @unittest.skipIf(not (portpicker and _pywrap_profiler_plugin), + "Test requires xprof and portpicker") def testSingleWorkerSamplingMode(self, delay_ms=None): def on_worker(port, worker_start): jax.profiler.start_server(port) @@ -262,17 +403,24 @@ def on_worker(port, worker_start): def on_profile(port, logdir, worker_start): worker_start.wait() - options = tf_profiler.ProfilerOptions( - host_tracer_level=2, - python_tracer_level=2, - device_tracer_level=1, - delay_ms=delay_ms, - ) + options = { + "host_tracer_level": 2, + "python_tracer_level": 2, + "device_tracer_level": 1, + "delay_ms": delay_ms, + } # Request for 1000 milliseconds of profile. duration_ms = 1000 - profiler_client.trace(f'localhost:{port}', logdir, duration_ms, - '', 1000, options) + _pywrap_profiler_plugin.trace( + f'localhost:{port}', + logdir, + '', + True, + duration_ms, + 3, + options + ) self.profile_done = True logdir = absltest.get_default_test_tmpdir() @@ -290,9 +438,8 @@ def on_profile(port, logdir, worker_start): self._check_xspace_pb_exist(logdir) @unittest.skipIf( - not (portpicker and profiler_client and tf_profiler and TBP_ENABLED), - "Test requires tensorflow.profiler, portpicker and " - "tensorboard_profile_plugin") + not (portpicker and _pywrap_profiler_plugin), + "Test requires xprof and portpicker") def test_remote_profiler(self): port = portpicker.pick_unused_port() jax.profiler.start_server(port) @@ -322,5 +469,56 @@ def on_profile(): thread_profiler.join() self._check_xspace_pb_exist(logdir) + @unittest.skip("Profiler takes >30s on Cloud TPUs") + @unittest.skipIf( + not (portpicker and _pywrap_profiler_plugin), + "Test requires xprof and portpicker") + def test_remote_profiler_gcs_path(self): + port = portpicker.pick_unused_port() + jax.profiler.start_server(port) + + profile_done = threading.Event() + logdir = "gs://mock-test-bucket/test-dir" + # Mock XProf call in collect_profile. + _pywrap_profiler_plugin.trace = unittest.mock.MagicMock() + def on_profile(): + jax.collect_profile(port, 500, logdir, no_perfetto_link=True) + profile_done.set() + + thread_profiler = threading.Thread( + target=on_profile, args=()) + thread_profiler.start() + start_time = time.time() + y = jnp.zeros((5, 5)) + while not profile_done.is_set(): + # The timeout here must be relatively high. The profiler takes a while to + # start up on Cloud TPUs. + if time.time() - start_time > 30: + raise RuntimeError("Profile did not complete in 30s") + y = jnp.dot(y, y) + jax.profiler.stop_server() + thread_profiler.join() + _pywrap_profiler_plugin.trace.assert_called_once_with( + unittest.mock.ANY, + logdir, + unittest.mock.ANY, + unittest.mock.ANY, + unittest.mock.ANY, + unittest.mock.ANY, + unittest.mock.ANY, + ) + + def test_advanced_configuration_getter(self): + options = jax.profiler.ProfileOptions() + advanced_config = { + "tpu_trace_mode": "TRACE_COMPUTE", + "tpu_num_sparse_cores_to_trace": 1, + "enableFwThrottleEvent": True, + } + options.advanced_configuration = advanced_config + returned_config = options.advanced_configuration + self.assertDictEqual(returned_config, advanced_config) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 05b4c8d7c0ff..3a70b08ea912 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -30,7 +30,7 @@ from jax._src import util from jax.experimental import io_callback from jax.experimental import pjit -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map import jax.numpy as jnp from jax.sharding import Mesh import numpy as np @@ -585,6 +585,89 @@ def fun(x): self.assertAllClose(2 * x, fun(x)) self.assertEqual(count(), 1) + @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") + def test_subbyte_operands(self, dtype: str): + if "2" in dtype and jtu.test_device_matches(["tpu"]): + self.skipTest( + "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" + " float4_e2m1fn." + ) + def get(x): + return x + def f(x): + y = jax.pure_callback( + get, + jax.ShapeDtypeStruct((8,), dtype=dtype), + x, + ) + return y + x = np.arange(8, dtype=dtype) + np.testing.assert_array_equal(jax.jit(f)(x), np.arange(8, dtype=dtype)) + + def test_pure_callback_sequential_vmap_method_eval_jaxpr(self): + def f(x): + return jax.pure_callback( + lambda x: x, jax.ShapeDtypeStruct(shape=(), dtype=jnp.float32), + x, vmap_method="sequential") + + jaxpr = jax.make_jaxpr(lambda: jax.vmap(f)( + jnp.zeros(100, dtype=jnp.float32)))() + with jax.ensure_compile_time_eval(): + jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts) # doesn't crash + + @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") + def test_subbyte_results(self, dtype: str): + if "2" in dtype and jtu.test_device_matches(["tpu"]): + self.skipTest( + "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" + " float4_e2m1fn." + ) + def get(): + return np.arange(8, dtype=dtype) + + def f(): + y = jax.pure_callback( + get, + jax.ShapeDtypeStruct((8,), dtype) + ) + return y + + np.testing.assert_array_equal(jax.jit(f)(), np.arange(8, dtype=dtype)) + + @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") + def test_non_default_stride_subbyte_results(self, dtype: str): + if "2" in dtype and jtu.test_device_matches(["tpu"]): + self.skipTest( + "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" + " float4_e2m1fn." + ) + x = jnp.arange(24, dtype=dtype).reshape(2, 3, 4) + def callback(x): + return np.asfortranarray(x) + + @jax.jit + def f(x): + return jax.pure_callback( + callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x + ) + + result = f(x) + np.testing.assert_array_equal(x, result) + + def test_non_default_stride(self): + x = jnp.arange(24, dtype=jnp.float32).reshape(2, 3, 4) + def callback(x): + return np.asfortranarray(x) + + @jax.jit + def f(x): + return jax.pure_callback( + callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x + ) + + result = f(x) + np.testing.assert_array_equal(x, result) + class PureCallbackTest(jtu.JaxTestCase): @@ -787,7 +870,7 @@ def sin_jvp(xs, ts): def f(x): return sin(x) out = f(2.) - np.testing.assert_allclose(out, jnp.cos(2.)) + np.testing.assert_allclose(out, jnp.cos(2.), atol=1e-7) def test_callback_inside_of_cond(self): @@ -990,26 +1073,11 @@ def f(x): def test_vmap_method_raise(self): @jax.vmap def f(x): - # Setting vectorized to None disables the current default behavior of - # falling back on sequential. - return jax.pure_callback(np.sin, x, x, vectorized=None) + return jax.pure_callback(np.sin, x, x) with self.assertRaisesRegex(NotImplementedError, "vmap is only supported"): f(jnp.arange(4.)) - def test_deprecated_vectorized(self): - def f(x, **kwargs): - return jax.pure_callback(np.sin, x, x, **kwargs) - - with self.assertWarnsRegex(DeprecationWarning, "The default behavior"): - jax.vmap(f)(jnp.arange(4.0)) - - with self.assertWarnsRegex(DeprecationWarning, "The vectorized argument"): - f(jnp.arange(4.0), vectorized=True) - - with self.assertWarnsRegex(DeprecationWarning, "The vectorized argument"): - f(jnp.arange(4.0), vectorized=False) - def test_vmap_method_expand_dims(self): def callback(x, y): self.assertTupleEqual(x.shape, (4,)) @@ -1057,19 +1125,17 @@ def fun(x): result += fun(jnp.ones((500, 500), jnp.complex64))[1] jax.block_until_ready(result) # doesn't deadlock - def test_non_default_stride(self): - x = jnp.arange(24, dtype=jnp.float32).reshape(2, 3, 4) - def callback(x): - return np.asfortranarray(x) - + def test_pure_callback_fastpath(self): + # Regression test for https://github.com/jax-ml/jax/issues/31319 @jax.jit def f(x): - return jax.pure_callback( - callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x - ) + return jax.pure_callback(lambda x: x, x, x) - result = f(x) - np.testing.assert_array_equal(x, result) + x = jax.numpy.arange(5.0) + with jtu.count_pjit_cpp_cache_miss() as count: + f(x) + f(x) + self.assertEqual(count(), 1) class IOCallbackTest(jtu.JaxTestCase): @@ -1117,6 +1183,8 @@ def f(x): self.assertEqual(_mut, 8) def test_cannot_call_ordered_io_in_pmap(self): + if config.pmap_shmap_merge.value: + self.skipTest("Test does not raise under pmap_shmap_merge=True") def f(x): return io_callback( lambda x: x, jax.ShapeDtypeStruct((), jnp.int32), x, ordered=True) @@ -1201,6 +1269,8 @@ def f(x, y): for ordered in [True, False] for with_sharding in [True, False] ) + @jtu.ignore_warning(message='.*Please use `jax.jit` instead.*', + category=DeprecationWarning) def test_can_use_io_callback_in_pjit( self, *, ordered: bool, with_sharding: bool ): @@ -1261,7 +1331,12 @@ def f(x): else: self.assertIn(f"{{maximal device={callback_device_index}}}", stablehlo_ir) + @jtu.ignore_warning(message='.*Please use `jax.jit` instead.*', + category=DeprecationWarning) def test_sequence_pjit_io_callback_ordered(self): + if jtu.is_device_tpu(7, 'x'): + self.skipTest('TODO(b/453664256): Failing on TPU 7x.') + # A sequence of pairs of calls to pjit(io_callback(ordered=True)) with each # pair on a different device assignment. _collected: list[int] = [] @@ -1313,11 +1388,18 @@ def f_base(i, x): jax.effects_barrier() self.assertEqual(_collected, expected) - def test_can_shard_io_callback_manually(self): - if config.use_shardy_partitioner.value: - self.skipTest("TODO(b/384938613): Failing under shardy.") + @parameterized.named_parameters( + dict(testcase_name='multi_device', + single_device=False), + dict(testcase_name='single_device', + single_device=True) + ) + def test_can_shard_io_callback_manually(self, single_device: bool): - mesh = Mesh(np.array(jax.devices()), axis_names=('x',)) + devices = jax.devices() + if single_device: + devices = devices[:1] + mesh = Mesh(np.array(devices), axis_names=('x',)) spec = jax.sharding.PartitionSpec('x') sharding = jax.sharding.NamedSharding(mesh, spec) diff --git a/tests/pytorch_interoperability_test.py b/tests/pytorch_interoperability_test.py index e41c4329b95b..98f3db1fe9ec 100644 --- a/tests/pytorch_interoperability_test.py +++ b/tests/pytorch_interoperability_test.py @@ -17,11 +17,11 @@ from absl.testing import absltest import jax -import jax.dlpack from jax._src import config from jax._src import test_util as jtu from jax._src import xla_bridge from jax._src.lib import xla_client +import jax.dlpack import jax.numpy as jnp config.parse_flags_with_absl() @@ -64,7 +64,7 @@ def testTorchToJaxFailure(self): r'striding are supported') with self.assertRaisesRegex(RuntimeError, regex_str): xla_client._xla.dlpack_managed_tensor_to_buffer( - y, client, client) + y, client.devices()[0], None) @jtu.sample_product(shape=all_shapes, dtype=torch_dtypes) def testJaxToTorch(self, shape, dtype): @@ -77,8 +77,7 @@ def testJaxToTorch(self, shape, dtype): rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) x = jnp.array(np) - dlpack = jax.dlpack.to_dlpack(x) - y = torch.utils.dlpack.from_dlpack(dlpack) + y = torch.utils.dlpack.from_dlpack(x) if dtype == jnp.bfloat16: # .numpy() doesn't work on Torch bfloat16 tensors. self.assertAllClose(np, @@ -108,18 +107,19 @@ def testJaxArrayToTorch(self, shape, dtype): else: self.assertAllClose(np, y.cpu().numpy()) - @jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor", - category=DeprecationWarning) def testTorchToJaxInt64(self): # See https://github.com/jax-ml/jax/issues/11895 x = jax.dlpack.from_dlpack( - torch.utils.dlpack.to_dlpack(torch.ones((2, 3), dtype=torch.int64))) + torch.ones((2, 3), dtype=torch.int64)) dtype_expected = jnp.int64 if config.enable_x64.value else jnp.int32 self.assertEqual(x.dtype, dtype_expected) + def testTorchToJaxNondefaultLayout(self): + x = torch.arange(4).reshape(2, 2).T + x = x.cuda() if jtu.test_device_matches(["gpu"]) else x + self.assertAllClose(x.cpu().numpy(), jax.dlpack.from_dlpack(x)) + @jtu.sample_product(shape=all_shapes, dtype=torch_dtypes) - @jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor", - category=DeprecationWarning) def testTorchToJax(self, shape, dtype): if not config.enable_x64.value and dtype in [ jnp.int64, @@ -135,8 +135,7 @@ def testTorchToJax(self, shape, dtype): else: x = torch.tensor(x_np) x = x.cuda() if jtu.test_device_matches(["gpu"]) else x - x = x.contiguous() - y = jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(x)) + y = jax.dlpack.from_dlpack(x) self.assertAllClose(x_np, y) # Verify the resulting value can be passed to a jit computation. @@ -159,7 +158,6 @@ def testTorchToJaxArray(self, shape, dtype): else: x = torch.tensor(x_np) x = x.cuda() if jtu.test_device_matches(["gpu"]) else x - x = x.contiguous() y = jax.dlpack.from_dlpack(x) self.assertAllClose(x_np, y) diff --git a/tests/qdwh_test.py b/tests/qdwh_test.py index 91cc3a51f876..955e23374fee 100644 --- a/tests/qdwh_test.py +++ b/tests/qdwh_test.py @@ -18,7 +18,7 @@ import jax from jax._src import config from jax._src import test_util as jtu -from jax._src.lax import qdwh +from jax._src.tpu.linalg import qdwh import jax.numpy as jnp import numpy as np diff --git a/tests/ragged_collective_test.py b/tests/ragged_collective_test.py index 844892adc052..d3b35a3acf6f 100644 --- a/tests/ragged_collective_test.py +++ b/tests/ragged_collective_test.py @@ -19,14 +19,14 @@ from absl.testing import parameterized import jax -import jax.ad_checkpoint from jax import lax +from jax import vmap from jax.sharding import PartitionSpec as P from jax._src import config from jax._src import test_util as jtu import jax.numpy as jnp -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map config.parse_flags_with_absl() @@ -89,7 +89,7 @@ def test_ragged_all_to_all(self, axis_name, mesh_axes): P(axis_name, None), ), out_specs=P(axis_name), - check_rep=False, + check_vma=False, ) def fwd( operand, output, input_offsets, send_sizes, output_offsets, recv_sizes @@ -175,7 +175,7 @@ def test_ragged_all_to_all_grad(self, axis_name, mesh_axes): P(axis_name, None), ), out_specs=P(axis_name), - check_rep=False, + check_vma=False, ) def fwd( operand, output, input_offsets, send_sizes, output_offsets, recv_sizes @@ -256,7 +256,7 @@ def test_ragged_all_to_all_axis_index_groups(self, axis_name, mesh_axes): P(axis_name, None), ), out_specs=P(axis_name), - check_rep=False, + check_vma=False, ) def fwd( operand, output, input_offsets, send_sizes, output_offsets, recv_sizes @@ -345,7 +345,7 @@ def test_ragged_all_to_all_degenerate_groups(self, axis_name, mesh_axes): P(axis_name, None), ), out_specs=P(axis_name), - check_rep=False, + check_vma=False, ) def fwd( operand, output, input_offsets, send_sizes, output_offsets, recv_sizes @@ -381,6 +381,261 @@ def fwd( c, jnp.array([[0, 0, 1, 0], [0, 2, 3, 4]], dtype=jnp.int32) ) + def test_ragged_all_to_all_vmap_multi_dim_operand(self): + device_type = jax.devices()[0].platform + if device_type == 'tpu' and jtu.get_tpu_version() < 4: + raise unittest.SkipTest( + 'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by TPU' + f' v{jtu.get_tpu_version()}' + ) + + axis_name = 'x' + mesh = jtu.create_mesh((2,), ('x',)) + data_sharding = P(axis_name, None, None) + operand_data = jnp.zeros((2, 2, 3), dtype=jnp.int32) + output_data = jnp.zeros((2, 2, 4), dtype=jnp.int32) + input_offsets_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + send_sizes_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + output_offsets_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + recv_sizes_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + + operand = jax.device_put(operand_data, jax.sharding.NamedSharding(mesh, data_sharding)) + output = jax.device_put(output_data, jax.sharding.NamedSharding(mesh, data_sharding)) + input_offsets = jax.device_put(input_offsets_data, jax.sharding.NamedSharding(mesh, data_sharding)) + send_sizes = jax.device_put(send_sizes_data, jax.sharding.NamedSharding(mesh, data_sharding)) + output_offsets = jax.device_put(output_offsets_data, jax.sharding.NamedSharding(mesh, data_sharding)) + recv_sizes = jax.device_put(recv_sizes_data, jax.sharding.NamedSharding(mesh, data_sharding)) + + @partial( + shard_map, + mesh=mesh, + in_specs=( + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + ), + out_specs=P(axis_name), + check_vma=False, + ) + def fwd( + operand, output, input_offsets, send_sizes, output_offsets, recv_sizes + ): + return lax.ragged_all_to_all( + operand=operand.reshape(operand.shape[1:]), + output=output.reshape(output.shape[1:]), + input_offsets=input_offsets.reshape(input_offsets.shape[1:]), + send_sizes=send_sizes.reshape(send_sizes.shape[1:]), + output_offsets=output_offsets.reshape(output_offsets.shape[1:]), + recv_sizes=recv_sizes.reshape(recv_sizes.shape[1:]), + axis_name=axis_name, + ) + + res = vmap(fwd, in_axes=0, out_axes=0 + )(operand, output, input_offsets, send_sizes, output_offsets, recv_sizes) + ref = jnp.stack(list(map(fwd, operand, output, input_offsets, send_sizes, output_offsets, recv_sizes))) + self.assertEqual(res.shape, ref.shape) + self.assertAllClose(res, ref, check_dtypes=False) + + @parameterized.named_parameters( + dict( + testcase_name='_batch_0_data_shard_axis_0_input_0', + axis_name='x', + vmap_axis_name='y', + mesh_axes=dict(x=2, y=2), + vmap_batch_axis=0, + data_shard_axis=0, + input_config=0, + ), + dict( + testcase_name='_batch_0_data_shard_axis_1_input_0', + axis_name='x', + vmap_axis_name='y', + mesh_axes=dict(x=2, y=2), + vmap_batch_axis=0, + data_shard_axis=1, + input_config=0, + ), + dict( + testcase_name='_batch_1_data_shard_axis_0_input_1', + axis_name='x', + vmap_axis_name='y', + mesh_axes=dict(x=2, y=2), + vmap_batch_axis=1, + data_shard_axis=0, + input_config=1, + ), + dict( + testcase_name='_batch_1_data_shard_axis_1_input_1', + axis_name='x', + vmap_axis_name='y', + mesh_axes=dict(x=2, y=2), + vmap_batch_axis=1, + data_shard_axis=1, + input_config=1, + ), + ) + def test_ragged_all_to_all_vmap( + self, + axis_name, + vmap_axis_name, + mesh_axes, + vmap_batch_axis, + data_shard_axis, + input_config, + ): + device_type = jax.devices()[0].platform + if device_type == 'tpu' and jtu.get_tpu_version() < 4: + raise unittest.SkipTest( + 'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by TPU' + f' v{jtu.get_tpu_version()}' + ) + mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys())) + + def get_data_sharding(axis): + if axis == 0: + return P(axis_name, None, None) + elif axis == 1: + return P(None, axis_name, None) + else: + raise ValueError("Invalid data_shard_axis") + + data_sharding = get_data_sharding(data_shard_axis) + + if input_config == 0: + operand_data = jnp.array([[[1, 2, 3], [4, 5, 6]], + [[1, 2, 3], [4, 5, 6]]], dtype=jnp.int32) + send_sizes_data = jnp.array([[[1, 2], [1, 1]], + [[1, 2], [1, 1]]], dtype=jnp.int32) + output_offsets_data = jnp.array([[[0, 0], [1, 2]], + [[0, 0], [1, 2]]], dtype=jnp.int32) + recv_sizes_data = jnp.array([[[1, 1], [2, 1]], + [[1, 1], [2, 1]]], dtype=jnp.int32) + elif input_config == 1: + operand_data = jnp.array([[[1, 2, 3], [1, 2, 3]], + [[4, 5, 6], [4, 5, 6]]], dtype=jnp.int32) + send_sizes_data = jnp.array([[[1, 2], [1, 2]], + [[1, 1], [1, 1]]], dtype=jnp.int32) + output_offsets_data = jnp.array([[[0, 0], [0, 0]], + [[1, 2], [1, 2]]], dtype=jnp.int32) + recv_sizes_data = jnp.array([[[1, 1], [1, 1]], + [[2, 1], [2, 1]]], dtype=jnp.int32) + else: + raise ValueError("Invalid input config") + + output_data = jnp.zeros((2, 2, 4), dtype=jnp.int32) + input_offsets_data = jnp.array([[[0, 1], [0, 1]], + [[0, 1], [0, 1]]], dtype=jnp.int32) + + operand = jax.device_put(operand_data, jax.sharding.NamedSharding(mesh, data_sharding)) + output = jax.device_put(output_data, jax.sharding.NamedSharding(mesh, data_sharding)) + input_offsets = jax.device_put(input_offsets_data, jax.sharding.NamedSharding(mesh, data_sharding)) + send_sizes = jax.device_put(send_sizes_data, jax.sharding.NamedSharding(mesh, data_sharding)) + output_offsets = jax.device_put(output_offsets_data, jax.sharding.NamedSharding(mesh, data_sharding)) + recv_sizes = jax.device_put(recv_sizes_data, jax.sharding.NamedSharding(mesh, data_sharding)) + + @partial( + shard_map, + mesh=mesh, + in_specs=( + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + ), + out_specs=P(axis_name), + check_vma=False, + ) + def fwd( + operand, output, input_offsets, send_sizes, output_offsets, recv_sizes + ): + return lax.ragged_all_to_all( + operand=operand.reshape(operand.shape[1:]), + output=output.reshape(output.shape[1:]), + input_offsets=input_offsets.reshape(input_offsets.shape[1:]), + send_sizes=send_sizes.reshape(send_sizes.shape[1:]), + output_offsets=output_offsets.reshape(output_offsets.shape[1:]), + recv_sizes=recv_sizes.reshape(recv_sizes.shape[1:]), + axis_name=axis_name, + ) + + res = vmap( + fwd, in_axes=vmap_batch_axis, out_axes=0, + )( + operand, output, input_offsets, send_sizes, output_offsets, recv_sizes + ) + + expected_res = [] + vmap_size = output_data.shape[vmap_batch_axis] + args = operand, output, input_offsets, send_sizes, output_offsets, recv_sizes + for i in range(vmap_size): + args_ = [jax.lax.index_in_dim(x, i, vmap_batch_axis, False) for x in args] + expected_res.append(fwd(*args_)) + expected_res = jnp.stack(expected_res) + self.assertAllClose(res, expected_res) + + def test_ragged_all_to_all_vmap_unsupported_axis_index_groups(self): + device_type = jax.devices()[0].platform + if device_type == 'tpu' and jtu.get_tpu_version() < 4: + raise unittest.SkipTest( + 'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by TPU' + f' v{jtu.get_tpu_version()}' + ) + + axis_name = 'x' + mesh_axes = dict(x=2) + mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys())) + data_sharding = P(axis_name, None, None) + operand_data = jnp.zeros((2, 2, 3), dtype=jnp.int32) + output_data = jnp.zeros((2, 2, 4), dtype=jnp.int32) + input_offsets_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + send_sizes_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + output_offsets_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + recv_sizes_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + + operand = jax.device_put(operand_data, jax.sharding.NamedSharding(mesh, data_sharding)) + output = jax.device_put(output_data, jax.sharding.NamedSharding(mesh, data_sharding)) + input_offsets = jax.device_put(input_offsets_data, jax.sharding.NamedSharding(mesh, data_sharding)) + send_sizes = jax.device_put(send_sizes_data, jax.sharding.NamedSharding(mesh, data_sharding)) + output_offsets = jax.device_put(output_offsets_data, jax.sharding.NamedSharding(mesh, data_sharding)) + recv_sizes = jax.device_put(recv_sizes_data, jax.sharding.NamedSharding(mesh, data_sharding)) + + @partial( + shard_map, + mesh=mesh, + in_specs=( + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + ), + out_specs=P(axis_name), + check_vma=False, + ) + def fwd( + operand, output, input_offsets, send_sizes, output_offsets, recv_sizes + ): + return lax.ragged_all_to_all( + operand=operand.reshape(operand.shape[1:]), + output=output.reshape(output.shape[1:]), + input_offsets=input_offsets.reshape(input_offsets.shape[1:]), + send_sizes=send_sizes.reshape(send_sizes.shape[1:]), + output_offsets=output_offsets.reshape(output_offsets.shape[1:]), + recv_sizes=recv_sizes.reshape(recv_sizes.shape[1:]), + axis_name=axis_name, + axis_index_groups=[[0, 1]], + ) + + with self.assertRaisesWithLiteralMatch( + NotImplementedError, 'Please open a feature request!'): + vmap(fwd, in_axes=0, out_axes=0, axis_name='b')(operand, output, input_offsets, send_sizes, output_offsets, recv_sizes) + def test_ragged_all_to_all_errors(self): operand = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=jnp.float32) output = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32) @@ -494,6 +749,111 @@ def test_ragged_all_to_all_errors(self): operand, output, input_offsets, send_sizes, output_offsets, jnp.array([], dtype=jnp.int32), axis_name=axis_name) + def test_vmap_basic(self): + device_type = jax.devices()[0].platform + if device_type == 'tpu' and jtu.get_tpu_version() < 4: + raise unittest.SkipTest( + 'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by TPU' + f' v{jtu.get_tpu_version()}' + ) + num_devices = len(jax.devices()) # we expect either 4 or 8 total devices + if num_devices not in (4, 8): + raise unittest.SkipTest("test requires 4 or 8 devices") + + expert_parallelism = 2 + pipeline_parallelism = num_devices // expert_parallelism # We expect this is either 2 or 4 + batch = 2 * expert_parallelism**2 + model = 3 + axis_name = "expert" + + # Define a mesh with PP + EP + mesh = jtu.create_mesh((expert_parallelism, pipeline_parallelism), ('expert', 'pipeline')) + x_partition_spec = jax.sharding.PartitionSpec("expert", None) + x_sharding = jax.sharding.NamedSharding(mesh, x_partition_spec) + + @partial( + shard_map, + mesh=mesh, + in_specs=( + x_partition_spec, + x_partition_spec, + x_partition_spec, + x_partition_spec, + x_partition_spec, + x_partition_spec, + ), + out_specs=(x_partition_spec), + check_vma=False, + ) + def ra2a_wrapper(x, output_shape, input_offsets, send_sizes, output_offsets, recv_sizes): + input_offsets = input_offsets.reshape(input_offsets.shape[1:]) + send_sizes = send_sizes.reshape(send_sizes.shape[1:]) + output_offsets = output_offsets.reshape(output_offsets.shape[1:]) + recv_sizes = recv_sizes.reshape(recv_sizes.shape[1:]) + return jax.lax.ragged_all_to_all( + x, output_shape, input_offsets, send_sizes, output_offsets, + recv_sizes, axis_name=axis_name,) + + # create an array x which is [batch, model] and has elements like + # [[0,0,0], + # [1,1,1], + # ... + x = jnp.arange(0.0, batch) + x = jnp.expand_dims(x, axis=1) + x = jnp.tile(x, (1, model)) + x = jax.device_put(x, x_sharding) + + output_shape = x.copy() + + input_offsets = jnp.array([[0, 2],[0,2]], dtype=jnp.int32) + input_offsets = jax.device_put(input_offsets, x_sharding) + + send_sizes = jnp.array([[2, 2],[2,2]], dtype=jnp.int32) + send_sizes = jax.device_put(send_sizes, x_sharding) + + output_offsets = jnp.array([[0, 0],[2,2]], dtype=jnp.int32) + output_offsets = jax.device_put(output_offsets, x_sharding) + + recv_sizes = jnp.array([[2, 2],[2,2]], dtype=jnp.int32) + recv_sizes = jax.device_put(recv_sizes, x_sharding) + + expected_array = jnp.array([[0,0,0],[1,1,1],[4,4,4],[5,5,5],[2,2,2],[3,3,3],[6,6,6],[7,7,7]], dtype=jnp.int32) + + ##### Non-vmap ##### + jit_wrapper = jax.jit(ra2a_wrapper) + x_a2a = jit_wrapper(x, output_shape, input_offsets, send_sizes, output_offsets, recv_sizes) + self.assertEqual(x_a2a.shape, (batch, model)) + self.assertAllClose(x_a2a, expected_array, check_dtypes=False) + + + #### Vmap ##### + vmap_func = jax.vmap( + ra2a_wrapper, + ) + jit_vmap_func = jax.jit(vmap_func) + + vmap_sharding = jax.sharding.NamedSharding(mesh, jax.P("pipeline", "expert", None)) + + def expand_array_for_vmap(arr): + arr = jnp.expand_dims(arr, axis=0) + arr = jnp.tile(arr, (pipeline_parallelism, 1, 1)) + arr = jax.device_put(arr, vmap_sharding) + return arr + + x_vmap = expand_array_for_vmap(x) + output_shape_vmap = expand_array_for_vmap(output_shape) + input_offsets_vmap = expand_array_for_vmap(input_offsets) + send_sizes_vmap = expand_array_for_vmap(send_sizes) + output_offsets_vmap = expand_array_for_vmap(output_offsets) + recv_sizes_vmap = expand_array_for_vmap(recv_sizes) + + vmap_output = jit_vmap_func( + x_vmap, output_shape_vmap, input_offsets_vmap, send_sizes_vmap, + output_offsets_vmap, recv_sizes_vmap) + self.assertEqual(vmap_output.shape, (pipeline_parallelism, batch, model)) + for i in range(pipeline_parallelism): + self.assertAllClose(vmap_output[i,:,:], expected_array, check_dtypes=False) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/random_lax_test.py b/tests/random_lax_test.py index b6f8b4f132bf..48fef40a1cf9 100644 --- a/tests/random_lax_test.py +++ b/tests/random_lax_test.py @@ -21,7 +21,6 @@ import numpy as np import scipy.linalg -import scipy.special import scipy.stats import jax @@ -46,7 +45,7 @@ @jtu.with_config(jax_legacy_prng_key='allow') -class LaxRandomTest(jtu.JaxTestCase): +class RandomTestBase(jtu.JaxTestCase): def _CheckCollisions(self, samples, nbits): fail_prob = 0.01 # conservative bound on statistical fail prob by Chebyshev @@ -70,9 +69,8 @@ def _CheckKolmogorovSmirnovCDF(self, samples, cdf, pval=None): samples = samples.astype('float32') # kstest fails for infinities starting in scipy 1.12 # (https://github.com/scipy/scipy/issues/20386) - # TODO(jakevdp): remove this logic if/when fixed upstream. scipy_version = jtu.parse_version(scipy.__version__) - if scipy_version >= (1, 12) and np.issubdtype(samples.dtype, np.floating): + if scipy_version < (1, 14) and np.issubdtype(samples.dtype, np.floating): samples = np.array(samples, copy=True) samples[np.isposinf(samples)] = 0.01 * np.finfo(samples.dtype).max samples[np.isneginf(samples)] = 0.01 * np.finfo(samples.dtype).min @@ -110,6 +108,11 @@ def _CheckChiSquared(self, samples, pmf, *, pval=None): def make_key(self, seed): return random.PRNGKey(seed, impl='threefry2x32') + +class CommonRandomTest(RandomTestBase): + """ + Tests of common functionality that should be run with all PRNG impls. + """ @jtu.sample_product( num=(None, 6, (6,), (2, 3), (2, 3, 4)), ) @@ -164,6 +167,60 @@ def testRngRandint(self, dtype): self.assertTrue(np.all(lo <= samples)) self.assertTrue(np.all(samples < hi)) + def test_eval_shape_big_random_array(self): + def f(x): + return random.normal(self.make_key(x), (int(1e12),)) + with jax.enable_checks(False): # check_jaxpr will materialize array + jax.eval_shape(f, 0) # doesn't error + + @jtu.sample_product( + type_=["int", "np.array", "jnp.array"], + seed=[-1, 0, 1, (1 << 32) - 1, (1 << 63) - 1, np.uint64((1 << 64) - 1)], + ) + def test_prng_jit_invariance(self, seed, type_): + if type_ == "int" and seed == (1 << 64) - 1: + self.skipTest("Expected failure: Python int too large.") + if not config.enable_x64.value and seed > np.iinfo(np.int32).max: + self.skipTest("Expected failure: Python int too large.") + type_ = {"int": int, "np.array": np.array, "jnp.array": jnp.array}[type_] + args_maker = lambda: [type_(seed)] + f = lambda s: random.key_data(self.make_key(s)) + self._CompileAndCheck(f, args_maker) + + def test_prng_errors(self): + seed = np.iinfo(np.int64).max + 1 + with self.assertRaises(OverflowError): + self.make_key(seed) + with self.assertRaises(OverflowError): + jax.jit(self.make_key)(seed) + + def test_random_split_doesnt_device_put_during_tracing(self): + key = self.make_key(1).block_until_ready() + with jtu.count_device_put() as count: + jax.jit(random.split)(key) + self.assertLessEqual(count(), 1) # 1 for the argument device_put + + def test_large_prng(self): + # https://github.com/jax-ml/jax/issues/11010 + def f(): + return random.uniform( + self.make_key(3), (308000000, 128), dtype=jnp.bfloat16) + + # TODO(jakevdp): key reuse checks for this OOM because of slice masking. + # Can we fix this? + with jax.debug_key_reuse(False): + # just lower, don't run, takes too long + jax.jit(f).lower() + + +class DistributionsTest(RandomTestBase): + """ + Tests of distribution statistics that need only be run with the default PRNG. + + We limit this to the default PRNG to avoid repeated execution of very costly + tests. So long as the input bits are valid (as tested in BasicRandomTest) then + the distribution logic tested here will apply correctly. + """ @jtu.sample_product(dtype=float_dtypes) def testNormal(self, dtype): key = lambda: self.make_key(0) @@ -227,8 +284,9 @@ def testTruncatedNormal(self, dtype): ], dtype=jtu.dtypes.floating + jtu.dtypes.integer, weighted=[True, False], + mode=[None, 'low', 'high'] ) - def testChoice(self, dtype, input_range_or_shape, shape, replace, weighted, axis): + def testChoice(self, dtype, input_range_or_shape, shape, replace, weighted, axis, mode): # This is the function API that we test against (note that self.rng().choice differs) np_choice = np.random.default_rng(0).choice p_dtype = dtypes.to_inexact_dtype(dtype) @@ -244,7 +302,7 @@ def testChoice(self, dtype, input_range_or_shape, shape, replace, weighted, axis p /= p.sum() else: p = None - rand = lambda key, x: random.choice(key, x, shape, replace, p, axis) + rand = lambda key, x: random.choice(key, x, shape, replace, p, axis, mode=mode) sample = rand(key(), x) if not is_range: self.assertEqual(dtype, sample.dtype) @@ -313,11 +371,13 @@ def testPermutationErrors(self): @jtu.sample_product( p=[0.1, 0.5, 0.9], dtype=jtu.dtypes.floating, + mode=[None, 'low', 'high'], ) - def testBernoulli(self, p, dtype): + def testBernoulli(self, p, dtype, mode): key = lambda: self.make_key(0) p = np.array(p, dtype=dtype) - rand = lambda key, p: random.bernoulli(key, p, (10000,)) + kwds = {} if mode is None else {'mode': mode} + rand = lambda key, p: random.bernoulli(key, p, (10000,), **kwds) crand = jax.jit(rand) uncompiled_samples = rand(key(), p) @@ -336,15 +396,16 @@ def testBernoulli(self, p, dtype): ] ], sample_shape=[(10000,), (5000, 2)], + mode=[None, 'low', 'high'], dtype=jtu.dtypes.floating, ) - def testCategorical(self, p, axis, dtype, sample_shape): + def testCategorical(self, p, axis, dtype, sample_shape, mode): key = lambda: self.make_key(0) p = np.array(p, dtype=dtype) logits = np.log(p) - 42 # test unnormalized out_shape = tuple(np.delete(logits.shape, axis)) shape = sample_shape + out_shape - rand = partial(random.categorical, shape=shape, axis=axis) + rand = partial(random.categorical, shape=shape, axis=axis, mode=mode) crand = jax.jit(rand) uncompiled_samples = rand(key(), logits) @@ -396,13 +457,29 @@ def testCategoricalWithoutReplacement(self, logits_shape, prefix_shape): counts = jax.vmap(partial(jnp.bincount, length=n_categories), 1)(flat) assert (counts <= 1).all() - def testBernoulliShape(self): key = self.make_key(0) with jax.numpy_rank_promotion('allow'): x = random.bernoulli(key, np.array([0.2, 0.3]), shape=(3, 2)) assert x.shape == (3, 2) + def testBernoulliSmallProbabilty(self): + # Regression test for https://github.com/jax-ml/jax/issues/28017 + key = jax.random.key(0) + + # Choose such that N * p is much less than 1. + p = jnp.float32(1E-10) + N = int(1E8) + + # mode='low' fails for p<~1E-7 in float32 + samples = jax.random.bernoulli(key, p=p, shape=N, mode='low') + self.assertNotEqual(samples.sum(), 0) + + # mode='high' is good up to p<~1E-14 in float32 + samples = jax.random.bernoulli(key, p=p, shape=N, mode='high') + self.assertEqual(samples.sum(), 0) + + @jtu.sample_product( a=[0.2, 5.], b=[0.2, 5.], @@ -886,7 +963,7 @@ def testMultivariateNormalCovariance(self): check_dtypes=False) @jtu.sample_product(method=['cholesky', 'eigh', 'svd']) - @jtu.skip_on_devices('gpu', 'tpu') # Some NaNs on accelerators. + @jtu.skip_on_devices('cuda', 'tpu') # Some NaNs on accelerators. ROCm supported def testMultivariateNormalSingularCovariance(self, method): # Singular covariance matrix https://github.com/jax-ml/jax/discussions/13293 mu = jnp.zeros((2,)) @@ -936,7 +1013,7 @@ def feature_map(n, d, sigma=1.0, seed=123): def testIssue756(self): key = self.make_key(0) w = random.normal(key, ()) - self.assertEqual(w.dtype, dtypes.canonicalize_dtype(jnp.float_)) + self.assertEqual(w.dtype, dtypes.default_float_dtype()) def testIssue1789(self): def f(x): @@ -1071,46 +1148,13 @@ def testChoiceShapeIsNotSequenceError(self): with self.assertRaises(TypeError): random.choice(key, 5, 2, replace=True) - def test_eval_shape_big_random_array(self): - def f(x): - return random.normal(self.make_key(x), (int(1e12),)) - with jax.enable_checks(False): # check_jaxpr will materialize array - jax.eval_shape(f, 0) # doesn't error - - @jtu.sample_product( - type_=["int", "np.array", "jnp.array"], - seed=[-1, 0, 1, (1 << 32) - 1, (1 << 63) - 1, np.uint64((1 << 64) - 1)], - ) - def test_prng_jit_invariance(self, seed, type_): - if type_ == "int" and seed == (1 << 64) - 1: - self.skipTest("Expected failure: Python int too large.") - if not config.enable_x64.value and seed > np.iinfo(np.int32).max: - self.skipTest("Expected failure: Python int too large.") - type_ = {"int": int, "np.array": np.array, "jnp.array": jnp.array}[type_] - args_maker = lambda: [type_(seed)] - f = lambda s: random.key_data(self.make_key(s)) - self._CompileAndCheck(f, args_maker) - - def test_prng_errors(self): - seed = np.iinfo(np.int64).max + 1 - with self.assertRaises(OverflowError): - self.make_key(seed) - with self.assertRaises(OverflowError): - jax.jit(self.make_key)(seed) - - def test_random_split_doesnt_device_put_during_tracing(self): - key = self.make_key(1).block_until_ready() - with jtu.count_device_put() as count: - jax.jit(random.split)(key) - self.assertLessEqual(count(), 1) # 1 for the argument device_put - @jtu.sample_product(dtype=int_dtypes + uint_dtypes) def test_randint_bounds(self, dtype): min = np.iinfo(dtype).min max = np.iinfo(dtype).max key = lambda: self.make_key(1701) shape = (10,) - if np.iinfo(dtype).bits < np.iinfo(dtypes.canonicalize_dtype(int)).bits: + if np.iinfo(dtype).bits < np.iinfo(dtypes.default_int_dtype()).bits: expected = random.randint(key(), shape, min, max + 1, dtype) self.assertArraysEqual(expected, random.randint(key(), shape, min - 12345, max + 12345, dtype)) else: @@ -1131,18 +1175,6 @@ def test_randint_out_of_range(self): self.assertGreater((r == 0).sum(), 0) self.assertGreater((r == 255).sum(), 0) - def test_large_prng(self): - # https://github.com/jax-ml/jax/issues/11010 - def f(): - return random.uniform( - self.make_key(3), (308000000, 128), dtype=jnp.bfloat16) - - # TODO(jakevdp): key reuse checks for this OOM because of slice masking. - # Can we fix this? - with jax.debug_key_reuse(False): - # just lower, don't run, takes too long - jax.jit(f).lower() - @jtu.sample_product(shape=[(3, 4)], logits_shape_base=[(3, 4), (3, 1), (1, 4)], axis=[-3, -2, -1, 0, 1, 2]) @@ -1408,6 +1440,21 @@ def test_batched_key_errors(self): jax.random.key_data(keys()) jax.random.key_impl(keys()) + @jtu.sample_product( + dtype=['int8', 'uint8', 'int16', 'uint16'] + ) + def test_randint_narrow_int_bias(self, dtype): + # Regression test for https://github.com/jax-ml/jax/issues/27702 + key = self.make_key(7534892) + n_samples = 100_000 + n_bins = 100 + data = jax.random.randint(key, (n_samples,), 0, n_bins, dtype=dtype) + + # Check that counts within each bin are consistent with a uniform distribution: + # i.e. counts are poisson-distributed about the average count per bin. + counts = jnp.bincount(data, length=n_bins).astype(float) + self._CheckKolmogorovSmirnovCDF(counts, scipy.stats.poisson(n_samples / n_bins).cdf) + def get_energy_distance(samples_1, samples_2): """ @@ -1461,7 +1508,7 @@ def _double_threefry_fold_in(key, data): tag='fry2') @jtu.with_config(jax_default_prng_impl='threefry2x32') -class LaxRandomWithCustomPRNGTest(LaxRandomTest): +class CustomPRNGTest(CommonRandomTest): def make_key(self, seed): return prng_internal.random_seed(seed, impl=double_threefry_prng_impl) @@ -1522,7 +1569,7 @@ def test_grad_of_prng_key(self): @jtu.with_config(jax_default_prng_impl='rbg') -class LaxRandomWithRBGPRNGTest(LaxRandomTest): +class RBGPRNGTest(CommonRandomTest): def make_key(self, seed): return random.PRNGKey(seed, impl='rbg') @@ -1634,7 +1681,7 @@ def test_randint_out_of_range(self): @jtu.with_config(jax_default_prng_impl='unsafe_rbg') -class LaxRandomWithUnsafeRBGPRNGTest(LaxRandomWithRBGPRNGTest): +class UnsafeRBGPRNGTest(RBGPRNGTest): def make_key(self, seed): return random.PRNGKey(seed, impl="unsafe_rbg") @@ -1648,24 +1695,6 @@ def test_vmap_split_mapped_key_values(self): self.assertArraysEqual(random.key_data(vmapped_keys), random.key_data(ref_keys)) -def _sampler_unimplemented_with_custom_prng(*args, **kwargs): - raise SkipTest('sampler only implemented for default RNG') - -for test_prefix in [ - 'testPoisson', - 'testPoissonBatched', - 'testPoissonShape', - 'testPoissonZeros', -]: - for attr in dir(LaxRandomTest): - if attr.startswith(test_prefix): - setattr(LaxRandomWithCustomPRNGTest, attr, - _sampler_unimplemented_with_custom_prng) - setattr(LaxRandomWithRBGPRNGTest, attr, - _sampler_unimplemented_with_custom_prng) - setattr(LaxRandomWithUnsafeRBGPRNGTest, attr, - _sampler_unimplemented_with_custom_prng) - if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/random_test.py b/tests/random_test.py index a51e387dca76..ce3b2f6a9956 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -33,10 +33,10 @@ from jax import random from jax._src import config from jax._src import core +from jax._src import dispatch from jax._src import dtypes from jax._src import test_util as jtu from jax import vmap -from jax.interpreters import xla from jax._src import random as jax_random from jax._src import prng as prng_internal @@ -243,13 +243,14 @@ def testThreefry2x32Empty(self): jnp.ones((10, 0,), jnp.uint32)) np.testing.assert_equal(result, np.zeros((10, 0,), dtype=np.uint32)) + @jtu.thread_unsafe_test() def testNoOpByOpUnderHash(self): def fail(*args, **kwargs): assert False - apply_primitive, xla.apply_primitive = xla.apply_primitive, fail + apply_primitive, dispatch.apply_primitive = dispatch.apply_primitive, fail try: _ = prng_internal.threefry_2x32(np.zeros(2, np.uint32), np.arange(10, dtype=np.uint32)) finally: - xla.apply_primitive = apply_primitive + dispatch.apply_primitive = apply_primitive @skipIf(config.threefry_partitionable.value, 'changed random bit values') @parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS]) @@ -278,13 +279,18 @@ def testRngRandomBits(self, make_key): expected64 = np.array([56197195, 4200222568, 961309823], dtype=np.uint32) self.assertArraysEqual(bits64, expected64) - @jtu.sample_product(prng_name=[name for name, _ in PRNG_IMPLS], - make_key=KEY_CTORS) - def testRngRandomBitsShapeDtype(self, prng_name, make_key): + @jtu.sample_product( + prng_name=[name for name, _ in PRNG_IMPLS], + make_key=KEY_CTORS, + explicit_x64_dtypes=config.ExplicitX64Mode.__members__.values(), + ) + def testRngRandomBitsShapeDtype(self, prng_name, make_key, + explicit_x64_dtypes): # Like testRngRandomBits, but only meant to exercise random_bits # on every PRNG implementation. Instead of values, only checks # that shapes/dtypes are as expected. + @config.explicit_x64_dtypes(explicit_x64_dtypes) def random_bits(key, width, shape): dtype = jnp.dtype(f'uint{width}') return jax.random.bits(key, shape, dtype) @@ -304,11 +310,20 @@ def random_bits(key, width, shape): self.assertEqual(bits32.shape, (3,)) self.assertEqual(bits32.dtype, np.dtype('uint32')) - with jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*"): - bits64 = random_bits(make_key(seed), 64, (3,)) - expected_dtype = np.dtype('uint64' if config.enable_x64.value else 'uint32') - self.assertEqual(bits64.shape, (3,)) - self.assertEqual(bits64.dtype, expected_dtype) + if explicit_x64_dtypes == config.ExplicitX64Mode.ERROR and not config.enable_x64.value: + with self.assertRaisesRegex(ValueError, "Explicitly requested dtype.*"): + bits64 = random_bits(make_key(seed), 64, (3,)) + else: + with jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*"): + bits64 = random_bits(make_key(seed), 64, (3,)) + expected_dtype = np.dtype( + "uint64" + if config.enable_x64.value + or explicit_x64_dtypes == config.ExplicitX64Mode.ALLOW + else "uint32" + ) + self.assertEqual(bits64.shape, (3,)) + self.assertEqual(bits64.dtype, expected_dtype) @skipIf(config.threefry_partitionable.value, 'changed random bit values') @parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS]) @@ -326,7 +341,6 @@ def random_bits(key, width, shape): rand_bits_32 = np.array([np.array(r).view(np.uint32) for r in rand_bits]) assert np.all(rand_bits_32 == rand_bits_32[0]) - @jtu.sample_product(case=_RANDOM_VALUES_CASES, make_key=KEY_CTORS) @skipIf(config.threefry_partitionable.value, 'changed random bit values') @jtu.skip_on_devices("tpu") # TPU precision causes issues. @@ -338,7 +352,7 @@ def testRandomDistributionValues(self, case, make_key): Any refactoring of random distributions that leads to non-trivial differences in this test should follow the procedure outlined at - https://jax.readthedocs.io/en/latest/api_compatibility.html#numerics-and-randomness + https://docs.jax.dev/en/latest/api_compatibility.html#numerics-and-randomness This includes: * Announcing the change in the CHANGELOG.md @@ -364,9 +378,9 @@ def testPRNGValues(self, make_key): # Test to ensure consistent random values between JAX versions seed = 0 self.assertEqual(random.randint(make_key(seed), (3, 3), 0, 8).dtype, - dtypes.canonicalize_dtype(jnp.int_)) + dtypes.default_int_dtype()) if config.enable_x64.value: - self.assertAllClose( + self.assertAllClose( random.randint(make_key(seed), (3, 3), 0, 8, dtype='int64'), np.array([[7, 2, 6], [2, 1, 0], @@ -602,10 +616,26 @@ def assertKeysEqual(self, key1, key2): self.assertEqual(key1.dtype, key2.dtype) self.assertArraysEqual(random.key_data(key1), random.key_data(key2)) + def make_keys(self, *shape, seed=28, impl='threefry2x32'): + seeds = seed + jnp.arange(math.prod(shape), dtype=jnp.uint32) + return jax.vmap(partial(random.key, impl=impl))(seeds).reshape(shape) + def test_construction(self): key = random.key(42) self.assertIsInstance(key, prng_internal.PRNGKeyArray) + def test_numpy_construction(self): + key = random.wrap_key_data(np.array([42, 173], dtype=np.uint32), + impl='threefry2x32') + self.assertIsInstance(key, prng_internal.PRNGKeyArray) + self.assertIsInstance(key._base_array, jax.Array) + self.assertEqual(key._base_array.device, jax.devices()[0]) + self.assertEqual(key.device, jax.devices()[0]) + + def test_device_property(self): + key = random.key(42) + self.assertEqual(key.device, key._base_array.device) + def test_random_clone(self): # Here we test value semantics and compatibility with jit/vmap # key reuse semantics are tested in key_reuse_test.py @@ -624,18 +654,20 @@ def test_issubdtype(self): self.assertFalse(jnp.issubdtype(key.dtype, np.integer)) self.assertFalse(jnp.issubdtype(key.dtype, np.number)) - with self.assertRaisesRegex(TypeError, "Cannot interpret"): - jnp.issubdtype(key, dtypes.prng_key) + if jtu.numpy_version() < (2, 4, 0): + with self.assertRaisesRegex(TypeError, "Cannot interpret"): + jnp.issubdtype(key, dtypes.prng_key) + else: + with jtu.ignore_warning(category=DeprecationWarning, + message="Implicit conversion of an array to a dtype"): + with self.assertRaisesRegex(ValueError, "Could not convert Array"): + jnp.issubdtype(key, dtypes.prng_key) @skipIf(not config.enable_custom_prng.value, 'relies on typed key upgrade flag') def test_construction_upgrade_flag(self): key = random.PRNGKey(42) self.assertIsInstance(key, prng_internal.PRNGKeyArray) - def make_keys(self, *shape, seed=28): - seeds = seed + jnp.arange(math.prod(shape), dtype=jnp.uint32) - return jax.vmap(random.key)(seeds).reshape(shape) - def test_key_as_seed(self): key = self.make_keys() with self.assertRaisesRegex(TypeError, "PRNGKey accepts a scalar seed"): @@ -657,6 +689,11 @@ def test_non_integer_seed(self): with self.assertRaisesRegex(TypeError, "PRNG key seed must be an integer"): random.key(seed) + def test_nbytes_property(self): + key = self.make_keys() + self.assertEqual(key.nbytes, key._base_array.nbytes) + self.assertEqual(key.nbytes, key.itemsize * key.size) + def test_dtype_property(self): k1, k2 = self.make_keys(), self.make_keys() self.assertEqual(k1.dtype, k2.dtype) @@ -914,6 +951,20 @@ def test_gather(self): self.assertIsInstance(ys, prng_internal.PRNGKeyArray) self.assertEqual(ys.shape, (3, 2, 1)) + @parameterized.parameters("threefry2x32", "rbg", "unsafe_rbg") + def test_gather_fill(self, impl): + # Regression test for https://github.com/jax-ml/jax/issues/33476 + keys = self.make_keys(4, impl=impl) + + # Expected fill value is a key wrapping an array containing uint32 max. + expected = random.wrap_key_data( + jnp.full_like(random.key_data(keys)[0], fill_value=np.iinfo('uint32').max), + impl=random.key_impl(keys)) + + out = jax.jit(lambda x: x.at[100].get(mode='fill'))(keys) + self.assertIsInstance(out, prng_internal.PRNGKeyArray) + self.assertKeysEqual(out, expected) + def test_select(self): ks = self.make_keys(3, 2) cs = jnp.array([True, False, False, True, False, True]).reshape(3, 2) @@ -947,12 +998,14 @@ def test_device_put(self): keys_on_device = jax.device_put(keys, device) self.assertKeysEqual(keys, keys_on_device) + @jtu.ignore_warning(category=DeprecationWarning) def test_device_put_sharded(self): devices = jax.devices() keys = self.make_keys(len(devices)) keys_on_device = jax.device_put_sharded(list(keys), devices) self.assertKeysEqual(keys, keys_on_device) + @jtu.ignore_warning(category=DeprecationWarning) def test_device_put_replicated(self): devices = jax.devices() key = self.make_keys() @@ -974,7 +1027,7 @@ def callback(index): def test_make_array_from_single_device_arrays(self): devices = jax.devices() shape = (len(devices),) - mesh = jtu.create_mesh((len(devices),), ('x',)) + mesh = jtu.create_mesh((len(devices),), ('x',), iota_order=True) sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x')) keys = random.split(random.key(0), len(devices)) arrays = [jax.device_put(keys[i:i + 1], device) for i, device in enumerate(devices)] diff --git a/tests/roofline_test.py b/tests/roofline_test.py index 564b4a9a1f9e..5480406aa104 100644 --- a/tests/roofline_test.py +++ b/tests/roofline_test.py @@ -14,11 +14,11 @@ from __future__ import annotations from functools import partial -from typing import Sequence +from collections.abc import Sequence from absl.testing import absltest import jax -from jax._src import mesh +from jax._src import core from jax._src import test_util as jtu from jax.experimental import roofline import jax.lax as lax @@ -29,6 +29,8 @@ jax.config.parse_flags_with_absl() jtu.request_cpu_devices(8) +_VERY_LARGE_NUMBER = 512 * 1024 + def create_inputs( *shardings: P, @@ -45,6 +47,49 @@ def create_inputs( return mesh, tuple(arrays) +def example_function(x): + return jnp.sin(x) + x**2 + + +@jax.custom_jvp +def example_custom_function(x): + """Example custom function. + + Small wrapper around `example_function`. We define `example_custom_function` + separately since we add the `@jax.custom_jvp` decorator and want to compare + its behavior to `example_function`'s in tests. + """ + return example_function(x) + + +@example_custom_function.defjvp +def example_custom_function_jvp(primals, tangents): + """Example custom function jvp. + + Normally this function would define a mathematically correct JVP, but its + definition has 0 effect on the roofline result, so we keep it very simple. + """ + return example_custom_function(primals), tangents + +# A fake primitive without a roofline rule. This is used to test that roofline +# can handle primitives without a roofline rule. +fake_jax_primitive_p = core.Primitive("fake_jax_primitive") + + +@fake_jax_primitive_p.def_impl +def _fake_jax_primitive_impl(x): + return x + + +@fake_jax_primitive_p.def_abstract_eval +def _fake_jax_primitive_abstract_eval(x): + return core.ShapedArray(x.shape, x.dtype) + + +def fake_jax_primitive_function(x): + return fake_jax_primitive_p.bind(x) + + class RooflineTest(jtu.JaxTestCase): def setUp(self): @@ -465,18 +510,13 @@ def collective_matmul(a, b): ) def test_unary_ops(self, f, dtype): data = jnp.zeros((3, 8), dtype=dtype) - out, result = roofline.roofline( - f, - in_specs=(P()), - out_specs=P(), - )(data) - with self.subTest("flops"): - self.assertEqual(result.unfused_flops, 3 * 8) - with self.subTest("hbm_bytes"): - self.assertEqual( - result.unfused_hbm_bytes, - data.dtype.itemsize * 3 * 8 + out.dtype.itemsize * 3 * 8, - ) + out, result = roofline.roofline(f)(data) + + self.assertEqual(result.unfused_flops, 3 * 8) + self.assertEqual( + result.unfused_hbm_bytes, + data.dtype.itemsize * 3 * 8 + out.dtype.itemsize * 3 * 8, + ) def test_binary_ops(self): for f in [ @@ -495,12 +535,9 @@ def test_binary_ops(self): lambda a, b: jnp.minimum(a, b), lambda a, b: jnp.maximum(a, b), ]: - out, result = roofline.roofline( - f, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int)) + out, result = roofline.roofline(f)( + jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int) + ) self.assertEqual(result.unfused_flops, 3 * 8) self.assertEqual( result.unfused_hbm_bytes, @@ -515,12 +552,7 @@ def test_broadcast(self): (2.0, jnp.ones((3, 8))), (jnp.zeros((3, 8)), 2.0), ]: - _, result = roofline.roofline( - lambda a, b: a + b, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(left, right) + _, result = roofline.roofline(lambda a, b: a + b)(left, right) self.assertEqual(result.unfused_flops, 3 * 8) def test_nested(self): @@ -531,27 +563,21 @@ def g(x): return g(x) + g(y) - _, result = roofline.roofline( - f, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(jnp.zeros((11, 4), dtype=int), jnp.ones((11, 4), dtype=int)) + _, result = roofline.roofline(f)( + jnp.zeros((11, 4), dtype=int), jnp.ones((11, 4), dtype=int) + ) self.assertEqual(result.unfused_flops, 3 * (11 * 4)) def test_no_mesh(self): - _, result = roofline.roofline( - lambda a, b: a + b, - in_specs=(P(), P()), - out_specs=P(), - )(jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int)) + _, result = roofline.roofline(lambda a, b: a + b)( + jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int) + ) self.assertEqual(result.unfused_flops, 3 * 8) def test_no_specs(self): - _, result = roofline.roofline( - lambda a, b: a + b, - mesh=mesh.AbstractMesh((), ()), - )(jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int)) + _, result = roofline.roofline(lambda a, b: a + b)( + jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int) + ) self.assertEqual(result.unfused_flops, 3 * 8) def test_no_mesh_and_no_specs(self): @@ -560,63 +586,109 @@ def test_no_mesh_and_no_specs(self): )(jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int)) self.assertEqual(result.unfused_flops, 3 * 8) + @jtu.parameterized.product( + cumulative_function=[lax.cummax, lax.cummin, lax.cumprod, lax.cumsum], + axis=[0, 1, 2], + ) + def test_cumulative_ops(self, cumulative_function: int, axis: int): + f = lambda x: cumulative_function(operand=x, axis=axis) + x = jnp.zeros((3, 8, 15), dtype=int) + + _, result = roofline.roofline(f)(x) + + self.assertEqual(result.unfused_flops, x.shape[axis]) + self.assertEqual( + result.unfused_hbm_bytes, 2 * self._bytes_per_word * 3 * 8 * 15 + ) + + @jtu.parameterized.named_parameters( + dict(testcase_name="axis_0", axis=0), + dict(testcase_name="axis_1", axis=1), + dict(testcase_name="axis_2", axis=2), + ) + def test_cumlogsumexp_p_roofline(self, axis: int): + f = lambda x: lax.cumlogsumexp(operand=x, axis=axis) + x = jnp.zeros((3, 8, 15), dtype=int) + + _, result = roofline.roofline(f)(x) + + self.assertEqual(result.unfused_flops, 2 * x.shape[axis]) + self.assertEqual( + result.unfused_hbm_bytes, 2 * self._bytes_per_word * 3 * 8 * 15 + ) + def test_dot_general(self): - _, result = roofline.roofline( - lambda a, b: a @ b, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(jnp.zeros((3, 7), dtype=int), jnp.ones((7, 5), dtype=int)) + _, result = roofline.roofline(lambda a, b: a @ b)( + jnp.zeros((3, 7), dtype=int), jnp.ones((7, 5), dtype=int) + ) self.assertEqual(result.unfused_flops, 2 * 3 * 7 * 5) self.assertEqual( result.unfused_hbm_bytes, self._bytes_per_word * (3 * 7 + 7 * 5 + 3 * 5) ) - def get_conv_output_dim(self, i, k, pad_low, pad_high, stride): + def get_conv_output_dim(self, i, k, pad_low, pad_high, stride) -> int: return jnp.floor((i - k + pad_low + pad_high) / stride) + 1 - @jtu.parameterized.named_parameters( - dict( - testcase_name="simple", - window_strides=(1, 1), - padding=((0, 0), (0, 0)), - ), - dict( - testcase_name="padding", - window_strides=(1, 1), - padding=((1, 2), (3, 4)), - ), - dict( - testcase_name="window_strides", - window_strides=(2, 2), - padding=((0, 0), (0, 0)), - ), - dict( - testcase_name="window_strides_and_padding", - window_strides=(3, 3), - padding=((1, 2), (3, 4)), - ), + def get_conv_num_output_channels( + self, batch_group_count: int, feature_group_count: int + ) -> int: + if batch_group_count > 1: + return batch_group_count + elif feature_group_count > 1: + return feature_group_count + else: + return 1 + + @jtu.parameterized.product( + window_strides=[(1, 1), (2, 2)], + padding=[((0, 0), (0, 0)), ((1, 2), (3, 4))], + # batch must be divisible by batch_group_count, so we only include factors + # of batch_group_count. + batch=[6, 12], + batch_group_count=[1, 3], + # num_input_channels must be divisible by feature_group_count, so we only + # include factors of feature_group_count. + num_input_channels=[6, 12], + feature_group_count=[1, 3], ) def test_conv_general_dilated_unfused_hbm_bytes( - self, window_strides: Sequence[int, int], padding: Sequence[int, int] + self, + window_strides: Sequence[int, int], + padding: Sequence[int, int], + batch: int, + batch_group_count: int, + num_input_channels: int, + feature_group_count: int, ): + if batch_group_count > 1 and feature_group_count > 1: + self.skipTest( + "batch_group_count and feature_group_count cannot both be > 1" + ) + + num_output_channels = self.get_conv_num_output_channels( + batch_group_count, feature_group_count + ) + + num_input_features = int(num_input_channels / feature_group_count) iw, ih = 100, 200 kw, kh = 7, 7 - input_data = jnp.zeros((1, 1, iw, ih), dtype=int) - kernel_data = jnp.ones((1, 1, kw, kh), dtype=int) + input_data = jnp.zeros((batch, num_input_channels, iw, ih), dtype=int) + kernel_data = jnp.ones( + (num_output_channels, num_input_features, kw, kh), dtype=int + ) conv = lambda a, b: lax.conv_general_dilated( - lhs=a, rhs=b, window_strides=window_strides, padding=padding + lhs=a, + rhs=b, + window_strides=window_strides, + padding=padding, + batch_group_count=batch_group_count, + feature_group_count=feature_group_count, ) - _, result = roofline.roofline( - conv, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(input_data, kernel_data) + _, result = roofline.roofline(conv)(input_data, kernel_data) - expected_input_size = 1 * 1 * iw * ih - expected_kernel_size = 1 * 1 * kw * kh + expected_input_size = batch * num_input_channels * iw * ih + expected_kernel_size = num_output_channels * num_input_features * kw * kh ow = self.get_conv_output_dim( iw, kw, padding[0][0], padding[0][1], window_strides[0] @@ -624,12 +696,14 @@ def test_conv_general_dilated_unfused_hbm_bytes( oh = self.get_conv_output_dim( ih, kh, padding[1][0], padding[1][1], window_strides[1] ) - expected_output_size = 1 * 1 * ow * oh + expected_output_shape = jnp.array( + (batch / batch_group_count, num_output_channels, ow, oh) + ) + expected_output_size = jnp.prod(expected_output_shape) # Bytes accessed is sum of inputs and output. expected_unfused_hbm_bytes = self._bytes_per_word * ( expected_input_size + expected_kernel_size + expected_output_size ) - # TODO(b/394648206): add subtest for unfused_flops once they are supported. self.assertEqual(result.unfused_hbm_bytes, expected_unfused_hbm_bytes) @jtu.parameterized.named_parameters( @@ -642,24 +716,22 @@ def test_conv_general_dilated_unfused_hbm_bytes( padding="SAME_LOWER", ), ) - def test_conv_general_dilated_padding_string_unfused_hbm_bytes(self, padding: str): - input_data = jnp.zeros((1, 1, 10, 20), dtype=int) + def test_conv_general_dilated_padding_string( + self, padding: str + ): + input_data = jnp.zeros((1, 1, 3, 3), dtype=int) kernel_data = jnp.ones((1, 1, 3, 3), dtype=int) conv = lambda a, b: lax.conv_general_dilated( lhs=a, rhs=b, window_strides=(1, 1), padding=padding ) - _, result = roofline.roofline( - conv, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(input_data, kernel_data) + _, result = roofline.roofline(conv)(input_data, kernel_data) - expected_input_size = 1 * 1 * 10 * 20 + # Test hbm bytes. + expected_input_size = 1 * 1 * 3 * 3 expected_kernel_size = 1 * 1 * 3 * 3 # Because of same{_lower} padding, output shape should equal to input shape. - # This may not be true for other `{feature, batch}`_group_count`s.c + # This may not be true for other `{feature, batch}`_group_count`s. expected_output_size = expected_input_size # Bytes accessed is sum of inputs and output. expected_unfused_hbm_bytes = self._bytes_per_word * ( @@ -667,19 +739,28 @@ def test_conv_general_dilated_padding_string_unfused_hbm_bytes(self, padding: st ) self.assertEqual(result.unfused_hbm_bytes, expected_unfused_hbm_bytes) - def test_conv_general_dilated_padding_string_valid_unfused_hbm_bytes(self): + # Test flops. + # For spatial_valid_position_counts, we have 3x3 output with the following + # flops for each element: + # 4 6 4 + # 6 9 6 + # 4 6 4 + # Non_spatial_dims_factor = 1 because `{batch, feature}_group_count` are + # both equal to 1. + # Each FMA is 2 flops. + self.assertEqual( + result.unfused_flops, + 2 * (4 + 6 + 4 + 6 + 9 + 6 + 4 + 6 + 4), + ) + + def test_conv_general_dilated_padding_string_valid(self): input_data = jnp.zeros((1, 1, 10, 20), dtype=int) kernel_data = jnp.ones((1, 1, 3, 3), dtype=int) conv = lambda a, b: lax.conv_general_dilated( lhs=a, rhs=b, window_strides=(1, 1), padding="VALID" ) - _, result = roofline.roofline( - conv, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(input_data, kernel_data) + _, result = roofline.roofline(conv)(input_data, kernel_data) expected_input_size = 1 * 1 * 10 * 20 expected_kernel_size = 1 * 1 * 3 * 3 @@ -690,19 +771,91 @@ def test_conv_general_dilated_padding_string_valid_unfused_hbm_bytes(self): * self.get_conv_output_dim(10, 3, 0, 0, 1) * self.get_conv_output_dim(20, 3, 0, 0, 1) ) + # Bytes accessed is sum of inputs and output. expected_unfused_hbm_bytes = self._bytes_per_word * ( expected_input_size + expected_kernel_size + expected_output_size ) self.assertEqual(result.unfused_hbm_bytes, expected_unfused_hbm_bytes) + # Output shape is [1x1x8x18] and each output element requires (3x3) FMAs, + # and each FMA is 2 flops. + self.assertEqual( + result.unfused_flops, 2 * expected_output_size * 3 * 3 + ) + + @jtu.parameterized.named_parameters( + dict( + testcase_name="padding", + input_spatial_dim=1, + window_strides=[1], + padding=[(_VERY_LARGE_NUMBER - 1, _VERY_LARGE_NUMBER - 1)], + lhs_dilation=[1], + ), + dict( + testcase_name="input", + input_spatial_dim=_VERY_LARGE_NUMBER, + window_strides=[_VERY_LARGE_NUMBER - 1], + padding=[(0, 0)], + lhs_dilation=[_VERY_LARGE_NUMBER], + ), + ) + def test_conv_general_dilated_flops_very_large( + self, input_spatial_dim, window_strides, padding, lhs_dilation + ): + input_data = jnp.zeros((1, 1, input_spatial_dim), dtype=int) + kernel_data = jnp.ones((1, 1, _VERY_LARGE_NUMBER), dtype=int) + conv = lambda a, b: lax.conv_general_dilated( + lhs=a, + rhs=b, + window_strides=window_strides, + padding=padding, + lhs_dilation=lhs_dilation, + ) + _, result = roofline.roofline(conv)(input_data, kernel_data) + + self.assertEqual(result.unfused_flops, 2 * _VERY_LARGE_NUMBER) + + def test_conv_general_dilated_flops_feature_group_count(self): + feature_group_count = 120 + input_data = jnp.zeros((1, feature_group_count, 10, 20), dtype=int) + kernel_data = jnp.ones((feature_group_count, 1, 3, 3), dtype=int) + conv = lambda a, b: lax.conv_general_dilated( + lhs=a, + rhs=b, + window_strides=(1, 1), + padding=((0, 0), (0, 0)), + feature_group_count=feature_group_count, + ) + _, result = roofline.roofline(conv)(input_data, kernel_data) + + # Output shape is [1x120x8x18] and each output element requires (3x3) + # FMAs and one FMA is 2 flops. + self.assertEqual( + result.unfused_flops, 2 * 120 * 8 * 18 * 3 * 3 + ) + + def test_conv_general_dilated_flops_batch_group_count(self): + batch_group_count = 120 + input_data = jnp.zeros((batch_group_count, 1, 10, 20), dtype=int) + kernel_data = jnp.ones((batch_group_count, 1, 3, 3), dtype=int) + conv = lambda a, b: lax.conv_general_dilated( + lhs=a, + rhs=b, + window_strides=(1, 1), + padding=((0, 0), (0, 0)), + batch_group_count=batch_group_count, + ) + _, result = roofline.roofline(conv)(input_data, kernel_data) + + # Output shape is [120x1x8x18] and each output element requires (3x3) + # FMAs and one FMA is 2 flops. + self.assertEqual( + result.unfused_flops, 2 * 120 * 8 * 18 * 3 * 3 + ) + def test_reduce_sum_no_axis(self): - _, result = roofline.roofline( - lambda x: jnp.sum(x), - mesh=mesh.AbstractMesh((), ()), - in_specs=(P()), - out_specs=P(), - )(jnp.zeros((11, 4))) + _, result = roofline.roofline(lambda x: jnp.sum(x))(jnp.zeros((11, 4))) self.assertEqual(result.unfused_flops, 11 * 4 - 1) self.assertEqual( result.unfused_hbm_bytes, self._bytes_per_word * (11 * 4 + 1) @@ -715,17 +868,333 @@ def test_reduce_sum_with_axis(self): ([0, 1], 11 * 4 - 1, 11 * 4 + 1), ([], 0, 11 * 4 + 11 * 4), ]: - _, result = roofline.roofline( - lambda x: jnp.sum(x, axis=axis), - mesh=mesh.AbstractMesh((), ()), - in_specs=(P()), - out_specs=P(), - )(jnp.zeros((11, 4))) + _, result = roofline.roofline(lambda x: jnp.sum(x, axis=axis))( + jnp.zeros((11, 4)) + ) self.assertEqual(result.unfused_flops, expected_flops) self.assertEqual( result.unfused_hbm_bytes, self._bytes_per_word * expected_memory ) + def test_custom_jvp_call_p_roofline(self): + dummy_input = jnp.ones((3, 8)) + + _, base_result = roofline.roofline(example_function)(dummy_input) + _, custom_result = roofline.roofline(example_custom_function)(dummy_input) + + self.assertEqual(custom_result.unfused_flops, base_result.unfused_flops) + self.assertEqual( + custom_result.unfused_hbm_bytes, base_result.unfused_hbm_bytes + ) + + def test_custom_jvp_call_p_roofline_with_neg(self): + dummy_input = jnp.ones((3, 8)) + + def with_neg(f): + return lambda x: jax.lax.neg(f(x)) + + _, base_result = roofline.roofline(with_neg(example_function))(dummy_input) + _, custom_result = roofline.roofline(with_neg(example_custom_function))( + dummy_input + ) + + self.assertEqual(custom_result.unfused_flops, base_result.unfused_flops) + self.assertEqual( + custom_result.unfused_hbm_bytes, base_result.unfused_hbm_bytes + ) + + @jtu.parameterized.named_parameters( + dict( + testcase_name="promise_in_bounds", + mode=lax.GatherScatterMode.PROMISE_IN_BOUNDS, + expected_flops=0, + ), + dict( + testcase_name="clip", + mode=lax.GatherScatterMode.CLIP, + expected_flops=0, + ), + dict( + testcase_name="fill_or_drop", + mode=lax.GatherScatterMode.FILL_OR_DROP, + expected_flops=4 * 2 * 1 + 2 * 3, + ), + ) + def test_gather_roofline(self, mode, expected_flops): + operand = jnp.zeros((3, 3), dtype=jnp.int32) + indices = jnp.zeros((2, 1), dtype=jnp.int32) + + dimension_numbers = jax.lax.GatherDimensionNumbers( + offset_dims=(1,), + collapsed_slice_dims=(0,), + start_index_map=(0,), + ) + + f = lambda x, y: jax.lax.gather( + x, + y, + dimension_numbers=dimension_numbers, + slice_sizes=(1, 3), + mode=mode, + ) + + _, result = roofline.roofline(f)(operand, indices) + + self.assertEqual(result.unfused_flops, expected_flops) + # Expected bytes: + # operand: 2 * 3 * sizeof(int32) = 24 + # indices: 2 * 1 * sizeof(int32) = 8 + # output: 2 * 3 * sizeof(int32) = 24 + # total = 56 + self.assertEqual(result.unfused_hbm_bytes, 56) + + def test_gather_batching_dims_roofline(self): + operand = jnp.zeros((5, 3, 3), dtype=jnp.int32) + indices = jnp.zeros((5, 1), dtype=jnp.int32) + + dimension_numbers = jax.lax.GatherDimensionNumbers( + offset_dims=(1,), + collapsed_slice_dims=(1,), + start_index_map=(1,), + operand_batching_dims=(0,), + start_indices_batching_dims=(0,), + ) + + f = lambda x, y: jax.lax.gather( + x, + y, + dimension_numbers=dimension_numbers, + slice_sizes=(1, 1, 3), + ) + + _, result = roofline.roofline(f)(operand, indices) + + self.assertEqual(result.unfused_flops, 0) + # Expected bytes: + # operand: 5 * 3 * sizeof(int32) = 60 + # indices: 5 * 1 * sizeof(int32) = 20 + # output: 5 * 3 * sizeof(int32) = 60 + # total = 140 + self.assertEqual(result.unfused_hbm_bytes, 140) + + def _assert_scatter_hbm_bytes_is_correct( + self, + result: roofline.RooflineResult, + indices: jnp.ndarray, + updates: jnp.ndarray, + ): + self.assertEqual( + result.unfused_hbm_bytes, + 3 * updates.size * updates.dtype.itemsize + + indices.size * indices.dtype.itemsize, + ) + + @jtu.parameterized.named_parameters( + dict( + testcase_name="scatter_add", + scatter_fn=jax.lax.scatter_add, + ), + dict( + testcase_name="scatter_max", + scatter_fn=jax.lax.scatter_max, + ), + dict( + testcase_name="scatter_min", + scatter_fn=jax.lax.scatter_min, + ), + dict( + testcase_name="scatter_mul", + scatter_fn=jax.lax.scatter_mul, + ), + dict( + testcase_name="scatter_sub", + scatter_fn=jax.lax.scatter_sub, + ), + ) + def test_scatter_unary_roofline(self, scatter_fn): + operand = jnp.zeros((3, 3), dtype=jnp.float32) + indices = jnp.zeros((2, 1), dtype=jnp.int32) + updates = jnp.ones((2, 3), dtype=jnp.float32) + + f = lambda x, y, z: scatter_fn( + x, + y, + z, + dimension_numbers=jax.lax.ScatterDimensionNumbers( + update_window_dims=(1,), + inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(0,), + ), + ) + + _, result = roofline.roofline(f)(operand, indices, updates) + + # The `update_jaxpr` computation is a simple unary op, which has 1 flop, and + # is applied for each element in the updates tensor, which has size 2 * 3. + self.assertEqual(result.unfused_flops, 1 * 2 * 3) + self._assert_scatter_hbm_bytes_is_correct(result, indices, updates) + + def test_scatter_with_batching_dims_roofline(self): + operand = jnp.zeros((5, 3, 3), dtype=jnp.float32) + indices = jnp.zeros((5, 1), dtype=jnp.int32) + updates = jnp.zeros((5, 3), dtype=jnp.float32) + + # Use `scatter_add` as an example. + f = lambda x, y, z: jax.lax.scatter_add( + x, + y, + z, + dimension_numbers=jax.lax.ScatterDimensionNumbers( + update_window_dims=(1,), + inserted_window_dims=(1,), + operand_batching_dims=(0,), + scatter_indices_batching_dims=(0,), + scatter_dims_to_operand_dims=(1,), + ), + ) + + _, result = roofline.roofline(f)(operand, indices, updates) + + # The `update_jaxpr` computation is a simple add, which has 1 flop, and + # is applied for each element in the updates tensor, which has size 5 * 3. + self.assertEqual(result.unfused_flops, 1 * 5 * 3) + self._assert_scatter_hbm_bytes_is_correct(result, indices, updates) + + def test_scatter_roofline(self): + operand = jnp.zeros((3, 3), dtype=jnp.float32) + indices = jnp.zeros((2, 1), dtype=jnp.int32) + updates = jnp.ones((2, 3), dtype=jnp.float32) + + dimension_numbers = jax.lax.ScatterDimensionNumbers( + update_window_dims=(1,), + inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(0,), + ) + + f = lambda x, y, z: jax.lax.scatter( + x, + y, + z, + dimension_numbers=dimension_numbers, + ) + + _, result = roofline.roofline(f)(operand, indices, updates) + + # There is no update computation, so 0 flops. + self.assertEqual(result.unfused_flops, 0) + # Memory is still accessed though. + self._assert_scatter_hbm_bytes_is_correct(result, indices, updates) + + def test_scatter_apply_roofline(self): + operand = jnp.zeros((3, 3), dtype=jnp.float32) + indices = jnp.ones((2, 1), dtype=jnp.int32) + updates = jnp.ones((2, 3), dtype=jnp.float32) + + f = lambda x, y, z: jax.lax.scatter_apply( + operand=x, + scatter_indices=y, + func=example_function, + update_shape=z.shape, + dimension_numbers=lax.ScatterDimensionNumbers( + update_window_dims=(1,), + inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(0,), + ), + indices_are_sorted=True, + mode=lax.GatherScatterMode.PROMISE_IN_BOUNDS, + ) + + _, result = roofline.roofline(f)(operand, indices, updates) + + # `example_function` should take 3 flops per element, and we compute it for + # each element in the updates tensor, which has size 2 * 3. + self.assertEqual(result.unfused_flops, 3 * 2 * 3) + self._assert_scatter_hbm_bytes_is_correct(result, indices, updates) + + def test_select_n_roofline(self): + which = jnp.zeros((4, 8), dtype=int) + cases = ( + jnp.zeros((4, 8), dtype=int), + jnp.zeros((4, 8), dtype=int), + jnp.zeros((4, 8), dtype=int), + ) + + out, result = roofline.roofline(lax.select_n)(which, *cases) + + self.assertEqual(result.unfused_flops, 4 * 8) + self.assertEqual( + result.unfused_hbm_bytes, + which.dtype.itemsize * 4 * 8 + out.dtype.itemsize * 4 * 8, + ) + + def test_random_primitive_roofline_does_not_raise_error(self): + """Tests that roofline can handle primitives with PRNG keys as ins/outs.""" + dummy_input = jax.random.PRNGKey(0) + + def f(key): + # Use jax.random.split to ensure the random_wrap primitive is used. + prng_key, _ = jax.random.split(key) + return prng_key + + _, result = roofline.roofline(f)(dummy_input) + + self.assertEqual(result.unfused_flops, 0) + self.assertEqual(result.unfused_hbm_bytes, 0) + + @jtu.parameterized.named_parameters( + dict( + testcase_name="pure_callback", + callback_fn=jax.pure_callback, + ), + dict( + testcase_name="io_callback", + callback_fn=jax._src.callback.io_callback, + ), + ) + def test_callback_with_output_roofline(self, callback_fn): + def _example_callback_function(x): + result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype) + return callback_fn(example_function, result_shape, x) + + x = jnp.zeros((3, 8), dtype=jnp.float32) + out, result = roofline.roofline(_example_callback_function)(x) + + # Bytes accessed is sum of inputs and output. + expected_hbm_bytes = ( + x.dtype.itemsize * x.size + out.dtype.itemsize * out.size + ) + self.assertEqual(result.unfused_hbm_bytes, expected_hbm_bytes) + self.assertEqual(result.unfused_flops, 0) + + def test_debug_callback_roofline(self): + def _example_debug_callback_function(x): + jax.debug.callback(example_function, x) + return x + + x = jnp.zeros((3, 8), dtype=jnp.float32) + _, result = roofline.roofline(_example_debug_callback_function)(x) + + # Bytes accessed is only the input, as debug.callback does not return a + # value. + self.assertEqual(result.unfused_hbm_bytes, x.dtype.itemsize * x.size) + self.assertEqual(result.unfused_flops, 0) + + def test_primitive_with_no_roofline_rule_contributes_nothing(self): + x = jnp.zeros((3, 8), dtype=jnp.float32) + _, result = roofline.roofline(fake_jax_primitive_function)(x) + self.assertEqual(result.flops, 0) + self.assertEqual(result.unfused_flops, 0) + self.assertEqual(result.hbm_bytes, 0) + self.assertEqual(result.unfused_hbm_bytes, 0) + + def test_primitive_with_no_roofline_rule_contributes_nothing_with_abs(self): + x = jnp.zeros((3, 8), dtype=jnp.float32) + f = lambda x: lax.abs(fake_jax_primitive_function(x)) + _, result = roofline.roofline(f)(x) + + _, expected_result = roofline.roofline(lax.abs)(x) + self.assertDataclassEqual(result, expected_result) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/scaled_matmul_stablehlo_test.py b/tests/scaled_matmul_stablehlo_test.py index 141839a19a08..6cd691634ede 100644 --- a/tests/scaled_matmul_stablehlo_test.py +++ b/tests/scaled_matmul_stablehlo_test.py @@ -19,6 +19,7 @@ import numpy as np import jax import jax.numpy as jnp +import jax.ad_checkpoint from jax.sharding import Mesh from jax.sharding import PartitionSpec, NamedSharding from jax._src import config @@ -43,17 +44,19 @@ ((None, "dp", "tp"), (None, "dp", "tp")), ((None, "tp", None), (None, "tp", None)), ((None, None, "tp"), (None, "tp", None)), + ((None, ("dp", "tp"), None), (None, ("dp"), None)), ] c_name = "__cudnn$blockScaledDot" expected_hlos = [ (c_name, "all-reduce", "f32[1,512,512]", "replica_groups={{0,1},{2,3}}"), - ("all-gather", "f8e4m3fn[1,512,512]", "replica_groups=[2,2]<=[4]", c_name), - ("all-gather", "f8e4m3fn[1,512,512]", "replica_groups=[2,2]<=[4]", c_name), + ("all-gather", "f8e4m3fn[512,512]", "replica_groups=[2,2]<=[4]", c_name), + ("all-gather", "f8e4m3fn[512,512]", "replica_groups=[2,2]<=[4]", c_name), + (c_name,), + ("all-gather", "f8e4m3fn[256,1024]", "replica_groups=[2,2]<=[4]", c_name), (c_name,), - ("all-gather", "f8e4m3fn[1,256,1024]", "replica_groups=[2,2]<=[4]", c_name), - (c_name, "reduce-scatter", "f32[2,256,512]", "replica_groups={{0,1},{2,3}}"), ("all-gather", "f8e4m3fn[2,512,1024]", "replica_groups=[2,2]<=[4]", c_name), ("all-gather", "f8e4m3fn[2,512,512]", "replica_groups=[2,2]<=[4]", c_name), + ("all-gather", "f8e4m3fn[2,256,1024]", "replica_groups=[2,2]<=[2,2]", c_name,), ] expected_output_spec = [ PartitionSpec('dp',), @@ -61,10 +64,18 @@ PartitionSpec('dp', None, 'tp'), PartitionSpec('dp', None, 'tp'), PartitionSpec('dp', 'tp', None), - PartitionSpec(None, 'dp', 'tp'), + PartitionSpec(None, 'dp'), PartitionSpec(None, 'tp', None), PartitionSpec(None, None, 'tp'), + PartitionSpec(None, ('dp', 'tp'), None), ] + +# The GSPMD sharding logic inserts additional reduce-scatters which don't exist +# in Shardy. +if not config.use_shardy_partitioner.value: + expected_output_spec[5] = PartitionSpec(None, 'dp', 'tp') + expected_hlos[5] += ("reduce-scatter", "f32[2,256,512]", "replica_groups={{0,1},{2,3}}") + sharding_configs = { input_sharding: (hlo, output_spec) for input_sharding, hlo, output_spec in zip(input_shardings, @@ -155,26 +166,46 @@ def shard_and_device_put( return a, b, in_shardings -def create_nvfp4_configs(global_scale=None): +def create_nvfp4_configs(tensors, enable_grad_clip=False): if _dtypes.float4_e2m1fn is None: return None - g_one_scale = jnp.ones((1, ), dtype=jnp.float32) - nvfp4_config = BlockScaleConfig( - mode='nvfp4', - block_size=16, - data_type=jnp.float4_e2m1fn, - scale_type=jnp.float8_e4m3fn, - global_scale=g_one_scale if global_scale is None else global_scale, - infer_only=False - ) - return [nvfp4_config for _ in range(3)] + DATA_TYPE = jnp.float4_e2m1fn + SCALE_TYPE = jnp.float8_e4m3fn + + def get_global_scale(tensor): + # If we have a tensor, compute the global scale from its maximum value. + # Otherwise, use the default. + if tensor is None: + return jnp.array(1, dtype=jnp.float32) -def update_global_scale(config, new_global_scale): - config.global_scale = new_global_scale - return config + # Compute maximum absolute values for scaling + amax = jnp.max(jnp.abs(tensor)).astype(jnp.float32) + amax *= 0.9 if enable_grad_clip else 1.0 -def generate_nvfp4_quantized_tensors(dot_config, output_type): + data_max = jnp.finfo(DATA_TYPE).max.astype(jnp.float32) + scale_max = jnp.finfo(SCALE_TYPE).max.astype(jnp.float32) + return amax / (data_max * scale_max) + + def make_config(tensor): + return BlockScaleConfig( + mode='nvfp4', + block_size=16, + data_type=DATA_TYPE, + scale_type=SCALE_TYPE, + global_scale=get_global_scale(tensor), + infer_only=False + ) + + return [make_config(tensor) for tensor in tensors] + +def dequantize_nvfp4_tensor(x, scale, output_type, config): + x_reshaped = x.astype(output_type).reshape(-1, 16) + scale_reshaped = scale.astype(output_type).reshape(-1, 1) + scaled = x_reshaped * scale_reshaped * config.global_scale.astype(output_type) + return scaled.reshape(x.shape) + +def generate_nvfp4_quantized_tensors(dot_config, output_type, enable_grad_clip=False): k1, k2 = jax.random.split(jax.random.key(0), 2) a_shape, b_shape, dimension_numbers = dot_config @@ -188,49 +219,21 @@ def generate_nvfp4_quantized_tensors(dot_config, output_type): b = shape_normalization(b_raw, b_dn) # Initialize NVFP4 configurations - block_scale_configs_nvfp4 = create_nvfp4_configs() - - # Compute maximum absolute values for scaling - amax_a = jnp.max(jnp.abs(a)).astype(jnp.float32) - amax_b = jnp.max(jnp.abs(b)).astype(jnp.float32) - - # Update global scales - data_max = jnp.finfo(block_scale_configs_nvfp4[0].data_type).max.astype( - jnp.float32 - ) - scale_max = jnp.finfo(block_scale_configs_nvfp4[0].scale_type).max.astype( - jnp.float32 - ) - - block_scale_configs_nvfp4[0] = update_global_scale( - block_scale_configs_nvfp4[0], amax_a / (data_max * scale_max)) - block_scale_configs_nvfp4[1] = update_global_scale( - block_scale_configs_nvfp4[1], amax_b / (data_max * scale_max)) + a_cfg, b_cfg, out_cfg = create_nvfp4_configs([a, b, None], enable_grad_clip) # Quantize tensors - a_nvfp4, a_scale = quantize(a, block_scale_configs_nvfp4[0]) - b_nvfp4, b_scale = quantize(b, block_scale_configs_nvfp4[1]) + a_nvfp4, a_scale = quantize(a, a_cfg) + b_nvfp4, b_scale = quantize(b, b_cfg) # Reshape and scale quantized tensors - def reshape_and_scale(x, scale, global_scale, bs, k): - reshaped = x.astype(output_type).reshape(*bs, k // 16, 16) - scaled = reshaped * jnp.expand_dims(scale.astype(output_type), -1) - return scaled.reshape(*bs, k) * global_scale.astype(output_type) - - *bs_a, k_a = a_nvfp4.shape - *bs_b, k_b = b_nvfp4.shape - assert k_a == k_b - - a_dequantized = reshape_and_scale( - a_nvfp4, a_scale, block_scale_configs_nvfp4[0].global_scale, bs_a, k_a) - b_dequantized = reshape_and_scale( - b_nvfp4, b_scale, block_scale_configs_nvfp4[1].global_scale, bs_b, k_b) + a_dequantized = dequantize_nvfp4_tensor(a_nvfp4, a_scale, output_type, a_cfg) + b_dequantized = dequantize_nvfp4_tensor(b_nvfp4, b_scale, output_type, b_cfg) return ( (a_raw, b_raw), (a_dequantized, b_dequantized), (a_nvfp4, b_nvfp4, a_scale, b_scale), - block_scale_configs_nvfp4 + [a_cfg, b_cfg, out_cfg] ) def create_mxfp8_configs(): @@ -270,16 +273,9 @@ class ScaledMatmulTest(jtu.JaxTestCase): def setUp(self): super().setUp() try: - cudnn_version = check_cudnn_version() + check_cudnn_version() except RuntimeError as e: self.skipTest(str(e)) - return - if _dtypes.float8_e8m0fnu is None: - self.skipTest("Requries >= ml_dtypes 0.5.0 to support float8_e8m0fnu") - if _dtypes.float4_e2m1fn is None: - self.skipTest("Requries >= ml_dtypes 0.5.0 to support float4_e2m1fn") - if cudnn_version < 90700: - self.skipTest("Requires >= cuDNN 9.7.0") if not jtu.is_cuda_compute_capability_at_least("10.0"): self.skipTest("Requires at least Blackwell arch") @@ -463,14 +459,9 @@ class ScaledDotGeneralTest(jtu.JaxTestCase): def setUp(self): super().setUp() try: - cudnn_version = check_cudnn_version() + check_cudnn_version() except RuntimeError as e: self.skipTest(str(e)) - return - if _dtypes.float8_e8m0fnu is None: - self.skipTest("Requries >= ml_dtypes 0.5.0 to support float8_e8m0fnu") - if cudnn_version < 90700: - self.skipTest("Requires >= cuDNN 9.7.0") if not jtu.is_cuda_compute_capability_at_least("10.0"): self.skipTest("Requires at least Blackwell arch") @@ -489,18 +480,12 @@ def test_quantize_nvfp4(self, shape): output_type = jnp.float32 k1, k2 = jax.random.split(jax.random.key(0), 2) - a = jax.random.uniform(k1, shape, minval=-1.0, dtype=output_type) + a = jax.random.uniform(k1, shape, minval=1.0, maxval=8.0, dtype=output_type) - block_scale_configs_nvfp4 = create_nvfp4_configs() - data_max = jnp.finfo(jnp.float4_e2m1fn).max.astype(jnp.float32) - scale_max = jnp.finfo(jnp.float8_e4m3fn).max.astype(jnp.float32) - amax_a = jnp.max(jnp.abs(a)).astype(jnp.float32) / (data_max * scale_max) - block_scale_configs_nvfp4[0] = update_global_scale( - block_scale_configs_nvfp4[0], jnp.asarray(amax_a, jnp.float32) - ) + config, = create_nvfp4_configs([a]) def fn(a): - a_nvfp4, a_scale = quantize(a, block_scale_configs_nvfp4[0]) + a_nvfp4, a_scale = quantize(a, config) return a_nvfp4, a_scale out_q, scale = jax.jit(fn)(a) @@ -508,6 +493,85 @@ def fn(a): self.assertArraysAllClose(out_q, out_q_ref, rtol=1e-5, atol=1e-5) self.assertArraysAllClose(scale, scale_ref, rtol=1e-5, atol=1e-5) + # Verify the quantization is close to the original. + self.assertArraysAllClose(dequantize_nvfp4_tensor(out_q, scale, output_type, config), + a, rtol=0.2, atol=0.5) + + @jtu.sample_product(value=[1e6, 1/4096]) + @jtu.run_on_devices("cuda") + def test_quantize_requires_global_scale(self, value): + output_type = jnp.float32 + k1, k2 = jax.random.split(jax.random.key(0), 2) + + a = jnp.array([value]*16) + config, = create_nvfp4_configs([a]) + out_q, scale = quantize(a, config) + # Without an adjusted global scale, the values clip at 2688/0. + self.assertArraysEqual(dequantize_nvfp4_tensor(out_q, scale, output_type, config), + jnp.array([value]*16, dtype=output_type)) + + @jtu.sample_product( + enable_grad_clip=[True, False], + configs=[ + # a_shape, b_shape, dimension_numbers + ((1, 128, 128), (1, 128, 128), (([2], [2]), ([0], [0]))), + ((30, 64), (100, 64), (([1], [1]), ([], []))), + ] + ) + @jtu.run_on_devices("cuda") + def test_nvfp4_gradient_clip(self, enable_grad_clip, configs): + output_type = jnp.float32 + (a_raw, b_raw), (a_dq, b_dq), _, block_scale_configs = ( + generate_nvfp4_quantized_tensors(configs, output_type, enable_grad_clip) + ) + a_gs = block_scale_configs[0].global_scale + b_gs = block_scale_configs[1].global_scale + dimension_numbers = configs[2] + + scaled_dot_general = partial( + scaled_dot_general_wrapper, + configs=block_scale_configs + ) + + def fwd(a, b, use_normalized=False): + y = scaled_dot_general( + a, b, dimension_numbers, + preferred_element_type=output_type + ) + return jnp.sum(y) + + j_train = jax.jit(jax.value_and_grad(fwd, argnums=[0, 1])) + _, (x_grad, w_grad) = j_train(a_raw, b_raw) + + data_max = jnp.finfo(jnp.float4_e2m1fn).max.astype(output_type) + scale_max = jnp.finfo(jnp.float8_e4m3fn).max.astype(output_type) + prev_amax_a = a_gs * data_max * scale_max + prev_amax_b = b_gs * data_max * scale_max + + # Use a large value to ensure no clipping + threshold_a = prev_amax_a if enable_grad_clip else 1e9 + threshold_b = prev_amax_b if enable_grad_clip else 1e9 + + # Verify gradients are clipped to 0 where |input| > global_scale * MAX * SCALE_MAX + self.assertArraysEqual( + jnp.where(jnp.abs(a_raw) > threshold_a, x_grad, 0), + jnp.zeros_like(x_grad), + ) + self.assertArraysEqual( + jnp.where(jnp.abs(b_raw) > threshold_b, w_grad, 0), + jnp.zeros_like(w_grad), + ) + if enable_grad_clip: + # Verify gradients are preserved where |input| <= global_scale * MAX * SCALE_MAX + self.assertArraysEqual( + jnp.where(jnp.abs(a_raw) <= prev_amax_a, x_grad, 0), + x_grad, + ) + self.assertArraysEqual( + jnp.where(jnp.abs(b_raw) <= prev_amax_b, w_grad, 0), + w_grad, + ) + @jtu.sample_product( configs=[ # a_shape, b_shape, dimension_numbers, is_training @@ -567,6 +631,16 @@ def fwd(a, b, is_ref=False, use_normalized=False): out_ref, _ = j_train_fwd_ref(a_dq, b_dq) self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2) + def _grad_clip(amax, x, grad): + return jnp.where(jnp.abs(x) <= amax, grad, 0) + + data_max = jnp.finfo(jnp.float4_e2m1fn).max.astype(output_type) + scale_max = jnp.finfo(jnp.float8_e4m3fn).max.astype(output_type) + prev_amax_a = a_gs * data_max * scale_max + prev_amax_b = b_gs * data_max * scale_max + + x_grad_ref = _grad_clip(prev_amax_a, a_raw, x_grad_ref) + w_grad_ref = _grad_clip(prev_amax_b, b_raw, w_grad_ref) self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1) self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1) else: @@ -659,11 +733,11 @@ def test_dot_general_sharded(self, in_shardings): k1, k2 = jax.random.split(jax.random.key(0), 2) a = cast_to_representable( - jax.random.uniform(k1, a_shape, minval=-1.0), + jax.random.uniform(k1, a_shape, minval=-1.0, dtype=jnp.float32), self.block_scale_configs[0].data_type, ) b = cast_to_representable( - jax.random.uniform(k2, b_shape, minval=-1.0), + jax.random.uniform(k2, b_shape, minval=-1.0, dtype=jnp.float32), self.block_scale_configs[1].data_type, ) @@ -694,10 +768,6 @@ def fwd(a, b, is_ref=False): j_train = jax.jit(jax.value_and_grad(partial(fwd), argnums=[0, 1]), in_shardings=input_shardings) - hlo_text = j_train.lower(a, b).compile().as_text() - hlo_pattern = re.compile( - r".*".join([re.escape(x) for x in ("custom-call", c_name)]) - ) j_train_ref = jax.jit( jax.value_and_grad(partial(fwd, is_ref=True), argnums=[0, 1]), @@ -731,11 +801,11 @@ def test_dot_general_vmap(self, configs): dimension_numbers = (([1], [1]), ([], [])) a = cast_to_representable( - jax.random.uniform(k1, a_shape, minval=-1.0), + jax.random.uniform(k1, a_shape, minval=-1.0, dtype=jnp.float32), self.block_scale_configs[0].data_type, ) b = cast_to_representable( - jax.random.uniform(k2, b_shape, minval=-1.0), + jax.random.uniform(k2, b_shape, minval=-1.0, dtype=jnp.float32), self.block_scale_configs[1].data_type, ) @@ -763,6 +833,74 @@ def fwd(a, b, is_ref=False): self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1) self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1) + @jtu.run_on_devices("cuda") + def test_remat_checkpoint_dots(self): + input = jnp.ones((1, 128, 128)) + config = create_nvfp4_configs([input])[0] + + def f(x): + x = jnp.sin(x) + x = scaled_dot_general_wrapper( + x, x, + configs=[config, config], + dimension_numbers=(([2], [2]), ([0], [0])), + preferred_element_type=jnp.float32, + ) + return jnp.sin(x) + + # First check that with "nothing_saveable" policy, the backwards pass + # recomputes the scaled matmul. + nothing_saved_f = jax.checkpoint( + f, policy=jax.checkpoint_policies.nothing_saveable) + _, nothing_saved_f_vjp = jax.vjp(nothing_saved_f, input) + jaxpr = str(nothing_saved_f_vjp.jaxpr) + self.assertEqual(jaxpr.count(' scaled_matmul_wrapper'), 1) + # Check that the custom backward for scaled_matmul is used. + self.assertEqual(jaxpr.count('bwd=scaled_dot_bwd'), 1) + + # With "checkpoint_dots" policy, the backwards pass should reuse + # the scaled matmul from the forward pass, so it should be missing from vjp. + saved_dots_f = jax.checkpoint( + f, policy=jax.checkpoint_policies.checkpoint_dots) + _, saved_dots_f_vjp = jax.vjp(saved_dots_f, input) + jaxpr = str(saved_dots_f_vjp.jaxpr) + self.assertEqual(jaxpr.count(' scaled_matmul_wrapper'), 0) + # Check that the custom backward for scaled_matmul is used. + self.assertEqual(jaxpr.count('bwd=scaled_dot_bwd'), 1) + + @jtu.run_on_devices("cuda") + def test_remat_checkpoint_dots_with_no_batch_dims(self): + input = jnp.ones((1, 128, 128)) + batched_input = jnp.ones((16, 128, 128)) + config = create_nvfp4_configs([input])[0] + + def f(x): + x = jnp.sin(x) + x = scaled_dot_general_wrapper( + x, x, + configs=[config, config], + dimension_numbers=(([2], [2]), ([0], [0])), + preferred_element_type=jnp.float32, + ) + return jnp.sin(x) + + # Verify that scaled_matmul without batch dimensions + # will be saved (i.e., not recomputed on backward pass). + checkpointed_f = jax.checkpoint( + f, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable) + _, dot_saved_f_vjp = jax.vjp(checkpointed_f, input) + jaxpr = str(dot_saved_f_vjp.jaxpr) + self.assertEqual(jaxpr.count(' scaled_matmul_wrapper'), 0) + # Check that the custom backward for scaled_matmul is used. + self.assertEqual(jaxpr.count('bwd=scaled_dot_bwd'), 1) + + # Scaled matmuls with batch dimensions will be recomputed + # on backward pass. Let's verify that here. + _, dot_not_saved_f_vjp = jax.vjp(checkpointed_f, batched_input) + jaxpr = str(dot_not_saved_f_vjp.jaxpr) + self.assertEqual(jaxpr.count(' scaled_matmul_wrapper'), 1) + # Check that the custom backward for scaled_matmul is used. + self.assertEqual(jaxpr.count('bwd=scaled_dot_bwd'), 1) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/scheduling_groups_test.py b/tests/scheduling_groups_test.py new file mode 100644 index 000000000000..dab8a5a929ef --- /dev/null +++ b/tests/scheduling_groups_test.py @@ -0,0 +1,179 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest + +import jax +import jax.numpy as jnp +from jax._src import test_util as jtu + +from jax.experimental.scheduling_groups import ( + scheduling_group, xla_metadata_call) + +jax.config.parse_flags_with_absl() + + +class SchedulingGroupsTest(jtu.JaxTestCase): + + def test_basic(self): + a = 1. + b = 2. + x = 3. + y = 4. + + @scheduling_group(name="grp0:sub_grp0") + def fn0(a, b): + c = jnp.add(a, b) + return c + + @scheduling_group(name="grp0:sub_grp1") + def fn1(x, y): + z = jnp.multiply(x, y) + return z + + @scheduling_group(name="grp0") + def fn(a, b, x, y): + c = fn0(a, b) + z = fn1(x, y) + return c, z + + lowered = jax.jit(fn).lower(a, b, x, y) + self.assertIn('scheduling_group = "grp0"', lowered.as_text()) + + def test_transforms(self): + @scheduling_group(name='yash') + def f(x): + return 2 * x + + ans = jax.vmap(f)(jnp.arange(3.)) + self.assertAllClose(ans, 2. * jnp.arange(3.)) + + ans = jax.grad(f)(3.) + self.assertAllClose(ans, 2., check_dtypes=False) + + # TODO(yashkatariya): Enable this on TPU once XLA:TPU knows about inlineable + @jtu.run_on_devices('cpu') + def test_xla_metadata_call_inlineable(self): + inp = jnp.arange(8.) + + @xla_metadata_call(inlineable="false") + def g(x): + return x * 2 + + @jax.jit + def f(x): + y = g(x) + return jnp.sin(y).sum() + + f(inp) # doesn't crash + + lowered = jax.jit(jax.grad(f)).lower(inp) + self.assertIn('inlineable = "false"', lowered.as_text()) + compiled = lowered.compile() + self.assertIn('inlineable="false"', compiled.as_text()) + compiled(inp) # doesn't crash + + @jtu.run_on_devices('cpu') + def test_xla_metadata_call_inlineable_remat_in_scan(self): + @xla_metadata_call(inlineable="false") + def f(x, y): + return x + y, (x + y).sum() + + def g(x, use_remat=True): + maybe_rematted_f = jax.remat(f) if use_remat else f + _, b = maybe_rematted_f(x, x) + return b + + grad_f = jax.jit(jax.grad(g), static_argnums=(1,)) + grads = grad_f(jnp.array(5.0), use_remat=False) + self.assertIsNotNone(grads) + grads = grad_f(jnp.array(5.0), use_remat=True) + self.assertIsNotNone(grads) + + @jtu.run_on_devices('cpu') + def test_xla_metadata_call_deduplication(self): + inp = jnp.arange(8.) + + @xla_metadata_call(inlineable='false') + @jax.jit + def g(x): + return x * 2 + + def f(x): + y = g(x) + z = g(y) + return z.sum() + + f(inp) # doesn't crash + lowered = jax.jit(f).lower(inp) + self.assertEqual( + lowered.as_text().count('func.func private @xla_metadata_call'), 1) + compiled = lowered.compile() + compiled(inp) # doesn't crash + + jax.jit(jax.grad(f))(inp) # doesn't crash + lowered = jax.jit(jax.grad(f)).lower(inp) + self.assertEqual( + lowered.as_text().count('func.func private @xla_metadata_call'), 1) + compiled = lowered.compile() + compiled(inp) # doesn't crash + + @jtu.run_on_devices('cpu') + def test_xla_metadata_call_deduplication_remat(self): + inp = jnp.arange(8.) + + @jax.remat + @xla_metadata_call(inlineable='false') + @jax.jit + def g(x): + return x * 2 + + def f(x): + y = g(x) + z = g(y) + return z.sum() + + f(inp) # doesn't crash + lowered = jax.jit(f).lower(inp) + self.assertEqual( + lowered.as_text().count('func.func private @xla_metadata_call'), 1) + compiled = lowered.compile() + compiled(inp) # doesn't crash + + jax.jit(jax.value_and_grad(f))(inp) # doesn't crash + lowered = jax.jit(jax.value_and_grad(f)).lower(inp) + self.assertEqual( + lowered.as_text().count('func.func private @xla_metadata_call'), 2) + compiled = lowered.compile() + compiled(inp) # doesn't crash + + @jtu.run_on_devices('cpu') + def test_xla_metadata_call_deduplication_kwargs(self): + inp = jnp.arange(8.) + + @xla_metadata_call(inlineable='false') + @jax.jit + def g(x): + return x * 2 + + def f(x): + y = g(x=x) + z = g(x=y) + return z.sum() + + f(inp) # doesn't crash + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/scipy_fft_test.py b/tests/scipy_fft_test.py index a6fdd1b79f58..54f17fb1d605 100644 --- a/tests/scipy_fft_test.py +++ b/tests/scipy_fft_test.py @@ -88,9 +88,6 @@ def testDctn(self, shape, dtype, s, axes, norm): axis=[-1, 0], norm=[None, 'ortho', 'backward'], ) - # TODO(phawkins): these tests are failing on T4 GPUs in CI with a - # CUDA_ERROR_ILLEGAL_ADDRESS. - @jtu.skip_on_devices("cuda") def testiDct(self, shape, dtype, n, axis, norm): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype),) @@ -108,9 +105,6 @@ def testiDct(self, shape, dtype, n, axis, norm): dtype=real_dtypes, norm=[None, 'ortho', 'backward'], ) - # TODO(phawkins): these tests are failing on T4 GPUs in CI with a - # CUDA_ERROR_ILLEGAL_ADDRESS. - @jtu.skip_on_devices("cuda") def testiDctn(self, shape, dtype, s, axes, norm): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype),) @@ -130,5 +124,20 @@ def testIdctNormalizationPrecision(self): actual = jsp_fft.idct(x, n=n, type=2) self.assertArraysAllClose(actual, expected, atol=1e-14) + @jtu.sample_product(func=['idctn', 'dctn']) + def testDctnShape(self, func): + # Regression test for https://github.com/jax-ml/jax/issues/31836 + x = np.arange(10.0).reshape(5, 2) + kwds = dict(type=2, s=(12, 7), axes=(-2, -1)) + + osp_func = getattr(osp_fft, func) + jsp_func = getattr(jsp_fft, func) + + expected = osp_func(x, **kwds) + actual = jsp_func(x, **kwds) + rtol = {np.float64: 1E-12, np.float32: 1E-4} + self.assertArraysAllClose(actual, expected, rtol=rtol) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/scipy_optimize_test.py b/tests/scipy_optimize_test.py index ffa576850538..19d6d9beda2b 100644 --- a/tests/scipy_optimize_test.py +++ b/tests/scipy_optimize_test.py @@ -14,7 +14,6 @@ from absl.testing import absltest import numpy as np -import scipy import scipy.optimize import jax diff --git a/tests/scipy_signal_test.py b/tests/scipy_signal_test.py index 11923257a9dd..b1c5d9c98fed 100644 --- a/tests/scipy_signal_test.py +++ b/tests/scipy_signal_test.py @@ -357,12 +357,11 @@ def testWelchWithDefaultStepArgsAgainstNumpy( if use_nperseg: kwargs['nperseg'] = nperseg if use_window: - kwargs['window'] = jnp.array(osp_signal.get_window('hann', nperseg), - dtype=dtypes.to_complex_dtype(dtype)) + kwargs['window'] = jnp.array(osp_signal.get_window('hann', nperseg)) if use_noverlap: kwargs['noverlap'] = noverlap - @jtu.ignore_warning(message="nperseg = 256 is greater than") + @jtu.ignore_warning(message="nperseg") def osp_fun(x): freqs, Pxx = osp_signal.welch(x, **kwargs) return freqs.astype(_real_dtype(dtype)), Pxx.astype(_real_dtype(dtype)) @@ -388,7 +387,7 @@ def osp_fun(x): ], dtype=default_dtypes, fs=[1.0, 16000.0], - window=['boxcar', 'triang', 'blackman', 'hamming', 'hann'], + window=['boxcar', 'triang', 'blackman', 'hamming', 'hann', 'USE_ARRAY'], onesided=[False, True], boundary=[False, True], ) @@ -399,6 +398,11 @@ def testIstftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg, new_freq_len = (shape[freqaxis] - 1) * 2 shape = shape[:freqaxis] + (new_freq_len ,) + shape[freqaxis + 1:] + if window == 'USE_ARRAY': + # ensure dtype matches the expected dtype of `xsubs` within the implementation. + window = np.ones(nperseg, dtype=( + dtypes.to_floating_dtype(dtype) if onesided else dtypes.to_complex_dtype(dtype))) + kwds = dict(fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, input_onesided=onesided, boundary=boundary, time_axis=timeaxis, freq_axis=freqaxis) diff --git a/tests/scipy_spatial_test.py b/tests/scipy_spatial_test.py index 3da98efce884..6b1c042b049e 100644 --- a/tests/scipy_spatial_test.py +++ b/tests/scipy_spatial_test.py @@ -123,8 +123,6 @@ def testRotationAsQuat(self, shape, dtype): shape=[(4,), (num_samples, 4)], ) def testRotationAsQuatCanonical(self, shape, dtype): - if scipy_version < (1, 11, 0): - self.skipTest("Scipy 1.11.0 added the `canonical` arg.") rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype),) jnp_fn = lambda q: jsp_Rotation.from_quat(q).as_quat(canonical=True) @@ -152,8 +150,6 @@ def testRotationAsQuatScalarFirst(self, shape, dtype): other_shape=[(num_samples, 4)], ) def testRotationConcatenate(self, shape, other_shape, dtype): - if scipy_version < (1, 8, 0): - self.skipTest("Scipy 1.8.0 needed for concatenate.") rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype), rng(other_shape, dtype),) jnp_fn = lambda q, o: jsp_Rotation.concatenate([jsp_Rotation.from_quat(q), jsp_Rotation.from_quat(o)]).as_rotvec() @@ -297,8 +293,6 @@ def testRotationInv(self, shape, dtype): shape=[(4,), (num_samples, 4)], ) def testRotationInvConjugate(self, shape, dtype): - if scipy_version < (1, 11, 0): - self.skipTest("Scipy prior to 1.11.0 used a negative conjugate.") rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype),) jnp_fn = lambda q: jsp_Rotation.from_quat(q).inv().as_quat() diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index 796d4490daea..5e1eb440f428 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -20,18 +20,15 @@ import numpy as np import scipy.stats as osp_stats -import scipy.version import jax import jax.numpy as jnp -from jax._src import dtypes, test_util as jtu +from jax._src import test_util as jtu from jax.scipy import stats as lsp_stats from jax.scipy.special import expit jax.config.parse_flags_with_absl() -scipy_version = jtu.parse_version(scipy.version.version) - all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)] one_and_two_dim_shapes = [(4,), (3, 4), (3, 1), (1, 4)] @@ -217,9 +214,6 @@ def testBernoulliPpf(self, shapes, dtypes): scipy_fun = osp_stats.bernoulli.ppf lax_fun = lsp_stats.bernoulli.ppf - if scipy_version < (1, 9, 2): - self.skipTest("Scipy 1.9.2 needed for fix https://github.com/scipy/scipy/pull/17166.") - def args_maker(): q, p = map(rng, shapes, dtypes) q = expit(q) @@ -1110,6 +1104,103 @@ def args_maker(): tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(4) + def testParetoPdf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.pareto.pdf + lax_fun = lsp_stats.pareto.pdf + + def args_maker(): + x, b, loc, scale = map(rng, shapes, dtypes) + return [x, b, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy( + scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3 + ) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(4) + def testParetoLogCdf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.pareto.logcdf + lax_fun = lsp_stats.pareto.logcdf + + def args_maker(): + x, b, loc, scale = map(rng, shapes, dtypes) + return [x, b, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy( + scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3 + ) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(4) + def testParetoCdf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.pareto.cdf + lax_fun = lsp_stats.pareto.cdf + + def args_maker(): + x, b, loc, scale = map(rng, shapes, dtypes) + return [x, b, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy( + scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3 + ) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(4) + def testParetoPpf(self, shapes, dtypes): + rng_positive = jtu.rand_positive(self.rng()) + rng_uniform = jtu.rand_uniform(self.rng()) + scipy_fun = osp_stats.pareto.ppf + lax_fun = lsp_stats.pareto.ppf + + def args_maker(): + q = rng_uniform(shapes[0], dtypes[0]) + b, loc, scale = map(rng_positive, shapes[1:], dtypes[1:]) + return [q, b, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy( + scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3 + ) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(4) + def testParetoSf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.pareto.sf + lax_fun = lsp_stats.pareto.sf + + def args_maker(): + x, b, loc, scale = map(rng, shapes, dtypes) + return [x, b, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy( + scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3 + ) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(4) + def testParetoLogSf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.pareto.logsf + lax_fun = lsp_stats.pareto.logsf + + def args_maker(): + x, b, loc, scale = map(rng, shapes, dtypes) + return [x, b, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy( + scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3 + ) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(4) def testTLogPdf(self, shapes, dtypes): @@ -1441,6 +1532,248 @@ def testMultivariateNormalLogpdfBatch(self, ndim, nbatch, dtype): result2 = jax.vmap(lsp_stats.multivariate_normal.logpdf)(x, mean, cov) self.assertArraysAllClose(result1, result2, check_dtypes=False) + + @genNamedParametersNArgs(3) + def testGumbelRLogPdf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.gumbel_r.logpdf + lax_fun = lsp_stats.gumbel_r.logpdf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + scale = np.abs(scale) + np.array(0.1, dtype=scale.dtype) # Ensure scale > 0 + return [x, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-3) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(3) + def testGumbelRPdf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.gumbel_r.pdf + lax_fun = lsp_stats.gumbel_r.pdf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + scale = np.abs(scale) + np.array(0.1, dtype=scale.dtype) # Ensure scale > 0 + return [x, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-3) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(3) + def testGumbelRLogCdf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.gumbel_r.logcdf + lax_fun = lsp_stats.gumbel_r.logcdf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + scale = np.abs(scale) + np.array(0.1, dtype=scale.dtype) # Ensure scale > 0 + return [x, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-3) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(3) + def testGumbelRCdf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.gumbel_r.cdf + lax_fun = lsp_stats.gumbel_r.cdf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + scale = np.abs(scale) + np.array(0.1, dtype=scale.dtype) # Ensure scale > 0 + return [x, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-3) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(3) + def testGumbelRPpf(self, shapes, dtypes): + rng_p = jtu.rand_uniform(self.rng(), low=0.01, high=0.99) + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.gumbel_r.ppf + lax_fun = lsp_stats.gumbel_r.ppf + + def args_maker(): + p = rng_p(shapes[0], dtypes[0]) + loc = rng(shapes[1], dtypes[1]) + scale = rng(shapes[2], dtypes[2]) + scale = np.abs(scale) + np.array(0.1, dtype=scale.dtype) # Ensure scale > 0 + return [p, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-3) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(3) + def testGumbelRSf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.gumbel_r.sf + lax_fun = lsp_stats.gumbel_r.sf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + scale = np.abs(scale) + np.array(0.1, dtype=scale.dtype) # Ensure scale > 0 + return [x, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-3) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(3) + def testGumbelRLogSf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.gumbel_r.logsf + lax_fun = lsp_stats.gumbel_r.logsf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + scale = np.abs(scale) + np.array(0.1, dtype=scale.dtype) # Ensure scale > 0 + return [x, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-3) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(3) + def testGumbelLLogPdf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.gumbel_l.logpdf + lax_fun = lsp_stats.gumbel_l.logpdf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + scale = np.abs(scale) + np.array(0.1, dtype=scale.dtype) # Ensure scale > 0 + return [x, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-3) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(3) + def testGumbelLPdf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.gumbel_l.pdf + lax_fun = lsp_stats.gumbel_l.pdf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + scale = np.abs(scale) + np.array(0.1, dtype=scale.dtype) # Ensure scale > 0 + return [x, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-3) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(3) + def testGumbelLLogCdf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.gumbel_l.logcdf + lax_fun = lsp_stats.gumbel_l.logcdf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + scale = np.abs(scale) + np.array(0.1, dtype=scale.dtype) # Ensure scale > 0 + return [x, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-3) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(3) + def testGumbelLCdf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.gumbel_l.cdf + lax_fun = lsp_stats.gumbel_l.cdf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + scale = np.abs(scale) + np.array(0.1, dtype=scale.dtype) # Ensure scale > 0 + return [x, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-3) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(3) + def testGumbelLPpf(self, shapes, dtypes): + rng_p = jtu.rand_uniform(self.rng(), low=0.01, high=0.99) + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.gumbel_l.ppf + lax_fun = lsp_stats.gumbel_l.ppf + + def args_maker(): + p = rng_p(shapes[0], dtypes[0]) + loc = rng(shapes[1], dtypes[1]) + scale = rng(shapes[2], dtypes[2]) + scale = np.abs(scale) + np.array(0.1, dtype=scale.dtype) # Ensure scale > 0 + return [p, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-3) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(3) + def testGumbelLSf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.gumbel_l.sf + lax_fun = lsp_stats.gumbel_l.sf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + scale = np.abs(scale) + np.array(0.1, dtype=scale.dtype) # Ensure scale > 0 + return [x, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-3) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(3) + def testGumbelLLogSf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.gumbel_l.logsf + lax_fun = lsp_stats.gumbel_l.logsf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + scale = np.abs(scale) + np.array(0.1, dtype=scale.dtype) # Ensure scale > 0 + return [x, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-3) + self._CompileAndCheck(lax_fun, args_maker) + + # Edge case tests + def testGumbelRPdfZero(self): + # Test at specific values + self.assertAllClose( + osp_stats.gumbel_r.pdf(0.0, 0.0, 1.0), lsp_stats.gumbel_r.pdf(0.0, 0.0, 1.0), atol=1E-6) + + def testGumbelLPdfZero(self): + # Test at specific values + self.assertAllClose( + osp_stats.gumbel_l.pdf(0.0, 0.0, 1.0), lsp_stats.gumbel_l.pdf(0.0, 0.0, 1.0), atol=1E-6) + @jtu.sample_product( inshape=[(50,), (3, 50), (2, 12)], dtype=jtu.dtypes.floating, @@ -1664,9 +1997,6 @@ def evaluate_kde(kde, x): message="All axis-slices of one or more sample arguments are too small", ) def testMode(self, shape, dtype, axis, contains_nans, keepdims): - if scipy_version < (1, 9, 0) and keepdims != True: - self.skipTest("scipy < 1.9.0 only support keepdims == True") - if contains_nans: rng = jtu.rand_some_nan(self.rng()) else: @@ -1675,25 +2005,7 @@ def testMode(self, shape, dtype, axis, contains_nans, keepdims): def scipy_mode_wrapper(a, axis=0, nan_policy='propagate', keepdims=None): """Wrapper to manage the shape discrepancies between scipy and jax""" - if scipy_version < (1, 11, 0) and a.size == 0: - if keepdims: - if axis == None: - output_shape = tuple(1 for _ in a.shape) - else: - output_shape = tuple(1 if i == axis else s for i, s in enumerate(a.shape)) - else: - if axis == None: - output_shape = () - else: - output_shape = np.delete(np.array(a.shape, dtype=np.int64), axis) - t = dtypes.canonicalize_dtype(jax.numpy.float_) - return (np.full(output_shape, np.nan, dtype=t), - np.zeros(output_shape, dtype=t)) - - if scipy_version < (1, 9, 0): - result = osp_stats.mode(a, axis=axis, nan_policy=nan_policy) - else: - result = osp_stats.mode(a, axis=axis, nan_policy=nan_policy, keepdims=keepdims) + result = osp_stats.mode(a, axis=axis, nan_policy=nan_policy, keepdims=keepdims) if a.size != 0 and axis == None and keepdims == True: output_shape = tuple(1 for _ in a.shape) @@ -1748,16 +2060,78 @@ def testSEM(self, shape, dtype, axis, ddof, nan_policy, keepdims): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] - kwds = {} if scipy_version < (1, 11) else {'keepdims': keepdims} scipy_fun = partial(osp_stats.sem, axis=axis, ddof=ddof, nan_policy=nan_policy, - **kwds) + keepdims=keepdims) lax_fun = partial(lsp_stats.sem, axis=axis, ddof=ddof, nan_policy=nan_policy, - **kwds) + keepdims=keepdims) tol_spec = {np.float32: 2e-4, np.float64: 5e-6} tol = jtu.tolerance(dtype, tol_spec) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, atol=tol) self._CompileAndCheck(lax_fun, args_maker, atol=tol) + @jtu.sample_product( + shape=[(), (5,), (3, 4)], + dtype=jtu.dtypes.floating, + ) + def testPoissonEntropy(self, shape, dtype): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.poisson.entropy + lax_fun = lsp_stats.poisson.entropy + + args_maker = lambda: [rng(shape, dtype)] + tol = ({np.float32: 1e-2, np.float64: 1e-4} if jtu.test_device_matches(["tpu"]) + else {np.float32: 2e-4, np.float64: 5e-6}) + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker,check_dtypes=False, tol=tol) + self._CompileAndCheck(lax_fun, args_maker, rtol=1e-4) + + @genNamedParametersNArgs(2) + def testPoissonEntropyWithLoc(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = lambda mu, loc: osp_stats.poisson.entropy(mu, loc=loc) + lax_fun = lambda mu, loc: lsp_stats.poisson.entropy(mu, loc) + + args_maker = lambda: [rng(shapes[0], dtypes[0]), rng(shapes[1], dtypes[1])] + tol = ({np.float32: 1e-2, np.float64: 1e-4} if jtu.test_device_matches(["tpu"]) + else {np.float32: 2e-4, np.float64: 5e-6}) + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=tol) + self._CompileAndCheck(lax_fun, args_maker, rtol=1e-4) + + @jtu.sample_product( + dtype=jtu.dtypes.floating, + ) + def testPoissonEntropyEdgeCases(self, dtype): + """Test edge cases: invalid mu and very small mu""" + tol = ({np.float32: 1e-2, np.float64: 1e-4} if jtu.test_device_matches(["tpu"]) + else {np.float32: 2e-4, np.float64: 5e-6}) + + # Invalid mu (should return NaN) + invalid_mu = jnp.array([-1.0, 0.0, -5.0], dtype=dtype) + jax_result = lsp_stats.poisson.entropy(invalid_mu) + scipy_result = osp_stats.poisson.entropy(np.array(invalid_mu)) + self.assertAllClose(jax_result, scipy_result, check_dtypes=False, rtol=tol) + + # Very small mu + small_mu = jnp.array([0.01, 0.1, 0.5], dtype=dtype) + jax_result = lsp_stats.poisson.entropy(small_mu) + scipy_result = osp_stats.poisson.entropy(np.array(small_mu)) + self.assertAllClose(jax_result, scipy_result,check_dtypes=False, rtol=tol) + + @jtu.sample_product( + dtype=jtu.dtypes.floating, + ) + def testPoissonEntropyRegimes(self, dtype): + """Test all three computational regimes""" + mu = jnp.array([2.0, 5.0, 9.0, 15.0, 50.0, 99.0, 100.0, 200.0, 500.0], dtype=dtype) + scipy_fun = lambda m: osp_stats.poisson.entropy(m) + lax_fun = lambda m: lsp_stats.poisson.entropy(m) + args_maker = lambda: [mu] + + tol = ({np.float32: 1e-2, np.float64: 1e-4} if jtu.test_device_matches(["tpu"]) + else {np.float32: 2e-4, np.float64: 5e-6}) + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker,check_dtypes=False, tol=tol) + self._CompileAndCheck(lax_fun, args_maker, rtol=1e-4) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index 6d1ffe744ed9..c9ebfe0590a0 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -14,39 +14,38 @@ from __future__ import annotations -import enum +import collections from collections.abc import Callable, Sequence import cProfile +import enum +from functools import partial import itertools import math +import operator as op import os from pstats import Stats +import re from typing import Any import unittest from absl import logging from absl.testing import absltest -import collections -import functools -from functools import partial -import operator as op -import re - import jax from jax import export -from jax.experimental import pjit from jax import lax import jax.numpy as jnp from jax import ops from jax import random from jax._src import config from jax._src import core +from jax._src import dtypes from jax._src import test_util as jtu from jax._src.export import shape_poly from jax._src.export import shape_poly_decision from jax._src.lax import lax as lax_internal from jax._src.lax import control_flow as lax_control_flow +from jax._src.lax import utils as lax_utils from jax._src.state import discharge from jax._src.state import primitives as ref_primitives @@ -77,17 +76,6 @@ def _expect(*, current, best): def _bounds(e: shape_poly.DimSize) -> tuple[float, float]: return shape_poly._bounds_decision(e, shape_poly.BoundsPrecision.BEST) -def _assert_equal_bounds(tst: jtu.JaxTestCase, - e: shape_poly.DimSize, - bounds: tuple[float, float]): - if isinstance(e, shape_poly._DimExpr): - scope = e.scope - else: - scope = shape_poly.SymbolicScope() - decision = shape_poly._make_decision_state(scope) - found_bounds = decision.bounds(e) - tst.assertEqual(bounds, found_bounds) - def _start_profile(tst: jtu.JaxTestCase): tst.prof = None if os.getenv("JAX_PROFILE_TEST", False): @@ -673,7 +661,6 @@ def test_compare_ge(self): self.assertGreaterEqual(-8, -poly) self.assertGreater(-7, -poly) - def test_int_results(self): # Whenever the result is an integer, it should be represented as a # Python integer, not a symbolic dimension. @@ -718,7 +705,6 @@ def test_divmod(self, *, dividend, quotient, divisor, remainder): self.sampled_assertion(remainder, lambda *args: divmod(*args)[1], dividend, divisor) - def test_unit_combine_term_with_constraints(self): a, b, c, d, e = shape_poly.symbolic_shape("a, b, c, d, e", constraints=[ @@ -787,7 +773,6 @@ def _m(e: shape_poly._DimExpr) -> shape_poly._DimTerm: set(decision.combine_term_with_existing(_m(d), 2, scope=scope, only_smaller_than_t=True))) - def test_dilate_dim(self): """0 if d == 0 else 1 + dilation * (d - 1))""" a, = shape_poly.symbolic_shape("a,") @@ -961,7 +946,7 @@ def test_constraints_ge_complex_gen(self, self.assertEqual(bounds, _bounds(exp)) def test_constraints_ge_override(self): - # Some constaints override other + # Some constraints override other a, b = shape_poly.symbolic_shape("a, b", constraints=("a >= 5", "b <= 16", "a >= 10", "b <= 10")) @@ -979,7 +964,7 @@ def test_constraint_eq_0(self): self.assertIs(d, 5) def test_constraints_eq_1(self): - # Some constaints override other + # Some constraints override other a, b, c = shape_poly.symbolic_shape("a, b, c", constraints=("max(a, b) == c",)) self.assertEqual(_bounds(core.max_dim(a, b) - c + 3), (3, 3)) @@ -1002,7 +987,7 @@ def test_constraints_eq_3(self): self.assertEqual(_bounds(b), (2, np.inf)) # TODO: the following ought to work, but the way we wrote the equality # constraint, `min(b, 2)` gets rewritten to `2`. - #self.assertEqual(core.min_dim(a, b), b - core.min_dim(b, 2)) + # self.assertEqual(core.min_dim(a, b), b - core.min_dim(b, 2)) def test_constraints_eq_4(self): # Equalities of a variable with an expression @@ -1254,6 +1239,25 @@ def test_constraints_different_scope(self): "Invalid mixing of symbolic scopes"): o(a, a1) + def test_int_dtype_for_shape(self): + shape = shape_poly.symbolic_shape("2, a") + self.assertEqual( + dtypes.default_int_dtype(), + lax_utils.int_dtype_for_shape(shape, signed=True), + ) + self.assertEqual( + dtypes.default_uint_dtype(), + lax_utils.int_dtype_for_shape(shape, signed=False), + ) + self.assertEqual( + dtypes.default_int_dtype(), + lax_utils.int_dtype_for_dim(shape[1], signed=True), + ) + self.assertEqual( + dtypes.default_uint_dtype(), + lax_utils.int_dtype_for_dim(shape[1], signed=False), + ) + class PolyHarness(Harness): """Tests a function with shape polymorphism. @@ -1531,8 +1535,7 @@ def test_pytree(self): # Arguments are of the form [([x00, x01], [x10]), dict(a=ya, b=yb)] def add_all_jax(x_pair_of_list, y_dict): x_list_0, x_list_1 = x_pair_of_list - return functools.reduce(op.add, - x_list_0 + x_list_1 + [y_dict["a"], y_dict["b"]]) + return sum(x_list_0 + x_list_1 + [y_dict["a"], y_dict["b"]]) x = np.arange(4, dtype=_f32) args = (([x, x], [x]), dict(a=x, b=x)) @@ -1582,8 +1585,7 @@ def test_pytree_errors(self, polymorphic_shapes=("b", "b", "b")): args = (([x, x], [x]), dict(a=x, b=x)) def add_all_jax(x_pair_of_list, y_dict): x_list_0, x_list_1 = x_pair_of_list - return functools.reduce(op.add, - x_list_0 + x_list_1 + [y_dict["a"], y_dict["b"]]) + return sum(x_list_0 + x_list_1 + [y_dict["a"], y_dict["b"]]) with self.assertRaisesRegex(ValueError, "pytree structure error"): check_shape_poly(self, @@ -1966,7 +1968,7 @@ def collision_hash(obj): setattr(shape_poly._DimExpr, "__hash__", collision_hash) xs = [np.ones((3, 5, 6), dtype=np.float32)] - f_toconvert = jax.vmap(pjit.pjit(f_jax)) + f_toconvert = jax.vmap(jax.jit(f_jax)) res_1 = check_shape_poly(self, f_toconvert, arg_descriptors=xs, polymorphic_shapes=["..."]) res_2 = check_shape_poly(self, f_toconvert, arg_descriptors=xs, @@ -3463,6 +3465,26 @@ def f(x_ref): k=x.shape[1]), arg_descriptors=[RandArg((3, 4), _f32)], polymorphic_shapes=["m, n"]), + [ + PolyHarness("tril_indices", f"{has_k=}_{has_m=}", + lambda x: jnp.tril_indices(x.shape[0], # n + k=x.shape[0] - 1 if has_k else 0, + m=x.shape[1] if has_m else None), + arg_descriptors=[RandArg((3, 4), _f32)], + polymorphic_shapes=["n, m"]) + for has_k in [True, False] + for has_m in [True, False] + ], + [ + PolyHarness("triu_indices", f"{has_k=}_{has_m=}", + lambda x: jnp.triu_indices(x.shape[0], # n + k=x.shape[0] - 1 if has_k else 0, + m=x.shape[1] if has_m else None), + arg_descriptors=[RandArg((3, 4), _f32)], + polymorphic_shapes=["n, m"]) + for has_k in [True, False] + for has_m in [True, False] + ], [ PolyHarness("triangular_solve", f"shape={jtu.format_shape_dtype_string(a_shape, dtype)}_{left_side=}_{a_poly=}_{b_poly=}", diff --git a/tests/shard_alike_test.py b/tests/shard_alike_test.py index 2ad3e089e662..aeb218b478ad 100644 --- a/tests/shard_alike_test.py +++ b/tests/shard_alike_test.py @@ -19,7 +19,7 @@ from jax._src import test_util as jtu from jax.sharding import NamedSharding, PartitionSpec as P from jax.experimental.shard_alike import shard_alike -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map jax.config.parse_flags_with_absl() jtu.request_cpu_devices(8) @@ -146,7 +146,7 @@ def g(x): @jax.jit def f(x): y = x @ x.T - s_out = shard_map(g, mesh, in_specs=P('x', 'y'), + s_out = shard_map(g, mesh=mesh, in_specs=P('x', 'y'), out_specs=P(None, 'y'))(y) z = s_out.T @ s_out return shard_alike(y, z) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index f8d5a11e842f..676d53bb869b 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -36,18 +36,19 @@ from jax._src import config from jax._src import core from jax._src import prng +from jax._src.shard_map import shard_map from jax._src import test_util as jtu -from jax._src.lib.mlir.dialects import sdy from jax._src.util import safe_zip, safe_map, partition_list, merge_lists from jax._src.ad_checkpoint import saved_residuals -from jax._src.mesh import AxisType +from jax._src.mesh import AxisType, get_abstract_mesh, empty_concrete_mesh +from jax._src.lax.parallel import all_gather_invariant from jax._src.interpreters import partial_eval as pe from jax._src import linear_util as lu from jax._src import tree_util +from jax.custom_derivatives import SymbolicZero import jax.numpy as jnp from jax.experimental.custom_partitioning import custom_partitioning -from jax.experimental.shard_map import shard_map config.parse_flags_with_absl() @@ -57,14 +58,14 @@ zip, unsafe_zip = safe_zip, zip # Helper for some tests. -def create_inputs(a_sharding, b_sharding): +def create_inputs(a_sharding, b_sharding, dtype=None): mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) b, e, f = 8, 8, 8 # pylint: disable=invalid-name m1 = jax.device_put( - jnp.arange(b * e).reshape((b, e)), + jnp.arange(b * e, dtype=dtype).reshape((b, e)), jax.sharding.NamedSharding(mesh, a_sharding)) m2 = jax.device_put( - jnp.arange(e * f).reshape((e, f)), + jnp.arange(e * f, dtype=dtype).reshape((e, f)), jax.sharding.NamedSharding(mesh, b_sharding)) return mesh, m1, m2 @@ -82,7 +83,7 @@ def identity(x): def fwd(a): c = shard_map( identity, - mesh, + mesh=mesh, in_specs=(P('z', ('x', 'y')),), out_specs=P('z', ('x', 'y')))(a) return c @@ -94,17 +95,13 @@ def test_all_gather(self): mesh, a, _ = create_inputs(P('z', ('x', 'y')), P(None, None)) assert a.addressable_data(0).shape == (4, 2) - # NOTE(mattjj): to use out_specs=P(None, ('x', 'y')), we need to use - # all_gather_invariant primitive, which differs in its output replication - # type compared to all_gather. @jax.jit @partial(shard_map, mesh=mesh, in_specs=(P('z', ('x', 'y')),), out_specs=P('z', ('x', 'y'))) def fwd(a): - return ( - lax.all_gather(a, 'z', axis=0, tiled=True), - lax.all_gather(a, ('x', 'y'), axis=-1, tiled=True), - ) + return (lax.all_gather(a, 'z', axis=0, tiled=True), + lax.all_gather(a, ('x', 'y'), axis=-1, tiled=True)) + c, d = fwd(a) self.assertEqual(c.addressable_data(0).shape, (8, 2)) for i, a_shard in enumerate(np.split(a, 4, axis=1)): @@ -113,6 +110,64 @@ def fwd(a): for i, a_shard in enumerate(np.split(a, 2, axis=0)): self.assertAllClose(d.addressable_data(i), a_shard) + def test_all_gather_invariant_basic(self): + mesh = jtu.create_mesh((4,), 'x') + arr = jnp.arange(8.) + + @jax.jit + @shard_map(mesh=mesh, in_specs=P('x'), out_specs=P()) + def f(a): + out = all_gather_invariant(a, 'x', tiled=True) + self.assertEqual(out.aval.vma, set()) + return out + + out = f(arr) + self.assertArraysEqual(out, arr) + + jtu.check_grads(f, (arr,), order=2) + + def g(x): + return f(x).sum() + out = jax.jit(jax.grad(g))(arr) + self.assertEqual(out.shape, (8,)) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + def test_all_gather_invariant_complex(self): + mesh, a, _ = create_inputs(P('z', ('x', 'y')), P(None, None), + dtype=np.float32) + assert a.addressable_data(0).shape == (4, 2) + + @jax.jit + @shard_map(mesh=mesh, in_specs=(P('z', ('x', 'y')),), + out_specs=(P(None, ('x', 'y')), P('z'))) + def f(a): + c = all_gather_invariant(a, 'z', axis=0, tiled=True) + self.assertEqual(jax.typeof(c).vma, {'x', 'y'}) + d = all_gather_invariant(a, ('x', 'y'), axis=-1, tiled=True) + self.assertEqual(jax.typeof(d).vma, {'z'}) + return c, d + + c, d = f(a) + + self.assertEqual(c.addressable_data(0).shape, (8, 2)) + for i, a_shard in enumerate(np.split(a, 4, axis=1)): + self.assertAllClose(c.addressable_data(2 * i), a_shard) + + self.assertEqual(d.addressable_data(0).shape, (4, 8)) + for i, a_shard in enumerate(np.split(a, 2, axis=0)): + self.assertAllClose(d.addressable_data(i), a_shard) + + def g(x): + return f(x)[0].sum() + + out1 = jax.jit(jax.grad(g))(a) + self.assertEqual(out1.shape, (8, 8)) + self.assertEqual(out1.sharding, NamedSharding(mesh, P('z', ('x', 'y')))) + + out2 = jax.grad(g)(a) + self.assertEqual(out2.shape, (8, 8)) + self.assertEqual(out2.sharding, NamedSharding(mesh, P('z', ('x', 'y')))) + def test_all_gather_with_axis_index_groups(self): mesh, a, _ = create_inputs(P('x', ('y', 'z')), P(None, None)) @@ -135,21 +190,85 @@ def fwd(a): self.assertAllClose(c.addressable_data(4 * i + 2 * j), block) self.assertAllClose(c.addressable_data(4 * i + 2 * j + 1), block) - def test_matmul_partial(self): - raise unittest.SkipTest("invalid replication asserted by out_spec?") - - mesh, a, b = create_inputs(P('z', 'y'), P('y', None)) + def test_with_static_arg(self): + mesh, a, _ = create_inputs(P('x', 'y'), P(None, None)) assert a.addressable_data(0).shape == (4, 4) + def add_one(x, static_shape): + return x + jnp.ones(static_shape, dtype=x.dtype) + + @jax.jit(static_argnums=(1,)) + def fun(a, static_shape): + return shard_map( + add_one, + mesh=mesh, + in_specs=(P('x', 'y'), None), + out_specs=P('x', 'y'))(a, static_shape) + + self.assertAllClose(a + 1, fun(a, a.addressable_data(0).shape)) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_matmul_unreduced(self, mesh): + np_inp1 = np.arange(8.).reshape(2, 4) + np_inp2 = np.arange(8.).reshape(4, 2) + arr1 = jax.device_put(np_inp1, P('x', 'y')) + arr2 = jax.device_put(np_inp2, P('y', None)) + @jax.jit - @partial(shard_map, mesh=mesh, - in_specs=(P('z', 'y'), P('y', None)), out_specs=P('z', None)) - def fwd(a): - c = jnp.matmul(a, b) # [B.z, F] {y.unreduced} + @shard_map(in_specs=(P('x', 'y'), P('y', None)), + out_specs=P('x', None, unreduced={'y'})) + def f(a, b): + c = jnp.einsum('ab,bc->ac', a, b) + self.assertEqual(c.aval.vma, {'x', 'y'}) return c - c = fwd(a) - self.assertEqual(c.addressable_data(0).shape, (4, 8)) + out = f(arr1, arr2) + self.assertEqual(out.shape, (2, 2)) + self.assertEqual(out.sharding, + NamedSharding(mesh, P('x', None, unreduced={'y'}))) + + expected_shards = [np.array([[2., 3.]]), np.array([[26., 31.]]), + np.array([[10., 19.]]), np.array([[66., 79.]])] + for s, es in zip(out.addressable_shards, expected_shards): + self.assertEqual(s.data.shape, (1, 2)) + self.assertArraysEqual(s.data, es) + + resharded_out = jax.sharding.reshard(out, P('x', None)) + self.assertArraysEqual(resharded_out, np_inp1 @ np_inp2) + + def g(x, y): + out = f(x, y) + return jax.sharding.reshard(out, P('x', None)).sum() + + out1, out2 = jax.jit(jax.grad(g, argnums=(0, 1)))(arr1, arr2) + self.assertEqual(out1.sharding, NamedSharding(mesh, P('x', 'y'))) + self.assertEqual(out2.sharding, NamedSharding(mesh, P('y', None))) + + with jax.set_mesh(jtu.create_mesh((1,), 'x')): + ex_out1, ex_out2 = jax.jit(jax.grad(lambda x, y: (x @ y).sum(), + argnums=(0, 1)))(np_inp1, np_inp2) + self.assertArraysAllClose(ex_out1, out1, rtol=2e-4) + self.assertArraysAllClose(ex_out2, out2, rtol=2e-4) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_matmul_unreduced_error(self, mesh): + np_inp1 = np.arange(8.).reshape(2, 4) + np_inp2 = np.arange(8.).reshape(4, 2) + arr1 = jax.device_put(np_inp1, P('x', 'y')) + arr2 = jax.device_put(np_inp2, P('y', None)) + + @jax.jit + @shard_map(in_specs=(P('x', None), P(None, None)), + out_specs=P('x', None, unreduced={'y'})) + def f(a, b): + c = jnp.einsum('ab,bc->ac', a, b) + self.assertEqual(c.aval.vma, {'x'}) + return c + + with self.assertRaisesRegex( + ValueError, + "vary_unreduced_cast is a Varying->Unreduced collective"): + f(arr1, arr2) def test_matmul_reduce_scatter(self): mesh, a, b = create_inputs(P('z', 'y'), P('y', None)) @@ -217,13 +336,298 @@ def test_collective_permute(self): shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None) ) def fwd(a): - axis_size = lax.psum(1, 'x') + axis_size = lax.axis_size('x') perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] return lax.ppermute(a, 'x', perm=perm) c = fwd(a) self.assertAllClose(c[1, :], a[0, :]) + @jtu.run_on_devices("gpu") + def test_psend_precv_basic_two_gpus(self): + mesh = jtu.create_mesh((2,), 'x') + a = jax.device_put( + jnp.arange(2 * 2, dtype=jnp.float32).reshape((2, 2)), + jax.sharding.NamedSharding(mesh, P('x', None))) + weights = jax.random.uniform( + key=jax.random.key(0), shape=(2, 1), dtype=jnp.float32) + + @jax.jit + @partial( + shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None) + ) + def fwd(a): + return_dtype_and_shape = jax.ShapeDtypeStruct(a.shape, a.dtype) + fwd_token = jax.lax.psend(a, 'x', [(0, 1)]) + data = jax.lax.precv( + fwd_token, return_dtype_and_shape, 'x', [(0, 1)]) + # Here we use an optimization barrier to enforce an arbitrary ordering of + # operations. This makes sure the mat mul only happens after the recv is + # complete. + weights_, _ = ( + jax.lax.optimization_barrier( + (weights, data) + ) + ) + res = jnp.dot(weights_, data) + + # send the compute result back to the first device + bwd_token = jax.lax.psend( + res, + axis_name='x', + perm=[(1, 0)], + ) + + bwd_data = jax.lax.precv( + bwd_token, + out_shape=return_dtype_and_shape, + axis_name='x', + perm=[(1, 0)] + ) + return bwd_data + + c = fwd(a) + self.assertEqual(c.shape, a.shape) + + @jtu.run_on_devices("gpu") + def test_psend_precv_basic_with_no_deadlock_cycle(self): + mesh = jtu.create_mesh((8,), 'x') + a = jax.device_put( + jnp.arange(8 * 8, dtype=jnp.float32).reshape((8, 8)), + jax.sharding.NamedSharding(mesh, P('x', None))) + weights = jax.random.uniform( + key=jax.random.key(0), shape=(8, 1), dtype=jnp.float32) + + @jax.jit + @partial( + jax.shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None) + ) + def fwd(a): + return_dtype_and_shape = jax.ShapeDtypeStruct(a.shape, a.dtype) + + # We define the "forward edge" to be the device-to-device communication + # originating from device 0 in increasing indices. + fwd_token = jax.lax.psend( + a, + axis_name="x", + perm=[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)], + ) + + data = jax.lax.precv( + fwd_token, + out_shape=return_dtype_and_shape, + axis_name="x", + perm=[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)], + ) + + # Here we use an optimization barrier to enforce an arbitrary ordering of + # collectives. This will make sure compute happens after recv on the forward + # edge, and by extension will make sure the send on the back edge happens + # after the recv on the forward edge. Without this optimization barrier, the + # send on the backward edge might slip before the forward edge recv ops are + # completed, and will cause a deadlock. + weights_, _ = ( + jax.lax.optimization_barrier( + (weights, data) + ) + ) + res = jnp.dot(weights_, data) + + # send the compute result back to the first device + bwd_token = jax.lax.psend( + res, + axis_name="x", + perm=[(7, 0)], + ) + + bwd_data = jax.lax.precv( + bwd_token, + out_shape=return_dtype_and_shape, + axis_name="x", + perm=[(7, 0)] + ) + return bwd_data + + c = fwd(a) + self.assertEqual(c.shape, a.shape) + + @jtu.run_on_devices('gpu') + def test_psend_precv_basic_with_deadlock_cycle(self): + mesh = jtu.create_mesh((2,), 'x') + a = jax.device_put( + jnp.arange(2 * 2, dtype=jnp.float32).reshape((2, 2)), + jax.sharding.NamedSharding(mesh, P('x', None)), + ) + + @jax.jit + @partial( + jax.shard_map, + mesh=mesh, + in_specs=P('x', None), + out_specs=P('x', None), + ) + def fwd(a): + return_dtype_and_shape = jax.ShapeDtypeStruct(a.shape, a.dtype) + fwd_token = jax.lax.psend( + a, + axis_name='x', + perm=[(0, 1), (1, 0)], + ) + + data = jax.lax.precv( + fwd_token, + out_shape=return_dtype_and_shape, + axis_name='x', + perm=[(0, 1), (1, 0)], + ) + return data + expected_error_message = ( + 'Expected send and recv instructions to have non-cyclical' + ' source-target pairs' + ) + with self.assertRaisesRegex( + jax.errors.JaxRuntimeError, expected_error_message + ): + fwd(a) + + @jtu.run_on_devices('gpu') + def test_psend_precv_basic_with_dangling_recv(self): + mesh = jtu.create_mesh((2,), 'x') + a = jax.device_put( + jnp.arange(2 * 2, dtype=jnp.float32).reshape((2, 2)), + jax.sharding.NamedSharding(mesh, P('x', None)), + ) + + @jax.jit + @partial( + jax.shard_map, + mesh=mesh, + in_specs=P('x', None), + out_specs=P('x', None), + ) + def fwd(a): + return_dtype_and_shape = jax.ShapeDtypeStruct(a.shape, a.dtype) + data = jax.lax.precv( + jax.lax.create_token(), + out_shape=return_dtype_and_shape, + axis_name='x', + perm=[(0, 1)], + ) + return data + + expected_error_message = 'Expected send to match recv' + with self.assertRaisesRegex( + jax.errors.JaxRuntimeError, expected_error_message + ): + fwd(a) + + @jtu.run_on_devices('gpu') + def test_psend_precv_basic_with_non_matching_source_target_pairs(self): + mesh = jtu.create_mesh((2,), 'x') + a = jax.device_put( + jnp.arange(2 * 2, dtype=jnp.float32).reshape((2, 2)), + jax.sharding.NamedSharding(mesh, P('x', None)), + ) + + @jax.jit + @partial( + jax.shard_map, + mesh=mesh, + in_specs=P('x', None), + out_specs=P('x', None), + ) + def fwd(a): + return_dtype_and_shape = jax.ShapeDtypeStruct(a.shape, a.dtype) + fwd_token = jax.lax.psend( + a, + axis_name='x', + perm=[(0, 1)], + ) + + data = jax.lax.precv( + fwd_token, + out_shape=return_dtype_and_shape, + axis_name='x', + perm=[(1, 0)], + ) + return data + + expected_error_message = ( + 'Deadlock detected. Last checked instructions: %psend' + ) + with self.assertRaisesRegex( + jax.errors.JaxRuntimeError, expected_error_message + ): + fwd(a) + + @jtu.run_on_devices('gpu') + def test_psend_precv_basic_with_duplicate_source_target_pairs(self): + mesh = jtu.create_mesh((2,), 'x') + a = jax.device_put( + jnp.arange(2 * 2, dtype=jnp.float32).reshape((2, 2)), + jax.sharding.NamedSharding(mesh, P('x', None)), + ) + + @jax.jit + @partial( + jax.shard_map, + mesh=mesh, + in_specs=P('x', None), + out_specs=P('x', None), + ) + def fwd(a): + return_dtype_and_shape = jax.ShapeDtypeStruct(a.shape, a.dtype) + fwd_token = jax.lax.psend( + a, + axis_name='x', + perm=[(0, 1), (0, 1)], + ) + + data = jax.lax.precv( + fwd_token, + out_shape=return_dtype_and_shape, + axis_name='x', + perm=[(1, 0)], + ) + return data + + expected_error_message = ( + 'psend sources and destinations must be unique' + ) + with self.assertRaisesRegex( + ValueError, expected_error_message + ): + fwd(a) + + @jtu.run_on_devices("gpu") + def test_psend_precv_reverse_two_gpus(self): + mesh = jtu.create_mesh((2,), 'x') + a = jax.device_put( + jnp.arange(2 * 2, dtype=jnp.float32).reshape((2, 2)), + jax.sharding.NamedSharding(mesh, P('x', None))) + @jax.jit + @partial( + jax.shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None) + ) + def fwd(a): + return_dtype_and_shape = jax.ShapeDtypeStruct(a.shape, a.dtype) + dummy_data = jax.lax.precv( + jax.lax.create_token(), + out_shape=return_dtype_and_shape, + axis_name="x", + perm=[(0, 1)], + ) + + _ = jax.lax.psend( + dummy_data, + axis_name="x", + perm=[(0, 1)], + ) + return dummy_data + + c = fwd(a) + self.assertAllClose(c, jnp.zeros_like(a)) + def test_collective_permute_with_multiple_axis_names(self): mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) a = jax.device_put( @@ -239,8 +643,8 @@ def test_collective_permute_with_multiple_axis_names(self): out_specs=P('x', ('y', 'z')), ) def fwd(a): - xy_axis_size = lax.psum(1, ('x', 'y')) - yz_axis_size = lax.psum(1, ('y', 'z')) + xy_axis_size = lax.axis_size(('x', 'y')) + yz_axis_size = lax.axis_size(('y', 'z')) xy_perm = [(j, (j + 1) % xy_axis_size) for j in range(xy_axis_size)] yz_perm = [(j, (j + 1) % yz_axis_size) for j in range(yz_axis_size)] return ( @@ -289,6 +693,39 @@ def fwd(a): c = fwd(a) assert (c == jnp.reshape(a.T, (1, 64))).all() + @parameterized.named_parameters( + dict( + testcase_name='_partial_replicated', replicate_on_axes='x', + ), + dict( + testcase_name='_fully_replicated', + replicate_on_axes=('x', 'y'), + ), + ) + @jtu.run_on_devices("gpu") + def test_pbroadcast(self, replicate_on_axes): + mesh = jtu.create_mesh((4, 2), ('x', 'y')) + sharded_axes = set(mesh.axis_names) - set(replicate_on_axes) + sharded_axes = None if not sharded_axes else list(sharded_axes) + in_out_sharding = jax.sharding.NamedSharding(mesh, P(sharded_axes, None)) + a = jax.device_put(jnp.arange(16).reshape((4, 4)), in_out_sharding) + + @jax.jit + @partial( + shard_map, + mesh=mesh, + in_specs=(in_out_sharding.spec,), + out_specs=in_out_sharding.spec, + check_vma=False, + ) + def fwd(x): + axis_index = lax.axis_index(replicate_on_axes) + x = jnp.where(axis_index == 0, x + 1, x) + return lax.pbroadcast(x, replicate_on_axes, source=0) + + c = fwd(a) # Don't crash + self.assertAllClose(c, a + 1) + def test_all_to_all_with_axis_index_groups(self): mesh = jtu.create_mesh((4,), ('x',)) a = jax.device_put( @@ -367,7 +804,7 @@ def f(x): def test_jvp_basic(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) - g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh, + g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh=mesh, in_specs=(P('x', 'y'),), out_specs=P('x', 'y')) args = np.arange(4 * 4.).reshape(4, 4), jtu.check_grads(g, args, 2, ['fwd']) @@ -375,7 +812,7 @@ def test_jvp_basic(self): def test_linearize_basic(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) - g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh, + g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh=mesh, in_specs=(P('x', 'y'),), out_specs=P('x', 'y')) x = np.arange(4 * 4.).reshape(4, 4) @@ -389,7 +826,7 @@ def test_linearize_basic(self): def test_linearize_basic_repres(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) - g = shard_map(lambda x: jax.lax.sin(jax.lax.cos(x)), mesh, + g = shard_map(lambda x: jax.lax.sin(jax.lax.cos(x)), mesh=mesh, in_specs=(P('x',),), out_specs=P('x',)) x = np.arange(4.) @@ -403,7 +840,7 @@ def test_linearize_basic_repres(self): def test_linearize_basic_repres_jit(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) - g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh, + g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh=mesh, in_specs=(P('x',),), out_specs=P('x',)) x = np.arange(4.) @@ -422,7 +859,7 @@ def test_replication_checker_eager(self): def f(x): return 2 * x def g(x): - return shard_map(f, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x) + return shard_map(f, mesh=mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x) with self.assertRaisesRegex(ValueError, 'statically inferred'): g(x) @@ -430,26 +867,24 @@ def g(x): def f2(x): return jax.lax.psum(x, 'x') def g2(x): - return shard_map(f2, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x) + return shard_map(f2, mesh=mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x) _ = g2(x) # doesn't crash def test_replication_checker_jit(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) x = np.arange(8 * 8.).reshape(8, 8) - def f(x): - return 2 * x def g(x): - return shard_map(f, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x) + return shard_map(lambda x: x * 2, mesh=mesh, in_specs=P('x', 'y'), + out_specs=P(None, 'y'))(x) with self.assertRaisesRegex(ValueError, 'statically inferred'): jax.jit(g)(x) - def f2(x): - return jax.lax.psum(x, 'x') def g2(x): - return shard_map(f2, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x) - _ = jax.jit(g2)(x) # doesn't crash + return shard_map(lambda x: jax.lax.psum(x, 'x'), mesh=mesh, + in_specs=P('x', 'y'), out_specs=P(None, 'y'))(x) + jax.jit(g2)(x) # doesn't crash def test_process_env_traces(self): mesh = Mesh(np.array(jax.devices()[:4]), ('x',)) @@ -457,7 +892,7 @@ def test_process_env_traces(self): def g(x): y = (3. * x).sum() - z = shard_map(lambda x: 2 * x * y, mesh, + z = shard_map(lambda x: 2 * x * y, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'))(np.arange(8.)) return z @@ -475,13 +910,14 @@ def f(x): return -x def g(x): - return shard_map(f, mesh, in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))(x) + return shard_map(f, mesh=mesh, in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))(x) y = g(x) self.assertAllClose(y, -x, check_dtypes=False) def test_outer_jit_detects_shard_map_mesh(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) - f = shard_map(lambda x: x.reshape(1, *x.shape), mesh, P(), P('x')) + f = shard_map(lambda x: x.reshape(1, *x.shape), mesh=mesh, in_specs=P(), + out_specs=P('x')) _ = jax.jit(f)(jnp.array(2.0)) # doesn't crash def test_vmap_basic(self): @@ -489,7 +925,7 @@ def test_vmap_basic(self): x = jnp.arange(8 * 8.).reshape(8, 8) def g(x): - return shard_map(lambda x: 2. * x, mesh, + return shard_map(lambda x: 2. * x, mesh=mesh, in_specs=P('y'), out_specs=P('y'))(x) y = jax.vmap(g)(x) self.assertAllClose(y, 2 * x, check_dtypes=False) @@ -499,7 +935,7 @@ def test_vmap_basic_axis_name(self): x = jnp.arange(8 * 8.).reshape(8, 8) def g(x): - return shard_map(lambda x: 2. * x, mesh, + return shard_map(lambda x: 2. * x, mesh=mesh, in_specs=P('y'), out_specs=P('y'))(x) y = jax.vmap(g, axis_name='i')(x) self.assertAllClose(y, 2 * x, check_dtypes=False) @@ -509,7 +945,7 @@ def test_vmap_basic_axis_name_reuse_mesh_name(self): x = jnp.arange(8 * 8.).reshape(8, 8) def g(x): - return shard_map(lambda x: 2. * x, mesh, + return shard_map(lambda x: 2. * x, mesh=mesh, in_specs=P('y'), out_specs=P('y'))(x) y = jax.vmap(g, axis_name='x')(x) # NOTE reuse same 'x' as on mesh self.assertAllClose(y, 2 * x, check_dtypes=False) @@ -588,6 +1024,32 @@ def f(): x = f() self.assertAllClose(x, jnp.arange(4), check_dtypes=False) + def test_optimize_remat(self): + mesh = jtu.create_mesh((4,), 'x') + + @jax.custom_vjp + def f(x): + return jnp.tan(x) + + def f_fwd(x): + return jax.lax.psum(x, 'x'), (x,) + + def f_bwd(res, g): + x, = res + cos_x = jnp.cos(x) + return (cos_x * g,) + + f.defvjp(f_fwd, f_bwd, optimize_remat=True) + + @jax.jit + @jax.shard_map(mesh=mesh, in_specs=P(), out_specs=P()) + def temp(x): + out = jax.remat(f)(x) + out = out ** 2 + return out + + jax.grad(lambda x: temp(x).sum())(jnp.arange(4.)) + def test_remat_basic(self): # this tests remat-of-shmap mesh = Mesh(np.array(jax.devices()[:4]), ('x',)) @@ -630,6 +1092,7 @@ def f2(x): g2 = jax.grad(lambda x: f2(x).sum())(x) # doesn't crash self.assertAllClose(g2, jnp.cos(x), check_dtypes=False) + @jtu.skip_on_flag("jax_skip_slow_tests", True) def test_remat_scalar_residuals(self): mesh = Mesh(np.array(jax.devices()[:4]), ('x',)) @@ -662,29 +1125,37 @@ def test_check_rep_false_doesnt_hit_rep_rules(self): prim.def_impl(lambda: []) prim.def_abstract_eval(lambda: []) - @partial(shard_map, mesh=mesh, in_specs=(), out_specs=None, check_rep=True) + @partial(shard_map, mesh=mesh, in_specs=(), out_specs=None, check_vma=True) def f(): prim.bind() - with self.assertRaises(NotImplementedError): - f() - with self.assertRaises(NotImplementedError): - jax.jit(f)() - - @partial(shard_map, mesh=mesh, in_specs=(), out_specs=None, check_rep=False) + @partial(shard_map, mesh=mesh, in_specs=(), out_specs=None, check_vma=False) def f2(): prim.bind() f2() jax.jit(f2)() - @partial(shard_map, mesh=mesh, in_specs=(), out_specs=None, check_rep=False) + @partial(shard_map, mesh=mesh, in_specs=(), out_specs=None, check_vma=False) def f3(): jax.jit(prim.bind)() f3() jax.jit(f3)() + def test_multiple_result_primitive_with_none_sharding(self): + # https://github.com/jax-ml/jax/issues/27673 + xs = jnp.arange(20).reshape(2, 10) + mesh = jtu.create_mesh((2,), ("i",)) + y = shard_map( + lambda x: jnp.split(x.squeeze(), 2), + mesh=mesh, + in_specs=(None,), + out_specs=P("i"), + )(xs) + expected = jnp.repeat(xs, 2, axis=0).reshape(2, 2, 10) + self.assertArraysEqual(y, expected) + def test_vmap_spmd_axis_name(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) @@ -695,16 +1166,68 @@ def f(x): x = jnp.arange(4 * 4).reshape(4, 4) jaxpr = jax.make_jaxpr(jax.vmap(f, spmd_axis_name='y'))(x).jaxpr e, = jaxpr.eqns - self.assertIn('in_names', e.params) - self.assertEqual(e.params['in_names'], ({0: ('y',), 1: ('x',)},)) - self.assertIn('out_names', e.params) - self.assertEqual(e.params['out_names'], ({0: ('y',), 1: ('x',)},)) + self.assertIn('in_specs', e.params) + self.assertEqual(e.params['in_specs'], (P('y', 'x'),)) + self.assertIn('out_specs', e.params) + self.assertEqual(e.params['out_specs'], (P('y', 'x'),)) + + def test_vmap_explicit_mesh_axis(self): + mesh = jtu.create_mesh( + (1, 2, 2), ('z', 'x', 'y'), axis_types=(AxisType.Explicit,) * 3) + + @shard_map(mesh=mesh, in_specs=P('y'), out_specs=P('y')) + def f(x): + return x + + s = NamedSharding(mesh, P(('z', 'x'), 'y')) + x = jax.device_put(jnp.arange(4 * 4).reshape(4, 4), s) + + f = jax.jit(jax.vmap(f)) + out = f(x) + self.assertEqual(out.sharding, s) + + @jtu.with_explicit_mesh((2, 2), ('data', 'model')) + def test_vmap_explicit_mesh_axis_single_axis(self, mesh): + x = jax.device_put(jnp.arange(4 * 4).reshape(4, 4), P('data', 'model')) + + @shard_map(in_specs=P('model'), out_specs=P('model')) + def f(x): + return x + + f = jax.jit(jax.vmap(f)) + out = f(x) + self.assertEqual(out.sharding, NamedSharding(mesh, P('data', 'model'))) + + def test_vmap_explicit_mesh_axis_error(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y'), + axis_types=(AxisType.Explicit,) * 2) + + @shard_map(mesh=mesh, in_specs=P('x'), out_specs=P('x'), + axis_names={'x', 'y'}) + def f(x): + return x + + x = jnp.arange(4 * 4).reshape(4, 4) + s = NamedSharding(mesh, P('x', 'y')) + x = jax.device_put(x, s) + + f = jax.jit(jax.vmap(f)) + with self.assertRaisesRegex( + ValueError, "jax.shard_map requires axis_names.*subset of"): + f(x) + + f = jax.jit(jax.vmap(f, spmd_axis_name='y')) + with self.assertRaisesRegex( + ValueError, + 'Only one of spmd_axis_name or arrays sharded on `Explicit` mesh axis' + ' type is allowed'): + f(x) def test_vmap_of_grad_spmd_axis_name(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) @partial( - shard_map, mesh=mesh, in_specs=P('y'), out_specs=P(), check_rep=False + shard_map, mesh=mesh, in_specs=P('y'), out_specs=P(), check_vma=False ) def f(x): return jnp.sin(jnp.sum(x)) @@ -730,10 +1253,10 @@ def f(x): x = jnp.arange(4 * 4).reshape(4, 4) jaxpr = jax.make_jaxpr(jax.vmap(f, spmd_axis_name=('x', 'y')))(x).jaxpr e, = jaxpr.eqns - self.assertIn('in_names', e.params) - self.assertEqual(e.params['in_names'], ({0: ('x', 'y',)},)) - self.assertIn('out_names', e.params) - self.assertEqual(e.params['out_names'], ({0: ('x', 'y',)},)) + self.assertIn('in_specs', e.params) + self.assertEqual(e.params['in_specs'][0], P(('x', 'y'))) + self.assertIn('out_specs', e.params) + self.assertEqual(e.params['out_specs'][0], P(('x', 'y'))) def test_nested_vmap_with_capture_spmd_axis_name(self): self.skipTest('https://github.com/jax-ml/jax/issues/23476') @@ -861,8 +1384,6 @@ def test_shmap_abstract_mesh_errors(self): @jtu.run_on_devices('cpu', 'gpu', 'tpu') @jtu.thread_unsafe_test() def test_debug_print_jit(self, jit): - if config.use_shardy_partitioner.value: - self.skipTest('TODO(b/384938613): Failing under shardy') mesh = Mesh(jax.devices(), ('i',)) @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) @@ -884,6 +1405,52 @@ def f(x): for i in range(len(jax.devices())): self.assertIn(f'instance {i} has value', output()) + @jtu.with_explicit_mesh((2,), ('x')) + def test_pure_callback_return_multiple_arrays(self, mesh): + def host_kernel(arr: np.ndarray): + return arr + 1, arr * 2.0 + + @jax.shard_map(in_specs=P('x'), out_specs=(P('x'), P('x'))) + def per_shard(x_shard): + spec = jax.ShapeDtypeStruct(x_shard.shape, x_shard.dtype) + return jax.pure_callback(host_kernel, (spec, spec), x_shard) + + x = np.arange(32, dtype=np.float32).reshape(16, 2) + per_shard(x) # doesn't crash + + def test_psum_transpose_non_zero_cts(self): + mesh = jtu.create_mesh((8,), 'x') + @shard_map(mesh=mesh, in_specs=P('x'), out_specs=(P('x'), P())) + def f1(x_block): + return x_block, jax.lax.psum(x_block, axis_name='x') + + x1 = jnp.arange(16.) + f1(x1) # doesn't crash + + def f2(x_block): + y, _ = f1(x_block) + return y.sum() + + jax.jit(jax.grad(f2))(x1) # doesn't crash + jax.grad(f2)(x1) # doesn't crash + + @jtu.run_on_devices('cpu', 'gpu', 'tpu') + @jtu.thread_unsafe_test() + def test_debug_print_jit_partial_auto(self): + mesh = jtu.create_mesh((2,2), ('x', 'y')) + + @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'), + axis_names=frozenset({'x'})) + def f(x): + idx = jax.lax.axis_index('x') + jax.debug.print("instance {i} has value x={x}", i=idx, x=x) + y = jnp.cos(x) + return y + + f = jax.jit(f) + x = jnp.arange(2 * len(jax.devices())) + f(x) # don't crash! + def test_debug_print_eager(self): mesh = Mesh(jax.devices(), ('i',)) @@ -902,11 +1469,25 @@ def f(x): for i in range(len(jax.devices())): self.assertIn(f'x=[{2*i} {2*i+1}]', output()) - def test_partial_eval_custom_axis_env(self): - mesh = Mesh(jax.devices(), ('i',)) + def test_partial_auto_axis_index_eager(self): + mesh = jtu.create_mesh((2, 2, 1), ('i', 'j', 'k')) - @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) - def f(_): + def f(): + return jax.lax.axis_index('i').reshape((1,)) + + def g(): + return jax.shard_map(f, mesh=mesh, in_specs=(), out_specs=P('i'), + axis_names={'i'}, check_vma=False)() + + out = g() + expected_out = jax.jit(g)() + self.assertArraysEqual(out, expected_out) + + def test_partial_eval_custom_axis_env(self): + mesh = Mesh(jax.devices(), ('i',)) + + @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) + def f(_): _, idx = jax.lax.scan(lambda _, __: (None, jax.lax.axis_index('i')), None, None, length=1) return idx @@ -930,9 +1511,37 @@ def f(key): dtype=jnp.int32) pspec = P('x') if config.enable_custom_prng.value else P('x', None) - g = shard_map(f, mesh, in_specs=(pspec,), out_specs=pspec) + g = shard_map(f, mesh=mesh, in_specs=(pspec,), out_specs=pspec) _ = g(sharded_rng) # don't crash! + @parameterized.parameters(['threefry2x32', 'rbg', 'unsafe_rbg']) + def test_sharded_random_bits(self, prng_impl): + mesh = jtu.create_mesh((4,), ('x',)) + sharding = jax.sharding.NamedSharding(mesh, P('x')) + + rng = jax.random.key(0, impl=prng_impl) + sharded_rng = jax.random.split(rng, num=4) + sharded_rng = jax.device_put(sharded_rng, sharding) + + def f(key): + return jax.random.bits(key[0], shape=(4,), dtype=jnp.uint8) + + g = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x')) + g(sharded_rng) # don't crash! + + def test_vma_out_specs_error_check(self): + mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) + @shard_map(mesh=mesh, in_specs=P('x', 'y', 'z'), out_specs=P('x')) + def f(x): + return x * 2 + + with self.assertRaisesRegex( + ValueError, + r".*out_specs is PartitionSpec\('x',\) which implies that the.*" + r' output value is only varying across mesh axes \{x\} and not \{y,z\},' + r' but it was inferred to be possibly varying over \{x,y,z\}.*'): + f(np.arange(16).reshape(4, 2, 2)) + def test_functools_partial_rank_error(self): mesh = jtu.create_mesh((4,), ('x',)) @@ -940,7 +1549,7 @@ def test_functools_partial_rank_error(self): def f(x): return x - g = shard_map(f, mesh, in_specs=(P('x', None),), out_specs=P('x',)) + g = shard_map(f, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x',)) x = jnp.arange(4) with self.assertRaises(ValueError): g(x) @@ -950,14 +1559,9 @@ def test_in_specs_none_error(self): def f(x): return x - with self.assertRaisesRegex(TypeError, "but it was None"): - shard_map(f, mesh, in_specs=None, out_specs=P())(3.) - - # TODO(mattjj): enable this test once we fix the tree_map(f, None, 3.0) bug - # with self.assertRaises(TypeError): - # shard_map(f, mesh, in_specs=(None,), out_specs=P())(3.) - - shard_map(f, mesh, in_specs=P(), out_specs=P())(3.) # doesn't crash + shard_map(f, mesh=mesh, in_specs=None, out_specs=P())(3.) # doesn't crash + shard_map(f, mesh=mesh, in_specs=(None,), out_specs=P())(3.) # doesn't crash + shard_map(f, mesh=mesh, in_specs=P(), out_specs=P())(3.) # doesn't crash def test_scan_rep_rule(self): mesh = jtu.create_mesh((2, 2,), ('x', 'y')) @@ -967,24 +1571,25 @@ def f(x, y, z): def body(c, _): c, *cs = c return (*cs, c), None + x = lax.pcast(x, ('x', 'y'), to='varying') + y = lax.pcast(y, 'y', to='varying') out, _ = jax.lax.scan(body, (x, y, z), None, length=3) return [jnp.expand_dims(a, 0) for a in out] x = jnp.arange(4) - # doesn't crash, because out_spec assumes no replication (and there is none) - shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + shard_map(f, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=P(('x', 'y')))(x, x, x) # does crash, because output incorrectly promises replication with self.assertRaisesRegex(ValueError, "require replication"): - shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + shard_map(f, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=P('x'))(x, x, x) with self.assertRaisesRegex(ValueError, "require replication"): - shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + shard_map(f, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=P('y'))(x, x, x) with self.assertRaisesRegex(ValueError, "require replication"): - shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + shard_map(f, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=P(None))(x, x, x) def g(x, y, z): @@ -995,12 +1600,65 @@ def body(c, _): return [jnp.expand_dims(a, 0) for a in out] # doesn't crash, because everything matches - shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + shard_map(g, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + out_specs=[P(None), P('x'), P(('x', 'y'))])(x, x, x) + + # does crash, because the second guy is wrong + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(g, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + out_specs=[P(None), P(None), P(('x', 'y'))])(x, x, x) + + def test_while_rep_rule(self): + mesh = jtu.create_mesh((2, 2,), ('x', 'y')) + + def f(x, y, z): + x, y, z = x.sum(), y.sum(), z.sum() + def cond(c): + i, *_ = c + return i < 5 + def body(c): + i, c, *cs = c + return (i + 1, *cs, c) + x = lax.pcast(x, ('x', 'y'), to='varying') + y = lax.pcast(y, 'y', to='varying') + _, *out = jax.lax.while_loop(cond, body, (0, x, y, z)) + return [jnp.expand_dims(a, 0) for a in out] + + x = jnp.arange(4) + + # doesn't crash, because out_spec assumes no replication (and there is none) + shard_map(f, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + out_specs=P(('x', 'y')))(x, x, x) + + # does crash, because output incorrectly promises replication + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(f, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + out_specs=P('x'))(x, x, x) + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(f, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + out_specs=P('y'))(x, x, x) + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(f, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + out_specs=P(None))(x, x, x) + + def g(x, y, z): + x, y, z = x.sum(), y.sum(), z.sum() + def cond(c): + i, *_ = c + return i < 1 + def body(c): + i, *cs = c + return (i + 1, *cs) + _, *out = jax.lax.while_loop(cond, body, (0, x, y, z)) + return [jnp.expand_dims(a, 0) for a in out] + + # doesn't crash, because everything matches + shard_map(g, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=[P(None), P('x'), P(('x', 'y'))])(x, x, x) # does crash, because the second guy is wrong with self.assertRaisesRegex(ValueError, "require replication"): - shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + shard_map(g, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=[P(None), P(None), P(('x', 'y'))])(x, x, x) def test_cond_rep_rule(self): @@ -1014,20 +1672,22 @@ def false_fun(x, y): return x + 1 return jax.lax.cond(True, true_fn, false_fun, x, y) - shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) + shard_map(f, mesh=mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) + with self.assertRaisesRegex(ValueError, "require replication"): - shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(None))(x, x) + shard_map(f, mesh=mesh, in_specs=(P('x'), P('y')), out_specs=P(None))(x, x) def f(x, y): def true_fn(x, y): - return x + return lax.pcast(x, 'y', to='varying') def false_fun(x, y): - return y + return lax.pcast(y, 'x', to='varying') return jax.lax.cond(True, true_fn, false_fun, x, y) - shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) + shard_map(f, mesh=mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) + with self.assertRaisesRegex(ValueError, "require replication"): - shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) + shard_map(f, mesh=mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) def f(x, y): def true_fn(x, y): @@ -1036,9 +1696,10 @@ def false_fun(x, y): return x + 1 return jax.lax.cond(jnp.any(x > 0), true_fn, false_fun, x, y) - shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) + shard_map(f, mesh=mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) + with self.assertRaisesRegex(ValueError, "require replication"): - shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(None))(x, x) + shard_map(f, mesh=mesh, in_specs=(P('x'), P('y')), out_specs=P(None))(x, x) def f(x, y): def true_fn(x, y): @@ -1047,9 +1708,8 @@ def false_fun(x, y): return x + 1 return jax.lax.cond(jnp.any(y > 0), true_fn, false_fun, x, y) - shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) - with self.assertRaisesRegex(ValueError, "require replication"): - shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) + shard_map(f, mesh=mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) + shard_map(f, mesh=mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) # https://github.com/jax-ml/jax/issues/24418 def f(a): @@ -1058,7 +1718,7 @@ def f(a): mesh = jtu.create_mesh((2,), ('x',)) a = jnp.array([True, False]) - shard_map(f, mesh, in_specs=P('x'), out_specs=P('x'))(a) + shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(a) def test_switch_rep_rule(self): mesh = jtu.create_mesh((2, 2,), ('x', 'y')) @@ -1068,7 +1728,7 @@ def f(n, x, y): return jax.lax.switch( n, [lambda x, _: x, lambda x, _: x + 1, lambda x, _: x + 2], x, y) - shard_map(f, mesh, in_specs=(P(), P('x'), P('y')), out_specs=P('x'))(1, x, x) + shard_map(f, mesh=mesh, in_specs=(P(), P('x'), P('y')), out_specs=P('x'))(1, x, x) def test_eager_custom_jvp_basic(self): @jax.custom_jvp @@ -1081,7 +1741,7 @@ def foo_jvp(primals, tangents): return foo(x), 3. * x_dot mesh = jtu.create_mesh((4,), ('x',)) - g = shard_map(foo, mesh, in_specs=(P('x'),), out_specs=P('x')) + g = shard_map(foo, mesh=mesh, in_specs=(P('x'),), out_specs=P('x')) y, x_bar = jax.value_and_grad(lambda x: g(x).sum())(jnp.arange(4.)) self.assertAllClose(y, (2. * jnp.arange(4.)).sum()) self.assertAllClose(x_bar, 3. * jnp.ones(4), check_dtypes=False) @@ -1100,7 +1760,7 @@ def foo_bwd(_, y_bar): foo.defvjp(foo_fwd, foo_bwd) mesh = jtu.create_mesh((4,), ('x',)) - g = shard_map(foo, mesh, in_specs=(P('x'),), out_specs=P('x')) + g = shard_map(foo, mesh=mesh, in_specs=(P('x'),), out_specs=P('x')) y, x_bar = jax.value_and_grad(lambda x: g(x).sum())(jnp.arange(4.)) self.assertAllClose(y, (2. * jnp.arange(4.)).sum()) self.assertAllClose(x_bar, 3. * jnp.ones(4), check_dtypes=False) @@ -1114,7 +1774,7 @@ def foo(): foo = jax.jit(foo) mesh = jtu.create_mesh((4,), ('x',)) - ans = shard_map(foo, mesh, in_specs=(), out_specs=P('x'))() + ans = shard_map(foo, mesh=mesh, in_specs=(), out_specs=P('x'))() expected = jnp.arange(4.) self.assertAllClose(ans, expected, check_dtypes=False) @@ -1130,7 +1790,7 @@ def foo(): foo = jax.jit(foo) mesh = jtu.create_mesh((4, 2), ('i', 'j')) - ans1, ans2, ans3 = shard_map(foo, mesh, in_specs=(), + ans1, ans2, ans3 = shard_map(foo, mesh=mesh, in_specs=(), out_specs=P('i', 'j'))() expected1 = jnp.arange(4.)[:, None] + jnp.zeros((4, 2)) expected2 = jnp.arange(2.)[None, :] + jnp.zeros((4, 2)) @@ -1197,7 +1857,7 @@ def test_key_array_with_replicated_last_tile_dim(self): def f(rng): @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'), - check_rep=False) + check_vma=False) def g(rng): return jnp.array([jax.random.normal(rng[0])]) return g(jax.random.split(rng, 4)) @@ -1236,7 +1896,8 @@ def test_returned_out_sharding(self): mesh = jtu.create_mesh((1, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) inp = jax.device_put(jnp.zeros((2, 2)), s) - out = shard_map(lambda x: x, mesh, P('x', 'y'), P('x', 'y'))(inp) + out = shard_map(lambda x: x, mesh=mesh, in_specs=P('x', 'y'), + out_specs=P('x', 'y'))(inp) self.assertEqual(out.sharding, s) self.assertArraysEqual(out, inp) @@ -1304,9 +1965,9 @@ def test_sharding_metadata_in_hlo_attrs(self): def foo(x): x = jnp.sin(x) - x = shard_map(lambda x: jnp.cos(x * y), mesh, + x = shard_map(lambda x: jnp.cos(x * y), mesh=mesh, in_specs=P('i'), out_specs=P('i'))(x) - x = shard_map(lambda x: jnp.cos(x * y), mesh, + x = shard_map(lambda x: jnp.cos(x * y), mesh=mesh, in_specs=P('i'), out_specs=P('i'))(x) return x @@ -1318,11 +1979,13 @@ def foo(x): # When devices == 1, the `sdy.manual_computation` is inlined. self.assertEqual(0, hlo_str.count('sdy.manual_computation')) else: - self.assertIn('call @shmap_body', hlo_str) - self.assertIn('call @shmap_body_0', hlo_str) + self.assertIn('call @shmap_body(', hlo_str) + self.assertIn('call @shmap_body_', hlo_str) self.assertIn('%arg0: tensor<1xf32>', hlo_str) - self.assertIn('"[None]"', hlo_str) - self.assertIn('%arg1: tensor<1xf32>', hlo_str) + if not config.use_simplified_jaxpr_constants.value: + # A constvar is turned into an argument with location None in @shmap_body + self.assertIn('"[None]"', hlo_str) + self.assertIn('%arg1: tensor<1xf32>', hlo_str) self.assertIn('"[(\'i\',)]"', hlo_str) self.assertIn( '-> (tensor<1xf32> {jax.result_info = "[(\'i\',)]"})', hlo_str @@ -1337,7 +2000,7 @@ def f(x): x)[0] * x mesh = jtu.create_mesh((4,), ('x',)) - g = shard_map(f, mesh, in_specs=(P('x'),), out_specs=P('x')) + g = shard_map(f, mesh=mesh, in_specs=(P('x'),), out_specs=P('x')) x = jnp.arange(4.) y = jax.jit(g)(x) # eager requires shmap to have ShardMapTrace.process_call self.assertAllClose(y, 2 * x * x, check_dtypes=True) @@ -1371,7 +2034,7 @@ def foo_jvp(primals, tangents): return foo(x), 2. * x_dot mesh = jtu.create_mesh((4,), ('x',)) - g = shard_map(lambda x: foo(x) * x, mesh, + g = shard_map(lambda x: foo(x) * x, mesh=mesh, in_specs=(P('x'),), out_specs=P('x')) if jit: g = jax.jit(g) @@ -1399,7 +2062,7 @@ def foo_bwd(_, y_bar): foo.defvjp(foo_fwd, foo_bwd) mesh = jtu.create_mesh((4,), ('x',)) - g = shard_map(lambda x: foo(x) * x, mesh, + g = shard_map(lambda x: foo(x) * x, mesh=mesh, in_specs=(P('x'),), out_specs=P('x')) if jit: g = jax.jit(g) @@ -1427,7 +2090,7 @@ def foo_bwd(_, y_bar): foo.defvjp(foo_fwd, foo_bwd) mesh = jtu.create_mesh((4,), ('x',)) - g = shard_map(lambda x: foo(x) * x, mesh, + g = shard_map(lambda x: foo(x) * x, mesh=mesh, in_specs=(P('x'),), out_specs=P('x')) if jit: g = jax.jit(g) @@ -1456,38 +2119,6 @@ def f(x): y = shard_f(x) self.assertEqual(x_spec, y.sharding.spec) - @parameterized.parameters([True, False]) - def test_rewrite_process_custom_vjp_call_match_less_replicated(self, jit): - @jax.custom_vjp - def foo(x, y): - del y - return 2. * x - - def foo_fwd(x, y): - return foo(x, y), y - - def foo_bwd(y, _): - return y, None # diff! x_bar less replicated than primal/tangent - - foo.defvjp(foo_fwd, foo_bwd) - - mesh = jtu.create_mesh((4,), ('x',)) - g = shard_map(lambda x, y: foo(x, y) * y, mesh, - in_specs=(P(), P('x')), out_specs=P('x')) - if jit: - g = jax.jit(g) - - x = jnp.arange(4.) - y = jnp.arange(4 * 4.) - - z = g(x, y) - self.assertAllClose(z, 2 * jnp.tile(x, (4,)) * y, check_dtypes=False) - - z_, x_bar = jax.value_and_grad(lambda x, y: g(x, y).sum())(x, y) - self.assertAllClose(z.sum(), z_, check_dtypes=False) - self.assertAllClose(x_bar, jnp.arange(16).reshape(4, 4).sum(0), - check_dtypes=False) - @parameterized.parameters([True, False]) def test_rewrite_custom_vjp_call_jaxpr(self, jit): @jax.custom_vjp @@ -1507,7 +2138,7 @@ def foo_scan(x): return y mesh = jtu.create_mesh((4,), ('x',)) - g = shard_map(lambda x: foo_scan(x) * x, mesh, + g = shard_map(lambda x: foo_scan(x) * x, mesh=mesh, in_specs=(P('x'),), out_specs=P('x')) if jit: g = jax.jit(g) @@ -1555,7 +2186,7 @@ def f(x): jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(1.))[1])(jnp.arange(4.)) e, = jaxpr.jaxpr.eqns e2, = e.params['jaxpr'].eqns - self.assertEqual(str(e2.primitive), 'psum2') + self.assertEqual(str(e2.primitive), 'psum_invariant') self.assertEqual(e2.params['axes'], ('x',)) def test_fanin_psum_transposes_to_fanout(self): @@ -1568,7 +2199,7 @@ def f(x): jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(4.))[1])(jnp.array([1.])) e, = jaxpr.jaxpr.eqns e1, = e.params['jaxpr'].eqns - self.assertEqual(str(e1.primitive), 'pbroadcast') + self.assertEqual(str(e1.primitive), 'pvary') def test_psum_with_implicit_fanout_self_transposes(self): mesh = jtu.create_mesh((4,), ('x',)) @@ -1580,8 +2211,8 @@ def f(x): jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(4.))[1])(jnp.arange(4.)) e, = jaxpr.jaxpr.eqns e1, e2 = e.params['jaxpr'].eqns - self.assertEqual(str(e1.primitive), 'psum2') - self.assertEqual(str(e2.primitive), 'pbroadcast') + self.assertEqual(str(e1.primitive), 'psum_invariant') + self.assertEqual(str(e2.primitive), 'pvary') def test_transpose_float0(self): mesh = jtu.create_mesh((4,), ('x',)) @@ -1612,7 +2243,7 @@ def g_bwd(vjp_fn, result): def f_shmapped(x, y): return jax.lax.psum(f(x, y).sum(), axis_name=('x')) - @partial(shard_map, mesh=mesh, check_rep=False, + @partial(shard_map, mesh=mesh, check_vma=False, in_specs=P('x'), out_specs=(P('x'), P())) def f_shmapped2(x, y): return g(x, y) @@ -1621,7 +2252,7 @@ def f_wrapper(x, y): x, y = jax.lax.map(lambda xs: f_shmapped2(xs[0], xs[1]), (x, y)) return jax.lax.map(lambda xs: f_shmapped(xs[0], xs[1]), (x, y)).sum() - @partial(jax.jit, in_shardings=s, + @jax.jit(in_shardings=s, out_shardings=jax.sharding.NamedSharding(mesh, P())) def example(x, y): return jax.grad(f_wrapper, allow_int=True, argnums=(0, 1))(x, y) @@ -1632,6 +2263,36 @@ def example(x, y): dx, dy = example(x, y) self.assertEqual(dy.dtype, jax.dtypes.float0) + def test_pvary(self): + mesh = jtu.create_mesh((4,), ('x',)) + + @partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P('x')) + def f(x): + y = jax.lax.pcast(x, 'x', to='varying') + self.assertEqual(y.aval.vma, {'x'}) + return y + + f(jnp.arange(8.)) + jax.grad(lambda x: f(x).sum())(jnp.arange(8.)) + + @jtu.with_explicit_mesh((2,), 'x') + def test_pcast_axis_name_is_not_set(self, mesh): + def f(axis_name_type, x): + with self.assertRaisesRegex(TypeError, 'must be a tuple or a str'): + if axis_name_type == 'str': + jax.lax.pcast(x, {'x'}, to='varying') + elif axis_name_type == 'aval.vma': + jax.lax.pcast(x, x.aval.vma, to='varying') + + jax.shard_map(partial(f, 'str'), mesh=mesh, in_specs=P(), + out_specs=None)(np.arange(8.)) + jax.shard_map(partial(f, 'aval.vma'), mesh=mesh, in_specs=P(), + out_specs=None)(np.arange(8.)) + jax.jit(jax.shard_map(partial(f, 'str'), mesh=mesh, in_specs=P(), + out_specs=None))(np.arange(8.)) + jax.jit(jax.shard_map(partial(f, 'aval.vma'), mesh=mesh, in_specs=P(), + out_specs=None))(np.arange(8.)) + def test_rewrite_binops(self): mesh = jtu.create_mesh((4,), ('x',)) @@ -1642,7 +2303,7 @@ def f(x, y): jaxpr = jax.make_jaxpr(f)(jnp.arange(1.), jnp.arange(4.)) e, = jaxpr.jaxpr.eqns e = e.params['jaxpr'].eqns[0] - self.assertEqual(e.primitive.name, 'pbroadcast') + self.assertEqual(e.primitive.name, 'pvary') self.assertEqual(e.params['axes'], ('x',)) def test_rewrite_scan(self): @@ -1650,16 +2311,17 @@ def test_rewrite_scan(self): @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x')) def f(x): - x, _ = jax.lax.scan(lambda x, _: (jax.lax.psum(x, 'x'), None), x, None, - length=2) + def g(x, _): + return lax.pcast(jax.lax.psum(x, 'x'), 'x', to='varying'), None + x, _ = jax.lax.scan(g, x, None, length=2) return x jaxpr = jax.make_jaxpr(f)(jnp.arange(4.)) e, = jaxpr.jaxpr.eqns e, = e.params['jaxpr'].eqns e1, e2 = e.params['jaxpr'].eqns - self.assertEqual(e1.primitive.name, 'psum2') - self.assertEqual(e2.primitive.name, 'pbroadcast') + self.assertEqual(e1.primitive.name, 'psum_invariant') + self.assertEqual(e2.primitive.name, 'pvary') def test_check_rep_false_grads(self): if jtu.is_device_tpu(5, 'e'): @@ -1673,7 +2335,7 @@ def f(q, k, v): def body(q, k, v): return q * k[None, :] + v[None, :] - out = shard_map(body, mesh, check_rep=False, + out = shard_map(body, mesh=mesh, check_vma=False, in_specs=(q_spec, kv_spec, kv_spec,), out_specs=q_spec)(q, k, v) return out.sum() @@ -1698,7 +2360,7 @@ def foo(x): @partial(jax.remat, policy=lambda *args, **kwargs: True) def bar(x): return shard_map(foo, mesh=Mesh(jax.devices(), ['x']), in_specs=(P('x'),), - out_specs=P('x'), check_rep=False)(x) + out_specs=P('x'), check_vma=False)(x) jax.jit(jax.grad(lambda x: bar(x).sum()))(jnp.arange(8.)) # doesn't crash @@ -1706,7 +2368,7 @@ def bar(x): def test_res_forwarding_optimization(self, jit, remat): mesh = jtu.create_mesh((4,), ('i',)) - @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) + @shard_map(mesh=mesh, in_specs=P('i'), out_specs=P('i')) def f(x): return jax.lax.exp(x) if jit: @@ -1719,7 +2381,7 @@ def f(x): x = jnp.arange(16.) jaxpr_ = jax.make_jaxpr(jax.grad(g))(x) jaxpr, _ = pe.dce_jaxpr(jaxpr_.jaxpr, [True] * len(jaxpr_.out_avals)) - e1, _, e2 = jaxpr.eqns + e1, *_, e2 = jaxpr.eqns self.assertLen(e1.outvars, 1) # only primal output self.assertLen(e2.invars, 2) # res and cotangent inputs self.assertEqual(sum(e1.outvars[0] is v for v in e2.invars), 1) @@ -1729,7 +2391,7 @@ def test_res_forwarding_optimization_complex(self, jit, remat): # like the above test, but a different function `f` mesh = jtu.create_mesh((4,), ('i',)) - @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) + @shard_map(mesh=mesh, in_specs=P('i'), out_specs=P('i')) def f(x): return jax.lax.exp(x.sum()) + x, jax.lax.exp(x) if jit: @@ -1742,7 +2404,7 @@ def f(x): x = jnp.arange(16.) jaxpr_ = jax.make_jaxpr(jax.grad(g))(x) jaxpr, _ = pe.dce_jaxpr(jaxpr_.jaxpr, [True] * len(jaxpr_.out_avals)) - e1, _, e2 = jaxpr.eqns + e1, *_, e2 = jaxpr.eqns self.assertLen(e1.outvars, 2) # one primal and one res output self.assertLen(e2.invars, 4) # two res and two cotangent inputs self.assertEqual(sum(e1.outvars[-1] is v for v in e2.invars), 1) @@ -1752,7 +2414,7 @@ def test_check_rep_failure_inside_rule(self, jit): mesh = jtu.create_mesh((4,), ('i',)) def loss(w, x): - @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P()) + @shard_map(mesh=mesh, in_specs=P('i'), out_specs=P()) def f(x): return jax.lax.psum(((w * x) ** 2).sum(), 'i') return f(x) @@ -1768,8 +2430,8 @@ def test_conv_general_dilated(self): dot = partial(lax.conv_general_dilated, window_strides=(), padding='VALID', dimension_numbers=('NC', 'IO', 'NC')) - @partial(shard_map, mesh=mesh, in_specs=(P(None, 'i'), P('i', None)), - out_specs=P(None, None)) + @shard_map(mesh=mesh, in_specs=(P(None, 'i'), P('i', None)), + out_specs=P(None, None)) def f(x, y): return lax.psum(dot(x, y), 'i') @@ -1809,7 +2471,7 @@ def f(*args): return args[0] @ args[1] shard_f = shard_map( - f, mesh, in_specs=(P('x', 'y', None), P('x', 'y', None)), out_specs=P('x', 'y')) + f, mesh=mesh, in_specs=(P('x', 'y', None), P('x', 'y', None)), out_specs=P('x', 'y')) with self.assertRaisesRegex(ValueError, "shard_map applied to the function 'f'"): shard_f(jnp.ones((8, 8)), jnp.ones((8, 8))) @@ -1830,11 +2492,10 @@ def f_bwd(_, g): def g(x): return f(f(x)) - y, grad = jax.value_and_grad(lambda x: g(x).sum())(jnp.ones(4)) - # first psum sums, second psum multiplies by 4 - self.assertAllClose(y, (jnp.ones(4) * 4).sum(), check_dtypes=False) - # two psums on the backward pass, each one multiplies by 4 - self.assertAllClose(grad, jnp.ones(4) * 4 * 4, check_dtypes=False) + with self.assertRaisesRegex( + ValueError, + "Custom VJP bwd rule must produce an output with the same type"): + jax.value_and_grad(lambda x: g(x).sum())(jnp.ones(4)) def test_repeated_psum_allowed(self): # https://github.com/jax-ml/jax/issues/19175 @@ -1852,7 +2513,8 @@ def test_approx_top_k(self): mesh = Mesh(np.array(jax.devices()[:2]), ('i',)) x = jnp.array([3.0, 1.0, 4.0, 2.0]) - _ = shard_map(lambda x: lax.approx_max_k(x, 2), mesh, P('i'), P('i'))(x) + _ = shard_map(lambda x: lax.approx_max_k(x, 2), mesh=mesh, in_specs=P('i'), + out_specs=P('i'))(x) def test_disable_jit(self): mesh = Mesh(np.array(jax.devices()[:2]), ('i',)) @@ -1865,6 +2527,59 @@ def f(x): with jax.disable_jit(): f(x) # don't crash + @jtu.with_explicit_mesh((2,), 'x') + def test_jacrev_explicit(self, mesh): + B, N, H = 20, 6, 8 + w = jnp.arange(N * H).reshape(N, H).astype(jnp.float32) + x = jnp.arange(B * N).reshape(B, N).astype(jnp.float32) + + def f(w, x): + return jnp.sum(x @ w, axis=-1) + + @jax.jit + @shard_map(in_specs=(P(), P('x', None)), out_specs=P('x', None)) + def f_jac_sharded(w, x): + return jax.jacrev(f, argnums=1)(w, x) + + f_jac_sharded(w, x) # doesn't crash + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_jacrev_explicit_complex(self, mesh): + B, N, H = 20, 6, 8 + w = jnp.arange(N * H).reshape(N, H).astype(jnp.float32) + x = jnp.arange(B * N).reshape(B, N).astype(jnp.float32) + + def f(w, xs): + return jax.tree.map(lambda z: jnp.sum(z @ w, axis=-1), xs) + + @jax.jit + @shard_map(in_specs=(P(), P('x'), P('y')), + out_specs=(P('x'), P('y'), P('x'), P('y'))) + def f_jac_sharded(w, x, y): + ret = jax.jacrev(f, argnums=1)(w, (x, y)) + self.assertEqual(ret[0][0].aval.vma, {'x'}) + self.assertEqual(ret[0][1].aval.vma, {'y'}) + self.assertEqual(ret[1][0].aval.vma, {'x'}) + self.assertEqual(ret[1][1].aval.vma, {'y'}) + return ret[0][0], ret[0][1], ret[1][0], ret[1][1] + + f_jac_sharded(w, x, x) # doesn't crash + + @jtu.with_explicit_mesh((2,), 'x') + def test_random_choice_pvary(self, mesh): + B, C = 8, 3 + key = jax.random.key(0) + keys = jax.random.split(key, B) + hoppable_clusters = jax.random.randint(key, (B, C), minval=0, maxval=2) == 1 + + @jax.vmap + def _update_samples(key, hoppable_clusters): + return jax.random.choice(key, a=jnp.arange(C), p=hoppable_clusters, + replace=True) + + shard_map(_update_samples, in_specs=(P('x'), P('x')), + out_specs=P('x'))(keys, hoppable_clusters) # doesn't crash + @parameterized.parameters(it.product(range(4), repeat=3)) @jtu.run_on_devices("cpu") def test_forwarding_correctness(self, seed, num_input_fwd, num_output_fwd): @@ -1886,22 +2601,365 @@ def f(inputs): jtu.check_grads(f, (list(jnp.arange(float(num_args))[:,None]),), order=1, modes=['rev'], atol=1e-3, rtol=1e-3) + @jtu.with_explicit_mesh((2, 2), ('data', 'seq')) + def test_shmap_unreduced_fsdp_custom_vjp_bwd(self, mesh): + np_inp1 = np.arange(64.).reshape(8, 4, 2) + np_inp2 = np.arange(12.).reshape(2, 6) + arr1 = jax.device_put(np_inp1, P('data', 'seq', None)) + arr2 = jax.device_put(np_inp2, P('seq', None)) + + @jax.custom_vjp + def f(x, y): + @shard_map(in_specs=P('seq', None), + out_specs=P(None, None, reduced={'seq'})) + def ag(a): + self.assertEqual(a.aval.vma, {'seq'}) + self.assertEqual(a.aval.sharding.spec.unreduced, frozenset()) + out = lax.all_gather(a, axis_name='seq', tiled=True, to='reduced') + self.assertEqual(out.aval.vma, frozenset()) + self.assertEqual(out.aval.sharding.spec.unreduced, frozenset()) + self.assertEqual(out.aval.sharding.spec.reduced, {'seq'}) + return out + + y2 = ag(y) + return jnp.einsum('btd,df->btf', x, y2) + + def f_fwd(x, y): + return f(x, y), (x, y) + + def f_bwd(res, g): + x, y = res + + @shard_map(in_specs=P(unreduced={'data', 'seq'}), + out_specs=P('seq', None, unreduced={'data'})) + def rs(a): + self.assertEqual(a.aval.vma, frozenset()) + self.assertEqual(a.aval.sharding.spec.unreduced, {'data', 'seq'}) + out = lax.psum_scatter(a, axis_name='seq', tiled=True) + self.assertEqual(out.aval.vma, {'seq'}) + self.assertEqual(out.aval.sharding.spec.unreduced, {'data'}) + return out + + @shard_map(in_specs=P('seq', None, unreduced={'data'}), + out_specs=P('seq', None)) + def ar(a): + self.assertEqual(a.aval.vma, {'seq'}) + self.assertEqual(a.aval.sharding.spec.unreduced, {'data'}) + out = lax.psum(a, axis_name='data') + self.assertEqual(out.aval.vma, {'seq'}) + self.assertEqual(out.aval.sharding.spec.unreduced, frozenset()) + return out + + x_bar = jnp.einsum('btf,df->btd', g, y, out_sharding=P('data', 'seq', None)) + y_bar_ = jnp.einsum('btd,btf->df', x, g, + out_sharding=P(None, None, unreduced={'data', 'seq'})) + + self.assertEqual(y_bar_.aval.sharding.spec.unreduced, {'data', 'seq'}) + y_bar = rs(y_bar_) + + self.assertEqual(y_bar.aval.sharding.spec.unreduced, {'data'}) + y_bar = ar(y_bar) + self.assertEqual(y_bar.aval.sharding.spec.unreduced, frozenset()) + + return (x_bar, y_bar) + + f.defvjp(f_fwd, f_bwd) + f = jax.jit(f) + + f(arr1, arr2) # doesn't crash + + out1, out2 = jax.jit(jax.grad(lambda x, y: jnp.sin(f(x, y).sum()), + argnums=(0, 1)))(arr1, arr2) + + with jax.set_mesh(jtu.create_mesh((1,), 'x')): + ex_out1, ex_out2 = jax.jit(jax.grad(lambda x, y: jnp.sin((x @ y).sum()), + argnums=(0, 1)))(np_inp1, np_inp2) + self.assertArraysAllClose(ex_out1, out1, rtol=2e-4) + self.assertArraysAllClose(ex_out2, out2, rtol=2e-4) + + @jtu.with_explicit_mesh((2, 2), ('data', 'seq')) + def test_shmap_unreduced_fsdp_grad(self, mesh): + np_inp1 = np.arange(64.).reshape(8, 4, 2) + np_inp2 = np.arange(12.).reshape(2, 6) + arr1 = jax.device_put(np_inp1, P('data', 'seq', None)) + arr2 = jax.device_put(np_inp2, P('seq', None)) + + @shard_map(in_specs=P('seq', None), + out_specs=P('seq', None, reduced={'data'})) + def preduced(a): + self.assertEqual(a.aval.vma, {'seq'}) + self.assertEqual(a.aval.sharding.spec.unreduced, frozenset()) + out = jax.lax.pcast(a, axis_name='data', to='reduced') + self.assertEqual(out.aval.vma, {'seq'}) + self.assertEqual(out.aval.sharding.spec.unreduced, frozenset()) + self.assertEqual(out.aval.sharding.spec.reduced, {'data'}) + return out + + @shard_map(in_specs=P('seq', None, reduced={'data'}), + out_specs=P(None, None, reduced={'seq', 'data'})) + def ag(a): + self.assertEqual(a.aval.vma, {'seq'}) + self.assertEqual(a.aval.sharding.spec.unreduced, frozenset()) + self.assertEqual(a.aval.sharding.spec.reduced, {'data'}) + out = lax.all_gather(a, axis_name='seq', tiled=True, to='reduced') + self.assertEqual(out.aval.vma, frozenset()) + self.assertEqual(out.aval.sharding.spec.unreduced, frozenset()) + self.assertEqual(out.aval.sharding.spec.reduced, {'seq', 'data'}) + return out + + @jax.jit + def f(x, y): + y2 = preduced(y) + y3 = ag(y2) + self.assertEqual(y3.aval.sharding.spec.reduced, {'seq', 'data'}) + return jnp.einsum('btd,df->btf', x, y3) + + out = f(arr1, arr2) # doesn't crash + self.assertEqual(out.sharding, NamedSharding(mesh, P('data', 'seq', None))) + + out1, out2 = jax.jit(jax.grad(lambda x, y: jnp.sin(f(x, y).sum()), + argnums=(0, 1)))(arr1, arr2) + + jaxpr = jax.jit(jax.grad(lambda x, y: jnp.sin(f(x, y).sum()), + argnums=(0, 1))).trace(arr1, arr2).jaxpr + self.assertIn('unreduced_reduce_scatter', str(jaxpr)) + self.assertIn('unreduced_psum', str(jaxpr)) + + with jax.set_mesh(jtu.create_mesh((1,), 'x')): + ex_out1, ex_out2 = jax.jit(jax.grad(lambda x, y: jnp.sin((x @ y).sum()), + argnums=(0, 1)))(np_inp1, np_inp2) + self.assertArraysAllClose(ex_out1, out1, rtol=2e-4) + self.assertEqual(out1.sharding, NamedSharding(mesh, P('data', 'seq', None))) + self.assertArraysAllClose(ex_out2, out2, rtol=2e-4) + self.assertEqual(out2.sharding, NamedSharding(mesh, P('seq', None))) + + @jtu.with_explicit_mesh((2,), 'x') + def test_unreduced_psum_fwd_preduced_bwd(self, mesh): + np_inp1 = np.arange(8.).reshape(2, 4) + np_inp2 = np.arange(24.).reshape(4, 6) + arr1 = jax.device_put(np_inp1, P(None, 'x')) + arr2 = jax.device_put(np_inp2, P('x', None)) + + @shard_map(in_specs=P(unreduced={'x'}), out_specs=P()) + def ar(x): + self.assertEqual(x.aval.vma, frozenset()) + self.assertEqual(x.aval.sharding.spec.unreduced, {'x'}) + out = jax.lax.psum(x, 'x') + self.assertEqual(out.aval.vma, frozenset()) + self.assertEqual(out.aval.sharding.spec.unreduced, frozenset()) + return out + + @jax.jit + def f(x, y): + z = jnp.einsum('ab,bc->ac', x, y, out_sharding=P(unreduced={'x'})) + return ar(z).sum() + + out = f(arr1, arr2) + self.assertArraysEqual(out, np.sum(np_inp1 @ np_inp2)) + self.assertEqual(out.sharding, NamedSharding(mesh, P())) + + out1, out2 = jax.jit(jax.grad(f, argnums=(0, 1)))(arr1, arr2) + self.assertEqual(out1.sharding, NamedSharding(mesh, P(None, 'x'))) + self.assertEqual(out2.sharding, NamedSharding(mesh, P('x', None))) + + with jax.set_mesh(jtu.create_mesh((1,), 'x')): + ex_out1, ex_out2 = jax.jit(jax.grad(lambda x, y: (x @ y).sum(), + argnums=(0, 1)))(np_inp1, np_inp2) + self.assertArraysAllClose(ex_out1, out1, rtol=2e-4) + self.assertArraysAllClose(ex_out2, out2, rtol=2e-4) + + @jtu.with_explicit_mesh((2,), 'x') + def test_preduced_fwd_unreduced_psum_bwd(self, mesh): + np_inp1 = np.arange(8.).reshape(2, 4) + np_inp2 = np.arange(24.).reshape(4, 6) + arr1 = jax.device_put(np_inp1, P(None, None)) + arr2 = jax.device_put(np_inp2, P(None, 'x')) + + @shard_map(in_specs=P(), out_specs=P(reduced={'x'})) + def pr(x): + self.assertEqual(x.aval.vma, frozenset()) + self.assertEqual(x.aval.sharding.spec.unreduced, frozenset()) + self.assertEqual(x.aval.sharding.spec.reduced, frozenset()) + out = jax.lax.pcast(x, 'x', to='reduced') + self.assertEqual(out.aval.vma, frozenset()) + self.assertEqual(out.aval.sharding.spec.unreduced, frozenset()) + self.assertEqual(out.aval.sharding.spec.reduced, {'x'}) + return out + + @jax.jit + def f(x, y): + x = pr(x) + z = jnp.einsum('ab,bc->ac', x, y) + return z.sum() + + out = f(arr1, arr2) + self.assertArraysEqual(out, np.sum(np_inp1 @ np_inp2)) + self.assertEqual(out.sharding, NamedSharding(mesh, P())) + + out = jax.jit(jax.grad(f))(arr1, arr2) + self.assertEqual(out.sharding, NamedSharding(mesh, P(None, None))) + + with jax.set_mesh(jtu.create_mesh((1,), 'x')): + ex_out = jax.jit(jax.grad(lambda x, y: (x @ y).sum()))(np_inp1, np_inp2) + self.assertArraysAllClose(ex_out, out, rtol=2e-4) + + @jtu.with_explicit_mesh((2, 2), ('data', 'seq')) + def test_all_gather_reduced_fwd_unreduced_psum_scatter_bwd(self, mesh): + np_inp1 = np.arange(8.).reshape(4, 2) + np_inp2 = np.arange(12.).reshape(2, 6) + arr1 = jax.device_put(np_inp1, P('seq', None)) + arr2 = jax.device_put(np_inp2, P('seq', None)) + + @shard_map(in_specs=P('seq', None), out_specs=P(None, None, reduced={'seq'})) + def ag(a): + self.assertEqual(a.aval.vma, {'seq'}) + self.assertEqual(a.aval.sharding.spec.unreduced, frozenset()) + out = lax.all_gather(a, axis_name='seq', tiled=True, to='reduced') + self.assertEqual(out.aval.vma, frozenset()) + self.assertEqual(out.aval.sharding.spec.unreduced, frozenset()) + self.assertEqual(out.aval.sharding.spec.reduced, {'seq'}) + return out + + @jax.jit + def f(x, y): + y2 = ag(y) + z = jnp.einsum('td,df->tf', x, y2) + return z.sum() + + out = f(arr1, arr2) + self.assertArraysEqual(out, np.sum(np_inp1 @ np_inp2)) + self.assertEqual(out.sharding, NamedSharding(mesh, P())) + + out1, out2 = jax.jit(jax.grad(f, argnums=(0, 1)))(arr1, arr2) + self.assertEqual(out1.sharding, NamedSharding(mesh, P('seq', None))) + self.assertEqual(out2.sharding, NamedSharding(mesh, P('seq', None))) + + with jax.set_mesh(jtu.create_mesh((1,), 'x')): + ex_out1, ex_out2 = jax.jit(jax.grad(lambda x, y: (x @ y).sum(), + argnums=(0, 1)))(np_inp1, np_inp2) + self.assertArraysAllClose(ex_out1, out1, rtol=2e-4) + self.assertArraysAllClose(ex_out2, out2, rtol=2e-4) + + @jtu.with_explicit_mesh((2,), ('x',)) + def test_unreduced_psum_scatter_fwd_all_gather_reduced_bwd(self, mesh): + np_inp1 = np.arange(8.).reshape(4, 2) + np_inp2 = np.arange(12.).reshape(2, 6) + arr1 = jax.device_put(np_inp1, P(None, 'x')) + arr2 = jax.device_put(np_inp2, P('x', None)) + + @shard_map(in_specs=P(unreduced={'x'}), out_specs=P('x', None)) + def rs(a): + self.assertEqual(a.aval.vma, frozenset()) + self.assertEqual(a.aval.sharding.spec.unreduced, {'x'}) + out = lax.psum_scatter(a, axis_name='x', tiled=True) + self.assertEqual(out.aval.vma, {'x'}) + self.assertEqual(out.aval.sharding.spec.unreduced, frozenset()) + return out + + @jax.jit + def f(x, y): + z = jnp.einsum('ab,bc->ac', x, y, out_sharding=P(unreduced={'x'})) + return rs(z).sum() + + out = f(arr1, arr2) + self.assertArraysEqual(out, np.sum(np_inp1 @ np_inp2)) + self.assertEqual(out.sharding, NamedSharding(mesh, P())) + + out1, out2 = jax.jit(jax.grad(f, argnums=(0, 1)))(arr1, arr2) + self.assertEqual(out1.sharding, NamedSharding(mesh, P(None, 'x'))) + self.assertEqual(out2.sharding, NamedSharding(mesh, P('x', None))) + + with jax.set_mesh(jtu.create_mesh((1,), 'x')): + ex_out1, ex_out2 = jax.jit(jax.grad(lambda x, y: (x @ y).sum(), + argnums=(0, 1)))(np_inp1, np_inp2) + self.assertArraysAllClose(ex_out1, out1, rtol=2e-4) + self.assertArraysAllClose(ex_out2, out2, rtol=2e-4) + + @jtu.with_explicit_mesh((2,), 'x') + def test_sin_unreduced_error(self, mesh): + np_inp1 = np.arange(8.).reshape(4, 2) + np_inp2 = np.arange(12.).reshape(2, 6) + arr1 = jax.device_put(np_inp1, P(None, 'x')) + arr2 = jax.device_put(np_inp2, P('x', None)) + + @jax.jit + def f(x, y): + z = jnp.einsum('ab,bc->ac', x, y, out_sharding=P(unreduced={'x'})) + return shard_map(lambda x: jnp.sin(x), out_specs=P())(z) + + with self.assertRaisesRegex(NotImplementedError, "unreduced rule for sin"): + f(arr1, arr2) + + @jtu.with_explicit_mesh((2,), 'x') + def test_eval_shape_vma(self, mesh): + k1, k2 = jax.random.split(jax.random.key(123)) + p = jax.random.uniform(k1, shape=5, out_sharding=P()) + x = jax.random.uniform(k2, shape=(1024, 5), out_sharding=P('x')) + + def f(p, x): + return jnp.einsum('i, i->', x, p) + + @shard_map(in_specs=(P(), P('x')), out_specs=P('x')) + def g(p, x): + def _grad(f, p, x): + _, vjp_fun = jax.vjp(f, p, x) + y_eval_shape = jax.eval_shape(f, p, x) + self.assertEqual(core.typeof(y_eval_shape).vma, frozenset('x')) + one = jax.lax.full_like(y_eval_shape, 1) + self.assertEqual(core.typeof(one).vma, frozenset('x')) + return vjp_fun(one) + return jax.lax.map(partial(_grad, f, p), x) + + g(p, x) # doesn't crash + jax.jit(g)(p, x) # doesn't crash + + def test_psum_not_under_shmap_error(self): + mesh = jtu.create_mesh((2,), 'x') + + @jax.jit + def f(x): + return jax.lax.psum(x, 'x') + + with self.assertRaisesRegex( + NameError, + 'Found an unbound axis name: x. To fix this, please call psum under' + ' `jax.shard_map`'): + f(jnp.arange(8.)) + + # fixes the above error + shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P()) # doesn't crash + + def test_shmap_auto_unreduced_error(self): + mesh = jtu.create_mesh((2, 1), ('x', 'y')) + with self.assertRaisesRegex( + ValueError, + 'unreduced.*can only be used when the mesh passed to shard_map contains' + ' axis names all of type `Explicit`'): + shard_map(lambda x: x, mesh=mesh, in_specs=P(unreduced={'x'}), + out_specs=P())(np.arange(8)) + + with self.assertRaisesRegex( + NotImplementedError, + 'unreduced.*can only be passed to in_specs when shard_map is in full' + ' manual mode'): + shard_map(lambda x: x, mesh=mesh, in_specs=P(unreduced={'x'}), + out_specs=P(), axis_names={'x'})(np.arange(8)) + def test_partial_auto(self): mesh = jtu.create_mesh((2, 2), ('i', 'j')) def g(x): - self.assertDictEqual(x.aval.sharding.mesh._axis_types_dict, - {AxisType.Manual: ('i',), AxisType.Auto: ('j',)}) - x = jax.lax.with_sharding_constraint( - x, jax.sharding.NamedSharding(mesh, P(None, 'j'))) + self.assertTupleEqual(x.aval.sharding.mesh.axis_types, + (AxisType.Manual, AxisType.Auto)) + x = jax.lax.with_sharding_constraint(x, P(None, 'j')) return x * x @jax.jit def f(x): - x = shard_map(g, mesh, + x = shard_map(g, mesh=mesh, in_specs=P('i', None), out_specs=P('i', None), - auto=frozenset({'j'}))(x) + axis_names=frozenset({'i'}))(x) return jax.lax.with_sharding_constraint( x, jax.sharding.NamedSharding(mesh, P('i', 'j'))) @@ -1909,8 +2967,8 @@ def f(x): v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) if config.use_shardy_partitioner.value: self.assertIn( - 'in_shardings=[<@mesh, [{"i"}, {?}]>]' - ' out_shardings=[<@mesh, [{"i"}, {?}]>] manual_axes={"i"}', + 'in_shardings=[<@mesh, [{"i", ?}, {?}]>]' + ' out_shardings=[<@mesh, [{"i", ?}, {?}]>] manual_axes={"i"}', f.lower(v).as_text(), ) else: @@ -1921,13 +2979,13 @@ def f(x): ) self.assertAllClose(v * v, f(v), check_dtypes=False) - def test_partial_auto_explicit_no_use_mesh(self): + def test_partial_auto_explicit_no_set_mesh(self): mesh = jtu.create_mesh((2, 2), ('i', 'j'), axis_types=(AxisType.Explicit,) * 2) def g(x): - self.assertDictEqual(x.aval.sharding.mesh._axis_types_dict, - {AxisType.Manual: ('i',), AxisType.Explicit: ('j',)}) + self.assertTupleEqual(x.aval.sharding.mesh.axis_types, + (AxisType.Manual, AxisType.Explicit)) self.assertEqual(x.aval.sharding.spec, P(None, 'j')) out = x * x self.assertEqual(out.aval.sharding.spec, P(None, 'j')) @@ -1935,10 +2993,10 @@ def g(x): @jax.jit def f(x): - x = shard_map(g, mesh, + x = shard_map(g, mesh=mesh, in_specs=P('i', None), out_specs=P('i', None), - auto=frozenset({'j'}))(x) + axis_names=frozenset({'i'}))(x) self.assertEqual(x.aval.sharding.spec, P('i', 'j')) return x @@ -1949,11 +3007,11 @@ def f(x): self.assertEqual(out.sharding, NamedSharding(mesh, P('i', 'j'))) self.assertAllClose(v * v, out, check_dtypes=False) - @jtu.with_user_mesh((2, 2), ('i', 'j')) + @jtu.with_explicit_mesh((2, 2), ('i', 'j')) def test_partial_auto_explicit(self, mesh): def g(x): - self.assertDictEqual(x.aval.sharding.mesh._axis_types_dict, - {AxisType.Manual: ('i',), AxisType.Explicit: ('j',)}) + self.assertTupleEqual(x.aval.sharding.mesh.axis_types, + (AxisType.Manual, AxisType.Explicit)) self.assertEqual(x.aval.sharding.spec, P(None, 'j')) out = x * x self.assertEqual(out.aval.sharding.spec, P(None, 'j')) @@ -1961,10 +3019,7 @@ def g(x): @jax.jit def f(x): - x = shard_map(g, mesh, - in_specs=P('i', None), - out_specs=P('i', None), - auto=frozenset({'j'}))(x) + x = jax.shard_map(g, out_specs=P('i', None), axis_names=frozenset({'i'}))(x) self.assertEqual(x.aval.sharding.spec, P('i', 'j')) return x @@ -1993,12 +3048,12 @@ def h(x): jax.grad(h)(v) # doesn't crash jax.jit(jax.grad(h))(v) # doesn't crash - @jtu.with_user_mesh((2, 1, 2, 2), ('i', 'j', 'k', 'l')) + @jtu.with_explicit_mesh((2, 1, 2, 2), ('i', 'j', 'k', 'l')) def test_partial_auto_explicit_multi_explicit(self, mesh): def g(x): - self.assertDictEqual(x.aval.sharding.mesh._axis_types_dict, - {AxisType.Manual: ('i', 'j'), - AxisType.Explicit: ('k', 'l')}) + self.assertTupleEqual(x.aval.sharding.mesh.axis_types, + (AxisType.Manual, AxisType.Manual, + AxisType.Explicit, AxisType.Explicit)) self.assertEqual(x.aval.sharding.spec, P(None, None, 'k', 'l')) out = x.T self.assertEqual(out.aval.sharding.spec, P('l', 'k', None, None)) @@ -2006,10 +3061,8 @@ def g(x): @jax.jit def f(x): - x = shard_map(g, mesh, - in_specs=P('i', 'j', None, None), - out_specs=P('i', 'j', None, None), - auto=frozenset({'k', 'l'}))(x) + x = jax.shard_map(g, out_specs=P('i', 'j', None, None), + axis_names=frozenset({'i', 'j'}))(x) self.assertEqual(x.aval.sharding.spec, P(('i', 'l'), ('j', 'k'), None, None)) return x @@ -2031,11 +3084,11 @@ def g(x): def f(x): return shard_map( g, - mesh, + mesh=mesh, in_specs=P(), out_specs=P(), - check_rep=False, - auto=frozenset({'i'}), + check_vma=False, + axis_names=frozenset({'j', 'k'}), )(x) v = jnp.arange(32.0).reshape(4, 8) @@ -2067,13 +3120,43 @@ def update_fn(params, batch): def grad_fn(batch): return jax.value_and_grad(loss_fn)(params, batch) return shard_map(grad_fn, mesh=mesh, in_specs=P("data"), out_specs=P(), - check_rep=False)(batch) + check_vma=False)(batch) arr_sharded = jax.device_put(jnp.arange(32.0).reshape(4, 8), NamedSharding(mesh, P())) params = jnp.copy(arr_sharded) update_fn(params, arr_sharded) # doesn't crash + @jtu.with_explicit_mesh((2,), ('x',)) + def test_close_over_explicit_sharded_input_error(self, mesh): + def simple_func(w, x): + return jnp.sum(w * x, axis=-1) + + w = jnp.ones((2, 4), dtype=np.float32) + x = jnp.ones((4, 4), dtype=np.float32) + + shard_map(simple_func, in_specs=(P(), P('x')), out_specs=P('x'))(w, x) + + with self.assertRaisesRegex( + NotImplementedError, + 'Closing over inputs to shard_map where the input is sharded on' + ' `Explicit` axes is not implemented'): + shard_map(lambda xi: simple_func(w, xi), + in_specs=P('x'), out_specs=P('x'))(x) + + def test_close_over_input_explict_ctx_mesh(self): + mesh = jtu.create_mesh((2,), 'x', axis_types=(AxisType.Explicit,)) + w = jnp.ones((2, 4), dtype=np.float32) + x = jnp.ones((4, 4), dtype=np.float32) + + def simple_func(w, x): + return jnp.sum(w * x, axis=-1) + + shard_map(simple_func, mesh=mesh, in_specs=(P(), P('x')), + out_specs=P('x'))(w, x) + shard_map(lambda xi: simple_func(w, xi), mesh=mesh, + in_specs=P('x'), out_specs=P('x'))(x) + def test_shmap_close_over_unused_params_vmap(self): mesh = jtu.create_mesh((2,), ("data",)) @@ -2085,7 +3168,7 @@ def update_fn(params, batch): def grad_fn(batch): return jax.value_and_grad(loss_fn)(params, batch) return shard_map(jax.vmap(grad_fn), mesh=mesh, in_specs=P("data"), - out_specs=P("data"), check_rep=False)(batch) + out_specs=P("data"), check_vma=False)(batch) arr_sharded = jax.device_put(jnp.arange(32.0).reshape(4, 8), NamedSharding(mesh, P())) @@ -2119,11 +3202,11 @@ def g(x): @jax.jit def f(x): - x = shard_map(g, mesh, + x = shard_map(g, mesh=mesh, in_specs=P('i', None), out_specs=P('i', None), - check_rep=False, - auto=frozenset({'j'}))(x) + check_vma=False, + axis_names=frozenset({'i'}))(x) return jax.lax.with_sharding_constraint( x, jax.sharding.NamedSharding(mesh, P('i', 'j'))) @@ -2142,17 +3225,17 @@ def g(x): @jax.jit def f(x): - x = shard_map(g, mesh, + x = shard_map(g, mesh=mesh, in_specs=P('i', None), out_specs=P('i', None), - check_rep=False, - auto=frozenset({'k'}))(x) + check_vma=False, + axis_names=frozenset({'i', 'j'}))(x) return jax.lax.with_sharding_constraint( x, jax.sharding.NamedSharding(mesh, P('i', 'j'))) v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) - with self.assertRaisesRegex(ValueError, "to be a subset of mesh.axis_names"): + with self.assertRaisesRegex(ValueError, "contains a manual axes.*of mesh"): f(v) def test_partial_auto_error_wrong_in_specs(self): @@ -2165,11 +3248,11 @@ def g(x): @jax.jit def f(x): - x = shard_map(g, mesh, + x = shard_map(g, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', None), - check_rep=False, - auto=frozenset({'j'}))(x) + check_vma=False, + axis_names=frozenset({'i'}))(x) return jax.lax.with_sharding_constraint( x, jax.sharding.NamedSharding(mesh, P('i', 'j'))) @@ -2178,28 +3261,83 @@ def f(x): with self.assertRaisesRegex(ValueError, "in_specs refers to 'j'"): f(v) - def test_nested_partial_auto(self): + def test_partial_auto_mismatch_mesh_error(self): mesh = jtu.create_mesh((2, 2), ('i', 'j')) + v = jnp.arange(32.).reshape(4, 8) + v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) def g(x): return x * x def h(x): - return shard_map(g, mesh, - in_specs=P(None, 'j'), - out_specs=P(None, 'j'))(x) + return shard_map(g, mesh=mesh, in_specs=P(None, 'j'), + out_specs=P(None, 'j'))(x) @jax.jit def f(x): - return shard_map(h, mesh, - in_specs=P('i', None), - out_specs=P('i', None), - check_rep=False, - auto=frozenset({'j'}))(x) + return shard_map(h, mesh=mesh, in_specs=P('i', None), + out_specs=P('i', None), check_vma=False, + axis_names=frozenset({'i'}))(x) + + with self.assertRaisesRegex( + ValueError, r"context mesh.*should match the mesh passed to shard_map"): + self.assertAllClose(v*v, f(v), check_dtypes=False) + def test_nested_partial_auto(self): + mesh = jtu.create_mesh((2, 2), ('i', 'j')) v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) - self.assertAllClose(v*v, f(v), check_dtypes=False) + + def g(x): + return x * x + + def h(x): + return shard_map(g, in_specs=P(None, 'j'), out_specs=P(None, 'j'))(x) + + @jax.jit + def f(x): + return shard_map(h, in_specs=P('i', None), out_specs=P('i', None), + check_vma=False, axis_names=frozenset({'i'}))(x) + + with jax.set_mesh(mesh): + self.assertAllClose(v*v, f(v), check_dtypes=False) + + @parameterized.named_parameters( + ('0', 'x', 'y', {'x'}, {'x', 'y'}), + ('1', None, 'y', frozenset(), {'y'}), + ('2', 'x', None, {'x'}, {'x'}), + ('3', None, None, frozenset(), frozenset()), + ) + def test_nested_partial_auto_1d(self, dim1, dim2, outer_vma, inner_vma): + mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) + np_inp = np.arange(32.).reshape(4, 8) + arr = jax.device_put(np_inp, NamedSharding(mesh, P(dim1, dim2))) + + def g(x): + self.assertEqual(get_abstract_mesh().manual_axes, ('x', 'y')) + self.assertEqual(get_abstract_mesh().auto_axes, ('z',)) + self.assertEqual(x.aval.vma, inner_vma) + out = x * x + self.assertEqual(out.aval.vma, inner_vma) + return out + + def h(x): + self.assertEqual(get_abstract_mesh().manual_axes, ('x',)) + self.assertEqual(get_abstract_mesh().auto_axes, ('y', 'z')) + self.assertEqual(x.aval.vma, outer_vma) + out = shard_map(g, in_specs=P(None, dim2), + out_specs=P(None, dim2), axis_names={'y'})(x) + self.assertEqual(out.aval.vma, outer_vma) + return out + + @jax.jit + def f(x): + return shard_map(h, in_specs=P(dim1, None), + out_specs=P(dim1, None), axis_names={'x'})(x) + + with jax.set_mesh(mesh): + out = f(arr) + self.assertArraysEqual(out, np_inp * np_inp) def test_grad_nested_partial_auto(self): mesh = jtu.create_mesh((2, 2), ('i', 'j')) @@ -2210,22 +3348,19 @@ def g(x): def h(x): # auto: 'j', manual: 'i' - return shard_map(g, mesh, - in_specs=P(None, 'j'), - out_specs=P(None, 'j'))(x) + return shard_map(g, in_specs=P(None, 'j'), out_specs=P(None, 'j'))(x) @jax.jit def f(x): # auto: 'i', 'j' - return shard_map(h, mesh, - in_specs=P('i', None), - out_specs=P('i', None), - check_rep=False, - auto=frozenset({'j'}))(x).sum() + return shard_map(h, in_specs=P('i', None), out_specs=P('i', None), + check_vma=False, axis_names=frozenset({'i'}))(x).sum() v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) - self.assertAllClose(v*2, jax.grad(f)(v), check_dtypes=False) + with jax.set_mesh(mesh): + out = jax.grad(f)(v) + self.assertAllClose(out, v * 2, check_dtypes=False) def test_grad_nested_partial_auto_with_residuals(self): mesh = jtu.create_mesh((2, 2), ('i', 'j')) @@ -2234,21 +3369,18 @@ def g(x): return x * x * x def h(x): - return shard_map(g, mesh, - in_specs=P(None, 'j'), - out_specs=P(None, 'j'))(x) + return shard_map(g, in_specs=P(None, 'j'), out_specs=P(None, 'j'))(x) @jax.jit def f(x): - return shard_map(h, mesh, - in_specs=P('i', None), - out_specs=P('i', None), - check_rep=False, - auto=frozenset({'j'}))(x).sum() + return shard_map(h, in_specs=P('i', None), out_specs=P('i', None), + check_vma=False, axis_names=frozenset({'i'}))(x).sum() v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) - self.assertAllClose(v*v*3, jax.grad(f)(v), check_dtypes=False) + with jax.set_mesh(mesh): + out = jax.grad(f)(v) + self.assertAllClose(out, v * v * 3, check_dtypes=False) def test_axis_size_1_partial_auto(self): mesh = jtu.create_mesh((1, 2, 2), ('i', 'j', 'k')) @@ -2258,11 +3390,11 @@ def h(x): @jax.jit def f(x): - return shard_map(h, mesh, + return shard_map(h, mesh=mesh, in_specs=P('i', None), out_specs=P('i', None), - check_rep=False, - auto=frozenset({'j', 'k'}))(x) + check_vma=False, + axis_names=frozenset({'i'}))(x) v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) @@ -2280,8 +3412,8 @@ def _make_zeros(): def f(): return shard_map( - h, mesh, in_specs=(), - out_specs=P('i'), check_rep=False, auto=frozenset({'j'}))() + h, mesh=mesh, in_specs=(), + out_specs=P('i'), check_vma=False, axis_names=frozenset({'i'}))() self.assertAllClose(jax.jit(f)(), jnp.zeros((2,))) @@ -2303,8 +3435,8 @@ def _make_zeros(): def f(): return shard_map( - h, mesh, in_specs=(), - out_specs=P('i'), check_rep=False, auto=frozenset({'j'}))() + h, mesh=mesh, in_specs=(), + out_specs=P('i'), check_vma=False, axis_names=frozenset({'i'}))() self.assertAllClose(jax.jit(f)(), jnp.zeros((2,))) @@ -2312,23 +3444,24 @@ def test_partial_auto_axis_index(self): mesh = jtu.create_mesh((4, 2), ('i', 'j')) out_sharding = NamedSharding(mesh, P('i', None)) - @partial(jax.jit, out_shardings=out_sharding) + @jax.jit(out_shardings=out_sharding) def f(): return shard_map(lambda: jax.lax.axis_index('i').reshape(1,1), - mesh, in_specs=P('i', None), out_specs=P('i', None), - check_rep=False, auto=frozenset({'j'}))() + in_specs=P('i', None), out_specs=P('i', None), + check_vma=False, axis_names=frozenset({'i'}))() - self.assertAllClose(f(), np.arange(4, dtype=np.int32).reshape(-1, 1)) + with jax.set_mesh(mesh): + self.assertAllClose(f(), np.arange(4, dtype=np.int32).reshape(-1, 1)) def test_partial_auto_axis_index_degenerated_axis(self): mesh = jtu.create_mesh((1, 2), ('i', 'j')) out_sharding = NamedSharding(mesh, P('i', None)) - @partial(jax.jit, out_shardings=out_sharding) + @jax.jit(out_shardings=out_sharding) def f(): return shard_map(lambda: jax.lax.axis_index('i').reshape(1, 1), - mesh, in_specs=P('i', None), out_specs=P('i', None), - check_rep=False, auto=frozenset({'j'}))() + mesh=mesh, in_specs=P('i', None), out_specs=P('i', None), + check_vma=False, axis_names=frozenset({'i'}))() self.assertAllClose(f(), np.arange(1, dtype=np.int32).reshape(-1, 1)) @@ -2343,8 +3476,8 @@ def g(x): @jax.jit def f(x): return shard_map(g, - mesh, in_specs=P('i'), out_specs=P('i'), - check_rep=False, auto=frozenset({'j'}))(x) + mesh=mesh, in_specs=P('i'), out_specs=P('i'), + check_vma=False, axis_names=frozenset({'i'}))(x) y = f(x) # don't crash self.assertAllClose(y, jnp.array([6., 7., 0., 1., 2., 3., 4., 5.]), @@ -2363,8 +3496,8 @@ def f(x): # @jax.jit # def f(x): # return shard_map(g, - # mesh, in_specs=P('i', None), out_specs=P(None, 'i'), - # check_rep=False, auto=frozenset({'j'}))(x) + # mesh=mesh, in_specs=P('i', None), out_specs=P(None, 'i'), + # check_vma=False, axis_names=frozenset({'i'}))(x) # # f(x) # don't crash @@ -2380,11 +3513,11 @@ def g(x): @jax.jit def f(x): - return shard_map(g, - mesh, in_specs=P('i'), out_specs=None, - check_rep=False, auto=frozenset({'j'}))(x) + return shard_map(g, mesh=mesh, in_specs=P('i'), out_specs=None, + check_vma=False, axis_names=frozenset({'i'}))(x) - y = f(x) # don't crash + with jax.set_mesh(mesh): + f(x) # don't crash def test_partial_auto_of_random_keys(self): mesh = jtu.create_mesh((4, 2), ('i', 'j')) @@ -2393,8 +3526,8 @@ def test_partial_auto_of_random_keys(self): @jax.jit def f(x): return shard_map(lambda k: k, - mesh, in_specs=P('i'), out_specs=P('i'), - check_rep=False, auto=frozenset({'j'}))(keys) + mesh=mesh, in_specs=P('i'), out_specs=P('i'), + check_vma=False, axis_names=frozenset({'i'}))(keys) y = f(keys) # doesn't crash self.assertAllClose(jax.random.key_data(y), jax.random.key_data(keys), @@ -2407,21 +3540,26 @@ def test_partial_auto_of_random_keys_slice(self): @jax.jit def f(x): return shard_map(lambda k: k[0], - mesh, in_specs=P('i'), out_specs=P('i'), - check_rep=False, auto=frozenset({'j'}))(x) + mesh=mesh, in_specs=P('i'), out_specs=P('i'), + check_vma=False, axis_names=frozenset({'i'}))(x) f(keys) # doesn't crash + def test_grad_remat(self): + mesh = jtu.create_mesh((1, 1), ('i', 'j')) + args = [jnp.arange(6.).reshape(3, 2), jnp.arange(6.).reshape(3, 2, 1)] + + @partial(jax.remat, policy=lambda *_, **__: True) + @shard_map(mesh=mesh, in_specs=(P('j'), P('i')), out_specs=P('i', 'j')) + def f(x, y): + return jnp.dot(x, y) + jax.grad(lambda x, y: f(x, y).sum())(*args) + def test_vmap_grad_shmap_spmd_axis_name_residuals(self): # https://github.com/jax-ml/jax/pull/21032 mesh = jtu.create_mesh((4, 2), ('i', 'j')) - @partial( - shard_map, - mesh=mesh, - in_specs=P('j'), - out_specs=P('j'), - ) + @shard_map(mesh=mesh, in_specs=P('j'), out_specs=P('j')) def f(x): return jnp.sin(x) @@ -2434,12 +3572,7 @@ def test_vmap_grad_remat_shmap_spmd_axis_name_residuals(self): mesh = jtu.create_mesh((4, 2), ('i', 'j')) @partial(jax.remat, policy=lambda *_, **__: True) - @partial( - shard_map, - mesh=mesh, - in_specs=P('j'), - out_specs=P('j'), - ) + @partial(shard_map, mesh=mesh, in_specs=P('j'), out_specs=P('j')) def f(x): return jnp.sin(x) @@ -2454,8 +3587,8 @@ def test_grad_shmap_residuals_axis_names_in_mesh_order(self): @partial( shard_map, mesh=mesh, - in_specs=P('j'), - out_specs=P('j'), + in_specs=P(('i', 'k')), + out_specs=P(('i', 'k')), ) def f(x): return jnp.sin(x) @@ -2465,22 +3598,45 @@ def f(x): ir = jax.jit(jax.grad(lambda x: f(x).sum())).lower(xs) if config.use_shardy_partitioner.value: self.assertIn( - 'out_shardings=[<@mesh, [{"i", "j", "k", "a"}]>]', ir.as_text() + 'out_shardings=[<@mesh, [{"i", "k"}]>]', ir.as_text() ) else: self.assertIn( - "{jax.result_info = \"[('i', 'j', 'k', 'a')]\"}", ir.as_text() + "{jax.result_info = \"[('i', 'k')]\"}", ir.as_text() ) + def test_dynamic_slice_transpose(self): + mesh = jtu.create_mesh((2,), ('x',)) + arr = np.arange(16., dtype=np.float32) + + @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x')) + def f(x): + return lax.dynamic_slice_in_dim(x, jnp.array(1, dtype=np.int32), 2) + + f(arr) # doesn't crash + jax.jit(f)(arr) # doesn't crash + + def g(x): + return jnp.sum(f(x)) + + jax.grad(g)(arr) # doesn't crash + jax.jit(jax.grad(g))(arr) # doesn't crash + + @parameterized.parameters([P()], [P('x')], [P(('x', 'y'))]) + def test_print_inside_shard_map(self, specs): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + x = jnp.arange(4.) + + @partial(shard_map, mesh=mesh, in_specs=specs, out_specs=specs) + def f(x): + print(x) + return 2 * x + f(x) # doesn't crash + def test_vmap_spmd_axis_name_error(self): mesh = jtu.create_mesh((4, 2), ('i', 'j')) - @partial( - shard_map, - mesh=mesh, - in_specs=P('i'), - out_specs=P('i'), - ) + @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) def f(x): return jnp.sin(x) @@ -2488,13 +3644,8 @@ def f(x): with self.assertRaisesRegex(ValueError, "spmd_axis_name cannot appear"): jax.vmap(f, spmd_axis_name='i')(xs) - @partial( - shard_map, - mesh=mesh, - in_specs=P('j'), - out_specs=P(('i', 'j')), - check_rep=False, - ) + @partial(shard_map, mesh=mesh, in_specs=P('j'), out_specs=P(('i', 'j')), + check_vma=False) def g(x): return jnp.sin(x) @@ -2512,11 +3663,11 @@ def f(o, x): return jnp.sin(x) obj = object() - y = shard_map(f, mesh, (None, P('i')), P('i'))(obj, x) + y = shard_map(f, mesh=mesh, in_specs=(None, P('i')), out_specs=P('i'))(obj, x) self.assertAllClose(y, jnp.sin(x), check_dtypes=False) obj = None - y = shard_map(f, mesh, (None, P('i')), P('i'))(None, x) + y = shard_map(f, mesh=mesh, in_specs=(None, P('i')), out_specs=P('i'))(None, x) self.assertAllClose(y, jnp.sin(x), check_dtypes=False) def f2(o, x): @@ -2525,7 +3676,7 @@ def f2(o, x): return jnp.sin(x) obj = {'a': object()} - y = shard_map(f2, mesh, ({'a': None}, P('i')), P('i'))(obj, x) + y = shard_map(f2, mesh=mesh, in_specs=({'a': None}, P('i')), out_specs=P('i'))(obj, x) self.assertAllClose(y, jnp.sin(x), check_dtypes=False) def f3(x, o): @@ -2533,11 +3684,11 @@ def f3(x, o): return jnp.sin(x) obj = object() - y = shard_map(f3, mesh, (P('i'), None), P('i'))(x, obj) + y = shard_map(f3, mesh=mesh, in_specs=(P('i'), None), out_specs=P('i'))(x, obj) self.assertAllClose(y, jnp.sin(x), check_dtypes=False) obj = None - y = shard_map(f3, mesh, (P('i'), None), P('i'))(x, obj) + y = shard_map(f3, mesh=mesh, in_specs=(P('i'), None), out_specs=P('i'))(x, obj) self.assertAllClose(y, jnp.sin(x), check_dtypes=False) def f4(o1, o2, x, o3): @@ -2550,7 +3701,8 @@ def f4(o1, o2, x, o3): obj1 = object() obj2 = (object(), object()) obj3 = object() - y = shard_map(f4, mesh, (None, None, P('i'), None), P('i'))(obj1, obj2, x, obj3) + y = shard_map(f4, mesh=mesh, in_specs=(None, None, P('i'), None), + out_specs=P('i'))(obj1, obj2, x, obj3) self.assertAllClose(y, jnp.sin(x), check_dtypes=False) def test_in_spec_none_divisibility_errors(self): @@ -2558,44 +3710,48 @@ def test_in_spec_none_divisibility_errors(self): x = jnp.arange(4).reshape(2, 2) with self.assertRaisesRegex(ValueError, 'divisible'): - shard_map(lambda *_: None, mesh, (None, P('i')), None)(object(), x) + shard_map(lambda *_: None, mesh=mesh, in_specs=(None, P('i')), + out_specs=None)(object(), x) with self.assertRaisesRegex(ValueError, 'divisible'): - shard_map(lambda *_: None, mesh, (P('i'), None), None)(x, object()) + shard_map(lambda *_: None, mesh=mesh, in_specs=(P('i'), None), + out_specs=None)(x, object()) with self.assertRaisesRegex(ValueError, 'divisible'): - shard_map(lambda *_: None, mesh, (P('i'), None), None - )(x, (object(), object())) + shard_map(lambda *_: None, mesh=mesh, in_specs=(P('i'), None), + out_specs=None)(x, (object(), object())) with self.assertRaisesRegex(ValueError, 'divisible'): - shard_map(lambda *_: None, mesh, (P('i'), (None, None)), None, - )(x, (object(), object())) + shard_map(lambda *_: None, mesh=mesh, in_specs=(P('i'), (None, None)), + out_specs=None)(x, (object(), object())) with self.assertRaisesRegex(ValueError, 'divisible'): - shard_map(lambda *_: None, mesh, ((None, None), P('i')), None, - )((object(), object()), x) + shard_map(lambda *_: None, mesh=mesh, in_specs=((None, None), P('i')), + out_specs=None)((object(), object()), x) def test_in_spec_none_rank_errors(self): mesh = jtu.create_mesh((4, 2), ('i', 'j')) x = jnp.arange(4) with self.assertRaisesRegex(ValueError, 'rank'): - shard_map(lambda *_: None, mesh, (None, P('i', 'j')), None)(object(), x) + shard_map(lambda *_: None, mesh=mesh, in_specs=(None, P('i', 'j')), + out_specs=None)(object(), x) with self.assertRaisesRegex(ValueError, 'rank'): - shard_map(lambda *_: None, mesh, (P('i', 'j'), None), None)(x, object()) + shard_map(lambda *_: None, mesh=mesh, in_specs=(P('i', 'j'), None), + out_specs=None)(x, object()) with self.assertRaisesRegex(ValueError, 'rank'): - shard_map(lambda *_: None, mesh, (P('i', 'j'), None), None - )(x, (object(), object())) + shard_map(lambda *_: None, mesh=mesh, in_specs=(P('i', 'j'), None), + out_specs=None)(x, (object(), object())) with self.assertRaisesRegex(ValueError, 'rank'): - shard_map(lambda *_: None, mesh, (P('i', 'j'), (None, None)), None, - )(x, (object(), object())) + shard_map(lambda *_: None, mesh=mesh, in_specs=(P('i', 'j'), (None, None)), + out_specs=None)(x, (object(), object())) with self.assertRaisesRegex(ValueError, 'rank'): - shard_map(lambda *_: None, mesh, ((None, None), P('i', 'j')), None, - )((object(), object()), x) + shard_map(lambda *_: None, mesh=mesh, in_specs=((None, None), P('i', 'j')), + out_specs=None)((object(), object()), x) def test_custom_linear_solve_rep_rules(self): # https://github.com/jax-ml/jax/issues/20162 @@ -2616,7 +3772,7 @@ def test_temporary_error_suppression_flag(self): def f(x, y): z = shard_map(lambda x, y: x + jax.lax.all_gather(y, 'i', tiled=True), mesh=mesh, in_specs=(P(None), P('i')), out_specs=P(None), - check_rep=False, + check_vma=False, )(x, y) return z @@ -2641,6 +3797,29 @@ def f(a): f(A()) # don't crash + @parameterized.named_parameters( + ('axis_name', True), + ('no_axis_name', False), + ) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_explicit_vmap_grad_shmap(self, use_axis_name, mesh): + np_inp = np.arange(6 * 24, dtype=np.float32).reshape(6, 24) + arr = jax.device_put(np_inp, P('x', None)) + + def g(x): + self.assertEqual(x.aval.vma, frozenset()) + self.assertEqual(x.aval.sharding.spec, P(None)) + if use_axis_name: + out = jax.shard_map(jnp.cos, in_specs=P('y'), out_specs=P('y'), + axis_names={'y'})(x) + else: + out = jax.shard_map(jnp.cos, in_specs=P('y'), out_specs=P('y'))(x) + self.assertEqual(out.aval.sharding.spec, P('y')) + return out.sum() + + out = jax.jit(jax.vmap(jax.grad(g)))(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + def test_get_check_rep(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) @@ -2650,13 +3829,7 @@ def f(x, reduce_along, use_jit): @partial(shard_map, mesh=mesh, in_specs=P('x', 'y'), out_specs=out_spec) def g(x): result = lax.psum(x, axis_name=reduce_along) - def check_rep(result): - self.assertEqual( - jax.experimental.shard_map.get_replication(result), - set(reduce_along)) - return result - result = check_rep(result) - result = jax.vmap(check_rep)(result) + self.assertEqual(result.aval.vma, x.aval.vma - set(reduce_along)) return result if use_jit: return jax.jit(g)(x) @@ -2673,26 +3846,1018 @@ def test_pmin(self): mesh = jtu.create_mesh((4,), ('i',)) x = jnp.arange(8., dtype=np.float32) y = shard_map(lambda x: jax.lax.pmin(x, 'i'), - mesh=mesh, in_specs=P('i'), out_specs=P() - )(x) # don't crash + mesh=mesh, in_specs=P('i'), out_specs=P())(x) # don't crash self.assertArraysEqual(y, np.array([0, 1], dtype=np.float32)) def test_pmax(self): mesh = jtu.create_mesh((4,), ('i',)) x = jnp.arange(8., dtype=np.float32) y = shard_map(lambda x: jax.lax.pmax(x, 'i'), - mesh=mesh, in_specs=P('i'), out_specs=P() - )(x) # don't crash + mesh=mesh, in_specs=P('i'), out_specs=P())(x) # don't crash self.assertArraysEqual(y, np.array([6, 7], dtype=np.float32)) + def test_pmax_vma_in_types(self): + mesh = jtu.create_mesh((4,), ('i',)) + x = jnp.arange(8., dtype=np.float32) + f = jax.jit(shard_map(lambda x: jax.lax.pmax(x, 'i'), mesh=mesh, + in_specs=P(), out_specs=P())) + jaxpr = f.trace(x).jaxpr + self.assertIn("pvary[axes=('i',)", str(jaxpr)) + f(x) # doesn't crash + + def test_mul_with_vma_in_types(self): + mesh = jtu.create_mesh((2,), ('x',)) + x = np.arange(8.) + + def f(x): + self.assertEqual(x.aval.vma, frozenset({'x'})) + out = x * 2 + self.assertEqual(out.aval.vma, frozenset({'x'})) + return out + + f = jax.jit(shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))) + jaxpr = f.trace(x).jaxpr + self.assertIn("pvary[axes=('x',)", str(jaxpr)) + out = f(x) + self.assertArraysEqual(out, x * 2) + + # TODO(yashkatariya): Enable grad test which requires adding psum_p support. + # def g(x, y): + # return jnp.sum(f(x, y)) + # print(jax.jit(jax.grad(g)).trace(x, y).jaxpr) + + def test_all_gather_with_vma_in_types(self): + mesh = jtu.create_mesh((2,), ('x',)) + x = np.arange(8.) + + def f(x): + self.assertEqual(x.aval.vma, frozenset()) + out = jax.lax.all_gather(x, 'x') + self.assertEqual(out.aval.vma, frozenset({'x'})) + return out + + f = jax.jit(shard_map(f, mesh=mesh, in_specs=P(), out_specs=P('x'))) + jaxpr = f.trace(x).jaxpr + self.assertIn("pvary[axes=('x',)", str(jaxpr)) + + f(x) # doesn't crash + + def test_rep_none_canonicalization(self): + # https://github.com/jax-ml/jax/issues/26621 + if config.use_shardy_partitioner.value: + self.skipTest('complex values fail under shardy') + N = 8 + xs = jnp.ones((8, N), dtype=jnp.int32) + variables = jax.random.normal(jax.random.key(1), (N, N), jnp.complex64) + mesh = jtu.create_mesh((2,), ('i',)) + in_specs = (P(), P("i"),) + out_specs = P("i") + + variables = jax.lax.with_sharding_constraint(variables, NamedSharding(mesh, P())) + xs = jax.lax.with_sharding_constraint(xs, NamedSharding(mesh, P('i'))) + + def fun(v, xs): + # Commenting this single line below makes everything work + v = jax.scipy.linalg.expm(v) + v = v.sum() + return v * xs.sum(axis=-1).astype(v.dtype) + + res = fun(variables, xs) + fun_shard_map = shard_map(fun, mesh=mesh, in_specs=in_specs, out_specs=out_specs) + res = fun_shard_map(variables, xs) # don't crash + + def test_rep_none_canonicalization_again(self): + # https://github.com/jax-ml/jax/issues/24762 + mesh = jtu.create_mesh((2,), ('i',)) + def f(x): + return jnp.insert(x, 0, 0)[None] + f = shard_map(f, mesh=mesh, in_specs=P('i'), out_specs=P('i')) + f(jnp.zeros(100)) # don't crash + + def test_custom_jvp_symbolic_zeros(self): + # https://github.com/jax-ml/jax/issues/26763 + mesh = jtu.create_mesh((4,), ('i',)) + @jax.custom_jvp + def f(a: jax.Array, b: jax.Array) -> jax.Array: + return a + b + + @partial(f.defjvp, symbolic_zeros=True) + def f_jvp(primals, tangents): + a, b = primals + a_dot, b_dot = tangents + y = f(a, b) + y_dot = jnp.zeros_like(y) + if not isinstance(a_dot, SymbolicZero): + y_dot += a_dot + if not isinstance(b_dot, SymbolicZero): + y_dot += b_dot + return y, y_dot + x = jax.random.normal(jax.random.key(0), (jax.device_count(), 20)) + A = jax.random.normal(jax.random.key(1), (jax.device_count(), 20)) + + g = shard_map(f, mesh=mesh, in_specs=P('i'), out_specs=P('i')) + jax.jvp(lambda x: g(x, A), (x,), (x,)) # don't crash + + def test_cond_pvary_errors(self): + mesh = jtu.create_mesh((1, 1), ('x', 'y')) + def f(x, y): + def true_fn(x, y): + return x + def false_fun(x, y): + return y + return jax.lax.cond(True, true_fn, false_fun, x, y) + x = jnp.arange(4.) + with self.assertRaisesRegex( + TypeError, + r"applying `jax.lax.pcast\(..., \('y',\).*to the output of true_fun"): + shard_map(f, mesh=mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) + + def test_cond_pvary_errors_pytree(self): + mesh = jtu.create_mesh((1, 1), ('x', 'y')) + + def f(x, y): + def true_fn(x, y): + return x, y + def false_fun(x, y): + return y, x + return jax.lax.cond(True, true_fn, false_fun, x, y) + x = jnp.arange(4.) + with self.assertRaisesRegex( + TypeError, + r"applying `jax.lax.pcast\(..., \('y',\).*to the output of true_fun"): + shard_map(f, mesh=mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) + + def test_scan_pvary_errors(self): + mesh = jtu.create_mesh((1, 1), ('i', 'j')) + x = jnp.arange(3.) + y = jnp.arange(3.) + + @partial(shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i')) + def f(x, y): + def body(carry, _): + c1, c2 = carry + return (c2, c1), () # swap the carry + (x_, y_), _ = jax.lax.scan(body, (x, y), (), length=2) + return x_, y_ + + with self.assertRaisesRegex( + TypeError, + r"This might be fixed by applying `jax.lax.pcast\(..., \('i',\).*to" + r' the initial'): + f(x, y) + + @partial(shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i')) + def g(x, y): + def body(carry, _): + c1, c2 = carry + return (c2, c1), () + y = jax.lax.pcast(y, 'i', to='varying') # fix the issue + (x_, y_), _ = jax.lax.scan(body, (x, y), (), length=2) + return x_, y_ + + g(x, y) # doesn't crash + + def test_scan_pvary_errors2(self): + mesh = jtu.create_mesh((1, 1), ('i', 'j')) + x = jnp.arange(3.) + y = jnp.arange(3.) + z = jnp.arange(3.) + + @partial(shard_map, mesh=mesh, in_specs=(P('i'), P(), P(('i', 'j'))), out_specs=P(('i', 'j'))) + def f(x, y, z): + def body(carry, _): + c1, c2, c3 = carry + return (c3, c1, c2), () # swap the carry + carry, _ = jax.lax.scan(body, (x, y, z), (), length=2) + return carry + + with self.assertRaisesRegex( + TypeError, + r'This might be fixed by:\n \* applying `jax.lax.pcast\(...,' + r" \('j',\)"): + f(x, y, z) + + @partial(shard_map, mesh=mesh, in_specs=(P('i'), P(), P(('i', 'j'))), out_specs=P(('i', 'j'))) + def g(x, y, z): + def body(carry, _): + c1, c2, c3 = carry + return (c3, c1, c2), () # swap the carry + + x = jax.lax.pcast(x, 'j', to='varying') # fix the issue + y = jax.lax.pcast(y, ('i', 'j'), to='varying') + carry, _ = jax.lax.scan(body, (x, y, z), (), length=2) + return carry + + g(x, y, z) # doesn't crash + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_shmap_full_manual_context_explicit(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, P('x', 'y')) + + @partial(jax.shard_map, out_specs=P('x', 'y')) + def f(x): + self.assertEqual(get_abstract_mesh().manual_axes, ('x', 'y')) + self.assertEqual(x.aval.vma, {'x', 'y'}) + out = x * 2 + self.assertEqual(out.aval.vma, {'x', 'y'}) + return out + + out = f(arr) + self.assertArraysEqual(out, np_inp * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + jax.jit(f)(arr) # doesn't crash + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_shmap_partial_manual_explicit(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, P('x', 'y')) + + @partial(jax.shard_map, axis_names=frozenset('x'), out_specs=P('x')) + def f(x): + self.assertEqual(get_abstract_mesh().manual_axes, ('x',)) + self.assertEqual(get_abstract_mesh().explicit_axes, ('y',)) + self.assertEqual(x.aval.sharding.spec, P(None, 'y')) + self.assertEqual(x.aval.vma, {'x'}) + out = x * 2 + self.assertEqual(out.aval.sharding.spec, P(None, 'y')) + self.assertEqual(out.aval.vma, {'x'}) + return out + + out = jax.jit(f)(arr) + self.assertArraysEqual(out, np_inp * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Auto,) * 2) + def test_shmap_full_manual_context_auto(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, P('x', 'y')) + + @partial(jax.shard_map, in_specs=P('x', 'y'), out_specs=P('x', 'y')) + def f(x): + self.assertEqual(get_abstract_mesh().manual_axes, ('x', 'y')) + self.assertEqual(x.aval.vma, {'x', 'y'}) + out = x * 2 + self.assertEqual(out.aval.vma, {'x', 'y'}) + return out + + out = f(arr) + self.assertArraysEqual(out, np_inp * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + jax.jit(f)(arr) # doesn't crash + + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Auto,) * 2) + def test_shmap_partial_manual_auto(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, P('x', 'y')) + + @partial(jax.shard_map, axis_names=frozenset('x'), in_specs=P('x'), + out_specs=P('x')) + def f(x): + self.assertEqual(get_abstract_mesh().manual_axes, ('x',)) + self.assertEqual(get_abstract_mesh().auto_axes, ('y',)) + self.assertEqual(x.aval.vma, {'x'}) + out = x * 2 + self.assertEqual(out.aval.vma, {'x'}) + return out + + out = jax.jit(f)(arr) + self.assertArraysEqual(out, np_inp * 2) + + def test_no_mesh_context_error(self): + with self.assertRaisesRegex(ValueError, "The context mesh cannot be empty"): + jax.shard_map(lambda x: x, in_specs=P(), out_specs=P())(np.arange(8)) + + def test_pvary_in_shmap_of_grad(self): + mesh = jtu.create_mesh((2,), 'x') + + def g(x): + return jnp.mean(x ** 2) + + def f(x): + val, grad = jax.value_and_grad(g)(x) + return (jnp.atleast_1d(val), jnp.atleast_1d(grad)) + + jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x') + )(jnp.ones(2,)) # doesn't crash + + @jtu.with_explicit_mesh((2,), ('data',)) + def test_jnp_histogram(self, mesh): + x = jnp.arange(8 * 4 * 2).reshape(8, 4, 2) + + def f(x, bin_edges): + hist, _ = jax.vmap(lambda q: jnp.histogram(q, bins=bin_edges))(x) + return hist + + bin_edges = jnp.histogram_bin_edges(x, bins=100) + g = jax.shard_map(f, in_specs=(P('data'), P()), out_specs=P('data')) + g(x, bin_edges) # doesn't crash + jax.jit(g)(x, bin_edges) # doesn't crash + + def test_shmap_linearize_and_linearize_transpose_error(self): + mesh = jtu.create_mesh((2,), ('x',)) + + def f(x): + return jnp.mean(x ** 2) + + def m(p, t): + out_p, fwd = jax.linearize(f, p) + out_t = fwd(t) + bwd = jax.linear_transpose(fwd, p) + return bwd(out_t) + + with self.assertRaisesRegex( + ValueError, + r"applying `jax.lax.pcast\(..., \('x',\).*to the primal value passed"): + shard_map(partial(m, jnp.array([1.])), mesh=mesh, in_specs=P('x'), + out_specs=P('x'))(jnp.ones((2,))) # doesn't crash + + def m2(p, t): + p = jax.lax.pcast(p, 'x', to='varying') # fixes the issue + out_p, fwd = jax.linearize(f, p) + out_t = fwd(t) + bwd = jax.linear_transpose(fwd, p) + return bwd(out_t) + + shard_map(partial(m2, jnp.array([1.])), mesh=mesh, in_specs=P('x'), + out_specs=P('x'))(jnp.ones((2,))) # doesn't crash + + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Auto,) * 2) + def test_argmax_pvary(self, mesh): + @jax.shard_map(in_specs=P('x', 'y'), out_specs=P('x', 'y')) + def argmax_impl(x): + y = x.argmax(axis=-1, keepdims=1) + return y + + argmax_impl(jax.random.normal(jax.random.key(0), (1024, 1024))) # doesn't crash + + @parameterized.parameters([False, True]) + def test_smap(self, jit): + mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) + np_inp = np.arange(32.).reshape(4, 8) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) + + def g(x): + self.assertEqual(get_abstract_mesh().manual_axes, ('x', 'y')) + self.assertEqual(get_abstract_mesh().auto_axes, ('z',)) + self.assertEqual(x.aval.vma, {'x', 'y'}) + out = x * x + self.assertEqual(out.aval.vma, {'x', 'y'}) + return out + + def h(x): + self.assertEqual(get_abstract_mesh().manual_axes, ('x',)) + self.assertEqual(get_abstract_mesh().auto_axes, ('y', 'z')) + self.assertEqual(x.aval.vma, {'x'}) + out = jax.smap(g, in_axes=0, out_axes=0, axis_name='y')(x) + self.assertEqual(out.aval.vma, {'x'}) + return out + + def f(x): + return jax.smap(h, in_axes=0, out_axes=0, axis_name='x')(x) + + if jit: + f = jax.jit(f) + + with jax.set_mesh(mesh): + out = f(arr) + self.assertArraysEqual(out, np_inp * np_inp) + + @parameterized.parameters([False, True]) + @jtu.with_explicit_mesh((2, 2, 2), ('x', 'y', 'z')) + def test_smap_explicit(self, jit, mesh): + np_inp = np.arange(32.).reshape(4, 8) + arr = jax.device_put(np_inp, P('x', 'y')) + + def g(x): + self.assertEqual(get_abstract_mesh().manual_axes, ('x', 'y')) + self.assertEqual(get_abstract_mesh().explicit_axes, ('z',)) + self.assertEqual(x.aval.vma, {'x', 'y'}) + out = x * x + self.assertEqual(out.aval.vma, {'x', 'y'}) + return out + + def h(x): + self.assertEqual(get_abstract_mesh().manual_axes, ('x',)) + self.assertEqual(get_abstract_mesh().explicit_axes, ('y', 'z')) + self.assertEqual(x.aval.vma, {'x'}) + out = jax.smap(g, in_axes=0, out_axes=0, axis_name='y')(x) + self.assertEqual(out.aval.vma, {'x'}) + return out + + def f(x): + return jax.smap(h, out_axes=0, axis_name='x')(x) + + if jit: + f = jax.jit(f) + + out = f(arr) + self.assertArraysEqual(out, np_inp * np_inp) + + @parameterized.parameters([False, True]) + @jtu.with_explicit_mesh((2,), ('x',), axis_types=(AxisType.Auto,)) + def test_smap_replicated(self, jit, mesh): + @jax.smap(in_axes=None, out_axes=None, axis_name='x') + def f(x): + return x * 2 + out = f(np.arange(8)) + self.assertArraysEqual(out, np.arange(8) * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P())) + + @parameterized.parameters([False, True]) + @jtu.with_explicit_mesh((2,), ('data',), axis_types=(AxisType.Auto,)) + def test_smap_replicated_sharded(self, jit, mesh): + @jax.smap(in_axes=(None, 0), out_axes=(None, 0), axis_name='data') + def f(x, y): + return x * 2, y * 2 + + out1, out2 = f(np.arange(8), np.arange(8)) + self.assertArraysEqual(out1, np.arange(8) * 2) + self.assertEqual(out1.sharding, NamedSharding(mesh, P())) + self.assertArraysEqual(out2, np.arange(8) * 2) + self.assertEqual(out2.sharding, NamedSharding(mesh, P('data'))) + + @jax.smap(in_axes=(None, 0), out_axes=0, axis_name='data') + def g(x, y): + return x + y + + out = g(np.arange(4), np.arange(8)) + self.assertEqual(out.sharding, NamedSharding(mesh, P('data'))) + + @parameterized.parameters([False, True]) + @jtu.with_explicit_mesh((2,), ('x',), axis_types=(AxisType.Auto,)) + def test_smap_auto_error(self, jit, mesh): + with self.assertRaisesRegex(TypeError, "in_axes was not specified"): + jax.smap(lambda x: x * 2, out_axes=0, axis_name='x')(np.arange(4)) + + @parameterized.parameters([False, True]) + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), + axis_types=(AxisType.Explicit, AxisType.Auto)) + def test_smap_auto_explicit(self, jit, mesh): + def f(x): + self.assertEqual(x.aval.vma, {'x'}) + return x * 2 + + arr = jax.device_put(np.arange(4), P('x')) + g = jax.smap(f, out_axes=0, axis_name='x') + if jit: + g = jax.jit(g) + out = g(arr) + self.assertArraysEqual(out, np.arange(4) * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + def g(x): + self.assertEqual(x.aval.vma, {'y'}) + return x * 2 + + arr = jax.device_put(np.arange(4), P('y')) + g = jax.smap(g, in_axes=0, out_axes=0, axis_name='y') + if jit: + g = jax.jit(g) + out = g(arr) + self.assertArraysEqual(out, np.arange(4) * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('y'))) + + @config.remove_size_one_mesh_axis_from_type(True) + @jtu.with_explicit_mesh((1,), 'x') + def test_pvary_no_op_one_sized_mesh_axis(self, mesh): + @jax.jit + def f(x): + return jax.lax.pcast(x, 'x', to='varying') + + jaxpr = f.trace(jnp.arange(8)).jaxpr + self.assertNotIn('pvary', str(jaxpr)) + + @parameterized.parameters([False, True]) + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), + axis_types=(AxisType.Explicit, AxisType.Auto)) + def test_smap_auto_explicit_nest(self, jit, mesh): + def g(b): + self.assertEqual(b.aval.vma, {'x', 'y'}) + return jnp.sin(b) + + def f(a): + self.assertEqual(a.aval.vma, {'y'}) + b = a * 2 + return jax.smap(g, in_axes=1, out_axes=1, axis_name='x')(b) + + arr = jax.device_put(np.arange(16).reshape(8, 2), P('y')) + h = jax.smap(f, in_axes=0, out_axes=0, axis_name='y') + if jit: + h = jax.jit(h) + h(arr) # doesn't crash + + @parameterized.parameters([False, True]) + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), + axis_types=(AxisType.Explicit, AxisType.Auto)) + def test_smap_auto_explicit_nest_inner_none(self, jit, mesh): + def g(b): + self.assertEqual(b.aval.vma, {'y'}) + return jnp.sin(b) + + def f(a): + self.assertEqual(a.aval.vma, {'y'}) + b = a * 2 + # Going manual over explicit axis `x` but in_axes is Infer and since + # input has no sharding, it will default to None. + return jax.smap(g, out_axes=1, axis_name='x')(b) + + arr = jax.device_put(np.arange(16).reshape(8, 2), P('y')) + h = jax.smap(f, in_axes=0, out_axes=0, axis_name='y') + if jit: + h = jax.jit(h) + h(arr) # doesn't crash + + @parameterized.parameters([False, True]) + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), + axis_types=(AxisType.Explicit, AxisType.Auto)) + def test_smap_auto_explicit_nest_mesh_call_time(self, jit, mesh): + @jax.smap(in_axes=1, out_axes=1, axis_name='x') + def g(b): + return jnp.sin(b) + + @jax.smap(in_axes=0, out_axes=0, axis_name='y') + def f(a): + self.assertEqual(a.aval.vma, {'y'}) + b = a * 2 + return g(b) + + if jit: + f = jax.jit(f) + + arr = jax.device_put(np.arange(16).reshape(8, 2), P('y')) + f(arr) # doesn't crash + + @parameterized.parameters([False, True]) + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), + axis_types=(AxisType.Auto, AxisType.Auto)) + def test_smap_nested_psum(self, jit, mesh): + @jax.smap(axis_name='x', in_axes=0, out_axes=0) + def f(x): + x = jnp.sin(x) + + @jax.smap(axis_name='y', in_axes=0, out_axes=None) + def g(x): + self.assertEqual(jax.typeof(x).vma, {'x', 'y'}) + x = jax.lax.psum(x, 'y') + self.assertEqual(jax.typeof(x).vma, {'x'}) + return x + + x = g(x) + self.assertEqual(jax.typeof(x).vma, {'x'}) + return x + + if jit: + f = jax.jit(f) + + x = jnp.arange(4.) + f(x) # asserts in f + + @jtu.with_explicit_mesh((2,), 'x') + def test_linalg_inv(self, mesh): + key = jax.random.key(123) + arr = jax.random.uniform(key, shape=(4,5,5), out_sharding=P('x')) + + @shard_map(out_specs=P('x')) + def f(x): + return jax.lax.map(jnp.linalg.inv, x) + + f(arr) # doesn't crash + + @jtu.with_explicit_mesh((2,), 'x') + def test_mutable_array_arg_basic(self, mesh): + x_ref = core.new_ref(jnp.zeros(4, 'float32', out_sharding=P('x'))) + + @jax.jit + @shard_map(out_specs=None) + def f(x_ref): + x_ref[...] += jax.lax.axis_index('x').astype('float32') + + f(x_ref) + + self.assertAllClose(x_ref[...], jnp.array([0., 0., 1., 1.]), + check_dtypes=False) + + @jtu.with_explicit_mesh((2,), 'x') + def test_mutable_array_internal_basic(self, mesh): + x = jnp.arange(4, dtype='float32', out_sharding=P('x')) + + @jax.jit + @shard_map(out_specs=P('x')) + def f(x): + x_ref = core.new_ref(jnp.zeros_like(x)) + x_ref[...] = x + return x_ref[...] + + y = f(x) + + self.assertAllClose(y, x, check_dtypes=False) + + def test_random_beta_vma(self): + mesh = jtu.create_mesh((2,), 'dp') + + rng = jax.random.key(42) + f = shard_map( + lambda x, y, z: jax.random.beta(jax.lax.pcast(x, ('dp',), to='varying'), + y, z), + mesh=mesh, in_specs=(P(), P('dp'), P('dp')), out_specs=P('dp')) + res = f(rng, jnp.ones((64, 1)), jnp.ones((64, 1))) + # explicit key resuse. + a, b = res.reshape(2, 32) + self.assertAllClose(a, b) # Key reuse. + + # Also works without key-reuse: + rng = jax.random.key(42) + f = shard_map(lambda x, y, z: jax.random.beta(x[0], y, z), mesh=mesh, + in_specs=(P('dp'), P('dp'), P('dp')), out_specs=P('dp')) + f(jax.random.split(rng, 2), jnp.ones((64, 1)), jnp.ones((64, 1))) # doesn't crash + + rng = jax.random.key(42) + f = shard_map(lambda x, y, z: jax.random.beta(x[0], y, z), mesh=mesh, + in_specs=(P('dp'), P(), P()), out_specs=P('dp')) + f(jax.random.split(rng, 2), jnp.ones((64, 1)), jnp.ones((64, 1))) # doesn't crash + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_pcast_to_unreduced(self, mesh): + arr1 = jax.device_put(np.arange(8).reshape(4, 2), P('x', 'y')) + arr2 = jax.device_put(np.arange(8).reshape(2, 4), P('y', None)) + + @jax.jit + @jax.shard_map(out_specs=P('x', unreduced={'y'})) + def f(x, y): + z = jnp.dot(x, y) + return jax.lax.pcast(z, 'y', to='unreduced') + f(arr1, arr2) # doesn't crash + + @jax.jit + @jax.shard_map(out_specs=P('x', unreduced={'y'})) + def f(x, y): + z = jnp.dot(x, y) + a = jax.lax.pcast(z, 'y', to='unreduced') + return jax.lax.pcast(a, ('x', 'y'), to='reduced') + + with self.assertRaisesRegex( + ValueError, "jax.lax.pcast can only accept axis_name which"): + f(arr1, arr2) + + @parameterized.named_parameters( + ('1', P('x'), {'x'}, P(None, 'y')), + ('2', P(None, 'y'), {'y'}, P('x', None)) + ) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_partial_manual_explicit_shmap(self, out_spec, axis_name, aval_spec, + mesh): + @jax.shard_map(out_specs=out_spec, axis_names=axis_name) + def f(x): + self.assertEqual(jax.typeof(x).sharding.spec, aval_spec) + return x * 2 + + arr = jax.device_put(np.arange(16).reshape(8, 2), P('x', 'y')) + out = f(arr) + out2 = jax.jit(f)(arr) + self.assertEqual(out.sharding, out2.sharding) + + @jtu.with_explicit_mesh((1, 2, 2), ('x', 'y', 'z')) + def test_axis_index_explicit_mesh_eager_shmap(self, mesh): + def f(): + return jnp.array([jax.lax.axis_index(n) for n in mesh.axis_names]) + + jax.shard_map(f, out_specs=P(mesh.axis_names))() # doesn't crash + jax.jit(jax.shard_map(f, out_specs=P(mesh.axis_names)))() # doesn't crash + + jax.shard_map(f, out_specs=P(), check_vma=False)() # doesn't crash + jax.jit(jax.shard_map(f, out_specs=P(), check_vma=False))() # doesn't crash + + @config.remove_size_one_mesh_axis_from_type(True) + @jtu.with_explicit_mesh((2, 1), ('x', 'y')) + def test_remove_one_sized_mesh_axis_from_vma(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, P('x', 'y')) + + @jax.jit + @jax.shard_map(in_specs=P('x', 'y'), out_specs=P()) + def f(x): + self.assertEqual(x.aval.vma, {'x'}) + out = jax.lax.psum(x, 'x') + self.assertEqual(out.aval.vma, frozenset()) + return out + + out = f(arr) + self.assertEqual(out.shape, (4, 2)) + self.assertEqual(out.sharding, NamedSharding(mesh, P(None, None))) + self.assertArraysEqual(out, np_inp[:4, :] + np_inp[4:, :]) + + @jtu.with_explicit_mesh((2,), ('x',)) + def test_varargs_error(self, mesh): + @jax.shard_map(in_specs=jax.P('x'), out_specs=()) + def f(*foos, **kwargs): + return () + + with self.assertRaises(ValueError): # not AssertionError + f(jnp.arange(3.), jnp.arange(3.), jnp.arange(3.)) + + @jtu.with_explicit_mesh((2,), 'x') + def test_reduced_vary_automatically_inserted(self, mesh): + arr1 = jax.device_put(np.arange(8.).reshape(4, 2), P('x', None)) + arr2 = jax.device_put(np.arange(12.).reshape(2, 6), + P(None, None, reduced={'x'})) + + @jax.jit + @jax.shard_map(in_specs=(P('x', None), P(None, None, reduced={'x'})), + out_specs=P('x', None)) + def f(x, y): + return jnp.dot(x, y) + + out = f(arr1, arr2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + self.assertArraysEqual(out, arr1 @ arr2) + + def g(x, y): + return f(x, y).sum() + + out1, out2 = jax.jit(jax.grad(g, argnums=(0, 1)))(arr1, arr2) + self.assertEqual(out1.sharding, arr1.sharding) + self.assertEqual(out2.sharding, + NamedSharding(mesh, P(None, None, unreduced={'x'}))) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_reduced_vary_automatically_inserted_psum(self, mesh): + arr1 = jax.device_put(np.arange(8.).reshape(4, 2), P('x', 'y')) + arr2 = jax.device_put(np.arange(12.).reshape(2, 6), + P('y', None, reduced={'x'})) + + @jax.jit + @jax.shard_map(in_specs=(P('x', 'y'), P('y', None, reduced={'x'})), + out_specs=P('x', None)) + def f(x, y): + self.assertEqual(x.vma, {'x', 'y'}) + self.assertEqual(y.vma, {'y'}) + self.assertEqual(y.aval.sharding.spec.reduced, {'x'}) + z = jnp.dot(x, y) + self.assertEqual(z.vma, {'x', 'y'}) + return jax.lax.psum(z, axis_name='y') + + out = f(arr1, arr2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + self.assertArraysEqual(out, jnp.dot(arr1, arr2, out_sharding=P('x'))) + + def g(x, y): + return f(x, y).sum() + + out1, out2 = jax.jit(jax.grad(g, argnums=(0, 1)))(arr1, arr2) + self.assertEqual(out1.sharding, arr1.sharding) + self.assertEqual(out2.sharding, + NamedSharding(mesh, P('y', None, unreduced={'x'}))) + + @config.numpy_dtype_promotion('standard') + @jtu.with_explicit_mesh((2,), 'x') + def test_astype_reduced_fwd_unreduced_bwd_shmap(self, mesh): + inputs = jax.device_put(np.ones((32, 64), dtype=jnp.bfloat16), P('x', None)) + params = jax.device_put(np.ones((64, 64), dtype=jnp.float32), + P(None, None, reduced={'x'})) + targets = jax.device_put(np.ones((32, 64), dtype=jnp.bfloat16), P('x', None)) + + @jax.shard_map(in_specs=(P('x', None), P(None, None, reduced={'x'})), + out_specs=P('x', None)) + def dot(inputs, params): + self.assertEqual(params.aval.sharding.spec.reduced, {'x'}) + params = params.astype(jnp.bfloat16) + self.assertEqual(params.aval.sharding.spec.reduced, {'x'}) + return jnp.dot(inputs.astype(jnp.bfloat16), params) + + @jax.jit + def loss_fn(inputs, params, targets): + out = dot(inputs, params) + return jnp.mean(jnp.sum((out - targets) ** 2, axis=-1)) + + jax.jit(jax.grad(loss_fn, argnums=1))(inputs, params, targets) # doesn't crash + + @jtu.with_explicit_mesh((2,), 'x') + def test_transpose_unreduced_shmap(self, mesh): + arr1 = jax.device_put(np.arange(8.).reshape(2, 4), P(reduced={'x'})) + arr2 = jax.device_put(np.arange(12.).reshape(2, 6), P(None, 'x')) + + @jax.shard_map(out_specs=P(None, 'x')) + def f(x, y): + x_ = x.T + return jnp.dot(x_, y) + + @jax.jit + def g(x, y): + return f(x, y).sum() + + out = g(arr1, arr2) + self.assertEqual(out.sharding, NamedSharding(mesh, P())) + self.assertArraysEqual(out, (arr1.T @ arr2).sum()) + + out1, out2 = jax.jit(jax.grad(g, argnums=(0, 1)))(arr1, arr2) + self.assertEqual(out1.sharding, + NamedSharding(mesh, P(None, None, unreduced={'x'}))) + self.assertEqual(out2.sharding, arr2.sharding) + + @parameterized.named_parameters( + ('mul', jax.lax.mul), + ('add', jax.lax.add), + ) + @jtu.with_explicit_mesh((2,), 'x') + def test_one_input_sharded_another_reduced_shmap(self, func, mesh): + np1 = np.arange(16.) + np2 = np.arange(8.) + arr1 = jax.device_put(np1, P('x')) + arr2 = jax.device_put(np2, P(None, reduced={'x'})) + + @jax.jit + @jax.shard_map(out_specs=P()) + def f(x, y): + z = func(x, y) + return jax.lax.psum(z, 'x') + + out = f(arr1, arr2) + self.assertEqual(out.sharding, NamedSharding(mesh, P(None))) + with jax.set_mesh(empty_concrete_mesh): + ex_out = np.sum([func(s.data, np2) for s in arr1.addressable_shards], + axis=0) + self.assertArraysEqual(out, ex_out) + + @jax.jit + def g(x, y): + return f(x, y).sum() + + out1, out2 = jax.jit(jax.grad(g, argnums=(0, 1)))(arr1, arr2) + self.assertEqual(out1.sharding, NamedSharding(mesh, P('x'))) + self.assertEqual(out2.sharding, + NamedSharding(mesh, P(None, unreduced={'x'}))) + + arr3 = jax.device_put(np2, P(None)) + ex_out1, ex_out2 = jax.jit(jax.grad(g, argnums=(0, 1)))(arr1, arr3) + self.assertArraysEqual(out1, ex_out1) + self.assertArraysEqual(jax.reshard(out2, P()), ex_out2) + + @jtu.with_explicit_mesh((2,), 'x') + def test_reduced_pcast_fwd_unreduced_bwd(self, mesh): + np1 = np.arange(8.) + arr = jax.device_put(np1, P(None, reduced={'x'})) + + @jax.jit + @jax.shard_map(out_specs=P('x')) + def f(x): + return jax.lax.pcast(x, 'x', to='varying') + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + self.assertArraysEqual(out, np.concat([np1, np1], axis=0)) + + @jax.jit + def g(x): + return f(x).sum() + + out = jax.jit(jax.grad(g))(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P(None, unreduced={'x'}))) + + arr2 = jax.device_put(np1, P(None)) + ex_out = jax.jit(jax.grad(g))(arr2) + self.assertArraysEqual(jax.reshard(out, P()), ex_out) + + @parameterized.named_parameters( + ('mul', jax.lax.mul), + ('add', jax.lax.add), + ) + @jtu.with_explicit_mesh((2,), 'x') + def test_one_input_sharded_another_reduced_shmap_no_psum(self, func, mesh): + np1 = np.arange(16.) + np2 = np.arange(8.) + arr1 = jax.device_put(np1, P('x')) + arr2 = jax.device_put(np2, P(None, reduced={'x'})) + + @jax.jit + @jax.shard_map(out_specs=P('x')) + def f(x, y): + z = func(x, y) + return z + + out = f(arr1, arr2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + with jax.set_mesh(empty_concrete_mesh): + ex_out = [func(s.data, np2) for s in arr1.addressable_shards] + for s, e in zip(out.addressable_shards, ex_out): + self.assertArraysEqual(s.data, e) + + @jax.jit + def g(x, y): + return f(x, y).sum() + + out1, out2 = jax.jit(jax.grad(g, argnums=(0, 1)))(arr1, arr2) + self.assertEqual(out1.sharding, NamedSharding(mesh, P('x'))) + self.assertEqual(out2.sharding, + NamedSharding(mesh, P(None, unreduced={'x'}))) + + arr3 = jax.device_put(np2, P(None)) + ex_out1, ex_out2 = jax.jit(jax.grad(g, argnums=(0, 1)))(arr1, arr3) + self.assertArraysEqual(out1, ex_out1) + self.assertArraysEqual(jax.reshard(out2, P()), ex_out2) + + @jtu.with_explicit_mesh((2,), 'x') + def test_split_with_unused_result_in_shardmap(self, mesh): + arr = jax.device_put(jnp.ones(8), P('x')) + + @jax.shard_map(in_specs=P('x'), out_specs=P('x')) + def f(x): + a, _ = jnp.split(x, 2, axis=0) # Important that one result is unused. + return a + + def g(x): + a = f(x) + b = 0.1 * a.mean(keepdims=True) + return b.squeeze(0) + + jax.jit(jax.grad(g))(arr) # doesn't crash + + @jtu.with_explicit_mesh((2,), 'x') + def test_shmap_primal_type_match_ct_type(self, mesh): + arr = jax.device_put(np.arange(8.), P('x')) + + @jax.jit + @jax.shard_map(in_specs=P(), out_specs=P('x')) + def f(x): + return x * 2 + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + out_g = jax.jit(jax.grad(lambda x: f(x).sum()))(arr) + self.assertEqual(out_g.sharding, NamedSharding(mesh, P('x'))) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_mix_manual_explicit_partial(self, mesh): + arr = jax.device_put(np.arange(16).reshape(8, 2), P(('x', 'y'), None)) + + @jax.jit + @jax.shard_map(out_specs=P('x'), axis_names={'x'}) + def f(x): + self.assertEqual(x.shape, (4, 2)) + self.assertEqual(x.aval.sharding.spec, P('y', None)) + self.assertEqual(x.aval.vma, {'x'}) + return x * 2 + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P(('x', 'y'), None))) + self.assertArraysEqual(out, arr * 2) + + @jtu.with_explicit_mesh((2,), ('x',)) + def test_zero_cotangent_sharding(self, mesh): + @jax.custom_vjp + def inner(x): + return x + + def inner_fwd(x): + return x, None + + def inner_bwd(_, g): + return None, + + inner.defvjp(inner_fwd, inner_bwd) + + @jax.shard_map(out_specs=jax.P('x'), check_vma=False) + def f(x): + return inner(x) + + x = jax.device_put(jnp.arange(8.), jax.P('x')) + jax.grad(lambda x: f(x).sum())(x) # don't crash + + @jtu.with_explicit_mesh((2,), 'x') + def test_vmap_shmap_psum(self, mesh): + arr = jnp.arange(16).reshape(2, 8) + + @jax.shard_map(in_specs=P("x"), out_specs=P(None)) + def f(x): + return jax.lax.psum(x, axis_name='x') + + f(arr) # doesn't crash + jax.vmap(f)(arr) # doesn't crash + + jax.jit(f)(arr) # doesn't crash + jax.jit(jax.vmap(f))(arr) # doesn't crash + + @jtu.with_explicit_mesh((2,), 'x') + def test_subclass_partition_spec_error_message(self, mesh): + class MyP(jax.P): + pass + + @jax.shard_map(in_specs=(jax.P("x"), MyP("x"),), out_specs=()) + def f(x, y): + return () + + with self.assertRaisesRegex(ValueError, 'not evenly divisible'): + f(jnp.arange(8.), jnp.arange(9.)) + + +class FunSpec(NamedTuple): + name: str + num_inputs: int + fun: Callable + out_rep: Callable + valid_types: Callable | None = None -class FunSpec(NamedTuple): - name: str - num_inputs: int - fun: Callable - out_rep: Callable - valid_types: Callable | None = None - fun_specs = [ FunSpec('id', 1, lambda x: x, lambda r: r), FunSpec('flip', 2, lambda x, y: (y, x), lambda r_x, r_y: (r_y, r_x)), @@ -2937,7 +5102,8 @@ def make_mesh(mesh_shape): def test_eager_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref): mesh = self.make_mesh(mesh) args = map(jnp.array, args) - out = shard_map(fun, mesh, in_specs, out_specs)(*args) + out = shard_map(fun, mesh=mesh, in_specs=in_specs, + out_specs=out_specs)(*args) expected = ref(fun, mesh, in_specs, out_specs)(*args) self.assertAllClose(expected, out, check_dtypes=False) @@ -2946,7 +5112,8 @@ def test_eager_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref): def test_jit_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref): mesh = self.make_mesh(mesh) args = map(jnp.array, args) - out = jax.jit(shard_map(fun, mesh, in_specs, out_specs))(*args) + out = jax.jit(shard_map(fun, mesh=mesh, in_specs=in_specs, + out_specs=out_specs))(*args) expected = ref(fun, mesh, in_specs, out_specs)(*args) self.assertAllClose(expected, out, check_dtypes=False) @@ -2956,10 +5123,12 @@ def test_jit_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref): for check_rep in [True, False] ) @jax.default_matmul_precision("float32") + @jtu.skip_on_flag("jax_skip_slow_tests", True) def test_grads(self, fun, mesh, jit, in_specs, out_specs, args, _, check_rep): mesh = self.make_mesh(mesh) args = map(jnp.array, args) - f = shard_map(fun, mesh, in_specs, out_specs, check_rep=check_rep) + f = shard_map(fun, mesh=mesh, in_specs=in_specs, + out_specs=out_specs, check_vma=check_rep) if jit: f = jax.jit(f) jtu.check_grads(f, args, order=2, atol=1e-2, rtol=1e-2) @@ -2967,6 +5136,7 @@ def test_grads(self, fun, mesh, jit, in_specs, out_specs, args, _, check_rep): @parameterized.parameters( sample(jtu.NUM_GENERATED_CASES.value, sample_shmap)) @jax.default_matmul_precision("float32") + @jtu.skip_on_flag("jax_skip_slow_tests", True) def test_grads_closure(self, fun, mesh, jit, in_specs, out_specs, args, _): mesh = self.make_mesh(mesh) no_sharding = [all(elt is None for elt in spec) for spec in in_specs] @@ -2990,7 +5160,7 @@ def test_vmap(self, bdims, fun, mesh, jit, in_specs, out_specs, args, ref): mesh = self.make_mesh(mesh) args = map(jnp.array, args) - f = shard_map(fun, mesh, in_specs, out_specs) + f = shard_map(fun, mesh=mesh, in_specs=in_specs, out_specs=out_specs) if jit: f = jax.jit(f) ans = jax.vmap(f, bdims)(*args) @@ -3042,7 +5212,7 @@ def g(*args): else: slices = map(jnp.stack, zip(*expected_slices)) expected = jax.tree.unflatten(treedef, slices) - tol = 1e-2 if jtu.test_device_matches(['tpu']) else None + tol = 1e-2 if jtu.test_device_matches(['gpu', 'tpu']) else None self.assertAllClose(ans, expected, check_dtypes=False, atol=tol, rtol=tol) @jtu.pytest_mark_if_available('multiaccelerator') @@ -3083,14 +5253,15 @@ def f(x): infer_sharding_from_operands=infer_sharding_from_operands, partition=partition, propagate_user_sharding=propagate_user_sharding, + sharding_rule='i -> i', ) @jax.jit def fwd(a): c = shard_map( f, - mesh, - check_rep=False, + mesh=mesh, + check_vma=False, in_specs=(P('z', ('x', 'y')),), out_specs=P('z', ('x', 'y')))(a) return c @@ -3107,17 +5278,103 @@ def g(x): @jax.jit def f(x): x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P(('i', 'j')))) - re = shard_map(g, mesh, in_specs=P('i'), out_specs=P('i'), - check_rep=False, auto=frozenset({'j'}))(x) + re = shard_map(g, mesh=mesh, in_specs=P('i'), out_specs=P('i'), + check_vma=False, axis_names={'i'})(x) re = jax.lax.with_sharding_constraint(re, NamedSharding(mesh, P(('i', 'j')))) return re self.assertAllClose(f(jnp.arange(8.)), jnp.array([1., 5., 9., 13.])) -@jtu.with_config(jax_use_shardy_partitioner=True) -# TODO(phawkins): enable this test unconditionally once shardy is the default. -@unittest.skipIf(sdy is None, "shardy is not enabled") +def smap_ref(f, in_axes, out_axes, axis_name, axis_size): + del axis_name # no collectives + def smapped(*args): + split_args = zip(*[split_arg(x, d, axis_size) for x, d in zip(args, in_axes)]) + split_result = [f(*xs) for xs in split_args] + return concat_result(split_result, out_axes) + return smapped + +def split_arg(x, d, axis_size): + if d is None: + x = np.tile(x, [axis_size] + [1] * (x.ndim - 1)) + return np.split(x, axis_size, d or 0) + +def concat_result(results, out_axes): + if not isinstance(results[0], (list, tuple)): + return results[0] if out_axes is None else np.concatenate(results, out_axes) + return [res[0] if d is None else np.concatenate(res, d) + for res, d in zip(zip(*results), out_axes)] + +def sample_smap() -> Chooser: + spec = yield fun_specs + mesh_shape = yield mesh_shapes + axis_names = ('i', 'j', 'k', 'l')[:len(mesh_shape)] + mesh = SimpleNamespace(shape=dict(zip(axis_names, mesh_shape)), + axis_names=axis_names) + axis_name = yield axis_names + body_in_types = yield (tys for tys in it.product(input_shapes, repeat=spec.num_inputs) + if not spec.valid_types or spec.valid_types(*tys)) + in_axes = yield from sample_in_axes(body_in_types) + out_rep = spec.out_rep(*[ax is None for ax in in_axes]) + body_out_type = jax.eval_shape(spec.fun, *body_in_types) + out_axes = yield from sample_out_axes(out_rep, body_out_type) + in_str = '(' + ','.join(jax.core.ShapedArray(t.shape, t.dtype).str_short() + for t in body_in_types) + ')' + name = f'{spec.name}_{mesh.shape}_{in_axes}_{out_axes}_{axis_name}_{in_str}' + in_types = [ty.update(shape=dilate_axis(ty.shape, d, mesh.shape[axis_name])) + for ty, d in zip(body_in_types, in_axes)] + args = [np.arange(ty.size, dtype=ty.dtype).reshape(ty.shape) / ty.size + for ty in in_types] + return name, spec, mesh.shape, in_axes, out_axes, axis_name, args + +def sample_in_axes(body_in_types) -> Chooser: + in_axes = [] + for ty in body_in_types: + in_axes.append((yield [None, *range(ty.ndim)])) + return tuple(in_axes) + +def sample_out_axes(out_rep, body_out_type) -> Chooser: + if not isinstance(body_out_type, (list, tuple)): + out_axes = yield [None] * out_rep + list(range(body_out_type.ndim)) + else: + out_axes_ = [] + for ty, r in zip(body_out_type, out_rep): + out_axes_.append((yield [None] * r + list(range(ty.ndim)))) + out_axes = tuple(out_axes_) + return out_axes + +def dilate_axis(shape: tuple[int, ...], i: int | None, size: int) -> tuple[int, ...]: + if i is None: + return shape + shp = list(shape) + shp[i] *= size + return tuple(shp) + +class SmapSystematicTest(jtu.JaxTestCase): + + @staticmethod + def make_mesh(mesh_shape): + return jtu.create_mesh(tuple(mesh_shape.values()), tuple(mesh_shape)) + + @parameterized.parameters( + sample(jtu.NUM_GENERATED_CASES.value, sample_smap)) + def test_against_ref(self, fun_spec, mesh_shape, in_axes, out_axes, axis_name, args): + fun = fun_spec.fun + mesh = self.make_mesh(mesh_shape) + args = map(jnp.array, args) + + with jax.set_mesh(mesh): + fun_ = jax.smap(fun, in_axes=in_axes, out_axes=out_axes, + axis_name=axis_name) + out = jax.jit(fun_)(*args) + + fun_ref = smap_ref(fun, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name, + axis_size=mesh_shape[axis_name]) + expected = fun_ref(*args) + + self.assertAllClose(out, expected, check_dtypes=False) + + class SdyIntegrationTest(jtu.JaxTestCase): # Verify we can lower to a `ManualComputationOp`. @@ -3133,7 +5390,7 @@ def test_shardy_collective_permute(self): shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None) ) def fwd(a): - axis_size = lax.psum(1, 'x') + axis_size = lax.axis_size('x') perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] return lax.ppermute(a, 'x', perm=perm) diff --git a/tests/source_info_test.py b/tests/source_info_test.py index 0f876de1c20f..d26a318bd5c2 100644 --- a/tests/source_info_test.py +++ b/tests/source_info_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial import inspect from absl.testing import absltest @@ -32,7 +31,7 @@ def test_inline_jit_location_uses_callee_location(self): # 'f' should be inlined into both 'g' and 'h', using the source line # information of the call site. In particular, the source line information # of 'h' should not refer to the source information of 'g'. - @partial(jax.jit, inline=True) + @jax.jit(inline=True) def f(x): return lax.add(x, 3) def g(x): return lax.add(f(x), 4) @@ -44,7 +43,7 @@ def h(x): return lax.add(f(x), 5) fn_endline = fn_startline + len(lines) jaxpr = jax.make_jaxpr(fn)(2) for eqn in jaxpr.eqns: - frame = source_info_util.user_frame(eqn.source_info) + frame = source_info_util.user_frame(eqn.source_info.traceback) assert frame is not None, eqn self.assertLessEqual(fn_startline, frame.start_line) self.assertLessEqual(frame.end_line, fn_endline) diff --git a/tests/source_mapper_test.py b/tests/source_mapper_test.py index ffb1d6252f77..64c489560bed 100644 --- a/tests/source_mapper_test.py +++ b/tests/source_mapper_test.py @@ -18,6 +18,33 @@ from jax import numpy as jnp from jax._src import test_util as jtu from jax.experimental import source_mapper +from jax.experimental.source_mapper import hlo + + +HLO_EXAMPLE = r"""HloModule m, entry_computation_layout={()->pred[]} + +FileNames +1 "" +2 "experimental/module.py" +3 "yet/another/test.py" + +FunctionNames +1 "main" +2 "method" + +FileLocations +1 {file_name_id=1 function_name_id=1 line=153 end_line=153 column=2 end_column=31} +2 {file_name_id=3 function_name_id=2 line=35 end_line=35 column=2 end_column=24} +3 {file_name_id=2 function_name_id=2 line=83 end_line=83 column=2 end_column=15} + +StackFrames +1 {file_location_id=1 parent_frame_id=1} +2 {file_location_id=2 parent_frame_id=2} + + +ENTRY %constant_pred () -> pred[] { + ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="opname" stack_frame_id=1} +}""" class SourceMapperTest(jtu.JaxTestCase): @@ -92,5 +119,28 @@ def jax_fn(x, y): # self.assertEqual(src_col, expected_col) +class HLOParserTest(jtu.JaxTestCase): + + def test_hlo_parser(self): + source_map = hlo._parse_hlo_new_format(HLO_EXAMPLE.split("\n")) + print(source_map) + self.assertLen(source_map.sources, 1) + self.assertEqual(source_map.sources[0], "") + mappings = source_map.mappings + constant_line_idx = -1 + for i, line in enumerate(HLO_EXAMPLE.split("\n")): + if r"ROOT %constant" in line: + constant_line_idx = i + break + line_mappings = mappings[constant_line_idx] + gen_col, file_idx, src_line, _ = line_mappings[0] + + self.assertEqual( + file_idx, 0 + ) # "" is the first and only used source + self.assertEqual(src_line, 152) # 153 - 1 + self.assertEqual(gen_col, 2) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/sparse_bcoo_bcsr_test.py b/tests/sparse_bcoo_bcsr_test.py index e839bacbe5fc..a0798957e4f1 100644 --- a/tests/sparse_bcoo_bcsr_test.py +++ b/tests/sparse_bcoo_bcsr_test.py @@ -35,8 +35,7 @@ from jax.experimental.sparse import test_util as sptu from jax.experimental.sparse import util as sparse_util import jax.numpy as jnp -import jax.random -from jax.util import split_list +from jax._src.util import split_list import numpy as np jax.config.parse_flags_with_absl() @@ -603,7 +602,7 @@ def test_bcoo_batched_matmat_default_lowering( # with self.gpu_matmul_warning_context( # "bcoo_dot_general GPU lowering currently does not support this batch-mode computation.*"): matmat_default_lowering_fallback = sp_matmat(lhs_bcoo, rhs) - self.assertArraysEqual(matmat_expected, matmat_default_lowering_fallback) + self.assertArraysAllClose(matmat_expected, matmat_default_lowering_fallback) @jtu.run_on_devices("gpu") def test_bcoo_dot_general_oob_and_unsorted_indices_cusparse(self): @@ -974,6 +973,7 @@ def test_bcoo_spdot_general_nse(self, lhs_shape, rhs_shape): self.assertEqual(out.nse, expected_nse) @jtu.ignore_warning(message="bcoo_dot_general cusparse/hipsparse lowering not available") + @jtu.ignore_warning(category=sparse.CuSparseEfficiencyWarning) def test_bcoo_spdot_general_ad_bug(self): # Regression test for https://github.com/jax-ml/jax/issues/10163 A_indices = jnp.array([[0, 1], [0, 2], [1, 1], [1, 2], [1, 0]]) diff --git a/tests/sparse_nm_test.py b/tests/sparse_nm_test.py deleted file mode 100644 index 9ecf30eb6229..000000000000 --- a/tests/sparse_nm_test.py +++ /dev/null @@ -1,209 +0,0 @@ -# Copyright 2024 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math - -import numpy as np -from absl.testing import absltest -from absl.testing import parameterized - -import jax -import jax.numpy as jnp -from jax import dtypes -from jax._src import config -from jax._src import test_util as jtu -from jax.experimental.sparse import nm - -jax.config.parse_flags_with_absl() - - -class SpmmTest(jtu.JaxTestCase): - def setUp(self): - if not jtu.test_device_matches(["gpu"]): - self.skipTest("Only works on GPU") - if (jtu.test_device_matches(["cuda"]) and - not jtu.is_cuda_compute_capability_at_least("8.0")): - self.skipTest("Only works on GPUs with capability >= sm80") - super().setUp() - - # ----- Test different input shapes - @parameterized.product( - tile_m=(32, 128), - tile_n=(32, 128), - tile_k=(32, 128), - batch=(None, 5), - sparse_idx=(0, 1), - ) - @jtu.run_on_devices("gpu") - def test_shapes(self, tile_m, tile_n, tile_k, batch, sparse_idx): - # Build keyword arguments - kwargs = { - "dimension_numbers": (((1,), (1,)), (tuple(), tuple())), - "sparse_operand_idx": sparse_idx, - } - if batch: - kwargs["dimension_numbers"] = (((2,), (2,)), ((0,), (0,))) - - # Build input data - batch_dims = (batch,) if batch else tuple() - lhs = ( - (np.arange((batch or 1) * tile_m * tile_k) % 11) - .astype(dtypes.bfloat16) - .reshape(batch_dims + (tile_m, tile_k)) - ) - rhs = ( - (np.arange((batch or 1) * tile_n * tile_k) % 13) - .astype(dtypes.bfloat16) - .reshape(batch_dims + (tile_n, tile_k)) - ) - - # Build sparsity mask and metadata - sp = [lhs, rhs][sparse_idx] - mask = np.tile([True, False], math.prod(sp.shape) // 2).reshape(sp.shape) - sparse = sp[mask].reshape(sp.shape[:-1] + (sp.shape[-1] // 2,)) - meta = nm.nm_pack(mask) - - # Calculate sparse and dense dots - if sparse_idx == 0: - dot_sparse = nm.nm_spmm(sparse, rhs, meta, **kwargs) - dot_dense = jnp.einsum("...mk,...nk->...mn", (lhs * mask), rhs) - else: - dot_sparse = nm.nm_spmm(lhs, sparse, meta, **kwargs) - dot_dense = jnp.einsum("...mk,...nk->...mn", lhs, (rhs * mask)) - - # Verify the result - jtu.check_eq(dot_sparse, dot_dense.astype(dtypes.bfloat16)) - - # ----- Test different input types - @parameterized.product( - lhs_type=[jnp.int8, jnp.int16, jnp.float16, jnp.bfloat16], - rhs_type=[jnp.bfloat16], - output_type=[jnp.bfloat16, jnp.float32], - ) - @jtu.run_on_devices("gpu") - def test_types(self, lhs_type, rhs_type, output_type): - tile_m, tile_n, tile_k = 64, 32, 128 - - # Build input data - lhs = ( - (np.arange(tile_m * tile_k) % 17) - .astype(lhs_type) - .reshape((tile_m, tile_k)) - ) - rhs = ( - (np.arange(tile_k * tile_n) % 19) - .astype(rhs_type) - .reshape((tile_k, tile_n)) - ) - - # Build sparsity mask and metadata - mask = np.tile([True, False], tile_m * tile_k // 2).reshape(lhs.shape) - sparse = lhs[mask].reshape(tile_m, tile_k // 2) - meta = nm.nm_pack(mask) - - # Calculate sparse and dense dots - dot_sparse = nm.nm_spmm(sparse, rhs, meta, output_dtype=output_type) - dot_dense = (lhs * mask) @ rhs - - # Verify the result - jtu.check_close(dot_sparse, dot_dense.astype(output_type), rtol=0.01) - - # ----- Test validation - @jtu.run_on_devices("gpu") - def test_validate_nm_pack(self): - with self.assertRaisesRegex(TypeError, "Mask should be bool"): - nm.nm_pack(jnp.zeros(16, jnp.int8)) - with self.assertRaisesRegex( - TypeError, "Inner dimension size should be divisible by 16" - ): - nm.nm_pack(jnp.array([False] * 8)) - - @jtu.run_on_devices("gpu") - def test_validate_nm_spmm(self): - batch, tile_m, tile_n, tile_k = 2, 64, 32, 128 - lhs = jnp.zeros((batch, tile_m, tile_k // 2), dtype=jnp.bfloat16) - rhs = jnp.zeros((batch, tile_k, tile_n), dtype=jnp.bfloat16) - meta = jnp.zeros((batch, tile_m, tile_k // 16), dtype=jnp.uint16) - - if config.enable_x64.value: - with self.assertRaisesRegex(TypeError, "Unsupported lhs input type"): - nm.nm_spmm(jnp.zeros(lhs.shape, dtype=jnp.int64), rhs, meta) - with self.assertRaisesRegex(TypeError, "Unsupported rhs input type"): - nm.nm_spmm(lhs, jnp.zeros(rhs.shape, dtype=jnp.int64), meta) - with self.assertRaisesRegex(TypeError, "Unsupported output type"): - nm.nm_spmm(lhs, rhs, meta, output_dtype=jnp.int64) - - # Check dimension numbers - nm_spmm_with_dnums = lambda c, b: nm.nm_spmm( - lhs, rhs, meta, dimension_numbers=(c, b) - ) - with self.assertRaisesRegex( - TypeError, "Only single contracting dimension is supported" - ): - nm_spmm_with_dnums(((0, 2), (0, 1)), (tuple(), tuple())) - with self.assertRaisesRegex( - TypeError, "Incorrect dimension numbers for lhs" - ): - nm_spmm_with_dnums(((2,), (1,)), ((2,), (0,))) - with self.assertRaisesRegex( - TypeError, "Incorrect dimension numbers for rhs" - ): - nm_spmm_with_dnums(((2,), (1,)), ((0,), (1,))) - with self.assertRaisesRegex( - TypeError, "Only single non-contracting dimension is supported" - ): - nm_spmm_with_dnums(((2,), (1,)), (tuple(), tuple())) - with self.assertRaisesRegex( - TypeError, "Batch dimension sizes do not match" - ): - nm.nm_spmm( - lhs, - rhs.reshape(1, tile_k, tile_n * batch), - meta, - dimension_numbers=(((2,), (1,)), ((0,), (0,))), - ) - - # Check metadata - nm_spmm_with_meta = lambda m: nm.nm_spmm( - lhs, rhs, m, dimension_numbers=(((2,), (1,)), ((0,), (0,))) - ) - with self.assertRaisesRegex(TypeError, "Metadata must be uint16"): - nm_spmm_with_meta(jnp.zeros(meta.shape, dtype=jnp.uint8)) - with self.assertRaisesRegex( - TypeError, "Metadata shape must match the operand shape" - ): - nm_spmm_with_meta(meta.reshape(1, batch * tile_m, tile_k // 16)) - with self.assertRaisesRegex( - TypeError, - "Metadata must be exactly 8 times less than the contracting dimension" - " for 2:4 structured sparsity", - ): - nm_spmm_with_meta(jnp.repeat(meta, 2, axis=-1)) - with self.assertRaisesRegex( - TypeError, "Contracting dimension must be the minor one" - ): - nm.nm_spmm(lhs, rhs, meta, dimension_numbers=(((1,), (1,)), ((0,), (0,)))) - with self.assertRaisesRegex( - TypeError, "Contracting dimension sizes should have 2:4 ratio" - ): - nm.nm_spmm( - lhs, - jnp.repeat(rhs, 2, axis=1), - meta, - dimension_numbers=(((2,), (1,)), ((0,), (0,))), - ) - - -if __name__ == "__main__": - absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index eb8d70be1f05..30240ff69c50 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -16,6 +16,8 @@ from functools import partial import itertools import math +import os +from pathlib import Path from absl.testing import absltest from absl.testing import parameterized @@ -31,17 +33,23 @@ from jax.experimental.sparse import util as sparse_util from jax.experimental.sparse import test_util as sptu from jax.experimental.sparse import _lowerings -from jax._src import xla_bridge -from jax._src.lib import gpu_sparse from jax import jit from jax import vmap from jax._src import test_util as jtu -from jax.interpreters import mlir import jax.numpy as jnp -from jax.util import split_list +from jax._src.util import split_list import numpy as np import scipy.sparse +def get_rocm_version(): + rocm_path = os.environ.get("ROCM_PATH", "/opt/rocm") + version_path = Path(rocm_path) / ".info" / "version" + if not version_path.exists(): + raise FileNotFoundError(f"Expected ROCm version file at {version_path}") + version_str = version_path.read_text().strip() + major, minor, *_ = version_str.split(".") + return int(major), int(minor) + jax.config.parse_flags_with_absl() all_dtypes = jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex @@ -130,6 +138,8 @@ def test_csr_fromdense_ad(self, shape, dtype): ) @jax.default_matmul_precision("float32") def test_csr_matmul_ad(self, shape, dtype, bshape): + if jtu.is_device_rocm(): + self.skipTest("test_csr_matmul_ad not supported on ROCm due to hipSPARSE issue") csr_matmul = sparse_csr._csr_matvec if len(bshape) == 1 else sparse_csr._csr_matmat tol = {np.float32: 2E-5, np.float64: 1E-12, np.complex64: 1E-5, np.complex128: 1E-12} @@ -208,6 +218,9 @@ def test_csr_fromdense(self, shape, dtype): transpose=[True, False], ) def test_csr_matvec(self, shape, dtype, transpose): + if jtu.is_device_rocm(): + self.skipTest("test_csr_matvec not supported on ROCm due to hipSPARSE issue") + op = lambda M: M.T if transpose else M v_rng = jtu.rand_default(self.rng()) @@ -228,6 +241,14 @@ def test_csr_matvec(self, shape, dtype, transpose): transpose=[True, False], ) def test_csr_matmat(self, shape, dtype, transpose): + if ( + jtu.is_device_rocm() and + get_rocm_version() < (6, 4) and + dtype in (jtu.dtypes.floating + jtu.dtypes.complex) + ): + # TODO: Remove this check when ROCm 6.4+ is the minimum supported version + self.skipTest("ROCm <6.4 bug: NaN propagation when beta==0 (fixed in ROCm 6.4.0)") + op = lambda M: M.T if transpose else M B_rng = jtu.rand_default(self.rng()) @@ -411,25 +432,6 @@ def test_coo_sorted_indices_gpu_lowerings(self): self.assertArraysEqual(matmat_expected, matmat_unsorted) self.assertArraysEqual(matmat_expected, matmat_unsorted_fallback) - @jtu.run_on_devices("gpu") - def test_gpu_translation_rule(self): - version = xla_bridge.get_backend().platform_version - if "rocm" not in version.split(): - cuda_version = None if version == "" else int( - version.split()[-1]) - if cuda_version is None or cuda_version < 11000: - self.assertFalse(gpu_sparse and gpu_sparse.cuda_is_supported) - self.assertNotIn(sparse.csr_todense_p, - mlir._platform_specific_lowerings["cuda"]) - else: - self.assertTrue(gpu_sparse and gpu_sparse.cuda_is_supported) - self.assertIn(sparse.csr_todense_p, - mlir._platform_specific_lowerings["cuda"]) - else: - self.assertTrue(gpu_sparse and gpu_sparse.rocm_is_supported) - self.assertIn(sparse.csr_todense_p, - mlir._platform_specific_lowerings["rocm"]) - @jtu.sample_product( shape=[(5, 8), (8, 5), (5, 5), (8, 8)], dtype=jtu.dtypes.floating + jtu.dtypes.complex, @@ -587,6 +589,10 @@ def test_coo_spmm(self, shape, dtype, transpose): ) @jtu.run_on_devices("gpu") def test_csr_spmv(self, shape, dtype, transpose): + if jtu.is_device_rocm(): + self.skipTest("test_csr_spmv not supported on ROCm due to hipSPARSE issue") + tol = {np.float32: 2E-5, np.float64: 2E-14} + rng_sparse = sptu.rand_sparse(self.rng()) rng_dense = jtu.rand_default(self.rng()) @@ -599,7 +605,7 @@ def test_csr_spmv(self, shape, dtype, transpose): data, indices.astype('int32'), indptr.astype('int32'), vec, transpose=transpose, shape=mat.shape) - self.assertArraysAllClose(actual, expected) + self.assertArraysAllClose(actual, expected, atol=tol, rtol=tol) @jtu.sample_product( shape=[(4, 5), (3, 4), (5, 4)], @@ -1036,6 +1042,8 @@ def test_transpose(self, shape, dtype, Obj): for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) @jax.default_matmul_precision("float32") def test_matmul(self, shape, dtype, Obj, bshape): + if jtu.is_device_rocm(): + self.skipTest("test_matmul not supported on ROCm due to hipSPARSE issue") rng = sptu.rand_sparse(self.rng(), post=jnp.array) rng_b = jtu.rand_default(self.rng()) M = rng(shape, dtype) @@ -1102,7 +1110,9 @@ def test_bcoo_to_bcsr_round_trip(self, shape, dtype, n_batch): _, bcoo_indices = sparse_bcoo._bcoo_fromdense(M, nse=nse, n_batch=n_batch, n_dense=n_dense) - bcoo_to_bcsr = partial(sparse_bcsr._bcoo_to_bcsr, shape=shape) + bcoo_to_bcsr = partial( + sparse_bcsr._bcoo_to_bcsr, shape=shape, index_dtype=bcoo_indices.dtype + ) args_maker_bcoo_to_bcsr = lambda: [bcoo_indices] self._CompileAndCheck(bcoo_to_bcsr, args_maker_bcoo_to_bcsr) @@ -1177,7 +1187,12 @@ def sparse_solve(data, indices, indptr, b): return sparse.linalg.spsolve(data, indices, indptr, b, tol, reorder) x = sparse_solve(data, indices, indptr, b) - self.assertAllClose(a @ x, b, rtol=1e-2, atol=1e-3) + self.assertAllClose( + jnp.matmul(a, x, precision=jax.lax.Precision.HIGHEST), + b, + rtol=1e-2, + atol=1e-3, + ) self._CompileAndCheck(sparse_solve, args_maker) @jtu.sample_product( diff --git a/tests/stack_test.py b/tests/stack_test.py index aa1a02793b1a..8ebfc3489ff5 100644 --- a/tests/stack_test.py +++ b/tests/stack_test.py @@ -16,7 +16,7 @@ import jax import jax.numpy as jnp -from jax._src.lax.stack import Stack +from jax._src.tpu.linalg.stack import Stack from jax._src import test_util as jtu diff --git a/tests/state_test.py b/tests/state_test.py index 60a7d8bc9f8a..6230732204c1 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -28,27 +28,23 @@ from jax import lax from jax._src import core from jax._src import config +from jax._src import dtypes from jax._src import linear_util as lu from jax._src.interpreters import partial_eval as pe from jax._src import test_util as jtu from jax._src.state import types as state_types from jax._src.util import tuple_insert import jax.numpy as jnp -from jax._src.lax.control_flow import for_loop -try: - import hypothesis as hp - import hypothesis.extra.numpy as hnp - import hypothesis.strategies as hps - CAN_USE_HYPOTHESIS = True -except (ModuleNotFoundError, ImportError): - CAN_USE_HYPOTHESIS = False +import hypothesis as hp +import hypothesis.extra.numpy as hnp +import hypothesis.strategies as hps from jax._src.state.discharge import (run_state, run_state_reference, discharge_state) from jax._src.state.primitives import (get_p, swap_p, addupdate_p, ref_addupdate, ref_get, ref_set, - ref_swap) + ref_swap, pin, unpin) from jax._src.state.types import (shaped_array_ref, ReadEffect, WriteEffect, AccumEffect, AbstractRef) @@ -130,7 +126,7 @@ def f(x_ref): with self.assertRaises(Exception): pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), [ref_aval]) else: - jaxpr, out_avals, _, () = pe.trace_to_jaxpr_dynamic( + jaxpr, out_avals, _ = pe.trace_to_jaxpr_dynamic( wrap_init(f, 1), [ref_aval]) self.assertSetEqual(jaxpr.effects, {ReadEffect(len(jaxpr.constvars))}) @@ -140,6 +136,30 @@ def f(x_ref): self.assertEqual(out_aval.shape, out_shape) self.assertEqual(out_aval.dtype, out_dtype) + @parameterized.parameters( + ((4, 5), 0, (0,)), + ((4, 5), 1, (0,)), + ((9, 10, 11, 12), 0, (slice(None), 0, 1)), # Contiguous int indexing + ((9, 10, 11, 12), 0, (0, slice(None), 1)), # Non-contiguous int indexing + ((9, 10, 11, 12), 1, (slice(None), 0, 1)), # Contiguous after batch + ((9, 10, 11, 12), 2, (slice(None), 0, 1)), # Non-contiguous after batch + ((9, 10, 11, 12), 3, (slice(None), slice(None), 0)), + # Shaped int indexer, contiguous after batch + ((9, 10, 11, 12), 3, + (slice(None), slice(None), np.array([[0,1]]))), + # Shaped int indexer, non-contiguous after batch + ((9, 10, 11, 12), 2, + (np.array([[0, 1]]), slice(None), np.array([[0, 1]]))), + ) + def test_vmap_of_get_regression(self, shape, in_axes, indexer): + # Regression test for https://github.com/jax-ml/jax/issues/33309 + def f(x): + return x[indexer] + x = jnp.ones(shape) + result = jax.vmap(f, in_axes=in_axes)(jax.new_ref(x)) + expected = jax.vmap(f, in_axes=in_axes)(x) + self.assertArraysEqual(result, expected) + def test_swap_abstract_eval_must_take_in_refs(self): ref_aval = core.ShapedArray((), jnp.float32) val_aval = core.ShapedArray((), jnp.float32) @@ -227,7 +247,7 @@ def f(x_ref, val): with self.assertRaises(Exception): pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), [ref_aval, val_aval]) else: - jaxpr, out_avals, _, () = pe.trace_to_jaxpr_dynamic( + jaxpr, out_avals, _ = pe.trace_to_jaxpr_dynamic( wrap_init(f, 2), [ref_aval, val_aval]) self.assertSetEqual(jaxpr.effects, {WriteEffect(len(jaxpr.constvars))}) @@ -304,7 +324,7 @@ def f(x_ref, val): with self.assertRaises(Exception): pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), [ref_aval, val_aval]) else: - jaxpr, out_avals, _, () = pe.trace_to_jaxpr_dynamic( + jaxpr, out_avals, _ = pe.trace_to_jaxpr_dynamic( wrap_init(f, 2), [ref_aval, val_aval]) self.assertSetEqual(jaxpr.effects, {AccumEffect(len(jaxpr.constvars))}) @@ -324,7 +344,7 @@ def body(x): x[()] = jnp.int32(1) x[()] = jnp.int32(2) return (x[()],) - jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic( + jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic( wrap_init(body, 1), [shaped_array_ref((), jnp.int32)]) self.assertLen(consts, 0) self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)]) @@ -337,7 +357,7 @@ def test_can_represent_addupdate_in_jaxprs(self): def body(x): ref_addupdate(x, (), jnp.int32(1)) return (x[()],) - jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic( + jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic( wrap_init(body, 1), [shaped_array_ref((), jnp.int32)]) self.assertLen(consts, 0) self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)]) @@ -347,14 +367,14 @@ def test_get_custom_pretty_printing_rule(self): def body(x_ref): x = x_ref[()] return [x] - jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( + jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( wrap_init(body, 1), [shaped_array_ref((), jnp.int32)]) - self.assertIn("b:i32[] <- a[]", jaxpr.pretty_print(use_color=False)) + self.assertIn("b:i32[] <- a[...]", jaxpr.pretty_print(use_color=False)) def body(x_ref): x = x_ref[:, 0] return [x] - jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( + jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( wrap_init(body, 1), [shaped_array_ref((1, 2), jnp.int32)]) self.assertIn("b:i32[1] <- a[:,0]", jaxpr.pretty_print(use_color=False)) @@ -362,14 +382,14 @@ def test_set_custom_pretty_printing_rule(self): def body(x_ref): x_ref[()] = jnp.int32(2) return [] - jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( + jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( wrap_init(body, 1), [shaped_array_ref((), jnp.int32)]) - self.assertIn("a[] <- 2", jaxpr.pretty_print(use_color=False)) + self.assertIn("a[...] <- 2:i32[]", jaxpr.pretty_print(use_color=False)) def body(x_ref, val): x_ref[:, 0] = val return [] - jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( + jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( wrap_init(body, 2), [shaped_array_ref((1, 2), jnp.int32), core.ShapedArray((1,), jnp.int32)]) self.assertIn("a[:,0] <- b", jaxpr.pretty_print(use_color=False)) @@ -378,14 +398,14 @@ def test_swap_custom_pretty_printing_rule(self): def body(x_ref): x = ref_swap(x_ref, (), jnp.int32(2)) return [x] - jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( + jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( wrap_init(body, 1), [shaped_array_ref((), jnp.int32)]) - self.assertIn("b:i32[], a[] <- a[], 2", jaxpr.pretty_print(use_color=False)) + self.assertIn("b:i32[], a[...] <- a[...], 2:i32[]", jaxpr.pretty_print(use_color=False)) def body(x_ref, val): x = ref_swap(x_ref, (slice(None), 0), val) return [x] - jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( + jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( wrap_init(body, 2), [shaped_array_ref((1, 2), jnp.int32), core.ShapedArray((1,), jnp.int32)]) self.assertIn("c:i32[1], a[:,0] <- a[:,0], b", @@ -395,20 +415,19 @@ def test_addupdate_custom_pretty_printing_rule(self): def body(x_ref): ref_addupdate(x_ref, (), jnp.int32(2)) return [] - jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( + jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( wrap_init(body, 1), [shaped_array_ref((), jnp.int32)]) - self.assertIn("a[] += 2", jaxpr.pretty_print(use_color=False)) + self.assertIn("a[...] += 2", jaxpr.pretty_print(use_color=False)) def body(x_ref, val): ref_addupdate(x_ref, (slice(None), 0), val) return [] - jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( + jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( wrap_init(body, 2), [shaped_array_ref((1, 2), jnp.int32), core.ShapedArray((1,), jnp.int32)]) self.assertIn("a[:,0] += b", jaxpr.pretty_print(use_color=False)) - def test_get_jvp(self): def f(r): @@ -420,7 +439,7 @@ def g(r, rdot): in_avals = [shaped_array_ref((), jnp.dtype('float32')), shaped_array_ref((), jnp.dtype('float32'))] - jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrap_init(g, 2), in_avals) + jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(wrap_init(g, 2), in_avals) self.assertEqual(jaxpr.eqns[0].primitive, get_p) self.assertEqual(jaxpr.eqns[1].primitive, get_p) @@ -436,7 +455,7 @@ def g(r, rdot): in_avals = [shaped_array_ref((), jnp.dtype('float32')), shaped_array_ref((), jnp.dtype('float32'))] - jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrap_init(g, 2), in_avals) + jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(wrap_init(g, 2), in_avals) self.assertEqual(jaxpr.eqns[0].primitive, get_p) self.assertEqual(jaxpr.eqns[1].primitive, get_p) self.assertEqual(jaxpr.eqns[2].primitive, lax.sin_p) @@ -456,7 +475,7 @@ def g(r, rdot): in_avals = [shaped_array_ref((), jnp.dtype('float32')), shaped_array_ref((), jnp.dtype('float32'))] - jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrap_init(g, 2), in_avals) + jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(wrap_init(g, 2), in_avals) self.assertEqual(jaxpr.eqns[0].primitive, addupdate_p) self.assertEqual(jaxpr.eqns[1].primitive, addupdate_p) self.assertEqual(jaxpr.eqns[2].primitive, get_p) @@ -477,27 +496,17 @@ def g(r, rdot): op=[ lambda x_ref, indexer: [x_ref[indexer]], lambda x_ref, indexer: [ - ref_swap(x_ref, indexer, - jnp.ones(x_ref.shape, x_ref.dtype)[None][(0, - *indexer)])], + ref_swap(x_ref, indexer, jnp.ones_like(x_ref[indexer]))], lambda x_ref, indexer: ( - ref_addupdate(x_ref, indexer, - jnp.ones(x_ref.shape, x_ref.dtype)[None][(0, - *indexer)]) - or [jnp.ones(x_ref.shape, x_ref.dtype)[None][(0, *indexer)]]) + ref_addupdate(x_ref, indexer, jnp.ones_like(x_ref[indexer])) + or [jnp.ones_like(x_ref[indexer])]), ], ) def test_vmap(self, ref_shape, ref_bdim, idx_shape, indexed_dims, idx_bdims, out_bdim, op): - - float_ = (jnp.dtype('float64') if config.enable_x64.value else - jnp.dtype('float32')) - int_ = (jnp.dtype('int64') if config.enable_x64.value else - jnp.dtype('int32')) + intx = dtypes.default_int_dtype() + floatx = dtypes.default_float_dtype() axis_size = 7 - out_shape = tuple(d for d, b in zip(ref_shape, indexed_dims) if not b) - if any(indexed_dims): - out_shape = (*idx_shape, *out_shape) def maybe_insert(shape, idx): if idx is None: @@ -505,13 +514,13 @@ def maybe_insert(shape, idx): return tuple_insert(shape, idx, axis_size) batched_ref_shape = maybe_insert(ref_shape, ref_bdim) - ref_aval = shaped_array_ref(ref_shape, float_) - bat_ref_aval = shaped_array_ref(batched_ref_shape, float_) + ref_aval = shaped_array_ref(ref_shape, floatx) + bat_ref_aval = shaped_array_ref(batched_ref_shape, floatx) - idx_avals = [core.ShapedArray(idx_shape, int_) + idx_avals = [core.ShapedArray(idx_shape, intx) for _ in idx_bdims] bat_idx_avals = [ - core.ShapedArray(maybe_insert(idx_shape, idx_bdim), int_) + core.ShapedArray(maybe_insert(idx_shape, idx_bdim), intx) for idx_bdim in idx_bdims] def f(x_ref, *idxs): @@ -520,19 +529,20 @@ def f(x_ref, *idxs): return op(x_ref, indexer) rng = self.rng() - a = rng.randn(*bat_ref_aval.shape) + a = rng.randn(*bat_ref_aval.shape).astype(floatx) his = [d for d, b in zip(ref_aval.shape, indexed_dims) if b] - idxs = [rng.randint(low=0, high=hi, size=i.shape) + idxs = [rng.randint(low=0, high=hi, size=i.shape, dtype=intx) for i, hi in zip(bat_idx_avals, his)] # discharge-of-vmap f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim]) - stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( + stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic( wrap_init(f_batched, 1 + len(bat_idx_avals)), [bat_ref_aval, *bat_idx_avals]) jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, a, *idxs) + # vmap-of-discharge - stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( + stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic( wrap_init(f, 1 + len(idx_avals)), [ref_aval, *idx_avals]) jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), @@ -551,7 +561,7 @@ def f(a_ref): a = ref_get(a_ref, ()) return [a + 1] in_avals = [shaped_array_ref((), jnp.dtype('float32'))] - stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), + stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), in_avals) # Discharging should just turn this into a jaxpr that just adds 1. discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts) @@ -567,7 +577,7 @@ def f(a_ref): a = ref_get(a_ref, (0, 1)) return [a + 1] in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))] - stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), + stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), in_avals) # Discharging should just turn this into a jaxpr that just adds 1. discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts) @@ -586,7 +596,7 @@ def f(a_ref): a = a_ref[jnp.array([0, 1])] return [a + 1] in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))] - stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( + stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( wrap_init(f, 1), in_avals) discharged_jaxpr, discharged_consts = discharge_state( stateful_jaxpr, consts) @@ -601,7 +611,7 @@ def f(a_ref, b): return [] in_avals = [shaped_array_ref((), jnp.dtype('float32')), core.ShapedArray((), jnp.dtype('float32'))] - stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), + stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), in_avals) # Discharging should just turn this into a jaxpr that ignores the first # value and returns second value plus 1. @@ -618,8 +628,7 @@ def f(a_ref): ref_set(a_ref, (0, 1), jnp.ones(2, dtype=jnp.dtype('float32'))) return [] in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))] - stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), - in_avals) + stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), in_avals) # Discharging should just turn this into a jaxpr that just adds 1. discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts) self.assertLen(discharged_jaxpr.invars, 1) @@ -638,7 +647,7 @@ def f(a_ref): a_ref[jnp.array([0, 1])] = jnp.ones((2, 3), 'float32') return [] in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))] - stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), + stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), in_avals) discharged_jaxpr, discharged_consts = discharge_state( stateful_jaxpr, consts) @@ -654,7 +663,7 @@ def f(a_ref): jnp.zeros((2, 2), jnp.float32)) return [a + 1] in_avals = [shaped_array_ref((4, 3, 2), jnp.float32)] - stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( + stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( wrap_init(f, 1), in_avals) discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts) @@ -672,7 +681,7 @@ def f(a_ref, b): return [] in_avals = [shaped_array_ref((), jnp.dtype('float32')), core.ShapedArray((), jnp.dtype('float32'))] - stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), + stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), in_avals) # Discharging should just turn this into a jaxpr that adds the first value, # second value, and 1. @@ -690,7 +699,7 @@ def f(a_ref): jnp.ones(2, dtype=jnp.dtype('float32'))) return [] in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))] - stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), + stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), in_avals) discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts) self.assertLen(discharged_jaxpr.invars, 1) @@ -711,7 +720,7 @@ def f(a_ref): jnp.ones((2, 3), 'float32')) return [] in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))] - stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), + stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), in_avals) discharged_jaxpr, discharged_consts = discharge_state( stateful_jaxpr, consts) @@ -725,7 +734,7 @@ def f(a_ref): b = a + 1 return [a, b] in_avals = [shaped_array_ref((4,), jnp.dtype('float32'))] - stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), + stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrap_init(f, 1), in_avals) discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts) self.assertLen(discharged_jaxpr.invars, 1) @@ -745,7 +754,7 @@ def f(a_ref, b_ref): shaped_array_ref((4,), jnp.dtype('float32')), shaped_array_ref((4,), jnp.dtype('float32')) ] - stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), + stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), in_avals) discharged_jaxpr, _ = discharge_state( stateful_jaxpr, consts, should_discharge=[False, True]) @@ -775,7 +784,7 @@ def f(a_ref, b_ref): scalar_ref_1 = shaped_array_ref((), jnp.float32) scalar_ref_2 = shaped_array_ref((), jnp.float32) - jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( + jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( wrap_init(f, 2), [scalar_ref_1, scalar_ref_2]) discharged_jaxpr, _ = discharge_state(jaxpr, (), should_discharge=[False, True]) @@ -792,7 +801,7 @@ def body(i, st): lax.fori_loop(0, 5, body, init_val=()) return a_ref[...], b_ref[...] - ref = lambda x: AbstractRef(core.raise_to_shaped(core.get_aval(x))) + ref = lambda x: AbstractRef(core.get_aval(x)) f_jaxpr = jax.make_jaxpr(f)(ref(1.), ref(2.)) jaxpr, _ = discharge_state(f_jaxpr.jaxpr, (), should_discharge=[False, True]) # Effects on y_ref were discharged away but not the effects on x_ref @@ -806,330 +815,310 @@ def body(i, st): self.assertLen(jaxpr.outvars, 3) -if CAN_USE_HYPOTHESIS: - - def index_arrays(size, idx_shape): - valid_idx = hps.integers(min_value=-size, max_value=size - 1) - return hnp.arrays(np.int32, idx_shape, elements=valid_idx) - - Shape = tuple[int, ...] - - class IndexParam(NamedTuple): - ref_aval: shaped_array_ref - ref_shape: Shape - indexed_dims: list[bool] - idx_avals: tuple[core.ShapedArray, ...] - idx_shape: Shape - slice_aval: core.ShapedArray - slice_shape: Shape - - @hps.composite - def index_params(draw): - ref_shape = draw(hnp.array_shapes(max_dims=4, max_side=7), label='ref_shape') - indexed_dims = draw(hps.lists(hps.booleans(), - min_size=len(ref_shape), - max_size=len(ref_shape))) - idx_shape = draw(hnp.array_shapes(max_dims=3, max_side=5)) - if any(indexed_dims): - sliced_shape = (s for s, b in zip(ref_shape, indexed_dims) if not b) +def index_arrays(size, idx_shape): + valid_idx = hps.integers(min_value=-size, max_value=size - 1) + return hnp.arrays(np.int32, idx_shape, elements=valid_idx) + +Shape = tuple[int, ...] + +class IndexParam(NamedTuple): + ref_aval: shaped_array_ref + ref_shape: Shape + indexed_dims: list[bool] + idx_avals: tuple[core.ShapedArray, ...] + idx_shape: Shape + slice_aval: core.ShapedArray + slice_shape: Shape + +@hps.composite +def index_params(draw): + ref_shape = draw(hnp.array_shapes(max_dims=4, max_side=7), label='ref_shape') + indexed_dims = draw(hps.lists(hps.booleans(), + min_size=len(ref_shape), + max_size=len(ref_shape))) + idx_shape = draw(hnp.array_shapes(max_dims=3, max_side=5)) + if not any(indexed_dims): + slice_shape = ref_shape + else: + sliced_shape = tuple(s for s, b in zip(ref_shape, indexed_dims) if not b) + int_indexers_contiguous = bool( + np.all(np.diff(np.where(indexed_dims)[0]) == 1) + ) + if not int_indexers_contiguous: slice_shape = (*idx_shape, *sliced_shape) else: - slice_shape = ref_shape - ref_aval = shaped_array_ref(ref_shape, np.float32) - idx_avals = tuple(core.ShapedArray(idx_shape, np.int32) for _ in - range(sum(indexed_dims))) - slice_aval = core.ShapedArray(slice_shape, np.float32) - return IndexParam(ref_aval, ref_shape, indexed_dims, idx_avals, idx_shape, - slice_aval, slice_shape) - - class VmappableIndexParam(NamedTuple): - index_param: IndexParam - ref_bdim: int | None - non_slice_idx_bdims: tuple[int | None, ...] - slice_bdim: int - bat_ref_aval: shaped_array_ref - bat_ref_shape: Shape - bat_non_slice_idx_avals: tuple[core.ShapedArray, ...] - bat_non_slice_idx_shapes: tuple[Shape, ...] - bat_slice_aval: core.ShapedArray - bat_slice_shape: Shape - - def maybe_tuple_insert(t: tuple[Any, ...], idx: int | None, - val: Any) -> tuple[Any, ...]: - if idx is None: - return t - return tuple_insert(t, idx, val) - - @hps.composite - def vmappable_index_params(draw, *, op_type: str): - axis_size = draw(hps.integers(min_value=1, max_value=7), label='axis_size') - index_param: IndexParam = draw(index_params()) - non_slice_idx_bdims = tuple( - draw(hps.one_of( - hps.none(), - hps.integers(min_value=0, max_value=len(index_param.idx_shape)))) - for b in index_param.indexed_dims if b) - bat_non_slice_idx_shapes = tuple( - maybe_tuple_insert(index_param.idx_shape, idx_bdim, axis_size) - for idx_bdim in non_slice_idx_bdims) - if op_type == "swap": - # In a swap, the ref *must* be batched - ref_bdim = draw(hps.integers(min_value=0, - max_value=len(index_param.ref_shape))) - if any(idx_bdim is not None for idx_bdim in non_slice_idx_bdims): - # If it's a swap, if indices are batched, val must be batched. - slice_bdim = draw(hps.integers( - min_value=0, max_value=len(index_param.slice_shape))) - else: - slice_bdim = draw(hps.one_of(hps.none(), hps.integers( - min_value=0, max_value=len(index_param.slice_shape)))) - elif op_type == "get": - # In a get, the indices must be batched or ref is batched - if all(idx_bdim is None for idx_bdim in non_slice_idx_bdims): - ref_bdim = draw(hps.integers(min_value=0, - max_value=len(index_param.ref_shape))) - else: - ref_bdim = draw(hps.one_of(hps.none(), - hps.integers(min_value=0, max_value=len(index_param.ref_shape)))) + insert_pos = indexed_dims.index(True) + slice_shape = ( + *sliced_shape[:insert_pos], + *idx_shape, + *sliced_shape[insert_pos:], + ) + ref_aval = shaped_array_ref(ref_shape, np.float32) + idx_avals = tuple(core.ShapedArray(idx_shape, np.int32) for _ in + range(sum(indexed_dims))) + slice_aval = core.ShapedArray(slice_shape, np.float32) + return IndexParam(ref_aval, ref_shape, indexed_dims, idx_avals, idx_shape, + slice_aval, slice_shape) + +class VmappableIndexParam(NamedTuple): + index_param: IndexParam + ref_bdim: int | None + non_slice_idx_bdims: tuple[int | None, ...] + slice_bdim: int + bat_ref_aval: shaped_array_ref + bat_ref_shape: Shape + bat_non_slice_idx_avals: tuple[core.ShapedArray, ...] + bat_non_slice_idx_shapes: tuple[Shape, ...] + bat_slice_aval: core.ShapedArray + bat_slice_shape: Shape + +def maybe_tuple_insert(t: tuple[Any, ...], idx: int | None, + val: Any) -> tuple[Any, ...]: + if idx is None: + return t + return tuple_insert(t, idx, val) + +@hps.composite +def vmappable_index_params(draw, *, op_type: str): + axis_size = draw(hps.integers(min_value=1, max_value=7), label='axis_size') + index_param: IndexParam = draw(index_params()) + non_slice_idx_bdims = tuple( + draw(hps.one_of( + hps.none(), + hps.integers(min_value=0, max_value=len(index_param.idx_shape)))) + for b in index_param.indexed_dims if b) + bat_non_slice_idx_shapes = tuple( + maybe_tuple_insert(index_param.idx_shape, idx_bdim, axis_size) + for idx_bdim in non_slice_idx_bdims) + if op_type == "swap": + # In a swap, the ref *must* be batched + ref_bdim = draw(hps.integers(min_value=0, + max_value=len(index_param.ref_shape))) + if any(idx_bdim is not None for idx_bdim in non_slice_idx_bdims): + # If it's a swap, if indices are batched, val must be batched. slice_bdim = draw(hps.integers( min_value=0, max_value=len(index_param.slice_shape))) + else: + slice_bdim = draw(hps.one_of(hps.none(), hps.integers( + min_value=0, max_value=len(index_param.slice_shape)))) + elif op_type == "get": + # In a get, the indices must be batched or ref is batched + if all(idx_bdim is None for idx_bdim in non_slice_idx_bdims): + ref_bdim = draw(hps.integers(min_value=0, + max_value=len(index_param.ref_shape))) + else: + ref_bdim = draw(hps.one_of(hps.none(), + hps.integers(min_value=0, max_value=len(index_param.ref_shape)))) + slice_bdim = draw(hps.integers( + min_value=0, max_value=len(index_param.slice_shape))) + + bat_ref_shape = maybe_tuple_insert(index_param.ref_shape, ref_bdim, axis_size) + bat_ref_aval = shaped_array_ref(bat_ref_shape, np.float32) + bat_non_slice_idx_avals = tuple( + core.ShapedArray(shape, np.int32) for shape in bat_non_slice_idx_shapes) + bat_slice_shape = maybe_tuple_insert(index_param.slice_shape, slice_bdim, axis_size) + bat_slice_aval = core.ShapedArray(bat_slice_shape, np.float32) + return VmappableIndexParam(index_param, ref_bdim, non_slice_idx_bdims, + slice_bdim, bat_ref_aval, bat_ref_shape, + bat_non_slice_idx_avals, bat_non_slice_idx_shapes, + bat_slice_aval, bat_slice_shape) + +class GetVmapParams(NamedTuple): + vmap_index_param: VmappableIndexParam + bat_ref: np.ndarray + bat_idxs: tuple[np.ndarray, ...] + +@hps.composite +def get_vmap_params(draw): + vmap_index_param: VmappableIndexParam = draw( + vmappable_index_params(op_type="get")) + bat_ref = draw(hnp.arrays(np.float32, vmap_index_param.bat_ref_shape)) + bat_idx_shapes_ = iter(vmap_index_param.bat_non_slice_idx_shapes) + bat_idxs = tuple( + draw(index_arrays(size, next(bat_idx_shapes_))) + for size, indexed in zip( + vmap_index_param.index_param.ref_shape, + vmap_index_param.index_param.indexed_dims) + if indexed) + assert next(bat_idx_shapes_, None) is None + return GetVmapParams(vmap_index_param, bat_ref, bat_idxs) + +class SetVmapParams(NamedTuple): + vmap_index_param: VmappableIndexParam + bat_ref: np.ndarray + bat_val: np.ndarray + bat_idxs: tuple[np.ndarray, ...] + +@hps.composite +def set_vmap_params(draw): + vmap_index_param: VmappableIndexParam = draw(vmappable_index_params( + op_type="swap")) + bat_ref = draw(hnp.arrays(np.float32, vmap_index_param.bat_ref_shape)) + bat_idx_shapes_ = iter(vmap_index_param.bat_non_slice_idx_shapes) + bat_idxs = tuple( + draw(index_arrays(size, next(bat_idx_shapes_))) + for size, indexed in zip( + vmap_index_param.index_param.ref_shape, + vmap_index_param.index_param.indexed_dims) + if indexed) + assert next(bat_idx_shapes_, None) is None + bat_val = draw(hnp.arrays(np.float32, vmap_index_param.bat_slice_shape)) + return SetVmapParams(vmap_index_param, bat_ref, bat_val, bat_idxs) + +Indexer = tuple[Union[int, slice, np.ndarray]] + +def _unpack_idx(idx: Indexer + ) -> tuple[Sequence[int | np.ndarray], Sequence[bool]]: + indexed_dims = [type(i) != slice for i in idx] + non_slice_idx = [i for i, b in zip(idx, indexed_dims) if b] + return non_slice_idx, indexed_dims + +def _pack_idx(non_slice_idx: Sequence[int | np.ndarray], + indexed_dims: Sequence[bool]) -> Indexer: + idx_ = iter(non_slice_idx) + idx = tuple(next(idx_) if b else slice(None) for b in indexed_dims) + assert next(idx_, None) is None + return idx + +@jtu.thread_unsafe_test_class(condition=not jtu.hypothesis_is_thread_safe()) +class StateHypothesisTest(jtu.JaxTestCase): + + @hp.given(get_vmap_params()) + @hp.settings(deadline=None, + print_blob=True, + max_examples=jtu.NUM_GENERATED_CASES.value, + suppress_health_check=[hp.HealthCheck.too_slow]) + def test_get_vmap(self, get_vmap_param: GetVmapParams): + + indexed_dims = get_vmap_param.vmap_index_param.index_param.indexed_dims + + def f(ref, *non_slice_idx): + idx = _pack_idx(non_slice_idx, indexed_dims) + return [ref_get(ref, idx)] + ref_aval = get_vmap_param.vmap_index_param.index_param.ref_aval + bat_ref_aval = get_vmap_param.vmap_index_param.bat_ref_aval + bat_non_slice_idx_avals = get_vmap_param.vmap_index_param.bat_non_slice_idx_avals + ref_bdim = get_vmap_param.vmap_index_param.ref_bdim + idx_bdims = get_vmap_param.vmap_index_param.non_slice_idx_bdims + out_bdim = get_vmap_param.vmap_index_param.slice_bdim + non_slice_idx = get_vmap_param.bat_idxs + idx_avals = get_vmap_param.vmap_index_param.index_param.idx_avals + ref = get_vmap_param.bat_ref - bat_ref_shape = maybe_tuple_insert(index_param.ref_shape, ref_bdim, axis_size) - bat_ref_aval = shaped_array_ref(bat_ref_shape, np.float32) - bat_non_slice_idx_avals = tuple( - core.ShapedArray(shape, np.int32) for shape in bat_non_slice_idx_shapes) - bat_slice_shape = maybe_tuple_insert(index_param.slice_shape, slice_bdim, axis_size) - bat_slice_aval = core.ShapedArray(bat_slice_shape, np.float32) - return VmappableIndexParam(index_param, ref_bdim, non_slice_idx_bdims, - slice_bdim, bat_ref_aval, bat_ref_shape, - bat_non_slice_idx_avals, bat_non_slice_idx_shapes, - bat_slice_aval, bat_slice_shape) - - class GetVmapParams(NamedTuple): - vmap_index_param: VmappableIndexParam - bat_ref: np.ndarray - bat_idxs: tuple[np.ndarray, ...] - - @hps.composite - def get_vmap_params(draw): - vmap_index_param: VmappableIndexParam = draw( - vmappable_index_params(op_type="get")) - bat_ref = draw(hnp.arrays(np.float32, vmap_index_param.bat_ref_shape)) - bat_idx_shapes_ = iter(vmap_index_param.bat_non_slice_idx_shapes) - bat_idxs = tuple( - draw(index_arrays(size, next(bat_idx_shapes_))) - for size, indexed in zip( - vmap_index_param.index_param.ref_shape, - vmap_index_param.index_param.indexed_dims) - if indexed) - assert next(bat_idx_shapes_, None) is None - return GetVmapParams(vmap_index_param, bat_ref, bat_idxs) - - class SetVmapParams(NamedTuple): - vmap_index_param: VmappableIndexParam - bat_ref: np.ndarray - bat_val: np.ndarray - bat_idxs: tuple[np.ndarray, ...] - - @hps.composite - def set_vmap_params(draw): - vmap_index_param: VmappableIndexParam = draw(vmappable_index_params( - op_type="swap")) - bat_ref = draw(hnp.arrays(np.float32, vmap_index_param.bat_ref_shape)) - bat_idx_shapes_ = iter(vmap_index_param.bat_non_slice_idx_shapes) - bat_idxs = tuple( - draw(index_arrays(size, next(bat_idx_shapes_))) - for size, indexed in zip( - vmap_index_param.index_param.ref_shape, - vmap_index_param.index_param.indexed_dims) - if indexed) - assert next(bat_idx_shapes_, None) is None - bat_val = draw(hnp.arrays(np.float32, vmap_index_param.bat_slice_shape)) - return SetVmapParams(vmap_index_param, bat_ref, bat_val, bat_idxs) - - Indexer = tuple[Union[int, slice, np.ndarray]] - - def _unpack_idx(idx: Indexer - ) -> tuple[Sequence[int | np.ndarray], Sequence[bool]]: - indexed_dims = [type(i) != slice for i in idx] - non_slice_idx = [i for i, b in zip(idx, indexed_dims) if b] - return non_slice_idx, indexed_dims - - def _pack_idx(non_slice_idx: Sequence[int | np.ndarray], - indexed_dims: Sequence[bool]) -> Indexer: - idx_ = iter(non_slice_idx) - idx = tuple(next(idx_) if b else slice(None) for b in indexed_dims) - assert next(idx_, None) is None - return idx - - @jtu.thread_unsafe_test_class() # hypothesis isn't thread-safe - class StateHypothesisTest(jtu.JaxTestCase): - - @hp.given(get_vmap_params()) - @hp.settings(deadline=None, print_blob=True, - max_examples=jtu.NUM_GENERATED_CASES.value) - def test_get_vmap(self, get_vmap_param: GetVmapParams): - - indexed_dims = get_vmap_param.vmap_index_param.index_param.indexed_dims - - def f(ref, *non_slice_idx): - idx = _pack_idx(non_slice_idx, indexed_dims) - return [ref_get(ref, idx)] - ref_aval = get_vmap_param.vmap_index_param.index_param.ref_aval - bat_ref_aval = get_vmap_param.vmap_index_param.bat_ref_aval - bat_non_slice_idx_avals = get_vmap_param.vmap_index_param.bat_non_slice_idx_avals - ref_bdim = get_vmap_param.vmap_index_param.ref_bdim - idx_bdims = get_vmap_param.vmap_index_param.non_slice_idx_bdims - out_bdim = get_vmap_param.vmap_index_param.slice_bdim - non_slice_idx = get_vmap_param.bat_idxs - idx_avals = get_vmap_param.vmap_index_param.index_param.idx_avals - ref = get_vmap_param.bat_ref - - f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim]) - stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - wrap_init(f_batched, 1 + len(bat_non_slice_idx_avals)), - [bat_ref_aval, *bat_non_slice_idx_avals]) - jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) - discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, *non_slice_idx) - - # vmap-of-discharge - stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - wrap_init(f, 1 + len(idx_avals)), [ref_aval, *idx_avals]) - jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) - f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), - in_axes=(ref_bdim, *idx_bdims), - out_axes=[out_bdim, ref_bdim]) - vmap_of_discharge_ans = f_batched(ref, *non_slice_idx) - - self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans, - check_dtypes=False) - - - @hp.given(set_vmap_params()) - @hp.settings(deadline=None, print_blob=True, - max_examples=jtu.NUM_GENERATED_CASES.value) - def test_set_vmap(self, set_vmap_param: SetVmapParams): - if jtu.test_device_matches(["gpu"]): - self.skipTest("Scatter is nondeterministic on GPU") - indexed_dims = set_vmap_param.vmap_index_param.index_param.indexed_dims - - def f(ref, val, *non_slice_idx): - idx = _pack_idx(non_slice_idx, indexed_dims) - ref_set(ref, idx, val) - return [] - ref_aval = set_vmap_param.vmap_index_param.index_param.ref_aval - bat_ref_aval = set_vmap_param.vmap_index_param.bat_ref_aval - bat_non_slice_idx_avals = set_vmap_param.vmap_index_param.bat_non_slice_idx_avals - ref_bdim = set_vmap_param.vmap_index_param.ref_bdim - idx_bdims = set_vmap_param.vmap_index_param.non_slice_idx_bdims - non_slice_idx = set_vmap_param.bat_idxs - idx_avals = set_vmap_param.vmap_index_param.index_param.idx_avals - ref = set_vmap_param.bat_ref - val = set_vmap_param.bat_val - bat_val_aval = set_vmap_param.vmap_index_param.bat_slice_aval - val_aval = set_vmap_param.vmap_index_param.index_param.slice_aval - val_bdim = set_vmap_param.vmap_index_param.slice_bdim - - f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims), - out_axes=[]) - stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - wrap_init(f_batched, 2 + len(bat_non_slice_idx_avals)), - [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals]) - jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) - discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx) - - # vmap-of-discharge - stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - wrap_init(f, 2 + len(idx_avals)), [ref_aval, val_aval, *idx_avals]) - jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) - f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), - in_axes=(ref_bdim, val_bdim, *idx_bdims), - out_axes=[ref_bdim]) - vmap_of_discharge_ans = f_batched(ref, val, *non_slice_idx) - - self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans, - check_dtypes=False) - - - @hp.given(set_vmap_params()) - @hp.settings(deadline=None, print_blob=True, - max_examples=jtu.NUM_GENERATED_CASES.value) - def test_addupdate_vmap(self, set_vmap_param: SetVmapParams): - - indexed_dims = set_vmap_param.vmap_index_param.index_param.indexed_dims - - def f(ref, val, *non_slice_idx): - idx = _pack_idx(non_slice_idx, indexed_dims) - ref_addupdate(ref, idx, val) - return [] - ref_aval = set_vmap_param.vmap_index_param.index_param.ref_aval - bat_ref_aval = set_vmap_param.vmap_index_param.bat_ref_aval - bat_non_slice_idx_avals = set_vmap_param.vmap_index_param.bat_non_slice_idx_avals - ref_bdim = set_vmap_param.vmap_index_param.ref_bdim - idx_bdims = set_vmap_param.vmap_index_param.non_slice_idx_bdims - non_slice_idx = set_vmap_param.bat_idxs - idx_avals = set_vmap_param.vmap_index_param.index_param.idx_avals - ref = set_vmap_param.bat_ref - val = set_vmap_param.bat_val - bat_val_aval = set_vmap_param.vmap_index_param.bat_slice_aval - val_aval = set_vmap_param.vmap_index_param.index_param.slice_aval - val_bdim = set_vmap_param.vmap_index_param.slice_bdim - - f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims), - out_axes=[]) - stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - wrap_init(f_batched, 2 + len(bat_non_slice_idx_avals)), - [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals]) - jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) - discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx) - - # vmap-of-discharge - stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - wrap_init(f, 2 + len(idx_avals)), [ref_aval, val_aval, *idx_avals]) - jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) - f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), - in_axes=(ref_bdim, val_bdim, *idx_bdims), - out_axes=[ref_bdim]) - vmap_of_discharge_ans = f_batched(ref, val, *non_slice_idx) - - self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans, - check_dtypes=False) + f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim]) + stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic( + wrap_init(f_batched, 1 + len(bat_non_slice_idx_avals)), + [bat_ref_aval, *bat_non_slice_idx_avals]) + jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) + discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, *non_slice_idx) + # vmap-of-discharge + stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic( + wrap_init(f, 1 + len(idx_avals)), [ref_aval, *idx_avals]) + jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) + f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), + in_axes=(ref_bdim, *idx_bdims), + out_axes=[out_bdim, ref_bdim]) + vmap_of_discharge_ans = f_batched(ref, *non_slice_idx) -class StateControlFlowTest(jtu.JaxTestCase): + self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans, + check_dtypes=False) - def test_simple_cond(self): - def f(pred): - def body(x_ref): - def true_fun(): - x_ref[()] = 1. - def false_fun(): - pass - lax.cond(pred, true_fun, false_fun) - return for_loop.run_state(body, 0.) - jaxpr = jax.make_jaxpr(f)(True).jaxpr - self.assertEmpty(jaxpr.effects) - self.assertAllClose(jax.jit(f)(True), 1.) - self.assertAllClose(jax.jit(f)(False), 0.) - def test_simple_cond_with_return(self): - def f(pred): - def body(refs): - x_ref, y_ref = refs - def true_fun(): - x_ref[()] = 1. - return 4. - def false_fun(): - return 5. - out = lax.cond(pred, true_fun, false_fun) - y_ref[...] = out - return for_loop.run_state(body, (0., 0.)) - jaxpr = jax.make_jaxpr(f)(True).jaxpr - self.assertEmpty(jaxpr.effects) - out = jax.jit(f)(True) - self.assertTupleEqual(out, (1., 4.)) - out = jax.jit(f)(False) - self.assertTupleEqual(out, (0., 5.)) + @hp.given(set_vmap_params()) + @hp.settings(deadline=None, print_blob=True, + max_examples=jtu.NUM_GENERATED_CASES.value, + suppress_health_check=[hp.HealthCheck.too_slow]) + def test_set_vmap(self, set_vmap_param: SetVmapParams): + if jtu.test_device_matches(["gpu"]): + self.skipTest("Scatter is nondeterministic on GPU") + indexed_dims = set_vmap_param.vmap_index_param.index_param.indexed_dims + + def f(ref, val, *non_slice_idx): + idx = _pack_idx(non_slice_idx, indexed_dims) + ref_set(ref, idx, val) + return [] + ref_aval = set_vmap_param.vmap_index_param.index_param.ref_aval + bat_ref_aval = set_vmap_param.vmap_index_param.bat_ref_aval + bat_non_slice_idx_avals = set_vmap_param.vmap_index_param.bat_non_slice_idx_avals + ref_bdim = set_vmap_param.vmap_index_param.ref_bdim + idx_bdims = set_vmap_param.vmap_index_param.non_slice_idx_bdims + non_slice_idx = set_vmap_param.bat_idxs + idx_avals = set_vmap_param.vmap_index_param.index_param.idx_avals + ref = set_vmap_param.bat_ref + val = set_vmap_param.bat_val + bat_val_aval = set_vmap_param.vmap_index_param.bat_slice_aval + val_aval = set_vmap_param.vmap_index_param.index_param.slice_aval + val_bdim = set_vmap_param.vmap_index_param.slice_bdim + + f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims), + out_axes=[]) + stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic( + wrap_init(f_batched, 2 + len(bat_non_slice_idx_avals)), + [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals]) + jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) + discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx) + + # vmap-of-discharge + stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic( + wrap_init(f, 2 + len(idx_avals)), [ref_aval, val_aval, *idx_avals]) + jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) + f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), + in_axes=(ref_bdim, val_bdim, *idx_bdims), + out_axes=[ref_bdim]) + vmap_of_discharge_ans = f_batched(ref, val, *non_slice_idx) + + self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans, + check_dtypes=False) + + + @hp.given(set_vmap_params()) + @hp.settings(deadline=None, print_blob=True, + max_examples=jtu.NUM_GENERATED_CASES.value, + suppress_health_check=[hp.HealthCheck.too_slow]) + def test_addupdate_vmap(self, set_vmap_param: SetVmapParams): + + indexed_dims = set_vmap_param.vmap_index_param.index_param.indexed_dims + + def f(ref, val, *non_slice_idx): + idx = _pack_idx(non_slice_idx, indexed_dims) + ref_addupdate(ref, idx, val) + return [] + ref_aval = set_vmap_param.vmap_index_param.index_param.ref_aval + bat_ref_aval = set_vmap_param.vmap_index_param.bat_ref_aval + bat_non_slice_idx_avals = set_vmap_param.vmap_index_param.bat_non_slice_idx_avals + ref_bdim = set_vmap_param.vmap_index_param.ref_bdim + idx_bdims = set_vmap_param.vmap_index_param.non_slice_idx_bdims + non_slice_idx = set_vmap_param.bat_idxs + idx_avals = set_vmap_param.vmap_index_param.index_param.idx_avals + ref = set_vmap_param.bat_ref + val = set_vmap_param.bat_val + bat_val_aval = set_vmap_param.vmap_index_param.bat_slice_aval + val_aval = set_vmap_param.vmap_index_param.index_param.slice_aval + val_bdim = set_vmap_param.vmap_index_param.slice_bdim + + f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims), + out_axes=[]) + stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic( + wrap_init(f_batched, 2 + len(bat_non_slice_idx_avals)), + [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals]) + jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) + discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx) + + # vmap-of-discharge + stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic( + wrap_init(f, 2 + len(idx_avals)), [ref_aval, val_aval, *idx_avals]) + jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) + f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), + in_axes=(ref_bdim, val_bdim, *idx_bdims), + out_axes=[ref_bdim]) + vmap_of_discharge_ans = f_batched(ref, val, *non_slice_idx) + + self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans, + check_dtypes=False) + + +class StateControlFlowTest(jtu.JaxTestCase): def test_cond_discharge(self): def f0(pred, x_ref, y_ref): @@ -1139,7 +1128,7 @@ def false_fun(): y_ref[...] = 2. lax.cond(pred, true_fun, false_fun) return x_ref[...], y_ref[...] - ref = lambda x: AbstractRef(core.raise_to_shaped(core.get_aval(x))) + ref = lambda x: AbstractRef(core.get_aval(x)) f_jaxpr = jax.make_jaxpr(f0)(False, ref(3.), ref(4.)) jaxpr, _ = discharge_state(f_jaxpr.jaxpr, (), should_discharge=[False, False, True]) # Effects on y_ref were discharged away but not the effects on x_ref @@ -1287,75 +1276,44 @@ def inner_false_fun(): expected = (4., 0., 0.) self.assertTupleEqual(out, expected) - def test_nested_cond(self): - def f(pred): + def test_while_with_state_in_body(self): + def f(x, y, z): + @run_state def body(x_ref): - def true_fun(): - def true_fun_inner(): - x_ref[()] = 1. - def false_fun_inner(): - pass - return lax.cond(pred, true_fun_inner, false_fun_inner) - def false_fun(): - pass - lax.cond(pred, true_fun, false_fun) - return for_loop.run_state(body, 0.) - jaxpr = jax.make_jaxpr(f)(True).jaxpr + def cond(i): + return i < y + def body(i): + x_ref[...] += z + return i + 1 + lax.while_loop(cond, body, 0) + return body(x) + jaxpr = jax.make_jaxpr(f)(0, 5, 2).jaxpr self.assertEmpty(jaxpr.effects) - self.assertAllClose(jax.jit(f)(True), 1.) - self.assertAllClose(jax.jit(f)(False), 0.) - - def test_cond_jvp_with_state(self): - def f(pred, init_value): - def body(x_ref): - def true_fun(): - x_ref[()] = x_ref[()] ** 2 - def false_fun(): - pass - lax.cond(pred, true_fun, false_fun) - return for_loop.run_state(body, init_value) - - out_primal, out_tangent = jax.jvp(partial(f, True), (3.,), (1.,)) - self.assertAllClose(out_primal, 9.) - self.assertAllClose(out_tangent, 6.) - - out_primal, out_tangent = jax.jvp(partial(f, False), (3.,), (1.,)) - self.assertAllClose(out_primal, 3.) - self.assertAllClose(out_tangent, 1.) - - def test_cond_vmap_not_implemented(self): - @jax.jit - def f(init_value): - def body(x_ref): - def true_fun(): - x_ref[()] = x_ref[()] ** 2 - def false_fun(): - pass - lax.cond(x_ref[()] < 1, true_fun, false_fun) - return for_loop.run_state(body, init_value) - - with self.assertRaises(NotImplementedError): - jax.vmap(f)(jnp.arange(2.)) + self.assertAllClose(jax.jit(f)(0, 5, 2), 10) + self.assertAllClose(jax.jit(f)(1, 2, 3), 7) - def test_cond_grad_not_implemented(self): - @jax.jit - def f(init_value): + def test_while_with_state_in_cond(self): + def f(x, y, z): + @run_state def body(x_ref): - def true_fun(): - x_ref[()] = x_ref[()] ** 2 - def false_fun(): - pass - lax.cond(True, true_fun, false_fun) - return for_loop.run_state(body, init_value) - - with self.assertRaises(NotImplementedError): - jax.grad(f)(3.) + def cond(i): + x_ref[...] += z + return i < y + def body(i): + return i + 1 + lax.while_loop(cond, body, 0) + return body(x) + jaxpr = jax.make_jaxpr(f)(0, 5, 2).jaxpr + self.assertEmpty(jaxpr.effects) + self.assertAllClose(jax.jit(f)(0, 5, 2), 10) + self.assertAllClose(jax.jit(f)(1, 2, 3), 7) - def test_while_with_state_in_body(self): + def test_while_errors_if_same_ref_in_body_and_cond(self): def f(x, y, z): @run_state def body(x_ref): def cond(i): + x_ref[...] += z return i < y def body(i): x_ref[...] += z @@ -1364,8 +1322,9 @@ def body(i): return body(x) jaxpr = jax.make_jaxpr(f)(0, 5, 2).jaxpr self.assertEmpty(jaxpr.effects) - self.assertAllClose(jax.jit(f)(0, 5, 2), 10) - self.assertAllClose(jax.jit(f)(1, 2, 3), 7) + with self.assertRaisesRegex(NotImplementedError, + "Cannot write to the same ref in both cond and body."): + jax.jit(f)(0, 5, 2) def test_scan_with_state_in_body(self): def f(x, w, y, zs): @@ -1414,6 +1373,41 @@ def body(y, z): self.assertAllClose(jax.jit(g)((1, 0, 1, 5, zs))[:3], (13, 35, 11)) self.assertAllClose(jax.jit(g)((1, 1, 1, 2, zs))[:3], (13, 21, 11)) + def test_scan_discharges_into_carry(self): + # we want to discharge scanned-over refs into the carry for aliasing + def body(_, x_ref): + x_ref[...] += 1 + return (), () + + x_ref = jax.new_ref(jnp.arange(3.)) + jaxpr = jax.make_jaxpr(lambda x_ref: jax.lax.scan(body, (), x_ref))(x_ref) + jaxpr, () = discharge_state(jaxpr.jaxpr, jaxpr.consts) + scan_eqn = jaxpr.eqns[0] + self.assertEqual(scan_eqn.params['num_consts'], 0) + self.assertEqual(scan_eqn.params['num_carry'], 2) + a, b = scan_eqn.params['jaxpr'].in_avals + self.assertEqual(a.shape, ()) + self.assertEqual(b.shape, (3,)) + + @parameterized.named_parameters( + ("call_primitive", core.call_p), + ("closed_call_primitive", core.closed_call_p), + ) + def test_call_primitive_discharges(self, prim): + + def g(y_ref, x): + x_ref = jax.new_ref(x) + y_ref[...] = jnp.exp(x_ref[...]) + return [jax.freeze(y_ref)] + + def f(x): + y_ref = jax.new_ref(jnp.zeros_like(x)) + g_ = partial(g, y_ref) + return prim.bind( + lu.wrap_init(g_, debug_info=api_util.debug_info("f", g, (x,), {})), x + )[0] + out = f(4.) + np.testing.assert_array_equal(out, jnp.exp(4.)) class GeneralRefTest(jtu.JaxTestCase): @@ -1423,20 +1417,30 @@ def f(x_ref): x_ref[...] = x ref_addupdate(x_ref, (), x) return [x] - jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( + jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( wrap_init(f, 1), [AbstractRef(core.AbstractToken())]) self.assertIs(type(jaxpr.outvars[0].aval), core.AbstractToken) - def test_ref_of_ref(self): - def f(x_ref_ref): - x_ref = x_ref_ref[...] - return [x_ref] - # Not sure why you'd ever want to do this, but it works! - jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( - wrap_init(f, 1), - [AbstractRef(AbstractRef(core.ShapedArray((), jnp.int32)))]) - self.assertIs(type(jaxpr.outvars[0].aval), AbstractRef) - self.assertIs(type(jaxpr.outvars[0].aval.inner_aval), core.ShapedArray) + def test_reshape(self): + def f(x_ref): + x_ref = x_ref.reshape(4, -1) + x_ref.reshape(-1)[...] = jnp.arange(36) + return [x_ref[...]] + jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( + wrap_init(f, 1), [AbstractRef(core.ShapedArray((12, 3), jnp.int32))]) + self.assertEqual(jaxpr.outvars[0].aval.shape, (4, 9)) + + # NOTE(mattjj): disabled because it's extremely illegal + # def test_ref_of_ref(self): + # def f(x_ref_ref): + # x_ref = x_ref_ref[...] + # return [x_ref] + # # Not sure why you'd ever want to do this, but it works! + # jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( + # wrap_init(f, 1), + # [AbstractRef(AbstractRef(core.ShapedArray((), jnp.int32)))]) + # self.assertIs(type(jaxpr.outvars[0].aval), AbstractRef) + # self.assertIs(type(jaxpr.outvars[0].aval.inner_aval), core.ShapedArray) class RunStateTest(jtu.JaxTestCase): @@ -1510,18 +1514,19 @@ def f(x): self.assertIsNotNone(jaxpr.jaxpr.debug_info) self.assertIsNotNone(jaxpr.jaxpr.debug_info.func_src_info) - def test_can_stage_run_state_leaked_tracer_error(self): - leaks = [] - def f(x): - def my_fun(x): - leaks.append(x) - return None - return run_state(my_fun)(x) - _ = jax.make_jaxpr(f)(2) + # NOTE(mattjj): disabled because the error message changed for the better + # def test_can_stage_run_state_leaked_tracer_error(self): + # leaks = [] + # def f(x): + # def my_fun(x): + # leaks.append(x) + # return None + # return run_state(my_fun)(x) + # _ = jax.make_jaxpr(f)(2) - with self.assertRaisesRegex(jax.errors.UnexpectedTracerError, - "The function being traced when the value leaked was .*my_fun"): - jax.jit(lambda _: leaks[0])(1) + # with self.assertRaisesRegex(jax.errors.UnexpectedTracerError, + # "The function being traced when the value leaked was .*my_fun"): + # jax.jit(lambda _: leaks[0])(1) def test_nested_run_state_captures_effects(self): def f(x): @@ -1565,282 +1570,155 @@ def f(refs): self.assertAllClose(x, np.sin(2.)) self.assertAllClose(x_t, 3 * np.cos(2.)) - def test_linearize_of_run_state(self): - @run_state - def f(refs): - x_ref, y_ref = refs - y_ref[...] = jnp.sin(x_ref[...]) - (x, y), f_lin = jax.linearize(f, (1., 0.)) - self.assertAllClose(x, 1.) - self.assertAllClose(y, np.sin(1.)) - x_t, y_t = f_lin((2., 1.)) - self.assertAllClose(x_t, 2.) - self.assertAllClose(y_t, 2. * np.cos(1.)) - - def test_grad_of_run_state(self): - @run_state - def f(refs): - x_ref, y_ref = refs - y_ref[...] = jnp.sin(x_ref[...]) - - def sin(x): - return f((x, 0.))[1] - - x_g = jax.grad(sin)(1.) - self.assertAllClose(x_g, np.cos(1.)) - - x_g2 = jax.grad(jax.grad(sin))(1.) - self.assertAllClose(x_g2, -np.sin(1.)) - - x_g3 = jax.grad(jax.grad(jax.grad(sin)))(1.) - self.assertAllClose(x_g3, -np.cos(1.)) - - def test_vjp_of_run_state(self): - @run_state - def f(refs): - x_ref, y_ref = refs - y_ref[...] = jnp.sin(x_ref[...]) - - (x, y), f_vjp = jax.vjp(f, (1., 0.)) - self.assertAllClose(x, 1.) - self.assertAllClose(y, np.sin(1.)) - ((x_ct, y_ct),) = f_vjp((0., 1.)) - self.assertAllClose(x_ct, np.cos(1.)) - self.assertAllClose(y_ct, 0.) - - def test_vjp_of_run_state_single(self): - @run_state - def f(x_ref): - x = x_ref[...] - def _body(ref): - ref[...] = jnp.sin(ref[...]) - x = run_state(_body)(x) - x_ref[...] = x - - y, f_lin = jax.linearize(f, 1.) - self.assertAllClose(y, np.sin(1.)) - y_t = f_lin(1.) - self.assertAllClose(y_t, np.cos(1.)) - - y, f_vjp = jax.vjp(f, 1.) - self.assertAllClose(y, np.sin(1.)) - x_ct, = f_vjp(1.) - self.assertAllClose(x_ct, np.cos(1.)) - - jtu.check_grads(f, (0.5,), order=3) - - -if CAN_USE_HYPOTHESIS: - - class FuncSpec(NamedTuple): - fun: Callable[..., Any] - name: str - min_rank: int = 0 - max_rank: int = 4 - min_dim: int = 0 - max_dim: int = 4 - - def call(self, *args): - return run_state(self.fun)(*args) - - def ref(self, *args): - return run_state_reference(self.fun)(*args) - - def sin_stateful(refs): - x_ref, y_ref = refs - y_ref[...] = jnp.sin(x_ref[...]) - - sin_spec = FuncSpec(sin_stateful, "sin") - - def cos_stateful(refs): +class FuncSpec(NamedTuple): + fun: Callable[..., Any] + name: str + min_rank: int = 0 + max_rank: int = 4 + min_dim: int = 0 + max_dim: int = 4 + + def call(self, *args): + return run_state(self.fun)(*args) + + def ref(self, *args): + return run_state_reference(self.fun)(*args) + +def sin_stateful(refs): + x_ref, y_ref = refs + y_ref[...] = jnp.sin(x_ref[...]) + +sin_spec = FuncSpec(sin_stateful, "sin") + +def cos_stateful(refs): + x_ref, y_ref = refs + y_ref[...] = jnp.cos(x_ref[...]) + +cos_spec = FuncSpec(cos_stateful, "cos") + +def mul2_stateful(refs): + x_ref, y_ref = refs + y_ref[...] = x_ref[...] + y_ref[...] = y_ref[...] + x_ref[...] + +mul2_spec = FuncSpec(mul2_stateful, "mul2") + +def mul2_stateful_with_constant(refs): + x_ref, y_ref = refs + y_ref[...] = (2. * np.ones(x_ref.shape, x_ref.dtype)) * x_ref[...] + +mul2_constant_spec = FuncSpec(mul2_stateful_with_constant, "mul2_c") + +def crazy_identity_stateful(refs): + x_ref, y_ref = refs + x = x_ref[...] + x_ref[...] = (x + x) / 2 + y_ref[...] = x_ref[...] + y = y_ref[...] + y_ref[...] = (y + y) / 2 + +crazy_identity_spec = FuncSpec(crazy_identity_stateful, "id") + +def func_spec(depth: int = 4): + raw_specs = hps.sampled_from([sin_spec, cos_spec, mul2_spec, + mul2_constant_spec, crazy_identity_spec]) + if depth > 0: + return hps.one_of([raw_specs, nest_spec(depth - 1), add_spec(depth - 1), + compose_spec(depth - 1)]) + return raw_specs + +@hps.composite +def compose_spec(draw, depth): + f1 = draw(func_spec(depth)) + f2 = draw(func_spec(depth)) + def wrapped_impl(*args): + f1.fun(*args) + f2.fun(*args) + return FuncSpec(wrapped_impl, + f"({f2.name} . {f1.name})", + min_rank=max(f1.min_rank, f2.min_rank), + max_rank=min(f1.max_rank, f2.max_rank), + min_dim=max(f1.min_dim, f2.min_dim), + max_dim=min(f1.max_dim, f2.max_dim)) + +@hps.composite +def nest_spec(draw, depth): + f = draw(func_spec(depth)) + def wrapped_impl(refs): x_ref, y_ref = refs - y_ref[...] = jnp.cos(x_ref[...]) - - cos_spec = FuncSpec(cos_stateful, "cos") - - def mul2_stateful(refs): + x, y = x_ref[...], y_ref[...] + x, y = run_state(f.fun)((x, y)) + x_ref[...], y_ref[...] = x, y + return FuncSpec(wrapped_impl, + f"nest({f.name})", + min_rank=f.min_rank, + max_rank=f.max_rank, + min_dim=f.min_dim, + max_dim=f.max_dim) + + +@hps.composite +def add_spec(draw, depth): + f1 = draw(func_spec(depth)) + f2 = draw(func_spec(depth)) + def wrapped_impl(refs): x_ref, y_ref = refs - y_ref[...] = x_ref[...] - y_ref[...] = y_ref[...] + x_ref[...] - - mul2_spec = FuncSpec(mul2_stateful, "mul2") + x, y = x_ref[...], y_ref[...] + x1, y1 = run_state(f1.fun)((x, y)) + x2, y2 = run_state(f2.fun)((x, y)) + x_ref[...], y_ref[...] = x1 + x2, y1 + y2 + return FuncSpec(wrapped_impl, + f"({f2.name} + {f1.name})", + min_rank=max(f1.min_rank, f2.min_rank), + max_rank=min(f1.max_rank, f2.max_rank), + min_dim=max(f1.min_dim, f2.min_dim), + max_dim=min(f1.max_dim, f2.max_dim)) + +@jtu.thread_unsafe_test_class(condition=not jtu.hypothesis_is_thread_safe()) +class RunStateHypothesisTest(jtu.JaxTestCase): + + @jax.legacy_prng_key('allow') + @hp.given(hps.data()) + @hp.settings(deadline=None, print_blob=True, + max_examples=jtu.NUM_GENERATED_CASES.value) + def test_jvp(self, data): + + spec = data.draw(func_spec()) + + def impl(x): + return spec.call((x, jnp.zeros_like(x)))[1] + + def ref(x): + return spec.ref((x, jnp.zeros_like(x)))[1] + + k1, k2 = random.split(random.PRNGKey(0)) + shape = data.draw(hnp.array_shapes(min_dims=spec.min_rank, + max_dims=spec.max_rank, min_side=spec.min_dim, + max_side=spec.max_dim)) + x = random.normal(k1, shape) + t = random.normal(k2, x.shape) + y, y_t = jax.jvp(impl, (x,), (t,)) + y_ref, y_ref_t = jax.jvp(ref, (x,), (t,)) + self.assertAllClose(y, y_ref) + self.assertAllClose(y_t, y_ref_t) + + +class PinnedBuffersTest(jtu.JaxTestCase): + + def test_pin_unpin_basic(self): + @jax.jit + def f(x): + return unpin(pin(x)) - def mul2_stateful_with_constant(refs): - x_ref, y_ref = refs - y_ref[...] = (2. * np.ones(x_ref.shape, x_ref.dtype)) * x_ref[...] + x = jnp.arange(3.) + txt = f.lower(x).as_text('hlo') + self.assertIn("Pin", txt) - mul2_constant_spec = FuncSpec(mul2_stateful_with_constant, "mul2_c") + if jtu.test_device_matches(['gpu', 'tpu']): + y = f(x) + self.assertAllClose(y, x) - def crazy_identity_stateful(refs): - x_ref, y_ref = refs - x = x_ref[...] - x_ref[...] = (x + x) / 2 - y_ref[...] = x_ref[...] - y = y_ref[...] - y_ref[...] = (y + y) / 2 - - crazy_identity_spec = FuncSpec(crazy_identity_stateful, "id") - - def func_spec(depth: int = 4): - raw_specs = hps.sampled_from([sin_spec, cos_spec, mul2_spec, - mul2_constant_spec, crazy_identity_spec]) - if depth > 0: - return hps.one_of([raw_specs, nest_spec(depth - 1), add_spec(depth - 1), - compose_spec(depth - 1)]) - return raw_specs - - @hps.composite - def compose_spec(draw, depth): - f1 = draw(func_spec(depth)) - f2 = draw(func_spec(depth)) - def wrapped_impl(*args): - f1.fun(*args) - f2.fun(*args) - return FuncSpec(wrapped_impl, - f"({f2.name} . {f1.name})", - min_rank=max(f1.min_rank, f2.min_rank), - max_rank=min(f1.max_rank, f2.max_rank), - min_dim=max(f1.min_dim, f2.min_dim), - max_dim=min(f1.max_dim, f2.max_dim)) - - @hps.composite - def nest_spec(draw, depth): - f = draw(func_spec(depth)) - def wrapped_impl(refs): - x_ref, y_ref = refs - x, y = x_ref[...], y_ref[...] - x, y = run_state(f.fun)((x, y)) - x_ref[...], y_ref[...] = x, y - return FuncSpec(wrapped_impl, - f"nest({f.name})", - min_rank=f.min_rank, - max_rank=f.max_rank, - min_dim=f.min_dim, - max_dim=f.max_dim) - - - @hps.composite - def add_spec(draw, depth): - f1 = draw(func_spec(depth)) - f2 = draw(func_spec(depth)) - def wrapped_impl(refs): - x_ref, y_ref = refs - x, y = x_ref[...], y_ref[...] - x1, y1 = run_state(f1.fun)((x, y)) - x2, y2 = run_state(f2.fun)((x, y)) - x_ref[...], y_ref[...] = x1 + x2, y1 + y2 - return FuncSpec(wrapped_impl, - f"({f2.name} + {f1.name})", - min_rank=max(f1.min_rank, f2.min_rank), - max_rank=min(f1.max_rank, f2.max_rank), - min_dim=max(f1.min_dim, f2.min_dim), - max_dim=min(f1.max_dim, f2.max_dim)) - - @jtu.thread_unsafe_test_class() # because of hypothesis - class RunStateHypothesisTest(jtu.JaxTestCase): - - @jax.legacy_prng_key('allow') - @hp.given(hps.data()) - @hp.settings(deadline=None, print_blob=True, - max_examples=jtu.NUM_GENERATED_CASES.value) - def test_jvp(self, data): - - spec = data.draw(func_spec()) - - def impl(x): - return spec.call((x, jnp.zeros_like(x)))[1] - - def ref(x): - return spec.ref((x, jnp.zeros_like(x)))[1] - - k1, k2 = random.split(random.PRNGKey(0)) - shape = data.draw(hnp.array_shapes(min_dims=spec.min_rank, - max_dims=spec.max_rank, min_side=spec.min_dim, - max_side=spec.max_dim)) - x = random.normal(k1, shape) - t = random.normal(k2, x.shape) - y, y_t = jax.jvp(impl, (x,), (t,)) - y_ref, y_ref_t = jax.jvp(ref, (x,), (t,)) - self.assertAllClose(y, y_ref) - self.assertAllClose(y_t, y_ref_t) - - @jax.legacy_prng_key('allow') - @hp.given(hps.data()) - @hp.settings(deadline=None, print_blob=True, - max_examples=jtu.NUM_GENERATED_CASES.value) - def test_linearize(self, data): - - spec = data.draw(func_spec()) - - def impl(x): - return spec.call((x, jnp.zeros_like(x)))[1] - - def ref(x): - return spec.ref((x, jnp.zeros_like(x)))[1] - - - k1, k2 = random.split(random.PRNGKey(0)) - shape = data.draw(hnp.array_shapes(min_dims=spec.min_rank, - max_dims=spec.max_rank, min_side=spec.min_dim, - max_side=spec.max_dim)) - x = random.normal(k1, shape) - y, impl_lin = jax.linearize(impl, x) - y_ref, ref_lin = jax.linearize(ref, x) - self.assertAllClose(y, y_ref, atol=1e-2, rtol=1e-2) - t = random.normal(k2, x.shape) - self.assertAllClose(impl_lin(t), ref_lin(t), atol=1e-2, rtol=1e-2) - - @jax.legacy_prng_key('allow') - @hp.given(hps.data()) - @hp.settings(deadline=None, print_blob=True, - max_examples=jtu.NUM_GENERATED_CASES.value) - def test_vjp(self, data): - - spec = data.draw(func_spec()) - - def impl(x): - return spec.call((x, jnp.zeros_like(x)))[1] - - def ref(x): - return spec.ref((x, jnp.zeros_like(x)))[1] - - - key, k1, k2 = random.split(random.PRNGKey(0), 3) - shape = data.draw(hnp.array_shapes(min_dims=spec.min_rank, - max_dims=spec.max_rank, min_side=spec.min_dim, - max_side=spec.max_dim)) - x = random.normal(k1, shape) - - # First order - y, impl_lin = jax.linearize(impl, x) - y_ref, ref_lin = jax.linearize(ref, x) - self.assertAllClose(y, y_ref) - t = random.normal(k2, x.shape) - self.assertAllClose(impl_lin(t), ref_lin(t)) - - y, impl_vjp = jax.vjp(impl, x) - y_ref, ref_vjp = jax.vjp(ref, x) - self.assertAllClose(y, y_ref) - t = random.normal(jax.random.clone(k2), x.shape) - y2 = random.normal(jax.random.clone(k1), y.shape) - self.assertAllClose(impl_vjp(t), ref_vjp(t)) - - # Second order - key, k1, k2 = random.split(key, 3) - t2 = random.normal(k2, t.shape) - - (x,), impl_lin2 = jax.linearize(impl_vjp, t2) - (x_ref,), ref_lin2 = jax.linearize(ref_vjp, t2) - self.assertAllClose(x, x_ref) - y2 = random.normal(k1, y.shape) - self.assertAllClose(impl_lin2(y2), ref_lin2(y2)) - - (x,), impl_vjp2 = jax.vjp(impl_vjp, t2) - (x_ref,), ref_vjp2 = jax.vjp(ref_vjp, t2) - self.assertAllClose(x, x_ref) - y2 = random.normal(jax.random.clone(k1), y.shape) - self.assertAllClose(impl_vjp2((y2,)), ref_vjp2((y2,))) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/stateful_rng_test.py b/tests/stateful_rng_test.py new file mode 100644 index 000000000000..0d9cb30d8636 --- /dev/null +++ b/tests/stateful_rng_test.py @@ -0,0 +1,240 @@ +# Copyright 2026 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest + +import numpy as np + +import jax +import jax.numpy as jnp +from jax.experimental import random as exp_random +from jax._src import config +from jax._src import test_util as jtu + + +config.parse_flags_with_absl() + +class StatefulRNGTest(jtu.JaxTestCase): + + def test_stateful_rng_instantiation(self, seed=547389): + rng = exp_random.stateful_rng(seed) + key = jax.random.key(seed) + + self.assertEqual(key, rng._base_key) + self.assertEqual(rng._counter.shape, ()) + self.assertEqual(0, rng._counter[...]) + + def test_stateful_rng_counter_increment(self, seed=7865943): + rng = exp_random.stateful_rng(seed) + original_key = rng._base_key + self.assertEqual(0, rng._counter[...]) + + _ = jax.jit(rng.key)() # implicit update + + self.assertEqual(original_key, rng._base_key) # base key does not change + self.assertEqual(1, rng._counter[...]) # counter is incremented + + def test_stateful_rng_invalid_instantiation(self): + valid_key = jax.random.key(0) + valid_counter = jax.new_ref(0) + invalid_key = jax.numpy.array([0, 1], dtype='uint32') + invalid_counter = 0 + with self.assertRaisesRegex(ValueError, "Expected base_key to be a typed PRNG key"): + exp_random.StatefulPRNG(invalid_key, valid_counter) + with self.assertRaisesRegex(ValueError, "Expected counter to be a scalar integer ref"): + exp_random.StatefulPRNG(valid_key, invalid_counter) + + def testRepeatedKeys(self, seed=578543): + prng = exp_random.stateful_rng(seed) + self.assertNotEqual(prng.key(), prng.key()) + + def testShapedKeys(self, seed=7589432): + prng = exp_random.stateful_rng(seed) + + keys1 = prng.key(10) + self.assertEqual(keys1.shape, (10,)) + self.assertTrue(jax.dtypes.issubdtype(keys1.dtype, jax.dtypes.prng_key)) + + keys2 = prng.key(10) + self.assertEqual(keys1.shape, (10,)) + self.assertTrue(jax.dtypes.issubdtype(keys2.dtype, jax.dtypes.prng_key)) + + self.assertFalse((keys1 == keys2).any()) + + def testRepeatedDraws(self, seed=328090): + prng = exp_random.stateful_rng(seed) + vals1 = prng.uniform(size=10) + vals2 = prng.uniform(size=10) + self.assertTrue((vals1 != vals2).all()) + + def testRepeatedDrawsJIT(self, seed=328090): + prng = exp_random.stateful_rng(seed) + @jax.jit + def get_values(prng): + return prng.uniform(size=10) + vals1 = get_values(prng) + vals2 = get_values(prng) + self.assertTrue((vals1 != vals2).all()) + + @jtu.sample_product( + size=[None, 2, (5, 2)], + dtype=jtu.dtypes.floating, + ) + def testRandom(self, size, dtype): + rng = exp_random.stateful_rng(578943) + vals = rng.random(size, dtype) + shape = np.broadcast_shapes(size or ()) + + self.assertEqual(vals.shape, shape) + self.assertEqual(vals.dtype, dtype) + self.assertTrue((vals < 1).all()) + self.assertTrue((vals >= 0).all()) + + @jtu.sample_product( + low=[0, 1, np.array([0, 1])], + high=[2, 3, np.array([2, 3])], + size=[None, 2, (5, 2)], + dtype=jtu.dtypes.floating, + ) + @jax.numpy_dtype_promotion('standard') + @jax.numpy_rank_promotion('allow') + def testUniform(self, low, high, size, dtype): + rng = exp_random.stateful_rng(473289) + vals = rng.uniform(low, high, size, dtype=dtype) + shape = np.broadcast_shapes(np.shape(low), np.shape(high), size or ()) + + self.assertEqual(vals.shape, shape) + self.assertEqual(vals.dtype, dtype) + self.assertTrue((vals < high).all()) + self.assertTrue((vals >= low).all()) + + @jtu.sample_product( + loc=[0, 1, np.array([0, 1])], + scale=[2, 3, np.array([2, 3])], + size=[None, 2, (5, 2)], + dtype=jtu.dtypes.floating, + ) + @jax.numpy_dtype_promotion('standard') + @jax.numpy_rank_promotion('allow') + def testNormal(self, loc, scale, size, dtype): + rng = exp_random.stateful_rng(473289) + vals = rng.normal(loc, scale, size, dtype=dtype) + shape = np.broadcast_shapes(np.shape(loc), np.shape(scale), size or ()) + + self.assertEqual(vals.shape, shape) + self.assertEqual(vals.dtype, dtype) + + @jtu.sample_product( + low=[0, 1, np.array([0, 1])], + high=[10, 15, np.array([10, 15])], + size=[None, 2, (5, 2)], + dtype=jtu.dtypes.integer, + ) + @jax.numpy_dtype_promotion('standard') + @jax.numpy_rank_promotion('allow') + def testIntegers(self, low, high, size, dtype): + rng = exp_random.stateful_rng(473289) + vals = rng.integers(low, high, size, dtype=dtype) + shape = np.broadcast_shapes(np.shape(low), np.shape(high), size or ()) + + self.assertEqual(vals.shape, shape) + self.assertEqual(vals.dtype, dtype) + self.assertTrue((vals < high).all()) + self.assertTrue((vals >= low).all()) + + def testSpawn(self): + rng = exp_random.stateful_rng(758943) + rngs = rng.spawn(4) + + for child_rng in rngs: + self.assertNotEqual(rng._base_key, child_rng._base_key) + self.assertEqual(0, child_rng._counter[...]) + + @jtu.sample_product(shape=[4, (5,), (2, 3)]) + def testSplit(self, shape): + rng = exp_random.stateful_rng(758943) + rng_split = rng.split(shape) + + expected_shape = (shape,) if isinstance(shape, int) else shape + + self.assertEqual(rng_split._base_key.dtype, rng._base_key.dtype) + self.assertEqual(rng_split._base_key.shape, expected_shape) + + self.assertIsInstance(rng_split._counter, jax.Ref) + self.assertEqual(rng_split._counter.shape, expected_shape) + + def testVmapMapped(self): + seed = 758943 + N = 4 + x = np.arange(N, dtype=float) + def f(rng, x): + return x + rng.uniform() + + rng = exp_random.stateful_rng(seed) + expected = x + jnp.array([rng.uniform() for rng in rng.spawn(N)]) + + rng = exp_random.stateful_rng(seed) + actual = jax.vmap(f)(rng.split(N), x) + + self.assertArraysEqual(actual, expected) + + def testVmapUnmapped(self): + seed = 758943 + x = np.arange(4, dtype=float) + rng = exp_random.stateful_rng(seed) + def f(rng, x): + return x + rng.uniform() + with self.assertRaisesRegex(Exception, "performing an addupdate operation with vmapped value"): + jax.vmap(f, in_axes=(None, 0))(rng, x) + + def testScanClosure(self): + seed = 432932 + def f1(seed): + rng = exp_random.stateful_rng(seed) + def scan_f(_, __): + return None, rng.uniform() + return jax.lax.scan(scan_f, None, length=10)[1] + + def f2(seed): + rng = exp_random.stateful_rng(seed) + return jax.numpy.array([rng.uniform() for i in range(10)]) + + self.assertArraysAllClose(f1(seed), f2(seed)) + + def testDefaultSeed(self): + rng = exp_random.stateful_rng() + x = rng.uniform(size=10) + self.assertEqual(x.shape, (10,)) + + def testDefaultSeedErrorUnderJIT(self): + def f(): + return exp_random.stateful_rng().uniform(size=10) + with self.assertRaisesRegex(TypeError, "When used within transformed code"): + jax.jit(f)() + + def testDefaultSeedErrorUnderGrad(self): + def f(x): + return x + exp_random.stateful_rng().uniform() + with self.assertRaisesRegex(TypeError, "When used within transformed code"): + jax.grad(f)(1.0) + + def testDefaultSeedErrorUnderVmap(self): + def f(x): + return x + exp_random.stateful_rng().uniform() + with self.assertRaisesRegex(TypeError, "When used within transformed code"): + jax.vmap(f)(jnp.arange(5.0)) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/stax_test.py b/tests/stax_test.py index 8c38820d27a5..0e86c443c457 100644 --- a/tests/stax_test.py +++ b/tests/stax_test.py @@ -17,17 +17,17 @@ import numpy as np import jax +from jax._src import dtypes from jax._src import test_util as jtu from jax import random from jax.example_libraries import stax -from jax import dtypes jax.config.parse_flags_with_absl() def random_inputs(rng, input_shape): if type(input_shape) is tuple: - return rng.randn(*input_shape).astype(dtypes.canonicalize_dtype(float)) + return rng.randn(*input_shape).astype(dtypes.default_float_dtype()) elif type(input_shape) is list: return [random_inputs(rng, shape) for shape in input_shape] else: @@ -40,7 +40,7 @@ def _CheckShapeAgreement(test_case, init_fun, apply_fun, input_shape): result_shape, params = init_fun(init_key, input_shape) inputs = random_inputs(test_case.rng(), input_shape) if params: - inputs = inputs.astype(np.dtype(params[0])) + inputs = inputs.astype(dtypes.dtype(jax.tree.leaves(params)[0])) result = apply_fun(params, inputs, rng=rng_key) test_case.assertEqual(result.shape, result_shape) diff --git a/tests/string_array_test.py b/tests/string_array_test.py index 364c71759023..797f85c63c7b 100644 --- a/tests/string_array_test.py +++ b/tests/string_array_test.py @@ -26,14 +26,6 @@ class StringArrayTest(jtu.JaxTestCase): - def setUp(self): - super().setUp() - if not hasattr(np.dtypes, "StringDType"): - self.skipTest( - "Skipping this test because the numpy.dtype.StringDType is not" - " available." - ) - def make_test_string_array(self, device=None): """Makes and returns a simple 2x1 string array on the first CPU device.""" if device is None: diff --git a/tests/svd_test.py b/tests/svd_test.py index 97f8176f8f94..a04b25046370 100644 --- a/tests/svd_test.py +++ b/tests/svd_test.py @@ -20,7 +20,7 @@ import scipy.linalg as osp_linalg from jax._src import config from jax._src import test_util as jtu -from jax._src.lax import svd +from jax._src.tpu.linalg import svd from absl.testing import absltest @@ -52,10 +52,7 @@ def testSvdvals(self, shape, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] jnp_fun = jax.numpy.linalg.svdvals - if jtu.numpy_version() < (2, 0, 0): - np_fun = lambda x: np.linalg.svd(x, compute_uv=False) - else: - np_fun = np.linalg.svdvals + np_fun = np.linalg.svdvals self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=_SVD_RTOL, atol=1E-5) self._CompileAndCheck(jnp_fun, args_maker, rtol=_SVD_RTOL) @@ -166,7 +163,7 @@ def testSvdWithOnRankDeficientInputZeroColumns(self, m, r): np.testing.assert_almost_equal(diff, 1e-4, decimal=2) # Check that u and v are orthogonal. self.assertAllClose(u.T.conj() @ u, np.eye(m), atol=10 * _SVD_TEST_EPS) - self.assertAllClose(v.T.conj() @ v, np.eye(m), atol=11 * _SVD_TEST_EPS) + self.assertAllClose(v.T.conj() @ v, np.eye(m), atol=30 * _SVD_TEST_EPS) @jtu.sample_product( [dict(m=m, n=n) for m, n in zip([2, 8, 10, 20], [4, 6, 10, 18])], @@ -189,7 +186,9 @@ def testSingularValues(self, m, n, log_cond, full_matrices): osp_linalg_fn = functools.partial( osp_linalg.svd, full_matrices=full_matrices, compute_uv=compute_uv) - actual_s = svd.svd(a, full_matrices=full_matrices, compute_uv=compute_uv) + actual_s = svd.svd( + a, full_matrices=full_matrices, compute_uv=compute_uv + ).block_until_ready() expected_s = osp_linalg_fn(a) @@ -275,13 +274,16 @@ def lax_fun(a): start=[0, 1, 64, 126, 127], end=[1, 2, 65, 127, 128], ) - @jtu.run_on_devices('tpu') # TODO(rmlarsen: enable on other devices) + @jtu.run_on_devices('tpu', 'rocm') def testSvdSubsetByIndex(self, start, end): if start >= end: return dtype = np.float32 m = 256 n = 128 + # subset_by_index is only implemented for TPU; on ROCm only full range works + if jtu.is_device_rocm() and not (start == 0 and end == min(m, n)): + self.skipTest("subset_by_index not implemented for ROCm") rng = jtu.rand_default(self.rng()) tol = np.maximum(n, 80) * np.finfo(dtype).eps args_maker = lambda: [rng((m, n), dtype)] diff --git a/tests/traceback_test.py b/tests/traceback_test.py new file mode 100644 index 000000000000..499284667dc6 --- /dev/null +++ b/tests/traceback_test.py @@ -0,0 +1,160 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import traceback + +from absl.testing import absltest +import jax +from jax._src import test_util as jtu +from jax._src.lib import _jax +import jax.numpy as jnp +import numpy as np + +Traceback = _jax.Traceback +Frame = _jax.Frame + + +@contextlib.contextmanager +def tracebacks(enabled=True): + """Context manager that enables or disables traceback collection.""" + saved = _jax.tracebacks_enabled() + _jax.set_tracebacks_enabled(enabled) + try: + yield + finally: + _jax.set_tracebacks_enabled(saved) + + +class TracebackTest(absltest.TestCase): + + def testNoTracebacksIfDisabled(self): + with tracebacks(enabled=False): + self.assertEqual(None, Traceback.get_traceback()) + buffer = jnp.array(7, np.int32) + self.assertEqual(None, buffer.traceback) + + e = jax.jit(lambda x: x + 1).lower(1).compile().runtime_executable() + self.assertEqual(None, e.traceback) + + def assertIsTracebackContaining(self, tb, function): + self.assertIsInstance(tb, Traceback) + self.assertIn(function, str(tb)) + self.assertTrue(any(f.function_name == function for f in tb.frames)) + + def testTracebacks(self): + with tracebacks(enabled=True): + fn = "TracebackTest.testTracebacks" + + tb = Traceback.get_traceback() + self.assertIsTracebackContaining(tb, fn) + + buffer = jnp.array(7, np.int32) + self.assertIsTracebackContaining(buffer.traceback, fn) + + e = jax.jit(lambda x: x + 1).lower(1).compile().runtime_executable() + self.assertIsTracebackContaining(e.traceback, fn) + + def testNestedFunction(self): + + def AFunction(): + + def AnotherFunction(): + return Traceback.get_traceback() + + return AnotherFunction() + + with tracebacks(enabled=True): + tb = AFunction() + self.assertIsInstance(tb, Traceback) + frames = tb.frames + fn = "TracebackTest.testNestedFunction..AFunction" + i = next(i for (i, f) in enumerate(frames) if f.function_name == fn) + self.assertEqual( + frames[i - 1].function_name, + "TracebackTest.testNestedFunction..AFunction..AnotherFunction", + ) + self.assertEqual( + frames[i + 1].function_name, "TracebackTest.testNestedFunction" + ) + + def testPythonTracebackHasCorrectLineNumbers(self): + def B(): + return Traceback.get_traceback() + + def A(): + return B() + + tb = A().as_python_traceback() + for frame, lineno in traceback.walk_tb(tb): + if frame.f_code.co_name == "A": + line = A.__code__.co_firstlineno + self.assertBetween(lineno, line, line + 2) + elif frame.f_code.co_name == "B": + line = B.__code__.co_firstlineno + self.assertBetween(lineno, line, line + 2) + + def testAccessingLocalsDoesNotCrash(self): + # https://github.com/google/jax/issues/16027 + tb = Traceback.get_traceback() + python_tb = tb.as_python_traceback() + for frame, _ in traceback.walk_tb(python_tb): + _ = frame.f_locals # should not crash + + def testTracebackFromFrames(self): + def FooFn(x): + return x + 1 + + def BarFn(y): + y = y + 1 + y = y + 2 + return y * 2 + + frame_foo = Frame( + __file__, + FooFn.__code__.co_name, + FooFn.__code__.co_firstlineno, + FooFn.__code__.co_firstlineno + 1, + ) + frame_bar = Frame( + __file__, + BarFn.__code__.co_name, + BarFn.__code__.co_firstlineno, + BarFn.__code__.co_firstlineno + 2, + ) + frames = [frame_foo, frame_bar] + tb = Traceback.traceback_from_frames(frames) + + with self.subTest("WalkDoesNotError"): + for frame, _ in traceback.walk_tb(tb): + _ = frame.f_locals # should not crash + + with self.subTest("TracebackCorrectness"): + tb_string = traceback.format_tb(tb) + # The traceback should have the format: + # File , line N in BarFn + # y = y + 2 + # File , line N in FooFn + # return x + 1 + self.assertLen(tb_string, len(frames)) + bar_frame = tb_string[0].split("\n") + self.assertEndsWith(bar_frame[0], "BarFn") + self.assertEqual(bar_frame[1].strip(), "y = y + 2") + foo_frame = tb_string[1].split("\n") + self.assertEndsWith(foo_frame[0], "FooFn") + self.assertEqual(foo_frame[1].strip(), "return x + 1") + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index e5e649d43d8a..b9edfe245ed2 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -25,7 +25,8 @@ from jax import flatten_util from jax import tree_util from jax._src import test_util as jtu -from jax._src.tree_util import flatten_one_level, prefix_errors +from jax._src.tree_util import ( + flatten_one_level, prefix_errors, broadcast_flattened_prefix_with_treedef) import jax.numpy as jnp # Easier to read. @@ -394,14 +395,12 @@ def testFlattenUpTo(self, tree, xs, expected): ( {"a": 1}, {"a": 7, "b": 8}, - re.escape( - "Dict key mismatch; expected keys: ['a']; dict: {'a': 7, 'b': 8}." - ), + re.escape("Dict key mismatch; expected keys: ['a'];"), ), ( {"a": 1}, {"b": 7}, - re.escape("Dict key mismatch; expected keys: ['a']; dict: {'b': 7}."), + re.escape("Dict key mismatch; expected keys: ['a'];"), ), ([1], {"a": 7}, re.escape("Expected list, got {'a': 7}.")), ([1], (7,), re.escape("Expected list, got (7,).")), @@ -421,7 +420,7 @@ def testFlattenUpTo(self, tree, xs, expected): ( [{"a": 1}], [{"b": 7}], - re.escape("Dict key mismatch; expected keys: ['a']; dict: {'b': 7}."), + re.escape("Dict key mismatch; expected keys: ['a'];"), ), (([1],), (7,), re.escape("Expected list, got 7.")), (([1],), ((7,),), re.escape("Expected list, got (7,).")), @@ -435,7 +434,7 @@ def testFlattenUpTo(self, tree, xs, expected): ( ({"a": 1},), ({"b": 7},), - re.escape("Dict key mismatch; expected keys: ['a']; dict: {'b': 7}."), + re.escape("Dict key mismatch; expected keys: ['a'];"), ), ({"a": [1]}, {"a": 7}, re.escape("Expected list, got 7.")), ({"a": [1]}, {"a": (7,)}, re.escape("Expected list, got (7,).")), @@ -453,7 +452,7 @@ def testFlattenUpTo(self, tree, xs, expected): ( {"a": {"a": 1}}, {"a": {"b": 7}}, - re.escape("Dict key mismatch; expected keys: ['a']; dict: {'b': 7}."), + re.escape("Dict key mismatch; expected keys: ['a'];"), ), ( [ATuple(foo=1, bar=2)], @@ -470,9 +469,7 @@ def testFlattenUpTo(self, tree, xs, expected): [([1], (2,), {"a": [1]})], re.escape("Custom node type mismatch"), ), - ( - (None, [2], re.escape("Expected None, got [2].")) - ), + ((None, [2], re.escape("Expected None, got [2]."))), ) def testFlattenUpToErrors(self, tree, xs, error): _, tree_def = tree_util.tree_flatten(tree) @@ -499,6 +496,13 @@ def testTreeReduceWithIsLeafArgument(self): is_leaf=lambda l: isinstance(l, tuple)) self.assertEqual(out, (1, 2, 3, 4, 5, 6)) + def testTreeReduceAssociativeWithIsLeafArgument(self): + out = tree_util.tree_reduce_associative( + lambda x, y: x + y, [(1, 2), [(3, 4), (5, 6)]], + is_leaf=lambda l: isinstance(l, tuple), + ) + self.assertEqual(out, (1, 2, 3, 4, 5, 6)) + @parameterized.parameters( tree_util.tree_leaves, lambda tree, is_leaf: tree_util.tree_flatten(tree, is_leaf)[0]) @@ -552,6 +556,18 @@ def testAllLeavesWithTrees(self, tree): def testAllLeavesWithLeaves(self, leaf): self.assertTrue(tree_util.all_leaves([leaf])) + @parameterized.parameters(*TREES) + def testAllLeavesWithTreesAndCustomIsLeaf(self, tree): + def is_leaf(t): + return tree_util.all_leaves([t]) + self.assertFalse(tree_util.all_leaves([tree], is_leaf=is_leaf)) + + @parameterized.parameters(*LEAVES) + def testAllLeavesWithLeavesAndCustomIsLeaf(self, leaf): + def is_leaf(t): + return tree_util.all_leaves([t]) + self.assertTrue(tree_util.all_leaves([leaf], is_leaf=is_leaf)) + @parameterized.parameters(*TREES) def testCompose(self, tree): treedef = tree_util.tree_structure(tree) @@ -615,11 +631,57 @@ def testTransposeWithCustomObject(self): FlatCache({"a": [3, 4], "b": [5, 6]})) self.assertEqual(expected, actual) + @parameterized.parameters(*TREES) + def testBroadcast(self, tree): + if isinstance(tree, FlatCache): + # The tree_map construction below fails for FlatCache, because + # the cached metadata becomes out of sync. + self.skipTest("Test does not work properly for FlatCache.") + def make_inner(x): + return [x, x, x] + nested = tree_util.tree_map(make_inner, tree) + actual = tree_util.tree_broadcast(tree, nested) + self.assertEqual(actual, nested) + + actual_flat = broadcast_flattened_prefix_with_treedef( + *tree_util.tree_flatten(tree), tree_util.tree_structure(nested)) + actual = tree_util.tree_structure(nested).unflatten(actual_flat) + self.assertEqual(actual, nested) + + def testBroadcastSimple(self): + prefix = (1, 2, 3) + full = (0, {'a': 0, 'b': 0}, (0, 0)) + actual = tree_util.tree_broadcast(prefix, full) + expected = (1, {'a': 2, 'b': 2}, (3, 3)) + self.assertEqual(actual, expected) + + def testBroadcastError(self): + prefix = (1, 2, 3) + full = (0, {'a': 0, 'b': 0}) + with self.assertRaisesRegex(ValueError, "pytree structure error"): + tree_util.tree_broadcast(prefix, full) + with self.assertRaises(Exception): + broadcast_flattened_prefix_with_treedef( + *tree_util.tree_flatten(prefix), tree_util.tree_structure(full)) + prefix = (1, 2) + full = (0, {'a': 0, 'b': 0}, (0, 0)) + with self.assertRaisesRegex(ValueError, "pytree structure error"): + tree_util.tree_broadcast(prefix, full) + with self.assertRaises(Exception): + broadcast_flattened_prefix_with_treedef( + *tree_util.tree_flatten(prefix), tree_util.tree_structure(full)) + prefix = (1, {'a': 0}) + full = (0, {'a': 0, 'b': 0}) + with self.assertRaisesRegex(ValueError, "pytree structure error"): + tree_util.tree_broadcast(prefix, full) + with self.assertRaises(Exception): + broadcast_flattened_prefix_with_treedef( + *tree_util.tree_flatten(prefix), tree_util.tree_structure(full)) + @parameterized.parameters([(*t, s) for t, s in zip(TREES, TREE_STRINGS)]) def testStringRepresentation(self, tree, correct_string): """Checks that the string representation of a tree works.""" treedef = tree_util.tree_structure(tree) - print(TREES) self.assertRegex(str(treedef), correct_string) def testTreeDefWithEmptyDictStringRepresentation(self): @@ -746,7 +808,7 @@ def testTreeMapWithPathWithIsLeafArgument(self): y = (([3], jnp.array(0)), ([0], 7, [5, 6])) out = tree_util.tree_map_with_path( lambda kp, *xs: (kp[0].idx, *xs), x, y, - is_leaf=lambda n: isinstance(n, list)) + is_leaf=lambda _, n: isinstance(n, list), is_leaf_takes_path=True) self.assertEqual(out, (((0, 1, [3]), (0, 2, jnp.array(0))), (1, [3, 4, 5], ([0], 7, [5, 6])))) @@ -763,7 +825,11 @@ def is_empty(x): tree1 = {'a': 1, 'sub': [jnp.array((1, 2)), ATuple(foo=(), bar=[None])], 'obj': AnObject2(x=EmptyTuple(), y=0, z='constantdef')} - flattened, _ = tree_util.tree_flatten_with_path(tree1, is_empty) + + is_empty_new = lambda kp, x: is_empty(x) + flattened, _ = tree_util.tree_flatten_with_path( + tree1, is_empty_new, is_leaf_takes_path=True + ) strs = [f"{tree_util.keystr(kp)}: {x}" for kp, x in flattened] self.assertEqual( strs, @@ -777,6 +843,32 @@ def is_empty(x): ], ) + def testTreeFlattenWithPathWithIsLeafWithPathArgument(self): + x = ((1, 2), [3, {4: 4, 5: 5}]) + check_max_depth = lambda kp, _: len(kp) >= 2 + flattened, _ = tree_util.tree_flatten_with_path( + x, is_leaf=check_max_depth, is_leaf_takes_path=True + ) + self.assertEqual( + flattened, + [ + ((SequenceKey(0), SequenceKey(0),), 1), + ((SequenceKey(0), SequenceKey(1),), 2), + ((SequenceKey(1), SequenceKey(0),), 3), + ((SequenceKey(1), SequenceKey(1)), {4: 4, 5: 5}), + ], + ) + + def testTreeMapWithPathWithIsLeafWithPathArgument(self): + x = ((1, 2), [3, 4, 5]) + y = (([3], jnp.array(0)), ([0], 7, [5, 6])) + out = tree_util.tree_map_with_path( + lambda kp, *xs: (kp[0].idx, *xs), x, y, + is_leaf=lambda kp, n: isinstance(n, list), is_leaf_takes_path=True) + self.assertEqual(out, (((0, 1, [3]), + (0, 2, jnp.array(0))), + (1, [3, 4, 5], ([0], 7, [5, 6])))) + def testTreeFlattenWithPathBuiltin(self): x = (1, {"a": 2, "b": 3}) flattened = tree_util.tree_flatten_with_path(x) @@ -1005,6 +1097,24 @@ def testPickle(self): unpickled = pickle.loads(pickle.dumps(key)) self.assertEqual(key, unpickled) + def testEqualityErrorWithArrayAsStaticArg(self): + # Regression test for https://github.com/jax-ml/jax/issues/28659 + @tree_util.register_dataclass + @dataclasses.dataclass + class Tree: + x : jnp.ndarray = dataclasses.field(metadata={'static': True}) + + f = jax.jit(lambda x: x) + + msg = "Exception raised while checking equality of metadata fields of pytree." + + # First call succeeds, because there is no equality check. + f(Tree(jnp.arange(4))) + + # Second fall fails, because arrays are marked static and compared for equality. + with self.assertRaisesRegex(ValueError, msg): + f(Tree(jnp.arange(4))) + class StaticTest(parameterized.TestCase): @@ -1408,6 +1518,23 @@ def test_tree_reduce_is_leaf(self): tree_util.tree_reduce(func, obj, is_leaf=is_leaf), ) + def test_tree_reduce_associative(self): + func = lambda a, b: a + b + obj = [1, 2, (3, 4)] + self.assertEqual( + jax.tree.reduce_associative(func, obj), + tree_util.tree_reduce_associative(func, obj), + ) + + def test_tree_reduce_associative_is_leaf(self): + func = lambda a, b: a + b + obj = [(1, 2), (3, 4)] + is_leaf = lambda x: isinstance(x, tuple) + self.assertEqual( + jax.tree.reduce_associative(func, obj, is_leaf=is_leaf), + tree_util.tree_reduce_associative(func, obj, is_leaf=is_leaf), + ) + def test_tree_structure(self): obj = [1, 2, (3, 4)] self.assertEqual( @@ -1432,6 +1559,13 @@ def test_tree_transpose(self): tree_util.tree_transpose(outer_treedef, inner_treedef, obj) ) + def test_tree_broadcast(self): + prefix = (1, 2, 3) + full = (0, {'a': 0, 'b': 0}, (0, 0)) + actual = jax.tree.broadcast(prefix, full) + expected = (1, {'a': 2, 'b': 2}, (3, 3)) + self.assertEqual(actual, expected) + def test_tree_unflatten(self): leaves, treedef = jax.tree.flatten([1, 2, (3, 4)]) self.assertEqual( @@ -1449,9 +1583,10 @@ def test_tree_flatten_with_path(self): def test_tree_flatten_with_path_is_leaf(self): obj = [1, 2, (3, 4)] is_leaf = lambda x: isinstance(x, tuple) + is_leaf = lambda kp, x: isinstance(x, tuple) self.assertEqual( - jax.tree.flatten_with_path(obj, is_leaf=is_leaf), - tree_util.tree_flatten_with_path(obj, is_leaf=is_leaf), + jax.tree.flatten_with_path(obj, is_leaf, is_leaf_takes_path=True), + tree_util.tree_flatten_with_path(obj, is_leaf, is_leaf_takes_path=True), ) def test_tree_leaves_with_path(self): @@ -1464,9 +1599,14 @@ def test_tree_leaves_with_path(self): def test_tree_leaves_with_path_is_leaf(self): obj = [1, 2, (3, 4)] is_leaf = lambda x: isinstance(x, tuple) + is_leaf = lambda kp, x: isinstance(x, tuple) self.assertEqual( - jax.tree.leaves_with_path(obj, is_leaf=is_leaf), - tree_util.tree_leaves_with_path(obj, is_leaf=is_leaf), + jax.tree.leaves_with_path( + obj, is_leaf=is_leaf, is_leaf_takes_path=True + ), + tree_util.tree_leaves_with_path( + obj, is_leaf=is_leaf, is_leaf_takes_path=True + ), ) def test_tree_map_with_path(self): @@ -1483,9 +1623,14 @@ def test_tree_map_with_path_is_leaf(self): obj = [1, 2, (3, 4)] obj2 = [5, 6, (7, 8)] is_leaf = lambda x: isinstance(x, tuple) + is_leaf = lambda kp, x: isinstance(x, tuple) self.assertEqual( - jax.tree.map_with_path(func, obj, obj2, is_leaf=is_leaf), - tree_util.tree_map_with_path(func, obj, obj2, is_leaf=is_leaf), + jax.tree.map_with_path( + func, obj, obj2, is_leaf=is_leaf, is_leaf_takes_path=True + ), + tree_util.tree_map_with_path( + func, obj, obj2, is_leaf=is_leaf, is_leaf_takes_path=True + ), ) @@ -1553,6 +1698,21 @@ class Foo: Foo, data_fields=["x"], meta_fields=["y", "z"] ) + def test_register_dataclass_overlapping_fields(self): + @dataclasses.dataclass + class Foo: + x: int + y: float + + with self.assertRaisesRegex( + ValueError, + "data_fields and meta_fields must not overlap.*" + "Overlapping fields: {'x'}", + ): + tree_util.register_dataclass( + Foo, data_fields=["x", "y"], meta_fields=["x"] + ) + def test_register_dataclass_drop_fields(self): @dataclasses.dataclass class Foo: diff --git a/tests/typing_test.py b/tests/typing_test.py index 562c6c56d2d9..9368d11c8af8 100644 --- a/tests/typing_test.py +++ b/tests/typing_test.py @@ -24,6 +24,7 @@ import jax from jax._src import core +from jax._src import dtypes from jax._src import test_util as jtu from jax._src import typing from jax import lax @@ -45,7 +46,7 @@ def dtypelike_to_dtype(x: typing.DTypeLike) -> typing.DType: # inputs to jax primitive functions; use convert_element_type here # for simplicity. def arraylike_to_array(x: typing.ArrayLike) -> typing.Array: - return lax.convert_element_type(x, np.result_type(x)) + return lax.convert_element_type(x, dtypes.dtype(x)) class HasDType: @@ -74,6 +75,8 @@ def testDTypeLike(self) -> None: out5: typing.DType = dtypelike_to_dtype(HasDType("float32")) self.assertEqual(out5, float32_dtype) + @jtu.ignore_warning(category=UserWarning, + message="Explicitly requested dtype.*") def testArrayLike(self) -> None: out1: typing.Array = arraylike_to_array(jnp.arange(4)) self.assertArraysEqual(out1, jnp.arange(4)) @@ -81,8 +84,8 @@ def testArrayLike(self) -> None: out2: typing.Array = jax.jit(arraylike_to_array)(jnp.arange(4)) self.assertArraysEqual(out2, jnp.arange(4)) - out3: typing.Array = arraylike_to_array(np.arange(4)) - self.assertArraysEqual(out3, jnp.arange(4), check_dtypes=False) + out3: typing.Array = arraylike_to_array(np.arange(4, dtype=np.int32)) + self.assertArraysEqual(out3, jnp.arange(4, dtype=np.int32)) out4: typing.Array = arraylike_to_array(True) self.assertArraysEqual(out4, jnp.array(True)) @@ -143,11 +146,7 @@ def f(x: Any) -> typing.Array | None: # - Confirm that types from *.pyi files are correctly pulled-in # - Confirm that non-trivial overloads are behaving as expected. # - import sys - if sys.version_info >= (3, 11): - from typing import assert_type # pytype: disable=not-supported-yet # py311-upgrade - else: - from typing_extensions import assert_type # pytype: disable=not-supported-yet + from typing import assert_type # pytype: disable=not-supported-yet # py311-upgrade mat = jnp.zeros((2, 5)) vals = jnp.arange(5) diff --git a/tests/unary_ops_accuracy_test.py b/tests/unary_ops_accuracy_test.py new file mode 100644 index 000000000000..4883b03446b8 --- /dev/null +++ b/tests/unary_ops_accuracy_test.py @@ -0,0 +1,404 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit test for result accuracy for unary ops.""" + +from typing import Any, NamedTuple +from collections.abc import Callable + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import config +from jax._src import test_util as jtu +from jax._src.lax import lax +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import hlo +import jax.numpy as jnp +import numpy as np + + +config.parse_flags_with_absl() + + +class TolerancePair(NamedTuple): + high: lax.Tolerance | lax.AccuracyMode = lax.AccuracyMode.DEFAULT + low: lax.Tolerance | lax.AccuracyMode = lax.AccuracyMode.DEFAULT + + +def make_unary_test_cases( + testcase_name: str, + op: Callable[..., Any], + x: np.ndarray, + tp: TolerancePair = None, + min_error_val: float = 0.0, +): + """Creates a single test case.""" + return [{ + "testcase_name": testcase_name, + "op": op, + "x": x, + "tp": tp, + "min_error_val": min_error_val, + }] + + +UNARY_OPS = { + "exp": make_unary_test_cases( + "exp", + lax.exp, + np.arange(84.0, 88.0, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=2**-5, rtol=2**-5, ulps=2), + low=lax.Tolerance(atol=1.5 * 2**-8, rtol=2**-18, ulps=2), + ), + ), + "exp2": make_unary_test_cases( + "exp2", + lax.exp2, + np.arange(84.0, 88.0, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=2**-5, rtol=2**-5, ulps=2), + low=lax.Tolerance(atol=1.5 * 2**-8, rtol=2**-18, ulps=2), + ), + ), + "expm1": make_unary_test_cases( + "expm1", + lax.expm1, + np.arange(84.0, 88.0, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=2**-5, rtol=2**-5, ulps=2), + low=lax.Tolerance(atol=1.5 * 2**-8, rtol=2**-18, ulps=2), + ), + ), + "log": make_unary_test_cases( + "log", + lax.log, + np.linspace(1e28, 2e28, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=2**-16, rtol=2**-20, ulps=0), + ), + 1.0, + ), + "log1p": make_unary_test_cases( + "log1p", + lax.log1p, + np.linspace(-9e-8, -8e-8, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-11, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-14, ulps=0), + ), + 1.0, + ), + "tanh": make_unary_test_cases( + "tanh", + lax.tanh, + np.linspace(5.83, 5.86, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=2**-12, rtol=0, ulps=0), + low=lax.Tolerance(atol=2**-16, rtol=0, ulps=0), + ), + ), + "cos": make_unary_test_cases( + "cos", + lax.cos, + np.linspace(9.7e22, 9.8e22, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0), + ), + ), + "sin": make_unary_test_cases( + "sin", + lax.sin, + np.linspace(9.7e22, 9.8e22, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0), + ), + ), + "tan": make_unary_test_cases( + "tan", + lax.tan, + np.linspace(250.0, 252.0, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0), + ), + ), + "sqrt": make_unary_test_cases( + "sqrt", + lax.sqrt, + np.linspace(250.0, 252.0, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0), + ), + ), + "rsqrt": make_unary_test_cases( + "rsqrt", + lax.rsqrt, + np.linspace(250.0, 252.0, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0), + ), + ), +} + + +def generate_test_cases(op_names): + test_cases = [] + for op in op_names: + op_group = UNARY_OPS[op] + if op_group is None: + raise ValueError(f"No test cases found for op: {op}") + test_cases.extend(op_group) + return test_cases + + +class UnaryOpsAccuracyTest(jtu.JaxTestCase): + + def setUp(self): + if not jtu.stablehlo_version_at_least("1.10.0"): + self.skipTest("Test requires StableHLO v1.10.0 or higher.") + if not jtu.is_device_tpu(): + self.skipTest("Skipping test on non TPU devices.") + # TODO(b/412112097): Enable this test on TPU version 7 and above once + # accuracy analysis is done. + if jtu.get_tpu_version() >= 7: + self.skipTest("Accuracy analysis is not yet done on TPU version 7 and above.") + super().setUp() + + def test_result_accuracy_mode_attr(self): + with ir.Context() as context: + hlo.register_dialect(context) + attr = hlo.ResultAccuracyModeAttr.get("DEFAULT") + assert attr is not None + assert attr.value == "DEFAULT" + + def test_result_accuracy_attr(self): + with ir.Context() as context: + hlo.register_dialect(context) + attr = hlo.ResultAccuracyAttr.get( + atol=1e-5, rtol=0.0, ulps=1, mode="TOLERANCE" + ) + assert attr is not None + assert attr.mode == "TOLERANCE" + assert attr.atol == 1e-5 + assert attr.rtol == 0.0 + assert attr.ulps == 1 + + @parameterized.named_parameters( + *generate_test_cases(["exp", "expm1", "exp2", "log", "log1p", "tanh"]) + ) + def test_unary_ops_choose_impl(self, op, x, tp, **kwargs): + @jax.jit + def f_default(x): + y = op(x, accuracy=tp.high) + return y + + @jax.jit + def f_accurate(x): + y = op(x, accuracy=tp.low) + return y + + # Input values that would cause large differences between the two + # implementations. + diff = abs(f_default(x) - f_accurate(x)) + if jtu.get_tpu_version() >= 5 and op in [ + lax.tanh, + jnp.tanh, + lax.log, + jnp.log, + ]: + # From tpu version 5 and onwards, even with tighter tolerance, the high performant + # implementation for tanh is chosen because the chip implementation has improved accuracy. + self.assertTrue(jnp.all(diff == 0)) + else: + self.assertTrue(jnp.any(diff > 0)) + + @parameterized.named_parameters( + *generate_test_cases(["exp", "expm1", "exp2", "log", "log1p", "tanh"]) + ) + def test_unary_vmap(self, op, x, tp, min_error_val): + @jax.jit + def f(x, y): + diff = lambda val: abs( + op(val, accuracy=tp.high) - op(val, accuracy=tp.low) + ) + return diff(x), diff(y) + + diff_x, diff_y = jax.vmap(f, in_axes=(None, 0), out_axes=0)( + min_error_val, x + ) + # diff(min_error_val) should be 0 + self.assertTrue(jnp.all(diff_x == 0)) + # diff(x) should be > 0 + if jtu.get_tpu_version() >= 5 and op in [ + lax.tanh, + jnp.tanh, + lax.log, + jnp.log, + ]: + # From tpu version 5 and onwards, even with tighter tolerance, the high performant + # implementation for tanh and log is chosen because the chip implementation has improved accuracy. + self.assertTrue(jnp.all(diff_y == 0)) + else: + self.assertTrue(jnp.any(diff_y > 0)) + + @parameterized.named_parameters( + *generate_test_cases(["exp", "expm1", "exp2"]) + ) + def test_diff_grad(self, op, x, tp, **kwargs): + @jax.jit + def f_default(x): + default_op = op(x, accuracy=tp.low) + return jnp.sum(default_op) + + f_default_grad = jax.grad(f_default) + + @jax.jit + def f_accurate(x): + high_op = op(x, accuracy=tp.high) + return jnp.sum(high_op) + + f_accurate_grad = jax.grad(f_accurate) + # Accuracy should be carried through to the gradient causing + # a large diff. + diff = abs(f_default_grad(x) - f_accurate_grad(x)) + self.assertTrue(jnp.any(diff > 0)) + + @parameterized.named_parameters( + *generate_test_cases(["log", "log1p", "tanh"]) + ) + def test_grad_unchanged(self, op, x, tp, **kwargs): + @jax.jit + def f(x): + return jnp.sum(op(x)) + + f_grad = jax.grad(f) + + @jax.jit + def f_default(x): + default_op = op(x, accuracy=tp.low) + return jnp.sum(default_op) + + f_default_grad = jax.grad(f_default) + + @jax.jit + def f_accurate(x): + high_op = op(x, accuracy=tp.high) + return jnp.sum(high_op) + + f_accurate_grad = jax.grad(f_accurate) + # Accuracy should be carried through to the gradient causing a large diff. + # Diff between f_default and f_accurate should follow diff(f_grad,f_default_grad). + expected_diff = abs(f_grad(x) - f_default_grad(x)) + if jnp.all(expected_diff > 0): + # Don't expect f_accurate_grad and f_default_grad to be equal. + self.assertFalse( + jnp.all(abs(f_default_grad(x) - f_accurate_grad(x)) == 0) + ) + elif jnp.all(expected_diff == 0): + # f_accurate_grad and f_default_grad should be equal. + diff = abs(f_default_grad(x) - f_accurate_grad(x)) + self.assertTrue(jnp.all(diff == 0)) + else: + raise ValueError("Unexpected diff: ", expected_diff) + + @parameterized.named_parameters( + *generate_test_cases(["cos", "sin", "tan", "sqrt", "rsqrt"]) + ) + def test_single_impl(self, op, x, tp, **kwargs): + @jax.jit + def f_tol(x): + return op(x, accuracy=tp.high) + + @jax.jit + def f(x): + return op(x) + + diff = abs(f_tol(x) - f(x)) + self.assertTrue(jnp.all(diff == 0)) + + @parameterized.named_parameters( + *generate_test_cases(["cos", "sin", "tan", "sqrt", "rsqrt"]) + ) + def test_default_grad(self, op, x, tp, **kwargs): + @jax.jit + def f_tol(x): + return jnp.sum(op(x, accuracy=tp.high)) + + @jax.jit + def f(x): + return jnp.sum(op(x)) + + self.assertTrue(jnp.all(abs(jax.grad(f_tol)(x) - jax.grad(f)(x)) == 0)) + + def test_invalid_accuracy(self): + with self.assertRaisesRegex( + ValueError, "At least one of atol, rtol, or ulps must be set." + ): + lax.exp(1.0, accuracy=lax.Tolerance(atol=0.0, rtol=0.0, ulps=0)) + with self.assertRaisesRegex(ValueError, "Tolerances must be non-negative."): + lax.exp(1.0, accuracy=lax.Tolerance(atol=-4e-10, rtol=0.0, ulps=0)) + + @parameterized.named_parameters( + *generate_test_cases([ + "exp", + "expm1", + "exp2", + "log", + "log1p", + "tanh", + "cos", + "sin", + "tan", + "sqrt", + "rsqrt", + ]) + ) + def test_low_tol(self, op, x, **kwargs): + with self.assertRaisesRegex( + jax.errors.JaxRuntimeError, "impl_type.ok()" + ): + op(x, accuracy=lax.Tolerance(atol=1e-60, rtol=1e-60, ulps=0)) + + def test_accuracy_jaxpr(self): + # Since accuracy is not set, the jaxpr should not contain "accuracy". + self.assertNotIn( + "accuracy", + str( + jax.make_jaxpr(lambda x: lax.exp(x, accuracy=None))( + np.arange(4.0, dtype=np.float32) + ) + ), + ) + # Set accuracy. + self.assertIn( + "accuracy", + str( + jax.make_jaxpr( + lambda x: lax.exp( + x, accuracy=lax.Tolerance(atol=1e-60, rtol=1e-60, ulps=0) + ) + )(np.arange(4.0, dtype=np.float32)) + ), + ) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/util_test.py b/tests/util_test.py index 53414dae977f..384f83e6c102 100644 --- a/tests/util_test.py +++ b/tests/util_test.py @@ -11,17 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import dataclasses +import random +from functools import partial +import gc import operator +import threading -from absl.testing import absltest - +from absl.testing import absltest, parameterized import jax from jax import api_util from jax._src import linear_util as lu from jax._src import test_util as jtu from jax._src import util - from jax._src.util import weakref_lru_cache jax.config.parse_flags_with_absl() @@ -74,6 +76,64 @@ def kw_to_positional(f, store, factor, *args, **kwargs): self.assertEqual(dict(three=6, four=8), scaled_kwargs) self.assertEqual(2, out_thunk()) + def test_wrapped_fun_name(self): + def my_function(): + return + + with self.subTest("function"): + wrapped = lu.wrap_init( + my_function, + debug_info=api_util.debug_info("test", my_function, (), {}), + ) + self.assertEqual(wrapped.__name__, my_function.__name__) + + with self.subTest("default_partial"): + my_partial = partial(my_function) + wrapped = lu.wrap_init( + my_partial, + debug_info=api_util.debug_info("test", my_partial, (), {}), + ) + self.assertEqual(wrapped.__name__, my_function.__name__) + + with self.subTest("nested_default_partial"): + my_partial = partial(partial(my_function)) + wrapped = lu.wrap_init( + my_partial, + debug_info=api_util.debug_info("test", my_partial, (), {}), + ) + self.assertEqual(wrapped.__name__, my_function.__name__) + + with self.subTest("named_partial"): + my_partial = partial(my_function) + my_partial.__name__ = "my_partial" + wrapped = lu.wrap_init( + my_partial, + debug_info=api_util.debug_info("test", my_partial, (), {}), + ) + self.assertEqual(wrapped.__name__, my_partial.__name__) + + with self.subTest("lambda"): + l = lambda: my_function() + wrapped = lu.wrap_init( + l, + debug_info=api_util.debug_info("test", l, (), {}), + ) + self.assertEqual(wrapped.__name__, "") + + with self.subTest("unnamed_callable"): + + class MyCallable: + + def __call__(self): + return + + my_callable = MyCallable() + wrapped = lu.wrap_init( + my_callable, + debug_info=api_util.debug_info("test", my_callable, (), {}), + ) + self.assertEqual(wrapped.__name__, "") + def test_weakref_lru_cache(self): @weakref_lru_cache def example_cached_fn(key): @@ -98,6 +158,128 @@ def reference_loop_generator(x): for _ in range(4097): reference_loop_generator(lambda x: x) + @parameterized.named_parameters( + dict(weakref_count=weakref_count, + testcase_name=f"_{weakref_count=}") + for weakref_count in [0, 1, 3]) + def test_multi_weak_ref_cache(self, *, weakref_count=1): + + class Key: # hashed by id + def __init__(self, x): + self.x = x + + if weakref_count > 0: + util.weakref_cache_key_types.add(Key) + + @partial(util.multi_weakref_lru_cache, trace_context_in_key=False) + def myfun(a, k1, *, k2, k3): + return f"{a=}, {k1=}, {k2=}, {k3=}" + + def check_invariant(expected_live_keys: int): + self.assertLen(myfun._multi_weakref_id_to_key, expected_live_keys if weakref_count > 1 else 0) + for key_id, key in myfun._multi_weakref_id_to_key.items(): + for wr in key.weakrefs: + self.assertIn(wr, myfun._multi_weakref_to_key_ids) + self.assertIn(key_id, myfun._multi_weakref_to_key_ids[wr]) + + k1 = Key(1) + k3 = (k1, k1) if weakref_count > 1 else 4 + util.clear_all_caches() + r1 = myfun(2, k1, k2=3, k3=k3) # miss + c1 = myfun.cache_info() + self.assertEqual((0, 1, 1), (c1.hits, c1.misses, c1.currsize)) + check_invariant(1) + + for i in range(10): + r2 = myfun(2, k1, k2=3, k3=k3) # all hits + self.assertIs(r1, r2) + c2 = myfun.cache_info() + self.assertEqual((1 + i, 1, 1), (c2.hits, c2.misses, c2.currsize)) + check_invariant(1) + + del k1, k3 # expect that the cache entries are removed (if weakref_count > 0) + gc.collect() + c3 = myfun.cache_info() + self.assertEqual(c3.currsize, 0 if weakref_count > 0 else 1) + check_invariant(0) + + k1_2 = Key(2) + k3_2 = (Key(3), Key(3)) if weakref_count > 1 else (3, 3) + r4 = myfun(2, k1_2, k2=3, k3=k3_2) # miss + c4 = myfun.cache_info() + self.assertEqual((10, 2, (1 if weakref_count > 0 else 2)), (c4.hits, c4.misses, c4.currsize)) + check_invariant(1) + + if weakref_count > 1: + del k3_2 # clear the cache entry + gc.collect() + c5 = myfun.cache_info() + self.assertEqual((10, 2, 0), (c5.hits, c5.misses, c5.currsize)) + check_invariant(0) + + k3_3 = (Key(3), Key(3)) + r6 = myfun(2, k1_2, k2=3, k3=k3_3) # miss because Key hashed by it + self.assertIsNot(r4, r6) + c6 = myfun.cache_info() + self.assertEqual((10, 3, 1), (c6.hits, c6.misses, c6.currsize)) + check_invariant(1) + + del k1_2 + gc.collect() + c7 = myfun.cache_info() + self.assertEqual(0 if weakref_count > 0 else 2, c7.currsize ) + check_invariant(0) + + def test_multi_weak_ref_cache_custom_tuple(self): + class MyTuple(tuple): + pass + + class Key: # hashed by id + def __init__(self, x): + self.x = x + + util.weakref_cache_key_types.add(Key) + + @partial(util.multi_weakref_lru_cache, trace_context_in_key=False) + def my_fun(a): + self.assertIsInstance(a, MyTuple) + return str(a) + + key = Key(1) + my_fun(MyTuple([key, key])) + self.assertLen(my_fun._multi_weakref_id_to_key, 0) # key was not picked up + + del key + self.assertEqual(1, my_fun.cache_info().currsize) # cache is not cleaned + + def test_multi_weakref_lru_cache_threads(self): + num_workers = 5 + num_live_keys_per_worker = 16 + size_key_space = 32 + @dataclasses.dataclass(frozen=True) + class WRKey: + f: int + + util.weakref_cache_key_types.add(WRKey) + + @partial(util.multi_weakref_lru_cache, maxsize=size_key_space // 2) + def myfun(k: WRKey): + return None + + def Worker(): + keys = [None] * num_live_keys_per_worker # These are the live keys for this worker + for i in range(1000): + key_idx = random.randint(0, num_live_keys_per_worker - 1) + key = WRKey(random.randint(0, size_key_space)) + myfun(key) + keys[key_idx] = key # Kill some previous key and keep this live + + workers = [threading.Thread(target=Worker()) for _ in range(num_workers)] + for t in workers: + t.start() + for t in workers: + t.join() + class SafeMapTest(jtu.JaxTestCase): @@ -186,17 +368,17 @@ def test_safe_zip_errors(self): util.safe_zip(lambda x: x) with self.assertRaisesRegex( - ValueError, r"safe_zip\(\) argument 2 is longer than argument 1" + ValueError, r"zip\(\) argument 2 is longer than argument 1" ): util.safe_zip(range(3), range(4)) with self.assertRaisesRegex( - ValueError, r"safe_zip\(\) argument 2 is shorter than argument 1" + ValueError, r"zip\(\) argument 2 is shorter than argument 1" ): util.safe_zip(range(7), range(2)) with self.assertRaisesRegex( - ValueError, r"safe_zip\(\) argument 2 is longer than argument 1" + ValueError, r"zip\(\) argument 2 is longer than argument 1" ): util.safe_zip((), range(3)) diff --git a/tests/version_test.py b/tests/version_test.py index b78e61ae024c..14da82df2e3e 100644 --- a/tests/version_test.py +++ b/tests/version_test.py @@ -143,6 +143,7 @@ def testBuildVersionFromEnvironment(self): JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None): with assert_no_subprocess_call(): version = jax.version._get_version_for_build() + self.assertFalse(jax.version._is_prerelease()) self.assertEqual(version, base_version) self.assertValidVersion(version) @@ -150,6 +151,7 @@ def testBuildVersionFromEnvironment(self): JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None): with assert_no_subprocess_call(): version = jax.version._get_version_for_build() + self.assertFalse(jax.version._is_prerelease()) self.assertEqual(version, base_version) self.assertValidVersion(version) @@ -183,6 +185,20 @@ def testBuildVersionFromEnvironment(self): ): with assert_no_subprocess_call(): version = jax.version._get_version_for_build() + self.assertTrue(jax.version._is_prerelease()) + self.assertEqual(version, f"{base_version}rc0") + self.assertValidVersion(version) + + with jtu.set_env( + JAX_RELEASE=None, + JAXLIB_RELEASE="1", + JAX_NIGHTLY=None, + JAXLIB_NIGHTLY=None, + WHEEL_VERSION_SUFFIX="rc0", + ): + with assert_no_subprocess_call(): + version = jax.version._get_version_for_build() + self.assertTrue(jax.version._is_prerelease()) self.assertEqual(version, f"{base_version}rc0") self.assertValidVersion(version) diff --git a/tests/x64_context_test.py b/tests/x64_context_test.py index d4403b7e5e30..f31ed2ebf85c 100644 --- a/tests/x64_context_test.py +++ b/tests/x64_context_test.py @@ -16,20 +16,29 @@ import concurrent.futures from functools import partial import time -import unittest from absl.testing import absltest import numpy as np import jax +from jax import custom_jvp, custom_vjp, grad, jvp from jax import lax from jax import random -from jax.experimental import enable_x64, disable_x64 import jax.numpy as jnp +from jax._src import config import jax._src.test_util as jtu + jax.config.parse_flags_with_absl() +# TODO(jakevdp): remove this check for JAX v0.9.0 and test jax.enable_x64 directly. +with jtu.ignore_warning(message=".* is deprecated", category=DeprecationWarning): + if hasattr(jax.experimental, "enable_x64"): + from jax.experimental import enable_x64, disable_x64 + else: + enable_x64 = jax.enable_x64 + disable_x64 = lambda: jax.enable_x64(False) + class X64ContextTests(jtu.JaxTestCase): @jtu.sample_product(jit=jtu.JIT_IMPLEMENTATION) @@ -113,11 +122,14 @@ def func_x64(): @jax.legacy_prng_key('allow') @jax.debug_key_reuse(False) @jtu.ignore_warning(category=UserWarning, - message="Explicitly requested dtype float64 is not available") + message="Explicitly requested dtype float64 is not available") def test_jit_cache(self): if jtu.test_device_matches(["tpu"]): self.skipTest("64-bit random not available on TPU") + if config.explicit_x64_dtypes.value == config.ExplicitX64Mode.ERROR: + self.skipTest("Test uses float64 which is not available") + f = partial(random.uniform, random.PRNGKey(0), (1,), 'float64', -1, 1) with disable_x64(): for _ in range(2): @@ -126,7 +138,6 @@ def test_jit_cache(self): for _ in range(2): f() - @unittest.skip("test fails, see #8552") def test_convert_element_type(self): # Regression test for part of https://github.com/jax-ml/jax/issues/5982 with enable_x64(): @@ -139,5 +150,120 @@ def test_convert_element_type(self): z = jax.jit(lambda x: x.astype(jnp.int32))(x) self.assertEqual(z.dtype, jnp.int32) + def test_python_scalar(self): + @jax.jit + def f(a): + with enable_x64(): + return 2 + a + self.assertEqual(f(1).dtype, jnp.int64) + + def test_grad(self): + def fun(x): + with enable_x64(True): + return jnp.sin(x) + self.assertEqual( + jax.grad(fun)(0.5).dtype, + jnp.float64 if jax.config.x64_enabled else jnp.float32, + ) + + def test_sin(self): + def fun(x): + with enable_x64(True): + x = jnp.asarray(x, dtype=jnp.float64) + return lax.sin(x) + + self.assertEqual(fun(0.5).dtype, jnp.float64) + + def test_mul(self): + def fun(x, y): + with enable_x64(True): + x = jnp.asarray(x, dtype=jnp.float64) + y = jnp.asarray(y, dtype=jnp.float64) + return lax.mul(x, y) + + self.assertEqual(fun(0.5, 1.5).dtype, jnp.float64) + + @jtu.sample_product(disable_jit=[True, False]) + def test_scan_with_contextmanager(self, disable_jit): + with jax.disable_jit(disable_jit): + def f(a): + def body(carry, _): + with enable_x64(): + y = (carry + a).astype(jnp.int64) + assert y.dtype == jnp.int64 + z = y.astype(jnp.int32) + return carry, (z, y) + return lax.scan(body, jnp.int32(2), jnp.arange(4)) + carry_out, ys_out = f(3) + self.assertEqual(carry_out.dtype, jnp.int32) + self.assertEqual(ys_out[0].dtype, jnp.int32) + self.assertEqual(ys_out[1].dtype, jnp.int64) + + def test_custom_jvp(self): + + @custom_jvp + def f(x): + return x ** 2. + + @f.defjvp + def f_jvp(xs, ts): + x, = xs + t, = ts + self.assertTrue(jax.config.x64_enabled) + return f(x), t * jnp.sin(x) + + def g(x): + with enable_x64(): + x = jnp.array(x, jnp.float64) + return f(x) + + self.assertEqual(g(5.).dtype, jnp.float64) + out_primal, out_tangent = jvp(g, (5.,), (1.,)) + self.assertEqual(out_primal.dtype, jnp.float64) + self.assertEqual(out_tangent.dtype, jnp.float64) + + with enable_x64(False): + self.assertEqual(g(5.).dtype, jnp.float64) + self.assertEqual(grad(g)(5.).dtype, jnp.float32) + self.assertEqual(grad(grad(g))(5.).dtype, jnp.float32) + self.assertEqual(grad(grad(grad(g)))(5.).dtype, jnp.float32) + + with enable_x64(True): + self.assertEqual(g(5.).dtype, jnp.float64) + self.assertEqual(grad(g)(5.).dtype, jnp.float64) + self.assertEqual(grad(grad(g))(5.).dtype, jnp.float64) + self.assertEqual(grad(grad(grad(g)))(5.).dtype, jnp.float64) + + def test_custom_vjp(self): + + @custom_vjp + def f(x): + return x ** 2. + + def f_fwd(x): + return f(x), jnp.sin(x) + + def f_bwd(res, t): + return (res * t,) + f.defvjp(f_fwd, f_bwd) + + def g(x): + with enable_x64(): + x = jnp.array(x, jnp.float64) + return f(x) + + with enable_x64(False): + self.assertEqual(g(5.).dtype, jnp.float64) + self.assertEqual(grad(g)(5.).dtype, jnp.float32) + self.assertEqual(grad(grad(g))(5.).dtype, jnp.float32) + self.assertEqual(grad(grad(grad(g)))(5.).dtype, jnp.float32) + + with enable_x64(True): + self.assertEqual(g(5.).dtype, jnp.float64) + self.assertEqual(grad(g)(5.).dtype, jnp.float64) + self.assertEqual(grad(grad(g))(5.).dtype, jnp.float64) + self.assertEqual(grad(grad(grad(g)))(5.).dtype, jnp.float64) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/xla_bridge_test.py b/tests/xla_bridge_test.py index 97e8765cc096..c995f874010e 100644 --- a/tests/xla_bridge_test.py +++ b/tests/xla_bridge_test.py @@ -17,12 +17,12 @@ from absl import logging from absl.testing import absltest - from jax import version from jax._src import compiler from jax._src import config from jax._src import test_util as jtu from jax._src import xla_bridge as xb +from jax._src.lib import _profiler from jax._src.lib import xla_client as xc config.parse_flags_with_absl() @@ -35,18 +35,14 @@ class XlaBridgeTest(jtu.JaxTestCase): def test_set_device_assignment_no_partition(self): compile_options = compiler.get_compile_options( num_replicas=4, num_partitions=1, device_assignment=[0, 1, 2, 3]) - expected_device_assignment = ("Computations: 1 Replicas: 4\nComputation 0: " - "0 1 2 3 \n") - self.assertEqual(compile_options.device_assignment.__repr__(), - expected_device_assignment) + self.assertEqual(compile_options.device_assignment.replica_count(), 4) + self.assertEqual(compile_options.device_assignment.computation_count(), 1) def test_set_device_assignment_with_partition(self): compile_options = compiler.get_compile_options( num_replicas=2, num_partitions=2, device_assignment=[[0, 1], [2, 3]]) - expected_device_assignment = ("Computations: 2 Replicas: 2\nComputation 0: " - "0 2 \nComputation 1: 1 3 \n") - self.assertEqual(compile_options.device_assignment.__repr__(), - expected_device_assignment) + self.assertEqual(compile_options.device_assignment.replica_count(), 2) + self.assertEqual(compile_options.device_assignment.computation_count(), 2) def test_set_fdo_profile(self): compile_options = compiler.get_compile_options( @@ -136,13 +132,15 @@ def test_register_plugin(self): "name1:path1,name2:path2,name3" ) with mock.patch.object( - xc.profiler, "register_plugin_profiler", autospec=True + _profiler, "register_plugin_profiler", autospec=True ): xb.register_pjrt_plugin_factories_from_env() registration = xb._backend_factories["name1"] with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make: with mock.patch.object( - xc, "pjrt_plugin_initialized", autospec=True, return_vale=True + xc, + "pjrt_plugin_initialized", + autospec=True, ): with mock.patch.object(xc, "initialize_pjrt_plugin", autospec=True): registration.factory() @@ -174,13 +172,15 @@ def test_register_plugin_with_config(self): ) with mock.patch.object(xc, "load_pjrt_plugin_dynamically", autospec=True): with mock.patch.object( - xc.profiler, "register_plugin_profiler", autospec=True + _profiler, "register_plugin_profiler", autospec=True ): xb.register_pjrt_plugin_factories_from_env() registration = xb._backend_factories["name1"] with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make: with mock.patch.object( - xc, "pjrt_plugin_initialized", autospec=True, return_vale=True + xc, + "pjrt_plugin_initialized", + autospec=True, ): with mock.patch.object(xc, "initialize_pjrt_plugin", autospec=True): registration.factory() @@ -202,6 +202,42 @@ def test_register_plugin_with_config(self): mock_make.assert_called_once_with("name1", options, None) + def test_register_plugin_with_lazy_config(self): + options = {"bar": "baz"} + + def getopts(): + return options + + def make_c_api_client(plugin_name, new_options, *args, **kwargs): + for k in options: + self.assertEqual(new_options[k], options[k]) + + with mock.patch.object(xc, "load_pjrt_plugin_dynamically", autospec=True): + with mock.patch.object( + _profiler, "register_plugin_profiler", autospec=True + ): + xb.register_plugin("foo", options=getopts, library_path="/dev/null") + with mock.patch.object( + xc, "make_c_api_client", autospec=True, wraps=make_c_api_client + ) as mock_make: + with mock.patch.object(xc, "pjrt_plugin_initialized", autospec=True): + xb._backend_factories["foo"].factory() + mock_make.assert_called_once() + + def test_num_cpu_devices_update(self): + xb.devices() + + current_val = config.config.jax_num_cpu_devices + + config.update("jax_num_cpu_devices", current_val) + + with self.assertRaisesRegex( + RuntimeError, + "jax_num_cpu_devices config should be updated before backends are" + " initialized", + ): + config.update("jax_num_cpu_devices", current_val + 2) + class GetBackendTest(jtu.JaxTestCase): diff --git a/tests/xla_metadata_test.py b/tests/xla_metadata_test.py index d141bc15c249..2e0b57f015e5 100644 --- a/tests/xla_metadata_test.py +++ b/tests/xla_metadata_test.py @@ -17,19 +17,45 @@ correctly propagated to the jaxpr and mlir. """ -from absl.testing import absltest +from absl.testing import absltest, parameterized import jax from jax._src import config from jax._src import test_util as jtu from jax._src.lax import lax from jax.experimental.xla_metadata import set_xla_metadata import jax.numpy as jnp +import numpy as np config.parse_flags_with_absl() class XlaMetadataTest(jtu.JaxTestCase): + def _assert_metadata_appears_once_per_op( + self, + hlo_text: str, + expected_tagged_ops: list[str], + metadata: dict[str, str], + ): + attribute_strings = [f'{k}="{v}"' for k, v in metadata.items()] + op_with_metadata_count = {op: 0 for op in expected_tagged_ops} + + for line in hlo_text.splitlines(): + for op in expected_tagged_ops: + if (str(op + "(") in line and all(attr in line for attr in attribute_strings) + and "frontend_attributes=" in line): + op_with_metadata_count[op] += 1 + + for op in op_with_metadata_count: + self.assertEqual( + op_with_metadata_count[op], + 1, + f"Expected op '{op}' to have the metadata exactly once," + f" but found it {op_with_metadata_count[op]} times\n" + f"Metadata: {metadata}\n" + f"HLO Graph:\n\n{hlo_text}", + ) + def test_f_jitted(self): @jax.jit def f(a, b): @@ -62,6 +88,33 @@ def f(a, b): f_lowered_text = f.lower(1.0, 2.0).as_text() self.assertIn('mhlo.frontend_attributes = {a = "10"}', f_lowered_text) + def test_decorator(self): + @set_xla_metadata(a="b") + @jax.jit + def f(a, b): + return a + b + + f_jaxpr = jax.make_jaxpr(f)(1, 2) + eqns = f_jaxpr.eqns + for eq in eqns[1:]: + self.assertDictEqual(eq.ctx.attributes, {"a": "b"}) + + f_lowered_text = f.lower(1.0, 2.0).as_text() + self.assertIn('mhlo.frontend_attributes = {a = "b"}', f_lowered_text) + + def test_decorator_and_context_manager_nested(self): + @set_xla_metadata(a="b") + @jax.jit + def f(a, b): + with set_xla_metadata(c="d"): + return a + b + + f_lowered_text = f.lower(1.0, 2.0).as_text() + self.assertIn( + 'mhlo.frontend_attributes = {a = "b", c = "d"}', + f_lowered_text, + ) + def test_f_nonjitted(self): def f_add(a, b): return lax.add(a, b) @@ -190,6 +243,39 @@ def while_fn(a): if "stablehlo.add" in line: self.assertIn('mhlo.frontend_attributes = {a = "c"}', line) + def test_cond_annotates_branches(self): + sin = jnp.sin + cos = jnp.cos + + @jax.jit + def f(x): + with set_xla_metadata(a="b"): + return jax.lax.cond(x < 0., sin, cos, x) + + hlo_lines = f.lower(1.).as_text().split("\n") + sin_hlo, = (line for line in hlo_lines if "stablehlo.sine" in line) + cos_hlo, = (line for line in hlo_lines if "stablehlo.cosine" in line) + self.assertIn('mhlo.frontend_attributes = {a = "b"}', sin_hlo) + self.assertIn('mhlo.frontend_attributes = {a = "b"}', cos_hlo) + + def test_cond_annotates_branches_and_none_unsets(self): + sin = jnp.sin + + def cos(x): + with set_xla_metadata(a=None): + return jnp.cos(x) + + @jax.jit + def f(x): + with set_xla_metadata(a="b"): + return jax.lax.cond(x < 0., sin, cos, x) + + hlo_lines = f.lower(1.).as_text().split("\n") + sin_hlo, = (line for line in hlo_lines if "stablehlo.sine" in line) + cos_hlo, = (line for line in hlo_lines if "stablehlo.cosine" in line) + self.assertIn( 'mhlo.frontend_attributes = {a = "b"}', sin_hlo) + self.assertNotIn('mhlo.frontend_attributes = {a = "b"}', cos_hlo) + def test_nested_jit(self): @jax.jit def f(x, y): @@ -255,10 +341,10 @@ def f2(x, y): with set_xla_metadata(a="b"): return (x + y, y * 2.0) - f_vmap_jaxpr = jax.make_jaxpr(jax.vmap(f2, in_axes=(0, None))) + f2_vmap = jax.vmap(f2, in_axes=(0, None)) self.assertIn( 'mhlo.frontend_attributes = {a = "b"}', - f_vmap_jaxpr.lower(jnp.arange(5.0), 1.0).as_text(), + jax.jit(f2_vmap).lower(jnp.arange(5.0), 1.0).as_text(), ) def test_multiple_instructions(self): @@ -284,6 +370,104 @@ def f(x): 'mhlo.frontend_attributes = {a = "b"}', f.lower(jnp.arange(5.0)).as_text() ) + @parameterized.parameters( + ("x*x", lambda x: x * x, "multiply"), + ("sin(x)", jnp.sin, "sin"), + ("tanh(x)", jnp.tanh, "tanh"), + ("1/x", lambda x: 1 / x, "divide"), + ("sinc(x)", jnp.sinc, "call"), + ) + def test_value_tagging(self, name, fn, expected_tagged_op): + metadata = {"test_value_tagging": name} + + def wrapped_fn(x): + return set_xla_metadata(fn(x), **metadata) + + x_scalar = jnp.array(0.7) + text = jax.jit(wrapped_fn).lower(x_scalar).as_text("hlo") + self._assert_metadata_appears_once_per_op( + text, [expected_tagged_op], metadata) + + @parameterized.parameters( + ("x*x", lambda x: x * x, "add"), + ("sin(x)", jnp.sin, "cosine"), + ("tanh(x)", jnp.tanh, "add"), + ("1/x", lambda x: 1 / x, "negate"), + ("sinc(x)", jnp.sinc, "call"), + ) + def test_value_grad_tagging(self, name, fn, expected_tagged_op): + metadata = {"test_value_grad_tagging": name} + + @jax.custom_vjp + def wrapped_fn(x): + return fn(x) + + def fwd(*args): + primal_out, vjp_fn = jax.vjp(fn, *args) + return primal_out, vjp_fn + + def bwd(vjp_fn, cts_in): + cts_out = vjp_fn(cts_in) + cts_out = set_xla_metadata(cts_out, **metadata) + return cts_out + + wrapped_fn.defvjp(fwd, bwd) + + x_scalar = jnp.array(0.7) + text = jax.jit(jax.grad(wrapped_fn)).lower(x_scalar).as_text("hlo") + self._assert_metadata_appears_once_per_op( + text, [expected_tagged_op], metadata) + + def test_vmap_multi_input_output_value_tagging(self): + metadata = {"test_vmap_multi_input_output_value_tagging": "value"} + + fn = lambda x, y, z: (x @ y + z, z - x @ y) + + @jax.vmap + def vmapped_fn(x_item, y_item, z_item): + return set_xla_metadata(fn(x_item, y_item, z_item), **metadata) + + batch_size, num_rows, num_cols = 4, 5, 6 + rng = np.random.default_rng(0) + x_batch = rng.random((batch_size, num_rows, num_cols)).astype(np.float32) + y_batch = rng.random((batch_size, num_cols, num_rows)).astype(np.float32) + z_batch = rng.random((batch_size, num_rows, num_rows)).astype(np.float32) + inputs = (x_batch, y_batch, z_batch) + + text = jax.jit(vmapped_fn).lower(*inputs).as_text("hlo") + self._assert_metadata_appears_once_per_op( + text, ["add", "subtract"], metadata) + + def test_sharding_support_value_tagging(self): + mesh = jtu.create_mesh((1,), "data") + np_inp = np.arange(8, dtype=np.float32) + arr = jax.device_put(np_inp, jax.NamedSharding(mesh, jax.P("data"))) + + metadata = {"test_sharding_support_value_tagging": "value"} + + @jax.jit + def wrapped_fn(x): + return set_xla_metadata(x * 2.0, **metadata) + + text = jax.jit(wrapped_fn).lower(arr).as_text("hlo") + self._assert_metadata_appears_once_per_op(text, ["multiply"], metadata) + + def test_scan_support_value_tagging(self): + metadata = {"test_scan_support_value_tagging": "value"} + + fn = lambda carry, x: (carry + x) * 2.0 + + def scan_body_val_with_metadata(carry, x): + tagged_result = set_xla_metadata(fn(carry, x), **metadata) + return tagged_result, tagged_result + + def scan_fn(init_carry, inputs_arr): + return jax.lax.scan(scan_body_val_with_metadata, init_carry, inputs_arr) + + inputs = (jnp.array(0.0), jnp.arange(1, 4, dtype=jnp.float32)) + text = jax.jit(scan_fn).lower(*inputs).as_text("hlo") + self._assert_metadata_appears_once_per_op(text, ["multiply"], metadata) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/third_party/repo.bzl b/third_party/repo.bzl index 17e0bbb03542..f09875140772 100644 --- a/third_party/repo.bzl +++ b/third_party/repo.bzl @@ -129,7 +129,7 @@ def tf_http_archive(name, sha256, urls, **kwargs): "storage.googleapis.com", )]): fail("The first entry of tf_http_archive(urls) must be a mirror " + - "URL, preferrably mirror.tensorflow.org. Even if you don't have " + + "URL, preferably mirror.tensorflow.org. Even if you don't have " + "permission to mirror the file, please put the correctly " + "formatted mirror URL there anyway, because someone will come " + "along shortly thereafter and mirror the file.") @@ -145,22 +145,3 @@ def tf_http_archive(name, sha256, urls, **kwargs): urls = urls, **kwargs ) - -def _tf_vendored_impl(repository_ctx): - parent_path = repository_ctx.path(repository_ctx.attr.parent).dirname - - # get_child doesn't allow slashes. Yes this is silly. bazel_skylib paths - # doesn't work with path objects. - relpath_parts = repository_ctx.attr.relpath.split("/") - vendored_path = parent_path - for part in relpath_parts: - vendored_path = vendored_path.get_child(part) - repository_ctx.symlink(vendored_path, ".") - -tf_vendored = repository_rule( - implementation = _tf_vendored_impl, - attrs = { - "parent": attr.label(default = "//:WORKSPACE"), - "relpath": attr.string(), - }, -) diff --git a/third_party/xla/revision.bzl b/third_party/xla/revision.bzl new file mode 100644 index 000000000000..c8d30ca63f49 --- /dev/null +++ b/third_party/xla/revision.bzl @@ -0,0 +1,25 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# buildifier: disable=module-docstring + +# To update XLA to a new revision, +# a) update XLA_COMMIT to the new git commit hash +# b) get the sha256 hash of the commit by running: +# curl -L https://api.github.com/repos/openxla/xla/tarball/{git_hash} | sha256sum +# and update XLA_SHA256 with the result. + +# buildifier: disable=module-docstring +XLA_COMMIT = "bb760b047bdbfeff962f0366ad5cc782c98657e0" +XLA_SHA256 = "76a3e1d5118e7bd6e5f275818a0ef0440407d8e946068590ed03c84b93383f0b" diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 73bf2eb3850d..be345f8c6ed3 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -14,22 +14,18 @@ # buildifier: disable=module-docstring load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") - -# To update XLA to a new revision, -# a) update XLA_COMMIT to the new git commit hash -# b) get the sha256 hash of the commit by running: -# curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum -# and update XLA_SHA256 with the result. - -XLA_COMMIT = "df971129bd82e381954da0185b534220e21798a4" -XLA_SHA256 = "11e9a568320cf7e7d61819620fd369927527ecefb68d5d1154b1521456bbdb72" +load("//third_party/xla:revision.bzl", "XLA_COMMIT", "XLA_SHA256") def repo(): tf_http_archive( name = "xla", sha256 = XLA_SHA256, - strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT), - urls = tf_mirror_urls("https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)), + type = "tar.gz", + strip_prefix = "openxla-xla-{commit}".format(commit = XLA_COMMIT[:7]), + # We use an automated tool to update the revision.bzl file. GitHub prohibits the crawling of + # web links (`/archive/`) links so we use the GitHub API endpoint to get the tarball + # instead. + urls = tf_mirror_urls("https://api.github.com/repos/openxla/xla/tarball/{commit}".format(commit = XLA_COMMIT)), ) # For development, one often wants to make changes to the TF repository as well